From 2cb36c3277d15b049e9db58dd30be9f292d7264a Mon Sep 17 00:00:00 2001 From: tpp Date: Tue, 6 Aug 2024 19:58:16 -0500 Subject: [PATCH 01/35] Do not allow cardinality to go below 1 --- pkg/planner/cardinality/row_count_column.go | 6 +++--- pkg/planner/cardinality/row_count_index.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/planner/cardinality/row_count_column.go b/pkg/planner/cardinality/row_count_column.go index c0f19fa0d8acc..b3df50776d106 100644 --- a/pkg/planner/cardinality/row_count_column.go +++ b/pkg/planner/cardinality/row_count_column.go @@ -278,7 +278,7 @@ func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges [] return 0, errors.Trace(err) } cnt -= lowCnt - cnt = mathutil.Clamp(cnt, 0, c.NotNullCount()) + cnt = mathutil.Clamp(cnt, 1, c.NotNullCount()) } if !rg.LowExclude && lowVal.IsNull() { cnt += float64(c.NullCount) @@ -291,7 +291,7 @@ func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges [] cnt += highCnt } - cnt = mathutil.Clamp(cnt, 0, c.TotalRowCount()) + cnt = mathutil.Clamp(cnt, 1, c.TotalRowCount()) // If the current table row count has changed, we should scale the row count accordingly. increaseFactor := c.GetIncreaseFactor(realtimeRowCount) @@ -312,7 +312,7 @@ func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges [] } rowCount += cnt } - rowCount = mathutil.Clamp(rowCount, 0, float64(realtimeRowCount)) + rowCount = mathutil.Clamp(rowCount, 1, float64(realtimeRowCount)) return rowCount, nil } diff --git a/pkg/planner/cardinality/row_count_index.go b/pkg/planner/cardinality/row_count_index.go index 8d42904d4796f..48861044725e4 100644 --- a/pkg/planner/cardinality/row_count_index.go +++ b/pkg/planner/cardinality/row_count_index.go @@ -350,7 +350,7 @@ func getIndexRowCountForStatsV2(sctx context.PlanContext, idx *statistics.Index, } totalCount += count } - totalCount = mathutil.Clamp(totalCount, 0, float64(realtimeRowCount)) + totalCount = mathutil.Clamp(totalCount, 1, float64(realtimeRowCount)) return totalCount, nil } From ffac1b9300594adf8de97519535e1e936e09efb5 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 08:53:47 -0500 Subject: [PATCH 02/35] testcase updates1 --- pkg/statistics/statistics_test.go | 2 +- pkg/statistics/testdata/integration_suite_out.json | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/statistics/statistics_test.go b/pkg/statistics/statistics_test.go index fe18ee80b0ee3..ca192db2f0d24 100644 --- a/pkg/statistics/statistics_test.go +++ b/pkg/statistics/statistics_test.go @@ -441,7 +441,7 @@ func SubTestIndexRanges() func(*testing.T) { ran[0].HighVal[0] = types.NewIntDatum(1000) count, err = GetRowCountByIndexRanges(ctx, &tbl.HistColl, 0, ran) require.NoError(t, err) - require.Equal(t, 0, int(count)) + require.Equal(t, 1, int(count)) } } diff --git a/pkg/statistics/testdata/integration_suite_out.json b/pkg/statistics/testdata/integration_suite_out.json index e2706b96dd588..14cdc66e90ce0 100644 --- a/pkg/statistics/testdata/integration_suite_out.json +++ b/pkg/statistics/testdata/integration_suite_out.json @@ -26,8 +26,8 @@ "└─IndexRangeScan_5 1.36 cop[tikv] table:exp_backoff, index:idx(a, b, c, d) range:[1 1 1 3,1 1 1 5], keep order:false" ], [ - "IndexReader_6 0.00 root index:IndexRangeScan_5", - "└─IndexRangeScan_5 0.00 cop[tikv] table:exp_backoff, index:idx(a, b, c, d) range:[1 1 1 3,1 1 1 5], keep order:false" + "IndexReader_6 1.00 root index:IndexRangeScan_5", + "└─IndexRangeScan_5 1.00 cop[tikv] table:exp_backoff, index:idx(a, b, c, d) range:[1 1 1 3,1 1 1 5], keep order:false" ] ] }, From 05b8e18418134de36ea50ecc8bfe913691e73a87 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 09:39:51 -0500 Subject: [PATCH 03/35] testcase updates2 --- .../handle/globalstats/global_stats_test.go | 6 +- .../partition/partition_boundaries.result | 500 +++++++++--------- 2 files changed, 253 insertions(+), 253 deletions(-) diff --git a/pkg/statistics/handle/globalstats/global_stats_test.go b/pkg/statistics/handle/globalstats/global_stats_test.go index 3272e36e6cce2..e446750ea1940 100644 --- a/pkg/statistics/handle/globalstats/global_stats_test.go +++ b/pkg/statistics/handle/globalstats/global_stats_test.go @@ -783,9 +783,9 @@ func TestGlobalStats(t *testing.T) { // Even if we have global-stats, we will not use it when the switch is set to `static`. tk.MustExec("set @@tidb_partition_prune_mode = 'static';") tk.MustQuery("explain format = 'brief' select a from t where a > 5").Check(testkit.Rows( - "PartitionUnion 4.00 root ", - "├─IndexReader 0.00 root index:IndexRangeScan", - "│ └─IndexRangeScan 0.00 cop[tikv] table:t, partition:p0, index:a(a) range:(5,+inf], keep order:false", + "PartitionUnion 5.00 root ", + "├─IndexReader 1.00 root index:IndexRangeScan", + "│ └─IndexRangeScan 1.00 cop[tikv] table:t, partition:p0, index:a(a) range:(5,+inf], keep order:false", "├─IndexReader 2.00 root index:IndexRangeScan", "│ └─IndexRangeScan 2.00 cop[tikv] table:t, partition:p1, index:a(a) range:(5,+inf], keep order:false", "└─IndexReader 2.00 root index:IndexRangeScan", diff --git a/tests/integrationtest/r/executor/partition/partition_boundaries.result b/tests/integrationtest/r/executor/partition/partition_boundaries.result index fbb2627e500ea..a3df0aebe0137 100644 --- a/tests/integrationtest/r/executor/partition/partition_boundaries.result +++ b/tests/integrationtest/r/executor/partition/partition_boundaries.result @@ -125,29 +125,29 @@ a b 1000002 1000002 Filler ... explain format='brief' SELECT * FROM t WHERE a = 3000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a = 3000000; a b explain format='brief' SELECT * FROM t WHERE a IN (3000000); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a IN (3000000); a b explain format='brief' SELECT * FROM t WHERE a = 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a = 3000001; a b explain format='brief' SELECT * FROM t WHERE a IN (3000001); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] eq(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a IN (3000001); a b @@ -161,8 +161,8 @@ a b -2147483648 MIN_INT filler... explain format='brief' SELECT * FROM t WHERE a IN (-2147483647, -2147483646); id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] in(executor__partition__partition_boundaries.t.a, -2147483647, -2147483646) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] in(executor__partition__partition_boundaries.t.a, -2147483647, -2147483646) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a IN (-2147483647, -2147483646); a b @@ -272,8 +272,8 @@ a b 2999999 2999999 Filler ... explain format='brief' SELECT * FROM t WHERE a IN (3000000, 3000001, 3000002); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] in(executor__partition__partition_boundaries.t.a, 3000000, 3000001, 3000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] in(executor__partition__partition_boundaries.t.a, 3000000, 3000001, 3000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a IN (3000000, 3000001, 3000002); a b @@ -342,15 +342,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 0 OR a = -1; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] or(0, eq(executor__partition__partition_boundaries.t.a, -1)) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] or(0, eq(executor__partition__partition_boundaries.t.a, -1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 0 OR a = -1; a b explain format='brief' SELECT * FROM t WHERE a != 0; id estRows task access object operator info -TableReader 6.00 root partition:all data:Selection -└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 0) +TableReader 7.00 root partition:all data:Selection +└─Selection 7.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a != 0; a b @@ -362,8 +362,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0; id estRows task access object operator info -TableReader 6.00 root partition:all data:Selection -└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0) +TableReader 7.00 root partition:all data:Selection +└─Selection 7.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0; a b @@ -375,8 +375,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0); id estRows task access object operator info -TableReader 6.00 root partition:all data:Selection -└─Selection 6.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0)) +TableReader 7.00 root partition:all data:Selection +└─Selection 7.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0); a b @@ -409,8 +409,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1; id estRows task access object operator info -TableReader 5.00 root partition:all data:Selection -└─Selection 5.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1) +TableReader 6.00 root partition:all data:Selection +└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1; a b @@ -421,8 +421,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1); id estRows task access object operator info -TableReader 5.00 root partition:all data:Selection -└─Selection 5.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1)) +TableReader 6.00 root partition:all data:Selection +└─Selection 6.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1); a b @@ -455,8 +455,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2; id estRows task access object operator info -TableReader 4.00 root partition:all data:Selection -└─Selection 4.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2) +TableReader 5.00 root partition:all data:Selection +└─Selection 5.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2; a b @@ -466,8 +466,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2); id estRows task access object operator info -TableReader 4.00 root partition:all data:Selection -└─Selection 4.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2)) +TableReader 5.00 root partition:all data:Selection +└─Selection 5.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2); a b @@ -500,8 +500,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3; id estRows task access object operator info -TableReader 3.00 root partition:all data:Selection -└─Selection 3.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3) +TableReader 4.00 root partition:all data:Selection +└─Selection 4.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3; a b @@ -510,8 +510,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3); id estRows task access object operator info -TableReader 3.00 root partition:all data:Selection -└─Selection 3.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3)) +TableReader 4.00 root partition:all data:Selection +└─Selection 4.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3); a b @@ -544,8 +544,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4; id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4) +TableReader 3.00 root partition:all data:Selection +└─Selection 3.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4; a b @@ -553,8 +553,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4); id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4)) +TableReader 3.00 root partition:all data:Selection +└─Selection 3.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4); a b @@ -587,16 +587,16 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5; id estRows task access object operator info -TableReader 1.00 root partition:all data:Selection -└─Selection 1.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5; a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5); id estRows task access object operator info -TableReader 1.00 root partition:all data:Selection -└─Selection 1.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5)) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5); a b @@ -616,8 +616,8 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE a != 6; id estRows task access object operator info -TableReader 6.00 root partition:all data:Selection -└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 6) +TableReader 7.00 root partition:all data:Selection +└─Selection 7.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 6) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a != 6; a b @@ -629,15 +629,15 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6; id estRows task access object operator info -TableReader 0.00 root partition:all data:Selection -└─Selection 0.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6; a b explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6); id estRows task access object operator info -TableReader 0.00 root partition:all data:Selection -└─Selection 0.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6)) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6); a b @@ -671,15 +671,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6 AND a != 7; id estRows task access object operator info -TableReader 0.00 root partition:all data:Selection -└─Selection 0.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6), ne(executor__partition__partition_boundaries.t.a, 7) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6), ne(executor__partition__partition_boundaries.t.a, 7) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6 AND a != 7; a b explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6, 7); id estRows task access object operator info -TableReader 0.00 root partition:all data:Selection -└─Selection 0.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7)) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6, 7); a b @@ -712,8 +712,8 @@ INSERT INTO t VALUES (-2147483648, 'MIN_INT filler...'), (0, '0 Filler...'); ANALYZE TABLE t all columns; explain format='brief' SELECT * FROM t WHERE a BETWEEN -2147483648 AND -2147483649; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, -2147483648), le(executor__partition__partition_boundaries.t.a, -2147483649) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, -2147483648), le(executor__partition__partition_boundaries.t.a, -2147483649) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN -2147483648 AND -2147483649; a b @@ -791,8 +791,8 @@ a b -2147483648 MIN_INT filler... explain format='brief' SELECT * FROM t WHERE a BETWEEN 0 AND -1; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 0), le(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 0), le(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 0 AND -1; a b @@ -885,8 +885,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a BETWEEN 999998 AND 999997; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 999998), le(executor__partition__partition_boundaries.t.a, 999997) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 999998), le(executor__partition__partition_boundaries.t.a, 999997) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 999998 AND 999997; a b @@ -997,8 +997,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a BETWEEN 999999 AND 999998; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 999999), le(executor__partition__partition_boundaries.t.a, 999998) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 999999), le(executor__partition__partition_boundaries.t.a, 999998) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 999999 AND 999998; a b @@ -1107,8 +1107,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a BETWEEN 1000000 AND 999999; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 1000000), le(executor__partition__partition_boundaries.t.a, 999999) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 1000000), le(executor__partition__partition_boundaries.t.a, 999999) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 1000000 AND 999999; a b @@ -1216,8 +1216,8 @@ a b 2000002 2000002 Filler ... explain format='brief' SELECT * FROM t WHERE a BETWEEN 1000001 AND 1000000; id estRows task access object operator info -TableReader 0.00 root partition:p1 data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 1000001), le(executor__partition__partition_boundaries.t.a, 1000000) +TableReader 1.00 root partition:p1 data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 1000001), le(executor__partition__partition_boundaries.t.a, 1000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 1000001 AND 1000000; a b @@ -1322,8 +1322,8 @@ a b 2000002 2000002 Filler ... explain format='brief' SELECT * FROM t WHERE a BETWEEN 1000002 AND 1000001; id estRows task access object operator info -TableReader 0.00 root partition:p1 data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 1000002), le(executor__partition__partition_boundaries.t.a, 1000001) +TableReader 1.00 root partition:p1 data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 1000002), le(executor__partition__partition_boundaries.t.a, 1000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 1000002 AND 1000001; a b @@ -1423,141 +1423,141 @@ a b 2000002 2000002 Filler ... explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 2999999; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 2999999) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 2999999) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 2999999; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000000; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000001; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000002; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000010; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000010) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000010) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 3000010; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 3999998; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3999998) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3999998) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 3999998; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 3999999; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3999999) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3999999) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 3999999; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 4000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 4000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 4000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 4000000; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 4000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 4000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 4000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 4000001; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000000 AND 4000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 4000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 4000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000000 AND 4000002; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000000; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000001; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000002; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000003; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000003) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000003) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000003; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000011; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000011) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3000011) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 3000011; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 3999999; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3999999) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 3999999) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 3999999; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000000; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000001; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000002; a b explain format='brief' SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000003; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000003) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001), le(executor__partition__partition_boundaries.t.a, 4000003) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 3000001 AND 4000003; a b @@ -1582,8 +1582,8 @@ INSERT INTO t VALUES (6, '6 Filler...'); ANALYZE TABLE t all columns; explain format='brief' SELECT * FROM t WHERE a BETWEEN 2 AND -1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 2 AND -1; a b @@ -1601,8 +1601,8 @@ a b 4 4 Filler... explain format='brief' SELECT * FROM t WHERE a BETWEEN 2 AND 0; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 2 AND 0; a b @@ -1620,8 +1620,8 @@ a b 4 4 Filler... explain format='brief' SELECT * FROM t WHERE a BETWEEN 2 AND 1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 2 AND 1; a b @@ -1703,8 +1703,8 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE a BETWEEN 5 AND 4; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 5), le(executor__partition__partition_boundaries.t.a, 4) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 5), le(executor__partition__partition_boundaries.t.a, 4) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 5 AND 4; a b @@ -1722,8 +1722,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a BETWEEN 6 AND 4; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 6), le(executor__partition__partition_boundaries.t.a, 4) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 6), le(executor__partition__partition_boundaries.t.a, 4) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 6 AND 4; a b @@ -1741,8 +1741,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a BETWEEN 7 AND 4; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 7), le(executor__partition__partition_boundaries.t.a, 4) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 7), le(executor__partition__partition_boundaries.t.a, 4) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a BETWEEN 7 AND 4; a b @@ -1761,8 +1761,8 @@ INSERT INTO t VALUES (-2147483648, 'MIN_INT filler...'), (0, '0 Filler...'); ANALYZE TABLE t all columns; explain format='brief' SELECT * FROM t WHERE a < -2147483648; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] lt(executor__partition__partition_boundaries.t.a, -2147483648) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] lt(executor__partition__partition_boundaries.t.a, -2147483648) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < -2147483648; a b @@ -2174,8 +2174,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a > 3000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 3000000; a b @@ -2202,8 +2202,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a >= 3000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 3000000; a b @@ -2230,8 +2230,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a > 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 3000001; a b @@ -2258,8 +2258,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a >= 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 3000001; a b @@ -3199,8 +3199,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a > 2999999; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2999999; a b @@ -3243,22 +3243,22 @@ a b 2999999 2999999 Filler ... explain format='brief' SELECT * FROM t WHERE a > 2999999 AND a <= 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999), le(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999), le(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2999999 AND a <= 3000001; a b explain format='brief' SELECT * FROM t WHERE a > 2999999 AND a < 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999), lt(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999), lt(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2999999 AND a < 3000001; a b explain format='brief' SELECT * FROM t WHERE a > 2999999 AND a <= 3000001; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999), le(executor__partition__partition_boundaries.t.a, 3000001) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2999999), le(executor__partition__partition_boundaries.t.a, 3000001) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2999999 AND a <= 3000001; a b @@ -3285,8 +3285,8 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a > 3000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 3000000; a b @@ -3313,36 +3313,36 @@ a b 999999 999999 Filler ... explain format='brief' SELECT * FROM t WHERE a >= 3000000; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 3000000; a b explain format='brief' SELECT * FROM t WHERE a >= 3000000 AND a <= 3000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 3000000 AND a <= 3000002; a b explain format='brief' SELECT * FROM t WHERE a > 3000000 AND a <= 3000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 3000000 AND a <= 3000002; a b explain format='brief' SELECT * FROM t WHERE a > 3000000 AND a < 3000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000), lt(executor__partition__partition_boundaries.t.a, 3000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000), lt(executor__partition__partition_boundaries.t.a, 3000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 3000000 AND a < 3000002; a b explain format='brief' SELECT * FROM t WHERE a > 3000000 AND a <= 3000002; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 3000000), le(executor__partition__partition_boundaries.t.a, 3000002) └─TableFullScan 14.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 3000000 AND a <= 3000002; a b @@ -3369,8 +3369,8 @@ INSERT INTO t VALUES (6, '6 Filler...'); ANALYZE TABLE t all columns; explain format='brief' SELECT * FROM t WHERE a < -1; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] lt(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] lt(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < -1; a b @@ -3390,8 +3390,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a <= -1; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] le(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] le(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a <= -1; a b @@ -3425,15 +3425,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a < -1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a < -1; a b explain format='brief' SELECT * FROM t WHERE NOT (a < 2 OR a > -1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a < 2 OR a > -1); a b @@ -3467,15 +3467,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 2 AND a < -1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 2 AND a < -1; a b explain format='brief' SELECT * FROM t WHERE NOT (a < 2 OR a >= -1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a < 2 OR a >= -1); a b @@ -3509,15 +3509,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a <= -1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a <= -1; a b explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a > -1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a > -1); a b @@ -3551,15 +3551,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 2 AND a <= -1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, -1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 2 AND a <= -1; a b explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a >= -1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, -1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a >= -1); a b @@ -3579,8 +3579,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a < 0; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] lt(executor__partition__partition_boundaries.t.a, 0) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] lt(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < 0; a b @@ -3635,15 +3635,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a < 0; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a < 0; a b explain format='brief' SELECT * FROM t WHERE NOT (a < 2 OR a > 0); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a < 2 OR a > 0); a b @@ -3677,15 +3677,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 2 AND a < 0; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 2 AND a < 0; a b explain format='brief' SELECT * FROM t WHERE NOT (a < 2 OR a >= 0); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a < 2 OR a >= 0); a b @@ -3719,15 +3719,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a <= 0; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a <= 0; a b explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a > 0); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a > 0); a b @@ -3761,15 +3761,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 2 AND a <= 0; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 2 AND a <= 0; a b explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a >= 0); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 0)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a >= 0); a b @@ -3845,15 +3845,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a < 1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a < 1; a b explain format='brief' SELECT * FROM t WHERE NOT (a < 2 OR a > 1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a < 2 OR a > 1); a b @@ -3887,15 +3887,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 2 AND a < 1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 2 AND a < 1; a b explain format='brief' SELECT * FROM t WHERE NOT (a < 2 OR a >= 1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a < 2 OR a >= 1); a b @@ -3929,15 +3929,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a <= 1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a <= 1; a b explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a > 1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a > 1); a b @@ -3971,15 +3971,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 2 AND a <= 1; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 2 AND a <= 1; a b explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a >= 1); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a >= 1); a b @@ -4054,8 +4054,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a < 2; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a < 2; a b @@ -4097,15 +4097,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 2 AND a < 2; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 2 AND a < 2; a b explain format='brief' SELECT * FROM t WHERE NOT (a < 2 OR a >= 2); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(ge(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a < 2 OR a >= 2); a b @@ -4139,15 +4139,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a <= 2; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 2) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 2) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a <= 2; a b explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a > 2); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 2)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 2)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a > 2); a b @@ -4189,8 +4189,8 @@ a b 2 2 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a >= 2); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 2)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a >= 2); a b @@ -4263,8 +4263,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a < 3; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 3) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 3) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a < 3; a b @@ -4348,16 +4348,16 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a <= 3; id estRows task access object operator info -TableReader 1.00 root partition:p3 data:Selection -└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3) +TableReader 2.00 root partition:p3 data:Selection +└─Selection 2.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a <= 3; a b 3 3 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a > 3); id estRows task access object operator info -TableReader 1.00 root partition:p3 data:Selection -└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3)) +TableReader 2.00 root partition:p3 data:Selection +└─Selection 2.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a > 3); a b @@ -4400,8 +4400,8 @@ a b 3 3 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a >= 3); id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 3)) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), lt(executor__partition__partition_boundaries.t.a, 3)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a >= 3); a b @@ -4852,8 +4852,8 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE a > 6; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 6) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 6) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 6; a b @@ -4881,8 +4881,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a < 2 OR a > 6; id estRows task access object operator info -TableReader 2.00 root partition:p0,p1 data:Selection -└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 3.00 root partition:p0,p1 data:Selection +└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < 2 OR a > 6; a b @@ -4965,8 +4965,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a <= 2 OR a > 6; id estRows task access object operator info -TableReader 3.00 root partition:p0,p1,p2 data:Selection -└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 4.00 root partition:p0,p1,p2 data:Selection +└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a <= 2 OR a > 6; a b @@ -4997,8 +4997,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a > 2 AND a <= 6); id estRows task access object operator info -TableReader 3.00 root partition:p0,p1,p2 data:Selection -└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 4.00 root partition:p0,p1,p2 data:Selection +└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a > 2 AND a <= 6); a b @@ -5040,8 +5040,8 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a >= 2 AND a <= 6); id estRows task access object operator info -TableReader 2.00 root partition:p0,p1 data:Selection -└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 3.00 root partition:p0,p1 data:Selection +└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a >= 2 AND a <= 6); a b @@ -5063,8 +5063,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 7; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 7) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 7) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 7; a b @@ -5084,15 +5084,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a >= 7; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 7) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] ge(executor__partition__partition_boundaries.t.a, 7) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a >= 7; a b explain format='brief' SELECT * FROM t WHERE a < 2 OR a > 7; id estRows task access object operator info -TableReader 2.00 root partition:p0,p1 data:Selection -└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1 data:Selection +└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < 2 OR a > 7; a b @@ -5123,8 +5123,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a > 2 AND a < 7); id estRows task access object operator info -TableReader 3.00 root partition:p0,p1,p2 data:Selection -└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 4.00 root partition:p0,p1,p2 data:Selection +└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a > 2 AND a < 7); a b @@ -5133,8 +5133,8 @@ a b 2 2 Filler... explain format='brief' SELECT * FROM t WHERE a < 2 OR a >= 7; id estRows task access object operator info -TableReader 2.00 root partition:p0,p1 data:Selection -└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1 data:Selection +└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < 2 OR a >= 7; a b @@ -5166,8 +5166,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a >= 2 AND a < 7); id estRows task access object operator info -TableReader 2.00 root partition:p0,p1 data:Selection -└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1 data:Selection +└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a >= 2 AND a < 7); a b @@ -5175,8 +5175,8 @@ a b 1 1 Filler... explain format='brief' SELECT * FROM t WHERE a <= 2 OR a > 7; id estRows task access object operator info -TableReader 3.00 root partition:p0,p1,p2 data:Selection -└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 4.00 root partition:p0,p1,p2 data:Selection +└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a <= 2 OR a > 7; a b @@ -5207,8 +5207,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a > 2 AND a <= 7); id estRows task access object operator info -TableReader 3.00 root partition:p0,p1,p2 data:Selection -└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 4.00 root partition:p0,p1,p2 data:Selection +└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a > 2 AND a <= 7); a b @@ -5217,8 +5217,8 @@ a b 2 2 Filler... explain format='brief' SELECT * FROM t WHERE a <= 2 OR a >= 7; id estRows task access object operator info -TableReader 3.00 root partition:p0,p1,p2 data:Selection -└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 4.00 root partition:p0,p1,p2 data:Selection +└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a <= 2 OR a >= 7; a b @@ -5250,8 +5250,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a >= 2 AND a <= 7); id estRows task access object operator info -TableReader 2.00 root partition:p0,p1 data:Selection -└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1 data:Selection +└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a >= 2 AND a <= 7); a b From 422e7926828a81fce6f7bbe6e03735a53af2704d Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 11:10:32 -0500 Subject: [PATCH 04/35] testcase updates3 --- .../testdata/cardinality_suite_out.json | 196 +++++++++--------- pkg/statistics/histogram.go | 5 +- 2 files changed, 102 insertions(+), 99 deletions(-) diff --git a/pkg/planner/cardinality/testdata/cardinality_suite_out.json b/pkg/planner/cardinality/testdata/cardinality_suite_out.json index 4ea3d621a19e2..3f2ad3e2b597b 100644 --- a/pkg/planner/cardinality/testdata/cardinality_suite_out.json +++ b/pkg/planner/cardinality/testdata/cardinality_suite_out.json @@ -24,7 +24,7 @@ { "Start": 800, "End": 900, - "Count": 755.754166655054 + "Count": 739.254166655054 }, { "Start": 900, @@ -79,7 +79,7 @@ { "Start": 800, "End": 1000, - "Count": 1213.946869573942 + "Count": 1197.446869573942 }, { "Start": 900, @@ -104,7 +104,7 @@ { "Start": 200, "End": 400, - "Count": 1215.0288209899081 + "Count": 1221.0288209899081 }, { "Start": 200, @@ -142,8 +142,8 @@ { "SQL": "explain format = 'brief' select * from t where a < 300", "Result": [ - "TableReader 1000.00 root data:Selection", - "└─Selection 1000.00 cop[tikv] lt(test.t.a, 300)", + "TableReader 1000.67 root data:Selection", + "└─Selection 1000.67 cop[tikv] lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -166,16 +166,16 @@ { "SQL": "explain format = 'brief' select * from t where a >= 900", "Result": [ - "TableReader 1000.00 root data:Selection", - "└─Selection 1000.00 cop[tikv] ge(test.t.a, 900)", + "TableReader 1000.67 root data:Selection", + "└─Selection 1000.67 cop[tikv] ge(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900", "Result": [ - "TableReader 1000.00 root data:Selection", - "└─Selection 1000.00 cop[tikv] gt(test.t.a, 900)", + "TableReader 1000.67 root data:Selection", + "└─Selection 1000.67 cop[tikv] gt(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -206,32 +206,32 @@ { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1000", "Result": [ - "TableReader 458.19 root data:Selection", - "└─Selection 458.19 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", + "TableReader 458.86 root data:Selection", + "└─Selection 458.86 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1100", "Result": [ - "TableReader 832.77 root data:Selection", - "└─Selection 832.77 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", + "TableReader 833.44 root data:Selection", + "└─Selection 833.44 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 200 and a < 300", "Result": [ - "TableReader 459.03 root data:Selection", - "└─Selection 459.03 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", + "TableReader 459.70 root data:Selection", + "└─Selection 459.70 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 100 and a < 300", "Result": [ - "TableReader 834.45 root data:Selection", - "└─Selection 834.45 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", + "TableReader 835.11 root data:Selection", + "└─Selection 835.11 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -258,8 +258,8 @@ { "SQL": "explain format = 'brief' select * from t where a > 900", "Result": [ - "TableReader 5.00 root data:Selection", - "└─Selection 5.00 cop[tikv] gt(test.t.a, 900)", + "TableReader 6.00 root data:Selection", + "└─Selection 6.00 cop[tikv] gt(test.t.a, 900)", " └─TableFullScan 3000.00 cop[tikv] table:t keep order:false" ] }, @@ -865,8 +865,8 @@ { "SQL": "explain format = 'brief' select * from t where a < 300", "Result": [ - "TableReader 1000.00 root partition:p0 data:Selection", - "└─Selection 1000.00 cop[tikv] lt(test.t.a, 300)", + "TableReader 1000.67 root partition:p0 data:Selection", + "└─Selection 1000.67 cop[tikv] lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -889,16 +889,16 @@ { "SQL": "explain format = 'brief' select * from t where a >= 900", "Result": [ - "TableReader 1000.00 root partition:p3,p4 data:Selection", - "└─Selection 1000.00 cop[tikv] ge(test.t.a, 900)", + "TableReader 1000.67 root partition:p3,p4 data:Selection", + "└─Selection 1000.67 cop[tikv] ge(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900", "Result": [ - "TableReader 1000.00 root partition:p3,p4 data:Selection", - "└─Selection 1000.00 cop[tikv] gt(test.t.a, 900)", + "TableReader 1000.67 root partition:p3,p4 data:Selection", + "└─Selection 1000.67 cop[tikv] gt(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -929,32 +929,32 @@ { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1000", "Result": [ - "TableReader 458.19 root partition:p3 data:Selection", - "└─Selection 458.19 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", + "TableReader 458.86 root partition:p3 data:Selection", + "└─Selection 458.86 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1100", "Result": [ - "TableReader 832.77 root partition:p3,p4 data:Selection", - "└─Selection 832.77 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", + "TableReader 833.44 root partition:p3,p4 data:Selection", + "└─Selection 833.44 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 200 and a < 300", "Result": [ - "TableReader 459.03 root partition:p0 data:Selection", - "└─Selection 459.03 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", + "TableReader 459.70 root partition:p0 data:Selection", + "└─Selection 459.70 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 100 and a < 300", "Result": [ - "TableReader 834.45 root partition:p0 data:Selection", - "└─Selection 834.45 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", + "TableReader 835.11 root partition:p0 data:Selection", + "└─Selection 835.11 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] } @@ -976,9 +976,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 200.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 200.00 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -986,9 +986,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 200.00 cop[tikv] table:t, index:ic(c) keep order:true, desc, stats:pseudo", + " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, desc, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 200.00 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -996,9 +996,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 9.99 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 8999)", - " └─TableRowIDScan 9.99 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1006,9 +1006,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9000)", - " └─TableRowIDScan 10.00 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1016,9 +1016,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 10.01 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9001)", - " └─TableRowIDScan 10.01 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1027,7 +1027,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10001)", - "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1037,7 +1037,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10000)", - "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1047,7 +1047,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 9999)", - "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1068,7 +1068,7 @@ "└─TableReader 1.00 root data:Limit", " └─Limit 1.00 cop[tikv] offset:0, count:1", " └─Selection 1.00 cop[tikv] lt(test.t.c, 100)", - " └─TableRangeScan 100.00 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" + " └─TableRangeScan 1.96 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" ] }, { @@ -1078,21 +1078,21 @@ { "Query": "explain format = 'brief' select * from t where b >= 9950 order by c limit 1", "Result": [ - "TopN 1.00 root test.t.c, offset:0, count:1", + "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ib(b) range:[9950,+inf], keep order:false, stats:pseudo", - " └─TopN(Probe) 1.00 cop[tikv] test.t.c, offset:0, count:1", - " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" + " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", + " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where b >= 9950 order by c desc limit 1", "Result": [ - "TopN 1.00 root test.t.c:desc, offset:0, count:1", + "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ib(b) range:[9950,+inf], keep order:false, stats:pseudo", - " └─TopN(Probe) 1.00 cop[tikv] test.t.c:desc, offset:0, count:1", - " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" + " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, desc, stats:pseudo", + " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", + " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1100,9 +1100,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 9.99 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 8999)", - " └─TableRowIDScan 9.99 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1110,19 +1110,19 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9000)", - " └─TableRowIDScan 10.00 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where b >= 9001 order by c limit 1", "Result": [ - "TopN 1.00 root test.t.c, offset:0, count:1", - "└─TableReader 1.00 root data:TopN", - " └─TopN 1.00 cop[tikv] test.t.c, offset:0, count:1", - " └─Selection 9990.00 cop[tikv] ge(test.t.b, 9001)", - " └─TableFullScan 100000.00 cop[tikv] table:t keep order:false, stats:pseudo" + "Limit 1.00 root offset:0, count:1", + "└─IndexLookUp 1.00 root ", + " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9001)", + " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1131,7 +1131,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10001)", - "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1141,17 +1141,18 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10000)", - "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where a < 9999 order by c limit 1", "Result": [ - "TopN 1.00 root test.t.c, offset:0, count:1", - "└─TableReader 1.00 root data:TopN", - " └─TopN 1.00 cop[tikv] test.t.c, offset:0, count:1", - " └─TableRangeScan 9999.00 cop[tikv] table:t range:[-inf,9999), keep order:false, stats:pseudo" + "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", + "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", + "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 9999)", + "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1170,20 +1171,19 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ic(c) range:[9950,+inf], keep order:true, stats:pseudo", + " ├─IndexRangeScan(Build) 1.98 cop[tikv] table:t, index:ic(c) range:[9950,+inf], keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where b >= 9950 and c >= 9900 order by c limit 1", "Result": [ - "TopN 1.00 root test.t.c, offset:0, count:1", + "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ib(b) range:[9950,+inf], keep order:false, stats:pseudo", - " └─TopN(Probe) 1.00 cop[tikv] test.t.c, offset:0, count:1", - " └─Selection 5.00 cop[tikv] ge(test.t.c, 9900)", - " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" + " ├─IndexRangeScan(Build) 1.98 cop[tikv] table:t, index:ic(c) range:[9900,+inf], keep order:true, stats:pseudo", + " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", + " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1193,7 +1193,7 @@ "└─TableReader 1.00 root data:Limit", " └─Limit 1.00 cop[tikv] offset:0, count:1", " └─Selection 1.00 cop[tikv] lt(test.t.c, 100)", - " └─TableRangeScan 100.00 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" + " └─TableRangeScan 1.96 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" ] } ] @@ -2575,13 +2575,13 @@ "rowCount": 1294.0277777777778 }, { - "Result": 0 + "Result": 1294.0277777777778 } ] }, { "End estimate range": { - "RowCount": 0, + "RowCount": 1295.0277777777778, "Type": "Range" } } @@ -2589,7 +2589,7 @@ }, { "Name": "a", - "Result": 0 + "Result": 1295.0277777777778 } ] }, @@ -2873,13 +2873,13 @@ "rowCount": 1540 }, { - "Result": 0 + "Result": 1540 } ] }, { "End estimate range": { - "RowCount": 0, + "RowCount": 1540, "Type": "Range" } } @@ -2887,7 +2887,7 @@ }, { "End estimate range": { - "RowCount": 0, + "RowCount": 1540, "Type": "Range" } } @@ -2895,7 +2895,7 @@ }, { "Name": "iab", - "Result": 0 + "Result": 1540 } ] }, @@ -3447,11 +3447,11 @@ "Expressions": [ "lt(test.t.a, -1500)" ], - "Selectivity": 0, + "Selectivity": 0.5, "partial cover": false }, { - "Result": 0 + "Result": 0.0003246753246753247 } ] } @@ -3602,7 +3602,7 @@ "histR": 999, "lPercent": 0.5613744375005636, "rPercent": 0, - "rowCount": 555.760693125558 + "rowCount": 100 }, { "Result": 100 @@ -3611,7 +3611,7 @@ }, { "End estimate range": { - "RowCount": 100, + "RowCount": 101.03355704697987, "Type": "Range" } } @@ -3619,7 +3619,7 @@ }, { "Name": "a", - "Result": 100 + "Result": 101.03355704697987 } ] }, @@ -3940,7 +3940,7 @@ }, { "Name": "iab", - "Result": 0 + "Result": 1 } ] }, @@ -4093,11 +4093,11 @@ "Expressions": [ "lt(test.t.a, -1500)" ], - "Selectivity": 0, + "Selectivity": 0.0003246753246753247, "partial cover": false }, { - "Result": 0 + "Result": 1.9066503965832828e-7 } ] } @@ -4113,8 +4113,8 @@ "Result": [ "Projection 2000.00 root test.t.a, test.t.b, test.t.a, test.t.b", "└─IndexJoin 2000.00 root inner join, inner:IndexLookUp, outer key:test.t.a, inner key:test.t.b, equal cond:eq(test.t.a, test.t.b)", - " ├─TableReader(Build) 1000.00 root data:Selection", - " │ └─Selection 1000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", + " ├─TableReader(Build) 251000.00 root data:Selection", + " │ └─Selection 251000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", " │ └─TableFullScan 500000.00 cop[tikv] table:t keep order:false, stats:pseudo", " └─IndexLookUp(Probe) 2000.00 root ", " ├─Selection(Build) 1000000.00 cop[tikv] lt(test.t.b, 1), not(isnull(test.t.b))", @@ -4132,12 +4132,12 @@ "Result": [ "Projection 2000.00 root test.t.a, test.t.b, test.t.a, test.t.b", "└─IndexJoin 2000.00 root inner join, inner:IndexLookUp, outer key:test.t.a, inner key:test.t.b, equal cond:eq(test.t.a, test.t.b)", - " ├─TableReader(Build) 1000.00 root data:Selection", - " │ └─Selection 1000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", + " ├─TableReader(Build) 251000.00 root data:Selection", + " │ └─Selection 251000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", " │ └─TableFullScan 500000.00 cop[tikv] table:t keep order:false, stats:pseudo", " └─IndexLookUp(Probe) 2000.00 root ", " ├─Selection(Build) 1000000.00 cop[tikv] lt(test.t.b, 1), not(isnull(test.t.b))", - " │ └─IndexRangeScan 1000000.00 cop[tikv] table:t2, index:idx(b) range: decided by [eq(test.t.b, test.t.a)], keep order:false, stats:pseudo", + " │ └─IndexRangeScan 251000000.00 cop[tikv] table:t2, index:idx(b) range: decided by [eq(test.t.b, test.t.a)], keep order:false, stats:pseudo", " └─Selection(Probe) 2000.00 cop[tikv] eq(test.t.a, 0)", " └─TableRowIDScan 1000000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" ] diff --git a/pkg/statistics/histogram.go b/pkg/statistics/histogram.go index 2b3f3225af8d1..4be9a5398b59e 100644 --- a/pkg/statistics/histogram.go +++ b/pkg/statistics/histogram.go @@ -1057,7 +1057,10 @@ func (hg *Histogram) OutOfRangeRowCount( } // Use modifyCount as a final bound - return min(rowCount, float64(modifyCount)) + if modifyCount > 0 && rowCount > float64(modifyCount) { + rowCount = float64(modifyCount) + } + return rowCount } // Copy deep copies the histogram. From 32b9f23ab278fa8343a0c04799359fc4a1c3c343 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 13:54:13 -0500 Subject: [PATCH 05/35] testcase updates4 --- pkg/planner/cardinality/row_count_column.go | 5 +- pkg/planner/cardinality/row_count_index.go | 1 + .../testdata/cardinality_suite_out.json | 74 +++++----- .../partition/partition_boundaries.result | 128 +++++++++--------- 4 files changed, 105 insertions(+), 103 deletions(-) diff --git a/pkg/planner/cardinality/row_count_column.go b/pkg/planner/cardinality/row_count_column.go index b3df50776d106..e38f8da4146c3 100644 --- a/pkg/planner/cardinality/row_count_column.go +++ b/pkg/planner/cardinality/row_count_column.go @@ -278,7 +278,7 @@ func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges [] return 0, errors.Trace(err) } cnt -= lowCnt - cnt = mathutil.Clamp(cnt, 1, c.NotNullCount()) + cnt = mathutil.Clamp(cnt, 0, c.NotNullCount()) } if !rg.LowExclude && lowVal.IsNull() { cnt += float64(c.NullCount) @@ -291,7 +291,7 @@ func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges [] cnt += highCnt } - cnt = mathutil.Clamp(cnt, 1, c.TotalRowCount()) + cnt = mathutil.Clamp(cnt, 0, c.TotalRowCount()) // If the current table row count has changed, we should scale the row count accordingly. increaseFactor := c.GetIncreaseFactor(realtimeRowCount) @@ -312,6 +312,7 @@ func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges [] } rowCount += cnt } + // Don't allow the final result to go below 1 row rowCount = mathutil.Clamp(rowCount, 1, float64(realtimeRowCount)) return rowCount, nil } diff --git a/pkg/planner/cardinality/row_count_index.go b/pkg/planner/cardinality/row_count_index.go index 48861044725e4..bc0f3c3226090 100644 --- a/pkg/planner/cardinality/row_count_index.go +++ b/pkg/planner/cardinality/row_count_index.go @@ -350,6 +350,7 @@ func getIndexRowCountForStatsV2(sctx context.PlanContext, idx *statistics.Index, } totalCount += count } + // Don't allow the final result to go below 1 row totalCount = mathutil.Clamp(totalCount, 1, float64(realtimeRowCount)) return totalCount, nil } diff --git a/pkg/planner/cardinality/testdata/cardinality_suite_out.json b/pkg/planner/cardinality/testdata/cardinality_suite_out.json index 3f2ad3e2b597b..b982afe0a0095 100644 --- a/pkg/planner/cardinality/testdata/cardinality_suite_out.json +++ b/pkg/planner/cardinality/testdata/cardinality_suite_out.json @@ -24,7 +24,7 @@ { "Start": 800, "End": 900, - "Count": 739.254166655054 + "Count": 752.004166655054 }, { "Start": 900, @@ -79,7 +79,7 @@ { "Start": 800, "End": 1000, - "Count": 1197.446869573942 + "Count": 1210.196869573942 }, { "Start": 900, @@ -104,7 +104,7 @@ { "Start": 200, "End": 400, - "Count": 1221.0288209899081 + "Count": 1216.5288209899081 }, { "Start": 200, @@ -142,8 +142,8 @@ { "SQL": "explain format = 'brief' select * from t where a < 300", "Result": [ - "TableReader 1000.67 root data:Selection", - "└─Selection 1000.67 cop[tikv] lt(test.t.a, 300)", + "TableReader 1000.00 root data:Selection", + "└─Selection 1000.00 cop[tikv] lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -166,16 +166,16 @@ { "SQL": "explain format = 'brief' select * from t where a >= 900", "Result": [ - "TableReader 1000.67 root data:Selection", - "└─Selection 1000.67 cop[tikv] ge(test.t.a, 900)", + "TableReader 1000.00 root data:Selection", + "└─Selection 1000.00 cop[tikv] ge(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900", "Result": [ - "TableReader 1000.67 root data:Selection", - "└─Selection 1000.67 cop[tikv] gt(test.t.a, 900)", + "TableReader 1000.00 root data:Selection", + "└─Selection 1000.00 cop[tikv] gt(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -206,32 +206,32 @@ { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1000", "Result": [ - "TableReader 458.86 root data:Selection", - "└─Selection 458.86 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", + "TableReader 458.19 root data:Selection", + "└─Selection 458.19 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1100", "Result": [ - "TableReader 833.44 root data:Selection", - "└─Selection 833.44 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", + "TableReader 832.77 root data:Selection", + "└─Selection 832.77 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 200 and a < 300", "Result": [ - "TableReader 459.70 root data:Selection", - "└─Selection 459.70 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", + "TableReader 459.03 root data:Selection", + "└─Selection 459.03 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 100 and a < 300", "Result": [ - "TableReader 835.11 root data:Selection", - "└─Selection 835.11 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", + "TableReader 834.45 root data:Selection", + "└─Selection 834.45 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -258,8 +258,8 @@ { "SQL": "explain format = 'brief' select * from t where a > 900", "Result": [ - "TableReader 6.00 root data:Selection", - "└─Selection 6.00 cop[tikv] gt(test.t.a, 900)", + "TableReader 5.00 root data:Selection", + "└─Selection 5.00 cop[tikv] gt(test.t.a, 900)", " └─TableFullScan 3000.00 cop[tikv] table:t keep order:false" ] }, @@ -865,8 +865,8 @@ { "SQL": "explain format = 'brief' select * from t where a < 300", "Result": [ - "TableReader 1000.67 root partition:p0 data:Selection", - "└─Selection 1000.67 cop[tikv] lt(test.t.a, 300)", + "TableReader 1000.00 root partition:p0 data:Selection", + "└─Selection 1000.00 cop[tikv] lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -889,16 +889,16 @@ { "SQL": "explain format = 'brief' select * from t where a >= 900", "Result": [ - "TableReader 1000.67 root partition:p3,p4 data:Selection", - "└─Selection 1000.67 cop[tikv] ge(test.t.a, 900)", + "TableReader 1000.00 root partition:p3,p4 data:Selection", + "└─Selection 1000.00 cop[tikv] ge(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900", "Result": [ - "TableReader 1000.67 root partition:p3,p4 data:Selection", - "└─Selection 1000.67 cop[tikv] gt(test.t.a, 900)", + "TableReader 1000.00 root partition:p3,p4 data:Selection", + "└─Selection 1000.00 cop[tikv] gt(test.t.a, 900)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, @@ -929,32 +929,32 @@ { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1000", "Result": [ - "TableReader 458.86 root partition:p3 data:Selection", - "└─Selection 458.86 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", + "TableReader 458.19 root partition:p3 data:Selection", + "└─Selection 458.19 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1000)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 900 and a < 1100", "Result": [ - "TableReader 833.44 root partition:p3,p4 data:Selection", - "└─Selection 833.44 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", + "TableReader 832.77 root partition:p3,p4 data:Selection", + "└─Selection 832.77 cop[tikv] gt(test.t.a, 900), lt(test.t.a, 1100)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 200 and a < 300", "Result": [ - "TableReader 459.70 root partition:p0 data:Selection", - "└─Selection 459.70 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", + "TableReader 459.03 root partition:p0 data:Selection", + "└─Selection 459.03 cop[tikv] gt(test.t.a, 200), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] }, { "SQL": "explain format = 'brief' select * from t where a > 100 and a < 300", "Result": [ - "TableReader 835.11 root partition:p0 data:Selection", - "└─Selection 835.11 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", + "TableReader 834.45 root partition:p0 data:Selection", + "└─Selection 834.45 cop[tikv] gt(test.t.a, 100), lt(test.t.a, 300)", " └─TableFullScan 2000.00 cop[tikv] table:t keep order:false" ] } @@ -2581,7 +2581,7 @@ }, { "End estimate range": { - "RowCount": 1295.0277777777778, + "RowCount": 1294.0277777777778, "Type": "Range" } } @@ -2589,7 +2589,7 @@ }, { "Name": "a", - "Result": 1295.0277777777778 + "Result": 1294.0277777777778 } ] }, @@ -3611,7 +3611,7 @@ }, { "End estimate range": { - "RowCount": 101.03355704697987, + "RowCount": 100, "Type": "Range" } } @@ -3619,7 +3619,7 @@ }, { "Name": "a", - "Result": 101.03355704697987 + "Result": 100 } ] }, diff --git a/tests/integrationtest/r/executor/partition/partition_boundaries.result b/tests/integrationtest/r/executor/partition/partition_boundaries.result index a3df0aebe0137..5148ee5e34696 100644 --- a/tests/integrationtest/r/executor/partition/partition_boundaries.result +++ b/tests/integrationtest/r/executor/partition/partition_boundaries.result @@ -349,8 +349,8 @@ SELECT * FROM t WHERE 1 = 0 OR a = -1; a b explain format='brief' SELECT * FROM t WHERE a != 0; id estRows task access object operator info -TableReader 7.00 root partition:all data:Selection -└─Selection 7.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 0) +TableReader 6.00 root partition:all data:Selection +└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a != 0; a b @@ -362,8 +362,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0; id estRows task access object operator info -TableReader 7.00 root partition:all data:Selection -└─Selection 7.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0) +TableReader 6.00 root partition:all data:Selection +└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0; a b @@ -375,8 +375,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0); id estRows task access object operator info -TableReader 7.00 root partition:all data:Selection -└─Selection 7.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0)) +TableReader 6.00 root partition:all data:Selection +└─Selection 6.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0); a b @@ -409,8 +409,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1; id estRows task access object operator info -TableReader 6.00 root partition:all data:Selection -└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1) +TableReader 5.00 root partition:all data:Selection +└─Selection 5.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1; a b @@ -421,8 +421,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1); id estRows task access object operator info -TableReader 6.00 root partition:all data:Selection -└─Selection 6.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1)) +TableReader 5.00 root partition:all data:Selection +└─Selection 5.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1); a b @@ -455,8 +455,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2; id estRows task access object operator info -TableReader 5.00 root partition:all data:Selection -└─Selection 5.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2) +TableReader 4.00 root partition:all data:Selection +└─Selection 4.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2; a b @@ -466,8 +466,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2); id estRows task access object operator info -TableReader 5.00 root partition:all data:Selection -└─Selection 5.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2)) +TableReader 4.00 root partition:all data:Selection +└─Selection 4.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2); a b @@ -500,8 +500,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3; id estRows task access object operator info -TableReader 4.00 root partition:all data:Selection -└─Selection 4.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3) +TableReader 3.00 root partition:all data:Selection +└─Selection 3.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3; a b @@ -510,8 +510,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3); id estRows task access object operator info -TableReader 4.00 root partition:all data:Selection -└─Selection 4.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3)) +TableReader 3.00 root partition:all data:Selection +└─Selection 3.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3); a b @@ -544,8 +544,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4; id estRows task access object operator info -TableReader 3.00 root partition:all data:Selection -└─Selection 3.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4; a b @@ -553,8 +553,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4); id estRows task access object operator info -TableReader 3.00 root partition:all data:Selection -└─Selection 3.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4)) +TableReader 2.00 root partition:all data:Selection +└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4); a b @@ -587,16 +587,16 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5; id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5; a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5); id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5)) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5); a b @@ -616,8 +616,8 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE a != 6; id estRows task access object operator info -TableReader 7.00 root partition:all data:Selection -└─Selection 7.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 6) +TableReader 6.00 root partition:all data:Selection +└─Selection 6.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, 6) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a != 6; a b @@ -629,15 +629,15 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6; id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6; a b explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6); id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6)) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6); a b @@ -671,15 +671,15 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6 AND a != 7; id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6), ne(executor__partition__partition_boundaries.t.a, 7) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] ne(executor__partition__partition_boundaries.t.a, -1), ne(executor__partition__partition_boundaries.t.a, 0), ne(executor__partition__partition_boundaries.t.a, 1), ne(executor__partition__partition_boundaries.t.a, 2), ne(executor__partition__partition_boundaries.t.a, 3), ne(executor__partition__partition_boundaries.t.a, 4), ne(executor__partition__partition_boundaries.t.a, 5), ne(executor__partition__partition_boundaries.t.a, 6), ne(executor__partition__partition_boundaries.t.a, 7) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE 1 = 1 AND a != -1 AND a != 0 AND a != 1 AND a != 2 AND a != 3 AND a != 4 AND a != 5 AND a != 6 AND a != 7; a b explain format='brief' SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6, 7); id estRows task access object operator info -TableReader 2.00 root partition:all data:Selection -└─Selection 2.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7)) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] not(in(executor__partition__partition_boundaries.t.a, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a NOT IN (-2, -1, 0, 1, 2, 3, 4, 5, 6, 7); a b @@ -4348,16 +4348,16 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a > 2 AND a <= 3; id estRows task access object operator info -TableReader 2.00 root partition:p3 data:Selection -└─Selection 2.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3) +TableReader 1.00 root partition:p3 data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a > 2 AND a <= 3; a b 3 3 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a <= 2 OR a > 3); id estRows task access object operator info -TableReader 2.00 root partition:p3 data:Selection -└─Selection 2.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3)) +TableReader 1.00 root partition:p3 data:Selection +└─Selection 1.00 cop[tikv] and(gt(executor__partition__partition_boundaries.t.a, 2), le(executor__partition__partition_boundaries.t.a, 3)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a <= 2 OR a > 3); a b @@ -4881,8 +4881,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a < 2 OR a > 6; id estRows task access object operator info -TableReader 3.00 root partition:p0,p1 data:Selection -└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 2.00 root partition:p0,p1 data:Selection +└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < 2 OR a > 6; a b @@ -4965,8 +4965,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE a <= 2 OR a > 6; id estRows task access object operator info -TableReader 4.00 root partition:p0,p1,p2 data:Selection -└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 3.00 root partition:p0,p1,p2 data:Selection +└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a <= 2 OR a > 6; a b @@ -4997,8 +4997,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a > 2 AND a <= 6); id estRows task access object operator info -TableReader 4.00 root partition:p0,p1,p2 data:Selection -└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 3.00 root partition:p0,p1,p2 data:Selection +└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a > 2 AND a <= 6); a b @@ -5040,8 +5040,8 @@ a b 5 5 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a >= 2 AND a <= 6); id estRows task access object operator info -TableReader 3.00 root partition:p0,p1 data:Selection -└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) +TableReader 2.00 root partition:p0,p1 data:Selection +└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 6)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a >= 2 AND a <= 6); a b @@ -5091,8 +5091,8 @@ SELECT * FROM t WHERE a >= 7; a b explain format='brief' SELECT * FROM t WHERE a < 2 OR a > 7; id estRows task access object operator info -TableReader 3.00 root partition:p0,p1 data:Selection -└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 2.00 root partition:p0,p1 data:Selection +└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < 2 OR a > 7; a b @@ -5123,8 +5123,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a > 2 AND a < 7); id estRows task access object operator info -TableReader 4.00 root partition:p0,p1,p2 data:Selection -└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1,p2 data:Selection +└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a > 2 AND a < 7); a b @@ -5133,8 +5133,8 @@ a b 2 2 Filler... explain format='brief' SELECT * FROM t WHERE a < 2 OR a >= 7; id estRows task access object operator info -TableReader 3.00 root partition:p0,p1 data:Selection -└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 2.00 root partition:p0,p1 data:Selection +└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a < 2 OR a >= 7; a b @@ -5166,8 +5166,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a >= 2 AND a < 7); id estRows task access object operator info -TableReader 3.00 root partition:p0,p1 data:Selection -└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 2.00 root partition:p0,p1 data:Selection +└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a >= 2 AND a < 7); a b @@ -5175,8 +5175,8 @@ a b 1 1 Filler... explain format='brief' SELECT * FROM t WHERE a <= 2 OR a > 7; id estRows task access object operator info -TableReader 4.00 root partition:p0,p1,p2 data:Selection -└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1,p2 data:Selection +└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a <= 2 OR a > 7; a b @@ -5207,8 +5207,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a > 2 AND a <= 7); id estRows task access object operator info -TableReader 4.00 root partition:p0,p1,p2 data:Selection -└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1,p2 data:Selection +└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a > 2 AND a <= 7); a b @@ -5217,8 +5217,8 @@ a b 2 2 Filler... explain format='brief' SELECT * FROM t WHERE a <= 2 OR a >= 7; id estRows task access object operator info -TableReader 4.00 root partition:p0,p1,p2 data:Selection -└─Selection 4.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) +TableReader 3.00 root partition:p0,p1,p2 data:Selection +└─Selection 3.00 cop[tikv] or(le(executor__partition__partition_boundaries.t.a, 2), ge(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE a <= 2 OR a >= 7; a b @@ -5250,8 +5250,8 @@ a b 6 6 Filler... explain format='brief' SELECT * FROM t WHERE NOT (a >= 2 AND a <= 7); id estRows task access object operator info -TableReader 3.00 root partition:p0,p1 data:Selection -└─Selection 3.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) +TableReader 2.00 root partition:p0,p1 data:Selection +└─Selection 2.00 cop[tikv] or(lt(executor__partition__partition_boundaries.t.a, 2), gt(executor__partition__partition_boundaries.t.a, 7)) └─TableFullScan 7.00 cop[tikv] table:t keep order:false SELECT * FROM t WHERE NOT (a >= 2 AND a <= 7); a b From 3549f99c2d863a2188cb8539a5b1b50fd1903d2f Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 14:11:21 -0500 Subject: [PATCH 06/35] testcase updates5 --- br/pkg/backup/binding__failpoint_binding__.go | 14 + .../binding__failpoint_binding__.go | 14 + br/pkg/backup/prepare_snap/prepare.go | 6 +- .../prepare.go__failpoint_stash__ | 484 ++ br/pkg/backup/store.go | 32 +- br/pkg/backup/store.go__failpoint_stash__ | 307 + .../binding__failpoint_binding__.go | 14 + br/pkg/checkpoint/checkpoint.go | 22 +- .../checkpoint.go__failpoint_stash__ | 872 ++ .../checksum/binding__failpoint_binding__.go | 14 + br/pkg/checksum/executor.go | 4 +- .../checksum/executor.go__failpoint_stash__ | 419 + br/pkg/conn/binding__failpoint_binding__.go | 14 + br/pkg/conn/conn.go | 18 +- br/pkg/conn/conn.go__failpoint_stash__ | 457 ++ br/pkg/pdutil/binding__failpoint_binding__.go | 14 + br/pkg/pdutil/pd.go | 4 +- br/pkg/pdutil/pd.go__failpoint_stash__ | 782 ++ .../restore/binding__failpoint_binding__.go | 14 + .../binding__failpoint_binding__.go | 14 + br/pkg/restore/log_client/client.go | 30 +- .../log_client/client.go__failpoint_stash__ | 1689 ++++ br/pkg/restore/log_client/import_retry.go | 4 +- .../import_retry.go__failpoint_stash__ | 284 + br/pkg/restore/misc.go | 4 +- br/pkg/restore/misc.go__failpoint_stash__ | 157 + .../binding__failpoint_binding__.go | 14 + br/pkg/restore/snap_client/client.go | 10 +- .../snap_client/client.go__failpoint_stash__ | 1200 +++ br/pkg/restore/snap_client/context_manager.go | 4 +- .../context_manager.go__failpoint_stash__ | 290 + br/pkg/restore/snap_client/import.go | 8 +- .../snap_client/import.go__failpoint_stash__ | 846 ++ .../split/binding__failpoint_binding__.go | 14 + br/pkg/restore/split/client.go | 26 +- .../split/client.go__failpoint_stash__ | 1067 +++ br/pkg/restore/split/split.go | 4 +- .../restore/split/split.go__failpoint_stash__ | 352 + .../storage/binding__failpoint_binding__.go | 14 + br/pkg/storage/s3.go | 8 +- br/pkg/storage/s3.go__failpoint_stash__ | 1208 +++ br/pkg/streamhelper/advancer.go | 12 +- .../advancer.go__failpoint_stash__ | 735 ++ br/pkg/streamhelper/advancer_cliext.go | 8 +- .../advancer_cliext.go__failpoint_stash__ | 301 + .../binding__failpoint_binding__.go | 14 + br/pkg/streamhelper/flush_subscriber.go | 6 +- .../flush_subscriber.go__failpoint_stash__ | 373 + br/pkg/task/backup.go | 8 +- br/pkg/task/backup.go__failpoint_stash__ | 835 ++ br/pkg/task/binding__failpoint_binding__.go | 14 + .../operator/binding__failpoint_binding__.go | 14 + br/pkg/task/operator/cmd.go | 6 +- .../task/operator/cmd.go__failpoint_stash__ | 251 + br/pkg/task/restore.go | 4 +- br/pkg/task/restore.go__failpoint_stash__ | 1713 ++++ br/pkg/task/stream.go | 12 +- br/pkg/task/stream.go__failpoint_stash__ | 1846 +++++ br/pkg/utils/backoff.go | 4 +- br/pkg/utils/backoff.go__failpoint_stash__ | 323 + br/pkg/utils/binding__failpoint_binding__.go | 14 + br/pkg/utils/pprof.go | 4 +- br/pkg/utils/pprof.go__failpoint_stash__ | 69 + br/pkg/utils/register.go | 16 +- br/pkg/utils/register.go__failpoint_stash__ | 334 + br/pkg/utils/store_manager.go | 4 +- .../utils/store_manager.go__failpoint_stash__ | 264 + .../export/binding__failpoint_binding__.go | 14 + dumpling/export/config.go | 4 +- dumpling/export/config.go__failpoint_stash__ | 809 ++ dumpling/export/dump.go | 26 +- dumpling/export/dump.go__failpoint_stash__ | 1704 ++++ dumpling/export/sql.go | 12 +- dumpling/export/sql.go__failpoint_stash__ | 1643 ++++ dumpling/export/status.go | 4 +- dumpling/export/status.go__failpoint_stash__ | 144 + dumpling/export/writer_util.go | 16 +- .../export/writer_util.go__failpoint_stash__ | 674 ++ .../importer/binding__failpoint_binding__.go | 14 + lightning/pkg/importer/chunk_process.go | 14 +- .../chunk_process.go__failpoint_stash__ | 778 ++ lightning/pkg/importer/get_pre_info.go | 35 +- .../get_pre_info.go__failpoint_stash__ | 835 ++ lightning/pkg/importer/import.go | 30 +- .../pkg/importer/import.go__failpoint_stash__ | 2080 +++++ lightning/pkg/importer/table_import.go | 40 +- .../table_import.go__failpoint_stash__ | 1822 +++++ .../server/binding__failpoint_binding__.go | 14 + lightning/pkg/server/lightning.go | 40 +- .../server/lightning.go__failpoint_stash__ | 1152 +++ pkg/autoid_service/autoid.go | 6 +- .../autoid.go__failpoint_stash__ | 612 ++ .../binding__failpoint_binding__.go | 14 + pkg/bindinfo/binding__failpoint_binding__.go | 14 + pkg/bindinfo/global_handle.go | 4 +- .../global_handle.go__failpoint_stash__ | 745 ++ pkg/ddl/add_column.go | 8 +- pkg/ddl/add_column.go__failpoint_stash__ | 1288 +++ pkg/ddl/backfilling.go | 26 +- pkg/ddl/backfilling.go__failpoint_stash__ | 1124 +++ pkg/ddl/backfilling_dist_scheduler.go | 22 +- ...lling_dist_scheduler.go__failpoint_stash__ | 639 ++ pkg/ddl/backfilling_operators.go | 34 +- ...ackfilling_operators.go__failpoint_stash__ | 962 +++ pkg/ddl/backfilling_read_index.go | 2 +- ...ckfilling_read_index.go__failpoint_stash__ | 315 + pkg/ddl/binding__failpoint_binding__.go | 14 + pkg/ddl/cluster.go | 16 +- pkg/ddl/cluster.go__failpoint_stash__ | 902 ++ pkg/ddl/column.go | 22 +- pkg/ddl/column.go__failpoint_stash__ | 1320 +++ pkg/ddl/constraint.go | 12 +- pkg/ddl/constraint.go__failpoint_stash__ | 432 + pkg/ddl/create_table.go | 14 +- pkg/ddl/create_table.go__failpoint_stash__ | 1527 ++++ pkg/ddl/ddl.go | 12 +- pkg/ddl/ddl.go__failpoint_stash__ | 1421 ++++ pkg/ddl/ddl_tiflash_api.go | 32 +- pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ | 608 ++ pkg/ddl/delete_range.go | 4 +- pkg/ddl/delete_range.go__failpoint_stash__ | 548 ++ pkg/ddl/executor.go | 32 +- pkg/ddl/executor.go__failpoint_stash__ | 6540 +++++++++++++++ pkg/ddl/index.go | 72 +- pkg/ddl/index.go__failpoint_stash__ | 2616 ++++++ pkg/ddl/index_cop.go | 8 +- pkg/ddl/index_cop.go__failpoint_stash__ | 392 + pkg/ddl/index_merge_tmp.go | 4 +- pkg/ddl/index_merge_tmp.go__failpoint_stash__ | 400 + pkg/ddl/ingest/backend.go | 16 +- pkg/ddl/ingest/backend.go__failpoint_stash__ | 343 + pkg/ddl/ingest/backend_mgr.go | 4 +- .../ingest/backend_mgr.go__failpoint_stash__ | 285 + .../ingest/binding__failpoint_binding__.go | 14 + pkg/ddl/ingest/checkpoint.go | 16 +- .../ingest/checkpoint.go__failpoint_stash__ | 509 ++ pkg/ddl/ingest/disk_root.go | 6 +- .../ingest/disk_root.go__failpoint_stash__ | 157 + pkg/ddl/ingest/env.go | 4 +- pkg/ddl/ingest/env.go__failpoint_stash__ | 189 + pkg/ddl/ingest/mock.go | 2 +- pkg/ddl/ingest/mock.go__failpoint_stash__ | 236 + pkg/ddl/job_scheduler.go | 34 +- pkg/ddl/job_scheduler.go__failpoint_stash__ | 837 ++ pkg/ddl/job_submitter.go | 18 +- pkg/ddl/job_submitter.go__failpoint_stash__ | 669 ++ pkg/ddl/job_worker.go | 32 +- pkg/ddl/job_worker.go__failpoint_stash__ | 1013 +++ pkg/ddl/mock.go | 20 +- pkg/ddl/mock.go__failpoint_stash__ | 260 + pkg/ddl/modify_column.go | 18 +- pkg/ddl/modify_column.go__failpoint_stash__ | 1318 +++ pkg/ddl/partition.go | 54 +- pkg/ddl/partition.go__failpoint_stash__ | 4922 +++++++++++ .../placement/binding__failpoint_binding__.go | 14 + pkg/ddl/placement/bundle.go | 4 +- .../placement/bundle.go__failpoint_stash__ | 712 ++ pkg/ddl/reorg.go | 18 +- pkg/ddl/reorg.go__failpoint_stash__ | 982 +++ pkg/ddl/rollingback.go | 6 +- pkg/ddl/rollingback.go__failpoint_stash__ | 629 ++ pkg/ddl/schema_version.go | 8 +- pkg/ddl/schema_version.go__failpoint_stash__ | 418 + .../session/binding__failpoint_binding__.go | 14 + pkg/ddl/session/session.go | 4 +- pkg/ddl/session/session.go__failpoint_stash__ | 137 + .../syncer/binding__failpoint_binding__.go | 14 + pkg/ddl/syncer/syncer.go | 8 +- pkg/ddl/syncer/syncer.go__failpoint_stash__ | 629 ++ pkg/ddl/table.go | 36 +- pkg/ddl/table.go__failpoint_stash__ | 1681 ++++ pkg/ddl/util/binding__failpoint_binding__.go | 14 + pkg/ddl/util/util.go | 4 +- pkg/ddl/util/util.go__failpoint_stash__ | 427 + pkg/distsql/binding__failpoint_binding__.go | 14 + pkg/distsql/request_builder.go | 8 +- .../request_builder.go__failpoint_stash__ | 862 ++ pkg/distsql/select_result.go | 6 +- .../select_result.go__failpoint_stash__ | 815 ++ .../scheduler/binding__failpoint_binding__.go | 14 + pkg/disttask/framework/scheduler/nodes.go | 4 +- .../scheduler/nodes.go__failpoint_stash__ | 191 + pkg/disttask/framework/scheduler/scheduler.go | 32 +- .../scheduler/scheduler.go__failpoint_stash__ | 624 ++ .../framework/scheduler/scheduler_manager.go | 14 +- .../scheduler_manager.go__failpoint_stash__ | 485 ++ .../storage/binding__failpoint_binding__.go | 14 + pkg/disttask/framework/storage/history.go | 4 +- .../storage/history.go__failpoint_stash__ | 96 + pkg/disttask/framework/storage/task_table.go | 8 +- .../storage/task_table.go__failpoint_stash__ | 809 ++ .../binding__failpoint_binding__.go | 14 + .../framework/taskexecutor/task_executor.go | 38 +- .../task_executor.go__failpoint_stash__ | 683 ++ .../binding__failpoint_binding__.go | 14 + pkg/disttask/importinto/planner.go | 14 +- .../importinto/planner.go__failpoint_stash__ | 528 ++ pkg/disttask/importinto/scheduler.go | 20 +- .../scheduler.go__failpoint_stash__ | 719 ++ pkg/disttask/importinto/subtask_executor.go | 14 +- .../subtask_executor.go__failpoint_stash__ | 142 + pkg/disttask/importinto/task_executor.go | 4 +- .../task_executor.go__failpoint_stash__ | 577 ++ pkg/domain/binding__failpoint_binding__.go | 14 + pkg/domain/domain.go | 38 +- pkg/domain/domain.go__failpoint_stash__ | 3239 ++++++++ pkg/domain/historical_stats.go | 4 +- .../historical_stats.go__failpoint_stash__ | 98 + .../infosync/binding__failpoint_binding__.go | 14 + pkg/domain/infosync/info.go | 32 +- .../infosync/info.go__failpoint_stash__ | 1457 ++++ pkg/domain/infosync/tiflash_manager.go | 4 +- .../tiflash_manager.go__failpoint_stash__ | 893 ++ pkg/domain/plan_replayer_dump.go | 4 +- .../plan_replayer_dump.go__failpoint_stash__ | 923 +++ pkg/domain/runaway.go | 12 +- pkg/domain/runaway.go__failpoint_stash__ | 659 ++ pkg/executor/adapter.go | 40 +- pkg/executor/adapter.go__failpoint_stash__ | 2240 +++++ pkg/executor/aggregate/agg_hash_executor.go | 34 +- .../agg_hash_executor.go__failpoint_stash__ | 857 ++ .../aggregate/agg_hash_final_worker.go | 12 +- ...gg_hash_final_worker.go__failpoint_stash__ | 283 + .../aggregate/agg_hash_partial_worker.go | 14 +- ..._hash_partial_worker.go__failpoint_stash__ | 390 + pkg/executor/aggregate/agg_stream_executor.go | 14 +- .../agg_stream_executor.go__failpoint_stash__ | 237 + pkg/executor/aggregate/agg_util.go | 4 +- .../aggregate/agg_util.go__failpoint_stash__ | 312 + .../aggregate/binding__failpoint_binding__.go | 14 + pkg/executor/analyze.go | 12 +- pkg/executor/analyze.go__failpoint_stash__ | 619 ++ pkg/executor/analyze_col.go | 8 +- .../analyze_col.go__failpoint_stash__ | 494 ++ pkg/executor/analyze_col_v2.go | 20 +- .../analyze_col_v2.go__failpoint_stash__ | 885 ++ pkg/executor/analyze_idx.go | 14 +- .../analyze_idx.go__failpoint_stash__ | 344 + pkg/executor/batch_point_get.go | 6 +- .../batch_point_get.go__failpoint_stash__ | 528 ++ pkg/executor/binding__failpoint_binding__.go | 14 + pkg/executor/brie.go | 10 +- pkg/executor/brie.go__failpoint_stash__ | 826 ++ pkg/executor/builder.go | 42 +- pkg/executor/builder.go__failpoint_stash__ | 5659 +++++++++++++ pkg/executor/checksum.go | 2 +- pkg/executor/checksum.go__failpoint_stash__ | 336 + pkg/executor/compiler.go | 8 +- pkg/executor/compiler.go__failpoint_stash__ | 568 ++ pkg/executor/cte.go | 20 +- pkg/executor/cte.go__failpoint_stash__ | 770 ++ pkg/executor/executor.go | 6 +- pkg/executor/executor.go__failpoint_stash__ | 2673 ++++++ pkg/executor/import_into.go | 2 +- .../import_into.go__failpoint_stash__ | 344 + .../importer/binding__failpoint_binding__.go | 14 + pkg/executor/importer/job.go | 4 +- .../importer/job.go__failpoint_stash__ | 370 + pkg/executor/importer/table_import.go | 16 +- .../table_import.go__failpoint_stash__ | 983 +++ pkg/executor/index_merge_reader.go | 72 +- .../index_merge_reader.go__failpoint_stash__ | 2056 +++++ pkg/executor/infoschema_reader.go | 4 +- .../infoschema_reader.go__failpoint_stash__ | 3878 +++++++++ pkg/executor/inspection_result.go | 4 +- .../inspection_result.go__failpoint_stash__ | 1248 +++ .../binding__failpoint_binding__.go | 14 + .../calibrateresource/calibrate_resource.go | 12 +- .../calibrate_resource.go__failpoint_stash__ | 704 ++ .../exec/binding__failpoint_binding__.go | 14 + pkg/executor/internal/exec/executor.go | 12 +- .../exec/executor.go__failpoint_stash__ | 468 ++ .../mpp/binding__failpoint_binding__.go | 14 + .../internal/mpp/executor_with_retry.go | 16 +- .../executor_with_retry.go__failpoint_stash__ | 254 + .../internal/mpp/local_mpp_coordinator.go | 18 +- ...ocal_mpp_coordinator.go__failpoint_stash__ | 772 ++ .../pdhelper/binding__failpoint_binding__.go | 14 + pkg/executor/internal/pdhelper/pd.go | 4 +- .../pdhelper/pd.go__failpoint_stash__ | 128 + pkg/executor/join/base_join_probe.go | 4 +- .../base_join_probe.go__failpoint_stash__ | 593 ++ .../join/binding__failpoint_binding__.go | 14 + pkg/executor/join/hash_join_base.go | 26 +- .../join/hash_join_base.go__failpoint_stash__ | 379 + pkg/executor/join/hash_join_v1.go | 16 +- .../join/hash_join_v1.go__failpoint_stash__ | 1434 ++++ pkg/executor/join/hash_join_v2.go | 12 +- .../join/hash_join_v2.go__failpoint_stash__ | 943 +++ pkg/executor/join/index_lookup_hash_join.go | 34 +- ...dex_lookup_hash_join.go__failpoint_stash__ | 884 ++ pkg/executor/join/index_lookup_join.go | 14 +- .../index_lookup_join.go__failpoint_stash__ | 882 ++ pkg/executor/join/index_lookup_merge_join.go | 6 +- ...ex_lookup_merge_join.go__failpoint_stash__ | 743 ++ pkg/executor/join/merge_join.go | 10 +- .../join/merge_join.go__failpoint_stash__ | 420 + pkg/executor/load_data.go | 8 +- pkg/executor/load_data.go__failpoint_stash__ | 780 ++ pkg/executor/memtable_reader.go | 8 +- .../memtable_reader.go__failpoint_stash__ | 1009 +++ pkg/executor/metrics_reader.go | 12 +- .../metrics_reader.go__failpoint_stash__ | 365 + pkg/executor/parallel_apply.go | 8 +- .../parallel_apply.go__failpoint_stash__ | 405 + pkg/executor/point_get.go | 10 +- pkg/executor/point_get.go__failpoint_stash__ | 824 ++ pkg/executor/projection.go | 16 +- pkg/executor/projection.go__failpoint_stash__ | 501 ++ pkg/executor/shuffle.go | 12 +- pkg/executor/shuffle.go__failpoint_stash__ | 492 ++ pkg/executor/slow_query.go | 8 +- pkg/executor/slow_query.go__failpoint_stash__ | 1259 +++ .../sortexec/binding__failpoint_binding__.go | 14 + pkg/executor/sortexec/parallel_sort_worker.go | 8 +- ...parallel_sort_worker.go__failpoint_stash__ | 229 + pkg/executor/sortexec/sort.go | 16 +- .../sortexec/sort.go__failpoint_stash__ | 845 ++ pkg/executor/sortexec/sort_partition.go | 8 +- .../sort_partition.go__failpoint_stash__ | 367 + pkg/executor/sortexec/sort_util.go | 4 +- .../sortexec/sort_util.go__failpoint_stash__ | 124 + pkg/executor/sortexec/topn.go | 4 +- .../sortexec/topn.go__failpoint_stash__ | 647 ++ pkg/executor/sortexec/topn_worker.go | 4 +- .../topn_worker.go__failpoint_stash__ | 130 + pkg/executor/table_reader.go | 4 +- .../table_reader.go__failpoint_stash__ | 632 ++ .../unionexec/binding__failpoint_binding__.go | 14 + pkg/executor/unionexec/union.go | 12 +- .../unionexec/union.go__failpoint_stash__ | 232 + .../binding__failpoint_binding__.go | 14 + pkg/expression/aggregation/explain.go | 4 +- .../aggregation/explain.go__failpoint_stash__ | 80 + .../binding__failpoint_binding__.go | 14 + pkg/expression/builtin_json.go | 6 +- .../builtin_json.go__failpoint_stash__ | 1940 +++++ pkg/expression/builtin_time.go | 8 +- .../builtin_time.go__failpoint_stash__ | 6832 ++++++++++++++++ pkg/expression/expr_to_pb.go | 4 +- .../expr_to_pb.go__failpoint_stash__ | 319 + pkg/expression/helper.go | 6 +- pkg/expression/helper.go__failpoint_stash__ | 170 + pkg/expression/infer_pushdown.go | 14 +- .../infer_pushdown.go__failpoint_stash__ | 536 ++ pkg/expression/util.go | 6 +- pkg/expression/util.go__failpoint_stash__ | 1921 +++++ .../binding__failpoint_binding__.go | 14 + pkg/infoschema/builder.go | 6 +- pkg/infoschema/builder.go__failpoint_stash__ | 1040 +++ pkg/infoschema/infoschema_v2.go | 6 +- .../infoschema_v2.go__failpoint_stash__ | 1456 ++++ .../binding__failpoint_binding__.go | 14 + pkg/infoschema/perfschema/tables.go | 4 +- .../perfschema/tables.go__failpoint_stash__ | 415 + pkg/infoschema/sieve.go | 6 +- pkg/infoschema/sieve.go__failpoint_stash__ | 272 + pkg/infoschema/tables.go | 22 +- pkg/infoschema/tables.go__failpoint_stash__ | 2694 ++++++ pkg/kv/binding__failpoint_binding__.go | 14 + pkg/kv/txn.go | 10 +- pkg/kv/txn.go__failpoint_stash__ | 247 + pkg/lightning/backend/backend.go | 4 +- .../backend/backend.go__failpoint_stash__ | 439 + .../backend/binding__failpoint_binding__.go | 14 + .../external/binding__failpoint_binding__.go | 14 + pkg/lightning/backend/external/byte_reader.go | 4 +- .../byte_reader.go__failpoint_stash__ | 351 + pkg/lightning/backend/external/engine.go | 4 +- .../external/engine.go__failpoint_stash__ | 732 ++ pkg/lightning/backend/external/merge_v2.go | 4 +- .../external/merge_v2.go__failpoint_stash__ | 183 + .../local/binding__failpoint_binding__.go | 14 + pkg/lightning/backend/local/checksum.go | 2 +- .../local/checksum.go__failpoint_stash__ | 517 ++ pkg/lightning/backend/local/engine.go | 12 +- .../local/engine.go__failpoint_stash__ | 1682 ++++ pkg/lightning/backend/local/engine_mgr.go | 4 +- .../local/engine_mgr.go__failpoint_stash__ | 658 ++ pkg/lightning/backend/local/local.go | 50 +- .../backend/local/local.go__failpoint_stash__ | 1754 ++++ pkg/lightning/backend/local/local_unix.go | 8 +- .../local/local_unix.go__failpoint_stash__ | 92 + pkg/lightning/backend/local/region_job.go | 36 +- .../local/region_job.go__failpoint_stash__ | 907 ++ .../tidb/binding__failpoint_binding__.go | 14 + pkg/lightning/backend/tidb/tidb.go | 15 +- .../backend/tidb/tidb.go__failpoint_stash__ | 956 +++ .../common/binding__failpoint_binding__.go | 14 + pkg/lightning/common/storage_unix.go | 6 +- .../common/storage_unix.go__failpoint_stash__ | 79 + pkg/lightning/common/storage_windows.go | 6 +- .../storage_windows.go__failpoint_stash__ | 56 + pkg/lightning/common/util.go | 8 +- .../common/util.go__failpoint_stash__ | 704 ++ .../mydump/binding__failpoint_binding__.go | 14 + pkg/lightning/mydump/loader.go | 8 +- .../mydump/loader.go__failpoint_stash__ | 868 ++ pkg/meta/autoid/autoid.go | 16 +- pkg/meta/autoid/autoid.go__failpoint_stash__ | 1351 +++ .../autoid/binding__failpoint_binding__.go | 14 + pkg/owner/binding__failpoint_binding__.go | 14 + pkg/owner/manager.go | 10 +- pkg/owner/manager.go__failpoint_stash__ | 486 ++ pkg/owner/mock.go | 6 +- pkg/owner/mock.go__failpoint_stash__ | 230 + .../ast/binding__failpoint_binding__.go | 14 + pkg/parser/ast/misc.go | 4 +- pkg/parser/ast/misc.go__failpoint_stash__ | 4209 ++++++++++ pkg/planner/binding__failpoint_binding__.go | 14 + .../binding__failpoint_binding__.go | 14 + pkg/planner/cardinality/row_count_index.go | 4 +- .../row_count_index.go__failpoint_stash__ | 568 ++ pkg/planner/cardinality/selectivity_test.go | 18 +- .../core/binding__failpoint_binding__.go | 14 + .../core/collect_column_stats_usage.go | 4 +- ...t_column_stats_usage.go__failpoint_stash__ | 456 ++ pkg/planner/core/debugtrace.go | 4 +- .../core/debugtrace.go__failpoint_stash__ | 261 + pkg/planner/core/encode.go | 8 +- pkg/planner/core/encode.go__failpoint_stash__ | 386 + pkg/planner/core/exhaust_physical_plans.go | 16 +- ...haust_physical_plans.go__failpoint_stash__ | 3004 +++++++ pkg/planner/core/find_best_task.go | 6 +- .../core/find_best_task.go__failpoint_stash__ | 2982 +++++++ pkg/planner/core/logical_join.go | 6 +- .../core/logical_join.go__failpoint_stash__ | 1672 ++++ pkg/planner/core/logical_plan_builder.go | 4 +- ...logical_plan_builder.go__failpoint_stash__ | 7284 +++++++++++++++++ pkg/planner/core/optimizer.go | 8 +- .../core/optimizer.go__failpoint_stash__ | 1222 +++ pkg/planner/core/rule_collect_plan_stats.go | 4 +- ...e_collect_plan_stats.go__failpoint_stash__ | 316 + pkg/planner/core/rule_eliminate_projection.go | 6 +- ...eliminate_projection.go__failpoint_stash__ | 274 + .../core/rule_inject_extra_projection.go | 6 +- ...ect_extra_projection.go__failpoint_stash__ | 352 + pkg/planner/core/task.go | 4 +- pkg/planner/core/task.go__failpoint_stash__ | 2473 ++++++ pkg/planner/optimize.go | 18 +- pkg/planner/optimize.go__failpoint_stash__ | 631 ++ pkg/server/binding__failpoint_binding__.go | 14 + pkg/server/conn.go | 38 +- pkg/server/conn.go__failpoint_stash__ | 2748 +++++++ pkg/server/conn_stmt.go | 10 +- pkg/server/conn_stmt.go__failpoint_stash__ | 673 ++ .../handler/binding__failpoint_binding__.go | 14 + .../binding__failpoint_binding__.go | 14 + .../handler/extractorhandler/extractor.go | 6 +- .../extractor.go__failpoint_stash__ | 169 + pkg/server/handler/tikv_handler.go | 4 +- .../tikv_handler.go__failpoint_stash__ | 279 + pkg/server/http_status.go | 4 +- pkg/server/http_status.go__failpoint_stash__ | 613 ++ pkg/session/binding__failpoint_binding__.go | 14 + pkg/session/nontransactional.go | 4 +- .../nontransactional.go__failpoint_stash__ | 847 ++ pkg/session/session.go | 36 +- pkg/session/session.go__failpoint_stash__ | 4611 +++++++++++ pkg/session/sync_upgrade.go | 6 +- .../sync_upgrade.go__failpoint_stash__ | 163 + pkg/session/tidb.go | 6 +- pkg/session/tidb.go__failpoint_stash__ | 403 + pkg/session/txn.go | 26 +- pkg/session/txn.go__failpoint_stash__ | 778 ++ pkg/session/txnmanager.go | 4 +- pkg/session/txnmanager.go__failpoint_stash__ | 381 + .../binding__failpoint_binding__.go | 14 + pkg/sessionctx/sessionstates/session_token.go | 4 +- .../session_token.go__failpoint_stash__ | 336 + pkg/sessiontxn/isolation/base.go | 4 +- .../isolation/base.go__failpoint_stash__ | 747 ++ .../isolation/binding__failpoint_binding__.go | 14 + pkg/sessiontxn/isolation/readcommitted.go | 12 +- .../readcommitted.go__failpoint_stash__ | 360 + pkg/sessiontxn/isolation/repeatable_read.go | 4 +- .../repeatable_read.go__failpoint_stash__ | 309 + .../staleread/binding__failpoint_binding__.go | 14 + pkg/sessiontxn/staleread/util.go | 4 +- .../staleread/util.go__failpoint_stash__ | 97 + .../binding__failpoint_binding__.go | 14 + pkg/statistics/cmsketch.go | 10 +- pkg/statistics/cmsketch.go__failpoint_stash__ | 865 ++ .../handle/autoanalyze/autoanalyze.go | 16 +- .../autoanalyze.go__failpoint_stash__ | 843 ++ .../binding__failpoint_binding__.go | 14 + .../handle/binding__failpoint_binding__.go | 14 + pkg/statistics/handle/bootstrap.go | 4 +- .../handle/bootstrap.go__failpoint_stash__ | 815 ++ .../cache/binding__failpoint_binding__.go | 14 + pkg/statistics/handle/cache/statscache.go | 6 +- .../cache/statscache.go__failpoint_stash__ | 287 + .../binding__failpoint_binding__.go | 14 + .../handle/globalstats/global_stats_async.go | 36 +- .../global_stats_async.go__failpoint_stash__ | 542 ++ .../storage/binding__failpoint_binding__.go | 14 + pkg/statistics/handle/storage/read.go | 6 +- .../handle/storage/read.go__failpoint_stash__ | 759 ++ .../syncload/binding__failpoint_binding__.go | 14 + .../handle/syncload/stats_syncload.go | 12 +- .../stats_syncload.go__failpoint_stash__ | 574 ++ pkg/store/copr/batch_coprocessor.go | 12 +- .../batch_coprocessor.go__failpoint_stash__ | 1588 ++++ .../copr/binding__failpoint_binding__.go | 14 + pkg/store/copr/coprocessor.go | 76 +- .../copr/coprocessor.go__failpoint_stash__ | 2170 +++++ pkg/store/copr/mpp.go | 16 +- pkg/store/copr/mpp.go__failpoint_stash__ | 346 + .../txn/binding__failpoint_binding__.go | 14 + pkg/store/driver/txn/binlog.go | 4 +- .../driver/txn/binlog.go__failpoint_stash__ | 90 + pkg/store/driver/txn/txn_driver.go | 2 +- .../txn/txn_driver.go__failpoint_stash__ | 491 ++ .../gcworker/binding__failpoint_binding__.go | 14 + pkg/store/gcworker/gc_worker.go | 22 +- .../gcworker/gc_worker.go__failpoint_stash__ | 1759 ++++ .../unistore/binding__failpoint_binding__.go | 14 + .../binding__failpoint_binding__.go | 14 + .../unistore/cophandler/cop_handler.go | 6 +- .../cop_handler.go__failpoint_stash__ | 674 ++ pkg/store/mockstore/unistore/rpc.go | 150 +- .../unistore/rpc.go__failpoint_stash__ | 582 ++ .../binding__failpoint_binding__.go | 14 + pkg/table/contextimpl/table.go | 6 +- .../contextimpl/table.go__failpoint_stash__ | 200 + .../tables/binding__failpoint_binding__.go | 14 + pkg/table/tables/cache.go | 12 +- pkg/table/tables/cache.go__failpoint_stash__ | 355 + pkg/table/tables/mutation_checker.go | 6 +- .../mutation_checker.go__failpoint_stash__ | 556 ++ pkg/table/tables/tables.go | 18 +- pkg/table/tables/tables.go__failpoint_stash__ | 2100 +++++ .../ttlworker/binding__failpoint_binding__.go | 14 + pkg/ttl/ttlworker/config.go | 20 +- .../ttlworker/config.go__failpoint_stash__ | 70 + pkg/util/binding__failpoint_binding__.go | 14 + .../binding__failpoint_binding__.go | 14 + pkg/util/breakpoint/breakpoint.go | 4 +- .../breakpoint.go__failpoint_stash__ | 34 + .../cgroup/binding__failpoint_binding__.go | 14 + pkg/util/cgroup/cgroup_cpu_linux.go | 6 +- .../cgroup_cpu_linux.go__failpoint_stash__ | 100 + pkg/util/cgroup/cgroup_cpu_unsupport.go | 6 +- ...cgroup_cpu_unsupport.go__failpoint_stash__ | 55 + .../chunk/binding__failpoint_binding__.go | 14 + pkg/util/chunk/chunk_in_disk.go | 4 +- .../chunk/chunk_in_disk.go__failpoint_stash__ | 344 + pkg/util/chunk/row_container.go | 16 +- .../chunk/row_container.go__failpoint_stash__ | 691 ++ pkg/util/chunk/row_container_reader.go | 4 +- ...row_container_reader.go__failpoint_stash__ | 170 + .../codec/binding__failpoint_binding__.go | 14 + pkg/util/codec/decimal.go | 6 +- pkg/util/codec/decimal.go__failpoint_stash__ | 69 + pkg/util/cpu/binding__failpoint_binding__.go | 14 + pkg/util/cpu/cpu.go | 6 +- pkg/util/cpu/cpu.go__failpoint_stash__ | 132 + pkg/util/etcd.go | 12 +- pkg/util/etcd.go__failpoint_stash__ | 103 + .../gctuner/binding__failpoint_binding__.go | 14 + pkg/util/gctuner/memory_limit_tuner.go | 8 +- .../memory_limit_tuner.go__failpoint_stash__ | 190 + .../memory/binding__failpoint_binding__.go | 14 + pkg/util/memory/meminfo.go | 4 +- pkg/util/memory/meminfo.go__failpoint_stash__ | 215 + pkg/util/memory/memstats.go | 4 +- .../memory/memstats.go__failpoint_stash__ | 57 + .../replayer/binding__failpoint_binding__.go | 14 + pkg/util/replayer/replayer.go | 4 +- .../replayer/replayer.go__failpoint_stash__ | 87 + .../binding__failpoint_binding__.go | 14 + .../servermemorylimit/servermemorylimit.go | 4 +- .../servermemorylimit.go__failpoint_stash__ | 264 + pkg/util/session_pool.go | 4 +- pkg/util/session_pool.go__failpoint_stash__ | 113 + pkg/util/sli/binding__failpoint_binding__.go | 14 + pkg/util/sli/sli.go | 6 +- pkg/util/sli/sli.go__failpoint_stash__ | 120 + .../sqlkiller/binding__failpoint_binding__.go | 14 + pkg/util/sqlkiller/sqlkiller.go | 4 +- .../sqlkiller/sqlkiller.go__failpoint_stash__ | 136 + .../binding__failpoint_binding__.go | 14 + pkg/util/stmtsummary/statement_summary.go | 4 +- .../statement_summary.go__failpoint_stash__ | 1039 +++ .../topsql/binding__failpoint_binding__.go | 14 + .../reporter/binding__failpoint_binding__.go | 14 + pkg/util/topsql/reporter/pubsub.go | 2 +- .../reporter/pubsub.go__failpoint_stash__ | 274 + pkg/util/topsql/reporter/reporter.go | 4 +- .../reporter/reporter.go__failpoint_stash__ | 333 + pkg/util/topsql/topsql.go | 8 +- pkg/util/topsql/topsql.go__failpoint_stash__ | 187 + 592 files changed, 218916 insertions(+), 1622 deletions(-) create mode 100644 br/pkg/backup/binding__failpoint_binding__.go create mode 100644 br/pkg/backup/prepare_snap/binding__failpoint_binding__.go create mode 100644 br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ create mode 100644 br/pkg/backup/store.go__failpoint_stash__ create mode 100644 br/pkg/checkpoint/binding__failpoint_binding__.go create mode 100644 br/pkg/checkpoint/checkpoint.go__failpoint_stash__ create mode 100644 br/pkg/checksum/binding__failpoint_binding__.go create mode 100644 br/pkg/checksum/executor.go__failpoint_stash__ create mode 100644 br/pkg/conn/binding__failpoint_binding__.go create mode 100644 br/pkg/conn/conn.go__failpoint_stash__ create mode 100644 br/pkg/pdutil/binding__failpoint_binding__.go create mode 100644 br/pkg/pdutil/pd.go__failpoint_stash__ create mode 100644 br/pkg/restore/binding__failpoint_binding__.go create mode 100644 br/pkg/restore/log_client/binding__failpoint_binding__.go create mode 100644 br/pkg/restore/log_client/client.go__failpoint_stash__ create mode 100644 br/pkg/restore/log_client/import_retry.go__failpoint_stash__ create mode 100644 br/pkg/restore/misc.go__failpoint_stash__ create mode 100644 br/pkg/restore/snap_client/binding__failpoint_binding__.go create mode 100644 br/pkg/restore/snap_client/client.go__failpoint_stash__ create mode 100644 br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ create mode 100644 br/pkg/restore/snap_client/import.go__failpoint_stash__ create mode 100644 br/pkg/restore/split/binding__failpoint_binding__.go create mode 100644 br/pkg/restore/split/client.go__failpoint_stash__ create mode 100644 br/pkg/restore/split/split.go__failpoint_stash__ create mode 100644 br/pkg/storage/binding__failpoint_binding__.go create mode 100644 br/pkg/storage/s3.go__failpoint_stash__ create mode 100644 br/pkg/streamhelper/advancer.go__failpoint_stash__ create mode 100644 br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ create mode 100644 br/pkg/streamhelper/binding__failpoint_binding__.go create mode 100644 br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ create mode 100644 br/pkg/task/backup.go__failpoint_stash__ create mode 100644 br/pkg/task/binding__failpoint_binding__.go create mode 100644 br/pkg/task/operator/binding__failpoint_binding__.go create mode 100644 br/pkg/task/operator/cmd.go__failpoint_stash__ create mode 100644 br/pkg/task/restore.go__failpoint_stash__ create mode 100644 br/pkg/task/stream.go__failpoint_stash__ create mode 100644 br/pkg/utils/backoff.go__failpoint_stash__ create mode 100644 br/pkg/utils/binding__failpoint_binding__.go create mode 100644 br/pkg/utils/pprof.go__failpoint_stash__ create mode 100644 br/pkg/utils/register.go__failpoint_stash__ create mode 100644 br/pkg/utils/store_manager.go__failpoint_stash__ create mode 100644 dumpling/export/binding__failpoint_binding__.go create mode 100644 dumpling/export/config.go__failpoint_stash__ create mode 100644 dumpling/export/dump.go__failpoint_stash__ create mode 100644 dumpling/export/sql.go__failpoint_stash__ create mode 100644 dumpling/export/status.go__failpoint_stash__ create mode 100644 dumpling/export/writer_util.go__failpoint_stash__ create mode 100644 lightning/pkg/importer/binding__failpoint_binding__.go create mode 100644 lightning/pkg/importer/chunk_process.go__failpoint_stash__ create mode 100644 lightning/pkg/importer/get_pre_info.go__failpoint_stash__ create mode 100644 lightning/pkg/importer/import.go__failpoint_stash__ create mode 100644 lightning/pkg/importer/table_import.go__failpoint_stash__ create mode 100644 lightning/pkg/server/binding__failpoint_binding__.go create mode 100644 lightning/pkg/server/lightning.go__failpoint_stash__ create mode 100644 pkg/autoid_service/autoid.go__failpoint_stash__ create mode 100644 pkg/autoid_service/binding__failpoint_binding__.go create mode 100644 pkg/bindinfo/binding__failpoint_binding__.go create mode 100644 pkg/bindinfo/global_handle.go__failpoint_stash__ create mode 100644 pkg/ddl/add_column.go__failpoint_stash__ create mode 100644 pkg/ddl/backfilling.go__failpoint_stash__ create mode 100644 pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ create mode 100644 pkg/ddl/backfilling_operators.go__failpoint_stash__ create mode 100644 pkg/ddl/backfilling_read_index.go__failpoint_stash__ create mode 100644 pkg/ddl/binding__failpoint_binding__.go create mode 100644 pkg/ddl/cluster.go__failpoint_stash__ create mode 100644 pkg/ddl/column.go__failpoint_stash__ create mode 100644 pkg/ddl/constraint.go__failpoint_stash__ create mode 100644 pkg/ddl/create_table.go__failpoint_stash__ create mode 100644 pkg/ddl/ddl.go__failpoint_stash__ create mode 100644 pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ create mode 100644 pkg/ddl/delete_range.go__failpoint_stash__ create mode 100644 pkg/ddl/executor.go__failpoint_stash__ create mode 100644 pkg/ddl/index.go__failpoint_stash__ create mode 100644 pkg/ddl/index_cop.go__failpoint_stash__ create mode 100644 pkg/ddl/index_merge_tmp.go__failpoint_stash__ create mode 100644 pkg/ddl/ingest/backend.go__failpoint_stash__ create mode 100644 pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ create mode 100644 pkg/ddl/ingest/binding__failpoint_binding__.go create mode 100644 pkg/ddl/ingest/checkpoint.go__failpoint_stash__ create mode 100644 pkg/ddl/ingest/disk_root.go__failpoint_stash__ create mode 100644 pkg/ddl/ingest/env.go__failpoint_stash__ create mode 100644 pkg/ddl/ingest/mock.go__failpoint_stash__ create mode 100644 pkg/ddl/job_scheduler.go__failpoint_stash__ create mode 100644 pkg/ddl/job_submitter.go__failpoint_stash__ create mode 100644 pkg/ddl/job_worker.go__failpoint_stash__ create mode 100644 pkg/ddl/mock.go__failpoint_stash__ create mode 100644 pkg/ddl/modify_column.go__failpoint_stash__ create mode 100644 pkg/ddl/partition.go__failpoint_stash__ create mode 100644 pkg/ddl/placement/binding__failpoint_binding__.go create mode 100644 pkg/ddl/placement/bundle.go__failpoint_stash__ create mode 100644 pkg/ddl/reorg.go__failpoint_stash__ create mode 100644 pkg/ddl/rollingback.go__failpoint_stash__ create mode 100644 pkg/ddl/schema_version.go__failpoint_stash__ create mode 100644 pkg/ddl/session/binding__failpoint_binding__.go create mode 100644 pkg/ddl/session/session.go__failpoint_stash__ create mode 100644 pkg/ddl/syncer/binding__failpoint_binding__.go create mode 100644 pkg/ddl/syncer/syncer.go__failpoint_stash__ create mode 100644 pkg/ddl/table.go__failpoint_stash__ create mode 100644 pkg/ddl/util/binding__failpoint_binding__.go create mode 100644 pkg/ddl/util/util.go__failpoint_stash__ create mode 100644 pkg/distsql/binding__failpoint_binding__.go create mode 100644 pkg/distsql/request_builder.go__failpoint_stash__ create mode 100644 pkg/distsql/select_result.go__failpoint_stash__ create mode 100644 pkg/disttask/framework/scheduler/binding__failpoint_binding__.go create mode 100644 pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ create mode 100644 pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ create mode 100644 pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ create mode 100644 pkg/disttask/framework/storage/binding__failpoint_binding__.go create mode 100644 pkg/disttask/framework/storage/history.go__failpoint_stash__ create mode 100644 pkg/disttask/framework/storage/task_table.go__failpoint_stash__ create mode 100644 pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go create mode 100644 pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ create mode 100644 pkg/disttask/importinto/binding__failpoint_binding__.go create mode 100644 pkg/disttask/importinto/planner.go__failpoint_stash__ create mode 100644 pkg/disttask/importinto/scheduler.go__failpoint_stash__ create mode 100644 pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ create mode 100644 pkg/disttask/importinto/task_executor.go__failpoint_stash__ create mode 100644 pkg/domain/binding__failpoint_binding__.go create mode 100644 pkg/domain/domain.go__failpoint_stash__ create mode 100644 pkg/domain/historical_stats.go__failpoint_stash__ create mode 100644 pkg/domain/infosync/binding__failpoint_binding__.go create mode 100644 pkg/domain/infosync/info.go__failpoint_stash__ create mode 100644 pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ create mode 100644 pkg/domain/plan_replayer_dump.go__failpoint_stash__ create mode 100644 pkg/domain/runaway.go__failpoint_stash__ create mode 100644 pkg/executor/adapter.go__failpoint_stash__ create mode 100644 pkg/executor/aggregate/agg_hash_executor.go__failpoint_stash__ create mode 100644 pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ create mode 100644 pkg/executor/aggregate/agg_hash_partial_worker.go__failpoint_stash__ create mode 100644 pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ create mode 100644 pkg/executor/aggregate/agg_util.go__failpoint_stash__ create mode 100644 pkg/executor/aggregate/binding__failpoint_binding__.go create mode 100644 pkg/executor/analyze.go__failpoint_stash__ create mode 100644 pkg/executor/analyze_col.go__failpoint_stash__ create mode 100644 pkg/executor/analyze_col_v2.go__failpoint_stash__ create mode 100644 pkg/executor/analyze_idx.go__failpoint_stash__ create mode 100644 pkg/executor/batch_point_get.go__failpoint_stash__ create mode 100644 pkg/executor/binding__failpoint_binding__.go create mode 100644 pkg/executor/brie.go__failpoint_stash__ create mode 100644 pkg/executor/builder.go__failpoint_stash__ create mode 100644 pkg/executor/checksum.go__failpoint_stash__ create mode 100644 pkg/executor/compiler.go__failpoint_stash__ create mode 100644 pkg/executor/cte.go__failpoint_stash__ create mode 100644 pkg/executor/executor.go__failpoint_stash__ create mode 100644 pkg/executor/import_into.go__failpoint_stash__ create mode 100644 pkg/executor/importer/binding__failpoint_binding__.go create mode 100644 pkg/executor/importer/job.go__failpoint_stash__ create mode 100644 pkg/executor/importer/table_import.go__failpoint_stash__ create mode 100644 pkg/executor/index_merge_reader.go__failpoint_stash__ create mode 100644 pkg/executor/infoschema_reader.go__failpoint_stash__ create mode 100644 pkg/executor/inspection_result.go__failpoint_stash__ create mode 100644 pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go create mode 100644 pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ create mode 100644 pkg/executor/internal/exec/binding__failpoint_binding__.go create mode 100644 pkg/executor/internal/exec/executor.go__failpoint_stash__ create mode 100644 pkg/executor/internal/mpp/binding__failpoint_binding__.go create mode 100644 pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ create mode 100644 pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ create mode 100644 pkg/executor/internal/pdhelper/binding__failpoint_binding__.go create mode 100644 pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ create mode 100644 pkg/executor/join/base_join_probe.go__failpoint_stash__ create mode 100644 pkg/executor/join/binding__failpoint_binding__.go create mode 100644 pkg/executor/join/hash_join_base.go__failpoint_stash__ create mode 100644 pkg/executor/join/hash_join_v1.go__failpoint_stash__ create mode 100644 pkg/executor/join/hash_join_v2.go__failpoint_stash__ create mode 100644 pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ create mode 100644 pkg/executor/join/index_lookup_join.go__failpoint_stash__ create mode 100644 pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ create mode 100644 pkg/executor/join/merge_join.go__failpoint_stash__ create mode 100644 pkg/executor/load_data.go__failpoint_stash__ create mode 100644 pkg/executor/memtable_reader.go__failpoint_stash__ create mode 100644 pkg/executor/metrics_reader.go__failpoint_stash__ create mode 100644 pkg/executor/parallel_apply.go__failpoint_stash__ create mode 100644 pkg/executor/point_get.go__failpoint_stash__ create mode 100644 pkg/executor/projection.go__failpoint_stash__ create mode 100644 pkg/executor/shuffle.go__failpoint_stash__ create mode 100644 pkg/executor/slow_query.go__failpoint_stash__ create mode 100644 pkg/executor/sortexec/binding__failpoint_binding__.go create mode 100644 pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ create mode 100644 pkg/executor/sortexec/sort.go__failpoint_stash__ create mode 100644 pkg/executor/sortexec/sort_partition.go__failpoint_stash__ create mode 100644 pkg/executor/sortexec/sort_util.go__failpoint_stash__ create mode 100644 pkg/executor/sortexec/topn.go__failpoint_stash__ create mode 100644 pkg/executor/sortexec/topn_worker.go__failpoint_stash__ create mode 100644 pkg/executor/table_reader.go__failpoint_stash__ create mode 100644 pkg/executor/unionexec/binding__failpoint_binding__.go create mode 100644 pkg/executor/unionexec/union.go__failpoint_stash__ create mode 100644 pkg/expression/aggregation/binding__failpoint_binding__.go create mode 100644 pkg/expression/aggregation/explain.go__failpoint_stash__ create mode 100644 pkg/expression/binding__failpoint_binding__.go create mode 100644 pkg/expression/builtin_json.go__failpoint_stash__ create mode 100644 pkg/expression/builtin_time.go__failpoint_stash__ create mode 100644 pkg/expression/expr_to_pb.go__failpoint_stash__ create mode 100644 pkg/expression/helper.go__failpoint_stash__ create mode 100644 pkg/expression/infer_pushdown.go__failpoint_stash__ create mode 100644 pkg/expression/util.go__failpoint_stash__ create mode 100644 pkg/infoschema/binding__failpoint_binding__.go create mode 100644 pkg/infoschema/builder.go__failpoint_stash__ create mode 100644 pkg/infoschema/infoschema_v2.go__failpoint_stash__ create mode 100644 pkg/infoschema/perfschema/binding__failpoint_binding__.go create mode 100644 pkg/infoschema/perfschema/tables.go__failpoint_stash__ create mode 100644 pkg/infoschema/sieve.go__failpoint_stash__ create mode 100644 pkg/infoschema/tables.go__failpoint_stash__ create mode 100644 pkg/kv/binding__failpoint_binding__.go create mode 100644 pkg/kv/txn.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/backend.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/binding__failpoint_binding__.go create mode 100644 pkg/lightning/backend/external/binding__failpoint_binding__.go create mode 100644 pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/external/engine.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/local/binding__failpoint_binding__.go create mode 100644 pkg/lightning/backend/local/checksum.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/local/engine.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/local/local.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/local/local_unix.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/local/region_job.go__failpoint_stash__ create mode 100644 pkg/lightning/backend/tidb/binding__failpoint_binding__.go create mode 100644 pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ create mode 100644 pkg/lightning/common/binding__failpoint_binding__.go create mode 100644 pkg/lightning/common/storage_unix.go__failpoint_stash__ create mode 100644 pkg/lightning/common/storage_windows.go__failpoint_stash__ create mode 100644 pkg/lightning/common/util.go__failpoint_stash__ create mode 100644 pkg/lightning/mydump/binding__failpoint_binding__.go create mode 100644 pkg/lightning/mydump/loader.go__failpoint_stash__ create mode 100644 pkg/meta/autoid/autoid.go__failpoint_stash__ create mode 100644 pkg/meta/autoid/binding__failpoint_binding__.go create mode 100644 pkg/owner/binding__failpoint_binding__.go create mode 100644 pkg/owner/manager.go__failpoint_stash__ create mode 100644 pkg/owner/mock.go__failpoint_stash__ create mode 100644 pkg/parser/ast/binding__failpoint_binding__.go create mode 100644 pkg/parser/ast/misc.go__failpoint_stash__ create mode 100644 pkg/planner/binding__failpoint_binding__.go create mode 100644 pkg/planner/cardinality/binding__failpoint_binding__.go create mode 100644 pkg/planner/cardinality/row_count_index.go__failpoint_stash__ create mode 100644 pkg/planner/core/binding__failpoint_binding__.go create mode 100644 pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ create mode 100644 pkg/planner/core/debugtrace.go__failpoint_stash__ create mode 100644 pkg/planner/core/encode.go__failpoint_stash__ create mode 100644 pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ create mode 100644 pkg/planner/core/find_best_task.go__failpoint_stash__ create mode 100644 pkg/planner/core/logical_join.go__failpoint_stash__ create mode 100644 pkg/planner/core/logical_plan_builder.go__failpoint_stash__ create mode 100644 pkg/planner/core/optimizer.go__failpoint_stash__ create mode 100644 pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ create mode 100644 pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ create mode 100644 pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ create mode 100644 pkg/planner/core/task.go__failpoint_stash__ create mode 100644 pkg/planner/optimize.go__failpoint_stash__ create mode 100644 pkg/server/binding__failpoint_binding__.go create mode 100644 pkg/server/conn.go__failpoint_stash__ create mode 100644 pkg/server/conn_stmt.go__failpoint_stash__ create mode 100644 pkg/server/handler/binding__failpoint_binding__.go create mode 100644 pkg/server/handler/extractorhandler/binding__failpoint_binding__.go create mode 100644 pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ create mode 100644 pkg/server/handler/tikv_handler.go__failpoint_stash__ create mode 100644 pkg/server/http_status.go__failpoint_stash__ create mode 100644 pkg/session/binding__failpoint_binding__.go create mode 100644 pkg/session/nontransactional.go__failpoint_stash__ create mode 100644 pkg/session/session.go__failpoint_stash__ create mode 100644 pkg/session/sync_upgrade.go__failpoint_stash__ create mode 100644 pkg/session/tidb.go__failpoint_stash__ create mode 100644 pkg/session/txn.go__failpoint_stash__ create mode 100644 pkg/session/txnmanager.go__failpoint_stash__ create mode 100644 pkg/sessionctx/sessionstates/binding__failpoint_binding__.go create mode 100644 pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ create mode 100644 pkg/sessiontxn/isolation/base.go__failpoint_stash__ create mode 100644 pkg/sessiontxn/isolation/binding__failpoint_binding__.go create mode 100644 pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ create mode 100644 pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ create mode 100644 pkg/sessiontxn/staleread/binding__failpoint_binding__.go create mode 100644 pkg/sessiontxn/staleread/util.go__failpoint_stash__ create mode 100644 pkg/statistics/binding__failpoint_binding__.go create mode 100644 pkg/statistics/cmsketch.go__failpoint_stash__ create mode 100644 pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ create mode 100644 pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go create mode 100644 pkg/statistics/handle/binding__failpoint_binding__.go create mode 100644 pkg/statistics/handle/bootstrap.go__failpoint_stash__ create mode 100644 pkg/statistics/handle/cache/binding__failpoint_binding__.go create mode 100644 pkg/statistics/handle/cache/statscache.go__failpoint_stash__ create mode 100644 pkg/statistics/handle/globalstats/binding__failpoint_binding__.go create mode 100644 pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ create mode 100644 pkg/statistics/handle/storage/binding__failpoint_binding__.go create mode 100644 pkg/statistics/handle/storage/read.go__failpoint_stash__ create mode 100644 pkg/statistics/handle/syncload/binding__failpoint_binding__.go create mode 100644 pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ create mode 100644 pkg/store/copr/batch_coprocessor.go__failpoint_stash__ create mode 100644 pkg/store/copr/binding__failpoint_binding__.go create mode 100644 pkg/store/copr/coprocessor.go__failpoint_stash__ create mode 100644 pkg/store/copr/mpp.go__failpoint_stash__ create mode 100644 pkg/store/driver/txn/binding__failpoint_binding__.go create mode 100644 pkg/store/driver/txn/binlog.go__failpoint_stash__ create mode 100644 pkg/store/driver/txn/txn_driver.go__failpoint_stash__ create mode 100644 pkg/store/gcworker/binding__failpoint_binding__.go create mode 100644 pkg/store/gcworker/gc_worker.go__failpoint_stash__ create mode 100644 pkg/store/mockstore/unistore/binding__failpoint_binding__.go create mode 100644 pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go create mode 100644 pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ create mode 100644 pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ create mode 100644 pkg/table/contextimpl/binding__failpoint_binding__.go create mode 100644 pkg/table/contextimpl/table.go__failpoint_stash__ create mode 100644 pkg/table/tables/binding__failpoint_binding__.go create mode 100644 pkg/table/tables/cache.go__failpoint_stash__ create mode 100644 pkg/table/tables/mutation_checker.go__failpoint_stash__ create mode 100644 pkg/table/tables/tables.go__failpoint_stash__ create mode 100644 pkg/ttl/ttlworker/binding__failpoint_binding__.go create mode 100644 pkg/ttl/ttlworker/config.go__failpoint_stash__ create mode 100644 pkg/util/binding__failpoint_binding__.go create mode 100644 pkg/util/breakpoint/binding__failpoint_binding__.go create mode 100644 pkg/util/breakpoint/breakpoint.go__failpoint_stash__ create mode 100644 pkg/util/cgroup/binding__failpoint_binding__.go create mode 100644 pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ create mode 100644 pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ create mode 100644 pkg/util/chunk/binding__failpoint_binding__.go create mode 100644 pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ create mode 100644 pkg/util/chunk/row_container.go__failpoint_stash__ create mode 100644 pkg/util/chunk/row_container_reader.go__failpoint_stash__ create mode 100644 pkg/util/codec/binding__failpoint_binding__.go create mode 100644 pkg/util/codec/decimal.go__failpoint_stash__ create mode 100644 pkg/util/cpu/binding__failpoint_binding__.go create mode 100644 pkg/util/cpu/cpu.go__failpoint_stash__ create mode 100644 pkg/util/etcd.go__failpoint_stash__ create mode 100644 pkg/util/gctuner/binding__failpoint_binding__.go create mode 100644 pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ create mode 100644 pkg/util/memory/binding__failpoint_binding__.go create mode 100644 pkg/util/memory/meminfo.go__failpoint_stash__ create mode 100644 pkg/util/memory/memstats.go__failpoint_stash__ create mode 100644 pkg/util/replayer/binding__failpoint_binding__.go create mode 100644 pkg/util/replayer/replayer.go__failpoint_stash__ create mode 100644 pkg/util/servermemorylimit/binding__failpoint_binding__.go create mode 100644 pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ create mode 100644 pkg/util/session_pool.go__failpoint_stash__ create mode 100644 pkg/util/sli/binding__failpoint_binding__.go create mode 100644 pkg/util/sli/sli.go__failpoint_stash__ create mode 100644 pkg/util/sqlkiller/binding__failpoint_binding__.go create mode 100644 pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ create mode 100644 pkg/util/stmtsummary/binding__failpoint_binding__.go create mode 100644 pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ create mode 100644 pkg/util/topsql/binding__failpoint_binding__.go create mode 100644 pkg/util/topsql/reporter/binding__failpoint_binding__.go create mode 100644 pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ create mode 100644 pkg/util/topsql/reporter/reporter.go__failpoint_stash__ create mode 100644 pkg/util/topsql/topsql.go__failpoint_stash__ diff --git a/br/pkg/backup/binding__failpoint_binding__.go b/br/pkg/backup/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..20f171c30d696 --- /dev/null +++ b/br/pkg/backup/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package backup + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/backup/prepare_snap/binding__failpoint_binding__.go b/br/pkg/backup/prepare_snap/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..f51555385eb65 --- /dev/null +++ b/br/pkg/backup/prepare_snap/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package preparesnap + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/backup/prepare_snap/prepare.go b/br/pkg/backup/prepare_snap/prepare.go index 2435fb4986bdc..8ea386eb5a2bc 100644 --- a/br/pkg/backup/prepare_snap/prepare.go +++ b/br/pkg/backup/prepare_snap/prepare.go @@ -454,9 +454,9 @@ func (p *Preparer) pushWaitApply(reqs pendingRequests, region Region) { // PrepareConnections prepares the connections for each store. // This will pause the admin commands for each store. func (p *Preparer) PrepareConnections(ctx context.Context) error { - failpoint.Inject("PrepareConnectionsErr", func() { - failpoint.Return(errors.New("mock PrepareConnectionsErr")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("PrepareConnectionsErr")); _err_ == nil { + return errors.New("mock PrepareConnectionsErr") + } log.Info("Preparing connections to stores.") stores, err := p.env.GetAllLiveStores(ctx) if err != nil { diff --git a/br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ b/br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ new file mode 100644 index 0000000000000..2435fb4986bdc --- /dev/null +++ b/br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ @@ -0,0 +1,484 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package preparesnap + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/google/btree" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + brpb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/sync/errgroup" +) + +const ( + /* The combination of defaultMaxRetry and defaultRetryBackoff limits + the whole procedure to about 5 min if there is a region always fail. + Also note that we are batching during retrying. Retrying many region + costs only one chance of retrying if they are batched. */ + + defaultMaxRetry = 60 + defaultRetryBackoff = 5 * time.Second + defaultLeaseDur = 120 * time.Second + + /* Give pd enough time to find the region. If we aren't able to fetch + the region, the whole procedure might be aborted. */ + + regionCacheMaxBackoffMs = 60000 +) + +type pendingRequests map[uint64]*brpb.PrepareSnapshotBackupRequest + +type rangeOrRegion struct { + // If it is a range, this should be zero. + id uint64 + startKey []byte + endKey []byte +} + +func (r rangeOrRegion) String() string { + rng := logutil.StringifyRangeOf(r.startKey, r.endKey) + if r.id == 0 { + return fmt.Sprintf("range%s", rng) + } + return fmt.Sprintf("region(id=%d, range=%s)", r.id, rng) +} + +func (r rangeOrRegion) compareWith(than rangeOrRegion) bool { + return bytes.Compare(r.startKey, than.startKey) < 0 +} + +type Preparer struct { + /* Environments. */ + env Env + + /* Internal Status. */ + inflightReqs map[uint64]metapb.Region + failed []rangeOrRegion + waitApplyDoneRegions btree.BTreeG[rangeOrRegion] + retryTime int + nextRetry *time.Timer + + /* Internal I/O. */ + eventChan chan event + clients map[uint64]*prepareStream + + /* Interface for caller. */ + waitApplyFinished bool + + /* Some configurations. They aren't thread safe. + You may need to configure them before starting the Preparer. */ + RetryBackoff time.Duration + RetryLimit int + LeaseDuration time.Duration + + /* Observers. Initialize them before starting.*/ + AfterConnectionsEstablished func() +} + +func New(env Env) *Preparer { + prep := &Preparer{ + env: env, + + inflightReqs: make(map[uint64]metapb.Region), + waitApplyDoneRegions: *btree.NewG(16, rangeOrRegion.compareWith), + eventChan: make(chan event, 128), + clients: make(map[uint64]*prepareStream), + + RetryBackoff: defaultRetryBackoff, + RetryLimit: defaultMaxRetry, + LeaseDuration: defaultLeaseDur, + } + return prep +} + +func (p *Preparer) MarshalLogObject(om zapcore.ObjectEncoder) error { + om.AddInt("inflight_requests", len(p.inflightReqs)) + reqs := 0 + for _, r := range p.inflightReqs { + om.AddString("simple_inflight_region", rangeOrRegion{id: r.Id, startKey: r.StartKey, endKey: r.EndKey}.String()) + reqs += 1 + if reqs > 3 { + break + } + } + om.AddInt("failed_requests", len(p.failed)) + failed := 0 + for _, r := range p.failed { + om.AddString("simple_failed_region", r.String()) + failed += 1 + if failed > 5 { + break + } + } + err := om.AddArray("connected_stores", zapcore.ArrayMarshalerFunc(func(ae zapcore.ArrayEncoder) error { + for id := range p.clients { + ae.AppendUint64(id) + } + return nil + })) + if err != nil { + return err + } + om.AddInt("retry_time", p.retryTime) + om.AddBool("wait_apply_finished", p.waitApplyFinished) + return nil +} + +// DriveLoopAndWaitPrepare drives the state machine and block the +// current goroutine until we are safe to start taking snapshot. +// +// After this invoked, you shouldn't share this `Preparer` with any other goroutines. +// +// After this the cluster will enter the land between normal and taking snapshot. +// This state will continue even this function returns, until `Finalize` invoked. +// Splitting, ingesting and conf changing will all be blocked. +func (p *Preparer) DriveLoopAndWaitPrepare(ctx context.Context) error { + logutil.CL(ctx).Info("Start drive the loop.", zap.Duration("retry_backoff", p.RetryBackoff), + zap.Int("retry_limit", p.RetryLimit), + zap.Duration("lease_duration", p.LeaseDuration)) + p.retryTime = 0 + if err := p.PrepareConnections(ctx); err != nil { + log.Error("failed to prepare connections", logutil.ShortError(err)) + return errors.Annotate(err, "failed to prepare connections") + } + if p.AfterConnectionsEstablished != nil { + p.AfterConnectionsEstablished() + } + if err := p.AdvanceState(ctx); err != nil { + log.Error("failed to check the progress of our work", logutil.ShortError(err)) + return errors.Annotate(err, "failed to begin step") + } + for !p.waitApplyFinished { + if err := p.WaitAndHandleNextEvent(ctx); err != nil { + log.Error("failed to wait and handle next event", logutil.ShortError(err)) + return errors.Annotate(err, "failed to step") + } + } + return nil +} + +// Finalize notify the cluster to go back to the normal mode. +// This will return an error if the cluster has already entered the normal mode when this is called. +func (p *Preparer) Finalize(ctx context.Context) error { + eg := new(errgroup.Group) + for id, cli := range p.clients { + cli := cli + id := id + eg.Go(func() error { + if err := cli.Finalize(ctx); err != nil { + return errors.Annotatef(err, "failed to finalize the prepare stream for %d", id) + } + return nil + }) + } + errCh := make(chan error, 1) + go func() { + if err := eg.Wait(); err != nil { + logutil.CL(ctx).Warn("failed to finalize some prepare streams.", logutil.ShortError(err)) + errCh <- err + return + } + logutil.CL(ctx).Info("all connections to store have shuted down.") + errCh <- nil + }() + for { + select { + case event, ok := <-p.eventChan: + if !ok { + return nil + } + if err := p.onEvent(ctx, event); err != nil { + return err + } + case err, ok := <-errCh: + if !ok { + panic("unreachable.") + } + if err != nil { + return err + } + // All streams are finialized, they shouldn't send more events to event chan. + close(p.eventChan) + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func (p *Preparer) batchEvents(evts *[]event) { + for { + select { + case evt := <-p.eventChan: + *evts = append(*evts, evt) + default: + return + } + } +} + +// WaitAndHandleNextEvent is exported for test usage. +// This waits the next event (wait apply done, errors, etc..) of preparing. +// Generally `DriveLoopAndWaitPrepare` is all you need. +func (p *Preparer) WaitAndHandleNextEvent(ctx context.Context) error { + select { + case <-ctx.Done(): + logutil.CL(ctx).Warn("User canceled.", logutil.ShortError(ctx.Err())) + return ctx.Err() + case evt := <-p.eventChan: + logutil.CL(ctx).Debug("received event", zap.Stringer("event", evt)) + events := []event{evt} + p.batchEvents(&events) + for _, evt := range events { + err := p.onEvent(ctx, evt) + if err != nil { + return errors.Annotatef(err, "failed to handle event %v", evt) + } + } + return p.AdvanceState(ctx) + case <-p.retryChan(): + return p.workOnPendingRanges(ctx) + } +} + +func (p *Preparer) removePendingRequest(r *metapb.Region) bool { + r2, ok := p.inflightReqs[r.GetId()] + if !ok { + return false + } + matches := r2.GetRegionEpoch().GetVersion() == r.GetRegionEpoch().GetVersion() && + r2.GetRegionEpoch().GetConfVer() == r.GetRegionEpoch().GetConfVer() + if !matches { + return false + } + delete(p.inflightReqs, r.GetId()) + return true +} + +func (p *Preparer) onEvent(ctx context.Context, e event) error { + switch e.ty { + case eventMiscErr: + // Note: some of errors might be able to be retry. + // But for now it seems there isn't one. + return errors.Annotatef(e.err, "unrecoverable error at store %d", e.storeID) + case eventWaitApplyDone: + if !p.removePendingRequest(e.region) { + logutil.CL(ctx).Warn("received unmatched response, perhaps stale, drop it", zap.Stringer("region", e.region)) + return nil + } + r := rangeOrRegion{ + id: e.region.GetId(), + startKey: e.region.GetStartKey(), + endKey: e.region.GetEndKey(), + } + if e.err != nil { + logutil.CL(ctx).Warn("requesting a region failed.", zap.Uint64("store", e.storeID), logutil.ShortError(e.err)) + p.failed = append(p.failed, r) + if p.nextRetry != nil { + p.nextRetry.Stop() + } + // Reset the timer so we can collect more regions. + // Note: perhaps it is better to make a deadline heap or something + // so every region backoffs the same time. + p.nextRetry = time.NewTimer(p.RetryBackoff) + return nil + } + if item, ok := p.waitApplyDoneRegions.ReplaceOrInsert(r); ok { + logutil.CL(ctx).Warn("overlapping in success region", + zap.Stringer("old_region", item), + zap.Stringer("new_region", r)) + } + default: + return errors.Annotatef(unsupported(), "unsupported event type %d", e.ty) + } + + return nil +} + +func (p *Preparer) retryChan() <-chan time.Time { + if p.nextRetry == nil { + return nil + } + return p.nextRetry.C +} + +// AdvanceState is exported for test usage. +// This call will check whether now we are safe to forward the whole procedure. +// If we can, this will set `p.waitApplyFinished` to true. +// Generally `DriveLoopAndWaitPrepare` is all you need, you may not want to call this. +func (p *Preparer) AdvanceState(ctx context.Context) error { + logutil.CL(ctx).Info("Checking the progress of our work.", zap.Object("current", p)) + if len(p.inflightReqs) == 0 && len(p.failed) == 0 { + holes := p.checkHole() + if len(holes) == 0 { + p.waitApplyFinished = true + return nil + } + logutil.CL(ctx).Warn("It seems there are still some works to be done.", zap.Stringers("regions", holes)) + p.failed = holes + return p.workOnPendingRanges(ctx) + } + + return nil +} + +func (p *Preparer) checkHole() []rangeOrRegion { + log.Info("Start checking the hole.", zap.Int("len", p.waitApplyDoneRegions.Len())) + if p.waitApplyDoneRegions.Len() == 0 { + return []rangeOrRegion{{}} + } + + last := []byte("") + failed := []rangeOrRegion{} + p.waitApplyDoneRegions.Ascend(func(item rangeOrRegion) bool { + if bytes.Compare(last, item.startKey) < 0 { + failed = append(failed, rangeOrRegion{startKey: last, endKey: item.startKey}) + } + last = item.endKey + return true + }) + // Not the end key of key space. + if len(last) > 0 { + failed = append(failed, rangeOrRegion{ + startKey: last, + }) + } + return failed +} + +func (p *Preparer) workOnPendingRanges(ctx context.Context) error { + p.nextRetry = nil + if len(p.failed) == 0 { + return nil + } + p.retryTime += 1 + if p.retryTime > p.RetryLimit { + return retryLimitExceeded() + } + + logutil.CL(ctx).Info("retrying some ranges incomplete.", zap.Int("ranges", len(p.failed))) + preqs := pendingRequests{} + for _, r := range p.failed { + rs, err := p.env.LoadRegionsInKeyRange(ctx, r.startKey, r.endKey) + if err != nil { + return errors.Annotatef(err, "retrying range of %s: get region", logutil.StringifyRangeOf(r.startKey, r.endKey)) + } + logutil.CL(ctx).Info("loaded regions in range for retry.", zap.Int("regions", len(rs))) + for _, region := range rs { + p.pushWaitApply(preqs, region) + } + } + p.failed = nil + return p.sendWaitApply(ctx, preqs) +} + +func (p *Preparer) sendWaitApply(ctx context.Context, reqs pendingRequests) error { + logutil.CL(ctx).Info("about to send wait apply to stores", zap.Int("to-stores", len(reqs))) + for store, req := range reqs { + logutil.CL(ctx).Info("sending wait apply requests to store", zap.Uint64("store", store), zap.Int("regions", len(req.Regions))) + stream, err := p.streamOf(ctx, store) + if err != nil { + return errors.Annotatef(err, "failed to dial the store %d", store) + } + err = stream.cli.Send(req) + if err != nil { + return errors.Annotatef(err, "failed to send message to the store %d", store) + } + } + return nil +} + +func (p *Preparer) streamOf(ctx context.Context, storeID uint64) (*prepareStream, error) { + _, ok := p.clients[storeID] + if !ok { + log.Warn("stream of store found a store not established connection", zap.Uint64("store", storeID)) + cli, err := p.env.ConnectToStore(ctx, storeID) + if err != nil { + return nil, errors.Annotatef(err, "failed to dial store %d", storeID) + } + if err := p.createAndCacheStream(ctx, cli, storeID); err != nil { + return nil, errors.Annotatef(err, "failed to create and cache stream for store %d", storeID) + } + } + return p.clients[storeID], nil +} + +func (p *Preparer) createAndCacheStream(ctx context.Context, cli PrepareClient, storeID uint64) error { + if _, ok := p.clients[storeID]; ok { + return nil + } + + s := new(prepareStream) + s.storeID = storeID + s.output = p.eventChan + s.leaseDuration = p.LeaseDuration + err := s.InitConn(ctx, cli) + if err != nil { + return err + } + p.clients[storeID] = s + return nil +} + +func (p *Preparer) pushWaitApply(reqs pendingRequests, region Region) { + leader := region.GetLeaderStoreID() + if _, ok := reqs[leader]; !ok { + reqs[leader] = new(brpb.PrepareSnapshotBackupRequest) + reqs[leader].Ty = brpb.PrepareSnapshotBackupRequestType_WaitApply + } + reqs[leader].Regions = append(reqs[leader].Regions, region.GetMeta()) + p.inflightReqs[region.GetMeta().Id] = *region.GetMeta() +} + +// PrepareConnections prepares the connections for each store. +// This will pause the admin commands for each store. +func (p *Preparer) PrepareConnections(ctx context.Context) error { + failpoint.Inject("PrepareConnectionsErr", func() { + failpoint.Return(errors.New("mock PrepareConnectionsErr")) + }) + log.Info("Preparing connections to stores.") + stores, err := p.env.GetAllLiveStores(ctx) + if err != nil { + return errors.Annotate(err, "failed to get all live stores") + } + + log.Info("Start to initialize the connections.", zap.Int("stores", len(stores))) + clients := map[uint64]PrepareClient{} + for _, store := range stores { + cli, err := p.env.ConnectToStore(ctx, store.Id) + if err != nil { + return errors.Annotatef(err, "failed to dial the store %d", store.Id) + } + clients[store.Id] = cli + } + + for id, cli := range clients { + log.Info("Start to pause the admin commands.", zap.Uint64("store", id)) + if err := p.createAndCacheStream(ctx, cli, id); err != nil { + return errors.Annotatef(err, "failed to create and cache stream for store %d", id) + } + } + + return nil +} diff --git a/br/pkg/backup/store.go b/br/pkg/backup/store.go index 02f7166193918..48f935b03f4b0 100644 --- a/br/pkg/backup/store.go +++ b/br/pkg/backup/store.go @@ -63,7 +63,7 @@ func doSendBackup( req backuppb.BackupRequest, respFn func(*backuppb.BackupResponse) error, ) error { - failpoint.Inject("hint-backup-start", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("hint-backup-start")); _err_ == nil { logutil.CL(ctx).Info("failpoint hint-backup-start injected, " + "process will notify the shell.") if sigFile, ok := v.(string); ok { @@ -76,9 +76,9 @@ func doSendBackup( } } time.Sleep(3 * time.Second) - }) + } bCli, err := client.Backup(ctx, &req) - failpoint.Inject("reset-retryable-error", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("reset-retryable-error")); _err_ == nil { switch val.(string) { case "Unavailable": { @@ -91,13 +91,13 @@ func doSendBackup( err = status.Error(codes.Internal, "Internal error") } } - }) - failpoint.Inject("reset-not-retryable-error", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("reset-not-retryable-error")); _err_ == nil { if val.(bool) { logutil.CL(ctx).Debug("failpoint reset-not-retryable-error injected.") err = status.Error(codes.Unknown, "Your server was haunted hence doesn't work, meow :3") } - }) + } if err != nil { return err } @@ -159,28 +159,28 @@ func startBackup( zap.Int("retry", retry), zap.Int("reqIndex", reqIndex)) return doSendBackup(ectx, backupCli, bkReq, func(resp *backuppb.BackupResponse) error { // Forward all responses (including error). - failpoint.Inject("backup-timeout-error", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("backup-timeout-error")); _err_ == nil { msg := val.(string) logutil.CL(ectx).Info("failpoint backup-timeout-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ Msg: msg, } - }) - failpoint.Inject("backup-storage-error", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("backup-storage-error")); _err_ == nil { msg := val.(string) logutil.CL(ectx).Debug("failpoint backup-storage-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ Msg: msg, } - }) - failpoint.Inject("tikv-rw-error", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("tikv-rw-error")); _err_ == nil { msg := val.(string) logutil.CL(ectx).Debug("failpoint tikv-rw-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ Msg: msg, } - }) - failpoint.Inject("tikv-region-error", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("tikv-region-error")); _err_ == nil { msg := val.(string) logutil.CL(ectx).Debug("failpoint tikv-region-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ @@ -191,7 +191,7 @@ func startBackup( }, }, } - }) + } select { case <-ectx.Done(): return ectx.Err() @@ -247,12 +247,12 @@ func ObserveStoreChangesAsync(ctx context.Context, stateNotifier chan BackupRetr logutil.CL(ctx).Warn("failed to watch store changes at beginning, ignore it", zap.Error(err)) } tickInterval := 30 * time.Second - failpoint.Inject("backup-store-change-tick", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("backup-store-change-tick")); _err_ == nil { if val.(bool) { tickInterval = 100 * time.Millisecond } logutil.CL(ctx).Info("failpoint backup-store-change-tick injected.", zap.Duration("interval", tickInterval)) - }) + } tick := time.NewTicker(tickInterval) for { select { diff --git a/br/pkg/backup/store.go__failpoint_stash__ b/br/pkg/backup/store.go__failpoint_stash__ new file mode 100644 index 0000000000000..02f7166193918 --- /dev/null +++ b/br/pkg/backup/store.go__failpoint_stash__ @@ -0,0 +1,307 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package backup + +import ( + "context" + "io" + "os" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/rtree" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/utils/storewatch" + tidbutil "github.com/pingcap/tidb/pkg/util" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type BackupRetryPolicy struct { + One uint64 + All bool +} + +type BackupSender interface { + SendAsync( + ctx context.Context, + round uint64, + storeID uint64, + request backuppb.BackupRequest, + concurrency uint, + cli backuppb.BackupClient, + respCh chan *ResponseAndStore, + StateNotifier chan BackupRetryPolicy) +} + +type ResponseAndStore struct { + Resp *backuppb.BackupResponse + StoreID uint64 +} + +func (r ResponseAndStore) GetResponse() *backuppb.BackupResponse { + return r.Resp +} + +func (r ResponseAndStore) GetStoreID() uint64 { + return r.StoreID +} + +func doSendBackup( + ctx context.Context, + client backuppb.BackupClient, + req backuppb.BackupRequest, + respFn func(*backuppb.BackupResponse) error, +) error { + failpoint.Inject("hint-backup-start", func(v failpoint.Value) { + logutil.CL(ctx).Info("failpoint hint-backup-start injected, " + + "process will notify the shell.") + if sigFile, ok := v.(string); ok { + file, err := os.Create(sigFile) + if err != nil { + log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) + } + if file != nil { + file.Close() + } + } + time.Sleep(3 * time.Second) + }) + bCli, err := client.Backup(ctx, &req) + failpoint.Inject("reset-retryable-error", func(val failpoint.Value) { + switch val.(string) { + case "Unavailable": + { + logutil.CL(ctx).Debug("failpoint reset-retryable-error unavailable injected.") + err = status.Error(codes.Unavailable, "Unavailable error") + } + case "Internal": + { + logutil.CL(ctx).Debug("failpoint reset-retryable-error internal injected.") + err = status.Error(codes.Internal, "Internal error") + } + } + }) + failpoint.Inject("reset-not-retryable-error", func(val failpoint.Value) { + if val.(bool) { + logutil.CL(ctx).Debug("failpoint reset-not-retryable-error injected.") + err = status.Error(codes.Unknown, "Your server was haunted hence doesn't work, meow :3") + } + }) + if err != nil { + return err + } + defer func() { + _ = bCli.CloseSend() + }() + + for { + resp, err := bCli.Recv() + if err != nil { + if errors.Cause(err) == io.EOF { // nolint:errorlint + logutil.CL(ctx).Debug("backup streaming finish", + logutil.Key("backup-start-key", req.GetStartKey()), + logutil.Key("backup-end-key", req.GetEndKey())) + return nil + } + return err + } + // TODO: handle errors in the resp. + logutil.CL(ctx).Debug("range backed up", + logutil.Key("small-range-start-key", resp.GetStartKey()), + logutil.Key("small-range-end-key", resp.GetEndKey()), + zap.Int("api-version", int(resp.ApiVersion))) + err = respFn(resp) + if err != nil { + return errors.Trace(err) + } + } +} + +func startBackup( + ctx context.Context, + storeID uint64, + backupReq backuppb.BackupRequest, + backupCli backuppb.BackupClient, + concurrency uint, + respCh chan *ResponseAndStore, +) error { + // this goroutine handle the response from a single store + select { + case <-ctx.Done(): + return ctx.Err() + default: + logutil.CL(ctx).Info("try backup", zap.Uint64("storeID", storeID)) + // Send backup request to the store. + // handle the backup response or internal error here. + // handle the store error(reboot or network partition) outside. + reqs := SplitBackupReqRanges(backupReq, concurrency) + pool := tidbutil.NewWorkerPool(concurrency, "store_backup") + eg, ectx := errgroup.WithContext(ctx) + for i, req := range reqs { + bkReq := req + reqIndex := i + pool.ApplyOnErrorGroup(eg, func() error { + retry := -1 + return utils.WithRetry(ectx, func() error { + retry += 1 + logutil.CL(ectx).Info("backup to store", zap.Uint64("storeID", storeID), + zap.Int("retry", retry), zap.Int("reqIndex", reqIndex)) + return doSendBackup(ectx, backupCli, bkReq, func(resp *backuppb.BackupResponse) error { + // Forward all responses (including error). + failpoint.Inject("backup-timeout-error", func(val failpoint.Value) { + msg := val.(string) + logutil.CL(ectx).Info("failpoint backup-timeout-error injected.", zap.String("msg", msg)) + resp.Error = &backuppb.Error{ + Msg: msg, + } + }) + failpoint.Inject("backup-storage-error", func(val failpoint.Value) { + msg := val.(string) + logutil.CL(ectx).Debug("failpoint backup-storage-error injected.", zap.String("msg", msg)) + resp.Error = &backuppb.Error{ + Msg: msg, + } + }) + failpoint.Inject("tikv-rw-error", func(val failpoint.Value) { + msg := val.(string) + logutil.CL(ectx).Debug("failpoint tikv-rw-error injected.", zap.String("msg", msg)) + resp.Error = &backuppb.Error{ + Msg: msg, + } + }) + failpoint.Inject("tikv-region-error", func(val failpoint.Value) { + msg := val.(string) + logutil.CL(ectx).Debug("failpoint tikv-region-error injected.", zap.String("msg", msg)) + resp.Error = &backuppb.Error{ + // Msg: msg, + Detail: &backuppb.Error_RegionError{ + RegionError: &errorpb.Error{ + Message: msg, + }, + }, + } + }) + select { + case <-ectx.Done(): + return ectx.Err() + case respCh <- &ResponseAndStore{ + Resp: resp, + StoreID: storeID, + }: + } + return nil + }) + }, utils.NewBackupSSTBackoffer()) + }) + } + return eg.Wait() + } +} + +func getBackupRanges(ranges []rtree.Range) []*kvrpcpb.KeyRange { + requestRanges := make([]*kvrpcpb.KeyRange, 0, len(ranges)) + for _, r := range ranges { + requestRanges = append(requestRanges, &kvrpcpb.KeyRange{ + StartKey: r.StartKey, + EndKey: r.EndKey, + }) + } + return requestRanges +} + +func ObserveStoreChangesAsync(ctx context.Context, stateNotifier chan BackupRetryPolicy, pdCli pd.Client) { + go func() { + sendAll := false + newJoinStoresMap := make(map[uint64]struct{}) + cb := storewatch.MakeCallback(storewatch.WithOnReboot(func(s *metapb.Store) { + sendAll = true + }), storewatch.WithOnDisconnect(func(s *metapb.Store) { + sendAll = true + }), storewatch.WithOnNewStoreRegistered(func(s *metapb.Store) { + // only backup for this store + newJoinStoresMap[s.Id] = struct{}{} + })) + + notifyFn := func(ctx context.Context, sendPolicy BackupRetryPolicy) { + select { + case <-ctx.Done(): + case stateNotifier <- sendPolicy: + } + } + + watcher := storewatch.New(pdCli, cb) + // make a first step, and make the state correct for next 30s check + err := watcher.Step(ctx) + if err != nil { + logutil.CL(ctx).Warn("failed to watch store changes at beginning, ignore it", zap.Error(err)) + } + tickInterval := 30 * time.Second + failpoint.Inject("backup-store-change-tick", func(val failpoint.Value) { + if val.(bool) { + tickInterval = 100 * time.Millisecond + } + logutil.CL(ctx).Info("failpoint backup-store-change-tick injected.", zap.Duration("interval", tickInterval)) + }) + tick := time.NewTicker(tickInterval) + for { + select { + case <-ctx.Done(): + return + case <-tick.C: + // reset the state + sendAll = false + clear(newJoinStoresMap) + logutil.CL(ctx).Info("check store changes every tick") + err := watcher.Step(ctx) + if err != nil { + logutil.CL(ctx).Warn("failed to watch store changes, ignore it", zap.Error(err)) + } + if sendAll { + logutil.CL(ctx).Info("detect some store(s) restarted or disconnected, notify with all stores") + notifyFn(ctx, BackupRetryPolicy{All: true}) + } else if len(newJoinStoresMap) > 0 { + for storeID := range newJoinStoresMap { + logutil.CL(ctx).Info("detect a new registered store, notify with this store", zap.Uint64("storeID", storeID)) + notifyFn(ctx, BackupRetryPolicy{One: storeID}) + } + } + } + } + }() +} + +func SplitBackupReqRanges(req backuppb.BackupRequest, count uint) []backuppb.BackupRequest { + rangeCount := len(req.SubRanges) + if rangeCount == 0 { + return []backuppb.BackupRequest{req} + } + splitRequests := make([]backuppb.BackupRequest, 0, count) + if count <= 1 { + // 0/1 means no need to split, just send one batch request + return []backuppb.BackupRequest{req} + } + splitStep := rangeCount / int(count) + if splitStep == 0 { + // splitStep should be at least 1 + // if count >= rangeCount, means no batch, split them all + splitStep = 1 + } + subRanges := req.SubRanges + for i := 0; i < rangeCount; i += splitStep { + splitReq := req + splitReq.SubRanges = subRanges[i:min(i+splitStep, rangeCount)] + splitRequests = append(splitRequests, splitReq) + } + return splitRequests +} diff --git a/br/pkg/checkpoint/binding__failpoint_binding__.go b/br/pkg/checkpoint/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..8590a12a60919 --- /dev/null +++ b/br/pkg/checkpoint/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package checkpoint + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/checkpoint/checkpoint.go b/br/pkg/checkpoint/checkpoint.go index 4b397a60e5eeb..ab438f2635665 100644 --- a/br/pkg/checkpoint/checkpoint.go +++ b/br/pkg/checkpoint/checkpoint.go @@ -391,7 +391,7 @@ func (r *CheckpointRunner[K, V]) startCheckpointMainLoop( tickDurationForChecksum, tickDurationForLock time.Duration, ) { - failpoint.Inject("checkpoint-more-quickly-flush", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("checkpoint-more-quickly-flush")); _err_ == nil { tickDurationForChecksum = 1 * time.Second tickDurationForFlush = 3 * time.Second if tickDurationForLock > 0 { @@ -402,7 +402,7 @@ func (r *CheckpointRunner[K, V]) startCheckpointMainLoop( zap.Duration("checksum", tickDurationForChecksum), zap.Duration("lock", tickDurationForLock), ) - }) + } r.wg.Add(1) checkpointLoop := func(ctx context.Context) { defer r.wg.Done() @@ -506,9 +506,9 @@ func (r *CheckpointRunner[K, V]) doChecksumFlush(ctx context.Context, checksumIt return errors.Annotatef(err, "failed to write file %s for checkpoint checksum", fname) } - failpoint.Inject("failed-after-checkpoint-flushes-checksum", func(_ failpoint.Value) { - failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes checksum")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failed-after-checkpoint-flushes-checksum")); _err_ == nil { + return errors.Errorf("failpoint: failed after checkpoint flushes checksum") + } return nil } @@ -570,9 +570,9 @@ func (r *CheckpointRunner[K, V]) doFlush(ctx context.Context, meta map[K]*RangeG } } - failpoint.Inject("failed-after-checkpoint-flushes", func(_ failpoint.Value) { - failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failed-after-checkpoint-flushes")); _err_ == nil { + return errors.Errorf("failpoint: failed after checkpoint flushes") + } return nil } @@ -663,9 +663,9 @@ func (r *CheckpointRunner[K, V]) updateLock(ctx context.Context) error { return errors.Trace(err) } - failpoint.Inject("failed-after-checkpoint-updates-lock", func(_ failpoint.Value) { - failpoint.Return(errors.Errorf("failpoint: failed after checkpoint updates lock")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failed-after-checkpoint-updates-lock")); _err_ == nil { + return errors.Errorf("failpoint: failed after checkpoint updates lock") + } return nil } diff --git a/br/pkg/checkpoint/checkpoint.go__failpoint_stash__ b/br/pkg/checkpoint/checkpoint.go__failpoint_stash__ new file mode 100644 index 0000000000000..4b397a60e5eeb --- /dev/null +++ b/br/pkg/checkpoint/checkpoint.go__failpoint_stash__ @@ -0,0 +1,872 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checkpoint + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/rtree" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/util" + "github.com/tikv/client-go/v2/oracle" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +const CheckpointDir = "checkpoints" + +type flushPosition struct { + CheckpointDataDir string + CheckpointChecksumDir string + CheckpointLockPath string +} + +const MaxChecksumTotalCost float64 = 60.0 + +const defaultTickDurationForFlush = 30 * time.Second + +const defaultTckDurationForChecksum = 5 * time.Second + +const defaultTickDurationForLock = 4 * time.Minute + +const lockTimeToLive = 5 * time.Minute + +type KeyType interface { + ~BackupKeyType | ~RestoreKeyType +} + +type RangeType struct { + *rtree.Range +} + +func (r RangeType) IdentKey() []byte { + return r.StartKey +} + +type ValueType interface { + IdentKey() []byte +} + +type CheckpointMessage[K KeyType, V ValueType] struct { + // start-key of the origin range + GroupKey K + + Group []V +} + +// A Checkpoint Range File is like this: +// +// CheckpointData +// +----------------+ RangeGroupData RangeGroup +// | DureTime | +--------------------------+ encrypted +--------------------+ +// | RangeGroupData-+---> | RangeGroupsEncriptedData-+----------> | GroupKey/TableID | +// | RangeGroupData | | Checksum | | Range | +// | ... | | CipherIv | | ... | +// | RangeGroupData | | Size | | Range | +// +----------------+ +--------------------------+ +--------------------+ +// +// For restore, because there is no group key, so there is only one RangeGroupData +// with multi-ranges in the ChecksumData. + +type RangeGroup[K KeyType, V ValueType] struct { + GroupKey K `json:"group-key"` + Group []V `json:"groups"` +} + +type RangeGroupData struct { + RangeGroupsEncriptedData []byte + Checksum []byte + CipherIv []byte + + Size int +} + +type CheckpointData struct { + DureTime time.Duration `json:"dure-time"` + RangeGroupMetas []*RangeGroupData `json:"range-group-metas"` +} + +// A Checkpoint Checksum File is like this: +// +// ChecksumInfo ChecksumItems ChecksumItem +// +------------+ +--------------+ +--------------+ +// | Content--+--> | ChecksumItem-+---> | TableID | +// | Checksum | | ChecksumItem | | Crc64xor | +// | DureTime | | ... | | TotalKvs | +// +------------+ | ChecksumItem | | TotalBytes | +// +--------------+ +--------------+ + +type ChecksumItem struct { + TableID int64 `json:"table-id"` + Crc64xor uint64 `json:"crc64-xor"` + TotalKvs uint64 `json:"total-kvs"` + TotalBytes uint64 `json:"total-bytes"` +} + +type ChecksumItems struct { + Items []*ChecksumItem `json:"checksum-items"` +} + +type ChecksumInfo struct { + Content []byte `json:"content"` + Checksum []byte `json:"checksum"` + DureTime time.Duration `json:"dure-time"` +} + +type GlobalTimer interface { + GetTS(context.Context) (int64, int64, error) +} + +type CheckpointRunner[K KeyType, V ValueType] struct { + flushPosition + lockId uint64 + + meta map[K]*RangeGroup[K, V] + checksum ChecksumItems + + valueMarshaler func(*RangeGroup[K, V]) ([]byte, error) + + storage storage.ExternalStorage + cipher *backuppb.CipherInfo + timer GlobalTimer + + appendCh chan *CheckpointMessage[K, V] + checksumCh chan *ChecksumItem + doneCh chan bool + metaCh chan map[K]*RangeGroup[K, V] + checksumMetaCh chan ChecksumItems + lockCh chan struct{} + errCh chan error + err error + errLock sync.RWMutex + + wg sync.WaitGroup +} + +func newCheckpointRunner[K KeyType, V ValueType]( + ctx context.Context, + storage storage.ExternalStorage, + cipher *backuppb.CipherInfo, + timer GlobalTimer, + f flushPosition, + vm func(*RangeGroup[K, V]) ([]byte, error), +) *CheckpointRunner[K, V] { + return &CheckpointRunner[K, V]{ + flushPosition: f, + + meta: make(map[K]*RangeGroup[K, V]), + checksum: ChecksumItems{Items: make([]*ChecksumItem, 0)}, + + valueMarshaler: vm, + + storage: storage, + cipher: cipher, + timer: timer, + + appendCh: make(chan *CheckpointMessage[K, V]), + checksumCh: make(chan *ChecksumItem), + doneCh: make(chan bool, 1), + metaCh: make(chan map[K]*RangeGroup[K, V]), + checksumMetaCh: make(chan ChecksumItems), + lockCh: make(chan struct{}), + errCh: make(chan error, 1), + err: nil, + } +} + +func (r *CheckpointRunner[K, V]) FlushChecksum( + ctx context.Context, + tableID int64, + crc64xor uint64, + totalKvs uint64, + totalBytes uint64, +) error { + checksumItem := &ChecksumItem{ + TableID: tableID, + Crc64xor: crc64xor, + TotalKvs: totalKvs, + TotalBytes: totalBytes, + } + return r.FlushChecksumItem(ctx, checksumItem) +} + +func (r *CheckpointRunner[K, V]) FlushChecksumItem( + ctx context.Context, + checksumItem *ChecksumItem, +) error { + select { + case <-ctx.Done(): + return errors.Annotatef(ctx.Err(), "failed to append checkpoint checksum item") + case err, ok := <-r.errCh: + if !ok { + r.errLock.RLock() + err = r.err + r.errLock.RUnlock() + return errors.Annotate(err, "[checkpoint] Checksum: failed to append checkpoint checksum item") + } + return err + case r.checksumCh <- checksumItem: + return nil + } +} + +func (r *CheckpointRunner[K, V]) Append( + ctx context.Context, + message *CheckpointMessage[K, V], +) error { + select { + case <-ctx.Done(): + return errors.Annotatef(ctx.Err(), "failed to append checkpoint message") + case err, ok := <-r.errCh: + if !ok { + r.errLock.RLock() + err = r.err + r.errLock.RUnlock() + return errors.Annotate(err, "[checkpoint] Append: failed to append checkpoint message") + } + return err + case r.appendCh <- message: + return nil + } +} + +// Note: Cannot be parallel with `Append` function +func (r *CheckpointRunner[K, V]) WaitForFinish(ctx context.Context, flush bool) { + if r.doneCh != nil { + select { + case r.doneCh <- flush: + + default: + log.Warn("not the first close the checkpoint runner", zap.String("category", "checkpoint")) + } + } + // wait the range flusher exit + r.wg.Wait() + // remove the checkpoint lock + if r.lockId > 0 { + err := r.storage.DeleteFile(ctx, r.CheckpointLockPath) + if err != nil { + log.Warn("failed to remove the checkpoint lock", zap.Error(err)) + } + } +} + +// Send the checksum to the flush goroutine, and reset the CheckpointRunner's checksum +func (r *CheckpointRunner[K, V]) flushChecksum(ctx context.Context, errCh chan error) error { + checksum := ChecksumItems{ + Items: r.checksum.Items, + } + r.checksum.Items = make([]*ChecksumItem, 0) + // do flush + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case r.checksumMetaCh <- checksum: + } + return nil +} + +// Send the meta to the flush goroutine, and reset the CheckpointRunner's meta +func (r *CheckpointRunner[K, V]) flushMeta(ctx context.Context, errCh chan error) error { + meta := r.meta + r.meta = make(map[K]*RangeGroup[K, V]) + // do flush + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case r.metaCh <- meta: + } + return nil +} + +func (r *CheckpointRunner[K, V]) setLock(ctx context.Context, errCh chan error) error { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return err + case r.lockCh <- struct{}{}: + } + return nil +} + +// start a goroutine to flush the meta, which is sent from `checkpoint looper`, to the external storage +func (r *CheckpointRunner[K, V]) startCheckpointFlushLoop(ctx context.Context, wg *sync.WaitGroup) chan error { + errCh := make(chan error, 1) + wg.Add(1) + flushWorker := func(ctx context.Context, errCh chan error) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + errCh <- err + } + return + case meta, ok := <-r.metaCh: + if !ok { + log.Info("stop checkpoint flush worker") + return + } + if err := r.doFlush(ctx, meta); err != nil { + errCh <- errors.Annotate(err, "failed to flush checkpoint data.") + return + } + case checksums, ok := <-r.checksumMetaCh: + if !ok { + log.Info("stop checkpoint flush worker") + return + } + if err := r.doChecksumFlush(ctx, checksums); err != nil { + errCh <- errors.Annotate(err, "failed to flush checkpoint checksum.") + return + } + case _, ok := <-r.lockCh: + if !ok { + log.Info("stop checkpoint flush worker") + return + } + if err := r.updateLock(ctx); err != nil { + errCh <- errors.Annotate(err, "failed to update checkpoint lock.") + return + } + } + } + } + + go flushWorker(ctx, errCh) + return errCh +} + +func (r *CheckpointRunner[K, V]) sendError(err error) { + select { + case r.errCh <- err: + log.Error("send the error", zap.String("category", "checkpoint"), zap.Error(err)) + r.errLock.Lock() + r.err = err + r.errLock.Unlock() + close(r.errCh) + default: + log.Error("errCh is blocked", logutil.ShortError(err)) + } +} + +func (r *CheckpointRunner[K, V]) startCheckpointMainLoop( + ctx context.Context, + tickDurationForFlush, + tickDurationForChecksum, + tickDurationForLock time.Duration, +) { + failpoint.Inject("checkpoint-more-quickly-flush", func(_ failpoint.Value) { + tickDurationForChecksum = 1 * time.Second + tickDurationForFlush = 3 * time.Second + if tickDurationForLock > 0 { + tickDurationForLock = 1 * time.Second + } + log.Info("adjust the tick duration for flush or lock", + zap.Duration("flush", tickDurationForFlush), + zap.Duration("checksum", tickDurationForChecksum), + zap.Duration("lock", tickDurationForLock), + ) + }) + r.wg.Add(1) + checkpointLoop := func(ctx context.Context) { + defer r.wg.Done() + cctx, cancel := context.WithCancel(ctx) + defer cancel() + var wg sync.WaitGroup + errCh := r.startCheckpointFlushLoop(cctx, &wg) + flushTicker := time.NewTicker(tickDurationForFlush) + defer flushTicker.Stop() + checksumTicker := time.NewTicker(tickDurationForChecksum) + defer checksumTicker.Stop() + // register time ticker, the lock ticker is optional + lockTicker := dispatcherTicker(tickDurationForLock) + defer lockTicker.Stop() + for { + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + r.sendError(err) + } + return + case <-lockTicker.Ch(): + if err := r.setLock(ctx, errCh); err != nil { + r.sendError(err) + return + } + case <-checksumTicker.C: + if err := r.flushChecksum(ctx, errCh); err != nil { + r.sendError(err) + return + } + case <-flushTicker.C: + if err := r.flushMeta(ctx, errCh); err != nil { + r.sendError(err) + return + } + case msg := <-r.appendCh: + groups, exist := r.meta[msg.GroupKey] + if !exist { + groups = &RangeGroup[K, V]{ + GroupKey: msg.GroupKey, + Group: make([]V, 0), + } + r.meta[msg.GroupKey] = groups + } + groups.Group = append(groups.Group, msg.Group...) + case msg := <-r.checksumCh: + r.checksum.Items = append(r.checksum.Items, msg) + case flush := <-r.doneCh: + log.Info("stop checkpoint runner") + if flush { + // NOTE: the exit step, don't send error any more. + if err := r.flushMeta(ctx, errCh); err != nil { + log.Error("failed to flush checkpoint meta", zap.Error(err)) + } else if err := r.flushChecksum(ctx, errCh); err != nil { + log.Error("failed to flush checkpoint checksum", zap.Error(err)) + } + } + // close the channel to flush worker + // and wait it to consumes all the metas + close(r.metaCh) + close(r.checksumMetaCh) + close(r.lockCh) + wg.Wait() + return + case err := <-errCh: + // pass flush worker's error back + r.sendError(err) + return + } + } + } + + go checkpointLoop(ctx) +} + +// flush the checksum to the external storage +func (r *CheckpointRunner[K, V]) doChecksumFlush(ctx context.Context, checksumItems ChecksumItems) error { + if len(checksumItems.Items) == 0 { + return nil + } + content, err := json.Marshal(checksumItems) + if err != nil { + return errors.Trace(err) + } + + checksum := sha256.Sum256(content) + checksumInfo := &ChecksumInfo{ + Content: content, + Checksum: checksum[:], + DureTime: summary.NowDureTime(), + } + + data, err := json.Marshal(checksumInfo) + if err != nil { + return errors.Trace(err) + } + + fname := fmt.Sprintf("%s/t%d_and__.cpt", r.CheckpointChecksumDir, checksumItems.Items[0].TableID) + if err = r.storage.WriteFile(ctx, fname, data); err != nil { + return errors.Annotatef(err, "failed to write file %s for checkpoint checksum", fname) + } + + failpoint.Inject("failed-after-checkpoint-flushes-checksum", func(_ failpoint.Value) { + failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes checksum")) + }) + return nil +} + +// flush the meta to the external storage +func (r *CheckpointRunner[K, V]) doFlush(ctx context.Context, meta map[K]*RangeGroup[K, V]) error { + if len(meta) == 0 { + return nil + } + + checkpointData := &CheckpointData{ + DureTime: summary.NowDureTime(), + RangeGroupMetas: make([]*RangeGroupData, 0, len(meta)), + } + + var fname []byte = nil + + for _, group := range meta { + if len(group.Group) == 0 { + continue + } + + // use the first item's group-key and sub-range-key as the filename + if len(fname) == 0 { + fname = append([]byte(fmt.Sprint(group.GroupKey, '.', '.')), group.Group[0].IdentKey()...) + } + + // Flush the metaFile to storage + content, err := r.valueMarshaler(group) + if err != nil { + return errors.Trace(err) + } + + encryptBuff, iv, err := metautil.Encrypt(content, r.cipher) + if err != nil { + return errors.Trace(err) + } + + checksum := sha256.Sum256(content) + + checkpointData.RangeGroupMetas = append(checkpointData.RangeGroupMetas, &RangeGroupData{ + RangeGroupsEncriptedData: encryptBuff, + Checksum: checksum[:], + Size: len(content), + CipherIv: iv, + }) + } + + if len(checkpointData.RangeGroupMetas) > 0 { + data, err := json.Marshal(checkpointData) + if err != nil { + return errors.Trace(err) + } + + checksum := sha256.Sum256(fname) + checksumEncoded := base64.URLEncoding.EncodeToString(checksum[:]) + path := fmt.Sprintf("%s/%s_%d.cpt", r.CheckpointDataDir, checksumEncoded, rand.Uint64()) + if err := r.storage.WriteFile(ctx, path, data); err != nil { + return errors.Trace(err) + } + } + + failpoint.Inject("failed-after-checkpoint-flushes", func(_ failpoint.Value) { + failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes")) + }) + return nil +} + +type CheckpointLock struct { + LockId uint64 `json:"lock-id"` + ExpireAt int64 `json:"expire-at"` +} + +// get ts with retry +func (r *CheckpointRunner[K, V]) getTS(ctx context.Context) (int64, int64, error) { + var ( + p int64 = 0 + l int64 = 0 + retry int = 0 + ) + errRetry := utils.WithRetry(ctx, func() error { + var err error + p, l, err = r.timer.GetTS(ctx) + if err != nil { + retry++ + log.Info("failed to get ts", zap.Int("retry", retry), zap.Error(err)) + return err + } + + return nil + }, utils.NewPDReqBackoffer()) + + return p, l, errors.Trace(errRetry) +} + +// flush the lock to the external storage +func (r *CheckpointRunner[K, V]) flushLock(ctx context.Context, p int64) error { + lock := &CheckpointLock{ + LockId: r.lockId, + ExpireAt: p + lockTimeToLive.Milliseconds(), + } + log.Info("start to flush the checkpoint lock", zap.Int64("lock-at", p), + zap.Int64("expire-at", lock.ExpireAt)) + data, err := json.Marshal(lock) + if err != nil { + return errors.Trace(err) + } + + err = r.storage.WriteFile(ctx, r.CheckpointLockPath, data) + return errors.Trace(err) +} + +// check whether this lock belongs to this BR +func (r *CheckpointRunner[K, V]) checkLockFile(ctx context.Context, now int64) error { + data, err := r.storage.ReadFile(ctx, r.CheckpointLockPath) + if err != nil { + return errors.Trace(err) + } + lock := &CheckpointLock{} + err = json.Unmarshal(data, lock) + if err != nil { + return errors.Trace(err) + } + if lock.ExpireAt <= now { + if lock.LockId > r.lockId { + return errors.Errorf("There are another BR(%d) running after but setting lock before this one(%d). "+ + "Please check whether the BR is running. If not, you can retry.", lock.LockId, r.lockId) + } + if lock.LockId == r.lockId { + log.Warn("The lock has expired.", + zap.Int64("expire-at(ms)", lock.ExpireAt), zap.Int64("now(ms)", now)) + } + } else if lock.LockId != r.lockId { + return errors.Errorf("The existing lock will expire in %d seconds. "+ + "There may be another BR(%d) running. If not, you can wait for the lock to expire, "+ + "or delete the file `%s%s` manually.", + (lock.ExpireAt-now)/1000, lock.LockId, strings.TrimRight(r.storage.URI(), "/"), r.CheckpointLockPath) + } + + return nil +} + +// generate a new lock and flush the lock to the external storage +func (r *CheckpointRunner[K, V]) updateLock(ctx context.Context) error { + p, _, err := r.getTS(ctx) + if err != nil { + return errors.Trace(err) + } + if err = r.checkLockFile(ctx, p); err != nil { + return errors.Trace(err) + } + if err = r.flushLock(ctx, p); err != nil { + return errors.Trace(err) + } + + failpoint.Inject("failed-after-checkpoint-updates-lock", func(_ failpoint.Value) { + failpoint.Return(errors.Errorf("failpoint: failed after checkpoint updates lock")) + }) + + return nil +} + +// Attempt to initialize the lock. Need to stop the backup when there is an unexpired locks. +func (r *CheckpointRunner[K, V]) initialLock(ctx context.Context) error { + p, l, err := r.getTS(ctx) + if err != nil { + return errors.Trace(err) + } + r.lockId = oracle.ComposeTS(p, l) + exist, err := r.storage.FileExists(ctx, r.CheckpointLockPath) + if err != nil { + return errors.Trace(err) + } + if exist { + if err := r.checkLockFile(ctx, p); err != nil { + return errors.Trace(err) + } + } + if err = r.flushLock(ctx, p); err != nil { + return errors.Trace(err) + } + + // wait for 3 seconds to check whether the lock file is overwritten by another BR + time.Sleep(3 * time.Second) + err = r.checkLockFile(ctx, p) + return errors.Trace(err) +} + +// walk the whole checkpoint range files and retrieve the metadata of backed up/restored ranges +// and return the total time cost in the past executions +func walkCheckpointFile[K KeyType, V ValueType]( + ctx context.Context, + s storage.ExternalStorage, + cipher *backuppb.CipherInfo, + subDir string, + fn func(groupKey K, value V), +) (time.Duration, error) { + // records the total time cost in the past executions + var pastDureTime time.Duration = 0 + err := s.WalkDir(ctx, &storage.WalkOption{SubDir: subDir}, func(path string, size int64) error { + if strings.HasSuffix(path, ".cpt") { + content, err := s.ReadFile(ctx, path) + if err != nil { + return errors.Trace(err) + } + + checkpointData := &CheckpointData{} + if err = json.Unmarshal(content, checkpointData); err != nil { + log.Error("failed to unmarshal the checkpoint data info, skip it", zap.Error(err)) + return nil + } + + if checkpointData.DureTime > pastDureTime { + pastDureTime = checkpointData.DureTime + } + for _, meta := range checkpointData.RangeGroupMetas { + decryptContent, err := metautil.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv) + if err != nil { + return errors.Trace(err) + } + + checksum := sha256.Sum256(decryptContent) + if !bytes.Equal(meta.Checksum, checksum[:]) { + log.Error("checkpoint checksum info's checksum mismatch, skip it", + zap.ByteString("expect", meta.Checksum), + zap.ByteString("got", checksum[:]), + ) + continue + } + + group := &RangeGroup[K, V]{} + if err = json.Unmarshal(decryptContent, group); err != nil { + return errors.Trace(err) + } + + for _, g := range group.Group { + fn(group.GroupKey, g) + } + } + } + return nil + }) + + return pastDureTime, errors.Trace(err) +} + +// load checkpoint meta data from external storage and unmarshal back +func loadCheckpointMeta[T any](ctx context.Context, s storage.ExternalStorage, path string, m *T) error { + data, err := s.ReadFile(ctx, path) + if err != nil { + return errors.Trace(err) + } + + err = json.Unmarshal(data, m) + return errors.Trace(err) +} + +// walk the whole checkpoint checksum files and retrieve checksum information of tables calculated +func loadCheckpointChecksum( + ctx context.Context, + s storage.ExternalStorage, + subDir string, +) (map[int64]*ChecksumItem, time.Duration, error) { + var pastDureTime time.Duration = 0 + checkpointChecksum := make(map[int64]*ChecksumItem) + err := s.WalkDir(ctx, &storage.WalkOption{SubDir: subDir}, func(path string, size int64) error { + data, err := s.ReadFile(ctx, path) + if err != nil { + return errors.Trace(err) + } + info := &ChecksumInfo{} + err = json.Unmarshal(data, info) + if err != nil { + log.Error("failed to unmarshal the checkpoint checksum info, skip it", zap.Error(err)) + return nil + } + + checksum := sha256.Sum256(info.Content) + if !bytes.Equal(info.Checksum, checksum[:]) { + log.Error("checkpoint checksum info's checksum mismatch, skip it", + zap.ByteString("expect", info.Checksum), + zap.ByteString("got", checksum[:]), + ) + return nil + } + + if info.DureTime > pastDureTime { + pastDureTime = info.DureTime + } + + items := &ChecksumItems{} + err = json.Unmarshal(info.Content, items) + if err != nil { + return errors.Trace(err) + } + + for _, c := range items.Items { + checkpointChecksum[c.TableID] = c + } + return nil + }) + return checkpointChecksum, pastDureTime, errors.Trace(err) +} + +func saveCheckpointMetadata[T any](ctx context.Context, s storage.ExternalStorage, meta *T, path string) error { + data, err := json.Marshal(meta) + if err != nil { + return errors.Trace(err) + } + + err = s.WriteFile(ctx, path, data) + return errors.Trace(err) +} + +func removeCheckpointData(ctx context.Context, s storage.ExternalStorage, subDir string) error { + var ( + // Generate one file every 30 seconds, so there are only 1200 files in 10 hours. + removedFileNames = make([]string, 0, 1200) + + removeCnt int = 0 + removeSize int64 = 0 + ) + err := s.WalkDir(ctx, &storage.WalkOption{SubDir: subDir}, func(path string, size int64) error { + if !strings.HasSuffix(path, ".cpt") && !strings.HasSuffix(path, ".meta") && !strings.HasSuffix(path, ".lock") { + return nil + } + removedFileNames = append(removedFileNames, path) + removeCnt += 1 + removeSize += size + return nil + }) + if err != nil { + return errors.Trace(err) + } + log.Info("start to remove checkpoint data", + zap.String("checkpoint task", subDir), + zap.Int("remove-count", removeCnt), + zap.Int64("remove-size", removeSize), + ) + + maxFailedFilesNum := int64(16) + var failedFilesCount atomic.Int64 + pool := util.NewWorkerPool(4, "checkpoint remove worker") + eg, gCtx := errgroup.WithContext(ctx) + for _, filename := range removedFileNames { + name := filename + pool.ApplyOnErrorGroup(eg, func() error { + if err := s.DeleteFile(gCtx, name); err != nil { + log.Warn("failed to remove the file", zap.String("filename", name), zap.Error(err)) + if failedFilesCount.Add(1) >= maxFailedFilesNum { + return errors.Annotate(err, "failed to delete too many files") + } + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + log.Info("all the checkpoint data has been removed", zap.String("checkpoint task", subDir)) + return nil +} diff --git a/br/pkg/checksum/binding__failpoint_binding__.go b/br/pkg/checksum/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..c63a7388ac3ba --- /dev/null +++ b/br/pkg/checksum/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package checksum + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/checksum/executor.go b/br/pkg/checksum/executor.go index 22f5a13d23a65..58c9ebc56a3cc 100644 --- a/br/pkg/checksum/executor.go +++ b/br/pkg/checksum/executor.go @@ -387,12 +387,12 @@ func (exec *Executor) Execute( vars.BackOffWeight = exec.backoffWeight } resp, err = sendChecksumRequest(ctx, client, req, vars) - failpoint.Inject("checksumRetryErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checksumRetryErr")); _err_ == nil { // first time reach here. return error if val.(bool) { err = errors.New("inject checksum error") } - }) + } if err != nil { return errors.Trace(err) } diff --git a/br/pkg/checksum/executor.go__failpoint_stash__ b/br/pkg/checksum/executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..22f5a13d23a65 --- /dev/null +++ b/br/pkg/checksum/executor.go__failpoint_stash__ @@ -0,0 +1,419 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package checksum + +import ( + "context" + + "github.com/gogo/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +// ExecutorBuilder is used to build a "kv.Request". +type ExecutorBuilder struct { + table *model.TableInfo + ts uint64 + + oldTable *metautil.Table + + concurrency uint + backoffWeight int + + oldKeyspace []byte + newKeyspace []byte + + resourceGroupName string + explicitRequestSourceType string +} + +// NewExecutorBuilder returns a new executor builder. +func NewExecutorBuilder(table *model.TableInfo, ts uint64) *ExecutorBuilder { + return &ExecutorBuilder{ + table: table, + ts: ts, + + concurrency: variable.DefDistSQLScanConcurrency, + } +} + +// SetOldTable set a old table info to the builder. +func (builder *ExecutorBuilder) SetOldTable(oldTable *metautil.Table) *ExecutorBuilder { + builder.oldTable = oldTable + return builder +} + +// SetConcurrency set the concurrency of the checksum executing. +func (builder *ExecutorBuilder) SetConcurrency(conc uint) *ExecutorBuilder { + builder.concurrency = conc + return builder +} + +// SetBackoffWeight set the backoffWeight of the checksum executing. +func (builder *ExecutorBuilder) SetBackoffWeight(backoffWeight int) *ExecutorBuilder { + builder.backoffWeight = backoffWeight + return builder +} + +func (builder *ExecutorBuilder) SetOldKeyspace(keyspace []byte) *ExecutorBuilder { + builder.oldKeyspace = keyspace + return builder +} + +func (builder *ExecutorBuilder) SetNewKeyspace(keyspace []byte) *ExecutorBuilder { + builder.newKeyspace = keyspace + return builder +} + +func (builder *ExecutorBuilder) SetResourceGroupName(name string) *ExecutorBuilder { + builder.resourceGroupName = name + return builder +} + +func (builder *ExecutorBuilder) SetExplicitRequestSourceType(name string) *ExecutorBuilder { + builder.explicitRequestSourceType = name + return builder +} + +// Build builds a checksum executor. +func (builder *ExecutorBuilder) Build() (*Executor, error) { + reqs, err := buildChecksumRequest( + builder.table, + builder.oldTable, + builder.ts, + builder.concurrency, + builder.oldKeyspace, + builder.newKeyspace, + builder.resourceGroupName, + builder.explicitRequestSourceType, + ) + if err != nil { + return nil, errors.Trace(err) + } + return &Executor{reqs: reqs, backoffWeight: builder.backoffWeight}, nil +} + +func buildChecksumRequest( + newTable *model.TableInfo, + oldTable *metautil.Table, + startTS uint64, + concurrency uint, + oldKeyspace []byte, + newKeyspace []byte, + resourceGroupName, explicitRequestSourceType string, +) ([]*kv.Request, error) { + var partDefs []model.PartitionDefinition + if part := newTable.Partition; part != nil { + partDefs = part.Definitions + } + + reqs := make([]*kv.Request, 0, (len(newTable.Indices)+1)*(len(partDefs)+1)) + var oldTableID int64 + if oldTable != nil { + oldTableID = oldTable.Info.ID + } + rs, err := buildRequest(newTable, newTable.ID, oldTable, oldTableID, startTS, concurrency, + oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) + if err != nil { + return nil, errors.Trace(err) + } + reqs = append(reqs, rs...) + + for _, partDef := range partDefs { + var oldPartID int64 + if oldTable != nil { + for _, oldPartDef := range oldTable.Info.Partition.Definitions { + if oldPartDef.Name == partDef.Name { + oldPartID = oldPartDef.ID + } + } + } + rs, err := buildRequest(newTable, partDef.ID, oldTable, oldPartID, startTS, concurrency, + oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) + if err != nil { + return nil, errors.Trace(err) + } + reqs = append(reqs, rs...) + } + + return reqs, nil +} + +func buildRequest( + tableInfo *model.TableInfo, + tableID int64, + oldTable *metautil.Table, + oldTableID int64, + startTS uint64, + concurrency uint, + oldKeyspace []byte, + newKeyspace []byte, + resourceGroupName, explicitRequestSourceType string, +) ([]*kv.Request, error) { + reqs := make([]*kv.Request, 0) + req, err := buildTableRequest(tableInfo, tableID, oldTable, oldTableID, startTS, concurrency, + oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) + if err != nil { + return nil, errors.Trace(err) + } + reqs = append(reqs, req) + + for _, indexInfo := range tableInfo.Indices { + if indexInfo.State != model.StatePublic { + continue + } + var oldIndexInfo *model.IndexInfo + if oldTable != nil { + for _, oldIndex := range oldTable.Info.Indices { + if oldIndex.Name == indexInfo.Name { + oldIndexInfo = oldIndex + break + } + } + if oldIndexInfo == nil { + log.Panic("index not found in origin table, "+ + "please check the restore table has the same index info with origin table", + zap.Int64("table id", tableID), + zap.Stringer("table name", tableInfo.Name), + zap.Int64("origin table id", oldTableID), + zap.Stringer("origin table name", oldTable.Info.Name), + zap.Stringer("index name", indexInfo.Name)) + } + } + req, err = buildIndexRequest( + tableID, indexInfo, oldTableID, oldIndexInfo, startTS, concurrency, + oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) + if err != nil { + return nil, errors.Trace(err) + } + reqs = append(reqs, req) + } + + return reqs, nil +} + +func buildTableRequest( + tableInfo *model.TableInfo, + tableID int64, + oldTable *metautil.Table, + oldTableID int64, + startTS uint64, + concurrency uint, + oldKeyspace []byte, + newKeyspace []byte, + resourceGroupName, explicitRequestSourceType string, +) (*kv.Request, error) { + var rule *tipb.ChecksumRewriteRule + if oldTable != nil { + rule = &tipb.ChecksumRewriteRule{ + OldPrefix: append(append([]byte{}, oldKeyspace...), tablecodec.GenTableRecordPrefix(oldTableID)...), + NewPrefix: append(append([]byte{}, newKeyspace...), tablecodec.GenTableRecordPrefix(tableID)...), + } + } + + checksum := &tipb.ChecksumRequest{ + ScanOn: tipb.ChecksumScanOn_Table, + Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, + Rule: rule, + } + + var ranges []*ranger.Range + if tableInfo.IsCommonHandle { + ranges = ranger.FullNotNullRange() + } else { + ranges = ranger.FullIntRange(false) + } + + var builder distsql.RequestBuilder + // Use low priority to reducing impact to other requests. + builder.Request.Priority = kv.PriorityLow + return builder.SetHandleRanges(nil, tableID, tableInfo.IsCommonHandle, ranges). + SetStartTS(startTS). + SetChecksumRequest(checksum). + SetConcurrency(int(concurrency)). + SetResourceGroupName(resourceGroupName). + SetExplicitRequestSourceType(explicitRequestSourceType). + Build() +} + +func buildIndexRequest( + tableID int64, + indexInfo *model.IndexInfo, + oldTableID int64, + oldIndexInfo *model.IndexInfo, + startTS uint64, + concurrency uint, + oldKeyspace []byte, + newKeyspace []byte, + resourceGroupName, ExplicitRequestSourceType string, +) (*kv.Request, error) { + var rule *tipb.ChecksumRewriteRule + if oldIndexInfo != nil { + rule = &tipb.ChecksumRewriteRule{ + OldPrefix: append(append([]byte{}, oldKeyspace...), + tablecodec.EncodeTableIndexPrefix(oldTableID, oldIndexInfo.ID)...), + NewPrefix: append(append([]byte{}, newKeyspace...), + tablecodec.EncodeTableIndexPrefix(tableID, indexInfo.ID)...), + } + } + checksum := &tipb.ChecksumRequest{ + ScanOn: tipb.ChecksumScanOn_Index, + Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, + Rule: rule, + } + + ranges := ranger.FullRange() + + var builder distsql.RequestBuilder + // Use low priority to reducing impact to other requests. + builder.Request.Priority = kv.PriorityLow + return builder.SetIndexRanges(nil, tableID, indexInfo.ID, ranges). + SetStartTS(startTS). + SetChecksumRequest(checksum). + SetConcurrency(int(concurrency)). + SetResourceGroupName(resourceGroupName). + SetExplicitRequestSourceType(ExplicitRequestSourceType). + Build() +} + +func sendChecksumRequest( + ctx context.Context, client kv.Client, req *kv.Request, vars *kv.Variables, +) (resp *tipb.ChecksumResponse, err error) { + res, err := distsql.Checksum(ctx, client, req, vars) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + if err1 := res.Close(); err1 != nil { + err = err1 + } + }() + + resp = &tipb.ChecksumResponse{} + + for { + data, err := res.NextRaw(ctx) + if err != nil { + return nil, errors.Trace(err) + } + if data == nil { + break + } + checksum := &tipb.ChecksumResponse{} + if err = checksum.Unmarshal(data); err != nil { + return nil, errors.Trace(err) + } + updateChecksumResponse(resp, checksum) + } + + return resp, nil +} + +func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { + resp.Checksum ^= update.Checksum + resp.TotalKvs += update.TotalKvs + resp.TotalBytes += update.TotalBytes +} + +// Executor is a checksum executor. +type Executor struct { + reqs []*kv.Request + backoffWeight int +} + +// Len returns the total number of checksum requests. +func (exec *Executor) Len() int { + return len(exec.reqs) +} + +// Each executes the function to each requests in the executor. +func (exec *Executor) Each(f func(*kv.Request) error) error { + for _, req := range exec.reqs { + err := f(req) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +// RawRequests extracts the raw requests associated with this executor. +// This is mainly used for debugging only. +func (exec *Executor) RawRequests() ([]*tipb.ChecksumRequest, error) { + res := make([]*tipb.ChecksumRequest, 0, len(exec.reqs)) + for _, req := range exec.reqs { + rawReq := new(tipb.ChecksumRequest) + if err := proto.Unmarshal(req.Data, rawReq); err != nil { + return nil, errors.Trace(err) + } + res = append(res, rawReq) + } + return res, nil +} + +// Execute executes a checksum executor. +func (exec *Executor) Execute( + ctx context.Context, + client kv.Client, + updateFn func(), +) (*tipb.ChecksumResponse, error) { + checksumResp := &tipb.ChecksumResponse{} + checksumBackoffer := utils.InitialRetryState(utils.ChecksumRetryTime, + utils.ChecksumWaitInterval, utils.ChecksumMaxWaitInterval) + for _, req := range exec.reqs { + // Pointer to SessionVars.Killed + // Killed is a flag to indicate that this query is killed. + // + // It is useful in TiDB, however, it's a place holder in BR. + killed := uint32(0) + var ( + resp *tipb.ChecksumResponse + err error + ) + err = utils.WithRetry(ctx, func() error { + vars := kv.NewVariables(&killed) + if exec.backoffWeight > 0 { + vars.BackOffWeight = exec.backoffWeight + } + resp, err = sendChecksumRequest(ctx, client, req, vars) + failpoint.Inject("checksumRetryErr", func(val failpoint.Value) { + // first time reach here. return error + if val.(bool) { + err = errors.New("inject checksum error") + } + }) + if err != nil { + return errors.Trace(err) + } + return nil + }, &checksumBackoffer) + if err != nil { + return nil, errors.Trace(err) + } + updateChecksumResponse(checksumResp, resp) + updateFn() + } + return checksumResp, checkContextDone(ctx) +} + +// The coprocessor won't return the error if the context is done, +// so sometimes BR would get the incomplete result. +// checkContextDone makes sure the result is not affected by CONTEXT DONE. +func checkContextDone(ctx context.Context) error { + ctxErr := ctx.Err() + if ctxErr != nil { + return errors.Annotate(ctxErr, "context is cancelled by other error") + } + return nil +} diff --git a/br/pkg/conn/binding__failpoint_binding__.go b/br/pkg/conn/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..195eaae166265 --- /dev/null +++ b/br/pkg/conn/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package conn + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/conn/conn.go b/br/pkg/conn/conn.go index cdb81a011c8a5..29cb9dafa12ba 100644 --- a/br/pkg/conn/conn.go +++ b/br/pkg/conn/conn.go @@ -87,29 +87,29 @@ func GetAllTiKVStoresWithRetry(ctx context.Context, ctx, func() error { stores, err = util.GetAllTiKVStores(ctx, pdClient, storeBehavior) - failpoint.Inject("hint-GetAllTiKVStores-error", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("hint-GetAllTiKVStores-error")); _err_ == nil { logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-error injected.") if val.(bool) { err = status.Error(codes.Unknown, "Retryable error") - failpoint.Return(err) + return err } - }) + } - failpoint.Inject("hint-GetAllTiKVStores-grpc-cancel", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("hint-GetAllTiKVStores-grpc-cancel")); _err_ == nil { logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-grpc-cancel injected.") if val.(bool) { err = status.Error(codes.Canceled, "Cancel Retry") - failpoint.Return(err) + return err } - }) + } - failpoint.Inject("hint-GetAllTiKVStores-ctx-cancel", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("hint-GetAllTiKVStores-ctx-cancel")); _err_ == nil { logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-ctx-cancel injected.") if val.(bool) { err = context.Canceled - failpoint.Return(err) + return err } - }) + } return errors.Trace(err) }, diff --git a/br/pkg/conn/conn.go__failpoint_stash__ b/br/pkg/conn/conn.go__failpoint_stash__ new file mode 100644 index 0000000000000..cdb81a011c8a5 --- /dev/null +++ b/br/pkg/conn/conn.go__failpoint_stash__ @@ -0,0 +1,457 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package conn + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strings" + + "github.com/docker/go-units" + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + kvconfig "github.com/pingcap/tidb/br/pkg/config" + "github.com/pingcap/tidb/br/pkg/conn/util" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/pdutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/txnkv/txnlock" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/status" +) + +const ( + // DefaultMergeRegionSizeBytes is the default region split size, 96MB. + // See https://github.com/tikv/tikv/blob/v4.0.8/components/raftstore/src/coprocessor/config.rs#L35-L38 + DefaultMergeRegionSizeBytes uint64 = 96 * units.MiB + + // DefaultMergeRegionKeyCount is the default region key count, 960000. + DefaultMergeRegionKeyCount uint64 = 960000 + + // DefaultImportNumGoroutines is the default number of threads for import. + // use 128 as default value, which is 8 times of the default value of tidb. + // we think is proper for IO-bound cases. + DefaultImportNumGoroutines uint = 128 +) + +type VersionCheckerType int + +const ( + // default version checker + NormalVersionChecker VersionCheckerType = iota + // version checker for PiTR + StreamVersionChecker +) + +// Mgr manages connections to a TiDB cluster. +type Mgr struct { + *pdutil.PdController + dom *domain.Domain + storage kv.Storage // Used to access SQL related interfaces. + tikvStore tikv.Storage // Used to access TiKV specific interfaces. + ownsStorage bool + + *utils.StoreManager +} + +func GetAllTiKVStoresWithRetry(ctx context.Context, + pdClient util.StoreMeta, + storeBehavior util.StoreBehavior, +) ([]*metapb.Store, error) { + stores := make([]*metapb.Store, 0) + var err error + + errRetry := utils.WithRetry( + ctx, + func() error { + stores, err = util.GetAllTiKVStores(ctx, pdClient, storeBehavior) + failpoint.Inject("hint-GetAllTiKVStores-error", func(val failpoint.Value) { + logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-error injected.") + if val.(bool) { + err = status.Error(codes.Unknown, "Retryable error") + failpoint.Return(err) + } + }) + + failpoint.Inject("hint-GetAllTiKVStores-grpc-cancel", func(val failpoint.Value) { + logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-grpc-cancel injected.") + if val.(bool) { + err = status.Error(codes.Canceled, "Cancel Retry") + failpoint.Return(err) + } + }) + + failpoint.Inject("hint-GetAllTiKVStores-ctx-cancel", func(val failpoint.Value) { + logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-ctx-cancel injected.") + if val.(bool) { + err = context.Canceled + failpoint.Return(err) + } + }) + + return errors.Trace(err) + }, + utils.NewPDReqBackoffer(), + ) + + return stores, errors.Trace(errRetry) +} + +func checkStoresAlive(ctx context.Context, + pdclient pd.Client, + storeBehavior util.StoreBehavior) error { + // Check live tikv. + stores, err := util.GetAllTiKVStores(ctx, pdclient, storeBehavior) + if err != nil { + log.Error("failed to get store", zap.Error(err)) + return errors.Trace(err) + } + + liveStoreCount := 0 + for _, s := range stores { + if s.GetState() != metapb.StoreState_Up { + continue + } + liveStoreCount++ + } + log.Info("checked alive KV stores", zap.Int("aliveStores", liveStoreCount), zap.Int("totalStores", len(stores))) + return nil +} + +// NewMgr creates a new Mgr. +// +// Domain is optional for Backup, set `needDomain` to false to disable +// initializing Domain. +func NewMgr( + ctx context.Context, + g glue.Glue, + pdAddrs []string, + tlsConf *tls.Config, + securityOption pd.SecurityOption, + keepalive keepalive.ClientParameters, + storeBehavior util.StoreBehavior, + checkRequirements bool, + needDomain bool, + versionCheckerType VersionCheckerType, +) (*Mgr, error) { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("conn.NewMgr", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + log.Info("new mgr", zap.Strings("pdAddrs", pdAddrs)) + + controller, err := pdutil.NewPdController(ctx, pdAddrs, tlsConf, securityOption) + if err != nil { + log.Error("failed to create pd controller", zap.Error(err)) + return nil, errors.Trace(err) + } + if checkRequirements { + var checker version.VerChecker + switch versionCheckerType { + case NormalVersionChecker: + checker = version.CheckVersionForBR + case StreamVersionChecker: + checker = version.CheckVersionForBRPiTR + default: + return nil, errors.Errorf("unknown command type, comman code is %d", versionCheckerType) + } + err = version.CheckClusterVersion(ctx, controller.GetPDClient(), checker) + if err != nil { + return nil, errors.Annotate(err, "running BR in incompatible version of cluster, "+ + "if you believe it's OK, use --check-requirements=false to skip.") + } + } + + err = checkStoresAlive(ctx, controller.GetPDClient(), storeBehavior) + if err != nil { + return nil, errors.Trace(err) + } + + // Disable GC because TiDB enables GC already. + path := fmt.Sprintf( + "tikv://%s?disableGC=true&keyspaceName=%s", + strings.Join(pdAddrs, ","), config.GetGlobalKeyspaceName(), + ) + storage, err := g.Open(path, securityOption) + if err != nil { + return nil, errors.Trace(err) + } + + tikvStorage, ok := storage.(tikv.Storage) + if !ok { + return nil, berrors.ErrKVNotTiKV + } + + var dom *domain.Domain + if needDomain { + dom, err = g.GetDomain(storage) + if err != nil { + return nil, errors.Trace(err) + } + // we must check tidb(tikv version) any time after concurrent ddl feature implemented in v6.2. + // we will keep this check until 7.0, which allow the breaking changes. + // NOTE: must call it after domain created! + // FIXME: remove this check in v7.0 + err = version.CheckClusterVersion(ctx, controller.GetPDClient(), version.CheckVersionForDDL) + if err != nil { + return nil, errors.Annotate(err, "unable to check cluster version for ddl") + } + } + + mgr := &Mgr{ + PdController: controller, + storage: storage, + tikvStore: tikvStorage, + dom: dom, + ownsStorage: g.OwnsStorage(), + StoreManager: utils.NewStoreManager(controller.GetPDClient(), keepalive, tlsConf), + } + return mgr, nil +} + +// GetBackupClient get or create a backup client. +func (mgr *Mgr) GetBackupClient(ctx context.Context, storeID uint64) (backuppb.BackupClient, error) { + var cli backuppb.BackupClient + if err := mgr.WithConn(ctx, storeID, func(cc *grpc.ClientConn) { + cli = backuppb.NewBackupClient(cc) + }); err != nil { + return nil, err + } + return cli, nil +} + +func (mgr *Mgr) GetLogBackupClient(ctx context.Context, storeID uint64) (logbackup.LogBackupClient, error) { + var cli logbackup.LogBackupClient + if err := mgr.WithConn(ctx, storeID, func(cc *grpc.ClientConn) { + cli = logbackup.NewLogBackupClient(cc) + }); err != nil { + return nil, err + } + return cli, nil +} + +// GetStorage returns a kv storage. +func (mgr *Mgr) GetStorage() kv.Storage { + return mgr.storage +} + +// GetTLSConfig returns the tls config. +func (mgr *Mgr) GetTLSConfig() *tls.Config { + return mgr.StoreManager.TLSConfig() +} + +// GetStore gets the tikvStore. +func (mgr *Mgr) GetStore() tikv.Storage { + return mgr.tikvStore +} + +// GetLockResolver gets the LockResolver. +func (mgr *Mgr) GetLockResolver() *txnlock.LockResolver { + return mgr.tikvStore.GetLockResolver() +} + +// GetDomain returns a tikv storage. +func (mgr *Mgr) GetDomain() *domain.Domain { + return mgr.dom +} + +func (mgr *Mgr) Close() { + if mgr.StoreManager != nil { + mgr.StoreManager.Close() + } + // Gracefully shutdown domain so it does not affect other TiDB DDL. + // Must close domain before closing storage, otherwise it gets stuck forever. + if mgr.ownsStorage { + if mgr.dom != nil { + mgr.dom.Close() + } + tikv.StoreShuttingDown(1) + _ = mgr.storage.Close() + } + + mgr.PdController.Close() +} + +// GetTS gets current ts from pd. +func (mgr *Mgr) GetTS(ctx context.Context) (uint64, error) { + p, l, err := mgr.GetPDClient().GetTS(ctx) + if err != nil { + return 0, errors.Trace(err) + } + + return oracle.ComposeTS(p, l), nil +} + +// ProcessTiKVConfigs handle the tikv config for region split size, region split keys, and import goroutines in place. +// It retrieves the config from all alive tikv stores and returns the minimum values. +// If retrieving the config fails, it returns the default config values. +func (mgr *Mgr) ProcessTiKVConfigs(ctx context.Context, cfg *kvconfig.KVConfig, client *http.Client) { + mergeRegionSize := cfg.MergeRegionSize + mergeRegionKeyCount := cfg.MergeRegionKeyCount + importGoroutines := cfg.ImportGoroutines + + if mergeRegionSize.Modified && mergeRegionKeyCount.Modified && importGoroutines.Modified { + log.Info("no need to retrieve the config from tikv if user has set the config") + return + } + + err := mgr.GetConfigFromTiKV(ctx, client, func(resp *http.Response) error { + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + if !mergeRegionSize.Modified || !mergeRegionKeyCount.Modified { + size, keys, e := kvconfig.ParseMergeRegionSizeFromConfig(respBytes) + if e != nil { + log.Warn("Failed to parse region split size and keys from config", logutil.ShortError(e)) + return e + } + if mergeRegionKeyCount.Value == DefaultMergeRegionKeyCount || keys < mergeRegionKeyCount.Value { + mergeRegionSize.Value = size + mergeRegionKeyCount.Value = keys + } + } + if !importGoroutines.Modified { + threads, e := kvconfig.ParseImportThreadsFromConfig(respBytes) + if e != nil { + log.Warn("Failed to parse import num-threads from config", logutil.ShortError(e)) + return e + } + // We use 8 times the default value because it's an IO-bound case. + if importGoroutines.Value == DefaultImportNumGoroutines || (threads > 0 && threads*8 < importGoroutines.Value) { + importGoroutines.Value = threads * 8 + } + } + // replace the value + cfg.MergeRegionSize = mergeRegionSize + cfg.MergeRegionKeyCount = mergeRegionKeyCount + cfg.ImportGoroutines = importGoroutines + return nil + }) + + if err != nil { + log.Warn("Failed to get config from TiKV; using default", logutil.ShortError(err)) + } +} + +// IsLogBackupEnabled is used for br to check whether tikv has enabled log backup. +func (mgr *Mgr) IsLogBackupEnabled(ctx context.Context, client *http.Client) (bool, error) { + logbackupEnable := true + err := mgr.GetConfigFromTiKV(ctx, client, func(resp *http.Response) error { + respBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + enable, err := kvconfig.ParseLogBackupEnableFromConfig(respBytes) + if err != nil { + log.Warn("Failed to parse log-backup enable from config", logutil.ShortError(err)) + return err + } + logbackupEnable = logbackupEnable && enable + return nil + }) + return logbackupEnable, errors.Trace(err) +} + +// GetConfigFromTiKV get configs from all alive tikv stores. +func (mgr *Mgr) GetConfigFromTiKV(ctx context.Context, cli *http.Client, fn func(*http.Response) error) error { + allStores, err := GetAllTiKVStoresWithRetry(ctx, mgr.GetPDClient(), util.SkipTiFlash) + if err != nil { + return errors.Trace(err) + } + + httpPrefix := "http://" + if mgr.GetTLSConfig() != nil { + httpPrefix = "https://" + } + + for _, store := range allStores { + if store.State != metapb.StoreState_Up { + continue + } + // we need make sure every available store support backup-stream otherwise we might lose data. + // so check every store's config + addr, err := handleTiKVAddress(store, httpPrefix) + if err != nil { + return err + } + configAddr := fmt.Sprintf("%s/config", addr.String()) + + err = utils.WithRetry(ctx, func() error { + resp, e := cli.Get(configAddr) + if e != nil { + return e + } + defer resp.Body.Close() + err = fn(resp) + if err != nil { + return err + } + return nil + }, utils.NewPDReqBackoffer()) + if err != nil { + // if one store failed, break and return error + return err + } + } + return nil +} + +func handleTiKVAddress(store *metapb.Store, httpPrefix string) (*url.URL, error) { + statusAddr := store.GetStatusAddress() + nodeAddr := store.GetAddress() + if !strings.HasPrefix(statusAddr, "http") { + statusAddr = httpPrefix + statusAddr + } + if !strings.HasPrefix(nodeAddr, "http") { + nodeAddr = httpPrefix + nodeAddr + } + + statusUrl, err := url.Parse(statusAddr) + if err != nil { + return nil, err + } + nodeUrl, err := url.Parse(nodeAddr) + if err != nil { + return nil, err + } + + // we try status address as default + addr := statusUrl + // but in sometimes we may not get the correct status address from PD. + if statusUrl.Hostname() != nodeUrl.Hostname() { + // if not matched, we use the address as default, but change the port + addr.Host = net.JoinHostPort(nodeUrl.Hostname(), statusUrl.Port()) + log.Warn("store address and status address mismatch the host, we will use the store address as hostname", + zap.Uint64("store", store.Id), + zap.String("status address", statusAddr), + zap.String("node address", nodeAddr), + zap.Any("request address", statusUrl), + ) + } + return addr, nil +} diff --git a/br/pkg/pdutil/binding__failpoint_binding__.go b/br/pkg/pdutil/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..35b536059c6a3 --- /dev/null +++ b/br/pkg/pdutil/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package pdutil + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 31f20cd8af433..4e7ec9c0a614a 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -212,12 +212,12 @@ func parseVersion(versionStr string) *semver.Version { zap.String("version", versionStr), zap.Error(err)) version = &semver.Version{Major: 0, Minor: 0, Patch: 0} } - failpoint.Inject("PDEnabledPauseConfig", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("PDEnabledPauseConfig")); _err_ == nil { if val.(bool) { // test pause config is enable version = &semver.Version{Major: 5, Minor: 0, Patch: 0} } - }) + } return version } diff --git a/br/pkg/pdutil/pd.go__failpoint_stash__ b/br/pkg/pdutil/pd.go__failpoint_stash__ new file mode 100644 index 0000000000000..31f20cd8af433 --- /dev/null +++ b/br/pkg/pdutil/pd.go__failpoint_stash__ @@ -0,0 +1,782 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package pdutil + +import ( + "context" + "crypto/tls" + "encoding/hex" + "fmt" + "math" + "net/http" + "strings" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/docker/go-units" + "github.com/google/uuid" + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/pkg/util/codec" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "github.com/tikv/pd/client/retry" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +const ( + maxMsgSize = int(128 * units.MiB) // pd.ScanRegion may return a large response + pauseTimeout = 5 * time.Minute + // pd request retry time when connection fail + PDRequestRetryTime = 120 + // set max-pending-peer-count to a large value to avoid scatter region failed. + maxPendingPeerUnlimited uint64 = math.MaxInt32 +) + +// pauseConfigGenerator generate a config value according to store count and current value. +type pauseConfigGenerator func(int, any) any + +// zeroPauseConfig sets the config to 0. +func zeroPauseConfig(int, any) any { + return 0 +} + +// pauseConfigMulStores multiplies the existing value by +// number of stores. The value is limited to 40, as larger value +// may make the cluster unstable. +func pauseConfigMulStores(stores int, raw any) any { + rawCfg := raw.(float64) + return math.Min(40, rawCfg*float64(stores)) +} + +// pauseConfigFalse sets the config to "false". +func pauseConfigFalse(int, any) any { + return "false" +} + +// constConfigGeneratorBuilder build a pauseConfigGenerator based on a given const value. +func constConfigGeneratorBuilder(val any) pauseConfigGenerator { + return func(int, any) any { + return val + } +} + +// ClusterConfig represents a set of scheduler whose config have been modified +// along with their original config. +type ClusterConfig struct { + // Enable PD schedulers before restore + Schedulers []string `json:"schedulers"` + // Original scheudle configuration + ScheduleCfg map[string]any `json:"schedule_cfg"` +} + +type pauseSchedulerBody struct { + Delay int64 `json:"delay"` +} + +var ( + // in v4.0.8 version we can use pause configs + // see https://github.com/tikv/pd/pull/3088 + pauseConfigVersion = semver.Version{Major: 4, Minor: 0, Patch: 8} + + // After v6.1.0 version, we can pause schedulers by key range with TTL. + minVersionForRegionLabelTTL = semver.Version{Major: 6, Minor: 1, Patch: 0} + + // Schedulers represent region/leader schedulers which can impact on performance. + Schedulers = map[string]struct{}{ + "balance-leader-scheduler": {}, + "balance-hot-region-scheduler": {}, + "balance-region-scheduler": {}, + + "shuffle-leader-scheduler": {}, + "shuffle-region-scheduler": {}, + "shuffle-hot-region-scheduler": {}, + + "evict-slow-store-scheduler": {}, + } + expectPDCfgGenerators = map[string]pauseConfigGenerator{ + "merge-schedule-limit": zeroPauseConfig, + // TODO "leader-schedule-limit" and "region-schedule-limit" don't support ttl for now, + // but we still need set these config for compatible with old version. + // we need wait for https://github.com/tikv/pd/pull/3131 merged. + // see details https://github.com/pingcap/br/pull/592#discussion_r522684325 + "leader-schedule-limit": pauseConfigMulStores, + "region-schedule-limit": pauseConfigMulStores, + "max-snapshot-count": pauseConfigMulStores, + "enable-location-replacement": pauseConfigFalse, + "max-pending-peer-count": constConfigGeneratorBuilder(maxPendingPeerUnlimited), + } + + // defaultPDCfg find by https://github.com/tikv/pd/blob/master/conf/config.toml. + // only use for debug command. + defaultPDCfg = map[string]any{ + "merge-schedule-limit": 8, + "leader-schedule-limit": 4, + "region-schedule-limit": 2048, + "enable-location-replacement": "true", + } +) + +// DefaultExpectPDCfgGenerators returns default pd config generators +func DefaultExpectPDCfgGenerators() map[string]pauseConfigGenerator { + clone := make(map[string]pauseConfigGenerator, len(expectPDCfgGenerators)) + for k := range expectPDCfgGenerators { + clone[k] = expectPDCfgGenerators[k] + } + return clone +} + +// PdController manage get/update config from pd. +type PdController struct { + pdClient pd.Client + pdHTTPCli pdhttp.Client + version *semver.Version + + // control the pause schedulers goroutine + schedulerPauseCh chan struct{} + // control the ttl of pausing schedulers + SchedulerPauseTTL time.Duration +} + +// NewPdController creates a new PdController. +func NewPdController( + ctx context.Context, + pdAddrs []string, + tlsConf *tls.Config, + securityOption pd.SecurityOption, +) (*PdController, error) { + maxCallMsgSize := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), + grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)), + } + pdClient, err := pd.NewClientWithContext( + ctx, pdAddrs, securityOption, + pd.WithGRPCDialOptions(maxCallMsgSize...), + // If the time too short, we may scatter a region many times, because + // the interface `ScatterRegions` may time out. + pd.WithCustomTimeoutOption(60*time.Second), + ) + if err != nil { + log.Error("fail to create pd client", zap.Error(err)) + return nil, errors.Trace(err) + } + + pdHTTPCliConfig := make([]pdhttp.ClientOption, 0, 1) + if tlsConf != nil { + pdHTTPCliConfig = append(pdHTTPCliConfig, pdhttp.WithTLSConfig(tlsConf)) + } + pdHTTPCli := pdhttp.NewClientWithServiceDiscovery( + "br/lightning PD controller", + pdClient.GetServiceDiscovery(), + pdHTTPCliConfig..., + ).WithBackoffer(retry.InitialBackoffer(time.Second, time.Second, PDRequestRetryTime*time.Second)) + versionStr, err := pdHTTPCli.GetPDVersion(ctx) + if err != nil { + pdHTTPCli.Close() + pdClient.Close() + return nil, errors.Trace(err) + } + version := parseVersion(versionStr) + + return &PdController{ + pdClient: pdClient, + pdHTTPCli: pdHTTPCli, + version: version, + // We should make a buffered channel here otherwise when context canceled, + // gracefully shutdown will stick at resuming schedulers. + schedulerPauseCh: make(chan struct{}, 1), + }, nil +} + +func NewPdControllerWithPDClient(pdClient pd.Client, pdHTTPCli pdhttp.Client, v *semver.Version) *PdController { + return &PdController{ + pdClient: pdClient, + pdHTTPCli: pdHTTPCli, + version: v, + schedulerPauseCh: make(chan struct{}, 1), + } +} + +func parseVersion(versionStr string) *semver.Version { + // we need trim space or semver will parse failed + v := strings.TrimSpace(versionStr) + v = strings.Trim(v, "\"") + v = strings.TrimPrefix(v, "v") + version, err := semver.NewVersion(v) + if err != nil { + log.Warn("fail back to v0.0.0 version", + zap.String("version", versionStr), zap.Error(err)) + version = &semver.Version{Major: 0, Minor: 0, Patch: 0} + } + failpoint.Inject("PDEnabledPauseConfig", func(val failpoint.Value) { + if val.(bool) { + // test pause config is enable + version = &semver.Version{Major: 5, Minor: 0, Patch: 0} + } + }) + return version +} + +func (p *PdController) isPauseConfigEnabled() bool { + return p.version.Compare(pauseConfigVersion) >= 0 +} + +// SetPDClient set pd addrs and cli for test. +func (p *PdController) SetPDClient(pdClient pd.Client) { + p.pdClient = pdClient +} + +// GetPDClient set pd addrs and cli for test. +func (p *PdController) GetPDClient() pd.Client { + return p.pdClient +} + +// GetPDHTTPClient returns the pd http client. +func (p *PdController) GetPDHTTPClient() pdhttp.Client { + return p.pdHTTPCli +} + +// GetClusterVersion returns the current cluster version. +func (p *PdController) GetClusterVersion(ctx context.Context) (string, error) { + v, err := p.pdHTTPCli.GetClusterVersion(ctx) + return v, errors.Trace(err) +} + +// GetRegionCount returns the region count in the specified range. +func (p *PdController) GetRegionCount(ctx context.Context, startKey, endKey []byte) (int, error) { + // TiKV reports region start/end keys to PD in memcomparable-format. + var start, end []byte + start = codec.EncodeBytes(nil, startKey) + if len(endKey) != 0 { // Empty end key means the max. + end = codec.EncodeBytes(nil, endKey) + } + status, err := p.pdHTTPCli.GetRegionStatusByKeyRange(ctx, pdhttp.NewKeyRange(start, end), true) + if err != nil { + return 0, errors.Trace(err) + } + return status.Count, nil +} + +// GetStoreInfo returns the info of store with the specified id. +func (p *PdController) GetStoreInfo(ctx context.Context, storeID uint64) (*pdhttp.StoreInfo, error) { + info, err := p.pdHTTPCli.GetStore(ctx, storeID) + return info, errors.Trace(err) +} + +func (p *PdController) doPauseSchedulers( + ctx context.Context, + schedulers []string, +) ([]string, error) { + // pause this scheduler with 300 seconds + delay := int64(p.ttlOfPausing().Seconds()) + removedSchedulers := make([]string, 0, len(schedulers)) + for _, scheduler := range schedulers { + err := p.pdHTTPCli.SetSchedulerDelay(ctx, scheduler, delay) + if err != nil { + return removedSchedulers, errors.Trace(err) + } + removedSchedulers = append(removedSchedulers, scheduler) + } + return removedSchedulers, nil +} + +func (p *PdController) pauseSchedulersAndConfigWith( + ctx context.Context, schedulers []string, + schedulerCfg map[string]any, +) ([]string, error) { + // first pause this scheduler, if the first time failed. we should return the error + // so put first time out of for loop. and in for loop we could ignore other failed pause. + removedSchedulers, err := p.doPauseSchedulers(ctx, schedulers) + if err != nil { + log.Error("failed to pause scheduler at beginning", + zap.Strings("name", schedulers), zap.Error(err)) + return nil, errors.Trace(err) + } + log.Info("pause scheduler successful at beginning", zap.Strings("name", schedulers)) + if schedulerCfg != nil { + err = p.doPauseConfigs(ctx, schedulerCfg) + if err != nil { + log.Error("failed to pause config at beginning", + zap.Any("cfg", schedulerCfg), zap.Error(err)) + return nil, errors.Trace(err) + } + log.Info("pause configs successful at beginning", zap.Any("cfg", schedulerCfg)) + } + + go func() { + tick := time.NewTicker(p.ttlOfPausing() / 3) + defer tick.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-tick.C: + _, err := p.doPauseSchedulers(ctx, schedulers) + if err != nil { + log.Warn("pause scheduler failed, ignore it and wait next time pause", zap.Error(err)) + } + if schedulerCfg != nil { + err = p.doPauseConfigs(ctx, schedulerCfg) + if err != nil { + log.Warn("pause configs failed, ignore it and wait next time pause", zap.Error(err)) + } + } + log.Info("pause scheduler(configs)", zap.Strings("name", removedSchedulers), + zap.Any("cfg", schedulerCfg)) + case <-p.schedulerPauseCh: + log.Info("exit pause scheduler and configs successful") + return + } + } + }() + return removedSchedulers, nil +} + +// ResumeSchedulers resume pd scheduler. +func (p *PdController) ResumeSchedulers(ctx context.Context, schedulers []string) error { + return errors.Trace(p.resumeSchedulerWith(ctx, schedulers)) +} + +func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []string) (err error) { + log.Info("resume scheduler", zap.Strings("schedulers", schedulers)) + p.schedulerPauseCh <- struct{}{} + + // 0 means stop pause. + delay := int64(0) + for _, scheduler := range schedulers { + err = p.pdHTTPCli.SetSchedulerDelay(ctx, scheduler, delay) + if err != nil { + log.Error("failed to resume scheduler after retry, you may reset this scheduler manually"+ + "or just wait this scheduler pause timeout", zap.String("scheduler", scheduler)) + } else { + log.Info("resume scheduler successful", zap.String("scheduler", scheduler)) + } + } + // no need to return error, because the pause will timeout. + return nil +} + +// ListSchedulers list all pd scheduler. +func (p *PdController) ListSchedulers(ctx context.Context) ([]string, error) { + s, err := p.pdHTTPCli.GetSchedulers(ctx) + return s, errors.Trace(err) +} + +// GetPDScheduleConfig returns PD schedule config value associated with the key. +// It returns nil if there is no such config item. +func (p *PdController) GetPDScheduleConfig(ctx context.Context) (map[string]any, error) { + cfg, err := p.pdHTTPCli.GetScheduleConfig(ctx) + return cfg, errors.Trace(err) +} + +// UpdatePDScheduleConfig updates PD schedule config value associated with the key. +func (p *PdController) UpdatePDScheduleConfig(ctx context.Context) error { + log.Info("update pd with default config", zap.Any("cfg", defaultPDCfg)) + return errors.Trace(p.doUpdatePDScheduleConfig(ctx, defaultPDCfg)) +} + +func (p *PdController) doUpdatePDScheduleConfig( + ctx context.Context, cfg map[string]any, ttlSeconds ...float64, +) error { + newCfg := make(map[string]any) + for k, v := range cfg { + // if we want use ttl, we need use config prefix first. + // which means cfg should transfer from "max-merge-region-keys" to "schedule.max-merge-region-keys". + sc := fmt.Sprintf("schedule.%s", k) + newCfg[sc] = v + } + + if err := p.pdHTTPCli.SetConfig(ctx, newCfg, ttlSeconds...); err != nil { + return errors.Annotatef( + berrors.ErrPDUpdateFailed, + "failed to update PD schedule config: %s", + err.Error(), + ) + } + return nil +} + +func (p *PdController) doPauseConfigs(ctx context.Context, cfg map[string]any) error { + // pause this scheduler with 300 seconds + return errors.Trace(p.doUpdatePDScheduleConfig(ctx, cfg, p.ttlOfPausing().Seconds())) +} + +func restoreSchedulers(ctx context.Context, pd *PdController, clusterCfg ClusterConfig, + configsNeedRestore map[string]pauseConfigGenerator) error { + if err := pd.ResumeSchedulers(ctx, clusterCfg.Schedulers); err != nil { + return errors.Annotate(err, "fail to add PD schedulers") + } + log.Info("restoring config", zap.Any("config", clusterCfg.ScheduleCfg)) + mergeCfg := make(map[string]any) + for cfgKey := range configsNeedRestore { + value := clusterCfg.ScheduleCfg[cfgKey] + if value == nil { + // Ignore non-exist config. + continue + } + mergeCfg[cfgKey] = value + } + + prefix := make([]float64, 0, 1) + if pd.isPauseConfigEnabled() { + // set config's ttl to zero, make temporary config invalid immediately. + prefix = append(prefix, 0) + } + // reset config with previous value. + if err := pd.doUpdatePDScheduleConfig(ctx, mergeCfg, prefix...); err != nil { + return errors.Annotate(err, "fail to update PD merge config") + } + return nil +} + +// MakeUndoFunctionByConfig return an UndoFunc based on specified ClusterConfig +func (p *PdController) MakeUndoFunctionByConfig(config ClusterConfig) UndoFunc { + return p.GenRestoreSchedulerFunc(config, expectPDCfgGenerators) +} + +// GenRestoreSchedulerFunc gen restore func +func (p *PdController) GenRestoreSchedulerFunc(config ClusterConfig, + configsNeedRestore map[string]pauseConfigGenerator) UndoFunc { + // todo: we only need config names, not a map[string]pauseConfigGenerator + restore := func(ctx context.Context) error { + return restoreSchedulers(ctx, p, config, configsNeedRestore) + } + return restore +} + +// RemoveSchedulers removes the schedulers that may slow down BR speed. +func (p *PdController) RemoveSchedulers(ctx context.Context) (undo UndoFunc, err error) { + undo = Nop + + origin, _, err1 := p.RemoveSchedulersWithOrigin(ctx) + if err1 != nil { + err = err1 + return + } + + undo = p.MakeUndoFunctionByConfig(ClusterConfig{Schedulers: origin.Schedulers, ScheduleCfg: origin.ScheduleCfg}) + return undo, errors.Trace(err) +} + +// RemoveSchedulersWithConfig removes the schedulers that may slow down BR speed. +func (p *PdController) RemoveSchedulersWithConfig( + ctx context.Context, +) (undo UndoFunc, config *ClusterConfig, err error) { + undo = Nop + + origin, _, err1 := p.RemoveSchedulersWithOrigin(ctx) + if err1 != nil { + err = err1 + return + } + + undo = p.MakeUndoFunctionByConfig(ClusterConfig{Schedulers: origin.Schedulers, ScheduleCfg: origin.ScheduleCfg}) + return undo, &origin, errors.Trace(err) +} + +// RemoveAllPDSchedulers pause pd scheduler during the snapshot backup and restore +func (p *PdController) RemoveAllPDSchedulers(ctx context.Context) (undo UndoFunc, err error) { + undo = Nop + + // during the backup, we shall stop all scheduler so that restore easy to implement + // during phase-2, pd is fresh and in recovering-mode(recovering-mark=true), there's no leader + // so there's no leader or region schedule initially. when phase-2 start force setting leaders, schedule may begin. + // we don't want pd do any leader or region schedule during this time, so we set those params to 0 + // before we force setting leaders + const enableTiKVSplitRegion = "enable-tikv-split-region" + scheduleLimitParams := []string{ + "hot-region-schedule-limit", + "leader-schedule-limit", + "merge-schedule-limit", + "region-schedule-limit", + "replica-schedule-limit", + enableTiKVSplitRegion, + } + pdConfigGenerators := DefaultExpectPDCfgGenerators() + for _, param := range scheduleLimitParams { + if param == enableTiKVSplitRegion { + pdConfigGenerators[param] = func(int, any) any { return false } + } else { + pdConfigGenerators[param] = func(int, any) any { return 0 } + } + } + + oldPDConfig, _, err1 := p.RemoveSchedulersWithConfigGenerator(ctx, pdConfigGenerators) + if err1 != nil { + err = err1 + return + } + + undo = p.GenRestoreSchedulerFunc(oldPDConfig, pdConfigGenerators) + return undo, errors.Trace(err) +} + +// RemoveSchedulersWithOrigin pause and remove br related schedule configs and return the origin and modified configs +func (p *PdController) RemoveSchedulersWithOrigin(ctx context.Context) ( + origin ClusterConfig, + modified ClusterConfig, + err error, +) { + origin, modified, err = p.RemoveSchedulersWithConfigGenerator(ctx, expectPDCfgGenerators) + err = errors.Trace(err) + return +} + +// RemoveSchedulersWithConfigGenerator pause scheduler with custom config generator +func (p *PdController) RemoveSchedulersWithConfigGenerator( + ctx context.Context, + pdConfigGenerators map[string]pauseConfigGenerator, +) (origin ClusterConfig, modified ClusterConfig, err error) { + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("PdController.RemoveSchedulers", + opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + originCfg := ClusterConfig{} + removedCfg := ClusterConfig{} + stores, err := p.pdClient.GetAllStores(ctx) + if err != nil { + return originCfg, removedCfg, errors.Trace(err) + } + scheduleCfg, err := p.GetPDScheduleConfig(ctx) + if err != nil { + return originCfg, removedCfg, errors.Trace(err) + } + disablePDCfg := make(map[string]any, len(pdConfigGenerators)) + originPDCfg := make(map[string]any, len(pdConfigGenerators)) + for cfgKey, cfgValFunc := range pdConfigGenerators { + value, ok := scheduleCfg[cfgKey] + if !ok { + // Ignore non-exist config. + continue + } + disablePDCfg[cfgKey] = cfgValFunc(len(stores), value) + originPDCfg[cfgKey] = value + } + originCfg.ScheduleCfg = originPDCfg + removedCfg.ScheduleCfg = disablePDCfg + + log.Debug("saved PD config", zap.Any("config", scheduleCfg)) + + // Remove default PD scheduler that may affect restore process. + existSchedulers, err := p.ListSchedulers(ctx) + if err != nil { + return originCfg, removedCfg, errors.Trace(err) + } + needRemoveSchedulers := make([]string, 0, len(existSchedulers)) + for _, s := range existSchedulers { + if _, ok := Schedulers[s]; ok { + needRemoveSchedulers = append(needRemoveSchedulers, s) + } + } + + removedSchedulers, err := p.doRemoveSchedulersWith(ctx, needRemoveSchedulers, disablePDCfg) + if err != nil { + return originCfg, removedCfg, errors.Trace(err) + } + + originCfg.Schedulers = removedSchedulers + removedCfg.Schedulers = removedSchedulers + + return originCfg, removedCfg, nil +} + +// RemoveSchedulersWithCfg removes pd schedulers and configs with specified ClusterConfig +func (p *PdController) RemoveSchedulersWithCfg(ctx context.Context, removeCfg ClusterConfig) error { + _, err := p.doRemoveSchedulersWith(ctx, removeCfg.Schedulers, removeCfg.ScheduleCfg) + return errors.Trace(err) +} + +func (p *PdController) doRemoveSchedulersWith( + ctx context.Context, + needRemoveSchedulers []string, + disablePDCfg map[string]any, +) ([]string, error) { + if !p.isPauseConfigEnabled() { + return nil, errors.Errorf("pd version %s not support pause config, please upgrade", p.version.String()) + } + // after 4.0.8 we can set these config with TTL + s, err := p.pauseSchedulersAndConfigWith(ctx, needRemoveSchedulers, disablePDCfg) + return s, errors.Trace(err) +} + +// GetMinResolvedTS get min-resolved-ts from pd +func (p *PdController) GetMinResolvedTS(ctx context.Context) (uint64, error) { + ts, _, err := p.pdHTTPCli.GetMinResolvedTSByStoresIDs(ctx, nil) + return ts, errors.Trace(err) +} + +// RecoverBaseAllocID recover base alloc id +func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error { + return errors.Trace(p.pdHTTPCli.ResetBaseAllocID(ctx, id)) +} + +// ResetTS reset current ts of pd +func (p *PdController) ResetTS(ctx context.Context, ts uint64) error { + // reset-ts of PD will never set ts < current pd ts + // we set force-use-larger=true to allow ts > current pd ts + 24h(on default) + err := p.pdHTTPCli.ResetTS(ctx, ts, true) + if err == nil { + return nil + } + if strings.Contains(err.Error(), http.StatusText(http.StatusForbidden)) { + log.Info("reset-ts returns with status forbidden, ignore") + return nil + } + return errors.Trace(err) +} + +// MarkRecovering mark pd into recovering +func (p *PdController) MarkRecovering(ctx context.Context) error { + return errors.Trace(p.pdHTTPCli.SetSnapshotRecoveringMark(ctx)) +} + +// UnmarkRecovering unmark pd recovering +func (p *PdController) UnmarkRecovering(ctx context.Context) error { + return errors.Trace(p.pdHTTPCli.DeleteSnapshotRecoveringMark(ctx)) +} + +// RegionLabel is the label of a region. This struct is partially copied from +// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L31. +type RegionLabel struct { + Key string `json:"key"` + Value string `json:"value"` + TTL string `json:"ttl,omitempty"` + StartAt string `json:"start_at,omitempty"` +} + +// LabelRule is the rule to assign labels to a region. This struct is partially copied from +// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L41. +type LabelRule struct { + ID string `json:"id"` + Labels []RegionLabel `json:"labels"` + RuleType string `json:"rule_type"` + Data any `json:"data"` +} + +// KeyRangeRule contains the start key and end key of the LabelRule. This struct is partially copied from +// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L62. +type KeyRangeRule struct { + StartKeyHex string `json:"start_key"` // hex format start key, for marshal/unmarshal + EndKeyHex string `json:"end_key"` // hex format end key, for marshal/unmarshal +} + +// PauseSchedulersByKeyRange will pause schedulers for regions in the specific key range. +// This function will spawn a goroutine to keep pausing schedulers periodically until the context is done. +// The return done channel is used to notify the caller that the background goroutine is exited. +func PauseSchedulersByKeyRange( + ctx context.Context, + pdHTTPCli pdhttp.Client, + startKey, endKey []byte, +) (done <-chan struct{}, err error) { + done, err = pauseSchedulerByKeyRangeWithTTL(ctx, pdHTTPCli, startKey, endKey, pauseTimeout) + // Wait for the rule to take effect because the PD operator is processed asynchronously. + // To synchronize this, checking the operator status may not be enough. For details, see + // https://github.com/pingcap/tidb/issues/49477. + // Let's use two times default value of `patrol-region-interval` from PD configuration. + <-time.After(20 * time.Millisecond) + return done, errors.Trace(err) +} + +func pauseSchedulerByKeyRangeWithTTL( + ctx context.Context, + pdHTTPCli pdhttp.Client, + startKey, endKey []byte, + ttl time.Duration, +) (<-chan struct{}, error) { + rule := &pdhttp.LabelRule{ + ID: uuid.New().String(), + Labels: []pdhttp.RegionLabel{{ + Key: "schedule", + Value: "deny", + TTL: ttl.String(), + }}, + RuleType: "key-range", + // Data should be a list of KeyRangeRule when rule type is key-range. + // See https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L169. + Data: []KeyRangeRule{{ + StartKeyHex: hex.EncodeToString(startKey), + EndKeyHex: hex.EncodeToString(endKey), + }}, + } + done := make(chan struct{}) + + if err := pdHTTPCli.SetRegionLabelRule(ctx, rule); err != nil { + close(done) + return nil, errors.Trace(err) + } + + go func() { + defer close(done) + ticker := time.NewTicker(ttl / 3) + defer ticker.Stop() + loop: + for { + select { + case <-ticker.C: + if err := pdHTTPCli.SetRegionLabelRule(ctx, rule); err != nil { + if berrors.IsContextCanceled(err) { + break loop + } + log.Warn("pause scheduler by key range failed, ignore it and wait next time pause", + zap.Error(err)) + } + case <-ctx.Done(): + break loop + } + } + // Use a new context to avoid the context is canceled by the caller. + recoverCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + // Set ttl to 0 to remove the rule. + rule.Labels[0].TTL = time.Duration(0).String() + deleteRule := &pdhttp.LabelRulePatch{DeleteRules: []string{rule.ID}} + if err := pdHTTPCli.PatchRegionLabelRules(recoverCtx, deleteRule); err != nil { + log.Warn("failed to delete region label rule, the rule will be removed after ttl expires", + zap.String("rule-id", rule.ID), zap.Duration("ttl", ttl), zap.Error(err)) + } + }() + return done, nil +} + +// CanPauseSchedulerByKeyRange returns whether the scheduler can be paused by key range. +func (p *PdController) CanPauseSchedulerByKeyRange() bool { + // We need ttl feature to ensure scheduler can recover from pause automatically. + return p.version.Compare(minVersionForRegionLabelTTL) >= 0 +} + +// Close closes the connection to pd. +func (p *PdController) Close() { + p.pdClient.Close() + if p.pdHTTPCli != nil { + // nil in some unit tests + p.pdHTTPCli.Close() + } + if p.schedulerPauseCh != nil { + close(p.schedulerPauseCh) + } +} + +func (p *PdController) ttlOfPausing() time.Duration { + if p.SchedulerPauseTTL > 0 { + return p.SchedulerPauseTTL + } + return pauseTimeout +} + +// FetchPDVersion get pd version +func FetchPDVersion(ctx context.Context, pdHTTPCli pdhttp.Client) (*semver.Version, error) { + ver, err := pdHTTPCli.GetPDVersion(ctx) + if err != nil { + return nil, errors.Trace(err) + } + + return parseVersion(ver), nil +} diff --git a/br/pkg/restore/binding__failpoint_binding__.go b/br/pkg/restore/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..dab6e72f95323 --- /dev/null +++ b/br/pkg/restore/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package restore + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/restore/log_client/binding__failpoint_binding__.go b/br/pkg/restore/log_client/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..b17db4979c387 --- /dev/null +++ b/br/pkg/restore/log_client/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package logclient + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/restore/log_client/client.go b/br/pkg/restore/log_client/client.go index fb34e709ce4ec..18a14c76c10e3 100644 --- a/br/pkg/restore/log_client/client.go +++ b/br/pkg/restore/log_client/client.go @@ -828,9 +828,9 @@ func (rc *LogClient) RestoreMetaKVFiles( filesInDefaultCF = SortMetaKVFiles(filesInDefaultCF) filesInWriteCF = SortMetaKVFiles(filesInWriteCF) - failpoint.Inject("failed-before-id-maps-saved", func(_ failpoint.Value) { - failpoint.Return(errors.New("failpoint: failed before id maps saved")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failed-before-id-maps-saved")); _err_ == nil { + return errors.New("failpoint: failed before id maps saved") + } log.Info("start to restore meta files", zap.Int("total files", len(files)), @@ -848,9 +848,9 @@ func (rc *LogClient) RestoreMetaKVFiles( return errors.Trace(err) } } - failpoint.Inject("failed-after-id-maps-saved", func(_ failpoint.Value) { - failpoint.Return(errors.New("failpoint: failed after id maps saved")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failed-after-id-maps-saved")); _err_ == nil { + return errors.New("failpoint: failed after id maps saved") + } // run the rewrite and restore meta-kv into TiKV cluster. if err := RestoreMetaKVFilesWithBatchMethod( @@ -1087,18 +1087,18 @@ func (rc *LogClient) restoreMetaKvEntries( log.Debug("after rewrite entry", zap.Int("new-key-len", len(newEntry.Key)), zap.Int("new-value-len", len(entry.E.Value)), zap.ByteString("new-key", newEntry.Key)) - failpoint.Inject("failed-to-restore-metakv", func(_ failpoint.Value) { - failpoint.Return(0, 0, errors.Errorf("failpoint: failed to restore metakv")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failed-to-restore-metakv")); _err_ == nil { + return 0, 0, errors.Errorf("failpoint: failed to restore metakv") + } if err := rc.rawKVClient.Put(ctx, newEntry.Key, newEntry.Value, entry.Ts); err != nil { return 0, 0, errors.Trace(err) } // for failpoint, we need to flush the cache in rawKVClient every time - failpoint.Inject("do-not-put-metakv-in-batch", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("do-not-put-metakv-in-batch")); _err_ == nil { if err := rc.rawKVClient.PutRest(ctx); err != nil { - failpoint.Return(0, 0, errors.Trace(err)) + return 0, 0, errors.Trace(err) } - }) + } kvCount++ size += uint64(len(newEntry.Key) + len(newEntry.Value)) } @@ -1397,11 +1397,11 @@ NEXTSQL: return errors.Trace(err) } } - failpoint.Inject("failed-before-create-ingest-index", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("failed-before-create-ingest-index")); _err_ == nil { if v != nil && v.(bool) { - failpoint.Return(errors.New("failed before create ingest index")) + return errors.New("failed before create ingest index") } - }) + } // create the repaired index when first execution or not found it if err := rc.se.ExecuteInternal(ctx, sql.AddSQL, sql.AddArgs...); err != nil { return errors.Trace(err) diff --git a/br/pkg/restore/log_client/client.go__failpoint_stash__ b/br/pkg/restore/log_client/client.go__failpoint_stash__ new file mode 100644 index 0000000000000..fb34e709ce4ec --- /dev/null +++ b/br/pkg/restore/log_client/client.go__failpoint_stash__ @@ -0,0 +1,1689 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logclient + +import ( + "cmp" + "context" + "crypto/tls" + "fmt" + "math" + "os" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/fatih/color" + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/checkpoint" + "github.com/pingcap/tidb/br/pkg/checksum" + "github.com/pingcap/tidb/br/pkg/conn" + "github.com/pingcap/tidb/br/pkg/conn/util" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/restore" + "github.com/pingcap/tidb/br/pkg/restore/ingestrec" + importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" + logsplit "github.com/pingcap/tidb/br/pkg/restore/internal/log_split" + "github.com/pingcap/tidb/br/pkg/restore/internal/rawkv" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/stream" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/utils/iter" + "github.com/pingcap/tidb/br/pkg/version" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/model" + tidbutil "github.com/pingcap/tidb/pkg/util" + filter "github.com/pingcap/tidb/pkg/util/table-filter" + "github.com/tikv/client-go/v2/config" + kvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/keepalive" +) + +const MetaKVBatchSize = 64 * 1024 * 1024 +const maxSplitKeysOnce = 10240 + +// rawKVBatchCount specifies the count of entries that the rawkv client puts into TiKV. +const rawKVBatchCount = 64 + +type LogClient struct { + cipher *backuppb.CipherInfo + pdClient pd.Client + pdHTTPClient pdhttp.Client + clusterID uint64 + dom *domain.Domain + tlsConf *tls.Config + keepaliveConf keepalive.ClientParameters + + rawKVClient *rawkv.RawKVBatchClient + storage storage.ExternalStorage + + se glue.Session + + // currentTS is used for rewrite meta kv when restore stream. + // Can not use `restoreTS` directly, because schema created in `full backup` maybe is new than `restoreTS`. + currentTS uint64 + + *LogFileManager + + workerPool *tidbutil.WorkerPool + fileImporter *LogFileImporter + + // the query to insert rows into table `gc_delete_range`, lack of ts. + deleteRangeQuery []*stream.PreDelRangeQuery + deleteRangeQueryCh chan *stream.PreDelRangeQuery + deleteRangeQueryWaitGroup sync.WaitGroup + + // checkpoint information for log restore + useCheckpoint bool +} + +// NewRestoreClient returns a new RestoreClient. +func NewRestoreClient( + pdClient pd.Client, + pdHTTPCli pdhttp.Client, + tlsConf *tls.Config, + keepaliveConf keepalive.ClientParameters, +) *LogClient { + return &LogClient{ + pdClient: pdClient, + pdHTTPClient: pdHTTPCli, + tlsConf: tlsConf, + keepaliveConf: keepaliveConf, + deleteRangeQuery: make([]*stream.PreDelRangeQuery, 0), + deleteRangeQueryCh: make(chan *stream.PreDelRangeQuery, 10), + } +} + +// Close a client. +func (rc *LogClient) Close() { + // close the connection, and it must be succeed when in SQL mode. + if rc.se != nil { + rc.se.Close() + } + + if rc.rawKVClient != nil { + rc.rawKVClient.Close() + } + + if err := rc.fileImporter.Close(); err != nil { + log.Warn("failed to close file improter") + } + + log.Info("Restore client closed") +} + +func (rc *LogClient) SetRawKVBatchClient( + ctx context.Context, + pdAddrs []string, + security config.Security, +) error { + rawkvClient, err := rawkv.NewRawkvClient(ctx, pdAddrs, security) + if err != nil { + return errors.Trace(err) + } + + rc.rawKVClient = rawkv.NewRawKVBatchClient(rawkvClient, rawKVBatchCount) + return nil +} + +func (rc *LogClient) SetCrypter(crypter *backuppb.CipherInfo) { + rc.cipher = crypter +} + +func (rc *LogClient) SetConcurrency(c uint) { + log.Info("download worker pool", zap.Uint("size", c)) + rc.workerPool = tidbutil.NewWorkerPool(c, "file") +} + +func (rc *LogClient) SetStorage(ctx context.Context, backend *backuppb.StorageBackend, opts *storage.ExternalStorageOptions) error { + var err error + rc.storage, err = storage.New(ctx, backend, opts) + if err != nil { + return errors.Trace(err) + } + return nil +} + +func (rc *LogClient) SetCurrentTS(ts uint64) { + rc.currentTS = ts +} + +// GetClusterID gets the cluster id from down-stream cluster. +func (rc *LogClient) GetClusterID(ctx context.Context) uint64 { + if rc.clusterID <= 0 { + rc.clusterID = rc.pdClient.GetClusterID(ctx) + } + return rc.clusterID +} + +func (rc *LogClient) GetDomain() *domain.Domain { + return rc.dom +} + +func (rc *LogClient) CleanUpKVFiles( + ctx context.Context, +) error { + // Current we only have v1 prefix. + // In the future, we can add more operation for this interface. + return rc.fileImporter.ClearFiles(ctx, rc.pdClient, "v1") +} + +func (rc *LogClient) StartCheckpointRunnerForLogRestore(ctx context.Context, taskName string) (*checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType], error) { + runner, err := checkpoint.StartCheckpointRunnerForLogRestore(ctx, rc.storage, rc.cipher, taskName) + return runner, errors.Trace(err) +} + +// Init create db connection and domain for storage. +func (rc *LogClient) Init(g glue.Glue, store kv.Storage) error { + var err error + rc.se, err = g.CreateSession(store) + if err != nil { + return errors.Trace(err) + } + + // Set SQL mode to None for avoiding SQL compatibility problem + err = rc.se.Execute(context.Background(), "set @@sql_mode=''") + if err != nil { + return errors.Trace(err) + } + + rc.dom, err = g.GetDomain(store) + if err != nil { + return errors.Trace(err) + } + + return nil +} + +func (rc *LogClient) InitClients(ctx context.Context, backend *backuppb.StorageBackend) { + stores, err := conn.GetAllTiKVStoresWithRetry(ctx, rc.pdClient, util.SkipTiFlash) + if err != nil { + log.Fatal("failed to get stores", zap.Error(err)) + } + + metaClient := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, len(stores)+1) + importCli := importclient.NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf) + rc.fileImporter = NewLogFileImporter(metaClient, importCli, backend) +} + +func (rc *LogClient) InitCheckpointMetadataForLogRestore(ctx context.Context, taskName string, gcRatio string) (string, error) { + rc.useCheckpoint = true + + // it shows that the user has modified gc-ratio, if `gcRatio` doesn't equal to "1.1". + // update the `gcRatio` for checkpoint metadata. + if gcRatio == utils.DefaultGcRatioVal { + // if the checkpoint metadata exists in the external storage, the restore is not + // for the first time. + exists, err := checkpoint.ExistsRestoreCheckpoint(ctx, rc.storage, taskName) + if err != nil { + return "", errors.Trace(err) + } + + if exists { + // load the checkpoint since this is not the first time to restore + meta, err := checkpoint.LoadCheckpointMetadataForRestore(ctx, rc.storage, taskName) + if err != nil { + return "", errors.Trace(err) + } + + log.Info("reuse gc ratio from checkpoint metadata", zap.String("gc-ratio", gcRatio)) + return meta.GcRatio, nil + } + } + + // initialize the checkpoint metadata since it is the first time to restore. + log.Info("save gc ratio into checkpoint metadata", zap.String("gc-ratio", gcRatio)) + if err := checkpoint.SaveCheckpointMetadataForRestore(ctx, rc.storage, &checkpoint.CheckpointMetadataForRestore{ + GcRatio: gcRatio, + }, taskName); err != nil { + return gcRatio, errors.Trace(err) + } + + return gcRatio, nil +} + +func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint) error { + init := LogFileManagerInit{ + StartTS: startTS, + RestoreTS: restoreTS, + Storage: rc.storage, + + MetadataDownloadBatchSize: metadataDownloadBatchSize, + } + var err error + rc.LogFileManager, err = CreateLogFileManager(ctx, init) + if err != nil { + return err + } + return nil +} + +type FilesInRegion struct { + defaultSize uint64 + defaultKVCount int64 + writeSize uint64 + writeKVCount int64 + + defaultFiles []*LogDataFileInfo + writeFiles []*LogDataFileInfo + deleteFiles []*LogDataFileInfo +} + +type FilesInTable struct { + regionMapFiles map[int64]*FilesInRegion +} + +func ApplyKVFilesWithBatchMethod( + ctx context.Context, + logIter LogIter, + batchCount int, + batchSize uint64, + applyFunc func(files []*LogDataFileInfo, kvCount int64, size uint64), + applyWg *sync.WaitGroup, +) error { + var ( + tableMapFiles = make(map[int64]*FilesInTable) + tmpFiles = make([]*LogDataFileInfo, 0, batchCount) + tmpSize uint64 = 0 + tmpKVCount int64 = 0 + ) + for r := logIter.TryNext(ctx); !r.Finished; r = logIter.TryNext(ctx) { + if r.Err != nil { + return r.Err + } + + f := r.Item + if f.GetType() == backuppb.FileType_Put && f.GetLength() >= batchSize { + applyFunc([]*LogDataFileInfo{f}, f.GetNumberOfEntries(), f.GetLength()) + continue + } + + fit, exist := tableMapFiles[f.TableId] + if !exist { + fit = &FilesInTable{ + regionMapFiles: make(map[int64]*FilesInRegion), + } + tableMapFiles[f.TableId] = fit + } + fs, exist := fit.regionMapFiles[f.RegionId] + if !exist { + fs = &FilesInRegion{} + fit.regionMapFiles[f.RegionId] = fs + } + + if f.GetType() == backuppb.FileType_Delete { + if fs.defaultFiles == nil { + fs.deleteFiles = make([]*LogDataFileInfo, 0) + } + fs.deleteFiles = append(fs.deleteFiles, f) + } else { + if f.GetCf() == stream.DefaultCF { + if fs.defaultFiles == nil { + fs.defaultFiles = make([]*LogDataFileInfo, 0, batchCount) + } + fs.defaultFiles = append(fs.defaultFiles, f) + fs.defaultSize += f.Length + fs.defaultKVCount += f.GetNumberOfEntries() + if len(fs.defaultFiles) >= batchCount || fs.defaultSize >= batchSize { + applyFunc(fs.defaultFiles, fs.defaultKVCount, fs.defaultSize) + fs.defaultFiles = nil + fs.defaultSize = 0 + fs.defaultKVCount = 0 + } + } else { + if fs.writeFiles == nil { + fs.writeFiles = make([]*LogDataFileInfo, 0, batchCount) + } + fs.writeFiles = append(fs.writeFiles, f) + fs.writeSize += f.GetLength() + fs.writeKVCount += f.GetNumberOfEntries() + if len(fs.writeFiles) >= batchCount || fs.writeSize >= batchSize { + applyFunc(fs.writeFiles, fs.writeKVCount, fs.writeSize) + fs.writeFiles = nil + fs.writeSize = 0 + fs.writeKVCount = 0 + } + } + } + } + + for _, fwt := range tableMapFiles { + for _, fs := range fwt.regionMapFiles { + if len(fs.defaultFiles) > 0 { + applyFunc(fs.defaultFiles, fs.defaultKVCount, fs.defaultSize) + } + if len(fs.writeFiles) > 0 { + applyFunc(fs.writeFiles, fs.writeKVCount, fs.writeSize) + } + } + } + + applyWg.Wait() + for _, fwt := range tableMapFiles { + for _, fs := range fwt.regionMapFiles { + for _, d := range fs.deleteFiles { + tmpFiles = append(tmpFiles, d) + tmpSize += d.GetLength() + tmpKVCount += d.GetNumberOfEntries() + + if len(tmpFiles) >= batchCount || tmpSize >= batchSize { + applyFunc(tmpFiles, tmpKVCount, tmpSize) + tmpFiles = make([]*LogDataFileInfo, 0, batchCount) + tmpSize = 0 + tmpKVCount = 0 + } + } + if len(tmpFiles) > 0 { + applyFunc(tmpFiles, tmpKVCount, tmpSize) + tmpFiles = make([]*LogDataFileInfo, 0, batchCount) + tmpSize = 0 + tmpKVCount = 0 + } + } + } + + return nil +} + +func ApplyKVFilesWithSingelMethod( + ctx context.Context, + files LogIter, + applyFunc func(file []*LogDataFileInfo, kvCount int64, size uint64), + applyWg *sync.WaitGroup, +) error { + deleteKVFiles := make([]*LogDataFileInfo, 0) + + for r := files.TryNext(ctx); !r.Finished; r = files.TryNext(ctx) { + if r.Err != nil { + return r.Err + } + + f := r.Item + if f.GetType() == backuppb.FileType_Delete { + deleteKVFiles = append(deleteKVFiles, f) + continue + } + applyFunc([]*LogDataFileInfo{f}, f.GetNumberOfEntries(), f.GetLength()) + } + + applyWg.Wait() + log.Info("restore delete files", zap.Int("count", len(deleteKVFiles))) + for _, file := range deleteKVFiles { + f := file + applyFunc([]*LogDataFileInfo{f}, f.GetNumberOfEntries(), f.GetLength()) + } + + return nil +} + +func (rc *LogClient) RestoreKVFiles( + ctx context.Context, + rules map[int64]*restoreutils.RewriteRules, + idrules map[int64]int64, + logIter LogIter, + runner *checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType], + pitrBatchCount uint32, + pitrBatchSize uint32, + updateStats func(kvCount uint64, size uint64), + onProgress func(cnt int64), +) error { + var ( + err error + fileCount = 0 + start = time.Now() + supportBatch = version.CheckPITRSupportBatchKVFiles() + skipFile = 0 + ) + defer func() { + if err == nil { + elapsed := time.Since(start) + log.Info("Restore KV files", zap.Duration("take", elapsed)) + summary.CollectSuccessUnit("files", fileCount, elapsed) + } + }() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("Client.RestoreKVFiles", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + var applyWg sync.WaitGroup + eg, ectx := errgroup.WithContext(ctx) + applyFunc := func(files []*LogDataFileInfo, kvCount int64, size uint64) { + if len(files) == 0 { + return + } + // get rewrite rule from table id. + // because the tableID of files is the same. + rule, ok := rules[files[0].TableId] + if !ok { + // TODO handle new created table + // For this version we do not handle new created table after full backup. + // in next version we will perform rewrite and restore meta key to restore new created tables. + // so we can simply skip the file that doesn't have the rule here. + onProgress(int64(len(files))) + summary.CollectInt("FileSkip", len(files)) + log.Debug("skip file due to table id not matched", zap.Int64("table-id", files[0].TableId)) + skipFile += len(files) + } else { + applyWg.Add(1) + downstreamId := idrules[files[0].TableId] + rc.workerPool.ApplyOnErrorGroup(eg, func() (err error) { + fileStart := time.Now() + defer applyWg.Done() + defer func() { + onProgress(int64(len(files))) + updateStats(uint64(kvCount), size) + summary.CollectInt("File", len(files)) + + if err == nil { + filenames := make([]string, 0, len(files)) + if runner == nil { + for _, f := range files { + filenames = append(filenames, f.Path+", ") + } + } else { + for _, f := range files { + filenames = append(filenames, f.Path+", ") + if e := checkpoint.AppendRangeForLogRestore(ectx, runner, f.MetaDataGroupName, downstreamId, f.OffsetInMetaGroup, f.OffsetInMergedGroup); e != nil { + err = errors.Annotate(e, "failed to append checkpoint data") + break + } + } + } + log.Info("import files done", zap.Int("batch-count", len(files)), zap.Uint64("batch-size", size), + zap.Duration("take", time.Since(fileStart)), zap.Strings("files", filenames)) + } + }() + + return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, supportBatch) + }) + } + } + + rc.workerPool.ApplyOnErrorGroup(eg, func() error { + if supportBatch { + err = ApplyKVFilesWithBatchMethod(ectx, logIter, int(pitrBatchCount), uint64(pitrBatchSize), applyFunc, &applyWg) + } else { + err = ApplyKVFilesWithSingelMethod(ectx, logIter, applyFunc, &applyWg) + } + return errors.Trace(err) + }) + + if err = eg.Wait(); err != nil { + summary.CollectFailureUnit("file", err) + log.Error("restore files failed", zap.Error(err)) + } + + log.Info("total skip files due to table id not matched", zap.Int("count", skipFile)) + if skipFile > 0 { + log.Debug("table id in full backup storage", zap.Any("tables", rules)) + } + + return errors.Trace(err) +} + +func (rc *LogClient) initSchemasMap( + ctx context.Context, + clusterID uint64, + restoreTS uint64, +) ([]*backuppb.PitrDBMap, error) { + filename := metautil.PitrIDMapsFilename(clusterID, restoreTS) + exist, err := rc.storage.FileExists(ctx, filename) + if err != nil { + return nil, errors.Annotatef(err, "failed to check filename:%s ", filename) + } else if !exist { + log.Info("pitr id maps isn't existed", zap.String("file", filename)) + return nil, nil + } + + metaData, err := rc.storage.ReadFile(ctx, filename) + if err != nil { + return nil, errors.Trace(err) + } + backupMeta := &backuppb.BackupMeta{} + if err = backupMeta.Unmarshal(metaData); err != nil { + return nil, errors.Trace(err) + } + + return backupMeta.GetDbMaps(), nil +} + +func initFullBackupTables( + ctx context.Context, + s storage.ExternalStorage, + tableFilter filter.Filter, +) (map[int64]*metautil.Table, error) { + metaData, err := s.ReadFile(ctx, metautil.MetaFile) + if err != nil { + return nil, errors.Trace(err) + } + backupMeta := &backuppb.BackupMeta{} + if err = backupMeta.Unmarshal(metaData); err != nil { + return nil, errors.Trace(err) + } + + // read full backup databases to get map[table]table.Info + reader := metautil.NewMetaReader(backupMeta, s, nil) + + databases, err := metautil.LoadBackupTables(ctx, reader, false) + if err != nil { + return nil, errors.Trace(err) + } + + tables := make(map[int64]*metautil.Table) + for _, db := range databases { + dbName := db.Info.Name.O + if name, ok := utils.GetSysDBName(db.Info.Name); utils.IsSysDB(name) && ok { + dbName = name + } + + if !tableFilter.MatchSchema(dbName) { + continue + } + + for _, table := range db.Tables { + // check this db is empty. + if table.Info == nil { + tables[db.Info.ID] = table + continue + } + if !tableFilter.MatchTable(dbName, table.Info.Name.O) { + continue + } + tables[table.Info.ID] = table + } + } + + return tables, nil +} + +type FullBackupStorageConfig struct { + Backend *backuppb.StorageBackend + Opts *storage.ExternalStorageOptions +} + +type InitSchemaConfig struct { + // required + IsNewTask bool + TableFilter filter.Filter + + // optional + TiFlashRecorder *tiflashrec.TiFlashRecorder + FullBackupStorage *FullBackupStorageConfig +} + +const UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL = "UNSAFE_PITR_LOG_RESTORE_START_BEFORE_ANY_UPSTREAM_USER_DDL" + +func (rc *LogClient) generateDBReplacesFromFullBackupStorage( + ctx context.Context, + cfg *InitSchemaConfig, +) (map[stream.UpstreamID]*stream.DBReplace, error) { + dbReplaces := make(map[stream.UpstreamID]*stream.DBReplace) + if cfg.FullBackupStorage == nil { + envVal, ok := os.LookupEnv(UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL) + if ok && len(envVal) > 0 { + log.Info(fmt.Sprintf("the environment variable %s is active, skip loading the base schemas.", UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL)) + return dbReplaces, nil + } + return nil, errors.Errorf("miss upstream table information at `start-ts`(%d) but the full backup path is not specified", rc.startTS) + } + s, err := storage.New(ctx, cfg.FullBackupStorage.Backend, cfg.FullBackupStorage.Opts) + if err != nil { + return nil, errors.Trace(err) + } + fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter) + if err != nil { + return nil, errors.Trace(err) + } + for _, t := range fullBackupTables { + dbName, _ := utils.GetSysDBCIStrName(t.DB.Name) + newDBInfo, exist := rc.dom.InfoSchema().SchemaByName(dbName) + if !exist { + log.Info("db not existed", zap.String("dbname", dbName.String())) + continue + } + + dbReplace, exist := dbReplaces[t.DB.ID] + if !exist { + dbReplace = stream.NewDBReplace(t.DB.Name.O, newDBInfo.ID) + dbReplaces[t.DB.ID] = dbReplace + } + + if t.Info == nil { + // If the db is empty, skip it. + continue + } + newTableInfo, err := restore.GetTableSchema(rc.GetDomain(), dbName, t.Info.Name) + if err != nil { + log.Info("table not existed", zap.String("tablename", dbName.String()+"."+t.Info.Name.String())) + continue + } + + dbReplace.TableMap[t.Info.ID] = &stream.TableReplace{ + Name: newTableInfo.Name.O, + TableID: newTableInfo.ID, + PartitionMap: restoreutils.GetPartitionIDMap(newTableInfo, t.Info), + IndexMap: restoreutils.GetIndexIDMap(newTableInfo, t.Info), + } + } + return dbReplaces, nil +} + +// InitSchemasReplaceForDDL gets schemas information Mapping from old schemas to new schemas. +// It is used to rewrite meta kv-event. +func (rc *LogClient) InitSchemasReplaceForDDL( + ctx context.Context, + cfg *InitSchemaConfig, +) (*stream.SchemasReplace, error) { + var ( + err error + dbMaps []*backuppb.PitrDBMap + // the id map doesn't need to construct only when it is not the first execution + needConstructIdMap bool + + dbReplaces map[stream.UpstreamID]*stream.DBReplace + ) + + // not new task, load schemas map from external storage + if !cfg.IsNewTask { + log.Info("try to load pitr id maps") + needConstructIdMap = false + dbMaps, err = rc.initSchemasMap(ctx, rc.GetClusterID(ctx), rc.restoreTS) + if err != nil { + return nil, errors.Trace(err) + } + } + + // a new task, but without full snapshot restore, tries to load + // schemas map whose `restore-ts`` is the task's `start-ts`. + if len(dbMaps) <= 0 && cfg.FullBackupStorage == nil { + log.Info("try to load pitr id maps of the previous task", zap.Uint64("start-ts", rc.startTS)) + needConstructIdMap = true + dbMaps, err = rc.initSchemasMap(ctx, rc.GetClusterID(ctx), rc.startTS) + if err != nil { + return nil, errors.Trace(err) + } + existTiFlashTable := false + rc.dom.InfoSchema().ListTablesWithSpecialAttribute(func(tableInfo *model.TableInfo) bool { + if tableInfo.TiFlashReplica != nil && tableInfo.TiFlashReplica.Count > 0 { + existTiFlashTable = true + } + return false + }) + if existTiFlashTable { + return nil, errors.Errorf("exist table(s) have tiflash replica, please remove it before restore") + } + } + + if len(dbMaps) <= 0 { + log.Info("no id maps, build the table replaces from cluster and full backup schemas") + needConstructIdMap = true + dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg) + if err != nil { + return nil, errors.Trace(err) + } + } else { + dbReplaces = stream.FromSchemaMaps(dbMaps) + } + + for oldDBID, dbReplace := range dbReplaces { + log.Info("replace info", func() []zapcore.Field { + fields := make([]zapcore.Field, 0, (len(dbReplace.TableMap)+1)*3) + fields = append(fields, + zap.String("dbName", dbReplace.Name), + zap.Int64("oldID", oldDBID), + zap.Int64("newID", dbReplace.DbID)) + for oldTableID, tableReplace := range dbReplace.TableMap { + fields = append(fields, + zap.String("table", tableReplace.Name), + zap.Int64("oldID", oldTableID), + zap.Int64("newID", tableReplace.TableID)) + } + return fields + }()...) + } + + rp := stream.NewSchemasReplace( + dbReplaces, needConstructIdMap, cfg.TiFlashRecorder, rc.currentTS, cfg.TableFilter, rc.GenGlobalID, rc.GenGlobalIDs, + rc.RecordDeleteRange) + return rp, nil +} + +func SortMetaKVFiles(files []*backuppb.DataFileInfo) []*backuppb.DataFileInfo { + slices.SortFunc(files, func(i, j *backuppb.DataFileInfo) int { + if c := cmp.Compare(i.GetMinTs(), j.GetMinTs()); c != 0 { + return c + } + if c := cmp.Compare(i.GetMaxTs(), j.GetMaxTs()); c != 0 { + return c + } + return cmp.Compare(i.GetResolvedTs(), j.GetResolvedTs()) + }) + return files +} + +// RestoreMetaKVFiles tries to restore files about meta kv-event from stream-backup. +func (rc *LogClient) RestoreMetaKVFiles( + ctx context.Context, + files []*backuppb.DataFileInfo, + schemasReplace *stream.SchemasReplace, + updateStats func(kvCount uint64, size uint64), + progressInc func(), +) error { + filesInWriteCF := make([]*backuppb.DataFileInfo, 0, len(files)) + filesInDefaultCF := make([]*backuppb.DataFileInfo, 0, len(files)) + + // The k-v events in default CF should be restored firstly. The reason is that: + // The error of transactions of meta could happen if restore write CF events successfully, + // but failed to restore default CF events. + for _, f := range files { + if f.Cf == stream.WriteCF { + filesInWriteCF = append(filesInWriteCF, f) + continue + } + if f.Type == backuppb.FileType_Delete { + // this should happen abnormally. + // only do some preventive checks here. + log.Warn("detected delete file of meta key, skip it", zap.Any("file", f)) + continue + } + if f.Cf == stream.DefaultCF { + filesInDefaultCF = append(filesInDefaultCF, f) + } + } + filesInDefaultCF = SortMetaKVFiles(filesInDefaultCF) + filesInWriteCF = SortMetaKVFiles(filesInWriteCF) + + failpoint.Inject("failed-before-id-maps-saved", func(_ failpoint.Value) { + failpoint.Return(errors.New("failpoint: failed before id maps saved")) + }) + + log.Info("start to restore meta files", + zap.Int("total files", len(files)), + zap.Int("default files", len(filesInDefaultCF)), + zap.Int("write files", len(filesInWriteCF))) + + if schemasReplace.NeedConstructIdMap() { + // Preconstruct the map and save it into external storage. + if err := rc.PreConstructAndSaveIDMap( + ctx, + filesInWriteCF, + filesInDefaultCF, + schemasReplace, + ); err != nil { + return errors.Trace(err) + } + } + failpoint.Inject("failed-after-id-maps-saved", func(_ failpoint.Value) { + failpoint.Return(errors.New("failpoint: failed after id maps saved")) + }) + + // run the rewrite and restore meta-kv into TiKV cluster. + if err := RestoreMetaKVFilesWithBatchMethod( + ctx, + filesInDefaultCF, + filesInWriteCF, + schemasReplace, + updateStats, + progressInc, + rc.RestoreBatchMetaKVFiles, + ); err != nil { + return errors.Trace(err) + } + + // Update global schema version and report all of TiDBs. + if err := rc.UpdateSchemaVersion(ctx); err != nil { + return errors.Trace(err) + } + return nil +} + +// PreConstructAndSaveIDMap constructs id mapping and save it. +func (rc *LogClient) PreConstructAndSaveIDMap( + ctx context.Context, + fsInWriteCF, fsInDefaultCF []*backuppb.DataFileInfo, + sr *stream.SchemasReplace, +) error { + sr.SetPreConstructMapStatus() + + if err := rc.constructIDMap(ctx, fsInWriteCF, sr); err != nil { + return errors.Trace(err) + } + if err := rc.constructIDMap(ctx, fsInDefaultCF, sr); err != nil { + return errors.Trace(err) + } + + if err := rc.SaveIDMap(ctx, sr); err != nil { + return errors.Trace(err) + } + return nil +} + +func (rc *LogClient) constructIDMap( + ctx context.Context, + fs []*backuppb.DataFileInfo, + sr *stream.SchemasReplace, +) error { + for _, f := range fs { + entries, _, err := rc.ReadAllEntries(ctx, f, math.MaxUint64) + if err != nil { + return errors.Trace(err) + } + + for _, entry := range entries { + if _, err := sr.RewriteKvEntry(&entry.E, f.GetCf()); err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +func RestoreMetaKVFilesWithBatchMethod( + ctx context.Context, + defaultFiles []*backuppb.DataFileInfo, + writeFiles []*backuppb.DataFileInfo, + schemasReplace *stream.SchemasReplace, + updateStats func(kvCount uint64, size uint64), + progressInc func(), + restoreBatch func( + ctx context.Context, + files []*backuppb.DataFileInfo, + schemasReplace *stream.SchemasReplace, + kvEntries []*KvEntryWithTS, + filterTS uint64, + updateStats func(kvCount uint64, size uint64), + progressInc func(), + cf string, + ) ([]*KvEntryWithTS, error), +) error { + // the average size of each KV is 2560 Bytes + // kvEntries is kvs left by the previous batch + const kvSize = 2560 + var ( + rangeMin uint64 + rangeMax uint64 + err error + + batchSize uint64 = 0 + defaultIdx int = 0 + writeIdx int = 0 + + defaultKvEntries = make([]*KvEntryWithTS, 0) + writeKvEntries = make([]*KvEntryWithTS, 0) + ) + // Set restoreKV to SchemaReplace. + schemasReplace.SetRestoreKVStatus() + + for i, f := range defaultFiles { + if i == 0 { + rangeMax = f.MaxTs + rangeMin = f.MinTs + batchSize = f.Length + } else { + if f.MinTs <= rangeMax && batchSize+f.Length <= MetaKVBatchSize { + rangeMin = min(rangeMin, f.MinTs) + rangeMax = max(rangeMax, f.MaxTs) + batchSize += f.Length + } else { + // Either f.MinTS > rangeMax or f.MinTs is the filterTs we need. + // So it is ok to pass f.MinTs as filterTs. + defaultKvEntries, err = restoreBatch(ctx, defaultFiles[defaultIdx:i], schemasReplace, defaultKvEntries, f.MinTs, updateStats, progressInc, stream.DefaultCF) + if err != nil { + return errors.Trace(err) + } + defaultIdx = i + rangeMin = f.MinTs + rangeMax = f.MaxTs + // the initial batch size is the size of left kvs and the current file length. + batchSize = uint64(len(defaultKvEntries)*kvSize) + f.Length + + // restore writeCF kv to f.MinTs + var toWriteIdx int + for toWriteIdx = writeIdx; toWriteIdx < len(writeFiles); toWriteIdx++ { + if writeFiles[toWriteIdx].MinTs >= f.MinTs { + break + } + } + writeKvEntries, err = restoreBatch(ctx, writeFiles[writeIdx:toWriteIdx], schemasReplace, writeKvEntries, f.MinTs, updateStats, progressInc, stream.WriteCF) + if err != nil { + return errors.Trace(err) + } + writeIdx = toWriteIdx + } + } + } + + // restore the left meta kv files and entries + // Notice: restoreBatch needs to realize the parameter `files` and `kvEntries` might be empty + // Assert: defaultIdx <= len(defaultFiles) && writeIdx <= len(writeFiles) + _, err = restoreBatch(ctx, defaultFiles[defaultIdx:], schemasReplace, defaultKvEntries, math.MaxUint64, updateStats, progressInc, stream.DefaultCF) + if err != nil { + return errors.Trace(err) + } + _, err = restoreBatch(ctx, writeFiles[writeIdx:], schemasReplace, writeKvEntries, math.MaxUint64, updateStats, progressInc, stream.WriteCF) + if err != nil { + return errors.Trace(err) + } + + return nil +} + +func (rc *LogClient) RestoreBatchMetaKVFiles( + ctx context.Context, + files []*backuppb.DataFileInfo, + schemasReplace *stream.SchemasReplace, + kvEntries []*KvEntryWithTS, + filterTS uint64, + updateStats func(kvCount uint64, size uint64), + progressInc func(), + cf string, +) ([]*KvEntryWithTS, error) { + nextKvEntries := make([]*KvEntryWithTS, 0) + curKvEntries := make([]*KvEntryWithTS, 0) + if len(files) == 0 && len(kvEntries) == 0 { + return nextKvEntries, nil + } + + // filter the kv from kvEntries again. + for _, kv := range kvEntries { + if kv.Ts < filterTS { + curKvEntries = append(curKvEntries, kv) + } else { + nextKvEntries = append(nextKvEntries, kv) + } + } + + // read all of entries from files. + for _, f := range files { + es, nextEs, err := rc.ReadAllEntries(ctx, f, filterTS) + if err != nil { + return nextKvEntries, errors.Trace(err) + } + + curKvEntries = append(curKvEntries, es...) + nextKvEntries = append(nextKvEntries, nextEs...) + } + + // sort these entries. + slices.SortFunc(curKvEntries, func(i, j *KvEntryWithTS) int { + return cmp.Compare(i.Ts, j.Ts) + }) + + // restore these entries with rawPut() method. + kvCount, size, err := rc.restoreMetaKvEntries(ctx, schemasReplace, curKvEntries, cf) + if err != nil { + return nextKvEntries, errors.Trace(err) + } + + if schemasReplace.IsRestoreKVStatus() { + updateStats(kvCount, size) + for i := 0; i < len(files); i++ { + progressInc() + } + } + return nextKvEntries, nil +} + +func (rc *LogClient) restoreMetaKvEntries( + ctx context.Context, + sr *stream.SchemasReplace, + entries []*KvEntryWithTS, + columnFamily string, +) (uint64, uint64, error) { + var ( + kvCount uint64 + size uint64 + ) + + rc.rawKVClient.SetColumnFamily(columnFamily) + + for _, entry := range entries { + log.Debug("before rewrte entry", zap.Uint64("key-ts", entry.Ts), zap.Int("key-len", len(entry.E.Key)), + zap.Int("value-len", len(entry.E.Value)), zap.ByteString("key", entry.E.Key)) + + newEntry, err := sr.RewriteKvEntry(&entry.E, columnFamily) + if err != nil { + log.Error("rewrite txn entry failed", zap.Int("klen", len(entry.E.Key)), + logutil.Key("txn-key", entry.E.Key)) + return 0, 0, errors.Trace(err) + } else if newEntry == nil { + continue + } + log.Debug("after rewrite entry", zap.Int("new-key-len", len(newEntry.Key)), + zap.Int("new-value-len", len(entry.E.Value)), zap.ByteString("new-key", newEntry.Key)) + + failpoint.Inject("failed-to-restore-metakv", func(_ failpoint.Value) { + failpoint.Return(0, 0, errors.Errorf("failpoint: failed to restore metakv")) + }) + if err := rc.rawKVClient.Put(ctx, newEntry.Key, newEntry.Value, entry.Ts); err != nil { + return 0, 0, errors.Trace(err) + } + // for failpoint, we need to flush the cache in rawKVClient every time + failpoint.Inject("do-not-put-metakv-in-batch", func(_ failpoint.Value) { + if err := rc.rawKVClient.PutRest(ctx); err != nil { + failpoint.Return(0, 0, errors.Trace(err)) + } + }) + kvCount++ + size += uint64(len(newEntry.Key) + len(newEntry.Value)) + } + + return kvCount, size, rc.rawKVClient.PutRest(ctx) +} + +// GenGlobalID generates a global id by transaction way. +func (rc *LogClient) GenGlobalID(ctx context.Context) (int64, error) { + var id int64 + storage := rc.GetDomain().Store() + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) + err := kv.RunInNewTxn( + ctx, + storage, + true, + func(ctx context.Context, txn kv.Transaction) error { + var e error + t := meta.NewMeta(txn) + id, e = t.GenGlobalID() + return e + }) + + return id, err +} + +// GenGlobalIDs generates several global ids by transaction way. +func (rc *LogClient) GenGlobalIDs(ctx context.Context, n int) ([]int64, error) { + ids := make([]int64, 0) + storage := rc.GetDomain().Store() + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) + err := kv.RunInNewTxn( + ctx, + storage, + true, + func(ctx context.Context, txn kv.Transaction) error { + var e error + t := meta.NewMeta(txn) + ids, e = t.GenGlobalIDs(n) + return e + }) + + return ids, err +} + +// UpdateSchemaVersion updates schema version by transaction way. +func (rc *LogClient) UpdateSchemaVersion(ctx context.Context) error { + storage := rc.GetDomain().Store() + var schemaVersion int64 + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) + if err := kv.RunInNewTxn( + ctx, + storage, + true, + func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + var e error + // To trigger full-reload instead of diff-reload, we need to increase the schema version + // by at least `domain.LoadSchemaDiffVersionGapThreshold`. + schemaVersion, e = t.GenSchemaVersions(64 + domain.LoadSchemaDiffVersionGapThreshold) + if e != nil { + return e + } + // add the diff key so that the domain won't retry to reload the schemas with `schemaVersion` frequently. + return t.SetSchemaDiff(&model.SchemaDiff{ + Version: schemaVersion, + Type: model.ActionNone, + SchemaID: -1, + TableID: -1, + RegenerateSchemaMap: true, + }) + }, + ); err != nil { + return errors.Trace(err) + } + + log.Info("update global schema version", zap.Int64("global-schema-version", schemaVersion)) + + ver := strconv.FormatInt(schemaVersion, 10) + if err := ddlutil.PutKVToEtcd( + ctx, + rc.GetDomain().GetEtcdClient(), + math.MaxInt, + ddlutil.DDLGlobalSchemaVersion, + ver, + ); err != nil { + return errors.Annotatef(err, "failed to put global schema verson %v to etcd", ver) + } + + return nil +} + +func (rc *LogClient) WrapLogFilesIterWithSplitHelper(logIter LogIter, rules map[int64]*restoreutils.RewriteRules, g glue.Glue, store kv.Storage) (LogIter, error) { + se, err := g.CreateSession(store) + if err != nil { + return nil, errors.Trace(err) + } + execCtx := se.GetSessionCtx().GetRestrictedSQLExecutor() + splitSize, splitKeys := utils.GetRegionSplitInfo(execCtx) + log.Info("get split threshold from tikv config", zap.Uint64("split-size", splitSize), zap.Int64("split-keys", splitKeys)) + client := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, 3) + return NewLogFilesIterWithSplitHelper(logIter, rules, client, splitSize, splitKeys), nil +} + +func (rc *LogClient) generateKvFilesSkipMap(ctx context.Context, downstreamIdset map[int64]struct{}, taskName string) (*LogFilesSkipMap, error) { + skipMap := NewLogFilesSkipMap() + t, err := checkpoint.WalkCheckpointFileForRestore(ctx, rc.storage, rc.cipher, taskName, func(groupKey checkpoint.LogRestoreKeyType, off checkpoint.LogRestoreValueMarshaled) { + for tableID, foffs := range off.Foffs { + // filter out the checkpoint data of dropped table + if _, exists := downstreamIdset[tableID]; exists { + for _, foff := range foffs { + skipMap.Insert(groupKey, off.Goff, foff) + } + } + } + }) + if err != nil { + return nil, errors.Trace(err) + } + summary.AdjustStartTimeToEarlierTime(t) + return skipMap, nil +} + +func (rc *LogClient) WrapLogFilesIterWithCheckpoint( + ctx context.Context, + logIter LogIter, + downstreamIdset map[int64]struct{}, + taskName string, + updateStats func(kvCount, size uint64), + onProgress func(), +) (LogIter, error) { + skipMap, err := rc.generateKvFilesSkipMap(ctx, downstreamIdset, taskName) + if err != nil { + return nil, errors.Trace(err) + } + return iter.FilterOut(logIter, func(d *LogDataFileInfo) bool { + if skipMap.NeedSkip(d.MetaDataGroupName, d.OffsetInMetaGroup, d.OffsetInMergedGroup) { + onProgress() + updateStats(uint64(d.NumberOfEntries), d.Length) + return true + } + return false + }), nil +} + +const ( + alterTableDropIndexSQL = "ALTER TABLE %n.%n DROP INDEX %n" + alterTableAddIndexFormat = "ALTER TABLE %%n.%%n ADD INDEX %%n(%s)" + alterTableAddUniqueIndexFormat = "ALTER TABLE %%n.%%n ADD UNIQUE KEY %%n(%s)" + alterTableAddPrimaryFormat = "ALTER TABLE %%n.%%n ADD PRIMARY KEY (%s) NONCLUSTERED" +) + +func (rc *LogClient) generateRepairIngestIndexSQLs( + ctx context.Context, + ingestRecorder *ingestrec.IngestRecorder, + taskName string, +) ([]checkpoint.CheckpointIngestIndexRepairSQL, bool, error) { + var sqls []checkpoint.CheckpointIngestIndexRepairSQL + if rc.useCheckpoint { + exists, err := checkpoint.ExistsCheckpointIngestIndexRepairSQLs(ctx, rc.storage, taskName) + if err != nil { + return sqls, false, errors.Trace(err) + } + if exists { + checkpointSQLs, err := checkpoint.LoadCheckpointIngestIndexRepairSQLs(ctx, rc.storage, taskName) + if err != nil { + return sqls, false, errors.Trace(err) + } + sqls = checkpointSQLs.SQLs + log.Info("load ingest index repair sqls from checkpoint", zap.String("category", "ingest"), zap.Reflect("sqls", sqls)) + return sqls, true, nil + } + } + + if err := ingestRecorder.UpdateIndexInfo(rc.dom.InfoSchema()); err != nil { + return sqls, false, errors.Trace(err) + } + if err := ingestRecorder.Iterate(func(_, indexID int64, info *ingestrec.IngestIndexInfo) error { + var ( + addSQL strings.Builder + addArgs []any = make([]any, 0, 5+len(info.ColumnArgs)) + ) + if info.IsPrimary { + addSQL.WriteString(fmt.Sprintf(alterTableAddPrimaryFormat, info.ColumnList)) + addArgs = append(addArgs, info.SchemaName.O, info.TableName.O) + addArgs = append(addArgs, info.ColumnArgs...) + } else if info.IndexInfo.Unique { + addSQL.WriteString(fmt.Sprintf(alterTableAddUniqueIndexFormat, info.ColumnList)) + addArgs = append(addArgs, info.SchemaName.O, info.TableName.O, info.IndexInfo.Name.O) + addArgs = append(addArgs, info.ColumnArgs...) + } else { + addSQL.WriteString(fmt.Sprintf(alterTableAddIndexFormat, info.ColumnList)) + addArgs = append(addArgs, info.SchemaName.O, info.TableName.O, info.IndexInfo.Name.O) + addArgs = append(addArgs, info.ColumnArgs...) + } + // USING BTREE/HASH/RTREE + indexTypeStr := info.IndexInfo.Tp.String() + if len(indexTypeStr) > 0 { + addSQL.WriteString(" USING ") + addSQL.WriteString(indexTypeStr) + } + + // COMMENT [...] + if len(info.IndexInfo.Comment) > 0 { + addSQL.WriteString(" COMMENT %?") + addArgs = append(addArgs, info.IndexInfo.Comment) + } + + if info.IndexInfo.Invisible { + addSQL.WriteString(" INVISIBLE") + } else { + addSQL.WriteString(" VISIBLE") + } + + sqls = append(sqls, checkpoint.CheckpointIngestIndexRepairSQL{ + IndexID: indexID, + SchemaName: info.SchemaName, + TableName: info.TableName, + IndexName: info.IndexInfo.Name.O, + AddSQL: addSQL.String(), + AddArgs: addArgs, + }) + + return nil + }); err != nil { + return sqls, false, errors.Trace(err) + } + + if rc.useCheckpoint && len(sqls) > 0 { + if err := checkpoint.SaveCheckpointIngestIndexRepairSQLs(ctx, rc.storage, &checkpoint.CheckpointIngestIndexRepairSQLs{ + SQLs: sqls, + }, taskName); err != nil { + return sqls, false, errors.Trace(err) + } + } + return sqls, false, nil +} + +// RepairIngestIndex drops the indexes from IngestRecorder and re-add them. +func (rc *LogClient) RepairIngestIndex(ctx context.Context, ingestRecorder *ingestrec.IngestRecorder, g glue.Glue, taskName string) error { + sqls, fromCheckpoint, err := rc.generateRepairIngestIndexSQLs(ctx, ingestRecorder, taskName) + if err != nil { + return errors.Trace(err) + } + + info := rc.dom.InfoSchema() + console := glue.GetConsole(g) +NEXTSQL: + for _, sql := range sqls { + progressTitle := fmt.Sprintf("repair ingest index %s for table %s.%s", sql.IndexName, sql.SchemaName, sql.TableName) + + tableInfo, err := info.TableByName(ctx, sql.SchemaName, sql.TableName) + if err != nil { + return errors.Trace(err) + } + oldIndexIDFound := false + if fromCheckpoint { + for _, idx := range tableInfo.Indices() { + indexInfo := idx.Meta() + if indexInfo.ID == sql.IndexID { + // the original index id is not dropped + oldIndexIDFound = true + break + } + // what if index's state is not public? + if indexInfo.Name.O == sql.IndexName { + // find the same name index, but not the same index id, + // which means the repaired index id is created + if _, err := fmt.Fprintf(console.Out(), "%s ... %s\n", progressTitle, color.HiGreenString("SKIPPED DUE TO CHECKPOINT MODE")); err != nil { + return errors.Trace(err) + } + continue NEXTSQL + } + } + } + + if err := func(sql checkpoint.CheckpointIngestIndexRepairSQL) error { + w := console.StartProgressBar(progressTitle, glue.OnlyOneTask) + defer w.Close() + + // TODO: When the TiDB supports the DROP and CREATE the same name index in one SQL, + // the checkpoint for ingest recorder can be removed and directly use the SQL: + // ALTER TABLE db.tbl DROP INDEX `i_1`, ADD IDNEX `i_1` ... + // + // This SQL is compatible with checkpoint: If one ingest index has been recreated by + // the SQL, the index's id would be another one. In the next retry execution, BR can + // not find the ingest index's dropped id so that BR regards it as a dropped index by + // restored metakv and then skips repairing it. + + // only when first execution or old index id is not dropped + if !fromCheckpoint || oldIndexIDFound { + if err := rc.se.ExecuteInternal(ctx, alterTableDropIndexSQL, sql.SchemaName.O, sql.TableName.O, sql.IndexName); err != nil { + return errors.Trace(err) + } + } + failpoint.Inject("failed-before-create-ingest-index", func(v failpoint.Value) { + if v != nil && v.(bool) { + failpoint.Return(errors.New("failed before create ingest index")) + } + }) + // create the repaired index when first execution or not found it + if err := rc.se.ExecuteInternal(ctx, sql.AddSQL, sql.AddArgs...); err != nil { + return errors.Trace(err) + } + w.Inc() + if err := w.Wait(ctx); err != nil { + return errors.Trace(err) + } + return nil + }(sql); err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func (rc *LogClient) RecordDeleteRange(sql *stream.PreDelRangeQuery) { + rc.deleteRangeQueryCh <- sql +} + +// use channel to save the delete-range query to make it thread-safety. +func (rc *LogClient) RunGCRowsLoader(ctx context.Context) { + rc.deleteRangeQueryWaitGroup.Add(1) + + go func() { + defer rc.deleteRangeQueryWaitGroup.Done() + for { + select { + case <-ctx.Done(): + return + case query, ok := <-rc.deleteRangeQueryCh: + if !ok { + return + } + rc.deleteRangeQuery = append(rc.deleteRangeQuery, query) + } + } + }() +} + +// InsertGCRows insert the querys into table `gc_delete_range` +func (rc *LogClient) InsertGCRows(ctx context.Context) error { + close(rc.deleteRangeQueryCh) + rc.deleteRangeQueryWaitGroup.Wait() + ts, err := restore.GetTSWithRetry(ctx, rc.pdClient) + if err != nil { + return errors.Trace(err) + } + jobIDMap := make(map[int64]int64) + for _, query := range rc.deleteRangeQuery { + paramsList := make([]any, 0, len(query.ParamsList)*5) + for _, params := range query.ParamsList { + newJobID, exists := jobIDMap[params.JobID] + if !exists { + newJobID, err = rc.GenGlobalID(ctx) + if err != nil { + return errors.Trace(err) + } + jobIDMap[params.JobID] = newJobID + } + log.Info("insert into the delete range", + zap.Int64("jobID", newJobID), + zap.Int64("elemID", params.ElemID), + zap.String("startKey", params.StartKey), + zap.String("endKey", params.EndKey), + zap.Uint64("ts", ts)) + // (job_id, elem_id, start_key, end_key, ts) + paramsList = append(paramsList, newJobID, params.ElemID, params.StartKey, params.EndKey, ts) + } + if len(paramsList) > 0 { + // trim the ',' behind the query.Sql if exists + // that's when the rewrite rule of the last table id is not exist + sql := strings.TrimSuffix(query.Sql, ",") + if err := rc.se.ExecuteInternal(ctx, sql, paramsList...); err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +// only for unit test +func (rc *LogClient) GetGCRows() []*stream.PreDelRangeQuery { + close(rc.deleteRangeQueryCh) + rc.deleteRangeQueryWaitGroup.Wait() + return rc.deleteRangeQuery +} + +// SaveIDMap saves the id mapping information. +func (rc *LogClient) SaveIDMap( + ctx context.Context, + sr *stream.SchemasReplace, +) error { + idMaps := sr.TidySchemaMaps() + clusterID := rc.GetClusterID(ctx) + metaFileName := metautil.PitrIDMapsFilename(clusterID, rc.restoreTS) + metaWriter := metautil.NewMetaWriter(rc.storage, metautil.MetaFileSize, false, metaFileName, nil) + metaWriter.Update(func(m *backuppb.BackupMeta) { + // save log startTS to backupmeta file + m.ClusterId = clusterID + m.DbMaps = idMaps + }) + + if err := metaWriter.FlushBackupMeta(ctx); err != nil { + return errors.Trace(err) + } + if rc.useCheckpoint { + var items map[int64]model.TiFlashReplicaInfo + if sr.TiflashRecorder != nil { + items = sr.TiflashRecorder.GetItems() + } + log.Info("save checkpoint task info with InLogRestoreAndIdMapPersist status") + if err := checkpoint.SaveCheckpointTaskInfoForLogRestore(ctx, rc.storage, &checkpoint.CheckpointTaskInfoForLogRestore{ + Progress: checkpoint.InLogRestoreAndIdMapPersist, + StartTS: rc.startTS, + RestoreTS: rc.restoreTS, + RewriteTS: rc.currentTS, + TiFlashItems: items, + }, rc.GetClusterID(ctx)); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// called by failpoint, only used for test +// it would print the checksum result into the log, and +// the auto-test script records them to compare another +// cluster's checksum. +func (rc *LogClient) FailpointDoChecksumForLogRestore( + ctx context.Context, + kvClient kv.Client, + pdClient pd.Client, + idrules map[int64]int64, + rewriteRules map[int64]*restoreutils.RewriteRules, +) (finalErr error) { + startTS, err := restore.GetTSWithRetry(ctx, rc.pdClient) + if err != nil { + return errors.Trace(err) + } + // set gc safepoint for checksum + sp := utils.BRServiceSafePoint{ + BackupTS: startTS, + TTL: utils.DefaultBRGCSafePointTTL, + ID: utils.MakeSafePointID(), + } + cctx, gcSafePointKeeperCancel := context.WithCancel(ctx) + defer func() { + log.Info("start to remove gc-safepoint keeper") + // close the gc safe point keeper at first + gcSafePointKeeperCancel() + // set the ttl to 0 to remove the gc-safe-point + sp.TTL = 0 + if err := utils.UpdateServiceSafePoint(ctx, pdClient, sp); err != nil { + log.Warn("failed to update service safe point, backup may fail if gc triggered", + zap.Error(err), + ) + } + log.Info("finish removing gc-safepoint keeper") + }() + err = utils.StartServiceSafePointKeeper(cctx, pdClient, sp) + if err != nil { + return errors.Trace(err) + } + + eg, ectx := errgroup.WithContext(ctx) + pool := tidbutil.NewWorkerPool(4, "checksum for log restore") + infoSchema := rc.GetDomain().InfoSchema() + // downstream id -> upstream id + reidRules := make(map[int64]int64) + for upstreamID, downstreamID := range idrules { + reidRules[downstreamID] = upstreamID + } + for upstreamID, downstreamID := range idrules { + newTable, ok := infoSchema.TableByID(downstreamID) + if !ok { + // a dropped table + continue + } + rewriteRule, ok := rewriteRules[upstreamID] + if !ok { + continue + } + newTableInfo := newTable.Meta() + var definitions []model.PartitionDefinition + if newTableInfo.Partition != nil { + for _, def := range newTableInfo.Partition.Definitions { + upid, ok := reidRules[def.ID] + if !ok { + log.Panic("no rewrite rule for parition table id", zap.Int64("id", def.ID)) + } + definitions = append(definitions, model.PartitionDefinition{ + ID: upid, + }) + } + } + oldPartition := &model.PartitionInfo{ + Definitions: definitions, + } + oldTable := &metautil.Table{ + Info: &model.TableInfo{ + ID: upstreamID, + Indices: newTableInfo.Indices, + Partition: oldPartition, + }, + } + pool.ApplyOnErrorGroup(eg, func() error { + exe, err := checksum.NewExecutorBuilder(newTableInfo, startTS). + SetOldTable(oldTable). + SetConcurrency(4). + SetOldKeyspace(rewriteRule.OldKeyspace). + SetNewKeyspace(rewriteRule.NewKeyspace). + SetExplicitRequestSourceType(kvutil.ExplicitTypeBR). + Build() + if err != nil { + return errors.Trace(err) + } + checksumResp, err := exe.Execute(ectx, kvClient, func() {}) + if err != nil { + return errors.Trace(err) + } + // print to log so that the test script can get final checksum + log.Info("failpoint checksum completed", + zap.String("table-name", newTableInfo.Name.O), + zap.Int64("upstream-id", oldTable.Info.ID), + zap.Uint64("checksum", checksumResp.Checksum), + zap.Uint64("total-kvs", checksumResp.TotalKvs), + zap.Uint64("total-bytes", checksumResp.TotalBytes), + ) + return nil + }) + } + + return eg.Wait() +} + +type LogFilesIterWithSplitHelper struct { + iter LogIter + helper *logsplit.LogSplitHelper + buffer []*LogDataFileInfo + next int +} + +const SplitFilesBufferSize = 4096 + +func NewLogFilesIterWithSplitHelper(iter LogIter, rules map[int64]*restoreutils.RewriteRules, client split.SplitClient, splitSize uint64, splitKeys int64) LogIter { + return &LogFilesIterWithSplitHelper{ + iter: iter, + helper: logsplit.NewLogSplitHelper(rules, client, splitSize, splitKeys), + buffer: nil, + next: 0, + } +} + +func (splitIter *LogFilesIterWithSplitHelper) TryNext(ctx context.Context) iter.IterResult[*LogDataFileInfo] { + if splitIter.next >= len(splitIter.buffer) { + splitIter.buffer = make([]*LogDataFileInfo, 0, SplitFilesBufferSize) + for r := splitIter.iter.TryNext(ctx); !r.Finished; r = splitIter.iter.TryNext(ctx) { + if r.Err != nil { + return r + } + f := r.Item + splitIter.helper.Merge(f.DataFileInfo) + splitIter.buffer = append(splitIter.buffer, f) + if len(splitIter.buffer) >= SplitFilesBufferSize { + break + } + } + splitIter.next = 0 + if len(splitIter.buffer) == 0 { + return iter.Done[*LogDataFileInfo]() + } + log.Info("start to split the regions") + startTime := time.Now() + if err := splitIter.helper.Split(ctx); err != nil { + return iter.Throw[*LogDataFileInfo](errors.Trace(err)) + } + log.Info("end to split the regions", zap.Duration("takes", time.Since(startTime))) + } + + res := iter.Emit(splitIter.buffer[splitIter.next]) + splitIter.next += 1 + return res +} diff --git a/br/pkg/restore/log_client/import_retry.go b/br/pkg/restore/log_client/import_retry.go index 93f454d6252e5..ab7d60c7e98b1 100644 --- a/br/pkg/restore/log_client/import_retry.go +++ b/br/pkg/restore/log_client/import_retry.go @@ -92,12 +92,12 @@ func (o *OverRegionsInRangeController) handleInRegionError(ctx context.Context, if strings.Contains(result.StoreError.GetMessage(), "memory is limited") { sleepDuration := 15 * time.Second - failpoint.Inject("hint-memory-is-limited", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("hint-memory-is-limited")); _err_ == nil { if val.(bool) { logutil.CL(ctx).Debug("failpoint hint-memory-is-limited injected.") sleepDuration = 100 * time.Microsecond } - }) + } time.Sleep(sleepDuration) return true } diff --git a/br/pkg/restore/log_client/import_retry.go__failpoint_stash__ b/br/pkg/restore/log_client/import_retry.go__failpoint_stash__ new file mode 100644 index 0000000000000..93f454d6252e5 --- /dev/null +++ b/br/pkg/restore/log_client/import_retry.go__failpoint_stash__ @@ -0,0 +1,284 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package logclient + +import ( + "context" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/metapb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/restore/split" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/tikv/client-go/v2/kv" + "go.uber.org/multierr" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type RegionFunc func(ctx context.Context, r *split.RegionInfo) RPCResult + +type OverRegionsInRangeController struct { + start []byte + end []byte + metaClient split.SplitClient + + errors error + rs *utils.RetryState +} + +// OverRegionsInRange creates a controller that cloud be used to scan regions in a range and +// apply a function over these regions. +// You can then call the `Run` method for applying some functions. +func OverRegionsInRange(start, end []byte, metaClient split.SplitClient, retryStatus *utils.RetryState) OverRegionsInRangeController { + // IMPORTANT: we record the start/end key with TimeStamp. + // but scanRegion will drop the TimeStamp and the end key is exclusive. + // if we do not use PrefixNextKey. we might scan fewer regions than we expected. + // and finally cause the data lost. + end = restoreutils.TruncateTS(end) + end = kv.PrefixNextKey(end) + + return OverRegionsInRangeController{ + start: start, + end: end, + metaClient: metaClient, + rs: retryStatus, + } +} + +func (o *OverRegionsInRangeController) onError(_ context.Context, result RPCResult, region *split.RegionInfo) { + o.errors = multierr.Append(o.errors, errors.Annotatef(&result, "execute over region %v failed", region.Region)) + // TODO: Maybe handle some of region errors like `epoch not match`? +} + +func (o *OverRegionsInRangeController) tryFindLeader(ctx context.Context, region *split.RegionInfo) (*metapb.Peer, error) { + var leader *metapb.Peer + failed := false + leaderRs := utils.InitialRetryState(4, 5*time.Second, 10*time.Second) + err := utils.WithRetry(ctx, func() error { + r, err := o.metaClient.GetRegionByID(ctx, region.Region.Id) + if err != nil { + return err + } + if !split.CheckRegionEpoch(r, region) { + failed = true + return nil + } + if r.Leader != nil { + leader = r.Leader + return nil + } + return errors.Annotatef(berrors.ErrPDLeaderNotFound, "there is no leader for region %d", region.Region.Id) + }, &leaderRs) + if failed { + return nil, errors.Annotatef(berrors.ErrKVEpochNotMatch, "the current epoch of %s is changed", region) + } + if err != nil { + return nil, err + } + return leader, nil +} + +// handleInRegionError handles the error happens internal in the region. Update the region info, and perform a suitable backoff. +func (o *OverRegionsInRangeController) handleInRegionError(ctx context.Context, result RPCResult, region *split.RegionInfo) (cont bool) { + if result.StoreError.GetServerIsBusy() != nil { + if strings.Contains(result.StoreError.GetMessage(), "memory is limited") { + sleepDuration := 15 * time.Second + + failpoint.Inject("hint-memory-is-limited", func(val failpoint.Value) { + if val.(bool) { + logutil.CL(ctx).Debug("failpoint hint-memory-is-limited injected.") + sleepDuration = 100 * time.Microsecond + } + }) + time.Sleep(sleepDuration) + return true + } + } + + if nl := result.StoreError.GetNotLeader(); nl != nil { + if nl.Leader != nil { + region.Leader = nl.Leader + // try the new leader immediately. + return true + } + // we retry manually, simply record the retry event. + time.Sleep(o.rs.ExponentialBackoff()) + // There may not be leader, waiting... + leader, err := o.tryFindLeader(ctx, region) + if err != nil { + // Leave the region info unchanged, let it retry then. + logutil.CL(ctx).Warn("failed to find leader", logutil.Region(region.Region), logutil.ShortError(err)) + return false + } + region.Leader = leader + return true + } + // For other errors, like `ServerIsBusy`, `RegionIsNotInitialized`, just trivially backoff. + time.Sleep(o.rs.ExponentialBackoff()) + return true +} + +func (o *OverRegionsInRangeController) prepareLogCtx(ctx context.Context) context.Context { + lctx := logutil.ContextWithField( + ctx, + logutil.Key("startKey", o.start), + logutil.Key("endKey", o.end), + ) + return lctx +} + +// Run executes the `regionFunc` over the regions in `o.start` and `o.end`. +// It would retry the errors according to the `rpcResponse`. +func (o *OverRegionsInRangeController) Run(ctx context.Context, f RegionFunc) error { + return o.runOverRegions(o.prepareLogCtx(ctx), f) +} + +func (o *OverRegionsInRangeController) runOverRegions(ctx context.Context, f RegionFunc) error { + if !o.rs.ShouldRetry() { + return o.errors + } + + // Scan regions covered by the file range + regionInfos, errScanRegion := split.PaginateScanRegion( + ctx, o.metaClient, o.start, o.end, split.ScanRegionPaginationLimit) + if errScanRegion != nil { + return errors.Trace(errScanRegion) + } + + for _, region := range regionInfos { + cont, err := o.runInRegion(ctx, f, region) + if err != nil { + return err + } + if !cont { + return nil + } + } + return nil +} + +// runInRegion executes the function in the region, and returns `cont = false` if no need for trying for next region. +func (o *OverRegionsInRangeController) runInRegion(ctx context.Context, f RegionFunc, region *split.RegionInfo) (cont bool, err error) { + if !o.rs.ShouldRetry() { + return false, o.errors + } + result := f(ctx, region) + + if !result.OK() { + o.onError(ctx, result, region) + switch result.StrategyForRetry() { + case StrategyGiveUp: + logutil.CL(ctx).Warn("unexpected error, should stop to retry", logutil.ShortError(&result), logutil.Region(region.Region)) + return false, o.errors + case StrategyFromThisRegion: + logutil.CL(ctx).Warn("retry for region", logutil.Region(region.Region), logutil.ShortError(&result)) + if !o.handleInRegionError(ctx, result, region) { + return false, o.runOverRegions(ctx, f) + } + return o.runInRegion(ctx, f, region) + case StrategyFromStart: + logutil.CL(ctx).Warn("retry for execution over regions", logutil.ShortError(&result)) + // TODO: make a backoffer considering more about the error info, + // instead of ingore the result and retry. + time.Sleep(o.rs.ExponentialBackoff()) + return false, o.runOverRegions(ctx, f) + } + } + return true, nil +} + +// RPCResult is the result after executing some RPCs to TiKV. +type RPCResult struct { + Err error + + ImportError string + StoreError *errorpb.Error +} + +func RPCResultFromPBError(err *import_sstpb.Error) RPCResult { + return RPCResult{ + ImportError: err.GetMessage(), + StoreError: err.GetStoreError(), + } +} + +func RPCResultFromError(err error) RPCResult { + return RPCResult{ + Err: err, + } +} + +func RPCResultOK() RPCResult { + return RPCResult{} +} + +type RetryStrategy int + +const ( + StrategyGiveUp RetryStrategy = iota + StrategyFromThisRegion + StrategyFromStart +) + +func (r *RPCResult) StrategyForRetry() RetryStrategy { + if r.Err != nil { + return r.StrategyForRetryGoError() + } + return r.StrategyForRetryStoreError() +} + +func (r *RPCResult) StrategyForRetryStoreError() RetryStrategy { + if r.StoreError == nil && r.ImportError == "" { + return StrategyGiveUp + } + + if r.StoreError.GetServerIsBusy() != nil || + r.StoreError.GetRegionNotInitialized() != nil || + r.StoreError.GetNotLeader() != nil || + r.StoreError.GetServerIsBusy() != nil { + return StrategyFromThisRegion + } + + return StrategyFromStart +} + +func (r *RPCResult) StrategyForRetryGoError() RetryStrategy { + if r.Err == nil { + return StrategyGiveUp + } + + // we should unwrap the error or we cannot get the write gRPC status. + if gRPCErr, ok := status.FromError(errors.Cause(r.Err)); ok { + switch gRPCErr.Code() { + case codes.Unavailable, codes.Aborted, codes.ResourceExhausted, codes.DeadlineExceeded: + return StrategyFromThisRegion + } + } + + return StrategyGiveUp +} + +func (r *RPCResult) Error() string { + if r.Err != nil { + return r.Err.Error() + } + if r.StoreError != nil { + return r.StoreError.GetMessage() + } + if r.ImportError != "" { + return r.ImportError + } + return "BUG(There is no error but reported as error)" +} + +func (r *RPCResult) OK() bool { + return r.Err == nil && r.ImportError == "" && r.StoreError == nil +} diff --git a/br/pkg/restore/misc.go b/br/pkg/restore/misc.go index 62d7fbc32fdb4..469b4ea7b9cca 100644 --- a/br/pkg/restore/misc.go +++ b/br/pkg/restore/misc.go @@ -137,11 +137,11 @@ func GetTSWithRetry(ctx context.Context, pdClient pd.Client) (uint64, error) { err := utils.WithRetry(ctx, func() error { startTS, getTSErr = GetTS(ctx, pdClient) - failpoint.Inject("get-ts-error", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("get-ts-error")); _err_ == nil { if val.(bool) && retry < 3 { getTSErr = errors.Errorf("rpc error: code = Unknown desc = [PD:tso:ErrGenerateTimestamp]generate timestamp failed, requested pd is not leader of cluster") } - }) + } retry++ if getTSErr != nil { diff --git a/br/pkg/restore/misc.go__failpoint_stash__ b/br/pkg/restore/misc.go__failpoint_stash__ new file mode 100644 index 0000000000000..62d7fbc32fdb4 --- /dev/null +++ b/br/pkg/restore/misc.go__failpoint_stash__ @@ -0,0 +1,157 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package restore + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/model" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" +) + +// deprecated parameter +type Granularity string + +const ( + FineGrained Granularity = "fine-grained" + CoarseGrained Granularity = "coarse-grained" +) + +type UniqueTableName struct { + DB string + Table string +} + +func TransferBoolToValue(enable bool) string { + if enable { + return "ON" + } + return "OFF" +} + +// GetTableSchema returns the schema of a table from TiDB. +func GetTableSchema( + dom *domain.Domain, + dbName model.CIStr, + tableName model.CIStr, +) (*model.TableInfo, error) { + info := dom.InfoSchema() + table, err := info.TableByName(context.Background(), dbName, tableName) + if err != nil { + return nil, errors.Trace(err) + } + return table.Meta(), nil +} + +const maxUserTablesNum = 10 + +// AssertUserDBsEmpty check whether user dbs exist in the cluster +func AssertUserDBsEmpty(dom *domain.Domain) error { + databases := dom.InfoSchema().AllSchemas() + m := meta.NewSnapshotMeta(dom.Store().GetSnapshot(kv.MaxVersion)) + userTables := make([]string, 0, maxUserTablesNum+1) + appendTables := func(dbName, tableName string) bool { + if len(userTables) >= maxUserTablesNum { + userTables = append(userTables, "...") + return true + } + userTables = append(userTables, fmt.Sprintf("%s.%s", dbName, tableName)) + return false + } +LISTDBS: + for _, db := range databases { + dbName := db.Name.L + if tidbutil.IsMemOrSysDB(dbName) { + continue + } + tables, err := m.ListSimpleTables(db.ID) + if err != nil { + return errors.Annotatef(err, "failed to iterator tables of database[id=%d]", db.ID) + } + if len(tables) == 0 { + // tidb create test db on fresh cluster + // if it's empty we don't take it as user db + if dbName != "test" { + if appendTables(db.Name.O, "") { + break LISTDBS + } + } + continue + } + for _, table := range tables { + if appendTables(db.Name.O, table.Name.O) { + break LISTDBS + } + } + } + if len(userTables) > 0 { + return errors.Annotate(berrors.ErrRestoreNotFreshCluster, + "user db/tables: "+strings.Join(userTables, ", ")) + } + return nil +} + +// GetTS gets a new timestamp from PD. +func GetTS(ctx context.Context, pdClient pd.Client) (uint64, error) { + p, l, err := pdClient.GetTS(ctx) + if err != nil { + return 0, errors.Trace(err) + } + restoreTS := oracle.ComposeTS(p, l) + return restoreTS, nil +} + +// GetTSWithRetry gets a new timestamp with retry from PD. +func GetTSWithRetry(ctx context.Context, pdClient pd.Client) (uint64, error) { + var ( + startTS uint64 + getTSErr error + retry uint + ) + + err := utils.WithRetry(ctx, func() error { + startTS, getTSErr = GetTS(ctx, pdClient) + failpoint.Inject("get-ts-error", func(val failpoint.Value) { + if val.(bool) && retry < 3 { + getTSErr = errors.Errorf("rpc error: code = Unknown desc = [PD:tso:ErrGenerateTimestamp]generate timestamp failed, requested pd is not leader of cluster") + } + }) + + retry++ + if getTSErr != nil { + log.Warn("failed to get TS, retry it", zap.Uint("retry time", retry), logutil.ShortError(getTSErr)) + } + return getTSErr + }, utils.NewPDReqBackoffer()) + + if err != nil { + log.Error("failed to get TS", zap.Error(err)) + } + return startTS, errors.Trace(err) +} diff --git a/br/pkg/restore/snap_client/binding__failpoint_binding__.go b/br/pkg/restore/snap_client/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..777c78dc6633b --- /dev/null +++ b/br/pkg/restore/snap_client/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package snapclient + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/restore/snap_client/client.go b/br/pkg/restore/snap_client/client.go index 061fad6388016..d1fb0f18c227f 100644 --- a/br/pkg/restore/snap_client/client.go +++ b/br/pkg/restore/snap_client/client.go @@ -726,11 +726,11 @@ func (rc *SnapClient) createTablesInWorkerPool(ctx context.Context, tables []*me workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { db := rc.dbPool[id%uint64(len(rc.dbPool))] cts, err := rc.createTables(ectx, db, tableSlice, newTS) // ddl job for [lastSent:i) - failpoint.Inject("restore-createtables-error", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("restore-createtables-error")); _err_ == nil { if val.(bool) { err = errors.New("sample error without extra message") } - }) + } if err != nil { log.Error("create tables fail", zap.Error(err)) return err @@ -825,9 +825,9 @@ func (rc *SnapClient) IsFullClusterRestore() bool { // IsFull returns whether this backup is full. func (rc *SnapClient) IsFull() bool { - failpoint.Inject("mock-incr-backup-data", func() { - failpoint.Return(false) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mock-incr-backup-data")); _err_ == nil { + return false + } return !rc.IsIncremental() } diff --git a/br/pkg/restore/snap_client/client.go__failpoint_stash__ b/br/pkg/restore/snap_client/client.go__failpoint_stash__ new file mode 100644 index 0000000000000..061fad6388016 --- /dev/null +++ b/br/pkg/restore/snap_client/client.go__failpoint_stash__ @@ -0,0 +1,1200 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapclient + +import ( + "bytes" + "cmp" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "slices" + "strings" + "sync" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/checkpoint" + "github.com/pingcap/tidb/br/pkg/checksum" + "github.com/pingcap/tidb/br/pkg/conn" + "github.com/pingcap/tidb/br/pkg/conn/util" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/pdutil" + "github.com/pingcap/tidb/br/pkg/restore" + importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" + tidallocdb "github.com/pingcap/tidb/br/pkg/restore/internal/prealloc_db" + tidalloc "github.com/pingcap/tidb/br/pkg/restore/internal/prealloc_table_id" + internalutils "github.com/pingcap/tidb/br/pkg/restore/internal/utils" + "github.com/pingcap/tidb/br/pkg/restore/split" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/rtree" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/model" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/redact" + kvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/keepalive" +) + +const ( + strictPlacementPolicyMode = "STRICT" + ignorePlacementPolicyMode = "IGNORE" + + defaultDDLConcurrency = 16 + maxSplitKeysOnce = 10240 +) + +const minBatchDdlSize = 1 + +type SnapClient struct { + // Tool clients used by SnapClient + fileImporter *SnapFileImporter + pdClient pd.Client + pdHTTPClient pdhttp.Client + + // User configurable parameters + cipher *backuppb.CipherInfo + concurrencyPerStore uint + keepaliveConf keepalive.ClientParameters + rateLimit uint64 + tlsConf *tls.Config + + switchCh chan struct{} + + storeCount int + supportPolicy bool + workerPool *tidbutil.WorkerPool + + noSchema bool + hasSpeedLimited bool + + databases map[string]*metautil.Database + ddlJobs []*model.Job + + // store tables need to rebase info like auto id and random id and so on after create table + rebasedTablesMap map[restore.UniqueTableName]bool + + backupMeta *backuppb.BackupMeta + + // TODO Remove this field or replace it with a []*DB, + // since https://github.com/pingcap/br/pull/377 needs more DBs to speed up DDL execution. + // And for now, we must inject a pool of DBs to `Client.GoCreateTables`, otherwise there would be a race condition. + // This is dirty: why we need DBs from different sources? + // By replace it with a []*DB, we can remove the dirty parameter of `Client.GoCreateTable`, + // along with them in some private functions. + // Before you do it, you can firstly read discussions at + // https://github.com/pingcap/br/pull/377#discussion_r446594501, + // this probably isn't as easy as it seems like (however, not hard, too :D) + db *tidallocdb.DB + + // use db pool to speed up restoration in BR binary mode. + dbPool []*tidallocdb.DB + + dom *domain.Domain + + // correspond to --tidb-placement-mode config. + // STRICT(default) means policy related SQL can be executed in tidb. + // IGNORE means policy related SQL will be ignored. + policyMode string + + // policy name -> policy info + policyMap *sync.Map + + batchDdlSize uint + + // if fullClusterRestore = true: + // - if there's system tables in the backup(backup data since br 5.1.0), the cluster should be a fresh cluster + // without user database or table. and system tables about privileges is restored together with user data. + // - if there no system tables in the backup(backup data from br < 5.1.0), restore all user data just like + // previous version did. + // if fullClusterRestore = false, restore all user data just like previous version did. + // fullClusterRestore = true when there is no explicit filter setting, and it's full restore or point command + // with a full backup data. + // todo: maybe change to an enum + // this feature is controlled by flag with-sys-table + fullClusterRestore bool + + // see RestoreCommonConfig.WithSysTable + withSysTable bool + + // the rewrite mode of the downloaded SST files in TiKV. + rewriteMode RewriteMode + + // checkpoint information for snapshot restore + checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType] + checkpointChecksum map[int64]*checkpoint.ChecksumItem +} + +// NewRestoreClient returns a new RestoreClient. +func NewRestoreClient( + pdClient pd.Client, + pdHTTPCli pdhttp.Client, + tlsConf *tls.Config, + keepaliveConf keepalive.ClientParameters, +) *SnapClient { + return &SnapClient{ + pdClient: pdClient, + pdHTTPClient: pdHTTPCli, + tlsConf: tlsConf, + keepaliveConf: keepaliveConf, + switchCh: make(chan struct{}), + } +} + +func (rc *SnapClient) closeConn() { + // rc.db can be nil in raw kv mode. + if rc.db != nil { + rc.db.Close() + } + for _, db := range rc.dbPool { + db.Close() + } +} + +// Close a client. +func (rc *SnapClient) Close() { + // close the connection, and it must be succeed when in SQL mode. + rc.closeConn() + + if err := rc.fileImporter.Close(); err != nil { + log.Warn("failed to close file improter") + } + + log.Info("Restore client closed") +} + +func (rc *SnapClient) SetRateLimit(rateLimit uint64) { + rc.rateLimit = rateLimit +} + +func (rc *SnapClient) SetCrypter(crypter *backuppb.CipherInfo) { + rc.cipher = crypter +} + +// GetClusterID gets the cluster id from down-stream cluster. +func (rc *SnapClient) GetClusterID(ctx context.Context) uint64 { + return rc.pdClient.GetClusterID(ctx) +} + +func (rc *SnapClient) GetDomain() *domain.Domain { + return rc.dom +} + +// GetTLSConfig returns the tls config. +func (rc *SnapClient) GetTLSConfig() *tls.Config { + return rc.tlsConf +} + +// GetSupportPolicy tells whether target tidb support placement policy. +func (rc *SnapClient) GetSupportPolicy() bool { + return rc.supportPolicy +} + +func (rc *SnapClient) updateConcurrency() { + // we believe 32 is large enough for download worker pool. + // it won't reach the limit if sst files distribute evenly. + // when restore memory usage is still too high, we should reduce concurrencyPerStore + // to sarifice some speed to reduce memory usage. + count := uint(rc.storeCount) * rc.concurrencyPerStore * 32 + log.Info("download coarse worker pool", zap.Uint("size", count)) + rc.workerPool = tidbutil.NewWorkerPool(count, "file") +} + +// SetConcurrencyPerStore sets the concurrency of download files for each store. +func (rc *SnapClient) SetConcurrencyPerStore(c uint) { + log.Info("per-store download worker pool", zap.Uint("size", c)) + rc.concurrencyPerStore = c +} + +func (rc *SnapClient) SetBatchDdlSize(batchDdlsize uint) { + rc.batchDdlSize = batchDdlsize +} + +func (rc *SnapClient) GetBatchDdlSize() uint { + return rc.batchDdlSize +} + +func (rc *SnapClient) SetWithSysTable(withSysTable bool) { + rc.withSysTable = withSysTable +} + +// TODO: remove this check and return RewriteModeKeyspace +func (rc *SnapClient) SetRewriteMode(ctx context.Context) { + if err := version.CheckClusterVersion(ctx, rc.pdClient, version.CheckVersionForKeyspaceBR); err != nil { + log.Warn("Keyspace BR is not supported in this cluster, fallback to legacy restore", zap.Error(err)) + rc.rewriteMode = RewriteModeLegacy + } else { + rc.rewriteMode = RewriteModeKeyspace + } +} + +func (rc *SnapClient) GetRewriteMode() RewriteMode { + return rc.rewriteMode +} + +// SetPlacementPolicyMode to policy mode. +func (rc *SnapClient) SetPlacementPolicyMode(withPlacementPolicy string) { + switch strings.ToUpper(withPlacementPolicy) { + case strictPlacementPolicyMode: + rc.policyMode = strictPlacementPolicyMode + case ignorePlacementPolicyMode: + rc.policyMode = ignorePlacementPolicyMode + default: + rc.policyMode = strictPlacementPolicyMode + } + log.Info("set placement policy mode", zap.String("mode", rc.policyMode)) +} + +// AllocTableIDs would pre-allocate the table's origin ID if exists, so that the TiKV doesn't need to rewrite the key in +// the download stage. +func (rc *SnapClient) AllocTableIDs(ctx context.Context, tables []*metautil.Table) error { + preallocedTableIDs := tidalloc.New(tables) + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) + err := kv.RunInNewTxn(ctx, rc.GetDomain().Store(), true, func(_ context.Context, txn kv.Transaction) error { + return preallocedTableIDs.Alloc(meta.NewMeta(txn)) + }) + if err != nil { + return err + } + + log.Info("registering the table IDs", zap.Stringer("ids", preallocedTableIDs)) + for i := range rc.dbPool { + rc.dbPool[i].RegisterPreallocatedIDs(preallocedTableIDs) + } + if rc.db != nil { + rc.db.RegisterPreallocatedIDs(preallocedTableIDs) + } + return nil +} + +// InitCheckpoint initialize the checkpoint status for the cluster. If the cluster is +// restored for the first time, it will initialize the checkpoint metadata. Otherwrise, +// it will load checkpoint metadata and checkpoint ranges/checksum from the external +// storage. +func (rc *SnapClient) InitCheckpoint( + ctx context.Context, + s storage.ExternalStorage, + taskName string, + config *pdutil.ClusterConfig, + checkpointFirstRun bool, +) (map[int64]map[string]struct{}, *pdutil.ClusterConfig, error) { + var ( + // checkpoint sets distinguished by range key + checkpointSetWithTableID = make(map[int64]map[string]struct{}) + + checkpointClusterConfig *pdutil.ClusterConfig + + err error + ) + + if !checkpointFirstRun { + // load the checkpoint since this is not the first time to restore + meta, err := checkpoint.LoadCheckpointMetadataForRestore(ctx, s, taskName) + if err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + + // The schedulers config is nil, so the restore-schedulers operation is just nil. + // Then the undo function would use the result undo of `remove schedulers` operation, + // instead of that in checkpoint meta. + if meta.SchedulersConfig != nil { + checkpointClusterConfig = meta.SchedulersConfig + } + + // t1 is the latest time the checkpoint ranges persisted to the external storage. + t1, err := checkpoint.WalkCheckpointFileForRestore(ctx, s, rc.cipher, taskName, func(tableID int64, rangeKey checkpoint.RestoreValueType) { + checkpointSet, exists := checkpointSetWithTableID[tableID] + if !exists { + checkpointSet = make(map[string]struct{}) + checkpointSetWithTableID[tableID] = checkpointSet + } + checkpointSet[rangeKey.RangeKey] = struct{}{} + }) + if err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + // t2 is the latest time the checkpoint checksum persisted to the external storage. + checkpointChecksum, t2, err := checkpoint.LoadCheckpointChecksumForRestore(ctx, s, taskName) + if err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + rc.checkpointChecksum = checkpointChecksum + // use the later time to adjust the summary elapsed time. + if t1 > t2 { + summary.AdjustStartTimeToEarlierTime(t1) + } else { + summary.AdjustStartTimeToEarlierTime(t2) + } + } else { + // initialize the checkpoint metadata since it is the first time to restore. + meta := &checkpoint.CheckpointMetadataForRestore{} + // a nil config means undo function + if config != nil { + meta.SchedulersConfig = &pdutil.ClusterConfig{Schedulers: config.Schedulers, ScheduleCfg: config.ScheduleCfg} + } + if err = checkpoint.SaveCheckpointMetadataForRestore(ctx, s, meta, taskName); err != nil { + return checkpointSetWithTableID, nil, errors.Trace(err) + } + } + + rc.checkpointRunner, err = checkpoint.StartCheckpointRunnerForRestore(ctx, s, rc.cipher, taskName) + return checkpointSetWithTableID, checkpointClusterConfig, errors.Trace(err) +} + +func (rc *SnapClient) WaitForFinishCheckpoint(ctx context.Context, flush bool) { + if rc.checkpointRunner != nil { + rc.checkpointRunner.WaitForFinish(ctx, flush) + } +} + +// makeDBPool makes a session pool with specficated size by sessionFactory. +func makeDBPool(size uint, dbFactory func() (*tidallocdb.DB, error)) ([]*tidallocdb.DB, error) { + dbPool := make([]*tidallocdb.DB, 0, size) + for i := uint(0); i < size; i++ { + db, e := dbFactory() + if e != nil { + return dbPool, e + } + if db != nil { + dbPool = append(dbPool, db) + } + } + return dbPool, nil +} + +// Init create db connection and domain for storage. +func (rc *SnapClient) Init(g glue.Glue, store kv.Storage) error { + // setDB must happen after set PolicyMode. + // we will use policyMode to set session variables. + var err error + rc.db, rc.supportPolicy, err = tidallocdb.NewDB(g, store, rc.policyMode) + if err != nil { + return errors.Trace(err) + } + rc.dom, err = g.GetDomain(store) + if err != nil { + return errors.Trace(err) + } + + // init backupMeta only for passing unit test + if rc.backupMeta == nil { + rc.backupMeta = new(backuppb.BackupMeta) + } + + // There are different ways to create session between in binary and in SQL. + // + // Maybe allow user modify the DDL concurrency isn't necessary, + // because executing DDL is really I/O bound (or, algorithm bound?), + // and we cost most of time at waiting DDL jobs be enqueued. + // So these jobs won't be faster or slower when machine become faster or slower, + // hence make it a fixed value would be fine. + rc.dbPool, err = makeDBPool(defaultDDLConcurrency, func() (*tidallocdb.DB, error) { + db, _, err := tidallocdb.NewDB(g, store, rc.policyMode) + return db, err + }) + if err != nil { + log.Warn("create session pool failed, we will send DDLs only by created sessions", + zap.Error(err), + zap.Int("sessionCount", len(rc.dbPool)), + ) + } + return errors.Trace(err) +} + +func (rc *SnapClient) initClients(ctx context.Context, backend *backuppb.StorageBackend, isRawKvMode bool, isTxnKvMode bool) error { + stores, err := conn.GetAllTiKVStoresWithRetry(ctx, rc.pdClient, util.SkipTiFlash) + if err != nil { + return errors.Annotate(err, "failed to get stores") + } + rc.storeCount = len(stores) + rc.updateConcurrency() + + var splitClientOpts []split.ClientOptionalParameter + if isRawKvMode { + splitClientOpts = append(splitClientOpts, split.WithRawKV()) + } + metaClient := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, rc.storeCount+1, splitClientOpts...) + importCli := importclient.NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf) + rc.fileImporter, err = NewSnapFileImporter(ctx, metaClient, importCli, backend, isRawKvMode, isTxnKvMode, stores, rc.rewriteMode, rc.concurrencyPerStore) + return errors.Trace(err) +} + +func (rc *SnapClient) needLoadSchemas(backupMeta *backuppb.BackupMeta) bool { + return !(backupMeta.IsRawKv || backupMeta.IsTxnKv) +} + +// InitBackupMeta loads schemas from BackupMeta to initialize RestoreClient. +func (rc *SnapClient) InitBackupMeta( + c context.Context, + backupMeta *backuppb.BackupMeta, + backend *backuppb.StorageBackend, + reader *metautil.MetaReader, + loadStats bool) error { + if rc.needLoadSchemas(backupMeta) { + databases, err := metautil.LoadBackupTables(c, reader, loadStats) + if err != nil { + return errors.Trace(err) + } + rc.databases = databases + + var ddlJobs []*model.Job + // ddls is the bytes of json.Marshal + ddls, err := reader.ReadDDLs(c) + if err != nil { + return errors.Trace(err) + } + if len(ddls) != 0 { + err = json.Unmarshal(ddls, &ddlJobs) + if err != nil { + return errors.Trace(err) + } + } + rc.ddlJobs = ddlJobs + } + rc.backupMeta = backupMeta + log.Info("load backupmeta", zap.Int("databases", len(rc.databases)), zap.Int("jobs", len(rc.ddlJobs))) + + return rc.initClients(c, backend, backupMeta.IsRawKv, backupMeta.IsTxnKv) +} + +// IsRawKvMode checks whether the backup data is in raw kv format, in which case transactional recover is forbidden. +func (rc *SnapClient) IsRawKvMode() bool { + return rc.backupMeta.IsRawKv +} + +// GetFilesInRawRange gets all files that are in the given range or intersects with the given range. +func (rc *SnapClient) GetFilesInRawRange(startKey []byte, endKey []byte, cf string) ([]*backuppb.File, error) { + if !rc.IsRawKvMode() { + return nil, errors.Annotate(berrors.ErrRestoreModeMismatch, "the backup data is not in raw kv mode") + } + + for _, rawRange := range rc.backupMeta.RawRanges { + // First check whether the given range is backup-ed. If not, we cannot perform the restore. + if rawRange.Cf != cf { + continue + } + + if (len(rawRange.EndKey) > 0 && bytes.Compare(startKey, rawRange.EndKey) >= 0) || + (len(endKey) > 0 && bytes.Compare(rawRange.StartKey, endKey) >= 0) { + // The restoring range is totally out of the current range. Skip it. + continue + } + + if bytes.Compare(startKey, rawRange.StartKey) < 0 || + utils.CompareEndKey(endKey, rawRange.EndKey) > 0 { + // Only partial of the restoring range is in the current backup-ed range. So the given range can't be fully + // restored. + return nil, errors.Annotatef(berrors.ErrRestoreRangeMismatch, + "the given range to restore [%s, %s) is not fully covered by the range that was backed up [%s, %s)", + redact.Key(startKey), redact.Key(endKey), redact.Key(rawRange.StartKey), redact.Key(rawRange.EndKey), + ) + } + + // We have found the range that contains the given range. Find all necessary files. + files := make([]*backuppb.File, 0) + + for _, file := range rc.backupMeta.Files { + if file.Cf != cf { + continue + } + + if len(file.EndKey) > 0 && bytes.Compare(file.EndKey, startKey) < 0 { + // The file is before the range to be restored. + continue + } + if len(endKey) > 0 && bytes.Compare(endKey, file.StartKey) <= 0 { + // The file is after the range to be restored. + // The specified endKey is exclusive, so when it equals to a file's startKey, the file is still skipped. + continue + } + + files = append(files, file) + } + + // There should be at most one backed up range that covers the restoring range. + return files, nil + } + + return nil, errors.Annotate(berrors.ErrRestoreRangeMismatch, "no backup data in the range") +} + +// ResetTS resets the timestamp of PD to a bigger value. +func (rc *SnapClient) ResetTS(ctx context.Context, pdCtrl *pdutil.PdController) error { + restoreTS := rc.backupMeta.GetEndVersion() + log.Info("reset pd timestamp", zap.Uint64("ts", restoreTS)) + return utils.WithRetry(ctx, func() error { + return pdCtrl.ResetTS(ctx, restoreTS) + }, utils.NewPDReqBackoffer()) +} + +// GetDatabases returns all databases. +func (rc *SnapClient) GetDatabases() []*metautil.Database { + dbs := make([]*metautil.Database, 0, len(rc.databases)) + for _, db := range rc.databases { + dbs = append(dbs, db) + } + return dbs +} + +// HasBackedUpSysDB whether we have backed up system tables +// br backs system tables up since 5.1.0 +func (rc *SnapClient) HasBackedUpSysDB() bool { + sysDBs := []string{"mysql", "sys"} + for _, db := range sysDBs { + temporaryDB := utils.TemporaryDBName(db) + _, backedUp := rc.databases[temporaryDB.O] + if backedUp { + return true + } + } + return false +} + +// GetPlacementPolicies returns policies. +func (rc *SnapClient) GetPlacementPolicies() (*sync.Map, error) { + policies := &sync.Map{} + for _, p := range rc.backupMeta.Policies { + policyInfo := &model.PolicyInfo{} + err := json.Unmarshal(p.Info, policyInfo) + if err != nil { + return nil, errors.Trace(err) + } + policies.Store(policyInfo.Name.L, policyInfo) + } + return policies, nil +} + +// GetDDLJobs returns ddl jobs. +func (rc *SnapClient) GetDDLJobs() []*model.Job { + return rc.ddlJobs +} + +// SetPolicyMap set policyMap. +func (rc *SnapClient) SetPolicyMap(p *sync.Map) { + rc.policyMap = p +} + +// CreatePolicies creates all policies in full restore. +func (rc *SnapClient) CreatePolicies(ctx context.Context, policyMap *sync.Map) error { + var err error + policyMap.Range(func(key, value any) bool { + e := rc.db.CreatePlacementPolicy(ctx, value.(*model.PolicyInfo)) + if e != nil { + err = e + return false + } + return true + }) + return err +} + +// CreateDatabases creates databases. If the client has the db pool, it would create it. +func (rc *SnapClient) CreateDatabases(ctx context.Context, dbs []*metautil.Database) error { + if rc.IsSkipCreateSQL() { + log.Info("skip create database") + return nil + } + + if len(rc.dbPool) == 0 { + log.Info("create databases sequentially") + for _, db := range dbs { + err := rc.db.CreateDatabase(ctx, db.Info, rc.supportPolicy, rc.policyMap) + if err != nil { + return errors.Trace(err) + } + } + return nil + } + + log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool))) + eg, ectx := errgroup.WithContext(ctx) + workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers") + for _, db_ := range dbs { + db := db_ + workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { + conn := rc.dbPool[id%uint64(len(rc.dbPool))] + return conn.CreateDatabase(ectx, db.Info, rc.supportPolicy, rc.policyMap) + }) + } + return eg.Wait() +} + +// generateRebasedTables generate a map[UniqueTableName]bool to represent tables that haven't updated table info. +// there are two situations: +// 1. tables that already exists in the restored cluster. +// 2. tables that are created by executing ddl jobs. +// so, only tables in incremental restoration will be added to the map +func (rc *SnapClient) generateRebasedTables(tables []*metautil.Table) { + if !rc.IsIncremental() { + // in full restoration, all tables are created by Session.CreateTable, and all tables' info is updated. + rc.rebasedTablesMap = make(map[restore.UniqueTableName]bool) + return + } + + rc.rebasedTablesMap = make(map[restore.UniqueTableName]bool, len(tables)) + for _, table := range tables { + rc.rebasedTablesMap[restore.UniqueTableName{DB: table.DB.Name.String(), Table: table.Info.Name.String()}] = true + } +} + +// getRebasedTables returns tables that may need to be rebase auto increment id or auto random id +func (rc *SnapClient) getRebasedTables() map[restore.UniqueTableName]bool { + return rc.rebasedTablesMap +} + +func (rc *SnapClient) createTables( + ctx context.Context, + db *tidallocdb.DB, + tables []*metautil.Table, + newTS uint64, +) ([]CreatedTable, error) { + log.Info("client to create tables") + if rc.IsSkipCreateSQL() { + log.Info("skip create table and alter autoIncID") + } else { + err := db.CreateTables(ctx, tables, rc.getRebasedTables(), rc.supportPolicy, rc.policyMap) + if err != nil { + return nil, errors.Trace(err) + } + } + cts := make([]CreatedTable, 0, len(tables)) + for _, table := range tables { + newTableInfo, err := restore.GetTableSchema(rc.dom, table.DB.Name, table.Info.Name) + if err != nil { + return nil, errors.Trace(err) + } + if newTableInfo.IsCommonHandle != table.Info.IsCommonHandle { + return nil, errors.Annotatef(berrors.ErrRestoreModeMismatch, + "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", + restore.TransferBoolToValue(table.Info.IsCommonHandle), + table.Info.IsCommonHandle, + newTableInfo.IsCommonHandle) + } + rules := restoreutils.GetRewriteRules(newTableInfo, table.Info, newTS, true) + ct := CreatedTable{ + RewriteRule: rules, + Table: newTableInfo, + OldTable: table, + } + log.Debug("new created tables", zap.Any("table", ct)) + cts = append(cts, ct) + } + return cts, nil +} + +func (rc *SnapClient) createTablesInWorkerPool(ctx context.Context, tables []*metautil.Table, newTS uint64, outCh chan<- CreatedTable) error { + eg, ectx := errgroup.WithContext(ctx) + rater := logutil.TraceRateOver(logutil.MetricTableCreatedCounter) + workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "Create Tables Worker") + numOfTables := len(tables) + + for lastSent := 0; lastSent < numOfTables; lastSent += int(rc.batchDdlSize) { + end := min(lastSent+int(rc.batchDdlSize), len(tables)) + log.Info("create tables", zap.Int("table start", lastSent), zap.Int("table end", end)) + + tableSlice := tables[lastSent:end] + workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { + db := rc.dbPool[id%uint64(len(rc.dbPool))] + cts, err := rc.createTables(ectx, db, tableSlice, newTS) // ddl job for [lastSent:i) + failpoint.Inject("restore-createtables-error", func(val failpoint.Value) { + if val.(bool) { + err = errors.New("sample error without extra message") + } + }) + if err != nil { + log.Error("create tables fail", zap.Error(err)) + return err + } + for _, ct := range cts { + log.Debug("table created and send to next", + zap.Int("output chan size", len(outCh)), + zap.Stringer("table", ct.OldTable.Info.Name), + zap.Stringer("database", ct.OldTable.DB.Name)) + outCh <- ct + rater.Inc() + rater.L().Info("table created", + zap.Stringer("table", ct.OldTable.Info.Name), + zap.Stringer("database", ct.OldTable.DB.Name)) + } + return err + }) + } + return eg.Wait() +} + +func (rc *SnapClient) createTable( + ctx context.Context, + db *tidallocdb.DB, + table *metautil.Table, + newTS uint64, +) (CreatedTable, error) { + if rc.IsSkipCreateSQL() { + log.Info("skip create table and alter autoIncID", zap.Stringer("table", table.Info.Name)) + } else { + err := db.CreateTable(ctx, table, rc.getRebasedTables(), rc.supportPolicy, rc.policyMap) + if err != nil { + return CreatedTable{}, errors.Trace(err) + } + } + newTableInfo, err := restore.GetTableSchema(rc.dom, table.DB.Name, table.Info.Name) + if err != nil { + return CreatedTable{}, errors.Trace(err) + } + if newTableInfo.IsCommonHandle != table.Info.IsCommonHandle { + return CreatedTable{}, errors.Annotatef(berrors.ErrRestoreModeMismatch, + "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", + restore.TransferBoolToValue(table.Info.IsCommonHandle), + table.Info.IsCommonHandle, + newTableInfo.IsCommonHandle) + } + rules := restoreutils.GetRewriteRules(newTableInfo, table.Info, newTS, true) + et := CreatedTable{ + RewriteRule: rules, + Table: newTableInfo, + OldTable: table, + } + return et, nil +} + +func (rc *SnapClient) createTablesWithSoleDB(ctx context.Context, + createOneTable func(ctx context.Context, db *tidallocdb.DB, t *metautil.Table) error, + tables []*metautil.Table) error { + for _, t := range tables { + if err := createOneTable(ctx, rc.db, t); err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (rc *SnapClient) createTablesWithDBPool(ctx context.Context, + createOneTable func(ctx context.Context, db *tidallocdb.DB, t *metautil.Table) error, + tables []*metautil.Table) error { + eg, ectx := errgroup.WithContext(ctx) + workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DDL workers") + for _, t := range tables { + table := t + workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { + db := rc.dbPool[id%uint64(len(rc.dbPool))] + return createOneTable(ectx, db, table) + }) + } + return eg.Wait() +} + +// InitFullClusterRestore init fullClusterRestore and set SkipGrantTable as needed +func (rc *SnapClient) InitFullClusterRestore(explicitFilter bool) { + rc.fullClusterRestore = !explicitFilter && rc.IsFull() + + log.Info("full cluster restore", zap.Bool("value", rc.fullClusterRestore)) +} + +func (rc *SnapClient) IsFullClusterRestore() bool { + return rc.fullClusterRestore +} + +// IsFull returns whether this backup is full. +func (rc *SnapClient) IsFull() bool { + failpoint.Inject("mock-incr-backup-data", func() { + failpoint.Return(false) + }) + return !rc.IsIncremental() +} + +// IsIncremental returns whether this backup is incremental. +func (rc *SnapClient) IsIncremental() bool { + return !(rc.backupMeta.StartVersion == rc.backupMeta.EndVersion || + rc.backupMeta.StartVersion == 0) +} + +// NeedCheckFreshCluster is every time. except restore from a checkpoint or user has not set filter argument. +func (rc *SnapClient) NeedCheckFreshCluster(ExplicitFilter bool, firstRun bool) bool { + return rc.IsFull() && !ExplicitFilter && firstRun +} + +// EnableSkipCreateSQL sets switch of skip create schema and tables. +func (rc *SnapClient) EnableSkipCreateSQL() { + rc.noSchema = true +} + +// IsSkipCreateSQL returns whether we need skip create schema and tables in restore. +func (rc *SnapClient) IsSkipCreateSQL() bool { + return rc.noSchema +} + +// CheckTargetClusterFresh check whether the target cluster is fresh or not +// if there's no user dbs or tables, we take it as a fresh cluster, although +// user may have created some users or made other changes. +func (rc *SnapClient) CheckTargetClusterFresh(ctx context.Context) error { + log.Info("checking whether target cluster is fresh") + return restore.AssertUserDBsEmpty(rc.dom) +} + +// ExecDDLs executes the queries of the ddl jobs. +func (rc *SnapClient) ExecDDLs(ctx context.Context, ddlJobs []*model.Job) error { + // Sort the ddl jobs by schema version in ascending order. + slices.SortFunc(ddlJobs, func(i, j *model.Job) int { + return cmp.Compare(i.BinlogInfo.SchemaVersion, j.BinlogInfo.SchemaVersion) + }) + + for _, job := range ddlJobs { + err := rc.db.ExecDDL(ctx, job) + if err != nil { + return errors.Trace(err) + } + log.Info("execute ddl query", + zap.String("db", job.SchemaName), + zap.String("query", job.Query), + zap.Int64("historySchemaVersion", job.BinlogInfo.SchemaVersion)) + } + return nil +} + +func (rc *SnapClient) ResetSpeedLimit(ctx context.Context) error { + rc.hasSpeedLimited = false + err := rc.setSpeedLimit(ctx, 0) + if err != nil { + return errors.Trace(err) + } + return nil +} + +func (rc *SnapClient) setSpeedLimit(ctx context.Context, rateLimit uint64) error { + if !rc.hasSpeedLimited { + stores, err := util.GetAllTiKVStores(ctx, rc.pdClient, util.SkipTiFlash) + if err != nil { + return errors.Trace(err) + } + + eg, ectx := errgroup.WithContext(ctx) + for _, store := range stores { + if err := ectx.Err(); err != nil { + return errors.Trace(err) + } + + finalStore := store + rc.workerPool.ApplyOnErrorGroup(eg, + func() error { + err := rc.fileImporter.SetDownloadSpeedLimit(ectx, finalStore.GetId(), rateLimit) + if err != nil { + return errors.Trace(err) + } + return nil + }) + } + + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + rc.hasSpeedLimited = true + } + return nil +} + +func getFileRangeKey(f string) string { + // the backup date file pattern is `{store_id}_{region_id}_{epoch_version}_{key}_{ts}_{cf}.sst` + // so we need to compare with out the `_{cf}.sst` suffix + idx := strings.LastIndex(f, "_") + if idx < 0 { + panic(fmt.Sprintf("invalid backup data file name: '%s'", f)) + } + + return f[:idx] +} + +// isFilesBelongToSameRange check whether two files are belong to the same range with different cf. +func isFilesBelongToSameRange(f1, f2 string) bool { + return getFileRangeKey(f1) == getFileRangeKey(f2) +} + +func drainFilesByRange(files []*backuppb.File) ([]*backuppb.File, []*backuppb.File) { + if len(files) == 0 { + return nil, nil + } + idx := 1 + for idx < len(files) { + if !isFilesBelongToSameRange(files[idx-1].Name, files[idx].Name) { + break + } + idx++ + } + + return files[:idx], files[idx:] +} + +// RestoreSSTFiles tries to restore the files. +func (rc *SnapClient) RestoreSSTFiles( + ctx context.Context, + tableIDWithFiles []TableIDWithFiles, + updateCh glue.Progress, +) (err error) { + start := time.Now() + fileCount := 0 + defer func() { + elapsed := time.Since(start) + if err == nil { + log.Info("Restore files", zap.Duration("take", elapsed)) + summary.CollectSuccessUnit("files", fileCount, elapsed) + } + }() + + log.Debug("start to restore files", zap.Int("files", fileCount)) + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("Client.RestoreSSTFiles", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + eg, ectx := errgroup.WithContext(ctx) + err = rc.setSpeedLimit(ctx, rc.rateLimit) + if err != nil { + return errors.Trace(err) + } + + var rangeFiles []*backuppb.File + var leftFiles []*backuppb.File +LOOPFORTABLE: + for _, tableIDWithFile := range tableIDWithFiles { + tableID := tableIDWithFile.TableID + files := tableIDWithFile.Files + rules := tableIDWithFile.RewriteRules + fileCount += len(files) + for rangeFiles, leftFiles = drainFilesByRange(files); len(rangeFiles) != 0; rangeFiles, leftFiles = drainFilesByRange(leftFiles) { + if ectx.Err() != nil { + log.Warn("Restoring encountered error and already stopped, give up remained files.", + zap.Int("remained", len(leftFiles)), + logutil.ShortError(ectx.Err())) + // We will fetch the error from the errgroup then (If there were). + // Also note if the parent context has been canceled or something, + // breaking here directly is also a reasonable behavior. + break LOOPFORTABLE + } + filesReplica := rangeFiles + rc.fileImporter.WaitUntilUnblock() + rc.workerPool.ApplyOnErrorGroup(eg, func() (restoreErr error) { + fileStart := time.Now() + defer func() { + if restoreErr == nil { + log.Info("import files done", logutil.Files(filesReplica), + zap.Duration("take", time.Since(fileStart))) + updateCh.Inc() + } + }() + if importErr := rc.fileImporter.ImportSSTFiles(ectx, filesReplica, rules, rc.cipher, rc.dom.Store().GetCodec().GetAPIVersion()); importErr != nil { + return errors.Trace(importErr) + } + + // the data of this range has been import done + if rc.checkpointRunner != nil && len(filesReplica) > 0 { + rangeKey := getFileRangeKey(filesReplica[0].Name) + // The checkpoint range shows this ranges of kvs has been restored into + // the table corresponding to the table-id. + if err := checkpoint.AppendRangesForRestore(ectx, rc.checkpointRunner, tableID, rangeKey); err != nil { + return errors.Trace(err) + } + } + return nil + }) + } + } + + if err := eg.Wait(); err != nil { + summary.CollectFailureUnit("file", err) + log.Error( + "restore files failed", + zap.Error(err), + ) + return errors.Trace(err) + } + // Once the parent context canceled and there is no task running in the errgroup, + // we may break the for loop without error in the errgroup. (Will this happen?) + // At that time, return the error in the context here. + return ctx.Err() +} + +func (rc *SnapClient) execChecksum( + ctx context.Context, + tbl *CreatedTable, + kvClient kv.Client, + concurrency uint, +) error { + logger := log.L().With( + zap.String("db", tbl.OldTable.DB.Name.O), + zap.String("table", tbl.OldTable.Info.Name.O), + ) + + if tbl.OldTable.NoChecksum() { + logger.Warn("table has no checksum, skipping checksum") + return nil + } + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("Client.execChecksum", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + item, exists := rc.checkpointChecksum[tbl.Table.ID] + if !exists { + startTS, err := restore.GetTSWithRetry(ctx, rc.pdClient) + if err != nil { + return errors.Trace(err) + } + exe, err := checksum.NewExecutorBuilder(tbl.Table, startTS). + SetOldTable(tbl.OldTable). + SetConcurrency(concurrency). + SetOldKeyspace(tbl.RewriteRule.OldKeyspace). + SetNewKeyspace(tbl.RewriteRule.NewKeyspace). + SetExplicitRequestSourceType(kvutil.ExplicitTypeBR). + Build() + if err != nil { + return errors.Trace(err) + } + checksumResp, err := exe.Execute(ctx, kvClient, func() { + // TODO: update progress here. + }) + if err != nil { + return errors.Trace(err) + } + item = &checkpoint.ChecksumItem{ + TableID: tbl.Table.ID, + Crc64xor: checksumResp.Checksum, + TotalKvs: checksumResp.TotalKvs, + TotalBytes: checksumResp.TotalBytes, + } + if rc.checkpointRunner != nil { + err = rc.checkpointRunner.FlushChecksumItem(ctx, item) + if err != nil { + return errors.Trace(err) + } + } + } + table := tbl.OldTable + if item.Crc64xor != table.Crc64Xor || + item.TotalKvs != table.TotalKvs || + item.TotalBytes != table.TotalBytes { + logger.Error("failed in validate checksum", + zap.Uint64("origin tidb crc64", table.Crc64Xor), + zap.Uint64("calculated crc64", item.Crc64xor), + zap.Uint64("origin tidb total kvs", table.TotalKvs), + zap.Uint64("calculated total kvs", item.TotalKvs), + zap.Uint64("origin tidb total bytes", table.TotalBytes), + zap.Uint64("calculated total bytes", item.TotalBytes), + ) + return errors.Annotate(berrors.ErrRestoreChecksumMismatch, "failed to validate checksum") + } + logger.Info("success in validate checksum") + return nil +} + +func (rc *SnapClient) WaitForFilesRestored(ctx context.Context, files []*backuppb.File, updateCh glue.Progress) error { + errCh := make(chan error, len(files)) + eg, ectx := errgroup.WithContext(ctx) + defer close(errCh) + + for _, file := range files { + fileReplica := file + rc.workerPool.ApplyOnErrorGroup(eg, + func() error { + defer func() { + log.Info("import sst files done", logutil.Files(files)) + updateCh.Inc() + }() + return rc.fileImporter.ImportSSTFiles(ectx, []*backuppb.File{fileReplica}, restoreutils.EmptyRewriteRule(), rc.cipher, rc.backupMeta.ApiVersion) + }) + } + if err := eg.Wait(); err != nil { + return errors.Trace(err) + } + return nil +} + +// RestoreRaw tries to restore raw keys in the specified range. +func (rc *SnapClient) RestoreRaw( + ctx context.Context, startKey []byte, endKey []byte, files []*backuppb.File, updateCh glue.Progress, +) error { + start := time.Now() + defer func() { + elapsed := time.Since(start) + log.Info("Restore Raw", + logutil.Key("startKey", startKey), + logutil.Key("endKey", endKey), + zap.Duration("take", elapsed)) + }() + err := rc.fileImporter.SetRawRange(startKey, endKey) + if err != nil { + return errors.Trace(err) + } + + err = rc.WaitForFilesRestored(ctx, files, updateCh) + if err != nil { + return errors.Trace(err) + } + log.Info( + "finish to restore raw range", + logutil.Key("startKey", startKey), + logutil.Key("endKey", endKey), + ) + return nil +} + +// SplitRanges implements TiKVRestorer. It splits region by +// data range after rewrite. +func (rc *SnapClient) SplitRanges( + ctx context.Context, + ranges []rtree.Range, + updateCh glue.Progress, + isRawKv bool, +) error { + splitClientOpts := make([]split.ClientOptionalParameter, 0, 2) + splitClientOpts = append(splitClientOpts, split.WithOnSplit(func(keys [][]byte) { + for range keys { + updateCh.Inc() + } + })) + if isRawKv { + splitClientOpts = append(splitClientOpts, split.WithRawKV()) + } + + splitter := internalutils.NewRegionSplitter(split.NewClient( + rc.pdClient, + rc.pdHTTPClient, + rc.tlsConf, + maxSplitKeysOnce, + rc.storeCount+1, + splitClientOpts..., + )) + + return splitter.ExecuteSplit(ctx, ranges) +} diff --git a/br/pkg/restore/snap_client/context_manager.go b/br/pkg/restore/snap_client/context_manager.go index 294f774630db6..1c5a2569a6a00 100644 --- a/br/pkg/restore/snap_client/context_manager.go +++ b/br/pkg/restore/snap_client/context_manager.go @@ -242,10 +242,10 @@ func (manager *brContextManager) waitPlacementSchedule(ctx context.Context, tabl } log.Info("start waiting placement schedule") ticker := time.NewTicker(time.Second * 10) - failpoint.Inject("wait-placement-schedule-quicker-ticker", func() { + if _, _err_ := failpoint.Eval(_curpkg_("wait-placement-schedule-quicker-ticker")); _err_ == nil { ticker.Stop() ticker = time.NewTicker(time.Millisecond * 500) - }) + } defer ticker.Stop() for { select { diff --git a/br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ b/br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ new file mode 100644 index 0000000000000..294f774630db6 --- /dev/null +++ b/br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ @@ -0,0 +1,290 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapclient + +import ( + "context" + "crypto/tls" + "encoding/hex" + "fmt" + "strconv" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/conn" + "github.com/pingcap/tidb/br/pkg/conn/util" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/codec" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/zap" +) + +// ContextManager is the struct to manage a TiKV 'context' for restore. +// Batcher will call Enter when any table should be restore on batch, +// so you can do some prepare work here(e.g. set placement rules for online restore). +type ContextManager interface { + // Enter make some tables 'enter' this context(a.k.a., prepare for restore). + Enter(ctx context.Context, tables []CreatedTable) error + // Leave make some tables 'leave' this context(a.k.a., restore is done, do some post-works). + Leave(ctx context.Context, tables []CreatedTable) error + // Close closes the context manager, sometimes when the manager is 'killed' and should do some cleanup + // it would be call. + Close(ctx context.Context) +} + +// NewBRContextManager makes a BR context manager, that is, +// set placement rules for online restore when enter(see ), +// unset them when leave. +func NewBRContextManager(ctx context.Context, pdClient pd.Client, pdHTTPCli pdhttp.Client, tlsConf *tls.Config, isOnline bool) (ContextManager, error) { + manager := &brContextManager{ + // toolClient reuse the split.SplitClient to do miscellaneous things. It doesn't + // call split related functions so set the arguments to arbitrary values. + toolClient: split.NewClient(pdClient, pdHTTPCli, tlsConf, maxSplitKeysOnce, 3), + isOnline: isOnline, + + hasTable: make(map[int64]CreatedTable), + } + + err := manager.loadRestoreStores(ctx, pdClient) + return manager, errors.Trace(err) +} + +type brContextManager struct { + toolClient split.SplitClient + restoreStores []uint64 + isOnline bool + + // This 'set' of table ID allow us to handle each table just once. + hasTable map[int64]CreatedTable + mu sync.Mutex +} + +func (manager *brContextManager) Close(ctx context.Context) { + tbls := make([]*model.TableInfo, 0, len(manager.hasTable)) + for _, tbl := range manager.hasTable { + tbls = append(tbls, tbl.Table) + } + manager.splitPostWork(ctx, tbls) +} + +func (manager *brContextManager) Enter(ctx context.Context, tables []CreatedTable) error { + placementRuleTables := make([]*model.TableInfo, 0, len(tables)) + manager.mu.Lock() + defer manager.mu.Unlock() + + for _, tbl := range tables { + if _, ok := manager.hasTable[tbl.Table.ID]; !ok { + placementRuleTables = append(placementRuleTables, tbl.Table) + } + manager.hasTable[tbl.Table.ID] = tbl + } + + return manager.splitPrepareWork(ctx, placementRuleTables) +} + +func (manager *brContextManager) Leave(ctx context.Context, tables []CreatedTable) error { + manager.mu.Lock() + defer manager.mu.Unlock() + placementRuleTables := make([]*model.TableInfo, 0, len(tables)) + + for _, table := range tables { + placementRuleTables = append(placementRuleTables, table.Table) + } + + manager.splitPostWork(ctx, placementRuleTables) + log.Info("restore table done", zapTables(tables)) + for _, tbl := range placementRuleTables { + delete(manager.hasTable, tbl.ID) + } + return nil +} + +func (manager *brContextManager) splitPostWork(ctx context.Context, tables []*model.TableInfo) { + err := manager.resetPlacementRules(ctx, tables) + if err != nil { + log.Warn("reset placement rules failed", zap.Error(err)) + return + } +} + +func (manager *brContextManager) splitPrepareWork(ctx context.Context, tables []*model.TableInfo) error { + err := manager.setupPlacementRules(ctx, tables) + if err != nil { + log.Error("setup placement rules failed", zap.Error(err)) + return errors.Trace(err) + } + + err = manager.waitPlacementSchedule(ctx, tables) + if err != nil { + log.Error("wait placement schedule failed", zap.Error(err)) + return errors.Trace(err) + } + return nil +} + +const ( + restoreLabelKey = "exclusive" + restoreLabelValue = "restore" +) + +// loadRestoreStores loads the stores used to restore data. This function is called only when is online. +func (manager *brContextManager) loadRestoreStores(ctx context.Context, pdClient util.StoreMeta) error { + if !manager.isOnline { + return nil + } + stores, err := conn.GetAllTiKVStoresWithRetry(ctx, pdClient, util.SkipTiFlash) + if err != nil { + return errors.Trace(err) + } + for _, s := range stores { + if s.GetState() != metapb.StoreState_Up { + continue + } + for _, l := range s.GetLabels() { + if l.GetKey() == restoreLabelKey && l.GetValue() == restoreLabelValue { + manager.restoreStores = append(manager.restoreStores, s.GetId()) + break + } + } + } + log.Info("load restore stores", zap.Uint64s("store-ids", manager.restoreStores)) + return nil +} + +// SetupPlacementRules sets rules for the tables' regions. +func (manager *brContextManager) setupPlacementRules(ctx context.Context, tables []*model.TableInfo) error { + if !manager.isOnline || len(manager.restoreStores) == 0 { + return nil + } + log.Info("start setting placement rules") + rule, err := manager.toolClient.GetPlacementRule(ctx, "pd", "default") + if err != nil { + return errors.Trace(err) + } + rule.Index = 100 + rule.Override = true + rule.LabelConstraints = append(rule.LabelConstraints, pdhttp.LabelConstraint{ + Key: restoreLabelKey, + Op: "in", + Values: []string{restoreLabelValue}, + }) + for _, t := range tables { + rule.ID = getRuleID(t.ID) + rule.StartKeyHex = hex.EncodeToString(codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID))) + rule.EndKeyHex = hex.EncodeToString(codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID+1))) + err = manager.toolClient.SetPlacementRule(ctx, rule) + if err != nil { + return errors.Trace(err) + } + } + log.Info("finish setting placement rules") + return nil +} + +func (manager *brContextManager) checkRegions(ctx context.Context, tables []*model.TableInfo) (bool, string, error) { + for i, t := range tables { + start := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID)) + end := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID+1)) + ok, regionProgress, err := manager.checkRange(ctx, start, end) + if err != nil { + return false, "", errors.Trace(err) + } + if !ok { + return false, fmt.Sprintf("table %v/%v, %s", i, len(tables), regionProgress), nil + } + } + return true, "", nil +} + +func (manager *brContextManager) checkRange(ctx context.Context, start, end []byte) (bool, string, error) { + regions, err := manager.toolClient.ScanRegions(ctx, start, end, -1) + if err != nil { + return false, "", errors.Trace(err) + } + for i, r := range regions { + NEXT_PEER: + for _, p := range r.Region.GetPeers() { + for _, storeID := range manager.restoreStores { + if p.GetStoreId() == storeID { + continue NEXT_PEER + } + } + return false, fmt.Sprintf("region %v/%v", i, len(regions)), nil + } + } + return true, "", nil +} + +// waitPlacementSchedule waits PD to move tables to restore stores. +func (manager *brContextManager) waitPlacementSchedule(ctx context.Context, tables []*model.TableInfo) error { + if !manager.isOnline || len(manager.restoreStores) == 0 { + return nil + } + log.Info("start waiting placement schedule") + ticker := time.NewTicker(time.Second * 10) + failpoint.Inject("wait-placement-schedule-quicker-ticker", func() { + ticker.Stop() + ticker = time.NewTicker(time.Millisecond * 500) + }) + defer ticker.Stop() + for { + select { + case <-ticker.C: + ok, progress, err := manager.checkRegions(ctx, tables) + if err != nil { + return errors.Trace(err) + } + if ok { + log.Info("finish waiting placement schedule") + return nil + } + log.Info("placement schedule progress: " + progress) + case <-ctx.Done(): + return ctx.Err() + } + } +} + +func getRuleID(tableID int64) string { + return "restore-t" + strconv.FormatInt(tableID, 10) +} + +// resetPlacementRules removes placement rules for tables. +func (manager *brContextManager) resetPlacementRules(ctx context.Context, tables []*model.TableInfo) error { + if !manager.isOnline || len(manager.restoreStores) == 0 { + return nil + } + log.Info("start resetting placement rules") + var failedTables []int64 + for _, t := range tables { + err := manager.toolClient.DeletePlacementRule(ctx, "pd", getRuleID(t.ID)) + if err != nil { + log.Info("failed to delete placement rule for table", zap.Int64("table-id", t.ID)) + failedTables = append(failedTables, t.ID) + } + } + if len(failedTables) > 0 { + return errors.Annotatef(berrors.ErrPDInvalidResponse, "failed to delete placement rules for tables %v", failedTables) + } + return nil +} diff --git a/br/pkg/restore/snap_client/import.go b/br/pkg/restore/snap_client/import.go index cdab5a678628a..1a27d67011066 100644 --- a/br/pkg/restore/snap_client/import.go +++ b/br/pkg/restore/snap_client/import.go @@ -489,15 +489,15 @@ func (importer *SnapFileImporter) download( downloadMetas, e = importer.downloadSST(ctx, regionInfo, files, rewriteRules, cipher, apiVersion) } - failpoint.Inject("restore-storage-error", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("restore-storage-error")); _err_ == nil { msg := val.(string) log.Debug("failpoint restore-storage-error injected.", zap.String("msg", msg)) e = errors.Annotate(e, msg) - }) - failpoint.Inject("restore-gRPC-error", func(_ failpoint.Value) { + } + if _, _err_ := failpoint.Eval(_curpkg_("restore-gRPC-error")); _err_ == nil { log.Warn("the connection to TiKV has been cut by a neko, meow :3") e = status.Error(codes.Unavailable, "the connection to TiKV has been cut by a neko, meow :3") - }) + } if isDecryptSstErr(e) { log.Info("fail to decrypt when download sst, try again with no-crypt", logutil.Files(files)) if importer.kvMode == Raw || importer.kvMode == Txn { diff --git a/br/pkg/restore/snap_client/import.go__failpoint_stash__ b/br/pkg/restore/snap_client/import.go__failpoint_stash__ new file mode 100644 index 0000000000000..cdab5a678628a --- /dev/null +++ b/br/pkg/restore/snap_client/import.go__failpoint_stash__ @@ -0,0 +1,846 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package snapclient + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" + "github.com/pingcap/tidb/br/pkg/restore/split" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/util/codec" + kvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/multierr" + "go.uber.org/zap" + "golang.org/x/exp/maps" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type KvMode int + +const ( + TiDB KvMode = iota + Raw + Txn +) + +const ( + // Todo: make it configable + gRPCTimeOut = 25 * time.Minute +) + +// RewriteMode is a mode flag that tells the TiKV how to handle the rewrite rules. +type RewriteMode int + +const ( + // RewriteModeLegacy means no rewrite rule is applied. + RewriteModeLegacy RewriteMode = iota + + // RewriteModeKeyspace means the rewrite rule could be applied to keyspace. + RewriteModeKeyspace +) + +type storeTokenChannelMap struct { + sync.RWMutex + tokens map[uint64]chan struct{} +} + +func (s *storeTokenChannelMap) acquireTokenCh(storeID uint64, bufferSize uint) chan struct{} { + s.RLock() + tokenCh, ok := s.tokens[storeID] + // handle the case that the store is new-scaled in the cluster + if !ok { + s.RUnlock() + s.Lock() + // Notice: worker channel can't replaced, because it is still used after unlock. + if tokenCh, ok = s.tokens[storeID]; !ok { + tokenCh = utils.BuildWorkerTokenChannel(bufferSize) + s.tokens[storeID] = tokenCh + } + s.Unlock() + } else { + s.RUnlock() + } + return tokenCh +} + +func (s *storeTokenChannelMap) ShouldBlock() bool { + s.RLock() + defer s.RUnlock() + if len(s.tokens) == 0 { + // never block if there is no store worker pool + return false + } + for _, pool := range s.tokens { + if len(pool) > 0 { + // At least one store worker pool has available worker + return false + } + } + return true +} + +func newStoreTokenChannelMap(stores []*metapb.Store, bufferSize uint) *storeTokenChannelMap { + storeTokenChannelMap := &storeTokenChannelMap{ + sync.RWMutex{}, + make(map[uint64]chan struct{}), + } + if bufferSize == 0 { + return storeTokenChannelMap + } + for _, store := range stores { + ch := utils.BuildWorkerTokenChannel(bufferSize) + storeTokenChannelMap.tokens[store.Id] = ch + } + return storeTokenChannelMap +} + +type SnapFileImporter struct { + metaClient split.SplitClient + importClient importclient.ImporterClient + backend *backuppb.StorageBackend + + downloadTokensMap *storeTokenChannelMap + ingestTokensMap *storeTokenChannelMap + + concurrencyPerStore uint + + kvMode KvMode + rawStartKey []byte + rawEndKey []byte + rewriteMode RewriteMode + + cacheKey string + cond *sync.Cond +} + +func NewSnapFileImporter( + ctx context.Context, + metaClient split.SplitClient, + importClient importclient.ImporterClient, + backend *backuppb.StorageBackend, + isRawKvMode bool, + isTxnKvMode bool, + tikvStores []*metapb.Store, + rewriteMode RewriteMode, + concurrencyPerStore uint, +) (*SnapFileImporter, error) { + kvMode := TiDB + if isRawKvMode { + kvMode = Raw + } + if isTxnKvMode { + kvMode = Txn + } + + fileImporter := &SnapFileImporter{ + metaClient: metaClient, + backend: backend, + importClient: importClient, + downloadTokensMap: newStoreTokenChannelMap(tikvStores, concurrencyPerStore), + ingestTokensMap: newStoreTokenChannelMap(tikvStores, concurrencyPerStore), + kvMode: kvMode, + rewriteMode: rewriteMode, + cacheKey: fmt.Sprintf("BR-%s-%d", time.Now().Format("20060102150405"), rand.Int63()), + concurrencyPerStore: concurrencyPerStore, + cond: sync.NewCond(new(sync.Mutex)), + } + + err := fileImporter.checkMultiIngestSupport(ctx, tikvStores) + return fileImporter, errors.Trace(err) +} + +func (importer *SnapFileImporter) WaitUntilUnblock() { + importer.cond.L.Lock() + for importer.ShouldBlock() { + // wait for download worker notified + importer.cond.Wait() + } + importer.cond.L.Unlock() +} + +func (importer *SnapFileImporter) ShouldBlock() bool { + if importer != nil { + return importer.downloadTokensMap.ShouldBlock() || importer.ingestTokensMap.ShouldBlock() + } + return false +} + +func (importer *SnapFileImporter) releaseToken(tokenCh chan struct{}) { + tokenCh <- struct{}{} + // finish the task, notify the main goroutine to continue + importer.cond.L.Lock() + importer.cond.Signal() + importer.cond.L.Unlock() +} + +func (importer *SnapFileImporter) Close() error { + if importer != nil && importer.importClient != nil { + return importer.importClient.CloseGrpcClient() + } + return nil +} + +func (importer *SnapFileImporter) SetDownloadSpeedLimit(ctx context.Context, storeID, rateLimit uint64) error { + req := &import_sstpb.SetDownloadSpeedLimitRequest{ + SpeedLimit: rateLimit, + } + _, err := importer.importClient.SetDownloadSpeedLimit(ctx, storeID, req) + return errors.Trace(err) +} + +// checkMultiIngestSupport checks whether all stores support multi-ingest +func (importer *SnapFileImporter) checkMultiIngestSupport(ctx context.Context, tikvStores []*metapb.Store) error { + storeIDs := make([]uint64, 0, len(tikvStores)) + for _, s := range tikvStores { + if s.State != metapb.StoreState_Up { + continue + } + storeIDs = append(storeIDs, s.Id) + } + + if err := importer.importClient.CheckMultiIngestSupport(ctx, storeIDs); err != nil { + return errors.Trace(err) + } + return nil +} + +// SetRawRange sets the range to be restored in raw kv mode. +func (importer *SnapFileImporter) SetRawRange(startKey, endKey []byte) error { + if importer.kvMode != Raw { + return errors.Annotate(berrors.ErrRestoreModeMismatch, "file importer is not in raw kv mode") + } + importer.rawStartKey = startKey + importer.rawEndKey = endKey + return nil +} + +func getKeyRangeByMode(mode KvMode) func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { + switch mode { + case Raw: + return func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { + return f.GetStartKey(), f.GetEndKey(), nil + } + case Txn: + return func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { + start, end := f.GetStartKey(), f.GetEndKey() + if len(start) != 0 { + start = codec.EncodeBytes([]byte{}, f.GetStartKey()) + } + if len(end) != 0 { + end = codec.EncodeBytes([]byte{}, f.GetEndKey()) + } + return start, end, nil + } + default: + return func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { + return restoreutils.GetRewriteRawKeys(f, rules) + } + } +} + +// getKeyRangeForFiles gets the maximum range on files. +func (importer *SnapFileImporter) getKeyRangeForFiles( + files []*backuppb.File, + rewriteRules *restoreutils.RewriteRules, +) ([]byte, []byte, error) { + var ( + startKey, endKey []byte + start, end []byte + err error + ) + getRangeFn := getKeyRangeByMode(importer.kvMode) + for _, f := range files { + start, end, err = getRangeFn(f, rewriteRules) + if err != nil { + return nil, nil, errors.Trace(err) + } + if len(startKey) == 0 || bytes.Compare(start, startKey) < 0 { + startKey = start + } + if len(endKey) == 0 || bytes.Compare(endKey, end) < 0 { + endKey = end + } + } + + log.Debug("rewrite file keys", logutil.Files(files), + logutil.Key("startKey", startKey), logutil.Key("endKey", endKey)) + return startKey, endKey, nil +} + +// ImportSSTFiles tries to import a file. +// All rules must contain encoded keys. +func (importer *SnapFileImporter) ImportSSTFiles( + ctx context.Context, + files []*backuppb.File, + rewriteRules *restoreutils.RewriteRules, + cipher *backuppb.CipherInfo, + apiVersion kvrpcpb.APIVersion, +) error { + start := time.Now() + log.Debug("import file", logutil.Files(files)) + + // Rewrite the start key and end key of file to scan regions + startKey, endKey, err := importer.getKeyRangeForFiles(files, rewriteRules) + if err != nil { + return errors.Trace(err) + } + + err = utils.WithRetry(ctx, func() error { + // Scan regions covered by the file range + regionInfos, errScanRegion := split.PaginateScanRegion( + ctx, importer.metaClient, startKey, endKey, split.ScanRegionPaginationLimit) + if errScanRegion != nil { + return errors.Trace(errScanRegion) + } + + log.Debug("scan regions", logutil.Files(files), zap.Int("count", len(regionInfos))) + // Try to download and ingest the file in every region + regionLoop: + for _, regionInfo := range regionInfos { + info := regionInfo + // Try to download file. + downloadMetas, errDownload := importer.download(ctx, info, files, rewriteRules, cipher, apiVersion) + if errDownload != nil { + for _, e := range multierr.Errors(errDownload) { + switch errors.Cause(e) { // nolint:errorlint + case berrors.ErrKVRewriteRuleNotFound, berrors.ErrKVRangeIsEmpty: + // Skip this region + log.Warn("download file skipped", + logutil.Files(files), + logutil.Region(info.Region), + logutil.Key("startKey", startKey), + logutil.Key("endKey", endKey), + logutil.Key("file-simple-start", files[0].StartKey), + logutil.Key("file-simple-end", files[0].EndKey), + logutil.ShortError(e)) + continue regionLoop + } + } + log.Warn("download file failed, retry later", + logutil.Files(files), + logutil.Region(info.Region), + logutil.Key("startKey", startKey), + logutil.Key("endKey", endKey), + logutil.ShortError(errDownload)) + return errors.Trace(errDownload) + } + log.Debug("download file done", + zap.String("file-sample", files[0].Name), zap.Stringer("take", time.Since(start)), + logutil.Key("start", files[0].StartKey), logutil.Key("end", files[0].EndKey)) + start = time.Now() + if errIngest := importer.ingest(ctx, files, info, downloadMetas); errIngest != nil { + log.Warn("ingest file failed, retry later", + logutil.Files(files), + logutil.SSTMetas(downloadMetas), + logutil.Region(info.Region), + zap.Error(errIngest)) + return errors.Trace(errIngest) + } + log.Debug("ingest file done", zap.String("file-sample", files[0].Name), zap.Stringer("take", time.Since(start))) + } + + for _, f := range files { + summary.CollectSuccessUnit(summary.TotalKV, 1, f.TotalKvs) + summary.CollectSuccessUnit(summary.TotalBytes, 1, f.TotalBytes) + } + return nil + }, utils.NewImportSSTBackoffer()) + if err != nil { + log.Error("import sst file failed after retry, stop the whole progress", logutil.Files(files), zap.Error(err)) + return errors.Trace(err) + } + return nil +} + +// getSSTMetaFromFile compares the keys in file, region and rewrite rules, then returns a sst conn. +// The range of the returned sst meta is [regionRule.NewKeyPrefix, append(regionRule.NewKeyPrefix, 0xff)]. +func getSSTMetaFromFile( + id []byte, + file *backuppb.File, + region *metapb.Region, + regionRule *import_sstpb.RewriteRule, + rewriteMode RewriteMode, +) (meta *import_sstpb.SSTMeta, err error) { + r := *region + // If the rewrite mode is for keyspace, then the region bound should be decoded. + if rewriteMode == RewriteModeKeyspace { + if len(region.GetStartKey()) > 0 { + _, r.StartKey, err = codec.DecodeBytes(region.GetStartKey(), nil) + if err != nil { + return + } + } + if len(region.GetEndKey()) > 0 { + _, r.EndKey, err = codec.DecodeBytes(region.GetEndKey(), nil) + if err != nil { + return + } + } + } + + // Get the column family of the file by the file name. + var cfName string + if strings.Contains(file.GetName(), restoreutils.DefaultCFName) { + cfName = restoreutils.DefaultCFName + } else if strings.Contains(file.GetName(), restoreutils.WriteCFName) { + cfName = restoreutils.WriteCFName + } + // Find the overlapped part between the file and the region. + // Here we rewrites the keys to compare with the keys of the region. + rangeStart := regionRule.GetNewKeyPrefix() + // rangeStart = max(rangeStart, region.StartKey) + if bytes.Compare(rangeStart, r.GetStartKey()) < 0 { + rangeStart = r.GetStartKey() + } + + // Append 10 * 0xff to make sure rangeEnd cover all file key + // If choose to regionRule.NewKeyPrefix + 1, it may cause WrongPrefix here + // https://github.com/tikv/tikv/blob/970a9bf2a9ea782a455ae579ad237aaf6cb1daec/ + // components/sst_importer/src/sst_importer.rs#L221 + suffix := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + rangeEnd := append(append([]byte{}, regionRule.GetNewKeyPrefix()...), suffix...) + // rangeEnd = min(rangeEnd, region.EndKey) + if len(r.GetEndKey()) > 0 && bytes.Compare(rangeEnd, r.GetEndKey()) > 0 { + rangeEnd = r.GetEndKey() + } + + if bytes.Compare(rangeStart, rangeEnd) > 0 { + log.Panic("range start exceed range end", + logutil.File(file), + logutil.Key("startKey", rangeStart), + logutil.Key("endKey", rangeEnd)) + } + + log.Debug("get sstMeta", + logutil.Region(region), + logutil.File(file), + logutil.Key("startKey", rangeStart), + logutil.Key("endKey", rangeEnd)) + + return &import_sstpb.SSTMeta{ + Uuid: id, + CfName: cfName, + Range: &import_sstpb.Range{ + Start: rangeStart, + End: rangeEnd, + }, + Length: file.GetSize_(), + RegionId: region.GetId(), + RegionEpoch: region.GetRegionEpoch(), + CipherIv: file.GetCipherIv(), + }, nil +} + +// a new way to download ssts files +// 1. download write + default sst files at peer level. +// 2. control the download concurrency per store. +func (importer *SnapFileImporter) download( + ctx context.Context, + regionInfo *split.RegionInfo, + files []*backuppb.File, + rewriteRules *restoreutils.RewriteRules, + cipher *backuppb.CipherInfo, + apiVersion kvrpcpb.APIVersion, +) ([]*import_sstpb.SSTMeta, error) { + var ( + downloadMetas = make([]*import_sstpb.SSTMeta, 0, len(files)) + ) + errDownload := utils.WithRetry(ctx, func() error { + var e error + // we treat Txn kv file as Raw kv file. because we don't have table id to decode + if importer.kvMode == Raw || importer.kvMode == Txn { + downloadMetas, e = importer.downloadRawKVSST(ctx, regionInfo, files, cipher, apiVersion) + } else { + downloadMetas, e = importer.downloadSST(ctx, regionInfo, files, rewriteRules, cipher, apiVersion) + } + + failpoint.Inject("restore-storage-error", func(val failpoint.Value) { + msg := val.(string) + log.Debug("failpoint restore-storage-error injected.", zap.String("msg", msg)) + e = errors.Annotate(e, msg) + }) + failpoint.Inject("restore-gRPC-error", func(_ failpoint.Value) { + log.Warn("the connection to TiKV has been cut by a neko, meow :3") + e = status.Error(codes.Unavailable, "the connection to TiKV has been cut by a neko, meow :3") + }) + if isDecryptSstErr(e) { + log.Info("fail to decrypt when download sst, try again with no-crypt", logutil.Files(files)) + if importer.kvMode == Raw || importer.kvMode == Txn { + downloadMetas, e = importer.downloadRawKVSST(ctx, regionInfo, files, nil, apiVersion) + } else { + downloadMetas, e = importer.downloadSST(ctx, regionInfo, files, rewriteRules, nil, apiVersion) + } + } + if e != nil { + return errors.Trace(e) + } + + return nil + }, utils.NewDownloadSSTBackoffer()) + + return downloadMetas, errDownload +} + +func (importer *SnapFileImporter) buildDownloadRequest( + file *backuppb.File, + rewriteRules *restoreutils.RewriteRules, + regionInfo *split.RegionInfo, + cipher *backuppb.CipherInfo, +) (*import_sstpb.DownloadRequest, import_sstpb.SSTMeta, error) { + uid := uuid.New() + id := uid[:] + // Get the rewrite rule for the file. + fileRule := restoreutils.FindMatchedRewriteRule(file, rewriteRules) + if fileRule == nil { + return nil, import_sstpb.SSTMeta{}, errors.Trace(berrors.ErrKVRewriteRuleNotFound) + } + + // For the legacy version of TiKV, we need to encode the key prefix, since in the legacy + // version, the TiKV will rewrite the key with the encoded prefix without decoding the keys in + // the SST file. For the new version of TiKV that support keyspace rewrite, we don't need to + // encode the key prefix. The TiKV will decode the keys in the SST file and rewrite the keys + // with the plain prefix and encode the keys before writing to SST. + + // for the keyspace rewrite mode + rule := *fileRule + // for the legacy rewrite mode + if importer.rewriteMode == RewriteModeLegacy { + rule.OldKeyPrefix = restoreutils.EncodeKeyPrefix(fileRule.GetOldKeyPrefix()) + rule.NewKeyPrefix = restoreutils.EncodeKeyPrefix(fileRule.GetNewKeyPrefix()) + } + + sstMeta, err := getSSTMetaFromFile(id, file, regionInfo.Region, &rule, importer.rewriteMode) + if err != nil { + return nil, import_sstpb.SSTMeta{}, err + } + + req := &import_sstpb.DownloadRequest{ + Sst: *sstMeta, + StorageBackend: importer.backend, + Name: file.GetName(), + RewriteRule: rule, + CipherInfo: cipher, + StorageCacheId: importer.cacheKey, + // For the older version of TiDB, the request type will be default to `import_sstpb.RequestType_Legacy` + RequestType: import_sstpb.DownloadRequestType_Keyspace, + Context: &kvrpcpb.Context{ + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: "", // TODO, + }, + RequestSource: kvutil.BuildRequestSource(true, kv.InternalTxnBR, kvutil.ExplicitTypeBR), + }, + } + return req, *sstMeta, nil +} + +func (importer *SnapFileImporter) downloadSST( + ctx context.Context, + regionInfo *split.RegionInfo, + files []*backuppb.File, + rewriteRules *restoreutils.RewriteRules, + cipher *backuppb.CipherInfo, + apiVersion kvrpcpb.APIVersion, +) ([]*import_sstpb.SSTMeta, error) { + var mu sync.Mutex + downloadMetasMap := make(map[string]import_sstpb.SSTMeta) + resultMetasMap := make(map[string]*import_sstpb.SSTMeta) + downloadReqsMap := make(map[string]*import_sstpb.DownloadRequest) + for _, file := range files { + req, sstMeta, err := importer.buildDownloadRequest(file, rewriteRules, regionInfo, cipher) + if err != nil { + return nil, errors.Trace(err) + } + sstMeta.ApiVersion = apiVersion + downloadMetasMap[file.Name] = sstMeta + downloadReqsMap[file.Name] = req + } + + eg, ectx := errgroup.WithContext(ctx) + for _, p := range regionInfo.Region.GetPeers() { + peer := p + eg.Go(func() error { + tokenCh := importer.downloadTokensMap.acquireTokenCh(peer.GetStoreId(), importer.concurrencyPerStore) + select { + case <-ectx.Done(): + return ectx.Err() + case <-tokenCh: + } + defer func() { + importer.releaseToken(tokenCh) + }() + for _, file := range files { + req, ok := downloadReqsMap[file.Name] + if !ok { + return errors.New("not found file key for download request") + } + var err error + var resp *import_sstpb.DownloadResponse + resp, err = utils.WithRetryV2(ectx, utils.NewDownloadSSTBackoffer(), func(ctx context.Context) (*import_sstpb.DownloadResponse, error) { + dctx, cancel := context.WithTimeout(ctx, gRPCTimeOut) + defer cancel() + return importer.importClient.DownloadSST(dctx, peer.GetStoreId(), req) + }) + if err != nil { + return errors.Trace(err) + } + if resp.GetError() != nil { + return errors.Annotate(berrors.ErrKVDownloadFailed, resp.GetError().GetMessage()) + } + if resp.GetIsEmpty() { + return errors.Trace(berrors.ErrKVRangeIsEmpty) + } + + mu.Lock() + sstMeta, ok := downloadMetasMap[file.Name] + if !ok { + mu.Unlock() + return errors.Errorf("not found file %s for download sstMeta", file.Name) + } + sstMeta.Range = &import_sstpb.Range{ + Start: restoreutils.TruncateTS(resp.Range.GetStart()), + End: restoreutils.TruncateTS(resp.Range.GetEnd()), + } + resultMetasMap[file.Name] = &sstMeta + mu.Unlock() + + log.Debug("download from peer", + logutil.Region(regionInfo.Region), + logutil.File(file), + logutil.Peer(peer), + logutil.Key("resp-range-start", resp.Range.Start), + logutil.Key("resp-range-end", resp.Range.End), + zap.Bool("resp-isempty", resp.IsEmpty), + zap.Uint32("resp-crc32", resp.Crc32), + zap.Int("len files", len(files)), + ) + } + return nil + }) + } + if err := eg.Wait(); err != nil { + return nil, err + } + return maps.Values(resultMetasMap), nil +} + +func (importer *SnapFileImporter) downloadRawKVSST( + ctx context.Context, + regionInfo *split.RegionInfo, + files []*backuppb.File, + cipher *backuppb.CipherInfo, + apiVersion kvrpcpb.APIVersion, +) ([]*import_sstpb.SSTMeta, error) { + downloadMetas := make([]*import_sstpb.SSTMeta, 0, len(files)) + for _, file := range files { + uid := uuid.New() + id := uid[:] + // Empty rule + var rule import_sstpb.RewriteRule + sstMeta, err := getSSTMetaFromFile(id, file, regionInfo.Region, &rule, RewriteModeLegacy) + if err != nil { + return nil, err + } + + // Cut the SST file's range to fit in the restoring range. + if bytes.Compare(importer.rawStartKey, sstMeta.Range.GetStart()) > 0 { + sstMeta.Range.Start = importer.rawStartKey + } + if len(importer.rawEndKey) > 0 && + (len(sstMeta.Range.GetEnd()) == 0 || bytes.Compare(importer.rawEndKey, sstMeta.Range.GetEnd()) <= 0) { + sstMeta.Range.End = importer.rawEndKey + sstMeta.EndKeyExclusive = true + } + if bytes.Compare(sstMeta.Range.GetStart(), sstMeta.Range.GetEnd()) > 0 { + return nil, errors.Trace(berrors.ErrKVRangeIsEmpty) + } + + req := &import_sstpb.DownloadRequest{ + Sst: *sstMeta, + StorageBackend: importer.backend, + Name: file.GetName(), + RewriteRule: rule, + IsRawKv: true, + CipherInfo: cipher, + StorageCacheId: importer.cacheKey, + } + log.Debug("download SST", logutil.SSTMeta(sstMeta), logutil.Region(regionInfo.Region)) + + var atomicResp atomic.Pointer[import_sstpb.DownloadResponse] + eg, ectx := errgroup.WithContext(ctx) + for _, p := range regionInfo.Region.GetPeers() { + peer := p + eg.Go(func() error { + resp, err := importer.importClient.DownloadSST(ectx, peer.GetStoreId(), req) + if err != nil { + return errors.Trace(err) + } + if resp.GetError() != nil { + return errors.Annotate(berrors.ErrKVDownloadFailed, resp.GetError().GetMessage()) + } + if resp.GetIsEmpty() { + return errors.Trace(berrors.ErrKVRangeIsEmpty) + } + + atomicResp.Store(resp) + return nil + }) + } + + if err := eg.Wait(); err != nil { + return nil, err + } + + downloadResp := atomicResp.Load() + sstMeta.Range.Start = downloadResp.Range.GetStart() + sstMeta.Range.End = downloadResp.Range.GetEnd() + sstMeta.ApiVersion = apiVersion + downloadMetas = append(downloadMetas, sstMeta) + } + return downloadMetas, nil +} + +func (importer *SnapFileImporter) ingest( + ctx context.Context, + files []*backuppb.File, + info *split.RegionInfo, + downloadMetas []*import_sstpb.SSTMeta, +) error { + tokenCh := importer.ingestTokensMap.acquireTokenCh(info.Leader.GetStoreId(), importer.concurrencyPerStore) + select { + case <-ctx.Done(): + return ctx.Err() + case <-tokenCh: + } + defer func() { + importer.releaseToken(tokenCh) + }() + for { + ingestResp, errIngest := importer.ingestSSTs(ctx, downloadMetas, info) + if errIngest != nil { + return errors.Trace(errIngest) + } + + errPb := ingestResp.GetError() + switch { + case errPb == nil: + return nil + case errPb.NotLeader != nil: + // If error is `NotLeader`, update the region info and retry + var newInfo *split.RegionInfo + if newLeader := errPb.GetNotLeader().GetLeader(); newLeader != nil { + newInfo = &split.RegionInfo{ + Leader: newLeader, + Region: info.Region, + } + } else { + for { + // Slow path, get region from PD + newInfo, errIngest = importer.metaClient.GetRegion( + ctx, info.Region.GetStartKey()) + if errIngest != nil { + return errors.Trace(errIngest) + } + if newInfo != nil { + break + } + // do not get region info, wait a second and GetRegion() again. + log.Warn("ingest get region by key return nil", logutil.Region(info.Region), + logutil.Files(files), + logutil.SSTMetas(downloadMetas), + ) + time.Sleep(time.Second) + } + } + + if !split.CheckRegionEpoch(newInfo, info) { + return errors.Trace(berrors.ErrKVEpochNotMatch) + } + log.Debug("ingest sst returns not leader error, retry it", + logutil.Files(files), + logutil.SSTMetas(downloadMetas), + logutil.Region(info.Region), + zap.Stringer("newLeader", newInfo.Leader)) + info = newInfo + case errPb.EpochNotMatch != nil: + // TODO handle epoch not match error + // 1. retry download if needed + // 2. retry ingest + return errors.Trace(berrors.ErrKVEpochNotMatch) + case errPb.KeyNotInRegion != nil: + return errors.Trace(berrors.ErrKVKeyNotInRegion) + default: + // Other errors like `ServerIsBusy`, `RegionNotFound`, etc. should be retryable + return errors.Annotatef(berrors.ErrKVIngestFailed, "ingest error %s", errPb) + } + } +} + +func (importer *SnapFileImporter) ingestSSTs( + ctx context.Context, + sstMetas []*import_sstpb.SSTMeta, + regionInfo *split.RegionInfo, +) (*import_sstpb.IngestResponse, error) { + leader := regionInfo.Leader + if leader == nil { + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", regionInfo.Region.Id) + } + reqCtx := &kvrpcpb.Context{ + RegionId: regionInfo.Region.GetId(), + RegionEpoch: regionInfo.Region.GetRegionEpoch(), + Peer: leader, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: "", // TODO, + }, + RequestSource: kvutil.BuildRequestSource(true, kv.InternalTxnBR, kvutil.ExplicitTypeBR), + } + + req := &import_sstpb.MultiIngestRequest{ + Context: reqCtx, + Ssts: sstMetas, + } + log.Debug("ingest SSTs", logutil.SSTMetas(sstMetas), logutil.Leader(leader)) + resp, err := importer.importClient.MultiIngest(ctx, leader.GetStoreId(), req) + return resp, errors.Trace(err) +} + +func isDecryptSstErr(err error) bool { + return err != nil && + strings.Contains(err.Error(), "Engine Engine") && + strings.Contains(err.Error(), "Corruption: Bad table magic number") +} diff --git a/br/pkg/restore/split/binding__failpoint_binding__.go b/br/pkg/restore/split/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..bb941977d20db --- /dev/null +++ b/br/pkg/restore/split/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package split + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/restore/split/client.go b/br/pkg/restore/split/client.go index 226ecc58360e0..c6c54cf0f46e6 100644 --- a/br/pkg/restore/split/client.go +++ b/br/pkg/restore/split/client.go @@ -277,7 +277,7 @@ func splitRegionWithFailpoint( keys [][]byte, isRawKv bool, ) (*kvrpcpb.SplitRegionResponse, error) { - failpoint.Inject("not-leader-error", func(injectNewLeader failpoint.Value) { + if injectNewLeader, _err_ := failpoint.Eval(_curpkg_("not-leader-error")); _err_ == nil { log.Debug("failpoint not-leader-error injected.") resp := &kvrpcpb.SplitRegionResponse{ RegionError: &errorpb.Error{ @@ -289,16 +289,16 @@ func splitRegionWithFailpoint( if injectNewLeader.(bool) { resp.RegionError.NotLeader.Leader = regionInfo.Leader } - failpoint.Return(resp, nil) - }) - failpoint.Inject("somewhat-retryable-error", func() { + return resp, nil + } + if _, _err_ := failpoint.Eval(_curpkg_("somewhat-retryable-error")); _err_ == nil { log.Debug("failpoint somewhat-retryable-error injected.") - failpoint.Return(&kvrpcpb.SplitRegionResponse{ + return &kvrpcpb.SplitRegionResponse{ RegionError: &errorpb.Error{ ServerIsBusy: &errorpb.ServerIsBusy{}, }, - }, nil) - }) + }, nil + } return client.SplitRegion(ctx, &kvrpcpb.SplitRegionRequest{ Context: &kvrpcpb.Context{ RegionId: regionInfo.Region.Id, @@ -646,9 +646,9 @@ func (bo *splitBackoffer) Attempt() int { } func (c *pdClient) SplitWaitAndScatter(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) { - failpoint.Inject("failToSplit", func(_ failpoint.Value) { - failpoint.Return(nil, errors.New("retryable error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failToSplit")); _err_ == nil { + return nil, errors.New("retryable error") + } if len(keys) == 0 { return []*RegionInfo{region}, nil } @@ -764,10 +764,10 @@ func (c *pdClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetO } func (c *pdClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*RegionInfo, error) { - failpoint.Inject("no-leader-error", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("no-leader-error")); _err_ == nil { logutil.CL(ctx).Debug("failpoint no-leader-error injected.") - failpoint.Return(nil, status.Error(codes.Unavailable, "not leader")) - }) + return nil, status.Error(codes.Unavailable, "not leader") + } //nolint:staticcheck regions, err := c.client.ScanRegions(ctx, key, endKey, limit) diff --git a/br/pkg/restore/split/client.go__failpoint_stash__ b/br/pkg/restore/split/client.go__failpoint_stash__ new file mode 100644 index 0000000000000..226ecc58360e0 --- /dev/null +++ b/br/pkg/restore/split/client.go__failpoint_stash__ @@ -0,0 +1,1067 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package split + +import ( + "bytes" + "context" + "crypto/tls" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/docker/go-units" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/kvproto/pkg/tikvpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/conn/util" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + brlog "github.com/pingcap/tidb/pkg/lightning/log" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/intest" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/multierr" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" +) + +const ( + splitRegionMaxRetryTime = 4 +) + +var ( + // the max total key size in a split region batch. + // our threshold should be smaller than TiKV's raft max entry size(default is 8MB). + maxBatchSplitSize = 6 * units.MiB +) + +// SplitClient is an external client used by RegionSplitter. +type SplitClient interface { + // GetStore gets a store by a store id. + GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) + // GetRegion gets a region which includes a specified key. + GetRegion(ctx context.Context, key []byte) (*RegionInfo, error) + // GetRegionByID gets a region by a region id. + GetRegionByID(ctx context.Context, regionID uint64) (*RegionInfo, error) + // SplitKeysAndScatter splits the related regions of the keys and scatters the + // new regions. It returns the new regions that need to be called with + // WaitRegionsScattered. + SplitKeysAndScatter(ctx context.Context, sortedSplitKeys [][]byte) ([]*RegionInfo, error) + + // SplitWaitAndScatter splits a region from a batch of keys, waits for the split + // is finished, and scatters the new regions. It will return the original region, + // new regions and error. The input keys should not be encoded. + // + // The split step has a few retry times. If it meets error, the error is returned + // directly. + // + // The split waiting step has a backoff retry logic, if split has made progress, + // it will not increase the retry counter. Otherwise, it will retry for about 1h. + // If the retry is timeout, it will log a warning and continue. + // + // The scatter step has a few retry times. If it meets error, it will log a + // warning and continue. + // TODO(lance6716): remove this function in interface after BR uses SplitKeysAndScatter. + SplitWaitAndScatter(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) + // GetOperator gets the status of operator of the specified region. + GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) + // ScanRegions gets a list of regions, starts from the region that contains key. + // Limit limits the maximum number of regions returned. + ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*RegionInfo, error) + // GetPlacementRule loads a placement rule from PD. + GetPlacementRule(ctx context.Context, groupID, ruleID string) (*pdhttp.Rule, error) + // SetPlacementRule insert or update a placement rule to PD. + SetPlacementRule(ctx context.Context, rule *pdhttp.Rule) error + // DeletePlacementRule removes a placement rule from PD. + DeletePlacementRule(ctx context.Context, groupID, ruleID string) error + // SetStoresLabel add or update specified label of stores. If labelValue + // is empty, it clears the label. + SetStoresLabel(ctx context.Context, stores []uint64, labelKey, labelValue string) error + // WaitRegionsScattered waits for an already started scatter region action to + // finish. Internally it will backoff and retry at the maximum internal of 2 + // seconds. If the scatter makes progress during the retry, it will not decrease + // the retry counter. If there's always no progress, it will retry for about 1h. + // Caller can set the context timeout to control the max waiting time. + // + // The first return value is always the number of regions that are not finished + // scattering no matter what the error is. + WaitRegionsScattered(ctx context.Context, regionInfos []*RegionInfo) (notFinished int, err error) +} + +// pdClient is a wrapper of pd client, can be used by RegionSplitter. +type pdClient struct { + mu sync.Mutex + client pd.Client + httpCli pdhttp.Client + tlsConf *tls.Config + storeCache map[uint64]*metapb.Store + + // FIXME when config changed during the lifetime of pdClient, + // this may mislead the scatter. + needScatterVal bool + needScatterInit sync.Once + + isRawKv bool + onSplit func(key [][]byte) + splitConcurrency int + splitBatchKeyCnt int +} + +type ClientOptionalParameter func(*pdClient) + +// WithRawKV sets the client to use raw kv mode. +func WithRawKV() ClientOptionalParameter { + return func(c *pdClient) { + c.isRawKv = true + } +} + +// WithOnSplit sets a callback function to be called after each split. +func WithOnSplit(onSplit func(key [][]byte)) ClientOptionalParameter { + return func(c *pdClient) { + c.onSplit = onSplit + } +} + +// NewClient creates a SplitClient. +// +// splitBatchKeyCnt controls how many keys are sent to TiKV in a batch in split +// region API. splitConcurrency controls how many regions are split concurrently. +func NewClient( + client pd.Client, + httpCli pdhttp.Client, + tlsConf *tls.Config, + splitBatchKeyCnt int, + splitConcurrency int, + opts ...ClientOptionalParameter, +) SplitClient { + cli := &pdClient{ + client: client, + httpCli: httpCli, + tlsConf: tlsConf, + storeCache: make(map[uint64]*metapb.Store), + splitBatchKeyCnt: splitBatchKeyCnt, + splitConcurrency: splitConcurrency, + } + for _, opt := range opts { + opt(cli) + } + return cli +} + +func (c *pdClient) needScatter(ctx context.Context) bool { + c.needScatterInit.Do(func() { + var err error + c.needScatterVal, err = c.checkNeedScatter(ctx) + if err != nil { + log.Warn( + "failed to check whether need to scatter, use permissive strategy: always scatter", + logutil.ShortError(err)) + c.needScatterVal = true + } + if !c.needScatterVal { + log.Info("skipping scatter because the replica number isn't less than store count.") + } + }) + return c.needScatterVal +} + +func (c *pdClient) scatterRegions(ctx context.Context, newRegions []*RegionInfo) error { + log.Info("scatter regions", zap.Int("regions", len(newRegions))) + // the retry is for the temporary network errors during sending request. + return utils.WithRetry(ctx, func() error { + err := c.tryScatterRegions(ctx, newRegions) + if isUnsupportedError(err) { + log.Warn("batch scatter isn't supported, rollback to old method", logutil.ShortError(err)) + c.scatterRegionsSequentially( + ctx, newRegions, + // backoff about 6s, or we give up scattering this region. + &ExponentialBackoffer{ + Attempts: 7, + BaseBackoff: 100 * time.Millisecond, + }) + return nil + } + return err + }, &ExponentialBackoffer{Attempts: 3, BaseBackoff: 500 * time.Millisecond}) +} + +func (c *pdClient) tryScatterRegions(ctx context.Context, regionInfo []*RegionInfo) error { + regionsID := make([]uint64, 0, len(regionInfo)) + for _, v := range regionInfo { + regionsID = append(regionsID, v.Region.Id) + log.Debug("scattering regions", logutil.Key("start", v.Region.StartKey), + logutil.Key("end", v.Region.EndKey), + zap.Uint64("id", v.Region.Id)) + } + resp, err := c.client.ScatterRegions(ctx, regionsID, pd.WithSkipStoreLimit()) + if err != nil { + return err + } + if pbErr := resp.GetHeader().GetError(); pbErr.GetType() != pdpb.ErrorType_OK { + return errors.Annotatef(berrors.ErrPDInvalidResponse, + "pd returns error during batch scattering: %s", pbErr) + } + return nil +} + +func (c *pdClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) { + c.mu.Lock() + defer c.mu.Unlock() + store, ok := c.storeCache[storeID] + if ok { + return store, nil + } + store, err := c.client.GetStore(ctx, storeID) + if err != nil { + return nil, errors.Trace(err) + } + c.storeCache[storeID] = store + return store, nil +} + +func (c *pdClient) GetRegion(ctx context.Context, key []byte) (*RegionInfo, error) { + region, err := c.client.GetRegion(ctx, key) + if err != nil { + return nil, errors.Trace(err) + } + if region == nil { + return nil, nil + } + return &RegionInfo{ + Region: region.Meta, + Leader: region.Leader, + }, nil +} + +func (c *pdClient) GetRegionByID(ctx context.Context, regionID uint64) (*RegionInfo, error) { + region, err := c.client.GetRegionByID(ctx, regionID) + if err != nil { + return nil, errors.Trace(err) + } + if region == nil { + return nil, nil + } + return &RegionInfo{ + Region: region.Meta, + Leader: region.Leader, + PendingPeers: region.PendingPeers, + DownPeers: region.DownPeers, + }, nil +} + +func splitRegionWithFailpoint( + ctx context.Context, + regionInfo *RegionInfo, + peer *metapb.Peer, + client tikvpb.TikvClient, + keys [][]byte, + isRawKv bool, +) (*kvrpcpb.SplitRegionResponse, error) { + failpoint.Inject("not-leader-error", func(injectNewLeader failpoint.Value) { + log.Debug("failpoint not-leader-error injected.") + resp := &kvrpcpb.SplitRegionResponse{ + RegionError: &errorpb.Error{ + NotLeader: &errorpb.NotLeader{ + RegionId: regionInfo.Region.Id, + }, + }, + } + if injectNewLeader.(bool) { + resp.RegionError.NotLeader.Leader = regionInfo.Leader + } + failpoint.Return(resp, nil) + }) + failpoint.Inject("somewhat-retryable-error", func() { + log.Debug("failpoint somewhat-retryable-error injected.") + failpoint.Return(&kvrpcpb.SplitRegionResponse{ + RegionError: &errorpb.Error{ + ServerIsBusy: &errorpb.ServerIsBusy{}, + }, + }, nil) + }) + return client.SplitRegion(ctx, &kvrpcpb.SplitRegionRequest{ + Context: &kvrpcpb.Context{ + RegionId: regionInfo.Region.Id, + RegionEpoch: regionInfo.Region.RegionEpoch, + Peer: peer, + }, + SplitKeys: keys, + IsRawKv: isRawKv, + }) +} + +func (c *pdClient) sendSplitRegionRequest( + ctx context.Context, regionInfo *RegionInfo, keys [][]byte, +) (*kvrpcpb.SplitRegionResponse, error) { + var splitErrors error + for i := 0; i < splitRegionMaxRetryTime; i++ { + retry, result, err := sendSplitRegionRequest(ctx, c, regionInfo, keys, &splitErrors, i) + if retry { + continue + } + if err != nil { + return nil, multierr.Append(splitErrors, err) + } + if result != nil { + return result, nil + } + return nil, errors.Trace(splitErrors) + } + return nil, errors.Trace(splitErrors) +} + +func sendSplitRegionRequest( + ctx context.Context, + c *pdClient, + regionInfo *RegionInfo, + keys [][]byte, + splitErrors *error, + retry int, +) (bool, *kvrpcpb.SplitRegionResponse, error) { + if intest.InTest { + mockCli, ok := c.client.(*MockPDClientForSplit) + if ok { + return mockCli.SplitRegion(regionInfo, keys, c.isRawKv) + } + } + var peer *metapb.Peer + // scanRegions may return empty Leader in https://github.com/tikv/pd/blob/v4.0.8/server/grpc_service.go#L524 + // so wee also need check Leader.Id != 0 + if regionInfo.Leader != nil && regionInfo.Leader.Id != 0 { + peer = regionInfo.Leader + } else { + if len(regionInfo.Region.Peers) == 0 { + return false, nil, + errors.Annotatef(berrors.ErrRestoreNoPeer, "region[%d] doesn't have any peer", + regionInfo.Region.GetId()) + } + peer = regionInfo.Region.Peers[0] + } + storeID := peer.GetStoreId() + store, err := c.GetStore(ctx, storeID) + if err != nil { + return false, nil, err + } + opt := grpc.WithTransportCredentials(insecure.NewCredentials()) + if c.tlsConf != nil { + opt = grpc.WithTransportCredentials(credentials.NewTLS(c.tlsConf)) + } + conn, err := grpc.Dial(store.GetAddress(), opt, + config.DefaultGrpcKeepaliveParams) + if err != nil { + return false, nil, err + } + defer conn.Close() + client := tikvpb.NewTikvClient(conn) + resp, err := splitRegionWithFailpoint(ctx, regionInfo, peer, client, keys, c.isRawKv) + if err != nil { + return false, nil, err + } + if resp.RegionError != nil { + log.Warn("fail to split region", + logutil.Region(regionInfo.Region), + logutil.Keys(keys), + zap.Stringer("regionErr", resp.RegionError)) + *splitErrors = multierr.Append(*splitErrors, + errors.Annotatef(berrors.ErrRestoreSplitFailed, "split region failed: err=%v", resp.RegionError)) + if nl := resp.RegionError.NotLeader; nl != nil { + if leader := nl.GetLeader(); leader != nil { + regionInfo.Leader = leader + } else { + newRegionInfo, findLeaderErr := c.GetRegionByID(ctx, nl.RegionId) + if findLeaderErr != nil { + return false, nil, findLeaderErr + } + if !CheckRegionEpoch(newRegionInfo, regionInfo) { + return false, nil, berrors.ErrKVEpochNotMatch + } + log.Info("find new leader", zap.Uint64("new leader", newRegionInfo.Leader.Id)) + regionInfo = newRegionInfo + } + log.Info("split region meet not leader error, retrying", + zap.Int("retry times", retry), + zap.Uint64("regionID", regionInfo.Region.Id), + zap.Any("new leader", regionInfo.Leader), + ) + return true, nil, nil + } + // TODO: we don't handle RegionNotMatch and RegionNotFound here, + // because I think we don't have enough information to retry. + // But maybe we can handle them here by some information the error itself provides. + if resp.RegionError.ServerIsBusy != nil || + resp.RegionError.StaleCommand != nil { + log.Warn("a error occurs on split region", + zap.Int("retry times", retry), + zap.Uint64("regionID", regionInfo.Region.Id), + zap.String("error", resp.RegionError.Message), + zap.Any("error verbose", resp.RegionError), + ) + return true, nil, nil + } + return false, nil, nil + } + return false, resp, nil +} + +// batchSplitRegionsWithOrigin calls the batch split region API and groups the +// returned regions into two groups: the region with the same ID as the origin, +// and the other regions. The former does not need to be scattered while the +// latter need to be scattered. +// +// Depending on the TiKV configuration right-derive-when-split, the origin region +// can be the first return region or the last return region. +func (c *pdClient) batchSplitRegionsWithOrigin( + ctx context.Context, regionInfo *RegionInfo, keys [][]byte, +) (*RegionInfo, []*RegionInfo, error) { + resp, err := c.sendSplitRegionRequest(ctx, regionInfo, keys) + if err != nil { + return nil, nil, errors.Trace(err) + } + + regions := resp.GetRegions() + newRegionInfos := make([]*RegionInfo, 0, len(regions)) + var originRegion *RegionInfo + for _, region := range regions { + var leader *metapb.Peer + + // Assume the leaders will be at the same store. + if regionInfo.Leader != nil { + for _, p := range region.GetPeers() { + if p.GetStoreId() == regionInfo.Leader.GetStoreId() { + leader = p + break + } + } + } + // original region + if region.GetId() == regionInfo.Region.GetId() { + originRegion = &RegionInfo{ + Region: region, + Leader: leader, + } + continue + } + newRegionInfos = append(newRegionInfos, &RegionInfo{ + Region: region, + Leader: leader, + }) + } + return originRegion, newRegionInfos, nil +} + +func (c *pdClient) waitRegionsSplit(ctx context.Context, newRegions []*RegionInfo) error { + backoffer := NewBackoffMayNotCountBackoffer() + needRecheck := make([]*RegionInfo, 0, len(newRegions)) + return utils.WithRetryReturnLastErr(ctx, func() error { + needRecheck = needRecheck[:0] + + for _, r := range newRegions { + regionID := r.Region.GetId() + + ok, err := c.hasHealthyRegion(ctx, regionID) + if !ok || err != nil { + if err != nil { + brlog.FromContext(ctx).Warn( + "wait for split failed", + zap.Uint64("regionID", regionID), + zap.Error(err), + ) + } + needRecheck = append(needRecheck, r) + } + } + + if len(needRecheck) == 0 { + return nil + } + + backoffErr := ErrBackoff + // if made progress in this round, don't increase the retryCnt + if len(needRecheck) < len(newRegions) { + backoffErr = ErrBackoffAndDontCount + } + newRegions = slices.Clone(needRecheck) + + return errors.Annotatef( + backoffErr, + "WaitRegionsSplit not finished, needRecheck: %d, the first unfinished region: %s", + len(needRecheck), needRecheck[0].Region.String(), + ) + }, backoffer) +} + +func (c *pdClient) hasHealthyRegion(ctx context.Context, regionID uint64) (bool, error) { + regionInfo, err := c.GetRegionByID(ctx, regionID) + if err != nil { + return false, errors.Trace(err) + } + // the region hasn't get ready. + if regionInfo == nil { + return false, nil + } + + // check whether the region is healthy and report. + // TODO: the log may be too verbose. we should use Prometheus metrics once it get ready for BR. + for _, peer := range regionInfo.PendingPeers { + log.Debug("unhealthy region detected", logutil.Peer(peer), zap.String("type", "pending")) + } + for _, peer := range regionInfo.DownPeers { + log.Debug("unhealthy region detected", logutil.Peer(peer), zap.String("type", "down")) + } + // we ignore down peers for they are (normally) hard to be fixed in reasonable time. + // (or once there is a peer down, we may get stuck at waiting region get ready.) + return len(regionInfo.PendingPeers) == 0, nil +} + +func (c *pdClient) SplitKeysAndScatter(ctx context.Context, sortedSplitKeys [][]byte) ([]*RegionInfo, error) { + if len(sortedSplitKeys) == 0 { + return nil, nil + } + // we need to find the regions that contain the split keys. However, the scan + // region API accepts a key range [start, end) where end key is exclusive, and if + // sortedSplitKeys length is 1, scan region may return empty result. So we + // increase the end key a bit. If the end key is on the region boundaries, it + // will be skipped by getSplitKeysOfRegions. + scanStart := codec.EncodeBytesExt(nil, sortedSplitKeys[0], c.isRawKv) + lastKey := kv.Key(sortedSplitKeys[len(sortedSplitKeys)-1]) + if len(lastKey) > 0 { + lastKey = lastKey.Next() + } + scanEnd := codec.EncodeBytesExt(nil, lastKey, c.isRawKv) + + // mu protects ret, retrySplitKeys, lastSplitErr + mu := sync.Mutex{} + ret := make([]*RegionInfo, 0, len(sortedSplitKeys)+1) + retrySplitKeys := make([][]byte, 0, len(sortedSplitKeys)) + var lastSplitErr error + + err := utils.WithRetryReturnLastErr(ctx, func() error { + ret = ret[:0] + + if len(retrySplitKeys) > 0 { + scanStart = codec.EncodeBytesExt(nil, retrySplitKeys[0], c.isRawKv) + lastKey2 := kv.Key(retrySplitKeys[len(retrySplitKeys)-1]) + scanEnd = codec.EncodeBytesExt(nil, lastKey2.Next(), c.isRawKv) + } + regions, err := PaginateScanRegion(ctx, c, scanStart, scanEnd, ScanRegionPaginationLimit) + if err != nil { + return err + } + log.Info("paginate scan regions", + zap.Int("count", len(regions)), + logutil.Key("start", scanStart), + logutil.Key("end", scanEnd)) + + allSplitKeys := sortedSplitKeys + if len(retrySplitKeys) > 0 { + allSplitKeys = retrySplitKeys + retrySplitKeys = retrySplitKeys[:0] + } + splitKeyMap := getSplitKeysOfRegions(allSplitKeys, regions, c.isRawKv) + workerPool := tidbutil.NewWorkerPool(uint(c.splitConcurrency), "split keys") + eg, eCtx := errgroup.WithContext(ctx) + for region, splitKeys := range splitKeyMap { + region := region + splitKeys := splitKeys + workerPool.ApplyOnErrorGroup(eg, func() error { + // TODO(lance6716): add error handling to retry from scan or retry from split + newRegions, err2 := c.SplitWaitAndScatter(eCtx, region, splitKeys) + if err2 != nil { + if common.IsContextCanceledError(err2) { + return err2 + } + log.Warn("split and scatter region meet error, will retry", + zap.Uint64("region_id", region.Region.Id), + zap.Error(err2)) + mu.Lock() + retrySplitKeys = append(retrySplitKeys, splitKeys...) + lastSplitErr = err2 + mu.Unlock() + return nil + } + + if len(newRegions) != len(splitKeys) { + log.Warn("split key count and new region count mismatch", + zap.Int("new region count", len(newRegions)), + zap.Int("split key count", len(splitKeys))) + } + mu.Lock() + ret = append(ret, newRegions...) + mu.Unlock() + return nil + }) + } + if err2 := eg.Wait(); err2 != nil { + return err2 + } + if len(retrySplitKeys) == 0 { + return nil + } + slices.SortFunc(retrySplitKeys, bytes.Compare) + return lastSplitErr + }, newSplitBackoffer()) + return ret, errors.Trace(err) +} + +type splitBackoffer struct { + state utils.RetryState +} + +func newSplitBackoffer() *splitBackoffer { + return &splitBackoffer{ + state: utils.InitialRetryState(SplitRetryTimes, SplitRetryInterval, SplitMaxRetryInterval), + } +} + +func (bo *splitBackoffer) NextBackoff(err error) time.Duration { + if berrors.ErrInvalidRange.Equal(err) { + bo.state.GiveUp() + return 0 + } + return bo.state.ExponentialBackoff() +} + +func (bo *splitBackoffer) Attempt() int { + return bo.state.Attempt() +} + +func (c *pdClient) SplitWaitAndScatter(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) { + failpoint.Inject("failToSplit", func(_ failpoint.Value) { + failpoint.Return(nil, errors.New("retryable error")) + }) + if len(keys) == 0 { + return []*RegionInfo{region}, nil + } + + var ( + start, end = 0, 0 + batchSize = 0 + newRegions = make([]*RegionInfo, 0, len(keys)) + ) + + for end <= len(keys) { + if end == len(keys) || + batchSize+len(keys[end]) > maxBatchSplitSize || + end-start >= c.splitBatchKeyCnt { + // split, wait and scatter for this batch + originRegion, newRegionsOfBatch, err := c.batchSplitRegionsWithOrigin(ctx, region, keys[start:end]) + if err != nil { + return nil, errors.Trace(err) + } + err = c.waitRegionsSplit(ctx, newRegionsOfBatch) + if err != nil { + brlog.FromContext(ctx).Warn( + "wait regions split failed, will continue anyway", + zap.Error(err), + ) + } + if err = ctx.Err(); err != nil { + return nil, errors.Trace(err) + } + err = c.scatterRegions(ctx, newRegionsOfBatch) + if err != nil { + brlog.FromContext(ctx).Warn( + "scatter regions failed, will continue anyway", + zap.Error(err), + ) + } + if c.onSplit != nil { + c.onSplit(keys[start:end]) + } + + // the region with the max start key is the region need to be further split, + // depending on the origin region is the first region or last region, we need to + // compare the origin region and the last one of new regions. + lastNewRegion := newRegionsOfBatch[len(newRegionsOfBatch)-1] + if bytes.Compare(originRegion.Region.StartKey, lastNewRegion.Region.StartKey) < 0 { + region = lastNewRegion + } else { + region = originRegion + } + newRegions = append(newRegions, newRegionsOfBatch...) + batchSize = 0 + start = end + } + + if end < len(keys) { + batchSize += len(keys[end]) + } + end++ + } + + return newRegions, errors.Trace(ctx.Err()) +} + +func (c *pdClient) getStoreCount(ctx context.Context) (int, error) { + stores, err := util.GetAllTiKVStores(ctx, c.client, util.SkipTiFlash) + if err != nil { + return 0, err + } + return len(stores), err +} + +func (c *pdClient) getMaxReplica(ctx context.Context) (int, error) { + resp, err := c.httpCli.GetReplicateConfig(ctx) + if err != nil { + return 0, errors.Trace(err) + } + key := "max-replicas" + val, ok := resp[key] + if !ok { + return 0, errors.Errorf("key %s not found in response %v", key, resp) + } + return int(val.(float64)), nil +} + +func (c *pdClient) checkNeedScatter(ctx context.Context) (bool, error) { + storeCount, err := c.getStoreCount(ctx) + if err != nil { + return false, err + } + maxReplica, err := c.getMaxReplica(ctx) + if err != nil { + return false, err + } + log.Info("checking whether need to scatter", zap.Int("store", storeCount), zap.Int("max-replica", maxReplica)) + // Skipping scatter may lead to leader unbalanced, + // currently, we skip scatter only when: + // 1. max-replica > store-count (Probably a misconfigured or playground cluster.) + // 2. store-count == 1 (No meaning for scattering.) + // We can still omit scatter when `max-replica == store-count`, if we create a BalanceLeader operator here, + // however, there isn't evidence for transform leader is much faster than scattering empty regions. + return storeCount >= maxReplica && storeCount > 1, nil +} + +func (c *pdClient) scatterRegion(ctx context.Context, regionInfo *RegionInfo) error { + if !c.needScatter(ctx) { + return nil + } + return c.client.ScatterRegion(ctx, regionInfo.Region.GetId()) +} + +func (c *pdClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { + return c.client.GetOperator(ctx, regionID) +} + +func (c *pdClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*RegionInfo, error) { + failpoint.Inject("no-leader-error", func(_ failpoint.Value) { + logutil.CL(ctx).Debug("failpoint no-leader-error injected.") + failpoint.Return(nil, status.Error(codes.Unavailable, "not leader")) + }) + + //nolint:staticcheck + regions, err := c.client.ScanRegions(ctx, key, endKey, limit) + if err != nil { + return nil, errors.Trace(err) + } + regionInfos := make([]*RegionInfo, 0, len(regions)) + for _, region := range regions { + regionInfos = append(regionInfos, &RegionInfo{ + Region: region.Meta, + Leader: region.Leader, + }) + } + return regionInfos, nil +} + +func (c *pdClient) GetPlacementRule(ctx context.Context, groupID, ruleID string) (*pdhttp.Rule, error) { + resp, err := c.httpCli.GetPlacementRule(ctx, groupID, ruleID) + return resp, errors.Trace(err) +} + +func (c *pdClient) SetPlacementRule(ctx context.Context, rule *pdhttp.Rule) error { + return c.httpCli.SetPlacementRule(ctx, rule) +} + +func (c *pdClient) DeletePlacementRule(ctx context.Context, groupID, ruleID string) error { + return c.httpCli.DeletePlacementRule(ctx, groupID, ruleID) +} + +func (c *pdClient) SetStoresLabel( + ctx context.Context, stores []uint64, labelKey, labelValue string, +) error { + m := map[string]string{labelKey: labelValue} + for _, id := range stores { + err := c.httpCli.SetStoreLabels(ctx, int64(id), m) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (c *pdClient) scatterRegionsSequentially(ctx context.Context, newRegions []*RegionInfo, backoffer utils.Backoffer) { + newRegionSet := make(map[uint64]*RegionInfo, len(newRegions)) + for _, newRegion := range newRegions { + newRegionSet[newRegion.Region.Id] = newRegion + } + + if err := utils.WithRetry(ctx, func() error { + log.Info("trying to scatter regions...", zap.Int("remain", len(newRegionSet))) + var errs error + for _, region := range newRegionSet { + err := c.scatterRegion(ctx, region) + if err == nil { + // it is safe according to the Go language spec. + delete(newRegionSet, region.Region.Id) + } else if !PdErrorCanRetry(err) { + log.Warn("scatter meet error cannot be retried, skipping", + logutil.ShortError(err), + logutil.Region(region.Region), + ) + delete(newRegionSet, region.Region.Id) + } + errs = multierr.Append(errs, err) + } + return errs + }, backoffer); err != nil { + log.Warn("Some regions haven't been scattered because errors.", + zap.Int("count", len(newRegionSet)), + // if all region are failed to scatter, the short error might also be verbose... + logutil.ShortError(err), + logutil.AbbreviatedArray("failed-regions", newRegionSet, func(i any) []string { + m := i.(map[uint64]*RegionInfo) + result := make([]string, 0, len(m)) + for id := range m { + result = append(result, strconv.Itoa(int(id))) + } + return result + }), + ) + } +} + +func (c *pdClient) isScatterRegionFinished( + ctx context.Context, + regionID uint64, +) (scatterDone bool, needRescatter bool, scatterErr error) { + resp, err := c.GetOperator(ctx, regionID) + if err != nil { + if common.IsRetryableError(err) { + // retry in the next cycle + return false, false, nil + } + return false, false, errors.Trace(err) + } + return isScatterRegionFinished(resp) +} + +func (c *pdClient) WaitRegionsScattered(ctx context.Context, regions []*RegionInfo) (int, error) { + var ( + backoffer = NewBackoffMayNotCountBackoffer() + retryCnt = -1 + needRescatter = make([]*RegionInfo, 0, len(regions)) + needRecheck = make([]*RegionInfo, 0, len(regions)) + ) + + err := utils.WithRetryReturnLastErr(ctx, func() error { + retryCnt++ + loggedInThisRound := false + needRecheck = needRecheck[:0] + needRescatter = needRescatter[:0] + + for i, region := range regions { + regionID := region.Region.GetId() + + if retryCnt > 10 && !loggedInThisRound { + loggedInThisRound = true + resp, err := c.GetOperator(ctx, regionID) + brlog.FromContext(ctx).Info( + "retried many times to wait for scattering regions, checking operator", + zap.Int("retryCnt", retryCnt), + zap.Uint64("firstRegionID", regionID), + zap.Stringer("response", resp), + zap.Error(err), + ) + } + + ok, rescatter, err := c.isScatterRegionFinished(ctx, regionID) + if err != nil { + if !common.IsRetryableError(err) { + brlog.FromContext(ctx).Warn( + "wait for scatter region encountered non-retryable error", + logutil.Region(region.Region), + zap.Error(err), + ) + needRecheck = append(needRecheck, regions[i:]...) + return err + } + // if meet retryable error, recheck this region in next round + brlog.FromContext(ctx).Warn( + "wait for scatter region encountered error, will retry again", + logutil.Region(region.Region), + zap.Error(err), + ) + needRecheck = append(needRecheck, region) + continue + } + + if ok { + continue + } + // not finished scattered, check again in next round + needRecheck = append(needRecheck, region) + + if rescatter { + needRescatter = append(needRescatter, region) + } + } + + if len(needRecheck) == 0 { + return nil + } + + backoffErr := ErrBackoff + // if made progress in this round, don't increase the retryCnt + if len(needRecheck) < len(regions) { + backoffErr = ErrBackoffAndDontCount + } + + regions = slices.Clone(needRecheck) + + if len(needRescatter) > 0 { + scatterErr := c.scatterRegions(ctx, needRescatter) + if scatterErr != nil { + if !common.IsRetryableError(scatterErr) { + return scatterErr + } + + return errors.Annotate(backoffErr, scatterErr.Error()) + } + } + return errors.Annotatef( + backoffErr, + "scatter region not finished, retryCnt: %d, needRecheck: %d, needRescatter: %d, the first unfinished region: %s", + retryCnt, len(needRecheck), len(needRescatter), needRecheck[0].Region.String(), + ) + }, backoffer) + + return len(needRecheck), err +} + +// isScatterRegionFinished checks whether the scatter region operator is +// finished. +func isScatterRegionFinished(resp *pdpb.GetOperatorResponse) ( + scatterDone bool, + needRescatter bool, + scatterErr error, +) { + // Heartbeat may not be sent to PD + if respErr := resp.GetHeader().GetError(); respErr != nil { + if respErr.GetType() == pdpb.ErrorType_REGION_NOT_FOUND { + return true, false, nil + } + return false, false, errors.Annotatef( + berrors.ErrPDInvalidResponse, + "get operator error: %s, error message: %s", + respErr.GetType(), + respErr.GetMessage(), + ) + } + // that 'scatter-operator' has finished + if string(resp.GetDesc()) != "scatter-region" { + return true, false, nil + } + switch resp.GetStatus() { + case pdpb.OperatorStatus_SUCCESS: + return true, false, nil + case pdpb.OperatorStatus_RUNNING: + return false, false, nil + default: + return false, true, nil + } +} + +// CheckRegionEpoch check region epoch. +func CheckRegionEpoch(_new, _old *RegionInfo) bool { + return _new.Region.GetId() == _old.Region.GetId() && + _new.Region.GetRegionEpoch().GetVersion() == _old.Region.GetRegionEpoch().GetVersion() && + _new.Region.GetRegionEpoch().GetConfVer() == _old.Region.GetRegionEpoch().GetConfVer() +} + +// ExponentialBackoffer trivially retry any errors it meets. +// It's useful when the caller has handled the errors but +// only want to a more semantic backoff implementation. +type ExponentialBackoffer struct { + Attempts int + BaseBackoff time.Duration +} + +func (b *ExponentialBackoffer) exponentialBackoff() time.Duration { + bo := b.BaseBackoff + b.Attempts-- + if b.Attempts == 0 { + return 0 + } + b.BaseBackoff *= 2 + return bo +} + +// PdErrorCanRetry when pd error retry. +func PdErrorCanRetry(err error) bool { + // There are 3 type of reason that PD would reject a `scatter` request: + // (1) region %d has no leader + // (2) region %d is hot + // (3) region %d is not fully replicated + // + // (2) shouldn't happen in a recently splitted region. + // (1) and (3) might happen, and should be retried. + grpcErr := status.Convert(err) + if grpcErr == nil { + return false + } + return strings.Contains(grpcErr.Message(), "is not fully replicated") || + strings.Contains(grpcErr.Message(), "has no leader") +} + +// NextBackoff returns a duration to wait before retrying again. +func (b *ExponentialBackoffer) NextBackoff(error) time.Duration { + // trivially exponential back off, because we have handled the error at upper level. + return b.exponentialBackoff() +} + +// Attempt returns the remain attempt times +func (b *ExponentialBackoffer) Attempt() int { + return b.Attempts +} + +// isUnsupportedError checks whether we should fallback to ScatterRegion API when meeting the error. +func isUnsupportedError(err error) bool { + s, ok := status.FromError(errors.Cause(err)) + if !ok { + // Not a gRPC error. Something other went wrong. + return false + } + // In two conditions, we fallback to ScatterRegion: + // (1) If the RPC endpoint returns UNIMPLEMENTED. (This is just for making test cases not be so magic.) + // (2) If the Message is "region 0 not found": + // In fact, PD reuses the gRPC endpoint `ScatterRegion` for the batch version of scattering. + // When the request contains the field `regionIDs`, it would use the batch version, + // Otherwise, it uses the old version and scatter the region with `regionID` in the request. + // When facing 4.x, BR(which uses v5.x PD clients and call `ScatterRegions`!) would set `regionIDs` + // which would be ignored by protocol buffers, and leave the `regionID` be zero. + // Then the older version of PD would try to search the region with ID 0. + // (Then it consistently fails, and returns "region 0 not found".) + return s.Code() == codes.Unimplemented || + strings.Contains(s.Message(), "region 0 not found") +} diff --git a/br/pkg/restore/split/split.go b/br/pkg/restore/split/split.go index ce6faa90b209c..fef8899ab10a9 100644 --- a/br/pkg/restore/split/split.go +++ b/br/pkg/restore/split/split.go @@ -233,11 +233,11 @@ func (b *WaitRegionOnlineBackoffer) NextBackoff(err error) time.Duration { // it needs more time to wait splitting the regions that contains data in PITR. // 2s * 150 delayTime := b.Stat.ExponentialBackoff() - failpoint.Inject("hint-scan-region-backoff", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("hint-scan-region-backoff")); _err_ == nil { if val.(bool) { delayTime = time.Microsecond } - }) + } return delayTime } b.Stat.GiveUp() diff --git a/br/pkg/restore/split/split.go__failpoint_stash__ b/br/pkg/restore/split/split.go__failpoint_stash__ new file mode 100644 index 0000000000000..ce6faa90b209c --- /dev/null +++ b/br/pkg/restore/split/split.go__failpoint_stash__ @@ -0,0 +1,352 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package split + +import ( + "bytes" + "context" + "encoding/hex" + goerrors "errors" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/redact" + "go.uber.org/zap" +) + +var ( + WaitRegionOnlineAttemptTimes = config.DefaultRegionCheckBackoffLimit + SplitRetryTimes = 150 +) + +// Constants for split retry machinery. +const ( + SplitRetryInterval = 50 * time.Millisecond + SplitMaxRetryInterval = 4 * time.Second + + // it takes 30 minutes to scatter regions when each TiKV has 400k regions + ScatterWaitUpperInterval = 30 * time.Minute + + ScanRegionPaginationLimit = 128 +) + +func checkRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { + // current pd can't guarantee the consistency of returned regions + if len(regions) == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan region return empty result, startKey: %s, endKey: %s", + redact.Key(startKey), redact.Key(endKey)) + } + + if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "first region %d's startKey(%s) > startKey(%s), region epoch: %s", + regions[0].Region.Id, + redact.Key(regions[0].Region.StartKey), redact.Key(startKey), + regions[0].Region.RegionEpoch.String()) + } else if len(regions[len(regions)-1].Region.EndKey) != 0 && + bytes.Compare(regions[len(regions)-1].Region.EndKey, endKey) < 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "last region %d's endKey(%s) < endKey(%s), region epoch: %s", + regions[len(regions)-1].Region.Id, + redact.Key(regions[len(regions)-1].Region.EndKey), redact.Key(endKey), + regions[len(regions)-1].Region.RegionEpoch.String()) + } + + cur := regions[0] + if cur.Leader == nil { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader is nil", cur.Region.Id) + } + if cur.Leader.StoreId == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader's store id is 0", cur.Region.Id) + } + for _, r := range regions[1:] { + if r.Leader == nil { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader is nil", r.Region.Id) + } + if r.Leader.StoreId == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's leader's store id is 0", r.Region.Id) + } + if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region %d's endKey not equal to next region %d's startKey, endKey: %s, startKey: %s, region epoch: %s %s", + cur.Region.Id, r.Region.Id, + redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey), + cur.Region.RegionEpoch.String(), r.Region.RegionEpoch.String()) + } + cur = r + } + + return nil +} + +// PaginateScanRegion scan regions with a limit pagination and return all regions +// at once. The returned regions are continuous and cover the key range. If not, +// or meet errors, it will retry internally. +func PaginateScanRegion( + ctx context.Context, client SplitClient, startKey, endKey []byte, limit int, +) ([]*RegionInfo, error) { + if len(endKey) != 0 && bytes.Compare(startKey, endKey) > 0 { + return nil, errors.Annotatef(berrors.ErrInvalidRange, "startKey > endKey, startKey: %s, endkey: %s", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + } + + var ( + lastRegions []*RegionInfo + err error + backoffer = NewWaitRegionOnlineBackoffer() + ) + _ = utils.WithRetry(ctx, func() error { + regions := make([]*RegionInfo, 0, 16) + scanStartKey := startKey + for { + var batch []*RegionInfo + batch, err = client.ScanRegions(ctx, scanStartKey, endKey, limit) + if err != nil { + err = errors.Annotatef(berrors.ErrPDBatchScanRegion.Wrap(err), "scan regions from start-key:%s, err: %s", + redact.Key(scanStartKey), err.Error()) + return err + } + regions = append(regions, batch...) + if len(batch) < limit { + // No more region + break + } + scanStartKey = batch[len(batch)-1].Region.GetEndKey() + if len(scanStartKey) == 0 || + (len(endKey) > 0 && bytes.Compare(scanStartKey, endKey) >= 0) { + // All key space have scanned + break + } + } + // if the number of regions changed, we can infer TiKV side really + // made some progress so don't increase the retry times. + if len(regions) != len(lastRegions) { + backoffer.Stat.ReduceRetry() + } + lastRegions = regions + + if err = checkRegionConsistency(startKey, endKey, regions); err != nil { + log.Warn("failed to scan region, retrying", + logutil.ShortError(err), + zap.Int("regionLength", len(regions))) + return err + } + return nil + }, backoffer) + + return lastRegions, err +} + +// checkPartRegionConsistency only checks the continuity of regions and the first region consistency. +func checkPartRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { + // current pd can't guarantee the consistency of returned regions + if len(regions) == 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "scan region return empty result, startKey: %s, endKey: %s", + redact.Key(startKey), redact.Key(endKey)) + } + + if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "first region's startKey > startKey, startKey: %s, regionStartKey: %s", + redact.Key(startKey), redact.Key(regions[0].Region.StartKey)) + } + + cur := regions[0] + for _, r := range regions[1:] { + if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { + return errors.Annotatef(berrors.ErrPDBatchScanRegion, + "region endKey not equal to next region startKey, endKey: %s, startKey: %s", + redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey)) + } + cur = r + } + + return nil +} + +func ScanRegionsWithRetry( + ctx context.Context, client SplitClient, startKey, endKey []byte, limit int, +) ([]*RegionInfo, error) { + if len(endKey) != 0 && bytes.Compare(startKey, endKey) > 0 { + return nil, errors.Annotatef(berrors.ErrInvalidRange, "startKey > endKey, startKey: %s, endkey: %s", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + } + + var regions []*RegionInfo + var err error + // we don't need to return multierr. since there only 3 times retry. + // in most case 3 times retry have the same error. so we just return the last error. + // actually we'd better remove all multierr in br/lightning. + // because it's not easy to check multierr equals normal error. + // see https://github.com/pingcap/tidb/issues/33419. + _ = utils.WithRetry(ctx, func() error { + regions, err = client.ScanRegions(ctx, startKey, endKey, limit) + if err != nil { + err = errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan regions from start-key:%s, err: %s", + redact.Key(startKey), err.Error()) + return err + } + + if err = checkPartRegionConsistency(startKey, endKey, regions); err != nil { + log.Warn("failed to scan region, retrying", logutil.ShortError(err)) + return err + } + + return nil + }, NewWaitRegionOnlineBackoffer()) + + return regions, err +} + +type WaitRegionOnlineBackoffer struct { + Stat utils.RetryState +} + +// NewWaitRegionOnlineBackoffer create a backoff to wait region online. +func NewWaitRegionOnlineBackoffer() *WaitRegionOnlineBackoffer { + return &WaitRegionOnlineBackoffer{ + Stat: utils.InitialRetryState( + WaitRegionOnlineAttemptTimes, + time.Millisecond*10, + time.Second*2, + ), + } +} + +// NextBackoff returns a duration to wait before retrying again +func (b *WaitRegionOnlineBackoffer) NextBackoff(err error) time.Duration { + // TODO(lance6716): why we only backoff when the error is ErrPDBatchScanRegion? + var perr *errors.Error + if goerrors.As(err, &perr) && berrors.ErrPDBatchScanRegion.ID() == perr.ID() { + // it needs more time to wait splitting the regions that contains data in PITR. + // 2s * 150 + delayTime := b.Stat.ExponentialBackoff() + failpoint.Inject("hint-scan-region-backoff", func(val failpoint.Value) { + if val.(bool) { + delayTime = time.Microsecond + } + }) + return delayTime + } + b.Stat.GiveUp() + return 0 +} + +// Attempt returns the remain attempt times +func (b *WaitRegionOnlineBackoffer) Attempt() int { + return b.Stat.Attempt() +} + +// BackoffMayNotCountBackoffer is a backoffer but it may not increase the retry +// counter. It should be used with ErrBackoff or ErrBackoffAndDontCount. +type BackoffMayNotCountBackoffer struct { + state utils.RetryState +} + +var ( + ErrBackoff = errors.New("found backoff error") + ErrBackoffAndDontCount = errors.New("found backoff error but don't count") +) + +// NewBackoffMayNotCountBackoffer creates a new backoffer that may backoff or retry. +// +// TODO: currently it has the same usage as NewWaitRegionOnlineBackoffer so we +// don't expose its inner settings. +func NewBackoffMayNotCountBackoffer() *BackoffMayNotCountBackoffer { + return &BackoffMayNotCountBackoffer{ + state: utils.InitialRetryState( + WaitRegionOnlineAttemptTimes, + time.Millisecond*10, + time.Second*2, + ), + } +} + +// NextBackoff implements utils.Backoffer. For BackoffMayNotCountBackoffer, only +// ErrBackoff and ErrBackoffAndDontCount is meaningful. +func (b *BackoffMayNotCountBackoffer) NextBackoff(err error) time.Duration { + if errors.ErrorEqual(err, ErrBackoff) { + return b.state.ExponentialBackoff() + } + if errors.ErrorEqual(err, ErrBackoffAndDontCount) { + delay := b.state.ExponentialBackoff() + b.state.ReduceRetry() + return delay + } + b.state.GiveUp() + return 0 +} + +// Attempt implements utils.Backoffer. +func (b *BackoffMayNotCountBackoffer) Attempt() int { + return b.state.Attempt() +} + +// getSplitKeysOfRegions checks every input key is necessary to split region on +// it. Returns a map from region to split keys belongs to it. +// +// The key will be skipped if it's the region boundary. +// +// prerequisite: +// - sortedKeys are sorted in ascending order. +// - sortedRegions are continuous and sorted in ascending order by start key. +// - sortedRegions can cover all keys in sortedKeys. +// PaginateScanRegion should satisfy the above prerequisites. +func getSplitKeysOfRegions( + sortedKeys [][]byte, + sortedRegions []*RegionInfo, + isRawKV bool, +) map[*RegionInfo][][]byte { + splitKeyMap := make(map[*RegionInfo][][]byte, len(sortedRegions)) + curKeyIndex := 0 + splitKey := codec.EncodeBytesExt(nil, sortedKeys[curKeyIndex], isRawKV) + + for _, region := range sortedRegions { + for { + if len(sortedKeys[curKeyIndex]) == 0 { + // should not happen? + goto nextKey + } + // If splitKey is the boundary of the region, don't need to split on it. + if bytes.Equal(splitKey, region.Region.GetStartKey()) { + goto nextKey + } + // If splitKey is not in this region, we should move to the next region. + if !region.ContainsInterior(splitKey) { + break + } + + splitKeyMap[region] = append(splitKeyMap[region], sortedKeys[curKeyIndex]) + + nextKey: + curKeyIndex++ + if curKeyIndex >= len(sortedKeys) { + return splitKeyMap + } + splitKey = codec.EncodeBytesExt(nil, sortedKeys[curKeyIndex], isRawKV) + } + } + lastKey := sortedKeys[len(sortedKeys)-1] + endOfLastRegion := sortedRegions[len(sortedRegions)-1].Region.GetEndKey() + if !bytes.Equal(lastKey, endOfLastRegion) { + log.Error("in getSplitKeysOfRegions, regions don't cover all keys", + zap.String("firstKey", hex.EncodeToString(sortedKeys[0])), + zap.String("lastKey", hex.EncodeToString(lastKey)), + zap.String("firstRegionStartKey", hex.EncodeToString(sortedRegions[0].Region.GetStartKey())), + zap.String("lastRegionEndKey", hex.EncodeToString(endOfLastRegion)), + ) + } + return splitKeyMap +} diff --git a/br/pkg/storage/binding__failpoint_binding__.go b/br/pkg/storage/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..a1a747a15d57f --- /dev/null +++ b/br/pkg/storage/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package storage + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/storage/s3.go b/br/pkg/storage/s3.go index 3987512b2a0a2..0755077968f39 100644 --- a/br/pkg/storage/s3.go +++ b/br/pkg/storage/s3.go @@ -590,10 +590,10 @@ func (rs *S3Storage) ReadFile(ctx context.Context, file string) ([]byte, error) // close the body of response since data has been already read out result.Body.Close() // for unit test - failpoint.Inject("read-s3-body-failed", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("read-s3-body-failed")); _err_ == nil { log.Info("original error", zap.Error(readErr)) readErr = errors.Errorf("read: connection reset by peer") - }) + } if readErr != nil { if isDeadlineExceedError(readErr) || isCancelError(readErr) { return nil, errors.Annotatef(readErr, "failed to read body from get object result, file info: input.bucket='%s', input.key='%s', retryCnt='%d'", @@ -1169,12 +1169,12 @@ func isConnectionRefusedError(err error) bool { func (rl retryerWithLog) ShouldRetry(r *request.Request) bool { // for unit test - failpoint.Inject("replace-error-to-connection-reset-by-peer", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("replace-error-to-connection-reset-by-peer")); _err_ == nil { log.Info("original error", zap.Error(r.Error)) if r.Error != nil { r.Error = errors.New("read tcp *.*.*.*:*->*.*.*.*:*: read: connection reset by peer") } - }) + } if r.HTTPRequest.URL.Host == ec2MetaAddress && (isDeadlineExceedError(r.Error) || isConnectionResetError(r.Error)) { // fast fail for unreachable linklocal address in EC2 containers. log.Warn("failed to get EC2 metadata. skipping.", logutil.ShortError(r.Error)) diff --git a/br/pkg/storage/s3.go__failpoint_stash__ b/br/pkg/storage/s3.go__failpoint_stash__ new file mode 100644 index 0000000000000..3987512b2a0a2 --- /dev/null +++ b/br/pkg/storage/s3.go__failpoint_stash__ @@ -0,0 +1,1208 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package storage + +import ( + "bytes" + "context" + "fmt" + "io" + "net/url" + "path" + "regexp" + "strconv" + "strings" + "sync" + "time" + + alicred "github.com/aliyun/alibaba-cloud-sdk-go/sdk/auth/credentials" + aliproviders "github.com/aliyun/alibaba-cloud-sdk-go/sdk/auth/credentials/providers" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/pkg/util/prefetch" + "github.com/spf13/pflag" + "go.uber.org/zap" +) + +var hardcodedS3ChunkSize = 5 * 1024 * 1024 + +const ( + s3EndpointOption = "s3.endpoint" + s3RegionOption = "s3.region" + s3StorageClassOption = "s3.storage-class" + s3SseOption = "s3.sse" + s3SseKmsKeyIDOption = "s3.sse-kms-key-id" + s3ACLOption = "s3.acl" + s3ProviderOption = "s3.provider" + s3RoleARNOption = "s3.role-arn" + s3ExternalIDOption = "s3.external-id" + notFound = "NotFound" + // number of retries to make of operations. + maxRetries = 7 + // max number of retries when meets error + maxErrorRetries = 3 + ec2MetaAddress = "169.254.169.254" + + // the maximum number of byte to read for seek. + maxSkipOffsetByRead = 1 << 16 // 64KB + + defaultRegion = "us-east-1" + // to check the cloud type by endpoint tag. + domainAliyun = "aliyuncs.com" +) + +var permissionCheckFn = map[Permission]func(context.Context, s3iface.S3API, *backuppb.S3) error{ + AccessBuckets: s3BucketExistenceCheck, + ListObjects: listObjectsCheck, + GetObject: getObjectCheck, + PutAndDeleteObject: PutAndDeleteObjectCheck, +} + +// WriteBufferSize is the size of the buffer used for writing. (64K may be a better choice) +var WriteBufferSize = 5 * 1024 * 1024 + +// S3Storage defines some standard operations for BR/Lightning on the S3 storage. +// It implements the `ExternalStorage` interface. +type S3Storage struct { + svc s3iface.S3API + options *backuppb.S3 +} + +// GetS3APIHandle gets the handle to the S3 API. +func (rs *S3Storage) GetS3APIHandle() s3iface.S3API { + return rs.svc +} + +// GetOptions gets the external storage operations for the S3. +func (rs *S3Storage) GetOptions() *backuppb.S3 { + return rs.options +} + +// S3Uploader does multi-part upload to s3. +type S3Uploader struct { + svc s3iface.S3API + createOutput *s3.CreateMultipartUploadOutput + completeParts []*s3.CompletedPart +} + +// UploadPart update partial data to s3, we should call CreateMultipartUpload to start it, +// and call CompleteMultipartUpload to finish it. +func (u *S3Uploader) Write(ctx context.Context, data []byte) (int, error) { + partInput := &s3.UploadPartInput{ + Body: bytes.NewReader(data), + Bucket: u.createOutput.Bucket, + Key: u.createOutput.Key, + PartNumber: aws.Int64(int64(len(u.completeParts) + 1)), + UploadId: u.createOutput.UploadId, + ContentLength: aws.Int64(int64(len(data))), + } + + uploadResult, err := u.svc.UploadPartWithContext(ctx, partInput) + if err != nil { + return 0, errors.Trace(err) + } + u.completeParts = append(u.completeParts, &s3.CompletedPart{ + ETag: uploadResult.ETag, + PartNumber: partInput.PartNumber, + }) + return len(data), nil +} + +// Close complete multi upload request. +func (u *S3Uploader) Close(ctx context.Context) error { + completeInput := &s3.CompleteMultipartUploadInput{ + Bucket: u.createOutput.Bucket, + Key: u.createOutput.Key, + UploadId: u.createOutput.UploadId, + MultipartUpload: &s3.CompletedMultipartUpload{ + Parts: u.completeParts, + }, + } + _, err := u.svc.CompleteMultipartUploadWithContext(ctx, completeInput) + return errors.Trace(err) +} + +// S3BackendOptions contains options for s3 storage. +type S3BackendOptions struct { + Endpoint string `json:"endpoint" toml:"endpoint"` + Region string `json:"region" toml:"region"` + StorageClass string `json:"storage-class" toml:"storage-class"` + Sse string `json:"sse" toml:"sse"` + SseKmsKeyID string `json:"sse-kms-key-id" toml:"sse-kms-key-id"` + ACL string `json:"acl" toml:"acl"` + AccessKey string `json:"access-key" toml:"access-key"` + SecretAccessKey string `json:"secret-access-key" toml:"secret-access-key"` + SessionToken string `json:"session-token" toml:"session-token"` + Provider string `json:"provider" toml:"provider"` + ForcePathStyle bool `json:"force-path-style" toml:"force-path-style"` + UseAccelerateEndpoint bool `json:"use-accelerate-endpoint" toml:"use-accelerate-endpoint"` + RoleARN string `json:"role-arn" toml:"role-arn"` + ExternalID string `json:"external-id" toml:"external-id"` + ObjectLockEnabled bool `json:"object-lock-enabled" toml:"object-lock-enabled"` +} + +// Apply apply s3 options on backuppb.S3. +func (options *S3BackendOptions) Apply(s3 *backuppb.S3) error { + if options.Endpoint != "" { + u, err := url.Parse(options.Endpoint) + if err != nil { + return errors.Trace(err) + } + if u.Scheme == "" { + return errors.Errorf("scheme not found in endpoint") + } + if u.Host == "" { + return errors.Errorf("host not found in endpoint") + } + } + // In some cases, we need to set ForcePathStyle to false. + // Refer to: https://rclone.org/s3/#s3-force-path-style + if options.Provider == "alibaba" || options.Provider == "netease" || + options.UseAccelerateEndpoint { + options.ForcePathStyle = false + } + if options.AccessKey == "" && options.SecretAccessKey != "" { + return errors.Annotate(berrors.ErrStorageInvalidConfig, "access_key not found") + } + if options.AccessKey != "" && options.SecretAccessKey == "" { + return errors.Annotate(berrors.ErrStorageInvalidConfig, "secret_access_key not found") + } + + s3.Endpoint = strings.TrimSuffix(options.Endpoint, "/") + s3.Region = options.Region + // StorageClass, SSE and ACL are acceptable to be empty + s3.StorageClass = options.StorageClass + s3.Sse = options.Sse + s3.SseKmsKeyId = options.SseKmsKeyID + s3.Acl = options.ACL + s3.AccessKey = options.AccessKey + s3.SecretAccessKey = options.SecretAccessKey + s3.SessionToken = options.SessionToken + s3.ForcePathStyle = options.ForcePathStyle + s3.RoleArn = options.RoleARN + s3.ExternalId = options.ExternalID + s3.Provider = options.Provider + return nil +} + +// defineS3Flags defines the command line flags for S3BackendOptions. +func defineS3Flags(flags *pflag.FlagSet) { + // TODO: remove experimental tag if it's stable + flags.String(s3EndpointOption, "", + "(experimental) Set the S3 endpoint URL, please specify the http or https scheme explicitly") + flags.String(s3RegionOption, "", "(experimental) Set the S3 region, e.g. us-east-1") + flags.String(s3StorageClassOption, "", "(experimental) Set the S3 storage class, e.g. STANDARD") + flags.String(s3SseOption, "", "Set S3 server-side encryption, e.g. aws:kms") + flags.String(s3SseKmsKeyIDOption, "", "KMS CMK key id to use with S3 server-side encryption."+ + "Leave empty to use S3 owned key.") + flags.String(s3ACLOption, "", "(experimental) Set the S3 canned ACLs, e.g. authenticated-read") + flags.String(s3ProviderOption, "", "(experimental) Set the S3 provider, e.g. aws, alibaba, ceph") + flags.String(s3RoleARNOption, "", "(experimental) Set the ARN of the IAM role to assume when accessing AWS S3") + flags.String(s3ExternalIDOption, "", "(experimental) Set the external ID when assuming the role to access AWS S3") +} + +// parseFromFlags parse S3BackendOptions from command line flags. +func (options *S3BackendOptions) parseFromFlags(flags *pflag.FlagSet) error { + var err error + options.Endpoint, err = flags.GetString(s3EndpointOption) + if err != nil { + return errors.Trace(err) + } + options.Endpoint = strings.TrimSuffix(options.Endpoint, "/") + options.Region, err = flags.GetString(s3RegionOption) + if err != nil { + return errors.Trace(err) + } + options.Sse, err = flags.GetString(s3SseOption) + if err != nil { + return errors.Trace(err) + } + options.SseKmsKeyID, err = flags.GetString(s3SseKmsKeyIDOption) + if err != nil { + return errors.Trace(err) + } + options.ACL, err = flags.GetString(s3ACLOption) + if err != nil { + return errors.Trace(err) + } + options.StorageClass, err = flags.GetString(s3StorageClassOption) + if err != nil { + return errors.Trace(err) + } + options.ForcePathStyle = true + options.Provider, err = flags.GetString(s3ProviderOption) + if err != nil { + return errors.Trace(err) + } + options.RoleARN, err = flags.GetString(s3RoleARNOption) + if err != nil { + return errors.Trace(err) + } + options.ExternalID, err = flags.GetString(s3ExternalIDOption) + if err != nil { + return errors.Trace(err) + } + return nil +} + +// NewS3StorageForTest creates a new S3Storage for testing only. +func NewS3StorageForTest(svc s3iface.S3API, options *backuppb.S3) *S3Storage { + return &S3Storage{ + svc: svc, + options: options, + } +} + +// auto access without ak / sk. +func autoNewCred(qs *backuppb.S3) (cred *credentials.Credentials, err error) { + if qs.AccessKey != "" && qs.SecretAccessKey != "" { + return credentials.NewStaticCredentials(qs.AccessKey, qs.SecretAccessKey, qs.SessionToken), nil + } + endpoint := qs.Endpoint + // if endpoint is empty,return no error and run default(aws) follow. + if endpoint == "" { + return nil, nil + } + // if it Contains 'aliyuncs', fetch the sts token. + if strings.Contains(endpoint, domainAliyun) { + return createOssRAMCred() + } + // other case ,return no error and run default(aws) follow. + return nil, nil +} + +func createOssRAMCred() (*credentials.Credentials, error) { + cred, err := aliproviders.NewInstanceMetadataProvider().Retrieve() + if err != nil { + log.Warn("failed to get aliyun ram credential", zap.Error(err)) + return nil, nil + } + var aliCred, ok = cred.(*alicred.StsTokenCredential) + if !ok { + return nil, errors.Errorf("invalid credential type %T", cred) + } + newCred := credentials.NewChainCredentials([]credentials.Provider{ + &credentials.EnvProvider{}, + &credentials.SharedCredentialsProvider{}, + &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: aliCred.AccessKeyId, SecretAccessKey: aliCred.AccessKeySecret, SessionToken: aliCred.AccessKeyStsToken, ProviderName: ""}}, + }) + if _, err := newCred.Get(); err != nil { + return nil, errors.Trace(err) + } + return newCred, nil +} + +// NewS3Storage initialize a new s3 storage for metadata. +func NewS3Storage(ctx context.Context, backend *backuppb.S3, opts *ExternalStorageOptions) (obj *S3Storage, errRet error) { + qs := *backend + awsConfig := aws.NewConfig(). + WithS3ForcePathStyle(qs.ForcePathStyle). + WithCredentialsChainVerboseErrors(true) + if qs.Region == "" { + awsConfig.WithRegion(defaultRegion) + } else { + awsConfig.WithRegion(qs.Region) + } + + if opts.S3Retryer != nil { + request.WithRetryer(awsConfig, opts.S3Retryer) + } else { + request.WithRetryer(awsConfig, defaultS3Retryer()) + } + + if qs.Endpoint != "" { + awsConfig.WithEndpoint(qs.Endpoint) + } + if opts.HTTPClient != nil { + awsConfig.WithHTTPClient(opts.HTTPClient) + } + cred, err := autoNewCred(&qs) + if err != nil { + return nil, errors.Trace(err) + } + if cred != nil { + awsConfig.WithCredentials(cred) + } + // awsConfig.WithLogLevel(aws.LogDebugWithSigning) + awsSessionOpts := session.Options{ + Config: *awsConfig, + } + ses, err := session.NewSessionWithOptions(awsSessionOpts) + if err != nil { + return nil, errors.Trace(err) + } + + if !opts.SendCredentials { + // Clear the credentials if exists so that they will not be sent to TiKV + backend.AccessKey = "" + backend.SecretAccessKey = "" + backend.SessionToken = "" + } else if ses.Config.Credentials != nil { + if qs.AccessKey == "" || qs.SecretAccessKey == "" { + v, cerr := ses.Config.Credentials.Get() + if cerr != nil { + return nil, errors.Trace(cerr) + } + backend.AccessKey = v.AccessKeyID + backend.SecretAccessKey = v.SecretAccessKey + backend.SessionToken = v.SessionToken + } + } + + s3CliConfigs := []*aws.Config{} + // if role ARN and external ID are provided, try to get the credential using this way + if len(qs.RoleArn) > 0 { + creds := stscreds.NewCredentials(ses, qs.RoleArn, func(p *stscreds.AssumeRoleProvider) { + if len(qs.ExternalId) > 0 { + p.ExternalID = &qs.ExternalId + } + }) + s3CliConfigs = append(s3CliConfigs, + aws.NewConfig().WithCredentials(creds), + ) + } + c := s3.New(ses, s3CliConfigs...) + + var region string + if len(qs.Provider) == 0 || qs.Provider == "aws" { + confCred := ses.Config.Credentials + setCredOpt := func(req *request.Request) { + // s3manager.GetBucketRegionWithClient will set credential anonymous, which works with s3. + // we need reassign credential to be compatible with minio authentication. + if confCred != nil { + req.Config.Credentials = confCred + } + // s3manager.GetBucketRegionWithClient use path style addressing default. + // we need set S3ForcePathStyle by our config if we set endpoint. + if qs.Endpoint != "" { + req.Config.S3ForcePathStyle = ses.Config.S3ForcePathStyle + } + } + region, err = s3manager.GetBucketRegionWithClient(ctx, c, qs.Bucket, setCredOpt) + if err != nil { + return nil, errors.Annotatef(err, "failed to get region of bucket %s", qs.Bucket) + } + } else { + // for other s3 compatible provider like ovh storage didn't return the region correctlly + // so we cannot automatically get the bucket region. just fallback to manually region setting. + region = qs.Region + } + + if qs.Region != region { + if qs.Region != "" { + return nil, errors.Trace(fmt.Errorf("s3 bucket and region are not matched, bucket=%s, input region=%s, real region=%s", + qs.Bucket, qs.Region, region)) + } + + qs.Region = region + backend.Region = region + if region != defaultRegion { + s3CliConfigs = append(s3CliConfigs, aws.NewConfig().WithRegion(region)) + c = s3.New(ses, s3CliConfigs...) + } + } + log.Info("succeed to get bucket region from s3", zap.String("bucket region", region)) + + if len(qs.Prefix) > 0 && !strings.HasSuffix(qs.Prefix, "/") { + qs.Prefix += "/" + } + + for _, p := range opts.CheckPermissions { + err := permissionCheckFn[p](ctx, c, &qs) + if err != nil { + return nil, errors.Annotatef(berrors.ErrStorageInvalidPermission, "check permission %s failed due to %v", p, err) + } + } + + s3Storage := &S3Storage{ + svc: c, + options: &qs, + } + if opts.CheckS3ObjectLockOptions { + backend.ObjectLockEnabled = s3Storage.IsObjectLockEnabled() + } + return s3Storage, nil +} + +// s3BucketExistenceCheck checks if a bucket exists. +func s3BucketExistenceCheck(_ context.Context, svc s3iface.S3API, qs *backuppb.S3) error { + input := &s3.HeadBucketInput{ + Bucket: aws.String(qs.Bucket), + } + _, err := svc.HeadBucket(input) + return errors.Trace(err) +} + +// listObjectsCheck checks the permission of listObjects +func listObjectsCheck(_ context.Context, svc s3iface.S3API, qs *backuppb.S3) error { + input := &s3.ListObjectsInput{ + Bucket: aws.String(qs.Bucket), + Prefix: aws.String(qs.Prefix), + MaxKeys: aws.Int64(1), + } + _, err := svc.ListObjects(input) + if err != nil { + return errors.Trace(err) + } + return nil +} + +// getObjectCheck checks the permission of getObject +func getObjectCheck(_ context.Context, svc s3iface.S3API, qs *backuppb.S3) error { + input := &s3.GetObjectInput{ + Bucket: aws.String(qs.Bucket), + Key: aws.String("not-exists"), + } + _, err := svc.GetObject(input) + if aerr, ok := err.(awserr.Error); ok { + if aerr.Code() == "NoSuchKey" { + // if key not exists and we reach this error, that + // means we have the correct permission to GetObject + // other we will get another error + return nil + } + return errors.Trace(err) + } + return nil +} + +// PutAndDeleteObjectCheck checks the permission of putObject +// S3 API doesn't provide a way to check the permission, we have to put an +// object to check the permission. +// exported for testing. +func PutAndDeleteObjectCheck(ctx context.Context, svc s3iface.S3API, options *backuppb.S3) (err error) { + file := fmt.Sprintf("access-check/%s", uuid.New().String()) + defer func() { + // we always delete the object used for permission check, + // even on error, since the object might be created successfully even + // when it returns an error. + input := &s3.DeleteObjectInput{ + Bucket: aws.String(options.Bucket), + Key: aws.String(options.Prefix + file), + } + _, err2 := svc.DeleteObjectWithContext(ctx, input) + if aerr, ok := err2.(awserr.Error); ok { + if aerr.Code() != "NoSuchKey" { + log.Warn("failed to delete object used for permission check", + zap.String("bucket", options.Bucket), + zap.String("key", *input.Key), zap.Error(err2)) + } + } + if err == nil { + err = errors.Trace(err2) + } + }() + // when no permission, aws returns err with code "AccessDenied" + input := buildPutObjectInput(options, file, []byte("check")) + _, err = svc.PutObjectWithContext(ctx, input) + return errors.Trace(err) +} + +func (rs *S3Storage) IsObjectLockEnabled() bool { + input := &s3.GetObjectLockConfigurationInput{ + Bucket: aws.String(rs.options.Bucket), + } + resp, err := rs.svc.GetObjectLockConfiguration(input) + if err != nil { + log.Warn("failed to check object lock for bucket", zap.String("bucket", rs.options.Bucket), zap.Error(err)) + return false + } + if resp != nil && resp.ObjectLockConfiguration != nil { + if s3.ObjectLockEnabledEnabled == aws.StringValue(resp.ObjectLockConfiguration.ObjectLockEnabled) { + return true + } + } + return false +} + +func buildPutObjectInput(options *backuppb.S3, file string, data []byte) *s3.PutObjectInput { + input := &s3.PutObjectInput{ + Body: aws.ReadSeekCloser(bytes.NewReader(data)), + Bucket: aws.String(options.Bucket), + Key: aws.String(options.Prefix + file), + } + if options.Acl != "" { + input = input.SetACL(options.Acl) + } + if options.Sse != "" { + input = input.SetServerSideEncryption(options.Sse) + } + if options.SseKmsKeyId != "" { + input = input.SetSSEKMSKeyId(options.SseKmsKeyId) + } + if options.StorageClass != "" { + input = input.SetStorageClass(options.StorageClass) + } + return input +} + +// WriteFile writes data to a file to storage. +func (rs *S3Storage) WriteFile(ctx context.Context, file string, data []byte) error { + input := buildPutObjectInput(rs.options, file, data) + // we don't need to calculate contentMD5 if s3 object lock enabled. + // since aws-go-sdk already did it in #computeBodyHashes + // https://github.com/aws/aws-sdk-go/blob/bcb2cf3fc2263c8c28b3119b07d2dbb44d7c93a0/service/s3/body_hash.go#L30 + _, err := rs.svc.PutObjectWithContext(ctx, input) + if err != nil { + return errors.Trace(err) + } + hinput := &s3.HeadObjectInput{ + Bucket: aws.String(rs.options.Bucket), + Key: aws.String(rs.options.Prefix + file), + } + err = rs.svc.WaitUntilObjectExistsWithContext(ctx, hinput) + return errors.Trace(err) +} + +// ReadFile reads the file from the storage and returns the contents. +func (rs *S3Storage) ReadFile(ctx context.Context, file string) ([]byte, error) { + var ( + data []byte + readErr error + ) + for retryCnt := 0; retryCnt < maxErrorRetries; retryCnt += 1 { + input := &s3.GetObjectInput{ + Bucket: aws.String(rs.options.Bucket), + Key: aws.String(rs.options.Prefix + file), + } + result, err := rs.svc.GetObjectWithContext(ctx, input) + if err != nil { + return nil, errors.Annotatef(err, + "failed to read s3 file, file info: input.bucket='%s', input.key='%s'", + *input.Bucket, *input.Key) + } + data, readErr = io.ReadAll(result.Body) + // close the body of response since data has been already read out + result.Body.Close() + // for unit test + failpoint.Inject("read-s3-body-failed", func(_ failpoint.Value) { + log.Info("original error", zap.Error(readErr)) + readErr = errors.Errorf("read: connection reset by peer") + }) + if readErr != nil { + if isDeadlineExceedError(readErr) || isCancelError(readErr) { + return nil, errors.Annotatef(readErr, "failed to read body from get object result, file info: input.bucket='%s', input.key='%s', retryCnt='%d'", + *input.Bucket, *input.Key, retryCnt) + } + continue + } + return data, nil + } + // retry too much, should be failed + return nil, errors.Annotatef(readErr, "failed to read body from get object result (retry too much), file info: input.bucket='%s', input.key='%s'", + rs.options.Bucket, rs.options.Prefix+file) +} + +// DeleteFile delete the file in s3 storage +func (rs *S3Storage) DeleteFile(ctx context.Context, file string) error { + input := &s3.DeleteObjectInput{ + Bucket: aws.String(rs.options.Bucket), + Key: aws.String(rs.options.Prefix + file), + } + + _, err := rs.svc.DeleteObjectWithContext(ctx, input) + return errors.Trace(err) +} + +// s3DeleteObjectsLimit is the upper limit of objects in a delete request. +// See https://docs.aws.amazon.com/sdk-for-go/api/service/s3/#S3.DeleteObjects. +const s3DeleteObjectsLimit = 1000 + +// DeleteFiles delete the files in batch in s3 storage. +func (rs *S3Storage) DeleteFiles(ctx context.Context, files []string) error { + for len(files) > 0 { + batch := files + if len(batch) > s3DeleteObjectsLimit { + batch = batch[:s3DeleteObjectsLimit] + } + objects := make([]*s3.ObjectIdentifier, 0, len(batch)) + for _, file := range batch { + objects = append(objects, &s3.ObjectIdentifier{ + Key: aws.String(rs.options.Prefix + file), + }) + } + input := &s3.DeleteObjectsInput{ + Bucket: aws.String(rs.options.Bucket), + Delete: &s3.Delete{ + Objects: objects, + Quiet: aws.Bool(false), + }, + } + _, err := rs.svc.DeleteObjectsWithContext(ctx, input) + if err != nil { + return errors.Trace(err) + } + files = files[len(batch):] + } + return nil +} + +// FileExists check if file exists on s3 storage. +func (rs *S3Storage) FileExists(ctx context.Context, file string) (bool, error) { + input := &s3.HeadObjectInput{ + Bucket: aws.String(rs.options.Bucket), + Key: aws.String(rs.options.Prefix + file), + } + + _, err := rs.svc.HeadObjectWithContext(ctx, input) + if err != nil { + if aerr, ok := errors.Cause(err).(awserr.Error); ok { // nolint:errorlint + switch aerr.Code() { + case s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchKey, notFound: + return false, nil + } + } + return false, errors.Trace(err) + } + return true, nil +} + +// WalkDir traverse all the files in a dir. +// +// fn is the function called for each regular file visited by WalkDir. +// The first argument is the file path that can be used in `Open` +// function; the second argument is the size in byte of the file determined +// by path. +func (rs *S3Storage) WalkDir(ctx context.Context, opt *WalkOption, fn func(string, int64) error) error { + if opt == nil { + opt = &WalkOption{} + } + prefix := path.Join(rs.options.Prefix, opt.SubDir) + if len(prefix) > 0 && !strings.HasSuffix(prefix, "/") { + prefix += "/" + } + + if len(opt.ObjPrefix) != 0 { + prefix += opt.ObjPrefix + } + + maxKeys := int64(1000) + if opt.ListCount > 0 { + maxKeys = opt.ListCount + } + req := &s3.ListObjectsInput{ + Bucket: aws.String(rs.options.Bucket), + Prefix: aws.String(prefix), + MaxKeys: aws.Int64(maxKeys), + } + + for { + // FIXME: We can't use ListObjectsV2, it is not universally supported. + // (Ceph RGW supported ListObjectsV2 since v15.1.0, released 2020 Jan 30th) + // (as of 2020, DigitalOcean Spaces still does not support V2 - https://developers.digitalocean.com/documentation/spaces/#list-bucket-contents) + res, err := rs.svc.ListObjectsWithContext(ctx, req) + if err != nil { + return errors.Trace(err) + } + for _, r := range res.Contents { + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjects.html#AmazonS3-ListObjects-response-NextMarker - + // + // `res.NextMarker` is populated only if we specify req.Delimiter. + // Aliyun OSS and minio will populate NextMarker no matter what, + // but this documented behavior does apply to AWS S3: + // + // "If response does not include the NextMarker and it is truncated, + // you can use the value of the last Key in the response as the marker + // in the subsequent request to get the next set of object keys." + req.Marker = r.Key + + // when walk on specify directory, the result include storage.Prefix, + // which can not be reuse in other API(Open/Read) directly. + // so we use TrimPrefix to filter Prefix for next Open/Read. + path := strings.TrimPrefix(*r.Key, rs.options.Prefix) + // trim the prefix '/' to ensure that the path returned is consistent with the local storage + path = strings.TrimPrefix(path, "/") + itemSize := *r.Size + + // filter out s3's empty directory items + if itemSize <= 0 && strings.HasSuffix(path, "/") { + log.Info("this path is an empty directory and cannot be opened in S3. Skip it", zap.String("path", path)) + continue + } + if err = fn(path, itemSize); err != nil { + return errors.Trace(err) + } + } + if !aws.BoolValue(res.IsTruncated) { + break + } + } + + return nil +} + +// URI returns s3:///. +func (rs *S3Storage) URI() string { + return "s3://" + rs.options.Bucket + "/" + rs.options.Prefix +} + +// Open a Reader by file path. +func (rs *S3Storage) Open(ctx context.Context, path string, o *ReaderOption) (ExternalFileReader, error) { + start := int64(0) + end := int64(0) + prefetchSize := 0 + if o != nil { + if o.StartOffset != nil { + start = *o.StartOffset + } + if o.EndOffset != nil { + end = *o.EndOffset + } + prefetchSize = o.PrefetchSize + } + reader, r, err := rs.open(ctx, path, start, end) + if err != nil { + return nil, errors.Trace(err) + } + if prefetchSize > 0 { + reader = prefetch.NewReader(reader, o.PrefetchSize) + } + return &s3ObjectReader{ + storage: rs, + name: path, + reader: reader, + ctx: ctx, + rangeInfo: r, + prefetchSize: prefetchSize, + }, nil +} + +// RangeInfo represents the an HTTP Content-Range header value +// of the form `bytes [Start]-[End]/[Size]`. +type RangeInfo struct { + // Start is the absolute position of the first byte of the byte range, + // starting from 0. + Start int64 + // End is the absolute position of the last byte of the byte range. This end + // offset is inclusive, e.g. if the Size is 1000, the maximum value of End + // would be 999. + End int64 + // Size is the total size of the original file. + Size int64 +} + +// if endOffset > startOffset, should return reader for bytes in [startOffset, endOffset). +func (rs *S3Storage) open( + ctx context.Context, + path string, + startOffset, endOffset int64, +) (io.ReadCloser, RangeInfo, error) { + input := &s3.GetObjectInput{ + Bucket: aws.String(rs.options.Bucket), + Key: aws.String(rs.options.Prefix + path), + } + + // If we just open part of the object, we set `Range` in the request. + // If we meant to open the whole object, not just a part of it, + // we do not pass the range in the request, + // so that even if the object is empty, we can still get the response without errors. + // Then this behavior is similar to openning an empty file in local file system. + isFullRangeRequest := false + var rangeOffset *string + switch { + case endOffset > startOffset: + // s3 endOffset is inclusive + rangeOffset = aws.String(fmt.Sprintf("bytes=%d-%d", startOffset, endOffset-1)) + case startOffset == 0: + // openning the whole object, no need to fill the `Range` field in the request + isFullRangeRequest = true + default: + rangeOffset = aws.String(fmt.Sprintf("bytes=%d-", startOffset)) + } + input.Range = rangeOffset + result, err := rs.svc.GetObjectWithContext(ctx, input) + if err != nil { + return nil, RangeInfo{}, errors.Trace(err) + } + + var r RangeInfo + // Those requests without a `Range` will have no `ContentRange` in the response, + // In this case, we'll parse the `ContentLength` field instead. + if isFullRangeRequest { + // We must ensure the `ContentLengh` has data even if for empty objects, + // otherwise we have no places to get the object size + if result.ContentLength == nil { + return nil, RangeInfo{}, errors.Annotatef(berrors.ErrStorageUnknown, "open file '%s' failed. The S3 object has no content length", path) + } + objectSize := *(result.ContentLength) + r = RangeInfo{ + Start: 0, + End: objectSize - 1, + Size: objectSize, + } + } else { + r, err = ParseRangeInfo(result.ContentRange) + if err != nil { + return nil, RangeInfo{}, errors.Trace(err) + } + } + + if startOffset != r.Start || (endOffset != 0 && endOffset != r.End+1) { + return nil, r, errors.Annotatef(berrors.ErrStorageUnknown, "open file '%s' failed, expected range: %s, got: %v", + path, *rangeOffset, result.ContentRange) + } + + return result.Body, r, nil +} + +var contentRangeRegex = regexp.MustCompile(`bytes (\d+)-(\d+)/(\d+)$`) + +// ParseRangeInfo parses the Content-Range header and returns the offsets. +func ParseRangeInfo(info *string) (ri RangeInfo, err error) { + if info == nil || len(*info) == 0 { + err = errors.Annotate(berrors.ErrStorageUnknown, "ContentRange is empty") + return + } + subMatches := contentRangeRegex.FindStringSubmatch(*info) + if len(subMatches) != 4 { + err = errors.Annotatef(berrors.ErrStorageUnknown, "invalid content range: '%s'", *info) + return + } + + ri.Start, err = strconv.ParseInt(subMatches[1], 10, 64) + if err != nil { + err = errors.Annotatef(err, "invalid start offset value '%s' in ContentRange '%s'", subMatches[1], *info) + return + } + ri.End, err = strconv.ParseInt(subMatches[2], 10, 64) + if err != nil { + err = errors.Annotatef(err, "invalid end offset value '%s' in ContentRange '%s'", subMatches[2], *info) + return + } + ri.Size, err = strconv.ParseInt(subMatches[3], 10, 64) + if err != nil { + err = errors.Annotatef(err, "invalid size size value '%s' in ContentRange '%s'", subMatches[3], *info) + return + } + return +} + +// s3ObjectReader wrap GetObjectOutput.Body and add the `Seek` method. +type s3ObjectReader struct { + storage *S3Storage + name string + reader io.ReadCloser + pos int64 + rangeInfo RangeInfo + // reader context used for implement `io.Seek` + // currently, lightning depends on package `xitongsys/parquet-go` to read parquet file and it needs `io.Seeker` + // See: https://github.com/xitongsys/parquet-go/blob/207a3cee75900b2b95213627409b7bac0f190bb3/source/source.go#L9-L10 + ctx context.Context + prefetchSize int +} + +// Read implement the io.Reader interface. +func (r *s3ObjectReader) Read(p []byte) (n int, err error) { + retryCnt := 0 + maxCnt := r.rangeInfo.End + 1 - r.pos + if maxCnt == 0 { + return 0, io.EOF + } + if maxCnt > int64(len(p)) { + maxCnt = int64(len(p)) + } + n, err = r.reader.Read(p[:maxCnt]) + // TODO: maybe we should use !errors.Is(err, io.EOF) here to avoid error lint, but currently, pingcap/errors + // doesn't implement this method yet. + for err != nil && errors.Cause(err) != io.EOF && retryCnt < maxErrorRetries { //nolint:errorlint + log.L().Warn( + "read s3 object failed, will retry", + zap.String("file", r.name), + zap.Int("retryCnt", retryCnt), + zap.Error(err), + ) + // if can retry, reopen a new reader and try read again + end := r.rangeInfo.End + 1 + if end == r.rangeInfo.Size { + end = 0 + } + _ = r.reader.Close() + + newReader, _, err1 := r.storage.open(r.ctx, r.name, r.pos, end) + if err1 != nil { + log.Warn("open new s3 reader failed", zap.String("file", r.name), zap.Error(err1)) + return + } + r.reader = newReader + if r.prefetchSize > 0 { + r.reader = prefetch.NewReader(r.reader, r.prefetchSize) + } + retryCnt++ + n, err = r.reader.Read(p[:maxCnt]) + } + + r.pos += int64(n) + return +} + +// Close implement the io.Closer interface. +func (r *s3ObjectReader) Close() error { + return r.reader.Close() +} + +// Seek implement the io.Seeker interface. +// +// Currently, tidb-lightning depends on this method to read parquet file for s3 storage. +func (r *s3ObjectReader) Seek(offset int64, whence int) (int64, error) { + var realOffset int64 + switch whence { + case io.SeekStart: + realOffset = offset + case io.SeekCurrent: + realOffset = r.pos + offset + case io.SeekEnd: + realOffset = r.rangeInfo.Size + offset + default: + return 0, errors.Annotatef(berrors.ErrStorageUnknown, "Seek: invalid whence '%d'", whence) + } + if realOffset < 0 { + return 0, errors.Annotatef(berrors.ErrStorageUnknown, "Seek in '%s': invalid offset to seek '%d'.", r.name, realOffset) + } + + if realOffset == r.pos { + return realOffset, nil + } else if realOffset >= r.rangeInfo.Size { + // See: https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 + // because s3's GetObject interface doesn't allow get a range that matches zero length data, + // so if the position is out of range, we need to always return io.EOF after the seek operation. + + // close current read and open a new one which target offset + if err := r.reader.Close(); err != nil { + log.L().Warn("close s3 reader failed, will ignore this error", logutil.ShortError(err)) + } + + r.reader = io.NopCloser(bytes.NewReader(nil)) + r.pos = r.rangeInfo.Size + return r.pos, nil + } + + // if seek ahead no more than 64k, we discard these data + if realOffset > r.pos && realOffset-r.pos <= maxSkipOffsetByRead { + _, err := io.CopyN(io.Discard, r, realOffset-r.pos) + if err != nil { + return r.pos, errors.Trace(err) + } + return realOffset, nil + } + + // close current read and open a new one which target offset + err := r.reader.Close() + if err != nil { + return 0, errors.Trace(err) + } + + newReader, info, err := r.storage.open(r.ctx, r.name, realOffset, 0) + if err != nil { + return 0, errors.Trace(err) + } + r.reader = newReader + if r.prefetchSize > 0 { + r.reader = prefetch.NewReader(r.reader, r.prefetchSize) + } + r.rangeInfo = info + r.pos = realOffset + return realOffset, nil +} + +func (r *s3ObjectReader) GetFileSize() (int64, error) { + return r.rangeInfo.Size, nil +} + +// createUploader create multi upload request. +func (rs *S3Storage) createUploader(ctx context.Context, name string) (ExternalFileWriter, error) { + input := &s3.CreateMultipartUploadInput{ + Bucket: aws.String(rs.options.Bucket), + Key: aws.String(rs.options.Prefix + name), + } + if rs.options.Acl != "" { + input = input.SetACL(rs.options.Acl) + } + if rs.options.Sse != "" { + input = input.SetServerSideEncryption(rs.options.Sse) + } + if rs.options.SseKmsKeyId != "" { + input = input.SetSSEKMSKeyId(rs.options.SseKmsKeyId) + } + if rs.options.StorageClass != "" { + input = input.SetStorageClass(rs.options.StorageClass) + } + + resp, err := rs.svc.CreateMultipartUploadWithContext(ctx, input) + if err != nil { + return nil, errors.Trace(err) + } + return &S3Uploader{ + svc: rs.svc, + createOutput: resp, + completeParts: make([]*s3.CompletedPart, 0, 128), + }, nil +} + +type s3ObjectWriter struct { + wd *io.PipeWriter + wg *sync.WaitGroup + err error +} + +// Write implement the io.Writer interface. +func (s *s3ObjectWriter) Write(_ context.Context, p []byte) (int, error) { + return s.wd.Write(p) +} + +// Close implement the io.Closer interface. +func (s *s3ObjectWriter) Close(_ context.Context) error { + err := s.wd.Close() + if err != nil { + return err + } + s.wg.Wait() + return s.err +} + +// Create creates multi upload request. +func (rs *S3Storage) Create(ctx context.Context, name string, option *WriterOption) (ExternalFileWriter, error) { + var uploader ExternalFileWriter + var err error + if option == nil || option.Concurrency <= 1 { + uploader, err = rs.createUploader(ctx, name) + if err != nil { + return nil, err + } + } else { + up := s3manager.NewUploaderWithClient(rs.svc, func(u *s3manager.Uploader) { + u.PartSize = option.PartSize + u.Concurrency = option.Concurrency + u.BufferProvider = s3manager.NewBufferedReadSeekerWriteToPool(option.Concurrency * hardcodedS3ChunkSize) + }) + rd, wd := io.Pipe() + upParams := &s3manager.UploadInput{ + Bucket: aws.String(rs.options.Bucket), + Key: aws.String(rs.options.Prefix + name), + Body: rd, + } + s3Writer := &s3ObjectWriter{wd: wd, wg: &sync.WaitGroup{}} + s3Writer.wg.Add(1) + go func() { + _, err := up.UploadWithContext(ctx, upParams) + // like a channel we only let sender close the pipe in happy path + if err != nil { + log.Warn("upload to s3 failed", zap.String("filename", name), zap.Error(err)) + _ = rd.CloseWithError(err) + } + s3Writer.err = err + s3Writer.wg.Done() + }() + uploader = s3Writer + } + bufSize := WriteBufferSize + if option != nil && option.PartSize > 0 { + bufSize = int(option.PartSize) + } + uploaderWriter := newBufferedWriter(uploader, bufSize, NoCompression) + return uploaderWriter, nil +} + +// Rename implements ExternalStorage interface. +func (rs *S3Storage) Rename(ctx context.Context, oldFileName, newFileName string) error { + content, err := rs.ReadFile(ctx, oldFileName) + if err != nil { + return errors.Trace(err) + } + err = rs.WriteFile(ctx, newFileName, content) + if err != nil { + return errors.Trace(err) + } + if err = rs.DeleteFile(ctx, oldFileName); err != nil { + return errors.Trace(err) + } + return nil +} + +// Close implements ExternalStorage interface. +func (*S3Storage) Close() {} + +// retryerWithLog wrappes the client.DefaultRetryer, and logging when retry triggered. +type retryerWithLog struct { + client.DefaultRetryer +} + +func isCancelError(err error) bool { + return strings.Contains(err.Error(), "context canceled") +} + +func isDeadlineExceedError(err error) bool { + // TODO find a better way. + // Known challenges: + // + // If we want to unwrap the r.Error: + // 1. the err should be an awserr.Error (let it be awsErr) + // 2. awsErr.OrigErr() should be an *url.Error (let it be urlErr). + // 3. urlErr.Err should be a http.httpError (which is private). + // + // If we want to reterive the error from the request context: + // The error of context in the HTTPRequest (i.e. r.HTTPRequest.Context().Err() ) is nil. + return strings.Contains(err.Error(), "context deadline exceeded") +} + +func isConnectionResetError(err error) bool { + return strings.Contains(err.Error(), "read: connection reset") +} + +func isConnectionRefusedError(err error) bool { + return strings.Contains(err.Error(), "connection refused") +} + +func (rl retryerWithLog) ShouldRetry(r *request.Request) bool { + // for unit test + failpoint.Inject("replace-error-to-connection-reset-by-peer", func(_ failpoint.Value) { + log.Info("original error", zap.Error(r.Error)) + if r.Error != nil { + r.Error = errors.New("read tcp *.*.*.*:*->*.*.*.*:*: read: connection reset by peer") + } + }) + if r.HTTPRequest.URL.Host == ec2MetaAddress && (isDeadlineExceedError(r.Error) || isConnectionResetError(r.Error)) { + // fast fail for unreachable linklocal address in EC2 containers. + log.Warn("failed to get EC2 metadata. skipping.", logutil.ShortError(r.Error)) + return false + } + if isConnectionResetError(r.Error) { + return true + } + if isConnectionRefusedError(r.Error) { + return false + } + return rl.DefaultRetryer.ShouldRetry(r) +} + +func (rl retryerWithLog) RetryRules(r *request.Request) time.Duration { + backoffTime := rl.DefaultRetryer.RetryRules(r) + if backoffTime > 0 { + log.Warn("failed to request s3, retrying", zap.Error(r.Error), zap.Duration("backoff", backoffTime)) + } + return backoffTime +} + +func defaultS3Retryer() request.Retryer { + return retryerWithLog{ + DefaultRetryer: client.DefaultRetryer{ + NumMaxRetries: maxRetries, + MinRetryDelay: 1 * time.Second, + MinThrottleDelay: 2 * time.Second, + }, + } +} diff --git a/br/pkg/streamhelper/advancer.go b/br/pkg/streamhelper/advancer.go index 6d477994ef07f..1b21c8da19e59 100644 --- a/br/pkg/streamhelper/advancer.go +++ b/br/pkg/streamhelper/advancer.go @@ -138,9 +138,9 @@ func (c *checkpoint) equal(o *checkpoint) bool { // we should try to resolve lock for the range // to keep the RPO in 5 min. func (c *checkpoint) needResolveLocks() bool { - failpoint.Inject("NeedResolveLocks", func(val failpoint.Value) { - failpoint.Return(val.(bool)) - }) + if val, _err_ := failpoint.Eval(_curpkg_("NeedResolveLocks")); _err_ == nil { + return val.(bool) + } return time.Since(c.resolveLockTime) > 3*time.Minute } @@ -532,7 +532,7 @@ func (c *CheckpointAdvancer) SpawnSubscriptionHandler(ctx context.Context) { if !ok { return } - failpoint.Inject("subscription-handler-loop", func() {}) + failpoint.Eval(_curpkg_("subscription-handler-loop")) c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { if vsf == nil { log.Warn("Span tree not found, perhaps stale event of removed tasks.", @@ -555,7 +555,7 @@ func (c *CheckpointAdvancer) subscribeTick(ctx context.Context) error { if c.subscriber == nil { return nil } - failpoint.Inject("get_subscriber", nil) + failpoint.Eval(_curpkg_("get_subscriber")) if err := c.subscriber.UpdateStoreTopology(ctx); err != nil { log.Warn("Error when updating store topology.", zap.String("category", "log backup advancer"), logutil.ShortError(err)) @@ -684,7 +684,7 @@ func (c *CheckpointAdvancer) asyncResolveLocksForRanges(ctx context.Context, tar // run in another goroutine // do not block main tick here go func() { - failpoint.Inject("AsyncResolveLocks", func() {}) + failpoint.Eval(_curpkg_("AsyncResolveLocks")) handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { // we will scan all locks and try to resolve them by check txn status. return tikv.ResolveLocksForRange( diff --git a/br/pkg/streamhelper/advancer.go__failpoint_stash__ b/br/pkg/streamhelper/advancer.go__failpoint_stash__ new file mode 100644 index 0000000000000..6d477994ef07f --- /dev/null +++ b/br/pkg/streamhelper/advancer.go__failpoint_stash__ @@ -0,0 +1,735 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "bytes" + "context" + "fmt" + "math" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/streamhelper/config" + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/util" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/txnkv/rangetask" + "go.uber.org/multierr" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +// CheckpointAdvancer is the central node for advancing the checkpoint of log backup. +// It's a part of "checkpoint v3". +// Generally, it scan the regions in the task range, collect checkpoints from tikvs. +/* + ┌──────┐ + ┌────►│ TiKV │ + │ └──────┘ + │ + │ + ┌──────────┐GetLastFlushTSOfRegion│ ┌──────┐ + │ Advancer ├──────────────────────┼────►│ TiKV │ + └────┬─────┘ │ └──────┘ + │ │ + │ │ + │ │ ┌──────┐ + │ └────►│ TiKV │ + │ └──────┘ + │ + │ UploadCheckpointV3 ┌──────────────────┐ + └─────────────────────►│ PD │ + └──────────────────┘ +*/ +type CheckpointAdvancer struct { + env Env + + // The concurrency accessed task: + // both by the task listener and ticking. + task *backuppb.StreamBackupTaskInfo + taskRange []kv.KeyRange + taskMu sync.Mutex + + // the read-only config. + // once tick begin, this should not be changed for now. + cfg config.Config + + // the cached last checkpoint. + // if no progress, this cache can help us don't to send useless requests. + lastCheckpoint *checkpoint + lastCheckpointMu sync.Mutex + inResolvingLock atomic.Bool + isPaused atomic.Bool + + checkpoints *spans.ValueSortedFull + checkpointsMu sync.Mutex + + subscriber *FlushSubscriber + subscriberMu sync.Mutex +} + +// HasTask returns whether the advancer has been bound to a task. +func (c *CheckpointAdvancer) HasTask() bool { + c.taskMu.Lock() + defer c.taskMu.Unlock() + + return c.task != nil +} + +// HasSubscriber returns whether the advancer is associated with a subscriber. +func (c *CheckpointAdvancer) HasSubscribion() bool { + c.subscriberMu.Lock() + defer c.subscriberMu.Unlock() + + return c.subscriber != nil && len(c.subscriber.subscriptions) > 0 +} + +// checkpoint represents the TS with specific range. +// it's only used in advancer.go. +type checkpoint struct { + StartKey []byte + EndKey []byte + TS uint64 + + // It's better to use PD timestamp in future, for now + // use local time to decide the time to resolve lock is ok. + resolveLockTime time.Time +} + +func newCheckpointWithTS(ts uint64) *checkpoint { + return &checkpoint{ + TS: ts, + resolveLockTime: time.Now(), + } +} + +func NewCheckpointWithSpan(s spans.Valued) *checkpoint { + return &checkpoint{ + StartKey: s.Key.StartKey, + EndKey: s.Key.EndKey, + TS: s.Value, + resolveLockTime: time.Now(), + } +} + +func (c *checkpoint) safeTS() uint64 { + return c.TS - 1 +} + +func (c *checkpoint) equal(o *checkpoint) bool { + return bytes.Equal(c.StartKey, o.StartKey) && + bytes.Equal(c.EndKey, o.EndKey) && c.TS == o.TS +} + +// if a checkpoint stay in a time too long(3 min) +// we should try to resolve lock for the range +// to keep the RPO in 5 min. +func (c *checkpoint) needResolveLocks() bool { + failpoint.Inject("NeedResolveLocks", func(val failpoint.Value) { + failpoint.Return(val.(bool)) + }) + return time.Since(c.resolveLockTime) > 3*time.Minute +} + +// NewCheckpointAdvancer creates a checkpoint advancer with the env. +func NewCheckpointAdvancer(env Env) *CheckpointAdvancer { + return &CheckpointAdvancer{ + env: env, + cfg: config.Default(), + } +} + +// UpdateConfig updates the config for the advancer. +// Note this should be called before starting the loop, because there isn't locks, +// TODO: support updating config when advancer starts working. +// (Maybe by applying changes at begin of ticking, and add locks.) +func (c *CheckpointAdvancer) UpdateConfig(newConf config.Config) { + c.cfg = newConf +} + +// UpdateConfigWith updates the config by modifying the current config. +func (c *CheckpointAdvancer) UpdateConfigWith(f func(*config.Config)) { + cfg := c.cfg + f(&cfg) + c.UpdateConfig(cfg) +} + +// UpdateLastCheckpoint modify the checkpoint in ticking. +func (c *CheckpointAdvancer) UpdateLastCheckpoint(p *checkpoint) { + c.lastCheckpointMu.Lock() + c.lastCheckpoint = p + c.lastCheckpointMu.Unlock() +} + +// Config returns the current config. +func (c *CheckpointAdvancer) Config() config.Config { + return c.cfg +} + +// GetInResolvingLock only used for test. +func (c *CheckpointAdvancer) GetInResolvingLock() bool { + return c.inResolvingLock.Load() +} + +// GetCheckpointInRange scans the regions in the range, +// collect them to the collector. +func (c *CheckpointAdvancer) GetCheckpointInRange(ctx context.Context, start, end []byte, + collector *clusterCollector) error { + log.Debug("scanning range", logutil.Key("start", start), logutil.Key("end", end)) + iter := IterateRegion(c.env, start, end) + for !iter.Done() { + rs, err := iter.Next(ctx) + if err != nil { + return err + } + log.Debug("scan region", zap.Int("len", len(rs))) + for _, r := range rs { + err := collector.CollectRegion(r) + if err != nil { + log.Warn("meet error during getting checkpoint", logutil.ShortError(err)) + return err + } + } + } + return nil +} + +func (c *CheckpointAdvancer) recordTimeCost(message string, fields ...zap.Field) func() { + now := time.Now() + label := strings.ReplaceAll(message, " ", "-") + return func() { + cost := time.Since(now) + fields = append(fields, zap.Stringer("take", cost)) + metrics.AdvancerTickDuration.WithLabelValues(label).Observe(cost.Seconds()) + log.Debug(message, fields...) + } +} + +// tryAdvance tries to advance the checkpoint ts of a set of ranges which shares the same checkpoint. +func (c *CheckpointAdvancer) tryAdvance(ctx context.Context, length int, + getRange func(int) kv.KeyRange) (err error) { + defer c.recordTimeCost("try advance", zap.Int("len", length))() + defer utils.PanicToErr(&err) + + ranges := spans.Collapse(length, getRange) + workers := util.NewWorkerPool(uint(config.DefaultMaxConcurrencyAdvance)*4, "sub ranges") + eg, cx := errgroup.WithContext(ctx) + collector := NewClusterCollector(ctx, c.env) + collector.SetOnSuccessHook(func(u uint64, kr kv.KeyRange) { + c.checkpointsMu.Lock() + defer c.checkpointsMu.Unlock() + c.checkpoints.Merge(spans.Valued{Key: kr, Value: u}) + }) + clampedRanges := utils.IntersectAll(ranges, utils.CloneSlice(c.taskRange)) + for _, r := range clampedRanges { + r := r + workers.ApplyOnErrorGroup(eg, func() (e error) { + defer c.recordTimeCost("get regions in range")() + defer utils.PanicToErr(&e) + return c.GetCheckpointInRange(cx, r.StartKey, r.EndKey, collector) + }) + } + err = eg.Wait() + if err != nil { + return err + } + + _, err = collector.Finish(ctx) + if err != nil { + return err + } + return nil +} + +func tsoBefore(n time.Duration) uint64 { + now := time.Now() + return oracle.ComposeTS(now.UnixMilli()-n.Milliseconds(), 0) +} + +func tsoAfter(ts uint64, n time.Duration) uint64 { + return oracle.GoTimeToTS(oracle.GetTimeFromTS(ts).Add(n)) +} + +func (c *CheckpointAdvancer) WithCheckpoints(f func(*spans.ValueSortedFull)) { + c.checkpointsMu.Lock() + defer c.checkpointsMu.Unlock() + + f(c.checkpoints) +} + +// only used for test +func (c *CheckpointAdvancer) NewCheckpoints(cps *spans.ValueSortedFull) { + c.checkpoints = cps +} + +func (c *CheckpointAdvancer) fetchRegionHint(ctx context.Context, startKey []byte) string { + region, err := locateKeyOfRegion(ctx, c.env, startKey) + if err != nil { + return errors.Annotate(err, "failed to fetch region").Error() + } + r := region.Region + l := region.Leader + prs := []int{} + for _, p := range r.GetPeers() { + prs = append(prs, int(p.StoreId)) + } + metrics.LogBackupCurrentLastRegionID.Set(float64(r.Id)) + metrics.LogBackupCurrentLastRegionLeaderStoreID.Set(float64(l.StoreId)) + return fmt.Sprintf("ID=%d,Leader=%d,ConfVer=%d,Version=%d,Peers=%v,RealRange=%s", + r.GetId(), l.GetStoreId(), r.GetRegionEpoch().GetConfVer(), r.GetRegionEpoch().GetVersion(), + prs, logutil.StringifyRangeOf(r.GetStartKey(), r.GetEndKey())) +} + +func (c *CheckpointAdvancer) CalculateGlobalCheckpointLight(ctx context.Context, + threshold time.Duration) (spans.Valued, error) { + var targets []spans.Valued + var minValue spans.Valued + thresholdTso := tsoBefore(threshold) + c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { + vsf.TraverseValuesLessThan(thresholdTso, func(v spans.Valued) bool { + targets = append(targets, v) + return true + }) + minValue = vsf.Min() + }) + sctx, cancel := context.WithTimeout(ctx, time.Second) + // Always fetch the hint and update the metrics. + hint := c.fetchRegionHint(sctx, minValue.Key.StartKey) + logger := log.Debug + if minValue.Value < thresholdTso { + logger = log.Info + } + logger("current last region", zap.String("category", "log backup advancer hint"), + zap.Stringer("min", minValue), zap.Int("for-polling", len(targets)), + zap.String("min-ts", oracle.GetTimeFromTS(minValue.Value).Format(time.RFC3339)), + zap.String("region-hint", hint), + ) + cancel() + if len(targets) == 0 { + return minValue, nil + } + err := c.tryAdvance(ctx, len(targets), func(i int) kv.KeyRange { return targets[i].Key }) + if err != nil { + return minValue, err + } + return minValue, nil +} + +func (c *CheckpointAdvancer) consumeAllTask(ctx context.Context, ch <-chan TaskEvent) error { + for { + select { + case e, ok := <-ch: + if !ok { + return nil + } + log.Info("meet task event", zap.Stringer("event", &e)) + if err := c.onTaskEvent(ctx, e); err != nil { + if errors.Cause(e.Err) != context.Canceled { + log.Error("listen task meet error, would reopen.", logutil.ShortError(err)) + return err + } + return nil + } + default: + return nil + } + } +} + +// beginListenTaskChange bootstraps the initial task set, +// and returns a channel respecting the change of tasks. +func (c *CheckpointAdvancer) beginListenTaskChange(ctx context.Context) (<-chan TaskEvent, error) { + ch := make(chan TaskEvent, 1024) + if err := c.env.Begin(ctx, ch); err != nil { + return nil, err + } + err := c.consumeAllTask(ctx, ch) + if err != nil { + return nil, err + } + return ch, nil +} + +// StartTaskListener starts the task listener for the advancer. +// When no task detected, advancer would do nothing, please call this before begin the tick loop. +func (c *CheckpointAdvancer) StartTaskListener(ctx context.Context) { + cx, cancel := context.WithCancel(ctx) + var ch <-chan TaskEvent + for { + if cx.Err() != nil { + // make linter happy. + cancel() + return + } + var err error + ch, err = c.beginListenTaskChange(cx) + if err == nil { + break + } + log.Warn("failed to begin listening, retrying...", logutil.ShortError(err)) + time.Sleep(c.cfg.BackoffTime) + } + + go func() { + defer cancel() + for { + select { + case <-ctx.Done(): + return + case e, ok := <-ch: + if !ok { + log.Info("Task watcher exits due to stream ends.", zap.String("category", "log backup advancer")) + return + } + log.Info("Meet task event", zap.String("category", "log backup advancer"), zap.Stringer("event", &e)) + if err := c.onTaskEvent(ctx, e); err != nil { + if errors.Cause(e.Err) != context.Canceled { + log.Error("listen task meet error, would reopen.", logutil.ShortError(err)) + time.AfterFunc(c.cfg.BackoffTime, func() { c.StartTaskListener(ctx) }) + } + log.Info("Task watcher exits due to some error.", zap.String("category", "log backup advancer"), + logutil.ShortError(err)) + return + } + } + } + }() +} + +func (c *CheckpointAdvancer) setCheckpoints(cps *spans.ValueSortedFull) { + c.checkpointsMu.Lock() + c.checkpoints = cps + c.checkpointsMu.Unlock() +} + +func (c *CheckpointAdvancer) onTaskEvent(ctx context.Context, e TaskEvent) error { + c.taskMu.Lock() + defer c.taskMu.Unlock() + switch e.Type { + case EventAdd: + utils.LogBackupTaskCountInc() + c.task = e.Info + c.taskRange = spans.Collapse(len(e.Ranges), func(i int) kv.KeyRange { return e.Ranges[i] }) + c.setCheckpoints(spans.Sorted(spans.NewFullWith(e.Ranges, 0))) + globalCheckpointTs, err := c.env.GetGlobalCheckpointForTask(ctx, e.Name) + if err != nil { + log.Error("failed to get global checkpoint, skipping.", logutil.ShortError(err)) + return err + } + if globalCheckpointTs < c.task.StartTs { + globalCheckpointTs = c.task.StartTs + } + log.Info("get global checkpoint", zap.Uint64("checkpoint", globalCheckpointTs)) + c.lastCheckpoint = newCheckpointWithTS(globalCheckpointTs) + p, err := c.env.BlockGCUntil(ctx, globalCheckpointTs) + if err != nil { + log.Warn("failed to upload service GC safepoint, skipping.", logutil.ShortError(err)) + } + log.Info("added event", zap.Stringer("task", e.Info), + zap.Stringer("ranges", logutil.StringifyKeys(c.taskRange)), zap.Uint64("current-checkpoint", p)) + case EventDel: + utils.LogBackupTaskCountDec() + c.task = nil + c.isPaused.Store(false) + c.taskRange = nil + // This would be synced by `taskMu`, perhaps we'd better rename that to `tickMu`. + // Do the null check because some of test cases won't equip the advancer with subscriber. + if c.subscriber != nil { + c.subscriber.Clear() + } + c.setCheckpoints(nil) + if err := c.env.ClearV3GlobalCheckpointForTask(ctx, e.Name); err != nil { + log.Warn("failed to clear global checkpoint", logutil.ShortError(err)) + } + if err := c.env.UnblockGC(ctx); err != nil { + log.Warn("failed to remove service GC safepoint", logutil.ShortError(err)) + } + metrics.LastCheckpoint.DeleteLabelValues(e.Name) + case EventPause: + if c.task.GetName() == e.Name { + c.isPaused.Store(true) + } + case EventResume: + if c.task.GetName() == e.Name { + c.isPaused.Store(false) + } + case EventErr: + return e.Err + } + return nil +} + +func (c *CheckpointAdvancer) setCheckpoint(ctx context.Context, s spans.Valued) bool { + cp := NewCheckpointWithSpan(s) + if cp.TS < c.lastCheckpoint.TS { + log.Warn("failed to update global checkpoint: stale", + zap.Uint64("old", c.lastCheckpoint.TS), zap.Uint64("new", cp.TS)) + return false + } + // Need resolve lock for different range and same TS + // so check the range and TS here. + if cp.equal(c.lastCheckpoint) { + return false + } + c.UpdateLastCheckpoint(cp) + metrics.LastCheckpoint.WithLabelValues(c.task.GetName()).Set(float64(c.lastCheckpoint.TS)) + return true +} + +// advanceCheckpointBy advances the checkpoint by a checkpoint getter function. +func (c *CheckpointAdvancer) advanceCheckpointBy(ctx context.Context, + getCheckpoint func(context.Context) (spans.Valued, error)) error { + start := time.Now() + cp, err := getCheckpoint(ctx) + if err != nil { + return err + } + + if c.setCheckpoint(ctx, cp) { + log.Info("uploading checkpoint for task", + zap.Stringer("checkpoint", oracle.GetTimeFromTS(cp.Value)), + zap.Uint64("checkpoint", cp.Value), + zap.String("task", c.task.Name), + zap.Stringer("take", time.Since(start))) + } + return nil +} + +func (c *CheckpointAdvancer) stopSubscriber() { + c.subscriberMu.Lock() + defer c.subscriberMu.Unlock() + c.subscriber.Drop() + c.subscriber = nil +} + +func (c *CheckpointAdvancer) SpawnSubscriptionHandler(ctx context.Context) { + c.subscriberMu.Lock() + defer c.subscriberMu.Unlock() + c.subscriber = NewSubscriber(c.env, c.env, WithMasterContext(ctx)) + es := c.subscriber.Events() + log.Info("Subscription handler spawned.", zap.String("category", "log backup subscription manager")) + + go func() { + defer utils.CatchAndLogPanic() + for { + select { + case <-ctx.Done(): + return + case event, ok := <-es: + if !ok { + return + } + failpoint.Inject("subscription-handler-loop", func() {}) + c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { + if vsf == nil { + log.Warn("Span tree not found, perhaps stale event of removed tasks.", + zap.String("category", "log backup subscription manager")) + return + } + log.Debug("Accepting region flush event.", + zap.Stringer("range", logutil.StringifyRange(event.Key)), + zap.Uint64("checkpoint", event.Value)) + vsf.Merge(event) + }) + } + } + }() +} + +func (c *CheckpointAdvancer) subscribeTick(ctx context.Context) error { + c.subscriberMu.Lock() + defer c.subscriberMu.Unlock() + if c.subscriber == nil { + return nil + } + failpoint.Inject("get_subscriber", nil) + if err := c.subscriber.UpdateStoreTopology(ctx); err != nil { + log.Warn("Error when updating store topology.", + zap.String("category", "log backup advancer"), logutil.ShortError(err)) + } + c.subscriber.HandleErrors(ctx) + return c.subscriber.PendingErrors() +} + +func (c *CheckpointAdvancer) isCheckpointLagged(ctx context.Context) (bool, error) { + if c.cfg.CheckPointLagLimit <= 0 { + return false, nil + } + + now, err := c.env.FetchCurrentTS(ctx) + if err != nil { + return false, err + } + + lagDuration := oracle.GetTimeFromTS(now).Sub(oracle.GetTimeFromTS(c.lastCheckpoint.TS)) + if lagDuration > c.cfg.CheckPointLagLimit { + log.Warn("checkpoint lag is too large", zap.String("category", "log backup advancer"), + zap.Stringer("lag", lagDuration)) + return true, nil + } + return false, nil +} + +func (c *CheckpointAdvancer) importantTick(ctx context.Context) error { + c.checkpointsMu.Lock() + c.setCheckpoint(ctx, c.checkpoints.Min()) + c.checkpointsMu.Unlock() + if err := c.env.UploadV3GlobalCheckpointForTask(ctx, c.task.Name, c.lastCheckpoint.TS); err != nil { + return errors.Annotate(err, "failed to upload global checkpoint") + } + isLagged, err := c.isCheckpointLagged(ctx) + if err != nil { + return errors.Annotate(err, "failed to check timestamp") + } + if isLagged { + err := c.env.PauseTask(ctx, c.task.Name) + if err != nil { + return errors.Annotate(err, "failed to pause task") + } + return errors.Annotate(errors.Errorf("check point lagged too large"), "check point lagged too large") + } + p, err := c.env.BlockGCUntil(ctx, c.lastCheckpoint.safeTS()) + if err != nil { + return errors.Annotatef(err, + "failed to update service GC safe point, current checkpoint is %d, target checkpoint is %d", + c.lastCheckpoint.safeTS(), p) + } + if p <= c.lastCheckpoint.safeTS() { + log.Info("updated log backup GC safe point.", + zap.Uint64("checkpoint", p), zap.Uint64("target", c.lastCheckpoint.safeTS())) + } + if p > c.lastCheckpoint.safeTS() { + log.Warn("update log backup GC safe point failed: stale.", + zap.Uint64("checkpoint", p), zap.Uint64("target", c.lastCheckpoint.safeTS())) + } + return nil +} + +func (c *CheckpointAdvancer) optionalTick(cx context.Context) error { + // lastCheckpoint is not increased too long enough. + // assume the cluster has expired locks for whatever reasons. + var targets []spans.Valued + if c.lastCheckpoint != nil && c.lastCheckpoint.needResolveLocks() && c.inResolvingLock.CompareAndSwap(false, true) { + c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { + // when get locks here. assume these locks are not belong to same txn, + // but these locks' start ts are close to 1 minute. try resolve these locks at one time + vsf.TraverseValuesLessThan(tsoAfter(c.lastCheckpoint.TS, time.Minute), func(v spans.Valued) bool { + targets = append(targets, v) + return true + }) + }) + if len(targets) != 0 { + log.Info("Advancer starts to resolve locks", zap.Int("targets", len(targets))) + // use new context here to avoid timeout + ctx := context.Background() + c.asyncResolveLocksForRanges(ctx, targets) + } else { + // don't forget set state back + c.inResolvingLock.Store(false) + } + } + threshold := c.Config().GetDefaultStartPollThreshold() + if err := c.subscribeTick(cx); err != nil { + log.Warn("Subscriber meet error, would polling the checkpoint.", zap.String("category", "log backup advancer"), + logutil.ShortError(err)) + threshold = c.Config().GetSubscriberErrorStartPollThreshold() + } + + return c.advanceCheckpointBy(cx, func(cx context.Context) (spans.Valued, error) { + return c.CalculateGlobalCheckpointLight(cx, threshold) + }) +} + +func (c *CheckpointAdvancer) tick(ctx context.Context) error { + c.taskMu.Lock() + defer c.taskMu.Unlock() + if c.task == nil || c.isPaused.Load() { + log.Debug("No tasks yet, skipping advancing.") + return nil + } + + var errs error + + cx, cancel := context.WithTimeout(ctx, c.Config().TickTimeout()) + defer cancel() + err := c.optionalTick(cx) + if err != nil { + log.Warn("option tick failed.", zap.String("category", "log backup advancer"), logutil.ShortError(err)) + errs = multierr.Append(errs, err) + } + + err = c.importantTick(ctx) + if err != nil { + log.Warn("important tick failed.", zap.String("category", "log backup advancer"), logutil.ShortError(err)) + errs = multierr.Append(errs, err) + } + + return errs +} + +func (c *CheckpointAdvancer) asyncResolveLocksForRanges(ctx context.Context, targets []spans.Valued) { + // run in another goroutine + // do not block main tick here + go func() { + failpoint.Inject("AsyncResolveLocks", func() {}) + handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { + // we will scan all locks and try to resolve them by check txn status. + return tikv.ResolveLocksForRange( + ctx, c.env, math.MaxUint64, r.StartKey, r.EndKey, tikv.NewGcResolveLockMaxBackoffer, tikv.GCScanLockLimit) + } + workerPool := util.NewWorkerPool(uint(config.DefaultMaxConcurrencyAdvance), "advancer resolve locks") + var wg sync.WaitGroup + for _, r := range targets { + targetRange := r + wg.Add(1) + workerPool.Apply(func() { + defer wg.Done() + // Run resolve lock on the whole TiKV cluster. + // it will use startKey/endKey to scan region in PD. + // but regionCache already has a codecPDClient. so just use decode key here. + // and it almost only include one region here. so set concurrency to 1. + runner := rangetask.NewRangeTaskRunner("advancer-resolve-locks-runner", + c.env.GetStore(), 1, handler) + err := runner.RunOnRange(ctx, targetRange.Key.StartKey, targetRange.Key.EndKey) + if err != nil { + // wait for next tick + log.Warn("resolve locks failed, wait for next tick", zap.String("category", "advancer"), + zap.String("uuid", "log backup advancer"), + zap.Error(err)) + } + }) + } + wg.Wait() + log.Info("finish resolve locks for checkpoint", zap.String("category", "advancer"), + zap.String("uuid", "log backup advancer"), + logutil.Key("StartKey", c.lastCheckpoint.StartKey), + logutil.Key("EndKey", c.lastCheckpoint.EndKey), + zap.Int("targets", len(targets))) + c.lastCheckpointMu.Lock() + c.lastCheckpoint.resolveLockTime = time.Now() + c.lastCheckpointMu.Unlock() + c.inResolvingLock.Store(false) + }() +} + +func (c *CheckpointAdvancer) TEST_registerCallbackForSubscriptions(f func()) int { + cnt := 0 + for _, sub := range c.subscriber.subscriptions { + sub.onDaemonExit = f + cnt += 1 + } + return cnt +} diff --git a/br/pkg/streamhelper/advancer_cliext.go b/br/pkg/streamhelper/advancer_cliext.go index 1411c306c3abd..f283120549451 100644 --- a/br/pkg/streamhelper/advancer_cliext.go +++ b/br/pkg/streamhelper/advancer_cliext.go @@ -183,10 +183,10 @@ func (t AdvancerExt) startListen(ctx context.Context, rev int64, ch chan<- TaskE for { select { case resp, ok := <-taskCh: - failpoint.Inject("advancer_close_channel", func() { + if _, _err_ := failpoint.Eval(_curpkg_("advancer_close_channel")); _err_ == nil { // We cannot really close the channel, just simulating it. ok = false - }) + } if !ok { ch <- errorEvent(io.EOF) return @@ -195,10 +195,10 @@ func (t AdvancerExt) startListen(ctx context.Context, rev int64, ch chan<- TaskE return } case resp, ok := <-pauseCh: - failpoint.Inject("advancer_close_pause_channel", func() { + if _, _err_ := failpoint.Eval(_curpkg_("advancer_close_pause_channel")); _err_ == nil { // We cannot really close the channel, just simulating it. ok = false - }) + } if !ok { ch <- errorEvent(io.EOF) return diff --git a/br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ b/br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ new file mode 100644 index 0000000000000..1411c306c3abd --- /dev/null +++ b/br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ @@ -0,0 +1,301 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "io" + "strings" + + "github.com/golang/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/util/redact" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +type EventType int + +const ( + EventAdd EventType = iota + EventDel + EventErr + EventPause + EventResume +) + +func (t EventType) String() string { + switch t { + case EventAdd: + return "Add" + case EventDel: + return "Del" + case EventErr: + return "Err" + case EventPause: + return "Pause" + case EventResume: + return "Resume" + } + return "Unknown" +} + +type TaskEvent struct { + Type EventType + Name string + Info *backuppb.StreamBackupTaskInfo + Ranges []kv.KeyRange + Err error +} + +func (t *TaskEvent) String() string { + if t.Err != nil { + return fmt.Sprintf("%s(%s, err = %s)", t.Type, t.Name, t.Err) + } + return fmt.Sprintf("%s(%s)", t.Type, t.Name) +} + +type AdvancerExt struct { + MetaDataClient +} + +func errorEvent(err error) TaskEvent { + return TaskEvent{ + Type: EventErr, + Err: err, + } +} + +func (t AdvancerExt) toTaskEvent(ctx context.Context, event *clientv3.Event) (TaskEvent, error) { + te := TaskEvent{} + var prefix string + + if bytes.HasPrefix(event.Kv.Key, []byte(PrefixOfTask())) { + prefix = PrefixOfTask() + te.Name = strings.TrimPrefix(string(event.Kv.Key), prefix) + } else if bytes.HasPrefix(event.Kv.Key, []byte(PrefixOfPause())) { + prefix = PrefixOfPause() + te.Name = strings.TrimPrefix(string(event.Kv.Key), prefix) + } else { + return TaskEvent{}, + errors.Annotatef(berrors.ErrInvalidArgument, "the path isn't a task/pause path (%s)", + string(event.Kv.Key)) + } + + switch { + case event.Type == clientv3.EventTypePut && prefix == PrefixOfTask(): + te.Type = EventAdd + case event.Type == clientv3.EventTypeDelete && prefix == PrefixOfTask(): + te.Type = EventDel + case event.Type == clientv3.EventTypePut && prefix == PrefixOfPause(): + te.Type = EventPause + case event.Type == clientv3.EventTypeDelete && prefix == PrefixOfPause(): + te.Type = EventResume + default: + return TaskEvent{}, + errors.Annotatef(berrors.ErrInvalidArgument, + "invalid event type or prefix: type=%s, prefix=%s", event.Type, prefix) + } + + te.Info = new(backuppb.StreamBackupTaskInfo) + if err := proto.Unmarshal(event.Kv.Value, te.Info); err != nil { + return TaskEvent{}, err + } + + var err error + te.Ranges, err = t.MetaDataClient.TaskByInfo(*te.Info).Ranges(ctx) + if err != nil { + return TaskEvent{}, err + } + + return te, nil +} + +func (t AdvancerExt) eventFromWatch(ctx context.Context, resp clientv3.WatchResponse) ([]TaskEvent, error) { + result := make([]TaskEvent, 0, len(resp.Events)) + if err := resp.Err(); err != nil { + return nil, err + } + for _, event := range resp.Events { + te, err := t.toTaskEvent(ctx, event) + if err != nil { + te.Type = EventErr + te.Err = err + } + result = append(result, te) + } + return result, nil +} + +func (t AdvancerExt) startListen(ctx context.Context, rev int64, ch chan<- TaskEvent) { + taskCh := t.Client.Watcher.Watch(ctx, PrefixOfTask(), clientv3.WithPrefix(), clientv3.WithRev(rev)) + pauseCh := t.Client.Watcher.Watch(ctx, PrefixOfPause(), clientv3.WithPrefix(), clientv3.WithRev(rev)) + + // inner function def + handleResponse := func(resp clientv3.WatchResponse) bool { + events, err := t.eventFromWatch(ctx, resp) + if err != nil { + log.Warn("Meet error during receiving the task event.", + zap.String("category", "log backup advancer"), logutil.ShortError(err)) + ch <- errorEvent(err) + return false + } + for _, event := range events { + ch <- event + } + return true + } + + // inner function def + collectRemaining := func() { + log.Info("Start collecting remaining events in the channel.", zap.String("category", "log backup advancer"), + zap.Int("remained", len(taskCh))) + defer log.Info("Finish collecting remaining events in the channel.", zap.String("category", "log backup advancer")) + for { + if taskCh == nil && pauseCh == nil { + return + } + + select { + case resp, ok := <-taskCh: + if !ok || !handleResponse(resp) { + taskCh = nil + } + case resp, ok := <-pauseCh: + if !ok || !handleResponse(resp) { + pauseCh = nil + } + } + } + } + + go func() { + defer close(ch) + for { + select { + case resp, ok := <-taskCh: + failpoint.Inject("advancer_close_channel", func() { + // We cannot really close the channel, just simulating it. + ok = false + }) + if !ok { + ch <- errorEvent(io.EOF) + return + } + if !handleResponse(resp) { + return + } + case resp, ok := <-pauseCh: + failpoint.Inject("advancer_close_pause_channel", func() { + // We cannot really close the channel, just simulating it. + ok = false + }) + if !ok { + ch <- errorEvent(io.EOF) + return + } + if !handleResponse(resp) { + return + } + case <-ctx.Done(): + collectRemaining() + ch <- errorEvent(ctx.Err()) + return + } + } + }() +} + +func (t AdvancerExt) getFullTasksAsEvent(ctx context.Context) ([]TaskEvent, int64, error) { + tasks, rev, err := t.GetAllTasksWithRevision(ctx) + if err != nil { + return nil, 0, err + } + events := make([]TaskEvent, 0, len(tasks)) + for _, task := range tasks { + ranges, err := task.Ranges(ctx) + if err != nil { + return nil, 0, err + } + te := TaskEvent{ + Type: EventAdd, + Name: task.Info.Name, + Info: &(task.Info), + Ranges: ranges, + } + events = append(events, te) + } + return events, rev, nil +} + +func (t AdvancerExt) Begin(ctx context.Context, ch chan<- TaskEvent) error { + initialTasks, rev, err := t.getFullTasksAsEvent(ctx) + if err != nil { + return err + } + // Note: maybe `go` here so we won't block? + for _, task := range initialTasks { + ch <- task + } + t.startListen(ctx, rev+1, ch) + return nil +} + +func (t AdvancerExt) GetGlobalCheckpointForTask(ctx context.Context, taskName string) (uint64, error) { + key := GlobalCheckpointOf(taskName) + resp, err := t.KV.Get(ctx, key) + if err != nil { + return 0, err + } + + if len(resp.Kvs) == 0 { + return 0, nil + } + + firstKV := resp.Kvs[0] + value := firstKV.Value + if len(value) != 8 { + return 0, errors.Annotatef(berrors.ErrPiTRMalformedMetadata, + "the global checkpoint isn't 64bits (it is %d bytes, value = %s)", + len(value), + redact.Key(value)) + } + + return binary.BigEndian.Uint64(value), nil +} + +func (t AdvancerExt) UploadV3GlobalCheckpointForTask(ctx context.Context, taskName string, checkpoint uint64) error { + key := GlobalCheckpointOf(taskName) + value := string(encodeUint64(checkpoint)) + oldValue, err := t.GetGlobalCheckpointForTask(ctx, taskName) + if err != nil { + return err + } + + if checkpoint < oldValue { + log.Warn("skipping upload global checkpoint", zap.String("category", "log backup advancer"), + zap.Uint64("old", oldValue), zap.Uint64("new", checkpoint)) + return nil + } + + _, err = t.KV.Put(ctx, key, value) + if err != nil { + return err + } + return nil +} + +func (t AdvancerExt) ClearV3GlobalCheckpointForTask(ctx context.Context, taskName string) error { + key := GlobalCheckpointOf(taskName) + _, err := t.KV.Delete(ctx, key) + return err +} diff --git a/br/pkg/streamhelper/binding__failpoint_binding__.go b/br/pkg/streamhelper/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..0872efd5448a7 --- /dev/null +++ b/br/pkg/streamhelper/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package streamhelper + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/streamhelper/flush_subscriber.go b/br/pkg/streamhelper/flush_subscriber.go index f1310b0212372..8d32adb576749 100644 --- a/br/pkg/streamhelper/flush_subscriber.go +++ b/br/pkg/streamhelper/flush_subscriber.go @@ -102,10 +102,10 @@ func (f *FlushSubscriber) UpdateStoreTopology(ctx context.Context) error { // Clear clears all the subscriptions. func (f *FlushSubscriber) Clear() { timeout := clearSubscriberTimeOut - failpoint.Inject("FlushSubscriber.Clear.timeoutMs", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("FlushSubscriber.Clear.timeoutMs")); _err_ == nil { //nolint:durationcheck timeout = time.Duration(v.(int)) * time.Millisecond - }) + } log.Info("Clearing.", zap.String("category", "log backup flush subscriber"), zap.Duration("timeout", timeout)) @@ -302,7 +302,7 @@ func (s *subscription) listenOver(ctx context.Context, cli eventStream) { logutil.Key("event", m.EndKey), logutil.ShortError(err)) continue } - failpoint.Inject("subscription.listenOver.aboutToSend", func() {}) + failpoint.Eval(_curpkg_("subscription.listenOver.aboutToSend")) evt := spans.Valued{ Key: spans.Span{ diff --git a/br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ b/br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ new file mode 100644 index 0000000000000..f1310b0212372 --- /dev/null +++ b/br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ @@ -0,0 +1,373 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package streamhelper + +import ( + "context" + "io" + "strconv" + "sync" + "time" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/streamhelper/spans" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/util/codec" + "go.uber.org/multierr" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + // clearSubscriberTimeOut is the timeout for clearing the subscriber. + clearSubscriberTimeOut = 1 * time.Minute +) + +// FlushSubscriber maintains the state of subscribing to the cluster. +type FlushSubscriber struct { + dialer LogBackupService + cluster TiKVClusterMeta + + // Current connections. + subscriptions map[uint64]*subscription + // The output channel. + eventsTunnel chan spans.Valued + // The background context for subscribes. + masterCtx context.Context +} + +// SubscriberConfig is a config which cloud be applied into the subscriber. +type SubscriberConfig func(*FlushSubscriber) + +// WithMasterContext sets the "master context" for the subscriber, +// that context would be the "background" context for every subtasks created by the subscription manager. +func WithMasterContext(ctx context.Context) SubscriberConfig { + return func(fs *FlushSubscriber) { fs.masterCtx = ctx } +} + +// NewSubscriber creates a new subscriber via the environment and optional configs. +func NewSubscriber(dialer LogBackupService, cluster TiKVClusterMeta, config ...SubscriberConfig) *FlushSubscriber { + subs := &FlushSubscriber{ + dialer: dialer, + cluster: cluster, + + subscriptions: map[uint64]*subscription{}, + eventsTunnel: make(chan spans.Valued, 1024), + masterCtx: context.Background(), + } + + for _, c := range config { + c(subs) + } + + return subs +} + +// UpdateStoreTopology fetches the current store topology and try to adapt the subscription state with it. +func (f *FlushSubscriber) UpdateStoreTopology(ctx context.Context) error { + stores, err := f.cluster.Stores(ctx) + if err != nil { + return errors.Annotate(err, "failed to get store list") + } + + storeSet := map[uint64]struct{}{} + for _, store := range stores { + sub, ok := f.subscriptions[store.ID] + if !ok { + f.addSubscription(ctx, store) + f.subscriptions[store.ID].connect(f.masterCtx, f.dialer) + } else if sub.storeBootAt != store.BootAt { + sub.storeBootAt = store.BootAt + sub.connect(f.masterCtx, f.dialer) + } + storeSet[store.ID] = struct{}{} + } + + for id := range f.subscriptions { + _, ok := storeSet[id] + if !ok { + f.removeSubscription(ctx, id) + } + } + return nil +} + +// Clear clears all the subscriptions. +func (f *FlushSubscriber) Clear() { + timeout := clearSubscriberTimeOut + failpoint.Inject("FlushSubscriber.Clear.timeoutMs", func(v failpoint.Value) { + //nolint:durationcheck + timeout = time.Duration(v.(int)) * time.Millisecond + }) + log.Info("Clearing.", + zap.String("category", "log backup flush subscriber"), + zap.Duration("timeout", timeout)) + cx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for id := range f.subscriptions { + f.removeSubscription(cx, id) + } +} + +// Drop terminates the lifetime of the subscriber. +// This subscriber would be no more usable. +func (f *FlushSubscriber) Drop() { + f.Clear() + close(f.eventsTunnel) +} + +// HandleErrors execute the handlers over all pending errors. +// Note that the handler may cannot handle the pending errors, at that time, +// you can fetch the errors via `PendingErrors` call. +func (f *FlushSubscriber) HandleErrors(ctx context.Context) { + for id, sub := range f.subscriptions { + err := sub.loadError() + if err != nil { + retry := f.canBeRetried(err) + log.Warn("Meet error.", zap.String("category", "log backup flush subscriber"), + logutil.ShortError(err), zap.Bool("can-retry?", retry), zap.Uint64("store", id)) + if retry { + sub.connect(f.masterCtx, f.dialer) + } + } + } +} + +// Events returns the output channel of the events. +func (f *FlushSubscriber) Events() <-chan spans.Valued { + return f.eventsTunnel +} + +type eventStream = logbackup.LogBackup_SubscribeFlushEventClient + +type joinHandle <-chan struct{} + +func (jh joinHandle) Wait(ctx context.Context) { + select { + case <-jh: + case <-ctx.Done(): + log.Warn("join handle timed out.", zap.StackSkip("caller", 1)) + } +} + +func spawnJoinable(f func()) joinHandle { + c := make(chan struct{}) + go func() { + defer close(c) + f() + }() + return c +} + +// subscription is the state of subscription of one store. +// initially, it is IDLE, where cancel == nil. +// once `connect` called, it goto CONNECTED, where cancel != nil and err == nil. +// once some error (both foreground or background) happens, it goto ERROR, where err != nil. +type subscription struct { + // the handle to cancel the worker goroutine. + cancel context.CancelFunc + // the handle to wait until the worker goroutine exits. + background joinHandle + errMu sync.Mutex + err error + + // Immutable state. + storeID uint64 + // We record start bootstrap time and once a store restarts + // we need to try reconnect even there is a error cannot be retry. + storeBootAt uint64 + output chan<- spans.Valued + + onDaemonExit func() +} + +func (s *subscription) emitError(err error) { + s.errMu.Lock() + defer s.errMu.Unlock() + + s.err = err +} + +func (s *subscription) loadError() error { + s.errMu.Lock() + defer s.errMu.Unlock() + + return s.err +} + +func (s *subscription) clearError() { + s.errMu.Lock() + defer s.errMu.Unlock() + + s.err = nil +} + +func newSubscription(toStore Store, output chan<- spans.Valued) *subscription { + return &subscription{ + storeID: toStore.ID, + storeBootAt: toStore.BootAt, + output: output, + } +} + +func (s *subscription) connect(ctx context.Context, dialer LogBackupService) { + err := s.doConnect(ctx, dialer) + if err != nil { + s.emitError(err) + } +} + +func (s *subscription) doConnect(ctx context.Context, dialer LogBackupService) error { + log.Info("Adding subscription.", zap.String("category", "log backup subscription manager"), + zap.Uint64("store", s.storeID), zap.Uint64("boot", s.storeBootAt)) + // We should shutdown the background task firstly. + // Once it yields some error during shuting down, the error won't be brought to next run. + s.close(ctx) + s.clearError() + + c, err := dialer.GetLogBackupClient(ctx, s.storeID) + if err != nil { + return errors.Annotate(err, "failed to get log backup client") + } + cx, cancel := context.WithCancel(ctx) + cli, err := c.SubscribeFlushEvent(cx, &logbackup.SubscribeFlushEventRequest{ + ClientId: uuid.NewString(), + }) + if err != nil { + cancel() + _ = dialer.ClearCache(ctx, s.storeID) + return errors.Annotate(err, "failed to subscribe events") + } + lcx := logutil.ContextWithField(cx, zap.Uint64("store-id", s.storeID), + zap.String("category", "log backup flush subscriber")) + s.cancel = cancel + s.background = spawnJoinable(func() { s.listenOver(lcx, cli) }) + return nil +} + +func (s *subscription) close(ctx context.Context) { + if s.cancel != nil { + s.cancel() + s.background.Wait(ctx) + } + // HACK: don't close the internal channel here, + // because it is a ever-sharing channel. +} + +func (s *subscription) listenOver(ctx context.Context, cli eventStream) { + storeID := s.storeID + logutil.CL(ctx).Info("Listen starting.", zap.Uint64("store", storeID)) + defer func() { + if s.onDaemonExit != nil { + s.onDaemonExit() + } + + if pData := recover(); pData != nil { + log.Warn("Subscriber paniked.", zap.Uint64("store", storeID), zap.Any("panic-data", pData), zap.Stack("stack")) + s.emitError(errors.Annotatef(berrors.ErrUnknown, "panic during executing: %v", pData)) + } + }() + for { + // Shall we use RecvMsg for better performance? + // Note that the spans.Full requires the input slice be immutable. + msg, err := cli.Recv() + if err != nil { + logutil.CL(ctx).Info("Listen stopped.", + zap.Uint64("store", storeID), logutil.ShortError(err)) + if err == io.EOF || err == context.Canceled || status.Code(err) == codes.Canceled { + return + } + s.emitError(errors.Annotatef(err, "while receiving from store id %d", storeID)) + return + } + + log.Debug("Sending events.", zap.Int("size", len(msg.Events))) + for _, m := range msg.Events { + start, err := decodeKey(m.StartKey) + if err != nil { + logutil.CL(ctx).Warn("start key not encoded, skipping", + logutil.Key("event", m.StartKey), logutil.ShortError(err)) + continue + } + end, err := decodeKey(m.EndKey) + if err != nil { + logutil.CL(ctx).Warn("end key not encoded, skipping", + logutil.Key("event", m.EndKey), logutil.ShortError(err)) + continue + } + failpoint.Inject("subscription.listenOver.aboutToSend", func() {}) + + evt := spans.Valued{ + Key: spans.Span{ + StartKey: start, + EndKey: end, + }, + Value: m.Checkpoint, + } + select { + case s.output <- evt: + case <-ctx.Done(): + logutil.CL(ctx).Warn("Context canceled while sending events.", + zap.Uint64("store", storeID)) + return + } + } + metrics.RegionCheckpointSubscriptionEvent.WithLabelValues( + strconv.Itoa(int(storeID))).Observe(float64(len(msg.Events))) + } +} + +func (f *FlushSubscriber) addSubscription(ctx context.Context, toStore Store) { + f.subscriptions[toStore.ID] = newSubscription(toStore, f.eventsTunnel) +} + +func (f *FlushSubscriber) removeSubscription(ctx context.Context, toStore uint64) { + subs, ok := f.subscriptions[toStore] + if ok { + log.Info("Removing subscription.", zap.String("category", "log backup subscription manager"), + zap.Uint64("store", toStore)) + subs.close(ctx) + delete(f.subscriptions, toStore) + } +} + +// decodeKey decodes the key from TiKV, because the region range is encoded in TiKV. +func decodeKey(key []byte) ([]byte, error) { + if len(key) == 0 { + return key, nil + } + // Ignore the timestamp... + _, data, err := codec.DecodeBytes(key, nil) + if err != nil { + return key, err + } + return data, err +} + +func (f *FlushSubscriber) canBeRetried(err error) bool { + for _, e := range multierr.Errors(errors.Cause(err)) { + s := status.Convert(e) + // Is there any other error cannot be retried? + if s.Code() == codes.Unimplemented { + return false + } + } + return true +} + +func (f *FlushSubscriber) PendingErrors() error { + var allErr error + for _, s := range f.subscriptions { + if err := s.loadError(); err != nil { + allErr = multierr.Append(allErr, errors.Annotatef(err, "store %d has error", s.storeID)) + } + } + return allErr +} diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index a04f14f8b519a..2acb33e8aa6c8 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -631,7 +631,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig progressCount := uint64(0) progressCallBack := func() { updateCh.Inc() - failpoint.Inject("progress-call-back", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("progress-call-back")); _err_ == nil { log.Info("failpoint progress-call-back injected") atomic.AddUint64(&progressCount, 1) if fileName, ok := v.(string); ok { @@ -645,7 +645,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig log.Warn("failed to write data to file", zap.Error(err)) } } - }) + } } if cfg.UseCheckpoint { @@ -668,7 +668,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig }() } - failpoint.Inject("s3-outage-during-writing-file", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("s3-outage-during-writing-file")); _err_ == nil { log.Info("failpoint s3-outage-during-writing-file injected, " + "process will sleep for 5s and notify the shell to kill s3 service.") if sigFile, ok := v.(string); ok { @@ -681,7 +681,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig } } time.Sleep(5 * time.Second) - }) + } metawriter.StartWriteMetasAsync(ctx, metautil.AppendDataFile) err = client.BackupRanges(ctx, ranges, req, uint(cfg.Concurrency), cfg.ReplicaReadLabel, metawriter, progressCallBack) diff --git a/br/pkg/task/backup.go__failpoint_stash__ b/br/pkg/task/backup.go__failpoint_stash__ new file mode 100644 index 0000000000000..a04f14f8b519a --- /dev/null +++ b/br/pkg/task/backup.go__failpoint_stash__ @@ -0,0 +1,835 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package task + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/docker/go-units" + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/backup" + "github.com/pingcap/tidb/br/pkg/checkpoint" + "github.com/pingcap/tidb/br/pkg/checksum" + "github.com/pingcap/tidb/br/pkg/conn" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/rtree" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/statistics/handle" + "github.com/pingcap/tidb/pkg/types" + "github.com/spf13/pflag" + "github.com/tikv/client-go/v2/oracle" + kvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +const ( + flagBackupTimeago = "timeago" + flagBackupTS = "backupts" + flagLastBackupTS = "lastbackupts" + flagCompressionType = "compression" + flagCompressionLevel = "compression-level" + flagRemoveSchedulers = "remove-schedulers" + flagIgnoreStats = "ignore-stats" + flagUseBackupMetaV2 = "use-backupmeta-v2" + flagUseCheckpoint = "use-checkpoint" + flagKeyspaceName = "keyspace-name" + flagReplicaReadLabel = "replica-read-label" + flagTableConcurrency = "table-concurrency" + + flagGCTTL = "gcttl" + + defaultBackupConcurrency = 4 + maxBackupConcurrency = 256 +) + +const ( + FullBackupCmd = "Full Backup" + DBBackupCmd = "Database Backup" + TableBackupCmd = "Table Backup" + RawBackupCmd = "Raw Backup" + TxnBackupCmd = "Txn Backup" + EBSBackupCmd = "EBS Backup" +) + +// CompressionConfig is the configuration for sst file compression. +type CompressionConfig struct { + CompressionType backuppb.CompressionType `json:"compression-type" toml:"compression-type"` + CompressionLevel int32 `json:"compression-level" toml:"compression-level"` +} + +// BackupConfig is the configuration specific for backup tasks. +type BackupConfig struct { + Config + + TimeAgo time.Duration `json:"time-ago" toml:"time-ago"` + BackupTS uint64 `json:"backup-ts" toml:"backup-ts"` + LastBackupTS uint64 `json:"last-backup-ts" toml:"last-backup-ts"` + GCTTL int64 `json:"gc-ttl" toml:"gc-ttl"` + RemoveSchedulers bool `json:"remove-schedulers" toml:"remove-schedulers"` + IgnoreStats bool `json:"ignore-stats" toml:"ignore-stats"` + UseBackupMetaV2 bool `json:"use-backupmeta-v2"` + UseCheckpoint bool `json:"use-checkpoint" toml:"use-checkpoint"` + ReplicaReadLabel map[string]string `json:"replica-read-label" toml:"replica-read-label"` + TableConcurrency uint `json:"table-concurrency" toml:"table-concurrency"` + CompressionConfig + + // for ebs-based backup + FullBackupType FullBackupType `json:"full-backup-type" toml:"full-backup-type"` + VolumeFile string `json:"volume-file" toml:"volume-file"` + SkipAWS bool `json:"skip-aws" toml:"skip-aws"` + CloudAPIConcurrency uint `json:"cloud-api-concurrency" toml:"cloud-api-concurrency"` + ProgressFile string `json:"progress-file" toml:"progress-file"` + SkipPauseGCAndScheduler bool `json:"skip-pause-gc-and-scheduler" toml:"skip-pause-gc-and-scheduler"` +} + +// DefineBackupFlags defines common flags for the backup command. +func DefineBackupFlags(flags *pflag.FlagSet) { + flags.Duration( + flagBackupTimeago, 0, + "The history version of the backup task, e.g. 1m, 1h. Do not exceed GCSafePoint") + + // TODO: remove experimental tag if it's stable + flags.Uint64(flagLastBackupTS, 0, "(experimental) the last time backup ts,"+ + " use for incremental backup, support TSO only") + flags.String(flagBackupTS, "", "the backup ts support TSO or datetime,"+ + " e.g. '400036290571534337', '2018-05-11 01:42:23'") + flags.Int64(flagGCTTL, utils.DefaultBRGCSafePointTTL, "the TTL (in seconds) that PD holds for BR's GC safepoint") + flags.String(flagCompressionType, "zstd", + "backup sst file compression algorithm, value can be one of 'lz4|zstd|snappy'") + flags.Int32(flagCompressionLevel, 0, "compression level used for sst file compression") + + flags.Uint32(flagConcurrency, 4, "The size of a BR thread pool that executes tasks, "+ + "One task represents one table range (or one index range) according to the backup schemas. If there is one table with one index."+ + "there will be two tasks to back up this table. This value should increase if you need to back up lots of tables or indices.") + + flags.Uint(flagTableConcurrency, backup.DefaultSchemaConcurrency, "The size of a BR thread pool used for backup table metas, "+ + "including tableInfo/checksum and stats.") + + flags.Bool(flagRemoveSchedulers, false, + "disable the balance, shuffle and region-merge schedulers in PD to speed up backup") + // This flag can impact the online cluster, so hide it in case of abuse. + _ = flags.MarkHidden(flagRemoveSchedulers) + + // Disable stats by default. + // TODO: we need a better way to backup/restore stats. + flags.Bool(flagIgnoreStats, true, "ignore backup stats") + + flags.Bool(flagUseBackupMetaV2, true, + "use backup meta v2 to store meta info") + + flags.String(flagKeyspaceName, "", "keyspace name for backup") + // This flag will change the structure of backupmeta. + // we must make sure the old three version of br can parse the v2 meta to keep compatibility. + // so this flag should set to false for three version by default. + // for example: + // if we put this feature in v4.0.14, then v4.0.14 br can parse v2 meta + // but will generate v1 meta due to this flag is false. the behaviour is as same as v4.0.15, v4.0.16. + // finally v4.0.17 will set this flag to true, and generate v2 meta. + // + // the version currently is v7.4.0, the flag can be set to true as default value. + // _ = flags.MarkHidden(flagUseBackupMetaV2) + + flags.Bool(flagUseCheckpoint, true, "use checkpoint mode") + _ = flags.MarkHidden(flagUseCheckpoint) + + flags.String(flagReplicaReadLabel, "", "specify the label of the stores to be used for backup, e.g. 'label_key:label_value'") +} + +// ParseFromFlags parses the backup-related flags from the flag set. +func (cfg *BackupConfig) ParseFromFlags(flags *pflag.FlagSet) error { + timeAgo, err := flags.GetDuration(flagBackupTimeago) + if err != nil { + return errors.Trace(err) + } + if timeAgo < 0 { + return errors.Annotate(berrors.ErrInvalidArgument, "negative timeago is not allowed") + } + cfg.TimeAgo = timeAgo + cfg.LastBackupTS, err = flags.GetUint64(flagLastBackupTS) + if err != nil { + return errors.Trace(err) + } + backupTS, err := flags.GetString(flagBackupTS) + if err != nil { + return errors.Trace(err) + } + cfg.BackupTS, err = ParseTSString(backupTS, false) + if err != nil { + return errors.Trace(err) + } + cfg.UseBackupMetaV2, err = flags.GetBool(flagUseBackupMetaV2) + if err != nil { + return errors.Trace(err) + } + cfg.UseCheckpoint, err = flags.GetBool(flagUseCheckpoint) + if err != nil { + return errors.Trace(err) + } + if cfg.LastBackupTS > 0 { + // TODO: compatible with incremental backup + cfg.UseCheckpoint = false + log.Info("since incremental backup is used, turn off checkpoint mode") + } + gcTTL, err := flags.GetInt64(flagGCTTL) + if err != nil { + return errors.Trace(err) + } + cfg.GCTTL = gcTTL + cfg.Concurrency, err = flags.GetUint32(flagConcurrency) + if err != nil { + return errors.Trace(err) + } + if cfg.TableConcurrency, err = flags.GetUint(flagTableConcurrency); err != nil { + return errors.Trace(err) + } + + compressionCfg, err := parseCompressionFlags(flags) + if err != nil { + return errors.Trace(err) + } + cfg.CompressionConfig = *compressionCfg + + if err = cfg.Config.ParseFromFlags(flags); err != nil { + return errors.Trace(err) + } + cfg.RemoveSchedulers, err = flags.GetBool(flagRemoveSchedulers) + if err != nil { + return errors.Trace(err) + } + cfg.IgnoreStats, err = flags.GetBool(flagIgnoreStats) + if err != nil { + return errors.Trace(err) + } + cfg.KeyspaceName, err = flags.GetString(flagKeyspaceName) + if err != nil { + return errors.Trace(err) + } + + if flags.Lookup(flagFullBackupType) != nil { + // for backup full + fullBackupType, err := flags.GetString(flagFullBackupType) + if err != nil { + return errors.Trace(err) + } + if !FullBackupType(fullBackupType).Valid() { + return errors.New("invalid full backup type") + } + cfg.FullBackupType = FullBackupType(fullBackupType) + cfg.SkipAWS, err = flags.GetBool(flagSkipAWS) + if err != nil { + return errors.Trace(err) + } + cfg.CloudAPIConcurrency, err = flags.GetUint(flagCloudAPIConcurrency) + if err != nil { + return errors.Trace(err) + } + cfg.VolumeFile, err = flags.GetString(flagBackupVolumeFile) + if err != nil { + return errors.Trace(err) + } + cfg.ProgressFile, err = flags.GetString(flagProgressFile) + if err != nil { + return errors.Trace(err) + } + cfg.SkipPauseGCAndScheduler, err = flags.GetBool(flagOperatorPausedGCAndSchedulers) + if err != nil { + return errors.Trace(err) + } + } + + cfg.ReplicaReadLabel, err = parseReplicaReadLabelFlag(flags) + if err != nil { + return errors.Trace(err) + } + + return nil +} + +// parseCompressionFlags parses the backup-related flags from the flag set. +func parseCompressionFlags(flags *pflag.FlagSet) (*CompressionConfig, error) { + compressionStr, err := flags.GetString(flagCompressionType) + if err != nil { + return nil, errors.Trace(err) + } + compressionType, err := parseCompressionType(compressionStr) + if err != nil { + return nil, errors.Trace(err) + } + level, err := flags.GetInt32(flagCompressionLevel) + if err != nil { + return nil, errors.Trace(err) + } + return &CompressionConfig{ + CompressionLevel: level, + CompressionType: compressionType, + }, nil +} + +// Adjust is use for BR(binary) and BR in TiDB. +// When new config was add and not included in parser. +// we should set proper value in this function. +// so that both binary and TiDB will use same default value. +func (cfg *BackupConfig) Adjust() { + cfg.adjust() + usingDefaultConcurrency := false + if cfg.Config.Concurrency == 0 { + cfg.Config.Concurrency = defaultBackupConcurrency + usingDefaultConcurrency = true + } + if cfg.Config.Concurrency > maxBackupConcurrency { + cfg.Config.Concurrency = maxBackupConcurrency + } + if cfg.RateLimit != unlimited { + // TiKV limits the upload rate by each backup request. + // When the backup requests are sent concurrently, + // the ratelimit couldn't work as intended. + // Degenerating to sequentially sending backup requests to avoid this. + if !usingDefaultConcurrency { + logutil.WarnTerm("setting `--ratelimit` and `--concurrency` at the same time, "+ + "ignoring `--concurrency`: `--ratelimit` forces sequential (i.e. concurrency = 1) backup", + zap.String("ratelimit", units.HumanSize(float64(cfg.RateLimit))+"/s"), + zap.Uint32("concurrency-specified", cfg.Config.Concurrency)) + } + cfg.Config.Concurrency = 1 + } + + if cfg.GCTTL == 0 { + cfg.GCTTL = utils.DefaultBRGCSafePointTTL + } + // Use zstd as default + if cfg.CompressionType == backuppb.CompressionType_UNKNOWN { + cfg.CompressionType = backuppb.CompressionType_ZSTD + } + if cfg.CloudAPIConcurrency == 0 { + cfg.CloudAPIConcurrency = defaultCloudAPIConcurrency + } +} + +type immutableBackupConfig struct { + LastBackupTS uint64 `json:"last-backup-ts"` + IgnoreStats bool `json:"ignore-stats"` + UseCheckpoint bool `json:"use-checkpoint"` + + storage.BackendOptions + Storage string `json:"storage"` + PD []string `json:"pd"` + SendCreds bool `json:"send-credentials-to-tikv"` + NoCreds bool `json:"no-credentials"` + FilterStr []string `json:"filter-strings"` + CipherInfo backuppb.CipherInfo `json:"cipher"` + KeyspaceName string `json:"keyspace-name"` +} + +// a rough hash for checkpoint checker +func (cfg *BackupConfig) Hash() ([]byte, error) { + config := &immutableBackupConfig{ + LastBackupTS: cfg.LastBackupTS, + IgnoreStats: cfg.IgnoreStats, + UseCheckpoint: cfg.UseCheckpoint, + + BackendOptions: cfg.BackendOptions, + Storage: cfg.Storage, + PD: cfg.PD, + SendCreds: cfg.SendCreds, + NoCreds: cfg.NoCreds, + FilterStr: cfg.FilterStr, + CipherInfo: cfg.CipherInfo, + KeyspaceName: cfg.KeyspaceName, + } + data, err := json.Marshal(config) + if err != nil { + return nil, errors.Trace(err) + } + hash := sha256.Sum256(data) + + return hash[:], nil +} + +func isFullBackup(cmdName string) bool { + return cmdName == FullBackupCmd +} + +// RunBackup starts a backup task inside the current goroutine. +func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig) error { + cfg.Adjust() + config.UpdateGlobal(func(conf *config.Config) { + conf.KeyspaceName = cfg.KeyspaceName + }) + + defer summary.Summary(cmdName) + ctx, cancel := context.WithCancel(c) + defer cancel() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("task.RunBackup", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + u, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) + if err != nil { + return errors.Trace(err) + } + // if use noop as external storage, turn off the checkpoint mode + if u.GetNoop() != nil { + log.Info("since noop external storage is used, turn off checkpoint mode") + cfg.UseCheckpoint = false + } + skipStats := cfg.IgnoreStats + // For backup, Domain is not needed if user ignores stats. + // Domain loads all table info into memory. By skipping Domain, we save + // lots of memory (about 500MB for 40K 40 fields YCSB tables). + needDomain := !skipStats + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), cfg.CheckRequirements, needDomain, conn.NormalVersionChecker) + if err != nil { + return errors.Trace(err) + } + defer mgr.Close() + // after version check, check the cluster whether support checkpoint mode + if cfg.UseCheckpoint { + err = version.CheckCheckpointSupport() + if err != nil { + log.Warn("unable to use checkpoint mode, fall back to normal mode", zap.Error(err)) + cfg.UseCheckpoint = false + } + } + var statsHandle *handle.Handle + if !skipStats { + statsHandle = mgr.GetDomain().StatsHandle() + } + + var newCollationEnable string + err = g.UseOneShotSession(mgr.GetStorage(), !needDomain, func(se glue.Session) error { + newCollationEnable, err = se.GetGlobalVariable(utils.GetTidbNewCollationEnabled()) + if err != nil { + return errors.Trace(err) + } + log.Info(fmt.Sprintf("get %s config from mysql.tidb table", utils.TidbNewCollationEnabled), + zap.String(utils.GetTidbNewCollationEnabled(), newCollationEnable)) + return nil + }) + if err != nil { + return errors.Trace(err) + } + + client := backup.NewBackupClient(ctx, mgr) + + // set cipher only for checkpoint + client.SetCipher(&cfg.CipherInfo) + + opts := storage.ExternalStorageOptions{ + NoCredentials: cfg.NoCreds, + SendCredentials: cfg.SendCreds, + CheckS3ObjectLockOptions: true, + } + if err = client.SetStorageAndCheckNotInUse(ctx, u, &opts); err != nil { + return errors.Trace(err) + } + // if checkpoint mode is unused at this time but there is checkpoint meta, + // CheckCheckpoint will stop backing up + cfgHash, err := cfg.Hash() + if err != nil { + return errors.Trace(err) + } + err = client.CheckCheckpoint(cfgHash) + if err != nil { + return errors.Trace(err) + } + err = client.SetLockFile(ctx) + if err != nil { + return errors.Trace(err) + } + // if use checkpoint and gcTTL is the default value + // update gcttl to checkpoint's default gc ttl + if cfg.UseCheckpoint && cfg.GCTTL == utils.DefaultBRGCSafePointTTL { + cfg.GCTTL = utils.DefaultCheckpointGCSafePointTTL + log.Info("use checkpoint's default GC TTL", zap.Int64("GC TTL", cfg.GCTTL)) + } + client.SetGCTTL(cfg.GCTTL) + + backupTS, err := client.GetTS(ctx, cfg.TimeAgo, cfg.BackupTS) + if err != nil { + return errors.Trace(err) + } + g.Record("BackupTS", backupTS) + safePointID := client.GetSafePointID() + sp := utils.BRServiceSafePoint{ + BackupTS: backupTS, + TTL: client.GetGCTTL(), + ID: safePointID, + } + + // use lastBackupTS as safePoint if exists + isIncrementalBackup := cfg.LastBackupTS > 0 + if isIncrementalBackup { + sp.BackupTS = cfg.LastBackupTS + } + + log.Info("current backup safePoint job", zap.Object("safePoint", sp)) + cctx, gcSafePointKeeperCancel := context.WithCancel(ctx) + gcSafePointKeeperRemovable := false + defer func() { + // don't reset the gc-safe-point if checkpoint mode is used and backup is not finished + if cfg.UseCheckpoint && !gcSafePointKeeperRemovable { + log.Info("skip removing gc-safepoint keeper for next retry", zap.String("gc-id", sp.ID)) + return + } + log.Info("start to remove gc-safepoint keeper") + // close the gc safe point keeper at first + gcSafePointKeeperCancel() + // set the ttl to 0 to remove the gc-safe-point + sp.TTL = 0 + if err := utils.UpdateServiceSafePoint(ctx, mgr.GetPDClient(), sp); err != nil { + log.Warn("failed to update service safe point, backup may fail if gc triggered", + zap.Error(err), + ) + } + log.Info("finish removing gc-safepoint keeper") + }() + err = utils.StartServiceSafePointKeeper(cctx, mgr.GetPDClient(), sp) + if err != nil { + return errors.Trace(err) + } + + if cfg.RemoveSchedulers { + log.Debug("removing some PD schedulers") + restore, e := mgr.RemoveSchedulers(ctx) + defer func() { + if ctx.Err() != nil { + log.Warn("context canceled, doing clean work with background context") + ctx = context.Background() + } + if restoreE := restore(ctx); restoreE != nil { + log.Warn("failed to restore removed schedulers, you may need to restore them manually", zap.Error(restoreE)) + } + }() + if e != nil { + return errors.Trace(err) + } + } + + req := backuppb.BackupRequest{ + ClusterId: client.GetClusterID(), + StartVersion: cfg.LastBackupTS, + EndVersion: backupTS, + RateLimit: cfg.RateLimit, + StorageBackend: client.GetStorageBackend(), + Concurrency: defaultBackupConcurrency, + CompressionType: cfg.CompressionType, + CompressionLevel: cfg.CompressionLevel, + CipherInfo: &cfg.CipherInfo, + ReplicaRead: len(cfg.ReplicaReadLabel) != 0, + Context: &kvrpcpb.Context{ + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: "", // TODO, + }, + RequestSource: kvutil.BuildRequestSource(true, kv.InternalTxnBR, kvutil.ExplicitTypeBR), + }, + } + brVersion := g.GetVersion() + clusterVersion, err := mgr.GetClusterVersion(ctx) + if err != nil { + return errors.Trace(err) + } + + ranges, schemas, policies, err := client.BuildBackupRangeAndSchema(mgr.GetStorage(), cfg.TableFilter, backupTS, isFullBackup(cmdName)) + if err != nil { + return errors.Trace(err) + } + + // Metafile size should be less than 64MB. + metawriter := metautil.NewMetaWriter(client.GetStorage(), + metautil.MetaFileSize, cfg.UseBackupMetaV2, "", &cfg.CipherInfo) + // Hack way to update backupmeta. + metawriter.Update(func(m *backuppb.BackupMeta) { + m.StartVersion = req.StartVersion + m.EndVersion = req.EndVersion + m.IsRawKv = req.IsRawKv + m.ClusterId = req.ClusterId + m.ClusterVersion = clusterVersion + m.BrVersion = brVersion + m.NewCollationsEnabled = newCollationEnable + m.ApiVersion = mgr.GetStorage().GetCodec().GetAPIVersion() + }) + + log.Info("get placement policies", zap.Int("count", len(policies))) + if len(policies) != 0 { + metawriter.Update(func(m *backuppb.BackupMeta) { + m.Policies = policies + }) + } + + // nothing to backup + if len(ranges) == 0 { + pdAddress := strings.Join(cfg.PD, ",") + log.Warn("Nothing to backup, maybe connected to cluster for restoring", + zap.String("PD address", pdAddress)) + + err = metawriter.FlushBackupMeta(ctx) + if err == nil { + summary.SetSuccessStatus(true) + } + return err + } + + if isIncrementalBackup { + if backupTS <= cfg.LastBackupTS { + log.Error("LastBackupTS is larger or equal to current TS") + return errors.Annotate(berrors.ErrInvalidArgument, "LastBackupTS is larger or equal to current TS") + } + err = utils.CheckGCSafePoint(ctx, mgr.GetPDClient(), cfg.LastBackupTS) + if err != nil { + log.Error("Check gc safepoint for last backup ts failed", zap.Error(err)) + return errors.Trace(err) + } + + metawriter.StartWriteMetasAsync(ctx, metautil.AppendDDL) + err = backup.WriteBackupDDLJobs(metawriter, g, mgr.GetStorage(), cfg.LastBackupTS, backupTS, needDomain) + if err != nil { + return errors.Trace(err) + } + if err = metawriter.FinishWriteMetas(ctx, metautil.AppendDDL); err != nil { + return errors.Trace(err) + } + } + + summary.CollectInt("backup total ranges", len(ranges)) + + approximateRegions, err := getRegionCountOfRanges(ctx, mgr, ranges) + if err != nil { + return errors.Trace(err) + } + // Redirect to log if there is no log file to avoid unreadable output. + updateCh := g.StartProgress( + ctx, cmdName, int64(approximateRegions), !cfg.LogProgress) + summary.CollectInt("backup total regions", approximateRegions) + + progressCount := uint64(0) + progressCallBack := func() { + updateCh.Inc() + failpoint.Inject("progress-call-back", func(v failpoint.Value) { + log.Info("failpoint progress-call-back injected") + atomic.AddUint64(&progressCount, 1) + if fileName, ok := v.(string); ok { + f, osErr := os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY, os.ModePerm) + if osErr != nil { + log.Warn("failed to create file", zap.Error(osErr)) + } + msg := []byte(fmt.Sprintf("region:%d\n", atomic.LoadUint64(&progressCount))) + _, err = f.Write(msg) + if err != nil { + log.Warn("failed to write data to file", zap.Error(err)) + } + } + }) + } + + if cfg.UseCheckpoint { + if err = client.StartCheckpointRunner(ctx, cfgHash, backupTS, ranges, safePointID, progressCallBack); err != nil { + return errors.Trace(err) + } + defer func() { + if !gcSafePointKeeperRemovable { + log.Info("wait for flush checkpoint...") + client.WaitForFinishCheckpoint(ctx, true) + } else { + log.Info("start to remove checkpoint data for backup") + client.WaitForFinishCheckpoint(ctx, false) + if removeErr := checkpoint.RemoveCheckpointDataForBackup(ctx, client.GetStorage()); removeErr != nil { + log.Warn("failed to remove checkpoint data for backup", zap.Error(removeErr)) + } else { + log.Info("the checkpoint data for backup is removed.") + } + } + }() + } + + failpoint.Inject("s3-outage-during-writing-file", func(v failpoint.Value) { + log.Info("failpoint s3-outage-during-writing-file injected, " + + "process will sleep for 5s and notify the shell to kill s3 service.") + if sigFile, ok := v.(string); ok { + file, err := os.Create(sigFile) + if err != nil { + log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) + } + if file != nil { + file.Close() + } + } + time.Sleep(5 * time.Second) + }) + + metawriter.StartWriteMetasAsync(ctx, metautil.AppendDataFile) + err = client.BackupRanges(ctx, ranges, req, uint(cfg.Concurrency), cfg.ReplicaReadLabel, metawriter, progressCallBack) + if err != nil { + return errors.Trace(err) + } + // Backup has finished + updateCh.Close() + + err = metawriter.FinishWriteMetas(ctx, metautil.AppendDataFile) + if err != nil { + return errors.Trace(err) + } + + skipChecksum := !cfg.Checksum || isIncrementalBackup + checksumProgress := int64(schemas.Len()) + if skipChecksum { + checksumProgress = 1 + if isIncrementalBackup { + // Since we don't support checksum for incremental data, fast checksum should be skipped. + log.Info("Skip fast checksum in incremental backup") + } else { + // When user specified not to calculate checksum, don't calculate checksum. + log.Info("Skip fast checksum") + } + } + updateCh = g.StartProgress(ctx, "Checksum", checksumProgress, !cfg.LogProgress) + schemasConcurrency := min(cfg.TableConcurrency, uint(schemas.Len())) + + err = schemas.BackupSchemas( + ctx, metawriter, client.GetCheckpointRunner(), mgr.GetStorage(), statsHandle, backupTS, schemasConcurrency, cfg.ChecksumConcurrency, skipChecksum, updateCh) + if err != nil { + return errors.Trace(err) + } + + err = metawriter.FlushBackupMeta(ctx) + if err != nil { + return errors.Trace(err) + } + // Since backupmeta is flushed on the external storage, + // we can remove the gc safepoint keeper + gcSafePointKeeperRemovable = true + + // Checksum has finished, close checksum progress. + updateCh.Close() + + if !skipChecksum { + // Check if checksum from files matches checksum from coprocessor. + err = checksum.FastChecksum(ctx, metawriter.Backupmeta(), client.GetStorage(), &cfg.CipherInfo) + if err != nil { + return errors.Trace(err) + } + } + archiveSize := metawriter.ArchiveSize() + g.Record(summary.BackupDataSize, archiveSize) + //backup from tidb will fetch a general Size issue https://github.com/pingcap/tidb/issues/27247 + g.Record("Size", archiveSize) + // Set task summary to success status. + summary.SetSuccessStatus(true) + return nil +} + +func getRegionCountOfRanges( + ctx context.Context, + mgr *conn.Mgr, + ranges []rtree.Range, +) (int, error) { + // The number of regions need to backup + approximateRegions := 0 + for _, r := range ranges { + regionCount, err := mgr.GetRegionCount(ctx, r.StartKey, r.EndKey) + if err != nil { + return 0, errors.Trace(err) + } + approximateRegions += regionCount + } + return approximateRegions, nil +} + +// ParseTSString port from tidb setSnapshotTS. +func ParseTSString(ts string, tzCheck bool) (uint64, error) { + if len(ts) == 0 { + return 0, nil + } + if tso, err := strconv.ParseUint(ts, 10, 64); err == nil { + return tso, nil + } + + loc := time.Local + sc := stmtctx.NewStmtCtxWithTimeZone(loc) + if tzCheck { + tzIdx, _, _, _, _ := types.GetTimezone(ts) + if tzIdx < 0 { + return 0, errors.Errorf("must set timezone when using datetime format ts, e.g. '2018-05-11 01:42:23+0800'") + } + } + t, err := types.ParseTime(sc.TypeCtx(), ts, mysql.TypeTimestamp, types.MaxFsp) + if err != nil { + return 0, errors.Trace(err) + } + t1, err := t.GoTime(loc) + if err != nil { + return 0, errors.Trace(err) + } + return oracle.GoTimeToTS(t1), nil +} + +func DefaultBackupConfig() BackupConfig { + fs := pflag.NewFlagSet("dummy", pflag.ContinueOnError) + DefineCommonFlags(fs) + DefineBackupFlags(fs) + cfg := BackupConfig{} + err := multierr.Combine( + cfg.ParseFromFlags(fs), + cfg.Config.ParseFromFlags(fs), + ) + if err != nil { + log.Panic("infallible operation failed.", zap.Error(err)) + } + return cfg +} + +func parseCompressionType(s string) (backuppb.CompressionType, error) { + var ct backuppb.CompressionType + switch s { + case "lz4": + ct = backuppb.CompressionType_LZ4 + case "snappy": + ct = backuppb.CompressionType_SNAPPY + case "zstd": + ct = backuppb.CompressionType_ZSTD + default: + return backuppb.CompressionType_UNKNOWN, errors.Annotatef(berrors.ErrInvalidArgument, "invalid compression type '%s'", s) + } + return ct, nil +} + +func parseReplicaReadLabelFlag(flags *pflag.FlagSet) (map[string]string, error) { + replicaReadLabelStr, err := flags.GetString(flagReplicaReadLabel) + if err != nil { + return nil, errors.Trace(err) + } + if replicaReadLabelStr == "" { + return nil, nil + } + kv := strings.Split(replicaReadLabelStr, ":") + if len(kv) != 2 { + return nil, errors.Annotatef(berrors.ErrInvalidArgument, "invalid replica read label '%s'", replicaReadLabelStr) + } + return map[string]string{kv[0]: kv[1]}, nil +} diff --git a/br/pkg/task/binding__failpoint_binding__.go b/br/pkg/task/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..ecd81d7d48d63 --- /dev/null +++ b/br/pkg/task/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package task + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/task/operator/binding__failpoint_binding__.go b/br/pkg/task/operator/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..fc7dd0d4f3fa7 --- /dev/null +++ b/br/pkg/task/operator/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package operator + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/task/operator/cmd.go b/br/pkg/task/operator/cmd.go index cbe5c3ac2442b..5ae421460737c 100644 --- a/br/pkg/task/operator/cmd.go +++ b/br/pkg/task/operator/cmd.go @@ -146,9 +146,9 @@ func AdaptEnvForSnapshotBackup(ctx context.Context, cfg *PauseGcConfig) error { }) cx.run(func() error { return pauseAdminAndWaitApply(cx, initChan) }) go func() { - failpoint.Inject("SkipReadyHint", func() { - failpoint.Return() - }) + if _, _err_ := failpoint.Eval(_curpkg_("SkipReadyHint")); _err_ == nil { + return + } cx.rdGrp.Wait() if cfg.OnAllReady != nil { cfg.OnAllReady() diff --git a/br/pkg/task/operator/cmd.go__failpoint_stash__ b/br/pkg/task/operator/cmd.go__failpoint_stash__ new file mode 100644 index 0000000000000..cbe5c3ac2442b --- /dev/null +++ b/br/pkg/task/operator/cmd.go__failpoint_stash__ @@ -0,0 +1,251 @@ +// Copyright 2023 PingCAP, Inc. Licensed under Apache-2.0. + +package operator + +import ( + "context" + "crypto/tls" + "runtime/debug" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + preparesnap "github.com/pingcap/tidb/br/pkg/backup/prepare_snap" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/pdutil" + "github.com/pingcap/tidb/br/pkg/task" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/tikv/client-go/v2/tikv" + "go.uber.org/multierr" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/keepalive" +) + +func dialPD(ctx context.Context, cfg *task.Config) (*pdutil.PdController, error) { + var tc *tls.Config + if cfg.TLS.IsEnabled() { + var err error + tc, err = cfg.TLS.ToTLSConfig() + if err != nil { + return nil, err + } + } + mgr, err := pdutil.NewPdController(ctx, cfg.PD, tc, cfg.TLS.ToPDSecurityOption()) + if err != nil { + return nil, err + } + return mgr, nil +} + +func (cx *AdaptEnvForSnapshotBackupContext) cleanUpWith(f func(ctx context.Context)) { + cx.cleanUpWithRetErr(nil, func(ctx context.Context) error { f(ctx); return nil }) +} + +func (cx *AdaptEnvForSnapshotBackupContext) cleanUpWithRetErr(errOut *error, f func(ctx context.Context) error) { + ctx, cancel := context.WithTimeout(context.Background(), cx.cfg.TTL) + defer cancel() + err := f(ctx) + if errOut != nil { + *errOut = multierr.Combine(*errOut, err) + } +} + +func (cx *AdaptEnvForSnapshotBackupContext) run(f func() error) { + cx.rdGrp.Add(1) + buf := debug.Stack() + cx.runGrp.Go(func() error { + err := f() + if err != nil { + log.Error("A task failed.", zap.Error(err), zap.ByteString("task-created-at", buf)) + } + return err + }) +} + +type AdaptEnvForSnapshotBackupContext struct { + context.Context + + pdMgr *pdutil.PdController + kvMgr *utils.StoreManager + cfg PauseGcConfig + + rdGrp sync.WaitGroup + runGrp *errgroup.Group +} + +func (cx *AdaptEnvForSnapshotBackupContext) Close() { + cx.pdMgr.Close() + cx.kvMgr.Close() +} + +func (cx *AdaptEnvForSnapshotBackupContext) GetBackOffer(operation string) utils.Backoffer { + state := utils.InitialRetryState(64, 1*time.Second, 10*time.Second) + bo := utils.GiveUpRetryOn(&state, berrors.ErrPossibleInconsistency) + bo = utils.VerboseRetry(bo, logutil.CL(cx).With(zap.String("operation", operation))) + return bo +} + +func (cx *AdaptEnvForSnapshotBackupContext) ReadyL(name string, notes ...zap.Field) { + logutil.CL(cx).Info("Stage ready.", append(notes, zap.String("component", name))...) + cx.rdGrp.Done() +} + +func hintAllReady() { + // Hacking: some version of operators using the follow two logs to check whether we are ready... + log.Info("Schedulers are paused.") + log.Info("GC is paused.") + log.Info("All ready.") +} + +// AdaptEnvForSnapshotBackup blocks the current goroutine and pause the GC safepoint and remove the scheduler by the config. +// This function will block until the context being canceled. +func AdaptEnvForSnapshotBackup(ctx context.Context, cfg *PauseGcConfig) error { + utils.DumpGoroutineWhenExit.Store(true) + mgr, err := dialPD(ctx, &cfg.Config) + if err != nil { + return errors.Annotate(err, "failed to dial PD") + } + mgr.SchedulerPauseTTL = cfg.TTL + var tconf *tls.Config + if cfg.TLS.IsEnabled() { + tconf, err = cfg.TLS.ToTLSConfig() + if err != nil { + return errors.Annotate(err, "invalid tls config") + } + } + kvMgr := utils.NewStoreManager(mgr.GetPDClient(), keepalive.ClientParameters{ + Time: cfg.Config.GRPCKeepaliveTime, + Timeout: cfg.Config.GRPCKeepaliveTimeout, + }, tconf) + eg, ectx := errgroup.WithContext(ctx) + cx := &AdaptEnvForSnapshotBackupContext{ + Context: logutil.ContextWithField(ectx, zap.String("tag", "br_operator")), + pdMgr: mgr, + kvMgr: kvMgr, + cfg: *cfg, + rdGrp: sync.WaitGroup{}, + runGrp: eg, + } + defer cx.Close() + + initChan := make(chan struct{}) + cx.run(func() error { return pauseGCKeeper(cx) }) + cx.run(func() error { + log.Info("Pause scheduler waiting all connections established.") + select { + case <-initChan: + case <-cx.Done(): + return cx.Err() + } + log.Info("Pause scheduler noticed connections established.") + return pauseSchedulerKeeper(cx) + }) + cx.run(func() error { return pauseAdminAndWaitApply(cx, initChan) }) + go func() { + failpoint.Inject("SkipReadyHint", func() { + failpoint.Return() + }) + cx.rdGrp.Wait() + if cfg.OnAllReady != nil { + cfg.OnAllReady() + } + utils.DumpGoroutineWhenExit.Store(false) + hintAllReady() + }() + defer func() { + if cfg.OnExit != nil { + cfg.OnExit() + } + }() + + return eg.Wait() +} + +func pauseAdminAndWaitApply(cx *AdaptEnvForSnapshotBackupContext, afterConnectionsEstablished chan<- struct{}) error { + env := preparesnap.CliEnv{ + Cache: tikv.NewRegionCache(cx.pdMgr.GetPDClient()), + Mgr: cx.kvMgr, + } + defer env.Cache.Close() + retryEnv := preparesnap.RetryAndSplitRequestEnv{Env: env} + begin := time.Now() + prep := preparesnap.New(retryEnv) + prep.LeaseDuration = cx.cfg.TTL + prep.AfterConnectionsEstablished = func() { + log.Info("All connections are stablished.") + close(afterConnectionsEstablished) + } + + defer cx.cleanUpWith(func(ctx context.Context) { + if err := prep.Finalize(ctx); err != nil { + logutil.CL(ctx).Warn("failed to finalize the prepare stream", logutil.ShortError(err)) + } + }) + + // We must use our own context here, or once we are cleaning up the client will be invalid. + myCtx := logutil.ContextWithField(context.Background(), zap.String("category", "pause_admin_and_wait_apply")) + if err := prep.DriveLoopAndWaitPrepare(myCtx); err != nil { + return err + } + + cx.ReadyL("pause_admin_and_wait_apply", zap.Stringer("take", time.Since(begin))) + <-cx.Done() + return nil +} + +func pauseGCKeeper(cx *AdaptEnvForSnapshotBackupContext) (err error) { + // Note: should we remove the service safepoint as soon as this exits? + sp := utils.BRServiceSafePoint{ + ID: utils.MakeSafePointID(), + TTL: int64(cx.cfg.TTL.Seconds()), + BackupTS: cx.cfg.SafePoint, + } + if sp.BackupTS == 0 { + rts, err := cx.pdMgr.GetMinResolvedTS(cx) + if err != nil { + return err + } + logutil.CL(cx).Info("No service safepoint provided, using the minimal resolved TS.", zap.Uint64("min-resolved-ts", rts)) + sp.BackupTS = rts + } + err = utils.StartServiceSafePointKeeper(cx, cx.pdMgr.GetPDClient(), sp) + if err != nil { + return err + } + cx.ReadyL("pause_gc", zap.Object("safepoint", sp)) + defer cx.cleanUpWithRetErr(&err, func(ctx context.Context) error { + cancelSP := utils.BRServiceSafePoint{ + ID: sp.ID, + TTL: 0, + } + return utils.UpdateServiceSafePoint(ctx, cx.pdMgr.GetPDClient(), cancelSP) + }) + // Note: in fact we can directly return here. + // But the name `keeper` implies once the function exits, + // the GC should be resume, so let's block here. + <-cx.Done() + return nil +} + +func pauseSchedulerKeeper(ctx *AdaptEnvForSnapshotBackupContext) error { + undo, err := ctx.pdMgr.RemoveAllPDSchedulers(ctx) + if undo != nil { + defer ctx.cleanUpWith(func(ctx context.Context) { + if err := undo(ctx); err != nil { + log.Warn("failed to restore pd scheduler.", logutil.ShortError(err)) + } + }) + } + if err != nil { + return err + } + ctx.ReadyL("pause_scheduler") + // Wait until the context canceled. + // So we can properly do the clean up work. + <-ctx.Done() + return nil +} diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 8bc6383be78b6..392a8b005b858 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -1113,10 +1113,10 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf // Restore sst files in batch. batchSize := mathutil.MaxInt - failpoint.Inject("small-batch-size", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("small-batch-size")); _err_ == nil { log.Info("failpoint small batch size is on", zap.Int("size", v.(int))) batchSize = v.(int) - }) + } // Split/Scatter + Download/Ingest progressLen := int64(rangeSize + len(files)) diff --git a/br/pkg/task/restore.go__failpoint_stash__ b/br/pkg/task/restore.go__failpoint_stash__ new file mode 100644 index 0000000000000..8bc6383be78b6 --- /dev/null +++ b/br/pkg/task/restore.go__failpoint_stash__ @@ -0,0 +1,1713 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package task + +import ( + "cmp" + "context" + "fmt" + "slices" + "strings" + "time" + + "github.com/docker/go-units" + "github.com/google/uuid" + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/checkpoint" + pconfig "github.com/pingcap/tidb/br/pkg/config" + "github.com/pingcap/tidb/br/pkg/conn" + connutil "github.com/pingcap/tidb/br/pkg/conn/util" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/httputil" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/restore" + snapclient "github.com/pingcap/tidb/br/pkg/restore/snap_client" + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/engine" + "github.com/pingcap/tidb/pkg/util/mathutil" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/tikv/client-go/v2/tikv" + pd "github.com/tikv/pd/client" + "github.com/tikv/pd/client/http" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +const ( + flagOnline = "online" + flagNoSchema = "no-schema" + flagLoadStats = "load-stats" + flagGranularity = "granularity" + flagConcurrencyPerStore = "tikv-max-restore-concurrency" + flagAllowPITRFromIncremental = "allow-pitr-from-incremental" + + // FlagMergeRegionSizeBytes is the flag name of merge small regions by size + FlagMergeRegionSizeBytes = "merge-region-size-bytes" + // FlagMergeRegionKeyCount is the flag name of merge small regions by key count + FlagMergeRegionKeyCount = "merge-region-key-count" + // FlagPDConcurrency controls concurrency pd-relative operations like split & scatter. + FlagPDConcurrency = "pd-concurrency" + // FlagStatsConcurrency controls concurrency to restore statistic. + FlagStatsConcurrency = "stats-concurrency" + // FlagBatchFlushInterval controls after how long the restore batch would be auto sended. + FlagBatchFlushInterval = "batch-flush-interval" + // FlagDdlBatchSize controls batch ddl size to create a batch of tables + FlagDdlBatchSize = "ddl-batch-size" + // FlagWithPlacementPolicy corresponds to tidb config with-tidb-placement-mode + // current only support STRICT or IGNORE, the default is STRICT according to tidb. + FlagWithPlacementPolicy = "with-tidb-placement-mode" + // FlagKeyspaceName corresponds to tidb config keyspace-name + FlagKeyspaceName = "keyspace-name" + + // FlagWaitTiFlashReady represents whether wait tiflash replica ready after table restored and checksumed. + FlagWaitTiFlashReady = "wait-tiflash-ready" + + // FlagStreamStartTS and FlagStreamRestoreTS is used for log restore timestamp range. + FlagStreamStartTS = "start-ts" + FlagStreamRestoreTS = "restored-ts" + // FlagStreamFullBackupStorage is used for log restore, represents the full backup storage. + FlagStreamFullBackupStorage = "full-backup-storage" + // FlagPiTRBatchCount and FlagPiTRBatchSize are used for restore log with batch method. + FlagPiTRBatchCount = "pitr-batch-count" + FlagPiTRBatchSize = "pitr-batch-size" + FlagPiTRConcurrency = "pitr-concurrency" + + FlagResetSysUsers = "reset-sys-users" + + defaultPiTRBatchCount = 8 + defaultPiTRBatchSize = 16 * 1024 * 1024 + defaultRestoreConcurrency = 128 + defaultPiTRConcurrency = 16 + defaultPDConcurrency = 1 + defaultStatsConcurrency = 12 + defaultBatchFlushInterval = 16 * time.Second + defaultFlagDdlBatchSize = 128 + resetSpeedLimitRetryTimes = 3 + maxRestoreBatchSizeLimit = 10240 +) + +const ( + FullRestoreCmd = "Full Restore" + DBRestoreCmd = "DataBase Restore" + TableRestoreCmd = "Table Restore" + PointRestoreCmd = "Point Restore" + RawRestoreCmd = "Raw Restore" + TxnRestoreCmd = "Txn Restore" +) + +// RestoreCommonConfig is the common configuration for all BR restore tasks. +type RestoreCommonConfig struct { + Online bool `json:"online" toml:"online"` + Granularity string `json:"granularity" toml:"granularity"` + ConcurrencyPerStore pconfig.ConfigTerm[uint] `json:"tikv-max-restore-concurrency" toml:"tikv-max-restore-concurrency"` + + // MergeSmallRegionSizeBytes is the threshold of merging small regions (Default 96MB, region split size). + // MergeSmallRegionKeyCount is the threshold of merging smalle regions (Default 960_000, region split key count). + // See https://github.com/tikv/tikv/blob/v4.0.8/components/raftstore/src/coprocessor/config.rs#L35-L38 + MergeSmallRegionSizeBytes pconfig.ConfigTerm[uint64] `json:"merge-region-size-bytes" toml:"merge-region-size-bytes"` + MergeSmallRegionKeyCount pconfig.ConfigTerm[uint64] `json:"merge-region-key-count" toml:"merge-region-key-count"` + + // determines whether enable restore sys table on default, see fullClusterRestore in restore/client.go + WithSysTable bool `json:"with-sys-table" toml:"with-sys-table"` + + ResetSysUsers []string `json:"reset-sys-users" toml:"reset-sys-users"` +} + +// adjust adjusts the abnormal config value in the current config. +// useful when not starting BR from CLI (e.g. from BRIE in SQL). +func (cfg *RestoreCommonConfig) adjust() { + if !cfg.MergeSmallRegionKeyCount.Modified { + cfg.MergeSmallRegionKeyCount.Value = conn.DefaultMergeRegionKeyCount + } + if !cfg.MergeSmallRegionSizeBytes.Modified { + cfg.MergeSmallRegionSizeBytes.Value = conn.DefaultMergeRegionSizeBytes + } + if len(cfg.Granularity) == 0 { + cfg.Granularity = string(restore.CoarseGrained) + } + if !cfg.ConcurrencyPerStore.Modified { + cfg.ConcurrencyPerStore.Value = conn.DefaultImportNumGoroutines + } +} + +// DefineRestoreCommonFlags defines common flags for the restore command. +func DefineRestoreCommonFlags(flags *pflag.FlagSet) { + // TODO remove experimental tag if it's stable + flags.Bool(flagOnline, false, "(experimental) Whether online when restore") + flags.String(flagGranularity, string(restore.CoarseGrained), "(deprecated) Whether split & scatter regions using fine-grained way during restore") + flags.Uint(flagConcurrencyPerStore, 128, "The size of thread pool on each store that executes tasks") + flags.Uint32(flagConcurrency, 128, "(deprecated) The size of thread pool on BR that executes tasks, "+ + "where each task restores one SST file to TiKV") + flags.Uint64(FlagMergeRegionSizeBytes, conn.DefaultMergeRegionSizeBytes, + "the threshold of merging small regions (Default 96MB, region split size)") + flags.Uint64(FlagMergeRegionKeyCount, conn.DefaultMergeRegionKeyCount, + "the threshold of merging small regions (Default 960_000, region split key count)") + flags.Uint(FlagPDConcurrency, defaultPDConcurrency, + "concurrency pd-relative operations like split & scatter.") + flags.Uint(FlagStatsConcurrency, defaultStatsConcurrency, + "concurrency to restore statistic") + flags.Duration(FlagBatchFlushInterval, defaultBatchFlushInterval, + "after how long a restore batch would be auto sent.") + flags.Uint(FlagDdlBatchSize, defaultFlagDdlBatchSize, + "batch size for ddl to create a batch of tables once.") + flags.Bool(flagWithSysTable, true, "whether restore system privilege tables on default setting") + flags.StringArrayP(FlagResetSysUsers, "", []string{"cloud_admin", "root"}, "whether reset these users after restoration") + flags.Bool(flagUseFSR, false, "whether enable FSR for AWS snapshots") + + _ = flags.MarkHidden(FlagResetSysUsers) + _ = flags.MarkHidden(FlagMergeRegionSizeBytes) + _ = flags.MarkHidden(FlagMergeRegionKeyCount) + _ = flags.MarkHidden(FlagPDConcurrency) + _ = flags.MarkHidden(FlagStatsConcurrency) + _ = flags.MarkHidden(FlagBatchFlushInterval) + _ = flags.MarkHidden(FlagDdlBatchSize) +} + +// ParseFromFlags parses the config from the flag set. +func (cfg *RestoreCommonConfig) ParseFromFlags(flags *pflag.FlagSet) error { + var err error + cfg.Online, err = flags.GetBool(flagOnline) + if err != nil { + return errors.Trace(err) + } + cfg.Granularity, err = flags.GetString(flagGranularity) + if err != nil { + return errors.Trace(err) + } + cfg.ConcurrencyPerStore.Value, err = flags.GetUint(flagConcurrencyPerStore) + if err != nil { + return errors.Trace(err) + } + cfg.ConcurrencyPerStore.Modified = flags.Changed(flagConcurrencyPerStore) + + cfg.MergeSmallRegionKeyCount.Value, err = flags.GetUint64(FlagMergeRegionKeyCount) + if err != nil { + return errors.Trace(err) + } + cfg.MergeSmallRegionKeyCount.Modified = flags.Changed(FlagMergeRegionKeyCount) + + cfg.MergeSmallRegionSizeBytes.Value, err = flags.GetUint64(FlagMergeRegionSizeBytes) + if err != nil { + return errors.Trace(err) + } + cfg.MergeSmallRegionSizeBytes.Modified = flags.Changed(FlagMergeRegionSizeBytes) + + if flags.Lookup(flagWithSysTable) != nil { + cfg.WithSysTable, err = flags.GetBool(flagWithSysTable) + if err != nil { + return errors.Trace(err) + } + } + cfg.ResetSysUsers, err = flags.GetStringArray(FlagResetSysUsers) + if err != nil { + return errors.Trace(err) + } + return errors.Trace(err) +} + +// RestoreConfig is the configuration specific for restore tasks. +type RestoreConfig struct { + Config + RestoreCommonConfig + + NoSchema bool `json:"no-schema" toml:"no-schema"` + LoadStats bool `json:"load-stats" toml:"load-stats"` + PDConcurrency uint `json:"pd-concurrency" toml:"pd-concurrency"` + StatsConcurrency uint `json:"stats-concurrency" toml:"stats-concurrency"` + BatchFlushInterval time.Duration `json:"batch-flush-interval" toml:"batch-flush-interval"` + // DdlBatchSize use to define the size of batch ddl to create tables + DdlBatchSize uint `json:"ddl-batch-size" toml:"ddl-batch-size"` + + WithPlacementPolicy string `json:"with-tidb-placement-mode" toml:"with-tidb-placement-mode"` + + // FullBackupStorage is used to run `restore full` before `restore log`. + // if it is empty, directly take restoring log justly. + FullBackupStorage string `json:"full-backup-storage" toml:"full-backup-storage"` + + // AllowPITRFromIncremental indicates whether this restore should enter a compatibility mode for incremental restore. + // In this restore mode, the restore will not perform timestamp rewrite on the incremental data. + AllowPITRFromIncremental bool `json:"allow-pitr-from-incremental" toml:"allow-pitr-from-incremental"` + + // [startTs, RestoreTS] is used to `restore log` from StartTS to RestoreTS. + StartTS uint64 `json:"start-ts" toml:"start-ts"` + RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"` + tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"` + PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"` + PitrBatchSize uint32 `json:"pitr-batch-size" toml:"pitr-batch-size"` + PitrConcurrency uint32 `json:"-" toml:"-"` + + UseCheckpoint bool `json:"use-checkpoint" toml:"use-checkpoint"` + checkpointSnapshotRestoreTaskName string `json:"-" toml:"-"` + checkpointLogRestoreTaskName string `json:"-" toml:"-"` + checkpointTaskInfoClusterID uint64 `json:"-" toml:"-"` + WaitTiflashReady bool `json:"wait-tiflash-ready" toml:"wait-tiflash-ready"` + + // for ebs-based restore + FullBackupType FullBackupType `json:"full-backup-type" toml:"full-backup-type"` + Prepare bool `json:"prepare" toml:"prepare"` + OutputFile string `json:"output-file" toml:"output-file"` + SkipAWS bool `json:"skip-aws" toml:"skip-aws"` + CloudAPIConcurrency uint `json:"cloud-api-concurrency" toml:"cloud-api-concurrency"` + VolumeType pconfig.EBSVolumeType `json:"volume-type" toml:"volume-type"` + VolumeIOPS int64 `json:"volume-iops" toml:"volume-iops"` + VolumeThroughput int64 `json:"volume-throughput" toml:"volume-throughput"` + VolumeEncrypted bool `json:"volume-encrypted" toml:"volume-encrypted"` + ProgressFile string `json:"progress-file" toml:"progress-file"` + TargetAZ string `json:"target-az" toml:"target-az"` + UseFSR bool `json:"use-fsr" toml:"use-fsr"` +} + +// DefineRestoreFlags defines common flags for the restore tidb command. +func DefineRestoreFlags(flags *pflag.FlagSet) { + flags.Bool(flagNoSchema, false, "skip creating schemas and tables, reuse existing empty ones") + flags.Bool(flagLoadStats, true, "Run load stats at end of snapshot restore task") + // Do not expose this flag + _ = flags.MarkHidden(flagNoSchema) + flags.String(FlagWithPlacementPolicy, "STRICT", "correspond to tidb global/session variable with-tidb-placement-mode") + flags.String(FlagKeyspaceName, "", "correspond to tidb config keyspace-name") + + flags.Bool(flagUseCheckpoint, true, "use checkpoint mode") + _ = flags.MarkHidden(flagUseCheckpoint) + + flags.Bool(FlagWaitTiFlashReady, false, "whether wait tiflash replica ready if tiflash exists") + flags.Bool(flagAllowPITRFromIncremental, true, "whether make incremental restore compatible with later log restore"+ + " default is true, the incremental restore will not perform rewrite on the incremental data"+ + " meanwhile the incremental restore will not allow to restore 3 backfilled type ddl jobs,"+ + " these ddl jobs are Add index, Modify column and Reorganize partition") + + DefineRestoreCommonFlags(flags) +} + +// DefineStreamRestoreFlags defines for the restore log command. +func DefineStreamRestoreFlags(command *cobra.Command) { + command.Flags().String(FlagStreamStartTS, "", "the start timestamp which log restore from.\n"+ + "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'") + command.Flags().String(FlagStreamRestoreTS, "", "the point of restore, used for log restore.\n"+ + "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'") + command.Flags().String(FlagStreamFullBackupStorage, "", "specify the backup full storage. "+ + "fill it if want restore full backup before restore log.") + command.Flags().Uint32(FlagPiTRBatchCount, defaultPiTRBatchCount, "specify the batch count to restore log.") + command.Flags().Uint32(FlagPiTRBatchSize, defaultPiTRBatchSize, "specify the batch size to retore log.") + command.Flags().Uint32(FlagPiTRConcurrency, defaultPiTRConcurrency, "specify the concurrency to restore log.") +} + +// ParseStreamRestoreFlags parses the `restore stream` flags from the flag set. +func (cfg *RestoreConfig) ParseStreamRestoreFlags(flags *pflag.FlagSet) error { + tsString, err := flags.GetString(FlagStreamStartTS) + if err != nil { + return errors.Trace(err) + } + if cfg.StartTS, err = ParseTSString(tsString, true); err != nil { + return errors.Trace(err) + } + tsString, err = flags.GetString(FlagStreamRestoreTS) + if err != nil { + return errors.Trace(err) + } + if cfg.RestoreTS, err = ParseTSString(tsString, true); err != nil { + return errors.Trace(err) + } + + if cfg.FullBackupStorage, err = flags.GetString(FlagStreamFullBackupStorage); err != nil { + return errors.Trace(err) + } + + if cfg.StartTS > 0 && len(cfg.FullBackupStorage) > 0 { + return errors.Annotatef(berrors.ErrInvalidArgument, "%v and %v are mutually exclusive", + FlagStreamStartTS, FlagStreamFullBackupStorage) + } + + if cfg.PitrBatchCount, err = flags.GetUint32(FlagPiTRBatchCount); err != nil { + return errors.Trace(err) + } + if cfg.PitrBatchSize, err = flags.GetUint32(FlagPiTRBatchSize); err != nil { + return errors.Trace(err) + } + if cfg.PitrConcurrency, err = flags.GetUint32(FlagPiTRConcurrency); err != nil { + return errors.Trace(err) + } + return nil +} + +// ParseFromFlags parses the restore-related flags from the flag set. +func (cfg *RestoreConfig) ParseFromFlags(flags *pflag.FlagSet) error { + var err error + cfg.NoSchema, err = flags.GetBool(flagNoSchema) + if err != nil { + return errors.Trace(err) + } + cfg.LoadStats, err = flags.GetBool(flagLoadStats) + if err != nil { + return errors.Trace(err) + } + err = cfg.Config.ParseFromFlags(flags) + if err != nil { + return errors.Trace(err) + } + err = cfg.RestoreCommonConfig.ParseFromFlags(flags) + if err != nil { + return errors.Trace(err) + } + cfg.Concurrency, err = flags.GetUint32(flagConcurrency) + if err != nil { + return errors.Trace(err) + } + if cfg.Config.Concurrency == 0 { + cfg.Config.Concurrency = defaultRestoreConcurrency + } + cfg.PDConcurrency, err = flags.GetUint(FlagPDConcurrency) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", FlagPDConcurrency) + } + cfg.StatsConcurrency, err = flags.GetUint(FlagStatsConcurrency) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", FlagStatsConcurrency) + } + cfg.BatchFlushInterval, err = flags.GetDuration(FlagBatchFlushInterval) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", FlagBatchFlushInterval) + } + + cfg.DdlBatchSize, err = flags.GetUint(FlagDdlBatchSize) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", FlagDdlBatchSize) + } + cfg.WithPlacementPolicy, err = flags.GetString(FlagWithPlacementPolicy) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", FlagWithPlacementPolicy) + } + cfg.KeyspaceName, err = flags.GetString(FlagKeyspaceName) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", FlagKeyspaceName) + } + cfg.UseCheckpoint, err = flags.GetBool(flagUseCheckpoint) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", flagUseCheckpoint) + } + + cfg.WaitTiflashReady, err = flags.GetBool(FlagWaitTiFlashReady) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", FlagWaitTiFlashReady) + } + + cfg.AllowPITRFromIncremental, err = flags.GetBool(flagAllowPITRFromIncremental) + if err != nil { + return errors.Annotatef(err, "failed to get flag %s", flagAllowPITRFromIncremental) + } + + if flags.Lookup(flagFullBackupType) != nil { + // for restore full only + fullBackupType, err := flags.GetString(flagFullBackupType) + if err != nil { + return errors.Trace(err) + } + if !FullBackupType(fullBackupType).Valid() { + return errors.New("invalid full backup type") + } + cfg.FullBackupType = FullBackupType(fullBackupType) + cfg.Prepare, err = flags.GetBool(flagPrepare) + if err != nil { + return errors.Trace(err) + } + cfg.SkipAWS, err = flags.GetBool(flagSkipAWS) + if err != nil { + return errors.Trace(err) + } + cfg.CloudAPIConcurrency, err = flags.GetUint(flagCloudAPIConcurrency) + if err != nil { + return errors.Trace(err) + } + cfg.OutputFile, err = flags.GetString(flagOutputMetaFile) + if err != nil { + return errors.Trace(err) + } + volumeType, err := flags.GetString(flagVolumeType) + if err != nil { + return errors.Trace(err) + } + cfg.VolumeType = pconfig.EBSVolumeType(volumeType) + if !cfg.VolumeType.Valid() { + return errors.New("invalid volume type: " + volumeType) + } + if cfg.VolumeIOPS, err = flags.GetInt64(flagVolumeIOPS); err != nil { + return errors.Trace(err) + } + if cfg.VolumeThroughput, err = flags.GetInt64(flagVolumeThroughput); err != nil { + return errors.Trace(err) + } + if cfg.VolumeEncrypted, err = flags.GetBool(flagVolumeEncrypted); err != nil { + return errors.Trace(err) + } + + cfg.ProgressFile, err = flags.GetString(flagProgressFile) + if err != nil { + return errors.Trace(err) + } + + cfg.TargetAZ, err = flags.GetString(flagTargetAZ) + if err != nil { + return errors.Trace(err) + } + + cfg.UseFSR, err = flags.GetBool(flagUseFSR) + if err != nil { + return errors.Trace(err) + } + + // iops: gp3 [3,000-16,000]; io1/io2 [100-32,000] + // throughput: gp3 [125, 1000]; io1/io2 cannot set throughput + // io1 and io2 volumes support up to 64,000 IOPS only on Instances built on the Nitro System. + // Other instance families support performance up to 32,000 IOPS. + // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_CreateVolume.html + // todo: check lower/upper bound + } + + return nil +} + +// Adjust is use for BR(binary) and BR in TiDB. +// When new config was added and not included in parser. +// we should set proper value in this function. +// so that both binary and TiDB will use same default value. +func (cfg *RestoreConfig) Adjust() { + cfg.Config.adjust() + cfg.RestoreCommonConfig.adjust() + + if cfg.Config.Concurrency == 0 { + cfg.Config.Concurrency = defaultRestoreConcurrency + } + if cfg.Config.SwitchModeInterval == 0 { + cfg.Config.SwitchModeInterval = defaultSwitchInterval + } + if cfg.PDConcurrency == 0 { + cfg.PDConcurrency = defaultPDConcurrency + } + if cfg.StatsConcurrency == 0 { + cfg.StatsConcurrency = defaultStatsConcurrency + } + if cfg.BatchFlushInterval == 0 { + cfg.BatchFlushInterval = defaultBatchFlushInterval + } + if cfg.DdlBatchSize == 0 { + cfg.DdlBatchSize = defaultFlagDdlBatchSize + } + if cfg.CloudAPIConcurrency == 0 { + cfg.CloudAPIConcurrency = defaultCloudAPIConcurrency + } +} + +func (cfg *RestoreConfig) adjustRestoreConfigForStreamRestore() { + if cfg.PitrConcurrency == 0 { + cfg.PitrConcurrency = defaultPiTRConcurrency + } + if cfg.PitrBatchCount == 0 { + cfg.PitrBatchCount = defaultPiTRBatchCount + } + if cfg.PitrBatchSize == 0 { + cfg.PitrBatchSize = defaultPiTRBatchSize + } + // another goroutine is used to iterate the backup file + cfg.PitrConcurrency += 1 + log.Info("set restore kv files concurrency", zap.Int("concurrency", int(cfg.PitrConcurrency))) + cfg.Config.Concurrency = cfg.PitrConcurrency +} + +// generateLogRestoreTaskName generates the log restore taskName for checkpoint +func (cfg *RestoreConfig) generateLogRestoreTaskName(clusterID, startTS, restoreTs uint64) string { + cfg.checkpointTaskInfoClusterID = clusterID + cfg.checkpointLogRestoreTaskName = fmt.Sprintf("%d/%d.%d", clusterID, startTS, restoreTs) + return cfg.checkpointLogRestoreTaskName +} + +// generateSnapshotRestoreTaskName generates the snapshot restore taskName for checkpoint +func (cfg *RestoreConfig) generateSnapshotRestoreTaskName(clusterID uint64) string { + cfg.checkpointSnapshotRestoreTaskName = fmt.Sprint(clusterID) + return cfg.checkpointSnapshotRestoreTaskName +} + +func configureRestoreClient(ctx context.Context, client *snapclient.SnapClient, cfg *RestoreConfig) error { + client.SetRateLimit(cfg.RateLimit) + client.SetCrypter(&cfg.CipherInfo) + if cfg.NoSchema { + client.EnableSkipCreateSQL() + } + client.SetBatchDdlSize(cfg.DdlBatchSize) + client.SetPlacementPolicyMode(cfg.WithPlacementPolicy) + client.SetWithSysTable(cfg.WithSysTable) + client.SetRewriteMode(ctx) + return nil +} + +func CheckNewCollationEnable( + backupNewCollationEnable string, + g glue.Glue, + storage kv.Storage, + CheckRequirements bool, +) (bool, error) { + se, err := g.CreateSession(storage) + if err != nil { + return false, errors.Trace(err) + } + + newCollationEnable, err := se.GetGlobalVariable(utils.GetTidbNewCollationEnabled()) + if err != nil { + return false, errors.Trace(err) + } + // collate.newCollationEnabled is set to 1 when the collate package is initialized, + // so we need to modify this value according to the config of the cluster + // before using the collate package. + enabled := newCollationEnable == "True" + // modify collate.newCollationEnabled according to the config of the cluster + collate.SetNewCollationEnabledForTest(enabled) + log.Info(fmt.Sprintf("set %s", utils.TidbNewCollationEnabled), zap.Bool("new_collation_enabled", enabled)) + + if backupNewCollationEnable == "" { + if CheckRequirements { + return enabled, errors.Annotatef(berrors.ErrUnknown, + "the value '%s' not found in backupmeta. "+ + "you can use \"SELECT VARIABLE_VALUE FROM mysql.tidb WHERE VARIABLE_NAME='%s';\" to manually check the config. "+ + "if you ensure the value '%s' in backup cluster is as same as restore cluster, use --check-requirements=false to skip this check", + utils.TidbNewCollationEnabled, utils.TidbNewCollationEnabled, utils.TidbNewCollationEnabled) + } + log.Warn(fmt.Sprintf("the config '%s' is not in backupmeta", utils.TidbNewCollationEnabled)) + return enabled, nil + } + + if !strings.EqualFold(backupNewCollationEnable, newCollationEnable) { + return enabled, errors.Annotatef(berrors.ErrUnknown, + "the config '%s' not match, upstream:%v, downstream: %v", + utils.TidbNewCollationEnabled, backupNewCollationEnable, newCollationEnable) + } + + return enabled, nil +} + +// CheckRestoreDBAndTable is used to check whether the restore dbs or tables have been backup +func CheckRestoreDBAndTable(schemas []*metautil.Database, cfg *RestoreConfig) error { + if len(cfg.Schemas) == 0 && len(cfg.Tables) == 0 { + return nil + } + schemasMap := make(map[string]struct{}) + tablesMap := make(map[string]struct{}) + for _, db := range schemas { + dbName := db.Info.Name.L + if dbCIStrName, ok := utils.GetSysDBCIStrName(db.Info.Name); utils.IsSysDB(dbCIStrName.O) && ok { + dbName = dbCIStrName.L + } + schemasMap[utils.EncloseName(dbName)] = struct{}{} + for _, table := range db.Tables { + if table.Info == nil { + // we may back up empty database. + continue + } + tablesMap[utils.EncloseDBAndTable(dbName, table.Info.Name.L)] = struct{}{} + } + } + restoreSchemas := cfg.Schemas + restoreTables := cfg.Tables + for schema := range restoreSchemas { + schemaLName := strings.ToLower(schema) + if _, ok := schemasMap[schemaLName]; !ok { + return errors.Annotatef(berrors.ErrUndefinedRestoreDbOrTable, + "[database: %v] has not been backup, please ensure you has input a correct database name", schema) + } + } + for table := range restoreTables { + tableLName := strings.ToLower(table) + if _, ok := tablesMap[tableLName]; !ok { + return errors.Annotatef(berrors.ErrUndefinedRestoreDbOrTable, + "[table: %v] has not been backup, please ensure you has input a correct table name", table) + } + } + return nil +} + +func isFullRestore(cmdName string) bool { + return cmdName == FullRestoreCmd +} + +// IsStreamRestore checks the command is `restore point` +func IsStreamRestore(cmdName string) bool { + return cmdName == PointRestoreCmd +} + +func registerTaskToPD(ctx context.Context, etcdCLI *clientv3.Client) (closeF func(context.Context) error, err error) { + register := utils.NewTaskRegister(etcdCLI, utils.RegisterRestore, fmt.Sprintf("restore-%s", uuid.New())) + err = register.RegisterTask(ctx) + return register.Close, errors.Trace(err) +} + +func removeCheckpointDataForSnapshotRestore(ctx context.Context, storageName string, taskName string, config *Config) error { + _, s, err := GetStorage(ctx, storageName, config) + if err != nil { + return errors.Trace(err) + } + return errors.Trace(checkpoint.RemoveCheckpointDataForRestore(ctx, s, taskName)) +} + +func removeCheckpointDataForLogRestore(ctx context.Context, storageName string, taskName string, clusterID uint64, config *Config) error { + _, s, err := GetStorage(ctx, storageName, config) + if err != nil { + return errors.Trace(err) + } + return errors.Trace(checkpoint.RemoveCheckpointDataForLogRestore(ctx, s, taskName, clusterID)) +} + +func DefaultRestoreConfig() RestoreConfig { + fs := pflag.NewFlagSet("dummy", pflag.ContinueOnError) + DefineCommonFlags(fs) + DefineRestoreFlags(fs) + cfg := RestoreConfig{} + err := multierr.Combine( + cfg.ParseFromFlags(fs), + cfg.RestoreCommonConfig.ParseFromFlags(fs), + cfg.Config.ParseFromFlags(fs), + ) + if err != nil { + log.Panic("infallible failed.", zap.Error(err)) + } + + return cfg +} + +// RunRestore starts a restore task inside the current goroutine. +func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConfig) error { + etcdCLI, err := dialEtcdWithCfg(c, cfg.Config) + if err != nil { + return err + } + defer func() { + if err := etcdCLI.Close(); err != nil { + log.Error("failed to close the etcd client", zap.Error(err)) + } + }() + if err := checkTaskExists(c, cfg, etcdCLI); err != nil { + return errors.Annotate(err, "failed to check task exists") + } + closeF, err := registerTaskToPD(c, etcdCLI) + if err != nil { + return errors.Annotate(err, "failed to register task to pd") + } + defer func() { + _ = closeF(c) + }() + + config.UpdateGlobal(func(conf *config.Config) { + conf.KeyspaceName = cfg.KeyspaceName + }) + + var restoreError error + if IsStreamRestore(cmdName) { + restoreError = RunStreamRestore(c, g, cmdName, cfg) + } else { + restoreError = runRestore(c, g, cmdName, cfg, nil) + } + if restoreError != nil { + return errors.Trace(restoreError) + } + // Clear the checkpoint data + if cfg.UseCheckpoint { + if len(cfg.checkpointLogRestoreTaskName) > 0 { + log.Info("start to remove checkpoint data for log restore") + err = removeCheckpointDataForLogRestore(c, cfg.Config.Storage, cfg.checkpointLogRestoreTaskName, cfg.checkpointTaskInfoClusterID, &cfg.Config) + if err != nil { + log.Warn("failed to remove checkpoint data for log restore", zap.Error(err)) + } + } + if len(cfg.checkpointSnapshotRestoreTaskName) > 0 { + log.Info("start to remove checkpoint data for snapshot restore.") + var storage string + if IsStreamRestore(cmdName) { + storage = cfg.FullBackupStorage + } else { + storage = cfg.Config.Storage + } + err = removeCheckpointDataForSnapshotRestore(c, storage, cfg.checkpointSnapshotRestoreTaskName, &cfg.Config) + if err != nil { + log.Warn("failed to remove checkpoint data for snapshot restore", zap.Error(err)) + } + } + log.Info("all the checkpoint data is removed.") + } + return nil +} + +func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error { + cfg.Adjust() + defer summary.Summary(cmdName) + ctx, cancel := context.WithCancel(c) + defer cancel() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("task.RunRestore", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + // Restore needs domain to do DDL. + needDomain := true + keepaliveCfg := GetKeepalive(&cfg.Config) + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, keepaliveCfg, cfg.CheckRequirements, needDomain, conn.NormalVersionChecker) + if err != nil { + return errors.Trace(err) + } + defer mgr.Close() + codec := mgr.GetStorage().GetCodec() + + // need retrieve these configs from tikv if not set in command. + kvConfigs := &pconfig.KVConfig{ + ImportGoroutines: cfg.ConcurrencyPerStore, + MergeRegionSize: cfg.MergeSmallRegionSizeBytes, + MergeRegionKeyCount: cfg.MergeSmallRegionKeyCount, + } + + // according to https://github.com/pingcap/tidb/issues/34167. + // we should get the real config from tikv to adapt the dynamic region. + httpCli := httputil.NewClient(mgr.GetTLSConfig()) + mgr.ProcessTiKVConfigs(ctx, kvConfigs, httpCli) + + keepaliveCfg.PermitWithoutStream = true + client := snapclient.NewRestoreClient(mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), keepaliveCfg) + // using tikv config to set the concurrency-per-store for client. + client.SetConcurrencyPerStore(kvConfigs.ImportGoroutines.Value) + err = configureRestoreClient(ctx, client, cfg) + if err != nil { + return errors.Trace(err) + } + // Init DB connection sessions + err = client.Init(g, mgr.GetStorage()) + defer client.Close() + + if err != nil { + return errors.Trace(err) + } + u, s, backupMeta, err := ReadBackupMeta(ctx, metautil.MetaFile, &cfg.Config) + if err != nil { + return errors.Trace(err) + } + if cfg.CheckRequirements { + err := checkIncompatibleChangefeed(ctx, backupMeta.EndVersion, mgr.GetDomain().GetEtcdClient()) + log.Info("Checking incompatible TiCDC changefeeds before restoring.", + logutil.ShortError(err), zap.Uint64("restore-ts", backupMeta.EndVersion)) + if err != nil { + return errors.Trace(err) + } + } + + backupVersion := version.NormalizeBackupVersion(backupMeta.ClusterVersion) + if cfg.CheckRequirements && backupVersion != nil { + if versionErr := version.CheckClusterVersion(ctx, mgr.GetPDClient(), version.CheckVersionForBackup(backupVersion)); versionErr != nil { + return errors.Trace(versionErr) + } + } + if _, err = CheckNewCollationEnable(backupMeta.GetNewCollationsEnabled(), g, mgr.GetStorage(), cfg.CheckRequirements); err != nil { + return errors.Trace(err) + } + + reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) + if err = client.InitBackupMeta(c, backupMeta, u, reader, cfg.LoadStats); err != nil { + return errors.Trace(err) + } + + if client.IsRawKvMode() { + return errors.Annotate(berrors.ErrRestoreModeMismatch, "cannot do transactional restore from raw kv data") + } + if err = CheckRestoreDBAndTable(client.GetDatabases(), cfg); err != nil { + return err + } + files, tables, dbs := filterRestoreFiles(client, cfg) + if len(dbs) == 0 && len(tables) != 0 { + return errors.Annotate(berrors.ErrRestoreInvalidBackup, "contain tables but no databases") + } + + if cfg.CheckRequirements { + if err := checkDiskSpace(ctx, mgr, files, tables); err != nil { + return errors.Trace(err) + } + } + + archiveSize := reader.ArchiveSize(ctx, files) + g.Record(summary.RestoreDataSize, archiveSize) + //restore from tidb will fetch a general Size issue https://github.com/pingcap/tidb/issues/27247 + g.Record("Size", archiveSize) + restoreTS, err := restore.GetTSWithRetry(ctx, mgr.GetPDClient()) + if err != nil { + return errors.Trace(err) + } + + // for full + log restore. should check the cluster is empty. + if client.IsFull() && checkInfo != nil && checkInfo.FullRestoreCheckErr != nil { + return checkInfo.FullRestoreCheckErr + } + + if client.IsIncremental() { + // don't support checkpoint for the ddl restore + log.Info("the incremental snapshot restore doesn't support checkpoint mode, so unuse checkpoint.") + cfg.UseCheckpoint = false + } + + importModeSwitcher := restore.NewImportModeSwitcher(mgr.GetPDClient(), cfg.Config.SwitchModeInterval, mgr.GetTLSConfig()) + restoreSchedulers, schedulersConfig, err := restore.RestorePreWork(ctx, mgr, importModeSwitcher, cfg.Online, true) + if err != nil { + return errors.Trace(err) + } + + schedulersRemovable := false + defer func() { + // don't reset pd scheduler if checkpoint mode is used and restored is not finished + if cfg.UseCheckpoint && !schedulersRemovable { + log.Info("skip removing pd schehduler for next retry") + return + } + log.Info("start to remove the pd scheduler") + // run the post-work to avoid being stuck in the import + // mode or emptied schedulers. + restore.RestorePostWork(ctx, importModeSwitcher, restoreSchedulers, cfg.Online) + log.Info("finish removing pd scheduler") + }() + + var checkpointTaskName string + var checkpointFirstRun bool = true + if cfg.UseCheckpoint { + checkpointTaskName = cfg.generateSnapshotRestoreTaskName(client.GetClusterID(ctx)) + // if the checkpoint metadata exists in the external storage, the restore is not + // for the first time. + existsCheckpointMetadata, err := checkpoint.ExistsRestoreCheckpoint(ctx, s, checkpointTaskName) + if err != nil { + return errors.Trace(err) + } + checkpointFirstRun = !existsCheckpointMetadata + } + + if isFullRestore(cmdName) { + if client.NeedCheckFreshCluster(cfg.ExplicitFilter, checkpointFirstRun) { + if err = client.CheckTargetClusterFresh(ctx); err != nil { + return errors.Trace(err) + } + } + // todo: move this check into InitFullClusterRestore, we should move restore config into a separate package + // to avoid import cycle problem which we won't do it in this pr, then refactor this + // + // if it's point restore and reached here, then cmdName=FullRestoreCmd and len(cfg.FullBackupStorage) > 0 + if cfg.WithSysTable { + client.InitFullClusterRestore(cfg.ExplicitFilter) + } + } + + if client.IsFullClusterRestore() && client.HasBackedUpSysDB() { + if err = snapclient.CheckSysTableCompatibility(mgr.GetDomain(), tables); err != nil { + return errors.Trace(err) + } + } + + // reload or register the checkpoint + var checkpointSetWithTableID map[int64]map[string]struct{} + if cfg.UseCheckpoint { + sets, restoreSchedulersConfigFromCheckpoint, err := client.InitCheckpoint(ctx, s, checkpointTaskName, schedulersConfig, checkpointFirstRun) + if err != nil { + return errors.Trace(err) + } + if restoreSchedulersConfigFromCheckpoint != nil { + restoreSchedulers = mgr.MakeUndoFunctionByConfig(*restoreSchedulersConfigFromCheckpoint) + } + checkpointSetWithTableID = sets + + defer func() { + // need to flush the whole checkpoint data so that br can quickly jump to + // the log kv restore step when the next retry. + log.Info("wait for flush checkpoint...") + client.WaitForFinishCheckpoint(ctx, len(cfg.FullBackupStorage) > 0 || !schedulersRemovable) + }() + } + + sp := utils.BRServiceSafePoint{ + BackupTS: restoreTS, + TTL: utils.DefaultBRGCSafePointTTL, + ID: utils.MakeSafePointID(), + } + g.Record("BackupTS", backupMeta.EndVersion) + g.Record("RestoreTS", restoreTS) + cctx, gcSafePointKeeperCancel := context.WithCancel(ctx) + defer func() { + log.Info("start to remove gc-safepoint keeper") + // close the gc safe point keeper at first + gcSafePointKeeperCancel() + // set the ttl to 0 to remove the gc-safe-point + sp.TTL = 0 + if err := utils.UpdateServiceSafePoint(ctx, mgr.GetPDClient(), sp); err != nil { + log.Warn("failed to update service safe point, backup may fail if gc triggered", + zap.Error(err), + ) + } + log.Info("finish removing gc-safepoint keeper") + }() + // restore checksum will check safe point with its start ts, see details at + // https://github.com/pingcap/tidb/blob/180c02127105bed73712050594da6ead4d70a85f/store/tikv/kv.go#L186-L190 + // so, we should keep the safe point unchangeable. to avoid GC life time is shorter than transaction duration. + err = utils.StartServiceSafePointKeeper(cctx, mgr.GetPDClient(), sp) + if err != nil { + return errors.Trace(err) + } + + ddlJobs := FilterDDLJobs(client.GetDDLJobs(), tables) + ddlJobs = FilterDDLJobByRules(ddlJobs, DDLJobBlockListRule) + if cfg.AllowPITRFromIncremental { + err = CheckDDLJobByRules(ddlJobs, DDLJobLogIncrementalCompactBlockListRule) + if err != nil { + return errors.Trace(err) + } + } + + err = PreCheckTableTiFlashReplica(ctx, mgr.GetPDClient(), tables, cfg.tiflashRecorder) + if err != nil { + return errors.Trace(err) + } + + err = PreCheckTableClusterIndex(tables, ddlJobs, mgr.GetDomain()) + if err != nil { + return errors.Trace(err) + } + + // pre-set TiDB config for restore + restoreDBConfig := enableTiDBConfig() + defer restoreDBConfig() + + if client.GetSupportPolicy() { + // create policy if backupMeta has policies. + policies, err := client.GetPlacementPolicies() + if err != nil { + return errors.Trace(err) + } + if isFullRestore(cmdName) { + // we should restore all policies during full restoration. + err = client.CreatePolicies(ctx, policies) + if err != nil { + return errors.Trace(err) + } + } else { + client.SetPolicyMap(policies) + } + } + + // preallocate the table id, because any ddl job or database creation also allocates the global ID + err = client.AllocTableIDs(ctx, tables) + if err != nil { + return errors.Trace(err) + } + + // execute DDL first + err = client.ExecDDLs(ctx, ddlJobs) + if err != nil { + return errors.Trace(err) + } + + // nothing to restore, maybe only ddl changes in incremental restore + if len(dbs) == 0 && len(tables) == 0 { + log.Info("nothing to restore, all databases and tables are filtered out") + // even nothing to restore, we show a success message since there is no failure. + summary.SetSuccessStatus(true) + return nil + } + + if err = client.CreateDatabases(ctx, dbs); err != nil { + return errors.Trace(err) + } + + var newTS uint64 + if client.IsIncremental() { + if !cfg.AllowPITRFromIncremental { + // we need to get the new ts after execDDL + // or backfilled data in upstream may not be covered by + // the new ts. + // see https://github.com/pingcap/tidb/issues/54426 + newTS, err = restore.GetTSWithRetry(ctx, mgr.GetPDClient()) + if err != nil { + return errors.Trace(err) + } + } + } + // We make bigger errCh so we won't block on multi-part failed. + errCh := make(chan error, 32) + + tableStream := client.GoCreateTables(ctx, tables, newTS, errCh) + + if len(files) == 0 { + log.Info("no files, empty databases and tables are restored") + summary.SetSuccessStatus(true) + // don't return immediately, wait all pipeline done. + } else { + oldKeyspace, _, err := tikv.DecodeKey(files[0].GetStartKey(), backupMeta.ApiVersion) + if err != nil { + return errors.Trace(err) + } + newKeyspace := codec.GetKeyspace() + + // If the API V2 data occurs in the restore process, the cluster must + // support the keyspace rewrite mode. + if (len(oldKeyspace) > 0 || len(newKeyspace) > 0) && client.GetRewriteMode() == snapclient.RewriteModeLegacy { + return errors.Annotate(berrors.ErrRestoreModeMismatch, "cluster only supports legacy rewrite mode") + } + + // Hijack the tableStream and rewrite the rewrite rules. + tableStream = util.ChanMap(tableStream, func(t snapclient.CreatedTable) snapclient.CreatedTable { + // Set the keyspace info for the checksum requests + t.RewriteRule.OldKeyspace = oldKeyspace + t.RewriteRule.NewKeyspace = newKeyspace + + for _, rule := range t.RewriteRule.Data { + rule.OldKeyPrefix = append(append([]byte{}, oldKeyspace...), rule.OldKeyPrefix...) + rule.NewKeyPrefix = codec.EncodeKey(rule.NewKeyPrefix) + } + return t + }) + } + + if cfg.tiflashRecorder != nil { + tableStream = util.ChanMap(tableStream, func(t snapclient.CreatedTable) snapclient.CreatedTable { + if cfg.tiflashRecorder != nil { + cfg.tiflashRecorder.Rewrite(t.OldTable.Info.ID, t.Table.ID) + } + return t + }) + } + + // Block on creating tables before restore starts. since create table is no longer a heavy operation any more. + tableStream = client.GoBlockCreateTablesPipeline(ctx, maxRestoreBatchSizeLimit, tableStream) + + tableFileMap := MapTableToFiles(files) + log.Debug("mapped table to files", zap.Any("result map", tableFileMap)) + + rangeStream := client.GoValidateFileRanges( + ctx, tableStream, tableFileMap, kvConfigs.MergeRegionSize.Value, kvConfigs.MergeRegionKeyCount.Value, errCh) + + rangeSize := EstimateRangeSize(files) + summary.CollectInt("restore ranges", rangeSize) + log.Info("range and file prepared", zap.Int("file count", len(files)), zap.Int("range count", rangeSize)) + + // Do not reset timestamp if we are doing incremental restore, because + // we are not allowed to decrease timestamp. + if !client.IsIncremental() { + if err = client.ResetTS(ctx, mgr.PdController); err != nil { + log.Error("reset pd TS failed", zap.Error(err)) + return errors.Trace(err) + } + } + + // Restore sst files in batch. + batchSize := mathutil.MaxInt + failpoint.Inject("small-batch-size", func(v failpoint.Value) { + log.Info("failpoint small batch size is on", zap.Int("size", v.(int))) + batchSize = v.(int) + }) + + // Split/Scatter + Download/Ingest + progressLen := int64(rangeSize + len(files)) + if cfg.Checksum { + progressLen += int64(len(tables)) + } + if cfg.WaitTiflashReady { + progressLen += int64(len(tables)) + } + // Redirect to log if there is no log file to avoid unreadable output. + updateCh := g.StartProgress( + ctx, + cmdName, + progressLen, + !cfg.LogProgress) + defer updateCh.Close() + sender, err := snapclient.NewTiKVSender(ctx, client, updateCh, cfg.PDConcurrency) + if err != nil { + return errors.Trace(err) + } + manager, err := snapclient.NewBRContextManager(ctx, mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), cfg.Online) + if err != nil { + return errors.Trace(err) + } + batcher, afterTableRestoredCh := snapclient.NewBatcher(ctx, sender, manager, errCh, updateCh) + batcher.SetCheckpoint(checkpointSetWithTableID) + batcher.SetThreshold(batchSize) + batcher.EnableAutoCommit(ctx, cfg.BatchFlushInterval) + go restoreTableStream(ctx, rangeStream, batcher, errCh) + + var finish <-chan struct{} + postHandleCh := afterTableRestoredCh + + // pipeline checksum + if cfg.Checksum { + postHandleCh = client.GoValidateChecksum( + ctx, postHandleCh, mgr.GetStorage().GetClient(), errCh, updateCh, cfg.ChecksumConcurrency) + } + + // pipeline update meta and load stats + postHandleCh = client.GoUpdateMetaAndLoadStats(ctx, s, postHandleCh, errCh, cfg.StatsConcurrency, cfg.LoadStats) + + // pipeline wait Tiflash synced + if cfg.WaitTiflashReady { + postHandleCh = client.GoWaitTiFlashReady(ctx, postHandleCh, updateCh, errCh) + } + + finish = dropToBlackhole(ctx, postHandleCh, errCh) + + // Reset speed limit. ResetSpeedLimit must be called after client.InitBackupMeta has been called. + defer func() { + var resetErr error + // In future we may need a mechanism to set speed limit in ttl. like what we do in switchmode. TODO + for retry := 0; retry < resetSpeedLimitRetryTimes; retry++ { + resetErr = client.ResetSpeedLimit(ctx) + if resetErr != nil { + log.Warn("failed to reset speed limit, retry it", + zap.Int("retry time", retry), logutil.ShortError(resetErr)) + time.Sleep(time.Duration(retry+3) * time.Second) + continue + } + break + } + if resetErr != nil { + log.Error("failed to reset speed limit, please reset it manually", zap.Error(resetErr)) + } + }() + + select { + case err = <-errCh: + err = multierr.Append(err, multierr.Combine(Exhaust(errCh)...)) + case <-finish: + } + + // If any error happened, return now. + if err != nil { + return errors.Trace(err) + } + + // The cost of rename user table / replace into system table wouldn't be so high. + // So leave it out of the pipeline for easier implementation. + err = client.RestoreSystemSchemas(ctx, cfg.TableFilter) + if err != nil { + return errors.Trace(err) + } + + schedulersRemovable = true + + // Set task summary to success status. + summary.SetSuccessStatus(true) + return nil +} + +func getMaxReplica(ctx context.Context, mgr *conn.Mgr) (cnt uint64, err error) { + var resp map[string]any + err = utils.WithRetry(ctx, func() error { + resp, err = mgr.GetPDHTTPClient().GetReplicateConfig(ctx) + return err + }, utils.NewPDReqBackoffer()) + if err != nil { + return 0, errors.Trace(err) + } + + key := "max-replicas" + val, ok := resp[key] + if !ok { + return 0, errors.Errorf("key %s not found in response %v", key, resp) + } + return uint64(val.(float64)), nil +} + +func getStores(ctx context.Context, mgr *conn.Mgr) (stores *http.StoresInfo, err error) { + err = utils.WithRetry(ctx, func() error { + stores, err = mgr.GetPDHTTPClient().GetStores(ctx) + return err + }, utils.NewPDReqBackoffer()) + if err != nil { + return nil, errors.Trace(err) + } + return stores, nil +} + +func EstimateTikvUsage(files []*backuppb.File, replicaCnt uint64, storeCnt uint64) uint64 { + if storeCnt == 0 { + return 0 + } + if replicaCnt > storeCnt { + replicaCnt = storeCnt + } + totalSize := uint64(0) + for _, file := range files { + totalSize += file.GetSize_() + } + log.Info("estimate tikv usage", zap.Uint64("total size", totalSize), zap.Uint64("replicaCnt", replicaCnt), zap.Uint64("store count", storeCnt)) + return totalSize * replicaCnt / storeCnt +} + +func EstimateTiflashUsage(tables []*metautil.Table, storeCnt uint64) uint64 { + if storeCnt == 0 { + return 0 + } + tiflashTotal := uint64(0) + for _, table := range tables { + if table.Info.TiFlashReplica == nil || table.Info.TiFlashReplica.Count <= 0 { + continue + } + tableBytes := uint64(0) + for _, file := range table.Files { + tableBytes += file.GetSize_() + } + tiflashTotal += tableBytes * table.Info.TiFlashReplica.Count + } + log.Info("estimate tiflash usage", zap.Uint64("total size", tiflashTotal), zap.Uint64("store count", storeCnt)) + return tiflashTotal / storeCnt +} + +func CheckStoreSpace(necessary uint64, store *http.StoreInfo) error { + available, err := units.RAMInBytes(store.Status.Available) + if err != nil { + return errors.Annotatef(berrors.ErrPDInvalidResponse, "store %d has invalid available space %s", store.Store.ID, store.Status.Available) + } + if available <= 0 { + return errors.Annotatef(berrors.ErrPDInvalidResponse, "store %d has invalid available space %s", store.Store.ID, store.Status.Available) + } + if uint64(available) < necessary { + return errors.Annotatef(berrors.ErrKVDiskFull, "store %d has no space left on device, available %s, necessary %s", + store.Store.ID, units.BytesSize(float64(available)), units.BytesSize(float64(necessary))) + } + return nil +} + +func checkDiskSpace(ctx context.Context, mgr *conn.Mgr, files []*backuppb.File, tables []*metautil.Table) error { + maxReplica, err := getMaxReplica(ctx, mgr) + if err != nil { + return errors.Trace(err) + } + stores, err := getStores(ctx, mgr) + if err != nil { + return errors.Trace(err) + } + + var tikvCnt, tiflashCnt uint64 = 0, 0 + for i := range stores.Stores { + store := &stores.Stores[i] + if engine.IsTiFlashHTTPResp(&store.Store) { + tiflashCnt += 1 + continue + } + tikvCnt += 1 + } + + // We won't need to restore more than 1800 PB data at one time, right? + preserve := func(base uint64, ratio float32) uint64 { + if base > 1000*units.PB { + return base + } + return base * uint64(ratio*10) / 10 + } + tikvUsage := preserve(EstimateTikvUsage(files, maxReplica, tikvCnt), 1.1) + tiflashUsage := preserve(EstimateTiflashUsage(tables, tiflashCnt), 1.4) + log.Info("preserved disk space", zap.Uint64("tikv", tikvUsage), zap.Uint64("tiflash", tiflashUsage)) + + err = utils.WithRetry(ctx, func() error { + stores, err = getStores(ctx, mgr) + if err != nil { + return errors.Trace(err) + } + for _, store := range stores.Stores { + if engine.IsTiFlashHTTPResp(&store.Store) { + if err := CheckStoreSpace(tiflashUsage, &store); err != nil { + return errors.Trace(err) + } + continue + } + if err := CheckStoreSpace(tikvUsage, &store); err != nil { + return errors.Trace(err) + } + } + return nil + }, utils.NewDiskCheckBackoffer()) + if err != nil { + return errors.Trace(err) + } + return nil +} + +// Exhaust drains all remaining errors in the channel, into a slice of errors. +func Exhaust(ec <-chan error) []error { + out := make([]error, 0, len(ec)) + for { + select { + case err := <-ec: + out = append(out, err) + default: + // errCh will NEVER be closed(ya see, it has multi sender-part), + // so we just consume the current backlog of this channel, then return. + return out + } + } +} + +// EstimateRangeSize estimates the total range count by file. +func EstimateRangeSize(files []*backuppb.File) int { + result := 0 + for _, f := range files { + if strings.HasSuffix(f.GetName(), "_write.sst") { + result++ + } + } + return result +} + +// MapTableToFiles makes a map that mapping table ID to its backup files. +// aware that one file can and only can hold one table. +func MapTableToFiles(files []*backuppb.File) map[int64][]*backuppb.File { + result := map[int64][]*backuppb.File{} + for _, file := range files { + tableID := tablecodec.DecodeTableID(file.GetStartKey()) + tableEndID := tablecodec.DecodeTableID(file.GetEndKey()) + if tableID != tableEndID { + log.Panic("key range spread between many files.", + zap.String("file name", file.Name), + logutil.Key("startKey", file.StartKey), + logutil.Key("endKey", file.EndKey)) + } + if tableID == 0 { + log.Panic("invalid table key of file", + zap.String("file name", file.Name), + logutil.Key("startKey", file.StartKey), + logutil.Key("endKey", file.EndKey)) + } + result[tableID] = append(result[tableID], file) + } + return result +} + +// dropToBlackhole drop all incoming tables into black hole, +// i.e. don't execute checksum, just increase the process anyhow. +func dropToBlackhole( + ctx context.Context, + inCh <-chan *snapclient.CreatedTable, + errCh chan<- error, +) <-chan struct{} { + outCh := make(chan struct{}, 1) + go func() { + defer func() { + close(outCh) + }() + for { + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case _, ok := <-inCh: + if !ok { + return + } + } + } + }() + return outCh +} + +// filterRestoreFiles filters tables that can't be processed after applying cfg.TableFilter.MatchTable. +// if the db has no table that can be processed, the db will be filtered too. +func filterRestoreFiles( + client *snapclient.SnapClient, + cfg *RestoreConfig, +) (files []*backuppb.File, tables []*metautil.Table, dbs []*metautil.Database) { + for _, db := range client.GetDatabases() { + dbName := db.Info.Name.O + if name, ok := utils.GetSysDBName(db.Info.Name); utils.IsSysDB(name) && ok { + dbName = name + } + if !cfg.TableFilter.MatchSchema(dbName) { + continue + } + dbs = append(dbs, db) + for _, table := range db.Tables { + if table.Info == nil || !cfg.TableFilter.MatchTable(dbName, table.Info.Name.O) { + continue + } + files = append(files, table.Files...) + tables = append(tables, table) + } + } + return +} + +// enableTiDBConfig tweaks some of configs of TiDB to make the restore progress go well. +// return a function that could restore the config to origin. +func enableTiDBConfig() func() { + restoreConfig := config.RestoreFunc() + config.UpdateGlobal(func(conf *config.Config) { + // set max-index-length before execute DDLs and create tables + // we set this value to max(3072*4), otherwise we might not restore table + // when upstream and downstream both set this value greater than default(3072) + conf.MaxIndexLength = config.DefMaxOfMaxIndexLength + log.Warn("set max-index-length to max(3072*4) to skip check index length in DDL") + conf.IndexLimit = config.DefMaxOfIndexLimit + log.Warn("set index-limit to max(64*8) to skip check index count in DDL") + conf.TableColumnCountLimit = config.DefMaxOfTableColumnCountLimit + log.Warn("set table-column-count to max(4096) to skip check column count in DDL") + }) + return restoreConfig +} + +// restoreTableStream blocks current goroutine and restore a stream of tables, +// by send tables to batcher. +func restoreTableStream( + ctx context.Context, + inputCh <-chan snapclient.TableWithRange, + batcher *snapclient.Batcher, + errCh chan<- error, +) { + oldTableCount := 0 + defer func() { + // when things done, we must clean pending requests. + batcher.Close() + log.Info("doing postwork", + zap.Int("table count", oldTableCount), + ) + }() + + for { + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case t, ok := <-inputCh: + if !ok { + return + } + oldTableCount += 1 + + batcher.Add(t) + } + } +} + +func getTiFlashNodeCount(ctx context.Context, pdClient pd.Client) (uint64, error) { + tiFlashStores, err := conn.GetAllTiKVStoresWithRetry(ctx, pdClient, connutil.TiFlashOnly) + if err != nil { + return 0, errors.Trace(err) + } + return uint64(len(tiFlashStores)), nil +} + +// PreCheckTableTiFlashReplica checks whether TiFlash replica is less than TiFlash node. +func PreCheckTableTiFlashReplica( + ctx context.Context, + pdClient pd.Client, + tables []*metautil.Table, + recorder *tiflashrec.TiFlashRecorder, +) error { + tiFlashStoreCount, err := getTiFlashNodeCount(ctx, pdClient) + if err != nil { + return err + } + for _, table := range tables { + if table.Info.TiFlashReplica != nil { + // we should not set available to true. because we cannot guarantee the raft log lag of tiflash when restore finished. + // just let tiflash ticker set it by checking lag of all related regions. + table.Info.TiFlashReplica.Available = false + table.Info.TiFlashReplica.AvailablePartitionIDs = nil + if recorder != nil { + recorder.AddTable(table.Info.ID, *table.Info.TiFlashReplica) + log.Info("record tiflash replica for table, to reset it by ddl later", + zap.Stringer("db", table.DB.Name), + zap.Stringer("table", table.Info.Name), + ) + table.Info.TiFlashReplica = nil + } else if table.Info.TiFlashReplica.Count > tiFlashStoreCount { + // we cannot satisfy TiFlash replica in restore cluster. so we should + // set TiFlashReplica to unavailable in tableInfo, to avoid TiDB cannot sense TiFlash and make plan to TiFlash + // see details at https://github.com/pingcap/br/issues/931 + // TODO maybe set table.Info.TiFlashReplica.Count to tiFlashStoreCount, but we need more tests about it. + log.Warn("table does not satisfy tiflash replica requirements, set tiflash replcia to unavailable", + zap.Stringer("db", table.DB.Name), + zap.Stringer("table", table.Info.Name), + zap.Uint64("expect tiflash replica", table.Info.TiFlashReplica.Count), + zap.Uint64("actual tiflash store", tiFlashStoreCount), + ) + table.Info.TiFlashReplica = nil + } + } + } + return nil +} + +// PreCheckTableClusterIndex checks whether backup tables and existed tables have different cluster index options。 +func PreCheckTableClusterIndex( + tables []*metautil.Table, + ddlJobs []*model.Job, + dom *domain.Domain, +) error { + for _, table := range tables { + oldTableInfo, err := restore.GetTableSchema(dom, table.DB.Name, table.Info.Name) + // table exists in database + if err == nil { + if table.Info.IsCommonHandle != oldTableInfo.IsCommonHandle { + return errors.Annotatef(berrors.ErrRestoreModeMismatch, + "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", + restore.TransferBoolToValue(table.Info.IsCommonHandle), + table.Info.IsCommonHandle, + oldTableInfo.IsCommonHandle) + } + } + } + for _, job := range ddlJobs { + if job.Type == model.ActionCreateTable { + tableInfo := job.BinlogInfo.TableInfo + if tableInfo != nil { + oldTableInfo, err := restore.GetTableSchema(dom, model.NewCIStr(job.SchemaName), tableInfo.Name) + // table exists in database + if err == nil { + if tableInfo.IsCommonHandle != oldTableInfo.IsCommonHandle { + return errors.Annotatef(berrors.ErrRestoreModeMismatch, + "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", + restore.TransferBoolToValue(tableInfo.IsCommonHandle), + tableInfo.IsCommonHandle, + oldTableInfo.IsCommonHandle) + } + } + } + } + } + return nil +} + +func getDatabases(tables []*metautil.Table) (dbs []*model.DBInfo) { + dbIDs := make(map[int64]bool) + for _, table := range tables { + if !dbIDs[table.DB.ID] { + dbs = append(dbs, table.DB) + dbIDs[table.DB.ID] = true + } + } + return +} + +// FilterDDLJobs filters ddl jobs. +func FilterDDLJobs(allDDLJobs []*model.Job, tables []*metautil.Table) (ddlJobs []*model.Job) { + // Sort the ddl jobs by schema version in descending order. + slices.SortFunc(allDDLJobs, func(i, j *model.Job) int { + return cmp.Compare(j.BinlogInfo.SchemaVersion, i.BinlogInfo.SchemaVersion) + }) + dbs := getDatabases(tables) + for _, db := range dbs { + // These maps is for solving some corner case. + // e.g. let "t=2" indicates that the id of database "t" is 2, if the ddl execution sequence is: + // rename "a" to "b"(a=1) -> drop "b"(b=1) -> create "b"(b=2) -> rename "b" to "a"(a=2) + // Which we cannot find the "create" DDL by name and id directly. + // To cover †his case, we must find all names and ids the database/table ever had. + dbIDs := make(map[int64]bool) + dbIDs[db.ID] = true + dbNames := make(map[string]bool) + dbNames[db.Name.String()] = true + for _, job := range allDDLJobs { + if job.BinlogInfo.DBInfo != nil { + if dbIDs[job.SchemaID] || dbNames[job.BinlogInfo.DBInfo.Name.String()] { + ddlJobs = append(ddlJobs, job) + // The the jobs executed with the old id, like the step 2 in the example above. + dbIDs[job.SchemaID] = true + // For the jobs executed after rename, like the step 3 in the example above. + dbNames[job.BinlogInfo.DBInfo.Name.String()] = true + } + } + } + } + + for _, table := range tables { + tableIDs := make(map[int64]bool) + tableIDs[table.Info.ID] = true + tableNames := make(map[restore.UniqueTableName]bool) + name := restore.UniqueTableName{DB: table.DB.Name.String(), Table: table.Info.Name.String()} + tableNames[name] = true + for _, job := range allDDLJobs { + if job.BinlogInfo.TableInfo != nil { + name = restore.UniqueTableName{DB: job.SchemaName, Table: job.BinlogInfo.TableInfo.Name.String()} + if tableIDs[job.TableID] || tableNames[name] { + ddlJobs = append(ddlJobs, job) + tableIDs[job.TableID] = true + // For truncate table, the id may be changed + tableIDs[job.BinlogInfo.TableInfo.ID] = true + tableNames[name] = true + } + } + } + } + return ddlJobs +} + +// CheckDDLJobByRules if one of rules returns true, the job in srcDDLJobs will be filtered. +func CheckDDLJobByRules(srcDDLJobs []*model.Job, rules ...DDLJobFilterRule) error { + for _, ddlJob := range srcDDLJobs { + for _, rule := range rules { + if rule(ddlJob) { + return errors.Annotatef(berrors.ErrRestoreModeMismatch, "DDL job %s is not allowed in incremental restore"+ + " when --allow-pitr-from-incremental enabled", ddlJob.String()) + } + } + } + return nil +} + +// FilterDDLJobByRules if one of rules returns true, the job in srcDDLJobs will be filtered. +func FilterDDLJobByRules(srcDDLJobs []*model.Job, rules ...DDLJobFilterRule) (dstDDLJobs []*model.Job) { + dstDDLJobs = make([]*model.Job, 0, len(srcDDLJobs)) + for _, ddlJob := range srcDDLJobs { + passed := true + for _, rule := range rules { + if rule(ddlJob) { + passed = false + break + } + } + + if passed { + dstDDLJobs = append(dstDDLJobs, ddlJob) + } + } + + return +} + +type DDLJobFilterRule func(ddlJob *model.Job) bool + +var incrementalRestoreActionBlockList = map[model.ActionType]struct{}{ + model.ActionSetTiFlashReplica: {}, + model.ActionUpdateTiFlashReplicaStatus: {}, + model.ActionLockTable: {}, + model.ActionUnlockTable: {}, +} + +var logIncrementalRestoreCompactibleBlockList = map[model.ActionType]struct{}{ + model.ActionAddIndex: {}, + model.ActionModifyColumn: {}, + model.ActionReorganizePartition: {}, +} + +// DDLJobBlockListRule rule for filter ddl job with type in block list. +func DDLJobBlockListRule(ddlJob *model.Job) bool { + return checkIsInActions(ddlJob.Type, incrementalRestoreActionBlockList) +} + +func DDLJobLogIncrementalCompactBlockListRule(ddlJob *model.Job) bool { + return checkIsInActions(ddlJob.Type, logIncrementalRestoreCompactibleBlockList) +} + +func checkIsInActions(action model.ActionType, actions map[model.ActionType]struct{}) bool { + _, ok := actions[action] + return ok +} diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index 29e3177df7e0c..cf46760af7677 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -1174,9 +1174,9 @@ func RunStreamRestore( return errors.Trace(err) } - failpoint.Inject("failed-before-full-restore", func(_ failpoint.Value) { - failpoint.Return(errors.New("failpoint: failed before full restore")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failed-before-full-restore")); _err_ == nil { + return errors.New("failpoint: failed before full restore") + } recorder := tiflashrec.New() cfg.tiflashRecorder = recorder @@ -1469,11 +1469,11 @@ func restoreStream( } } - failpoint.Inject("do-checksum-with-rewrite-rules", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("do-checksum-with-rewrite-rules")); _err_ == nil { if err := client.FailpointDoChecksumForLogRestore(ctx, mgr.GetStorage().GetClient(), mgr.GetPDClient(), idrules, rewriteRules); err != nil { - failpoint.Return(errors.Annotate(err, "failed to do checksum")) + return errors.Annotate(err, "failed to do checksum") } - }) + } gcDisabledRestorable = true diff --git a/br/pkg/task/stream.go__failpoint_stash__ b/br/pkg/task/stream.go__failpoint_stash__ new file mode 100644 index 0000000000000..29e3177df7e0c --- /dev/null +++ b/br/pkg/task/stream.go__failpoint_stash__ @@ -0,0 +1,1846 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "math" + "net/http" + "slices" + "strings" + "sync" + "time" + + "github.com/docker/go-units" + "github.com/fatih/color" + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/backup" + "github.com/pingcap/tidb/br/pkg/checkpoint" + "github.com/pingcap/tidb/br/pkg/conn" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/httputil" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/metautil" + "github.com/pingcap/tidb/br/pkg/restore" + "github.com/pingcap/tidb/br/pkg/restore/ingestrec" + logclient "github.com/pingcap/tidb/br/pkg/restore/log_client" + "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" + restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/stream" + "github.com/pingcap/tidb/br/pkg/streamhelper" + advancercfg "github.com/pingcap/tidb/br/pkg/streamhelper/config" + "github.com/pingcap/tidb/br/pkg/streamhelper/daemon" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/util/cdcutil" + "github.com/spf13/pflag" + "github.com/tikv/client-go/v2/oracle" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +const ( + flagYes = "yes" + flagUntil = "until" + flagStreamJSONOutput = "json" + flagStreamTaskName = "task-name" + flagStreamStartTS = "start-ts" + flagStreamEndTS = "end-ts" + flagGCSafePointTTS = "gc-ttl" + + truncateLockPath = "truncating.lock" + hintOnTruncateLock = "There might be another truncate task running, or a truncate task that didn't exit properly. " + + "You may check the metadata and continue by wait other task finish or manually delete the lock file " + truncateLockPath + " at the external storage." +) + +var ( + StreamStart = "log start" + StreamStop = "log stop" + StreamPause = "log pause" + StreamResume = "log resume" + StreamStatus = "log status" + StreamTruncate = "log truncate" + StreamMetadata = "log metadata" + StreamCtl = "log advancer" + + skipSummaryCommandList = map[string]struct{}{ + StreamStatus: {}, + StreamTruncate: {}, + } + + streamShiftDuration = time.Hour +) + +var StreamCommandMap = map[string]func(c context.Context, g glue.Glue, cmdName string, cfg *StreamConfig) error{ + StreamStart: RunStreamStart, + StreamStop: RunStreamStop, + StreamPause: RunStreamPause, + StreamResume: RunStreamResume, + StreamStatus: RunStreamStatus, + StreamTruncate: RunStreamTruncate, + StreamMetadata: RunStreamMetadata, + StreamCtl: RunStreamAdvancer, +} + +// StreamConfig specifies the configure about backup stream +type StreamConfig struct { + Config + + TaskName string `json:"task-name" toml:"task-name"` + + // StartTS usually equals the tso of full-backup, but user can reset it + StartTS uint64 `json:"start-ts" toml:"start-ts"` + EndTS uint64 `json:"end-ts" toml:"end-ts"` + // SafePointTTL ensures TiKV can scan entries not being GC at [startTS, currentTS] + SafePointTTL int64 `json:"safe-point-ttl" toml:"safe-point-ttl"` + + // Spec for the command `truncate`, we should truncate the until when? + Until uint64 `json:"until" toml:"until"` + DryRun bool `json:"dry-run" toml:"dry-run"` + SkipPrompt bool `json:"skip-prompt" toml:"skip-prompt"` + + // Spec for the command `status`. + JSONOutput bool `json:"json-output" toml:"json-output"` + + // Spec for the command `advancer`. + AdvancerCfg advancercfg.Config `json:"advancer-config" toml:"advancer-config"` +} + +func (cfg *StreamConfig) makeStorage(ctx context.Context) (storage.ExternalStorage, error) { + u, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) + if err != nil { + return nil, errors.Trace(err) + } + opts := getExternalStorageOptions(&cfg.Config, u) + storage, err := storage.New(ctx, u, &opts) + if err != nil { + return nil, errors.Trace(err) + } + return storage, nil +} + +// DefineStreamStartFlags defines flags used for `stream start` +func DefineStreamStartFlags(flags *pflag.FlagSet) { + DefineStreamCommonFlags(flags) + + flags.String(flagStreamStartTS, "", + "usually equals last full backupTS, used for backup log. Default value is current ts.\n"+ + "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'.") + // 999999999999999999 means 2090-11-18 22:07:45 + flags.String(flagStreamEndTS, "999999999999999999", "end ts, indicate stopping observe after endTS"+ + "support TSO or datetime") + _ = flags.MarkHidden(flagStreamEndTS) + flags.Int64(flagGCSafePointTTS, utils.DefaultStreamStartSafePointTTL, + "the TTL (in seconds) that PD holds for BR's GC safepoint") + _ = flags.MarkHidden(flagGCSafePointTTS) +} + +func DefineStreamPauseFlags(flags *pflag.FlagSet) { + DefineStreamCommonFlags(flags) + flags.Int64(flagGCSafePointTTS, utils.DefaultStreamPauseSafePointTTL, + "the TTL (in seconds) that PD holds for BR's GC safepoint") +} + +// DefineStreamCommonFlags define common flags for `stream task` +func DefineStreamCommonFlags(flags *pflag.FlagSet) { + flags.String(flagStreamTaskName, "", "The task name for the backup log task.") +} + +func DefineStreamStatusCommonFlags(flags *pflag.FlagSet) { + flags.String(flagStreamTaskName, stream.WildCard, + "The task name for backup stream log. If default, get status of all of tasks", + ) + flags.Bool(flagStreamJSONOutput, false, + "Print JSON as the output.", + ) +} + +func DefineStreamTruncateLogFlags(flags *pflag.FlagSet) { + flags.String(flagUntil, "", "Remove all backup data until this TS."+ + "(support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'.)") + flags.Bool(flagDryRun, false, "Run the command but don't really delete the files.") + flags.BoolP(flagYes, "y", false, "Skip all prompts and always execute the command.") +} + +func (cfg *StreamConfig) ParseStreamStatusFromFlags(flags *pflag.FlagSet) error { + var err error + cfg.JSONOutput, err = flags.GetBool(flagStreamJSONOutput) + if err != nil { + return errors.Trace(err) + } + + if err = cfg.ParseStreamCommonFromFlags(flags); err != nil { + return errors.Trace(err) + } + + return nil +} + +func (cfg *StreamConfig) ParseStreamTruncateFromFlags(flags *pflag.FlagSet) error { + tsString, err := flags.GetString(flagUntil) + if err != nil { + return errors.Trace(err) + } + if cfg.Until, err = ParseTSString(tsString, true); err != nil { + return errors.Trace(err) + } + if cfg.SkipPrompt, err = flags.GetBool(flagYes); err != nil { + return errors.Trace(err) + } + if cfg.DryRun, err = flags.GetBool(flagDryRun); err != nil { + return errors.Trace(err) + } + return nil +} + +// ParseStreamStartFromFlags parse parameters for `stream start` +func (cfg *StreamConfig) ParseStreamStartFromFlags(flags *pflag.FlagSet) error { + err := cfg.ParseStreamCommonFromFlags(flags) + if err != nil { + return errors.Trace(err) + } + + tsString, err := flags.GetString(flagStreamStartTS) + if err != nil { + return errors.Trace(err) + } + + if cfg.StartTS, err = ParseTSString(tsString, true); err != nil { + return errors.Trace(err) + } + + tsString, err = flags.GetString(flagStreamEndTS) + if err != nil { + return errors.Trace(err) + } + + if cfg.EndTS, err = ParseTSString(tsString, true); err != nil { + return errors.Trace(err) + } + + if cfg.SafePointTTL, err = flags.GetInt64(flagGCSafePointTTS); err != nil { + return errors.Trace(err) + } + + if cfg.SafePointTTL <= 0 { + cfg.SafePointTTL = utils.DefaultStreamStartSafePointTTL + } + + return nil +} + +// ParseStreamPauseFromFlags parse parameters for `stream pause` +func (cfg *StreamConfig) ParseStreamPauseFromFlags(flags *pflag.FlagSet) error { + err := cfg.ParseStreamCommonFromFlags(flags) + if err != nil { + return errors.Trace(err) + } + + if cfg.SafePointTTL, err = flags.GetInt64(flagGCSafePointTTS); err != nil { + return errors.Trace(err) + } + if cfg.SafePointTTL <= 0 { + cfg.SafePointTTL = utils.DefaultStreamPauseSafePointTTL + } + return nil +} + +// ParseStreamCommonFromFlags parse parameters for `stream task` +func (cfg *StreamConfig) ParseStreamCommonFromFlags(flags *pflag.FlagSet) error { + var err error + + cfg.TaskName, err = flags.GetString(flagStreamTaskName) + if err != nil { + return errors.Trace(err) + } + + if len(cfg.TaskName) <= 0 { + return errors.Annotate(berrors.ErrInvalidArgument, "Miss parameters task-name") + } + return nil +} + +type streamMgr struct { + cfg *StreamConfig + mgr *conn.Mgr + bc *backup.Client + httpCli *http.Client +} + +func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamStart bool) (*streamMgr, error) { + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), + cfg.CheckRequirements, false, conn.StreamVersionChecker) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + if err != nil { + mgr.Close() + } + }() + + // just stream start need Storage + s := &streamMgr{ + cfg: cfg, + mgr: mgr, + } + if isStreamStart { + client := backup.NewBackupClient(ctx, mgr) + + backend, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) + if err != nil { + return nil, errors.Trace(err) + } + + opts := storage.ExternalStorageOptions{ + NoCredentials: cfg.NoCreds, + SendCredentials: cfg.SendCreds, + CheckS3ObjectLockOptions: true, + } + if err = client.SetStorage(ctx, backend, &opts); err != nil { + return nil, errors.Trace(err) + } + s.bc = client + + // create http client to do some requirements check. + s.httpCli = httputil.NewClient(mgr.GetTLSConfig()) + } + return s, nil +} + +func (s *streamMgr) close() { + s.mgr.Close() +} + +func (s *streamMgr) checkLock(ctx context.Context) (bool, error) { + return s.bc.GetStorage().FileExists(ctx, metautil.LockFile) +} + +func (s *streamMgr) setLock(ctx context.Context) error { + return s.bc.SetLockFile(ctx) +} + +// adjustAndCheckStartTS checks that startTS should be smaller than currentTS, +// and endTS is larger than currentTS. +func (s *streamMgr) adjustAndCheckStartTS(ctx context.Context) error { + currentTS, err := s.mgr.GetTS(ctx) + if err != nil { + return errors.Trace(err) + } + // set currentTS to startTS as a default value + if s.cfg.StartTS == 0 { + s.cfg.StartTS = currentTS + } + + if currentTS < s.cfg.StartTS { + return errors.Annotatef(berrors.ErrInvalidArgument, + "invalid timestamps, startTS %d should be smaller than currentTS %d", + s.cfg.StartTS, currentTS) + } + if s.cfg.EndTS <= currentTS { + return errors.Annotatef(berrors.ErrInvalidArgument, + "invalid timestamps, endTS %d should be larger than currentTS %d", + s.cfg.EndTS, currentTS) + } + + return nil +} + +// checkImportTaskRunning checks whether there is any import task running. +func (s *streamMgr) checkImportTaskRunning(ctx context.Context, etcdCLI *clientv3.Client) error { + list, err := utils.GetImportTasksFrom(ctx, etcdCLI) + if err != nil { + return errors.Trace(err) + } + if !list.Empty() { + return errors.Errorf("There are some lightning/restore tasks running: %s"+ + "please stop or wait finishing at first. "+ + "If the lightning/restore task is forced to terminate by system, "+ + "please wait for ttl to decrease to 0.", list.MessageToUser()) + } + return nil +} + +// setGCSafePoint sets the server safe point to PD. +func (s *streamMgr) setGCSafePoint(ctx context.Context, sp utils.BRServiceSafePoint) error { + err := utils.CheckGCSafePoint(ctx, s.mgr.GetPDClient(), sp.BackupTS) + if err != nil { + return errors.Annotatef(err, + "failed to check gc safePoint, ts %v", sp.BackupTS) + } + + err = utils.UpdateServiceSafePoint(ctx, s.mgr.GetPDClient(), sp) + if err != nil { + return errors.Trace(err) + } + + log.Info("set stream safePoint", zap.Object("safePoint", sp)) + return nil +} + +func (s *streamMgr) buildObserveRanges() ([]kv.KeyRange, error) { + dRanges, err := stream.BuildObserveDataRanges( + s.mgr.GetStorage(), + s.cfg.FilterStr, + s.cfg.TableFilter, + s.cfg.StartTS, + ) + if err != nil { + return nil, errors.Trace(err) + } + + mRange := stream.BuildObserveMetaRange() + rs := append([]kv.KeyRange{*mRange}, dRanges...) + slices.SortFunc(rs, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + + return rs, nil +} + +func (s *streamMgr) backupFullSchemas(ctx context.Context) error { + clusterVersion, err := s.mgr.GetClusterVersion(ctx) + if err != nil { + return errors.Trace(err) + } + + metaWriter := metautil.NewMetaWriter(s.bc.GetStorage(), metautil.MetaFileSize, true, metautil.MetaFile, nil) + metaWriter.Update(func(m *backuppb.BackupMeta) { + // save log startTS to backupmeta file + m.StartVersion = s.cfg.StartTS + m.ClusterId = s.bc.GetClusterID() + m.ClusterVersion = clusterVersion + }) + + if err = metaWriter.FlushBackupMeta(ctx); err != nil { + return errors.Trace(err) + } + return nil +} + +func (s *streamMgr) checkStreamStartEnable(ctx context.Context) error { + supportStream, err := s.mgr.IsLogBackupEnabled(ctx, s.httpCli) + if err != nil { + return errors.Trace(err) + } + if !supportStream { + return errors.New("Unable to create task about log-backup. " + + "please set TiKV config `log-backup.enable` to true and restart TiKVs.") + } + + return nil +} + +type RestoreFunc func(string) error + +// KeepGcDisabled keeps GC disabled and return a function that used to gc enabled. +// gc.ratio-threshold = "-1.0", which represents disable gc in TiKV. +func KeepGcDisabled(g glue.Glue, store kv.Storage) (RestoreFunc, string, error) { + se, err := g.CreateSession(store) + if err != nil { + return nil, "", errors.Trace(err) + } + + execCtx := se.GetSessionCtx().GetRestrictedSQLExecutor() + oldRatio, err := utils.GetGcRatio(execCtx) + if err != nil { + return nil, "", errors.Trace(err) + } + + newRatio := "-1.0" + err = utils.SetGcRatio(execCtx, newRatio) + if err != nil { + return nil, "", errors.Trace(err) + } + + // If the oldRatio is negative, which is not normal status. + // It should set default value "1.1" after PiTR finished. + if strings.HasPrefix(oldRatio, "-") { + oldRatio = utils.DefaultGcRatioVal + } + + return func(ratio string) error { + return utils.SetGcRatio(execCtx, ratio) + }, oldRatio, nil +} + +// RunStreamCommand run all kinds of `stream task` +func RunStreamCommand( + ctx context.Context, + g glue.Glue, + cmdName string, + cfg *StreamConfig, +) error { + cfg.Config.adjust() + defer func() { + if _, ok := skipSummaryCommandList[cmdName]; !ok { + summary.Summary(cmdName) + } + }() + commandFn, exist := StreamCommandMap[cmdName] + if !exist { + return errors.Annotatef(berrors.ErrInvalidArgument, "invalid command %s", cmdName) + } + + if err := commandFn(ctx, g, cmdName, cfg); err != nil { + log.Error("failed to stream", zap.String("command", cmdName), zap.Error(err)) + summary.SetSuccessStatus(false) + summary.CollectFailureUnit(cmdName, err) + return err + } + summary.SetSuccessStatus(true) + return nil +} + +// RunStreamStart specifies starting a stream task +func RunStreamStart( + c context.Context, + g glue.Glue, + cmdName string, + cfg *StreamConfig, +) error { + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("task.RunStreamStart", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + streamMgr, err := NewStreamMgr(ctx, cfg, g, true) + if err != nil { + return errors.Trace(err) + } + defer streamMgr.close() + + if err = streamMgr.checkStreamStartEnable(ctx); err != nil { + return errors.Trace(err) + } + if err = streamMgr.adjustAndCheckStartTS(ctx); err != nil { + return errors.Trace(err) + } + + etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) + if err != nil { + return errors.Trace(err) + } + cli := streamhelper.NewMetaDataClient(etcdCLI) + defer func() { + if closeErr := cli.Close(); closeErr != nil { + log.Warn("failed to close etcd client", zap.Error(closeErr)) + } + }() + if err = streamMgr.checkImportTaskRunning(ctx, cli.Client); err != nil { + return errors.Trace(err) + } + // It supports single stream log task currently. + if count, err := cli.GetTaskCount(ctx); err != nil { + return errors.Trace(err) + } else if count > 0 { + return errors.Annotate(berrors.ErrStreamLogTaskExist, "It supports single stream log task currently") + } + + exist, err := streamMgr.checkLock(ctx) + if err != nil { + return errors.Trace(err) + } + // exist is true, which represents restart a stream task. Or create a new stream task. + if exist { + logInfo, err := getLogRange(ctx, &cfg.Config) + if err != nil { + return errors.Trace(err) + } + if logInfo.clusterID > 0 && logInfo.clusterID != streamMgr.bc.GetClusterID() { + return errors.Annotatef(berrors.ErrInvalidArgument, + "the stream log files from cluster ID:%v and current cluster ID:%v ", + logInfo.clusterID, streamMgr.bc.GetClusterID()) + } + + cfg.StartTS = logInfo.logMaxTS + if err = streamMgr.setGCSafePoint( + ctx, + utils.BRServiceSafePoint{ + ID: utils.MakeSafePointID(), + TTL: cfg.SafePointTTL, + BackupTS: cfg.StartTS, + }, + ); err != nil { + return errors.Trace(err) + } + } else { + if err = streamMgr.setGCSafePoint( + ctx, + utils.BRServiceSafePoint{ + ID: utils.MakeSafePointID(), + TTL: cfg.SafePointTTL, + BackupTS: cfg.StartTS, + }, + ); err != nil { + return errors.Trace(err) + } + if err = streamMgr.setLock(ctx); err != nil { + return errors.Trace(err) + } + if err = streamMgr.backupFullSchemas(ctx); err != nil { + return errors.Trace(err) + } + } + + ranges, err := streamMgr.buildObserveRanges() + if err != nil { + return errors.Trace(err) + } else if len(ranges) == 0 { + // nothing to backup + pdAddress := strings.Join(cfg.PD, ",") + log.Warn("Nothing to observe, maybe connected to cluster for restoring", + zap.String("PD address", pdAddress)) + return errors.Annotate(berrors.ErrInvalidArgument, "nothing need to observe") + } + + ti := streamhelper.TaskInfo{ + PBInfo: backuppb.StreamBackupTaskInfo{ + Storage: streamMgr.bc.GetStorageBackend(), + StartTs: cfg.StartTS, + EndTs: cfg.EndTS, + Name: cfg.TaskName, + TableFilter: cfg.FilterStr, + CompressionType: backuppb.CompressionType_ZSTD, + }, + Ranges: ranges, + Pausing: false, + } + if err = cli.PutTask(ctx, ti); err != nil { + return errors.Trace(err) + } + summary.Log(cmdName, ti.ZapTaskInfo()...) + return nil +} + +func RunStreamMetadata( + c context.Context, + g glue.Glue, + cmdName string, + cfg *StreamConfig, +) error { + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan( + "task.RunStreamCheckLog", + opentracing.ChildOf(span.Context()), + ) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + logInfo, err := getLogRange(ctx, &cfg.Config) + if err != nil { + return errors.Trace(err) + } + + logMinDate := stream.FormatDate(oracle.GetTimeFromTS(logInfo.logMinTS)) + logMaxDate := stream.FormatDate(oracle.GetTimeFromTS(logInfo.logMaxTS)) + summary.Log(cmdName, zap.Uint64("log-min-ts", logInfo.logMinTS), + zap.String("log-min-date", logMinDate), + zap.Uint64("log-max-ts", logInfo.logMaxTS), + zap.String("log-max-date", logMaxDate), + ) + return nil +} + +// RunStreamStop specifies stoping a stream task +func RunStreamStop( + c context.Context, + g glue.Glue, + cmdName string, + cfg *StreamConfig, +) error { + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan( + "task.RunStreamStop", + opentracing.ChildOf(span.Context()), + ) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + streamMgr, err := NewStreamMgr(ctx, cfg, g, false) + if err != nil { + return errors.Trace(err) + } + defer streamMgr.close() + + etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) + if err != nil { + return errors.Trace(err) + } + cli := streamhelper.NewMetaDataClient(etcdCLI) + defer func() { + if closeErr := cli.Close(); closeErr != nil { + log.Warn("failed to close etcd client", zap.Error(closeErr)) + } + }() + // to add backoff + ti, err := cli.GetTask(ctx, cfg.TaskName) + if err != nil { + return errors.Trace(err) + } + + if err = cli.DeleteTask(ctx, cfg.TaskName); err != nil { + return errors.Trace(err) + } + + if err := streamMgr.setGCSafePoint(ctx, + utils.BRServiceSafePoint{ + ID: buildPauseSafePointName(ti.Info.Name), + TTL: 0, // 0 means remove this service safe point. + BackupTS: math.MaxUint64, + }, + ); err != nil { + log.Warn("failed to remove safe point", zap.String("error", err.Error())) + } + + summary.Log(cmdName, logutil.StreamBackupTaskInfo(&ti.Info)) + return nil +} + +// RunStreamPause specifies pausing a stream task. +func RunStreamPause( + c context.Context, + g glue.Glue, + cmdName string, + cfg *StreamConfig, +) error { + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan( + "task.RunStreamPause", + opentracing.ChildOf(span.Context()), + ) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + streamMgr, err := NewStreamMgr(ctx, cfg, g, false) + if err != nil { + return errors.Trace(err) + } + defer streamMgr.close() + + etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) + if err != nil { + return errors.Trace(err) + } + cli := streamhelper.NewMetaDataClient(etcdCLI) + defer func() { + if closeErr := cli.Close(); closeErr != nil { + log.Warn("failed to close etcd client", zap.Error(closeErr)) + } + }() + // to add backoff + ti, isPaused, err := cli.GetTaskWithPauseStatus(ctx, cfg.TaskName) + if err != nil { + return errors.Trace(err) + } else if isPaused { + return errors.Annotatef(berrors.ErrKVUnknown, "The task %s is paused already.", cfg.TaskName) + } + + globalCheckPointTS, err := ti.GetGlobalCheckPointTS(ctx) + if err != nil { + return errors.Trace(err) + } + if err = streamMgr.setGCSafePoint( + ctx, + utils.BRServiceSafePoint{ + ID: buildPauseSafePointName(ti.Info.Name), + TTL: cfg.SafePointTTL, + BackupTS: globalCheckPointTS, + }, + ); err != nil { + return errors.Trace(err) + } + + err = cli.PauseTask(ctx, cfg.TaskName) + if err != nil { + return errors.Trace(err) + } + + summary.Log(cmdName, logutil.StreamBackupTaskInfo(&ti.Info)) + return nil +} + +// RunStreamResume specifies resuming a stream task. +func RunStreamResume( + c context.Context, + g glue.Glue, + cmdName string, + cfg *StreamConfig, +) error { + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan( + "task.RunStreamResume", + opentracing.ChildOf(span.Context()), + ) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + streamMgr, err := NewStreamMgr(ctx, cfg, g, false) + if err != nil { + return errors.Trace(err) + } + defer streamMgr.close() + + etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) + if err != nil { + return errors.Trace(err) + } + cli := streamhelper.NewMetaDataClient(etcdCLI) + defer func() { + if closeErr := cli.Close(); closeErr != nil { + log.Warn("failed to close etcd client", zap.Error(closeErr)) + } + }() + // to add backoff + ti, isPaused, err := cli.GetTaskWithPauseStatus(ctx, cfg.TaskName) + if err != nil { + return errors.Trace(err) + } else if !isPaused { + return errors.Annotatef(berrors.ErrKVUnknown, + "The task %s is active already.", cfg.TaskName) + } + + globalCheckPointTS, err := ti.GetGlobalCheckPointTS(ctx) + if err != nil { + return errors.Trace(err) + } + err = utils.CheckGCSafePoint(ctx, streamMgr.mgr.GetPDClient(), globalCheckPointTS) + if err != nil { + return errors.Annotatef(err, "the global checkpoint ts: %v(%s) has been gc. ", + globalCheckPointTS, oracle.GetTimeFromTS(globalCheckPointTS)) + } + + err = cli.ResumeTask(ctx, cfg.TaskName) + if err != nil { + return errors.Trace(err) + } + + err = cli.CleanLastErrorOfTask(ctx, cfg.TaskName) + if err != nil { + return err + } + + if err := streamMgr.setGCSafePoint(ctx, + utils.BRServiceSafePoint{ + ID: buildPauseSafePointName(ti.Info.Name), + TTL: utils.DefaultStreamStartSafePointTTL, + BackupTS: globalCheckPointTS, + }, + ); err != nil { + log.Warn("failed to remove safe point", + zap.Uint64("safe-point", globalCheckPointTS), zap.String("error", err.Error())) + } + + summary.Log(cmdName, logutil.StreamBackupTaskInfo(&ti.Info)) + return nil +} + +func RunStreamAdvancer(c context.Context, g glue.Glue, cmdName string, cfg *StreamConfig) error { + ctx, cancel := context.WithCancel(c) + defer cancel() + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), + cfg.CheckRequirements, false, conn.StreamVersionChecker) + if err != nil { + return err + } + + etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) + if err != nil { + return err + } + env := streamhelper.CliEnv(mgr.StoreManager, mgr.GetStore(), etcdCLI) + advancer := streamhelper.NewCheckpointAdvancer(env) + advancer.UpdateConfig(cfg.AdvancerCfg) + advancerd := daemon.New(advancer, streamhelper.OwnerManagerForLogBackup(ctx, etcdCLI), cfg.AdvancerCfg.TickDuration) + loop, err := advancerd.Begin(ctx) + if err != nil { + return err + } + loop() + return nil +} + +func checkConfigForStatus(pd []string) error { + if len(pd) == 0 { + return errors.Annotatef(berrors.ErrInvalidArgument, + "the command needs access to PD, please specify `-u` or `--pd`") + } + + return nil +} + +// makeStatusController makes the status controller via some config. +// this should better be in the `stream` package but it is impossible because of cyclic requirements. +func makeStatusController(ctx context.Context, cfg *StreamConfig, g glue.Glue) (*stream.StatusController, error) { + console := glue.GetConsole(g) + etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) + if err != nil { + return nil, err + } + cli := streamhelper.NewMetaDataClient(etcdCLI) + var printer stream.TaskPrinter + if !cfg.JSONOutput { + printer = stream.PrintTaskByTable(console) + } else { + printer = stream.PrintTaskWithJSON(console) + } + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), + cfg.CheckRequirements, false, conn.StreamVersionChecker) + if err != nil { + return nil, err + } + return stream.NewStatusController(cli, mgr, printer), nil +} + +// RunStreamStatus get status for a specific stream task +func RunStreamStatus( + c context.Context, + g glue.Glue, + cmdName string, + cfg *StreamConfig, +) error { + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan( + "task.RunStreamStatus", + opentracing.ChildOf(span.Context()), + ) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + if err := checkConfigForStatus(cfg.PD); err != nil { + return err + } + ctl, err := makeStatusController(ctx, cfg, g) + if err != nil { + return err + } + + defer func() { + if closeErr := ctl.Close(); closeErr != nil { + log.Warn("failed to close etcd client", zap.Error(closeErr)) + } + }() + return ctl.PrintStatusOfTask(ctx, cfg.TaskName) +} + +// RunStreamTruncate truncates the log that belong to (0, until-ts) +func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *StreamConfig) (err error) { + console := glue.GetConsole(g) + em := color.New(color.Bold).SprintFunc() + warn := color.New(color.Bold, color.FgHiRed).SprintFunc() + formatTS := func(ts uint64) string { + return oracle.GetTimeFromTS(ts).Format("2006-01-02 15:04:05.0000") + } + if cfg.Until == 0 { + return errors.Annotatef(berrors.ErrInvalidArgument, "please provide the `--until` ts") + } + + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + extStorage, err := cfg.makeStorage(ctx) + if err != nil { + return err + } + if err := storage.TryLockRemote(ctx, extStorage, truncateLockPath, hintOnTruncateLock); err != nil { + return err + } + defer utils.WithCleanUp(&err, 10*time.Second, func(ctx context.Context) error { + return storage.UnlockRemote(ctx, extStorage, truncateLockPath) + }) + + sp, err := stream.GetTSFromFile(ctx, extStorage, stream.TruncateSafePointFileName) + if err != nil { + return err + } + + if cfg.Until < sp { + console.Println("According to the log, you have truncated backup data before", em(formatTS(sp))) + if !cfg.SkipPrompt && !console.PromptBool("Continue? ") { + return nil + } + } + + readMetaDone := console.ShowTask("Reading Metadata... ", glue.WithTimeCost()) + metas := stream.StreamMetadataSet{ + MetadataDownloadBatchSize: cfg.MetadataDownloadBatchSize, + Helper: stream.NewMetadataHelper(), + DryRun: cfg.DryRun, + } + shiftUntilTS, err := metas.LoadUntilAndCalculateShiftTS(ctx, extStorage, cfg.Until) + if err != nil { + return err + } + readMetaDone() + + var ( + fileCount int = 0 + kvCount int64 = 0 + totalSize uint64 = 0 + ) + + metas.IterateFilesFullyBefore(shiftUntilTS, func(d *stream.FileGroupInfo) (shouldBreak bool) { + fileCount++ + totalSize += d.Length + kvCount += d.KVCount + return + }) + console.Printf("We are going to remove %s files, until %s.\n", + em(fileCount), + em(formatTS(cfg.Until)), + ) + if !cfg.SkipPrompt && !console.PromptBool(warn("Sure? ")) { + return nil + } + + if cfg.Until > sp && !cfg.DryRun { + if err := stream.SetTSToFile( + ctx, extStorage, cfg.Until, stream.TruncateSafePointFileName); err != nil { + return err + } + } + + // begin to remove + p := console.StartProgressBar( + "Clearing Data Files and Metadata", fileCount, + glue.WithTimeCost(), + glue.WithConstExtraField("kv-count", kvCount), + glue.WithConstExtraField("kv-size", fmt.Sprintf("%d(%s)", totalSize, units.HumanSize(float64(totalSize)))), + ) + defer p.Close() + + notDeleted, err := metas.RemoveDataFilesAndUpdateMetadataInBatch(ctx, shiftUntilTS, extStorage, p.IncBy) + if err != nil { + return err + } + + if err := p.Wait(ctx); err != nil { + return err + } + + if len(notDeleted) > 0 { + const keepFirstNFailure = 16 + console.Println("Files below are not deleted due to error, you may clear it manually, check log for detail error:") + console.Println("- Total", em(len(notDeleted)), "items.") + if len(notDeleted) > keepFirstNFailure { + console.Println("-", em(len(notDeleted)-keepFirstNFailure), "items omitted.") + // TODO: maybe don't add them at the very first. + notDeleted = notDeleted[:keepFirstNFailure] + } + for _, f := range notDeleted { + console.Println(f) + } + } + + return nil +} + +// checkTaskExists checks whether there is a log backup task running. +// If so, return an error. +func checkTaskExists(ctx context.Context, cfg *RestoreConfig, etcdCLI *clientv3.Client) error { + if err := checkConfigForStatus(cfg.PD); err != nil { + return err + } + + cli := streamhelper.NewMetaDataClient(etcdCLI) + // check log backup task + tasks, err := cli.GetAllTasks(ctx) + if err != nil { + return err + } + if len(tasks) > 0 { + return errors.Errorf("log backup task is running: %s, "+ + "please stop the task before restore, and after PITR operation finished, "+ + "create log-backup task again and create a full backup on this cluster", tasks[0].Info.Name) + } + + return nil +} + +func checkIncompatibleChangefeed(ctx context.Context, backupTS uint64, etcdCLI *clientv3.Client) error { + nameSet, err := cdcutil.GetIncompatibleChangefeedsWithSafeTS(ctx, etcdCLI, backupTS) + if err != nil { + return err + } + if !nameSet.Empty() { + return errors.Errorf("%splease remove changefeed(s) before restore", nameSet.MessageToUser()) + } + return nil +} + +// RunStreamRestore restores stream log. +func RunStreamRestore( + c context.Context, + g glue.Glue, + cmdName string, + cfg *RestoreConfig, +) (err error) { + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan("task.RunStreamRestore", opentracing.ChildOf(span.Context())) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + _, s, err := GetStorage(ctx, cfg.Config.Storage, &cfg.Config) + if err != nil { + return errors.Trace(err) + } + logInfo, err := getLogRangeWithStorage(ctx, s) + if err != nil { + return errors.Trace(err) + } + if cfg.RestoreTS == 0 { + cfg.RestoreTS = logInfo.logMaxTS + } + + if len(cfg.FullBackupStorage) > 0 { + startTS, fullClusterID, err := getFullBackupTS(ctx, cfg) + if err != nil { + return errors.Trace(err) + } + if logInfo.clusterID > 0 && fullClusterID > 0 && logInfo.clusterID != fullClusterID { + return errors.Annotatef(berrors.ErrInvalidArgument, + "the full snapshot(from cluster ID:%v) and log(from cluster ID:%v) come from different cluster.", + fullClusterID, logInfo.clusterID) + } + + cfg.StartTS = startTS + if cfg.StartTS < logInfo.logMinTS { + return errors.Annotatef(berrors.ErrInvalidArgument, + "it has gap between full backup ts:%d(%s) and log backup ts:%d(%s). ", + cfg.StartTS, oracle.GetTimeFromTS(cfg.StartTS), + logInfo.logMinTS, oracle.GetTimeFromTS(logInfo.logMinTS)) + } + } + + log.Info("start restore on point", + zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS), + zap.Uint64("log-min-ts", logInfo.logMinTS), zap.Uint64("log-max-ts", logInfo.logMaxTS)) + if err := checkLogRange(cfg.StartTS, cfg.RestoreTS, logInfo.logMinTS, logInfo.logMaxTS); err != nil { + return errors.Trace(err) + } + + checkInfo, err := checkPiTRTaskInfo(ctx, g, s, cfg) + if err != nil { + return errors.Trace(err) + } + + failpoint.Inject("failed-before-full-restore", func(_ failpoint.Value) { + failpoint.Return(errors.New("failpoint: failed before full restore")) + }) + + recorder := tiflashrec.New() + cfg.tiflashRecorder = recorder + // restore full snapshot. + if checkInfo.NeedFullRestore { + logStorage := cfg.Config.Storage + cfg.Config.Storage = cfg.FullBackupStorage + // TiFlash replica is restored to down-stream on 'pitr' currently. + if err = runRestore(ctx, g, FullRestoreCmd, cfg, checkInfo); err != nil { + return errors.Trace(err) + } + cfg.Config.Storage = logStorage + } else if len(cfg.FullBackupStorage) > 0 { + skipMsg := []byte(fmt.Sprintf("%s command is skipped due to checkpoint mode for restore\n", FullRestoreCmd)) + if _, err := glue.GetConsole(g).Out().Write(skipMsg); err != nil { + return errors.Trace(err) + } + if checkInfo.CheckpointInfo != nil && checkInfo.CheckpointInfo.TiFlashItems != nil { + log.Info("load tiflash records of snapshot restore from checkpoint") + if err != nil { + return errors.Trace(err) + } + cfg.tiflashRecorder.Load(checkInfo.CheckpointInfo.TiFlashItems) + } + } + // restore log. + cfg.adjustRestoreConfigForStreamRestore() + if err := restoreStream(ctx, g, cfg, checkInfo.CheckpointInfo); err != nil { + return errors.Trace(err) + } + return nil +} + +// RunStreamRestore start restore job +func restoreStream( + c context.Context, + g glue.Glue, + cfg *RestoreConfig, + taskInfo *checkpoint.CheckpointTaskInfoForLogRestore, +) (err error) { + var ( + totalKVCount uint64 + totalSize uint64 + checkpointTotalKVCount uint64 + checkpointTotalSize uint64 + currentTS uint64 + mu sync.Mutex + startTime = time.Now() + ) + defer func() { + if err != nil { + summary.Log("restore log failed summary", zap.Error(err)) + } else { + totalDureTime := time.Since(startTime) + summary.Log("restore log success summary", zap.Duration("total-take", totalDureTime), + zap.Uint64("source-start-point", cfg.StartTS), + zap.Uint64("source-end-point", cfg.RestoreTS), + zap.Uint64("target-end-point", currentTS), + zap.String("source-start", stream.FormatDate(oracle.GetTimeFromTS(cfg.StartTS))), + zap.String("source-end", stream.FormatDate(oracle.GetTimeFromTS(cfg.RestoreTS))), + zap.String("target-end", stream.FormatDate(oracle.GetTimeFromTS(currentTS))), + zap.Uint64("total-kv-count", totalKVCount), + zap.Uint64("skipped-kv-count-by-checkpoint", checkpointTotalKVCount), + zap.String("total-size", units.HumanSize(float64(totalSize))), + zap.String("skipped-size-by-checkpoint", units.HumanSize(float64(checkpointTotalSize))), + zap.String("average-speed", units.HumanSize(float64(totalSize)/totalDureTime.Seconds())+"/s"), + ) + } + }() + + ctx, cancelFn := context.WithCancel(c) + defer cancelFn() + + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + span1 := span.Tracer().StartSpan( + "restoreStream", + opentracing.ChildOf(span.Context()), + ) + defer span1.Finish() + ctx = opentracing.ContextWithSpan(ctx, span1) + } + + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), + cfg.CheckRequirements, true, conn.StreamVersionChecker) + if err != nil { + return errors.Trace(err) + } + defer mgr.Close() + + client, err := createRestoreClient(ctx, g, cfg, mgr) + if err != nil { + return errors.Annotate(err, "failed to create restore client") + } + defer client.Close() + + if taskInfo != nil && taskInfo.RewriteTS > 0 { + // reuse the task's rewrite ts + log.Info("reuse the task's rewrite ts", zap.Uint64("rewrite-ts", taskInfo.RewriteTS)) + currentTS = taskInfo.RewriteTS + } else { + currentTS, err = restore.GetTSWithRetry(ctx, mgr.GetPDClient()) + if err != nil { + return errors.Trace(err) + } + } + client.SetCurrentTS(currentTS) + + importModeSwitcher := restore.NewImportModeSwitcher(mgr.GetPDClient(), cfg.Config.SwitchModeInterval, mgr.GetTLSConfig()) + restoreSchedulers, _, err := restore.RestorePreWork(ctx, mgr, importModeSwitcher, cfg.Online, false) + if err != nil { + return errors.Trace(err) + } + // Always run the post-work even on error, so we don't stuck in the import + // mode or emptied schedulers + defer restore.RestorePostWork(ctx, importModeSwitcher, restoreSchedulers, cfg.Online) + + // It need disable GC in TiKV when PiTR. + // because the process of PITR is concurrent and kv events isn't sorted by tso. + restoreGc, oldRatio, err := KeepGcDisabled(g, mgr.GetStorage()) + if err != nil { + return errors.Trace(err) + } + gcDisabledRestorable := false + defer func() { + // don't restore the gc-ratio-threshold if checkpoint mode is used and restored is not finished + if cfg.UseCheckpoint && !gcDisabledRestorable { + log.Info("skip restore the gc-ratio-threshold for next retry") + return + } + + log.Info("start to restore gc", zap.String("ratio", oldRatio)) + if err := restoreGc(oldRatio); err != nil { + log.Error("failed to set gc enabled", zap.Error(err)) + } + log.Info("finish restoring gc") + }() + + var taskName string + var checkpointRunner *checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType] + if cfg.UseCheckpoint { + taskName = cfg.generateLogRestoreTaskName(client.GetClusterID(ctx), cfg.StartTS, cfg.RestoreTS) + oldRatioFromCheckpoint, err := client.InitCheckpointMetadataForLogRestore(ctx, taskName, oldRatio) + if err != nil { + return errors.Trace(err) + } + oldRatio = oldRatioFromCheckpoint + + checkpointRunner, err = client.StartCheckpointRunnerForLogRestore(ctx, taskName) + if err != nil { + return errors.Trace(err) + } + defer func() { + log.Info("wait for flush checkpoint...") + checkpointRunner.WaitForFinish(ctx, !gcDisabledRestorable) + }() + } + + err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize) + if err != nil { + return err + } + + // get full backup meta storage to generate rewrite rules. + fullBackupStorage, err := parseFullBackupTablesStorage(cfg) + if err != nil { + return errors.Trace(err) + } + // load the id maps only when the checkpoint mode is used and not the first execution + newTask := true + if taskInfo != nil && taskInfo.Progress == checkpoint.InLogRestoreAndIdMapPersist { + newTask = false + } + // get the schemas ID replace information. + schemasReplace, err := client.InitSchemasReplaceForDDL(ctx, &logclient.InitSchemaConfig{ + IsNewTask: newTask, + TableFilter: cfg.TableFilter, + TiFlashRecorder: cfg.tiflashRecorder, + FullBackupStorage: fullBackupStorage, + }) + if err != nil { + return errors.Trace(err) + } + schemasReplace.AfterTableRewritten = func(deleted bool, tableInfo *model.TableInfo) { + // When the table replica changed to 0, the tiflash replica might be set to `nil`. + // We should remove the table if we meet. + if deleted || tableInfo.TiFlashReplica == nil { + cfg.tiflashRecorder.DelTable(tableInfo.ID) + return + } + cfg.tiflashRecorder.AddTable(tableInfo.ID, *tableInfo.TiFlashReplica) + // Remove the replica firstly. Let's restore them at the end. + tableInfo.TiFlashReplica = nil + } + + updateStats := func(kvCount uint64, size uint64) { + mu.Lock() + defer mu.Unlock() + totalKVCount += kvCount + totalSize += size + } + dataFileCount := 0 + ddlFiles, err := client.LoadDDLFilesAndCountDMLFiles(ctx, &dataFileCount) + if err != nil { + return err + } + pm := g.StartProgress(ctx, "Restore Meta Files", int64(len(ddlFiles)), !cfg.LogProgress) + if err = withProgress(pm, func(p glue.Progress) error { + client.RunGCRowsLoader(ctx) + return client.RestoreMetaKVFiles(ctx, ddlFiles, schemasReplace, updateStats, p.Inc) + }); err != nil { + return errors.Annotate(err, "failed to restore meta files") + } + + rewriteRules := initRewriteRules(schemasReplace) + + ingestRecorder := schemasReplace.GetIngestRecorder() + if err := rangeFilterFromIngestRecorder(ingestRecorder, rewriteRules); err != nil { + return errors.Trace(err) + } + + // generate the upstream->downstream id maps for checkpoint + idrules := make(map[int64]int64) + downstreamIdset := make(map[int64]struct{}) + for upstreamId, rule := range rewriteRules { + downstreamId := restoreutils.GetRewriteTableID(upstreamId, rule) + idrules[upstreamId] = downstreamId + downstreamIdset[downstreamId] = struct{}{} + } + + logFilesIter, err := client.LoadDMLFiles(ctx) + if err != nil { + return errors.Trace(err) + } + pd := g.StartProgress(ctx, "Restore KV Files", int64(dataFileCount), !cfg.LogProgress) + err = withProgress(pd, func(p glue.Progress) error { + if cfg.UseCheckpoint { + updateStatsWithCheckpoint := func(kvCount, size uint64) { + mu.Lock() + defer mu.Unlock() + totalKVCount += kvCount + totalSize += size + checkpointTotalKVCount += kvCount + checkpointTotalSize += size + } + logFilesIter, err = client.WrapLogFilesIterWithCheckpoint(ctx, logFilesIter, downstreamIdset, taskName, updateStatsWithCheckpoint, p.Inc) + if err != nil { + return errors.Trace(err) + } + } + logFilesIterWithSplit, err := client.WrapLogFilesIterWithSplitHelper(logFilesIter, rewriteRules, g, mgr.GetStorage()) + if err != nil { + return errors.Trace(err) + } + + return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy) + }) + if err != nil { + return errors.Annotate(err, "failed to restore kv files") + } + + if err = client.CleanUpKVFiles(ctx); err != nil { + return errors.Annotate(err, "failed to clean up") + } + + if err = client.InsertGCRows(ctx); err != nil { + return errors.Annotate(err, "failed to insert rows into gc_delete_range") + } + + if err = client.RepairIngestIndex(ctx, ingestRecorder, g, taskName); err != nil { + return errors.Annotate(err, "failed to repair ingest index") + } + + if cfg.tiflashRecorder != nil { + sqls := cfg.tiflashRecorder.GenerateAlterTableDDLs(mgr.GetDomain().InfoSchema()) + log.Info("Generating SQLs for restoring TiFlash Replica", + zap.Strings("sqls", sqls)) + err = g.UseOneShotSession(mgr.GetStorage(), false, func(se glue.Session) error { + for _, sql := range sqls { + if errExec := se.ExecuteInternal(ctx, sql); errExec != nil { + logutil.WarnTerm("Failed to restore tiflash replica config, you may execute the sql restore it manually.", + logutil.ShortError(errExec), + zap.String("sql", sql), + ) + } + } + return nil + }) + if err != nil { + return err + } + } + + failpoint.Inject("do-checksum-with-rewrite-rules", func(_ failpoint.Value) { + if err := client.FailpointDoChecksumForLogRestore(ctx, mgr.GetStorage().GetClient(), mgr.GetPDClient(), idrules, rewriteRules); err != nil { + failpoint.Return(errors.Annotate(err, "failed to do checksum")) + } + }) + + gcDisabledRestorable = true + + return nil +} + +func createRestoreClient(ctx context.Context, g glue.Glue, cfg *RestoreConfig, mgr *conn.Mgr) (*logclient.LogClient, error) { + var err error + keepaliveCfg := GetKeepalive(&cfg.Config) + keepaliveCfg.PermitWithoutStream = true + client := logclient.NewRestoreClient(mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), keepaliveCfg) + err = client.Init(g, mgr.GetStorage()) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + if err != nil { + client.Close() + } + }() + + u, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) + if err != nil { + return nil, errors.Trace(err) + } + + opts := getExternalStorageOptions(&cfg.Config, u) + if err = client.SetStorage(ctx, u, &opts); err != nil { + return nil, errors.Trace(err) + } + client.SetCrypter(&cfg.CipherInfo) + client.SetConcurrency(uint(cfg.Concurrency)) + client.InitClients(ctx, u) + + err = client.SetRawKVBatchClient(ctx, cfg.PD, cfg.TLS.ToKVSecurity()) + if err != nil { + return nil, errors.Trace(err) + } + + return client, nil +} + +// rangeFilterFromIngestRecorder rewrites the table id of items in the ingestRecorder +// TODO: need to implement the range filter out feature +func rangeFilterFromIngestRecorder(recorder *ingestrec.IngestRecorder, rewriteRules map[int64]*restoreutils.RewriteRules) error { + err := recorder.RewriteTableID(func(tableID int64) (int64, bool, error) { + rewriteRule, exists := rewriteRules[tableID] + if !exists { + // since the table's files will be skipped restoring, here also skips. + return 0, true, nil + } + newTableID := restoreutils.GetRewriteTableID(tableID, rewriteRule) + if newTableID == 0 { + return 0, false, errors.Errorf("newTableID is 0, tableID: %d", tableID) + } + return newTableID, false, nil + }) + return errors.Trace(err) +} + +func getExternalStorageOptions(cfg *Config, u *backuppb.StorageBackend) storage.ExternalStorageOptions { + var httpClient *http.Client + if u.GetGcs() == nil { + httpClient = storage.GetDefaultHttpClient(cfg.MetadataDownloadBatchSize) + } + return storage.ExternalStorageOptions{ + NoCredentials: cfg.NoCreds, + SendCredentials: cfg.SendCreds, + HTTPClient: httpClient, + } +} + +func checkLogRange(restoreFrom, restoreTo, logMinTS, logMaxTS uint64) error { + // serveral ts constraint: + // logMinTS <= restoreFrom <= restoreTo <= logMaxTS + if logMinTS > restoreFrom || restoreFrom > restoreTo || restoreTo > logMaxTS { + return errors.Annotatef(berrors.ErrInvalidArgument, + "restore log from %d(%s) to %d(%s), "+ + " but the current existed log from %d(%s) to %d(%s)", + restoreFrom, oracle.GetTimeFromTS(restoreFrom), + restoreTo, oracle.GetTimeFromTS(restoreTo), + logMinTS, oracle.GetTimeFromTS(logMinTS), + logMaxTS, oracle.GetTimeFromTS(logMaxTS), + ) + } + return nil +} + +// withProgress execute some logic with the progress, and close it once the execution done. +func withProgress(p glue.Progress, cc func(p glue.Progress) error) error { + defer p.Close() + return cc(p) +} + +type backupLogInfo struct { + logMaxTS uint64 + logMinTS uint64 + clusterID uint64 +} + +// getLogRange gets the log-min-ts and log-max-ts of starting log backup. +func getLogRange( + ctx context.Context, + cfg *Config, +) (backupLogInfo, error) { + _, s, err := GetStorage(ctx, cfg.Storage, cfg) + if err != nil { + return backupLogInfo{}, errors.Trace(err) + } + return getLogRangeWithStorage(ctx, s) +} + +func getLogRangeWithStorage( + ctx context.Context, + s storage.ExternalStorage, +) (backupLogInfo, error) { + // logStartTS: Get log start ts from backupmeta file. + metaData, err := s.ReadFile(ctx, metautil.MetaFile) + if err != nil { + return backupLogInfo{}, errors.Trace(err) + } + backupMeta := &backuppb.BackupMeta{} + if err = backupMeta.Unmarshal(metaData); err != nil { + return backupLogInfo{}, errors.Trace(err) + } + // endVersion > 0 represents that the storage has been used for `br backup` + if backupMeta.GetEndVersion() > 0 { + return backupLogInfo{}, errors.Annotate(berrors.ErrStorageUnknown, + "the storage has been used for full backup") + } + logStartTS := backupMeta.GetStartVersion() + + // truncateTS: get log truncate ts from TruncateSafePointFileName. + // If truncateTS equals 0, which represents the stream log has never been truncated. + truncateTS, err := stream.GetTSFromFile(ctx, s, stream.TruncateSafePointFileName) + if err != nil { + return backupLogInfo{}, errors.Trace(err) + } + logMinTS := max(logStartTS, truncateTS) + + // get max global resolved ts from metas. + logMaxTS, err := getGlobalCheckpointFromStorage(ctx, s) + if err != nil { + return backupLogInfo{}, errors.Trace(err) + } + logMaxTS = max(logMinTS, logMaxTS) + + return backupLogInfo{ + logMaxTS: logMaxTS, + logMinTS: logMinTS, + clusterID: backupMeta.ClusterId, + }, nil +} + +func getGlobalCheckpointFromStorage(ctx context.Context, s storage.ExternalStorage) (uint64, error) { + var globalCheckPointTS uint64 = 0 + opt := storage.WalkOption{SubDir: stream.GetStreamBackupGlobalCheckpointPrefix()} + err := s.WalkDir(ctx, &opt, func(path string, size int64) error { + if !strings.HasSuffix(path, ".ts") { + return nil + } + + buff, err := s.ReadFile(ctx, path) + if err != nil { + return errors.Trace(err) + } + ts := binary.LittleEndian.Uint64(buff) + globalCheckPointTS = max(ts, globalCheckPointTS) + return nil + }) + return globalCheckPointTS, errors.Trace(err) +} + +// getFullBackupTS gets the snapshot-ts of full bakcup +func getFullBackupTS( + ctx context.Context, + cfg *RestoreConfig, +) (uint64, uint64, error) { + _, s, err := GetStorage(ctx, cfg.FullBackupStorage, &cfg.Config) + if err != nil { + return 0, 0, errors.Trace(err) + } + + metaData, err := s.ReadFile(ctx, metautil.MetaFile) + if err != nil { + return 0, 0, errors.Trace(err) + } + + backupmeta := &backuppb.BackupMeta{} + if err = backupmeta.Unmarshal(metaData); err != nil { + return 0, 0, errors.Trace(err) + } + + return backupmeta.GetEndVersion(), backupmeta.GetClusterId(), nil +} + +func parseFullBackupTablesStorage( + cfg *RestoreConfig, +) (*logclient.FullBackupStorageConfig, error) { + if len(cfg.FullBackupStorage) == 0 { + log.Info("the full backup path is not specified, so BR will try to get id maps") + return nil, nil + } + u, err := storage.ParseBackend(cfg.FullBackupStorage, &cfg.BackendOptions) + if err != nil { + return nil, errors.Trace(err) + } + return &logclient.FullBackupStorageConfig{ + Backend: u, + Opts: storageOpts(&cfg.Config), + }, nil +} + +func initRewriteRules(schemasReplace *stream.SchemasReplace) map[int64]*restoreutils.RewriteRules { + rules := make(map[int64]*restoreutils.RewriteRules) + filter := schemasReplace.TableFilter + + for _, dbReplace := range schemasReplace.DbMap { + if utils.IsSysDB(dbReplace.Name) || !filter.MatchSchema(dbReplace.Name) { + continue + } + + for oldTableID, tableReplace := range dbReplace.TableMap { + if !filter.MatchTable(dbReplace.Name, tableReplace.Name) { + continue + } + + if _, exist := rules[oldTableID]; !exist { + log.Info("add rewrite rule", + zap.String("tableName", dbReplace.Name+"."+tableReplace.Name), + zap.Int64("oldID", oldTableID), zap.Int64("newID", tableReplace.TableID)) + rules[oldTableID] = restoreutils.GetRewriteRuleOfTable( + oldTableID, tableReplace.TableID, 0, tableReplace.IndexMap, false) + } + + for oldID, newID := range tableReplace.PartitionMap { + if _, exist := rules[oldID]; !exist { + log.Info("add rewrite rule", + zap.String("tableName", dbReplace.Name+"."+tableReplace.Name), + zap.Int64("oldID", oldID), zap.Int64("newID", newID)) + rules[oldID] = restoreutils.GetRewriteRuleOfTable(oldID, newID, 0, tableReplace.IndexMap, false) + } + } + } + } + return rules +} + +// ShiftTS gets a smaller shiftTS than startTS. +// It has a safe duration between shiftTS and startTS for trasaction. +func ShiftTS(startTS uint64) uint64 { + physical := oracle.ExtractPhysical(startTS) + logical := oracle.ExtractLogical(startTS) + + shiftPhysical := physical - streamShiftDuration.Milliseconds() + if shiftPhysical < 0 { + return 0 + } + return oracle.ComposeTS(shiftPhysical, logical) +} + +func buildPauseSafePointName(taskName string) string { + return fmt.Sprintf("%s_pause_safepoint", taskName) +} + +func checkPiTRRequirements(mgr *conn.Mgr) error { + return restore.AssertUserDBsEmpty(mgr.GetDomain()) +} + +type PiTRTaskInfo struct { + CheckpointInfo *checkpoint.CheckpointTaskInfoForLogRestore + NeedFullRestore bool + FullRestoreCheckErr error +} + +func checkPiTRTaskInfo( + ctx context.Context, + g glue.Glue, + s storage.ExternalStorage, + cfg *RestoreConfig, +) (*PiTRTaskInfo, error) { + var ( + doFullRestore = (len(cfg.FullBackupStorage) > 0) + curTaskInfo *checkpoint.CheckpointTaskInfoForLogRestore + errTaskMsg string + ) + checkInfo := &PiTRTaskInfo{} + + mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), + cfg.CheckRequirements, true, conn.StreamVersionChecker) + if err != nil { + return checkInfo, errors.Trace(err) + } + defer mgr.Close() + + clusterID := mgr.GetPDClient().GetClusterID(ctx) + if cfg.UseCheckpoint { + exists, err := checkpoint.ExistsCheckpointTaskInfo(ctx, s, clusterID) + if err != nil { + return checkInfo, errors.Trace(err) + } + if exists { + curTaskInfo, err = checkpoint.LoadCheckpointTaskInfoForLogRestore(ctx, s, clusterID) + if err != nil { + return checkInfo, errors.Trace(err) + } + // TODO: check whether user has manually modified the cluster(ddl). If so, regard the behavior + // as restore from scratch. (update `curTaskInfo.RewriteTs` to 0 as an uninitial value) + + // The task info is written to external storage without status `InSnapshotRestore` only when + // id-maps is persist into external storage, so there is no need to do snapshot restore again. + if curTaskInfo.StartTS == cfg.StartTS && curTaskInfo.RestoreTS == cfg.RestoreTS { + // the same task, check whether skip snapshot restore + doFullRestore = doFullRestore && (curTaskInfo.Progress == checkpoint.InSnapshotRestore) + // update the snapshot restore task name to clean up in final + if !doFullRestore && (len(cfg.FullBackupStorage) > 0) { + _ = cfg.generateSnapshotRestoreTaskName(clusterID) + } + log.Info("the same task", zap.Bool("skip-snapshot-restore", !doFullRestore)) + } else { + // not the same task, so overwrite the taskInfo with a new task + log.Info("not the same task, start to restore from scratch") + errTaskMsg = fmt.Sprintf( + "a new task [start-ts=%d] [restored-ts=%d] while the last task info: [start-ts=%d] [restored-ts=%d] [skip-snapshot-restore=%t]", + cfg.StartTS, cfg.RestoreTS, curTaskInfo.StartTS, curTaskInfo.RestoreTS, curTaskInfo.Progress == checkpoint.InLogRestoreAndIdMapPersist) + + curTaskInfo = nil + } + } + } + checkInfo.CheckpointInfo = curTaskInfo + checkInfo.NeedFullRestore = doFullRestore + // restore full snapshot precheck. + if doFullRestore { + if !(cfg.UseCheckpoint && curTaskInfo != nil) { + // Only when use checkpoint and not the first execution, + // skip checking requirements. + log.Info("check pitr requirements for the first execution") + if err := checkPiTRRequirements(mgr); err != nil { + if len(errTaskMsg) > 0 { + err = errors.Annotatef(err, "The current restore task is regarded as %s. "+ + "If you ensure that no changes have been made to the cluster since the last execution, "+ + "you can adjust the `start-ts` or `restored-ts` to continue with the previous execution. "+ + "Otherwise, if you want to restore from scratch, please clean the cluster at first", errTaskMsg) + } + // delay cluster checks after we get the backupmeta. + // for the case that the restore inc + log backup, + // we can still restore them. + checkInfo.FullRestoreCheckErr = err + return checkInfo, nil + } + } + } + + // persist the new task info + if cfg.UseCheckpoint && curTaskInfo == nil { + log.Info("save checkpoint task info with `InSnapshotRestore` status") + if err := checkpoint.SaveCheckpointTaskInfoForLogRestore(ctx, s, &checkpoint.CheckpointTaskInfoForLogRestore{ + Progress: checkpoint.InSnapshotRestore, + StartTS: cfg.StartTS, + RestoreTS: cfg.RestoreTS, + // updated in the stage of `InLogRestoreAndIdMapPersist` + RewriteTS: 0, + TiFlashItems: nil, + }, clusterID); err != nil { + return checkInfo, errors.Trace(err) + } + } + return checkInfo, nil +} diff --git a/br/pkg/utils/backoff.go b/br/pkg/utils/backoff.go index 385ed4319a06a..e8093d86cfa35 100644 --- a/br/pkg/utils/backoff.go +++ b/br/pkg/utils/backoff.go @@ -268,9 +268,9 @@ func (bo *pdReqBackoffer) NextBackoff(err error) time.Duration { } } - failpoint.Inject("set-attempt-to-one", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("set-attempt-to-one")); _err_ == nil { bo.attempt = 1 - }) + } if bo.delayTime > bo.maxDelayTime { return bo.maxDelayTime } diff --git a/br/pkg/utils/backoff.go__failpoint_stash__ b/br/pkg/utils/backoff.go__failpoint_stash__ new file mode 100644 index 0000000000000..385ed4319a06a --- /dev/null +++ b/br/pkg/utils/backoff.go__failpoint_stash__ @@ -0,0 +1,323 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package utils + +import ( + "context" + "database/sql" + "io" + "math" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "go.uber.org/multierr" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +const ( + // importSSTRetryTimes specifies the retry time. Its longest time is about 90s-100s. + importSSTRetryTimes = 16 + importSSTWaitInterval = 40 * time.Millisecond + importSSTMaxWaitInterval = 10 * time.Second + + downloadSSTRetryTimes = 8 + downloadSSTWaitInterval = 1 * time.Second + downloadSSTMaxWaitInterval = 4 * time.Second + + backupSSTRetryTimes = 5 + backupSSTWaitInterval = 2 * time.Second + backupSSTMaxWaitInterval = 3 * time.Second + + resetTSRetryTime = 32 + resetTSWaitInterval = 50 * time.Millisecond + resetTSMaxWaitInterval = 2 * time.Second + + resetTSRetryTimeExt = 600 + resetTSWaitIntervalExt = 500 * time.Millisecond + resetTSMaxWaitIntervalExt = 300 * time.Second + + // region heartbeat are 10 seconds by default, if some region has 2 heartbeat missing (15 seconds), it appear to be a network issue between PD and TiKV. + FlashbackRetryTime = 3 + FlashbackWaitInterval = 3 * time.Second + FlashbackMaxWaitInterval = 15 * time.Second + + ChecksumRetryTime = 8 + ChecksumWaitInterval = 1 * time.Second + ChecksumMaxWaitInterval = 30 * time.Second + + gRPC_Cancel = "the client connection is closing" +) + +// At least, there are two possible cancel() call, +// one from go context, another from gRPC, here we retry when gRPC cancel with connection closing +func isGRPCCancel(err error) bool { + if s, ok := status.FromError(err); ok { + if strings.Contains(s.Message(), gRPC_Cancel) { + return true + } + } + return false +} + +// ConstantBackoff is a backoffer that retry forever until success. +type ConstantBackoff time.Duration + +// NextBackoff returns a duration to wait before retrying again +func (c ConstantBackoff) NextBackoff(err error) time.Duration { + return time.Duration(c) +} + +// Attempt returns the remain attempt times +func (c ConstantBackoff) Attempt() int { + // A large enough value. Also still safe for arithmetic operations (won't easily overflow). + return math.MaxInt16 +} + +// RetryState is the mutable state needed for retrying. +// It likes the `utils.Backoffer`, but more fundamental: +// this only control the backoff time and knows nothing about what error happens. +// NOTE: Maybe also implement the backoffer via this. +type RetryState struct { + maxRetry int + retryTimes int + + maxBackoff time.Duration + nextBackoff time.Duration +} + +// InitialRetryState make the initial state for retrying. +func InitialRetryState(maxRetryTimes int, initialBackoff, maxBackoff time.Duration) RetryState { + return RetryState{ + maxRetry: maxRetryTimes, + maxBackoff: maxBackoff, + nextBackoff: initialBackoff, + } +} + +// Whether in the current state we can retry. +func (rs *RetryState) ShouldRetry() bool { + return rs.retryTimes < rs.maxRetry +} + +// Get the exponential backoff durion and transform the state. +func (rs *RetryState) ExponentialBackoff() time.Duration { + rs.retryTimes++ + backoff := rs.nextBackoff + rs.nextBackoff *= 2 + if rs.nextBackoff > rs.maxBackoff { + rs.nextBackoff = rs.maxBackoff + } + return backoff +} + +func (rs *RetryState) GiveUp() { + rs.retryTimes = rs.maxRetry +} + +// ReduceRetry reduces retry times for 1. +func (rs *RetryState) ReduceRetry() { + rs.retryTimes-- +} + +// Attempt implements the `Backoffer`. +// TODO: Maybe use this to replace the `exponentialBackoffer` (which is nearly homomorphic to this)? +func (rs *RetryState) Attempt() int { + return rs.maxRetry - rs.retryTimes +} + +// NextBackoff implements the `Backoffer`. +func (rs *RetryState) NextBackoff(error) time.Duration { + return rs.ExponentialBackoff() +} + +type importerBackoffer struct { + attempt int + delayTime time.Duration + maxDelayTime time.Duration + errContext *ErrorContext +} + +// NewBackoffer creates a new controller regulating a truncated exponential backoff. +func NewBackoffer(attempt int, delayTime, maxDelayTime time.Duration, errContext *ErrorContext) Backoffer { + return &importerBackoffer{ + attempt: attempt, + delayTime: delayTime, + maxDelayTime: maxDelayTime, + errContext: errContext, + } +} + +func NewImportSSTBackoffer() Backoffer { + errContext := NewErrorContext("import sst", 3) + return NewBackoffer(importSSTRetryTimes, importSSTWaitInterval, importSSTMaxWaitInterval, errContext) +} + +func NewDownloadSSTBackoffer() Backoffer { + errContext := NewErrorContext("download sst", 3) + return NewBackoffer(downloadSSTRetryTimes, downloadSSTWaitInterval, downloadSSTMaxWaitInterval, errContext) +} + +func NewBackupSSTBackoffer() Backoffer { + errContext := NewErrorContext("backup sst", 3) + return NewBackoffer(backupSSTRetryTimes, backupSSTWaitInterval, backupSSTMaxWaitInterval, errContext) +} + +func (bo *importerBackoffer) NextBackoff(err error) time.Duration { + // we don't care storeID here. + errs := multierr.Errors(err) + lastErr := errs[len(errs)-1] + res := HandleUnknownBackupError(lastErr.Error(), 0, bo.errContext) + if res.Strategy == StrategyRetry { + bo.delayTime = 2 * bo.delayTime + bo.attempt-- + } else { + e := errors.Cause(lastErr) + switch e { // nolint:errorlint + case berrors.ErrKVEpochNotMatch, berrors.ErrKVDownloadFailed, berrors.ErrKVIngestFailed, berrors.ErrPDLeaderNotFound: + bo.delayTime = 2 * bo.delayTime + bo.attempt-- + case berrors.ErrKVRangeIsEmpty, berrors.ErrKVRewriteRuleNotFound: + // Expected error, finish the operation + bo.delayTime = 0 + bo.attempt = 0 + default: + switch status.Code(e) { + case codes.Unavailable, codes.Aborted, codes.DeadlineExceeded, codes.ResourceExhausted, codes.Internal: + bo.delayTime = 2 * bo.delayTime + bo.attempt-- + case codes.Canceled: + if isGRPCCancel(lastErr) { + bo.delayTime = 2 * bo.delayTime + bo.attempt-- + } else { + bo.delayTime = 0 + bo.attempt = 0 + } + default: + // Unexpected error + bo.delayTime = 0 + bo.attempt = 0 + log.Warn("unexpected error, stop retrying", zap.Error(err)) + } + } + } + if bo.delayTime > bo.maxDelayTime { + return bo.maxDelayTime + } + return bo.delayTime +} + +func (bo *importerBackoffer) Attempt() int { + return bo.attempt +} + +type pdReqBackoffer struct { + attempt int + delayTime time.Duration + maxDelayTime time.Duration +} + +func NewPDReqBackoffer() Backoffer { + return &pdReqBackoffer{ + attempt: resetTSRetryTime, + delayTime: resetTSWaitInterval, + maxDelayTime: resetTSMaxWaitInterval, + } +} + +func NewPDReqBackofferExt() Backoffer { + return &pdReqBackoffer{ + attempt: resetTSRetryTimeExt, + delayTime: resetTSWaitIntervalExt, + maxDelayTime: resetTSMaxWaitIntervalExt, + } +} + +func (bo *pdReqBackoffer) NextBackoff(err error) time.Duration { + // bo.delayTime = 2 * bo.delayTime + // bo.attempt-- + e := errors.Cause(err) + switch e { // nolint:errorlint + case nil, context.Canceled, context.DeadlineExceeded, sql.ErrNoRows: + // Excepted error, finish the operation + bo.delayTime = 0 + bo.attempt = 0 + case berrors.ErrRestoreTotalKVMismatch, io.EOF: + bo.delayTime = 2 * bo.delayTime + bo.attempt-- + default: + // If the connection timeout, pd client would cancel the context, and return grpc context cancel error. + // So make the codes.Canceled retryable too. + // It's OK to retry the grpc context cancel error, because the parent context cancel returns context.Canceled. + // For example, cancel the `ectx` and then pdClient.GetTS(ectx) returns context.Canceled instead of grpc context canceled. + switch status.Code(e) { + case codes.DeadlineExceeded, codes.Canceled, codes.NotFound, codes.AlreadyExists, codes.PermissionDenied, codes.ResourceExhausted, codes.Aborted, codes.OutOfRange, codes.Unavailable, codes.DataLoss, codes.Unknown: + bo.delayTime = 2 * bo.delayTime + bo.attempt-- + default: + // Unexcepted error + bo.delayTime = 0 + bo.attempt = 0 + log.Warn("unexcepted error, stop to retry", zap.Error(err)) + } + } + + failpoint.Inject("set-attempt-to-one", func(_ failpoint.Value) { + bo.attempt = 1 + }) + if bo.delayTime > bo.maxDelayTime { + return bo.maxDelayTime + } + return bo.delayTime +} + +func (bo *pdReqBackoffer) Attempt() int { + return bo.attempt +} + +type DiskCheckBackoffer struct { + attempt int + delayTime time.Duration + maxDelayTime time.Duration +} + +func NewDiskCheckBackoffer() Backoffer { + return &DiskCheckBackoffer{ + attempt: resetTSRetryTime, + delayTime: resetTSWaitInterval, + maxDelayTime: resetTSMaxWaitInterval, + } +} + +func (bo *DiskCheckBackoffer) NextBackoff(err error) time.Duration { + e := errors.Cause(err) + switch e { // nolint:errorlint + case nil, context.Canceled, context.DeadlineExceeded, berrors.ErrKVDiskFull: + bo.delayTime = 0 + bo.attempt = 0 + case berrors.ErrPDInvalidResponse: + bo.delayTime = 2 * bo.delayTime + bo.attempt-- + default: + bo.delayTime = 2 * bo.delayTime + if bo.attempt > 5 { + bo.attempt = 5 + } + bo.attempt-- + } + + if bo.delayTime > bo.maxDelayTime { + return bo.maxDelayTime + } + return bo.delayTime +} + +func (bo *DiskCheckBackoffer) Attempt() int { + return bo.attempt +} diff --git a/br/pkg/utils/binding__failpoint_binding__.go b/br/pkg/utils/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..0a24d8976f30c --- /dev/null +++ b/br/pkg/utils/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package utils + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/br/pkg/utils/pprof.go b/br/pkg/utils/pprof.go index c2e5ad63c8e5a..66e5c5e57d14f 100644 --- a/br/pkg/utils/pprof.go +++ b/br/pkg/utils/pprof.go @@ -33,11 +33,11 @@ func listen(statusAddr string) (net.Listener, error) { log.Warn("Try to start pprof when it has been started, nothing will happen", zap.String("address", startedPProf)) return nil, errors.Annotate(berrors.ErrUnknown, "try to start pprof when it has been started at "+startedPProf) } - failpoint.Inject("determined-pprof-port", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("determined-pprof-port")); _err_ == nil { port := v.(int) statusAddr = fmt.Sprintf(":%d", port) log.Info("injecting failpoint, pprof will start at determined port", zap.Int("port", port)) - }) + } listener, err := net.Listen("tcp", statusAddr) if err != nil { log.Warn("failed to start pprof", zap.String("addr", statusAddr), zap.Error(err)) diff --git a/br/pkg/utils/pprof.go__failpoint_stash__ b/br/pkg/utils/pprof.go__failpoint_stash__ new file mode 100644 index 0000000000000..c2e5ad63c8e5a --- /dev/null +++ b/br/pkg/utils/pprof.go__failpoint_stash__ @@ -0,0 +1,69 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package utils + +import ( + "fmt" + "net" //nolint:goimports + // #nosec + // register HTTP handler for /debug/pprof + "net/http" + // For pprof + _ "net/http/pprof" // #nosec G108 + "os" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + tidbutils "github.com/pingcap/tidb/pkg/util" + "go.uber.org/zap" +) + +var ( + startedPProf = "" + mu sync.Mutex +) + +func listen(statusAddr string) (net.Listener, error) { + mu.Lock() + defer mu.Unlock() + if startedPProf != "" { + log.Warn("Try to start pprof when it has been started, nothing will happen", zap.String("address", startedPProf)) + return nil, errors.Annotate(berrors.ErrUnknown, "try to start pprof when it has been started at "+startedPProf) + } + failpoint.Inject("determined-pprof-port", func(v failpoint.Value) { + port := v.(int) + statusAddr = fmt.Sprintf(":%d", port) + log.Info("injecting failpoint, pprof will start at determined port", zap.Int("port", port)) + }) + listener, err := net.Listen("tcp", statusAddr) + if err != nil { + log.Warn("failed to start pprof", zap.String("addr", statusAddr), zap.Error(err)) + return nil, errors.Trace(err) + } + startedPProf = listener.Addr().String() + log.Info("bound pprof to addr", zap.String("addr", startedPProf)) + _, _ = fmt.Fprintf(os.Stderr, "bound pprof to addr %s\n", startedPProf) + return listener, nil +} + +// StartPProfListener forks a new goroutine listening on specified port and provide pprof info. +func StartPProfListener(statusAddr string, wrapper *tidbutils.TLS) error { + listener, err := listen(statusAddr) + if err != nil { + return err + } + + go func() { + if e := http.Serve(wrapper.WrapListener(listener), nil); e != nil { + log.Warn("failed to serve pprof", zap.String("addr", startedPProf), zap.Error(e)) + mu.Lock() + startedPProf = "" + mu.Unlock() + return + } + }() + return nil +} diff --git a/br/pkg/utils/register.go b/br/pkg/utils/register.go index 95a102ae68d26..3aeb4a9b21343 100644 --- a/br/pkg/utils/register.go +++ b/br/pkg/utils/register.go @@ -195,17 +195,17 @@ func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.L if timeLeftThreshold < minTimeLeftThreshold { timeLeftThreshold = minTimeLeftThreshold } - failpoint.Inject("brie-task-register-always-grant", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-always-grant")); _err_ == nil { timeLeftThreshold = tr.ttl - }) + } for { CONSUMERESP: for { - failpoint.Inject("brie-task-register-keepalive-stop", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-keepalive-stop")); _err_ == nil { if _, err = tr.client.Lease.Revoke(ctx, tr.curLeaseID); err != nil { log.Warn("brie-task-register-keepalive-stop", zap.Error(err)) } - }) + } select { case <-ctx.Done(): return @@ -223,9 +223,9 @@ func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.L timeGap := time.Since(lastUpdateTime) if tr.ttl-timeGap <= timeLeftThreshold { lease, err := tr.grant(ctx) - failpoint.Inject("brie-task-register-failed-to-grant", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-failed-to-grant")); _err_ == nil { err = errors.New("failpoint-error") - }) + } if err != nil { select { case <-ctx.Done(): @@ -243,9 +243,9 @@ func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.L if needReputKV { // if the lease has expired, need to put the key again _, err := tr.client.KV.Put(ctx, tr.key, "", clientv3.WithLease(tr.curLeaseID)) - failpoint.Inject("brie-task-register-failed-to-reput", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-failed-to-reput")); _err_ == nil { err = errors.New("failpoint-error") - }) + } if err != nil { select { case <-ctx.Done(): diff --git a/br/pkg/utils/register.go__failpoint_stash__ b/br/pkg/utils/register.go__failpoint_stash__ new file mode 100644 index 0000000000000..95a102ae68d26 --- /dev/null +++ b/br/pkg/utils/register.go__failpoint_stash__ @@ -0,0 +1,334 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "context" + "fmt" + "path" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// RegisterTaskType for the sub-prefix path for key +type RegisterTaskType int + +const ( + RegisterRestore RegisterTaskType = iota + RegisterLightning + RegisterImportInto +) + +func (tp RegisterTaskType) String() string { + switch tp { + case RegisterRestore: + return "restore" + case RegisterLightning: + return "lightning" + case RegisterImportInto: + return "import-into" + } + return "default" +} + +// The key format should be {RegisterImportTaskPrefix}/{RegisterTaskType}/{taskName} +const ( + // RegisterImportTaskPrefix is the prefix of the key for task register + // todo: remove "/import" suffix, it's confusing to have a key like "/tidb/brie/import/restore/restore-xxx" + RegisterImportTaskPrefix = "/tidb/brie/import" + + RegisterRetryInternal = 10 * time.Second + defaultTaskRegisterTTL = 3 * time.Minute // 3 minutes +) + +// TaskRegister can register the task to PD with a lease. +type TaskRegister interface { + // Close closes the background task if using RegisterTask + // and revoke the lease. + // NOTE: we don't close the etcd client here, call should do it. + Close(ctx context.Context) (err error) + // RegisterTask firstly put its key to PD with a lease, + // and start to keepalive the lease in the background. + // DO NOT mix calls to RegisterTask and RegisterTaskOnce. + RegisterTask(c context.Context) error + // RegisterTaskOnce put its key to PD with a lease if the key does not exist, + // else we refresh the lease. + // you have to call this method periodically to keep the lease alive. + // DO NOT mix calls to RegisterTask and RegisterTaskOnce. + RegisterTaskOnce(ctx context.Context) error +} + +type taskRegister struct { + client *clientv3.Client + ttl time.Duration + secondTTL int64 + key string + + // leaseID used to revoke the lease + curLeaseID clientv3.LeaseID + wg sync.WaitGroup + cancel context.CancelFunc +} + +// NewTaskRegisterWithTTL build a TaskRegister with key format {RegisterTaskPrefix}/{RegisterTaskType}/{taskName} +func NewTaskRegisterWithTTL(client *clientv3.Client, ttl time.Duration, tp RegisterTaskType, taskName string) TaskRegister { + return &taskRegister{ + client: client, + ttl: ttl, + secondTTL: int64(ttl / time.Second), + key: path.Join(RegisterImportTaskPrefix, tp.String(), taskName), + + curLeaseID: clientv3.NoLease, + } +} + +// NewTaskRegister build a TaskRegister with key format {RegisterTaskPrefix}/{RegisterTaskType}/{taskName} +func NewTaskRegister(client *clientv3.Client, tp RegisterTaskType, taskName string) TaskRegister { + return NewTaskRegisterWithTTL(client, defaultTaskRegisterTTL, tp, taskName) +} + +// Close implements the TaskRegister interface +func (tr *taskRegister) Close(ctx context.Context) (err error) { + // not needed if using RegisterTaskOnce + if tr.cancel != nil { + tr.cancel() + } + tr.wg.Wait() + if tr.curLeaseID != clientv3.NoLease { + _, err = tr.client.Lease.Revoke(ctx, tr.curLeaseID) + if err != nil { + log.Warn("failed to revoke the lease", zap.Error(err), zap.Int64("lease-id", int64(tr.curLeaseID))) + } + } + return err +} + +func (tr *taskRegister) grant(ctx context.Context) (*clientv3.LeaseGrantResponse, error) { + lease, err := tr.client.Lease.Grant(ctx, tr.secondTTL) + if err != nil { + return nil, err + } + if len(lease.Error) > 0 { + return nil, errors.New(lease.Error) + } + return lease, nil +} + +// RegisterTaskOnce implements the TaskRegister interface +func (tr *taskRegister) RegisterTaskOnce(ctx context.Context) error { + resp, err := tr.client.Get(ctx, tr.key) + if err != nil { + return errors.Trace(err) + } + if len(resp.Kvs) == 0 { + lease, err2 := tr.grant(ctx) + if err2 != nil { + return errors.Annotatef(err2, "failed grant a lease") + } + tr.curLeaseID = lease.ID + _, err2 = tr.client.KV.Put(ctx, tr.key, "", clientv3.WithLease(lease.ID)) + if err2 != nil { + return errors.Trace(err2) + } + } else { + // if the task is run distributively, like IMPORT INTO, we should refresh the lease ID, + // in case the owner changed during the registration, and the new owner create the key. + tr.curLeaseID = clientv3.LeaseID(resp.Kvs[0].Lease) + _, err2 := tr.client.Lease.KeepAliveOnce(ctx, tr.curLeaseID) + if err2 != nil { + return errors.Trace(err2) + } + } + return nil +} + +// RegisterTask implements the TaskRegister interface +func (tr *taskRegister) RegisterTask(c context.Context) error { + cctx, cancel := context.WithCancel(c) + tr.cancel = cancel + lease, err := tr.grant(cctx) + if err != nil { + return errors.Annotatef(err, "failed grant a lease") + } + tr.curLeaseID = lease.ID + _, err = tr.client.KV.Put(cctx, tr.key, "", clientv3.WithLease(lease.ID)) + if err != nil { + return errors.Trace(err) + } + + // KeepAlive interval equals to ttl/3 + respCh, err := tr.client.Lease.KeepAlive(cctx, lease.ID) + if err != nil { + return errors.Trace(err) + } + tr.wg.Add(1) + go tr.keepaliveLoop(cctx, respCh) + return nil +} + +func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.LeaseKeepAliveResponse) { + defer tr.wg.Done() + const minTimeLeftThreshold time.Duration = 20 * time.Second + var ( + timeLeftThreshold time.Duration = tr.ttl / 4 + lastUpdateTime time.Time = time.Now() + err error + ) + if timeLeftThreshold < minTimeLeftThreshold { + timeLeftThreshold = minTimeLeftThreshold + } + failpoint.Inject("brie-task-register-always-grant", func(_ failpoint.Value) { + timeLeftThreshold = tr.ttl + }) + for { + CONSUMERESP: + for { + failpoint.Inject("brie-task-register-keepalive-stop", func(_ failpoint.Value) { + if _, err = tr.client.Lease.Revoke(ctx, tr.curLeaseID); err != nil { + log.Warn("brie-task-register-keepalive-stop", zap.Error(err)) + } + }) + select { + case <-ctx.Done(): + return + case _, ok := <-ch: + if !ok { + break CONSUMERESP + } + lastUpdateTime = time.Now() + } + } + log.Warn("the keepalive channel is closed, try to recreate it") + needReputKV := false + RECREATE: + for { + timeGap := time.Since(lastUpdateTime) + if tr.ttl-timeGap <= timeLeftThreshold { + lease, err := tr.grant(ctx) + failpoint.Inject("brie-task-register-failed-to-grant", func(_ failpoint.Value) { + err = errors.New("failpoint-error") + }) + if err != nil { + select { + case <-ctx.Done(): + return + default: + } + log.Warn("failed to grant lease", zap.Error(err)) + time.Sleep(RegisterRetryInternal) + continue + } + tr.curLeaseID = lease.ID + lastUpdateTime = time.Now() + needReputKV = true + } + if needReputKV { + // if the lease has expired, need to put the key again + _, err := tr.client.KV.Put(ctx, tr.key, "", clientv3.WithLease(tr.curLeaseID)) + failpoint.Inject("brie-task-register-failed-to-reput", func(_ failpoint.Value) { + err = errors.New("failpoint-error") + }) + if err != nil { + select { + case <-ctx.Done(): + return + default: + } + log.Warn("failed to put new kv", zap.Error(err)) + time.Sleep(RegisterRetryInternal) + continue + } + needReputKV = false + } + // recreate keepalive + ch, err = tr.client.Lease.KeepAlive(ctx, tr.curLeaseID) + if err != nil { + select { + case <-ctx.Done(): + return + default: + } + log.Warn("failed to create new kv", zap.Error(err)) + time.Sleep(RegisterRetryInternal) + continue + } + + break RECREATE + } + } +} + +// RegisterTask saves the task's information +type RegisterTask struct { + Key string + LeaseID int64 + TTL int64 +} + +// MessageToUser marshal the task to user message +func (task RegisterTask) MessageToUser() string { + return fmt.Sprintf("[ key: %s, lease-id: %x, ttl: %ds ]", task.Key, task.LeaseID, task.TTL) +} + +type RegisterTasksList struct { + Tasks []RegisterTask +} + +func (list RegisterTasksList) MessageToUser() string { + var tasksMsgBuf strings.Builder + for _, task := range list.Tasks { + tasksMsgBuf.WriteString(task.MessageToUser()) + tasksMsgBuf.WriteString(", ") + } + return tasksMsgBuf.String() +} + +func (list RegisterTasksList) Empty() bool { + return len(list.Tasks) == 0 +} + +// GetImportTasksFrom try to get all the import tasks with prefix `RegisterTaskPrefix` +func GetImportTasksFrom(ctx context.Context, client *clientv3.Client) (RegisterTasksList, error) { + resp, err := client.KV.Get(ctx, RegisterImportTaskPrefix, clientv3.WithPrefix()) + if err != nil { + return RegisterTasksList{}, errors.Trace(err) + } + + list := RegisterTasksList{ + Tasks: make([]RegisterTask, 0, len(resp.Kvs)), + } + for _, kv := range resp.Kvs { + leaseResp, err := client.Lease.TimeToLive(ctx, clientv3.LeaseID(kv.Lease)) + if err != nil { + return list, errors.Annotatef(err, "failed to get time-to-live of lease: %x", kv.Lease) + } + // the lease has expired + if leaseResp.TTL <= 0 { + continue + } + list.Tasks = append(list.Tasks, RegisterTask{ + Key: string(kv.Key), + LeaseID: kv.Lease, + TTL: leaseResp.TTL, + }) + } + return list, nil +} diff --git a/br/pkg/utils/store_manager.go b/br/pkg/utils/store_manager.go index 73e7e3fbb7a07..cbddafd5bfe70 100644 --- a/br/pkg/utils/store_manager.go +++ b/br/pkg/utils/store_manager.go @@ -119,7 +119,7 @@ func (mgr *StoreManager) PDClient() pd.Client { } func (mgr *StoreManager) getGrpcConnLocked(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - failpoint.Inject("hint-get-backup-client", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("hint-get-backup-client")); _err_ == nil { log.Info("failpoint hint-get-backup-client injected, "+ "process will notify the shell.", zap.Uint64("store", storeID)) if sigFile, ok := v.(string); ok { @@ -132,7 +132,7 @@ func (mgr *StoreManager) getGrpcConnLocked(ctx context.Context, storeID uint64) } } time.Sleep(3 * time.Second) - }) + } store, err := mgr.pdClient.GetStore(ctx, storeID) if err != nil { return nil, errors.Trace(err) diff --git a/br/pkg/utils/store_manager.go__failpoint_stash__ b/br/pkg/utils/store_manager.go__failpoint_stash__ new file mode 100644 index 0000000000000..73e7e3fbb7a07 --- /dev/null +++ b/br/pkg/utils/store_manager.go__failpoint_stash__ @@ -0,0 +1,264 @@ +// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. + +package utils + +import ( + "context" + "crypto/tls" + "os" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" +) + +const ( + dialTimeout = 30 * time.Second + resetRetryTimes = 3 +) + +// Pool is a lazy pool of gRPC channels. +// When `Get` called, it lazily allocates new connection if connection not full. +// If it's full, then it will return allocated channels round-robin. +type Pool struct { + mu sync.Mutex + + conns []*grpc.ClientConn + next int + cap int + newConn func(ctx context.Context) (*grpc.ClientConn, error) +} + +func (p *Pool) takeConns() (conns []*grpc.ClientConn) { + p.mu.Lock() + defer p.mu.Unlock() + p.conns, conns = nil, p.conns + p.next = 0 + return conns +} + +// Close closes the conn pool. +func (p *Pool) Close() { + for _, c := range p.takeConns() { + if err := c.Close(); err != nil { + log.Warn("failed to close clientConn", zap.String("target", c.Target()), zap.Error(err)) + } + } +} + +// Get tries to get an existing connection from the pool, or make a new one if the pool not full. +func (p *Pool) Get(ctx context.Context) (*grpc.ClientConn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.conns) < p.cap { + c, err := p.newConn(ctx) + if err != nil { + return nil, err + } + p.conns = append(p.conns, c) + return c, nil + } + + conn := p.conns[p.next] + p.next = (p.next + 1) % p.cap + return conn, nil +} + +// NewConnPool creates a new Pool by the specified conn factory function and capacity. +func NewConnPool(capacity int, newConn func(ctx context.Context) (*grpc.ClientConn, error)) *Pool { + return &Pool{ + cap: capacity, + conns: make([]*grpc.ClientConn, 0, capacity), + newConn: newConn, + + mu: sync.Mutex{}, + } +} + +type StoreManager struct { + pdClient pd.Client + grpcClis struct { + mu sync.Mutex + clis map[uint64]*grpc.ClientConn + } + keepalive keepalive.ClientParameters + tlsConf *tls.Config +} + +func (mgr *StoreManager) GetKeepalive() keepalive.ClientParameters { + return mgr.keepalive +} + +// NewStoreManager create a new manager for gRPC connections to stores. +func NewStoreManager(pdCli pd.Client, kl keepalive.ClientParameters, tlsConf *tls.Config) *StoreManager { + return &StoreManager{ + pdClient: pdCli, + grpcClis: struct { + mu sync.Mutex + clis map[uint64]*grpc.ClientConn + }{clis: make(map[uint64]*grpc.ClientConn)}, + keepalive: kl, + tlsConf: tlsConf, + } +} + +func (mgr *StoreManager) PDClient() pd.Client { + return mgr.pdClient +} + +func (mgr *StoreManager) getGrpcConnLocked(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { + failpoint.Inject("hint-get-backup-client", func(v failpoint.Value) { + log.Info("failpoint hint-get-backup-client injected, "+ + "process will notify the shell.", zap.Uint64("store", storeID)) + if sigFile, ok := v.(string); ok { + file, err := os.Create(sigFile) + if err != nil { + log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) + } + if file != nil { + file.Close() + } + } + time.Sleep(3 * time.Second) + }) + store, err := mgr.pdClient.GetStore(ctx, storeID) + if err != nil { + return nil, errors.Trace(err) + } + opt := grpc.WithTransportCredentials(insecure.NewCredentials()) + if mgr.tlsConf != nil { + opt = grpc.WithTransportCredentials(credentials.NewTLS(mgr.tlsConf)) + } + ctx, cancel := context.WithTimeout(ctx, dialTimeout) + bfConf := backoff.DefaultConfig + bfConf.MaxDelay = time.Second * 3 + addr := store.GetPeerAddress() + if addr == "" { + addr = store.GetAddress() + } + log.Info("StoreManager: dialing to store.", zap.String("address", addr), zap.Uint64("store-id", storeID)) + conn, err := grpc.DialContext( + ctx, + addr, + opt, + grpc.WithBlock(), + grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), + grpc.WithKeepaliveParams(mgr.keepalive), + ) + cancel() + if err != nil { + return nil, berrors.ErrFailedToConnect.Wrap(err).GenWithStack("failed to make connection to store %d", storeID) + } + return conn, nil +} + +func (mgr *StoreManager) RemoveConn(ctx context.Context, storeID uint64) error { + if ctx.Err() != nil { + return errors.Trace(ctx.Err()) + } + + mgr.grpcClis.mu.Lock() + defer mgr.grpcClis.mu.Unlock() + + if conn, ok := mgr.grpcClis.clis[storeID]; ok { + // Find a cached backup client. + err := conn.Close() + if err != nil { + log.Warn("close backup connection failed, ignore it", zap.Uint64("storeID", storeID)) + } + delete(mgr.grpcClis.clis, storeID) + return nil + } + return nil +} + +func (mgr *StoreManager) TryWithConn(ctx context.Context, storeID uint64, f func(*grpc.ClientConn) error) error { + if ctx.Err() != nil { + return errors.Trace(ctx.Err()) + } + + mgr.grpcClis.mu.Lock() + defer mgr.grpcClis.mu.Unlock() + + if conn, ok := mgr.grpcClis.clis[storeID]; ok { + // Find a cached backup client. + return f(conn) + } + + conn, err := mgr.getGrpcConnLocked(ctx, storeID) + if err != nil { + return errors.Trace(err) + } + // Cache the conn. + mgr.grpcClis.clis[storeID] = conn + return f(conn) +} + +func (mgr *StoreManager) WithConn(ctx context.Context, storeID uint64, f func(*grpc.ClientConn)) error { + return mgr.TryWithConn(ctx, storeID, func(cc *grpc.ClientConn) error { f(cc); return nil }) +} + +// ResetBackupClient reset the connection for backup client. +func (mgr *StoreManager) ResetBackupClient(ctx context.Context, storeID uint64) (backuppb.BackupClient, error) { + var ( + conn *grpc.ClientConn + err error + ) + err = mgr.RemoveConn(ctx, storeID) + if err != nil { + return nil, errors.Trace(err) + } + + mgr.grpcClis.mu.Lock() + defer mgr.grpcClis.mu.Unlock() + + for retry := 0; retry < resetRetryTimes; retry++ { + conn, err = mgr.getGrpcConnLocked(ctx, storeID) + if err != nil { + log.Warn("failed to reset grpc connection, retry it", + zap.Int("retry time", retry), logutil.ShortError(err)) + time.Sleep(time.Duration(retry+3) * time.Second) + continue + } + mgr.grpcClis.clis[storeID] = conn + break + } + if err != nil { + return nil, errors.Trace(err) + } + return backuppb.NewBackupClient(conn), nil +} + +// Close closes all client in Mgr. +func (mgr *StoreManager) Close() { + if mgr == nil { + return + } + mgr.grpcClis.mu.Lock() + for _, cli := range mgr.grpcClis.clis { + err := cli.Close() + if err != nil { + log.Error("fail to close Mgr", zap.Error(err)) + } + } + mgr.grpcClis.mu.Unlock() +} + +func (mgr *StoreManager) TLSConfig() *tls.Config { + if mgr == nil { + return nil + } + return mgr.tlsConf +} diff --git a/dumpling/export/binding__failpoint_binding__.go b/dumpling/export/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..e39dec4835192 --- /dev/null +++ b/dumpling/export/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package export + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/dumpling/export/config.go b/dumpling/export/config.go index 52337ec732601..a799184ae43c1 100644 --- a/dumpling/export/config.go +++ b/dumpling/export/config.go @@ -291,11 +291,11 @@ func (conf *Config) GetDriverConfig(db string) *mysql.Config { if conf.AllowCleartextPasswords { driverCfg.AllowCleartextPasswords = true } - failpoint.Inject("SetWaitTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SetWaitTimeout")); _err_ == nil { driverCfg.Params = map[string]string{ "wait_timeout": strconv.Itoa(val.(int)), } - }) + } return driverCfg } diff --git a/dumpling/export/config.go__failpoint_stash__ b/dumpling/export/config.go__failpoint_stash__ new file mode 100644 index 0000000000000..52337ec732601 --- /dev/null +++ b/dumpling/export/config.go__failpoint_stash__ @@ -0,0 +1,809 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package export + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net" + "strconv" + "strings" + "text/template" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/docker/go-units" + "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/promutil" + filter "github.com/pingcap/tidb/pkg/util/table-filter" + "github.com/prometheus/client_golang/prometheus" + "github.com/spf13/pflag" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + flagDatabase = "database" + flagTablesList = "tables-list" + flagHost = "host" + flagUser = "user" + flagPort = "port" + flagPassword = "password" + flagAllowCleartextPasswords = "allow-cleartext-passwords" + flagThreads = "threads" + flagFilesize = "filesize" + flagStatementSize = "statement-size" + flagOutput = "output" + flagLoglevel = "loglevel" + flagLogfile = "logfile" + flagLogfmt = "logfmt" + flagConsistency = "consistency" + flagSnapshot = "snapshot" + flagNoViews = "no-views" + flagNoSequences = "no-sequences" + flagSortByPk = "order-by-primary-key" + flagStatusAddr = "status-addr" + flagRows = "rows" + flagWhere = "where" + flagEscapeBackslash = "escape-backslash" + flagFiletype = "filetype" + flagNoHeader = "no-header" + flagNoSchemas = "no-schemas" + flagNoData = "no-data" + flagCsvNullValue = "csv-null-value" + flagSQL = "sql" + flagFilter = "filter" + flagCaseSensitive = "case-sensitive" + flagDumpEmptyDatabase = "dump-empty-database" + flagTidbMemQuotaQuery = "tidb-mem-quota-query" + flagCA = "ca" + flagCert = "cert" + flagKey = "key" + flagCsvSeparator = "csv-separator" + flagCsvDelimiter = "csv-delimiter" + flagCsvLineTerminator = "csv-line-terminator" + flagOutputFilenameTemplate = "output-filename-template" + flagCompleteInsert = "complete-insert" + flagParams = "params" + flagReadTimeout = "read-timeout" + flagTransactionalConsistency = "transactional-consistency" + flagCompress = "compress" + flagCsvOutputDialect = "csv-output-dialect" + + // FlagHelp represents the help flag + FlagHelp = "help" +) + +// CSVDialect is the dialect of the CSV output for compatible with different import target +type CSVDialect int + +const ( + // CSVDialectDefault is the default dialect, which is MySQL/MariaDB/TiDB etc. + CSVDialectDefault CSVDialect = iota + // CSVDialectSnowflake is the dialect of Snowflake + CSVDialectSnowflake + // CSVDialectRedshift is the dialect of Redshift + CSVDialectRedshift + // CSVDialectBigQuery is the dialect of BigQuery + CSVDialectBigQuery +) + +// BinaryFormat is the format of binary data +// Three standard formats are supported: UTF8, HEX and Base64 now. +type BinaryFormat int + +const ( + // BinaryFormatUTF8 is the default format, format binary data as UTF8 string + BinaryFormatUTF8 BinaryFormat = iota + // BinaryFormatHEX format binary data as HEX string, e.g. 12ABCD + BinaryFormatHEX + // BinaryFormatBase64 format binary data as Base64 string, e.g. 123qwer== + BinaryFormatBase64 +) + +// DialectBinaryFormatMap is the map of dialect and binary format +var DialectBinaryFormatMap = map[CSVDialect]BinaryFormat{ + CSVDialectDefault: BinaryFormatUTF8, + CSVDialectSnowflake: BinaryFormatHEX, + CSVDialectRedshift: BinaryFormatHEX, + CSVDialectBigQuery: BinaryFormatBase64, +} + +// Config is the dump config for dumpling +type Config struct { + storage.BackendOptions + + SpecifiedTables bool + AllowCleartextPasswords bool + SortByPk bool + NoViews bool + NoSequences bool + NoHeader bool + NoSchemas bool + NoData bool + CompleteInsert bool + TransactionalConsistency bool + EscapeBackslash bool + DumpEmptyDatabase bool + PosAfterConnect bool + CompressType storage.CompressType + + Host string + Port int + Threads int + User string + Password string `json:"-"` + Security struct { + TLS *tls.Config `json:"-"` + CAPath string + CertPath string + KeyPath string + SSLCABytes []byte `json:"-"` + SSLCertBytes []byte `json:"-"` + SSLKeyBytes []byte `json:"-"` + } + + LogLevel string + LogFile string + LogFormat string + OutputDirPath string + StatusAddr string + Snapshot string + Consistency string + CsvNullValue string + SQL string + CsvSeparator string + CsvDelimiter string + CsvLineTerminator string + Databases []string + + TableFilter filter.Filter `json:"-"` + Where string + FileType string + ServerInfo version.ServerInfo + Logger *zap.Logger `json:"-"` + OutputFileTemplate *template.Template `json:"-"` + Rows uint64 + ReadTimeout time.Duration + TiDBMemQuotaQuery uint64 + FileSize uint64 + StatementSize uint64 + SessionParams map[string]any + Tables DatabaseTables + CollationCompatible string + CsvOutputDialect CSVDialect + + Labels prometheus.Labels `json:"-"` + PromFactory promutil.Factory `json:"-"` + PromRegistry promutil.Registry `json:"-"` + ExtStorage storage.ExternalStorage `json:"-"` + MinTLSVersion uint16 `json:"-"` + + IOTotalBytes *atomic.Uint64 + Net string +} + +// ServerInfoUnknown is the unknown database type to dumpling +var ServerInfoUnknown = version.ServerInfo{ + ServerType: version.ServerTypeUnknown, + ServerVersion: nil, +} + +// DefaultConfig returns the default export Config for dumpling +func DefaultConfig() *Config { + allFilter, _ := filter.Parse([]string{"*.*"}) + return &Config{ + Databases: nil, + Host: "127.0.0.1", + User: "root", + Port: 3306, + Password: "", + Threads: 4, + Logger: nil, + StatusAddr: ":8281", + FileSize: UnspecifiedSize, + StatementSize: DefaultStatementSize, + OutputDirPath: ".", + ServerInfo: ServerInfoUnknown, + SortByPk: true, + Tables: nil, + Snapshot: "", + Consistency: ConsistencyTypeAuto, + NoViews: true, + NoSequences: true, + Rows: UnspecifiedSize, + Where: "", + EscapeBackslash: true, + FileType: "", + NoHeader: false, + NoSchemas: false, + NoData: false, + CsvNullValue: "\\N", + SQL: "", + TableFilter: allFilter, + DumpEmptyDatabase: true, + CsvDelimiter: "\"", + CsvSeparator: ",", + CsvLineTerminator: "\r\n", + SessionParams: make(map[string]any), + OutputFileTemplate: DefaultOutputFileTemplate, + PosAfterConnect: false, + CollationCompatible: LooseCollationCompatible, + CsvOutputDialect: CSVDialectDefault, + SpecifiedTables: false, + PromFactory: promutil.NewDefaultFactory(), + PromRegistry: promutil.NewDefaultRegistry(), + TransactionalConsistency: true, + } +} + +// String returns dumpling's config in json format +func (conf *Config) String() string { + cfg, err := json.Marshal(conf) + if err != nil && conf.Logger != nil { + conf.Logger.Error("fail to marshal config to json", zap.Error(err)) + } + return string(cfg) +} + +// GetDriverConfig returns the MySQL driver config from Config. +func (conf *Config) GetDriverConfig(db string) *mysql.Config { + driverCfg := mysql.NewConfig() + // maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection. + // https://github.com/go-sql-driver/mysql#maxallowedpacket + hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port)) + driverCfg.User = conf.User + driverCfg.Passwd = conf.Password + driverCfg.Net = "tcp" + if conf.Net != "" { + driverCfg.Net = conf.Net + } + driverCfg.Addr = hostPort + driverCfg.DBName = db + driverCfg.Collation = "utf8mb4_general_ci" + driverCfg.ReadTimeout = conf.ReadTimeout + driverCfg.WriteTimeout = 30 * time.Second + driverCfg.InterpolateParams = true + driverCfg.MaxAllowedPacket = 0 + if conf.Security.TLS != nil { + driverCfg.TLS = conf.Security.TLS + } else { + // Use TLS first. + driverCfg.AllowFallbackToPlaintext = true + minTLSVersion := uint16(tls.VersionTLS12) + if conf.MinTLSVersion != 0 { + minTLSVersion = conf.MinTLSVersion + } + /* #nosec G402 */ + driverCfg.TLS = &tls.Config{ + InsecureSkipVerify: true, + MinVersion: minTLSVersion, + NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2. + } + } + if conf.AllowCleartextPasswords { + driverCfg.AllowCleartextPasswords = true + } + failpoint.Inject("SetWaitTimeout", func(val failpoint.Value) { + driverCfg.Params = map[string]string{ + "wait_timeout": strconv.Itoa(val.(int)), + } + }) + return driverCfg +} + +func timestampDirName() string { + return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339)) +} + +// DefineFlags defines flags of dumpling's configuration +func (*Config) DefineFlags(flags *pflag.FlagSet) { + storage.DefineFlags(flags) + flags.StringSliceP(flagDatabase, "B", nil, "Databases to dump") + flags.StringSliceP(flagTablesList, "T", nil, "Comma delimited table list to dump; must be qualified table names") + flags.StringP(flagHost, "h", "127.0.0.1", "The host to connect to") + flags.StringP(flagUser, "u", "root", "Username with privileges to run the dump") + flags.IntP(flagPort, "P", 4000, "TCP/IP port to connect to") + flags.StringP(flagPassword, "p", "", "User password") + flags.Bool(flagAllowCleartextPasswords, false, "Allow passwords to be sent in cleartext (warning: don't use without TLS)") + flags.IntP(flagThreads, "t", 4, "Number of goroutines to use, default 4") + flags.StringP(flagFilesize, "F", "", "The approximate size of output file") + flags.Uint64P(flagStatementSize, "s", DefaultStatementSize, "Attempted size of INSERT statement in bytes") + flags.StringP(flagOutput, "o", timestampDirName(), "Output directory") + flags.String(flagLoglevel, "info", "Log level: {debug|info|warn|error|dpanic|panic|fatal}") + flags.StringP(flagLogfile, "L", "", "Log file `path`, leave empty to write to console") + flags.String(flagLogfmt, "text", "Log `format`: {text|json}") + flags.String(flagConsistency, ConsistencyTypeAuto, "Consistency level during dumping: {auto|none|flush|lock|snapshot}") + flags.String(flagSnapshot, "", "Snapshot position (uint64 or MySQL style string timestamp). Valid only when consistency=snapshot") + flags.BoolP(flagNoViews, "W", true, "Do not dump views") + flags.Bool(flagNoSequences, true, "Do not dump sequences") + flags.Bool(flagSortByPk, true, "Sort dump results by primary key through order by sql") + flags.String(flagStatusAddr, ":8281", "dumpling API server and pprof addr") + flags.Uint64P(flagRows, "r", UnspecifiedSize, "If specified, dumpling will split table into chunks and concurrently dump them to different files to improve efficiency. For TiDB v3.0+, specify this will make dumpling split table with each file one TiDB region(no matter how many rows is).\n"+ + "If not specified, dumpling will dump table without inner-concurrency which could be relatively slow. default unlimited") + flags.String(flagWhere, "", "Dump only selected records") + flags.Bool(flagEscapeBackslash, true, "use backslash to escape special characters") + flags.String(flagFiletype, "", "The type of export file (sql/csv)") + flags.Bool(flagNoHeader, false, "whether not to dump CSV table header") + flags.BoolP(flagNoSchemas, "m", false, "Do not dump table schemas with the data") + flags.BoolP(flagNoData, "d", false, "Do not dump table data") + flags.String(flagCsvNullValue, "\\N", "The null value used when export to csv") + flags.StringP(flagSQL, "S", "", "Dump data with given sql. This argument doesn't support concurrent dump") + _ = flags.MarkHidden(flagSQL) + flags.StringSliceP(flagFilter, "f", []string{"*.*", DefaultTableFilter}, "filter to select which tables to dump") + flags.Bool(flagCaseSensitive, false, "whether the filter should be case-sensitive") + flags.Bool(flagDumpEmptyDatabase, true, "whether to dump empty database") + flags.Uint64(flagTidbMemQuotaQuery, UnspecifiedSize, "The maximum memory limit for a single SQL statement, in bytes.") + flags.String(flagCA, "", "The path name to the certificate authority file for TLS connection") + flags.String(flagCert, "", "The path name to the client certificate file for TLS connection") + flags.String(flagKey, "", "The path name to the client private key file for TLS connection") + flags.String(flagCsvSeparator, ",", "The separator for csv files, default ','") + flags.String(flagCsvDelimiter, "\"", "The delimiter for values in csv files, default '\"'") + flags.String(flagCsvLineTerminator, "\r\n", "The line terminator for csv files, default '\\r\\n'") + flags.String(flagOutputFilenameTemplate, "", "The output filename template (without file extension)") + flags.Bool(flagCompleteInsert, false, "Use complete INSERT statements that include column names") + flags.StringToString(flagParams, nil, `Extra session variables used while dumping, accepted format: --params "character_set_client=latin1,character_set_connection=latin1"`) + flags.Bool(FlagHelp, false, "Print help message and quit") + flags.Duration(flagReadTimeout, 15*time.Minute, "I/O read timeout for db connection.") + _ = flags.MarkHidden(flagReadTimeout) + flags.Bool(flagTransactionalConsistency, true, "Only support transactional consistency") + _ = flags.MarkHidden(flagTransactionalConsistency) + flags.StringP(flagCompress, "c", "", "Compress output file type, support 'gzip', 'snappy', 'zstd', 'no-compression' now") + flags.String(flagCsvOutputDialect, "", "The dialect of output CSV file, support 'snowflake', 'redshift', 'bigquery' now") +} + +// ParseFromFlags parses dumpling's export.Config from flags +// nolint: gocyclo +func (conf *Config) ParseFromFlags(flags *pflag.FlagSet) error { + var err error + conf.Databases, err = flags.GetStringSlice(flagDatabase) + if err != nil { + return errors.Trace(err) + } + conf.Host, err = flags.GetString(flagHost) + if err != nil { + return errors.Trace(err) + } + conf.User, err = flags.GetString(flagUser) + if err != nil { + return errors.Trace(err) + } + conf.Port, err = flags.GetInt(flagPort) + if err != nil { + return errors.Trace(err) + } + conf.Password, err = flags.GetString(flagPassword) + if err != nil { + return errors.Trace(err) + } + conf.AllowCleartextPasswords, err = flags.GetBool(flagAllowCleartextPasswords) + if err != nil { + return errors.Trace(err) + } + conf.Threads, err = flags.GetInt(flagThreads) + if err != nil { + return errors.Trace(err) + } + conf.StatementSize, err = flags.GetUint64(flagStatementSize) + if err != nil { + return errors.Trace(err) + } + conf.OutputDirPath, err = flags.GetString(flagOutput) + if err != nil { + return errors.Trace(err) + } + conf.LogLevel, err = flags.GetString(flagLoglevel) + if err != nil { + return errors.Trace(err) + } + conf.LogFile, err = flags.GetString(flagLogfile) + if err != nil { + return errors.Trace(err) + } + conf.LogFormat, err = flags.GetString(flagLogfmt) + if err != nil { + return errors.Trace(err) + } + conf.Consistency, err = flags.GetString(flagConsistency) + if err != nil { + return errors.Trace(err) + } + conf.Snapshot, err = flags.GetString(flagSnapshot) + if err != nil { + return errors.Trace(err) + } + conf.NoViews, err = flags.GetBool(flagNoViews) + if err != nil { + return errors.Trace(err) + } + conf.NoSequences, err = flags.GetBool(flagNoSequences) + if err != nil { + return errors.Trace(err) + } + conf.SortByPk, err = flags.GetBool(flagSortByPk) + if err != nil { + return errors.Trace(err) + } + conf.StatusAddr, err = flags.GetString(flagStatusAddr) + if err != nil { + return errors.Trace(err) + } + conf.Rows, err = flags.GetUint64(flagRows) + if err != nil { + return errors.Trace(err) + } + conf.Where, err = flags.GetString(flagWhere) + if err != nil { + return errors.Trace(err) + } + conf.EscapeBackslash, err = flags.GetBool(flagEscapeBackslash) + if err != nil { + return errors.Trace(err) + } + conf.FileType, err = flags.GetString(flagFiletype) + if err != nil { + return errors.Trace(err) + } + conf.NoHeader, err = flags.GetBool(flagNoHeader) + if err != nil { + return errors.Trace(err) + } + conf.NoSchemas, err = flags.GetBool(flagNoSchemas) + if err != nil { + return errors.Trace(err) + } + conf.NoData, err = flags.GetBool(flagNoData) + if err != nil { + return errors.Trace(err) + } + conf.CsvNullValue, err = flags.GetString(flagCsvNullValue) + if err != nil { + return errors.Trace(err) + } + conf.SQL, err = flags.GetString(flagSQL) + if err != nil { + return errors.Trace(err) + } + conf.DumpEmptyDatabase, err = flags.GetBool(flagDumpEmptyDatabase) + if err != nil { + return errors.Trace(err) + } + conf.Security.CAPath, err = flags.GetString(flagCA) + if err != nil { + return errors.Trace(err) + } + conf.Security.CertPath, err = flags.GetString(flagCert) + if err != nil { + return errors.Trace(err) + } + conf.Security.KeyPath, err = flags.GetString(flagKey) + if err != nil { + return errors.Trace(err) + } + conf.CsvSeparator, err = flags.GetString(flagCsvSeparator) + if err != nil { + return errors.Trace(err) + } + conf.CsvDelimiter, err = flags.GetString(flagCsvDelimiter) + if err != nil { + return errors.Trace(err) + } + conf.CsvLineTerminator, err = flags.GetString(flagCsvLineTerminator) + if err != nil { + return errors.Trace(err) + } + conf.CompleteInsert, err = flags.GetBool(flagCompleteInsert) + if err != nil { + return errors.Trace(err) + } + conf.ReadTimeout, err = flags.GetDuration(flagReadTimeout) + if err != nil { + return errors.Trace(err) + } + conf.TransactionalConsistency, err = flags.GetBool(flagTransactionalConsistency) + if err != nil { + return errors.Trace(err) + } + conf.TiDBMemQuotaQuery, err = flags.GetUint64(flagTidbMemQuotaQuery) + if err != nil { + return errors.Trace(err) + } + + if conf.Threads <= 0 { + return errors.Errorf("--threads is set to %d. It should be greater than 0", conf.Threads) + } + if len(conf.CsvSeparator) == 0 { + return errors.New("--csv-separator is set to \"\". It must not be an empty string") + } + + if conf.SessionParams == nil { + conf.SessionParams = make(map[string]any) + } + + tablesList, err := flags.GetStringSlice(flagTablesList) + if err != nil { + return errors.Trace(err) + } + fileSizeStr, err := flags.GetString(flagFilesize) + if err != nil { + return errors.Trace(err) + } + filters, err := flags.GetStringSlice(flagFilter) + if err != nil { + return errors.Trace(err) + } + caseSensitive, err := flags.GetBool(flagCaseSensitive) + if err != nil { + return errors.Trace(err) + } + outputFilenameFormat, err := flags.GetString(flagOutputFilenameTemplate) + if err != nil { + return errors.Trace(err) + } + params, err := flags.GetStringToString(flagParams) + if err != nil { + return errors.Trace(err) + } + + conf.SpecifiedTables = len(tablesList) > 0 + conf.Tables, err = GetConfTables(tablesList) + if err != nil { + return errors.Trace(err) + } + + conf.TableFilter, err = ParseTableFilter(tablesList, filters) + if err != nil { + return errors.Errorf("failed to parse filter: %s", err) + } + + if !caseSensitive { + conf.TableFilter = filter.CaseInsensitive(conf.TableFilter) + } + + conf.FileSize, err = ParseFileSize(fileSizeStr) + if err != nil { + return errors.Trace(err) + } + + if outputFilenameFormat == "" && conf.SQL != "" { + outputFilenameFormat = DefaultAnonymousOutputFileTemplateText + } + tmpl, err := ParseOutputFileTemplate(outputFilenameFormat) + if err != nil { + return errors.Errorf("failed to parse output filename template (--output-filename-template '%s')", outputFilenameFormat) + } + conf.OutputFileTemplate = tmpl + + compressType, err := flags.GetString(flagCompress) + if err != nil { + return errors.Trace(err) + } + conf.CompressType, err = ParseCompressType(compressType) + if err != nil { + return errors.Trace(err) + } + + dialect, err := flags.GetString(flagCsvOutputDialect) + if err != nil { + return errors.Trace(err) + } + if dialect != "" && conf.FileType != "csv" { + return errors.Errorf("%s is only supported when dumping whole table to csv, not compatible with %s", flagCsvOutputDialect, conf.FileType) + } + conf.CsvOutputDialect, err = ParseOutputDialect(dialect) + if err != nil { + return errors.Trace(err) + } + + for k, v := range params { + conf.SessionParams[k] = v + } + + err = conf.BackendOptions.ParseFromFlags(pflag.CommandLine) + if err != nil { + return errors.Trace(err) + } + + return nil +} + +// ParseFileSize parses file size from tables-list and filter arguments +func ParseFileSize(fileSizeStr string) (uint64, error) { + if len(fileSizeStr) == 0 { + return UnspecifiedSize, nil + } else if fileSizeMB, err := strconv.ParseUint(fileSizeStr, 10, 64); err == nil { + fmt.Printf("Warning: -F without unit is not recommended, try using `-F '%dMiB'` in the future\n", fileSizeMB) + return fileSizeMB * units.MiB, nil + } else if size, err := units.RAMInBytes(fileSizeStr); err == nil { + return uint64(size), nil + } + return 0, errors.Errorf("failed to parse filesize (-F '%s')", fileSizeStr) +} + +// ParseTableFilter parses table filter from tables-list and filter arguments +func ParseTableFilter(tablesList, filters []string) (filter.Filter, error) { + if len(tablesList) == 0 { + return filter.Parse(filters) + } + + // only parse -T when -f is default value. otherwise bail out. + if !sameStringArray(filters, []string{"*.*", DefaultTableFilter}) { + return nil, errors.New("cannot pass --tables-list and --filter together") + } + + tableNames := make([]filter.Table, 0, len(tablesList)) + for _, table := range tablesList { + parts := strings.SplitN(table, ".", 2) + if len(parts) < 2 { + return nil, errors.Errorf("--tables-list only accepts qualified table names, but `%s` lacks a dot", table) + } + tableNames = append(tableNames, filter.Table{Schema: parts[0], Name: parts[1]}) + } + + return filter.NewTablesFilter(tableNames...), nil +} + +// GetConfTables parses tables from tables-list and filter arguments +func GetConfTables(tablesList []string) (DatabaseTables, error) { + dbTables := DatabaseTables{} + var ( + tablename string + avgRowLength uint64 + ) + avgRowLength = 0 + for _, tablename = range tablesList { + parts := strings.SplitN(tablename, ".", 2) + if len(parts) < 2 { + return nil, errors.Errorf("--tables-list only accepts qualified table names, but `%s` lacks a dot", tablename) + } + dbName := parts[0] + tbName := parts[1] + dbTables[dbName] = append(dbTables[dbName], &TableInfo{tbName, avgRowLength, TableTypeBase}) + } + return dbTables, nil +} + +// ParseCompressType parses compressType string to storage.CompressType +func ParseCompressType(compressType string) (storage.CompressType, error) { + switch compressType { + case "", "no-compression": + return storage.NoCompression, nil + case "gzip", "gz": + return storage.Gzip, nil + case "snappy": + return storage.Snappy, nil + case "zstd", "zst": + return storage.Zstd, nil + default: + return storage.NoCompression, errors.Errorf("unknown compress type %s", compressType) + } +} + +// ParseOutputDialect parses output dialect string to Dialect +func ParseOutputDialect(outputDialect string) (CSVDialect, error) { + switch outputDialect { + case "", "default": + return CSVDialectDefault, nil + case "snowflake": + return CSVDialectSnowflake, nil + case "redshift": + return CSVDialectRedshift, nil + case "bigquery": + return CSVDialectBigQuery, nil + default: + return CSVDialectDefault, errors.Errorf("unknown output dialect %s", outputDialect) + } +} + +func (conf *Config) createExternalStorage(ctx context.Context) (storage.ExternalStorage, error) { + if conf.ExtStorage != nil { + return conf.ExtStorage, nil + } + b, err := storage.ParseBackend(conf.OutputDirPath, &conf.BackendOptions) + if err != nil { + return nil, errors.Trace(err) + } + + // TODO: support setting httpClient with certification later + return storage.New(ctx, b, &storage.ExternalStorageOptions{}) +} + +const ( + // UnspecifiedSize means the filesize/statement-size is unspecified + UnspecifiedSize = 0 + // DefaultStatementSize is the default statement size + DefaultStatementSize = 1000000 + // TiDBMemQuotaQueryName is the session variable TiDBMemQuotaQuery's name in TiDB + TiDBMemQuotaQueryName = "tidb_mem_quota_query" + // DefaultTableFilter is the default exclude table filter. It will exclude all system databases + DefaultTableFilter = "!/^(mysql|sys|INFORMATION_SCHEMA|PERFORMANCE_SCHEMA|METRICS_SCHEMA|INSPECTION_SCHEMA)$/.*" + + defaultTaskChannelCapacity = 128 + defaultDumpGCSafePointTTL = 5 * 60 + defaultEtcdDialTimeOut = 3 * time.Second + + // LooseCollationCompatible is used in DM, represents a collation setting for best compatibility. + LooseCollationCompatible = "loose" + // StrictCollationCompatible is used in DM, represents a collation setting for correctness. + StrictCollationCompatible = "strict" + + dumplingServiceSafePointPrefix = "dumpling" +) + +var ( + decodeRegionVersion = semver.New("3.0.0") + gcSafePointVersion = semver.New("4.0.0") + tableSampleVersion = semver.New("5.0.0-nightly") +) + +func adjustConfig(conf *Config, fns ...func(*Config) error) error { + for _, f := range fns { + err := f(conf) + if err != nil { + return err + } + } + return nil +} + +func buildTLSConfig(conf *Config) error { + tlsConfig, err := util.NewTLSConfig( + util.WithCAPath(conf.Security.CAPath), + util.WithCertAndKeyPath(conf.Security.CertPath, conf.Security.KeyPath), + util.WithCAContent(conf.Security.SSLCABytes), + util.WithCertAndKeyContent(conf.Security.SSLCertBytes, conf.Security.SSLKeyBytes), + util.WithMinTLSVersion(conf.MinTLSVersion), + ) + if err != nil { + return errors.Trace(err) + } + conf.Security.TLS = tlsConfig + return nil +} + +func validateSpecifiedSQL(conf *Config) error { + if conf.SQL != "" && conf.Where != "" { + return errors.New("can't specify both --sql and --where at the same time. Please try to combine them into --sql") + } + return nil +} + +func adjustFileFormat(conf *Config) error { + conf.FileType = strings.ToLower(conf.FileType) + switch conf.FileType { + case "": + if conf.SQL != "" { + conf.FileType = FileFormatCSVString + } else { + conf.FileType = FileFormatSQLTextString + } + case FileFormatSQLTextString: + if conf.SQL != "" { + return errors.Errorf("unsupported config.FileType '%s' when we specify --sql, please unset --filetype or set it to 'csv'", conf.FileType) + } + case FileFormatCSVString: + default: + return errors.Errorf("unknown config.FileType '%s'", conf.FileType) + } + return nil +} + +func matchMysqlBugversion(info version.ServerInfo) bool { + // if 8.0.3 <= mysql8 version < 8.0.23 + // FLUSH TABLES WITH READ LOCK could block other sessions from executing SHOW TABLE STATUS. + // see more in https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-23.html + if info.ServerType != version.ServerTypeMySQL { + return false + } + currentVersion := info.ServerVersion + bugVersionStart := semver.New("8.0.2") + bugVersionEnd := semver.New("8.0.23") + return bugVersionStart.LessThan(*currentVersion) && currentVersion.LessThan(*bugVersionEnd) +} diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index 0d94c814eaa77..0d8627108e089 100644 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -72,7 +72,7 @@ type Dumper struct { // NewDumper returns a new Dumper func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { - failpoint.Inject("setExtStorage", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("setExtStorage")); _err_ == nil { path := val.(string) b, err := storage.ParseBackend(path, nil) if err != nil { @@ -83,7 +83,7 @@ func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { panic(err) } conf.ExtStorage = s - }) + } tctx, cancelFn := tcontext.Background().WithContext(ctx).WithCancel() d := &Dumper{ @@ -111,7 +111,7 @@ func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { if err != nil { return nil, err } - failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("SetIOTotalBytes")); _err_ == nil { d.conf.IOTotalBytes = gatomic.NewUint64(0) d.conf.Net = uuid.New().String() go func() { @@ -120,7 +120,7 @@ func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { d.tctx.L().Logger.Info("IOTotalBytes", zap.Uint64("IOTotalBytes", d.conf.IOTotalBytes.Load())) } }() - }) + } err = runSteps(d, initLogger, @@ -257,9 +257,9 @@ func (d *Dumper) Dump() (dumpErr error) { } chanSize := defaultTaskChannelCapacity - failpoint.Inject("SmallDumpChanSize", func() { + if _, _err_ := failpoint.Eval(_curpkg_("SmallDumpChanSize")); _err_ == nil { chanSize = 1 - }) + } taskIn, taskOut := infiniteChan[Task]() // todo: refine metrics AddGauge(d.metrics.taskChannelCapacity, float64(chanSize)) @@ -280,7 +280,7 @@ func (d *Dumper) Dump() (dumpErr error) { } } // Inject consistency failpoint test after we release the table lock - failpoint.Inject("ConsistencyCheck", nil) + failpoint.Eval(_curpkg_("ConsistencyCheck")) if conf.PosAfterConnect { // record again, to provide a location to exit safe mode for DM @@ -300,7 +300,7 @@ func (d *Dumper) Dump() (dumpErr error) { tableDataStartTime := time.Now() - failpoint.Inject("PrintTiDBMemQuotaQuery", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("PrintTiDBMemQuotaQuery")); _err_ == nil { row := d.dbHandle.QueryRowContext(tctx, "select @@tidb_mem_quota_query;") var s string err = row.Scan(&s) @@ -309,7 +309,7 @@ func (d *Dumper) Dump() (dumpErr error) { } else { fmt.Printf("tidb_mem_quota_query == %s\n", s) } - }) + } baseConn := newBaseConn(metaConn, true, rebuildMetaConn) if conf.SQL == "" { @@ -321,10 +321,10 @@ func (d *Dumper) Dump() (dumpErr error) { } d.metrics.progressReady.Store(true) close(taskIn) - failpoint.Inject("EnableLogProgress", func() { + if _, _err_ := failpoint.Eval(_curpkg_("EnableLogProgress")); _err_ == nil { time.Sleep(1 * time.Second) tctx.L().Debug("progress ready, sleep 1s") - }) + } _ = baseConn.DBConn.Close() if err := wg.Wait(); err != nil { summary.CollectFailureUnit("dump table data", err) @@ -357,10 +357,10 @@ func (d *Dumper) startWriters(tctx *tcontext.Context, wg *errgroup.Group, taskCh // tctx.L().Debug("finished dumping table data", // zap.String("database", td.Meta.DatabaseName()), // zap.String("table", td.Meta.TableName())) - failpoint.Inject("EnableLogProgress", func() { + if _, _err_ := failpoint.Eval(_curpkg_("EnableLogProgress")); _err_ == nil { time.Sleep(1 * time.Second) tctx.L().Debug("EnableLogProgress, sleep 1s") - }) + } } }) writer.setFinishTaskCallBack(func(task Task) { diff --git a/dumpling/export/dump.go__failpoint_stash__ b/dumpling/export/dump.go__failpoint_stash__ new file mode 100644 index 0000000000000..0d94c814eaa77 --- /dev/null +++ b/dumpling/export/dump.go__failpoint_stash__ @@ -0,0 +1,1704 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package export + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "encoding/hex" + "fmt" + "math/big" + "net" + "slices" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/coreos/go-semver/semver" + // import mysql driver + "github.com/go-sql-driver/mysql" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + pclog "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/summary" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/dumpling/cli" + tcontext "github.com/pingcap/tidb/dumpling/context" + "github.com/pingcap/tidb/dumpling/log" + infoschema "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/codec" + pd "github.com/tikv/pd/client" + gatomic "go.uber.org/atomic" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var openDBFunc = openDB + +var errEmptyHandleVals = errors.New("empty handleVals for TiDB table") + +// After TiDB v6.2.0 we always enable tidb_enable_paging by default. +// see https://docs.pingcap.com/zh/tidb/dev/system-variables#tidb_enable_paging-%E4%BB%8E-v540-%E7%89%88%E6%9C%AC%E5%BC%80%E5%A7%8B%E5%BC%95%E5%85%A5 +var enablePagingVersion = semver.New("6.2.0") + +// Dumper is the dump progress structure +type Dumper struct { + tctx *tcontext.Context + cancelCtx context.CancelFunc + conf *Config + metrics *metrics + + extStore storage.ExternalStorage + dbHandle *sql.DB + + tidbPDClientForGC pd.Client + selectTiDBTableRegionFunc func(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) + totalTables int64 + charsetAndDefaultCollationMap map[string]string + + speedRecorder *SpeedRecorder +} + +// NewDumper returns a new Dumper +func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { + failpoint.Inject("setExtStorage", func(val failpoint.Value) { + path := val.(string) + b, err := storage.ParseBackend(path, nil) + if err != nil { + panic(err) + } + s, err := storage.New(context.Background(), b, &storage.ExternalStorageOptions{}) + if err != nil { + panic(err) + } + conf.ExtStorage = s + }) + + tctx, cancelFn := tcontext.Background().WithContext(ctx).WithCancel() + d := &Dumper{ + tctx: tctx, + conf: conf, + cancelCtx: cancelFn, + selectTiDBTableRegionFunc: selectTiDBTableRegion, + speedRecorder: NewSpeedRecorder(), + } + + var err error + + d.metrics = newMetrics(conf.PromFactory, conf.Labels) + d.metrics.registerTo(conf.PromRegistry) + defer func() { + if err != nil { + d.metrics.unregisterFrom(conf.PromRegistry) + } + }() + + err = adjustConfig(conf, + buildTLSConfig, + validateSpecifiedSQL, + adjustFileFormat) + if err != nil { + return nil, err + } + failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { + d.conf.IOTotalBytes = gatomic.NewUint64(0) + d.conf.Net = uuid.New().String() + go func() { + for { + time.Sleep(10 * time.Millisecond) + d.tctx.L().Logger.Info("IOTotalBytes", zap.Uint64("IOTotalBytes", d.conf.IOTotalBytes.Load())) + } + }() + }) + + err = runSteps(d, + initLogger, + createExternalStore, + startHTTPService, + openSQLDB, + detectServerInfo, + resolveAutoConsistency, + + validateResolveAutoConsistency, + tidbSetPDClientForGC, + tidbGetSnapshot, + tidbStartGCSavepointUpdateService, + + setSessionParam) + return d, err +} + +// Dump dumps table from database +// nolint: gocyclo +func (d *Dumper) Dump() (dumpErr error) { + initColTypeRowReceiverMap() + var ( + conn *sql.Conn + err error + conCtrl ConsistencyController + ) + tctx, conf, pool := d.tctx, d.conf, d.dbHandle + tctx.L().Info("begin to run Dump", zap.Stringer("conf", conf)) + m := newGlobalMetadata(tctx, d.extStore, conf.Snapshot) + repeatableRead := needRepeatableRead(conf.ServerInfo.ServerType, conf.Consistency) + defer func() { + if dumpErr == nil { + _ = m.writeGlobalMetaData() + } + }() + + // for consistency lock, we should get table list at first to generate the lock tables SQL + if conf.Consistency == ConsistencyTypeLock { + conn, err = createConnWithConsistency(tctx, pool, repeatableRead) + if err != nil { + return errors.Trace(err) + } + if err = prepareTableListToDump(tctx, conf, conn); err != nil { + _ = conn.Close() + return err + } + _ = conn.Close() + } + + conCtrl, err = NewConsistencyController(tctx, conf, pool) + if err != nil { + return err + } + if err = conCtrl.Setup(tctx); err != nil { + return errors.Trace(err) + } + // To avoid lock is not released + defer func() { + err = conCtrl.TearDown(tctx) + if err != nil { + tctx.L().Warn("fail to tear down consistency controller", zap.Error(err)) + } + }() + + metaConn, err := createConnWithConsistency(tctx, pool, repeatableRead) + if err != nil { + return err + } + defer func() { + _ = metaConn.Close() + }() + m.recordStartTime(time.Now()) + // for consistency lock, we can write snapshot info after all tables are locked. + // the binlog pos may changed because there is still possible write between we lock tables and write master status. + // but for the locked tables doing replication that starts from metadata is safe. + // for consistency flush, record snapshot after whole tables are locked. The recorded meta info is exactly the locked snapshot. + // for consistency snapshot, we should use the snapshot that we get/set at first in metadata. TiDB will assure the snapshot of TSO. + // for consistency none, the binlog pos in metadata might be earlier than dumped data. We need to enable safe-mode to assure data safety. + err = m.recordGlobalMetaData(metaConn, conf.ServerInfo.ServerType, false) + if err != nil { + tctx.L().Info("get global metadata failed", log.ShortError(err)) + } + + if d.conf.CollationCompatible == StrictCollationCompatible { + //init charset and default collation map + d.charsetAndDefaultCollationMap, err = GetCharsetAndDefaultCollation(tctx.Context, metaConn) + if err != nil { + return err + } + } + + // for other consistencies, we should get table list after consistency is set up and GlobalMetaData is cached + if conf.Consistency != ConsistencyTypeLock { + if err = prepareTableListToDump(tctx, conf, metaConn); err != nil { + return err + } + } + if err = d.renewSelectTableRegionFuncForLowerTiDB(tctx); err != nil { + tctx.L().Info("cannot update select table region info for TiDB", log.ShortError(err)) + } + + atomic.StoreInt64(&d.totalTables, int64(calculateTableCount(conf.Tables))) + + rebuildMetaConn := func(conn *sql.Conn, updateMeta bool) (*sql.Conn, error) { + _ = conn.Raw(func(any) error { + // return an `ErrBadConn` to ensure close the connection, but do not put it back to the pool. + // if we choose to use `Close`, it will always put the connection back to the pool. + return driver.ErrBadConn + }) + + newConn, err1 := createConnWithConsistency(tctx, pool, repeatableRead) + if err1 != nil { + return conn, errors.Trace(err1) + } + conn = newConn + // renew the master status after connection. dm can't close safe-mode until dm reaches current pos + if updateMeta && conf.PosAfterConnect { + err1 = m.recordGlobalMetaData(conn, conf.ServerInfo.ServerType, true) + if err1 != nil { + return conn, errors.Trace(err1) + } + } + return conn, nil + } + + rebuildConn := func(conn *sql.Conn, updateMeta bool) (*sql.Conn, error) { + // make sure that the lock connection is still alive + err1 := conCtrl.PingContext(tctx) + if err1 != nil { + return conn, errors.Trace(err1) + } + return rebuildMetaConn(conn, updateMeta) + } + + chanSize := defaultTaskChannelCapacity + failpoint.Inject("SmallDumpChanSize", func() { + chanSize = 1 + }) + taskIn, taskOut := infiniteChan[Task]() + // todo: refine metrics + AddGauge(d.metrics.taskChannelCapacity, float64(chanSize)) + wg, writingCtx := errgroup.WithContext(tctx) + writerCtx := tctx.WithContext(writingCtx) + writers, tearDownWriters, err := d.startWriters(writerCtx, wg, taskOut, rebuildConn) + if err != nil { + return err + } + defer tearDownWriters() + + if conf.TransactionalConsistency { + if conf.Consistency == ConsistencyTypeFlush || conf.Consistency == ConsistencyTypeLock { + tctx.L().Info("All the dumping transactions have started. Start to unlock tables") + } + if err = conCtrl.TearDown(tctx); err != nil { + return errors.Trace(err) + } + } + // Inject consistency failpoint test after we release the table lock + failpoint.Inject("ConsistencyCheck", nil) + + if conf.PosAfterConnect { + // record again, to provide a location to exit safe mode for DM + err = m.recordGlobalMetaData(metaConn, conf.ServerInfo.ServerType, true) + if err != nil { + tctx.L().Info("get global metadata (after connection pool established) failed", log.ShortError(err)) + } + } + + summary.SetLogCollector(summary.NewLogCollector(tctx.L().Info)) + summary.SetUnit(summary.BackupUnit) + defer summary.Summary(summary.BackupUnit) + + logProgressCtx, logProgressCancel := tctx.WithCancel() + go d.runLogProgress(logProgressCtx) + defer logProgressCancel() + + tableDataStartTime := time.Now() + + failpoint.Inject("PrintTiDBMemQuotaQuery", func(_ failpoint.Value) { + row := d.dbHandle.QueryRowContext(tctx, "select @@tidb_mem_quota_query;") + var s string + err = row.Scan(&s) + if err != nil { + fmt.Println(errors.Trace(err)) + } else { + fmt.Printf("tidb_mem_quota_query == %s\n", s) + } + }) + baseConn := newBaseConn(metaConn, true, rebuildMetaConn) + + if conf.SQL == "" { + if err = d.dumpDatabases(writerCtx, baseConn, taskIn); err != nil && !errors.ErrorEqual(err, context.Canceled) { + return err + } + } else { + d.dumpSQL(writerCtx, baseConn, taskIn) + } + d.metrics.progressReady.Store(true) + close(taskIn) + failpoint.Inject("EnableLogProgress", func() { + time.Sleep(1 * time.Second) + tctx.L().Debug("progress ready, sleep 1s") + }) + _ = baseConn.DBConn.Close() + if err := wg.Wait(); err != nil { + summary.CollectFailureUnit("dump table data", err) + return errors.Trace(err) + } + summary.CollectSuccessUnit("dump cost", countTotalTask(writers), time.Since(tableDataStartTime)) + + summary.SetSuccessStatus(true) + m.recordFinishTime(time.Now()) + return nil +} + +func (d *Dumper) startWriters(tctx *tcontext.Context, wg *errgroup.Group, taskChan <-chan Task, + rebuildConnFn func(*sql.Conn, bool) (*sql.Conn, error)) ([]*Writer, func(), error) { + conf, pool := d.conf, d.dbHandle + writers := make([]*Writer, conf.Threads) + for i := 0; i < conf.Threads; i++ { + conn, err := createConnWithConsistency(tctx, pool, needRepeatableRead(conf.ServerInfo.ServerType, conf.Consistency)) + if err != nil { + return nil, func() {}, err + } + writer := NewWriter(tctx, int64(i), conf, conn, d.extStore, d.metrics) + writer.rebuildConnFn = rebuildConnFn + writer.setFinishTableCallBack(func(task Task) { + if _, ok := task.(*TaskTableData); ok { + IncCounter(d.metrics.finishedTablesCounter) + // FIXME: actually finishing the last chunk doesn't means this table is 'finished'. + // We can call this table is 'finished' if all its chunks are finished. + // Comment this log now to avoid ambiguity. + // tctx.L().Debug("finished dumping table data", + // zap.String("database", td.Meta.DatabaseName()), + // zap.String("table", td.Meta.TableName())) + failpoint.Inject("EnableLogProgress", func() { + time.Sleep(1 * time.Second) + tctx.L().Debug("EnableLogProgress, sleep 1s") + }) + } + }) + writer.setFinishTaskCallBack(func(task Task) { + IncGauge(d.metrics.taskChannelCapacity) + if td, ok := task.(*TaskTableData); ok { + d.metrics.completedChunks.Add(1) + tctx.L().Debug("finish dumping table data task", + zap.String("database", td.Meta.DatabaseName()), + zap.String("table", td.Meta.TableName()), + zap.Int("chunkIdx", td.ChunkIndex)) + } + }) + wg.Go(func() error { + return writer.run(taskChan) + }) + writers[i] = writer + } + tearDown := func() { + for _, w := range writers { + _ = w.conn.Close() + } + } + return writers, tearDown, nil +} + +func (d *Dumper) dumpDatabases(tctx *tcontext.Context, metaConn *BaseConn, taskChan chan<- Task) error { + conf := d.conf + allTables := conf.Tables + + // policy should be created before database + // placement policy in other server type can be different, so we only handle the tidb server + if conf.ServerInfo.ServerType == version.ServerTypeTiDB { + policyNames, err := ListAllPlacementPolicyNames(tctx, metaConn) + if err != nil { + errCause := errors.Cause(err) + if mysqlErr, ok := errCause.(*mysql.MySQLError); ok && mysqlErr.Number == ErrNoSuchTable { + // some old tidb version and other server type doesn't support placement rules, we can skip it. + tctx.L().Debug("cannot dump placement policy, maybe the server doesn't support it", log.ShortError(err)) + } else { + tctx.L().Warn("fail to dump placement policy: ", log.ShortError(err)) + } + } + for _, policy := range policyNames { + createPolicySQL, err := ShowCreatePlacementPolicy(tctx, metaConn, policy) + if err != nil { + return errors.Trace(err) + } + wrappedCreatePolicySQL := fmt.Sprintf("/*T![placement] %s */", createPolicySQL) + task := NewTaskPolicyMeta(policy, wrappedCreatePolicySQL) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + } + } + + parser1 := parser.New() + for dbName, tables := range allTables { + if !conf.NoSchemas { + createDatabaseSQL, err := ShowCreateDatabase(tctx, metaConn, dbName) + if err != nil { + return errors.Trace(err) + } + + // adjust db collation + createDatabaseSQL, err = adjustDatabaseCollation(tctx, d.conf.CollationCompatible, parser1, createDatabaseSQL, d.charsetAndDefaultCollationMap) + if err != nil { + return errors.Trace(err) + } + + task := NewTaskDatabaseMeta(dbName, createDatabaseSQL) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + } + + for _, table := range tables { + tctx.L().Debug("start dumping table...", zap.String("database", dbName), + zap.String("table", table.Name)) + meta, err := dumpTableMeta(tctx, conf, metaConn, dbName, table) + if err != nil { + return errors.Trace(err) + } + + if !conf.NoSchemas { + switch table.Type { + case TableTypeView: + task := NewTaskViewMeta(dbName, table.Name, meta.ShowCreateTable(), meta.ShowCreateView()) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + case TableTypeSequence: + task := NewTaskSequenceMeta(dbName, table.Name, meta.ShowCreateTable()) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + default: + // adjust table collation + newCreateSQL, err := adjustTableCollation(tctx, d.conf.CollationCompatible, parser1, meta.ShowCreateTable(), d.charsetAndDefaultCollationMap) + if err != nil { + return errors.Trace(err) + } + meta.(*tableMeta).showCreateTable = newCreateSQL + + task := NewTaskTableMeta(dbName, table.Name, meta.ShowCreateTable()) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + } + } + if table.Type == TableTypeBase { + err = d.dumpTableData(tctx, metaConn, meta, taskChan) + if err != nil { + return errors.Trace(err) + } + } + } + } + return nil +} + +// adjustDatabaseCollation adjusts db collation and return new create sql and collation +func adjustDatabaseCollation(tctx *tcontext.Context, collationCompatible string, parser *parser.Parser, originSQL string, charsetAndDefaultCollationMap map[string]string) (string, error) { + if collationCompatible != StrictCollationCompatible { + return originSQL, nil + } + stmt, err := parser.ParseOneStmt(originSQL, "", "") + if err != nil { + tctx.L().Warn("parse create database error, maybe tidb parser doesn't support it", zap.String("originSQL", originSQL), log.ShortError(err)) + return originSQL, nil + } + createStmt, ok := stmt.(*ast.CreateDatabaseStmt) + if !ok { + return originSQL, nil + } + var charset string + for _, createOption := range createStmt.Options { + // already have 'Collation' + if createOption.Tp == ast.DatabaseOptionCollate { + return originSQL, nil + } + if createOption.Tp == ast.DatabaseOptionCharset { + charset = createOption.Value + } + } + // get db collation + collation, ok := charsetAndDefaultCollationMap[strings.ToLower(charset)] + if !ok { + tctx.L().Warn("not found database charset default collation.", zap.String("originSQL", originSQL), zap.String("charset", strings.ToLower(charset))) + return originSQL, nil + } + // add collation + createStmt.Options = append(createStmt.Options, &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: collation}) + // rewrite sql + var b []byte + bf := bytes.NewBuffer(b) + err = createStmt.Restore(&format.RestoreCtx{ + Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment, + In: bf, + }) + if err != nil { + return "", errors.Trace(err) + } + return bf.String(), nil +} + +// adjustTableCollation adjusts table collation +func adjustTableCollation(tctx *tcontext.Context, collationCompatible string, parser *parser.Parser, originSQL string, charsetAndDefaultCollationMap map[string]string) (string, error) { + if collationCompatible != StrictCollationCompatible { + return originSQL, nil + } + stmt, err := parser.ParseOneStmt(originSQL, "", "") + if err != nil { + tctx.L().Warn("parse create table error, maybe tidb parser doesn't support it", zap.String("originSQL", originSQL), log.ShortError(err)) + return originSQL, nil + } + createStmt, ok := stmt.(*ast.CreateTableStmt) + if !ok { + return originSQL, nil + } + var charset string + var collation string + for _, createOption := range createStmt.Options { + // already have 'Collation' + if createOption.Tp == ast.TableOptionCollate { + collation = createOption.StrValue + break + } + if createOption.Tp == ast.TableOptionCharset { + charset = createOption.StrValue + } + } + + if collation == "" && charset != "" { + collation, ok := charsetAndDefaultCollationMap[strings.ToLower(charset)] + if !ok { + tctx.L().Warn("not found table charset default collation.", zap.String("originSQL", originSQL), zap.String("charset", strings.ToLower(charset))) + return originSQL, nil + } + + // add collation + createStmt.Options = append(createStmt.Options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: collation}) + } + + // adjust columns collation + adjustColumnsCollation(tctx, createStmt, charsetAndDefaultCollationMap) + + // rewrite sql + var b []byte + bf := bytes.NewBuffer(b) + err = createStmt.Restore(&format.RestoreCtx{ + Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment, + In: bf, + }) + if err != nil { + return "", errors.Trace(err) + } + return bf.String(), nil +} + +// adjustColumnsCollation adds column's collation. +func adjustColumnsCollation(tctx *tcontext.Context, createStmt *ast.CreateTableStmt, charsetAndDefaultCollationMap map[string]string) { +ColumnLoop: + for _, col := range createStmt.Cols { + for _, options := range col.Options { + // already have 'Collation' + if options.Tp == ast.ColumnOptionCollate { + continue ColumnLoop + } + } + fieldType := col.Tp + if fieldType.GetCollate() != "" { + continue + } + if fieldType.GetCharset() != "" { + // just have charset + collation, ok := charsetAndDefaultCollationMap[strings.ToLower(fieldType.GetCharset())] + if !ok { + tctx.L().Warn("not found charset default collation for column.", zap.String("table", createStmt.Table.Name.String()), zap.String("column", col.Name.String()), zap.String("charset", strings.ToLower(fieldType.GetCharset()))) + continue + } + fieldType.SetCollate(collation) + } + } +} + +func (d *Dumper) dumpTableData(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { + conf := d.conf + if conf.NoData { + return nil + } + + // Update total rows + fieldName, _ := pickupPossibleField(tctx, meta, conn) + c := estimateCount(tctx, meta.DatabaseName(), meta.TableName(), conn, fieldName, conf) + AddCounter(d.metrics.estimateTotalRowsCounter, float64(c)) + + if conf.Rows == UnspecifiedSize { + return d.sequentialDumpTable(tctx, conn, meta, taskChan) + } + return d.concurrentDumpTable(tctx, conn, meta, taskChan) +} + +func (d *Dumper) buildConcatTask(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (*TaskTableData, error) { + tableChan := make(chan Task, 128) + errCh := make(chan error, 1) + go func() { + // adjust rows to suitable rows for this table + d.conf.Rows = GetSuitableRows(meta.AvgRowLength()) + err := d.concurrentDumpTable(tctx, conn, meta, tableChan) + d.conf.Rows = UnspecifiedSize + if err != nil { + errCh <- err + } else { + close(errCh) + } + }() + tableDataArr := make([]*tableData, 0) + handleSubTask := func(task Task) { + tableTask, ok := task.(*TaskTableData) + if !ok { + tctx.L().Warn("unexpected task when splitting table chunks", zap.String("task", tableTask.Brief())) + return + } + tableDataInst, ok := tableTask.Data.(*tableData) + if !ok { + tctx.L().Warn("unexpected task.Data when splitting table chunks", zap.String("task", tableTask.Brief())) + return + } + tableDataArr = append(tableDataArr, tableDataInst) + d.metrics.totalChunks.Dec() + } + for { + select { + case err, ok := <-errCh: + if !ok { + // make sure all the subtasks in tableChan are handled + for len(tableChan) > 0 { + task := <-tableChan + handleSubTask(task) + } + if len(tableDataArr) <= 1 { + return nil, nil + } + queries := make([]string, 0, len(tableDataArr)) + colLen := tableDataArr[0].colLen + for _, tableDataInst := range tableDataArr { + queries = append(queries, tableDataInst.query) + if colLen != tableDataInst.colLen { + tctx.L().Warn("colLen varies for same table", + zap.Int("oldColLen", colLen), + zap.String("oldQuery", queries[0]), + zap.Int("newColLen", tableDataInst.colLen), + zap.String("newQuery", tableDataInst.query)) + return nil, nil + } + } + return d.newTaskTableData(meta, newMultiQueriesChunk(queries, colLen), 0, 1), nil + } + return nil, err + case task := <-tableChan: + handleSubTask(task) + } + } +} + +func (d *Dumper) dumpWholeTableDirectly(tctx *tcontext.Context, meta TableMeta, taskChan chan<- Task, partition, orderByClause string, currentChunk, totalChunks int) error { + conf := d.conf + tableIR := SelectAllFromTable(conf, meta, partition, orderByClause) + task := d.newTaskTableData(meta, tableIR, currentChunk, totalChunks) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + return nil +} + +func (d *Dumper) sequentialDumpTable(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { + conf := d.conf + if conf.ServerInfo.ServerType == version.ServerTypeTiDB { + task, err := d.buildConcatTask(tctx, conn, meta) + if err != nil { + return errors.Trace(err) + } + if task != nil { + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + return nil + } + tctx.L().Info("didn't build tidb concat sqls, will select all from table now", + zap.String("database", meta.DatabaseName()), + zap.String("table", meta.TableName())) + } + orderByClause, err := buildOrderByClause(tctx, conf, conn, meta.DatabaseName(), meta.TableName(), meta.HasImplicitRowID()) + if err != nil { + return err + } + return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) +} + +// concurrentDumpTable tries to split table into several chunks to dump +func (d *Dumper) concurrentDumpTable(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { + conf := d.conf + db, tbl := meta.DatabaseName(), meta.TableName() + if conf.ServerInfo.ServerType == version.ServerTypeTiDB && + conf.ServerInfo.ServerVersion != nil && + (conf.ServerInfo.ServerVersion.Compare(*tableSampleVersion) >= 0 || + (conf.ServerInfo.HasTiKV && conf.ServerInfo.ServerVersion.Compare(*decodeRegionVersion) >= 0)) { + err := d.concurrentDumpTiDBTables(tctx, conn, meta, taskChan) + // don't retry on context error and successful tasks + if err2 := errors.Cause(err); err2 == nil || err2 == context.DeadlineExceeded || err2 == context.Canceled { + return err + } else if err2 != errEmptyHandleVals { + tctx.L().Info("fallback to concurrent dump tables using rows due to some problem. This won't influence the whole dump process", + zap.String("database", db), zap.String("table", tbl), log.ShortError(err)) + } + } + + orderByClause, err := buildOrderByClause(tctx, conf, conn, db, tbl, meta.HasImplicitRowID()) + if err != nil { + return err + } + + field, err := pickupPossibleField(tctx, meta, conn) + if err != nil || field == "" { + // skip split chunk logic if not found proper field + tctx.L().Info("fallback to sequential dump due to no proper field. This won't influence the whole dump process", + zap.String("database", db), zap.String("table", tbl), log.ShortError(err)) + return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) + } + + count := estimateCount(d.tctx, db, tbl, conn, field, conf) + tctx.L().Info("get estimated rows count", + zap.String("database", db), + zap.String("table", tbl), + zap.Uint64("estimateCount", count)) + if count < conf.Rows { + // skip chunk logic if estimates are low + tctx.L().Info("fallback to sequential dump due to estimate count < rows. This won't influence the whole dump process", + zap.Uint64("estimate count", count), + zap.Uint64("conf.rows", conf.Rows), + zap.String("database", db), + zap.String("table", tbl)) + return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) + } + + min, max, err := d.selectMinAndMaxIntValue(tctx, conn, db, tbl, field) + if err != nil { + tctx.L().Info("fallback to sequential dump due to cannot get bounding values. This won't influence the whole dump process", + log.ShortError(err)) + return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) + } + tctx.L().Debug("get int bounding values", + zap.String("lower", min.String()), + zap.String("upper", max.String())) + + // every chunk would have eventual adjustments + estimatedChunks := count / conf.Rows + estimatedStep := new(big.Int).Sub(max, min).Uint64()/estimatedChunks + 1 + bigEstimatedStep := new(big.Int).SetUint64(estimatedStep) + cutoff := new(big.Int).Set(min) + totalChunks := estimatedChunks + if estimatedStep == 1 { + totalChunks = new(big.Int).Sub(max, min).Uint64() + 1 + } + + selectField, selectLen := meta.SelectedField(), meta.SelectedLen() + + chunkIndex := 0 + nullValueCondition := "" + if conf.Where == "" { + nullValueCondition = fmt.Sprintf("`%s` IS NULL OR ", escapeString(field)) + } + for max.Cmp(cutoff) >= 0 { + nextCutOff := new(big.Int).Add(cutoff, bigEstimatedStep) + where := fmt.Sprintf("%s(`%s` >= %d AND `%s` < %d)", nullValueCondition, escapeString(field), cutoff, escapeString(field), nextCutOff) + query := buildSelectQuery(db, tbl, selectField, "", buildWhereCondition(conf, where), orderByClause) + if len(nullValueCondition) > 0 { + nullValueCondition = "" + } + task := d.newTaskTableData(meta, newTableData(query, selectLen, false), chunkIndex, int(totalChunks)) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + cutoff = nextCutOff + chunkIndex++ + } + return nil +} + +func (d *Dumper) sendTaskToChan(tctx *tcontext.Context, task Task, taskChan chan<- Task) (ctxDone bool) { + select { + case <-tctx.Done(): + return true + case taskChan <- task: + tctx.L().Debug("send task to writer", + zap.String("task", task.Brief())) + DecGauge(d.metrics.taskChannelCapacity) + return false + } +} + +func (d *Dumper) selectMinAndMaxIntValue(tctx *tcontext.Context, conn *BaseConn, db, tbl, field string) (*big.Int, *big.Int, error) { + conf, zero := d.conf, &big.Int{} + query := fmt.Sprintf("SELECT MIN(`%s`),MAX(`%s`) FROM `%s`.`%s`", + escapeString(field), escapeString(field), escapeString(db), escapeString(tbl)) + if conf.Where != "" { + query = fmt.Sprintf("%s WHERE %s", query, conf.Where) + } + tctx.L().Debug("split chunks", zap.String("query", query)) + + var smin sql.NullString + var smax sql.NullString + err := conn.QuerySQL(tctx, func(rows *sql.Rows) error { + err := rows.Scan(&smin, &smax) + rows.Close() + return err + }, func() {}, query) + if err != nil { + return zero, zero, errors.Annotatef(err, "can't get min/max values to split chunks, query: %s", query) + } + if !smax.Valid || !smin.Valid { + // found no data + return zero, zero, errors.Errorf("no invalid min/max value found in query %s", query) + } + + max := new(big.Int) + min := new(big.Int) + var ok bool + if max, ok = max.SetString(smax.String, 10); !ok { + return zero, zero, errors.Errorf("fail to convert max value %s in query %s", smax.String, query) + } + if min, ok = min.SetString(smin.String, 10); !ok { + return zero, zero, errors.Errorf("fail to convert min value %s in query %s", smin.String, query) + } + return min, max, nil +} + +func (d *Dumper) concurrentDumpTiDBTables(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { + db, tbl := meta.DatabaseName(), meta.TableName() + + var ( + handleColNames []string + handleVals [][]string + err error + ) + // for TiDB v5.0+, we can use table sample directly + if d.conf.ServerInfo.ServerVersion.Compare(*tableSampleVersion) >= 0 { + tctx.L().Debug("dumping TiDB tables with TABLESAMPLE", + zap.String("database", db), zap.String("table", tbl)) + handleColNames, handleVals, err = selectTiDBTableSample(tctx, conn, meta) + } else { + // for TiDB v3.0+, we can use table region decode in TiDB directly + tctx.L().Debug("dumping TiDB tables with TABLE REGIONS", + zap.String("database", db), zap.String("table", tbl)) + var partitions []string + if d.conf.ServerInfo.ServerVersion.Compare(*gcSafePointVersion) >= 0 { + partitions, err = GetPartitionNames(tctx, conn, db, tbl) + } + if err == nil { + if len(partitions) != 0 { + return d.concurrentDumpTiDBPartitionTables(tctx, conn, meta, taskChan, partitions) + } + handleColNames, handleVals, err = d.selectTiDBTableRegionFunc(tctx, conn, meta) + } + } + if err != nil { + return err + } + return d.sendConcurrentDumpTiDBTasks(tctx, meta, taskChan, handleColNames, handleVals, "", 0, len(handleVals)+1) +} + +func (d *Dumper) concurrentDumpTiDBPartitionTables(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task, partitions []string) error { + db, tbl := meta.DatabaseName(), meta.TableName() + tctx.L().Debug("dumping TiDB tables with TABLE REGIONS for partition table", + zap.String("database", db), zap.String("table", tbl), zap.Strings("partitions", partitions)) + + startChunkIdx := 0 + totalChunk := 0 + cachedHandleVals := make([][][]string, len(partitions)) + + handleColNames, _, err := selectTiDBRowKeyFields(tctx, conn, meta, checkTiDBTableRegionPkFields) + if err != nil { + return err + } + // cache handleVals here to calculate the total chunks + for i, partition := range partitions { + handleVals, err := selectTiDBPartitionRegion(tctx, conn, db, tbl, partition) + if err != nil { + return err + } + totalChunk += len(handleVals) + 1 + cachedHandleVals[i] = handleVals + } + for i, partition := range partitions { + err := d.sendConcurrentDumpTiDBTasks(tctx, meta, taskChan, handleColNames, cachedHandleVals[i], partition, startChunkIdx, totalChunk) + if err != nil { + return err + } + startChunkIdx += len(cachedHandleVals[i]) + 1 + } + return nil +} + +func (d *Dumper) sendConcurrentDumpTiDBTasks(tctx *tcontext.Context, + meta TableMeta, taskChan chan<- Task, + handleColNames []string, handleVals [][]string, partition string, + startChunkIdx, totalChunk int) error { + db, tbl := meta.DatabaseName(), meta.TableName() + if len(handleVals) == 0 { + if partition == "" { + // return error to make outside function try using rows method to dump data + return errors.Annotatef(errEmptyHandleVals, "table: `%s`.`%s`", escapeString(db), escapeString(tbl)) + } + return d.dumpWholeTableDirectly(tctx, meta, taskChan, partition, buildOrderByClauseString(handleColNames), startChunkIdx, totalChunk) + } + conf := d.conf + selectField, selectLen := meta.SelectedField(), meta.SelectedLen() + where := buildWhereClauses(handleColNames, handleVals) + orderByClause := buildOrderByClauseString(handleColNames) + + for i, w := range where { + query := buildSelectQuery(db, tbl, selectField, partition, buildWhereCondition(conf, w), orderByClause) + task := d.newTaskTableData(meta, newTableData(query, selectLen, false), i+startChunkIdx, totalChunk) + ctxDone := d.sendTaskToChan(tctx, task, taskChan) + if ctxDone { + return tctx.Err() + } + } + return nil +} + +// L returns real logger +func (d *Dumper) L() log.Logger { + return d.tctx.L() +} + +func selectTiDBTableSample(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { + pkFields, pkColTypes, err := selectTiDBRowKeyFields(tctx, conn, meta, nil) + if err != nil { + return nil, nil, errors.Trace(err) + } + + query := buildTiDBTableSampleQuery(pkFields, meta.DatabaseName(), meta.TableName()) + pkValNum := len(pkFields) + var iter SQLRowIter + rowRec := MakeRowReceiver(pkColTypes) + buf := new(bytes.Buffer) + + err = conn.QuerySQL(tctx, func(rows *sql.Rows) error { + if iter == nil { + iter = &rowIter{ + rows: rows, + args: make([]any, pkValNum), + } + } + err = iter.Decode(rowRec) + if err != nil { + return errors.Trace(err) + } + pkValRow := make([]string, 0, pkValNum) + for _, rec := range rowRec.receivers { + rec.WriteToBuffer(buf, true) + pkValRow = append(pkValRow, buf.String()) + buf.Reset() + } + pkVals = append(pkVals, pkValRow) + return nil + }, func() { + if iter != nil { + _ = iter.Close() + iter = nil + } + rowRec = MakeRowReceiver(pkColTypes) + pkVals = pkVals[:0] + buf.Reset() + }, query) + if err == nil && iter != nil && iter.Error() != nil { + err = iter.Error() + } + + return pkFields, pkVals, err +} + +func buildTiDBTableSampleQuery(pkFields []string, dbName, tblName string) string { + template := "SELECT %s FROM `%s`.`%s` TABLESAMPLE REGIONS() ORDER BY %s" + quotaPk := make([]string, len(pkFields)) + for i, s := range pkFields { + quotaPk[i] = fmt.Sprintf("`%s`", escapeString(s)) + } + pks := strings.Join(quotaPk, ",") + return fmt.Sprintf(template, pks, escapeString(dbName), escapeString(tblName), pks) +} + +func selectTiDBRowKeyFields(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, checkPkFields func([]string, []string) error) (pkFields, pkColTypes []string, err error) { + if meta.HasImplicitRowID() { + pkFields, pkColTypes = []string{"_tidb_rowid"}, []string{"BIGINT"} + } else { + pkFields, pkColTypes, err = GetPrimaryKeyAndColumnTypes(tctx, conn, meta) + if err == nil { + if checkPkFields != nil { + err = checkPkFields(pkFields, pkColTypes) + } + } + } + return +} + +func checkTiDBTableRegionPkFields(pkFields, pkColTypes []string) (err error) { + if len(pkFields) != 1 || len(pkColTypes) != 1 { + err = errors.Errorf("unsupported primary key for selectTableRegion. pkFields: [%s], pkColTypes: [%s]", strings.Join(pkFields, ", "), strings.Join(pkColTypes, ", ")) + return + } + if _, ok := dataTypeInt[pkColTypes[0]]; !ok { + err = errors.Errorf("unsupported primary key type for selectTableRegion. pkFields: [%s], pkColTypes: [%s]", strings.Join(pkFields, ", "), strings.Join(pkColTypes, ", ")) + } + return +} + +func selectTiDBTableRegion(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { + pkFields, _, err = selectTiDBRowKeyFields(tctx, conn, meta, checkTiDBTableRegionPkFields) + if err != nil { + return + } + + var ( + startKey, decodedKey sql.NullString + rowID = -1 + ) + const ( + tableRegionSQL = "SELECT START_KEY,tidb_decode_key(START_KEY) from INFORMATION_SCHEMA.TIKV_REGION_STATUS s WHERE s.DB_NAME = ? AND s.TABLE_NAME = ? AND IS_INDEX = 0 ORDER BY START_KEY;" + tidbRowID = "_tidb_rowid=" + ) + dbName, tableName := meta.DatabaseName(), meta.TableName() + logger := tctx.L().With(zap.String("database", dbName), zap.String("table", tableName)) + err = conn.QuerySQL(tctx, func(rows *sql.Rows) error { + rowID++ + err = rows.Scan(&startKey, &decodedKey) + if err != nil { + return errors.Trace(err) + } + // first region's start key has no use. It may come from another table or might be invalid + if rowID == 0 { + return nil + } + if !startKey.Valid { + logger.Debug("meet invalid start key", zap.Int("rowID", rowID)) + return nil + } + if !decodedKey.Valid { + logger.Debug("meet invalid decoded start key", zap.Int("rowID", rowID), zap.String("startKey", startKey.String)) + return nil + } + pkVal, err2 := extractTiDBRowIDFromDecodedKey(tidbRowID, decodedKey.String) + if err2 != nil { + logger.Debug("cannot extract pkVal from decoded start key", + zap.Int("rowID", rowID), zap.String("startKey", startKey.String), zap.String("decodedKey", decodedKey.String), log.ShortError(err2)) + } else { + pkVals = append(pkVals, []string{pkVal}) + } + return nil + }, func() { + pkFields = pkFields[:0] + pkVals = pkVals[:0] + }, tableRegionSQL, dbName, tableName) + + return pkFields, pkVals, errors.Trace(err) +} + +func selectTiDBPartitionRegion(tctx *tcontext.Context, conn *BaseConn, dbName, tableName, partition string) (pkVals [][]string, err error) { + var startKeys [][]string + const ( + partitionRegionSQL = "SHOW TABLE `%s`.`%s` PARTITION(`%s`) REGIONS" + regionRowKey = "r_" + ) + logger := tctx.L().With(zap.String("database", dbName), zap.String("table", tableName), zap.String("partition", partition)) + startKeys, err = conn.QuerySQLWithColumns(tctx, []string{"START_KEY"}, fmt.Sprintf(partitionRegionSQL, escapeString(dbName), escapeString(tableName), escapeString(partition))) + if err != nil { + return + } + for rowID, startKey := range startKeys { + if rowID == 0 || len(startKey) != 1 { + continue + } + pkVal, err2 := extractTiDBRowIDFromDecodedKey(regionRowKey, startKey[0]) + if err2 != nil { + logger.Debug("show table region start key doesn't have rowID", + zap.Int("rowID", rowID), zap.String("startKey", startKey[0]), zap.Error(err2)) + } else { + pkVals = append(pkVals, []string{pkVal}) + } + } + + return pkVals, nil +} + +func extractTiDBRowIDFromDecodedKey(indexField, key string) (string, error) { + if p := strings.Index(key, indexField); p != -1 { + p += len(indexField) + return key[p:], nil + } + return "", errors.Errorf("decoded key %s doesn't have %s field", key, indexField) +} + +func getListTableTypeByConf(conf *Config) listTableType { + // use listTableByShowTableStatus by default because it has better performance + listType := listTableByShowTableStatus + if conf.Consistency == ConsistencyTypeLock { + // for consistency lock, we need to build the tables to dump as soon as possible + listType = listTableByInfoSchema + } else if conf.Consistency == ConsistencyTypeFlush && matchMysqlBugversion(conf.ServerInfo) { + // For some buggy versions of mysql, we need a workaround to get a list of table names. + listType = listTableByShowFullTables + } + return listType +} + +func prepareTableListToDump(tctx *tcontext.Context, conf *Config, db *sql.Conn) error { + if conf.SQL != "" { + return nil + } + + ifSeqExists, err := CheckIfSeqExists(db) + if err != nil { + return err + } + var listType listTableType + if ifSeqExists { + listType = listTableByShowFullTables + } else { + listType = getListTableTypeByConf(conf) + } + + if conf.SpecifiedTables { + return updateSpecifiedTablesMeta(tctx, db, conf.Tables, listType) + } + databases, err := prepareDumpingDatabases(tctx, conf, db) + if err != nil { + return err + } + + tableTypes := []TableType{TableTypeBase} + if !conf.NoViews { + tableTypes = append(tableTypes, TableTypeView) + } + if !conf.NoSequences { + tableTypes = append(tableTypes, TableTypeSequence) + } + + conf.Tables, err = ListAllDatabasesTables(tctx, db, databases, listType, tableTypes...) + if err != nil { + return err + } + + filterTables(tctx, conf) + return nil +} + +func dumpTableMeta(tctx *tcontext.Context, conf *Config, conn *BaseConn, db string, table *TableInfo) (TableMeta, error) { + tbl := table.Name + selectField, selectLen, err := buildSelectField(tctx, conn, db, tbl, conf.CompleteInsert) + if err != nil { + return nil, err + } + var ( + colTypes []*sql.ColumnType + hasImplicitRowID bool + ) + if conf.ServerInfo.ServerType == version.ServerTypeTiDB { + hasImplicitRowID, err = SelectTiDBRowID(tctx, conn, db, tbl) + if err != nil { + tctx.L().Info("check implicit rowID failed", zap.String("database", db), zap.String("table", tbl), log.ShortError(err)) + } + } + + // If all columns are generated + if table.Type == TableTypeBase { + if selectField == "" { + colTypes, err = GetColumnTypes(tctx, conn, "*", db, tbl) + } else { + colTypes, err = GetColumnTypes(tctx, conn, selectField, db, tbl) + } + } + if err != nil { + return nil, err + } + + meta := &tableMeta{ + avgRowLength: table.AvgRowLength, + database: db, + table: tbl, + colTypes: colTypes, + selectedField: selectField, + selectedLen: selectLen, + hasImplicitRowID: hasImplicitRowID, + specCmts: getSpecialComments(conf.ServerInfo.ServerType), + } + + if conf.NoSchemas { + return meta, nil + } + switch table.Type { + case TableTypeView: + viewName := table.Name + createTableSQL, createViewSQL, err1 := ShowCreateView(tctx, conn, db, viewName) + if err1 != nil { + return meta, err1 + } + meta.showCreateTable = createTableSQL + meta.showCreateView = createViewSQL + return meta, nil + case TableTypeSequence: + sequenceName := table.Name + createSequenceSQL, err2 := ShowCreateSequence(tctx, conn, db, sequenceName, conf) + if err2 != nil { + return meta, err2 + } + meta.showCreateTable = createSequenceSQL + return meta, nil + } + + createTableSQL, err := ShowCreateTable(tctx, conn, db, tbl) + if err != nil { + return nil, err + } + meta.showCreateTable = createTableSQL + return meta, nil +} + +func (d *Dumper) dumpSQL(tctx *tcontext.Context, metaConn *BaseConn, taskChan chan<- Task) { + conf := d.conf + meta := &tableMeta{} + data := newTableData(conf.SQL, 0, true) + task := d.newTaskTableData(meta, data, 0, 1) + c := detectEstimateRows(tctx, metaConn, fmt.Sprintf("EXPLAIN %s", conf.SQL), []string{"rows", "estRows", "count"}) + AddCounter(d.metrics.estimateTotalRowsCounter, float64(c)) + atomic.StoreInt64(&d.totalTables, int64(1)) + d.sendTaskToChan(tctx, task, taskChan) +} + +func canRebuildConn(consistency string, trxConsistencyOnly bool) bool { + switch consistency { + case ConsistencyTypeLock, ConsistencyTypeFlush: + return !trxConsistencyOnly + case ConsistencyTypeSnapshot, ConsistencyTypeNone: + return true + default: + return false + } +} + +// Close closes a Dumper and stop dumping immediately +func (d *Dumper) Close() error { + d.cancelCtx() + d.metrics.unregisterFrom(d.conf.PromRegistry) + if d.dbHandle != nil { + return d.dbHandle.Close() + } + return nil +} + +func runSteps(d *Dumper, steps ...func(*Dumper) error) error { + for _, st := range steps { + err := st(d) + if err != nil { + return err + } + } + return nil +} + +func initLogger(d *Dumper) error { + conf := d.conf + var ( + logger log.Logger + err error + props *pclog.ZapProperties + ) + // conf.Logger != nil means dumpling is used as a library + if conf.Logger != nil { + logger = log.NewAppLogger(conf.Logger) + } else { + logger, props, err = log.InitAppLogger(&log.Config{ + Level: conf.LogLevel, + File: conf.LogFile, + Format: conf.LogFormat, + }) + if err != nil { + return errors.Trace(err) + } + pclog.ReplaceGlobals(logger.Logger, props) + cli.LogLongVersion(logger) + } + d.tctx = d.tctx.WithLogger(logger) + return nil +} + +// createExternalStore is an initialization step of Dumper. +func createExternalStore(d *Dumper) error { + tctx, conf := d.tctx, d.conf + extStore, err := conf.createExternalStorage(tctx) + if err != nil { + return errors.Trace(err) + } + d.extStore = extStore + return nil +} + +// startHTTPService is an initialization step of Dumper. +func startHTTPService(d *Dumper) error { + conf := d.conf + if conf.StatusAddr != "" { + go func() { + err := startDumplingService(d.tctx, conf.StatusAddr) + if err != nil { + d.L().Info("meet error when stopping dumpling http service", log.ShortError(err)) + } + }() + } + return nil +} + +// openSQLDB is an initialization step of Dumper. +func openSQLDB(d *Dumper) error { + if d.conf.IOTotalBytes != nil { + mysql.RegisterDialContext(d.conf.Net, func(ctx context.Context, addr string) (net.Conn, error) { + dial := &net.Dialer{} + conn, err := dial.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + tcpConn := conn.(*net.TCPConn) + // try https://github.com/go-sql-driver/mysql/blob/bcc459a906419e2890a50fc2c99ea6dd927a88f2/connector.go#L56-L64 + err = tcpConn.SetKeepAlive(true) + if err != nil { + d.tctx.L().Logger.Warn("fail to keep alive", zap.Error(err)) + } + return util.NewTCPConnWithIOCounter(tcpConn, d.conf.IOTotalBytes), nil + }) + } + conf := d.conf + c, err := mysql.NewConnector(conf.GetDriverConfig("")) + if err != nil { + return errors.Trace(err) + } + d.dbHandle = sql.OpenDB(c) + return nil +} + +// detectServerInfo is an initialization step of Dumper. +func detectServerInfo(d *Dumper) error { + db, conf := d.dbHandle, d.conf + versionStr, err := version.FetchVersion(d.tctx.Context, db) + if err != nil { + conf.ServerInfo = ServerInfoUnknown + return err + } + conf.ServerInfo = version.ParseServerInfo(versionStr) + return nil +} + +// resolveAutoConsistency is an initialization step of Dumper. +func resolveAutoConsistency(d *Dumper) error { + conf := d.conf + if conf.Consistency != ConsistencyTypeAuto { + return nil + } + switch conf.ServerInfo.ServerType { + case version.ServerTypeTiDB: + conf.Consistency = ConsistencyTypeSnapshot + case version.ServerTypeMySQL, version.ServerTypeMariaDB: + conf.Consistency = ConsistencyTypeFlush + default: + conf.Consistency = ConsistencyTypeNone + } + + if conf.Consistency == ConsistencyTypeFlush { + timeout := time.Second * 5 + ctx, cancel := context.WithTimeout(d.tctx.Context, timeout) + defer cancel() + + // probe if upstream has enough privilege to FLUSH TABLE WITH READ LOCK + conn, err := d.dbHandle.Conn(ctx) + if err != nil { + return errors.New("failed to get connection from db pool after 5 seconds") + } + //nolint: errcheck + defer conn.Close() + + err = FlushTableWithReadLock(d.tctx, conn) + //nolint: errcheck + defer UnlockTables(d.tctx, conn) + if err != nil { + // fallback to ConsistencyTypeLock + d.tctx.L().Warn("error when use FLUSH TABLE WITH READ LOCK, fallback to LOCK TABLES", + zap.Error(err)) + conf.Consistency = ConsistencyTypeLock + } + } + return nil +} + +func validateResolveAutoConsistency(d *Dumper) error { + conf := d.conf + if conf.Consistency != ConsistencyTypeSnapshot && conf.Snapshot != "" { + return errors.Errorf("can't specify --snapshot when --consistency isn't snapshot, resolved consistency: %s", conf.Consistency) + } + return nil +} + +// tidbSetPDClientForGC is an initialization step of Dumper. +func tidbSetPDClientForGC(d *Dumper) error { + tctx, si, pool := d.tctx, d.conf.ServerInfo, d.dbHandle + if si.ServerType != version.ServerTypeTiDB || + si.ServerVersion == nil || + si.ServerVersion.Compare(*gcSafePointVersion) < 0 { + return nil + } + pdAddrs, err := GetPdAddrs(tctx, pool) + if err != nil { + tctx.L().Info("meet some problem while fetching pd addrs. This won't affect dump process", log.ShortError(err)) + return nil + } + if len(pdAddrs) > 0 { + doPdGC, err := checkSameCluster(tctx, pool, pdAddrs) + if err != nil { + tctx.L().Info("meet error while check whether fetched pd addr and TiDB belong to one cluster. This won't affect dump process", log.ShortError(err), zap.Strings("pdAddrs", pdAddrs)) + } else if doPdGC { + pdClient, err := pd.NewClientWithContext(tctx, pdAddrs, pd.SecurityOption{}) + if err != nil { + tctx.L().Info("create pd client to control GC failed. This won't affect dump process", log.ShortError(err), zap.Strings("pdAddrs", pdAddrs)) + } + d.tidbPDClientForGC = pdClient + } + } + return nil +} + +// tidbGetSnapshot is an initialization step of Dumper. +func tidbGetSnapshot(d *Dumper) error { + conf, doPdGC := d.conf, d.tidbPDClientForGC != nil + consistency := conf.Consistency + pool, tctx := d.dbHandle, d.tctx + snapshotConsistency := consistency == "snapshot" + if conf.Snapshot == "" && (doPdGC || snapshotConsistency) { + conn, err := pool.Conn(tctx) + if err != nil { + tctx.L().Warn("fail to open connection to get snapshot from TiDB", log.ShortError(err)) + // for consistency snapshot, we must get a snapshot here, or we will dump inconsistent data, but for other consistency we can ignore this error. + if !snapshotConsistency { + err = nil + } + return err + } + snapshot, err := getSnapshot(conn) + _ = conn.Close() + if err != nil { + tctx.L().Warn("fail to get snapshot from TiDB", log.ShortError(err)) + // for consistency snapshot, we must get a snapshot here, or we will dump inconsistent data, but for other consistency we can ignore this error. + if !snapshotConsistency { + err = nil + } + return err + } + conf.Snapshot = snapshot + } + return nil +} + +// tidbStartGCSavepointUpdateService is an initialization step of Dumper. +func tidbStartGCSavepointUpdateService(d *Dumper) error { + tctx, pool, conf := d.tctx, d.dbHandle, d.conf + snapshot, si := conf.Snapshot, conf.ServerInfo + if d.tidbPDClientForGC != nil { + snapshotTS, err := parseSnapshotToTSO(pool, snapshot) + if err != nil { + return err + } + go updateServiceSafePoint(tctx, d.tidbPDClientForGC, defaultDumpGCSafePointTTL, snapshotTS) + } else if si.ServerType == version.ServerTypeTiDB { + tctx.L().Warn("If the amount of data to dump is large, criteria: (data more than 60GB or dumped time more than 10 minutes)\n" + + "you'd better adjust the tikv_gc_life_time to avoid export failure due to TiDB GC during the dump process.\n" + + "Before dumping: run sql `update mysql.tidb set VARIABLE_VALUE = '720h' where VARIABLE_NAME = 'tikv_gc_life_time';` in tidb.\n" + + "After dumping: run sql `update mysql.tidb set VARIABLE_VALUE = '10m' where VARIABLE_NAME = 'tikv_gc_life_time';` in tidb.\n") + } + return nil +} + +func updateServiceSafePoint(tctx *tcontext.Context, pdClient pd.Client, ttl int64, snapshotTS uint64) { + updateInterval := time.Duration(ttl/2) * time.Second + tick := time.NewTicker(updateInterval) + dumplingServiceSafePointID := fmt.Sprintf("%s_%d", dumplingServiceSafePointPrefix, time.Now().UnixNano()) + tctx.L().Info("generate dumpling gc safePoint id", zap.String("id", dumplingServiceSafePointID)) + + for { + tctx.L().Debug("update PD safePoint limit with ttl", + zap.Uint64("safePoint", snapshotTS), + zap.Int64("ttl", ttl)) + for retryCnt := 0; retryCnt <= 10; retryCnt++ { + _, err := pdClient.UpdateServiceGCSafePoint(tctx, dumplingServiceSafePointID, ttl, snapshotTS) + if err == nil { + break + } + tctx.L().Debug("update PD safePoint failed", log.ShortError(err), zap.Int("retryTime", retryCnt)) + select { + case <-tctx.Done(): + return + case <-time.After(time.Second): + } + } + select { + case <-tctx.Done(): + return + case <-tick.C: + } + } +} + +// setDefaultSessionParams is a step to set default params for session params. +func setDefaultSessionParams(si version.ServerInfo, sessionParams map[string]any) { + defaultSessionParams := map[string]any{} + if si.ServerType == version.ServerTypeTiDB && si.HasTiKV && si.ServerVersion.Compare(*enablePagingVersion) >= 0 { + defaultSessionParams["tidb_enable_paging"] = "ON" + } + for k, v := range defaultSessionParams { + if _, ok := sessionParams[k]; !ok { + sessionParams[k] = v + } + } +} + +// setSessionParam is an initialization step of Dumper. +func setSessionParam(d *Dumper) error { + conf, pool := d.conf, d.dbHandle + si := conf.ServerInfo + consistency, snapshot := conf.Consistency, conf.Snapshot + sessionParam := conf.SessionParams + if si.ServerType == version.ServerTypeTiDB && conf.TiDBMemQuotaQuery != UnspecifiedSize { + sessionParam[TiDBMemQuotaQueryName] = conf.TiDBMemQuotaQuery + } + var err error + if snapshot != "" { + if si.ServerType != version.ServerTypeTiDB { + return errors.New("snapshot consistency is not supported for this server") + } + if consistency == ConsistencyTypeSnapshot { + conf.ServerInfo.HasTiKV, err = CheckTiDBWithTiKV(pool) + if err != nil { + d.L().Info("cannot check whether TiDB has TiKV, will apply tidb_snapshot by default. This won't affect dump process", log.ShortError(err)) + } + if conf.ServerInfo.HasTiKV { + sessionParam[snapshotVar] = snapshot + } + } + } + if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDriverConfig(""), conf.SessionParams); err != nil { + return errors.Trace(err) + } + return nil +} + +func openDB(cfg *mysql.Config) (*sql.DB, error) { + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + return sql.OpenDB(c), nil +} + +func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) error { + conf := d.conf + if !(conf.ServerInfo.ServerType == version.ServerTypeTiDB && conf.ServerInfo.ServerVersion != nil && conf.ServerInfo.HasTiKV && + conf.ServerInfo.ServerVersion.Compare(*decodeRegionVersion) >= 0 && + conf.ServerInfo.ServerVersion.Compare(*gcSafePointVersion) < 0) { + tctx.L().Debug("no need to build region info because database is not TiDB 3.x") + return nil + } + // for TiDB v3.0+, the original selectTiDBTableRegionFunc will always fail, + // because TiDB v3.0 doesn't have `tidb_decode_key` function nor `DB_NAME`,`TABLE_NAME` columns in `INFORMATION_SCHEMA.TIKV_REGION_STATUS`. + // reference: https://github.com/pingcap/tidb/blob/c497d5c/dumpling/export/dump.go#L775 + // To avoid this function continuously returning errors and confusing users because we fail to init this function at first, + // selectTiDBTableRegionFunc is set to always return an ignorable error at first. + d.selectTiDBTableRegionFunc = func(_ *tcontext.Context, _ *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { + return nil, nil, errors.Annotatef(errEmptyHandleVals, "table: `%s`.`%s`", escapeString(meta.DatabaseName()), escapeString(meta.TableName())) + } + dbHandle, err := openDBFunc(conf.GetDriverConfig("")) + if err != nil { + return errors.Trace(err) + } + defer func() { + _ = dbHandle.Close() + }() + conn, err := dbHandle.Conn(tctx) + if err != nil { + return errors.Trace(err) + } + defer func() { + _ = conn.Close() + }() + dbInfos, err := GetDBInfo(conn, DatabaseTablesToMap(conf.Tables)) + if err != nil { + return errors.Trace(err) + } + regionsInfo, err := GetRegionInfos(conn) + if err != nil { + return errors.Trace(err) + } + tikvHelper := &helper.Helper{} + tableInfos := tikvHelper.GetRegionsTableInfo(regionsInfo, infoschema.DBInfoAsInfoSchema(dbInfos), nil) + + tableInfoMap := make(map[string]map[string][]int64, len(conf.Tables)) + for _, region := range regionsInfo.Regions { + tableList := tableInfos[region.ID] + for _, table := range tableList { + db, tbl := table.DB.Name.O, table.Table.Name.O + if _, ok := tableInfoMap[db]; !ok { + tableInfoMap[db] = make(map[string][]int64, len(conf.Tables[db])) + } + + key, err := hex.DecodeString(region.StartKey) + if err != nil { + d.L().Debug("invalid region start key", log.ShortError(err), zap.String("key", region.StartKey)) + continue + } + // Auto decode byte if needed. + _, bs, err := codec.DecodeBytes(key, nil) + if err == nil { + key = bs + } + // Try to decode it as a record key. + tableID, handle, err := tablecodec.DecodeRecordKey(key) + if err != nil { + d.L().Debug("cannot decode region start key", log.ShortError(err), zap.String("key", region.StartKey), zap.Int64("tableID", tableID)) + continue + } + if handle.IsInt() { + tableInfoMap[db][tbl] = append(tableInfoMap[db][tbl], handle.IntValue()) + } else { + d.L().Debug("not an int handle", log.ShortError(err), zap.Stringer("handle", handle)) + } + } + } + for _, tbInfos := range tableInfoMap { + for _, tbInfoLoop := range tbInfos { + // make sure tbInfo is only used in this loop + tbInfo := tbInfoLoop + slices.Sort(tbInfo) + } + } + + d.selectTiDBTableRegionFunc = func(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { + pkFields, _, err = selectTiDBRowKeyFields(tctx, conn, meta, checkTiDBTableRegionPkFields) + if err != nil { + return + } + dbName, tableName := meta.DatabaseName(), meta.TableName() + if tbInfos, ok := tableInfoMap[dbName]; ok { + if tbInfo, ok := tbInfos[tableName]; ok { + pkVals = make([][]string, len(tbInfo)) + for i, val := range tbInfo { + pkVals[i] = []string{strconv.FormatInt(val, 10)} + } + } + } + return + } + + return nil +} + +func (d *Dumper) newTaskTableData(meta TableMeta, data TableDataIR, currentChunk, totalChunks int) *TaskTableData { + d.metrics.totalChunks.Add(1) + return NewTaskTableData(meta, data, currentChunk, totalChunks) +} diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index 690ef65fe054f..e0c8fb1682d7f 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -609,9 +609,9 @@ func GetColumnTypes(tctx *tcontext.Context, db *BaseConn, fields, database, tabl if err == nil { err = rows.Close() } - failpoint.Inject("ChaosBrokenMetaConn", func(_ failpoint.Value) { - failpoint.Return(errors.New("connection is closed")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("ChaosBrokenMetaConn")); _err_ == nil { + return errors.New("connection is closed") + } return errors.Annotatef(err, "sql: %s", query) }, func() { colTypes = nil @@ -992,9 +992,9 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Con } cfg.Params[k] = s } - failpoint.Inject("SkipResetDB", func(_ failpoint.Value) { - failpoint.Return(db, nil) - }) + if _, _err_ := failpoint.Eval(_curpkg_("SkipResetDB")); _err_ == nil { + return db, nil + } db.Close() c, err := mysql.NewConnector(cfg) diff --git a/dumpling/export/sql.go__failpoint_stash__ b/dumpling/export/sql.go__failpoint_stash__ new file mode 100644 index 0000000000000..690ef65fe054f --- /dev/null +++ b/dumpling/export/sql.go__failpoint_stash__ @@ -0,0 +1,1643 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package export + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "math" + "strconv" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/version" + tcontext "github.com/pingcap/tidb/dumpling/context" + "github.com/pingcap/tidb/dumpling/log" + dbconfig "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/parser/model" + pd "github.com/tikv/pd/client/http" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +const ( + orderByTiDBRowID = "ORDER BY `_tidb_rowid`" + snapshotVar = "tidb_snapshot" +) + +type listTableType int + +const ( + listTableByInfoSchema listTableType = iota + listTableByShowFullTables + listTableByShowTableStatus +) + +// ShowDatabases shows the databases of a database server. +func ShowDatabases(db *sql.Conn) ([]string, error) { + var res oneStrColumnTable + if err := simpleQuery(db, "SHOW DATABASES", res.handleOneRow); err != nil { + return nil, err + } + return res.data, nil +} + +// ShowTables shows the tables of a database, the caller should use the correct database. +func ShowTables(db *sql.Conn) ([]string, error) { + var res oneStrColumnTable + if err := simpleQuery(db, "SHOW TABLES", res.handleOneRow); err != nil { + return nil, err + } + return res.data, nil +} + +// ShowCreateDatabase constructs the create database SQL for a specified database +// returns (createDatabaseSQL, error) +func ShowCreateDatabase(tctx *tcontext.Context, db *BaseConn, database string) (string, error) { + var oneRow [2]string + handleOneRow := func(rows *sql.Rows) error { + return rows.Scan(&oneRow[0], &oneRow[1]) + } + query := fmt.Sprintf("SHOW CREATE DATABASE `%s`", escapeString(database)) + err := db.QuerySQL(tctx, handleOneRow, func() { + oneRow[0], oneRow[1] = "", "" + }, query) + if multiErrs := multierr.Errors(err); len(multiErrs) > 0 { + for _, multiErr := range multiErrs { + if mysqlErr, ok := errors.Cause(multiErr).(*mysql.MySQLError); ok { + // Falling back to simple create statement for MemSQL/SingleStore, because of this: + // ERROR 1706 (HY000): Feature 'SHOW CREATE DATABASE' is not supported by MemSQL. + if strings.Contains(mysqlErr.Error(), "SHOW CREATE DATABASE") { + return fmt.Sprintf("CREATE DATABASE `%s`", escapeString(database)), nil + } + } + } + } + return oneRow[1], err +} + +// ShowCreateTable constructs the create table SQL for a specified table +// returns (createTableSQL, error) +func ShowCreateTable(tctx *tcontext.Context, db *BaseConn, database, table string) (string, error) { + var oneRow [2]string + handleOneRow := func(rows *sql.Rows) error { + return rows.Scan(&oneRow[0], &oneRow[1]) + } + query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", escapeString(database), escapeString(table)) + err := db.QuerySQL(tctx, handleOneRow, func() { + oneRow[0], oneRow[1] = "", "" + }, query) + if err != nil { + return "", err + } + return oneRow[1], nil +} + +// ShowCreatePlacementPolicy constructs the create policy SQL for a specified table +// returns (createPolicySQL, error) +func ShowCreatePlacementPolicy(tctx *tcontext.Context, db *BaseConn, policy string) (string, error) { + var oneRow [2]string + handleOneRow := func(rows *sql.Rows) error { + return rows.Scan(&oneRow[0], &oneRow[1]) + } + query := fmt.Sprintf("SHOW CREATE PLACEMENT POLICY `%s`", escapeString(policy)) + err := db.QuerySQL(tctx, handleOneRow, func() { + oneRow[0], oneRow[1] = "", "" + }, query) + return oneRow[1], err +} + +// ShowCreateView constructs the create view SQL for a specified view +// returns (createFakeTableSQL, createViewSQL, error) +func ShowCreateView(tctx *tcontext.Context, db *BaseConn, database, view string) (createFakeTableSQL string, createRealViewSQL string, err error) { + var fieldNames []string + handleFieldRow := func(rows *sql.Rows) error { + var oneRow [6]sql.NullString + scanErr := rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3], &oneRow[4], &oneRow[5]) + if scanErr != nil { + return errors.Trace(scanErr) + } + if oneRow[0].Valid { + fieldNames = append(fieldNames, fmt.Sprintf("`%s` int", escapeString(oneRow[0].String))) + } + return nil + } + var oneRow [4]string + handleOneRow := func(rows *sql.Rows) error { + return rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3]) + } + var createTableSQL, createViewSQL strings.Builder + + // Build createTableSQL + query := fmt.Sprintf("SHOW FIELDS FROM `%s`.`%s`", escapeString(database), escapeString(view)) + err = db.QuerySQL(tctx, handleFieldRow, func() { + fieldNames = []string{} + }, query) + if err != nil { + return "", "", err + } + fmt.Fprintf(&createTableSQL, "CREATE TABLE `%s`(\n", escapeString(view)) + createTableSQL.WriteString(strings.Join(fieldNames, ",\n")) + createTableSQL.WriteString("\n)ENGINE=MyISAM;\n") + + // Build createViewSQL + fmt.Fprintf(&createViewSQL, "DROP TABLE IF EXISTS `%s`;\n", escapeString(view)) + fmt.Fprintf(&createViewSQL, "DROP VIEW IF EXISTS `%s`;\n", escapeString(view)) + query = fmt.Sprintf("SHOW CREATE VIEW `%s`.`%s`", escapeString(database), escapeString(view)) + err = db.QuerySQL(tctx, handleOneRow, func() { + for i := range oneRow { + oneRow[i] = "" + } + }, query) + if err != nil { + return "", "", err + } + // The result for `show create view` SQL + // mysql> show create view v1; + // +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+ + // | View | Create View | character_set_client | collation_connection | + // +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+ + // | v1 | CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `v1` (`a`) AS SELECT `t`.`a` AS `a` FROM `test`.`t` | utf8 | utf8_general_ci | + // +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+ + SetCharset(&createViewSQL, oneRow[2], oneRow[3]) + createViewSQL.WriteString(oneRow[1]) + createViewSQL.WriteString(";\n") + RestoreCharset(&createViewSQL) + + return createTableSQL.String(), createViewSQL.String(), nil +} + +// ShowCreateSequence constructs the create sequence SQL for a specified sequence +// returns (createSequenceSQL, error) +func ShowCreateSequence(tctx *tcontext.Context, db *BaseConn, database, sequence string, conf *Config) (string, error) { + var oneRow [2]string + handleOneRow := func(rows *sql.Rows) error { + return rows.Scan(&oneRow[0], &oneRow[1]) + } + var ( + createSequenceSQL strings.Builder + nextNotCachedValue int64 + ) + query := fmt.Sprintf("SHOW CREATE SEQUENCE `%s`.`%s`", escapeString(database), escapeString(sequence)) + err := db.QuerySQL(tctx, handleOneRow, func() { + oneRow[0], oneRow[1] = "", "" + }, query) + if err != nil { + return "", err + } + createSequenceSQL.WriteString(oneRow[1]) + createSequenceSQL.WriteString(";\n") + + switch conf.ServerInfo.ServerType { + case version.ServerTypeTiDB: + // Get next not allocated auto increment id of the whole cluster + query := fmt.Sprintf("SHOW TABLE `%s`.`%s` NEXT_ROW_ID", escapeString(database), escapeString(sequence)) + results, err := db.QuerySQLWithColumns(tctx, []string{"NEXT_GLOBAL_ROW_ID", "ID_TYPE"}, query) + if err != nil { + return "", err + } + for _, oneRow := range results { + nextGlobalRowID, idType := oneRow[0], oneRow[1] + if idType == "SEQUENCE" { + nextNotCachedValue, _ = strconv.ParseInt(nextGlobalRowID, 10, 64) + } + } + fmt.Fprintf(&createSequenceSQL, "SELECT SETVAL(`%s`,%d);\n", escapeString(sequence), nextNotCachedValue) + case version.ServerTypeMariaDB: + var oneRow1 string + handleOneRow1 := func(rows *sql.Rows) error { + return rows.Scan(&oneRow1) + } + query := fmt.Sprintf("SELECT NEXT_NOT_CACHED_VALUE FROM `%s`.`%s`", escapeString(database), escapeString(sequence)) + err := db.QuerySQL(tctx, handleOneRow1, func() { + oneRow1 = "" + }, query) + if err != nil { + return "", err + } + nextNotCachedValue, _ = strconv.ParseInt(oneRow1, 10, 64) + fmt.Fprintf(&createSequenceSQL, "SELECT SETVAL(`%s`,%d);\n", escapeString(sequence), nextNotCachedValue) + } + return createSequenceSQL.String(), nil +} + +// SetCharset builds the set charset SQLs +func SetCharset(w *strings.Builder, characterSet, collationConnection string) { + w.WriteString("SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT;\n") + w.WriteString("SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS;\n") + w.WriteString("SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION;\n") + + fmt.Fprintf(w, "SET character_set_client = %s;\n", characterSet) + fmt.Fprintf(w, "SET character_set_results = %s;\n", characterSet) + fmt.Fprintf(w, "SET collation_connection = %s;\n", collationConnection) +} + +// RestoreCharset builds the restore charset SQLs +func RestoreCharset(w io.StringWriter) { + _, _ = w.WriteString("SET character_set_client = @PREV_CHARACTER_SET_CLIENT;\n") + _, _ = w.WriteString("SET character_set_results = @PREV_CHARACTER_SET_RESULTS;\n") + _, _ = w.WriteString("SET collation_connection = @PREV_COLLATION_CONNECTION;\n") +} + +// updateSpecifiedTablesMeta updates DatabaseTables with correct table type and avg row size. +func updateSpecifiedTablesMeta(tctx *tcontext.Context, db *sql.Conn, dbTables DatabaseTables, listType listTableType) error { + var ( + schema, table, tableTypeStr string + tableType TableType + avgRowLength uint64 + err error + ) + switch listType { + case listTableByInfoSchema: + dbNames := make([]string, 0, len(dbTables)) + for db := range dbTables { + dbNames = append(dbNames, fmt.Sprintf("'%s'", db)) + } + query := fmt.Sprintf("SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN (%s)", strings.Join(dbNames, ",")) + if err := simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { + var ( + sqlAvgRowLength sql.NullInt64 + err2 error + ) + if err2 = rows.Scan(&schema, &table, &tableTypeStr, &sqlAvgRowLength); err != nil { + return errors.Trace(err2) + } + + tbls, ok := dbTables[schema] + if !ok { + return nil + } + for _, tbl := range tbls { + if tbl.Name == table { + tableType, err2 = ParseTableType(tableTypeStr) + if err2 != nil { + return errors.Trace(err2) + } + if sqlAvgRowLength.Valid { + avgRowLength = uint64(sqlAvgRowLength.Int64) + } else { + avgRowLength = 0 + } + tbl.Type = tableType + tbl.AvgRowLength = avgRowLength + } + } + return nil + }, query); err != nil { + return errors.Annotatef(err, "sql: %s", query) + } + return nil + case listTableByShowFullTables: + for schema, tbls := range dbTables { + query := fmt.Sprintf("SHOW FULL TABLES FROM `%s`", + escapeString(schema)) + if err := simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { + var err2 error + if err2 = rows.Scan(&table, &tableTypeStr); err != nil { + return errors.Trace(err2) + } + for _, tbl := range tbls { + if tbl.Name == table { + tableType, err2 = ParseTableType(tableTypeStr) + if err2 != nil { + return errors.Trace(err2) + } + tbl.Type = tableType + } + } + return nil + }, query); err != nil { + return errors.Annotatef(err, "sql: %s", query) + } + } + return nil + default: + const queryTemplate = "SHOW TABLE STATUS FROM `%s`" + for schema, tbls := range dbTables { + query := fmt.Sprintf(queryTemplate, escapeString(schema)) + rows, err := db.QueryContext(tctx, query) + if err != nil { + return errors.Annotatef(err, "sql: %s", query) + } + results, err := GetSpecifiedColumnValuesAndClose(rows, "NAME", "ENGINE", "AVG_ROW_LENGTH", "COMMENT") + if err != nil { + return errors.Annotatef(err, "sql: %s", query) + } + for _, oneRow := range results { + table, engine, avgRowLengthStr, comment := oneRow[0], oneRow[1], oneRow[2], oneRow[3] + for _, tbl := range tbls { + if tbl.Name == table { + if avgRowLengthStr != "" { + avgRowLength, err = strconv.ParseUint(avgRowLengthStr, 10, 64) + if err != nil { + return errors.Annotatef(err, "sql: %s", query) + } + } else { + avgRowLength = 0 + } + tbl.AvgRowLength = avgRowLength + tableType = TableTypeBase + if engine == "" && (comment == "" || comment == TableTypeViewStr) { + tableType = TableTypeView + } else if engine == "" { + tctx.L().Warn("invalid table without engine found", zap.String("database", schema), zap.String("table", table)) + continue + } + tbl.Type = tableType + } + } + } + } + return nil + } +} + +// ListAllDatabasesTables lists all the databases and tables from the database +// listTableByInfoSchema list tables by table information_schema in MySQL +// listTableByShowTableStatus has better performance than listTableByInfoSchema +// listTableByShowFullTables is used in mysql8 version [8.0.3,8.0.23), more details can be found in the comments of func matchMysqlBugversion +func ListAllDatabasesTables(tctx *tcontext.Context, db *sql.Conn, databaseNames []string, + listType listTableType, tableTypes ...TableType) (DatabaseTables, error) { // revive:disable-line:flag-parameter + dbTables := DatabaseTables{} + var ( + schema, table, tableTypeStr string + tableType TableType + avgRowLength uint64 + err error + ) + + tableTypeConditions := make([]string, len(tableTypes)) + for i, tableType := range tableTypes { + tableTypeConditions[i] = fmt.Sprintf("TABLE_TYPE='%s'", tableType) + } + switch listType { + case listTableByInfoSchema: + query := fmt.Sprintf("SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE %s", strings.Join(tableTypeConditions, " OR ")) + for _, schema := range databaseNames { + dbTables[schema] = make([]*TableInfo, 0) + } + if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { + var ( + sqlAvgRowLength sql.NullInt64 + err2 error + ) + if err2 = rows.Scan(&schema, &table, &tableTypeStr, &sqlAvgRowLength); err != nil { + return errors.Trace(err2) + } + tableType, err2 = ParseTableType(tableTypeStr) + if err2 != nil { + return errors.Trace(err2) + } + + if sqlAvgRowLength.Valid { + avgRowLength = uint64(sqlAvgRowLength.Int64) + } else { + avgRowLength = 0 + } + // only append tables to schemas in databaseNames + if _, ok := dbTables[schema]; ok { + dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType}) + } + return nil + }, query); err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + case listTableByShowFullTables: + for _, schema = range databaseNames { + dbTables[schema] = make([]*TableInfo, 0) + query := fmt.Sprintf("SHOW FULL TABLES FROM `%s` WHERE %s", + escapeString(schema), strings.Join(tableTypeConditions, " OR ")) + if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { + var err2 error + if err2 = rows.Scan(&table, &tableTypeStr); err != nil { + return errors.Trace(err2) + } + tableType, err2 = ParseTableType(tableTypeStr) + if err2 != nil { + return errors.Trace(err2) + } + avgRowLength = 0 // can't get avgRowLength from the result of `show full tables` so hardcode to 0 here + dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType}) + return nil + }, query); err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + } + default: + const queryTemplate = "SHOW TABLE STATUS FROM `%s`" + selectedTableType := make(map[TableType]struct{}) + for _, tableType = range tableTypes { + selectedTableType[tableType] = struct{}{} + } + for _, schema = range databaseNames { + dbTables[schema] = make([]*TableInfo, 0) + query := fmt.Sprintf(queryTemplate, escapeString(schema)) + rows, err := db.QueryContext(tctx, query) + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + results, err := GetSpecifiedColumnValuesAndClose(rows, "NAME", "ENGINE", "AVG_ROW_LENGTH", "COMMENT") + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + for _, oneRow := range results { + table, engine, avgRowLengthStr, comment := oneRow[0], oneRow[1], oneRow[2], oneRow[3] + if avgRowLengthStr != "" { + avgRowLength, err = strconv.ParseUint(avgRowLengthStr, 10, 64) + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + } else { + avgRowLength = 0 + } + tableType = TableTypeBase + if engine == "" && (comment == "" || comment == TableTypeViewStr) { + tableType = TableTypeView + } else if engine == "" { + tctx.L().Warn("invalid table without engine found", zap.String("database", schema), zap.String("table", table)) + continue + } + if _, ok := selectedTableType[tableType]; !ok { + continue + } + dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType}) + } + } + } + return dbTables, nil +} + +// ListAllPlacementPolicyNames returns all placement policy names. +func ListAllPlacementPolicyNames(tctx *tcontext.Context, db *BaseConn) ([]string, error) { + var policyList []string + var policy string + const query = "select distinct policy_name from information_schema.placement_policies where policy_name is not null;" + err := db.QuerySQL(tctx, func(rows *sql.Rows) error { + err := rows.Scan(&policy) + if err != nil { + return errors.Trace(err) + } + policyList = append(policyList, policy) + return nil + }, func() { + policyList = policyList[:0] + }, query) + return policyList, errors.Annotatef(err, "sql: %s", query) +} + +// SelectVersion gets the version information from the database server +func SelectVersion(db *sql.DB) (string, error) { + var versionInfo string + const query = "SELECT version()" + row := db.QueryRow(query) + err := row.Scan(&versionInfo) + if err != nil { + return "", errors.Annotatef(err, "sql: %s", query) + } + return versionInfo, nil +} + +// SelectAllFromTable dumps data serialized from a specified table +func SelectAllFromTable(conf *Config, meta TableMeta, partition, orderByClause string) TableDataIR { + database, table := meta.DatabaseName(), meta.TableName() + selectedField, selectLen := meta.SelectedField(), meta.SelectedLen() + query := buildSelectQuery(database, table, selectedField, partition, buildWhereCondition(conf, ""), orderByClause) + + return &tableData{ + query: query, + colLen: selectLen, + } +} + +func buildSelectQuery(database, table, fields, partition, where, orderByClause string) string { + var query strings.Builder + query.WriteString("SELECT ") + if fields == "" { + // If all of the columns are generated, + // we need to make sure the query is valid. + fields = "''" + } + query.WriteString(fields) + query.WriteString(" FROM `") + query.WriteString(escapeString(database)) + query.WriteString("`.`") + query.WriteString(escapeString(table)) + query.WriteByte('`') + if partition != "" { + query.WriteString(" PARTITION(`") + query.WriteString(escapeString(partition)) + query.WriteString("`)") + } + + if where != "" { + query.WriteString(" ") + query.WriteString(where) + } + + if orderByClause != "" { + query.WriteString(" ") + query.WriteString(orderByClause) + } + + return query.String() +} + +func buildOrderByClause(tctx *tcontext.Context, conf *Config, db *BaseConn, database, table string, hasImplicitRowID bool) (string, error) { // revive:disable-line:flag-parameter + if !conf.SortByPk { + return "", nil + } + if hasImplicitRowID { + return orderByTiDBRowID, nil + } + cols, err := GetPrimaryKeyColumns(tctx, db, database, table) + if err != nil { + return "", errors.Trace(err) + } + return buildOrderByClauseString(cols), nil +} + +// SelectTiDBRowID checks whether this table has _tidb_rowid column +func SelectTiDBRowID(tctx *tcontext.Context, db *BaseConn, database, table string) (bool, error) { + tiDBRowIDQuery := fmt.Sprintf("SELECT _tidb_rowid from `%s`.`%s` LIMIT 1", escapeString(database), escapeString(table)) + hasImplictRowID := false + err := db.ExecSQL(tctx, func(_ sql.Result, err error) error { + if err != nil { + hasImplictRowID = false + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, fmt.Sprintf("%d", errno.ErrBadField)) { + return nil + } + return errors.Annotatef(err, "sql: %s", tiDBRowIDQuery) + } + hasImplictRowID = true + return nil + }, tiDBRowIDQuery) + return hasImplictRowID, err +} + +// GetSuitableRows gets suitable rows for each table +func GetSuitableRows(avgRowLength uint64) uint64 { + const ( + defaultRows = 200000 + maxRows = 1000000 + bytesPerFile = 128 * 1024 * 1024 // 128MB per file by default + ) + if avgRowLength == 0 { + return defaultRows + } + estimateRows := bytesPerFile / avgRowLength + if estimateRows > maxRows { + return maxRows + } + return estimateRows +} + +// GetColumnTypes gets *sql.ColumnTypes from a specified table +func GetColumnTypes(tctx *tcontext.Context, db *BaseConn, fields, database, table string) ([]*sql.ColumnType, error) { + query := fmt.Sprintf("SELECT %s FROM `%s`.`%s` LIMIT 1", fields, escapeString(database), escapeString(table)) + var colTypes []*sql.ColumnType + err := db.QuerySQL(tctx, func(rows *sql.Rows) error { + var err error + colTypes, err = rows.ColumnTypes() + if err == nil { + err = rows.Close() + } + failpoint.Inject("ChaosBrokenMetaConn", func(_ failpoint.Value) { + failpoint.Return(errors.New("connection is closed")) + }) + return errors.Annotatef(err, "sql: %s", query) + }, func() { + colTypes = nil + }, query) + if err != nil { + return nil, err + } + return colTypes, nil +} + +// GetPrimaryKeyAndColumnTypes gets all primary columns and their types in ordinal order +func GetPrimaryKeyAndColumnTypes(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) ([]string, []string, error) { + var ( + colNames, colTypes []string + err error + ) + colNames, err = GetPrimaryKeyColumns(tctx, conn, meta.DatabaseName(), meta.TableName()) + if err != nil { + return nil, nil, err + } + colName2Type := string2Map(meta.ColumnNames(), meta.ColumnTypes()) + colTypes = make([]string, len(colNames)) + for i, colName := range colNames { + colTypes[i] = colName2Type[colName] + } + return colNames, colTypes, nil +} + +// GetPrimaryKeyColumns gets all primary columns in ordinal order +func GetPrimaryKeyColumns(tctx *tcontext.Context, db *BaseConn, database, table string) ([]string, error) { + priKeyColsQuery := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", escapeString(database), escapeString(table)) + results, err := db.QuerySQLWithColumns(tctx, []string{"KEY_NAME", "COLUMN_NAME"}, priKeyColsQuery) + if err != nil { + return nil, err + } + + cols := make([]string, 0, len(results)) + for _, oneRow := range results { + keyName, columnName := oneRow[0], oneRow[1] + if keyName == "PRIMARY" { + cols = append(cols, columnName) + } + } + return cols, nil +} + +// getNumericIndex picks up indices according to the following priority: +// primary key > unique key with the smallest count > key with the max cardinality +// primary key with multi cols is before unique key with single col because we will sort result by primary keys +func getNumericIndex(tctx *tcontext.Context, db *BaseConn, meta TableMeta) (string, error) { + database, table := meta.DatabaseName(), meta.TableName() + colName2Type := string2Map(meta.ColumnNames(), meta.ColumnTypes()) + keyQuery := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", escapeString(database), escapeString(table)) + results, err := db.QuerySQLWithColumns(tctx, []string{"NON_UNIQUE", "SEQ_IN_INDEX", "KEY_NAME", "COLUMN_NAME", "CARDINALITY"}, keyQuery) + if err != nil { + return "", err + } + type keyColumnPair struct { + colName string + count uint64 + } + var ( + uniqueKeyMap = map[string]keyColumnPair{} // unique key name -> key column name, unique key columns count + keyColumn string + maxCardinality int64 = -1 + ) + + // check primary key first, then unique key + for _, oneRow := range results { + nonUnique, seqInIndex, keyName, colName, cardinality := oneRow[0], oneRow[1], oneRow[2], oneRow[3], oneRow[4] + // only try pick the first column, because the second column of pk/uk in where condition will trigger a full table scan + if seqInIndex != "1" { + if pair, ok := uniqueKeyMap[keyName]; ok { + seqInIndexInt, err := strconv.ParseUint(seqInIndex, 10, 64) + if err == nil && seqInIndexInt > pair.count { + uniqueKeyMap[keyName] = keyColumnPair{pair.colName, seqInIndexInt} + } + } + continue + } + _, numberColumn := dataTypeInt[colName2Type[colName]] + if numberColumn { + switch { + case keyName == "PRIMARY": + return colName, nil + case nonUnique == "0": + uniqueKeyMap[keyName] = keyColumnPair{colName, 1} + // pick index column with max cardinality when there is no unique index + case len(uniqueKeyMap) == 0: + cardinalityInt, err := strconv.ParseInt(cardinality, 10, 64) + if err == nil && cardinalityInt > maxCardinality { + keyColumn = colName + maxCardinality = cardinalityInt + } + } + } + } + if len(uniqueKeyMap) > 0 { + var ( + minCols uint64 = math.MaxUint64 + uniqueKeyColumn string + ) + for _, pair := range uniqueKeyMap { + if pair.count < minCols { + uniqueKeyColumn = pair.colName + minCols = pair.count + } + } + return uniqueKeyColumn, nil + } + return keyColumn, nil +} + +// FlushTableWithReadLock flush tables with read lock +func FlushTableWithReadLock(ctx context.Context, db *sql.Conn) error { + const ftwrlQuery = "FLUSH TABLES WITH READ LOCK" + _, err := db.ExecContext(ctx, ftwrlQuery) + return errors.Annotatef(err, "sql: %s", ftwrlQuery) +} + +// LockTables locks table with read lock +func LockTables(ctx context.Context, db *sql.Conn, database, table string) error { + lockTableQuery := fmt.Sprintf("LOCK TABLES `%s`.`%s` READ", escapeString(database), escapeString(table)) + _, err := db.ExecContext(ctx, lockTableQuery) + return errors.Annotatef(err, "sql: %s", lockTableQuery) +} + +// UnlockTables unlocks all tables' lock +func UnlockTables(ctx context.Context, db *sql.Conn) error { + const unlockTableQuery = "UNLOCK TABLES" + _, err := db.ExecContext(ctx, unlockTableQuery) + return errors.Annotatef(err, "sql: %s", unlockTableQuery) +} + +// ShowMasterStatus get SHOW MASTER STATUS result from database +func ShowMasterStatus(db *sql.Conn) ([]string, error) { + var oneRow []string + handleOneRow := func(rows *sql.Rows) error { + cols, err := rows.Columns() + if err != nil { + return errors.Trace(err) + } + fieldNum := len(cols) + oneRow = make([]string, fieldNum) + addr := make([]any, fieldNum) + for i := range oneRow { + addr[i] = &oneRow[i] + } + return rows.Scan(addr...) + } + const showMasterStatusQuery = "SHOW MASTER STATUS" + err := simpleQuery(db, showMasterStatusQuery, handleOneRow) + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", showMasterStatusQuery) + } + return oneRow, nil +} + +// GetSpecifiedColumnValueAndClose get columns' values whose name is equal to columnName and close the given rows +func GetSpecifiedColumnValueAndClose(rows *sql.Rows, columnName string) ([]string, error) { + if rows == nil { + return []string{}, nil + } + defer rows.Close() + var strs []string + columns, _ := rows.Columns() + addr := make([]any, len(columns)) + oneRow := make([]sql.NullString, len(columns)) + fieldIndex := -1 + for i, col := range columns { + if strings.EqualFold(col, columnName) { + fieldIndex = i + } + addr[i] = &oneRow[i] + } + if fieldIndex == -1 { + return strs, nil + } + for rows.Next() { + err := rows.Scan(addr...) + if err != nil { + return strs, errors.Trace(err) + } + if oneRow[fieldIndex].Valid { + strs = append(strs, oneRow[fieldIndex].String) + } + } + return strs, errors.Trace(rows.Err()) +} + +// GetSpecifiedColumnValuesAndClose get columns' values whose name is equal to columnName +func GetSpecifiedColumnValuesAndClose(rows *sql.Rows, columnName ...string) ([][]string, error) { + if rows == nil { + return [][]string{}, nil + } + defer rows.Close() + var strs [][]string + columns, err := rows.Columns() + if err != nil { + return strs, errors.Trace(err) + } + addr := make([]any, len(columns)) + oneRow := make([]sql.NullString, len(columns)) + fieldIndexMp := make(map[int]int) + for i, col := range columns { + addr[i] = &oneRow[i] + for j, name := range columnName { + if strings.EqualFold(col, name) { + fieldIndexMp[i] = j + } + } + } + if len(fieldIndexMp) == 0 { + return strs, nil + } + for rows.Next() { + err := rows.Scan(addr...) + if err != nil { + return strs, errors.Trace(err) + } + written := false + tmpStr := make([]string, len(columnName)) + for colPos, namePos := range fieldIndexMp { + if oneRow[colPos].Valid { + written = true + tmpStr[namePos] = oneRow[colPos].String + } + } + if written { + strs = append(strs, tmpStr) + } + } + return strs, errors.Trace(rows.Err()) +} + +// GetPdAddrs gets PD address from TiDB +func GetPdAddrs(tctx *tcontext.Context, db *sql.DB) ([]string, error) { + const query = "SELECT * FROM information_schema.cluster_info where type = 'pd';" + rows, err := db.QueryContext(tctx, query) + if err != nil { + return []string{}, errors.Annotatef(err, "sql: %s", query) + } + pdAddrs, err := GetSpecifiedColumnValueAndClose(rows, "STATUS_ADDRESS") + return pdAddrs, errors.Annotatef(err, "sql: %s", query) +} + +// GetTiDBDDLIDs gets DDL IDs from TiDB +func GetTiDBDDLIDs(tctx *tcontext.Context, db *sql.DB) ([]string, error) { + const query = "SELECT * FROM information_schema.tidb_servers_info;" + rows, err := db.QueryContext(tctx, query) + if err != nil { + return []string{}, errors.Annotatef(err, "sql: %s", query) + } + ddlIDs, err := GetSpecifiedColumnValueAndClose(rows, "DDL_ID") + return ddlIDs, errors.Annotatef(err, "sql: %s", query) +} + +// getTiDBConfig gets tidb config from TiDB server +// @@tidb_config details doc https://docs.pingcap.com/tidb/stable/system-variables#tidb_config +// this variable exists at least from v2.0.0, so this works in most existing tidb instances +func getTiDBConfig(db *sql.Conn) (dbconfig.Config, error) { + const query = "SELECT @@tidb_config;" + var ( + tidbConfig dbconfig.Config + tidbConfigBytes []byte + ) + row := db.QueryRowContext(context.Background(), query) + err := row.Scan(&tidbConfigBytes) + if err != nil { + return tidbConfig, errors.Annotatef(err, "sql: %s", query) + } + err = json.Unmarshal(tidbConfigBytes, &tidbConfig) + return tidbConfig, errors.Annotatef(err, "sql: %s", query) +} + +// CheckTiDBWithTiKV use sql to check whether current TiDB has TiKV +func CheckTiDBWithTiKV(db *sql.DB) (bool, error) { + conn, err := db.Conn(context.Background()) + if err == nil { + defer func() { + _ = conn.Close() + }() + tidbConfig, err := getTiDBConfig(conn) + if err == nil { + return tidbConfig.Store == "tikv", nil + } + } + var count int + const query = "SELECT COUNT(1) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'" + row := db.QueryRow(query) + err = row.Scan(&count) + if err != nil { + // still return true here. Because sometimes users may not have privileges for MySQL.TiDB database + // In most production cases TiDB has TiKV + return true, errors.Annotatef(err, "sql: %s", query) + } + return count > 0, nil +} + +// CheckIfSeqExists use sql to check whether sequence exists +func CheckIfSeqExists(db *sql.Conn) (bool, error) { + var count int + const query = "SELECT COUNT(1) as c FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='SEQUENCE'" + row := db.QueryRowContext(context.Background(), query) + err := row.Scan(&count) + if err != nil { + return false, errors.Annotatef(err, "sql: %s", query) + } + + return count > 0, nil +} + +// CheckTiDBEnableTableLock use sql variable to check whether current TiDB has TiKV +func CheckTiDBEnableTableLock(db *sql.Conn) (bool, error) { + tidbConfig, err := getTiDBConfig(db) + if err != nil { + return false, err + } + return tidbConfig.EnableTableLock, nil +} + +func getSnapshot(db *sql.Conn) (string, error) { + str, err := ShowMasterStatus(db) + if err != nil { + return "", err + } + return str[snapshotFieldIndex], nil +} + +func isUnknownSystemVariableErr(err error) bool { + return strings.Contains(err.Error(), "Unknown system variable") +} + +// resetDBWithSessionParams will return a new sql.DB as a replacement for input `db` with new session parameters. +// If returned error is nil, the input `db` will be closed. +func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Config, params map[string]any) (*sql.DB, error) { + support := make(map[string]any) + for k, v := range params { + var pv any + if str, ok := v.(string); ok { + if pvi, err := strconv.ParseInt(str, 10, 64); err == nil { + pv = pvi + } else if pvf, err := strconv.ParseFloat(str, 64); err == nil { + pv = pvf + } else { + pv = str + } + } else { + pv = v + } + s := fmt.Sprintf("SET SESSION %s = ?", k) + _, err := db.ExecContext(tctx, s, pv) + if err != nil { + if k == snapshotVar { + err = errors.Annotate(err, "fail to set snapshot for tidb, please set --consistency=none/--consistency=lock or fix snapshot problem") + } else if isUnknownSystemVariableErr(err) { + tctx.L().Info("session variable is not supported by db", zap.String("variable", k), zap.Reflect("value", v)) + continue + } + return nil, errors.Trace(err) + } + + support[k] = pv + } + + if cfg.Params == nil { + cfg.Params = make(map[string]string) + } + + for k, v := range support { + var s string + // Wrap string with quote to handle string with space. For example, '2020-10-20 13:41:40' + // For --params argument, quote doesn't matter because it doesn't affect the actual value + if str, ok := v.(string); ok { + s = wrapStringWith(str, "'") + } else { + s = fmt.Sprintf("%v", v) + } + cfg.Params[k] = s + } + failpoint.Inject("SkipResetDB", func(_ failpoint.Value) { + failpoint.Return(db, nil) + }) + + db.Close() + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + newDB := sql.OpenDB(c) + // ping to make sure all session parameters are set correctly + err = newDB.PingContext(tctx) + if err != nil { + newDB.Close() + } + return newDB, nil +} + +func createConnWithConsistency(ctx context.Context, db *sql.DB, repeatableRead bool) (*sql.Conn, error) { + conn, err := db.Conn(ctx) + if err != nil { + return nil, errors.Trace(err) + } + var query string + if repeatableRead { + query = "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ" + _, err = conn.ExecContext(ctx, query) + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + } + query = "START TRANSACTION /*!40108 WITH CONSISTENT SNAPSHOT */" + _, err = conn.ExecContext(ctx, query) + if err != nil { + // Some MySQL Compatible databases like Vitess and MemSQL/SingleStore + // are newer than 4.1.8 (the version comment) but don't actually support + // `WITH CONSISTENT SNAPSHOT`. So retry without that if the statement fails. + query = "START TRANSACTION" + _, err = conn.ExecContext(ctx, query) + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + } + return conn, nil +} + +// buildSelectField returns the selecting fields' string(joined by comma(`,`)), +// and the number of writable fields. +func buildSelectField(tctx *tcontext.Context, db *BaseConn, dbName, tableName string, completeInsert bool) (string, int, error) { // revive:disable-line:flag-parameter + query := fmt.Sprintf("SHOW COLUMNS FROM `%s`.`%s`", escapeString(dbName), escapeString(tableName)) + results, err := db.QuerySQLWithColumns(tctx, []string{"FIELD", "EXTRA"}, query) + if err != nil { + return "", 0, err + } + availableFields := make([]string, 0) + hasGenerateColumn := false + for _, oneRow := range results { + fieldName, extra := oneRow[0], oneRow[1] + switch extra { + case "STORED GENERATED", "VIRTUAL GENERATED": + hasGenerateColumn = true + continue + } + availableFields = append(availableFields, wrapBackTicks(escapeString(fieldName))) + } + if completeInsert || hasGenerateColumn { + return strings.Join(availableFields, ","), len(availableFields), nil + } + return "*", len(availableFields), nil +} + +func buildWhereClauses(handleColNames []string, handleVals [][]string) []string { + if len(handleColNames) == 0 || len(handleVals) == 0 { + return nil + } + quotaCols := make([]string, len(handleColNames)) + for i, s := range handleColNames { + quotaCols[i] = fmt.Sprintf("`%s`", escapeString(s)) + } + where := make([]string, 0, len(handleVals)+1) + buf := &bytes.Buffer{} + buildCompareClause(buf, quotaCols, handleVals[0], less, false) + where = append(where, buf.String()) + buf.Reset() + for i := 1; i < len(handleVals); i++ { + low, up := handleVals[i-1], handleVals[i] + buildBetweenClause(buf, quotaCols, low, up) + where = append(where, buf.String()) + buf.Reset() + } + buildCompareClause(buf, quotaCols, handleVals[len(handleVals)-1], greater, true) + where = append(where, buf.String()) + buf.Reset() + return where +} + +// return greater than TableRangeScan where clause +// the result doesn't contain brackets +const ( + greater = '>' + less = '<' + equal = '=' +) + +// buildCompareClause build clause with specified bounds. Usually we will use the following two conditions: +// (compare, writeEqual) == (less, false), return quotaCols < bound clause. In other words, (-inf, bound) +// (compare, writeEqual) == (greater, true), return quotaCols >= bound clause. In other words, [bound, +inf) +func buildCompareClause(buf *bytes.Buffer, quotaCols []string, bound []string, compare byte, writeEqual bool) { // revive:disable-line:flag-parameter + for i, col := range quotaCols { + if i > 0 { + buf.WriteString("or(") + } + for j := 0; j < i; j++ { + buf.WriteString(quotaCols[j]) + buf.WriteByte(equal) + buf.WriteString(bound[j]) + buf.WriteString(" and ") + } + buf.WriteString(col) + buf.WriteByte(compare) + if writeEqual && i == len(quotaCols)-1 { + buf.WriteByte(equal) + } + buf.WriteString(bound[i]) + if i > 0 { + buf.WriteByte(')') + } else if i != len(quotaCols)-1 { + buf.WriteByte(' ') + } + } +} + +// getCommonLength returns the common length of low and up +func getCommonLength(low []string, up []string) int { + for i := range low { + if low[i] != up[i] { + return i + } + } + return len(low) +} + +// buildBetweenClause build clause in a specified table range. +// the result where clause will be low <= quotaCols < up. In other words, [low, up) +func buildBetweenClause(buf *bytes.Buffer, quotaCols []string, low []string, up []string) { + singleBetween := func(writeEqual bool) { + buf.WriteString(quotaCols[0]) + buf.WriteByte(greater) + if writeEqual { + buf.WriteByte(equal) + } + buf.WriteString(low[0]) + buf.WriteString(" and ") + buf.WriteString(quotaCols[0]) + buf.WriteByte(less) + buf.WriteString(up[0]) + } + // handle special cases with common prefix + commonLen := getCommonLength(low, up) + if commonLen > 0 { + // unexpected case for low == up, return empty result + if commonLen == len(low) { + buf.WriteString("false") + return + } + for i := 0; i < commonLen; i++ { + if i > 0 { + buf.WriteString(" and ") + } + buf.WriteString(quotaCols[i]) + buf.WriteByte(equal) + buf.WriteString(low[i]) + } + buf.WriteString(" and(") + defer buf.WriteByte(')') + quotaCols = quotaCols[commonLen:] + low = low[commonLen:] + up = up[commonLen:] + } + + // handle special cases with only one column + if len(quotaCols) == 1 { + singleBetween(true) + return + } + buf.WriteByte('(') + singleBetween(false) + buf.WriteString(")or(") + buf.WriteString(quotaCols[0]) + buf.WriteByte(equal) + buf.WriteString(low[0]) + buf.WriteString(" and(") + buildCompareClause(buf, quotaCols[1:], low[1:], greater, true) + buf.WriteString("))or(") + buf.WriteString(quotaCols[0]) + buf.WriteByte(equal) + buf.WriteString(up[0]) + buf.WriteString(" and(") + buildCompareClause(buf, quotaCols[1:], up[1:], less, false) + buf.WriteString("))") +} + +func buildOrderByClauseString(handleColNames []string) string { + if len(handleColNames) == 0 { + return "" + } + separator := "," + quotaCols := make([]string, len(handleColNames)) + for i, col := range handleColNames { + quotaCols[i] = fmt.Sprintf("`%s`", escapeString(col)) + } + return fmt.Sprintf("ORDER BY %s", strings.Join(quotaCols, separator)) +} + +func buildLockTablesSQL(allTables DatabaseTables, blockList map[string]map[string]any) string { + // ,``.`` READ has 11 bytes, "LOCK TABLE" has 10 bytes + estimatedCap := len(allTables)*11 + 10 + s := bytes.NewBuffer(make([]byte, 0, estimatedCap)) + n := false + for dbName, tables := range allTables { + escapedDBName := escapeString(dbName) + for _, table := range tables { + // Lock views will lock related tables. However, we won't dump data only the create sql of view, so we needn't lock view here. + // Besides, mydumper also only lock base table here. https://github.com/maxbube/mydumper/blob/1fabdf87e3007e5934227b504ad673ba3697946c/mydumper.c#L1568 + if table.Type != TableTypeBase { + continue + } + if blockTable, ok := blockList[dbName]; ok { + if _, ok := blockTable[table.Name]; ok { + continue + } + } + if !n { + fmt.Fprintf(s, "LOCK TABLES `%s`.`%s` READ", escapedDBName, escapeString(table.Name)) + n = true + } else { + fmt.Fprintf(s, ",`%s`.`%s` READ", escapedDBName, escapeString(table.Name)) + } + } + } + return s.String() +} + +type oneStrColumnTable struct { + data []string +} + +func (o *oneStrColumnTable) handleOneRow(rows *sql.Rows) error { + var str string + if err := rows.Scan(&str); err != nil { + return errors.Trace(err) + } + o.data = append(o.data, str) + return nil +} + +func simpleQuery(conn *sql.Conn, query string, handleOneRow func(*sql.Rows) error) error { + return simpleQueryWithArgs(context.Background(), conn, handleOneRow, query) +} + +func simpleQueryWithArgs(ctx context.Context, conn *sql.Conn, handleOneRow func(*sql.Rows) error, query string, args ...any) error { + var ( + rows *sql.Rows + err error + ) + if len(args) > 0 { + rows, err = conn.QueryContext(ctx, query, args...) + } else { + rows, err = conn.QueryContext(ctx, query) + } + if err != nil { + return errors.Annotatef(err, "sql: %s, args: %s", query, args) + } + defer rows.Close() + + for rows.Next() { + if err := handleOneRow(rows); err != nil { + rows.Close() + return errors.Annotatef(err, "sql: %s, args: %s", query, args) + } + } + return errors.Annotatef(rows.Err(), "sql: %s, args: %s", query, args) +} + +func pickupPossibleField(tctx *tcontext.Context, meta TableMeta, db *BaseConn) (string, error) { + // try using _tidb_rowid first + if meta.HasImplicitRowID() { + return "_tidb_rowid", nil + } + // try to use pk or uk + fieldName, err := getNumericIndex(tctx, db, meta) + if err != nil { + return "", err + } + + // if fieldName == "", there is no proper index + return fieldName, nil +} + +func estimateCount(tctx *tcontext.Context, dbName, tableName string, db *BaseConn, field string, conf *Config) uint64 { + var query string + if strings.TrimSpace(field) == "*" || strings.TrimSpace(field) == "" { + query = fmt.Sprintf("EXPLAIN SELECT * FROM `%s`.`%s`", escapeString(dbName), escapeString(tableName)) + } else { + query = fmt.Sprintf("EXPLAIN SELECT `%s` FROM `%s`.`%s`", escapeString(field), escapeString(dbName), escapeString(tableName)) + } + + if conf.Where != "" { + query += " WHERE " + query += conf.Where + } + + estRows := detectEstimateRows(tctx, db, query, []string{"rows", "estRows", "count"}) + /* tidb results field name is estRows (before 4.0.0-beta.2: count) + +-----------------------+----------+-----------+---------------------------------------------------------+ + | id | estRows | task | access object | operator info | + +-----------------------+----------+-----------+---------------------------------------------------------+ + | tablereader_5 | 10000.00 | root | | data:tablefullscan_4 | + | └─tablefullscan_4 | 10000.00 | cop[tikv] | table:a | table:a, keep order:false, stats:pseudo | + +-----------------------+----------+-----------+---------------------------------------------------------- + + mariadb result field name is rows + +------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+ + | id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra | + +------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+ + | 1 | SIMPLE | sbtest1 | index | NULL | k_1 | 4 | NULL | 15000049 | Using index | + +------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+ + + mysql result field name is rows + +----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+ + | id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | + +----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+ + | 1 | SIMPLE | t1 | NULL | index | NULL | multi_col | 10 | NULL | 5 | 100.00 | Using index | + +----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+ + */ + if estRows > 0 { + return estRows + } + return 0 +} + +func detectEstimateRows(tctx *tcontext.Context, db *BaseConn, query string, fieldNames []string) uint64 { + var ( + fieldIndex int + oneRow []sql.NullString + ) + err := db.QuerySQL(tctx, func(rows *sql.Rows) error { + columns, err := rows.Columns() + if err != nil { + return errors.Trace(err) + } + addr := make([]any, len(columns)) + oneRow = make([]sql.NullString, len(columns)) + fieldIndex = -1 + found: + for i := range oneRow { + for _, fieldName := range fieldNames { + if strings.EqualFold(columns[i], fieldName) { + fieldIndex = i + break found + } + } + } + if fieldIndex == -1 { + rows.Close() + return nil + } + + for i := range oneRow { + addr[i] = &oneRow[i] + } + return rows.Scan(addr...) + }, func() {}, query) + if err != nil || fieldIndex == -1 { + tctx.L().Info("can't estimate rows from db", + zap.String("query", query), zap.Int("fieldIndex", fieldIndex), log.ShortError(err)) + return 0 + } + + estRows, err := strconv.ParseFloat(oneRow[fieldIndex].String, 64) + if err != nil { + tctx.L().Info("can't get parse estimate rows from db", + zap.String("query", query), zap.String("estRows", oneRow[fieldIndex].String), log.ShortError(err)) + return 0 + } + return uint64(estRows) +} + +func parseSnapshotToTSO(pool *sql.DB, snapshot string) (uint64, error) { + snapshotTS, err := strconv.ParseUint(snapshot, 10, 64) + if err == nil { + return snapshotTS, nil + } + var tso sql.NullInt64 + query := "SELECT unix_timestamp(?)" + row := pool.QueryRow(query, snapshot) + err = row.Scan(&tso) + if err != nil { + return 0, errors.Annotatef(err, "sql: %s", strings.ReplaceAll(query, "?", fmt.Sprintf(`"%s"`, snapshot))) + } + if !tso.Valid { + return 0, errors.Errorf("snapshot %s format not supported. please use tso or '2006-01-02 15:04:05' format time", snapshot) + } + return (uint64(tso.Int64) << 18) * 1000, nil +} + +func buildWhereCondition(conf *Config, where string) string { + var query strings.Builder + separator := "WHERE" + leftBracket := " " + rightBracket := " " + if conf.Where != "" && where != "" { + leftBracket = " (" + rightBracket = ") " + } + if conf.Where != "" { + query.WriteString(separator) + query.WriteString(leftBracket) + query.WriteString(conf.Where) + query.WriteString(rightBracket) + separator = "AND" + } + if where != "" { + query.WriteString(separator) + query.WriteString(leftBracket) + query.WriteString(where) + query.WriteString(rightBracket) + } + return query.String() +} + +func escapeString(s string) string { + return strings.ReplaceAll(s, "`", "``") +} + +// GetPartitionNames get partition names from a specified table +func GetPartitionNames(tctx *tcontext.Context, db *BaseConn, schema, table string) (partitions []string, err error) { + partitions = make([]string, 0) + var partitionName sql.NullString + err = db.QuerySQL(tctx, func(rows *sql.Rows) error { + err := rows.Scan(&partitionName) + if err != nil { + return errors.Trace(err) + } + if partitionName.Valid { + partitions = append(partitions, partitionName.String) + } + return nil + }, func() { + partitions = partitions[:0] + }, "SELECT PARTITION_NAME from INFORMATION_SCHEMA.PARTITIONS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", schema, table) + return +} + +// GetPartitionTableIDs get partition tableIDs through histograms. +// SHOW STATS_HISTOGRAMS has db_name,table_name,partition_name but doesn't have partition id +// mysql.stats_histograms has partition_id but doesn't have db_name,table_name,partition_name +// So we combine the results from these two sqls to get partition ids for each table +// If UPDATE_TIME,DISTINCT_COUNT are equal, we assume these two records can represent one line. +// If histograms are not accurate or (UPDATE_TIME,DISTINCT_COUNT) has duplicate data, it's still fine. +// Because the possibility is low and the effect is that we will select more than one regions in one time, +// this will not affect the correctness of the dumping data and will not affect the memory usage much. +// This method is tricky, but no better way is found. +// Because TiDB v3.0.0's information_schema.partition table doesn't have partition name or partition id info +// return (dbName -> tbName -> partitionName -> partitionID, error) +func GetPartitionTableIDs(db *sql.Conn, tables map[string]map[string]struct{}) (map[string]map[string]map[string]int64, error) { + const ( + showStatsHistogramsSQL = "SHOW STATS_HISTOGRAMS" + selectStatsHistogramsSQL = "SELECT TABLE_ID,FROM_UNIXTIME(VERSION DIV 262144 DIV 1000,'%Y-%m-%d %H:%i:%s') AS UPDATE_TIME,DISTINCT_COUNT FROM mysql.stats_histograms" + ) + partitionIDs := make(map[string]map[string]map[string]int64, len(tables)) + rows, err := db.QueryContext(context.Background(), showStatsHistogramsSQL) + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", showStatsHistogramsSQL) + } + results, err := GetSpecifiedColumnValuesAndClose(rows, "DB_NAME", "TABLE_NAME", "PARTITION_NAME", "UPDATE_TIME", "DISTINCT_COUNT") + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", showStatsHistogramsSQL) + } + type partitionInfo struct { + dbName, tbName, partitionName string + } + saveMap := make(map[string]map[string]partitionInfo) + for _, oneRow := range results { + dbName, tbName, partitionName, updateTime, distinctCount := oneRow[0], oneRow[1], oneRow[2], oneRow[3], oneRow[4] + if len(partitionName) == 0 { + continue + } + if tbm, ok := tables[dbName]; ok { + if _, ok = tbm[tbName]; ok { + if _, ok = saveMap[updateTime]; !ok { + saveMap[updateTime] = make(map[string]partitionInfo) + } + saveMap[updateTime][distinctCount] = partitionInfo{ + dbName: dbName, + tbName: tbName, + partitionName: partitionName, + } + } + } + } + if len(saveMap) == 0 { + return map[string]map[string]map[string]int64{}, nil + } + err = simpleQuery(db, selectStatsHistogramsSQL, func(rows *sql.Rows) error { + var ( + tableID int64 + updateTime, distinctCount string + ) + err2 := rows.Scan(&tableID, &updateTime, &distinctCount) + if err2 != nil { + return errors.Trace(err2) + } + if mpt, ok := saveMap[updateTime]; ok { + if partition, ok := mpt[distinctCount]; ok { + dbName, tbName, partitionName := partition.dbName, partition.tbName, partition.partitionName + if _, ok := partitionIDs[dbName]; !ok { + partitionIDs[dbName] = make(map[string]map[string]int64) + } + if _, ok := partitionIDs[dbName][tbName]; !ok { + partitionIDs[dbName][tbName] = make(map[string]int64) + } + partitionIDs[dbName][tbName][partitionName] = tableID + } + } + return nil + }) + return partitionIDs, err +} + +// GetDBInfo get model.DBInfos from database sql interface. +// We need table_id to check whether a region belongs to this table +func GetDBInfo(db *sql.Conn, tables map[string]map[string]struct{}) ([]*model.DBInfo, error) { + const tableIDSQL = "SELECT TABLE_SCHEMA,TABLE_NAME,TIDB_TABLE_ID FROM INFORMATION_SCHEMA.TABLES ORDER BY TABLE_SCHEMA" + + schemas := make([]*model.DBInfo, 0, len(tables)) + var ( + tableSchema, tableName string + tidbTableID int64 + ) + partitionIDs, err := GetPartitionTableIDs(db, tables) + if err != nil { + return nil, err + } + err = simpleQuery(db, tableIDSQL, func(rows *sql.Rows) error { + err2 := rows.Scan(&tableSchema, &tableName, &tidbTableID) + if err2 != nil { + return errors.Trace(err2) + } + if tbm, ok := tables[tableSchema]; !ok { + return nil + } else if _, ok = tbm[tableName]; !ok { + return nil + } + last := len(schemas) - 1 + if last < 0 || schemas[last].Name.O != tableSchema { + dbInfo := &model.DBInfo{Name: model.CIStr{O: tableSchema}} + dbInfo.Deprecated.Tables = make([]*model.TableInfo, 0, len(tables[tableSchema])) + schemas = append(schemas, dbInfo) + last++ + } + var partition *model.PartitionInfo + if tbm, ok := partitionIDs[tableSchema]; ok { + if ptm, ok := tbm[tableName]; ok { + partition = &model.PartitionInfo{Definitions: make([]model.PartitionDefinition, 0, len(ptm))} + for partitionName, partitionID := range ptm { + partition.Definitions = append(partition.Definitions, model.PartitionDefinition{ + ID: partitionID, + Name: model.CIStr{O: partitionName}, + }) + } + } + } + schemas[last].Deprecated.Tables = append(schemas[last].Deprecated.Tables, &model.TableInfo{ + ID: tidbTableID, + Name: model.CIStr{O: tableName}, + Partition: partition, + }) + return nil + }) + return schemas, err +} + +// GetRegionInfos get region info including regionID, start key, end key from database sql interface. +// start key, end key includes information to help split table +func GetRegionInfos(db *sql.Conn) (*pd.RegionsInfo, error) { + const tableRegionSQL = "SELECT REGION_ID,START_KEY,END_KEY FROM INFORMATION_SCHEMA.TIKV_REGION_STATUS ORDER BY START_KEY;" + var ( + regionID int64 + startKey, endKey string + ) + regionsInfo := &pd.RegionsInfo{Regions: make([]pd.RegionInfo, 0)} + err := simpleQuery(db, tableRegionSQL, func(rows *sql.Rows) error { + err := rows.Scan(®ionID, &startKey, &endKey) + if err != nil { + return errors.Trace(err) + } + regionsInfo.Regions = append(regionsInfo.Regions, pd.RegionInfo{ + ID: regionID, + StartKey: startKey, + EndKey: endKey, + }) + return nil + }) + return regionsInfo, err +} + +// GetCharsetAndDefaultCollation gets charset and default collation map. +func GetCharsetAndDefaultCollation(ctx context.Context, db *sql.Conn) (map[string]string, error) { + charsetAndDefaultCollation := make(map[string]string) + query := "SHOW CHARACTER SET" + + // Show an example. + /* + mysql> SHOW CHARACTER SET; + +----------+---------------------------------+---------------------+--------+ + | Charset | Description | Default collation | Maxlen | + +----------+---------------------------------+---------------------+--------+ + | armscii8 | ARMSCII-8 Armenian | armscii8_general_ci | 1 | + | ascii | US ASCII | ascii_general_ci | 1 | + | big5 | Big5 Traditional Chinese | big5_chinese_ci | 2 | + | binary | Binary pseudo charset | binary | 1 | + | cp1250 | Windows Central European | cp1250_general_ci | 1 | + | cp1251 | Windows Cyrillic | cp1251_general_ci | 1 | + +----------+---------------------------------+---------------------+--------+ + */ + + rows, err := db.QueryContext(ctx, query) + if err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + + defer rows.Close() + for rows.Next() { + var charset, description, collation string + var maxlen int + if scanErr := rows.Scan(&charset, &description, &collation, &maxlen); scanErr != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + charsetAndDefaultCollation[strings.ToLower(charset)] = collation + } + if err = rows.Close(); err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + if err = rows.Err(); err != nil { + return nil, errors.Annotatef(err, "sql: %s", query) + } + return charsetAndDefaultCollation, err +} diff --git a/dumpling/export/status.go b/dumpling/export/status.go index 0a861f4c40677..9c67964f69602 100644 --- a/dumpling/export/status.go +++ b/dumpling/export/status.go @@ -18,11 +18,11 @@ const logProgressTick = 2 * time.Minute func (d *Dumper) runLogProgress(tctx *tcontext.Context) { logProgressTicker := time.NewTicker(logProgressTick) - failpoint.Inject("EnableLogProgress", func() { + if _, _err_ := failpoint.Eval(_curpkg_("EnableLogProgress")); _err_ == nil { logProgressTicker.Stop() logProgressTicker = time.NewTicker(time.Duration(1) * time.Second) tctx.L().Debug("EnableLogProgress") - }) + } lastCheckpoint := time.Now() lastBytes := float64(0) defer logProgressTicker.Stop() diff --git a/dumpling/export/status.go__failpoint_stash__ b/dumpling/export/status.go__failpoint_stash__ new file mode 100644 index 0000000000000..0a861f4c40677 --- /dev/null +++ b/dumpling/export/status.go__failpoint_stash__ @@ -0,0 +1,144 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package export + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/docker/go-units" + "github.com/pingcap/failpoint" + tcontext "github.com/pingcap/tidb/dumpling/context" + "go.uber.org/zap" +) + +const logProgressTick = 2 * time.Minute + +func (d *Dumper) runLogProgress(tctx *tcontext.Context) { + logProgressTicker := time.NewTicker(logProgressTick) + failpoint.Inject("EnableLogProgress", func() { + logProgressTicker.Stop() + logProgressTicker = time.NewTicker(time.Duration(1) * time.Second) + tctx.L().Debug("EnableLogProgress") + }) + lastCheckpoint := time.Now() + lastBytes := float64(0) + defer logProgressTicker.Stop() + for { + select { + case <-tctx.Done(): + tctx.L().Debug("stopping log progress") + return + case <-logProgressTicker.C: + nanoseconds := float64(time.Since(lastCheckpoint).Nanoseconds()) + s := d.GetStatus() + tctx.L().Info("progress", + zap.String("tables", fmt.Sprintf("%.0f/%.0f (%.1f%%)", s.CompletedTables, float64(d.totalTables), s.CompletedTables/float64(d.totalTables)*100)), + zap.String("finished rows", fmt.Sprintf("%.0f", s.FinishedRows)), + zap.String("estimate total rows", fmt.Sprintf("%.0f", s.EstimateTotalRows)), + zap.String("finished size", units.HumanSize(s.FinishedBytes)), + zap.Float64("average speed(MiB/s)", (s.FinishedBytes-lastBytes)/(1048576e-9*nanoseconds)), + zap.Float64("recent speed bps", s.CurrentSpeedBPS), + zap.String("chunks progress", s.Progress), + ) + + lastCheckpoint = time.Now() + lastBytes = s.FinishedBytes + } + } +} + +// DumpStatus is the status of dumping. +type DumpStatus struct { + CompletedTables float64 + FinishedBytes float64 + FinishedRows float64 + EstimateTotalRows float64 + TotalTables int64 + CurrentSpeedBPS float64 + Progress string +} + +// GetStatus returns the status of dumping by reading metrics. +func (d *Dumper) GetStatus() *DumpStatus { + ret := &DumpStatus{} + ret.TotalTables = atomic.LoadInt64(&d.totalTables) + ret.CompletedTables = ReadCounter(d.metrics.finishedTablesCounter) + ret.FinishedBytes = ReadGauge(d.metrics.finishedSizeGauge) + ret.FinishedRows = ReadGauge(d.metrics.finishedRowsGauge) + ret.EstimateTotalRows = ReadCounter(d.metrics.estimateTotalRowsCounter) + ret.CurrentSpeedBPS = d.speedRecorder.GetSpeed(ret.FinishedBytes) + if d.metrics.progressReady.Load() { + // chunks will be zero when upstream has no data + if d.metrics.totalChunks.Load() == 0 { + ret.Progress = "100 %" + return ret + } + progress := float64(d.metrics.completedChunks.Load()) / float64(d.metrics.totalChunks.Load()) + if progress > 1 { + ret.Progress = "100 %" + d.L().Warn("completedChunks is greater than totalChunks", zap.Int64("completedChunks", d.metrics.completedChunks.Load()), zap.Int64("totalChunks", d.metrics.totalChunks.Load())) + } else { + ret.Progress = fmt.Sprintf("%5.2f %%", progress*100) + } + } + return ret +} + +func calculateTableCount(m DatabaseTables) int { + cnt := 0 + for _, tables := range m { + for _, table := range tables { + if table.Type == TableTypeBase { + cnt++ + } + } + } + return cnt +} + +// SpeedRecorder record the finished bytes and calculate its speed. +type SpeedRecorder struct { + mu sync.Mutex + lastFinished float64 + lastUpdateTime time.Time + speedBPS float64 +} + +// NewSpeedRecorder new a SpeedRecorder. +func NewSpeedRecorder() *SpeedRecorder { + return &SpeedRecorder{ + lastUpdateTime: time.Now(), + } +} + +// GetSpeed calculate status speed. +func (s *SpeedRecorder) GetSpeed(finished float64) float64 { + s.mu.Lock() + defer s.mu.Unlock() + + if finished <= s.lastFinished { + // for finished bytes does not get forwarded, use old speed to avoid + // display zero. We may find better strategy in future. + return s.speedBPS + } + + now := time.Now() + elapsed := now.Sub(s.lastUpdateTime).Seconds() + if elapsed == 0 { + // if time is short, return last speed + return s.speedBPS + } + currentSpeed := (finished - s.lastFinished) / elapsed + if currentSpeed == 0 { + currentSpeed = 1 + } + + s.lastFinished = finished + s.lastUpdateTime = now + s.speedBPS = currentSpeed + + return currentSpeed +} diff --git a/dumpling/export/writer_util.go b/dumpling/export/writer_util.go index e7ed2de2e611a..d6a0c3c69d609 100644 --- a/dumpling/export/writer_util.go +++ b/dumpling/export/writer_util.go @@ -239,10 +239,10 @@ func WriteInsert( } counter++ wp.AddFileSize(uint64(bf.Len()-lastBfSize) + 2) // 2 is for ",\n" and ";\n" - failpoint.Inject("ChaosBrokenWriterConn", func(_ failpoint.Value) { - failpoint.Return(0, errors.New("connection is closed")) - }) - failpoint.Inject("AtEveryRow", nil) + if _, _err_ := failpoint.Eval(_curpkg_("ChaosBrokenWriterConn")); _err_ == nil { + return 0, errors.New("connection is closed") + } + failpoint.Eval(_curpkg_("AtEveryRow")) fileRowIter.Next() shouldSwitch := wp.ShouldSwitchStatement() @@ -464,9 +464,9 @@ func buildFileWriter(tctx *tcontext.Context, s storage.ExternalStorage, fileName tctx.L().Debug("opened file", zap.String("path", fullPath)) tearDownRoutine := func(ctx context.Context) error { err := writer.Close(ctx) - failpoint.Inject("FailToCloseMetaFile", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("FailToCloseMetaFile")); _err_ == nil { err = errors.New("injected error: fail to close meta file") - }) + } if err == nil { return nil } @@ -507,9 +507,9 @@ func buildInterceptFileWriter(pCtx *tcontext.Context, s storage.ExternalStorage, } pCtx.L().Debug("tear down lazy file writer...", zap.String("path", fullPath)) err := writer.Close(ctx) - failpoint.Inject("FailToCloseDataFile", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("FailToCloseDataFile")); _err_ == nil { err = errors.New("injected error: fail to close data file") - }) + } if err != nil { pCtx.L().Warn("fail to close file", zap.String("path", fullPath), diff --git a/dumpling/export/writer_util.go__failpoint_stash__ b/dumpling/export/writer_util.go__failpoint_stash__ new file mode 100644 index 0000000000000..e7ed2de2e611a --- /dev/null +++ b/dumpling/export/writer_util.go__failpoint_stash__ @@ -0,0 +1,674 @@ +// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. + +package export + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/summary" + tcontext "github.com/pingcap/tidb/dumpling/context" + "github.com/pingcap/tidb/dumpling/log" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" +) + +const lengthLimit = 1048576 + +var pool = sync.Pool{New: func() any { + return &bytes.Buffer{} +}} + +type writerPipe struct { + input chan *bytes.Buffer + closed chan struct{} + errCh chan error + metrics *metrics + labels prometheus.Labels + + finishedFileSize uint64 + currentFileSize uint64 + currentStatementSize uint64 + + fileSizeLimit uint64 + statementSizeLimit uint64 + + w storage.ExternalFileWriter +} + +func newWriterPipe( + w storage.ExternalFileWriter, + fileSizeLimit, + statementSizeLimit uint64, + metrics *metrics, + labels prometheus.Labels, +) *writerPipe { + return &writerPipe{ + input: make(chan *bytes.Buffer, 8), + closed: make(chan struct{}), + errCh: make(chan error, 1), + w: w, + metrics: metrics, + labels: labels, + + currentFileSize: 0, + currentStatementSize: 0, + fileSizeLimit: fileSizeLimit, + statementSizeLimit: statementSizeLimit, + } +} + +func (b *writerPipe) Run(tctx *tcontext.Context) { + defer close(b.closed) + var errOccurs bool + receiveChunkTime := time.Now() + for { + select { + case s, ok := <-b.input: + if !ok { + return + } + if errOccurs { + continue + } + ObserveHistogram(b.metrics.receiveWriteChunkTimeHistogram, time.Since(receiveChunkTime).Seconds()) + receiveChunkTime = time.Now() + err := writeBytes(tctx, b.w, s.Bytes()) + ObserveHistogram(b.metrics.writeTimeHistogram, time.Since(receiveChunkTime).Seconds()) + AddGauge(b.metrics.finishedSizeGauge, float64(s.Len())) + b.finishedFileSize += uint64(s.Len()) + s.Reset() + pool.Put(s) + if err != nil { + errOccurs = true + b.errCh <- err + } + receiveChunkTime = time.Now() + case <-tctx.Done(): + return + } + } +} + +func (b *writerPipe) AddFileSize(fileSize uint64) { + b.currentFileSize += fileSize + b.currentStatementSize += fileSize +} + +func (b *writerPipe) Error() error { + select { + case err := <-b.errCh: + return err + default: + return nil + } +} + +func (b *writerPipe) ShouldSwitchFile() bool { + return b.fileSizeLimit != UnspecifiedSize && b.currentFileSize >= b.fileSizeLimit +} + +func (b *writerPipe) ShouldSwitchStatement() bool { + return (b.fileSizeLimit != UnspecifiedSize && b.currentFileSize >= b.fileSizeLimit) || + (b.statementSizeLimit != UnspecifiedSize && b.currentStatementSize >= b.statementSizeLimit) +} + +// WriteMeta writes MetaIR to a storage.ExternalFileWriter +func WriteMeta(tctx *tcontext.Context, meta MetaIR, w storage.ExternalFileWriter) error { + tctx.L().Debug("start dumping meta data", zap.String("target", meta.TargetName())) + + specCmtIter := meta.SpecialComments() + for specCmtIter.HasNext() { + if err := write(tctx, w, fmt.Sprintf("%s\n", specCmtIter.Next())); err != nil { + return err + } + } + + if err := write(tctx, w, meta.MetaSQL()); err != nil { + return err + } + + tctx.L().Debug("finish dumping meta data", zap.String("target", meta.TargetName())) + return nil +} + +// WriteInsert writes TableDataIR to a storage.ExternalFileWriter in sql type +func WriteInsert( + pCtx *tcontext.Context, + cfg *Config, + meta TableMeta, + tblIR TableDataIR, + w storage.ExternalFileWriter, + metrics *metrics, +) (n uint64, err error) { + fileRowIter := tblIR.Rows() + if !fileRowIter.HasNext() { + return 0, fileRowIter.Error() + } + + bf := pool.Get().(*bytes.Buffer) + if bfCap := bf.Cap(); bfCap < lengthLimit { + bf.Grow(lengthLimit - bfCap) + } + + wp := newWriterPipe(w, cfg.FileSize, cfg.StatementSize, metrics, cfg.Labels) + + // use context.Background here to make sure writerPipe can deplete all the chunks in pipeline + ctx, cancel := tcontext.Background().WithLogger(pCtx.L()).WithCancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + wp.Run(ctx) + wg.Done() + }() + defer func() { + cancel() + wg.Wait() + }() + + specCmtIter := meta.SpecialComments() + for specCmtIter.HasNext() { + bf.WriteString(specCmtIter.Next()) + bf.WriteByte('\n') + } + wp.currentFileSize += uint64(bf.Len()) + + var ( + insertStatementPrefix string + row = MakeRowReceiver(meta.ColumnTypes()) + counter uint64 + lastCounter uint64 + escapeBackslash = cfg.EscapeBackslash + ) + + defer func() { + if err != nil { + pCtx.L().Warn("fail to dumping table(chunk), will revert some metrics and start a retry if possible", + zap.String("database", meta.DatabaseName()), + zap.String("table", meta.TableName()), + zap.Uint64("finished rows", lastCounter), + zap.Uint64("finished size", wp.finishedFileSize), + log.ShortError(err)) + SubGauge(metrics.finishedRowsGauge, float64(lastCounter)) + SubGauge(metrics.finishedSizeGauge, float64(wp.finishedFileSize)) + } else { + pCtx.L().Debug("finish dumping table(chunk)", + zap.String("database", meta.DatabaseName()), + zap.String("table", meta.TableName()), + zap.Uint64("finished rows", counter), + zap.Uint64("finished size", wp.finishedFileSize)) + summary.CollectSuccessUnit(summary.TotalBytes, 1, wp.finishedFileSize) + summary.CollectSuccessUnit("total rows", 1, counter) + } + }() + + selectedField := meta.SelectedField() + + // if has generated column + if selectedField != "" && selectedField != "*" { + insertStatementPrefix = fmt.Sprintf("INSERT INTO %s (%s) VALUES\n", + wrapBackTicks(escapeString(meta.TableName())), selectedField) + } else { + insertStatementPrefix = fmt.Sprintf("INSERT INTO %s VALUES\n", + wrapBackTicks(escapeString(meta.TableName()))) + } + insertStatementPrefixLen := uint64(len(insertStatementPrefix)) + + for fileRowIter.HasNext() { + wp.currentStatementSize = 0 + bf.WriteString(insertStatementPrefix) + wp.AddFileSize(insertStatementPrefixLen) + + for fileRowIter.HasNext() { + lastBfSize := bf.Len() + if selectedField != "" { + if err = fileRowIter.Decode(row); err != nil { + return counter, errors.Trace(err) + } + row.WriteToBuffer(bf, escapeBackslash) + } else { + bf.WriteString("()") + } + counter++ + wp.AddFileSize(uint64(bf.Len()-lastBfSize) + 2) // 2 is for ",\n" and ";\n" + failpoint.Inject("ChaosBrokenWriterConn", func(_ failpoint.Value) { + failpoint.Return(0, errors.New("connection is closed")) + }) + failpoint.Inject("AtEveryRow", nil) + + fileRowIter.Next() + shouldSwitch := wp.ShouldSwitchStatement() + if fileRowIter.HasNext() && !shouldSwitch { + bf.WriteString(",\n") + } else { + bf.WriteString(";\n") + } + if bf.Len() >= lengthLimit { + select { + case <-pCtx.Done(): + return counter, pCtx.Err() + case err = <-wp.errCh: + return counter, err + case wp.input <- bf: + bf = pool.Get().(*bytes.Buffer) + if bfCap := bf.Cap(); bfCap < lengthLimit { + bf.Grow(lengthLimit - bfCap) + } + AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) + lastCounter = counter + } + } + + if shouldSwitch { + break + } + } + if wp.ShouldSwitchFile() { + break + } + } + if bf.Len() > 0 { + wp.input <- bf + } + close(wp.input) + <-wp.closed + AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) + lastCounter = counter + if err = fileRowIter.Error(); err != nil { + return counter, errors.Trace(err) + } + return counter, wp.Error() +} + +// WriteInsertInCsv writes TableDataIR to a storage.ExternalFileWriter in csv type +func WriteInsertInCsv( + pCtx *tcontext.Context, + cfg *Config, + meta TableMeta, + tblIR TableDataIR, + w storage.ExternalFileWriter, + metrics *metrics, +) (n uint64, err error) { + fileRowIter := tblIR.Rows() + if !fileRowIter.HasNext() { + return 0, fileRowIter.Error() + } + + bf := pool.Get().(*bytes.Buffer) + if bfCap := bf.Cap(); bfCap < lengthLimit { + bf.Grow(lengthLimit - bfCap) + } + + wp := newWriterPipe(w, cfg.FileSize, UnspecifiedSize, metrics, cfg.Labels) + opt := &csvOption{ + nullValue: cfg.CsvNullValue, + separator: []byte(cfg.CsvSeparator), + delimiter: []byte(cfg.CsvDelimiter), + lineTerminator: []byte(cfg.CsvLineTerminator), + binaryFormat: DialectBinaryFormatMap[cfg.CsvOutputDialect], + } + + // use context.Background here to make sure writerPipe can deplete all the chunks in pipeline + ctx, cancel := tcontext.Background().WithLogger(pCtx.L()).WithCancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + wp.Run(ctx) + wg.Done() + }() + defer func() { + cancel() + wg.Wait() + }() + + var ( + row = MakeRowReceiver(meta.ColumnTypes()) + counter uint64 + lastCounter uint64 + escapeBackslash = cfg.EscapeBackslash + selectedFields = meta.SelectedField() + ) + + defer func() { + if err != nil { + pCtx.L().Warn("fail to dumping table(chunk), will revert some metrics and start a retry if possible", + zap.String("database", meta.DatabaseName()), + zap.String("table", meta.TableName()), + zap.Uint64("finished rows", lastCounter), + zap.Uint64("finished size", wp.finishedFileSize), + log.ShortError(err)) + SubGauge(metrics.finishedRowsGauge, float64(lastCounter)) + SubGauge(metrics.finishedSizeGauge, float64(wp.finishedFileSize)) + } else { + pCtx.L().Debug("finish dumping table(chunk)", + zap.String("database", meta.DatabaseName()), + zap.String("table", meta.TableName()), + zap.Uint64("finished rows", counter), + zap.Uint64("finished size", wp.finishedFileSize)) + summary.CollectSuccessUnit(summary.TotalBytes, 1, wp.finishedFileSize) + summary.CollectSuccessUnit("total rows", 1, counter) + } + }() + + if !cfg.NoHeader && len(meta.ColumnNames()) != 0 && selectedFields != "" { + for i, col := range meta.ColumnNames() { + bf.Write(opt.delimiter) + escapeCSV([]byte(col), bf, escapeBackslash, opt) + bf.Write(opt.delimiter) + if i != len(meta.ColumnTypes())-1 { + bf.Write(opt.separator) + } + } + bf.Write(opt.lineTerminator) + } + wp.currentFileSize += uint64(bf.Len()) + + for fileRowIter.HasNext() { + lastBfSize := bf.Len() + if selectedFields != "" { + if err = fileRowIter.Decode(row); err != nil { + return counter, errors.Trace(err) + } + row.WriteToBufferInCsv(bf, escapeBackslash, opt) + } + counter++ + wp.currentFileSize += uint64(bf.Len()-lastBfSize) + 1 // 1 is for "\n" + + bf.Write(opt.lineTerminator) + if bf.Len() >= lengthLimit { + select { + case <-pCtx.Done(): + return counter, pCtx.Err() + case err = <-wp.errCh: + return counter, err + case wp.input <- bf: + bf = pool.Get().(*bytes.Buffer) + if bfCap := bf.Cap(); bfCap < lengthLimit { + bf.Grow(lengthLimit - bfCap) + } + AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) + lastCounter = counter + } + } + + fileRowIter.Next() + if wp.ShouldSwitchFile() { + break + } + } + + if bf.Len() > 0 { + wp.input <- bf + } + close(wp.input) + <-wp.closed + AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) + lastCounter = counter + if err = fileRowIter.Error(); err != nil { + return counter, errors.Trace(err) + } + return counter, wp.Error() +} + +func write(tctx *tcontext.Context, writer storage.ExternalFileWriter, str string) error { + _, err := writer.Write(tctx, []byte(str)) + if err != nil { + // str might be very long, only output the first 200 chars + outputLength := len(str) + if outputLength >= 200 { + outputLength = 200 + } + tctx.L().Warn("fail to write", + zap.String("heading 200 characters", str[:outputLength]), + zap.Error(err)) + } + return errors.Trace(err) +} + +func writeBytes(tctx *tcontext.Context, writer storage.ExternalFileWriter, p []byte) error { + _, err := writer.Write(tctx, p) + if err != nil { + // str might be very long, only output the first 200 chars + outputLength := len(p) + if outputLength >= 200 { + outputLength = 200 + } + tctx.L().Warn("fail to write", + zap.ByteString("heading 200 characters", p[:outputLength]), + zap.Error(err)) + if strings.Contains(err.Error(), "Part number must be an integer between 1 and 10000") { + err = errors.Annotate(err, "workaround: dump file exceeding 50GB, please specify -F=256MB -r=200000 to avoid this problem") + } + } + return errors.Trace(err) +} + +func buildFileWriter(tctx *tcontext.Context, s storage.ExternalStorage, fileName string, compressType storage.CompressType) (storage.ExternalFileWriter, func(ctx context.Context) error, error) { + fileName += compressFileSuffix(compressType) + fullPath := s.URI() + "/" + fileName + writer, err := storage.WithCompression(s, compressType, storage.DecompressConfig{}).Create(tctx, fileName, nil) + if err != nil { + tctx.L().Warn("fail to open file", + zap.String("path", fullPath), + zap.Error(err)) + return nil, nil, errors.Trace(err) + } + tctx.L().Debug("opened file", zap.String("path", fullPath)) + tearDownRoutine := func(ctx context.Context) error { + err := writer.Close(ctx) + failpoint.Inject("FailToCloseMetaFile", func(_ failpoint.Value) { + err = errors.New("injected error: fail to close meta file") + }) + if err == nil { + return nil + } + err = errors.Trace(err) + tctx.L().Warn("fail to close file", + zap.String("path", fullPath), + zap.Error(err)) + return err + } + return writer, tearDownRoutine, nil +} + +func buildInterceptFileWriter(pCtx *tcontext.Context, s storage.ExternalStorage, fileName string, compressType storage.CompressType) (storage.ExternalFileWriter, func(context.Context) error) { + fileName += compressFileSuffix(compressType) + var writer storage.ExternalFileWriter + fullPath := s.URI() + "/" + fileName + fileWriter := &InterceptFileWriter{} + initRoutine := func() error { + // use separated context pCtx here to make sure context used in ExternalFile won't be canceled before close, + // which will cause a context canceled error when closing gcs's Writer + w, err := storage.WithCompression(s, compressType, storage.DecompressConfig{}).Create(pCtx, fileName, nil) + if err != nil { + pCtx.L().Warn("fail to open file", + zap.String("path", fullPath), + zap.Error(err)) + return newWriterError(err) + } + writer = w + pCtx.L().Debug("opened file", zap.String("path", fullPath)) + fileWriter.ExternalFileWriter = writer + return nil + } + fileWriter.initRoutine = initRoutine + + tearDownRoutine := func(ctx context.Context) error { + if writer == nil { + return nil + } + pCtx.L().Debug("tear down lazy file writer...", zap.String("path", fullPath)) + err := writer.Close(ctx) + failpoint.Inject("FailToCloseDataFile", func(_ failpoint.Value) { + err = errors.New("injected error: fail to close data file") + }) + if err != nil { + pCtx.L().Warn("fail to close file", + zap.String("path", fullPath), + zap.Error(err)) + } + return err + } + return fileWriter, tearDownRoutine +} + +// LazyStringWriter is an interceptor of io.StringWriter, +// will lazily create file the first time StringWriter need to write something. +type LazyStringWriter struct { + initRoutine func() error + sync.Once + io.StringWriter + err error +} + +// WriteString implements io.StringWriter. It check whether writer has written something and init a file at first time +func (l *LazyStringWriter) WriteString(str string) (int, error) { + l.Do(func() { l.err = l.initRoutine() }) + if l.err != nil { + return 0, errors.Errorf("open file error: %s", l.err.Error()) + } + return l.StringWriter.WriteString(str) +} + +type writerError struct { + error +} + +func (e *writerError) Error() string { + return e.error.Error() +} + +func newWriterError(err error) error { + if err == nil { + return nil + } + return &writerError{error: err} +} + +// InterceptFileWriter is an interceptor of os.File, +// tracking whether a StringWriter has written something. +type InterceptFileWriter struct { + storage.ExternalFileWriter + sync.Once + SomethingIsWritten bool + + initRoutine func() error + err error +} + +// Write implements storage.ExternalFileWriter.Write. It check whether writer has written something and init a file at first time +func (w *InterceptFileWriter) Write(ctx context.Context, p []byte) (int, error) { + w.Do(func() { w.err = w.initRoutine() }) + if len(p) > 0 { + w.SomethingIsWritten = true + } + if w.err != nil { + return 0, errors.Annotate(w.err, "open file error") + } + n, err := w.ExternalFileWriter.Write(ctx, p) + return n, newWriterError(err) +} + +// Close closes the InterceptFileWriter +func (w *InterceptFileWriter) Close(ctx context.Context) error { + return w.ExternalFileWriter.Close(ctx) +} + +func wrapBackTicks(identifier string) string { + if !strings.HasPrefix(identifier, "`") && !strings.HasSuffix(identifier, "`") { + return wrapStringWith(identifier, "`") + } + return identifier +} + +func wrapStringWith(str string, wrapper string) string { + return fmt.Sprintf("%s%s%s", wrapper, str, wrapper) +} + +func compressFileSuffix(compressType storage.CompressType) string { + switch compressType { + case storage.NoCompression: + return "" + case storage.Gzip: + return ".gz" + case storage.Snappy: + return ".snappy" + case storage.Zstd: + return ".zst" + default: + return "" + } +} + +// FileFormat is the format that output to file. Currently we support SQL text and CSV file format. +type FileFormat int32 + +const ( + // FileFormatUnknown indicates the given file type is unknown + FileFormatUnknown FileFormat = iota + // FileFormatSQLText indicates the given file type is sql type + FileFormatSQLText + // FileFormatCSV indicates the given file type is csv type + FileFormatCSV +) + +const ( + // FileFormatSQLTextString indicates the string/suffix of sql type file + FileFormatSQLTextString = "sql" + // FileFormatCSVString indicates the string/suffix of csv type file + FileFormatCSVString = "csv" +) + +// String implement Stringer.String method. +func (f FileFormat) String() string { + switch f { + case FileFormatSQLText: + return strings.ToUpper(FileFormatSQLTextString) + case FileFormatCSV: + return strings.ToUpper(FileFormatCSVString) + default: + return "unknown" + } +} + +// Extension returns the extension for specific format. +// +// text -> "sql" +// csv -> "csv" +func (f FileFormat) Extension() string { + switch f { + case FileFormatSQLText: + return FileFormatSQLTextString + case FileFormatCSV: + return FileFormatCSVString + default: + return "unknown_format" + } +} + +// WriteInsert writes TableDataIR to a storage.ExternalFileWriter in sql/csv type +func (f FileFormat) WriteInsert( + pCtx *tcontext.Context, + cfg *Config, + meta TableMeta, + tblIR TableDataIR, + w storage.ExternalFileWriter, + metrics *metrics, +) (uint64, error) { + switch f { + case FileFormatSQLText: + return WriteInsert(pCtx, cfg, meta, tblIR, w, metrics) + case FileFormatCSV: + return WriteInsertInCsv(pCtx, cfg, meta, tblIR, w, metrics) + default: + return 0, errors.Errorf("unknown file format") + } +} diff --git a/lightning/pkg/importer/binding__failpoint_binding__.go b/lightning/pkg/importer/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..62be625525776 --- /dev/null +++ b/lightning/pkg/importer/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package importer + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/lightning/pkg/importer/chunk_process.go b/lightning/pkg/importer/chunk_process.go index 802e7c15d4b6c..fff2ef14cefd1 100644 --- a/lightning/pkg/importer/chunk_process.go +++ b/lightning/pkg/importer/chunk_process.go @@ -480,9 +480,9 @@ func (cr *chunkProcessor) encodeLoop( kvPacket = append(kvPacket, deliveredKVs{kvs: kvs, columns: filteredColumns, offset: newOffset, rowID: rowID, realOffset: newScannedOffset}) kvSize += kvs.Size() - failpoint.Inject("mock-kv-size", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mock-kv-size")); _err_ == nil { kvSize += uint64(val.(int)) - }) + } // pebble cannot allow > 4.0G kv in one batch. // we will meet pebble panic when import sql file and each kv has the size larger than 4G / maxKvPairsCnt. // so add this check. @@ -735,25 +735,25 @@ func (cr *chunkProcessor) deliverLoop( // No need to save checkpoint if nothing was delivered. dataSynced = cr.maybeSaveCheckpoint(rc, t, engineID, cr.chunk, dataEngine, indexEngine) } - failpoint.Inject("SlowDownWriteRows", func() { + if _, _err_ := failpoint.Eval(_curpkg_("SlowDownWriteRows")); _err_ == nil { deliverLogger.Warn("Slowed down write rows") finished := rc.status.FinishedFileSize.Load() total := rc.status.TotalFileSize.Load() deliverLogger.Warn("PrintStatus Failpoint", zap.Int64("finished", finished), zap.Int64("total", total)) - }) - failpoint.Inject("FailAfterWriteRows", nil) + } + failpoint.Eval(_curpkg_("FailAfterWriteRows")) // TODO: for local backend, we may save checkpoint more frequently, e.g. after written // 10GB kv pairs to data engine, we can do a flush for both data & index engine, then we // can safely update current checkpoint. - failpoint.Inject("LocalBackendSaveCheckpoint", func() { + if _, _err_ := failpoint.Eval(_curpkg_("LocalBackendSaveCheckpoint")); _err_ == nil { if !isLocalBackend(rc.cfg) && (dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0) { // No need to save checkpoint if nothing was delivered. saveCheckpoint(rc, t, engineID, cr.chunk) } - }) + } } return diff --git a/lightning/pkg/importer/chunk_process.go__failpoint_stash__ b/lightning/pkg/importer/chunk_process.go__failpoint_stash__ new file mode 100644 index 0000000000000..802e7c15d4b6c --- /dev/null +++ b/lightning/pkg/importer/chunk_process.go__failpoint_stash__ @@ -0,0 +1,778 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "bytes" + "context" + "io" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/keyspace" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/tidb" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/mydump" + verify "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/lightning/worker" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/store/driver/txn" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/extsort" + "go.uber.org/zap" +) + +// chunkProcessor process data chunk +// for local backend it encodes and writes KV to local disk +// for tidb backend it transforms data into sql and executes them. +type chunkProcessor struct { + parser mydump.Parser + index int + chunk *checkpoints.ChunkCheckpoint +} + +func newChunkProcessor( + ctx context.Context, + index int, + cfg *config.Config, + chunk *checkpoints.ChunkCheckpoint, + ioWorkers *worker.Pool, + store storage.ExternalStorage, + tableInfo *model.TableInfo, +) (*chunkProcessor, error) { + parser, err := openParser(ctx, cfg, chunk, ioWorkers, store, tableInfo) + if err != nil { + return nil, err + } + return &chunkProcessor{ + parser: parser, + index: index, + chunk: chunk, + }, nil +} + +func openParser( + ctx context.Context, + cfg *config.Config, + chunk *checkpoints.ChunkCheckpoint, + ioWorkers *worker.Pool, + store storage.ExternalStorage, + tblInfo *model.TableInfo, +) (mydump.Parser, error) { + blockBufSize := int64(cfg.Mydumper.ReadBlockSize) + reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, store, storage.DecompressConfig{ + ZStdDecodeConcurrency: 1, + }) + if err != nil { + return nil, err + } + + var parser mydump.Parser + switch chunk.FileMeta.Type { + case mydump.SourceTypeCSV: + hasHeader := cfg.Mydumper.CSV.Header && chunk.Chunk.Offset == 0 + // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. + charsetConvertor, err := mydump.NewCharsetConvertor(cfg.Mydumper.DataCharacterSet, cfg.Mydumper.DataInvalidCharReplace) + if err != nil { + return nil, err + } + parser, err = mydump.NewCSVParser(ctx, &cfg.Mydumper.CSV, reader, blockBufSize, ioWorkers, hasHeader, charsetConvertor) + if err != nil { + return nil, err + } + case mydump.SourceTypeSQL: + parser = mydump.NewChunkParser(ctx, cfg.TiDB.SQLMode, reader, blockBufSize, ioWorkers) + case mydump.SourceTypeParquet: + parser, err = mydump.NewParquetParser(ctx, store, reader, chunk.FileMeta.Path) + if err != nil { + return nil, err + } + default: + return nil, errors.Errorf("file '%s' with unknown source type '%s'", chunk.Key.Path, chunk.FileMeta.Type.String()) + } + + if chunk.FileMeta.Compression == mydump.CompressionNone { + if err = parser.SetPos(chunk.Chunk.Offset, chunk.Chunk.PrevRowIDMax); err != nil { + _ = parser.Close() + return nil, err + } + } else { + if err = mydump.ReadUntil(parser, chunk.Chunk.Offset); err != nil { + _ = parser.Close() + return nil, err + } + parser.SetRowID(chunk.Chunk.PrevRowIDMax) + } + if len(chunk.ColumnPermutation) > 0 { + parser.SetColumns(getColumnNames(tblInfo, chunk.ColumnPermutation)) + } + + return parser, nil +} + +func getColumnNames(tableInfo *model.TableInfo, permutation []int) []string { + colIndexes := make([]int, 0, len(permutation)) + for i := 0; i < len(permutation); i++ { + colIndexes = append(colIndexes, -1) + } + colCnt := 0 + for i, p := range permutation { + if p >= 0 { + colIndexes[p] = i + colCnt++ + } + } + + names := make([]string, 0, colCnt) + for _, idx := range colIndexes { + // skip columns with index -1 + if idx >= 0 { + // original fields contains _tidb_rowid field + if idx == len(tableInfo.Columns) { + names = append(names, model.ExtraHandleName.O) + } else { + names = append(names, tableInfo.Columns[idx].Name.O) + } + } + } + return names +} + +func (cr *chunkProcessor) process( + ctx context.Context, + t *TableImporter, + engineID int32, + dataEngine, indexEngine backend.EngineWriter, + rc *Controller, +) error { + logger := t.logger.With( + zap.Int32("engineNumber", engineID), + zap.Int("fileIndex", cr.index), + zap.Stringer("path", &cr.chunk.Key), + ) + // Create the encoder. + kvEncoder, err := rc.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{ + SessionOptions: encode.SessionOptions{ + SQLMode: rc.cfg.TiDB.SQLMode, + Timestamp: cr.chunk.Timestamp, + SysVars: rc.sysVars, + // use chunk.PrevRowIDMax as the auto random seed, so it can stay the same value after recover from checkpoint. + AutoRandomSeed: cr.chunk.Chunk.PrevRowIDMax, + }, + Path: cr.chunk.Key.Path, + Table: t.encTable, + Logger: logger, + }) + if err != nil { + return err + } + defer kvEncoder.Close() + + kvsCh := make(chan []deliveredKVs, maxKVQueueSize) + deliverCompleteCh := make(chan deliverResult) + + go func() { + defer close(deliverCompleteCh) + dur, err := cr.deliverLoop(ctx, kvsCh, t, engineID, dataEngine, indexEngine, rc) + select { + case <-ctx.Done(): + case deliverCompleteCh <- deliverResult{dur, err}: + } + }() + + logTask := logger.Begin(zap.InfoLevel, "restore file") + + readTotalDur, encodeTotalDur, encodeErr := cr.encodeLoop( + ctx, + kvsCh, + t, + logger, + kvEncoder, + deliverCompleteCh, + rc, + ) + var deliverErr error + select { + case deliverResult, ok := <-deliverCompleteCh: + if ok { + logTask.End(zap.ErrorLevel, deliverResult.err, + zap.Duration("readDur", readTotalDur), + zap.Duration("encodeDur", encodeTotalDur), + zap.Duration("deliverDur", deliverResult.totalDur), + zap.Object("checksum", &cr.chunk.Checksum), + ) + deliverErr = deliverResult.err + } else { + // else, this must cause by ctx cancel + deliverErr = ctx.Err() + } + case <-ctx.Done(): + deliverErr = ctx.Err() + } + return errors.Trace(firstErr(encodeErr, deliverErr)) +} + +//nolint:nakedret // TODO: refactor +func (cr *chunkProcessor) encodeLoop( + ctx context.Context, + kvsCh chan<- []deliveredKVs, + t *TableImporter, + logger log.Logger, + kvEncoder encode.Encoder, + deliverCompleteCh <-chan deliverResult, + rc *Controller, +) (readTotalDur time.Duration, encodeTotalDur time.Duration, err error) { + defer close(kvsCh) + + // when AddIndexBySQL, we use all PK and UK to run pre-deduplication, and then we + // strip almost all secondary index to run encodeLoop. In encodeLoop when we meet + // a duplicated row marked by pre-deduplication, we need original table structure + // to generate the duplicate error message, so here create a new encoder with + // original table structure. + originalTableEncoder := kvEncoder + if rc.cfg.TikvImporter.AddIndexBySQL { + encTable, err := tables.TableFromMeta(t.alloc, t.tableInfo.Desired) + if err != nil { + return 0, 0, errors.Trace(err) + } + + originalTableEncoder, err = rc.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{ + SessionOptions: encode.SessionOptions{ + SQLMode: rc.cfg.TiDB.SQLMode, + Timestamp: cr.chunk.Timestamp, + SysVars: rc.sysVars, + // use chunk.PrevRowIDMax as the auto random seed, so it can stay the same value after recover from checkpoint. + AutoRandomSeed: cr.chunk.Chunk.PrevRowIDMax, + }, + Path: cr.chunk.Key.Path, + Table: encTable, + Logger: logger, + }) + if err != nil { + return 0, 0, errors.Trace(err) + } + defer originalTableEncoder.Close() + } + + send := func(kvs []deliveredKVs) error { + select { + case kvsCh <- kvs: + return nil + case <-ctx.Done(): + return ctx.Err() + case deliverResult, ok := <-deliverCompleteCh: + if deliverResult.err == nil && !ok { + deliverResult.err = ctx.Err() + } + if deliverResult.err == nil { + deliverResult.err = errors.New("unexpected premature fulfillment") + logger.DPanic("unexpected: deliverCompleteCh prematurely fulfilled with no error", zap.Bool("chIsOpen", ok)) + } + return errors.Trace(deliverResult.err) + } + } + + pauser, maxKvPairsCnt := rc.pauser, rc.cfg.TikvImporter.MaxKVPairs + initializedColumns, reachEOF := false, false + // filteredColumns is column names that excluded ignored columns + // WARN: this might be not correct when different SQL statements contains different fields, + // but since ColumnPermutation also depends on the hypothesis that the columns in one source file is the same + // so this should be ok. + var ( + filteredColumns []string + extendVals []types.Datum + ) + ignoreColumns, err1 := rc.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(t.dbInfo.Name, t.tableInfo.Core.Name.O, rc.cfg.Mydumper.CaseSensitive) + if err1 != nil { + err = err1 + return + } + + var dupIgnoreRowsIter extsort.Iterator + if t.dupIgnoreRows != nil { + dupIgnoreRowsIter, err = t.dupIgnoreRows.NewIterator(ctx) + if err != nil { + return 0, 0, err + } + defer func() { + _ = dupIgnoreRowsIter.Close() + }() + } + + for !reachEOF { + if err = pauser.Wait(ctx); err != nil { + return + } + offset, _ := cr.parser.Pos() + if offset >= cr.chunk.Chunk.EndOffset { + break + } + + var readDur, encodeDur time.Duration + canDeliver := false + kvPacket := make([]deliveredKVs, 0, maxKvPairsCnt) + curOffset := offset + var newOffset, rowID, newScannedOffset int64 + var scannedOffset int64 = -1 + var kvSize uint64 + var scannedOffsetErr error + outLoop: + for !canDeliver { + readDurStart := time.Now() + err = cr.parser.ReadRow() + columnNames := cr.parser.Columns() + newOffset, rowID = cr.parser.Pos() + if cr.chunk.FileMeta.Compression != mydump.CompressionNone || cr.chunk.FileMeta.Type == mydump.SourceTypeParquet { + newScannedOffset, scannedOffsetErr = cr.parser.ScannedPos() + if scannedOffsetErr != nil { + logger.Warn("fail to get data engine ScannedPos, progress may not be accurate", + log.ShortError(scannedOffsetErr), zap.String("file", cr.chunk.FileMeta.Path)) + } + if scannedOffset == -1 { + scannedOffset = newScannedOffset + } + } + + switch errors.Cause(err) { + case nil: + if !initializedColumns { + if len(cr.chunk.ColumnPermutation) == 0 { + if err = t.initializeColumns(columnNames, cr.chunk); err != nil { + return + } + } + filteredColumns = columnNames + ignoreColsMap := ignoreColumns.ColumnsMap() + if len(ignoreColsMap) > 0 || len(cr.chunk.FileMeta.ExtendData.Columns) > 0 { + filteredColumns, extendVals = filterColumns(columnNames, cr.chunk.FileMeta.ExtendData, ignoreColsMap, t.tableInfo.Core) + } + lastRow := cr.parser.LastRow() + lastRowLen := len(lastRow.Row) + extendColsMap := make(map[string]int) + for i, c := range cr.chunk.FileMeta.ExtendData.Columns { + extendColsMap[c] = lastRowLen + i + } + for i, col := range t.tableInfo.Core.Columns { + if p, ok := extendColsMap[col.Name.O]; ok { + cr.chunk.ColumnPermutation[i] = p + } + } + initializedColumns = true + + if dupIgnoreRowsIter != nil { + dupIgnoreRowsIter.Seek(common.EncodeIntRowID(lastRow.RowID)) + } + } + case io.EOF: + reachEOF = true + break outLoop + default: + err = common.ErrEncodeKV.Wrap(err).GenWithStackByArgs(&cr.chunk.Key, newOffset) + return + } + readDur += time.Since(readDurStart) + encodeDurStart := time.Now() + lastRow := cr.parser.LastRow() + lastRow.Row = append(lastRow.Row, extendVals...) + + // Skip duplicated rows. + if dupIgnoreRowsIter != nil { + rowIDKey := common.EncodeIntRowID(lastRow.RowID) + isDupIgnored := false + dupDetectLoop: + for dupIgnoreRowsIter.Valid() { + switch bytes.Compare(rowIDKey, dupIgnoreRowsIter.UnsafeKey()) { + case 0: + isDupIgnored = true + break dupDetectLoop + case 1: + dupIgnoreRowsIter.Next() + case -1: + break dupDetectLoop + } + } + if dupIgnoreRowsIter.Error() != nil { + err = dupIgnoreRowsIter.Error() + return + } + if isDupIgnored { + cr.parser.RecycleRow(lastRow) + lastOffset := curOffset + curOffset = newOffset + + if rc.errorMgr.ConflictRecordsRemain() <= 0 { + continue + } + + dupMsg := cr.getDuplicateMessage( + originalTableEncoder, + lastRow, + lastOffset, + dupIgnoreRowsIter.UnsafeValue(), + t.tableInfo.Desired, + logger, + ) + rowText := tidb.EncodeRowForRecord(ctx, t.encTable, rc.cfg.TiDB.SQLMode, lastRow.Row, cr.chunk.ColumnPermutation) + err = rc.errorMgr.RecordDuplicate( + ctx, + logger, + t.tableName, + cr.chunk.Key.Path, + newOffset, + dupMsg, + lastRow.RowID, + rowText, + ) + if err != nil { + return 0, 0, err + } + continue + } + } + + // sql -> kv + kvs, encodeErr := kvEncoder.Encode(lastRow.Row, lastRow.RowID, cr.chunk.ColumnPermutation, curOffset) + encodeDur += time.Since(encodeDurStart) + + hasIgnoredEncodeErr := false + if encodeErr != nil { + rowText := tidb.EncodeRowForRecord(ctx, t.encTable, rc.cfg.TiDB.SQLMode, lastRow.Row, cr.chunk.ColumnPermutation) + encodeErr = rc.errorMgr.RecordTypeError(ctx, logger, t.tableName, cr.chunk.Key.Path, newOffset, rowText, encodeErr) + if encodeErr != nil { + err = common.ErrEncodeKV.Wrap(encodeErr).GenWithStackByArgs(&cr.chunk.Key, newOffset) + } + hasIgnoredEncodeErr = true + } + cr.parser.RecycleRow(lastRow) + curOffset = newOffset + + if err != nil { + return + } + if hasIgnoredEncodeErr { + continue + } + + kvPacket = append(kvPacket, deliveredKVs{kvs: kvs, columns: filteredColumns, offset: newOffset, + rowID: rowID, realOffset: newScannedOffset}) + kvSize += kvs.Size() + failpoint.Inject("mock-kv-size", func(val failpoint.Value) { + kvSize += uint64(val.(int)) + }) + // pebble cannot allow > 4.0G kv in one batch. + // we will meet pebble panic when import sql file and each kv has the size larger than 4G / maxKvPairsCnt. + // so add this check. + if kvSize >= minDeliverBytes || len(kvPacket) >= maxKvPairsCnt || newOffset == cr.chunk.Chunk.EndOffset { + canDeliver = true + kvSize = 0 + } + } + encodeTotalDur += encodeDur + readTotalDur += readDur + if m, ok := metric.FromContext(ctx); ok { + m.RowEncodeSecondsHistogram.Observe(encodeDur.Seconds()) + m.RowReadSecondsHistogram.Observe(readDur.Seconds()) + if cr.chunk.FileMeta.Type == mydump.SourceTypeParquet { + m.RowReadBytesHistogram.Observe(float64(newScannedOffset - scannedOffset)) + } else { + m.RowReadBytesHistogram.Observe(float64(newOffset - offset)) + } + } + + if len(kvPacket) != 0 { + deliverKvStart := time.Now() + if err = send(kvPacket); err != nil { + return + } + if m, ok := metric.FromContext(ctx); ok { + m.RowKVDeliverSecondsHistogram.Observe(time.Since(deliverKvStart).Seconds()) + } + } + } + + err = send([]deliveredKVs{{offset: cr.chunk.Chunk.EndOffset, realOffset: cr.chunk.FileMeta.FileSize}}) + return +} + +// getDuplicateMessage gets the duplicate message like a SQL error. When it meets +// internal error, the error message will be returned instead of the duplicate message. +// If the index is not found (which is not expected), an empty string will be returned. +func (cr *chunkProcessor) getDuplicateMessage( + kvEncoder encode.Encoder, + lastRow mydump.Row, + lastOffset int64, + encodedIdxID []byte, + tableInfo *model.TableInfo, + logger log.Logger, +) string { + _, idxID, err := codec.DecodeVarint(encodedIdxID) + if err != nil { + return err.Error() + } + kvs, err := kvEncoder.Encode(lastRow.Row, lastRow.RowID, cr.chunk.ColumnPermutation, lastOffset) + if err != nil { + return err.Error() + } + + if idxID == conflictOnHandle { + for _, kv := range kvs.(*kv.Pairs).Pairs { + if tablecodec.IsRecordKey(kv.Key) { + dupErr := txn.ExtractKeyExistsErrFromHandle(kv.Key, kv.Val, tableInfo) + return dupErr.Error() + } + } + // should not happen + logger.Warn("fail to find conflict record key", + zap.String("file", cr.chunk.FileMeta.Path), + zap.Any("row", lastRow.Row)) + } else { + for _, kv := range kvs.(*kv.Pairs).Pairs { + _, decodedIdxID, isRecordKey, err := tablecodec.DecodeKeyHead(kv.Key) + if err != nil { + return err.Error() + } + if !isRecordKey && decodedIdxID == idxID { + dupErr := txn.ExtractKeyExistsErrFromIndex(kv.Key, kv.Val, tableInfo, idxID) + return dupErr.Error() + } + } + // should not happen + logger.Warn("fail to find conflict index key", + zap.String("file", cr.chunk.FileMeta.Path), + zap.Int64("idxID", idxID), + zap.Any("row", lastRow.Row)) + } + return "" +} + +//nolint:nakedret // TODO: refactor +func (cr *chunkProcessor) deliverLoop( + ctx context.Context, + kvsCh <-chan []deliveredKVs, + t *TableImporter, + engineID int32, + dataEngine, indexEngine backend.EngineWriter, + rc *Controller, +) (deliverTotalDur time.Duration, err error) { + deliverLogger := t.logger.With( + zap.Int32("engineNumber", engineID), + zap.Int("fileIndex", cr.index), + zap.Stringer("path", &cr.chunk.Key), + zap.String("task", "deliver"), + ) + // Fetch enough KV pairs from the source. + dataKVs := rc.encBuilder.MakeEmptyRows() + indexKVs := rc.encBuilder.MakeEmptyRows() + + dataSynced := true + hasMoreKVs := true + var startRealOffset, currRealOffset int64 // save to 0 at first + + keyspace := keyspace.CodecV1.GetKeyspace() + if t.kvStore != nil { + keyspace = t.kvStore.GetCodec().GetKeyspace() + } + for hasMoreKVs { + var ( + dataChecksum = verify.NewKVChecksumWithKeyspace(keyspace) + indexChecksum = verify.NewKVChecksumWithKeyspace(keyspace) + ) + var columns []string + var kvPacket []deliveredKVs + // init these two field as checkpoint current value, so even if there are no kv pairs delivered, + // chunk checkpoint should stay the same + startOffset := cr.chunk.Chunk.Offset + currOffset := startOffset + startRealOffset = cr.chunk.Chunk.RealOffset + currRealOffset = startRealOffset + rowID := cr.chunk.Chunk.PrevRowIDMax + + populate: + for dataChecksum.SumSize()+indexChecksum.SumSize() < minDeliverBytes { + select { + case kvPacket = <-kvsCh: + if len(kvPacket) == 0 { + hasMoreKVs = false + break populate + } + for _, p := range kvPacket { + if p.kvs == nil { + // This is the last message. + currOffset = p.offset + currRealOffset = p.realOffset + hasMoreKVs = false + break populate + } + p.kvs.ClassifyAndAppend(&dataKVs, dataChecksum, &indexKVs, indexChecksum) + columns = p.columns + currOffset = p.offset + currRealOffset = p.realOffset + rowID = p.rowID + } + case <-ctx.Done(): + err = ctx.Err() + return + } + } + + err = func() error { + // We use `TryRLock` with sleep here to avoid blocking current goroutine during importing when disk-quota is + // triggered, so that we can save chunkCheckpoint as soon as possible after `FlushEngine` is called. + // This implementation may not be very elegant or even completely correct, but it is currently a relatively + // simple and effective solution. + for !rc.diskQuotaLock.TryRLock() { + // try to update chunk checkpoint, this can help save checkpoint after importing when disk-quota is triggered + if !dataSynced { + dataSynced = cr.maybeSaveCheckpoint(rc, t, engineID, cr.chunk, dataEngine, indexEngine) + } + time.Sleep(time.Millisecond) + } + defer rc.diskQuotaLock.RUnlock() + + // Write KVs into the engine + start := time.Now() + + if err = dataEngine.AppendRows(ctx, columns, dataKVs); err != nil { + if !common.IsContextCanceledError(err) { + deliverLogger.Error("write to data engine failed", log.ShortError(err)) + } + + return errors.Trace(err) + } + if err = indexEngine.AppendRows(ctx, columns, indexKVs); err != nil { + if !common.IsContextCanceledError(err) { + deliverLogger.Error("write to index engine failed", log.ShortError(err)) + } + return errors.Trace(err) + } + + if m, ok := metric.FromContext(ctx); ok { + deliverDur := time.Since(start) + deliverTotalDur += deliverDur + m.BlockDeliverSecondsHistogram.Observe(deliverDur.Seconds()) + m.BlockDeliverBytesHistogram.WithLabelValues(metric.BlockDeliverKindData).Observe(float64(dataChecksum.SumSize())) + m.BlockDeliverBytesHistogram.WithLabelValues(metric.BlockDeliverKindIndex).Observe(float64(indexChecksum.SumSize())) + m.BlockDeliverKVPairsHistogram.WithLabelValues(metric.BlockDeliverKindData).Observe(float64(dataChecksum.SumKVS())) + m.BlockDeliverKVPairsHistogram.WithLabelValues(metric.BlockDeliverKindIndex).Observe(float64(indexChecksum.SumKVS())) + } + return nil + }() + if err != nil { + return + } + dataSynced = false + + dataKVs = dataKVs.Clear() + indexKVs = indexKVs.Clear() + + // Update the table, and save a checkpoint. + // (the write to the importer is effective immediately, thus update these here) + // No need to apply a lock since this is the only thread updating `cr.chunk.**`. + // In local mode, we should write these checkpoints after engine flushed. + lastOffset := cr.chunk.Chunk.Offset + cr.chunk.Checksum.Add(dataChecksum) + cr.chunk.Checksum.Add(indexChecksum) + cr.chunk.Chunk.Offset = currOffset + cr.chunk.Chunk.RealOffset = currRealOffset + cr.chunk.Chunk.PrevRowIDMax = rowID + + if m, ok := metric.FromContext(ctx); ok { + // value of currOffset comes from parser.pos which increase monotonically. the init value of parser.pos + // comes from chunk.Chunk.Offset. so it shouldn't happen that currOffset - startOffset < 0. + // but we met it one time, but cannot reproduce it now, we add this check to make code more robust + // TODO: reproduce and find the root cause and fix it completely + var lowOffset, highOffset int64 + if cr.chunk.FileMeta.Compression != mydump.CompressionNone { + lowOffset, highOffset = startRealOffset, currRealOffset + } else { + lowOffset, highOffset = startOffset, currOffset + } + delta := highOffset - lowOffset + if delta >= 0 { + if cr.chunk.FileMeta.Type == mydump.SourceTypeParquet { + if currRealOffset > startRealOffset { + m.BytesCounter.WithLabelValues(metric.StateRestored).Add(float64(currRealOffset - startRealOffset)) + } + m.RowsCounter.WithLabelValues(metric.StateRestored, t.tableName).Add(float64(delta)) + } else { + m.BytesCounter.WithLabelValues(metric.StateRestored).Add(float64(delta)) + m.RowsCounter.WithLabelValues(metric.StateRestored, t.tableName).Add(float64(dataChecksum.SumKVS())) + } + if rc.status != nil && rc.status.backend == config.BackendTiDB { + rc.status.FinishedFileSize.Add(delta) + } + } else { + deliverLogger.Error("offset go back", zap.Int64("curr", highOffset), + zap.Int64("start", lowOffset)) + } + } + + if currOffset > lastOffset || dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0 { + // No need to save checkpoint if nothing was delivered. + dataSynced = cr.maybeSaveCheckpoint(rc, t, engineID, cr.chunk, dataEngine, indexEngine) + } + failpoint.Inject("SlowDownWriteRows", func() { + deliverLogger.Warn("Slowed down write rows") + finished := rc.status.FinishedFileSize.Load() + total := rc.status.TotalFileSize.Load() + deliverLogger.Warn("PrintStatus Failpoint", + zap.Int64("finished", finished), + zap.Int64("total", total)) + }) + failpoint.Inject("FailAfterWriteRows", nil) + // TODO: for local backend, we may save checkpoint more frequently, e.g. after written + // 10GB kv pairs to data engine, we can do a flush for both data & index engine, then we + // can safely update current checkpoint. + + failpoint.Inject("LocalBackendSaveCheckpoint", func() { + if !isLocalBackend(rc.cfg) && (dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0) { + // No need to save checkpoint if nothing was delivered. + saveCheckpoint(rc, t, engineID, cr.chunk) + } + }) + } + + return +} + +func (*chunkProcessor) maybeSaveCheckpoint( + rc *Controller, + t *TableImporter, + engineID int32, + chunk *checkpoints.ChunkCheckpoint, + data, index backend.EngineWriter, +) bool { + if data.IsSynced() && index.IsSynced() { + saveCheckpoint(rc, t, engineID, chunk) + return true + } + return false +} + +func (cr *chunkProcessor) close() { + _ = cr.parser.Close() +} diff --git a/lightning/pkg/importer/get_pre_info.go b/lightning/pkg/importer/get_pre_info.go index 5e34c6bf36186..64880d9c52750 100644 --- a/lightning/pkg/importer/get_pre_info.go +++ b/lightning/pkg/importer/get_pre_info.go @@ -187,9 +187,9 @@ func (g *TargetInfoGetterImpl) CheckVersionRequirements(ctx context.Context) err // It tries to select the row count from the target DB. func (g *TargetInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { var result bool - failpoint.Inject("CheckTableEmptyFailed", func() { - failpoint.Return(nil, errors.New("mock error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("CheckTableEmptyFailed")); _err_ == nil { + return nil, errors.New("mock error") + } exec := common.SQLWithRetry{ DB: g.db, Logger: log.FromContext(ctx), @@ -365,19 +365,18 @@ func (p *PreImportInfoGetterImpl) GetAllTableStructures(ctx context.Context, opt func (p *PreImportInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Context, dbSrcFileMeta *mydump.MDDatabaseMeta, getPreInfoCfg *ropts.GetPreInfoConfig) ([]*model.TableInfo, error) { dbName := dbSrcFileMeta.Name - failpoint.Inject( - "getTableStructuresByFileMeta_BeforeFetchRemoteTableModels", - func(v failpoint.Value) { - fmt.Println("failpoint: getTableStructuresByFileMeta_BeforeFetchRemoteTableModels") - const defaultMilliSeconds int = 5000 - sleepMilliSeconds, ok := v.(int) - if !ok || sleepMilliSeconds <= 0 || sleepMilliSeconds > 30000 { - sleepMilliSeconds = defaultMilliSeconds - } - //nolint: errcheck - failpoint.Enable("github.com/pingcap/tidb/pkg/lightning/backend/tidb/FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", fmt.Sprintf("sleep(%d)", sleepMilliSeconds)) - }, - ) + if v, _err_ := failpoint.Eval(_curpkg_("getTableStructuresByFileMeta_BeforeFetchRemoteTableModels")); _err_ == nil { + + fmt.Println("failpoint: getTableStructuresByFileMeta_BeforeFetchRemoteTableModels") + const defaultMilliSeconds int = 5000 + sleepMilliSeconds, ok := v.(int) + if !ok || sleepMilliSeconds <= 0 || sleepMilliSeconds > 30000 { + sleepMilliSeconds = defaultMilliSeconds + } + //nolint: errcheck + failpoint.Enable("github.com/pingcap/tidb/pkg/lightning/backend/tidb/FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", fmt.Sprintf("sleep(%d)", sleepMilliSeconds)) + + } currentTableInfosFromDB, err := p.targetInfoGetter.FetchRemoteTableModels(ctx, dbName) if err != nil { if getPreInfoCfg != nil && getPreInfoCfg.IgnoreDBNotExist { @@ -758,9 +757,9 @@ outloop: rowSize += uint64(lastRow.Length) parser.RecycleRow(lastRow) - failpoint.Inject("mock-kv-size", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mock-kv-size")); _err_ == nil { kvSize += uint64(val.(int)) - }) + } if rowSize > maxSampleDataSize || rowCount > maxSampleRowCount { break } diff --git a/lightning/pkg/importer/get_pre_info.go__failpoint_stash__ b/lightning/pkg/importer/get_pre_info.go__failpoint_stash__ new file mode 100644 index 0000000000000..5e34c6bf36186 --- /dev/null +++ b/lightning/pkg/importer/get_pre_info.go__failpoint_stash__ @@ -0,0 +1,835 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "strings" + + mysql_sql_driver "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + ropts "github.com/pingcap/tidb/lightning/pkg/importer/opts" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/backend/tidb" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/errormanager" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/mydump" + "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/lightning/worker" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + _ "github.com/pingcap/tidb/pkg/planner/core" // to setup expression.EvalAstExpr. Otherwise we cannot parse the default value + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/mock" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +// compressionRatio is the tikv/tiflash's compression ratio +const compressionRatio = float64(1) / 3 + +// EstimateSourceDataSizeResult is the object for estimated data size result. +type EstimateSourceDataSizeResult struct { + // SizeWithIndex is the tikv size with the index. + SizeWithIndex int64 + // SizeWithoutIndex is the tikv size without the index. + SizeWithoutIndex int64 + // HasUnsortedBigTables indicates whether the source data has unsorted big tables or not. + HasUnsortedBigTables bool + // TiFlashSize is the size of tiflash. + TiFlashSize int64 +} + +// PreImportInfoGetter defines the operations to get information from sources and target. +// These information are used in the preparation of the import ( like precheck ). +type PreImportInfoGetter interface { + TargetInfoGetter + // GetAllTableStructures gets all the table structures with the information from both the source and the target. + GetAllTableStructures(ctx context.Context, opts ...ropts.GetPreInfoOption) (map[string]*checkpoints.TidbDBInfo, error) + // ReadFirstNRowsByTableName reads the first N rows of data of an importing source table. + ReadFirstNRowsByTableName(ctx context.Context, schemaName string, tableName string, n int) (cols []string, rows [][]types.Datum, err error) + // ReadFirstNRowsByFileMeta reads the first N rows of an data file. + ReadFirstNRowsByFileMeta(ctx context.Context, dataFileMeta mydump.SourceFileMeta, n int) (cols []string, rows [][]types.Datum, err error) + // EstimateSourceDataSize estimates the datasize to generate during the import as well as some other sub-informaiton. + // It will return: + // * the estimated data size to generate during the import, + // which might include some extra index data to generate besides the source file data + // * the total data size of all the source files, + // * whether there are some unsorted big tables + EstimateSourceDataSize(ctx context.Context, opts ...ropts.GetPreInfoOption) (*EstimateSourceDataSizeResult, error) +} + +// TargetInfoGetter defines the operations to get information from target. +type TargetInfoGetter interface { + // FetchRemoteDBModels fetches the database structures from the remote target. + FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) + // FetchRemoteTableModels fetches the table structures from the remote target. + FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) + // CheckVersionRequirements performs the check whether the target satisfies the version requirements. + CheckVersionRequirements(ctx context.Context) error + // IsTableEmpty checks whether the specified table on the target DB contains data or not. + IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) + // GetTargetSysVariablesForImport gets some important systam variables for importing on the target. + GetTargetSysVariablesForImport(ctx context.Context, opts ...ropts.GetPreInfoOption) map[string]string + // GetMaxReplica gets the max-replica from replication config on the target. + GetMaxReplica(ctx context.Context) (uint64, error) + // GetStorageInfo gets the storage information on the target. + GetStorageInfo(ctx context.Context) (*pdhttp.StoresInfo, error) + // GetEmptyRegionsInfo gets the region information of all the empty regions on the target. + GetEmptyRegionsInfo(ctx context.Context) (*pdhttp.RegionsInfo, error) +} + +type preInfoGetterKey string + +const ( + preInfoGetterKeyDBMetas preInfoGetterKey = "PRE_INFO_GETTER/DB_METAS" +) + +// WithPreInfoGetterDBMetas returns a new context with the specified dbMetas. +func WithPreInfoGetterDBMetas(ctx context.Context, dbMetas []*mydump.MDDatabaseMeta) context.Context { + return context.WithValue(ctx, preInfoGetterKeyDBMetas, dbMetas) +} + +// TargetInfoGetterImpl implements the operations to get information from the target. +type TargetInfoGetterImpl struct { + cfg *config.Config + db *sql.DB + backend backend.TargetInfoGetter + pdHTTPCli pdhttp.Client +} + +// NewTargetInfoGetterImpl creates a TargetInfoGetterImpl object. +func NewTargetInfoGetterImpl( + cfg *config.Config, + targetDB *sql.DB, + pdHTTPCli pdhttp.Client, +) (*TargetInfoGetterImpl, error) { + tls, err := cfg.ToTLS() + if err != nil { + return nil, errors.Trace(err) + } + var backendTargetInfoGetter backend.TargetInfoGetter + switch cfg.TikvImporter.Backend { + case config.BackendTiDB: + backendTargetInfoGetter = tidb.NewTargetInfoGetter(targetDB) + case config.BackendLocal: + backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDB, pdHTTPCli) + default: + return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) + } + return &TargetInfoGetterImpl{ + cfg: cfg, + db: targetDB, + backend: backendTargetInfoGetter, + pdHTTPCli: pdHTTPCli, + }, nil +} + +// FetchRemoteDBModels implements TargetInfoGetter. +func (g *TargetInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + return g.backend.FetchRemoteDBModels(ctx) +} + +// FetchRemoteTableModels fetches the table structures from the remote target. +// It implements the TargetInfoGetter interface. +func (g *TargetInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + return g.backend.FetchRemoteTableModels(ctx, schemaName) +} + +// CheckVersionRequirements performs the check whether the target satisfies the version requirements. +// It implements the TargetInfoGetter interface. +// Mydump database metas are retrieved from the context. +func (g *TargetInfoGetterImpl) CheckVersionRequirements(ctx context.Context) error { + var dbMetas []*mydump.MDDatabaseMeta + dbmetasVal := ctx.Value(preInfoGetterKeyDBMetas) + if dbmetasVal != nil { + if m, ok := dbmetasVal.([]*mydump.MDDatabaseMeta); ok { + dbMetas = m + } + } + return g.backend.CheckRequirements(ctx, &backend.CheckCtx{ + DBMetas: dbMetas, + }) +} + +// IsTableEmpty checks whether the specified table on the target DB contains data or not. +// It implements the TargetInfoGetter interface. +// It tries to select the row count from the target DB. +func (g *TargetInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { + var result bool + failpoint.Inject("CheckTableEmptyFailed", func() { + failpoint.Return(nil, errors.New("mock error")) + }) + exec := common.SQLWithRetry{ + DB: g.db, + Logger: log.FromContext(ctx), + } + var dump int + err := exec.QueryRow(ctx, "check table empty", + // Here we use the `USE INDEX()` hint to skip fetch the record from index. + // In Lightning, if previous importing is halted half-way, it is possible that + // the data is partially imported, but the index data has not been imported. + // In this situation, if no hint is added, the SQL executor might fetch the record from index, + // which is empty. This will result in missing check. + common.SprintfWithIdentifiers("SELECT 1 FROM %s.%s USE INDEX() LIMIT 1", schemaName, tableName), + &dump, + ) + + isNoSuchTableErr := false + rootErr := errors.Cause(err) + if mysqlErr, ok := rootErr.(*mysql_sql_driver.MySQLError); ok && mysqlErr.Number == errno.ErrNoSuchTable { + isNoSuchTableErr = true + } + switch { + case isNoSuchTableErr: + result = true + case errors.ErrorEqual(err, sql.ErrNoRows): + result = true + case err != nil: + return nil, errors.Trace(err) + default: + result = false + } + return &result, nil +} + +// GetTargetSysVariablesForImport gets some important system variables for importing on the target. +// It implements the TargetInfoGetter interface. +// It uses the SQL to fetch sys variables from the target. +func (g *TargetInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Context, _ ...ropts.GetPreInfoOption) map[string]string { + sysVars := ObtainImportantVariables(ctx, g.db, !isTiDBBackend(g.cfg)) + // override by manually set vars + maps.Copy(sysVars, g.cfg.TiDB.Vars) + return sysVars +} + +// GetMaxReplica implements the TargetInfoGetter interface. +func (g *TargetInfoGetterImpl) GetMaxReplica(ctx context.Context) (uint64, error) { + cfg, err := g.pdHTTPCli.GetReplicateConfig(ctx) + if err != nil { + return 0, errors.Trace(err) + } + val := cfg["max-replicas"].(float64) + return uint64(val), nil +} + +// GetStorageInfo gets the storage information on the target. +// It implements the TargetInfoGetter interface. +// It uses the PD interface through TLS to get the information. +func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdhttp.StoresInfo, error) { + return g.pdHTTPCli.GetStores(ctx) +} + +// GetEmptyRegionsInfo gets the region information of all the empty regions on the target. +// It implements the TargetInfoGetter interface. +// It uses the PD interface through TLS to get the information. +func (g *TargetInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdhttp.RegionsInfo, error) { + return g.pdHTTPCli.GetEmptyRegions(ctx) +} + +// PreImportInfoGetterImpl implements the operations to get information used in importing preparation. +type PreImportInfoGetterImpl struct { + cfg *config.Config + getPreInfoCfg *ropts.GetPreInfoConfig + srcStorage storage.ExternalStorage + ioWorkers *worker.Pool + encBuilder encode.EncodingBuilder + targetInfoGetter TargetInfoGetter + + dbMetas []*mydump.MDDatabaseMeta + mdDBMetaMap map[string]*mydump.MDDatabaseMeta + mdDBTableMetaMap map[string]map[string]*mydump.MDTableMeta + + dbInfosCache map[string]*checkpoints.TidbDBInfo + sysVarsCache map[string]string + estimatedSizeCache *EstimateSourceDataSizeResult +} + +// NewPreImportInfoGetter creates a PreImportInfoGetterImpl object. +func NewPreImportInfoGetter( + cfg *config.Config, + dbMetas []*mydump.MDDatabaseMeta, + srcStorage storage.ExternalStorage, + targetInfoGetter TargetInfoGetter, + ioWorkers *worker.Pool, + encBuilder encode.EncodingBuilder, + opts ...ropts.GetPreInfoOption, +) (*PreImportInfoGetterImpl, error) { + if ioWorkers == nil { + ioWorkers = worker.NewPool(context.Background(), cfg.App.IOConcurrency, "pre_info_getter_io") + } + if encBuilder == nil { + switch cfg.TikvImporter.Backend { + case config.BackendTiDB: + encBuilder = tidb.NewEncodingBuilder() + case config.BackendLocal: + encBuilder = local.NewEncodingBuilder(context.Background()) + default: + return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) + } + } + + getPreInfoCfg := ropts.NewDefaultGetPreInfoConfig() + for _, o := range opts { + o(getPreInfoCfg) + } + result := &PreImportInfoGetterImpl{ + cfg: cfg, + getPreInfoCfg: getPreInfoCfg, + dbMetas: dbMetas, + srcStorage: srcStorage, + ioWorkers: ioWorkers, + encBuilder: encBuilder, + targetInfoGetter: targetInfoGetter, + } + result.Init() + return result, nil +} + +// Init initializes some internal data and states for PreImportInfoGetterImpl. +func (p *PreImportInfoGetterImpl) Init() { + mdDBMetaMap := make(map[string]*mydump.MDDatabaseMeta) + mdDBTableMetaMap := make(map[string]map[string]*mydump.MDTableMeta) + for _, dbMeta := range p.dbMetas { + dbName := dbMeta.Name + mdDBMetaMap[dbName] = dbMeta + mdTableMetaMap, ok := mdDBTableMetaMap[dbName] + if !ok { + mdTableMetaMap = make(map[string]*mydump.MDTableMeta) + mdDBTableMetaMap[dbName] = mdTableMetaMap + } + for _, tblMeta := range dbMeta.Tables { + tblName := tblMeta.Name + mdTableMetaMap[tblName] = tblMeta + } + } + p.mdDBMetaMap = mdDBMetaMap + p.mdDBTableMetaMap = mdDBTableMetaMap +} + +// GetAllTableStructures gets all the table structures with the information from both the source and the target. +// It implements the PreImportInfoGetter interface. +// It has a caching mechanism: the table structures will be obtained from the source only once. +func (p *PreImportInfoGetterImpl) GetAllTableStructures(ctx context.Context, opts ...ropts.GetPreInfoOption) (map[string]*checkpoints.TidbDBInfo, error) { + var ( + dbInfos map[string]*checkpoints.TidbDBInfo + err error + ) + getPreInfoCfg := p.getPreInfoCfg.Clone() + for _, o := range opts { + o(getPreInfoCfg) + } + dbInfos = p.dbInfosCache + if dbInfos != nil && !getPreInfoCfg.ForceReloadCache { + return dbInfos, nil + } + dbInfos, err = LoadSchemaInfo(ctx, p.dbMetas, func(ctx context.Context, dbName string) ([]*model.TableInfo, error) { + return p.getTableStructuresByFileMeta(ctx, p.mdDBMetaMap[dbName], getPreInfoCfg) + }) + if err != nil { + return nil, errors.Trace(err) + } + p.dbInfosCache = dbInfos + return dbInfos, nil +} + +func (p *PreImportInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Context, dbSrcFileMeta *mydump.MDDatabaseMeta, getPreInfoCfg *ropts.GetPreInfoConfig) ([]*model.TableInfo, error) { + dbName := dbSrcFileMeta.Name + failpoint.Inject( + "getTableStructuresByFileMeta_BeforeFetchRemoteTableModels", + func(v failpoint.Value) { + fmt.Println("failpoint: getTableStructuresByFileMeta_BeforeFetchRemoteTableModels") + const defaultMilliSeconds int = 5000 + sleepMilliSeconds, ok := v.(int) + if !ok || sleepMilliSeconds <= 0 || sleepMilliSeconds > 30000 { + sleepMilliSeconds = defaultMilliSeconds + } + //nolint: errcheck + failpoint.Enable("github.com/pingcap/tidb/pkg/lightning/backend/tidb/FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", fmt.Sprintf("sleep(%d)", sleepMilliSeconds)) + }, + ) + currentTableInfosFromDB, err := p.targetInfoGetter.FetchRemoteTableModels(ctx, dbName) + if err != nil { + if getPreInfoCfg != nil && getPreInfoCfg.IgnoreDBNotExist { + dbNotExistErr := dbterror.ClassSchema.NewStd(errno.ErrBadDB).FastGenByArgs(dbName) + // The returned error is an error showing get info request error, + // and attaches the detailed error response as a string. + // So we cannot get the error chain and use error comparison, + // and instead, we use the string comparison on error messages. + if strings.Contains(err.Error(), dbNotExistErr.Error()) { + log.L().Warn("DB not exists. But ignore it", zap.Error(err)) + goto get_struct_from_src + } + } + return nil, errors.Trace(err) + } +get_struct_from_src: + currentTableInfosMap := make(map[string]*model.TableInfo) + for _, tblInfo := range currentTableInfosFromDB { + currentTableInfosMap[tblInfo.Name.L] = tblInfo + } + resultInfos := make([]*model.TableInfo, len(dbSrcFileMeta.Tables)) + for i, tableFileMeta := range dbSrcFileMeta.Tables { + if curTblInfo, ok := currentTableInfosMap[strings.ToLower(tableFileMeta.Name)]; ok { + resultInfos[i] = curTblInfo + continue + } + createTblSQL, err := tableFileMeta.GetSchema(ctx, p.srcStorage) + if err != nil { + return nil, errors.Annotatef(err, "get create table statement from schema file error: %s", tableFileMeta.Name) + } + theTableInfo, err := newTableInfo(createTblSQL, 0) + log.L().Info("generate table info from SQL", zap.Error(err), zap.String("sql", createTblSQL), zap.String("table_name", tableFileMeta.Name), zap.String("db_name", dbSrcFileMeta.Name)) + if err != nil { + errMsg := "generate table info from SQL error" + log.L().Error(errMsg, zap.Error(err), zap.String("sql", createTblSQL), zap.String("table_name", tableFileMeta.Name)) + return nil, errors.Annotatef(err, "%s: %s", errMsg, tableFileMeta.Name) + } + resultInfos[i] = theTableInfo + } + return resultInfos, nil +} + +func newTableInfo(createTblSQL string, tableID int64) (*model.TableInfo, error) { + parser := parser.New() + astNode, err := parser.ParseOneStmt(createTblSQL, "", "") + if err != nil { + errMsg := "parse sql statement error" + log.L().Error(errMsg, zap.Error(err), zap.String("sql", createTblSQL)) + return nil, errors.Trace(err) + } + sctx := mock.NewContext() + createTableStmt, ok := astNode.(*ast.CreateTableStmt) + if !ok { + return nil, errors.New("cannot transfer the parsed SQL as an CREATE TABLE statement") + } + info, err := ddl.MockTableInfo(sctx, createTableStmt, tableID) + if err != nil { + return nil, errors.Trace(err) + } + info.State = model.StatePublic + return info, nil +} + +// ReadFirstNRowsByTableName reads the first N rows of data of an importing source table. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) ReadFirstNRowsByTableName(ctx context.Context, schemaName string, tableName string, n int) ([]string, [][]types.Datum, error) { + mdTableMetaMap, ok := p.mdDBTableMetaMap[schemaName] + if !ok { + return nil, nil, errors.Errorf("cannot find the schema: %s", schemaName) + } + mdTableMeta, ok := mdTableMetaMap[tableName] + if !ok { + return nil, nil, errors.Errorf("cannot find the table: %s.%s", schemaName, tableName) + } + if len(mdTableMeta.DataFiles) <= 0 { + return nil, [][]types.Datum{}, nil + } + return p.ReadFirstNRowsByFileMeta(ctx, mdTableMeta.DataFiles[0].FileMeta, n) +} + +// ReadFirstNRowsByFileMeta reads the first N rows of an data file. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) ReadFirstNRowsByFileMeta(ctx context.Context, dataFileMeta mydump.SourceFileMeta, n int) ([]string, [][]types.Datum, error) { + reader, err := mydump.OpenReader(ctx, &dataFileMeta, p.srcStorage, storage.DecompressConfig{ + ZStdDecodeConcurrency: 1, + }) + if err != nil { + return nil, nil, errors.Trace(err) + } + + var parser mydump.Parser + blockBufSize := int64(p.cfg.Mydumper.ReadBlockSize) + switch dataFileMeta.Type { + case mydump.SourceTypeCSV: + hasHeader := p.cfg.Mydumper.CSV.Header + // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. + charsetConvertor, err := mydump.NewCharsetConvertor(p.cfg.Mydumper.DataCharacterSet, p.cfg.Mydumper.DataInvalidCharReplace) + if err != nil { + return nil, nil, errors.Trace(err) + } + parser, err = mydump.NewCSVParser(ctx, &p.cfg.Mydumper.CSV, reader, blockBufSize, p.ioWorkers, hasHeader, charsetConvertor) + if err != nil { + return nil, nil, errors.Trace(err) + } + case mydump.SourceTypeSQL: + parser = mydump.NewChunkParser(ctx, p.cfg.TiDB.SQLMode, reader, blockBufSize, p.ioWorkers) + case mydump.SourceTypeParquet: + parser, err = mydump.NewParquetParser(ctx, p.srcStorage, reader, dataFileMeta.Path) + if err != nil { + return nil, nil, errors.Trace(err) + } + default: + panic(fmt.Sprintf("unknown file type '%s'", dataFileMeta.Type)) + } + //nolint: errcheck + defer parser.Close() + + rows := [][]types.Datum{} + for i := 0; i < n; i++ { + err := parser.ReadRow() + if err != nil { + if errors.Cause(err) != io.EOF { + return nil, nil, errors.Trace(err) + } + break + } + lastRowDatums := append([]types.Datum{}, parser.LastRow().Row...) + rows = append(rows, lastRowDatums) + } + return parser.Columns(), rows, nil +} + +// EstimateSourceDataSize estimates the datasize to generate during the import as well as some other sub-informaiton. +// It implements the PreImportInfoGetter interface. +// It has a cache mechanism. The estimated size will only calculated once. +// The caching behavior can be changed by appending the `ForceReloadCache(true)` option. +func (p *PreImportInfoGetterImpl) EstimateSourceDataSize(ctx context.Context, opts ...ropts.GetPreInfoOption) (*EstimateSourceDataSizeResult, error) { + var result *EstimateSourceDataSizeResult + + getPreInfoCfg := p.getPreInfoCfg.Clone() + for _, o := range opts { + o(getPreInfoCfg) + } + result = p.estimatedSizeCache + if result != nil && !getPreInfoCfg.ForceReloadCache { + return result, nil + } + + var ( + sizeWithIndex = int64(0) + tiflashSize = int64(0) + sourceTotalSize = int64(0) + tableCount = 0 + unSortedBigTableCount = 0 + errMgr = errormanager.New(nil, p.cfg, log.FromContext(ctx)) + ) + + dbInfos, err := p.GetAllTableStructures(ctx) + if err != nil { + return nil, errors.Trace(err) + } + sysVars := p.GetTargetSysVariablesForImport(ctx) + for _, db := range p.dbMetas { + info, ok := dbInfos[db.Name] + if !ok { + continue + } + for _, tbl := range db.Tables { + sourceTotalSize += tbl.TotalSize + tableInfo, ok := info.Tables[tbl.Name] + if ok { + tableSize := tbl.TotalSize + // Do not sample small table because there may a large number of small table and it will take a long + // time to sample data for all of them. + if isTiDBBackend(p.cfg) || tbl.TotalSize < int64(config.SplitRegionSize) { + tbl.IndexRatio = 1.0 + tbl.IsRowOrdered = false + } else { + sampledIndexRatio, isRowOrderedFromSample, err := p.sampleDataFromTable(ctx, db.Name, tbl, tableInfo.Core, errMgr, sysVars) + if err != nil { + return nil, errors.Trace(err) + } + tbl.IndexRatio = sampledIndexRatio + tbl.IsRowOrdered = isRowOrderedFromSample + + tableSize = int64(float64(tbl.TotalSize) * tbl.IndexRatio) + + if tbl.TotalSize > int64(config.DefaultBatchSize)*2 && !tbl.IsRowOrdered { + unSortedBigTableCount++ + } + } + + sizeWithIndex += tableSize + if tableInfo.Core.TiFlashReplica != nil && tableInfo.Core.TiFlashReplica.Available { + tiflashSize += tableSize * int64(tableInfo.Core.TiFlashReplica.Count) + } + tableCount++ + } + } + } + + if isLocalBackend(p.cfg) { + sizeWithIndex = int64(float64(sizeWithIndex) * compressionRatio) + tiflashSize = int64(float64(tiflashSize) * compressionRatio) + } + + result = &EstimateSourceDataSizeResult{ + SizeWithIndex: sizeWithIndex, + SizeWithoutIndex: sourceTotalSize, + HasUnsortedBigTables: (unSortedBigTableCount > 0), + TiFlashSize: tiflashSize, + } + p.estimatedSizeCache = result + return result, nil +} + +// sampleDataFromTable samples the source data file to get the extra data ratio for the index +// It returns: +// * the extra data ratio with index size accounted +// * is the sample data ordered by row +func (p *PreImportInfoGetterImpl) sampleDataFromTable( + ctx context.Context, + dbName string, + tableMeta *mydump.MDTableMeta, + tableInfo *model.TableInfo, + errMgr *errormanager.ErrorManager, + sysVars map[string]string, +) (float64, bool, error) { + resultIndexRatio := 1.0 + isRowOrdered := false + if len(tableMeta.DataFiles) == 0 { + return resultIndexRatio, isRowOrdered, nil + } + sampleFile := tableMeta.DataFiles[0].FileMeta + reader, err := mydump.OpenReader(ctx, &sampleFile, p.srcStorage, storage.DecompressConfig{ + ZStdDecodeConcurrency: 1, + }) + if err != nil { + return 0.0, false, errors.Trace(err) + } + idAlloc := kv.NewPanickingAllocators(tableInfo.SepAutoInc(), 0) + tbl, err := tables.TableFromMeta(idAlloc, tableInfo) + if err != nil { + return 0.0, false, errors.Trace(err) + } + logger := log.FromContext(ctx).With(zap.String("table", tableMeta.Name)) + kvEncoder, err := p.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{ + SessionOptions: encode.SessionOptions{ + SQLMode: p.cfg.TiDB.SQLMode, + Timestamp: 0, + SysVars: sysVars, + AutoRandomSeed: 0, + }, + Table: tbl, + Logger: logger, + }) + if err != nil { + return 0.0, false, errors.Trace(err) + } + blockBufSize := int64(p.cfg.Mydumper.ReadBlockSize) + + var parser mydump.Parser + switch tableMeta.DataFiles[0].FileMeta.Type { + case mydump.SourceTypeCSV: + hasHeader := p.cfg.Mydumper.CSV.Header + // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. + charsetConvertor, err := mydump.NewCharsetConvertor(p.cfg.Mydumper.DataCharacterSet, p.cfg.Mydumper.DataInvalidCharReplace) + if err != nil { + return 0.0, false, errors.Trace(err) + } + parser, err = mydump.NewCSVParser(ctx, &p.cfg.Mydumper.CSV, reader, blockBufSize, p.ioWorkers, hasHeader, charsetConvertor) + if err != nil { + return 0.0, false, errors.Trace(err) + } + case mydump.SourceTypeSQL: + parser = mydump.NewChunkParser(ctx, p.cfg.TiDB.SQLMode, reader, blockBufSize, p.ioWorkers) + case mydump.SourceTypeParquet: + parser, err = mydump.NewParquetParser(ctx, p.srcStorage, reader, sampleFile.Path) + if err != nil { + return 0.0, false, errors.Trace(err) + } + default: + panic(fmt.Sprintf("file '%s' with unknown source type '%s'", sampleFile.Path, sampleFile.Type.String())) + } + //nolint: errcheck + defer parser.Close() + logger.Begin(zap.InfoLevel, "sample file") + igCols, err := p.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(dbName, tableMeta.Name, p.cfg.Mydumper.CaseSensitive) + if err != nil { + return 0.0, false, errors.Trace(err) + } + + initializedColumns := false + var ( + columnPermutation []int + kvSize uint64 + rowSize uint64 + extendVals []types.Datum + ) + rowCount := 0 + dataKVs := p.encBuilder.MakeEmptyRows() + indexKVs := p.encBuilder.MakeEmptyRows() + lastKey := make([]byte, 0) + isRowOrdered = true +outloop: + for { + offset, _ := parser.Pos() + err = parser.ReadRow() + columnNames := parser.Columns() + + switch errors.Cause(err) { + case nil: + if !initializedColumns { + ignoreColsMap := igCols.ColumnsMap() + if len(columnPermutation) == 0 { + columnPermutation, err = createColumnPermutation( + columnNames, + ignoreColsMap, + tableInfo, + log.FromContext(ctx)) + if err != nil { + return 0.0, false, errors.Trace(err) + } + } + if len(sampleFile.ExtendData.Columns) > 0 { + _, extendVals = filterColumns(columnNames, sampleFile.ExtendData, ignoreColsMap, tableInfo) + } + initializedColumns = true + lastRow := parser.LastRow() + lastRowLen := len(lastRow.Row) + extendColsMap := make(map[string]int) + for i, c := range sampleFile.ExtendData.Columns { + extendColsMap[c] = lastRowLen + i + } + for i, col := range tableInfo.Columns { + if p, ok := extendColsMap[col.Name.O]; ok { + columnPermutation[i] = p + } + } + } + case io.EOF: + break outloop + default: + err = errors.Annotatef(err, "in file offset %d", offset) + return 0.0, false, errors.Trace(err) + } + lastRow := parser.LastRow() + rowCount++ + lastRow.Row = append(lastRow.Row, extendVals...) + + var dataChecksum, indexChecksum verification.KVChecksum + kvs, encodeErr := kvEncoder.Encode(lastRow.Row, lastRow.RowID, columnPermutation, offset) + if encodeErr != nil { + encodeErr = errMgr.RecordTypeError(ctx, log.FromContext(ctx), tableInfo.Name.O, sampleFile.Path, offset, + "" /* use a empty string here because we don't actually record */, encodeErr) + if encodeErr != nil { + return 0.0, false, errors.Annotatef(encodeErr, "in file at offset %d", offset) + } + if rowCount < maxSampleRowCount { + continue + } + break + } + if isRowOrdered { + kvs.ClassifyAndAppend(&dataKVs, &dataChecksum, &indexKVs, &indexChecksum) + for _, kv := range kv.Rows2KvPairs(dataKVs) { + if len(lastKey) == 0 { + lastKey = kv.Key + } else if bytes.Compare(lastKey, kv.Key) > 0 { + isRowOrdered = false + break + } + } + dataKVs = dataKVs.Clear() + indexKVs = indexKVs.Clear() + } + kvSize += kvs.Size() + rowSize += uint64(lastRow.Length) + parser.RecycleRow(lastRow) + + failpoint.Inject("mock-kv-size", func(val failpoint.Value) { + kvSize += uint64(val.(int)) + }) + if rowSize > maxSampleDataSize || rowCount > maxSampleRowCount { + break + } + } + + if rowSize > 0 && kvSize > rowSize { + resultIndexRatio = float64(kvSize) / float64(rowSize) + } + log.FromContext(ctx).Info("Sample source data", zap.String("table", tableMeta.Name), zap.Float64("IndexRatio", tableMeta.IndexRatio), zap.Bool("IsSourceOrder", tableMeta.IsRowOrdered)) + return resultIndexRatio, isRowOrdered, nil +} + +// GetMaxReplica implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) GetMaxReplica(ctx context.Context) (uint64, error) { + return p.targetInfoGetter.GetMaxReplica(ctx) +} + +// GetStorageInfo gets the storage information on the target. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdhttp.StoresInfo, error) { + return p.targetInfoGetter.GetStorageInfo(ctx) +} + +// GetEmptyRegionsInfo gets the region information of all the empty regions on the target. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdhttp.RegionsInfo, error) { + return p.targetInfoGetter.GetEmptyRegionsInfo(ctx) +} + +// IsTableEmpty checks whether the specified table on the target DB contains data or not. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { + return p.targetInfoGetter.IsTableEmpty(ctx, schemaName, tableName) +} + +// FetchRemoteDBModels fetches the database structures from the remote target. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + return p.targetInfoGetter.FetchRemoteDBModels(ctx) +} + +// FetchRemoteTableModels fetches the table structures from the remote target. +// It implements the PreImportInfoGetter interface. +func (p *PreImportInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + return p.targetInfoGetter.FetchRemoteTableModels(ctx, schemaName) +} + +// CheckVersionRequirements performs the check whether the target satisfies the version requirements. +// It implements the PreImportInfoGetter interface. +// Mydump database metas are retrieved from the context. +func (p *PreImportInfoGetterImpl) CheckVersionRequirements(ctx context.Context) error { + return p.targetInfoGetter.CheckVersionRequirements(ctx) +} + +// GetTargetSysVariablesForImport gets some important systam variables for importing on the target. +// It implements the PreImportInfoGetter interface. +// It has caching mechanism. +func (p *PreImportInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Context, opts ...ropts.GetPreInfoOption) map[string]string { + var sysVars map[string]string + + getPreInfoCfg := p.getPreInfoCfg.Clone() + for _, o := range opts { + o(getPreInfoCfg) + } + sysVars = p.sysVarsCache + if sysVars != nil && !getPreInfoCfg.ForceReloadCache { + return sysVars + } + sysVars = p.targetInfoGetter.GetTargetSysVariablesForImport(ctx) + p.sysVarsCache = sysVars + return sysVars +} diff --git a/lightning/pkg/importer/import.go b/lightning/pkg/importer/import.go index 04218f2f1102e..60a1a3059c780 100644 --- a/lightning/pkg/importer/import.go +++ b/lightning/pkg/importer/import.go @@ -133,9 +133,9 @@ var DeliverPauser = common.NewPauser() // nolint:gochecknoinits // TODO: refactor func init() { - failpoint.Inject("SetMinDeliverBytes", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("SetMinDeliverBytes")); _err_ == nil { minDeliverBytes = uint64(v.(int)) - }) + } } type saveCp struct { @@ -541,7 +541,7 @@ func (rc *Controller) Close() { // Run starts the restore task. func (rc *Controller) Run(ctx context.Context) error { - failpoint.Inject("beforeRun", func() {}) + failpoint.Eval(_curpkg_("beforeRun")) opts := []func(context.Context) error{ rc.setGlobalVariables, @@ -636,10 +636,10 @@ func (rc *Controller) initCheckpoint(ctx context.Context) error { if err != nil { return common.ErrInitCheckpoint.Wrap(err).GenWithStackByArgs() } - failpoint.Inject("InitializeCheckpointExit", func() { + if _, _err_ := failpoint.Eval(_curpkg_("InitializeCheckpointExit")); _err_ == nil { log.FromContext(ctx).Warn("exit triggered", zap.String("failpoint", "InitializeCheckpointExit")) os.Exit(0) - }) + } if err := rc.loadDesiredTableInfos(ctx); err != nil { return err } @@ -864,7 +864,7 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { lock.Unlock() //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("SlowDownCheckpointUpdate", func() {}) + failpoint.Eval(_curpkg_("SlowDownCheckpointUpdate")) if len(cpd) > 0 { err := rc.checkpointsDB.Update(rc.taskCtx, cpd) @@ -897,25 +897,25 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { lock.Unlock() //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("FailIfImportedChunk", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FailIfImportedChunk")); _err_ == nil { if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { rc.checkpointsWg.Done() rc.checkpointsWg.Wait() panic("forcing failure due to FailIfImportedChunk") } - }) + } //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("FailIfStatusBecomes", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("FailIfStatusBecomes")); _err_ == nil { if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && merger.EngineID >= 0 && int(merger.Status) == val.(int) { rc.checkpointsWg.Done() rc.checkpointsWg.Wait() panic("forcing failure due to FailIfStatusBecomes") } - }) + } //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("FailIfIndexEngineImported", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("FailIfIndexEngineImported")); _err_ == nil { if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && merger.EngineID == checkpoints.WholeTableEngineID && merger.Status == checkpoints.CheckpointStatusIndexImported && val.(int) > 0 { @@ -923,10 +923,10 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { rc.checkpointsWg.Wait() panic("forcing failure due to FailIfIndexEngineImported") } - }) + } //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("KillIfImportedChunk", func() { + if _, _err_ := failpoint.Eval(_curpkg_("KillIfImportedChunk")); _err_ == nil { if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { rc.checkpointsWg.Done() rc.checkpointsWg.Wait() @@ -938,9 +938,9 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { scp.waitCh <- context.Canceled } } - failpoint.Return() + return } - }) + } } // Don't put this statement in defer function at the beginning. failpoint function may call it manually. rc.checkpointsWg.Done() diff --git a/lightning/pkg/importer/import.go__failpoint_stash__ b/lightning/pkg/importer/import.go__failpoint_stash__ new file mode 100644 index 0000000000000..04218f2f1102e --- /dev/null +++ b/lightning/pkg/importer/import.go__failpoint_stash__ @@ -0,0 +1,2080 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "context" + "database/sql" + "fmt" + "math" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/docker/go-units" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/pdutil" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/utils" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/br/pkg/version/build" + "github.com/pingcap/tidb/lightning/pkg/web" + tidbconfig "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/keyspace" + tidbkv "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/backend/tidb" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/errormanager" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/mydump" + "github.com/pingcap/tidb/pkg/lightning/tikv" + "github.com/pingcap/tidb/pkg/lightning/worker" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/store/driver" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/etcd" + regexprrouter "github.com/pingcap/tidb/pkg/util/regexpr-router" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/prometheus/client_golang/prometheus" + tikvconfig "github.com/tikv/client-go/v2/config" + kvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "github.com/tikv/pd/client/retry" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +// compact levels +const ( + FullLevelCompact = -1 + Level1Compact = 1 +) + +const ( + compactStateIdle int32 = iota + compactStateDoing +) + +// task related table names and create table statements. +const ( + TaskMetaTableName = "task_meta_v2" + TableMetaTableName = "table_meta" + // CreateTableMetadataTable stores the per-table sub jobs information used by TiDB Lightning + CreateTableMetadataTable = `CREATE TABLE IF NOT EXISTS %s.%s ( + task_id BIGINT(20) UNSIGNED, + table_id BIGINT(64) NOT NULL, + table_name VARCHAR(64) NOT NULL, + row_id_base BIGINT(20) NOT NULL DEFAULT 0, + row_id_max BIGINT(20) NOT NULL DEFAULT 0, + total_kvs_base BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + total_bytes_base BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + checksum_base BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + total_kvs BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + total_bytes BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + checksum BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + status VARCHAR(32) NOT NULL, + has_duplicates BOOL NOT NULL DEFAULT 0, + PRIMARY KEY (table_id, task_id) + );` + // CreateTaskMetaTable stores the pre-lightning metadata used by TiDB Lightning + CreateTaskMetaTable = `CREATE TABLE IF NOT EXISTS %s.%s ( + task_id BIGINT(20) UNSIGNED NOT NULL, + pd_cfgs VARCHAR(2048) NOT NULL DEFAULT '', + status VARCHAR(32) NOT NULL, + state TINYINT(1) NOT NULL DEFAULT 0 COMMENT '0: normal, 1: exited before finish', + tikv_source_bytes BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + tiflash_source_bytes BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + tikv_avail BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + tiflash_avail BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, + PRIMARY KEY (task_id) + );` +) + +var ( + minTiKVVersionForConflictStrategy = *semver.New("5.2.0") + maxTiKVVersionForConflictStrategy = version.NextMajorVersion() +) + +// DeliverPauser is a shared pauser to pause progress to (*chunkProcessor).encodeLoop +var DeliverPauser = common.NewPauser() + +// nolint:gochecknoinits // TODO: refactor +func init() { + failpoint.Inject("SetMinDeliverBytes", func(v failpoint.Value) { + minDeliverBytes = uint64(v.(int)) + }) +} + +type saveCp struct { + tableName string + merger checkpoints.TableCheckpointMerger + waitCh chan<- error +} + +type errorSummary struct { + status checkpoints.CheckpointStatus + err error +} + +type errorSummaries struct { + sync.Mutex + logger log.Logger + summary map[string]errorSummary +} + +// makeErrorSummaries returns an initialized errorSummaries instance +func makeErrorSummaries(logger log.Logger) errorSummaries { + return errorSummaries{ + logger: logger, + summary: make(map[string]errorSummary), + } +} + +func (es *errorSummaries) emitLog() { + es.Lock() + defer es.Unlock() + + if errorCount := len(es.summary); errorCount > 0 { + logger := es.logger + logger.Error("tables failed to be imported", zap.Int("count", errorCount)) + for tableName, errorSummary := range es.summary { + logger.Error("-", + zap.String("table", tableName), + zap.String("status", errorSummary.status.MetricName()), + log.ShortError(errorSummary.err), + ) + } + } +} + +func (es *errorSummaries) record(tableName string, err error, status checkpoints.CheckpointStatus) { + es.Lock() + defer es.Unlock() + es.summary[tableName] = errorSummary{status: status, err: err} +} + +const ( + diskQuotaStateIdle int32 = iota + diskQuotaStateChecking + diskQuotaStateImporting +) + +// Controller controls the whole import process. +type Controller struct { + taskCtx context.Context + cfg *config.Config + dbMetas []*mydump.MDDatabaseMeta + dbInfos map[string]*checkpoints.TidbDBInfo + tableWorkers *worker.Pool + indexWorkers *worker.Pool + regionWorkers *worker.Pool + ioWorkers *worker.Pool + checksumWorks *worker.Pool + pauser *common.Pauser + engineMgr backend.EngineManager + backend backend.Backend + db *sql.DB + pdCli pd.Client + pdHTTPCli pdhttp.Client + + sysVars map[string]string + tls *common.TLS + checkTemplate Template + + errorSummaries errorSummaries + + checkpointsDB checkpoints.DB + saveCpCh chan saveCp + checkpointsWg sync.WaitGroup + + closedEngineLimit *worker.Pool + addIndexLimit *worker.Pool + + store storage.ExternalStorage + ownStore bool + metaMgrBuilder metaMgrBuilder + errorMgr *errormanager.ErrorManager + taskMgr taskMetaMgr + + diskQuotaLock sync.RWMutex + diskQuotaState atomic.Int32 + compactState atomic.Int32 + status *LightningStatus + dupIndicator *atomic.Bool + + preInfoGetter PreImportInfoGetter + precheckItemBuilder *PrecheckItemBuilder + encBuilder encode.EncodingBuilder + tikvModeSwitcher local.TiKVModeSwitcher + + keyspaceName string + resourceGroupName string + taskType string +} + +// LightningStatus provides the finished bytes and total bytes of the current task. +// It should keep the value after restart from checkpoint. +// When it is tidb backend, FinishedFileSize can be counted after chunk data is +// restored to tidb. When it is local backend it's counted after whole engine is +// imported. +// TotalFileSize may be an estimated value, so when the task is finished, it may +// not equal to FinishedFileSize. +type LightningStatus struct { + backend string + FinishedFileSize atomic.Int64 + TotalFileSize atomic.Int64 +} + +// ControllerParam contains many parameters for creating a Controller. +type ControllerParam struct { + // databases that dumper created + DBMetas []*mydump.MDDatabaseMeta + // a pointer to status to report it to caller + Status *LightningStatus + // storage interface to read the dump data + DumpFileStorage storage.ExternalStorage + // true if DumpFileStorage is created by lightning. In some cases where lightning is a library, the framework may pass an DumpFileStorage + OwnExtStorage bool + // used by lightning server mode to pause tasks + Pauser *common.Pauser + // DB is a connection pool to TiDB + DB *sql.DB + // storage interface to write file checkpoints + CheckpointStorage storage.ExternalStorage + // when CheckpointStorage is not nil, save file checkpoint to it with this name + CheckpointName string + // DupIndicator can expose the duplicate detection result to the caller + DupIndicator *atomic.Bool + // Keyspace name + KeyspaceName string + // ResourceGroup name for current TiDB user + ResourceGroupName string + // TaskType is the source component name use for background task control. + TaskType string +} + +// NewImportController creates a new Controller instance. +func NewImportController( + ctx context.Context, + cfg *config.Config, + param *ControllerParam, +) (*Controller, error) { + param.Pauser = DeliverPauser + return NewImportControllerWithPauser(ctx, cfg, param) +} + +// NewImportControllerWithPauser creates a new Controller instance with a pauser. +func NewImportControllerWithPauser( + ctx context.Context, + cfg *config.Config, + p *ControllerParam, +) (*Controller, error) { + tls, err := cfg.ToTLS() + if err != nil { + return nil, err + } + + var cpdb checkpoints.DB + // if CheckpointStorage is set, we should use given ExternalStorage to create checkpoints. + if p.CheckpointStorage != nil { + cpdb, err = checkpoints.NewFileCheckpointsDBWithExstorageFileName(ctx, p.CheckpointStorage.URI(), p.CheckpointStorage, p.CheckpointName) + if err != nil { + return nil, common.ErrOpenCheckpoint.Wrap(err).GenWithStackByArgs() + } + } else { + cpdb, err = checkpoints.OpenCheckpointsDB(ctx, cfg) + if err != nil { + if berrors.Is(err, common.ErrUnknownCheckpointDriver) { + return nil, err + } + return nil, common.ErrOpenCheckpoint.Wrap(err).GenWithStackByArgs() + } + } + + taskCp, err := cpdb.TaskCheckpoint(ctx) + if err != nil { + return nil, common.ErrReadCheckpoint.Wrap(err).GenWithStack("get task checkpoint failed") + } + if err := verifyCheckpoint(cfg, taskCp); err != nil { + return nil, errors.Trace(err) + } + // reuse task id to reuse task meta correctly. + if taskCp != nil { + cfg.TaskID = taskCp.TaskID + } + + db := p.DB + errorMgr := errormanager.New(db, cfg, log.FromContext(ctx)) + if err := errorMgr.Init(ctx); err != nil { + return nil, common.ErrInitErrManager.Wrap(err).GenWithStackByArgs() + } + + var encodingBuilder encode.EncodingBuilder + var backendObj backend.Backend + var pdCli pd.Client + var pdHTTPCli pdhttp.Client + switch cfg.TikvImporter.Backend { + case config.BackendTiDB: + encodingBuilder = tidb.NewEncodingBuilder() + backendObj = tidb.NewTiDBBackend(ctx, db, cfg, errorMgr) + case config.BackendLocal: + var rLimit local.RlimT + rLimit, err = local.GetSystemRLimit() + if err != nil { + return nil, err + } + maxOpenFiles := int(rLimit / local.RlimT(cfg.App.TableConcurrency)) + // check overflow + if maxOpenFiles < 0 { + maxOpenFiles = math.MaxInt32 + } + + addrs := strings.Split(cfg.TiDB.PdAddr, ",") + pdCli, err = pd.NewClientWithContext(ctx, addrs, tls.ToPDSecurityOption()) + if err != nil { + return nil, errors.Trace(err) + } + pdHTTPCli = pdhttp.NewClientWithServiceDiscovery( + "lightning", + pdCli.GetServiceDiscovery(), + pdhttp.WithTLSConfig(tls.TLSConfig()), + ).WithBackoffer(retry.InitialBackoffer(time.Second, time.Second, pdutil.PDRequestRetryTime*time.Second)) + + if isLocalBackend(cfg) && cfg.Conflict.Strategy != config.NoneOnDup { + if err := tikv.CheckTiKVVersion(ctx, pdHTTPCli, minTiKVVersionForConflictStrategy, maxTiKVVersionForConflictStrategy); err != nil { + if !berrors.Is(err, berrors.ErrVersionMismatch) { + return nil, common.ErrCheckKVVersion.Wrap(err).GenWithStackByArgs() + } + log.FromContext(ctx).Warn("TiKV version doesn't support conflict strategy. The resolution algorithm will fall back to 'none'", zap.Error(err)) + cfg.Conflict.Strategy = config.NoneOnDup + } + } + + initGlobalConfig(tls.ToTiKVSecurityConfig()) + + encodingBuilder = local.NewEncodingBuilder(ctx) + + // get resource group name. + exec := common.SQLWithRetry{ + DB: db, + Logger: log.FromContext(ctx), + } + if err := exec.QueryRow(ctx, "", "select current_resource_group();", &p.ResourceGroupName); err != nil { + if common.IsFunctionNotExistErr(err, "current_resource_group") { + log.FromContext(ctx).Warn("current_resource_group() not supported, ignore this error", zap.Error(err)) + } + } + + taskType, err := common.GetExplicitRequestSourceTypeFromDB(ctx, db) + if err != nil { + return nil, errors.Annotatef(err, "get system variable '%s' failed", variable.TiDBExplicitRequestSourceType) + } + if taskType == "" { + taskType = kvutil.ExplicitTypeLightning + } + p.TaskType = taskType + + // TODO: we should not need to check config here. + // Instead, we should perform the following during switch mode: + // 1. for each tikv, try to switch mode without any ranges. + // 2. if it returns normally, it means the store is using a raft-v1 engine. + // 3. if it returns the `partitioned-raft-kv only support switch mode with range set` error, + // it means the store is a raft-v2 engine and we will include the ranges from now on. + isRaftKV2, err := common.IsRaftKV2(ctx, db) + if err != nil { + log.FromContext(ctx).Warn("check isRaftKV2 failed", zap.Error(err)) + } + var raftKV2SwitchModeDuration time.Duration + if isRaftKV2 { + raftKV2SwitchModeDuration = cfg.Cron.SwitchMode.Duration + } + backendConfig := local.NewBackendConfig(cfg, maxOpenFiles, p.KeyspaceName, p.ResourceGroupName, p.TaskType, raftKV2SwitchModeDuration) + backendObj, err = local.NewBackend(ctx, tls, backendConfig, pdCli.GetServiceDiscovery()) + if err != nil { + return nil, common.NormalizeOrWrapErr(common.ErrUnknown, err) + } + err = verifyLocalFile(ctx, cpdb, cfg.TikvImporter.SortedKVDir) + if err != nil { + return nil, err + } + default: + return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) + } + p.Status.backend = cfg.TikvImporter.Backend + + var metaBuilder metaMgrBuilder + isSSTImport := cfg.TikvImporter.Backend == config.BackendLocal + switch { + case isSSTImport && cfg.TikvImporter.ParallelImport: + metaBuilder = &dbMetaMgrBuilder{ + db: db, + taskID: cfg.TaskID, + schema: cfg.App.MetaSchemaName, + needChecksum: cfg.PostRestore.Checksum != config.OpLevelOff, + } + case isSSTImport: + metaBuilder = singleMgrBuilder{ + taskID: cfg.TaskID, + } + default: + metaBuilder = noopMetaMgrBuilder{} + } + + var wrapper backend.TargetInfoGetter + if cfg.TikvImporter.Backend == config.BackendLocal { + wrapper = local.NewTargetInfoGetter(tls, db, pdHTTPCli) + } else { + wrapper = tidb.NewTargetInfoGetter(db) + } + ioWorkers := worker.NewPool(ctx, cfg.App.IOConcurrency, "io") + targetInfoGetter := &TargetInfoGetterImpl{ + cfg: cfg, + db: db, + backend: wrapper, + pdHTTPCli: pdHTTPCli, + } + preInfoGetter, err := NewPreImportInfoGetter( + cfg, + p.DBMetas, + p.DumpFileStorage, + targetInfoGetter, + ioWorkers, + encodingBuilder, + ) + if err != nil { + return nil, errors.Trace(err) + } + + preCheckBuilder := NewPrecheckItemBuilder( + cfg, p.DBMetas, preInfoGetter, cpdb, pdHTTPCli, + ) + + rc := &Controller{ + taskCtx: ctx, + cfg: cfg, + dbMetas: p.DBMetas, + tableWorkers: nil, + indexWorkers: nil, + regionWorkers: worker.NewPool(ctx, cfg.App.RegionConcurrency, "region"), + ioWorkers: ioWorkers, + checksumWorks: worker.NewPool(ctx, cfg.TiDB.ChecksumTableConcurrency, "checksum"), + pauser: p.Pauser, + engineMgr: backend.MakeEngineManager(backendObj), + backend: backendObj, + pdCli: pdCli, + pdHTTPCli: pdHTTPCli, + db: db, + sysVars: common.DefaultImportantVariables, + tls: tls, + checkTemplate: NewSimpleTemplate(), + + errorSummaries: makeErrorSummaries(log.FromContext(ctx)), + checkpointsDB: cpdb, + saveCpCh: make(chan saveCp), + closedEngineLimit: worker.NewPool(ctx, cfg.App.TableConcurrency*2, "closed-engine"), + // Currently, TiDB add index acceration doesn't support multiple tables simultaneously. + // So we use a single worker to ensure at most one table is adding index at the same time. + addIndexLimit: worker.NewPool(ctx, 1, "add-index"), + + store: p.DumpFileStorage, + ownStore: p.OwnExtStorage, + metaMgrBuilder: metaBuilder, + errorMgr: errorMgr, + status: p.Status, + taskMgr: nil, + dupIndicator: p.DupIndicator, + + preInfoGetter: preInfoGetter, + precheckItemBuilder: preCheckBuilder, + encBuilder: encodingBuilder, + tikvModeSwitcher: local.NewTiKVModeSwitcher(tls.TLSConfig(), pdHTTPCli, log.FromContext(ctx).Logger), + + keyspaceName: p.KeyspaceName, + resourceGroupName: p.ResourceGroupName, + taskType: p.TaskType, + } + + return rc, nil +} + +// Close closes the controller. +func (rc *Controller) Close() { + rc.backend.Close() + _ = rc.db.Close() + if rc.pdCli != nil { + rc.pdCli.Close() + } +} + +// Run starts the restore task. +func (rc *Controller) Run(ctx context.Context) error { + failpoint.Inject("beforeRun", func() {}) + + opts := []func(context.Context) error{ + rc.setGlobalVariables, + rc.restoreSchema, + rc.preCheckRequirements, + rc.initCheckpoint, + rc.importTables, + rc.fullCompact, + rc.cleanCheckpoints, + } + + task := log.FromContext(ctx).Begin(zap.InfoLevel, "the whole procedure") + + var err error + finished := false +outside: + for i, process := range opts { + err = process(ctx) + if i == len(opts)-1 { + finished = true + } + logger := task.With(zap.Int("step", i), log.ShortError(err)) + + switch { + case err == nil: + case log.IsContextCanceledError(err): + logger.Info("task canceled") + break outside + default: + logger.Error("run failed") + break outside // ps : not continue + } + } + + // if process is cancelled, should make sure checkpoints are written to db. + if !finished { + rc.waitCheckpointFinish() + } + + task.End(zap.ErrorLevel, err) + rc.errorMgr.LogErrorDetails() + rc.errorSummaries.emitLog() + + return errors.Trace(err) +} + +func (rc *Controller) restoreSchema(ctx context.Context) error { + // create table with schema file + // we can handle the duplicated created with createIfNotExist statement + // and we will check the schema in TiDB is valid with the datafile in DataCheck later. + logger := log.FromContext(ctx) + concurrency := min(rc.cfg.App.RegionConcurrency, 8) + // sql.DB is a connection pool, we set it to concurrency + 1(for job generator) + // to reuse connections, as we might call db.Conn/conn.Close many times. + // there's no API to get sql.DB.MaxIdleConns, so we revert to its default which is 2 + rc.db.SetMaxIdleConns(concurrency + 1) + defer rc.db.SetMaxIdleConns(2) + schemaImp := mydump.NewSchemaImporter(logger, rc.cfg.TiDB.SQLMode, rc.db, rc.store, concurrency) + err := schemaImp.Run(ctx, rc.dbMetas) + if err != nil { + return err + } + + dbInfos, err := rc.preInfoGetter.GetAllTableStructures(ctx) + if err != nil { + return errors.Trace(err) + } + // For local backend, we need DBInfo.ID to operate the global autoid allocator. + if isLocalBackend(rc.cfg) { + dbs, err := tikv.FetchRemoteDBModelsFromTLS(ctx, rc.tls) + if err != nil { + return errors.Trace(err) + } + dbIDs := make(map[string]int64) + for _, db := range dbs { + dbIDs[db.Name.L] = db.ID + } + for _, dbInfo := range dbInfos { + dbInfo.ID = dbIDs[strings.ToLower(dbInfo.Name)] + } + } + rc.dbInfos = dbInfos + rc.sysVars = rc.preInfoGetter.GetTargetSysVariablesForImport(ctx) + + return nil +} + +// initCheckpoint initializes all tables' checkpoint data +func (rc *Controller) initCheckpoint(ctx context.Context) error { + // Load new checkpoints + err := rc.checkpointsDB.Initialize(ctx, rc.cfg, rc.dbInfos) + if err != nil { + return common.ErrInitCheckpoint.Wrap(err).GenWithStackByArgs() + } + failpoint.Inject("InitializeCheckpointExit", func() { + log.FromContext(ctx).Warn("exit triggered", zap.String("failpoint", "InitializeCheckpointExit")) + os.Exit(0) + }) + if err := rc.loadDesiredTableInfos(ctx); err != nil { + return err + } + + rc.checkpointsWg.Add(1) // checkpointsWg will be done in `rc.listenCheckpointUpdates` + go rc.listenCheckpointUpdates(log.FromContext(ctx)) + + // Estimate the number of chunks for progress reporting + return rc.estimateChunkCountIntoMetrics(ctx) +} + +func (rc *Controller) loadDesiredTableInfos(ctx context.Context) error { + for _, dbInfo := range rc.dbInfos { + for _, tableInfo := range dbInfo.Tables { + cp, err := rc.checkpointsDB.Get(ctx, common.UniqueTable(dbInfo.Name, tableInfo.Name)) + if err != nil { + return err + } + // If checkpoint is disabled, cp.TableInfo will be nil. + // In this case, we just use current tableInfo as desired tableInfo. + if cp.TableInfo != nil { + tableInfo.Desired = cp.TableInfo + } + } + } + return nil +} + +// verifyCheckpoint check whether previous task checkpoint is compatible with task config +func verifyCheckpoint(cfg *config.Config, taskCp *checkpoints.TaskCheckpoint) error { + if taskCp == nil { + return nil + } + // always check the backend value even with 'check-requirements = false' + retryUsage := "destroy all checkpoints" + if cfg.Checkpoint.Driver == config.CheckpointDriverFile { + retryUsage = fmt.Sprintf("delete the file '%s'", cfg.Checkpoint.DSN) + } + retryUsage += " and remove all restored tables and try again" + + if cfg.TikvImporter.Backend != taskCp.Backend { + return common.ErrInvalidCheckpoint.GenWithStack("config 'tikv-importer.backend' value '%s' different from checkpoint value '%s', please %s", cfg.TikvImporter.Backend, taskCp.Backend, retryUsage) + } + + if cfg.App.CheckRequirements { + if build.ReleaseVersion != taskCp.LightningVer { + var displayVer string + if len(taskCp.LightningVer) != 0 { + displayVer = fmt.Sprintf("at '%s'", taskCp.LightningVer) + } else { + displayVer = "before v4.0.6/v3.0.19" + } + return common.ErrInvalidCheckpoint.GenWithStack("lightning version is '%s', but checkpoint was created %s, please %s", build.ReleaseVersion, displayVer, retryUsage) + } + + errorFmt := "config '%s' value '%s' different from checkpoint value '%s'. You may set 'check-requirements = false' to skip this check or " + retryUsage + if cfg.Mydumper.SourceDir != taskCp.SourceDir { + return common.ErrInvalidCheckpoint.GenWithStack(errorFmt, "mydumper.data-source-dir", cfg.Mydumper.SourceDir, taskCp.SourceDir) + } + + if cfg.TikvImporter.Backend == config.BackendLocal && cfg.TikvImporter.SortedKVDir != taskCp.SortedKVDir { + return common.ErrInvalidCheckpoint.GenWithStack(errorFmt, "mydumper.sorted-kv-dir", cfg.TikvImporter.SortedKVDir, taskCp.SortedKVDir) + } + } + + return nil +} + +// for local backend, we should check if local SST exists in disk, otherwise we'll lost data +func verifyLocalFile(ctx context.Context, cpdb checkpoints.DB, dir string) error { + targetTables, err := cpdb.GetLocalStoringTables(ctx) + if err != nil { + return errors.Trace(err) + } + for tableName, engineIDs := range targetTables { + for _, engineID := range engineIDs { + _, eID := backend.MakeUUID(tableName, int64(engineID)) + file := local.Engine{UUID: eID} + err := file.Exist(dir) + if err != nil { + log.FromContext(ctx).Error("can't find local file", + zap.String("table name", tableName), + zap.Int32("engine ID", engineID)) + if os.IsNotExist(err) { + err = common.ErrCheckLocalFile.GenWithStackByArgs(tableName, dir) + } else { + err = common.ErrCheckLocalFile.Wrap(err).GenWithStackByArgs(tableName, dir) + } + return err + } + } + } + return nil +} + +func (rc *Controller) estimateChunkCountIntoMetrics(ctx context.Context) error { + estimatedChunkCount := 0.0 + estimatedEngineCnt := int64(0) + for _, dbMeta := range rc.dbMetas { + for _, tableMeta := range dbMeta.Tables { + batchSize := mydump.CalculateBatchSize(float64(rc.cfg.Mydumper.BatchSize), + tableMeta.IsRowOrdered, float64(tableMeta.TotalSize)) + tableName := common.UniqueTable(dbMeta.Name, tableMeta.Name) + dbCp, err := rc.checkpointsDB.Get(ctx, tableName) + if err != nil { + return errors.Trace(err) + } + + fileChunks := make(map[string]float64) + for engineID, eCp := range dbCp.Engines { + if eCp.Status < checkpoints.CheckpointStatusImported { + estimatedEngineCnt++ + } + if engineID == common.IndexEngineID { + continue + } + for _, c := range eCp.Chunks { + if _, ok := fileChunks[c.Key.Path]; !ok { + fileChunks[c.Key.Path] = 0.0 + } + remainChunkCnt := float64(c.UnfinishedSize()) / float64(c.TotalSize()) + fileChunks[c.Key.Path] += remainChunkCnt + } + } + // estimate engines count if engine cp is empty + if len(dbCp.Engines) == 0 { + estimatedEngineCnt += ((tableMeta.TotalSize + int64(batchSize) - 1) / int64(batchSize)) + 1 + } + for _, fileMeta := range tableMeta.DataFiles { + if cnt, ok := fileChunks[fileMeta.FileMeta.Path]; ok { + estimatedChunkCount += cnt + continue + } + if fileMeta.FileMeta.Type == mydump.SourceTypeCSV { + cfg := rc.cfg.Mydumper + if fileMeta.FileMeta.FileSize > int64(cfg.MaxRegionSize) && cfg.StrictFormat && + !cfg.CSV.Header && fileMeta.FileMeta.Compression == mydump.CompressionNone { + estimatedChunkCount += math.Round(float64(fileMeta.FileMeta.FileSize) / float64(cfg.MaxRegionSize)) + } else { + estimatedChunkCount++ + } + } else { + estimatedChunkCount++ + } + } + } + } + if m, ok := metric.FromContext(ctx); ok { + m.ChunkCounter.WithLabelValues(metric.ChunkStateEstimated).Add(estimatedChunkCount) + m.ProcessedEngineCounter.WithLabelValues(metric.ChunkStateEstimated, metric.TableResultSuccess). + Add(float64(estimatedEngineCnt)) + } + return nil +} + +func firstErr(errors ...error) error { + for _, err := range errors { + if err != nil { + return err + } + } + return nil +} + +func (rc *Controller) saveStatusCheckpoint(ctx context.Context, tableName string, engineID int32, err error, statusIfSucceed checkpoints.CheckpointStatus) error { + merger := &checkpoints.StatusCheckpointMerger{Status: statusIfSucceed, EngineID: engineID} + + logger := log.FromContext(ctx).With(zap.String("table", tableName), zap.Int32("engine_id", engineID), + zap.String("new_status", statusIfSucceed.MetricName()), zap.Error(err)) + logger.Debug("update checkpoint") + + switch { + case err == nil: + case utils.MessageIsRetryableStorageError(err.Error()), common.IsContextCanceledError(err): + // recoverable error, should not be recorded in checkpoint + // which will prevent lightning from automatically recovering + return nil + default: + // unrecoverable error + merger.SetInvalid() + rc.errorSummaries.record(tableName, err, statusIfSucceed) + } + + if m, ok := metric.FromContext(ctx); ok { + if engineID == checkpoints.WholeTableEngineID { + m.RecordTableCount(statusIfSucceed.MetricName(), err) + } else { + m.RecordEngineCount(statusIfSucceed.MetricName(), err) + } + } + + waitCh := make(chan error, 1) + rc.saveCpCh <- saveCp{tableName: tableName, merger: merger, waitCh: waitCh} + + select { + case saveCpErr := <-waitCh: + if saveCpErr != nil { + logger.Error("failed to save status checkpoint", log.ShortError(saveCpErr)) + } + return saveCpErr + case <-ctx.Done(): + return ctx.Err() + } +} + +// listenCheckpointUpdates will combine several checkpoints together to reduce database load. +func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { + var lock sync.Mutex + coalesed := make(map[string]*checkpoints.TableCheckpointDiff) + var waiters []chan<- error + + hasCheckpoint := make(chan struct{}, 1) + defer close(hasCheckpoint) + + go func() { + for range hasCheckpoint { + lock.Lock() + cpd := coalesed + coalesed = make(map[string]*checkpoints.TableCheckpointDiff) + ws := waiters + waiters = nil + lock.Unlock() + + //nolint:scopelint // This would be either INLINED or ERASED, at compile time. + failpoint.Inject("SlowDownCheckpointUpdate", func() {}) + + if len(cpd) > 0 { + err := rc.checkpointsDB.Update(rc.taskCtx, cpd) + for _, w := range ws { + w <- common.NormalizeOrWrapErr(common.ErrUpdateCheckpoint, err) + } + web.BroadcastCheckpointDiff(cpd) + } + rc.checkpointsWg.Done() + } + }() + + for scp := range rc.saveCpCh { + lock.Lock() + cpd, ok := coalesed[scp.tableName] + if !ok { + cpd = checkpoints.NewTableCheckpointDiff() + coalesed[scp.tableName] = cpd + } + scp.merger.MergeInto(cpd) + if scp.waitCh != nil { + waiters = append(waiters, scp.waitCh) + } + + if len(hasCheckpoint) == 0 { + rc.checkpointsWg.Add(1) + hasCheckpoint <- struct{}{} + } + + lock.Unlock() + + //nolint:scopelint // This would be either INLINED or ERASED, at compile time. + failpoint.Inject("FailIfImportedChunk", func() { + if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { + rc.checkpointsWg.Done() + rc.checkpointsWg.Wait() + panic("forcing failure due to FailIfImportedChunk") + } + }) + + //nolint:scopelint // This would be either INLINED or ERASED, at compile time. + failpoint.Inject("FailIfStatusBecomes", func(val failpoint.Value) { + if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && merger.EngineID >= 0 && int(merger.Status) == val.(int) { + rc.checkpointsWg.Done() + rc.checkpointsWg.Wait() + panic("forcing failure due to FailIfStatusBecomes") + } + }) + + //nolint:scopelint // This would be either INLINED or ERASED, at compile time. + failpoint.Inject("FailIfIndexEngineImported", func(val failpoint.Value) { + if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && + merger.EngineID == checkpoints.WholeTableEngineID && + merger.Status == checkpoints.CheckpointStatusIndexImported && val.(int) > 0 { + rc.checkpointsWg.Done() + rc.checkpointsWg.Wait() + panic("forcing failure due to FailIfIndexEngineImported") + } + }) + + //nolint:scopelint // This would be either INLINED or ERASED, at compile time. + failpoint.Inject("KillIfImportedChunk", func() { + if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { + rc.checkpointsWg.Done() + rc.checkpointsWg.Wait() + if err := common.KillMySelf(); err != nil { + logger.Warn("KillMySelf() failed to kill itself", log.ShortError(err)) + } + for scp := range rc.saveCpCh { + if scp.waitCh != nil { + scp.waitCh <- context.Canceled + } + } + failpoint.Return() + } + }) + } + // Don't put this statement in defer function at the beginning. failpoint function may call it manually. + rc.checkpointsWg.Done() +} + +// buildRunPeriodicActionAndCancelFunc build the runPeriodicAction func and a cancel func +func (rc *Controller) buildRunPeriodicActionAndCancelFunc(ctx context.Context, stop <-chan struct{}) (func(), func(bool)) { + cancelFuncs := make([]func(bool), 0) + closeFuncs := make([]func(), 0) + // a nil channel blocks forever. + // if the cron duration is zero we use the nil channel to skip the action. + var logProgressChan <-chan time.Time + if rc.cfg.Cron.LogProgress.Duration > 0 { + logProgressTicker := time.NewTicker(rc.cfg.Cron.LogProgress.Duration) + closeFuncs = append(closeFuncs, func() { + logProgressTicker.Stop() + }) + logProgressChan = logProgressTicker.C + } + + var switchModeChan <-chan time.Time + // tidb backend don't need to switch tikv to import mode + if isLocalBackend(rc.cfg) && rc.cfg.Cron.SwitchMode.Duration > 0 { + switchModeTicker := time.NewTicker(rc.cfg.Cron.SwitchMode.Duration) + cancelFuncs = append(cancelFuncs, func(bool) { switchModeTicker.Stop() }) + cancelFuncs = append(cancelFuncs, func(do bool) { + if do { + rc.tikvModeSwitcher.ToNormalMode(ctx) + } + }) + switchModeChan = switchModeTicker.C + } + + var checkQuotaChan <-chan time.Time + // only local storage has disk quota concern. + if rc.cfg.TikvImporter.Backend == config.BackendLocal && rc.cfg.Cron.CheckDiskQuota.Duration > 0 { + checkQuotaTicker := time.NewTicker(rc.cfg.Cron.CheckDiskQuota.Duration) + cancelFuncs = append(cancelFuncs, func(bool) { checkQuotaTicker.Stop() }) + checkQuotaChan = checkQuotaTicker.C + } + + return func() { + defer func() { + for _, f := range closeFuncs { + f() + } + }() + if rc.cfg.Cron.SwitchMode.Duration > 0 && isLocalBackend(rc.cfg) { + rc.tikvModeSwitcher.ToImportMode(ctx) + } + start := time.Now() + for { + select { + case <-ctx.Done(): + log.FromContext(ctx).Warn("stopping periodic actions", log.ShortError(ctx.Err())) + return + case <-stop: + log.FromContext(ctx).Info("everything imported, stopping periodic actions") + return + + case <-switchModeChan: + // periodically switch to import mode, as requested by TiKV 3.0 + // TiKV will switch back to normal mode if we didn't call this again within 10 minutes + rc.tikvModeSwitcher.ToImportMode(ctx) + + case <-logProgressChan: + metrics, ok := metric.FromContext(ctx) + if !ok { + log.FromContext(ctx).Warn("couldn't find metrics from context, skip log progress") + continue + } + // log the current progress periodically, so OPS will know that we're still working + nanoseconds := float64(time.Since(start).Nanoseconds()) + totalRestoreBytes := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateTotalRestore)) + restoredBytes := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateRestored)) + totalRowsToRestore := metric.ReadAllCounters(metrics.RowsCounter.MetricVec, prometheus.Labels{"state": metric.StateTotalRestore}) + restoredRows := metric.ReadAllCounters(metrics.RowsCounter.MetricVec, prometheus.Labels{"state": metric.StateRestored}) + // the estimated chunk is not accurate(likely under estimated), but the actual count is not accurate + // before the last table start, so use the bigger of the two should be a workaround + estimated := metric.ReadCounter(metrics.ChunkCounter.WithLabelValues(metric.ChunkStateEstimated)) + pending := metric.ReadCounter(metrics.ChunkCounter.WithLabelValues(metric.ChunkStatePending)) + if estimated < pending { + estimated = pending + } + finished := metric.ReadCounter(metrics.ChunkCounter.WithLabelValues(metric.ChunkStateFinished)) + totalTables := metric.ReadCounter(metrics.TableCounter.WithLabelValues(metric.TableStatePending, metric.TableResultSuccess)) + completedTables := metric.ReadCounter(metrics.TableCounter.WithLabelValues(metric.TableStateCompleted, metric.TableResultSuccess)) + bytesRead := metric.ReadHistogramSum(metrics.RowReadBytesHistogram) + engineEstimated := metric.ReadCounter(metrics.ProcessedEngineCounter.WithLabelValues(metric.ChunkStateEstimated, metric.TableResultSuccess)) + enginePending := metric.ReadCounter(metrics.ProcessedEngineCounter.WithLabelValues(metric.ChunkStatePending, metric.TableResultSuccess)) + if engineEstimated < enginePending { + engineEstimated = enginePending + } + engineFinished := metric.ReadCounter(metrics.ProcessedEngineCounter.WithLabelValues(metric.TableStateImported, metric.TableResultSuccess)) + bytesWritten := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateRestoreWritten)) + bytesImported := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateImported)) + + var state string + var remaining zap.Field + switch { + case finished >= estimated: + if engineFinished < engineEstimated { + state = "importing" + } else { + state = "post-processing" + } + case finished > 0: + state = "writing" + default: + state = "preparing" + } + + // lightning restore is separated into restore engine and import engine, they are both parallelized + // and pipelined between engines, so we can only weight the progress of those 2 phase to get the + // total progress. + // + // for local & importer backend: + // in most case import engine is faster since there's little computations, but inside one engine + // restore and import is serialized, the progress of those two will not differ too much, and + // import engine determines the end time of the whole restore, so we average them for now. + // the result progress may fall behind the real progress if import is faster. + // + // for tidb backend, we do nothing during import engine, so we use restore engine progress as the + // total progress. + restoreBytesField := zap.Skip() + importBytesField := zap.Skip() + restoreRowsField := zap.Skip() + remaining = zap.Skip() + totalPercent := 0.0 + if restoredBytes > 0 || restoredRows > 0 { + var restorePercent float64 + if totalRowsToRestore > 0 { + restorePercent = math.Min(restoredRows/totalRowsToRestore, 1.0) + restoreRowsField = zap.String("restore-rows", fmt.Sprintf("%.0f/%.0f", + restoredRows, totalRowsToRestore)) + } else { + restorePercent = math.Min(restoredBytes/totalRestoreBytes, 1.0) + restoreRowsField = zap.String("restore-rows", fmt.Sprintf("%.0f/%.0f(estimated)", + restoredRows, restoredRows/restorePercent)) + } + metrics.ProgressGauge.WithLabelValues(metric.ProgressPhaseRestore).Set(restorePercent) + if rc.cfg.TikvImporter.Backend != config.BackendTiDB { + var importPercent float64 + if bytesWritten > 0 { + // estimate total import bytes from written bytes + // when importPercent = 1, totalImportBytes = bytesWritten, but there's case + // bytesImported may bigger or smaller than bytesWritten such as when deduplicate + // we calculate progress using engines then use the bigger one in case bytesImported is + // smaller. + totalImportBytes := bytesWritten / restorePercent + biggerPercent := math.Max(bytesImported/totalImportBytes, engineFinished/engineEstimated) + importPercent = math.Min(biggerPercent, 1.0) + importBytesField = zap.String("import-bytes", fmt.Sprintf("%s/%s(estimated)", + units.BytesSize(bytesImported), units.BytesSize(totalImportBytes))) + } + metrics.ProgressGauge.WithLabelValues(metric.ProgressPhaseImport).Set(importPercent) + totalPercent = (restorePercent + importPercent) / 2 + } else { + totalPercent = restorePercent + } + if totalPercent < 1.0 { + remainNanoseconds := (1.0 - totalPercent) / totalPercent * nanoseconds + remaining = zap.Duration("remaining", time.Duration(remainNanoseconds).Round(time.Second)) + } + restoreBytesField = zap.String("restore-bytes", fmt.Sprintf("%s/%s", + units.BytesSize(restoredBytes), units.BytesSize(totalRestoreBytes))) + } + metrics.ProgressGauge.WithLabelValues(metric.ProgressPhaseTotal).Set(totalPercent) + + formatPercent := func(num, denom float64) string { + if denom > 0 { + return fmt.Sprintf(" (%.1f%%)", num/denom*100) + } + return "" + } + + // avoid output bytes speed if there are no unfinished chunks + encodeSpeedField := zap.Skip() + if bytesRead > 0 { + encodeSpeedField = zap.Float64("encode speed(MiB/s)", bytesRead/(1048576e-9*nanoseconds)) + } + + // Note: a speed of 28 MiB/s roughly corresponds to 100 GiB/hour. + log.FromContext(ctx).Info("progress", + zap.String("total", fmt.Sprintf("%.1f%%", totalPercent*100)), + // zap.String("files", fmt.Sprintf("%.0f/%.0f (%.1f%%)", finished, estimated, finished/estimated*100)), + zap.String("tables", fmt.Sprintf("%.0f/%.0f%s", completedTables, totalTables, formatPercent(completedTables, totalTables))), + zap.String("chunks", fmt.Sprintf("%.0f/%.0f%s", finished, estimated, formatPercent(finished, estimated))), + zap.String("engines", fmt.Sprintf("%.f/%.f%s", engineFinished, engineEstimated, formatPercent(engineFinished, engineEstimated))), + restoreBytesField, restoreRowsField, importBytesField, + encodeSpeedField, + zap.String("state", state), + remaining, + ) + + case <-checkQuotaChan: + // verify the total space occupied by sorted-kv-dir is below the quota, + // otherwise we perform an emergency import. + rc.enforceDiskQuota(ctx) + } + } + }, func(do bool) { + log.FromContext(ctx).Info("cancel periodic actions", zap.Bool("do", do)) + for _, f := range cancelFuncs { + f(do) + } + } +} + +func (rc *Controller) buildTablesRanges() []tidbkv.KeyRange { + var keyRanges []tidbkv.KeyRange + for _, dbInfo := range rc.dbInfos { + for _, tableInfo := range dbInfo.Tables { + if ranges, err := distsql.BuildTableRanges(tableInfo.Core); err == nil { + keyRanges = append(keyRanges, ranges...) + } + } + } + return keyRanges +} + +type checksumManagerKeyType struct{} + +var checksumManagerKey checksumManagerKeyType + +const ( + pauseGCTTLForDupeRes = time.Hour + pauseGCIntervalForDupeRes = time.Minute +) + +func (rc *Controller) keepPauseGCForDupeRes(ctx context.Context) (<-chan struct{}, error) { + tlsOpt := rc.tls.ToPDSecurityOption() + addrs := strings.Split(rc.cfg.TiDB.PdAddr, ",") + pdCli, err := pd.NewClientWithContext(ctx, addrs, tlsOpt) + if err != nil { + return nil, errors.Trace(err) + } + + serviceID := "lightning-duplicate-resolution-" + uuid.New().String() + ttl := int64(pauseGCTTLForDupeRes / time.Second) + + var ( + safePoint uint64 + paused bool + ) + // Try to get the minimum safe point across all services as our GC safe point. + for i := 0; i < 10; i++ { + if i > 0 { + time.Sleep(time.Second * 3) + } + minSafePoint, err := pdCli.UpdateServiceGCSafePoint(ctx, serviceID, ttl, 1) + if err != nil { + pdCli.Close() + return nil, errors.Trace(err) + } + newMinSafePoint, err := pdCli.UpdateServiceGCSafePoint(ctx, serviceID, ttl, minSafePoint) + if err != nil { + pdCli.Close() + return nil, errors.Trace(err) + } + if newMinSafePoint <= minSafePoint { + safePoint = minSafePoint + paused = true + break + } + log.FromContext(ctx).Warn( + "Failed to register GC safe point because the current minimum safe point is newer"+ + " than what we assume, will retry newMinSafePoint next time", + zap.Uint64("minSafePoint", minSafePoint), + zap.Uint64("newMinSafePoint", newMinSafePoint), + ) + } + if !paused { + pdCli.Close() + return nil, common.ErrPauseGC.GenWithStack("failed to pause GC for duplicate resolution after all retries") + } + + exitCh := make(chan struct{}) + go func(safePoint uint64) { + defer pdCli.Close() + defer close(exitCh) + ticker := time.NewTicker(pauseGCIntervalForDupeRes) + defer ticker.Stop() + for { + select { + case <-ticker.C: + minSafePoint, err := pdCli.UpdateServiceGCSafePoint(ctx, serviceID, ttl, safePoint) + if err != nil { + log.FromContext(ctx).Warn("Failed to register GC safe point", zap.Error(err)) + continue + } + if minSafePoint > safePoint { + log.FromContext(ctx).Warn("The current minimum safe point is newer than what we hold, duplicate records are at"+ + "risk of being GC and not detectable", + zap.Uint64("safePoint", safePoint), + zap.Uint64("minSafePoint", minSafePoint), + ) + safePoint = minSafePoint + } + case <-ctx.Done(): + stopCtx, cancelFunc := context.WithTimeout(context.Background(), time.Second*5) + if _, err := pdCli.UpdateServiceGCSafePoint(stopCtx, serviceID, 0, safePoint); err != nil { + log.FromContext(ctx).Warn("Failed to reset safe point ttl to zero", zap.Error(err)) + } + // just make compiler happy + cancelFunc() + return + } + } + }(safePoint) + return exitCh, nil +} + +func (rc *Controller) importTables(ctx context.Context) (finalErr error) { + // output error summary + defer rc.outputErrorSummary() + + if isLocalBackend(rc.cfg) && rc.cfg.Conflict.Strategy != config.NoneOnDup { + subCtx, cancel := context.WithCancel(ctx) + exitCh, err := rc.keepPauseGCForDupeRes(subCtx) + if err != nil { + cancel() + return errors.Trace(err) + } + defer func() { + cancel() + <-exitCh + }() + } + + logTask := log.FromContext(ctx).Begin(zap.InfoLevel, "restore all tables data") + if rc.tableWorkers == nil { + rc.tableWorkers = worker.NewPool(ctx, rc.cfg.App.TableConcurrency, "table") + } + if rc.indexWorkers == nil { + rc.indexWorkers = worker.NewPool(ctx, rc.cfg.App.IndexConcurrency, "index") + } + + // for local backend, we should disable some pd scheduler and change some settings, to + // make split region and ingest sst more stable + // because importer backend is mostly use for v3.x cluster which doesn't support these api, + // so we also don't do this for import backend + finishSchedulers := func() { + if rc.taskMgr != nil { + rc.taskMgr.Close() + } + } + // if one lightning failed abnormally, and can't determine whether it needs to switch back, + // we do not do switch back automatically + switchBack := false + cleanup := false + postProgress := func() error { return nil } + var kvStore tidbkv.Storage + var etcdCli *clientv3.Client + + if isLocalBackend(rc.cfg) { + var ( + restoreFn pdutil.UndoFunc + err error + ) + + if rc.cfg.TikvImporter.PausePDSchedulerScope == config.PausePDSchedulerScopeGlobal { + logTask.Info("pause pd scheduler of global scope") + + restoreFn, err = rc.taskMgr.CheckAndPausePdSchedulers(ctx) + if err != nil { + return errors.Trace(err) + } + } + + finishSchedulers = func() { + taskFinished := finalErr == nil + // use context.Background to make sure this restore function can still be executed even if ctx is canceled + restoreCtx := context.Background() + needSwitchBack, needCleanup, err := rc.taskMgr.CheckAndFinishRestore(restoreCtx, taskFinished) + if err != nil { + logTask.Warn("check restore pd schedulers failed", zap.Error(err)) + return + } + switchBack = needSwitchBack + cleanup = needCleanup + + if needSwitchBack && restoreFn != nil { + logTask.Info("add back PD leader®ion schedulers") + if restoreE := restoreFn(restoreCtx); restoreE != nil { + logTask.Warn("failed to restore removed schedulers, you may need to restore them manually", zap.Error(restoreE)) + } + } + + if rc.taskMgr != nil { + rc.taskMgr.Close() + } + } + + // Disable GC because TiDB enables GC already. + urlsWithScheme := rc.pdCli.GetServiceDiscovery().GetServiceURLs() + // remove URL scheme + urlsWithoutScheme := make([]string, 0, len(urlsWithScheme)) + for _, u := range urlsWithScheme { + u = strings.TrimPrefix(u, "http://") + u = strings.TrimPrefix(u, "https://") + urlsWithoutScheme = append(urlsWithoutScheme, u) + } + kvStore, err = driver.TiKVDriver{}.OpenWithOptions( + fmt.Sprintf( + "tikv://%s?disableGC=true&keyspaceName=%s", + strings.Join(urlsWithoutScheme, ","), rc.keyspaceName, + ), + driver.WithSecurity(rc.tls.ToTiKVSecurityConfig()), + ) + if err != nil { + return errors.Trace(err) + } + etcdCli, err := clientv3.New(clientv3.Config{ + Endpoints: urlsWithScheme, + AutoSyncInterval: 30 * time.Second, + TLS: rc.tls.TLSConfig(), + }) + if err != nil { + return errors.Trace(err) + } + etcd.SetEtcdCliByNamespace(etcdCli, keyspace.MakeKeyspaceEtcdNamespace(kvStore.GetCodec())) + + manager, err := NewChecksumManager(ctx, rc, kvStore) + if err != nil { + return errors.Trace(err) + } + ctx = context.WithValue(ctx, &checksumManagerKey, manager) + + undo, err := rc.registerTaskToPD(ctx) + if err != nil { + return errors.Trace(err) + } + defer undo() + } + + type task struct { + tr *TableImporter + cp *checkpoints.TableCheckpoint + } + + totalTables := 0 + for _, dbMeta := range rc.dbMetas { + totalTables += len(dbMeta.Tables) + } + postProcessTaskChan := make(chan task, totalTables) + + var wg sync.WaitGroup + var restoreErr common.OnceError + + stopPeriodicActions := make(chan struct{}) + + periodicActions, cancelFunc := rc.buildRunPeriodicActionAndCancelFunc(ctx, stopPeriodicActions) + go periodicActions() + + defer close(stopPeriodicActions) + + defer func() { + finishSchedulers() + cancelFunc(switchBack) + + if err := postProgress(); err != nil { + logTask.End(zap.ErrorLevel, err) + finalErr = err + return + } + logTask.End(zap.ErrorLevel, nil) + // clean up task metas + if cleanup { + logTask.Info("cleanup task metas") + if cleanupErr := rc.taskMgr.Cleanup(context.Background()); cleanupErr != nil { + logTask.Warn("failed to clean task metas, you may need to restore them manually", zap.Error(cleanupErr)) + } + // cleanup table meta and schema db if needed. + if err := rc.taskMgr.CleanupAllMetas(context.Background()); err != nil { + logTask.Warn("failed to clean table task metas, you may need to restore them manually", zap.Error(err)) + } + } + if kvStore != nil { + if err := kvStore.Close(); err != nil { + logTask.Warn("failed to close kv store", zap.Error(err)) + } + } + if etcdCli != nil { + if err := etcdCli.Close(); err != nil { + logTask.Warn("failed to close etcd client", zap.Error(err)) + } + } + }() + + taskCh := make(chan task, rc.cfg.App.IndexConcurrency) + defer close(taskCh) + + for i := 0; i < rc.cfg.App.IndexConcurrency; i++ { + go func() { + for task := range taskCh { + tableLogTask := task.tr.logger.Begin(zap.InfoLevel, "restore table") + web.BroadcastTableCheckpoint(task.tr.tableName, task.cp) + + needPostProcess, err := task.tr.importTable(ctx, rc, task.cp) + if err != nil && !common.IsContextCanceledError(err) { + task.tr.logger.Error("failed to import table", zap.Error(err)) + } + + err = common.NormalizeOrWrapErr(common.ErrRestoreTable, err, task.tr.tableName) + tableLogTask.End(zap.ErrorLevel, err) + web.BroadcastError(task.tr.tableName, err) + if m, ok := metric.FromContext(ctx); ok { + m.RecordTableCount(metric.TableStateCompleted, err) + } + restoreErr.Set(err) + if needPostProcess { + postProcessTaskChan <- task + } + wg.Done() + } + }() + } + + var allTasks []task + var totalDataSizeToRestore int64 + for _, dbMeta := range rc.dbMetas { + dbInfo := rc.dbInfos[dbMeta.Name] + for _, tableMeta := range dbMeta.Tables { + tableInfo := dbInfo.Tables[tableMeta.Name] + tableName := common.UniqueTable(dbInfo.Name, tableInfo.Name) + cp, err := rc.checkpointsDB.Get(ctx, tableName) + if err != nil { + return errors.Trace(err) + } + if cp.Status < checkpoints.CheckpointStatusAllWritten && len(tableMeta.DataFiles) == 0 { + continue + } + igCols, err := rc.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(dbInfo.Name, tableInfo.Name, rc.cfg.Mydumper.CaseSensitive) + if err != nil { + return errors.Trace(err) + } + tr, err := NewTableImporter(tableName, tableMeta, dbInfo, tableInfo, cp, igCols.ColumnsMap(), kvStore, etcdCli, log.FromContext(ctx)) + if err != nil { + return errors.Trace(err) + } + + allTasks = append(allTasks, task{tr: tr, cp: cp}) + + if len(cp.Engines) == 0 { + for i, fi := range tableMeta.DataFiles { + totalDataSizeToRestore += fi.FileMeta.FileSize + if fi.FileMeta.Type == mydump.SourceTypeParquet { + numberRows, err := mydump.ReadParquetFileRowCountByFile(ctx, rc.store, fi.FileMeta) + if err != nil { + return errors.Trace(err) + } + if m, ok := metric.FromContext(ctx); ok { + m.RowsCounter.WithLabelValues(metric.StateTotalRestore, tableName).Add(float64(numberRows)) + } + fi.FileMeta.Rows = numberRows + tableMeta.DataFiles[i] = fi + } + } + } else { + for _, eng := range cp.Engines { + for _, chunk := range eng.Chunks { + // for parquet files filesize is more accurate, we can calculate correct unfinished bytes unless + // we set up the reader, so we directly use filesize here + if chunk.FileMeta.Type == mydump.SourceTypeParquet { + totalDataSizeToRestore += chunk.FileMeta.FileSize + if m, ok := metric.FromContext(ctx); ok { + m.RowsCounter.WithLabelValues(metric.StateTotalRestore, tableName).Add(float64(chunk.UnfinishedSize())) + } + } else { + totalDataSizeToRestore += chunk.UnfinishedSize() + } + } + } + } + } + } + + if m, ok := metric.FromContext(ctx); ok { + m.BytesCounter.WithLabelValues(metric.StateTotalRestore).Add(float64(totalDataSizeToRestore)) + } + + for i := range allTasks { + wg.Add(1) + select { + case taskCh <- allTasks[i]: + case <-ctx.Done(): + return ctx.Err() + } + } + + wg.Wait() + // if context is done, should return directly + select { + case <-ctx.Done(): + err := restoreErr.Get() + if err == nil { + err = ctx.Err() + } + logTask.End(zap.ErrorLevel, err) + return err + default: + } + + postProgress = func() error { + close(postProcessTaskChan) + // otherwise, we should run all tasks in the post-process task chan + for i := 0; i < rc.cfg.App.TableConcurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for task := range postProcessTaskChan { + metaMgr := rc.metaMgrBuilder.TableMetaMgr(task.tr) + // force all the remain post-process tasks to be executed + _, err2 := task.tr.postProcess(ctx, rc, task.cp, true, metaMgr) + restoreErr.Set(err2) + } + }() + } + wg.Wait() + return restoreErr.Get() + } + + return nil +} + +func (rc *Controller) registerTaskToPD(ctx context.Context) (undo func(), _ error) { + etcdCli, err := dialEtcdWithCfg(ctx, rc.cfg, rc.pdCli.GetServiceDiscovery().GetServiceURLs()) + if err != nil { + return nil, errors.Trace(err) + } + + register := utils.NewTaskRegister(etcdCli, utils.RegisterLightning, fmt.Sprintf("lightning-%s", uuid.New())) + + undo = func() { + closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := register.Close(closeCtx); err != nil { + log.L().Warn("failed to unregister task", zap.Error(err)) + } + if err := etcdCli.Close(); err != nil { + log.L().Warn("failed to close etcd client", zap.Error(err)) + } + } + if err := register.RegisterTask(ctx); err != nil { + undo() + return nil, errors.Trace(err) + } + return undo, nil +} + +func addExtendDataForCheckpoint( + ctx context.Context, + cfg *config.Config, + cp *checkpoints.TableCheckpoint, +) error { + if len(cfg.Routes) == 0 { + return nil + } + hasExtractor := false + for _, route := range cfg.Routes { + hasExtractor = hasExtractor || route.TableExtractor != nil || route.SchemaExtractor != nil || route.SourceExtractor != nil + if hasExtractor { + break + } + } + if !hasExtractor { + return nil + } + + // Use default file router directly because fileRouter and router are not compatible + fileRouter, err := mydump.NewDefaultFileRouter(log.FromContext(ctx)) + if err != nil { + return err + } + var router *regexprrouter.RouteTable + router, err = regexprrouter.NewRegExprRouter(cfg.Mydumper.CaseSensitive, cfg.Routes) + if err != nil { + return err + } + for _, engine := range cp.Engines { + for _, chunk := range engine.Chunks { + _, file := filepath.Split(chunk.FileMeta.Path) + var res *mydump.RouteResult + res, err = fileRouter.Route(file) + if err != nil { + return err + } + extendCols, extendData := router.FetchExtendColumn(res.Schema, res.Name, cfg.Mydumper.SourceID) + chunk.FileMeta.ExtendData = mydump.ExtendColumnData{ + Columns: extendCols, + Values: extendData, + } + } + } + return nil +} + +func (rc *Controller) outputErrorSummary() { + if rc.errorMgr.HasError() { + fmt.Println(rc.errorMgr.Output()) + } +} + +// do full compaction for the whole data. +func (rc *Controller) fullCompact(ctx context.Context) error { + if !rc.cfg.PostRestore.Compact { + log.FromContext(ctx).Info("skip full compaction") + return nil + } + + // wait until any existing level-1 compact to complete first. + task := log.FromContext(ctx).Begin(zap.InfoLevel, "wait for completion of existing level 1 compaction") + for !rc.compactState.CompareAndSwap(compactStateIdle, compactStateDoing) { + time.Sleep(100 * time.Millisecond) + } + task.End(zap.ErrorLevel, nil) + + return errors.Trace(rc.doCompact(ctx, FullLevelCompact)) +} + +func (rc *Controller) doCompact(ctx context.Context, level int32) error { + return tikv.ForAllStores( + ctx, + rc.pdHTTPCli, + metapb.StoreState_Offline, + func(c context.Context, store *pdhttp.MetaStore) error { + return tikv.Compact(c, rc.tls, store.Address, level, rc.resourceGroupName) + }, + ) +} + +func (rc *Controller) enforceDiskQuota(ctx context.Context) { + if !rc.diskQuotaState.CompareAndSwap(diskQuotaStateIdle, diskQuotaStateChecking) { + // do not run multiple the disk quota check / import simultaneously. + // (we execute the lock check in background to avoid blocking the cron thread) + return + } + + localBackend := rc.backend.(*local.Backend) + go func() { + // locker is assigned when we detect the disk quota is exceeded. + // before the disk quota is confirmed exceeded, we keep the diskQuotaLock + // unlocked to avoid periodically interrupting the writer threads. + var locker sync.Locker + defer func() { + rc.diskQuotaState.Store(diskQuotaStateIdle) + if locker != nil { + locker.Unlock() + } + }() + + isRetrying := false + + for { + // sleep for a cycle if we are retrying because there is nothing new to import. + if isRetrying { + select { + case <-ctx.Done(): + return + case <-time.After(rc.cfg.Cron.CheckDiskQuota.Duration): + } + } else { + isRetrying = true + } + + quota := int64(rc.cfg.TikvImporter.DiskQuota) + largeEngines, inProgressLargeEngines, totalDiskSize, totalMemSize := local.CheckDiskQuota(localBackend, quota) + if m, ok := metric.FromContext(ctx); ok { + m.LocalStorageUsageBytesGauge.WithLabelValues("disk").Set(float64(totalDiskSize)) + m.LocalStorageUsageBytesGauge.WithLabelValues("mem").Set(float64(totalMemSize)) + } + + logger := log.FromContext(ctx).With( + zap.Int64("diskSize", totalDiskSize), + zap.Int64("memSize", totalMemSize), + zap.Int64("quota", quota), + zap.Int("largeEnginesCount", len(largeEngines)), + zap.Int("inProgressLargeEnginesCount", inProgressLargeEngines)) + + if len(largeEngines) == 0 && inProgressLargeEngines == 0 { + logger.Debug("disk quota respected") + return + } + + if locker == nil { + // blocks all writers when we detected disk quota being exceeded. + rc.diskQuotaLock.Lock() + locker = &rc.diskQuotaLock + } + + logger.Warn("disk quota exceeded") + if len(largeEngines) == 0 { + logger.Warn("all large engines are already importing, keep blocking all writes") + continue + } + + // flush all engines so that checkpoints can be updated. + if err := rc.backend.FlushAllEngines(ctx); err != nil { + logger.Error("flush engine for disk quota failed, check again later", log.ShortError(err)) + return + } + + // at this point, all engines are synchronized on disk. + // we then import every large engines one by one and complete. + // if any engine failed to import, we just try again next time, since the data are still intact. + rc.diskQuotaState.Store(diskQuotaStateImporting) + task := logger.Begin(zap.WarnLevel, "importing large engines for disk quota") + var importErr error + for _, engine := range largeEngines { + // Use a larger split region size to avoid split the same region by many times. + if err := localBackend.UnsafeImportAndReset(ctx, engine, int64(config.SplitRegionSize)*int64(config.MaxSplitRegionSizeRatio), int64(config.SplitRegionKeys)*int64(config.MaxSplitRegionSizeRatio)); err != nil { + importErr = multierr.Append(importErr, err) + } + } + task.End(zap.ErrorLevel, importErr) + return + } + }() +} + +func (rc *Controller) setGlobalVariables(ctx context.Context) error { + // skip for tidb backend to be compatible with MySQL + if isTiDBBackend(rc.cfg) { + return nil + } + // set new collation flag base on tidb config + enabled, err := ObtainNewCollationEnabled(ctx, rc.db) + if err != nil { + return err + } + // we should enable/disable new collation here since in server mode, tidb config + // may be different in different tasks + collate.SetNewCollationEnabledForTest(enabled) + log.FromContext(ctx).Info(session.TidbNewCollationEnabled, zap.Bool("enabled", enabled)) + + return nil +} + +func (rc *Controller) waitCheckpointFinish() { + // wait checkpoint process finish so that we can do cleanup safely + close(rc.saveCpCh) + rc.checkpointsWg.Wait() +} + +func (rc *Controller) cleanCheckpoints(ctx context.Context) error { + rc.waitCheckpointFinish() + + if !rc.cfg.Checkpoint.Enable { + return nil + } + + logger := log.FromContext(ctx).With( + zap.Stringer("keepAfterSuccess", rc.cfg.Checkpoint.KeepAfterSuccess), + zap.Int64("taskID", rc.cfg.TaskID), + ) + + task := logger.Begin(zap.InfoLevel, "clean checkpoints") + var err error + switch rc.cfg.Checkpoint.KeepAfterSuccess { + case config.CheckpointRename: + err = rc.checkpointsDB.MoveCheckpoints(ctx, rc.cfg.TaskID) + case config.CheckpointRemove: + err = rc.checkpointsDB.RemoveCheckpoint(ctx, "all") + } + task.End(zap.ErrorLevel, err) + if err != nil { + return common.ErrCleanCheckpoint.Wrap(err).GenWithStackByArgs() + } + return nil +} + +func isLocalBackend(cfg *config.Config) bool { + return cfg.TikvImporter.Backend == config.BackendLocal +} + +func isTiDBBackend(cfg *config.Config) bool { + return cfg.TikvImporter.Backend == config.BackendTiDB +} + +// preCheckRequirements checks +// 1. Cluster resource +// 2. Local node resource +// 3. Cluster region +// 4. Lightning configuration +// before restore tables start. +func (rc *Controller) preCheckRequirements(ctx context.Context) error { + if err := rc.DataCheck(ctx); err != nil { + return errors.Trace(err) + } + + if rc.cfg.App.CheckRequirements { + if err := rc.ClusterIsAvailable(ctx); err != nil { + return errors.Trace(err) + } + + if rc.ownStore { + if err := rc.StoragePermission(ctx); err != nil { + return errors.Trace(err) + } + } + } + + if err := rc.metaMgrBuilder.Init(ctx); err != nil { + return common.ErrInitMetaManager.Wrap(err).GenWithStackByArgs() + } + taskExist := false + + // We still need to sample source data even if this task has existed, because we need to judge whether the + // source is in order as row key to decide how to sort local data. + estimatedSizeResult, err := rc.preInfoGetter.EstimateSourceDataSize(ctx) + if err != nil { + return common.ErrCheckDataSource.Wrap(err).GenWithStackByArgs() + } + estimatedDataSizeWithIndex := estimatedSizeResult.SizeWithIndex + estimatedTiflashDataSize := estimatedSizeResult.TiFlashSize + + // Do not import with too large concurrency because these data may be all unsorted. + if estimatedSizeResult.HasUnsortedBigTables { + if rc.cfg.App.TableConcurrency > rc.cfg.App.IndexConcurrency { + rc.cfg.App.TableConcurrency = rc.cfg.App.IndexConcurrency + } + } + if rc.status != nil { + rc.status.TotalFileSize.Store(estimatedSizeResult.SizeWithoutIndex) + } + if isLocalBackend(rc.cfg) { + pdAddrs := rc.pdCli.GetServiceDiscovery().GetServiceURLs() + pdController, err := pdutil.NewPdController( + ctx, pdAddrs, rc.tls.TLSConfig(), rc.tls.ToPDSecurityOption(), + ) + if err != nil { + return common.NormalizeOrWrapErr(common.ErrCreatePDClient, err) + } + + // PdController will be closed when `taskMetaMgr` closes. + rc.taskMgr = rc.metaMgrBuilder.TaskMetaMgr(pdController) + taskExist, err = rc.taskMgr.CheckTaskExist(ctx) + if err != nil { + return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() + } + if !taskExist { + if err = rc.taskMgr.InitTask(ctx, estimatedDataSizeWithIndex, estimatedTiflashDataSize); err != nil { + return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() + } + } + if rc.cfg.TikvImporter.PausePDSchedulerScope == config.PausePDSchedulerScopeTable && + !rc.taskMgr.CanPauseSchedulerByKeyRange() { + return errors.New("target cluster don't support pause-pd-scheduler-scope=table, the minimal version required is 6.1.0") + } + if rc.cfg.App.CheckRequirements { + needCheck := true + if rc.cfg.Checkpoint.Enable { + taskCheckpoints, err := rc.checkpointsDB.TaskCheckpoint(ctx) + if err != nil { + return common.ErrReadCheckpoint.Wrap(err).GenWithStack("get task checkpoint failed") + } + // If task checkpoint is initialized, it means check has been performed before. + // We don't need and shouldn't check again, because lightning may have already imported some data. + needCheck = taskCheckpoints == nil + } + if needCheck { + err = rc.localResource(ctx) + if err != nil { + return common.ErrCheckLocalResource.Wrap(err).GenWithStackByArgs() + } + if err := rc.clusterResource(ctx); err != nil { + if err1 := rc.taskMgr.CleanupTask(ctx); err1 != nil { + log.FromContext(ctx).Warn("cleanup task failed", zap.Error(err1)) + return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() + } + } + if err := rc.checkClusterRegion(ctx); err != nil { + return common.ErrCheckClusterRegion.Wrap(err).GenWithStackByArgs() + } + } + // even if checkpoint exists, we still need to make sure CDC/PiTR task is not running. + if err := rc.checkCDCPiTR(ctx); err != nil { + return common.ErrCheckCDCPiTR.Wrap(err).GenWithStackByArgs() + } + } + } + + if rc.cfg.App.CheckRequirements { + fmt.Println(rc.checkTemplate.Output()) + } + if !rc.checkTemplate.Success() { + if !taskExist && rc.taskMgr != nil { + err := rc.taskMgr.CleanupTask(ctx) + if err != nil { + log.FromContext(ctx).Warn("cleanup task failed", zap.Error(err)) + } + } + return common.ErrPreCheckFailed.GenWithStackByArgs(rc.checkTemplate.FailedMsg()) + } + return nil +} + +// DataCheck checks the data schema which needs #rc.restoreSchema finished. +func (rc *Controller) DataCheck(ctx context.Context) error { + if rc.cfg.App.CheckRequirements { + if err := rc.HasLargeCSV(ctx); err != nil { + return errors.Trace(err) + } + } + + if err := rc.checkCheckpoints(ctx); err != nil { + return errors.Trace(err) + } + + if rc.cfg.App.CheckRequirements { + if err := rc.checkSourceSchema(ctx); err != nil { + return errors.Trace(err) + } + } + + if err := rc.checkTableEmpty(ctx); err != nil { + return common.ErrCheckTableEmpty.Wrap(err).GenWithStackByArgs() + } + if err := rc.checkCSVHeader(ctx); err != nil { + return common.ErrCheckCSVHeader.Wrap(err).GenWithStackByArgs() + } + + return nil +} + +var ( + maxKVQueueSize = 32 // Cache at most this number of rows before blocking the encode loop + minDeliverBytes uint64 = 96 * units.KiB // 96 KB (data + index). batch at least this amount of bytes to reduce number of messages +) + +type deliveredKVs struct { + kvs encode.Row // if kvs is nil, this indicated we've got the last message. + columns []string + offset int64 + rowID int64 + + realOffset int64 // indicates file reader's current position, only used for compressed files +} + +type deliverResult struct { + totalDur time.Duration + err error +} + +func saveCheckpoint(rc *Controller, t *TableImporter, engineID int32, chunk *checkpoints.ChunkCheckpoint) { + // We need to update the AllocBase every time we've finished a file. + // The AllocBase is determined by the maximum of the "handle" (_tidb_rowid + // or integer primary key), which can only be obtained by reading all data. + + var base int64 + if t.tableInfo.Core.ContainsAutoRandomBits() { + base = t.alloc.Get(autoid.AutoRandomType).Base() + 1 + } else { + base = t.alloc.Get(autoid.RowIDAllocType).Base() + 1 + } + rc.saveCpCh <- saveCp{ + tableName: t.tableName, + merger: &checkpoints.RebaseCheckpointMerger{ + AllocBase: base, + }, + } + rc.saveCpCh <- saveCp{ + tableName: t.tableName, + merger: &checkpoints.ChunkCheckpointMerger{ + EngineID: engineID, + Key: chunk.Key, + Checksum: chunk.Checksum, + Pos: chunk.Chunk.Offset, + RowID: chunk.Chunk.PrevRowIDMax, + ColumnPermutation: chunk.ColumnPermutation, + EndOffset: chunk.Chunk.EndOffset, + }, + } +} + +// filterColumns filter columns and extend columns. +// It accepts: +// - columnsNames, header in the data files; +// - extendData, extendData fetched through data file name, that is to say, table info; +// - ignoreColsMap, columns to be ignored when we import; +// - tableInfo, tableInfo of the target table; +// It returns: +// - filteredColumns, columns of the original data to import. +// - extendValueDatums, extended Data to import. +// The data we import will use filteredColumns as columns, use (parser.LastRow+extendValueDatums) as data +// ColumnPermutation will be modified to make sure the correspondence relationship is correct. +// if len(columnsNames) > 0, it means users has specified each field definition, we can just use users +func filterColumns(columnNames []string, extendData mydump.ExtendColumnData, ignoreColsMap map[string]struct{}, tableInfo *model.TableInfo) ([]string, []types.Datum) { + extendCols, extendVals := extendData.Columns, extendData.Values + extendColsSet := set.NewStringSet(extendCols...) + filteredColumns := make([]string, 0, len(columnNames)) + if len(columnNames) > 0 { + if len(ignoreColsMap) > 0 { + for _, c := range columnNames { + _, ok := ignoreColsMap[c] + if !ok { + filteredColumns = append(filteredColumns, c) + } + } + } else { + filteredColumns = columnNames + } + } else if len(ignoreColsMap) > 0 || len(extendCols) > 0 { + // init column names by table schema + // after filtered out some columns, we must explicitly set the columns for TiDB backend + for _, col := range tableInfo.Columns { + _, ok := ignoreColsMap[col.Name.L] + // ignore all extend row values specified by users + if !col.Hidden && !ok && !extendColsSet.Exist(col.Name.O) { + filteredColumns = append(filteredColumns, col.Name.O) + } + } + } + extendValueDatums := make([]types.Datum, 0) + filteredColumns = append(filteredColumns, extendCols...) + for _, extendVal := range extendVals { + extendValueDatums = append(extendValueDatums, types.NewStringDatum(extendVal)) + } + return filteredColumns, extendValueDatums +} + +// check store liveness of tikv client-go requires GlobalConfig to work correctly, so we need to init it, +// else tikv will report SSL error when tls is enabled. +// and the SSL error seems affects normal logic of newer TiKV version, and cause the error "tikv: region is unavailable" +// during checksum. +// todo: DM relay on lightning physical mode too, but client-go doesn't support passing TLS data as bytes, +func initGlobalConfig(secCfg tikvconfig.Security) { + if secCfg.ClusterSSLCA != "" || secCfg.ClusterSSLCert != "" { + conf := tidbconfig.GetGlobalConfig() + conf.Security.ClusterSSLCA = secCfg.ClusterSSLCA + conf.Security.ClusterSSLCert = secCfg.ClusterSSLCert + conf.Security.ClusterSSLKey = secCfg.ClusterSSLKey + tidbconfig.StoreGlobalConfig(conf) + } +} diff --git a/lightning/pkg/importer/table_import.go b/lightning/pkg/importer/table_import.go index ccc0fcc088b3b..be00b6c6250e5 100644 --- a/lightning/pkg/importer/table_import.go +++ b/lightning/pkg/importer/table_import.go @@ -235,9 +235,9 @@ func (tr *TableImporter) importTable( } } - failpoint.Inject("FailAfterDuplicateDetection", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FailAfterDuplicateDetection")); _err_ == nil { panic("forcing failure after duplicate detection") - }) + } } // 3. Drop indexes if add-index-by-sql is enabled @@ -279,9 +279,9 @@ func (tr *TableImporter) populateChunks(ctx context.Context, rc *Controller, cp tableRegions, err := mydump.MakeTableRegions(ctx, divideConfig) if err == nil { timestamp := time.Now().Unix() - failpoint.Inject("PopulateChunkTimestamp", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("PopulateChunkTimestamp")); _err_ == nil { timestamp = int64(v.(int)) - }) + } for _, region := range tableRegions { engine, found := cp.Engines[region.EngineID] if !found { @@ -579,13 +579,13 @@ func (tr *TableImporter) importEngines(pCtx context.Context, rc *Controller, cp if cp.Status < checkpoints.CheckpointStatusIndexImported { var err error if indexEngineCp.Status < checkpoints.CheckpointStatusImported { - failpoint.Inject("FailBeforeStartImportingIndexEngine", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FailBeforeStartImportingIndexEngine")); _err_ == nil { errMsg := "fail before importing index KV data" tr.logger.Warn(errMsg) - failpoint.Return(errors.New(errMsg)) - }) + return errors.New(errMsg) + } err = tr.importKV(ctx, closedIndexEngine, rc) - failpoint.Inject("FailBeforeIndexEngineImported", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FailBeforeIndexEngineImported")); _err_ == nil { finished := rc.status.FinishedFileSize.Load() total := rc.status.TotalFileSize.Load() tr.logger.Warn("print lightning status", @@ -593,7 +593,7 @@ func (tr *TableImporter) importEngines(pCtx context.Context, rc *Controller, cp zap.Int64("total", total), zap.Bool("equal", finished == total)) panic("forcing failure due to FailBeforeIndexEngineImported") - }) + } } saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexImported) @@ -722,11 +722,11 @@ ChunkLoop: } checkFlushLock.Unlock() - failpoint.Inject("orphanWriterGoRoutine", func() { + if _, _err_ := failpoint.Eval(_curpkg_("orphanWriterGoRoutine")); _err_ == nil { if chunkIndex > 0 { <-pCtx.Done() } - }) + } select { case <-pCtx.Done(): @@ -1038,12 +1038,12 @@ func (tr *TableImporter) postProcess( } hasDupe = hasLocalDupe } - failpoint.Inject("SlowDownCheckDupe", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("SlowDownCheckDupe")); _err_ == nil { sec := v.(int) tr.logger.Warn("start to sleep several seconds before checking other dupe", zap.Int("seconds", sec)) time.Sleep(time.Duration(sec) * time.Second) - }) + } otherHasDupe, needRemoteDupe, baseTotalChecksum, err := metaMgr.CheckAndUpdateLocalChecksum(ctx, &localChecksum, hasDupe) if err != nil { @@ -1087,11 +1087,11 @@ func (tr *TableImporter) postProcess( var remoteChecksum *local.RemoteChecksum remoteChecksum, err = DoChecksum(ctx, tr.tableInfo) - failpoint.Inject("checksum-error", func() { + if _, _err_ := failpoint.Eval(_curpkg_("checksum-error")); _err_ == nil { tr.logger.Info("failpoint checksum-error injected.") remoteChecksum = nil err = status.Error(codes.Unknown, "Checksum meets error.") - }) + } if err != nil { if rc.cfg.PostRestore.Checksum != config.OpLevelOptional { return false, errors.Trace(err) @@ -1367,7 +1367,7 @@ func (tr *TableImporter) importKV( m.ImportSecondsHistogram.Observe(dur.Seconds()) } - failpoint.Inject("SlowDownImport", func() {}) + failpoint.Eval(_curpkg_("SlowDownImport")) return nil } @@ -1544,17 +1544,17 @@ func (*TableImporter) executeDDL( resultCh <- s.Exec(ctx, "add index", ddl) }() - failpoint.Inject("AddIndexCrash", func() { + if _, _err_ := failpoint.Eval(_curpkg_("AddIndexCrash")); _err_ == nil { _ = common.KillMySelf() - }) + } var ddlErr error for { select { case ddlErr = <-resultCh: - failpoint.Inject("AddIndexFail", func() { + if _, _err_ := failpoint.Eval(_curpkg_("AddIndexFail")); _err_ == nil { ddlErr = errors.New("injected error") - }) + } if ddlErr == nil { return nil } diff --git a/lightning/pkg/importer/table_import.go__failpoint_stash__ b/lightning/pkg/importer/table_import.go__failpoint_stash__ new file mode 100644 index 0000000000000..ccc0fcc088b3b --- /dev/null +++ b/lightning/pkg/importer/table_import.go__failpoint_stash__ @@ -0,0 +1,1822 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "cmp" + "context" + "database/sql" + "encoding/hex" + "fmt" + "path/filepath" + "slices" + "strings" + "sync" + "time" + + dmysql "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/lightning/pkg/web" + "github.com/pingcap/tidb/pkg/errno" + tidbkv "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/mydump" + verify "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/lightning/worker" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/extsort" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/multierr" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// TableImporter is a helper struct to import a table. +type TableImporter struct { + // The unique table name in the form "`db`.`tbl`". + tableName string + dbInfo *checkpoints.TidbDBInfo + tableInfo *checkpoints.TidbTableInfo + tableMeta *mydump.MDTableMeta + encTable table.Table + alloc autoid.Allocators + logger log.Logger + kvStore tidbkv.Storage + etcdCli *clientv3.Client + autoidCli *autoid.ClientDiscover + + // dupIgnoreRows tracks the rowIDs of rows that are duplicated and should be ignored. + dupIgnoreRows extsort.ExternalSorter + + ignoreColumns map[string]struct{} +} + +// NewTableImporter creates a new TableImporter. +func NewTableImporter( + tableName string, + tableMeta *mydump.MDTableMeta, + dbInfo *checkpoints.TidbDBInfo, + tableInfo *checkpoints.TidbTableInfo, + cp *checkpoints.TableCheckpoint, + ignoreColumns map[string]struct{}, + kvStore tidbkv.Storage, + etcdCli *clientv3.Client, + logger log.Logger, +) (*TableImporter, error) { + idAlloc := kv.NewPanickingAllocators(tableInfo.Core.SepAutoInc(), cp.AllocBase) + tbl, err := tables.TableFromMeta(idAlloc, tableInfo.Core) + if err != nil { + return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", tableName) + } + autoidCli := autoid.NewClientDiscover(etcdCli) + + return &TableImporter{ + tableName: tableName, + dbInfo: dbInfo, + tableInfo: tableInfo, + tableMeta: tableMeta, + encTable: tbl, + alloc: idAlloc, + kvStore: kvStore, + etcdCli: etcdCli, + autoidCli: autoidCli, + logger: logger.With(zap.String("table", tableName)), + ignoreColumns: ignoreColumns, + }, nil +} + +func (tr *TableImporter) importTable( + ctx context.Context, + rc *Controller, + cp *checkpoints.TableCheckpoint, +) (bool, error) { + // 1. Load the table info. + select { + case <-ctx.Done(): + return false, ctx.Err() + default: + } + + metaMgr := rc.metaMgrBuilder.TableMetaMgr(tr) + // no need to do anything if the chunks are already populated + if len(cp.Engines) > 0 { + tr.logger.Info("reusing engines and files info from checkpoint", + zap.Int("enginesCnt", len(cp.Engines)), + zap.Int("filesCnt", cp.CountChunks()), + ) + err := addExtendDataForCheckpoint(ctx, rc.cfg, cp) + if err != nil { + return false, errors.Trace(err) + } + } else if cp.Status < checkpoints.CheckpointStatusAllWritten { + if err := tr.populateChunks(ctx, rc, cp); err != nil { + return false, errors.Trace(err) + } + + // fetch the max chunk row_id max value as the global max row_id + rowIDMax := int64(0) + for _, engine := range cp.Engines { + if len(engine.Chunks) > 0 && engine.Chunks[len(engine.Chunks)-1].Chunk.RowIDMax > rowIDMax { + rowIDMax = engine.Chunks[len(engine.Chunks)-1].Chunk.RowIDMax + } + } + versionStr, err := version.FetchVersion(ctx, rc.db) + if err != nil { + return false, errors.Trace(err) + } + + versionInfo := version.ParseServerInfo(versionStr) + + // "show table next_row_id" is only available after tidb v4.0.0 + if versionInfo.ServerVersion.Major >= 4 && isLocalBackend(rc.cfg) { + // first, insert a new-line into meta table + if err = metaMgr.InitTableMeta(ctx); err != nil { + return false, err + } + + checksum, rowIDBase, err := metaMgr.AllocTableRowIDs(ctx, rowIDMax) + if err != nil { + return false, err + } + tr.RebaseChunkRowIDs(cp, rowIDBase) + + if checksum != nil { + if cp.Checksum != *checksum { + cp.Checksum = *checksum + rc.saveCpCh <- saveCp{ + tableName: tr.tableName, + merger: &checkpoints.TableChecksumMerger{ + Checksum: cp.Checksum, + }, + } + } + tr.logger.Info("checksum before restore table", zap.Object("checksum", &cp.Checksum)) + } + } + if err := rc.checkpointsDB.InsertEngineCheckpoints(ctx, tr.tableName, cp.Engines); err != nil { + return false, errors.Trace(err) + } + web.BroadcastTableCheckpoint(tr.tableName, cp) + + // rebase the allocator so it exceeds the number of rows. + if tr.tableInfo.Core.ContainsAutoRandomBits() { + cp.AllocBase = max(cp.AllocBase, tr.tableInfo.Core.AutoRandID) + if err := tr.alloc.Get(autoid.AutoRandomType).Rebase(context.Background(), cp.AllocBase, false); err != nil { + return false, err + } + } else { + cp.AllocBase = max(cp.AllocBase, tr.tableInfo.Core.AutoIncID) + if err := tr.alloc.Get(autoid.RowIDAllocType).Rebase(context.Background(), cp.AllocBase, false); err != nil { + return false, err + } + } + rc.saveCpCh <- saveCp{ + tableName: tr.tableName, + merger: &checkpoints.RebaseCheckpointMerger{ + AllocBase: cp.AllocBase, + }, + } + } + + // 2. Do duplicate detection if needed + if isLocalBackend(rc.cfg) && rc.cfg.Conflict.PrecheckConflictBeforeImport && rc.cfg.Conflict.Strategy != config.NoneOnDup { + _, uuid := backend.MakeUUID(tr.tableName, common.IndexEngineID) + workingDir := filepath.Join(rc.cfg.TikvImporter.SortedKVDir, uuid.String()+local.DupDetectDirSuffix) + resultDir := filepath.Join(rc.cfg.TikvImporter.SortedKVDir, uuid.String()+local.DupResultDirSuffix) + + dupIgnoreRows, err := extsort.OpenDiskSorter(resultDir, &extsort.DiskSorterOptions{ + Concurrency: rc.cfg.App.RegionConcurrency, + }) + if err != nil { + return false, errors.Trace(err) + } + tr.dupIgnoreRows = dupIgnoreRows + + if cp.Status < checkpoints.CheckpointStatusDupDetected { + err := tr.preDeduplicate(ctx, rc, cp, workingDir) + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusDupDetected) + if err := firstErr(err, saveCpErr); err != nil { + return false, errors.Trace(err) + } + } + + if !dupIgnoreRows.IsSorted() { + if err := dupIgnoreRows.Sort(ctx); err != nil { + return false, errors.Trace(err) + } + } + + failpoint.Inject("FailAfterDuplicateDetection", func() { + panic("forcing failure after duplicate detection") + }) + } + + // 3. Drop indexes if add-index-by-sql is enabled + if cp.Status < checkpoints.CheckpointStatusIndexDropped && isLocalBackend(rc.cfg) && rc.cfg.TikvImporter.AddIndexBySQL { + err := tr.dropIndexes(ctx, rc.db) + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexDropped) + if err := firstErr(err, saveCpErr); err != nil { + return false, errors.Trace(err) + } + } + + // 4. Restore engines (if still needed) + err := tr.importEngines(ctx, rc, cp) + if err != nil { + return false, errors.Trace(err) + } + + err = metaMgr.UpdateTableStatus(ctx, metaStatusRestoreFinished) + if err != nil { + return false, errors.Trace(err) + } + + // 5. Post-process. With the last parameter set to false, we can allow delay analyze execute latter + return tr.postProcess(ctx, rc, cp, false /* force-analyze */, metaMgr) +} + +// Close implements the Importer interface. +func (tr *TableImporter) Close() { + tr.encTable = nil + if tr.dupIgnoreRows != nil { + _ = tr.dupIgnoreRows.Close() + } + tr.logger.Info("restore done") +} + +func (tr *TableImporter) populateChunks(ctx context.Context, rc *Controller, cp *checkpoints.TableCheckpoint) error { + task := tr.logger.Begin(zap.InfoLevel, "load engines and files") + divideConfig := mydump.NewDataDivideConfig(rc.cfg, len(tr.tableInfo.Core.Columns), rc.ioWorkers, rc.store, tr.tableMeta) + tableRegions, err := mydump.MakeTableRegions(ctx, divideConfig) + if err == nil { + timestamp := time.Now().Unix() + failpoint.Inject("PopulateChunkTimestamp", func(v failpoint.Value) { + timestamp = int64(v.(int)) + }) + for _, region := range tableRegions { + engine, found := cp.Engines[region.EngineID] + if !found { + engine = &checkpoints.EngineCheckpoint{ + Status: checkpoints.CheckpointStatusLoaded, + } + cp.Engines[region.EngineID] = engine + } + ccp := &checkpoints.ChunkCheckpoint{ + Key: checkpoints.ChunkCheckpointKey{ + Path: region.FileMeta.Path, + Offset: region.Chunk.Offset, + }, + FileMeta: region.FileMeta, + ColumnPermutation: nil, + Chunk: region.Chunk, + Timestamp: timestamp, + } + if len(region.Chunk.Columns) > 0 { + perms, err := parseColumnPermutations( + tr.tableInfo.Core, + region.Chunk.Columns, + tr.ignoreColumns, + log.FromContext(ctx)) + if err != nil { + return errors.Trace(err) + } + ccp.ColumnPermutation = perms + } + engine.Chunks = append(engine.Chunks, ccp) + } + + // Add index engine checkpoint + cp.Engines[common.IndexEngineID] = &checkpoints.EngineCheckpoint{Status: checkpoints.CheckpointStatusLoaded} + } + task.End(zap.ErrorLevel, err, + zap.Int("enginesCnt", len(cp.Engines)), + zap.Int("filesCnt", len(tableRegions)), + ) + return err +} + +// AutoIDRequirement implements autoid.Requirement. +var _ autoid.Requirement = &TableImporter{} + +// Store implements the autoid.Requirement interface. +func (tr *TableImporter) Store() tidbkv.Storage { + return tr.kvStore +} + +// AutoIDClient implements the autoid.Requirement interface. +func (tr *TableImporter) AutoIDClient() *autoid.ClientDiscover { + return tr.autoidCli +} + +// RebaseChunkRowIDs rebase the row id of the chunks. +func (*TableImporter) RebaseChunkRowIDs(cp *checkpoints.TableCheckpoint, rowIDBase int64) { + if rowIDBase == 0 { + return + } + for _, engine := range cp.Engines { + for _, chunk := range engine.Chunks { + chunk.Chunk.PrevRowIDMax += rowIDBase + chunk.Chunk.RowIDMax += rowIDBase + } + } +} + +// initializeColumns computes the "column permutation" for an INSERT INTO +// statement. Suppose a table has columns (a, b, c, d) in canonical order, and +// we execute `INSERT INTO (d, b, a) VALUES ...`, we will need to remap the +// columns as: +// +// - column `a` is at position 2 +// - column `b` is at position 1 +// - column `c` is missing +// - column `d` is at position 0 +// +// The column permutation of (d, b, a) is set to be [2, 1, -1, 0]. +// +// The argument `columns` _must_ be in lower case. +func (tr *TableImporter) initializeColumns(columns []string, ccp *checkpoints.ChunkCheckpoint) error { + colPerm, err := createColumnPermutation(columns, tr.ignoreColumns, tr.tableInfo.Core, tr.logger) + if err != nil { + return err + } + ccp.ColumnPermutation = colPerm + return nil +} + +func createColumnPermutation( + columns []string, + ignoreColumns map[string]struct{}, + tableInfo *model.TableInfo, + logger log.Logger, +) ([]int, error) { + var colPerm []int + if len(columns) == 0 { + colPerm = make([]int, 0, len(tableInfo.Columns)+1) + shouldIncludeRowID := common.TableHasAutoRowID(tableInfo) + + // no provided columns, so use identity permutation. + for i, col := range tableInfo.Columns { + idx := i + if _, ok := ignoreColumns[col.Name.L]; ok { + idx = -1 + } else if col.IsGenerated() { + idx = -1 + } + colPerm = append(colPerm, idx) + } + if shouldIncludeRowID { + colPerm = append(colPerm, -1) + } + } else { + var err error + colPerm, err = parseColumnPermutations(tableInfo, columns, ignoreColumns, logger) + if err != nil { + return nil, errors.Trace(err) + } + } + return colPerm, nil +} + +func (tr *TableImporter) importEngines(pCtx context.Context, rc *Controller, cp *checkpoints.TableCheckpoint) error { + indexEngineCp := cp.Engines[common.IndexEngineID] + if indexEngineCp == nil { + tr.logger.Error("fail to importEngines because indexengine is nil") + return common.ErrCheckpointNotFound.GenWithStack("table %v index engine checkpoint not found", tr.tableName) + } + + ctx, cancel := context.WithCancel(pCtx) + defer cancel() + + // The table checkpoint status set to `CheckpointStatusIndexImported` only if + // both all data engines and the index engine had been imported to TiKV. + // But persist index engine checkpoint status and table checkpoint status are + // not an atomic operation, so `cp.Status < CheckpointStatusIndexImported` + // but `indexEngineCp.Status == CheckpointStatusImported` could happen + // when kill lightning after saving index engine checkpoint status before saving + // table checkpoint status. + var closedIndexEngine *backend.ClosedEngine + var restoreErr error + // if index-engine checkpoint is lower than `CheckpointStatusClosed`, there must be + // data-engines that need to be restore or import. Otherwise, all data-engines should + // be finished already. + + handleDataEngineThisRun := false + idxEngineCfg := &backend.EngineConfig{ + TableInfo: tr.tableInfo, + } + if indexEngineCp.Status < checkpoints.CheckpointStatusClosed { + handleDataEngineThisRun = true + indexWorker := rc.indexWorkers.Apply() + defer rc.indexWorkers.Recycle(indexWorker) + + if rc.cfg.TikvImporter.Backend == config.BackendLocal { + // for index engine, the estimate factor is non-clustered index count + idxCnt := len(tr.tableInfo.Core.Indices) + if !common.TableHasAutoRowID(tr.tableInfo.Core) { + idxCnt-- + } + threshold := local.EstimateCompactionThreshold(tr.tableMeta.DataFiles, cp, int64(idxCnt)) + idxEngineCfg.Local = backend.LocalEngineConfig{ + Compact: threshold > 0, + CompactConcurrency: 4, + CompactThreshold: threshold, + BlockSize: int(rc.cfg.TikvImporter.BlockSize), + } + } + // import backend can't reopen engine if engine is closed, so + // only open index engine if any data engines don't finish writing. + var indexEngine *backend.OpenedEngine + var err error + for engineID, engine := range cp.Engines { + if engineID == common.IndexEngineID { + continue + } + if engine.Status < checkpoints.CheckpointStatusAllWritten { + indexEngine, err = rc.engineMgr.OpenEngine(ctx, idxEngineCfg, tr.tableName, common.IndexEngineID) + if err != nil { + return errors.Trace(err) + } + break + } + } + + logTask := tr.logger.Begin(zap.InfoLevel, "import whole table") + var wg sync.WaitGroup + var engineErr common.OnceError + setError := func(err error) { + engineErr.Set(err) + // cancel this context to fail fast + cancel() + } + + type engineCheckpoint struct { + engineID int32 + checkpoint *checkpoints.EngineCheckpoint + } + allEngines := make([]engineCheckpoint, 0, len(cp.Engines)) + for engineID, engine := range cp.Engines { + allEngines = append(allEngines, engineCheckpoint{engineID: engineID, checkpoint: engine}) + } + slices.SortFunc(allEngines, func(i, j engineCheckpoint) int { return cmp.Compare(i.engineID, j.engineID) }) + + for _, ecp := range allEngines { + engineID := ecp.engineID + engine := ecp.checkpoint + select { + case <-ctx.Done(): + // Set engineErr and break this for loop to wait all the sub-routines done before return. + // Directly return may cause panic because caller will close the pebble db but some sub routines + // are still reading from or writing to the pebble db. + engineErr.Set(ctx.Err()) + default: + } + if engineErr.Get() != nil { + break + } + + // Should skip index engine + if engineID < 0 { + continue + } + + if engine.Status < checkpoints.CheckpointStatusImported { + wg.Add(1) + + // If the number of chunks is small, it means that this engine may be finished in a few times. + // We do not limit it in TableConcurrency + restoreWorker := rc.tableWorkers.Apply() + go func(w *worker.Worker, eid int32, ecp *checkpoints.EngineCheckpoint) { + defer wg.Done() + engineLogTask := tr.logger.With(zap.Int32("engineNumber", eid)).Begin(zap.InfoLevel, "restore engine") + dataClosedEngine, err := tr.preprocessEngine(ctx, rc, indexEngine, eid, ecp) + engineLogTask.End(zap.ErrorLevel, err) + rc.tableWorkers.Recycle(w) + if err == nil { + dataWorker := rc.closedEngineLimit.Apply() + defer rc.closedEngineLimit.Recycle(dataWorker) + err = tr.importEngine(ctx, dataClosedEngine, rc, ecp) + if rc.status != nil && rc.status.backend == config.BackendLocal { + for _, chunk := range ecp.Chunks { + rc.status.FinishedFileSize.Add(chunk.TotalSize()) + } + } + } + if err != nil { + setError(err) + } + }(restoreWorker, engineID, engine) + } else { + for _, chunk := range engine.Chunks { + rc.status.FinishedFileSize.Add(chunk.TotalSize()) + } + } + } + + wg.Wait() + + restoreErr = engineErr.Get() + logTask.End(zap.ErrorLevel, restoreErr) + if restoreErr != nil { + return errors.Trace(restoreErr) + } + + if indexEngine != nil { + closedIndexEngine, restoreErr = indexEngine.Close(ctx) + } else { + closedIndexEngine, restoreErr = rc.engineMgr.UnsafeCloseEngine(ctx, idxEngineCfg, tr.tableName, common.IndexEngineID) + } + + if err = rc.saveStatusCheckpoint(ctx, tr.tableName, common.IndexEngineID, restoreErr, checkpoints.CheckpointStatusClosed); err != nil { + return errors.Trace(firstErr(restoreErr, err)) + } + } else if indexEngineCp.Status == checkpoints.CheckpointStatusClosed { + // If index engine file has been closed but not imported only if context cancel occurred + // when `importKV()` execution, so `UnsafeCloseEngine` and continue import it. + closedIndexEngine, restoreErr = rc.engineMgr.UnsafeCloseEngine(ctx, idxEngineCfg, tr.tableName, common.IndexEngineID) + } + if restoreErr != nil { + return errors.Trace(restoreErr) + } + + // if data engine is handled in previous run and we continue importing from checkpoint + if !handleDataEngineThisRun { + for _, engine := range cp.Engines { + for _, chunk := range engine.Chunks { + rc.status.FinishedFileSize.Add(chunk.Chunk.EndOffset - chunk.Key.Offset) + } + } + } + + if cp.Status < checkpoints.CheckpointStatusIndexImported { + var err error + if indexEngineCp.Status < checkpoints.CheckpointStatusImported { + failpoint.Inject("FailBeforeStartImportingIndexEngine", func() { + errMsg := "fail before importing index KV data" + tr.logger.Warn(errMsg) + failpoint.Return(errors.New(errMsg)) + }) + err = tr.importKV(ctx, closedIndexEngine, rc) + failpoint.Inject("FailBeforeIndexEngineImported", func() { + finished := rc.status.FinishedFileSize.Load() + total := rc.status.TotalFileSize.Load() + tr.logger.Warn("print lightning status", + zap.Int64("finished", finished), + zap.Int64("total", total), + zap.Bool("equal", finished == total)) + panic("forcing failure due to FailBeforeIndexEngineImported") + }) + } + + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexImported) + if err = firstErr(err, saveCpErr); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// preprocessEngine do some preprocess work +// for local backend, it do local sort, for tidb backend it transforms data into sql and execute +// TODO: it's not a correct name for tidb backend, since there's no post-process for it +// TODO: after separate local/tidb backend more clearly, rename it. +func (tr *TableImporter) preprocessEngine( + pCtx context.Context, + rc *Controller, + indexEngine *backend.OpenedEngine, + engineID int32, + cp *checkpoints.EngineCheckpoint, +) (*backend.ClosedEngine, error) { + ctx, cancel := context.WithCancel(pCtx) + defer cancel() + // all data has finished written, we can close the engine directly. + if cp.Status >= checkpoints.CheckpointStatusAllWritten { + engineCfg := &backend.EngineConfig{ + TableInfo: tr.tableInfo, + } + closedEngine, err := rc.engineMgr.UnsafeCloseEngine(ctx, engineCfg, tr.tableName, engineID) + // If any error occurred, recycle worker immediately + if err != nil { + return closedEngine, errors.Trace(err) + } + if rc.status != nil && rc.status.backend == config.BackendTiDB { + for _, chunk := range cp.Chunks { + rc.status.FinishedFileSize.Add(chunk.Chunk.EndOffset - chunk.Key.Offset) + } + } + return closedEngine, nil + } + + // if the key are ordered, LocalWrite can optimize the writing. + // table has auto-incremented _tidb_rowid must satisfy following restrictions: + // - clustered index disable and primary key is not number + // - no auto random bits (auto random or shard row id) + // - no partition table + // - no explicit _tidb_rowid field (At this time we can't determine if the source file contains _tidb_rowid field, + // so we will do this check in LocalWriter when the first row is received.) + hasAutoIncrementAutoID := common.TableHasAutoRowID(tr.tableInfo.Core) && + tr.tableInfo.Core.AutoRandomBits == 0 && tr.tableInfo.Core.ShardRowIDBits == 0 && + tr.tableInfo.Core.Partition == nil + dataWriterCfg := &backend.LocalWriterConfig{} + dataWriterCfg.Local.IsKVSorted = hasAutoIncrementAutoID + dataWriterCfg.TiDB.TableName = tr.tableName + + logTask := tr.logger.With(zap.Int32("engineNumber", engineID)).Begin(zap.InfoLevel, "encode kv data and write") + dataEngineCfg := &backend.EngineConfig{ + TableInfo: tr.tableInfo, + } + if !tr.tableMeta.IsRowOrdered { + dataEngineCfg.Local.Compact = true + dataEngineCfg.Local.CompactConcurrency = 4 + dataEngineCfg.Local.CompactThreshold = local.CompactionUpperThreshold + } + dataEngine, err := rc.engineMgr.OpenEngine(ctx, dataEngineCfg, tr.tableName, engineID) + if err != nil { + return nil, errors.Trace(err) + } + + var wg sync.WaitGroup + var chunkErr common.OnceError + + type chunkFlushStatus struct { + dataStatus backend.ChunkFlushStatus + indexStatus backend.ChunkFlushStatus + chunkCp *checkpoints.ChunkCheckpoint + } + + // chunks that are finished writing, but checkpoints are not finished due to flush not finished. + var checkFlushLock sync.Mutex + flushPendingChunks := make([]chunkFlushStatus, 0, 16) + + chunkCpChan := make(chan *checkpoints.ChunkCheckpoint, 16) + go func() { + for { + select { + case cp, ok := <-chunkCpChan: + if !ok { + return + } + saveCheckpoint(rc, tr, engineID, cp) + case <-ctx.Done(): + return + } + } + }() + + setError := func(err error) { + chunkErr.Set(err) + cancel() + } + + metrics, _ := metric.FromContext(ctx) + + // Restore table data +ChunkLoop: + for chunkIndex, chunk := range cp.Chunks { + if rc.status != nil && rc.status.backend == config.BackendTiDB { + rc.status.FinishedFileSize.Add(chunk.Chunk.Offset - chunk.Key.Offset) + } + if chunk.Chunk.Offset >= chunk.Chunk.EndOffset { + continue + } + + checkFlushLock.Lock() + finished := 0 + for _, c := range flushPendingChunks { + if !(c.indexStatus.Flushed() && c.dataStatus.Flushed()) { + break + } + chunkCpChan <- c.chunkCp + finished++ + } + if finished > 0 { + flushPendingChunks = flushPendingChunks[finished:] + } + checkFlushLock.Unlock() + + failpoint.Inject("orphanWriterGoRoutine", func() { + if chunkIndex > 0 { + <-pCtx.Done() + } + }) + + select { + case <-pCtx.Done(): + break ChunkLoop + default: + } + + if chunkErr.Get() != nil { + break + } + + // Flows : + // 1. read mydump file + // 2. sql -> kvs + // 3. load kvs data (into kv deliver server) + // 4. flush kvs data (into tikv node) + var remainChunkCnt float64 + if chunk.Chunk.Offset < chunk.Chunk.EndOffset { + remainChunkCnt = float64(chunk.UnfinishedSize()) / float64(chunk.TotalSize()) + if metrics != nil { + metrics.ChunkCounter.WithLabelValues(metric.ChunkStatePending).Add(remainChunkCnt) + } + } + + dataWriter, err := dataEngine.LocalWriter(ctx, dataWriterCfg) + if err != nil { + setError(err) + break + } + + writerCfg := &backend.LocalWriterConfig{} + writerCfg.TiDB.TableName = tr.tableName + indexWriter, err := indexEngine.LocalWriter(ctx, writerCfg) + if err != nil { + _, _ = dataWriter.Close(ctx) + setError(err) + break + } + cr, err := newChunkProcessor(ctx, chunkIndex, rc.cfg, chunk, rc.ioWorkers, rc.store, tr.tableInfo.Core) + if err != nil { + setError(err) + break + } + + if chunk.FileMeta.Type == mydump.SourceTypeParquet { + // TODO: use the compressed size of the chunk to conduct memory control + if _, err = getChunkCompressedSizeForParquet(ctx, chunk, rc.store); err != nil { + return nil, errors.Trace(err) + } + } + + restoreWorker := rc.regionWorkers.Apply() + wg.Add(1) + go func(w *worker.Worker, cr *chunkProcessor) { + // Restore a chunk. + defer func() { + cr.close() + wg.Done() + rc.regionWorkers.Recycle(w) + }() + if metrics != nil { + metrics.ChunkCounter.WithLabelValues(metric.ChunkStateRunning).Add(remainChunkCnt) + } + err := cr.process(ctx, tr, engineID, dataWriter, indexWriter, rc) + var dataFlushStatus, indexFlushStaus backend.ChunkFlushStatus + if err == nil { + dataFlushStatus, err = dataWriter.Close(ctx) + } + if err == nil { + indexFlushStaus, err = indexWriter.Close(ctx) + } + if err == nil { + if metrics != nil { + metrics.ChunkCounter.WithLabelValues(metric.ChunkStateFinished).Add(remainChunkCnt) + metrics.BytesCounter.WithLabelValues(metric.StateRestoreWritten).Add(float64(cr.chunk.Checksum.SumSize())) + } + if dataFlushStatus != nil && indexFlushStaus != nil { + if dataFlushStatus.Flushed() && indexFlushStaus.Flushed() { + saveCheckpoint(rc, tr, engineID, cr.chunk) + } else { + checkFlushLock.Lock() + flushPendingChunks = append(flushPendingChunks, chunkFlushStatus{ + dataStatus: dataFlushStatus, + indexStatus: indexFlushStaus, + chunkCp: cr.chunk, + }) + checkFlushLock.Unlock() + } + } + } else { + if metrics != nil { + metrics.ChunkCounter.WithLabelValues(metric.ChunkStateFailed).Add(remainChunkCnt) + } + setError(err) + } + }(restoreWorker, cr) + } + + wg.Wait() + select { + case <-pCtx.Done(): + return nil, pCtx.Err() + default: + } + + // Report some statistics into the log for debugging. + totalKVSize := uint64(0) + totalSQLSize := int64(0) + logKeyName := "read(bytes)" + for _, chunk := range cp.Chunks { + totalKVSize += chunk.Checksum.SumSize() + totalSQLSize += chunk.UnfinishedSize() + if chunk.FileMeta.Type == mydump.SourceTypeParquet { + logKeyName = "read(rows)" + } + } + + err = chunkErr.Get() + logTask.End(zap.ErrorLevel, err, + zap.Int64(logKeyName, totalSQLSize), + zap.Uint64("written", totalKVSize), + ) + + trySavePendingChunks := func(context.Context) error { + checkFlushLock.Lock() + cnt := 0 + for _, chunk := range flushPendingChunks { + if !(chunk.dataStatus.Flushed() && chunk.indexStatus.Flushed()) { + break + } + saveCheckpoint(rc, tr, engineID, chunk.chunkCp) + cnt++ + } + flushPendingChunks = flushPendingChunks[cnt:] + checkFlushLock.Unlock() + return nil + } + + // in local mode, this check-point make no sense, because we don't do flush now, + // so there may be data lose if exit at here. So we don't write this checkpoint + // here like other mode. + if !isLocalBackend(rc.cfg) { + if saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, engineID, err, checkpoints.CheckpointStatusAllWritten); saveCpErr != nil { + return nil, errors.Trace(firstErr(err, saveCpErr)) + } + } + if err != nil { + // if process is canceled, we should flush all chunk checkpoints for local backend + if isLocalBackend(rc.cfg) && common.IsContextCanceledError(err) { + // ctx is canceled, so to avoid Close engine failed, we use `context.Background()` here + if _, err2 := dataEngine.Close(context.Background()); err2 != nil { + log.FromContext(ctx).Warn("flush all chunk checkpoints failed before manually exits", zap.Error(err2)) + return nil, errors.Trace(err) + } + if err2 := trySavePendingChunks(context.Background()); err2 != nil { + log.FromContext(ctx).Warn("flush all chunk checkpoints failed before manually exits", zap.Error(err2)) + } + } + return nil, errors.Trace(err) + } + + closedDataEngine, err := dataEngine.Close(ctx) + // For local backend, if checkpoint is enabled, we must flush index engine to avoid data loss. + // this flush action impact up to 10% of the performance, so we only do it if necessary. + if err == nil && rc.cfg.Checkpoint.Enable && isLocalBackend(rc.cfg) { + if err = indexEngine.Flush(ctx); err != nil { + return nil, errors.Trace(err) + } + if err = trySavePendingChunks(ctx); err != nil { + return nil, errors.Trace(err) + } + } + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, engineID, err, checkpoints.CheckpointStatusClosed) + if err = firstErr(err, saveCpErr); err != nil { + // If any error occurred, recycle worker immediately + return nil, errors.Trace(err) + } + return closedDataEngine, nil +} + +func (tr *TableImporter) importEngine( + ctx context.Context, + closedEngine *backend.ClosedEngine, + rc *Controller, + cp *checkpoints.EngineCheckpoint, +) error { + if cp.Status >= checkpoints.CheckpointStatusImported { + return nil + } + + // 1. calling import + if err := tr.importKV(ctx, closedEngine, rc); err != nil { + return errors.Trace(err) + } + + // 2. perform a level-1 compact if idling. + if rc.cfg.PostRestore.Level1Compact && rc.compactState.CompareAndSwap(compactStateIdle, compactStateDoing) { + go func() { + // we ignore level-1 compact failure since it is not fatal. + // no need log the error, it is done in (*Importer).Compact already. + _ = rc.doCompact(ctx, Level1Compact) + rc.compactState.Store(compactStateIdle) + }() + } + + return nil +} + +// postProcess execute rebase-auto-id/checksum/analyze according to the task config. +// +// if the parameter forcePostProcess to true, postProcess force run checksum and analyze even if the +// post-process-at-last config is true. And if this two phases are skipped, the first return value will be true. +func (tr *TableImporter) postProcess( + ctx context.Context, + rc *Controller, + cp *checkpoints.TableCheckpoint, + forcePostProcess bool, + metaMgr tableMetaMgr, +) (bool, error) { + if !rc.backend.ShouldPostProcess() { + return false, nil + } + + // alter table set auto_increment + if cp.Status < checkpoints.CheckpointStatusAlteredAutoInc { + tblInfo := tr.tableInfo.Core + var err error + // TODO why we have to rebase id for tidb backend??? remove it later. + if tblInfo.ContainsAutoRandomBits() { + ft := &common.GetAutoRandomColumn(tblInfo).FieldType + shardFmt := autoid.NewShardIDFormat(ft, tblInfo.AutoRandomBits, tblInfo.AutoRandomRangeBits) + maxCap := shardFmt.IncrementalBitsCapacity() + err = AlterAutoRandom(ctx, rc.db, tr.tableName, uint64(tr.alloc.Get(autoid.AutoRandomType).Base())+1, maxCap) + } else if common.TableHasAutoRowID(tblInfo) || tblInfo.GetAutoIncrementColInfo() != nil { + if isLocalBackend(rc.cfg) { + // for TiDB version >= 6.5.0, a table might have separate allocators for auto_increment column and _tidb_rowid, + // especially when a table has auto_increment non-clustered PK, it will use both allocators. + // And in this case, ALTER TABLE xxx AUTO_INCREMENT = xxx only works on the allocator of auto_increment column, + // not for allocator of _tidb_rowid. + // So we need to rebase IDs for those 2 allocators explicitly. + err = common.RebaseTableAllocators(ctx, map[autoid.AllocatorType]int64{ + autoid.RowIDAllocType: tr.alloc.Get(autoid.RowIDAllocType).Base(), + autoid.AutoIncrementType: tr.alloc.Get(autoid.AutoIncrementType).Base(), + }, tr, tr.dbInfo.ID, tr.tableInfo.Core) + } else { + // only alter auto increment id iff table contains auto-increment column or generated handle. + // ALTER TABLE xxx AUTO_INCREMENT = yyy has a bad naming. + // if a table has implicit _tidb_rowid column & tbl.SepAutoID=false, then it works on _tidb_rowid + // allocator, even if the table has NO auto-increment column. + newBase := uint64(tr.alloc.Get(autoid.RowIDAllocType).Base()) + 1 + err = AlterAutoIncrement(ctx, rc.db, tr.tableName, newBase) + } + } + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusAlteredAutoInc) + if err = firstErr(err, saveCpErr); err != nil { + return false, errors.Trace(err) + } + cp.Status = checkpoints.CheckpointStatusAlteredAutoInc + } + + // tidb backend don't need checksum & analyze + if rc.cfg.PostRestore.Checksum == config.OpLevelOff && rc.cfg.PostRestore.Analyze == config.OpLevelOff { + tr.logger.Debug("skip checksum & analyze, either because not supported by this backend or manually disabled") + err := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, nil, checkpoints.CheckpointStatusAnalyzeSkipped) + return false, errors.Trace(err) + } + + if !forcePostProcess && rc.cfg.PostRestore.PostProcessAtLast { + return true, nil + } + + w := rc.checksumWorks.Apply() + defer rc.checksumWorks.Recycle(w) + + shouldSkipAnalyze := false + estimatedModifyCnt := 100_000_000 + if cp.Status < checkpoints.CheckpointStatusChecksumSkipped { + // 4. do table checksum + var localChecksum verify.KVChecksum + for _, engine := range cp.Engines { + for _, chunk := range engine.Chunks { + localChecksum.Add(&chunk.Checksum) + } + } + indexNum := len(tr.tableInfo.Core.Indices) + if common.TableHasAutoRowID(tr.tableInfo.Core) { + indexNum++ + } + estimatedModifyCnt = int(localChecksum.SumKVS()) / (1 + indexNum) + tr.logger.Info("local checksum", zap.Object("checksum", &localChecksum)) + + // 4.5. do duplicate detection. + // if we came here, it must be a local backend. + // todo: remove this cast after we refactor the backend interface. Physical mode is so different, we shouldn't + // try to abstract it with logical mode. + localBackend := rc.backend.(*local.Backend) + dupeController := localBackend.GetDupeController(rc.cfg.TikvImporter.RangeConcurrency*2, rc.errorMgr) + hasDupe := false + if rc.cfg.Conflict.Strategy != config.NoneOnDup { + opts := &encode.SessionOptions{ + SQLMode: mysql.ModeStrictAllTables, + SysVars: rc.sysVars, + } + var err error + hasLocalDupe, err := dupeController.CollectLocalDuplicateRows(ctx, tr.encTable, tr.tableName, opts, rc.cfg.Conflict.Strategy) + if err != nil { + tr.logger.Error("collect local duplicate keys failed", log.ShortError(err)) + return false, errors.Trace(err) + } + hasDupe = hasLocalDupe + } + failpoint.Inject("SlowDownCheckDupe", func(v failpoint.Value) { + sec := v.(int) + tr.logger.Warn("start to sleep several seconds before checking other dupe", + zap.Int("seconds", sec)) + time.Sleep(time.Duration(sec) * time.Second) + }) + + otherHasDupe, needRemoteDupe, baseTotalChecksum, err := metaMgr.CheckAndUpdateLocalChecksum(ctx, &localChecksum, hasDupe) + if err != nil { + return false, errors.Trace(err) + } + needChecksum := !otherHasDupe && needRemoteDupe + hasDupe = hasDupe || otherHasDupe + + if needRemoteDupe && rc.cfg.Conflict.Strategy != config.NoneOnDup { + opts := &encode.SessionOptions{ + SQLMode: mysql.ModeStrictAllTables, + SysVars: rc.sysVars, + } + hasRemoteDupe, e := dupeController.CollectRemoteDuplicateRows(ctx, tr.encTable, tr.tableName, opts, rc.cfg.Conflict.Strategy) + if e != nil { + tr.logger.Error("collect remote duplicate keys failed", log.ShortError(e)) + return false, errors.Trace(e) + } + hasDupe = hasDupe || hasRemoteDupe + + if hasDupe { + if err = dupeController.ResolveDuplicateRows(ctx, tr.encTable, tr.tableName, rc.cfg.Conflict.Strategy); err != nil { + tr.logger.Error("resolve remote duplicate keys failed", log.ShortError(err)) + return false, errors.Trace(err) + } + } + } + + if rc.dupIndicator != nil { + tr.logger.Debug("set dupIndicator", zap.Bool("has-duplicate", hasDupe)) + rc.dupIndicator.CompareAndSwap(false, hasDupe) + } + + nextStage := checkpoints.CheckpointStatusChecksummed + if rc.cfg.PostRestore.Checksum != config.OpLevelOff && !hasDupe && needChecksum { + if cp.Checksum.SumKVS() > 0 || baseTotalChecksum.SumKVS() > 0 { + localChecksum.Add(&cp.Checksum) + localChecksum.Add(baseTotalChecksum) + tr.logger.Info("merged local checksum", zap.Object("checksum", &localChecksum)) + } + + var remoteChecksum *local.RemoteChecksum + remoteChecksum, err = DoChecksum(ctx, tr.tableInfo) + failpoint.Inject("checksum-error", func() { + tr.logger.Info("failpoint checksum-error injected.") + remoteChecksum = nil + err = status.Error(codes.Unknown, "Checksum meets error.") + }) + if err != nil { + if rc.cfg.PostRestore.Checksum != config.OpLevelOptional { + return false, errors.Trace(err) + } + tr.logger.Warn("do checksum failed, will skip this error and go on", log.ShortError(err)) + err = nil + } + if remoteChecksum != nil { + err = tr.compareChecksum(remoteChecksum, localChecksum) + // with post restore level 'optional', we will skip checksum error + if rc.cfg.PostRestore.Checksum == config.OpLevelOptional { + if err != nil { + tr.logger.Warn("compare checksum failed, will skip this error and go on", log.ShortError(err)) + err = nil + } + } + } + } else { + switch { + case rc.cfg.PostRestore.Checksum == config.OpLevelOff: + tr.logger.Info("skip checksum because the checksum option is off") + case hasDupe: + tr.logger.Info("skip checksum&analyze because duplicates were detected") + shouldSkipAnalyze = true + case !needChecksum: + tr.logger.Info("skip checksum&analyze because other lightning instance will do this") + shouldSkipAnalyze = true + } + err = nil + nextStage = checkpoints.CheckpointStatusChecksumSkipped + } + + // Don't call FinishTable when other lightning will calculate checksum. + if err == nil && needChecksum { + err = metaMgr.FinishTable(ctx) + } + + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, nextStage) + if err = firstErr(err, saveCpErr); err != nil { + return false, errors.Trace(err) + } + cp.Status = nextStage + } + + if cp.Status < checkpoints.CheckpointStatusIndexAdded { + var err error + if rc.cfg.TikvImporter.AddIndexBySQL { + w := rc.addIndexLimit.Apply() + err = tr.addIndexes(ctx, rc.db) + rc.addIndexLimit.Recycle(w) + // Analyze will be automatically triggered after indexes are added by SQL. We can skip manual analyze. + shouldSkipAnalyze = true + } + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexAdded) + if err = firstErr(err, saveCpErr); err != nil { + return false, errors.Trace(err) + } + cp.Status = checkpoints.CheckpointStatusIndexAdded + } + + // do table analyze + if cp.Status < checkpoints.CheckpointStatusAnalyzeSkipped { + switch { + case shouldSkipAnalyze || rc.cfg.PostRestore.Analyze == config.OpLevelOff: + if !shouldSkipAnalyze { + updateStatsMeta(ctx, rc.db, tr.tableInfo.ID, estimatedModifyCnt) + } + tr.logger.Info("skip analyze") + if err := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, nil, checkpoints.CheckpointStatusAnalyzeSkipped); err != nil { + return false, errors.Trace(err) + } + cp.Status = checkpoints.CheckpointStatusAnalyzeSkipped + case forcePostProcess || !rc.cfg.PostRestore.PostProcessAtLast: + err := tr.analyzeTable(ctx, rc.db) + // witch post restore level 'optional', we will skip analyze error + if rc.cfg.PostRestore.Analyze == config.OpLevelOptional { + if err != nil { + tr.logger.Warn("analyze table failed, will skip this error and go on", log.ShortError(err)) + err = nil + } + } + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusAnalyzed) + if err = firstErr(err, saveCpErr); err != nil { + return false, errors.Trace(err) + } + cp.Status = checkpoints.CheckpointStatusAnalyzed + } + } + + return true, nil +} + +func getChunkCompressedSizeForParquet( + ctx context.Context, + chunk *checkpoints.ChunkCheckpoint, + store storage.ExternalStorage, +) (int64, error) { + reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, store, storage.DecompressConfig{ + ZStdDecodeConcurrency: 1, + }) + if err != nil { + return 0, errors.Trace(err) + } + parser, err := mydump.NewParquetParser(ctx, store, reader, chunk.FileMeta.Path) + if err != nil { + _ = reader.Close() + return 0, errors.Trace(err) + } + //nolint: errcheck + defer parser.Close() + err = parser.Reader.ReadFooter() + if err != nil { + return 0, errors.Trace(err) + } + rowGroups := parser.Reader.Footer.GetRowGroups() + var maxRowGroupSize int64 + for _, rowGroup := range rowGroups { + var rowGroupSize int64 + columnChunks := rowGroup.GetColumns() + for _, columnChunk := range columnChunks { + columnChunkSize := columnChunk.MetaData.GetTotalCompressedSize() + rowGroupSize += columnChunkSize + } + maxRowGroupSize = max(maxRowGroupSize, rowGroupSize) + } + return maxRowGroupSize, nil +} + +func updateStatsMeta(ctx context.Context, db *sql.DB, tableID int64, count int) { + s := common.SQLWithRetry{ + DB: db, + Logger: log.FromContext(ctx).With(zap.Int64("tableID", tableID)), + } + err := s.Transact(ctx, "update stats_meta", func(ctx context.Context, tx *sql.Tx) error { + rs, err := tx.ExecContext(ctx, ` +update mysql.stats_meta + set modify_count = ?, + count = ?, + version = @@tidb_current_ts + where table_id = ?; +`, count, count, tableID) + if err != nil { + return errors.Trace(err) + } + affected, err := rs.RowsAffected() + if err != nil { + return errors.Trace(err) + } + if affected == 0 { + return errors.Errorf("record with table_id %d not found", tableID) + } + return nil + }) + if err != nil { + s.Logger.Warn("failed to update stats_meta", zap.Error(err)) + } +} + +func parseColumnPermutations( + tableInfo *model.TableInfo, + columns []string, + ignoreColumns map[string]struct{}, + logger log.Logger, +) ([]int, error) { + colPerm := make([]int, 0, len(tableInfo.Columns)+1) + + columnMap := make(map[string]int) + for i, column := range columns { + columnMap[column] = i + } + + tableColumnMap := make(map[string]int) + for i, col := range tableInfo.Columns { + tableColumnMap[col.Name.L] = i + } + + // check if there are some unknown columns + var unknownCols []string + for _, c := range columns { + if _, ok := tableColumnMap[c]; !ok && c != model.ExtraHandleName.L { + if _, ignore := ignoreColumns[c]; !ignore { + unknownCols = append(unknownCols, c) + } + } + } + + if len(unknownCols) > 0 { + return colPerm, common.ErrUnknownColumns.GenWithStackByArgs(strings.Join(unknownCols, ","), tableInfo.Name) + } + + for _, colInfo := range tableInfo.Columns { + if i, ok := columnMap[colInfo.Name.L]; ok { + if _, ignore := ignoreColumns[colInfo.Name.L]; !ignore { + colPerm = append(colPerm, i) + } else { + logger.Debug("column ignored by user requirements", + zap.Stringer("table", tableInfo.Name), + zap.String("colName", colInfo.Name.O), + zap.Stringer("colType", &colInfo.FieldType), + ) + colPerm = append(colPerm, -1) + } + } else { + if len(colInfo.GeneratedExprString) == 0 { + logger.Warn("column missing from data file, going to fill with default value", + zap.Stringer("table", tableInfo.Name), + zap.String("colName", colInfo.Name.O), + zap.Stringer("colType", &colInfo.FieldType), + ) + } + colPerm = append(colPerm, -1) + } + } + // append _tidb_rowid column + rowIDIdx := -1 + if i, ok := columnMap[model.ExtraHandleName.L]; ok { + if _, ignored := ignoreColumns[model.ExtraHandleName.L]; !ignored { + rowIDIdx = i + } + } + // FIXME: the schema info for tidb backend is not complete, so always add the _tidb_rowid field. + // Other logic should ignore this extra field if not needed. + colPerm = append(colPerm, rowIDIdx) + + return colPerm, nil +} + +func (tr *TableImporter) importKV( + ctx context.Context, + closedEngine *backend.ClosedEngine, + rc *Controller, +) error { + task := closedEngine.Logger().Begin(zap.InfoLevel, "import and cleanup engine") + regionSplitSize := int64(rc.cfg.TikvImporter.RegionSplitSize) + regionSplitKeys := int64(rc.cfg.TikvImporter.RegionSplitKeys) + + if regionSplitSize == 0 && rc.taskMgr != nil { + regionSplitSize = int64(config.SplitRegionSize) + if err := rc.taskMgr.CheckTasksExclusively(ctx, func(tasks []taskMeta) ([]taskMeta, error) { + if len(tasks) > 0 { + regionSplitSize = int64(config.SplitRegionSize) * int64(min(len(tasks), config.MaxSplitRegionSizeRatio)) + } + return nil, nil + }); err != nil { + return errors.Trace(err) + } + } + if regionSplitKeys == 0 { + if regionSplitSize > int64(config.SplitRegionSize) { + regionSplitKeys = int64(float64(regionSplitSize) / float64(config.SplitRegionSize) * float64(config.SplitRegionKeys)) + } else { + regionSplitKeys = int64(config.SplitRegionKeys) + } + } + err := closedEngine.Import(ctx, regionSplitSize, regionSplitKeys) + if common.ErrFoundDuplicateKeys.Equal(err) { + err = local.ConvertToErrFoundConflictRecords(err, tr.encTable) + } + saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, closedEngine.GetID(), err, checkpoints.CheckpointStatusImported) + // Don't clean up when save checkpoint failed, because we will verifyLocalFile and import engine again after restart. + if err == nil && saveCpErr == nil { + err = multierr.Append(err, closedEngine.Cleanup(ctx)) + } + err = firstErr(err, saveCpErr) + + dur := task.End(zap.ErrorLevel, err) + + if err != nil { + return errors.Trace(err) + } + + if m, ok := metric.FromContext(ctx); ok { + m.ImportSecondsHistogram.Observe(dur.Seconds()) + } + + failpoint.Inject("SlowDownImport", func() {}) + + return nil +} + +// do checksum for each table. +func (tr *TableImporter) compareChecksum(remoteChecksum *local.RemoteChecksum, localChecksum verify.KVChecksum) error { + if remoteChecksum.Checksum != localChecksum.Sum() || + remoteChecksum.TotalKVs != localChecksum.SumKVS() || + remoteChecksum.TotalBytes != localChecksum.SumSize() { + return common.ErrChecksumMismatch.GenWithStackByArgs( + remoteChecksum.Checksum, localChecksum.Sum(), + remoteChecksum.TotalKVs, localChecksum.SumKVS(), + remoteChecksum.TotalBytes, localChecksum.SumSize(), + ) + } + + tr.logger.Info("checksum pass", zap.Object("local", &localChecksum)) + return nil +} + +func (tr *TableImporter) analyzeTable(ctx context.Context, db *sql.DB) error { + task := tr.logger.Begin(zap.InfoLevel, "analyze") + exec := common.SQLWithRetry{ + DB: db, + Logger: tr.logger, + } + err := exec.Exec(ctx, "analyze table", "ANALYZE TABLE "+tr.tableName) + task.End(zap.ErrorLevel, err) + return err +} + +func (tr *TableImporter) dropIndexes(ctx context.Context, db *sql.DB) error { + logger := log.FromContext(ctx).With(zap.String("table", tr.tableName)) + + tblInfo := tr.tableInfo + remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo.Core) + for _, idxInfo := range dropIndexes { + sqlStr := common.BuildDropIndexSQL(tblInfo.DB, tblInfo.Name, idxInfo) + + logger.Info("drop index", zap.String("sql", sqlStr)) + + s := common.SQLWithRetry{ + DB: db, + Logger: logger, + } + if err := s.Exec(ctx, "drop index", sqlStr); err != nil { + if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { + switch merr.Number { + case errno.ErrCantDropFieldOrKey, errno.ErrDropIndexNeededInForeignKey: + remainIndexes = append(remainIndexes, idxInfo) + logger.Info("can't drop index, skip", zap.String("index", idxInfo.Name.O), zap.Error(err)) + continue + } + } + return common.ErrDropIndexFailed.Wrap(err).GenWithStackByArgs(common.EscapeIdentifier(idxInfo.Name.O), tr.tableName) + } + } + if len(remainIndexes) < len(tblInfo.Core.Indices) { + // Must clone (*model.TableInfo) before modifying it, since it may be referenced in other place. + tblInfo.Core = tblInfo.Core.Clone() + tblInfo.Core.Indices = remainIndexes + + // Rebuild encTable. + encTable, err := tables.TableFromMeta(tr.alloc, tblInfo.Core) + if err != nil { + return errors.Trace(err) + } + tr.encTable = encTable + } + return nil +} + +func (tr *TableImporter) addIndexes(ctx context.Context, db *sql.DB) (retErr error) { + const progressStep = "add-index" + task := tr.logger.Begin(zap.InfoLevel, "add indexes") + defer func() { + task.End(zap.ErrorLevel, retErr) + }() + + tblInfo := tr.tableInfo + tableName := tr.tableName + + singleSQL, multiSQLs := common.BuildAddIndexSQL(tableName, tblInfo.Core, tblInfo.Desired) + if len(multiSQLs) == 0 { + return nil + } + + logger := log.FromContext(ctx).With(zap.String("table", tableName)) + + defer func() { + if retErr == nil { + web.BroadcastTableProgress(tr.tableName, progressStep, 1) + } else if !log.IsContextCanceledError(retErr) { + // Try to strip the prefix of the error message. + // e.g "add index failed: Error 1062 ..." -> "Error 1062 ..." + cause := errors.Cause(retErr) + if cause == nil { + cause = retErr + } + retErr = common.ErrAddIndexFailed.GenWithStack( + "add index failed on table %s: %v, you can add index manually by the following SQL: %s", + tableName, cause, singleSQL) + } + }() + + var totalRows int + if m, ok := metric.FromContext(ctx); ok { + totalRows = int(metric.ReadCounter(m.RowsCounter.WithLabelValues(metric.StateRestored, tableName))) + } + + // Try to add all indexes in one statement. + err := tr.executeDDL(ctx, db, singleSQL, func(status *ddlStatus) { + if totalRows > 0 { + progress := float64(status.rowCount) / float64(totalRows*len(multiSQLs)) + if progress > 1 { + progress = 1 + } + web.BroadcastTableProgress(tableName, progressStep, progress) + logger.Info("add index progress", zap.String("progress", fmt.Sprintf("%.1f%%", progress*100))) + } + }) + if err == nil { + return nil + } + if !common.IsDupKeyError(err) { + return err + } + if len(multiSQLs) == 1 { + return nil + } + logger.Warn("cannot add all indexes in one statement, try to add them one by one", zap.Strings("sqls", multiSQLs), zap.Error(err)) + + baseProgress := float64(0) + for _, ddl := range multiSQLs { + err := tr.executeDDL(ctx, db, ddl, func(status *ddlStatus) { + if totalRows > 0 { + p := float64(status.rowCount) / float64(totalRows) + progress := baseProgress + p/float64(len(multiSQLs)) + web.BroadcastTableProgress(tableName, progressStep, progress) + logger.Info("add index progress", zap.String("progress", fmt.Sprintf("%.1f%%", progress*100))) + } + }) + if err != nil && !common.IsDupKeyError(err) { + return err + } + baseProgress += 1.0 / float64(len(multiSQLs)) + web.BroadcastTableProgress(tableName, progressStep, baseProgress) + } + return nil +} + +func (*TableImporter) executeDDL( + ctx context.Context, + db *sql.DB, + ddl string, + updateProgress func(status *ddlStatus), +) error { + logger := log.FromContext(ctx).With(zap.String("ddl", ddl)) + logger.Info("execute ddl") + + s := common.SQLWithRetry{ + DB: db, + Logger: logger, + } + + var currentTS int64 + if err := s.QueryRow(ctx, "", "SELECT UNIX_TIMESTAMP()", ¤tTS); err != nil { + currentTS = time.Now().Unix() + logger.Warn("failed to query current timestamp, use current time instead", zap.Int64("currentTS", currentTS), zap.Error(err)) + } + + resultCh := make(chan error, 1) + go func() { + resultCh <- s.Exec(ctx, "add index", ddl) + }() + + failpoint.Inject("AddIndexCrash", func() { + _ = common.KillMySelf() + }) + + var ddlErr error + for { + select { + case ddlErr = <-resultCh: + failpoint.Inject("AddIndexFail", func() { + ddlErr = errors.New("injected error") + }) + if ddlErr == nil { + return nil + } + if log.IsContextCanceledError(ddlErr) { + return ddlErr + } + if isDeterminedError(ddlErr) { + return ddlErr + } + logger.Warn("failed to execute ddl, try to query ddl status", zap.Error(ddlErr)) + case <-time.After(getDDLStatusInterval): + } + + var status *ddlStatus + err := common.Retry("query ddl status", logger, func() error { + var err error + status, err = getDDLStatus(ctx, db, ddl, time.Unix(currentTS, 0)) + return err + }) + if err != nil || status == nil { + logger.Warn("failed to query ddl status", zap.Error(err)) + if ddlErr != nil { + return ddlErr + } + continue + } + updateProgress(status) + + if ddlErr != nil { + switch state := status.state; state { + case model.JobStateDone, model.JobStateSynced: + logger.Info("ddl job is finished", zap.Stringer("state", state)) + return nil + case model.JobStateRunning, model.JobStateQueueing, model.JobStateNone: + logger.Info("ddl job is running", zap.Stringer("state", state)) + default: + logger.Warn("ddl job is canceled or rollbacked", zap.Stringer("state", state)) + return ddlErr + } + } + } +} + +func isDeterminedError(err error) bool { + if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { + switch merr.Number { + case errno.ErrDupKeyName, errno.ErrMultiplePriKey, errno.ErrDupUnique, errno.ErrDupEntry: + return true + } + } + return false +} + +const ( + getDDLStatusInterval = time.Minute + // Limit the number of jobs to query. Large limit may result in empty result. See https://github.com/pingcap/tidb/issues/42298. + // A new TiDB cluster has at least 40 jobs in the history queue, so 30 is a reasonable value. + getDDLStatusMaxJobs = 30 +) + +type ddlStatus struct { + state model.JobState + rowCount int64 +} + +func getDDLStatus( + ctx context.Context, + db *sql.DB, + query string, + minCreateTime time.Time, +) (*ddlStatus, error) { + jobID, err := getDDLJobIDByQuery(ctx, db, query) + if err != nil || jobID == 0 { + return nil, err + } + rows, err := db.QueryContext(ctx, fmt.Sprintf("ADMIN SHOW DDL JOBS %d WHERE job_id = %d", getDDLStatusMaxJobs, jobID)) + if err != nil { + return nil, errors.Trace(err) + } + defer rows.Close() + + cols, err := rows.Columns() + if err != nil { + return nil, errors.Trace(err) + } + + var ( + rowCount int64 + state string + createTimeStr sql.NullString + ) + dest := make([]any, len(cols)) + for i, col := range cols { + switch strings.ToLower(col) { + case "row_count": + dest[i] = &rowCount + case "state": + dest[i] = &state + case "create_time": + dest[i] = &createTimeStr + default: + var anyStr sql.NullString + dest[i] = &anyStr + } + } + status := &ddlStatus{} + + for rows.Next() { + if err := rows.Scan(dest...); err != nil { + return nil, errors.Trace(err) + } + status.rowCount += rowCount + // subjob doesn't have create_time, ignore it. + if !createTimeStr.Valid || createTimeStr.String == "" { + continue + } + createTime, err := time.Parse(time.DateTime, createTimeStr.String) + if err != nil { + return nil, errors.Trace(err) + } + // The job is not created by the current task, ignore it. + if createTime.Before(minCreateTime) { + return nil, nil + } + status.state = model.StrToJobState(state) + } + return status, errors.Trace(rows.Err()) +} + +func getDDLJobIDByQuery(ctx context.Context, db *sql.DB, wantQuery string) (int64, error) { + rows, err := db.QueryContext(ctx, fmt.Sprintf("ADMIN SHOW DDL JOB QUERIES LIMIT %d", getDDLStatusMaxJobs)) + if err != nil { + return 0, errors.Trace(err) + } + defer rows.Close() + + for rows.Next() { + var ( + jobID int64 + query string + ) + if err := rows.Scan(&jobID, &query); err != nil { + return 0, errors.Trace(err) + } + if query == wantQuery { + return jobID, errors.Trace(rows.Err()) + } + } + return 0, errors.Trace(rows.Err()) +} + +func (tr *TableImporter) preDeduplicate( + ctx context.Context, + rc *Controller, + cp *checkpoints.TableCheckpoint, + workingDir string, +) error { + d := &dupDetector{ + tr: tr, + rc: rc, + cp: cp, + logger: tr.logger, + } + originalErr := d.run(ctx, workingDir, tr.dupIgnoreRows) + if originalErr == nil { + return nil + } + + if !ErrDuplicateKey.Equal(originalErr) { + return errors.Trace(originalErr) + } + + var ( + idxName string + oneConflictMsg, otherConflictMsg string + ) + + // provide a more friendly error message + + dupErr := errors.Cause(originalErr).(*errors.Error) + conflictIdxID := dupErr.Args()[0].(int64) + if conflictIdxID == conflictOnHandle { + idxName = "PRIMARY" + } else { + for _, idxInfo := range tr.tableInfo.Core.Indices { + if idxInfo.ID == conflictIdxID { + idxName = idxInfo.Name.O + break + } + } + } + if idxName == "" { + tr.logger.Error("cannot find index name", zap.Int64("conflictIdxID", conflictIdxID)) + return errors.Trace(originalErr) + } + if !rc.cfg.Checkpoint.Enable { + err := errors.Errorf("duplicate key in table %s caused by index `%s`, but because checkpoint is off we can't have more details", + tr.tableName, idxName) + rc.errorMgr.RecordDuplicateOnce( + ctx, tr.logger, tr.tableName, "", -1, err.Error(), -1, "", + ) + return err + } + conflictEncodedRowIDs := dupErr.Args()[1].([][]byte) + if len(conflictEncodedRowIDs) < 2 { + tr.logger.Error("invalid conflictEncodedRowIDs", zap.Int("len", len(conflictEncodedRowIDs))) + return errors.Trace(originalErr) + } + rowID := make([]int64, 2) + var err error + _, rowID[0], err = codec.DecodeComparableVarint(conflictEncodedRowIDs[0]) + if err != nil { + rowIDHex := hex.EncodeToString(conflictEncodedRowIDs[0]) + tr.logger.Error("failed to decode rowID", + zap.String("rowID", rowIDHex), + zap.Error(err)) + return errors.Trace(originalErr) + } + _, rowID[1], err = codec.DecodeComparableVarint(conflictEncodedRowIDs[1]) + if err != nil { + rowIDHex := hex.EncodeToString(conflictEncodedRowIDs[1]) + tr.logger.Error("failed to decode rowID", + zap.String("rowID", rowIDHex), + zap.Error(err)) + return errors.Trace(originalErr) + } + + tableCp, err := rc.checkpointsDB.Get(ctx, tr.tableName) + if err != nil { + tr.logger.Error("failed to get table checkpoint", zap.Error(err)) + return errors.Trace(err) + } + var ( + secondConflictPath string + ) + for _, engineCp := range tableCp.Engines { + for _, chunkCp := range engineCp.Chunks { + if chunkCp.Chunk.PrevRowIDMax <= rowID[0] && rowID[0] < chunkCp.Chunk.RowIDMax { + oneConflictMsg = fmt.Sprintf("row %d counting from offset %d in file %s", + rowID[0]-chunkCp.Chunk.PrevRowIDMax, + chunkCp.Chunk.Offset, + chunkCp.FileMeta.Path) + } + if chunkCp.Chunk.PrevRowIDMax <= rowID[1] && rowID[1] < chunkCp.Chunk.RowIDMax { + secondConflictPath = chunkCp.FileMeta.Path + otherConflictMsg = fmt.Sprintf("row %d counting from offset %d in file %s", + rowID[1]-chunkCp.Chunk.PrevRowIDMax, + chunkCp.Chunk.Offset, + chunkCp.FileMeta.Path) + } + } + } + if oneConflictMsg == "" || otherConflictMsg == "" { + tr.logger.Error("cannot find conflict rows by rowID", + zap.Int64("rowID[0]", rowID[0]), + zap.Int64("rowID[1]", rowID[1])) + return errors.Trace(originalErr) + } + err = errors.Errorf("duplicate entry for key '%s', a pair of conflicting rows are (%s, %s)", + idxName, oneConflictMsg, otherConflictMsg) + rc.errorMgr.RecordDuplicateOnce( + ctx, tr.logger, tr.tableName, secondConflictPath, -1, err.Error(), rowID[1], "", + ) + return err +} diff --git a/lightning/pkg/server/binding__failpoint_binding__.go b/lightning/pkg/server/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..884841332390a --- /dev/null +++ b/lightning/pkg/server/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package server + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/lightning/pkg/server/lightning.go b/lightning/pkg/server/lightning.go index 0c20abf10c824..465538267411f 100644 --- a/lightning/pkg/server/lightning.go +++ b/lightning/pkg/server/lightning.go @@ -237,12 +237,12 @@ func (l *Lightning) goServe(statusAddr string, realAddrWriter io.Writer) error { mux.HandleFunc("/debug/pprof/trace", pprof.Trace) // Enable failpoint http API for testing. - failpoint.Inject("EnableTestAPI", func() { + if _, _err_ := failpoint.Eval(_curpkg_("EnableTestAPI")); _err_ == nil { mux.HandleFunc("/fail/", func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") new(failpoint.HttpHandler).ServeHTTP(w, r) }) - }) + } handleTasks := http.StripPrefix("/tasks", http.HandlerFunc(l.handleTask)) mux.Handle("/tasks", httpHandleWrapper(handleTasks.ServeHTTP)) @@ -329,7 +329,7 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. opt(o) } - failpoint.Inject("setExtStorage", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("setExtStorage")); _err_ == nil { path := val.(string) b, err := storage.ParseBackend(path, nil) if err != nil { @@ -341,11 +341,11 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. } o.dumpFileStorage = s o.checkpointStorage = s - }) - failpoint.Inject("setCheckpointName", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("setCheckpointName")); _err_ == nil { file := val.(string) o.checkpointName = file - }) + } if o.dumpFileStorage != nil { // we don't use it, set a value to pass Adjust @@ -357,11 +357,11 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. } taskCfg.TaskID = time.Now().UnixNano() - failpoint.Inject("SetTaskID", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SetTaskID")); _err_ == nil { taskCfg.TaskID = int64(val.(int)) - }) + } - failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("SetIOTotalBytes")); _err_ == nil { o.logger.Info("set io total bytes") taskCfg.TiDB.IOTotalBytes = atomic.NewUint64(0) taskCfg.TiDB.UUID = uuid.New().String() @@ -371,7 +371,7 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. log.L().Info("IOTotalBytes", zap.Uint64("IOTotalBytes", taskCfg.TiDB.IOTotalBytes.Load())) } }() - }) + } if taskCfg.TiDB.IOTotalBytes != nil { o.logger.Info("found IO total bytes counter") mysql.RegisterDialContext(taskCfg.TiDB.UUID, func(ctx context.Context, addr string) (net.Conn, error) { @@ -463,7 +463,7 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti web.BroadcastEndTask(err) }() - failpoint.Inject("SkipRunTask", func() { + if _, _err_ := failpoint.Eval(_curpkg_("SkipRunTask")); _err_ == nil { if notifyCh, ok := l.ctx.Value(taskRunNotifyKey).(chan struct{}); ok { select { case notifyCh <- struct{}{}: @@ -474,13 +474,13 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti select { case recorder <- taskCfg: case <-ctx.Done(): - failpoint.Return(ctx.Err()) + return ctx.Err() } } - failpoint.Return(nil) - }) + return nil + } - failpoint.Inject("SetCertExpiredSoon", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SetCertExpiredSoon")); _err_ == nil { rootKeyPath := val.(string) rootCaPath := taskCfg.Security.CAPath keyPath := taskCfg.Security.KeyPath @@ -488,9 +488,9 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti if err := updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath, time.Second*10); err != nil { panic(err) } - }) + } - failpoint.Inject("PrintStatus", func() { + if _, _err_ := failpoint.Eval(_curpkg_("PrintStatus")); _err_ == nil { defer func() { finished, total := l.Status() o.logger.Warn("PrintStatus Failpoint", @@ -498,7 +498,7 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti zap.Int64("total", total), zap.Bool("equal", finished == total)) }() - }) + } if err := taskCfg.TiDB.Security.BuildTLSConfig(); err != nil { return common.ErrInvalidTLSConfig.Wrap(err) @@ -602,10 +602,10 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti return errors.Trace(err) } - failpoint.Inject("orphanWriterGoRoutine", func() { + if _, _err_ := failpoint.Eval(_curpkg_("orphanWriterGoRoutine")); _err_ == nil { // don't exit too quickly to expose panic defer time.Sleep(time.Second * 10) - }) + } defer procedure.Close() err = procedure.Run(ctx) diff --git a/lightning/pkg/server/lightning.go__failpoint_stash__ b/lightning/pkg/server/lightning.go__failpoint_stash__ new file mode 100644 index 0000000000000..0c20abf10c824 --- /dev/null +++ b/lightning/pkg/server/lightning.go__failpoint_stash__ @@ -0,0 +1,1152 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "cmp" + "compress/gzip" + "context" + "crypto/ecdsa" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "database/sql" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net" + "net/http" + "net/http/pprof" + "os" + "slices" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/version/build" + "github.com/pingcap/tidb/lightning/pkg/importer" + "github.com/pingcap/tidb/lightning/pkg/web" + _ "github.com/pingcap/tidb/pkg/expression" // get rid of `import cycle`: just init expression.RewriteAstExpr,and called at package `backend.kv`. + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/mydump" + "github.com/pingcap/tidb/pkg/lightning/tikv" + _ "github.com/pingcap/tidb/pkg/planner/core" // init expression.EvalSimpleAst related function + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/promutil" + "github.com/pingcap/tidb/pkg/util/redact" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/shurcooL/httpgzip" + pdhttp "github.com/tikv/pd/client/http" + "go.uber.org/atomic" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// Lightning is the main struct of the lightning package. +type Lightning struct { + globalCfg *config.GlobalConfig + globalTLS *common.TLS + // taskCfgs is the list of task configurations enqueued in the server mode + taskCfgs *config.List + ctx context.Context + shutdown context.CancelFunc // for whole lightning context + server http.Server + serverAddr net.Addr + serverLock sync.Mutex + status importer.LightningStatus + + promFactory promutil.Factory + promRegistry promutil.Registry + metrics *metric.Metrics + + cancelLock sync.Mutex + curTask *config.Config + cancel context.CancelFunc // for per task context, which maybe different from lightning context + + taskCanceled bool +} + +func initEnv(cfg *config.GlobalConfig) error { + if cfg.App.Config.File == "" { + return nil + } + return log.InitLogger(&cfg.App.Config, cfg.TiDB.LogLevel) +} + +// New creates a new Lightning instance. +func New(globalCfg *config.GlobalConfig) *Lightning { + if err := initEnv(globalCfg); err != nil { + fmt.Println("Failed to initialize environment:", err) + os.Exit(1) + } + + tls, err := common.NewTLS( + globalCfg.Security.CAPath, + globalCfg.Security.CertPath, + globalCfg.Security.KeyPath, + globalCfg.App.StatusAddr, + globalCfg.Security.CABytes, + globalCfg.Security.CertBytes, + globalCfg.Security.KeyBytes, + ) + if err != nil { + log.L().Fatal("failed to load TLS certificates", zap.Error(err)) + } + + redact.InitRedact(globalCfg.Security.RedactInfoLog) + + promFactory := promutil.NewDefaultFactory() + promRegistry := promutil.NewDefaultRegistry() + ctx, shutdown := context.WithCancel(context.Background()) + return &Lightning{ + globalCfg: globalCfg, + globalTLS: tls, + ctx: ctx, + shutdown: shutdown, + promFactory: promFactory, + promRegistry: promRegistry, + } +} + +// GoServe starts the HTTP server in a goroutine. The server will be closed +func (l *Lightning) GoServe() error { + handleSigUsr1(func() { + l.serverLock.Lock() + statusAddr := l.globalCfg.App.StatusAddr + shouldStartServer := len(statusAddr) == 0 + if shouldStartServer { + l.globalCfg.App.StatusAddr = ":" + } + l.serverLock.Unlock() + + if shouldStartServer { + // open a random port and start the server if SIGUSR1 is received. + if err := l.goServe(":", os.Stderr); err != nil { + log.L().Warn("failed to start HTTP server", log.ShortError(err)) + } + } else { + // just prints the server address if it is already started. + log.L().Info("already started HTTP server", zap.Stringer("address", l.serverAddr)) + } + }) + + l.serverLock.Lock() + statusAddr := l.globalCfg.App.StatusAddr + l.serverLock.Unlock() + + if len(statusAddr) == 0 { + return nil + } + return l.goServe(statusAddr, io.Discard) +} + +// TODO: maybe handle http request using gin +type loggingResponseWriter struct { + http.ResponseWriter + statusCode int + body string +} + +func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { + return &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK} +} + +// WriteHeader implements http.ResponseWriter. +func (lrw *loggingResponseWriter) WriteHeader(code int) { + lrw.statusCode = code + lrw.ResponseWriter.WriteHeader(code) +} + +// Write implements http.ResponseWriter. +func (lrw *loggingResponseWriter) Write(d []byte) (int, error) { + // keep first part of the response for logging, max 1K + if lrw.body == "" && len(d) > 0 { + length := len(d) + if length > 1024 { + length = 1024 + } + lrw.body = string(d[:length]) + } + return lrw.ResponseWriter.Write(d) +} + +func httpHandleWrapper(h http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + logger := log.L().With(zap.String("method", r.Method), zap.Stringer("url", r.URL)). + Begin(zapcore.InfoLevel, "process http request") + + newWriter := newLoggingResponseWriter(w) + h.ServeHTTP(newWriter, r) + + bodyField := zap.Skip() + if newWriter.Header().Get("Content-Encoding") != "gzip" { + bodyField = zap.String("body", newWriter.body) + } + logger.End(zapcore.InfoLevel, nil, zap.Int("status", newWriter.statusCode), bodyField) + } +} + +func (l *Lightning) goServe(statusAddr string, realAddrWriter io.Writer) error { + mux := http.NewServeMux() + mux.Handle("/", http.RedirectHandler("/web/", http.StatusFound)) + + registry := l.promRegistry + registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + registry.MustRegister(collectors.NewGoCollector()) + if gatherer, ok := registry.(prometheus.Gatherer); ok { + handler := promhttp.InstrumentMetricHandler( + registry, promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{}), + ) + mux.Handle("/metrics", handler) + } + + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + + // Enable failpoint http API for testing. + failpoint.Inject("EnableTestAPI", func() { + mux.HandleFunc("/fail/", func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") + new(failpoint.HttpHandler).ServeHTTP(w, r) + }) + }) + + handleTasks := http.StripPrefix("/tasks", http.HandlerFunc(l.handleTask)) + mux.Handle("/tasks", httpHandleWrapper(handleTasks.ServeHTTP)) + mux.Handle("/tasks/", httpHandleWrapper(handleTasks.ServeHTTP)) + mux.HandleFunc("/progress/task", httpHandleWrapper(handleProgressTask)) + mux.HandleFunc("/progress/table", httpHandleWrapper(handleProgressTable)) + mux.HandleFunc("/pause", httpHandleWrapper(handlePause)) + mux.HandleFunc("/resume", httpHandleWrapper(handleResume)) + mux.HandleFunc("/loglevel", httpHandleWrapper(handleLogLevel)) + + mux.Handle("/web/", http.StripPrefix("/web", httpgzip.FileServer(web.Res, httpgzip.FileServerOptions{ + IndexHTML: true, + ServeError: func(w http.ResponseWriter, req *http.Request, err error) { + if os.IsNotExist(err) && !strings.Contains(req.URL.Path, ".") { + http.Redirect(w, req, "/web/", http.StatusFound) + } else { + httpgzip.NonSpecific(w, req, err) + } + }, + }))) + + listener, err := net.Listen("tcp", statusAddr) + if err != nil { + return err + } + l.serverAddr = listener.Addr() + log.L().Info("starting HTTP server", zap.Stringer("address", l.serverAddr)) + fmt.Fprintln(realAddrWriter, "started HTTP server on", l.serverAddr) + l.server.Handler = mux + listener = l.globalTLS.WrapListener(listener) + + go func() { + err := l.server.Serve(listener) + log.L().Info("stopped HTTP server", log.ShortError(err)) + }() + return nil +} + +// RunServer is used by binary lightning to start a HTTP server to receive import tasks. +func (l *Lightning) RunServer() error { + l.serverLock.Lock() + l.taskCfgs = config.NewConfigList() + l.serverLock.Unlock() + log.L().Info( + "Lightning server is running, post to /tasks to start an import task", + zap.Stringer("address", l.serverAddr), + ) + + for { + task, err := l.taskCfgs.Pop(l.ctx) + if err != nil { + return err + } + o := &options{ + promFactory: l.promFactory, + promRegistry: l.promRegistry, + logger: log.L(), + } + err = l.run(context.Background(), task, o) + if err != nil && !common.IsContextCanceledError(err) { + importer.DeliverPauser.Pause() // force pause the progress on error + log.L().Error("tidb lightning encountered error", zap.Error(err)) + } + } +} + +// RunOnceWithOptions is used by binary lightning and host when using lightning as a library. +// - for binary lightning, taskCtx could be context.Background which means taskCtx wouldn't be canceled directly by its +// cancel function, but only by Lightning.Stop or HTTP DELETE using l.cancel. No need to set Options +// - for lightning as a library, taskCtx could be a meaningful context that get canceled outside, and there Options may +// be used: +// - WithGlue: set a caller implemented glue. Otherwise, lightning will use a default glue later. +// - WithDumpFileStorage: caller has opened an external storage for lightning. Otherwise, lightning will open a +// storage by config +// - WithCheckpointStorage: caller has opened an external storage for lightning and want to save checkpoint +// in it. Otherwise, lightning will save checkpoint by the Checkpoint.DSN in config +func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config.Config, opts ...Option) error { + o := &options{ + promFactory: l.promFactory, + promRegistry: l.promRegistry, + logger: log.L(), + } + for _, opt := range opts { + opt(o) + } + + failpoint.Inject("setExtStorage", func(val failpoint.Value) { + path := val.(string) + b, err := storage.ParseBackend(path, nil) + if err != nil { + panic(err) + } + s, err := storage.New(context.Background(), b, &storage.ExternalStorageOptions{}) + if err != nil { + panic(err) + } + o.dumpFileStorage = s + o.checkpointStorage = s + }) + failpoint.Inject("setCheckpointName", func(val failpoint.Value) { + file := val.(string) + o.checkpointName = file + }) + + if o.dumpFileStorage != nil { + // we don't use it, set a value to pass Adjust + taskCfg.Mydumper.SourceDir = "noop://" + } + + if err := taskCfg.Adjust(taskCtx); err != nil { + return err + } + + taskCfg.TaskID = time.Now().UnixNano() + failpoint.Inject("SetTaskID", func(val failpoint.Value) { + taskCfg.TaskID = int64(val.(int)) + }) + + failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { + o.logger.Info("set io total bytes") + taskCfg.TiDB.IOTotalBytes = atomic.NewUint64(0) + taskCfg.TiDB.UUID = uuid.New().String() + go func() { + for { + time.Sleep(time.Millisecond * 10) + log.L().Info("IOTotalBytes", zap.Uint64("IOTotalBytes", taskCfg.TiDB.IOTotalBytes.Load())) + } + }() + }) + if taskCfg.TiDB.IOTotalBytes != nil { + o.logger.Info("found IO total bytes counter") + mysql.RegisterDialContext(taskCfg.TiDB.UUID, func(ctx context.Context, addr string) (net.Conn, error) { + o.logger.Debug("connection with IO bytes counter") + d := &net.Dialer{} + conn, err := d.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + tcpConn := conn.(*net.TCPConn) + // try https://github.com/go-sql-driver/mysql/blob/bcc459a906419e2890a50fc2c99ea6dd927a88f2/connector.go#L56-L64 + err = tcpConn.SetKeepAlive(true) + if err != nil { + o.logger.Warn("set TCP keep alive failed", zap.Error(err)) + } + return util.NewTCPConnWithIOCounter(tcpConn, taskCfg.TiDB.IOTotalBytes), nil + }) + } + + return l.run(taskCtx, taskCfg, o) +} + +var ( + taskRunNotifyKey = "taskRunNotifyKey" + taskCfgRecorderKey = "taskCfgRecorderKey" +) + +func getKeyspaceName(db *sql.DB) (string, error) { + if db == nil { + return "", nil + } + + rows, err := db.Query("show config where Type = 'tidb' and name = 'keyspace-name'") + if err != nil { + return "", err + } + //nolint: errcheck + defer rows.Close() + + var ( + _type string + _instance string + _name string + value string + ) + if rows.Next() { + err = rows.Scan(&_type, &_instance, &_name, &value) + if err != nil { + return "", err + } + } + + return value, rows.Err() +} + +func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *options) (err error) { + build.LogInfo(build.Lightning) + o.logger.Info("cfg", zap.Stringer("cfg", taskCfg)) + + logutil.LogEnvVariables() + + if split.WaitRegionOnlineAttemptTimes != taskCfg.TikvImporter.RegionCheckBackoffLimit { + // it will cause data race if lightning is used as a library, but this is a + // hidden config so we ignore that case + split.WaitRegionOnlineAttemptTimes = taskCfg.TikvImporter.RegionCheckBackoffLimit + } + + metrics := metric.NewMetrics(o.promFactory) + metrics.RegisterTo(o.promRegistry) + defer func() { + metrics.UnregisterFrom(o.promRegistry) + }() + l.metrics = metrics + + ctx := metric.WithMetric(taskCtx, metrics) + ctx = log.NewContext(ctx, o.logger) + ctx, cancel := context.WithCancel(ctx) + l.cancelLock.Lock() + l.cancel = cancel + l.curTask = taskCfg + l.cancelLock.Unlock() + web.BroadcastStartTask() + + defer func() { + cancel() + l.cancelLock.Lock() + l.cancel = nil + l.cancelLock.Unlock() + web.BroadcastEndTask(err) + }() + + failpoint.Inject("SkipRunTask", func() { + if notifyCh, ok := l.ctx.Value(taskRunNotifyKey).(chan struct{}); ok { + select { + case notifyCh <- struct{}{}: + default: + } + } + if recorder, ok := l.ctx.Value(taskCfgRecorderKey).(chan *config.Config); ok { + select { + case recorder <- taskCfg: + case <-ctx.Done(): + failpoint.Return(ctx.Err()) + } + } + failpoint.Return(nil) + }) + + failpoint.Inject("SetCertExpiredSoon", func(val failpoint.Value) { + rootKeyPath := val.(string) + rootCaPath := taskCfg.Security.CAPath + keyPath := taskCfg.Security.KeyPath + certPath := taskCfg.Security.CertPath + if err := updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath, time.Second*10); err != nil { + panic(err) + } + }) + + failpoint.Inject("PrintStatus", func() { + defer func() { + finished, total := l.Status() + o.logger.Warn("PrintStatus Failpoint", + zap.Int64("finished", finished), + zap.Int64("total", total), + zap.Bool("equal", finished == total)) + }() + }) + + if err := taskCfg.TiDB.Security.BuildTLSConfig(); err != nil { + return common.ErrInvalidTLSConfig.Wrap(err) + } + + s := o.dumpFileStorage + if s == nil { + u, err := storage.ParseBackend(taskCfg.Mydumper.SourceDir, nil) + if err != nil { + return common.NormalizeError(err) + } + s, err = storage.New(ctx, u, &storage.ExternalStorageOptions{}) + if err != nil { + return common.NormalizeError(err) + } + } + + // return expectedErr means at least meet one file + expectedErr := errors.New("Stop Iter") + walkErr := s.WalkDir(ctx, &storage.WalkOption{ListCount: 1}, func(string, int64) error { + // return an error when meet the first regular file to break the walk loop + return expectedErr + }) + if !errors.ErrorEqual(walkErr, expectedErr) { + if walkErr == nil { + return common.ErrEmptySourceDir.GenWithStackByArgs(taskCfg.Mydumper.SourceDir) + } + return common.NormalizeOrWrapErr(common.ErrStorageUnknown, walkErr) + } + + loadTask := o.logger.Begin(zap.InfoLevel, "load data source") + var mdl *mydump.MDLoader + mdl, err = mydump.NewLoaderWithStore(ctx, mydump.NewLoaderCfg(taskCfg), s) + loadTask.End(zap.ErrorLevel, err) + if err != nil { + return errors.Trace(err) + } + err = checkSystemRequirement(taskCfg, mdl.GetDatabases()) + if err != nil { + o.logger.Error("check system requirements failed", zap.Error(err)) + return common.ErrSystemRequirementNotMet.Wrap(err).GenWithStackByArgs() + } + // check table schema conflicts + err = checkSchemaConflict(taskCfg, mdl.GetDatabases()) + if err != nil { + o.logger.Error("checkpoint schema conflicts with data files", zap.Error(err)) + return errors.Trace(err) + } + + dbMetas := mdl.GetDatabases() + web.BroadcastInitProgress(dbMetas) + + // db is only not nil in unit test + db := o.db + if db == nil { + // initiation of default db should be after BuildTLSConfig + db, err = importer.DBFromConfig(ctx, taskCfg.TiDB) + if err != nil { + return common.ErrDBConnect.Wrap(err) + } + } + + var keyspaceName string + if taskCfg.TikvImporter.Backend == config.BackendLocal { + keyspaceName = taskCfg.TikvImporter.KeyspaceName + if keyspaceName == "" { + keyspaceName, err = getKeyspaceName(db) + if err != nil && common.IsAccessDeniedNeedConfigPrivilegeError(err) { + // if the cluster is not multitenant we don't really need to know about the keyspace. + // since the doc does not say we require CONFIG privilege, + // spelling out the Access Denied error just confuses the users. + // hide such allowed errors unless log level is DEBUG. + o.logger.Info("keyspace is unspecified and target user has no config privilege, assuming dedicated cluster") + if o.logger.Level() > zapcore.DebugLevel { + err = nil + } + } + if err != nil { + o.logger.Warn("unable to get keyspace name, lightning will use empty keyspace name", zap.Error(err)) + } + } + o.logger.Info("acquired keyspace name", zap.String("keyspaceName", keyspaceName)) + } + + param := &importer.ControllerParam{ + DBMetas: dbMetas, + Status: &l.status, + DumpFileStorage: s, + OwnExtStorage: o.dumpFileStorage == nil, + DB: db, + CheckpointStorage: o.checkpointStorage, + CheckpointName: o.checkpointName, + DupIndicator: o.dupIndicator, + KeyspaceName: keyspaceName, + } + + var procedure *importer.Controller + procedure, err = importer.NewImportController(ctx, taskCfg, param) + if err != nil { + o.logger.Error("restore failed", log.ShortError(err)) + return errors.Trace(err) + } + + failpoint.Inject("orphanWriterGoRoutine", func() { + // don't exit too quickly to expose panic + defer time.Sleep(time.Second * 10) + }) + defer procedure.Close() + + err = procedure.Run(ctx) + return errors.Trace(err) +} + +// Stop stops the lightning server. +func (l *Lightning) Stop() { + l.cancelLock.Lock() + if l.cancel != nil { + l.taskCanceled = true + l.cancel() + } + l.cancelLock.Unlock() + if err := l.server.Shutdown(l.ctx); err != nil { + log.L().Warn("failed to shutdown HTTP server", log.ShortError(err)) + } + l.shutdown() +} + +// TaskCanceled return whether the current task is canceled. +func (l *Lightning) TaskCanceled() bool { + l.cancelLock.Lock() + defer l.cancelLock.Unlock() + return l.taskCanceled +} + +// Status return the sum size of file which has been imported to TiKV and the total size of source file. +func (l *Lightning) Status() (finished int64, total int64) { + finished = l.status.FinishedFileSize.Load() + total = l.status.TotalFileSize.Load() + return +} + +// Metrics returns the metrics of lightning. +// it's inited during `run`, so might return nil. +func (l *Lightning) Metrics() *metric.Metrics { + return l.metrics +} + +func writeJSONError(w http.ResponseWriter, code int, prefix string, err error) { + type errorResponse struct { + Error string `json:"error"` + } + + w.WriteHeader(code) + + if err != nil { + prefix += ": " + err.Error() + } + _ = json.NewEncoder(w).Encode(errorResponse{Error: prefix}) +} + +func parseTaskID(req *http.Request) (int64, string, error) { + path := strings.TrimPrefix(req.URL.Path, "/") + taskIDString := path + verb := "" + if i := strings.IndexByte(path, '/'); i >= 0 { + taskIDString = path[:i] + verb = path[i+1:] + } + + taskID, err := strconv.ParseInt(taskIDString, 10, 64) + if err != nil { + return 0, "", err + } + + return taskID, verb, nil +} + +func (l *Lightning) handleTask(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch req.Method { + case http.MethodGet: + taskID, _, err := parseTaskID(req) + // golint tells us to refactor this with switch stmt. + // However switch stmt doesn't support init-statements, + // hence if we follow it things might be worse. + // Anyway, this chain of if-else isn't unacceptable. + //nolint:gocritic + if e, ok := err.(*strconv.NumError); ok && e.Num == "" { + l.handleGetTask(w) + } else if err == nil { + l.handleGetOneTask(w, req, taskID) + } else { + writeJSONError(w, http.StatusBadRequest, "invalid task ID", err) + } + case http.MethodPost: + l.handlePostTask(w, req) + case http.MethodDelete: + l.handleDeleteOneTask(w, req) + case http.MethodPatch: + l.handlePatchOneTask(w, req) + default: + w.Header().Set("Allow", http.MethodGet+", "+http.MethodPost+", "+http.MethodDelete+", "+http.MethodPatch) + writeJSONError(w, http.StatusMethodNotAllowed, "only GET, POST, DELETE and PATCH are allowed", nil) + } +} + +func (l *Lightning) handleGetTask(w http.ResponseWriter) { + var response struct { + Current *int64 `json:"current"` + QueuedIDs []int64 `json:"queue"` + } + l.serverLock.Lock() + if l.taskCfgs != nil { + response.QueuedIDs = l.taskCfgs.AllIDs() + } else { + response.QueuedIDs = []int64{} + } + l.serverLock.Unlock() + + l.cancelLock.Lock() + if l.cancel != nil && l.curTask != nil { + response.Current = new(int64) + *response.Current = l.curTask.TaskID + } + l.cancelLock.Unlock() + + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(response) +} + +func (l *Lightning) handleGetOneTask(w http.ResponseWriter, req *http.Request, taskID int64) { + var task *config.Config + + l.cancelLock.Lock() + if l.curTask != nil && l.curTask.TaskID == taskID { + task = l.curTask + } + l.cancelLock.Unlock() + + if task == nil && l.taskCfgs != nil { + task, _ = l.taskCfgs.Get(taskID) + } + + if task == nil { + writeJSONError(w, http.StatusNotFound, "task ID not found", nil) + return + } + + json, err := json.Marshal(task) + if err != nil { + writeJSONError(w, http.StatusInternalServerError, "unable to serialize task", err) + return + } + + writeBytesCompressed(w, req, json) +} + +func (l *Lightning) handlePostTask(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Cache-Control", "no-store") + l.serverLock.Lock() + defer l.serverLock.Unlock() + if l.taskCfgs == nil { + // l.taskCfgs is non-nil only if Lightning is started with RunServer(). + // Without the server mode this pointer is default to be nil. + writeJSONError(w, http.StatusNotImplemented, "server-mode not enabled", nil) + return + } + + type taskResponse struct { + ID int64 `json:"id"` + } + + data, err := io.ReadAll(req.Body) + if err != nil { + writeJSONError(w, http.StatusBadRequest, "cannot read request", err) + return + } + log.L().Info("received task config") + + cfg := config.NewConfig() + if err = cfg.LoadFromGlobal(l.globalCfg); err != nil { + writeJSONError(w, http.StatusInternalServerError, "cannot restore from global config", err) + return + } + if err = cfg.LoadFromTOML(data); err != nil { + writeJSONError(w, http.StatusBadRequest, "cannot parse task (must be TOML)", err) + return + } + if err = cfg.Adjust(l.ctx); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid task configuration", err) + return + } + + l.taskCfgs.Push(cfg) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(taskResponse{ID: cfg.TaskID}) +} + +func (l *Lightning) handleDeleteOneTask(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + taskID, _, err := parseTaskID(req) + if err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid task ID", err) + return + } + + var cancel context.CancelFunc + cancelSuccess := false + + l.cancelLock.Lock() + if l.cancel != nil && l.curTask != nil && l.curTask.TaskID == taskID { + cancel = l.cancel + l.cancel = nil + } + l.cancelLock.Unlock() + + if cancel != nil { + cancel() + cancelSuccess = true + } else if l.taskCfgs != nil { + cancelSuccess = l.taskCfgs.Remove(taskID) + } + + log.L().Info("canceled task", zap.Int64("taskID", taskID), zap.Bool("success", cancelSuccess)) + + if cancelSuccess { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + } else { + writeJSONError(w, http.StatusNotFound, "task ID not found", nil) + } +} + +func (l *Lightning) handlePatchOneTask(w http.ResponseWriter, req *http.Request) { + if l.taskCfgs == nil { + writeJSONError(w, http.StatusNotImplemented, "server-mode not enabled", nil) + return + } + + taskID, verb, err := parseTaskID(req) + if err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid task ID", err) + return + } + + moveSuccess := false + switch verb { + case "front": + moveSuccess = l.taskCfgs.MoveToFront(taskID) + case "back": + moveSuccess = l.taskCfgs.MoveToBack(taskID) + default: + writeJSONError(w, http.StatusBadRequest, "unknown patch action", nil) + return + } + + if moveSuccess { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + } else { + writeJSONError(w, http.StatusNotFound, "task ID not found", nil) + } +} + +func writeBytesCompressed(w http.ResponseWriter, req *http.Request, b []byte) { + if !strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") { + _, _ = w.Write(b) + return + } + + w.Header().Set("Content-Encoding", "gzip") + w.WriteHeader(http.StatusOK) + gw, _ := gzip.NewWriterLevel(w, gzip.BestSpeed) + _, _ = gw.Write(b) + _ = gw.Close() +} + +func handleProgressTask(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + res, err := web.MarshalTaskProgress() + if err == nil { + writeBytesCompressed(w, req, res) + } else { + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(err.Error()) + } +} + +func handleProgressTable(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + tableName := req.URL.Query().Get("t") + res, err := web.MarshalTableCheckpoints(tableName) + if err == nil { + writeBytesCompressed(w, req, res) + } else { + if errors.IsNotFound(err) { + w.WriteHeader(http.StatusNotFound) + } else { + w.WriteHeader(http.StatusInternalServerError) + } + _ = json.NewEncoder(w).Encode(err.Error()) + } +} + +func handlePause(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch req.Method { + case http.MethodGet: + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"paused":%v}`, importer.DeliverPauser.IsPaused()) + + case http.MethodPut: + w.WriteHeader(http.StatusOK) + importer.DeliverPauser.Pause() + log.L().Info("progress paused") + _, _ = w.Write([]byte("{}")) + + default: + w.Header().Set("Allow", http.MethodGet+", "+http.MethodPut) + writeJSONError(w, http.StatusMethodNotAllowed, "only GET and PUT are allowed", nil) + } +} + +func handleResume(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + switch req.Method { + case http.MethodPut: + w.WriteHeader(http.StatusOK) + importer.DeliverPauser.Resume() + log.L().Info("progress resumed") + _, _ = w.Write([]byte("{}")) + + default: + w.Header().Set("Allow", http.MethodPut) + writeJSONError(w, http.StatusMethodNotAllowed, "only PUT is allowed", nil) + } +} + +func handleLogLevel(w http.ResponseWriter, req *http.Request) { + w.Header().Set("Content-Type", "application/json") + + var logLevel struct { + Level zapcore.Level `json:"level"` + } + + switch req.Method { + case http.MethodGet: + logLevel.Level = log.Level() + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(logLevel) + + case http.MethodPut, http.MethodPost: + if err := json.NewDecoder(req.Body).Decode(&logLevel); err != nil { + writeJSONError(w, http.StatusBadRequest, "invalid log level", err) + return + } + oldLevel := log.SetLevel(zapcore.InfoLevel) + log.L().Info("changed log level. No effects if task has specified its logger", + zap.Stringer("old", oldLevel), + zap.Stringer("new", logLevel.Level)) + log.SetLevel(logLevel.Level) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + + default: + w.Header().Set("Allow", http.MethodGet+", "+http.MethodPut+", "+http.MethodPost) + writeJSONError(w, http.StatusMethodNotAllowed, "only GET, PUT and POST are allowed", nil) + } +} + +func checkSystemRequirement(cfg *config.Config, dbsMeta []*mydump.MDDatabaseMeta) error { + // in local mode, we need to read&write a lot of L0 sst files, so we need to check system max open files limit + if cfg.TikvImporter.Backend == config.BackendLocal { + // estimate max open files = {top N(TableConcurrency) table sizes} / {MemoryTableSize} + tableTotalSizes := make([]int64, 0) + for _, dbs := range dbsMeta { + for _, tb := range dbs.Tables { + tableTotalSizes = append(tableTotalSizes, tb.TotalSize) + } + } + slices.SortFunc(tableTotalSizes, func(i, j int64) int { + return cmp.Compare(j, i) + }) + topNTotalSize := int64(0) + for i := 0; i < len(tableTotalSizes) && i < cfg.App.TableConcurrency; i++ { + topNTotalSize += tableTotalSizes[i] + } + + // region-concurrency: number of LocalWriters writing SST files. + // 2*totalSize/memCacheSize: number of Pebble MemCache files. + maxDBFiles := topNTotalSize / int64(cfg.TikvImporter.LocalWriterMemCacheSize) * 2 + // the pebble db and all import routine need upto maxDBFiles fds for read and write. + maxOpenDBFiles := maxDBFiles * (1 + int64(cfg.TikvImporter.RangeConcurrency)) + estimateMaxFiles := local.RlimT(cfg.App.RegionConcurrency) + local.RlimT(maxOpenDBFiles) + if err := local.VerifyRLimit(estimateMaxFiles); err != nil { + return err + } + } + + return nil +} + +// checkSchemaConflict return error if checkpoint table scheme is conflict with data files +func checkSchemaConflict(cfg *config.Config, dbsMeta []*mydump.MDDatabaseMeta) error { + if cfg.Checkpoint.Enable && cfg.Checkpoint.Driver == config.CheckpointDriverMySQL { + for _, db := range dbsMeta { + if db.Name == cfg.Checkpoint.Schema { + for _, tb := range db.Tables { + if checkpoints.IsCheckpointTable(tb.Name) { + return common.ErrCheckpointSchemaConflict.GenWithStack("checkpoint table `%s`.`%s` conflict with data files. Please change the `checkpoint.schema` config or set `checkpoint.driver` to \"file\" instead", db.Name, tb.Name) + } + } + } + } + } + return nil +} + +// CheckpointRemove removes the checkpoint of the given table. +func CheckpointRemove(ctx context.Context, cfg *config.Config, tableName string) error { + cpdb, err := checkpoints.OpenCheckpointsDB(ctx, cfg) + if err != nil { + return errors.Trace(err) + } + //nolint: errcheck + defer cpdb.Close() + + // try to remove the metadata first. + taskCp, err := cpdb.TaskCheckpoint(ctx) + if err != nil { + return errors.Trace(err) + } + // a empty id means this task is not inited, we needn't further check metas. + if taskCp != nil && taskCp.TaskID != 0 { + // try to clean up table metas if exists + if err = CleanupMetas(ctx, cfg, tableName); err != nil { + return errors.Trace(err) + } + } + + return errors.Trace(cpdb.RemoveCheckpoint(ctx, tableName)) +} + +// CleanupMetas removes the table metas of the given table. +func CleanupMetas(ctx context.Context, cfg *config.Config, tableName string) error { + if tableName == "all" { + tableName = "" + } + // try to clean up table metas if exists + db, err := importer.DBFromConfig(ctx, cfg.TiDB) + if err != nil { + return errors.Trace(err) + } + + tableMetaExist, err := common.TableExists(ctx, db, cfg.App.MetaSchemaName, importer.TableMetaTableName) + if err != nil { + return errors.Trace(err) + } + if tableMetaExist { + metaTableName := common.UniqueTable(cfg.App.MetaSchemaName, importer.TableMetaTableName) + if err = importer.RemoveTableMetaByTableName(ctx, db, metaTableName, tableName); err != nil { + return errors.Trace(err) + } + } + + exist, err := common.TableExists(ctx, db, cfg.App.MetaSchemaName, importer.TaskMetaTableName) + if err != nil || !exist { + return errors.Trace(err) + } + return errors.Trace(importer.MaybeCleanupAllMetas(ctx, log.L(), db, cfg.App.MetaSchemaName, tableMetaExist)) +} + +// SwitchMode switches the mode of the TiKV cluster. +func SwitchMode(ctx context.Context, cli pdhttp.Client, tls *tls.Config, mode string, ranges ...*import_sstpb.Range) error { + var m import_sstpb.SwitchMode + switch mode { + case config.ImportMode: + m = import_sstpb.SwitchMode_Import + case config.NormalMode: + m = import_sstpb.SwitchMode_Normal + default: + return errors.Errorf("invalid mode %s, must use %s or %s", mode, config.ImportMode, config.NormalMode) + } + + return tikv.ForAllStores( + ctx, + cli, + metapb.StoreState_Offline, + func(c context.Context, store *pdhttp.MetaStore) error { + return tikv.SwitchMode(c, tls, store.Address, m, ranges...) + }, + ) +} + +func updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath string, expiry time.Duration) error { + rootKey, err := parsePrivateKey(rootKeyPath) + if err != nil { + return err + } + rootCaPem, err := os.ReadFile(rootCaPath) + if err != nil { + return err + } + rootCaDer, _ := pem.Decode(rootCaPem) + rootCa, err := x509.ParseCertificate(rootCaDer.Bytes) + if err != nil { + return err + } + key, err := parsePrivateKey(keyPath) + if err != nil { + return err + } + certPem, err := os.ReadFile(certPath) + if err != nil { + panic(err) + } + certDer, _ := pem.Decode(certPem) + cert, err := x509.ParseCertificate(certDer.Bytes) + if err != nil { + return err + } + cert.NotBefore = time.Now() + cert.NotAfter = time.Now().Add(expiry) + derBytes, err := x509.CreateCertificate(rand.Reader, cert, rootCa, &key.PublicKey, rootKey) + if err != nil { + return err + } + return os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), 0o600) +} + +func parsePrivateKey(keyPath string) (*ecdsa.PrivateKey, error) { + keyPemBlock, err := os.ReadFile(keyPath) + if err != nil { + return nil, err + } + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPemBlock = pem.Decode(keyPemBlock) + if keyDERBlock == nil { + return nil, errors.New("failed to find PEM block with type ending in \"PRIVATE KEY\"") + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + } + return x509.ParseECPrivateKey(keyDERBlock.Bytes) +} diff --git a/pkg/autoid_service/autoid.go b/pkg/autoid_service/autoid.go index dbbe7b8ee3353..528712b638642 100644 --- a/pkg/autoid_service/autoid.go +++ b/pkg/autoid_service/autoid.go @@ -462,11 +462,11 @@ func (s *Service) allocAutoID(ctx context.Context, req *autoid.AutoIDRequest) (* return nil, errors.New("not leader") } - failpoint.Inject("mockErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockErr")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("mock reload failed")) + return nil, errors.New("mock reload failed") } - }) + } val := s.getAlloc(req.DbID, req.TblID, req.IsUnsigned) val.Lock() diff --git a/pkg/autoid_service/autoid.go__failpoint_stash__ b/pkg/autoid_service/autoid.go__failpoint_stash__ new file mode 100644 index 0000000000000..dbbe7b8ee3353 --- /dev/null +++ b/pkg/autoid_service/autoid.go__failpoint_stash__ @@ -0,0 +1,612 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package autoid + +import ( + "context" + "crypto/tls" + "math" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/autoid" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/keyspace" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + autoid1 "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/util/etcd" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/mathutil" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" +) + +var ( + errAutoincReadFailed = errors.New("auto increment action failed") +) + +const ( + autoIDLeaderPath = "tidb/autoid/leader" +) + +type autoIDKey struct { + dbID int64 + tblID int64 +} + +type autoIDValue struct { + sync.Mutex + base int64 + end int64 + isUnsigned bool + token chan struct{} +} + +func (alloc *autoIDValue) alloc4Unsigned(ctx context.Context, store kv.Storage, dbID, tblID int64, isUnsigned bool, + n uint64, increment, offset int64) (min int64, max int64, err error) { + // Check offset rebase if necessary. + if uint64(offset-1) > uint64(alloc.base) { + if err := alloc.rebase4Unsigned(ctx, store, dbID, tblID, uint64(offset-1)); err != nil { + return 0, 0, err + } + } + // calcNeededBatchSize calculates the total batch size needed. + n1 := calcNeededBatchSize(alloc.base, int64(n), increment, offset, isUnsigned) + + // The local rest is not enough for alloc. + if uint64(alloc.base)+uint64(n1) > uint64(alloc.end) || alloc.base == 0 { + var newBase, newEnd int64 + nextStep := int64(batch) + fromBase := alloc.base + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) + var err1 error + newBase, err1 = idAcc.Get() + if err1 != nil { + return err1 + } + // calcNeededBatchSize calculates the total batch size needed on new base. + if alloc.base == 0 || newBase != alloc.end { + alloc.base = newBase + alloc.end = newBase + n1 = calcNeededBatchSize(newBase, int64(n), increment, offset, isUnsigned) + } + + // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. + if nextStep < n1 { + nextStep = n1 + } + tmpStep := int64(mathutil.Min(math.MaxUint64-uint64(newBase), uint64(nextStep))) + // The global rest is not enough for alloc. + if tmpStep < n1 { + return errAutoincReadFailed + } + newEnd, err1 = idAcc.Inc(tmpStep) + return err1 + }) + if err != nil { + return 0, 0, err + } + if uint64(newBase) == math.MaxUint64 { + return 0, 0, errAutoincReadFailed + } + logutil.BgLogger().Info("alloc4Unsigned from", + zap.String("category", "autoid service"), + zap.Int64("dbID", dbID), + zap.Int64("tblID", tblID), + zap.Int64("from base", fromBase), + zap.Int64("from end", alloc.end), + zap.Int64("to base", newBase), + zap.Int64("to end", newEnd)) + alloc.end = newEnd + } + min = alloc.base + // Use uint64 n directly. + alloc.base = int64(uint64(alloc.base) + uint64(n1)) + return min, alloc.base, nil +} + +func (alloc *autoIDValue) alloc4Signed(ctx context.Context, + store kv.Storage, + dbID, tblID int64, + isUnsigned bool, + n uint64, increment, offset int64) (min int64, max int64, err error) { + // Check offset rebase if necessary. + if offset-1 > alloc.base { + if err := alloc.rebase4Signed(ctx, store, dbID, tblID, offset-1); err != nil { + return 0, 0, err + } + } + // calcNeededBatchSize calculates the total batch size needed. + n1 := calcNeededBatchSize(alloc.base, int64(n), increment, offset, isUnsigned) + + // Condition alloc.base+N1 > alloc.end will overflow when alloc.base + N1 > MaxInt64. So need this. + if math.MaxInt64-alloc.base <= n1 { + return 0, 0, errAutoincReadFailed + } + + // The local rest is not enough for allocN. + // If alloc.base is 0, the alloc may not be initialized, force fetch from remote. + if alloc.base+n1 > alloc.end || alloc.base == 0 { + var newBase, newEnd int64 + nextStep := int64(batch) + fromBase := alloc.base + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) + var err1 error + newBase, err1 = idAcc.Get() + if err1 != nil { + return err1 + } + // calcNeededBatchSize calculates the total batch size needed on global base. + // alloc.base == 0 means uninitialized + // newBase != alloc.end means something abnormal, maybe transaction conflict and retry? + if alloc.base == 0 || newBase != alloc.end { + alloc.base = newBase + alloc.end = newBase + n1 = calcNeededBatchSize(newBase, int64(n), increment, offset, isUnsigned) + } + // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. + if nextStep < n1 { + nextStep = n1 + } + tmpStep := mathutil.Min(math.MaxInt64-newBase, nextStep) + // The global rest is not enough for alloc. + if tmpStep < n1 { + return errAutoincReadFailed + } + newEnd, err1 = idAcc.Inc(tmpStep) + return err1 + }) + if err != nil { + return 0, 0, err + } + if newBase == math.MaxInt64 { + return 0, 0, errAutoincReadFailed + } + logutil.BgLogger().Info("alloc4Signed from", + zap.String("category", "autoid service"), + zap.Int64("dbID", dbID), + zap.Int64("tblID", tblID), + zap.Int64("from base", fromBase), + zap.Int64("from end", alloc.end), + zap.Int64("to base", newBase), + zap.Int64("to end", newEnd)) + alloc.end = newEnd + } + min = alloc.base + alloc.base += n1 + return min, alloc.base, nil +} + +func (alloc *autoIDValue) rebase4Unsigned(ctx context.Context, + store kv.Storage, + dbID, tblID int64, + requiredBase uint64) error { + // Satisfied by alloc.base, nothing to do. + if requiredBase <= uint64(alloc.base) { + return nil + } + // Satisfied by alloc.end, need to update alloc.base. + if requiredBase > uint64(alloc.base) && requiredBase <= uint64(alloc.end) { + alloc.base = int64(requiredBase) + return nil + } + + var newBase, newEnd uint64 + var oldValue int64 + startTime := time.Now() + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) + currentEnd, err1 := idAcc.Get() + if err1 != nil { + return err1 + } + oldValue = currentEnd + uCurrentEnd := uint64(currentEnd) + newBase = mathutil.Max(uCurrentEnd, requiredBase) + newEnd = mathutil.Min(math.MaxUint64-uint64(batch), newBase) + uint64(batch) + _, err1 = idAcc.Inc(int64(newEnd - uCurrentEnd)) + return err1 + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return err + } + + logutil.BgLogger().Info("rebase4Unsigned from", + zap.String("category", "autoid service"), + zap.Int64("dbID", dbID), + zap.Int64("tblID", tblID), + zap.Int64("from", oldValue), + zap.Uint64("to", newEnd)) + alloc.base, alloc.end = int64(newBase), int64(newEnd) + return nil +} + +func (alloc *autoIDValue) rebase4Signed(ctx context.Context, store kv.Storage, dbID, tblID int64, requiredBase int64) error { + // Satisfied by alloc.base, nothing to do. + if requiredBase <= alloc.base { + return nil + } + // Satisfied by alloc.end, need to update alloc.base. + if requiredBase > alloc.base && requiredBase <= alloc.end { + alloc.base = requiredBase + return nil + } + + var oldValue, newBase, newEnd int64 + startTime := time.Now() + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) + currentEnd, err1 := idAcc.Get() + if err1 != nil { + return err1 + } + oldValue = currentEnd + newBase = mathutil.Max(currentEnd, requiredBase) + newEnd = mathutil.Min(math.MaxInt64-batch, newBase) + batch + _, err1 = idAcc.Inc(newEnd - currentEnd) + return err1 + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return err + } + + logutil.BgLogger().Info("rebase4Signed from", + zap.Int64("dbID", dbID), + zap.Int64("tblID", tblID), + zap.Int64("from", oldValue), + zap.Int64("to", newEnd), + zap.String("category", "autoid service")) + alloc.base, alloc.end = newBase, newEnd + return nil +} + +// Service implement the grpc AutoIDAlloc service, defined in kvproto/pkg/autoid. +type Service struct { + autoIDLock sync.Mutex + autoIDMap map[autoIDKey]*autoIDValue + + leaderShip owner.Manager + store kv.Storage +} + +// New return a Service instance. +func New(selfAddr string, etcdAddr []string, store kv.Storage, tlsConfig *tls.Config) *Service { + cfg := config.GetGlobalConfig() + etcdLogCfg := zap.NewProductionConfig() + + cli, err := clientv3.New(clientv3.Config{ + LogConfig: &etcdLogCfg, + Endpoints: etcdAddr, + AutoSyncInterval: 30 * time.Second, + DialTimeout: 5 * time.Second, + DialOptions: []grpc.DialOption{ + grpc.WithBackoffMaxDelay(time.Second * 3), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: time.Duration(cfg.TiKVClient.GrpcKeepAliveTime) * time.Second, + Timeout: time.Duration(cfg.TiKVClient.GrpcKeepAliveTimeout) * time.Second, + }), + }, + TLS: tlsConfig, + }) + if store.GetCodec().GetKeyspace() != nil { + etcd.SetEtcdCliByNamespace(cli, keyspace.MakeKeyspaceEtcdNamespaceSlash(store.GetCodec())) + } + if err != nil { + panic(err) + } + return newWithCli(selfAddr, cli, store) +} + +func newWithCli(selfAddr string, cli *clientv3.Client, store kv.Storage) *Service { + l := owner.NewOwnerManager(context.Background(), cli, "autoid", selfAddr, autoIDLeaderPath) + service := &Service{ + autoIDMap: make(map[autoIDKey]*autoIDValue), + leaderShip: l, + store: store, + } + l.SetListener(&ownerListener{ + Service: service, + selfAddr: selfAddr, + }) + // 10 means that autoid service's etcd lease is 10s. + err := l.CampaignOwner(10) + if err != nil { + panic(err) + } + + return service +} + +type mockClient struct { + Service +} + +func (m *mockClient) AllocAutoID(ctx context.Context, in *autoid.AutoIDRequest, _ ...grpc.CallOption) (*autoid.AutoIDResponse, error) { + return m.Service.AllocAutoID(ctx, in) +} + +func (m *mockClient) Rebase(ctx context.Context, in *autoid.RebaseRequest, _ ...grpc.CallOption) (*autoid.RebaseResponse, error) { + return m.Service.Rebase(ctx, in) +} + +var global = make(map[string]*mockClient) + +// MockForTest is used for testing, the UT test and unistore use this. +func MockForTest(store kv.Storage) autoid.AutoIDAllocClient { + uuid := store.UUID() + ret, ok := global[uuid] + if !ok { + ret = &mockClient{ + Service{ + autoIDMap: make(map[autoIDKey]*autoIDValue), + leaderShip: nil, + store: store, + }, + } + global[uuid] = ret + } + return ret +} + +// Close closes the Service and clean up resource. +func (s *Service) Close() { + if s.leaderShip != nil && s.leaderShip.IsOwner() { + s.leaderShip.Cancel() + } +} + +// seekToFirstAutoIDSigned seeks to the next valid signed position. +func seekToFirstAutoIDSigned(base, increment, offset int64) int64 { + nr := (base + increment - offset) / increment + nr = nr*increment + offset + return nr +} + +// seekToFirstAutoIDUnSigned seeks to the next valid unsigned position. +func seekToFirstAutoIDUnSigned(base, increment, offset uint64) uint64 { + nr := (base + increment - offset) / increment + nr = nr*increment + offset + return nr +} + +func calcNeededBatchSize(base, n, increment, offset int64, isUnsigned bool) int64 { + if increment == 1 { + return n + } + if isUnsigned { + // SeekToFirstAutoIDUnSigned seeks to the next unsigned valid position. + nr := seekToFirstAutoIDUnSigned(uint64(base), uint64(increment), uint64(offset)) + // calculate the total batch size needed. + nr += (uint64(n) - 1) * uint64(increment) + return int64(nr - uint64(base)) + } + nr := seekToFirstAutoIDSigned(base, increment, offset) + // calculate the total batch size needed. + nr += (n - 1) * increment + return nr - base +} + +const batch = 4000 + +// AllocAutoID implements gRPC AutoIDAlloc interface. +func (s *Service) AllocAutoID(ctx context.Context, req *autoid.AutoIDRequest) (*autoid.AutoIDResponse, error) { + serviceKeyspaceID := uint32(s.store.GetCodec().GetKeyspaceID()) + if req.KeyspaceID != serviceKeyspaceID { + logutil.BgLogger().Info("Current service is not request keyspace leader.", zap.Uint32("req-keyspace-id", req.KeyspaceID), zap.Uint32("service-keyspace-id", serviceKeyspaceID)) + return nil, errors.Trace(errors.New("not leader")) + } + var res *autoid.AutoIDResponse + for { + var err error + res, err = s.allocAutoID(ctx, req) + if err != nil { + return nil, errors.Trace(err) + } + if res != nil { + break + } + } + return res, nil +} + +func (s *Service) getAlloc(dbID, tblID int64, isUnsigned bool) *autoIDValue { + key := autoIDKey{dbID: dbID, tblID: tblID} + s.autoIDLock.Lock() + defer s.autoIDLock.Unlock() + + val, ok := s.autoIDMap[key] + if !ok { + val = &autoIDValue{ + isUnsigned: isUnsigned, + token: make(chan struct{}, 1), + } + s.autoIDMap[key] = val + } + + return val +} + +func (s *Service) allocAutoID(ctx context.Context, req *autoid.AutoIDRequest) (*autoid.AutoIDResponse, error) { + if s.leaderShip != nil && !s.leaderShip.IsOwner() { + logutil.BgLogger().Info("Alloc AutoID fail, not leader", zap.String("category", "autoid service")) + return nil, errors.New("not leader") + } + + failpoint.Inject("mockErr", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("mock reload failed")) + } + }) + + val := s.getAlloc(req.DbID, req.TblID, req.IsUnsigned) + val.Lock() + defer val.Unlock() + + if req.N == 0 { + if val.base != 0 { + return &autoid.AutoIDResponse{ + Min: val.base, + Max: val.base, + }, nil + } + // This item is not initialized, get the data from remote. + var currentEnd int64 + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, s.store, true, func(_ context.Context, txn kv.Transaction) error { + idAcc := meta.NewMeta(txn).GetAutoIDAccessors(req.DbID, req.TblID).IncrementID(model.TableInfoVersion5) + var err1 error + currentEnd, err1 = idAcc.Get() + if err1 != nil { + return err1 + } + val.base = currentEnd + val.end = currentEnd + return nil + }) + if err != nil { + return &autoid.AutoIDResponse{Errmsg: []byte(err.Error())}, nil + } + return &autoid.AutoIDResponse{ + Min: currentEnd, + Max: currentEnd, + }, nil + } + + var min, max int64 + var err error + if req.IsUnsigned { + min, max, err = val.alloc4Unsigned(ctx, s.store, req.DbID, req.TblID, req.IsUnsigned, req.N, req.Increment, req.Offset) + } else { + min, max, err = val.alloc4Signed(ctx, s.store, req.DbID, req.TblID, req.IsUnsigned, req.N, req.Increment, req.Offset) + } + + if err != nil { + return &autoid.AutoIDResponse{Errmsg: []byte(err.Error())}, nil + } + return &autoid.AutoIDResponse{ + Min: min, + Max: max, + }, nil +} + +func (alloc *autoIDValue) forceRebase(ctx context.Context, store kv.Storage, dbID, tblID, requiredBase int64, isUnsigned bool) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + var oldValue int64 + err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) + currentEnd, err1 := idAcc.Get() + if err1 != nil { + return err1 + } + oldValue = currentEnd + var step int64 + if !isUnsigned { + step = requiredBase - currentEnd + } else { + uRequiredBase, uCurrentEnd := uint64(requiredBase), uint64(currentEnd) + step = int64(uRequiredBase - uCurrentEnd) + } + _, err1 = idAcc.Inc(step) + return err1 + }) + if err != nil { + return err + } + logutil.BgLogger().Info("forceRebase from", + zap.Int64("dbID", dbID), + zap.Int64("tblID", tblID), + zap.Int64("from", oldValue), + zap.Int64("to", requiredBase), + zap.Bool("isUnsigned", isUnsigned), + zap.String("category", "autoid service")) + alloc.base, alloc.end = requiredBase, requiredBase + return nil +} + +// Rebase implements gRPC AutoIDAlloc interface. +// req.N = 0 is handled specially, it is used to return the current auto ID value. +func (s *Service) Rebase(ctx context.Context, req *autoid.RebaseRequest) (*autoid.RebaseResponse, error) { + if s.leaderShip != nil && !s.leaderShip.IsOwner() { + logutil.BgLogger().Info("Rebase() fail, not leader", zap.String("category", "autoid service")) + return nil, errors.New("not leader") + } + + val := s.getAlloc(req.DbID, req.TblID, req.IsUnsigned) + val.Lock() + defer val.Unlock() + + if req.Force { + err := val.forceRebase(ctx, s.store, req.DbID, req.TblID, req.Base, req.IsUnsigned) + if err != nil { + return &autoid.RebaseResponse{Errmsg: []byte(err.Error())}, nil + } + } + + var err error + if req.IsUnsigned { + err = val.rebase4Unsigned(ctx, s.store, req.DbID, req.TblID, uint64(req.Base)) + } else { + err = val.rebase4Signed(ctx, s.store, req.DbID, req.TblID, req.Base) + } + if err != nil { + return &autoid.RebaseResponse{Errmsg: []byte(err.Error())}, nil + } + return &autoid.RebaseResponse{}, nil +} + +type ownerListener struct { + *Service + selfAddr string +} + +var _ owner.Listener = (*ownerListener)(nil) + +func (l *ownerListener) OnBecomeOwner() { + // Reset the map to avoid a case that a node lose leadership and regain it, then + // improperly use the stale map to serve the autoid requests. + // See https://github.com/pingcap/tidb/issues/52600 + l.autoIDLock.Lock() + clear(l.autoIDMap) + l.autoIDLock.Unlock() + + logutil.BgLogger().Info("leader change of autoid service, this node become owner", + zap.String("addr", l.selfAddr), + zap.String("category", "autoid service")) +} + +func (*ownerListener) OnRetireOwner() { +} + +func init() { + autoid1.MockForTest = MockForTest +} diff --git a/pkg/autoid_service/binding__failpoint_binding__.go b/pkg/autoid_service/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..2c1025c7f434f --- /dev/null +++ b/pkg/autoid_service/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package autoid + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/bindinfo/binding__failpoint_binding__.go b/pkg/bindinfo/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..331f3106c1860 --- /dev/null +++ b/pkg/bindinfo/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package bindinfo + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/bindinfo/global_handle.go b/pkg/bindinfo/global_handle.go index db027b6c5d2f7..8e23689fd5c58 100644 --- a/pkg/bindinfo/global_handle.go +++ b/pkg/bindinfo/global_handle.go @@ -716,9 +716,9 @@ func (h *globalBindingHandle) LoadBindingsFromStorage(sctx sessionctx.Context, s } func (h *globalBindingHandle) loadBindingsFromStorageInternal(sqlDigest string) (any, error) { - failpoint.Inject("load_bindings_from_storage_internal_timeout", func() { + if _, _err_ := failpoint.Eval(_curpkg_("load_bindings_from_storage_internal_timeout")); _err_ == nil { time.Sleep(time.Second) - }) + } var bindings Bindings selectStmt := fmt.Sprintf("SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source, sql_digest, plan_digest FROM mysql.bind_info where sql_digest = '%s'", sqlDigest) err := h.callWithSCtx(false, func(sctx sessionctx.Context) error { diff --git a/pkg/bindinfo/global_handle.go__failpoint_stash__ b/pkg/bindinfo/global_handle.go__failpoint_stash__ new file mode 100644 index 0000000000000..db027b6c5d2f7 --- /dev/null +++ b/pkg/bindinfo/global_handle.go__failpoint_stash__ @@ -0,0 +1,745 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bindinfo + +import ( + "context" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/bindinfo/internal/logutil" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/hint" + utilparser "github.com/pingcap/tidb/pkg/util/parser" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +// GlobalBindingHandle is used to handle all global sql bind operations. +type GlobalBindingHandle interface { + // Methods for create, get, drop global sql bindings. + + // MatchGlobalBinding returns the matched binding for this statement. + MatchGlobalBinding(sctx sessionctx.Context, fuzzyDigest string, tableNames []*ast.TableName) (matchedBinding Binding, isMatched bool) + + // GetAllGlobalBindings returns all bind records in cache. + GetAllGlobalBindings() (bindings Bindings) + + // CreateGlobalBinding creates a Bindings to the storage and the cache. + // It replaces all the exists bindings for the same normalized SQL. + CreateGlobalBinding(sctx sessionctx.Context, binding Binding) (err error) + + // DropGlobalBinding drop Bindings to the storage and Bindings int the cache. + DropGlobalBinding(sqlDigest string) (deletedRows uint64, err error) + + // SetGlobalBindingStatus set a Bindings's status to the storage and bind cache. + SetGlobalBindingStatus(newStatus, sqlDigest string) (ok bool, err error) + + // AddInvalidGlobalBinding adds Bindings which needs to be deleted into invalidBindingCache. + AddInvalidGlobalBinding(invalidBinding Binding) + + // DropInvalidGlobalBinding executes the drop Bindings tasks. + DropInvalidGlobalBinding() + + // Methods for load and clear global sql bindings. + + // Reset is to reset the BindHandle and clean old info. + Reset() + + // LoadFromStorageToCache loads global bindings from storage to the memory cache. + LoadFromStorageToCache(fullLoad bool) (err error) + + // GCGlobalBinding physically removes the deleted bind records in mysql.bind_info. + GCGlobalBinding() (err error) + + // Methods for memory control. + + // Size returns the size of bind info cache. + Size() int + + // SetBindingCacheCapacity reset the capacity for the bindingCache. + SetBindingCacheCapacity(capacity int64) + + // GetMemUsage returns the memory usage for the bind cache. + GetMemUsage() (memUsage int64) + + // GetMemCapacity returns the memory capacity for the bind cache. + GetMemCapacity() (memCapacity int64) + + // Clear resets the bind handle. It is only used for test. + Clear() + + // FlushGlobalBindings flushes the Bindings in temp maps to storage and loads them into cache. + FlushGlobalBindings() error + + // Methods for Auto Capture. + + // CaptureBaselines is used to automatically capture plan baselines. + CaptureBaselines() + + variable.Statistics +} + +// globalBindingHandle is used to handle all global sql bind operations. +type globalBindingHandle struct { + sPool util.SessionPool + + fuzzyBindingCache atomic.Value + + // lastTaskTime records the last update time for the global sql bind cache. + // This value is used to avoid reload duplicated bindings from storage. + lastUpdateTime atomic.Value + + // invalidBindings indicates the invalid bindings found during querying. + // A binding will be deleted from this map, after 2 bind-lease, after it is dropped from the kv. + invalidBindings *invalidBindingCache + + // syncBindingSingleflight is used to synchronize the execution of `LoadFromStorageToCache` method. + syncBindingSingleflight singleflight.Group +} + +// Lease influences the duration of loading bind info and handling invalid bind. +var Lease = 3 * time.Second + +const ( + // OwnerKey is the bindinfo owner path that is saved to etcd. + OwnerKey = "/tidb/bindinfo/owner" + // Prompt is the prompt for bindinfo owner manager. + Prompt = "bindinfo" + // BuiltinPseudoSQL4BindLock is used to simulate LOCK TABLE for mysql.bind_info. + BuiltinPseudoSQL4BindLock = "builtin_pseudo_sql_for_bind_lock" + + // LockBindInfoSQL simulates LOCK TABLE by updating a same row in each pessimistic transaction. + LockBindInfoSQL = `UPDATE mysql.bind_info SET source= 'builtin' WHERE original_sql= 'builtin_pseudo_sql_for_bind_lock'` + + // StmtRemoveDuplicatedPseudoBinding is used to remove duplicated pseudo binding. + // After using BR to sync bind_info between two clusters, the pseudo binding may be duplicated, and + // BR use this statement to remove duplicated rows, and this SQL should only be executed by BR. + StmtRemoveDuplicatedPseudoBinding = `DELETE FROM mysql.bind_info + WHERE original_sql='builtin_pseudo_sql_for_bind_lock' AND + _tidb_rowid NOT IN ( -- keep one arbitrary pseudo binding + SELECT _tidb_rowid FROM mysql.bind_info WHERE original_sql='builtin_pseudo_sql_for_bind_lock' limit 1)` +) + +// NewGlobalBindingHandle creates a new GlobalBindingHandle. +func NewGlobalBindingHandle(sPool util.SessionPool) GlobalBindingHandle { + handle := &globalBindingHandle{sPool: sPool} + handle.Reset() + return handle +} + +func (h *globalBindingHandle) getCache() FuzzyBindingCache { + return h.fuzzyBindingCache.Load().(FuzzyBindingCache) +} + +func (h *globalBindingHandle) setCache(c FuzzyBindingCache) { + // TODO: update the global cache in-place instead of replacing it and remove this function. + h.fuzzyBindingCache.Store(c) +} + +// Reset is to reset the BindHandle and clean old info. +func (h *globalBindingHandle) Reset() { + h.lastUpdateTime.Store(types.ZeroTimestamp) + h.invalidBindings = newInvalidBindingCache() + h.setCache(newFuzzyBindingCache(h.LoadBindingsFromStorage)) + variable.RegisterStatistics(h) +} + +func (h *globalBindingHandle) getLastUpdateTime() types.Time { + return h.lastUpdateTime.Load().(types.Time) +} + +func (h *globalBindingHandle) setLastUpdateTime(t types.Time) { + h.lastUpdateTime.Store(t) +} + +// LoadFromStorageToCache loads bindings from the storage into the cache. +func (h *globalBindingHandle) LoadFromStorageToCache(fullLoad bool) (err error) { + var lastUpdateTime types.Time + var timeCondition string + var newCache FuzzyBindingCache + if fullLoad { + lastUpdateTime = types.ZeroTimestamp + timeCondition = "" + newCache = newFuzzyBindingCache(h.LoadBindingsFromStorage) + } else { + lastUpdateTime = h.getLastUpdateTime() + timeCondition = fmt.Sprintf("WHERE update_time>'%s'", lastUpdateTime.String()) + newCache, err = h.getCache().Copy() + if err != nil { + return err + } + } + + selectStmt := fmt.Sprintf(`SELECT original_sql, bind_sql, default_db, status, create_time, + update_time, charset, collation, source, sql_digest, plan_digest FROM mysql.bind_info + %s ORDER BY update_time, create_time`, timeCondition) + + return h.callWithSCtx(false, func(sctx sessionctx.Context) error { + rows, _, err := execRows(sctx, selectStmt) + if err != nil { + return err + } + + defer func() { + h.setLastUpdateTime(lastUpdateTime) + h.setCache(newCache) + + metrics.BindingCacheMemUsage.Set(float64(h.GetMemUsage())) + metrics.BindingCacheMemLimit.Set(float64(h.GetMemCapacity())) + metrics.BindingCacheNumBindings.Set(float64(h.Size())) + }() + + for _, row := range rows { + // Skip the builtin record which is designed for binding synchronization. + if row.GetString(0) == BuiltinPseudoSQL4BindLock { + continue + } + sqlDigest, binding, err := newBinding(sctx, row) + + // Update lastUpdateTime to the newest one. + // Even if this one is an invalid bind. + if binding.UpdateTime.Compare(lastUpdateTime) > 0 { + lastUpdateTime = binding.UpdateTime + } + + if err != nil { + logutil.BindLogger().Warn("failed to generate bind record from data row", zap.Error(err)) + continue + } + + oldBinding := newCache.GetBinding(sqlDigest) + newBinding := removeDeletedBindings(merge(oldBinding, []Binding{binding})) + if len(newBinding) > 0 { + err = newCache.SetBinding(sqlDigest, newBinding) + if err != nil { + // When the memory capacity of bing_cache is not enough, + // there will be some memory-related errors in multiple places. + // Only needs to be handled once. + logutil.BindLogger().Warn("BindHandle.Update", zap.Error(err)) + } + } else { + newCache.RemoveBinding(sqlDigest) + } + } + return nil + }) +} + +// CreateGlobalBinding creates a Bindings to the storage and the cache. +// It replaces all the exists bindings for the same normalized SQL. +func (h *globalBindingHandle) CreateGlobalBinding(sctx sessionctx.Context, binding Binding) (err error) { + if err := prepareHints(sctx, &binding); err != nil { + return err + } + defer func() { + if err == nil { + err = h.LoadFromStorageToCache(false) + } + }() + + return h.callWithSCtx(true, func(sctx sessionctx.Context) error { + // Lock mysql.bind_info to synchronize with CreateBinding / AddBinding / DropBinding on other tidb instances. + if err = lockBindInfoTable(sctx); err != nil { + return err + } + + now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) + + updateTs := now.String() + _, err = exec(sctx, `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE original_sql = %? AND update_time < %?`, + deleted, updateTs, binding.OriginalSQL, updateTs) + if err != nil { + return err + } + + binding.CreateTime = now + binding.UpdateTime = now + + // Insert the Bindings to the storage. + _, err = exec(sctx, `INSERT INTO mysql.bind_info VALUES (%?,%?, %?, %?, %?, %?, %?, %?, %?, %?, %?)`, + binding.OriginalSQL, + binding.BindSQL, + strings.ToLower(binding.Db), + binding.Status, + binding.CreateTime.String(), + binding.UpdateTime.String(), + binding.Charset, + binding.Collation, + binding.Source, + binding.SQLDigest, + binding.PlanDigest, + ) + if err != nil { + return err + } + return nil + }) +} + +// dropGlobalBinding drops a Bindings to the storage and Bindings int the cache. +func (h *globalBindingHandle) dropGlobalBinding(sqlDigest string) (deletedRows uint64, err error) { + err = h.callWithSCtx(false, func(sctx sessionctx.Context) error { + // Lock mysql.bind_info to synchronize with CreateBinding / AddBinding / DropBinding on other tidb instances. + if err = lockBindInfoTable(sctx); err != nil { + return err + } + + updateTs := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3).String() + + _, err = exec(sctx, `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE sql_digest = %? AND update_time < %? AND status != %?`, + deleted, updateTs, sqlDigest, updateTs, deleted) + if err != nil { + return err + } + deletedRows = sctx.GetSessionVars().StmtCtx.AffectedRows() + return nil + }) + return +} + +// DropGlobalBinding drop Bindings to the storage and Bindings int the cache. +func (h *globalBindingHandle) DropGlobalBinding(sqlDigest string) (deletedRows uint64, err error) { + if sqlDigest == "" { + return 0, errors.New("sql digest is empty") + } + defer func() { + if err == nil { + err = h.LoadFromStorageToCache(false) + } + }() + return h.dropGlobalBinding(sqlDigest) +} + +// SetGlobalBindingStatus set a Bindings's status to the storage and bind cache. +func (h *globalBindingHandle) SetGlobalBindingStatus(newStatus, sqlDigest string) (ok bool, err error) { + var ( + updateTs types.Time + oldStatus0, oldStatus1 string + ) + if newStatus == Disabled { + // For compatibility reasons, when we need to 'set binding disabled for ', + // we need to consider both the 'enabled' and 'using' status. + oldStatus0 = Using + oldStatus1 = Enabled + } else if newStatus == Enabled { + // In order to unify the code, two identical old statuses are set. + oldStatus0 = Disabled + oldStatus1 = Disabled + } + + defer func() { + if err == nil { + err = h.LoadFromStorageToCache(false) + } + }() + + err = h.callWithSCtx(true, func(sctx sessionctx.Context) error { + // Lock mysql.bind_info to synchronize with SetBindingStatus on other tidb instances. + if err = lockBindInfoTable(sctx); err != nil { + return err + } + + updateTs = types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) + updateTsStr := updateTs.String() + + _, err = exec(sctx, `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE sql_digest = %? AND update_time < %? AND status IN (%?, %?)`, + newStatus, updateTsStr, sqlDigest, updateTsStr, oldStatus0, oldStatus1) + return err + }) + return +} + +// GCGlobalBinding physically removes the deleted bind records in mysql.bind_info. +func (h *globalBindingHandle) GCGlobalBinding() (err error) { + return h.callWithSCtx(true, func(sctx sessionctx.Context) error { + // Lock mysql.bind_info to synchronize with CreateBinding / AddBinding / DropBinding on other tidb instances. + if err = lockBindInfoTable(sctx); err != nil { + return err + } + + // To make sure that all the deleted bind records have been acknowledged to all tidb, + // we only garbage collect those records with update_time before 10 leases. + updateTime := time.Now().Add(-(10 * Lease)) + updateTimeStr := types.NewTime(types.FromGoTime(updateTime), mysql.TypeTimestamp, 3).String() + _, err = exec(sctx, `DELETE FROM mysql.bind_info WHERE status = 'deleted' and update_time < %?`, updateTimeStr) + return err + }) +} + +// lockBindInfoTable simulates `LOCK TABLE mysql.bind_info WRITE` by acquiring a pessimistic lock on a +// special builtin row of mysql.bind_info. Note that this function must be called with h.sctx.Lock() held. +// We can replace this implementation to normal `LOCK TABLE mysql.bind_info WRITE` if that feature is +// generally available later. +// This lock would enforce the CREATE / DROP GLOBAL BINDING statements to be executed sequentially, +// even if they come from different tidb instances. +func lockBindInfoTable(sctx sessionctx.Context) error { + // h.sctx already locked. + _, err := exec(sctx, LockBindInfoSQL) + return err +} + +// invalidBindingCache is used to store invalid bindings temporarily. +type invalidBindingCache struct { + mu sync.RWMutex + m map[string]Binding // key: sqlDigest +} + +func newInvalidBindingCache() *invalidBindingCache { + return &invalidBindingCache{ + m: make(map[string]Binding), + } +} + +func (c *invalidBindingCache) add(binding Binding) { + c.mu.Lock() + defer c.mu.Unlock() + c.m[binding.SQLDigest] = binding +} + +func (c *invalidBindingCache) getAll() Bindings { + c.mu.Lock() + defer c.mu.Unlock() + bindings := make(Bindings, 0, len(c.m)) + for _, binding := range c.m { + bindings = append(bindings, binding) + } + return bindings +} + +func (c *invalidBindingCache) reset() { + c.mu.Lock() + defer c.mu.Unlock() + c.m = make(map[string]Binding) +} + +// DropInvalidGlobalBinding executes the drop Bindings tasks. +func (h *globalBindingHandle) DropInvalidGlobalBinding() { + defer func() { + if err := h.LoadFromStorageToCache(false); err != nil { + logutil.BindLogger().Warn("drop invalid global binding error", zap.Error(err)) + } + }() + + invalidBindings := h.invalidBindings.getAll() + h.invalidBindings.reset() + for _, invalidBinding := range invalidBindings { + if _, err := h.dropGlobalBinding(invalidBinding.SQLDigest); err != nil { + logutil.BindLogger().Debug("flush bind record failed", zap.Error(err)) + } + } +} + +// AddInvalidGlobalBinding adds Bindings which needs to be deleted into invalidBindings. +func (h *globalBindingHandle) AddInvalidGlobalBinding(invalidBinding Binding) { + h.invalidBindings.add(invalidBinding) +} + +// Size returns the size of bind info cache. +func (h *globalBindingHandle) Size() int { + size := len(h.getCache().GetAllBindings()) + return size +} + +// MatchGlobalBinding returns the matched binding for this statement. +func (h *globalBindingHandle) MatchGlobalBinding(sctx sessionctx.Context, fuzzyDigest string, tableNames []*ast.TableName) (matchedBinding Binding, isMatched bool) { + return h.getCache().FuzzyMatchingBinding(sctx, fuzzyDigest, tableNames) +} + +// GetAllGlobalBindings returns all bind records in cache. +func (h *globalBindingHandle) GetAllGlobalBindings() (bindings Bindings) { + return h.getCache().GetAllBindings() +} + +// SetBindingCacheCapacity reset the capacity for the bindingCache. +// It will not affect already cached Bindings. +func (h *globalBindingHandle) SetBindingCacheCapacity(capacity int64) { + h.getCache().SetMemCapacity(capacity) +} + +// GetMemUsage returns the memory usage for the bind cache. +func (h *globalBindingHandle) GetMemUsage() (memUsage int64) { + return h.getCache().GetMemUsage() +} + +// GetMemCapacity returns the memory capacity for the bind cache. +func (h *globalBindingHandle) GetMemCapacity() (memCapacity int64) { + return h.getCache().GetMemCapacity() +} + +// newBinding builds Bindings from a tuple in storage. +func newBinding(sctx sessionctx.Context, row chunk.Row) (string, Binding, error) { + status := row.GetString(3) + // For compatibility, the 'Using' status binding will be converted to the 'Enabled' status binding. + if status == Using { + status = Enabled + } + binding := Binding{ + OriginalSQL: row.GetString(0), + Db: strings.ToLower(row.GetString(2)), + BindSQL: row.GetString(1), + Status: status, + CreateTime: row.GetTime(4), + UpdateTime: row.GetTime(5), + Charset: row.GetString(6), + Collation: row.GetString(7), + Source: row.GetString(8), + SQLDigest: row.GetString(9), + PlanDigest: row.GetString(10), + } + sqlDigest := parser.DigestNormalized(binding.OriginalSQL) + err := prepareHints(sctx, &binding) + sctx.GetSessionVars().CurrentDB = binding.Db + return sqlDigest.String(), binding, err +} + +func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) { + origVals := sctx.GetSessionVars().UsePlanBaselines + sctx.GetSessionVars().UsePlanBaselines = false + + // Usually passing a sprintf to ExecuteInternal is not recommended, but in this case + // it is safe because ExecuteInternal does not permit MultiStatement execution. Thus, + // the statement won't be able to "break out" from EXPLAIN. + rs, err := exec(sctx, fmt.Sprintf("EXPLAIN FORMAT='hint' %s", sql)) + sctx.GetSessionVars().UsePlanBaselines = origVals + if rs != nil { + defer func() { + // Audit log is collected in Close(), set InRestrictedSQL to avoid 'create sql binding' been recorded as 'explain'. + origin := sctx.GetSessionVars().InRestrictedSQL + sctx.GetSessionVars().InRestrictedSQL = true + terror.Call(rs.Close) + sctx.GetSessionVars().InRestrictedSQL = origin + }() + } + if err != nil { + return "", err + } + chk := rs.NewChunk(nil) + err = rs.Next(context.TODO(), chk) + if err != nil { + return "", err + } + return chk.GetRow(0).GetString(0), nil +} + +// GenerateBindingSQL generates binding sqls from stmt node and plan hints. +func GenerateBindingSQL(stmtNode ast.StmtNode, planHint string, skipCheckIfHasParam bool, defaultDB string) string { + // If would be nil for very simple cases such as point get, we do not need to evolve for them. + if planHint == "" { + return "" + } + if !skipCheckIfHasParam { + paramChecker := ¶mMarkerChecker{} + stmtNode.Accept(paramChecker) + // We need to evolve on current sql, but we cannot restore values for paramMarkers yet, + // so just ignore them now. + if paramChecker.hasParamMarker { + return "" + } + } + // We need to evolve plan based on the current sql, not the original sql which may have different parameters. + // So here we would remove the hint and inject the current best plan hint. + hint.BindHint(stmtNode, &hint.HintsSet{}) + bindSQL := utilparser.RestoreWithDefaultDB(stmtNode, defaultDB, "") + if bindSQL == "" { + return "" + } + switch n := stmtNode.(type) { + case *ast.DeleteStmt: + deleteIdx := strings.Index(bindSQL, "DELETE") + // Remove possible `explain` prefix. + bindSQL = bindSQL[deleteIdx:] + return strings.Replace(bindSQL, "DELETE", fmt.Sprintf("DELETE /*+ %s*/", planHint), 1) + case *ast.UpdateStmt: + updateIdx := strings.Index(bindSQL, "UPDATE") + // Remove possible `explain` prefix. + bindSQL = bindSQL[updateIdx:] + return strings.Replace(bindSQL, "UPDATE", fmt.Sprintf("UPDATE /*+ %s*/", planHint), 1) + case *ast.SelectStmt: + var selectIdx int + if n.With != nil { + var withSb strings.Builder + withIdx := strings.Index(bindSQL, "WITH") + restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreSpacesAroundBinaryOperation|format.RestoreStringWithoutCharset|format.RestoreNameBackQuotes, &withSb) + restoreCtx.DefaultDB = defaultDB + if err := n.With.Restore(restoreCtx); err != nil { + logutil.BindLogger().Debug("restore SQL failed", zap.Error(err)) + return "" + } + withEnd := withIdx + len(withSb.String()) + tmp := strings.Replace(bindSQL[withEnd:], "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1) + return strings.Join([]string{bindSQL[withIdx:withEnd], tmp}, "") + } + selectIdx = strings.Index(bindSQL, "SELECT") + // Remove possible `explain` prefix. + bindSQL = bindSQL[selectIdx:] + return strings.Replace(bindSQL, "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1) + case *ast.InsertStmt: + insertIdx := int(0) + if n.IsReplace { + insertIdx = strings.Index(bindSQL, "REPLACE") + } else { + insertIdx = strings.Index(bindSQL, "INSERT") + } + // Remove possible `explain` prefix. + bindSQL = bindSQL[insertIdx:] + return strings.Replace(bindSQL, "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1) + } + logutil.BindLogger().Debug("unexpected statement type when generating bind SQL", zap.Any("statement", stmtNode)) + return "" +} + +type paramMarkerChecker struct { + hasParamMarker bool +} + +func (e *paramMarkerChecker) Enter(in ast.Node) (ast.Node, bool) { + if _, ok := in.(*driver.ParamMarkerExpr); ok { + e.hasParamMarker = true + return in, true + } + return in, false +} + +func (*paramMarkerChecker) Leave(in ast.Node) (ast.Node, bool) { + return in, true +} + +// Clear resets the bind handle. It is only used for test. +func (h *globalBindingHandle) Clear() { + h.setCache(newFuzzyBindingCache(h.LoadBindingsFromStorage)) + h.setLastUpdateTime(types.ZeroTimestamp) + h.invalidBindings.reset() +} + +// FlushGlobalBindings flushes the Bindings in temp maps to storage and loads them into cache. +func (h *globalBindingHandle) FlushGlobalBindings() error { + h.DropInvalidGlobalBinding() + return h.LoadFromStorageToCache(false) +} + +func (h *globalBindingHandle) callWithSCtx(wrapTxn bool, f func(sctx sessionctx.Context) error) (err error) { + resource, err := h.sPool.Get() + if err != nil { + return err + } + defer func() { + if err == nil { // only recycle when no error + h.sPool.Put(resource) + } + }() + sctx := resource.(sessionctx.Context) + if wrapTxn { + if _, err = exec(sctx, "BEGIN PESSIMISTIC"); err != nil { + return + } + defer func() { + if err == nil { + _, err = exec(sctx, "COMMIT") + } else { + _, err1 := exec(sctx, "ROLLBACK") + terror.Log(errors.Trace(err1)) + } + }() + } + + err = f(sctx) + return +} + +var ( + lastPlanBindingUpdateTime = "last_plan_binding_update_time" +) + +// GetScope gets the status variables scope. +func (*globalBindingHandle) GetScope(_ string) variable.ScopeFlag { + return variable.ScopeSession +} + +// Stats returns the server statistics. +func (h *globalBindingHandle) Stats(_ *variable.SessionVars) (map[string]any, error) { + m := make(map[string]any) + m[lastPlanBindingUpdateTime] = h.getLastUpdateTime().String() + return m, nil +} + +// LoadBindingsFromStorageToCache loads global bindings from storage to the memory cache. +func (h *globalBindingHandle) LoadBindingsFromStorage(sctx sessionctx.Context, sqlDigest string) (Bindings, error) { + if sqlDigest == "" { + return nil, nil + } + timeout := time.Duration(sctx.GetSessionVars().LoadBindingTimeout) * time.Millisecond + resultChan := h.syncBindingSingleflight.DoChan(sqlDigest, func() (any, error) { + return h.loadBindingsFromStorageInternal(sqlDigest) + }) + select { + case result := <-resultChan: + if result.Err != nil { + return nil, result.Err + } + bindings := result.Val + if bindings == nil { + return nil, nil + } + return bindings.(Bindings), nil + case <-time.After(timeout): + return nil, errors.New("load bindings from storage timeout") + } +} + +func (h *globalBindingHandle) loadBindingsFromStorageInternal(sqlDigest string) (any, error) { + failpoint.Inject("load_bindings_from_storage_internal_timeout", func() { + time.Sleep(time.Second) + }) + var bindings Bindings + selectStmt := fmt.Sprintf("SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source, sql_digest, plan_digest FROM mysql.bind_info where sql_digest = '%s'", sqlDigest) + err := h.callWithSCtx(false, func(sctx sessionctx.Context) error { + rows, _, err := execRows(sctx, selectStmt) + if err != nil { + return err + } + bindings = make([]Binding, 0, len(rows)) + for _, row := range rows { + // Skip the builtin record which is designed for binding synchronization. + if row.GetString(0) == BuiltinPseudoSQL4BindLock { + continue + } + _, binding, err := newBinding(sctx, row) + if err != nil { + logutil.BindLogger().Warn("failed to generate bind record from data row", zap.Error(err)) + continue + } + bindings = append(bindings, binding) + } + return nil + }) + return bindings, err +} diff --git a/pkg/ddl/add_column.go b/pkg/ddl/add_column.go index 54f519b7731af..3f95a7de11901 100644 --- a/pkg/ddl/add_column.go +++ b/pkg/ddl/add_column.go @@ -60,12 +60,12 @@ func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) return ver, nil } - failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorBeforeDecodeArgs")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { - failpoint.Return(ver, errors.New("occur an error before decode args")) + return ver, errors.New("occur an error before decode args") } - }) + } tblInfo, columnInfo, colFromArgs, pos, ifNotExists, err := checkAddColumn(t, job) if err != nil { @@ -117,7 +117,7 @@ func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) case model.StateWriteReorganization: // reorganization -> public // Adjust table column offset. - failpoint.InjectCall("onAddColumnStateWriteReorg") + failpoint.Call(_curpkg_("onAddColumnStateWriteReorg")) offset, err := LocateOffsetToMove(columnInfo.Offset, pos, tblInfo) if err != nil { return ver, errors.Trace(err) diff --git a/pkg/ddl/add_column.go__failpoint_stash__ b/pkg/ddl/add_column.go__failpoint_stash__ new file mode 100644 index 0000000000000..54f519b7731af --- /dev/null +++ b/pkg/ddl/add_column.go__failpoint_stash__ @@ -0,0 +1,1288 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "fmt" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + field_types "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/hack" + "go.uber.org/zap" +) + +func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + // Handle the rolling back job. + if job.IsRollingback() { + ver, err = onDropColumn(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + return ver, nil + } + + failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + failpoint.Return(ver, errors.New("occur an error before decode args")) + } + }) + + tblInfo, columnInfo, colFromArgs, pos, ifNotExists, err := checkAddColumn(t, job) + if err != nil { + if ifNotExists && infoschema.ErrColumnExists.Equal(err) { + job.Warning = toTError(err) + job.State = model.JobStateDone + return ver, nil + } + return ver, errors.Trace(err) + } + if columnInfo == nil { + columnInfo = InitAndAddColumnToTable(tblInfo, colFromArgs) + logutil.DDLLogger().Info("run add column job", zap.Stringer("job", job), zap.Reflect("columnInfo", *columnInfo)) + if err = checkAddColumnTooManyColumns(len(tblInfo.Columns)); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } + + originalState := columnInfo.State + switch columnInfo.State { + case model.StateNone: + // none -> delete only + columnInfo.State = model.StateDeleteOnly + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != columnInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + job.SchemaState = model.StateDeleteOnly + case model.StateDeleteOnly: + // delete only -> write only + columnInfo.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + // Update the job state when all affairs done. + job.SchemaState = model.StateWriteOnly + case model.StateWriteOnly: + // write only -> reorganization + columnInfo.State = model.StateWriteReorganization + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + // Update the job state when all affairs done. + job.SchemaState = model.StateWriteReorganization + job.MarkNonRevertible() + case model.StateWriteReorganization: + // reorganization -> public + // Adjust table column offset. + failpoint.InjectCall("onAddColumnStateWriteReorg") + offset, err := LocateOffsetToMove(columnInfo.Offset, pos, tblInfo) + if err != nil { + return ver, errors.Trace(err) + } + tblInfo.MoveColumnInfo(columnInfo.Offset, offset) + columnInfo.State = model.StatePublic + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + addColumnEvent := statsutil.NewAddColumnEvent( + job.SchemaID, + tblInfo, + []*model.ColumnInfo{columnInfo}, + ) + asyncNotifyEvent(d, addColumnEvent) + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("column", columnInfo.State) + } + + return ver, errors.Trace(err) +} + +func checkAndCreateNewColumn(ctx sessionctx.Context, ti ast.Ident, schema *model.DBInfo, spec *ast.AlterTableSpec, t table.Table, specNewColumn *ast.ColumnDef) (*table.Column, error) { + err := checkUnsupportedColumnConstraint(specNewColumn, ti) + if err != nil { + return nil, errors.Trace(err) + } + + colName := specNewColumn.Name.Name.O + // Check whether added column has existed. + col := table.FindCol(t.Cols(), colName) + if col != nil { + err = infoschema.ErrColumnExists.GenWithStackByArgs(colName) + if spec.IfNotExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil, nil + } + return nil, err + } + if err = checkColumnAttributes(colName, specNewColumn.Tp); err != nil { + return nil, errors.Trace(err) + } + if utf8.RuneCountInString(colName) > mysql.MaxColumnNameLength { + return nil, dbterror.ErrTooLongIdent.GenWithStackByArgs(colName) + } + + return CreateNewColumn(ctx, schema, spec, t, specNewColumn) +} + +func checkUnsupportedColumnConstraint(col *ast.ColumnDef, ti ast.Ident) error { + for _, constraint := range col.Options { + switch constraint.Tp { + case ast.ColumnOptionAutoIncrement: + return dbterror.ErrUnsupportedAddColumn.GenWithStack("unsupported add column '%s' constraint AUTO_INCREMENT when altering '%s.%s'", col.Name, ti.Schema, ti.Name) + case ast.ColumnOptionPrimaryKey: + return dbterror.ErrUnsupportedAddColumn.GenWithStack("unsupported add column '%s' constraint PRIMARY KEY when altering '%s.%s'", col.Name, ti.Schema, ti.Name) + case ast.ColumnOptionUniqKey: + return dbterror.ErrUnsupportedAddColumn.GenWithStack("unsupported add column '%s' constraint UNIQUE KEY when altering '%s.%s'", col.Name, ti.Schema, ti.Name) + case ast.ColumnOptionAutoRandom: + errMsg := fmt.Sprintf(autoid.AutoRandomAlterAddColumn, col.Name, ti.Schema, ti.Name) + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) + } + } + + return nil +} + +// CreateNewColumn creates a new column according to the column information. +func CreateNewColumn(ctx sessionctx.Context, schema *model.DBInfo, spec *ast.AlterTableSpec, t table.Table, specNewColumn *ast.ColumnDef) (*table.Column, error) { + // If new column is a generated column, do validation. + // NOTE: we do check whether the column refers other generated + // columns occurring later in a table, but we don't handle the col offset. + for _, option := range specNewColumn.Options { + if option.Tp == ast.ColumnOptionGenerated { + if err := checkIllegalFn4Generated(specNewColumn.Name.Name.L, typeColumn, option.Expr); err != nil { + return nil, errors.Trace(err) + } + + if option.Stored { + return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Adding generated stored column through ALTER TABLE") + } + + _, dependColNames, err := findDependedColumnNames(schema.Name, t.Meta().Name, specNewColumn) + if err != nil { + return nil, errors.Trace(err) + } + if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { + if err := checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil { + return nil, errors.Trace(err) + } + } + duplicateColNames := make(map[string]struct{}, len(dependColNames)) + for k := range dependColNames { + duplicateColNames[k] = struct{}{} + } + cols := t.Cols() + + if err := checkDependedColExist(dependColNames, cols); err != nil { + return nil, errors.Trace(err) + } + + if err := verifyColumnGenerationSingle(duplicateColNames, cols, spec.Position); err != nil { + return nil, errors.Trace(err) + } + } + // Specially, since sequence has been supported, if a newly added column has a + // sequence nextval function as it's default value option, it won't fill the + // known rows with specific sequence next value under current add column logic. + // More explanation can refer: TestSequenceDefaultLogic's comment in sequence_test.go + if option.Tp == ast.ColumnOptionDefaultValue { + if f, ok := option.Expr.(*ast.FuncCallExpr); ok { + switch f.FnName.L { + case ast.NextVal: + if _, err := getSequenceDefaultValue(option); err != nil { + return nil, errors.Trace(err) + } + return nil, errors.Trace(dbterror.ErrAddColumnWithSequenceAsDefault.GenWithStackByArgs(specNewColumn.Name.Name.O)) + case ast.Rand, ast.UUID, ast.UUIDToBin, ast.Replace, ast.Upper: + return nil, errors.Trace(dbterror.ErrBinlogUnsafeSystemFunction.GenWithStackByArgs()) + } + } + } + } + + tableCharset, tableCollate, err := ResolveCharsetCollation(ctx.GetSessionVars(), + ast.CharsetOpt{Chs: t.Meta().Charset, Col: t.Meta().Collate}, + ast.CharsetOpt{Chs: schema.Charset, Col: schema.Collate}, + ) + if err != nil { + return nil, errors.Trace(err) + } + // Ignore table constraints now, they will be checked later. + // We use length(t.Cols()) as the default offset firstly, we will change the column's offset later. + col, _, err := buildColumnAndConstraint( + ctx, + len(t.Cols()), + specNewColumn, + nil, + tableCharset, + tableCollate, + ) + if err != nil { + return nil, errors.Trace(err) + } + + originDefVal, err := generateOriginDefaultValue(col.ToInfo(), ctx) + if err != nil { + return nil, errors.Trace(err) + } + + err = col.SetOriginDefaultValue(originDefVal) + return col, err +} + +// buildColumnAndConstraint builds table.Column and ast.Constraint from the parameters. +// outPriKeyConstraint is the primary key constraint out of column definition. For example: +// `create table t1 (id int , age int, primary key(id));` +func buildColumnAndConstraint( + ctx sessionctx.Context, + offset int, + colDef *ast.ColumnDef, + outPriKeyConstraint *ast.Constraint, + tblCharset string, + tblCollate string, +) (*table.Column, []*ast.Constraint, error) { + if colName := colDef.Name.Name.L; colName == model.ExtraHandleName.L { + return nil, nil, dbterror.ErrWrongColumnName.GenWithStackByArgs(colName) + } + + // specifiedCollate refers to the last collate specified in colDef.Options. + chs, coll, err := getCharsetAndCollateInColumnDef(ctx.GetSessionVars(), colDef) + if err != nil { + return nil, nil, errors.Trace(err) + } + chs, coll, err = ResolveCharsetCollation(ctx.GetSessionVars(), + ast.CharsetOpt{Chs: chs, Col: coll}, + ast.CharsetOpt{Chs: tblCharset, Col: tblCollate}, + ) + chs, coll = OverwriteCollationWithBinaryFlag(ctx.GetSessionVars(), colDef, chs, coll) + if err != nil { + return nil, nil, errors.Trace(err) + } + + if err := setCharsetCollationFlenDecimal(colDef.Tp, colDef.Name.Name.O, chs, coll, ctx.GetSessionVars()); err != nil { + return nil, nil, errors.Trace(err) + } + decodeEnumSetBinaryLiteralToUTF8(colDef.Tp, chs) + col, cts, err := columnDefToCol(ctx, offset, colDef, outPriKeyConstraint) + if err != nil { + return nil, nil, errors.Trace(err) + } + return col, cts, nil +} + +// getCharsetAndCollateInColumnDef will iterate collate in the options, validate it by checking the charset +// of column definition. If there's no collate in the option, the default collate of column's charset will be used. +func getCharsetAndCollateInColumnDef(sessVars *variable.SessionVars, def *ast.ColumnDef) (chs, coll string, err error) { + chs = def.Tp.GetCharset() + coll = def.Tp.GetCollate() + if chs != "" && coll == "" { + if coll, err = GetDefaultCollation(sessVars, chs); err != nil { + return "", "", errors.Trace(err) + } + } + for _, opt := range def.Options { + if opt.Tp == ast.ColumnOptionCollate { + info, err := collate.GetCollationByName(opt.StrValue) + if err != nil { + return "", "", errors.Trace(err) + } + if chs == "" { + chs = info.CharsetName + } else if chs != info.CharsetName { + return "", "", dbterror.ErrCollationCharsetMismatch.GenWithStackByArgs(info.Name, chs) + } + coll = info.Name + } + } + return +} + +// OverwriteCollationWithBinaryFlag is used to handle the case like +// +// CREATE TABLE t (a VARCHAR(255) BINARY) CHARSET utf8 COLLATE utf8_general_ci; +// +// The 'BINARY' sets the column collation to *_bin according to the table charset. +func OverwriteCollationWithBinaryFlag(sessVars *variable.SessionVars, colDef *ast.ColumnDef, chs, coll string) (newChs string, newColl string) { + ignoreBinFlag := colDef.Tp.GetCharset() != "" && (colDef.Tp.GetCollate() != "" || containsColumnOption(colDef, ast.ColumnOptionCollate)) + if ignoreBinFlag { + return chs, coll + } + needOverwriteBinColl := types.IsString(colDef.Tp.GetType()) && mysql.HasBinaryFlag(colDef.Tp.GetFlag()) + if needOverwriteBinColl { + newColl, err := GetDefaultCollation(sessVars, chs) + if err != nil { + return chs, coll + } + return chs, newColl + } + return chs, coll +} + +func setCharsetCollationFlenDecimal(tp *types.FieldType, colName, colCharset, colCollate string, sessVars *variable.SessionVars) error { + var err error + if typesNeedCharset(tp.GetType()) { + tp.SetCharset(colCharset) + tp.SetCollate(colCollate) + } else { + tp.SetCharset(charset.CharsetBin) + tp.SetCollate(charset.CharsetBin) + } + + // Use default value for flen or decimal when they are unspecified. + defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(tp.GetType()) + if tp.GetDecimal() == types.UnspecifiedLength { + tp.SetDecimal(defaultDecimal) + } + if tp.GetFlen() == types.UnspecifiedLength { + tp.SetFlen(defaultFlen) + if mysql.HasUnsignedFlag(tp.GetFlag()) && tp.GetType() != mysql.TypeLonglong && mysql.IsIntegerType(tp.GetType()) { + // Issue #4684: the flen of unsigned integer(except bigint) is 1 digit shorter than signed integer + // because it has no prefix "+" or "-" character. + tp.SetFlen(tp.GetFlen() - 1) + } + } else { + // Adjust the field type for blob/text types if the flen is set. + if err = adjustBlobTypesFlen(tp, colCharset); err != nil { + return err + } + } + return checkTooBigFieldLengthAndTryAutoConvert(tp, colName, sessVars) +} + +func decodeEnumSetBinaryLiteralToUTF8(tp *types.FieldType, chs string) { + if tp.GetType() != mysql.TypeEnum && tp.GetType() != mysql.TypeSet { + return + } + enc := charset.FindEncoding(chs) + for i, elem := range tp.GetElems() { + if !tp.GetElemIsBinaryLit(i) { + continue + } + s, err := enc.Transform(nil, hack.Slice(elem), charset.OpDecodeReplace) + if err != nil { + logutil.DDLLogger().Warn("decode enum binary literal to utf-8 failed", zap.Error(err)) + } + tp.SetElem(i, string(hack.String(s))) + } + tp.CleanElemIsBinaryLit() +} + +func typesNeedCharset(tp byte) bool { + switch tp { + case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, + mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, + mysql.TypeEnum, mysql.TypeSet: + return true + } + return false +} + +// checkTooBigFieldLengthAndTryAutoConvert will check whether the field length is too big +// in non-strict mode and varchar column. If it is, will try to adjust to blob or text, see issue #30328 +func checkTooBigFieldLengthAndTryAutoConvert(tp *types.FieldType, colName string, sessVars *variable.SessionVars) error { + if sessVars != nil && !sessVars.SQLMode.HasStrictMode() && tp.GetType() == mysql.TypeVarchar { + err := types.IsVarcharTooBigFieldLength(tp.GetFlen(), colName, tp.GetCharset()) + if err != nil && terror.ErrorEqual(types.ErrTooBigFieldLength, err) { + tp.SetType(mysql.TypeBlob) + if err = adjustBlobTypesFlen(tp, tp.GetCharset()); err != nil { + return err + } + if tp.GetCharset() == charset.CharsetBin { + sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARBINARY", "BLOB")) + } else { + sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARCHAR", "TEXT")) + } + } + } + return nil +} + +// columnDefToCol converts ColumnDef to Col and TableConstraints. +// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); +func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) { + var constraints = make([]*ast.Constraint, 0) + col := table.ToColumn(&model.ColumnInfo{ + Offset: offset, + Name: colDef.Name.Name, + FieldType: *colDef.Tp, + // TODO: remove this version field after there is no old version. + Version: model.CurrLatestColumnInfoVersion, + }) + + if !isExplicitTimeStamp() { + // Check and set TimestampFlag, OnUpdateNowFlag and NotNullFlag. + if col.GetType() == mysql.TypeTimestamp { + col.AddFlag(mysql.TimestampFlag | mysql.OnUpdateNowFlag | mysql.NotNullFlag) + } + } + var err error + setOnUpdateNow := false + hasDefaultValue := false + hasNullFlag := false + if colDef.Options != nil { + length := types.UnspecifiedLength + + keys := []*ast.IndexPartSpecification{ + { + Column: colDef.Name, + Length: length, + }, + } + + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + + for _, v := range colDef.Options { + switch v.Tp { + case ast.ColumnOptionNotNull: + col.AddFlag(mysql.NotNullFlag) + case ast.ColumnOptionNull: + col.DelFlag(mysql.NotNullFlag) + removeOnUpdateNowFlag(col) + hasNullFlag = true + case ast.ColumnOptionAutoIncrement: + col.AddFlag(mysql.AutoIncrementFlag | mysql.NotNullFlag) + case ast.ColumnOptionPrimaryKey: + // Check PriKeyFlag first to avoid extra duplicate constraints. + if col.GetFlag()&mysql.PriKeyFlag == 0 { + constraint := &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: keys, + Option: &ast.IndexOption{PrimaryKeyTp: v.PrimaryKeyTp}} + constraints = append(constraints, constraint) + col.AddFlag(mysql.PriKeyFlag) + // Add NotNullFlag early so that processColumnFlags() can see it. + col.AddFlag(mysql.NotNullFlag) + } + case ast.ColumnOptionUniqKey: + // Check UniqueFlag first to avoid extra duplicate constraints. + if col.GetFlag()&mysql.UniqueFlag == 0 { + constraint := &ast.Constraint{Tp: ast.ConstraintUniqKey, Keys: keys} + constraints = append(constraints, constraint) + col.AddFlag(mysql.UniqueKeyFlag) + } + case ast.ColumnOptionDefaultValue: + hasDefaultValue, err = SetDefaultValue(ctx, col, v) + if err != nil { + return nil, nil, errors.Trace(err) + } + removeOnUpdateNowFlag(col) + case ast.ColumnOptionOnUpdate: + // TODO: Support other time functions. + if !(col.GetType() == mysql.TypeTimestamp || col.GetType() == mysql.TypeDatetime) { + return nil, nil, dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) + } + if !expression.IsValidCurrentTimestampExpr(v.Expr, colDef.Tp) { + return nil, nil, dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) + } + col.AddFlag(mysql.OnUpdateNowFlag) + setOnUpdateNow = true + case ast.ColumnOptionComment: + err := setColumnComment(ctx, col, v) + if err != nil { + return nil, nil, errors.Trace(err) + } + case ast.ColumnOptionGenerated: + sb.Reset() + err = v.Expr.Restore(restoreCtx) + if err != nil { + return nil, nil, errors.Trace(err) + } + col.GeneratedExprString = sb.String() + col.GeneratedStored = v.Stored + _, dependColNames, err := findDependedColumnNames(model.NewCIStr(""), model.NewCIStr(""), colDef) + if err != nil { + return nil, nil, errors.Trace(err) + } + col.Dependences = dependColNames + case ast.ColumnOptionCollate: + if field_types.HasCharset(colDef.Tp) { + col.FieldType.SetCollate(v.StrValue) + } + case ast.ColumnOptionFulltext: + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) + case ast.ColumnOptionCheck: + if !variable.EnableCheckConstraint.Load() { + ctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) + } else { + // Check the column CHECK constraint dependency lazily, after fill all the name. + // Extract column constraint from column option. + constraint := &ast.Constraint{ + Tp: ast.ConstraintCheck, + Expr: v.Expr, + Enforced: v.Enforced, + Name: v.ConstraintName, + InColumn: true, + InColumnName: colDef.Name.Name.O, + } + constraints = append(constraints, constraint) + } + } + } + } + + if err = processAndCheckDefaultValueAndColumn(ctx, col, outPriKeyConstraint, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { + return nil, nil, errors.Trace(err) + } + return col, constraints, nil +} + +// isExplicitTimeStamp is used to check if explicit_defaults_for_timestamp is on or off. +// Check out this link for more details. +// https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_explicit_defaults_for_timestamp +func isExplicitTimeStamp() bool { + // TODO: implement the behavior as MySQL when explicit_defaults_for_timestamp = off, then this function could return false. + return true +} + +// SetDefaultValue sets the default value of the column. +func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (hasDefaultValue bool, err error) { + var value any + var isSeqExpr bool + value, isSeqExpr, err = getDefaultValue( + exprctx.CtxWithHandleTruncateErrLevel(ctx.GetExprCtx(), errctx.LevelError), + col, option, + ) + if err != nil { + return false, errors.Trace(err) + } + if isSeqExpr { + if err := checkSequenceDefaultValue(col); err != nil { + return false, errors.Trace(err) + } + col.DefaultIsExpr = isSeqExpr + } + + // When the default value is expression, we skip check and convert. + if !col.DefaultIsExpr { + if hasDefaultValue, value, err = checkColumnDefaultValue(ctx.GetExprCtx(), col, value); err != nil { + return hasDefaultValue, errors.Trace(err) + } + value, err = convertTimestampDefaultValToUTC(ctx, value, col) + if err != nil { + return hasDefaultValue, errors.Trace(err) + } + } else { + hasDefaultValue = true + } + err = setDefaultValueWithBinaryPadding(col, value) + if err != nil { + return hasDefaultValue, errors.Trace(err) + } + return hasDefaultValue, nil +} + +// getFuncCallDefaultValue gets the default column value of function-call expression. +func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *ast.FuncCallExpr) (any, bool, error) { + switch expr.FnName.L { + case ast.CurrentTimestamp, ast.CurrentDate: // CURRENT_TIMESTAMP() and CURRENT_DATE() + tp, fsp := col.FieldType.GetType(), col.FieldType.GetDecimal() + if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { + defaultFsp := 0 + if len(expr.Args) == 1 { + if val := expr.Args[0].(*driver.ValueExpr); val != nil { + defaultFsp = int(val.GetInt64()) + } + } + if defaultFsp != fsp { + return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + } + return nil, false, nil + case ast.NextVal: + // handle default next value of sequence. (keep the expr string) + str, err := getSequenceDefaultValue(option) + if err != nil { + return nil, false, errors.Trace(err) + } + return str, true, nil + case ast.Rand, ast.UUID, ast.UUIDToBin: // RAND(), UUID() and UUID_TO_BIN() + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + case ast.DateFormat: // DATE_FORMAT() + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + // Support DATE_FORMAT(NOW(),'%Y-%m'), DATE_FORMAT(NOW(),'%Y-%m-%d'), + // DATE_FORMAT(NOW(),'%Y-%m-%d %H.%i.%s'), DATE_FORMAT(NOW(),'%Y-%m-%d %H:%i:%s'). + nowFunc, ok := expr.Args[0].(*ast.FuncCallExpr) + if ok && nowFunc.FnName.L == ast.Now { + if err := expression.VerifyArgsWrapper(nowFunc.FnName.L, len(nowFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + valExpr, isValue := expr.Args[1].(ast.ValueExpr) + if !isValue || (valExpr.GetString() != "%Y-%m" && valExpr.GetString() != "%Y-%m-%d" && + valExpr.GetString() != "%Y-%m-%d %H.%i.%s" && valExpr.GetString() != "%Y-%m-%d %H:%i:%s") { + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), valExpr) + } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + } + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), + fmt.Sprintf("%s with disallowed args", expr.FnName.String())) + case ast.Replace: + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + funcCall := expr.Args[0] + // Support REPLACE(CONVERT(UPPER(UUID()) USING UTF8MB4), '-', '')) + if convertFunc, ok := funcCall.(*ast.FuncCallExpr); ok && convertFunc.FnName.L == ast.Convert { + if err := expression.VerifyArgsWrapper(convertFunc.FnName.L, len(convertFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + funcCall = convertFunc.Args[0] + } + // Support REPLACE(UPPER(UUID()), '-', ''). + if upperFunc, ok := funcCall.(*ast.FuncCallExpr); ok && upperFunc.FnName.L == ast.Upper { + if err := expression.VerifyArgsWrapper(upperFunc.FnName.L, len(upperFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + if uuidFunc, ok := upperFunc.Args[0].(*ast.FuncCallExpr); ok && uuidFunc.FnName.L == ast.UUID { + if err := expression.VerifyArgsWrapper(uuidFunc.FnName.L, len(uuidFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + } + } + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), + fmt.Sprintf("%s with disallowed args", expr.FnName.String())) + case ast.Upper: + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + // Support UPPER(SUBSTRING_INDEX(USER(), '@', 1)). + if substringIndexFunc, ok := expr.Args[0].(*ast.FuncCallExpr); ok && substringIndexFunc.FnName.L == ast.SubstringIndex { + if err := expression.VerifyArgsWrapper(substringIndexFunc.FnName.L, len(substringIndexFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + if userFunc, ok := substringIndexFunc.Args[0].(*ast.FuncCallExpr); ok && userFunc.FnName.L == ast.User { + if err := expression.VerifyArgsWrapper(userFunc.FnName.L, len(userFunc.Args)); err != nil { + return nil, false, errors.Trace(err) + } + valExpr, isValue := substringIndexFunc.Args[1].(ast.ValueExpr) + if !isValue || valExpr.GetString() != "@" { + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), valExpr) + } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + } + } + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), + fmt.Sprintf("%s with disallowed args", expr.FnName.String())) + case ast.StrToDate: // STR_TO_DATE() + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + // Support STR_TO_DATE('1980-01-01', '%Y-%m-%d'). + if _, ok1 := expr.Args[0].(ast.ValueExpr); ok1 { + if _, ok2 := expr.Args[1].(ast.ValueExpr); ok2 { + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + } + } + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), + fmt.Sprintf("%s with disallowed args", expr.FnName.String())) + case ast.JSONObject, ast.JSONArray, ast.JSONQuote: // JSON_OBJECT(), JSON_ARRAY(), JSON_QUOTE() + if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { + return nil, false, errors.Trace(err) + } + str, err := restoreFuncCall(expr) + if err != nil { + return nil, false, errors.Trace(err) + } + col.DefaultIsExpr = true + return str, false, nil + + default: + return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String()) + } +} + +// getDefaultValue will get the default value for column. +// 1: get the expr restored string for the column which uses sequence next value as default value. +// 2: get specific default value for the other column. +func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.ColumnOption) (any, bool, error) { + // handle default value with function call + tp, fsp := col.FieldType.GetType(), col.FieldType.GetDecimal() + if x, ok := option.Expr.(*ast.FuncCallExpr); ok { + val, isSeqExpr, err := getFuncCallDefaultValue(col, option, x) + if val != nil || isSeqExpr || err != nil { + return val, isSeqExpr, err + } + // If the function call is ast.CurrentTimestamp, it needs to be continuously processed. + } + + if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate { + vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil) + value := vd.GetValue() + if err != nil { + return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + + // Value is nil means `default null`. + if value == nil { + return nil, false, nil + } + + // If value is types.Time, convert it to string. + if vv, ok := value.(types.Time); ok { + return vv.String(), false, nil + } + + return value, false, nil + } + + // evaluate the non-function-call expr to a certain value. + v, err := expression.EvalSimpleAst(ctx, option.Expr) + if err != nil { + return nil, false, errors.Trace(err) + } + + if v.IsNull() { + return nil, false, nil + } + + if v.Kind() == types.KindBinaryLiteral || v.Kind() == types.KindMysqlBit { + if types.IsTypeBlob(tp) || tp == mysql.TypeJSON { + // BLOB/TEXT/JSON column cannot have a default value. + // Skip the unnecessary decode procedure. + return v.GetString(), false, err + } + if tp == mysql.TypeBit || tp == mysql.TypeString || tp == mysql.TypeVarchar || + tp == mysql.TypeVarString || tp == mysql.TypeEnum || tp == mysql.TypeSet { + // For BinaryLiteral or bit fields, we decode the default value to utf8 string. + str, err := v.GetBinaryStringDecoded(types.StrictFlags, col.GetCharset()) + if err != nil { + // Overwrite the decoding error with invalid default value error. + err = dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + return str, false, err + } + // For other kind of fields (e.g. INT), we supply its integer as string value. + value, err := v.GetBinaryLiteral().ToInt(ctx.GetEvalCtx().TypeCtx()) + if err != nil { + return nil, false, err + } + return strconv.FormatUint(value, 10), false, nil + } + + switch tp { + case mysql.TypeSet: + val, err := getSetDefaultValue(v, col) + return val, false, err + case mysql.TypeEnum: + val, err := getEnumDefaultValue(v, col) + return val, false, err + case mysql.TypeDuration, mysql.TypeDate: + if v, err = v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err != nil { + return "", false, errors.Trace(err) + } + case mysql.TypeBit: + if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 { + // For BIT fields, convert int into BinaryLiteral. + return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), false, nil + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeFloat, mysql.TypeDouble: + // For these types, convert it to standard format firstly. + // like integer fields, convert it into integer string literals. like convert "1.25" into "1" and "2.8" into "3". + // if raise a error, we will use original expression. We will handle it in check phase + if temp, err := v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err == nil { + v = temp + } + } + + val, err := v.ToString() + return val, false, err +} + +func getSequenceDefaultValue(c *ast.ColumnOption) (expr string, err error) { + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + if err := c.Expr.Restore(restoreCtx); err != nil { + return "", err + } + return sb.String(), nil +} + +func setDefaultValueWithBinaryPadding(col *table.Column, value any) error { + err := col.SetDefaultValue(value) + if err != nil { + return err + } + // https://dev.mysql.com/doc/refman/8.0/en/binary-varbinary.html + // Set the default value for binary type should append the paddings. + if value != nil { + if col.GetType() == mysql.TypeString && types.IsBinaryStr(&col.FieldType) && len(value.(string)) < col.GetFlen() { + padding := make([]byte, col.GetFlen()-len(value.(string))) + col.DefaultValue = string(append([]byte(col.DefaultValue.(string)), padding...)) + } + } + return nil +} + +func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) error { + value, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr) + if err != nil { + return errors.Trace(err) + } + if col.Comment, err = value.ToString(); err != nil { + return errors.Trace(err) + } + + sessionVars := ctx.GetSessionVars() + col.Comment, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment) + return errors.Trace(err) +} + +func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Column, + outPriKeyConstraint *ast.Constraint, hasDefaultValue, setOnUpdateNow, hasNullFlag bool) error { + processDefaultValue(col, hasDefaultValue, setOnUpdateNow) + processColumnFlags(col) + + err := checkPriKeyConstraint(col, hasDefaultValue, hasNullFlag, outPriKeyConstraint) + if err != nil { + return errors.Trace(err) + } + if err = checkColumnValueConstraint(col, col.GetCollate()); err != nil { + return errors.Trace(err) + } + if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil { + return errors.Trace(err) + } + if err = checkColumnFieldLength(col); err != nil { + return errors.Trace(err) + } + return nil +} + +func restoreFuncCall(expr *ast.FuncCallExpr) (string, error) { + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + if err := expr.Restore(restoreCtx); err != nil { + return "", err + } + return sb.String(), nil +} + +// getSetDefaultValue gets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html. +func getSetDefaultValue(v types.Datum, col *table.Column) (string, error) { + if v.Kind() == types.KindInt64 { + setCnt := len(col.GetElems()) + maxLimit := int64(1< maxLimit { + return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + setVal, err := types.ParseSetValue(col.GetElems(), uint64(val)) + if err != nil { + return "", errors.Trace(err) + } + v.SetMysqlSet(setVal, col.GetCollate()) + return v.ToString() + } + + str, err := v.ToString() + if err != nil { + return "", errors.Trace(err) + } + if str == "" { + return str, nil + } + setVal, err := types.ParseSetName(col.GetElems(), str, col.GetCollate()) + if err != nil { + return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + v.SetMysqlSet(setVal, col.GetCollate()) + + return v.ToString() +} + +// getEnumDefaultValue gets the default value for the enum type. See https://dev.mysql.com/doc/refman/5.7/en/enum.html. +func getEnumDefaultValue(v types.Datum, col *table.Column) (string, error) { + if v.Kind() == types.KindInt64 { + val := v.GetInt64() + if val < 1 || val > int64(len(col.GetElems())) { + return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + enumVal, err := types.ParseEnumValue(col.GetElems(), uint64(val)) + if err != nil { + return "", errors.Trace(err) + } + v.SetMysqlEnum(enumVal, col.GetCollate()) + return v.ToString() + } + str, err := v.ToString() + if err != nil { + return "", errors.Trace(err) + } + // Ref: https://dev.mysql.com/doc/refman/8.0/en/enum.html + // Trailing spaces are automatically deleted from ENUM member values in the table definition when a table is created. + str = strings.TrimRight(str, " ") + enumVal, err := types.ParseEnumName(col.GetElems(), str, col.GetCollate()) + if err != nil { + return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) + } + v.SetMysqlEnum(enumVal, col.GetCollate()) + + return v.ToString() +} + +func removeOnUpdateNowFlag(c *table.Column) { + // For timestamp Col, if it is set null or default value, + // OnUpdateNowFlag should be removed. + if mysql.HasTimestampFlag(c.GetFlag()) { + c.DelFlag(mysql.OnUpdateNowFlag) + } +} + +func processDefaultValue(c *table.Column, hasDefaultValue bool, setOnUpdateNow bool) { + setTimestampDefaultValue(c, hasDefaultValue, setOnUpdateNow) + + setYearDefaultValue(c, hasDefaultValue) + + // Set `NoDefaultValueFlag` if this field doesn't have a default value and + // it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field. + setNoDefaultValueFlag(c, hasDefaultValue) +} + +func setYearDefaultValue(c *table.Column, hasDefaultValue bool) { + if hasDefaultValue { + return + } + + if c.GetType() == mysql.TypeYear && mysql.HasNotNullFlag(c.GetFlag()) { + if err := c.SetDefaultValue("0000"); err != nil { + logutil.DDLLogger().Error("set default value failed", zap.Error(err)) + } + } +} + +func setTimestampDefaultValue(c *table.Column, hasDefaultValue bool, setOnUpdateNow bool) { + if hasDefaultValue { + return + } + + // For timestamp Col, if is not set default value or not set null, use current timestamp. + if mysql.HasTimestampFlag(c.GetFlag()) && mysql.HasNotNullFlag(c.GetFlag()) { + if setOnUpdateNow { + if err := c.SetDefaultValue(types.ZeroDatetimeStr); err != nil { + logutil.DDLLogger().Error("set default value failed", zap.Error(err)) + } + } else { + if err := c.SetDefaultValue(strings.ToUpper(ast.CurrentTimestamp)); err != nil { + logutil.DDLLogger().Error("set default value failed", zap.Error(err)) + } + } + } +} + +func setNoDefaultValueFlag(c *table.Column, hasDefaultValue bool) { + if hasDefaultValue { + return + } + + if !mysql.HasNotNullFlag(c.GetFlag()) { + return + } + + // Check if it is an `AUTO_INCREMENT` field or `TIMESTAMP` field. + if !mysql.HasAutoIncrementFlag(c.GetFlag()) && !mysql.HasTimestampFlag(c.GetFlag()) { + c.AddFlag(mysql.NoDefaultValueFlag) + } +} + +func checkDefaultValue(ctx exprctx.BuildContext, c *table.Column, hasDefaultValue bool) (err error) { + if !hasDefaultValue { + return nil + } + + if c.GetDefaultValue() != nil { + if c.DefaultIsExpr { + if mysql.HasAutoIncrementFlag(c.GetFlag()) { + return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) + } + return nil + } + _, err = table.GetColDefaultValue( + exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelError), + c.ToInfo(), + ) + if err != nil { + return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) + } + return nil + } + // Primary key default null is invalid. + if mysql.HasPriKeyFlag(c.GetFlag()) { + return dbterror.ErrPrimaryCantHaveNull + } + + // Set not null but default null is invalid. + if mysql.HasNotNullFlag(c.GetFlag()) { + return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) + } + + return nil +} + +func checkColumnFieldLength(col *table.Column) error { + if col.GetType() == mysql.TypeVarchar { + if err := types.IsVarcharTooBigFieldLength(col.GetFlen(), col.Name.O, col.GetCharset()); err != nil { + return errors.Trace(err) + } + } + + return nil +} + +// checkPriKeyConstraint check all parts of a PRIMARY KEY must be NOT NULL +func checkPriKeyConstraint(col *table.Column, hasDefaultValue, hasNullFlag bool, outPriKeyConstraint *ast.Constraint) error { + // Primary key should not be null. + if mysql.HasPriKeyFlag(col.GetFlag()) && hasDefaultValue && col.GetDefaultValue() == nil { + return types.ErrInvalidDefault.GenWithStackByArgs(col.Name) + } + // Set primary key flag for outer primary key constraint. + // Such as: create table t1 (id int , age int, primary key(id)) + if !mysql.HasPriKeyFlag(col.GetFlag()) && outPriKeyConstraint != nil { + for _, key := range outPriKeyConstraint.Keys { + if key.Expr == nil && key.Column.Name.L != col.Name.L { + continue + } + col.AddFlag(mysql.PriKeyFlag) + break + } + } + // Primary key should not be null. + if mysql.HasPriKeyFlag(col.GetFlag()) && hasNullFlag { + return dbterror.ErrPrimaryCantHaveNull + } + return nil +} + +func checkColumnValueConstraint(col *table.Column, collation string) error { + if col.GetType() != mysql.TypeEnum && col.GetType() != mysql.TypeSet { + return nil + } + valueMap := make(map[string]bool, len(col.GetElems())) + ctor := collate.GetCollator(collation) + enumLengthLimit := config.GetGlobalConfig().EnableEnumLengthLimit + desc, err := charset.GetCharsetInfo(col.GetCharset()) + if err != nil { + return errors.Trace(err) + } + for i := range col.GetElems() { + val := string(ctor.Key(col.GetElems()[i])) + // According to MySQL 8.0 Refman: + // The maximum supported length of an individual ENUM element is M <= 255 and (M x w) <= 1020, + // where M is the element literal length and w is the number of bytes required for the maximum-length character in the character set. + // See https://dev.mysql.com/doc/refman/8.0/en/string-type-syntax.html for more details. + if enumLengthLimit && (len(val) > 255 || len(val)*desc.Maxlen > 1020) { + return dbterror.ErrTooLongValueForType.GenWithStackByArgs(col.Name) + } + if _, ok := valueMap[val]; ok { + tpStr := "ENUM" + if col.GetType() == mysql.TypeSet { + tpStr = "SET" + } + return types.ErrDuplicatedValueInType.GenWithStackByArgs(col.Name, col.GetElems()[i], tpStr) + } + valueMap[val] = true + } + return nil +} + +// checkColumnDefaultValue checks the default value of the column. +// In non-strict SQL mode, if the default value of the column is an empty string, the default value can be ignored. +// In strict SQL mode, TEXT/BLOB/JSON can't have not null default values. +// In NO_ZERO_DATE SQL mode, TIMESTAMP/DATE/DATETIME type can't have zero date like '0000-00-00' or '0000-00-00 00:00:00'. +func checkColumnDefaultValue(ctx exprctx.BuildContext, col *table.Column, value any) (bool, any, error) { + hasDefaultValue := true + if value != nil && (col.GetType() == mysql.TypeJSON || + col.GetType() == mysql.TypeTinyBlob || col.GetType() == mysql.TypeMediumBlob || + col.GetType() == mysql.TypeLongBlob || col.GetType() == mysql.TypeBlob) { + // In non-strict SQL mode. + if !ctx.GetEvalCtx().SQLMode().HasStrictMode() && value == "" { + if col.GetType() == mysql.TypeBlob || col.GetType() == mysql.TypeLongBlob { + // The TEXT/BLOB default value can be ignored. + hasDefaultValue = false + } + // In non-strict SQL mode, if the column type is json and the default value is null, it is initialized to an empty array. + if col.GetType() == mysql.TypeJSON { + value = `null` + } + ctx.GetEvalCtx().AppendWarning(dbterror.ErrBlobCantHaveDefault.FastGenByArgs(col.Name.O)) + return hasDefaultValue, value, nil + } + // In strict SQL mode or default value is not an empty string. + return hasDefaultValue, value, dbterror.ErrBlobCantHaveDefault.GenWithStackByArgs(col.Name.O) + } + if value != nil && ctx.GetEvalCtx().SQLMode().HasNoZeroDateMode() && + ctx.GetEvalCtx().SQLMode().HasStrictMode() && types.IsTypeTime(col.GetType()) { + if vv, ok := value.(string); ok { + timeValue, err := expression.GetTimeValue(ctx, vv, col.GetType(), col.GetDecimal(), nil) + if err != nil { + return hasDefaultValue, value, errors.Trace(err) + } + if timeValue.GetMysqlTime().CoreTime() == types.ZeroCoreTime { + return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) + } + } + } + return hasDefaultValue, value, nil +} + +func checkSequenceDefaultValue(col *table.Column) error { + if mysql.IsIntegerType(col.GetType()) { + return nil + } + return dbterror.ErrColumnTypeUnsupportedNextValue.GenWithStackByArgs(col.ColumnInfo.Name.O) +} + +func convertTimestampDefaultValToUTC(ctx sessionctx.Context, defaultVal any, col *table.Column) (any, error) { + if defaultVal == nil || col.GetType() != mysql.TypeTimestamp { + return defaultVal, nil + } + if vv, ok := defaultVal.(string); ok { + if vv != types.ZeroDatetimeStr && !strings.EqualFold(vv, ast.CurrentTimestamp) { + t, err := types.ParseTime(ctx.GetSessionVars().StmtCtx.TypeCtx(), vv, col.GetType(), col.GetDecimal()) + if err != nil { + return defaultVal, errors.Trace(err) + } + err = t.ConvertTimeZone(ctx.GetSessionVars().Location(), time.UTC) + if err != nil { + return defaultVal, errors.Trace(err) + } + defaultVal = t.String() + } + } + return defaultVal, nil +} + +// processColumnFlags is used by columnDefToCol and processColumnOptions. It is intended to unify behaviors on `create/add` and `modify/change` statements. Check tidb#issue#19342. +func processColumnFlags(col *table.Column) { + if col.FieldType.EvalType().IsStringKind() { + if col.GetCharset() == charset.CharsetBin { + col.AddFlag(mysql.BinaryFlag) + } else { + col.DelFlag(mysql.BinaryFlag) + } + } + if col.GetType() == mysql.TypeBit { + // For BIT field, it's charset is binary but does not have binary flag. + col.DelFlag(mysql.BinaryFlag) + col.AddFlag(mysql.UnsignedFlag) + } + if col.GetType() == mysql.TypeYear { + // For Year field, it's charset is binary but does not have binary flag. + col.DelFlag(mysql.BinaryFlag) + col.AddFlag(mysql.ZerofillFlag) + } + + // If you specify ZEROFILL for a numeric column, MySQL automatically adds the UNSIGNED attribute to the column. + // See https://dev.mysql.com/doc/refman/5.7/en/numeric-type-overview.html for more details. + // But some types like bit and year, won't show its unsigned flag in `show create table`. + if mysql.HasZerofillFlag(col.GetFlag()) { + col.AddFlag(mysql.UnsignedFlag) + } +} + +func adjustBlobTypesFlen(tp *types.FieldType, colCharset string) error { + cs, err := charset.GetCharsetInfo(colCharset) + // when we meet the unsupported charset, we do not adjust. + if err != nil { + return err + } + l := tp.GetFlen() * cs.Maxlen + if tp.GetType() == mysql.TypeBlob { + if l <= tinyBlobMaxLength { + logutil.DDLLogger().Info(fmt.Sprintf("Automatically convert BLOB(%d) to TINYBLOB", tp.GetFlen())) + tp.SetFlen(tinyBlobMaxLength) + tp.SetType(mysql.TypeTinyBlob) + } else if l <= blobMaxLength { + tp.SetFlen(blobMaxLength) + } else if l <= mediumBlobMaxLength { + logutil.DDLLogger().Info(fmt.Sprintf("Automatically convert BLOB(%d) to MEDIUMBLOB", tp.GetFlen())) + tp.SetFlen(mediumBlobMaxLength) + tp.SetType(mysql.TypeMediumBlob) + } else if l <= longBlobMaxLength { + logutil.DDLLogger().Info(fmt.Sprintf("Automatically convert BLOB(%d) to LONGBLOB", tp.GetFlen())) + tp.SetFlen(longBlobMaxLength) + tp.SetType(mysql.TypeLongBlob) + } + } + return nil +} diff --git a/pkg/ddl/backfilling.go b/pkg/ddl/backfilling.go index 4fbed0422726d..1f2d3dc37fd3c 100644 --- a/pkg/ddl/backfilling.go +++ b/pkg/ddl/backfilling.go @@ -397,22 +397,22 @@ func (w *backfillWorker) run(d *ddlCtx, bf backfiller, job *model.Job) { d.setDDLLabelForTopSQL(job.ID, job.Query) logger.Debug("backfill worker got task", zap.Int("workerID", w.GetCtx().id), zap.Stringer("task", task)) - failpoint.Inject("mockBackfillRunErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockBackfillRunErr")); _err_ == nil { if w.GetCtx().id == 0 { result := &backfillResult{taskID: task.id, addedCount: 0, nextKey: nil, err: errors.Errorf("mock backfill error")} w.sendResult(result) - failpoint.Continue() + continue } - }) + } - failpoint.Inject("mockHighLoadForAddIndex", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForAddIndex")); _err_ == nil { sqlPrefixes := []string{"alter"} topsql.MockHighCPULoad(job.Query, sqlPrefixes, 5) - }) + } - failpoint.Inject("mockBackfillSlow", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockBackfillSlow")); _err_ == nil { time.Sleep(100 * time.Millisecond) - }) + } // Change the batch size dynamically. w.GetCtx().batchCnt = int(variable.GetDDLReorgBatchSize()) @@ -828,12 +828,12 @@ func (dc *ddlCtx) writePhysicalTableRecord( return errors.Trace(err) } - failpoint.Inject("MockCaseWhenParseFailure", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockCaseWhenParseFailure")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { - failpoint.Return(errors.New("job.ErrCount:" + strconv.Itoa(int(reorgInfo.Job.ErrorCount)) + ", mock unknown type: ast.whenClause.")) + return errors.New("job.ErrCount:" + strconv.Itoa(int(reorgInfo.Job.ErrorCount)) + ", mock unknown type: ast.whenClause.") } - }) + } if bfWorkerType == typeAddIndexWorker && reorgInfo.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { return dc.runAddIndexInLocalIngestMode(ctx, sessPool, t, reorgInfo) } @@ -954,13 +954,13 @@ func injectCheckBackfillWorkerNum(curWorkerSize int, isMergeWorker bool) error { if isMergeWorker { return nil } - failpoint.Inject("checkBackfillWorkerNum", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkBackfillWorkerNum")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { num := int(atomic.LoadInt32(&TestCheckWorkerNumber)) if num != 0 { if num != curWorkerSize { - failpoint.Return(errors.Errorf("expected backfill worker num: %v, actual record num: %v", num, curWorkerSize)) + return errors.Errorf("expected backfill worker num: %v, actual record num: %v", num, curWorkerSize) } var wg sync.WaitGroup wg.Add(1) @@ -968,7 +968,7 @@ func injectCheckBackfillWorkerNum(curWorkerSize int, isMergeWorker bool) error { wg.Wait() } } - }) + } return nil } diff --git a/pkg/ddl/backfilling.go__failpoint_stash__ b/pkg/ddl/backfilling.go__failpoint_stash__ new file mode 100644 index 0000000000000..4fbed0422726d --- /dev/null +++ b/pkg/ddl/backfilling.go__failpoint_stash__ @@ -0,0 +1,1124 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/disttask/operator" + "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror" + decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" + "github.com/pingcap/tidb/pkg/util/topsql" + "github.com/prometheus/client_golang/prometheus" + "github.com/tikv/client-go/v2/tikv" + kvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +type backfillerType byte + +const ( + typeAddIndexWorker backfillerType = iota + typeUpdateColumnWorker + typeCleanUpIndexWorker + typeAddIndexMergeTmpWorker + typeReorgPartitionWorker + + typeCount +) + +// BackupFillerTypeCount represents the count of ddl jobs that need to do backfill. +func BackupFillerTypeCount() int { + return int(typeCount) +} + +func (bT backfillerType) String() string { + switch bT { + case typeAddIndexWorker: + return "add index" + case typeUpdateColumnWorker: + return "update column" + case typeCleanUpIndexWorker: + return "clean up index" + case typeAddIndexMergeTmpWorker: + return "merge temporary index" + case typeReorgPartitionWorker: + return "reorganize partition" + default: + return "unknown" + } +} + +// By now the DDL jobs that need backfilling include: +// 1: add-index +// 2: modify-column-type +// 3: clean-up global index +// 4: reorganize partition +// +// They all have a write reorganization state to back fill data into the rows existed. +// Backfilling is time consuming, to accelerate this process, TiDB has built some sub +// workers to do this in the DDL owner node. +// +// DDL owner thread (also see comments before runReorgJob func) +// ^ +// | (reorgCtx.doneCh) +// | +// worker master +// ^ (waitTaskResults) +// | +// | +// v (sendRangeTask) +// +--------------------+---------+---------+------------------+--------------+ +// | | | | | +// backfillworker1 backfillworker2 backfillworker3 backfillworker4 ... +// +// The worker master is responsible for scaling the backfilling workers according to the +// system variable "tidb_ddl_reorg_worker_cnt". Essentially, reorg job is mainly based +// on the [start, end] range of the table to backfill data. We did not do it all at once, +// there were several ddl rounds. +// +// [start1---end1 start2---end2 start3---end3 start4---end4 ... ... ] +// | | | | | | | | +// +-------+ +-------+ +-------+ +-------+ ... ... +// | | | | +// bfworker1 bfworker2 bfworker3 bfworker4 ... ... +// | | | | | | +// +---------------- (round1)----------------+ +--(round2)--+ +// +// The main range [start, end] will be split into small ranges. +// Each small range corresponds to a region and it will be delivered to a backfillworker. +// Each worker can only be assigned with one range at one round, those remaining ranges +// will be cached until all the backfill workers have had their previous range jobs done. +// +// [ region start --------------------- region end ] +// | +// v +// [ batch ] [ batch ] [ batch ] [ batch ] ... +// | | | | +// v v v v +// (a kv txn) -> -> -> +// +// For a single range, backfill worker doesn't backfill all the data in one kv transaction. +// Instead, it is divided into batches, each time a kv transaction completes the backfilling +// of a partial batch. + +// backfillTaskContext is the context of the batch adding indices or updating column values. +// After finishing the batch adding indices or updating column values, result in backfillTaskContext will be merged into backfillResult. +type backfillTaskContext struct { + nextKey kv.Key + done bool + addedCount int + scanCount int + warnings map[errors.ErrorID]*terror.Error + warningsCount map[errors.ErrorID]int64 + finishTS uint64 +} + +type backfillCtx struct { + id int + *ddlCtx + sessCtx sessionctx.Context + warnings contextutil.WarnHandlerExt + loc *time.Location + exprCtx exprctx.BuildContext + tblCtx table.MutateContext + schemaName string + table table.Table + batchCnt int + jobContext *JobContext + metricCounter prometheus.Counter +} + +func newBackfillCtx(id int, rInfo *reorgInfo, + schemaName string, tbl table.Table, jobCtx *JobContext, label string, isDistributed bool) (*backfillCtx, error) { + sessCtx, err := newSessCtx(rInfo.d.store, rInfo.ReorgMeta) + if err != nil { + return nil, err + } + + if isDistributed { + id = int(backfillContextID.Add(1)) + } + + exprCtx := sessCtx.GetExprCtx() + return &backfillCtx{ + id: id, + ddlCtx: rInfo.d, + sessCtx: sessCtx, + warnings: sessCtx.GetSessionVars().StmtCtx.WarnHandler, + exprCtx: exprCtx, + tblCtx: sessCtx.GetTableCtx(), + loc: exprCtx.GetEvalCtx().Location(), + schemaName: schemaName, + table: tbl, + batchCnt: int(variable.GetDDLReorgBatchSize()), + jobContext: jobCtx, + metricCounter: metrics.BackfillTotalCounter.WithLabelValues( + metrics.GenerateReorgLabel(label, schemaName, tbl.Meta().Name.String())), + }, nil +} + +func updateTxnEntrySizeLimitIfNeeded(txn kv.Transaction) { + if entrySizeLimit := variable.TxnEntrySizeLimit.Load(); entrySizeLimit > 0 { + txn.SetOption(kv.SizeLimits, kv.TxnSizeLimits{ + Entry: entrySizeLimit, + Total: kv.TxnTotalSizeLimit.Load(), + }) + } +} + +type backfiller interface { + BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, err error) + AddMetricInfo(float64) + GetCtx() *backfillCtx + String() string +} + +type backfillResult struct { + taskID int + addedCount int + scanCount int + totalCount int + nextKey kv.Key + err error +} + +type reorgBackfillTask struct { + physicalTable table.PhysicalTable + + // TODO: Remove the following fields after remove the function of run. + id int + startKey kv.Key + endKey kv.Key + jobID int64 + sqlQuery string + priority int +} + +func (r *reorgBackfillTask) getJobID() int64 { + return r.jobID +} + +func (r *reorgBackfillTask) String() string { + pID := r.physicalTable.GetPhysicalID() + start := hex.EncodeToString(r.startKey) + end := hex.EncodeToString(r.endKey) + jobID := r.getJobID() + return fmt.Sprintf("taskID: %d, physicalTableID: %d, range: [%s, %s), jobID: %d", r.id, pID, start, end, jobID) +} + +// mergeBackfillCtxToResult merge partial result in taskCtx into result. +func mergeBackfillCtxToResult(taskCtx *backfillTaskContext, result *backfillResult) { + result.nextKey = taskCtx.nextKey + result.addedCount += taskCtx.addedCount + result.scanCount += taskCtx.scanCount +} + +type backfillWorker struct { + backfiller + taskCh chan *reorgBackfillTask + resultCh chan *backfillResult + ctx context.Context + cancel func() + wg *sync.WaitGroup +} + +func newBackfillWorker(ctx context.Context, bf backfiller) *backfillWorker { + bfCtx, cancel := context.WithCancel(ctx) + return &backfillWorker{ + backfiller: bf, + taskCh: make(chan *reorgBackfillTask, 1), + resultCh: make(chan *backfillResult, 1), + ctx: bfCtx, + cancel: cancel, + } +} + +func (w *backfillWorker) String() string { + return fmt.Sprintf("backfill-worker %d, tp %s", w.GetCtx().id, w.backfiller.String()) +} + +func (w *backfillWorker) Close() { + if w.cancel != nil { + w.cancel() + w.cancel = nil + } +} + +func closeBackfillWorkers(workers []*backfillWorker) { + for _, worker := range workers { + worker.Close() + } +} + +// ResultCounterForTest is used for test. +var ResultCounterForTest *atomic.Int32 + +// handleBackfillTask backfills range [task.startHandle, task.endHandle) handle's index to table. +func (w *backfillWorker) handleBackfillTask(d *ddlCtx, task *reorgBackfillTask, bf backfiller) *backfillResult { + handleRange := *task + result := &backfillResult{ + taskID: task.id, + err: nil, + addedCount: 0, + nextKey: handleRange.startKey, + } + lastLogCount := 0 + lastLogTime := time.Now() + startTime := lastLogTime + jobID := task.getJobID() + rc := d.getReorgCtx(jobID) + + for { + // Give job chance to be canceled or paused, if we not check it here, + // we will never cancel the job once there is panic in bf.BackfillData. + // Because reorgRecordTask may run a long time, + // we should check whether this ddl job is still runnable. + err := d.isReorgRunnable(jobID, false) + if err != nil { + result.err = err + return result + } + + taskCtx, err := bf.BackfillData(handleRange) + if err != nil { + result.err = err + return result + } + + bf.AddMetricInfo(float64(taskCtx.addedCount)) + mergeBackfillCtxToResult(&taskCtx, result) + + // Although `handleRange` is for data in one region, but back fill worker still split it into many + // small reorg batch size slices and reorg them in many different kv txn. + // If a task failed, it may contained some committed small kv txn which has already finished the + // small range reorganization. + // In the next round of reorganization, the target handle range may overlap with last committed + // small ranges. This will cause the `redo` action in reorganization. + // So for added count and warnings collection, it is recommended to collect the statistics in every + // successfully committed small ranges rather than fetching it in the total result. + rc.increaseRowCount(int64(taskCtx.addedCount)) + rc.mergeWarnings(taskCtx.warnings, taskCtx.warningsCount) + + if num := result.scanCount - lastLogCount; num >= 90000 { + lastLogCount = result.scanCount + logutil.DDLLogger().Info("backfill worker back fill index", zap.Stringer("worker", w), + zap.Int("addedCount", result.addedCount), zap.Int("scanCount", result.scanCount), + zap.String("next key", hex.EncodeToString(taskCtx.nextKey)), + zap.Float64("speed(rows/s)", float64(num)/time.Since(lastLogTime).Seconds())) + lastLogTime = time.Now() + } + + handleRange.startKey = taskCtx.nextKey + if taskCtx.done { + break + } + } + logutil.DDLLogger().Info("backfill worker finish task", + zap.Stringer("worker", w), zap.Stringer("task", task), + zap.Int("added count", result.addedCount), + zap.Int("scan count", result.scanCount), + zap.String("next key", hex.EncodeToString(result.nextKey)), + zap.Stringer("take time", time.Since(startTime))) + if ResultCounterForTest != nil && result.err == nil { + ResultCounterForTest.Add(1) + } + return result +} + +func (w *backfillWorker) sendResult(result *backfillResult) { + select { + case <-w.ctx.Done(): + case w.resultCh <- result: + } +} + +func (w *backfillWorker) run(d *ddlCtx, bf backfiller, job *model.Job) { + logger := logutil.DDLLogger().With(zap.Stringer("worker", w), zap.Int64("jobID", job.ID)) + var ( + curTaskID int + task *reorgBackfillTask + ok bool + ) + + defer w.wg.Done() + defer util.Recover(metrics.LabelDDL, "backfillWorker.run", func() { + w.sendResult(&backfillResult{taskID: curTaskID, err: dbterror.ErrReorgPanic}) + }, false) + for { + select { + case <-w.ctx.Done(): + logger.Info("backfill worker exit on context done") + return + case task, ok = <-w.taskCh: + } + if !ok { + logger.Info("backfill worker exit") + return + } + curTaskID = task.id + d.setDDLLabelForTopSQL(job.ID, job.Query) + + logger.Debug("backfill worker got task", zap.Int("workerID", w.GetCtx().id), zap.Stringer("task", task)) + failpoint.Inject("mockBackfillRunErr", func() { + if w.GetCtx().id == 0 { + result := &backfillResult{taskID: task.id, addedCount: 0, nextKey: nil, err: errors.Errorf("mock backfill error")} + w.sendResult(result) + failpoint.Continue() + } + }) + + failpoint.Inject("mockHighLoadForAddIndex", func() { + sqlPrefixes := []string{"alter"} + topsql.MockHighCPULoad(job.Query, sqlPrefixes, 5) + }) + + failpoint.Inject("mockBackfillSlow", func() { + time.Sleep(100 * time.Millisecond) + }) + + // Change the batch size dynamically. + w.GetCtx().batchCnt = int(variable.GetDDLReorgBatchSize()) + result := w.handleBackfillTask(d, task, bf) + w.sendResult(result) + + if result.err != nil { + logger.Info("backfill worker exit on error", + zap.Error(result.err)) + return + } + } +} + +// loadTableRanges load table key ranges from PD between given start key and end key. +// It returns up to `limit` ranges. +func loadTableRanges( + ctx context.Context, + t table.PhysicalTable, + store kv.Storage, + startKey, endKey kv.Key, + limit int, +) ([]kv.KeyRange, error) { + if len(startKey) == 0 && len(endKey) == 0 { + logutil.DDLLogger().Info("load noop table range", + zap.Int64("physicalTableID", t.GetPhysicalID())) + return []kv.KeyRange{}, nil + } + + s, ok := store.(tikv.Storage) + if !ok { + // Only support split ranges in tikv.Storage now. + logutil.DDLLogger().Info("load table ranges failed, unsupported storage", + zap.String("storage", fmt.Sprintf("%T", store)), + zap.Int64("physicalTableID", t.GetPhysicalID())) + return []kv.KeyRange{{StartKey: startKey, EndKey: endKey}}, nil + } + + rc := s.GetRegionCache() + maxSleep := 10000 // ms + bo := tikv.NewBackofferWithVars(ctx, maxSleep, nil) + var ranges []kv.KeyRange + err := util.RunWithRetry(util.DefaultMaxRetries, util.RetryInterval, func() (bool, error) { + logutil.DDLLogger().Info("load table ranges from PD", + zap.Int64("physicalTableID", t.GetPhysicalID()), + zap.String("start key", hex.EncodeToString(startKey)), + zap.String("end key", hex.EncodeToString(endKey))) + rs, err := rc.BatchLoadRegionsWithKeyRange(bo, startKey, endKey, limit) + if err != nil { + return false, errors.Trace(err) + } + ranges = make([]kv.KeyRange, 0, len(rs)) + for _, r := range rs { + ranges = append(ranges, kv.KeyRange{StartKey: r.StartKey(), EndKey: r.EndKey()}) + } + err = validateAndFillRanges(ranges, startKey, endKey) + if err != nil { + return true, errors.Trace(err) + } + return false, nil + }) + if err != nil { + return nil, errors.Trace(err) + } + logutil.DDLLogger().Info("load table ranges from PD done", + zap.Int64("physicalTableID", t.GetPhysicalID()), + zap.String("range start", hex.EncodeToString(ranges[0].StartKey)), + zap.String("range end", hex.EncodeToString(ranges[len(ranges)-1].EndKey)), + zap.Int("range count", len(ranges))) + return ranges, nil +} + +func validateAndFillRanges(ranges []kv.KeyRange, startKey, endKey []byte) error { + if len(ranges) == 0 { + errMsg := fmt.Sprintf("cannot find region in range [%s, %s]", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + return dbterror.ErrInvalidSplitRegionRanges.GenWithStackByArgs(errMsg) + } + for i, r := range ranges { + if i == 0 { + s := r.StartKey + if len(s) == 0 || bytes.Compare(s, startKey) < 0 { + ranges[i].StartKey = startKey + } + } + if i == len(ranges)-1 { + e := r.EndKey + if len(e) == 0 || bytes.Compare(e, endKey) > 0 { + ranges[i].EndKey = endKey + } + } + if len(ranges[i].StartKey) == 0 || len(ranges[i].EndKey) == 0 { + return errors.Errorf("get empty start/end key in the middle of ranges") + } + if i > 0 && !bytes.Equal(ranges[i-1].EndKey, ranges[i].StartKey) { + return errors.Errorf("ranges are not continuous") + } + } + return nil +} + +func getBatchTasks( + t table.Table, + reorgInfo *reorgInfo, + kvRanges []kv.KeyRange, + taskIDAlloc *taskIDAllocator, + bfWorkerTp backfillerType, +) []*reorgBackfillTask { + batchTasks := make([]*reorgBackfillTask, 0, len(kvRanges)) + //nolint:forcetypeassert + phyTbl := t.(table.PhysicalTable) + for _, r := range kvRanges { + taskID := taskIDAlloc.alloc() + startKey := r.StartKey + endKey := r.EndKey + endKey = getActualEndKey(t, reorgInfo, bfWorkerTp, startKey, endKey, taskID) + task := &reorgBackfillTask{ + id: taskID, + jobID: reorgInfo.Job.ID, + physicalTable: phyTbl, + priority: reorgInfo.Priority, + startKey: startKey, + endKey: endKey, + } + batchTasks = append(batchTasks, task) + } + return batchTasks +} + +func getActualEndKey( + t table.Table, + reorgInfo *reorgInfo, + bfTp backfillerType, + rangeStart, rangeEnd kv.Key, + taskID int, +) kv.Key { + job := reorgInfo.Job + //nolint:forcetypeassert + phyTbl := t.(table.PhysicalTable) + + if bfTp == typeAddIndexMergeTmpWorker { + // Temp Index data does not grow infinitely, we can return the whole range + // and IndexMergeTmpWorker should still be finished in a bounded time. + return rangeEnd + } + if bfTp == typeAddIndexWorker && job.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { + // Ingest worker uses coprocessor to read table data. It is fast enough, + // we don't need to get the actual end key of this range. + return rangeEnd + } + + // Otherwise to avoid the future data written to key range of [backfillChunkEndKey, rangeEnd) and + // backfill worker can't catch up, we shrink the end key to the actual written key for now. + jobCtx := reorgInfo.NewJobContext() + + actualEndKey, err := GetRangeEndKey(jobCtx, reorgInfo.d.store, job.Priority, t.RecordPrefix(), rangeStart, rangeEnd) + if err != nil { + logutil.DDLLogger().Info("get backfill range task, get reverse key failed", zap.Error(err)) + return rangeEnd + } + logutil.DDLLogger().Info("get backfill range task, change end key", + zap.Int("id", taskID), + zap.Int64("pTbl", phyTbl.GetPhysicalID()), + zap.String("end key", hex.EncodeToString(rangeEnd)), + zap.String("current end key", hex.EncodeToString(actualEndKey))) + return actualEndKey +} + +// sendTasks sends tasks to workers, and returns remaining kvRanges that is not handled. +func sendTasks( + scheduler backfillScheduler, + t table.PhysicalTable, + kvRanges []kv.KeyRange, + reorgInfo *reorgInfo, + taskIDAlloc *taskIDAllocator, + bfWorkerTp backfillerType, +) error { + batchTasks := getBatchTasks(t, reorgInfo, kvRanges, taskIDAlloc, bfWorkerTp) + for _, task := range batchTasks { + if err := scheduler.sendTask(task); err != nil { + return errors.Trace(err) + } + } + return nil +} + +var ( + // TestCheckWorkerNumCh use for test adjust backfill worker. + TestCheckWorkerNumCh = make(chan *sync.WaitGroup) + // TestCheckWorkerNumber use for test adjust backfill worker. + TestCheckWorkerNumber = int32(variable.DefTiDBDDLReorgWorkerCount) + // TestCheckReorgTimeout is used to mock timeout when reorg data. + TestCheckReorgTimeout = int32(0) +) + +func loadDDLReorgVars(ctx context.Context, sessPool *sess.Pool) error { + // Get sessionctx from context resource pool. + sCtx, err := sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer sessPool.Put(sCtx) + return ddlutil.LoadDDLReorgVars(ctx, sCtx) +} + +func makeupDecodeColMap(dbName model.CIStr, t table.Table) (map[int64]decoder.Column, error) { + writableColInfos := make([]*model.ColumnInfo, 0, len(t.WritableCols())) + for _, col := range t.WritableCols() { + writableColInfos = append(writableColInfos, col.ColumnInfo) + } + exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(newReorgExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta()) + if err != nil { + return nil, err + } + mockSchema := expression.NewSchema(exprCols...) + + decodeColMap := decoder.BuildFullDecodeColMap(t.WritableCols(), mockSchema) + + return decodeColMap, nil +} + +var backfillTaskChanSize = 128 + +// SetBackfillTaskChanSizeForTest is only used for test. +func SetBackfillTaskChanSizeForTest(n int) { + backfillTaskChanSize = n +} + +func (dc *ddlCtx) runAddIndexInLocalIngestMode( + ctx context.Context, + sessPool *sess.Pool, + t table.PhysicalTable, + reorgInfo *reorgInfo, +) error { + // TODO(tangenta): support adjust worker count dynamically. + if err := dc.isReorgRunnable(reorgInfo.Job.ID, false); err != nil { + return errors.Trace(err) + } + job := reorgInfo.Job + opCtx := NewLocalOperatorCtx(ctx, job.ID) + idxCnt := len(reorgInfo.elements) + indexIDs := make([]int64, 0, idxCnt) + indexInfos := make([]*model.IndexInfo, 0, idxCnt) + uniques := make([]bool, 0, idxCnt) + hasUnique := false + for _, e := range reorgInfo.elements { + indexIDs = append(indexIDs, e.ID) + indexInfo := model.FindIndexInfoByID(t.Meta().Indices, e.ID) + if indexInfo == nil { + logutil.DDLIngestLogger().Warn("index info not found", + zap.Int64("jobID", job.ID), + zap.Int64("tableID", t.Meta().ID), + zap.Int64("indexID", e.ID)) + return errors.Errorf("index info not found: %d", e.ID) + } + indexInfos = append(indexInfos, indexInfo) + uniques = append(uniques, indexInfo.Unique) + hasUnique = hasUnique || indexInfo.Unique + } + + //nolint: forcetypeassert + discovery := dc.store.(tikv.Storage).GetRegionCache().PDClient().GetServiceDiscovery() + bcCtx, err := ingest.LitBackCtxMgr.Register( + ctx, job.ID, hasUnique, dc.etcdCli, discovery, job.ReorgMeta.ResourceGroupName) + if err != nil { + return errors.Trace(err) + } + defer ingest.LitBackCtxMgr.Unregister(job.ID) + sctx, err := sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer sessPool.Put(sctx) + + cpMgr, err := ingest.NewCheckpointManager( + ctx, + sessPool, + reorgInfo.PhysicalTableID, + job.ID, + indexIDs, + ingest.LitBackCtxMgr.EncodeJobSortPath(job.ID), + dc.store.(kv.StorageWithPD).GetPDClient(), + ) + if err != nil { + logutil.DDLIngestLogger().Warn("create checkpoint manager failed", + zap.Int64("jobID", job.ID), + zap.Error(err)) + } else { + defer cpMgr.Close() + bcCtx.AttachCheckpointManager(cpMgr) + } + + reorgCtx := dc.getReorgCtx(reorgInfo.Job.ID) + rowCntListener := &localRowCntListener{ + prevPhysicalRowCnt: reorgCtx.getRowCount(), + reorgCtx: dc.getReorgCtx(reorgInfo.Job.ID), + counter: metrics.BackfillTotalCounter.WithLabelValues( + metrics.GenerateReorgLabel("add_idx_rate", job.SchemaName, job.TableName)), + } + + avgRowSize := estimateTableRowSize(ctx, dc.store, sctx.GetRestrictedSQLExecutor(), t) + concurrency := int(variable.GetDDLReorgWorkerCounter()) + + engines, err := bcCtx.Register(indexIDs, uniques, t) + if err != nil { + logutil.DDLIngestLogger().Error("cannot register new engine", + zap.Int64("jobID", job.ID), + zap.Error(err), + zap.Int64s("index IDs", indexIDs)) + return errors.Trace(err) + } + + pipe, err := NewAddIndexIngestPipeline( + opCtx, + dc.store, + sessPool, + bcCtx, + engines, + job.ID, + t, + indexInfos, + reorgInfo.StartKey, + reorgInfo.EndKey, + job.ReorgMeta, + avgRowSize, + concurrency, + cpMgr, + rowCntListener, + ) + if err != nil { + return err + } + err = executeAndClosePipeline(opCtx, pipe) + if err != nil { + err1 := bcCtx.FinishAndUnregisterEngines(ingest.OptCloseEngines) + if err1 != nil { + logutil.DDLIngestLogger().Error("unregister engine failed", + zap.Int64("jobID", job.ID), + zap.Error(err1), + zap.Int64s("index IDs", indexIDs)) + } + return err + } + if cpMgr != nil { + cpMgr.AdvanceWatermark(true, true) + } + return bcCtx.FinishAndUnregisterEngines(ingest.OptCleanData | ingest.OptCheckDup) +} + +func executeAndClosePipeline(ctx *OperatorCtx, pipe *operator.AsyncPipeline) error { + err := pipe.Execute() + if err != nil { + return err + } + err = pipe.Close() + if opErr := ctx.OperatorErr(); opErr != nil { + return opErr + } + return err +} + +type localRowCntListener struct { + EmptyRowCntListener + reorgCtx *reorgCtx + counter prometheus.Counter + + // prevPhysicalRowCnt records the row count from previous physical tables (partitions). + prevPhysicalRowCnt int64 + // curPhysicalRowCnt records the row count of current physical table. + curPhysicalRowCnt struct { + cnt int64 + mu sync.Mutex + } +} + +func (s *localRowCntListener) Written(rowCnt int) { + s.curPhysicalRowCnt.mu.Lock() + s.curPhysicalRowCnt.cnt += int64(rowCnt) + s.reorgCtx.setRowCount(s.prevPhysicalRowCnt + s.curPhysicalRowCnt.cnt) + s.curPhysicalRowCnt.mu.Unlock() + s.counter.Add(float64(rowCnt)) +} + +func (s *localRowCntListener) SetTotal(total int) { + s.reorgCtx.setRowCount(s.prevPhysicalRowCnt + int64(total)) +} + +// writePhysicalTableRecord handles the "add index" or "modify/change column" reorganization state for a non-partitioned table or a partition. +// For a partitioned table, it should be handled partition by partition. +// +// How to "add index" or "update column value" in reorganization state? +// Concurrently process the @@tidb_ddl_reorg_worker_cnt tasks. Each task deals with a handle range of the index/row record. +// The handle range is split from PD regions now. Each worker deal with a region table key range one time. +// Each handle range by estimation, concurrent processing needs to perform after the handle range has been acquired. +// The operation flow is as follows: +// 1. Open numbers of defaultWorkers goroutines. +// 2. Split table key range from PD regions. +// 3. Send tasks to running workers by workers's task channel. Each task deals with a region key ranges. +// 4. Wait all these running tasks finished, then continue to step 3, until all tasks is done. +// +// The above operations are completed in a transaction. +// Finally, update the concurrent processing of the total number of rows, and store the completed handle value. +func (dc *ddlCtx) writePhysicalTableRecord( + ctx context.Context, + sessPool *sess.Pool, + t table.PhysicalTable, + bfWorkerType backfillerType, + reorgInfo *reorgInfo, +) error { + startKey, endKey := reorgInfo.StartKey, reorgInfo.EndKey + + if err := dc.isReorgRunnable(reorgInfo.Job.ID, false); err != nil { + return errors.Trace(err) + } + + failpoint.Inject("MockCaseWhenParseFailure", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + failpoint.Return(errors.New("job.ErrCount:" + strconv.Itoa(int(reorgInfo.Job.ErrorCount)) + ", mock unknown type: ast.whenClause.")) + } + }) + if bfWorkerType == typeAddIndexWorker && reorgInfo.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { + return dc.runAddIndexInLocalIngestMode(ctx, sessPool, t, reorgInfo) + } + + jc := reorgInfo.NewJobContext() + + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + + scheduler, err := newTxnBackfillScheduler(egCtx, reorgInfo, sessPool, bfWorkerType, t, jc) + if err != nil { + return errors.Trace(err) + } + defer scheduler.close(true) + + err = scheduler.setupWorkers() + if err != nil { + return errors.Trace(err) + } + + // process result goroutine + eg.Go(func() error { + totalAddedCount := reorgInfo.Job.GetRowCount() + keeper := newDoneTaskKeeper(startKey) + cnt := 0 + + for { + select { + case <-egCtx.Done(): + return egCtx.Err() + case result, ok := <-scheduler.resultChan(): + if !ok { + logutil.DDLLogger().Info("backfill workers successfully processed", + zap.Stringer("element", reorgInfo.currElement), + zap.Int64("total added count", totalAddedCount), + zap.String("start key", hex.EncodeToString(startKey))) + return nil + } + cnt++ + + if result.err != nil { + logutil.DDLLogger().Warn("backfill worker failed", + zap.Int64("job ID", reorgInfo.ID), + zap.Int64("total added count", totalAddedCount), + zap.String("start key", hex.EncodeToString(startKey)), + zap.String("result next key", hex.EncodeToString(result.nextKey)), + zap.Error(result.err)) + return result.err + } + + if result.totalCount > 0 { + totalAddedCount = int64(result.totalCount) + } else { + totalAddedCount += int64(result.addedCount) + } + dc.getReorgCtx(reorgInfo.Job.ID).setRowCount(totalAddedCount) + + keeper.updateNextKey(result.taskID, result.nextKey) + + if cnt%(scheduler.currentWorkerSize()*4) == 0 { + err2 := reorgInfo.UpdateReorgMeta(keeper.nextKey, sessPool) + if err2 != nil { + logutil.DDLLogger().Warn("update reorg meta failed", + zap.Int64("job ID", reorgInfo.ID), + zap.Error(err2)) + } + // We try to adjust the worker size regularly to reduce + // the overhead of loading the DDL related global variables. + err2 = scheduler.adjustWorkerSize() + if err2 != nil { + logutil.DDLLogger().Warn("cannot adjust backfill worker size", + zap.Int64("job ID", reorgInfo.ID), + zap.Error(err2)) + } + } + } + } + }) + + // generate task goroutine + eg.Go(func() error { + // we will modify the startKey in this goroutine, so copy them to avoid race. + start, end := startKey, endKey + taskIDAlloc := newTaskIDAllocator() + for { + kvRanges, err2 := loadTableRanges(egCtx, t, dc.store, start, end, backfillTaskChanSize) + if err2 != nil { + return errors.Trace(err2) + } + if len(kvRanges) == 0 { + break + } + logutil.DDLLogger().Info("start backfill workers to reorg record", + zap.Stringer("type", bfWorkerType), + zap.Int("workerCnt", scheduler.currentWorkerSize()), + zap.Int("regionCnt", len(kvRanges)), + zap.String("startKey", hex.EncodeToString(start)), + zap.String("endKey", hex.EncodeToString(end))) + + err2 = sendTasks(scheduler, t, kvRanges, reorgInfo, taskIDAlloc, bfWorkerType) + if err2 != nil { + return errors.Trace(err2) + } + + start = kvRanges[len(kvRanges)-1].EndKey + if start.Cmp(end) >= 0 { + break + } + } + + scheduler.close(false) + return nil + }) + + return eg.Wait() +} + +func injectCheckBackfillWorkerNum(curWorkerSize int, isMergeWorker bool) error { + if isMergeWorker { + return nil + } + failpoint.Inject("checkBackfillWorkerNum", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + num := int(atomic.LoadInt32(&TestCheckWorkerNumber)) + if num != 0 { + if num != curWorkerSize { + failpoint.Return(errors.Errorf("expected backfill worker num: %v, actual record num: %v", num, curWorkerSize)) + } + var wg sync.WaitGroup + wg.Add(1) + TestCheckWorkerNumCh <- &wg + wg.Wait() + } + } + }) + return nil +} + +// recordIterFunc is used for low-level record iteration. +type recordIterFunc func(h kv.Handle, rowKey kv.Key, rawRecord []byte) (more bool, err error) + +func iterateSnapshotKeys(ctx *JobContext, store kv.Storage, priority int, keyPrefix kv.Key, version uint64, + startKey kv.Key, endKey kv.Key, fn recordIterFunc) error { + isRecord := tablecodec.IsRecordKey(keyPrefix.Next()) + var firstKey kv.Key + if startKey == nil { + firstKey = keyPrefix + } else { + firstKey = startKey + } + + var upperBound kv.Key + if endKey == nil { + upperBound = keyPrefix.PrefixNext() + } else { + upperBound = endKey.PrefixNext() + } + + ver := kv.Version{Ver: version} + snap := store.GetSnapshot(ver) + snap.SetOption(kv.Priority, priority) + snap.SetOption(kv.RequestSourceInternal, true) + snap.SetOption(kv.RequestSourceType, ctx.ddlJobSourceType()) + snap.SetOption(kv.ExplicitRequestSourceType, kvutil.ExplicitTypeDDL) + if tagger := ctx.getResourceGroupTaggerForTopSQL(); tagger != nil { + snap.SetOption(kv.ResourceGroupTagger, tagger) + } + snap.SetOption(kv.ResourceGroupName, ctx.resourceGroupName) + + it, err := snap.Iter(firstKey, upperBound) + if err != nil { + return errors.Trace(err) + } + defer it.Close() + + for it.Valid() { + if !it.Key().HasPrefix(keyPrefix) { + break + } + + var handle kv.Handle + if isRecord { + handle, err = tablecodec.DecodeRowKey(it.Key()) + if err != nil { + return errors.Trace(err) + } + } + + more, err := fn(handle, it.Key(), it.Value()) + if !more || err != nil { + return errors.Trace(err) + } + + err = kv.NextUntil(it, util.RowKeyPrefixFilter(it.Key())) + if err != nil { + if kv.ErrNotExist.Equal(err) { + break + } + return errors.Trace(err) + } + } + + return nil +} + +// GetRangeEndKey gets the actual end key for the range of [startKey, endKey). +func GetRangeEndKey(ctx *JobContext, store kv.Storage, priority int, keyPrefix kv.Key, startKey, endKey kv.Key) (kv.Key, error) { + snap := store.GetSnapshot(kv.MaxVersion) + snap.SetOption(kv.Priority, priority) + if tagger := ctx.getResourceGroupTaggerForTopSQL(); tagger != nil { + snap.SetOption(kv.ResourceGroupTagger, tagger) + } + snap.SetOption(kv.ResourceGroupName, ctx.resourceGroupName) + snap.SetOption(kv.RequestSourceInternal, true) + snap.SetOption(kv.RequestSourceType, ctx.ddlJobSourceType()) + snap.SetOption(kv.ExplicitRequestSourceType, kvutil.ExplicitTypeDDL) + it, err := snap.IterReverse(endKey, nil) + if err != nil { + return nil, errors.Trace(err) + } + defer it.Close() + + if !it.Valid() || !it.Key().HasPrefix(keyPrefix) { + return startKey.Next(), nil + } + if it.Key().Cmp(startKey) < 0 { + return startKey.Next(), nil + } + + return it.Key().Next(), nil +} + +func mergeWarningsAndWarningsCount(partWarnings, totalWarnings map[errors.ErrorID]*terror.Error, partWarningsCount, totalWarningsCount map[errors.ErrorID]int64) (map[errors.ErrorID]*terror.Error, map[errors.ErrorID]int64) { + for _, warn := range partWarnings { + if _, ok := totalWarningsCount[warn.ID()]; ok { + totalWarningsCount[warn.ID()] += partWarningsCount[warn.ID()] + } else { + totalWarningsCount[warn.ID()] = partWarningsCount[warn.ID()] + totalWarnings[warn.ID()] = warn + } + } + return totalWarnings, totalWarningsCount +} + +func logSlowOperations(elapsed time.Duration, slowMsg string, threshold uint32) { + if threshold == 0 { + threshold = atomic.LoadUint32(&variable.DDLSlowOprThreshold) + } + + if elapsed >= time.Duration(threshold)*time.Millisecond { + logutil.DDLLogger().Info("slow operations", + zap.Duration("takeTimes", elapsed), + zap.String("msg", slowMsg)) + } +} + +// doneTaskKeeper keeps the done tasks and update the latest next key. +type doneTaskKeeper struct { + doneTaskNextKey map[int]kv.Key + current int + nextKey kv.Key +} + +func newDoneTaskKeeper(start kv.Key) *doneTaskKeeper { + return &doneTaskKeeper{ + doneTaskNextKey: make(map[int]kv.Key), + current: 0, + nextKey: start, + } +} + +func (n *doneTaskKeeper) updateNextKey(doneTaskID int, next kv.Key) { + if doneTaskID == n.current { + n.current++ + n.nextKey = next + for { + nKey, ok := n.doneTaskNextKey[n.current] + if !ok { + break + } + delete(n.doneTaskNextKey, n.current) + n.current++ + n.nextKey = nKey + } + return + } + n.doneTaskNextKey[doneTaskID] = next +} diff --git a/pkg/ddl/backfilling_dist_scheduler.go b/pkg/ddl/backfilling_dist_scheduler.go index 55abe4225da12..b74cbfa6f0cae 100644 --- a/pkg/ddl/backfilling_dist_scheduler.go +++ b/pkg/ddl/backfilling_dist_scheduler.go @@ -105,15 +105,15 @@ func (sch *BackfillingSchedulerExt) OnNextSubtasksBatch( return generateMergePlan(taskHandle, task, logger) case proto.BackfillStepWriteAndIngest: if sch.GlobalSort { - failpoint.Inject("mockWriteIngest", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockWriteIngest")); _err_ == nil { m := &BackfillSubTaskMeta{ MetaGroups: []*external.SortedKVMeta{}, } metaBytes, _ := json.Marshal(m) metaArr := make([][]byte, 0, 16) metaArr = append(metaArr, metaBytes) - failpoint.Return(metaArr, nil) - }) + return metaArr, nil + } return generateGlobalSortIngestPlan( ctx, sch.d.store.(kv.StorageWithPD), @@ -148,9 +148,9 @@ func (sch *BackfillingSchedulerExt) GetNextStep(task *proto.TaskBase) proto.Step } func skipMergeSort(stats []external.MultipleFilesStat) bool { - failpoint.Inject("forceMergeSort", func() { - failpoint.Return(false) - }) + if _, _err_ := failpoint.Eval(_curpkg_("forceMergeSort")); _err_ == nil { + return false + } return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold } @@ -330,9 +330,9 @@ func generateNonPartitionPlan( } func calculateRegionBatch(totalRegionCnt int, instanceCnt int, useLocalDisk bool) int { - failpoint.Inject("mockRegionBatch", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) + if val, _err_ := failpoint.Eval(_curpkg_("mockRegionBatch")); _err_ == nil { + return val.(int) + } var regionBatch int avgTasksPerInstance := (totalRegionCnt + instanceCnt - 1) / instanceCnt // ceiling if useLocalDisk { @@ -427,10 +427,10 @@ func splitSubtaskMetaForOneKVMetaGroup( return nil, err } ts := oracle.ComposeTS(p, l) - failpoint.Inject("mockTSForGlobalSort", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockTSForGlobalSort")); _err_ == nil { i := val.(int) ts = uint64(i) - }) + } splitter, err := getRangeSplitter( ctx, store, cloudStorageURI, int64(kvMeta.TotalKVSize), instanceCnt, kvMeta.MultipleFilesStats, logger) if err != nil { diff --git a/pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ b/pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ new file mode 100644 index 0000000000000..55abe4225da12 --- /dev/null +++ b/pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ @@ -0,0 +1,639 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "math" + "sort" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + diststorage "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/external" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/util/backoff" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "go.uber.org/zap" +) + +// BackfillingSchedulerExt is an extension of litBackfillScheduler, exported for test. +type BackfillingSchedulerExt struct { + d *ddl + GlobalSort bool +} + +// NewBackfillingSchedulerExt creates a new backfillingSchedulerExt, only used for test now. +func NewBackfillingSchedulerExt(d DDL) (scheduler.Extension, error) { + ddl, ok := d.(*ddl) + if !ok { + return nil, errors.New("The getDDL result should be the type of *ddl") + } + return &BackfillingSchedulerExt{ + d: ddl, + }, nil +} + +var _ scheduler.Extension = (*BackfillingSchedulerExt)(nil) + +// OnTick implements scheduler.Extension interface. +func (*BackfillingSchedulerExt) OnTick(_ context.Context, _ *proto.Task) { +} + +// OnNextSubtasksBatch generate batch of next step's plan. +func (sch *BackfillingSchedulerExt) OnNextSubtasksBatch( + ctx context.Context, + taskHandle diststorage.TaskHandle, + task *proto.Task, + execIDs []string, + nextStep proto.Step, +) (subtaskMeta [][]byte, err error) { + logger := logutil.DDLLogger().With( + zap.Stringer("type", task.Type), + zap.Int64("task-id", task.ID), + zap.String("curr-step", proto.Step2Str(task.Type, task.Step)), + zap.String("next-step", proto.Step2Str(task.Type, nextStep)), + ) + var backfillMeta BackfillTaskMeta + if err := json.Unmarshal(task.Meta, &backfillMeta); err != nil { + return nil, err + } + job := &backfillMeta.Job + tblInfo, err := getTblInfo(ctx, sch.d, job) + if err != nil { + return nil, err + } + logger.Info("on next subtasks batch") + + // TODO: use planner. + switch nextStep { + case proto.BackfillStepReadIndex: + if tblInfo.Partition != nil { + return generatePartitionPlan(tblInfo) + } + return generateNonPartitionPlan(ctx, sch.d, tblInfo, job, sch.GlobalSort, len(execIDs)) + case proto.BackfillStepMergeSort: + return generateMergePlan(taskHandle, task, logger) + case proto.BackfillStepWriteAndIngest: + if sch.GlobalSort { + failpoint.Inject("mockWriteIngest", func() { + m := &BackfillSubTaskMeta{ + MetaGroups: []*external.SortedKVMeta{}, + } + metaBytes, _ := json.Marshal(m) + metaArr := make([][]byte, 0, 16) + metaArr = append(metaArr, metaBytes) + failpoint.Return(metaArr, nil) + }) + return generateGlobalSortIngestPlan( + ctx, + sch.d.store.(kv.StorageWithPD), + taskHandle, + task, + backfillMeta.CloudStorageURI, + logger) + } + return nil, nil + default: + return nil, nil + } +} + +// GetNextStep implements scheduler.Extension interface. +func (sch *BackfillingSchedulerExt) GetNextStep(task *proto.TaskBase) proto.Step { + switch task.Step { + case proto.StepInit: + return proto.BackfillStepReadIndex + case proto.BackfillStepReadIndex: + if sch.GlobalSort { + return proto.BackfillStepMergeSort + } + return proto.StepDone + case proto.BackfillStepMergeSort: + return proto.BackfillStepWriteAndIngest + case proto.BackfillStepWriteAndIngest: + return proto.StepDone + default: + return proto.StepDone + } +} + +func skipMergeSort(stats []external.MultipleFilesStat) bool { + failpoint.Inject("forceMergeSort", func() { + failpoint.Return(false) + }) + return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold +} + +// OnDone implements scheduler.Extension interface. +func (*BackfillingSchedulerExt) OnDone(_ context.Context, _ diststorage.TaskHandle, _ *proto.Task) error { + return nil +} + +// GetEligibleInstances implements scheduler.Extension interface. +func (*BackfillingSchedulerExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]string, error) { + return nil, nil +} + +// IsRetryableErr implements scheduler.Extension.IsRetryableErr interface. +func (*BackfillingSchedulerExt) IsRetryableErr(error) bool { + return true +} + +// LitBackfillScheduler wraps BaseScheduler. +type LitBackfillScheduler struct { + *scheduler.BaseScheduler + d *ddl +} + +func newLitBackfillScheduler(ctx context.Context, d *ddl, task *proto.Task, param scheduler.Param) scheduler.Scheduler { + sch := LitBackfillScheduler{ + d: d, + BaseScheduler: scheduler.NewBaseScheduler(ctx, task, param), + } + return &sch +} + +// Init implements BaseScheduler interface. +func (sch *LitBackfillScheduler) Init() (err error) { + taskMeta := &BackfillTaskMeta{} + if err = json.Unmarshal(sch.BaseScheduler.GetTask().Meta, taskMeta); err != nil { + return errors.Annotate(err, "unmarshal task meta failed") + } + sch.BaseScheduler.Extension = &BackfillingSchedulerExt{ + d: sch.d, + GlobalSort: len(taskMeta.CloudStorageURI) > 0} + return sch.BaseScheduler.Init() +} + +// Close implements BaseScheduler interface. +func (sch *LitBackfillScheduler) Close() { + sch.BaseScheduler.Close() +} + +func getTblInfo(ctx context.Context, d *ddl, job *model.Job) (tblInfo *model.TableInfo, err error) { + err = kv.RunInNewTxn(ctx, d.store, true, func(_ context.Context, txn kv.Transaction) error { + tblInfo, err = meta.NewMeta(txn).GetTable(job.SchemaID, job.TableID) + return err + }) + if err != nil { + return nil, err + } + + return tblInfo, nil +} + +func generatePartitionPlan(tblInfo *model.TableInfo) (metas [][]byte, err error) { + defs := tblInfo.Partition.Definitions + physicalIDs := make([]int64, len(defs)) + for i := range defs { + physicalIDs[i] = defs[i].ID + } + + subTaskMetas := make([][]byte, 0, len(physicalIDs)) + for _, physicalID := range physicalIDs { + subTaskMeta := &BackfillSubTaskMeta{ + PhysicalTableID: physicalID, + } + + metaBytes, err := json.Marshal(subTaskMeta) + if err != nil { + return nil, err + } + + subTaskMetas = append(subTaskMetas, metaBytes) + } + return subTaskMetas, nil +} + +const ( + scanRegionBackoffBase = 200 * time.Millisecond + scanRegionBackoffMax = 2 * time.Second +) + +func generateNonPartitionPlan( + ctx context.Context, + d *ddl, + tblInfo *model.TableInfo, + job *model.Job, + useCloud bool, + instanceCnt int, +) (metas [][]byte, err error) { + tbl, err := getTable(d.ddlCtx.getAutoIDRequirement(), job.SchemaID, tblInfo) + if err != nil { + return nil, err + } + ver, err := getValidCurrentVersion(d.store) + if err != nil { + return nil, errors.Trace(err) + } + + startKey, endKey, err := getTableRange(d.jobContext(job.ID, job.ReorgMeta), d.ddlCtx, tbl.(table.PhysicalTable), ver.Ver, job.Priority) + if startKey == nil && endKey == nil { + // Empty table. + return nil, nil + } + if err != nil { + return nil, errors.Trace(err) + } + + subTaskMetas := make([][]byte, 0, 4) + backoffer := backoff.NewExponential(scanRegionBackoffBase, 2, scanRegionBackoffMax) + err = handle.RunWithRetry(ctx, 8, backoffer, logutil.DDLLogger(), func(_ context.Context) (bool, error) { + regionCache := d.store.(helper.Storage).GetRegionCache() + recordRegionMetas, err := regionCache.LoadRegionsInKeyRange(tikv.NewBackofferWithVars(context.Background(), 20000, nil), startKey, endKey) + + if err != nil { + return false, err + } + sort.Slice(recordRegionMetas, func(i, j int) bool { + return bytes.Compare(recordRegionMetas[i].StartKey(), recordRegionMetas[j].StartKey()) < 0 + }) + + // Check if regions are continuous. + shouldRetry := false + cur := recordRegionMetas[0] + for _, m := range recordRegionMetas[1:] { + if !bytes.Equal(cur.EndKey(), m.StartKey()) { + shouldRetry = true + break + } + cur = m + } + + if shouldRetry { + return true, nil + } + + regionBatch := calculateRegionBatch(len(recordRegionMetas), instanceCnt, !useCloud) + + for i := 0; i < len(recordRegionMetas); i += regionBatch { + end := i + regionBatch + if end > len(recordRegionMetas) { + end = len(recordRegionMetas) + } + batch := recordRegionMetas[i:end] + subTaskMeta := &BackfillSubTaskMeta{ + RowStart: batch[0].StartKey(), + RowEnd: batch[len(batch)-1].EndKey(), + } + if i == 0 { + subTaskMeta.RowStart = startKey + } + if end == len(recordRegionMetas) { + subTaskMeta.RowEnd = endKey + } + metaBytes, err := json.Marshal(subTaskMeta) + if err != nil { + return false, err + } + subTaskMetas = append(subTaskMetas, metaBytes) + } + return false, nil + }) + if err != nil { + return nil, errors.Trace(err) + } + if len(subTaskMetas) == 0 { + return nil, errors.Errorf("regions are not continuous") + } + return subTaskMetas, nil +} + +func calculateRegionBatch(totalRegionCnt int, instanceCnt int, useLocalDisk bool) int { + failpoint.Inject("mockRegionBatch", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) + var regionBatch int + avgTasksPerInstance := (totalRegionCnt + instanceCnt - 1) / instanceCnt // ceiling + if useLocalDisk { + regionBatch = avgTasksPerInstance + } else { + // For cloud storage, each subtask should contain no more than 4000 regions. + regionBatch = min(4000, avgTasksPerInstance) + } + regionBatch = max(regionBatch, 1) + return regionBatch +} + +func generateGlobalSortIngestPlan( + ctx context.Context, + store kv.StorageWithPD, + taskHandle diststorage.TaskHandle, + task *proto.Task, + cloudStorageURI string, + logger *zap.Logger, +) ([][]byte, error) { + var ( + kvMetaGroups []*external.SortedKVMeta + eleIDs []int64 + ) + for _, step := range []proto.Step{proto.BackfillStepMergeSort, proto.BackfillStepReadIndex} { + hasSubtasks := false + err := forEachBackfillSubtaskMeta(taskHandle, task.ID, step, func(subtask *BackfillSubTaskMeta) { + hasSubtasks = true + if kvMetaGroups == nil { + kvMetaGroups = make([]*external.SortedKVMeta, len(subtask.MetaGroups)) + eleIDs = subtask.EleIDs + } + for i, cur := range subtask.MetaGroups { + if kvMetaGroups[i] == nil { + kvMetaGroups[i] = &external.SortedKVMeta{} + } + kvMetaGroups[i].Merge(cur) + } + }) + if err != nil { + return nil, err + } + if hasSubtasks { + break + } + // If there is no subtask for merge sort step, + // it means the merge sort step is skipped. + } + + instanceIDs, err := scheduler.GetLiveExecIDs(ctx) + if err != nil { + return nil, err + } + iCnt := int64(len(instanceIDs)) + metaArr := make([][]byte, 0, 16) + for i, g := range kvMetaGroups { + if g == nil { + logger.Error("meet empty kv group when getting subtask summary", + zap.Int64("taskID", task.ID)) + return nil, errors.Errorf("subtask kv group %d is empty", i) + } + eleID := int64(0) + // in case the subtask metadata is written by an old version of TiDB. + if i < len(eleIDs) { + eleID = eleIDs[i] + } + newMeta, err := splitSubtaskMetaForOneKVMetaGroup(ctx, store, g, eleID, cloudStorageURI, iCnt, logger) + if err != nil { + return nil, errors.Trace(err) + } + metaArr = append(metaArr, newMeta...) + } + return metaArr, nil +} + +func splitSubtaskMetaForOneKVMetaGroup( + ctx context.Context, + store kv.StorageWithPD, + kvMeta *external.SortedKVMeta, + eleID int64, + cloudStorageURI string, + instanceCnt int64, + logger *zap.Logger, +) (metaArr [][]byte, err error) { + if len(kvMeta.StartKey) == 0 && len(kvMeta.EndKey) == 0 { + // Skip global sort for empty table. + return nil, nil + } + pdCli := store.GetPDClient() + p, l, err := pdCli.GetTS(ctx) + if err != nil { + return nil, err + } + ts := oracle.ComposeTS(p, l) + failpoint.Inject("mockTSForGlobalSort", func(val failpoint.Value) { + i := val.(int) + ts = uint64(i) + }) + splitter, err := getRangeSplitter( + ctx, store, cloudStorageURI, int64(kvMeta.TotalKVSize), instanceCnt, kvMeta.MultipleFilesStats, logger) + if err != nil { + return nil, err + } + defer func() { + err := splitter.Close() + if err != nil { + logger.Error("failed to close range splitter", zap.Error(err)) + } + }() + + startKey := kvMeta.StartKey + var endKey kv.Key + for { + endKeyOfGroup, dataFiles, statFiles, rangeSplitKeys, err := splitter.SplitOneRangesGroup() + if err != nil { + return nil, err + } + if len(endKeyOfGroup) == 0 { + endKey = kvMeta.EndKey + } else { + endKey = kv.Key(endKeyOfGroup).Clone() + } + logger.Info("split subtask range", + zap.String("startKey", hex.EncodeToString(startKey)), + zap.String("endKey", hex.EncodeToString(endKey))) + + if bytes.Compare(startKey, endKey) >= 0 { + return nil, errors.Errorf("invalid range, startKey: %s, endKey: %s", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + } + m := &BackfillSubTaskMeta{ + MetaGroups: []*external.SortedKVMeta{{ + StartKey: startKey, + EndKey: endKey, + TotalKVSize: kvMeta.TotalKVSize / uint64(instanceCnt), + }}, + DataFiles: dataFiles, + StatFiles: statFiles, + RangeSplitKeys: rangeSplitKeys, + TS: ts, + } + if eleID > 0 { + m.EleIDs = []int64{eleID} + } + metaBytes, err := json.Marshal(m) + if err != nil { + return nil, err + } + metaArr = append(metaArr, metaBytes) + if len(endKeyOfGroup) == 0 { + break + } + startKey = endKey + } + return metaArr, nil +} + +func generateMergePlan( + taskHandle diststorage.TaskHandle, + task *proto.Task, + logger *zap.Logger, +) ([][]byte, error) { + // check data files overlaps, + // if data files overlaps too much, we need a merge step. + var ( + multiStatsGroup [][]external.MultipleFilesStat + kvMetaGroups []*external.SortedKVMeta + eleIDs []int64 + ) + err := forEachBackfillSubtaskMeta(taskHandle, task.ID, proto.BackfillStepReadIndex, + func(subtask *BackfillSubTaskMeta) { + if kvMetaGroups == nil { + kvMetaGroups = make([]*external.SortedKVMeta, len(subtask.MetaGroups)) + multiStatsGroup = make([][]external.MultipleFilesStat, len(subtask.MetaGroups)) + eleIDs = subtask.EleIDs + } + for i, g := range subtask.MetaGroups { + if kvMetaGroups[i] == nil { + kvMetaGroups[i] = &external.SortedKVMeta{} + multiStatsGroup[i] = make([]external.MultipleFilesStat, 0, 100) + } + kvMetaGroups[i].Merge(g) + multiStatsGroup[i] = append(multiStatsGroup[i], g.MultipleFilesStats...) + } + }) + if err != nil { + return nil, err + } + + allSkip := true + for _, multiStats := range multiStatsGroup { + if !skipMergeSort(multiStats) { + allSkip = false + break + } + } + if allSkip { + logger.Info("skip merge sort") + return nil, nil + } + + metaArr := make([][]byte, 0, 16) + for i, g := range kvMetaGroups { + dataFiles := make([]string, 0, 1000) + if g == nil { + logger.Error("meet empty kv group when getting subtask summary", + zap.Int64("taskID", task.ID)) + return nil, errors.Errorf("subtask kv group %d is empty", i) + } + for _, m := range g.MultipleFilesStats { + for _, filePair := range m.Filenames { + dataFiles = append(dataFiles, filePair[0]) + } + } + var eleID []int64 + if i < len(eleIDs) { + eleID = []int64{eleIDs[i]} + } + start := 0 + step := external.MergeSortFileCountStep + for start < len(dataFiles) { + end := start + step + if end > len(dataFiles) { + end = len(dataFiles) + } + m := &BackfillSubTaskMeta{ + DataFiles: dataFiles[start:end], + EleIDs: eleID, + } + metaBytes, err := json.Marshal(m) + if err != nil { + return nil, err + } + metaArr = append(metaArr, metaBytes) + + start = end + } + } + return metaArr, nil +} + +func getRangeSplitter( + ctx context.Context, + store kv.StorageWithPD, + cloudStorageURI string, + totalSize int64, + instanceCnt int64, + multiFileStat []external.MultipleFilesStat, + logger *zap.Logger, +) (*external.RangeSplitter, error) { + backend, err := storage.ParseBackend(cloudStorageURI, nil) + if err != nil { + return nil, err + } + extStore, err := storage.NewWithDefaultOpt(ctx, backend) + if err != nil { + return nil, err + } + + rangeGroupSize := totalSize / instanceCnt + rangeGroupKeys := int64(math.MaxInt64) + + var maxSizePerRange = int64(config.SplitRegionSize) + var maxKeysPerRange = int64(config.SplitRegionKeys) + if store != nil { + pdCli := store.GetPDClient() + tls, err := ingest.NewDDLTLS() + if err == nil { + size, keys, err := local.GetRegionSplitSizeKeys(ctx, pdCli, tls) + if err == nil { + maxSizePerRange = max(maxSizePerRange, size) + maxKeysPerRange = max(maxKeysPerRange, keys) + } else { + logger.Warn("fail to get region split keys and size", zap.Error(err)) + } + } else { + logger.Warn("fail to get region split keys and size", zap.Error(err)) + } + } + + return external.NewRangeSplitter(ctx, multiFileStat, extStore, + rangeGroupSize, rangeGroupKeys, maxSizePerRange, maxKeysPerRange) +} + +func forEachBackfillSubtaskMeta( + taskHandle diststorage.TaskHandle, + gTaskID int64, + step proto.Step, + fn func(subtask *BackfillSubTaskMeta), +) error { + subTaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTaskID, step) + if err != nil { + return errors.Trace(err) + } + for _, subTaskMeta := range subTaskMetas { + subtask, err := decodeBackfillSubTaskMeta(subTaskMeta) + if err != nil { + logutil.DDLLogger().Error("unmarshal error", zap.Error(err)) + return errors.Trace(err) + } + fn(subtask) + } + return nil +} diff --git a/pkg/ddl/backfilling_operators.go b/pkg/ddl/backfilling_operators.go index 5ae5554ef290b..a2d87df69a74f 100644 --- a/pkg/ddl/backfilling_operators.go +++ b/pkg/ddl/backfilling_operators.go @@ -243,9 +243,9 @@ func NewWriteIndexToExternalStoragePipeline( } memCap := resource.Mem.Capacity() memSizePerIndex := uint64(memCap / int64(writerCnt*2*len(idxInfos))) - failpoint.Inject("mockWriterMemSize", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockWriterMemSize")); _err_ == nil { memSizePerIndex = 1 * size.GB - }) + } srcOp := NewTableScanTaskSource(ctx, store, tbl, startKey, endKey, nil) scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt, nil) @@ -492,9 +492,9 @@ func (w *tableScanWorker) HandleTask(task TableScanTask, sender func(IndexRecord w.ctx.onError(dbterror.ErrReorgPanic) }, false) - failpoint.Inject("injectPanicForTableScan", func() { + if _, _err_ := failpoint.Eval(_curpkg_("injectPanicForTableScan")); _err_ == nil { panic("mock panic") - }) + } if w.se == nil { sessCtx, err := w.sessPool.Get() if err != nil { @@ -519,10 +519,10 @@ func (w *tableScanWorker) scanRecords(task TableScanTask, sender func(IndexRecor var idxResult IndexRecordChunk err := wrapInBeginRollback(w.se, func(startTS uint64) error { - failpoint.Inject("mockScanRecordError", func() { - failpoint.Return(errors.New("mock scan record error")) - }) - failpoint.InjectCall("scanRecordExec") + if _, _err_ := failpoint.Eval(_curpkg_("mockScanRecordError")); _err_ == nil { + return errors.New("mock scan record error") + } + failpoint.Call(_curpkg_("scanRecordExec")) rs, err := buildTableScan(w.ctx, w.copCtx.GetBase(), startTS, task.Start, task.End) if err != nil { return err @@ -789,9 +789,9 @@ type indexIngestBaseWorker struct { } func (w *indexIngestBaseWorker) HandleTask(rs IndexRecordChunk) (IndexWriteResult, error) { - failpoint.Inject("injectPanicForIndexIngest", func() { + if _, _err_ := failpoint.Eval(_curpkg_("injectPanicForIndexIngest")); _err_ == nil { panic("mock panic") - }) + } result := IndexWriteResult{ ID: rs.ID, @@ -851,10 +851,10 @@ func (w *indexIngestBaseWorker) Close() { // WriteChunk will write index records to lightning engine. func (w *indexIngestBaseWorker) WriteChunk(rs *IndexRecordChunk) (count int, nextKey kv.Key, err error) { - failpoint.Inject("mockWriteLocalError", func(_ failpoint.Value) { - failpoint.Return(0, nil, errors.New("mock write local error")) - }) - failpoint.InjectCall("writeLocalExec", rs.Done) + if _, _err_ := failpoint.Eval(_curpkg_("mockWriteLocalError")); _err_ == nil { + return 0, nil, errors.New("mock write local error") + } + failpoint.Call(_curpkg_("writeLocalExec"), rs.Done) oprStartTime := time.Now() vars := w.se.GetSessionVars() @@ -934,9 +934,9 @@ func (s *indexWriteResultSink) flush() error { if s.backendCtx == nil { return nil } - failpoint.Inject("mockFlushError", func(_ failpoint.Value) { - failpoint.Return(errors.New("mock flush error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockFlushError")); _err_ == nil { + return errors.New("mock flush error") + } flushed, imported, err := s.backendCtx.Flush(ingest.FlushModeForceFlushAndImport) if s.cpMgr != nil { // Try to advance watermark even if there is an error. diff --git a/pkg/ddl/backfilling_operators.go__failpoint_stash__ b/pkg/ddl/backfilling_operators.go__failpoint_stash__ new file mode 100644 index 0000000000000..5ae5554ef290b --- /dev/null +++ b/pkg/ddl/backfilling_operators.go__failpoint_stash__ @@ -0,0 +1,962 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "encoding/hex" + "fmt" + "path" + "strconv" + "sync/atomic" + "time" + + "github.com/docker/go-units" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/ddl/copr" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/operator" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/external" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/resourcemanager/pool/workerpool" + "github.com/pingcap/tidb/pkg/resourcemanager/util" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/size" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var ( + _ operator.Operator = (*TableScanTaskSource)(nil) + _ operator.WithSink[TableScanTask] = (*TableScanTaskSource)(nil) + + _ operator.WithSource[TableScanTask] = (*TableScanOperator)(nil) + _ operator.Operator = (*TableScanOperator)(nil) + _ operator.WithSink[IndexRecordChunk] = (*TableScanOperator)(nil) + + _ operator.WithSource[IndexRecordChunk] = (*IndexIngestOperator)(nil) + _ operator.Operator = (*IndexIngestOperator)(nil) + _ operator.WithSink[IndexWriteResult] = (*IndexIngestOperator)(nil) + + _ operator.WithSource[IndexWriteResult] = (*indexWriteResultSink)(nil) + _ operator.Operator = (*indexWriteResultSink)(nil) +) + +type opSessPool interface { + Get() (sessionctx.Context, error) + Put(sessionctx.Context) +} + +// OperatorCtx is the context for AddIndexIngestPipeline. +// This is used to cancel the pipeline and collect errors. +type OperatorCtx struct { + context.Context + cancel context.CancelFunc + err atomic.Pointer[error] +} + +// NewDistTaskOperatorCtx is used for adding index with dist framework. +func NewDistTaskOperatorCtx(ctx context.Context, taskID, subtaskID int64) *OperatorCtx { + opCtx, cancel := context.WithCancel(ctx) + opCtx = logutil.WithFields(opCtx, zap.Int64("task-id", taskID), zap.Int64("subtask-id", subtaskID)) + return &OperatorCtx{ + Context: opCtx, + cancel: cancel, + } +} + +// NewLocalOperatorCtx is used for adding index with local ingest mode. +func NewLocalOperatorCtx(ctx context.Context, jobID int64) *OperatorCtx { + opCtx, cancel := context.WithCancel(ctx) + opCtx = logutil.WithFields(opCtx, zap.Int64("jobID", jobID)) + return &OperatorCtx{ + Context: opCtx, + cancel: cancel, + } +} + +func (ctx *OperatorCtx) onError(err error) { + tracedErr := errors.Trace(err) + ctx.cancel() + ctx.err.CompareAndSwap(nil, &tracedErr) +} + +// Cancel cancels the pipeline. +func (ctx *OperatorCtx) Cancel() { + ctx.cancel() +} + +// OperatorErr returns the error of the operator. +func (ctx *OperatorCtx) OperatorErr() error { + err := ctx.err.Load() + if err == nil { + return nil + } + return *err +} + +var ( + _ RowCountListener = (*EmptyRowCntListener)(nil) + _ RowCountListener = (*distTaskRowCntListener)(nil) + _ RowCountListener = (*localRowCntListener)(nil) +) + +// RowCountListener is invoked when some index records are flushed to disk or imported to TiKV. +type RowCountListener interface { + Written(rowCnt int) + SetTotal(total int) +} + +// EmptyRowCntListener implements a noop RowCountListener. +type EmptyRowCntListener struct{} + +// Written implements RowCountListener. +func (*EmptyRowCntListener) Written(_ int) {} + +// SetTotal implements RowCountListener. +func (*EmptyRowCntListener) SetTotal(_ int) {} + +// NewAddIndexIngestPipeline creates a pipeline for adding index in ingest mode. +func NewAddIndexIngestPipeline( + ctx *OperatorCtx, + store kv.Storage, + sessPool opSessPool, + backendCtx ingest.BackendCtx, + engines []ingest.Engine, + jobID int64, + tbl table.PhysicalTable, + idxInfos []*model.IndexInfo, + startKey, endKey kv.Key, + reorgMeta *model.DDLReorgMeta, + avgRowSize int, + concurrency int, + cpMgr *ingest.CheckpointManager, + rowCntListener RowCountListener, +) (*operator.AsyncPipeline, error) { + indexes := make([]table.Index, 0, len(idxInfos)) + for _, idxInfo := range idxInfos { + index := tables.NewIndex(tbl.GetPhysicalID(), tbl.Meta(), idxInfo) + indexes = append(indexes, index) + } + reqSrc := getDDLRequestSource(model.ActionAddIndex) + copCtx, err := NewReorgCopContext(store, reorgMeta, tbl.Meta(), idxInfos, reqSrc) + if err != nil { + return nil, err + } + poolSize := copReadChunkPoolSize() + srcChkPool := make(chan *chunk.Chunk, poolSize) + for i := 0; i < poolSize; i++ { + srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, copReadBatchSize()) + } + readerCnt, writerCnt := expectedIngestWorkerCnt(concurrency, avgRowSize) + + srcOp := NewTableScanTaskSource(ctx, store, tbl, startKey, endKey, cpMgr) + scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt, cpMgr) + ingestOp := NewIndexIngestOperator(ctx, copCtx, backendCtx, sessPool, + tbl, indexes, engines, srcChkPool, writerCnt, reorgMeta, cpMgr, rowCntListener) + sinkOp := newIndexWriteResultSink(ctx, backendCtx, tbl, indexes, cpMgr, rowCntListener) + + operator.Compose[TableScanTask](srcOp, scanOp) + operator.Compose[IndexRecordChunk](scanOp, ingestOp) + operator.Compose[IndexWriteResult](ingestOp, sinkOp) + + logutil.Logger(ctx).Info("build add index local storage operators", + zap.Int64("jobID", jobID), + zap.Int("avgRowSize", avgRowSize), + zap.Int("reader", readerCnt), + zap.Int("writer", writerCnt)) + + return operator.NewAsyncPipeline( + srcOp, scanOp, ingestOp, sinkOp, + ), nil +} + +// NewWriteIndexToExternalStoragePipeline creates a pipeline for writing index to external storage. +func NewWriteIndexToExternalStoragePipeline( + ctx *OperatorCtx, + store kv.Storage, + extStoreURI string, + sessPool opSessPool, + jobID, subtaskID int64, + tbl table.PhysicalTable, + idxInfos []*model.IndexInfo, + startKey, endKey kv.Key, + onClose external.OnCloseFunc, + reorgMeta *model.DDLReorgMeta, + avgRowSize int, + concurrency int, + resource *proto.StepResource, + rowCntListener RowCountListener, +) (*operator.AsyncPipeline, error) { + indexes := make([]table.Index, 0, len(idxInfos)) + for _, idxInfo := range idxInfos { + index := tables.NewIndex(tbl.GetPhysicalID(), tbl.Meta(), idxInfo) + indexes = append(indexes, index) + } + reqSrc := getDDLRequestSource(model.ActionAddIndex) + copCtx, err := NewReorgCopContext(store, reorgMeta, tbl.Meta(), idxInfos, reqSrc) + if err != nil { + return nil, err + } + poolSize := copReadChunkPoolSize() + srcChkPool := make(chan *chunk.Chunk, poolSize) + for i := 0; i < poolSize; i++ { + srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, copReadBatchSize()) + } + readerCnt, writerCnt := expectedIngestWorkerCnt(concurrency, avgRowSize) + + backend, err := storage.ParseBackend(extStoreURI, nil) + if err != nil { + return nil, err + } + extStore, err := storage.NewWithDefaultOpt(ctx, backend) + if err != nil { + return nil, err + } + memCap := resource.Mem.Capacity() + memSizePerIndex := uint64(memCap / int64(writerCnt*2*len(idxInfos))) + failpoint.Inject("mockWriterMemSize", func() { + memSizePerIndex = 1 * size.GB + }) + + srcOp := NewTableScanTaskSource(ctx, store, tbl, startKey, endKey, nil) + scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt, nil) + writeOp := NewWriteExternalStoreOperator( + ctx, copCtx, sessPool, jobID, subtaskID, tbl, indexes, extStore, srcChkPool, writerCnt, onClose, memSizePerIndex, reorgMeta) + sinkOp := newIndexWriteResultSink(ctx, nil, tbl, indexes, nil, rowCntListener) + + operator.Compose[TableScanTask](srcOp, scanOp) + operator.Compose[IndexRecordChunk](scanOp, writeOp) + operator.Compose[IndexWriteResult](writeOp, sinkOp) + + logutil.Logger(ctx).Info("build add index cloud storage operators", + zap.Int64("jobID", jobID), + zap.String("memCap", units.BytesSize(float64(memCap))), + zap.String("memSizePerIdx", units.BytesSize(float64(memSizePerIndex))), + zap.Int("avgRowSize", avgRowSize), + zap.Int("reader", readerCnt), + zap.Int("writer", writerCnt)) + + return operator.NewAsyncPipeline( + srcOp, scanOp, writeOp, sinkOp, + ), nil +} + +// TableScanTask contains the start key and the end key of a region. +type TableScanTask struct { + ID int + Start kv.Key + End kv.Key +} + +// String implement fmt.Stringer interface. +func (t TableScanTask) String() string { + return fmt.Sprintf("TableScanTask: id=%d, startKey=%s, endKey=%s", + t.ID, hex.EncodeToString(t.Start), hex.EncodeToString(t.End)) +} + +// IndexRecordChunk contains one of the chunk read from corresponding TableScanTask. +type IndexRecordChunk struct { + ID int + Chunk *chunk.Chunk + Err error + Done bool +} + +// TableScanTaskSource produces TableScanTask by splitting table records into ranges. +type TableScanTaskSource struct { + ctx context.Context + + errGroup errgroup.Group + sink operator.DataChannel[TableScanTask] + + tbl table.PhysicalTable + store kv.Storage + startKey kv.Key + endKey kv.Key + + // only used in local ingest + cpMgr *ingest.CheckpointManager +} + +// NewTableScanTaskSource creates a new TableScanTaskSource. +func NewTableScanTaskSource( + ctx context.Context, + store kv.Storage, + physicalTable table.PhysicalTable, + startKey kv.Key, + endKey kv.Key, + cpMgr *ingest.CheckpointManager, +) *TableScanTaskSource { + return &TableScanTaskSource{ + ctx: ctx, + errGroup: errgroup.Group{}, + tbl: physicalTable, + store: store, + startKey: startKey, + endKey: endKey, + cpMgr: cpMgr, + } +} + +// SetSink implements WithSink interface. +func (src *TableScanTaskSource) SetSink(sink operator.DataChannel[TableScanTask]) { + src.sink = sink +} + +// Open implements Operator interface. +func (src *TableScanTaskSource) Open() error { + src.errGroup.Go(src.generateTasks) + return nil +} + +// adjustStartKey adjusts the start key so that we can skip the ranges that have been processed +// according to the information of checkpoint manager. +func (src *TableScanTaskSource) adjustStartKey(start, end kv.Key) (adjusted kv.Key, done bool) { + if src.cpMgr == nil { + return start, false + } + cpKey := src.cpMgr.LastProcessedKey() + if len(cpKey) == 0 { + return start, false + } + if cpKey.Cmp(start) < 0 || cpKey.Cmp(end) > 0 { + logutil.Logger(src.ctx).Error("invalid checkpoint key", + zap.String("last_process_key", hex.EncodeToString(cpKey)), + zap.String("start", hex.EncodeToString(start)), + zap.String("end", hex.EncodeToString(end)), + ) + if intest.InTest { + panic("invalid checkpoint key") + } + return start, false + } + if cpKey.Cmp(end) == 0 { + return cpKey, true + } + return cpKey.Next(), false +} + +func (src *TableScanTaskSource) generateTasks() error { + taskIDAlloc := newTaskIDAllocator() + defer src.sink.Finish() + + startKey, done := src.adjustStartKey(src.startKey, src.endKey) + if done { + // All table data are done. + return nil + } + for { + kvRanges, err := loadTableRanges( + src.ctx, + src.tbl, + src.store, + startKey, + src.endKey, + backfillTaskChanSize, + ) + if err != nil { + return err + } + if len(kvRanges) == 0 { + break + } + + batchTasks := src.getBatchTableScanTask(kvRanges, taskIDAlloc) + for _, task := range batchTasks { + select { + case <-src.ctx.Done(): + return src.ctx.Err() + case src.sink.Channel() <- task: + } + } + startKey = kvRanges[len(kvRanges)-1].EndKey + if startKey.Cmp(src.endKey) >= 0 { + break + } + } + return nil +} + +func (src *TableScanTaskSource) getBatchTableScanTask( + kvRanges []kv.KeyRange, + taskIDAlloc *taskIDAllocator, +) []TableScanTask { + batchTasks := make([]TableScanTask, 0, len(kvRanges)) + prefix := src.tbl.RecordPrefix() + // Build reorg tasks. + for _, keyRange := range kvRanges { + taskID := taskIDAlloc.alloc() + startKey := keyRange.StartKey + if len(startKey) == 0 { + startKey = prefix + } + endKey := keyRange.EndKey + if len(endKey) == 0 { + endKey = prefix.PrefixNext() + } + + task := TableScanTask{ + ID: taskID, + Start: startKey, + End: endKey, + } + batchTasks = append(batchTasks, task) + } + return batchTasks +} + +// Close implements Operator interface. +func (src *TableScanTaskSource) Close() error { + return src.errGroup.Wait() +} + +// String implements fmt.Stringer interface. +func (*TableScanTaskSource) String() string { + return "TableScanTaskSource" +} + +// TableScanOperator scans table records in given key ranges from kv store. +type TableScanOperator struct { + *operator.AsyncOperator[TableScanTask, IndexRecordChunk] +} + +// NewTableScanOperator creates a new TableScanOperator. +func NewTableScanOperator( + ctx *OperatorCtx, + sessPool opSessPool, + copCtx copr.CopContext, + srcChkPool chan *chunk.Chunk, + concurrency int, + cpMgr *ingest.CheckpointManager, +) *TableScanOperator { + pool := workerpool.NewWorkerPool( + "TableScanOperator", + util.DDL, + concurrency, + func() workerpool.Worker[TableScanTask, IndexRecordChunk] { + return &tableScanWorker{ + ctx: ctx, + copCtx: copCtx, + sessPool: sessPool, + se: nil, + srcChkPool: srcChkPool, + cpMgr: cpMgr, + } + }) + return &TableScanOperator{ + AsyncOperator: operator.NewAsyncOperator[TableScanTask, IndexRecordChunk](ctx, pool), + } +} + +type tableScanWorker struct { + ctx *OperatorCtx + copCtx copr.CopContext + sessPool opSessPool + se *session.Session + srcChkPool chan *chunk.Chunk + + cpMgr *ingest.CheckpointManager +} + +func (w *tableScanWorker) HandleTask(task TableScanTask, sender func(IndexRecordChunk)) { + defer tidbutil.Recover(metrics.LblAddIndex, "handleTableScanTaskWithRecover", func() { + w.ctx.onError(dbterror.ErrReorgPanic) + }, false) + + failpoint.Inject("injectPanicForTableScan", func() { + panic("mock panic") + }) + if w.se == nil { + sessCtx, err := w.sessPool.Get() + if err != nil { + logutil.Logger(w.ctx).Error("tableScanWorker get session from pool failed", zap.Error(err)) + w.ctx.onError(err) + return + } + w.se = session.NewSession(sessCtx) + } + w.scanRecords(task, sender) +} + +func (w *tableScanWorker) Close() { + if w.se != nil { + w.sessPool.Put(w.se.Context) + } +} + +func (w *tableScanWorker) scanRecords(task TableScanTask, sender func(IndexRecordChunk)) { + logutil.Logger(w.ctx).Info("start a table scan task", + zap.Int("id", task.ID), zap.Stringer("task", task)) + + var idxResult IndexRecordChunk + err := wrapInBeginRollback(w.se, func(startTS uint64) error { + failpoint.Inject("mockScanRecordError", func() { + failpoint.Return(errors.New("mock scan record error")) + }) + failpoint.InjectCall("scanRecordExec") + rs, err := buildTableScan(w.ctx, w.copCtx.GetBase(), startTS, task.Start, task.End) + if err != nil { + return err + } + if w.cpMgr != nil { + w.cpMgr.Register(task.ID, task.End) + } + var done bool + for !done { + srcChk := w.getChunk() + done, err = fetchTableScanResult(w.ctx, w.copCtx.GetBase(), rs, srcChk) + if err != nil || w.ctx.Err() != nil { + w.recycleChunk(srcChk) + terror.Call(rs.Close) + return err + } + idxResult = IndexRecordChunk{ID: task.ID, Chunk: srcChk, Done: done} + if w.cpMgr != nil { + w.cpMgr.UpdateTotalKeys(task.ID, srcChk.NumRows(), done) + } + sender(idxResult) + } + return rs.Close() + }) + if err != nil { + w.ctx.onError(err) + } +} + +func (w *tableScanWorker) getChunk() *chunk.Chunk { + chk := <-w.srcChkPool + newCap := copReadBatchSize() + if chk.Capacity() != newCap { + chk = chunk.NewChunkWithCapacity(w.copCtx.GetBase().FieldTypes, newCap) + } + chk.Reset() + return chk +} + +func (w *tableScanWorker) recycleChunk(chk *chunk.Chunk) { + w.srcChkPool <- chk +} + +// WriteExternalStoreOperator writes index records to external storage. +type WriteExternalStoreOperator struct { + *operator.AsyncOperator[IndexRecordChunk, IndexWriteResult] +} + +// NewWriteExternalStoreOperator creates a new WriteExternalStoreOperator. +func NewWriteExternalStoreOperator( + ctx *OperatorCtx, + copCtx copr.CopContext, + sessPool opSessPool, + jobID int64, + subtaskID int64, + tbl table.PhysicalTable, + indexes []table.Index, + store storage.ExternalStorage, + srcChunkPool chan *chunk.Chunk, + concurrency int, + onClose external.OnCloseFunc, + memoryQuota uint64, + reorgMeta *model.DDLReorgMeta, +) *WriteExternalStoreOperator { + // due to multi-schema-change, we may merge processing multiple indexes into one + // local backend. + hasUnique := false + for _, index := range indexes { + if index.Meta().Unique { + hasUnique = true + break + } + } + + pool := workerpool.NewWorkerPool( + "WriteExternalStoreOperator", + util.DDL, + concurrency, + func() workerpool.Worker[IndexRecordChunk, IndexWriteResult] { + writers := make([]ingest.Writer, 0, len(indexes)) + for i := range indexes { + builder := external.NewWriterBuilder(). + SetOnCloseFunc(onClose). + SetKeyDuplicationEncoding(hasUnique). + SetMemorySizeLimit(memoryQuota). + SetGroupOffset(i) + writerID := uuid.New().String() + prefix := path.Join(strconv.Itoa(int(jobID)), strconv.Itoa(int(subtaskID))) + writer := builder.Build(store, prefix, writerID) + writers = append(writers, writer) + } + + return &indexIngestExternalWorker{ + indexIngestBaseWorker: indexIngestBaseWorker{ + ctx: ctx, + tbl: tbl, + indexes: indexes, + copCtx: copCtx, + se: nil, + sessPool: sessPool, + writers: writers, + srcChunkPool: srcChunkPool, + reorgMeta: reorgMeta, + }, + } + }) + return &WriteExternalStoreOperator{ + AsyncOperator: operator.NewAsyncOperator[IndexRecordChunk, IndexWriteResult](ctx, pool), + } +} + +// IndexWriteResult contains the result of writing index records to ingest engine. +type IndexWriteResult struct { + ID int + Added int + Total int + Next kv.Key +} + +// IndexIngestOperator writes index records to ingest engine. +type IndexIngestOperator struct { + *operator.AsyncOperator[IndexRecordChunk, IndexWriteResult] +} + +// NewIndexIngestOperator creates a new IndexIngestOperator. +func NewIndexIngestOperator( + ctx *OperatorCtx, + copCtx copr.CopContext, + backendCtx ingest.BackendCtx, + sessPool opSessPool, + tbl table.PhysicalTable, + indexes []table.Index, + engines []ingest.Engine, + srcChunkPool chan *chunk.Chunk, + concurrency int, + reorgMeta *model.DDLReorgMeta, + cpMgr *ingest.CheckpointManager, + rowCntListener RowCountListener, +) *IndexIngestOperator { + writerCfg := getLocalWriterConfig(len(indexes), concurrency) + + var writerIDAlloc atomic.Int32 + pool := workerpool.NewWorkerPool( + "indexIngestOperator", + util.DDL, + concurrency, + func() workerpool.Worker[IndexRecordChunk, IndexWriteResult] { + writers := make([]ingest.Writer, 0, len(indexes)) + for i := range indexes { + writerID := int(writerIDAlloc.Add(1)) + writer, err := engines[i].CreateWriter(writerID, writerCfg) + if err != nil { + logutil.Logger(ctx).Error("create index ingest worker failed", zap.Error(err)) + ctx.onError(err) + return nil + } + writers = append(writers, writer) + } + + indexIDs := make([]int64, len(indexes)) + for i := 0; i < len(indexes); i++ { + indexIDs[i] = indexes[i].Meta().ID + } + return &indexIngestLocalWorker{ + indexIngestBaseWorker: indexIngestBaseWorker{ + ctx: ctx, + tbl: tbl, + indexes: indexes, + copCtx: copCtx, + + se: nil, + sessPool: sessPool, + writers: writers, + srcChunkPool: srcChunkPool, + reorgMeta: reorgMeta, + }, + indexIDs: indexIDs, + backendCtx: backendCtx, + rowCntListener: rowCntListener, + cpMgr: cpMgr, + } + }) + return &IndexIngestOperator{ + AsyncOperator: operator.NewAsyncOperator[IndexRecordChunk, IndexWriteResult](ctx, pool), + } +} + +type indexIngestExternalWorker struct { + indexIngestBaseWorker +} + +func (w *indexIngestExternalWorker) HandleTask(ck IndexRecordChunk, send func(IndexWriteResult)) { + defer tidbutil.Recover(metrics.LblAddIndex, "indexIngestExternalWorkerRecover", func() { + w.ctx.onError(dbterror.ErrReorgPanic) + }, false) + defer func() { + if ck.Chunk != nil { + w.srcChunkPool <- ck.Chunk + } + }() + rs, err := w.indexIngestBaseWorker.HandleTask(ck) + if err != nil { + w.ctx.onError(err) + return + } + send(rs) +} + +type indexIngestLocalWorker struct { + indexIngestBaseWorker + indexIDs []int64 + backendCtx ingest.BackendCtx + rowCntListener RowCountListener + cpMgr *ingest.CheckpointManager +} + +func (w *indexIngestLocalWorker) HandleTask(ck IndexRecordChunk, send func(IndexWriteResult)) { + defer tidbutil.Recover(metrics.LblAddIndex, "indexIngestLocalWorkerRecover", func() { + w.ctx.onError(dbterror.ErrReorgPanic) + }, false) + defer func() { + if ck.Chunk != nil { + w.srcChunkPool <- ck.Chunk + } + }() + rs, err := w.indexIngestBaseWorker.HandleTask(ck) + if err != nil { + w.ctx.onError(err) + return + } + if rs.Added == 0 { + return + } + w.rowCntListener.Written(rs.Added) + flushed, imported, err := w.backendCtx.Flush(ingest.FlushModeAuto) + if err != nil { + w.ctx.onError(err) + return + } + if w.cpMgr != nil { + totalCnt, nextKey := w.cpMgr.Status() + rs.Total = totalCnt + rs.Next = nextKey + w.cpMgr.UpdateWrittenKeys(ck.ID, rs.Added) + w.cpMgr.AdvanceWatermark(flushed, imported) + } + send(rs) +} + +type indexIngestBaseWorker struct { + ctx *OperatorCtx + + tbl table.PhysicalTable + indexes []table.Index + reorgMeta *model.DDLReorgMeta + + copCtx copr.CopContext + sessPool opSessPool + se *session.Session + restore func(sessionctx.Context) + + writers []ingest.Writer + srcChunkPool chan *chunk.Chunk +} + +func (w *indexIngestBaseWorker) HandleTask(rs IndexRecordChunk) (IndexWriteResult, error) { + failpoint.Inject("injectPanicForIndexIngest", func() { + panic("mock panic") + }) + + result := IndexWriteResult{ + ID: rs.ID, + } + w.initSessCtx() + count, nextKey, err := w.WriteChunk(&rs) + if err != nil { + w.ctx.onError(err) + return result, err + } + if count == 0 { + logutil.Logger(w.ctx).Info("finish a index ingest task", zap.Int("id", rs.ID)) + return result, nil + } + result.Added = count + result.Next = nextKey + if ResultCounterForTest != nil { + ResultCounterForTest.Add(1) + } + return result, nil +} + +func (w *indexIngestBaseWorker) initSessCtx() { + if w.se == nil { + sessCtx, err := w.sessPool.Get() + if err != nil { + w.ctx.onError(err) + return + } + w.restore = restoreSessCtx(sessCtx) + if err := initSessCtx(sessCtx, w.reorgMeta); err != nil { + w.ctx.onError(err) + return + } + w.se = session.NewSession(sessCtx) + } +} + +func (w *indexIngestBaseWorker) Close() { + // TODO(lance6716): unify the real write action for engineInfo and external + // writer. + for _, writer := range w.writers { + ew, ok := writer.(*external.Writer) + if !ok { + break + } + err := ew.Close(w.ctx) + if err != nil { + w.ctx.onError(err) + } + } + if w.se != nil { + w.restore(w.se.Context) + w.sessPool.Put(w.se.Context) + } +} + +// WriteChunk will write index records to lightning engine. +func (w *indexIngestBaseWorker) WriteChunk(rs *IndexRecordChunk) (count int, nextKey kv.Key, err error) { + failpoint.Inject("mockWriteLocalError", func(_ failpoint.Value) { + failpoint.Return(0, nil, errors.New("mock write local error")) + }) + failpoint.InjectCall("writeLocalExec", rs.Done) + + oprStartTime := time.Now() + vars := w.se.GetSessionVars() + sc := vars.StmtCtx + cnt, lastHandle, err := writeChunkToLocal(w.ctx, w.writers, w.indexes, w.copCtx, sc.TimeZone(), sc.ErrCtx(), vars.GetWriteStmtBufs(), rs.Chunk) + if err != nil || cnt == 0 { + return 0, nil, err + } + logSlowOperations(time.Since(oprStartTime), "writeChunkToLocal", 3000) + nextKey = tablecodec.EncodeRecordKey(w.tbl.RecordPrefix(), lastHandle) + return cnt, nextKey, nil +} + +type indexWriteResultSink struct { + ctx *OperatorCtx + backendCtx ingest.BackendCtx + tbl table.PhysicalTable + indexes []table.Index + + cpMgr *ingest.CheckpointManager + rowCntListener RowCountListener + + errGroup errgroup.Group + source operator.DataChannel[IndexWriteResult] +} + +func newIndexWriteResultSink( + ctx *OperatorCtx, + backendCtx ingest.BackendCtx, + tbl table.PhysicalTable, + indexes []table.Index, + cpMgr *ingest.CheckpointManager, + rowCntListener RowCountListener, +) *indexWriteResultSink { + return &indexWriteResultSink{ + ctx: ctx, + backendCtx: backendCtx, + tbl: tbl, + indexes: indexes, + errGroup: errgroup.Group{}, + cpMgr: cpMgr, + rowCntListener: rowCntListener, + } +} + +func (s *indexWriteResultSink) SetSource(source operator.DataChannel[IndexWriteResult]) { + s.source = source +} + +func (s *indexWriteResultSink) Open() error { + s.errGroup.Go(s.collectResult) + return nil +} + +func (s *indexWriteResultSink) collectResult() error { + for { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + case _, ok := <-s.source.Channel(): + if !ok { + err := s.flush() + if err != nil { + s.ctx.onError(err) + } + if s.cpMgr != nil { + total, _ := s.cpMgr.Status() + s.rowCntListener.SetTotal(total) + } + return err + } + } + } +} + +func (s *indexWriteResultSink) flush() error { + if s.backendCtx == nil { + return nil + } + failpoint.Inject("mockFlushError", func(_ failpoint.Value) { + failpoint.Return(errors.New("mock flush error")) + }) + flushed, imported, err := s.backendCtx.Flush(ingest.FlushModeForceFlushAndImport) + if s.cpMgr != nil { + // Try to advance watermark even if there is an error. + s.cpMgr.AdvanceWatermark(flushed, imported) + } + if err != nil { + msg := "flush error" + if flushed { + msg = "import error" + } + logutil.Logger(s.ctx).Error(msg, zap.String("category", "ddl"), zap.Error(err)) + return err + } + return nil +} + +func (s *indexWriteResultSink) Close() error { + return s.errGroup.Wait() +} + +func (*indexWriteResultSink) String() string { + return "indexWriteResultSink" +} diff --git a/pkg/ddl/backfilling_read_index.go b/pkg/ddl/backfilling_read_index.go index 0c5ce7a62a538..41e0d3cc77ab5 100644 --- a/pkg/ddl/backfilling_read_index.go +++ b/pkg/ddl/backfilling_read_index.go @@ -148,7 +148,7 @@ func (r *readIndexExecutor) Cleanup(ctx context.Context) error { } func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { - failpoint.InjectCall("mockDMLExecutionAddIndexSubTaskFinish") + failpoint.Call(_curpkg_("mockDMLExecutionAddIndexSubTaskFinish")) if len(r.cloudStorageURI) == 0 { return nil } diff --git a/pkg/ddl/backfilling_read_index.go__failpoint_stash__ b/pkg/ddl/backfilling_read_index.go__failpoint_stash__ new file mode 100644 index 0000000000000..0c5ce7a62a538 --- /dev/null +++ b/pkg/ddl/backfilling_read_index.go__failpoint_stash__ @@ -0,0 +1,315 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "encoding/hex" + "encoding/json" + "sync" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" + "github.com/pingcap/tidb/pkg/disttask/operator" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/external" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/table" + tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" +) + +type readIndexExecutor struct { + execute.StepExecFrameworkInfo + d *ddl + job *model.Job + indexes []*model.IndexInfo + ptbl table.PhysicalTable + jc *JobContext + + avgRowSize int + cloudStorageURI string + + bc ingest.BackendCtx + curRowCount *atomic.Int64 + + subtaskSummary sync.Map // subtaskID => readIndexSummary +} + +type readIndexSummary struct { + metaGroups []*external.SortedKVMeta + mu sync.Mutex +} + +func newReadIndexExecutor( + d *ddl, + job *model.Job, + indexes []*model.IndexInfo, + ptbl table.PhysicalTable, + jc *JobContext, + bcGetter func() (ingest.BackendCtx, error), + cloudStorageURI string, + avgRowSize int, +) (*readIndexExecutor, error) { + bc, err := bcGetter() + if err != nil { + return nil, err + } + return &readIndexExecutor{ + d: d, + job: job, + indexes: indexes, + ptbl: ptbl, + jc: jc, + bc: bc, + cloudStorageURI: cloudStorageURI, + avgRowSize: avgRowSize, + curRowCount: &atomic.Int64{}, + }, nil +} + +func (*readIndexExecutor) Init(_ context.Context) error { + logutil.DDLLogger().Info("read index executor init subtask exec env") + return nil +} + +func (r *readIndexExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error { + logutil.DDLLogger().Info("read index executor run subtask", + zap.Bool("use cloud", len(r.cloudStorageURI) > 0)) + + r.subtaskSummary.Store(subtask.ID, &readIndexSummary{ + metaGroups: make([]*external.SortedKVMeta, len(r.indexes)), + }) + + sm, err := decodeBackfillSubTaskMeta(subtask.Meta) + if err != nil { + return err + } + + opCtx := NewDistTaskOperatorCtx(ctx, subtask.TaskID, subtask.ID) + defer opCtx.Cancel() + r.curRowCount.Store(0) + + if len(r.cloudStorageURI) > 0 { + pipe, err := r.buildExternalStorePipeline(opCtx, subtask.ID, sm, subtask.Concurrency) + if err != nil { + return err + } + return executeAndClosePipeline(opCtx, pipe) + } + + pipe, err := r.buildLocalStorePipeline(opCtx, sm, subtask.Concurrency) + if err != nil { + return err + } + err = executeAndClosePipeline(opCtx, pipe) + if err != nil { + // For dist task local based ingest, checkpoint is unsupported. + // If there is an error we should keep local sort dir clean. + err1 := r.bc.FinishAndUnregisterEngines(ingest.OptCleanData) + if err1 != nil { + logutil.DDLLogger().Warn("read index executor unregister engine failed", zap.Error(err1)) + } + return err + } + return r.bc.FinishAndUnregisterEngines(ingest.OptCleanData | ingest.OptCheckDup) +} + +func (r *readIndexExecutor) RealtimeSummary() *execute.SubtaskSummary { + return &execute.SubtaskSummary{ + RowCount: r.curRowCount.Load(), + } +} + +func (r *readIndexExecutor) Cleanup(ctx context.Context) error { + tidblogutil.Logger(ctx).Info("read index executor cleanup subtask exec env") + // cleanup backend context + ingest.LitBackCtxMgr.Unregister(r.job.ID) + return nil +} + +func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { + failpoint.InjectCall("mockDMLExecutionAddIndexSubTaskFinish") + if len(r.cloudStorageURI) == 0 { + return nil + } + // Rewrite the subtask meta to record statistics. + sm, err := decodeBackfillSubTaskMeta(subtask.Meta) + if err != nil { + return err + } + sum, _ := r.subtaskSummary.LoadAndDelete(subtask.ID) + s := sum.(*readIndexSummary) + sm.MetaGroups = s.metaGroups + sm.EleIDs = make([]int64, 0, len(r.indexes)) + for _, index := range r.indexes { + sm.EleIDs = append(sm.EleIDs, index.ID) + } + + all := external.SortedKVMeta{} + for _, g := range s.metaGroups { + all.Merge(g) + } + tidblogutil.Logger(ctx).Info("get key boundary on subtask finished", + zap.String("start", hex.EncodeToString(all.StartKey)), + zap.String("end", hex.EncodeToString(all.EndKey)), + zap.Int("fileCount", len(all.MultipleFilesStats)), + zap.Uint64("totalKVSize", all.TotalKVSize)) + + meta, err := json.Marshal(sm) + if err != nil { + return err + } + subtask.Meta = meta + return nil +} + +func (r *readIndexExecutor) getTableStartEndKey(sm *BackfillSubTaskMeta) ( + start, end kv.Key, tbl table.PhysicalTable, err error) { + currentVer, err1 := getValidCurrentVersion(r.d.store) + if err1 != nil { + return nil, nil, nil, errors.Trace(err1) + } + if parTbl, ok := r.ptbl.(table.PartitionedTable); ok { + pid := sm.PhysicalTableID + start, end, err = getTableRange(r.jc, r.d.ddlCtx, parTbl.GetPartition(pid), currentVer.Ver, r.job.Priority) + if err != nil { + logutil.DDLLogger().Error("get table range error", + zap.Error(err)) + return nil, nil, nil, err + } + tbl = parTbl.GetPartition(pid) + } else { + start, end = sm.RowStart, sm.RowEnd + tbl = r.ptbl + } + return start, end, tbl, nil +} + +func (r *readIndexExecutor) buildLocalStorePipeline( + opCtx *OperatorCtx, + sm *BackfillSubTaskMeta, + concurrency int, +) (*operator.AsyncPipeline, error) { + start, end, tbl, err := r.getTableStartEndKey(sm) + if err != nil { + return nil, err + } + d := r.d + indexIDs := make([]int64, 0, len(r.indexes)) + uniques := make([]bool, 0, len(r.indexes)) + for _, index := range r.indexes { + indexIDs = append(indexIDs, index.ID) + uniques = append(uniques, index.Unique) + } + engines, err := r.bc.Register(indexIDs, uniques, r.ptbl) + if err != nil { + tidblogutil.Logger(opCtx).Error("cannot register new engine", + zap.Error(err), + zap.Int64("job ID", r.job.ID), + zap.Int64s("index IDs", indexIDs)) + return nil, err + } + rowCntListener := newDistTaskRowCntListener(r.curRowCount, r.job.SchemaName, tbl.Meta().Name.O) + return NewAddIndexIngestPipeline( + opCtx, + d.store, + d.sessPool, + r.bc, + engines, + r.job.ID, + tbl, + r.indexes, + start, + end, + r.job.ReorgMeta, + r.avgRowSize, + concurrency, + nil, + rowCntListener, + ) +} + +func (r *readIndexExecutor) buildExternalStorePipeline( + opCtx *OperatorCtx, + subtaskID int64, + sm *BackfillSubTaskMeta, + concurrency int, +) (*operator.AsyncPipeline, error) { + start, end, tbl, err := r.getTableStartEndKey(sm) + if err != nil { + return nil, err + } + + d := r.d + onClose := func(summary *external.WriterSummary) { + sum, _ := r.subtaskSummary.Load(subtaskID) + s := sum.(*readIndexSummary) + s.mu.Lock() + kvMeta := s.metaGroups[summary.GroupOffset] + if kvMeta == nil { + kvMeta = &external.SortedKVMeta{} + s.metaGroups[summary.GroupOffset] = kvMeta + } + kvMeta.MergeSummary(summary) + s.mu.Unlock() + } + rowCntListener := newDistTaskRowCntListener(r.curRowCount, r.job.SchemaName, tbl.Meta().Name.O) + return NewWriteIndexToExternalStoragePipeline( + opCtx, + d.store, + r.cloudStorageURI, + r.d.sessPool, + r.job.ID, + subtaskID, + tbl, + r.indexes, + start, + end, + onClose, + r.job.ReorgMeta, + r.avgRowSize, + concurrency, + r.GetResource(), + rowCntListener, + ) +} + +type distTaskRowCntListener struct { + EmptyRowCntListener + totalRowCount *atomic.Int64 + counter prometheus.Counter +} + +func newDistTaskRowCntListener(totalRowCnt *atomic.Int64, dbName, tblName string) *distTaskRowCntListener { + counter := metrics.BackfillTotalCounter.WithLabelValues( + metrics.GenerateReorgLabel("add_idx_rate", dbName, tblName)) + return &distTaskRowCntListener{ + totalRowCount: totalRowCnt, + counter: counter, + } +} + +func (d *distTaskRowCntListener) Written(rowCnt int) { + d.totalRowCount.Add(int64(rowCnt)) + d.counter.Add(float64(rowCnt)) +} diff --git a/pkg/ddl/binding__failpoint_binding__.go b/pkg/ddl/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..d1a12631ba250 --- /dev/null +++ b/pkg/ddl/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package ddl + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/ddl/cluster.go b/pkg/ddl/cluster.go index b866023636484..7386b26d98e58 100644 --- a/pkg/ddl/cluster.go +++ b/pkg/ddl/cluster.go @@ -104,10 +104,10 @@ func recoverPDSchedule(ctx context.Context, pdScheduleParam map[string]any) erro func getStoreGlobalMinSafeTS(s kv.Storage) time.Time { minSafeTS := s.GetMinSafeTS(kv.GlobalTxnScope) // Inject mocked SafeTS for test. - failpoint.Inject("injectSafeTS", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectSafeTS")); _err_ == nil { injectTS := val.(int) minSafeTS = uint64(injectTS) - }) + } return oracle.GetTimeFromTS(minSafeTS) } @@ -129,10 +129,10 @@ func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBack } flashbackGetMinSafeTimeTimeout := time.Minute - failpoint.Inject("changeFlashbackGetMinSafeTimeTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("changeFlashbackGetMinSafeTimeTimeout")); _err_ == nil { t := val.(int) flashbackGetMinSafeTimeTimeout = time.Duration(t) - }) + } start := time.Now() minSafeTime := getStoreGlobalMinSafeTS(sctx.GetStore()) @@ -535,14 +535,14 @@ func SendPrepareFlashbackToVersionRPC( if err != nil { return taskStat, err } - failpoint.Inject("mockPrepareMeetsEpochNotMatch", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockPrepareMeetsEpochNotMatch")); _err_ == nil { if val.(bool) && bo.ErrorsNum() == 0 { regionErr = &errorpb.Error{ Message: "stale epoch", EpochNotMatch: &errorpb.EpochNotMatch{}, } } - }) + } if regionErr != nil { err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) if err != nil { @@ -702,11 +702,11 @@ func splitRegionsByKeyRanges(ctx context.Context, d *ddlCtx, keyRanges []kv.KeyR // 4. phase 2, send flashback RPC, do flashback jobs. func (w *worker) onFlashbackCluster(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { inFlashbackTest := false - failpoint.Inject("mockFlashbackTest", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockFlashbackTest")); _err_ == nil { if val.(bool) { inFlashbackTest = true } - }) + } // TODO: Support flashback in unistore. if d.store.Name() != "TiKV" && !inFlashbackTest { job.State = model.JobStateCancelled diff --git a/pkg/ddl/cluster.go__failpoint_stash__ b/pkg/ddl/cluster.go__failpoint_stash__ new file mode 100644 index 0000000000000..b866023636484 --- /dev/null +++ b/pkg/ddl/cluster.go__failpoint_stash__ @@ -0,0 +1,902 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "cmp" + "context" + "encoding/hex" + "fmt" + "slices" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/filter" + "github.com/pingcap/tidb/pkg/util/gcutil" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/txnkv/rangetask" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +var pdScheduleKey = []string{ + "merge-schedule-limit", +} + +const ( + flashbackMaxBackoff = 1800000 // 1800s + flashbackTimeout = 3 * time.Minute // 3min +) + +const ( + pdScheduleArgsOffset = 1 + iota + gcEnabledOffset + autoAnalyzeOffset + readOnlyOffset + totalLockedRegionsOffset + startTSOffset + commitTSOffset + ttlJobEnableOffSet + keyRangesOffset +) + +func closePDSchedule(ctx context.Context) error { + closeMap := make(map[string]any) + for _, key := range pdScheduleKey { + closeMap[key] = 0 + } + return infosync.SetPDScheduleConfig(ctx, closeMap) +} + +func savePDSchedule(ctx context.Context, job *model.Job) error { + retValue, err := infosync.GetPDScheduleConfig(ctx) + if err != nil { + return err + } + saveValue := make(map[string]any) + for _, key := range pdScheduleKey { + saveValue[key] = retValue[key] + } + job.Args[pdScheduleArgsOffset] = &saveValue + return nil +} + +func recoverPDSchedule(ctx context.Context, pdScheduleParam map[string]any) error { + if pdScheduleParam == nil { + return nil + } + return infosync.SetPDScheduleConfig(ctx, pdScheduleParam) +} + +func getStoreGlobalMinSafeTS(s kv.Storage) time.Time { + minSafeTS := s.GetMinSafeTS(kv.GlobalTxnScope) + // Inject mocked SafeTS for test. + failpoint.Inject("injectSafeTS", func(val failpoint.Value) { + injectTS := val.(int) + minSafeTS = uint64(injectTS) + }) + return oracle.GetTimeFromTS(minSafeTS) +} + +// ValidateFlashbackTS validates that flashBackTS in range [gcSafePoint, currentTS). +func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBackTS uint64) error { + currentTS, err := sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) + // If we fail to calculate currentTS from local time, fallback to get a timestamp from PD. + if err != nil { + metrics.ValidateReadTSFromPDCount.Inc() + currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope) + if err != nil { + return errors.Errorf("fail to validate flashback timestamp: %v", err) + } + currentTS = currentVer.Ver + } + oracleFlashbackTS := oracle.GetTimeFromTS(flashBackTS) + if oracleFlashbackTS.After(oracle.GetTimeFromTS(currentTS)) { + return errors.Errorf("cannot set flashback timestamp to future time") + } + + flashbackGetMinSafeTimeTimeout := time.Minute + failpoint.Inject("changeFlashbackGetMinSafeTimeTimeout", func(val failpoint.Value) { + t := val.(int) + flashbackGetMinSafeTimeTimeout = time.Duration(t) + }) + + start := time.Now() + minSafeTime := getStoreGlobalMinSafeTS(sctx.GetStore()) + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for oracleFlashbackTS.After(minSafeTime) { + if time.Since(start) >= flashbackGetMinSafeTimeTimeout { + return errors.Errorf("cannot set flashback timestamp after min-resolved-ts(%s)", minSafeTime) + } + select { + case <-ticker.C: + minSafeTime = getStoreGlobalMinSafeTS(sctx.GetStore()) + case <-ctx.Done(): + return ctx.Err() + } + } + + gcSafePoint, err := gcutil.GetGCSafePoint(sctx) + if err != nil { + return err + } + + return gcutil.ValidateSnapshotWithGCSafePoint(flashBackTS, gcSafePoint) +} + +func getTiDBTTLJobEnable(sess sessionctx.Context) (string, error) { + val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBTTLJobEnable) + if err != nil { + return "", errors.Trace(err) + } + return val, nil +} + +func setTiDBTTLJobEnable(ctx context.Context, sess sessionctx.Context, value string) error { + return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBTTLJobEnable, value) +} + +func setTiDBEnableAutoAnalyze(ctx context.Context, sess sessionctx.Context, value string) error { + return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBEnableAutoAnalyze, value) +} + +func getTiDBEnableAutoAnalyze(sess sessionctx.Context) (string, error) { + val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBEnableAutoAnalyze) + if err != nil { + return "", errors.Trace(err) + } + return val, nil +} + +func setTiDBSuperReadOnly(ctx context.Context, sess sessionctx.Context, value string) error { + return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBSuperReadOnly, value) +} + +func getTiDBSuperReadOnly(sess sessionctx.Context) (string, error) { + val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBSuperReadOnly) + if err != nil { + return "", errors.Trace(err) + } + return val, nil +} + +func isFlashbackSupportedDDLAction(action model.ActionType) bool { + switch action { + case model.ActionSetTiFlashReplica, model.ActionUpdateTiFlashReplicaStatus, model.ActionAlterPlacementPolicy, + model.ActionAlterTablePlacement, model.ActionAlterTablePartitionPlacement, model.ActionCreatePlacementPolicy, + model.ActionDropPlacementPolicy, model.ActionModifySchemaDefaultPlacement, + model.ActionAlterTableAttributes, model.ActionAlterTablePartitionAttributes: + return false + default: + return true + } +} + +func checkSystemSchemaID(t *meta.Meta, schemaID int64, flashbackTSString string) error { + if schemaID <= 0 { + return nil + } + dbInfo, err := t.GetDatabase(schemaID) + if err != nil || dbInfo == nil { + return errors.Trace(err) + } + if filter.IsSystemSchema(dbInfo.Name.L) { + return errors.Errorf("Detected modified system table during [%s, now), can't do flashback", flashbackTSString) + } + return nil +} + +func checkAndSetFlashbackClusterInfo(ctx context.Context, se sessionctx.Context, d *ddlCtx, t *meta.Meta, job *model.Job, flashbackTS uint64) (err error) { + if err = ValidateFlashbackTS(ctx, se, flashbackTS); err != nil { + return err + } + + if err = gcutil.DisableGC(se); err != nil { + return err + } + if err = closePDSchedule(ctx); err != nil { + return err + } + if err = setTiDBEnableAutoAnalyze(ctx, se, variable.Off); err != nil { + return err + } + if err = setTiDBSuperReadOnly(ctx, se, variable.On); err != nil { + return err + } + if err = setTiDBTTLJobEnable(ctx, se, variable.Off); err != nil { + return err + } + + nowSchemaVersion, err := t.GetSchemaVersion() + if err != nil { + return errors.Trace(err) + } + + flashbackSnapshotMeta := meta.NewSnapshotMeta(d.store.GetSnapshot(kv.NewVersion(flashbackTS))) + flashbackSchemaVersion, err := flashbackSnapshotMeta.GetSchemaVersion() + if err != nil { + return errors.Trace(err) + } + + flashbackTSString := oracle.GetTimeFromTS(flashbackTS).Format(types.TimeFSPFormat) + + // Check if there is an upgrade during [flashbackTS, now) + sql := fmt.Sprintf("select VARIABLE_VALUE from mysql.tidb as of timestamp '%s' where VARIABLE_NAME='tidb_server_version'", flashbackTSString) + rows, err := sess.NewSession(se).Execute(ctx, sql, "check_tidb_server_version") + if err != nil || len(rows) == 0 { + return errors.Errorf("Get history `tidb_server_version` failed, can't do flashback") + } + sql = fmt.Sprintf("select 1 from mysql.tidb where VARIABLE_NAME='tidb_server_version' and VARIABLE_VALUE=%s", rows[0].GetString(0)) + rows, err = sess.NewSession(se).Execute(ctx, sql, "check_tidb_server_version") + if err != nil { + return errors.Trace(err) + } + if len(rows) == 0 { + return errors.Errorf("Detected TiDB upgrade during [%s, now), can't do flashback", flashbackTSString) + } + + // Check is there a DDL task at flashbackTS. + sql = fmt.Sprintf("select count(*) from mysql.%s as of timestamp '%s'", JobTable, flashbackTSString) + rows, err = sess.NewSession(se).Execute(ctx, sql, "check_history_job") + if err != nil || len(rows) == 0 { + return errors.Errorf("Get history ddl jobs failed, can't do flashback") + } + if rows[0].GetInt64(0) != 0 { + return errors.Errorf("Detected another DDL job at %s, can't do flashback", flashbackTSString) + } + + // If flashbackSchemaVersion not same as nowSchemaVersion, we should check all schema diffs during [flashbackTs, now). + for i := flashbackSchemaVersion + 1; i <= nowSchemaVersion; i++ { + diff, err := t.GetSchemaDiff(i) + if err != nil { + return errors.Trace(err) + } + if diff == nil { + continue + } + if !isFlashbackSupportedDDLAction(diff.Type) { + return errors.Errorf("Detected unsupported DDL job type(%s) during [%s, now), can't do flashback", diff.Type.String(), flashbackTSString) + } + err = checkSystemSchemaID(flashbackSnapshotMeta, diff.SchemaID, flashbackTSString) + if err != nil { + return errors.Trace(err) + } + } + + jobs, err := GetAllDDLJobs(se) + if err != nil { + return errors.Trace(err) + } + // Other ddl jobs in queue, return error. + if len(jobs) != 1 { + var otherJob *model.Job + for _, j := range jobs { + if j.ID != job.ID { + otherJob = j + break + } + } + return errors.Errorf("have other ddl jobs(jobID: %d) in queue, can't do flashback", otherJob.ID) + } + return nil +} + +func addToSlice(schema string, tableName string, tableID int64, flashbackIDs []int64) []int64 { + if filter.IsSystemSchema(schema) && !strings.HasPrefix(tableName, "stats_") && tableName != "gc_delete_range" { + flashbackIDs = append(flashbackIDs, tableID) + } + return flashbackIDs +} + +// getTableDataKeyRanges get keyRanges by `flashbackIDs`. +// This func will return all flashback table data key ranges. +func getTableDataKeyRanges(nonFlashbackTableIDs []int64) []kv.KeyRange { + var keyRanges []kv.KeyRange + + nonFlashbackTableIDs = append(nonFlashbackTableIDs, -1) + + slices.SortFunc(nonFlashbackTableIDs, func(a, b int64) int { + return cmp.Compare(a, b) + }) + + for i := 1; i < len(nonFlashbackTableIDs); i++ { + keyRanges = append(keyRanges, kv.KeyRange{ + StartKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[i-1] + 1), + EndKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[i]), + }) + } + + // Add all other key ranges. + keyRanges = append(keyRanges, kv.KeyRange{ + StartKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[len(nonFlashbackTableIDs)-1] + 1), + EndKey: tablecodec.EncodeTablePrefix(meta.MaxGlobalID), + }) + + return keyRanges +} + +type keyRangeMayExclude struct { + r kv.KeyRange + exclude bool +} + +// appendContinuousKeyRanges merges not exclude continuous key ranges and appends +// to given []kv.KeyRange, assuming the gap between key ranges has no data. +// +// Precondition: schemaKeyRanges is sorted by start key. schemaKeyRanges are +// non-overlapping. +func appendContinuousKeyRanges(result []kv.KeyRange, schemaKeyRanges []keyRangeMayExclude) []kv.KeyRange { + var ( + continuousStart, continuousEnd kv.Key + ) + + for _, r := range schemaKeyRanges { + if r.exclude { + if continuousStart != nil { + result = append(result, kv.KeyRange{ + StartKey: continuousStart, + EndKey: continuousEnd, + }) + continuousStart = nil + } + continue + } + + if continuousStart == nil { + continuousStart = r.r.StartKey + } + continuousEnd = r.r.EndKey + } + + if continuousStart != nil { + result = append(result, kv.KeyRange{ + StartKey: continuousStart, + EndKey: continuousEnd, + }) + } + return result +} + +// getFlashbackKeyRanges get keyRanges for flashback cluster. +// It contains all non system table key ranges and meta data key ranges. +// The time complexity is O(nlogn). +func getFlashbackKeyRanges(ctx context.Context, sess sessionctx.Context, flashbackTS uint64) ([]kv.KeyRange, error) { + is := sess.GetDomainInfoSchema().(infoschema.InfoSchema) + schemas := is.AllSchemas() + + // The semantic of keyRanges(output). + keyRanges := make([]kv.KeyRange, 0) + + // get snapshot schema IDs. + flashbackSnapshotMeta := meta.NewSnapshotMeta(sess.GetStore().GetSnapshot(kv.NewVersion(flashbackTS))) + snapshotSchemas, err := flashbackSnapshotMeta.ListDatabases() + if err != nil { + return nil, errors.Trace(err) + } + + schemaIDs := make(map[int64]struct{}) + excludeSchemaIDs := make(map[int64]struct{}) + for _, schema := range schemas { + if filter.IsSystemSchema(schema.Name.L) { + excludeSchemaIDs[schema.ID] = struct{}{} + } else { + schemaIDs[schema.ID] = struct{}{} + } + } + for _, schema := range snapshotSchemas { + if filter.IsSystemSchema(schema.Name.L) { + excludeSchemaIDs[schema.ID] = struct{}{} + } else { + schemaIDs[schema.ID] = struct{}{} + } + } + + schemaKeyRanges := make([]keyRangeMayExclude, 0, len(schemaIDs)+len(excludeSchemaIDs)) + for schemaID := range schemaIDs { + metaStartKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID)) + metaEndKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID + 1)) + schemaKeyRanges = append(schemaKeyRanges, keyRangeMayExclude{ + r: kv.KeyRange{ + StartKey: metaStartKey, + EndKey: metaEndKey, + }, + exclude: false, + }) + } + for schemaID := range excludeSchemaIDs { + metaStartKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID)) + metaEndKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID + 1)) + schemaKeyRanges = append(schemaKeyRanges, keyRangeMayExclude{ + r: kv.KeyRange{ + StartKey: metaStartKey, + EndKey: metaEndKey, + }, + exclude: true, + }) + } + + slices.SortFunc(schemaKeyRanges, func(a, b keyRangeMayExclude) int { + return bytes.Compare(a.r.StartKey, b.r.StartKey) + }) + + keyRanges = appendContinuousKeyRanges(keyRanges, schemaKeyRanges) + + startKey := tablecodec.EncodeMetaKeyPrefix([]byte("DBs")) + keyRanges = append(keyRanges, kv.KeyRange{ + StartKey: startKey, + EndKey: startKey.PrefixNext(), + }) + + var nonFlashbackTableIDs []int64 + for _, db := range schemas { + tbls, err2 := is.SchemaTableInfos(ctx, db.Name) + if err2 != nil { + return nil, errors.Trace(err2) + } + for _, table := range tbls { + if !table.IsBaseTable() || table.ID > meta.MaxGlobalID { + continue + } + nonFlashbackTableIDs = addToSlice(db.Name.L, table.Name.L, table.ID, nonFlashbackTableIDs) + if table.Partition != nil { + for _, partition := range table.Partition.Definitions { + nonFlashbackTableIDs = addToSlice(db.Name.L, table.Name.L, partition.ID, nonFlashbackTableIDs) + } + } + } + } + + return append(keyRanges, getTableDataKeyRanges(nonFlashbackTableIDs)...), nil +} + +// SendPrepareFlashbackToVersionRPC prepares regions for flashback, the purpose is to put region into flashback state which region stop write +// Function also be called by BR for volume snapshot backup and restore +func SendPrepareFlashbackToVersionRPC( + ctx context.Context, + s tikv.Storage, + flashbackTS, startTS uint64, + r tikvstore.KeyRange, +) (rangetask.TaskStat, error) { + startKey, rangeEndKey := r.StartKey, r.EndKey + var taskStat rangetask.TaskStat + bo := tikv.NewBackoffer(ctx, flashbackMaxBackoff) + for { + select { + case <-ctx.Done(): + return taskStat, errors.WithStack(ctx.Err()) + default: + } + + if len(rangeEndKey) > 0 && bytes.Compare(startKey, rangeEndKey) >= 0 { + break + } + + loc, err := s.GetRegionCache().LocateKey(bo, startKey) + if err != nil { + return taskStat, err + } + + endKey := loc.EndKey + isLast := len(endKey) == 0 || (len(rangeEndKey) > 0 && bytes.Compare(endKey, rangeEndKey) >= 0) + // If it is the last region. + if isLast { + endKey = rangeEndKey + } + + logutil.DDLLogger().Info("send prepare flashback request", zap.Uint64("region_id", loc.Region.GetID()), + zap.String("start_key", hex.EncodeToString(startKey)), zap.String("end_key", hex.EncodeToString(endKey))) + + req := tikvrpc.NewRequest(tikvrpc.CmdPrepareFlashbackToVersion, &kvrpcpb.PrepareFlashbackToVersionRequest{ + StartKey: startKey, + EndKey: endKey, + StartTs: startTS, + Version: flashbackTS, + }) + + resp, err := s.SendReq(bo, req, loc.Region, flashbackTimeout) + if err != nil { + return taskStat, err + } + regionErr, err := resp.GetRegionError() + if err != nil { + return taskStat, err + } + failpoint.Inject("mockPrepareMeetsEpochNotMatch", func(val failpoint.Value) { + if val.(bool) && bo.ErrorsNum() == 0 { + regionErr = &errorpb.Error{ + Message: "stale epoch", + EpochNotMatch: &errorpb.EpochNotMatch{}, + } + } + }) + if regionErr != nil { + err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) + if err != nil { + return taskStat, err + } + continue + } + if resp.Resp == nil { + logutil.DDLLogger().Warn("prepare flashback miss resp body", zap.Uint64("region_id", loc.Region.GetID())) + err = bo.Backoff(tikv.BoTiKVRPC(), errors.New("prepare flashback rpc miss resp body")) + if err != nil { + return taskStat, err + } + continue + } + prepareFlashbackToVersionResp := resp.Resp.(*kvrpcpb.PrepareFlashbackToVersionResponse) + if err := prepareFlashbackToVersionResp.GetError(); err != "" { + boErr := bo.Backoff(tikv.BoTiKVRPC(), errors.New(err)) + if boErr != nil { + return taskStat, boErr + } + continue + } + taskStat.CompletedRegions++ + if isLast { + break + } + bo = tikv.NewBackoffer(ctx, flashbackMaxBackoff) + startKey = endKey + } + return taskStat, nil +} + +// SendFlashbackToVersionRPC flashback the MVCC key to the version +// Function also be called by BR for volume snapshot backup and restore +func SendFlashbackToVersionRPC( + ctx context.Context, + s tikv.Storage, + version uint64, + startTS, commitTS uint64, + r tikvstore.KeyRange, +) (rangetask.TaskStat, error) { + startKey, rangeEndKey := r.StartKey, r.EndKey + var taskStat rangetask.TaskStat + bo := tikv.NewBackoffer(ctx, flashbackMaxBackoff) + for { + select { + case <-ctx.Done(): + return taskStat, errors.WithStack(ctx.Err()) + default: + } + + if len(rangeEndKey) > 0 && bytes.Compare(startKey, rangeEndKey) >= 0 { + break + } + + loc, err := s.GetRegionCache().LocateKey(bo, startKey) + if err != nil { + return taskStat, err + } + + endKey := loc.EndKey + isLast := len(endKey) == 0 || (len(rangeEndKey) > 0 && bytes.Compare(endKey, rangeEndKey) >= 0) + // If it is the last region. + if isLast { + endKey = rangeEndKey + } + + logutil.DDLLogger().Info("send flashback request", zap.Uint64("region_id", loc.Region.GetID()), + zap.String("start_key", hex.EncodeToString(startKey)), zap.String("end_key", hex.EncodeToString(endKey))) + + req := tikvrpc.NewRequest(tikvrpc.CmdFlashbackToVersion, &kvrpcpb.FlashbackToVersionRequest{ + Version: version, + StartKey: startKey, + EndKey: endKey, + StartTs: startTS, + CommitTs: commitTS, + }) + + resp, err := s.SendReq(bo, req, loc.Region, flashbackTimeout) + if err != nil { + logutil.DDLLogger().Warn("send request meets error", zap.Uint64("region_id", loc.Region.GetID()), zap.Error(err)) + if err.Error() != fmt.Sprintf("region %d is not prepared for the flashback", loc.Region.GetID()) { + return taskStat, err + } + } else { + regionErr, err := resp.GetRegionError() + if err != nil { + return taskStat, err + } + if regionErr != nil { + err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) + if err != nil { + return taskStat, err + } + continue + } + if resp.Resp == nil { + logutil.DDLLogger().Warn("flashback miss resp body", zap.Uint64("region_id", loc.Region.GetID())) + err = bo.Backoff(tikv.BoTiKVRPC(), errors.New("flashback rpc miss resp body")) + if err != nil { + return taskStat, err + } + continue + } + flashbackToVersionResp := resp.Resp.(*kvrpcpb.FlashbackToVersionResponse) + if respErr := flashbackToVersionResp.GetError(); respErr != "" { + boErr := bo.Backoff(tikv.BoTiKVRPC(), errors.New(respErr)) + if boErr != nil { + return taskStat, boErr + } + continue + } + } + taskStat.CompletedRegions++ + if isLast { + break + } + bo = tikv.NewBackoffer(ctx, flashbackMaxBackoff) + startKey = endKey + } + return taskStat, nil +} + +func flashbackToVersion( + ctx context.Context, + d *ddlCtx, + handler rangetask.TaskHandler, + startKey []byte, endKey []byte, +) (err error) { + return rangetask.NewRangeTaskRunner( + "flashback-to-version-runner", + d.store.(tikv.Storage), + int(variable.GetDDLFlashbackConcurrency()), + handler, + ).RunOnRange(ctx, startKey, endKey) +} + +func splitRegionsByKeyRanges(ctx context.Context, d *ddlCtx, keyRanges []kv.KeyRange) { + if s, ok := d.store.(kv.SplittableStore); ok { + for _, keys := range keyRanges { + for { + // tableID is useless when scatter == false + _, err := s.SplitRegions(ctx, [][]byte{keys.StartKey, keys.EndKey}, false, nil) + if err == nil { + break + } + } + } + } +} + +// A Flashback has 4 different stages. +// 1. before lock flashbackClusterJobID, check clusterJobID and lock it. +// 2. before flashback start, check timestamp, disable GC and close PD schedule, get flashback key ranges. +// 3. phase 1, lock flashback key ranges. +// 4. phase 2, send flashback RPC, do flashback jobs. +func (w *worker) onFlashbackCluster(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + inFlashbackTest := false + failpoint.Inject("mockFlashbackTest", func(val failpoint.Value) { + if val.(bool) { + inFlashbackTest = true + } + }) + // TODO: Support flashback in unistore. + if d.store.Name() != "TiKV" && !inFlashbackTest { + job.State = model.JobStateCancelled + return ver, errors.Errorf("Not support flashback cluster in non-TiKV env") + } + + var flashbackTS, lockedRegions, startTS, commitTS uint64 + var pdScheduleValue map[string]any + var autoAnalyzeValue, readOnlyValue, ttlJobEnableValue string + var gcEnabledValue bool + var keyRanges []kv.KeyRange + if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabledValue, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS, &ttlJobEnableValue, &keyRanges); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + var totalRegions, completedRegions atomic.Uint64 + totalRegions.Store(lockedRegions) + + sess, err := w.sessPool.Get() + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + defer w.sessPool.Put(sess) + + switch job.SchemaState { + // Stage 1, check and set FlashbackClusterJobID, and update job args. + case model.StateNone: + if err = savePDSchedule(w.ctx, job); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + gcEnableValue, err := gcutil.CheckGCEnable(sess) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.Args[gcEnabledOffset] = &gcEnableValue + autoAnalyzeValue, err = getTiDBEnableAutoAnalyze(sess) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.Args[autoAnalyzeOffset] = &autoAnalyzeValue + readOnlyValue, err = getTiDBSuperReadOnly(sess) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.Args[readOnlyOffset] = &readOnlyValue + ttlJobEnableValue, err = getTiDBTTLJobEnable(sess) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.Args[ttlJobEnableOffSet] = &ttlJobEnableValue + job.SchemaState = model.StateDeleteOnly + return ver, nil + // Stage 2, check flashbackTS, close GC and PD schedule, get flashback key ranges. + case model.StateDeleteOnly: + if err = checkAndSetFlashbackClusterInfo(w.ctx, sess, d, t, job, flashbackTS); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + // We should get startTS here to avoid lost startTS when TiDB crashed during send prepare flashback RPC. + startTS, err = d.store.GetOracle().GetTimestamp(w.ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.Args[startTSOffset] = startTS + keyRanges, err = getFlashbackKeyRanges(w.ctx, sess, flashbackTS) + if err != nil { + return ver, errors.Trace(err) + } + job.Args[keyRangesOffset] = keyRanges + job.SchemaState = model.StateWriteOnly + return updateSchemaVersion(d, t, job) + // Stage 3, lock related key ranges. + case model.StateWriteOnly: + // TODO: Support flashback in unistore. + if inFlashbackTest { + job.SchemaState = model.StateWriteReorganization + return updateSchemaVersion(d, t, job) + } + // Split region by keyRanges, make sure no unrelated key ranges be locked. + splitRegionsByKeyRanges(w.ctx, d, keyRanges) + totalRegions.Store(0) + for _, r := range keyRanges { + if err = flashbackToVersion(w.ctx, d, + func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { + stats, err := SendPrepareFlashbackToVersionRPC(ctx, d.store.(tikv.Storage), flashbackTS, startTS, r) + totalRegions.Add(uint64(stats.CompletedRegions)) + return stats, err + }, r.StartKey, r.EndKey); err != nil { + logutil.DDLLogger().Warn("Get error when do flashback", zap.Error(err)) + return ver, err + } + } + job.Args[totalLockedRegionsOffset] = totalRegions.Load() + + // We should get commitTS here to avoid lost commitTS when TiDB crashed during send flashback RPC. + commitTS, err = d.store.GetOracle().GetTimestamp(w.ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) + if err != nil { + return ver, errors.Trace(err) + } + job.Args[commitTSOffset] = commitTS + job.SchemaState = model.StateWriteReorganization + return ver, nil + // Stage 4, get key ranges and send flashback RPC. + case model.StateWriteReorganization: + // TODO: Support flashback in unistore. + if inFlashbackTest { + asyncNotifyEvent(d, statsutil.NewFlashbackClusterEvent()) + job.State = model.JobStateDone + job.SchemaState = model.StatePublic + return ver, nil + } + + for _, r := range keyRanges { + if err = flashbackToVersion(w.ctx, d, + func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { + // Use same startTS as prepare phase to simulate 1PC txn. + stats, err := SendFlashbackToVersionRPC(ctx, d.store.(tikv.Storage), flashbackTS, startTS, commitTS, r) + completedRegions.Add(uint64(stats.CompletedRegions)) + logutil.DDLLogger().Info("flashback cluster stats", + zap.Uint64("complete regions", completedRegions.Load()), + zap.Uint64("total regions", totalRegions.Load()), + zap.Error(err)) + return stats, err + }, r.StartKey, r.EndKey); err != nil { + logutil.DDLLogger().Warn("Get error when do flashback", zap.Error(err)) + return ver, errors.Trace(err) + } + } + + asyncNotifyEvent(d, statsutil.NewFlashbackClusterEvent()) + job.State = model.JobStateDone + job.SchemaState = model.StatePublic + return updateSchemaVersion(d, t, job) + } + return ver, nil +} + +func finishFlashbackCluster(w *worker, job *model.Job) error { + // Didn't do anything during flashback, return directly + if job.SchemaState == model.StateNone { + return nil + } + + var flashbackTS, lockedRegions, startTS, commitTS uint64 + var pdScheduleValue map[string]any + var autoAnalyzeValue, readOnlyValue, ttlJobEnableValue string + var gcEnabled bool + + if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabled, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS, &ttlJobEnableValue); err != nil { + return errors.Trace(err) + } + sess, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(sess) + + err = kv.RunInNewTxn(w.ctx, w.store, true, func(context.Context, kv.Transaction) error { + if err = recoverPDSchedule(w.ctx, pdScheduleValue); err != nil { + return err + } + if gcEnabled { + if err = gcutil.EnableGC(sess); err != nil { + return err + } + } + if err = setTiDBSuperReadOnly(w.ctx, sess, readOnlyValue); err != nil { + return err + } + + if job.IsCancelled() { + // only restore `tidb_ttl_job_enable` when flashback failed + if err = setTiDBTTLJobEnable(w.ctx, sess, ttlJobEnableValue); err != nil { + return err + } + } + + return setTiDBEnableAutoAnalyze(w.ctx, sess, autoAnalyzeValue) + }) + if err != nil { + return err + } + + return nil +} diff --git a/pkg/ddl/column.go b/pkg/ddl/column.go index 5b889c8f73464..4b5e656d0389a 100644 --- a/pkg/ddl/column.go +++ b/pkg/ddl/column.go @@ -172,7 +172,7 @@ func onDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) } case model.StateWriteOnly: // write only -> delete only - failpoint.InjectCall("onDropColumnStateWriteOnly") + failpoint.Call(_curpkg_("onDropColumnStateWriteOnly")) colInfo.State = model.StateDeleteOnly tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) if len(idxInfos) > 0 { @@ -511,7 +511,7 @@ var TestReorgGoroutineRunning = make(chan any) // updateCurrentElement update the current element for reorgInfo. func (w *worker) updateCurrentElement(t table.Table, reorgInfo *reorgInfo) error { - failpoint.Inject("mockInfiniteReorgLogic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockInfiniteReorgLogic")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { a := new(any) @@ -520,11 +520,11 @@ func (w *worker) updateCurrentElement(t table.Table, reorgInfo *reorgInfo) error time.Sleep(30 * time.Millisecond) if w.isReorgCancelled(reorgInfo.Job.ID) { // Job is cancelled. So it can't be done. - failpoint.Return(dbterror.ErrCancelledDDLJob) + return dbterror.ErrCancelledDDLJob } } } - }) + } // TODO: Support partition tables. if bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { //nolint:forcetypeassert @@ -631,11 +631,11 @@ func newUpdateColumnWorker(id int, t table.PhysicalTable, decodeColMap map[int64 } } rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) - failpoint.Inject("forceRowLevelChecksumOnUpdateColumnBackfill", func() { + if _, _err_ := failpoint.Eval(_curpkg_("forceRowLevelChecksumOnUpdateColumnBackfill")); _err_ == nil { orig := variable.EnableRowLevelChecksum.Load() defer variable.EnableRowLevelChecksum.Store(orig) variable.EnableRowLevelChecksum.Store(true) - }) + } return &updateColumnWorker{ backfillCtx: bCtx, oldColInfo: oldCol, @@ -757,14 +757,14 @@ func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, ra recordWarning = errors.Cause(w.reformatErrors(warn[0].Err)).(*terror.Error) } - failpoint.Inject("MockReorgTimeoutInOneRegion", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockReorgTimeoutInOneRegion")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { if handle.IntValue() == 3000 && atomic.CompareAndSwapInt32(&TestCheckReorgTimeout, 0, 1) { - failpoint.Return(errors.Trace(dbterror.ErrWaitReorgTimeout)) + return errors.Trace(dbterror.ErrWaitReorgTimeout) } } - }) + } w.rowMap[w.newColInfo.ID] = newColVal _, err = w.rowDecoder.EvalRemainedExprColumnMap(w.exprCtx, w.rowMap) @@ -1158,12 +1158,12 @@ func modifyColsFromNull2NotNull(w *worker, dbInfo *model.DBInfo, tblInfo *model. defer w.sessPool.Put(sctx) skipCheck := false - failpoint.Inject("skipMockContextDoExec", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("skipMockContextDoExec")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { skipCheck = true } - }) + } if !skipCheck { // If there is a null value inserted, it cannot be modified and needs to be rollback. err = checkForNullValue(w.ctx, sctx, isDataTruncated, dbInfo.Name, tblInfo.Name, newCol, cols...) diff --git a/pkg/ddl/column.go__failpoint_stash__ b/pkg/ddl/column.go__failpoint_stash__ new file mode 100644 index 0000000000000..5b889c8f73464 --- /dev/null +++ b/pkg/ddl/column.go__failpoint_stash__ @@ -0,0 +1,1320 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "math/bits" + "strings" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror" + decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" + "github.com/pingcap/tidb/pkg/util/rowcodec" + kvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// InitAndAddColumnToTable initializes the ColumnInfo in-place and adds it to the table. +func InitAndAddColumnToTable(tblInfo *model.TableInfo, colInfo *model.ColumnInfo) *model.ColumnInfo { + cols := tblInfo.Columns + colInfo.ID = AllocateColumnID(tblInfo) + colInfo.State = model.StateNone + // To support add column asynchronous, we should mark its offset as the last column. + // So that we can use origin column offset to get value from row. + colInfo.Offset = len(cols) + // Append the column info to the end of the tblInfo.Columns. + // It will reorder to the right offset in "Columns" when it state change to public. + tblInfo.Columns = append(cols, colInfo) + return colInfo +} + +func checkAddColumn(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ColumnInfo, *model.ColumnInfo, + *ast.ColumnPosition, bool /* ifNotExists */, error) { + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, nil, false, errors.Trace(err) + } + col := &model.ColumnInfo{} + pos := &ast.ColumnPosition{} + offset := 0 + ifNotExists := false + err = job.DecodeArgs(col, pos, &offset, &ifNotExists) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, nil, false, errors.Trace(err) + } + + columnInfo := model.FindColumnInfo(tblInfo.Columns, col.Name.L) + if columnInfo != nil { + if columnInfo.State == model.StatePublic { + // We already have a column with the same column name. + job.State = model.JobStateCancelled + return nil, nil, nil, nil, ifNotExists, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name) + } + } + + err = CheckAfterPositionExists(tblInfo, pos) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, nil, false, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name) + } + + return tblInfo, columnInfo, col, pos, false, nil +} + +// CheckAfterPositionExists makes sure the column specified in AFTER clause is exists. +// For example, ALTER TABLE t ADD COLUMN c3 INT AFTER c1. +func CheckAfterPositionExists(tblInfo *model.TableInfo, pos *ast.ColumnPosition) error { + if pos != nil && pos.Tp == ast.ColumnPositionAfter { + c := model.FindColumnInfo(tblInfo.Columns, pos.RelativeColumn.Name.L) + if c == nil { + return infoschema.ErrColumnNotExists.GenWithStackByArgs(pos.RelativeColumn, tblInfo.Name) + } + } + return nil +} + +func setIndicesState(indexInfos []*model.IndexInfo, state model.SchemaState) { + for _, indexInfo := range indexInfos { + indexInfo.State = state + } +} + +func checkDropColumnForStatePublic(colInfo *model.ColumnInfo) (err error) { + // When the dropping column has not-null flag and it hasn't the default value, we can backfill the column value like "add column". + // NOTE: If the state of StateWriteOnly can be rollbacked, we'd better reconsider the original default value. + // And we need consider the column without not-null flag. + if colInfo.GetOriginDefaultValue() == nil && mysql.HasNotNullFlag(colInfo.GetFlag()) { + // If the column is timestamp default current_timestamp, and DDL owner is new version TiDB that set column.Version to 1, + // then old TiDB update record in the column write only stage will uses the wrong default value of the dropping column. + // Because new version of the column default value is UTC time, but old version TiDB will think the default value is the time in system timezone. + // But currently will be ok, because we can't cancel the drop column job when the job is running, + // so the column will be dropped succeed and client will never see the wrong default value of the dropped column. + // More info about this problem, see PR#9115. + originDefVal, err := generateOriginDefaultValue(colInfo, nil) + if err != nil { + return err + } + return colInfo.SetOriginDefaultValue(originDefVal) + } + return nil +} + +func onDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, colInfo, idxInfos, ifExists, err := checkDropColumn(d, t, job) + if err != nil { + if ifExists && dbterror.ErrCantDropFieldOrKey.Equal(err) { + // Convert the "not exists" error to a warning. + job.Warning = toTError(err) + job.State = model.JobStateDone + return ver, nil + } + return ver, errors.Trace(err) + } + if job.MultiSchemaInfo != nil && !job.IsRollingback() && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + job.SchemaState = colInfo.State + // Store the mark and enter the next DDL handling loop. + return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) + } + + originalState := colInfo.State + switch colInfo.State { + case model.StatePublic: + // public -> write only + colInfo.State = model.StateWriteOnly + setIndicesState(idxInfos, model.StateWriteOnly) + tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) + err = checkDropColumnForStatePublic(colInfo) + if err != nil { + return ver, errors.Trace(err) + } + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != colInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + case model.StateWriteOnly: + // write only -> delete only + failpoint.InjectCall("onDropColumnStateWriteOnly") + colInfo.State = model.StateDeleteOnly + tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) + if len(idxInfos) > 0 { + newIndices := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) + for _, idx := range tblInfo.Indices { + if !indexInfoContains(idx.ID, idxInfos) { + newIndices = append(newIndices, idx) + } + } + tblInfo.Indices = newIndices + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != colInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + job.Args = append(job.Args, indexInfosToIDList(idxInfos)) + case model.StateDeleteOnly: + // delete only -> reorganization + colInfo.State = model.StateDeleteReorganization + tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != colInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + case model.StateDeleteReorganization: + // reorganization -> absent + // All reorganization jobs are done, drop this column. + tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) + tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-1] + colInfo.State = model.StateNone + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != colInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + if job.IsRollingback() { + job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) + } else { + // We should set related index IDs for job + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + job.Args = append(job.Args, getPartitionIDs(tblInfo)) + } + default: + return ver, errors.Trace(dbterror.ErrInvalidDDLJob.GenWithStackByArgs("table", tblInfo.State)) + } + job.SchemaState = colInfo.State + return ver, errors.Trace(err) +} + +func checkDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ColumnInfo, []*model.IndexInfo, bool /* ifExists */, error) { + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, false, errors.Trace(err) + } + + var colName model.CIStr + var ifExists bool + // indexIDs is used to make sure we don't truncate args when decoding the rawArgs. + var indexIDs []int64 + err = job.DecodeArgs(&colName, &ifExists, &indexIDs) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, false, errors.Trace(err) + } + + colInfo := model.FindColumnInfo(tblInfo.Columns, colName.L) + if colInfo == nil || colInfo.Hidden { + job.State = model.JobStateCancelled + return nil, nil, nil, ifExists, dbterror.ErrCantDropFieldOrKey.GenWithStack("column %s doesn't exist", colName) + } + if err = isDroppableColumn(tblInfo, colName); err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, false, errors.Trace(err) + } + if err = checkDropColumnWithForeignKeyConstraintInOwner(d, t, job, tblInfo, colName.L); err != nil { + return nil, nil, nil, false, errors.Trace(err) + } + if err = checkDropColumnWithTTLConfig(tblInfo, colName.L); err != nil { + return nil, nil, nil, false, errors.Trace(err) + } + idxInfos := listIndicesWithColumn(colName.L, tblInfo.Indices) + return tblInfo, colInfo, idxInfos, false, nil +} + +func isDroppableColumn(tblInfo *model.TableInfo, colName model.CIStr) error { + if ok, dep, isHidden := hasDependentByGeneratedColumn(tblInfo, colName); ok { + if isHidden { + return dbterror.ErrDependentByFunctionalIndex.GenWithStackByArgs(dep) + } + return dbterror.ErrDependentByGeneratedColumn.GenWithStackByArgs(dep) + } + + if len(tblInfo.Columns) == 1 { + return dbterror.ErrCantRemoveAllFields.GenWithStack("can't drop only column %s in table %s", + colName, tblInfo.Name) + } + // We only support dropping column with single-value none Primary Key index covered now. + err := isColumnCanDropWithIndex(colName.L, tblInfo.Indices) + if err != nil { + return err + } + err = IsColumnDroppableWithCheckConstraint(colName, tblInfo) + if err != nil { + return err + } + return nil +} + +func onSetDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + newCol := &model.ColumnInfo{} + err := job.DecodeArgs(newCol) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + return updateColumnDefaultValue(d, t, job, newCol, &newCol.Name) +} + +func setIdxIDName(idxInfo *model.IndexInfo, newID int64, newName model.CIStr) { + idxInfo.ID = newID + idxInfo.Name = newName +} + +// SetIdxColNameOffset sets index column name and offset from changing ColumnInfo. +func SetIdxColNameOffset(idxCol *model.IndexColumn, changingCol *model.ColumnInfo) { + idxCol.Name = changingCol.Name + idxCol.Offset = changingCol.Offset + canPrefix := types.IsTypePrefixable(changingCol.GetType()) + if !canPrefix || (changingCol.GetFlen() <= idxCol.Length) { + idxCol.Length = types.UnspecifiedLength + } +} + +func removeChangingColAndIdxs(tblInfo *model.TableInfo, changingColID int64) { + restIdx := tblInfo.Indices[:0] + for _, idx := range tblInfo.Indices { + if !idx.HasColumnInIndexColumns(tblInfo, changingColID) { + restIdx = append(restIdx, idx) + } + } + tblInfo.Indices = restIdx + + restCols := tblInfo.Columns[:0] + for _, c := range tblInfo.Columns { + if c.ID != changingColID { + restCols = append(restCols, c) + } + } + tblInfo.Columns = restCols +} + +func replaceOldColumn(tblInfo *model.TableInfo, oldCol, changingCol *model.ColumnInfo, + newName model.CIStr) *model.ColumnInfo { + tblInfo.MoveColumnInfo(changingCol.Offset, len(tblInfo.Columns)-1) + changingCol = updateChangingCol(changingCol, newName, oldCol.Offset) + tblInfo.Columns[oldCol.Offset] = changingCol + tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-1] + return changingCol +} + +func replaceOldIndexes(tblInfo *model.TableInfo, changingIdxs []*model.IndexInfo) { + // Remove the changing indexes. + for i, idx := range tblInfo.Indices { + for _, cIdx := range changingIdxs { + if cIdx.ID == idx.ID { + tblInfo.Indices[i] = nil + break + } + } + } + tmp := tblInfo.Indices[:0] + for _, idx := range tblInfo.Indices { + if idx != nil { + tmp = append(tmp, idx) + } + } + tblInfo.Indices = tmp + // Replace the old indexes with changing indexes. + for _, cIdx := range changingIdxs { + // The index name should be changed from '_Idx$_name' to 'name'. + idxName := getChangingIndexOriginName(cIdx) + for i, idx := range tblInfo.Indices { + if strings.EqualFold(idxName, idx.Name.O) { + cIdx.Name = model.NewCIStr(idxName) + tblInfo.Indices[i] = cIdx + break + } + } + } +} + +// updateNewIdxColsNameOffset updates the name&offset of the index column. +func updateNewIdxColsNameOffset(changingIdxs []*model.IndexInfo, + oldName model.CIStr, changingCol *model.ColumnInfo) { + for _, idx := range changingIdxs { + for _, col := range idx.Columns { + if col.Name.L == oldName.L { + SetIdxColNameOffset(col, changingCol) + } + } + } +} + +// filterIndexesToRemove filters out the indexes that can be removed. +func filterIndexesToRemove(changingIdxs []*model.IndexInfo, colName model.CIStr, tblInfo *model.TableInfo) []*model.IndexInfo { + indexesToRemove := make([]*model.IndexInfo, 0, len(changingIdxs)) + for _, idx := range changingIdxs { + var hasOtherChangingCol bool + for _, col := range idx.Columns { + if col.Name.L == colName.L { + continue // ignore the current modifying column. + } + if !hasOtherChangingCol { + hasOtherChangingCol = tblInfo.Columns[col.Offset].ChangeStateInfo != nil + } + } + // For the indexes that still contains other changing column, skip removing it now. + // We leave the removal work to the last modify column job. + if !hasOtherChangingCol { + indexesToRemove = append(indexesToRemove, idx) + } + } + return indexesToRemove +} + +func updateChangingCol(col *model.ColumnInfo, newName model.CIStr, newOffset int) *model.ColumnInfo { + col.Name = newName + col.ChangeStateInfo = nil + col.Offset = newOffset + // After changing the column, the column's type is change, so it needs to set OriginDefaultValue back + // so that there is no error in getting the default value from OriginDefaultValue. + // Besides, nil data that was not backfilled in the "add column" is backfilled after the column is changed. + // So it can set OriginDefaultValue to nil. + col.OriginDefaultValue = nil + return col +} + +func buildRelatedIndexInfos(tblInfo *model.TableInfo, colID int64) []*model.IndexInfo { + var indexInfos []*model.IndexInfo + for _, idx := range tblInfo.Indices { + if idx.HasColumnInIndexColumns(tblInfo, colID) { + indexInfos = append(indexInfos, idx) + } + } + return indexInfos +} + +func buildRelatedIndexIDs(tblInfo *model.TableInfo, colID int64) []int64 { + var oldIdxIDs []int64 + for _, idx := range tblInfo.Indices { + if idx.HasColumnInIndexColumns(tblInfo, colID) { + oldIdxIDs = append(oldIdxIDs, idx.ID) + } + } + return oldIdxIDs +} + +// LocateOffsetToMove returns the offset of the column to move. +func LocateOffsetToMove(currentOffset int, pos *ast.ColumnPosition, tblInfo *model.TableInfo) (destOffset int, err error) { + if pos == nil { + return currentOffset, nil + } + // Get column offset. + switch pos.Tp { + case ast.ColumnPositionFirst: + return 0, nil + case ast.ColumnPositionAfter: + c := model.FindColumnInfo(tblInfo.Columns, pos.RelativeColumn.Name.L) + if c == nil || c.State != model.StatePublic { + return 0, infoschema.ErrColumnNotExists.GenWithStackByArgs(pos.RelativeColumn, tblInfo.Name) + } + if currentOffset <= c.Offset { + return c.Offset, nil + } + return c.Offset + 1, nil + case ast.ColumnPositionNone: + return currentOffset, nil + default: + return 0, errors.Errorf("unknown column position type") + } +} + +// BuildElements is exported for testing. +func BuildElements(changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo) []*meta.Element { + elements := make([]*meta.Element, 0, len(changingIdxs)+1) + elements = append(elements, &meta.Element{ID: changingCol.ID, TypeKey: meta.ColumnElementKey}) + for _, idx := range changingIdxs { + elements = append(elements, &meta.Element{ID: idx.ID, TypeKey: meta.IndexElementKey}) + } + return elements +} + +func (w *worker) updatePhysicalTableRow(t table.Table, reorgInfo *reorgInfo) error { + logutil.DDLLogger().Info("start to update table row", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) + if tbl, ok := t.(table.PartitionedTable); ok { + done := false + for !done { + p := tbl.GetPartition(reorgInfo.PhysicalTableID) + if p == nil { + return dbterror.ErrCancelledDDLJob.GenWithStack("Can not find partition id %d for table %d", reorgInfo.PhysicalTableID, t.Meta().ID) + } + workType := typeReorgPartitionWorker + switch reorgInfo.Job.Type { + case model.ActionReorganizePartition, + model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + // Expected + default: + // workType = typeUpdateColumnWorker + // TODO: Support Modify Column on partitioned table + // https://github.com/pingcap/tidb/issues/38297 + return dbterror.ErrCancelledDDLJob.GenWithStack("Modify Column on partitioned table / typeUpdateColumnWorker not yet supported.") + } + err := w.writePhysicalTableRecord(w.ctx, w.sessPool, p, workType, reorgInfo) + if err != nil { + return err + } + done, err = updateReorgInfo(w.sessPool, tbl, reorgInfo) + if err != nil { + return errors.Trace(err) + } + } + return nil + } + if tbl, ok := t.(table.PhysicalTable); ok { + return w.writePhysicalTableRecord(w.ctx, w.sessPool, tbl, typeUpdateColumnWorker, reorgInfo) + } + return dbterror.ErrCancelledDDLJob.GenWithStack("internal error for phys tbl id: %d tbl id: %d", reorgInfo.PhysicalTableID, t.Meta().ID) +} + +// TestReorgGoroutineRunning is only used in test to indicate the reorg goroutine has been started. +var TestReorgGoroutineRunning = make(chan any) + +// updateCurrentElement update the current element for reorgInfo. +func (w *worker) updateCurrentElement(t table.Table, reorgInfo *reorgInfo) error { + failpoint.Inject("mockInfiniteReorgLogic", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + a := new(any) + TestReorgGoroutineRunning <- a + for { + time.Sleep(30 * time.Millisecond) + if w.isReorgCancelled(reorgInfo.Job.ID) { + // Job is cancelled. So it can't be done. + failpoint.Return(dbterror.ErrCancelledDDLJob) + } + } + } + }) + // TODO: Support partition tables. + if bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { + //nolint:forcetypeassert + err := w.updatePhysicalTableRow(t.(table.PhysicalTable), reorgInfo) + if err != nil { + return errors.Trace(err) + } + } + + if _, ok := t.(table.PartitionedTable); ok { + // TODO: remove when modify column of partitioned table is supported + // https://github.com/pingcap/tidb/issues/38297 + return dbterror.ErrCancelledDDLJob.GenWithStack("Modify Column on partitioned table / typeUpdateColumnWorker not yet supported.") + } + // Get the original start handle and end handle. + currentVer, err := getValidCurrentVersion(reorgInfo.d.store) + if err != nil { + return errors.Trace(err) + } + //nolint:forcetypeassert + originalStartHandle, originalEndHandle, err := getTableRange(reorgInfo.NewJobContext(), reorgInfo.d, t.(table.PhysicalTable), currentVer.Ver, reorgInfo.Job.Priority) + if err != nil { + return errors.Trace(err) + } + + startElementOffset := 0 + startElementOffsetToResetHandle := -1 + // This backfill job starts with backfilling index data, whose index ID is currElement.ID. + if bytes.Equal(reorgInfo.currElement.TypeKey, meta.IndexElementKey) { + for i, element := range reorgInfo.elements[1:] { + if reorgInfo.currElement.ID == element.ID { + startElementOffset = i + startElementOffsetToResetHandle = i + break + } + } + } + + for i := startElementOffset; i < len(reorgInfo.elements[1:]); i++ { + // This backfill job has been exited during processing. At that time, the element is reorgInfo.elements[i+1] and handle range is [reorgInfo.StartHandle, reorgInfo.EndHandle]. + // Then the handle range of the rest elements' is [originalStartHandle, originalEndHandle]. + if i == startElementOffsetToResetHandle+1 { + reorgInfo.StartKey, reorgInfo.EndKey = originalStartHandle, originalEndHandle + } + + // Update the element in the reorgInfo for updating the reorg meta below. + reorgInfo.currElement = reorgInfo.elements[i+1] + // Write the reorg info to store so the whole reorganize process can recover from panic. + err := reorgInfo.UpdateReorgMeta(reorgInfo.StartKey, w.sessPool) + logutil.DDLLogger().Info("update column and indexes", + zap.Int64("job ID", reorgInfo.Job.ID), + zap.Stringer("element", reorgInfo.currElement), + zap.String("start key", hex.EncodeToString(reorgInfo.StartKey)), + zap.String("end key", hex.EncodeToString(reorgInfo.EndKey))) + if err != nil { + return errors.Trace(err) + } + err = w.addTableIndex(t, reorgInfo) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +type updateColumnWorker struct { + *backfillCtx + oldColInfo *model.ColumnInfo + newColInfo *model.ColumnInfo + + // The following attributes are used to reduce memory allocation. + rowRecords []*rowRecord + rowDecoder *decoder.RowDecoder + + rowMap map[int64]types.Datum + + checksumNeeded bool +} + +func newUpdateColumnWorker(id int, t table.PhysicalTable, decodeColMap map[int64]decoder.Column, reorgInfo *reorgInfo, jc *JobContext) (*updateColumnWorker, error) { + bCtx, err := newBackfillCtx(id, reorgInfo, reorgInfo.SchemaName, t, jc, "update_col_rate", false) + if err != nil { + return nil, err + } + + sessCtx := bCtx.sessCtx + sessCtx.GetSessionVars().StmtCtx.SetTypeFlags( + sessCtx.GetSessionVars().StmtCtx.TypeFlags(). + WithIgnoreZeroDateErr(!reorgInfo.ReorgMeta.SQLMode.HasStrictMode())) + bCtx.exprCtx = bCtx.sessCtx.GetExprCtx() + bCtx.tblCtx = bCtx.sessCtx.GetTableCtx() + + if !bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { + logutil.DDLLogger().Error("Element type for updateColumnWorker incorrect", zap.String("jobQuery", reorgInfo.Query), + zap.Stringer("reorgInfo", reorgInfo)) + return nil, nil + } + var oldCol, newCol *model.ColumnInfo + for _, col := range t.WritableCols() { + if col.ID == reorgInfo.currElement.ID { + newCol = col.ColumnInfo + oldCol = table.FindCol(t.Cols(), getChangingColumnOriginName(newCol)).ColumnInfo + break + } + } + rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) + failpoint.Inject("forceRowLevelChecksumOnUpdateColumnBackfill", func() { + orig := variable.EnableRowLevelChecksum.Load() + defer variable.EnableRowLevelChecksum.Store(orig) + variable.EnableRowLevelChecksum.Store(true) + }) + return &updateColumnWorker{ + backfillCtx: bCtx, + oldColInfo: oldCol, + newColInfo: newCol, + rowDecoder: rowDecoder, + rowMap: make(map[int64]types.Datum, len(decodeColMap)), + checksumNeeded: variable.EnableRowLevelChecksum.Load(), + }, nil +} + +func (w *updateColumnWorker) AddMetricInfo(cnt float64) { + w.metricCounter.Add(cnt) +} + +func (*updateColumnWorker) String() string { + return typeUpdateColumnWorker.String() +} + +func (w *updateColumnWorker) GetCtx() *backfillCtx { + return w.backfillCtx +} + +type rowRecord struct { + key []byte // It's used to lock a record. Record it to reduce the encoding time. + vals []byte // It's the record. + warning *terror.Error // It's used to record the cast warning of a record. +} + +// getNextHandleKey gets next handle of entry that we are going to process. +func getNextHandleKey(taskRange reorgBackfillTask, + taskDone bool, lastAccessedHandle kv.Key) (nextHandle kv.Key) { + if !taskDone { + // The task is not done. So we need to pick the last processed entry's handle and add one. + return lastAccessedHandle.Next() + } + + return taskRange.endKey.Next() +} + +func (w *updateColumnWorker) fetchRowColVals(txn kv.Transaction, taskRange reorgBackfillTask) ([]*rowRecord, kv.Key, bool, error) { + w.rowRecords = w.rowRecords[:0] + startTime := time.Now() + + // taskDone means that the added handle is out of taskRange.endHandle. + taskDone := false + var lastAccessedHandle kv.Key + oprStartTime := startTime + err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, taskRange.physicalTable.RecordPrefix(), + txn.StartTS(), taskRange.startKey, taskRange.endKey, func(handle kv.Handle, recordKey kv.Key, rawRow []byte) (bool, error) { + oprEndTime := time.Now() + logSlowOperations(oprEndTime.Sub(oprStartTime), "iterateSnapshotKeys in updateColumnWorker fetchRowColVals", 0) + oprStartTime = oprEndTime + + taskDone = recordKey.Cmp(taskRange.endKey) >= 0 + + if taskDone || len(w.rowRecords) >= w.batchCnt { + return false, nil + } + + if err1 := w.getRowRecord(handle, recordKey, rawRow); err1 != nil { + return false, errors.Trace(err1) + } + lastAccessedHandle = recordKey + if recordKey.Cmp(taskRange.endKey) == 0 { + taskDone = true + return false, nil + } + return true, nil + }) + + if len(w.rowRecords) == 0 { + taskDone = true + } + + logutil.DDLLogger().Debug("txn fetches handle info", + zap.Uint64("txnStartTS", txn.StartTS()), + zap.String("taskRange", taskRange.String()), + zap.Duration("takeTime", time.Since(startTime))) + return w.rowRecords, getNextHandleKey(taskRange, taskDone, lastAccessedHandle), taskDone, errors.Trace(err) +} + +func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, rawRow []byte) error { + sysTZ := w.loc + _, err := w.rowDecoder.DecodeTheExistedColumnMap(w.exprCtx, handle, rawRow, sysTZ, w.rowMap) + if err != nil { + return errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("column", err)) + } + + if _, ok := w.rowMap[w.newColInfo.ID]; ok { + // The column is already added by update or insert statement, skip it. + w.cleanRowMap() + return nil + } + + var recordWarning *terror.Error + // Since every updateColumnWorker handle their own work individually, we can cache warning in statement context when casting datum. + oldWarn := w.warnings.GetWarnings() + if oldWarn == nil { + oldWarn = []contextutil.SQLWarn{} + } else { + oldWarn = oldWarn[:0] + } + w.warnings.SetWarnings(oldWarn) + val := w.rowMap[w.oldColInfo.ID] + col := w.newColInfo + if val.Kind() == types.KindNull && col.FieldType.GetType() == mysql.TypeTimestamp && mysql.HasNotNullFlag(col.GetFlag()) { + if v, err := expression.GetTimeCurrentTimestamp(w.exprCtx.GetEvalCtx(), col.GetType(), col.GetDecimal()); err == nil { + // convert null value to timestamp should be substituted with current timestamp if NOT_NULL flag is set. + w.rowMap[w.oldColInfo.ID] = v + } + } + newColVal, err := table.CastColumnValue(w.exprCtx, w.rowMap[w.oldColInfo.ID], w.newColInfo, false, false) + if err != nil { + return w.reformatErrors(err) + } + warn := w.warnings.GetWarnings() + if len(warn) != 0 { + //nolint:forcetypeassert + recordWarning = errors.Cause(w.reformatErrors(warn[0].Err)).(*terror.Error) + } + + failpoint.Inject("MockReorgTimeoutInOneRegion", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + if handle.IntValue() == 3000 && atomic.CompareAndSwapInt32(&TestCheckReorgTimeout, 0, 1) { + failpoint.Return(errors.Trace(dbterror.ErrWaitReorgTimeout)) + } + } + }) + + w.rowMap[w.newColInfo.ID] = newColVal + _, err = w.rowDecoder.EvalRemainedExprColumnMap(w.exprCtx, w.rowMap) + if err != nil { + return errors.Trace(err) + } + newColumnIDs := make([]int64, 0, len(w.rowMap)) + newRow := make([]types.Datum, 0, len(w.rowMap)) + for colID, val := range w.rowMap { + newColumnIDs = append(newColumnIDs, colID) + newRow = append(newRow, val) + } + rd := w.tblCtx.GetRowEncodingConfig().RowEncoder + ec := w.exprCtx.GetEvalCtx().ErrCtx() + var checksum rowcodec.Checksum + if w.checksumNeeded { + checksum = rowcodec.RawChecksum{Key: recordKey} + } + newRowVal, err := tablecodec.EncodeRow(sysTZ, newRow, newColumnIDs, nil, nil, checksum, rd) + err = ec.HandleError(err) + if err != nil { + return errors.Trace(err) + } + + w.rowRecords = append(w.rowRecords, &rowRecord{key: recordKey, vals: newRowVal, warning: recordWarning}) + w.cleanRowMap() + return nil +} + +// reformatErrors casted error because `convertTo` function couldn't package column name and datum value for some errors. +func (w *updateColumnWorker) reformatErrors(err error) error { + // Since row count is not precious in concurrent reorganization, here we substitute row count with datum value. + if types.ErrTruncated.Equal(err) || types.ErrDataTooLong.Equal(err) { + dStr := datumToStringNoErr(w.rowMap[w.oldColInfo.ID]) + err = types.ErrTruncated.GenWithStack("Data truncated for column '%s', value is '%s'", w.oldColInfo.Name, dStr) + } + + if types.ErrWarnDataOutOfRange.Equal(err) { + dStr := datumToStringNoErr(w.rowMap[w.oldColInfo.ID]) + err = types.ErrWarnDataOutOfRange.GenWithStack("Out of range value for column '%s', the value is '%s'", w.oldColInfo.Name, dStr) + } + return err +} + +func datumToStringNoErr(d types.Datum) string { + if v, err := d.ToString(); err == nil { + return v + } + return fmt.Sprintf("%v", d.GetValue()) +} + +func (w *updateColumnWorker) cleanRowMap() { + for id := range w.rowMap { + delete(w.rowMap, id) + } +} + +// BackfillData will backfill the table record in a transaction. A lock corresponds to a rowKey if the value of rowKey is changed. +func (w *updateColumnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { + oprStartTime := time.Now() + ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) + errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { + taskCtx.addedCount = 0 + taskCtx.scanCount = 0 + updateTxnEntrySizeLimitIfNeeded(txn) + + // Because TiCDC do not want this kind of change, + // so we set the lossy DDL reorg txn source to 1 to + // avoid TiCDC to replicate this kind of change. + var txnSource uint64 + if val := txn.GetOption(kv.TxnSource); val != nil { + txnSource, _ = val.(uint64) + } + err := kv.SetLossyDDLReorgSource(&txnSource, kv.LossyDDLColumnReorgSource) + if err != nil { + return errors.Trace(err) + } + txn.SetOption(kv.TxnSource, txnSource) + + txn.SetOption(kv.Priority, handleRange.priority) + if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(handleRange.getJobID()); tagger != nil { + txn.SetOption(kv.ResourceGroupTagger, tagger) + } + txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) + + rowRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) + if err != nil { + return errors.Trace(err) + } + taskCtx.nextKey = nextKey + taskCtx.done = taskDone + + // Optimize for few warnings! + warningsMap := make(map[errors.ErrorID]*terror.Error, 2) + warningsCountMap := make(map[errors.ErrorID]int64, 2) + for _, rowRecord := range rowRecords { + taskCtx.scanCount++ + + err = txn.Set(rowRecord.key, rowRecord.vals) + if err != nil { + return errors.Trace(err) + } + taskCtx.addedCount++ + if rowRecord.warning != nil { + if _, ok := warningsCountMap[rowRecord.warning.ID()]; ok { + warningsCountMap[rowRecord.warning.ID()]++ + } else { + warningsCountMap[rowRecord.warning.ID()] = 1 + warningsMap[rowRecord.warning.ID()] = rowRecord.warning + } + } + } + + // Collect the warnings. + taskCtx.warnings, taskCtx.warningsCount = warningsMap, warningsCountMap + + return nil + }) + logSlowOperations(time.Since(oprStartTime), "BackfillData", 3000) + + return +} + +func updateChangingObjState(changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo, schemaState model.SchemaState) { + changingCol.State = schemaState + for _, idx := range changingIdxs { + idx.State = schemaState + } +} + +func checkAndApplyAutoRandomBits(d *ddlCtx, m *meta.Meta, dbInfo *model.DBInfo, tblInfo *model.TableInfo, + oldCol *model.ColumnInfo, newCol *model.ColumnInfo, newAutoRandBits uint64) error { + if newAutoRandBits == 0 { + return nil + } + idAcc := m.GetAutoIDAccessors(dbInfo.ID, tblInfo.ID) + err := checkNewAutoRandomBits(idAcc, oldCol, newCol, newAutoRandBits, tblInfo.AutoRandomRangeBits, tblInfo.SepAutoInc()) + if err != nil { + return err + } + return applyNewAutoRandomBits(d, m, dbInfo, tblInfo, oldCol, newAutoRandBits) +} + +// checkNewAutoRandomBits checks whether the new auto_random bits number can cause overflow. +func checkNewAutoRandomBits(idAccessors meta.AutoIDAccessors, oldCol *model.ColumnInfo, + newCol *model.ColumnInfo, newShardBits, newRangeBits uint64, sepAutoInc bool) error { + shardFmt := autoid.NewShardIDFormat(&newCol.FieldType, newShardBits, newRangeBits) + + idAcc := idAccessors.RandomID() + convertedFromAutoInc := mysql.HasAutoIncrementFlag(oldCol.GetFlag()) + if convertedFromAutoInc { + if sepAutoInc { + idAcc = idAccessors.IncrementID(model.TableInfoVersion5) + } else { + idAcc = idAccessors.RowID() + } + } + // Generate a new auto ID first to prevent concurrent update in DML. + _, err := idAcc.Inc(1) + if err != nil { + return err + } + currentIncBitsVal, err := idAcc.Get() + if err != nil { + return err + } + // Find the max number of available shard bits by + // counting leading zeros in current inc part of auto_random ID. + usedBits := uint64(64 - bits.LeadingZeros64(uint64(currentIncBitsVal))) + if usedBits > shardFmt.IncrementalBits { + overflowCnt := usedBits - shardFmt.IncrementalBits + errMsg := fmt.Sprintf(autoid.AutoRandomOverflowErrMsg, newShardBits-overflowCnt, newShardBits, oldCol.Name.O) + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) + } + return nil +} + +func (d *ddlCtx) getAutoIDRequirement() autoid.Requirement { + return &asAutoIDRequirement{ + store: d.store, + autoidCli: d.autoidCli, + } +} + +type asAutoIDRequirement struct { + store kv.Storage + autoidCli *autoid.ClientDiscover +} + +var _ autoid.Requirement = &asAutoIDRequirement{} + +func (r *asAutoIDRequirement) Store() kv.Storage { + return r.store +} + +func (r *asAutoIDRequirement) AutoIDClient() *autoid.ClientDiscover { + return r.autoidCli +} + +// applyNewAutoRandomBits set auto_random bits to TableInfo and +// migrate auto_increment ID to auto_random ID if possible. +func applyNewAutoRandomBits(d *ddlCtx, m *meta.Meta, dbInfo *model.DBInfo, + tblInfo *model.TableInfo, oldCol *model.ColumnInfo, newAutoRandBits uint64) error { + tblInfo.AutoRandomBits = newAutoRandBits + needMigrateFromAutoIncToAutoRand := mysql.HasAutoIncrementFlag(oldCol.GetFlag()) + if !needMigrateFromAutoIncToAutoRand { + return nil + } + autoRandAlloc := autoid.NewAllocatorsFromTblInfo(d.getAutoIDRequirement(), dbInfo.ID, tblInfo).Get(autoid.AutoRandomType) + if autoRandAlloc == nil { + errMsg := fmt.Sprintf(autoid.AutoRandomAllocatorNotFound, dbInfo.Name.O, tblInfo.Name.O) + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) + } + idAcc := m.GetAutoIDAccessors(dbInfo.ID, tblInfo.ID).RowID() + nextAutoIncID, err := idAcc.Get() + if err != nil { + return errors.Trace(err) + } + err = autoRandAlloc.Rebase(context.Background(), nextAutoIncID, false) + if err != nil { + return errors.Trace(err) + } + if err := idAcc.Del(); err != nil { + return errors.Trace(err) + } + return nil +} + +// checkForNullValue ensure there are no null values of the column of this table. +// `isDataTruncated` indicates whether the new field and the old field type are the same, in order to be compatible with mysql. +func checkForNullValue(ctx context.Context, sctx sessionctx.Context, isDataTruncated bool, schema, table model.CIStr, newCol *model.ColumnInfo, oldCols ...*model.ColumnInfo) error { + needCheckNullValue := false + for _, oldCol := range oldCols { + if oldCol.GetType() != mysql.TypeTimestamp && newCol.GetType() == mysql.TypeTimestamp { + // special case for convert null value of non-timestamp type to timestamp type, null value will be substituted with current timestamp. + continue + } + needCheckNullValue = true + } + if !needCheckNullValue { + return nil + } + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where ") + paramsList := make([]any, 0, 2+len(oldCols)) + paramsList = append(paramsList, schema.L, table.L) + for i, col := range oldCols { + if i == 0 { + buf.WriteString("%n is null") + paramsList = append(paramsList, col.Name.L) + } else { + buf.WriteString(" or %n is null") + paramsList = append(paramsList, col.Name.L) + } + } + buf.WriteString(" limit 1") + //nolint:forcetypeassert + rows, _, err := sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, buf.String(), paramsList...) + if err != nil { + return errors.Trace(err) + } + rowCount := len(rows) + if rowCount != 0 { + if isDataTruncated { + return dbterror.ErrWarnDataTruncated.GenWithStackByArgs(newCol.Name.L, rowCount) + } + return dbterror.ErrInvalidUseOfNull + } + return nil +} + +func updateColumnDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job, newCol *model.ColumnInfo, oldColName *model.CIStr) (ver int64, _ error) { + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + + if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + // Store the mark and enter the next DDL handling loop. + return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) + } + + oldCol := model.FindColumnInfo(tblInfo.Columns, oldColName.L) + if oldCol == nil || oldCol.State != model.StatePublic { + job.State = model.JobStateCancelled + return ver, infoschema.ErrColumnNotExists.GenWithStackByArgs(newCol.Name, tblInfo.Name) + } + + if hasDefaultValue, _, err := checkColumnDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol.Clone()), newCol.DefaultValue); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } else if !hasDefaultValue { + job.State = model.JobStateCancelled + return ver, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(newCol.Name) + } + + // The newCol's offset may be the value of the old schema version, so we can't use newCol directly. + oldCol.DefaultValue = newCol.DefaultValue + oldCol.DefaultValueBit = newCol.DefaultValueBit + oldCol.DefaultIsExpr = newCol.DefaultIsExpr + if mysql.HasNoDefaultValueFlag(newCol.GetFlag()) { + oldCol.AddFlag(mysql.NoDefaultValueFlag) + } else { + oldCol.DelFlag(mysql.NoDefaultValueFlag) + err = checkDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol), true) + if err != nil { + job.State = model.JobStateCancelled + return ver, err + } + } + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func isColumnWithIndex(colName string, indices []*model.IndexInfo) bool { + for _, indexInfo := range indices { + for _, col := range indexInfo.Columns { + if col.Name.L == colName { + return true + } + } + } + return false +} + +func isColumnCanDropWithIndex(colName string, indices []*model.IndexInfo) error { + for _, indexInfo := range indices { + if indexInfo.Primary || len(indexInfo.Columns) > 1 { + for _, col := range indexInfo.Columns { + if col.Name.L == colName { + return dbterror.ErrCantDropColWithIndex.GenWithStack("can't drop column %s with composite index covered or Primary Key covered now", colName) + } + } + } + } + return nil +} + +func listIndicesWithColumn(colName string, indices []*model.IndexInfo) []*model.IndexInfo { + ret := make([]*model.IndexInfo, 0) + for _, indexInfo := range indices { + if len(indexInfo.Columns) == 1 && colName == indexInfo.Columns[0].Name.L { + ret = append(ret, indexInfo) + } + } + return ret +} + +// GetColumnForeignKeyInfo returns the wanted foreign key info +func GetColumnForeignKeyInfo(colName string, fkInfos []*model.FKInfo) *model.FKInfo { + for _, fkInfo := range fkInfos { + for _, col := range fkInfo.Cols { + if col.L == colName { + return fkInfo + } + } + } + return nil +} + +// AllocateColumnID allocates next column ID from TableInfo. +func AllocateColumnID(tblInfo *model.TableInfo) int64 { + tblInfo.MaxColumnID++ + return tblInfo.MaxColumnID +} + +func checkAddColumnTooManyColumns(colNum int) error { + if uint32(colNum) > atomic.LoadUint32(&config.GetGlobalConfig().TableColumnCountLimit) { + return dbterror.ErrTooManyFields + } + return nil +} + +// modifyColsFromNull2NotNull modifies the type definitions of 'null' to 'not null'. +// Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. +func modifyColsFromNull2NotNull(w *worker, dbInfo *model.DBInfo, tblInfo *model.TableInfo, cols []*model.ColumnInfo, newCol *model.ColumnInfo, isDataTruncated bool) error { + // Get sessionctx from context resource pool. + var sctx sessionctx.Context + sctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(sctx) + + skipCheck := false + failpoint.Inject("skipMockContextDoExec", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + skipCheck = true + } + }) + if !skipCheck { + // If there is a null value inserted, it cannot be modified and needs to be rollback. + err = checkForNullValue(w.ctx, sctx, isDataTruncated, dbInfo.Name, tblInfo.Name, newCol, cols...) + if err != nil { + return errors.Trace(err) + } + } + + // Prevent this field from inserting null values. + for _, col := range cols { + col.AddFlag(mysql.PreventNullInsertFlag) + } + return nil +} + +func generateOriginDefaultValue(col *model.ColumnInfo, ctx sessionctx.Context) (any, error) { + var err error + odValue := col.GetDefaultValue() + if odValue == nil && mysql.HasNotNullFlag(col.GetFlag()) || + // It's for drop column and modify column. + (col.DefaultIsExpr && odValue != strings.ToUpper(ast.CurrentTimestamp) && ctx == nil) { + switch col.GetType() { + // Just use enum field's first element for OriginDefaultValue. + case mysql.TypeEnum: + defEnum, verr := types.ParseEnumValue(col.GetElems(), 1) + if verr != nil { + return nil, errors.Trace(verr) + } + defVal := types.NewCollateMysqlEnumDatum(defEnum, col.GetCollate()) + return defVal.ToString() + default: + zeroVal := table.GetZeroValue(col) + odValue, err = zeroVal.ToString() + if err != nil { + return nil, errors.Trace(err) + } + } + } + + if odValue == strings.ToUpper(ast.CurrentTimestamp) { + var t time.Time + if ctx == nil { + t = time.Now() + } else { + t, _ = expression.GetStmtTimestamp(ctx.GetExprCtx().GetEvalCtx()) + } + if col.GetType() == mysql.TypeTimestamp { + odValue = types.NewTime(types.FromGoTime(t.UTC()), col.GetType(), col.GetDecimal()).String() + } else if col.GetType() == mysql.TypeDatetime { + odValue = types.NewTime(types.FromGoTime(t), col.GetType(), col.GetDecimal()).String() + } + return odValue, nil + } + + if col.DefaultIsExpr && ctx != nil { + valStr, ok := odValue.(string) + if !ok { + return nil, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String()) + } + oldValue := strings.ToLower(valStr) + // It's checked in getFuncCallDefaultValue. + if !strings.Contains(oldValue, fmt.Sprintf("%s(%s(),", ast.DateFormat, ast.Now)) && + !strings.Contains(oldValue, ast.StrToDate) { + return nil, errors.Trace(dbterror.ErrBinlogUnsafeSystemFunction) + } + + defVal, err := table.GetColDefaultValue(ctx.GetExprCtx(), col) + if err != nil { + return nil, errors.Trace(err) + } + odValue, err = defVal.ToString() + if err != nil { + return nil, errors.Trace(err) + } + } + return odValue, nil +} + +func indexInfoContains(idxID int64, idxInfos []*model.IndexInfo) bool { + for _, idxInfo := range idxInfos { + if idxID == idxInfo.ID { + return true + } + } + return false +} + +func indexInfosToIDList(idxInfos []*model.IndexInfo) []int64 { + ids := make([]int64, 0, len(idxInfos)) + for _, idxInfo := range idxInfos { + ids = append(ids, idxInfo.ID) + } + return ids +} + +func genChangingColumnUniqueName(tblInfo *model.TableInfo, oldCol *model.ColumnInfo) string { + suffix := 0 + newColumnNamePrefix := fmt.Sprintf("%s%s", changingColumnPrefix, oldCol.Name.O) + newColumnLowerName := fmt.Sprintf("%s_%d", strings.ToLower(newColumnNamePrefix), suffix) + // Check whether the new column name is used. + columnNameMap := make(map[string]bool, len(tblInfo.Columns)) + for _, col := range tblInfo.Columns { + columnNameMap[col.Name.L] = true + } + for columnNameMap[newColumnLowerName] { + suffix++ + newColumnLowerName = fmt.Sprintf("%s_%d", strings.ToLower(newColumnNamePrefix), suffix) + } + return fmt.Sprintf("%s_%d", newColumnNamePrefix, suffix) +} + +func genChangingIndexUniqueName(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) string { + suffix := 0 + newIndexNamePrefix := fmt.Sprintf("%s%s", changingIndexPrefix, idxInfo.Name.O) + newIndexLowerName := fmt.Sprintf("%s_%d", strings.ToLower(newIndexNamePrefix), suffix) + // Check whether the new index name is used. + indexNameMap := make(map[string]bool, len(tblInfo.Indices)) + for _, idx := range tblInfo.Indices { + indexNameMap[idx.Name.L] = true + } + for indexNameMap[newIndexLowerName] { + suffix++ + newIndexLowerName = fmt.Sprintf("%s_%d", strings.ToLower(newIndexNamePrefix), suffix) + } + return fmt.Sprintf("%s_%d", newIndexNamePrefix, suffix) +} + +func getChangingIndexOriginName(changingIdx *model.IndexInfo) string { + idxName := strings.TrimPrefix(changingIdx.Name.O, changingIndexPrefix) + // Since the unique idxName may contain the suffix number (indexName_num), better trim the suffix. + var pos int + if pos = strings.LastIndex(idxName, "_"); pos == -1 { + return idxName + } + return idxName[:pos] +} + +func getChangingColumnOriginName(changingColumn *model.ColumnInfo) string { + columnName := strings.TrimPrefix(changingColumn.Name.O, changingColumnPrefix) + var pos int + if pos = strings.LastIndex(columnName, "_"); pos == -1 { + return columnName + } + return columnName[:pos] +} + +func getExpressionIndexOriginName(expressionIdx *model.ColumnInfo) string { + columnName := strings.TrimPrefix(expressionIdx.Name.O, expressionIndexPrefix+"_") + var pos int + if pos = strings.LastIndex(columnName, "_"); pos == -1 { + return columnName + } + return columnName[:pos] +} diff --git a/pkg/ddl/constraint.go b/pkg/ddl/constraint.go index 572c66439889a..01364e895e641 100644 --- a/pkg/ddl/constraint.go +++ b/pkg/ddl/constraint.go @@ -37,11 +37,11 @@ func (w *worker) onAddCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) ( return rollingBackAddConstraint(d, t, job) } - failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorBeforeDecodeArgs")); _err_ == nil { if val.(bool) { - failpoint.Return(ver, errors.New("occur an error before decode args")) + return ver, errors.New("occur an error before decode args") } - }) + } dbInfo, tblInfo, constraintInfoInMeta, constraintInfoInJob, err := checkAddCheckConstraint(t, job) if err != nil { @@ -355,11 +355,11 @@ func findDependentColsInExpr(expr ast.ExprNode) map[string]struct{} { func (w *worker) verifyRemainRecordsForCheckConstraint(dbInfo *model.DBInfo, tableInfo *model.TableInfo, constr *model.ConstraintInfo) error { // Inject a fail-point to skip the remaining records check. - failpoint.Inject("mockVerifyRemainDataSuccess", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockVerifyRemainDataSuccess")); _err_ == nil { if val.(bool) { - failpoint.Return(nil) + return nil } - }) + } // Get sessionctx from ddl context resource pool in ddl worker. var sctx sessionctx.Context sctx, err := w.sessPool.Get() diff --git a/pkg/ddl/constraint.go__failpoint_stash__ b/pkg/ddl/constraint.go__failpoint_stash__ new file mode 100644 index 0000000000000..572c66439889a --- /dev/null +++ b/pkg/ddl/constraint.go__failpoint_stash__ @@ -0,0 +1,432 @@ +// Copyright 2023-2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/dbterror" +) + +func (w *worker) onAddCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + // Handle the rolling back job. + if job.IsRollingback() { + return rollingBackAddConstraint(d, t, job) + } + + failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(ver, errors.New("occur an error before decode args")) + } + }) + + dbInfo, tblInfo, constraintInfoInMeta, constraintInfoInJob, err := checkAddCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + if constraintInfoInMeta == nil { + // It's first time to run add constraint job, so there is no constraint info in meta. + // Use the raw constraint info from job directly and modify table info here. + constraintInfoInJob.ID = allocateConstraintID(tblInfo) + // Reset constraint name according to real-time constraints name at this point. + constrNames := map[string]bool{} + for _, constr := range tblInfo.Constraints { + constrNames[constr.Name.L] = true + } + setNameForConstraintInfo(tblInfo.Name.L, constrNames, []*model.ConstraintInfo{constraintInfoInJob}) + // Double check the constraint dependency. + existedColsMap := make(map[string]struct{}) + cols := tblInfo.Columns + for _, v := range cols { + if v.State == model.StatePublic { + existedColsMap[v.Name.L] = struct{}{} + } + } + dependedCols := constraintInfoInJob.ConstraintCols + for _, k := range dependedCols { + if _, ok := existedColsMap[k.L]; !ok { + // The table constraint depended on a non-existed column. + return ver, dbterror.ErrTableCheckConstraintReferUnknown.GenWithStackByArgs(constraintInfoInJob.Name, k) + } + } + + tblInfo.Constraints = append(tblInfo.Constraints, constraintInfoInJob) + constraintInfoInMeta = constraintInfoInJob + } + + // If not enforced, add it directly. + if !constraintInfoInMeta.Enforced { + constraintInfoInMeta.State = model.StatePublic + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil + } + + switch constraintInfoInMeta.State { + case model.StateNone: + job.SchemaState = model.StateWriteOnly + constraintInfoInMeta.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + case model.StateWriteOnly: + job.SchemaState = model.StateWriteReorganization + constraintInfoInMeta.State = model.StateWriteReorganization + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + case model.StateWriteReorganization: + err = w.verifyRemainRecordsForCheckConstraint(dbInfo, tblInfo, constraintInfoInMeta) + if err != nil { + if dbterror.ErrCheckConstraintIsViolated.Equal(err) { + job.State = model.JobStateRollingback + } + return ver, errors.Trace(err) + } + constraintInfoInMeta.State = model.StatePublic + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("constraint", constraintInfoInMeta.State) + } + + return ver, errors.Trace(err) +} + +func checkAddCheckConstraint(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ConstraintInfo, *model.ConstraintInfo, error) { + schemaID := job.SchemaID + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return nil, nil, nil, nil, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, nil, errors.Trace(err) + } + constraintInfo1 := &model.ConstraintInfo{} + err = job.DecodeArgs(constraintInfo1) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, nil, errors.Trace(err) + } + // do the double-check with constraint existence. + constraintInfo2 := tblInfo.FindConstraintInfoByName(constraintInfo1.Name.L) + if constraintInfo2 != nil { + if constraintInfo2.State == model.StatePublic { + // We already have a constraint with the same constraint name. + job.State = model.JobStateCancelled + return nil, nil, nil, nil, infoschema.ErrColumnExists.GenWithStackByArgs(constraintInfo1.Name) + } + // if not, that means constraint was in intermediate state. + } + + err = checkConstraintNamesNotExists(t, schemaID, []*model.ConstraintInfo{constraintInfo1}) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, nil, err + } + + return dbInfo, tblInfo, constraintInfo2, constraintInfo1, nil +} + +// onDropCheckConstraint can be called from two case: +// 1: rollback in add constraint.(in rollback function the job.args will be changed) +// 2: user drop constraint ddl. +func onDropCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, constraintInfo, err := checkDropCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + switch constraintInfo.State { + case model.StatePublic: + job.SchemaState = model.StateWriteOnly + constraintInfo.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + case model.StateWriteOnly: + // write only state constraint will still take effect to check the newly inserted data. + // So the dependent column shouldn't be dropped even in this intermediate state. + constraintInfo.State = model.StateNone + // remove the constraint from tableInfo. + for i, constr := range tblInfo.Constraints { + if constr.Name.L == constraintInfo.Name.L { + tblInfo.Constraints = append(tblInfo.Constraints[0:i], tblInfo.Constraints[i+1:]...) + } + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + default: + err = dbterror.ErrInvalidDDLJob.GenWithStackByArgs("constraint", tblInfo.State) + } + return ver, errors.Trace(err) +} + +func checkDropCheckConstraint(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ConstraintInfo, error) { + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, errors.Trace(err) + } + + var constrName model.CIStr + err = job.DecodeArgs(&constrName) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, errors.Trace(err) + } + + // double check with constraint existence. + constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + job.State = model.JobStateCancelled + return nil, nil, dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + return tblInfo, constraintInfo, nil +} + +func (w *worker) onAlterCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + dbInfo, tblInfo, constraintInfo, enforced, err := checkAlterCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + if job.IsRollingback() { + return rollingBackAlterConstraint(d, t, job) + } + + // Current State is desired. + if constraintInfo.State == model.StatePublic && constraintInfo.Enforced == enforced { + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return + } + + // enforced will fetch table data and check the constraint. + if enforced { + switch constraintInfo.State { + case model.StatePublic: + job.SchemaState = model.StateWriteReorganization + constraintInfo.State = model.StateWriteReorganization + constraintInfo.Enforced = enforced + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + case model.StateWriteReorganization: + job.SchemaState = model.StateWriteOnly + constraintInfo.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + case model.StateWriteOnly: + err = w.verifyRemainRecordsForCheckConstraint(dbInfo, tblInfo, constraintInfo) + if err != nil { + if dbterror.ErrCheckConstraintIsViolated.Equal(err) { + job.State = model.JobStateRollingback + } + return ver, errors.Trace(err) + } + constraintInfo.State = model.StatePublic + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + } + } else { + constraintInfo.Enforced = enforced + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + if err != nil { + // update version and tableInfo error will cause retry. + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + } + return ver, err +} + +func checkAlterCheckConstraint(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ConstraintInfo, bool, error) { + schemaID := job.SchemaID + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return nil, nil, nil, false, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, false, errors.Trace(err) + } + + var ( + enforced bool + constrName model.CIStr + ) + err = job.DecodeArgs(&constrName, &enforced) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, false, errors.Trace(err) + } + // do the double check with constraint existence. + constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + job.State = model.JobStateCancelled + return nil, nil, nil, false, dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + return dbInfo, tblInfo, constraintInfo, enforced, nil +} + +func allocateConstraintID(tblInfo *model.TableInfo) int64 { + tblInfo.MaxConstraintID++ + return tblInfo.MaxConstraintID +} + +func buildConstraintInfo(tblInfo *model.TableInfo, dependedCols []model.CIStr, constr *ast.Constraint, state model.SchemaState) (*model.ConstraintInfo, error) { + constraintName := model.NewCIStr(constr.Name) + if err := checkTooLongConstraint(constraintName); err != nil { + return nil, errors.Trace(err) + } + + // Restore check constraint expression to string. + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + + sb.Reset() + err := constr.Expr.Restore(restoreCtx) + if err != nil { + return nil, errors.Trace(err) + } + + // Create constraint info. + constraintInfo := &model.ConstraintInfo{ + Name: constraintName, + Table: tblInfo.Name, + ConstraintCols: dependedCols, + ExprString: sb.String(), + Enforced: constr.Enforced, + InColumn: constr.InColumn, + State: state, + } + + return constraintInfo, nil +} + +func checkTooLongConstraint(constr model.CIStr) error { + if len(constr.L) > mysql.MaxConstraintIdentifierLen { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(constr) + } + return nil +} + +// findDependentColsInExpr returns a set of string, which indicates +// the names of the columns that are dependent by exprNode. +func findDependentColsInExpr(expr ast.ExprNode) map[string]struct{} { + colNames := FindColumnNamesInExpr(expr) + colsMap := make(map[string]struct{}, len(colNames)) + for _, depCol := range colNames { + colsMap[depCol.Name.L] = struct{}{} + } + return colsMap +} + +func (w *worker) verifyRemainRecordsForCheckConstraint(dbInfo *model.DBInfo, tableInfo *model.TableInfo, constr *model.ConstraintInfo) error { + // Inject a fail-point to skip the remaining records check. + failpoint.Inject("mockVerifyRemainDataSuccess", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil) + } + }) + // Get sessionctx from ddl context resource pool in ddl worker. + var sctx sessionctx.Context + sctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(sctx) + + // If there is any row can't pass the check expression, the add constraint action will error. + // It's no need to construct expression node out and pull the chunk rows through it. Here we + // can let the check expression restored string as the filter in where clause directly. + // Prepare internal SQL to fetch data from physical table under this filter. + sql := fmt.Sprintf("select 1 from `%s`.`%s` where not %s limit 1", dbInfo.Name.L, tableInfo.Name.L, constr.ExprString) + ctx := kv.WithInternalSourceType(w.ctx, kv.InternalTxnDDL) + rows, _, err := sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, sql) + if err != nil { + return errors.Trace(err) + } + rowCount := len(rows) + if rowCount != 0 { + return dbterror.ErrCheckConstraintIsViolated.GenWithStackByArgs(constr.Name.L) + } + return nil +} + +func setNameForConstraintInfo(tableLowerName string, namesMap map[string]bool, infos []*model.ConstraintInfo) { + cnt := 1 + constraintPrefix := tableLowerName + "_chk_" + for _, constrInfo := range infos { + if constrInfo.Name.O == "" { + constrName := fmt.Sprintf("%s%d", constraintPrefix, cnt) + for { + // loop until find constrName that haven't been used. + if !namesMap[constrName] { + namesMap[constrName] = true + break + } + cnt++ + constrName = fmt.Sprintf("%s%d", constraintPrefix, cnt) + } + constrInfo.Name = model.NewCIStr(constrName) + } + } +} + +// IsColumnDroppableWithCheckConstraint check whether the column in check-constraint whose dependent col is more than 1 +func IsColumnDroppableWithCheckConstraint(col model.CIStr, tblInfo *model.TableInfo) error { + for _, cons := range tblInfo.Constraints { + if len(cons.ConstraintCols) > 1 { + for _, colName := range cons.ConstraintCols { + if colName.L == col.L { + return dbterror.ErrCantDropColWithCheckConstraint.GenWithStackByArgs(cons.Name, col) + } + } + } + } + return nil +} + +// IsColumnRenameableWithCheckConstraint check whether the column is referenced in check-constraint +func IsColumnRenameableWithCheckConstraint(col model.CIStr, tblInfo *model.TableInfo) error { + for _, cons := range tblInfo.Constraints { + for _, colName := range cons.ConstraintCols { + if colName.L == col.L { + return dbterror.ErrCantDropColWithCheckConstraint.GenWithStackByArgs(cons.Name, col) + } + } + } + return nil +} diff --git a/pkg/ddl/create_table.go b/pkg/ddl/create_table.go index 68d686ddc35dc..cc6ac85b7ddd7 100644 --- a/pkg/ddl/create_table.go +++ b/pkg/ddl/create_table.go @@ -96,11 +96,11 @@ func createTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkCheck bool) (*model. return tbInfo, errors.Trace(err) } - failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkOwnerCheckAllVersionsWaitTime")); _err_ == nil { if val.(bool) { - failpoint.Return(tbInfo, errors.New("mock create table error")) + return tbInfo, errors.New("mock create table error") } - }) + } // build table & partition bundles if any. if err = checkAllTablePlacementPoliciesExistAndCancelNonExistJob(t, job, tbInfo); err != nil { @@ -149,11 +149,11 @@ func createTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkCheck bool) (*model. } func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockExceedErrorLimit")); _err_ == nil { if val.(bool) { - failpoint.Return(ver, errors.New("mock do job error")) + return ver, errors.New("mock do job error") } - }) + } // just decode, createTable will use it as Args[0] tbInfo := &model.TableInfo{} @@ -306,7 +306,7 @@ func onCreateView(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) if infoschema.ErrTableNotExists.Equal(err) { err = nil } - failpoint.InjectCall("onDDLCreateView", job) + failpoint.Call(_curpkg_("onDDLCreateView"), job) if err != nil { if infoschema.ErrDatabaseNotExists.Equal(err) { job.State = model.JobStateCancelled diff --git a/pkg/ddl/create_table.go__failpoint_stash__ b/pkg/ddl/create_table.go__failpoint_stash__ new file mode 100644 index 0000000000000..68d686ddc35dc --- /dev/null +++ b/pkg/ddl/create_table.go__failpoint_stash__ @@ -0,0 +1,1527 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "fmt" + "math" + "strings" + "sync/atomic" + "unicode/utf8" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + field_types "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/pingcap/tidb/pkg/util/set" + "go.uber.org/zap" +) + +// DANGER: it is an internal function used by onCreateTable and onCreateTables, for reusing code. Be careful. +// 1. it expects the argument of job has been deserialized. +// 2. it won't call updateSchemaVersion, FinishTableJob and asyncNotifyEvent. +func createTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkCheck bool) (*model.TableInfo, error) { + schemaID := job.SchemaID + tbInfo := job.Args[0].(*model.TableInfo) + + tbInfo.State = model.StateNone + err := checkTableNotExists(d, schemaID, tbInfo.Name.L) + if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { + job.State = model.JobStateCancelled + } + return tbInfo, errors.Trace(err) + } + + err = checkConstraintNamesNotExists(t, schemaID, tbInfo.Constraints) + if err != nil { + if infoschema.ErrCheckConstraintDupName.Equal(err) { + job.State = model.JobStateCancelled + } + return tbInfo, errors.Trace(err) + } + + retryable, err := checkTableForeignKeyValidInOwner(d, t, job, tbInfo, fkCheck) + if err != nil { + if !retryable { + job.State = model.JobStateCancelled + } + return tbInfo, errors.Trace(err) + } + // Allocate foreign key ID. + for _, fkInfo := range tbInfo.ForeignKeys { + fkInfo.ID = allocateFKIndexID(tbInfo) + fkInfo.State = model.StatePublic + } + switch tbInfo.State { + case model.StateNone: + // none -> public + tbInfo.State = model.StatePublic + tbInfo.UpdateTS = t.StartTS + err = createTableOrViewWithCheck(t, job, schemaID, tbInfo) + if err != nil { + return tbInfo, errors.Trace(err) + } + + failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(tbInfo, errors.New("mock create table error")) + } + }) + + // build table & partition bundles if any. + if err = checkAllTablePlacementPoliciesExistAndCancelNonExistJob(t, job, tbInfo); err != nil { + return tbInfo, errors.Trace(err) + } + + if tbInfo.TiFlashReplica != nil { + replicaInfo := tbInfo.TiFlashReplica + if pi := tbInfo.GetPartitionInfo(); pi != nil { + logutil.DDLLogger().Info("Set TiFlash replica pd rule for partitioned table when creating", zap.Int64("tableID", tbInfo.ID)) + if e := infosync.ConfigureTiFlashPDForPartitions(false, &pi.Definitions, replicaInfo.Count, &replicaInfo.LocationLabels, tbInfo.ID); e != nil { + job.State = model.JobStateCancelled + return tbInfo, errors.Trace(e) + } + // Partitions that in adding mid-state. They have high priorities, so we should set accordingly pd rules. + if e := infosync.ConfigureTiFlashPDForPartitions(true, &pi.AddingDefinitions, replicaInfo.Count, &replicaInfo.LocationLabels, tbInfo.ID); e != nil { + job.State = model.JobStateCancelled + return tbInfo, errors.Trace(e) + } + } else { + logutil.DDLLogger().Info("Set TiFlash replica pd rule when creating", zap.Int64("tableID", tbInfo.ID)) + if e := infosync.ConfigureTiFlashPDForTable(tbInfo.ID, replicaInfo.Count, &replicaInfo.LocationLabels); e != nil { + job.State = model.JobStateCancelled + return tbInfo, errors.Trace(e) + } + } + } + + bundles, err := placement.NewFullTableBundles(t, tbInfo) + if err != nil { + job.State = model.JobStateCancelled + return tbInfo, errors.Trace(err) + } + + // Send the placement bundle to PD. + err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) + if err != nil { + job.State = model.JobStateCancelled + return tbInfo, errors.Wrapf(err, "failed to notify PD the placement rules") + } + + return tbInfo, nil + default: + return tbInfo, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tbInfo.State) + } +} + +func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(ver, errors.New("mock do job error")) + } + }) + + // just decode, createTable will use it as Args[0] + tbInfo := &model.TableInfo{} + fkCheck := false + if err := job.DecodeArgs(tbInfo, &fkCheck); err != nil { + // Invalid arguments, cancel this job. + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if len(tbInfo.ForeignKeys) > 0 { + return createTableWithForeignKeys(d, t, job, tbInfo, fkCheck) + } + + tbInfo, err := createTable(d, t, job, fkCheck) + if err != nil { + return ver, errors.Trace(err) + } + + ver, err = updateSchemaVersion(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) + createTableEvent := statsutil.NewCreateTableEvent( + job.SchemaID, + tbInfo, + ) + asyncNotifyEvent(d, createTableEvent) + return ver, errors.Trace(err) +} + +func createTableWithForeignKeys(d *ddlCtx, t *meta.Meta, job *model.Job, tbInfo *model.TableInfo, fkCheck bool) (ver int64, err error) { + switch tbInfo.State { + case model.StateNone, model.StatePublic: + // create table in non-public or public state. The function `createTable` will always reset + // the `tbInfo.State` with `model.StateNone`, so it's fine to just call the `createTable` with + // public state. + // when `br` restores table, the state of `tbInfo` will be public. + tbInfo, err = createTable(d, t, job, fkCheck) + if err != nil { + return ver, errors.Trace(err) + } + tbInfo.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfo(d, t, job, tbInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.SchemaState = model.StateWriteOnly + case model.StateWriteOnly: + tbInfo.State = model.StatePublic + ver, err = updateVersionAndTableInfo(d, t, job, tbInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) + createTableEvent := statsutil.NewCreateTableEvent( + job.SchemaID, + tbInfo, + ) + asyncNotifyEvent(d, createTableEvent) + return ver, nil + default: + return ver, errors.Trace(dbterror.ErrInvalidDDLJob.GenWithStackByArgs("table", tbInfo.State)) + } + return ver, errors.Trace(err) +} + +func onCreateTables(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { + var ver int64 + + var args []*model.TableInfo + fkCheck := false + err := job.DecodeArgs(&args, &fkCheck) + if err != nil { + // Invalid arguments, cancel this job. + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + // We don't construct jobs for every table, but only tableInfo + // The following loop creates a stub job for every table + // + // it clones a stub job from the ActionCreateTables job + stubJob := job.Clone() + stubJob.Args = make([]any, 1) + for i := range args { + stubJob.TableID = args[i].ID + stubJob.Args[0] = args[i] + if args[i].Sequence != nil { + err := createSequenceWithCheck(t, stubJob, args[i]) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } else { + tbInfo, err := createTable(d, t, stubJob, fkCheck) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + args[i] = tbInfo + } + } + + ver, err = updateSchemaVersion(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + + job.State = model.JobStateDone + job.SchemaState = model.StatePublic + job.BinlogInfo.SetTableInfos(ver, args) + + for i := range args { + createTableEvent := statsutil.NewCreateTableEvent( + job.SchemaID, + args[i], + ) + asyncNotifyEvent(d, createTableEvent) + } + + return ver, errors.Trace(err) +} + +func createTableOrViewWithCheck(t *meta.Meta, job *model.Job, schemaID int64, tbInfo *model.TableInfo) error { + err := checkTableInfoValid(tbInfo) + if err != nil { + job.State = model.JobStateCancelled + return errors.Trace(err) + } + return t.CreateTableOrView(schemaID, tbInfo) +} + +func onCreateView(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + schemaID := job.SchemaID + tbInfo := &model.TableInfo{} + var orReplace bool + var _placeholder int64 // oldTblInfoID + if err := job.DecodeArgs(tbInfo, &orReplace, &_placeholder); err != nil { + // Invalid arguments, cancel this job. + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + tbInfo.State = model.StateNone + + oldTableID, err := findTableIDByName(d, t, schemaID, tbInfo.Name.L) + if infoschema.ErrTableNotExists.Equal(err) { + err = nil + } + failpoint.InjectCall("onDDLCreateView", job) + if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } else if !infoschema.ErrTableExists.Equal(err) { + return ver, errors.Trace(err) + } + if !orReplace { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } + ver, err = updateSchemaVersion(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + switch tbInfo.State { + case model.StateNone: + // none -> public + tbInfo.State = model.StatePublic + tbInfo.UpdateTS = t.StartTS + if oldTableID > 0 && orReplace { + err = t.DropTableOrView(schemaID, oldTableID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + err = t.GetAutoIDAccessors(schemaID, oldTableID).Del() + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } + err = createTableOrViewWithCheck(t, job, schemaID, tbInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) + return ver, nil + default: + return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tbInfo.State) + } +} + +func findTableIDByName(d *ddlCtx, t *meta.Meta, schemaID int64, tableName string) (int64, error) { + // Try to use memory schema info to check first. + currVer, err := t.GetSchemaVersion() + if err != nil { + return 0, err + } + is := d.infoCache.GetLatest() + if is != nil && is.SchemaMetaVersion() == currVer { + return findTableIDFromInfoSchema(is, schemaID, tableName) + } + + return findTableIDFromStore(t, schemaID, tableName) +} + +func findTableIDFromInfoSchema(is infoschema.InfoSchema, schemaID int64, tableName string) (int64, error) { + schema, ok := is.SchemaByID(schemaID) + if !ok { + return 0, infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") + } + tbl, err := is.TableByName(context.Background(), schema.Name, model.NewCIStr(tableName)) + if err != nil { + return 0, err + } + return tbl.Meta().ID, nil +} + +func findTableIDFromStore(t *meta.Meta, schemaID int64, tableName string) (int64, error) { + tbls, err := t.ListSimpleTables(schemaID) + if err != nil { + if meta.ErrDBNotExists.Equal(err) { + return 0, infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") + } + return 0, errors.Trace(err) + } + for _, tbl := range tbls { + if tbl.Name.L == tableName { + return tbl.ID, nil + } + } + return 0, infoschema.ErrTableNotExists.FastGenByArgs(tableName) +} + +// BuildTableInfoFromAST builds model.TableInfo from a SQL statement. +// Note: TableID and PartitionID are left as uninitialized value. +func BuildTableInfoFromAST(s *ast.CreateTableStmt) (*model.TableInfo, error) { + return buildTableInfoWithCheck(mock.NewContext(), s, mysql.DefaultCharset, "", nil) +} + +// buildTableInfoWithCheck builds model.TableInfo from a SQL statement. +// Note: TableID and PartitionIDs are left as uninitialized value. +func buildTableInfoWithCheck(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { + tbInfo, err := BuildTableInfoWithStmt(ctx, s, dbCharset, dbCollate, placementPolicyRef) + if err != nil { + return nil, err + } + // Fix issue 17952 which will cause partition range expr can't be parsed as Int. + // checkTableInfoValidWithStmt will do the constant fold the partition expression first, + // then checkTableInfoValidExtra will pass the tableInfo check successfully. + if err = checkTableInfoValidWithStmt(ctx, tbInfo, s); err != nil { + return nil, err + } + if err = checkTableInfoValidExtra(tbInfo); err != nil { + return nil, err + } + return tbInfo, nil +} + +// CheckTableInfoValidWithStmt exposes checkTableInfoValidWithStmt to SchemaTracker. Maybe one day we can delete it. +func CheckTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { + return checkTableInfoValidWithStmt(ctx, tbInfo, s) +} + +func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { + // All of these rely on the AST structure of expressions, which were + // lost in the model (got serialized into strings). + if err := checkGeneratedColumn(ctx, s.Table.Schema, tbInfo.Name, s.Cols); err != nil { + return errors.Trace(err) + } + + // Check if table has a primary key if required. + if !ctx.GetSessionVars().InRestrictedSQL && ctx.GetSessionVars().PrimaryKeyRequired && len(tbInfo.GetPkName().String()) == 0 { + return infoschema.ErrTableWithoutPrimaryKey + } + if tbInfo.Partition != nil { + if err := checkPartitionDefinitionConstraints(ctx, tbInfo); err != nil { + return errors.Trace(err) + } + if s.Partition != nil { + if err := checkPartitionFuncType(ctx, s.Partition.Expr, s.Table.Schema.O, tbInfo); err != nil { + return errors.Trace(err) + } + if err := checkPartitioningKeysConstraints(ctx, s, tbInfo); err != nil { + return errors.Trace(err) + } + } + } + if tbInfo.TTLInfo != nil { + if err := checkTTLInfoValid(ctx, s.Table.Schema, tbInfo); err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func checkGeneratedColumn(ctx sessionctx.Context, schemaName model.CIStr, tableName model.CIStr, colDefs []*ast.ColumnDef) error { + var colName2Generation = make(map[string]columnGenerationInDDL, len(colDefs)) + var exists bool + var autoIncrementColumn string + for i, colDef := range colDefs { + for _, option := range colDef.Options { + if option.Tp == ast.ColumnOptionGenerated { + if err := checkIllegalFn4Generated(colDef.Name.Name.L, typeColumn, option.Expr); err != nil { + return errors.Trace(err) + } + } + } + if containsColumnOption(colDef, ast.ColumnOptionAutoIncrement) { + exists, autoIncrementColumn = true, colDef.Name.Name.L + } + generated, depCols, err := findDependedColumnNames(schemaName, tableName, colDef) + if err != nil { + return errors.Trace(err) + } + if !generated { + colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{ + position: i, + generated: false, + } + } else { + colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{ + position: i, + generated: true, + dependences: depCols, + } + } + } + + // Check whether the generated column refers to any auto-increment columns + if exists { + if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { + for colName, generated := range colName2Generation { + if _, found := generated.dependences[autoIncrementColumn]; found { + return dbterror.ErrGeneratedColumnRefAutoInc.GenWithStackByArgs(colName) + } + } + } + } + + for _, colDef := range colDefs { + colName := colDef.Name.Name.L + if err := verifyColumnGeneration(colName2Generation, colName); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// checkTableInfoValidExtra is like checkTableInfoValid, but also assumes the +// table info comes from untrusted source and performs further checks such as +// name length and column count. +// (checkTableInfoValid is also used in repairing objects which don't perform +// these checks. Perhaps the two functions should be merged together regardless?) +func checkTableInfoValidExtra(tbInfo *model.TableInfo) error { + if err := checkTooLongTable(tbInfo.Name); err != nil { + return err + } + + if err := checkDuplicateColumn(tbInfo.Columns); err != nil { + return err + } + if err := checkTooLongColumns(tbInfo.Columns); err != nil { + return err + } + if err := checkTooManyColumns(tbInfo.Columns); err != nil { + return errors.Trace(err) + } + if err := checkTooManyIndexes(tbInfo.Indices); err != nil { + return errors.Trace(err) + } + if err := checkColumnsAttributes(tbInfo.Columns); err != nil { + return errors.Trace(err) + } + + // FIXME: perform checkConstraintNames + if err := checkCharsetAndCollation(tbInfo.Charset, tbInfo.Collate); err != nil { + return errors.Trace(err) + } + + oldState := tbInfo.State + tbInfo.State = model.StatePublic + err := checkTableInfoValid(tbInfo) + tbInfo.State = oldState + return err +} + +// checkTableInfoValid uses to check table info valid. This is used to validate table info. +func checkTableInfoValid(tblInfo *model.TableInfo) error { + _, err := tables.TableFromMeta(autoid.NewAllocators(false), tblInfo) + if err != nil { + return err + } + return checkInvisibleIndexOnPK(tblInfo) +} + +func checkDuplicateColumn(cols []*model.ColumnInfo) error { + colNames := set.StringSet{} + for _, col := range cols { + colName := col.Name + if colNames.Exist(colName.L) { + return infoschema.ErrColumnExists.GenWithStackByArgs(colName.O) + } + colNames.Insert(colName.L) + } + return nil +} + +func checkTooLongColumns(cols []*model.ColumnInfo) error { + for _, col := range cols { + if err := checkTooLongColumn(col.Name); err != nil { + return err + } + } + return nil +} + +func checkTooManyColumns(colDefs []*model.ColumnInfo) error { + if uint32(len(colDefs)) > atomic.LoadUint32(&config.GetGlobalConfig().TableColumnCountLimit) { + return dbterror.ErrTooManyFields + } + return nil +} + +func checkTooManyIndexes(idxDefs []*model.IndexInfo) error { + if len(idxDefs) > config.GetGlobalConfig().IndexLimit { + return dbterror.ErrTooManyKeys.GenWithStackByArgs(config.GetGlobalConfig().IndexLimit) + } + return nil +} + +// checkColumnsAttributes checks attributes for multiple columns. +func checkColumnsAttributes(colDefs []*model.ColumnInfo) error { + for _, colDef := range colDefs { + if err := checkColumnAttributes(colDef.Name.O, &colDef.FieldType); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// checkColumnAttributes check attributes for single column. +func checkColumnAttributes(colName string, tp *types.FieldType) error { + switch tp.GetType() { + case mysql.TypeNewDecimal, mysql.TypeDouble, mysql.TypeFloat: + if tp.GetFlen() < tp.GetDecimal() { + return types.ErrMBiggerThanD.GenWithStackByArgs(colName) + } + case mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: + if tp.GetDecimal() != types.UnspecifiedFsp && (tp.GetDecimal() < types.MinFsp || tp.GetDecimal() > types.MaxFsp) { + return types.ErrTooBigPrecision.GenWithStackByArgs(tp.GetDecimal(), colName, types.MaxFsp) + } + } + return nil +} + +// BuildSessionTemporaryTableInfo builds model.TableInfo from a SQL statement. +func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSchema, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { + ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} + //build tableInfo + var tbInfo *model.TableInfo + var referTbl table.Table + var err error + if s.ReferTable != nil { + referIdent := ast.Ident{Schema: s.ReferTable.Schema, Name: s.ReferTable.Name} + _, ok := is.SchemaByName(referIdent.Schema) + if !ok { + return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) + } + referTbl, err = is.TableByName(context.Background(), referIdent.Schema, referIdent.Name) + if err != nil { + return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) + } + tbInfo, err = BuildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) + } else { + tbInfo, err = buildTableInfoWithCheck(ctx, s, dbCharset, dbCollate, placementPolicyRef) + } + return tbInfo, err +} + +// BuildTableInfoWithStmt builds model.TableInfo from a SQL statement without validity check +func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { + colDefs := s.Cols + tableCharset, tableCollate, err := GetCharsetAndCollateInTableOption(ctx.GetSessionVars(), 0, s.Options) + if err != nil { + return nil, errors.Trace(err) + } + tableCharset, tableCollate, err = ResolveCharsetCollation(ctx.GetSessionVars(), + ast.CharsetOpt{Chs: tableCharset, Col: tableCollate}, + ast.CharsetOpt{Chs: dbCharset, Col: dbCollate}, + ) + if err != nil { + return nil, errors.Trace(err) + } + + // The column charset haven't been resolved here. + cols, newConstraints, err := buildColumnsAndConstraints(ctx, colDefs, s.Constraints, tableCharset, tableCollate) + if err != nil { + return nil, errors.Trace(err) + } + err = checkConstraintNames(s.Table.Name, newConstraints) + if err != nil { + return nil, errors.Trace(err) + } + + var tbInfo *model.TableInfo + tbInfo, err = BuildTableInfo(ctx, s.Table.Name, cols, newConstraints, tableCharset, tableCollate) + if err != nil { + return nil, errors.Trace(err) + } + if err = setTemporaryType(ctx, tbInfo, s); err != nil { + return nil, errors.Trace(err) + } + + if err = setTableAutoRandomBits(ctx, tbInfo, colDefs); err != nil { + return nil, errors.Trace(err) + } + + if err = handleTableOptions(s.Options, tbInfo); err != nil { + return nil, errors.Trace(err) + } + + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil { + return nil, errors.Trace(err) + } + + if tbInfo.TempTableType == model.TempTableNone && tbInfo.PlacementPolicyRef == nil && placementPolicyRef != nil { + // Set the defaults from Schema. Note: they are mutual exclusive! + tbInfo.PlacementPolicyRef = placementPolicyRef + } + + // After handleTableOptions, so the partitions can get defaults from Table level + err = buildTablePartitionInfo(ctx, s.Partition, tbInfo) + if err != nil { + return nil, errors.Trace(err) + } + + return tbInfo, nil +} + +func setTableAutoRandomBits(ctx sessionctx.Context, tbInfo *model.TableInfo, colDefs []*ast.ColumnDef) error { + for _, col := range colDefs { + if containsColumnOption(col, ast.ColumnOptionAutoRandom) { + if col.Tp.GetType() != mysql.TypeLonglong { + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs( + fmt.Sprintf(autoid.AutoRandomOnNonBigIntColumn, types.TypeStr(col.Tp.GetType()))) + } + switch { + case tbInfo.PKIsHandle: + if tbInfo.GetPkName().L != col.Name.Name.L { + errMsg := fmt.Sprintf(autoid.AutoRandomMustFirstColumnInPK, col.Name.Name.O) + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) + } + case tbInfo.IsCommonHandle: + pk := tables.FindPrimaryIndex(tbInfo) + if pk == nil { + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomNoClusteredPKErrMsg) + } + if col.Name.Name.L != pk.Columns[0].Name.L { + errMsg := fmt.Sprintf(autoid.AutoRandomMustFirstColumnInPK, col.Name.Name.O) + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) + } + default: + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomNoClusteredPKErrMsg) + } + + if containsColumnOption(col, ast.ColumnOptionAutoIncrement) { + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithAutoIncErrMsg) + } + if containsColumnOption(col, ast.ColumnOptionDefaultValue) { + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithDefaultValueErrMsg) + } + + shardBits, rangeBits, err := extractAutoRandomBitsFromColDef(col) + if err != nil { + return errors.Trace(err) + } + tbInfo.AutoRandomBits = shardBits + tbInfo.AutoRandomRangeBits = rangeBits + + shardFmt := autoid.NewShardIDFormat(col.Tp, shardBits, rangeBits) + if shardFmt.IncrementalBits < autoid.AutoRandomIncBitsMin { + return dbterror.ErrInvalidAutoRandom.FastGenByArgs(autoid.AutoRandomIncrementalBitsTooSmall) + } + msg := fmt.Sprintf(autoid.AutoRandomAvailableAllocTimesNote, shardFmt.IncrementalBitsCapacity()) + ctx.GetSessionVars().StmtCtx.AppendNote(errors.NewNoStackError(msg)) + } + } + return nil +} + +func containsColumnOption(colDef *ast.ColumnDef, opTp ast.ColumnOptionType) bool { + for _, option := range colDef.Options { + if option.Tp == opTp { + return true + } + } + return false +} + +func extractAutoRandomBitsFromColDef(colDef *ast.ColumnDef) (shardBits, rangeBits uint64, err error) { + for _, op := range colDef.Options { + if op.Tp == ast.ColumnOptionAutoRandom { + shardBits, err = autoid.AutoRandomShardBitsNormalize(op.AutoRandOpt.ShardBits, colDef.Name.Name.O) + if err != nil { + return 0, 0, err + } + rangeBits, err = autoid.AutoRandomRangeBitsNormalize(op.AutoRandOpt.RangeBits) + if err != nil { + return 0, 0, err + } + return shardBits, rangeBits, nil + } + } + return 0, 0, nil +} + +// handleTableOptions updates tableInfo according to table options. +func handleTableOptions(options []*ast.TableOption, tbInfo *model.TableInfo) error { + var ttlOptionsHandled bool + + for _, op := range options { + switch op.Tp { + case ast.TableOptionAutoIncrement: + tbInfo.AutoIncID = int64(op.UintValue) + case ast.TableOptionAutoIdCache: + if op.UintValue > uint64(math.MaxInt64) { + // TODO: Refine this error. + return errors.New("table option auto_id_cache overflows int64") + } + tbInfo.AutoIdCache = int64(op.UintValue) + case ast.TableOptionAutoRandomBase: + tbInfo.AutoRandID = int64(op.UintValue) + case ast.TableOptionComment: + tbInfo.Comment = op.StrValue + case ast.TableOptionCompression: + tbInfo.Compression = op.StrValue + case ast.TableOptionShardRowID: + if op.UintValue > 0 && tbInfo.HasClusteredIndex() { + return dbterror.ErrUnsupportedShardRowIDBits + } + tbInfo.ShardRowIDBits = op.UintValue + if tbInfo.ShardRowIDBits > shardRowIDBitsMax { + tbInfo.ShardRowIDBits = shardRowIDBitsMax + } + tbInfo.MaxShardRowIDBits = tbInfo.ShardRowIDBits + case ast.TableOptionPreSplitRegion: + if tbInfo.TempTableType != model.TempTableNone { + return errors.Trace(dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("pre split regions")) + } + tbInfo.PreSplitRegions = op.UintValue + case ast.TableOptionCharset, ast.TableOptionCollate: + // We don't handle charset and collate here since they're handled in `GetCharsetAndCollateInTableOption`. + case ast.TableOptionPlacementPolicy: + tbInfo.PlacementPolicyRef = &model.PolicyRefInfo{ + Name: model.NewCIStr(op.StrValue), + } + case ast.TableOptionTTL, ast.TableOptionTTLEnable, ast.TableOptionTTLJobInterval: + if ttlOptionsHandled { + continue + } + + ttlInfo, ttlEnable, ttlJobInterval, err := getTTLInfoInOptions(options) + if err != nil { + return err + } + // It's impossible that `ttlInfo` and `ttlEnable` are all nil, because we have met this option. + // After exclude the situation `ttlInfo == nil && ttlEnable != nil`, we could say `ttlInfo != nil` + if ttlInfo == nil { + if ttlEnable != nil { + return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_ENABLE")) + } + if ttlJobInterval != nil { + return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_JOB_INTERVAL")) + } + } + + tbInfo.TTLInfo = ttlInfo + ttlOptionsHandled = true + } + } + shardingBits := shardingBits(tbInfo) + if tbInfo.PreSplitRegions > shardingBits { + tbInfo.PreSplitRegions = shardingBits + } + return nil +} + +func setTemporaryType(_ sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) error { + switch s.TemporaryKeyword { + case ast.TemporaryGlobal: + tbInfo.TempTableType = model.TempTableGlobal + // "create global temporary table ... on commit preserve rows" + if !s.OnCommitDelete { + return errors.Trace(dbterror.ErrUnsupportedOnCommitPreserve) + } + case ast.TemporaryLocal: + tbInfo.TempTableType = model.TempTableLocal + default: + tbInfo.TempTableType = model.TempTableNone + } + return nil +} + +func buildColumnsAndConstraints( + ctx sessionctx.Context, + colDefs []*ast.ColumnDef, + constraints []*ast.Constraint, + tblCharset string, + tblCollate string, +) ([]*table.Column, []*ast.Constraint, error) { + // outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); + var outPriKeyConstraint *ast.Constraint + for _, v := range constraints { + if v.Tp == ast.ConstraintPrimaryKey { + outPriKeyConstraint = v + break + } + } + cols := make([]*table.Column, 0, len(colDefs)) + colMap := make(map[string]*table.Column, len(colDefs)) + + for i, colDef := range colDefs { + if field_types.TiDBStrictIntegerDisplayWidth { + switch colDef.Tp.GetType() { + case mysql.TypeTiny: + // No warning for BOOL-like tinyint(1) + if colDef.Tp.GetFlen() != types.UnspecifiedLength && colDef.Tp.GetFlen() != 1 { + ctx.GetSessionVars().StmtCtx.AppendWarning( + dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), + ) + } + case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if colDef.Tp.GetFlen() != types.UnspecifiedLength { + ctx.GetSessionVars().StmtCtx.AppendWarning( + dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), + ) + } + } + } + col, cts, err := buildColumnAndConstraint(ctx, i, colDef, outPriKeyConstraint, tblCharset, tblCollate) + if err != nil { + return nil, nil, errors.Trace(err) + } + col.State = model.StatePublic + if mysql.HasZerofillFlag(col.GetFlag()) { + ctx.GetSessionVars().StmtCtx.AppendWarning( + dbterror.ErrWarnDeprecatedZerofill.FastGenByArgs(), + ) + } + constraints = append(constraints, cts...) + cols = append(cols, col) + colMap[colDef.Name.Name.L] = col + } + // Traverse table Constraints and set col.flag. + for _, v := range constraints { + setColumnFlagWithConstraint(colMap, v) + } + return cols, constraints, nil +} + +func setEmptyConstraintName(namesMap map[string]bool, constr *ast.Constraint) { + if constr.Name == "" && len(constr.Keys) > 0 { + var colName string + for _, keyPart := range constr.Keys { + if keyPart.Expr != nil { + colName = "expression_index" + } + } + if colName == "" { + colName = constr.Keys[0].Column.Name.O + } + constrName := colName + i := 2 + if strings.EqualFold(constrName, mysql.PrimaryKeyName) { + constrName = fmt.Sprintf("%s_%d", constrName, 2) + i = 3 + } + for namesMap[constrName] { + // We loop forever until we find constrName that haven't been used. + constrName = fmt.Sprintf("%s_%d", colName, i) + i++ + } + constr.Name = constrName + namesMap[constrName] = true + } +} + +func checkConstraintNames(tableName model.CIStr, constraints []*ast.Constraint) error { + constrNames := map[string]bool{} + fkNames := map[string]bool{} + + // Check not empty constraint name whether is duplicated. + for _, constr := range constraints { + if constr.Tp == ast.ConstraintForeignKey { + err := checkDuplicateConstraint(fkNames, constr.Name, constr.Tp) + if err != nil { + return errors.Trace(err) + } + } else { + err := checkDuplicateConstraint(constrNames, constr.Name, constr.Tp) + if err != nil { + return errors.Trace(err) + } + } + } + + // Set empty constraint names. + checkConstraints := make([]*ast.Constraint, 0, len(constraints)) + for _, constr := range constraints { + if constr.Tp != ast.ConstraintForeignKey { + setEmptyConstraintName(constrNames, constr) + } + if constr.Tp == ast.ConstraintCheck { + checkConstraints = append(checkConstraints, constr) + } + } + // Set check constraint name under its order. + if len(checkConstraints) > 0 { + setEmptyCheckConstraintName(tableName.L, constrNames, checkConstraints) + } + return nil +} + +func checkDuplicateConstraint(namesMap map[string]bool, name string, constraintType ast.ConstraintType) error { + if name == "" { + return nil + } + nameLower := strings.ToLower(name) + if namesMap[nameLower] { + switch constraintType { + case ast.ConstraintForeignKey: + return dbterror.ErrFkDupName.GenWithStackByArgs(name) + case ast.ConstraintCheck: + return dbterror.ErrCheckConstraintDupName.GenWithStackByArgs(name) + default: + return dbterror.ErrDupKeyName.GenWithStackByArgs(name) + } + } + namesMap[nameLower] = true + return nil +} + +func setEmptyCheckConstraintName(tableLowerName string, namesMap map[string]bool, constrs []*ast.Constraint) { + cnt := 1 + constraintPrefix := tableLowerName + "_chk_" + for _, constr := range constrs { + if constr.Name == "" { + constrName := fmt.Sprintf("%s%d", constraintPrefix, cnt) + for { + // loop until find constrName that haven't been used. + if !namesMap[constrName] { + namesMap[constrName] = true + break + } + cnt++ + constrName = fmt.Sprintf("%s%d", constraintPrefix, cnt) + } + constr.Name = constrName + } + } +} + +func setColumnFlagWithConstraint(colMap map[string]*table.Column, v *ast.Constraint) { + switch v.Tp { + case ast.ConstraintPrimaryKey: + for _, key := range v.Keys { + if key.Expr != nil { + continue + } + c, ok := colMap[key.Column.Name.L] + if !ok { + continue + } + c.AddFlag(mysql.PriKeyFlag) + // Primary key can not be NULL. + c.AddFlag(mysql.NotNullFlag) + setNoDefaultValueFlag(c, c.DefaultValue != nil) + } + case ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey: + for i, key := range v.Keys { + if key.Expr != nil { + continue + } + c, ok := colMap[key.Column.Name.L] + if !ok { + continue + } + if i == 0 { + // Only the first column can be set + // if unique index has multi columns, + // the flag should be MultipleKeyFlag. + // See https://dev.mysql.com/doc/refman/5.7/en/show-columns.html + if len(v.Keys) > 1 { + c.AddFlag(mysql.MultipleKeyFlag) + } else { + c.AddFlag(mysql.UniqueKeyFlag) + } + } + } + case ast.ConstraintKey, ast.ConstraintIndex: + for i, key := range v.Keys { + if key.Expr != nil { + continue + } + c, ok := colMap[key.Column.Name.L] + if !ok { + continue + } + if i == 0 { + // Only the first column can be set. + c.AddFlag(mysql.MultipleKeyFlag) + } + } + } +} + +// BuildTableInfoWithLike builds a new table info according to CREATE TABLE ... LIKE statement. +func BuildTableInfoWithLike(ctx sessionctx.Context, ident ast.Ident, referTblInfo *model.TableInfo, s *ast.CreateTableStmt) (*model.TableInfo, error) { + // Check the referred table is a real table object. + if referTblInfo.IsSequence() || referTblInfo.IsView() { + return nil, dbterror.ErrWrongObject.GenWithStackByArgs(ident.Schema, referTblInfo.Name, "BASE TABLE") + } + tblInfo := *referTblInfo + if err := setTemporaryType(ctx, &tblInfo, s); err != nil { + return nil, errors.Trace(err) + } + // Check non-public column and adjust column offset. + newColumns := referTblInfo.Cols() + newIndices := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) + for _, idx := range tblInfo.Indices { + if idx.State == model.StatePublic { + newIndices = append(newIndices, idx) + } + } + tblInfo.Columns = newColumns + tblInfo.Indices = newIndices + tblInfo.Name = ident.Name + tblInfo.AutoIncID = 0 + tblInfo.ForeignKeys = nil + // Ignore TiFlash replicas for temporary tables. + if s.TemporaryKeyword != ast.TemporaryNone { + tblInfo.TiFlashReplica = nil + } else if tblInfo.TiFlashReplica != nil { + replica := *tblInfo.TiFlashReplica + // Keep the tiflash replica setting, remove the replica available status. + replica.AvailablePartitionIDs = nil + replica.Available = false + tblInfo.TiFlashReplica = &replica + } + if referTblInfo.Partition != nil { + pi := *referTblInfo.Partition + pi.Definitions = make([]model.PartitionDefinition, len(referTblInfo.Partition.Definitions)) + copy(pi.Definitions, referTblInfo.Partition.Definitions) + tblInfo.Partition = &pi + } + + if referTblInfo.TTLInfo != nil { + tblInfo.TTLInfo = referTblInfo.TTLInfo.Clone() + } + renameCheckConstraint(&tblInfo) + return &tblInfo, nil +} + +func renameCheckConstraint(tblInfo *model.TableInfo) { + for _, cons := range tblInfo.Constraints { + cons.Name = model.NewCIStr("") + cons.Table = tblInfo.Name + } + setNameForConstraintInfo(tblInfo.Name.L, map[string]bool{}, tblInfo.Constraints) +} + +// BuildTableInfo creates a TableInfo. +func BuildTableInfo( + ctx sessionctx.Context, + tableName model.CIStr, + cols []*table.Column, + constraints []*ast.Constraint, + charset string, + collate string, +) (tbInfo *model.TableInfo, err error) { + tbInfo = &model.TableInfo{ + Name: tableName, + Version: model.CurrLatestTableInfoVersion, + Charset: charset, + Collate: collate, + } + tblColumns := make([]*table.Column, 0, len(cols)) + existedColsMap := make(map[string]struct{}, len(cols)) + for _, v := range cols { + v.ID = AllocateColumnID(tbInfo) + tbInfo.Columns = append(tbInfo.Columns, v.ToInfo()) + tblColumns = append(tblColumns, table.ToColumn(v.ToInfo())) + existedColsMap[v.Name.L] = struct{}{} + } + foreignKeyID := tbInfo.MaxForeignKeyID + for _, constr := range constraints { + // Build hidden columns if necessary. + hiddenCols, err := buildHiddenColumnInfoWithCheck(ctx, constr.Keys, model.NewCIStr(constr.Name), tbInfo, tblColumns) + if err != nil { + return nil, err + } + for _, hiddenCol := range hiddenCols { + hiddenCol.State = model.StatePublic + hiddenCol.ID = AllocateColumnID(tbInfo) + hiddenCol.Offset = len(tbInfo.Columns) + tbInfo.Columns = append(tbInfo.Columns, hiddenCol) + tblColumns = append(tblColumns, table.ToColumn(hiddenCol)) + } + // Check clustered on non-primary key. + if constr.Option != nil && constr.Option.PrimaryKeyTp != model.PrimaryKeyTypeDefault && + constr.Tp != ast.ConstraintPrimaryKey { + return nil, dbterror.ErrUnsupportedClusteredSecondaryKey + } + if constr.Tp == ast.ConstraintForeignKey { + var fkName model.CIStr + foreignKeyID++ + if constr.Name != "" { + fkName = model.NewCIStr(constr.Name) + } else { + fkName = model.NewCIStr(fmt.Sprintf("fk_%d", foreignKeyID)) + } + if model.FindFKInfoByName(tbInfo.ForeignKeys, fkName.L) != nil { + return nil, infoschema.ErrCannotAddForeign + } + fk, err := buildFKInfo(fkName, constr.Keys, constr.Refer, cols) + if err != nil { + return nil, err + } + fk.State = model.StatePublic + + tbInfo.ForeignKeys = append(tbInfo.ForeignKeys, fk) + continue + } + if constr.Tp == ast.ConstraintPrimaryKey { + lastCol, err := CheckPKOnGeneratedColumn(tbInfo, constr.Keys) + if err != nil { + return nil, err + } + isSingleIntPK := isSingleIntPK(constr, lastCol) + if ShouldBuildClusteredIndex(ctx, constr.Option, isSingleIntPK) { + if isSingleIntPK { + tbInfo.PKIsHandle = true + } else { + tbInfo.IsCommonHandle = true + tbInfo.CommonHandleVersion = 1 + } + } + if tbInfo.HasClusteredIndex() { + // Primary key cannot be invisible. + if constr.Option != nil && constr.Option.Visibility == ast.IndexVisibilityInvisible { + return nil, dbterror.ErrPKIndexCantBeInvisible + } + } + if tbInfo.PKIsHandle { + continue + } + } + + if constr.Tp == ast.ConstraintFulltext { + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) + continue + } + + var ( + indexName = constr.Name + primary, unique bool + ) + + // Check if the index is primary or unique. + switch constr.Tp { + case ast.ConstraintPrimaryKey: + primary = true + unique = true + indexName = mysql.PrimaryKeyName + case ast.ConstraintUniq, ast.ConstraintUniqKey, ast.ConstraintUniqIndex: + unique = true + } + + // check constraint + if constr.Tp == ast.ConstraintCheck { + if !variable.EnableCheckConstraint.Load() { + ctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) + continue + } + // Since column check constraint dependency has been done in columnDefToCol. + // Here do the table check constraint dependency check, table constraint + // can only refer the columns in defined columns of the table. + // Refer: https://dev.mysql.com/doc/refman/8.0/en/create-table-check-constraints.html + if ok, err := table.IsSupportedExpr(constr); !ok { + return nil, err + } + var dependedCols []model.CIStr + dependedColsMap := findDependentColsInExpr(constr.Expr) + if !constr.InColumn { + dependedCols = make([]model.CIStr, 0, len(dependedColsMap)) + for k := range dependedColsMap { + if _, ok := existedColsMap[k]; !ok { + // The table constraint depended on a non-existed column. + return nil, dbterror.ErrTableCheckConstraintReferUnknown.GenWithStackByArgs(constr.Name, k) + } + dependedCols = append(dependedCols, model.NewCIStr(k)) + } + } else { + // Check the column-type constraint dependency. + if len(dependedColsMap) > 1 { + return nil, dbterror.ErrColumnCheckConstraintReferOther.GenWithStackByArgs(constr.Name) + } else if len(dependedColsMap) == 0 { + // If dependedCols is empty, the expression must be true/false. + valExpr, ok := constr.Expr.(*driver.ValueExpr) + if !ok || !mysql.HasIsBooleanFlag(valExpr.GetType().GetFlag()) { + return nil, errors.Trace(errors.New("unsupported expression in check constraint")) + } + } else { + if _, ok := dependedColsMap[constr.InColumnName]; !ok { + return nil, dbterror.ErrColumnCheckConstraintReferOther.GenWithStackByArgs(constr.Name) + } + dependedCols = []model.CIStr{model.NewCIStr(constr.InColumnName)} + } + } + // check auto-increment column + if table.ContainsAutoIncrementCol(dependedCols, tbInfo) { + return nil, dbterror.ErrCheckConstraintRefersAutoIncrementColumn.GenWithStackByArgs(constr.Name) + } + // check foreign key + if err := table.HasForeignKeyRefAction(tbInfo.ForeignKeys, constraints, constr, dependedCols); err != nil { + return nil, err + } + // build constraint meta info. + constraintInfo, err := buildConstraintInfo(tbInfo, dependedCols, constr, model.StatePublic) + if err != nil { + return nil, errors.Trace(err) + } + // check if the expression is bool type + if err := table.IfCheckConstraintExprBoolType(ctx.GetExprCtx().GetEvalCtx(), constraintInfo, tbInfo); err != nil { + return nil, err + } + constraintInfo.ID = allocateConstraintID(tbInfo) + tbInfo.Constraints = append(tbInfo.Constraints, constraintInfo) + continue + } + + // build index info. + idxInfo, err := BuildIndexInfo( + ctx, + tbInfo.Columns, + model.NewCIStr(indexName), + primary, + unique, + false, + constr.Keys, + constr.Option, + model.StatePublic, + ) + if err != nil { + return nil, errors.Trace(err) + } + + if len(hiddenCols) > 0 { + AddIndexColumnFlag(tbInfo, idxInfo) + } + sessionVars := ctx.GetSessionVars() + _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment) + if err != nil { + return nil, errors.Trace(err) + } + idxInfo.ID = AllocateIndexID(tbInfo) + tbInfo.Indices = append(tbInfo.Indices, idxInfo) + } + + err = addIndexForForeignKey(ctx, tbInfo) + return tbInfo, err +} + +func precheckBuildHiddenColumnInfo( + indexPartSpecifications []*ast.IndexPartSpecification, + indexName model.CIStr, +) error { + for i, idxPart := range indexPartSpecifications { + if idxPart.Expr == nil { + continue + } + name := fmt.Sprintf("%s_%s_%d", expressionIndexPrefix, indexName, i) + if utf8.RuneCountInString(name) > mysql.MaxColumnNameLength { + // TODO: Refine the error message. + return dbterror.ErrTooLongIdent.GenWithStackByArgs("hidden column") + } + // TODO: Refine the error message. + if err := checkIllegalFn4Generated(indexName.L, typeIndex, idxPart.Expr); err != nil { + return errors.Trace(err) + } + } + return nil +} + +func buildHiddenColumnInfoWithCheck(ctx sessionctx.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName model.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { + if err := precheckBuildHiddenColumnInfo(indexPartSpecifications, indexName); err != nil { + return nil, err + } + return BuildHiddenColumnInfo(ctx, indexPartSpecifications, indexName, tblInfo, existCols) +} + +// BuildHiddenColumnInfo builds hidden column info. +func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName model.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { + hiddenCols := make([]*model.ColumnInfo, 0, len(indexPartSpecifications)) + for i, idxPart := range indexPartSpecifications { + if idxPart.Expr == nil { + continue + } + idxPart.Column = &ast.ColumnName{Name: model.NewCIStr(fmt.Sprintf("%s_%s_%d", expressionIndexPrefix, indexName, i))} + // Check whether the hidden columns have existed. + col := table.FindCol(existCols, idxPart.Column.Name.L) + if col != nil { + // TODO: Use expression index related error. + return nil, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name.String()) + } + idxPart.Length = types.UnspecifiedLength + // The index part is an expression, prepare a hidden column for it. + + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + sb.Reset() + err := idxPart.Expr.Restore(restoreCtx) + if err != nil { + return nil, errors.Trace(err) + } + expr, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), idxPart.Expr, + expression.WithTableInfo(ctx.GetSessionVars().CurrentDB, tblInfo), + expression.WithAllowCastArray(true), + ) + if err != nil { + // TODO: refine the error message. + return nil, err + } + if _, ok := expr.(*expression.Column); ok { + return nil, dbterror.ErrFunctionalIndexOnField + } + + colInfo := &model.ColumnInfo{ + Name: idxPart.Column.Name, + GeneratedExprString: sb.String(), + GeneratedStored: false, + Version: model.CurrLatestColumnInfoVersion, + Dependences: make(map[string]struct{}), + Hidden: true, + FieldType: *expr.GetType(ctx.GetExprCtx().GetEvalCtx()), + } + // Reset some flag, it may be caused by wrong type infer. But it's not easy to fix them all, so reset them here for safety. + colInfo.DelFlag(mysql.PriKeyFlag | mysql.UniqueKeyFlag | mysql.AutoIncrementFlag) + + if colInfo.GetType() == mysql.TypeDatetime || colInfo.GetType() == mysql.TypeDate || colInfo.GetType() == mysql.TypeTimestamp || colInfo.GetType() == mysql.TypeDuration { + if colInfo.FieldType.GetDecimal() == types.UnspecifiedLength { + colInfo.FieldType.SetDecimal(types.MaxFsp) + } + } + // For an array, the collation is set to "binary". The collation has no effect on the array itself (as it's usually + // regarded as a JSON), but will influence how TiKV handles the index value. + if colInfo.FieldType.IsArray() { + colInfo.SetCharset("binary") + colInfo.SetCollate("binary") + } + checkDependencies := make(map[string]struct{}) + for _, colName := range FindColumnNamesInExpr(idxPart.Expr) { + colInfo.Dependences[colName.Name.L] = struct{}{} + checkDependencies[colName.Name.L] = struct{}{} + } + if err = checkDependedColExist(checkDependencies, existCols); err != nil { + return nil, errors.Trace(err) + } + if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { + if err = checkExpressionIndexAutoIncrement(indexName.O, colInfo.Dependences, tblInfo); err != nil { + return nil, errors.Trace(err) + } + } + idxPart.Expr = nil + hiddenCols = append(hiddenCols, colInfo) + } + return hiddenCols, nil +} + +// addIndexForForeignKey uses to auto create an index for the foreign key if the table doesn't have any index cover the +// foreign key columns. +func addIndexForForeignKey(ctx sessionctx.Context, tbInfo *model.TableInfo) error { + if len(tbInfo.ForeignKeys) == 0 { + return nil + } + var handleCol *model.ColumnInfo + if tbInfo.PKIsHandle { + handleCol = tbInfo.GetPkColInfo() + } + for _, fk := range tbInfo.ForeignKeys { + if fk.Version < model.FKVersion1 { + continue + } + if handleCol != nil && len(fk.Cols) == 1 && handleCol.Name.L == fk.Cols[0].L { + continue + } + if model.FindIndexByColumns(tbInfo, tbInfo.Indices, fk.Cols...) != nil { + continue + } + idxName := fk.Name + if tbInfo.FindIndexByName(idxName.L) != nil { + return dbterror.ErrDupKeyName.GenWithStack("duplicate key name %s", fk.Name.O) + } + keys := make([]*ast.IndexPartSpecification, 0, len(fk.Cols)) + for _, col := range fk.Cols { + keys = append(keys, &ast.IndexPartSpecification{ + Column: &ast.ColumnName{Name: col}, + Length: types.UnspecifiedLength, + }) + } + idxInfo, err := BuildIndexInfo(ctx, tbInfo.Columns, idxName, false, false, false, keys, nil, model.StatePublic) + if err != nil { + return errors.Trace(err) + } + idxInfo.ID = AllocateIndexID(tbInfo) + tbInfo.Indices = append(tbInfo.Indices, idxInfo) + } + return nil +} + +func isSingleIntPK(constr *ast.Constraint, lastCol *model.ColumnInfo) bool { + if len(constr.Keys) != 1 { + return false + } + switch lastCol.GetType() { + case mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24: + return true + } + return false +} + +// ShouldBuildClusteredIndex is used to determine whether the CREATE TABLE statement should build a clustered index table. +func ShouldBuildClusteredIndex(ctx sessionctx.Context, opt *ast.IndexOption, isSingleIntPK bool) bool { + if opt == nil || opt.PrimaryKeyTp == model.PrimaryKeyTypeDefault { + switch ctx.GetSessionVars().EnableClusteredIndex { + case variable.ClusteredIndexDefModeOn: + return true + case variable.ClusteredIndexDefModeIntOnly: + return !config.GetGlobalConfig().AlterPrimaryKey && isSingleIntPK + default: + return false + } + } + return opt.PrimaryKeyTp == model.PrimaryKeyTypeClustered +} + +// BuildViewInfo builds a ViewInfo structure from an ast.CreateViewStmt. +func BuildViewInfo(s *ast.CreateViewStmt) (*model.ViewInfo, error) { + // Always Use `format.RestoreNameBackQuotes` to restore `SELECT` statement despite the `ANSI_QUOTES` SQL Mode is enabled or not. + restoreFlag := format.RestoreStringSingleQuotes | format.RestoreKeyWordUppercase | format.RestoreNameBackQuotes + var sb strings.Builder + if err := s.Select.Restore(format.NewRestoreCtx(restoreFlag, &sb)); err != nil { + return nil, err + } + + return &model.ViewInfo{Definer: s.Definer, Algorithm: s.Algorithm, + Security: s.Security, SelectStmt: sb.String(), CheckOption: s.CheckOption, Cols: nil}, nil +} diff --git a/pkg/ddl/ddl.go b/pkg/ddl/ddl.go index 4ec779e8cc0e7..12eea8d51efbc 100644 --- a/pkg/ddl/ddl.go +++ b/pkg/ddl/ddl.go @@ -1224,11 +1224,11 @@ func processJobs( ids []int64, byWho model.AdminCommandOperator, ) (jobErrs []error, err error) { - failpoint.Inject("mockFailedCommandOnConcurencyDDL", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockFailedCommandOnConcurencyDDL")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("mock failed admin command on ddl jobs")) + return nil, errors.New("mock failed admin command on ddl jobs") } - }) + } if len(ids) == 0 { return nil, nil @@ -1279,11 +1279,11 @@ func processJobs( } } - failpoint.Inject("mockCommitFailedOnDDLCommand", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockCommitFailedOnDDLCommand")); _err_ == nil { if val.(bool) { - failpoint.Return(jobErrs, errors.New("mock commit failed on admin command on ddl jobs")) + return jobErrs, errors.New("mock commit failed on admin command on ddl jobs") } - }) + } // There may be some conflict during the update, try it again if err = ns.Commit(context.Background()); err != nil { diff --git a/pkg/ddl/ddl.go__failpoint_stash__ b/pkg/ddl/ddl.go__failpoint_stash__ new file mode 100644 index 0000000000000..4ec779e8cc0e7 --- /dev/null +++ b/pkg/ddl/ddl.go__failpoint_stash__ @@ -0,0 +1,1421 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package ddl + +import ( + "context" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/ddl/syncer" + "github.com/pingcap/tidb/pkg/ddl/systable" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics/handle" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/table" + pumpcli "github.com/pingcap/tidb/pkg/tidb-binlog/pump_client" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/gcutil" + "github.com/pingcap/tidb/pkg/util/generic" + "github.com/tikv/client-go/v2/tikvrpc" + clientv3 "go.etcd.io/etcd/client/v3" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + // currentVersion is for all new DDL jobs. + currentVersion = 1 + // DDLOwnerKey is the ddl owner path that is saved to etcd, and it's exported for testing. + DDLOwnerKey = "/tidb/ddl/fg/owner" + ddlSchemaVersionKeyLock = "/tidb/ddl/schema_version_lock" + // addingDDLJobPrefix is the path prefix used to record the newly added DDL job, and it's saved to etcd. + addingDDLJobPrefix = "/tidb/ddl/add_ddl_job_" + ddlPrompt = "ddl" + + shardRowIDBitsMax = 15 + + batchAddingJobs = 100 + + reorgWorkerCnt = 10 + generalWorkerCnt = 10 + + // checkFlagIndexInJobArgs is the recoverCheckFlag index used in RecoverTable/RecoverSchema job arg list. + checkFlagIndexInJobArgs = 1 +) + +const ( + // The recoverCheckFlag is used to judge the gc work status when RecoverTable/RecoverSchema. + recoverCheckFlagNone int64 = iota + recoverCheckFlagEnableGC + recoverCheckFlagDisableGC +) + +// OnExist specifies what to do when a new object has a name collision. +type OnExist uint8 + +// CreateTableConfig is the configuration of `CreateTableWithInfo`. +type CreateTableConfig struct { + OnExist OnExist + // IDAllocated indicates whether the job has allocated all IDs for tables affected + // in the job, if true, DDL will not allocate IDs for them again, it's only used + // by BR now. By reusing IDs BR can save a lot of works such as rewriting table + // IDs in backed up KVs. + IDAllocated bool +} + +// CreateTableOption is the option for creating table. +type CreateTableOption func(*CreateTableConfig) + +// GetCreateTableConfig applies the series of config options from default config +// and returns the final config. +func GetCreateTableConfig(cs []CreateTableOption) CreateTableConfig { + cfg := CreateTableConfig{} + for _, c := range cs { + c(&cfg) + } + return cfg +} + +// WithOnExist applies the OnExist option. +func WithOnExist(o OnExist) CreateTableOption { + return func(cfg *CreateTableConfig) { + cfg.OnExist = o + } +} + +// WithIDAllocated applies the IDAllocated option. +// WARNING!!!: if idAllocated == true, DDL will NOT allocate IDs by itself. That +// means if the caller can not promise ID is unique, then we got inconsistency. +// This option is only exposed to be used by BR. +func WithIDAllocated(idAllocated bool) CreateTableOption { + return func(cfg *CreateTableConfig) { + cfg.IDAllocated = idAllocated + } +} + +const ( + // OnExistError throws an error on name collision. + OnExistError OnExist = iota + // OnExistIgnore skips creating the new object. + OnExistIgnore + // OnExistReplace replaces the old object by the new object. This is only + // supported by VIEWs at the moment. For other object types, this is + // equivalent to OnExistError. + OnExistReplace + + jobRecordCapacity = 16 + jobOnceCapacity = 1000 +) + +var ( + // EnableSplitTableRegion is a flag to decide whether to split a new region for + // a newly created table. It takes effect only if the Storage supports split + // region. + EnableSplitTableRegion = uint32(0) +) + +// DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache. +type DDL interface { + // Start campaigns the owner and starts workers. + // ctxPool is used for the worker's delRangeManager and creates sessions. + Start(ctxPool *pools.ResourcePool) error + // GetLease returns current schema lease time. + GetLease() time.Duration + // Stats returns the DDL statistics. + Stats(vars *variable.SessionVars) (map[string]any, error) + // GetScope gets the status variables scope. + GetScope(status string) variable.ScopeFlag + // Stop stops DDL worker. + Stop() error + // RegisterStatsHandle registers statistics handle and its corresponding event channel for ddl. + RegisterStatsHandle(*handle.Handle) + // SchemaSyncer gets the schema syncer. + SchemaSyncer() syncer.SchemaSyncer + // StateSyncer gets the cluster state syncer. + StateSyncer() syncer.StateSyncer + // OwnerManager gets the owner manager. + OwnerManager() owner.Manager + // GetID gets the ddl ID. + GetID() string + // GetTableMaxHandle gets the max row ID of a normal table or a partition. + GetTableMaxHandle(ctx *JobContext, startTS uint64, tbl table.PhysicalTable) (kv.Handle, bool, error) + // SetBinlogClient sets the binlog client for DDL worker. It's exported for testing. + SetBinlogClient(*pumpcli.PumpsClient) + // GetMinJobIDRefresher gets the MinJobIDRefresher, this api only works after Start. + GetMinJobIDRefresher() *systable.MinJobIDRefresher +} + +type jobSubmitResult struct { + err error + jobID int64 + // merged indicates whether the job is merged into another job together with + // other jobs. we only merge multiple create table jobs into one job when fast + // create table is enabled. + merged bool +} + +// JobWrapper is used to wrap a job and some other information. +// exported for testing. +type JobWrapper struct { + *model.Job + // IDAllocated see config of same name in CreateTableConfig. + // exported for test. + IDAllocated bool + // job submission is run in async, we use this channel to notify the caller. + // when fast create table enabled, we might combine multiple jobs into one, and + // append the channel to this slice. + ResultCh []chan jobSubmitResult + cacheErr error +} + +// NewJobWrapper creates a new JobWrapper. +// exported for testing. +func NewJobWrapper(job *model.Job, idAllocated bool) *JobWrapper { + return &JobWrapper{ + Job: job, + IDAllocated: idAllocated, + ResultCh: []chan jobSubmitResult{make(chan jobSubmitResult)}, + } +} + +// NotifyResult notifies the job submit result. +func (t *JobWrapper) NotifyResult(err error) { + merged := len(t.ResultCh) > 1 + for _, resultCh := range t.ResultCh { + resultCh <- jobSubmitResult{ + err: err, + jobID: t.ID, + merged: merged, + } + } +} + +// ddl is used to handle the statements that define the structure or schema of the database. +type ddl struct { + m sync.RWMutex + wg tidbutil.WaitGroupWrapper // It's only used to deal with data race in restart_test. + limitJobCh chan *JobWrapper + + *ddlCtx + sessPool *sess.Pool + delRangeMgr delRangeManager + enableTiFlashPoll *atomicutil.Bool + // get notification if any DDL job submitted or finished. + ddlJobNotifyCh chan struct{} + sysTblMgr systable.Manager + minJobIDRefresher *systable.MinJobIDRefresher + + // globalIDLock locks global id to reduce write conflict. + globalIDLock sync.Mutex + executor *executor +} + +// waitSchemaSyncedController is to control whether to waitSchemaSynced or not. +type waitSchemaSyncedController struct { + mu sync.RWMutex + job map[int64]struct{} + + // Use to check if the DDL job is the first run on this owner. + onceMap map[int64]struct{} +} + +func newWaitSchemaSyncedController() *waitSchemaSyncedController { + return &waitSchemaSyncedController{ + job: make(map[int64]struct{}, jobRecordCapacity), + onceMap: make(map[int64]struct{}, jobOnceCapacity), + } +} + +func (w *waitSchemaSyncedController) registerSync(job *model.Job) { + w.mu.Lock() + defer w.mu.Unlock() + w.job[job.ID] = struct{}{} +} + +func (w *waitSchemaSyncedController) isSynced(job *model.Job) bool { + w.mu.RLock() + defer w.mu.RUnlock() + _, ok := w.job[job.ID] + return !ok +} + +func (w *waitSchemaSyncedController) synced(job *model.Job) { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.job, job.ID) +} + +// maybeAlreadyRunOnce returns true means that the job may be the first run on this owner. +// Returns false means that the job must not be the first run on this owner. +func (w *waitSchemaSyncedController) maybeAlreadyRunOnce(id int64) bool { + w.mu.Lock() + defer w.mu.Unlock() + _, ok := w.onceMap[id] + return ok +} + +func (w *waitSchemaSyncedController) setAlreadyRunOnce(id int64) { + w.mu.Lock() + defer w.mu.Unlock() + if len(w.onceMap) > jobOnceCapacity { + // If the map is too large, we reset it. These jobs may need to check schema synced again, but it's ok. + w.onceMap = make(map[int64]struct{}, jobRecordCapacity) + } + w.onceMap[id] = struct{}{} +} + +func (w *waitSchemaSyncedController) clearOnceMap() { + w.mu.Lock() + defer w.mu.Unlock() + w.onceMap = make(map[int64]struct{}, jobOnceCapacity) +} + +// ddlCtx is the context when we use worker to handle DDL jobs. +type ddlCtx struct { + ctx context.Context + cancel context.CancelFunc + uuid string + store kv.Storage + ownerManager owner.Manager + schemaSyncer syncer.SchemaSyncer + stateSyncer syncer.StateSyncer + // ddlJobDoneChMap is used to notify the session that the DDL job is finished. + // jobID -> chan struct{} + ddlJobDoneChMap generic.SyncMap[int64, chan struct{}] + ddlEventCh chan<- *statsutil.DDLEvent + lease time.Duration // lease is schema lease, default 45s, see config.Lease. + binlogCli *pumpcli.PumpsClient // binlogCli is used for Binlog. + infoCache *infoschema.InfoCache + statsHandle *handle.Handle + tableLockCkr util.DeadTableLockChecker + etcdCli *clientv3.Client + autoidCli *autoid.ClientDiscover + schemaLoader SchemaLoader + + *waitSchemaSyncedController + *schemaVersionManager + + // reorgCtx is used for reorganization. + reorgCtx reorgContexts + + jobCtx struct { + sync.RWMutex + // jobCtxMap maps job ID to job's ctx. + jobCtxMap map[int64]*JobContext + } +} + +// SchemaLoader is used to avoid import loop, the only impl is domain currently. +type SchemaLoader interface { + Reload() error +} + +// schemaVersionManager is used to manage the schema version. To prevent the conflicts on this key between different DDL job, +// we use another transaction to update the schema version, so that we need to lock the schema version and unlock it until the job is committed. +// for version2, we use etcd lock to lock the schema version between TiDB nodes now. +type schemaVersionManager struct { + schemaVersionMu sync.Mutex + // lockOwner stores the job ID that is holding the lock. + lockOwner atomicutil.Int64 +} + +func newSchemaVersionManager() *schemaVersionManager { + return &schemaVersionManager{} +} + +func (sv *schemaVersionManager) setSchemaVersion(job *model.Job, store kv.Storage) (schemaVersion int64, err error) { + err = sv.lockSchemaVersion(job.ID) + if err != nil { + return schemaVersion, errors.Trace(err) + } + err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { + var err error + m := meta.NewMeta(txn) + schemaVersion, err = m.GenSchemaVersion() + return err + }) + return schemaVersion, err +} + +// lockSchemaVersion gets the lock to prevent the schema version from being updated. +func (sv *schemaVersionManager) lockSchemaVersion(jobID int64) error { + ownerID := sv.lockOwner.Load() + // There may exist one job update schema version many times in multiple-schema-change, so we do not lock here again + // if they are the same job. + if ownerID != jobID { + sv.schemaVersionMu.Lock() + sv.lockOwner.Store(jobID) + } + return nil +} + +// unlockSchemaVersion releases the lock. +func (sv *schemaVersionManager) unlockSchemaVersion(jobID int64) { + ownerID := sv.lockOwner.Load() + if ownerID == jobID { + sv.lockOwner.Store(0) + sv.schemaVersionMu.Unlock() + } +} + +func (dc *ddlCtx) isOwner() bool { + isOwner := dc.ownerManager.IsOwner() + logutil.DDLLogger().Debug("check whether is the DDL owner", zap.Bool("isOwner", isOwner), zap.String("selfID", dc.uuid)) + if isOwner { + metrics.DDLCounter.WithLabelValues(metrics.DDLOwner + "_" + mysql.TiDBReleaseVersion).Inc() + } + return isOwner +} + +func (dc *ddlCtx) setDDLLabelForTopSQL(jobID int64, jobQuery string) { + dc.jobCtx.Lock() + defer dc.jobCtx.Unlock() + ctx, exists := dc.jobCtx.jobCtxMap[jobID] + if !exists { + ctx = NewJobContext() + dc.jobCtx.jobCtxMap[jobID] = ctx + } + ctx.setDDLLabelForTopSQL(jobQuery) +} + +func (dc *ddlCtx) setDDLSourceForDiagnosis(jobID int64, jobType model.ActionType) { + dc.jobCtx.Lock() + defer dc.jobCtx.Unlock() + ctx, exists := dc.jobCtx.jobCtxMap[jobID] + if !exists { + ctx = NewJobContext() + dc.jobCtx.jobCtxMap[jobID] = ctx + } + ctx.setDDLLabelForDiagnosis(jobType) +} + +func (dc *ddlCtx) getResourceGroupTaggerForTopSQL(jobID int64) tikvrpc.ResourceGroupTagger { + dc.jobCtx.Lock() + defer dc.jobCtx.Unlock() + ctx, exists := dc.jobCtx.jobCtxMap[jobID] + if !exists { + return nil + } + return ctx.getResourceGroupTaggerForTopSQL() +} + +func (dc *ddlCtx) removeJobCtx(job *model.Job) { + dc.jobCtx.Lock() + defer dc.jobCtx.Unlock() + delete(dc.jobCtx.jobCtxMap, job.ID) +} + +func (dc *ddlCtx) jobContext(jobID int64, reorgMeta *model.DDLReorgMeta) *JobContext { + dc.jobCtx.RLock() + defer dc.jobCtx.RUnlock() + var ctx *JobContext + if jobContext, exists := dc.jobCtx.jobCtxMap[jobID]; exists { + ctx = jobContext + } else { + ctx = NewJobContext() + } + if reorgMeta != nil && len(ctx.resourceGroupName) == 0 { + ctx.resourceGroupName = reorgMeta.ResourceGroupName + } + return ctx +} + +type reorgContexts struct { + sync.RWMutex + // reorgCtxMap maps job ID to reorg context. + reorgCtxMap map[int64]*reorgCtx + beOwnerTS int64 +} + +func (r *reorgContexts) getOwnerTS() int64 { + r.RLock() + defer r.RUnlock() + return r.beOwnerTS +} + +func (r *reorgContexts) setOwnerTS(ts int64) { + r.Lock() + r.beOwnerTS = ts + r.Unlock() +} + +func (dc *ddlCtx) getReorgCtx(jobID int64) *reorgCtx { + dc.reorgCtx.RLock() + defer dc.reorgCtx.RUnlock() + return dc.reorgCtx.reorgCtxMap[jobID] +} + +func (dc *ddlCtx) newReorgCtx(jobID int64, rowCount int64) *reorgCtx { + dc.reorgCtx.Lock() + defer dc.reorgCtx.Unlock() + existedRC, ok := dc.reorgCtx.reorgCtxMap[jobID] + if ok { + existedRC.references.Add(1) + return existedRC + } + rc := &reorgCtx{} + rc.doneCh = make(chan reorgFnResult, 1) + // initial reorgCtx + rc.setRowCount(rowCount) + rc.mu.warnings = make(map[errors.ErrorID]*terror.Error) + rc.mu.warningsCount = make(map[errors.ErrorID]int64) + rc.references.Add(1) + dc.reorgCtx.reorgCtxMap[jobID] = rc + return rc +} + +func (dc *ddlCtx) removeReorgCtx(jobID int64) { + dc.reorgCtx.Lock() + defer dc.reorgCtx.Unlock() + ctx, ok := dc.reorgCtx.reorgCtxMap[jobID] + if ok { + ctx.references.Sub(1) + if ctx.references.Load() == 0 { + delete(dc.reorgCtx.reorgCtxMap, jobID) + } + } +} + +func (dc *ddlCtx) notifyReorgWorkerJobStateChange(job *model.Job) { + rc := dc.getReorgCtx(job.ID) + if rc == nil { + logutil.DDLLogger().Warn("cannot find reorgCtx", zap.Int64("Job ID", job.ID)) + return + } + logutil.DDLLogger().Info("notify reorg worker the job's state", + zap.Int64("Job ID", job.ID), zap.Stringer("Job State", job.State), + zap.Stringer("Schema State", job.SchemaState)) + rc.notifyJobState(job.State) +} + +func (dc *ddlCtx) notifyJobDone(jobID int64) { + if ch, ok := dc.ddlJobDoneChMap.Delete(jobID); ok { + // broadcast done event as we might merge multiple jobs into one when fast + // create table is enabled. + close(ch) + } +} + +// EnableTiFlashPoll enables TiFlash poll loop aka PollTiFlashReplicaStatus. +func EnableTiFlashPoll(d any) { + if dd, ok := d.(*ddl); ok { + dd.enableTiFlashPoll.Store(true) + } +} + +// DisableTiFlashPoll disables TiFlash poll loop aka PollTiFlashReplicaStatus. +func DisableTiFlashPoll(d any) { + if dd, ok := d.(*ddl); ok { + dd.enableTiFlashPoll.Store(false) + } +} + +// IsTiFlashPollEnabled reveals enableTiFlashPoll +func (d *ddl) IsTiFlashPollEnabled() bool { + return d.enableTiFlashPoll.Load() +} + +// RegisterStatsHandle registers statistics handle and its corresponding even channel for ddl. +// TODO this is called after ddl started, will cause panic if related DDL are executed +// in between. +func (d *ddl) RegisterStatsHandle(h *handle.Handle) { + d.ddlCtx.statsHandle = h + d.executor.statsHandle = h + d.ddlEventCh = h.DDLEventCh() +} + +// asyncNotifyEvent will notify the ddl event to outside world, say statistic handle. When the channel is full, we may +// give up notify and log it. +func asyncNotifyEvent(d *ddlCtx, e *statsutil.DDLEvent) { + if d.ddlEventCh != nil { + if d.lease == 0 { + // If lease is 0, it's always used in test. + select { + case d.ddlEventCh <- e: + default: + } + return + } + for i := 0; i < 10; i++ { + select { + case d.ddlEventCh <- e: + return + default: + time.Sleep(time.Microsecond * 10) + } + } + logutil.DDLLogger().Warn("fail to notify DDL event", zap.Stringer("event", e)) + } +} + +// NewDDL creates a new DDL. +// TODO remove it, to simplify this PR we use this way. +func NewDDL(ctx context.Context, options ...Option) (DDL, Executor) { + return newDDL(ctx, options...) +} + +func newDDL(ctx context.Context, options ...Option) (*ddl, *executor) { + opt := &Options{} + for _, o := range options { + o(opt) + } + + id := uuid.New().String() + var manager owner.Manager + var schemaSyncer syncer.SchemaSyncer + var stateSyncer syncer.StateSyncer + var deadLockCkr util.DeadTableLockChecker + if etcdCli := opt.EtcdCli; etcdCli == nil { + // The etcdCli is nil if the store is localstore which is only used for testing. + // So we use mockOwnerManager and MockSchemaSyncer. + manager = owner.NewMockManager(ctx, id, opt.Store, DDLOwnerKey) + schemaSyncer = NewMockSchemaSyncer() + stateSyncer = NewMockStateSyncer() + } else { + manager = owner.NewOwnerManager(ctx, etcdCli, ddlPrompt, id, DDLOwnerKey) + schemaSyncer = syncer.NewSchemaSyncer(etcdCli, id) + stateSyncer = syncer.NewStateSyncer(etcdCli, util.ServerGlobalState) + deadLockCkr = util.NewDeadTableLockChecker(etcdCli) + } + + // TODO: make store and infoCache explicit arguments + // these two should be ensured to exist + if opt.Store == nil { + panic("store should not be nil") + } + if opt.InfoCache == nil { + panic("infoCache should not be nil") + } + + ddlCtx := &ddlCtx{ + uuid: id, + store: opt.Store, + lease: opt.Lease, + ddlJobDoneChMap: generic.NewSyncMap[int64, chan struct{}](10), + ownerManager: manager, + schemaSyncer: schemaSyncer, + stateSyncer: stateSyncer, + binlogCli: binloginfo.GetPumpsClient(), + infoCache: opt.InfoCache, + tableLockCkr: deadLockCkr, + etcdCli: opt.EtcdCli, + autoidCli: opt.AutoIDClient, + schemaLoader: opt.SchemaLoader, + waitSchemaSyncedController: newWaitSchemaSyncedController(), + } + ddlCtx.reorgCtx.reorgCtxMap = make(map[int64]*reorgCtx) + ddlCtx.jobCtx.jobCtxMap = make(map[int64]*JobContext) + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnDDL) + ddlCtx.ctx, ddlCtx.cancel = context.WithCancel(ctx) + ddlCtx.schemaVersionManager = newSchemaVersionManager() + + d := &ddl{ + ddlCtx: ddlCtx, + limitJobCh: make(chan *JobWrapper, batchAddingJobs), + enableTiFlashPoll: atomicutil.NewBool(true), + ddlJobNotifyCh: make(chan struct{}, 100), + } + + taskexecutor.RegisterTaskType(proto.Backfill, + func(ctx context.Context, id string, task *proto.Task, taskTable taskexecutor.TaskTable) taskexecutor.TaskExecutor { + return newBackfillDistExecutor(ctx, id, task, taskTable, d) + }, + ) + + scheduler.RegisterSchedulerFactory(proto.Backfill, + func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { + return newLitBackfillScheduler(ctx, d, task, param) + }) + scheduler.RegisterSchedulerCleanUpFactory(proto.Backfill, newBackfillCleanUpS3) + // Register functions for enable/disable ddl when changing system variable `tidb_enable_ddl`. + variable.EnableDDL = d.EnableDDL + variable.DisableDDL = d.DisableDDL + variable.SwitchMDL = d.SwitchMDL + + e := &executor{ + ctx: d.ctx, + uuid: d.uuid, + store: d.store, + etcdCli: d.etcdCli, + autoidCli: d.autoidCli, + infoCache: d.infoCache, + limitJobCh: d.limitJobCh, + schemaLoader: d.schemaLoader, + lease: d.lease, + ownerManager: d.ownerManager, + ddlJobDoneChMap: &d.ddlJobDoneChMap, + ddlJobNotifyCh: d.ddlJobNotifyCh, + globalIDLock: &d.globalIDLock, + } + d.executor = e + + return d, e +} + +// Stop implements DDL.Stop interface. +func (d *ddl) Stop() error { + d.m.Lock() + defer d.m.Unlock() + + d.close() + logutil.DDLLogger().Info("stop DDL", zap.String("ID", d.uuid)) + return nil +} + +func (d *ddl) newDeleteRangeManager(mock bool) delRangeManager { + var delRangeMgr delRangeManager + if !mock { + delRangeMgr = newDelRangeManager(d.store, d.sessPool) + logutil.DDLLogger().Info("start delRangeManager OK", zap.Bool("is a emulator", !d.store.SupportDeleteRange())) + } else { + delRangeMgr = newMockDelRangeManager() + } + + delRangeMgr.start() + return delRangeMgr +} + +// Start implements DDL.Start interface. +func (d *ddl) Start(ctxPool *pools.ResourcePool) error { + logutil.DDLLogger().Info("start DDL", zap.String("ID", d.uuid), zap.Bool("runWorker", config.GetGlobalConfig().Instance.TiDBEnableDDL.Load())) + + d.sessPool = sess.NewSessionPool(ctxPool) + d.executor.sessPool = d.sessPool + d.sysTblMgr = systable.NewManager(d.sessPool) + d.minJobIDRefresher = systable.NewMinJobIDRefresher(d.sysTblMgr) + d.wg.Run(func() { + d.limitDDLJobs() + }) + d.wg.Run(func() { + d.minJobIDRefresher.Start(d.ctx) + }) + + d.delRangeMgr = d.newDeleteRangeManager(ctxPool == nil) + + if err := d.stateSyncer.Init(d.ctx); err != nil { + logutil.DDLLogger().Warn("start DDL init state syncer failed", zap.Error(err)) + return errors.Trace(err) + } + d.ownerManager.SetListener(&ownerListener{ + ddl: d, + }) + + if config.TableLockEnabled() { + d.wg.Add(1) + go d.startCleanDeadTableLock() + } + + // If tidb_enable_ddl is true, we need campaign owner and do DDL jobs. Besides, we also can do backfill jobs. + // Otherwise, we needn't do that. + if config.GetGlobalConfig().Instance.TiDBEnableDDL.Load() { + if err := d.EnableDDL(); err != nil { + return err + } + } + + variable.RegisterStatistics(d) + + metrics.DDLCounter.WithLabelValues(metrics.CreateDDLInstance).Inc() + + // Start some background routine to manage TiFlash replica. + d.wg.Run(d.PollTiFlashRoutine) + + ingestDataDir, err := ingest.GenIngestTempDataDir() + if err != nil { + logutil.DDLIngestLogger().Warn(ingest.LitWarnEnvInitFail, + zap.Error(err)) + } else { + ok := ingest.InitGlobalLightningEnv(ingestDataDir) + if ok { + d.wg.Run(func() { + d.CleanUpTempDirLoop(d.ctx, ingestDataDir) + }) + } + } + + return nil +} + +func (d *ddl) CleanUpTempDirLoop(ctx context.Context, path string) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ticker.C: + se, err := d.sessPool.Get() + if err != nil { + logutil.DDLLogger().Warn("get session from pool failed", zap.Error(err)) + return + } + ingest.CleanUpTempDir(ctx, se, path) + d.sessPool.Put(se) + case <-d.ctx.Done(): + return + } + } +} + +// EnableDDL enable this node to execute ddl. +// Since ownerManager.CampaignOwner will start a new goroutine to run ownerManager.campaignLoop, +// we should make sure that before invoking EnableDDL(), ddl is DISABLE. +func (d *ddl) EnableDDL() error { + err := d.ownerManager.CampaignOwner() + return errors.Trace(err) +} + +// DisableDDL disable this node to execute ddl. +// We should make sure that before invoking DisableDDL(), ddl is ENABLE. +func (d *ddl) DisableDDL() error { + if d.ownerManager.IsOwner() { + // If there is only one node, we should NOT disable ddl. + serverInfo, err := infosync.GetAllServerInfo(d.ctx) + if err != nil { + logutil.DDLLogger().Error("error when GetAllServerInfo", zap.Error(err)) + return err + } + if len(serverInfo) <= 1 { + return dbterror.ErrDDLSetting.GenWithStackByArgs("disabling", "can not disable ddl owner when it is the only one tidb instance") + } + // FIXME: if possible, when this node is the only node with DDL, ths setting of DisableDDL should fail. + } + + // disable campaign by interrupting campaignLoop + d.ownerManager.CampaignCancel() + return nil +} + +func (d *ddl) close() { + if d.ctx.Err() != nil { + return + } + + startTime := time.Now() + d.cancel() + d.wg.Wait() + d.ownerManager.Cancel() + d.schemaSyncer.Close() + + // d.delRangeMgr using sessions from d.sessPool. + // Put it before d.sessPool.close to reduce the time spent by d.sessPool.close. + if d.delRangeMgr != nil { + d.delRangeMgr.clear() + } + if d.sessPool != nil { + d.sessPool.Close() + } + variable.UnregisterStatistics(d) + + logutil.DDLLogger().Info("DDL closed", zap.String("ID", d.uuid), zap.Duration("take time", time.Since(startTime))) +} + +// GetLease implements DDL.GetLease interface. +func (d *ddl) GetLease() time.Duration { + lease := d.lease + return lease +} + +// SchemaSyncer implements DDL.SchemaSyncer interface. +func (d *ddl) SchemaSyncer() syncer.SchemaSyncer { + return d.schemaSyncer +} + +// StateSyncer implements DDL.StateSyncer interface. +func (d *ddl) StateSyncer() syncer.StateSyncer { + return d.stateSyncer +} + +// OwnerManager implements DDL.OwnerManager interface. +func (d *ddl) OwnerManager() owner.Manager { + return d.ownerManager +} + +// GetID implements DDL.GetID interface. +func (d *ddl) GetID() string { + return d.uuid +} + +// SetBinlogClient implements DDL.SetBinlogClient interface. +func (d *ddl) SetBinlogClient(binlogCli *pumpcli.PumpsClient) { + d.binlogCli = binlogCli +} + +func (d *ddl) GetMinJobIDRefresher() *systable.MinJobIDRefresher { + return d.minJobIDRefresher +} + +func (d *ddl) startCleanDeadTableLock() { + defer func() { + d.wg.Done() + }() + + defer tidbutil.Recover(metrics.LabelDDL, "startCleanDeadTableLock", nil, false) + + ticker := time.NewTicker(time.Second * 10) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if !d.ownerManager.IsOwner() { + continue + } + deadLockTables, err := d.tableLockCkr.GetDeadLockedTables(d.ctx, d.infoCache.GetLatest()) + if err != nil { + logutil.DDLLogger().Info("get dead table lock failed.", zap.Error(err)) + continue + } + for se, tables := range deadLockTables { + err := d.cleanDeadTableLock(tables, se) + if err != nil { + logutil.DDLLogger().Info("clean dead table lock failed.", zap.Error(err)) + } + } + case <-d.ctx.Done(): + return + } + } +} + +// cleanDeadTableLock uses to clean dead table locks. +func (d *ddl) cleanDeadTableLock(unlockTables []model.TableLockTpInfo, se model.SessionInfo) error { + if len(unlockTables) == 0 { + return nil + } + arg := &LockTablesArg{ + UnlockTables: unlockTables, + SessionInfo: se, + } + job := &model.Job{ + SchemaID: unlockTables[0].SchemaID, + TableID: unlockTables[0].TableID, + Type: model.ActionUnlockTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{arg}, + } + + ctx, err := d.sessPool.Get() + if err != nil { + return err + } + defer d.sessPool.Put(ctx) + err = d.executor.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// SwitchMDL enables MDL or disable MDL. +func (d *ddl) SwitchMDL(enable bool) error { + isEnableBefore := variable.EnableMDL.Load() + if isEnableBefore == enable { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + // Check if there is any DDL running. + // This check can not cover every corner cases, so users need to guarantee that there is no DDL running by themselves. + sessCtx, err := d.sessPool.Get() + if err != nil { + return err + } + defer d.sessPool.Put(sessCtx) + se := sess.NewSession(sessCtx) + rows, err := se.Execute(ctx, "select 1 from mysql.tidb_ddl_job", "check job") + if err != nil { + return err + } + if len(rows) != 0 { + return errors.New("please wait for all jobs done") + } + + variable.EnableMDL.Store(enable) + err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), d.store, true, func(_ context.Context, txn kv.Transaction) error { + m := meta.NewMeta(txn) + oldEnable, _, err := m.GetMetadataLock() + if err != nil { + return err + } + if oldEnable != enable { + err = m.SetMetadataLock(enable) + } + return err + }) + if err != nil { + logutil.DDLLogger().Warn("switch metadata lock feature", zap.Bool("enable", enable), zap.Error(err)) + return err + } + logutil.DDLLogger().Info("switch metadata lock feature", zap.Bool("enable", enable)) + return nil +} + +// RecoverInfo contains information needed by DDL.RecoverTable. +type RecoverInfo struct { + SchemaID int64 + TableInfo *model.TableInfo + DropJobID int64 + SnapshotTS uint64 + AutoIDs meta.AutoIDGroup + OldSchemaName string + OldTableName string +} + +// RecoverSchemaInfo contains information needed by DDL.RecoverSchema. +type RecoverSchemaInfo struct { + *model.DBInfo + RecoverTabsInfo []*RecoverInfo + // LoadTablesOnExecute is the new logic to avoid a large RecoverTabsInfo can't be + // persisted. If it's true, DDL owner will recover RecoverTabsInfo instead of the + // job submit node. + LoadTablesOnExecute bool + DropJobID int64 + SnapshotTS uint64 + OldSchemaName model.CIStr +} + +// delayForAsyncCommit sleeps `SafeWindow + AllowedClockDrift` before a DDL job finishes. +// It should be called before any DDL that could break data consistency. +// This provides a safe window for async commit and 1PC to commit with an old schema. +func delayForAsyncCommit() { + if variable.EnableMDL.Load() { + // If metadata lock is enabled. The transaction of DDL must begin after prewrite of the async commit transaction, + // then the commit ts of DDL must be greater than the async commit transaction. In this case, the corresponding schema of the async commit transaction + // is correct. But if metadata lock is disabled, we can't ensure that the corresponding schema of the async commit transaction isn't change. + return + } + cfg := config.GetGlobalConfig().TiKVClient.AsyncCommit + duration := cfg.SafeWindow + cfg.AllowedClockDrift + logutil.DDLLogger().Info("sleep before DDL finishes to make async commit and 1PC safe", + zap.Duration("duration", duration)) + time.Sleep(duration) +} + +var ( + // RunInGoTest is used to identify whether ddl in running in the test. + RunInGoTest bool +) + +// GetDropOrTruncateTableInfoFromJobsByStore implements GetDropOrTruncateTableInfoFromJobs +func GetDropOrTruncateTableInfoFromJobsByStore(jobs []*model.Job, gcSafePoint uint64, getTable func(uint64, int64, int64) (*model.TableInfo, error), fn func(*model.Job, *model.TableInfo) (bool, error)) (bool, error) { + for _, job := range jobs { + // Check GC safe point for getting snapshot infoSchema. + err := gcutil.ValidateSnapshotWithGCSafePoint(job.StartTS, gcSafePoint) + if err != nil { + return false, err + } + if job.Type != model.ActionDropTable && job.Type != model.ActionTruncateTable { + continue + } + + tbl, err := getTable(job.StartTS, job.SchemaID, job.TableID) + if err != nil { + if meta.ErrDBNotExists.Equal(err) { + // The dropped/truncated DDL maybe execute failed that caused by the parallel DDL execution, + // then can't find the table from the snapshot info-schema. Should just ignore error here, + // see more in TestParallelDropSchemaAndDropTable. + continue + } + return false, err + } + if tbl == nil { + // The dropped/truncated DDL maybe execute failed that caused by the parallel DDL execution, + // then can't find the table from the snapshot info-schema. Should just ignore error here, + // see more in TestParallelDropSchemaAndDropTable. + continue + } + finish, err := fn(job, tbl) + if err != nil || finish { + return finish, err + } + } + return false, nil +} + +// Info is for DDL information. +type Info struct { + SchemaVer int64 + ReorgHandle kv.Key // It's only used for DDL information. + Jobs []*model.Job // It's the currently running jobs. +} + +// GetDDLInfoWithNewTxn returns DDL information using a new txn. +func GetDDLInfoWithNewTxn(s sessionctx.Context) (*Info, error) { + se := sess.NewSession(s) + err := se.Begin(context.Background()) + if err != nil { + return nil, err + } + info, err := GetDDLInfo(s) + se.Rollback() + return info, err +} + +// GetDDLInfo returns DDL information and only uses for testing. +func GetDDLInfo(s sessionctx.Context) (*Info, error) { + var err error + info := &Info{} + se := sess.NewSession(s) + txn, err := se.Txn() + if err != nil { + return nil, errors.Trace(err) + } + t := meta.NewMeta(txn) + info.Jobs = make([]*model.Job, 0, 2) + var generalJob, reorgJob *model.Job + generalJob, reorgJob, err = get2JobsFromTable(se) + if err != nil { + return nil, errors.Trace(err) + } + + if generalJob != nil { + info.Jobs = append(info.Jobs, generalJob) + } + + if reorgJob != nil { + info.Jobs = append(info.Jobs, reorgJob) + } + + info.SchemaVer, err = t.GetSchemaVersionWithNonEmptyDiff() + if err != nil { + return nil, errors.Trace(err) + } + if reorgJob == nil { + return info, nil + } + + _, info.ReorgHandle, _, _, err = newReorgHandler(se).GetDDLReorgHandle(reorgJob) + if err != nil { + if meta.ErrDDLReorgElementNotExist.Equal(err) { + return info, nil + } + return nil, errors.Trace(err) + } + + return info, nil +} + +func get2JobsFromTable(sess *sess.Session) (*model.Job, *model.Job, error) { + var generalJob, reorgJob *model.Job + jobs, err := getJobsBySQL(sess, JobTable, "not reorg order by job_id limit 1") + if err != nil { + return nil, nil, errors.Trace(err) + } + + if len(jobs) != 0 { + generalJob = jobs[0] + } + jobs, err = getJobsBySQL(sess, JobTable, "reorg order by job_id limit 1") + if err != nil { + return nil, nil, errors.Trace(err) + } + if len(jobs) != 0 { + reorgJob = jobs[0] + } + return generalJob, reorgJob, nil +} + +// cancelRunningJob cancel a DDL job that is in the concurrent state. +func cancelRunningJob(_ *sess.Session, job *model.Job, + byWho model.AdminCommandOperator) (err error) { + // These states can't be cancelled. + if job.IsDone() || job.IsSynced() { + return dbterror.ErrCancelFinishedDDLJob.GenWithStackByArgs(job.ID) + } + + // If the state is rolling back, it means the work is cleaning the data after cancelling the job. + if job.IsCancelled() || job.IsRollingback() || job.IsRollbackDone() { + return nil + } + + if !job.IsRollbackable() { + return dbterror.ErrCannotCancelDDLJob.GenWithStackByArgs(job.ID) + } + job.State = model.JobStateCancelling + job.AdminOperator = byWho + return nil +} + +// pauseRunningJob check and pause the running Job +func pauseRunningJob(_ *sess.Session, job *model.Job, + byWho model.AdminCommandOperator) (err error) { + if job.IsPausing() || job.IsPaused() { + return dbterror.ErrPausedDDLJob.GenWithStackByArgs(job.ID) + } + if !job.IsPausable() { + errMsg := fmt.Sprintf("state [%s] or schema state [%s]", job.State.String(), job.SchemaState.String()) + err = dbterror.ErrCannotPauseDDLJob.GenWithStackByArgs(job.ID, errMsg) + if err != nil { + return err + } + } + + job.State = model.JobStatePausing + job.AdminOperator = byWho + return nil +} + +// resumePausedJob check and resume the Paused Job +func resumePausedJob(_ *sess.Session, job *model.Job, + byWho model.AdminCommandOperator) (err error) { + if !job.IsResumable() { + errMsg := fmt.Sprintf("job has not been paused, job state:%s, schema state:%s", + job.State, job.SchemaState) + return dbterror.ErrCannotResumeDDLJob.GenWithStackByArgs(job.ID, errMsg) + } + // The Paused job should only be resumed by who paused it + if job.AdminOperator != byWho { + errMsg := fmt.Sprintf("job has been paused by [%s], should not resumed by [%s]", + job.AdminOperator.String(), byWho.String()) + return dbterror.ErrCannotResumeDDLJob.GenWithStackByArgs(job.ID, errMsg) + } + + job.State = model.JobStateQueueing + + return nil +} + +// processJobs command on the Job according to the process +func processJobs( + process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error), + sessCtx sessionctx.Context, + ids []int64, + byWho model.AdminCommandOperator, +) (jobErrs []error, err error) { + failpoint.Inject("mockFailedCommandOnConcurencyDDL", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("mock failed admin command on ddl jobs")) + } + }) + + if len(ids) == 0 { + return nil, nil + } + + ns := sess.NewSession(sessCtx) + // We should process (and try) all the jobs in one Transaction. + for tryN := uint(0); tryN < 3; tryN++ { + jobErrs = make([]error, len(ids)) + // Need to figure out which one could not be paused + jobMap := make(map[int64]int, len(ids)) + idsStr := make([]string, 0, len(ids)) + for idx, id := range ids { + jobMap[id] = idx + idsStr = append(idsStr, strconv.FormatInt(id, 10)) + } + + err = ns.Begin(context.Background()) + if err != nil { + return nil, err + } + jobs, err := getJobsBySQL(ns, JobTable, fmt.Sprintf("job_id in (%s) order by job_id", strings.Join(idsStr, ", "))) + if err != nil { + ns.Rollback() + return nil, err + } + + for _, job := range jobs { + i, ok := jobMap[job.ID] + if !ok { + logutil.DDLLogger().Debug("Job ID from meta is not consistent with requested job id,", + zap.Int64("fetched job ID", job.ID)) + jobErrs[i] = dbterror.ErrInvalidDDLJob.GenWithStackByArgs(job.ID) + continue + } + delete(jobMap, job.ID) + + err = process(ns, job, byWho) + if err != nil { + jobErrs[i] = err + continue + } + + err = updateDDLJob2Table(ns, job, false) + if err != nil { + jobErrs[i] = err + continue + } + } + + failpoint.Inject("mockCommitFailedOnDDLCommand", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(jobErrs, errors.New("mock commit failed on admin command on ddl jobs")) + } + }) + + // There may be some conflict during the update, try it again + if err = ns.Commit(context.Background()); err != nil { + continue + } + + for id, idx := range jobMap { + jobErrs[idx] = dbterror.ErrDDLJobNotFound.GenWithStackByArgs(id) + } + + return jobErrs, nil + } + + return jobErrs, err +} + +// CancelJobs cancels the DDL jobs according to user command. +func CancelJobs(se sessionctx.Context, ids []int64) (errs []error, err error) { + return processJobs(cancelRunningJob, se, ids, model.AdminCommandByEndUser) +} + +// PauseJobs pause all the DDL jobs according to user command. +func PauseJobs(se sessionctx.Context, ids []int64) ([]error, error) { + return processJobs(pauseRunningJob, se, ids, model.AdminCommandByEndUser) +} + +// ResumeJobs resume all the DDL jobs according to user command. +func ResumeJobs(se sessionctx.Context, ids []int64) ([]error, error) { + return processJobs(resumePausedJob, se, ids, model.AdminCommandByEndUser) +} + +// CancelJobsBySystem cancels Jobs because of internal reasons. +func CancelJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) { + return processJobs(cancelRunningJob, se, ids, model.AdminCommandBySystem) +} + +// PauseJobsBySystem pauses Jobs because of internal reasons. +func PauseJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) { + return processJobs(pauseRunningJob, se, ids, model.AdminCommandBySystem) +} + +// ResumeJobsBySystem resumes Jobs that are paused by TiDB itself. +func ResumeJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) { + return processJobs(resumePausedJob, se, ids, model.AdminCommandBySystem) +} + +// pprocessAllJobs processes all the jobs in the job table, 100 jobs at a time in case of high memory usage. +func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error), + se sessionctx.Context, byWho model.AdminCommandOperator) (map[int64]error, error) { + var err error + var jobErrs = make(map[int64]error) + + ns := sess.NewSession(se) + err = ns.Begin(context.Background()) + if err != nil { + return nil, err + } + + var jobID int64 + var jobIDMax int64 + var limit = 100 + for { + var jobs []*model.Job + jobs, err = getJobsBySQL(ns, JobTable, + fmt.Sprintf("job_id >= %s order by job_id asc limit %s", + strconv.FormatInt(jobID, 10), + strconv.FormatInt(int64(limit), 10))) + if err != nil { + ns.Rollback() + return nil, err + } + + for _, job := range jobs { + err = process(ns, job, byWho) + if err != nil { + jobErrs[job.ID] = err + continue + } + + err = updateDDLJob2Table(ns, job, false) + if err != nil { + jobErrs[job.ID] = err + continue + } + } + + // Just in case the job ID is not sequential + if len(jobs) > 0 && jobs[len(jobs)-1].ID > jobIDMax { + jobIDMax = jobs[len(jobs)-1].ID + } + + // If rows returned is smaller than $limit, then there is no more records + if len(jobs) < limit { + break + } + + jobID = jobIDMax + 1 + } + + err = ns.Commit(context.Background()) + if err != nil { + return nil, err + } + return jobErrs, nil +} + +// PauseAllJobsBySystem pauses all running Jobs because of internal reasons. +func PauseAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) { + return processAllJobs(pauseRunningJob, se, model.AdminCommandBySystem) +} + +// ResumeAllJobsBySystem resumes all paused Jobs because of internal reasons. +func ResumeAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) { + return processAllJobs(resumePausedJob, se, model.AdminCommandBySystem) +} + +// GetAllDDLJobs get all DDL jobs and sorts jobs by job.ID. +func GetAllDDLJobs(se sessionctx.Context) ([]*model.Job, error) { + return getJobsBySQL(sess.NewSession(se), JobTable, "1 order by job_id") +} + +// IterAllDDLJobs will iterates running DDL jobs first, return directly if `finishFn` return true or error, +// then iterates history DDL jobs until the `finishFn` return true or error. +func IterAllDDLJobs(ctx sessionctx.Context, txn kv.Transaction, finishFn func([]*model.Job) (bool, error)) error { + jobs, err := GetAllDDLJobs(ctx) + if err != nil { + return err + } + + finish, err := finishFn(jobs) + if err != nil || finish { + return err + } + return IterHistoryDDLJobs(txn, finishFn) +} diff --git a/pkg/ddl/ddl_tiflash_api.go b/pkg/ddl/ddl_tiflash_api.go index 9f4c2512f0019..d232d78758f91 100644 --- a/pkg/ddl/ddl_tiflash_api.go +++ b/pkg/ddl/ddl_tiflash_api.go @@ -352,9 +352,9 @@ func updateTiFlashStores(pollTiFlashContext *TiFlashManagementContext) error { // PollAvailableTableProgress will poll and check availability of available tables. func PollAvailableTableProgress(schemas infoschema.InfoSchema, _ sessionctx.Context, pollTiFlashContext *TiFlashManagementContext) { pollMaxCount := RefreshProgressMaxTableCount - failpoint.Inject("PollAvailableTableProgressMaxCount", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("PollAvailableTableProgressMaxCount")); _err_ == nil { pollMaxCount = uint64(val.(int)) - }) + } for element := pollTiFlashContext.UpdatingProgressTables.Front(); element != nil && pollMaxCount > 0; pollMaxCount-- { availableTableID := element.Value.(AvailableTableID) var table table.Table @@ -431,13 +431,13 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } } - failpoint.Inject("OneTiFlashStoreDown", func() { + if _, _err_ := failpoint.Eval(_curpkg_("OneTiFlashStoreDown")); _err_ == nil { for storeID, store := range pollTiFlashContext.TiFlashStores { store.Store.StateName = "Down" pollTiFlashContext.TiFlashStores[storeID] = store break } - }) + } pollTiFlashContext.PollCounter++ // Start to process every table. @@ -458,7 +458,7 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } } - failpoint.Inject("waitForAddPartition", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("waitForAddPartition")); _err_ == nil { for _, phyTable := range tableList { is := d.infoCache.GetLatest() _, ok := is.TableByID(phyTable.ID) @@ -471,7 +471,7 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } } } - }) + } needPushPending := false if pollTiFlashContext.UpdatingProgressTables.Len() == 0 { @@ -482,9 +482,9 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T // For every region in each table, if it has one replica, we reckon it ready. // These request can be batched as an optimization. available := tb.Available - failpoint.Inject("PollTiFlashReplicaStatusReplacePrevAvailableValue", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("PollTiFlashReplicaStatusReplacePrevAvailableValue")); _err_ == nil { available = val.(bool) - }) + } // We only check unavailable tables here, so doesn't include blocked add partition case. if !available && !tb.LogicalTableAvailable { enabled, inqueue, _ := pollTiFlashContext.Backoff.Tick(tb.ID) @@ -514,9 +514,9 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } avail := progress == 1 - failpoint.Inject("PollTiFlashReplicaStatusReplaceCurAvailableValue", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("PollTiFlashReplicaStatusReplaceCurAvailableValue")); _err_ == nil { avail = val.(bool) - }) + } if !avail { logutil.DDLLogger().Info("Tiflash replica is not available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) @@ -525,9 +525,9 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T logutil.DDLLogger().Info("Tiflash replica is available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) pollTiFlashContext.Backoff.Remove(tb.ID) } - failpoint.Inject("skipUpdateTableReplicaInfoInLoop", func() { - failpoint.Continue() - }) + if _, _err_ := failpoint.Eval(_curpkg_("skipUpdateTableReplicaInfoInLoop")); _err_ == nil { + continue + } // Will call `onUpdateFlashReplicaStatus` to update `TiFlashReplica`. if err := d.executor.UpdateTableReplicaInfo(ctx, tb.ID, avail); err != nil { if infoschema.ErrTableNotExists.Equal(err) && tb.IsPartition { @@ -566,9 +566,9 @@ func (d *ddl) PollTiFlashRoutine() { logutil.DDLLogger().Error("failed to get sessionPool for refreshTiFlashTicker") return } - failpoint.Inject("BeforeRefreshTiFlashTickeLoop", func() { - failpoint.Continue() - }) + if _, _err_ := failpoint.Eval(_curpkg_("BeforeRefreshTiFlashTickeLoop")); _err_ == nil { + continue + } if !hasSetTiFlashGroup && !time.Now().Before(nextSetTiFlashGroupTime) { // We should set tiflash rule group a higher index than other placement groups to forbid override by them. diff --git a/pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ b/pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ new file mode 100644 index 0000000000000..9f4c2512f0019 --- /dev/null +++ b/pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ @@ -0,0 +1,608 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package ddl + +import ( + "bytes" + "container/list" + "context" + "encoding/json" + "fmt" + "net" + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/engine" + "github.com/pingcap/tidb/pkg/util/intest" + pd "github.com/tikv/pd/client/http" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// TiFlashReplicaStatus records status for each TiFlash replica. +type TiFlashReplicaStatus struct { + ID int64 + Count uint64 + LocationLabels []string + Available bool + LogicalTableAvailable bool + HighPriority bool + IsPartition bool +} + +// TiFlashTick is type for backoff threshold. +type TiFlashTick float64 + +// PollTiFlashBackoffElement records backoff for each TiFlash Table. +// `Counter` increases every `Tick`, if it reached `Threshold`, it will be reset to 0 while `Threshold` grows. +// `TotalCounter` records total `Tick`s this element has since created. +type PollTiFlashBackoffElement struct { + Counter int + Threshold TiFlashTick + TotalCounter int +} + +// NewPollTiFlashBackoffElement initialize backoff element for a TiFlash table. +func NewPollTiFlashBackoffElement() *PollTiFlashBackoffElement { + return &PollTiFlashBackoffElement{ + Counter: 0, + Threshold: PollTiFlashBackoffMinTick, + TotalCounter: 0, + } +} + +// PollTiFlashBackoffContext is a collection of all backoff states. +type PollTiFlashBackoffContext struct { + MinThreshold TiFlashTick + MaxThreshold TiFlashTick + // Capacity limits tables a backoff pool can handle, in order to limit handling of big tables. + Capacity int + Rate TiFlashTick + elements map[int64]*PollTiFlashBackoffElement +} + +// NewPollTiFlashBackoffContext creates an instance of PollTiFlashBackoffContext. +func NewPollTiFlashBackoffContext(minThreshold, maxThreshold TiFlashTick, capacity int, rate TiFlashTick) (*PollTiFlashBackoffContext, error) { + if maxThreshold < minThreshold { + return nil, fmt.Errorf("`maxThreshold` should always be larger than `minThreshold`") + } + if minThreshold < 1 { + return nil, fmt.Errorf("`minThreshold` should not be less than 1") + } + if capacity < 0 { + return nil, fmt.Errorf("negative `capacity`") + } + if rate <= 1 { + return nil, fmt.Errorf("`rate` should always be larger than 1") + } + return &PollTiFlashBackoffContext{ + MinThreshold: minThreshold, + MaxThreshold: maxThreshold, + Capacity: capacity, + elements: make(map[int64]*PollTiFlashBackoffElement), + Rate: rate, + }, nil +} + +// TiFlashManagementContext is the context for TiFlash Replica Management +type TiFlashManagementContext struct { + TiFlashStores map[int64]pd.StoreInfo + PollCounter uint64 + Backoff *PollTiFlashBackoffContext + // tables waiting for updating progress after become available. + UpdatingProgressTables *list.List +} + +// AvailableTableID is the table id info of available table for waiting to update TiFlash replica progress. +type AvailableTableID struct { + ID int64 + IsPartition bool +} + +// Tick will first check increase Counter. +// It returns: +// 1. A bool indicates whether threshold is grown during this tick. +// 2. A bool indicates whether this ID exists. +// 3. A int indicates how many ticks ID has counted till now. +func (b *PollTiFlashBackoffContext) Tick(id int64) (grew bool, exist bool, cnt int) { + e, ok := b.Get(id) + if !ok { + return false, false, 0 + } + grew = e.MaybeGrow(b) + e.Counter++ + e.TotalCounter++ + return grew, true, e.TotalCounter +} + +// NeedGrow returns if we need to grow. +// It is exported for testing. +func (e *PollTiFlashBackoffElement) NeedGrow() bool { + return e.Counter >= int(e.Threshold) +} + +func (e *PollTiFlashBackoffElement) doGrow(b *PollTiFlashBackoffContext) { + if e.Threshold < b.MinThreshold { + e.Threshold = b.MinThreshold + } + if e.Threshold*b.Rate > b.MaxThreshold { + e.Threshold = b.MaxThreshold + } else { + e.Threshold *= b.Rate + } + e.Counter = 0 +} + +// MaybeGrow grows threshold and reset counter when needed. +func (e *PollTiFlashBackoffElement) MaybeGrow(b *PollTiFlashBackoffContext) bool { + if !e.NeedGrow() { + return false + } + e.doGrow(b) + return true +} + +// Remove will reset table from backoff. +func (b *PollTiFlashBackoffContext) Remove(id int64) bool { + _, ok := b.elements[id] + delete(b.elements, id) + return ok +} + +// Get returns pointer to inner PollTiFlashBackoffElement. +// Only exported for test. +func (b *PollTiFlashBackoffContext) Get(id int64) (*PollTiFlashBackoffElement, bool) { + res, ok := b.elements[id] + return res, ok +} + +// Put will record table into backoff pool, if there is enough room, or returns false. +func (b *PollTiFlashBackoffContext) Put(id int64) bool { + _, ok := b.elements[id] + if ok { + return true + } else if b.Len() < b.Capacity { + b.elements[id] = NewPollTiFlashBackoffElement() + return true + } + return false +} + +// Len gets size of PollTiFlashBackoffContext. +func (b *PollTiFlashBackoffContext) Len() int { + return len(b.elements) +} + +// NewTiFlashManagementContext creates an instance for TiFlashManagementContext. +func NewTiFlashManagementContext() (*TiFlashManagementContext, error) { + c, err := NewPollTiFlashBackoffContext(PollTiFlashBackoffMinTick, PollTiFlashBackoffMaxTick, PollTiFlashBackoffCapacity, PollTiFlashBackoffRate) + if err != nil { + return nil, err + } + return &TiFlashManagementContext{ + PollCounter: 0, + TiFlashStores: make(map[int64]pd.StoreInfo), + Backoff: c, + UpdatingProgressTables: list.New(), + }, nil +} + +var ( + // PollTiFlashInterval is the interval between every pollTiFlashReplicaStatus call. + PollTiFlashInterval = 2 * time.Second + // PullTiFlashPdTick indicates the number of intervals before we fully sync all TiFlash pd rules and tables. + PullTiFlashPdTick = atomicutil.NewUint64(30 * 5) + // UpdateTiFlashStoreTick indicates the number of intervals before we fully update TiFlash stores. + UpdateTiFlashStoreTick = atomicutil.NewUint64(5) + // PollTiFlashBackoffMaxTick is the max tick before we try to update TiFlash replica availability for one table. + PollTiFlashBackoffMaxTick TiFlashTick = 10 + // PollTiFlashBackoffMinTick is the min tick before we try to update TiFlash replica availability for one table. + PollTiFlashBackoffMinTick TiFlashTick = 1 + // PollTiFlashBackoffCapacity is the cache size of backoff struct. + PollTiFlashBackoffCapacity = 1000 + // PollTiFlashBackoffRate is growth rate of exponential backoff threshold. + PollTiFlashBackoffRate TiFlashTick = 1.5 + // RefreshProgressMaxTableCount is the max count of table to refresh progress after available each poll. + RefreshProgressMaxTableCount uint64 = 1000 +) + +func getTiflashHTTPAddr(host string, statusAddr string) (string, error) { + configURL := fmt.Sprintf("%s://%s/config", + util.InternalHTTPSchema(), + statusAddr, + ) + resp, err := util.InternalHTTPClient().Get(configURL) + if err != nil { + return "", errors.Trace(err) + } + defer func() { + resp.Body.Close() + }() + + buf := new(bytes.Buffer) + _, err = buf.ReadFrom(resp.Body) + if err != nil { + return "", errors.Trace(err) + } + + var j map[string]any + err = json.Unmarshal(buf.Bytes(), &j) + if err != nil { + return "", errors.Trace(err) + } + + engineStore, ok := j["engine-store"].(map[string]any) + if !ok { + return "", errors.New("Error json") + } + port64, ok := engineStore["http_port"].(float64) + if !ok { + return "", errors.New("Error json") + } + + addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port64), 10)) + return addr, nil +} + +// LoadTiFlashReplicaInfo parses model.TableInfo into []TiFlashReplicaStatus. +func LoadTiFlashReplicaInfo(tblInfo *model.TableInfo, tableList *[]TiFlashReplicaStatus) { + if tblInfo.TiFlashReplica == nil { + // reject tables that has no tiflash replica such like `INFORMATION_SCHEMA` + return + } + if pi := tblInfo.GetPartitionInfo(); pi != nil { + for _, p := range pi.Definitions { + logutil.DDLLogger().Debug(fmt.Sprintf("Table %v has partition %v\n", tblInfo.ID, p.ID)) + *tableList = append(*tableList, TiFlashReplicaStatus{p.ID, + tblInfo.TiFlashReplica.Count, tblInfo.TiFlashReplica.LocationLabels, tblInfo.TiFlashReplica.IsPartitionAvailable(p.ID), tblInfo.TiFlashReplica.Available, false, true}) + } + // partitions that in adding mid-state + for _, p := range pi.AddingDefinitions { + logutil.DDLLogger().Debug(fmt.Sprintf("Table %v has partition adding %v\n", tblInfo.ID, p.ID)) + *tableList = append(*tableList, TiFlashReplicaStatus{p.ID, tblInfo.TiFlashReplica.Count, tblInfo.TiFlashReplica.LocationLabels, tblInfo.TiFlashReplica.IsPartitionAvailable(p.ID), tblInfo.TiFlashReplica.Available, true, true}) + } + } else { + logutil.DDLLogger().Debug(fmt.Sprintf("Table %v has no partition\n", tblInfo.ID)) + *tableList = append(*tableList, TiFlashReplicaStatus{tblInfo.ID, tblInfo.TiFlashReplica.Count, tblInfo.TiFlashReplica.LocationLabels, tblInfo.TiFlashReplica.Available, tblInfo.TiFlashReplica.Available, false, false}) + } +} + +// UpdateTiFlashHTTPAddress report TiFlash's StatusAddress's port to Pd's etcd. +func (d *ddl) UpdateTiFlashHTTPAddress(store *pd.StoreInfo) error { + host, _, err := net.SplitHostPort(store.Store.StatusAddress) + if err != nil { + return errors.Trace(err) + } + httpAddr, err := getTiflashHTTPAddr(host, store.Store.StatusAddress) + if err != nil { + return errors.Trace(err) + } + // Report to pd + key := fmt.Sprintf("/tiflash/cluster/http_port/%v", store.Store.Address) + if d.etcdCli == nil { + return errors.New("no etcdCli in ddl") + } + origin := "" + resp, err := d.etcdCli.Get(d.ctx, key) + if err != nil { + return errors.Trace(err) + } + // Try to update. + for _, kv := range resp.Kvs { + if string(kv.Key) == key { + origin = string(kv.Value) + break + } + } + if origin != httpAddr { + logutil.DDLLogger().Warn(fmt.Sprintf("Update status addr of %v from %v to %v", key, origin, httpAddr)) + err := ddlutil.PutKVToEtcd(d.ctx, d.etcdCli, 1, key, httpAddr) + if err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func updateTiFlashStores(pollTiFlashContext *TiFlashManagementContext) error { + // We need the up-to-date information about TiFlash stores. + // Since TiFlash Replica synchronize may happen immediately after new TiFlash stores are added. + tikvStats, err := infosync.GetTiFlashStoresStat(context.Background()) + // If MockTiFlash is not set, will issue a MockTiFlashError here. + if err != nil { + return err + } + pollTiFlashContext.TiFlashStores = make(map[int64]pd.StoreInfo) + for _, store := range tikvStats.Stores { + if engine.IsTiFlashHTTPResp(&store.Store) { + pollTiFlashContext.TiFlashStores[store.Store.ID] = store + } + } + logutil.DDLLogger().Debug("updateTiFlashStores finished", zap.Int("TiFlash store count", len(pollTiFlashContext.TiFlashStores))) + return nil +} + +// PollAvailableTableProgress will poll and check availability of available tables. +func PollAvailableTableProgress(schemas infoschema.InfoSchema, _ sessionctx.Context, pollTiFlashContext *TiFlashManagementContext) { + pollMaxCount := RefreshProgressMaxTableCount + failpoint.Inject("PollAvailableTableProgressMaxCount", func(val failpoint.Value) { + pollMaxCount = uint64(val.(int)) + }) + for element := pollTiFlashContext.UpdatingProgressTables.Front(); element != nil && pollMaxCount > 0; pollMaxCount-- { + availableTableID := element.Value.(AvailableTableID) + var table table.Table + if availableTableID.IsPartition { + table, _, _ = schemas.FindTableByPartitionID(availableTableID.ID) + if table == nil { + logutil.DDLLogger().Info("get table by partition failed, may be dropped or truncated", + zap.Int64("partitionID", availableTableID.ID), + ) + pollTiFlashContext.UpdatingProgressTables.Remove(element) + element = element.Next() + continue + } + } else { + var ok bool + table, ok = schemas.TableByID(availableTableID.ID) + if !ok { + logutil.DDLLogger().Info("get table id failed, may be dropped or truncated", + zap.Int64("tableID", availableTableID.ID), + ) + pollTiFlashContext.UpdatingProgressTables.Remove(element) + element = element.Next() + continue + } + } + tableInfo := table.Meta() + if tableInfo.TiFlashReplica == nil { + logutil.DDLLogger().Info("table has no TiFlash replica", + zap.Int64("tableID or partitionID", availableTableID.ID), + zap.Bool("IsPartition", availableTableID.IsPartition), + ) + pollTiFlashContext.UpdatingProgressTables.Remove(element) + element = element.Next() + continue + } + + progress, err := infosync.CalculateTiFlashProgress(availableTableID.ID, tableInfo.TiFlashReplica.Count, pollTiFlashContext.TiFlashStores) + if err != nil { + if intest.InTest && err.Error() != "EOF" { + // In the test, the server cannot start up because the port is occupied. + // Although the port is random. so we need to quickly return when to + // fail to get tiflash sync. + // https://github.com/pingcap/tidb/issues/39949 + panic(err) + } + pollTiFlashContext.UpdatingProgressTables.Remove(element) + element = element.Next() + continue + } + err = infosync.UpdateTiFlashProgressCache(availableTableID.ID, progress) + if err != nil { + logutil.DDLLogger().Error("update tiflash sync progress cache failed", + zap.Error(err), + zap.Int64("tableID", availableTableID.ID), + zap.Bool("IsPartition", availableTableID.IsPartition), + zap.Float64("progress", progress), + ) + pollTiFlashContext.UpdatingProgressTables.Remove(element) + element = element.Next() + continue + } + next := element.Next() + pollTiFlashContext.UpdatingProgressTables.Remove(element) + element = next + } +} + +func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *TiFlashManagementContext) error { + if pollTiFlashContext.PollCounter%UpdateTiFlashStoreTick.Load() == 0 { + if err := updateTiFlashStores(pollTiFlashContext); err != nil { + // If we failed to get from pd, retry everytime. + pollTiFlashContext.PollCounter = 0 + return err + } + } + + failpoint.Inject("OneTiFlashStoreDown", func() { + for storeID, store := range pollTiFlashContext.TiFlashStores { + store.Store.StateName = "Down" + pollTiFlashContext.TiFlashStores[storeID] = store + break + } + }) + pollTiFlashContext.PollCounter++ + + // Start to process every table. + schema := d.infoCache.GetLatest() + if schema == nil { + return errors.New("Schema is nil") + } + + PollAvailableTableProgress(schema, ctx, pollTiFlashContext) + + var tableList = make([]TiFlashReplicaStatus, 0) + + // Collect TiFlash Replica info, for every table. + ch := schema.ListTablesWithSpecialAttribute(infoschema.TiFlashAttribute) + for _, v := range ch { + for _, tblInfo := range v.TableInfos { + LoadTiFlashReplicaInfo(tblInfo, &tableList) + } + } + + failpoint.Inject("waitForAddPartition", func(val failpoint.Value) { + for _, phyTable := range tableList { + is := d.infoCache.GetLatest() + _, ok := is.TableByID(phyTable.ID) + if !ok { + tb, _, _ := is.FindTableByPartitionID(phyTable.ID) + if tb == nil { + logutil.DDLLogger().Info("waitForAddPartition") + sleepSecond := val.(int) + time.Sleep(time.Duration(sleepSecond) * time.Second) + } + } + } + }) + + needPushPending := false + if pollTiFlashContext.UpdatingProgressTables.Len() == 0 { + needPushPending = true + } + + for _, tb := range tableList { + // For every region in each table, if it has one replica, we reckon it ready. + // These request can be batched as an optimization. + available := tb.Available + failpoint.Inject("PollTiFlashReplicaStatusReplacePrevAvailableValue", func(val failpoint.Value) { + available = val.(bool) + }) + // We only check unavailable tables here, so doesn't include blocked add partition case. + if !available && !tb.LogicalTableAvailable { + enabled, inqueue, _ := pollTiFlashContext.Backoff.Tick(tb.ID) + if inqueue && !enabled { + logutil.DDLLogger().Info("Escape checking available status due to backoff", zap.Int64("tableId", tb.ID)) + continue + } + + progress, err := infosync.CalculateTiFlashProgress(tb.ID, tb.Count, pollTiFlashContext.TiFlashStores) + if err != nil { + logutil.DDLLogger().Error("get tiflash sync progress failed", + zap.Error(err), + zap.Int64("tableID", tb.ID), + ) + continue + } + + err = infosync.UpdateTiFlashProgressCache(tb.ID, progress) + if err != nil { + logutil.DDLLogger().Error("get tiflash sync progress from cache failed", + zap.Error(err), + zap.Int64("tableID", tb.ID), + zap.Bool("IsPartition", tb.IsPartition), + zap.Float64("progress", progress), + ) + continue + } + + avail := progress == 1 + failpoint.Inject("PollTiFlashReplicaStatusReplaceCurAvailableValue", func(val failpoint.Value) { + avail = val.(bool) + }) + + if !avail { + logutil.DDLLogger().Info("Tiflash replica is not available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) + pollTiFlashContext.Backoff.Put(tb.ID) + } else { + logutil.DDLLogger().Info("Tiflash replica is available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) + pollTiFlashContext.Backoff.Remove(tb.ID) + } + failpoint.Inject("skipUpdateTableReplicaInfoInLoop", func() { + failpoint.Continue() + }) + // Will call `onUpdateFlashReplicaStatus` to update `TiFlashReplica`. + if err := d.executor.UpdateTableReplicaInfo(ctx, tb.ID, avail); err != nil { + if infoschema.ErrTableNotExists.Equal(err) && tb.IsPartition { + // May be due to blocking add partition + logutil.DDLLogger().Info("updating TiFlash replica status err, maybe false alarm by blocking add", zap.Error(err), zap.Int64("tableID", tb.ID), zap.Bool("isPartition", tb.IsPartition)) + } else { + logutil.DDLLogger().Error("updating TiFlash replica status err", zap.Error(err), zap.Int64("tableID", tb.ID), zap.Bool("isPartition", tb.IsPartition)) + } + } + } else { + if needPushPending { + pollTiFlashContext.UpdatingProgressTables.PushFront(AvailableTableID{tb.ID, tb.IsPartition}) + } + } + } + + return nil +} + +func (d *ddl) PollTiFlashRoutine() { + pollTiflashContext, err := NewTiFlashManagementContext() + if err != nil { + logutil.DDLLogger().Fatal("TiFlashManagement init failed", zap.Error(err)) + } + + hasSetTiFlashGroup := false + nextSetTiFlashGroupTime := time.Now() + for { + select { + case <-d.ctx.Done(): + return + case <-time.After(PollTiFlashInterval): + } + if d.IsTiFlashPollEnabled() { + if d.sessPool == nil { + logutil.DDLLogger().Error("failed to get sessionPool for refreshTiFlashTicker") + return + } + failpoint.Inject("BeforeRefreshTiFlashTickeLoop", func() { + failpoint.Continue() + }) + + if !hasSetTiFlashGroup && !time.Now().Before(nextSetTiFlashGroupTime) { + // We should set tiflash rule group a higher index than other placement groups to forbid override by them. + // Once `SetTiFlashGroupConfig` succeed, we do not need to invoke it again. If failed, we should retry it util success. + if err = infosync.SetTiFlashGroupConfig(d.ctx); err != nil { + logutil.DDLLogger().Warn("SetTiFlashGroupConfig failed", zap.Error(err)) + nextSetTiFlashGroupTime = time.Now().Add(time.Minute) + } else { + hasSetTiFlashGroup = true + } + } + + sctx, err := d.sessPool.Get() + if err == nil { + if d.ownerManager.IsOwner() { + err := d.refreshTiFlashTicker(sctx, pollTiflashContext) + if err != nil { + switch err.(type) { + case *infosync.MockTiFlashError: + // If we have not set up MockTiFlash instance, for those tests without TiFlash, just suppress. + default: + logutil.DDLLogger().Warn("refreshTiFlashTicker returns error", zap.Error(err)) + } + } + } else { + infosync.CleanTiFlashProgressCache() + } + d.sessPool.Put(sctx) + } else { + if sctx != nil { + d.sessPool.Put(sctx) + } + logutil.DDLLogger().Error("failed to get session for pollTiFlashReplicaStatus", zap.Error(err)) + } + } + } +} diff --git a/pkg/ddl/delete_range.go b/pkg/ddl/delete_range.go index bafa01c847977..bd5e0920aa000 100644 --- a/pkg/ddl/delete_range.go +++ b/pkg/ddl/delete_range.go @@ -372,11 +372,11 @@ func insertJobIntoDeleteRangeTable(ctx context.Context, wrapper DelRangeExecWrap if len(partitionIDs) == 0 { return errors.Trace(doBatchDeleteIndiceRange(ctx, wrapper, job.ID, tableID, allIndexIDs, ea, "drop index: table ID")) } - failpoint.Inject("checkDropGlobalIndex", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkDropGlobalIndex")); _err_ == nil { if val.(bool) { panic("drop global index must not delete partition index range") } - }) + } for _, pid := range partitionIDs { if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, allIndexIDs, ea, "drop index: partition table ID"); err != nil { return errors.Trace(err) diff --git a/pkg/ddl/delete_range.go__failpoint_stash__ b/pkg/ddl/delete_range.go__failpoint_stash__ new file mode 100644 index 0000000000000..bafa01c847977 --- /dev/null +++ b/pkg/ddl/delete_range.go__failpoint_stash__ @@ -0,0 +1,548 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "encoding/hex" + "math" + "strings" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/tablecodec" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "go.uber.org/zap" +) + +const ( + insertDeleteRangeSQLPrefix = `INSERT IGNORE INTO mysql.gc_delete_range VALUES ` + insertDeleteRangeSQLValue = `(%?, %?, %?, %?, %?)` + + delBatchSize = 65536 + delBackLog = 128 +) + +// Only used in the BR unit test. Once these const variables modified, please make sure compatible with BR. +const ( + BRInsertDeleteRangeSQLPrefix = insertDeleteRangeSQLPrefix + BRInsertDeleteRangeSQLValue = insertDeleteRangeSQLValue +) + +var ( + // batchInsertDeleteRangeSize is the maximum size for each batch insert statement in the delete-range. + batchInsertDeleteRangeSize = 256 +) + +type delRangeManager interface { + // addDelRangeJob add a DDL job into gc_delete_range table. + addDelRangeJob(ctx context.Context, job *model.Job) error + // removeFromGCDeleteRange removes the deleting table job from gc_delete_range table by jobID and tableID. + // It's use for recover the table that was mistakenly deleted. + removeFromGCDeleteRange(ctx context.Context, jobID int64) error + start() + clear() +} + +type delRange struct { + store kv.Storage + sessPool *sess.Pool + emulatorCh chan struct{} + keys []kv.Key + quitCh chan struct{} + + wait sync.WaitGroup // wait is only used when storeSupport is false. + storeSupport bool +} + +// newDelRangeManager returns a delRangeManager. +func newDelRangeManager(store kv.Storage, sessPool *sess.Pool) delRangeManager { + dr := &delRange{ + store: store, + sessPool: sessPool, + storeSupport: store.SupportDeleteRange(), + quitCh: make(chan struct{}), + } + if !dr.storeSupport { + dr.emulatorCh = make(chan struct{}, delBackLog) + dr.keys = make([]kv.Key, 0, delBatchSize) + } + return dr +} + +// addDelRangeJob implements delRangeManager interface. +func (dr *delRange) addDelRangeJob(ctx context.Context, job *model.Job) error { + sctx, err := dr.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer dr.sessPool.Put(sctx) + + // The same Job ID uses the same element ID allocator + wrapper := newDelRangeExecWrapper(sctx) + err = AddDelRangeJobInternal(ctx, wrapper, job) + if err != nil { + logutil.DDLLogger().Error("add job into delete-range table failed", zap.Int64("jobID", job.ID), zap.String("jobType", job.Type.String()), zap.Error(err)) + return errors.Trace(err) + } + if !dr.storeSupport { + dr.emulatorCh <- struct{}{} + } + logutil.DDLLogger().Info("add job into delete-range table", zap.Int64("jobID", job.ID), zap.String("jobType", job.Type.String())) + return nil +} + +// AddDelRangeJobInternal implements the generation the delete ranges for the provided job and consumes the delete ranges through delRangeExecWrapper. +func AddDelRangeJobInternal(ctx context.Context, wrapper DelRangeExecWrapper, job *model.Job) error { + var err error + var ea elementIDAlloc + if job.MultiSchemaInfo != nil { + err = insertJobIntoDeleteRangeTableMultiSchema(ctx, wrapper, job, &ea) + } else { + err = insertJobIntoDeleteRangeTable(ctx, wrapper, job, &ea) + } + return errors.Trace(err) +} + +func insertJobIntoDeleteRangeTableMultiSchema(ctx context.Context, wrapper DelRangeExecWrapper, job *model.Job, ea *elementIDAlloc) error { + for i, sub := range job.MultiSchemaInfo.SubJobs { + proxyJob := sub.ToProxyJob(job, i) + if JobNeedGC(&proxyJob) { + err := insertJobIntoDeleteRangeTable(ctx, wrapper, &proxyJob, ea) + if err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +// removeFromGCDeleteRange implements delRangeManager interface. +func (dr *delRange) removeFromGCDeleteRange(ctx context.Context, jobID int64) error { + sctx, err := dr.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer dr.sessPool.Put(sctx) + err = util.RemoveMultiFromGCDeleteRange(ctx, sctx, jobID) + return errors.Trace(err) +} + +// start implements delRangeManager interface. +func (dr *delRange) start() { + if !dr.storeSupport { + dr.wait.Add(1) + go dr.startEmulator() + } +} + +// clear implements delRangeManager interface. +func (dr *delRange) clear() { + logutil.DDLLogger().Info("closing delRange") + close(dr.quitCh) + dr.wait.Wait() +} + +// startEmulator is only used for those storage engines which don't support +// delete-range. The emulator fetches records from gc_delete_range table and +// deletes all keys in each DelRangeTask. +func (dr *delRange) startEmulator() { + defer dr.wait.Done() + logutil.DDLLogger().Info("start delRange emulator") + for { + select { + case <-dr.emulatorCh: + case <-dr.quitCh: + return + } + if util.IsEmulatorGCEnable() { + err := dr.doDelRangeWork() + terror.Log(errors.Trace(err)) + } + } +} + +func (dr *delRange) doDelRangeWork() error { + sctx, err := dr.sessPool.Get() + if err != nil { + logutil.DDLLogger().Error("delRange emulator get session failed", zap.Error(err)) + return errors.Trace(err) + } + defer dr.sessPool.Put(sctx) + + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + ranges, err := util.LoadDeleteRanges(ctx, sctx, math.MaxInt64) + if err != nil { + logutil.DDLLogger().Error("delRange emulator load tasks failed", zap.Error(err)) + return errors.Trace(err) + } + + for _, r := range ranges { + if err := dr.doTask(sctx, r); err != nil { + logutil.DDLLogger().Error("delRange emulator do task failed", zap.Error(err)) + return errors.Trace(err) + } + } + return nil +} + +func (dr *delRange) doTask(sctx sessionctx.Context, r util.DelRangeTask) error { + var oldStartKey, newStartKey kv.Key + oldStartKey = r.StartKey + for { + finish := true + dr.keys = dr.keys[:0] + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + err := kv.RunInNewTxn(ctx, dr.store, false, func(_ context.Context, txn kv.Transaction) error { + if topsqlstate.TopSQLEnabled() { + // Only when TiDB run without PD(use unistore as storage for test) will run into here, so just set a mock internal resource tagger. + txn.SetOption(kv.ResourceGroupTagger, util.GetInternalResourceGroupTaggerForTopSQL()) + } + iter, err := txn.Iter(oldStartKey, r.EndKey) + if err != nil { + return errors.Trace(err) + } + defer iter.Close() + + txn.SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + for i := 0; i < delBatchSize; i++ { + if !iter.Valid() { + break + } + finish = false + dr.keys = append(dr.keys, iter.Key().Clone()) + newStartKey = iter.Key().Next() + + if err := iter.Next(); err != nil { + return errors.Trace(err) + } + } + + for _, key := range dr.keys { + err := txn.Delete(key) + if err != nil && !kv.ErrNotExist.Equal(err) { + return errors.Trace(err) + } + } + return nil + }) + if err != nil { + return errors.Trace(err) + } + if finish { + if err := util.CompleteDeleteRange(sctx, r, true); err != nil { + logutil.DDLLogger().Error("delRange emulator complete task failed", zap.Error(err)) + return errors.Trace(err) + } + startKey, endKey := r.Range() + logutil.DDLLogger().Info("delRange emulator complete task", zap.String("category", "ddl"), + zap.Int64("jobID", r.JobID), + zap.Int64("elementID", r.ElementID), + zap.Stringer("startKey", startKey), + zap.Stringer("endKey", endKey)) + break + } + if err := util.UpdateDeleteRange(sctx, r, newStartKey, oldStartKey); err != nil { + logutil.DDLLogger().Error("delRange emulator update task failed", zap.Error(err)) + } + oldStartKey = newStartKey + } + return nil +} + +// insertJobIntoDeleteRangeTable parses the job into delete-range arguments, +// and inserts a new record into gc_delete_range table. The primary key is +// (job ID, element ID), so we ignore key conflict error. +func insertJobIntoDeleteRangeTable(ctx context.Context, wrapper DelRangeExecWrapper, job *model.Job, ea *elementIDAlloc) error { + if err := wrapper.UpdateTSOForJob(); err != nil { + return errors.Trace(err) + } + + ctx = kv.WithInternalSourceType(ctx, getDDLRequestSource(job.Type)) + switch job.Type { + case model.ActionDropSchema: + var tableIDs []int64 + if err := job.DecodeArgs(&tableIDs); err != nil { + return errors.Trace(err) + } + for i := 0; i < len(tableIDs); i += batchInsertDeleteRangeSize { + batchEnd := len(tableIDs) + if batchEnd > i+batchInsertDeleteRangeSize { + batchEnd = i + batchInsertDeleteRangeSize + } + if err := doBatchDeleteTablesRange(ctx, wrapper, job.ID, tableIDs[i:batchEnd], ea, "drop schema: table IDs"); err != nil { + return errors.Trace(err) + } + } + case model.ActionDropTable, model.ActionTruncateTable: + tableID := job.TableID + // The startKey here is for compatibility with previous versions, old version did not endKey so don't have to deal with. + var startKey kv.Key + var physicalTableIDs []int64 + var ruleIDs []string + if err := job.DecodeArgs(&startKey, &physicalTableIDs, &ruleIDs); err != nil { + return errors.Trace(err) + } + if len(physicalTableIDs) > 0 { + if err := doBatchDeleteTablesRange(ctx, wrapper, job.ID, physicalTableIDs, ea, "drop table: partition table IDs"); err != nil { + return errors.Trace(err) + } + // logical table may contain global index regions, so delete the logical table range. + return errors.Trace(doBatchDeleteTablesRange(ctx, wrapper, job.ID, []int64{tableID}, ea, "drop table: table ID")) + } + return errors.Trace(doBatchDeleteTablesRange(ctx, wrapper, job.ID, []int64{tableID}, ea, "drop table: table ID")) + case model.ActionDropTablePartition, model.ActionTruncateTablePartition, + model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + var physicalTableIDs []int64 + if err := job.DecodeArgs(&physicalTableIDs); err != nil { + return errors.Trace(err) + } + return errors.Trace(doBatchDeleteTablesRange(ctx, wrapper, job.ID, physicalTableIDs, ea, "reorganize/drop partition: physical table ID(s)")) + // ActionAddIndex, ActionAddPrimaryKey needs do it, because it needs to be rolled back when it's canceled. + case model.ActionAddIndex, model.ActionAddPrimaryKey: + allIndexIDs := make([]int64, 1) + ifExists := make([]bool, 1) + isGlobal := make([]bool, 0, 1) + var partitionIDs []int64 + if err := job.DecodeArgs(&allIndexIDs[0], &ifExists[0], &partitionIDs); err != nil { + if err = job.DecodeArgs(&allIndexIDs, &ifExists, &partitionIDs, &isGlobal); err != nil { + return errors.Trace(err) + } + } + // Determine the physicalIDs to be added. + physicalIDs := []int64{job.TableID} + if len(partitionIDs) > 0 { + physicalIDs = partitionIDs + } + for i, indexID := range allIndexIDs { + // Determine the index IDs to be added. + tempIdxID := tablecodec.TempIndexPrefix | indexID + var indexIDs []int64 + if job.State == model.JobStateRollbackDone { + indexIDs = []int64{indexID, tempIdxID} + } else { + indexIDs = []int64{tempIdxID} + } + if len(isGlobal) != 0 && isGlobal[i] { + if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, job.TableID, indexIDs, ea, "add index: physical table ID(s)"); err != nil { + return errors.Trace(err) + } + } else { + for _, pid := range physicalIDs { + if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, indexIDs, ea, "add index: physical table ID(s)"); err != nil { + return errors.Trace(err) + } + } + } + } + case model.ActionDropIndex, model.ActionDropPrimaryKey: + tableID := job.TableID + var indexName any + var partitionIDs []int64 + ifExists := make([]bool, 1) + allIndexIDs := make([]int64, 1) + if err := job.DecodeArgs(&indexName, &ifExists[0], &allIndexIDs[0], &partitionIDs); err != nil { + if err = job.DecodeArgs(&indexName, &ifExists, &allIndexIDs, &partitionIDs); err != nil { + return errors.Trace(err) + } + } + // partitionIDs len is 0 if the dropped index is a global index, even if it is a partitioned table. + if len(partitionIDs) == 0 { + return errors.Trace(doBatchDeleteIndiceRange(ctx, wrapper, job.ID, tableID, allIndexIDs, ea, "drop index: table ID")) + } + failpoint.Inject("checkDropGlobalIndex", func(val failpoint.Value) { + if val.(bool) { + panic("drop global index must not delete partition index range") + } + }) + for _, pid := range partitionIDs { + if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, allIndexIDs, ea, "drop index: partition table ID"); err != nil { + return errors.Trace(err) + } + } + case model.ActionDropColumn: + var colName model.CIStr + var ifExists bool + var indexIDs []int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&colName, &ifExists, &indexIDs, &partitionIDs); err != nil { + return errors.Trace(err) + } + if len(indexIDs) > 0 { + if len(partitionIDs) == 0 { + return errors.Trace(doBatchDeleteIndiceRange(ctx, wrapper, job.ID, job.TableID, indexIDs, ea, "drop column: table ID")) + } + for _, pid := range partitionIDs { + if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, indexIDs, ea, "drop column: partition table ID"); err != nil { + return errors.Trace(err) + } + } + } + case model.ActionModifyColumn: + var indexIDs []int64 + var partitionIDs []int64 + if err := job.DecodeArgs(&indexIDs, &partitionIDs); err != nil { + return errors.Trace(err) + } + if len(indexIDs) == 0 { + return nil + } + if len(partitionIDs) == 0 { + return doBatchDeleteIndiceRange(ctx, wrapper, job.ID, job.TableID, indexIDs, ea, "modify column: table ID") + } + for _, pid := range partitionIDs { + if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, indexIDs, ea, "modify column: partition table ID"); err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +func doBatchDeleteIndiceRange(ctx context.Context, wrapper DelRangeExecWrapper, jobID, tableID int64, indexIDs []int64, ea *elementIDAlloc, comment string) error { + logutil.DDLLogger().Info("insert into delete-range indices", zap.Int64("jobID", jobID), zap.Int64("tableID", tableID), zap.Int64s("indexIDs", indexIDs), zap.String("comment", comment)) + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) + wrapper.PrepareParamsList(len(indexIDs) * 5) + tableID, ok := wrapper.RewriteTableID(tableID) + if !ok { + return nil + } + for i, indexID := range indexIDs { + startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) + endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) + startKeyEncoded := hex.EncodeToString(startKey) + endKeyEncoded := hex.EncodeToString(endKey) + buf.WriteString(insertDeleteRangeSQLValue) + if i != len(indexIDs)-1 { + buf.WriteString(",") + } + elemID := ea.allocForIndexID(tableID, indexID) + wrapper.AppendParamsList(jobID, elemID, startKeyEncoded, endKeyEncoded) + } + + return errors.Trace(wrapper.ConsumeDeleteRange(ctx, buf.String())) +} + +func doBatchDeleteTablesRange(ctx context.Context, wrapper DelRangeExecWrapper, jobID int64, tableIDs []int64, ea *elementIDAlloc, comment string) error { + logutil.DDLLogger().Info("insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("tableIDs", tableIDs), zap.String("comment", comment)) + var buf strings.Builder + buf.WriteString(insertDeleteRangeSQLPrefix) + wrapper.PrepareParamsList(len(tableIDs) * 5) + for i, tableID := range tableIDs { + tableID, ok := wrapper.RewriteTableID(tableID) + if !ok { + continue + } + startKey := tablecodec.EncodeTablePrefix(tableID) + endKey := tablecodec.EncodeTablePrefix(tableID + 1) + startKeyEncoded := hex.EncodeToString(startKey) + endKeyEncoded := hex.EncodeToString(endKey) + buf.WriteString(insertDeleteRangeSQLValue) + if i != len(tableIDs)-1 { + buf.WriteString(",") + } + elemID := ea.allocForPhysicalID(tableID) + wrapper.AppendParamsList(jobID, elemID, startKeyEncoded, endKeyEncoded) + } + + return errors.Trace(wrapper.ConsumeDeleteRange(ctx, buf.String())) +} + +// DelRangeExecWrapper consumes the delete ranges with the provided table ID(s) and index ID(s). +type DelRangeExecWrapper interface { + // generate a new tso for the next job + UpdateTSOForJob() error + + // initialize the paramsList + PrepareParamsList(sz int) + + // rewrite table id if necessary, used for BR + RewriteTableID(tableID int64) (int64, bool) + + // (job_id, element_id, start_key, end_key, ts) + // ts is generated by delRangeExecWrapper itself + AppendParamsList(jobID, elemID int64, startKey, endKey string) + + // consume the delete range. For TiDB Server, it insert rows into mysql.gc_delete_range. + ConsumeDeleteRange(ctx context.Context, sql string) error +} + +// sessionDelRangeExecWrapper is a lightweight wrapper that implements the DelRangeExecWrapper interface and used for TiDB Server. +// It consumes the delete ranges by directly insert rows into mysql.gc_delete_range. +type sessionDelRangeExecWrapper struct { + sctx sessionctx.Context + ts uint64 + + // temporary values + paramsList []any +} + +func newDelRangeExecWrapper(sctx sessionctx.Context) DelRangeExecWrapper { + return &sessionDelRangeExecWrapper{ + sctx: sctx, + paramsList: nil, + } +} + +func (sdr *sessionDelRangeExecWrapper) UpdateTSOForJob() error { + now, err := getNowTSO(sdr.sctx) + if err != nil { + return errors.Trace(err) + } + sdr.ts = now + return nil +} + +func (sdr *sessionDelRangeExecWrapper) PrepareParamsList(sz int) { + sdr.paramsList = make([]any, 0, sz) +} + +func (*sessionDelRangeExecWrapper) RewriteTableID(tableID int64) (int64, bool) { + return tableID, true +} + +func (sdr *sessionDelRangeExecWrapper) AppendParamsList(jobID, elemID int64, startKey, endKey string) { + sdr.paramsList = append(sdr.paramsList, jobID, elemID, startKey, endKey, sdr.ts) +} + +func (sdr *sessionDelRangeExecWrapper) ConsumeDeleteRange(ctx context.Context, sql string) error { + // set session disk full opt + sdr.sctx.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + _, err := sdr.sctx.GetSQLExecutor().ExecuteInternal(ctx, sql, sdr.paramsList...) + // clear session disk full opt + sdr.sctx.GetSessionVars().ClearDiskFullOpt() + sdr.paramsList = nil + return errors.Trace(err) +} + +// getNowTS gets the current timestamp, in TSO. +func getNowTSO(ctx sessionctx.Context) (uint64, error) { + currVer, err := ctx.GetStore().CurrentVersion(kv.GlobalTxnScope) + if err != nil { + return 0, errors.Trace(err) + } + return currVer.Ver, nil +} diff --git a/pkg/ddl/executor.go b/pkg/ddl/executor.go index 225ba84582730..834d1e7048785 100644 --- a/pkg/ddl/executor.go +++ b/pkg/ddl/executor.go @@ -435,19 +435,19 @@ func isSessionDone(sctx sessionctx.Context) (bool, uint32) { if killed { return true, 1 } - failpoint.Inject("BatchAddTiFlashSendDone", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("BatchAddTiFlashSendDone")); _err_ == nil { done = val.(bool) - }) + } return done, 0 } func (e *executor) waitPendingTableThreshold(sctx sessionctx.Context, schemaID int64, tableID int64, originVersion int64, pendingCount uint32, threshold uint32) (bool, int64, uint32, bool) { configRetry := tiflashCheckPendingTablesRetry configWaitTime := tiflashCheckPendingTablesWaitTime - failpoint.Inject("FastFailCheckTiFlashPendingTables", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("FastFailCheckTiFlashPendingTables")); _err_ == nil { configRetry = value.(int) configWaitTime = time.Millisecond * 200 - }) + } for retry := 0; retry < configRetry; retry++ { done, killed := isSessionDone(sctx) @@ -1190,12 +1190,12 @@ func (e *executor) BatchCreateTableWithInfo(ctx sessionctx.Context, infos []*model.TableInfo, cs ...CreateTableOption, ) error { - failpoint.Inject("RestoreBatchCreateTableEntryTooLarge", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("RestoreBatchCreateTableEntryTooLarge")); _err_ == nil { injectBatchSize := val.(int) if len(infos) > injectBatchSize { - failpoint.Return(kv.ErrEntryTooLarge) + return kv.ErrEntryTooLarge } - }) + } c := GetCreateTableConfig(cs) jobW := NewJobWrapper( @@ -2181,7 +2181,7 @@ func (e *executor) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.Alt if err != nil { return errors.Trace(err) } - failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) + failpoint.Call(_curpkg_("afterGetSchemaAndTableByIdent"), ctx) tbInfo := t.Meta() if err = checkAddColumnTooManyColumns(len(t.Cols()) + 1); err != nil { return errors.Trace(err) @@ -2577,7 +2577,7 @@ func (e *executor) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. err = e.DoDDLJob(ctx, job) - failpoint.InjectCall("afterReorganizePartition") + failpoint.Call(_curpkg_("afterReorganizePartition")) if err == nil { ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of related partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) } @@ -3172,7 +3172,7 @@ func (e *executor) DropColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.Al if err != nil { return errors.Trace(err) } - failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) + failpoint.Call(_curpkg_("afterGetSchemaAndTableByIdent"), ctx) isDropable, err := checkIsDroppableColumn(ctx, e.infoCache.GetLatest(), schema, t, spec) if err != nil { @@ -4766,9 +4766,9 @@ func newReorgMetaFromVariables(job *model.Job, sctx sessionctx.Context) (*model. } reorgMeta.IsDistReorg = false reorgMeta.IsFastReorg = false - failpoint.Inject("reorgMetaRecordFastReorgDisabled", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("reorgMetaRecordFastReorgDisabled")); _err_ == nil { LastReorgMetaFastReorgDisabled = true - }) + } } return reorgMeta, nil } @@ -6296,7 +6296,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err setDDLJobQuery(ctx, job) e.deliverJobTask(jobW) - failpoint.Inject("mockParallelSameDDLJobTwice", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockParallelSameDDLJobTwice")); _err_ == nil { if val.(bool) { <-jobW.ResultCh[0] // The same job will be put to the DDL queue twice. @@ -6306,7 +6306,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err // The second job result is used for test. jobW = newJobW } - }) + } // worker should restart to continue handling tasks in limitJobCh, and send back through jobW.err result := <-jobW.ResultCh[0] @@ -6317,7 +6317,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err // The transaction of enqueuing job is failed. return errors.Trace(err) } - failpoint.InjectCall("waitJobSubmitted") + failpoint.Call(_curpkg_("waitJobSubmitted")) sessVars := ctx.GetSessionVars() sessVars.StmtCtx.IsDDLJobInQueue = true @@ -6353,7 +6353,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err i := 0 notifyCh, _ := e.getJobDoneCh(jobID) for { - failpoint.InjectCall("storeCloseInLoop") + failpoint.Call(_curpkg_("storeCloseInLoop")) select { case _, ok := <-notifyCh: if !ok { diff --git a/pkg/ddl/executor.go__failpoint_stash__ b/pkg/ddl/executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..225ba84582730 --- /dev/null +++ b/pkg/ddl/executor.go__failpoint_stash__ @@ -0,0 +1,6540 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package ddl + +import ( + "bytes" + "context" + "fmt" + "math" + "strings" + "sync" + "sync/atomic" + "time" + "unicode/utf8" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/label" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/ddl/resourcegroup" + sess "github.com/pingcap/tidb/pkg/ddl/session" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + rg "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/statistics/handle" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/domainutil" + "github.com/pingcap/tidb/pkg/util/generic" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/mathutil" + "github.com/pingcap/tidb/pkg/util/stringutil" + "github.com/tikv/client-go/v2/oracle" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +const ( + expressionIndexPrefix = "_V$" + changingColumnPrefix = "_Col$_" + changingIndexPrefix = "_Idx$_" + tableNotExist = -1 + tinyBlobMaxLength = 255 + blobMaxLength = 65535 + mediumBlobMaxLength = 16777215 + longBlobMaxLength = 4294967295 + // When setting the placement policy with "PLACEMENT POLICY `default`", + // it means to remove placement policy from the specified object. + defaultPlacementPolicyName = "default" + tiflashCheckPendingTablesWaitTime = 3000 * time.Millisecond + // Once tiflashCheckPendingTablesLimit is reached, we trigger a limiter detection. + tiflashCheckPendingTablesLimit = 100 + tiflashCheckPendingTablesRetry = 7 +) + +var errCheckConstraintIsOff = errors.NewNoStackError(variable.TiDBEnableCheckConstraint + " is off") + +// Executor is the interface for executing DDL statements. +// it's mostly called by SQL executor. +type Executor interface { + CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDatabaseStmt) error + AlterSchema(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) error + DropSchema(ctx sessionctx.Context, stmt *ast.DropDatabaseStmt) error + CreateTable(ctx sessionctx.Context, stmt *ast.CreateTableStmt) error + CreateView(ctx sessionctx.Context, stmt *ast.CreateViewStmt) error + DropTable(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) + RecoverTable(ctx sessionctx.Context, recoverInfo *RecoverInfo) (err error) + RecoverSchema(ctx sessionctx.Context, recoverSchemaInfo *RecoverSchemaInfo) error + DropView(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) + CreateIndex(ctx sessionctx.Context, stmt *ast.CreateIndexStmt) error + DropIndex(ctx sessionctx.Context, stmt *ast.DropIndexStmt) error + AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast.AlterTableStmt) error + TruncateTable(ctx sessionctx.Context, tableIdent ast.Ident) error + RenameTable(ctx sessionctx.Context, stmt *ast.RenameTableStmt) error + LockTables(ctx sessionctx.Context, stmt *ast.LockTablesStmt) error + UnlockTables(ctx sessionctx.Context, lockedTables []model.TableLockTpInfo) error + CleanupTableLock(ctx sessionctx.Context, tables []*ast.TableName) error + UpdateTableReplicaInfo(ctx sessionctx.Context, physicalID int64, available bool) error + RepairTable(ctx sessionctx.Context, createStmt *ast.CreateTableStmt) error + CreateSequence(ctx sessionctx.Context, stmt *ast.CreateSequenceStmt) error + DropSequence(ctx sessionctx.Context, stmt *ast.DropSequenceStmt) (err error) + AlterSequence(ctx sessionctx.Context, stmt *ast.AlterSequenceStmt) error + CreatePlacementPolicy(ctx sessionctx.Context, stmt *ast.CreatePlacementPolicyStmt) error + DropPlacementPolicy(ctx sessionctx.Context, stmt *ast.DropPlacementPolicyStmt) error + AlterPlacementPolicy(ctx sessionctx.Context, stmt *ast.AlterPlacementPolicyStmt) error + AddResourceGroup(ctx sessionctx.Context, stmt *ast.CreateResourceGroupStmt) error + AlterResourceGroup(ctx sessionctx.Context, stmt *ast.AlterResourceGroupStmt) error + DropResourceGroup(ctx sessionctx.Context, stmt *ast.DropResourceGroupStmt) error + FlashbackCluster(ctx sessionctx.Context, flashbackTS uint64) error + + // CreateSchemaWithInfo creates a database (schema) given its database info. + // + // WARNING: the DDL owns the `info` after calling this function, and will modify its fields + // in-place. If you want to keep using `info`, please call Clone() first. + CreateSchemaWithInfo( + ctx sessionctx.Context, + info *model.DBInfo, + onExist OnExist) error + + // CreateTableWithInfo creates a table, view or sequence given its table info. + // + // WARNING: the DDL owns the `info` after calling this function, and will modify its fields + // in-place. If you want to keep using `info`, please call Clone() first. + CreateTableWithInfo( + ctx sessionctx.Context, + schema model.CIStr, + info *model.TableInfo, + involvingRef []model.InvolvingSchemaInfo, + cs ...CreateTableOption) error + + // BatchCreateTableWithInfo is like CreateTableWithInfo, but can handle multiple tables. + BatchCreateTableWithInfo(ctx sessionctx.Context, + schema model.CIStr, + info []*model.TableInfo, + cs ...CreateTableOption) error + + // CreatePlacementPolicyWithInfo creates a placement policy + // + // WARNING: the DDL owns the `policy` after calling this function, and will modify its fields + // in-place. If you want to keep using `policy`, please call Clone() first. + CreatePlacementPolicyWithInfo(ctx sessionctx.Context, policy *model.PolicyInfo, onExist OnExist) error +} + +// ExecutorForTest is the interface for executing DDL statements in tests. +// TODO remove it later +type ExecutorForTest interface { + // DoDDLJob does the DDL job, it's exported for test. + DoDDLJob(ctx sessionctx.Context, job *model.Job) error + // DoDDLJobWrapper similar to DoDDLJob, but with JobWrapper as input. + DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) error +} + +// all fields are shared with ddl now. +type executor struct { + sessPool *sess.Pool + statsHandle *handle.Handle + + ctx context.Context + uuid string + store kv.Storage + etcdCli *clientv3.Client + autoidCli *autoid.ClientDiscover + infoCache *infoschema.InfoCache + limitJobCh chan *JobWrapper + schemaLoader SchemaLoader + lease time.Duration // lease is schema lease, default 45s, see config.Lease. + ownerManager owner.Manager + ddlJobDoneChMap *generic.SyncMap[int64, chan struct{}] + ddlJobNotifyCh chan struct{} + globalIDLock *sync.Mutex +} + +var _ Executor = (*executor)(nil) +var _ ExecutorForTest = (*executor)(nil) + +func (e *executor) CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDatabaseStmt) (err error) { + var placementPolicyRef *model.PolicyRefInfo + sessionVars := ctx.GetSessionVars() + + // If no charset and/or collation is specified use collation_server and character_set_server + charsetOpt := ast.CharsetOpt{} + if sessionVars.GlobalVarsAccessor != nil { + charsetOpt.Col, err = sessionVars.GetSessionOrGlobalSystemVar(context.Background(), variable.CollationServer) + if err != nil { + return err + } + charsetOpt.Chs, err = sessionVars.GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetServer) + if err != nil { + return err + } + } + + explicitCharset := false + explicitCollation := false + for _, val := range stmt.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + charsetOpt.Chs = val.Value + explicitCharset = true + case ast.DatabaseOptionCollate: + charsetOpt.Col = val.Value + explicitCollation = true + case ast.DatabaseOptionPlacementPolicy: + placementPolicyRef = &model.PolicyRefInfo{ + Name: model.NewCIStr(val.Value), + } + } + } + + if charsetOpt.Col != "" { + coll, err := collate.GetCollationByName(charsetOpt.Col) + if err != nil { + return err + } + + // The collation is not valid for the specified character set. + // Try to remove any of them, but not if they are explicitly defined. + if coll.CharsetName != charsetOpt.Chs { + if explicitCollation && !explicitCharset { + // Use the explicitly set collation, not the implicit charset. + charsetOpt.Chs = "" + } + if !explicitCollation && explicitCharset { + // Use the explicitly set charset, not the (session) collation. + charsetOpt.Col = "" + } + } + } + if !explicitCollation && explicitCharset { + coll, err := getDefaultCollationForUTF8MB4(ctx.GetSessionVars(), charsetOpt.Chs) + if err != nil { + return err + } + if len(coll) != 0 { + charsetOpt.Col = coll + } + } + dbInfo := &model.DBInfo{Name: stmt.Name} + chs, coll, err := ResolveCharsetCollation(ctx.GetSessionVars(), charsetOpt) + if err != nil { + return errors.Trace(err) + } + dbInfo.Charset = chs + dbInfo.Collate = coll + dbInfo.PlacementPolicyRef = placementPolicyRef + + onExist := OnExistError + if stmt.IfNotExists { + onExist = OnExistIgnore + } + return e.CreateSchemaWithInfo(ctx, dbInfo, onExist) +} + +func (e *executor) CreateSchemaWithInfo( + ctx sessionctx.Context, + dbInfo *model.DBInfo, + onExist OnExist, +) error { + is := e.infoCache.GetLatest() + _, ok := is.SchemaByName(dbInfo.Name) + if ok { + // since this error may be seen as error, keep it stack info. + err := infoschema.ErrDatabaseExists.GenWithStackByArgs(dbInfo.Name) + switch onExist { + case OnExistIgnore: + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + case OnExistError, OnExistReplace: + // FIXME: can we implement MariaDB's CREATE OR REPLACE SCHEMA? + return err + } + } + + if err := checkTooLongSchema(dbInfo.Name); err != nil { + return errors.Trace(err) + } + + if err := checkCharsetAndCollation(dbInfo.Charset, dbInfo.Collate); err != nil { + return errors.Trace(err) + } + + if err := handleDatabasePlacement(ctx, dbInfo); err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaName: dbInfo.Name.L, + Type: model.ActionCreateSchema, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{dbInfo}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Database: dbInfo.Name.L, + Table: model.InvolvingAll, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + if ref := dbInfo.PlacementPolicyRef; ref != nil { + job.InvolvingSchemaInfo = append(job.InvolvingSchemaInfo, model.InvolvingSchemaInfo{ + Policy: ref.Name.L, + Mode: model.SharedInvolving, + }) + } + + err := e.DoDDLJob(ctx, job) + + if infoschema.ErrDatabaseExists.Equal(err) && onExist == OnExistIgnore { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + + return errors.Trace(err) +} + +func (e *executor) ModifySchemaCharsetAndCollate(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt, toCharset, toCollate string) (err error) { + if toCollate == "" { + if toCollate, err = GetDefaultCollation(ctx.GetSessionVars(), toCharset); err != nil { + return errors.Trace(err) + } + } + + // Check if need to change charset/collation. + dbName := stmt.Name + is := e.infoCache.GetLatest() + dbInfo, ok := is.SchemaByName(dbName) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName.O) + } + if dbInfo.Charset == toCharset && dbInfo.Collate == toCollate { + return nil + } + // Do the DDL job. + job := &model.Job{ + SchemaID: dbInfo.ID, + SchemaName: dbInfo.Name.L, + Type: model.ActionModifySchemaCharsetAndCollate, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{toCharset, toCollate}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Database: dbInfo.Name.L, + Table: model.InvolvingAll, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) ModifySchemaDefaultPlacement(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt, placementPolicyRef *model.PolicyRefInfo) (err error) { + dbName := stmt.Name + is := e.infoCache.GetLatest() + dbInfo, ok := is.SchemaByName(dbName) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName.O) + } + + if checkIgnorePlacementDDL(ctx) { + return nil + } + + placementPolicyRef, err = checkAndNormalizePlacementPolicy(ctx, placementPolicyRef) + if err != nil { + return err + } + + // Do the DDL job. + job := &model.Job{ + SchemaID: dbInfo.ID, + SchemaName: dbInfo.Name.L, + Type: model.ActionModifySchemaDefaultPlacement, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{placementPolicyRef}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Database: dbInfo.Name.L, + Table: model.InvolvingAll, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + if placementPolicyRef != nil { + job.InvolvingSchemaInfo = append(job.InvolvingSchemaInfo, model.InvolvingSchemaInfo{ + Policy: placementPolicyRef.Name.L, + Mode: model.SharedInvolving, + }) + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// getPendingTiFlashTableCount counts unavailable TiFlash replica by iterating all tables in infoCache. +func (e *executor) getPendingTiFlashTableCount(originVersion int64, pendingCount uint32) (int64, uint32) { + is := e.infoCache.GetLatest() + // If there are no schema change since last time(can be weird). + if is.SchemaMetaVersion() == originVersion { + return originVersion, pendingCount + } + cnt := uint32(0) + dbs := is.ListTablesWithSpecialAttribute(infoschema.TiFlashAttribute) + for _, db := range dbs { + if util.IsMemOrSysDB(db.DBName) { + continue + } + for _, tbl := range db.TableInfos { + if tbl.TiFlashReplica != nil && !tbl.TiFlashReplica.Available { + cnt++ + } + } + } + return is.SchemaMetaVersion(), cnt +} + +func isSessionDone(sctx sessionctx.Context) (bool, uint32) { + done := false + killed := sctx.GetSessionVars().SQLKiller.HandleSignal() == exeerrors.ErrQueryInterrupted + if killed { + return true, 1 + } + failpoint.Inject("BatchAddTiFlashSendDone", func(val failpoint.Value) { + done = val.(bool) + }) + return done, 0 +} + +func (e *executor) waitPendingTableThreshold(sctx sessionctx.Context, schemaID int64, tableID int64, originVersion int64, pendingCount uint32, threshold uint32) (bool, int64, uint32, bool) { + configRetry := tiflashCheckPendingTablesRetry + configWaitTime := tiflashCheckPendingTablesWaitTime + failpoint.Inject("FastFailCheckTiFlashPendingTables", func(value failpoint.Value) { + configRetry = value.(int) + configWaitTime = time.Millisecond * 200 + }) + + for retry := 0; retry < configRetry; retry++ { + done, killed := isSessionDone(sctx) + if done { + logutil.DDLLogger().Info("abort batch add TiFlash replica", zap.Int64("schemaID", schemaID), zap.Uint32("isKilled", killed)) + return true, originVersion, pendingCount, false + } + originVersion, pendingCount = e.getPendingTiFlashTableCount(originVersion, pendingCount) + delay := time.Duration(0) + if pendingCount < threshold { + // If there are not many unavailable tables, we don't need a force check. + return false, originVersion, pendingCount, false + } + logutil.DDLLogger().Info("too many unavailable tables, wait", + zap.Uint32("threshold", threshold), + zap.Uint32("currentPendingCount", pendingCount), + zap.Int64("schemaID", schemaID), + zap.Int64("tableID", tableID), + zap.Duration("time", configWaitTime)) + delay = configWaitTime + time.Sleep(delay) + } + logutil.DDLLogger().Info("too many unavailable tables, timeout", zap.Int64("schemaID", schemaID), zap.Int64("tableID", tableID)) + // If timeout here, we will trigger a ddl job, to force sync schema. However, it doesn't mean we remove limiter, + // so there is a force check immediately after that. + return false, originVersion, pendingCount, true +} + +func (e *executor) ModifySchemaSetTiFlashReplica(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt, tiflashReplica *ast.TiFlashReplicaSpec) error { + dbName := stmt.Name + is := e.infoCache.GetLatest() + dbInfo, ok := is.SchemaByName(dbName) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName.O) + } + + if util.IsMemOrSysDB(dbInfo.Name.L) { + return errors.Trace(dbterror.ErrUnsupportedTiFlashOperationForSysOrMemTable) + } + + tbls, err := is.SchemaTableInfos(context.Background(), dbInfo.Name) + if err != nil { + return errors.Trace(err) + } + + total := len(tbls) + succ := 0 + skip := 0 + fail := 0 + oneFail := int64(0) + + if total == 0 { + return infoschema.ErrEmptyDatabase.GenWithStack("Empty database '%v'", dbName.O) + } + err = checkTiFlashReplicaCount(sctx, tiflashReplica.Count) + if err != nil { + return errors.Trace(err) + } + + var originVersion int64 + var pendingCount uint32 + forceCheck := false + + logutil.DDLLogger().Info("start batch add TiFlash replicas", zap.Int("total", total), zap.Int64("schemaID", dbInfo.ID)) + threshold := uint32(sctx.GetSessionVars().BatchPendingTiFlashCount) + + for _, tbl := range tbls { + done, killed := isSessionDone(sctx) + if done { + logutil.DDLLogger().Info("abort batch add TiFlash replica", zap.Int64("schemaID", dbInfo.ID), zap.Uint32("isKilled", killed)) + return nil + } + + tbReplicaInfo := tbl.TiFlashReplica + if !shouldModifyTiFlashReplica(tbReplicaInfo, tiflashReplica) { + logutil.DDLLogger().Info("skip repeated processing table", + zap.Int64("tableID", tbl.ID), + zap.Int64("schemaID", dbInfo.ID), + zap.String("tableName", tbl.Name.String()), + zap.String("schemaName", dbInfo.Name.String())) + skip++ + continue + } + + // If table is not supported, add err to warnings. + err = isTableTiFlashSupported(dbName, tbl) + if err != nil { + logutil.DDLLogger().Info("skip processing table", zap.Int64("tableID", tbl.ID), + zap.Int64("schemaID", dbInfo.ID), + zap.String("tableName", tbl.Name.String()), + zap.String("schemaName", dbInfo.Name.String()), + zap.Error(err)) + sctx.GetSessionVars().StmtCtx.AppendNote(err) + skip++ + continue + } + + // Alter `tiflashCheckPendingTablesLimit` tables are handled, we need to check if we have reached threshold. + if (succ+fail)%tiflashCheckPendingTablesLimit == 0 || forceCheck { + // We can execute one probing ddl to the latest schema, if we timeout in `pendingFunc`. + // However, we shall mark `forceCheck` to true, because we may still reach `threshold`. + finished := false + finished, originVersion, pendingCount, forceCheck = e.waitPendingTableThreshold(sctx, dbInfo.ID, tbl.ID, originVersion, pendingCount, threshold) + if finished { + logutil.DDLLogger().Info("abort batch add TiFlash replica", zap.Int64("schemaID", dbInfo.ID)) + return nil + } + } + + job := &model.Job{ + SchemaID: dbInfo.ID, + SchemaName: dbInfo.Name.L, + TableID: tbl.ID, + Type: model.ActionSetTiFlashReplica, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{*tiflashReplica}, + CDCWriteSource: sctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Database: dbInfo.Name.L, + Table: model.InvolvingAll, + }}, + SQLMode: sctx.GetSessionVars().SQLMode, + } + err := e.DoDDLJob(sctx, job) + if err != nil { + oneFail = tbl.ID + fail++ + logutil.DDLLogger().Info("processing schema table error", + zap.Int64("tableID", tbl.ID), + zap.Int64("schemaID", dbInfo.ID), + zap.Stringer("tableName", tbl.Name), + zap.Stringer("schemaName", dbInfo.Name), + zap.Error(err)) + } else { + succ++ + } + } + failStmt := "" + if fail > 0 { + failStmt = fmt.Sprintf("(including table %v)", oneFail) + } + msg := fmt.Sprintf("In total %v tables: %v succeed, %v failed%v, %v skipped", total, succ, fail, failStmt, skip) + sctx.GetSessionVars().StmtCtx.SetMessage(msg) + logutil.DDLLogger().Info("finish batch add TiFlash replica", zap.Int64("schemaID", dbInfo.ID)) + return nil +} + +func (e *executor) AlterTablePlacement(ctx sessionctx.Context, ident ast.Ident, placementPolicyRef *model.PolicyRefInfo) (err error) { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + + tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + + if checkIgnorePlacementDDL(ctx) { + return nil + } + + tblInfo := tb.Meta() + if tblInfo.TempTableType != model.TempTableNone { + return errors.Trace(dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("placement")) + } + + placementPolicyRef, err = checkAndNormalizePlacementPolicy(ctx, placementPolicyRef) + if err != nil { + return err + } + + var involvingSchemaInfo []model.InvolvingSchemaInfo + if placementPolicyRef != nil { + involvingSchemaInfo = []model.InvolvingSchemaInfo{ + { + Database: schema.Name.L, + Table: tblInfo.Name.L, + }, + { + Policy: placementPolicyRef.Name.L, + Mode: model.SharedInvolving, + }, + } + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tblInfo.ID, + SchemaName: schema.Name.L, + TableName: tblInfo.Name.L, + Type: model.ActionAlterTablePlacement, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{placementPolicyRef}, + InvolvingSchemaInfo: involvingSchemaInfo, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func checkMultiSchemaSpecs(_ sessionctx.Context, specs []*ast.DatabaseOption) error { + hasSetTiFlashReplica := false + if len(specs) == 1 { + return nil + } + for _, spec := range specs { + if spec.Tp == ast.DatabaseSetTiFlashReplica { + if hasSetTiFlashReplica { + return dbterror.ErrRunMultiSchemaChanges.FastGenByArgs(model.ActionSetTiFlashReplica.String()) + } + hasSetTiFlashReplica = true + } + } + return nil +} + +func (e *executor) AlterSchema(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) (err error) { + // Resolve target charset and collation from options. + var ( + toCharset, toCollate string + isAlterCharsetAndCollate bool + placementPolicyRef *model.PolicyRefInfo + tiflashReplica *ast.TiFlashReplicaSpec + ) + + err = checkMultiSchemaSpecs(sctx, stmt.Options) + if err != nil { + return err + } + + for _, val := range stmt.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + if toCharset == "" { + toCharset = val.Value + } else if toCharset != val.Value { + return dbterror.ErrConflictingDeclarations.GenWithStackByArgs(toCharset, val.Value) + } + isAlterCharsetAndCollate = true + case ast.DatabaseOptionCollate: + info, errGetCollate := collate.GetCollationByName(val.Value) + if errGetCollate != nil { + return errors.Trace(errGetCollate) + } + if toCharset == "" { + toCharset = info.CharsetName + } else if toCharset != info.CharsetName { + return dbterror.ErrConflictingDeclarations.GenWithStackByArgs(toCharset, info.CharsetName) + } + toCollate = info.Name + isAlterCharsetAndCollate = true + case ast.DatabaseOptionPlacementPolicy: + placementPolicyRef = &model.PolicyRefInfo{Name: model.NewCIStr(val.Value)} + case ast.DatabaseSetTiFlashReplica: + tiflashReplica = val.TiFlashReplica + } + } + + if isAlterCharsetAndCollate { + if err = e.ModifySchemaCharsetAndCollate(sctx, stmt, toCharset, toCollate); err != nil { + return err + } + } + if placementPolicyRef != nil { + if err = e.ModifySchemaDefaultPlacement(sctx, stmt, placementPolicyRef); err != nil { + return err + } + } + if tiflashReplica != nil { + if err = e.ModifySchemaSetTiFlashReplica(sctx, stmt, tiflashReplica); err != nil { + return err + } + } + return nil +} + +func (e *executor) DropSchema(ctx sessionctx.Context, stmt *ast.DropDatabaseStmt) (err error) { + is := e.infoCache.GetLatest() + old, ok := is.SchemaByName(stmt.Name) + if !ok { + if stmt.IfExists { + return nil + } + return infoschema.ErrDatabaseDropExists.GenWithStackByArgs(stmt.Name) + } + fkCheck := ctx.GetSessionVars().ForeignKeyChecks + err = checkDatabaseHasForeignKeyReferred(e.ctx, is, old.Name, fkCheck) + if err != nil { + return err + } + job := &model.Job{ + SchemaID: old.ID, + SchemaName: old.Name.L, + SchemaState: old.State, + Type: model.ActionDropSchema, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{fkCheck}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Database: old.Name.L, + Table: model.InvolvingAll, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) { + if stmt.IfExists { + return nil + } + return infoschema.ErrDatabaseDropExists.GenWithStackByArgs(stmt.Name) + } + return errors.Trace(err) + } + if !config.TableLockEnabled() { + return nil + } + // Clear table locks hold by the session. + tbs, err := is.SchemaTableInfos(e.ctx, stmt.Name) + if err != nil { + return errors.Trace(err) + } + + lockTableIDs := make([]int64, 0) + for _, tb := range tbs { + if ok, _ := ctx.CheckTableLocked(tb.ID); ok { + lockTableIDs = append(lockTableIDs, tb.ID) + } + } + ctx.ReleaseTableLockByTableIDs(lockTableIDs) + return nil +} + +func (e *executor) RecoverSchema(ctx sessionctx.Context, recoverSchemaInfo *RecoverSchemaInfo) error { + involvedSchemas := []model.InvolvingSchemaInfo{{ + Database: recoverSchemaInfo.DBInfo.Name.L, + Table: model.InvolvingAll, + }} + if recoverSchemaInfo.OldSchemaName.L != recoverSchemaInfo.DBInfo.Name.L { + involvedSchemas = append(involvedSchemas, model.InvolvingSchemaInfo{ + Database: recoverSchemaInfo.OldSchemaName.L, + Table: model.InvolvingAll, + }) + } + recoverSchemaInfo.State = model.StateNone + job := &model.Job{ + Type: model.ActionRecoverSchema, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{recoverSchemaInfo, recoverCheckFlagNone}, + InvolvingSchemaInfo: involvedSchemas, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err := e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func checkTooLongSchema(schema model.CIStr) error { + if utf8.RuneCountInString(schema.L) > mysql.MaxDatabaseNameLength { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(schema) + } + return nil +} + +func checkTooLongTable(table model.CIStr) error { + if utf8.RuneCountInString(table.L) > mysql.MaxTableNameLength { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(table) + } + return nil +} + +func checkTooLongIndex(index model.CIStr) error { + if utf8.RuneCountInString(index.L) > mysql.MaxIndexIdentifierLen { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(index) + } + return nil +} + +func checkTooLongColumn(col model.CIStr) error { + if utf8.RuneCountInString(col.L) > mysql.MaxColumnNameLength { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(col) + } + return nil +} + +func checkTooLongForeignKey(fk model.CIStr) error { + if utf8.RuneCountInString(fk.L) > mysql.MaxForeignKeyIdentifierLen { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(fk) + } + return nil +} + +func getDefaultCollationForUTF8MB4(sessVars *variable.SessionVars, cs string) (string, error) { + if sessVars == nil || cs != charset.CharsetUTF8MB4 { + return "", nil + } + defaultCollation, err := sessVars.GetSessionOrGlobalSystemVar(context.Background(), variable.DefaultCollationForUTF8MB4) + if err != nil { + return "", err + } + return defaultCollation, nil +} + +// GetDefaultCollation returns the default collation for charset and handle the default collation for UTF8MB4. +func GetDefaultCollation(sessVars *variable.SessionVars, cs string) (string, error) { + coll, err := getDefaultCollationForUTF8MB4(sessVars, cs) + if err != nil { + return "", errors.Trace(err) + } + if coll != "" { + return coll, nil + } + + coll, err = charset.GetDefaultCollation(cs) + if err != nil { + return "", errors.Trace(err) + } + return coll, nil +} + +// ResolveCharsetCollation will resolve the charset and collate by the order of parameters: +// * If any given ast.CharsetOpt is not empty, the resolved charset and collate will be returned. +// * If all ast.CharsetOpts are empty, the default charset and collate will be returned. +func ResolveCharsetCollation(sessVars *variable.SessionVars, charsetOpts ...ast.CharsetOpt) (chs string, coll string, err error) { + for _, v := range charsetOpts { + if v.Col != "" { + collation, err := collate.GetCollationByName(v.Col) + if err != nil { + return "", "", errors.Trace(err) + } + if v.Chs != "" && collation.CharsetName != v.Chs { + return "", "", charset.ErrCollationCharsetMismatch.GenWithStackByArgs(v.Col, v.Chs) + } + return collation.CharsetName, v.Col, nil + } + if v.Chs != "" { + coll, err := GetDefaultCollation(sessVars, v.Chs) + if err != nil { + return "", "", errors.Trace(err) + } + return v.Chs, coll, nil + } + } + chs, coll = charset.GetDefaultCharsetAndCollate() + utf8mb4Coll, err := getDefaultCollationForUTF8MB4(sessVars, chs) + if err != nil { + return "", "", errors.Trace(err) + } + if utf8mb4Coll != "" { + return chs, utf8mb4Coll, nil + } + return chs, coll, nil +} + +// IsAutoRandomColumnID returns true if the given column ID belongs to an auto_random column. +func IsAutoRandomColumnID(tblInfo *model.TableInfo, colID int64) bool { + if !tblInfo.ContainsAutoRandomBits() { + return false + } + if tblInfo.PKIsHandle { + return tblInfo.GetPkColInfo().ID == colID + } else if tblInfo.IsCommonHandle { + pk := tables.FindPrimaryIndex(tblInfo) + if pk == nil { + return false + } + offset := pk.Columns[0].Offset + return tblInfo.Columns[offset].ID == colID + } + return false +} + +// checkInvisibleIndexOnPK check if primary key is invisible index. +// Note: PKIsHandle == true means the table already has a visible primary key, +// we do not need do a check for this case and return directly, +// because whether primary key is invisible has been check when creating table. +func checkInvisibleIndexOnPK(tblInfo *model.TableInfo) error { + if tblInfo.PKIsHandle { + return nil + } + pk := tblInfo.GetPrimaryKey() + if pk != nil && pk.Invisible { + return dbterror.ErrPKIndexCantBeInvisible + } + return nil +} + +func (e *executor) assignPartitionIDs(defs []model.PartitionDefinition) error { + genIDs, err := e.genGlobalIDs(len(defs)) + if err != nil { + return errors.Trace(err) + } + for i := range defs { + defs[i].ID = genIDs[i] + } + return nil +} + +func (e *executor) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err error) { + ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + + var ( + referTbl table.Table + involvingRef []model.InvolvingSchemaInfo + ) + if s.ReferTable != nil { + referIdent := ast.Ident{Schema: s.ReferTable.Schema, Name: s.ReferTable.Name} + _, ok := is.SchemaByName(referIdent.Schema) + if !ok { + return infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) + } + referTbl, err = is.TableByName(e.ctx, referIdent.Schema, referIdent.Name) + if err != nil { + return infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) + } + involvingRef = append(involvingRef, model.InvolvingSchemaInfo{ + Database: s.ReferTable.Schema.L, + Table: s.ReferTable.Name.L, + Mode: model.SharedInvolving, + }) + } + + // build tableInfo + var tbInfo *model.TableInfo + if s.ReferTable != nil { + tbInfo, err = BuildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) + } else { + tbInfo, err = BuildTableInfoWithStmt(ctx, s, schema.Charset, schema.Collate, schema.PlacementPolicyRef) + } + if err != nil { + return errors.Trace(err) + } + + if err = checkTableInfoValidWithStmt(ctx, tbInfo, s); err != nil { + return err + } + if err = checkTableForeignKeysValid(ctx, is, schema.Name.L, tbInfo); err != nil { + return err + } + + onExist := OnExistError + if s.IfNotExists { + onExist = OnExistIgnore + } + + return e.CreateTableWithInfo(ctx, schema.Name, tbInfo, involvingRef, WithOnExist(onExist)) +} + +// createTableWithInfoJob returns the table creation job. +// WARNING: it may return a nil job, which means you don't need to submit any DDL job. +func (e *executor) createTableWithInfoJob( + ctx sessionctx.Context, + dbName model.CIStr, + tbInfo *model.TableInfo, + involvingRef []model.InvolvingSchemaInfo, + onExist OnExist, +) (job *model.Job, err error) { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(dbName) + if !ok { + return nil, infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName) + } + + if err = handleTablePlacement(ctx, tbInfo); err != nil { + return nil, errors.Trace(err) + } + + var oldViewTblID int64 + if oldTable, err := is.TableByName(e.ctx, schema.Name, tbInfo.Name); err == nil { + err = infoschema.ErrTableExists.GenWithStackByArgs(ast.Ident{Schema: schema.Name, Name: tbInfo.Name}) + switch onExist { + case OnExistIgnore: + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil, nil + case OnExistReplace: + // only CREATE OR REPLACE VIEW is supported at the moment. + if tbInfo.View != nil { + if oldTable.Meta().IsView() { + oldViewTblID = oldTable.Meta().ID + break + } + // The object to replace isn't a view. + return nil, dbterror.ErrWrongObject.GenWithStackByArgs(dbName, tbInfo.Name, "VIEW") + } + return nil, err + default: + return nil, err + } + } + + if err := checkTableInfoValidExtra(tbInfo); err != nil { + return nil, err + } + + var actionType model.ActionType + args := []any{tbInfo} + switch { + case tbInfo.View != nil: + actionType = model.ActionCreateView + args = append(args, onExist == OnExistReplace, oldViewTblID) + case tbInfo.Sequence != nil: + actionType = model.ActionCreateSequence + default: + actionType = model.ActionCreateTable + args = append(args, ctx.GetSessionVars().ForeignKeyChecks) + } + + var involvingSchemas []model.InvolvingSchemaInfo + sharedInvolvingFromTableInfo := getSharedInvolvingSchemaInfo(tbInfo) + + if sum := len(involvingRef) + len(sharedInvolvingFromTableInfo); sum > 0 { + involvingSchemas = make([]model.InvolvingSchemaInfo, 0, sum+1) + involvingSchemas = append(involvingSchemas, model.InvolvingSchemaInfo{ + Database: schema.Name.L, + Table: tbInfo.Name.L, + }) + involvingSchemas = append(involvingSchemas, involvingRef...) + involvingSchemas = append(involvingSchemas, sharedInvolvingFromTableInfo...) + } + + job = &model.Job{ + SchemaID: schema.ID, + SchemaName: schema.Name.L, + TableName: tbInfo.Name.L, + Type: actionType, + BinlogInfo: &model.HistoryInfo{}, + Args: args, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: involvingSchemas, + SQLMode: ctx.GetSessionVars().SQLMode, + } + return job, nil +} + +func getSharedInvolvingSchemaInfo(info *model.TableInfo) []model.InvolvingSchemaInfo { + ret := make([]model.InvolvingSchemaInfo, 0, len(info.ForeignKeys)+1) + for _, fk := range info.ForeignKeys { + ret = append(ret, model.InvolvingSchemaInfo{ + Database: fk.RefSchema.L, + Table: fk.RefTable.L, + Mode: model.SharedInvolving, + }) + } + if ref := info.PlacementPolicyRef; ref != nil { + ret = append(ret, model.InvolvingSchemaInfo{ + Policy: ref.Name.L, + Mode: model.SharedInvolving, + }) + } + return ret +} + +func (e *executor) createTableWithInfoPost( + ctx sessionctx.Context, + tbInfo *model.TableInfo, + schemaID int64, +) error { + var err error + var partitions []model.PartitionDefinition + if pi := tbInfo.GetPartitionInfo(); pi != nil { + partitions = pi.Definitions + } + preSplitAndScatter(ctx, e.store, tbInfo, partitions) + if tbInfo.AutoIncID > 1 { + // Default tableAutoIncID base is 0. + // If the first ID is expected to greater than 1, we need to do rebase. + newEnd := tbInfo.AutoIncID - 1 + var allocType autoid.AllocatorType + if tbInfo.SepAutoInc() { + allocType = autoid.AutoIncrementType + } else { + allocType = autoid.RowIDAllocType + } + if err = e.handleAutoIncID(tbInfo, schemaID, newEnd, allocType); err != nil { + return errors.Trace(err) + } + } + // For issue https://github.com/pingcap/tidb/issues/46093 + if tbInfo.AutoIncIDExtra != 0 { + if err = e.handleAutoIncID(tbInfo, schemaID, tbInfo.AutoIncIDExtra-1, autoid.RowIDAllocType); err != nil { + return errors.Trace(err) + } + } + if tbInfo.AutoRandID > 1 { + // Default tableAutoRandID base is 0. + // If the first ID is expected to greater than 1, we need to do rebase. + newEnd := tbInfo.AutoRandID - 1 + err = e.handleAutoIncID(tbInfo, schemaID, newEnd, autoid.AutoRandomType) + } + return err +} + +func (e *executor) CreateTableWithInfo( + ctx sessionctx.Context, + dbName model.CIStr, + tbInfo *model.TableInfo, + involvingRef []model.InvolvingSchemaInfo, + cs ...CreateTableOption, +) (err error) { + c := GetCreateTableConfig(cs) + + job, err := e.createTableWithInfoJob( + ctx, dbName, tbInfo, involvingRef, c.OnExist, + ) + if err != nil { + return err + } + if job == nil { + return nil + } + + jobW := NewJobWrapper(job, c.IDAllocated) + + err = e.DoDDLJobWrapper(ctx, jobW) + if err != nil { + // table exists, but if_not_exists flags is true, so we ignore this error. + if c.OnExist == OnExistIgnore && infoschema.ErrTableExists.Equal(err) { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + err = nil + } + } else { + err = e.createTableWithInfoPost(ctx, tbInfo, job.SchemaID) + } + + return errors.Trace(err) +} + +func (e *executor) BatchCreateTableWithInfo(ctx sessionctx.Context, + dbName model.CIStr, + infos []*model.TableInfo, + cs ...CreateTableOption, +) error { + failpoint.Inject("RestoreBatchCreateTableEntryTooLarge", func(val failpoint.Value) { + injectBatchSize := val.(int) + if len(infos) > injectBatchSize { + failpoint.Return(kv.ErrEntryTooLarge) + } + }) + c := GetCreateTableConfig(cs) + + jobW := NewJobWrapper( + &model.Job{ + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + }, + c.IDAllocated, + ) + args := make([]*model.TableInfo, 0, len(infos)) + + var err error + + // check if there are any duplicated table names + duplication := make(map[string]struct{}) + // TODO filter those duplicated info out. + for _, info := range infos { + if _, ok := duplication[info.Name.L]; ok { + err = infoschema.ErrTableExists.FastGenByArgs("can not batch create tables with same name") + if c.OnExist == OnExistIgnore && infoschema.ErrTableExists.Equal(err) { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + err = nil + } + } + if err != nil { + return errors.Trace(err) + } + + duplication[info.Name.L] = struct{}{} + } + + for _, info := range infos { + job, err := e.createTableWithInfoJob(ctx, dbName, info, nil, c.OnExist) + if err != nil { + return errors.Trace(err) + } + if job == nil { + continue + } + + // if jobW.Type == model.ActionCreateTables, it is initialized + // if not, initialize jobW by job.XXXX + if jobW.Type != model.ActionCreateTables { + jobW.Type = model.ActionCreateTables + jobW.SchemaID = job.SchemaID + jobW.SchemaName = job.SchemaName + } + + // append table job args + info, ok := job.Args[0].(*model.TableInfo) + if !ok { + return errors.Trace(fmt.Errorf("except table info")) + } + args = append(args, info) + jobW.InvolvingSchemaInfo = append(jobW.InvolvingSchemaInfo, model.InvolvingSchemaInfo{ + Database: dbName.L, + Table: info.Name.L, + }) + if sharedInv := getSharedInvolvingSchemaInfo(info); len(sharedInv) > 0 { + jobW.InvolvingSchemaInfo = append(jobW.InvolvingSchemaInfo, sharedInv...) + } + } + if len(args) == 0 { + return nil + } + jobW.Args = append(jobW.Args, args) + jobW.Args = append(jobW.Args, ctx.GetSessionVars().ForeignKeyChecks) + + err = e.DoDDLJobWrapper(ctx, jobW) + if err != nil { + // table exists, but if_not_exists flags is true, so we ignore this error. + if c.OnExist == OnExistIgnore && infoschema.ErrTableExists.Equal(err) { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + err = nil + } + return errors.Trace(err) + } + + for j := range args { + if err = e.createTableWithInfoPost(ctx, args[j], jobW.SchemaID); err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func (e *executor) CreatePlacementPolicyWithInfo(ctx sessionctx.Context, policy *model.PolicyInfo, onExist OnExist) error { + if checkIgnorePlacementDDL(ctx) { + return nil + } + + policyName := policy.Name + if policyName.L == defaultPlacementPolicyName { + return errors.Trace(infoschema.ErrReservedSyntax.GenWithStackByArgs(policyName)) + } + + // Check policy existence. + _, ok := e.infoCache.GetLatest().PolicyByName(policyName) + if ok { + err := infoschema.ErrPlacementPolicyExists.GenWithStackByArgs(policyName) + switch onExist { + case OnExistIgnore: + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + case OnExistError: + return err + } + } + + if err := checkPolicyValidation(policy.PlacementSettings); err != nil { + return err + } + + policyID, err := e.genPlacementPolicyID() + if err != nil { + return err + } + policy.ID = policyID + + job := &model.Job{ + SchemaName: policy.Name.L, + Type: model.ActionCreatePlacementPolicy, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{policy, onExist == OnExistReplace}, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Policy: policy.Name.L, + }}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// preSplitAndScatter performs pre-split and scatter of the table's regions. +// If `pi` is not nil, will only split region for `pi`, this is used when add partition. +func preSplitAndScatter(ctx sessionctx.Context, store kv.Storage, tbInfo *model.TableInfo, parts []model.PartitionDefinition) { + if tbInfo.TempTableType != model.TempTableNone { + return + } + sp, ok := store.(kv.SplittableStore) + if !ok || atomic.LoadUint32(&EnableSplitTableRegion) == 0 { + return + } + var ( + preSplit func() + scatterRegion bool + ) + val, err := ctx.GetSessionVars().GetGlobalSystemVar(context.Background(), variable.TiDBScatterRegion) + if err != nil { + logutil.DDLLogger().Warn("won't scatter region", zap.Error(err)) + } else { + scatterRegion = variable.TiDBOptOn(val) + } + if len(parts) > 0 { + preSplit = func() { splitPartitionTableRegion(ctx, sp, tbInfo, parts, scatterRegion) } + } else { + preSplit = func() { splitTableRegion(ctx, sp, tbInfo, scatterRegion) } + } + if scatterRegion { + preSplit() + } else { + go preSplit() + } +} + +func (e *executor) FlashbackCluster(ctx sessionctx.Context, flashbackTS uint64) error { + logutil.DDLLogger().Info("get flashback cluster job", zap.Stringer("flashbackTS", oracle.GetTimeFromTS(flashbackTS))) + nowTS, err := ctx.GetStore().GetOracle().GetTimestamp(e.ctx, &oracle.Option{}) + if err != nil { + return errors.Trace(err) + } + gap := time.Until(oracle.GetTimeFromTS(nowTS)).Abs() + if gap > 1*time.Second { + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Gap between local time and PD TSO is %s, please check PD/system time", gap)) + } + job := &model.Job{ + Type: model.ActionFlashbackCluster, + BinlogInfo: &model.HistoryInfo{}, + // The value for global variables is meaningless, it will cover during flashback cluster. + Args: []any{ + flashbackTS, + map[string]any{}, + true, /* tidb_gc_enable */ + variable.On, /* tidb_enable_auto_analyze */ + variable.Off, /* tidb_super_read_only */ + 0, /* totalRegions */ + 0, /* startTS */ + 0, /* commitTS */ + variable.On, /* tidb_ttl_job_enable */ + []kv.KeyRange{} /* flashback key_ranges */}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + // FLASHBACK CLUSTER affects all schemas and tables. + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Database: model.InvolvingAll, + Table: model.InvolvingAll, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) RecoverTable(ctx sessionctx.Context, recoverInfo *RecoverInfo) (err error) { + is := e.infoCache.GetLatest() + schemaID, tbInfo := recoverInfo.SchemaID, recoverInfo.TableInfo + // Check schema exist. + schema, ok := is.SchemaByID(schemaID) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", schemaID), + )) + } + // Check not exist table with same name. + if ok := is.TableExists(schema.Name, tbInfo.Name); ok { + return infoschema.ErrTableExists.GenWithStackByArgs(tbInfo.Name) + } + + // for "flashback table xxx to yyy" + // Note: this case only allow change table name, schema remains the same. + var involvedSchemas []model.InvolvingSchemaInfo + if recoverInfo.OldTableName != tbInfo.Name.L { + involvedSchemas = []model.InvolvingSchemaInfo{ + {Database: schema.Name.L, Table: recoverInfo.OldTableName}, + {Database: schema.Name.L, Table: tbInfo.Name.L}, + } + } + + tbInfo.State = model.StateNone + job := &model.Job{ + SchemaID: schemaID, + TableID: tbInfo.ID, + SchemaName: schema.Name.L, + TableName: tbInfo.Name.L, + + Type: model.ActionRecoverTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{recoverInfo, recoverCheckFlagNone}, + InvolvingSchemaInfo: involvedSchemas, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) CreateView(ctx sessionctx.Context, s *ast.CreateViewStmt) (err error) { + viewInfo, err := BuildViewInfo(s) + if err != nil { + return err + } + + cols := make([]*table.Column, len(s.Cols)) + for i, v := range s.Cols { + cols[i] = table.ToColumn(&model.ColumnInfo{ + Name: v, + ID: int64(i), + Offset: i, + State: model.StatePublic, + }) + } + + tblCharset := "" + tblCollate := "" + if v, ok := ctx.GetSessionVars().GetSystemVar(variable.CharacterSetConnection); ok { + tblCharset = v + } + if v, ok := ctx.GetSessionVars().GetSystemVar(variable.CollationConnection); ok { + tblCollate = v + } + + tbInfo, err := BuildTableInfo(ctx, s.ViewName.Name, cols, nil, tblCharset, tblCollate) + if err != nil { + return err + } + tbInfo.View = viewInfo + + onExist := OnExistError + if s.OrReplace { + onExist = OnExistReplace + } + + return e.CreateTableWithInfo(ctx, s.ViewName.Schema, tbInfo, nil, WithOnExist(onExist)) +} + +func checkCharsetAndCollation(cs string, co string) error { + if !charset.ValidCharsetAndCollation(cs, co) { + return dbterror.ErrUnknownCharacterSet.GenWithStackByArgs(cs) + } + if co != "" { + if _, err := collate.GetCollationByName(co); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// handleAutoIncID handles auto_increment option in DDL. It creates a ID counter for the table and initiates the counter to a proper value. +// For example if the option sets auto_increment to 10. The counter will be set to 9. So the next allocated ID will be 10. +func (e *executor) handleAutoIncID(tbInfo *model.TableInfo, schemaID int64, newEnd int64, tp autoid.AllocatorType) error { + allocs := autoid.NewAllocatorsFromTblInfo(e.getAutoIDRequirement(), schemaID, tbInfo) + if alloc := allocs.Get(tp); alloc != nil { + err := alloc.Rebase(context.Background(), newEnd, false) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +// TODO we can unify this part with ddlCtx. +func (e *executor) getAutoIDRequirement() autoid.Requirement { + return &asAutoIDRequirement{ + store: e.store, + autoidCli: e.autoidCli, + } +} + +func shardingBits(tblInfo *model.TableInfo) uint64 { + if tblInfo.ShardRowIDBits > 0 { + return tblInfo.ShardRowIDBits + } + return tblInfo.AutoRandomBits +} + +// isIgnorableSpec checks if the spec type is ignorable. +// Some specs are parsed by ignored. This is for compatibility. +func isIgnorableSpec(tp ast.AlterTableType) bool { + // AlterTableLock/AlterTableAlgorithm are ignored. + return tp == ast.AlterTableLock || tp == ast.AlterTableAlgorithm +} + +// GetCharsetAndCollateInTableOption will iterate the charset and collate in the options, +// and returns the last charset and collate in options. If there is no charset in the options, +// the returns charset will be "", the same as collate. +func GetCharsetAndCollateInTableOption(sessVars *variable.SessionVars, startIdx int, options []*ast.TableOption) (chs, coll string, err error) { + for i := startIdx; i < len(options); i++ { + opt := options[i] + // we set the charset to the last option. example: alter table t charset latin1 charset utf8 collate utf8_bin; + // the charset will be utf8, collate will be utf8_bin + switch opt.Tp { + case ast.TableOptionCharset: + info, err := charset.GetCharsetInfo(opt.StrValue) + if err != nil { + return "", "", err + } + if len(chs) == 0 { + chs = info.Name + } else if chs != info.Name { + return "", "", dbterror.ErrConflictingDeclarations.GenWithStackByArgs(chs, info.Name) + } + if len(coll) == 0 { + defaultColl, err := getDefaultCollationForUTF8MB4(sessVars, chs) + if err != nil { + return "", "", errors.Trace(err) + } + if len(defaultColl) == 0 { + coll = info.DefaultCollation + } else { + coll = defaultColl + } + } + case ast.TableOptionCollate: + info, err := collate.GetCollationByName(opt.StrValue) + if err != nil { + return "", "", err + } + if len(chs) == 0 { + chs = info.CharsetName + } else if chs != info.CharsetName { + return "", "", dbterror.ErrCollationCharsetMismatch.GenWithStackByArgs(info.Name, chs) + } + coll = info.Name + } + } + return +} + +// NeedToOverwriteColCharset return true for altering charset and specified CONVERT TO. +func NeedToOverwriteColCharset(options []*ast.TableOption) bool { + for i := len(options) - 1; i >= 0; i-- { + opt := options[i] + if opt.Tp == ast.TableOptionCharset { + // Only overwrite columns charset if the option contains `CONVERT TO`. + return opt.UintValue == ast.TableOptionCharsetWithConvertTo + } + } + return false +} + +// resolveAlterTableAddColumns splits "add columns" to multiple spec. For example, +// `ALTER TABLE ADD COLUMN (c1 INT, c2 INT)` is split into +// `ALTER TABLE ADD COLUMN c1 INT, ADD COLUMN c2 INT`. +func resolveAlterTableAddColumns(spec *ast.AlterTableSpec) []*ast.AlterTableSpec { + specs := make([]*ast.AlterTableSpec, 0, len(spec.NewColumns)+len(spec.NewConstraints)) + for _, col := range spec.NewColumns { + t := *spec + t.NewColumns = []*ast.ColumnDef{col} + t.NewConstraints = []*ast.Constraint{} + specs = append(specs, &t) + } + // Split the add constraints from AlterTableSpec. + for _, con := range spec.NewConstraints { + t := *spec + t.NewColumns = []*ast.ColumnDef{} + t.NewConstraints = []*ast.Constraint{} + t.Constraint = con + t.Tp = ast.AlterTableAddConstraint + specs = append(specs, &t) + } + return specs +} + +// ResolveAlterTableSpec resolves alter table algorithm and removes ignore table spec in specs. +// returns valid specs, and the occurred error. +func ResolveAlterTableSpec(ctx sessionctx.Context, specs []*ast.AlterTableSpec) ([]*ast.AlterTableSpec, error) { + validSpecs := make([]*ast.AlterTableSpec, 0, len(specs)) + algorithm := ast.AlgorithmTypeDefault + for _, spec := range specs { + if spec.Tp == ast.AlterTableAlgorithm { + // Find the last AlterTableAlgorithm. + algorithm = spec.Algorithm + } + if isIgnorableSpec(spec.Tp) { + continue + } + if spec.Tp == ast.AlterTableAddColumns && (len(spec.NewColumns) > 1 || len(spec.NewConstraints) > 0) { + validSpecs = append(validSpecs, resolveAlterTableAddColumns(spec)...) + } else { + validSpecs = append(validSpecs, spec) + } + // TODO: Only allow REMOVE PARTITIONING as a single ALTER TABLE statement? + } + + // Verify whether the algorithm is supported. + for _, spec := range validSpecs { + resolvedAlgorithm, err := ResolveAlterAlgorithm(spec, algorithm) + if err != nil { + // If TiDB failed to choose a better algorithm, report the error + if resolvedAlgorithm == ast.AlgorithmTypeDefault { + return nil, errors.Trace(err) + } + // For the compatibility, we return warning instead of error when a better algorithm is chosed by TiDB + ctx.GetSessionVars().StmtCtx.AppendError(err) + } + + spec.Algorithm = resolvedAlgorithm + } + + // Only handle valid specs. + return validSpecs, nil +} + +func isMultiSchemaChanges(specs []*ast.AlterTableSpec) bool { + if len(specs) > 1 { + return true + } + if len(specs) == 1 && len(specs[0].NewColumns) > 1 && specs[0].Tp == ast.AlterTableAddColumns { + return true + } + return false +} + +func (e *executor) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast.AlterTableStmt) (err error) { + ident := ast.Ident{Schema: stmt.Table.Schema, Name: stmt.Table.Name} + validSpecs, err := ResolveAlterTableSpec(sctx, stmt.Specs) + if err != nil { + return errors.Trace(err) + } + + is := e.infoCache.GetLatest() + tb, err := is.TableByName(ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(err) + } + if tb.Meta().IsView() || tb.Meta().IsSequence() { + return dbterror.ErrWrongObject.GenWithStackByArgs(ident.Schema, ident.Name, "BASE TABLE") + } + if tb.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + if len(validSpecs) != 1 { + return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Alter Table") + } + if validSpecs[0].Tp != ast.AlterTableCache && validSpecs[0].Tp != ast.AlterTableNoCache { + return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Alter Table") + } + } + if isMultiSchemaChanges(validSpecs) && (sctx.GetSessionVars().EnableRowLevelChecksum || variable.EnableRowLevelChecksum.Load()) { + return dbterror.ErrRunMultiSchemaChanges.GenWithStack("Unsupported multi schema change when row level checksum is enabled") + } + // set name for anonymous foreign key. + maxForeignKeyID := tb.Meta().MaxForeignKeyID + for _, spec := range validSpecs { + if spec.Tp == ast.AlterTableAddConstraint && spec.Constraint.Tp == ast.ConstraintForeignKey && spec.Constraint.Name == "" { + maxForeignKeyID++ + spec.Constraint.Name = fmt.Sprintf("fk_%d", maxForeignKeyID) + } + } + + if len(validSpecs) > 1 { + // after MultiSchemaInfo is set, DoDDLJob will collect all jobs into + // MultiSchemaInfo and skip running them. Then we will run them in + // d.multiSchemaChange all at once. + sctx.GetSessionVars().StmtCtx.MultiSchemaInfo = model.NewMultiSchemaInfo() + } + for _, spec := range validSpecs { + var handledCharsetOrCollate bool + var ttlOptionsHandled bool + switch spec.Tp { + case ast.AlterTableAddColumns: + err = e.AddColumn(sctx, ident, spec) + case ast.AlterTableAddPartitions, ast.AlterTableAddLastPartition: + err = e.AddTablePartitions(sctx, ident, spec) + case ast.AlterTableCoalescePartitions: + err = e.CoalescePartitions(sctx, ident, spec) + case ast.AlterTableReorganizePartition: + err = e.ReorganizePartitions(sctx, ident, spec) + case ast.AlterTableReorganizeFirstPartition: + err = dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("MERGE FIRST PARTITION") + case ast.AlterTableReorganizeLastPartition: + err = dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("SPLIT LAST PARTITION") + case ast.AlterTableCheckPartitions: + err = errors.Trace(dbterror.ErrUnsupportedCheckPartition) + case ast.AlterTableRebuildPartition: + err = errors.Trace(dbterror.ErrUnsupportedRebuildPartition) + case ast.AlterTableOptimizePartition: + err = errors.Trace(dbterror.ErrUnsupportedOptimizePartition) + case ast.AlterTableRemovePartitioning: + err = e.RemovePartitioning(sctx, ident, spec) + case ast.AlterTableRepairPartition: + err = errors.Trace(dbterror.ErrUnsupportedRepairPartition) + case ast.AlterTableDropColumn: + err = e.DropColumn(sctx, ident, spec) + case ast.AlterTableDropIndex: + err = e.dropIndex(sctx, ident, model.NewCIStr(spec.Name), spec.IfExists, false) + case ast.AlterTableDropPrimaryKey: + err = e.dropIndex(sctx, ident, model.NewCIStr(mysql.PrimaryKeyName), spec.IfExists, false) + case ast.AlterTableRenameIndex: + err = e.RenameIndex(sctx, ident, spec) + case ast.AlterTableDropPartition, ast.AlterTableDropFirstPartition: + err = e.DropTablePartition(sctx, ident, spec) + case ast.AlterTableTruncatePartition: + err = e.TruncateTablePartition(sctx, ident, spec) + case ast.AlterTableWriteable: + if !config.TableLockEnabled() { + return nil + } + tName := &ast.TableName{Schema: ident.Schema, Name: ident.Name} + if spec.Writeable { + err = e.CleanupTableLock(sctx, []*ast.TableName{tName}) + } else { + lockStmt := &ast.LockTablesStmt{ + TableLocks: []ast.TableLock{ + { + Table: tName, + Type: model.TableLockReadOnly, + }, + }, + } + err = e.LockTables(sctx, lockStmt) + } + case ast.AlterTableExchangePartition: + err = e.ExchangeTablePartition(sctx, ident, spec) + case ast.AlterTableAddConstraint: + constr := spec.Constraint + switch spec.Constraint.Tp { + case ast.ConstraintKey, ast.ConstraintIndex: + err = e.createIndex(sctx, ident, ast.IndexKeyTypeNone, model.NewCIStr(constr.Name), + spec.Constraint.Keys, constr.Option, constr.IfNotExists) + case ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey: + err = e.createIndex(sctx, ident, ast.IndexKeyTypeUnique, model.NewCIStr(constr.Name), + spec.Constraint.Keys, constr.Option, false) // IfNotExists should be not applied + case ast.ConstraintForeignKey: + // NOTE: we do not handle `symbol` and `index_name` well in the parser and we do not check ForeignKey already exists, + // so we just also ignore the `if not exists` check. + err = e.CreateForeignKey(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, spec.Constraint.Refer) + case ast.ConstraintPrimaryKey: + err = e.CreatePrimaryKey(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, constr.Option) + case ast.ConstraintFulltext: + sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt) + case ast.ConstraintCheck: + if !variable.EnableCheckConstraint.Load() { + sctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) + } else { + err = e.CreateCheckConstraint(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint) + } + default: + // Nothing to do now. + } + case ast.AlterTableDropForeignKey: + // NOTE: we do not check `if not exists` and `if exists` for ForeignKey now. + err = e.DropForeignKey(sctx, ident, model.NewCIStr(spec.Name)) + case ast.AlterTableModifyColumn: + err = e.ModifyColumn(ctx, sctx, ident, spec) + case ast.AlterTableChangeColumn: + err = e.ChangeColumn(ctx, sctx, ident, spec) + case ast.AlterTableRenameColumn: + err = e.RenameColumn(sctx, ident, spec) + case ast.AlterTableAlterColumn: + err = e.AlterColumn(sctx, ident, spec) + case ast.AlterTableRenameTable: + newIdent := ast.Ident{Schema: spec.NewTable.Schema, Name: spec.NewTable.Name} + isAlterTable := true + err = e.renameTable(sctx, ident, newIdent, isAlterTable) + case ast.AlterTablePartition: + err = e.AlterTablePartitioning(sctx, ident, spec) + case ast.AlterTableOption: + var placementPolicyRef *model.PolicyRefInfo + for i, opt := range spec.Options { + switch opt.Tp { + case ast.TableOptionShardRowID: + if opt.UintValue > shardRowIDBitsMax { + opt.UintValue = shardRowIDBitsMax + } + err = e.ShardRowID(sctx, ident, opt.UintValue) + case ast.TableOptionAutoIncrement: + err = e.RebaseAutoID(sctx, ident, int64(opt.UintValue), autoid.AutoIncrementType, opt.BoolValue) + case ast.TableOptionAutoIdCache: + if opt.UintValue > uint64(math.MaxInt64) { + // TODO: Refine this error. + return errors.New("table option auto_id_cache overflows int64") + } + err = e.AlterTableAutoIDCache(sctx, ident, int64(opt.UintValue)) + case ast.TableOptionAutoRandomBase: + err = e.RebaseAutoID(sctx, ident, int64(opt.UintValue), autoid.AutoRandomType, opt.BoolValue) + case ast.TableOptionComment: + spec.Comment = opt.StrValue + err = e.AlterTableComment(sctx, ident, spec) + case ast.TableOptionCharset, ast.TableOptionCollate: + // GetCharsetAndCollateInTableOption will get the last charset and collate in the options, + // so it should be handled only once. + if handledCharsetOrCollate { + continue + } + var toCharset, toCollate string + toCharset, toCollate, err = GetCharsetAndCollateInTableOption(sctx.GetSessionVars(), i, spec.Options) + if err != nil { + return err + } + needsOverwriteCols := NeedToOverwriteColCharset(spec.Options) + err = e.AlterTableCharsetAndCollate(sctx, ident, toCharset, toCollate, needsOverwriteCols) + handledCharsetOrCollate = true + case ast.TableOptionPlacementPolicy: + placementPolicyRef = &model.PolicyRefInfo{ + Name: model.NewCIStr(opt.StrValue), + } + case ast.TableOptionEngine: + case ast.TableOptionRowFormat: + case ast.TableOptionTTL, ast.TableOptionTTLEnable, ast.TableOptionTTLJobInterval: + var ttlInfo *model.TTLInfo + var ttlEnable *bool + var ttlJobInterval *string + + if ttlOptionsHandled { + continue + } + ttlInfo, ttlEnable, ttlJobInterval, err = getTTLInfoInOptions(spec.Options) + if err != nil { + return err + } + err = e.AlterTableTTLInfoOrEnable(sctx, ident, ttlInfo, ttlEnable, ttlJobInterval) + + ttlOptionsHandled = true + default: + err = dbterror.ErrUnsupportedAlterTableOption + } + + if err != nil { + return errors.Trace(err) + } + } + + if placementPolicyRef != nil { + err = e.AlterTablePlacement(sctx, ident, placementPolicyRef) + } + case ast.AlterTableSetTiFlashReplica: + err = e.AlterTableSetTiFlashReplica(sctx, ident, spec.TiFlashReplica) + case ast.AlterTableOrderByColumns: + err = e.OrderByColumns(sctx, ident) + case ast.AlterTableIndexInvisible: + err = e.AlterIndexVisibility(sctx, ident, spec.IndexName, spec.Visibility) + case ast.AlterTableAlterCheck: + if !variable.EnableCheckConstraint.Load() { + sctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) + } else { + err = e.AlterCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name), spec.Constraint.Enforced) + } + case ast.AlterTableDropCheck: + if !variable.EnableCheckConstraint.Load() { + sctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) + } else { + err = e.DropCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name)) + } + case ast.AlterTableWithValidation: + sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedAlterTableWithValidation) + case ast.AlterTableWithoutValidation: + sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedAlterTableWithoutValidation) + case ast.AlterTableAddStatistics: + err = e.AlterTableAddStatistics(sctx, ident, spec.Statistics, spec.IfNotExists) + case ast.AlterTableDropStatistics: + err = e.AlterTableDropStatistics(sctx, ident, spec.Statistics, spec.IfExists) + case ast.AlterTableAttributes: + err = e.AlterTableAttributes(sctx, ident, spec) + case ast.AlterTablePartitionAttributes: + err = e.AlterTablePartitionAttributes(sctx, ident, spec) + case ast.AlterTablePartitionOptions: + err = e.AlterTablePartitionOptions(sctx, ident, spec) + case ast.AlterTableCache: + err = e.AlterTableCache(sctx, ident) + case ast.AlterTableNoCache: + err = e.AlterTableNoCache(sctx, ident) + case ast.AlterTableDisableKeys, ast.AlterTableEnableKeys: + // Nothing to do now, see https://github.com/pingcap/tidb/issues/1051 + // MyISAM specific + case ast.AlterTableRemoveTTL: + // the parser makes sure we have only one `ast.AlterTableRemoveTTL` in an alter statement + err = e.AlterTableRemoveTTL(sctx, ident) + default: + err = errors.Trace(dbterror.ErrUnsupportedAlterTableSpec) + } + + if err != nil { + return errors.Trace(err) + } + } + + if sctx.GetSessionVars().StmtCtx.MultiSchemaInfo != nil { + info := sctx.GetSessionVars().StmtCtx.MultiSchemaInfo + sctx.GetSessionVars().StmtCtx.MultiSchemaInfo = nil + err = e.multiSchemaChange(sctx, ident, info) + if err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func (e *executor) multiSchemaChange(ctx sessionctx.Context, ti ast.Ident, info *model.MultiSchemaInfo) error { + subJobs := info.SubJobs + if len(subJobs) == 0 { + return nil + } + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return errors.Trace(err) + } + + logFn := logutil.DDLLogger().Warn + if intest.InTest { + logFn = logutil.DDLLogger().Fatal + } + + var involvingSchemaInfo []model.InvolvingSchemaInfo + for _, j := range subJobs { + switch j.Type { + case model.ActionAlterTablePlacement: + ref, ok := j.Args[0].(*model.PolicyRefInfo) + if !ok { + logFn("unexpected type of policy reference info", + zap.Any("args[0]", j.Args[0]), + zap.String("type", fmt.Sprintf("%T", j.Args[0]))) + continue + } + if ref == nil { + continue + } + involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ + Policy: ref.Name.L, + Mode: model.SharedInvolving, + }) + case model.ActionAddForeignKey: + ref, ok := j.Args[0].(*model.FKInfo) + if !ok { + logFn("unexpected type of foreign key info", + zap.Any("args[0]", j.Args[0]), + zap.String("type", fmt.Sprintf("%T", j.Args[0]))) + continue + } + involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ + Database: ref.RefSchema.L, + Table: ref.RefTable.L, + Mode: model.SharedInvolving, + }) + case model.ActionAlterTablePartitionPlacement: + if len(j.Args) < 2 { + logFn("unexpected number of arguments for partition placement", + zap.Int("len(args)", len(j.Args)), + zap.Any("args", j.Args)) + continue + } + ref, ok := j.Args[1].(*model.PolicyRefInfo) + if !ok { + logFn("unexpected type of policy reference info", + zap.Any("args[0]", j.Args[0]), + zap.String("type", fmt.Sprintf("%T", j.Args[0]))) + continue + } + if ref == nil { + continue + } + involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ + Policy: ref.Name.L, + Mode: model.SharedInvolving, + }) + } + } + + if len(involvingSchemaInfo) > 0 { + involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ + Database: schema.Name.L, + Table: t.Meta().Name.L, + }) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionMultiSchemaChange, + BinlogInfo: &model.HistoryInfo{}, + Args: nil, + MultiSchemaInfo: info, + ReorgMeta: nil, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: involvingSchemaInfo, + SQLMode: ctx.GetSessionVars().SQLMode, + } + if containsDistTaskSubJob(subJobs) { + job.ReorgMeta, err = newReorgMetaFromVariables(job, ctx) + if err != nil { + return err + } + } else { + job.ReorgMeta = NewDDLReorgMeta(ctx) + } + + err = checkMultiSchemaInfo(info, t) + if err != nil { + return errors.Trace(err) + } + mergeAddIndex(info) + return e.DoDDLJob(ctx, job) +} + +func containsDistTaskSubJob(subJobs []*model.SubJob) bool { + for _, sub := range subJobs { + if sub.Type == model.ActionAddIndex || + sub.Type == model.ActionAddPrimaryKey { + return true + } + } + return false +} + +func (e *executor) RebaseAutoID(ctx sessionctx.Context, ident ast.Ident, newBase int64, tp autoid.AllocatorType, force bool) error { + schema, t, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + tbInfo := t.Meta() + var actionType model.ActionType + switch tp { + case autoid.AutoRandomType: + pkCol := tbInfo.GetPkColInfo() + if tbInfo.AutoRandomBits == 0 || pkCol == nil { + return errors.Trace(dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomRebaseNotApplicable)) + } + shardFmt := autoid.NewShardIDFormat(&pkCol.FieldType, tbInfo.AutoRandomBits, tbInfo.AutoRandomRangeBits) + if shardFmt.IncrementalMask()&newBase != newBase { + errMsg := fmt.Sprintf(autoid.AutoRandomRebaseOverflow, newBase, shardFmt.IncrementalBitsCapacity()) + return errors.Trace(dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg)) + } + actionType = model.ActionRebaseAutoRandomBase + case autoid.RowIDAllocType: + actionType = model.ActionRebaseAutoID + case autoid.AutoIncrementType: + actionType = model.ActionRebaseAutoID + default: + panic(fmt.Sprintf("unimplemented rebase autoid type %s", tp)) + } + + if !force { + newBaseTemp, err := adjustNewBaseToNextGlobalID(ctx.GetTableCtx(), t, tp, newBase) + if err != nil { + return err + } + if newBase != newBaseTemp { + ctx.GetSessionVars().StmtCtx.AppendWarning( + errors.NewNoStackErrorf("Can't reset AUTO_INCREMENT to %d without FORCE option, using %d instead", + newBase, newBaseTemp, + )) + } + newBase = newBaseTemp + } + job := &model.Job{ + SchemaID: schema.ID, + TableID: tbInfo.ID, + SchemaName: schema.Name.L, + TableName: tbInfo.Name.L, + Type: actionType, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{newBase, force}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func adjustNewBaseToNextGlobalID(ctx table.AllocatorContext, t table.Table, tp autoid.AllocatorType, newBase int64) (int64, error) { + alloc := t.Allocators(ctx).Get(tp) + if alloc == nil { + return newBase, nil + } + autoID, err := alloc.NextGlobalAutoID() + if err != nil { + return newBase, errors.Trace(err) + } + // If newBase < autoID, we need to do a rebase before returning. + // Assume there are 2 TiDB servers: TiDB-A with allocator range of 0 ~ 30000; TiDB-B with allocator range of 30001 ~ 60000. + // If the user sends SQL `alter table t1 auto_increment = 100` to TiDB-B, + // and TiDB-B finds 100 < 30001 but returns without any handling, + // then TiDB-A may still allocate 99 for auto_increment column. This doesn't make sense for the user. + return int64(mathutil.Max(uint64(newBase), uint64(autoID))), nil +} + +// ShardRowID shards the implicit row ID by adding shard value to the row ID's first few bits. +func (e *executor) ShardRowID(ctx sessionctx.Context, tableIdent ast.Ident, uVal uint64) error { + schema, t, err := e.getSchemaAndTableByIdent(tableIdent) + if err != nil { + return errors.Trace(err) + } + tbInfo := t.Meta() + if tbInfo.TempTableType != model.TempTableNone { + return dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("shard_row_id_bits") + } + if uVal == tbInfo.ShardRowIDBits { + // Nothing need to do. + return nil + } + if uVal > 0 && tbInfo.HasClusteredIndex() { + return dbterror.ErrUnsupportedShardRowIDBits + } + err = verifyNoOverflowShardBits(e.sessPool, t, uVal) + if err != nil { + return err + } + job := &model.Job{ + Type: model.ActionShardRowID, + SchemaID: schema.ID, + TableID: tbInfo.ID, + SchemaName: schema.Name.L, + TableName: tbInfo.Name.L, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{uVal}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) getSchemaAndTableByIdent(tableIdent ast.Ident) (dbInfo *model.DBInfo, t table.Table, err error) { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(tableIdent.Schema) + if !ok { + return nil, nil, infoschema.ErrDatabaseNotExists.GenWithStackByArgs(tableIdent.Schema) + } + t, err = is.TableByName(e.ctx, tableIdent.Schema, tableIdent.Name) + if err != nil { + return nil, nil, infoschema.ErrTableNotExists.GenWithStackByArgs(tableIdent.Schema, tableIdent.Name) + } + return schema, t, nil +} + +// AddColumn will add a new column to the table. +func (e *executor) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) error { + specNewColumn := spec.NewColumns[0] + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return errors.Trace(err) + } + failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) + tbInfo := t.Meta() + if err = checkAddColumnTooManyColumns(len(t.Cols()) + 1); err != nil { + return errors.Trace(err) + } + col, err := checkAndCreateNewColumn(ctx, ti, schema, spec, t, specNewColumn) + if err != nil { + return errors.Trace(err) + } + // Added column has existed and if_not_exists flag is true. + if col == nil { + return nil + } + err = CheckAfterPositionExists(tbInfo, spec.Position) + if err != nil { + return errors.Trace(err) + } + + txn, err := ctx.Txn(true) + if err != nil { + return errors.Trace(err) + } + bdrRole, err := meta.NewMeta(txn).GetBDRRole() + if err != nil { + return errors.Trace(err) + } + if bdrRole == string(ast.BDRRolePrimary) && deniedByBDRWhenAddColumn(specNewColumn.Options) { + return dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tbInfo.ID, + SchemaName: schema.Name.L, + TableName: tbInfo.Name.L, + Type: model.ActionAddColumn, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{col, spec.Position, 0, spec.IfNotExists}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// AddTablePartitions will add a new partition to the table. +func (e *executor) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) + } + t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + + meta := t.Meta() + pi := meta.GetPartitionInfo() + if pi == nil { + return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + if pi.Type == model.PartitionTypeHash || pi.Type == model.PartitionTypeKey { + // Add partition for hash/key is actually a reorganize partition + // operation and not a metadata only change! + switch spec.Tp { + case ast.AlterTableAddLastPartition: + return errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("LAST PARTITION of HASH/KEY partitioned table")) + case ast.AlterTableAddPartitions: + // only thing supported + default: + return errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("ADD PARTITION of HASH/KEY partitioned table")) + } + return e.hashPartitionManagement(ctx, ident, spec, pi) + } + + partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec) + if err != nil { + return errors.Trace(err) + } + if pi.Type == model.PartitionTypeList { + // TODO: make sure that checks in ddl_api and ddl_worker is the same. + err = checkAddListPartitions(meta) + if err != nil { + return errors.Trace(err) + } + } + if err := e.assignPartitionIDs(partInfo.Definitions); err != nil { + return errors.Trace(err) + } + + // partInfo contains only the new added partition, we have to combine it with the + // old partitions to check all partitions is strictly increasing. + clonedMeta := meta.Clone() + tmp := *partInfo + tmp.Definitions = append(pi.Definitions, tmp.Definitions...) + clonedMeta.Partition = &tmp + if err := checkPartitionDefinitionConstraints(ctx, clonedMeta); err != nil { + if dbterror.ErrSameNamePartition.Equal(err) && spec.IfNotExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) + } + + if err = handlePartitionPlacement(ctx, partInfo); err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionAddTablePartition, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{partInfo}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + if spec.Tp == ast.AlterTableAddLastPartition && spec.Partition != nil { + query, ok := ctx.Value(sessionctx.QueryString).(string) + if ok { + sqlMode := ctx.GetSessionVars().SQLMode + var buf bytes.Buffer + AppendPartitionDefs(partInfo, &buf, sqlMode) + + syntacticSugar := spec.Partition.PartitionMethod.OriginalText() + syntacticStart := spec.Partition.PartitionMethod.OriginTextPosition() + newQuery := query[:syntacticStart] + "ADD PARTITION (" + buf.String() + ")" + query[syntacticStart+len(syntacticSugar):] + defer ctx.SetValue(sessionctx.QueryString, query) + ctx.SetValue(sessionctx.QueryString, newQuery) + } + } + err = e.DoDDLJob(ctx, job) + if dbterror.ErrSameNamePartition.Equal(err) && spec.IfNotExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) +} + +// getReorganizedDefinitions return the definitions as they would look like after the REORGANIZE PARTITION is done. +func getReorganizedDefinitions(pi *model.PartitionInfo, firstPartIdx, lastPartIdx int, idMap map[int]struct{}) []model.PartitionDefinition { + tmpDefs := make([]model.PartitionDefinition, 0, len(pi.Definitions)+len(pi.AddingDefinitions)-len(idMap)) + if pi.Type == model.PartitionTypeList { + replaced := false + for i := range pi.Definitions { + if _, ok := idMap[i]; ok { + if !replaced { + tmpDefs = append(tmpDefs, pi.AddingDefinitions...) + replaced = true + } + continue + } + tmpDefs = append(tmpDefs, pi.Definitions[i]) + } + if !replaced { + // For safety, for future non-partitioned table -> partitioned + tmpDefs = append(tmpDefs, pi.AddingDefinitions...) + } + return tmpDefs + } + // Range + tmpDefs = append(tmpDefs, pi.Definitions[:firstPartIdx]...) + tmpDefs = append(tmpDefs, pi.AddingDefinitions...) + if len(pi.Definitions) > (lastPartIdx + 1) { + tmpDefs = append(tmpDefs, pi.Definitions[lastPartIdx+1:]...) + } + return tmpDefs +} + +func getReplacedPartitionIDs(names []string, pi *model.PartitionInfo) (firstPartIdx int, lastPartIdx int, idMap map[int]struct{}, err error) { + idMap = make(map[int]struct{}) + firstPartIdx, lastPartIdx = -1, -1 + for _, name := range names { + nameL := strings.ToLower(name) + partIdx := pi.FindPartitionDefinitionByName(nameL) + if partIdx == -1 { + return 0, 0, nil, errors.Trace(dbterror.ErrWrongPartitionName) + } + if _, ok := idMap[partIdx]; ok { + return 0, 0, nil, errors.Trace(dbterror.ErrSameNamePartition) + } + idMap[partIdx] = struct{}{} + if firstPartIdx == -1 { + firstPartIdx = partIdx + } else { + firstPartIdx = mathutil.Min[int](firstPartIdx, partIdx) + } + if lastPartIdx == -1 { + lastPartIdx = partIdx + } else { + lastPartIdx = mathutil.Max[int](lastPartIdx, partIdx) + } + } + switch pi.Type { + case model.PartitionTypeRange: + if len(idMap) != (lastPartIdx - firstPartIdx + 1) { + return 0, 0, nil, errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "REORGANIZE PARTITION of RANGE; not adjacent partitions")) + } + case model.PartitionTypeHash, model.PartitionTypeKey: + if len(idMap) != len(pi.Definitions) { + return 0, 0, nil, errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "REORGANIZE PARTITION of HASH/RANGE; must reorganize all partitions")) + } + } + + return firstPartIdx, lastPartIdx, idMap, nil +} + +func getPartitionInfoTypeNone() *model.PartitionInfo { + return &model.PartitionInfo{ + Type: model.PartitionTypeNone, + Enable: true, + Definitions: []model.PartitionDefinition{{ + Name: model.NewCIStr("pFullTable"), + Comment: "Intermediate partition during ALTER TABLE ... PARTITION BY ...", + }}, + Num: 1, + } +} + +// AlterTablePartitioning reorganize one set of partitions to a new set of partitions. +func (e *executor) AlterTablePartitioning(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + schema, t, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.FastGenByArgs(ident.Schema, ident.Name)) + } + + meta := t.Meta().Clone() + piOld := meta.GetPartitionInfo() + var partNames []string + if piOld != nil { + partNames = make([]string, 0, len(piOld.Definitions)) + for i := range piOld.Definitions { + partNames = append(partNames, piOld.Definitions[i].Name.L) + } + } else { + piOld = getPartitionInfoTypeNone() + meta.Partition = piOld + partNames = append(partNames, piOld.Definitions[0].Name.L) + } + newMeta := meta.Clone() + err = buildTablePartitionInfo(ctx, spec.Partition, newMeta) + if err != nil { + return err + } + newPartInfo := newMeta.Partition + + for _, index := range newMeta.Indices { + if index.Unique { + ck, err := checkPartitionKeysConstraint(newMeta.GetPartitionInfo(), index.Columns, newMeta) + if err != nil { + return err + } + if !ck { + indexTp := "" + if !ctx.GetSessionVars().EnableGlobalIndex { + if index.Primary { + indexTp = "PRIMARY KEY" + } else { + indexTp = "UNIQUE INDEX" + } + } else if t.Meta().IsCommonHandle { + indexTp = "CLUSTERED INDEX" + } + if indexTp != "" { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs(indexTp) + } + // Also mark the unique index as global index + index.Global = true + } + } + } + if newMeta.PKIsHandle { + // This case is covers when the Handle is the PK (only ints), since it would not + // have an entry in the tblInfo.Indices + indexCols := []*model.IndexColumn{{ + Name: newMeta.GetPkName(), + Length: types.UnspecifiedLength, + }} + ck, err := checkPartitionKeysConstraint(newMeta.GetPartitionInfo(), indexCols, newMeta) + if err != nil { + return err + } + if !ck { + if !ctx.GetSessionVars().EnableGlobalIndex { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("PRIMARY KEY") + } + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("CLUSTERED INDEX") + } + } + + if err = handlePartitionPlacement(ctx, newPartInfo); err != nil { + return errors.Trace(err) + } + + if err = e.assignPartitionIDs(newPartInfo.Definitions); err != nil { + return errors.Trace(err) + } + // A new table ID would be needed for + // the global index, which cannot be the same as the current table id, + // since this table id will be removed in the final state when removing + // all the data with this table id. + var newID []int64 + newID, err = e.genGlobalIDs(1) + if err != nil { + return errors.Trace(err) + } + newPartInfo.NewTableID = newID[0] + newPartInfo.DDLType = piOld.Type + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionAlterTablePartitioning, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{partNames, newPartInfo}, + ReorgMeta: NewDDLReorgMeta(ctx), + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. + err = e.DoDDLJob(ctx, job) + if err == nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of new partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) + } + return errors.Trace(err) +} + +// ReorganizePartitions reorganize one set of partitions to a new set of partitions. +func (e *executor) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + schema, t, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.FastGenByArgs(ident.Schema, ident.Name)) + } + + meta := t.Meta() + pi := meta.GetPartitionInfo() + if pi == nil { + return dbterror.ErrPartitionMgmtOnNonpartitioned + } + switch pi.Type { + case model.PartitionTypeRange, model.PartitionTypeList: + case model.PartitionTypeHash, model.PartitionTypeKey: + if spec.Tp != ast.AlterTableCoalescePartitions && + spec.Tp != ast.AlterTableAddPartitions { + return errors.Trace(dbterror.ErrUnsupportedReorganizePartition) + } + default: + return errors.Trace(dbterror.ErrUnsupportedReorganizePartition) + } + partNames := make([]string, 0, len(spec.PartitionNames)) + for _, name := range spec.PartitionNames { + partNames = append(partNames, name.L) + } + firstPartIdx, lastPartIdx, idMap, err := getReplacedPartitionIDs(partNames, pi) + if err != nil { + return errors.Trace(err) + } + partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec) + if err != nil { + return errors.Trace(err) + } + if err = e.assignPartitionIDs(partInfo.Definitions); err != nil { + return errors.Trace(err) + } + if err = checkReorgPartitionDefs(ctx, model.ActionReorganizePartition, meta, partInfo, firstPartIdx, lastPartIdx, idMap); err != nil { + return errors.Trace(err) + } + if err = handlePartitionPlacement(ctx, partInfo); err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionReorganizePartition, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{partNames, partInfo}, + ReorgMeta: NewDDLReorgMeta(ctx), + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. + err = e.DoDDLJob(ctx, job) + failpoint.InjectCall("afterReorganizePartition") + if err == nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of related partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) + } + return errors.Trace(err) +} + +// RemovePartitioning removes partitioning from a table. +func (e *executor) RemovePartitioning(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + schema, t, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.FastGenByArgs(ident.Schema, ident.Name)) + } + + meta := t.Meta().Clone() + pi := meta.GetPartitionInfo() + if pi == nil { + return dbterror.ErrPartitionMgmtOnNonpartitioned + } + // TODO: Optimize for remove partitioning with a single partition + // TODO: Add the support for this in onReorganizePartition + // skip if only one partition + // If there are only one partition, then we can do: + // change the table id to the partition id + // and keep the statistics for the partition id (which should be similar to the global statistics) + // and it let the GC clean up the old table metadata including possible global index. + + newSpec := &ast.AlterTableSpec{} + newSpec.Tp = spec.Tp + defs := make([]*ast.PartitionDefinition, 1) + defs[0] = &ast.PartitionDefinition{} + defs[0].Name = model.NewCIStr("CollapsedPartitions") + newSpec.PartDefinitions = defs + partNames := make([]string, len(pi.Definitions)) + for i := range pi.Definitions { + partNames[i] = pi.Definitions[i].Name.L + } + meta.Partition.Type = model.PartitionTypeNone + partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, newSpec) + if err != nil { + return errors.Trace(err) + } + if err = e.assignPartitionIDs(partInfo.Definitions); err != nil { + return errors.Trace(err) + } + // TODO: check where the default placement comes from (i.e. table level) + if err = handlePartitionPlacement(ctx, partInfo); err != nil { + return errors.Trace(err) + } + partInfo.NewTableID = partInfo.Definitions[0].ID + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + TableName: meta.Name.L, + Type: model.ActionRemovePartitioning, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{partNames, partInfo}, + ReorgMeta: NewDDLReorgMeta(ctx), + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tblInfo *model.TableInfo, partInfo *model.PartitionInfo, firstPartIdx, lastPartIdx int, idMap map[int]struct{}) error { + // partInfo contains only the new added partition, we have to combine it with the + // old partitions to check all partitions is strictly increasing. + pi := tblInfo.Partition + clonedMeta := tblInfo.Clone() + switch action { + case model.ActionRemovePartitioning, model.ActionAlterTablePartitioning: + clonedMeta.Partition = partInfo + clonedMeta.ID = partInfo.NewTableID + case model.ActionReorganizePartition: + clonedMeta.Partition.AddingDefinitions = partInfo.Definitions + clonedMeta.Partition.Definitions = getReorganizedDefinitions(clonedMeta.Partition, firstPartIdx, lastPartIdx, idMap) + default: + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("partition type") + } + if err := checkPartitionDefinitionConstraints(ctx, clonedMeta); err != nil { + return errors.Trace(err) + } + if action == model.ActionReorganizePartition { + if pi.Type == model.PartitionTypeRange { + if lastPartIdx == len(pi.Definitions)-1 { + // Last partition dropped, OK to change the end range + // Also includes MAXVALUE + return nil + } + // Check if the replaced end range is the same as before + lastAddingPartition := partInfo.Definitions[len(partInfo.Definitions)-1] + lastOldPartition := pi.Definitions[lastPartIdx] + if len(pi.Columns) > 0 { + newGtOld, err := checkTwoRangeColumns(ctx, &lastAddingPartition, &lastOldPartition, pi, tblInfo) + if err != nil { + return errors.Trace(err) + } + if newGtOld { + return errors.Trace(dbterror.ErrRangeNotIncreasing) + } + oldGtNew, err := checkTwoRangeColumns(ctx, &lastOldPartition, &lastAddingPartition, pi, tblInfo) + if err != nil { + return errors.Trace(err) + } + if oldGtNew { + return errors.Trace(dbterror.ErrRangeNotIncreasing) + } + return nil + } + + isUnsigned := isPartExprUnsigned(ctx.GetExprCtx().GetEvalCtx(), tblInfo) + currentRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), pi.Definitions[lastPartIdx].LessThan[0], isUnsigned) + if err != nil { + return errors.Trace(err) + } + newRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned) + if err != nil { + return errors.Trace(err) + } + + if currentRangeValue != newRangeValue { + return errors.Trace(dbterror.ErrRangeNotIncreasing) + } + } + } else { + if len(pi.Definitions) != (lastPartIdx - firstPartIdx + 1) { + // if not ActionReorganizePartition, require all partitions to be changed. + return errors.Trace(dbterror.ErrAlterOperationNotSupported) + } + } + return nil +} + +// CoalescePartitions coalesce partitions can be used with a table that is partitioned by hash or key to reduce the number of partitions by number. +func (e *executor) CoalescePartitions(sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) + } + t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + + pi := t.Meta().GetPartitionInfo() + if pi == nil { + return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + + switch pi.Type { + case model.PartitionTypeHash, model.PartitionTypeKey: + return e.hashPartitionManagement(sctx, ident, spec, pi) + + // Coalesce partition can only be used on hash/key partitions. + default: + return errors.Trace(dbterror.ErrCoalesceOnlyOnHashPartition) + } +} + +func (e *executor) hashPartitionManagement(sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec, pi *model.PartitionInfo) error { + newSpec := *spec + newSpec.PartitionNames = make([]model.CIStr, len(pi.Definitions)) + for i := 0; i < len(pi.Definitions); i++ { + // reorganize ALL partitions into the new number of partitions + newSpec.PartitionNames[i] = pi.Definitions[i].Name + } + for i := 0; i < len(newSpec.PartDefinitions); i++ { + switch newSpec.PartDefinitions[i].Clause.(type) { + case *ast.PartitionDefinitionClauseNone: + // OK, expected + case *ast.PartitionDefinitionClauseIn: + return errors.Trace(ast.ErrPartitionWrongValues.FastGenByArgs("LIST", "IN")) + case *ast.PartitionDefinitionClauseLessThan: + return errors.Trace(ast.ErrPartitionWrongValues.FastGenByArgs("RANGE", "LESS THAN")) + case *ast.PartitionDefinitionClauseHistory: + return errors.Trace(ast.ErrPartitionWrongValues.FastGenByArgs("SYSTEM_TIME", "HISTORY")) + + default: + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "partitioning clause") + } + } + if newSpec.Num < uint64(len(newSpec.PartDefinitions)) { + newSpec.Num = uint64(len(newSpec.PartDefinitions)) + } + if spec.Tp == ast.AlterTableCoalescePartitions { + if newSpec.Num < 1 { + return ast.ErrCoalescePartitionNoPartition + } + if newSpec.Num >= uint64(len(pi.Definitions)) { + return dbterror.ErrDropLastPartition + } + if isNonDefaultPartitionOptionsUsed(pi.Definitions) { + // The partition definitions will be copied in buildHashPartitionDefinitions() + // if there is a non-empty list of definitions + newSpec.PartDefinitions = []*ast.PartitionDefinition{{}} + } + } + + return e.ReorganizePartitions(sctx, ident, &newSpec) +} + +func (e *executor) TruncateTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) + } + t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + meta := t.Meta() + if meta.GetPartitionInfo() == nil { + return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + + getTruncatedParts := func(pi *model.PartitionInfo) (*model.PartitionInfo, error) { + if spec.OnAllPartitions { + return pi.Clone(), nil + } + var defs []model.PartitionDefinition + // MySQL allows duplicate partition names in truncate partition + // so we filter them out through a hash + posMap := make(map[int]bool) + for _, name := range spec.PartitionNames { + pos := pi.FindPartitionDefinitionByName(name.L) + if pos < 0 { + return nil, errors.Trace(table.ErrUnknownPartition.GenWithStackByArgs(name.L, ident.Name.O)) + } + if _, ok := posMap[pos]; !ok { + defs = append(defs, pi.Definitions[pos]) + posMap[pos] = true + } + } + pi = pi.Clone() + pi.Definitions = defs + return pi, nil + } + pi, err := getTruncatedParts(meta.GetPartitionInfo()) + if err != nil { + return err + } + pids := make([]int64, 0, len(pi.Definitions)) + for i := range pi.Definitions { + pids = append(pids, pi.Definitions[i].ID) + } + + genIDs, err := e.genGlobalIDs(len(pids)) + if err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + SchemaState: model.StatePublic, + TableName: t.Meta().Name.L, + Type: model.ActionTruncateTablePartition, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{pids, genIDs}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + if err != nil { + return errors.Trace(err) + } + return nil +} + +func (e *executor) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) + } + t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + meta := t.Meta() + if meta.GetPartitionInfo() == nil { + return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + + if spec.Tp == ast.AlterTableDropFirstPartition { + intervalOptions := getPartitionIntervalFromTable(ctx.GetExprCtx(), meta) + if intervalOptions == nil { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "FIRST PARTITION, does not seem like an INTERVAL partitioned table") + } + if len(spec.Partition.Definitions) != 0 { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "FIRST PARTITION, table info already contains partition definitions") + } + spec.Partition.Interval = intervalOptions + err = GeneratePartDefsFromInterval(ctx.GetExprCtx(), spec.Tp, meta, spec.Partition) + if err != nil { + return err + } + pNullOffset := 0 + if intervalOptions.NullPart { + pNullOffset = 1 + } + if len(spec.Partition.Definitions) == 0 || + len(spec.Partition.Definitions) >= len(meta.Partition.Definitions)-pNullOffset { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "FIRST PARTITION, number of partitions does not match") + } + if len(spec.PartitionNames) != 0 || len(spec.Partition.Definitions) <= 1 { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "FIRST PARTITION, given value does not generate a list of partition names to be dropped") + } + for i := range spec.Partition.Definitions { + spec.PartitionNames = append(spec.PartitionNames, meta.Partition.Definitions[i+pNullOffset].Name) + } + // Use the last generated partition as First, i.e. do not drop the last name in the slice + spec.PartitionNames = spec.PartitionNames[:len(spec.PartitionNames)-1] + + query, ok := ctx.Value(sessionctx.QueryString).(string) + if ok { + partNames := make([]string, 0, len(spec.PartitionNames)) + sqlMode := ctx.GetSessionVars().SQLMode + for i := range spec.PartitionNames { + partNames = append(partNames, stringutil.Escape(spec.PartitionNames[i].O, sqlMode)) + } + syntacticSugar := spec.Partition.PartitionMethod.OriginalText() + syntacticStart := spec.Partition.PartitionMethod.OriginTextPosition() + newQuery := query[:syntacticStart] + "DROP PARTITION " + strings.Join(partNames, ", ") + query[syntacticStart+len(syntacticSugar):] + defer ctx.SetValue(sessionctx.QueryString, query) + ctx.SetValue(sessionctx.QueryString, newQuery) + } + } + partNames := make([]string, len(spec.PartitionNames)) + for i, partCIName := range spec.PartitionNames { + partNames[i] = partCIName.L + } + err = CheckDropTablePartition(meta, partNames) + if err != nil { + if dbterror.ErrDropPartitionNonExistent.Equal(err) && spec.IfExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + SchemaState: model.StatePublic, + TableName: meta.Name.L, + Type: model.ActionDropTablePartition, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{partNames}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + if err != nil { + if dbterror.ErrDropPartitionNonExistent.Equal(err) && spec.IfExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) + } + return errors.Trace(err) +} + +func checkFieldTypeCompatible(ft *types.FieldType, other *types.FieldType) bool { + // int(1) could match the type with int(8) + partialEqual := ft.GetType() == other.GetType() && + ft.GetDecimal() == other.GetDecimal() && + ft.GetCharset() == other.GetCharset() && + ft.GetCollate() == other.GetCollate() && + (ft.GetFlen() == other.GetFlen() || ft.StorageLength() != types.VarStorageLen) && + mysql.HasUnsignedFlag(ft.GetFlag()) == mysql.HasUnsignedFlag(other.GetFlag()) && + mysql.HasAutoIncrementFlag(ft.GetFlag()) == mysql.HasAutoIncrementFlag(other.GetFlag()) && + mysql.HasNotNullFlag(ft.GetFlag()) == mysql.HasNotNullFlag(other.GetFlag()) && + mysql.HasZerofillFlag(ft.GetFlag()) == mysql.HasZerofillFlag(other.GetFlag()) && + mysql.HasBinaryFlag(ft.GetFlag()) == mysql.HasBinaryFlag(other.GetFlag()) && + mysql.HasPriKeyFlag(ft.GetFlag()) == mysql.HasPriKeyFlag(other.GetFlag()) + if !partialEqual || len(ft.GetElems()) != len(other.GetElems()) { + return false + } + for i := range ft.GetElems() { + if ft.GetElems()[i] != other.GetElems()[i] { + return false + } + } + return true +} + +func checkTiFlashReplicaCompatible(source *model.TiFlashReplicaInfo, target *model.TiFlashReplicaInfo) bool { + if source == target { + return true + } + if source == nil || target == nil { + return false + } + if source.Count != target.Count || + source.Available != target.Available || len(source.LocationLabels) != len(target.LocationLabels) { + return false + } + for i, lable := range source.LocationLabels { + if target.LocationLabels[i] != lable { + return false + } + } + return true +} + +func checkTableDefCompatible(source *model.TableInfo, target *model.TableInfo) error { + // check temp table + if target.TempTableType != model.TempTableNone { + return errors.Trace(dbterror.ErrPartitionExchangeTempTable.FastGenByArgs(target.Name)) + } + + // check auto_random + if source.AutoRandomBits != target.AutoRandomBits || + source.AutoRandomRangeBits != target.AutoRandomRangeBits || + source.Charset != target.Charset || + source.Collate != target.Collate || + source.ShardRowIDBits != target.ShardRowIDBits || + source.MaxShardRowIDBits != target.MaxShardRowIDBits || + !checkTiFlashReplicaCompatible(source.TiFlashReplica, target.TiFlashReplica) { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + if len(source.Cols()) != len(target.Cols()) { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + // Col compatible check + for i, sourceCol := range source.Cols() { + targetCol := target.Cols()[i] + if sourceCol.IsVirtualGenerated() != targetCol.IsVirtualGenerated() { + return dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Exchanging partitions for non-generated columns") + } + // It should strictyle compare expressions for generated columns + if sourceCol.Name.L != targetCol.Name.L || + sourceCol.Hidden != targetCol.Hidden || + !checkFieldTypeCompatible(&sourceCol.FieldType, &targetCol.FieldType) || + sourceCol.GeneratedExprString != targetCol.GeneratedExprString { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + if sourceCol.State != model.StatePublic || + targetCol.State != model.StatePublic { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + if sourceCol.ID != targetCol.ID { + return dbterror.ErrPartitionExchangeDifferentOption.GenWithStackByArgs(fmt.Sprintf("column: %s", sourceCol.Name)) + } + } + if len(source.Indices) != len(target.Indices) { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + for _, sourceIdx := range source.Indices { + if sourceIdx.Global { + return dbterror.ErrPartitionExchangeDifferentOption.GenWithStackByArgs(fmt.Sprintf("global index: %s", sourceIdx.Name)) + } + var compatIdx *model.IndexInfo + for _, targetIdx := range target.Indices { + if strings.EqualFold(sourceIdx.Name.L, targetIdx.Name.L) { + compatIdx = targetIdx + } + } + // No match index + if compatIdx == nil { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + // Index type is not compatible + if sourceIdx.Tp != compatIdx.Tp || + sourceIdx.Unique != compatIdx.Unique || + sourceIdx.Primary != compatIdx.Primary { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + // The index column + if len(sourceIdx.Columns) != len(compatIdx.Columns) { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + for i, sourceIdxCol := range sourceIdx.Columns { + compatIdxCol := compatIdx.Columns[i] + if sourceIdxCol.Length != compatIdxCol.Length || + sourceIdxCol.Name.L != compatIdxCol.Name.L { + return errors.Trace(dbterror.ErrTablesDifferentMetadata) + } + } + if sourceIdx.ID != compatIdx.ID { + return dbterror.ErrPartitionExchangeDifferentOption.GenWithStackByArgs(fmt.Sprintf("index: %s", sourceIdx.Name)) + } + } + + return nil +} + +func checkExchangePartition(pt *model.TableInfo, nt *model.TableInfo) error { + if nt.IsView() || nt.IsSequence() { + return errors.Trace(dbterror.ErrCheckNoSuchTable) + } + if pt.GetPartitionInfo() == nil { + return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + if nt.GetPartitionInfo() != nil { + return errors.Trace(dbterror.ErrPartitionExchangePartTable.GenWithStackByArgs(nt.Name)) + } + + if len(nt.ForeignKeys) > 0 { + return errors.Trace(dbterror.ErrPartitionExchangeForeignKey.GenWithStackByArgs(nt.Name)) + } + + return nil +} + +func (e *executor) ExchangeTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + ptSchema, pt, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + + ptMeta := pt.Meta() + + ntIdent := ast.Ident{Schema: spec.NewTable.Schema, Name: spec.NewTable.Name} + + // We should check local temporary here using session's info schema because the local temporary tables are only stored in session. + ntLocalTempTable, err := sessiontxn.GetTxnManager(ctx).GetTxnInfoSchema().TableByName(context.Background(), ntIdent.Schema, ntIdent.Name) + if err == nil && ntLocalTempTable.Meta().TempTableType == model.TempTableLocal { + return errors.Trace(dbterror.ErrPartitionExchangeTempTable.FastGenByArgs(ntLocalTempTable.Meta().Name)) + } + + ntSchema, nt, err := e.getSchemaAndTableByIdent(ntIdent) + if err != nil { + return errors.Trace(err) + } + + ntMeta := nt.Meta() + + err = checkExchangePartition(ptMeta, ntMeta) + if err != nil { + return errors.Trace(err) + } + + partName := spec.PartitionNames[0].L + + // NOTE: if pt is subPartitioned, it should be checked + + defID, err := tables.FindPartitionByName(ptMeta, partName) + if err != nil { + return errors.Trace(err) + } + + err = checkTableDefCompatible(ptMeta, ntMeta) + if err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: ntSchema.ID, + TableID: ntMeta.ID, + SchemaName: ntSchema.Name.L, + TableName: ntMeta.Name.L, + Type: model.ActionExchangeTablePartition, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{defID, ptSchema.ID, ptMeta.ID, partName, spec.WithValidation}, + CtxVars: []any{[]int64{ntSchema.ID, ptSchema.ID}, []int64{ntMeta.ID, ptMeta.ID}}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{ + {Database: ptSchema.Name.L, Table: ptMeta.Name.L}, + {Database: ntSchema.Name.L, Table: ntMeta.Name.L}, + }, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + if err != nil { + return errors.Trace(err) + } + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("after the exchange, please analyze related table of the exchange to update statistics")) + return nil +} + +// DropColumn will drop a column from the table, now we don't support drop the column with index covered. +func (e *executor) DropColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) error { + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return errors.Trace(err) + } + failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) + + isDropable, err := checkIsDroppableColumn(ctx, e.infoCache.GetLatest(), schema, t, spec) + if err != nil { + return err + } + if !isDropable { + return nil + } + colName := spec.OldColumnName.Name + err = checkVisibleColumnCnt(t, 0, 1) + if err != nil { + return err + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + SchemaState: model.StatePublic, + TableName: t.Meta().Name.L, + Type: model.ActionDropColumn, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{colName, spec.IfExists}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func checkIsDroppableColumn(ctx sessionctx.Context, is infoschema.InfoSchema, schema *model.DBInfo, t table.Table, spec *ast.AlterTableSpec) (isDrapable bool, err error) { + tblInfo := t.Meta() + // Check whether dropped column has existed. + colName := spec.OldColumnName.Name + col := table.FindCol(t.VisibleCols(), colName.L) + if col == nil { + err = dbterror.ErrCantDropFieldOrKey.GenWithStackByArgs(colName) + if spec.IfExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return false, nil + } + return false, err + } + + if err = isDroppableColumn(tblInfo, colName); err != nil { + return false, errors.Trace(err) + } + if err = checkDropColumnWithPartitionConstraint(t, colName); err != nil { + return false, errors.Trace(err) + } + // Check the column with foreign key. + err = checkDropColumnWithForeignKeyConstraint(is, schema.Name.L, tblInfo, colName.L) + if err != nil { + return false, errors.Trace(err) + } + // Check the column with TTL config + err = checkDropColumnWithTTLConfig(tblInfo, colName.L) + if err != nil { + return false, errors.Trace(err) + } + // We don't support dropping column with PK handle covered now. + if col.IsPKHandleColumn(tblInfo) { + return false, dbterror.ErrUnsupportedPKHandle + } + if mysql.HasAutoIncrementFlag(col.GetFlag()) && !ctx.GetSessionVars().AllowRemoveAutoInc { + return false, dbterror.ErrCantDropColWithAutoInc + } + return true, nil +} + +// checkDropColumnWithPartitionConstraint is used to check the partition constraint of the drop column. +func checkDropColumnWithPartitionConstraint(t table.Table, colName model.CIStr) error { + if t.Meta().Partition == nil { + return nil + } + pt, ok := t.(table.PartitionedTable) + if !ok { + // Should never happen! + return errors.Trace(dbterror.ErrDependentByPartitionFunctional.GenWithStackByArgs(colName.L)) + } + for _, name := range pt.GetPartitionColumnNames() { + if strings.EqualFold(name.L, colName.L) { + return errors.Trace(dbterror.ErrDependentByPartitionFunctional.GenWithStackByArgs(colName.L)) + } + } + return nil +} + +func checkVisibleColumnCnt(t table.Table, addCnt, dropCnt int) error { + tblInfo := t.Meta() + visibleColumCnt := 0 + for _, column := range tblInfo.Columns { + if !column.Hidden { + visibleColumCnt++ + } + } + if visibleColumCnt+addCnt > dropCnt { + return nil + } + if len(tblInfo.Columns)-visibleColumCnt > 0 { + // There are only invisible columns. + return dbterror.ErrTableMustHaveColumns + } + return dbterror.ErrCantRemoveAllFields +} + +// checkModifyCharsetAndCollation returns error when the charset or collation is not modifiable. +// needRewriteCollationData is used when trying to modify the collation of a column, it is true when the column is with +// index because index of a string column is collation-aware. +func checkModifyCharsetAndCollation(toCharset, toCollate, origCharset, origCollate string, needRewriteCollationData bool) error { + if !charset.ValidCharsetAndCollation(toCharset, toCollate) { + return dbterror.ErrUnknownCharacterSet.GenWithStack("Unknown character set: '%s', collation: '%s'", toCharset, toCollate) + } + + if needRewriteCollationData && collate.NewCollationEnabled() && !collate.CompatibleCollate(origCollate, toCollate) { + return dbterror.ErrUnsupportedModifyCollation.GenWithStackByArgs(origCollate, toCollate) + } + + if (origCharset == charset.CharsetUTF8 && toCharset == charset.CharsetUTF8MB4) || + (origCharset == charset.CharsetUTF8 && toCharset == charset.CharsetUTF8) || + (origCharset == charset.CharsetUTF8MB4 && toCharset == charset.CharsetUTF8MB4) || + (origCharset == charset.CharsetLatin1 && toCharset == charset.CharsetUTF8MB4) { + // TiDB only allow utf8/latin1 to be changed to utf8mb4, or changing the collation when the charset is utf8/utf8mb4/latin1. + return nil + } + + if toCharset != origCharset { + msg := fmt.Sprintf("charset from %s to %s", origCharset, toCharset) + return dbterror.ErrUnsupportedModifyCharset.GenWithStackByArgs(msg) + } + if toCollate != origCollate { + msg := fmt.Sprintf("change collate from %s to %s", origCollate, toCollate) + return dbterror.ErrUnsupportedModifyCharset.GenWithStackByArgs(msg) + } + return nil +} + +func (e *executor) getModifiableColumnJob(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, originalColName model.CIStr, + spec *ast.AlterTableSpec) (*model.Job, error) { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return nil, errors.Trace(infoschema.ErrDatabaseNotExists) + } + t, err := is.TableByName(ctx, ident.Schema, ident.Name) + if err != nil { + return nil, errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + + return GetModifiableColumnJob(ctx, sctx, is, ident, originalColName, schema, t, spec) +} + +// ChangeColumn renames an existing column and modifies the column's definition, +// currently we only support limited kind of changes +// that do not need to change or check data on the table. +func (e *executor) ChangeColumn(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + specNewColumn := spec.NewColumns[0] + if len(specNewColumn.Name.Schema.O) != 0 && ident.Schema.L != specNewColumn.Name.Schema.L { + return dbterror.ErrWrongDBName.GenWithStackByArgs(specNewColumn.Name.Schema.O) + } + if len(spec.OldColumnName.Schema.O) != 0 && ident.Schema.L != spec.OldColumnName.Schema.L { + return dbterror.ErrWrongDBName.GenWithStackByArgs(spec.OldColumnName.Schema.O) + } + if len(specNewColumn.Name.Table.O) != 0 && ident.Name.L != specNewColumn.Name.Table.L { + return dbterror.ErrWrongTableName.GenWithStackByArgs(specNewColumn.Name.Table.O) + } + if len(spec.OldColumnName.Table.O) != 0 && ident.Name.L != spec.OldColumnName.Table.L { + return dbterror.ErrWrongTableName.GenWithStackByArgs(spec.OldColumnName.Table.O) + } + + job, err := e.getModifiableColumnJob(ctx, sctx, ident, spec.OldColumnName.Name, spec) + if err != nil { + if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { + sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.FastGenByArgs(spec.OldColumnName.Name, ident.Name)) + return nil + } + return errors.Trace(err) + } + + err = e.DoDDLJob(sctx, job) + // column not exists, but if_exists flags is true, so we ignore this error. + if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { + sctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) +} + +// RenameColumn renames an existing column. +func (e *executor) RenameColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + oldColName := spec.OldColumnName.Name + newColName := spec.NewColumnName.Name + + schema, tbl, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + + oldCol := table.FindCol(tbl.VisibleCols(), oldColName.L) + if oldCol == nil { + return infoschema.ErrColumnNotExists.GenWithStackByArgs(oldColName, ident.Name) + } + // check if column can rename with check constraint + err = IsColumnRenameableWithCheckConstraint(oldCol.Name, tbl.Meta()) + if err != nil { + return err + } + + if oldColName.L == newColName.L { + return nil + } + if newColName.L == model.ExtraHandleName.L { + return dbterror.ErrWrongColumnName.GenWithStackByArgs(newColName.L) + } + + allCols := tbl.Cols() + colWithNewNameAlreadyExist := table.FindCol(allCols, newColName.L) != nil + if colWithNewNameAlreadyExist { + return infoschema.ErrColumnExists.GenWithStackByArgs(newColName) + } + + // Check generated expression. + err = checkModifyColumnWithGeneratedColumnsConstraint(allCols, oldColName) + if err != nil { + return errors.Trace(err) + } + err = checkDropColumnWithPartitionConstraint(tbl, oldColName) + if err != nil { + return errors.Trace(err) + } + + newCol := oldCol.Clone() + newCol.Name = newColName + job := &model.Job{ + SchemaID: schema.ID, + TableID: tbl.Meta().ID, + SchemaName: schema.Name.L, + TableName: tbl.Meta().Name.L, + Type: model.ActionModifyColumn, + BinlogInfo: &model.HistoryInfo{}, + ReorgMeta: NewDDLReorgMeta(ctx), + Args: []any{&newCol, oldColName, spec.Position, 0, 0}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// ModifyColumn does modification on an existing column, currently we only support limited kind of changes +// that do not need to change or check data on the table. +func (e *executor) ModifyColumn(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + specNewColumn := spec.NewColumns[0] + if len(specNewColumn.Name.Schema.O) != 0 && ident.Schema.L != specNewColumn.Name.Schema.L { + return dbterror.ErrWrongDBName.GenWithStackByArgs(specNewColumn.Name.Schema.O) + } + if len(specNewColumn.Name.Table.O) != 0 && ident.Name.L != specNewColumn.Name.Table.L { + return dbterror.ErrWrongTableName.GenWithStackByArgs(specNewColumn.Name.Table.O) + } + + originalColName := specNewColumn.Name.Name + job, err := e.getModifiableColumnJob(ctx, sctx, ident, originalColName, spec) + if err != nil { + if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { + sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.FastGenByArgs(originalColName, ident.Name)) + return nil + } + return errors.Trace(err) + } + + err = e.DoDDLJob(sctx, job) + // column not exists, but if_exists flags is true, so we ignore this error. + if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { + sctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) +} + +func (e *executor) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + specNewColumn := spec.NewColumns[0] + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name) + } + t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name) + } + + colName := specNewColumn.Name.Name + // Check whether alter column has existed. + oldCol := table.FindCol(t.Cols(), colName.L) + if oldCol == nil { + return dbterror.ErrBadField.GenWithStackByArgs(colName, ident.Name) + } + col := table.ToColumn(oldCol.Clone()) + + // Clean the NoDefaultValueFlag value. + col.DelFlag(mysql.NoDefaultValueFlag) + col.DefaultIsExpr = false + if len(specNewColumn.Options) == 0 { + err = col.SetDefaultValue(nil) + if err != nil { + return errors.Trace(err) + } + col.AddFlag(mysql.NoDefaultValueFlag) + } else { + if IsAutoRandomColumnID(t.Meta(), col.ID) { + return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithDefaultValueErrMsg) + } + hasDefaultValue, err := SetDefaultValue(ctx, col, specNewColumn.Options[0]) + if err != nil { + return errors.Trace(err) + } + if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil { + return errors.Trace(err) + } + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionSetDefaultValue, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{col}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// AlterTableComment updates the table comment information. +func (e *executor) AlterTableComment(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + + tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, ident.Name.L, &spec.Comment, dbterror.ErrTooLongTableComment); err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + SchemaName: schema.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionModifyTableComment, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{spec.Comment}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// AlterTableAutoIDCache updates the table comment information. +func (e *executor) AlterTableAutoIDCache(ctx sessionctx.Context, ident ast.Ident, newCache int64) error { + schema, tb, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + tbInfo := tb.Meta() + if (newCache == 1 && tbInfo.AutoIdCache != 1) || + (newCache != 1 && tbInfo.AutoIdCache == 1) { + return fmt.Errorf("Can't Alter AUTO_ID_CACHE between 1 and non-1, the underlying implementation is different") + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + SchemaName: schema.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionModifyTableAutoIdCache, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{newCache}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// AlterTableCharsetAndCollate changes the table charset and collate. +func (e *executor) AlterTableCharsetAndCollate(ctx sessionctx.Context, ident ast.Ident, toCharset, toCollate string, needsOverwriteCols bool) error { + // use the last one. + if toCharset == "" && toCollate == "" { + return dbterror.ErrUnknownCharacterSet.GenWithStackByArgs(toCharset) + } + + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + + tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + + if toCharset == "" { + // charset does not change. + toCharset = tb.Meta().Charset + } + + if toCollate == "" { + // Get the default collation of the charset. + toCollate, err = GetDefaultCollation(ctx.GetSessionVars(), toCharset) + if err != nil { + return errors.Trace(err) + } + } + doNothing, err := checkAlterTableCharset(tb.Meta(), schema, toCharset, toCollate, needsOverwriteCols) + if err != nil { + return err + } + if doNothing { + return nil + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + SchemaName: schema.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionModifyTableCharsetAndCollate, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{toCharset, toCollate, needsOverwriteCols}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func shouldModifyTiFlashReplica(tbReplicaInfo *model.TiFlashReplicaInfo, replicaInfo *ast.TiFlashReplicaSpec) bool { + if tbReplicaInfo != nil && tbReplicaInfo.Count == replicaInfo.Count && + len(tbReplicaInfo.LocationLabels) == len(replicaInfo.Labels) { + for i, label := range tbReplicaInfo.LocationLabels { + if replicaInfo.Labels[i] != label { + return true + } + } + return false + } + return true +} + +// addHypoTiFlashReplicaIntoCtx adds this hypothetical tiflash replica into this ctx. +func (*executor) setHypoTiFlashReplica(ctx sessionctx.Context, schemaName, tableName model.CIStr, replicaInfo *ast.TiFlashReplicaSpec) error { + sctx := ctx.GetSessionVars() + if sctx.HypoTiFlashReplicas == nil { + sctx.HypoTiFlashReplicas = make(map[string]map[string]struct{}) + } + if sctx.HypoTiFlashReplicas[schemaName.L] == nil { + sctx.HypoTiFlashReplicas[schemaName.L] = make(map[string]struct{}) + } + if replicaInfo.Count > 0 { // add replicas + sctx.HypoTiFlashReplicas[schemaName.L][tableName.L] = struct{}{} + } else { // delete replicas + delete(sctx.HypoTiFlashReplicas[schemaName.L], tableName.L) + } + return nil +} + +// AlterTableSetTiFlashReplica sets the TiFlash replicas info. +func (e *executor) AlterTableSetTiFlashReplica(ctx sessionctx.Context, ident ast.Ident, replicaInfo *ast.TiFlashReplicaSpec) error { + schema, tb, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + + err = isTableTiFlashSupported(schema.Name, tb.Meta()) + if err != nil { + return errors.Trace(err) + } + + tbReplicaInfo := tb.Meta().TiFlashReplica + if !shouldModifyTiFlashReplica(tbReplicaInfo, replicaInfo) { + return nil + } + + if replicaInfo.Hypo { + return e.setHypoTiFlashReplica(ctx, schema.Name, tb.Meta().Name, replicaInfo) + } + + err = checkTiFlashReplicaCount(ctx, replicaInfo.Count) + if err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + SchemaName: schema.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionSetTiFlashReplica, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{*replicaInfo}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// AlterTableTTLInfoOrEnable submit ddl job to change table info according to the ttlInfo, or ttlEnable +// at least one of the `ttlInfo`, `ttlEnable` or `ttlCronJobSchedule` should be not nil. +// When `ttlInfo` is nil, and `ttlEnable` is not, it will use the original `.TTLInfo` in the table info and modify the +// `.Enable`. If the `.TTLInfo` in the table info is empty, this function will return an error. +// When `ttlInfo` is nil, and `ttlCronJobSchedule` is not, it will use the original `.TTLInfo` in the table info and modify the +// `.JobInterval`. If the `.TTLInfo` in the table info is empty, this function will return an error. +// When `ttlInfo` is not nil, it simply submits the job with the `ttlInfo` and ignore the `ttlEnable`. +func (e *executor) AlterTableTTLInfoOrEnable(ctx sessionctx.Context, ident ast.Ident, ttlInfo *model.TTLInfo, ttlEnable *bool, ttlCronJobSchedule *string) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + + tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + + tblInfo := tb.Meta().Clone() + tableID := tblInfo.ID + tableName := tblInfo.Name.L + + var job *model.Job + if ttlInfo != nil { + tblInfo.TTLInfo = ttlInfo + err = checkTTLInfoValid(ctx, ident.Schema, tblInfo) + if err != nil { + return err + } + } else { + if tblInfo.TTLInfo == nil { + if ttlEnable != nil { + return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_ENABLE")) + } + if ttlCronJobSchedule != nil { + return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_JOB_INTERVAL")) + } + } + } + + job = &model.Job{ + SchemaID: schema.ID, + TableID: tableID, + SchemaName: schema.Name.L, + TableName: tableName, + Type: model.ActionAlterTTLInfo, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{ttlInfo, ttlEnable, ttlCronJobSchedule}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) AlterTableRemoveTTL(ctx sessionctx.Context, ident ast.Ident) error { + is := e.infoCache.GetLatest() + + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + + tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + + tblInfo := tb.Meta().Clone() + tableID := tblInfo.ID + tableName := tblInfo.Name.L + + if tblInfo.TTLInfo != nil { + job := &model.Job{ + SchemaID: schema.ID, + TableID: tableID, + SchemaName: schema.Name.L, + TableName: tableName, + Type: model.ActionAlterTTLRemove, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) + } + + return nil +} + +func isTableTiFlashSupported(dbName model.CIStr, tbl *model.TableInfo) error { + // Memory tables and system tables are not supported by TiFlash + if util.IsMemOrSysDB(dbName.L) { + return errors.Trace(dbterror.ErrUnsupportedTiFlashOperationForSysOrMemTable) + } else if tbl.TempTableType != model.TempTableNone { + return dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("set on tiflash") + } else if tbl.IsView() || tbl.IsSequence() { + return dbterror.ErrWrongObject.GenWithStackByArgs(dbName, tbl.Name, "BASE TABLE") + } + + // Tables that has charset are not supported by TiFlash + for _, col := range tbl.Cols() { + _, ok := charset.TiFlashSupportedCharsets[col.GetCharset()] + if !ok { + return dbterror.ErrUnsupportedTiFlashOperationForUnsupportedCharsetTable.GenWithStackByArgs(col.GetCharset()) + } + } + + return nil +} + +func checkTiFlashReplicaCount(ctx sessionctx.Context, replicaCount uint64) error { + // Check the tiflash replica count should be less than the total tiflash stores. + tiflashStoreCnt, err := infoschema.GetTiFlashStoreCount(ctx) + if err != nil { + return errors.Trace(err) + } + if replicaCount > tiflashStoreCnt { + return errors.Errorf("the tiflash replica count: %d should be less than the total tiflash server count: %d", replicaCount, tiflashStoreCnt) + } + return nil +} + +// AlterTableAddStatistics registers extended statistics for a table. +func (e *executor) AlterTableAddStatistics(ctx sessionctx.Context, ident ast.Ident, stats *ast.StatisticsSpec, ifNotExists bool) error { + if !ctx.GetSessionVars().EnableExtendedStats { + return errors.New("Extended statistics feature is not generally available now, and tidb_enable_extended_stats is OFF") + } + // Not support Cardinality and Dependency statistics type for now. + if stats.StatsType == ast.StatsTypeCardinality || stats.StatsType == ast.StatsTypeDependency { + return errors.New("Cardinality and Dependency statistics types are not supported now") + } + _, tbl, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return err + } + tblInfo := tbl.Meta() + if tblInfo.GetPartitionInfo() != nil { + return errors.New("Extended statistics on partitioned tables are not supported now") + } + colIDs := make([]int64, 0, 2) + colIDSet := make(map[int64]struct{}, 2) + // Check whether columns exist. + for _, colName := range stats.Columns { + col := table.FindCol(tbl.VisibleCols(), colName.Name.L) + if col == nil { + return infoschema.ErrColumnNotExists.GenWithStackByArgs(colName.Name, ident.Name) + } + if stats.StatsType == ast.StatsTypeCorrelation && tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.GetFlag()) { + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("No need to create correlation statistics on the integer primary key column")) + return nil + } + if _, exist := colIDSet[col.ID]; exist { + return errors.Errorf("Cannot create extended statistics on duplicate column names '%s'", colName.Name.L) + } + colIDSet[col.ID] = struct{}{} + colIDs = append(colIDs, col.ID) + } + if len(colIDs) != 2 && (stats.StatsType == ast.StatsTypeCorrelation || stats.StatsType == ast.StatsTypeDependency) { + return errors.New("Only support Correlation and Dependency statistics types on 2 columns") + } + if len(colIDs) < 1 && stats.StatsType == ast.StatsTypeCardinality { + return errors.New("Only support Cardinality statistics type on at least 2 columns") + } + // TODO: check whether covering index exists for cardinality / dependency types. + + // Call utilities of statistics.Handle to modify system tables instead of doing DML directly, + // because locking in Handle can guarantee the correctness of `version` in system tables. + return e.statsHandle.InsertExtendedStats(stats.StatsName, colIDs, int(stats.StatsType), tblInfo.ID, ifNotExists) +} + +// AlterTableDropStatistics logically deletes extended statistics for a table. +func (e *executor) AlterTableDropStatistics(ctx sessionctx.Context, ident ast.Ident, stats *ast.StatisticsSpec, ifExists bool) error { + if !ctx.GetSessionVars().EnableExtendedStats { + return errors.New("Extended statistics feature is not generally available now, and tidb_enable_extended_stats is OFF") + } + _, tbl, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return err + } + tblInfo := tbl.Meta() + // Call utilities of statistics.Handle to modify system tables instead of doing DML directly, + // because locking in Handle can guarantee the correctness of `version` in system tables. + return e.statsHandle.MarkExtendedStatsDeleted(stats.StatsName, tblInfo.ID, ifExists) +} + +// UpdateTableReplicaInfo updates the table flash replica infos. +func (e *executor) UpdateTableReplicaInfo(ctx sessionctx.Context, physicalID int64, available bool) error { + is := e.infoCache.GetLatest() + tb, ok := is.TableByID(physicalID) + if !ok { + tb, _, _ = is.FindTableByPartitionID(physicalID) + if tb == nil { + return infoschema.ErrTableNotExists.GenWithStack("Table which ID = %d does not exist.", physicalID) + } + } + tbInfo := tb.Meta() + if tbInfo.TiFlashReplica == nil || (tbInfo.ID == physicalID && tbInfo.TiFlashReplica.Available == available) || + (tbInfo.ID != physicalID && available == tbInfo.TiFlashReplica.IsPartitionAvailable(physicalID)) { + return nil + } + + db, ok := infoschema.SchemaByTable(is, tbInfo) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStack("Database of table `%s` does not exist.", tb.Meta().Name) + } + + job := &model.Job{ + SchemaID: db.ID, + TableID: tb.Meta().ID, + SchemaName: db.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionUpdateTiFlashReplicaStatus, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{available, physicalID}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err := e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// checkAlterTableCharset uses to check is it possible to change the charset of table. +// This function returns 2 variable: +// doNothing: if doNothing is true, means no need to change any more, because the target charset is same with the charset of table. +// err: if err is not nil, means it is not possible to change table charset to target charset. +func checkAlterTableCharset(tblInfo *model.TableInfo, dbInfo *model.DBInfo, toCharset, toCollate string, needsOverwriteCols bool) (doNothing bool, err error) { + origCharset := tblInfo.Charset + origCollate := tblInfo.Collate + // Old version schema charset maybe modified when load schema if TreatOldVersionUTF8AsUTF8MB4 was enable. + // So even if the origCharset equal toCharset, we still need to do the ddl for old version schema. + if origCharset == toCharset && origCollate == toCollate && tblInfo.Version >= model.TableInfoVersion2 { + // nothing to do. + doNothing = true + for _, col := range tblInfo.Columns { + if col.GetCharset() == charset.CharsetBin { + continue + } + if col.GetCharset() == toCharset && col.GetCollate() == toCollate { + continue + } + doNothing = false + } + if doNothing { + return doNothing, nil + } + } + + // This DDL will update the table charset to default charset. + origCharset, origCollate, err = ResolveCharsetCollation(nil, + ast.CharsetOpt{Chs: origCharset, Col: origCollate}, + ast.CharsetOpt{Chs: dbInfo.Charset, Col: dbInfo.Collate}, + ) + if err != nil { + return doNothing, err + } + + if err = checkModifyCharsetAndCollation(toCharset, toCollate, origCharset, origCollate, false); err != nil { + return doNothing, err + } + if !needsOverwriteCols { + // If we don't change the charset and collation of columns, skip the next checks. + return doNothing, nil + } + + for _, col := range tblInfo.Columns { + if col.GetType() == mysql.TypeVarchar { + if err = types.IsVarcharTooBigFieldLength(col.GetFlen(), col.Name.O, toCharset); err != nil { + return doNothing, err + } + } + if col.GetCharset() == charset.CharsetBin { + continue + } + if len(col.GetCharset()) == 0 { + continue + } + if err = checkModifyCharsetAndCollation(toCharset, toCollate, col.GetCharset(), col.GetCollate(), isColumnWithIndex(col.Name.L, tblInfo.Indices)); err != nil { + if strings.Contains(err.Error(), "Unsupported modifying collation") { + colErrMsg := "Unsupported converting collation of column '%s' from '%s' to '%s' when index is defined on it." + err = dbterror.ErrUnsupportedModifyCollation.GenWithStack(colErrMsg, col.Name.L, col.GetCollate(), toCollate) + } + return doNothing, err + } + } + return doNothing, nil +} + +// RenameIndex renames an index. +// In TiDB, indexes are case-insensitive (so index 'a' and 'A" are considered the same index), +// but index names are case-sensitive (we can rename index 'a' to 'A') +func (e *executor) RenameIndex(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + + tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) + } + if tb.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Index")) + } + duplicate, err := ValidateRenameIndex(spec.FromKey, spec.ToKey, tb.Meta()) + if duplicate { + return nil + } + if err != nil { + return errors.Trace(err) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + SchemaName: schema.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionRenameIndex, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{spec.FromKey, spec.ToKey}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// If one drop those tables by mistake, it's difficult to recover. +// In the worst case, the whole TiDB cluster fails to bootstrap, so we prevent user from dropping them. +var systemTables = map[string]struct{}{ + "tidb": {}, + "gc_delete_range": {}, + "gc_delete_range_done": {}, +} + +func isUndroppableTable(schema, table string) bool { + if schema != mysql.SystemDB { + return false + } + if _, ok := systemTables[table]; ok { + return true + } + return false +} + +type objectType int + +const ( + tableObject objectType = iota + viewObject + sequenceObject +) + +// dropTableObject provides common logic to DROP TABLE/VIEW/SEQUENCE. +func (e *executor) dropTableObject( + ctx sessionctx.Context, + objects []*ast.TableName, + ifExists bool, + tableObjectType objectType, +) error { + var ( + notExistTables []string + sessVars = ctx.GetSessionVars() + is = e.infoCache.GetLatest() + dropExistErr *terror.Error + jobType model.ActionType + ) + + var jobArgs []any + switch tableObjectType { + case tableObject: + dropExistErr = infoschema.ErrTableDropExists + jobType = model.ActionDropTable + objectIdents := make([]ast.Ident, len(objects)) + fkCheck := ctx.GetSessionVars().ForeignKeyChecks + jobArgs = []any{objectIdents, fkCheck} + for i, tn := range objects { + objectIdents[i] = ast.Ident{Schema: tn.Schema, Name: tn.Name} + } + for _, tn := range objects { + if referredFK := checkTableHasForeignKeyReferred(is, tn.Schema.L, tn.Name.L, objectIdents, fkCheck); referredFK != nil { + return errors.Trace(dbterror.ErrForeignKeyCannotDropParent.GenWithStackByArgs(tn.Name, referredFK.ChildFKName, referredFK.ChildTable)) + } + } + case viewObject: + dropExistErr = infoschema.ErrTableDropExists + jobType = model.ActionDropView + case sequenceObject: + dropExistErr = infoschema.ErrSequenceDropExists + jobType = model.ActionDropSequence + } + for _, tn := range objects { + fullti := ast.Ident{Schema: tn.Schema, Name: tn.Name} + schema, ok := is.SchemaByName(tn.Schema) + if !ok { + // TODO: we should return special error for table not exist, checking "not exist" is not enough, + // because some other errors may contain this error string too. + notExistTables = append(notExistTables, fullti.String()) + continue + } + tableInfo, err := is.TableByName(e.ctx, tn.Schema, tn.Name) + if err != nil && infoschema.ErrTableNotExists.Equal(err) { + notExistTables = append(notExistTables, fullti.String()) + continue + } else if err != nil { + return err + } + + // prechecks before build DDL job + + // Protect important system table from been dropped by a mistake. + // I can hardly find a case that a user really need to do this. + if isUndroppableTable(tn.Schema.L, tn.Name.L) { + return errors.Errorf("Drop tidb system table '%s.%s' is forbidden", tn.Schema.L, tn.Name.L) + } + switch tableObjectType { + case tableObject: + if !tableInfo.Meta().IsBaseTable() { + notExistTables = append(notExistTables, fullti.String()) + continue + } + + tempTableType := tableInfo.Meta().TempTableType + if config.CheckTableBeforeDrop && tempTableType == model.TempTableNone { + logutil.DDLLogger().Warn("admin check table before drop", + zap.String("database", fullti.Schema.O), + zap.String("table", fullti.Name.O), + ) + exec := ctx.GetRestrictedSQLExecutor() + internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + _, _, err := exec.ExecRestrictedSQL(internalCtx, nil, "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) + if err != nil { + return err + } + } + + if tableInfo.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Drop Table") + } + case viewObject: + if !tableInfo.Meta().IsView() { + return dbterror.ErrWrongObject.GenWithStackByArgs(fullti.Schema, fullti.Name, "VIEW") + } + case sequenceObject: + if !tableInfo.Meta().IsSequence() { + err = dbterror.ErrWrongObject.GenWithStackByArgs(fullti.Schema, fullti.Name, "SEQUENCE") + if ifExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + continue + } + return err + } + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tableInfo.Meta().ID, + SchemaName: schema.Name.L, + SchemaState: schema.State, + TableName: tableInfo.Meta().Name.L, + Type: jobType, + BinlogInfo: &model.HistoryInfo{}, + Args: jobArgs, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { + notExistTables = append(notExistTables, fullti.String()) + continue + } else if err != nil { + return errors.Trace(err) + } + + // unlock table after drop + if tableObjectType != tableObject { + continue + } + if !config.TableLockEnabled() { + continue + } + if ok, _ := ctx.CheckTableLocked(tableInfo.Meta().ID); ok { + ctx.ReleaseTableLockByTableIDs([]int64{tableInfo.Meta().ID}) + } + } + if len(notExistTables) > 0 && !ifExists { + return dropExistErr.FastGenByArgs(strings.Join(notExistTables, ",")) + } + // We need add warning when use if exists. + if len(notExistTables) > 0 && ifExists { + for _, table := range notExistTables { + sessVars.StmtCtx.AppendNote(dropExistErr.FastGenByArgs(table)) + } + } + return nil +} + +// DropTable will proceed even if some table in the list does not exists. +func (e *executor) DropTable(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) { + return e.dropTableObject(ctx, stmt.Tables, stmt.IfExists, tableObject) +} + +// DropView will proceed even if some view in the list does not exists. +func (e *executor) DropView(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) { + return e.dropTableObject(ctx, stmt.Tables, stmt.IfExists, viewObject) +} + +func (e *executor) TruncateTable(ctx sessionctx.Context, ti ast.Ident) error { + schema, tb, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return errors.Trace(err) + } + if tb.Meta().IsView() || tb.Meta().IsSequence() { + return infoschema.ErrTableNotExists.GenWithStackByArgs(schema.Name.O, tb.Meta().Name.O) + } + if tb.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Truncate Table") + } + fkCheck := ctx.GetSessionVars().ForeignKeyChecks + referredFK := checkTableHasForeignKeyReferred(e.infoCache.GetLatest(), ti.Schema.L, ti.Name.L, []ast.Ident{{Name: ti.Name, Schema: ti.Schema}}, fkCheck) + if referredFK != nil { + msg := fmt.Sprintf("`%s`.`%s` CONSTRAINT `%s`", referredFK.ChildSchema, referredFK.ChildTable, referredFK.ChildFKName) + return errors.Trace(dbterror.ErrTruncateIllegalForeignKey.GenWithStackByArgs(msg)) + } + + ids := 1 + if tb.Meta().Partition != nil { + ids += len(tb.Meta().Partition.Definitions) + } + genIDs, err := e.genGlobalIDs(ids) + if err != nil { + return errors.Trace(err) + } + newTableID := genIDs[0] + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + SchemaName: schema.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionTruncateTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{newTableID, fkCheck, genIDs[1:]}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + if ok, _ := ctx.CheckTableLocked(tb.Meta().ID); ok && config.TableLockEnabled() { + // AddTableLock here to avoid this ddl job was executed successfully but the session was been kill before return. + // The session will release all table locks it holds, if we don't add the new locking table id here, + // the session may forget to release the new locked table id when this ddl job was executed successfully + // but the session was killed before return. + ctx.AddTableLock([]model.TableLockTpInfo{{SchemaID: schema.ID, TableID: newTableID, Tp: tb.Meta().Lock.Tp}}) + } + err = e.DoDDLJob(ctx, job) + if err != nil { + if config.TableLockEnabled() { + ctx.ReleaseTableLockByTableIDs([]int64{newTableID}) + } + return errors.Trace(err) + } + + if !config.TableLockEnabled() { + return nil + } + if ok, _ := ctx.CheckTableLocked(tb.Meta().ID); ok { + ctx.ReleaseTableLockByTableIDs([]int64{tb.Meta().ID}) + } + return nil +} + +func (e *executor) RenameTable(ctx sessionctx.Context, s *ast.RenameTableStmt) error { + isAlterTable := false + var err error + if len(s.TableToTables) == 1 { + oldIdent := ast.Ident{Schema: s.TableToTables[0].OldTable.Schema, Name: s.TableToTables[0].OldTable.Name} + newIdent := ast.Ident{Schema: s.TableToTables[0].NewTable.Schema, Name: s.TableToTables[0].NewTable.Name} + err = e.renameTable(ctx, oldIdent, newIdent, isAlterTable) + } else { + oldIdents := make([]ast.Ident, 0, len(s.TableToTables)) + newIdents := make([]ast.Ident, 0, len(s.TableToTables)) + for _, tables := range s.TableToTables { + oldIdent := ast.Ident{Schema: tables.OldTable.Schema, Name: tables.OldTable.Name} + newIdent := ast.Ident{Schema: tables.NewTable.Schema, Name: tables.NewTable.Name} + oldIdents = append(oldIdents, oldIdent) + newIdents = append(newIdents, newIdent) + } + err = e.renameTables(ctx, oldIdents, newIdents, isAlterTable) + } + return err +} + +func (e *executor) renameTable(ctx sessionctx.Context, oldIdent, newIdent ast.Ident, isAlterTable bool) error { + is := e.infoCache.GetLatest() + tables := make(map[string]int64) + schemas, tableID, err := ExtractTblInfos(is, oldIdent, newIdent, isAlterTable, tables) + if err != nil { + return err + } + + if schemas == nil { + return nil + } + + if tbl, ok := is.TableByID(tableID); ok { + if tbl.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Table")) + } + } + + job := &model.Job{ + SchemaID: schemas[1].ID, + TableID: tableID, + SchemaName: schemas[1].Name.L, + TableName: oldIdent.Name.L, + Type: model.ActionRenameTable, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{schemas[0].ID, newIdent.Name, schemas[0].Name}, + CtxVars: []any{[]int64{schemas[0].ID, schemas[1].ID}, []int64{tableID}}, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{ + {Database: schemas[0].Name.L, Table: oldIdent.Name.L}, + {Database: schemas[1].Name.L, Table: newIdent.Name.L}, + }, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) renameTables(ctx sessionctx.Context, oldIdents, newIdents []ast.Ident, isAlterTable bool) error { + is := e.infoCache.GetLatest() + oldTableNames := make([]*model.CIStr, 0, len(oldIdents)) + tableNames := make([]*model.CIStr, 0, len(oldIdents)) + oldSchemaIDs := make([]int64, 0, len(oldIdents)) + newSchemaIDs := make([]int64, 0, len(oldIdents)) + tableIDs := make([]int64, 0, len(oldIdents)) + oldSchemaNames := make([]*model.CIStr, 0, len(oldIdents)) + involveSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(oldIdents)*2) + + var schemas []*model.DBInfo + var tableID int64 + var err error + + tables := make(map[string]int64) + for i := 0; i < len(oldIdents); i++ { + schemas, tableID, err = ExtractTblInfos(is, oldIdents[i], newIdents[i], isAlterTable, tables) + if err != nil { + return err + } + + if t, ok := is.TableByID(tableID); ok { + if t.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Tables")) + } + } + + tableIDs = append(tableIDs, tableID) + oldTableNames = append(oldTableNames, &oldIdents[i].Name) + tableNames = append(tableNames, &newIdents[i].Name) + oldSchemaIDs = append(oldSchemaIDs, schemas[0].ID) + newSchemaIDs = append(newSchemaIDs, schemas[1].ID) + oldSchemaNames = append(oldSchemaNames, &schemas[0].Name) + involveSchemaInfo = append(involveSchemaInfo, + model.InvolvingSchemaInfo{ + Database: schemas[0].Name.L, Table: oldIdents[i].Name.L, + }, + model.InvolvingSchemaInfo{ + Database: schemas[1].Name.L, Table: newIdents[i].Name.L, + }, + ) + } + + job := &model.Job{ + SchemaID: schemas[1].ID, + TableID: tableIDs[0], + SchemaName: schemas[1].Name.L, + Type: model.ActionRenameTables, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{oldSchemaIDs, newSchemaIDs, tableNames, tableIDs, oldSchemaNames, oldTableNames}, + CtxVars: []any{append(oldSchemaIDs, newSchemaIDs...), tableIDs}, + InvolvingSchemaInfo: involveSchemaInfo, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// ExtractTblInfos extracts the table information from the infoschema. +func ExtractTblInfos(is infoschema.InfoSchema, oldIdent, newIdent ast.Ident, isAlterTable bool, tables map[string]int64) ([]*model.DBInfo, int64, error) { + oldSchema, ok := is.SchemaByName(oldIdent.Schema) + if !ok { + if isAlterTable { + return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if tableExists(is, newIdent, tables) { + return nil, 0, infoschema.ErrTableExists.GenWithStackByArgs(newIdent) + } + return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if !tableExists(is, oldIdent, tables) { + if isAlterTable { + return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if tableExists(is, newIdent, tables) { + return nil, 0, infoschema.ErrTableExists.GenWithStackByArgs(newIdent) + } + return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) + } + if isAlterTable && newIdent.Schema.L == oldIdent.Schema.L && newIdent.Name.L == oldIdent.Name.L { + // oldIdent is equal to newIdent, do nothing + return nil, 0, nil + } + //View can be renamed only in the same schema. Compatible with mysql + if infoschema.TableIsView(is, oldIdent.Schema, oldIdent.Name) { + if oldIdent.Schema != newIdent.Schema { + return nil, 0, infoschema.ErrForbidSchemaChange.GenWithStackByArgs(oldIdent.Schema, newIdent.Schema) + } + } + + newSchema, ok := is.SchemaByName(newIdent.Schema) + if !ok { + return nil, 0, dbterror.ErrErrorOnRename.GenWithStackByArgs( + fmt.Sprintf("%s.%s", oldIdent.Schema, oldIdent.Name), + fmt.Sprintf("%s.%s", newIdent.Schema, newIdent.Name), + 168, + fmt.Sprintf("Database `%s` doesn't exist", newIdent.Schema)) + } + if tableExists(is, newIdent, tables) { + return nil, 0, infoschema.ErrTableExists.GenWithStackByArgs(newIdent) + } + if err := checkTooLongTable(newIdent.Name); err != nil { + return nil, 0, errors.Trace(err) + } + oldTableID := getTableID(is, oldIdent, tables) + oldIdentKey := getIdentKey(oldIdent) + tables[oldIdentKey] = tableNotExist + newIdentKey := getIdentKey(newIdent) + tables[newIdentKey] = oldTableID + return []*model.DBInfo{oldSchema, newSchema}, oldTableID, nil +} + +func tableExists(is infoschema.InfoSchema, ident ast.Ident, tables map[string]int64) bool { + identKey := getIdentKey(ident) + tableID, ok := tables[identKey] + if (ok && tableID != tableNotExist) || (!ok && is.TableExists(ident.Schema, ident.Name)) { + return true + } + return false +} + +func getTableID(is infoschema.InfoSchema, ident ast.Ident, tables map[string]int64) int64 { + identKey := getIdentKey(ident) + tableID, ok := tables[identKey] + if !ok { + oldTbl, err := is.TableByName(context.Background(), ident.Schema, ident.Name) + if err != nil { + return tableNotExist + } + tableID = oldTbl.Meta().ID + } + return tableID +} + +func getIdentKey(ident ast.Ident) string { + return fmt.Sprintf("%s.%s", ident.Schema.L, ident.Name.L) +} + +// GetName4AnonymousIndex returns a valid name for anonymous index. +func GetName4AnonymousIndex(t table.Table, colName model.CIStr, idxName model.CIStr) model.CIStr { + // `id` is used to indicated the index name's suffix. + id := 2 + l := len(t.Indices()) + indexName := colName + if idxName.O != "" { + // Use the provided index name, it only happens when the original index name is too long and be truncated. + indexName = idxName + id = 3 + } + if strings.EqualFold(indexName.L, mysql.PrimaryKeyName) { + indexName = model.NewCIStr(fmt.Sprintf("%s_%d", colName.O, id)) + id = 3 + } + for i := 0; i < l; i++ { + if t.Indices()[i].Meta().Name.L == indexName.L { + indexName = model.NewCIStr(fmt.Sprintf("%s_%d", colName.O, id)) + if err := checkTooLongIndex(indexName); err != nil { + indexName = GetName4AnonymousIndex(t, model.NewCIStr(colName.O[:30]), model.NewCIStr(fmt.Sprintf("%s_%d", colName.O[:30], 2))) + } + i = -1 + id++ + } + } + return indexName +} + +func (e *executor) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, + indexPartSpecifications []*ast.IndexPartSpecification, indexOption *ast.IndexOption) error { + if indexOption != nil && indexOption.PrimaryKeyTp == model.PrimaryKeyTypeClustered { + return dbterror.ErrUnsupportedModifyPrimaryKey.GenWithStack("Adding clustered primary key is not supported. " + + "Please consider adding NONCLUSTERED primary key instead") + } + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return errors.Trace(err) + } + + if err = checkTooLongIndex(indexName); err != nil { + return dbterror.ErrTooLongIdent.GenWithStackByArgs(mysql.PrimaryKeyName) + } + + indexName = model.NewCIStr(mysql.PrimaryKeyName) + if indexInfo := t.Meta().FindIndexByName(indexName.L); indexInfo != nil || + // If the table's PKIsHandle is true, it also means that this table has a primary key. + t.Meta().PKIsHandle { + return infoschema.ErrMultiplePriKey + } + + // Primary keys cannot include expression index parts. A primary key requires the generated column to be stored, + // but expression index parts are implemented as virtual generated columns, not stored generated columns. + for _, idxPart := range indexPartSpecifications { + if idxPart.Expr != nil { + return dbterror.ErrFunctionalIndexPrimaryKey + } + } + + tblInfo := t.Meta() + // Check before the job is put to the queue. + // This check is redundant, but useful. If DDL check fail before the job is put + // to job queue, the fail path logic is super fast. + // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. + // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. + // For same reason, decide whether index is global here. + indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications) + if err != nil { + return errors.Trace(err) + } + if _, err = CheckPKOnGeneratedColumn(tblInfo, indexPartSpecifications); err != nil { + return err + } + + global := false + if tblInfo.GetPartitionInfo() != nil { + ck, err := checkPartitionKeysConstraint(tblInfo.GetPartitionInfo(), indexColumns, tblInfo) + if err != nil { + return err + } + if !ck { + if !ctx.GetSessionVars().EnableGlobalIndex { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("PRIMARY") + } + // index columns does not contain all partition columns, must set global + global = true + } + } + + // May be truncate comment here, when index comment too long and sql_mode is't strict. + if indexOption != nil { + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { + return errors.Trace(err) + } + } + + unique := true + sqlMode := ctx.GetSessionVars().SQLMode + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionAddPrimaryKey, + BinlogInfo: &model.HistoryInfo{}, + ReorgMeta: nil, + Args: []any{unique, indexName, indexPartSpecifications, indexOption, sqlMode, nil, global}, + Priority: ctx.GetSessionVars().DDLReorgPriority, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + reorgMeta, err := newReorgMetaFromVariables(job, ctx) + if err != nil { + return err + } + job.ReorgMeta = reorgMeta + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) CreateIndex(ctx sessionctx.Context, stmt *ast.CreateIndexStmt) error { + ident := ast.Ident{Schema: stmt.Table.Schema, Name: stmt.Table.Name} + return e.createIndex(ctx, ident, stmt.KeyType, model.NewCIStr(stmt.IndexName), + stmt.IndexPartSpecifications, stmt.IndexOption, stmt.IfNotExists) +} + +// addHypoIndexIntoCtx adds this index as a hypo-index into this ctx. +func (*executor) addHypoIndexIntoCtx(ctx sessionctx.Context, schemaName, tableName model.CIStr, indexInfo *model.IndexInfo) error { + sctx := ctx.GetSessionVars() + indexName := indexInfo.Name + + if sctx.HypoIndexes == nil { + sctx.HypoIndexes = make(map[string]map[string]map[string]*model.IndexInfo) + } + if sctx.HypoIndexes[schemaName.L] == nil { + sctx.HypoIndexes[schemaName.L] = make(map[string]map[string]*model.IndexInfo) + } + if sctx.HypoIndexes[schemaName.L][tableName.L] == nil { + sctx.HypoIndexes[schemaName.L][tableName.L] = make(map[string]*model.IndexInfo) + } + if _, exist := sctx.HypoIndexes[schemaName.L][tableName.L][indexName.L]; exist { + return errors.Trace(errors.Errorf("conflict hypo index name %s", indexName.L)) + } + + sctx.HypoIndexes[schemaName.L][tableName.L][indexName.L] = indexInfo + return nil +} + +func (e *executor) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.IndexKeyType, indexName model.CIStr, + indexPartSpecifications []*ast.IndexPartSpecification, indexOption *ast.IndexOption, ifNotExists bool) error { + // not support Spatial and FullText index + if keyType == ast.IndexKeyTypeFullText || keyType == ast.IndexKeyTypeSpatial { + return dbterror.ErrUnsupportedIndexType.GenWithStack("FULLTEXT and SPATIAL index is not supported") + } + unique := keyType == ast.IndexKeyTypeUnique + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return errors.Trace(err) + } + + if t.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Create Index")) + } + // Deal with anonymous index. + if len(indexName.L) == 0 { + colName := model.NewCIStr("expression_index") + if indexPartSpecifications[0].Column != nil { + colName = indexPartSpecifications[0].Column.Name + } + indexName = GetName4AnonymousIndex(t, colName, model.NewCIStr("")) + } + + if indexInfo := t.Meta().FindIndexByName(indexName.L); indexInfo != nil { + if indexInfo.State != model.StatePublic { + // NOTE: explicit error message. See issue #18363. + err = dbterror.ErrDupKeyName.GenWithStack("Duplicate key name '%s'; "+ + "a background job is trying to add the same index, "+ + "please check by `ADMIN SHOW DDL JOBS`", indexName) + } else { + err = dbterror.ErrDupKeyName.GenWithStackByArgs(indexName) + } + if ifNotExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return err + } + + if err = checkTooLongIndex(indexName); err != nil { + return errors.Trace(err) + } + + tblInfo := t.Meta() + + // Build hidden columns if necessary. + hiddenCols, err := buildHiddenColumnInfoWithCheck(ctx, indexPartSpecifications, indexName, t.Meta(), t.Cols()) + if err != nil { + return err + } + if err = checkAddColumnTooManyColumns(len(t.Cols()) + len(hiddenCols)); err != nil { + return errors.Trace(err) + } + + finalColumns := make([]*model.ColumnInfo, len(tblInfo.Columns), len(tblInfo.Columns)+len(hiddenCols)) + copy(finalColumns, tblInfo.Columns) + finalColumns = append(finalColumns, hiddenCols...) + // Check before the job is put to the queue. + // This check is redundant, but useful. If DDL check fail before the job is put + // to job queue, the fail path logic is super fast. + // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. + // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. + // For same reason, decide whether index is global here. + indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications) + if err != nil { + return errors.Trace(err) + } + + global := false + if unique && tblInfo.GetPartitionInfo() != nil { + ck, err := checkPartitionKeysConstraint(tblInfo.GetPartitionInfo(), indexColumns, tblInfo) + if err != nil { + return err + } + if !ck { + if !ctx.GetSessionVars().EnableGlobalIndex { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("UNIQUE INDEX") + } + // index columns does not contain all partition columns, must set global + global = true + } + } + // May be truncate comment here, when index comment too long and sql_mode is't strict. + if indexOption != nil { + sessionVars := ctx.GetSessionVars() + if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { + return errors.Trace(err) + } + } + + if indexOption != nil && indexOption.Tp == model.IndexTypeHypo { // for hypo-index + indexInfo, err := BuildIndexInfo(ctx, tblInfo.Columns, indexName, false, unique, global, + indexPartSpecifications, indexOption, model.StatePublic) + if err != nil { + return err + } + return e.addHypoIndexIntoCtx(ctx, ti.Schema, ti.Name, indexInfo) + } + + chs, coll := ctx.GetSessionVars().GetCharsetInfo() + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionAddIndex, + BinlogInfo: &model.HistoryInfo{}, + ReorgMeta: nil, + Args: []any{unique, indexName, indexPartSpecifications, indexOption, hiddenCols, global}, + Priority: ctx.GetSessionVars().DDLReorgPriority, + Charset: chs, + Collate: coll, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + reorgMeta, err := newReorgMetaFromVariables(job, ctx) + if err != nil { + return err + } + job.ReorgMeta = reorgMeta + + err = e.DoDDLJob(ctx, job) + // key exists, but if_not_exists flags is true, so we ignore this error. + if dbterror.ErrDupKeyName.Equal(err) && ifNotExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return errors.Trace(err) +} + +func newReorgMetaFromVariables(job *model.Job, sctx sessionctx.Context) (*model.DDLReorgMeta, error) { + reorgMeta := NewDDLReorgMeta(sctx) + reorgMeta.IsDistReorg = variable.EnableDistTask.Load() + reorgMeta.IsFastReorg = variable.EnableFastReorg.Load() + reorgMeta.TargetScope = variable.ServiceScope.Load() + + if reorgMeta.IsDistReorg && !reorgMeta.IsFastReorg { + return nil, dbterror.ErrUnsupportedDistTask + } + if hasSysDB(job) { + if reorgMeta.IsDistReorg { + logutil.DDLLogger().Info("cannot use distributed task execution on system DB", + zap.Stringer("job", job)) + } + reorgMeta.IsDistReorg = false + reorgMeta.IsFastReorg = false + failpoint.Inject("reorgMetaRecordFastReorgDisabled", func(_ failpoint.Value) { + LastReorgMetaFastReorgDisabled = true + }) + } + return reorgMeta, nil +} + +// LastReorgMetaFastReorgDisabled is used for test. +var LastReorgMetaFastReorgDisabled bool + +func buildFKInfo(fkName model.CIStr, keys []*ast.IndexPartSpecification, refer *ast.ReferenceDef, cols []*table.Column) (*model.FKInfo, error) { + if len(keys) != len(refer.IndexPartSpecifications) { + return nil, infoschema.ErrForeignKeyNotMatch.GenWithStackByArgs(fkName, "Key reference and table reference don't match") + } + if err := checkTooLongForeignKey(fkName); err != nil { + return nil, err + } + if err := checkTooLongSchema(refer.Table.Schema); err != nil { + return nil, err + } + if err := checkTooLongTable(refer.Table.Name); err != nil { + return nil, err + } + + // all base columns of stored generated columns + baseCols := make(map[string]struct{}) + for _, col := range cols { + if col.IsGenerated() && col.GeneratedStored { + for name := range col.Dependences { + baseCols[name] = struct{}{} + } + } + } + + fkInfo := &model.FKInfo{ + Name: fkName, + RefSchema: refer.Table.Schema, + RefTable: refer.Table.Name, + Cols: make([]model.CIStr, len(keys)), + } + if variable.EnableForeignKey.Load() { + fkInfo.Version = model.FKVersion1 + } + + for i, key := range keys { + // Check add foreign key to generated columns + // For more detail, see https://dev.mysql.com/doc/refman/8.0/en/innodb-foreign-key-constraints.html#innodb-foreign-key-generated-columns + for _, col := range cols { + if col.Name.L != key.Column.Name.L { + continue + } + if col.IsGenerated() { + // Check foreign key on virtual generated columns + if !col.GeneratedStored { + return nil, infoschema.ErrForeignKeyCannotUseVirtualColumn.GenWithStackByArgs(fkInfo.Name.O, col.Name.O) + } + + // Check wrong reference options of foreign key on stored generated columns + switch refer.OnUpdate.ReferOpt { + case model.ReferOptionCascade, model.ReferOptionSetNull, model.ReferOptionSetDefault: + //nolint: gosec + return nil, dbterror.ErrWrongFKOptionForGeneratedColumn.GenWithStackByArgs("ON UPDATE " + refer.OnUpdate.ReferOpt.String()) + } + switch refer.OnDelete.ReferOpt { + case model.ReferOptionSetNull, model.ReferOptionSetDefault: + //nolint: gosec + return nil, dbterror.ErrWrongFKOptionForGeneratedColumn.GenWithStackByArgs("ON DELETE " + refer.OnDelete.ReferOpt.String()) + } + continue + } + // Check wrong reference options of foreign key on base columns of stored generated columns + if _, ok := baseCols[col.Name.L]; ok { + switch refer.OnUpdate.ReferOpt { + case model.ReferOptionCascade, model.ReferOptionSetNull, model.ReferOptionSetDefault: + return nil, infoschema.ErrCannotAddForeign + } + switch refer.OnDelete.ReferOpt { + case model.ReferOptionCascade, model.ReferOptionSetNull, model.ReferOptionSetDefault: + return nil, infoschema.ErrCannotAddForeign + } + } + } + col := table.FindCol(cols, key.Column.Name.O) + if col == nil { + return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStackByArgs(key.Column.Name) + } + if mysql.HasNotNullFlag(col.GetFlag()) && (refer.OnDelete.ReferOpt == model.ReferOptionSetNull || refer.OnUpdate.ReferOpt == model.ReferOptionSetNull) { + return nil, infoschema.ErrForeignKeyColumnNotNull.GenWithStackByArgs(col.Name.O, fkName) + } + fkInfo.Cols[i] = key.Column.Name + } + + fkInfo.RefCols = make([]model.CIStr, len(refer.IndexPartSpecifications)) + for i, key := range refer.IndexPartSpecifications { + if err := checkTooLongColumn(key.Column.Name); err != nil { + return nil, err + } + fkInfo.RefCols[i] = key.Column.Name + } + + fkInfo.OnDelete = int(refer.OnDelete.ReferOpt) + fkInfo.OnUpdate = int(refer.OnUpdate.ReferOpt) + + return fkInfo, nil +} + +func (e *executor) CreateForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName model.CIStr, keys []*ast.IndexPartSpecification, refer *ast.ReferenceDef) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) + } + + t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + } + if t.Meta().TempTableType != model.TempTableNone { + return infoschema.ErrCannotAddForeign + } + + if fkName.L == "" { + fkName = model.NewCIStr(fmt.Sprintf("fk_%d", t.Meta().MaxForeignKeyID+1)) + } + err = checkFKDupName(t.Meta(), fkName) + if err != nil { + return err + } + fkInfo, err := buildFKInfo(fkName, keys, refer, t.Cols()) + if err != nil { + return errors.Trace(err) + } + fkCheck := ctx.GetSessionVars().ForeignKeyChecks + err = checkAddForeignKeyValid(is, schema.Name.L, t.Meta(), fkInfo, fkCheck) + if err != nil { + return err + } + if model.FindIndexByColumns(t.Meta(), t.Meta().Indices, fkInfo.Cols...) == nil { + // Need to auto create index for fk cols + if ctx.GetSessionVars().StmtCtx.MultiSchemaInfo == nil { + ctx.GetSessionVars().StmtCtx.MultiSchemaInfo = model.NewMultiSchemaInfo() + } + indexPartSpecifications := make([]*ast.IndexPartSpecification, 0, len(fkInfo.Cols)) + for _, col := range fkInfo.Cols { + indexPartSpecifications = append(indexPartSpecifications, &ast.IndexPartSpecification{ + Column: &ast.ColumnName{Name: col}, + Length: types.UnspecifiedLength, // Index prefixes on foreign key columns are not supported. + }) + } + indexOption := &ast.IndexOption{} + err = e.createIndex(ctx, ti, ast.IndexKeyTypeNone, fkInfo.Name, indexPartSpecifications, indexOption, false) + if err != nil { + return err + } + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionAddForeignKey, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{fkInfo, fkCheck}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{ + { + Database: schema.Name.L, + Table: t.Meta().Name.L, + }, + { + Database: fkInfo.RefSchema.L, + Table: fkInfo.RefTable.L, + Mode: model.SharedInvolving, + }, + }, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) DropForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName model.CIStr) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) + } + + t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + SchemaState: model.StatePublic, + TableName: t.Meta().Name.L, + Type: model.ActionDropForeignKey, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{fkName}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) DropIndex(ctx sessionctx.Context, stmt *ast.DropIndexStmt) error { + ti := ast.Ident{Schema: stmt.Table.Schema, Name: stmt.Table.Name} + err := e.dropIndex(ctx, ti, model.NewCIStr(stmt.IndexName), stmt.IfExists, stmt.IsHypo) + if (infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err)) && stmt.IfExists { + err = nil + } + return err +} + +// dropHypoIndexFromCtx drops this hypo-index from this ctx. +func (*executor) dropHypoIndexFromCtx(ctx sessionctx.Context, schema, table, index model.CIStr, ifExists bool) error { + sctx := ctx.GetSessionVars() + if sctx.HypoIndexes != nil && + sctx.HypoIndexes[schema.L] != nil && + sctx.HypoIndexes[schema.L][table.L] != nil && + sctx.HypoIndexes[schema.L][table.L][index.L] != nil { + delete(sctx.HypoIndexes[schema.L][table.L], index.L) + return nil + } + if !ifExists { + return dbterror.ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", index) + } + return nil +} + +// dropIndex drops the specified index. +// isHypo is used to indicate whether this operation is for a hypo-index. +func (e *executor) dropIndex(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, ifExists, isHypo bool) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists) + } + t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + } + if t.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Drop Index")) + } + + if isHypo { + return e.dropHypoIndexFromCtx(ctx, ti.Schema, ti.Name, indexName, ifExists) + } + + indexInfo := t.Meta().FindIndexByName(indexName.L) + + isPK, err := CheckIsDropPrimaryKey(indexName, indexInfo, t) + if err != nil { + return err + } + + if !ctx.GetSessionVars().InRestrictedSQL && ctx.GetSessionVars().PrimaryKeyRequired && isPK { + return infoschema.ErrTableWithoutPrimaryKey + } + + if indexInfo == nil { + err = dbterror.ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", indexName) + if ifExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return err + } + + err = checkIndexNeededInForeignKey(is, schema.Name.L, t.Meta(), indexInfo) + if err != nil { + return err + } + + jobTp := model.ActionDropIndex + if isPK { + jobTp = model.ActionDropPrimaryKey + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + SchemaState: indexInfo.State, + TableName: t.Meta().Name.L, + Type: jobTp, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{indexName, ifExists}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// CheckIsDropPrimaryKey checks if we will drop PK, there are many PK implementations so we provide a helper function. +func CheckIsDropPrimaryKey(indexName model.CIStr, indexInfo *model.IndexInfo, t table.Table) (bool, error) { + var isPK bool + if indexName.L == strings.ToLower(mysql.PrimaryKeyName) && + // Before we fixed #14243, there might be a general index named `primary` but not a primary key. + (indexInfo == nil || indexInfo.Primary) { + isPK = true + } + if isPK { + // If the table's PKIsHandle is true, we can't find the index from the table. So we check the value of PKIsHandle. + if indexInfo == nil && !t.Meta().PKIsHandle { + return isPK, dbterror.ErrCantDropFieldOrKey.GenWithStackByArgs("PRIMARY") + } + if t.Meta().IsCommonHandle || t.Meta().PKIsHandle { + return isPK, dbterror.ErrUnsupportedModifyPrimaryKey.GenWithStack("Unsupported drop primary key when the table is using clustered index") + } + } + + return isPK, nil +} + +// validateCommentLength checks comment length of table, column, or index +// If comment length is more than the standard length truncate it +// and store the comment length upto the standard comment length size. +func validateCommentLength(ec errctx.Context, sqlMode mysql.SQLMode, name string, comment *string, errTooLongComment *terror.Error) (string, error) { + if comment == nil { + return "", nil + } + + maxLen := MaxCommentLength + // The maximum length of table comment in MySQL 5.7 is 2048 + // Other comment is 1024 + switch errTooLongComment { + case dbterror.ErrTooLongTableComment: + maxLen *= 2 + case dbterror.ErrTooLongFieldComment, dbterror.ErrTooLongIndexComment, dbterror.ErrTooLongTablePartitionComment: + default: + // add more types of terror.Error if need + } + if len(*comment) > maxLen { + err := errTooLongComment.GenWithStackByArgs(name, maxLen) + if sqlMode.HasStrictMode() { + // may be treated like an error. + return "", err + } + ec.AppendWarning(err) + *comment = (*comment)[:maxLen] + } + return *comment, nil +} + +// BuildAddedPartitionInfo build alter table add partition info +func BuildAddedPartitionInfo(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) { + numParts := uint64(0) + switch meta.Partition.Type { + case model.PartitionTypeNone: + // OK + case model.PartitionTypeList: + if len(spec.PartDefinitions) == 0 { + return nil, ast.ErrPartitionsMustBeDefined.GenWithStackByArgs(meta.Partition.Type) + } + err := checkListPartitions(spec.PartDefinitions) + if err != nil { + return nil, err + } + + case model.PartitionTypeRange: + if spec.Tp == ast.AlterTableAddLastPartition { + err := buildAddedPartitionDefs(ctx, meta, spec) + if err != nil { + return nil, err + } + spec.PartDefinitions = spec.Partition.Definitions + } else { + if len(spec.PartDefinitions) == 0 { + return nil, ast.ErrPartitionsMustBeDefined.GenWithStackByArgs(meta.Partition.Type) + } + } + case model.PartitionTypeHash, model.PartitionTypeKey: + switch spec.Tp { + case ast.AlterTableRemovePartitioning: + numParts = 1 + default: + return nil, errors.Trace(dbterror.ErrUnsupportedAddPartition) + case ast.AlterTableCoalescePartitions: + if int(spec.Num) >= len(meta.Partition.Definitions) { + return nil, dbterror.ErrDropLastPartition + } + numParts = uint64(len(meta.Partition.Definitions)) - spec.Num + case ast.AlterTableAddPartitions: + if len(spec.PartDefinitions) > 0 { + numParts = uint64(len(meta.Partition.Definitions)) + uint64(len(spec.PartDefinitions)) + } else { + numParts = uint64(len(meta.Partition.Definitions)) + spec.Num + } + } + default: + // we don't support ADD PARTITION for all other partition types yet. + return nil, errors.Trace(dbterror.ErrUnsupportedAddPartition) + } + + part := &model.PartitionInfo{ + Type: meta.Partition.Type, + Expr: meta.Partition.Expr, + Columns: meta.Partition.Columns, + Enable: meta.Partition.Enable, + } + + defs, err := buildPartitionDefinitionsInfo(ctx, spec.PartDefinitions, meta, numParts) + if err != nil { + return nil, err + } + + part.Definitions = defs + part.Num = uint64(len(defs)) + return part, nil +} + +func buildAddedPartitionDefs(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) error { + partInterval := getPartitionIntervalFromTable(ctx, meta) + if partInterval == nil { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( + "LAST PARTITION, does not seem like an INTERVAL partitioned table") + } + if partInterval.MaxValPart { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("LAST PARTITION when MAXVALUE partition exists") + } + + spec.Partition.Interval = partInterval + + if len(spec.PartDefinitions) > 0 { + return errors.Trace(dbterror.ErrUnsupportedAddPartition) + } + return GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition) +} + +// LockTables uses to execute lock tables statement. +func (e *executor) LockTables(ctx sessionctx.Context, stmt *ast.LockTablesStmt) error { + lockTables := make([]model.TableLockTpInfo, 0, len(stmt.TableLocks)) + sessionInfo := model.SessionInfo{ + ServerID: e.uuid, + SessionID: ctx.GetSessionVars().ConnectionID, + } + uniqueTableID := make(map[int64]struct{}) + involveSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(stmt.TableLocks)) + // Check whether the table was already locked by another. + for _, tl := range stmt.TableLocks { + tb := tl.Table + err := throwErrIfInMemOrSysDB(ctx, tb.Schema.L) + if err != nil { + return err + } + schema, t, err := e.getSchemaAndTableByIdent(ast.Ident{Schema: tb.Schema, Name: tb.Name}) + if err != nil { + return errors.Trace(err) + } + if t.Meta().IsView() || t.Meta().IsSequence() { + return table.ErrUnsupportedOp.GenWithStackByArgs() + } + + err = checkTableLocked(t.Meta(), tl.Type, sessionInfo) + if err != nil { + return err + } + if _, ok := uniqueTableID[t.Meta().ID]; ok { + return infoschema.ErrNonuniqTable.GenWithStackByArgs(t.Meta().Name) + } + uniqueTableID[t.Meta().ID] = struct{}{} + lockTables = append(lockTables, model.TableLockTpInfo{SchemaID: schema.ID, TableID: t.Meta().ID, Tp: tl.Type}) + involveSchemaInfo = append(involveSchemaInfo, model.InvolvingSchemaInfo{ + Database: schema.Name.L, + Table: t.Meta().Name.L, + }) + } + + unlockTables := ctx.GetAllTableLocks() + arg := &LockTablesArg{ + LockTables: lockTables, + UnlockTables: unlockTables, + SessionInfo: sessionInfo, + } + job := &model.Job{ + SchemaID: lockTables[0].SchemaID, + TableID: lockTables[0].TableID, + Type: model.ActionLockTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{arg}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: involveSchemaInfo, + SQLMode: ctx.GetSessionVars().SQLMode, + } + // AddTableLock here is avoiding this job was executed successfully but the session was killed before return. + ctx.AddTableLock(lockTables) + err := e.DoDDLJob(ctx, job) + if err == nil { + ctx.ReleaseTableLocks(unlockTables) + ctx.AddTableLock(lockTables) + } + return errors.Trace(err) +} + +// UnlockTables uses to execute unlock tables statement. +func (e *executor) UnlockTables(ctx sessionctx.Context, unlockTables []model.TableLockTpInfo) error { + if len(unlockTables) == 0 { + return nil + } + arg := &LockTablesArg{ + UnlockTables: unlockTables, + SessionInfo: model.SessionInfo{ + ServerID: e.uuid, + SessionID: ctx.GetSessionVars().ConnectionID, + }, + } + + involveSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(unlockTables)) + is := e.infoCache.GetLatest() + for _, t := range unlockTables { + schema, ok := is.SchemaByID(t.SchemaID) + if !ok { + continue + } + tbl, ok := is.TableByID(t.TableID) + if !ok { + continue + } + involveSchemaInfo = append(involveSchemaInfo, model.InvolvingSchemaInfo{ + Database: schema.Name.L, + Table: tbl.Meta().Name.L, + }) + } + job := &model.Job{ + SchemaID: unlockTables[0].SchemaID, + TableID: unlockTables[0].TableID, + Type: model.ActionUnlockTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{arg}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: involveSchemaInfo, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err := e.DoDDLJob(ctx, job) + if err == nil { + ctx.ReleaseAllTableLocks() + } + return errors.Trace(err) +} + +func throwErrIfInMemOrSysDB(ctx sessionctx.Context, dbLowerName string) error { + if util.IsMemOrSysDB(dbLowerName) { + if ctx.GetSessionVars().User != nil { + return infoschema.ErrAccessDenied.GenWithStackByArgs(ctx.GetSessionVars().User.Username, ctx.GetSessionVars().User.Hostname) + } + return infoschema.ErrAccessDenied.GenWithStackByArgs("", "") + } + return nil +} + +func (e *executor) CleanupTableLock(ctx sessionctx.Context, tables []*ast.TableName) error { + uniqueTableID := make(map[int64]struct{}) + cleanupTables := make([]model.TableLockTpInfo, 0, len(tables)) + unlockedTablesNum := 0 + involvingSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(tables)) + // Check whether the table was already locked by another. + for _, tb := range tables { + err := throwErrIfInMemOrSysDB(ctx, tb.Schema.L) + if err != nil { + return err + } + schema, t, err := e.getSchemaAndTableByIdent(ast.Ident{Schema: tb.Schema, Name: tb.Name}) + if err != nil { + return errors.Trace(err) + } + if t.Meta().IsView() || t.Meta().IsSequence() { + return table.ErrUnsupportedOp + } + // Maybe the table t was not locked, but still try to unlock this table. + // If we skip unlock the table here, the job maybe not consistent with the job.Query. + // eg: unlock tables t1,t2; If t2 is not locked and skip here, then the job will only unlock table t1, + // and this behaviour is not consistent with the sql query. + if !t.Meta().IsLocked() { + unlockedTablesNum++ + } + if _, ok := uniqueTableID[t.Meta().ID]; ok { + return infoschema.ErrNonuniqTable.GenWithStackByArgs(t.Meta().Name) + } + uniqueTableID[t.Meta().ID] = struct{}{} + cleanupTables = append(cleanupTables, model.TableLockTpInfo{SchemaID: schema.ID, TableID: t.Meta().ID}) + involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ + Database: schema.Name.L, + Table: t.Meta().Name.L, + }) + } + // If the num of cleanupTables is 0, or all cleanupTables is unlocked, just return here. + if len(cleanupTables) == 0 || len(cleanupTables) == unlockedTablesNum { + return nil + } + + arg := &LockTablesArg{ + UnlockTables: cleanupTables, + IsCleanup: true, + } + job := &model.Job{ + SchemaID: cleanupTables[0].SchemaID, + TableID: cleanupTables[0].TableID, + Type: model.ActionUnlockTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{arg}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: involvingSchemaInfo, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err := e.DoDDLJob(ctx, job) + if err == nil { + ctx.ReleaseTableLocks(cleanupTables) + } + return errors.Trace(err) +} + +// LockTablesArg is the argument for LockTables, export for test. +type LockTablesArg struct { + LockTables []model.TableLockTpInfo + IndexOfLock int + UnlockTables []model.TableLockTpInfo + IndexOfUnlock int + SessionInfo model.SessionInfo + IsCleanup bool +} + +func (e *executor) RepairTable(ctx sessionctx.Context, createStmt *ast.CreateTableStmt) error { + // Existence of DB and table has been checked in the preprocessor. + oldTableInfo, ok := (ctx.Value(domainutil.RepairedTable)).(*model.TableInfo) + if !ok || oldTableInfo == nil { + return dbterror.ErrRepairTableFail.GenWithStack("Failed to get the repaired table") + } + oldDBInfo, ok := (ctx.Value(domainutil.RepairedDatabase)).(*model.DBInfo) + if !ok || oldDBInfo == nil { + return dbterror.ErrRepairTableFail.GenWithStack("Failed to get the repaired database") + } + // By now only support same DB repair. + if createStmt.Table.Schema.L != oldDBInfo.Name.L { + return dbterror.ErrRepairTableFail.GenWithStack("Repaired table should in same database with the old one") + } + + // It is necessary to specify the table.ID and partition.ID manually. + newTableInfo, err := buildTableInfoWithCheck(ctx, createStmt, oldTableInfo.Charset, oldTableInfo.Collate, oldTableInfo.PlacementPolicyRef) + if err != nil { + return errors.Trace(err) + } + // Override newTableInfo with oldTableInfo's element necessary. + // TODO: There may be more element assignments here, and the new TableInfo should be verified with the actual data. + newTableInfo.ID = oldTableInfo.ID + if err = checkAndOverridePartitionID(newTableInfo, oldTableInfo); err != nil { + return err + } + newTableInfo.AutoIncID = oldTableInfo.AutoIncID + // If any old columnInfo has lost, that means the old column ID lost too, repair failed. + for i, newOne := range newTableInfo.Columns { + old := oldTableInfo.FindPublicColumnByName(newOne.Name.L) + if old == nil { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Column " + newOne.Name.L + " has lost") + } + if newOne.GetType() != old.GetType() { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Column " + newOne.Name.L + " type should be the same") + } + if newOne.GetFlen() != old.GetFlen() { + logutil.DDLLogger().Warn("admin repair table : Column " + newOne.Name.L + " flen is not equal to the old one") + } + newTableInfo.Columns[i].ID = old.ID + } + // If any old indexInfo has lost, that means the index ID lost too, so did the data, repair failed. + for i, newOne := range newTableInfo.Indices { + old := getIndexInfoByNameAndColumn(oldTableInfo, newOne) + if old == nil { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Index " + newOne.Name.L + " has lost") + } + if newOne.Tp != old.Tp { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Index " + newOne.Name.L + " type should be the same") + } + newTableInfo.Indices[i].ID = old.ID + } + + newTableInfo.State = model.StatePublic + err = checkTableInfoValid(newTableInfo) + if err != nil { + return err + } + newTableInfo.State = model.StateNone + + job := &model.Job{ + SchemaID: oldDBInfo.ID, + TableID: newTableInfo.ID, + SchemaName: oldDBInfo.Name.L, + TableName: newTableInfo.Name.L, + Type: model.ActionRepairTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{newTableInfo}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + if err == nil { + // Remove the old TableInfo from repairInfo before domain reload. + domainutil.RepairInfo.RemoveFromRepairInfo(oldDBInfo.Name.L, oldTableInfo.Name.L) + } + return errors.Trace(err) +} + +func (e *executor) OrderByColumns(ctx sessionctx.Context, ident ast.Ident) error { + _, tb, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + if tb.Meta().GetPkColInfo() != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("ORDER BY ignored as there is a user-defined clustered index in the table '%s'", ident.Name)) + } + return nil +} + +func (e *executor) CreateSequence(ctx sessionctx.Context, stmt *ast.CreateSequenceStmt) error { + ident := ast.Ident{Name: stmt.Name.Name, Schema: stmt.Name.Schema} + sequenceInfo, err := buildSequenceInfo(stmt, ident) + if err != nil { + return err + } + // TiDB describe the sequence within a tableInfo, as a same-level object of a table and view. + tbInfo, err := BuildTableInfo(ctx, ident.Name, nil, nil, "", "") + if err != nil { + return err + } + tbInfo.Sequence = sequenceInfo + + onExist := OnExistError + if stmt.IfNotExists { + onExist = OnExistIgnore + } + + return e.CreateTableWithInfo(ctx, ident.Schema, tbInfo, nil, WithOnExist(onExist)) +} + +func (e *executor) AlterSequence(ctx sessionctx.Context, stmt *ast.AlterSequenceStmt) error { + ident := ast.Ident{Name: stmt.Name.Name, Schema: stmt.Name.Schema} + is := e.infoCache.GetLatest() + // Check schema existence. + db, ok := is.SchemaByName(ident.Schema) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) + } + // Check table existence. + tbl, err := is.TableByName(context.Background(), ident.Schema, ident.Name) + if err != nil { + if stmt.IfExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return err + } + if !tbl.Meta().IsSequence() { + return dbterror.ErrWrongObject.GenWithStackByArgs(ident.Schema, ident.Name, "SEQUENCE") + } + + // Validate the new sequence option value in old sequenceInfo. + oldSequenceInfo := tbl.Meta().Sequence + copySequenceInfo := *oldSequenceInfo + _, _, err = alterSequenceOptions(stmt.SeqOptions, ident, ©SequenceInfo) + if err != nil { + return err + } + + job := &model.Job{ + SchemaID: db.ID, + TableID: tbl.Meta().ID, + SchemaName: db.Name.L, + TableName: tbl.Meta().Name.L, + Type: model.ActionAlterSequence, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{ident, stmt.SeqOptions}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) DropSequence(ctx sessionctx.Context, stmt *ast.DropSequenceStmt) (err error) { + return e.dropTableObject(ctx, stmt.Sequences, stmt.IfExists, sequenceObject) +} + +func (e *executor) AlterIndexVisibility(ctx sessionctx.Context, ident ast.Ident, indexName model.CIStr, visibility ast.IndexVisibility) error { + schema, tb, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return err + } + + invisible := false + if visibility == ast.IndexVisibilityInvisible { + invisible = true + } + + skip, err := validateAlterIndexVisibility(ctx, indexName, invisible, tb.Meta()) + if err != nil { + return errors.Trace(err) + } + if skip { + return nil + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tb.Meta().ID, + SchemaName: schema.Name.L, + TableName: tb.Meta().Name.L, + Type: model.ActionAlterIndexVisibility, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{indexName, invisible}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) AlterTableAttributes(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { + schema, tb, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + meta := tb.Meta() + + rule := label.NewRule() + err = rule.ApplyAttributesSpec(spec.AttributesSpec) + if err != nil { + return dbterror.ErrInvalidAttributesSpec.GenWithStackByArgs(err) + } + ids := getIDs([]*model.TableInfo{meta}) + rule.Reset(schema.Name.L, meta.Name.L, "", ids...) + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + TableName: meta.Name.L, + Type: model.ActionAlterTableAttributes, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{rule}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + if err != nil { + return errors.Trace(err) + } + + return errors.Trace(err) +} + +func (e *executor) AlterTablePartitionAttributes(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { + schema, tb, err := e.getSchemaAndTableByIdent(ident) + if err != nil { + return errors.Trace(err) + } + + meta := tb.Meta() + if meta.Partition == nil { + return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + + partitionID, err := tables.FindPartitionByName(meta, spec.PartitionNames[0].L) + if err != nil { + return errors.Trace(err) + } + + rule := label.NewRule() + err = rule.ApplyAttributesSpec(spec.AttributesSpec) + if err != nil { + return dbterror.ErrInvalidAttributesSpec.GenWithStackByArgs(err) + } + rule.Reset(schema.Name.L, meta.Name.L, spec.PartitionNames[0].L, partitionID) + + job := &model.Job{ + SchemaID: schema.ID, + TableID: meta.ID, + SchemaName: schema.Name.L, + TableName: meta.Name.L, + Type: model.ActionAlterTablePartitionAttributes, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{partitionID, rule}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + if err != nil { + return errors.Trace(err) + } + + return errors.Trace(err) +} + +func (e *executor) AlterTablePartitionOptions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { + var policyRefInfo *model.PolicyRefInfo + if spec.Options != nil { + for _, op := range spec.Options { + switch op.Tp { + case ast.TableOptionPlacementPolicy: + policyRefInfo = &model.PolicyRefInfo{ + Name: model.NewCIStr(op.StrValue), + } + default: + return errors.Trace(errors.New("unknown partition option")) + } + } + } + + if policyRefInfo != nil { + err = e.AlterTablePartitionPlacement(ctx, ident, spec, policyRefInfo) + if err != nil { + return errors.Trace(err) + } + } + + return nil +} + +func (e *executor) AlterTablePartitionPlacement(ctx sessionctx.Context, tableIdent ast.Ident, spec *ast.AlterTableSpec, policyRefInfo *model.PolicyRefInfo) (err error) { + schema, tb, err := e.getSchemaAndTableByIdent(tableIdent) + if err != nil { + return errors.Trace(err) + } + + tblInfo := tb.Meta() + if tblInfo.Partition == nil { + return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + + partitionID, err := tables.FindPartitionByName(tblInfo, spec.PartitionNames[0].L) + if err != nil { + return errors.Trace(err) + } + + if checkIgnorePlacementDDL(ctx) { + return nil + } + + policyRefInfo, err = checkAndNormalizePlacementPolicy(ctx, policyRefInfo) + if err != nil { + return errors.Trace(err) + } + + var involveSchemaInfo []model.InvolvingSchemaInfo + if policyRefInfo != nil { + involveSchemaInfo = []model.InvolvingSchemaInfo{ + { + Database: schema.Name.L, + Table: tblInfo.Name.L, + }, + { + Policy: policyRefInfo.Name.L, + Mode: model.SharedInvolving, + }, + } + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tblInfo.ID, + SchemaName: schema.Name.L, + TableName: tblInfo.Name.L, + Type: model.ActionAlterTablePartitionPlacement, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{partitionID, policyRefInfo}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + InvolvingSchemaInfo: involveSchemaInfo, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +// AddResourceGroup implements the DDL interface, creates a resource group. +func (e *executor) AddResourceGroup(ctx sessionctx.Context, stmt *ast.CreateResourceGroupStmt) (err error) { + groupName := stmt.ResourceGroupName + groupInfo := &model.ResourceGroupInfo{Name: groupName, ResourceGroupSettings: model.NewResourceGroupSettings()} + groupInfo, err = buildResourceGroup(groupInfo, stmt.ResourceGroupOptionList) + if err != nil { + return err + } + + if _, ok := e.infoCache.GetLatest().ResourceGroupByName(groupName); ok { + if stmt.IfNotExists { + err = infoschema.ErrResourceGroupExists.FastGenByArgs(groupName) + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return infoschema.ErrResourceGroupExists.GenWithStackByArgs(groupName) + } + + if err := checkResourceGroupValidation(groupInfo); err != nil { + return err + } + + logutil.DDLLogger().Debug("create resource group", zap.String("name", groupName.O), zap.Stringer("resource group settings", groupInfo.ResourceGroupSettings)) + groupIDs, err := e.genGlobalIDs(1) + if err != nil { + return err + } + groupInfo.ID = groupIDs[0] + + job := &model.Job{ + SchemaName: groupName.L, + Type: model.ActionCreateResourceGroup, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{groupInfo, false}, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + ResourceGroup: groupInfo.Name.L, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return err +} + +// DropResourceGroup implements the DDL interface. +func (e *executor) DropResourceGroup(ctx sessionctx.Context, stmt *ast.DropResourceGroupStmt) (err error) { + groupName := stmt.ResourceGroupName + if groupName.L == rg.DefaultResourceGroupName { + return resourcegroup.ErrDroppingInternalResourceGroup + } + is := e.infoCache.GetLatest() + // Check group existence. + group, ok := is.ResourceGroupByName(groupName) + if !ok { + err = infoschema.ErrResourceGroupNotExists.GenWithStackByArgs(groupName) + if stmt.IfExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return err + } + + // check to see if some user has dependency on the group + checker := privilege.GetPrivilegeManager(ctx) + if checker == nil { + return errors.New("miss privilege checker") + } + user, matched := checker.MatchUserResourceGroupName(groupName.L) + if matched { + err = errors.Errorf("user [%s] depends on the resource group to drop", user) + return err + } + + job := &model.Job{ + SchemaID: group.ID, + SchemaName: group.Name.L, + Type: model.ActionDropResourceGroup, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{groupName}, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + ResourceGroup: groupName.L, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return err +} + +// AlterResourceGroup implements the DDL interface. +func (e *executor) AlterResourceGroup(ctx sessionctx.Context, stmt *ast.AlterResourceGroupStmt) (err error) { + groupName := stmt.ResourceGroupName + is := e.infoCache.GetLatest() + // Check group existence. + group, ok := is.ResourceGroupByName(groupName) + if !ok { + err := infoschema.ErrResourceGroupNotExists.GenWithStackByArgs(groupName) + if stmt.IfExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return err + } + newGroupInfo, err := buildResourceGroup(group, stmt.ResourceGroupOptionList) + if err != nil { + return errors.Trace(err) + } + + if err := checkResourceGroupValidation(newGroupInfo); err != nil { + return err + } + + logutil.DDLLogger().Debug("alter resource group", zap.String("name", groupName.L), zap.Stringer("new resource group settings", newGroupInfo.ResourceGroupSettings)) + + job := &model.Job{ + SchemaID: newGroupInfo.ID, + SchemaName: newGroupInfo.Name.L, + Type: model.ActionAlterResourceGroup, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{newGroupInfo}, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + ResourceGroup: newGroupInfo.Name.L, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return err +} + +func (e *executor) CreatePlacementPolicy(ctx sessionctx.Context, stmt *ast.CreatePlacementPolicyStmt) (err error) { + if checkIgnorePlacementDDL(ctx) { + return nil + } + + if stmt.OrReplace && stmt.IfNotExists { + return dbterror.ErrWrongUsage.GenWithStackByArgs("OR REPLACE", "IF NOT EXISTS") + } + + policyInfo, err := buildPolicyInfo(stmt.PolicyName, stmt.PlacementOptions) + if err != nil { + return errors.Trace(err) + } + + var onExists OnExist + switch { + case stmt.IfNotExists: + onExists = OnExistIgnore + case stmt.OrReplace: + onExists = OnExistReplace + default: + onExists = OnExistError + } + + return e.CreatePlacementPolicyWithInfo(ctx, policyInfo, onExists) +} + +func (e *executor) DropPlacementPolicy(ctx sessionctx.Context, stmt *ast.DropPlacementPolicyStmt) (err error) { + if checkIgnorePlacementDDL(ctx) { + return nil + } + policyName := stmt.PolicyName + is := e.infoCache.GetLatest() + // Check policy existence. + policy, ok := is.PolicyByName(policyName) + if !ok { + err = infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(policyName) + if stmt.IfExists { + ctx.GetSessionVars().StmtCtx.AppendNote(err) + return nil + } + return err + } + + if err = CheckPlacementPolicyNotInUseFromInfoSchema(is, policy); err != nil { + return err + } + + job := &model.Job{ + SchemaID: policy.ID, + SchemaName: policy.Name.L, + Type: model.ActionDropPlacementPolicy, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{policyName}, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Policy: policyName.L, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) AlterPlacementPolicy(ctx sessionctx.Context, stmt *ast.AlterPlacementPolicyStmt) (err error) { + if checkIgnorePlacementDDL(ctx) { + return nil + } + policyName := stmt.PolicyName + is := e.infoCache.GetLatest() + // Check policy existence. + policy, ok := is.PolicyByName(policyName) + if !ok { + return infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(policyName) + } + + newPolicyInfo, err := buildPolicyInfo(policy.Name, stmt.PlacementOptions) + if err != nil { + return errors.Trace(err) + } + + err = checkPolicyValidation(newPolicyInfo.PlacementSettings) + if err != nil { + return err + } + + job := &model.Job{ + SchemaID: policy.ID, + SchemaName: policy.Name.L, + Type: model.ActionAlterPlacementPolicy, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{newPolicyInfo}, + InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ + Policy: newPolicyInfo.Name.L, + }}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) AlterTableCache(sctx sessionctx.Context, ti ast.Ident) (err error) { + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return err + } + // if a table is already in cache state, return directly + if t.Meta().TableCacheStatusType == model.TableCacheStatusEnable { + return nil + } + + // forbidden cache table in system database. + if util.IsMemOrSysDB(schema.Name.L) { + return errors.Trace(dbterror.ErrUnsupportedAlterCacheForSysTable) + } else if t.Meta().TempTableType != model.TempTableNone { + return dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("alter temporary table cache") + } + + if t.Meta().Partition != nil { + return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("partition mode") + } + + succ, err := checkCacheTableSize(e.store, t.Meta().ID) + if err != nil { + return errors.Trace(err) + } + if !succ { + return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("table too large") + } + + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + ddlQuery, _ := sctx.Value(sessionctx.QueryString).(string) + // Initialize the cached table meta lock info in `mysql.table_cache_meta`. + // The operation shouldn't fail in most cases, and if it does, return the error directly. + // This DML and the following DDL is not atomic, that's not a problem. + _, _, err = sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, + "replace into mysql.table_cache_meta values (%?, 'NONE', 0, 0)", t.Meta().ID) + if err != nil { + return errors.Trace(err) + } + + sctx.SetValue(sessionctx.QueryString, ddlQuery) + + job := &model.Job{ + SchemaID: schema.ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + TableID: t.Meta().ID, + Type: model.ActionAlterCacheTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{}, + CDCWriteSource: sctx.GetSessionVars().CDCWriteSource, + SQLMode: sctx.GetSessionVars().SQLMode, + } + + return e.DoDDLJob(sctx, job) +} + +func checkCacheTableSize(store kv.Storage, tableID int64) (bool, error) { + const cacheTableSizeLimit = 64 * (1 << 20) // 64M + succ := true + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnCacheTable) + err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + txn.SetOption(kv.RequestSourceType, kv.InternalTxnCacheTable) + prefix := tablecodec.GenTablePrefix(tableID) + it, err := txn.Iter(prefix, prefix.PrefixNext()) + if err != nil { + return errors.Trace(err) + } + defer it.Close() + + totalSize := 0 + for it.Valid() && it.Key().HasPrefix(prefix) { + key := it.Key() + value := it.Value() + totalSize += len(key) + totalSize += len(value) + + if totalSize > cacheTableSizeLimit { + succ = false + break + } + + err = it.Next() + if err != nil { + return errors.Trace(err) + } + } + return nil + }) + return succ, err +} + +func (e *executor) AlterTableNoCache(ctx sessionctx.Context, ti ast.Ident) (err error) { + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return err + } + // if a table is not in cache state, return directly + if t.Meta().TableCacheStatusType == model.TableCacheStatusDisable { + return nil + } + + job := &model.Job{ + SchemaID: schema.ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + TableID: t.Meta().ID, + Type: model.ActionAlterNoCacheTable, + BinlogInfo: &model.HistoryInfo{}, + Args: []any{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + return e.DoDDLJob(ctx, job) +} + +func (e *executor) CreateCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr, constr *ast.Constraint) error { + schema, t, err := e.getSchemaAndTableByIdent(ti) + if err != nil { + return errors.Trace(err) + } + if constraintInfo := t.Meta().FindConstraintInfoByName(constrName.L); constraintInfo != nil { + return infoschema.ErrCheckConstraintDupName.GenWithStackByArgs(constrName.L) + } + + // allocate the temporary constraint name for dependency-check-error-output below. + constrNames := map[string]bool{} + for _, constr := range t.Meta().Constraints { + constrNames[constr.Name.L] = true + } + setEmptyCheckConstraintName(t.Meta().Name.L, constrNames, []*ast.Constraint{constr}) + + // existedColsMap can be used to check the existence of depended. + existedColsMap := make(map[string]struct{}) + cols := t.Cols() + for _, v := range cols { + existedColsMap[v.Name.L] = struct{}{} + } + // check expression if supported + if ok, err := table.IsSupportedExpr(constr); !ok { + return err + } + + dependedColsMap := findDependentColsInExpr(constr.Expr) + dependedCols := make([]model.CIStr, 0, len(dependedColsMap)) + for k := range dependedColsMap { + if _, ok := existedColsMap[k]; !ok { + // The table constraint depended on a non-existed column. + return dbterror.ErrBadField.GenWithStackByArgs(k, "check constraint "+constr.Name+" expression") + } + dependedCols = append(dependedCols, model.NewCIStr(k)) + } + + // build constraint meta info. + tblInfo := t.Meta() + + // check auto-increment column + if table.ContainsAutoIncrementCol(dependedCols, tblInfo) { + return dbterror.ErrCheckConstraintRefersAutoIncrementColumn.GenWithStackByArgs(constr.Name) + } + // check foreign key + if err := table.HasForeignKeyRefAction(tblInfo.ForeignKeys, nil, constr, dependedCols); err != nil { + return err + } + constraintInfo, err := buildConstraintInfo(tblInfo, dependedCols, constr, model.StateNone) + if err != nil { + return errors.Trace(err) + } + // check if the expression is bool type + if err := table.IfCheckConstraintExprBoolType(ctx.GetExprCtx().GetEvalCtx(), constraintInfo, tblInfo); err != nil { + return err + } + job := &model.Job{ + SchemaID: schema.ID, + TableID: tblInfo.ID, + SchemaName: schema.Name.L, + TableName: tblInfo.Name.L, + Type: model.ActionAddCheckConstraint, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{constraintInfo}, + Priority: ctx.GetSessionVars().DDLReorgPriority, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) DropCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists) + } + t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + } + tblInfo := t.Meta() + + constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + return dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tblInfo.ID, + SchemaName: schema.Name.L, + TableName: tblInfo.Name.L, + Type: model.ActionDropCheckConstraint, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{constrName}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) AlterCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr, enforced bool) error { + is := e.infoCache.GetLatest() + schema, ok := is.SchemaByName(ti.Schema) + if !ok { + return errors.Trace(infoschema.ErrDatabaseNotExists) + } + t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) + if err != nil { + return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) + } + tblInfo := t.Meta() + + constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) + if constraintInfo == nil { + return dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: tblInfo.ID, + SchemaName: schema.Name.L, + TableName: tblInfo.Name.L, + Type: model.ActionAlterCheckConstraint, + BinlogInfo: &model.HistoryInfo{}, + CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, + Args: []any{constrName, enforced}, + SQLMode: ctx.GetSessionVars().SQLMode, + } + + err = e.DoDDLJob(ctx, job) + return errors.Trace(err) +} + +func (e *executor) genGlobalIDs(count int) ([]int64, error) { + var ret []int64 + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + // lock to reduce conflict + e.globalIDLock.Lock() + defer e.globalIDLock.Unlock() + err := kv.RunInNewTxn(ctx, e.store, true, func(_ context.Context, txn kv.Transaction) error { + m := meta.NewMeta(txn) + var err error + ret, err = m.GenGlobalIDs(count) + return err + }) + + return ret, err +} + +func (e *executor) genPlacementPolicyID() (int64, error) { + var ret int64 + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + err := kv.RunInNewTxn(ctx, e.store, true, func(_ context.Context, txn kv.Transaction) error { + m := meta.NewMeta(txn) + var err error + ret, err = m.GenPlacementPolicyID() + return err + }) + + return ret, err +} + +// DoDDLJob will return +// - nil: found in history DDL job and no job error +// - context.Cancel: job has been sent to worker, but not found in history DDL job before cancel +// - other: found in history DDL job and return that job error +func (e *executor) DoDDLJob(ctx sessionctx.Context, job *model.Job) error { + return e.DoDDLJobWrapper(ctx, NewJobWrapper(job, false)) +} + +// DoDDLJobWrapper submit DDL job and wait it finishes. +// When fast create is enabled, we might merge multiple jobs into one, so do not +// depend on job.ID, use JobID from jobSubmitResult. +func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) error { + job := jobW.Job + job.TraceInfo = &model.TraceInfo{ + ConnectionID: ctx.GetSessionVars().ConnectionID, + SessionAlias: ctx.GetSessionVars().SessionAlias, + } + if mci := ctx.GetSessionVars().StmtCtx.MultiSchemaInfo; mci != nil { + // In multiple schema change, we don't run the job. + // Instead, we merge all the jobs into one pending job. + return appendToSubJobs(mci, job) + } + // Get a global job ID and put the DDL job in the queue. + setDDLJobQuery(ctx, job) + e.deliverJobTask(jobW) + + failpoint.Inject("mockParallelSameDDLJobTwice", func(val failpoint.Value) { + if val.(bool) { + <-jobW.ResultCh[0] + // The same job will be put to the DDL queue twice. + job = job.Clone() + newJobW := NewJobWrapper(job, jobW.IDAllocated) + e.deliverJobTask(newJobW) + // The second job result is used for test. + jobW = newJobW + } + }) + + // worker should restart to continue handling tasks in limitJobCh, and send back through jobW.err + result := <-jobW.ResultCh[0] + // job.ID must be allocated after previous channel receive returns nil. + jobID, err := result.jobID, result.err + defer e.delJobDoneCh(jobID) + if err != nil { + // The transaction of enqueuing job is failed. + return errors.Trace(err) + } + failpoint.InjectCall("waitJobSubmitted") + + sessVars := ctx.GetSessionVars() + sessVars.StmtCtx.IsDDLJobInQueue = true + + ddlAction := job.Type + // Notice worker that we push a new job and wait the job done. + e.notifyNewJobSubmitted(e.ddlJobNotifyCh, addingDDLJobNotifyKey, jobID, ddlAction.String()) + if result.merged { + logutil.DDLLogger().Info("DDL job submitted", zap.Int64("job_id", jobID), zap.String("query", job.Query), zap.String("merged", "true")) + } else { + logutil.DDLLogger().Info("DDL job submitted", zap.Stringer("job", job), zap.String("query", job.Query)) + } + + var historyJob *model.Job + + // Attach the context of the jobId to the calling session so that + // KILL can cancel this DDL job. + ctx.GetSessionVars().StmtCtx.DDLJobID = jobID + + // For a job from start to end, the state of it will be none -> delete only -> write only -> reorganization -> public + // For every state changes, we will wait as lease 2 * lease time, so here the ticker check is 10 * lease. + // But we use etcd to speed up, normally it takes less than 0.5s now, so we use 0.5s or 1s or 3s as the max value. + initInterval, _ := getJobCheckInterval(ddlAction, 0) + ticker := time.NewTicker(chooseLeaseTime(10*e.lease, initInterval)) + startTime := time.Now() + metrics.JobsGauge.WithLabelValues(ddlAction.String()).Inc() + defer func() { + ticker.Stop() + metrics.JobsGauge.WithLabelValues(ddlAction.String()).Dec() + metrics.HandleJobHistogram.WithLabelValues(ddlAction.String(), metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + recordLastDDLInfo(ctx, historyJob) + }() + i := 0 + notifyCh, _ := e.getJobDoneCh(jobID) + for { + failpoint.InjectCall("storeCloseInLoop") + select { + case _, ok := <-notifyCh: + if !ok { + // when fast create enabled, jobs might be merged, and we broadcast + // the result by closing the channel, to avoid this loop keeps running + // without sleeping on retryable error, we set it to nil. + notifyCh = nil + } + case <-ticker.C: + i++ + ticker = updateTickerInterval(ticker, 10*e.lease, ddlAction, i) + case <-e.ctx.Done(): + logutil.DDLLogger().Info("DoDDLJob will quit because context done") + return context.Canceled + } + + // If the connection being killed, we need to CANCEL the DDL job. + if sessVars.SQLKiller.HandleSignal() == exeerrors.ErrQueryInterrupted { + if atomic.LoadInt32(&sessVars.ConnectionStatus) == variable.ConnStatusShutdown { + logutil.DDLLogger().Info("DoDDLJob will quit because context done") + return context.Canceled + } + if sessVars.StmtCtx.DDLJobID != 0 { + se, err := e.sessPool.Get() + if err != nil { + logutil.DDLLogger().Error("get session failed, check again", zap.Error(err)) + continue + } + sessVars.StmtCtx.DDLJobID = 0 // Avoid repeat. + errs, err := CancelJobsBySystem(se, []int64{jobID}) + e.sessPool.Put(se) + if len(errs) > 0 { + logutil.DDLLogger().Warn("error canceling DDL job", zap.Error(errs[0])) + } + if err != nil { + logutil.DDLLogger().Warn("Kill command could not cancel DDL job", zap.Error(err)) + continue + } + } + } + + se, err := e.sessPool.Get() + if err != nil { + logutil.DDLLogger().Error("get session failed, check again", zap.Error(err)) + continue + } + historyJob, err = GetHistoryJobByID(se, jobID) + e.sessPool.Put(se) + if err != nil { + logutil.DDLLogger().Error("get history DDL job failed, check again", zap.Error(err)) + continue + } + if historyJob == nil { + logutil.DDLLogger().Debug("DDL job is not in history, maybe not run", zap.Int64("jobID", jobID)) + continue + } + + e.checkHistoryJobInTest(ctx, historyJob) + + // If a job is a history job, the state must be JobStateSynced or JobStateRollbackDone or JobStateCancelled. + if historyJob.IsSynced() { + // Judge whether there are some warnings when executing DDL under the certain SQL mode. + if historyJob.ReorgMeta != nil && len(historyJob.ReorgMeta.Warnings) != 0 { + if len(historyJob.ReorgMeta.Warnings) != len(historyJob.ReorgMeta.WarningsCount) { + logutil.DDLLogger().Info("DDL warnings doesn't match the warnings count", zap.Int64("jobID", jobID)) + } else { + for key, warning := range historyJob.ReorgMeta.Warnings { + keyCount := historyJob.ReorgMeta.WarningsCount[key] + if keyCount == 1 { + ctx.GetSessionVars().StmtCtx.AppendWarning(warning) + } else { + newMsg := fmt.Sprintf("%d warnings with this error code, first warning: "+warning.GetMsg(), keyCount) + newWarning := dbterror.ClassTypes.Synthesize(terror.ErrCode(warning.Code()), newMsg) + ctx.GetSessionVars().StmtCtx.AppendWarning(newWarning) + } + } + } + } + appendMultiChangeWarningsToOwnerCtx(ctx, historyJob) + + logutil.DDLLogger().Info("DDL job is finished", zap.Int64("jobID", jobID)) + return nil + } + + if historyJob.Error != nil { + logutil.DDLLogger().Info("DDL job is failed", zap.Int64("jobID", jobID)) + return errors.Trace(historyJob.Error) + } + panic("When the state is JobStateRollbackDone or JobStateCancelled, historyJob.Error should never be nil") + } +} + +func (e *executor) getJobDoneCh(jobID int64) (chan struct{}, bool) { + return e.ddlJobDoneChMap.Load(jobID) +} + +func (e *executor) delJobDoneCh(jobID int64) { + e.ddlJobDoneChMap.Delete(jobID) +} + +func (e *executor) deliverJobTask(task *JobWrapper) { + // TODO this might block forever, as the consumer part considers context cancel. + e.limitJobCh <- task +} + +func updateTickerInterval(ticker *time.Ticker, lease time.Duration, action model.ActionType, i int) *time.Ticker { + interval, changed := getJobCheckInterval(action, i) + if !changed { + return ticker + } + // For now we should stop old ticker and create a new ticker + ticker.Stop() + return time.NewTicker(chooseLeaseTime(lease, interval)) +} + +func recordLastDDLInfo(ctx sessionctx.Context, job *model.Job) { + if job == nil { + return + } + ctx.GetSessionVars().LastDDLInfo.Query = job.Query + ctx.GetSessionVars().LastDDLInfo.SeqNum = job.SeqNum +} + +func setDDLJobQuery(ctx sessionctx.Context, job *model.Job) { + switch job.Type { + case model.ActionUpdateTiFlashReplicaStatus, model.ActionUnlockTable: + job.Query = "" + default: + job.Query, _ = ctx.Value(sessionctx.QueryString).(string) + } +} + +var ( + fastDDLIntervalPolicy = []time.Duration{ + 500 * time.Millisecond, + } + normalDDLIntervalPolicy = []time.Duration{ + 500 * time.Millisecond, + 500 * time.Millisecond, + 1 * time.Second, + } + slowDDLIntervalPolicy = []time.Duration{ + 500 * time.Millisecond, + 500 * time.Millisecond, + 1 * time.Second, + 1 * time.Second, + 3 * time.Second, + } +) + +func getIntervalFromPolicy(policy []time.Duration, i int) (time.Duration, bool) { + plen := len(policy) + if i < plen { + return policy[i], true + } + return policy[plen-1], false +} + +func getJobCheckInterval(action model.ActionType, i int) (time.Duration, bool) { + switch action { + case model.ActionAddIndex, model.ActionAddPrimaryKey, model.ActionModifyColumn, + model.ActionReorganizePartition, + model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + return getIntervalFromPolicy(slowDDLIntervalPolicy, i) + case model.ActionCreateTable, model.ActionCreateSchema: + return getIntervalFromPolicy(fastDDLIntervalPolicy, i) + default: + return getIntervalFromPolicy(normalDDLIntervalPolicy, i) + } +} + +// NewDDLReorgMeta create a DDL ReorgMeta. +func NewDDLReorgMeta(ctx sessionctx.Context) *model.DDLReorgMeta { + tzName, tzOffset := ddlutil.GetTimeZone(ctx) + return &model.DDLReorgMeta{ + SQLMode: ctx.GetSessionVars().SQLMode, + Warnings: make(map[errors.ErrorID]*terror.Error), + WarningsCount: make(map[errors.ErrorID]int64), + Location: &model.TimeZoneLocation{Name: tzName, Offset: tzOffset}, + ResourceGroupName: ctx.GetSessionVars().StmtCtx.ResourceGroupName, + Version: model.CurrentReorgMetaVersion, + } +} diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index d9b2c97d50903..168022effcbdf 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -791,7 +791,7 @@ SwitchIndexState: } // Inject the failpoint to prevent the progress of index creation. - failpoint.Inject("create-index-stuck-before-public", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("create-index-stuck-before-public")); _err_ == nil { if sigFile, ok := v.(string); ok { for { time.Sleep(1 * time.Second) @@ -799,12 +799,12 @@ SwitchIndexState: if os.IsNotExist(err) { continue } - failpoint.Return(ver, errors.Trace(err)) + return ver, errors.Trace(err) } break } } - }) + } ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StatePublic) if err != nil { @@ -931,11 +931,11 @@ func doReorgWorkForCreateIndex( ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) return false, ver, errors.Trace(err) case model.BackfillStateReadyToMerge: - failpoint.Inject("mockDMLExecutionStateBeforeMerge", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateBeforeMerge")); _err_ == nil { if MockDMLExecutionStateBeforeMerge != nil { MockDMLExecutionStateBeforeMerge() } - }) + } logutil.DDLLogger().Info("index backfill state ready to merge", zap.Int64("job ID", job.ID), zap.String("table", tbl.Meta().Name.O), @@ -993,7 +993,7 @@ func runIngestReorgJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, } return false, ver, errors.Trace(err) } - failpoint.InjectCall("afterRunIngestReorgJob", job, done) + failpoint.Call(_curpkg_("afterRunIngestReorgJob"), job, done) return done, ver, nil } @@ -1025,13 +1025,13 @@ func runReorgJobAndHandleErr( elements = append(elements, &meta.Element{ID: indexInfo.ID, TypeKey: meta.IndexElementKey}) } - failpoint.Inject("mockDMLExecutionStateMerging", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateMerging")); _err_ == nil { //nolint:forcetypeassert if val.(bool) && allIndexInfos[0].BackfillState == model.BackfillStateMerging && MockDMLExecutionStateMerging != nil { MockDMLExecutionStateMerging() } - }) + } sctx, err1 := w.sessPool.Get() if err1 != nil { @@ -1080,11 +1080,11 @@ func runReorgJobAndHandleErr( } return false, ver, errors.Trace(err) } - failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateBeforeImport")); _err_ == nil { if MockDMLExecutionStateBeforeImport != nil { MockDMLExecutionStateBeforeImport() } - }) + } return true, ver, nil } @@ -1149,12 +1149,12 @@ func onDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { idxIDs = append(idxIDs, indexInfo.ID) } - failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockExceedErrorLimit")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { panic("panic test in cancelling add index") } - }) + } ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != model.StateNone) if err != nil { @@ -1423,25 +1423,25 @@ var mockNotOwnerErrOnce uint32 // getIndexRecord gets index columns values use w.rowDecoder, and generate indexRecord. func (w *baseIndexWorker) getIndexRecord(idxInfo *model.IndexInfo, handle kv.Handle, recordKey []byte) (*indexRecord, error) { cols := w.table.WritableCols() - failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockGetIndexRecordErr")); _err_ == nil { if valStr, ok := val.(string); ok { switch valStr { case "cantDecodeRecordErr": - failpoint.Return(nil, errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("index", - errors.New("mock can't decode record error")))) + return nil, errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("index", + errors.New("mock can't decode record error"))) case "modifyColumnNotOwnerErr": if idxInfo.Name.O == "_Idx$_idx_0" && handle.IntValue() == 7168 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 0, 1) { - failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) + return nil, errors.Trace(dbterror.ErrNotOwner) } case "addIdxNotOwnerErr": // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. // First step, we need to exit "addPhysicalTableIndex". if idxInfo.Name.O == "idx2" && handle.IntValue() == 6144 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 1, 2) { - failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) + return nil, errors.Trace(dbterror.ErrNotOwner) } } } - }) + } idxVal := make([]types.Datum, len(idxInfo.Columns)) var err error for j, v := range idxInfo.Columns { @@ -1768,16 +1768,16 @@ func writeOneKVToLocal( if err != nil { return errors.Trace(err) } - failpoint.Inject("mockLocalWriterPanic", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockLocalWriterPanic")); _err_ == nil { panic("mock panic") - }) + } err = writer.WriteRow(ctx, key, idxVal, handle) if err != nil { return errors.Trace(err) } - failpoint.Inject("mockLocalWriterError", func() { - failpoint.Return(errors.New("mock engine error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockLocalWriterError")); _err_ == nil { + return errors.New("mock engine error") + } writeBufs.IndexKeyBuf = key writeBufs.RowValBuf = idxVal } @@ -1788,12 +1788,12 @@ func writeOneKVToLocal( // Note that index columns values may change, and an index is not allowed to be added, so the txn will rollback and retry. // BackfillData will add w.batchCnt indices once, default value of w.batchCnt is 128. func (w *addIndexTxnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - failpoint.Inject("errorMockPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorMockPanic")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { panic("panic test") } - }) + } oprStartTime := time.Now() jobID := handleRange.getJobID() @@ -1858,12 +1858,12 @@ func (w *addIndexTxnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx return nil }) logSlowOperations(time.Since(oprStartTime), "AddIndexBackfillData", 3000) - failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecution")); _err_ == nil { //nolint:forcetypeassert if val.(bool) && MockDMLExecution != nil { MockDMLExecution() } - }) + } return } @@ -1923,7 +1923,7 @@ func (w *worker) addTableIndex(t table.Table, reorgInfo *reorgInfo) error { if err != nil { return errors.Trace(err) } - failpoint.InjectCall("afterUpdatePartitionReorgInfo", reorgInfo.Job) + failpoint.Call(_curpkg_("afterUpdatePartitionReorgInfo"), reorgInfo.Job) // Every time we finish a partition, we update the progress of the job. if rc := w.getReorgCtx(reorgInfo.Job.ID); rc != nil { reorgInfo.Job.SetRowCount(rc.getRowCount()) @@ -2055,7 +2055,7 @@ func (w *worker) executeDistTask(t table.Table, reorgInfo *reorgInfo) error { g.Go(func() error { defer close(done) err := submitAndWaitTask(ctx, taskKey, taskType, concurrency, reorgInfo.ReorgMeta.TargetScope, metaData) - failpoint.InjectCall("pauseAfterDistTaskFinished") + failpoint.Call(_curpkg_("pauseAfterDistTaskFinished")) if err := w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { if dbterror.ErrPausedDDLJob.Equal(err) { logutil.DDLLogger().Warn("job paused by user", zap.Error(err)) @@ -2083,7 +2083,7 @@ func (w *worker) executeDistTask(t table.Table, reorgInfo *reorgInfo) error { logutil.DDLLogger().Error("pause task error", zap.String("task_key", taskKey), zap.Error(err)) continue } - failpoint.InjectCall("syncDDLTaskPause") + failpoint.Call(_curpkg_("syncDDLTaskPause")) } if !dbterror.ErrCancelledDDLJob.Equal(err) { return errors.Trace(err) @@ -2264,7 +2264,7 @@ func getNextPartitionInfo(reorg *reorgInfo, t table.PartitionedTable, currPhysic return 0, nil, nil, nil } - failpoint.Inject("mockUpdateCachedSafePoint", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateCachedSafePoint")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { ts := oracle.GoTimeToTS(time.Now()) @@ -2273,7 +2273,7 @@ func getNextPartitionInfo(reorg *reorgInfo, t table.PartitionedTable, currPhysic s.UpdateSPCache(ts, time.Now()) time.Sleep(time.Second * 3) } - }) + } var startKey, endKey kv.Key if reorg.mergingTmpIdx { @@ -2414,12 +2414,12 @@ func newCleanUpIndexWorker(id int, t table.PhysicalTable, decodeColMap map[int64 } func (w *cleanUpIndexWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - failpoint.Inject("errorMockPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorMockPanic")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { panic("panic test") } - }) + } oprStartTime := time.Now() ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) @@ -2457,12 +2457,12 @@ func (w *cleanUpIndexWorker) BackfillData(handleRange reorgBackfillTask) (taskCt return nil }) logSlowOperations(time.Since(oprStartTime), "cleanUpIndexBackfillDataInTxn", 3000) - failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecution")); _err_ == nil { //nolint:forcetypeassert if val.(bool) && MockDMLExecution != nil { MockDMLExecution() } - }) + } return } diff --git a/pkg/ddl/index.go__failpoint_stash__ b/pkg/ddl/index.go__failpoint_stash__ new file mode 100644 index 0000000000000..d9b2c97d50903 --- /dev/null +++ b/pkg/ddl/index.go__failpoint_stash__ @@ -0,0 +1,2616 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "cmp" + "context" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "slices" + "strings" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/copr" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + litconfig "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/backoff" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror" + tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" + decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" + "github.com/pingcap/tidb/pkg/util/size" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/stringutil" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + kvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + pdHttp "github.com/tikv/pd/client/http" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +const ( + // MaxCommentLength is exported for testing. + MaxCommentLength = 1024 +) + +var ( + // SuppressErrorTooLongKeyKey is used by SchemaTracker to suppress err too long key error + SuppressErrorTooLongKeyKey stringutil.StringerStr = "suppressErrorTooLongKeyKey" +) + +func suppressErrorTooLongKeyKey(sctx sessionctx.Context) bool { + if suppress, ok := sctx.Value(SuppressErrorTooLongKeyKey).(bool); ok && suppress { + return true + } + return false +} + +func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) { + // Build offsets. + idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications)) + var col *model.ColumnInfo + var mvIndex bool + maxIndexLength := config.GetGlobalConfig().MaxIndexLength + // The sum of length of all index columns. + sumLength := 0 + for _, ip := range indexPartSpecifications { + col = model.FindColumnInfo(columns, ip.Column.Name.L) + if col == nil { + return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name) + } + + if err := checkIndexColumn(ctx, col, ip.Length); err != nil { + return nil, false, err + } + if col.FieldType.IsArray() { + if mvIndex { + return nil, false, dbterror.ErrNotSupportedYet.GenWithStackByArgs("more than one multi-valued key part per index") + } + mvIndex = true + } + indexColLen := ip.Length + if indexColLen != types.UnspecifiedLength && + types.IsTypeChar(col.FieldType.GetType()) && + indexColLen == col.FieldType.GetFlen() { + indexColLen = types.UnspecifiedLength + } + indexColumnLength, err := getIndexColumnLength(col, indexColLen) + if err != nil { + return nil, false, err + } + sumLength += indexColumnLength + + // The sum of all lengths must be shorter than the max length for prefix. + if sumLength > maxIndexLength { + // The multiple column index and the unique index in which the length sum exceeds the maximum size + // will return an error instead produce a warning. + if ctx == nil || (ctx.GetSessionVars().SQLMode.HasStrictMode() && !suppressErrorTooLongKeyKey(ctx)) || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { + return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(sumLength, maxIndexLength) + } + // truncate index length and produce warning message in non-restrict sql mode. + colLenPerUint, err := getIndexColumnLength(col, 1) + if err != nil { + return nil, false, err + } + indexColLen = maxIndexLength / colLenPerUint + // produce warning message + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTooLongKey.FastGenByArgs(sumLength, maxIndexLength)) + } + + idxParts = append(idxParts, &model.IndexColumn{ + Name: col.Name, + Offset: col.Offset, + Length: indexColLen, + }) + } + + return idxParts, mvIndex, nil +} + +// CheckPKOnGeneratedColumn checks the specification of PK is valid. +func CheckPKOnGeneratedColumn(tblInfo *model.TableInfo, indexPartSpecifications []*ast.IndexPartSpecification) (*model.ColumnInfo, error) { + var lastCol *model.ColumnInfo + for _, colName := range indexPartSpecifications { + lastCol = tblInfo.FindPublicColumnByName(colName.Column.Name.L) + if lastCol == nil { + return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStackByArgs(colName.Column.Name) + } + // Virtual columns cannot be used in primary key. + if lastCol.IsGenerated() && !lastCol.GeneratedStored { + if lastCol.Hidden { + return nil, dbterror.ErrFunctionalIndexPrimaryKey + } + return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Defining a virtual generated column as primary key") + } + } + + return lastCol, nil +} + +func checkIndexPrefixLength(columns []*model.ColumnInfo, idxColumns []*model.IndexColumn) error { + idxLen, err := indexColumnsLen(columns, idxColumns) + if err != nil { + return err + } + if idxLen > config.GetGlobalConfig().MaxIndexLength { + return dbterror.ErrTooLongKey.GenWithStackByArgs(idxLen, config.GetGlobalConfig().MaxIndexLength) + } + return nil +} + +func indexColumnsLen(cols []*model.ColumnInfo, idxCols []*model.IndexColumn) (colLen int, err error) { + for _, idxCol := range idxCols { + col := model.FindColumnInfo(cols, idxCol.Name.L) + if col == nil { + err = dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", idxCol.Name.L) + return + } + var l int + l, err = getIndexColumnLength(col, idxCol.Length) + if err != nil { + return + } + colLen += l + } + return +} + +func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumnLen int) error { + if col.GetFlen() == 0 && (types.IsTypeChar(col.FieldType.GetType()) || types.IsTypeVarchar(col.FieldType.GetType())) { + if col.Hidden { + return errors.Trace(dbterror.ErrWrongKeyColumnFunctionalIndex.GenWithStackByArgs(col.GeneratedExprString)) + } + return errors.Trace(dbterror.ErrWrongKeyColumn.GenWithStackByArgs(col.Name)) + } + + // JSON column cannot index. + if col.FieldType.GetType() == mysql.TypeJSON && !col.FieldType.IsArray() { + if col.Hidden { + return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction + } + return errors.Trace(dbterror.ErrJSONUsedAsKey.GenWithStackByArgs(col.Name.O)) + } + + // Length must be specified and non-zero for BLOB and TEXT column indexes. + if types.IsTypeBlob(col.FieldType.GetType()) { + if indexColumnLen == types.UnspecifiedLength { + if col.Hidden { + return dbterror.ErrFunctionalIndexOnBlob + } + return errors.Trace(dbterror.ErrBlobKeyWithoutLength.GenWithStackByArgs(col.Name.O)) + } + if indexColumnLen == types.ErrorLength { + return errors.Trace(dbterror.ErrKeyPart0.GenWithStackByArgs(col.Name.O)) + } + } + + // Length can only be specified for specifiable types. + if indexColumnLen != types.UnspecifiedLength && !types.IsTypePrefixable(col.FieldType.GetType()) { + return errors.Trace(dbterror.ErrIncorrectPrefixKey) + } + + // Key length must be shorter or equal to the column length. + if indexColumnLen != types.UnspecifiedLength && + types.IsTypeChar(col.FieldType.GetType()) { + if col.GetFlen() < indexColumnLen { + return errors.Trace(dbterror.ErrIncorrectPrefixKey) + } + // Length must be non-zero for char. + if indexColumnLen == types.ErrorLength { + return errors.Trace(dbterror.ErrKeyPart0.GenWithStackByArgs(col.Name.O)) + } + } + + if types.IsString(col.FieldType.GetType()) { + desc, err := charset.GetCharsetInfo(col.GetCharset()) + if err != nil { + return err + } + indexColumnLen *= desc.Maxlen + } + // Specified length must be shorter than the max length for prefix. + maxIndexLength := config.GetGlobalConfig().MaxIndexLength + if indexColumnLen > maxIndexLength { + if ctx == nil || (ctx.GetSessionVars().SQLMode.HasStrictMode() && !suppressErrorTooLongKeyKey(ctx)) { + // return error in strict sql mode + return dbterror.ErrTooLongKey.GenWithStackByArgs(indexColumnLen, maxIndexLength) + } + } + return nil +} + +// getIndexColumnLength calculate the bytes number required in an index column. +func getIndexColumnLength(col *model.ColumnInfo, colLen int) (int, error) { + length := types.UnspecifiedLength + if colLen != types.UnspecifiedLength { + length = colLen + } else if col.GetFlen() != types.UnspecifiedLength { + length = col.GetFlen() + } + + switch col.GetType() { + case mysql.TypeBit: + return (length + 7) >> 3, nil + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: + // Different charsets occupy different numbers of bytes on each character. + desc, err := charset.GetCharsetInfo(col.GetCharset()) + if err != nil { + return 0, dbterror.ErrUnsupportedCharset.GenWithStackByArgs(col.GetCharset(), col.GetCollate()) + } + return desc.Maxlen * length, nil + case mysql.TypeTiny, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeShort: + return mysql.DefaultLengthOfMysqlTypes[col.GetType()], nil + case mysql.TypeFloat: + if length <= mysql.MaxFloatPrecisionLength { + return mysql.DefaultLengthOfMysqlTypes[mysql.TypeFloat], nil + } + return mysql.DefaultLengthOfMysqlTypes[mysql.TypeDouble], nil + case mysql.TypeNewDecimal: + return calcBytesLengthForDecimal(length), nil + case mysql.TypeYear, mysql.TypeDate, mysql.TypeDuration, mysql.TypeDatetime, mysql.TypeTimestamp: + return mysql.DefaultLengthOfMysqlTypes[col.GetType()], nil + default: + return length, nil + } +} + +// decimal using a binary format that packs nine decimal (base 10) digits into four bytes. +func calcBytesLengthForDecimal(m int) int { + return (m / 9 * 4) + ((m%9)+1)/2 +} + +// BuildIndexInfo builds a new IndexInfo according to the index information. +func BuildIndexInfo( + ctx sessionctx.Context, + allTableColumns []*model.ColumnInfo, + indexName model.CIStr, + isPrimary bool, + isUnique bool, + isGlobal bool, + indexPartSpecifications []*ast.IndexPartSpecification, + indexOption *ast.IndexOption, + state model.SchemaState, +) (*model.IndexInfo, error) { + if err := checkTooLongIndex(indexName); err != nil { + return nil, errors.Trace(err) + } + + idxColumns, mvIndex, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications) + if err != nil { + return nil, errors.Trace(err) + } + + // Create index info. + idxInfo := &model.IndexInfo{ + Name: indexName, + Columns: idxColumns, + State: state, + Primary: isPrimary, + Unique: isUnique, + Global: isGlobal, + MVIndex: mvIndex, + } + + if indexOption != nil { + idxInfo.Comment = indexOption.Comment + if indexOption.Visibility == ast.IndexVisibilityInvisible { + idxInfo.Invisible = true + } + if indexOption.Tp == model.IndexTypeInvalid { + // Use btree as default index type. + idxInfo.Tp = model.IndexTypeBtree + } else { + idxInfo.Tp = indexOption.Tp + } + } else { + // Use btree as default index type. + idxInfo.Tp = model.IndexTypeBtree + } + + return idxInfo, nil +} + +// AddIndexColumnFlag aligns the column flags of columns in TableInfo to IndexInfo. +func AddIndexColumnFlag(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) { + if indexInfo.Primary { + for _, col := range indexInfo.Columns { + tblInfo.Columns[col.Offset].AddFlag(mysql.PriKeyFlag) + } + return + } + + col := indexInfo.Columns[0] + if indexInfo.Unique && len(indexInfo.Columns) == 1 { + tblInfo.Columns[col.Offset].AddFlag(mysql.UniqueKeyFlag) + } else { + tblInfo.Columns[col.Offset].AddFlag(mysql.MultipleKeyFlag) + } +} + +// DropIndexColumnFlag drops the column flag of columns in TableInfo according to the IndexInfo. +func DropIndexColumnFlag(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) { + if indexInfo.Primary { + for _, col := range indexInfo.Columns { + tblInfo.Columns[col.Offset].DelFlag(mysql.PriKeyFlag) + } + } else if indexInfo.Unique && len(indexInfo.Columns) == 1 { + tblInfo.Columns[indexInfo.Columns[0].Offset].DelFlag(mysql.UniqueKeyFlag) + } else { + tblInfo.Columns[indexInfo.Columns[0].Offset].DelFlag(mysql.MultipleKeyFlag) + } + + col := indexInfo.Columns[0] + // other index may still cover this col + for _, index := range tblInfo.Indices { + if index.Name.L == indexInfo.Name.L { + continue + } + + if index.Columns[0].Name.L != col.Name.L { + continue + } + + AddIndexColumnFlag(tblInfo, index) + } +} + +// ValidateRenameIndex checks if index name is ok to be renamed. +func ValidateRenameIndex(from, to model.CIStr, tbl *model.TableInfo) (ignore bool, err error) { + if fromIdx := tbl.FindIndexByName(from.L); fromIdx == nil { + return false, errors.Trace(infoschema.ErrKeyNotExists.GenWithStackByArgs(from.O, tbl.Name)) + } + // Take case-sensitivity into account, if `FromKey` and `ToKey` are the same, nothing need to be changed + if from.O == to.O { + return true, nil + } + // If spec.FromKey.L == spec.ToKey.L, we operate on the same index(case-insensitive) and change its name (case-sensitive) + // e.g: from `inDex` to `IndEX`. Otherwise, we try to rename an index to another different index which already exists, + // that's illegal by rule. + if toIdx := tbl.FindIndexByName(to.L); toIdx != nil && from.L != to.L { + return false, errors.Trace(infoschema.ErrKeyNameDuplicate.GenWithStackByArgs(toIdx.Name.O)) + } + return false, nil +} + +func onRenameIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, from, to, err := checkRenameIndex(t, job) + if err != nil || tblInfo == nil { + return ver, errors.Trace(err) + } + if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { + return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Index")) + } + + if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + // Store the mark and enter the next DDL handling loop. + return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) + } + + renameIndexes(tblInfo, from, to) + renameHiddenColumns(tblInfo, from, to) + + if ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func validateAlterIndexVisibility(ctx sessionctx.Context, indexName model.CIStr, invisible bool, tbl *model.TableInfo) (bool, error) { + var idx *model.IndexInfo + if idx = tbl.FindIndexByName(indexName.L); idx == nil || idx.State != model.StatePublic { + return false, errors.Trace(infoschema.ErrKeyNotExists.GenWithStackByArgs(indexName.O, tbl.Name)) + } + if ctx == nil || ctx.GetSessionVars() == nil || ctx.GetSessionVars().StmtCtx.MultiSchemaInfo == nil { + // Early return. + if idx.Invisible == invisible { + return true, nil + } + } + return false, nil +} + +func onAlterIndexVisibility(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, from, invisible, err := checkAlterIndexVisibility(t, job) + if err != nil || tblInfo == nil { + return ver, errors.Trace(err) + } + + if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + return updateVersionAndTableInfo(d, t, job, tblInfo, false) + } + + setIndexVisibility(tblInfo, from, invisible) + if ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func setIndexVisibility(tblInfo *model.TableInfo, name model.CIStr, invisible bool) { + for _, idx := range tblInfo.Indices { + if idx.Name.L == name.L || (isTempIdxInfo(idx, tblInfo) && getChangingIndexOriginName(idx) == name.O) { + idx.Invisible = invisible + } + } +} + +func getNullColInfos(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) ([]*model.ColumnInfo, error) { + nullCols := make([]*model.ColumnInfo, 0, len(indexInfo.Columns)) + for _, colName := range indexInfo.Columns { + col := model.FindColumnInfo(tblInfo.Columns, colName.Name.L) + if !mysql.HasNotNullFlag(col.GetFlag()) || mysql.HasPreventNullInsertFlag(col.GetFlag()) { + nullCols = append(nullCols, col) + } + } + return nullCols, nil +} + +func checkPrimaryKeyNotNull(d *ddlCtx, w *worker, t *meta.Meta, job *model.Job, + tblInfo *model.TableInfo, indexInfo *model.IndexInfo) (warnings []string, err error) { + if !indexInfo.Primary { + return nil, nil + } + + dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) + if err != nil { + return nil, err + } + nullCols, err := getNullColInfos(tblInfo, indexInfo) + if err != nil { + return nil, err + } + if len(nullCols) == 0 { + return nil, nil + } + + err = modifyColsFromNull2NotNull(w, dbInfo, tblInfo, nullCols, &model.ColumnInfo{Name: model.NewCIStr("")}, false) + if err == nil { + return nil, nil + } + _, err = convertAddIdxJob2RollbackJob(d, t, job, tblInfo, []*model.IndexInfo{indexInfo}, err) + // TODO: Support non-strict mode. + // warnings = append(warnings, ErrWarnDataTruncated.GenWithStackByArgs(oldCol.Name.L, 0).Error()) + return nil, err +} + +// moveAndUpdateHiddenColumnsToPublic updates the hidden columns to public, and +// moves the hidden columns to proper offsets, so that Table.Columns' states meet the assumption of +// [public, public, ..., public, non-public, non-public, ..., non-public]. +func moveAndUpdateHiddenColumnsToPublic(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) { + hiddenColOffset := make(map[int]struct{}, 0) + for _, col := range idxInfo.Columns { + if tblInfo.Columns[col.Offset].Hidden { + hiddenColOffset[col.Offset] = struct{}{} + } + } + if len(hiddenColOffset) == 0 { + return + } + // Find the first non-public column. + firstNonPublicPos := len(tblInfo.Columns) - 1 + for i, c := range tblInfo.Columns { + if c.State != model.StatePublic { + firstNonPublicPos = i + break + } + } + for _, col := range idxInfo.Columns { + tblInfo.Columns[col.Offset].State = model.StatePublic + if _, needMove := hiddenColOffset[col.Offset]; needMove { + tblInfo.MoveColumnInfo(col.Offset, firstNonPublicPos) + } + } +} + +func decodeAddIndexArgs(job *model.Job) ( + uniques []bool, + indexNames []model.CIStr, + indexPartSpecifications [][]*ast.IndexPartSpecification, + indexOptions []*ast.IndexOption, + hiddenCols [][]*model.ColumnInfo, + globals []bool, + err error, +) { + var ( + unique bool + indexName model.CIStr + indexPartSpecification []*ast.IndexPartSpecification + indexOption *ast.IndexOption + hiddenCol []*model.ColumnInfo + global bool + ) + err = job.DecodeArgs(&unique, &indexName, &indexPartSpecification, &indexOption, &hiddenCol, &global) + if err == nil { + return []bool{unique}, + []model.CIStr{indexName}, + [][]*ast.IndexPartSpecification{indexPartSpecification}, + []*ast.IndexOption{indexOption}, + [][]*model.ColumnInfo{hiddenCol}, + []bool{global}, + nil + } + + err = job.DecodeArgs(&uniques, &indexNames, &indexPartSpecifications, &indexOptions, &hiddenCols, &globals) + return +} + +func (w *worker) onCreateIndex(d *ddlCtx, t *meta.Meta, job *model.Job, isPK bool) (ver int64, err error) { + // Handle the rolling back job. + if job.IsRollingback() { + ver, err = onDropIndex(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + return ver, nil + } + + // Handle normal job. + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return ver, errors.Trace(err) + } + if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { + return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Create Index")) + } + + uniques := make([]bool, 1) + global := make([]bool, 1) + indexNames := make([]model.CIStr, 1) + indexPartSpecifications := make([][]*ast.IndexPartSpecification, 1) + indexOption := make([]*ast.IndexOption, 1) + var sqlMode mysql.SQLMode + var warnings []string + hiddenCols := make([][]*model.ColumnInfo, 1) + + if isPK { + // Notice: sqlMode and warnings is used to support non-strict mode. + err = job.DecodeArgs(&uniques[0], &indexNames[0], &indexPartSpecifications[0], &indexOption[0], &sqlMode, &warnings, &global[0]) + } else { + uniques, indexNames, indexPartSpecifications, indexOption, hiddenCols, global, err = decodeAddIndexArgs(job) + } + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + allIndexInfos := make([]*model.IndexInfo, 0, len(indexNames)) + for i, indexName := range indexNames { + indexInfo := tblInfo.FindIndexByName(indexName.L) + if indexInfo != nil && indexInfo.State == model.StatePublic { + job.State = model.JobStateCancelled + err = dbterror.ErrDupKeyName.GenWithStack("index already exist %s", indexName) + if isPK { + err = infoschema.ErrMultiplePriKey + } + return ver, err + } + if indexInfo == nil { + for _, hiddenCol := range hiddenCols[i] { + columnInfo := model.FindColumnInfo(tblInfo.Columns, hiddenCol.Name.L) + if columnInfo != nil && columnInfo.State == model.StatePublic { + // We already have a column with the same column name. + job.State = model.JobStateCancelled + // TODO: refine the error message + return ver, infoschema.ErrColumnExists.GenWithStackByArgs(hiddenCol.Name) + } + } + } + if indexInfo == nil { + if len(hiddenCols) > 0 { + for _, hiddenCol := range hiddenCols[i] { + InitAndAddColumnToTable(tblInfo, hiddenCol) + } + } + if err = checkAddColumnTooManyColumns(len(tblInfo.Columns)); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + indexInfo, err = BuildIndexInfo( + nil, + tblInfo.Columns, + indexName, + isPK, + uniques[i], + global[i], + indexPartSpecifications[i], + indexOption[i], + model.StateNone, + ) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + if isPK { + if _, err = CheckPKOnGeneratedColumn(tblInfo, indexPartSpecifications[i]); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + } + indexInfo.ID = AllocateIndexID(tblInfo) + tblInfo.Indices = append(tblInfo.Indices, indexInfo) + if err = checkTooManyIndexes(tblInfo.Indices); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + // Here we need do this check before set state to `DeleteOnly`, + // because if hidden columns has been set to `DeleteOnly`, + // the `DeleteOnly` columns are missing when we do this check. + if err := checkInvisibleIndexOnPK(tblInfo); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + logutil.DDLLogger().Info("run add index job", zap.Stringer("job", job), zap.Reflect("indexInfo", indexInfo)) + } + allIndexInfos = append(allIndexInfos, indexInfo) + } + + originalState := allIndexInfos[0].State + +SwitchIndexState: + switch allIndexInfos[0].State { + case model.StateNone: + // none -> delete only + var reorgTp model.ReorgType + reorgTp, err = pickBackfillType(job) + if err != nil { + if !errorIsRetryable(err, job) { + job.State = model.JobStateCancelled + } + return ver, err + } + loadCloudStorageURI(w, job) + if reorgTp.NeedMergeProcess() { + for _, indexInfo := range allIndexInfos { + indexInfo.BackfillState = model.BackfillStateRunning + } + } + for _, indexInfo := range allIndexInfos { + indexInfo.State = model.StateDeleteOnly + moveAndUpdateHiddenColumnsToPublic(tblInfo, indexInfo) + } + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != model.StateDeleteOnly) + if err != nil { + return ver, err + } + job.SchemaState = model.StateDeleteOnly + case model.StateDeleteOnly: + // delete only -> write only + for _, indexInfo := range allIndexInfos { + indexInfo.State = model.StateWriteOnly + _, err = checkPrimaryKeyNotNull(d, w, t, job, tblInfo, indexInfo) + if err != nil { + break SwitchIndexState + } + } + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateWriteOnly) + if err != nil { + return ver, err + } + job.SchemaState = model.StateWriteOnly + case model.StateWriteOnly: + // write only -> reorganization + for _, indexInfo := range allIndexInfos { + indexInfo.State = model.StateWriteReorganization + _, err = checkPrimaryKeyNotNull(d, w, t, job, tblInfo, indexInfo) + if err != nil { + break SwitchIndexState + } + } + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateWriteReorganization) + if err != nil { + return ver, err + } + // Initialize SnapshotVer to 0 for later reorganization check. + job.SnapshotVer = 0 + job.SchemaState = model.StateWriteReorganization + case model.StateWriteReorganization: + // reorganization -> public + tbl, err := getTable(d.getAutoIDRequirement(), schemaID, tblInfo) + if err != nil { + return ver, errors.Trace(err) + } + + var done bool + if job.MultiSchemaInfo != nil { + done, ver, err = doReorgWorkForCreateIndexMultiSchema(w, d, t, job, tbl, allIndexInfos) + } else { + done, ver, err = doReorgWorkForCreateIndex(w, d, t, job, tbl, allIndexInfos) + } + if !done { + return ver, err + } + + // Set column index flag. + for _, indexInfo := range allIndexInfos { + AddIndexColumnFlag(tblInfo, indexInfo) + if isPK { + if err = UpdateColsNull2NotNull(tblInfo, indexInfo); err != nil { + return ver, errors.Trace(err) + } + } + indexInfo.State = model.StatePublic + } + + // Inject the failpoint to prevent the progress of index creation. + failpoint.Inject("create-index-stuck-before-public", func(v failpoint.Value) { + if sigFile, ok := v.(string); ok { + for { + time.Sleep(1 * time.Second) + if _, err := os.Stat(sigFile); err != nil { + if os.IsNotExist(err) { + continue + } + failpoint.Return(ver, errors.Trace(err)) + } + break + } + } + }) + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StatePublic) + if err != nil { + return ver, errors.Trace(err) + } + + allIndexIDs := make([]int64, 0, len(allIndexInfos)) + ifExists := make([]bool, 0, len(allIndexInfos)) + isGlobal := make([]bool, 0, len(allIndexInfos)) + for _, indexInfo := range allIndexInfos { + allIndexIDs = append(allIndexIDs, indexInfo.ID) + ifExists = append(ifExists, false) + isGlobal = append(isGlobal, indexInfo.Global) + } + job.Args = []any{allIndexIDs, ifExists, getPartitionIDs(tbl.Meta()), isGlobal} + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + if !job.ReorgMeta.IsDistReorg && job.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { + ingest.LitBackCtxMgr.Unregister(job.ID) + } + logutil.DDLLogger().Info("run add index job done", + zap.String("charset", job.Charset), + zap.String("collation", job.Collate)) + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", tblInfo.State) + } + + return ver, errors.Trace(err) +} + +// pickBackfillType determines which backfill process will be used. The result is +// both stored in job.ReorgMeta.ReorgTp and returned. +func pickBackfillType(job *model.Job) (model.ReorgType, error) { + if job.ReorgMeta.ReorgTp != model.ReorgTypeNone { + // The backfill task has been started. + // Don't change the backfill type. + return job.ReorgMeta.ReorgTp, nil + } + if !job.ReorgMeta.IsFastReorg { + job.ReorgMeta.ReorgTp = model.ReorgTypeTxn + return model.ReorgTypeTxn, nil + } + if ingest.LitInitialized { + if job.ReorgMeta.UseCloudStorage { + job.ReorgMeta.ReorgTp = model.ReorgTypeLitMerge + return model.ReorgTypeLitMerge, nil + } + available, err := ingest.LitBackCtxMgr.CheckMoreTasksAvailable() + if err != nil { + return model.ReorgTypeNone, err + } + if available { + job.ReorgMeta.ReorgTp = model.ReorgTypeLitMerge + return model.ReorgTypeLitMerge, nil + } + } + // The lightning environment is unavailable, but we can still use the txn-merge backfill. + logutil.DDLLogger().Info("fallback to txn-merge backfill process", + zap.Bool("lightning env initialized", ingest.LitInitialized)) + job.ReorgMeta.ReorgTp = model.ReorgTypeTxnMerge + return model.ReorgTypeTxnMerge, nil +} + +func loadCloudStorageURI(w *worker, job *model.Job) { + jc := w.jobContext(job.ID, job.ReorgMeta) + jc.cloudStorageURI = variable.CloudStorageURI.Load() + job.ReorgMeta.UseCloudStorage = len(jc.cloudStorageURI) > 0 +} + +func doReorgWorkForCreateIndexMultiSchema(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, + tbl table.Table, allIndexInfos []*model.IndexInfo) (done bool, ver int64, err error) { + if job.MultiSchemaInfo.Revertible { + done, ver, err = doReorgWorkForCreateIndex(w, d, t, job, tbl, allIndexInfos) + if done { + job.MarkNonRevertible() + if err == nil { + ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) + } + } + // We need another round to wait for all the others sub-jobs to finish. + return false, ver, err + } + return true, ver, err +} + +func doReorgWorkForCreateIndex( + w *worker, + d *ddlCtx, + t *meta.Meta, + job *model.Job, + tbl table.Table, + allIndexInfos []*model.IndexInfo, +) (done bool, ver int64, err error) { + var reorgTp model.ReorgType + reorgTp, err = pickBackfillType(job) + if err != nil { + return false, ver, err + } + if !reorgTp.NeedMergeProcess() { + return runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) + } + switch allIndexInfos[0].BackfillState { + case model.BackfillStateRunning: + logutil.DDLLogger().Info("index backfill state running", + zap.Int64("job ID", job.ID), zap.String("table", tbl.Meta().Name.O), + zap.Bool("ingest mode", reorgTp == model.ReorgTypeLitMerge), + zap.String("index", allIndexInfos[0].Name.O)) + switch reorgTp { + case model.ReorgTypeLitMerge: + if job.ReorgMeta.IsDistReorg { + done, ver, err = runIngestReorgJobDist(w, d, t, job, tbl, allIndexInfos) + } else { + done, ver, err = runIngestReorgJob(w, d, t, job, tbl, allIndexInfos) + } + case model.ReorgTypeTxnMerge: + done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) + } + if err != nil || !done { + return false, ver, errors.Trace(err) + } + for _, indexInfo := range allIndexInfos { + indexInfo.BackfillState = model.BackfillStateReadyToMerge + } + ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) + return false, ver, errors.Trace(err) + case model.BackfillStateReadyToMerge: + failpoint.Inject("mockDMLExecutionStateBeforeMerge", func(_ failpoint.Value) { + if MockDMLExecutionStateBeforeMerge != nil { + MockDMLExecutionStateBeforeMerge() + } + }) + logutil.DDLLogger().Info("index backfill state ready to merge", + zap.Int64("job ID", job.ID), + zap.String("table", tbl.Meta().Name.O), + zap.String("index", allIndexInfos[0].Name.O)) + for _, indexInfo := range allIndexInfos { + indexInfo.BackfillState = model.BackfillStateMerging + } + if reorgTp == model.ReorgTypeLitMerge { + ingest.LitBackCtxMgr.Unregister(job.ID) + } + job.SnapshotVer = 0 // Reset the snapshot version for merge index reorg. + ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) + return false, ver, errors.Trace(err) + case model.BackfillStateMerging: + done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, true) + if !done { + return false, ver, err + } + for _, indexInfo := range allIndexInfos { + indexInfo.BackfillState = model.BackfillStateInapplicable // Prevent double-write on this index. + } + return true, ver, err + default: + return false, 0, dbterror.ErrInvalidDDLState.GenWithStackByArgs("backfill", allIndexInfos[0].BackfillState) + } +} + +func runIngestReorgJobDist(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, + tbl table.Table, allIndexInfos []*model.IndexInfo) (done bool, ver int64, err error) { + done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) + if err != nil { + return false, ver, errors.Trace(err) + } + + if !done { + return false, ver, nil + } + + return true, ver, nil +} + +func runIngestReorgJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, + tbl table.Table, allIndexInfos []*model.IndexInfo) (done bool, ver int64, err error) { + done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) + if err != nil { + if kv.ErrKeyExists.Equal(err) { + logutil.DDLLogger().Warn("import index duplicate key, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) + ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), allIndexInfos, err) + } else if !errorIsRetryable(err, job) { + logutil.DDLLogger().Warn("run reorg job failed, convert job to rollback", + zap.String("job", job.String()), zap.Error(err)) + ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), allIndexInfos, err) + } else { + logutil.DDLLogger().Warn("run add index ingest job error", zap.Error(err)) + } + return false, ver, errors.Trace(err) + } + failpoint.InjectCall("afterRunIngestReorgJob", job, done) + return done, ver, nil +} + +func errorIsRetryable(err error, job *model.Job) bool { + if job.ErrorCount+1 >= variable.GetDDLErrorCountLimit() { + return false + } + originErr := errors.Cause(err) + if tErr, ok := originErr.(*terror.Error); ok { + sqlErr := terror.ToSQLError(tErr) + _, ok := dbterror.ReorgRetryableErrCodes[sqlErr.Code] + return ok + } + // For the unknown errors, we should retry. + return true +} + +func runReorgJobAndHandleErr( + w *worker, + d *ddlCtx, + t *meta.Meta, + job *model.Job, + tbl table.Table, + allIndexInfos []*model.IndexInfo, + mergingTmpIdx bool, +) (done bool, ver int64, err error) { + elements := make([]*meta.Element, 0, len(allIndexInfos)) + for _, indexInfo := range allIndexInfos { + elements = append(elements, &meta.Element{ID: indexInfo.ID, TypeKey: meta.IndexElementKey}) + } + + failpoint.Inject("mockDMLExecutionStateMerging", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) && allIndexInfos[0].BackfillState == model.BackfillStateMerging && + MockDMLExecutionStateMerging != nil { + MockDMLExecutionStateMerging() + } + }) + + sctx, err1 := w.sessPool.Get() + if err1 != nil { + err = err1 + return + } + defer w.sessPool.Put(sctx) + rh := newReorgHandler(sess.NewSession(sctx)) + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return false, ver, errors.Trace(err) + } + reorgInfo, err := getReorgInfo(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, tbl, elements, mergingTmpIdx) + if err != nil || reorgInfo == nil || reorgInfo.first { + // If we run reorg firstly, we should update the job snapshot version + // and then run the reorg next time. + return false, ver, errors.Trace(err) + } + err = overwriteReorgInfoFromGlobalCheckpoint(w, rh.s, job, reorgInfo) + if err != nil { + return false, ver, errors.Trace(err) + } + err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (addIndexErr error) { + defer util.Recover(metrics.LabelDDL, "onCreateIndex", + func() { + addIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("add table `%v` index `%v` panic", tbl.Meta().Name, allIndexInfos[0].Name) + }, false) + return w.addTableIndex(tbl, reorgInfo) + }) + if err != nil { + if dbterror.ErrPausedDDLJob.Equal(err) { + return false, ver, nil + } + if dbterror.ErrWaitReorgTimeout.Equal(err) { + // if timeout, we should return, check for the owner and re-wait job done. + return false, ver, nil + } + // TODO(tangenta): get duplicate column and match index. + err = ingest.TryConvertToKeyExistsErr(err, allIndexInfos[0], tbl.Meta()) + if !errorIsRetryable(err, job) { + logutil.DDLLogger().Warn("run add index job failed, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) + ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), allIndexInfos, err) + if err1 := rh.RemoveDDLReorgHandle(job, reorgInfo.elements); err1 != nil { + logutil.DDLLogger().Warn("run add index job failed, convert job to rollback, RemoveDDLReorgHandle failed", zap.Stringer("job", job), zap.Error(err1)) + } + } + return false, ver, errors.Trace(err) + } + failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { + if MockDMLExecutionStateBeforeImport != nil { + MockDMLExecutionStateBeforeImport() + } + }) + return true, ver, nil +} + +func onDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, allIndexInfos, ifExists, err := checkDropIndex(d, t, job) + if err != nil { + if ifExists && dbterror.ErrCantDropFieldOrKey.Equal(err) { + job.Warning = toTError(err) + job.State = model.JobStateDone + return ver, nil + } + return ver, errors.Trace(err) + } + if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { + return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Drop Index")) + } + + if job.MultiSchemaInfo != nil && !job.IsRollingback() && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + job.SchemaState = allIndexInfos[0].State + return updateVersionAndTableInfo(d, t, job, tblInfo, false) + } + + originalState := allIndexInfos[0].State + switch allIndexInfos[0].State { + case model.StatePublic: + // public -> write only + for _, indexInfo := range allIndexInfos { + indexInfo.State = model.StateWriteOnly + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateWriteOnly) + if err != nil { + return ver, errors.Trace(err) + } + case model.StateWriteOnly: + // write only -> delete only + for _, indexInfo := range allIndexInfos { + indexInfo.State = model.StateDeleteOnly + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateDeleteOnly) + if err != nil { + return ver, errors.Trace(err) + } + case model.StateDeleteOnly: + // delete only -> reorganization + for _, indexInfo := range allIndexInfos { + indexInfo.State = model.StateDeleteReorganization + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateDeleteReorganization) + if err != nil { + return ver, errors.Trace(err) + } + case model.StateDeleteReorganization: + // reorganization -> absent + idxIDs := make([]int64, 0, len(allIndexInfos)) + for _, indexInfo := range allIndexInfos { + indexInfo.State = model.StateNone + // Set column index flag. + DropIndexColumnFlag(tblInfo, indexInfo) + RemoveDependentHiddenColumns(tblInfo, indexInfo) + removeIndexInfo(tblInfo, indexInfo) + idxIDs = append(idxIDs, indexInfo.ID) + } + + failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + panic("panic test in cancelling add index") + } + }) + + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != model.StateNone) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + if job.IsRollingback() { + job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) + job.Args[0] = idxIDs + } else { + // the partition ids were append by convertAddIdxJob2RollbackJob, it is weird, but for the compatibility, + // we should keep appending the partitions in the convertAddIdxJob2RollbackJob. + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + // Global index key has t{tableID}_ prefix. + // Assign partitionIDs empty to guarantee correct prefix in insertJobIntoDeleteRangeTable. + if allIndexInfos[0].Global { + job.Args = append(job.Args, idxIDs[0], []int64{}) + } else { + job.Args = append(job.Args, idxIDs[0], getPartitionIDs(tblInfo)) + } + } + default: + return ver, errors.Trace(dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", allIndexInfos[0].State)) + } + job.SchemaState = allIndexInfos[0].State + return ver, errors.Trace(err) +} + +// RemoveDependentHiddenColumns removes hidden columns by the indexInfo. +func RemoveDependentHiddenColumns(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) { + hiddenColOffs := make([]int, 0) + for _, indexColumn := range idxInfo.Columns { + col := tblInfo.Columns[indexColumn.Offset] + if col.Hidden { + hiddenColOffs = append(hiddenColOffs, col.Offset) + } + } + // Sort the offset in descending order. + slices.SortFunc(hiddenColOffs, func(a, b int) int { return cmp.Compare(b, a) }) + // Move all the dependent hidden columns to the end. + endOffset := len(tblInfo.Columns) - 1 + for _, offset := range hiddenColOffs { + tblInfo.MoveColumnInfo(offset, endOffset) + } + tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-len(hiddenColOffs)] +} + +func removeIndexInfo(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) { + indices := tblInfo.Indices + offset := -1 + for i, idx := range indices { + if idxInfo.ID == idx.ID { + offset = i + break + } + } + if offset == -1 { + // The target index has been removed. + return + } + // Remove the target index. + tblInfo.Indices = append(tblInfo.Indices[:offset], tblInfo.Indices[offset+1:]...) +} + +func checkDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (*model.TableInfo, []*model.IndexInfo, bool /* ifExists */, error) { + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, false, errors.Trace(err) + } + + indexNames := make([]model.CIStr, 1) + ifExists := make([]bool, 1) + if err = job.DecodeArgs(&indexNames[0], &ifExists[0]); err != nil { + if err = job.DecodeArgs(&indexNames, &ifExists); err != nil { + job.State = model.JobStateCancelled + return nil, nil, false, errors.Trace(err) + } + } + + indexInfos := make([]*model.IndexInfo, 0, len(indexNames)) + for i, idxName := range indexNames { + indexInfo := tblInfo.FindIndexByName(idxName.L) + if indexInfo == nil { + job.State = model.JobStateCancelled + return nil, nil, ifExists[i], dbterror.ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", idxName) + } + + // Check that drop primary index will not cause invisible implicit primary index. + if err := checkInvisibleIndexesOnPK(tblInfo, []*model.IndexInfo{indexInfo}, job); err != nil { + job.State = model.JobStateCancelled + return nil, nil, false, errors.Trace(err) + } + + // Double check for drop index needed in foreign key. + if err := checkIndexNeededInForeignKeyInOwner(d, t, job, job.SchemaName, tblInfo, indexInfo); err != nil { + return nil, nil, false, errors.Trace(err) + } + indexInfos = append(indexInfos, indexInfo) + } + return tblInfo, indexInfos, false, nil +} + +func checkInvisibleIndexesOnPK(tblInfo *model.TableInfo, indexInfos []*model.IndexInfo, job *model.Job) error { + newIndices := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) + for _, oidx := range tblInfo.Indices { + needAppend := true + for _, idx := range indexInfos { + if idx.Name.L == oidx.Name.L { + needAppend = false + break + } + } + if needAppend { + newIndices = append(newIndices, oidx) + } + } + newTbl := tblInfo.Clone() + newTbl.Indices = newIndices + if err := checkInvisibleIndexOnPK(newTbl); err != nil { + job.State = model.JobStateCancelled + return err + } + + return nil +} + +func checkRenameIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, model.CIStr, model.CIStr, error) { + var from, to model.CIStr + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, from, to, errors.Trace(err) + } + + if err := job.DecodeArgs(&from, &to); err != nil { + job.State = model.JobStateCancelled + return nil, from, to, errors.Trace(err) + } + + // Double check. See function `RenameIndex` in executor.go + duplicate, err := ValidateRenameIndex(from, to, tblInfo) + if duplicate { + return nil, from, to, nil + } + if err != nil { + job.State = model.JobStateCancelled + return nil, from, to, errors.Trace(err) + } + return tblInfo, from, to, errors.Trace(err) +} + +func checkAlterIndexVisibility(t *meta.Meta, job *model.Job) (*model.TableInfo, model.CIStr, bool, error) { + var ( + indexName model.CIStr + invisible bool + ) + + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, indexName, invisible, errors.Trace(err) + } + + if err := job.DecodeArgs(&indexName, &invisible); err != nil { + job.State = model.JobStateCancelled + return nil, indexName, invisible, errors.Trace(err) + } + + skip, err := validateAlterIndexVisibility(nil, indexName, invisible, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return nil, indexName, invisible, errors.Trace(err) + } + if skip { + job.State = model.JobStateDone + return nil, indexName, invisible, nil + } + return tblInfo, indexName, invisible, nil +} + +// indexRecord is the record information of an index. +type indexRecord struct { + handle kv.Handle + key []byte // It's used to lock a record. Record it to reduce the encoding time. + vals []types.Datum // It's the index values. + rsData []types.Datum // It's the restored data for handle. + skip bool // skip indicates that the index key is already exists, we should not add it. +} + +type baseIndexWorker struct { + *backfillCtx + indexes []table.Index + + tp backfillerType + // The following attributes are used to reduce memory allocation. + defaultVals []types.Datum + idxRecords []*indexRecord + rowMap map[int64]types.Datum + rowDecoder *decoder.RowDecoder +} + +type addIndexTxnWorker struct { + baseIndexWorker + + // The following attributes are used to reduce memory allocation. + idxKeyBufs [][]byte + batchCheckKeys []kv.Key + batchCheckValues [][]byte + distinctCheckFlags []bool + recordIdx []int +} + +func newAddIndexTxnWorker( + decodeColMap map[int64]decoder.Column, + t table.PhysicalTable, + bfCtx *backfillCtx, + jobID int64, + elements []*meta.Element, + eleTypeKey []byte, +) (*addIndexTxnWorker, error) { + if !bytes.Equal(eleTypeKey, meta.IndexElementKey) { + logutil.DDLLogger().Error("Element type for addIndexTxnWorker incorrect", + zap.Int64("job ID", jobID), zap.ByteString("element type", eleTypeKey), zap.Int64("element ID", elements[0].ID)) + return nil, errors.Errorf("element type is not index, typeKey: %v", eleTypeKey) + } + + allIndexes := make([]table.Index, 0, len(elements)) + for _, elem := range elements { + if !bytes.Equal(elem.TypeKey, meta.IndexElementKey) { + continue + } + indexInfo := model.FindIndexInfoByID(t.Meta().Indices, elem.ID) + index := tables.NewIndex(t.GetPhysicalID(), t.Meta(), indexInfo) + allIndexes = append(allIndexes, index) + } + rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) + + return &addIndexTxnWorker{ + baseIndexWorker: baseIndexWorker{ + backfillCtx: bfCtx, + indexes: allIndexes, + rowDecoder: rowDecoder, + defaultVals: make([]types.Datum, len(t.WritableCols())), + rowMap: make(map[int64]types.Datum, len(decodeColMap)), + }, + }, nil +} + +func (w *baseIndexWorker) AddMetricInfo(cnt float64) { + w.metricCounter.Add(cnt) +} + +func (w *baseIndexWorker) String() string { + return w.tp.String() +} + +func (w *baseIndexWorker) GetCtx() *backfillCtx { + return w.backfillCtx +} + +// mockNotOwnerErrOnce uses to make sure `notOwnerErr` only mock error once. +var mockNotOwnerErrOnce uint32 + +// getIndexRecord gets index columns values use w.rowDecoder, and generate indexRecord. +func (w *baseIndexWorker) getIndexRecord(idxInfo *model.IndexInfo, handle kv.Handle, recordKey []byte) (*indexRecord, error) { + cols := w.table.WritableCols() + failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { + if valStr, ok := val.(string); ok { + switch valStr { + case "cantDecodeRecordErr": + failpoint.Return(nil, errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("index", + errors.New("mock can't decode record error")))) + case "modifyColumnNotOwnerErr": + if idxInfo.Name.O == "_Idx$_idx_0" && handle.IntValue() == 7168 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 0, 1) { + failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) + } + case "addIdxNotOwnerErr": + // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. + // First step, we need to exit "addPhysicalTableIndex". + if idxInfo.Name.O == "idx2" && handle.IntValue() == 6144 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 1, 2) { + failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) + } + } + } + }) + idxVal := make([]types.Datum, len(idxInfo.Columns)) + var err error + for j, v := range idxInfo.Columns { + col := cols[v.Offset] + idxColumnVal, ok := w.rowMap[col.ID] + if ok { + idxVal[j] = idxColumnVal + continue + } + idxColumnVal, err = tables.GetColDefaultValue(w.exprCtx, col, w.defaultVals) + if err != nil { + return nil, errors.Trace(err) + } + + idxVal[j] = idxColumnVal + } + + rsData := tables.TryGetHandleRestoredDataWrapper(w.table.Meta(), nil, w.rowMap, idxInfo) + idxRecord := &indexRecord{handle: handle, key: recordKey, vals: idxVal, rsData: rsData} + return idxRecord, nil +} + +func (w *baseIndexWorker) cleanRowMap() { + for id := range w.rowMap { + delete(w.rowMap, id) + } +} + +// getNextKey gets next key of entry that we are going to process. +func (w *baseIndexWorker) getNextKey(taskRange reorgBackfillTask, taskDone bool) (nextKey kv.Key) { + if !taskDone { + // The task is not done. So we need to pick the last processed entry's handle and add one. + lastHandle := w.idxRecords[len(w.idxRecords)-1].handle + recordKey := tablecodec.EncodeRecordKey(taskRange.physicalTable.RecordPrefix(), lastHandle) + return recordKey.Next() + } + return taskRange.endKey +} + +func (w *baseIndexWorker) updateRowDecoder(handle kv.Handle, rawRecord []byte) error { + sysZone := w.loc + _, err := w.rowDecoder.DecodeAndEvalRowWithMap(w.exprCtx, handle, rawRecord, sysZone, w.rowMap) + return errors.Trace(err) +} + +// fetchRowColVals fetch w.batchCnt count records that need to reorganize indices, and build the corresponding indexRecord slice. +// fetchRowColVals returns: +// 1. The corresponding indexRecord slice. +// 2. Next handle of entry that we need to process. +// 3. Boolean indicates whether the task is done. +// 4. error occurs in fetchRowColVals. nil if no error occurs. +func (w *baseIndexWorker) fetchRowColVals(txn kv.Transaction, taskRange reorgBackfillTask) ([]*indexRecord, kv.Key, bool, error) { + // TODO: use tableScan to prune columns. + w.idxRecords = w.idxRecords[:0] + startTime := time.Now() + + // taskDone means that the reorged handle is out of taskRange.endHandle. + taskDone := false + oprStartTime := startTime + err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, taskRange.physicalTable.RecordPrefix(), txn.StartTS(), + taskRange.startKey, taskRange.endKey, func(handle kv.Handle, recordKey kv.Key, rawRow []byte) (bool, error) { + oprEndTime := time.Now() + logSlowOperations(oprEndTime.Sub(oprStartTime), "iterateSnapshotKeys in baseIndexWorker fetchRowColVals", 0) + oprStartTime = oprEndTime + + taskDone = recordKey.Cmp(taskRange.endKey) >= 0 + + if taskDone || len(w.idxRecords) >= w.batchCnt { + return false, nil + } + + // Decode one row, generate records of this row. + err := w.updateRowDecoder(handle, rawRow) + if err != nil { + return false, err + } + for _, index := range w.indexes { + idxRecord, err1 := w.getIndexRecord(index.Meta(), handle, recordKey) + if err1 != nil { + return false, errors.Trace(err1) + } + w.idxRecords = append(w.idxRecords, idxRecord) + } + // If there are generated column, rowDecoder will use column value that not in idxInfo.Columns to calculate + // the generated value, so we need to clear up the reusing map. + w.cleanRowMap() + + if recordKey.Cmp(taskRange.endKey) == 0 { + taskDone = true + return false, nil + } + return true, nil + }) + + if len(w.idxRecords) == 0 { + taskDone = true + } + + logutil.DDLLogger().Debug("txn fetches handle info", zap.Stringer("worker", w), zap.Uint64("txnStartTS", txn.StartTS()), + zap.String("taskRange", taskRange.String()), zap.Duration("takeTime", time.Since(startTime))) + return w.idxRecords, w.getNextKey(taskRange, taskDone), taskDone, errors.Trace(err) +} + +func (w *addIndexTxnWorker) initBatchCheckBufs(batchCount int) { + if len(w.idxKeyBufs) < batchCount { + w.idxKeyBufs = make([][]byte, batchCount) + } + + w.batchCheckKeys = w.batchCheckKeys[:0] + w.batchCheckValues = w.batchCheckValues[:0] + w.distinctCheckFlags = w.distinctCheckFlags[:0] + w.recordIdx = w.recordIdx[:0] +} + +func (w *addIndexTxnWorker) checkHandleExists(idxInfo *model.IndexInfo, key kv.Key, value []byte, handle kv.Handle) error { + tblInfo := w.table.Meta() + idxColLen := len(idxInfo.Columns) + h, err := tablecodec.DecodeIndexHandle(key, value, idxColLen) + if err != nil { + return errors.Trace(err) + } + hasBeenBackFilled := h.Equal(handle) + if hasBeenBackFilled { + return nil + } + return ddlutil.GenKeyExistsErr(key, value, idxInfo, tblInfo) +} + +// batchCheckUniqueKey checks the unique keys in the batch. +// Note that `idxRecords` may belong to multiple indexes. +func (w *addIndexTxnWorker) batchCheckUniqueKey(txn kv.Transaction, idxRecords []*indexRecord) error { + w.initBatchCheckBufs(len(idxRecords)) + evalCtx := w.exprCtx.GetEvalCtx() + ec := evalCtx.ErrCtx() + uniqueBatchKeys := make([]kv.Key, 0, len(idxRecords)) + cnt := 0 + for i, record := range idxRecords { + idx := w.indexes[i%len(w.indexes)] + if !idx.Meta().Unique { + // non-unique key need not to check, use `nil` as a placeholder to keep + // `idxRecords[i]` belonging to `indexes[i%len(indexes)]`. + w.batchCheckKeys = append(w.batchCheckKeys, nil) + w.batchCheckValues = append(w.batchCheckValues, nil) + w.distinctCheckFlags = append(w.distinctCheckFlags, false) + w.recordIdx = append(w.recordIdx, 0) + continue + } + // skip by default. + idxRecords[i].skip = true + iter := idx.GenIndexKVIter(ec, w.loc, record.vals, record.handle, idxRecords[i].rsData) + for iter.Valid() { + var buf []byte + if cnt < len(w.idxKeyBufs) { + buf = w.idxKeyBufs[cnt] + } + key, val, distinct, err := iter.Next(buf, nil) + if err != nil { + return errors.Trace(err) + } + if cnt < len(w.idxKeyBufs) { + w.idxKeyBufs[cnt] = key + } else { + w.idxKeyBufs = append(w.idxKeyBufs, key) + } + cnt++ + w.batchCheckKeys = append(w.batchCheckKeys, key) + w.batchCheckValues = append(w.batchCheckValues, val) + w.distinctCheckFlags = append(w.distinctCheckFlags, distinct) + w.recordIdx = append(w.recordIdx, i) + uniqueBatchKeys = append(uniqueBatchKeys, key) + } + } + + if len(uniqueBatchKeys) == 0 { + return nil + } + + batchVals, err := txn.BatchGet(context.Background(), uniqueBatchKeys) + if err != nil { + return errors.Trace(err) + } + + // 1. unique-key/primary-key is duplicate and the handle is equal, skip it. + // 2. unique-key/primary-key is duplicate and the handle is not equal, return duplicate error. + // 3. non-unique-key is duplicate, skip it. + for i, key := range w.batchCheckKeys { + if len(key) == 0 { + continue + } + idx := w.indexes[i%len(w.indexes)] + val, found := batchVals[string(key)] + if found { + if w.distinctCheckFlags[i] { + if err := w.checkHandleExists(idx.Meta(), key, val, idxRecords[w.recordIdx[i]].handle); err != nil { + return errors.Trace(err) + } + } + } else if w.distinctCheckFlags[i] { + // The keys in w.batchCheckKeys also maybe duplicate, + // so we need to backfill the not found key into `batchVals` map. + batchVals[string(key)] = w.batchCheckValues[i] + } + idxRecords[w.recordIdx[i]].skip = found && idxRecords[w.recordIdx[i]].skip + } + return nil +} + +func getLocalWriterConfig(indexCnt, writerCnt int) *backend.LocalWriterConfig { + writerCfg := &backend.LocalWriterConfig{} + // avoid unit test panic + memRoot := ingest.LitMemRoot + if memRoot == nil { + return writerCfg + } + + // leave some room for objects overhead + availMem := memRoot.MaxMemoryQuota() - memRoot.CurrentUsage() - int64(10*size.MB) + memLimitPerWriter := availMem / int64(indexCnt) / int64(writerCnt) + memLimitPerWriter = min(memLimitPerWriter, litconfig.DefaultLocalWriterMemCacheSize) + writerCfg.Local.MemCacheSize = memLimitPerWriter + return writerCfg +} + +func writeChunkToLocal( + ctx context.Context, + writers []ingest.Writer, + indexes []table.Index, + copCtx copr.CopContext, + loc *time.Location, + errCtx errctx.Context, + writeStmtBufs *variable.WriteStmtBufs, + copChunk *chunk.Chunk, +) (int, kv.Handle, error) { + iter := chunk.NewIterator4Chunk(copChunk) + c := copCtx.GetBase() + ectx := c.ExprCtx.GetEvalCtx() + + maxIdxColCnt := maxIndexColumnCount(indexes) + idxDataBuf := make([]types.Datum, maxIdxColCnt) + handleDataBuf := make([]types.Datum, len(c.HandleOutputOffsets)) + var restoreDataBuf []types.Datum + count := 0 + var lastHandle kv.Handle + + unlockFns := make([]func(), 0, len(writers)) + for _, w := range writers { + unlock := w.LockForWrite() + unlockFns = append(unlockFns, unlock) + } + defer func() { + for _, unlock := range unlockFns { + unlock() + } + }() + needRestoreForIndexes := make([]bool, len(indexes)) + restore, pkNeedRestore := false, false + if c.PrimaryKeyInfo != nil && c.TableInfo.IsCommonHandle && c.TableInfo.CommonHandleVersion != 0 { + pkNeedRestore = tables.NeedRestoredData(c.PrimaryKeyInfo.Columns, c.TableInfo.Columns) + } + for i, index := range indexes { + needRestore := pkNeedRestore || tables.NeedRestoredData(index.Meta().Columns, c.TableInfo.Columns) + needRestoreForIndexes[i] = needRestore + restore = restore || needRestore + } + if restore { + restoreDataBuf = make([]types.Datum, len(c.HandleOutputOffsets)) + } + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + handleDataBuf := extractDatumByOffsets(ectx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf) + if restore { + // restoreDataBuf should not truncate index values. + for i, datum := range handleDataBuf { + restoreDataBuf[i] = *datum.Clone() + } + } + h, err := buildHandle(handleDataBuf, c.TableInfo, c.PrimaryKeyInfo, loc, errCtx) + if err != nil { + return 0, nil, errors.Trace(err) + } + for i, index := range indexes { + idxID := index.Meta().ID + idxDataBuf = extractDatumByOffsets(ectx, + row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf) + idxData := idxDataBuf[:len(index.Meta().Columns)] + var rsData []types.Datum + if needRestoreForIndexes[i] { + rsData = getRestoreData(c.TableInfo, copCtx.IndexInfo(idxID), c.PrimaryKeyInfo, restoreDataBuf) + } + err = writeOneKVToLocal(ctx, writers[i], index, loc, errCtx, writeStmtBufs, idxData, rsData, h) + if err != nil { + return 0, nil, errors.Trace(err) + } + } + count++ + lastHandle = h + } + return count, lastHandle, nil +} + +func maxIndexColumnCount(indexes []table.Index) int { + maxCnt := 0 + for _, idx := range indexes { + colCnt := len(idx.Meta().Columns) + if colCnt > maxCnt { + maxCnt = colCnt + } + } + return maxCnt +} + +func writeOneKVToLocal( + ctx context.Context, + writer ingest.Writer, + index table.Index, + loc *time.Location, + errCtx errctx.Context, + writeBufs *variable.WriteStmtBufs, + idxDt, rsData []types.Datum, + handle kv.Handle, +) error { + iter := index.GenIndexKVIter(errCtx, loc, idxDt, handle, rsData) + for iter.Valid() { + key, idxVal, _, err := iter.Next(writeBufs.IndexKeyBuf, writeBufs.RowValBuf) + if err != nil { + return errors.Trace(err) + } + failpoint.Inject("mockLocalWriterPanic", func() { + panic("mock panic") + }) + err = writer.WriteRow(ctx, key, idxVal, handle) + if err != nil { + return errors.Trace(err) + } + failpoint.Inject("mockLocalWriterError", func() { + failpoint.Return(errors.New("mock engine error")) + }) + writeBufs.IndexKeyBuf = key + writeBufs.RowValBuf = idxVal + } + return nil +} + +// BackfillData will backfill table index in a transaction. A lock corresponds to a rowKey if the value of rowKey is changed, +// Note that index columns values may change, and an index is not allowed to be added, so the txn will rollback and retry. +// BackfillData will add w.batchCnt indices once, default value of w.batchCnt is 128. +func (w *addIndexTxnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { + failpoint.Inject("errorMockPanic", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + panic("panic test") + } + }) + + oprStartTime := time.Now() + jobID := handleRange.getJobID() + ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) + errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) (err error) { + taskCtx.finishTS = txn.StartTS() + taskCtx.addedCount = 0 + taskCtx.scanCount = 0 + updateTxnEntrySizeLimitIfNeeded(txn) + txn.SetOption(kv.Priority, handleRange.priority) + if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(jobID); tagger != nil { + txn.SetOption(kv.ResourceGroupTagger, tagger) + } + txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) + + idxRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) + if err != nil { + return errors.Trace(err) + } + taskCtx.nextKey = nextKey + taskCtx.done = taskDone + + err = w.batchCheckUniqueKey(txn, idxRecords) + if err != nil { + return errors.Trace(err) + } + + for i, idxRecord := range idxRecords { + taskCtx.scanCount++ + // The index is already exists, we skip it, no needs to backfill it. + // The following update, delete, insert on these rows, TiDB can handle it correctly. + if idxRecord.skip { + continue + } + + // We need to add this lock to make sure pessimistic transaction can realize this operation. + // For the normal pessimistic transaction, it's ok. But if async commit is used, it may lead to inconsistent data and index. + // TODO: For global index, lock the correct key?! Currently it locks the partition (phyTblID) and the handle or actual key? + // but should really lock the table's ID + key col(s) + err := txn.LockKeys(context.Background(), new(kv.LockCtx), idxRecord.key) + if err != nil { + return errors.Trace(err) + } + + handle, err := w.indexes[i%len(w.indexes)].Create( + w.tblCtx, txn, idxRecord.vals, idxRecord.handle, idxRecord.rsData, + table.WithIgnoreAssertion, + table.FromBackfill, + // Constrains is already checked in batchCheckUniqueKey + table.DupKeyCheckSkip, + ) + if err != nil { + if kv.ErrKeyExists.Equal(err) && idxRecord.handle.Equal(handle) { + // Index already exists, skip it. + continue + } + return errors.Trace(err) + } + taskCtx.addedCount++ + } + + return nil + }) + logSlowOperations(time.Since(oprStartTime), "AddIndexBackfillData", 3000) + failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) && MockDMLExecution != nil { + MockDMLExecution() + } + }) + return +} + +// MockDMLExecution is only used for test. +var MockDMLExecution func() + +// MockDMLExecutionMerging is only used for test. +var MockDMLExecutionMerging func() + +// MockDMLExecutionStateMerging is only used for test. +var MockDMLExecutionStateMerging func() + +// MockDMLExecutionStateBeforeImport is only used for test. +var MockDMLExecutionStateBeforeImport func() + +// MockDMLExecutionStateBeforeMerge is only used for test. +var MockDMLExecutionStateBeforeMerge func() + +func (w *worker) addPhysicalTableIndex(t table.PhysicalTable, reorgInfo *reorgInfo) error { + if reorgInfo.mergingTmpIdx { + logutil.DDLLogger().Info("start to merge temp index", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) + return w.writePhysicalTableRecord(w.ctx, w.sessPool, t, typeAddIndexMergeTmpWorker, reorgInfo) + } + logutil.DDLLogger().Info("start to add table index", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) + return w.writePhysicalTableRecord(w.ctx, w.sessPool, t, typeAddIndexWorker, reorgInfo) +} + +// addTableIndex handles the add index reorganization state for a table. +func (w *worker) addTableIndex(t table.Table, reorgInfo *reorgInfo) error { + // TODO: Support typeAddIndexMergeTmpWorker. + if reorgInfo.ReorgMeta.IsDistReorg && !reorgInfo.mergingTmpIdx { + if reorgInfo.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { + err := w.executeDistTask(t, reorgInfo) + if err != nil { + return err + } + //nolint:forcetypeassert + discovery := w.store.(tikv.Storage).GetRegionCache().PDClient().GetServiceDiscovery() + return checkDuplicateForUniqueIndex(w.ctx, t, reorgInfo, discovery) + } + } + + var err error + if tbl, ok := t.(table.PartitionedTable); ok { + var finish bool + for !finish { + p := tbl.GetPartition(reorgInfo.PhysicalTableID) + if p == nil { + return dbterror.ErrCancelledDDLJob.GenWithStack("Can not find partition id %d for table %d", reorgInfo.PhysicalTableID, t.Meta().ID) + } + err = w.addPhysicalTableIndex(p, reorgInfo) + if err != nil { + break + } + + finish, err = updateReorgInfo(w.sessPool, tbl, reorgInfo) + if err != nil { + return errors.Trace(err) + } + failpoint.InjectCall("afterUpdatePartitionReorgInfo", reorgInfo.Job) + // Every time we finish a partition, we update the progress of the job. + if rc := w.getReorgCtx(reorgInfo.Job.ID); rc != nil { + reorgInfo.Job.SetRowCount(rc.getRowCount()) + } + } + } else { + //nolint:forcetypeassert + phyTbl := t.(table.PhysicalTable) + err = w.addPhysicalTableIndex(phyTbl, reorgInfo) + } + return errors.Trace(err) +} + +func checkDuplicateForUniqueIndex(ctx context.Context, t table.Table, reorgInfo *reorgInfo, discovery pd.ServiceDiscovery) error { + var bc ingest.BackendCtx + var err error + defer func() { + if bc != nil { + ingest.LitBackCtxMgr.Unregister(reorgInfo.ID) + } + }() + + for _, elem := range reorgInfo.elements { + indexInfo := model.FindIndexInfoByID(t.Meta().Indices, elem.ID) + if indexInfo == nil { + return errors.New("unexpected error, can't find index info") + } + if indexInfo.Unique { + ctx := tidblogutil.WithCategory(ctx, "ddl-ingest") + if bc == nil { + bc, err = ingest.LitBackCtxMgr.Register(ctx, reorgInfo.ID, indexInfo.Unique, nil, discovery, reorgInfo.ReorgMeta.ResourceGroupName) + if err != nil { + return err + } + } + err = bc.CollectRemoteDuplicateRows(indexInfo.ID, t) + if err != nil { + return err + } + } + } + return nil +} + +func (w *worker) executeDistTask(t table.Table, reorgInfo *reorgInfo) error { + if reorgInfo.mergingTmpIdx { + return errors.New("do not support merge index") + } + + taskType := proto.Backfill + taskKey := fmt.Sprintf("ddl/%s/%d", taskType, reorgInfo.Job.ID) + g, ctx := errgroup.WithContext(w.ctx) + ctx = kv.WithInternalSourceType(ctx, kv.InternalDistTask) + + done := make(chan struct{}) + + // generate taskKey for multi schema change. + if mInfo := reorgInfo.Job.MultiSchemaInfo; mInfo != nil { + taskKey = fmt.Sprintf("%s/%d", taskKey, mInfo.Seq) + } + + // For resuming add index task. + // Need to fetch task by taskKey in tidb_global_task and tidb_global_task_history tables. + // When pausing the related ddl job, it is possible that the task with taskKey is succeed and in tidb_global_task_history. + // As a result, when resuming the related ddl job, + // it is necessary to check task exits in tidb_global_task and tidb_global_task_history tables. + taskManager, err := storage.GetTaskManager() + if err != nil { + return err + } + task, err := taskManager.GetTaskByKeyWithHistory(w.ctx, taskKey) + if err != nil && err != storage.ErrTaskNotFound { + return err + } + if task != nil { + // It's possible that the task state is succeed but the ddl job is paused. + // When task in succeed state, we can skip the dist task execution/scheduing process. + if task.State == proto.TaskStateSucceed { + logutil.DDLLogger().Info( + "task succeed, start to resume the ddl job", + zap.String("task-key", taskKey)) + return nil + } + g.Go(func() error { + defer close(done) + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logutil.DDLLogger(), + func(context.Context) (bool, error) { + return true, handle.ResumeTask(w.ctx, taskKey) + }, + ) + if err != nil { + return err + } + err = handle.WaitTaskDoneOrPaused(ctx, task.ID) + if err := w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { + if dbterror.ErrPausedDDLJob.Equal(err) { + logutil.DDLLogger().Warn("job paused by user", zap.Error(err)) + return dbterror.ErrPausedDDLJob.GenWithStackByArgs(reorgInfo.Job.ID) + } + } + return err + }) + } else { + job := reorgInfo.Job + workerCntLimit := int(variable.GetDDLReorgWorkerCounter()) + cpuCount, err := handle.GetCPUCountOfNode(ctx) + if err != nil { + return err + } + concurrency := min(workerCntLimit, cpuCount) + logutil.DDLLogger().Info("adjusted add-index task concurrency", + zap.Int("worker-cnt", workerCntLimit), zap.Int("task-concurrency", concurrency), + zap.String("task-key", taskKey)) + rowSize := estimateTableRowSize(w.ctx, w.store, w.sess.GetRestrictedSQLExecutor(), t) + taskMeta := &BackfillTaskMeta{ + Job: *job.Clone(), + EleIDs: extractElemIDs(reorgInfo), + EleTypeKey: reorgInfo.currElement.TypeKey, + CloudStorageURI: w.jobContext(job.ID, job.ReorgMeta).cloudStorageURI, + EstimateRowSize: rowSize, + } + + metaData, err := json.Marshal(taskMeta) + if err != nil { + return err + } + + g.Go(func() error { + defer close(done) + err := submitAndWaitTask(ctx, taskKey, taskType, concurrency, reorgInfo.ReorgMeta.TargetScope, metaData) + failpoint.InjectCall("pauseAfterDistTaskFinished") + if err := w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { + if dbterror.ErrPausedDDLJob.Equal(err) { + logutil.DDLLogger().Warn("job paused by user", zap.Error(err)) + return dbterror.ErrPausedDDLJob.GenWithStackByArgs(reorgInfo.Job.ID) + } + } + return err + }) + } + + g.Go(func() error { + checkFinishTk := time.NewTicker(CheckBackfillJobFinishInterval) + defer checkFinishTk.Stop() + updateRowCntTk := time.NewTicker(UpdateBackfillJobRowCountInterval) + defer updateRowCntTk.Stop() + for { + select { + case <-done: + w.updateJobRowCount(taskKey, reorgInfo.Job.ID) + return nil + case <-checkFinishTk.C: + if err = w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { + if dbterror.ErrPausedDDLJob.Equal(err) { + if err = handle.PauseTask(w.ctx, taskKey); err != nil { + logutil.DDLLogger().Error("pause task error", zap.String("task_key", taskKey), zap.Error(err)) + continue + } + failpoint.InjectCall("syncDDLTaskPause") + } + if !dbterror.ErrCancelledDDLJob.Equal(err) { + return errors.Trace(err) + } + if err = handle.CancelTask(w.ctx, taskKey); err != nil { + logutil.DDLLogger().Error("cancel task error", zap.String("task_key", taskKey), zap.Error(err)) + // continue to cancel task. + continue + } + } + case <-updateRowCntTk.C: + w.updateJobRowCount(taskKey, reorgInfo.Job.ID) + } + } + }) + err = g.Wait() + return err +} + +// EstimateTableRowSizeForTest is used for test. +var EstimateTableRowSizeForTest = estimateTableRowSize + +// estimateTableRowSize estimates the row size in bytes of a table. +// This function tries to retrieve row size in following orders: +// 1. AVG_ROW_LENGTH column from information_schema.tables. +// 2. region info's approximate key size / key number. +func estimateTableRowSize( + ctx context.Context, + store kv.Storage, + exec sqlexec.RestrictedSQLExecutor, + tbl table.Table, +) (sizeInBytes int) { + defer util.Recover(metrics.LabelDDL, "estimateTableRowSize", nil, false) + var gErr error + defer func() { + tidblogutil.Logger(ctx).Info("estimate row size", + zap.Int64("tableID", tbl.Meta().ID), zap.Int("size", sizeInBytes), zap.Error(gErr)) + }() + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, + "select AVG_ROW_LENGTH from information_schema.tables where TIDB_TABLE_ID = %?", tbl.Meta().ID) + if err != nil { + gErr = err + return 0 + } + if len(rows) == 0 { + gErr = errors.New("no average row data") + return 0 + } + avgRowSize := rows[0].GetInt64(0) + if avgRowSize != 0 { + return int(avgRowSize) + } + regionRowSize, err := estimateRowSizeFromRegion(ctx, store, tbl) + if err != nil { + gErr = err + return 0 + } + return regionRowSize +} + +func estimateRowSizeFromRegion(ctx context.Context, store kv.Storage, tbl table.Table) (int, error) { + hStore, ok := store.(helper.Storage) + if !ok { + return 0, fmt.Errorf("not a helper.Storage") + } + h := &helper.Helper{ + Store: hStore, + RegionCache: hStore.GetRegionCache(), + } + pdCli, err := h.TryGetPDHTTPClient() + if err != nil { + return 0, err + } + pid := tbl.Meta().ID + sk, ek := tablecodec.GetTableHandleKeyRange(pid) + sRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, sk)) + if err != nil { + return 0, err + } + eRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, ek)) + if err != nil { + return 0, err + } + sk, err = hex.DecodeString(sRegion.StartKey) + if err != nil { + return 0, err + } + ek, err = hex.DecodeString(eRegion.EndKey) + if err != nil { + return 0, err + } + // We use the second region to prevent the influence of the front and back tables. + regionLimit := 3 + regionInfos, err := pdCli.GetRegionsByKeyRange(ctx, pdHttp.NewKeyRange(sk, ek), regionLimit) + if err != nil { + return 0, err + } + if len(regionInfos.Regions) != regionLimit { + return 0, fmt.Errorf("less than 3 regions") + } + sample := regionInfos.Regions[1] + if sample.ApproximateKeys == 0 || sample.ApproximateSize == 0 { + return 0, fmt.Errorf("zero approximate size") + } + return int(uint64(sample.ApproximateSize)*size.MB) / int(sample.ApproximateKeys), nil +} + +func (w *worker) updateJobRowCount(taskKey string, jobID int64) { + taskMgr, err := storage.GetTaskManager() + if err != nil { + logutil.DDLLogger().Warn("cannot get task manager", zap.String("task_key", taskKey), zap.Error(err)) + return + } + task, err := taskMgr.GetTaskByKey(w.ctx, taskKey) + if err != nil { + logutil.DDLLogger().Warn("cannot get task", zap.String("task_key", taskKey), zap.Error(err)) + return + } + rowCount, err := taskMgr.GetSubtaskRowCount(w.ctx, task.ID, proto.BackfillStepReadIndex) + if err != nil { + logutil.DDLLogger().Warn("cannot get subtask row count", zap.String("task_key", taskKey), zap.Error(err)) + return + } + w.getReorgCtx(jobID).setRowCount(rowCount) +} + +// submitAndWaitTask submits a task and wait for it to finish. +func submitAndWaitTask(ctx context.Context, taskKey string, taskType proto.TaskType, concurrency int, targetScope string, taskMeta []byte) error { + task, err := handle.SubmitTask(ctx, taskKey, taskType, concurrency, targetScope, taskMeta) + if err != nil { + return err + } + return handle.WaitTaskDoneOrPaused(ctx, task.ID) +} + +func getNextPartitionInfo(reorg *reorgInfo, t table.PartitionedTable, currPhysicalTableID int64) (int64, kv.Key, kv.Key, error) { + pi := t.Meta().GetPartitionInfo() + if pi == nil { + return 0, nil, nil, nil + } + + // This will be used in multiple different scenarios/ALTER TABLE: + // ADD INDEX - no change in partitions, just use pi.Definitions (1) + // REORGANIZE PARTITION - copy data from partitions to be dropped (2) + // REORGANIZE PARTITION - (re)create indexes on partitions to be added (3) + // REORGANIZE PARTITION - Update new Global indexes with data from non-touched partitions (4) + // (i.e. pi.Definitions - pi.DroppingDefinitions) + var pid int64 + var err error + if bytes.Equal(reorg.currElement.TypeKey, meta.IndexElementKey) { + // case 1, 3 or 4 + if len(pi.AddingDefinitions) == 0 { + // case 1 + // Simply AddIndex, without any partitions added or dropped! + pid, err = findNextPartitionID(currPhysicalTableID, pi.Definitions) + } else { + // case 3 (or if not found AddingDefinitions; 4) + // check if recreating Global Index (during Reorg Partition) + pid, err = findNextPartitionID(currPhysicalTableID, pi.AddingDefinitions) + if err != nil { + // case 4 + // Not a partition in the AddingDefinitions, so it must be an existing + // non-touched partition, i.e. recreating Global Index for the non-touched partitions + pid, err = findNextNonTouchedPartitionID(currPhysicalTableID, pi) + } + } + } else { + // case 2 + pid, err = findNextPartitionID(currPhysicalTableID, pi.DroppingDefinitions) + } + if err != nil { + // Fatal error, should not run here. + logutil.DDLLogger().Error("find next partition ID failed", zap.Reflect("table", t), zap.Error(err)) + return 0, nil, nil, errors.Trace(err) + } + if pid == 0 { + // Next partition does not exist, all the job done. + return 0, nil, nil, nil + } + + failpoint.Inject("mockUpdateCachedSafePoint", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + ts := oracle.GoTimeToTS(time.Now()) + //nolint:forcetypeassert + s := reorg.d.store.(tikv.Storage) + s.UpdateSPCache(ts, time.Now()) + time.Sleep(time.Second * 3) + } + }) + + var startKey, endKey kv.Key + if reorg.mergingTmpIdx { + elements := reorg.elements + firstElemTempID := tablecodec.TempIndexPrefix | elements[0].ID + lastElemTempID := tablecodec.TempIndexPrefix | elements[len(elements)-1].ID + startKey = tablecodec.EncodeIndexSeekKey(pid, firstElemTempID, nil) + endKey = tablecodec.EncodeIndexSeekKey(pid, lastElemTempID, []byte{255}) + } else { + currentVer, err := getValidCurrentVersion(reorg.d.store) + if err != nil { + return 0, nil, nil, errors.Trace(err) + } + startKey, endKey, err = getTableRange(reorg.NewJobContext(), reorg.d, t.GetPartition(pid), currentVer.Ver, reorg.Job.Priority) + if err != nil { + return 0, nil, nil, errors.Trace(err) + } + } + return pid, startKey, endKey, nil +} + +// updateReorgInfo will find the next partition according to current reorgInfo. +// If no more partitions, or table t is not a partitioned table, returns true to +// indicate that the reorganize work is finished. +func updateReorgInfo(sessPool *sess.Pool, t table.PartitionedTable, reorg *reorgInfo) (bool, error) { + pid, startKey, endKey, err := getNextPartitionInfo(reorg, t, reorg.PhysicalTableID) + if err != nil { + return false, errors.Trace(err) + } + if pid == 0 { + // Next partition does not exist, all the job done. + return true, nil + } + reorg.PhysicalTableID, reorg.StartKey, reorg.EndKey = pid, startKey, endKey + + // Write the reorg info to store so the whole reorganize process can recover from panic. + err = reorg.UpdateReorgMeta(reorg.StartKey, sessPool) + logutil.DDLLogger().Info("job update reorgInfo", + zap.Int64("jobID", reorg.Job.ID), + zap.Stringer("element", reorg.currElement), + zap.Int64("partitionTableID", pid), + zap.String("startKey", hex.EncodeToString(reorg.StartKey)), + zap.String("endKey", hex.EncodeToString(reorg.EndKey)), zap.Error(err)) + return false, errors.Trace(err) +} + +// findNextPartitionID finds the next partition ID in the PartitionDefinition array. +// Returns 0 if current partition is already the last one. +func findNextPartitionID(currentPartition int64, defs []model.PartitionDefinition) (int64, error) { + for i, def := range defs { + if currentPartition == def.ID { + if i == len(defs)-1 { + return 0, nil + } + return defs[i+1].ID, nil + } + } + return 0, errors.Errorf("partition id not found %d", currentPartition) +} + +func findNextNonTouchedPartitionID(currPartitionID int64, pi *model.PartitionInfo) (int64, error) { + pid, err := findNextPartitionID(currPartitionID, pi.Definitions) + if err != nil { + return 0, err + } + if pid == 0 { + return 0, nil + } + for _, notFoundErr := findNextPartitionID(pid, pi.DroppingDefinitions); notFoundErr == nil; { + // This can be optimized, but it is not frequently called, so keeping as-is + pid, err = findNextPartitionID(pid, pi.Definitions) + if pid == 0 { + break + } + } + return pid, err +} + +// AllocateIndexID allocates an index ID from TableInfo. +func AllocateIndexID(tblInfo *model.TableInfo) int64 { + tblInfo.MaxIndexID++ + return tblInfo.MaxIndexID +} + +func getIndexInfoByNameAndColumn(oldTableInfo *model.TableInfo, newOne *model.IndexInfo) *model.IndexInfo { + for _, oldOne := range oldTableInfo.Indices { + if newOne.Name.L == oldOne.Name.L && indexColumnSliceEqual(newOne.Columns, oldOne.Columns) { + return oldOne + } + } + return nil +} + +func indexColumnSliceEqual(a, b []*model.IndexColumn) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + logutil.DDLLogger().Warn("admin repair table : index's columns length equal to 0") + return true + } + // Accelerate the compare by eliminate index bound check. + b = b[:len(a)] + for i, v := range a { + if v.Name.L != b[i].Name.L { + return false + } + } + return true +} + +type cleanUpIndexWorker struct { + baseIndexWorker +} + +func newCleanUpIndexWorker(id int, t table.PhysicalTable, decodeColMap map[int64]decoder.Column, reorgInfo *reorgInfo, jc *JobContext) (*cleanUpIndexWorker, error) { + bCtx, err := newBackfillCtx(id, reorgInfo, reorgInfo.SchemaName, t, jc, "cleanup_idx_rate", false) + if err != nil { + return nil, err + } + + indexes := make([]table.Index, 0, len(t.Indices())) + rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) + for _, index := range t.Indices() { + if index.Meta().Global { + indexes = append(indexes, index) + } + } + return &cleanUpIndexWorker{ + baseIndexWorker: baseIndexWorker{ + backfillCtx: bCtx, + indexes: indexes, + rowDecoder: rowDecoder, + defaultVals: make([]types.Datum, len(t.WritableCols())), + rowMap: make(map[int64]types.Datum, len(decodeColMap)), + }, + }, nil +} + +func (w *cleanUpIndexWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { + failpoint.Inject("errorMockPanic", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + panic("panic test") + } + }) + + oprStartTime := time.Now() + ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) + errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { + taskCtx.addedCount = 0 + taskCtx.scanCount = 0 + updateTxnEntrySizeLimitIfNeeded(txn) + txn.SetOption(kv.Priority, handleRange.priority) + if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(handleRange.getJobID()); tagger != nil { + txn.SetOption(kv.ResourceGroupTagger, tagger) + } + txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) + + idxRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) + if err != nil { + return errors.Trace(err) + } + taskCtx.nextKey = nextKey + taskCtx.done = taskDone + + txn.SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + + n := len(w.indexes) + for i, idxRecord := range idxRecords { + taskCtx.scanCount++ + // we fetch records row by row, so records will belong to + // index[0], index[1] ... index[n-1], index[0], index[1] ... + // respectively. So indexes[i%n] is the index of idxRecords[i]. + err := w.indexes[i%n].Delete(w.tblCtx, txn, idxRecord.vals, idxRecord.handle) + if err != nil { + return errors.Trace(err) + } + taskCtx.addedCount++ + } + return nil + }) + logSlowOperations(time.Since(oprStartTime), "cleanUpIndexBackfillDataInTxn", 3000) + failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) && MockDMLExecution != nil { + MockDMLExecution() + } + }) + + return +} + +// cleanupPhysicalTableIndex handles the drop partition reorganization state for a non-partitioned table or a partition. +func (w *worker) cleanupPhysicalTableIndex(t table.PhysicalTable, reorgInfo *reorgInfo) error { + logutil.DDLLogger().Info("start to clean up index", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) + return w.writePhysicalTableRecord(w.ctx, w.sessPool, t, typeCleanUpIndexWorker, reorgInfo) +} + +// cleanupGlobalIndex handles the drop partition reorganization state to clean up index entries of partitions. +func (w *worker) cleanupGlobalIndexes(tbl table.PartitionedTable, partitionIDs []int64, reorgInfo *reorgInfo) error { + var err error + var finish bool + for !finish { + p := tbl.GetPartition(reorgInfo.PhysicalTableID) + if p == nil { + return dbterror.ErrCancelledDDLJob.GenWithStack("Can not find partition id %d for table %d", reorgInfo.PhysicalTableID, tbl.Meta().ID) + } + err = w.cleanupPhysicalTableIndex(p, reorgInfo) + if err != nil { + break + } + finish, err = w.updateReorgInfoForPartitions(tbl, reorgInfo, partitionIDs) + if err != nil { + return errors.Trace(err) + } + } + + return errors.Trace(err) +} + +// updateReorgInfoForPartitions will find the next partition in partitionIDs according to current reorgInfo. +// If no more partitions, or table t is not a partitioned table, returns true to +// indicate that the reorganize work is finished. +func (w *worker) updateReorgInfoForPartitions(t table.PartitionedTable, reorg *reorgInfo, partitionIDs []int64) (bool, error) { + pi := t.Meta().GetPartitionInfo() + if pi == nil { + return true, nil + } + + var pid int64 + for i, pi := range partitionIDs { + if pi == reorg.PhysicalTableID { + if i == len(partitionIDs)-1 { + return true, nil + } + pid = partitionIDs[i+1] + break + } + } + + currentVer, err := getValidCurrentVersion(reorg.d.store) + if err != nil { + return false, errors.Trace(err) + } + start, end, err := getTableRange(reorg.NewJobContext(), reorg.d, t.GetPartition(pid), currentVer.Ver, reorg.Job.Priority) + if err != nil { + return false, errors.Trace(err) + } + reorg.StartKey, reorg.EndKey, reorg.PhysicalTableID = start, end, pid + + // Write the reorg info to store so the whole reorganize process can recover from panic. + err = reorg.UpdateReorgMeta(reorg.StartKey, w.sessPool) + logutil.DDLLogger().Info("job update reorg info", zap.Int64("jobID", reorg.Job.ID), + zap.Stringer("element", reorg.currElement), + zap.Int64("partition table ID", pid), zap.String("start key", hex.EncodeToString(start)), + zap.String("end key", hex.EncodeToString(end)), zap.Error(err)) + return false, errors.Trace(err) +} + +// changingIndex is used to store the index that need to be changed during modifying column. +type changingIndex struct { + IndexInfo *model.IndexInfo + // Column offset in idxInfo.Columns. + Offset int + // When the modifying column is contained in the index, a temp index is created. + // isTemp indicates whether the indexInfo is a temp index created by a previous modify column job. + isTemp bool +} + +// FindRelatedIndexesToChange finds the indexes that covering the given column. +// The normal one will be overridden by the temp one. +func FindRelatedIndexesToChange(tblInfo *model.TableInfo, colName model.CIStr) []changingIndex { + // In multi-schema change jobs that contains several "modify column" sub-jobs, there may be temp indexes for another temp index. + // To prevent reorganizing too many indexes, we should create the temp indexes that are really necessary. + var normalIdxInfos, tempIdxInfos []changingIndex + for _, idxInfo := range tblInfo.Indices { + if pos := findIdxCol(idxInfo, colName); pos != -1 { + isTemp := isTempIdxInfo(idxInfo, tblInfo) + r := changingIndex{IndexInfo: idxInfo, Offset: pos, isTemp: isTemp} + if isTemp { + tempIdxInfos = append(tempIdxInfos, r) + } else { + normalIdxInfos = append(normalIdxInfos, r) + } + } + } + // Overwrite if the index has the corresponding temp index. For example, + // we try to find the indexes that contain the column `b` and there are two indexes, `i(a, b)` and `$i($a, b)`. + // Note that the symbol `$` means temporary. The index `$i($a, b)` is temporarily created by the previous "modify a" statement. + // In this case, we would create a temporary index like $$i($a, $b), so the latter should be chosen. + result := normalIdxInfos + for _, tmpIdx := range tempIdxInfos { + origName := getChangingIndexOriginName(tmpIdx.IndexInfo) + for i, normIdx := range normalIdxInfos { + if normIdx.IndexInfo.Name.O == origName { + result[i] = tmpIdx + } + } + } + return result +} + +func isTempIdxInfo(idxInfo *model.IndexInfo, tblInfo *model.TableInfo) bool { + for _, idxCol := range idxInfo.Columns { + if tblInfo.Columns[idxCol.Offset].ChangeStateInfo != nil { + return true + } + } + return false +} + +func findIdxCol(idxInfo *model.IndexInfo, colName model.CIStr) int { + for offset, idxCol := range idxInfo.Columns { + if idxCol.Name.L == colName.L { + return offset + } + } + return -1 +} + +func renameIndexes(tblInfo *model.TableInfo, from, to model.CIStr) { + for _, idx := range tblInfo.Indices { + if idx.Name.L == from.L { + idx.Name = to + } else if isTempIdxInfo(idx, tblInfo) && getChangingIndexOriginName(idx) == from.O { + idx.Name.L = strings.Replace(idx.Name.L, from.L, to.L, 1) + idx.Name.O = strings.Replace(idx.Name.O, from.O, to.O, 1) + } + } +} + +func renameHiddenColumns(tblInfo *model.TableInfo, from, to model.CIStr) { + for _, col := range tblInfo.Columns { + if col.Hidden && getExpressionIndexOriginName(col) == from.O { + col.Name.L = strings.Replace(col.Name.L, from.L, to.L, 1) + col.Name.O = strings.Replace(col.Name.O, from.O, to.O, 1) + } + } +} diff --git a/pkg/ddl/index_cop.go b/pkg/ddl/index_cop.go index 30d6f70c8b9e0..dd05fd7fee847 100644 --- a/pkg/ddl/index_cop.go +++ b/pkg/ddl/index_cop.go @@ -145,11 +145,11 @@ func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) if err != nil { return err } - failpoint.Inject("mockCopSenderPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockCopSenderPanic")); _err_ == nil { if val.(bool) { panic("mock panic") } - }) + } if p.checkpointMgr != nil { p.checkpointMgr.Register(task.id, task.endKey) } @@ -169,9 +169,9 @@ func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) idxRs := IndexRecordChunk{ID: task.id, Chunk: srcChk, Done: done} rate := float64(srcChk.MemoryUsage()) / 1024.0 / 1024.0 / time.Since(startTime).Seconds() metrics.AddIndexScanRate.WithLabelValues(metrics.LblAddIndex).Observe(rate) - failpoint.Inject("mockCopSenderError", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockCopSenderError")); _err_ == nil { idxRs.Err = errors.New("mock cop error") - }) + } p.chunkSender.AddTask(idxRs) startTime = time.Now() } diff --git a/pkg/ddl/index_cop.go__failpoint_stash__ b/pkg/ddl/index_cop.go__failpoint_stash__ new file mode 100644 index 0000000000000..30d6f70c8b9e0 --- /dev/null +++ b/pkg/ddl/index_cop.go__failpoint_stash__ @@ -0,0 +1,392 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "encoding/hex" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/copr" + "github.com/pingcap/tidb/pkg/ddl/ingest" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/distsql" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tipb/go-tipb" + kvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// copReadBatchSize is the batch size of coprocessor read. +// It multiplies the tidb_ddl_reorg_batch_size by 10 to avoid +// sending too many cop requests for the same handle range. +func copReadBatchSize() int { + return 10 * int(variable.GetDDLReorgBatchSize()) +} + +// copReadChunkPoolSize is the size of chunk pool, which +// represents the max concurrent ongoing coprocessor requests. +// It multiplies the tidb_ddl_reorg_worker_cnt by 10. +func copReadChunkPoolSize() int { + return 10 * int(variable.GetDDLReorgWorkerCounter()) +} + +// chunkSender is used to receive the result of coprocessor request. +type chunkSender interface { + AddTask(IndexRecordChunk) +} + +type copReqSenderPool struct { + tasksCh chan *reorgBackfillTask + chunkSender chunkSender + checkpointMgr *ingest.CheckpointManager + sessPool *sess.Pool + + ctx context.Context + copCtx copr.CopContext + store kv.Storage + + senders []*copReqSender + wg sync.WaitGroup + closed bool + + srcChkPool chan *chunk.Chunk +} + +type copReqSender struct { + senderPool *copReqSenderPool + + ctx context.Context + cancel context.CancelFunc +} + +func (c *copReqSender) run() { + p := c.senderPool + defer p.wg.Done() + defer util.Recover(metrics.LabelDDL, "copReqSender.run", func() { + p.chunkSender.AddTask(IndexRecordChunk{Err: dbterror.ErrReorgPanic}) + }, false) + sessCtx, err := p.sessPool.Get() + if err != nil { + logutil.Logger(p.ctx).Error("copReqSender get session from pool failed", zap.Error(err)) + p.chunkSender.AddTask(IndexRecordChunk{Err: err}) + return + } + se := sess.NewSession(sessCtx) + defer p.sessPool.Put(sessCtx) + var ( + task *reorgBackfillTask + ok bool + ) + + for { + select { + case <-c.ctx.Done(): + return + case task, ok = <-p.tasksCh: + } + if !ok { + return + } + if p.checkpointMgr != nil && p.checkpointMgr.IsKeyProcessed(task.endKey) { + logutil.Logger(p.ctx).Info("checkpoint detected, skip a cop-request task", + zap.Int("task ID", task.id), + zap.String("task end key", hex.EncodeToString(task.endKey))) + continue + } + err := scanRecords(p, task, se) + if err != nil { + p.chunkSender.AddTask(IndexRecordChunk{ID: task.id, Err: err}) + return + } + } +} + +func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) error { + logutil.Logger(p.ctx).Info("start a cop-request task", + zap.Int("id", task.id), zap.Stringer("task", task)) + + return wrapInBeginRollback(se, func(startTS uint64) error { + rs, err := buildTableScan(p.ctx, p.copCtx.GetBase(), startTS, task.startKey, task.endKey) + if err != nil { + return err + } + failpoint.Inject("mockCopSenderPanic", func(val failpoint.Value) { + if val.(bool) { + panic("mock panic") + } + }) + if p.checkpointMgr != nil { + p.checkpointMgr.Register(task.id, task.endKey) + } + var done bool + startTime := time.Now() + for !done { + srcChk := p.getChunk() + done, err = fetchTableScanResult(p.ctx, p.copCtx.GetBase(), rs, srcChk) + if err != nil { + p.recycleChunk(srcChk) + terror.Call(rs.Close) + return err + } + if p.checkpointMgr != nil { + p.checkpointMgr.UpdateTotalKeys(task.id, srcChk.NumRows(), done) + } + idxRs := IndexRecordChunk{ID: task.id, Chunk: srcChk, Done: done} + rate := float64(srcChk.MemoryUsage()) / 1024.0 / 1024.0 / time.Since(startTime).Seconds() + metrics.AddIndexScanRate.WithLabelValues(metrics.LblAddIndex).Observe(rate) + failpoint.Inject("mockCopSenderError", func() { + idxRs.Err = errors.New("mock cop error") + }) + p.chunkSender.AddTask(idxRs) + startTime = time.Now() + } + terror.Call(rs.Close) + return nil + }) +} + +func wrapInBeginRollback(se *sess.Session, f func(startTS uint64) error) error { + err := se.Begin(context.Background()) + if err != nil { + return errors.Trace(err) + } + defer se.Rollback() + var startTS uint64 + sessVars := se.GetSessionVars() + sessVars.TxnCtxMu.Lock() + startTS = sessVars.TxnCtx.StartTS + sessVars.TxnCtxMu.Unlock() + return f(startTS) +} + +func newCopReqSenderPool(ctx context.Context, copCtx copr.CopContext, store kv.Storage, + taskCh chan *reorgBackfillTask, sessPool *sess.Pool, + checkpointMgr *ingest.CheckpointManager) *copReqSenderPool { + poolSize := copReadChunkPoolSize() + srcChkPool := make(chan *chunk.Chunk, poolSize) + for i := 0; i < poolSize; i++ { + srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, copReadBatchSize()) + } + return &copReqSenderPool{ + tasksCh: taskCh, + ctx: ctx, + copCtx: copCtx, + store: store, + senders: make([]*copReqSender, 0, variable.GetDDLReorgWorkerCounter()), + wg: sync.WaitGroup{}, + srcChkPool: srcChkPool, + sessPool: sessPool, + checkpointMgr: checkpointMgr, + } +} + +func (c *copReqSenderPool) adjustSize(n int) { + // Add some senders. + for i := len(c.senders); i < n; i++ { + ctx, cancel := context.WithCancel(c.ctx) + c.senders = append(c.senders, &copReqSender{ + senderPool: c, + ctx: ctx, + cancel: cancel, + }) + c.wg.Add(1) + go c.senders[i].run() + } + // Remove some senders. + if n < len(c.senders) { + for i := n; i < len(c.senders); i++ { + c.senders[i].cancel() + } + c.senders = c.senders[:n] + } +} + +func (c *copReqSenderPool) close(force bool) { + if c.closed { + return + } + logutil.Logger(c.ctx).Info("close cop-request sender pool", zap.Bool("force", force)) + if force { + for _, w := range c.senders { + w.cancel() + } + } + // Wait for all cop-req senders to exit. + c.wg.Wait() + c.closed = true +} + +func (c *copReqSenderPool) getChunk() *chunk.Chunk { + chk := <-c.srcChkPool + newCap := copReadBatchSize() + if chk.Capacity() != newCap { + chk = chunk.NewChunkWithCapacity(c.copCtx.GetBase().FieldTypes, newCap) + } + chk.Reset() + return chk +} + +// recycleChunk puts the index record slice and the chunk back to the pool for reuse. +func (c *copReqSenderPool) recycleChunk(chk *chunk.Chunk) { + if chk == nil { + return + } + c.srcChkPool <- chk +} + +func buildTableScan(ctx context.Context, c *copr.CopContextBase, startTS uint64, start, end kv.Key) (distsql.SelectResult, error) { + dagPB, err := buildDAGPB(c.ExprCtx, c.DistSQLCtx, c.PushDownFlags, c.TableInfo, c.ColumnInfos) + if err != nil { + return nil, err + } + + var builder distsql.RequestBuilder + kvReq, err := builder. + SetDAGRequest(dagPB). + SetStartTS(startTS). + SetKeyRanges([]kv.KeyRange{{StartKey: start, EndKey: end}}). + SetKeepOrder(true). + SetFromSessionVars(c.DistSQLCtx). + SetConcurrency(1). + Build() + kvReq.RequestSource.RequestSourceInternal = true + kvReq.RequestSource.RequestSourceType = getDDLRequestSource(model.ActionAddIndex) + kvReq.RequestSource.ExplicitRequestSourceType = kvutil.ExplicitTypeDDL + if err != nil { + return nil, err + } + return distsql.Select(ctx, c.DistSQLCtx, kvReq, c.FieldTypes) +} + +func fetchTableScanResult( + ctx context.Context, + copCtx *copr.CopContextBase, + result distsql.SelectResult, + chk *chunk.Chunk, +) (bool, error) { + err := result.Next(ctx, chk) + if err != nil { + return false, errors.Trace(err) + } + if chk.NumRows() == 0 { + return true, nil + } + err = table.FillVirtualColumnValue( + copCtx.VirtualColumnsFieldTypes, copCtx.VirtualColumnsOutputOffsets, + copCtx.ExprColumnInfos, copCtx.ColumnInfos, copCtx.ExprCtx, chk) + return false, err +} + +func completeErr(err error, idxInfo *model.IndexInfo) error { + if expression.ErrInvalidJSONForFuncIndex.Equal(err) { + err = expression.ErrInvalidJSONForFuncIndex.GenWithStackByArgs(idxInfo.Name.O) + } + return errors.Trace(err) +} + +func getRestoreData(tblInfo *model.TableInfo, targetIdx, pkIdx *model.IndexInfo, handleDts []types.Datum) []types.Datum { + if !collate.NewCollationEnabled() || !tblInfo.IsCommonHandle || tblInfo.CommonHandleVersion == 0 { + return nil + } + if pkIdx == nil { + return nil + } + for i, pkIdxCol := range pkIdx.Columns { + pkCol := tblInfo.Columns[pkIdxCol.Offset] + if !types.NeedRestoredData(&pkCol.FieldType) { + // Since the handle data cannot be null, we can use SetNull to + // indicate that this column does not need to be restored. + handleDts[i].SetNull() + continue + } + tables.TryTruncateRestoredData(&handleDts[i], pkCol, pkIdxCol, targetIdx) + tables.ConvertDatumToTailSpaceCount(&handleDts[i], pkCol) + } + dtToRestored := handleDts[:0] + for _, handleDt := range handleDts { + if !handleDt.IsNull() { + dtToRestored = append(dtToRestored, handleDt) + } + } + return dtToRestored +} + +func buildDAGPB(exprCtx exprctx.BuildContext, distSQLCtx *distsqlctx.DistSQLContext, pushDownFlags uint64, tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.DAGRequest, error) { + dagReq := &tipb.DAGRequest{} + dagReq.TimeZoneName, dagReq.TimeZoneOffset = timeutil.Zone(exprCtx.GetEvalCtx().Location()) + dagReq.Flags = pushDownFlags + for i := range colInfos { + dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) + } + execPB, err := constructTableScanPB(exprCtx, tblInfo, colInfos) + if err != nil { + return nil, err + } + dagReq.Executors = append(dagReq.Executors, execPB) + distsql.SetEncodeType(distSQLCtx, dagReq) + return dagReq, nil +} + +func constructTableScanPB(ctx exprctx.BuildContext, tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.Executor, error) { + tblScan := tables.BuildTableScanFromInfos(tblInfo, colInfos) + tblScan.TableId = tblInfo.ID + err := tables.SetPBColumnsDefaultValue(ctx, tblScan.Columns, colInfos) + return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err +} + +func extractDatumByOffsets(ctx expression.EvalContext, row chunk.Row, offsets []int, expCols []*expression.Column, buf []types.Datum) []types.Datum { + for i, offset := range offsets { + c := expCols[offset] + row.DatumWithBuffer(offset, c.GetType(ctx), &buf[i]) + } + return buf +} + +func buildHandle(pkDts []types.Datum, tblInfo *model.TableInfo, + pkInfo *model.IndexInfo, loc *time.Location, errCtx errctx.Context) (kv.Handle, error) { + if tblInfo.IsCommonHandle { + tablecodec.TruncateIndexValues(tblInfo, pkInfo, pkDts) + handleBytes, err := codec.EncodeKey(loc, nil, pkDts...) + err = errCtx.HandleError(err) + if err != nil { + return nil, err + } + return kv.NewCommonHandle(handleBytes) + } + return kv.IntHandle(pkDts[0].GetInt64()), nil +} diff --git a/pkg/ddl/index_merge_tmp.go b/pkg/ddl/index_merge_tmp.go index c0001250b6c22..e6c3c1c61cd33 100644 --- a/pkg/ddl/index_merge_tmp.go +++ b/pkg/ddl/index_merge_tmp.go @@ -241,12 +241,12 @@ func (w *mergeIndexWorker) BackfillData(taskRange reorgBackfillTask) (taskCtx ba return nil }) - failpoint.Inject("mockDMLExecutionMerging", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionMerging")); _err_ == nil { //nolint:forcetypeassert if val.(bool) && MockDMLExecutionMerging != nil { MockDMLExecutionMerging() } - }) + } logSlowOperations(time.Since(oprStartTime), "AddIndexMergeDataInTxn", 3000) return } diff --git a/pkg/ddl/index_merge_tmp.go__failpoint_stash__ b/pkg/ddl/index_merge_tmp.go__failpoint_stash__ new file mode 100644 index 0000000000000..c0001250b6c22 --- /dev/null +++ b/pkg/ddl/index_merge_tmp.go__failpoint_stash__ @@ -0,0 +1,400 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/model" + driver "github.com/pingcap/tidb/pkg/store/driver/txn" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + kvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +func (w *mergeIndexWorker) batchCheckTemporaryUniqueKey( + txn kv.Transaction, + idxRecords []*temporaryIndexRecord, +) error { + if !w.currentIndex.Unique { + // non-unique key need no check, just overwrite it, + // because in most case, backfilling indices is not exists. + return nil + } + + batchVals, err := txn.BatchGet(context.Background(), w.originIdxKeys) + if err != nil { + return errors.Trace(err) + } + + for i, key := range w.originIdxKeys { + if val, found := batchVals[string(key)]; found { + // Found a value in the original index key. + err := checkTempIndexKey(txn, idxRecords[i], val, w.table) + if err != nil { + if kv.ErrKeyExists.Equal(err) { + return driver.ExtractKeyExistsErrFromIndex(key, val, w.table.Meta(), w.currentIndex.ID) + } + return errors.Trace(err) + } + } else if idxRecords[i].distinct { + // The keys in w.batchCheckKeys also maybe duplicate, + // so we need to backfill the not found key into `batchVals` map. + batchVals[string(key)] = idxRecords[i].vals + } + } + return nil +} + +func checkTempIndexKey(txn kv.Transaction, tmpRec *temporaryIndexRecord, originIdxVal []byte, tblInfo table.Table) error { + if !tmpRec.delete { + if tmpRec.distinct && !bytes.Equal(originIdxVal, tmpRec.vals) { + return kv.ErrKeyExists + } + // The key has been found in the original index, skip merging it. + tmpRec.skip = true + return nil + } + // Delete operation. + distinct := tablecodec.IndexKVIsUnique(originIdxVal) + if !distinct { + // For non-distinct key, it is consist of a null value and the handle. + // Same as the non-unique indexes, replay the delete operation on non-distinct keys. + return nil + } + // For distinct index key values, prevent deleting an unexpected index KV in original index. + hdInVal, err := tablecodec.DecodeHandleInIndexValue(originIdxVal) + if err != nil { + return errors.Trace(err) + } + if !tmpRec.handle.Equal(hdInVal) { + // The inequality means multiple modifications happened in the same key. + // We use the handle in origin index value to check if the row exists. + rowKey := tablecodec.EncodeRecordKey(tblInfo.RecordPrefix(), hdInVal) + _, err := txn.Get(context.Background(), rowKey) + if err != nil { + if kv.IsErrNotFound(err) { + // The row is deleted, so we can merge the delete operation to the origin index. + tmpRec.skip = false + return nil + } + // Unexpected errors. + return errors.Trace(err) + } + // Don't delete the index key if the row exists. + tmpRec.skip = true + return nil + } + return nil +} + +// temporaryIndexRecord is the record information of an index. +type temporaryIndexRecord struct { + vals []byte + skip bool // skip indicates that the index key is already exists, we should not add it. + delete bool + unique bool + distinct bool + handle kv.Handle +} + +type mergeIndexWorker struct { + *backfillCtx + + indexes []table.Index + + tmpIdxRecords []*temporaryIndexRecord + originIdxKeys []kv.Key + tmpIdxKeys []kv.Key + + needValidateKey bool + currentTempIndexPrefix []byte + currentIndex *model.IndexInfo +} + +func newMergeTempIndexWorker(bfCtx *backfillCtx, t table.PhysicalTable, elements []*meta.Element) *mergeIndexWorker { + allIndexes := make([]table.Index, 0, len(elements)) + for _, elem := range elements { + indexInfo := model.FindIndexInfoByID(t.Meta().Indices, elem.ID) + index := tables.NewIndex(t.GetPhysicalID(), t.Meta(), indexInfo) + allIndexes = append(allIndexes, index) + } + + return &mergeIndexWorker{ + backfillCtx: bfCtx, + indexes: allIndexes, + } +} + +func (w *mergeIndexWorker) validateTaskRange(taskRange *reorgBackfillTask) (skip bool, err error) { + tmpID, err := tablecodec.DecodeIndexID(taskRange.startKey) + if err != nil { + return false, err + } + startIndexID := tmpID & tablecodec.IndexIDMask + tmpID, err = tablecodec.DecodeIndexID(taskRange.endKey) + if err != nil { + return false, err + } + endIndexID := tmpID & tablecodec.IndexIDMask + + w.needValidateKey = startIndexID != endIndexID + containsTargetID := false + for _, idx := range w.indexes { + idxInfo := idx.Meta() + if idxInfo.ID == startIndexID { + containsTargetID = true + w.currentIndex = idxInfo + break + } + if idxInfo.ID == endIndexID { + containsTargetID = true + } + } + return !containsTargetID, nil +} + +// BackfillData merge temp index data in txn. +func (w *mergeIndexWorker) BackfillData(taskRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { + skip, err := w.validateTaskRange(&taskRange) + if skip || err != nil { + return taskCtx, err + } + + oprStartTime := time.Now() + ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) + + errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { + taskCtx.addedCount = 0 + taskCtx.scanCount = 0 + updateTxnEntrySizeLimitIfNeeded(txn) + txn.SetOption(kv.Priority, taskRange.priority) + if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(taskRange.getJobID()); tagger != nil { + txn.SetOption(kv.ResourceGroupTagger, tagger) + } + txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) + + tmpIdxRecords, nextKey, taskDone, err := w.fetchTempIndexVals(txn, taskRange) + if err != nil { + return errors.Trace(err) + } + taskCtx.nextKey = nextKey + taskCtx.done = taskDone + + err = w.batchCheckTemporaryUniqueKey(txn, tmpIdxRecords) + if err != nil { + return errors.Trace(err) + } + + for i, idxRecord := range tmpIdxRecords { + taskCtx.scanCount++ + // The index is already exists, we skip it, no needs to backfill it. + // The following update, delete, insert on these rows, TiDB can handle it correctly. + // If all batch are skipped, update first index key to make txn commit to release lock. + if idxRecord.skip { + continue + } + + // Lock the corresponding row keys so that it doesn't modify the index KVs + // that are changing by a pessimistic transaction. + rowKey := tablecodec.EncodeRecordKey(w.table.RecordPrefix(), idxRecord.handle) + err := txn.LockKeys(context.Background(), new(kv.LockCtx), rowKey) + if err != nil { + return errors.Trace(err) + } + + if idxRecord.delete { + if idxRecord.unique { + err = txn.GetMemBuffer().DeleteWithFlags(w.originIdxKeys[i], kv.SetNeedLocked) + } else { + err = txn.GetMemBuffer().Delete(w.originIdxKeys[i]) + } + } else { + err = txn.GetMemBuffer().Set(w.originIdxKeys[i], idxRecord.vals) + } + if err != nil { + return err + } + taskCtx.addedCount++ + } + return nil + }) + + failpoint.Inject("mockDMLExecutionMerging", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) && MockDMLExecutionMerging != nil { + MockDMLExecutionMerging() + } + }) + logSlowOperations(time.Since(oprStartTime), "AddIndexMergeDataInTxn", 3000) + return +} + +func (*mergeIndexWorker) AddMetricInfo(float64) { +} + +func (*mergeIndexWorker) String() string { + return typeAddIndexMergeTmpWorker.String() +} + +func (w *mergeIndexWorker) GetCtx() *backfillCtx { + return w.backfillCtx +} + +func (w *mergeIndexWorker) prefixIsChanged(newKey kv.Key) bool { + return len(w.currentTempIndexPrefix) == 0 || !bytes.HasPrefix(newKey, w.currentTempIndexPrefix) +} + +func (w *mergeIndexWorker) updateCurrentIndexInfo(newIndexKey kv.Key) (skip bool, err error) { + tempIdxID, err := tablecodec.DecodeIndexID(newIndexKey) + if err != nil { + return false, err + } + idxID := tablecodec.IndexIDMask & tempIdxID + var curIdx *model.IndexInfo + for _, idx := range w.indexes { + if idx.Meta().ID == idxID { + curIdx = idx.Meta() + } + } + if curIdx == nil { + // Index IDs are always increasing, but not always continuous: + // if DDL adds another index between these indexes, it is possible that: + // multi-schema add index IDs = [1, 2, 4, 5] + // another index ID = [3] + // If the new index get rollback, temp index 0xFFxxx03 may have dirty records. + // We should skip these dirty records. + return true, nil + } + pfx := tablecodec.CutIndexPrefix(newIndexKey) + + w.currentTempIndexPrefix = kv.Key(pfx).Clone() + w.currentIndex = curIdx + + return false, nil +} + +func (w *mergeIndexWorker) fetchTempIndexVals( + txn kv.Transaction, + taskRange reorgBackfillTask, +) ([]*temporaryIndexRecord, kv.Key, bool, error) { + startTime := time.Now() + w.tmpIdxRecords = w.tmpIdxRecords[:0] + w.tmpIdxKeys = w.tmpIdxKeys[:0] + w.originIdxKeys = w.originIdxKeys[:0] + // taskDone means that the merged handle is out of taskRange.endHandle. + taskDone := false + oprStartTime := startTime + idxPrefix := w.table.IndexPrefix() + var lastKey kv.Key + err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, idxPrefix, txn.StartTS(), + taskRange.startKey, taskRange.endKey, func(_ kv.Handle, indexKey kv.Key, rawValue []byte) (more bool, err error) { + oprEndTime := time.Now() + logSlowOperations(oprEndTime.Sub(oprStartTime), "iterate temporary index in merge process", 0) + oprStartTime = oprEndTime + + taskDone = indexKey.Cmp(taskRange.endKey) >= 0 + + if taskDone || len(w.tmpIdxRecords) >= w.batchCnt { + return false, nil + } + + if w.needValidateKey && w.prefixIsChanged(indexKey) { + skip, err := w.updateCurrentIndexInfo(indexKey) + if err != nil || skip { + return skip, err + } + } + + tempIdxVal, err := tablecodec.DecodeTempIndexValue(rawValue) + if err != nil { + return false, err + } + tempIdxVal, err = decodeTempIndexHandleFromIndexKV(indexKey, tempIdxVal, len(w.currentIndex.Columns)) + if err != nil { + return false, err + } + + tempIdxVal = tempIdxVal.FilterOverwritten() + + // Extract the operations on the original index and replay them later. + for _, elem := range tempIdxVal { + if elem.KeyVer == tables.TempIndexKeyTypeMerge || elem.KeyVer == tables.TempIndexKeyTypeDelete { + // For 'm' version kvs, they are double-written. + // For 'd' version kvs, they are written in the delete-only state and can be dropped safely. + continue + } + + originIdxKey := make([]byte, len(indexKey)) + copy(originIdxKey, indexKey) + tablecodec.TempIndexKey2IndexKey(originIdxKey) + + idxRecord := &temporaryIndexRecord{ + handle: elem.Handle, + delete: elem.Delete, + unique: elem.Distinct, + skip: false, + } + if !elem.Delete { + idxRecord.vals = elem.Value + idxRecord.distinct = tablecodec.IndexKVIsUnique(elem.Value) + } + w.tmpIdxRecords = append(w.tmpIdxRecords, idxRecord) + w.originIdxKeys = append(w.originIdxKeys, originIdxKey) + w.tmpIdxKeys = append(w.tmpIdxKeys, indexKey) + } + + lastKey = indexKey + return true, nil + }) + + if len(w.tmpIdxRecords) == 0 { + taskDone = true + } + var nextKey kv.Key + if taskDone { + nextKey = taskRange.endKey + } else { + nextKey = lastKey + } + + logutil.DDLLogger().Debug("merge temp index txn fetches handle info", zap.Uint64("txnStartTS", txn.StartTS()), + zap.String("taskRange", taskRange.String()), zap.Duration("takeTime", time.Since(startTime))) + return w.tmpIdxRecords, nextKey.Next(), taskDone, errors.Trace(err) +} + +func decodeTempIndexHandleFromIndexKV(indexKey kv.Key, tmpVal tablecodec.TempIndexValue, idxColLen int) (ret tablecodec.TempIndexValue, err error) { + for _, elem := range tmpVal { + if elem.Handle == nil { + // If the handle is not found in the value of the temp index, it means + // 1) This is not a deletion marker, the handle is in the key or the origin value. + // 2) This is a deletion marker, but the handle is in the key of temp index. + elem.Handle, err = tablecodec.DecodeIndexHandle(indexKey, elem.Value, idxColLen) + if err != nil { + return nil, err + } + } + } + return tmpVal, nil +} diff --git a/pkg/ddl/ingest/backend.go b/pkg/ddl/ingest/backend.go index 6a8a4e0e6666a..802a5424ad9d2 100644 --- a/pkg/ddl/ingest/backend.go +++ b/pkg/ddl/ingest/backend.go @@ -215,11 +215,11 @@ func (bc *litBackendCtx) Flush(mode FlushMode) (flushed, imported bool, err erro } }() } - failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateBeforeImport")); _err_ == nil { if MockDMLExecutionStateBeforeImport != nil { MockDMLExecutionStateBeforeImport() } - }) + } for indexID, ei := range bc.engines { if err = bc.unsafeImportAndReset(ei); err != nil { @@ -286,9 +286,9 @@ func (bc *litBackendCtx) unsafeImportAndReset(ei *engineInfo) error { } err := resetFn(bc.ctx, ei.uuid) - failpoint.Inject("mockResetEngineFailed", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockResetEngineFailed")); _err_ == nil { err = fmt.Errorf("mock reset engine failed") - }) + } if err != nil { logutil.Logger(bc.ctx).Error(LitErrResetEngineFail, zap.Int64("index ID", ei.indexID)) err1 := closedEngine.Cleanup(bc.ctx) @@ -306,10 +306,10 @@ func (bc *litBackendCtx) unsafeImportAndReset(ei *engineInfo) error { var ForceSyncFlagForTest = false func (bc *litBackendCtx) checkFlush(mode FlushMode) (shouldFlush bool, shouldImport bool) { - failpoint.Inject("forceSyncFlagForTest", func() { + if _, _err_ := failpoint.Eval(_curpkg_("forceSyncFlagForTest")); _err_ == nil { // used in a manual test ForceSyncFlagForTest = true - }) + } if mode == FlushModeForceFlushAndImport || ForceSyncFlagForTest { return true, true } @@ -317,11 +317,11 @@ func (bc *litBackendCtx) checkFlush(mode FlushMode) (shouldFlush bool, shouldImp shouldImport = bc.diskRoot.ShouldImport() interval := bc.updateInterval // This failpoint will be manually set through HTTP status port. - failpoint.Inject("mockSyncIntervalMs", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockSyncIntervalMs")); _err_ == nil { if v, ok := val.(int); ok { interval = time.Duration(v) * time.Millisecond } - }) + } shouldFlush = shouldImport || time.Since(bc.timeOfLastFlush.Load()) >= interval return shouldFlush, shouldImport diff --git a/pkg/ddl/ingest/backend.go__failpoint_stash__ b/pkg/ddl/ingest/backend.go__failpoint_stash__ new file mode 100644 index 0000000000000..6a8a4e0e6666a --- /dev/null +++ b/pkg/ddl/ingest/backend.go__failpoint_stash__ @@ -0,0 +1,343 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ingest + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + tikv "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/common" + lightning "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/errormanager" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/util/logutil" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// MockDMLExecutionStateBeforeImport is a failpoint to mock the DML execution state before import. +var MockDMLExecutionStateBeforeImport func() + +// BackendCtx is the backend context for one add index reorg task. +type BackendCtx interface { + // Register create a new engineInfo for each index ID and register it to the + // backend context. If the index ID is already registered, it will return the + // associated engines. Only one group of index ID is allowed to register for a + // BackendCtx. + // + // Register is only used in local disk based ingest. + Register(indexIDs []int64, uniques []bool, tbl table.Table) ([]Engine, error) + // FinishAndUnregisterEngines finishes the task and unregisters all engines that + // are Register-ed before. It's safe to call it multiple times. + // + // FinishAndUnregisterEngines is only used in local disk based ingest. + FinishAndUnregisterEngines(opt UnregisterOpt) error + + FlushController + + AttachCheckpointManager(*CheckpointManager) + GetCheckpointManager() *CheckpointManager + + // GetLocalBackend exposes local.Backend. It's only used in global sort based + // ingest. + GetLocalBackend() *local.Backend + // CollectRemoteDuplicateRows collects duplicate entry error for given index as + // the supplement of FlushController.Flush. + // + // CollectRemoteDuplicateRows is only used in global sort based ingest. + CollectRemoteDuplicateRows(indexID int64, tbl table.Table) error +} + +// FlushMode is used to control how to flush. +type FlushMode byte + +const ( + // FlushModeAuto means caller does not enforce any flush, the implementation can + // decide it. + FlushModeAuto FlushMode = iota + // FlushModeForceFlushAndImport means flush and import all data to TiKV. + FlushModeForceFlushAndImport +) + +// litBackendCtx implements BackendCtx. +type litBackendCtx struct { + engines map[int64]*engineInfo + memRoot MemRoot + diskRoot DiskRoot + jobID int64 + tbl table.Table + backend *local.Backend + ctx context.Context + cfg *lightning.Config + sysVars map[string]string + + flushing atomic.Bool + timeOfLastFlush atomicutil.Time + updateInterval time.Duration + checkpointMgr *CheckpointManager + etcdClient *clientv3.Client + + // unregisterMu prevents concurrent calls of `FinishAndUnregisterEngines`. + // For details, see https://github.com/pingcap/tidb/issues/53843. + unregisterMu sync.Mutex +} + +func (bc *litBackendCtx) handleErrorAfterCollectRemoteDuplicateRows( + err error, + indexID int64, + tbl table.Table, + hasDupe bool, +) error { + if err != nil && !common.ErrFoundIndexConflictRecords.Equal(err) { + logutil.Logger(bc.ctx).Error(LitInfoRemoteDupCheck, zap.Error(err), + zap.String("table", tbl.Meta().Name.O), zap.Int64("index ID", indexID)) + return errors.Trace(err) + } else if hasDupe { + logutil.Logger(bc.ctx).Error(LitErrRemoteDupExistErr, + zap.String("table", tbl.Meta().Name.O), zap.Int64("index ID", indexID)) + + if common.ErrFoundIndexConflictRecords.Equal(err) { + tErr, ok := errors.Cause(err).(*terror.Error) + if !ok { + return errors.Trace(tikv.ErrKeyExists) + } + if len(tErr.Args()) != 4 { + return errors.Trace(tikv.ErrKeyExists) + } + //nolint: forcetypeassert + indexName := tErr.Args()[1].(string) + //nolint: forcetypeassert + keyCols := tErr.Args()[2].([]string) + return errors.Trace(tikv.GenKeyExistsErr(keyCols, indexName)) + } + return errors.Trace(tikv.ErrKeyExists) + } + return nil +} + +// CollectRemoteDuplicateRows collects duplicate rows from remote TiKV. +func (bc *litBackendCtx) CollectRemoteDuplicateRows(indexID int64, tbl table.Table) error { + return bc.collectRemoteDuplicateRows(indexID, tbl) +} + +func (bc *litBackendCtx) collectRemoteDuplicateRows(indexID int64, tbl table.Table) error { + errorMgr := errormanager.New(nil, bc.cfg, log.Logger{Logger: logutil.Logger(bc.ctx)}) + dupeController := bc.backend.GetDupeController(bc.cfg.TikvImporter.RangeConcurrency*2, errorMgr) + hasDupe, err := dupeController.CollectRemoteDuplicateRows(bc.ctx, tbl, tbl.Meta().Name.L, &encode.SessionOptions{ + SQLMode: mysql.ModeStrictAllTables, + SysVars: bc.sysVars, + IndexID: indexID, + }, lightning.ErrorOnDup) + return bc.handleErrorAfterCollectRemoteDuplicateRows(err, indexID, tbl, hasDupe) +} + +func acquireLock(ctx context.Context, se *concurrency.Session, key string) (*concurrency.Mutex, error) { + mu := concurrency.NewMutex(se, key) + err := mu.Lock(ctx) + if err != nil { + return nil, err + } + return mu, nil +} + +// Flush implements FlushController. +func (bc *litBackendCtx) Flush(mode FlushMode) (flushed, imported bool, err error) { + shouldFlush, shouldImport := bc.checkFlush(mode) + if !shouldFlush { + return false, false, nil + } + if !bc.flushing.CompareAndSwap(false, true) { + return false, false, nil + } + defer bc.flushing.Store(false) + + for _, ei := range bc.engines { + ei.flushLock.Lock() + //nolint: all_revive,revive + defer ei.flushLock.Unlock() + + if err = ei.Flush(); err != nil { + return false, false, err + } + } + bc.timeOfLastFlush.Store(time.Now()) + + if !shouldImport { + return true, false, nil + } + + // Use distributed lock if run in distributed mode). + if bc.etcdClient != nil { + distLockKey := fmt.Sprintf("/tidb/distributeLock/%d", bc.jobID) + se, _ := concurrency.NewSession(bc.etcdClient) + mu, err := acquireLock(bc.ctx, se, distLockKey) + if err != nil { + return true, false, errors.Trace(err) + } + logutil.Logger(bc.ctx).Info("acquire distributed flush lock success", zap.Int64("jobID", bc.jobID)) + defer func() { + err = mu.Unlock(bc.ctx) + if err != nil { + logutil.Logger(bc.ctx).Warn("release distributed flush lock error", zap.Error(err), zap.Int64("jobID", bc.jobID)) + } else { + logutil.Logger(bc.ctx).Info("release distributed flush lock success", zap.Int64("jobID", bc.jobID)) + } + err = se.Close() + if err != nil { + logutil.Logger(bc.ctx).Warn("close session error", zap.Error(err)) + } + }() + } + failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { + if MockDMLExecutionStateBeforeImport != nil { + MockDMLExecutionStateBeforeImport() + } + }) + + for indexID, ei := range bc.engines { + if err = bc.unsafeImportAndReset(ei); err != nil { + if common.ErrFoundDuplicateKeys.Equal(err) { + idxInfo := model.FindIndexInfoByID(bc.tbl.Meta().Indices, indexID) + if idxInfo == nil { + logutil.Logger(bc.ctx).Error( + "index not found", + zap.Int64("indexID", indexID)) + err = tikv.ErrKeyExists + } else { + err = TryConvertToKeyExistsErr(err, idxInfo, bc.tbl.Meta()) + } + } + return true, false, err + } + } + + var newTS uint64 + if mgr := bc.GetCheckpointManager(); mgr != nil { + // for local disk case, we need to refresh TS because duplicate detection + // requires each ingest to have a unique TS. + // + // TODO(lance6716): there's still a chance that data is imported but because of + // checkpoint is low-watermark, the data will still be imported again with + // another TS after failover. Need to refine the checkpoint mechanism. + newTS, err = mgr.refreshTSAndUpdateCP() + if err == nil { + for _, ei := range bc.engines { + ei.openedEngine.SetTS(newTS) + } + } + } + + return true, true, err +} + +func (bc *litBackendCtx) unsafeImportAndReset(ei *engineInfo) error { + logger := log.FromContext(bc.ctx).With( + zap.Stringer("engineUUID", ei.uuid), + ) + logger.Info(LitInfoUnsafeImport, + zap.Int64("index ID", ei.indexID), + zap.String("usage info", bc.diskRoot.UsageInfo())) + + closedEngine := backend.NewClosedEngine(bc.backend, logger, ei.uuid, 0) + + regionSplitSize := int64(lightning.SplitRegionSize) * int64(lightning.MaxSplitRegionSizeRatio) + regionSplitKeys := int64(lightning.SplitRegionKeys) + if err := closedEngine.Import(bc.ctx, regionSplitSize, regionSplitKeys); err != nil { + logutil.Logger(bc.ctx).Error(LitErrIngestDataErr, zap.Int64("index ID", ei.indexID), + zap.String("usage info", bc.diskRoot.UsageInfo())) + return err + } + + resetFn := bc.backend.ResetEngineSkipAllocTS + mgr := bc.GetCheckpointManager() + if mgr == nil { + // disttask case, no need to refresh TS. + // + // TODO(lance6716): for disttask local sort case, we need to use a fixed TS. But + // it doesn't have checkpoint, so we need to find a way to save TS. + resetFn = bc.backend.ResetEngine + } + + err := resetFn(bc.ctx, ei.uuid) + failpoint.Inject("mockResetEngineFailed", func() { + err = fmt.Errorf("mock reset engine failed") + }) + if err != nil { + logutil.Logger(bc.ctx).Error(LitErrResetEngineFail, zap.Int64("index ID", ei.indexID)) + err1 := closedEngine.Cleanup(bc.ctx) + if err1 != nil { + logutil.Logger(ei.ctx).Error(LitErrCleanEngineErr, zap.Error(err1), + zap.Int64("job ID", ei.jobID), zap.Int64("index ID", ei.indexID)) + } + ei.openedEngine = nil + return err + } + return nil +} + +// ForceSyncFlagForTest is a flag to force sync only for test. +var ForceSyncFlagForTest = false + +func (bc *litBackendCtx) checkFlush(mode FlushMode) (shouldFlush bool, shouldImport bool) { + failpoint.Inject("forceSyncFlagForTest", func() { + // used in a manual test + ForceSyncFlagForTest = true + }) + if mode == FlushModeForceFlushAndImport || ForceSyncFlagForTest { + return true, true + } + bc.diskRoot.UpdateUsage() + shouldImport = bc.diskRoot.ShouldImport() + interval := bc.updateInterval + // This failpoint will be manually set through HTTP status port. + failpoint.Inject("mockSyncIntervalMs", func(val failpoint.Value) { + if v, ok := val.(int); ok { + interval = time.Duration(v) * time.Millisecond + } + }) + shouldFlush = shouldImport || + time.Since(bc.timeOfLastFlush.Load()) >= interval + return shouldFlush, shouldImport +} + +// AttachCheckpointManager attaches a checkpoint manager to the backend context. +func (bc *litBackendCtx) AttachCheckpointManager(mgr *CheckpointManager) { + bc.checkpointMgr = mgr +} + +// GetCheckpointManager returns the checkpoint manager attached to the backend context. +func (bc *litBackendCtx) GetCheckpointManager() *CheckpointManager { + return bc.checkpointMgr +} + +// GetLocalBackend returns the local backend. +func (bc *litBackendCtx) GetLocalBackend() *local.Backend { + return bc.backend +} diff --git a/pkg/ddl/ingest/backend_mgr.go b/pkg/ddl/ingest/backend_mgr.go index 068047e5a8710..031a5e0da6886 100644 --- a/pkg/ddl/ingest/backend_mgr.go +++ b/pkg/ddl/ingest/backend_mgr.go @@ -136,9 +136,9 @@ func (m *litBackendCtxMgr) Register( logutil.Logger(ctx).Warn(LitWarnConfigError, zap.Int64("job ID", jobID), zap.Error(err)) return nil, err } - failpoint.Inject("beforeCreateLocalBackend", func() { + if _, _err_ := failpoint.Eval(_curpkg_("beforeCreateLocalBackend")); _err_ == nil { ResignOwnerForTest.Store(true) - }) + } // lock backends because createLocalBackend will let lightning create the sort // folder, which may cause cleanupSortPath wrongly delete the sort folder if only // checking the existence of the entry in backends. diff --git a/pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ b/pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ new file mode 100644 index 0000000000000..068047e5a8710 --- /dev/null +++ b/pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ @@ -0,0 +1,285 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ingest + +import ( + "context" + "math" + "os" + "path/filepath" + "strconv" + "sync" + "time" + + "github.com/pingcap/failpoint" + ddllogutil "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/util/logutil" + kvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +// BackendCtxMgr is used to manage the BackendCtx. +type BackendCtxMgr interface { + // CheckMoreTasksAvailable checks if it can run more ingest backfill tasks. + CheckMoreTasksAvailable() (bool, error) + // Register uses jobID to identify the BackendCtx. If there's already a + // BackendCtx with the same jobID, it will be returned. Otherwise, a new + // BackendCtx will be created and returned. + Register( + ctx context.Context, + jobID int64, + hasUnique bool, + etcdClient *clientv3.Client, + pdSvcDiscovery pd.ServiceDiscovery, + resourceGroupName string, + ) (BackendCtx, error) + Unregister(jobID int64) + // EncodeJobSortPath encodes the job ID to the local disk sort path. + EncodeJobSortPath(jobID int64) string + // Load returns the registered BackendCtx with the given jobID. + Load(jobID int64) (BackendCtx, bool) +} + +// litBackendCtxMgr manages multiple litBackendCtx for each DDL job. Each +// litBackendCtx can use some local disk space and memory resource which are +// controlled by litBackendCtxMgr. +type litBackendCtxMgr struct { + // the lifetime of entries in backends should cover all other resources so it can + // be used as a lightweight indicator when interacts with other resources. + // Currently, the entry must be created not after disk folder is created and + // memory usage is tracked, and vice versa when considering deletion. + backends struct { + mu sync.RWMutex + m map[int64]*litBackendCtx + } + // all disk resources of litBackendCtx should be used under path. Currently the + // hierarchy is ${path}/${jobID} for each litBackendCtx. + path string + memRoot MemRoot + diskRoot DiskRoot +} + +// NewLitBackendCtxMgr creates a new litBackendCtxMgr. +func NewLitBackendCtxMgr(path string, memQuota uint64) BackendCtxMgr { + mgr := &litBackendCtxMgr{ + path: path, + } + mgr.backends.m = make(map[int64]*litBackendCtx, 4) + mgr.memRoot = NewMemRootImpl(int64(memQuota), mgr) + mgr.diskRoot = NewDiskRootImpl(path, mgr) + LitMemRoot = mgr.memRoot + litDiskRoot = mgr.diskRoot + litDiskRoot.UpdateUsage() + err := litDiskRoot.StartupCheck() + if err != nil { + ddllogutil.DDLIngestLogger().Warn("ingest backfill may not be available", zap.Error(err)) + } + return mgr +} + +// CheckMoreTasksAvailable implements BackendCtxMgr.CheckMoreTaskAvailable interface. +func (m *litBackendCtxMgr) CheckMoreTasksAvailable() (bool, error) { + if err := m.diskRoot.PreCheckUsage(); err != nil { + ddllogutil.DDLIngestLogger().Info("ingest backfill is not available", zap.Error(err)) + return false, err + } + return true, nil +} + +// ResignOwnerForTest is only used for test. +var ResignOwnerForTest = atomic.NewBool(false) + +// Register creates a new backend and registers it to the backend context. +func (m *litBackendCtxMgr) Register( + ctx context.Context, + jobID int64, + hasUnique bool, + etcdClient *clientv3.Client, + pdSvcDiscovery pd.ServiceDiscovery, + resourceGroupName string, +) (BackendCtx, error) { + bc, exist := m.Load(jobID) + if exist { + return bc, nil + } + + m.memRoot.RefreshConsumption() + ok := m.memRoot.CheckConsume(structSizeBackendCtx) + if !ok { + return nil, genBackendAllocMemFailedErr(ctx, m.memRoot, jobID) + } + sortPath := m.EncodeJobSortPath(jobID) + err := os.MkdirAll(sortPath, 0700) + if err != nil { + logutil.Logger(ctx).Error(LitErrCreateDirFail, zap.Error(err)) + return nil, err + } + cfg, err := genConfig(ctx, sortPath, m.memRoot, hasUnique, resourceGroupName) + if err != nil { + logutil.Logger(ctx).Warn(LitWarnConfigError, zap.Int64("job ID", jobID), zap.Error(err)) + return nil, err + } + failpoint.Inject("beforeCreateLocalBackend", func() { + ResignOwnerForTest.Store(true) + }) + // lock backends because createLocalBackend will let lightning create the sort + // folder, which may cause cleanupSortPath wrongly delete the sort folder if only + // checking the existence of the entry in backends. + m.backends.mu.Lock() + bd, err := createLocalBackend(ctx, cfg, pdSvcDiscovery) + if err != nil { + m.backends.mu.Unlock() + logutil.Logger(ctx).Error(LitErrCreateBackendFail, zap.Int64("job ID", jobID), zap.Error(err)) + return nil, err + } + + bcCtx := newBackendContext(ctx, jobID, bd, cfg.lightning, defaultImportantVariables, m.memRoot, m.diskRoot, etcdClient) + m.backends.m[jobID] = bcCtx + m.memRoot.Consume(structSizeBackendCtx) + m.backends.mu.Unlock() + + logutil.Logger(ctx).Info(LitInfoCreateBackend, zap.Int64("job ID", jobID), + zap.Int64("current memory usage", m.memRoot.CurrentUsage()), + zap.Int64("max memory quota", m.memRoot.MaxMemoryQuota()), + zap.Bool("has unique index", hasUnique)) + return bcCtx, nil +} + +// EncodeJobSortPath implements BackendCtxMgr. +func (m *litBackendCtxMgr) EncodeJobSortPath(jobID int64) string { + return filepath.Join(m.path, encodeBackendTag(jobID)) +} + +func createLocalBackend( + ctx context.Context, + cfg *litConfig, + pdSvcDiscovery pd.ServiceDiscovery, +) (*local.Backend, error) { + tls, err := cfg.lightning.ToTLS() + if err != nil { + logutil.Logger(ctx).Error(LitErrCreateBackendFail, zap.Error(err)) + return nil, err + } + + ddllogutil.DDLIngestLogger().Info("create local backend for adding index", + zap.String("sortDir", cfg.lightning.TikvImporter.SortedKVDir), + zap.String("keyspaceName", cfg.keyspaceName)) + // We disable the switch TiKV mode feature for now, + // because the impact is not fully tested. + var raftKV2SwitchModeDuration time.Duration + backendConfig := local.NewBackendConfig(cfg.lightning, int(litRLimit), cfg.keyspaceName, cfg.resourceGroup, kvutil.ExplicitTypeDDL, raftKV2SwitchModeDuration) + return local.NewBackend(ctx, tls, backendConfig, pdSvcDiscovery) +} + +const checkpointUpdateInterval = 10 * time.Minute + +func newBackendContext( + ctx context.Context, + jobID int64, + be *local.Backend, + cfg *config.Config, + vars map[string]string, + memRoot MemRoot, + diskRoot DiskRoot, + etcdClient *clientv3.Client, +) *litBackendCtx { + bCtx := &litBackendCtx{ + engines: make(map[int64]*engineInfo, 10), + memRoot: memRoot, + diskRoot: diskRoot, + jobID: jobID, + backend: be, + ctx: ctx, + cfg: cfg, + sysVars: vars, + updateInterval: checkpointUpdateInterval, + etcdClient: etcdClient, + } + bCtx.timeOfLastFlush.Store(time.Now()) + return bCtx +} + +// Unregister removes a backend context from the backend context manager. +func (m *litBackendCtxMgr) Unregister(jobID int64) { + m.backends.mu.RLock() + _, exist := m.backends.m[jobID] + m.backends.mu.RUnlock() + if !exist { + return + } + + m.backends.mu.Lock() + defer m.backends.mu.Unlock() + bc, exist := m.backends.m[jobID] + if !exist { + return + } + _ = bc.FinishAndUnregisterEngines(OptCloseEngines) + bc.backend.Close() + m.memRoot.Release(structSizeBackendCtx) + m.memRoot.ReleaseWithTag(encodeBackendTag(jobID)) + logutil.Logger(bc.ctx).Info(LitInfoCloseBackend, zap.Int64("job ID", jobID), + zap.Int64("current memory usage", m.memRoot.CurrentUsage()), + zap.Int64("max memory quota", m.memRoot.MaxMemoryQuota())) + delete(m.backends.m, jobID) +} + +func (m *litBackendCtxMgr) Load(jobID int64) (BackendCtx, bool) { + m.backends.mu.RLock() + defer m.backends.mu.RUnlock() + ret, ok := m.backends.m[jobID] + return ret, ok +} + +// TotalDiskUsage returns the total disk usage of all backends. +func (m *litBackendCtxMgr) TotalDiskUsage() uint64 { + var totalDiskUsed uint64 + m.backends.mu.RLock() + defer m.backends.mu.RUnlock() + + for _, bc := range m.backends.m { + _, _, bcDiskUsed, _ := local.CheckDiskQuota(bc.backend, math.MaxInt64) + totalDiskUsed += uint64(bcDiskUsed) + } + return totalDiskUsed +} + +// UpdateMemoryUsage collects the memory usages from all the backend and updates it to the memRoot. +func (m *litBackendCtxMgr) UpdateMemoryUsage() { + m.backends.mu.RLock() + defer m.backends.mu.RUnlock() + + for _, bc := range m.backends.m { + curSize := bc.backend.TotalMemoryConsume() + m.memRoot.ReleaseWithTag(encodeBackendTag(bc.jobID)) + m.memRoot.ConsumeWithTag(encodeBackendTag(bc.jobID), curSize) + } +} + +// encodeBackendTag encodes the job ID to backend tag. +// The backend tag is also used as the file name of the local index data files. +func encodeBackendTag(jobID int64) string { + return strconv.FormatInt(jobID, 10) +} + +// decodeBackendTag decodes the backend tag to job ID. +func decodeBackendTag(name string) (int64, error) { + return strconv.ParseInt(name, 10, 64) +} diff --git a/pkg/ddl/ingest/binding__failpoint_binding__.go b/pkg/ddl/ingest/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..207ba83505336 --- /dev/null +++ b/pkg/ddl/ingest/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package ingest + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/ddl/ingest/checkpoint.go b/pkg/ddl/ingest/checkpoint.go index 2d1536905736c..8560f616abe3d 100644 --- a/pkg/ddl/ingest/checkpoint.go +++ b/pkg/ddl/ingest/checkpoint.go @@ -221,14 +221,14 @@ func (s *CheckpointManager) AdvanceWatermark(flushed, imported bool) { return } - failpoint.Inject("resignAfterFlush", func() { + if _, _err_ := failpoint.Eval(_curpkg_("resignAfterFlush")); _err_ == nil { // used in a manual test ResignOwnerForTest.Store(true) // wait until ResignOwnerForTest is processed for ResignOwnerForTest.Load() { time.Sleep(100 * time.Millisecond) } - }) + } s.mu.Lock() defer s.mu.Unlock() @@ -445,10 +445,10 @@ func (s *CheckpointManager) updateCheckpointImpl() error { } func (s *CheckpointManager) updateCheckpointLoop() { - failpoint.Inject("checkpointLoopExit", func() { + if _, _err_ := failpoint.Eval(_curpkg_("checkpointLoopExit")); _err_ == nil { // used in a manual test - failpoint.Return() - }) + return + } ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { @@ -477,10 +477,10 @@ func (s *CheckpointManager) updateCheckpointLoop() { } func (s *CheckpointManager) updateCheckpoint() error { - failpoint.Inject("checkpointLoopExit", func() { + if _, _err_ := failpoint.Eval(_curpkg_("checkpointLoopExit")); _err_ == nil { // used in a manual test - failpoint.Return(errors.New("failpoint triggered so can't update checkpoint")) - }) + return errors.New("failpoint triggered so can't update checkpoint") + } finishCh := make(chan struct{}) select { case s.updaterCh <- finishCh: diff --git a/pkg/ddl/ingest/checkpoint.go__failpoint_stash__ b/pkg/ddl/ingest/checkpoint.go__failpoint_stash__ new file mode 100644 index 0000000000000..2d1536905736c --- /dev/null +++ b/pkg/ddl/ingest/checkpoint.go__failpoint_stash__ @@ -0,0 +1,509 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ingest + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + "net" + "strconv" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" +) + +// CheckpointManager is a checkpoint manager implementation that used by +// non-distributed reorganization. It manages the data as two-level checkpoints: +// "flush"ed to local storage and "import"ed to TiKV. The checkpoint is saved in +// a table in the TiDB cluster. +type CheckpointManager struct { + ctx context.Context + cancel context.CancelFunc + sessPool *sess.Pool + jobID int64 + indexIDs []int64 + localStoreDir string + pdCli pd.Client + logger *zap.Logger + physicalID int64 + + // Derived and unchanged after the initialization. + instanceAddr string + localDataIsValid bool + + // Live in memory. + mu sync.Mutex + checkpoints map[int]*taskCheckpoint // task ID -> checkpoint + // we require each task ID to be continuous and start from 1. + minTaskIDFinished int + dirty bool + + // Persisted to the storage. + flushedKeyLowWatermark kv.Key + importedKeyLowWatermark kv.Key + flushedKeyCnt int + importedKeyCnt int + + ts uint64 + + // For persisting the checkpoint periodically. + updaterWg sync.WaitGroup + updaterCh chan chan struct{} +} + +// taskCheckpoint is the checkpoint for a single task. +type taskCheckpoint struct { + totalKeys int + writtenKeys int + checksum int64 + endKey kv.Key + lastBatchRead bool +} + +// FlushController is an interface to control the flush of data so after it +// returns caller can save checkpoint. +type FlushController interface { + // Flush checks if al engines need to be flushed and imported based on given + // FlushMode. It's concurrent safe. + Flush(mode FlushMode) (flushed, imported bool, err error) +} + +// NewCheckpointManager creates a new checkpoint manager. +func NewCheckpointManager( + ctx context.Context, + sessPool *sess.Pool, + physicalID int64, + jobID int64, + indexIDs []int64, + localStoreDir string, + pdCli pd.Client, +) (*CheckpointManager, error) { + instanceAddr := InstanceAddr() + ctx2, cancel := context.WithCancel(ctx) + logger := logutil.DDLIngestLogger().With( + zap.Int64("jobID", jobID), zap.Int64s("indexIDs", indexIDs)) + + cm := &CheckpointManager{ + ctx: ctx2, + cancel: cancel, + sessPool: sessPool, + jobID: jobID, + indexIDs: indexIDs, + localStoreDir: localStoreDir, + pdCli: pdCli, + logger: logger, + checkpoints: make(map[int]*taskCheckpoint, 16), + mu: sync.Mutex{}, + instanceAddr: instanceAddr, + physicalID: physicalID, + updaterWg: sync.WaitGroup{}, + updaterCh: make(chan chan struct{}), + } + err := cm.resumeOrInitCheckpoint() + if err != nil { + return nil, err + } + cm.updaterWg.Add(1) + go func() { + cm.updateCheckpointLoop() + cm.updaterWg.Done() + }() + logger.Info("create checkpoint manager") + return cm, nil +} + +// InstanceAddr returns the string concat with instance address and temp-dir. +func InstanceAddr() string { + cfg := config.GetGlobalConfig() + dsn := net.JoinHostPort(cfg.AdvertiseAddress, strconv.Itoa(int(cfg.Port))) + return fmt.Sprintf("%s:%s", dsn, cfg.TempDir) +} + +// IsKeyProcessed checks if the key is processed. The key may not be imported. +// This is called before the reader reads the data and decides whether to skip +// the current task. +func (s *CheckpointManager) IsKeyProcessed(end kv.Key) bool { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.importedKeyLowWatermark) > 0 && end.Cmp(s.importedKeyLowWatermark) <= 0 { + return true + } + return s.localDataIsValid && len(s.flushedKeyLowWatermark) > 0 && end.Cmp(s.flushedKeyLowWatermark) <= 0 +} + +// LastProcessedKey finds the last processed key in checkpoint. +// If there is no processed key, it returns nil. +func (s *CheckpointManager) LastProcessedKey() kv.Key { + s.mu.Lock() + defer s.mu.Unlock() + + if s.localDataIsValid && len(s.flushedKeyLowWatermark) > 0 { + return s.flushedKeyLowWatermark.Clone() + } + if len(s.importedKeyLowWatermark) > 0 { + return s.importedKeyLowWatermark.Clone() + } + return nil +} + +// Status returns the status of the checkpoint. +func (s *CheckpointManager) Status() (keyCnt int, minKeyImported kv.Key) { + s.mu.Lock() + defer s.mu.Unlock() + total := 0 + for _, cp := range s.checkpoints { + total += cp.writtenKeys + } + // TODO(lance6716): ??? + return s.flushedKeyCnt + total, s.importedKeyLowWatermark +} + +// Register registers a new task. taskID MUST be continuous ascending and start +// from 1. +// +// TODO(lance6716): remove this constraint, use endKey as taskID and use +// ordered map type for checkpoints. +func (s *CheckpointManager) Register(taskID int, end kv.Key) { + s.mu.Lock() + defer s.mu.Unlock() + s.checkpoints[taskID] = &taskCheckpoint{ + endKey: end, + } +} + +// UpdateTotalKeys updates the total keys of the task. +// This is called by the reader after reading the data to update the number of rows contained in the current chunk. +func (s *CheckpointManager) UpdateTotalKeys(taskID int, delta int, last bool) { + s.mu.Lock() + defer s.mu.Unlock() + cp := s.checkpoints[taskID] + cp.totalKeys += delta + cp.lastBatchRead = last +} + +// UpdateWrittenKeys updates the written keys of the task. +// This is called by the writer after writing the local engine to update the current number of rows written. +func (s *CheckpointManager) UpdateWrittenKeys(taskID int, delta int) { + s.mu.Lock() + cp := s.checkpoints[taskID] + cp.writtenKeys += delta + s.mu.Unlock() +} + +// AdvanceWatermark advances the watermark according to flushed or imported status. +func (s *CheckpointManager) AdvanceWatermark(flushed, imported bool) { + if !flushed { + return + } + + failpoint.Inject("resignAfterFlush", func() { + // used in a manual test + ResignOwnerForTest.Store(true) + // wait until ResignOwnerForTest is processed + for ResignOwnerForTest.Load() { + time.Sleep(100 * time.Millisecond) + } + }) + + s.mu.Lock() + defer s.mu.Unlock() + s.afterFlush() + + if imported { + s.afterImport() + } +} + +// afterFlush should be called after all engine is flushed. +func (s *CheckpointManager) afterFlush() { + for { + cp := s.checkpoints[s.minTaskIDFinished+1] + if cp == nil || !cp.lastBatchRead || cp.writtenKeys < cp.totalKeys { + break + } + s.minTaskIDFinished++ + s.flushedKeyLowWatermark = cp.endKey + s.flushedKeyCnt += cp.totalKeys + delete(s.checkpoints, s.minTaskIDFinished) + s.dirty = true + } +} + +func (s *CheckpointManager) afterImport() { + if s.importedKeyLowWatermark.Cmp(s.flushedKeyLowWatermark) > 0 { + s.logger.Warn("lower watermark of flushed key is less than imported key", + zap.String("flushed", hex.EncodeToString(s.flushedKeyLowWatermark)), + zap.String("imported", hex.EncodeToString(s.importedKeyLowWatermark)), + ) + return + } + s.importedKeyLowWatermark = s.flushedKeyLowWatermark + s.importedKeyCnt = s.flushedKeyCnt + s.dirty = true +} + +// Close closes the checkpoint manager. +func (s *CheckpointManager) Close() { + err := s.updateCheckpoint() + if err != nil { + s.logger.Error("update checkpoint failed", zap.Error(err)) + } + + s.cancel() + s.updaterWg.Wait() + s.logger.Info("checkpoint manager closed") +} + +// GetTS returns the TS saved in checkpoint. +func (s *CheckpointManager) GetTS() uint64 { + s.mu.Lock() + defer s.mu.Unlock() + return s.ts +} + +// JobReorgMeta is the metadata for a reorg job. +type JobReorgMeta struct { + Checkpoint *ReorgCheckpoint `json:"reorg_checkpoint"` +} + +// ReorgCheckpoint is the checkpoint for a reorg job. +type ReorgCheckpoint struct { + LocalSyncKey kv.Key `json:"local_sync_key"` + LocalKeyCount int `json:"local_key_count"` + GlobalSyncKey kv.Key `json:"global_sync_key"` + GlobalKeyCount int `json:"global_key_count"` + InstanceAddr string `json:"instance_addr"` + + PhysicalID int64 `json:"physical_id"` + // TS of next engine ingest. + TS uint64 `json:"ts"` + + Version int64 `json:"version"` +} + +// JobCheckpointVersionCurrent is the current version of the checkpoint. +const ( + JobCheckpointVersionCurrent = JobCheckpointVersion1 + JobCheckpointVersion1 = 1 +) + +func (s *CheckpointManager) resumeOrInitCheckpoint() error { + sessCtx, err := s.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer s.sessPool.Put(sessCtx) + ddlSess := sess.NewSession(sessCtx) + err = ddlSess.RunInTxn(func(se *sess.Session) error { + template := "select reorg_meta from mysql.tidb_ddl_reorg where job_id = %d and ele_type = %s;" + sql := fmt.Sprintf(template, s.jobID, util.WrapKey2String(meta.IndexElementKey)) + ctx := kv.WithInternalSourceType(s.ctx, kv.InternalTxnBackfillDDLPrefix+"add_index") + rows, err := se.Execute(ctx, sql, "get_checkpoint") + if err != nil { + return errors.Trace(err) + } + + if len(rows) == 0 || rows[0].IsNull(0) { + return nil + } + rawReorgMeta := rows[0].GetBytes(0) + var reorgMeta JobReorgMeta + err = json.Unmarshal(rawReorgMeta, &reorgMeta) + if err != nil { + return errors.Trace(err) + } + if cp := reorgMeta.Checkpoint; cp != nil { + if cp.PhysicalID != s.physicalID { + s.logger.Info("checkpoint physical table ID mismatch", + zap.Int64("current", s.physicalID), + zap.Int64("get", cp.PhysicalID)) + return nil + } + s.importedKeyLowWatermark = cp.GlobalSyncKey + s.importedKeyCnt = cp.GlobalKeyCount + s.ts = cp.TS + folderNotEmpty := util.FolderNotEmpty(s.localStoreDir) + if folderNotEmpty && + (s.instanceAddr == cp.InstanceAddr || cp.InstanceAddr == "" /* initial state */) { + s.localDataIsValid = true + s.flushedKeyLowWatermark = cp.LocalSyncKey + s.flushedKeyCnt = cp.LocalKeyCount + } + s.logger.Info("resume checkpoint", + zap.String("flushed key low watermark", hex.EncodeToString(s.flushedKeyLowWatermark)), + zap.String("imported key low watermark", hex.EncodeToString(s.importedKeyLowWatermark)), + zap.Int64("physical table ID", cp.PhysicalID), + zap.String("previous instance", cp.InstanceAddr), + zap.String("current instance", s.instanceAddr), + zap.Bool("folder is empty", !folderNotEmpty)) + return nil + } + s.logger.Info("checkpoint not found") + return nil + }) + if err != nil { + return errors.Trace(err) + } + + if s.ts > 0 { + return nil + } + // if TS is not set, we need to allocate a TS and save it to the storage before + // continue. + p, l, err := s.pdCli.GetTS(s.ctx) + if err != nil { + return errors.Trace(err) + } + ts := oracle.ComposeTS(p, l) + s.ts = ts + return s.updateCheckpointImpl() +} + +// updateCheckpointImpl is only used by updateCheckpointLoop goroutine or in +// NewCheckpointManager. In other cases, use updateCheckpoint instead. +func (s *CheckpointManager) updateCheckpointImpl() error { + s.mu.Lock() + flushedKeyLowWatermark := s.flushedKeyLowWatermark + importedKeyLowWatermark := s.importedKeyLowWatermark + flushedKeyCnt := s.flushedKeyCnt + importedKeyCnt := s.importedKeyCnt + physicalID := s.physicalID + ts := s.ts + s.mu.Unlock() + + sessCtx, err := s.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer s.sessPool.Put(sessCtx) + ddlSess := sess.NewSession(sessCtx) + err = ddlSess.RunInTxn(func(se *sess.Session) error { + template := "update mysql.tidb_ddl_reorg set reorg_meta = %s where job_id = %d and ele_type = %s;" + cp := &ReorgCheckpoint{ + LocalSyncKey: flushedKeyLowWatermark, + GlobalSyncKey: importedKeyLowWatermark, + LocalKeyCount: flushedKeyCnt, + GlobalKeyCount: importedKeyCnt, + InstanceAddr: s.instanceAddr, + PhysicalID: physicalID, + TS: ts, + Version: JobCheckpointVersionCurrent, + } + rawReorgMeta, err := json.Marshal(JobReorgMeta{Checkpoint: cp}) + if err != nil { + return errors.Trace(err) + } + sql := fmt.Sprintf(template, util.WrapKey2String(rawReorgMeta), s.jobID, util.WrapKey2String(meta.IndexElementKey)) + ctx := kv.WithInternalSourceType(s.ctx, kv.InternalTxnBackfillDDLPrefix+"add_index") + _, err = se.Execute(ctx, sql, "update_checkpoint") + if err != nil { + return errors.Trace(err) + } + s.mu.Lock() + s.dirty = false + s.mu.Unlock() + return nil + }) + + logFunc := s.logger.Info + if err != nil { + logFunc = s.logger.With(zap.Error(err)).Error + } + logFunc("update checkpoint", + zap.String("local checkpoint", hex.EncodeToString(flushedKeyLowWatermark)), + zap.String("global checkpoint", hex.EncodeToString(importedKeyLowWatermark)), + zap.Int("flushed keys", flushedKeyCnt), + zap.Int("imported keys", importedKeyCnt), + zap.Int64("global physical ID", physicalID), + zap.Uint64("ts", ts)) + return err +} + +func (s *CheckpointManager) updateCheckpointLoop() { + failpoint.Inject("checkpointLoopExit", func() { + // used in a manual test + failpoint.Return() + }) + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case finishCh := <-s.updaterCh: + err := s.updateCheckpointImpl() + if err != nil { + s.logger.Error("update checkpoint failed", zap.Error(err)) + } + close(finishCh) + case <-ticker.C: + s.mu.Lock() + if !s.dirty { + s.mu.Unlock() + continue + } + s.mu.Unlock() + err := s.updateCheckpointImpl() + if err != nil { + s.logger.Error("periodically update checkpoint failed", zap.Error(err)) + } + case <-s.ctx.Done(): + return + } + } +} + +func (s *CheckpointManager) updateCheckpoint() error { + failpoint.Inject("checkpointLoopExit", func() { + // used in a manual test + failpoint.Return(errors.New("failpoint triggered so can't update checkpoint")) + }) + finishCh := make(chan struct{}) + select { + case s.updaterCh <- finishCh: + case <-s.ctx.Done(): + return s.ctx.Err() + } + // wait updateCheckpointLoop to finish checkpoint update. + select { + case <-finishCh: + case <-s.ctx.Done(): + return s.ctx.Err() + } + return nil +} + +func (s *CheckpointManager) refreshTSAndUpdateCP() (uint64, error) { + p, l, err := s.pdCli.GetTS(s.ctx) + if err != nil { + return 0, errors.Trace(err) + } + newTS := oracle.ComposeTS(p, l) + s.mu.Lock() + s.ts = newTS + s.mu.Unlock() + return newTS, s.updateCheckpoint() +} diff --git a/pkg/ddl/ingest/disk_root.go b/pkg/ddl/ingest/disk_root.go index 90e3fa4f62922..adfacda863dab 100644 --- a/pkg/ddl/ingest/disk_root.go +++ b/pkg/ddl/ingest/disk_root.go @@ -116,9 +116,9 @@ func (d *diskRootImpl) usageInfo() string { // PreCheckUsage implements DiskRoot interface. func (d *diskRootImpl) PreCheckUsage() error { - failpoint.Inject("mockIngestCheckEnvFailed", func(_ failpoint.Value) { - failpoint.Return(dbterror.ErrIngestCheckEnvFailed.FastGenByArgs("mock error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockIngestCheckEnvFailed")); _err_ == nil { + return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs("mock error") + } err := os.MkdirAll(d.path, 0700) if err != nil { return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(err.Error()) diff --git a/pkg/ddl/ingest/disk_root.go__failpoint_stash__ b/pkg/ddl/ingest/disk_root.go__failpoint_stash__ new file mode 100644 index 0000000000000..90e3fa4f62922 --- /dev/null +++ b/pkg/ddl/ingest/disk_root.go__failpoint_stash__ @@ -0,0 +1,157 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ingest + +import ( + "fmt" + "os" + "sync" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + lcom "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util/dbterror" + "go.uber.org/zap" +) + +// DiskRoot is used to track the disk usage for the lightning backfill process. +type DiskRoot interface { + UpdateUsage() + ShouldImport() bool + UsageInfo() string + PreCheckUsage() error + StartupCheck() error +} + +const capacityThreshold = 0.9 + +// diskRootImpl implements DiskRoot interface. +type diskRootImpl struct { + path string + capacity uint64 + used uint64 + bcUsed uint64 + bcCtx *litBackendCtxMgr + mu sync.RWMutex + updating atomic.Bool +} + +// NewDiskRootImpl creates a new DiskRoot. +func NewDiskRootImpl(path string, bcCtx *litBackendCtxMgr) DiskRoot { + return &diskRootImpl{ + path: path, + bcCtx: bcCtx, + } +} + +// UpdateUsage implements DiskRoot interface. +func (d *diskRootImpl) UpdateUsage() { + if !d.updating.CompareAndSwap(false, true) { + return + } + bcUsed := d.bcCtx.TotalDiskUsage() + var capacity, used uint64 + sz, err := lcom.GetStorageSize(d.path) + if err != nil { + logutil.DDLIngestLogger().Error(LitErrGetStorageQuota, zap.Error(err)) + } else { + capacity, used = sz.Capacity, sz.Capacity-sz.Available + } + d.updating.Store(false) + d.mu.Lock() + d.bcUsed = bcUsed + d.capacity = capacity + d.used = used + d.mu.Unlock() +} + +// ShouldImport implements DiskRoot interface. +func (d *diskRootImpl) ShouldImport() bool { + d.mu.RLock() + defer d.mu.RUnlock() + if d.bcUsed > variable.DDLDiskQuota.Load() { + logutil.DDLIngestLogger().Info("disk usage is over quota", + zap.Uint64("quota", variable.DDLDiskQuota.Load()), + zap.String("usage", d.usageInfo())) + return true + } + if d.used == 0 && d.capacity == 0 { + return false + } + if float64(d.used) >= float64(d.capacity)*capacityThreshold { + logutil.DDLIngestLogger().Warn("available disk space is less than 10%, "+ + "this may degrade the performance, "+ + "please make sure the disk available space is larger than @@tidb_ddl_disk_quota before adding index", + zap.String("usage", d.usageInfo())) + return true + } + return false +} + +// UsageInfo implements DiskRoot interface. +func (d *diskRootImpl) UsageInfo() string { + d.mu.RLock() + defer d.mu.RUnlock() + return d.usageInfo() +} + +func (d *diskRootImpl) usageInfo() string { + return fmt.Sprintf("disk usage: %d/%d, backend usage: %d", d.used, d.capacity, d.bcUsed) +} + +// PreCheckUsage implements DiskRoot interface. +func (d *diskRootImpl) PreCheckUsage() error { + failpoint.Inject("mockIngestCheckEnvFailed", func(_ failpoint.Value) { + failpoint.Return(dbterror.ErrIngestCheckEnvFailed.FastGenByArgs("mock error")) + }) + err := os.MkdirAll(d.path, 0700) + if err != nil { + return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(err.Error()) + } + sz, err := lcom.GetStorageSize(d.path) + if err != nil { + return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(err.Error()) + } + if RiskOfDiskFull(sz.Available, sz.Capacity) { + logutil.DDLIngestLogger().Warn("available disk space is less than 10%, cannot use ingest mode", + zap.String("sort path", d.path), + zap.String("usage", d.usageInfo())) + msg := fmt.Sprintf("no enough space in %s", d.path) + return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(msg) + } + return nil +} + +// StartupCheck implements DiskRoot interface. +func (d *diskRootImpl) StartupCheck() error { + sz, err := lcom.GetStorageSize(d.path) + if err != nil { + return errors.Trace(err) + } + quota := variable.DDLDiskQuota.Load() + if sz.Available < quota { + return errors.Errorf("the available disk space(%d) in %s should be greater than @@tidb_ddl_disk_quota(%d)", + sz.Available, d.path, quota) + } + return nil +} + +// RiskOfDiskFull checks if the disk has less than 10% space. +func RiskOfDiskFull(available, capacity uint64) bool { + return float64(available) < (1-capacityThreshold)*float64(capacity) +} diff --git a/pkg/ddl/ingest/env.go b/pkg/ddl/ingest/env.go index fda7e76720319..6aaff89dea58d 100644 --- a/pkg/ddl/ingest/env.go +++ b/pkg/ddl/ingest/env.go @@ -70,11 +70,11 @@ func InitGlobalLightningEnv(path string) (ok bool) { } else { memTotal = memTotal / 2 } - failpoint.Inject("setMemTotalInMB", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("setMemTotalInMB")); _err_ == nil { //nolint: forcetypeassert i := val.(int) memTotal = uint64(i) * size.MB - }) + } LitBackCtxMgr = NewLitBackendCtxMgr(path, memTotal) litRLimit = util.GenRLimit("ddl-ingest") LitInitialized = true diff --git a/pkg/ddl/ingest/env.go__failpoint_stash__ b/pkg/ddl/ingest/env.go__failpoint_stash__ new file mode 100644 index 0000000000000..fda7e76720319 --- /dev/null +++ b/pkg/ddl/ingest/env.go__failpoint_stash__ @@ -0,0 +1,189 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ingest + +import ( + "context" + "fmt" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/size" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +var ( + // LitBackCtxMgr is the entry for the lightning backfill process. + LitBackCtxMgr BackendCtxMgr + // LitMemRoot is used to track the memory usage of the lightning backfill process. + LitMemRoot MemRoot + // litDiskRoot is used to track the disk usage of the lightning backfill process. + litDiskRoot DiskRoot + // litRLimit is the max open file number of the lightning backfill process. + litRLimit uint64 + // LitInitialized is the flag indicates whether the lightning backfill process is initialized. + LitInitialized bool +) + +const defaultMemoryQuota = 2 * size.GB + +// InitGlobalLightningEnv initialize Lightning backfill environment. +func InitGlobalLightningEnv(path string) (ok bool) { + log.SetAppLogger(logutil.DDLIngestLogger()) + globalCfg := config.GetGlobalConfig() + if globalCfg.Store != "tikv" { + logutil.DDLIngestLogger().Warn(LitWarnEnvInitFail, + zap.String("storage limitation", "only support TiKV storage"), + zap.String("current storage", globalCfg.Store), + zap.Bool("lightning is initialized", LitInitialized)) + return false + } + memTotal, err := memory.MemTotal() + if err != nil { + logutil.DDLIngestLogger().Warn("get total memory fail", zap.Error(err)) + memTotal = defaultMemoryQuota + } else { + memTotal = memTotal / 2 + } + failpoint.Inject("setMemTotalInMB", func(val failpoint.Value) { + //nolint: forcetypeassert + i := val.(int) + memTotal = uint64(i) * size.MB + }) + LitBackCtxMgr = NewLitBackendCtxMgr(path, memTotal) + litRLimit = util.GenRLimit("ddl-ingest") + LitInitialized = true + logutil.DDLIngestLogger().Info(LitInfoEnvInitSucc, + zap.Uint64("memory limitation", memTotal), + zap.String("disk usage info", litDiskRoot.UsageInfo()), + zap.Uint64("max open file number", litRLimit), + zap.Bool("lightning is initialized", LitInitialized)) + return true +} + +// GenIngestTempDataDir generates a path for DDL ingest. +// Format: ${temp-dir}/tmp_ddl-{port} +func GenIngestTempDataDir() (string, error) { + tidbCfg := config.GetGlobalConfig() + sortPathSuffix := "/tmp_ddl-" + strconv.Itoa(int(tidbCfg.Port)) + sortPath := filepath.Join(tidbCfg.TempDir, sortPathSuffix) + + if _, err := os.Stat(sortPath); err != nil { + if !os.IsNotExist(err) { + logutil.DDLIngestLogger().Error(LitErrStatDirFail, + zap.String("sort path", sortPath), zap.Error(err)) + return "", err + } + } + err := os.MkdirAll(sortPath, 0o700) + if err != nil { + logutil.DDLIngestLogger().Error(LitErrCreateDirFail, + zap.String("sort path", sortPath), zap.Error(err)) + return "", err + } + logutil.DDLIngestLogger().Info(LitInfoSortDir, zap.String("data path", sortPath)) + return sortPath, nil +} + +// CleanUpTempDir is used to remove the stale index data. +// This function gets running DDL jobs from `mysql.tidb_ddl_job` and +// it only removes the folders that related to finished jobs. +func CleanUpTempDir(ctx context.Context, se sessionctx.Context, path string) { + entries, err := os.ReadDir(path) + if err != nil { + if strings.Contains(err.Error(), "no such file") { + return + } + logutil.DDLIngestLogger().Warn(LitErrCleanSortPath, zap.Error(err)) + return + } + toCheckJobIDs := make(map[int64]struct{}, len(entries)) + for _, entry := range entries { + if !entry.IsDir() { + continue + } + jobID, err := decodeBackendTag(entry.Name()) + if err != nil { + logutil.DDLIngestLogger().Error(LitErrCleanSortPath, zap.Error(err)) + continue + } + toCheckJobIDs[jobID] = struct{}{} + } + + if len(toCheckJobIDs) == 0 { + return + } + + idSlice := maps.Keys(toCheckJobIDs) + slices.Sort(idSlice) + processing, err := filterProcessingJobIDs(ctx, sess.NewSession(se), idSlice) + if err != nil { + logutil.DDLIngestLogger().Error(LitErrCleanSortPath, zap.Error(err)) + return + } + + for _, id := range processing { + delete(toCheckJobIDs, id) + } + + if len(toCheckJobIDs) == 0 { + return + } + + for id := range toCheckJobIDs { + logutil.DDLIngestLogger().Info("remove stale temp index data", + zap.Int64("jobID", id)) + p := filepath.Join(path, encodeBackendTag(id)) + err = os.RemoveAll(p) + if err != nil { + logutil.DDLIngestLogger().Error(LitErrCleanSortPath, zap.Error(err)) + } + } +} + +func filterProcessingJobIDs(ctx context.Context, se *sess.Session, jobIDs []int64) ([]int64, error) { + var sb strings.Builder + for i, id := range jobIDs { + if i != 0 { + sb.WriteString(",") + } + sb.WriteString(strconv.FormatInt(id, 10)) + } + sql := fmt.Sprintf( + "SELECT job_id FROM mysql.tidb_ddl_job WHERE job_id IN (%s)", + sb.String()) + rows, err := se.Execute(ctx, sql, "filter_processing_job_ids") + if err != nil { + return nil, errors.Trace(err) + } + ret := make([]int64, 0, len(rows)) + for _, row := range rows { + ret = append(ret, row.GetInt64(0)) + } + return ret, nil +} diff --git a/pkg/ddl/ingest/mock.go b/pkg/ddl/ingest/mock.go index 4d9261ecfc672..f7361d490c833 100644 --- a/pkg/ddl/ingest/mock.go +++ b/pkg/ddl/ingest/mock.go @@ -206,7 +206,7 @@ func (m *MockWriter) WriteRow(_ context.Context, key, idxVal []byte, _ kv.Handle zap.String("key", hex.EncodeToString(key)), zap.String("idxVal", hex.EncodeToString(idxVal))) - failpoint.InjectCall("onMockWriterWriteRow") + failpoint.Call(_curpkg_("onMockWriterWriteRow")) m.mu.Lock() defer m.mu.Unlock() if m.onWrite != nil { diff --git a/pkg/ddl/ingest/mock.go__failpoint_stash__ b/pkg/ddl/ingest/mock.go__failpoint_stash__ new file mode 100644 index 0000000000000..4d9261ecfc672 --- /dev/null +++ b/pkg/ddl/ingest/mock.go__failpoint_stash__ @@ -0,0 +1,236 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ingest + +import ( + "context" + "encoding/hex" + "os" + "path/filepath" + "strconv" + "sync" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + pd "github.com/tikv/pd/client" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// MockBackendCtxMgr is a mock backend context manager. +type MockBackendCtxMgr struct { + sessCtxProvider func() sessionctx.Context + runningJobs map[int64]*MockBackendCtx +} + +var _ BackendCtxMgr = (*MockBackendCtxMgr)(nil) + +// NewMockBackendCtxMgr creates a new mock backend context manager. +func NewMockBackendCtxMgr(sessCtxProvider func() sessionctx.Context) *MockBackendCtxMgr { + return &MockBackendCtxMgr{ + sessCtxProvider: sessCtxProvider, + runningJobs: make(map[int64]*MockBackendCtx), + } +} + +// CheckMoreTasksAvailable implements BackendCtxMgr.CheckMoreTaskAvailable interface. +func (m *MockBackendCtxMgr) CheckMoreTasksAvailable() (bool, error) { + return len(m.runningJobs) == 0, nil +} + +// Register implements BackendCtxMgr.Register interface. +func (m *MockBackendCtxMgr) Register(ctx context.Context, jobID int64, unique bool, etcdClient *clientv3.Client, pdSvcDiscovery pd.ServiceDiscovery, resourceGroupName string) (BackendCtx, error) { + logutil.DDLIngestLogger().Info("mock backend mgr register", zap.Int64("jobID", jobID)) + if mockCtx, ok := m.runningJobs[jobID]; ok { + return mockCtx, nil + } + sessCtx := m.sessCtxProvider() + mockCtx := &MockBackendCtx{ + mu: sync.Mutex{}, + sessCtx: sessCtx, + jobID: jobID, + } + m.runningJobs[jobID] = mockCtx + return mockCtx, nil +} + +// Unregister implements BackendCtxMgr.Unregister interface. +func (m *MockBackendCtxMgr) Unregister(jobID int64) { + if mCtx, ok := m.runningJobs[jobID]; ok { + mCtx.sessCtx.StmtCommit(context.Background()) + err := mCtx.sessCtx.CommitTxn(context.Background()) + logutil.DDLIngestLogger().Info("mock backend mgr unregister", zap.Int64("jobID", jobID), zap.Error(err)) + delete(m.runningJobs, jobID) + } +} + +// EncodeJobSortPath implements BackendCtxMgr interface. +func (m *MockBackendCtxMgr) EncodeJobSortPath(int64) string { + return "" +} + +// Load implements BackendCtxMgr.Load interface. +func (m *MockBackendCtxMgr) Load(jobID int64) (BackendCtx, bool) { + logutil.DDLIngestLogger().Info("mock backend mgr load", zap.Int64("jobID", jobID)) + if mockCtx, ok := m.runningJobs[jobID]; ok { + return mockCtx, true + } + return nil, false +} + +// ResetSessCtx is only used for mocking test. +func (m *MockBackendCtxMgr) ResetSessCtx() { + for _, mockCtx := range m.runningJobs { + mockCtx.sessCtx = m.sessCtxProvider() + } +} + +// MockBackendCtx is a mock backend context. +type MockBackendCtx struct { + sessCtx sessionctx.Context + mu sync.Mutex + jobID int64 + checkpointMgr *CheckpointManager +} + +// Register implements BackendCtx.Register interface. +func (m *MockBackendCtx) Register(indexIDs []int64, _ []bool, _ table.Table) ([]Engine, error) { + logutil.DDLIngestLogger().Info("mock backend ctx register", zap.Int64("jobID", m.jobID), zap.Int64s("indexIDs", indexIDs)) + ret := make([]Engine, 0, len(indexIDs)) + for range indexIDs { + ret = append(ret, &MockEngineInfo{sessCtx: m.sessCtx, mu: &m.mu}) + } + return ret, nil +} + +// FinishAndUnregisterEngines implements BackendCtx interface. +func (*MockBackendCtx) FinishAndUnregisterEngines(_ UnregisterOpt) error { + logutil.DDLIngestLogger().Info("mock backend ctx unregister") + return nil +} + +// CollectRemoteDuplicateRows implements BackendCtx.CollectRemoteDuplicateRows interface. +func (*MockBackendCtx) CollectRemoteDuplicateRows(indexID int64, _ table.Table) error { + logutil.DDLIngestLogger().Info("mock backend ctx collect remote duplicate rows", zap.Int64("indexID", indexID)) + return nil +} + +// Flush implements BackendCtx.Flush interface. +func (*MockBackendCtx) Flush(mode FlushMode) (flushed, imported bool, err error) { + return false, false, nil +} + +// AttachCheckpointManager attaches a checkpoint manager to the backend context. +func (m *MockBackendCtx) AttachCheckpointManager(mgr *CheckpointManager) { + m.checkpointMgr = mgr +} + +// GetCheckpointManager returns the checkpoint manager attached to the backend context. +func (m *MockBackendCtx) GetCheckpointManager() *CheckpointManager { + return m.checkpointMgr +} + +// GetLocalBackend returns the local backend. +func (m *MockBackendCtx) GetLocalBackend() *local.Backend { + b := &local.Backend{} + b.LocalStoreDir = filepath.Join(os.TempDir(), "mock_backend", strconv.FormatInt(m.jobID, 10)) + return b +} + +// MockWriteHook the hook for write in mock engine. +type MockWriteHook func(key, val []byte) + +// MockEngineInfo is a mock engine info. +type MockEngineInfo struct { + sessCtx sessionctx.Context + mu *sync.Mutex + + onWrite MockWriteHook +} + +// NewMockEngineInfo creates a new mock engine info. +func NewMockEngineInfo(sessCtx sessionctx.Context) *MockEngineInfo { + return &MockEngineInfo{ + sessCtx: sessCtx, + mu: &sync.Mutex{}, + } +} + +// Flush implements Engine.Flush interface. +func (*MockEngineInfo) Flush() error { + return nil +} + +// Close implements Engine.Close interface. +func (*MockEngineInfo) Close(_ bool) { +} + +// SetHook set the write hook. +func (m *MockEngineInfo) SetHook(onWrite func(key, val []byte)) { + m.onWrite = onWrite +} + +// CreateWriter implements Engine.CreateWriter interface. +func (m *MockEngineInfo) CreateWriter(id int, _ *backend.LocalWriterConfig) (Writer, error) { + logutil.DDLIngestLogger().Info("mock engine info create writer", zap.Int("id", id)) + return &MockWriter{sessCtx: m.sessCtx, mu: m.mu, onWrite: m.onWrite}, nil +} + +// MockWriter is a mock writer. +type MockWriter struct { + sessCtx sessionctx.Context + mu *sync.Mutex + onWrite MockWriteHook +} + +// WriteRow implements Writer.WriteRow interface. +func (m *MockWriter) WriteRow(_ context.Context, key, idxVal []byte, _ kv.Handle) error { + logutil.DDLIngestLogger().Info("mock writer write row", + zap.String("key", hex.EncodeToString(key)), + zap.String("idxVal", hex.EncodeToString(idxVal))) + + failpoint.InjectCall("onMockWriterWriteRow") + m.mu.Lock() + defer m.mu.Unlock() + if m.onWrite != nil { + m.onWrite(key, idxVal) + return nil + } + txn, err := m.sessCtx.Txn(true) + if err != nil { + return err + } + err = txn.Set(key, idxVal) + if err != nil { + return err + } + if MockExecAfterWriteRow != nil { + MockExecAfterWriteRow() + } + return nil +} + +// LockForWrite implements Writer.LockForWrite interface. +func (*MockWriter) LockForWrite() func() { + return func() {} +} + +// MockExecAfterWriteRow is only used for test. +var MockExecAfterWriteRow func() diff --git a/pkg/ddl/job_scheduler.go b/pkg/ddl/job_scheduler.go index 1c3446b36704a..75d1257c70fec 100644 --- a/pkg/ddl/job_scheduler.go +++ b/pkg/ddl/job_scheduler.go @@ -177,7 +177,7 @@ func (s *jobScheduler) close() { if s.generalDDLWorkerPool != nil { s.generalDDLWorkerPool.close() } - failpoint.InjectCall("afterSchedulerClose") + failpoint.Call(_curpkg_("afterSchedulerClose")) } func hasSysDB(job *model.Job) bool { @@ -286,7 +286,7 @@ func (s *jobScheduler) schedule() error { if err := s.schCtx.Err(); err != nil { return err } - failpoint.Inject("ownerResignAfterDispatchLoopCheck", func() { + if _, _err_ := failpoint.Eval(_curpkg_("ownerResignAfterDispatchLoopCheck")); _err_ == nil { if ingest.ResignOwnerForTest.Load() { err2 := s.ownerManager.ResignOwner(context.Background()) if err2 != nil { @@ -294,7 +294,7 @@ func (s *jobScheduler) schedule() error { } ingest.ResignOwnerForTest.Store(false) } - }) + } select { case <-s.ddlJobNotifyCh: case <-ticker.C: @@ -311,7 +311,7 @@ func (s *jobScheduler) schedule() error { if err := s.checkAndUpdateClusterState(false); err != nil { continue } - failpoint.InjectCall("beforeLoadAndDeliverJobs") + failpoint.Call(_curpkg_("beforeLoadAndDeliverJobs")) if err := s.loadAndDeliverJobs(se); err != nil { logutil.SampleLogger().Warn("load and deliver jobs failed", zap.Error(err)) } @@ -451,7 +451,7 @@ func (s *jobScheduler) mustReloadSchemas() { // the worker will run the job until it's finished, paused or another owner takes // over and finished it. func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) { - failpoint.InjectCall("beforeDeliveryJob", job) + failpoint.Call(_curpkg_("beforeDeliveryJob"), job) injectFailPointForGetJob(job) jobID, involvedSchemaInfos := job.ID, job.GetInvolvingSchemaInfo() s.runningJobs.addRunning(jobID, involvedSchemaInfos) @@ -462,7 +462,7 @@ func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) if r != nil { logutil.DDLLogger().Error("panic in deliveryJob", zap.Any("recover", r), zap.Stack("stack")) } - failpoint.InjectCall("afterDeliveryJob", job) + failpoint.Call(_curpkg_("afterDeliveryJob"), job) // Because there is a gap between `allIDs()` and `checkRunnable()`, // we append unfinished job to pending atomically to prevent `getJob()` // chosing another runnable job that involves the same schema object. @@ -483,10 +483,10 @@ func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) // or the job is finished by another owner. // TODO for JobStateRollbackDone we have to query 1 additional time when the // job is already moved to history. - failpoint.InjectCall("beforeRefreshJob", job) + failpoint.Call(_curpkg_("beforeRefreshJob"), job) for { job, err = s.sysTblMgr.GetJobByID(s.schCtx, jobID) - failpoint.InjectCall("mockGetJobByIDFail", &err) + failpoint.Call(_curpkg_("mockGetJobByIDFail"), &err) if err == nil { break } @@ -511,7 +511,7 @@ func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) // transitOneJobStepAndWaitSync runs one step of the DDL job, persist it and // waits for other TiDB node to synchronize. func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) error { - failpoint.InjectCall("beforeRunOneJobStep") + failpoint.Call(_curpkg_("beforeRunOneJobStep")) ownerID := s.ownerManager.ID() // suppose we failed to sync version last time, we need to check and sync it // before run to maintain the 2-version invariant. @@ -545,16 +545,16 @@ func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) tidblogutil.Logger(wk.logCtx).Info("handle ddl job failed", zap.Error(err), zap.Stringer("job", job)) return err } - failpoint.Inject("mockDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockDownBeforeUpdateGlobalVersion")); _err_ == nil { if val.(bool) { if mockDDLErrOnce == 0 { mockDDLErrOnce = schemaVer - failpoint.Return(errors.New("mock down before update global version")) + return errors.New("mock down before update global version") } } - }) + } - failpoint.InjectCall("beforeWaitSchemaChanged", job, schemaVer) + failpoint.Call(_curpkg_("beforeWaitSchemaChanged"), job, schemaVer) // Here means the job enters another state (delete only, write only, public, etc...) or is cancelled. // If the job is done or still running or rolling back, we will wait 2 * lease time or util MDL synced to guarantee other servers to update // the newest schema. @@ -564,7 +564,7 @@ func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) s.cleanMDLInfo(job, ownerID) s.synced(job) - failpoint.InjectCall("onJobUpdated", job) + failpoint.Call(_curpkg_("onJobUpdated"), job) return nil } @@ -626,11 +626,11 @@ const ( ) func insertDDLJobs2Table(ctx context.Context, se *sess.Session, jobWs ...*JobWrapper) error { - failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockAddBatchDDLJobsErr")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) + return errors.Errorf("mockAddBatchDDLJobsErr") } - }) + } if len(jobWs) == 0 { return nil } diff --git a/pkg/ddl/job_scheduler.go__failpoint_stash__ b/pkg/ddl/job_scheduler.go__failpoint_stash__ new file mode 100644 index 0000000000000..1c3446b36704a --- /dev/null +++ b/pkg/ddl/job_scheduler.go__failpoint_stash__ @@ -0,0 +1,837 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "fmt" + "runtime" + "slices" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/ddl/syncer" + "github.com/pingcap/tidb/pkg/ddl/systable" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/intest" + tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +var ( + // addingDDLJobNotifyKey is the key in etcd to notify DDL scheduler that there + // is a new DDL job. + addingDDLJobNotifyKey = "/tidb/ddl/add_ddl_job_general" + dispatchLoopWaitingDuration = 1 * time.Second + schedulerLoopRetryInterval = time.Second +) + +func init() { + // In test the wait duration can be reduced to make test case run faster + if intest.InTest { + dispatchLoopWaitingDuration = 50 * time.Millisecond + } +} + +type jobType int + +func (t jobType) String() string { + switch t { + case jobTypeGeneral: + return "general" + case jobTypeReorg: + return "reorg" + } + return "unknown job type: " + strconv.Itoa(int(t)) +} + +const ( + jobTypeGeneral jobType = iota + jobTypeReorg +) + +type ownerListener struct { + ddl *ddl + scheduler *jobScheduler +} + +var _ owner.Listener = (*ownerListener)(nil) + +func (l *ownerListener) OnBecomeOwner() { + ctx, cancelFunc := context.WithCancel(l.ddl.ddlCtx.ctx) + sysTblMgr := systable.NewManager(l.ddl.sessPool) + l.scheduler = &jobScheduler{ + schCtx: ctx, + cancel: cancelFunc, + runningJobs: newRunningJobs(), + sysTblMgr: sysTblMgr, + schemaLoader: l.ddl.schemaLoader, + minJobIDRefresher: l.ddl.minJobIDRefresher, + + ddlCtx: l.ddl.ddlCtx, + ddlJobNotifyCh: l.ddl.ddlJobNotifyCh, + sessPool: l.ddl.sessPool, + delRangeMgr: l.ddl.delRangeMgr, + } + l.ddl.reorgCtx.setOwnerTS(time.Now().Unix()) + l.scheduler.start() +} + +func (l *ownerListener) OnRetireOwner() { + if l.scheduler == nil { + return + } + l.scheduler.close() +} + +// jobScheduler is used to schedule the DDL jobs, it's only run on the DDL owner. +type jobScheduler struct { + // *ddlCtx already have context named as "ctx", so we use "schCtx" here to avoid confusion. + schCtx context.Context + cancel context.CancelFunc + wg tidbutil.WaitGroupWrapper + runningJobs *runningJobs + sysTblMgr systable.Manager + schemaLoader SchemaLoader + minJobIDRefresher *systable.MinJobIDRefresher + + // those fields are created or initialized on start + reorgWorkerPool *workerPool + generalDDLWorkerPool *workerPool + seqAllocator atomic.Uint64 + + // those fields are shared with 'ddl' instance + // TODO ddlCtx is too large for here, we should remove dependency on it. + *ddlCtx + ddlJobNotifyCh chan struct{} + sessPool *sess.Pool + delRangeMgr delRangeManager +} + +func (s *jobScheduler) start() { + workerFactory := func(tp workerType) func() (pools.Resource, error) { + return func() (pools.Resource, error) { + wk := newWorker(s.schCtx, tp, s.sessPool, s.delRangeMgr, s.ddlCtx) + sessForJob, err := s.sessPool.Get() + if err != nil { + return nil, err + } + wk.seqAllocator = &s.seqAllocator + sessForJob.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + wk.sess = sess.NewSession(sessForJob) + metrics.DDLCounter.WithLabelValues(fmt.Sprintf("%s_%s", metrics.CreateDDL, wk.String())).Inc() + return wk, nil + } + } + // reorg worker count at least 1 at most 10. + reorgCnt := min(max(runtime.GOMAXPROCS(0)/4, 1), reorgWorkerCnt) + s.reorgWorkerPool = newDDLWorkerPool(pools.NewResourcePool(workerFactory(addIdxWorker), reorgCnt, reorgCnt, 0), jobTypeReorg) + s.generalDDLWorkerPool = newDDLWorkerPool(pools.NewResourcePool(workerFactory(generalWorker), generalWorkerCnt, generalWorkerCnt, 0), jobTypeGeneral) + s.wg.RunWithLog(s.scheduleLoop) + s.wg.RunWithLog(func() { + s.schemaSyncer.SyncJobSchemaVerLoop(s.schCtx) + }) +} + +func (s *jobScheduler) close() { + s.cancel() + s.wg.Wait() + if s.reorgWorkerPool != nil { + s.reorgWorkerPool.close() + } + if s.generalDDLWorkerPool != nil { + s.generalDDLWorkerPool.close() + } + failpoint.InjectCall("afterSchedulerClose") +} + +func hasSysDB(job *model.Job) bool { + for _, info := range job.GetInvolvingSchemaInfo() { + if tidbutil.IsSysDB(info.Database) { + return true + } + } + return false +} + +func (s *jobScheduler) processJobDuringUpgrade(sess *sess.Session, job *model.Job) (isRunnable bool, err error) { + if s.stateSyncer.IsUpgradingState() { + if job.IsPaused() { + return false, nil + } + // We need to turn the 'pausing' job to be 'paused' in ddl worker, + // and stop the reorganization workers + if job.IsPausing() || hasSysDB(job) { + return true, nil + } + var errs []error + // During binary upgrade, pause all running DDL jobs + errs, err = PauseJobsBySystem(sess.Session(), []int64{job.ID}) + if len(errs) > 0 && errs[0] != nil { + err = errs[0] + } + + if err != nil { + isCannotPauseDDLJobErr := dbterror.ErrCannotPauseDDLJob.Equal(err) + logutil.DDLUpgradingLogger().Warn("pause the job failed", zap.Stringer("job", job), + zap.Bool("isRunnable", isCannotPauseDDLJobErr), zap.Error(err)) + if isCannotPauseDDLJobErr { + return true, nil + } + } else { + logutil.DDLUpgradingLogger().Warn("pause the job successfully", zap.Stringer("job", job)) + } + + return false, nil + } + + if job.IsPausedBySystem() { + var errs []error + errs, err = ResumeJobsBySystem(sess.Session(), []int64{job.ID}) + if len(errs) > 0 && errs[0] != nil { + logutil.DDLUpgradingLogger().Warn("normal cluster state, resume the job failed", zap.Stringer("job", job), zap.Error(errs[0])) + return false, errs[0] + } + if err != nil { + logutil.DDLUpgradingLogger().Warn("normal cluster state, resume the job failed", zap.Stringer("job", job), zap.Error(err)) + return false, err + } + logutil.DDLUpgradingLogger().Warn("normal cluster state, resume the job successfully", zap.Stringer("job", job)) + return false, errors.Errorf("system paused job:%d need to be resumed", job.ID) + } + + if job.IsPaused() { + return false, nil + } + + return true, nil +} + +func (s *jobScheduler) scheduleLoop() { + const retryInterval = 3 * time.Second + for { + err := s.schedule() + if err == context.Canceled { + logutil.DDLLogger().Info("scheduleLoop quit due to context canceled") + return + } + logutil.DDLLogger().Warn("scheduleLoop failed, retrying", + zap.Error(err)) + + select { + case <-s.schCtx.Done(): + logutil.DDLLogger().Info("scheduleLoop quit due to context done") + return + case <-time.After(retryInterval): + } + } +} + +func (s *jobScheduler) schedule() error { + sessCtx, err := s.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer s.sessPool.Put(sessCtx) + se := sess.NewSession(sessCtx) + var notifyDDLJobByEtcdCh clientv3.WatchChan + if s.etcdCli != nil { + notifyDDLJobByEtcdCh = s.etcdCli.Watch(s.schCtx, addingDDLJobNotifyKey) + } + if err := s.checkAndUpdateClusterState(true); err != nil { + return errors.Trace(err) + } + ticker := time.NewTicker(dispatchLoopWaitingDuration) + defer ticker.Stop() + // TODO move waitSchemaSyncedController out of ddlCtx. + s.clearOnceMap() + s.mustReloadSchemas() + + for { + if err := s.schCtx.Err(); err != nil { + return err + } + failpoint.Inject("ownerResignAfterDispatchLoopCheck", func() { + if ingest.ResignOwnerForTest.Load() { + err2 := s.ownerManager.ResignOwner(context.Background()) + if err2 != nil { + logutil.DDLLogger().Info("resign meet error", zap.Error(err2)) + } + ingest.ResignOwnerForTest.Store(false) + } + }) + select { + case <-s.ddlJobNotifyCh: + case <-ticker.C: + case _, ok := <-notifyDDLJobByEtcdCh: + if !ok { + logutil.DDLLogger().Warn("start worker watch channel closed", zap.String("watch key", addingDDLJobNotifyKey)) + notifyDDLJobByEtcdCh = s.etcdCli.Watch(s.schCtx, addingDDLJobNotifyKey) + time.Sleep(time.Second) + continue + } + case <-s.schCtx.Done(): + return s.schCtx.Err() + } + if err := s.checkAndUpdateClusterState(false); err != nil { + continue + } + failpoint.InjectCall("beforeLoadAndDeliverJobs") + if err := s.loadAndDeliverJobs(se); err != nil { + logutil.SampleLogger().Warn("load and deliver jobs failed", zap.Error(err)) + } + } +} + +// TODO make it run in a separate routine. +func (s *jobScheduler) checkAndUpdateClusterState(needUpdate bool) error { + select { + case _, ok := <-s.stateSyncer.WatchChan(): + if !ok { + // TODO stateSyncer should only be started when we are the owner, and use + // the context of scheduler, will refactor it later. + s.stateSyncer.Rewatch(s.ddlCtx.ctx) + } + default: + if !needUpdate { + return nil + } + } + + oldState := s.stateSyncer.IsUpgradingState() + stateInfo, err := s.stateSyncer.GetGlobalState(s.schCtx) + if err != nil { + logutil.DDLLogger().Warn("get global state failed", zap.Error(err)) + return errors.Trace(err) + } + logutil.DDLLogger().Info("get global state and global state change", + zap.Bool("oldState", oldState), zap.Bool("currState", s.stateSyncer.IsUpgradingState())) + + ownerOp := owner.OpNone + if stateInfo.State == syncer.StateUpgrading { + ownerOp = owner.OpSyncUpgradingState + } + err = s.ownerManager.SetOwnerOpValue(s.schCtx, ownerOp) + if err != nil { + logutil.DDLLogger().Warn("the owner sets global state to owner operator value failed", zap.Error(err)) + return errors.Trace(err) + } + logutil.DDLLogger().Info("the owner sets owner operator value", zap.Stringer("ownerOp", ownerOp)) + return nil +} + +func (s *jobScheduler) loadAndDeliverJobs(se *sess.Session) error { + if s.generalDDLWorkerPool.available() == 0 && s.reorgWorkerPool.available() == 0 { + return nil + } + + defer s.runningJobs.resetAllPending() + + const getJobSQL = `select reorg, job_meta from mysql.tidb_ddl_job where job_id >= %d %s order by job_id` + var whereClause string + if ids := s.runningJobs.allIDs(); len(ids) > 0 { + whereClause = fmt.Sprintf("and job_id not in (%s)", ids) + } + sql := fmt.Sprintf(getJobSQL, s.minJobIDRefresher.GetCurrMinJobID(), whereClause) + rows, err := se.Execute(context.Background(), sql, "load_ddl_jobs") + if err != nil { + return errors.Trace(err) + } + for _, row := range rows { + reorgJob := row.GetInt64(0) == 1 + targetPool := s.generalDDLWorkerPool + if reorgJob { + targetPool = s.reorgWorkerPool + } + jobBinary := row.GetBytes(1) + + job := model.Job{} + err = job.Decode(jobBinary) + if err != nil { + return errors.Trace(err) + } + + involving := job.GetInvolvingSchemaInfo() + if targetPool.available() == 0 { + s.runningJobs.addPending(involving) + continue + } + + isRunnable, err := s.processJobDuringUpgrade(se, &job) + if err != nil { + return errors.Trace(err) + } + if !isRunnable { + s.runningJobs.addPending(involving) + continue + } + + if !s.runningJobs.checkRunnable(job.ID, involving) { + s.runningJobs.addPending(involving) + continue + } + + wk, err := targetPool.get() + if err != nil { + return errors.Trace(err) + } + intest.Assert(wk != nil, "worker should not be nil") + if wk == nil { + // should not happen, we have checked available() before, and we are + // the only routine consumes worker. + logutil.DDLLogger().Info("no worker available now", zap.Stringer("type", targetPool.tp())) + s.runningJobs.addPending(involving) + continue + } + + s.deliveryJob(wk, targetPool, &job) + + if s.generalDDLWorkerPool.available() == 0 && s.reorgWorkerPool.available() == 0 { + break + } + } + return nil +} + +// mustReloadSchemas is used to reload schema when we become the DDL owner, in case +// the schema version is outdated before we become the owner. +// It will keep reloading schema until either success or context done. +// Domain also have a similar method 'mustReload', but its methods don't accept context. +func (s *jobScheduler) mustReloadSchemas() { + for { + err := s.schemaLoader.Reload() + if err == nil { + return + } + logutil.DDLLogger().Warn("reload schema failed, will retry later", zap.Error(err)) + select { + case <-s.schCtx.Done(): + return + case <-time.After(schedulerLoopRetryInterval): + } + } +} + +// deliveryJob deliver the job to the worker to run it asynchronously. +// the worker will run the job until it's finished, paused or another owner takes +// over and finished it. +func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) { + failpoint.InjectCall("beforeDeliveryJob", job) + injectFailPointForGetJob(job) + jobID, involvedSchemaInfos := job.ID, job.GetInvolvingSchemaInfo() + s.runningJobs.addRunning(jobID, involvedSchemaInfos) + metrics.DDLRunningJobCount.WithLabelValues(pool.tp().String()).Inc() + s.wg.Run(func() { + defer func() { + r := recover() + if r != nil { + logutil.DDLLogger().Error("panic in deliveryJob", zap.Any("recover", r), zap.Stack("stack")) + } + failpoint.InjectCall("afterDeliveryJob", job) + // Because there is a gap between `allIDs()` and `checkRunnable()`, + // we append unfinished job to pending atomically to prevent `getJob()` + // chosing another runnable job that involves the same schema object. + moveRunningJobsToPending := r != nil || (job != nil && !job.IsFinished()) + s.runningJobs.finishOrPendJob(jobID, involvedSchemaInfos, moveRunningJobsToPending) + asyncNotify(s.ddlJobNotifyCh) + metrics.DDLRunningJobCount.WithLabelValues(pool.tp().String()).Dec() + pool.put(wk) + }() + for { + err := s.transitOneJobStepAndWaitSync(wk, job) + if err != nil { + logutil.DDLLogger().Info("run job failed", zap.Error(err), zap.Stringer("job", job)) + } else if job.InFinalState() { + return + } + // we have to refresh the job, to handle cases like job cancel or pause + // or the job is finished by another owner. + // TODO for JobStateRollbackDone we have to query 1 additional time when the + // job is already moved to history. + failpoint.InjectCall("beforeRefreshJob", job) + for { + job, err = s.sysTblMgr.GetJobByID(s.schCtx, jobID) + failpoint.InjectCall("mockGetJobByIDFail", &err) + if err == nil { + break + } + + if err == systable.ErrNotFound { + logutil.DDLLogger().Info("job not found, might already finished", + zap.Int64("job_id", jobID)) + return + } + logutil.DDLLogger().Error("get job failed", zap.Int64("job_id", jobID), zap.Error(err)) + select { + case <-s.schCtx.Done(): + return + case <-time.After(500 * time.Millisecond): + continue + } + } + } + }) +} + +// transitOneJobStepAndWaitSync runs one step of the DDL job, persist it and +// waits for other TiDB node to synchronize. +func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) error { + failpoint.InjectCall("beforeRunOneJobStep") + ownerID := s.ownerManager.ID() + // suppose we failed to sync version last time, we need to check and sync it + // before run to maintain the 2-version invariant. + if !job.NotStarted() && (!s.isSynced(job) || !s.maybeAlreadyRunOnce(job.ID)) { + if variable.EnableMDL.Load() { + version, err := s.sysTblMgr.GetMDLVer(s.schCtx, job.ID) + if err == nil { + err = waitSchemaSyncedForMDL(wk.ctx, s.ddlCtx, job, version) + if err != nil { + return err + } + s.setAlreadyRunOnce(job.ID) + s.cleanMDLInfo(job, ownerID) + return nil + } else if err != systable.ErrNotFound { + wk.jobLogger(job).Warn("check MDL info failed", zap.Error(err)) + return err + } + } else { + err := waitSchemaSynced(wk.ctx, s.ddlCtx, job) + if err != nil { + time.Sleep(time.Second) + return err + } + s.setAlreadyRunOnce(job.ID) + } + } + + schemaVer, err := wk.transitOneJobStep(s.ddlCtx, job) + if err != nil { + tidblogutil.Logger(wk.logCtx).Info("handle ddl job failed", zap.Error(err), zap.Stringer("job", job)) + return err + } + failpoint.Inject("mockDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val.(bool) { + if mockDDLErrOnce == 0 { + mockDDLErrOnce = schemaVer + failpoint.Return(errors.New("mock down before update global version")) + } + } + }) + + failpoint.InjectCall("beforeWaitSchemaChanged", job, schemaVer) + // Here means the job enters another state (delete only, write only, public, etc...) or is cancelled. + // If the job is done or still running or rolling back, we will wait 2 * lease time or util MDL synced to guarantee other servers to update + // the newest schema. + if err = waitSchemaChanged(wk.ctx, s.ddlCtx, schemaVer, job); err != nil { + return err + } + s.cleanMDLInfo(job, ownerID) + s.synced(job) + + failpoint.InjectCall("onJobUpdated", job) + return nil +} + +// cleanMDLInfo cleans metadata lock info. +func (s *jobScheduler) cleanMDLInfo(job *model.Job, ownerID string) { + if !variable.EnableMDL.Load() { + return + } + var sql string + if tidbutil.IsSysDB(strings.ToLower(job.SchemaName)) { + // DDLs that modify system tables could only happen in upgrade process, + // we should not reference 'owner_id'. Otherwise, there is a circular blocking problem. + sql = fmt.Sprintf("delete from mysql.tidb_mdl_info where job_id = %d", job.ID) + } else { + sql = fmt.Sprintf("delete from mysql.tidb_mdl_info where job_id = %d and owner_id = '%s'", job.ID, ownerID) + } + sctx, _ := s.sessPool.Get() + defer s.sessPool.Put(sctx) + se := sess.NewSession(sctx) + se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + _, err := se.Execute(s.schCtx, sql, "delete-mdl-info") + if err != nil { + logutil.DDLLogger().Warn("unexpected error when clean mdl info", zap.Int64("job ID", job.ID), zap.Error(err)) + return + } + // TODO we need clean it when version of JobStateRollbackDone is synced also. + if job.State == model.JobStateSynced && s.etcdCli != nil { + path := fmt.Sprintf("%s/%d/", util.DDLAllSchemaVersionsByJob, job.ID) + _, err = s.etcdCli.Delete(s.schCtx, path, clientv3.WithPrefix()) + if err != nil { + logutil.DDLLogger().Warn("delete versions failed", zap.Int64("job ID", job.ID), zap.Error(err)) + } + } +} + +func (d *ddl) getTableByTxn(r autoid.Requirement, schemaID, tableID int64) (*model.DBInfo, table.Table, error) { + var tbl table.Table + var dbInfo *model.DBInfo + err := kv.RunInNewTxn(d.ctx, r.Store(), false, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + var err1 error + dbInfo, err1 = t.GetDatabase(schemaID) + if err1 != nil { + return errors.Trace(err1) + } + tblInfo, err1 := getTableInfo(t, tableID, schemaID) + if err1 != nil { + return errors.Trace(err1) + } + tbl, err1 = getTable(r, schemaID, tblInfo) + return errors.Trace(err1) + }) + return dbInfo, tbl, err +} + +const ( + addDDLJobSQL = "insert into mysql.tidb_ddl_job(job_id, reorg, schema_ids, table_ids, job_meta, type, processing) values" + updateDDLJobSQL = "update mysql.tidb_ddl_job set job_meta = %s where job_id = %d" +) + +func insertDDLJobs2Table(ctx context.Context, se *sess.Session, jobWs ...*JobWrapper) error { + failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) + } + }) + if len(jobWs) == 0 { + return nil + } + var sql bytes.Buffer + sql.WriteString(addDDLJobSQL) + for i, jobW := range jobWs { + b, err := jobW.Encode(true) + if err != nil { + return err + } + if i != 0 { + sql.WriteString(",") + } + fmt.Fprintf(&sql, "(%d, %t, %s, %s, %s, %d, %t)", jobW.ID, jobW.MayNeedReorg(), + strconv.Quote(job2SchemaIDs(jobW.Job)), strconv.Quote(job2TableIDs(jobW.Job)), + util.WrapKey2String(b), jobW.Type, !jobW.NotStarted()) + } + se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + _, err := se.Execute(ctx, sql.String(), "insert_job") + logutil.DDLLogger().Debug("add job to mysql.tidb_ddl_job table", zap.String("sql", sql.String())) + return errors.Trace(err) +} + +func job2SchemaIDs(job *model.Job) string { + return job2UniqueIDs(job, true) +} + +func job2TableIDs(job *model.Job) string { + return job2UniqueIDs(job, false) +} + +func job2UniqueIDs(job *model.Job, schema bool) string { + switch job.Type { + case model.ActionExchangeTablePartition, model.ActionRenameTables, model.ActionRenameTable: + var ids []int64 + if schema { + ids = job.CtxVars[0].([]int64) + } else { + ids = job.CtxVars[1].([]int64) + } + set := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + set[id] = struct{}{} + } + + s := make([]string, 0, len(set)) + for id := range set { + s = append(s, strconv.FormatInt(id, 10)) + } + slices.Sort(s) + return strings.Join(s, ",") + case model.ActionTruncateTable: + if schema { + return strconv.FormatInt(job.SchemaID, 10) + } + return strconv.FormatInt(job.TableID, 10) + "," + strconv.FormatInt(job.Args[0].(int64), 10) + } + if schema { + return strconv.FormatInt(job.SchemaID, 10) + } + return strconv.FormatInt(job.TableID, 10) +} + +func updateDDLJob2Table(se *sess.Session, job *model.Job, updateRawArgs bool) error { + b, err := job.Encode(updateRawArgs) + if err != nil { + return err + } + sql := fmt.Sprintf(updateDDLJobSQL, util.WrapKey2String(b), job.ID) + _, err = se.Execute(context.Background(), sql, "update_job") + return errors.Trace(err) +} + +// getDDLReorgHandle gets DDL reorg handle. +func getDDLReorgHandle(se *sess.Session, job *model.Job) (element *meta.Element, + startKey, endKey kv.Key, physicalTableID int64, err error) { + sql := fmt.Sprintf("select ele_id, ele_type, start_key, end_key, physical_id, reorg_meta from mysql.tidb_ddl_reorg where job_id = %d", job.ID) + ctx := kv.WithInternalSourceType(context.Background(), getDDLRequestSource(job.Type)) + rows, err := se.Execute(ctx, sql, "get_handle") + if err != nil { + return nil, nil, nil, 0, err + } + if len(rows) == 0 { + return nil, nil, nil, 0, meta.ErrDDLReorgElementNotExist + } + id := rows[0].GetInt64(0) + tp := rows[0].GetBytes(1) + element = &meta.Element{ + ID: id, + TypeKey: tp, + } + startKey = rows[0].GetBytes(2) + endKey = rows[0].GetBytes(3) + physicalTableID = rows[0].GetInt64(4) + return +} + +func getImportedKeyFromCheckpoint(se *sess.Session, job *model.Job) (imported kv.Key, physicalTableID int64, err error) { + sql := fmt.Sprintf("select reorg_meta from mysql.tidb_ddl_reorg where job_id = %d", job.ID) + ctx := kv.WithInternalSourceType(context.Background(), getDDLRequestSource(job.Type)) + rows, err := se.Execute(ctx, sql, "get_handle") + if err != nil { + return nil, 0, err + } + if len(rows) == 0 { + return nil, 0, meta.ErrDDLReorgElementNotExist + } + if !rows[0].IsNull(0) { + rawReorgMeta := rows[0].GetBytes(0) + var reorgMeta ingest.JobReorgMeta + err = json.Unmarshal(rawReorgMeta, &reorgMeta) + if err != nil { + return nil, 0, errors.Trace(err) + } + if cp := reorgMeta.Checkpoint; cp != nil { + logutil.DDLIngestLogger().Info("resume physical table ID from checkpoint", + zap.Int64("jobID", job.ID), + zap.String("global sync key", hex.EncodeToString(cp.GlobalSyncKey)), + zap.Int64("checkpoint physical ID", cp.PhysicalID)) + return cp.GlobalSyncKey, cp.PhysicalID, nil + } + } + return +} + +// updateDDLReorgHandle update startKey, endKey physicalTableID and element of the handle. +// Caller should wrap this in a separate transaction, to avoid conflicts. +func updateDDLReorgHandle(se *sess.Session, jobID int64, startKey kv.Key, endKey kv.Key, physicalTableID int64, element *meta.Element) error { + sql := fmt.Sprintf("update mysql.tidb_ddl_reorg set ele_id = %d, ele_type = %s, start_key = %s, end_key = %s, physical_id = %d where job_id = %d", + element.ID, util.WrapKey2String(element.TypeKey), util.WrapKey2String(startKey), util.WrapKey2String(endKey), physicalTableID, jobID) + _, err := se.Execute(context.Background(), sql, "update_handle") + return err +} + +// initDDLReorgHandle initializes the handle for ddl reorg. +func initDDLReorgHandle(s *sess.Session, jobID int64, startKey kv.Key, endKey kv.Key, physicalTableID int64, element *meta.Element) error { + rawReorgMeta, err := json.Marshal(ingest.JobReorgMeta{ + Checkpoint: &ingest.ReorgCheckpoint{ + PhysicalID: physicalTableID, + Version: ingest.JobCheckpointVersionCurrent, + }}) + if err != nil { + return errors.Trace(err) + } + del := fmt.Sprintf("delete from mysql.tidb_ddl_reorg where job_id = %d", jobID) + ins := fmt.Sprintf("insert into mysql.tidb_ddl_reorg(job_id, ele_id, ele_type, start_key, end_key, physical_id, reorg_meta) values (%d, %d, %s, %s, %s, %d, %s)", + jobID, element.ID, util.WrapKey2String(element.TypeKey), util.WrapKey2String(startKey), util.WrapKey2String(endKey), physicalTableID, util.WrapKey2String(rawReorgMeta)) + return s.RunInTxn(func(se *sess.Session) error { + _, err := se.Execute(context.Background(), del, "init_handle") + if err != nil { + logutil.DDLLogger().Info("initDDLReorgHandle failed to delete", zap.Int64("jobID", jobID), zap.Error(err)) + } + _, err = se.Execute(context.Background(), ins, "init_handle") + return err + }) +} + +// deleteDDLReorgHandle deletes the handle for ddl reorg. +func removeDDLReorgHandle(se *sess.Session, job *model.Job, elements []*meta.Element) error { + if len(elements) == 0 { + return nil + } + sql := fmt.Sprintf("delete from mysql.tidb_ddl_reorg where job_id = %d", job.ID) + return se.RunInTxn(func(se *sess.Session) error { + _, err := se.Execute(context.Background(), sql, "remove_handle") + return err + }) +} + +// removeReorgElement removes the element from ddl reorg, it is the same with removeDDLReorgHandle, only used in failpoint +func removeReorgElement(se *sess.Session, job *model.Job) error { + sql := fmt.Sprintf("delete from mysql.tidb_ddl_reorg where job_id = %d", job.ID) + return se.RunInTxn(func(se *sess.Session) error { + _, err := se.Execute(context.Background(), sql, "remove_handle") + return err + }) +} + +// cleanDDLReorgHandles removes handles that are no longer needed. +func cleanDDLReorgHandles(se *sess.Session, job *model.Job) error { + sql := "delete from mysql.tidb_ddl_reorg where job_id = " + strconv.FormatInt(job.ID, 10) + return se.RunInTxn(func(se *sess.Session) error { + _, err := se.Execute(context.Background(), sql, "clean_handle") + return err + }) +} + +func getJobsBySQL(se *sess.Session, tbl, condition string) ([]*model.Job, error) { + rows, err := se.Execute(context.Background(), fmt.Sprintf("select job_meta from mysql.%s where %s", tbl, condition), "get_job") + if err != nil { + return nil, errors.Trace(err) + } + jobs := make([]*model.Job, 0, 16) + for _, row := range rows { + jobBinary := row.GetBytes(0) + job := model.Job{} + err := job.Decode(jobBinary) + if err != nil { + return nil, errors.Trace(err) + } + jobs = append(jobs, &job) + } + return jobs, nil +} diff --git a/pkg/ddl/job_submitter.go b/pkg/ddl/job_submitter.go index da78b30bc5a3e..dd99d93376ca9 100644 --- a/pkg/ddl/job_submitter.go +++ b/pkg/ddl/job_submitter.go @@ -55,7 +55,7 @@ func (d *ddl) limitDDLJobs() { // the channel is never closed case jobW := <-ch: jobWs = jobWs[:0] - failpoint.InjectCall("afterGetJobFromLimitCh", ch) + failpoint.Call(_curpkg_("afterGetJobFromLimitCh"), ch) jobLen := len(ch) jobWs = append(jobWs, jobW) for i := 0; i < jobLen; i++ { @@ -369,11 +369,11 @@ func (d *ddl) addBatchDDLJobs2Queue(jobWs []*JobWrapper) error { return errors.Trace(err) } } - failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockAddBatchDDLJobsErr")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) + return errors.Errorf("mockAddBatchDDLJobsErr") } - }) + } return nil }) } @@ -399,11 +399,11 @@ func (*ddl) checkFlashbackJobInQueue(t *meta.Meta) error { func GenGIDAndInsertJobsWithRetry(ctx context.Context, ddlSe *sess.Session, jobWs []*JobWrapper) error { count := getRequiredGIDCount(jobWs) return genGIDAndCallWithRetry(ctx, ddlSe, count, func(ids []int64) error { - failpoint.Inject("mockGenGlobalIDFail", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockGenGlobalIDFail")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.New("gofail genGlobalIDs error")) + return errors.New("gofail genGlobalIDs error") } - }) + } assignGIDsForJobs(jobWs, ids) injectModifyJobArgFailPoint(jobWs) return insertDDLJobs2Table(ctx, ddlSe, jobWs...) @@ -578,7 +578,7 @@ func lockGlobalIDKey(ctx context.Context, ddlSe *sess.Session, txn kv.Transactio // TODO this failpoint is only checking how job scheduler handle // corrupted job args, we should test it there by UT, not here. func injectModifyJobArgFailPoint(jobWs []*JobWrapper) { - failpoint.Inject("MockModifyJobArg", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockModifyJobArg")); _err_ == nil { if val.(bool) { for _, jobW := range jobWs { job := jobW.Job @@ -592,7 +592,7 @@ func injectModifyJobArgFailPoint(jobWs []*JobWrapper) { } } } - }) + } } func setJobStateToQueueing(job *model.Job) { diff --git a/pkg/ddl/job_submitter.go__failpoint_stash__ b/pkg/ddl/job_submitter.go__failpoint_stash__ new file mode 100644 index 0000000000000..da78b30bc5a3e --- /dev/null +++ b/pkg/ddl/job_submitter.go__failpoint_stash__ @@ -0,0 +1,669 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/mathutil" + tikv "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" + "go.uber.org/zap" +) + +func (d *ddl) limitDDLJobs() { + defer util.Recover(metrics.LabelDDL, "limitDDLJobs", nil, true) + + jobWs := make([]*JobWrapper, 0, batchAddingJobs) + ch := d.limitJobCh + for { + select { + // the channel is never closed + case jobW := <-ch: + jobWs = jobWs[:0] + failpoint.InjectCall("afterGetJobFromLimitCh", ch) + jobLen := len(ch) + jobWs = append(jobWs, jobW) + for i := 0; i < jobLen; i++ { + jobWs = append(jobWs, <-ch) + } + d.addBatchDDLJobs(jobWs) + case <-d.ctx.Done(): + return + } + } +} + +// addBatchDDLJobs gets global job IDs and puts the DDL jobs in the DDL queue. +func (d *ddl) addBatchDDLJobs(jobWs []*JobWrapper) { + startTime := time.Now() + var ( + err error + newWs []*JobWrapper + ) + // DDLForce2Queue is a flag to tell DDL worker to always push the job to the DDL queue. + toTable := !variable.DDLForce2Queue.Load() + fastCreate := variable.EnableFastCreateTable.Load() + if toTable { + if fastCreate { + newWs, err = mergeCreateTableJobs(jobWs) + if err != nil { + logutil.DDLLogger().Warn("failed to merge create table jobs", zap.Error(err)) + } else { + jobWs = newWs + } + } + err = d.addBatchDDLJobs2Table(jobWs) + } else { + err = d.addBatchDDLJobs2Queue(jobWs) + } + var jobs string + for _, jobW := range jobWs { + if err == nil { + err = jobW.cacheErr + } + jobW.NotifyResult(err) + jobs += jobW.Job.String() + "; " + metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerAddDDLJob, jobW.Job.Type.String(), + metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + } + if err != nil { + logutil.DDLLogger().Warn("add DDL jobs failed", zap.String("jobs", jobs), zap.Error(err)) + } else { + logutil.DDLLogger().Info("add DDL jobs", + zap.Int("batch count", len(jobWs)), + zap.String("jobs", jobs), + zap.Bool("table", toTable), + zap.Bool("fast_create", fastCreate)) + } +} + +// mergeCreateTableJobs merges CreateTable jobs to CreateTables. +func mergeCreateTableJobs(jobWs []*JobWrapper) ([]*JobWrapper, error) { + if len(jobWs) <= 1 { + return jobWs, nil + } + resJobWs := make([]*JobWrapper, 0, len(jobWs)) + mergeableJobWs := make(map[string][]*JobWrapper, len(jobWs)) + for _, jobW := range jobWs { + // we don't merge jobs with ID pre-allocated. + if jobW.Type != model.ActionCreateTable || jobW.IDAllocated { + resJobWs = append(resJobWs, jobW) + continue + } + // ActionCreateTables doesn't support foreign key now. + tbInfo, ok := jobW.Args[0].(*model.TableInfo) + if !ok || len(tbInfo.ForeignKeys) > 0 { + resJobWs = append(resJobWs, jobW) + continue + } + // CreateTables only support tables of same schema now. + mergeableJobWs[jobW.Job.SchemaName] = append(mergeableJobWs[jobW.Job.SchemaName], jobW) + } + + for schema, jobs := range mergeableJobWs { + total := len(jobs) + if total <= 1 { + resJobWs = append(resJobWs, jobs...) + continue + } + const maxBatchSize = 8 + batchCount := (total + maxBatchSize - 1) / maxBatchSize + start := 0 + for _, batchSize := range mathutil.Divide2Batches(total, batchCount) { + batch := jobs[start : start+batchSize] + job, err := mergeCreateTableJobsOfSameSchema(batch) + if err != nil { + return nil, err + } + start += batchSize + logutil.DDLLogger().Info("merge create table jobs", zap.String("schema", schema), + zap.Int("total", total), zap.Int("batch_size", batchSize)) + + newJobW := &JobWrapper{ + Job: job, + ResultCh: make([]chan jobSubmitResult, 0, batchSize), + } + // merge the result channels. + for _, j := range batch { + newJobW.ResultCh = append(newJobW.ResultCh, j.ResultCh...) + } + resJobWs = append(resJobWs, newJobW) + } + } + return resJobWs, nil +} + +// buildQueryStringFromJobs takes a slice of Jobs and concatenates their +// queries into a single query string. +// Each query is separated by a semicolon and a space. +// Trailing spaces are removed from each query, and a semicolon is appended +// if it's not already present. +func buildQueryStringFromJobs(jobs []*JobWrapper) string { + var queryBuilder strings.Builder + for i, job := range jobs { + q := strings.TrimSpace(job.Query) + if !strings.HasSuffix(q, ";") { + q += ";" + } + queryBuilder.WriteString(q) + + if i < len(jobs)-1 { + queryBuilder.WriteString(" ") + } + } + return queryBuilder.String() +} + +// mergeCreateTableJobsOfSameSchema combine CreateTableJobs to BatchCreateTableJob. +func mergeCreateTableJobsOfSameSchema(jobWs []*JobWrapper) (*model.Job, error) { + if len(jobWs) == 0 { + return nil, errors.Trace(fmt.Errorf("expect non-empty jobs")) + } + + var combinedJob *model.Job + + args := make([]*model.TableInfo, 0, len(jobWs)) + involvingSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(jobWs)) + var foreignKeyChecks bool + + // if there is any duplicated table name + duplication := make(map[string]struct{}) + for _, job := range jobWs { + if combinedJob == nil { + combinedJob = job.Clone() + combinedJob.Type = model.ActionCreateTables + combinedJob.Args = combinedJob.Args[:0] + foreignKeyChecks = job.Args[1].(bool) + } + // append table job args + info, ok := job.Args[0].(*model.TableInfo) + if !ok { + return nil, errors.Trace(fmt.Errorf("expect model.TableInfo, but got %T", job.Args[0])) + } + args = append(args, info) + + if _, ok := duplication[info.Name.L]; ok { + // return err even if create table if not exists + return nil, infoschema.ErrTableExists.FastGenByArgs("can not batch create tables with same name") + } + + duplication[info.Name.L] = struct{}{} + + involvingSchemaInfo = append(involvingSchemaInfo, + model.InvolvingSchemaInfo{ + Database: job.SchemaName, + Table: info.Name.L, + }) + } + + combinedJob.Args = append(combinedJob.Args, args) + combinedJob.Args = append(combinedJob.Args, foreignKeyChecks) + combinedJob.InvolvingSchemaInfo = involvingSchemaInfo + combinedJob.Query = buildQueryStringFromJobs(jobWs) + + return combinedJob, nil +} + +// addBatchDDLJobs2Table gets global job IDs and puts the DDL jobs in the DDL job table. +func (d *ddl) addBatchDDLJobs2Table(jobWs []*JobWrapper) error { + var err error + + if len(jobWs) == 0 { + return nil + } + + ctx := kv.WithInternalSourceType(d.ctx, kv.InternalTxnDDL) + se, err := d.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer d.sessPool.Put(se) + found, err := d.sysTblMgr.HasFlashbackClusterJob(ctx, d.minJobIDRefresher.GetCurrMinJobID()) + if err != nil { + return errors.Trace(err) + } + if found { + return errors.Errorf("Can't add ddl job, have flashback cluster job") + } + + var ( + startTS = uint64(0) + bdrRole = string(ast.BDRRoleNone) + ) + + err = kv.RunInNewTxn(ctx, d.store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + + bdrRole, err = t.GetBDRRole() + if err != nil { + return errors.Trace(err) + } + startTS = txn.StartTS() + + if variable.DDLForce2Queue.Load() { + if err := d.checkFlashbackJobInQueue(t); err != nil { + return err + } + } + + return nil + }) + if err != nil { + return errors.Trace(err) + } + + for _, jobW := range jobWs { + job := jobW.Job + job.Version = currentVersion + job.StartTS = startTS + job.BDRRole = bdrRole + + // BDR mode only affects the DDL not from CDC + if job.CDCWriteSource == 0 && bdrRole != string(ast.BDRRoleNone) { + if job.Type == model.ActionMultiSchemaChange && job.MultiSchemaInfo != nil { + for _, subJob := range job.MultiSchemaInfo.SubJobs { + if ast.DeniedByBDR(ast.BDRRole(bdrRole), subJob.Type, job) { + return dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) + } + } + } else if ast.DeniedByBDR(ast.BDRRole(bdrRole), job.Type, job) { + return dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) + } + } + + setJobStateToQueueing(job) + + if d.stateSyncer.IsUpgradingState() && !hasSysDB(job) { + if err = pauseRunningJob(sess.NewSession(se), job, model.AdminCommandBySystem); err != nil { + logutil.DDLUpgradingLogger().Warn("pause user DDL by system failed", zap.Stringer("job", job), zap.Error(err)) + jobW.cacheErr = err + continue + } + logutil.DDLUpgradingLogger().Info("pause user DDL by system successful", zap.Stringer("job", job)) + } + } + + se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + ddlSe := sess.NewSession(se) + if err = GenGIDAndInsertJobsWithRetry(ctx, ddlSe, jobWs); err != nil { + return errors.Trace(err) + } + for _, jobW := range jobWs { + d.initJobDoneCh(jobW.ID) + } + + return nil +} + +func (d *ddl) initJobDoneCh(jobID int64) { + d.ddlJobDoneChMap.Store(jobID, make(chan struct{}, 1)) +} + +func (d *ddl) addBatchDDLJobs2Queue(jobWs []*JobWrapper) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + // lock to reduce conflict + d.globalIDLock.Lock() + defer d.globalIDLock.Unlock() + return kv.RunInNewTxn(ctx, d.store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + + count := getRequiredGIDCount(jobWs) + ids, err := t.GenGlobalIDs(count) + if err != nil { + return errors.Trace(err) + } + assignGIDsForJobs(jobWs, ids) + + if err := d.checkFlashbackJobInQueue(t); err != nil { + return errors.Trace(err) + } + + for _, jobW := range jobWs { + job := jobW.Job + job.Version = currentVersion + job.StartTS = txn.StartTS() + setJobStateToQueueing(job) + if err = buildJobDependence(t, job); err != nil { + return errors.Trace(err) + } + jobListKey := meta.DefaultJobListKey + if job.MayNeedReorg() { + jobListKey = meta.AddIndexJobListKey + } + if err = t.EnQueueDDLJob(job, jobListKey); err != nil { + return errors.Trace(err) + } + } + failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) + } + }) + return nil + }) +} + +func (*ddl) checkFlashbackJobInQueue(t *meta.Meta) error { + jobs, err := t.GetAllDDLJobsInQueue(meta.DefaultJobListKey) + if err != nil { + return errors.Trace(err) + } + for _, job := range jobs { + if job.Type == model.ActionFlashbackCluster { + return errors.Errorf("Can't add ddl job, have flashback cluster job") + } + } + return nil +} + +// GenGIDAndInsertJobsWithRetry generate job related global ID and inserts DDL jobs to the DDL job +// table with retry. job id allocation and job insertion are in the same transaction, +// as we want to make sure DDL jobs are inserted in id order, then we can query from +// a min job ID when scheduling DDL jobs to mitigate https://github.com/pingcap/tidb/issues/52905. +// so this function has side effect, it will set table/db/job id of 'jobs'. +func GenGIDAndInsertJobsWithRetry(ctx context.Context, ddlSe *sess.Session, jobWs []*JobWrapper) error { + count := getRequiredGIDCount(jobWs) + return genGIDAndCallWithRetry(ctx, ddlSe, count, func(ids []int64) error { + failpoint.Inject("mockGenGlobalIDFail", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("gofail genGlobalIDs error")) + } + }) + assignGIDsForJobs(jobWs, ids) + injectModifyJobArgFailPoint(jobWs) + return insertDDLJobs2Table(ctx, ddlSe, jobWs...) + }) +} + +// getRequiredGIDCount returns the count of required global IDs for the jobs. it's calculated +// as: the count of jobs + the count of IDs for the jobs which do NOT have pre-allocated ID. +func getRequiredGIDCount(jobWs []*JobWrapper) int { + count := len(jobWs) + idCountForTable := func(info *model.TableInfo) int { + c := 1 + if partitionInfo := info.GetPartitionInfo(); partitionInfo != nil { + c += len(partitionInfo.Definitions) + } + return c + } + for _, jobW := range jobWs { + if jobW.IDAllocated { + continue + } + switch jobW.Type { + case model.ActionCreateView, model.ActionCreateSequence, model.ActionCreateTable: + info := jobW.Args[0].(*model.TableInfo) + count += idCountForTable(info) + case model.ActionCreateTables: + infos := jobW.Args[0].([]*model.TableInfo) + for _, info := range infos { + count += idCountForTable(info) + } + case model.ActionCreateSchema: + count++ + } + // TODO support other type of jobs + } + return count +} + +// assignGIDsForJobs should be used with getRequiredGIDCount, and len(ids) must equal +// what getRequiredGIDCount returns. +func assignGIDsForJobs(jobWs []*JobWrapper, ids []int64) { + idx := 0 + + assignIDsForTable := func(info *model.TableInfo) { + info.ID = ids[idx] + idx++ + if partitionInfo := info.GetPartitionInfo(); partitionInfo != nil { + for i := range partitionInfo.Definitions { + partitionInfo.Definitions[i].ID = ids[idx] + idx++ + } + } + } + for _, jobW := range jobWs { + switch jobW.Type { + case model.ActionCreateView, model.ActionCreateSequence, model.ActionCreateTable: + info := jobW.Args[0].(*model.TableInfo) + if !jobW.IDAllocated { + assignIDsForTable(info) + } + jobW.TableID = info.ID + case model.ActionCreateTables: + if !jobW.IDAllocated { + infos := jobW.Args[0].([]*model.TableInfo) + for _, info := range infos { + assignIDsForTable(info) + } + } + case model.ActionCreateSchema: + dbInfo := jobW.Args[0].(*model.DBInfo) + if !jobW.IDAllocated { + dbInfo.ID = ids[idx] + idx++ + } + jobW.SchemaID = dbInfo.ID + } + // TODO support other type of jobs + jobW.ID = ids[idx] + idx++ + } +} + +// genGIDAndCallWithRetry generates global IDs and calls the function with retry. +// generate ID and call function runs in the same transaction. +func genGIDAndCallWithRetry(ctx context.Context, ddlSe *sess.Session, count int, fn func(ids []int64) error) error { + var resErr error + for i := uint(0); i < kv.MaxRetryCnt; i++ { + resErr = func() (err error) { + if err := ddlSe.Begin(ctx); err != nil { + return errors.Trace(err) + } + defer func() { + if err != nil { + ddlSe.Rollback() + } + }() + txn, err := ddlSe.Txn() + if err != nil { + return errors.Trace(err) + } + txn.SetOption(kv.Pessimistic, true) + forUpdateTS, err := lockGlobalIDKey(ctx, ddlSe, txn) + if err != nil { + return errors.Trace(err) + } + txn.GetSnapshot().SetOption(kv.SnapshotTS, forUpdateTS) + + m := meta.NewMeta(txn) + ids, err := m.GenGlobalIDs(count) + if err != nil { + return errors.Trace(err) + } + if err = fn(ids); err != nil { + return errors.Trace(err) + } + return ddlSe.Commit(ctx) + }() + + if resErr != nil && kv.IsTxnRetryableError(resErr) { + logutil.DDLLogger().Warn("insert job meet retryable error", zap.Error(resErr)) + kv.BackOff(i) + continue + } + break + } + return resErr +} + +// lockGlobalIDKey locks the global ID key in the meta store. it keeps trying if +// meet write conflict, we cannot have a fixed retry count for this error, see this +// https://github.com/pingcap/tidb/issues/27197#issuecomment-2216315057. +// this part is same as how we implement pessimistic + repeatable read isolation +// level in SQL executor, see doLockKeys. +// NextGlobalID is a meta key, so we cannot use "select xx for update", if we store +// it into a table row or using advisory lock, we will depends on a system table +// that is created by us, cyclic. although we can create a system table without using +// DDL logic, we will only consider change it when we have data dictionary and keep +// it this way now. +// TODO maybe we can unify the lock mechanism with SQL executor in the future, or +// implement it inside TiKV client-go. +func lockGlobalIDKey(ctx context.Context, ddlSe *sess.Session, txn kv.Transaction) (uint64, error) { + var ( + iteration uint + forUpdateTs = txn.StartTS() + ver kv.Version + err error + ) + waitTime := ddlSe.GetSessionVars().LockWaitTimeout + m := meta.NewMeta(txn) + idKey := m.GlobalIDKey() + for { + lockCtx := tikv.NewLockCtx(forUpdateTs, waitTime, time.Now()) + err = txn.LockKeys(ctx, lockCtx, idKey) + if err == nil || !terror.ErrorEqual(kv.ErrWriteConflict, err) { + break + } + // ErrWriteConflict contains a conflict-commit-ts in most case, but it cannot + // be used as forUpdateTs, see comments inside handleAfterPessimisticLockError + ver, err = ddlSe.GetStore().CurrentVersion(oracle.GlobalTxnScope) + if err != nil { + break + } + forUpdateTs = ver.Ver + + kv.BackOff(iteration) + // avoid it keep growing and overflow. + iteration = min(iteration+1, math.MaxInt) + } + return forUpdateTs, err +} + +// TODO this failpoint is only checking how job scheduler handle +// corrupted job args, we should test it there by UT, not here. +func injectModifyJobArgFailPoint(jobWs []*JobWrapper) { + failpoint.Inject("MockModifyJobArg", func(val failpoint.Value) { + if val.(bool) { + for _, jobW := range jobWs { + job := jobW.Job + // Corrupt the DDL job argument. + if job.Type == model.ActionMultiSchemaChange { + if len(job.MultiSchemaInfo.SubJobs) > 0 && len(job.MultiSchemaInfo.SubJobs[0].Args) > 0 { + job.MultiSchemaInfo.SubJobs[0].Args[0] = 1 + } + } else if len(job.Args) > 0 { + job.Args[0] = 1 + } + } + } + }) +} + +func setJobStateToQueueing(job *model.Job) { + if job.Type == model.ActionMultiSchemaChange && job.MultiSchemaInfo != nil { + for _, sub := range job.MultiSchemaInfo.SubJobs { + sub.State = model.JobStateQueueing + } + } + job.State = model.JobStateQueueing +} + +// buildJobDependence sets the curjob's dependency-ID. +// The dependency-job's ID must less than the current job's ID, and we need the largest one in the list. +func buildJobDependence(t *meta.Meta, curJob *model.Job) error { + // Jobs in the same queue are ordered. If we want to find a job's dependency-job, we need to look for + // it from the other queue. So if the job is "ActionAddIndex" job, we need find its dependency-job from DefaultJobList. + jobListKey := meta.DefaultJobListKey + if !curJob.MayNeedReorg() { + jobListKey = meta.AddIndexJobListKey + } + jobs, err := t.GetAllDDLJobsInQueue(jobListKey) + if err != nil { + return errors.Trace(err) + } + + for _, job := range jobs { + if curJob.ID < job.ID { + continue + } + isDependent, err := curJob.IsDependentOn(job) + if err != nil { + return errors.Trace(err) + } + if isDependent { + logutil.DDLLogger().Info("current DDL job depends on other job", + zap.Stringer("currentJob", curJob), + zap.Stringer("dependentJob", job)) + curJob.DependencyID = job.ID + break + } + } + return nil +} + +func (e *executor) notifyNewJobSubmitted(ch chan struct{}, etcdPath string, jobID int64, jobType string) { + // If the workers don't run, we needn't notify workers. + // TODO: It does not affect informing the backfill worker. + if !config.GetGlobalConfig().Instance.TiDBEnableDDL.Load() { + return + } + if e.ownerManager.IsOwner() { + asyncNotify(ch) + } else { + e.notifyNewJobByEtcd(etcdPath, jobID, jobType) + } +} + +func (e *executor) notifyNewJobByEtcd(etcdPath string, jobID int64, jobType string) { + if e.etcdCli == nil { + return + } + + jobIDStr := strconv.FormatInt(jobID, 10) + timeStart := time.Now() + err := ddlutil.PutKVToEtcd(e.ctx, e.etcdCli, 1, etcdPath, jobIDStr) + if err != nil { + logutil.DDLLogger().Info("notify handling DDL job failed", + zap.String("etcdPath", etcdPath), + zap.Int64("jobID", jobID), + zap.String("type", jobType), + zap.Error(err)) + } + metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerNotifyDDLJob, jobType, metrics.RetLabel(err)).Observe(time.Since(timeStart).Seconds()) +} diff --git a/pkg/ddl/job_worker.go b/pkg/ddl/job_worker.go index c204711ce5217..ff2e392a2557e 100644 --- a/pkg/ddl/job_worker.go +++ b/pkg/ddl/job_worker.go @@ -191,12 +191,12 @@ func injectFailPointForGetJob(job *model.Job) { if job == nil { return } - failpoint.Inject("mockModifyJobSchemaId", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockModifyJobSchemaId")); _err_ == nil { job.SchemaID = int64(val.(int)) - }) - failpoint.Inject("MockModifyJobTableId", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("MockModifyJobTableId")); _err_ == nil { job.TableID = int64(val.(int)) - }) + } } // handleUpdateJobError handles the too large DDL job. @@ -224,11 +224,11 @@ func (w *worker) handleUpdateJobError(t *meta.Meta, job *model.Job, err error) e // updateDDLJob updates the DDL job information. func (w *worker) updateDDLJob(job *model.Job, updateRawArgs bool) error { - failpoint.Inject("mockErrEntrySizeTooLarge", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockErrEntrySizeTooLarge")); _err_ == nil { if val.(bool) { - failpoint.Return(kv.ErrEntryTooLarge) + return kv.ErrEntryTooLarge } - }) + } if !updateRawArgs { w.jobLogger(job).Info("meet something wrong before update DDL job, shouldn't update raw args", @@ -348,7 +348,7 @@ func (w *worker) finishDDLJob(t *meta.Meta, job *model.Job) (err error) { } job.SeqNum = w.seqAllocator.Add(1) w.removeJobCtx(job) - failpoint.InjectCall("afterFinishDDLJob", job) + failpoint.Call(_curpkg_("afterFinishDDLJob"), job) err = AddHistoryDDLJob(w.ctx, w.sess, t, job, updateRawArgs) return errors.Trace(err) } @@ -456,11 +456,11 @@ func (w *worker) prepareTxn(job *model.Job) (kv.Transaction, error) { if err != nil { return nil, err } - failpoint.Inject("mockRunJobTime", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockRunJobTime")); _err_ == nil { if val.(bool) { time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) // #nosec G404 } - }) + } txn, err := w.sess.Txn() if err != nil { w.sess.Rollback() @@ -503,7 +503,7 @@ func (w *worker) transitOneJobStep(d *ddlCtx, job *model.Job) (int64, error) { job.State = model.JobStateSynced } // Inject the failpoint to prevent the progress of index creation. - failpoint.Inject("create-index-stuck-before-ddlhistory", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("create-index-stuck-before-ddlhistory")); _err_ == nil { if sigFile, ok := v.(string); ok && job.Type == model.ActionAddIndex { for { time.Sleep(1 * time.Second) @@ -511,21 +511,21 @@ func (w *worker) transitOneJobStep(d *ddlCtx, job *model.Job) (int64, error) { if os.IsNotExist(err) { continue } - failpoint.Return(0, errors.Trace(err)) + return 0, errors.Trace(err) } break } } - }) + } return 0, w.handleJobDone(d, job, t) } - failpoint.InjectCall("onJobRunBefore", job) + failpoint.Call(_curpkg_("onJobRunBefore"), job) // If running job meets error, we will save this error in job Error and retry // later if the job is not cancelled. schemaVer, updateRawArgs, runJobErr := w.runOneJobStep(d, t, job) - failpoint.InjectCall("onJobRunAfter", job) + failpoint.Call(_curpkg_("onJobRunAfter"), job) if job.IsCancelled() { defer d.unlockSchemaVersion(job.ID) @@ -748,7 +748,7 @@ func (w *worker) runOneJobStep( }, false) // Mock for run ddl job panic. - failpoint.Inject("mockPanicInRunDDLJob", func(failpoint.Value) {}) + failpoint.Eval(_curpkg_("mockPanicInRunDDLJob")) if job.Type != model.ActionMultiSchemaChange { w.jobLogger(job).Info("run DDL job", zap.String("category", "ddl"), zap.String("job", job.String())) diff --git a/pkg/ddl/job_worker.go__failpoint_stash__ b/pkg/ddl/job_worker.go__failpoint_stash__ new file mode 100644 index 0000000000000..c204711ce5217 --- /dev/null +++ b/pkg/ddl/job_worker.go__failpoint_stash__ @@ -0,0 +1,1013 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "fmt" + "math/rand" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + pumpcli "github.com/pingcap/tidb/pkg/tidb-binlog/pump_client" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/resourcegrouptag" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/tikv/client-go/v2/tikvrpc" + kvutil "github.com/tikv/client-go/v2/util" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +var ( + // ddlWorkerID is used for generating the next DDL worker ID. + ddlWorkerID = atomicutil.NewInt32(0) + // backfillContextID is used for generating the next backfill context ID. + backfillContextID = atomicutil.NewInt32(0) + // WaitTimeWhenErrorOccurred is waiting interval when processing DDL jobs encounter errors. + WaitTimeWhenErrorOccurred = int64(1 * time.Second) + + mockDDLErrOnce = int64(0) + // TestNotifyBeginTxnCh is used for if the txn is beginning in runInTxn. + TestNotifyBeginTxnCh = make(chan struct{}) +) + +// GetWaitTimeWhenErrorOccurred return waiting interval when processing DDL jobs encounter errors. +func GetWaitTimeWhenErrorOccurred() time.Duration { + return time.Duration(atomic.LoadInt64(&WaitTimeWhenErrorOccurred)) +} + +// SetWaitTimeWhenErrorOccurred update waiting interval when processing DDL jobs encounter errors. +func SetWaitTimeWhenErrorOccurred(dur time.Duration) { + atomic.StoreInt64(&WaitTimeWhenErrorOccurred, int64(dur)) +} + +type workerType byte + +const ( + // generalWorker is the worker who handles all DDL statements except “add index”. + generalWorker workerType = 0 + // addIdxWorker is the worker who handles the operation of adding indexes. + addIdxWorker workerType = 1 +) + +// worker is used for handling DDL jobs. +// Now we have two kinds of workers. +type worker struct { + id int32 + tp workerType + addingDDLJobKey string + ddlJobCh chan struct{} + // it's the ctx of 'job scheduler'. + ctx context.Context + wg sync.WaitGroup + + sessPool *sess.Pool // sessPool is used to new sessions to execute SQL in ddl package. + sess *sess.Session // sess is used and only used in running DDL job. + delRangeManager delRangeManager + logCtx context.Context + seqAllocator *atomic.Uint64 + + *ddlCtx +} + +// JobContext is the ddl job execution context. +type JobContext struct { + // below fields are cache for top sql + ddlJobCtx context.Context + cacheSQL string + cacheNormalizedSQL string + cacheDigest *parser.Digest + tp string + + resourceGroupName string + cloudStorageURI string +} + +// NewJobContext returns a new ddl job context. +func NewJobContext() *JobContext { + return &JobContext{ + ddlJobCtx: context.Background(), + cacheSQL: "", + cacheNormalizedSQL: "", + cacheDigest: nil, + tp: "", + } +} + +func newWorker(ctx context.Context, tp workerType, sessPool *sess.Pool, delRangeMgr delRangeManager, dCtx *ddlCtx) *worker { + worker := &worker{ + id: ddlWorkerID.Add(1), + tp: tp, + ddlJobCh: make(chan struct{}, 1), + ctx: ctx, + ddlCtx: dCtx, + sessPool: sessPool, + delRangeManager: delRangeMgr, + } + worker.addingDDLJobKey = addingDDLJobPrefix + worker.typeStr() + worker.logCtx = tidblogutil.WithFields(context.Background(), zap.String("worker", worker.String()), zap.String("category", "ddl")) + return worker +} + +func (w *worker) jobLogger(job *model.Job) *zap.Logger { + logger := tidblogutil.Logger(w.logCtx) + if job != nil { + logger = tidblogutil.LoggerWithTraceInfo( + logger.With(zap.Int64("jobID", job.ID)), + job.TraceInfo, + ) + } + return logger +} + +func (w *worker) typeStr() string { + var str string + switch w.tp { + case generalWorker: + str = "general" + case addIdxWorker: + str = "add index" + default: + str = "unknown" + } + return str +} + +func (w *worker) String() string { + return fmt.Sprintf("worker %d, tp %s", w.id, w.typeStr()) +} + +func (w *worker) Close() { + startTime := time.Now() + if w.sess != nil { + w.sessPool.Put(w.sess.Session()) + } + w.wg.Wait() + tidblogutil.Logger(w.logCtx).Info("DDL worker closed", zap.Duration("take time", time.Since(startTime))) +} + +func asyncNotify(ch chan struct{}) { + select { + case ch <- struct{}{}: + default: + } +} + +func injectFailPointForGetJob(job *model.Job) { + if job == nil { + return + } + failpoint.Inject("mockModifyJobSchemaId", func(val failpoint.Value) { + job.SchemaID = int64(val.(int)) + }) + failpoint.Inject("MockModifyJobTableId", func(val failpoint.Value) { + job.TableID = int64(val.(int)) + }) +} + +// handleUpdateJobError handles the too large DDL job. +func (w *worker) handleUpdateJobError(t *meta.Meta, job *model.Job, err error) error { + if err == nil { + return nil + } + if kv.ErrEntryTooLarge.Equal(err) { + w.jobLogger(job).Warn("update DDL job failed", zap.String("job", job.String()), zap.Error(err)) + w.sess.Rollback() + err1 := w.sess.Begin(w.ctx) + if err1 != nil { + return errors.Trace(err1) + } + // Reduce this txn entry size. + job.BinlogInfo.Clean() + job.Error = toTError(err) + job.ErrorCount++ + job.SchemaState = model.StateNone + job.State = model.JobStateCancelled + err = w.finishDDLJob(t, job) + } + return errors.Trace(err) +} + +// updateDDLJob updates the DDL job information. +func (w *worker) updateDDLJob(job *model.Job, updateRawArgs bool) error { + failpoint.Inject("mockErrEntrySizeTooLarge", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(kv.ErrEntryTooLarge) + } + }) + + if !updateRawArgs { + w.jobLogger(job).Info("meet something wrong before update DDL job, shouldn't update raw args", + zap.String("job", job.String())) + } + return errors.Trace(updateDDLJob2Table(w.sess, job, updateRawArgs)) +} + +// registerMDLInfo registers metadata lock info. +func (w *worker) registerMDLInfo(job *model.Job, ver int64) error { + if !variable.EnableMDL.Load() { + return nil + } + if ver == 0 { + return nil + } + rows, err := w.sess.Execute(w.ctx, fmt.Sprintf("select table_ids from mysql.tidb_ddl_job where job_id = %d", job.ID), "register-mdl-info") + if err != nil { + return err + } + if len(rows) == 0 { + return errors.Errorf("can't find ddl job %d", job.ID) + } + ownerID := w.ownerManager.ID() + ids := rows[0].GetString(0) + var sql string + if tidbutil.IsSysDB(strings.ToLower(job.SchemaName)) { + // DDLs that modify system tables could only happen in upgrade process, + // we should not reference 'owner_id'. Otherwise, there is a circular blocking problem. + sql = fmt.Sprintf("replace into mysql.tidb_mdl_info (job_id, version, table_ids) values (%d, %d, '%s')", job.ID, ver, ids) + } else { + sql = fmt.Sprintf("replace into mysql.tidb_mdl_info (job_id, version, table_ids, owner_id) values (%d, %d, '%s', '%s')", job.ID, ver, ids, ownerID) + } + _, err = w.sess.Execute(w.ctx, sql, "register-mdl-info") + return err +} + +// JobNeedGC is called to determine whether delete-ranges need to be generated for the provided job. +// +// NOTICE: BR also uses jobNeedGC to determine whether delete-ranges need to be generated for the provided job. +// Therefore, please make sure any modification is compatible with BR. +func JobNeedGC(job *model.Job) bool { + if !job.IsCancelled() { + if job.Warning != nil && dbterror.ErrCantDropFieldOrKey.Equal(job.Warning) { + // For the field/key not exists warnings, there is no need to + // delete the ranges. + return false + } + switch job.Type { + case model.ActionDropSchema, model.ActionDropTable, + model.ActionTruncateTable, model.ActionDropIndex, + model.ActionDropPrimaryKey, + model.ActionDropTablePartition, model.ActionTruncateTablePartition, + model.ActionDropColumn, model.ActionModifyColumn, + model.ActionAddIndex, model.ActionAddPrimaryKey, + model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + return true + case model.ActionMultiSchemaChange: + for i, sub := range job.MultiSchemaInfo.SubJobs { + proxyJob := sub.ToProxyJob(job, i) + needGC := JobNeedGC(&proxyJob) + if needGC { + return true + } + } + return false + } + } + return false +} + +// finishDDLJob deletes the finished DDL job in the ddl queue and puts it to history queue. +// If the DDL job need to handle in background, it will prepare a background job. +func (w *worker) finishDDLJob(t *meta.Meta, job *model.Job) (err error) { + startTime := time.Now() + defer func() { + metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerFinishDDLJob, job.Type.String(), metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + }() + + if JobNeedGC(job) { + err = w.delRangeManager.addDelRangeJob(w.ctx, job) + if err != nil { + return errors.Trace(err) + } + } + + switch job.Type { + case model.ActionRecoverTable: + err = finishRecoverTable(w, job) + case model.ActionFlashbackCluster: + err = finishFlashbackCluster(w, job) + case model.ActionRecoverSchema: + err = finishRecoverSchema(w, job) + case model.ActionCreateTables: + if job.IsCancelled() { + // it may be too large that it can not be added to the history queue, so + // delete its arguments + job.Args = nil + } + } + if err != nil { + return errors.Trace(err) + } + err = w.deleteDDLJob(job) + if err != nil { + return errors.Trace(err) + } + + job.BinlogInfo.FinishedTS = t.StartTS + w.jobLogger(job).Info("finish DDL job", zap.String("job", job.String())) + updateRawArgs := true + if job.Type == model.ActionAddPrimaryKey && !job.IsCancelled() { + // ActionAddPrimaryKey needs to check the warnings information in job.Args. + // Notice: warnings is used to support non-strict mode. + updateRawArgs = false + } + job.SeqNum = w.seqAllocator.Add(1) + w.removeJobCtx(job) + failpoint.InjectCall("afterFinishDDLJob", job) + err = AddHistoryDDLJob(w.ctx, w.sess, t, job, updateRawArgs) + return errors.Trace(err) +} + +func (w *worker) deleteDDLJob(job *model.Job) error { + sql := fmt.Sprintf("delete from mysql.tidb_ddl_job where job_id = %d", job.ID) + _, err := w.sess.Execute(context.Background(), sql, "delete_job") + return errors.Trace(err) +} + +func finishRecoverTable(w *worker, job *model.Job) error { + var ( + recoverInfo *RecoverInfo + recoverTableCheckFlag int64 + ) + err := job.DecodeArgs(&recoverInfo, &recoverTableCheckFlag) + if err != nil { + return errors.Trace(err) + } + if recoverTableCheckFlag == recoverCheckFlagEnableGC { + err = enableGC(w) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +func finishRecoverSchema(w *worker, job *model.Job) error { + var ( + recoverSchemaInfo *RecoverSchemaInfo + recoverSchemaCheckFlag int64 + ) + err := job.DecodeArgs(&recoverSchemaInfo, &recoverSchemaCheckFlag) + if err != nil { + return errors.Trace(err) + } + if recoverSchemaCheckFlag == recoverCheckFlagEnableGC { + err = enableGC(w) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (w *JobContext) setDDLLabelForTopSQL(jobQuery string) { + if !topsqlstate.TopSQLEnabled() || jobQuery == "" { + return + } + + if jobQuery != w.cacheSQL || w.cacheDigest == nil { + w.cacheNormalizedSQL, w.cacheDigest = parser.NormalizeDigest(jobQuery) + w.cacheSQL = jobQuery + w.ddlJobCtx = topsql.AttachAndRegisterSQLInfo(context.Background(), w.cacheNormalizedSQL, w.cacheDigest, false) + } else { + topsql.AttachAndRegisterSQLInfo(w.ddlJobCtx, w.cacheNormalizedSQL, w.cacheDigest, false) + } +} + +// DDLBackfillers contains the DDL need backfill step. +var DDLBackfillers = map[model.ActionType]string{ + model.ActionAddIndex: "add_index", + model.ActionModifyColumn: "modify_column", + model.ActionDropIndex: "drop_index", + model.ActionReorganizePartition: "reorganize_partition", +} + +func getDDLRequestSource(jobType model.ActionType) string { + if tp, ok := DDLBackfillers[jobType]; ok { + return kv.InternalTxnBackfillDDLPrefix + tp + } + return kv.InternalTxnDDL +} + +func (w *JobContext) setDDLLabelForDiagnosis(jobType model.ActionType) { + if w.tp != "" { + return + } + w.tp = getDDLRequestSource(jobType) + w.ddlJobCtx = kv.WithInternalSourceAndTaskType(w.ddlJobCtx, w.ddlJobSourceType(), kvutil.ExplicitTypeDDL) +} + +func (w *worker) handleJobDone(d *ddlCtx, job *model.Job, t *meta.Meta) error { + if err := w.checkBeforeCommit(); err != nil { + return err + } + err := w.finishDDLJob(t, job) + if err != nil { + w.sess.Rollback() + return err + } + + err = w.sess.Commit(w.ctx) + if err != nil { + return err + } + cleanupDDLReorgHandles(job, w.sess) + d.notifyJobDone(job.ID) + return nil +} + +func (w *worker) prepareTxn(job *model.Job) (kv.Transaction, error) { + err := w.sess.Begin(w.ctx) + if err != nil { + return nil, err + } + failpoint.Inject("mockRunJobTime", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) // #nosec G404 + } + }) + txn, err := w.sess.Txn() + if err != nil { + w.sess.Rollback() + return txn, err + } + // Only general DDLs are allowed to be executed when TiKV is disk full. + if w.tp == addIdxWorker && job.IsRunning() { + txn.SetDiskFullOpt(kvrpcpb.DiskFullOpt_NotAllowedOnFull) + } + w.setDDLLabelForTopSQL(job.ID, job.Query) + w.setDDLSourceForDiagnosis(job.ID, job.Type) + jobContext := w.jobContext(job.ID, job.ReorgMeta) + if tagger := w.getResourceGroupTaggerForTopSQL(job.ID); tagger != nil { + txn.SetOption(kv.ResourceGroupTagger, tagger) + } + txn.SetOption(kv.ResourceGroupName, jobContext.resourceGroupName) + // set request source type to DDL type + txn.SetOption(kv.RequestSourceType, jobContext.ddlJobSourceType()) + return txn, err +} + +// transitOneJobStep runs one step of the DDL job and persist the new job +// information. +// +// The first return value is the schema version after running the job. If it's +// non-zero, caller should wait for other nodes to catch up. +func (w *worker) transitOneJobStep(d *ddlCtx, job *model.Job) (int64, error) { + var ( + err error + ) + + txn, err := w.prepareTxn(job) + if err != nil { + return 0, err + } + t := meta.NewMeta(txn) + + if job.IsDone() || job.IsRollbackDone() || job.IsCancelled() { + if job.IsDone() { + job.State = model.JobStateSynced + } + // Inject the failpoint to prevent the progress of index creation. + failpoint.Inject("create-index-stuck-before-ddlhistory", func(v failpoint.Value) { + if sigFile, ok := v.(string); ok && job.Type == model.ActionAddIndex { + for { + time.Sleep(1 * time.Second) + if _, err := os.Stat(sigFile); err != nil { + if os.IsNotExist(err) { + continue + } + failpoint.Return(0, errors.Trace(err)) + } + break + } + } + }) + return 0, w.handleJobDone(d, job, t) + } + failpoint.InjectCall("onJobRunBefore", job) + + // If running job meets error, we will save this error in job Error and retry + // later if the job is not cancelled. + schemaVer, updateRawArgs, runJobErr := w.runOneJobStep(d, t, job) + + failpoint.InjectCall("onJobRunAfter", job) + + if job.IsCancelled() { + defer d.unlockSchemaVersion(job.ID) + w.sess.Reset() + return 0, w.handleJobDone(d, job, t) + } + + if err = w.checkBeforeCommit(); err != nil { + d.unlockSchemaVersion(job.ID) + return 0, err + } + + if runJobErr != nil && !job.IsRollingback() && !job.IsRollbackDone() { + // If the running job meets an error + // and the job state is rolling back, it means that we have already handled this error. + // Some DDL jobs (such as adding indexes) may need to update the table info and the schema version, + // then shouldn't discard the KV modification. + // And the job state is rollback done, it means the job was already finished, also shouldn't discard too. + // Otherwise, we should discard the KV modification when running job. + w.sess.Reset() + // If error happens after updateSchemaVersion(), then the schemaVer is updated. + // Result in the retry duration is up to 2 * lease. + schemaVer = 0 + } + + err = w.registerMDLInfo(job, schemaVer) + if err != nil { + w.sess.Rollback() + d.unlockSchemaVersion(job.ID) + return 0, err + } + err = w.updateDDLJob(job, updateRawArgs) + if err = w.handleUpdateJobError(t, job, err); err != nil { + w.sess.Rollback() + d.unlockSchemaVersion(job.ID) + return 0, err + } + writeBinlog(d.binlogCli, txn, job) + // reset the SQL digest to make topsql work right. + w.sess.GetSessionVars().StmtCtx.ResetSQLDigest(job.Query) + err = w.sess.Commit(w.ctx) + d.unlockSchemaVersion(job.ID) + if err != nil { + return 0, err + } + w.registerSync(job) + + // If error is non-retryable, we can ignore the sleep. + if runJobErr != nil && errorIsRetryable(runJobErr, job) { + w.jobLogger(job).Info("run DDL job failed, sleeps a while then retries it.", + zap.Duration("waitTime", GetWaitTimeWhenErrorOccurred()), zap.Error(runJobErr)) + // wait a while to retry again. If we don't wait here, DDL will retry this job immediately, + // which may act like a deadlock. + select { + case <-time.After(GetWaitTimeWhenErrorOccurred()): + case <-w.ctx.Done(): + } + } + + return schemaVer, nil +} + +func (w *worker) checkBeforeCommit() error { + if !w.ddlCtx.isOwner() { + // Since this TiDB instance is not a DDL owner anymore, + // it should not commit any transaction. + w.sess.Rollback() + return dbterror.ErrNotOwner + } + + if err := w.ctx.Err(); err != nil { + // The worker context is canceled, it should not commit any transaction. + return err + } + return nil +} + +func (w *JobContext) getResourceGroupTaggerForTopSQL() tikvrpc.ResourceGroupTagger { + if !topsqlstate.TopSQLEnabled() || w.cacheDigest == nil { + return nil + } + + digest := w.cacheDigest + tagger := func(req *tikvrpc.Request) { + req.ResourceGroupTag = resourcegrouptag.EncodeResourceGroupTag(digest, nil, + resourcegrouptag.GetResourceGroupLabelByKey(resourcegrouptag.GetFirstKeyFromRequest(req))) + } + return tagger +} + +func (w *JobContext) ddlJobSourceType() string { + return w.tp +} + +func skipWriteBinlog(job *model.Job) bool { + switch job.Type { + // ActionUpdateTiFlashReplicaStatus is a TiDB internal DDL, + // it's used to update table's TiFlash replica available status. + case model.ActionUpdateTiFlashReplicaStatus: + return true + // Don't sync 'alter table cache|nocache' to other tools. + // It's internal to the current cluster. + case model.ActionAlterCacheTable, model.ActionAlterNoCacheTable: + return true + } + + return false +} + +func writeBinlog(binlogCli *pumpcli.PumpsClient, txn kv.Transaction, job *model.Job) { + if job.IsDone() || job.IsRollbackDone() || + // When this column is in the "delete only" and "delete reorg" states, the binlog of "drop column" has not been written yet, + // but the column has been removed from the binlog of the write operation. + // So we add this binlog to enable downstream components to handle DML correctly in this schema state. + (job.Type == model.ActionDropColumn && job.SchemaState == model.StateDeleteOnly) { + if skipWriteBinlog(job) { + return + } + binloginfo.SetDDLBinlog(binlogCli, txn, job.ID, int32(job.SchemaState), job.Query) + } +} + +func chooseLeaseTime(t, max time.Duration) time.Duration { + if t == 0 || t > max { + return max + } + return t +} + +// countForPanic records the error count for DDL job. +func (w *worker) countForPanic(job *model.Job) { + // If run DDL job panic, just cancel the DDL jobs. + if job.State == model.JobStateRollingback { + job.State = model.JobStateCancelled + } else { + job.State = model.JobStateCancelling + } + job.ErrorCount++ + + logger := w.jobLogger(job) + // Load global DDL variables. + if err1 := loadDDLVars(w); err1 != nil { + logger.Error("load DDL global variable failed", zap.Error(err1)) + } + errorCount := variable.GetDDLErrorCountLimit() + + if job.ErrorCount > errorCount { + msg := fmt.Sprintf("panic in handling DDL logic and error count beyond the limitation %d, cancelled", errorCount) + logger.Warn(msg) + job.Error = toTError(errors.New(msg)) + job.State = model.JobStateCancelled + } +} + +// countForError records the error count for DDL job. +func (w *worker) countForError(err error, job *model.Job) error { + job.Error = toTError(err) + job.ErrorCount++ + + logger := w.jobLogger(job) + // If job is cancelled, we shouldn't return an error and shouldn't load DDL variables. + if job.State == model.JobStateCancelled { + logger.Info("DDL job is cancelled normally", zap.Error(err)) + return nil + } + logger.Warn("run DDL job error", zap.Error(err)) + + // Load global DDL variables. + if err1 := loadDDLVars(w); err1 != nil { + logger.Error("load DDL global variable failed", zap.Error(err1)) + } + // Check error limit to avoid falling into an infinite loop. + if job.ErrorCount > variable.GetDDLErrorCountLimit() && job.State == model.JobStateRunning && job.IsRollbackable() { + logger.Warn("DDL job error count exceed the limit, cancelling it now", zap.Int64("errorCountLimit", variable.GetDDLErrorCountLimit())) + job.State = model.JobStateCancelling + } + return err +} + +func (w *worker) processJobPausingRequest(d *ddlCtx, job *model.Job) (isRunnable bool, err error) { + if job.IsPaused() { + w.jobLogger(job).Debug("paused DDL job ", zap.String("job", job.String())) + return false, err + } + if job.IsPausing() { + w.jobLogger(job).Debug("pausing DDL job ", zap.String("job", job.String())) + job.State = model.JobStatePaused + return false, pauseReorgWorkers(w, d, job) + } + return true, nil +} + +// runOneJobStep runs a DDL job *step*. It returns the current schema version in +// this transaction, if the given job.Args has changed, and the error. The *step* +// is defined as the following reasons: +// +// - TiDB uses "Asynchronous Schema Change in F1", one job may have multiple +// *steps* each for a schema state change such as 'delete only' -> 'write only'. +// Combined with caller transitOneJobStepAndWaitSync waiting for other nodes to +// catch up with the returned schema version, we can make sure the cluster will +// only have two adjacent schema state for a DDL object. +// +// - Some types of DDL jobs has defined its own *step*s other than F1 paper. +// These *step*s may not be schema state change, and their purposes are various. +// For example, onLockTables updates the lock state of one table every *step*. +// +// - To provide linearizability we have added extra job state change *step*. For +// example, if job becomes JobStateDone in runOneJobStep, we cannot return to +// user that the job is finished because other nodes in cluster may not be +// synchronized. So JobStateSynced *step* is added to make sure there is +// waitSchemaChanged to wait for all nodes to catch up JobStateDone. +func (w *worker) runOneJobStep( + d *ddlCtx, + t *meta.Meta, + job *model.Job, +) (ver int64, updateRawArgs bool, err error) { + defer tidbutil.Recover(metrics.LabelDDLWorker, fmt.Sprintf("%s runOneJobStep", w), + func() { + w.countForPanic(job) + }, false) + + // Mock for run ddl job panic. + failpoint.Inject("mockPanicInRunDDLJob", func(failpoint.Value) {}) + + if job.Type != model.ActionMultiSchemaChange { + w.jobLogger(job).Info("run DDL job", zap.String("category", "ddl"), zap.String("job", job.String())) + } + timeStart := time.Now() + if job.RealStartTS == 0 { + job.RealStartTS = t.StartTS + } + defer func() { + metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerRunDDLJob, job.Type.String(), metrics.RetLabel(err)).Observe(time.Since(timeStart).Seconds()) + }() + + if job.IsCancelling() { + w.jobLogger(job).Debug("cancel DDL job", zap.String("job", job.String())) + ver, err = convertJob2RollbackJob(w, d, t, job) + // if job is converted to rollback job, the job.Args may be changed for the + // rollback logic, so we let caller persist the new arguments. + updateRawArgs = job.IsRollingback() + return + } + + isRunnable, err := w.processJobPausingRequest(d, job) + if !isRunnable { + return ver, false, err + } + + // It would be better to do the positive check, but no idea to list all valid states here now. + if !job.IsRollingback() { + job.State = model.JobStateRunning + } + + prevState := job.State + + // For every type, `schema/table` modification and `job` modification are conducted + // in the one kv transaction. The `schema/table` modification can be always discarded + // by kv reset when meets an unhandled error, but the `job` modification can't. + // So make sure job state and args change is after all other checks or make sure these + // change has no effect when retrying it. + switch job.Type { + case model.ActionCreateSchema: + ver, err = onCreateSchema(d, t, job) + case model.ActionModifySchemaCharsetAndCollate: + ver, err = onModifySchemaCharsetAndCollate(d, t, job) + case model.ActionDropSchema: + ver, err = onDropSchema(d, t, job) + case model.ActionRecoverSchema: + ver, err = w.onRecoverSchema(d, t, job) + case model.ActionModifySchemaDefaultPlacement: + ver, err = onModifySchemaDefaultPlacement(d, t, job) + case model.ActionCreateTable: + ver, err = onCreateTable(d, t, job) + case model.ActionCreateTables: + ver, err = onCreateTables(d, t, job) + case model.ActionRepairTable: + ver, err = onRepairTable(d, t, job) + case model.ActionCreateView: + ver, err = onCreateView(d, t, job) + case model.ActionDropTable, model.ActionDropView, model.ActionDropSequence: + ver, err = onDropTableOrView(d, t, job) + case model.ActionDropTablePartition: + ver, err = w.onDropTablePartition(d, t, job) + case model.ActionTruncateTablePartition: + ver, err = w.onTruncateTablePartition(d, t, job) + case model.ActionExchangeTablePartition: + ver, err = w.onExchangeTablePartition(d, t, job) + case model.ActionAddColumn: + ver, err = onAddColumn(d, t, job) + case model.ActionDropColumn: + ver, err = onDropColumn(d, t, job) + case model.ActionModifyColumn: + ver, err = w.onModifyColumn(d, t, job) + case model.ActionSetDefaultValue: + ver, err = onSetDefaultValue(d, t, job) + case model.ActionAddIndex: + ver, err = w.onCreateIndex(d, t, job, false) + case model.ActionAddPrimaryKey: + ver, err = w.onCreateIndex(d, t, job, true) + case model.ActionDropIndex, model.ActionDropPrimaryKey: + ver, err = onDropIndex(d, t, job) + case model.ActionRenameIndex: + ver, err = onRenameIndex(d, t, job) + case model.ActionAddForeignKey: + ver, err = w.onCreateForeignKey(d, t, job) + case model.ActionDropForeignKey: + ver, err = onDropForeignKey(d, t, job) + case model.ActionTruncateTable: + ver, err = w.onTruncateTable(d, t, job) + case model.ActionRebaseAutoID: + ver, err = onRebaseAutoIncrementIDType(d, t, job) + case model.ActionRebaseAutoRandomBase: + ver, err = onRebaseAutoRandomType(d, t, job) + case model.ActionRenameTable: + ver, err = onRenameTable(d, t, job) + case model.ActionShardRowID: + ver, err = w.onShardRowID(d, t, job) + case model.ActionModifyTableComment: + ver, err = onModifyTableComment(d, t, job) + case model.ActionModifyTableAutoIdCache: + ver, err = onModifyTableAutoIDCache(d, t, job) + case model.ActionAddTablePartition: + ver, err = w.onAddTablePartition(d, t, job) + case model.ActionModifyTableCharsetAndCollate: + ver, err = onModifyTableCharsetAndCollate(d, t, job) + case model.ActionRecoverTable: + ver, err = w.onRecoverTable(d, t, job) + case model.ActionLockTable: + ver, err = onLockTables(d, t, job) + case model.ActionUnlockTable: + ver, err = onUnlockTables(d, t, job) + case model.ActionSetTiFlashReplica: + ver, err = w.onSetTableFlashReplica(d, t, job) + case model.ActionUpdateTiFlashReplicaStatus: + ver, err = onUpdateFlashReplicaStatus(d, t, job) + case model.ActionCreateSequence: + ver, err = onCreateSequence(d, t, job) + case model.ActionAlterIndexVisibility: + ver, err = onAlterIndexVisibility(d, t, job) + case model.ActionAlterSequence: + ver, err = onAlterSequence(d, t, job) + case model.ActionRenameTables: + ver, err = onRenameTables(d, t, job) + case model.ActionAlterTableAttributes: + ver, err = onAlterTableAttributes(d, t, job) + case model.ActionAlterTablePartitionAttributes: + ver, err = onAlterTablePartitionAttributes(d, t, job) + case model.ActionCreatePlacementPolicy: + ver, err = onCreatePlacementPolicy(d, t, job) + case model.ActionDropPlacementPolicy: + ver, err = onDropPlacementPolicy(d, t, job) + case model.ActionAlterPlacementPolicy: + ver, err = onAlterPlacementPolicy(d, t, job) + case model.ActionAlterTablePartitionPlacement: + ver, err = onAlterTablePartitionPlacement(d, t, job) + case model.ActionAlterTablePlacement: + ver, err = onAlterTablePlacement(d, t, job) + case model.ActionCreateResourceGroup: + ver, err = onCreateResourceGroup(w.ctx, d, t, job) + case model.ActionAlterResourceGroup: + ver, err = onAlterResourceGroup(d, t, job) + case model.ActionDropResourceGroup: + ver, err = onDropResourceGroup(d, t, job) + case model.ActionAlterCacheTable: + ver, err = onAlterCacheTable(d, t, job) + case model.ActionAlterNoCacheTable: + ver, err = onAlterNoCacheTable(d, t, job) + case model.ActionFlashbackCluster: + ver, err = w.onFlashbackCluster(d, t, job) + case model.ActionMultiSchemaChange: + ver, err = onMultiSchemaChange(w, d, t, job) + case model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + ver, err = w.onReorganizePartition(d, t, job) + case model.ActionAlterTTLInfo: + ver, err = onTTLInfoChange(d, t, job) + case model.ActionAlterTTLRemove: + ver, err = onTTLInfoRemove(d, t, job) + case model.ActionAddCheckConstraint: + ver, err = w.onAddCheckConstraint(d, t, job) + case model.ActionDropCheckConstraint: + ver, err = onDropCheckConstraint(d, t, job) + case model.ActionAlterCheckConstraint: + ver, err = w.onAlterCheckConstraint(d, t, job) + default: + // Invalid job, cancel it. + job.State = model.JobStateCancelled + err = dbterror.ErrInvalidDDLJob.GenWithStack("invalid ddl job type: %v", job.Type) + } + + // there are too many job types, instead let every job type output its own + // updateRawArgs, we try to use these rules as a generalization: + // + // if job has no error, some arguments may be changed, there's no harm to update + // it. + updateRawArgs = err == nil + // if job changed from running to rolling back, arguments may be changed + if prevState == model.JobStateRunning && job.IsRollingback() { + updateRawArgs = true + } + + // Save errors in job if any, so that others can know errors happened. + if err != nil { + err = w.countForError(err, job) + } + return ver, updateRawArgs, err +} + +func loadDDLVars(w *worker) error { + // Get sessionctx from context resource pool. + var ctx sessionctx.Context + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) + return util.LoadDDLVars(ctx) +} + +func toTError(err error) *terror.Error { + originErr := errors.Cause(err) + tErr, ok := originErr.(*terror.Error) + if ok { + return tErr + } + + // TODO: Add the error code. + return dbterror.ClassDDL.Synthesize(terror.CodeUnknown, err.Error()) +} + +// waitSchemaChanged waits for the completion of updating all servers' schema or MDL synced. In order to make sure that happens, +// we wait at most 2 * lease time(sessionTTL, 90 seconds). +func waitSchemaChanged(ctx context.Context, d *ddlCtx, latestSchemaVersion int64, job *model.Job) error { + if !job.IsRunning() && !job.IsRollingback() && !job.IsDone() && !job.IsRollbackDone() { + return nil + } + + timeStart := time.Now() + var err error + defer func() { + metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerWaitSchemaChanged, job.Type.String(), metrics.RetLabel(err)).Observe(time.Since(timeStart).Seconds()) + }() + + if latestSchemaVersion == 0 { + logutil.DDLLogger().Info("schema version doesn't change", zap.Int64("jobID", job.ID)) + return nil + } + + err = d.schemaSyncer.OwnerUpdateGlobalVersion(ctx, latestSchemaVersion) + if err != nil { + logutil.DDLLogger().Info("update latest schema version failed", zap.Int64("ver", latestSchemaVersion), zap.Error(err)) + if variable.EnableMDL.Load() { + return err + } + if terror.ErrorEqual(err, context.DeadlineExceeded) { + // If err is context.DeadlineExceeded, it means waitTime(2 * lease) is elapsed. So all the schemas are synced by ticker. + // There is no need to use etcd to sync. The function returns directly. + return nil + } + } + + return checkAllVersions(ctx, d, job, latestSchemaVersion, timeStart) +} + +// waitSchemaSyncedForMDL likes waitSchemaSynced, but it waits for getting the metadata lock of the latest version of this DDL. +func waitSchemaSyncedForMDL(ctx context.Context, d *ddlCtx, job *model.Job, latestSchemaVersion int64) error { + timeStart := time.Now() + return checkAllVersions(ctx, d, job, latestSchemaVersion, timeStart) +} + +func buildPlacementAffects(oldIDs []int64, newIDs []int64) []*model.AffectedOption { + if len(oldIDs) == 0 { + return nil + } + + affects := make([]*model.AffectedOption, len(oldIDs)) + for i := 0; i < len(oldIDs); i++ { + affects[i] = &model.AffectedOption{ + OldTableID: oldIDs[i], + TableID: newIDs[i], + } + } + return affects +} diff --git a/pkg/ddl/mock.go b/pkg/ddl/mock.go index 7be8f499fa01e..330aa62ce6980 100644 --- a/pkg/ddl/mock.go +++ b/pkg/ddl/mock.go @@ -72,11 +72,11 @@ func (*MockSchemaSyncer) WatchGlobalSchemaVer(context.Context) {} // UpdateSelfVersion implements SchemaSyncer.UpdateSelfVersion interface. func (s *MockSchemaSyncer) UpdateSelfVersion(_ context.Context, jobID int64, version int64) error { - failpoint.Inject("mockUpdateMDLToETCDError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateMDLToETCDError")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.New("mock update mdl to etcd error")) + return errors.New("mock update mdl to etcd error") } - }) + } if variable.EnableMDL.Load() { s.mdlSchemaVersions.Store(jobID, version) } else { @@ -115,20 +115,20 @@ func (s *MockSchemaSyncer) OwnerCheckAllVersions(ctx context.Context, jobID int6 ticker := time.NewTicker(mockCheckVersInterval) defer ticker.Stop() - failpoint.Inject("mockOwnerCheckAllVersionSlow", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockOwnerCheckAllVersionSlow")); _err_ == nil { if v, ok := val.(int); ok && v == int(jobID) { time.Sleep(2 * time.Second) } - }) + } for { select { case <-ctx.Done(): - failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("checkOwnerCheckAllVersionsWaitTime")); _err_ == nil { if v.(bool) { panic("shouldn't happen") } - }) + } return errors.Trace(ctx.Err()) case <-ticker.C: if variable.EnableMDL.Load() { @@ -181,12 +181,12 @@ func (s *MockStateSyncer) Init(context.Context) error { // UpdateGlobalState implements StateSyncer.UpdateGlobalState interface. func (s *MockStateSyncer) UpdateGlobalState(_ context.Context, stateInfo *syncer.StateInfo) error { - failpoint.Inject("mockUpgradingState", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockUpgradingState")); _err_ == nil { if val.(bool) { clusterState.Store(stateInfo) - failpoint.Return(nil) + return nil } - }) + } s.globalVerCh <- clientv3.WatchResponse{} clusterState.Store(stateInfo) return nil diff --git a/pkg/ddl/mock.go__failpoint_stash__ b/pkg/ddl/mock.go__failpoint_stash__ new file mode 100644 index 0000000000000..7be8f499fa01e --- /dev/null +++ b/pkg/ddl/mock.go__failpoint_stash__ @@ -0,0 +1,260 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/syncer" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + clientv3 "go.etcd.io/etcd/client/v3" + atomicutil "go.uber.org/atomic" +) + +// SetBatchInsertDeleteRangeSize sets the batch insert/delete range size in the test +func SetBatchInsertDeleteRangeSize(i int) { + batchInsertDeleteRangeSize = i +} + +var _ syncer.SchemaSyncer = &MockSchemaSyncer{} + +const mockCheckVersInterval = 2 * time.Millisecond + +// MockSchemaSyncer is a mock schema syncer, it is exported for testing. +type MockSchemaSyncer struct { + selfSchemaVersion int64 + mdlSchemaVersions sync.Map + globalVerCh chan clientv3.WatchResponse + mockSession chan struct{} +} + +// NewMockSchemaSyncer creates a new mock SchemaSyncer. +func NewMockSchemaSyncer() syncer.SchemaSyncer { + return &MockSchemaSyncer{} +} + +// Init implements SchemaSyncer.Init interface. +func (s *MockSchemaSyncer) Init(_ context.Context) error { + s.mdlSchemaVersions = sync.Map{} + s.globalVerCh = make(chan clientv3.WatchResponse, 1) + s.mockSession = make(chan struct{}, 1) + return nil +} + +// GlobalVersionCh implements SchemaSyncer.GlobalVersionCh interface. +func (s *MockSchemaSyncer) GlobalVersionCh() clientv3.WatchChan { + return s.globalVerCh +} + +// WatchGlobalSchemaVer implements SchemaSyncer.WatchGlobalSchemaVer interface. +func (*MockSchemaSyncer) WatchGlobalSchemaVer(context.Context) {} + +// UpdateSelfVersion implements SchemaSyncer.UpdateSelfVersion interface. +func (s *MockSchemaSyncer) UpdateSelfVersion(_ context.Context, jobID int64, version int64) error { + failpoint.Inject("mockUpdateMDLToETCDError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock update mdl to etcd error")) + } + }) + if variable.EnableMDL.Load() { + s.mdlSchemaVersions.Store(jobID, version) + } else { + atomic.StoreInt64(&s.selfSchemaVersion, version) + } + return nil +} + +// Done implements SchemaSyncer.Done interface. +func (s *MockSchemaSyncer) Done() <-chan struct{} { + return s.mockSession +} + +// CloseSession mockSession, it is exported for testing. +func (s *MockSchemaSyncer) CloseSession() { + close(s.mockSession) +} + +// Restart implements SchemaSyncer.Restart interface. +func (s *MockSchemaSyncer) Restart(_ context.Context) error { + s.mockSession = make(chan struct{}, 1) + return nil +} + +// OwnerUpdateGlobalVersion implements SchemaSyncer.OwnerUpdateGlobalVersion interface. +func (s *MockSchemaSyncer) OwnerUpdateGlobalVersion(_ context.Context, _ int64) error { + select { + case s.globalVerCh <- clientv3.WatchResponse{}: + default: + } + return nil +} + +// OwnerCheckAllVersions implements SchemaSyncer.OwnerCheckAllVersions interface. +func (s *MockSchemaSyncer) OwnerCheckAllVersions(ctx context.Context, jobID int64, latestVer int64) error { + ticker := time.NewTicker(mockCheckVersInterval) + defer ticker.Stop() + + failpoint.Inject("mockOwnerCheckAllVersionSlow", func(val failpoint.Value) { + if v, ok := val.(int); ok && v == int(jobID) { + time.Sleep(2 * time.Second) + } + }) + + for { + select { + case <-ctx.Done(): + failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(v failpoint.Value) { + if v.(bool) { + panic("shouldn't happen") + } + }) + return errors.Trace(ctx.Err()) + case <-ticker.C: + if variable.EnableMDL.Load() { + ver, ok := s.mdlSchemaVersions.Load(jobID) + if ok && ver.(int64) >= latestVer { + return nil + } + } else { + ver := atomic.LoadInt64(&s.selfSchemaVersion) + if ver >= latestVer { + return nil + } + } + } + } +} + +// SyncJobSchemaVerLoop implements SchemaSyncer.SyncJobSchemaVerLoop interface. +func (*MockSchemaSyncer) SyncJobSchemaVerLoop(context.Context) { +} + +// Close implements SchemaSyncer.Close interface. +func (*MockSchemaSyncer) Close() {} + +// NewMockStateSyncer creates a new mock StateSyncer. +func NewMockStateSyncer() syncer.StateSyncer { + return &MockStateSyncer{} +} + +// clusterState mocks cluster state. +// We move it from MockStateSyncer to here. Because we want to make it unaffected by ddl close. +var clusterState *atomicutil.Pointer[syncer.StateInfo] + +// MockStateSyncer is a mock state syncer, it is exported for testing. +type MockStateSyncer struct { + globalVerCh chan clientv3.WatchResponse + mockSession chan struct{} +} + +// Init implements StateSyncer.Init interface. +func (s *MockStateSyncer) Init(context.Context) error { + s.globalVerCh = make(chan clientv3.WatchResponse, 1) + s.mockSession = make(chan struct{}, 1) + state := syncer.NewStateInfo(syncer.StateNormalRunning) + if clusterState == nil { + clusterState = atomicutil.NewPointer(state) + } + return nil +} + +// UpdateGlobalState implements StateSyncer.UpdateGlobalState interface. +func (s *MockStateSyncer) UpdateGlobalState(_ context.Context, stateInfo *syncer.StateInfo) error { + failpoint.Inject("mockUpgradingState", func(val failpoint.Value) { + if val.(bool) { + clusterState.Store(stateInfo) + failpoint.Return(nil) + } + }) + s.globalVerCh <- clientv3.WatchResponse{} + clusterState.Store(stateInfo) + return nil +} + +// GetGlobalState implements StateSyncer.GetGlobalState interface. +func (*MockStateSyncer) GetGlobalState(context.Context) (*syncer.StateInfo, error) { + return clusterState.Load(), nil +} + +// IsUpgradingState implements StateSyncer.IsUpgradingState interface. +func (*MockStateSyncer) IsUpgradingState() bool { + return clusterState.Load().State == syncer.StateUpgrading +} + +// WatchChan implements StateSyncer.WatchChan interface. +func (s *MockStateSyncer) WatchChan() clientv3.WatchChan { + return s.globalVerCh +} + +// Rewatch implements StateSyncer.Rewatch interface. +func (*MockStateSyncer) Rewatch(context.Context) {} + +type mockDelRange struct { +} + +// newMockDelRangeManager creates a mock delRangeManager only used for test. +func newMockDelRangeManager() delRangeManager { + return &mockDelRange{} +} + +// addDelRangeJob implements delRangeManager interface. +func (*mockDelRange) addDelRangeJob(_ context.Context, _ *model.Job) error { + return nil +} + +// removeFromGCDeleteRange implements delRangeManager interface. +func (*mockDelRange) removeFromGCDeleteRange(_ context.Context, _ int64) error { + return nil +} + +// start implements delRangeManager interface. +func (*mockDelRange) start() {} + +// clear implements delRangeManager interface. +func (*mockDelRange) clear() {} + +// MockTableInfo mocks a table info by create table stmt ast and a specified table id. +func MockTableInfo(ctx sessionctx.Context, stmt *ast.CreateTableStmt, tableID int64) (*model.TableInfo, error) { + chs, coll := charset.GetDefaultCharsetAndCollate() + cols, newConstraints, err := buildColumnsAndConstraints(ctx, stmt.Cols, stmt.Constraints, chs, coll) + if err != nil { + return nil, errors.Trace(err) + } + tbl, err := BuildTableInfo(ctx, stmt.Table.Name, cols, newConstraints, "", "") + if err != nil { + return nil, errors.Trace(err) + } + tbl.ID = tableID + + if err = setTableAutoRandomBits(ctx, tbl, stmt.Cols); err != nil { + return nil, errors.Trace(err) + } + + // The specified charset will be handled in handleTableOptions + if err = handleTableOptions(stmt.Options, tbl); err != nil { + return nil, errors.Trace(err) + } + + return tbl, nil +} diff --git a/pkg/ddl/modify_column.go b/pkg/ddl/modify_column.go index 09a654d133a26..1c95073988983 100644 --- a/pkg/ddl/modify_column.go +++ b/pkg/ddl/modify_column.go @@ -84,14 +84,14 @@ func (w *worker) onModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver in } } - failpoint.Inject("uninitializedOffsetAndState", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("uninitializedOffsetAndState")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { if modifyInfo.newCol.State != model.StatePublic { - failpoint.Return(ver, errors.New("the column state is wrong")) + return ver, errors.New("the column state is wrong") } } - }) + } err = checkAndApplyAutoRandomBits(d, t, dbInfo, tblInfo, oldCol, modifyInfo.newCol, modifyInfo.updatedAutoRandomBits) if err != nil { @@ -442,12 +442,12 @@ func (w *worker) doModifyColumnTypeWithData( } // none -> delete only updateChangingObjState(changingCol, changingIdxs, model.StateDeleteOnly) - failpoint.Inject("mockInsertValueAfterCheckNull", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockInsertValueAfterCheckNull")); _err_ == nil { if valStr, ok := val.(string); ok { var sctx sessionctx.Context sctx, err := w.sessPool.Get() if err != nil { - failpoint.Return(ver, err) + return ver, err } defer w.sessPool.Put(sctx) @@ -456,10 +456,10 @@ func (w *worker) doModifyColumnTypeWithData( _, _, err = sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, valStr) if err != nil { job.State = model.JobStateCancelled - failpoint.Return(ver, err) + return ver, err } } - }) + } ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != changingCol.State) if err != nil { return ver, errors.Trace(err) @@ -488,7 +488,7 @@ func (w *worker) doModifyColumnTypeWithData( return ver, errors.Trace(err) } job.SchemaState = model.StateWriteOnly - failpoint.InjectCall("afterModifyColumnStateDeleteOnly", job.ID) + failpoint.Call(_curpkg_("afterModifyColumnStateDeleteOnly"), job.ID) case model.StateWriteOnly: // write only -> reorganization updateChangingObjState(changingCol, changingIdxs, model.StateWriteReorganization) @@ -587,7 +587,7 @@ func doReorgWorkForModifyColumn(w *worker, d *ddlCtx, t *meta.Meta, job *model.J // With a failpoint-enabled version of TiDB, you can trigger this failpoint by the following command: // enable: curl -X PUT -d "pause" "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData". // disable: curl -X DELETE "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData" - failpoint.Inject("mockDelayInModifyColumnTypeWithData", func() {}) + failpoint.Eval(_curpkg_("mockDelayInModifyColumnTypeWithData")) err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (addIndexErr error) { defer util.Recover(metrics.LabelDDL, "onModifyColumn", func() { diff --git a/pkg/ddl/modify_column.go__failpoint_stash__ b/pkg/ddl/modify_column.go__failpoint_stash__ new file mode 100644 index 0000000000000..09a654d133a26 --- /dev/null +++ b/pkg/ddl/modify_column.go__failpoint_stash__ @@ -0,0 +1,1318 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "context" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + "go.uber.org/zap" +) + +type modifyingColInfo struct { + newCol *model.ColumnInfo + oldColName *model.CIStr + modifyColumnTp byte + updatedAutoRandomBits uint64 + changingCol *model.ColumnInfo + changingIdxs []*model.IndexInfo + pos *ast.ColumnPosition + removedIdxs []int64 +} + +func (w *worker) onModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + dbInfo, tblInfo, oldCol, modifyInfo, err := getModifyColumnInfo(t, job) + if err != nil { + return ver, err + } + + if job.IsRollingback() { + // For those column-type-change jobs which don't reorg the data. + if !needChangeColumnData(oldCol, modifyInfo.newCol) { + return rollbackModifyColumnJob(d, t, tblInfo, job, modifyInfo.newCol, oldCol, modifyInfo.modifyColumnTp) + } + // For those column-type-change jobs which reorg the data. + return rollbackModifyColumnJobWithData(d, t, tblInfo, job, oldCol, modifyInfo) + } + + // If we want to rename the column name, we need to check whether it already exists. + if modifyInfo.newCol.Name.L != modifyInfo.oldColName.L { + c := model.FindColumnInfo(tblInfo.Columns, modifyInfo.newCol.Name.L) + if c != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(infoschema.ErrColumnExists.GenWithStackByArgs(modifyInfo.newCol.Name)) + } + } + + failpoint.Inject("uninitializedOffsetAndState", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + if modifyInfo.newCol.State != model.StatePublic { + failpoint.Return(ver, errors.New("the column state is wrong")) + } + } + }) + + err = checkAndApplyAutoRandomBits(d, t, dbInfo, tblInfo, oldCol, modifyInfo.newCol, modifyInfo.updatedAutoRandomBits) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if !needChangeColumnData(oldCol, modifyInfo.newCol) { + return w.doModifyColumn(d, t, job, dbInfo, tblInfo, modifyInfo.newCol, oldCol, modifyInfo.pos) + } + + if err = isGeneratedRelatedColumn(tblInfo, modifyInfo.newCol, oldCol); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + if tblInfo.Partition != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("table is partition table")) + } + + changingCol := modifyInfo.changingCol + if changingCol == nil { + newColName := model.NewCIStr(genChangingColumnUniqueName(tblInfo, oldCol)) + if mysql.HasPriKeyFlag(oldCol.GetFlag()) { + job.State = model.JobStateCancelled + msg := "this column has primary key flag" + return ver, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(msg) + } + + changingCol = modifyInfo.newCol.Clone() + changingCol.Name = newColName + changingCol.ChangeStateInfo = &model.ChangeStateInfo{DependencyColumnOffset: oldCol.Offset} + + originDefVal, err := GetOriginDefaultValueForModifyColumn(newReorgExprCtx(), changingCol, oldCol) + if err != nil { + return ver, errors.Trace(err) + } + if err = changingCol.SetOriginDefaultValue(originDefVal); err != nil { + return ver, errors.Trace(err) + } + + InitAndAddColumnToTable(tblInfo, changingCol) + indexesToChange := FindRelatedIndexesToChange(tblInfo, oldCol.Name) + for _, info := range indexesToChange { + newIdxID := AllocateIndexID(tblInfo) + if !info.isTemp { + // We create a temp index for each normal index. + tmpIdx := info.IndexInfo.Clone() + tmpIdxName := genChangingIndexUniqueName(tblInfo, info.IndexInfo) + setIdxIDName(tmpIdx, newIdxID, model.NewCIStr(tmpIdxName)) + SetIdxColNameOffset(tmpIdx.Columns[info.Offset], changingCol) + tblInfo.Indices = append(tblInfo.Indices, tmpIdx) + } else { + // The index is a temp index created by previous modify column job(s). + // We can overwrite it to reduce reorg cost, because it will be dropped eventually. + tmpIdx := info.IndexInfo + oldTempIdxID := tmpIdx.ID + setIdxIDName(tmpIdx, newIdxID, tmpIdx.Name /* unchanged */) + SetIdxColNameOffset(tmpIdx.Columns[info.Offset], changingCol) + modifyInfo.removedIdxs = append(modifyInfo.removedIdxs, oldTempIdxID) + } + } + } else { + changingCol = model.FindColumnInfoByID(tblInfo.Columns, modifyInfo.changingCol.ID) + if changingCol == nil { + logutil.DDLLogger().Error("the changing column has been removed", zap.Error(err)) + job.State = model.JobStateCancelled + return ver, errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) + } + } + + return w.doModifyColumnTypeWithData(d, t, job, dbInfo, tblInfo, changingCol, oldCol, modifyInfo.newCol.Name, modifyInfo.pos, modifyInfo.removedIdxs) +} + +// rollbackModifyColumnJob rollbacks the job when an error occurs. +func rollbackModifyColumnJob(d *ddlCtx, t *meta.Meta, tblInfo *model.TableInfo, job *model.Job, newCol, oldCol *model.ColumnInfo, modifyColumnTp byte) (ver int64, _ error) { + var err error + if oldCol.ID == newCol.ID && modifyColumnTp == mysql.TypeNull { + // field NotNullFlag flag reset. + tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.NotNullFlag) + // field PreventNullInsertFlag flag reset. + tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.PreventNullInsertFlag) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + } + job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) + // For those column-type-change type which doesn't need reorg data, we should also mock the job args for delete range. + job.Args = []any{[]int64{}, []int64{}} + return ver, nil +} + +func getModifyColumnInfo(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ColumnInfo, *modifyingColInfo, error) { + modifyInfo := &modifyingColInfo{pos: &ast.ColumnPosition{}} + err := job.DecodeArgs(&modifyInfo.newCol, &modifyInfo.oldColName, modifyInfo.pos, &modifyInfo.modifyColumnTp, + &modifyInfo.updatedAutoRandomBits, &modifyInfo.changingCol, &modifyInfo.changingIdxs, &modifyInfo.removedIdxs) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, modifyInfo, errors.Trace(err) + } + + dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) + if err != nil { + return nil, nil, nil, modifyInfo, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return nil, nil, nil, modifyInfo, errors.Trace(err) + } + + oldCol := model.FindColumnInfo(tblInfo.Columns, modifyInfo.oldColName.L) + if oldCol == nil || oldCol.State != model.StatePublic { + job.State = model.JobStateCancelled + return nil, nil, nil, modifyInfo, errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(*(modifyInfo.oldColName), tblInfo.Name)) + } + + return dbInfo, tblInfo, oldCol, modifyInfo, errors.Trace(err) +} + +// GetOriginDefaultValueForModifyColumn gets the original default value for modifying column. +// Since column type change is implemented as adding a new column then substituting the old one. +// Case exists when update-where statement fetch a NULL for not-null column without any default data, +// it will errors. +// So we set original default value here to prevent this error. If the oldCol has the original default value, we use it. +// Otherwise we set the zero value as original default value. +// Besides, in insert & update records, we have already implement using the casted value of relative column to insert +// rather than the original default value. +func GetOriginDefaultValueForModifyColumn(ctx exprctx.BuildContext, changingCol, oldCol *model.ColumnInfo) (any, error) { + var err error + originDefVal := oldCol.GetOriginDefaultValue() + if originDefVal != nil { + odv, err := table.CastColumnValue(ctx, types.NewDatum(originDefVal), changingCol, false, false) + if err != nil { + logutil.DDLLogger().Info("cast origin default value failed", zap.Error(err)) + } + if !odv.IsNull() { + if originDefVal, err = odv.ToString(); err != nil { + originDefVal = nil + logutil.DDLLogger().Info("convert default value to string failed", zap.Error(err)) + } + } + } + if originDefVal == nil { + originDefVal, err = generateOriginDefaultValue(changingCol, nil) + if err != nil { + return nil, errors.Trace(err) + } + } + return originDefVal, nil +} + +// rollbackModifyColumnJobWithData is used to rollback modify-column job which need to reorg the data. +func rollbackModifyColumnJobWithData(d *ddlCtx, t *meta.Meta, tblInfo *model.TableInfo, job *model.Job, oldCol *model.ColumnInfo, modifyInfo *modifyingColInfo) (ver int64, err error) { + // If the not-null change is included, we should clean the flag info in oldCol. + if modifyInfo.modifyColumnTp == mysql.TypeNull { + // Reset NotNullFlag flag. + tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.NotNullFlag) + // Reset PreventNullInsertFlag flag. + tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.PreventNullInsertFlag) + } + var changingIdxIDs []int64 + if modifyInfo.changingCol != nil { + changingIdxIDs = buildRelatedIndexIDs(tblInfo, modifyInfo.changingCol.ID) + // The job is in the middle state. The appended changingCol and changingIndex should + // be removed from the tableInfo as well. + removeChangingColAndIdxs(tblInfo, modifyInfo.changingCol.ID) + } + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) + // Reconstruct the job args to add the temporary index ids into delete range table. + job.Args = []any{changingIdxIDs, getPartitionIDs(tblInfo)} + return ver, nil +} + +// doModifyColumn updates the column information and reorders all columns. It does not support modifying column data. +func (w *worker) doModifyColumn( + d *ddlCtx, t *meta.Meta, job *model.Job, dbInfo *model.DBInfo, tblInfo *model.TableInfo, + newCol, oldCol *model.ColumnInfo, pos *ast.ColumnPosition) (ver int64, _ error) { + if oldCol.ID != newCol.ID { + job.State = model.JobStateRollingback + return ver, dbterror.ErrColumnInChange.GenWithStackByArgs(oldCol.Name, newCol.ID) + } + // Column from null to not null. + if !mysql.HasNotNullFlag(oldCol.GetFlag()) && mysql.HasNotNullFlag(newCol.GetFlag()) { + noPreventNullFlag := !mysql.HasPreventNullInsertFlag(oldCol.GetFlag()) + + // lease = 0 means it's in an integration test. In this case we don't delay so the test won't run too slowly. + // We need to check after the flag is set + if d.lease > 0 && !noPreventNullFlag { + delayForAsyncCommit() + } + + // Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. + err := modifyColsFromNull2NotNull(w, dbInfo, tblInfo, []*model.ColumnInfo{oldCol}, newCol, oldCol.GetType() != newCol.GetType()) + if err != nil { + if dbterror.ErrWarnDataTruncated.Equal(err) || dbterror.ErrInvalidUseOfNull.Equal(err) { + job.State = model.JobStateRollingback + } + return ver, err + } + // The column should get into prevent null status first. + if noPreventNullFlag { + return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + } + } + + if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + // Store the mark and enter the next DDL handling loop. + return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) + } + + if err := adjustTableInfoAfterModifyColumn(tblInfo, newCol, oldCol, pos); err != nil { + job.State = model.JobStateRollingback + return ver, errors.Trace(err) + } + + childTableInfos, err := adjustForeignKeyChildTableInfoAfterModifyColumn(d, t, job, tblInfo, newCol, oldCol) + if err != nil { + return ver, errors.Trace(err) + } + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true, childTableInfos...) + if err != nil { + // Modified the type definition of 'null' to 'not null' before this, so rollBack the job when an error occurs. + job.State = model.JobStateRollingback + return ver, errors.Trace(err) + } + + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + // For those column-type-change type which doesn't need reorg data, we should also mock the job args for delete range. + job.Args = []any{[]int64{}, []int64{}} + return ver, nil +} + +func adjustTableInfoAfterModifyColumn( + tblInfo *model.TableInfo, newCol, oldCol *model.ColumnInfo, pos *ast.ColumnPosition) error { + // We need the latest column's offset and state. This information can be obtained from the store. + newCol.Offset = oldCol.Offset + newCol.State = oldCol.State + if pos != nil && pos.RelativeColumn != nil && oldCol.Name.L == pos.RelativeColumn.Name.L { + // For cases like `modify column b after b`, it should report this error. + return errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) + } + destOffset, err := LocateOffsetToMove(oldCol.Offset, pos, tblInfo) + if err != nil { + return errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) + } + tblInfo.Columns[oldCol.Offset] = newCol + tblInfo.MoveColumnInfo(oldCol.Offset, destOffset) + updateNewIdxColsNameOffset(tblInfo.Indices, oldCol.Name, newCol) + updateFKInfoWhenModifyColumn(tblInfo, oldCol.Name, newCol.Name) + updateTTLInfoWhenModifyColumn(tblInfo, oldCol.Name, newCol.Name) + return nil +} + +func updateFKInfoWhenModifyColumn(tblInfo *model.TableInfo, oldCol, newCol model.CIStr) { + if oldCol.L == newCol.L { + return + } + for _, fk := range tblInfo.ForeignKeys { + for i := range fk.Cols { + if fk.Cols[i].L == oldCol.L { + fk.Cols[i] = newCol + } + } + } +} + +func updateTTLInfoWhenModifyColumn(tblInfo *model.TableInfo, oldCol, newCol model.CIStr) { + if oldCol.L == newCol.L { + return + } + if tblInfo.TTLInfo != nil { + if tblInfo.TTLInfo.ColumnName.L == oldCol.L { + tblInfo.TTLInfo.ColumnName = newCol + } + } +} + +func adjustForeignKeyChildTableInfoAfterModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, newCol, oldCol *model.ColumnInfo) ([]schemaIDAndTableInfo, error) { + if !variable.EnableForeignKey.Load() || newCol.Name.L == oldCol.Name.L { + return nil, nil + } + is, err := getAndCheckLatestInfoSchema(d, t) + if err != nil { + return nil, err + } + referredFKs := is.GetTableReferredForeignKeys(job.SchemaName, tblInfo.Name.L) + if len(referredFKs) == 0 { + return nil, nil + } + fkh := newForeignKeyHelper() + fkh.addLoadedTable(job.SchemaName, tblInfo.Name.L, job.SchemaID, tblInfo) + for _, referredFK := range referredFKs { + info, err := fkh.getTableFromStorage(is, t, referredFK.ChildSchema, referredFK.ChildTable) + if err != nil { + if infoschema.ErrTableNotExists.Equal(err) || infoschema.ErrDatabaseNotExists.Equal(err) { + continue + } + return nil, err + } + fkInfo := model.FindFKInfoByName(info.tblInfo.ForeignKeys, referredFK.ChildFKName.L) + if fkInfo == nil { + continue + } + for i := range fkInfo.RefCols { + if fkInfo.RefCols[i].L == oldCol.Name.L { + fkInfo.RefCols[i] = newCol.Name + } + } + } + infoList := make([]schemaIDAndTableInfo, 0, len(fkh.loaded)) + for _, info := range fkh.loaded { + if info.tblInfo.ID == tblInfo.ID { + continue + } + infoList = append(infoList, info) + } + return infoList, nil +} + +func (w *worker) doModifyColumnTypeWithData( + d *ddlCtx, t *meta.Meta, job *model.Job, + dbInfo *model.DBInfo, tblInfo *model.TableInfo, changingCol, oldCol *model.ColumnInfo, + colName model.CIStr, pos *ast.ColumnPosition, rmIdxIDs []int64) (ver int64, _ error) { + var err error + originalState := changingCol.State + targetCol := changingCol.Clone() + targetCol.Name = colName + changingIdxs := buildRelatedIndexInfos(tblInfo, changingCol.ID) + switch changingCol.State { + case model.StateNone: + // Column from null to not null. + if !mysql.HasNotNullFlag(oldCol.GetFlag()) && mysql.HasNotNullFlag(changingCol.GetFlag()) { + // Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. + err := modifyColsFromNull2NotNull(w, dbInfo, tblInfo, []*model.ColumnInfo{oldCol}, targetCol, oldCol.GetType() != changingCol.GetType()) + if err != nil { + if dbterror.ErrWarnDataTruncated.Equal(err) || dbterror.ErrInvalidUseOfNull.Equal(err) { + job.State = model.JobStateRollingback + } + return ver, err + } + } + // none -> delete only + updateChangingObjState(changingCol, changingIdxs, model.StateDeleteOnly) + failpoint.Inject("mockInsertValueAfterCheckNull", func(val failpoint.Value) { + if valStr, ok := val.(string); ok { + var sctx sessionctx.Context + sctx, err := w.sessPool.Get() + if err != nil { + failpoint.Return(ver, err) + } + defer w.sessPool.Put(sctx) + + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + //nolint:forcetypeassert + _, _, err = sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, valStr) + if err != nil { + job.State = model.JobStateCancelled + failpoint.Return(ver, err) + } + } + }) + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != changingCol.State) + if err != nil { + return ver, errors.Trace(err) + } + // Make sure job args change after `updateVersionAndTableInfoWithCheck`, otherwise, the job args will + // be updated in `updateDDLJob` even if it meets an error in `updateVersionAndTableInfoWithCheck`. + job.SchemaState = model.StateDeleteOnly + metrics.GetBackfillProgressByLabel(metrics.LblModifyColumn, job.SchemaName, tblInfo.Name.String()).Set(0) + job.Args = append(job.Args, changingCol, changingIdxs, rmIdxIDs) + case model.StateDeleteOnly: + // Column from null to not null. + if !mysql.HasNotNullFlag(oldCol.GetFlag()) && mysql.HasNotNullFlag(changingCol.GetFlag()) { + // Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. + err := modifyColsFromNull2NotNull(w, dbInfo, tblInfo, []*model.ColumnInfo{oldCol}, targetCol, oldCol.GetType() != changingCol.GetType()) + if err != nil { + if dbterror.ErrWarnDataTruncated.Equal(err) || dbterror.ErrInvalidUseOfNull.Equal(err) { + job.State = model.JobStateRollingback + } + return ver, err + } + } + // delete only -> write only + updateChangingObjState(changingCol, changingIdxs, model.StateWriteOnly) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != changingCol.State) + if err != nil { + return ver, errors.Trace(err) + } + job.SchemaState = model.StateWriteOnly + failpoint.InjectCall("afterModifyColumnStateDeleteOnly", job.ID) + case model.StateWriteOnly: + // write only -> reorganization + updateChangingObjState(changingCol, changingIdxs, model.StateWriteReorganization) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != changingCol.State) + if err != nil { + return ver, errors.Trace(err) + } + // Initialize SnapshotVer to 0 for later reorganization check. + job.SnapshotVer = 0 + job.SchemaState = model.StateWriteReorganization + case model.StateWriteReorganization: + tbl, err := getTable(d.getAutoIDRequirement(), dbInfo.ID, tblInfo) + if err != nil { + return ver, errors.Trace(err) + } + + var done bool + if job.MultiSchemaInfo != nil { + done, ver, err = doReorgWorkForModifyColumnMultiSchema(w, d, t, job, tbl, oldCol, changingCol, changingIdxs) + } else { + done, ver, err = doReorgWorkForModifyColumn(w, d, t, job, tbl, oldCol, changingCol, changingIdxs) + } + if !done { + return ver, err + } + + rmIdxIDs = append(buildRelatedIndexIDs(tblInfo, oldCol.ID), rmIdxIDs...) + + err = adjustTableInfoAfterModifyColumnWithData(tblInfo, pos, oldCol, changingCol, colName, changingIdxs) + if err != nil { + job.State = model.JobStateRollingback + return ver, errors.Trace(err) + } + + updateChangingObjState(changingCol, changingIdxs, model.StatePublic) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != changingCol.State) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + // Refactor the job args to add the old index ids into delete range table. + job.Args = []any{rmIdxIDs, getPartitionIDs(tblInfo)} + modifyColumnEvent := statsutil.NewModifyColumnEvent( + job.SchemaID, + tblInfo, + []*model.ColumnInfo{changingCol}, + ) + asyncNotifyEvent(d, modifyColumnEvent) + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("column", changingCol.State) + } + + return ver, errors.Trace(err) +} + +func doReorgWorkForModifyColumnMultiSchema(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, tbl table.Table, + oldCol, changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo) (done bool, ver int64, err error) { + if job.MultiSchemaInfo.Revertible { + done, ver, err = doReorgWorkForModifyColumn(w, d, t, job, tbl, oldCol, changingCol, changingIdxs) + if done { + // We need another round to wait for all the others sub-jobs to finish. + job.MarkNonRevertible() + } + // We need another round to run the reorg process. + return false, ver, err + } + // Non-revertible means all the sub jobs finished. + return true, ver, err +} + +func doReorgWorkForModifyColumn(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, tbl table.Table, + oldCol, changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo) (done bool, ver int64, err error) { + job.ReorgMeta.ReorgTp = model.ReorgTypeTxn + sctx, err1 := w.sessPool.Get() + if err1 != nil { + err = errors.Trace(err1) + return + } + defer w.sessPool.Put(sctx) + rh := newReorgHandler(sess.NewSession(sctx)) + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return false, ver, errors.Trace(err) + } + reorgInfo, err := getReorgInfo(d.jobContext(job.ID, job.ReorgMeta), + d, rh, job, dbInfo, tbl, BuildElements(changingCol, changingIdxs), false) + if err != nil || reorgInfo == nil || reorgInfo.first { + // If we run reorg firstly, we should update the job snapshot version + // and then run the reorg next time. + return false, ver, errors.Trace(err) + } + + // Inject a failpoint so that we can pause here and do verification on other components. + // With a failpoint-enabled version of TiDB, you can trigger this failpoint by the following command: + // enable: curl -X PUT -d "pause" "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData". + // disable: curl -X DELETE "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData" + failpoint.Inject("mockDelayInModifyColumnTypeWithData", func() {}) + err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (addIndexErr error) { + defer util.Recover(metrics.LabelDDL, "onModifyColumn", + func() { + addIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("modify table `%v` column `%v` panic", tbl.Meta().Name, oldCol.Name) + }, false) + // Use old column name to generate less confusing error messages. + changingColCpy := changingCol.Clone() + changingColCpy.Name = oldCol.Name + return w.updateCurrentElement(tbl, reorgInfo) + }) + if err != nil { + if dbterror.ErrPausedDDLJob.Equal(err) { + return false, ver, nil + } + + if dbterror.ErrWaitReorgTimeout.Equal(err) { + // If timeout, we should return, check for the owner and re-wait job done. + return false, ver, nil + } + if kv.IsTxnRetryableError(err) || dbterror.ErrNotOwner.Equal(err) { + return false, ver, errors.Trace(err) + } + if err1 := rh.RemoveDDLReorgHandle(job, reorgInfo.elements); err1 != nil { + logutil.DDLLogger().Warn("run modify column job failed, RemoveDDLReorgHandle failed, can't convert job to rollback", + zap.String("job", job.String()), zap.Error(err1)) + } + logutil.DDLLogger().Warn("run modify column job failed, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) + job.State = model.JobStateRollingback + return false, ver, errors.Trace(err) + } + return true, ver, nil +} + +func adjustTableInfoAfterModifyColumnWithData(tblInfo *model.TableInfo, pos *ast.ColumnPosition, + oldCol, changingCol *model.ColumnInfo, newName model.CIStr, changingIdxs []*model.IndexInfo) (err error) { + if pos != nil && pos.RelativeColumn != nil && oldCol.Name.L == pos.RelativeColumn.Name.L { + // For cases like `modify column b after b`, it should report this error. + return errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) + } + internalColName := changingCol.Name + changingCol = replaceOldColumn(tblInfo, oldCol, changingCol, newName) + if len(changingIdxs) > 0 { + updateNewIdxColsNameOffset(changingIdxs, internalColName, changingCol) + indexesToRemove := filterIndexesToRemove(changingIdxs, newName, tblInfo) + replaceOldIndexes(tblInfo, indexesToRemove) + } + if tblInfo.TTLInfo != nil { + updateTTLInfoWhenModifyColumn(tblInfo, oldCol.Name, changingCol.Name) + } + // Move the new column to a correct offset. + destOffset, err := LocateOffsetToMove(changingCol.Offset, pos, tblInfo) + if err != nil { + return errors.Trace(err) + } + tblInfo.MoveColumnInfo(changingCol.Offset, destOffset) + return nil +} + +func checkModifyColumnWithGeneratedColumnsConstraint(allCols []*table.Column, oldColName model.CIStr) error { + for _, col := range allCols { + if col.GeneratedExpr == nil { + continue + } + dependedColNames := FindColumnNamesInExpr(col.GeneratedExpr.Internal()) + for _, name := range dependedColNames { + if name.Name.L == oldColName.L { + if col.Hidden { + return dbterror.ErrDependentByFunctionalIndex.GenWithStackByArgs(oldColName.O) + } + return dbterror.ErrDependentByGeneratedColumn.GenWithStackByArgs(oldColName.O) + } + } + } + return nil +} + +// GetModifiableColumnJob returns a DDL job of model.ActionModifyColumn. +func GetModifiableColumnJob( + ctx context.Context, + sctx sessionctx.Context, + is infoschema.InfoSchema, // WARN: is maybe nil here. + ident ast.Ident, + originalColName model.CIStr, + schema *model.DBInfo, + t table.Table, + spec *ast.AlterTableSpec, +) (*model.Job, error) { + var err error + specNewColumn := spec.NewColumns[0] + + col := table.FindCol(t.Cols(), originalColName.L) + if col == nil { + return nil, infoschema.ErrColumnNotExists.GenWithStackByArgs(originalColName, ident.Name) + } + newColName := specNewColumn.Name.Name + if newColName.L == model.ExtraHandleName.L { + return nil, dbterror.ErrWrongColumnName.GenWithStackByArgs(newColName.L) + } + errG := checkModifyColumnWithGeneratedColumnsConstraint(t.Cols(), originalColName) + + // If we want to rename the column name, we need to check whether it already exists. + if newColName.L != originalColName.L { + c := table.FindCol(t.Cols(), newColName.L) + if c != nil { + return nil, infoschema.ErrColumnExists.GenWithStackByArgs(newColName) + } + + // And also check the generated columns dependency, if some generated columns + // depend on this column, we can't rename the column name. + if errG != nil { + return nil, errors.Trace(errG) + } + } + + // Constraints in the new column means adding new constraints. Errors should thrown, + // which will be done by `processColumnOptions` later. + if specNewColumn.Tp == nil { + // Make sure the column definition is simple field type. + return nil, errors.Trace(dbterror.ErrUnsupportedModifyColumn) + } + + if err = checkColumnAttributes(specNewColumn.Name.OrigColName(), specNewColumn.Tp); err != nil { + return nil, errors.Trace(err) + } + + newCol := table.ToColumn(&model.ColumnInfo{ + ID: col.ID, + // We use this PR(https://github.com/pingcap/tidb/pull/6274) as the dividing line to define whether it is a new version or an old version TiDB. + // The old version TiDB initializes the column's offset and state here. + // The new version TiDB doesn't initialize the column's offset and state, and it will do the initialization in run DDL function. + // When we do the rolling upgrade the following may happen: + // a new version TiDB builds the DDL job that doesn't be set the column's offset and state, + // and the old version TiDB is the DDL owner, it doesn't get offset and state from the store. Then it will encounter errors. + // So here we set offset and state to support the rolling upgrade. + Offset: col.Offset, + State: col.State, + OriginDefaultValue: col.OriginDefaultValue, + OriginDefaultValueBit: col.OriginDefaultValueBit, + FieldType: *specNewColumn.Tp, + Name: newColName, + Version: col.Version, + }) + + if err = ProcessColumnCharsetAndCollation(sctx, col, newCol, t.Meta(), specNewColumn, schema); err != nil { + return nil, err + } + + if err = checkModifyColumnWithForeignKeyConstraint(is, schema.Name.L, t.Meta(), col.ColumnInfo, newCol.ColumnInfo); err != nil { + return nil, errors.Trace(err) + } + + // Copy index related options to the new spec. + indexFlags := col.FieldType.GetFlag() & (mysql.PriKeyFlag | mysql.UniqueKeyFlag | mysql.MultipleKeyFlag) + newCol.FieldType.AddFlag(indexFlags) + if mysql.HasPriKeyFlag(col.FieldType.GetFlag()) { + newCol.FieldType.AddFlag(mysql.NotNullFlag) + // TODO: If user explicitly set NULL, we should throw error ErrPrimaryCantHaveNull. + } + + if err = ProcessModifyColumnOptions(sctx, newCol, specNewColumn.Options); err != nil { + return nil, errors.Trace(err) + } + + if err = checkModifyTypes(&col.FieldType, &newCol.FieldType, isColumnWithIndex(col.Name.L, t.Meta().Indices)); err != nil { + if strings.Contains(err.Error(), "Unsupported modifying collation") { + colErrMsg := "Unsupported modifying collation of column '%s' from '%s' to '%s' when index is defined on it." + err = dbterror.ErrUnsupportedModifyCollation.GenWithStack(colErrMsg, col.Name.L, col.GetCollate(), newCol.GetCollate()) + } + return nil, errors.Trace(err) + } + needChangeColData := needChangeColumnData(col.ColumnInfo, newCol.ColumnInfo) + if needChangeColData { + if err = isGeneratedRelatedColumn(t.Meta(), newCol.ColumnInfo, col.ColumnInfo); err != nil { + return nil, errors.Trace(err) + } + if t.Meta().Partition != nil { + return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("table is partition table") + } + } + + // Check that the column change does not affect the partitioning column + // It must keep the same type, int [unsigned], [var]char, date[time] + if t.Meta().Partition != nil { + pt, ok := t.(table.PartitionedTable) + if !ok { + // Should never happen! + return nil, dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(newCol.Name.O) + } + isPartitioningColumn := false + for _, name := range pt.GetPartitionColumnNames() { + if strings.EqualFold(name.L, col.Name.L) { + isPartitioningColumn = true + break + } + } + if isPartitioningColumn { + // TODO: update the partitioning columns with new names if column is renamed + // Would be an extension from MySQL which does not support it. + if col.Name.L != newCol.Name.L { + return nil, dbterror.ErrDependentByPartitionFunctional.GenWithStackByArgs(col.Name.L) + } + if !isColTypeAllowedAsPartitioningCol(t.Meta().Partition.Type, newCol.FieldType) { + return nil, dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(newCol.Name.O) + } + pi := pt.Meta().GetPartitionInfo() + if len(pi.Columns) == 0 { + // non COLUMNS partitioning, only checks INTs, not their actual range + // There are many edge cases, like when truncating SQL Mode is allowed + // which will change the partitioning expression value resulting in a + // different partition. Better be safe and not allow decreasing of length. + // TODO: Should we allow it in strict mode? Wait for a use case / request. + if newCol.FieldType.GetFlen() < col.FieldType.GetFlen() { + return nil, dbterror.ErrUnsupportedModifyCollation.GenWithStack("Unsupported modify column, decreasing length of int may result in truncation and change of partition") + } + } + // Basically only allow changes of the length/decimals for the column + // Note that enum is not allowed, so elems are not checked + // TODO: support partition by ENUM + if newCol.FieldType.EvalType() != col.FieldType.EvalType() || + newCol.FieldType.GetFlag() != col.FieldType.GetFlag() || + newCol.FieldType.GetCollate() != col.FieldType.GetCollate() || + newCol.FieldType.GetCharset() != col.FieldType.GetCharset() { + return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't change the partitioning column, since it would require reorganize all partitions") + } + // Generate a new PartitionInfo and validate it together with the new column definition + // Checks if all partition definition values are compatible. + // Similar to what buildRangePartitionDefinitions would do in terms of checks. + + tblInfo := pt.Meta() + newTblInfo := *tblInfo + // Replace col with newCol and see if we can generate a new SHOW CREATE TABLE + // and reparse it and build new partition definitions (which will do additional + // checks columns vs partition definition values + newCols := make([]*model.ColumnInfo, 0, len(newTblInfo.Columns)) + for _, c := range newTblInfo.Columns { + if c.ID == col.ID { + newCols = append(newCols, newCol.ColumnInfo) + continue + } + newCols = append(newCols, c) + } + newTblInfo.Columns = newCols + + var buf bytes.Buffer + AppendPartitionInfo(tblInfo.GetPartitionInfo(), &buf, mysql.ModeNone) + // The parser supports ALTER TABLE ... PARTITION BY ... even if the ddl code does not yet :) + // Ignoring warnings + stmt, _, err := parser.New().ParseSQL("ALTER TABLE t " + buf.String()) + if err != nil { + // Should never happen! + return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("cannot parse generated PartitionInfo") + } + at, ok := stmt[0].(*ast.AlterTableStmt) + if !ok || len(at.Specs) != 1 || at.Specs[0].Partition == nil { + return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("cannot parse generated PartitionInfo") + } + pAst := at.Specs[0].Partition + _, err = buildPartitionDefinitionsInfo( + exprctx.CtxWithHandleTruncateErrLevel(sctx.GetExprCtx(), errctx.LevelError), + pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)), + ) + if err != nil { + return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error()) + } + } + } + + // We don't support modifying column from not_auto_increment to auto_increment. + if !mysql.HasAutoIncrementFlag(col.GetFlag()) && mysql.HasAutoIncrementFlag(newCol.GetFlag()) { + return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't set auto_increment") + } + // Not support auto id with default value. + if mysql.HasAutoIncrementFlag(newCol.GetFlag()) && newCol.GetDefaultValue() != nil { + return nil, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(newCol.Name) + } + // Disallow modifying column from auto_increment to not auto_increment if the session variable `AllowRemoveAutoInc` is false. + if !sctx.GetSessionVars().AllowRemoveAutoInc && mysql.HasAutoIncrementFlag(col.GetFlag()) && !mysql.HasAutoIncrementFlag(newCol.GetFlag()) { + return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't remove auto_increment without @@tidb_allow_remove_auto_inc enabled") + } + + // We support modifying the type definitions of 'null' to 'not null' now. + var modifyColumnTp byte + if !mysql.HasNotNullFlag(col.GetFlag()) && mysql.HasNotNullFlag(newCol.GetFlag()) { + if err = checkForNullValue(ctx, sctx, true, ident.Schema, ident.Name, newCol.ColumnInfo, col.ColumnInfo); err != nil { + return nil, errors.Trace(err) + } + // `modifyColumnTp` indicates that there is a type modification. + modifyColumnTp = mysql.TypeNull + } + + if err = checkColumnWithIndexConstraint(t.Meta(), col.ColumnInfo, newCol.ColumnInfo); err != nil { + return nil, err + } + + // As same with MySQL, we don't support modifying the stored status for generated columns. + if err = checkModifyGeneratedColumn(sctx, schema.Name, t, col, newCol, specNewColumn, spec.Position); err != nil { + return nil, errors.Trace(err) + } + if errG != nil { + // According to issue https://github.com/pingcap/tidb/issues/24321, + // changing the type of a column involving generating a column is prohibited. + return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs(errG.Error()) + } + + if t.Meta().TTLInfo != nil { + // the column referenced by TTL should be a time type + if t.Meta().TTLInfo.ColumnName.L == originalColName.L && !types.IsTypeTime(newCol.ColumnInfo.FieldType.GetType()) { + return nil, errors.Trace(dbterror.ErrUnsupportedColumnInTTLConfig.GenWithStackByArgs(newCol.ColumnInfo.Name.O)) + } + } + + var newAutoRandBits uint64 + if newAutoRandBits, err = checkAutoRandom(t.Meta(), col, specNewColumn); err != nil { + return nil, errors.Trace(err) + } + + txn, err := sctx.Txn(true) + if err != nil { + return nil, errors.Trace(err) + } + bdrRole, err := meta.NewMeta(txn).GetBDRRole() + if err != nil { + return nil, errors.Trace(err) + } + if bdrRole == string(ast.BDRRolePrimary) && + deniedByBDRWhenModifyColumn(newCol.FieldType, col.FieldType, specNewColumn.Options) { + return nil, dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) + } + + job := &model.Job{ + SchemaID: schema.ID, + TableID: t.Meta().ID, + SchemaName: schema.Name.L, + TableName: t.Meta().Name.L, + Type: model.ActionModifyColumn, + BinlogInfo: &model.HistoryInfo{}, + ReorgMeta: NewDDLReorgMeta(sctx), + CtxVars: []any{needChangeColData}, + Args: []any{&newCol.ColumnInfo, originalColName, spec.Position, modifyColumnTp, newAutoRandBits}, + CDCWriteSource: sctx.GetSessionVars().CDCWriteSource, + SQLMode: sctx.GetSessionVars().SQLMode, + } + return job, nil +} + +func needChangeColumnData(oldCol, newCol *model.ColumnInfo) bool { + toUnsigned := mysql.HasUnsignedFlag(newCol.GetFlag()) + originUnsigned := mysql.HasUnsignedFlag(oldCol.GetFlag()) + needTruncationOrToggleSign := func() bool { + return (newCol.GetFlen() > 0 && (newCol.GetFlen() < oldCol.GetFlen() || newCol.GetDecimal() < oldCol.GetDecimal())) || + (toUnsigned != originUnsigned) + } + // Ignore the potential max display length represented by integer's flen, use default flen instead. + defaultOldColFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(oldCol.GetType()) + defaultNewColFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(newCol.GetType()) + needTruncationOrToggleSignForInteger := func() bool { + return (defaultNewColFlen > 0 && defaultNewColFlen < defaultOldColFlen) || (toUnsigned != originUnsigned) + } + + // Deal with the same type. + if oldCol.GetType() == newCol.GetType() { + switch oldCol.GetType() { + case mysql.TypeNewDecimal: + // Since type decimal will encode the precision, frac, negative(signed) and wordBuf into storage together, there is no short + // cut to eliminate data reorg change for column type change between decimal. + return oldCol.GetFlen() != newCol.GetFlen() || oldCol.GetDecimal() != newCol.GetDecimal() || toUnsigned != originUnsigned + case mysql.TypeEnum, mysql.TypeSet: + return IsElemsChangedToModifyColumn(oldCol.GetElems(), newCol.GetElems()) + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + return toUnsigned != originUnsigned + case mysql.TypeString: + // Due to the behavior of padding \x00 at binary type, always change column data when binary length changed + if types.IsBinaryStr(&oldCol.FieldType) { + return newCol.GetFlen() != oldCol.GetFlen() + } + } + + return needTruncationOrToggleSign() + } + + if ConvertBetweenCharAndVarchar(oldCol.GetType(), newCol.GetType()) { + return true + } + + // Deal with the different type. + switch oldCol.GetType() { + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + switch newCol.GetType() { + case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: + return needTruncationOrToggleSign() + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + switch newCol.GetType() { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + return needTruncationOrToggleSignForInteger() + } + // conversion between float and double needs reorganization, see issue #31372 + } + + return true +} + +// ConvertBetweenCharAndVarchar check whether column converted between char and varchar +// TODO: it is used for plugins. so change plugin's using and remove it. +func ConvertBetweenCharAndVarchar(oldCol, newCol byte) bool { + return types.ConvertBetweenCharAndVarchar(oldCol, newCol) +} + +// IsElemsChangedToModifyColumn check elems changed +func IsElemsChangedToModifyColumn(oldElems, newElems []string) bool { + if len(newElems) < len(oldElems) { + return true + } + for index, oldElem := range oldElems { + newElem := newElems[index] + if oldElem != newElem { + return true + } + } + return false +} + +// ProcessColumnCharsetAndCollation process column charset and collation +func ProcessColumnCharsetAndCollation(sctx sessionctx.Context, col *table.Column, newCol *table.Column, meta *model.TableInfo, specNewColumn *ast.ColumnDef, schema *model.DBInfo) error { + var chs, coll string + var err error + // TODO: Remove it when all table versions are greater than or equal to TableInfoVersion1. + // If newCol's charset is empty and the table's version less than TableInfoVersion1, + // we will not modify the charset of the column. This behavior is not compatible with MySQL. + if len(newCol.FieldType.GetCharset()) == 0 && meta.Version < model.TableInfoVersion1 { + chs = col.FieldType.GetCharset() + coll = col.FieldType.GetCollate() + } else { + chs, coll, err = getCharsetAndCollateInColumnDef(sctx.GetSessionVars(), specNewColumn) + if err != nil { + return errors.Trace(err) + } + chs, coll, err = ResolveCharsetCollation(sctx.GetSessionVars(), + ast.CharsetOpt{Chs: chs, Col: coll}, + ast.CharsetOpt{Chs: meta.Charset, Col: meta.Collate}, + ast.CharsetOpt{Chs: schema.Charset, Col: schema.Collate}, + ) + chs, coll = OverwriteCollationWithBinaryFlag(sctx.GetSessionVars(), specNewColumn, chs, coll) + if err != nil { + return errors.Trace(err) + } + } + + if err = setCharsetCollationFlenDecimal(&newCol.FieldType, newCol.Name.O, chs, coll, sctx.GetSessionVars()); err != nil { + return errors.Trace(err) + } + decodeEnumSetBinaryLiteralToUTF8(&newCol.FieldType, chs) + return nil +} + +// checkColumnWithIndexConstraint is used to check the related index constraint of the modified column. +// Index has a max-prefix-length constraint. eg: a varchar(100), index idx(a), modifying column a to a varchar(4000) +// will cause index idx to break the max-prefix-length constraint. +func checkColumnWithIndexConstraint(tbInfo *model.TableInfo, originalCol, newCol *model.ColumnInfo) error { + columns := make([]*model.ColumnInfo, 0, len(tbInfo.Columns)) + columns = append(columns, tbInfo.Columns...) + // Replace old column with new column. + for i, col := range columns { + if col.Name.L != originalCol.Name.L { + continue + } + columns[i] = newCol.Clone() + columns[i].Name = originalCol.Name + break + } + + pkIndex := tables.FindPrimaryIndex(tbInfo) + + checkOneIndex := func(indexInfo *model.IndexInfo) (err error) { + var modified bool + for _, col := range indexInfo.Columns { + if col.Name.L == originalCol.Name.L { + modified = true + break + } + } + if !modified { + return + } + err = checkIndexInModifiableColumns(columns, indexInfo.Columns) + if err != nil { + return + } + err = checkIndexPrefixLength(columns, indexInfo.Columns) + return + } + + // Check primary key first. + var err error + + if pkIndex != nil { + err = checkOneIndex(pkIndex) + if err != nil { + return err + } + } + + // Check secondary indexes. + for _, indexInfo := range tbInfo.Indices { + if indexInfo.Primary { + continue + } + // the second param should always be set to true, check index length only if it was modified + // checkOneIndex needs one param only. + err = checkOneIndex(indexInfo) + if err != nil { + return err + } + } + return nil +} + +func checkIndexInModifiableColumns(columns []*model.ColumnInfo, idxColumns []*model.IndexColumn) error { + for _, ic := range idxColumns { + col := model.FindColumnInfo(columns, ic.Name.L) + if col == nil { + return dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ic.Name) + } + + prefixLength := types.UnspecifiedLength + if types.IsTypePrefixable(col.FieldType.GetType()) && col.FieldType.GetFlen() > ic.Length { + // When the index column is changed, prefix length is only valid + // if the type is still prefixable and larger than old prefix length. + prefixLength = ic.Length + } + if err := checkIndexColumn(nil, col, prefixLength); err != nil { + return err + } + } + return nil +} + +// checkModifyTypes checks if the 'origin' type can be modified to 'to' type no matter directly change +// or change by reorg. It returns error if the two types are incompatible and correlated change are not +// supported. However, even the two types can be change, if the "origin" type contains primary key, error will be returned. +func checkModifyTypes(origin *types.FieldType, to *types.FieldType, needRewriteCollationData bool) error { + canReorg, err := types.CheckModifyTypeCompatible(origin, to) + if err != nil { + if !canReorg { + return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(err.Error())) + } + if mysql.HasPriKeyFlag(origin.GetFlag()) { + msg := "this column has primary key flag" + return dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(msg) + } + } + + err = checkModifyCharsetAndCollation(to.GetCharset(), to.GetCollate(), origin.GetCharset(), origin.GetCollate(), needRewriteCollationData) + + if err != nil { + if to.GetCharset() == charset.CharsetGBK || origin.GetCharset() == charset.CharsetGBK { + return errors.Trace(err) + } + // column type change can handle the charset change between these two types in the process of the reorg. + if dbterror.ErrUnsupportedModifyCharset.Equal(err) && canReorg { + return nil + } + } + return errors.Trace(err) +} + +// ProcessModifyColumnOptions process column options. +func ProcessModifyColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error { + var sb strings.Builder + restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | + format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutSchemaName + restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) + + var hasDefaultValue, setOnUpdateNow bool + var err error + var hasNullFlag bool + for _, opt := range options { + switch opt.Tp { + case ast.ColumnOptionDefaultValue: + hasDefaultValue, err = SetDefaultValue(ctx, col, opt) + if err != nil { + return errors.Trace(err) + } + case ast.ColumnOptionComment: + err := setColumnComment(ctx, col, opt) + if err != nil { + return errors.Trace(err) + } + case ast.ColumnOptionNotNull: + col.AddFlag(mysql.NotNullFlag) + case ast.ColumnOptionNull: + hasNullFlag = true + col.DelFlag(mysql.NotNullFlag) + case ast.ColumnOptionAutoIncrement: + col.AddFlag(mysql.AutoIncrementFlag) + case ast.ColumnOptionPrimaryKey: + return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStack("can't change column constraint (PRIMARY KEY)")) + case ast.ColumnOptionUniqKey: + return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStack("can't change column constraint (UNIQUE KEY)")) + case ast.ColumnOptionOnUpdate: + // TODO: Support other time functions. + if !(col.GetType() == mysql.TypeTimestamp || col.GetType() == mysql.TypeDatetime) { + return dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) + } + if !expression.IsValidCurrentTimestampExpr(opt.Expr, &col.FieldType) { + return dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) + } + col.AddFlag(mysql.OnUpdateNowFlag) + setOnUpdateNow = true + case ast.ColumnOptionGenerated: + sb.Reset() + err = opt.Expr.Restore(restoreCtx) + if err != nil { + return errors.Trace(err) + } + col.GeneratedExprString = sb.String() + col.GeneratedStored = opt.Stored + col.Dependences = make(map[string]struct{}) + // Only used by checkModifyGeneratedColumn, there is no need to set a ctor for it. + col.GeneratedExpr = table.NewClonableExprNode(nil, opt.Expr) + for _, colName := range FindColumnNamesInExpr(opt.Expr) { + col.Dependences[colName.Name.L] = struct{}{} + } + case ast.ColumnOptionCollate: + col.SetCollate(opt.StrValue) + case ast.ColumnOptionReference: + return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't modify with references")) + case ast.ColumnOptionFulltext: + return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't modify with full text")) + case ast.ColumnOptionCheck: + return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't modify with check")) + // Ignore ColumnOptionAutoRandom. It will be handled later. + case ast.ColumnOptionAutoRandom: + default: + return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(fmt.Sprintf("unknown column option type: %d", opt.Tp))) + } + } + + if err = processAndCheckDefaultValueAndColumn(ctx, col, nil, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { + return errors.Trace(err) + } + + return nil +} + +func checkAutoRandom(tableInfo *model.TableInfo, originCol *table.Column, specNewColumn *ast.ColumnDef) (uint64, error) { + var oldShardBits, oldRangeBits uint64 + if isClusteredPKColumn(originCol, tableInfo) { + oldShardBits = tableInfo.AutoRandomBits + oldRangeBits = tableInfo.AutoRandomRangeBits + } + newShardBits, newRangeBits, err := extractAutoRandomBitsFromColDef(specNewColumn) + if err != nil { + return 0, errors.Trace(err) + } + switch { + case oldShardBits == newShardBits: + case oldShardBits < newShardBits: + addingAutoRandom := oldShardBits == 0 + if addingAutoRandom { + convFromAutoInc := mysql.HasAutoIncrementFlag(originCol.GetFlag()) && originCol.IsPKHandleColumn(tableInfo) + if !convFromAutoInc { + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomAlterChangeFromAutoInc) + } + } + if autoid.AutoRandomShardBitsMax < newShardBits { + errMsg := fmt.Sprintf(autoid.AutoRandomOverflowErrMsg, + autoid.AutoRandomShardBitsMax, newShardBits, specNewColumn.Name.Name.O) + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) + } + // increasing auto_random shard bits is allowed. + case oldShardBits > newShardBits: + if newShardBits == 0 { + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomAlterErrMsg) + } + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomDecreaseBitErrMsg) + } + + modifyingAutoRandCol := oldShardBits > 0 || newShardBits > 0 + if modifyingAutoRandCol { + // Disallow changing the column field type. + if originCol.GetType() != specNewColumn.Tp.GetType() { + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomModifyColTypeErrMsg) + } + if originCol.GetType() != mysql.TypeLonglong { + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(fmt.Sprintf(autoid.AutoRandomOnNonBigIntColumn, types.TypeStr(originCol.GetType()))) + } + // Disallow changing from auto_random to auto_increment column. + if containsColumnOption(specNewColumn, ast.ColumnOptionAutoIncrement) { + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithAutoIncErrMsg) + } + // Disallow specifying a default value on auto_random column. + if containsColumnOption(specNewColumn, ast.ColumnOptionDefaultValue) { + return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithDefaultValueErrMsg) + } + } + if rangeBitsIsChanged(oldRangeBits, newRangeBits) { + return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(autoid.AutoRandomUnsupportedAlterRangeBits) + } + return newShardBits, nil +} + +func isClusteredPKColumn(col *table.Column, tblInfo *model.TableInfo) bool { + switch { + case tblInfo.PKIsHandle: + return mysql.HasPriKeyFlag(col.GetFlag()) + case tblInfo.IsCommonHandle: + pk := tables.FindPrimaryIndex(tblInfo) + for _, c := range pk.Columns { + if c.Name.L == col.Name.L { + return true + } + } + return false + default: + return false + } +} + +func rangeBitsIsChanged(oldBits, newBits uint64) bool { + if oldBits == 0 { + oldBits = autoid.AutoRandomRangeBitsDefault + } + if newBits == 0 { + newBits = autoid.AutoRandomRangeBitsDefault + } + return oldBits != newBits +} diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index a0ee830c0e870..dcc59bf5cd8ae 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -179,10 +179,10 @@ func (w *worker) onAddTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (v job.SchemaState = model.StateReplicaOnly case model.StateReplicaOnly: // replica only -> public - failpoint.Inject("sleepBeforeReplicaOnly", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("sleepBeforeReplicaOnly")); _err_ == nil { sleepSecond := val.(int) time.Sleep(time.Duration(sleepSecond) * time.Second) - }) + } // Here need do some tiflash replica complement check. // TODO: If a table is with no TiFlashReplica or it is not available, the replica-only state can be eliminated. if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { @@ -410,16 +410,16 @@ func checkAddPartitionValue(meta *model.TableInfo, part *model.PartitionInfo) er } func checkPartitionReplica(replicaCount uint64, addingDefinitions []model.PartitionDefinition, d *ddlCtx) (needWait bool, err error) { - failpoint.Inject("mockWaitTiFlashReplica", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockWaitTiFlashReplica")); _err_ == nil { if val.(bool) { - failpoint.Return(true, nil) + return true, nil } - }) - failpoint.Inject("mockWaitTiFlashReplicaOK", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("mockWaitTiFlashReplicaOK")); _err_ == nil { if val.(bool) { - failpoint.Return(false, nil) + return false, nil } - }) + } ctx := context.Background() pdCli := d.store.(tikv.Storage).GetRegionCache().PDClient() @@ -451,9 +451,9 @@ func checkPartitionReplica(replicaCount uint64, addingDefinitions []model.Partit return needWait, errors.Trace(err) } tiflashPeerAtLeastOne := checkTiFlashPeerStoreAtLeastOne(stores, regionState.Meta.Peers) - failpoint.Inject("ForceTiflashNotAvailable", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("ForceTiflashNotAvailable")); _err_ == nil { tiflashPeerAtLeastOne = v.(bool) - }) + } // It's unnecessary to wait all tiflash peer to be replicated. // Here only make sure that tiflash peer count > 0 (at least one). if tiflashPeerAtLeastOne { @@ -2516,9 +2516,9 @@ func clearTruncatePartitionTiflashStatus(tblInfo *model.TableInfo, newPartitions // Clear the tiflash replica available status. if tblInfo.TiFlashReplica != nil { e := infosync.ConfigureTiFlashPDForPartitions(true, &newPartitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID) - failpoint.Inject("FailTiFlashTruncatePartition", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FailTiFlashTruncatePartition")); _err_ == nil { e = errors.New("enforced error") - }) + } if e != nil { logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(e)) return e @@ -2778,11 +2778,11 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo return ver, errors.Trace(err) } - failpoint.Inject("exchangePartitionErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("exchangePartitionErr")); _err_ == nil { if val.(bool) { - failpoint.Return(ver, errors.New("occur an error after updating partition id")) + return ver, errors.New("occur an error after updating partition id") } - }) + } // Set both tables to the maximum auto IDs between normal table and partitioned table. // TODO: Fix the issue of big transactions during EXCHANGE PARTITION with AutoID. @@ -2801,20 +2801,20 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo return ver, errors.Trace(err) } - failpoint.Inject("exchangePartitionAutoID", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("exchangePartitionAutoID")); _err_ == nil { if val.(bool) { seCtx, err := w.sessPool.Get() defer w.sessPool.Put(seCtx) if err != nil { - failpoint.Return(ver, err) + return ver, err } se := sess.NewSession(seCtx) _, err = se.Execute(context.Background(), "insert ignore into test.pt values (40000000)", "exchange_partition_test") if err != nil { - failpoint.Return(ver, err) + return ver, err } } - }) + } // the follow code is a swap function for rules of two partitions // though partitions has exchanged their ID, swap still take effect @@ -3246,11 +3246,11 @@ func (w *worker) onReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) } } firstPartIdx, lastPartIdx, idMap, err2 := getReplacedPartitionIDs(partNames, tblInfo.Partition) - failpoint.Inject("reorgPartWriteReorgReplacedPartIDsFail", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("reorgPartWriteReorgReplacedPartIDsFail")); _err_ == nil { if val.(bool) { err2 = errors.New("Injected error by reorgPartWriteReorgReplacedPartIDsFail") } - }) + } if err2 != nil { return ver, err2 } @@ -3354,11 +3354,11 @@ func (w *worker) onReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) } job.CtxVars = []any{physicalTableIDs, newIDs} ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - failpoint.Inject("reorgPartWriteReorgSchemaVersionUpdateFail", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("reorgPartWriteReorgSchemaVersionUpdateFail")); _err_ == nil { if val.(bool) { err = errors.New("Injected error by reorgPartWriteReorgSchemaVersionUpdateFail") } - }) + } if err != nil { return ver, errors.Trace(err) } @@ -3717,12 +3717,12 @@ func (w *worker) reorgPartitionDataAndIndex(t table.Table, reorgInfo *reorgInfo) } } - failpoint.Inject("reorgPartitionAfterDataCopy", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("reorgPartitionAfterDataCopy")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { panic("panic test in reorgPartitionAfterDataCopy") } - }) + } if !bytes.Equal(reorgInfo.currElement.TypeKey, meta.IndexElementKey) { // row data has been copied, now proceed with creating the indexes @@ -4800,10 +4800,10 @@ func checkPartitionByHash(ctx sessionctx.Context, tbInfo *model.TableInfo) error // checkPartitionByRange checks validity of a "BY RANGE" partition. func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - failpoint.Inject("CheckPartitionByRangeErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("CheckPartitionByRangeErr")); _err_ == nil { ctx.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryMemoryExceeded) panic(ctx.GetSessionVars().SQLKiller.HandleSignal()) - }) + } pi := tbInfo.Partition if len(pi.Columns) == 0 { diff --git a/pkg/ddl/partition.go__failpoint_stash__ b/pkg/ddl/partition.go__failpoint_stash__ new file mode 100644 index 0000000000000..a0ee830c0e870 --- /dev/null +++ b/pkg/ddl/partition.go__failpoint_stash__ @@ -0,0 +1,4922 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "math" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/pkg/ddl/label" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/ddl/placement" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/opcode" + "github.com/pingcap/tidb/pkg/parser/terror" + field_types "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/mathutil" + decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" + "github.com/pingcap/tidb/pkg/util/slice" + "github.com/pingcap/tidb/pkg/util/sqlkiller" + "github.com/pingcap/tidb/pkg/util/stringutil" + "github.com/tikv/client-go/v2/tikv" + kvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" +) + +const ( + partitionMaxValue = "MAXVALUE" +) + +func checkAddPartition(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.PartitionInfo, []model.PartitionDefinition, error) { + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, errors.Trace(err) + } + partInfo := &model.PartitionInfo{} + err = job.DecodeArgs(&partInfo) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, errors.Trace(err) + } + if len(tblInfo.Partition.AddingDefinitions) > 0 { + return tblInfo, partInfo, tblInfo.Partition.AddingDefinitions, nil + } + return tblInfo, partInfo, []model.PartitionDefinition{}, nil +} + +// TODO: Move this into reorganize partition! +func (w *worker) onAddTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + // Handle the rolling back job + if job.IsRollingback() { + ver, err := w.onDropTablePartition(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + return ver, nil + } + + // notice: addingDefinitions is empty when job is in state model.StateNone + tblInfo, partInfo, addingDefinitions, err := checkAddPartition(t, job) + if err != nil { + return ver, err + } + + // In order to skip maintaining the state check in partitionDefinition, TiDB use addingDefinition instead of state field. + // So here using `job.SchemaState` to judge what the stage of this job is. + switch job.SchemaState { + case model.StateNone: + // job.SchemaState == model.StateNone means the job is in the initial state of add partition. + // Here should use partInfo from job directly and do some check action. + err = checkAddPartitionTooManyPartitions(uint64(len(tblInfo.Partition.Definitions) + len(partInfo.Definitions))) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = checkAddPartitionValue(tblInfo, partInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = checkAddPartitionNameUnique(tblInfo, partInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + // move the adding definition into tableInfo. + updateAddingPartitionInfo(partInfo, tblInfo) + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + + // modify placement settings + for _, def := range tblInfo.Partition.AddingDefinitions { + if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, def.PlacementPolicyRef); err != nil { + return ver, errors.Trace(err) + } + } + + if tblInfo.TiFlashReplica != nil { + // Must set placement rule, and make sure it succeeds. + if err := infosync.ConfigureTiFlashPDForPartitions(true, &tblInfo.Partition.AddingDefinitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID); err != nil { + logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(err)) + return ver, errors.Trace(err) + } + } + + bundles, err := alterTablePartitionBundles(t, tblInfo, tblInfo.Partition.AddingDefinitions) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") + } + + ids := getIDs([]*model.TableInfo{tblInfo}) + for _, p := range tblInfo.Partition.AddingDefinitions { + ids = append(ids, p.ID) + } + if _, err := alterTableLabelRule(job.SchemaName, tblInfo, ids); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + // none -> replica only + job.SchemaState = model.StateReplicaOnly + case model.StateReplicaOnly: + // replica only -> public + failpoint.Inject("sleepBeforeReplicaOnly", func(val failpoint.Value) { + sleepSecond := val.(int) + time.Sleep(time.Duration(sleepSecond) * time.Second) + }) + // Here need do some tiflash replica complement check. + // TODO: If a table is with no TiFlashReplica or it is not available, the replica-only state can be eliminated. + if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { + // For available state, the new added partition should wait it's replica to + // be finished. Otherwise the query to this partition will be blocked. + needRetry, err := checkPartitionReplica(tblInfo.TiFlashReplica.Count, addingDefinitions, d) + if err != nil { + return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) + } + if needRetry { + // The new added partition hasn't been replicated. + // Do nothing to the job this time, wait next worker round. + time.Sleep(tiflashCheckTiDBHTTPAPIHalfInterval) + // Set the error here which will lead this job exit when it's retry times beyond the limitation. + return ver, errors.Errorf("[ddl] add partition wait for tiflash replica to complete") + } + } + + // When TiFlash Replica is ready, we must move them into `AvailablePartitionIDs`. + if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { + for _, d := range partInfo.Definitions { + tblInfo.TiFlashReplica.AvailablePartitionIDs = append(tblInfo.TiFlashReplica.AvailablePartitionIDs, d.ID) + err = infosync.UpdateTiFlashProgressCache(d.ID, 1) + if err != nil { + // just print log, progress will be updated in `refreshTiFlashTicker` + logutil.DDLLogger().Error("update tiflash sync progress cache failed", + zap.Error(err), + zap.Int64("tableID", tblInfo.ID), + zap.Int64("partitionID", d.ID), + ) + } + } + } + // For normal and replica finished table, move the `addingDefinitions` into `Definitions`. + updatePartitionInfo(tblInfo) + + preSplitAndScatter(w.sess.Context, d.store, tblInfo, addingDefinitions) + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + addPartitionEvent := statsutil.NewAddPartitionEvent( + job.SchemaID, + tblInfo, + partInfo, + ) + asyncNotifyEvent(d, addPartitionEvent) + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) + } + + return ver, errors.Trace(err) +} + +// alterTableLabelRule updates Label Rules if they exists +// returns true if changed. +func alterTableLabelRule(schemaName string, meta *model.TableInfo, ids []int64) (bool, error) { + tableRuleID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, schemaName, meta.Name.L) + oldRule, err := infosync.GetLabelRules(context.TODO(), []string{tableRuleID}) + if err != nil { + return false, errors.Trace(err) + } + if len(oldRule) == 0 { + return false, nil + } + + r, ok := oldRule[tableRuleID] + if ok { + rule := r.Reset(schemaName, meta.Name.L, "", ids...) + err = infosync.PutLabelRule(context.TODO(), rule) + if err != nil { + return false, errors.Wrapf(err, "failed to notify PD label rule") + } + return true, nil + } + return false, nil +} + +func alterTablePartitionBundles(t *meta.Meta, tblInfo *model.TableInfo, addingDefinitions []model.PartitionDefinition) ([]*placement.Bundle, error) { + var bundles []*placement.Bundle + + // tblInfo do not include added partitions, so we should add them first + tblInfo = tblInfo.Clone() + p := *tblInfo.Partition + p.Definitions = append([]model.PartitionDefinition{}, p.Definitions...) + p.Definitions = append(tblInfo.Partition.Definitions, addingDefinitions...) + tblInfo.Partition = &p + + // bundle for table should be recomputed because it includes some default configs for partitions + tblBundle, err := placement.NewTableBundle(t, tblInfo) + if err != nil { + return nil, errors.Trace(err) + } + + if tblBundle != nil { + bundles = append(bundles, tblBundle) + } + + partitionBundles, err := placement.NewPartitionListBundles(t, addingDefinitions) + if err != nil { + return nil, errors.Trace(err) + } + + bundles = append(bundles, partitionBundles...) + return bundles, nil +} + +// When drop/truncate a partition, we should still keep the dropped partition's placement settings to avoid unnecessary region schedules. +// When a partition is not configured with a placement policy directly, its rule is in the table's placement group which will be deleted after +// partition truncated/dropped. So it is necessary to create a standalone placement group with partition id after it. +func droppedPartitionBundles(t *meta.Meta, tblInfo *model.TableInfo, dropPartitions []model.PartitionDefinition) ([]*placement.Bundle, error) { + partitions := make([]model.PartitionDefinition, 0, len(dropPartitions)) + for _, def := range dropPartitions { + def = def.Clone() + if def.PlacementPolicyRef == nil { + def.PlacementPolicyRef = tblInfo.PlacementPolicyRef + } + + if def.PlacementPolicyRef != nil { + partitions = append(partitions, def) + } + } + + return placement.NewPartitionListBundles(t, partitions) +} + +// updatePartitionInfo merge `addingDefinitions` into `Definitions` in the tableInfo. +func updatePartitionInfo(tblInfo *model.TableInfo) { + parInfo := &model.PartitionInfo{} + oldDefs, newDefs := tblInfo.Partition.Definitions, tblInfo.Partition.AddingDefinitions + parInfo.Definitions = make([]model.PartitionDefinition, 0, len(newDefs)+len(oldDefs)) + parInfo.Definitions = append(parInfo.Definitions, oldDefs...) + parInfo.Definitions = append(parInfo.Definitions, newDefs...) + tblInfo.Partition.Definitions = parInfo.Definitions + tblInfo.Partition.AddingDefinitions = nil +} + +// updateAddingPartitionInfo write adding partitions into `addingDefinitions` field in the tableInfo. +func updateAddingPartitionInfo(partitionInfo *model.PartitionInfo, tblInfo *model.TableInfo) { + newDefs := partitionInfo.Definitions + tblInfo.Partition.AddingDefinitions = make([]model.PartitionDefinition, 0, len(newDefs)) + tblInfo.Partition.AddingDefinitions = append(tblInfo.Partition.AddingDefinitions, newDefs...) +} + +// rollbackAddingPartitionInfo remove the `addingDefinitions` in the tableInfo. +func rollbackAddingPartitionInfo(tblInfo *model.TableInfo) ([]int64, []string, []*placement.Bundle) { + physicalTableIDs := make([]int64, 0, len(tblInfo.Partition.AddingDefinitions)) + partNames := make([]string, 0, len(tblInfo.Partition.AddingDefinitions)) + rollbackBundles := make([]*placement.Bundle, 0, len(tblInfo.Partition.AddingDefinitions)) + for _, one := range tblInfo.Partition.AddingDefinitions { + physicalTableIDs = append(physicalTableIDs, one.ID) + partNames = append(partNames, one.Name.L) + if one.PlacementPolicyRef != nil { + rollbackBundles = append(rollbackBundles, placement.NewBundle(one.ID)) + } + } + tblInfo.Partition.AddingDefinitions = nil + return physicalTableIDs, partNames, rollbackBundles +} + +// Check if current table already contains DEFAULT list partition +func checkAddListPartitions(tblInfo *model.TableInfo) error { + for i := range tblInfo.Partition.Definitions { + for j := range tblInfo.Partition.Definitions[i].InValues { + for _, val := range tblInfo.Partition.Definitions[i].InValues[j] { + if val == "DEFAULT" { // should already be normalized + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("ADD List partition, already contains DEFAULT partition. Please use REORGANIZE PARTITION instead") + } + } + } + } + return nil +} + +// checkAddPartitionValue check add Partition Values, +// For Range: values less than value must be strictly increasing for each partition. +// For List: if a Default partition exists, +// +// no ADD partition can be allowed +// (needs reorganize partition instead). +func checkAddPartitionValue(meta *model.TableInfo, part *model.PartitionInfo) error { + switch meta.Partition.Type { + case model.PartitionTypeRange: + if len(meta.Partition.Columns) == 0 { + newDefs, oldDefs := part.Definitions, meta.Partition.Definitions + rangeValue := oldDefs[len(oldDefs)-1].LessThan[0] + if strings.EqualFold(rangeValue, "MAXVALUE") { + return errors.Trace(dbterror.ErrPartitionMaxvalue) + } + + currentRangeValue, err := strconv.Atoi(rangeValue) + if err != nil { + return errors.Trace(err) + } + + for i := 0; i < len(newDefs); i++ { + ifMaxvalue := strings.EqualFold(newDefs[i].LessThan[0], "MAXVALUE") + if ifMaxvalue && i == len(newDefs)-1 { + return nil + } else if ifMaxvalue && i != len(newDefs)-1 { + return errors.Trace(dbterror.ErrPartitionMaxvalue) + } + + nextRangeValue, err := strconv.Atoi(newDefs[i].LessThan[0]) + if err != nil { + return errors.Trace(err) + } + if nextRangeValue <= currentRangeValue { + return errors.Trace(dbterror.ErrRangeNotIncreasing) + } + currentRangeValue = nextRangeValue + } + } + case model.PartitionTypeList: + err := checkAddListPartitions(meta) + if err != nil { + return err + } + } + return nil +} + +func checkPartitionReplica(replicaCount uint64, addingDefinitions []model.PartitionDefinition, d *ddlCtx) (needWait bool, err error) { + failpoint.Inject("mockWaitTiFlashReplica", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(true, nil) + } + }) + failpoint.Inject("mockWaitTiFlashReplicaOK", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(false, nil) + } + }) + + ctx := context.Background() + pdCli := d.store.(tikv.Storage).GetRegionCache().PDClient() + stores, err := pdCli.GetAllStores(ctx) + if err != nil { + return needWait, errors.Trace(err) + } + // Check whether stores have `count` tiflash engines. + tiFlashStoreCount := uint64(0) + for _, store := range stores { + if storeHasEngineTiFlashLabel(store) { + tiFlashStoreCount++ + } + } + if replicaCount > tiFlashStoreCount { + return false, errors.Errorf("[ddl] the tiflash replica count: %d should be less than the total tiflash server count: %d", replicaCount, tiFlashStoreCount) + } + for _, pDef := range addingDefinitions { + startKey, endKey := tablecodec.GetTableHandleKeyRange(pDef.ID) + regions, err := pdCli.BatchScanRegions(ctx, []pd.KeyRange{{StartKey: startKey, EndKey: endKey}}, -1) + if err != nil { + return needWait, errors.Trace(err) + } + // For every region in the partition, if it has some corresponding peers and + // no pending peers, that means the replication has completed. + for _, region := range regions { + regionState, err := pdCli.GetRegionByID(ctx, region.Meta.Id) + if err != nil { + return needWait, errors.Trace(err) + } + tiflashPeerAtLeastOne := checkTiFlashPeerStoreAtLeastOne(stores, regionState.Meta.Peers) + failpoint.Inject("ForceTiflashNotAvailable", func(v failpoint.Value) { + tiflashPeerAtLeastOne = v.(bool) + }) + // It's unnecessary to wait all tiflash peer to be replicated. + // Here only make sure that tiflash peer count > 0 (at least one). + if tiflashPeerAtLeastOne { + continue + } + needWait = true + logutil.DDLLogger().Info("partition replicas check failed in replica-only DDL state", zap.Int64("pID", pDef.ID), zap.Uint64("wait region ID", region.Meta.Id), zap.Bool("tiflash peer at least one", tiflashPeerAtLeastOne), zap.Time("check time", time.Now())) + return needWait, nil + } + } + logutil.DDLLogger().Info("partition replicas check ok in replica-only DDL state") + return needWait, nil +} + +func checkTiFlashPeerStoreAtLeastOne(stores []*metapb.Store, peers []*metapb.Peer) bool { + for _, peer := range peers { + for _, store := range stores { + if peer.StoreId == store.Id && storeHasEngineTiFlashLabel(store) { + return true + } + } + } + return false +} + +func storeHasEngineTiFlashLabel(store *metapb.Store) bool { + for _, label := range store.Labels { + if label.Key == placement.EngineLabelKey && label.Value == placement.EngineLabelTiFlash { + return true + } + } + return false +} + +func checkListPartitions(defs []*ast.PartitionDefinition) error { + for _, def := range defs { + _, ok := def.Clause.(*ast.PartitionDefinitionClauseIn) + if !ok { + switch def.Clause.(type) { + case *ast.PartitionDefinitionClauseLessThan: + return ast.ErrPartitionWrongValues.GenWithStackByArgs("RANGE", "LESS THAN") + case *ast.PartitionDefinitionClauseNone: + return ast.ErrPartitionRequiresValues.GenWithStackByArgs("LIST", "IN") + default: + return dbterror.ErrUnsupportedCreatePartition.GenWithStack("Only VALUES IN () is supported for LIST partitioning") + } + } + } + return nil +} + +// buildTablePartitionInfo builds partition info and checks for some errors. +func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tbInfo *model.TableInfo) error { + if s == nil { + return nil + } + + if strings.EqualFold(ctx.GetSessionVars().EnableTablePartition, "OFF") { + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTablePartitionDisabled) + return nil + } + + var enable bool + switch s.Tp { + case model.PartitionTypeRange: + enable = true + case model.PartitionTypeList: + // Partition by list is enabled only when tidb_enable_list_partition is 'ON'. + enable = ctx.GetSessionVars().EnableListTablePartition + if enable { + err := checkListPartitions(s.Definitions) + if err != nil { + return err + } + } + case model.PartitionTypeHash, model.PartitionTypeKey: + // Partition by hash and key is enabled by default. + if s.Sub != nil { + // Subpartitioning only allowed with Range or List + return ast.ErrSubpartition + } + // Note that linear hash is simply ignored, and creates non-linear hash/key. + if s.Linear { + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("LINEAR %s is not supported, using non-linear %s instead", s.Tp.String(), s.Tp.String()))) + } + if s.Tp == model.PartitionTypeHash || len(s.ColumnNames) != 0 { + enable = true + } + if s.Tp == model.PartitionTypeKey && len(s.ColumnNames) == 0 { + enable = true + } + } + + if !enable { + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported partition type %v, treat as normal table", s.Tp))) + return nil + } + if s.Sub != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported subpartitioning, only using %v partitioning", s.Tp))) + } + + pi := &model.PartitionInfo{ + Type: s.Tp, + Enable: enable, + Num: s.Num, + } + tbInfo.Partition = pi + if s.Expr != nil { + if err := checkPartitionFuncValid(ctx.GetExprCtx(), tbInfo, s.Expr); err != nil { + return errors.Trace(err) + } + buf := new(bytes.Buffer) + restoreFlags := format.DefaultRestoreFlags | format.RestoreBracketAroundBinaryOperation | + format.RestoreWithoutSchemaName | format.RestoreWithoutTableName + restoreCtx := format.NewRestoreCtx(restoreFlags, buf) + if err := s.Expr.Restore(restoreCtx); err != nil { + return err + } + pi.Expr = buf.String() + } else if s.ColumnNames != nil { + pi.Columns = make([]model.CIStr, 0, len(s.ColumnNames)) + for _, cn := range s.ColumnNames { + pi.Columns = append(pi.Columns, cn.Name) + } + if pi.Type == model.PartitionTypeKey && len(s.ColumnNames) == 0 { + if tbInfo.PKIsHandle { + pi.Columns = append(pi.Columns, tbInfo.GetPkName()) + pi.IsEmptyColumns = true + } else if key := tbInfo.GetPrimaryKey(); key != nil { + for _, col := range key.Columns { + pi.Columns = append(pi.Columns, col.Name) + } + pi.IsEmptyColumns = true + } + } + if err := checkColumnsPartitionType(tbInfo); err != nil { + return err + } + } + + exprCtx := ctx.GetExprCtx() + err := generatePartitionDefinitionsFromInterval(exprCtx, s, tbInfo) + if err != nil { + return errors.Trace(err) + } + + defs, err := buildPartitionDefinitionsInfo(exprCtx, s.Definitions, tbInfo, s.Num) + if err != nil { + return errors.Trace(err) + } + + tbInfo.Partition.Definitions = defs + + if s.Interval != nil { + // Syntactic sugar for INTERVAL partitioning + // Generate the resulting CREATE TABLE as the query string + query, ok := ctx.Value(sessionctx.QueryString).(string) + if ok { + sqlMode := ctx.GetSessionVars().SQLMode + var buf bytes.Buffer + AppendPartitionDefs(tbInfo.Partition, &buf, sqlMode) + + syntacticSugar := s.Interval.OriginalText() + syntacticStart := s.Interval.OriginTextPosition() + newQuery := query[:syntacticStart] + "(" + buf.String() + ")" + query[syntacticStart+len(syntacticSugar):] + ctx.SetValue(sessionctx.QueryString, newQuery) + } + } + + partCols, err := getPartitionColSlices(exprCtx, tbInfo, s) + if err != nil { + return errors.Trace(err) + } + + for _, index := range tbInfo.Indices { + if index.Unique && !checkUniqueKeyIncludePartKey(partCols, index.Columns) { + index.Global = ctx.GetSessionVars().EnableGlobalIndex + } + } + return nil +} + +func getPartitionColSlices(sctx expression.BuildContext, tblInfo *model.TableInfo, s *ast.PartitionOptions) (partCols stringSlice, err error) { + if s.Expr != nil { + extractCols := newPartitionExprChecker(sctx, tblInfo) + s.Expr.Accept(extractCols) + partColumns, err := extractCols.columns, extractCols.err + if err != nil { + return nil, err + } + return columnInfoSlice(partColumns), nil + } else if len(s.ColumnNames) > 0 { + return columnNameSlice(s.ColumnNames), nil + } else if len(s.ColumnNames) == 0 { + if tblInfo.PKIsHandle { + return columnInfoSlice([]*model.ColumnInfo{tblInfo.GetPkColInfo()}), nil + } else if key := tblInfo.GetPrimaryKey(); key != nil { + colInfos := make([]*model.ColumnInfo, 0, len(key.Columns)) + for _, col := range key.Columns { + colInfos = append(colInfos, model.FindColumnInfo(tblInfo.Cols(), col.Name.L)) + } + return columnInfoSlice(colInfos), nil + } + } + return nil, errors.Errorf("Table partition metadata not correct, neither partition expression or list of partition columns") +} + +func checkColumnsPartitionType(tbInfo *model.TableInfo) error { + for _, col := range tbInfo.Partition.Columns { + colInfo := tbInfo.FindPublicColumnByName(col.L) + if colInfo == nil { + return errors.Trace(dbterror.ErrFieldNotFoundPart) + } + if !isColTypeAllowedAsPartitioningCol(tbInfo.Partition.Type, colInfo.FieldType) { + return dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(col.O) + } + } + return nil +} + +func isValidKeyPartitionColType(fieldType types.FieldType) bool { + switch fieldType.GetType() { + case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry, mysql.TypeTiDBVectorFloat32: + return false + default: + return true + } +} + +func isColTypeAllowedAsPartitioningCol(partType model.PartitionType, fieldType types.FieldType) bool { + // For key partition, the permitted partition field types can be all field types except + // BLOB, JSON, Geometry + if partType == model.PartitionTypeKey { + return isValidKeyPartitionColType(fieldType) + } + // The permitted data types are shown in the following list: + // All integer types + // DATE and DATETIME + // CHAR, VARCHAR, BINARY, and VARBINARY + // See https://dev.mysql.com/doc/mysql-partitioning-excerpt/5.7/en/partitioning-columns.html + // Note that also TIME is allowed in MySQL. Also see https://bugs.mysql.com/bug.php?id=84362 + switch fieldType.GetType() { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeDuration: + case mysql.TypeVarchar, mysql.TypeString: + default: + return false + } + return true +} + +// getPartitionIntervalFromTable checks if a partitioned table matches a generated INTERVAL partitioned scheme +// will return nil if error occurs, i.e. not an INTERVAL partitioned table +func getPartitionIntervalFromTable(ctx expression.BuildContext, tbInfo *model.TableInfo) *ast.PartitionInterval { + if tbInfo.Partition == nil || + tbInfo.Partition.Type != model.PartitionTypeRange { + return nil + } + if len(tbInfo.Partition.Columns) > 1 { + // Multi-column RANGE COLUMNS is not supported with INTERVAL + return nil + } + if len(tbInfo.Partition.Definitions) < 2 { + // Must have at least two partitions to calculate an INTERVAL + return nil + } + + var ( + interval ast.PartitionInterval + startIdx = 0 + endIdx = len(tbInfo.Partition.Definitions) - 1 + isIntType = true + minVal = "0" + ) + if len(tbInfo.Partition.Columns) > 0 { + partCol := findColumnByName(tbInfo.Partition.Columns[0].L, tbInfo) + if partCol.FieldType.EvalType() == types.ETInt { + min := getLowerBoundInt(partCol) + minVal = strconv.FormatInt(min, 10) + } else if partCol.FieldType.EvalType() == types.ETDatetime { + isIntType = false + minVal = "0000-01-01" + } else { + // Only INT and Datetime columns are supported for INTERVAL partitioning + return nil + } + } else { + if !isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) { + minVal = "-9223372036854775808" + } + } + + // Check if possible null partition + firstPartLessThan := driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[0].LessThan[0]) + if strings.EqualFold(firstPartLessThan, minVal) { + interval.NullPart = true + startIdx++ + firstPartLessThan = driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[startIdx].LessThan[0]) + } + // flag if MAXVALUE partition + lastPartLessThan := driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[endIdx].LessThan[0]) + if strings.EqualFold(lastPartLessThan, partitionMaxValue) { + interval.MaxValPart = true + endIdx-- + lastPartLessThan = driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[endIdx].LessThan[0]) + } + // Guess the interval + if startIdx >= endIdx { + // Must have at least two partitions to calculate an INTERVAL + return nil + } + var firstExpr, lastExpr ast.ExprNode + if isIntType { + exprStr := fmt.Sprintf("((%s) - (%s)) DIV %d", lastPartLessThan, firstPartLessThan, endIdx-startIdx) + expr, err := expression.ParseSimpleExpr(ctx, exprStr) + if err != nil { + return nil + } + val, isNull, err := expr.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) + if isNull || err != nil || val < 1 { + // If NULL, error or interval < 1 then cannot be an INTERVAL partitioned table + return nil + } + interval.IntervalExpr.Expr = ast.NewValueExpr(val, "", "") + interval.IntervalExpr.TimeUnit = ast.TimeUnitInvalid + firstExpr, err = astIntValueExprFromStr(firstPartLessThan, minVal == "0") + if err != nil { + return nil + } + interval.FirstRangeEnd = &firstExpr + lastExpr, err = astIntValueExprFromStr(lastPartLessThan, minVal == "0") + if err != nil { + return nil + } + interval.LastRangeEnd = &lastExpr + } else { // types.ETDatetime + exprStr := fmt.Sprintf("TIMESTAMPDIFF(SECOND, '%s', '%s')", firstPartLessThan, lastPartLessThan) + expr, err := expression.ParseSimpleExpr(ctx, exprStr) + if err != nil { + return nil + } + val, isNull, err := expr.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) + if isNull || err != nil || val < 1 { + // If NULL, error or interval < 1 then cannot be an INTERVAL partitioned table + return nil + } + + // This will not find all matches > 28 days, since INTERVAL 1 MONTH can generate + // 2022-01-31, 2022-02-28, 2022-03-31 etc. so we just assume that if there is a + // diff >= 28 days, we will try with Month and not retry with something else... + i := val / int64(endIdx-startIdx) + if i < (28 * 24 * 60 * 60) { + // Since it is not stored or displayed, non need to try Minute..Week! + interval.IntervalExpr.Expr = ast.NewValueExpr(i, "", "") + interval.IntervalExpr.TimeUnit = ast.TimeUnitSecond + } else { + // Since it is not stored or displayed, non need to try to match Quarter or Year! + if (endIdx - startIdx) <= 3 { + // in case February is in the range + i = i / (28 * 24 * 60 * 60) + } else { + // This should be good for intervals up to 5 years + i = i / (30 * 24 * 60 * 60) + } + interval.IntervalExpr.Expr = ast.NewValueExpr(i, "", "") + interval.IntervalExpr.TimeUnit = ast.TimeUnitMonth + } + + firstExpr = ast.NewValueExpr(firstPartLessThan, "", "") + lastExpr = ast.NewValueExpr(lastPartLessThan, "", "") + interval.FirstRangeEnd = &firstExpr + interval.LastRangeEnd = &lastExpr + } + + partitionMethod := ast.PartitionMethod{ + Tp: model.PartitionTypeRange, + Interval: &interval, + } + partOption := &ast.PartitionOptions{PartitionMethod: partitionMethod} + // Generate the definitions from interval, first and last + err := generatePartitionDefinitionsFromInterval(ctx, partOption, tbInfo) + if err != nil { + return nil + } + + return &interval +} + +// comparePartitionAstAndModel compares a generated *ast.PartitionOptions and a *model.PartitionInfo +func comparePartitionAstAndModel(ctx expression.BuildContext, pAst *ast.PartitionOptions, pModel *model.PartitionInfo, partCol *model.ColumnInfo) error { + a := pAst.Definitions + m := pModel.Definitions + if len(pAst.Definitions) != len(pModel.Definitions) { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: number of partitions generated != partition defined (%d != %d)", len(a), len(m)) + } + + evalCtx := ctx.GetEvalCtx() + evalFn := func(expr ast.ExprNode) (types.Datum, error) { + val, err := expression.EvalSimpleAst(ctx, ast.NewValueExpr(expr, "", "")) + if err != nil || partCol == nil { + return val, err + } + return val.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) + } + for i := range pAst.Definitions { + // Allow options to differ! (like Placement Rules) + // Allow names to differ! + + // Check MAXVALUE + maxVD := false + if strings.EqualFold(m[i].LessThan[0], partitionMaxValue) { + maxVD = true + } + generatedExpr := a[i].Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs[0] + _, maxVG := generatedExpr.(*ast.MaxValueExpr) + if maxVG || maxVD { + if maxVG && maxVD { + continue + } + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("INTERVAL partitioning: MAXVALUE clause defined for partition %s differs between generated and defined", m[i].Name.O)) + } + + lessThan := m[i].LessThan[0] + if len(lessThan) > 1 && lessThan[:1] == "'" && lessThan[len(lessThan)-1:] == "'" { + lessThan = driver.UnwrapFromSingleQuotes(lessThan) + } + lessThanVal, err := evalFn(ast.NewValueExpr(lessThan, "", "")) + if err != nil { + return err + } + generatedExprVal, err := evalFn(generatedExpr) + if err != nil { + return err + } + cmp, err := lessThanVal.Compare(evalCtx.TypeCtx(), &generatedExprVal, collate.GetBinaryCollator()) + if err != nil { + return err + } + if cmp != 0 { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("INTERVAL partitioning: LESS THAN for partition %s differs between generated and defined", m[i].Name.O)) + } + } + return nil +} + +// comparePartitionDefinitions check if generated definitions are the same as the given ones +// Allow names to differ +// returns error in case of error or non-accepted difference +func comparePartitionDefinitions(ctx expression.BuildContext, a, b []*ast.PartitionDefinition) error { + if len(a) != len(b) { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("number of partitions generated != partition defined (%d != %d)", len(a), len(b)) + } + for i := range a { + if len(b[i].Sub) > 0 { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s does have unsupported subpartitions", b[i].Name.O)) + } + // TODO: We could extend the syntax to allow for table options too, like: + // CREATE TABLE t ... INTERVAL ... LAST PARTITION LESS THAN ('2015-01-01') PLACEMENT POLICY = 'cheapStorage' + // ALTER TABLE t LAST PARTITION LESS THAN ('2022-01-01') PLACEMENT POLICY 'defaultStorage' + // ALTER TABLE t LAST PARTITION LESS THAN ('2023-01-01') PLACEMENT POLICY 'fastStorage' + if len(b[i].Options) > 0 { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s does have unsupported options", b[i].Name.O)) + } + lessThan, ok := b[i].Clause.(*ast.PartitionDefinitionClauseLessThan) + if !ok { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s does not have the right type for LESS THAN", b[i].Name.O)) + } + definedExpr := lessThan.Exprs[0] + generatedExpr := a[i].Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs[0] + _, maxVD := definedExpr.(*ast.MaxValueExpr) + _, maxVG := generatedExpr.(*ast.MaxValueExpr) + if maxVG || maxVD { + if maxVG && maxVD { + continue + } + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s differs between generated and defined for MAXVALUE", b[i].Name.O)) + } + cmpExpr := &ast.BinaryOperationExpr{ + Op: opcode.EQ, + L: definedExpr, + R: generatedExpr, + } + cmp, err := expression.EvalSimpleAst(ctx, cmpExpr) + if err != nil { + return err + } + if cmp.GetInt64() != 1 { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s differs between generated and defined for expression", b[i].Name.O)) + } + } + return nil +} + +func getLowerBoundInt(partCols ...*model.ColumnInfo) int64 { + ret := int64(0) + for _, col := range partCols { + if mysql.HasUnsignedFlag(col.FieldType.GetFlag()) { + return 0 + } + ret = min(ret, types.IntergerSignedLowerBound(col.GetType())) + } + return ret +} + +// generatePartitionDefinitionsFromInterval generates partition Definitions according to INTERVAL options on partOptions +func generatePartitionDefinitionsFromInterval(ctx expression.BuildContext, partOptions *ast.PartitionOptions, tbInfo *model.TableInfo) error { + if partOptions.Interval == nil { + return nil + } + if tbInfo.Partition.Type != model.PartitionTypeRange { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, only allowed on RANGE partitioning") + } + if len(partOptions.ColumnNames) > 1 || len(tbInfo.Partition.Columns) > 1 { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, does not allow RANGE COLUMNS with more than one column") + } + var partCol *model.ColumnInfo + if len(tbInfo.Partition.Columns) > 0 { + partCol = findColumnByName(tbInfo.Partition.Columns[0].L, tbInfo) + if partCol == nil { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, could not find any RANGE COLUMNS") + } + // Only support Datetime, date and INT column types for RANGE INTERVAL! + switch partCol.FieldType.EvalType() { + case types.ETInt, types.ETDatetime: + default: + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, only supports Date, Datetime and INT types") + } + } + // Allow given partition definitions, but check it later! + definedPartDefs := partOptions.Definitions + partOptions.Definitions = make([]*ast.PartitionDefinition, 0, 1) + if partOptions.Interval.FirstRangeEnd == nil || partOptions.Interval.LastRangeEnd == nil { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, currently requires FIRST and LAST partitions to be defined") + } + switch partOptions.Interval.IntervalExpr.TimeUnit { + case ast.TimeUnitInvalid, ast.TimeUnitYear, ast.TimeUnitQuarter, ast.TimeUnitMonth, ast.TimeUnitWeek, ast.TimeUnitDay, ast.TimeUnitHour, ast.TimeUnitDayMinute, ast.TimeUnitSecond: + default: + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, only supports YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE and SECOND as time unit") + } + first := ast.PartitionDefinitionClauseLessThan{ + Exprs: []ast.ExprNode{*partOptions.Interval.FirstRangeEnd}, + } + last := ast.PartitionDefinitionClauseLessThan{ + Exprs: []ast.ExprNode{*partOptions.Interval.LastRangeEnd}, + } + if len(tbInfo.Partition.Columns) > 0 { + colTypes := collectColumnsType(tbInfo) + if len(colTypes) != len(tbInfo.Partition.Columns) { + return dbterror.ErrWrongPartitionName.GenWithStack("partition column name cannot be found") + } + if _, err := checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, first.Exprs); err != nil { + return err + } + if _, err := checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, last.Exprs); err != nil { + return err + } + } else { + if err := checkPartitionValuesIsInt(ctx, "FIRST PARTITION", first.Exprs, tbInfo); err != nil { + return err + } + if err := checkPartitionValuesIsInt(ctx, "LAST PARTITION", last.Exprs, tbInfo); err != nil { + return err + } + } + if partOptions.Interval.NullPart { + var partExpr ast.ExprNode + if len(tbInfo.Partition.Columns) == 1 && partOptions.Interval.IntervalExpr.TimeUnit != ast.TimeUnitInvalid { + // Notice compatibility with MySQL, keyword here is 'supported range' but MySQL seems to work from 0000-01-01 too + // https://dev.mysql.com/doc/refman/8.0/en/datetime.html says range 1000-01-01 - 9999-12-31 + // https://docs.pingcap.com/tidb/dev/data-type-date-and-time says The supported range is '0000-01-01' to '9999-12-31' + // set LESS THAN to ZeroTime + partExpr = ast.NewValueExpr("0000-01-01", "", "") + } else { + var min int64 + if partCol != nil { + min = getLowerBoundInt(partCol) + } else { + if !isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) { + min = math.MinInt64 + } + } + partExpr = ast.NewValueExpr(min, "", "") + } + partOptions.Definitions = append(partOptions.Definitions, &ast.PartitionDefinition{ + Name: model.NewCIStr("P_NULL"), + Clause: &ast.PartitionDefinitionClauseLessThan{ + Exprs: []ast.ExprNode{partExpr}, + }, + }) + } + + err := GeneratePartDefsFromInterval(ctx, ast.AlterTablePartition, tbInfo, partOptions) + if err != nil { + return err + } + + if partOptions.Interval.MaxValPart { + partOptions.Definitions = append(partOptions.Definitions, &ast.PartitionDefinition{ + Name: model.NewCIStr("P_MAXVALUE"), + Clause: &ast.PartitionDefinitionClauseLessThan{ + Exprs: []ast.ExprNode{&ast.MaxValueExpr{}}, + }, + }) + } + + if len(definedPartDefs) > 0 { + err := comparePartitionDefinitions(ctx, partOptions.Definitions, definedPartDefs) + if err != nil { + return err + } + // Seems valid, so keep the defined so that the user defined names are kept etc. + partOptions.Definitions = definedPartDefs + } else if len(tbInfo.Partition.Definitions) > 0 { + err := comparePartitionAstAndModel(ctx, partOptions, tbInfo.Partition, partCol) + if err != nil { + return err + } + } + + return nil +} + +func checkAndGetColumnsTypeAndValuesMatch(ctx expression.BuildContext, colTypes []types.FieldType, exprs []ast.ExprNode) ([]types.Datum, error) { + // Validate() has already checked len(colNames) = len(exprs) + // create table ... partition by range columns (cols) + // partition p0 values less than (expr) + // check the type of cols[i] and expr is consistent. + valDatums := make([]types.Datum, 0, len(colTypes)) + for i, colExpr := range exprs { + if _, ok := colExpr.(*ast.MaxValueExpr); ok { + valDatums = append(valDatums, types.NewStringDatum(partitionMaxValue)) + continue + } + if d, ok := colExpr.(*ast.DefaultExpr); ok { + if d.Name != nil { + return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() + } + continue + } + colType := colTypes[i] + val, err := expression.EvalSimpleAst(ctx, colExpr) + if err != nil { + return nil, err + } + // Check val.ConvertTo(colType) doesn't work, so we need this case by case check. + vkind := val.Kind() + switch colType.GetType() { + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeDuration: + switch vkind { + case types.KindString, types.KindBytes, types.KindNull: + default: + return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() + } + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + switch vkind { + case types.KindInt64, types.KindUint64, types.KindNull: + default: + return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() + } + case mysql.TypeFloat, mysql.TypeDouble: + switch vkind { + case types.KindFloat32, types.KindFloat64, types.KindNull: + default: + return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() + } + case mysql.TypeString, mysql.TypeVarString: + switch vkind { + case types.KindString, types.KindBytes, types.KindNull, types.KindBinaryLiteral: + default: + return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() + } + } + evalCtx := ctx.GetEvalCtx() + newVal, err := val.ConvertTo(evalCtx.TypeCtx(), &colType) + if err != nil { + return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() + } + valDatums = append(valDatums, newVal) + } + return valDatums, nil +} + +func astIntValueExprFromStr(s string, unsigned bool) (ast.ExprNode, error) { + if unsigned { + u, err := strconv.ParseUint(s, 10, 64) + if err != nil { + return nil, err + } + return ast.NewValueExpr(u, "", ""), nil + } + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, err + } + return ast.NewValueExpr(i, "", ""), nil +} + +// GeneratePartDefsFromInterval generates range partitions from INTERVAL partitioning. +// Handles +// - CREATE TABLE: all partitions are generated +// - ALTER TABLE FIRST PARTITION (expr): Drops all partitions before the partition matching the expr (i.e. sets that partition as the new first partition) +// i.e. will return the partitions from old FIRST partition to (and including) new FIRST partition +// - ALTER TABLE LAST PARTITION (expr): Creates new partitions from (excluding) old LAST partition to (including) new LAST partition +// +// partition definitions will be set on partitionOptions +func GeneratePartDefsFromInterval(ctx expression.BuildContext, tp ast.AlterTableType, tbInfo *model.TableInfo, partitionOptions *ast.PartitionOptions) error { + if partitionOptions == nil { + return nil + } + var sb strings.Builder + err := partitionOptions.Interval.IntervalExpr.Expr.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) + if err != nil { + return err + } + intervalString := driver.UnwrapFromSingleQuotes(sb.String()) + if len(intervalString) < 1 || intervalString[:1] < "1" || intervalString[:1] > "9" { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL, should be a positive number") + } + var currVal types.Datum + var startExpr, lastExpr, currExpr ast.ExprNode + var timeUnit ast.TimeUnitType + var partCol *model.ColumnInfo + if len(tbInfo.Partition.Columns) == 1 { + partCol = findColumnByName(tbInfo.Partition.Columns[0].L, tbInfo) + if partCol == nil { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL COLUMNS partitioning: could not find partitioning column") + } + } + timeUnit = partitionOptions.Interval.IntervalExpr.TimeUnit + switch tp { + case ast.AlterTablePartition: + // CREATE TABLE + startExpr = *partitionOptions.Interval.FirstRangeEnd + lastExpr = *partitionOptions.Interval.LastRangeEnd + case ast.AlterTableDropFirstPartition: + startExpr = *partitionOptions.Interval.FirstRangeEnd + lastExpr = partitionOptions.Expr + case ast.AlterTableAddLastPartition: + startExpr = *partitionOptions.Interval.LastRangeEnd + lastExpr = partitionOptions.Expr + default: + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: Internal error during generating altered INTERVAL partitions, no known alter type") + } + lastVal, err := expression.EvalSimpleAst(ctx, lastExpr) + if err != nil { + return err + } + evalCtx := ctx.GetEvalCtx() + if partCol != nil { + lastVal, err = lastVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) + if err != nil { + return err + } + } + var partDefs []*ast.PartitionDefinition + if len(partitionOptions.Definitions) != 0 { + partDefs = partitionOptions.Definitions + } else { + partDefs = make([]*ast.PartitionDefinition, 0, 1) + } + for i := 0; i < mysql.PartitionCountLimit; i++ { + if i == 0 { + currExpr = startExpr + // TODO: adjust the startExpr and have an offset for interval to handle + // Month/Quarters with start partition on day 28/29/30 + if tp == ast.AlterTableAddLastPartition { + // ALTER TABLE LAST PARTITION ... + // Current LAST PARTITION/start already exists, skip to next partition + continue + } + } else { + currExpr = &ast.BinaryOperationExpr{ + Op: opcode.Mul, + L: ast.NewValueExpr(i, "", ""), + R: partitionOptions.Interval.IntervalExpr.Expr, + } + if timeUnit == ast.TimeUnitInvalid { + currExpr = &ast.BinaryOperationExpr{ + Op: opcode.Plus, + L: startExpr, + R: currExpr, + } + } else { + currExpr = &ast.FuncCallExpr{ + FnName: model.NewCIStr("DATE_ADD"), + Args: []ast.ExprNode{ + startExpr, + currExpr, + &ast.TimeUnitExpr{Unit: timeUnit}, + }, + } + } + } + currVal, err = expression.EvalSimpleAst(ctx, currExpr) + if err != nil { + return err + } + if partCol != nil { + currVal, err = currVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) + if err != nil { + return err + } + } + cmp, err := currVal.Compare(evalCtx.TypeCtx(), &lastVal, collate.GetBinaryCollator()) + if err != nil { + return err + } + if cmp > 0 { + lastStr, err := lastVal.ToString() + if err != nil { + return err + } + sb.Reset() + err = startExpr.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) + if err != nil { + return err + } + startStr := sb.String() + errStr := fmt.Sprintf("INTERVAL: expr (%s) not matching FIRST + n INTERVALs (%s + n * %s", + lastStr, startStr, intervalString) + if timeUnit != ast.TimeUnitInvalid { + errStr = errStr + " " + timeUnit.String() + } + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(errStr + ")") + } + valStr, err := currVal.ToString() + if err != nil { + return err + } + if len(valStr) == 0 || valStr[0:1] == "'" { + return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: Error when generating partition values") + } + partName := "P_LT_" + valStr + if timeUnit != ast.TimeUnitInvalid { + currExpr = ast.NewValueExpr(valStr, "", "") + } else { + if valStr[:1] == "-" { + currExpr = ast.NewValueExpr(currVal.GetInt64(), "", "") + } else { + currExpr = ast.NewValueExpr(currVal.GetUint64(), "", "") + } + } + partDefs = append(partDefs, &ast.PartitionDefinition{ + Name: model.NewCIStr(partName), + Clause: &ast.PartitionDefinitionClauseLessThan{ + Exprs: []ast.ExprNode{currExpr}, + }, + }) + if cmp == 0 { + // Last partition! + break + } + // The last loop still not reach the max value, return error. + if i == mysql.PartitionCountLimit-1 { + return errors.Trace(dbterror.ErrTooManyPartitions) + } + } + if len(tbInfo.Partition.Definitions)+len(partDefs) > mysql.PartitionCountLimit { + return errors.Trace(dbterror.ErrTooManyPartitions) + } + partitionOptions.Definitions = partDefs + return nil +} + +// buildPartitionDefinitionsInfo build partition definitions info without assign partition id. tbInfo will be constant +func buildPartitionDefinitionsInfo(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) (partitions []model.PartitionDefinition, err error) { + switch tbInfo.Partition.Type { + case model.PartitionTypeNone: + if len(defs) != 1 { + return nil, dbterror.ErrUnsupportedPartitionType + } + partitions = []model.PartitionDefinition{{Name: defs[0].Name}} + if comment, set := defs[0].Comment(); set { + partitions[0].Comment = comment + } + case model.PartitionTypeRange: + partitions, err = buildRangePartitionDefinitions(ctx, defs, tbInfo) + case model.PartitionTypeHash, model.PartitionTypeKey: + partitions, err = buildHashPartitionDefinitions(defs, tbInfo, numParts) + case model.PartitionTypeList: + partitions, err = buildListPartitionDefinitions(ctx, defs, tbInfo) + default: + err = dbterror.ErrUnsupportedPartitionType + } + + if err != nil { + return nil, err + } + + return partitions, nil +} + +func setPartitionPlacementFromOptions(partition *model.PartitionDefinition, options []*ast.TableOption) error { + // the partition inheritance of placement rules don't have to copy the placement elements to themselves. + // For example: + // t placement policy x (p1 placement policy y, p2) + // p2 will share the same rule as table t does, but it won't copy the meta to itself. we will + // append p2 range to the coverage of table t's rules. This mechanism is good for cascading change + // when policy x is altered. + for _, opt := range options { + if opt.Tp == ast.TableOptionPlacementPolicy { + partition.PlacementPolicyRef = &model.PolicyRefInfo{ + Name: model.NewCIStr(opt.StrValue), + } + } + } + + return nil +} + +func isNonDefaultPartitionOptionsUsed(defs []model.PartitionDefinition) bool { + for i := range defs { + orgDef := defs[i] + if orgDef.Name.O != fmt.Sprintf("p%d", i) { + return true + } + if len(orgDef.Comment) > 0 { + return true + } + if orgDef.PlacementPolicyRef != nil { + return true + } + } + return false +} + +func buildHashPartitionDefinitions(defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) ([]model.PartitionDefinition, error) { + if err := checkAddPartitionTooManyPartitions(tbInfo.Partition.Num); err != nil { + return nil, err + } + + definitions := make([]model.PartitionDefinition, numParts) + oldParts := uint64(len(tbInfo.Partition.Definitions)) + for i := uint64(0); i < numParts; i++ { + if i < oldParts { + // Use the existing definitions + def := tbInfo.Partition.Definitions[i] + definitions[i].Name = def.Name + definitions[i].Comment = def.Comment + definitions[i].PlacementPolicyRef = def.PlacementPolicyRef + } else if i < oldParts+uint64(len(defs)) { + // Use the new defs + def := defs[i-oldParts] + definitions[i].Name = def.Name + definitions[i].Comment, _ = def.Comment() + if err := setPartitionPlacementFromOptions(&definitions[i], def.Options); err != nil { + return nil, err + } + } else { + // Use the default + definitions[i].Name = model.NewCIStr(fmt.Sprintf("p%d", i)) + } + } + return definitions, nil +} + +func buildListPartitionDefinitions(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { + definitions := make([]model.PartitionDefinition, 0, len(defs)) + exprChecker := newPartitionExprChecker(ctx, nil, checkPartitionExprAllowed) + colTypes := collectColumnsType(tbInfo) + if len(colTypes) != len(tbInfo.Partition.Columns) { + return nil, dbterror.ErrWrongPartitionName.GenWithStack("partition column name cannot be found") + } + for _, def := range defs { + if err := def.Clause.Validate(model.PartitionTypeList, len(tbInfo.Partition.Columns)); err != nil { + return nil, err + } + clause := def.Clause.(*ast.PartitionDefinitionClauseIn) + partVals := make([][]types.Datum, 0, len(clause.Values)) + if len(tbInfo.Partition.Columns) > 0 { + for _, vs := range clause.Values { + vals, err := checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, vs) + if err != nil { + return nil, err + } + partVals = append(partVals, vals) + } + } else { + for _, vs := range clause.Values { + if err := checkPartitionValuesIsInt(ctx, def.Name, vs, tbInfo); err != nil { + return nil, err + } + } + } + comment, _ := def.Comment() + err := checkTooLongTable(def.Name) + if err != nil { + return nil, err + } + piDef := model.PartitionDefinition{ + Name: def.Name, + Comment: comment, + } + + if err = setPartitionPlacementFromOptions(&piDef, def.Options); err != nil { + return nil, err + } + + buf := new(bytes.Buffer) + for valIdx, vs := range clause.Values { + inValue := make([]string, 0, len(vs)) + isDefault := false + if len(vs) == 1 { + if _, ok := vs[0].(*ast.DefaultExpr); ok { + isDefault = true + } + } + if len(partVals) > valIdx && !isDefault { + for colIdx := range partVals[valIdx] { + partVal, err := generatePartValuesWithTp(partVals[valIdx][colIdx], colTypes[colIdx]) + if err != nil { + return nil, err + } + inValue = append(inValue, partVal) + } + } else { + for i := range vs { + vs[i].Accept(exprChecker) + if exprChecker.err != nil { + return nil, exprChecker.err + } + buf.Reset() + vs[i].Format(buf) + inValue = append(inValue, buf.String()) + } + } + piDef.InValues = append(piDef.InValues, inValue) + buf.Reset() + } + definitions = append(definitions, piDef) + } + return definitions, nil +} + +func collectColumnsType(tbInfo *model.TableInfo) []types.FieldType { + if len(tbInfo.Partition.Columns) > 0 { + colTypes := make([]types.FieldType, 0, len(tbInfo.Partition.Columns)) + for _, col := range tbInfo.Partition.Columns { + c := findColumnByName(col.L, tbInfo) + if c == nil { + return nil + } + colTypes = append(colTypes, c.FieldType) + } + + return colTypes + } + + return nil +} + +func buildRangePartitionDefinitions(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { + definitions := make([]model.PartitionDefinition, 0, len(defs)) + exprChecker := newPartitionExprChecker(ctx, nil, checkPartitionExprAllowed) + colTypes := collectColumnsType(tbInfo) + if len(colTypes) != len(tbInfo.Partition.Columns) { + return nil, dbterror.ErrWrongPartitionName.GenWithStack("partition column name cannot be found") + } + for _, def := range defs { + if err := def.Clause.Validate(model.PartitionTypeRange, len(tbInfo.Partition.Columns)); err != nil { + return nil, err + } + clause := def.Clause.(*ast.PartitionDefinitionClauseLessThan) + var partValDatums []types.Datum + if len(tbInfo.Partition.Columns) > 0 { + var err error + if partValDatums, err = checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, clause.Exprs); err != nil { + return nil, err + } + } else { + if err := checkPartitionValuesIsInt(ctx, def.Name, clause.Exprs, tbInfo); err != nil { + return nil, err + } + } + comment, _ := def.Comment() + evalCtx := ctx.GetEvalCtx() + comment, err := validateCommentLength(evalCtx.ErrCtx(), evalCtx.SQLMode(), def.Name.L, &comment, dbterror.ErrTooLongTablePartitionComment) + if err != nil { + return nil, err + } + err = checkTooLongTable(def.Name) + if err != nil { + return nil, err + } + piDef := model.PartitionDefinition{ + Name: def.Name, + Comment: comment, + } + + if err = setPartitionPlacementFromOptions(&piDef, def.Options); err != nil { + return nil, err + } + + buf := new(bytes.Buffer) + // Range columns partitions support multi-column partitions. + for i, expr := range clause.Exprs { + expr.Accept(exprChecker) + if exprChecker.err != nil { + return nil, exprChecker.err + } + // If multi-column use new evaluated+normalized output, instead of just formatted expression + if len(partValDatums) > i { + var partVal string + if partValDatums[i].Kind() == types.KindNull { + return nil, dbterror.ErrNullInValuesLessThan + } + if _, ok := clause.Exprs[i].(*ast.MaxValueExpr); ok { + partVal, err = partValDatums[i].ToString() + if err != nil { + return nil, err + } + } else { + partVal, err = generatePartValuesWithTp(partValDatums[i], colTypes[i]) + if err != nil { + return nil, err + } + } + + piDef.LessThan = append(piDef.LessThan, partVal) + } else { + expr.Format(buf) + piDef.LessThan = append(piDef.LessThan, buf.String()) + buf.Reset() + } + } + definitions = append(definitions, piDef) + } + return definitions, nil +} + +func checkPartitionValuesIsInt(ctx expression.BuildContext, defName any, exprs []ast.ExprNode, tbInfo *model.TableInfo) error { + tp := types.NewFieldType(mysql.TypeLonglong) + if isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) { + tp.AddFlag(mysql.UnsignedFlag) + } + for _, exp := range exprs { + if _, ok := exp.(*ast.MaxValueExpr); ok { + continue + } + if d, ok := exp.(*ast.DefaultExpr); ok { + if d.Name != nil { + return dbterror.ErrPartitionConstDomain.GenWithStackByArgs() + } + continue + } + val, err := expression.EvalSimpleAst(ctx, exp) + if err != nil { + return err + } + switch val.Kind() { + case types.KindUint64, types.KindNull: + case types.KindInt64: + if mysql.HasUnsignedFlag(tp.GetFlag()) && val.GetInt64() < 0 { + return dbterror.ErrPartitionConstDomain.GenWithStackByArgs() + } + default: + return dbterror.ErrValuesIsNotIntType.GenWithStackByArgs(defName) + } + + evalCtx := ctx.GetEvalCtx() + _, err = val.ConvertTo(evalCtx.TypeCtx(), tp) + if err != nil && !types.ErrOverflow.Equal(err) { + return dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() + } + } + + return nil +} + +func checkPartitionNameUnique(pi *model.PartitionInfo) error { + newPars := pi.Definitions + partNames := make(map[string]struct{}, len(newPars)) + for _, newPar := range newPars { + if _, ok := partNames[newPar.Name.L]; ok { + return dbterror.ErrSameNamePartition.GenWithStackByArgs(newPar.Name) + } + partNames[newPar.Name.L] = struct{}{} + } + return nil +} + +func checkAddPartitionNameUnique(tbInfo *model.TableInfo, pi *model.PartitionInfo) error { + partNames := make(map[string]struct{}) + if tbInfo.Partition != nil { + oldPars := tbInfo.Partition.Definitions + for _, oldPar := range oldPars { + partNames[oldPar.Name.L] = struct{}{} + } + } + newPars := pi.Definitions + for _, newPar := range newPars { + if _, ok := partNames[newPar.Name.L]; ok { + return dbterror.ErrSameNamePartition.GenWithStackByArgs(newPar.Name) + } + partNames[newPar.Name.L] = struct{}{} + } + return nil +} + +func checkReorgPartitionNames(p *model.PartitionInfo, droppedNames []string, pi *model.PartitionInfo) error { + partNames := make(map[string]struct{}) + oldDefs := p.Definitions + for _, oldDef := range oldDefs { + partNames[oldDef.Name.L] = struct{}{} + } + for _, delName := range droppedNames { + droppedName := strings.ToLower(delName) + if _, ok := partNames[droppedName]; !ok { + return dbterror.ErrSameNamePartition.GenWithStackByArgs(delName) + } + delete(partNames, droppedName) + } + newDefs := pi.Definitions + for _, newDef := range newDefs { + if _, ok := partNames[newDef.Name.L]; ok { + return dbterror.ErrSameNamePartition.GenWithStackByArgs(newDef.Name) + } + partNames[newDef.Name.L] = struct{}{} + } + return nil +} + +func checkAndOverridePartitionID(newTableInfo, oldTableInfo *model.TableInfo) error { + // If any old partitionInfo has lost, that means the partition ID lost too, so did the data, repair failed. + if newTableInfo.Partition == nil { + return nil + } + if oldTableInfo.Partition == nil { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Old table doesn't have partitions") + } + if newTableInfo.Partition.Type != oldTableInfo.Partition.Type { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Partition type should be the same") + } + // Check whether partitionType is hash partition. + if newTableInfo.Partition.Type == model.PartitionTypeHash { + if newTableInfo.Partition.Num != oldTableInfo.Partition.Num { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Hash partition num should be the same") + } + } + for i, newOne := range newTableInfo.Partition.Definitions { + found := false + for _, oldOne := range oldTableInfo.Partition.Definitions { + // Fix issue 17952 which wanna substitute partition range expr. + // So eliminate stringSliceEqual(newOne.LessThan, oldOne.LessThan) here. + if newOne.Name.L == oldOne.Name.L { + newTableInfo.Partition.Definitions[i].ID = oldOne.ID + found = true + break + } + } + if !found { + return dbterror.ErrRepairTableFail.GenWithStackByArgs("Partition " + newOne.Name.L + " has lost") + } + } + return nil +} + +// checkPartitionFuncValid checks partition function validly. +func checkPartitionFuncValid(ctx expression.BuildContext, tblInfo *model.TableInfo, expr ast.ExprNode) error { + if expr == nil { + return nil + } + exprChecker := newPartitionExprChecker(ctx, tblInfo, checkPartitionExprArgs, checkPartitionExprAllowed) + expr.Accept(exprChecker) + if exprChecker.err != nil { + return errors.Trace(exprChecker.err) + } + if len(exprChecker.columns) == 0 { + return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) + } + return nil +} + +// checkResultOK derives from https://github.com/mysql/mysql-server/blob/5.7/sql/item_timefunc +// For partition tables, mysql do not support Constant, random or timezone-dependent expressions +// Based on mysql code to check whether field is valid, every time related type has check_valid_arguments_processor function. +func checkResultOK(ok bool) error { + if !ok { + return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) + } + + return nil +} + +// checkPartitionFuncType checks partition function return type. +func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, schema string, tblInfo *model.TableInfo) error { + if expr == nil { + return nil + } + + if schema == "" { + schema = ctx.GetSessionVars().CurrentDB + } + + e, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), expr, expression.WithTableInfo(schema, tblInfo)) + if err != nil { + return errors.Trace(err) + } + if e.GetType(ctx.GetExprCtx().GetEvalCtx()).EvalType() == types.ETInt { + return nil + } + + if col, ok := expr.(*ast.ColumnNameExpr); ok { + return errors.Trace(dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(col.Name.Name.L)) + } + + return errors.Trace(dbterror.ErrPartitionFuncNotAllowed.GenWithStackByArgs("PARTITION")) +} + +// checkRangePartitionValue checks whether `less than value` is strictly increasing for each partition. +// Side effect: it may simplify the partition range definition from a constant expression to an integer. +func checkRangePartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) error { + pi := tblInfo.Partition + defs := pi.Definitions + if len(defs) == 0 { + return nil + } + + if strings.EqualFold(defs[len(defs)-1].LessThan[0], partitionMaxValue) { + defs = defs[:len(defs)-1] + } + isUnsigned := isPartExprUnsigned(ctx.GetExprCtx().GetEvalCtx(), tblInfo) + var prevRangeValue any + for i := 0; i < len(defs); i++ { + if strings.EqualFold(defs[i].LessThan[0], partitionMaxValue) { + return errors.Trace(dbterror.ErrPartitionMaxvalue) + } + + currentRangeValue, fromExpr, err := getRangeValue(ctx.GetExprCtx(), defs[i].LessThan[0], isUnsigned) + if err != nil { + return errors.Trace(err) + } + if fromExpr { + // Constant fold the expression. + defs[i].LessThan[0] = fmt.Sprintf("%d", currentRangeValue) + } + + if i == 0 { + prevRangeValue = currentRangeValue + continue + } + + if isUnsigned { + if currentRangeValue.(uint64) <= prevRangeValue.(uint64) { + return errors.Trace(dbterror.ErrRangeNotIncreasing) + } + } else { + if currentRangeValue.(int64) <= prevRangeValue.(int64) { + return errors.Trace(dbterror.ErrRangeNotIncreasing) + } + } + prevRangeValue = currentRangeValue + } + return nil +} + +func checkListPartitionValue(ctx expression.BuildContext, tblInfo *model.TableInfo) error { + pi := tblInfo.Partition + if len(pi.Definitions) == 0 { + return ast.ErrPartitionsMustBeDefined.GenWithStackByArgs("LIST") + } + expStr, err := formatListPartitionValue(ctx, tblInfo) + if err != nil { + return errors.Trace(err) + } + + partitionsValuesMap := make(map[string]struct{}) + for _, s := range expStr { + if _, ok := partitionsValuesMap[s]; ok { + return errors.Trace(dbterror.ErrMultipleDefConstInListPart) + } + partitionsValuesMap[s] = struct{}{} + } + + return nil +} + +func formatListPartitionValue(ctx expression.BuildContext, tblInfo *model.TableInfo) ([]string, error) { + defs := tblInfo.Partition.Definitions + pi := tblInfo.Partition + var colTps []*types.FieldType + cols := make([]*model.ColumnInfo, 0, len(pi.Columns)) + if len(pi.Columns) == 0 { + tp := types.NewFieldType(mysql.TypeLonglong) + if isPartExprUnsigned(ctx.GetEvalCtx(), tblInfo) { + tp.AddFlag(mysql.UnsignedFlag) + } + colTps = []*types.FieldType{tp} + } else { + colTps = make([]*types.FieldType, 0, len(pi.Columns)) + for _, colName := range pi.Columns { + colInfo := findColumnByName(colName.L, tblInfo) + if colInfo == nil { + return nil, errors.Trace(dbterror.ErrFieldNotFoundPart) + } + colTps = append(colTps, colInfo.FieldType.Clone()) + cols = append(cols, colInfo) + } + } + + haveDefault := false + exprStrs := make([]string, 0) + inValueStrs := make([]string, 0, mathutil.Max(len(pi.Columns), 1)) + for i := range defs { + inValuesLoop: + for j, vs := range defs[i].InValues { + inValueStrs = inValueStrs[:0] + for k, v := range vs { + // if DEFAULT would be given as string, like "DEFAULT", + // it would be stored as "'DEFAULT'", + if strings.EqualFold(v, "DEFAULT") && k == 0 && len(vs) == 1 { + if haveDefault { + return nil, dbterror.ErrMultipleDefConstInListPart + } + haveDefault = true + continue inValuesLoop + } + if strings.EqualFold(v, "MAXVALUE") { + return nil, errors.Trace(dbterror.ErrMaxvalueInValuesIn) + } + expr, err := expression.ParseSimpleExpr(ctx, v, expression.WithCastExprTo(colTps[k])) + if err != nil { + return nil, errors.Trace(err) + } + eval, err := expr.Eval(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, errors.Trace(err) + } + s, err := eval.ToString() + if err != nil { + return nil, errors.Trace(err) + } + if eval.IsNull() { + s = "NULL" + } else { + if colTps[k].EvalType() == types.ETInt { + defs[i].InValues[j][k] = s + } + if colTps[k].EvalType() == types.ETString { + s = string(hack.String(collate.GetCollator(cols[k].GetCollate()).Key(s))) + s = driver.WrapInSingleQuotes(s) + } + } + inValueStrs = append(inValueStrs, s) + } + exprStrs = append(exprStrs, strings.Join(inValueStrs, ",")) + } + } + return exprStrs, nil +} + +// getRangeValue gets an integer from the range value string. +// The returned boolean value indicates whether the input string is a constant expression. +func getRangeValue(ctx expression.BuildContext, str string, unsigned bool) (any, bool, error) { + // Unsigned bigint was converted to uint64 handle. + if unsigned { + if value, err := strconv.ParseUint(str, 10, 64); err == nil { + return value, false, nil + } + + e, err1 := expression.ParseSimpleExpr(ctx, str) + if err1 != nil { + return 0, false, err1 + } + res, isNull, err2 := e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) + if err2 == nil && !isNull { + return uint64(res), true, nil + } + } else { + if value, err := strconv.ParseInt(str, 10, 64); err == nil { + return value, false, nil + } + // The range value maybe not an integer, it could be a constant expression. + // For example, the following two cases are the same: + // PARTITION p0 VALUES LESS THAN (TO_SECONDS('2004-01-01')) + // PARTITION p0 VALUES LESS THAN (63340531200) + e, err1 := expression.ParseSimpleExpr(ctx, str) + if err1 != nil { + return 0, false, err1 + } + res, isNull, err2 := e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) + if err2 == nil && !isNull { + return res, true, nil + } + } + return 0, false, dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(str) +} + +// CheckDropTablePartition checks if the partition exists and does not allow deleting the last existing partition in the table. +func CheckDropTablePartition(meta *model.TableInfo, partLowerNames []string) error { + pi := meta.Partition + if pi.Type != model.PartitionTypeRange && pi.Type != model.PartitionTypeList { + return dbterror.ErrOnlyOnRangeListPartition.GenWithStackByArgs("DROP") + } + + // To be error compatible with MySQL, we need to do this first! + // see https://github.com/pingcap/tidb/issues/31681#issuecomment-1015536214 + oldDefs := pi.Definitions + if len(oldDefs) <= len(partLowerNames) { + return errors.Trace(dbterror.ErrDropLastPartition) + } + + dupCheck := make(map[string]bool) + for _, pn := range partLowerNames { + found := false + for _, def := range oldDefs { + if def.Name.L == pn { + if _, ok := dupCheck[pn]; ok { + return errors.Trace(dbterror.ErrDropPartitionNonExistent.GenWithStackByArgs("DROP")) + } + dupCheck[pn] = true + found = true + break + } + } + if !found { + return errors.Trace(dbterror.ErrDropPartitionNonExistent.GenWithStackByArgs("DROP")) + } + } + return nil +} + +// updateDroppingPartitionInfo move dropping partitions to DroppingDefinitions, and return partitionIDs +func updateDroppingPartitionInfo(tblInfo *model.TableInfo, partLowerNames []string) []int64 { + oldDefs := tblInfo.Partition.Definitions + newDefs := make([]model.PartitionDefinition, 0, len(oldDefs)-len(partLowerNames)) + droppingDefs := make([]model.PartitionDefinition, 0, len(partLowerNames)) + pids := make([]int64, 0, len(partLowerNames)) + + // consider using a map to probe partLowerNames if too many partLowerNames + for i := range oldDefs { + found := false + for _, partName := range partLowerNames { + if oldDefs[i].Name.L == partName { + found = true + break + } + } + if found { + pids = append(pids, oldDefs[i].ID) + droppingDefs = append(droppingDefs, oldDefs[i]) + } else { + newDefs = append(newDefs, oldDefs[i]) + } + } + + tblInfo.Partition.Definitions = newDefs + tblInfo.Partition.DroppingDefinitions = droppingDefs + return pids +} + +func getPartitionDef(tblInfo *model.TableInfo, partName string) (index int, def *model.PartitionDefinition, _ error) { + defs := tblInfo.Partition.Definitions + for i := 0; i < len(defs); i++ { + if strings.EqualFold(defs[i].Name.L, strings.ToLower(partName)) { + return i, &(defs[i]), nil + } + } + return index, nil, table.ErrUnknownPartition.GenWithStackByArgs(partName, tblInfo.Name.O) +} + +func getPartitionIDsFromDefinitions(defs []model.PartitionDefinition) []int64 { + pids := make([]int64, 0, len(defs)) + for _, def := range defs { + pids = append(pids, def.ID) + } + return pids +} + +func hasGlobalIndex(tblInfo *model.TableInfo) bool { + for _, idxInfo := range tblInfo.Indices { + if idxInfo.Global { + return true + } + } + return false +} + +// getTableInfoWithDroppingPartitions builds oldTableInfo including dropping partitions, only used by onDropTablePartition. +func getTableInfoWithDroppingPartitions(t *model.TableInfo) *model.TableInfo { + p := t.Partition + nt := t.Clone() + np := *p + npd := make([]model.PartitionDefinition, 0, len(p.Definitions)+len(p.DroppingDefinitions)) + npd = append(npd, p.Definitions...) + npd = append(npd, p.DroppingDefinitions...) + np.Definitions = npd + np.DroppingDefinitions = nil + nt.Partition = &np + return nt +} + +// getTableInfoWithOriginalPartitions builds oldTableInfo including truncating partitions, only used by onTruncateTablePartition. +func getTableInfoWithOriginalPartitions(t *model.TableInfo, oldIDs []int64, newIDs []int64) *model.TableInfo { + nt := t.Clone() + np := nt.Partition + + // reconstruct original definitions + for _, oldDef := range np.DroppingDefinitions { + var newID int64 + for i := range newIDs { + if oldDef.ID == oldIDs[i] { + newID = newIDs[i] + break + } + } + for i := range np.Definitions { + newDef := &np.Definitions[i] + if newDef.ID == newID { + newDef.ID = oldDef.ID + break + } + } + } + + np.DroppingDefinitions = nil + np.NewPartitionIDs = nil + return nt +} + +func dropLabelRules(ctx context.Context, schemaName, tableName string, partNames []string) error { + deleteRules := make([]string, 0, len(partNames)) + for _, partName := range partNames { + deleteRules = append(deleteRules, fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, schemaName, tableName, partName)) + } + // delete batch rules + patch := label.NewRulePatch([]*label.Rule{}, deleteRules) + return infosync.UpdateLabelRules(ctx, patch) +} + +// onDropTablePartition deletes old partition meta. +func (w *worker) onDropTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var partNames []string + partInfo := model.PartitionInfo{} + if err := job.DecodeArgs(&partNames, &partInfo); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + if job.Type != model.ActionDropTablePartition { + // If rollback from reorganize partition, remove DroppingDefinitions from tableInfo + tblInfo.Partition.DroppingDefinitions = nil + // If rollback from adding table partition, remove addingDefinitions from tableInfo. + physicalTableIDs, pNames, rollbackBundles := rollbackAddingPartitionInfo(tblInfo) + err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), rollbackBundles) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") + } + // TODO: Will this drop LabelRules for existing partitions, if the new partitions have the same name? + err = dropLabelRules(w.ctx, job.SchemaName, tblInfo.Name.L, pNames) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the label rules") + } + + if _, err := alterTableLabelRule(job.SchemaName, tblInfo, getIDs([]*model.TableInfo{tblInfo})); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + // ALTER TABLE ... PARTITION BY + if partInfo.Type != model.PartitionTypeNone { + // Also remove anything with the new table id + physicalTableIDs = append(physicalTableIDs, partInfo.NewTableID) + // Reset if it was normal table before + if tblInfo.Partition.Type == model.PartitionTypeNone || + tblInfo.Partition.DDLType == model.PartitionTypeNone { + tblInfo.Partition = nil + } else { + tblInfo.Partition.ClearReorgIntermediateInfo() + } + } else { + // REMOVE PARTITIONING + tblInfo.Partition.ClearReorgIntermediateInfo() + } + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) + job.Args = []any{physicalTableIDs} + return ver, nil + } + + var physicalTableIDs []int64 + // In order to skip maintaining the state check in partitionDefinition, TiDB use droppingDefinition instead of state field. + // So here using `job.SchemaState` to judge what the stage of this job is. + originalState := job.SchemaState + switch job.SchemaState { + case model.StatePublic: + // If an error occurs, it returns that it cannot delete all partitions or that the partition doesn't exist. + err = CheckDropTablePartition(tblInfo, partNames) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + physicalTableIDs = updateDroppingPartitionInfo(tblInfo, partNames) + err = dropLabelRules(w.ctx, job.SchemaName, tblInfo.Name.L, partNames) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the label rules") + } + + if _, err := alterTableLabelRule(job.SchemaName, tblInfo, getIDs([]*model.TableInfo{tblInfo})); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + var bundles []*placement.Bundle + // create placement groups for each dropped partition to keep the data's placement before GC + // These placements groups will be deleted after GC + bundles, err = droppedPartitionBundles(t, tblInfo, tblInfo.Partition.DroppingDefinitions) + if err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + var tableBundle *placement.Bundle + // Recompute table bundle to remove dropped partitions rules from its group + tableBundle, err = placement.NewTableBundle(t, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if tableBundle != nil { + bundles = append(bundles, tableBundle) + } + + if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + job.SchemaState = model.StateDeleteOnly + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != job.SchemaState) + case model.StateDeleteOnly: + // This state is not a real 'DeleteOnly' state, because tidb does not maintaining the state check in partitionDefinition. + // Insert this state to confirm all servers can not see the old partitions when reorg is running, + // so that no new data will be inserted into old partitions when reorganizing. + job.SchemaState = model.StateDeleteReorganization + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != job.SchemaState) + case model.StateDeleteReorganization: + oldTblInfo := getTableInfoWithDroppingPartitions(tblInfo) + physicalTableIDs = getPartitionIDsFromDefinitions(tblInfo.Partition.DroppingDefinitions) + tbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, oldTblInfo) + if err != nil { + return ver, errors.Trace(err) + } + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + // If table has global indexes, we need reorg to clean up them. + if pt, ok := tbl.(table.PartitionedTable); ok && hasGlobalIndex(tblInfo) { + // Build elements for compatible with modify column type. elements will not be used when reorganizing. + elements := make([]*meta.Element, 0, len(tblInfo.Indices)) + for _, idxInfo := range tblInfo.Indices { + if idxInfo.Global { + elements = append(elements, &meta.Element{ID: idxInfo.ID, TypeKey: meta.IndexElementKey}) + } + } + sctx, err1 := w.sessPool.Get() + if err1 != nil { + return ver, err1 + } + defer w.sessPool.Put(sctx) + rh := newReorgHandler(sess.NewSession(sctx)) + reorgInfo, err := getReorgInfoFromPartitions(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, pt, physicalTableIDs, elements) + + if err != nil || reorgInfo.first { + // If we run reorg firstly, we should update the job snapshot version + // and then run the reorg next time. + return ver, errors.Trace(err) + } + err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (dropIndexErr error) { + defer tidbutil.Recover(metrics.LabelDDL, "onDropTablePartition", + func() { + dropIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("drop partition panic") + }, false) + return w.cleanupGlobalIndexes(pt, physicalTableIDs, reorgInfo) + }) + if err != nil { + if dbterror.ErrWaitReorgTimeout.Equal(err) { + // if timeout, we should return, check for the owner and re-wait job done. + return ver, nil + } + if dbterror.ErrPausedDDLJob.Equal(err) { + // if ErrPausedDDLJob, we should return, check for the owner and re-wait job done. + return ver, nil + } + return ver, errors.Trace(err) + } + } + if tblInfo.TiFlashReplica != nil { + removeTiFlashAvailablePartitionIDs(tblInfo, physicalTableIDs) + } + droppedDefs := tblInfo.Partition.DroppingDefinitions + tblInfo.Partition.DroppingDefinitions = nil + // used by ApplyDiff in updateSchemaVersion + job.CtxVars = []any{physicalTableIDs} + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.SchemaState = model.StateNone + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + dropPartitionEvent := statsutil.NewDropPartitionEvent( + job.SchemaID, + tblInfo, + &model.PartitionInfo{Definitions: droppedDefs}, + ) + asyncNotifyEvent(d, dropPartitionEvent) + // A background job will be created to delete old partition data. + job.Args = []any{physicalTableIDs} + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) + } + return ver, errors.Trace(err) +} + +func removeTiFlashAvailablePartitionIDs(tblInfo *model.TableInfo, pids []int64) { + // Remove the partitions + ids := tblInfo.TiFlashReplica.AvailablePartitionIDs + // Rarely called, so OK to take some time, to make it easy + for _, id := range pids { + for i, avail := range ids { + if id == avail { + tmp := ids[:i] + tmp = append(tmp, ids[i+1:]...) + ids = tmp + break + } + } + } + tblInfo.TiFlashReplica.AvailablePartitionIDs = ids +} + +// onTruncateTablePartition truncates old partition meta. +func (w *worker) onTruncateTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { + var ver int64 + var oldIDs, newIDs []int64 + if err := job.DecodeArgs(&oldIDs, &newIDs); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + if len(oldIDs) != len(newIDs) { + job.State = model.JobStateCancelled + return ver, errors.Trace(errors.New("len(oldIDs) must be the same as len(newIDs)")) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + pi := tblInfo.GetPartitionInfo() + if pi == nil { + return ver, errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) + } + + if !hasGlobalIndex(tblInfo) { + oldPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) + newPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) + for k, oldID := range oldIDs { + for i := 0; i < len(pi.Definitions); i++ { + def := &pi.Definitions[i] + if def.ID == oldID { + oldPartitions = append(oldPartitions, def.Clone()) + def.ID = newIDs[k] + // Shallow copy only use the def.ID in event handle. + newPartitions = append(newPartitions, *def) + break + } + } + } + if len(newPartitions) == 0 { + job.State = model.JobStateCancelled + return ver, table.ErrUnknownPartition.GenWithStackByArgs(fmt.Sprintf("pid:%v", oldIDs), tblInfo.Name.O) + } + + if err = clearTruncatePartitionTiflashStatus(tblInfo, newPartitions, oldIDs); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + if err = updateTruncatePartitionLabelRules(job, t, oldPartitions, newPartitions, tblInfo, oldIDs); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + preSplitAndScatter(w.sess.Context, d.store, tblInfo, newPartitions) + + job.CtxVars = []any{oldIDs, newIDs} + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + truncatePartitionEvent := statsutil.NewTruncatePartitionEvent( + job.SchemaID, + tblInfo, + &model.PartitionInfo{Definitions: newPartitions}, + &model.PartitionInfo{Definitions: oldPartitions}, + ) + asyncNotifyEvent(d, truncatePartitionEvent) + // A background job will be created to delete old partition data. + job.Args = []any{oldIDs} + + return ver, err + } + + // When table has global index, public->deleteOnly->deleteReorg->none schema changes should be handled. + switch job.SchemaState { + case model.StatePublic: + // Step1: generate new partition ids + truncatingDefinitions := make([]model.PartitionDefinition, 0, len(oldIDs)) + for i, oldID := range oldIDs { + for j := 0; j < len(pi.Definitions); j++ { + def := &pi.Definitions[j] + if def.ID == oldID { + truncatingDefinitions = append(truncatingDefinitions, def.Clone()) + def.ID = newIDs[i] + break + } + } + } + pi.DroppingDefinitions = truncatingDefinitions + pi.NewPartitionIDs = newIDs[:] + + job.SchemaState = model.StateDeleteOnly + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + case model.StateDeleteOnly: + // This state is not a real 'DeleteOnly' state, because tidb does not maintaining the state check in partitionDefinition. + // Insert this state to confirm all servers can not see the old partitions when reorg is running, + // so that no new data will be inserted into old partitions when reorganizing. + job.SchemaState = model.StateDeleteReorganization + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + case model.StateDeleteReorganization: + // Step2: clear global index rows. + physicalTableIDs := oldIDs + oldTblInfo := getTableInfoWithOriginalPartitions(tblInfo, oldIDs, newIDs) + + tbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, oldTblInfo) + if err != nil { + return ver, errors.Trace(err) + } + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + // If table has global indexes, we need reorg to clean up them. + if pt, ok := tbl.(table.PartitionedTable); ok && hasGlobalIndex(tblInfo) { + // Build elements for compatible with modify column type. elements will not be used when reorganizing. + elements := make([]*meta.Element, 0, len(tblInfo.Indices)) + for _, idxInfo := range tblInfo.Indices { + if idxInfo.Global { + elements = append(elements, &meta.Element{ID: idxInfo.ID, TypeKey: meta.IndexElementKey}) + } + } + sctx, err1 := w.sessPool.Get() + if err1 != nil { + return ver, err1 + } + defer w.sessPool.Put(sctx) + rh := newReorgHandler(sess.NewSession(sctx)) + reorgInfo, err := getReorgInfoFromPartitions(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, pt, physicalTableIDs, elements) + + if err != nil || reorgInfo.first { + // If we run reorg firstly, we should update the job snapshot version + // and then run the reorg next time. + return ver, errors.Trace(err) + } + err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (dropIndexErr error) { + defer tidbutil.Recover(metrics.LabelDDL, "onDropTablePartition", + func() { + dropIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("drop partition panic") + }, false) + return w.cleanupGlobalIndexes(pt, physicalTableIDs, reorgInfo) + }) + if err != nil { + if dbterror.ErrWaitReorgTimeout.Equal(err) { + // if timeout, we should return, check for the owner and re-wait job done. + return ver, nil + } + return ver, errors.Trace(err) + } + } + + // Step3: generate new partition ids and finish rest works + oldPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) + newPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) + for _, oldDef := range pi.DroppingDefinitions { + var newID int64 + for i := range oldIDs { + if oldDef.ID == oldIDs[i] { + newID = newIDs[i] + break + } + } + for i := 0; i < len(pi.Definitions); i++ { + def := &pi.Definitions[i] + if newID == def.ID { + oldPartitions = append(oldPartitions, oldDef.Clone()) + newPartitions = append(newPartitions, def.Clone()) + break + } + } + } + if len(newPartitions) == 0 { + job.State = model.JobStateCancelled + return ver, table.ErrUnknownPartition.GenWithStackByArgs(fmt.Sprintf("pid:%v", oldIDs), tblInfo.Name.O) + } + + if err = clearTruncatePartitionTiflashStatus(tblInfo, newPartitions, oldIDs); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + if err = updateTruncatePartitionLabelRules(job, t, oldPartitions, newPartitions, tblInfo, oldIDs); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + // Step4: clear DroppingDefinitions and finish job. + tblInfo.Partition.DroppingDefinitions = nil + tblInfo.Partition.NewPartitionIDs = nil + + preSplitAndScatter(w.sess.Context, d.store, tblInfo, newPartitions) + + // used by ApplyDiff in updateSchemaVersion + job.CtxVars = []any{oldIDs, newIDs} + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + truncatePartitionEvent := statsutil.NewTruncatePartitionEvent( + job.SchemaID, + tblInfo, + &model.PartitionInfo{Definitions: newPartitions}, + &model.PartitionInfo{Definitions: oldPartitions}, + ) + asyncNotifyEvent(d, truncatePartitionEvent) + // A background job will be created to delete old partition data. + job.Args = []any{oldIDs} + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) + } + + return ver, errors.Trace(err) +} + +func clearTruncatePartitionTiflashStatus(tblInfo *model.TableInfo, newPartitions []model.PartitionDefinition, oldIDs []int64) error { + // Clear the tiflash replica available status. + if tblInfo.TiFlashReplica != nil { + e := infosync.ConfigureTiFlashPDForPartitions(true, &newPartitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID) + failpoint.Inject("FailTiFlashTruncatePartition", func() { + e = errors.New("enforced error") + }) + if e != nil { + logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(e)) + return e + } + tblInfo.TiFlashReplica.Available = false + // Set partition replica become unavailable. + removeTiFlashAvailablePartitionIDs(tblInfo, oldIDs) + } + return nil +} + +func updateTruncatePartitionLabelRules(job *model.Job, t *meta.Meta, oldPartitions, newPartitions []model.PartitionDefinition, tblInfo *model.TableInfo, oldIDs []int64) error { + bundles, err := placement.NewPartitionListBundles(t, newPartitions) + if err != nil { + return errors.Trace(err) + } + + tableBundle, err := placement.NewTableBundle(t, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return errors.Trace(err) + } + + if tableBundle != nil { + bundles = append(bundles, tableBundle) + } + + // create placement groups for each dropped partition to keep the data's placement before GC + // These placements groups will be deleted after GC + keepDroppedBundles, err := droppedPartitionBundles(t, tblInfo, oldPartitions) + if err != nil { + job.State = model.JobStateCancelled + return errors.Trace(err) + } + bundles = append(bundles, keepDroppedBundles...) + + err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) + if err != nil { + return errors.Wrapf(err, "failed to notify PD the placement rules") + } + + tableID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, job.SchemaName, tblInfo.Name.L) + oldPartRules := make([]string, 0, len(oldIDs)) + for _, newPartition := range newPartitions { + oldPartRuleID := fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, job.SchemaName, tblInfo.Name.L, newPartition.Name.L) + oldPartRules = append(oldPartRules, oldPartRuleID) + } + + rules, err := infosync.GetLabelRules(context.TODO(), append(oldPartRules, tableID)) + if err != nil { + return errors.Wrapf(err, "failed to get label rules from PD") + } + + newPartIDs := getPartitionIDs(tblInfo) + newRules := make([]*label.Rule, 0, len(oldIDs)+1) + if tr, ok := rules[tableID]; ok { + newRules = append(newRules, tr.Clone().Reset(job.SchemaName, tblInfo.Name.L, "", append(newPartIDs, tblInfo.ID)...)) + } + + for idx, newPartition := range newPartitions { + if pr, ok := rules[oldPartRules[idx]]; ok { + newRules = append(newRules, pr.Clone().Reset(job.SchemaName, tblInfo.Name.L, newPartition.Name.L, newPartition.ID)) + } + } + + patch := label.NewRulePatch(newRules, []string{}) + err = infosync.UpdateLabelRules(context.TODO(), patch) + if err != nil { + return errors.Wrapf(err, "failed to notify PD the label rules") + } + + return nil +} + +// onExchangeTablePartition exchange partition data +func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var ( + // defID only for updateSchemaVersion + defID int64 + ptSchemaID int64 + ptID int64 + partName string + withValidation bool + ) + + if err := job.DecodeArgs(&defID, &ptSchemaID, &ptID, &partName, &withValidation); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + ntDbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + ptDbInfo, err := t.GetDatabase(ptSchemaID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + nt, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + + if job.IsRollingback() { + return rollbackExchangeTablePartition(d, t, job, nt) + } + pt, err := getTableInfo(t, ptID, ptSchemaID) + if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { + job.State = model.JobStateCancelled + } + return ver, errors.Trace(err) + } + + _, partDef, err := getPartitionDef(pt, partName) + if err != nil { + return ver, errors.Trace(err) + } + if job.SchemaState == model.StateNone { + if pt.State != model.StatePublic { + job.State = model.JobStateCancelled + return ver, dbterror.ErrInvalidDDLState.GenWithStack("table %s is not in public, but %s", pt.Name, pt.State) + } + err = checkExchangePartition(pt, nt) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = checkTableDefCompatible(pt, nt) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = checkExchangePartitionPlacementPolicy(t, nt.PlacementPolicyRef, pt.PlacementPolicyRef, partDef.PlacementPolicyRef) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if defID != partDef.ID { + logutil.DDLLogger().Info("Exchange partition id changed, updating to actual id", + zap.Stringer("job", job), zap.Int64("defID", defID), zap.Int64("partDef.ID", partDef.ID)) + job.Args[0] = partDef.ID + defID = partDef.ID + err = updateDDLJob2Table(w.sess, job, true) + if err != nil { + return ver, errors.Trace(err) + } + } + var ptInfo []schemaIDAndTableInfo + if len(nt.Constraints) > 0 { + pt.ExchangePartitionInfo = &model.ExchangePartitionInfo{ + ExchangePartitionTableID: nt.ID, + ExchangePartitionDefID: defID, + } + ptInfo = append(ptInfo, schemaIDAndTableInfo{ + schemaID: ptSchemaID, + tblInfo: pt, + }) + } + nt.ExchangePartitionInfo = &model.ExchangePartitionInfo{ + ExchangePartitionTableID: ptID, + ExchangePartitionDefID: defID, + } + // We need an interim schema version, + // so there are no non-matching rows inserted + // into the table using the schema version + // before the exchange is made. + job.SchemaState = model.StateWriteOnly + return updateVersionAndTableInfoWithCheck(d, t, job, nt, true, ptInfo...) + } + // From now on, nt (the non-partitioned table) has + // ExchangePartitionInfo set, meaning it is restricted + // to only allow writes that would match the + // partition to be exchange with. + // So we need to rollback that change, instead of just cancelling. + + if d.lease > 0 { + delayForAsyncCommit() + } + + if defID != partDef.ID { + // Should never happen, should have been updated above, in previous state! + logutil.DDLLogger().Error("Exchange partition id changed, updating to actual id", + zap.Stringer("job", job), zap.Int64("defID", defID), zap.Int64("partDef.ID", partDef.ID)) + job.Args[0] = partDef.ID + defID = partDef.ID + err = updateDDLJob2Table(w.sess, job, true) + if err != nil { + return ver, errors.Trace(err) + } + } + + if withValidation { + ntbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, nt) + if err != nil { + return ver, errors.Trace(err) + } + ptbl, err := getTable(d.getAutoIDRequirement(), ptSchemaID, pt) + if err != nil { + return ver, errors.Trace(err) + } + err = checkExchangePartitionRecordValidation(w, ptbl, ntbl, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) + if err != nil { + job.State = model.JobStateRollingback + return ver, errors.Trace(err) + } + } + + // partition table auto IDs. + ptAutoIDs, err := t.GetAutoIDAccessors(ptSchemaID, ptID).Get() + if err != nil { + return ver, errors.Trace(err) + } + // non-partition table auto IDs. + ntAutoIDs, err := t.GetAutoIDAccessors(job.SchemaID, nt.ID).Get() + if err != nil { + return ver, errors.Trace(err) + } + + if pt.TiFlashReplica != nil { + for i, id := range pt.TiFlashReplica.AvailablePartitionIDs { + if id == partDef.ID { + pt.TiFlashReplica.AvailablePartitionIDs[i] = nt.ID + break + } + } + } + + // Recreate non-partition table meta info, + // by first delete it with the old table id + err = t.DropTableOrView(job.SchemaID, nt.ID) + if err != nil { + return ver, errors.Trace(err) + } + + // exchange table meta id + pt.ExchangePartitionInfo = nil + // Used below to update the partitioned table's stats meta. + originalPartitionDef := partDef.Clone() + originalNt := nt.Clone() + partDef.ID, nt.ID = nt.ID, partDef.ID + + err = t.UpdateTable(ptSchemaID, pt) + if err != nil { + return ver, errors.Trace(err) + } + + err = t.CreateTableOrView(job.SchemaID, nt) + if err != nil { + return ver, errors.Trace(err) + } + + failpoint.Inject("exchangePartitionErr", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(ver, errors.New("occur an error after updating partition id")) + } + }) + + // Set both tables to the maximum auto IDs between normal table and partitioned table. + // TODO: Fix the issue of big transactions during EXCHANGE PARTITION with AutoID. + // Similar to https://github.com/pingcap/tidb/issues/46904 + newAutoIDs := meta.AutoIDGroup{ + RowID: mathutil.Max(ptAutoIDs.RowID, ntAutoIDs.RowID), + IncrementID: mathutil.Max(ptAutoIDs.IncrementID, ntAutoIDs.IncrementID), + RandomID: mathutil.Max(ptAutoIDs.RandomID, ntAutoIDs.RandomID), + } + err = t.GetAutoIDAccessors(ptSchemaID, pt.ID).Put(newAutoIDs) + if err != nil { + return ver, errors.Trace(err) + } + err = t.GetAutoIDAccessors(job.SchemaID, nt.ID).Put(newAutoIDs) + if err != nil { + return ver, errors.Trace(err) + } + + failpoint.Inject("exchangePartitionAutoID", func(val failpoint.Value) { + if val.(bool) { + seCtx, err := w.sessPool.Get() + defer w.sessPool.Put(seCtx) + if err != nil { + failpoint.Return(ver, err) + } + se := sess.NewSession(seCtx) + _, err = se.Execute(context.Background(), "insert ignore into test.pt values (40000000)", "exchange_partition_test") + if err != nil { + failpoint.Return(ver, err) + } + } + }) + + // the follow code is a swap function for rules of two partitions + // though partitions has exchanged their ID, swap still take effect + + bundles, err := bundlesForExchangeTablePartition(t, pt, partDef, nt) + if err != nil { + return ver, errors.Trace(err) + } + + if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") + } + + ntrID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, job.SchemaName, nt.Name.L) + ptrID := fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, job.SchemaName, pt.Name.L, partDef.Name.L) + + rules, err := infosync.GetLabelRules(context.TODO(), []string{ntrID, ptrID}) + if err != nil { + return 0, errors.Wrapf(err, "failed to get PD the label rules") + } + + ntr := rules[ntrID] + ptr := rules[ptrID] + + // This must be a bug, nt cannot be partitioned! + partIDs := getPartitionIDs(nt) + + var setRules []*label.Rule + var deleteRules []string + if ntr != nil && ptr != nil { + setRules = append(setRules, ntr.Clone().Reset(job.SchemaName, pt.Name.L, partDef.Name.L, partDef.ID)) + setRules = append(setRules, ptr.Clone().Reset(job.SchemaName, nt.Name.L, "", append(partIDs, nt.ID)...)) + } else if ptr != nil { + setRules = append(setRules, ptr.Clone().Reset(job.SchemaName, nt.Name.L, "", append(partIDs, nt.ID)...)) + // delete ptr + deleteRules = append(deleteRules, ptrID) + } else if ntr != nil { + setRules = append(setRules, ntr.Clone().Reset(job.SchemaName, pt.Name.L, partDef.Name.L, partDef.ID)) + // delete ntr + deleteRules = append(deleteRules, ntrID) + } + + patch := label.NewRulePatch(setRules, deleteRules) + err = infosync.UpdateLabelRules(context.TODO(), patch) + if err != nil { + return ver, errors.Wrapf(err, "failed to notify PD the label rules") + } + + job.SchemaState = model.StatePublic + nt.ExchangePartitionInfo = nil + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, nt, true) + if err != nil { + return ver, errors.Trace(err) + } + + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, pt) + exchangePartitionEvent := statsutil.NewExchangePartitionEvent( + job.SchemaID, + pt, + &model.PartitionInfo{Definitions: []model.PartitionDefinition{originalPartitionDef}}, + originalNt, + ) + asyncNotifyEvent(d, exchangePartitionEvent) + return ver, nil +} + +func getReorgPartitionInfo(t *meta.Meta, job *model.Job) (*model.TableInfo, []string, *model.PartitionInfo, []model.PartitionDefinition, []model.PartitionDefinition, error) { + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return nil, nil, nil, nil, nil, errors.Trace(err) + } + partInfo := &model.PartitionInfo{} + var partNames []string + err = job.DecodeArgs(&partNames, &partInfo) + if err != nil { + job.State = model.JobStateCancelled + return nil, nil, nil, nil, nil, errors.Trace(err) + } + var addingDefs, droppingDefs []model.PartitionDefinition + if tblInfo.Partition != nil { + addingDefs = tblInfo.Partition.AddingDefinitions + droppingDefs = tblInfo.Partition.DroppingDefinitions + tblInfo.Partition.NewTableID = partInfo.NewTableID + tblInfo.Partition.DDLType = partInfo.Type + tblInfo.Partition.DDLExpr = partInfo.Expr + tblInfo.Partition.DDLColumns = partInfo.Columns + } else { + tblInfo.Partition = getPartitionInfoTypeNone() + tblInfo.Partition.NewTableID = partInfo.NewTableID + tblInfo.Partition.Definitions[0].ID = tblInfo.ID + tblInfo.Partition.DDLType = partInfo.Type + tblInfo.Partition.DDLExpr = partInfo.Expr + tblInfo.Partition.DDLColumns = partInfo.Columns + } + if len(addingDefs) == 0 { + addingDefs = []model.PartitionDefinition{} + } + if len(droppingDefs) == 0 { + droppingDefs = []model.PartitionDefinition{} + } + return tblInfo, partNames, partInfo, droppingDefs, addingDefs, nil +} + +// onReorganizePartition reorganized the partitioning of a table including its indexes. +// ALTER TABLE t REORGANIZE PARTITION p0 [, p1...] INTO (PARTITION p0 ...) +// +// Takes one set of partitions and copies the data to a newly defined set of partitions +// +// ALTER TABLE t REMOVE PARTITIONING +// +// Makes a partitioned table non-partitioned, by first collapsing all partitions into a +// single partition and then converts that partition to a non-partitioned table +// +// ALTER TABLE t PARTITION BY ... +// +// Changes the partitioning to the newly defined partitioning type and definitions, +// works for both partitioned and non-partitioned tables. +// If the table is non-partitioned, then it will first convert it to a partitioned +// table with a single partition, i.e. the full table as a single partition. +// +// job.SchemaState goes through the following SchemaState(s): +// StateNone -> StateDeleteOnly -> StateWriteOnly -> StateWriteReorganization +// -> StateDeleteOrganization -> StatePublic +// There are more details embedded in the implementation, but the high level changes are: +// StateNone -> StateDeleteOnly: +// +// Various checks and validations. +// Add possible new unique/global indexes. They share the same state as job.SchemaState +// until end of StateWriteReorganization -> StateDeleteReorganization. +// Set DroppingDefinitions and AddingDefinitions. +// So both the new partitions and new indexes will be included in new delete/update DML. +// +// StateDeleteOnly -> StateWriteOnly: +// +// So both the new partitions and new indexes will be included also in update/insert DML. +// +// StateWriteOnly -> StateWriteReorganization: +// +// To make sure that when we are reorganizing the data, +// both the old and new partitions/indexes will be updated. +// +// StateWriteReorganization -> StateDeleteOrganization: +// +// Here is where all data is reorganized, both partitions and indexes. +// It copies all data from the old set of partitions into the new set of partitions, +// and creates the local indexes on the new set of partitions, +// and if new unique indexes are added, it also updates them with the rest of data from +// the non-touched partitions. +// For indexes that are to be replaced with new ones (old/new global index), +// mark the old indexes as StateDeleteReorganization and new ones as StatePublic +// Finally make the table visible with the new partition definitions. +// I.e. in this state clients will read from the old set of partitions, +// and will read the new set of partitions in StateDeleteReorganization. +// +// StateDeleteOrganization -> StatePublic: +// +// Now all heavy lifting is done, and we just need to finalize and drop things, while still doing +// double writes, since previous state sees the old partitions/indexes. +// Remove the old indexes and old partitions from the TableInfo. +// Add the old indexes and old partitions to the queue for later cleanup (see delete_range.go). +// Queue new partitions for statistics update. +// if ALTER TABLE t PARTITION BY/REMOVE PARTITIONING: +// Recreate the table with the new TableID, by DropTableOrView+CreateTableOrView +// +// StatePublic: +// +// Everything now looks as it should, no memory of old partitions/indexes, +// and no more double writing, since the previous state is only reading the new partitions/indexes. +func (w *worker) onReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + // Handle the rolling back job + if job.IsRollingback() { + ver, err := w.onDropTablePartition(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + return ver, nil + } + + tblInfo, partNames, partInfo, _, addingDefinitions, err := getReorgPartitionInfo(t, job) + if err != nil { + return ver, err + } + + switch job.SchemaState { + case model.StateNone: + // job.SchemaState == model.StateNone means the job is in the initial state of reorg partition. + // Here should use partInfo from job directly and do some check action. + // In case there was a race for queueing different schema changes on the same + // table and the checks was not done on the current schema version. + // The partInfo may have been checked against an older schema version for example. + // If the check is done here, it does not need to be repeated, since no other + // DDL on the same table can be run concurrently. + num := len(partInfo.Definitions) - len(partNames) + len(tblInfo.Partition.Definitions) + err = checkAddPartitionTooManyPartitions(uint64(num)) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = checkReorgPartitionNames(tblInfo.Partition, partNames, partInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + // Re-check that the dropped/added partitions are compatible with current definition + firstPartIdx, lastPartIdx, idMap, err := getReplacedPartitionIDs(partNames, tblInfo.Partition) + if err != nil { + job.State = model.JobStateCancelled + return ver, err + } + sctx := w.sess.Context + if err = checkReorgPartitionDefs(sctx, job.Type, tblInfo, partInfo, firstPartIdx, lastPartIdx, idMap); err != nil { + job.State = model.JobStateCancelled + return ver, err + } + + // move the adding definition into tableInfo. + updateAddingPartitionInfo(partInfo, tblInfo) + orgDefs := tblInfo.Partition.Definitions + _ = updateDroppingPartitionInfo(tblInfo, partNames) + // Reset original partitions, and keep DroppedDefinitions + tblInfo.Partition.Definitions = orgDefs + + // modify placement settings + for _, def := range tblInfo.Partition.AddingDefinitions { + if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, def.PlacementPolicyRef); err != nil { + // job.State = model.JobStateCancelled may be set depending on error in function above. + return ver, errors.Trace(err) + } + } + + // All global indexes must be recreated, we cannot update them in-place, since we must have + // both old and new set of partition ids in the unique index at the same time! + for _, index := range tblInfo.Indices { + if !index.Unique { + // for now, only unique index can be global, non-unique indexes are 'local' + continue + } + inAllPartitionColumns, err := checkPartitionKeysConstraint(partInfo, index.Columns, tblInfo) + if err != nil { + return ver, errors.Trace(err) + } + if index.Global || !inAllPartitionColumns { + // Duplicate the unique indexes with new index ids. + // If previously was Global or will be Global: + // it must be recreated with new index ID + newIndex := index.Clone() + newIndex.State = model.StateDeleteOnly + newIndex.ID = AllocateIndexID(tblInfo) + if inAllPartitionColumns { + newIndex.Global = false + } else { + // If not including all partitioning columns, make it Global + newIndex.Global = true + } + tblInfo.Indices = append(tblInfo.Indices, newIndex) + } + } + // From now on we cannot just cancel the DDL, we must roll back if changesMade! + changesMade := false + if tblInfo.TiFlashReplica != nil { + // Must set placement rule, and make sure it succeeds. + if err := infosync.ConfigureTiFlashPDForPartitions(true, &tblInfo.Partition.AddingDefinitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID); err != nil { + logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(err)) + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + changesMade = true + // In the next step, StateDeleteOnly, wait to verify the TiFlash replicas are OK + } + + bundles, err := alterTablePartitionBundles(t, tblInfo, tblInfo.Partition.AddingDefinitions) + if err != nil { + if !changesMade { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) + } + + if len(bundles) > 0 { + if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { + if !changesMade { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") + } + return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) + } + changesMade = true + } + + ids := getIDs([]*model.TableInfo{tblInfo}) + for _, p := range tblInfo.Partition.AddingDefinitions { + ids = append(ids, p.ID) + } + changed, err := alterTableLabelRule(job.SchemaName, tblInfo, ids) + changesMade = changesMade || changed + if err != nil { + if !changesMade { + job.State = model.JobStateCancelled + return ver, err + } + return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) + } + + // Doing the preSplitAndScatter here, since all checks are completed, + // and we will soon start writing to the new partitions. + if s, ok := d.store.(kv.SplittableStore); ok && s != nil { + // partInfo only contains the AddingPartitions + splitPartitionTableRegion(w.sess.Context, s, tblInfo, partInfo.Definitions, true) + } + + // Assume we cannot have more than MaxUint64 rows, set the progress to 1/10 of that. + metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, job.SchemaName, tblInfo.Name.String()).Set(0.1 / float64(math.MaxUint64)) + job.SchemaState = model.StateDeleteOnly + tblInfo.Partition.DDLState = model.StateDeleteOnly + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + + // Is really both StateDeleteOnly AND StateWriteOnly needed? + // If transaction A in WriteOnly inserts row 1 (into both new and old partition set) + // and then transaction B in DeleteOnly deletes that row (in both new and old) + // does really transaction B need to do the delete in the new partition? + // Yes, otherwise it would still be there when the WriteReorg happens, + // and WriteReorg would only copy existing rows to the new table, so unless it is + // deleted it would result in a ghost row! + // What about update then? + // Updates also need to be handled for new partitions in DeleteOnly, + // since it would not be overwritten during Reorganize phase. + // BUT if the update results in adding in one partition and deleting in another, + // THEN only the delete must happen in the new partition set, not the insert! + case model.StateDeleteOnly: + // This state is to confirm all servers can not see the new partitions when reorg is running, + // so that all deletes will be done in both old and new partitions when in either DeleteOnly + // or WriteOnly state. + // Also using the state for checking that the optional TiFlash replica is available, making it + // in a state without (much) data and easy to retry without side effects. + + // Reason for having it here, is to make it easy for retry, and better to make sure it is in-sync + // as early as possible, to avoid a long wait after the data copying. + if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { + // For available state, the new added partition should wait its replica to + // be finished, otherwise the query to this partition will be blocked. + count := tblInfo.TiFlashReplica.Count + needRetry, err := checkPartitionReplica(count, addingDefinitions, d) + if err != nil { + // need to rollback, since we tried to register the new + // partitions before! + return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) + } + if needRetry { + // The new added partition hasn't been replicated. + // Do nothing to the job this time, wait next worker round. + time.Sleep(tiflashCheckTiDBHTTPAPIHalfInterval) + // Set the error here which will lead this job exit when it's retry times beyond the limitation. + return ver, errors.Errorf("[ddl] add partition wait for tiflash replica to complete") + } + + // When TiFlash Replica is ready, we must move them into `AvailablePartitionIDs`. + // Since onUpdateFlashReplicaStatus cannot see the partitions yet (not public) + for _, d := range addingDefinitions { + tblInfo.TiFlashReplica.AvailablePartitionIDs = append(tblInfo.TiFlashReplica.AvailablePartitionIDs, d.ID) + } + } + + for i := range tblInfo.Indices { + if tblInfo.Indices[i].Unique && tblInfo.Indices[i].State == model.StateDeleteOnly { + tblInfo.Indices[i].State = model.StateWriteOnly + } + } + tblInfo.Partition.DDLState = model.StateWriteOnly + metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, job.SchemaName, tblInfo.Name.String()).Set(0.2 / float64(math.MaxUint64)) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + job.SchemaState = model.StateWriteOnly + case model.StateWriteOnly: + // Insert this state to confirm all servers can see the new partitions when reorg is running, + // so that new data will be updated in both old and new partitions when reorganizing. + job.SnapshotVer = 0 + for i := range tblInfo.Indices { + if tblInfo.Indices[i].Unique && tblInfo.Indices[i].State == model.StateWriteOnly { + tblInfo.Indices[i].State = model.StateWriteReorganization + } + } + tblInfo.Partition.DDLState = model.StateWriteReorganization + metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, job.SchemaName, tblInfo.Name.String()).Set(0.3 / float64(math.MaxUint64)) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + job.SchemaState = model.StateWriteReorganization + case model.StateWriteReorganization: + physicalTableIDs := getPartitionIDsFromDefinitions(tblInfo.Partition.DroppingDefinitions) + tbl, err2 := getTable(d.getAutoIDRequirement(), job.SchemaID, tblInfo) + if err2 != nil { + return ver, errors.Trace(err2) + } + var done bool + done, ver, err = doPartitionReorgWork(w, d, t, job, tbl, physicalTableIDs) + + if !done { + return ver, err + } + + for _, index := range tblInfo.Indices { + if !index.Unique { + continue + } + switch index.State { + case model.StateWriteReorganization: + // Newly created index, replacing old unique/global index + index.State = model.StatePublic + case model.StatePublic: + if index.Global { + // Mark the old global index as non-readable, and to be dropped + index.State = model.StateDeleteReorganization + } else { + inAllPartitionColumns, err := checkPartitionKeysConstraint(partInfo, index.Columns, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + if !inAllPartitionColumns { + // Mark the old unique index as non-readable, and to be dropped, + // since it is replaced by a global index + index.State = model.StateDeleteReorganization + } + } + } + } + firstPartIdx, lastPartIdx, idMap, err2 := getReplacedPartitionIDs(partNames, tblInfo.Partition) + failpoint.Inject("reorgPartWriteReorgReplacedPartIDsFail", func(val failpoint.Value) { + if val.(bool) { + err2 = errors.New("Injected error by reorgPartWriteReorgReplacedPartIDsFail") + } + }) + if err2 != nil { + return ver, err2 + } + newDefs := getReorganizedDefinitions(tblInfo.Partition, firstPartIdx, lastPartIdx, idMap) + + // From now on, use the new partitioning, but keep the Adding and Dropping for double write + tblInfo.Partition.Definitions = newDefs + tblInfo.Partition.Num = uint64(len(newDefs)) + if job.Type == model.ActionAlterTablePartitioning || + job.Type == model.ActionRemovePartitioning { + tblInfo.Partition.Type, tblInfo.Partition.DDLType = tblInfo.Partition.DDLType, tblInfo.Partition.Type + tblInfo.Partition.Expr, tblInfo.Partition.DDLExpr = tblInfo.Partition.DDLExpr, tblInfo.Partition.Expr + tblInfo.Partition.Columns, tblInfo.Partition.DDLColumns = tblInfo.Partition.DDLColumns, tblInfo.Partition.Columns + } + + // Now all the data copying is done, but we cannot simply remove the droppingDefinitions + // since they are a part of the normal Definitions that other nodes with + // the current schema version. So we need to double write for one more schema version + tblInfo.Partition.DDLState = model.StateDeleteReorganization + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + job.SchemaState = model.StateDeleteReorganization + + case model.StateDeleteReorganization: + // Drop the droppingDefinitions and finish the DDL + // This state is needed for the case where client A sees the schema + // with version of StateWriteReorg and would not see updates of + // client B that writes to the new partitions, previously + // addingDefinitions, since it would not double write to + // the droppingDefinitions during this time + // By adding StateDeleteReorg state, client B will write to both + // the new (previously addingDefinitions) AND droppingDefinitions + + // Register the droppingDefinitions ids for rangeDelete + // and the addingDefinitions for handling in the updateSchemaVersion + physicalTableIDs := getPartitionIDsFromDefinitions(tblInfo.Partition.DroppingDefinitions) + newIDs := getPartitionIDsFromDefinitions(partInfo.Definitions) + statisticsPartInfo := &model.PartitionInfo{Definitions: tblInfo.Partition.AddingDefinitions} + droppedPartInfo := &model.PartitionInfo{Definitions: tblInfo.Partition.DroppingDefinitions} + + tblInfo.Partition.DroppingDefinitions = nil + tblInfo.Partition.AddingDefinitions = nil + tblInfo.Partition.DDLState = model.StateNone + + var dropIndices []*model.IndexInfo + for _, indexInfo := range tblInfo.Indices { + if indexInfo.Unique && indexInfo.State == model.StateDeleteReorganization { + // Drop the old unique (possible global) index, see onDropIndex + indexInfo.State = model.StateNone + DropIndexColumnFlag(tblInfo, indexInfo) + RemoveDependentHiddenColumns(tblInfo, indexInfo) + dropIndices = append(dropIndices, indexInfo) + } + } + for _, indexInfo := range dropIndices { + removeIndexInfo(tblInfo, indexInfo) + } + var oldTblID int64 + if job.Type != model.ActionReorganizePartition { + // ALTER TABLE ... PARTITION BY + // REMOVE PARTITIONING + // Storing the old table ID, used for updating statistics. + oldTblID = tblInfo.ID + // TODO: Handle bundles? + // TODO: Add concurrent test! + // TODO: Will this result in big gaps? + // TODO: How to carrie over AUTO_INCREMENT etc.? + // Check if they are carried over in ApplyDiff?!? + autoIDs, err := t.GetAutoIDAccessors(job.SchemaID, tblInfo.ID).Get() + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + err = t.DropTableOrView(job.SchemaID, tblInfo.ID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + tblInfo.ID = partInfo.NewTableID + if partInfo.DDLType != model.PartitionTypeNone { + // if partitioned before, then also add the old table ID, + // otherwise it will be the already included first partition + physicalTableIDs = append(physicalTableIDs, oldTblID) + } + if job.Type == model.ActionRemovePartitioning { + tblInfo.Partition = nil + } else { + // ALTER TABLE ... PARTITION BY + tblInfo.Partition.ClearReorgIntermediateInfo() + } + err = t.GetAutoIDAccessors(job.SchemaID, tblInfo.ID).Put(autoIDs) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + // TODO: Add failpoint here? + err = t.CreateTableOrView(job.SchemaID, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } + job.CtxVars = []any{physicalTableIDs, newIDs} + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + failpoint.Inject("reorgPartWriteReorgSchemaVersionUpdateFail", func(val failpoint.Value) { + if val.(bool) { + err = errors.New("Injected error by reorgPartWriteReorgSchemaVersionUpdateFail") + } + }) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + // How to handle this? + // Seems to only trigger asynchronous update of statistics. + // Should it actually be synchronous? + // Include the old table ID, if changed, which may contain global statistics, + // so it can be reused for the new (non)partitioned table. + event, err := newStatsDDLEventForJob( + job.SchemaID, + job.Type, oldTblID, tblInfo, statisticsPartInfo, droppedPartInfo, + ) + if err != nil { + return ver, errors.Trace(err) + } + asyncNotifyEvent(d, event) + // A background job will be created to delete old partition data. + job.Args = []any{physicalTableIDs} + + default: + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) + } + + return ver, errors.Trace(err) +} + +// newStatsDDLEventForJob creates a statsutil.DDLEvent for a job. +// It is used for reorganize partition, add partitioning and remove partitioning. +func newStatsDDLEventForJob( + schemaID int64, + jobType model.ActionType, + oldTblID int64, + tblInfo *model.TableInfo, + addedPartInfo *model.PartitionInfo, + droppedPartInfo *model.PartitionInfo, +) (*statsutil.DDLEvent, error) { + var event *statsutil.DDLEvent + switch jobType { + case model.ActionReorganizePartition: + event = statsutil.NewReorganizePartitionEvent( + schemaID, + tblInfo, + addedPartInfo, + droppedPartInfo, + ) + case model.ActionAlterTablePartitioning: + event = statsutil.NewAddPartitioningEvent( + schemaID, + oldTblID, + tblInfo, + addedPartInfo, + ) + case model.ActionRemovePartitioning: + event = statsutil.NewRemovePartitioningEvent( + schemaID, + oldTblID, + tblInfo, + droppedPartInfo, + ) + default: + return nil, errors.Errorf("unknown job type: %s", jobType.String()) + } + return event, nil +} + +func doPartitionReorgWork(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, tbl table.Table, physTblIDs []int64) (done bool, ver int64, err error) { + job.ReorgMeta.ReorgTp = model.ReorgTypeTxn + sctx, err1 := w.sessPool.Get() + if err1 != nil { + return done, ver, err1 + } + defer w.sessPool.Put(sctx) + rh := newReorgHandler(sess.NewSession(sctx)) + indices := make([]*model.IndexInfo, 0, len(tbl.Meta().Indices)) + for _, index := range tbl.Meta().Indices { + if index.Global && index.State == model.StatePublic { + // Skip old global indexes, but rebuild all other indexes + continue + } + indices = append(indices, index) + } + elements := BuildElements(tbl.Meta().Columns[0], indices) + partTbl, ok := tbl.(table.PartitionedTable) + if !ok { + return false, ver, dbterror.ErrUnsupportedReorganizePartition.GenWithStackByArgs() + } + dbInfo, err := t.GetDatabase(job.SchemaID) + if err != nil { + return false, ver, errors.Trace(err) + } + reorgInfo, err := getReorgInfoFromPartitions(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, partTbl, physTblIDs, elements) + err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (reorgErr error) { + defer tidbutil.Recover(metrics.LabelDDL, "doPartitionReorgWork", + func() { + reorgErr = dbterror.ErrCancelledDDLJob.GenWithStack("reorganize partition for table `%v` panic", tbl.Meta().Name) + }, false) + return w.reorgPartitionDataAndIndex(tbl, reorgInfo) + }) + if err != nil { + if dbterror.ErrPausedDDLJob.Equal(err) { + return false, ver, nil + } + + if dbterror.ErrWaitReorgTimeout.Equal(err) { + // If timeout, we should return, check for the owner and re-wait job done. + return false, ver, nil + } + if kv.IsTxnRetryableError(err) { + return false, ver, errors.Trace(err) + } + if err1 := rh.RemoveDDLReorgHandle(job, reorgInfo.elements); err1 != nil { + logutil.DDLLogger().Warn("reorg partition job failed, RemoveDDLReorgHandle failed, can't convert job to rollback", + zap.Stringer("job", job), zap.Error(err1)) + } + logutil.DDLLogger().Warn("reorg partition job failed, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) + // TODO: rollback new global indexes! TODO: How to handle new index ids? + ver, err = convertAddTablePartitionJob2RollbackJob(d, t, job, err, tbl.Meta()) + return false, ver, errors.Trace(err) + } + return true, ver, err +} + +type reorgPartitionWorker struct { + *backfillCtx + // Static allocated to limit memory allocations + rowRecords []*rowRecord + rowDecoder *decoder.RowDecoder + rowMap map[int64]types.Datum + writeColOffsetMap map[int64]int + maxOffset int + reorgedTbl table.PartitionedTable +} + +func newReorgPartitionWorker(i int, t table.PhysicalTable, decodeColMap map[int64]decoder.Column, reorgInfo *reorgInfo, jc *JobContext) (*reorgPartitionWorker, error) { + bCtx, err := newBackfillCtx(i, reorgInfo, reorgInfo.SchemaName, t, jc, "reorg_partition_rate", false) + if err != nil { + return nil, err + } + reorgedTbl, err := tables.GetReorganizedPartitionedTable(t) + if err != nil { + return nil, errors.Trace(err) + } + pt := t.GetPartitionedTable() + if pt == nil { + return nil, dbterror.ErrUnsupportedReorganizePartition.GenWithStackByArgs() + } + partColIDs := reorgedTbl.GetPartitionColumnIDs() + writeColOffsetMap := make(map[int64]int, len(partColIDs)) + maxOffset := 0 + for _, id := range partColIDs { + var offset int + for _, col := range pt.Cols() { + if col.ID == id { + offset = col.Offset + break + } + } + writeColOffsetMap[id] = offset + maxOffset = mathutil.Max[int](maxOffset, offset) + } + return &reorgPartitionWorker{ + backfillCtx: bCtx, + rowDecoder: decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap), + rowMap: make(map[int64]types.Datum, len(decodeColMap)), + writeColOffsetMap: writeColOffsetMap, + maxOffset: maxOffset, + reorgedTbl: reorgedTbl, + }, nil +} + +func (w *reorgPartitionWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { + oprStartTime := time.Now() + ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) + errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { + taskCtx.addedCount = 0 + taskCtx.scanCount = 0 + updateTxnEntrySizeLimitIfNeeded(txn) + txn.SetOption(kv.Priority, handleRange.priority) + if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(handleRange.getJobID()); tagger != nil { + txn.SetOption(kv.ResourceGroupTagger, tagger) + } + txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) + + rowRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) + if err != nil { + return errors.Trace(err) + } + taskCtx.nextKey = nextKey + taskCtx.done = taskDone + + warningsMap := make(map[errors.ErrorID]*terror.Error) + warningsCountMap := make(map[errors.ErrorID]int64) + for _, prr := range rowRecords { + taskCtx.scanCount++ + + err = txn.Set(prr.key, prr.vals) + if err != nil { + return errors.Trace(err) + } + taskCtx.addedCount++ + if prr.warning != nil { + if _, ok := warningsCountMap[prr.warning.ID()]; ok { + warningsCountMap[prr.warning.ID()]++ + } else { + warningsCountMap[prr.warning.ID()] = 1 + warningsMap[prr.warning.ID()] = prr.warning + } + } + // TODO: Future optimization: also write the indexes here? + // What if the transaction limit is just enough for a single row, without index? + // Hmm, how could that be in the first place? + // For now, implement the batch-txn w.addTableIndex, + // since it already exists and is in use + } + + // Collect the warnings. + taskCtx.warnings, taskCtx.warningsCount = warningsMap, warningsCountMap + + // also add the index entries here? And make sure they are not added somewhere else + + return nil + }) + logSlowOperations(time.Since(oprStartTime), "BackfillData", 3000) + + return +} + +func (w *reorgPartitionWorker) fetchRowColVals(txn kv.Transaction, taskRange reorgBackfillTask) ([]*rowRecord, kv.Key, bool, error) { + w.rowRecords = w.rowRecords[:0] + startTime := time.Now() + + // taskDone means that the added handle is out of taskRange.endHandle. + taskDone := false + sysTZ := w.loc + + tmpRow := make([]types.Datum, w.maxOffset+1) + var lastAccessedHandle kv.Key + oprStartTime := startTime + err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, w.table.RecordPrefix(), txn.StartTS(), taskRange.startKey, taskRange.endKey, + func(handle kv.Handle, recordKey kv.Key, rawRow []byte) (bool, error) { + oprEndTime := time.Now() + logSlowOperations(oprEndTime.Sub(oprStartTime), "iterateSnapshotKeys in reorgPartitionWorker fetchRowColVals", 0) + oprStartTime = oprEndTime + + taskDone = recordKey.Cmp(taskRange.endKey) >= 0 + + if taskDone || len(w.rowRecords) >= w.batchCnt { + return false, nil + } + + _, err := w.rowDecoder.DecodeTheExistedColumnMap(w.exprCtx, handle, rawRow, sysTZ, w.rowMap) + if err != nil { + return false, errors.Trace(err) + } + + // Set the partitioning columns and calculate which partition to write to + for colID, offset := range w.writeColOffsetMap { + d, ok := w.rowMap[colID] + if !ok { + return false, dbterror.ErrUnsupportedReorganizePartition.GenWithStackByArgs() + } + tmpRow[offset] = d + } + p, err := w.reorgedTbl.GetPartitionByRow(w.exprCtx.GetEvalCtx(), tmpRow) + if err != nil { + return false, errors.Trace(err) + } + var newKey kv.Key + if w.reorgedTbl.Meta().PKIsHandle || w.reorgedTbl.Meta().IsCommonHandle { + pid := p.GetPhysicalID() + newKey = tablecodec.EncodeTablePrefix(pid) + newKey = append(newKey, recordKey[len(newKey):]...) + } else { + // Non-clustered table / not unique _tidb_rowid for the whole table + // Generate new _tidb_rowid if exists. + // Due to EXCHANGE PARTITION, the existing _tidb_rowid may collide between partitions! + if reserved, ok := w.tblCtx.GetReservedRowIDAlloc(); ok && reserved.Exhausted() { + // TODO: Which autoid allocator to use? + ids := uint64(max(1, w.batchCnt-len(w.rowRecords))) + // Keep using the original table's allocator + var baseRowID, maxRowID int64 + baseRowID, maxRowID, err = tables.AllocHandleIDs(w.ctx, w.tblCtx, w.reorgedTbl, ids) + if err != nil { + return false, errors.Trace(err) + } + reserved.Reset(baseRowID, maxRowID) + } + recordID, err := tables.AllocHandle(w.ctx, w.tblCtx, w.reorgedTbl) + if err != nil { + return false, errors.Trace(err) + } + newKey = tablecodec.EncodeRecordKey(p.RecordPrefix(), recordID) + } + w.rowRecords = append(w.rowRecords, &rowRecord{ + key: newKey, vals: rawRow, + }) + + w.cleanRowMap() + lastAccessedHandle = recordKey + if recordKey.Cmp(taskRange.endKey) == 0 { + taskDone = true + return false, nil + } + return true, nil + }) + + if len(w.rowRecords) == 0 { + taskDone = true + } + + logutil.DDLLogger().Debug("txn fetches handle info", + zap.Uint64("txnStartTS", txn.StartTS()), + zap.Stringer("taskRange", &taskRange), + zap.Duration("takeTime", time.Since(startTime))) + return w.rowRecords, getNextHandleKey(taskRange, taskDone, lastAccessedHandle), taskDone, errors.Trace(err) +} + +func (w *reorgPartitionWorker) cleanRowMap() { + for id := range w.rowMap { + delete(w.rowMap, id) + } +} + +func (w *reorgPartitionWorker) AddMetricInfo(cnt float64) { + w.metricCounter.Add(cnt) +} + +func (*reorgPartitionWorker) String() string { + return typeReorgPartitionWorker.String() +} + +func (w *reorgPartitionWorker) GetCtx() *backfillCtx { + return w.backfillCtx +} + +func (w *worker) reorgPartitionDataAndIndex(t table.Table, reorgInfo *reorgInfo) (err error) { + // First copy all table data to the new AddingDefinitions partitions + // from each of the DroppingDefinitions partitions. + // Then create all indexes on the AddingDefinitions partitions, + // both new local and new global indexes + // And last update new global indexes from the non-touched partitions + // Note it is hard to update global indexes in-place due to: + // - Transactions on different TiDB nodes/domains may see different states of the table/partitions + // - We cannot have multiple partition ids for a unique index entry. + + // Copy the data from the DroppingDefinitions to the AddingDefinitions + if bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { + err = w.updatePhysicalTableRow(t, reorgInfo) + if err != nil { + return errors.Trace(err) + } + if len(reorgInfo.elements) <= 1 { + // No indexes to (re)create, all done! + return nil + } + } + + failpoint.Inject("reorgPartitionAfterDataCopy", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + panic("panic test in reorgPartitionAfterDataCopy") + } + }) + + if !bytes.Equal(reorgInfo.currElement.TypeKey, meta.IndexElementKey) { + // row data has been copied, now proceed with creating the indexes + // on the new AddingDefinitions partitions + reorgInfo.PhysicalTableID = t.Meta().Partition.AddingDefinitions[0].ID + reorgInfo.currElement = reorgInfo.elements[1] + var physTbl table.PhysicalTable + if tbl, ok := t.(table.PartitionedTable); ok { + physTbl = tbl.GetPartition(reorgInfo.PhysicalTableID) + } else if tbl, ok := t.(table.PhysicalTable); ok { + // This may be used when partitioning a non-partitioned table + physTbl = tbl + } + // Get the original start handle and end handle. + currentVer, err := getValidCurrentVersion(reorgInfo.d.store) + if err != nil { + return errors.Trace(err) + } + startHandle, endHandle, err := getTableRange(reorgInfo.NewJobContext(), reorgInfo.d, physTbl, currentVer.Ver, reorgInfo.Job.Priority) + if err != nil { + return errors.Trace(err) + } + + // Always (re)start with the full PhysicalTable range + reorgInfo.StartKey, reorgInfo.EndKey = startHandle, endHandle + + // Write the reorg info to store so the whole reorganize process can recover from panic. + err = reorgInfo.UpdateReorgMeta(reorgInfo.StartKey, w.sessPool) + logutil.DDLLogger().Info("update column and indexes", + zap.Int64("jobID", reorgInfo.Job.ID), + zap.ByteString("elementType", reorgInfo.currElement.TypeKey), + zap.Int64("elementID", reorgInfo.currElement.ID), + zap.Int64("partitionTableId", physTbl.GetPhysicalID()), + zap.String("startHandle", hex.EncodeToString(reorgInfo.StartKey)), + zap.String("endHandle", hex.EncodeToString(reorgInfo.EndKey))) + if err != nil { + return errors.Trace(err) + } + } + + pi := t.Meta().GetPartitionInfo() + if _, err = findNextPartitionID(reorgInfo.PhysicalTableID, pi.AddingDefinitions); err == nil { + // Now build all the indexes in the new partitions + err = w.addTableIndex(t, reorgInfo) + if err != nil { + return errors.Trace(err) + } + // All indexes are up-to-date for new partitions, + // now we only need to add the existing non-touched partitions + // to the global indexes + reorgInfo.elements = reorgInfo.elements[:0] + for _, indexInfo := range t.Meta().Indices { + if indexInfo.Global && indexInfo.State == model.StateWriteReorganization { + reorgInfo.elements = append(reorgInfo.elements, &meta.Element{ID: indexInfo.ID, TypeKey: meta.IndexElementKey}) + } + } + if len(reorgInfo.elements) == 0 { + // No global indexes + return nil + } + reorgInfo.currElement = reorgInfo.elements[0] + pid := pi.Definitions[0].ID + if _, err = findNextPartitionID(pid, pi.DroppingDefinitions); err == nil { + // Skip all dropped partitions + pid, err = findNextNonTouchedPartitionID(pid, pi) + if err != nil { + return errors.Trace(err) + } + } + if pid == 0 { + // All partitions will be dropped, nothing more to add to global indexes. + return nil + } + reorgInfo.PhysicalTableID = pid + var physTbl table.PhysicalTable + if tbl, ok := t.(table.PartitionedTable); ok { + physTbl = tbl.GetPartition(reorgInfo.PhysicalTableID) + } else if tbl, ok := t.(table.PhysicalTable); ok { + // This may be used when partitioning a non-partitioned table + physTbl = tbl + } + // Get the original start handle and end handle. + currentVer, err := getValidCurrentVersion(reorgInfo.d.store) + if err != nil { + return errors.Trace(err) + } + startHandle, endHandle, err := getTableRange(reorgInfo.NewJobContext(), reorgInfo.d, physTbl, currentVer.Ver, reorgInfo.Job.Priority) + if err != nil { + return errors.Trace(err) + } + + // Always (re)start with the full PhysicalTable range + reorgInfo.StartKey, reorgInfo.EndKey = startHandle, endHandle + + // Write the reorg info to store so the whole reorganize process can recover from panic. + err = reorgInfo.UpdateReorgMeta(reorgInfo.StartKey, w.sessPool) + logutil.DDLLogger().Info("update column and indexes", + zap.Int64("jobID", reorgInfo.Job.ID), + zap.ByteString("elementType", reorgInfo.currElement.TypeKey), + zap.Int64("elementID", reorgInfo.currElement.ID), + zap.Int64("partitionTableId", physTbl.GetPhysicalID()), + zap.String("startHandle", hex.EncodeToString(reorgInfo.StartKey)), + zap.String("endHandle", hex.EncodeToString(reorgInfo.EndKey))) + if err != nil { + return errors.Trace(err) + } + } + return w.addTableIndex(t, reorgInfo) +} + +func bundlesForExchangeTablePartition(t *meta.Meta, pt *model.TableInfo, newPar *model.PartitionDefinition, nt *model.TableInfo) ([]*placement.Bundle, error) { + bundles := make([]*placement.Bundle, 0, 3) + + ptBundle, err := placement.NewTableBundle(t, pt) + if err != nil { + return nil, errors.Trace(err) + } + if ptBundle != nil { + bundles = append(bundles, ptBundle) + } + + parBundle, err := placement.NewPartitionBundle(t, *newPar) + if err != nil { + return nil, errors.Trace(err) + } + if parBundle != nil { + bundles = append(bundles, parBundle) + } + + ntBundle, err := placement.NewTableBundle(t, nt) + if err != nil { + return nil, errors.Trace(err) + } + if ntBundle != nil { + bundles = append(bundles, ntBundle) + } + + if parBundle == nil && ntBundle != nil { + // newPar.ID is the ID of old table to exchange, so ntBundle != nil means it has some old placement settings. + // We should remove it in this situation + bundles = append(bundles, placement.NewBundle(newPar.ID)) + } + + if parBundle != nil && ntBundle == nil { + // nt.ID is the ID of old partition to exchange, so parBundle != nil means it has some old placement settings. + // We should remove it in this situation + bundles = append(bundles, placement.NewBundle(nt.ID)) + } + + return bundles, nil +} + +func checkExchangePartitionRecordValidation(w *worker, ptbl, ntbl table.Table, pschemaName, nschemaName, partitionName string) error { + verifyFunc := func(sql string, params ...any) error { + var ctx sessionctx.Context + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) + + rows, _, err := ctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(w.ctx, nil, sql, params...) + if err != nil { + return errors.Trace(err) + } + rowCount := len(rows) + if rowCount != 0 { + return errors.Trace(dbterror.ErrRowDoesNotMatchPartition) + } + // Check warnings! + // Is it possible to check how many rows where checked as well? + return nil + } + genConstraintCondition := func(constraints []*table.Constraint) string { + var buf strings.Builder + buf.WriteString("not (") + for i, cons := range constraints { + if i != 0 { + buf.WriteString(" and ") + } + buf.WriteString(fmt.Sprintf("(%s)", cons.ExprString)) + } + buf.WriteString(")") + return buf.String() + } + type CheckConstraintTable interface { + WritableConstraint() []*table.Constraint + } + + pt := ptbl.Meta() + index, _, err := getPartitionDef(pt, partitionName) + if err != nil { + return errors.Trace(err) + } + + var buf strings.Builder + buf.WriteString("select 1 from %n.%n where ") + paramList := []any{nschemaName, ntbl.Meta().Name.L} + checkNt := true + + pi := pt.Partition + switch pi.Type { + case model.PartitionTypeHash: + if pi.Num == 1 { + checkNt = false + } else { + buf.WriteString("mod(") + buf.WriteString(pi.Expr) + buf.WriteString(", %?) != %?") + paramList = append(paramList, pi.Num, index) + if index != 0 { + // TODO: if hash result can't be NULL, we can remove the check part. + // For example hash(id), but id is defined not NULL. + buf.WriteString(" or mod(") + buf.WriteString(pi.Expr) + buf.WriteString(", %?) is null") + paramList = append(paramList, pi.Num, index) + } + } + case model.PartitionTypeRange: + // Table has only one partition and has the maximum value + if len(pi.Definitions) == 1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { + checkNt = false + } else { + // For range expression and range columns + if len(pi.Columns) == 0 { + conds, params := buildCheckSQLConditionForRangeExprPartition(pi, index) + buf.WriteString(conds) + paramList = append(paramList, params...) + } else { + conds, params := buildCheckSQLConditionForRangeColumnsPartition(pi, index) + buf.WriteString(conds) + paramList = append(paramList, params...) + } + } + case model.PartitionTypeList: + if len(pi.Columns) == 0 { + conds := buildCheckSQLConditionForListPartition(pi, index) + buf.WriteString(conds) + } else { + conds := buildCheckSQLConditionForListColumnsPartition(pi, index) + buf.WriteString(conds) + } + default: + return dbterror.ErrUnsupportedPartitionType.GenWithStackByArgs(pt.Name.O) + } + + if variable.EnableCheckConstraint.Load() { + pcc, ok := ptbl.(CheckConstraintTable) + if !ok { + return errors.Errorf("exchange partition process assert table partition failed") + } + pCons := pcc.WritableConstraint() + if len(pCons) > 0 { + if !checkNt { + checkNt = true + } else { + buf.WriteString(" or ") + } + buf.WriteString(genConstraintCondition(pCons)) + } + } + // Check non-partition table records. + if checkNt { + buf.WriteString(" limit 1") + err = verifyFunc(buf.String(), paramList...) + if err != nil { + return errors.Trace(err) + } + } + + // Check partition table records. + if variable.EnableCheckConstraint.Load() { + ncc, ok := ntbl.(CheckConstraintTable) + if !ok { + return errors.Errorf("exchange partition process assert table partition failed") + } + nCons := ncc.WritableConstraint() + if len(nCons) > 0 { + buf.Reset() + buf.WriteString("select 1 from %n.%n partition(%n) where ") + buf.WriteString(genConstraintCondition(nCons)) + buf.WriteString(" limit 1") + err = verifyFunc(buf.String(), pschemaName, pt.Name.L, partitionName) + if err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +func checkExchangePartitionPlacementPolicy(t *meta.Meta, ntPPRef, ptPPRef, partPPRef *model.PolicyRefInfo) error { + partitionPPRef := partPPRef + if partitionPPRef == nil { + partitionPPRef = ptPPRef + } + + if ntPPRef == nil && partitionPPRef == nil { + return nil + } + if ntPPRef == nil || partitionPPRef == nil { + return dbterror.ErrTablesDifferentMetadata + } + + ptPlacementPolicyInfo, _ := getPolicyInfo(t, partitionPPRef.ID) + ntPlacementPolicyInfo, _ := getPolicyInfo(t, ntPPRef.ID) + if ntPlacementPolicyInfo == nil && ptPlacementPolicyInfo == nil { + return nil + } + if ntPlacementPolicyInfo == nil || ptPlacementPolicyInfo == nil { + return dbterror.ErrTablesDifferentMetadata + } + if ntPlacementPolicyInfo.Name.L != ptPlacementPolicyInfo.Name.L { + return dbterror.ErrTablesDifferentMetadata + } + + return nil +} + +func buildCheckSQLConditionForRangeExprPartition(pi *model.PartitionInfo, index int) (string, []any) { + var buf strings.Builder + paramList := make([]any, 0, 2) + // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...) + // So we write it to the origin sql string here. + if index == 0 { + buf.WriteString(pi.Expr) + buf.WriteString(" >= %?") + paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) + } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { + buf.WriteString(pi.Expr) + buf.WriteString(" < %? or ") + buf.WriteString(pi.Expr) + buf.WriteString(" is null") + paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) + } else { + buf.WriteString(pi.Expr) + buf.WriteString(" < %? or ") + buf.WriteString(pi.Expr) + buf.WriteString(" >= %? or ") + buf.WriteString(pi.Expr) + buf.WriteString(" is null") + paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) + } + return buf.String(), paramList +} + +func buildCheckSQLConditionForRangeColumnsPartition(pi *model.PartitionInfo, index int) (string, []any) { + paramList := make([]any, 0, 2) + colName := pi.Columns[0].L + if index == 0 { + paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) + return "%n >= %?", paramList + } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { + paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) + return "%n < %?", paramList + } + paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) + return "%n < %? or %n >= %?", paramList +} + +func buildCheckSQLConditionForListPartition(pi *model.PartitionInfo, index int) string { + var buf strings.Builder + buf.WriteString("not (") + for i, inValue := range pi.Definitions[index].InValues { + if i != 0 { + buf.WriteString(" OR ") + } + // AND has higher priority than OR, so no need for parentheses + for j, val := range inValue { + if j != 0 { + // Should never happen, since there should be no multi-columns, only a single expression :) + buf.WriteString(" AND ") + } + // null-safe compare '<=>' + buf.WriteString(fmt.Sprintf("(%s) <=> %s", pi.Expr, val)) + } + } + buf.WriteString(")") + return buf.String() +} + +func buildCheckSQLConditionForListColumnsPartition(pi *model.PartitionInfo, index int) string { + var buf strings.Builder + // How to find a match? + // (row <=> vals1) OR (row <=> vals2) + // How to find a non-matching row: + // NOT ( (row <=> vals1) OR (row <=> vals2) ... ) + buf.WriteString("not (") + colNames := make([]string, 0, len(pi.Columns)) + for i := range pi.Columns { + // TODO: check if there are no proper quoting function for this? + n := "`" + strings.ReplaceAll(pi.Columns[i].O, "`", "``") + "`" + colNames = append(colNames, n) + } + for i, colValues := range pi.Definitions[index].InValues { + if i != 0 { + buf.WriteString(" OR ") + } + // AND has higher priority than OR, so no need for parentheses + for j, val := range colValues { + if j != 0 { + buf.WriteString(" AND ") + } + // null-safe compare '<=>' + buf.WriteString(fmt.Sprintf("%s <=> %s", colNames[j], val)) + } + } + buf.WriteString(")") + return buf.String() +} + +func checkAddPartitionTooManyPartitions(piDefs uint64) error { + if piDefs > uint64(mysql.PartitionCountLimit) { + return errors.Trace(dbterror.ErrTooManyPartitions) + } + return nil +} + +func checkAddPartitionOnTemporaryMode(tbInfo *model.TableInfo) error { + if tbInfo.Partition != nil && tbInfo.TempTableType != model.TempTableNone { + return dbterror.ErrPartitionNoTemporary + } + return nil +} + +func checkPartitionColumnsUnique(tbInfo *model.TableInfo) error { + if len(tbInfo.Partition.Columns) <= 1 { + return nil + } + var columnsMap = make(map[string]struct{}) + for _, col := range tbInfo.Partition.Columns { + if _, ok := columnsMap[col.L]; ok { + return dbterror.ErrSameNamePartitionField.GenWithStackByArgs(col.L) + } + columnsMap[col.L] = struct{}{} + } + return nil +} + +func checkNoHashPartitions(_ sessionctx.Context, partitionNum uint64) error { + if partitionNum == 0 { + return ast.ErrNoParts.GenWithStackByArgs("partitions") + } + return nil +} + +func getPartitionIDs(table *model.TableInfo) []int64 { + if table.GetPartitionInfo() == nil { + return []int64{} + } + physicalTableIDs := make([]int64, 0, len(table.Partition.Definitions)) + for _, def := range table.Partition.Definitions { + physicalTableIDs = append(physicalTableIDs, def.ID) + } + return physicalTableIDs +} + +func getPartitionRuleIDs(dbName string, table *model.TableInfo) []string { + if table.GetPartitionInfo() == nil { + return []string{} + } + partRuleIDs := make([]string, 0, len(table.Partition.Definitions)) + for _, def := range table.Partition.Definitions { + partRuleIDs = append(partRuleIDs, fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, dbName, table.Name.L, def.Name.L)) + } + return partRuleIDs +} + +// checkPartitioningKeysConstraints checks that the range partitioning key is included in the table constraint. +func checkPartitioningKeysConstraints(sctx sessionctx.Context, s *ast.CreateTableStmt, tblInfo *model.TableInfo) error { + // Returns directly if there are no unique keys in the table. + if len(tblInfo.Indices) == 0 && !tblInfo.PKIsHandle { + return nil + } + + partCols, err := getPartitionColSlices(sctx.GetExprCtx(), tblInfo, s.Partition) + if err != nil { + return errors.Trace(err) + } + + // Checks that the partitioning key is included in the constraint. + // Every unique key on the table must use every column in the table's partitioning expression. + // See https://dev.mysql.com/doc/refman/5.7/en/partitioning-limitations-partitioning-keys-unique-keys.html + for _, index := range tblInfo.Indices { + if index.Unique && !checkUniqueKeyIncludePartKey(partCols, index.Columns) { + if index.Primary { + // global index does not support clustered index + if tblInfo.IsCommonHandle { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("CLUSTERED INDEX") + } + if !sctx.GetSessionVars().EnableGlobalIndex { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("PRIMARY KEY") + } + } + if !sctx.GetSessionVars().EnableGlobalIndex { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("UNIQUE INDEX") + } + } + } + // when PKIsHandle, tblInfo.Indices will not contain the primary key. + if tblInfo.PKIsHandle { + indexCols := []*model.IndexColumn{{ + Name: tblInfo.GetPkName(), + Length: types.UnspecifiedLength, + }} + if !checkUniqueKeyIncludePartKey(partCols, indexCols) { + return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("CLUSTERED INDEX") + } + } + return nil +} + +func checkPartitionKeysConstraint(pi *model.PartitionInfo, indexColumns []*model.IndexColumn, tblInfo *model.TableInfo) (bool, error) { + var ( + partCols []*model.ColumnInfo + err error + ) + if pi.Type == model.PartitionTypeNone { + return true, nil + } + // The expr will be an empty string if the partition is defined by: + // CREATE TABLE t (...) PARTITION BY RANGE COLUMNS(...) + if partExpr := pi.Expr; partExpr != "" { + // Parse partitioning key, extract the column names in the partitioning key to slice. + partCols, err = extractPartitionColumns(partExpr, tblInfo) + if err != nil { + return false, err + } + } else { + partCols = make([]*model.ColumnInfo, 0, len(pi.Columns)) + for _, col := range pi.Columns { + colInfo := tblInfo.FindPublicColumnByName(col.L) + if colInfo == nil { + return false, infoschema.ErrColumnNotExists.GenWithStackByArgs(col, tblInfo.Name) + } + partCols = append(partCols, colInfo) + } + } + + // In MySQL, every unique key on the table must use every column in the table's partitioning expression.(This + // also includes the table's primary key.) + // In TiDB, global index will be built when this constraint is not satisfied and EnableGlobalIndex is set. + // See https://dev.mysql.com/doc/refman/5.7/en/partitioning-limitations-partitioning-keys-unique-keys.html + return checkUniqueKeyIncludePartKey(columnInfoSlice(partCols), indexColumns), nil +} + +type columnNameExtractor struct { + extractedColumns []*model.ColumnInfo + tblInfo *model.TableInfo + err error +} + +func (*columnNameExtractor) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (cne *columnNameExtractor) Leave(node ast.Node) (ast.Node, bool) { + if c, ok := node.(*ast.ColumnNameExpr); ok { + info := findColumnByName(c.Name.Name.L, cne.tblInfo) + if info != nil { + cne.extractedColumns = append(cne.extractedColumns, info) + return node, true + } + cne.err = dbterror.ErrBadField.GenWithStackByArgs(c.Name.Name.O, "expression") + return nil, false + } + return node, true +} + +func findColumnByName(colName string, tblInfo *model.TableInfo) *model.ColumnInfo { + if tblInfo == nil { + return nil + } + for _, info := range tblInfo.Columns { + if info.Name.L == colName { + return info + } + } + return nil +} + +func extractPartitionColumns(partExpr string, tblInfo *model.TableInfo) ([]*model.ColumnInfo, error) { + partExpr = "select " + partExpr + stmts, _, err := parser.New().ParseSQL(partExpr) + if err != nil { + return nil, errors.Trace(err) + } + extractor := &columnNameExtractor{ + tblInfo: tblInfo, + extractedColumns: make([]*model.ColumnInfo, 0), + } + stmts[0].Accept(extractor) + if extractor.err != nil { + return nil, errors.Trace(extractor.err) + } + return extractor.extractedColumns, nil +} + +// stringSlice is defined for checkUniqueKeyIncludePartKey. +// if Go supports covariance, the code shouldn't be so complex. +type stringSlice interface { + Len() int + At(i int) string +} + +// checkUniqueKeyIncludePartKey checks that the partitioning key is included in the constraint. +func checkUniqueKeyIncludePartKey(partCols stringSlice, idxCols []*model.IndexColumn) bool { + for i := 0; i < partCols.Len(); i++ { + partCol := partCols.At(i) + _, idxCol := model.FindIndexColumnByName(idxCols, partCol) + if idxCol == nil { + // Partition column is not found in the index columns. + return false + } + if idxCol.Length > 0 { + // The partition column is found in the index columns, but the index column is a prefix index + return false + } + } + return true +} + +// columnInfoSlice implements the stringSlice interface. +type columnInfoSlice []*model.ColumnInfo + +func (cis columnInfoSlice) Len() int { + return len(cis) +} + +func (cis columnInfoSlice) At(i int) string { + return cis[i].Name.L +} + +// columnNameSlice implements the stringSlice interface. +type columnNameSlice []*ast.ColumnName + +func (cns columnNameSlice) Len() int { + return len(cns) +} + +func (cns columnNameSlice) At(i int) string { + return cns[i].Name.L +} + +func isPartExprUnsigned(ectx expression.EvalContext, tbInfo *model.TableInfo) bool { + ctx := tables.NewPartitionExprBuildCtx() + expr, err := expression.ParseSimpleExpr(ctx, tbInfo.Partition.Expr, expression.WithTableInfo("", tbInfo)) + if err != nil { + logutil.DDLLogger().Error("isPartExpr failed parsing expression!", zap.Error(err)) + return false + } + if mysql.HasUnsignedFlag(expr.GetType(ectx).GetFlag()) { + return true + } + return false +} + +// truncateTableByReassignPartitionIDs reassigns new partition ids. +func truncateTableByReassignPartitionIDs(t *meta.Meta, tblInfo *model.TableInfo, pids []int64) (err error) { + if len(pids) < len(tblInfo.Partition.Definitions) { + // To make it compatible with older versions when pids was not given + // and if there has been any add/reorganize partition increasing the number of partitions + morePids, err := t.GenGlobalIDs(len(tblInfo.Partition.Definitions) - len(pids)) + if err != nil { + return errors.Trace(err) + } + pids = append(pids, morePids...) + } + newDefs := make([]model.PartitionDefinition, 0, len(tblInfo.Partition.Definitions)) + for i, def := range tblInfo.Partition.Definitions { + newDef := def + newDef.ID = pids[i] + newDefs = append(newDefs, newDef) + } + tblInfo.Partition.Definitions = newDefs + return nil +} + +type partitionExprProcessor func(expression.BuildContext, *model.TableInfo, ast.ExprNode) error + +type partitionExprChecker struct { + processors []partitionExprProcessor + ctx expression.BuildContext + tbInfo *model.TableInfo + err error + + columns []*model.ColumnInfo +} + +func newPartitionExprChecker(ctx expression.BuildContext, tbInfo *model.TableInfo, processor ...partitionExprProcessor) *partitionExprChecker { + p := &partitionExprChecker{processors: processor, ctx: ctx, tbInfo: tbInfo} + p.processors = append(p.processors, p.extractColumns) + return p +} + +func (p *partitionExprChecker) Enter(n ast.Node) (node ast.Node, skipChildren bool) { + expr, ok := n.(ast.ExprNode) + if !ok { + return n, true + } + for _, processor := range p.processors { + if err := processor(p.ctx, p.tbInfo, expr); err != nil { + p.err = err + return n, true + } + } + + return n, false +} + +func (p *partitionExprChecker) Leave(n ast.Node) (node ast.Node, ok bool) { + return n, p.err == nil +} + +func (p *partitionExprChecker) extractColumns(_ expression.BuildContext, _ *model.TableInfo, expr ast.ExprNode) error { + columnNameExpr, ok := expr.(*ast.ColumnNameExpr) + if !ok { + return nil + } + colInfo := findColumnByName(columnNameExpr.Name.Name.L, p.tbInfo) + if colInfo == nil { + return errors.Trace(dbterror.ErrBadField.GenWithStackByArgs(columnNameExpr.Name.Name.L, "partition function")) + } + + p.columns = append(p.columns, colInfo) + return nil +} + +func checkPartitionExprAllowed(_ expression.BuildContext, tb *model.TableInfo, e ast.ExprNode) error { + switch v := e.(type) { + case *ast.FuncCallExpr: + if _, ok := expression.AllowedPartitionFuncMap[v.FnName.L]; ok { + return nil + } + case *ast.BinaryOperationExpr: + if _, ok := expression.AllowedPartition4BinaryOpMap[v.Op]; ok { + return errors.Trace(checkNoTimestampArgs(tb, v.L, v.R)) + } + case *ast.UnaryOperationExpr: + if _, ok := expression.AllowedPartition4UnaryOpMap[v.Op]; ok { + return errors.Trace(checkNoTimestampArgs(tb, v.V)) + } + case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *driver.ValueExpr, *ast.MaxValueExpr, + *ast.DefaultExpr, *ast.TimeUnitExpr: + return nil + } + return errors.Trace(dbterror.ErrPartitionFunctionIsNotAllowed) +} + +func checkPartitionExprArgs(_ expression.BuildContext, tblInfo *model.TableInfo, e ast.ExprNode) error { + expr, ok := e.(*ast.FuncCallExpr) + if !ok { + return nil + } + argsType, err := collectArgsType(tblInfo, expr.Args...) + if err != nil { + return errors.Trace(err) + } + switch expr.FnName.L { + case ast.ToDays, ast.ToSeconds, ast.DayOfMonth, ast.Month, ast.DayOfYear, ast.Quarter, ast.YearWeek, + ast.Year, ast.Weekday, ast.DayOfWeek, ast.Day: + return errors.Trace(checkResultOK(hasDateArgs(argsType...))) + case ast.Hour, ast.Minute, ast.Second, ast.TimeToSec, ast.MicroSecond: + return errors.Trace(checkResultOK(hasTimeArgs(argsType...))) + case ast.UnixTimestamp: + return errors.Trace(checkResultOK(hasTimestampArgs(argsType...))) + case ast.FromDays: + return errors.Trace(checkResultOK(hasDateArgs(argsType...) || hasTimeArgs(argsType...))) + case ast.Extract: + switch expr.Args[0].(*ast.TimeUnitExpr).Unit { + case ast.TimeUnitYear, ast.TimeUnitYearMonth, ast.TimeUnitQuarter, ast.TimeUnitMonth, ast.TimeUnitDay: + return errors.Trace(checkResultOK(hasDateArgs(argsType...))) + case ast.TimeUnitDayMicrosecond, ast.TimeUnitDayHour, ast.TimeUnitDayMinute, ast.TimeUnitDaySecond: + return errors.Trace(checkResultOK(hasDatetimeArgs(argsType...))) + case ast.TimeUnitHour, ast.TimeUnitHourMinute, ast.TimeUnitHourSecond, ast.TimeUnitMinute, ast.TimeUnitMinuteSecond, + ast.TimeUnitSecond, ast.TimeUnitMicrosecond, ast.TimeUnitHourMicrosecond, ast.TimeUnitMinuteMicrosecond, ast.TimeUnitSecondMicrosecond: + return errors.Trace(checkResultOK(hasTimeArgs(argsType...))) + default: + return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) + } + case ast.DateDiff: + return errors.Trace(checkResultOK(slice.AllOf(argsType, func(i int) bool { + return hasDateArgs(argsType[i]) + }))) + + case ast.Abs, ast.Ceiling, ast.Floor, ast.Mod: + has := hasTimestampArgs(argsType...) + if has { + return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) + } + } + return nil +} + +func collectArgsType(tblInfo *model.TableInfo, exprs ...ast.ExprNode) ([]byte, error) { + ts := make([]byte, 0, len(exprs)) + for _, arg := range exprs { + col, ok := arg.(*ast.ColumnNameExpr) + if !ok { + continue + } + columnInfo := findColumnByName(col.Name.Name.L, tblInfo) + if columnInfo == nil { + return nil, errors.Trace(dbterror.ErrBadField.GenWithStackByArgs(col.Name.Name.L, "partition function")) + } + ts = append(ts, columnInfo.GetType()) + } + + return ts, nil +} + +func hasDateArgs(argsType ...byte) bool { + return slice.AnyOf(argsType, func(i int) bool { + return argsType[i] == mysql.TypeDate || argsType[i] == mysql.TypeDatetime + }) +} + +func hasTimeArgs(argsType ...byte) bool { + return slice.AnyOf(argsType, func(i int) bool { + return argsType[i] == mysql.TypeDuration || argsType[i] == mysql.TypeDatetime + }) +} + +func hasTimestampArgs(argsType ...byte) bool { + return slice.AnyOf(argsType, func(i int) bool { + return argsType[i] == mysql.TypeTimestamp + }) +} + +func hasDatetimeArgs(argsType ...byte) bool { + return slice.AnyOf(argsType, func(i int) bool { + return argsType[i] == mysql.TypeDatetime + }) +} + +func checkNoTimestampArgs(tbInfo *model.TableInfo, exprs ...ast.ExprNode) error { + argsType, err := collectArgsType(tbInfo, exprs...) + if err != nil { + return err + } + if hasTimestampArgs(argsType...) { + return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) + } + return nil +} + +// hexIfNonPrint checks if printable UTF-8 characters from a single quoted string, +// if so, just returns the string +// else returns a hex string of the binary string (i.e. actual encoding, not unicode code points!) +func hexIfNonPrint(s string) string { + isPrint := true + // https://go.dev/blog/strings `for range` of string converts to runes! + for _, runeVal := range s { + if !strconv.IsPrint(runeVal) { + isPrint = false + break + } + } + if isPrint { + return s + } + // To avoid 'simple' MySQL accepted escape characters, to be showed as hex, just escape them + // \0 \b \n \r \t \Z, see https://dev.mysql.com/doc/refman/8.0/en/string-literals.html + isPrint = true + res := "" + for _, runeVal := range s { + switch runeVal { + case 0: // Null + res += `\0` + case 7: // Bell + res += `\b` + case '\t': // 9 + res += `\t` + case '\n': // 10 + res += `\n` + case '\r': // 13 + res += `\r` + case 26: // ctrl-z / Substitute + res += `\Z` + default: + if !strconv.IsPrint(runeVal) { + isPrint = false + break + } + res += string(runeVal) + } + } + if isPrint { + return res + } + // Not possible to create an easy interpreted MySQL string, return as hex string + // Can be converted to string in MySQL like: CAST(UNHEX('') AS CHAR(255)) + return "0x" + hex.EncodeToString([]byte(driver.UnwrapFromSingleQuotes(s))) +} + +func writeColumnListToBuffer(partitionInfo *model.PartitionInfo, sqlMode mysql.SQLMode, buf *bytes.Buffer) { + if partitionInfo.IsEmptyColumns { + return + } + for i, col := range partitionInfo.Columns { + buf.WriteString(stringutil.Escape(col.O, sqlMode)) + if i < len(partitionInfo.Columns)-1 { + buf.WriteString(",") + } + } +} + +// AppendPartitionInfo is used in SHOW CREATE TABLE as well as generation the SQL syntax +// for the PartitionInfo during validation of various DDL commands +func AppendPartitionInfo(partitionInfo *model.PartitionInfo, buf *bytes.Buffer, sqlMode mysql.SQLMode) { + if partitionInfo == nil { + return + } + // Since MySQL 5.1/5.5 is very old and TiDB aims for 5.7/8.0 compatibility, we will not + // include the /*!50100 or /*!50500 comments for TiDB. + // This also solves the issue with comments within comments that would happen for + // PLACEMENT POLICY options. + defaultPartitionDefinitions := true + if partitionInfo.Type == model.PartitionTypeHash || + partitionInfo.Type == model.PartitionTypeKey { + for i, def := range partitionInfo.Definitions { + if def.Name.O != fmt.Sprintf("p%d", i) { + defaultPartitionDefinitions = false + break + } + if len(def.Comment) > 0 || def.PlacementPolicyRef != nil { + defaultPartitionDefinitions = false + break + } + } + + if defaultPartitionDefinitions { + if partitionInfo.Type == model.PartitionTypeHash { + fmt.Fprintf(buf, "\nPARTITION BY HASH (%s) PARTITIONS %d", partitionInfo.Expr, partitionInfo.Num) + } else { + buf.WriteString("\nPARTITION BY KEY (") + writeColumnListToBuffer(partitionInfo, sqlMode, buf) + buf.WriteString(")") + fmt.Fprintf(buf, " PARTITIONS %d", partitionInfo.Num) + } + return + } + } + // this if statement takes care of lists/range/key columns case + if len(partitionInfo.Columns) > 0 { + // partitionInfo.Type == model.PartitionTypeRange || partitionInfo.Type == model.PartitionTypeList + // || partitionInfo.Type == model.PartitionTypeKey + // Notice that MySQL uses two spaces between LIST and COLUMNS... + if partitionInfo.Type == model.PartitionTypeKey { + fmt.Fprintf(buf, "\nPARTITION BY %s (", partitionInfo.Type.String()) + } else { + fmt.Fprintf(buf, "\nPARTITION BY %s COLUMNS(", partitionInfo.Type.String()) + } + writeColumnListToBuffer(partitionInfo, sqlMode, buf) + buf.WriteString(")\n(") + } else { + fmt.Fprintf(buf, "\nPARTITION BY %s (%s)\n(", partitionInfo.Type.String(), partitionInfo.Expr) + } + + AppendPartitionDefs(partitionInfo, buf, sqlMode) + buf.WriteString(")") +} + +// AppendPartitionDefs generates a list of partition definitions needed for SHOW CREATE TABLE (in executor/show.go) +// as well as needed for generating the ADD PARTITION query for INTERVAL partitioning of ALTER TABLE t LAST PARTITION +// and generating the CREATE TABLE query from CREATE TABLE ... INTERVAL +func AppendPartitionDefs(partitionInfo *model.PartitionInfo, buf *bytes.Buffer, sqlMode mysql.SQLMode) { + for i, def := range partitionInfo.Definitions { + if i > 0 { + fmt.Fprintf(buf, ",\n ") + } + fmt.Fprintf(buf, "PARTITION %s", stringutil.Escape(def.Name.O, sqlMode)) + // PartitionTypeHash and PartitionTypeKey do not have any VALUES definition + if partitionInfo.Type == model.PartitionTypeRange { + lessThans := make([]string, len(def.LessThan)) + for idx, v := range def.LessThan { + lessThans[idx] = hexIfNonPrint(v) + } + fmt.Fprintf(buf, " VALUES LESS THAN (%s)", strings.Join(lessThans, ",")) + } else if partitionInfo.Type == model.PartitionTypeList { + if len(def.InValues) == 0 { + fmt.Fprintf(buf, " DEFAULT") + } else if len(def.InValues) == 1 && + len(def.InValues[0]) == 1 && + strings.EqualFold(def.InValues[0][0], "DEFAULT") { + fmt.Fprintf(buf, " DEFAULT") + } else { + values := bytes.NewBuffer(nil) + for j, inValues := range def.InValues { + if j > 0 { + values.WriteString(",") + } + if len(inValues) > 1 { + values.WriteString("(") + tmpVals := make([]string, len(inValues)) + for idx, v := range inValues { + tmpVals[idx] = hexIfNonPrint(v) + } + values.WriteString(strings.Join(tmpVals, ",")) + values.WriteString(")") + } else if len(inValues) == 1 { + values.WriteString(hexIfNonPrint(inValues[0])) + } + } + fmt.Fprintf(buf, " VALUES IN (%s)", values.String()) + } + } + if len(def.Comment) > 0 { + fmt.Fprintf(buf, " COMMENT '%s'", format.OutputFormat(def.Comment)) + } + if def.PlacementPolicyRef != nil { + // add placement ref info here + fmt.Fprintf(buf, " /*T![placement] PLACEMENT POLICY=%s */", stringutil.Escape(def.PlacementPolicyRef.Name.O, sqlMode)) + } + } +} + +func generatePartValuesWithTp(partVal types.Datum, tp types.FieldType) (string, error) { + if partVal.Kind() == types.KindNull { + return "NULL", nil + } + + s, err := partVal.ToString() + if err != nil { + return "", err + } + + switch tp.EvalType() { + case types.ETInt: + return s, nil + case types.ETString: + // The `partVal` can be an invalid utf8 string if it's converted to BINARY, then the content will be lost after + // marshaling and storing in the schema. In this case, we use a hex literal to work around this issue. + if tp.GetCharset() == charset.CharsetBin { + return fmt.Sprintf("_binary 0x%x", s), nil + } + return driver.WrapInSingleQuotes(s), nil + case types.ETDatetime, types.ETDuration: + return driver.WrapInSingleQuotes(s), nil + } + + return "", dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() +} + +func checkPartitionDefinitionConstraints(ctx sessionctx.Context, tbInfo *model.TableInfo) error { + var err error + if err = checkPartitionNameUnique(tbInfo.Partition); err != nil { + return errors.Trace(err) + } + if err = checkAddPartitionTooManyPartitions(uint64(len(tbInfo.Partition.Definitions))); err != nil { + return err + } + if err = checkAddPartitionOnTemporaryMode(tbInfo); err != nil { + return err + } + if err = checkPartitionColumnsUnique(tbInfo); err != nil { + return err + } + + switch tbInfo.Partition.Type { + case model.PartitionTypeRange: + err = checkPartitionByRange(ctx, tbInfo) + case model.PartitionTypeHash, model.PartitionTypeKey: + err = checkPartitionByHash(ctx, tbInfo) + case model.PartitionTypeList: + err = checkPartitionByList(ctx, tbInfo) + } + return errors.Trace(err) +} + +func checkPartitionByHash(ctx sessionctx.Context, tbInfo *model.TableInfo) error { + return checkNoHashPartitions(ctx, tbInfo.Partition.Num) +} + +// checkPartitionByRange checks validity of a "BY RANGE" partition. +func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) error { + failpoint.Inject("CheckPartitionByRangeErr", func() { + ctx.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryMemoryExceeded) + panic(ctx.GetSessionVars().SQLKiller.HandleSignal()) + }) + pi := tbInfo.Partition + + if len(pi.Columns) == 0 { + return checkRangePartitionValue(ctx, tbInfo) + } + + return checkRangeColumnsPartitionValue(ctx, tbInfo) +} + +func checkRangeColumnsPartitionValue(ctx sessionctx.Context, tbInfo *model.TableInfo) error { + // Range columns partition key supports multiple data types with integer、datetime、string. + pi := tbInfo.Partition + defs := pi.Definitions + if len(defs) < 1 { + return ast.ErrPartitionsMustBeDefined.GenWithStackByArgs("RANGE") + } + + curr := &defs[0] + if len(curr.LessThan) != len(pi.Columns) { + return errors.Trace(ast.ErrPartitionColumnList) + } + var prev *model.PartitionDefinition + for i := 1; i < len(defs); i++ { + prev, curr = curr, &defs[i] + succ, err := checkTwoRangeColumns(ctx, curr, prev, pi, tbInfo) + if err != nil { + return err + } + if !succ { + return errors.Trace(dbterror.ErrRangeNotIncreasing) + } + } + return nil +} + +func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDefinition, pi *model.PartitionInfo, tbInfo *model.TableInfo) (bool, error) { + if len(curr.LessThan) != len(pi.Columns) { + return false, errors.Trace(ast.ErrPartitionColumnList) + } + for i := 0; i < len(pi.Columns); i++ { + // Special handling for MAXVALUE. + if strings.EqualFold(curr.LessThan[i], partitionMaxValue) && !strings.EqualFold(prev.LessThan[i], partitionMaxValue) { + // If current is maxvalue, it certainly >= previous. + return true, nil + } + if strings.EqualFold(prev.LessThan[i], partitionMaxValue) { + // Current is not maxvalue, and previous is maxvalue. + return false, nil + } + + // The tuples of column values used to define the partitions are strictly increasing: + // PARTITION p0 VALUES LESS THAN (5,10,'ggg') + // PARTITION p1 VALUES LESS THAN (10,20,'mmm') + // PARTITION p2 VALUES LESS THAN (15,30,'sss') + colInfo := findColumnByName(pi.Columns[i].L, tbInfo) + cmp, err := parseAndEvalBoolExpr(ctx.GetExprCtx(), curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo) + if err != nil { + return false, err + } + + if cmp > 0 { + return true, nil + } + + if cmp < 0 { + return false, nil + } + } + return false, nil +} + +// equal, return 0 +// greater, return 1 +// less, return -1 +func parseAndEvalBoolExpr(ctx expression.BuildContext, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (int64, error) { + lexpr, err := expression.ParseSimpleExpr(ctx, l, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType)) + if err != nil { + return 0, err + } + rexpr, err := expression.ParseSimpleExpr(ctx, r, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType)) + if err != nil { + return 0, err + } + + e, err := expression.NewFunctionBase(ctx, ast.EQ, field_types.NewFieldType(mysql.TypeLonglong), lexpr, rexpr) + if err != nil { + return 0, err + } + e.SetCharsetAndCollation(colInfo.GetCharset(), colInfo.GetCollate()) + res, _, err1 := e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) + if err1 != nil { + return 0, err1 + } + if res == 1 { + return 0, nil + } + + e, err = expression.NewFunctionBase(ctx, ast.GT, field_types.NewFieldType(mysql.TypeLonglong), lexpr, rexpr) + if err != nil { + return 0, err + } + e.SetCharsetAndCollation(colInfo.GetCharset(), colInfo.GetCollate()) + res, _, err1 = e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) + if err1 != nil { + return 0, err1 + } + if res > 0 { + return 1, nil + } + return -1, nil +} + +// checkPartitionByList checks validity of a "BY LIST" partition. +func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error { + return checkListPartitionValue(ctx.GetExprCtx(), tbInfo) +} diff --git a/pkg/ddl/placement/binding__failpoint_binding__.go b/pkg/ddl/placement/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..72feaf66a5d4f --- /dev/null +++ b/pkg/ddl/placement/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package placement + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/ddl/placement/bundle.go b/pkg/ddl/placement/bundle.go index 427bbdf2a7d98..dbd8a70418f46 100644 --- a/pkg/ddl/placement/bundle.go +++ b/pkg/ddl/placement/bundle.go @@ -301,11 +301,11 @@ func NewBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err // String implements fmt.Stringer. func (b *Bundle) String() string { t, err := json.Marshal(b) - failpoint.Inject("MockMarshalFailure", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockMarshalFailure")); _err_ == nil { if _, ok := val.(bool); ok { err = errors.New("test") } - }) + } if err != nil { return "" } diff --git a/pkg/ddl/placement/bundle.go__failpoint_stash__ b/pkg/ddl/placement/bundle.go__failpoint_stash__ new file mode 100644 index 0000000000000..427bbdf2a7d98 --- /dev/null +++ b/pkg/ddl/placement/bundle.go__failpoint_stash__ @@ -0,0 +1,712 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package placement + +import ( + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "math" + "slices" + "sort" + "strconv" + "strings" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/codec" + pd "github.com/tikv/pd/client/http" + "gopkg.in/yaml.v2" +) + +// Bundle is a group of all rules and configurations. It is used to support rule cache. +// Alias `pd.GroupBundle` is to wrap more methods. +type Bundle pd.GroupBundle + +// NewBundle will create a bundle with the provided ID. +// Note that you should never pass negative id. +func NewBundle(id int64) *Bundle { + return &Bundle{ + ID: GroupID(id), + } +} + +// NewBundleFromConstraintsOptions will transform constraints options into the bundle. +func NewBundleFromConstraintsOptions(options *model.PlacementSettings) (*Bundle, error) { + if options == nil { + return nil, fmt.Errorf("%w: options can not be nil", ErrInvalidPlacementOptions) + } + + if len(options.PrimaryRegion) > 0 || len(options.Regions) > 0 || len(options.Schedule) > 0 { + return nil, fmt.Errorf("%w: should be [LEADER/VOTER/LEARNER/FOLLOWER]_CONSTRAINTS=.. [VOTERS/FOLLOWERS/LEARNERS]=.., mixed other sugar options %s", ErrInvalidPlacementOptions, options) + } + + constraints := options.Constraints + leaderConst := options.LeaderConstraints + learnerConstraints := options.LearnerConstraints + followerConstraints := options.FollowerConstraints + explicitFollowerCount := options.Followers + explicitLearnerCount := options.Learners + + rules := []*pd.Rule{} + commonConstraints, err := NewConstraintsFromYaml([]byte(constraints)) + if err != nil { + // If it's not in array format, attempt to parse it as a dictionary for more detailed definitions. + // The dictionary format specifies details for each replica. Constraints are used to define normal + // replicas that should act as voters. + // For example: CONSTRAINTS='{ "+region=us-east-1":2, "+region=us-east-2": 2, "+region=us-west-1": 1}' + normalReplicasRules, err := NewRuleBuilder(). + SetRole(pd.Voter). + SetConstraintStr(constraints). + BuildRulesWithDictConstraintsOnly() + if err != nil { + return nil, err + } + rules = append(rules, normalReplicasRules...) + } + needCreateDefault := len(rules) == 0 + leaderConstraints, err := NewConstraintsFromYaml([]byte(leaderConst)) + if err != nil { + return nil, fmt.Errorf("%w: 'LeaderConstraints' should be [constraint1, ...] or any yaml compatible array representation", err) + } + for _, cnst := range commonConstraints { + if err := AddConstraint(&leaderConstraints, cnst); err != nil { + return nil, fmt.Errorf("%w: LeaderConstraints conflicts with Constraints", err) + } + } + leaderReplicas, followerReplicas := uint64(1), uint64(2) + if explicitFollowerCount > 0 { + followerReplicas = explicitFollowerCount + } + if !needCreateDefault { + if len(leaderConst) == 0 { + leaderReplicas = 0 + } + if len(followerConstraints) == 0 { + if explicitFollowerCount > 0 { + return nil, fmt.Errorf("%w: specify follower count without specify follower constraints when specify other constraints", ErrInvalidPlacementOptions) + } + followerReplicas = 0 + } + } + + // create leader rule. + // if no constraints, we need create default leader rule. + if leaderReplicas > 0 { + leaderRule := NewRule(pd.Leader, leaderReplicas, leaderConstraints) + rules = append(rules, leaderRule) + } + + // create follower rules. + // if no constraints, we need create default follower rules. + if followerReplicas > 0 { + builder := NewRuleBuilder(). + SetRole(pd.Voter). + SetReplicasNum(followerReplicas). + SetSkipCheckReplicasConsistent(needCreateDefault && (explicitFollowerCount == 0)). + SetConstraintStr(followerConstraints) + followerRules, err := builder.BuildRules() + if err != nil { + return nil, fmt.Errorf("%w: invalid FollowerConstraints", err) + } + for _, followerRule := range followerRules { + for _, cnst := range commonConstraints { + if err := AddConstraint(&followerRule.LabelConstraints, cnst); err != nil { + return nil, fmt.Errorf("%w: FollowerConstraints conflicts with Constraints", err) + } + } + } + rules = append(rules, followerRules...) + } + + // create learner rules. + builder := NewRuleBuilder(). + SetRole(pd.Learner). + SetReplicasNum(explicitLearnerCount). + SetConstraintStr(learnerConstraints) + learnerRules, err := builder.BuildRules() + if err != nil { + return nil, fmt.Errorf("%w: invalid LearnerConstraints", err) + } + for _, rule := range learnerRules { + for _, cnst := range commonConstraints { + if err := AddConstraint(&rule.LabelConstraints, cnst); err != nil { + return nil, fmt.Errorf("%w: LearnerConstraints conflicts with Constraints", err) + } + } + } + rules = append(rules, learnerRules...) + labels, err := newLocationLabelsFromSurvivalPreferences(options.SurvivalPreferences) + if err != nil { + return nil, err + } + for _, rule := range rules { + rule.LocationLabels = labels + } + return &Bundle{Rules: rules}, nil +} + +// NewBundleFromSugarOptions will transform syntax sugar options into the bundle. +func NewBundleFromSugarOptions(options *model.PlacementSettings) (*Bundle, error) { + if options == nil { + return nil, fmt.Errorf("%w: options can not be nil", ErrInvalidPlacementOptions) + } + + if len(options.LeaderConstraints) > 0 || len(options.LearnerConstraints) > 0 || len(options.FollowerConstraints) > 0 || len(options.Constraints) > 0 || options.Learners > 0 { + return nil, fmt.Errorf("%w: should be PRIMARY_REGION=.. REGIONS=.. FOLLOWERS=.. SCHEDULE=.., mixed other constraints into options %s", ErrInvalidPlacementOptions, options) + } + + primaryRegion := strings.TrimSpace(options.PrimaryRegion) + + var regions []string + if k := strings.TrimSpace(options.Regions); len(k) > 0 { + regions = strings.Split(k, ",") + for i, r := range regions { + regions[i] = strings.TrimSpace(r) + } + } + + followers := options.Followers + if followers == 0 { + followers = 2 + } + schedule := options.Schedule + + var rules []*pd.Rule + + locationLabels, err := newLocationLabelsFromSurvivalPreferences(options.SurvivalPreferences) + if err != nil { + return nil, err + } + + // in case empty primaryRegion and regions, just return an empty bundle + if primaryRegion == "" && len(regions) == 0 { + rules = append(rules, NewRule(pd.Voter, followers+1, NewConstraintsDirect())) + for _, rule := range rules { + rule.LocationLabels = locationLabels + } + return &Bundle{Rules: rules}, nil + } + + // regions must include the primary + slices.Sort(regions) + primaryIndex := sort.SearchStrings(regions, primaryRegion) + if primaryIndex >= len(regions) || regions[primaryIndex] != primaryRegion { + return nil, fmt.Errorf("%w: primary region must be included in regions", ErrInvalidPlacementOptions) + } + + // primaryCount only makes sense when len(regions) > 0 + // but we will compute it here anyway for reusing code + var primaryCount uint64 + switch strings.ToLower(schedule) { + case "", "even": + primaryCount = uint64(math.Ceil(float64(followers+1) / float64(len(regions)))) + case "majority_in_primary": + // calculate how many replicas need to be in the primary region for quorum + primaryCount = (followers+1)/2 + 1 + default: + return nil, fmt.Errorf("%w: unsupported schedule %s", ErrInvalidPlacementOptions, schedule) + } + + rules = append(rules, NewRule(pd.Leader, 1, NewConstraintsDirect(NewConstraintDirect("region", pd.In, primaryRegion)))) + if primaryCount > 1 { + rules = append(rules, NewRule(pd.Voter, primaryCount-1, NewConstraintsDirect(NewConstraintDirect("region", pd.In, primaryRegion)))) + } + if cnt := followers + 1 - primaryCount; cnt > 0 { + // delete primary from regions + regions = regions[:primaryIndex+copy(regions[primaryIndex:], regions[primaryIndex+1:])] + if len(regions) > 0 { + rules = append(rules, NewRule(pd.Voter, cnt, NewConstraintsDirect(NewConstraintDirect("region", pd.In, regions...)))) + } else { + rules = append(rules, NewRule(pd.Voter, cnt, NewConstraintsDirect())) + } + } + + // set location labels + for _, rule := range rules { + rule.LocationLabels = locationLabels + } + + return &Bundle{Rules: rules}, nil +} + +// Non-Exported functionality function, do not use it directly but NewBundleFromOptions +// here is for only directly used in the test. +func newBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err error) { + if options == nil { + return nil, fmt.Errorf("%w: options can not be nil", ErrInvalidPlacementOptions) + } + + if options.Followers > uint64(8) { + return nil, fmt.Errorf("%w: followers should be less than or equal to 8: %d", ErrInvalidPlacementOptions, options.Followers) + } + + // always prefer the sugar syntax, which gives better schedule results most of the time + isSyntaxSugar := true + if len(options.LeaderConstraints) > 0 || len(options.LearnerConstraints) > 0 || len(options.FollowerConstraints) > 0 || len(options.Constraints) > 0 || options.Learners > 0 { + isSyntaxSugar = false + } + + if isSyntaxSugar { + bundle, err = NewBundleFromSugarOptions(options) + } else { + bundle, err = NewBundleFromConstraintsOptions(options) + } + return bundle, err +} + +// newLocationLabelsFromSurvivalPreferences will parse the survival preferences into location labels. +func newLocationLabelsFromSurvivalPreferences(survivalPreferenceStr string) ([]string, error) { + if len(survivalPreferenceStr) > 0 { + labels := []string{} + err := yaml.UnmarshalStrict([]byte(survivalPreferenceStr), &labels) + if err != nil { + return nil, ErrInvalidSurvivalPreferenceFormat + } + return labels, nil + } + return nil, nil +} + +// NewBundleFromOptions will transform options into the bundle. +func NewBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err error) { + bundle, err = newBundleFromOptions(options) + if err != nil { + return nil, err + } + if bundle == nil { + return nil, nil + } + err = bundle.Tidy() + if err != nil { + return nil, err + } + return bundle, err +} + +// String implements fmt.Stringer. +func (b *Bundle) String() string { + t, err := json.Marshal(b) + failpoint.Inject("MockMarshalFailure", func(val failpoint.Value) { + if _, ok := val.(bool); ok { + err = errors.New("test") + } + }) + if err != nil { + return "" + } + return string(t) +} + +// Tidy will post optimize Rules, trying to generate rules that suits PD. +func (b *Bundle) Tidy() error { + tempRules := b.Rules[:0] + id := 0 + for _, rule := range b.Rules { + // useless Rule + if rule.Count <= 0 { + continue + } + // refer to tidb#22065. + // add -engine=tiflash to every rule to avoid schedules to tiflash instances. + // placement rules in SQL is not compatible with `set tiflash replica` yet + err := AddConstraint(&rule.LabelConstraints, pd.LabelConstraint{ + Op: pd.NotIn, + Key: EngineLabelKey, + Values: []string{EngineLabelTiFlash}, + }) + if err != nil { + return err + } + rule.ID = strconv.Itoa(id) + tempRules = append(tempRules, rule) + id++ + } + + groups := make(map[string]*constraintsGroup) + finalRules := tempRules[:0] + for _, rule := range tempRules { + key := ConstraintsFingerPrint(&rule.LabelConstraints) + existing, ok := groups[key] + if !ok { + groups[key] = &constraintsGroup{rules: []*pd.Rule{rule}} + continue + } + existing.rules = append(existing.rules, rule) + } + for _, group := range groups { + group.MergeRulesByRole() + } + if err := transformableLeaderConstraint(groups); err != nil { + return err + } + for _, group := range groups { + finalRules = append(finalRules, group.rules...) + } + // sort by id + sort.SliceStable(finalRules, func(i, j int) bool { + return finalRules[i].ID < finalRules[j].ID + }) + b.Rules = finalRules + return nil +} + +// constraintsGroup is a group of rules with the same constraints. +type constraintsGroup struct { + rules []*pd.Rule + // canBecameLeader means the group has leader/voter role, + // it's valid if it has leader. + canBecameLeader bool + // isLeaderGroup means it has specified leader role in this group. + isLeaderGroup bool +} + +func transformableLeaderConstraint(groups map[string]*constraintsGroup) error { + var leaderGroup *constraintsGroup + canBecameLeaderNum := 0 + for _, group := range groups { + if group.isLeaderGroup { + if leaderGroup != nil { + return ErrInvalidPlacementOptions + } + leaderGroup = group + } + if group.canBecameLeader { + canBecameLeaderNum++ + } + } + // If there is a specified group should have leader, and only this group can be a leader, that means + // the leader's priority is certain, so we can merge the transformable rules into one. + // eg: + // - [ group1 (L F), group2 (F) ], after merging is [group1 (2*V), group2 (F)], we still know the leader prefers group1. + // - [ group1 (L F), group2 (V) ], after merging is [group1 (2*V), group2 (V)], we can't know leader priority after merge. + if leaderGroup != nil && canBecameLeaderNum == 1 { + leaderGroup.MergeTransformableRoles() + } + return nil +} + +// MergeRulesByRole merges the rules with the same role. +func (c *constraintsGroup) MergeRulesByRole() { + // Create a map to store rules by role + rulesByRole := make(map[pd.PeerRoleType][]*pd.Rule) + + // Iterate through each rule + for _, rule := range c.rules { + // Add the rule to the map based on its role + rulesByRole[rule.Role] = append(rulesByRole[rule.Role], rule) + if rule.Role == pd.Leader || rule.Role == pd.Voter { + c.canBecameLeader = true + } + if rule.Role == pd.Leader { + c.isLeaderGroup = true + } + } + + // Clear existing rules + c.rules = nil + + // Iterate through each role and merge the rules + for _, rules := range rulesByRole { + mergedRule := rules[0] + for i, rule := range rules { + if i == 0 { + continue + } + mergedRule.Count += rule.Count + if mergedRule.ID > rule.ID { + mergedRule.ID = rule.ID + } + } + c.rules = append(c.rules, mergedRule) + } +} + +// MergeTransformableRoles merges all the rules to one that can be transformed to other roles. +func (c *constraintsGroup) MergeTransformableRoles() { + if len(c.rules) == 0 || len(c.rules) == 1 { + return + } + var mergedRule *pd.Rule + newRules := make([]*pd.Rule, 0, len(c.rules)) + for _, rule := range c.rules { + // Learner is not transformable, it should be promote by PD. + if rule.Role == pd.Learner { + newRules = append(newRules, rule) + continue + } + if mergedRule == nil { + mergedRule = rule + continue + } + mergedRule.Count += rule.Count + if mergedRule.ID > rule.ID { + mergedRule.ID = rule.ID + } + } + if mergedRule != nil { + mergedRule.Role = pd.Voter + newRules = append(newRules, mergedRule) + } + c.rules = newRules +} + +// GetRangeStartAndEndKeyHex get startKeyHex and endKeyHex of range by rangeBundleID. +func GetRangeStartAndEndKeyHex(rangeBundleID string) (startKey string, endKey string) { + startKey, endKey = "", "" + if rangeBundleID == TiDBBundleRangePrefixForMeta { + startKey = hex.EncodeToString(metaPrefix) + endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(0))) + } + return startKey, endKey +} + +// RebuildForRange rebuilds the bundle for system range. +func (b *Bundle) RebuildForRange(rangeName string, policyName string) *Bundle { + rule := b.Rules + switch rangeName { + case KeyRangeGlobal: + b.ID = TiDBBundleRangePrefixForGlobal + b.Index = RuleIndexKeyRangeForGlobal + case KeyRangeMeta: + b.ID = TiDBBundleRangePrefixForMeta + b.Index = RuleIndexKeyRangeForMeta + } + + startKey, endKey := GetRangeStartAndEndKeyHex(b.ID) + b.Override = true + newRules := make([]*pd.Rule, 0, len(rule)) + for i, r := range b.Rules { + cp := r.Clone() + cp.ID = fmt.Sprintf("%s_rule_%d", strings.ToLower(policyName), i) + cp.GroupID = b.ID + cp.StartKeyHex = startKey + cp.EndKeyHex = endKey + cp.Index = i + newRules = append(newRules, cp) + } + b.Rules = newRules + return b +} + +// Reset resets the bundle ID and keyrange of all rules. +func (b *Bundle) Reset(ruleIndex int, newIDs []int64) *Bundle { + // eliminate the redundant rules. + var basicRules []*pd.Rule + if len(b.Rules) != 0 { + // Make priority for rules with RuleIndexTable cause of duplication rules existence with RuleIndexPartition. + // If RuleIndexTable doesn't exist, bundle itself is a independent series of rules for a partition. + for _, rule := range b.Rules { + if rule.Index == RuleIndexTable { + basicRules = append(basicRules, rule) + } + } + if len(basicRules) == 0 { + basicRules = b.Rules + } + } + + // extend and reset basic rules for all new ids, the first id should be the group id. + b.ID = GroupID(newIDs[0]) + b.Index = ruleIndex + b.Override = true + newRules := make([]*pd.Rule, 0, len(basicRules)*len(newIDs)) + for i, newID := range newIDs { + // rule.id should be distinguished with each other, otherwise it will be de-duplicated in pd http api. + var ruleID string + if ruleIndex == RuleIndexPartition { + ruleID = "partition_rule_" + strconv.FormatInt(newID, 10) + } else { + if i == 0 { + ruleID = "table_rule_" + strconv.FormatInt(newID, 10) + } else { + ruleID = "partition_rule_" + strconv.FormatInt(newID, 10) + } + } + // Involve all the table level objects. + startKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID))) + endKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID+1))) + for j, rule := range basicRules { + clone := rule.Clone() + // for the rules of one element id, distinguishing the rule ids to avoid the PD's overlap. + clone.ID = ruleID + "_" + strconv.FormatInt(int64(j), 10) + clone.GroupID = b.ID + clone.StartKeyHex = startKey + clone.EndKeyHex = endKey + if i == 0 { + clone.Index = RuleIndexTable + } else { + clone.Index = RuleIndexPartition + } + newRules = append(newRules, clone) + } + } + b.Rules = newRules + return b +} + +// Clone is used to duplicate a bundle. +func (b *Bundle) Clone() *Bundle { + newBundle := &Bundle{} + *newBundle = *b + if len(b.Rules) > 0 { + newBundle.Rules = make([]*pd.Rule, 0, len(b.Rules)) + for i := range b.Rules { + newBundle.Rules = append(newBundle.Rules, b.Rules[i].Clone()) + } + } + return newBundle +} + +// IsEmpty is used to check if a bundle is empty. +func (b *Bundle) IsEmpty() bool { + return len(b.Rules) == 0 && b.Index == 0 && !b.Override +} + +// ObjectID extracts the db/table/partition ID from the group ID +func (b *Bundle) ObjectID() (int64, error) { + // If the rule doesn't come from TiDB, skip it. + if !strings.HasPrefix(b.ID, BundleIDPrefix) { + return 0, ErrInvalidBundleIDFormat + } + id, err := strconv.ParseInt(b.ID[len(BundleIDPrefix):], 10, 64) + if err != nil { + return 0, fmt.Errorf("%w: %s", ErrInvalidBundleID, err) + } + if id <= 0 { + return 0, fmt.Errorf("%w: %s doesn't include an id", ErrInvalidBundleID, b.ID) + } + return id, nil +} + +func isValidLeaderRule(rule *pd.Rule, dcLabelKey string) bool { + if rule.Role == pd.Leader && rule.Count == 1 { + for _, con := range rule.LabelConstraints { + if con.Op == pd.In && con.Key == dcLabelKey && len(con.Values) == 1 { + return true + } + } + } + return false +} + +// GetLeaderDC returns the leader's DC by Bundle if found. +func (b *Bundle) GetLeaderDC(dcLabelKey string) (string, bool) { + for _, rule := range b.Rules { + if isValidLeaderRule(rule, dcLabelKey) { + return rule.LabelConstraints[0].Values[0], true + } + } + return "", false +} + +// PolicyGetter is the interface to get the policy +type PolicyGetter interface { + GetPolicy(policyID int64) (*model.PolicyInfo, error) +} + +// NewTableBundle creates a bundle for table key range. +// If table is a partitioned table, it also contains the rules that inherited from table for every partition. +// The bundle does not contain the rules specified independently by each partition +func NewTableBundle(getter PolicyGetter, tbInfo *model.TableInfo) (*Bundle, error) { + bundle, err := newBundleFromPolicy(getter, tbInfo.PlacementPolicyRef) + if err != nil { + return nil, err + } + + if bundle == nil { + return nil, nil + } + ids := []int64{tbInfo.ID} + // build the default partition rules in the table-level bundle. + if tbInfo.Partition != nil { + for _, pDef := range tbInfo.Partition.Definitions { + ids = append(ids, pDef.ID) + } + } + bundle.Reset(RuleIndexTable, ids) + return bundle, nil +} + +// NewPartitionBundle creates a bundle for partition key range. +// It only contains the rules specified independently by the partition. +// That is to say the inherited rules from table is not included. +func NewPartitionBundle(getter PolicyGetter, def model.PartitionDefinition) (*Bundle, error) { + bundle, err := newBundleFromPolicy(getter, def.PlacementPolicyRef) + if err != nil { + return nil, err + } + + if bundle != nil { + bundle.Reset(RuleIndexPartition, []int64{def.ID}) + } + + return bundle, nil +} + +// NewPartitionListBundles creates a bundle list for a partition list +func NewPartitionListBundles(getter PolicyGetter, defs []model.PartitionDefinition) ([]*Bundle, error) { + bundles := make([]*Bundle, 0, len(defs)) + // If the partition has the placement rules on their own, build the partition-level bundles additionally. + for _, def := range defs { + bundle, err := NewPartitionBundle(getter, def) + if err != nil { + return nil, err + } + + if bundle != nil { + bundles = append(bundles, bundle) + } + } + return bundles, nil +} + +// NewFullTableBundles returns a bundle list with both table bundle and partition bundles +func NewFullTableBundles(getter PolicyGetter, tbInfo *model.TableInfo) ([]*Bundle, error) { + var bundles []*Bundle + tableBundle, err := NewTableBundle(getter, tbInfo) + if err != nil { + return nil, err + } + + if tableBundle != nil { + bundles = append(bundles, tableBundle) + } + + if tbInfo.Partition != nil { + partitionBundles, err := NewPartitionListBundles(getter, tbInfo.Partition.Definitions) + if err != nil { + return nil, err + } + bundles = append(bundles, partitionBundles...) + } + + return bundles, nil +} + +func newBundleFromPolicy(getter PolicyGetter, ref *model.PolicyRefInfo) (*Bundle, error) { + if ref != nil { + policy, err := getter.GetPolicy(ref.ID) + if err != nil { + return nil, err + } + + return NewBundleFromOptions(policy.PlacementSettings) + } + + return nil, nil +} diff --git a/pkg/ddl/reorg.go b/pkg/ddl/reorg.go index 2154974826a5a..b00762f6ce8bd 100644 --- a/pkg/ddl/reorg.go +++ b/pkg/ddl/reorg.go @@ -733,15 +733,15 @@ func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, if job.SnapshotVer == 0 { // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. // Third step, we need to remove the element information to make sure we can save the reorganized information to storage. - failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockGetIndexRecordErr")); _err_ == nil { if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 3, 4) { if err := rh.RemoveReorgElementFailPoint(job); err != nil { - failpoint.Return(nil, errors.Trace(err)) + return nil, errors.Trace(err) } info.first = true - failpoint.Return(&info, nil) + return &info, nil } - }) + } info.first = true if d.lease > 0 { // Only delay when it's not in test. @@ -776,9 +776,9 @@ func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, zap.String("startKey", hex.EncodeToString(start)), zap.String("endKey", hex.EncodeToString(end))) - failpoint.Inject("errorUpdateReorgHandle", func() (*reorgInfo, error) { + if _, _err_ := failpoint.Eval(_curpkg_("errorUpdateReorgHandle")); _err_ == nil { return &info, errors.New("occur an error when update reorg handle") - }) + } err = rh.InitDDLReorgHandle(job, start, end, pid, elements[0]) if err != nil { return &info, errors.Trace(err) @@ -787,16 +787,16 @@ func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, job.SnapshotVer = ver.Ver element = elements[0] } else { - failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockGetIndexRecordErr")); _err_ == nil { // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. // Second step, we need to remove the element information to make sure we can get the error of "ErrDDLReorgElementNotExist". // However, since "txn.Reset()" will be called later, the reorganized information cannot be saved to storage. if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 2, 3) { if err := rh.RemoveReorgElementFailPoint(job); err != nil { - failpoint.Return(nil, errors.Trace(err)) + return nil, errors.Trace(err) } } - }) + } var err error element, start, end, pid, err = rh.GetDDLReorgHandle(job) diff --git a/pkg/ddl/reorg.go__failpoint_stash__ b/pkg/ddl/reorg.go__failpoint_stash__ new file mode 100644 index 0000000000000..2154974826a5a --- /dev/null +++ b/pkg/ddl/reorg.go__failpoint_stash__ @@ -0,0 +1,982 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "encoding/hex" + "fmt" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/ddl/logutil" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/distsql" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/errctx" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/expression/contextstatic" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tipb/go-tipb" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// reorgCtx is for reorganization. +type reorgCtx struct { + // doneCh is used to notify. + // If the reorganization job is done, we will use this channel to notify outer. + // TODO: Now we use goroutine to simulate reorganization jobs, later we may + // use a persistent job list. + doneCh chan reorgFnResult + // rowCount is used to simulate a job's row count. + rowCount int64 + jobState model.JobState + + mu struct { + sync.Mutex + // warnings are used to store the warnings when doing the reorg job under certain SQL modes. + warnings map[errors.ErrorID]*terror.Error + warningsCount map[errors.ErrorID]int64 + } + + references atomicutil.Int32 +} + +// reorgFnResult records the DDL owner TS before executing reorg function, in order to help +// receiver determine if the result is from reorg function of previous DDL owner in this instance. +type reorgFnResult struct { + ownerTS int64 + err error +} + +func newReorgExprCtx() exprctx.ExprContext { + evalCtx := contextstatic.NewStaticEvalContext( + contextstatic.WithSQLMode(mysql.ModeNone), + contextstatic.WithTypeFlags(types.DefaultStmtFlags), + contextstatic.WithErrLevelMap(stmtctx.DefaultStmtErrLevels), + ) + + planCacheTracker := contextutil.NewPlanCacheTracker(contextutil.IgnoreWarn) + + return contextstatic.NewStaticExprContext( + contextstatic.WithEvalCtx(evalCtx), + contextstatic.WithPlanCacheTracker(&planCacheTracker), + ) +} + +func reorgTypeFlagsWithSQLMode(mode mysql.SQLMode) types.Flags { + return types.StrictFlags. + WithTruncateAsWarning(!mode.HasStrictMode()). + WithIgnoreInvalidDateErr(mode.HasAllowInvalidDatesMode()). + WithIgnoreZeroInDate(!mode.HasStrictMode() || mode.HasAllowInvalidDatesMode()). + WithCastTimeToYearThroughConcat(true) +} + +func reorgErrLevelsWithSQLMode(mode mysql.SQLMode) errctx.LevelMap { + return errctx.LevelMap{ + errctx.ErrGroupTruncate: errctx.ResolveErrLevel(false, !mode.HasStrictMode()), + errctx.ErrGroupBadNull: errctx.ResolveErrLevel(false, !mode.HasStrictMode()), + errctx.ErrGroupDividedByZero: errctx.ResolveErrLevel( + !mode.HasErrorForDivisionByZeroMode(), + !mode.HasStrictMode(), + ), + } +} + +func reorgTimeZoneWithTzLoc(tzLoc *model.TimeZoneLocation) (*time.Location, error) { + if tzLoc == nil { + // It is set to SystemLocation to be compatible with nil LocationInfo. + return timeutil.SystemLocation(), nil + } + return tzLoc.GetLocation() +} + +func newReorgSessCtx(store kv.Storage) sessionctx.Context { + c := mock.NewContext() + c.Store = store + c.GetSessionVars().SetStatusFlag(mysql.ServerStatusAutocommit, false) + + tz := *time.UTC + c.GetSessionVars().TimeZone = &tz + c.GetSessionVars().StmtCtx.SetTimeZone(&tz) + return c +} + +const defaultWaitReorgTimeout = 10 * time.Second + +// ReorgWaitTimeout is the timeout that wait ddl in write reorganization stage. +var ReorgWaitTimeout = 5 * time.Second + +func (rc *reorgCtx) notifyJobState(state model.JobState) { + atomic.StoreInt32((*int32)(&rc.jobState), int32(state)) +} + +func (rc *reorgCtx) isReorgCanceled() bool { + s := atomic.LoadInt32((*int32)(&rc.jobState)) + return int32(model.JobStateCancelled) == s || int32(model.JobStateCancelling) == s +} + +func (rc *reorgCtx) isReorgPaused() bool { + s := atomic.LoadInt32((*int32)(&rc.jobState)) + return int32(model.JobStatePaused) == s || int32(model.JobStatePausing) == s +} + +func (rc *reorgCtx) setRowCount(count int64) { + atomic.StoreInt64(&rc.rowCount, count) +} + +func (rc *reorgCtx) mergeWarnings(warnings map[errors.ErrorID]*terror.Error, warningsCount map[errors.ErrorID]int64) { + if len(warnings) == 0 || len(warningsCount) == 0 { + return + } + rc.mu.Lock() + defer rc.mu.Unlock() + rc.mu.warnings, rc.mu.warningsCount = mergeWarningsAndWarningsCount(warnings, rc.mu.warnings, warningsCount, rc.mu.warningsCount) +} + +func (rc *reorgCtx) resetWarnings() { + rc.mu.Lock() + defer rc.mu.Unlock() + rc.mu.warnings = make(map[errors.ErrorID]*terror.Error) + rc.mu.warningsCount = make(map[errors.ErrorID]int64) +} + +func (rc *reorgCtx) increaseRowCount(count int64) { + atomic.AddInt64(&rc.rowCount, count) +} + +func (rc *reorgCtx) getRowCount() int64 { + row := atomic.LoadInt64(&rc.rowCount) + return row +} + +// runReorgJob is used as a portal to do the reorganization work. +// eg: +// 1: add index +// 2: alter column type +// 3: clean global index +// 4: reorganize partitions +/* + ddl goroutine >---------+ + ^ | + | | + | | + | | <---(doneCh)--- f() + HandleDDLQueue(...) | <---(regular timeout) + | | <---(ctx done) + | | + | | + A more ddl round <-----+ +*/ +// How can we cancel reorg job? +// +// The background reorg is continuously running except for several factors, for instances, ddl owner change, +// logic error (kv duplicate when insert index / cast error when alter column), ctx done, and cancel signal. +// +// When `admin cancel ddl jobs xxx` takes effect, we will give this kind of reorg ddl one more round. +// because we should pull the result from doneCh out, otherwise, the reorg worker will hang on `f()` logic, +// which is a kind of goroutine leak. +// +// That's why we couldn't set the job to rollingback state directly in `convertJob2RollbackJob`, which is a +// cancelling portal for admin cancel action. +// +// In other words, the cancelling signal is informed from the bottom up, we set the atomic cancel variable +// in the cancelling portal to notify the lower worker goroutine, and fetch the cancel error from them in +// the additional ddl round. +// +// After that, we can make sure that the worker goroutine is correctly shut down. +func (w *worker) runReorgJob( + reorgInfo *reorgInfo, + tblInfo *model.TableInfo, + lease time.Duration, + reorgFn func() error, +) error { + job := reorgInfo.Job + d := reorgInfo.d + // This is for tests compatible, because most of the early tests try to build the reorg job manually + // without reorg meta info, which will cause nil pointer in here. + if job.ReorgMeta == nil { + job.ReorgMeta = &model.DDLReorgMeta{ + SQLMode: mysql.ModeNone, + Warnings: make(map[errors.ErrorID]*terror.Error), + WarningsCount: make(map[errors.ErrorID]int64), + Location: &model.TimeZoneLocation{Name: time.UTC.String(), Offset: 0}, + Version: model.CurrentReorgMetaVersion, + } + } + + rc := w.getReorgCtx(job.ID) + if rc == nil { + // This job is cancelling, we should return ErrCancelledDDLJob directly. + // + // Q: Is there any possibility that the job is cancelling and has no reorgCtx? + // A: Yes, consider the case that : + // - we cancel the job when backfilling the last batch of data, the cancel txn is commit first, + // - and then the backfill workers send signal to the `doneCh` of the reorgCtx, + // - and then the DDL worker will remove the reorgCtx + // - and update the DDL job to `done` + // - but at the commit time, the DDL txn will raise a "write conflict" error and retry, and it happens. + if job.IsCancelling() { + return dbterror.ErrCancelledDDLJob + } + + beOwnerTS := w.ddlCtx.reorgCtx.getOwnerTS() + rc = w.newReorgCtx(reorgInfo.Job.ID, reorgInfo.Job.GetRowCount()) + w.wg.Add(1) + go func() { + defer w.wg.Done() + err := reorgFn() + rc.doneCh <- reorgFnResult{ownerTS: beOwnerTS, err: err} + }() + } + + waitTimeout := defaultWaitReorgTimeout + // if lease is 0, we are using a local storage, + // and we can wait the reorganization to be done here. + // if lease > 0, we don't need to wait here because + // we should update some job's progress context and try checking again, + // so we use a very little timeout here. + if lease > 0 { + waitTimeout = ReorgWaitTimeout + } + + // wait reorganization job done or timeout + select { + case res := <-rc.doneCh: + err := res.err + curTS := w.ddlCtx.reorgCtx.getOwnerTS() + if res.ownerTS != curTS { + d.removeReorgCtx(job.ID) + logutil.DDLLogger().Warn("owner ts mismatch, return timeout error and retry", + zap.Int64("prevTS", res.ownerTS), + zap.Int64("curTS", curTS)) + return dbterror.ErrWaitReorgTimeout + } + // Since job is cancelled,we don't care about its partial counts. + if rc.isReorgCanceled() || terror.ErrorEqual(err, dbterror.ErrCancelledDDLJob) { + d.removeReorgCtx(job.ID) + return dbterror.ErrCancelledDDLJob + } + rowCount := rc.getRowCount() + job.SetRowCount(rowCount) + if err != nil { + logutil.DDLLogger().Warn("run reorg job done", zap.Int64("handled rows", rowCount), zap.Error(err)) + } else { + logutil.DDLLogger().Info("run reorg job done", zap.Int64("handled rows", rowCount)) + } + + // Update a job's warnings. + w.mergeWarningsIntoJob(job) + + d.removeReorgCtx(job.ID) + + updateBackfillProgress(w, reorgInfo, tblInfo, rowCount) + + // For other errors, even err is not nil here, we still wait the partial counts to be collected. + // since in the next round, the startKey is brand new which is stored by last time. + if err != nil { + return errors.Trace(err) + } + case <-time.After(waitTimeout): + rowCount := rc.getRowCount() + job.SetRowCount(rowCount) + updateBackfillProgress(w, reorgInfo, tblInfo, rowCount) + + // Update a job's warnings. + w.mergeWarningsIntoJob(job) + + rc.resetWarnings() + + logutil.DDLLogger().Info("run reorg job wait timeout", + zap.Duration("wait time", waitTimeout), + zap.Int64("total added row count", rowCount)) + // If timeout, we will return, check the owner and retry to wait job done again. + return dbterror.ErrWaitReorgTimeout + } + return nil +} + +func overwriteReorgInfoFromGlobalCheckpoint(w *worker, sess *sess.Session, job *model.Job, reorgInfo *reorgInfo) error { + if job.ReorgMeta.ReorgTp != model.ReorgTypeLitMerge { + // Only used for the ingest mode job. + return nil + } + if reorgInfo.mergingTmpIdx { + // Merging the temporary index uses txn mode, so we don't need to consider the checkpoint. + return nil + } + if job.ReorgMeta.IsDistReorg { + // The global checkpoint is not used in distributed tasks. + return nil + } + if w.getReorgCtx(job.ID) != nil { + // We only overwrite from checkpoint when the job runs for the first time on this TiDB instance. + return nil + } + start, pid, err := getImportedKeyFromCheckpoint(sess, job) + if err != nil { + return errors.Trace(err) + } + if pid != reorgInfo.PhysicalTableID { + // Current physical ID does not match checkpoint physical ID. + // Don't overwrite reorgInfo.StartKey. + return nil + } + if len(start) > 0 { + reorgInfo.StartKey = start + } + return nil +} + +func extractElemIDs(r *reorgInfo) []int64 { + elemIDs := make([]int64, 0, len(r.elements)) + for _, elem := range r.elements { + elemIDs = append(elemIDs, elem.ID) + } + return elemIDs +} + +func (w *worker) mergeWarningsIntoJob(job *model.Job) { + rc := w.getReorgCtx(job.ID) + rc.mu.Lock() + partWarnings := rc.mu.warnings + partWarningsCount := rc.mu.warningsCount + rc.mu.Unlock() + warnings, warningsCount := job.GetWarnings() + warnings, warningsCount = mergeWarningsAndWarningsCount(partWarnings, warnings, partWarningsCount, warningsCount) + job.SetWarnings(warnings, warningsCount) +} + +func updateBackfillProgress(w *worker, reorgInfo *reorgInfo, tblInfo *model.TableInfo, + addedRowCount int64) { + if tblInfo == nil { + return + } + progress := float64(0) + if addedRowCount != 0 { + totalCount := getTableTotalCount(w, tblInfo) + if totalCount > 0 { + progress = float64(addedRowCount) / float64(totalCount) + } else { + progress = 0 + } + if progress > 1 { + progress = 1 + } + logutil.DDLLogger().Debug("update progress", + zap.Float64("progress", progress), + zap.Int64("addedRowCount", addedRowCount), + zap.Int64("totalCount", totalCount)) + } + switch reorgInfo.Type { + case model.ActionAddIndex, model.ActionAddPrimaryKey: + var label string + if reorgInfo.mergingTmpIdx { + label = metrics.LblAddIndexMerge + } else { + label = metrics.LblAddIndex + } + metrics.GetBackfillProgressByLabel(label, reorgInfo.SchemaName, tblInfo.Name.String()).Set(progress * 100) + case model.ActionModifyColumn: + metrics.GetBackfillProgressByLabel(metrics.LblModifyColumn, reorgInfo.SchemaName, tblInfo.Name.String()).Set(progress * 100) + case model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, reorgInfo.SchemaName, tblInfo.Name.String()).Set(progress * 100) + } +} + +func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { + var ctx sessionctx.Context + ctx, err := w.sessPool.Get() + if err != nil { + return statistics.PseudoRowCount + } + defer w.sessPool.Put(ctx) + + // `mock.Context` is used in tests, which doesn't support sql exec + if _, ok := ctx.(*mock.Context); ok { + return statistics.PseudoRowCount + } + + executor := ctx.GetRestrictedSQLExecutor() + var rows []chunk.Row + if tblInfo.Partition != nil && len(tblInfo.Partition.DroppingDefinitions) > 0 { + // if Reorganize Partition, only select number of rows from the selected partitions! + defs := tblInfo.Partition.DroppingDefinitions + partIDs := make([]string, 0, len(defs)) + for _, def := range defs { + partIDs = append(partIDs, strconv.FormatInt(def.ID, 10)) + } + sql := "select sum(table_rows) from information_schema.partitions where tidb_partition_id in (%?);" + rows, _, err = executor.ExecRestrictedSQL(w.ctx, nil, sql, strings.Join(partIDs, ",")) + } else { + sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" + rows, _, err = executor.ExecRestrictedSQL(w.ctx, nil, sql, tblInfo.ID) + } + if err != nil { + return statistics.PseudoRowCount + } + if len(rows) != 1 { + return statistics.PseudoRowCount + } + return rows[0].GetInt64(0) +} + +func (dc *ddlCtx) isReorgCancelled(jobID int64) bool { + return dc.getReorgCtx(jobID).isReorgCanceled() +} +func (dc *ddlCtx) isReorgPaused(jobID int64) bool { + return dc.getReorgCtx(jobID).isReorgPaused() +} + +func (dc *ddlCtx) isReorgRunnable(jobID int64, isDistReorg bool) error { + if dc.ctx.Err() != nil { + // Worker is closed. So it can't do the reorganization. + return dbterror.ErrInvalidWorker.GenWithStack("worker is closed") + } + + if dc.isReorgCancelled(jobID) { + // Job is cancelled. So it can't be done. + return dbterror.ErrCancelledDDLJob + } + + if dc.isReorgPaused(jobID) { + logutil.DDLLogger().Warn("job paused by user", zap.String("ID", dc.uuid)) + return dbterror.ErrPausedDDLJob.GenWithStackByArgs(jobID) + } + + // If isDistReorg is true, we needn't check if it is owner. + if isDistReorg { + return nil + } + if !dc.isOwner() { + // If it's not the owner, we will try later, so here just returns an error. + logutil.DDLLogger().Info("DDL is not the DDL owner", zap.String("ID", dc.uuid)) + return errors.Trace(dbterror.ErrNotOwner) + } + return nil +} + +type reorgInfo struct { + *model.Job + + StartKey kv.Key + EndKey kv.Key + d *ddlCtx + first bool + mergingTmpIdx bool + // PhysicalTableID is used for partitioned table. + // DDL reorganize for a partitioned table will handle partitions one by one, + // PhysicalTableID is used to trace the current partition we are handling. + // If the table is not partitioned, PhysicalTableID would be TableID. + PhysicalTableID int64 + dbInfo *model.DBInfo + elements []*meta.Element + currElement *meta.Element +} + +func (r *reorgInfo) NewJobContext() *JobContext { + return r.d.jobContext(r.Job.ID, r.Job.ReorgMeta) +} + +func (r *reorgInfo) String() string { + var isEnabled bool + if ingest.LitInitialized { + _, isEnabled = ingest.LitBackCtxMgr.Load(r.Job.ID) + } + return "CurrElementType:" + string(r.currElement.TypeKey) + "," + + "CurrElementID:" + strconv.FormatInt(r.currElement.ID, 10) + "," + + "StartKey:" + hex.EncodeToString(r.StartKey) + "," + + "EndKey:" + hex.EncodeToString(r.EndKey) + "," + + "First:" + strconv.FormatBool(r.first) + "," + + "PhysicalTableID:" + strconv.FormatInt(r.PhysicalTableID, 10) + "," + + "Ingest mode:" + strconv.FormatBool(isEnabled) +} + +func constructDescTableScanPB(physicalTableID int64, tblInfo *model.TableInfo, handleCols []*model.ColumnInfo) *tipb.Executor { + tblScan := tables.BuildTableScanFromInfos(tblInfo, handleCols) + tblScan.TableId = physicalTableID + tblScan.Desc = true + return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan} +} + +func constructLimitPB(count uint64) *tipb.Executor { + limitExec := &tipb.Limit{ + Limit: count, + } + return &tipb.Executor{Tp: tipb.ExecType_TypeLimit, Limit: limitExec} +} + +func buildDescTableScanDAG(distSQLCtx *distsqlctx.DistSQLContext, tbl table.PhysicalTable, handleCols []*model.ColumnInfo, limit uint64) (*tipb.DAGRequest, error) { + dagReq := &tipb.DAGRequest{} + _, timeZoneOffset := time.Now().In(time.UTC).Zone() + dagReq.TimeZoneOffset = int64(timeZoneOffset) + for i := range handleCols { + dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) + } + dagReq.Flags |= model.FlagInSelectStmt + + tblScanExec := constructDescTableScanPB(tbl.GetPhysicalID(), tbl.Meta(), handleCols) + dagReq.Executors = append(dagReq.Executors, tblScanExec) + dagReq.Executors = append(dagReq.Executors, constructLimitPB(limit)) + distsql.SetEncodeType(distSQLCtx, dagReq) + return dagReq, nil +} + +func getColumnsTypes(columns []*model.ColumnInfo) []*types.FieldType { + colTypes := make([]*types.FieldType, 0, len(columns)) + for _, col := range columns { + colTypes = append(colTypes, &col.FieldType) + } + return colTypes +} + +// buildDescTableScan builds a desc table scan upon tblInfo. +func (dc *ddlCtx) buildDescTableScan(ctx *JobContext, startTS uint64, tbl table.PhysicalTable, + handleCols []*model.ColumnInfo, limit uint64) (distsql.SelectResult, error) { + distSQLCtx := newDefaultReorgDistSQLCtx(dc.store.GetClient()) + dagPB, err := buildDescTableScanDAG(distSQLCtx, tbl, handleCols, limit) + if err != nil { + return nil, errors.Trace(err) + } + var b distsql.RequestBuilder + var builder *distsql.RequestBuilder + var ranges []*ranger.Range + if tbl.Meta().IsCommonHandle { + ranges = ranger.FullNotNullRange() + } else { + ranges = ranger.FullIntRange(false) + } + builder = b.SetHandleRanges(distSQLCtx, tbl.GetPhysicalID(), tbl.Meta().IsCommonHandle, ranges) + builder.SetDAGRequest(dagPB). + SetStartTS(startTS). + SetKeepOrder(true). + SetConcurrency(1). + SetDesc(true). + SetResourceGroupTagger(ctx.getResourceGroupTaggerForTopSQL()). + SetResourceGroupName(ctx.resourceGroupName) + + builder.Request.NotFillCache = true + builder.Request.Priority = kv.PriorityLow + builder.RequestSource.RequestSourceInternal = true + builder.RequestSource.RequestSourceType = ctx.ddlJobSourceType() + + kvReq, err := builder.Build() + if err != nil { + return nil, errors.Trace(err) + } + + result, err := distsql.Select(ctx.ddlJobCtx, distSQLCtx, kvReq, getColumnsTypes(handleCols)) + if err != nil { + return nil, errors.Trace(err) + } + return result, nil +} + +// GetTableMaxHandle gets the max handle of a PhysicalTable. +func (dc *ddlCtx) GetTableMaxHandle(ctx *JobContext, startTS uint64, tbl table.PhysicalTable) (maxHandle kv.Handle, emptyTable bool, err error) { + var handleCols []*model.ColumnInfo + var pkIdx *model.IndexInfo + tblInfo := tbl.Meta() + switch { + case tblInfo.PKIsHandle: + for _, col := range tbl.Meta().Columns { + if mysql.HasPriKeyFlag(col.GetFlag()) { + handleCols = []*model.ColumnInfo{col} + break + } + } + case tblInfo.IsCommonHandle: + pkIdx = tables.FindPrimaryIndex(tblInfo) + cols := tblInfo.Cols() + for _, idxCol := range pkIdx.Columns { + handleCols = append(handleCols, cols[idxCol.Offset]) + } + default: + handleCols = []*model.ColumnInfo{model.NewExtraHandleColInfo()} + } + + // build a desc scan of tblInfo, which limit is 1, we can use it to retrieve the last handle of the table. + result, err := dc.buildDescTableScan(ctx, startTS, tbl, handleCols, 1) + if err != nil { + return nil, false, errors.Trace(err) + } + defer terror.Call(result.Close) + + chk := chunk.New(getColumnsTypes(handleCols), 1, 1) + err = result.Next(ctx.ddlJobCtx, chk) + if err != nil { + return nil, false, errors.Trace(err) + } + + if chk.NumRows() == 0 { + // empty table + return nil, true, nil + } + row := chk.GetRow(0) + if tblInfo.IsCommonHandle { + maxHandle, err = buildCommonHandleFromChunkRow(time.UTC, tblInfo, pkIdx, handleCols, row) + return maxHandle, false, err + } + return kv.IntHandle(row.GetInt64(0)), false, nil +} + +func buildCommonHandleFromChunkRow(loc *time.Location, tblInfo *model.TableInfo, idxInfo *model.IndexInfo, + cols []*model.ColumnInfo, row chunk.Row) (kv.Handle, error) { + fieldTypes := make([]*types.FieldType, 0, len(cols)) + for _, col := range cols { + fieldTypes = append(fieldTypes, &col.FieldType) + } + datumRow := row.GetDatumRow(fieldTypes) + tablecodec.TruncateIndexValues(tblInfo, idxInfo, datumRow) + + var handleBytes []byte + handleBytes, err := codec.EncodeKey(loc, nil, datumRow...) + if err != nil { + return nil, err + } + return kv.NewCommonHandle(handleBytes) +} + +// getTableRange gets the start and end handle of a table (or partition). +func getTableRange(ctx *JobContext, d *ddlCtx, tbl table.PhysicalTable, snapshotVer uint64, priority int) (startHandleKey, endHandleKey kv.Key, err error) { + // Get the start handle of this partition. + err = iterateSnapshotKeys(ctx, d.store, priority, tbl.RecordPrefix(), snapshotVer, nil, nil, + func(_ kv.Handle, rowKey kv.Key, _ []byte) (bool, error) { + startHandleKey = rowKey + return false, nil + }) + if err != nil { + return startHandleKey, endHandleKey, errors.Trace(err) + } + maxHandle, isEmptyTable, err := d.GetTableMaxHandle(ctx, snapshotVer, tbl) + if err != nil { + return startHandleKey, nil, errors.Trace(err) + } + if maxHandle != nil { + endHandleKey = tablecodec.EncodeRecordKey(tbl.RecordPrefix(), maxHandle).Next() + } + if isEmptyTable || endHandleKey.Cmp(startHandleKey) <= 0 { + logutil.DDLLogger().Info("get noop table range", + zap.String("table", fmt.Sprintf("%v", tbl.Meta())), + zap.Int64("table/partition ID", tbl.GetPhysicalID()), + zap.String("start key", hex.EncodeToString(startHandleKey)), + zap.String("end key", hex.EncodeToString(endHandleKey)), + zap.Bool("is empty table", isEmptyTable)) + if startHandleKey == nil { + endHandleKey = nil + } else { + endHandleKey = startHandleKey.Next() + } + } + return +} + +func getValidCurrentVersion(store kv.Storage) (ver kv.Version, err error) { + ver, err = store.CurrentVersion(kv.GlobalTxnScope) + if err != nil { + return ver, errors.Trace(err) + } else if ver.Ver <= 0 { + return ver, dbterror.ErrInvalidStoreVer.GenWithStack("invalid storage current version %d", ver.Ver) + } + return ver, nil +} + +func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, dbInfo *model.DBInfo, + tbl table.Table, elements []*meta.Element, mergingTmpIdx bool) (*reorgInfo, error) { + var ( + element *meta.Element + start kv.Key + end kv.Key + pid int64 + info reorgInfo + ) + + if job.SnapshotVer == 0 { + // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. + // Third step, we need to remove the element information to make sure we can save the reorganized information to storage. + failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { + if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 3, 4) { + if err := rh.RemoveReorgElementFailPoint(job); err != nil { + failpoint.Return(nil, errors.Trace(err)) + } + info.first = true + failpoint.Return(&info, nil) + } + }) + + info.first = true + if d.lease > 0 { // Only delay when it's not in test. + delayForAsyncCommit() + } + ver, err := getValidCurrentVersion(d.store) + if err != nil { + return nil, errors.Trace(err) + } + tblInfo := tbl.Meta() + pid = tblInfo.ID + var tb table.PhysicalTable + if pi := tblInfo.GetPartitionInfo(); pi != nil { + pid = pi.Definitions[0].ID + tb = tbl.(table.PartitionedTable).GetPartition(pid) + } else { + tb = tbl.(table.PhysicalTable) + } + if mergingTmpIdx { + firstElemTempID := tablecodec.TempIndexPrefix | elements[0].ID + lastElemTempID := tablecodec.TempIndexPrefix | elements[len(elements)-1].ID + start = tablecodec.EncodeIndexSeekKey(pid, firstElemTempID, nil) + end = tablecodec.EncodeIndexSeekKey(pid, lastElemTempID, []byte{255}) + } else { + start, end, err = getTableRange(ctx, d, tb, ver.Ver, job.Priority) + if err != nil { + return nil, errors.Trace(err) + } + } + logutil.DDLLogger().Info("job get table range", + zap.Int64("jobID", job.ID), zap.Int64("physicalTableID", pid), + zap.String("startKey", hex.EncodeToString(start)), + zap.String("endKey", hex.EncodeToString(end))) + + failpoint.Inject("errorUpdateReorgHandle", func() (*reorgInfo, error) { + return &info, errors.New("occur an error when update reorg handle") + }) + err = rh.InitDDLReorgHandle(job, start, end, pid, elements[0]) + if err != nil { + return &info, errors.Trace(err) + } + // Update info should after data persistent. + job.SnapshotVer = ver.Ver + element = elements[0] + } else { + failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { + // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. + // Second step, we need to remove the element information to make sure we can get the error of "ErrDDLReorgElementNotExist". + // However, since "txn.Reset()" will be called later, the reorganized information cannot be saved to storage. + if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 2, 3) { + if err := rh.RemoveReorgElementFailPoint(job); err != nil { + failpoint.Return(nil, errors.Trace(err)) + } + } + }) + + var err error + element, start, end, pid, err = rh.GetDDLReorgHandle(job) + if err != nil { + // If the reorg element doesn't exist, this reorg info should be saved by the older TiDB versions. + // It's compatible with the older TiDB versions. + // We'll try to remove it in the next major TiDB version. + if meta.ErrDDLReorgElementNotExist.Equal(err) { + job.SnapshotVer = 0 + logutil.DDLLogger().Warn("get reorg info, the element does not exist", zap.Stringer("job", job)) + if job.IsCancelling() { + return nil, nil + } + } + return &info, errors.Trace(err) + } + } + info.Job = job + info.d = d + info.StartKey = start + info.EndKey = end + info.PhysicalTableID = pid + info.currElement = element + info.elements = elements + info.mergingTmpIdx = mergingTmpIdx + info.dbInfo = dbInfo + + return &info, nil +} + +func getReorgInfoFromPartitions(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, dbInfo *model.DBInfo, tbl table.PartitionedTable, partitionIDs []int64, elements []*meta.Element) (*reorgInfo, error) { + var ( + element *meta.Element + start kv.Key + end kv.Key + pid int64 + info reorgInfo + ) + if job.SnapshotVer == 0 { + info.first = true + if d.lease > 0 { // Only delay when it's not in test. + delayForAsyncCommit() + } + ver, err := getValidCurrentVersion(d.store) + if err != nil { + return nil, errors.Trace(err) + } + pid = partitionIDs[0] + physTbl := tbl.GetPartition(pid) + + start, end, err = getTableRange(ctx, d, physTbl, ver.Ver, job.Priority) + if err != nil { + return nil, errors.Trace(err) + } + logutil.DDLLogger().Info("job get table range", + zap.Int64("job ID", job.ID), zap.Int64("physical table ID", pid), + zap.String("start key", hex.EncodeToString(start)), + zap.String("end key", hex.EncodeToString(end))) + + err = rh.InitDDLReorgHandle(job, start, end, pid, elements[0]) + if err != nil { + return &info, errors.Trace(err) + } + // Update info should after data persistent. + job.SnapshotVer = ver.Ver + element = elements[0] + } else { + var err error + element, start, end, pid, err = rh.GetDDLReorgHandle(job) + if err != nil { + // If the reorg element doesn't exist, this reorg info should be saved by the older TiDB versions. + // It's compatible with the older TiDB versions. + // We'll try to remove it in the next major TiDB version. + if meta.ErrDDLReorgElementNotExist.Equal(err) { + job.SnapshotVer = 0 + logutil.DDLLogger().Warn("get reorg info, the element does not exist", zap.Stringer("job", job)) + } + return &info, errors.Trace(err) + } + } + info.Job = job + info.d = d + info.StartKey = start + info.EndKey = end + info.PhysicalTableID = pid + info.currElement = element + info.elements = elements + info.dbInfo = dbInfo + + return &info, nil +} + +// UpdateReorgMeta creates a new transaction and updates tidb_ddl_reorg table, +// so the reorg can restart in case of issues. +func (r *reorgInfo) UpdateReorgMeta(startKey kv.Key, pool *sess.Pool) (err error) { + if startKey == nil && r.EndKey == nil { + return nil + } + sctx, err := pool.Get() + if err != nil { + return + } + defer pool.Put(sctx) + + se := sess.NewSession(sctx) + err = se.Begin(context.Background()) + if err != nil { + return + } + rh := newReorgHandler(se) + err = updateDDLReorgHandle(rh.s, r.Job.ID, startKey, r.EndKey, r.PhysicalTableID, r.currElement) + err1 := se.Commit(context.Background()) + if err == nil { + err = err1 + } + return errors.Trace(err) +} + +// reorgHandler is used to handle the reorg information duration reorganization DDL job. +type reorgHandler struct { + s *sess.Session +} + +// NewReorgHandlerForTest creates a new reorgHandler, only used in test. +func NewReorgHandlerForTest(se sessionctx.Context) *reorgHandler { + return newReorgHandler(sess.NewSession(se)) +} + +func newReorgHandler(sess *sess.Session) *reorgHandler { + return &reorgHandler{s: sess} +} + +// InitDDLReorgHandle initializes the job reorganization information. +func (r *reorgHandler) InitDDLReorgHandle(job *model.Job, startKey, endKey kv.Key, physicalTableID int64, element *meta.Element) error { + return initDDLReorgHandle(r.s, job.ID, startKey, endKey, physicalTableID, element) +} + +// RemoveReorgElementFailPoint removes the element of the reorganization information. +func (r *reorgHandler) RemoveReorgElementFailPoint(job *model.Job) error { + return removeReorgElement(r.s, job) +} + +// RemoveDDLReorgHandle removes the job reorganization related handles. +func (r *reorgHandler) RemoveDDLReorgHandle(job *model.Job, elements []*meta.Element) error { + return removeDDLReorgHandle(r.s, job, elements) +} + +// cleanupDDLReorgHandles removes the job reorganization related handles. +func cleanupDDLReorgHandles(job *model.Job, s *sess.Session) { + if job != nil && !job.IsFinished() && !job.IsSynced() { + // Job is given, but it is neither finished nor synced; do nothing + return + } + + err := cleanDDLReorgHandles(s, job) + if err != nil { + // ignore error, cleanup is not that critical + logutil.DDLLogger().Warn("Failed removing the DDL reorg entry in tidb_ddl_reorg", zap.Stringer("job", job), zap.Error(err)) + } +} + +// GetDDLReorgHandle gets the latest processed DDL reorganize position. +func (r *reorgHandler) GetDDLReorgHandle(job *model.Job) (element *meta.Element, startKey, endKey kv.Key, physicalTableID int64, err error) { + element, startKey, endKey, physicalTableID, err = getDDLReorgHandle(r.s, job) + if err != nil { + return element, startKey, endKey, physicalTableID, err + } + adjustedEndKey := adjustEndKeyAcrossVersion(job, endKey) + return element, startKey, adjustedEndKey, physicalTableID, nil +} + +// #46306 changes the table range from [start_key, end_key] to [start_key, end_key.next). +// For old version TiDB, the semantic is still [start_key, end_key], we need to adjust it in new version TiDB. +func adjustEndKeyAcrossVersion(job *model.Job, endKey kv.Key) kv.Key { + if job.ReorgMeta != nil && job.ReorgMeta.Version == model.ReorgMetaVersion0 { + logutil.DDLLogger().Info("adjust range end key for old version ReorgMetas", + zap.Int64("jobID", job.ID), + zap.Int64("reorgMetaVersion", job.ReorgMeta.Version), + zap.String("endKey", hex.EncodeToString(endKey))) + return endKey.Next() + } + return endKey +} diff --git a/pkg/ddl/rollingback.go b/pkg/ddl/rollingback.go index b72b2ee76d392..459955cc150d8 100644 --- a/pkg/ddl/rollingback.go +++ b/pkg/ddl/rollingback.go @@ -52,11 +52,11 @@ func convertAddIdxJob2RollbackJob( allIndexInfos []*model.IndexInfo, err error, ) (int64, error) { - failpoint.Inject("mockConvertAddIdxJob2RollbackJobError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockConvertAddIdxJob2RollbackJobError")); _err_ == nil { if val.(bool) { - failpoint.Return(0, errors.New("mock convert add index job to rollback job error")) + return 0, errors.New("mock convert add index job to rollback job error") } - }) + } originalState := allIndexInfos[0].State idxNames := make([]model.CIStr, 0, len(allIndexInfos)) diff --git a/pkg/ddl/rollingback.go__failpoint_stash__ b/pkg/ddl/rollingback.go__failpoint_stash__ new file mode 100644 index 0000000000000..b72b2ee76d392 --- /dev/null +++ b/pkg/ddl/rollingback.go__failpoint_stash__ @@ -0,0 +1,629 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "fmt" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/ingest" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util/dbterror" + "go.uber.org/zap" +) + +// UpdateColsNull2NotNull changes the null option of columns of an index. +func UpdateColsNull2NotNull(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) error { + nullCols, err := getNullColInfos(tblInfo, indexInfo) + if err != nil { + return errors.Trace(err) + } + + for _, col := range nullCols { + col.AddFlag(mysql.NotNullFlag) + col.DelFlag(mysql.PreventNullInsertFlag) + } + return nil +} + +func convertAddIdxJob2RollbackJob( + d *ddlCtx, + t *meta.Meta, + job *model.Job, + tblInfo *model.TableInfo, + allIndexInfos []*model.IndexInfo, + err error, +) (int64, error) { + failpoint.Inject("mockConvertAddIdxJob2RollbackJobError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(0, errors.New("mock convert add index job to rollback job error")) + } + }) + + originalState := allIndexInfos[0].State + idxNames := make([]model.CIStr, 0, len(allIndexInfos)) + ifExists := make([]bool, 0, len(allIndexInfos)) + for _, indexInfo := range allIndexInfos { + if indexInfo.Primary { + nullCols, err := getNullColInfos(tblInfo, indexInfo) + if err != nil { + return 0, errors.Trace(err) + } + for _, col := range nullCols { + // Field PreventNullInsertFlag flag reset. + col.DelFlag(mysql.PreventNullInsertFlag) + } + } + // If add index job rollbacks in write reorganization state, its need to delete all keys which has been added. + // Its work is the same as drop index job do. + // The write reorganization state in add index job that likes write only state in drop index job. + // So the next state is delete only state. + indexInfo.State = model.StateDeleteOnly + idxNames = append(idxNames, indexInfo.Name) + ifExists = append(ifExists, false) + } + + // the second and the third args will be used in onDropIndex. + job.Args = []any{idxNames, ifExists, getPartitionIDs(tblInfo)} + job.SchemaState = model.StateDeleteOnly + ver, err1 := updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateDeleteOnly) + if err1 != nil { + return ver, errors.Trace(err1) + } + job.State = model.JobStateRollingback + // TODO(tangenta): get duplicate column and match index. + err = completeErr(err, allIndexInfos[0]) + if ingest.LitBackCtxMgr != nil { + ingest.LitBackCtxMgr.Unregister(job.ID) + } + return ver, errors.Trace(err) +} + +// convertNotReorgAddIdxJob2RollbackJob converts the add index job that are not started workers to rollingbackJob, +// to rollback add index operations. job.SnapshotVer == 0 indicates the workers are not started. +func convertNotReorgAddIdxJob2RollbackJob(d *ddlCtx, t *meta.Meta, job *model.Job, occuredErr error) (ver int64, err error) { + defer func() { + if ingest.LitBackCtxMgr != nil { + ingest.LitBackCtxMgr.Unregister(job.ID) + } + }() + schemaID := job.SchemaID + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return ver, errors.Trace(err) + } + + unique := make([]bool, 1) + indexName := make([]model.CIStr, 1) + indexPartSpecifications := make([][]*ast.IndexPartSpecification, 1) + indexOption := make([]*ast.IndexOption, 1) + + err = job.DecodeArgs(&unique[0], &indexName[0], &indexPartSpecifications[0], &indexOption[0]) + if err != nil { + err = job.DecodeArgs(&unique, &indexName, &indexPartSpecifications, &indexOption) + } + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + var indexesInfo []*model.IndexInfo + for _, idxName := range indexName { + indexInfo := tblInfo.FindIndexByName(idxName.L) + if indexInfo != nil { + indexesInfo = append(indexesInfo, indexInfo) + } + } + if len(indexesInfo) == 0 { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + return convertAddIdxJob2RollbackJob(d, t, job, tblInfo, indexesInfo, occuredErr) +} + +// rollingbackModifyColumn change the modifying-column job into rolling back state. +// Since modifying column job has two types: normal-type and reorg-type, we should handle it respectively. +// normal-type has only two states: None -> Public +// reorg-type has five states: None -> Delete-only -> Write-only -> Write-org -> Public +func rollingbackModifyColumn(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + if needNotifyAndStopReorgWorker(job) { + // column type change workers are started. we have to ask them to exit. + w.jobLogger(job).Info("run the cancelling DDL job", zap.String("job", job.String())) + d.notifyReorgWorkerJobStateChange(job) + // Give the this kind of ddl one more round to run, the dbterror.ErrCancelledDDLJob should be fetched from the bottom up. + return w.onModifyColumn(d, t, job) + } + _, tblInfo, oldCol, jp, err := getModifyColumnInfo(t, job) + if err != nil { + return ver, err + } + if !needChangeColumnData(oldCol, jp.newCol) { + // Normal-type rolling back + if job.SchemaState == model.StateNone { + // When change null to not null, although state is unchanged with none, the oldCol flag's has been changed to preNullInsertFlag. + // To roll back this kind of normal job, it is necessary to mark the state as JobStateRollingback to restore the old col's flag. + if jp.modifyColumnTp == mysql.TypeNull && tblInfo.Columns[oldCol.Offset].GetFlag()|mysql.PreventNullInsertFlag != 0 { + job.State = model.JobStateRollingback + return ver, dbterror.ErrCancelledDDLJob + } + // Normal job with stateNone can be cancelled directly. + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + // StatePublic couldn't be cancelled. + job.State = model.JobStateRunning + return ver, nil + } + // reorg-type rolling back + if jp.changingCol == nil { + // The job hasn't been handled and we cancel it directly. + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + // The job has been in its middle state (but the reorg worker hasn't started) and we roll it back here. + job.State = model.JobStateRollingback + return ver, dbterror.ErrCancelledDDLJob +} + +func rollingbackAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + tblInfo, columnInfo, col, _, _, err := checkAddColumn(t, job) + if err != nil { + return ver, errors.Trace(err) + } + if columnInfo == nil { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + + originalState := columnInfo.State + columnInfo.State = model.StateDeleteOnly + job.SchemaState = model.StateDeleteOnly + + job.Args = []any{col.Name} + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + + job.State = model.JobStateRollingback + return ver, dbterror.ErrCancelledDDLJob +} + +func rollingbackDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + _, colInfo, idxInfos, _, err := checkDropColumn(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + + for _, indexInfo := range idxInfos { + switch indexInfo.State { + case model.StateWriteOnly, model.StateDeleteOnly, model.StateDeleteReorganization, model.StateNone: + // We can not rollback now, so just continue to drop index. + // In function isJobRollbackable will let job rollback when state is StateNone. + // When there is no index related to the drop column job it is OK, but when there has indices, we should + // make sure the job is not rollback. + job.State = model.JobStateRunning + return ver, nil + case model.StatePublic: + default: + return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", indexInfo.State) + } + } + + // StatePublic means when the job is not running yet. + if colInfo.State == model.StatePublic { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + // In the state of drop column `write only -> delete only -> reorganization`, + // We can not rollback now, so just continue to drop column. + job.State = model.JobStateRunning + return ver, nil +} + +func rollingbackDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + _, indexInfo, _, err := checkDropIndex(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + + switch indexInfo[0].State { + case model.StateWriteOnly, model.StateDeleteOnly, model.StateDeleteReorganization, model.StateNone: + // We can not rollback now, so just continue to drop index. + // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. + job.State = model.JobStateRunning + return ver, nil + case model.StatePublic: + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + default: + return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", indexInfo[0].State) + } +} + +func rollingbackAddIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, isPK bool) (ver int64, err error) { + if needNotifyAndStopReorgWorker(job) { + // add index workers are started. need to ask them to exit. + w.jobLogger(job).Info("run the cancelling DDL job", zap.String("job", job.String())) + d.notifyReorgWorkerJobStateChange(job) + ver, err = w.onCreateIndex(d, t, job, isPK) + } else { + // add index's reorg workers are not running, remove the indexInfo in tableInfo. + ver, err = convertNotReorgAddIdxJob2RollbackJob(d, t, job, dbterror.ErrCancelledDDLJob) + } + return +} + +func needNotifyAndStopReorgWorker(job *model.Job) bool { + if job.SchemaState == model.StateWriteReorganization && job.SnapshotVer != 0 { + // If the value of SnapshotVer isn't zero, it means the reorg workers have been started. + if job.MultiSchemaInfo != nil { + // However, if the sub-job is non-revertible, it means the reorg process is finished. + // We don't need to start another round to notify reorg workers to exit. + return job.MultiSchemaInfo.Revertible + } + return true + } + return false +} + +// rollbackExchangeTablePartition will clear the non-partitioned +// table's ExchangePartitionInfo state. +func rollbackExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo) (ver int64, err error) { + tblInfo.ExchangePartitionInfo = nil + job.State = model.JobStateRollbackDone + job.SchemaState = model.StatePublic + if len(tblInfo.Constraints) == 0 { + return updateVersionAndTableInfo(d, t, job, tblInfo, true) + } + var ( + defID int64 + ptSchemaID int64 + ptID int64 + partName string + withValidation bool + ) + if err = job.DecodeArgs(&defID, &ptSchemaID, &ptID, &partName, &withValidation); err != nil { + return ver, errors.Trace(err) + } + pt, err := getTableInfo(t, ptID, ptSchemaID) + if err != nil { + return ver, errors.Trace(err) + } + pt.ExchangePartitionInfo = nil + var ptInfo []schemaIDAndTableInfo + ptInfo = append(ptInfo, schemaIDAndTableInfo{ + schemaID: ptSchemaID, + tblInfo: pt, + }) + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true, ptInfo...) + return ver, errors.Trace(err) +} + +func rollingbackExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + if job.SchemaState == model.StateNone { + // Nothing is changed + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + var nt *model.TableInfo + nt, err = GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + ver, err = rollbackExchangeTablePartition(d, t, job, nt) + return ver, errors.Trace(err) +} + +func convertAddTablePartitionJob2RollbackJob(d *ddlCtx, t *meta.Meta, job *model.Job, otherwiseErr error, tblInfo *model.TableInfo) (ver int64, err error) { + addingDefinitions := tblInfo.Partition.AddingDefinitions + partNames := make([]string, 0, len(addingDefinitions)) + for _, pd := range addingDefinitions { + partNames = append(partNames, pd.Name.L) + } + if job.Type == model.ActionReorganizePartition || + job.Type == model.ActionAlterTablePartitioning || + job.Type == model.ActionRemovePartitioning { + partInfo := &model.PartitionInfo{} + var pNames []string + err = job.DecodeArgs(&pNames, &partInfo) + if err != nil { + return ver, err + } + job.Args = []any{partNames, partInfo} + } else { + job.Args = []any{partNames} + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.State = model.JobStateRollingback + return ver, errors.Trace(otherwiseErr) +} + +func rollingbackAddTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + tblInfo, _, addingDefinitions, err := checkAddPartition(t, job) + if err != nil { + return ver, errors.Trace(err) + } + // addingDefinitions' len = 0 means the job hasn't reached the replica-only state. + if len(addingDefinitions) == 0 { + job.State = model.JobStateCancelled + return ver, errors.Trace(dbterror.ErrCancelledDDLJob) + } + // addingDefinitions is also in tblInfo, here pass the tblInfo as parameter directly. + return convertAddTablePartitionJob2RollbackJob(d, t, job, dbterror.ErrCancelledDDLJob, tblInfo) +} + +func rollingbackDropTableOrView(t *meta.Meta, job *model.Job) error { + tblInfo, err := checkTableExistAndCancelNonExistJob(t, job, job.SchemaID) + if err != nil { + return errors.Trace(err) + } + // To simplify the rollback logic, cannot be canceled after job start to run. + // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. + if tblInfo.State == model.StatePublic { + job.State = model.JobStateCancelled + return dbterror.ErrCancelledDDLJob + } + job.State = model.JobStateRunning + return nil +} + +func rollingbackDropTablePartition(t *meta.Meta, job *model.Job) (ver int64, err error) { + _, err = GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + return cancelOnlyNotHandledJob(job, model.StatePublic) +} + +func rollingbackDropSchema(t *meta.Meta, job *model.Job) error { + dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) + if err != nil { + return errors.Trace(err) + } + // To simplify the rollback logic, cannot be canceled after job start to run. + // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. + if dbInfo.State == model.StatePublic { + job.State = model.JobStateCancelled + return dbterror.ErrCancelledDDLJob + } + job.State = model.JobStateRunning + return nil +} + +func rollingbackRenameIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { + tblInfo, from, _, err := checkRenameIndex(t, job) + if err != nil { + return ver, errors.Trace(err) + } + // Here rename index is done in a transaction, if the job is not completed, it can be canceled. + idx := tblInfo.FindIndexByName(from.L) + if idx.State == model.StatePublic { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + job.State = model.JobStateRunning + return ver, errors.Trace(err) +} + +func cancelOnlyNotHandledJob(job *model.Job, initialState model.SchemaState) (ver int64, err error) { + // We can only cancel the not handled job. + if job.SchemaState == initialState { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + + job.State = model.JobStateRunning + + return ver, nil +} + +func rollingbackTruncateTable(t *meta.Meta, job *model.Job) (ver int64, err error) { + _, err = GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + return cancelOnlyNotHandledJob(job, model.StateNone) +} + +func rollingbackReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + if job.SchemaState == model.StateNone { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + + // addingDefinitions is also in tblInfo, here pass the tblInfo as parameter directly. + // TODO: Test this with reorganize partition p1 into (partition p1 ...)! + return convertAddTablePartitionJob2RollbackJob(d, t, job, dbterror.ErrCancelledDDLJob, tblInfo) +} + +func pauseReorgWorkers(w *worker, d *ddlCtx, job *model.Job) (err error) { + if needNotifyAndStopReorgWorker(job) { + w.jobLogger(job).Info("pausing the DDL job", zap.String("job", job.String())) + d.notifyReorgWorkerJobStateChange(job) + } + + return dbterror.ErrPausedDDLJob.GenWithStackByArgs(job.ID) +} + +func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + switch job.Type { + case model.ActionAddColumn: + ver, err = rollingbackAddColumn(d, t, job) + case model.ActionAddIndex: + ver, err = rollingbackAddIndex(w, d, t, job, false) + case model.ActionAddPrimaryKey: + ver, err = rollingbackAddIndex(w, d, t, job, true) + case model.ActionAddTablePartition: + ver, err = rollingbackAddTablePartition(d, t, job) + case model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + ver, err = rollingbackReorganizePartition(d, t, job) + case model.ActionDropColumn: + ver, err = rollingbackDropColumn(d, t, job) + case model.ActionDropIndex, model.ActionDropPrimaryKey: + ver, err = rollingbackDropIndex(d, t, job) + case model.ActionDropTable, model.ActionDropView, model.ActionDropSequence: + err = rollingbackDropTableOrView(t, job) + case model.ActionDropTablePartition: + ver, err = rollingbackDropTablePartition(t, job) + case model.ActionExchangeTablePartition: + ver, err = rollingbackExchangeTablePartition(d, t, job) + case model.ActionDropSchema: + err = rollingbackDropSchema(t, job) + case model.ActionRenameIndex: + ver, err = rollingbackRenameIndex(t, job) + case model.ActionTruncateTable: + ver, err = rollingbackTruncateTable(t, job) + case model.ActionModifyColumn: + ver, err = rollingbackModifyColumn(w, d, t, job) + case model.ActionDropForeignKey, model.ActionTruncateTablePartition: + ver, err = cancelOnlyNotHandledJob(job, model.StatePublic) + case model.ActionRebaseAutoID, model.ActionShardRowID, model.ActionAddForeignKey, + model.ActionRenameTable, model.ActionRenameTables, + model.ActionModifyTableCharsetAndCollate, + model.ActionModifySchemaCharsetAndCollate, model.ActionRepairTable, + model.ActionModifyTableAutoIdCache, model.ActionAlterIndexVisibility, + model.ActionModifySchemaDefaultPlacement, model.ActionRecoverSchema: + ver, err = cancelOnlyNotHandledJob(job, model.StateNone) + case model.ActionMultiSchemaChange: + err = rollingBackMultiSchemaChange(job) + case model.ActionAddCheckConstraint: + ver, err = rollingBackAddConstraint(d, t, job) + case model.ActionDropCheckConstraint: + ver, err = rollingBackDropConstraint(t, job) + case model.ActionAlterCheckConstraint: + ver, err = rollingBackAlterConstraint(d, t, job) + default: + job.State = model.JobStateCancelled + err = dbterror.ErrCancelledDDLJob + } + + logger := w.jobLogger(job) + if err != nil { + if job.Error == nil { + job.Error = toTError(err) + } + job.ErrorCount++ + + if dbterror.ErrCancelledDDLJob.Equal(err) { + // The job is normally cancelled. + if !job.Error.Equal(dbterror.ErrCancelledDDLJob) { + job.Error = terror.GetErrClass(job.Error).Synthesize(terror.ErrCode(job.Error.Code()), + fmt.Sprintf("DDL job rollback, error msg: %s", terror.ToSQLError(job.Error).Message)) + } + } else { + // A job canceling meet other error. + // + // Once `convertJob2RollbackJob` meets an error, the job state can't be set as `JobStateRollingback` since + // job state and args may not be correctly overwritten. The job will be fetched to run with the cancelling + // state again. So we should check the error count here. + if err1 := loadDDLVars(w); err1 != nil { + logger.Error("load DDL global variable failed", zap.Error(err1)) + } + errorCount := variable.GetDDLErrorCountLimit() + if job.ErrorCount > errorCount { + logger.Warn("rollback DDL job error count exceed the limit, cancelled it now", zap.Int64("errorCountLimit", errorCount)) + job.Error = toTError(errors.Errorf("rollback DDL job error count exceed the limit %d, cancelled it now", errorCount)) + job.State = model.JobStateCancelled + } + } + + if !(job.State != model.JobStateRollingback && job.State != model.JobStateCancelled) { + logger.Info("the DDL job is cancelled normally", zap.String("job", job.String()), zap.Error(err)) + // If job is cancelled, we shouldn't return an error. + return ver, nil + } + logger.Error("run DDL job failed", zap.String("job", job.String()), zap.Error(err)) + } + + return +} + +func rollingBackAddConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + _, tblInfo, constrInfoInMeta, _, err := checkAddCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + if constrInfoInMeta == nil { + // Add constraint hasn't stored constraint info into meta, so we can cancel the job + // directly without further rollback action. + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + for i, constr := range tblInfo.Constraints { + if constr.Name.L == constrInfoInMeta.Name.L { + tblInfo.Constraints = append(tblInfo.Constraints[0:i], tblInfo.Constraints[i+1:]...) + break + } + } + if job.IsRollingback() { + job.State = model.JobStateRollbackDone + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + return ver, errors.Trace(err) +} + +func rollingBackDropConstraint(t *meta.Meta, job *model.Job) (ver int64, err error) { + _, constrInfoInMeta, err := checkDropCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + // StatePublic means when the job is not running yet. + if constrInfoInMeta.State == model.StatePublic { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + // Can not rollback like drop other element, so just continue to drop constraint. + job.State = model.JobStateRunning + return ver, nil +} + +func rollingBackAlterConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + _, tblInfo, constraintInfo, enforced, err := checkAlterCheckConstraint(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + // StatePublic means when the job is not running yet. + if constraintInfo.State == model.StatePublic { + job.State = model.JobStateCancelled + return ver, dbterror.ErrCancelledDDLJob + } + + // Only alter check constraints ENFORCED can get here. + constraintInfo.Enforced = !enforced + constraintInfo.State = model.StatePublic + if job.IsRollingback() { + job.State = model.JobStateRollbackDone + } + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) + return ver, errors.Trace(err) +} diff --git a/pkg/ddl/schema_version.go b/pkg/ddl/schema_version.go index 96fbdad427d81..782c6397e124b 100644 --- a/pkg/ddl/schema_version.go +++ b/pkg/ddl/schema_version.go @@ -362,14 +362,14 @@ func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ... } func checkAllVersions(ctx context.Context, d *ddlCtx, job *model.Job, latestSchemaVersion int64, timeStart time.Time) error { - failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkDownBeforeUpdateGlobalVersion")); _err_ == nil { if val.(bool) { if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { panic("check down before update global version failed") } mockDDLErrOnce = -1 } - }) + } // OwnerCheckAllVersions returns only when all TiDB schemas are synced(exclude the isolated TiDB). err := d.schemaSyncer.OwnerCheckAllVersions(ctx, job.ID, latestSchemaVersion) @@ -405,14 +405,14 @@ func waitSchemaSynced(ctx context.Context, d *ddlCtx, job *model.Job) error { return err } - failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkDownBeforeUpdateGlobalVersion")); _err_ == nil { if val.(bool) { if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { panic("check down before update global version failed") } mockDDLErrOnce = -1 } - }) + } return waitSchemaChanged(ctx, d, latestSchemaVersion, job) } diff --git a/pkg/ddl/schema_version.go__failpoint_stash__ b/pkg/ddl/schema_version.go__failpoint_stash__ new file mode 100644 index 0000000000000..96fbdad427d81 --- /dev/null +++ b/pkg/ddl/schema_version.go__failpoint_stash__ @@ -0,0 +1,418 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/util/mathutil" + "go.uber.org/zap" +) + +// SetSchemaDiffForCreateTables set SchemaDiff for ActionCreateTables. +func SetSchemaDiffForCreateTables(diff *model.SchemaDiff, job *model.Job) error { + var tableInfos []*model.TableInfo + err := job.DecodeArgs(&tableInfos) + if err != nil { + return errors.Trace(err) + } + diff.AffectedOpts = make([]*model.AffectedOption, len(tableInfos)) + for i := range tableInfos { + diff.AffectedOpts[i] = &model.AffectedOption{ + SchemaID: job.SchemaID, + OldSchemaID: job.SchemaID, + TableID: tableInfos[i].ID, + OldTableID: tableInfos[i].ID, + } + } + return nil +} + +// SetSchemaDiffForTruncateTable set SchemaDiff for ActionTruncateTable. +func SetSchemaDiffForTruncateTable(diff *model.SchemaDiff, job *model.Job) error { + // Truncate table has two table ID, should be handled differently. + err := job.DecodeArgs(&diff.TableID) + if err != nil { + return errors.Trace(err) + } + diff.OldTableID = job.TableID + + // affects are used to update placement rule cache + if len(job.CtxVars) > 0 { + oldIDs := job.CtxVars[0].([]int64) + newIDs := job.CtxVars[1].([]int64) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } + return nil +} + +// SetSchemaDiffForCreateView set SchemaDiff for ActionCreateView. +func SetSchemaDiffForCreateView(diff *model.SchemaDiff, job *model.Job) error { + tbInfo := &model.TableInfo{} + var orReplace bool + var oldTbInfoID int64 + if err := job.DecodeArgs(tbInfo, &orReplace, &oldTbInfoID); err != nil { + return errors.Trace(err) + } + // When the statement is "create or replace view " and we need to drop the old view, + // it has two table IDs and should be handled differently. + if oldTbInfoID > 0 && orReplace { + diff.OldTableID = oldTbInfoID + } + diff.TableID = tbInfo.ID + return nil +} + +// SetSchemaDiffForRenameTable set SchemaDiff for ActionRenameTable. +func SetSchemaDiffForRenameTable(diff *model.SchemaDiff, job *model.Job) error { + err := job.DecodeArgs(&diff.OldSchemaID) + if err != nil { + return errors.Trace(err) + } + diff.TableID = job.TableID + return nil +} + +// SetSchemaDiffForRenameTables set SchemaDiff for ActionRenameTables. +func SetSchemaDiffForRenameTables(diff *model.SchemaDiff, job *model.Job) error { + var ( + oldSchemaIDs, newSchemaIDs, tableIDs []int64 + tableNames, oldSchemaNames []*model.CIStr + ) + err := job.DecodeArgs(&oldSchemaIDs, &newSchemaIDs, &tableNames, &tableIDs, &oldSchemaNames) + if err != nil { + return errors.Trace(err) + } + affects := make([]*model.AffectedOption, len(newSchemaIDs)-1) + for i, newSchemaID := range newSchemaIDs { + // Do not add the first table to AffectedOpts. Related issue tidb#47064. + if i == 0 { + continue + } + affects[i-1] = &model.AffectedOption{ + SchemaID: newSchemaID, + TableID: tableIDs[i], + OldTableID: tableIDs[i], + OldSchemaID: oldSchemaIDs[i], + } + } + diff.TableID = tableIDs[0] + diff.SchemaID = newSchemaIDs[0] + diff.OldSchemaID = oldSchemaIDs[0] + diff.AffectedOpts = affects + return nil +} + +// SetSchemaDiffForExchangeTablePartition set SchemaDiff for ActionExchangeTablePartition. +func SetSchemaDiffForExchangeTablePartition(diff *model.SchemaDiff, job *model.Job, multiInfos ...schemaIDAndTableInfo) error { + // From start of function: diff.SchemaID = job.SchemaID + // Old is original non partitioned table + diff.OldTableID = job.TableID + diff.OldSchemaID = job.SchemaID + // Update the partitioned table (it is only done in the last state) + var ( + ptSchemaID int64 + ptTableID int64 + ptDefID int64 + partName string // Not used + withValidation bool // Not used + ) + // See ddl.ExchangeTablePartition + err := job.DecodeArgs(&ptDefID, &ptSchemaID, &ptTableID, &partName, &withValidation) + if err != nil { + return errors.Trace(err) + } + // This is needed for not crashing TiFlash! + // TODO: Update TiFlash, to handle StateWriteOnly + diff.AffectedOpts = []*model.AffectedOption{{ + TableID: ptTableID, + }} + if job.SchemaState != model.StatePublic { + // No change, just to refresh the non-partitioned table + // with its new ExchangePartitionInfo. + diff.TableID = job.TableID + // Keep this as Schema ID of non-partitioned table + // to avoid trigger early rename in TiFlash + diff.AffectedOpts[0].SchemaID = job.SchemaID + // Need reload partition table, use diff.AffectedOpts[0].OldSchemaID to mark it. + if len(multiInfos) > 0 { + diff.AffectedOpts[0].OldSchemaID = ptSchemaID + } + } else { + // Swap + diff.TableID = ptDefID + // Also add correct SchemaID in case different schemas + diff.AffectedOpts[0].SchemaID = ptSchemaID + } + return nil +} + +// SetSchemaDiffForTruncateTablePartition set SchemaDiff for ActionTruncateTablePartition. +func SetSchemaDiffForTruncateTablePartition(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = job.TableID + if len(job.CtxVars) > 0 { + oldIDs := job.CtxVars[0].([]int64) + newIDs := job.CtxVars[1].([]int64) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } +} + +// SetSchemaDiffForDropTable set SchemaDiff for ActionDropTablePartition, ActionRecoverTable, ActionDropTable. +func SetSchemaDiffForDropTable(diff *model.SchemaDiff, job *model.Job) { + // affects are used to update placement rule cache + diff.TableID = job.TableID + if len(job.CtxVars) > 0 { + if oldIDs, ok := job.CtxVars[0].([]int64); ok { + diff.AffectedOpts = buildPlacementAffects(oldIDs, oldIDs) + } + } +} + +// SetSchemaDiffForReorganizePartition set SchemaDiff for ActionReorganizePartition. +func SetSchemaDiffForReorganizePartition(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = job.TableID + // TODO: should this be for every state of Reorganize? + if len(job.CtxVars) > 0 { + if droppedIDs, ok := job.CtxVars[0].([]int64); ok { + if addedIDs, ok := job.CtxVars[1].([]int64); ok { + // to use AffectedOpts we need both new and old to have the same length + maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) + // Also initialize them to 0! + oldIDs := make([]int64, maxParts) + copy(oldIDs, droppedIDs) + newIDs := make([]int64, maxParts) + copy(newIDs, addedIDs) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } + } + } +} + +// SetSchemaDiffForPartitionModify set SchemaDiff for ActionRemovePartitioning, ActionAlterTablePartitioning. +func SetSchemaDiffForPartitionModify(diff *model.SchemaDiff, job *model.Job) error { + diff.TableID = job.TableID + diff.OldTableID = job.TableID + if job.SchemaState == model.StateDeleteReorganization { + partInfo := &model.PartitionInfo{} + var partNames []string + err := job.DecodeArgs(&partNames, &partInfo) + if err != nil { + return errors.Trace(err) + } + // Final part, new table id is assigned + diff.TableID = partInfo.NewTableID + if len(job.CtxVars) > 0 { + if droppedIDs, ok := job.CtxVars[0].([]int64); ok { + if addedIDs, ok := job.CtxVars[1].([]int64); ok { + // to use AffectedOpts we need both new and old to have the same length + maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) + // Also initialize them to 0! + oldIDs := make([]int64, maxParts) + copy(oldIDs, droppedIDs) + newIDs := make([]int64, maxParts) + copy(newIDs, addedIDs) + diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) + } + } + } + } + return nil +} + +// SetSchemaDiffForCreateTable set SchemaDiff for ActionCreateTable. +func SetSchemaDiffForCreateTable(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = job.TableID + if len(job.Args) > 0 { + tbInfo, _ := job.Args[0].(*model.TableInfo) + // When create table with foreign key, there are two schema status change: + // 1. none -> write-only + // 2. write-only -> public + // In the second status change write-only -> public, infoschema loader should apply drop old table first, then + // apply create new table. So need to set diff.OldTableID here to make sure it. + if tbInfo != nil && tbInfo.State == model.StatePublic && len(tbInfo.ForeignKeys) > 0 { + diff.OldTableID = job.TableID + } + } +} + +// SetSchemaDiffForRecoverSchema set SchemaDiff for ActionRecoverSchema. +func SetSchemaDiffForRecoverSchema(diff *model.SchemaDiff, job *model.Job) error { + var ( + recoverSchemaInfo *RecoverSchemaInfo + recoverSchemaCheckFlag int64 + ) + err := job.DecodeArgs(&recoverSchemaInfo, &recoverSchemaCheckFlag) + if err != nil { + return errors.Trace(err) + } + // Reserved recoverSchemaCheckFlag value for gc work judgment. + job.Args[checkFlagIndexInJobArgs] = recoverSchemaCheckFlag + recoverTabsInfo := recoverSchemaInfo.RecoverTabsInfo + diff.AffectedOpts = make([]*model.AffectedOption, len(recoverTabsInfo)) + for i := range recoverTabsInfo { + diff.AffectedOpts[i] = &model.AffectedOption{ + SchemaID: job.SchemaID, + OldSchemaID: job.SchemaID, + TableID: recoverTabsInfo[i].TableInfo.ID, + OldTableID: recoverTabsInfo[i].TableInfo.ID, + } + } + diff.ReadTableFromMeta = true + return nil +} + +// SetSchemaDiffForFlashbackCluster set SchemaDiff for ActionFlashbackCluster. +func SetSchemaDiffForFlashbackCluster(diff *model.SchemaDiff, job *model.Job) { + diff.TableID = -1 + if job.SchemaState == model.StatePublic { + diff.RegenerateSchemaMap = true + } +} + +// SetSchemaDiffForMultiInfos set SchemaDiff for multiInfos. +func SetSchemaDiffForMultiInfos(diff *model.SchemaDiff, multiInfos ...schemaIDAndTableInfo) { + if len(multiInfos) > 0 { + existsMap := make(map[int64]struct{}) + existsMap[diff.TableID] = struct{}{} + for _, affect := range diff.AffectedOpts { + existsMap[affect.TableID] = struct{}{} + } + for _, info := range multiInfos { + _, exist := existsMap[info.tblInfo.ID] + if exist { + continue + } + existsMap[info.tblInfo.ID] = struct{}{} + diff.AffectedOpts = append(diff.AffectedOpts, &model.AffectedOption{ + SchemaID: info.schemaID, + OldSchemaID: info.schemaID, + TableID: info.tblInfo.ID, + OldTableID: info.tblInfo.ID, + }) + } + } +} + +// updateSchemaVersion increments the schema version by 1 and sets SchemaDiff. +func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ...schemaIDAndTableInfo) (int64, error) { + schemaVersion, err := d.setSchemaVersion(job, d.store) + if err != nil { + return 0, errors.Trace(err) + } + diff := &model.SchemaDiff{ + Version: schemaVersion, + Type: job.Type, + SchemaID: job.SchemaID, + } + switch job.Type { + case model.ActionCreateTables: + err = SetSchemaDiffForCreateTables(diff, job) + case model.ActionTruncateTable: + err = SetSchemaDiffForTruncateTable(diff, job) + case model.ActionCreateView: + err = SetSchemaDiffForCreateView(diff, job) + case model.ActionRenameTable: + err = SetSchemaDiffForRenameTable(diff, job) + case model.ActionRenameTables: + err = SetSchemaDiffForRenameTables(diff, job) + case model.ActionExchangeTablePartition: + err = SetSchemaDiffForExchangeTablePartition(diff, job, multiInfos...) + case model.ActionTruncateTablePartition: + SetSchemaDiffForTruncateTablePartition(diff, job) + case model.ActionDropTablePartition, model.ActionRecoverTable, model.ActionDropTable: + SetSchemaDiffForDropTable(diff, job) + case model.ActionReorganizePartition: + SetSchemaDiffForReorganizePartition(diff, job) + case model.ActionRemovePartitioning, model.ActionAlterTablePartitioning: + err = SetSchemaDiffForPartitionModify(diff, job) + case model.ActionCreateTable: + SetSchemaDiffForCreateTable(diff, job) + case model.ActionRecoverSchema: + err = SetSchemaDiffForRecoverSchema(diff, job) + case model.ActionFlashbackCluster: + SetSchemaDiffForFlashbackCluster(diff, job) + default: + diff.TableID = job.TableID + } + if err != nil { + return 0, err + } + SetSchemaDiffForMultiInfos(diff, multiInfos...) + err = t.SetSchemaDiff(diff) + return schemaVersion, errors.Trace(err) +} + +func checkAllVersions(ctx context.Context, d *ddlCtx, job *model.Job, latestSchemaVersion int64, timeStart time.Time) error { + failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val.(bool) { + if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { + panic("check down before update global version failed") + } + mockDDLErrOnce = -1 + } + }) + + // OwnerCheckAllVersions returns only when all TiDB schemas are synced(exclude the isolated TiDB). + err := d.schemaSyncer.OwnerCheckAllVersions(ctx, job.ID, latestSchemaVersion) + if err != nil { + logutil.DDLLogger().Info("wait latest schema version encounter error", zap.Int64("ver", latestSchemaVersion), + zap.Int64("jobID", job.ID), zap.Duration("take time", time.Since(timeStart)), zap.Error(err)) + return err + } + logutil.DDLLogger().Info("wait latest schema version changed(get the metadata lock if tidb_enable_metadata_lock is true)", + zap.Int64("ver", latestSchemaVersion), + zap.Duration("take time", time.Since(timeStart)), + zap.String("job", job.String())) + return nil +} + +// waitSchemaSynced handles the following situation: +// If the job enters a new state, and the worker crash when it's in the process of +// version sync, then the worker restarts quickly, we may run the job immediately again, +// but schema version might not sync. +// So here we get the latest schema version to make sure all servers' schema version +// update to the latest schema version in a cluster. +func waitSchemaSynced(ctx context.Context, d *ddlCtx, job *model.Job) error { + if !job.IsRunning() && !job.IsRollingback() && !job.IsDone() && !job.IsRollbackDone() { + return nil + } + + ver, _ := d.store.CurrentVersion(kv.GlobalTxnScope) + snapshot := d.store.GetSnapshot(ver) + m := meta.NewSnapshotMeta(snapshot) + latestSchemaVersion, err := m.GetSchemaVersionWithNonEmptyDiff() + if err != nil { + logutil.DDLLogger().Warn("get global version failed", zap.Int64("jobID", job.ID), zap.Error(err)) + return err + } + + failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { + if val.(bool) { + if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { + panic("check down before update global version failed") + } + mockDDLErrOnce = -1 + } + }) + + return waitSchemaChanged(ctx, d, latestSchemaVersion, job) +} diff --git a/pkg/ddl/session/binding__failpoint_binding__.go b/pkg/ddl/session/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..9ef59b452261c --- /dev/null +++ b/pkg/ddl/session/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package session + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/ddl/session/session.go b/pkg/ddl/session/session.go index bc16de5b0c484..17ebdf16bae53 100644 --- a/pkg/ddl/session/session.go +++ b/pkg/ddl/session/session.go @@ -109,7 +109,7 @@ func (s *Session) RunInTxn(f func(*Session) error) (err error) { if err != nil { return err } - failpoint.Inject("NotifyBeginTxnCh", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("NotifyBeginTxnCh")); _err_ == nil { //nolint:forcetypeassert v := val.(int) if v == 1 { @@ -119,7 +119,7 @@ func (s *Session) RunInTxn(f func(*Session) error) (err error) { <-TestNotifyBeginTxnCh MockDDLOnce = 0 } - }) + } err = f(s) if err != nil { diff --git a/pkg/ddl/session/session.go__failpoint_stash__ b/pkg/ddl/session/session.go__failpoint_stash__ new file mode 100644 index 0000000000000..bc16de5b0c484 --- /dev/null +++ b/pkg/ddl/session/session.go__failpoint_stash__ @@ -0,0 +1,137 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/sqlexec" +) + +// Session wraps sessionctx.Context for transaction usage. +type Session struct { + sessionctx.Context +} + +// NewSession creates a new Session. +func NewSession(s sessionctx.Context) *Session { + return &Session{s} +} + +// Begin starts a transaction. +func (s *Session) Begin(ctx context.Context) error { + err := sessiontxn.NewTxn(ctx, s.Context) + if err != nil { + return err + } + s.GetSessionVars().SetInTxn(true) + return nil +} + +// Commit commits the transaction. +func (s *Session) Commit(ctx context.Context) error { + s.StmtCommit(ctx) + return s.CommitTxn(ctx) +} + +// Txn activate and returns the current transaction. +func (s *Session) Txn() (kv.Transaction, error) { + return s.Context.Txn(true) +} + +// Rollback aborts the transaction. +func (s *Session) Rollback() { + s.StmtRollback(context.Background(), false) + s.RollbackTxn(context.Background()) +} + +// Reset resets the session. +func (s *Session) Reset() { + s.StmtRollback(context.Background(), false) +} + +// Execute executes a query. +func (s *Session) Execute(ctx context.Context, query string, label string) ([]chunk.Row, error) { + startTime := time.Now() + var err error + defer func() { + metrics.DDLJobTableDuration.WithLabelValues(label + "-" + metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + }() + + if ctx.Value(kv.RequestSourceKey) == nil { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnDDL) + } + rs, err := s.Context.GetSQLExecutor().ExecuteInternal(ctx, query) + if err != nil { + return nil, errors.Trace(err) + } + + if rs == nil { + return nil, nil + } + var rows []chunk.Row + defer terror.Call(rs.Close) + if rows, err = sqlexec.DrainRecordSet(ctx, rs, 8); err != nil { + return nil, errors.Trace(err) + } + return rows, nil +} + +// Session returns the sessionctx.Context. +func (s *Session) Session() sessionctx.Context { + return s.Context +} + +// RunInTxn runs a function in a transaction. +func (s *Session) RunInTxn(f func(*Session) error) (err error) { + err = s.Begin(context.Background()) + if err != nil { + return err + } + failpoint.Inject("NotifyBeginTxnCh", func(val failpoint.Value) { + //nolint:forcetypeassert + v := val.(int) + if v == 1 { + MockDDLOnce = 1 + TestNotifyBeginTxnCh <- struct{}{} + } else if v == 2 && MockDDLOnce == 1 { + <-TestNotifyBeginTxnCh + MockDDLOnce = 0 + } + }) + + err = f(s) + if err != nil { + s.Rollback() + return + } + return errors.Trace(s.Commit(context.Background())) +} + +var ( + // MockDDLOnce is only used for test. + MockDDLOnce = int64(0) + // TestNotifyBeginTxnCh is used for if the txn is beginning in RunInTxn. + TestNotifyBeginTxnCh = make(chan struct{}) +) diff --git a/pkg/ddl/syncer/binding__failpoint_binding__.go b/pkg/ddl/syncer/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..3db8a06874be4 --- /dev/null +++ b/pkg/ddl/syncer/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package syncer + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/ddl/syncer/syncer.go b/pkg/ddl/syncer/syncer.go index 818d4747eed06..d9e34b6fdf5c4 100644 --- a/pkg/ddl/syncer/syncer.go +++ b/pkg/ddl/syncer/syncer.go @@ -273,12 +273,12 @@ func (s *schemaVersionSyncer) storeSession(session *concurrency.Session) { // Done implements SchemaSyncer.Done interface. func (s *schemaVersionSyncer) Done() <-chan struct{} { - failpoint.Inject("ErrorMockSessionDone", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ErrorMockSessionDone")); _err_ == nil { if val.(bool) { err := s.loadSession().Close() logutil.DDLLogger().Error("close session failed", zap.Error(err)) } - }) + } return s.loadSession().Done() } @@ -530,9 +530,9 @@ func (s *schemaVersionSyncer) syncJobSchemaVer(ctx context.Context) { return } } - failpoint.Inject("mockCompaction", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockCompaction")); _err_ == nil { wresp.CompactRevision = 123 - }) + } if err := wresp.Err(); err != nil { logutil.DDLLogger().Warn("watch job version failed", zap.Error(err)) return diff --git a/pkg/ddl/syncer/syncer.go__failpoint_stash__ b/pkg/ddl/syncer/syncer.go__failpoint_stash__ new file mode 100644 index 0000000000000..818d4747eed06 --- /dev/null +++ b/pkg/ddl/syncer/syncer.go__failpoint_stash__ @@ -0,0 +1,629 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncer + +import ( + "context" + "fmt" + "math" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + tidbutil "github.com/pingcap/tidb/pkg/util" + disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + "go.uber.org/zap" +) + +const ( + // InitialVersion is the initial schema version for every server. + // It's exported for testing. + InitialVersion = "0" + putKeyNoRetry = 1 + keyOpDefaultRetryCnt = 3 + putKeyRetryUnlimited = math.MaxInt64 + checkVersInterval = 20 * time.Millisecond + ddlPrompt = "ddl-syncer" +) + +var ( + // CheckVersFirstWaitTime is a waitting time before the owner checks all the servers of the schema version, + // and it's an exported variable for testing. + CheckVersFirstWaitTime = 50 * time.Millisecond +) + +// Watcher is responsible for watching the etcd path related operations. +type Watcher interface { + // WatchChan returns the chan for watching etcd path. + WatchChan() clientv3.WatchChan + // Watch watches the etcd path. + Watch(ctx context.Context, etcdCli *clientv3.Client, path string) + // Rewatch rewatches the etcd path. + Rewatch(ctx context.Context, etcdCli *clientv3.Client, path string) +} + +type watcher struct { + sync.RWMutex + wCh clientv3.WatchChan +} + +// WatchChan implements SyncerWatch.WatchChan interface. +func (w *watcher) WatchChan() clientv3.WatchChan { + w.RLock() + defer w.RUnlock() + return w.wCh +} + +// Watch implements SyncerWatch.Watch interface. +func (w *watcher) Watch(ctx context.Context, etcdCli *clientv3.Client, path string) { + w.Lock() + w.wCh = etcdCli.Watch(ctx, path) + w.Unlock() +} + +// Rewatch implements SyncerWatch.Rewatch interface. +func (w *watcher) Rewatch(ctx context.Context, etcdCli *clientv3.Client, path string) { + startTime := time.Now() + // Make sure the wCh doesn't receive the information of 'close' before we finish the rewatch. + w.Lock() + w.wCh = nil + w.Unlock() + + go func() { + defer func() { + metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerRewatch, metrics.RetLabel(nil)).Observe(time.Since(startTime).Seconds()) + }() + wCh := etcdCli.Watch(ctx, path) + + w.Lock() + w.wCh = wCh + w.Unlock() + logutil.DDLLogger().Info("syncer rewatch global info finished") + }() +} + +// SchemaSyncer is used to synchronize schema version between the DDL worker leader and followers through etcd. +type SchemaSyncer interface { + // Init sets the global schema version path to etcd if it isn't exist, + // then watch this path, and initializes the self schema version to etcd. + Init(ctx context.Context) error + // UpdateSelfVersion updates the current version to the self path on etcd. + UpdateSelfVersion(ctx context.Context, jobID int64, version int64) error + // OwnerUpdateGlobalVersion updates the latest version to the global path on etcd until updating is successful or the ctx is done. + OwnerUpdateGlobalVersion(ctx context.Context, version int64) error + // GlobalVersionCh gets the chan for watching global version. + GlobalVersionCh() clientv3.WatchChan + // WatchGlobalSchemaVer watches the global schema version. + WatchGlobalSchemaVer(ctx context.Context) + // Done returns a channel that closes when the syncer is no longer being refreshed. + Done() <-chan struct{} + // Restart restarts the syncer when it's on longer being refreshed. + Restart(ctx context.Context) error + // OwnerCheckAllVersions checks whether all followers' schema version are equal to + // the latest schema version. (exclude the isolated TiDB) + // It returns until all servers' versions are equal to the latest version. + OwnerCheckAllVersions(ctx context.Context, jobID int64, latestVer int64) error + // SyncJobSchemaVerLoop syncs the schema versions on all TiDB nodes for DDL jobs. + SyncJobSchemaVerLoop(ctx context.Context) + // Close ends SchemaSyncer. + Close() +} + +// nodeVersions is used to record the schema versions of all TiDB nodes for a DDL job. +type nodeVersions struct { + sync.Mutex + nodeVersions map[string]int64 + // onceMatchFn is used to check if all the servers report the least version. + // If all the servers report the least version, i.e. return true, it will be + // set to nil. + onceMatchFn func(map[string]int64) bool +} + +func newNodeVersions(initialCap int, fn func(map[string]int64) bool) *nodeVersions { + return &nodeVersions{ + nodeVersions: make(map[string]int64, initialCap), + onceMatchFn: fn, + } +} + +func (v *nodeVersions) add(nodeID string, ver int64) { + v.Lock() + defer v.Unlock() + v.nodeVersions[nodeID] = ver + if v.onceMatchFn != nil { + if ok := v.onceMatchFn(v.nodeVersions); ok { + v.onceMatchFn = nil + } + } +} + +func (v *nodeVersions) del(nodeID string) { + v.Lock() + defer v.Unlock() + delete(v.nodeVersions, nodeID) + // we don't call onceMatchFn here, for only "add" can cause onceMatchFn return + // true currently. +} + +func (v *nodeVersions) len() int { + v.Lock() + defer v.Unlock() + return len(v.nodeVersions) +} + +// matchOrSet onceMatchFn must be nil before calling this method. +func (v *nodeVersions) matchOrSet(fn func(nodeVersions map[string]int64) bool) { + v.Lock() + defer v.Unlock() + if ok := fn(v.nodeVersions); !ok { + v.onceMatchFn = fn + } +} + +func (v *nodeVersions) clearData() { + v.Lock() + defer v.Unlock() + v.nodeVersions = make(map[string]int64, len(v.nodeVersions)) +} + +func (v *nodeVersions) clearMatchFn() { + v.Lock() + defer v.Unlock() + v.onceMatchFn = nil +} + +func (v *nodeVersions) emptyAndNotUsed() bool { + v.Lock() + defer v.Unlock() + return len(v.nodeVersions) == 0 && v.onceMatchFn == nil +} + +// for test +func (v *nodeVersions) getMatchFn() func(map[string]int64) bool { + v.Lock() + defer v.Unlock() + return v.onceMatchFn +} + +type schemaVersionSyncer struct { + selfSchemaVerPath string + etcdCli *clientv3.Client + session unsafe.Pointer + globalVerWatcher watcher + ddlID string + + mu sync.RWMutex + jobNodeVersions map[int64]*nodeVersions + jobNodeVerPrefix string +} + +// NewSchemaSyncer creates a new SchemaSyncer. +func NewSchemaSyncer(etcdCli *clientv3.Client, id string) SchemaSyncer { + return &schemaVersionSyncer{ + etcdCli: etcdCli, + selfSchemaVerPath: fmt.Sprintf("%s/%s", util.DDLAllSchemaVersions, id), + ddlID: id, + + jobNodeVersions: make(map[int64]*nodeVersions), + jobNodeVerPrefix: util.DDLAllSchemaVersionsByJob + "/", + } +} + +// Init implements SchemaSyncer.Init interface. +func (s *schemaVersionSyncer) Init(ctx context.Context) error { + startTime := time.Now() + var err error + defer func() { + metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerInit, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + }() + + _, err = s.etcdCli.Txn(ctx). + If(clientv3.Compare(clientv3.CreateRevision(util.DDLGlobalSchemaVersion), "=", 0)). + Then(clientv3.OpPut(util.DDLGlobalSchemaVersion, InitialVersion)). + Commit() + if err != nil { + return errors.Trace(err) + } + logPrefix := fmt.Sprintf("[%s] %s", ddlPrompt, s.selfSchemaVerPath) + session, err := tidbutil.NewSession(ctx, logPrefix, s.etcdCli, tidbutil.NewSessionDefaultRetryCnt, util.SessionTTL) + if err != nil { + return errors.Trace(err) + } + s.storeSession(session) + + s.globalVerWatcher.Watch(ctx, s.etcdCli, util.DDLGlobalSchemaVersion) + + err = util.PutKVToEtcd(ctx, s.etcdCli, keyOpDefaultRetryCnt, s.selfSchemaVerPath, InitialVersion, + clientv3.WithLease(s.loadSession().Lease())) + return errors.Trace(err) +} + +func (s *schemaVersionSyncer) loadSession() *concurrency.Session { + return (*concurrency.Session)(atomic.LoadPointer(&s.session)) +} + +func (s *schemaVersionSyncer) storeSession(session *concurrency.Session) { + atomic.StorePointer(&s.session, (unsafe.Pointer)(session)) +} + +// Done implements SchemaSyncer.Done interface. +func (s *schemaVersionSyncer) Done() <-chan struct{} { + failpoint.Inject("ErrorMockSessionDone", func(val failpoint.Value) { + if val.(bool) { + err := s.loadSession().Close() + logutil.DDLLogger().Error("close session failed", zap.Error(err)) + } + }) + + return s.loadSession().Done() +} + +// Restart implements SchemaSyncer.Restart interface. +func (s *schemaVersionSyncer) Restart(ctx context.Context) error { + startTime := time.Now() + var err error + defer func() { + metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerRestart, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + }() + + logPrefix := fmt.Sprintf("[%s] %s", ddlPrompt, s.selfSchemaVerPath) + // NewSession's context will affect the exit of the session. + session, err := tidbutil.NewSession(ctx, logPrefix, s.etcdCli, tidbutil.NewSessionRetryUnlimited, util.SessionTTL) + if err != nil { + return errors.Trace(err) + } + s.storeSession(session) + + childCtx, cancel := context.WithTimeout(ctx, util.KeyOpDefaultTimeout) + defer cancel() + err = util.PutKVToEtcd(childCtx, s.etcdCli, putKeyRetryUnlimited, s.selfSchemaVerPath, InitialVersion, + clientv3.WithLease(s.loadSession().Lease())) + + return errors.Trace(err) +} + +// GlobalVersionCh implements SchemaSyncer.GlobalVersionCh interface. +func (s *schemaVersionSyncer) GlobalVersionCh() clientv3.WatchChan { + return s.globalVerWatcher.WatchChan() +} + +// WatchGlobalSchemaVer implements SchemaSyncer.WatchGlobalSchemaVer interface. +func (s *schemaVersionSyncer) WatchGlobalSchemaVer(ctx context.Context) { + s.globalVerWatcher.Rewatch(ctx, s.etcdCli, util.DDLGlobalSchemaVersion) +} + +// UpdateSelfVersion implements SchemaSyncer.UpdateSelfVersion interface. +func (s *schemaVersionSyncer) UpdateSelfVersion(ctx context.Context, jobID int64, version int64) error { + startTime := time.Now() + ver := strconv.FormatInt(version, 10) + var err error + var path string + if variable.EnableMDL.Load() { + path = fmt.Sprintf("%s/%d/%s", util.DDLAllSchemaVersionsByJob, jobID, s.ddlID) + err = util.PutKVToEtcdMono(ctx, s.etcdCli, keyOpDefaultRetryCnt, path, ver) + } else { + path = s.selfSchemaVerPath + err = util.PutKVToEtcd(ctx, s.etcdCli, putKeyNoRetry, path, ver, + clientv3.WithLease(s.loadSession().Lease())) + } + + metrics.UpdateSelfVersionHistogram.WithLabelValues(metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + return errors.Trace(err) +} + +// OwnerUpdateGlobalVersion implements SchemaSyncer.OwnerUpdateGlobalVersion interface. +func (s *schemaVersionSyncer) OwnerUpdateGlobalVersion(ctx context.Context, version int64) error { + startTime := time.Now() + ver := strconv.FormatInt(version, 10) + // TODO: If the version is larger than the original global version, we need set the version. + // Otherwise, we'd better set the original global version. + err := util.PutKVToEtcd(ctx, s.etcdCli, putKeyRetryUnlimited, util.DDLGlobalSchemaVersion, ver) + metrics.OwnerHandleSyncerHistogram.WithLabelValues(metrics.OwnerUpdateGlobalVersion, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + return errors.Trace(err) +} + +// removeSelfVersionPath remove the self path from etcd. +func (s *schemaVersionSyncer) removeSelfVersionPath() error { + startTime := time.Now() + var err error + defer func() { + metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerClear, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + }() + + err = util.DeleteKeyFromEtcd(s.selfSchemaVerPath, s.etcdCli, keyOpDefaultRetryCnt, util.KeyOpDefaultTimeout) + return errors.Trace(err) +} + +// OwnerCheckAllVersions implements SchemaSyncer.OwnerCheckAllVersions interface. +func (s *schemaVersionSyncer) OwnerCheckAllVersions(ctx context.Context, jobID int64, latestVer int64) error { + startTime := time.Now() + if !variable.EnableMDL.Load() { + time.Sleep(CheckVersFirstWaitTime) + } + notMatchVerCnt := 0 + intervalCnt := int(time.Second / checkVersInterval) + + var err error + defer func() { + metrics.OwnerHandleSyncerHistogram.WithLabelValues(metrics.OwnerCheckAllVersions, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + }() + + // If MDL is disabled, updatedMap is a cache. We need to ensure all the keys equal to the least version. + // We can skip checking the key if it is checked in the cache(set by the previous loop). + // If MDL is enabled, updatedMap is used to check if all the servers report the least version. + // updatedMap is initialed to record all the server in every loop. We delete a server from the map if it gets the metadata lock(the key version equal the given version. + // updatedMap should be empty if all the servers get the metadata lock. + updatedMap := make(map[string]string) + for { + if err := ctx.Err(); err != nil { + // ctx is canceled or timeout. + return errors.Trace(err) + } + + if variable.EnableMDL.Load() { + serverInfos, err := infosync.GetAllServerInfo(ctx) + if err != nil { + return err + } + updatedMap = make(map[string]string) + instance2id := make(map[string]string) + + for _, info := range serverInfos { + instance := disttaskutil.GenerateExecID(info) + // if some node shutdown abnormally and start, we might see some + // instance with different id, we should use the latest one. + if id, ok := instance2id[instance]; ok { + if info.StartTimestamp > serverInfos[id].StartTimestamp { + // Replace it. + delete(updatedMap, id) + updatedMap[info.ID] = fmt.Sprintf("instance ip %s, port %d, id %s", info.IP, info.Port, info.ID) + instance2id[instance] = info.ID + } + } else { + updatedMap[info.ID] = fmt.Sprintf("instance ip %s, port %d, id %s", info.IP, info.Port, info.ID) + instance2id[instance] = info.ID + } + } + } + + // Check all schema versions. + if variable.EnableMDL.Load() { + notifyCh := make(chan struct{}) + var unmatchedNodeID atomic.Pointer[string] + matchFn := func(nodeVersions map[string]int64) bool { + if len(nodeVersions) < len(updatedMap) { + return false + } + for tidbID := range updatedMap { + if nodeVer, ok := nodeVersions[tidbID]; !ok || nodeVer < latestVer { + id := tidbID + unmatchedNodeID.Store(&id) + return false + } + } + close(notifyCh) + return true + } + item := s.jobSchemaVerMatchOrSet(jobID, matchFn) + select { + case <-notifyCh: + return nil + case <-ctx.Done(): + item.clearMatchFn() + return errors.Trace(ctx.Err()) + case <-time.After(time.Second): + item.clearMatchFn() + if id := unmatchedNodeID.Load(); id != nil { + logutil.DDLLogger().Info("syncer check all versions, someone is not synced", + zap.String("info", *id), + zap.Int64("ddl job id", jobID), + zap.Int64("ver", latestVer)) + } else { + logutil.DDLLogger().Info("syncer check all versions, all nodes are not synced", + zap.Int64("ddl job id", jobID), + zap.Int64("ver", latestVer)) + } + } + } else { + // Get all the schema versions from ETCD. + resp, err := s.etcdCli.Get(ctx, util.DDLAllSchemaVersions, clientv3.WithPrefix()) + if err != nil { + logutil.DDLLogger().Info("syncer check all versions failed, continue checking.", zap.Error(err)) + continue + } + succ := true + for _, kv := range resp.Kvs { + if _, ok := updatedMap[string(kv.Key)]; ok { + continue + } + + succ = isUpdatedLatestVersion(string(kv.Key), string(kv.Value), latestVer, notMatchVerCnt, intervalCnt, true) + if !succ { + break + } + updatedMap[string(kv.Key)] = "" + } + + if succ { + return nil + } + time.Sleep(checkVersInterval) + notMatchVerCnt++ + } + } +} + +// SyncJobSchemaVerLoop implements SchemaSyncer.SyncJobSchemaVerLoop interface. +func (s *schemaVersionSyncer) SyncJobSchemaVerLoop(ctx context.Context) { + for { + s.syncJobSchemaVer(ctx) + logutil.DDLLogger().Info("schema version sync loop interrupted, retrying...") + select { + case <-ctx.Done(): + return + case <-time.After(time.Second): + } + } +} + +func (s *schemaVersionSyncer) syncJobSchemaVer(ctx context.Context) { + resp, err := s.etcdCli.Get(ctx, s.jobNodeVerPrefix, clientv3.WithPrefix()) + if err != nil { + logutil.DDLLogger().Info("get all job versions failed", zap.Error(err)) + return + } + s.mu.Lock() + for jobID, item := range s.jobNodeVersions { + item.clearData() + // we might miss some DELETE events during retry, some items might be emptyAndNotUsed, remove them. + if item.emptyAndNotUsed() { + delete(s.jobNodeVersions, jobID) + } + } + s.mu.Unlock() + for _, oneKV := range resp.Kvs { + s.handleJobSchemaVerKV(oneKV, mvccpb.PUT) + } + + startRev := resp.Header.Revision + 1 + watchCtx, watchCtxCancel := context.WithCancel(ctx) + defer watchCtxCancel() + watchCtx = clientv3.WithRequireLeader(watchCtx) + watchCh := s.etcdCli.Watch(watchCtx, s.jobNodeVerPrefix, clientv3.WithPrefix(), clientv3.WithRev(startRev)) + for { + var ( + wresp clientv3.WatchResponse + ok bool + ) + select { + case <-watchCtx.Done(): + return + case wresp, ok = <-watchCh: + if !ok { + // ctx must be cancelled, else we should have received a response + // with err and caught by below err check. + return + } + } + failpoint.Inject("mockCompaction", func() { + wresp.CompactRevision = 123 + }) + if err := wresp.Err(); err != nil { + logutil.DDLLogger().Warn("watch job version failed", zap.Error(err)) + return + } + for _, ev := range wresp.Events { + s.handleJobSchemaVerKV(ev.Kv, ev.Type) + } + } +} + +func (s *schemaVersionSyncer) handleJobSchemaVerKV(kv *mvccpb.KeyValue, tp mvccpb.Event_EventType) { + jobID, tidbID, schemaVer, valid := decodeJobVersionEvent(kv, tp, s.jobNodeVerPrefix) + if !valid { + logutil.DDLLogger().Error("invalid job version kv", zap.Stringer("kv", kv), zap.Stringer("type", tp)) + return + } + if tp == mvccpb.PUT { + s.mu.Lock() + item, exists := s.jobNodeVersions[jobID] + if !exists { + item = newNodeVersions(1, nil) + s.jobNodeVersions[jobID] = item + } + s.mu.Unlock() + item.add(tidbID, schemaVer) + } else { // DELETE + s.mu.Lock() + if item, exists := s.jobNodeVersions[jobID]; exists { + item.del(tidbID) + if item.len() == 0 { + delete(s.jobNodeVersions, jobID) + } + } + s.mu.Unlock() + } +} + +func (s *schemaVersionSyncer) jobSchemaVerMatchOrSet(jobID int64, matchFn func(map[string]int64) bool) *nodeVersions { + s.mu.Lock() + defer s.mu.Unlock() + + item, exists := s.jobNodeVersions[jobID] + if exists { + item.matchOrSet(matchFn) + } else { + item = newNodeVersions(1, matchFn) + s.jobNodeVersions[jobID] = item + } + return item +} + +func decodeJobVersionEvent(kv *mvccpb.KeyValue, tp mvccpb.Event_EventType, prefix string) (jobID int64, tidbID string, schemaVer int64, valid bool) { + left := strings.TrimPrefix(string(kv.Key), prefix) + parts := strings.Split(left, "/") + if len(parts) != 2 { + return 0, "", 0, false + } + jobID, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return 0, "", 0, false + } + // there is no Value in DELETE event, so we need to check it. + if tp == mvccpb.PUT { + schemaVer, err = strconv.ParseInt(string(kv.Value), 10, 64) + if err != nil { + return 0, "", 0, false + } + } + return jobID, parts[1], schemaVer, true +} + +func isUpdatedLatestVersion(key, val string, latestVer int64, notMatchVerCnt, intervalCnt int, nodeAlive bool) bool { + ver, err := strconv.Atoi(val) + if err != nil { + logutil.DDLLogger().Info("syncer check all versions, convert value to int failed, continue checking.", + zap.String("ddl", key), zap.String("value", val), zap.Error(err)) + return false + } + if int64(ver) < latestVer && nodeAlive { + if notMatchVerCnt%intervalCnt == 0 { + logutil.DDLLogger().Info("syncer check all versions, someone is not synced, continue checking", + zap.String("ddl", key), zap.Int("currentVer", ver), zap.Int64("latestVer", latestVer)) + } + return false + } + return true +} + +func (s *schemaVersionSyncer) Close() { + err := s.removeSelfVersionPath() + if err != nil { + logutil.DDLLogger().Error("remove self version path failed", zap.Error(err)) + } +} diff --git a/pkg/ddl/table.go b/pkg/ddl/table.go index 4d4696b43339c..bcc0a0bb5d91c 100644 --- a/pkg/ddl/table.go +++ b/pkg/ddl/table.go @@ -273,14 +273,14 @@ func (w *worker) recoverTable(t *meta.Meta, job *model.Job, recoverInfo *Recover return ver, errors.Trace(err) } - failpoint.Inject("mockRecoverTableCommitErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockRecoverTableCommitErr")); _err_ == nil { if val.(bool) && atomic.CompareAndSwapUint32(&mockRecoverTableCommitErrOnce, 0, 1) { err = failpoint.Enable(`tikvclient/mockCommitErrorOpt`, "return(true)") if err != nil { return } } - }) + } err = updateLabelRules(job, recoverInfo.TableInfo, oldRules, tableRuleID, partRuleIDs, oldRuleIDs, recoverInfo.TableInfo.ID) if err != nil { @@ -292,9 +292,9 @@ func (w *worker) recoverTable(t *meta.Meta, job *model.Job, recoverInfo *Recover } func clearTablePlacementAndBundles(ctx context.Context, tblInfo *model.TableInfo) error { - failpoint.Inject("mockClearTablePlacementAndBundlesErr", func() { - failpoint.Return(errors.New("mock error for clearTablePlacementAndBundles")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockClearTablePlacementAndBundlesErr")); _err_ == nil { + return errors.New("mock error for clearTablePlacementAndBundles") + } var bundles []*placement.Bundle if tblInfo.PlacementPolicyRef != nil { tblInfo.PlacementPolicyRef = nil @@ -459,12 +459,12 @@ func (w *worker) onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver i job.State = model.JobStateCancelled return ver, errors.Trace(err) } - failpoint.Inject("truncateTableErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("truncateTableErr")); _err_ == nil { if val.(bool) { job.State = model.JobStateCancelled - failpoint.Return(ver, errors.New("occur an error after dropping table")) + return ver, errors.New("occur an error after dropping table") } - }) + } // Clear the TiFlash replica progress from ETCD. if tblInfo.TiFlashReplica != nil { @@ -552,11 +552,11 @@ func (w *worker) onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver i return ver, errors.Trace(err) } - failpoint.Inject("mockTruncateTableUpdateVersionError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockTruncateTableUpdateVersionError")); _err_ == nil { if val.(bool) { - failpoint.Return(ver, errors.New("mock update version error")) + return ver, errors.New("mock update version error") } - }) + } var partitions []model.PartitionDefinition if pi := tblInfo.GetPartitionInfo(); pi != nil { @@ -829,14 +829,14 @@ func checkAndRenameTables(t *meta.Meta, job *model.Job, tblInfo *model.TableInfo return ver, errors.Trace(err) } - failpoint.Inject("renameTableErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("renameTableErr")); _err_ == nil { if valStr, ok := val.(string); ok { if tableName.L == valStr { job.State = model.JobStateCancelled - failpoint.Return(ver, errors.New("occur an error after renaming table")) + return ver, errors.New("occur an error after renaming table") } } - }) + } oldTableName := tblInfo.Name tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err := getOldLabelRules(tblInfo, oldSchemaName.L, oldTableName.L) @@ -1291,18 +1291,18 @@ func updateVersionAndTableInfoWithCheck(d *ddlCtx, t *meta.Meta, job *model.Job, // updateVersionAndTableInfo updates the schema version and the table information. func updateVersionAndTableInfo(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, shouldUpdateVer bool, multiInfos ...schemaIDAndTableInfo) ( ver int64, err error) { - failpoint.Inject("mockUpdateVersionAndTableInfoErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateVersionAndTableInfoErr")); _err_ == nil { switch val.(int) { case 1: - failpoint.Return(ver, errors.New("mock update version and tableInfo error")) + return ver, errors.New("mock update version and tableInfo error") case 2: // We change it cancelled directly here, because we want to get the original error with the job id appended. // The job ID will be used to get the job from history queue and we will assert it's args. job.State = model.JobStateCancelled - failpoint.Return(ver, errors.New("mock update version and tableInfo error, jobID="+strconv.Itoa(int(job.ID)))) + return ver, errors.New("mock update version and tableInfo error, jobID=" + strconv.Itoa(int(job.ID))) default: } - }) + } if shouldUpdateVer && (job.MultiSchemaInfo == nil || !job.MultiSchemaInfo.SkipVersion) { ver, err = updateSchemaVersion(d, t, job, multiInfos...) if err != nil { diff --git a/pkg/ddl/table.go__failpoint_stash__ b/pkg/ddl/table.go__failpoint_stash__ new file mode 100644 index 0000000000000..4d4696b43339c --- /dev/null +++ b/pkg/ddl/table.go__failpoint_stash__ @@ -0,0 +1,1681 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ddl + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/label" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/ddl/placement" + sess "github.com/pingcap/tidb/pkg/ddl/session" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + field_types "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + tidb_util "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/gcutil" + "go.uber.org/zap" +) + +const tiflashCheckTiDBHTTPAPIHalfInterval = 2500 * time.Millisecond + +func repairTableOrViewWithCheck(t *meta.Meta, job *model.Job, schemaID int64, tbInfo *model.TableInfo) error { + err := checkTableInfoValid(tbInfo) + if err != nil { + job.State = model.JobStateCancelled + return errors.Trace(err) + } + return t.UpdateTable(schemaID, tbInfo) +} + +func onDropTableOrView(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + tblInfo, err := checkTableExistAndCancelNonExistJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + + originalState := job.SchemaState + switch tblInfo.State { + case model.StatePublic: + // public -> write only + if job.Type == model.ActionDropTable { + err = checkDropTableHasForeignKeyReferredInOwner(d, t, job) + if err != nil { + return ver, err + } + } + tblInfo.State = model.StateWriteOnly + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != tblInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + case model.StateWriteOnly: + // write only -> delete only + tblInfo.State = model.StateDeleteOnly + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != tblInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + case model.StateDeleteOnly: + tblInfo.State = model.StateNone + oldIDs := getPartitionIDs(tblInfo) + ruleIDs := append(getPartitionRuleIDs(job.SchemaName, tblInfo), fmt.Sprintf(label.TableIDFormat, label.IDPrefix, job.SchemaName, tblInfo.Name.L)) + job.CtxVars = []any{oldIDs} + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != tblInfo.State) + if err != nil { + return ver, errors.Trace(err) + } + if tblInfo.IsSequence() { + if err = t.DropSequence(job.SchemaID, job.TableID); err != nil { + return ver, errors.Trace(err) + } + } else { + if err = t.DropTableOrView(job.SchemaID, job.TableID); err != nil { + return ver, errors.Trace(err) + } + if err = t.GetAutoIDAccessors(job.SchemaID, job.TableID).Del(); err != nil { + return ver, errors.Trace(err) + } + } + if tblInfo.TiFlashReplica != nil { + e := infosync.DeleteTiFlashTableSyncProgress(tblInfo) + if e != nil { + logutil.DDLLogger().Error("DeleteTiFlashTableSyncProgress fails", zap.Error(e)) + } + } + // Placement rules cannot be removed immediately after drop table / truncate table, because the + // tables can be flashed back or recovered, therefore it moved to doGCPlacementRules in gc_worker.go. + + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) + startKey := tablecodec.EncodeTablePrefix(job.TableID) + job.Args = append(job.Args, startKey, oldIDs, ruleIDs) + if !tblInfo.IsSequence() && !tblInfo.IsView() { + dropTableEvent := statsutil.NewDropTableEvent( + job.SchemaID, + tblInfo, + ) + asyncNotifyEvent(d, dropTableEvent) + } + default: + return ver, errors.Trace(dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tblInfo.State)) + } + job.SchemaState = tblInfo.State + return ver, errors.Trace(err) +} + +func (w *worker) onRecoverTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + var ( + recoverInfo *RecoverInfo + recoverTableCheckFlag int64 + ) + if err = job.DecodeArgs(&recoverInfo, &recoverTableCheckFlag); err != nil { + // Invalid arguments, cancel this job. + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + schemaID := recoverInfo.SchemaID + tblInfo := recoverInfo.TableInfo + if tblInfo.TTLInfo != nil { + // force disable TTL job schedule for recovered table + tblInfo.TTLInfo.Enable = false + } + + // check GC and safe point + gcEnable, err := checkGCEnable(w) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = checkTableNotExists(d, schemaID, tblInfo.Name.L) + if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { + job.State = model.JobStateCancelled + } + return ver, errors.Trace(err) + } + + err = checkTableIDNotExists(t, schemaID, tblInfo.ID) + if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { + job.State = model.JobStateCancelled + } + return ver, errors.Trace(err) + } + + // Recover table divide into 2 steps: + // 1. Check GC enable status, to decided whether enable GC after recover table. + // a. Why not disable GC before put the job to DDL job queue? + // Think about concurrency problem. If a recover job-1 is doing and already disabled GC, + // then, another recover table job-2 check GC enable will get disable before into the job queue. + // then, after recover table job-2 finished, the GC will be disabled. + // b. Why split into 2 steps? 1 step also can finish this job: check GC -> disable GC -> recover table -> finish job. + // What if the transaction commit failed? then, the job will retry, but the GC already disabled when first running. + // So, after this job retry succeed, the GC will be disabled. + // 2. Do recover table job. + // a. Check whether GC enabled, if enabled, disable GC first. + // b. Check GC safe point. If drop table time if after safe point time, then can do recover. + // otherwise, can't recover table, because the records of the table may already delete by gc. + // c. Remove GC task of the table from gc_delete_range table. + // d. Create table and rebase table auto ID. + // e. Finish. + switch tblInfo.State { + case model.StateNone: + // none -> write only + // check GC enable and update flag. + if gcEnable { + job.Args[checkFlagIndexInJobArgs] = recoverCheckFlagEnableGC + } else { + job.Args[checkFlagIndexInJobArgs] = recoverCheckFlagDisableGC + } + + job.SchemaState = model.StateWriteOnly + tblInfo.State = model.StateWriteOnly + case model.StateWriteOnly: + // write only -> public + // do recover table. + if gcEnable { + err = disableGC(w) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Errorf("disable gc failed, try again later. err: %v", err) + } + } + // check GC safe point + err = checkSafePoint(w, recoverInfo.SnapshotTS) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + ver, err = w.recoverTable(t, job, recoverInfo) + if err != nil { + return ver, errors.Trace(err) + } + tableInfo := tblInfo.Clone() + tableInfo.State = model.StatePublic + tableInfo.UpdateTS = t.StartTS + ver, err = updateVersionAndTableInfo(d, t, job, tableInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + tblInfo.State = model.StatePublic + tblInfo.UpdateTS = t.StartTS + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + default: + return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tblInfo.State) + } + return ver, nil +} + +func (w *worker) recoverTable(t *meta.Meta, job *model.Job, recoverInfo *RecoverInfo) (ver int64, err error) { + var tids []int64 + if recoverInfo.TableInfo.GetPartitionInfo() != nil { + tids = getPartitionIDs(recoverInfo.TableInfo) + tids = append(tids, recoverInfo.TableInfo.ID) + } else { + tids = []int64{recoverInfo.TableInfo.ID} + } + tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err := getOldLabelRules(recoverInfo.TableInfo, recoverInfo.OldSchemaName, recoverInfo.OldTableName) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to get old label rules from PD") + } + // Remove dropped table DDL job from gc_delete_range table. + err = w.delRangeManager.removeFromGCDeleteRange(w.ctx, recoverInfo.DropJobID) + if err != nil { + return ver, errors.Trace(err) + } + err = clearTablePlacementAndBundles(w.ctx, recoverInfo.TableInfo) + if err != nil { + return ver, errors.Trace(err) + } + + tableInfo := recoverInfo.TableInfo.Clone() + tableInfo.State = model.StatePublic + tableInfo.UpdateTS = t.StartTS + err = t.CreateTableAndSetAutoID(recoverInfo.SchemaID, tableInfo, recoverInfo.AutoIDs) + if err != nil { + return ver, errors.Trace(err) + } + + failpoint.Inject("mockRecoverTableCommitErr", func(val failpoint.Value) { + if val.(bool) && atomic.CompareAndSwapUint32(&mockRecoverTableCommitErrOnce, 0, 1) { + err = failpoint.Enable(`tikvclient/mockCommitErrorOpt`, "return(true)") + if err != nil { + return + } + } + }) + + err = updateLabelRules(job, recoverInfo.TableInfo, oldRules, tableRuleID, partRuleIDs, oldRuleIDs, recoverInfo.TableInfo.ID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to update the label rule to PD") + } + job.CtxVars = []any{tids} + return ver, nil +} + +func clearTablePlacementAndBundles(ctx context.Context, tblInfo *model.TableInfo) error { + failpoint.Inject("mockClearTablePlacementAndBundlesErr", func() { + failpoint.Return(errors.New("mock error for clearTablePlacementAndBundles")) + }) + var bundles []*placement.Bundle + if tblInfo.PlacementPolicyRef != nil { + tblInfo.PlacementPolicyRef = nil + bundles = append(bundles, placement.NewBundle(tblInfo.ID)) + } + + if tblInfo.Partition != nil { + for i := range tblInfo.Partition.Definitions { + par := &tblInfo.Partition.Definitions[i] + if par.PlacementPolicyRef != nil { + par.PlacementPolicyRef = nil + bundles = append(bundles, placement.NewBundle(par.ID)) + } + } + } + + if len(bundles) == 0 { + return nil + } + + return infosync.PutRuleBundlesWithDefaultRetry(ctx, bundles) +} + +// mockRecoverTableCommitErrOnce uses to make sure +// `mockRecoverTableCommitErr` only mock error once. +var mockRecoverTableCommitErrOnce uint32 + +func enableGC(w *worker) error { + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) + + return gcutil.EnableGC(ctx) +} + +func disableGC(w *worker) error { + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) + + return gcutil.DisableGC(ctx) +} + +func checkGCEnable(w *worker) (enable bool, err error) { + ctx, err := w.sessPool.Get() + if err != nil { + return false, errors.Trace(err) + } + defer w.sessPool.Put(ctx) + + return gcutil.CheckGCEnable(ctx) +} + +func checkSafePoint(w *worker, snapshotTS uint64) error { + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) + + return gcutil.ValidateSnapshot(ctx, snapshotTS) +} + +func getTable(r autoid.Requirement, schemaID int64, tblInfo *model.TableInfo) (table.Table, error) { + allocs := autoid.NewAllocatorsFromTblInfo(r, schemaID, tblInfo) + tbl, err := table.TableFromMeta(allocs, tblInfo) + return tbl, errors.Trace(err) +} + +// GetTableInfoAndCancelFaultJob is exported for test. +func GetTableInfoAndCancelFaultJob(t *meta.Meta, job *model.Job, schemaID int64) (*model.TableInfo, error) { + tblInfo, err := checkTableExistAndCancelNonExistJob(t, job, schemaID) + if err != nil { + return nil, errors.Trace(err) + } + + if tblInfo.State != model.StatePublic { + job.State = model.JobStateCancelled + return nil, dbterror.ErrInvalidDDLState.GenWithStack("table %s is not in public, but %s", tblInfo.Name, tblInfo.State) + } + + return tblInfo, nil +} + +func checkTableExistAndCancelNonExistJob(t *meta.Meta, job *model.Job, schemaID int64) (*model.TableInfo, error) { + tblInfo, err := getTableInfo(t, job.TableID, schemaID) + if err == nil { + // Check if table name is renamed. + if job.TableName != "" && tblInfo.Name.L != job.TableName && job.Type != model.ActionRepairTable { + job.State = model.JobStateCancelled + return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(job.SchemaName, job.TableName) + } + return tblInfo, nil + } + if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { + job.State = model.JobStateCancelled + } + return nil, err +} + +func getTableInfo(t *meta.Meta, tableID, schemaID int64) (*model.TableInfo, error) { + // Check this table's database. + tblInfo, err := t.GetTable(schemaID, tableID) + if err != nil { + if meta.ErrDBNotExists.Equal(err) { + return nil, errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", schemaID), + )) + } + return nil, errors.Trace(err) + } + + // Check the table. + if tblInfo == nil { + return nil, errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", schemaID), + fmt.Sprintf("(Table ID %d)", tableID), + )) + } + return tblInfo, nil +} + +// onTruncateTable delete old table meta, and creates a new table identical to old table except for table ID. +// As all the old data is encoded with old table ID, it can not be accessed anymore. +// A background job will be created to delete old data. +func (w *worker) onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + schemaID := job.SchemaID + tableID := job.TableID + var newTableID int64 + var fkCheck bool + var newPartitionIDs []int64 + err := job.DecodeArgs(&newTableID, &fkCheck, &newPartitionIDs) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return ver, errors.Trace(err) + } + if tblInfo.IsView() || tblInfo.IsSequence() { + job.State = model.JobStateCancelled + return ver, infoschema.ErrTableNotExists.GenWithStackByArgs(job.SchemaName, tblInfo.Name.O) + } + // Copy the old tableInfo for later usage. + oldTblInfo := tblInfo.Clone() + err = checkTruncateTableHasForeignKeyReferredInOwner(d, t, job, tblInfo, fkCheck) + if err != nil { + return ver, err + } + err = t.DropTableOrView(schemaID, tblInfo.ID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + err = t.GetAutoIDAccessors(schemaID, tblInfo.ID).Del() + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + failpoint.Inject("truncateTableErr", func(val failpoint.Value) { + if val.(bool) { + job.State = model.JobStateCancelled + failpoint.Return(ver, errors.New("occur an error after dropping table")) + } + }) + + // Clear the TiFlash replica progress from ETCD. + if tblInfo.TiFlashReplica != nil { + e := infosync.DeleteTiFlashTableSyncProgress(tblInfo) + if e != nil { + logutil.DDLLogger().Error("DeleteTiFlashTableSyncProgress fails", zap.Error(e)) + } + } + + var oldPartitionIDs []int64 + if tblInfo.GetPartitionInfo() != nil { + oldPartitionIDs = getPartitionIDs(tblInfo) + // We use the new partition ID because all the old data is encoded with the old partition ID, it can not be accessed anymore. + err = truncateTableByReassignPartitionIDs(t, tblInfo, newPartitionIDs) + if err != nil { + return ver, errors.Trace(err) + } + } + + if pi := tblInfo.GetPartitionInfo(); pi != nil { + oldIDs := make([]int64, 0, len(oldPartitionIDs)) + newIDs := make([]int64, 0, len(oldPartitionIDs)) + newDefs := pi.Definitions + for i := range oldPartitionIDs { + newDef := &newDefs[i] + newID := newDef.ID + if newDef.PlacementPolicyRef != nil { + oldIDs = append(oldIDs, oldPartitionIDs[i]) + newIDs = append(newIDs, newID) + } + } + job.CtxVars = []any{oldIDs, newIDs} + } + + tableRuleID, partRuleIDs, _, oldRules, err := getOldLabelRules(tblInfo, job.SchemaName, tblInfo.Name.L) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Wrapf(err, "failed to get old label rules from PD") + } + + err = updateLabelRules(job, tblInfo, oldRules, tableRuleID, partRuleIDs, []string{}, newTableID) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Wrapf(err, "failed to update the label rule to PD") + } + + // Clear the TiFlash replica available status. + if tblInfo.TiFlashReplica != nil { + // Set PD rules for TiFlash + if pi := tblInfo.GetPartitionInfo(); pi != nil { + if e := infosync.ConfigureTiFlashPDForPartitions(true, &pi.Definitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID); e != nil { + logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(err)) + job.State = model.JobStateCancelled + return ver, errors.Trace(e) + } + } else { + if e := infosync.ConfigureTiFlashPDForTable(newTableID, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels); e != nil { + logutil.DDLLogger().Error("ConfigureTiFlashPDForTable fails", zap.Error(err)) + job.State = model.JobStateCancelled + return ver, errors.Trace(e) + } + } + tblInfo.TiFlashReplica.AvailablePartitionIDs = nil + tblInfo.TiFlashReplica.Available = false + } + + tblInfo.ID = newTableID + + // build table & partition bundles if any. + bundles, err := placement.NewFullTableBundles(t, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Wrapf(err, "failed to notify PD the placement rules") + } + + err = t.CreateTableOrView(schemaID, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + failpoint.Inject("mockTruncateTableUpdateVersionError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(ver, errors.New("mock update version error")) + } + }) + + var partitions []model.PartitionDefinition + if pi := tblInfo.GetPartitionInfo(); pi != nil { + partitions = tblInfo.GetPartitionInfo().Definitions + } + preSplitAndScatter(w.sess.Context, d.store, tblInfo, partitions) + + ver, err = updateSchemaVersion(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + truncateTableEvent := statsutil.NewTruncateTableEvent( + job.SchemaID, + tblInfo, + oldTblInfo, + ) + asyncNotifyEvent(d, truncateTableEvent) + startKey := tablecodec.EncodeTablePrefix(tableID) + job.Args = []any{startKey, oldPartitionIDs} + return ver, nil +} + +func onRebaseAutoIncrementIDType(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + return onRebaseAutoID(d, d.store, t, job, autoid.AutoIncrementType) +} + +func onRebaseAutoRandomType(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + return onRebaseAutoID(d, d.store, t, job, autoid.AutoRandomType) +} + +func onRebaseAutoID(d *ddlCtx, _ kv.Storage, t *meta.Meta, job *model.Job, tp autoid.AllocatorType) (ver int64, _ error) { + schemaID := job.SchemaID + var ( + newBase int64 + force bool + ) + err := job.DecodeArgs(&newBase, &force) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + return ver, nil + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + tbl, err := getTable(d.getAutoIDRequirement(), schemaID, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if !force { + newBaseTemp, err := adjustNewBaseToNextGlobalID(nil, tbl, tp, newBase) + if err != nil { + return ver, errors.Trace(err) + } + if newBase != newBaseTemp { + job.Warning = toTError(fmt.Errorf("Can't reset AUTO_INCREMENT to %d without FORCE option, using %d instead", + newBase, newBaseTemp, + )) + } + newBase = newBaseTemp + } + + if tp == autoid.AutoIncrementType { + tblInfo.AutoIncID = newBase + } else { + tblInfo.AutoRandID = newBase + } + + if alloc := tbl.Allocators(nil).Get(tp); alloc != nil { + // The next value to allocate is `newBase`. + newEnd := newBase - 1 + if force { + err = alloc.ForceRebase(newEnd) + } else { + err = alloc.Rebase(context.Background(), newEnd, false) + } + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func onModifyTableAutoIDCache(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { + var cache int64 + if err := job.DecodeArgs(&cache); err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, errors.Trace(err) + } + + tblInfo.AutoIdCache = cache + ver, err := updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func (w *worker) onShardRowID(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var shardRowIDBits uint64 + err := job.DecodeArgs(&shardRowIDBits) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + if shardRowIDBits < tblInfo.ShardRowIDBits { + tblInfo.ShardRowIDBits = shardRowIDBits + } else { + tbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, tblInfo) + if err != nil { + return ver, errors.Trace(err) + } + err = verifyNoOverflowShardBits(w.sessPool, tbl, shardRowIDBits) + if err != nil { + job.State = model.JobStateCancelled + return ver, err + } + tblInfo.ShardRowIDBits = shardRowIDBits + // MaxShardRowIDBits use to check the overflow of auto ID. + tblInfo.MaxShardRowIDBits = shardRowIDBits + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func verifyNoOverflowShardBits(s *sess.Pool, tbl table.Table, shardRowIDBits uint64) error { + if shardRowIDBits == 0 { + return nil + } + ctx, err := s.Get() + if err != nil { + return errors.Trace(err) + } + defer s.Put(ctx) + // Check next global max auto ID first. + autoIncID, err := tbl.Allocators(ctx.GetTableCtx()).Get(autoid.RowIDAllocType).NextGlobalAutoID() + if err != nil { + return errors.Trace(err) + } + if tables.OverflowShardBits(autoIncID, shardRowIDBits, autoid.RowIDBitLength, true) { + return autoid.ErrAutoincReadFailed.GenWithStack("shard_row_id_bits %d will cause next global auto ID %v overflow", shardRowIDBits, autoIncID) + } + return nil +} + +func onRenameTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var oldSchemaID int64 + var oldSchemaName model.CIStr + var tableName model.CIStr + if err := job.DecodeArgs(&oldSchemaID, &tableName, &oldSchemaName); err != nil { + // Invalid arguments, cancel this job. + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if job.SchemaState == model.StatePublic { + return finishJobRenameTable(d, t, job) + } + newSchemaID := job.SchemaID + err := checkTableNotExists(d, newSchemaID, tableName.L) + if err != nil { + if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { + job.State = model.JobStateCancelled + } + return ver, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, oldSchemaID) + if err != nil { + return ver, errors.Trace(err) + } + oldTableName := tblInfo.Name + ver, err = checkAndRenameTables(t, job, tblInfo, oldSchemaID, job.SchemaID, &oldSchemaName, &tableName) + if err != nil { + return ver, errors.Trace(err) + } + fkh := newForeignKeyHelper() + err = adjustForeignKeyChildTableInfoAfterRenameTable(d, t, job, &fkh, tblInfo, oldSchemaName, oldTableName, tableName, newSchemaID) + if err != nil { + return ver, errors.Trace(err) + } + ver, err = updateSchemaVersion(d, t, job, fkh.getLoadedTables()...) + if err != nil { + return ver, errors.Trace(err) + } + job.SchemaState = model.StatePublic + return ver, nil +} + +func onRenameTables(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + oldSchemaIDs := []int64{} + newSchemaIDs := []int64{} + tableNames := []*model.CIStr{} + tableIDs := []int64{} + oldSchemaNames := []*model.CIStr{} + oldTableNames := []*model.CIStr{} + if err := job.DecodeArgs(&oldSchemaIDs, &newSchemaIDs, &tableNames, &tableIDs, &oldSchemaNames, &oldTableNames); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if job.SchemaState == model.StatePublic { + return finishJobRenameTables(d, t, job, tableNames, tableIDs, newSchemaIDs) + } + + var err error + fkh := newForeignKeyHelper() + for i, oldSchemaID := range oldSchemaIDs { + job.TableID = tableIDs[i] + job.TableName = oldTableNames[i].L + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, oldSchemaID) + if err != nil { + return ver, errors.Trace(err) + } + ver, err := checkAndRenameTables(t, job, tblInfo, oldSchemaID, newSchemaIDs[i], oldSchemaNames[i], tableNames[i]) + if err != nil { + return ver, errors.Trace(err) + } + err = adjustForeignKeyChildTableInfoAfterRenameTable(d, t, job, &fkh, tblInfo, *oldSchemaNames[i], *oldTableNames[i], *tableNames[i], newSchemaIDs[i]) + if err != nil { + return ver, errors.Trace(err) + } + } + + ver, err = updateSchemaVersion(d, t, job, fkh.getLoadedTables()...) + if err != nil { + return ver, errors.Trace(err) + } + job.SchemaState = model.StatePublic + return ver, nil +} + +func checkAndRenameTables(t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, oldSchemaID, newSchemaID int64, oldSchemaName, tableName *model.CIStr) (ver int64, _ error) { + err := t.DropTableOrView(oldSchemaID, tblInfo.ID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + failpoint.Inject("renameTableErr", func(val failpoint.Value) { + if valStr, ok := val.(string); ok { + if tableName.L == valStr { + job.State = model.JobStateCancelled + failpoint.Return(ver, errors.New("occur an error after renaming table")) + } + } + }) + + oldTableName := tblInfo.Name + tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err := getOldLabelRules(tblInfo, oldSchemaName.L, oldTableName.L) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to get old label rules from PD") + } + + if tblInfo.AutoIDSchemaID == 0 && newSchemaID != oldSchemaID { + // The auto id is referenced by a schema id + table id + // Table ID is not changed between renames, but schema id can change. + // To allow concurrent use of the auto id during rename, keep the auto id + // by always reference it with the schema id it was originally created in. + tblInfo.AutoIDSchemaID = oldSchemaID + } + if newSchemaID == tblInfo.AutoIDSchemaID { + // Back to the original schema id, no longer needed. + tblInfo.AutoIDSchemaID = 0 + } + + tblInfo.Name = *tableName + err = t.CreateTableOrView(newSchemaID, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + err = updateLabelRules(job, tblInfo, oldRules, tableRuleID, partRuleIDs, oldRuleIDs, tblInfo.ID) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to update the label rule to PD") + } + + return ver, nil +} + +func adjustForeignKeyChildTableInfoAfterRenameTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkh *foreignKeyHelper, tblInfo *model.TableInfo, oldSchemaName, oldTableName, newTableName model.CIStr, newSchemaID int64) error { + if !variable.EnableForeignKey.Load() || newTableName.L == oldTableName.L { + return nil + } + is, err := getAndCheckLatestInfoSchema(d, t) + if err != nil { + return err + } + newDB, ok := is.SchemaByID(newSchemaID) + if !ok { + job.State = model.JobStateCancelled + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(fmt.Sprintf("schema-ID: %v", newSchemaID)) + } + referredFKs := is.GetTableReferredForeignKeys(oldSchemaName.L, oldTableName.L) + if len(referredFKs) == 0 { + return nil + } + fkh.addLoadedTable(oldSchemaName.L, oldTableName.L, newDB.ID, tblInfo) + for _, referredFK := range referredFKs { + childTableInfo, err := fkh.getTableFromStorage(is, t, referredFK.ChildSchema, referredFK.ChildTable) + if err != nil { + if infoschema.ErrTableNotExists.Equal(err) || infoschema.ErrDatabaseNotExists.Equal(err) { + continue + } + return err + } + childFKInfo := model.FindFKInfoByName(childTableInfo.tblInfo.ForeignKeys, referredFK.ChildFKName.L) + if childFKInfo == nil { + continue + } + childFKInfo.RefSchema = newDB.Name + childFKInfo.RefTable = newTableName + } + for _, info := range fkh.loaded { + err = updateTable(t, info.schemaID, info.tblInfo) + if err != nil { + return err + } + } + return nil +} + +// We split the renaming table job into two steps: +// 1. rename table and update the schema version. +// 2. update the job state to JobStateDone. +// This is the requirement from TiCDC because +// - it uses the job state to check whether the DDL is finished. +// - there is a gap between schema reloading and job state updating: +// when the job state is updated to JobStateDone, before the new schema reloaded, +// there may be DMLs that use the old schema. +// - TiCDC cannot handle the DMLs that use the old schema, because +// the commit TS of the DMLs are greater than the job state updating TS. +func finishJobRenameTable(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { + tblInfo, err := getTableInfo(t, job.TableID, job.SchemaID) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + // Before updating the schema version, we need to reset the old schema ID to new schema ID, so that + // the table info can be dropped normally in `ApplyDiff`. This is because renaming table requires two + // schema versions to complete. + oldRawArgs := job.RawArgs + job.Args[0] = job.SchemaID + job.RawArgs, err = json.Marshal(job.Args) + if err != nil { + return 0, errors.Trace(err) + } + ver, err := updateSchemaVersion(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + job.RawArgs = oldRawArgs + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func finishJobRenameTables(d *ddlCtx, t *meta.Meta, job *model.Job, + tableNames []*model.CIStr, tableIDs, newSchemaIDs []int64) (int64, error) { + tblSchemaIDs := make(map[int64]int64, len(tableIDs)) + for i := range tableIDs { + tblSchemaIDs[tableIDs[i]] = newSchemaIDs[i] + } + tblInfos := make([]*model.TableInfo, 0, len(tableNames)) + for i := range tableIDs { + tblID := tableIDs[i] + tblInfo, err := getTableInfo(t, tblID, tblSchemaIDs[tblID]) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + tblInfos = append(tblInfos, tblInfo) + } + // Before updating the schema version, we need to reset the old schema ID to new schema ID, so that + // the table info can be dropped normally in `ApplyDiff`. This is because renaming table requires two + // schema versions to complete. + var err error + oldRawArgs := job.RawArgs + job.Args[0] = newSchemaIDs + job.RawArgs, err = json.Marshal(job.Args) + if err != nil { + return 0, errors.Trace(err) + } + ver, err := updateSchemaVersion(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + job.RawArgs = oldRawArgs + job.FinishMultipleTableJob(model.JobStateDone, model.StatePublic, ver, tblInfos) + return ver, nil +} + +func onModifyTableComment(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var comment string + if err := job.DecodeArgs(&comment); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + + if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + return ver, nil + } + + tblInfo.Comment = comment + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func onModifyTableCharsetAndCollate(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var toCharset, toCollate string + var needsOverwriteCols bool + if err := job.DecodeArgs(&toCharset, &toCollate, &needsOverwriteCols); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) + if err != nil { + return ver, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + + // double check. + _, err = checkAlterTableCharset(tblInfo, dbInfo, toCharset, toCollate, needsOverwriteCols) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { + job.MarkNonRevertible() + return ver, nil + } + + tblInfo.Charset = toCharset + tblInfo.Collate = toCollate + + if needsOverwriteCols { + // update column charset. + for _, col := range tblInfo.Columns { + if field_types.HasCharset(&col.FieldType) { + col.SetCharset(toCharset) + col.SetCollate(toCollate) + } else { + col.SetCharset(charset.CharsetBin) + col.SetCollate(charset.CharsetBin) + } + } + } + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func (w *worker) onSetTableFlashReplica(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var replicaInfo ast.TiFlashReplicaSpec + if err := job.DecodeArgs(&replicaInfo); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + + // Ban setting replica count for tables in system database. + if tidb_util.IsMemOrSysDB(job.SchemaName) { + return ver, errors.Trace(dbterror.ErrUnsupportedTiFlashOperationForSysOrMemTable) + } + + err = w.checkTiFlashReplicaCount(replicaInfo.Count) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + // We should check this first, in order to avoid creating redundant DDL jobs. + if pi := tblInfo.GetPartitionInfo(); pi != nil { + logutil.DDLLogger().Info("Set TiFlash replica pd rule for partitioned table", zap.Int64("tableID", tblInfo.ID)) + if e := infosync.ConfigureTiFlashPDForPartitions(false, &pi.Definitions, replicaInfo.Count, &replicaInfo.Labels, tblInfo.ID); e != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(e) + } + // Partitions that in adding mid-state. They have high priorities, so we should set accordingly pd rules. + if e := infosync.ConfigureTiFlashPDForPartitions(true, &pi.AddingDefinitions, replicaInfo.Count, &replicaInfo.Labels, tblInfo.ID); e != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(e) + } + } else { + logutil.DDLLogger().Info("Set TiFlash replica pd rule", zap.Int64("tableID", tblInfo.ID)) + if e := infosync.ConfigureTiFlashPDForTable(tblInfo.ID, replicaInfo.Count, &replicaInfo.Labels); e != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(e) + } + } + + available := false + if tblInfo.TiFlashReplica != nil { + available = tblInfo.TiFlashReplica.Available + } + if replicaInfo.Count > 0 { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: replicaInfo.Count, + LocationLabels: replicaInfo.Labels, + Available: available, + } + } else { + if tblInfo.TiFlashReplica != nil { + err = infosync.DeleteTiFlashTableSyncProgress(tblInfo) + if err != nil { + logutil.DDLLogger().Error("DeleteTiFlashTableSyncProgress fails", zap.Error(err)) + } + } + tblInfo.TiFlashReplica = nil + } + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func (w *worker) checkTiFlashReplicaCount(replicaCount uint64) error { + ctx, err := w.sessPool.Get() + if err != nil { + return errors.Trace(err) + } + defer w.sessPool.Put(ctx) + + return checkTiFlashReplicaCount(ctx, replicaCount) +} + +func onUpdateFlashReplicaStatus(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + var available bool + var physicalID int64 + if err := job.DecodeArgs(&available, &physicalID); err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return ver, errors.Trace(err) + } + if tblInfo.TiFlashReplica == nil || (tblInfo.ID == physicalID && tblInfo.TiFlashReplica.Available == available) || + (tblInfo.ID != physicalID && available == tblInfo.TiFlashReplica.IsPartitionAvailable(physicalID)) { + job.State = model.JobStateCancelled + return ver, errors.Errorf("the replica available status of table %s is already updated", tblInfo.Name.String()) + } + + if tblInfo.ID == physicalID { + tblInfo.TiFlashReplica.Available = available + } else if pi := tblInfo.GetPartitionInfo(); pi != nil { + // Partition replica become available. + if available { + allAvailable := true + for _, p := range pi.Definitions { + if p.ID == physicalID { + tblInfo.TiFlashReplica.AvailablePartitionIDs = append(tblInfo.TiFlashReplica.AvailablePartitionIDs, physicalID) + } + allAvailable = allAvailable && tblInfo.TiFlashReplica.IsPartitionAvailable(p.ID) + } + tblInfo.TiFlashReplica.Available = allAvailable + } else { + // Partition replica become unavailable. + for i, id := range tblInfo.TiFlashReplica.AvailablePartitionIDs { + if id == physicalID { + newIDs := tblInfo.TiFlashReplica.AvailablePartitionIDs[:i] + newIDs = append(newIDs, tblInfo.TiFlashReplica.AvailablePartitionIDs[i+1:]...) + tblInfo.TiFlashReplica.AvailablePartitionIDs = newIDs + tblInfo.TiFlashReplica.Available = false + logutil.DDLLogger().Info("TiFlash replica become unavailable", zap.Int64("tableID", tblInfo.ID), zap.Int64("partitionID", id)) + break + } + } + } + } else { + job.State = model.JobStateCancelled + return ver, errors.Errorf("unknown physical ID %v in table %v", physicalID, tblInfo.Name.O) + } + + if tblInfo.TiFlashReplica.Available { + logutil.DDLLogger().Info("TiFlash replica available", zap.Int64("tableID", tblInfo.ID)) + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +// checking using cached info schema should be enough, as: +// - we will reload schema until success when become the owner +// - existing tables are correctly checked in the first place +// - we calculate job dependencies before running jobs, so there will not be 2 +// jobs creating same table running concurrently. +// +// if there are 2 owners A and B, we have 2 consecutive jobs J1 and J2 which +// are creating the same table T. those 2 jobs might be running concurrently when +// A sees J1 first and B sees J2 first. But for B sees J2 first, J1 must already +// be done and synced, and been deleted from tidb_ddl_job table, as we are querying +// jobs in the order of job id. During syncing J1, B should have synced the schema +// with the latest schema version, so when B runs J2, below check will see the table +// T already exists, and J2 will fail. +func checkTableNotExists(d *ddlCtx, schemaID int64, tableName string) error { + is := d.infoCache.GetLatest() + return checkTableNotExistsFromInfoSchema(is, schemaID, tableName) +} + +func checkConstraintNamesNotExists(t *meta.Meta, schemaID int64, constraints []*model.ConstraintInfo) error { + if len(constraints) == 0 { + return nil + } + tbInfos, err := t.ListTables(schemaID) + if err != nil { + return err + } + + for _, tb := range tbInfos { + for _, constraint := range constraints { + if constraint.State != model.StateWriteOnly { + if constraintInfo := tb.FindConstraintInfoByName(constraint.Name.L); constraintInfo != nil { + return infoschema.ErrCheckConstraintDupName.GenWithStackByArgs(constraint.Name.L) + } + } + } + } + + return nil +} + +func checkTableIDNotExists(t *meta.Meta, schemaID, tableID int64) error { + tbl, err := t.GetTable(schemaID, tableID) + if err != nil { + if meta.ErrDBNotExists.Equal(err) { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") + } + return errors.Trace(err) + } + if tbl != nil { + return infoschema.ErrTableExists.GenWithStackByArgs(tbl.Name) + } + return nil +} + +func checkTableNotExistsFromInfoSchema(is infoschema.InfoSchema, schemaID int64, tableName string) error { + // Check this table's database. + schema, ok := is.SchemaByID(schemaID) + if !ok { + return infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") + } + if is.TableExists(schema.Name, model.NewCIStr(tableName)) { + return infoschema.ErrTableExists.GenWithStackByArgs(tableName) + } + return nil +} + +// updateVersionAndTableInfoWithCheck checks table info validate and updates the schema version and the table information +func updateVersionAndTableInfoWithCheck(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, shouldUpdateVer bool, multiInfos ...schemaIDAndTableInfo) ( + ver int64, err error) { + err = checkTableInfoValid(tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + for _, info := range multiInfos { + err = checkTableInfoValid(info.tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + } + return updateVersionAndTableInfo(d, t, job, tblInfo, shouldUpdateVer, multiInfos...) +} + +// updateVersionAndTableInfo updates the schema version and the table information. +func updateVersionAndTableInfo(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, shouldUpdateVer bool, multiInfos ...schemaIDAndTableInfo) ( + ver int64, err error) { + failpoint.Inject("mockUpdateVersionAndTableInfoErr", func(val failpoint.Value) { + switch val.(int) { + case 1: + failpoint.Return(ver, errors.New("mock update version and tableInfo error")) + case 2: + // We change it cancelled directly here, because we want to get the original error with the job id appended. + // The job ID will be used to get the job from history queue and we will assert it's args. + job.State = model.JobStateCancelled + failpoint.Return(ver, errors.New("mock update version and tableInfo error, jobID="+strconv.Itoa(int(job.ID)))) + default: + } + }) + if shouldUpdateVer && (job.MultiSchemaInfo == nil || !job.MultiSchemaInfo.SkipVersion) { + ver, err = updateSchemaVersion(d, t, job, multiInfos...) + if err != nil { + return 0, errors.Trace(err) + } + } + + err = updateTable(t, job.SchemaID, tblInfo) + if err != nil { + return 0, errors.Trace(err) + } + for _, info := range multiInfos { + err = updateTable(t, info.schemaID, info.tblInfo) + if err != nil { + return 0, errors.Trace(err) + } + } + return ver, nil +} + +func updateTable(t *meta.Meta, schemaID int64, tblInfo *model.TableInfo) error { + if tblInfo.State == model.StatePublic { + tblInfo.UpdateTS = t.StartTS + } + return t.UpdateTable(schemaID, tblInfo) +} + +type schemaIDAndTableInfo struct { + schemaID int64 + tblInfo *model.TableInfo +} + +func onRepairTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { + schemaID := job.SchemaID + tblInfo := &model.TableInfo{} + + if err := job.DecodeArgs(tblInfo); err != nil { + // Invalid arguments, cancel this job. + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + tblInfo.State = model.StateNone + + // Check the old DB and old table exist. + _, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) + if err != nil { + return ver, errors.Trace(err) + } + + // When in repair mode, the repaired table in a server is not access to user, + // the table after repairing will be removed from repair list. Other server left + // behind alive may need to restart to get the latest schema version. + ver, err = updateSchemaVersion(d, t, job) + if err != nil { + return ver, errors.Trace(err) + } + switch tblInfo.State { + case model.StateNone: + // none -> public + tblInfo.State = model.StatePublic + tblInfo.UpdateTS = t.StartTS + err = repairTableOrViewWithCheck(t, job, schemaID, tblInfo) + if err != nil { + return ver, errors.Trace(err) + } + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil + default: + return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tblInfo.State) + } +} + +func onAlterTableAttributes(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + rule := label.NewRule() + err = job.DecodeArgs(rule) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, err + } + + if len(rule.Labels) == 0 { + patch := label.NewRulePatch([]*label.Rule{}, []string{rule.ID}) + err = infosync.UpdateLabelRules(context.TODO(), patch) + } else { + err = infosync.PutLabelRule(context.TODO(), rule) + } + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Wrapf(err, "failed to notify PD the label rules") + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + + return ver, nil +} + +func onAlterTablePartitionAttributes(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + var partitionID int64 + rule := label.NewRule() + err = job.DecodeArgs(&partitionID, rule) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, err + } + + ptInfo := tblInfo.GetPartitionInfo() + if ptInfo.GetNameByID(partitionID) == "" { + job.State = model.JobStateCancelled + return 0, errors.Trace(table.ErrUnknownPartition.GenWithStackByArgs("drop?", tblInfo.Name.O)) + } + + if len(rule.Labels) == 0 { + patch := label.NewRulePatch([]*label.Rule{}, []string{rule.ID}) + err = infosync.UpdateLabelRules(context.TODO(), patch) + } else { + err = infosync.PutLabelRule(context.TODO(), rule) + } + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Wrapf(err, "failed to notify PD the label rules") + } + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + + return ver, nil +} + +func onAlterTablePartitionPlacement(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + var partitionID int64 + policyRefInfo := &model.PolicyRefInfo{} + err = job.DecodeArgs(&partitionID, &policyRefInfo) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, err + } + + ptInfo := tblInfo.GetPartitionInfo() + var partitionDef *model.PartitionDefinition + definitions := ptInfo.Definitions + oldPartitionEnablesPlacement := false + for i := range definitions { + if partitionID == definitions[i].ID { + def := &definitions[i] + oldPartitionEnablesPlacement = def.PlacementPolicyRef != nil + def.PlacementPolicyRef = policyRefInfo + partitionDef = &definitions[i] + break + } + } + + if partitionDef == nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(table.ErrUnknownPartition.GenWithStackByArgs("drop?", tblInfo.Name.O)) + } + + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + + if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, partitionDef.PlacementPolicyRef); err != nil { + return ver, errors.Trace(err) + } + + bundle, err := placement.NewPartitionBundle(t, *partitionDef) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if bundle == nil && oldPartitionEnablesPlacement { + bundle = placement.NewBundle(partitionDef.ID) + } + + // Send the placement bundle to PD. + if bundle != nil { + err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), []*placement.Bundle{bundle}) + } + + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Wrapf(err, "failed to notify PD the placement rules") + } + + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + return ver, nil +} + +func onAlterTablePlacement(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + policyRefInfo := &model.PolicyRefInfo{} + err = job.DecodeArgs(&policyRefInfo) + if err != nil { + job.State = model.JobStateCancelled + return 0, errors.Trace(err) + } + + tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, err + } + + if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, policyRefInfo); err != nil { + return 0, errors.Trace(err) + } + + oldTableEnablesPlacement := tblInfo.PlacementPolicyRef != nil + tblInfo.PlacementPolicyRef = policyRefInfo + ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) + if err != nil { + return ver, errors.Trace(err) + } + + bundle, err := placement.NewTableBundle(t, tblInfo) + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + if bundle == nil && oldTableEnablesPlacement { + bundle = placement.NewBundle(tblInfo.ID) + } + + // Send the placement bundle to PD. + if bundle != nil { + err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), []*placement.Bundle{bundle}) + } + + if err != nil { + job.State = model.JobStateCancelled + return ver, errors.Trace(err) + } + + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) + + return ver, nil +} + +func getOldLabelRules(tblInfo *model.TableInfo, oldSchemaName, oldTableName string) (string, []string, []string, map[string]*label.Rule, error) { + tableRuleID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, oldSchemaName, oldTableName) + oldRuleIDs := []string{tableRuleID} + var partRuleIDs []string + if tblInfo.GetPartitionInfo() != nil { + for _, def := range tblInfo.GetPartitionInfo().Definitions { + partRuleIDs = append(partRuleIDs, fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, oldSchemaName, oldTableName, def.Name.L)) + } + } + + oldRuleIDs = append(oldRuleIDs, partRuleIDs...) + oldRules, err := infosync.GetLabelRules(context.TODO(), oldRuleIDs) + return tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err +} + +func updateLabelRules(job *model.Job, tblInfo *model.TableInfo, oldRules map[string]*label.Rule, tableRuleID string, partRuleIDs, oldRuleIDs []string, tID int64) error { + if oldRules == nil { + return nil + } + var newRules []*label.Rule + if tblInfo.GetPartitionInfo() != nil { + for idx, def := range tblInfo.GetPartitionInfo().Definitions { + if r, ok := oldRules[partRuleIDs[idx]]; ok { + newRules = append(newRules, r.Clone().Reset(job.SchemaName, tblInfo.Name.L, def.Name.L, def.ID)) + } + } + } + ids := []int64{tID} + if r, ok := oldRules[tableRuleID]; ok { + if tblInfo.GetPartitionInfo() != nil { + for _, def := range tblInfo.GetPartitionInfo().Definitions { + ids = append(ids, def.ID) + } + } + newRules = append(newRules, r.Clone().Reset(job.SchemaName, tblInfo.Name.L, "", ids...)) + } + + patch := label.NewRulePatch(newRules, oldRuleIDs) + return infosync.UpdateLabelRules(context.TODO(), patch) +} + +func onAlterCacheTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + tbInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, errors.Trace(err) + } + // If the table is already in the cache state + if tbInfo.TableCacheStatusType == model.TableCacheStatusEnable { + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) + return ver, nil + } + + if tbInfo.TempTableType != model.TempTableNone { + return ver, errors.Trace(dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("alter temporary table cache")) + } + + if tbInfo.Partition != nil { + return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("partition mode")) + } + + switch tbInfo.TableCacheStatusType { + case model.TableCacheStatusDisable: + // disable -> switching + tbInfo.TableCacheStatusType = model.TableCacheStatusSwitching + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) + if err != nil { + return ver, err + } + case model.TableCacheStatusSwitching: + // switching -> enable + tbInfo.TableCacheStatusType = model.TableCacheStatusEnable + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) + if err != nil { + return ver, err + } + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) + default: + job.State = model.JobStateCancelled + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("alter table cache", tbInfo.TableCacheStatusType.String()) + } + return ver, err +} + +func onAlterNoCacheTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { + tbInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) + if err != nil { + return 0, errors.Trace(err) + } + // If the table is not in the cache state + if tbInfo.TableCacheStatusType == model.TableCacheStatusDisable { + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) + return ver, nil + } + + switch tbInfo.TableCacheStatusType { + case model.TableCacheStatusEnable: + // enable -> switching + tbInfo.TableCacheStatusType = model.TableCacheStatusSwitching + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) + if err != nil { + return ver, err + } + case model.TableCacheStatusSwitching: + // switching -> disable + tbInfo.TableCacheStatusType = model.TableCacheStatusDisable + ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) + if err != nil { + return ver, err + } + // Finish this job. + job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) + default: + job.State = model.JobStateCancelled + err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("alter table no cache", tbInfo.TableCacheStatusType.String()) + } + return ver, err +} diff --git a/pkg/ddl/util/binding__failpoint_binding__.go b/pkg/ddl/util/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..c7fcdb8c0fcf4 --- /dev/null +++ b/pkg/ddl/util/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package util + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/ddl/util/util.go b/pkg/ddl/util/util.go index 6f088e0646854..7342d8b020d51 100644 --- a/pkg/ddl/util/util.go +++ b/pkg/ddl/util/util.go @@ -381,10 +381,10 @@ const ( func IsRaftKv2(ctx context.Context, sctx sessionctx.Context) (bool, error) { // Mock store does not support `show config` now, so we use failpoint here // to control whether we are in raft-kv2 - failpoint.Inject("IsRaftKv2", func(v failpoint.Value) (bool, error) { + if v, _err_ := failpoint.Eval(_curpkg_("IsRaftKv2")); _err_ == nil { v2, _ := v.(bool) return v2, nil - }) + } rs, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, getRaftKvVersionSQL) if err != nil { diff --git a/pkg/ddl/util/util.go__failpoint_stash__ b/pkg/ddl/util/util.go__failpoint_stash__ new file mode 100644 index 0000000000000..6f088e0646854 --- /dev/null +++ b/pkg/ddl/util/util.go__failpoint_stash__ @@ -0,0 +1,427 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "os" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/logutil" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/tikv/client-go/v2/tikvrpc" + clientv3 "go.etcd.io/etcd/client/v3" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + deleteRangesTable = `gc_delete_range` + doneDeleteRangesTable = `gc_delete_range_done` + loadDeleteRangeSQL = `SELECT HIGH_PRIORITY job_id, element_id, start_key, end_key FROM mysql.%n WHERE ts < %?` + recordDoneDeletedRangeSQL = `INSERT IGNORE INTO mysql.gc_delete_range_done SELECT * FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` + completeDeleteRangeSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` + completeDeleteMultiRangesSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %?` + updateDeleteRangeSQL = `UPDATE mysql.gc_delete_range SET start_key = %? WHERE job_id = %? AND element_id = %? AND start_key = %?` + deleteDoneRecordSQL = `DELETE FROM mysql.gc_delete_range_done WHERE job_id = %? AND element_id = %?` + loadGlobalVars = `SELECT HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in (` // + nameList + ")" + // KeyOpDefaultTimeout is the default timeout for each key operation. + KeyOpDefaultTimeout = 2 * time.Second + // KeyOpRetryInterval is the interval between two key operations. + KeyOpRetryInterval = 30 * time.Millisecond + // DDLAllSchemaVersions is the path on etcd that is used to store all servers current schema versions. + DDLAllSchemaVersions = "/tidb/ddl/all_schema_versions" + // DDLAllSchemaVersionsByJob is the path on etcd that is used to store all servers current schema versions. + // /tidb/ddl/all_schema_by_job_versions// ---> + DDLAllSchemaVersionsByJob = "/tidb/ddl/all_schema_by_job_versions" + // DDLGlobalSchemaVersion is the path on etcd that is used to store the latest schema versions. + DDLGlobalSchemaVersion = "/tidb/ddl/global_schema_version" + // ServerGlobalState is the path on etcd that is used to store the server global state. + ServerGlobalState = "/tidb/server/global_state" + // SessionTTL is the etcd session's TTL in seconds. + SessionTTL = 90 +) + +// DelRangeTask is for run delete-range command in gc_worker. +type DelRangeTask struct { + StartKey kv.Key + EndKey kv.Key + JobID int64 + ElementID int64 +} + +// Range returns the range [start, end) to delete. +func (t DelRangeTask) Range() (kv.Key, kv.Key) { + return t.StartKey, t.EndKey +} + +// LoadDeleteRanges loads delete range tasks from gc_delete_range table. +func LoadDeleteRanges(ctx context.Context, sctx sessionctx.Context, safePoint uint64) (ranges []DelRangeTask, _ error) { + return loadDeleteRangesFromTable(ctx, sctx, deleteRangesTable, safePoint) +} + +// LoadDoneDeleteRanges loads deleted ranges from gc_delete_range_done table. +func LoadDoneDeleteRanges(ctx context.Context, sctx sessionctx.Context, safePoint uint64) (ranges []DelRangeTask, _ error) { + return loadDeleteRangesFromTable(ctx, sctx, doneDeleteRangesTable, safePoint) +} + +func loadDeleteRangesFromTable(ctx context.Context, sctx sessionctx.Context, table string, safePoint uint64) (ranges []DelRangeTask, _ error) { + rs, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, loadDeleteRangeSQL, table, safePoint) + if rs != nil { + defer terror.Call(rs.Close) + } + if err != nil { + return nil, errors.Trace(err) + } + + req := rs.NewChunk(nil) + it := chunk.NewIterator4Chunk(req) + for { + err = rs.Next(context.TODO(), req) + if err != nil { + return nil, errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + + for row := it.Begin(); row != it.End(); row = it.Next() { + startKey, err := hex.DecodeString(row.GetString(2)) + if err != nil { + return nil, errors.Trace(err) + } + endKey, err := hex.DecodeString(row.GetString(3)) + if err != nil { + return nil, errors.Trace(err) + } + ranges = append(ranges, DelRangeTask{ + JobID: row.GetInt64(0), + ElementID: row.GetInt64(1), + StartKey: startKey, + EndKey: endKey, + }) + } + } + return ranges, nil +} + +// CompleteDeleteRange moves a record from gc_delete_range table to gc_delete_range_done table. +func CompleteDeleteRange(sctx sessionctx.Context, dr DelRangeTask, needToRecordDone bool) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + + _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "BEGIN") + if err != nil { + return errors.Trace(err) + } + + if needToRecordDone { + _, err = sctx.GetSQLExecutor().ExecuteInternal(ctx, recordDoneDeletedRangeSQL, dr.JobID, dr.ElementID) + if err != nil { + return errors.Trace(err) + } + } + + err = RemoveFromGCDeleteRange(sctx, dr.JobID, dr.ElementID) + if err != nil { + return errors.Trace(err) + } + _, err = sctx.GetSQLExecutor().ExecuteInternal(ctx, "COMMIT") + return errors.Trace(err) +} + +// RemoveFromGCDeleteRange is exported for ddl pkg to use. +func RemoveFromGCDeleteRange(sctx sessionctx.Context, jobID, elementID int64) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, completeDeleteRangeSQL, jobID, elementID) + return errors.Trace(err) +} + +// RemoveMultiFromGCDeleteRange is exported for ddl pkg to use. +func RemoveMultiFromGCDeleteRange(ctx context.Context, sctx sessionctx.Context, jobID int64) error { + _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, completeDeleteMultiRangesSQL, jobID) + return errors.Trace(err) +} + +// DeleteDoneRecord removes a record from gc_delete_range_done table. +func DeleteDoneRecord(sctx sessionctx.Context, dr DelRangeTask) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, deleteDoneRecordSQL, dr.JobID, dr.ElementID) + return errors.Trace(err) +} + +// UpdateDeleteRange is only for emulator. +func UpdateDeleteRange(sctx sessionctx.Context, dr DelRangeTask, newStartKey, oldStartKey kv.Key) error { + newStartKeyHex := hex.EncodeToString(newStartKey) + oldStartKeyHex := hex.EncodeToString(oldStartKey) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, updateDeleteRangeSQL, newStartKeyHex, dr.JobID, dr.ElementID, oldStartKeyHex) + return errors.Trace(err) +} + +// LoadDDLReorgVars loads ddl reorg variable from mysql.global_variables. +func LoadDDLReorgVars(ctx context.Context, sctx sessionctx.Context) error { + // close issue #21391 + // variable.TiDBRowFormatVersion is used to encode the new row for column type change. + return LoadGlobalVars(ctx, sctx, []string{variable.TiDBDDLReorgWorkerCount, variable.TiDBDDLReorgBatchSize, variable.TiDBRowFormatVersion}) +} + +// LoadDDLVars loads ddl variable from mysql.global_variables. +func LoadDDLVars(ctx sessionctx.Context) error { + return LoadGlobalVars(context.Background(), ctx, []string{variable.TiDBDDLErrorCountLimit}) +} + +// LoadGlobalVars loads global variable from mysql.global_variables. +func LoadGlobalVars(ctx context.Context, sctx sessionctx.Context, varNames []string) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnDDL) + // *mock.Context does not support SQL execution. Only do it when sctx is not `mock.Context` + if _, ok := sctx.(*mock.Context); !ok { + e := sctx.GetRestrictedSQLExecutor() + var buf strings.Builder + buf.WriteString(loadGlobalVars) + paramNames := make([]any, 0, len(varNames)) + for i, name := range varNames { + if i > 0 { + buf.WriteString(", ") + } + buf.WriteString("%?") + paramNames = append(paramNames, name) + } + buf.WriteString(")") + rows, _, err := e.ExecRestrictedSQL(ctx, nil, buf.String(), paramNames...) + if err != nil { + return errors.Trace(err) + } + for _, row := range rows { + varName := row.GetString(0) + varValue := row.GetString(1) + if err = sctx.GetSessionVars().SetSystemVarWithoutValidation(varName, varValue); err != nil { + return err + } + } + } + return nil +} + +// GetTimeZone gets the session location's zone name and offset. +func GetTimeZone(sctx sessionctx.Context) (string, int) { + loc := sctx.GetSessionVars().Location() + name := loc.String() + if name != "" { + _, err := time.LoadLocation(name) + if err == nil { + return name, 0 + } + } + _, offset := time.Now().In(loc).Zone() + return "", offset +} + +// enableEmulatorGC means whether to enable emulator GC. The default is enable. +// In some unit tests, we want to stop emulator GC, then wen can set enableEmulatorGC to 0. +var emulatorGCEnable = atomicutil.NewInt32(1) + +// EmulatorGCEnable enables emulator gc. It exports for testing. +func EmulatorGCEnable() { + emulatorGCEnable.Store(1) +} + +// EmulatorGCDisable disables emulator gc. It exports for testing. +func EmulatorGCDisable() { + emulatorGCEnable.Store(0) +} + +// IsEmulatorGCEnable indicates whether emulator GC enabled. It exports for testing. +func IsEmulatorGCEnable() bool { + return emulatorGCEnable.Load() == 1 +} + +var internalResourceGroupTag = []byte{0} + +// GetInternalResourceGroupTaggerForTopSQL only use for testing. +func GetInternalResourceGroupTaggerForTopSQL() tikvrpc.ResourceGroupTagger { + tagger := func(req *tikvrpc.Request) { + req.ResourceGroupTag = internalResourceGroupTag + } + return tagger +} + +// IsInternalResourceGroupTaggerForTopSQL use for testing. +func IsInternalResourceGroupTaggerForTopSQL(tag []byte) bool { + return bytes.Equal(tag, internalResourceGroupTag) +} + +// DeleteKeyFromEtcd deletes key value from etcd. +func DeleteKeyFromEtcd(key string, etcdCli *clientv3.Client, retryCnt int, timeout time.Duration) error { + var err error + ctx := context.Background() + for i := 0; i < retryCnt; i++ { + childCtx, cancel := context.WithTimeout(ctx, timeout) + _, err = etcdCli.Delete(childCtx, key) + cancel() + if err == nil { + return nil + } + logutil.DDLLogger().Warn("etcd-cli delete key failed", zap.String("key", key), zap.Error(err), zap.Int("retryCnt", i)) + } + return errors.Trace(err) +} + +// PutKVToEtcdMono puts key value to etcd monotonously. +// etcdCli is client of etcd. +// retryCnt is retry time when an error occurs. +// opts are configures of etcd Operations. +func PutKVToEtcdMono(ctx context.Context, etcdCli *clientv3.Client, retryCnt int, key, val string, + opts ...clientv3.OpOption) error { + var err error + for i := 0; i < retryCnt; i++ { + if err = ctx.Err(); err != nil { + return errors.Trace(err) + } + + childCtx, cancel := context.WithTimeout(ctx, KeyOpDefaultTimeout) + var resp *clientv3.GetResponse + resp, err = etcdCli.Get(childCtx, key) + if err != nil { + cancel() + logutil.DDLLogger().Warn("etcd-cli put kv failed", zap.String("key", key), zap.String("value", val), zap.Error(err), zap.Int("retryCnt", i)) + time.Sleep(KeyOpRetryInterval) + continue + } + prevRevision := int64(0) + if len(resp.Kvs) > 0 { + prevRevision = resp.Kvs[0].ModRevision + } + + var txnResp *clientv3.TxnResponse + txnResp, err = etcdCli.Txn(childCtx). + If(clientv3.Compare(clientv3.ModRevision(key), "=", prevRevision)). + Then(clientv3.OpPut(key, val, opts...)). + Commit() + + cancel() + + if err == nil && txnResp.Succeeded { + return nil + } + + if err == nil { + err = errors.New("performing compare-and-swap during PutKVToEtcd failed") + } + + logutil.DDLLogger().Warn("etcd-cli put kv failed", zap.String("key", key), zap.String("value", val), zap.Error(err), zap.Int("retryCnt", i)) + time.Sleep(KeyOpRetryInterval) + } + return errors.Trace(err) +} + +// PutKVToEtcd puts key value to etcd. +// etcdCli is client of etcd. +// retryCnt is retry time when an error occurs. +// opts are configures of etcd Operations. +func PutKVToEtcd(ctx context.Context, etcdCli *clientv3.Client, retryCnt int, key, val string, + opts ...clientv3.OpOption) error { + var err error + for i := 0; i < retryCnt; i++ { + if err = ctx.Err(); err != nil { + return errors.Trace(err) + } + + childCtx, cancel := context.WithTimeout(ctx, KeyOpDefaultTimeout) + _, err = etcdCli.Put(childCtx, key, val, opts...) + cancel() + if err == nil { + return nil + } + logutil.DDLLogger().Warn("etcd-cli put kv failed", zap.String("key", key), zap.String("value", val), zap.Error(err), zap.Int("retryCnt", i)) + time.Sleep(KeyOpRetryInterval) + } + return errors.Trace(err) +} + +// WrapKey2String wraps the key to a string. +func WrapKey2String(key []byte) string { + if len(key) == 0 { + return "''" + } + return fmt.Sprintf("0x%x", key) +} + +const ( + getRaftKvVersionSQL = "select `value` from information_schema.cluster_config where type = 'tikv' and `key` = 'storage.engine'" + raftKv2 = "raft-kv2" +) + +// IsRaftKv2 checks whether the raft-kv2 is enabled +func IsRaftKv2(ctx context.Context, sctx sessionctx.Context) (bool, error) { + // Mock store does not support `show config` now, so we use failpoint here + // to control whether we are in raft-kv2 + failpoint.Inject("IsRaftKv2", func(v failpoint.Value) (bool, error) { + v2, _ := v.(bool) + return v2, nil + }) + + rs, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, getRaftKvVersionSQL) + if err != nil { + return false, err + } + if rs == nil { + return false, nil + } + + defer terror.Call(rs.Close) + rows, err := sqlexec.DrainRecordSet(ctx, rs, sctx.GetSessionVars().MaxChunkSize) + if err != nil { + return false, errors.Trace(err) + } + if len(rows) == 0 { + return false, nil + } + + // All nodes should have the same type of engine + raftVersion := rows[0].GetString(0) + return raftVersion == raftKv2, nil +} + +// FolderNotEmpty returns true only when the folder is not empty. +func FolderNotEmpty(path string) bool { + entries, _ := os.ReadDir(path) + return len(entries) > 0 +} + +// GenKeyExistsErr builds a ErrKeyExists error. +func GenKeyExistsErr(key, value []byte, idxInfo *model.IndexInfo, tblInfo *model.TableInfo) error { + indexName := fmt.Sprintf("%s.%s", tblInfo.Name.String(), idxInfo.Name.String()) + valueStr, err := tables.GenIndexValueFromIndex(key, value, tblInfo, idxInfo) + if err != nil { + logutil.DDLLogger().Warn("decode index key value / column value failed", zap.String("index", indexName), + zap.String("key", hex.EncodeToString(key)), zap.String("value", hex.EncodeToString(value)), zap.Error(err)) + return errors.Trace(kv.ErrKeyExists.FastGenByArgs(key, indexName)) + } + return kv.GenKeyExistsErr(valueStr, indexName) +} diff --git a/pkg/distsql/binding__failpoint_binding__.go b/pkg/distsql/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..b2e21b1442daa --- /dev/null +++ b/pkg/distsql/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package distsql + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/distsql/request_builder.go b/pkg/distsql/request_builder.go index 9f0b8b24a2a18..9c9f05296038c 100644 --- a/pkg/distsql/request_builder.go +++ b/pkg/distsql/request_builder.go @@ -65,12 +65,12 @@ func (builder *RequestBuilder) Build() (*kv.Request, error) { }, } } - failpoint.Inject("assertRequestBuilderReplicaOption", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertRequestBuilderReplicaOption")); _err_ == nil { assertScope := val.(string) if builder.ReplicaRead.IsClosestRead() && assertScope != builder.ReadReplicaScope { panic("request builder get staleness option fail") } - }) + } err := builder.verifyTxnScope() if err != nil { builder.err = err @@ -90,12 +90,12 @@ func (builder *RequestBuilder) Build() (*kv.Request, error) { switch dag.Executors[0].Tp { case tipb.ExecType_TypeTableScan, tipb.ExecType_TypeIndexScan, tipb.ExecType_TypePartitionTableScan: builder.Request.Concurrency = 2 - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { if val.(bool) { // When the concurrency is too small, test case tests/realtikvtest/sessiontest.TestCoprocessorOOMAction can't trigger OOM condition builder.Request.Concurrency = oldConcurrency } - }) + } } } } diff --git a/pkg/distsql/request_builder.go__failpoint_stash__ b/pkg/distsql/request_builder.go__failpoint_stash__ new file mode 100644 index 0000000000000..9f0b8b24a2a18 --- /dev/null +++ b/pkg/distsql/request_builder.go__failpoint_stash__ @@ -0,0 +1,862 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package distsql + +import ( + "fmt" + "math" + "sort" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/pkg/ddl/placement" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/errctx" + infoschema "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tipb/go-tipb" + "github.com/tikv/client-go/v2/tikvrpc" +) + +// RequestBuilder is used to build a "kv.Request". +// It is called before we issue a kv request by "Select". +type RequestBuilder struct { + kv.Request + is infoschema.MetaOnlyInfoSchema + err error + + // When SetDAGRequest is called, builder will also this field. + dag *tipb.DAGRequest +} + +// Build builds a "kv.Request". +func (builder *RequestBuilder) Build() (*kv.Request, error) { + if builder.ReadReplicaScope == "" { + builder.ReadReplicaScope = kv.GlobalReplicaScope + } + if builder.ReplicaRead.IsClosestRead() && builder.ReadReplicaScope != kv.GlobalReplicaScope { + builder.MatchStoreLabels = []*metapb.StoreLabel{ + { + Key: placement.DCLabelKey, + Value: builder.ReadReplicaScope, + }, + } + } + failpoint.Inject("assertRequestBuilderReplicaOption", func(val failpoint.Value) { + assertScope := val.(string) + if builder.ReplicaRead.IsClosestRead() && assertScope != builder.ReadReplicaScope { + panic("request builder get staleness option fail") + } + }) + err := builder.verifyTxnScope() + if err != nil { + builder.err = err + } + if builder.Request.KeyRanges == nil { + builder.Request.KeyRanges = kv.NewNonPartitionedKeyRanges(nil) + } + + if dag := builder.dag; dag != nil { + if execCnt := len(dag.Executors); execCnt == 1 { + oldConcurrency := builder.Request.Concurrency + // select * from t order by id + if builder.Request.KeepOrder { + // When the DAG is just simple scan and keep order, set concurrency to 2. + // If a lot data are returned to client, mysql protocol is the bottleneck so concurrency 2 is enough. + // If very few data are returned to client, the speed is not optimal but good enough. + switch dag.Executors[0].Tp { + case tipb.ExecType_TypeTableScan, tipb.ExecType_TypeIndexScan, tipb.ExecType_TypePartitionTableScan: + builder.Request.Concurrency = 2 + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val.(bool) { + // When the concurrency is too small, test case tests/realtikvtest/sessiontest.TestCoprocessorOOMAction can't trigger OOM condition + builder.Request.Concurrency = oldConcurrency + } + }) + } + } + } + } + + return &builder.Request, builder.err +} + +// SetMemTracker sets a memTracker for this request. +func (builder *RequestBuilder) SetMemTracker(tracker *memory.Tracker) *RequestBuilder { + builder.Request.MemTracker = tracker + return builder +} + +// SetTableRanges sets "KeyRanges" for "kv.Request" by converting "tableRanges" +// to "KeyRanges" firstly. +// Note this function should be deleted or at least not exported, but currently +// br refers it, so have to keep it. +func (builder *RequestBuilder) SetTableRanges(tid int64, tableRanges []*ranger.Range) *RequestBuilder { + if builder.err == nil { + builder.Request.KeyRanges = kv.NewNonPartitionedKeyRanges(TableRangesToKVRanges(tid, tableRanges)) + } + return builder +} + +// SetIndexRanges sets "KeyRanges" for "kv.Request" by converting index range +// "ranges" to "KeyRanges" firstly. +func (builder *RequestBuilder) SetIndexRanges(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range) *RequestBuilder { + if builder.err == nil { + builder.Request.KeyRanges, builder.err = IndexRangesToKVRanges(dctx, tid, idxID, ranges) + } + return builder +} + +// SetIndexRangesForTables sets "KeyRanges" for "kv.Request" by converting multiple indexes range +// "ranges" to "KeyRanges" firstly. +func (builder *RequestBuilder) SetIndexRangesForTables(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range) *RequestBuilder { + if builder.err == nil { + builder.Request.KeyRanges, builder.err = IndexRangesToKVRangesForTables(dctx, tids, idxID, ranges) + } + return builder +} + +// SetHandleRanges sets "KeyRanges" for "kv.Request" by converting table handle range +// "ranges" to "KeyRanges" firstly. +func (builder *RequestBuilder) SetHandleRanges(dctx *distsqlctx.DistSQLContext, tid int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { + builder = builder.SetHandleRangesForTables(dctx, []int64{tid}, isCommonHandle, ranges) + builder.err = builder.Request.KeyRanges.SetToNonPartitioned() + return builder +} + +// SetHandleRangesForTables sets "KeyRanges" for "kv.Request" by converting table handle range +// "ranges" to "KeyRanges" firstly for multiple tables. +func (builder *RequestBuilder) SetHandleRangesForTables(dctx *distsqlctx.DistSQLContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { + if builder.err == nil { + builder.Request.KeyRanges, builder.err = TableHandleRangesToKVRanges(dctx, tid, isCommonHandle, ranges) + } + return builder +} + +// SetTableHandles sets "KeyRanges" for "kv.Request" by converting table handles +// "handles" to "KeyRanges" firstly. +func (builder *RequestBuilder) SetTableHandles(tid int64, handles []kv.Handle) *RequestBuilder { + keyRanges, hints := TableHandlesToKVRanges(tid, handles) + builder.Request.KeyRanges = kv.NewNonParitionedKeyRangesWithHint(keyRanges, hints) + return builder +} + +// SetPartitionsAndHandles sets "KeyRanges" for "kv.Request" by converting ParitionHandles to KeyRanges. +// handles in slice must be kv.PartitionHandle. +func (builder *RequestBuilder) SetPartitionsAndHandles(handles []kv.Handle) *RequestBuilder { + keyRanges, hints := PartitionHandlesToKVRanges(handles) + builder.Request.KeyRanges = kv.NewNonParitionedKeyRangesWithHint(keyRanges, hints) + return builder +} + +const estimatedRegionRowCount = 100000 + +// SetDAGRequest sets the request type to "ReqTypeDAG" and construct request data. +func (builder *RequestBuilder) SetDAGRequest(dag *tipb.DAGRequest) *RequestBuilder { + if builder.err == nil { + builder.Request.Tp = kv.ReqTypeDAG + builder.Request.Cacheable = true + builder.Request.Data, builder.err = dag.Marshal() + builder.dag = dag + if execCnt := len(dag.Executors); execCnt != 0 && dag.Executors[execCnt-1].GetLimit() != nil { + limit := dag.Executors[execCnt-1].GetLimit() + builder.Request.LimitSize = limit.GetLimit() + // When the DAG is just simple scan and small limit, set concurrency to 1 would be sufficient. + if execCnt == 2 { + if limit.Limit < estimatedRegionRowCount { + if kr := builder.Request.KeyRanges; kr != nil { + builder.Request.Concurrency = kr.PartitionNum() + } else { + builder.Request.Concurrency = 1 + } + } + } + } + } + return builder +} + +// SetAnalyzeRequest sets the request type to "ReqTypeAnalyze" and construct request data. +func (builder *RequestBuilder) SetAnalyzeRequest(ana *tipb.AnalyzeReq, isoLevel kv.IsoLevel) *RequestBuilder { + if builder.err == nil { + builder.Request.Tp = kv.ReqTypeAnalyze + builder.Request.Data, builder.err = ana.Marshal() + builder.Request.NotFillCache = true + builder.Request.IsolationLevel = isoLevel + builder.Request.Priority = kv.PriorityLow + } + + return builder +} + +// SetChecksumRequest sets the request type to "ReqTypeChecksum" and construct request data. +func (builder *RequestBuilder) SetChecksumRequest(checksum *tipb.ChecksumRequest) *RequestBuilder { + if builder.err == nil { + builder.Request.Tp = kv.ReqTypeChecksum + builder.Request.Data, builder.err = checksum.Marshal() + builder.Request.NotFillCache = true + } + + return builder +} + +// SetKeyRanges sets "KeyRanges" for "kv.Request". +func (builder *RequestBuilder) SetKeyRanges(keyRanges []kv.KeyRange) *RequestBuilder { + builder.Request.KeyRanges = kv.NewNonPartitionedKeyRanges(keyRanges) + return builder +} + +// SetKeyRangesWithHints sets "KeyRanges" for "kv.Request" with row count hints. +func (builder *RequestBuilder) SetKeyRangesWithHints(keyRanges []kv.KeyRange, hints []int) *RequestBuilder { + builder.Request.KeyRanges = kv.NewNonParitionedKeyRangesWithHint(keyRanges, hints) + return builder +} + +// SetWrappedKeyRanges sets "KeyRanges" for "kv.Request". +func (builder *RequestBuilder) SetWrappedKeyRanges(keyRanges *kv.KeyRanges) *RequestBuilder { + builder.Request.KeyRanges = keyRanges + return builder +} + +// SetPartitionKeyRanges sets the "KeyRanges" for "kv.Request" on partitioned table cases. +func (builder *RequestBuilder) SetPartitionKeyRanges(keyRanges [][]kv.KeyRange) *RequestBuilder { + builder.Request.KeyRanges = kv.NewPartitionedKeyRanges(keyRanges) + return builder +} + +// SetStartTS sets "StartTS" for "kv.Request". +func (builder *RequestBuilder) SetStartTS(startTS uint64) *RequestBuilder { + builder.Request.StartTs = startTS + return builder +} + +// SetDesc sets "Desc" for "kv.Request". +func (builder *RequestBuilder) SetDesc(desc bool) *RequestBuilder { + builder.Request.Desc = desc + return builder +} + +// SetKeepOrder sets "KeepOrder" for "kv.Request". +func (builder *RequestBuilder) SetKeepOrder(order bool) *RequestBuilder { + builder.Request.KeepOrder = order + return builder +} + +// SetStoreType sets "StoreType" for "kv.Request". +func (builder *RequestBuilder) SetStoreType(storeType kv.StoreType) *RequestBuilder { + builder.Request.StoreType = storeType + return builder +} + +// SetAllowBatchCop sets `BatchCop` property. +func (builder *RequestBuilder) SetAllowBatchCop(batchCop bool) *RequestBuilder { + builder.Request.BatchCop = batchCop + return builder +} + +// SetPartitionIDAndRanges sets `PartitionIDAndRanges` property. +func (builder *RequestBuilder) SetPartitionIDAndRanges(partitionIDAndRanges []kv.PartitionIDAndRanges) *RequestBuilder { + builder.PartitionIDAndRanges = partitionIDAndRanges + return builder +} + +func (builder *RequestBuilder) getIsolationLevel() kv.IsoLevel { + if builder.Tp == kv.ReqTypeAnalyze { + return kv.RC + } + return kv.SI +} + +func (*RequestBuilder) getKVPriority(dctx *distsqlctx.DistSQLContext) int { + switch dctx.Priority { + case mysql.NoPriority, mysql.DelayedPriority: + return kv.PriorityNormal + case mysql.LowPriority: + return kv.PriorityLow + case mysql.HighPriority: + return kv.PriorityHigh + } + return kv.PriorityNormal +} + +// SetFromSessionVars sets the following fields for "kv.Request" from session variables: +// "Concurrency", "IsolationLevel", "NotFillCache", "TaskID", "Priority", "ReplicaRead", +// "ResourceGroupTagger", "ResourceGroupName" +func (builder *RequestBuilder) SetFromSessionVars(dctx *distsqlctx.DistSQLContext) *RequestBuilder { + distsqlConcurrency := dctx.DistSQLConcurrency + if builder.Request.Concurrency == 0 { + // Concurrency unset. + builder.Request.Concurrency = distsqlConcurrency + } else if builder.Request.Concurrency > distsqlConcurrency { + // Concurrency is set in SetDAGRequest, check the upper limit. + builder.Request.Concurrency = distsqlConcurrency + } + replicaReadType := dctx.ReplicaReadType + if dctx.WeakConsistency { + builder.Request.IsolationLevel = kv.RC + } else if dctx.RCCheckTS { + builder.Request.IsolationLevel = kv.RCCheckTS + replicaReadType = kv.ReplicaReadLeader + } else { + builder.Request.IsolationLevel = builder.getIsolationLevel() + } + builder.Request.NotFillCache = dctx.NotFillCache + builder.Request.TaskID = dctx.TaskID + builder.Request.Priority = builder.getKVPriority(dctx) + builder.Request.ReplicaRead = replicaReadType + builder.SetResourceGroupTagger(dctx.ResourceGroupTagger) + { + builder.SetPaging(dctx.EnablePaging) + builder.Request.Paging.MinPagingSize = uint64(dctx.MinPagingSize) + builder.Request.Paging.MaxPagingSize = uint64(dctx.MaxPagingSize) + } + builder.RequestSource.RequestSourceInternal = dctx.InRestrictedSQL + builder.RequestSource.RequestSourceType = dctx.RequestSourceType + builder.RequestSource.ExplicitRequestSourceType = dctx.ExplicitRequestSourceType + builder.StoreBatchSize = dctx.StoreBatchSize + builder.Request.ResourceGroupName = dctx.ResourceGroupName + builder.Request.StoreBusyThreshold = dctx.LoadBasedReplicaReadThreshold + builder.Request.RunawayChecker = dctx.RunawayChecker + builder.Request.TiKVClientReadTimeout = dctx.TiKVClientReadTimeout + return builder +} + +// SetPaging sets "Paging" flag for "kv.Request". +func (builder *RequestBuilder) SetPaging(paging bool) *RequestBuilder { + builder.Request.Paging.Enable = paging + return builder +} + +// SetConcurrency sets "Concurrency" for "kv.Request". +func (builder *RequestBuilder) SetConcurrency(concurrency int) *RequestBuilder { + builder.Request.Concurrency = concurrency + return builder +} + +// SetTiDBServerID sets "TiDBServerID" for "kv.Request" +// +// ServerID is a unique id of TiDB instance among the cluster. +// See https://github.com/pingcap/tidb/blob/master/docs/design/2020-06-01-global-kill.md +func (builder *RequestBuilder) SetTiDBServerID(serverID uint64) *RequestBuilder { + builder.Request.TiDBServerID = serverID + return builder +} + +// SetFromInfoSchema sets the following fields from infoSchema: +// "bundles" +func (builder *RequestBuilder) SetFromInfoSchema(is infoschema.MetaOnlyInfoSchema) *RequestBuilder { + builder.is = is + builder.Request.SchemaVar = is.SchemaMetaVersion() + return builder +} + +// SetResourceGroupTagger sets the request resource group tagger. +func (builder *RequestBuilder) SetResourceGroupTagger(tagger tikvrpc.ResourceGroupTagger) *RequestBuilder { + builder.Request.ResourceGroupTagger = tagger + return builder +} + +// SetResourceGroupName sets the request resource group name. +func (builder *RequestBuilder) SetResourceGroupName(name string) *RequestBuilder { + builder.Request.ResourceGroupName = name + return builder +} + +// SetExplicitRequestSourceType sets the explicit request source type. +func (builder *RequestBuilder) SetExplicitRequestSourceType(sourceType string) *RequestBuilder { + builder.RequestSource.ExplicitRequestSourceType = sourceType + return builder +} + +func (builder *RequestBuilder) verifyTxnScope() error { + txnScope := builder.TxnScope + if txnScope == "" || txnScope == kv.GlobalReplicaScope || builder.is == nil { + return nil + } + visitPhysicalTableID := make(map[int64]struct{}) + tids, err := tablecodec.VerifyTableIDForRanges(builder.Request.KeyRanges) + if err != nil { + return err + } + for _, tid := range tids { + visitPhysicalTableID[tid] = struct{}{} + } + + for phyTableID := range visitPhysicalTableID { + valid := VerifyTxnScope(txnScope, phyTableID, builder.is) + if !valid { + var tblName string + var partName string + tblInfo, _, partInfo := builder.is.FindTableInfoByPartitionID(phyTableID) + if tblInfo != nil && partInfo != nil { + tblName = tblInfo.Name.String() + partName = partInfo.Name.String() + } else { + tblInfo, _ = builder.is.TableInfoByID(phyTableID) + tblName = tblInfo.Name.String() + } + err := fmt.Errorf("table %v can not be read by %v txn_scope", tblName, txnScope) + if len(partName) > 0 { + err = fmt.Errorf("table %v's partition %v can not be read by %v txn_scope", + tblName, partName, txnScope) + } + return err + } + } + return nil +} + +// SetTxnScope sets request TxnScope +func (builder *RequestBuilder) SetTxnScope(scope string) *RequestBuilder { + builder.TxnScope = scope + return builder +} + +// SetReadReplicaScope sets request readReplicaScope +func (builder *RequestBuilder) SetReadReplicaScope(scope string) *RequestBuilder { + builder.ReadReplicaScope = scope + return builder +} + +// SetIsStaleness sets request IsStaleness +func (builder *RequestBuilder) SetIsStaleness(is bool) *RequestBuilder { + builder.IsStaleness = is + return builder +} + +// SetClosestReplicaReadAdjuster sets request CoprRequestAdjuster +func (builder *RequestBuilder) SetClosestReplicaReadAdjuster(chkFn kv.CoprRequestAdjuster) *RequestBuilder { + builder.ClosestReplicaReadAdjuster = chkFn + return builder +} + +// SetConnIDAndConnAlias sets connection id for the builder. +func (builder *RequestBuilder) SetConnIDAndConnAlias(connID uint64, connAlias string) *RequestBuilder { + builder.ConnID = connID + builder.ConnAlias = connAlias + return builder +} + +// TableHandleRangesToKVRanges convert table handle ranges to "KeyRanges" for multiple tables. +func TableHandleRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) (*kv.KeyRanges, error) { + if !isCommonHandle { + return tablesRangesToKVRanges(tid, ranges), nil + } + return CommonHandleRangesToKVRanges(dctx, tid, ranges) +} + +// TableRangesToKVRanges converts table ranges to "KeyRange". +// Note this function should not be exported, but currently +// br refers to it, so have to keep it. +func TableRangesToKVRanges(tid int64, ranges []*ranger.Range) []kv.KeyRange { + if len(ranges) == 0 { + return []kv.KeyRange{} + } + return tablesRangesToKVRanges([]int64{tid}, ranges).FirstPartitionRange() +} + +// tablesRangesToKVRanges converts table ranges to "KeyRange". +func tablesRangesToKVRanges(tids []int64, ranges []*ranger.Range) *kv.KeyRanges { + return tableRangesToKVRangesWithoutSplit(tids, ranges) +} + +func tableRangesToKVRangesWithoutSplit(tids []int64, ranges []*ranger.Range) *kv.KeyRanges { + krs := make([][]kv.KeyRange, len(tids)) + for i := range krs { + krs[i] = make([]kv.KeyRange, 0, len(ranges)) + } + for _, ran := range ranges { + low, high := encodeHandleKey(ran) + for i, tid := range tids { + startKey := tablecodec.EncodeRowKey(tid, low) + endKey := tablecodec.EncodeRowKey(tid, high) + krs[i] = append(krs[i], kv.KeyRange{StartKey: startKey, EndKey: endKey}) + } + } + return kv.NewPartitionedKeyRanges(krs) +} + +func encodeHandleKey(ran *ranger.Range) ([]byte, []byte) { + low := codec.EncodeInt(nil, ran.LowVal[0].GetInt64()) + high := codec.EncodeInt(nil, ran.HighVal[0].GetInt64()) + if ran.LowExclude { + low = kv.Key(low).PrefixNext() + } + if !ran.HighExclude { + high = kv.Key(high).PrefixNext() + } + return low, high +} + +// SplitRangesAcrossInt64Boundary split the ranges into two groups: +// 1. signedRanges is less or equal than MaxInt64 +// 2. unsignedRanges is greater than MaxInt64 +// +// We do this because every key of tikv is encoded as an int64. As a result, MaxUInt64 is smaller than zero when +// interpreted as an int64 variable. +// +// This function does the following: +// 1. split ranges into two groups as described above. +// 2. if there's a range that straddles the int64 boundary, split it into two ranges, which results in one smaller and +// one greater than MaxInt64. +// +// if `KeepOrder` is false, we merge the two groups of ranges into one group, to save a rpc call later +// if `desc` is false, return signed ranges first, vice versa. +func SplitRangesAcrossInt64Boundary(ranges []*ranger.Range, keepOrder bool, desc bool, isCommonHandle bool) ([]*ranger.Range, []*ranger.Range) { + if isCommonHandle || len(ranges) == 0 || ranges[0].LowVal[0].Kind() == types.KindInt64 { + return ranges, nil + } + idx := sort.Search(len(ranges), func(i int) bool { return ranges[i].HighVal[0].GetUint64() > math.MaxInt64 }) + if idx == len(ranges) { + return ranges, nil + } + if ranges[idx].LowVal[0].GetUint64() > math.MaxInt64 { + signedRanges := ranges[0:idx] + unsignedRanges := ranges[idx:] + if !keepOrder { + return append(unsignedRanges, signedRanges...), nil + } + if desc { + return unsignedRanges, signedRanges + } + return signedRanges, unsignedRanges + } + // need to split the range that straddles the int64 boundary + signedRanges := make([]*ranger.Range, 0, idx+1) + unsignedRanges := make([]*ranger.Range, 0, len(ranges)-idx) + signedRanges = append(signedRanges, ranges[0:idx]...) + if !(ranges[idx].LowVal[0].GetUint64() == math.MaxInt64 && ranges[idx].LowExclude) { + signedRanges = append(signedRanges, &ranger.Range{ + LowVal: ranges[idx].LowVal, + LowExclude: ranges[idx].LowExclude, + HighVal: []types.Datum{types.NewUintDatum(math.MaxInt64)}, + Collators: ranges[idx].Collators, + }) + } + if !(ranges[idx].HighVal[0].GetUint64() == math.MaxInt64+1 && ranges[idx].HighExclude) { + unsignedRanges = append(unsignedRanges, &ranger.Range{ + LowVal: []types.Datum{types.NewUintDatum(math.MaxInt64 + 1)}, + HighVal: ranges[idx].HighVal, + HighExclude: ranges[idx].HighExclude, + Collators: ranges[idx].Collators, + }) + } + if idx < len(ranges) { + unsignedRanges = append(unsignedRanges, ranges[idx+1:]...) + } + if !keepOrder { + return append(unsignedRanges, signedRanges...), nil + } + if desc { + return unsignedRanges, signedRanges + } + return signedRanges, unsignedRanges +} + +// TableHandlesToKVRanges converts sorted handle to kv ranges. +// For continuous handles, we should merge them to a single key range. +func TableHandlesToKVRanges(tid int64, handles []kv.Handle) ([]kv.KeyRange, []int) { + krs := make([]kv.KeyRange, 0, len(handles)) + hints := make([]int, 0, len(handles)) + i := 0 + for i < len(handles) { + var isCommonHandle bool + var commonHandle *kv.CommonHandle + if partitionHandle, ok := handles[i].(kv.PartitionHandle); ok { + tid = partitionHandle.PartitionID + commonHandle, isCommonHandle = partitionHandle.Handle.(*kv.CommonHandle) + } else { + commonHandle, isCommonHandle = handles[i].(*kv.CommonHandle) + } + if isCommonHandle { + ran := kv.KeyRange{ + StartKey: tablecodec.EncodeRowKey(tid, commonHandle.Encoded()), + EndKey: tablecodec.EncodeRowKey(tid, kv.Key(commonHandle.Encoded()).Next()), + } + krs = append(krs, ran) + hints = append(hints, 1) + i++ + continue + } + j := i + 1 + for ; j < len(handles) && handles[j-1].IntValue() != math.MaxInt64; j++ { + if p, ok := handles[j].(kv.PartitionHandle); ok && p.PartitionID != tid { + break + } + if handles[j].IntValue() != handles[j-1].IntValue()+1 { + break + } + } + low := codec.EncodeInt(nil, handles[i].IntValue()) + high := codec.EncodeInt(nil, handles[j-1].IntValue()) + high = kv.Key(high).PrefixNext() + startKey := tablecodec.EncodeRowKey(tid, low) + endKey := tablecodec.EncodeRowKey(tid, high) + krs = append(krs, kv.KeyRange{StartKey: startKey, EndKey: endKey}) + hints = append(hints, j-i) + i = j + } + return krs, hints +} + +// PartitionHandlesToKVRanges convert ParitionHandles to kv ranges. +// Handle in slices must be kv.PartitionHandle +func PartitionHandlesToKVRanges(handles []kv.Handle) ([]kv.KeyRange, []int) { + krs := make([]kv.KeyRange, 0, len(handles)) + hints := make([]int, 0, len(handles)) + i := 0 + for i < len(handles) { + ph := handles[i].(kv.PartitionHandle) + h := ph.Handle + pid := ph.PartitionID + if commonHandle, ok := h.(*kv.CommonHandle); ok { + ran := kv.KeyRange{ + StartKey: tablecodec.EncodeRowKey(pid, commonHandle.Encoded()), + EndKey: tablecodec.EncodeRowKey(pid, append(commonHandle.Encoded(), 0)), + } + krs = append(krs, ran) + hints = append(hints, 1) + i++ + continue + } + j := i + 1 + for ; j < len(handles) && handles[j-1].IntValue() != math.MaxInt64; j++ { + if handles[j].IntValue() != handles[j-1].IntValue()+1 { + break + } + if handles[j].(kv.PartitionHandle).PartitionID != pid { + break + } + } + low := codec.EncodeInt(nil, handles[i].IntValue()) + high := codec.EncodeInt(nil, handles[j-1].IntValue()) + high = kv.Key(high).PrefixNext() + startKey := tablecodec.EncodeRowKey(pid, low) + endKey := tablecodec.EncodeRowKey(pid, high) + krs = append(krs, kv.KeyRange{StartKey: startKey, EndKey: endKey}) + hints = append(hints, j-i) + i = j + } + return krs, hints +} + +// IndexRangesToKVRanges converts index ranges to "KeyRange". +func IndexRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { + return IndexRangesToKVRangesWithInterruptSignal(dctx, tid, idxID, ranges, nil, nil) +} + +// IndexRangesToKVRangesWithInterruptSignal converts index ranges to "KeyRange". +// The process can be interrupted by set `interruptSignal` to true. +func IndexRangesToKVRangesWithInterruptSignal(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { + keyRanges, err := indexRangesToKVRangesForTablesWithInterruptSignal(dctx, []int64{tid}, idxID, ranges, memTracker, interruptSignal) + if err != nil { + return nil, err + } + err = keyRanges.SetToNonPartitioned() + return keyRanges, err +} + +// IndexRangesToKVRangesForTables converts indexes ranges to "KeyRange". +func IndexRangesToKVRangesForTables(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { + return indexRangesToKVRangesForTablesWithInterruptSignal(dctx, tids, idxID, ranges, nil, nil) +} + +// IndexRangesToKVRangesForTablesWithInterruptSignal converts indexes ranges to "KeyRange". +// The process can be interrupted by set `interruptSignal` to true. +func indexRangesToKVRangesForTablesWithInterruptSignal(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { + return indexRangesToKVWithoutSplit(dctx, tids, idxID, ranges, memTracker, interruptSignal) +} + +// CommonHandleRangesToKVRanges converts common handle ranges to "KeyRange". +func CommonHandleRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tids []int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { + rans := make([]*ranger.Range, 0, len(ranges)) + for _, ran := range ranges { + low, high, err := EncodeIndexKey(dctx, ran) + if err != nil { + return nil, err + } + rans = append(rans, &ranger.Range{ + LowVal: []types.Datum{types.NewBytesDatum(low)}, + HighVal: []types.Datum{types.NewBytesDatum(high)}, + LowExclude: false, + HighExclude: true, + Collators: collate.GetBinaryCollatorSlice(1), + }) + } + krs := make([][]kv.KeyRange, len(tids)) + for i := range krs { + krs[i] = make([]kv.KeyRange, 0, len(ranges)) + } + for _, ran := range rans { + low, high := ran.LowVal[0].GetBytes(), ran.HighVal[0].GetBytes() + if ran.LowExclude { + low = kv.Key(low).PrefixNext() + } + ran.LowVal[0].SetBytes(low) + for i, tid := range tids { + startKey := tablecodec.EncodeRowKey(tid, low) + endKey := tablecodec.EncodeRowKey(tid, high) + krs[i] = append(krs[i], kv.KeyRange{StartKey: startKey, EndKey: endKey}) + } + } + return kv.NewPartitionedKeyRanges(krs), nil +} + +// VerifyTxnScope verify whether the txnScope and visited physical table break the leader rule's dcLocation. +func VerifyTxnScope(txnScope string, physicalTableID int64, is infoschema.MetaOnlyInfoSchema) bool { + if txnScope == "" || txnScope == kv.GlobalTxnScope { + return true + } + bundle, ok := is.PlacementBundleByPhysicalTableID(physicalTableID) + if !ok { + return true + } + leaderDC, ok := bundle.GetLeaderDC(placement.DCLabelKey) + if !ok { + return true + } + if leaderDC != txnScope { + return false + } + return true +} + +func indexRangesToKVWithoutSplit(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { + krs := make([][]kv.KeyRange, len(tids)) + for i := range krs { + krs[i] = make([]kv.KeyRange, 0, len(ranges)) + } + + const checkSignalStep = 8 + var estimatedMemUsage int64 + // encodeIndexKey and EncodeIndexSeekKey is time-consuming, thus we need to + // check the interrupt signal periodically. + for i, ran := range ranges { + low, high, err := EncodeIndexKey(dctx, ran) + if err != nil { + return nil, err + } + if i == 0 { + estimatedMemUsage += int64(cap(low) + cap(high)) + } + for j, tid := range tids { + startKey := tablecodec.EncodeIndexSeekKey(tid, idxID, low) + endKey := tablecodec.EncodeIndexSeekKey(tid, idxID, high) + if i == 0 { + estimatedMemUsage += int64(cap(startKey)) + int64(cap(endKey)) + } + krs[j] = append(krs[j], kv.KeyRange{StartKey: startKey, EndKey: endKey}) + } + if i%checkSignalStep == 0 { + if i == 0 && memTracker != nil { + estimatedMemUsage *= int64(len(ranges)) + memTracker.Consume(estimatedMemUsage) + } + if interruptSignal != nil && interruptSignal.Load().(bool) { + return kv.NewPartitionedKeyRanges(nil), nil + } + } + } + return kv.NewPartitionedKeyRanges(krs), nil +} + +// EncodeIndexKey gets encoded keys containing low and high +func EncodeIndexKey(dctx *distsqlctx.DistSQLContext, ran *ranger.Range) ([]byte, []byte, error) { + tz := time.UTC + errCtx := errctx.StrictNoWarningContext + if dctx != nil { + tz = dctx.Location + errCtx = dctx.ErrCtx + } + + low, err := codec.EncodeKey(tz, nil, ran.LowVal...) + err = errCtx.HandleError(err) + if err != nil { + return nil, nil, err + } + if ran.LowExclude { + low = kv.Key(low).PrefixNext() + } + high, err := codec.EncodeKey(tz, nil, ran.HighVal...) + err = errCtx.HandleError(err) + if err != nil { + return nil, nil, err + } + + if !ran.HighExclude { + high = kv.Key(high).PrefixNext() + } + return low, high, nil +} + +// BuildTableRanges returns the key ranges encompassing the entire table, +// and its partitions if exists. +func BuildTableRanges(tbl *model.TableInfo) ([]kv.KeyRange, error) { + pis := tbl.GetPartitionInfo() + if pis == nil { + // Short path, no partition. + return appendRanges(tbl, tbl.ID) + } + + ranges := make([]kv.KeyRange, 0, len(pis.Definitions)*(len(tbl.Indices)+1)+1) + for _, def := range pis.Definitions { + rgs, err := appendRanges(tbl, def.ID) + if err != nil { + return nil, errors.Trace(err) + } + ranges = append(ranges, rgs...) + } + return ranges, nil +} + +func appendRanges(tbl *model.TableInfo, tblID int64) ([]kv.KeyRange, error) { + var ranges []*ranger.Range + if tbl.IsCommonHandle { + ranges = ranger.FullNotNullRange() + } else { + ranges = ranger.FullIntRange(false) + } + + retRanges := make([]kv.KeyRange, 0, 1+len(tbl.Indices)) + kvRanges, err := TableHandleRangesToKVRanges(nil, []int64{tblID}, tbl.IsCommonHandle, ranges) + if err != nil { + return nil, errors.Trace(err) + } + retRanges = kvRanges.AppendSelfTo(retRanges) + + for _, index := range tbl.Indices { + if index.State != model.StatePublic { + continue + } + ranges = ranger.FullRange() + idxRanges, err := IndexRangesToKVRanges(nil, tblID, index.ID, ranges) + if err != nil { + return nil, errors.Trace(err) + } + retRanges = idxRanges.AppendSelfTo(retRanges) + } + return retRanges, nil +} diff --git a/pkg/distsql/select_result.go b/pkg/distsql/select_result.go index 5d485d4cf1483..1f854a80ea236 100644 --- a/pkg/distsql/select_result.go +++ b/pkg/distsql/select_result.go @@ -402,11 +402,11 @@ func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { // NextRaw returns the next raw partial result. func (r *selectResult) NextRaw(ctx context.Context) (data []byte, err error) { - failpoint.Inject("mockNextRawError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockNextRawError")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("mockNextRawError")) + return nil, errors.New("mockNextRawError") } - }) + } resultSubset, err := r.resp.Next(ctx) r.partialCount++ diff --git a/pkg/distsql/select_result.go__failpoint_stash__ b/pkg/distsql/select_result.go__failpoint_stash__ new file mode 100644 index 0000000000000..5d485d4cf1483 --- /dev/null +++ b/pkg/distsql/select_result.go__failpoint_stash__ @@ -0,0 +1,815 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package distsql + +import ( + "bytes" + "container/heap" + "context" + "fmt" + "strconv" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + dcontext "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/store/copr" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tipb/go-tipb" + tikvmetrics "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/tikv" + clientutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" + "golang.org/x/exp/maps" +) + +var ( + errQueryInterrupted = dbterror.ClassExecutor.NewStd(errno.ErrQueryInterrupted) +) + +var ( + _ SelectResult = (*selectResult)(nil) + _ SelectResult = (*serialSelectResults)(nil) + _ SelectResult = (*sortedSelectResults)(nil) +) + +// SelectResult is an iterator of coprocessor partial results. +type SelectResult interface { + // NextRaw gets the next raw result. + NextRaw(context.Context) ([]byte, error) + // Next reads the data into chunk. + Next(context.Context, *chunk.Chunk) error + // Close closes the iterator. + Close() error +} + +type chunkRowHeap struct { + *sortedSelectResults +} + +func (h chunkRowHeap) Len() int { + return len(h.rowPtrs) +} + +func (h chunkRowHeap) Less(i, j int) bool { + iPtr := h.rowPtrs[i] + jPtr := h.rowPtrs[j] + return h.lessRow(h.cachedChunks[iPtr.ChkIdx].GetRow(int(iPtr.RowIdx)), + h.cachedChunks[jPtr.ChkIdx].GetRow(int(jPtr.RowIdx))) +} + +func (h chunkRowHeap) Swap(i, j int) { + h.rowPtrs[i], h.rowPtrs[j] = h.rowPtrs[j], h.rowPtrs[i] +} + +func (h *chunkRowHeap) Push(x any) { + h.rowPtrs = append(h.rowPtrs, x.(chunk.RowPtr)) +} + +func (h *chunkRowHeap) Pop() any { + ret := h.rowPtrs[len(h.rowPtrs)-1] + h.rowPtrs = h.rowPtrs[0 : len(h.rowPtrs)-1] + return ret +} + +// NewSortedSelectResults is only for partition table +// If schema == nil, sort by first few columns. +func NewSortedSelectResults(ectx expression.EvalContext, selectResult []SelectResult, schema *expression.Schema, byitems []*util.ByItems, memTracker *memory.Tracker) SelectResult { + s := &sortedSelectResults{ + schema: schema, + selectResult: selectResult, + byItems: byitems, + memTracker: memTracker, + } + s.initCompareFuncs(ectx) + s.buildKeyColumns() + s.heap = &chunkRowHeap{s} + s.cachedChunks = make([]*chunk.Chunk, len(selectResult)) + return s +} + +type sortedSelectResults struct { + schema *expression.Schema + selectResult []SelectResult + compareFuncs []chunk.CompareFunc + byItems []*util.ByItems + keyColumns []int + + cachedChunks []*chunk.Chunk + rowPtrs []chunk.RowPtr + heap *chunkRowHeap + + memTracker *memory.Tracker +} + +func (ssr *sortedSelectResults) updateCachedChunk(ctx context.Context, idx uint32) error { + prevMemUsage := ssr.cachedChunks[idx].MemoryUsage() + if err := ssr.selectResult[idx].Next(ctx, ssr.cachedChunks[idx]); err != nil { + return err + } + ssr.memTracker.Consume(ssr.cachedChunks[idx].MemoryUsage() - prevMemUsage) + if ssr.cachedChunks[idx].NumRows() == 0 { + return nil + } + heap.Push(ssr.heap, chunk.RowPtr{ChkIdx: idx, RowIdx: 0}) + return nil +} + +func (ssr *sortedSelectResults) initCompareFuncs(ectx expression.EvalContext) { + ssr.compareFuncs = make([]chunk.CompareFunc, len(ssr.byItems)) + for i, item := range ssr.byItems { + keyType := item.Expr.GetType(ectx) + ssr.compareFuncs[i] = chunk.GetCompareFunc(keyType) + } +} + +func (ssr *sortedSelectResults) buildKeyColumns() { + ssr.keyColumns = make([]int, 0, len(ssr.byItems)) + for i, by := range ssr.byItems { + col := by.Expr.(*expression.Column) + if ssr.schema == nil { + ssr.keyColumns = append(ssr.keyColumns, i) + } else { + ssr.keyColumns = append(ssr.keyColumns, ssr.schema.ColumnIndex(col)) + } + } +} + +func (ssr *sortedSelectResults) lessRow(rowI, rowJ chunk.Row) bool { + for i, colIdx := range ssr.keyColumns { + cmpFunc := ssr.compareFuncs[i] + cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) + if ssr.byItems[i].Desc { + cmp = -cmp + } + if cmp < 0 { + return true + } else if cmp > 0 { + return false + } + } + return false +} + +func (*sortedSelectResults) NextRaw(context.Context) ([]byte, error) { + panic("Not support NextRaw for sortedSelectResults") +} + +func (ssr *sortedSelectResults) Next(ctx context.Context, c *chunk.Chunk) (err error) { + c.Reset() + for i := range ssr.cachedChunks { + if ssr.cachedChunks[i] == nil { + ssr.cachedChunks[i] = c.CopyConstruct() + ssr.memTracker.Consume(ssr.cachedChunks[i].MemoryUsage()) + } + } + + if ssr.heap.Len() == 0 { + for i := range ssr.cachedChunks { + if err = ssr.updateCachedChunk(ctx, uint32(i)); err != nil { + return err + } + } + } + + for c.NumRows() < c.RequiredRows() { + if ssr.heap.Len() == 0 { + break + } + + idx := heap.Pop(ssr.heap).(chunk.RowPtr) + c.AppendRow(ssr.cachedChunks[idx.ChkIdx].GetRow(int(idx.RowIdx))) + if int(idx.RowIdx) >= ssr.cachedChunks[idx.ChkIdx].NumRows()-1 { + if err = ssr.updateCachedChunk(ctx, idx.ChkIdx); err != nil { + return err + } + } else { + heap.Push(ssr.heap, chunk.RowPtr{ChkIdx: idx.ChkIdx, RowIdx: idx.RowIdx + 1}) + } + } + return nil +} + +func (ssr *sortedSelectResults) Close() (err error) { + for i, sr := range ssr.selectResult { + err = sr.Close() + if err != nil { + return err + } + ssr.memTracker.Consume(-ssr.cachedChunks[i].MemoryUsage()) + ssr.cachedChunks[i] = nil + } + return nil +} + +// NewSerialSelectResults create a SelectResult which will read each SelectResult serially. +func NewSerialSelectResults(selectResults []SelectResult) SelectResult { + return &serialSelectResults{ + selectResults: selectResults, + cur: 0, + } +} + +// serialSelectResults reads each SelectResult serially +type serialSelectResults struct { + selectResults []SelectResult + cur int +} + +func (ssr *serialSelectResults) NextRaw(ctx context.Context) ([]byte, error) { + for ssr.cur < len(ssr.selectResults) { + resultSubset, err := ssr.selectResults[ssr.cur].NextRaw(ctx) + if err != nil { + return nil, err + } + if len(resultSubset) > 0 { + return resultSubset, nil + } + ssr.cur++ // move to the next SelectResult + } + return nil, nil +} + +func (ssr *serialSelectResults) Next(ctx context.Context, chk *chunk.Chunk) error { + for ssr.cur < len(ssr.selectResults) { + if err := ssr.selectResults[ssr.cur].Next(ctx, chk); err != nil { + return err + } + if chk.NumRows() > 0 { + return nil + } + ssr.cur++ // move to the next SelectResult + } + return nil +} + +func (ssr *serialSelectResults) Close() (err error) { + for _, r := range ssr.selectResults { + if rerr := r.Close(); rerr != nil { + err = rerr + } + } + return +} + +type selectResult struct { + label string + resp kv.Response + + rowLen int + fieldTypes []*types.FieldType + ctx *dcontext.DistSQLContext + + selectResp *tipb.SelectResponse + selectRespSize int64 // record the selectResp.Size() when it is initialized. + respChkIdx int + respChunkDecoder *chunk.Decoder + + partialCount int64 // number of partial results. + sqlType string + + // copPlanIDs contains all copTasks' planIDs, + // which help to collect copTasks' runtime stats. + copPlanIDs []int + rootPlanID int + + storeType kv.StoreType + + fetchDuration time.Duration + durationReported bool + memTracker *memory.Tracker + + stats *selectResultRuntimeStats + // distSQLConcurrency and paging are only for collecting information, and they don't affect the process of execution. + distSQLConcurrency int + paging bool +} + +func (r *selectResult) fetchResp(ctx context.Context) error { + for { + r.respChkIdx = 0 + startTime := time.Now() + resultSubset, err := r.resp.Next(ctx) + duration := time.Since(startTime) + r.fetchDuration += duration + if err != nil { + return errors.Trace(err) + } + if r.selectResp != nil { + r.memConsume(-atomic.LoadInt64(&r.selectRespSize)) + } + if resultSubset == nil { + r.selectResp = nil + atomic.StoreInt64(&r.selectRespSize, 0) + if !r.durationReported { + // final round of fetch + // TODO: Add a label to distinguish between success or failure. + // https://github.com/pingcap/tidb/issues/11397 + if r.paging { + metrics.DistSQLQueryHistogram.WithLabelValues(r.label, r.sqlType, "paging").Observe(r.fetchDuration.Seconds()) + } else { + metrics.DistSQLQueryHistogram.WithLabelValues(r.label, r.sqlType, "common").Observe(r.fetchDuration.Seconds()) + } + r.durationReported = true + } + return nil + } + r.selectResp = new(tipb.SelectResponse) + err = r.selectResp.Unmarshal(resultSubset.GetData()) + if err != nil { + return errors.Trace(err) + } + respSize := int64(r.selectResp.Size()) + atomic.StoreInt64(&r.selectRespSize, respSize) + r.memConsume(respSize) + if err := r.selectResp.Error; err != nil { + return dbterror.ClassTiKV.Synthesize(terror.ErrCode(err.Code), err.Msg) + } + if err = r.ctx.SQLKiller.HandleSignal(); err != nil { + return err + } + for _, warning := range r.selectResp.Warnings { + r.ctx.AppendWarning(dbterror.ClassTiKV.Synthesize(terror.ErrCode(warning.Code), warning.Msg)) + } + + r.partialCount++ + + hasStats, ok := resultSubset.(CopRuntimeStats) + if ok { + copStats := hasStats.GetCopRuntimeStats() + if copStats != nil { + if err := r.updateCopRuntimeStats(ctx, copStats, resultSubset.RespTime()); err != nil { + return err + } + copStats.CopTime = duration + r.ctx.ExecDetails.MergeExecDetails(&copStats.ExecDetails, nil) + } + } + if len(r.selectResp.Chunks) != 0 { + break + } + } + return nil +} + +func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { + chk.Reset() + if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { + err := r.fetchResp(ctx) + if err != nil { + return err + } + if r.selectResp == nil { + return nil + } + } + // TODO(Shenghui Wu): add metrics + encodeType := r.selectResp.GetEncodeType() + switch encodeType { + case tipb.EncodeType_TypeDefault: + return r.readFromDefault(ctx, chk) + case tipb.EncodeType_TypeChunk: + return r.readFromChunk(ctx, chk) + } + return errors.Errorf("unsupported encode type:%v", encodeType) +} + +// NextRaw returns the next raw partial result. +func (r *selectResult) NextRaw(ctx context.Context) (data []byte, err error) { + failpoint.Inject("mockNextRawError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("mockNextRawError")) + } + }) + + resultSubset, err := r.resp.Next(ctx) + r.partialCount++ + if resultSubset != nil && err == nil { + data = resultSubset.GetData() + } + return data, err +} + +func (r *selectResult) readFromDefault(ctx context.Context, chk *chunk.Chunk) error { + for !chk.IsFull() { + if r.respChkIdx == len(r.selectResp.Chunks) { + err := r.fetchResp(ctx) + if err != nil || r.selectResp == nil { + return err + } + } + err := r.readRowsData(chk) + if err != nil { + return err + } + if len(r.selectResp.Chunks[r.respChkIdx].RowsData) == 0 { + r.respChkIdx++ + } + } + return nil +} + +func (r *selectResult) readFromChunk(ctx context.Context, chk *chunk.Chunk) error { + if r.respChunkDecoder == nil { + r.respChunkDecoder = chunk.NewDecoder( + chunk.NewChunkWithCapacity(r.fieldTypes, 0), + r.fieldTypes, + ) + } + + for !chk.IsFull() { + if r.respChkIdx == len(r.selectResp.Chunks) { + err := r.fetchResp(ctx) + if err != nil || r.selectResp == nil { + return err + } + } + + if r.respChunkDecoder.IsFinished() { + r.respChunkDecoder.Reset(r.selectResp.Chunks[r.respChkIdx].RowsData) + } + // If the next chunk size is greater than required rows * 0.8, reuse the memory of the next chunk and return + // immediately. Otherwise, splice the data to one chunk and wait the next chunk. + if r.respChunkDecoder.RemainedRows() > int(float64(chk.RequiredRows())*0.8) { + if chk.NumRows() > 0 { + return nil + } + r.respChunkDecoder.ReuseIntermChk(chk) + r.respChkIdx++ + return nil + } + r.respChunkDecoder.Decode(chk) + if r.respChunkDecoder.IsFinished() { + r.respChkIdx++ + } + } + return nil +} + +// FillDummySummariesForTiFlashTasks fills dummy execution summaries for mpp tasks which lack summaries +func FillDummySummariesForTiFlashTasks(runtimeStatsColl *execdetails.RuntimeStatsColl, callee string, storeTypeName string, allPlanIDs []int, recordedPlanIDs map[int]int) { + num := uint64(0) + dummySummary := &tipb.ExecutorExecutionSummary{TimeProcessedNs: &num, NumProducedRows: &num, NumIterations: &num, ExecutorId: nil} + for _, planID := range allPlanIDs { + if _, ok := recordedPlanIDs[planID]; !ok { + runtimeStatsColl.RecordOneCopTask(planID, storeTypeName, callee, dummySummary) + } + } +} + +// recordExecutionSummariesForTiFlashTasks records mpp task execution summaries +func recordExecutionSummariesForTiFlashTasks(runtimeStatsColl *execdetails.RuntimeStatsColl, executionSummaries []*tipb.ExecutorExecutionSummary, callee string, storeTypeName string, allPlanIDs []int) { + var recordedPlanIDs = make(map[int]int) + for _, detail := range executionSummaries { + if detail != nil && detail.TimeProcessedNs != nil && + detail.NumProducedRows != nil && detail.NumIterations != nil { + recordedPlanIDs[runtimeStatsColl. + RecordOneCopTask(-1, storeTypeName, callee, detail)] = 0 + } + } + FillDummySummariesForTiFlashTasks(runtimeStatsColl, callee, storeTypeName, allPlanIDs, recordedPlanIDs) +} + +func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr.CopRuntimeStats, respTime time.Duration) (err error) { + callee := copStats.CalleeAddress + if r.rootPlanID <= 0 || r.ctx.RuntimeStatsColl == nil || (callee == "" && (copStats.ReqStats == nil || len(copStats.ReqStats.RPCStats) == 0)) { + return + } + + if copStats.ScanDetail != nil { + readKeys := copStats.ScanDetail.ProcessedKeys + readTime := copStats.TimeDetail.KvReadWallTime.Seconds() + readSize := float64(copStats.ScanDetail.ProcessedKeysSize) + tikvmetrics.ObserveReadSLI(uint64(readKeys), readTime, readSize) + } + + if r.stats == nil { + r.stats = &selectResultRuntimeStats{ + backoffSleep: make(map[string]time.Duration), + reqStat: tikv.NewRegionRequestRuntimeStats(), + distSQLConcurrency: r.distSQLConcurrency, + } + if ci, ok := r.resp.(copr.CopInfo); ok { + conc, extraConc := ci.GetConcurrency() + r.stats.distSQLConcurrency = conc + r.stats.extraConcurrency = extraConc + } + } + r.stats.mergeCopRuntimeStats(copStats, respTime) + + if copStats.ScanDetail != nil && len(r.copPlanIDs) > 0 { + r.ctx.RuntimeStatsColl.RecordScanDetail(r.copPlanIDs[len(r.copPlanIDs)-1], r.storeType.Name(), copStats.ScanDetail) + } + if len(r.copPlanIDs) > 0 { + r.ctx.RuntimeStatsColl.RecordTimeDetail(r.copPlanIDs[len(r.copPlanIDs)-1], r.storeType.Name(), &copStats.TimeDetail) + } + + // If hasExecutor is true, it means the summary is returned from TiFlash. + hasExecutor := false + for _, detail := range r.selectResp.GetExecutionSummaries() { + if detail != nil && detail.TimeProcessedNs != nil && + detail.NumProducedRows != nil && detail.NumIterations != nil { + if detail.ExecutorId != nil { + hasExecutor = true + } + break + } + } + + if ruDetailsRaw := ctx.Value(clientutil.RUDetailsCtxKey); ruDetailsRaw != nil && r.storeType == kv.TiFlash { + if err = execdetails.MergeTiFlashRUConsumption(r.selectResp.GetExecutionSummaries(), ruDetailsRaw.(*clientutil.RUDetails)); err != nil { + return err + } + } + if hasExecutor { + recordExecutionSummariesForTiFlashTasks(r.ctx.RuntimeStatsColl, r.selectResp.GetExecutionSummaries(), callee, r.storeType.Name(), r.copPlanIDs) + } else { + // For cop task cases, we still need this protection. + if len(r.selectResp.GetExecutionSummaries()) != len(r.copPlanIDs) { + // for TiFlash streaming call(BatchCop and MPP), it is by design that only the last response will + // carry the execution summaries, so it is ok if some responses have no execution summaries, should + // not trigger an error log in this case. + if !(r.storeType == kv.TiFlash && len(r.selectResp.GetExecutionSummaries()) == 0) { + logutil.Logger(ctx).Error("invalid cop task execution summaries length", + zap.Int("expected", len(r.copPlanIDs)), + zap.Int("received", len(r.selectResp.GetExecutionSummaries()))) + } + return + } + for i, detail := range r.selectResp.GetExecutionSummaries() { + if detail != nil && detail.TimeProcessedNs != nil && + detail.NumProducedRows != nil && detail.NumIterations != nil { + planID := r.copPlanIDs[i] + r.ctx.RuntimeStatsColl. + RecordOneCopTask(planID, r.storeType.Name(), callee, detail) + } + } + } + return +} + +func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { + rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData + decoder := codec.NewDecoder(chk, r.ctx.Location) + for !chk.IsFull() && len(rowsData) > 0 { + for i := 0; i < r.rowLen; i++ { + rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) + if err != nil { + return err + } + } + } + r.selectResp.Chunks[r.respChkIdx].RowsData = rowsData + return nil +} + +func (r *selectResult) memConsume(bytes int64) { + if r.memTracker != nil { + r.memTracker.Consume(bytes) + } +} + +// Close closes selectResult. +func (r *selectResult) Close() error { + metrics.DistSQLPartialCountHistogram.Observe(float64(r.partialCount)) + respSize := atomic.SwapInt64(&r.selectRespSize, 0) + if respSize > 0 { + r.memConsume(-respSize) + } + if r.ctx != nil { + if unconsumed, ok := r.resp.(copr.HasUnconsumedCopRuntimeStats); ok && unconsumed != nil { + unconsumedCopStats := unconsumed.CollectUnconsumedCopRuntimeStats() + for _, copStats := range unconsumedCopStats { + _ = r.updateCopRuntimeStats(context.Background(), copStats, time.Duration(0)) + r.ctx.ExecDetails.MergeExecDetails(&copStats.ExecDetails, nil) + } + } + } + if r.stats != nil && r.ctx != nil { + defer func() { + if ci, ok := r.resp.(copr.CopInfo); ok { + r.stats.buildTaskDuration = ci.GetBuildTaskElapsed() + batched, fallback := ci.GetStoreBatchInfo() + if batched != 0 || fallback != 0 { + r.stats.storeBatchedNum, r.stats.storeBatchedFallbackNum = batched, fallback + } + } + r.ctx.RuntimeStatsColl.RegisterStats(r.rootPlanID, r.stats) + }() + } + return r.resp.Close() +} + +// CopRuntimeStats is an interface uses to check whether the result has cop runtime stats. +type CopRuntimeStats interface { + // GetCopRuntimeStats gets the cop runtime stats information. + GetCopRuntimeStats() *copr.CopRuntimeStats +} + +type selectResultRuntimeStats struct { + copRespTime execdetails.Percentile[execdetails.Duration] + procKeys execdetails.Percentile[execdetails.Int64] + backoffSleep map[string]time.Duration + totalProcessTime time.Duration + totalWaitTime time.Duration + reqStat *tikv.RegionRequestRuntimeStats + distSQLConcurrency int + extraConcurrency int + CoprCacheHitNum int64 + storeBatchedNum uint64 + storeBatchedFallbackNum uint64 + buildTaskDuration time.Duration +} + +func (s *selectResultRuntimeStats) mergeCopRuntimeStats(copStats *copr.CopRuntimeStats, respTime time.Duration) { + s.copRespTime.Add(execdetails.Duration(respTime)) + if copStats.ScanDetail != nil { + s.procKeys.Add(execdetails.Int64(copStats.ScanDetail.ProcessedKeys)) + } else { + s.procKeys.Add(0) + } + maps.Copy(s.backoffSleep, copStats.BackoffSleep) + s.totalProcessTime += copStats.TimeDetail.ProcessTime + s.totalWaitTime += copStats.TimeDetail.WaitTime + s.reqStat.Merge(copStats.ReqStats) + if copStats.CoprCacheHit { + s.CoprCacheHitNum++ + } +} + +func (s *selectResultRuntimeStats) Clone() execdetails.RuntimeStats { + newRs := selectResultRuntimeStats{ + copRespTime: execdetails.Percentile[execdetails.Duration]{}, + procKeys: execdetails.Percentile[execdetails.Int64]{}, + backoffSleep: make(map[string]time.Duration, len(s.backoffSleep)), + reqStat: tikv.NewRegionRequestRuntimeStats(), + distSQLConcurrency: s.distSQLConcurrency, + extraConcurrency: s.extraConcurrency, + CoprCacheHitNum: s.CoprCacheHitNum, + storeBatchedNum: s.storeBatchedNum, + storeBatchedFallbackNum: s.storeBatchedFallbackNum, + buildTaskDuration: s.buildTaskDuration, + } + newRs.copRespTime.MergePercentile(&s.copRespTime) + newRs.procKeys.MergePercentile(&s.procKeys) + for k, v := range s.backoffSleep { + newRs.backoffSleep[k] += v + } + newRs.totalProcessTime += s.totalProcessTime + newRs.totalWaitTime += s.totalWaitTime + newRs.reqStat = s.reqStat.Clone() + return &newRs +} + +func (s *selectResultRuntimeStats) Merge(rs execdetails.RuntimeStats) { + other, ok := rs.(*selectResultRuntimeStats) + if !ok { + return + } + s.copRespTime.MergePercentile(&other.copRespTime) + s.procKeys.MergePercentile(&other.procKeys) + + for k, v := range other.backoffSleep { + s.backoffSleep[k] += v + } + s.totalProcessTime += other.totalProcessTime + s.totalWaitTime += other.totalWaitTime + s.reqStat.Merge(other.reqStat) + s.CoprCacheHitNum += other.CoprCacheHitNum + if other.distSQLConcurrency > s.distSQLConcurrency { + s.distSQLConcurrency = other.distSQLConcurrency + } + if other.extraConcurrency > s.extraConcurrency { + s.extraConcurrency = other.extraConcurrency + } + s.storeBatchedNum += other.storeBatchedNum + s.storeBatchedFallbackNum += other.storeBatchedFallbackNum + s.buildTaskDuration += other.buildTaskDuration +} + +func (s *selectResultRuntimeStats) String() string { + buf := bytes.NewBuffer(nil) + reqStat := s.reqStat + if s.copRespTime.Size() > 0 { + size := s.copRespTime.Size() + if size == 1 { + fmt.Fprintf(buf, "cop_task: {num: 1, max: %v, proc_keys: %v", execdetails.FormatDuration(time.Duration(s.copRespTime.GetPercentile(0))), s.procKeys.GetPercentile(0)) + } else { + vMax, vMin := s.copRespTime.GetMax(), s.copRespTime.GetMin() + vP95 := s.copRespTime.GetPercentile(0.95) + sum := s.copRespTime.Sum() + vAvg := time.Duration(sum / float64(size)) + + keyMax := s.procKeys.GetMax() + keyP95 := s.procKeys.GetPercentile(0.95) + fmt.Fprintf(buf, "cop_task: {num: %v, max: %v, min: %v, avg: %v, p95: %v", size, + execdetails.FormatDuration(time.Duration(vMax.GetFloat64())), execdetails.FormatDuration(time.Duration(vMin.GetFloat64())), + execdetails.FormatDuration(vAvg), execdetails.FormatDuration(time.Duration(vP95))) + if keyMax > 0 { + buf.WriteString(", max_proc_keys: ") + buf.WriteString(strconv.FormatInt(int64(keyMax), 10)) + buf.WriteString(", p95_proc_keys: ") + buf.WriteString(strconv.FormatInt(int64(keyP95), 10)) + } + } + if s.totalProcessTime > 0 { + buf.WriteString(", tot_proc: ") + buf.WriteString(execdetails.FormatDuration(s.totalProcessTime)) + if s.totalWaitTime > 0 { + buf.WriteString(", tot_wait: ") + buf.WriteString(execdetails.FormatDuration(s.totalWaitTime)) + } + } + if config.GetGlobalConfig().TiKVClient.CoprCache.CapacityMB > 0 { + fmt.Fprintf(buf, ", copr_cache_hit_ratio: %v", + strconv.FormatFloat(s.calcCacheHit(), 'f', 2, 64)) + } else { + buf.WriteString(", copr_cache: disabled") + } + if s.buildTaskDuration > 0 { + buf.WriteString(", build_task_duration: ") + buf.WriteString(execdetails.FormatDuration(s.buildTaskDuration)) + } + if s.distSQLConcurrency > 0 { + buf.WriteString(", max_distsql_concurrency: ") + buf.WriteString(strconv.FormatInt(int64(s.distSQLConcurrency), 10)) + } + if s.extraConcurrency > 0 { + buf.WriteString(", max_extra_concurrency: ") + buf.WriteString(strconv.FormatInt(int64(s.extraConcurrency), 10)) + } + if s.storeBatchedNum > 0 { + buf.WriteString(", store_batch_num: ") + buf.WriteString(strconv.FormatInt(int64(s.storeBatchedNum), 10)) + } + if s.storeBatchedFallbackNum > 0 { + buf.WriteString(", store_batch_fallback_num: ") + buf.WriteString(strconv.FormatInt(int64(s.storeBatchedFallbackNum), 10)) + } + buf.WriteString("}") + } + + rpcStatsStr := reqStat.String() + if len(rpcStatsStr) > 0 { + buf.WriteString(", rpc_info:{") + buf.WriteString(rpcStatsStr) + buf.WriteString("}") + } + + if len(s.backoffSleep) > 0 { + buf.WriteString(", backoff{") + idx := 0 + for k, d := range s.backoffSleep { + if idx > 0 { + buf.WriteString(", ") + } + idx++ + fmt.Fprintf(buf, "%s: %s", k, execdetails.FormatDuration(d)) + } + buf.WriteString("}") + } + return buf.String() +} + +// Tp implements the RuntimeStats interface. +func (*selectResultRuntimeStats) Tp() int { + return execdetails.TpSelectResultRuntimeStats +} + +func (s *selectResultRuntimeStats) calcCacheHit() float64 { + hit := s.CoprCacheHitNum + tot := s.copRespTime.Size() + if s.storeBatchedNum > 0 { + tot += int(s.storeBatchedNum) + } + if tot == 0 { + return 0 + } + return float64(hit) / float64(tot) +} diff --git a/pkg/disttask/framework/scheduler/binding__failpoint_binding__.go b/pkg/disttask/framework/scheduler/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..c960860aa5565 --- /dev/null +++ b/pkg/disttask/framework/scheduler/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package scheduler + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/disttask/framework/scheduler/nodes.go b/pkg/disttask/framework/scheduler/nodes.go index 7a8dc8cb401c4..99d81656a4804 100644 --- a/pkg/disttask/framework/scheduler/nodes.go +++ b/pkg/disttask/framework/scheduler/nodes.go @@ -150,9 +150,9 @@ func (nm *NodeManager) refreshNodes(ctx context.Context, taskMgr TaskManager, sl slotMgr.updateCapacity(cpuCount) nm.nodes.Store(&newNodes) - failpoint.Inject("syncRefresh", func() { + if _, _err_ := failpoint.Eval(_curpkg_("syncRefresh")); _err_ == nil { TestRefreshedChan <- struct{}{} - }) + } } // GetNodes returns the nodes managed by the framework. diff --git a/pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ b/pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ new file mode 100644 index 0000000000000..7a8dc8cb401c4 --- /dev/null +++ b/pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ @@ -0,0 +1,191 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "sync/atomic" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + llog "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/util/intest" + "go.uber.org/zap" +) + +var ( + // liveNodesCheckInterval is the tick interval of fetching all server infos from etcs. + nodesCheckInterval = 2 * CheckTaskFinishedInterval +) + +// NodeManager maintains live TiDB nodes in the cluster, and maintains the nodes +// managed by the framework. +type NodeManager struct { + logger *zap.Logger + // prevLiveNodes is used to record the live nodes in last checking. + prevLiveNodes map[string]struct{} + // nodes is the cached nodes managed by the framework. + // see TaskManager.GetNodes for more details. + nodes atomic.Pointer[[]proto.ManagedNode] +} + +func newNodeManager(serverID string) *NodeManager { + logger := log.L() + if intest.InTest { + logger = log.L().With(zap.String("server-id", serverID)) + } + nm := &NodeManager{ + logger: logger, + prevLiveNodes: make(map[string]struct{}), + } + nodes := make([]proto.ManagedNode, 0, 10) + nm.nodes.Store(&nodes) + return nm +} + +func (nm *NodeManager) maintainLiveNodesLoop(ctx context.Context, taskMgr TaskManager) { + ticker := time.NewTicker(nodesCheckInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + nm.maintainLiveNodes(ctx, taskMgr) + } + } +} + +// maintainLiveNodes manages live node info in dist_framework_meta table +// see recoverMetaLoop in task executor for when node is inserted into dist_framework_meta. +func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManager) { + // Safe to discard errors since this function can be called at regular intervals. + liveExecIDs, err := GetLiveExecIDs(ctx) + if err != nil { + nm.logger.Warn("generate task executor nodes met error", llog.ShortError(err)) + return + } + nodeChanged := len(liveExecIDs) != len(nm.prevLiveNodes) + currLiveNodes := make(map[string]struct{}, len(liveExecIDs)) + for _, execID := range liveExecIDs { + if _, ok := nm.prevLiveNodes[execID]; !ok { + nodeChanged = true + } + currLiveNodes[execID] = struct{}{} + } + if !nodeChanged { + return + } + + oldNodes, err := taskMgr.GetAllNodes(ctx) + if err != nil { + nm.logger.Warn("get all nodes met error", llog.ShortError(err)) + return + } + + deadNodes := make([]string, 0) + for _, node := range oldNodes { + if _, ok := currLiveNodes[node.ID]; !ok { + deadNodes = append(deadNodes, node.ID) + } + } + if len(deadNodes) == 0 { + nm.prevLiveNodes = currLiveNodes + return + } + nm.logger.Info("delete dead nodes from dist_framework_meta", + zap.Strings("dead-nodes", deadNodes)) + err = taskMgr.DeleteDeadNodes(ctx, deadNodes) + if err != nil { + nm.logger.Warn("delete dead nodes from dist_framework_meta failed", llog.ShortError(err)) + return + } + nm.prevLiveNodes = currLiveNodes +} + +func (nm *NodeManager) refreshNodesLoop(ctx context.Context, taskMgr TaskManager, slotMgr *SlotManager) { + ticker := time.NewTicker(nodesCheckInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + nm.refreshNodes(ctx, taskMgr, slotMgr) + } + } +} + +// TestRefreshedChan is used to sync the test. +var TestRefreshedChan = make(chan struct{}) + +// refreshNodes maintains the nodes managed by the framework. +func (nm *NodeManager) refreshNodes(ctx context.Context, taskMgr TaskManager, slotMgr *SlotManager) { + newNodes, err := taskMgr.GetAllNodes(ctx) + if err != nil { + nm.logger.Warn("get managed nodes met error", llog.ShortError(err)) + return + } + + var cpuCount int + for _, node := range newNodes { + if node.CPUCount > 0 { + cpuCount = node.CPUCount + } + } + slotMgr.updateCapacity(cpuCount) + nm.nodes.Store(&newNodes) + + failpoint.Inject("syncRefresh", func() { + TestRefreshedChan <- struct{}{} + }) +} + +// GetNodes returns the nodes managed by the framework. +// return a copy of the nodes. +func (nm *NodeManager) getNodes() []proto.ManagedNode { + nodes := *nm.nodes.Load() + res := make([]proto.ManagedNode, len(nodes)) + copy(res, nodes) + return res +} + +func filterByScope(nodes []proto.ManagedNode, targetScope string) []string { + var nodeIDs []string + haveBackground := false + for _, node := range nodes { + if node.Role == "background" { + haveBackground = true + } + } + // prefer to use "background" node instead of "" node. + if targetScope == "" && haveBackground { + for _, node := range nodes { + if node.Role == "background" { + nodeIDs = append(nodeIDs, node.ID) + } + } + return nodeIDs + } + + for _, node := range nodes { + if node.Role == targetScope { + nodeIDs = append(nodeIDs, node.ID) + } + } + return nodeIDs +} diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 476339645762d..39f3be3a48134 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -177,19 +177,19 @@ func (s *BaseScheduler) scheduleTask() { } task := *s.GetTask() // TODO: refine failpoints below. - failpoint.Inject("exitScheduler", func() { - failpoint.Return() - }) - failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("exitScheduler")); _err_ == nil { + return + } + if val, _err_ := failpoint.Eval(_curpkg_("cancelTaskAfterRefreshTask")); _err_ == nil { if val.(bool) && task.State == proto.TaskStateRunning { err := s.taskMgr.CancelTask(s.ctx, task.ID) if err != nil { s.logger.Error("cancel task failed", zap.Error(err)) } } - }) + } - failpoint.Inject("pausePendingTask", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("pausePendingTask")); _err_ == nil { if val.(bool) && task.State == proto.TaskStatePending { _, err := s.taskMgr.PauseTask(s.ctx, task.Key) if err != nil { @@ -198,9 +198,9 @@ func (s *BaseScheduler) scheduleTask() { task.State = proto.TaskStatePausing s.task.Store(&task) } - }) + } - failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("pauseTaskAfterRefreshTask")); _err_ == nil { if val.(bool) && task.State == proto.TaskStateRunning { _, err := s.taskMgr.PauseTask(s.ctx, task.Key) if err != nil { @@ -209,7 +209,7 @@ func (s *BaseScheduler) scheduleTask() { task.State = proto.TaskStatePausing s.task.Store(&task) } - }) + } switch task.State { case proto.TaskStateCancelling: @@ -261,7 +261,7 @@ func (s *BaseScheduler) scheduleTask() { s.logger.Info("schedule task meet err, reschedule it", zap.Error(err)) } - failpoint.InjectCall("mockOwnerChange") + failpoint.Call(_curpkg_("mockOwnerChange")) } } } @@ -302,7 +302,7 @@ func (s *BaseScheduler) onPausing() error { func (s *BaseScheduler) onPaused() error { task := s.GetTask() s.logger.Info("on paused state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - failpoint.InjectCall("mockDMLExecutionOnPausedState") + failpoint.Call(_curpkg_("mockDMLExecutionOnPausedState")) return nil } @@ -483,9 +483,9 @@ func (s *BaseScheduler) scheduleSubTask( size += uint64(len(meta)) } - failpoint.Inject("cancelBeforeUpdateTask", func() { + if _, _err_ := failpoint.Eval(_curpkg_("cancelBeforeUpdateTask")); _err_ == nil { _ = s.taskMgr.CancelTask(s.ctx, task.ID) - }) + } // as other fields and generated key and index KV takes space too, we limit // the size of subtasks to 80% of the transaction limit. @@ -537,9 +537,9 @@ var MockServerInfo atomic.Pointer[[]string] // GetLiveExecIDs returns all live executor node IDs. func GetLiveExecIDs(ctx context.Context) ([]string, error) { - failpoint.Inject("mockTaskExecutorNodes", func() { - failpoint.Return(*MockServerInfo.Load(), nil) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockTaskExecutorNodes")); _err_ == nil { + return *MockServerInfo.Load(), nil + } serverInfos, err := generateTaskExecutorNodes(ctx) if err != nil { return nil, err diff --git a/pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ b/pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ new file mode 100644 index 0000000000000..476339645762d --- /dev/null +++ b/pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ @@ -0,0 +1,624 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "math/rand" + "strings" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/backoff" + disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +const ( + // for a cancelled task, it's terminal state is reverted or reverted_failed, + // so we use a special error message to indicate that the task is cancelled + // by user. + taskCancelMsg = "cancelled by user" +) + +var ( + // CheckTaskFinishedInterval is the interval for scheduler. + // exported for testing. + CheckTaskFinishedInterval = 500 * time.Millisecond + // RetrySQLTimes is the max retry times when executing SQL. + RetrySQLTimes = 30 + // RetrySQLInterval is the initial interval between two SQL retries. + RetrySQLInterval = 3 * time.Second + // RetrySQLMaxInterval is the max interval between two SQL retries. + RetrySQLMaxInterval = 30 * time.Second +) + +// Scheduler manages the lifetime of a task +// including submitting subtasks and updating the status of a task. +type Scheduler interface { + // Init initializes the scheduler, should be called before ExecuteTask. + // if Init returns error, scheduler manager will fail the task directly, + // so the returned error should be a fatal error. + Init() error + // ScheduleTask schedules the task execution step by step. + ScheduleTask() + // Close closes the scheduler, should be called if Init returns nil. + Close() + // GetTask returns the task that the scheduler is managing. + GetTask() *proto.Task + Extension +} + +// BaseScheduler is the base struct for Scheduler. +// each task type embed this struct and implement the Extension interface. +type BaseScheduler struct { + ctx context.Context + Param + // task might be accessed by multiple goroutines, so don't change its fields + // directly, make a copy, update and store it back to the atomic pointer. + task atomic.Pointer[proto.Task] + logger *zap.Logger + // when RegisterSchedulerFactory, the factory MUST initialize this fields. + Extension + + balanceSubtaskTick int + // rand is for generating random selection of nodes. + rand *rand.Rand +} + +// NewBaseScheduler creates a new BaseScheduler. +func NewBaseScheduler(ctx context.Context, task *proto.Task, param Param) *BaseScheduler { + logger := log.L().With(zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type), zap.Bool("allocated-slots", param.allocatedSlots)) + if intest.InTest { + logger = logger.With(zap.String("server-id", param.serverID)) + } + s := &BaseScheduler{ + ctx: ctx, + Param: param, + logger: logger, + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + } + s.task.Store(task) + return s +} + +// Init implements the Scheduler interface. +func (*BaseScheduler) Init() error { + return nil +} + +// ScheduleTask implements the Scheduler interface. +func (s *BaseScheduler) ScheduleTask() { + task := s.GetTask() + s.logger.Info("schedule task", + zap.Stringer("state", task.State), zap.Int("concurrency", task.Concurrency)) + s.scheduleTask() +} + +// Close closes the scheduler. +func (*BaseScheduler) Close() { +} + +// GetTask implements the Scheduler interface. +func (s *BaseScheduler) GetTask() *proto.Task { + return s.task.Load() +} + +// refreshTaskIfNeeded fetch task state from tidb_global_task table. +func (s *BaseScheduler) refreshTaskIfNeeded() error { + task := s.GetTask() + // we only query the base fields of task to reduce memory usage, other fields + // are refreshed when needed. + newTaskBase, err := s.taskMgr.GetTaskBaseByID(s.ctx, task.ID) + if err != nil { + return err + } + // state might be changed by user to pausing/resuming/cancelling, or + // in case of network partition, state/step/meta might be changed by other scheduler, + // in both cases we refresh the whole task object. + if newTaskBase.State != task.State || newTaskBase.Step != task.Step { + s.logger.Info("task state/step changed by user or other scheduler", + zap.Stringer("old-state", task.State), + zap.Stringer("new-state", newTaskBase.State), + zap.String("old-step", proto.Step2Str(task.Type, task.Step)), + zap.String("new-step", proto.Step2Str(task.Type, newTaskBase.Step))) + newTask, err := s.taskMgr.GetTaskByID(s.ctx, task.ID) + if err != nil { + return err + } + s.task.Store(newTask) + } + return nil +} + +// scheduleTask schedule the task execution step by step. +func (s *BaseScheduler) scheduleTask() { + ticker := time.NewTicker(CheckTaskFinishedInterval) + defer ticker.Stop() + for { + select { + case <-s.ctx.Done(): + s.logger.Info("schedule task exits") + return + case <-ticker.C: + err := s.refreshTaskIfNeeded() + if err != nil { + if errors.Cause(err) == storage.ErrTaskNotFound { + // this can happen when task is reverted/succeed, but before + // we reach here, cleanup routine move it to history. + return + } + s.logger.Error("refresh task failed", zap.Error(err)) + continue + } + task := *s.GetTask() + // TODO: refine failpoints below. + failpoint.Inject("exitScheduler", func() { + failpoint.Return() + }) + failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) { + if val.(bool) && task.State == proto.TaskStateRunning { + err := s.taskMgr.CancelTask(s.ctx, task.ID) + if err != nil { + s.logger.Error("cancel task failed", zap.Error(err)) + } + } + }) + + failpoint.Inject("pausePendingTask", func(val failpoint.Value) { + if val.(bool) && task.State == proto.TaskStatePending { + _, err := s.taskMgr.PauseTask(s.ctx, task.Key) + if err != nil { + s.logger.Error("pause task failed", zap.Error(err)) + } + task.State = proto.TaskStatePausing + s.task.Store(&task) + } + }) + + failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) { + if val.(bool) && task.State == proto.TaskStateRunning { + _, err := s.taskMgr.PauseTask(s.ctx, task.Key) + if err != nil { + s.logger.Error("pause task failed", zap.Error(err)) + } + task.State = proto.TaskStatePausing + s.task.Store(&task) + } + }) + + switch task.State { + case proto.TaskStateCancelling: + err = s.onCancelling() + case proto.TaskStatePausing: + err = s.onPausing() + case proto.TaskStatePaused: + err = s.onPaused() + // close the scheduler. + if err == nil { + return + } + case proto.TaskStateResuming: + // Case with 2 nodes. + // Here is the timeline + // 1. task in pausing state. + // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. + // 3. node1's scheduler transfer the node from pausing to paused state. + // 4. resume the task. + // 5. node2 scheduler call refreshTask and get task with resuming state. + if !s.allocatedSlots { + s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State)) + return + } + err = s.onResuming() + case proto.TaskStateReverting: + err = s.onReverting() + case proto.TaskStatePending: + err = s.onPending() + case proto.TaskStateRunning: + // Case with 2 nodes. + // Here is the timeline + // 1. task in pausing state. + // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. + // 3. node1's scheduler transfer the node from pausing to paused state. + // 4. resume the task. + // 5. node1 start another scheduler and transfer the node from resuming to running state. + // 6. node2 scheduler call refreshTask and get task with running state. + if !s.allocatedSlots { + s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State)) + return + } + err = s.onRunning() + case proto.TaskStateSucceed, proto.TaskStateReverted, proto.TaskStateFailed: + s.onFinished() + return + } + if err != nil { + s.logger.Info("schedule task meet err, reschedule it", zap.Error(err)) + } + + failpoint.InjectCall("mockOwnerChange") + } + } +} + +// handle task in cancelling state, schedule revert subtasks. +func (s *BaseScheduler) onCancelling() error { + task := s.GetTask() + s.logger.Info("on cancelling state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) + + return s.revertTask(errors.New(taskCancelMsg)) +} + +// handle task in pausing state, cancel all running subtasks. +func (s *BaseScheduler) onPausing() error { + task := *s.GetTask() + s.logger.Info("on pausing state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) + if err != nil { + s.logger.Warn("check task failed", zap.Error(err)) + return err + } + runningPendingCnt := cntByStates[proto.SubtaskStateRunning] + cntByStates[proto.SubtaskStatePending] + if runningPendingCnt > 0 { + s.logger.Debug("on pausing state, this task keeps current state", zap.Stringer("state", task.State)) + return nil + } + + s.logger.Info("all running subtasks paused, update the task to paused state") + if err = s.taskMgr.PausedTask(s.ctx, task.ID); err != nil { + return err + } + task.State = proto.TaskStatePaused + s.task.Store(&task) + return nil +} + +// handle task in paused state. +func (s *BaseScheduler) onPaused() error { + task := s.GetTask() + s.logger.Info("on paused state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) + failpoint.InjectCall("mockDMLExecutionOnPausedState") + return nil +} + +// handle task in resuming state. +func (s *BaseScheduler) onResuming() error { + task := *s.GetTask() + s.logger.Info("on resuming state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) + if err != nil { + s.logger.Warn("check task failed", zap.Error(err)) + return err + } + if cntByStates[proto.SubtaskStatePaused] == 0 { + // Finish the resuming process. + s.logger.Info("all paused tasks converted to pending state, update the task to running state") + if err = s.taskMgr.ResumedTask(s.ctx, task.ID); err != nil { + return err + } + task.State = proto.TaskStateRunning + s.task.Store(&task) + return nil + } + + return s.taskMgr.ResumeSubtasks(s.ctx, task.ID) +} + +// handle task in reverting state, check all revert subtasks finishes. +func (s *BaseScheduler) onReverting() error { + task := *s.GetTask() + s.logger.Debug("on reverting state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) + if err != nil { + s.logger.Warn("check task failed", zap.Error(err)) + return err + } + runnableSubtaskCnt := cntByStates[proto.SubtaskStatePending] + cntByStates[proto.SubtaskStateRunning] + if runnableSubtaskCnt == 0 { + if err = s.OnDone(s.ctx, s, &task); err != nil { + return errors.Trace(err) + } + if err = s.taskMgr.RevertedTask(s.ctx, task.ID); err != nil { + return errors.Trace(err) + } + task.State = proto.TaskStateReverted + s.task.Store(&task) + return nil + } + // Wait all subtasks in this step finishes. + s.OnTick(s.ctx, &task) + s.logger.Debug("on reverting state, this task keeps current state", zap.Stringer("state", task.State)) + return nil +} + +// handle task in pending state, schedule subtasks. +func (s *BaseScheduler) onPending() error { + task := s.GetTask() + s.logger.Debug("on pending state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) + return s.switch2NextStep() +} + +// handle task in running state, check all running subtasks finishes. +// If subtasks finished, run into the next step. +func (s *BaseScheduler) onRunning() error { + task := s.GetTask() + s.logger.Debug("on running state", + zap.Stringer("state", task.State), + zap.String("step", proto.Step2Str(task.Type, task.Step))) + // check current step finishes. + cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) + if err != nil { + s.logger.Warn("check task failed", zap.Error(err)) + return err + } + if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 { + subTaskErrs, err := s.taskMgr.GetSubtaskErrors(s.ctx, task.ID) + if err != nil { + s.logger.Warn("collect subtask error failed", zap.Error(err)) + return err + } + if len(subTaskErrs) > 0 { + s.logger.Warn("subtasks encounter errors", zap.Errors("subtask-errs", subTaskErrs)) + // we only store the first error as task error. + return s.revertTask(subTaskErrs[0]) + } + } else if s.isStepSucceed(cntByStates) { + return s.switch2NextStep() + } + + // Wait all subtasks in this step finishes. + s.OnTick(s.ctx, task) + s.logger.Debug("on running state, this task keeps current state", zap.Stringer("state", task.State)) + return nil +} + +func (s *BaseScheduler) onFinished() { + task := s.GetTask() + metrics.UpdateMetricsForFinishTask(task) + s.logger.Debug("schedule task, task is finished", zap.Stringer("state", task.State)) +} + +func (s *BaseScheduler) switch2NextStep() error { + task := *s.GetTask() + nextStep := s.GetNextStep(&task.TaskBase) + s.logger.Info("switch to next step", + zap.String("current-step", proto.Step2Str(task.Type, task.Step)), + zap.String("next-step", proto.Step2Str(task.Type, nextStep))) + + if nextStep == proto.StepDone { + if err := s.OnDone(s.ctx, s, &task); err != nil { + return errors.Trace(err) + } + if err := s.taskMgr.SucceedTask(s.ctx, task.ID); err != nil { + return errors.Trace(err) + } + task.Step = nextStep + task.State = proto.TaskStateSucceed + s.task.Store(&task) + return nil + } + + nodes := s.nodeMgr.getNodes() + nodeIDs := filterByScope(nodes, task.TargetScope) + eligibleNodes, err := getEligibleNodes(s.ctx, s, nodeIDs) + if err != nil { + return err + } + + s.logger.Info("eligible instances", zap.Int("num", len(eligibleNodes))) + if len(eligibleNodes) == 0 { + return errors.New("no available TiDB node to dispatch subtasks") + } + + metas, err := s.OnNextSubtasksBatch(s.ctx, s, &task, eligibleNodes, nextStep) + if err != nil { + s.logger.Warn("generate part of subtasks failed", zap.Error(err)) + return s.handlePlanErr(err) + } + + if err = s.scheduleSubTask(&task, nextStep, metas, eligibleNodes); err != nil { + return err + } + task.Step = nextStep + task.State = proto.TaskStateRunning + // and OnNextSubtasksBatch might change meta of task. + s.task.Store(&task) + return nil +} + +func (s *BaseScheduler) scheduleSubTask( + task *proto.Task, + subtaskStep proto.Step, + metas [][]byte, + eligibleNodes []string) error { + s.logger.Info("schedule subtasks", + zap.Stringer("state", task.State), + zap.String("step", proto.Step2Str(task.Type, subtaskStep)), + zap.Int("concurrency", task.Concurrency), + zap.Int("subtasks", len(metas))) + + // the scheduled node of the subtask might not be optimal, as we run all + // scheduler in parallel, and update might be called too many times when + // multiple tasks are switching to next step. + // balancer will assign the subtasks to the right instance according to + // the system load of all nodes. + if err := s.slotMgr.update(s.ctx, s.nodeMgr, s.taskMgr); err != nil { + return err + } + adjustedEligibleNodes := s.slotMgr.adjustEligibleNodes(eligibleNodes, task.Concurrency) + var size uint64 + subTasks := make([]*proto.Subtask, 0, len(metas)) + for i, meta := range metas { + // we assign the subtask to the instance in a round-robin way. + pos := i % len(adjustedEligibleNodes) + instanceID := adjustedEligibleNodes[pos] + s.logger.Debug("create subtasks", zap.String("instanceID", instanceID)) + subTasks = append(subTasks, proto.NewSubtask( + subtaskStep, task.ID, task.Type, instanceID, task.Concurrency, meta, i+1)) + + size += uint64(len(meta)) + } + failpoint.Inject("cancelBeforeUpdateTask", func() { + _ = s.taskMgr.CancelTask(s.ctx, task.ID) + }) + + // as other fields and generated key and index KV takes space too, we limit + // the size of subtasks to 80% of the transaction limit. + limit := max(uint64(float64(kv.TxnTotalSizeLimit.Load())*0.8), 1) + fn := s.taskMgr.SwitchTaskStep + if size >= limit { + // On default, transaction size limit is controlled by tidb_mem_quota_query + // which is 1G on default, so it's unlikely to reach this limit, but in + // case user set txn-total-size-limit explicitly, we insert in batch. + s.logger.Info("subtasks size exceed limit, will insert in batch", + zap.Uint64("size", size), zap.Uint64("limit", limit)) + fn = s.taskMgr.SwitchTaskStepInBatch + } + + backoffer := backoff.NewExponential(RetrySQLInterval, 2, RetrySQLMaxInterval) + return handle.RunWithRetry(s.ctx, RetrySQLTimes, backoffer, s.logger, + func(context.Context) (bool, error) { + err := fn(s.ctx, task, proto.TaskStateRunning, subtaskStep, subTasks) + if errors.Cause(err) == storage.ErrUnstableSubtasks { + return false, err + } + return true, err + }, + ) +} + +func (s *BaseScheduler) handlePlanErr(err error) error { + task := *s.GetTask() + s.logger.Warn("generate plan failed", zap.Error(err), zap.Stringer("state", task.State)) + if s.IsRetryableErr(err) { + return err + } + return s.revertTask(err) +} + +func (s *BaseScheduler) revertTask(taskErr error) error { + task := *s.GetTask() + if err := s.taskMgr.RevertTask(s.ctx, task.ID, task.State, taskErr); err != nil { + return err + } + task.State = proto.TaskStateReverting + task.Error = taskErr + s.task.Store(&task) + return nil +} + +// MockServerInfo exported for scheduler_test.go +var MockServerInfo atomic.Pointer[[]string] + +// GetLiveExecIDs returns all live executor node IDs. +func GetLiveExecIDs(ctx context.Context) ([]string, error) { + failpoint.Inject("mockTaskExecutorNodes", func() { + failpoint.Return(*MockServerInfo.Load(), nil) + }) + serverInfos, err := generateTaskExecutorNodes(ctx) + if err != nil { + return nil, err + } + execIDs := make([]string, 0, len(serverInfos)) + for _, info := range serverInfos { + execIDs = append(execIDs, disttaskutil.GenerateExecID(info)) + } + return execIDs, nil +} + +func generateTaskExecutorNodes(ctx context.Context) (serverNodes []*infosync.ServerInfo, err error) { + var serverInfos map[string]*infosync.ServerInfo + _, etcd := ctx.Value("etcd").(bool) + if intest.InTest && !etcd { + serverInfos = infosync.MockGlobalServerInfoManagerEntry.GetAllServerInfo() + } else { + serverInfos, err = infosync.GetAllServerInfo(ctx) + } + if err != nil { + return nil, err + } + if len(serverInfos) == 0 { + return nil, errors.New("not found instance") + } + + serverNodes = make([]*infosync.ServerInfo, 0, len(serverInfos)) + for _, serverInfo := range serverInfos { + serverNodes = append(serverNodes, serverInfo) + } + return serverNodes, nil +} + +// GetPreviousSubtaskMetas get subtask metas from specific step. +func (s *BaseScheduler) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) { + previousSubtasks, err := s.taskMgr.GetAllSubtasksByStepAndState(s.ctx, taskID, step, proto.SubtaskStateSucceed) + if err != nil { + s.logger.Warn("get previous succeed subtask failed", zap.String("step", proto.Step2Str(s.GetTask().Type, step))) + return nil, err + } + previousSubtaskMetas := make([][]byte, 0, len(previousSubtasks)) + for _, subtask := range previousSubtasks { + previousSubtaskMetas = append(previousSubtaskMetas, subtask.Meta) + } + return previousSubtaskMetas, nil +} + +// WithNewSession executes the function with a new session. +func (s *BaseScheduler) WithNewSession(fn func(se sessionctx.Context) error) error { + return s.taskMgr.WithNewSession(fn) +} + +// WithNewTxn executes the fn in a new transaction. +func (s *BaseScheduler) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { + return s.taskMgr.WithNewTxn(ctx, fn) +} + +func (*BaseScheduler) isStepSucceed(cntByStates map[proto.SubtaskState]int64) bool { + _, ok := cntByStates[proto.SubtaskStateSucceed] + return len(cntByStates) == 0 || (len(cntByStates) == 1 && ok) +} + +// IsCancelledErr checks if the error is a cancelled error. +func IsCancelledErr(err error) bool { + return strings.Contains(err.Error(), taskCancelMsg) +} + +// getEligibleNodes returns the eligible(live) nodes for the task. +// if the task can only be scheduled to some specific nodes, return them directly, +// we don't care liveliness of them. +func getEligibleNodes(ctx context.Context, sch Scheduler, managedNodes []string) ([]string, error) { + serverNodes, err := sch.GetEligibleInstances(ctx, sch.GetTask()) + if err != nil { + return nil, err + } + logutil.BgLogger().Debug("eligible instances", zap.Int("num", len(serverNodes))) + if len(serverNodes) == 0 { + serverNodes = managedNodes + } + + return serverNodes, nil +} diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index 186fc77702a7d..8e7b4933fbe61 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -301,13 +301,13 @@ func (sm *Manager) failTask(id int64, currState proto.TaskState, err error) { func (sm *Manager) gcSubtaskHistoryTableLoop() { historySubtaskTableGcInterval := defaultHistorySubtaskTableGcInterval - failpoint.Inject("historySubtaskTableGcInterval", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("historySubtaskTableGcInterval")); _err_ == nil { if seconds, ok := val.(int); ok { historySubtaskTableGcInterval = time.Second * time.Duration(seconds) } <-WaitTaskFinished - }) + } sm.logger.Info("subtask table gc loop start") ticker := time.NewTicker(historySubtaskTableGcInterval) @@ -413,9 +413,9 @@ func (sm *Manager) doCleanupTask() { sm.logger.Warn("cleanup routine failed", zap.Error(err)) return } - failpoint.Inject("WaitCleanUpFinished", func() { + if _, _err_ := failpoint.Eval(_curpkg_("WaitCleanUpFinished")); _err_ == nil { WaitCleanUpFinished <- struct{}{} - }) + } sm.logger.Info("cleanup routine success") } @@ -442,9 +442,9 @@ func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error { sm.logger.Warn("cleanup routine failed", zap.Error(errors.Trace(firstErr))) } - failpoint.Inject("mockTransferErr", func() { - failpoint.Return(errors.New("transfer err")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockTransferErr")); _err_ == nil { + return errors.New("transfer err") + } return sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks) } diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ b/pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ new file mode 100644 index 0000000000000..186fc77702a7d --- /dev/null +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ @@ -0,0 +1,485 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scheduler + +import ( + "context" + "slices" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/metrics" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/syncutil" + "go.uber.org/zap" +) + +var ( + // CheckTaskRunningInterval is the interval for loading tasks. + // It is exported for testing. + CheckTaskRunningInterval = 3 * time.Second + // defaultHistorySubtaskTableGcInterval is the interval of gc history subtask table. + defaultHistorySubtaskTableGcInterval = 24 * time.Hour + // DefaultCleanUpInterval is the interval of cleanup routine. + DefaultCleanUpInterval = 10 * time.Minute + defaultCollectMetricsInterval = 5 * time.Second +) + +// WaitTaskFinished is used to sync the test. +var WaitTaskFinished = make(chan struct{}) + +func (sm *Manager) getSchedulerCount() int { + sm.mu.RLock() + defer sm.mu.RUnlock() + return len(sm.mu.schedulerMap) +} + +func (sm *Manager) addScheduler(taskID int64, scheduler Scheduler) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.mu.schedulerMap[taskID] = scheduler + sm.mu.schedulers = append(sm.mu.schedulers, scheduler) + slices.SortFunc(sm.mu.schedulers, func(i, j Scheduler) int { + return i.GetTask().CompareTask(j.GetTask()) + }) +} + +func (sm *Manager) hasScheduler(taskID int64) bool { + sm.mu.Lock() + defer sm.mu.Unlock() + _, ok := sm.mu.schedulerMap[taskID] + return ok +} + +func (sm *Manager) delScheduler(taskID int64) { + sm.mu.Lock() + defer sm.mu.Unlock() + delete(sm.mu.schedulerMap, taskID) + for i, scheduler := range sm.mu.schedulers { + if scheduler.GetTask().ID == taskID { + sm.mu.schedulers = append(sm.mu.schedulers[:i], sm.mu.schedulers[i+1:]...) + break + } + } +} + +func (sm *Manager) clearSchedulers() { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.mu.schedulerMap = make(map[int64]Scheduler) + sm.mu.schedulers = sm.mu.schedulers[:0] +} + +// getSchedulers returns a copy of schedulers. +func (sm *Manager) getSchedulers() []Scheduler { + sm.mu.RLock() + defer sm.mu.RUnlock() + res := make([]Scheduler, len(sm.mu.schedulers)) + copy(res, sm.mu.schedulers) + return res +} + +// Manager manage a bunch of schedulers. +// Scheduler schedule and monitor tasks. +// The scheduling task number is limited by size of gPool. +type Manager struct { + ctx context.Context + cancel context.CancelFunc + taskMgr TaskManager + wg tidbutil.WaitGroupWrapper + schedulerWG tidbutil.WaitGroupWrapper + slotMgr *SlotManager + nodeMgr *NodeManager + balancer *balancer + initialized bool + // serverID, it's value is ip:port now. + serverID string + logger *zap.Logger + + finishCh chan struct{} + + mu struct { + syncutil.RWMutex + schedulerMap map[int64]Scheduler + // in task order + schedulers []Scheduler + } +} + +// NewManager creates a scheduler struct. +func NewManager(ctx context.Context, taskMgr TaskManager, serverID string) *Manager { + logger := log.L() + if intest.InTest { + logger = log.L().With(zap.String("server-id", serverID)) + } + subCtx, cancel := context.WithCancel(ctx) + slotMgr := newSlotManager() + nodeMgr := newNodeManager(serverID) + schedulerManager := &Manager{ + ctx: subCtx, + cancel: cancel, + taskMgr: taskMgr, + serverID: serverID, + slotMgr: slotMgr, + nodeMgr: nodeMgr, + balancer: newBalancer(Param{ + taskMgr: taskMgr, + nodeMgr: nodeMgr, + slotMgr: slotMgr, + serverID: serverID, + }), + logger: logger, + finishCh: make(chan struct{}, proto.MaxConcurrentTask), + } + schedulerManager.mu.schedulerMap = make(map[int64]Scheduler) + + return schedulerManager +} + +// Start the schedulerManager, start the scheduleTaskLoop to start multiple schedulers. +func (sm *Manager) Start() { + // init cached managed nodes + sm.nodeMgr.refreshNodes(sm.ctx, sm.taskMgr, sm.slotMgr) + + sm.wg.Run(sm.scheduleTaskLoop) + sm.wg.Run(sm.gcSubtaskHistoryTableLoop) + sm.wg.Run(sm.cleanupTaskLoop) + sm.wg.Run(sm.collectLoop) + sm.wg.Run(func() { + sm.nodeMgr.maintainLiveNodesLoop(sm.ctx, sm.taskMgr) + }) + sm.wg.Run(func() { + sm.nodeMgr.refreshNodesLoop(sm.ctx, sm.taskMgr, sm.slotMgr) + }) + sm.wg.Run(func() { + sm.balancer.balanceLoop(sm.ctx, sm) + }) + sm.initialized = true +} + +// Cancel cancels the scheduler manager. +// used in test to simulate tidb node shutdown. +func (sm *Manager) Cancel() { + sm.cancel() +} + +// Stop the schedulerManager. +func (sm *Manager) Stop() { + sm.cancel() + sm.schedulerWG.Wait() + sm.wg.Wait() + sm.clearSchedulers() + sm.initialized = false + close(sm.finishCh) +} + +// Initialized check the manager initialized. +func (sm *Manager) Initialized() bool { + return sm.initialized +} + +// scheduleTaskLoop schedules the tasks. +func (sm *Manager) scheduleTaskLoop() { + sm.logger.Info("schedule task loop start") + ticker := time.NewTicker(CheckTaskRunningInterval) + defer ticker.Stop() + for { + select { + case <-sm.ctx.Done(): + sm.logger.Info("schedule task loop exits") + return + case <-ticker.C: + case <-handle.TaskChangedCh: + } + + taskCnt := sm.getSchedulerCount() + if taskCnt >= proto.MaxConcurrentTask { + sm.logger.Debug("scheduled tasks reached limit", + zap.Int("current", taskCnt), zap.Int("max", proto.MaxConcurrentTask)) + continue + } + + schedulableTasks, err := sm.getSchedulableTasks() + if err != nil { + continue + } + + err = sm.startSchedulers(schedulableTasks) + if err != nil { + continue + } + } +} + +func (sm *Manager) getSchedulableTasks() ([]*proto.TaskBase, error) { + tasks, err := sm.taskMgr.GetTopUnfinishedTasks(sm.ctx) + if err != nil { + sm.logger.Warn("get unfinished tasks failed", zap.Error(err)) + return nil, err + } + + schedulableTasks := make([]*proto.TaskBase, 0, len(tasks)) + for _, task := range tasks { + if sm.hasScheduler(task.ID) { + continue + } + // we check it before start scheduler, so no need to check it again. + // see startScheduler. + // this should not happen normally, unless user modify system table + // directly. + if getSchedulerFactory(task.Type) == nil { + sm.logger.Warn("unknown task type", zap.Int64("task-id", task.ID), + zap.Stringer("task-type", task.Type)) + sm.failTask(task.ID, task.State, errors.New("unknown task type")) + continue + } + schedulableTasks = append(schedulableTasks, task) + } + return schedulableTasks, nil +} + +func (sm *Manager) startSchedulers(schedulableTasks []*proto.TaskBase) error { + if len(schedulableTasks) == 0 { + return nil + } + if err := sm.slotMgr.update(sm.ctx, sm.nodeMgr, sm.taskMgr); err != nil { + sm.logger.Warn("update used slot failed", zap.Error(err)) + return err + } + for _, task := range schedulableTasks { + taskCnt := sm.getSchedulerCount() + if taskCnt >= proto.MaxConcurrentTask { + break + } + var reservedExecID string + allocateSlots := true + var ok bool + switch task.State { + case proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateResuming: + reservedExecID, ok = sm.slotMgr.canReserve(task) + if !ok { + // task of lower rank might be able to be scheduled. + continue + } + // reverting/cancelling/pausing + default: + allocateSlots = false + sm.logger.Info("start scheduler without allocating slots", + zap.Int64("task-id", task.ID), zap.Stringer("state", task.State)) + } + + metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.SchedulingStatus).Inc() + metrics.UpdateMetricsForScheduleTask(task.ID, task.Type) + sm.startScheduler(task, allocateSlots, reservedExecID) + } + return nil +} + +func (sm *Manager) failTask(id int64, currState proto.TaskState, err error) { + if err2 := sm.taskMgr.FailTask(sm.ctx, id, currState, err); err2 != nil { + sm.logger.Warn("failed to update task state to failed", + zap.Int64("task-id", id), zap.Error(err2)) + } +} + +func (sm *Manager) gcSubtaskHistoryTableLoop() { + historySubtaskTableGcInterval := defaultHistorySubtaskTableGcInterval + failpoint.Inject("historySubtaskTableGcInterval", func(val failpoint.Value) { + if seconds, ok := val.(int); ok { + historySubtaskTableGcInterval = time.Second * time.Duration(seconds) + } + + <-WaitTaskFinished + }) + + sm.logger.Info("subtask table gc loop start") + ticker := time.NewTicker(historySubtaskTableGcInterval) + defer ticker.Stop() + for { + select { + case <-sm.ctx.Done(): + sm.logger.Info("subtask history table gc loop exits") + return + case <-ticker.C: + err := sm.taskMgr.GCSubtasks(sm.ctx) + if err != nil { + sm.logger.Warn("subtask history table gc failed", zap.Error(err)) + } else { + sm.logger.Info("subtask history table gc success") + } + } + } +} + +func (sm *Manager) startScheduler(basicTask *proto.TaskBase, allocateSlots bool, reservedExecID string) { + task, err := sm.taskMgr.GetTaskByID(sm.ctx, basicTask.ID) + if err != nil { + sm.logger.Error("get task failed", zap.Int64("task-id", basicTask.ID), zap.Error(err)) + return + } + + schedulerFactory := getSchedulerFactory(task.Type) + scheduler := schedulerFactory(sm.ctx, task, Param{ + taskMgr: sm.taskMgr, + nodeMgr: sm.nodeMgr, + slotMgr: sm.slotMgr, + serverID: sm.serverID, + allocatedSlots: allocateSlots, + }) + if err = scheduler.Init(); err != nil { + sm.logger.Error("init scheduler failed", zap.Error(err)) + sm.failTask(task.ID, task.State, err) + return + } + sm.addScheduler(task.ID, scheduler) + if allocateSlots { + sm.slotMgr.reserve(basicTask, reservedExecID) + } + sm.logger.Info("task scheduler started", zap.Int64("task-id", task.ID)) + sm.schedulerWG.RunWithLog(func() { + defer func() { + scheduler.Close() + sm.delScheduler(task.ID) + if allocateSlots { + sm.slotMgr.unReserve(basicTask, reservedExecID) + } + handle.NotifyTaskChange() + sm.logger.Info("task scheduler exit", zap.Int64("task-id", task.ID)) + }() + metrics.UpdateMetricsForRunTask(task) + scheduler.ScheduleTask() + sm.finishCh <- struct{}{} + }) +} + +func (sm *Manager) cleanupTaskLoop() { + sm.logger.Info("cleanup loop start") + ticker := time.NewTicker(DefaultCleanUpInterval) + defer ticker.Stop() + for { + select { + case <-sm.ctx.Done(): + sm.logger.Info("cleanup loop exits") + return + case <-sm.finishCh: + sm.doCleanupTask() + case <-ticker.C: + sm.doCleanupTask() + } + } +} + +// WaitCleanUpFinished is used to sync the test. +var WaitCleanUpFinished = make(chan struct{}, 1) + +// doCleanupTask processes clean up routine defined by each type of tasks and cleanupMeta. +// For example: +// +// tasks with global sort should clean up tmp files stored on S3. +func (sm *Manager) doCleanupTask() { + tasks, err := sm.taskMgr.GetTasksInStates( + sm.ctx, + proto.TaskStateFailed, + proto.TaskStateReverted, + proto.TaskStateSucceed, + ) + if err != nil { + sm.logger.Warn("get task in states failed", zap.Error(err)) + return + } + if len(tasks) == 0 { + return + } + sm.logger.Info("cleanup routine start") + err = sm.cleanupFinishedTasks(tasks) + if err != nil { + sm.logger.Warn("cleanup routine failed", zap.Error(err)) + return + } + failpoint.Inject("WaitCleanUpFinished", func() { + WaitCleanUpFinished <- struct{}{} + }) + sm.logger.Info("cleanup routine success") +} + +func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error { + cleanedTasks := make([]*proto.Task, 0) + var firstErr error + for _, task := range tasks { + sm.logger.Info("cleanup task", zap.Int64("task-id", task.ID)) + cleanupFactory := getSchedulerCleanUpFactory(task.Type) + if cleanupFactory != nil { + cleanup := cleanupFactory() + err := cleanup.CleanUp(sm.ctx, task) + if err != nil { + firstErr = err + break + } + cleanedTasks = append(cleanedTasks, task) + } else { + // if task doesn't register cleanup function, mark it as cleaned. + cleanedTasks = append(cleanedTasks, task) + } + } + if firstErr != nil { + sm.logger.Warn("cleanup routine failed", zap.Error(errors.Trace(firstErr))) + } + + failpoint.Inject("mockTransferErr", func() { + failpoint.Return(errors.New("transfer err")) + }) + + return sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks) +} + +func (sm *Manager) collectLoop() { + sm.logger.Info("collect loop start") + ticker := time.NewTicker(defaultCollectMetricsInterval) + defer ticker.Stop() + for { + select { + case <-sm.ctx.Done(): + sm.logger.Info("collect loop exits") + return + case <-ticker.C: + sm.collect() + } + } +} + +func (sm *Manager) collect() { + subtasks, err := sm.taskMgr.GetAllSubtasks(sm.ctx) + if err != nil { + sm.logger.Warn("get all subtasks failed", zap.Error(err)) + return + } + + subtaskCollector.subtaskInfo.Store(&subtasks) +} + +// MockScheduler mock one scheduler for one task, only used for tests. +func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler { + return NewBaseScheduler(sm.ctx, task, Param{ + taskMgr: sm.taskMgr, + nodeMgr: sm.nodeMgr, + slotMgr: sm.slotMgr, + serverID: sm.serverID, + }) +} diff --git a/pkg/disttask/framework/storage/binding__failpoint_binding__.go b/pkg/disttask/framework/storage/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..a1a747a15d57f --- /dev/null +++ b/pkg/disttask/framework/storage/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package storage + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/disttask/framework/storage/history.go b/pkg/disttask/framework/storage/history.go index 9129a002da60e..de6da56d7b2c7 100644 --- a/pkg/disttask/framework/storage/history.go +++ b/pkg/disttask/framework/storage/history.go @@ -83,11 +83,11 @@ func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*prot // GCSubtasks deletes the history subtask which is older than the given days. func (mgr *TaskManager) GCSubtasks(ctx context.Context) error { subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60 - failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("subtaskHistoryKeepSeconds")); _err_ == nil { if val, ok := val.(int); ok { subtaskHistoryKeepSeconds = val } - }) + } _, err := mgr.ExecuteSQLWithNewSession( ctx, fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds), diff --git a/pkg/disttask/framework/storage/history.go__failpoint_stash__ b/pkg/disttask/framework/storage/history.go__failpoint_stash__ new file mode 100644 index 0000000000000..9129a002da60e --- /dev/null +++ b/pkg/disttask/framework/storage/history.go__failpoint_stash__ @@ -0,0 +1,96 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/sqlexec" +) + +// TransferSubtasks2HistoryWithSession transfer the selected subtasks into tidb_background_subtask_history table by taskID. +func (*TaskManager) TransferSubtasks2HistoryWithSession(ctx context.Context, se sessionctx.Context, taskID int64) error { + exec := se.GetSQLExecutor() + _, err := sqlexec.ExecSQL(ctx, exec, `insert into mysql.tidb_background_subtask_history select * from mysql.tidb_background_subtask where task_key = %?`, taskID) + if err != nil { + return err + } + // delete taskID subtask + _, err = sqlexec.ExecSQL(ctx, exec, "delete from mysql.tidb_background_subtask where task_key = %?", taskID) + return err +} + +// TransferTasks2History transfer the selected tasks into tidb_global_task_history table by taskIDs. +func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*proto.Task) error { + if len(tasks) == 0 { + return nil + } + taskIDStrs := make([]string, 0, len(tasks)) + for _, task := range tasks { + taskIDStrs = append(taskIDStrs, fmt.Sprintf("%d", task.ID)) + } + return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + // sensitive data in meta might be redacted, need update first. + exec := se.GetSQLExecutor() + for _, t := range tasks { + _, err := sqlexec.ExecSQL(ctx, exec, ` + update mysql.tidb_global_task + set meta= %?, state_update_time = CURRENT_TIMESTAMP() + where id = %?`, t.Meta, t.ID) + if err != nil { + return err + } + } + _, err := sqlexec.ExecSQL(ctx, exec, ` + insert into mysql.tidb_global_task_history + select * from mysql.tidb_global_task + where id in(`+strings.Join(taskIDStrs, `, `)+`)`) + if err != nil { + return err + } + + _, err = sqlexec.ExecSQL(ctx, exec, ` + delete from mysql.tidb_global_task + where id in(`+strings.Join(taskIDStrs, `, `)+`)`) + + for _, t := range tasks { + err = mgr.TransferSubtasks2HistoryWithSession(ctx, se, t.ID) + if err != nil { + return err + } + } + return err + }) +} + +// GCSubtasks deletes the history subtask which is older than the given days. +func (mgr *TaskManager) GCSubtasks(ctx context.Context) error { + subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60 + failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) { + if val, ok := val.(int); ok { + subtaskHistoryKeepSeconds = val + } + }) + _, err := mgr.ExecuteSQLWithNewSession( + ctx, + fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds), + ) + return err +} diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index d2c6f77e232c3..c28a42b595cf2 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -227,7 +227,9 @@ func (mgr *TaskManager) CreateTaskWithSession( } taskID = int64(rs[0].GetUint64(0)) - failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) }) + if _, _err_ := failpoint.Eval(_curpkg_("testSetLastTaskID")); _err_ == nil { + TestLastTaskID.Store(taskID) + } return taskID, nil } @@ -645,10 +647,10 @@ func (*TaskManager) insertSubtasks(ctx context.Context, se sessionctx.Context, s if len(subtasks) == 0 { return nil } - failpoint.Inject("waitBeforeInsertSubtasks", func() { + if _, _err_ := failpoint.Eval(_curpkg_("waitBeforeInsertSubtasks")); _err_ == nil { <-TestChannel <-TestChannel - }) + } var ( sb strings.Builder markerList = make([]string, 0, len(subtasks)) diff --git a/pkg/disttask/framework/storage/task_table.go__failpoint_stash__ b/pkg/disttask/framework/storage/task_table.go__failpoint_stash__ new file mode 100644 index 0000000000000..d2c6f77e232c3 --- /dev/null +++ b/pkg/disttask/framework/storage/task_table.go__failpoint_stash__ @@ -0,0 +1,809 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "strconv" + "strings" + "sync/atomic" + + "github.com/docker/go-units" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/sqlexec" + clitutil "github.com/tikv/client-go/v2/util" +) + +const ( + defaultSubtaskKeepDays = 14 + + basicTaskColumns = `t.id, t.task_key, t.type, t.state, t.step, t.priority, t.concurrency, t.create_time, t.target_scope` + // TaskColumns is the columns for task. + // TODO: dispatcher_id will update to scheduler_id later + TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error` + // InsertTaskColumns is the columns used in insert task. + InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, target_scope` + basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal, start_time` + // SubtaskColumns is the columns for subtask. + SubtaskColumns = basicSubtaskColumns + `, state_update_time, meta, summary` + // InsertSubtaskColumns is the columns used in insert subtask. + InsertSubtaskColumns = `step, task_key, exec_id, meta, state, type, concurrency, ordinal, create_time, checkpoint, summary` +) + +var ( + maxSubtaskBatchSize = 16 * units.MiB + + // ErrUnstableSubtasks is the error when we detected that the subtasks are + // unstable, i.e. count, order and content of the subtasks are changed on + // different call. + ErrUnstableSubtasks = errors.New("unstable subtasks") + + // ErrTaskNotFound is the error when we can't found task. + // i.e. TransferTasks2History move task from tidb_global_task to tidb_global_task_history. + ErrTaskNotFound = errors.New("task not found") + + // ErrTaskAlreadyExists is the error when we submit a task with the same task key. + // i.e. SubmitTask in handle may submit a task twice. + ErrTaskAlreadyExists = errors.New("task already exists") + + // ErrSubtaskNotFound is the error when can't find subtask by subtask_id and execId, + // i.e. scheduler change the subtask's execId when subtask need to balance to other nodes. + ErrSubtaskNotFound = errors.New("subtask not found") +) + +// TaskExecInfo is the execution information of a task, on some exec node. +type TaskExecInfo struct { + *proto.TaskBase + // SubtaskConcurrency is the concurrency of subtask in current task step. + // TODO: will be used when support subtask have smaller concurrency than task, + // TODO: such as post-process of import-into. + // TODO: we might need create one task executor for each step in this case, to alloc + // TODO: minimal resource + SubtaskConcurrency int +} + +// SessionExecutor defines the interface for executing SQLs in a session. +type SessionExecutor interface { + // WithNewSession executes the function with a new session. + WithNewSession(fn func(se sessionctx.Context) error) error + // WithNewTxn executes the fn in a new transaction. + WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error +} + +// TaskHandle provides the interface for operations needed by Scheduler. +// Then we can use scheduler's function in Scheduler interface. +type TaskHandle interface { + // GetPreviousSubtaskMetas gets previous subtask metas. + GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) + SessionExecutor +} + +// TaskManager is the manager of task and subtask. +type TaskManager struct { + sePool util.SessionPool +} + +var _ SessionExecutor = &TaskManager{} + +var taskManagerInstance atomic.Pointer[TaskManager] + +var ( + // TestLastTaskID is used for test to set the last task ID. + TestLastTaskID atomic.Int64 +) + +// NewTaskManager creates a new task manager. +func NewTaskManager(sePool util.SessionPool) *TaskManager { + return &TaskManager{ + sePool: sePool, + } +} + +// GetTaskManager gets the task manager. +func GetTaskManager() (*TaskManager, error) { + v := taskManagerInstance.Load() + if v == nil { + return nil, errors.New("task manager is not initialized") + } + return v, nil +} + +// SetTaskManager sets the task manager. +func SetTaskManager(is *TaskManager) { + taskManagerInstance.Store(is) +} + +// WithNewSession executes the function with a new session. +func (mgr *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error { + se, err := mgr.sePool.Get() + if err != nil { + return err + } + defer mgr.sePool.Put(se) + return fn(se.(sessionctx.Context)) +} + +// WithNewTxn executes the fn in a new transaction. +func (mgr *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { + ctx = clitutil.WithInternalSourceType(ctx, kv.InternalDistTask) + return mgr.WithNewSession(func(se sessionctx.Context) (err error) { + _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "begin") + if err != nil { + return err + } + + success := false + defer func() { + sql := "rollback" + if success { + sql = "commit" + } + _, commitErr := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql) + if err == nil && commitErr != nil { + err = commitErr + } + }() + + if err = fn(se); err != nil { + return err + } + + success = true + return nil + }) +} + +// ExecuteSQLWithNewSession executes one SQL with new session. +func (mgr *TaskManager) ExecuteSQLWithNewSession(ctx context.Context, sql string, args ...any) (rs []chunk.Row, err error) { + err = mgr.WithNewSession(func(se sessionctx.Context) error { + rs, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql, args...) + return err + }) + + if err != nil { + return nil, err + } + + return +} + +// CreateTask adds a new task to task table. +func (mgr *TaskManager) CreateTask(ctx context.Context, key string, tp proto.TaskType, concurrency int, targetScope string, meta []byte) (taskID int64, err error) { + err = mgr.WithNewSession(func(se sessionctx.Context) error { + var err2 error + taskID, err2 = mgr.CreateTaskWithSession(ctx, se, key, tp, concurrency, targetScope, meta) + return err2 + }) + return +} + +// CreateTaskWithSession adds a new task to task table with session. +func (mgr *TaskManager) CreateTaskWithSession( + ctx context.Context, + se sessionctx.Context, + key string, + tp proto.TaskType, + concurrency int, + targetScope string, + meta []byte, +) (taskID int64, err error) { + cpuCount, err := mgr.getCPUCountOfNode(ctx, se) + if err != nil { + return 0, err + } + if concurrency > cpuCount { + return 0, errors.Errorf("task concurrency(%d) larger than cpu count(%d) of managed node", concurrency, cpuCount) + } + _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + insert into mysql.tidb_global_task(`+InsertTaskColumns+`) + values (%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), %?)`, + key, tp, proto.TaskStatePending, proto.NormalPriority, concurrency, proto.StepInit, meta, targetScope) + if err != nil { + return 0, err + } + + rs, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "select @@last_insert_id") + if err != nil { + return 0, err + } + + taskID = int64(rs[0].GetUint64(0)) + failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) }) + + return taskID, nil +} + +// GetTopUnfinishedTasks implements the scheduler.TaskManager interface. +func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) ([]*proto.TaskBase, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, + `select `+basicTaskColumns+` from mysql.tidb_global_task t + where state in (%?, %?, %?, %?, %?, %?) + order by priority asc, create_time asc, id asc + limit %?`, + proto.TaskStatePending, + proto.TaskStateRunning, + proto.TaskStateReverting, + proto.TaskStateCancelling, + proto.TaskStatePausing, + proto.TaskStateResuming, + proto.MaxConcurrentTask*2, + ) + if err != nil { + return nil, err + } + + tasks := make([]*proto.TaskBase, 0, len(rs)) + for _, r := range rs { + tasks = append(tasks, row2TaskBasic(r)) + } + return tasks, nil +} + +// GetTaskExecInfoByExecID implements the scheduler.TaskManager interface. +func (mgr *TaskManager) GetTaskExecInfoByExecID(ctx context.Context, execID string) ([]*TaskExecInfo, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, + `select `+basicTaskColumns+`, max(st.concurrency) + from mysql.tidb_global_task t join mysql.tidb_background_subtask st + on t.id = st.task_key and t.step = st.step + where t.state in (%?, %?, %?) and st.state in (%?, %?) and st.exec_id = %? + group by t.id + order by priority asc, create_time asc, id asc`, + proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStatePausing, + proto.SubtaskStatePending, proto.SubtaskStateRunning, execID) + if err != nil { + return nil, err + } + + res := make([]*TaskExecInfo, 0, len(rs)) + for _, r := range rs { + res = append(res, &TaskExecInfo{ + TaskBase: row2TaskBasic(r), + SubtaskConcurrency: int(r.GetInt64(9)), + }) + } + return res, nil +} + +// GetTasksInStates gets the tasks in the states(order by priority asc, create_time acs, id asc). +func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...any) (task []*proto.Task, err error) { + if len(states) == 0 { + return task, nil + } + + rs, err := mgr.ExecuteSQLWithNewSession(ctx, + "select "+TaskColumns+" from mysql.tidb_global_task t "+ + "where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)"+ + " order by priority asc, create_time asc, id asc", states...) + if err != nil { + return task, err + } + + for _, r := range rs { + task = append(task, Row2Task(r)) + } + return task, nil +} + +// GetTaskByID gets the task by the task ID. +func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where id = %?", taskID) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, ErrTaskNotFound + } + + return Row2Task(rs[0]), nil +} + +// GetTaskBaseByID implements the TaskManager.GetTaskBaseByID interface. +func (mgr *TaskManager) GetTaskBaseByID(ctx context.Context, taskID int64) (task *proto.TaskBase, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+basicTaskColumns+" from mysql.tidb_global_task t where id = %?", taskID) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, ErrTaskNotFound + } + + return row2TaskBasic(rs[0]), nil +} + +// GetTaskByIDWithHistory gets the task by the task ID from both tidb_global_task and tidb_global_task_history. +func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64) (task *proto.Task, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where id = %? "+ + "union select "+TaskColumns+" from mysql.tidb_global_task_history t where id = %?", taskID, taskID) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, ErrTaskNotFound + } + + return Row2Task(rs[0]), nil +} + +// GetTaskBaseByIDWithHistory gets the task by the task ID from both tidb_global_task and tidb_global_task_history. +func (mgr *TaskManager) GetTaskBaseByIDWithHistory(ctx context.Context, taskID int64) (task *proto.TaskBase, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+basicTaskColumns+" from mysql.tidb_global_task t where id = %? "+ + "union select "+basicTaskColumns+" from mysql.tidb_global_task_history t where id = %?", taskID, taskID) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, ErrTaskNotFound + } + + return row2TaskBasic(rs[0]), nil +} + +// GetTaskByKey gets the task by the task key. +func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *proto.Task, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where task_key = %?", key) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, ErrTaskNotFound + } + + return Row2Task(rs[0]), nil +} + +// GetTaskByKeyWithHistory gets the task from history table by the task key. +func (mgr *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) (task *proto.Task, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where task_key = %?"+ + "union select "+TaskColumns+" from mysql.tidb_global_task_history t where task_key = %?", key, key) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, ErrTaskNotFound + } + + return Row2Task(rs[0]), nil +} + +// GetTaskBaseByKeyWithHistory gets the task base from history table by the task key. +func (mgr *TaskManager) GetTaskBaseByKeyWithHistory(ctx context.Context, key string) (task *proto.TaskBase, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+basicTaskColumns+" from mysql.tidb_global_task t where task_key = %?"+ + "union select "+basicTaskColumns+" from mysql.tidb_global_task_history t where task_key = %?", key, key) + if err != nil { + return task, err + } + if len(rs) == 0 { + return nil, ErrTaskNotFound + } + + return row2TaskBasic(rs[0]), nil +} + +// GetSubtasksByExecIDAndStepAndStates gets all subtasks by given states on one node. +func (mgr *TaskManager) GetSubtasksByExecIDAndStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) { + args := []any{execID, taskID, step} + for _, state := range states { + args = append(args, state) + } + rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask + where exec_id = %? and task_key = %? and step = %? + and state in (`+strings.Repeat("%?,", len(states)-1)+"%?)", args...) + if err != nil { + return nil, err + } + + subtasks := make([]*proto.Subtask, len(rs)) + for i, row := range rs { + subtasks[i] = Row2SubTask(row) + } + return subtasks, nil +} + +// GetFirstSubtaskInStates gets the first subtask by given states. +func (mgr *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (*proto.Subtask, error) { + args := []any{tidbID, taskID, step} + for _, state := range states { + args = append(args, state) + } + rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask + where exec_id = %? and task_key = %? and step = %? + and state in (`+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...) + if err != nil { + return nil, err + } + + if len(rs) == 0 { + return nil, nil + } + return Row2SubTask(rs[0]), nil +} + +// GetActiveSubtasks implements TaskManager.GetActiveSubtasks. +func (mgr *TaskManager) GetActiveSubtasks(ctx context.Context, taskID int64) ([]*proto.SubtaskBase, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, ` + select `+basicSubtaskColumns+` from mysql.tidb_background_subtask + where task_key = %? and state in (%?, %?)`, + taskID, proto.SubtaskStatePending, proto.SubtaskStateRunning) + if err != nil { + return nil, err + } + subtasks := make([]*proto.SubtaskBase, 0, len(rs)) + for _, r := range rs { + subtasks = append(subtasks, row2BasicSubTask(r)) + } + return subtasks, nil +} + +// GetAllSubtasksByStepAndState gets the subtask by step and state. +func (mgr *TaskManager) GetAllSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask + where task_key = %? and state = %? and step = %?`, + taskID, state, step) + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + subtasks := make([]*proto.Subtask, 0, len(rs)) + for _, r := range rs { + subtasks = append(subtasks, Row2SubTask(r)) + } + return subtasks, nil +} + +// GetSubtaskRowCount gets the subtask row count. +func (mgr *TaskManager) GetSubtaskRowCount(ctx context.Context, taskID int64, step proto.Step) (int64, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select + cast(sum(json_extract(summary, '$.row_count')) as signed) as row_count + from mysql.tidb_background_subtask where task_key = %? and step = %?`, + taskID, step) + if err != nil { + return 0, err + } + if len(rs) == 0 { + return 0, nil + } + return rs[0].GetInt64(0), nil +} + +// UpdateSubtaskRowCount updates the subtask row count. +func (mgr *TaskManager) UpdateSubtaskRowCount(ctx context.Context, subtaskID int64, rowCount int64) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask + set summary = json_set(summary, '$.row_count', %?) where id = %?`, + rowCount, subtaskID) + return err +} + +// GetSubtaskCntGroupByStates gets the subtask count by states. +func (mgr *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, ` + select state, count(*) + from mysql.tidb_background_subtask + where task_key = %? and step = %? + group by state`, + taskID, step) + if err != nil { + return nil, err + } + + res := make(map[proto.SubtaskState]int64, len(rs)) + for _, r := range rs { + state := proto.SubtaskState(r.GetString(0)) + res[state] = r.GetInt64(1) + } + + return res, nil +} + +// GetSubtaskErrors gets subtasks' errors. +func (mgr *TaskManager) GetSubtaskErrors(ctx context.Context, taskID int64) ([]error, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, + `select error from mysql.tidb_background_subtask + where task_key = %? AND state in (%?, %?)`, taskID, proto.SubtaskStateFailed, proto.SubtaskStateCanceled) + if err != nil { + return nil, err + } + subTaskErrors := make([]error, 0, len(rs)) + for _, row := range rs { + if row.IsNull(0) { + subTaskErrors = append(subTaskErrors, nil) + continue + } + errBytes := row.GetBytes(0) + if len(errBytes) == 0 { + subTaskErrors = append(subTaskErrors, nil) + continue + } + stdErr := errors.Normalize("") + err := stdErr.UnmarshalJSON(errBytes) + if err != nil { + return nil, err + } + subTaskErrors = append(subTaskErrors, stdErr) + } + + return subTaskErrors, nil +} + +// HasSubtasksInStates checks if there are subtasks in the states. +func (mgr *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (bool, error) { + args := []any{tidbID, taskID, step} + for _, state := range states { + args = append(args, state) + } + rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select 1 from mysql.tidb_background_subtask + where exec_id = %? and task_key = %? and step = %? + and state in (`+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...) + if err != nil { + return false, err + } + + return len(rs) > 0, nil +} + +// UpdateSubtasksExecIDs update subtasks' execID. +func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.SubtaskBase) error { + // skip the update process. + if len(subtasks) == 0 { + return nil + } + err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + for _, subtask := range subtasks { + _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + update mysql.tidb_background_subtask + set exec_id = %? + where id = %? and state = %?`, + subtask.ExecID, subtask.ID, subtask.State) + if err != nil { + return err + } + } + return nil + }) + return err +} + +// SwitchTaskStep implements the scheduler.TaskManager interface. +func (mgr *TaskManager) SwitchTaskStep( + ctx context.Context, + task *proto.Task, + nextState proto.TaskState, + nextStep proto.Step, + subtasks []*proto.Subtask, +) error { + return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + vars := se.GetSessionVars() + if vars.MemQuotaQuery < variable.DefTiDBMemQuotaQuery { + bak := vars.MemQuotaQuery + if err := vars.SetSystemVar(variable.TiDBMemQuotaQuery, + strconv.Itoa(variable.DefTiDBMemQuotaQuery)); err != nil { + return err + } + defer func() { + _ = vars.SetSystemVar(variable.TiDBMemQuotaQuery, strconv.Itoa(int(bak))) + }() + } + err := mgr.updateTaskStateStep(ctx, se, task, nextState, nextStep) + if err != nil { + return err + } + if vars.StmtCtx.AffectedRows() == 0 { + // on network partition or owner change, there might be multiple + // schedulers for the same task, if other scheduler has switched + // the task to next step, skip the update process. + // Or when there is no such task. + return nil + } + return mgr.insertSubtasks(ctx, se, subtasks) + }) +} + +func (*TaskManager) updateTaskStateStep(ctx context.Context, se sessionctx.Context, + task *proto.Task, nextState proto.TaskState, nextStep proto.Step) error { + var extraUpdateStr string + if task.State == proto.TaskStatePending { + extraUpdateStr = `start_time = CURRENT_TIMESTAMP(),` + } + // TODO: during generating subtask, task meta might change, maybe move meta + // update to another place. + _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + update mysql.tidb_global_task + set state = %?, + step = %?, `+extraUpdateStr+` + state_update_time = CURRENT_TIMESTAMP(), + meta = %? + where id = %? and state = %? and step = %?`, + nextState, nextStep, task.Meta, task.ID, task.State, task.Step) + return err +} + +// TestChannel is used for test. +var TestChannel = make(chan struct{}) + +func (*TaskManager) insertSubtasks(ctx context.Context, se sessionctx.Context, subtasks []*proto.Subtask) error { + if len(subtasks) == 0 { + return nil + } + failpoint.Inject("waitBeforeInsertSubtasks", func() { + <-TestChannel + <-TestChannel + }) + var ( + sb strings.Builder + markerList = make([]string, 0, len(subtasks)) + args = make([]any, 0, len(subtasks)*7) + ) + sb.WriteString(`insert into mysql.tidb_background_subtask(` + InsertSubtaskColumns + `) values `) + for _, subtask := range subtasks { + markerList = append(markerList, "(%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')") + args = append(args, subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta, + proto.SubtaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) + } + sb.WriteString(strings.Join(markerList, ",")) + _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sb.String(), args...) + return err +} + +// SwitchTaskStepInBatch implements the scheduler.TaskManager interface. +func (mgr *TaskManager) SwitchTaskStepInBatch( + ctx context.Context, + task *proto.Task, + nextState proto.TaskState, + nextStep proto.Step, + subtasks []*proto.Subtask, +) error { + return mgr.WithNewSession(func(se sessionctx.Context) error { + // some subtasks may be inserted by other schedulers, we can skip them. + rs, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` + select count(1) from mysql.tidb_background_subtask + where task_key = %? and step = %?`, task.ID, nextStep) + if err != nil { + return err + } + existingTaskCnt := int(rs[0].GetInt64(0)) + if existingTaskCnt > len(subtasks) { + return errors.Annotatef(ErrUnstableSubtasks, "expected %d, got %d", + len(subtasks), existingTaskCnt) + } + subtaskBatches := mgr.splitSubtasks(subtasks[existingTaskCnt:]) + for _, batch := range subtaskBatches { + if err = mgr.insertSubtasks(ctx, se, batch); err != nil { + return err + } + } + return mgr.updateTaskStateStep(ctx, se, task, nextState, nextStep) + }) +} + +func (*TaskManager) splitSubtasks(subtasks []*proto.Subtask) [][]*proto.Subtask { + var ( + res = make([][]*proto.Subtask, 0, 10) + currBatch = make([]*proto.Subtask, 0, 10) + size int + ) + maxSize := int(min(kv.TxnTotalSizeLimit.Load(), uint64(maxSubtaskBatchSize))) + for _, s := range subtasks { + if size+len(s.Meta) > maxSize { + res = append(res, currBatch) + currBatch = nil + size = 0 + } + currBatch = append(currBatch, s) + size += len(s.Meta) + } + if len(currBatch) > 0 { + res = append(res, currBatch) + } + return res +} + +func serializeErr(err error) []byte { + if err == nil { + return nil + } + originErr := errors.Cause(err) + tErr, ok := originErr.(*errors.Error) + if !ok { + tErr = errors.Normalize(originErr.Error()) + } + errBytes, err := tErr.MarshalJSON() + if err != nil { + return nil + } + return errBytes +} + +// GetSubtasksWithHistory gets the subtasks from tidb_global_task and tidb_global_task_history. +func (mgr *TaskManager) GetSubtasksWithHistory(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error) { + var ( + rs []chunk.Row + err error + ) + err = mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + rs, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), + `select `+SubtaskColumns+` from mysql.tidb_background_subtask where task_key = %? and step = %?`, + taskID, step, + ) + if err != nil { + return err + } + + // To avoid the situation that the subtasks has been `TransferTasks2History` + // when the user show import jobs, we need to check the history table. + rsFromHistory, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), + `select `+SubtaskColumns+` from mysql.tidb_background_subtask_history where task_key = %? and step = %?`, + taskID, step, + ) + if err != nil { + return err + } + + rs = append(rs, rsFromHistory...) + return nil + }) + + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + subtasks := make([]*proto.Subtask, 0, len(rs)) + for _, r := range rs { + subtasks = append(subtasks, Row2SubTask(r)) + } + return subtasks, nil +} + +// GetAllSubtasks gets all subtasks with basic columns. +func (mgr *TaskManager) GetAllSubtasks(ctx context.Context) ([]*proto.SubtaskBase, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+basicSubtaskColumns+` from mysql.tidb_background_subtask`) + if err != nil { + return nil, err + } + if len(rs) == 0 { + return nil, nil + } + subtasks := make([]*proto.SubtaskBase, 0, len(rs)) + for _, r := range rs { + subtasks = append(subtasks, row2BasicSubTask(r)) + } + return subtasks, nil +} + +// AdjustTaskOverflowConcurrency change the task concurrency to a max value supported by current cluster. +// This is a workaround for an upgrade bug: in v7.5.x, the task concurrency is hard-coded to 16, resulting in +// a stuck issue if the new version TiDB has less than 16 CPU count. +// We don't adjust the concurrency in subtask table because this field does not exist in v7.5.0. +// For details, see https://github.com/pingcap/tidb/issues/50894. +// For the following versions, there is a check when submitting a new task. This function should be a no-op. +func (mgr *TaskManager) AdjustTaskOverflowConcurrency(ctx context.Context, se sessionctx.Context) error { + cpuCount, err := mgr.getCPUCountOfNode(ctx, se) + if err != nil { + return err + } + sql := "update mysql.tidb_global_task set concurrency = %? where concurrency > %?;" + _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql, cpuCount, cpuCount) + return err +} diff --git a/pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go b/pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..42a8c5efc3dee --- /dev/null +++ b/pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package taskexecutor + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go index 2551173673f38..911a1f9f23c98 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -306,9 +306,9 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error) } execute.SetFrameworkInfo(stepExecutor, resource) - failpoint.Inject("mockExecSubtaskInitEnvErr", func() { - failpoint.Return(errors.New("mockExecSubtaskInitEnvErr")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockExecSubtaskInitEnvErr")); _err_ == nil { + return errors.New("mockExecSubtaskInitEnvErr") + } if err := stepExecutor.Init(runStepCtx); err != nil { e.onError(err) return e.getError() @@ -367,9 +367,9 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error) } } - failpoint.Inject("cancelBeforeRunSubtask", func() { + if _, _err_ := failpoint.Eval(_curpkg_("cancelBeforeRunSubtask")); _err_ == nil { runStepCancel(nil) - }) + } e.runSubtask(runStepCtx, stepExecutor, subtask) } @@ -402,17 +402,17 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute. }() return stepExecutor.RunSubtask(ctx, subtask) }() - failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockRunSubtaskCancel")); _err_ == nil { if val.(bool) { err = ErrCancelSubtask } - }) + } - failpoint.Inject("MockRunSubtaskContextCanceled", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockRunSubtaskContextCanceled")); _err_ == nil { if val.(bool) { err = context.Canceled } - }) + } if err != nil { e.onError(err) @@ -423,18 +423,18 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute. return } - failpoint.Inject("mockTiDBShutdown", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockTiDBShutdown")); _err_ == nil { if MockTiDBDown(e.id, e.GetTaskBase()) { - failpoint.Return() + return } - }) + } - failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockExecutorRunErr")); _err_ == nil { if val.(bool) { e.onError(errors.New("MockExecutorRunErr")) } - }) - failpoint.Inject("MockExecutorRunCancel", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("MockExecutorRunCancel")); _err_ == nil { if taskID, ok := val.(int); ok { mgr, err := storage.GetTaskManager() if err != nil { @@ -446,7 +446,7 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute. } } } - }) + } e.onSubtaskFinished(ctx, stepExecutor, subtask) } @@ -456,11 +456,11 @@ func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execu e.onError(err) } } - failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockSubtaskFinishedCancel")); _err_ == nil { if val.(bool) { e.onError(ErrCancelSubtask) } - }) + } finished := e.markSubTaskCanceledOrFailed(ctx, subtask) if finished { @@ -474,7 +474,7 @@ func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execu return } - failpoint.InjectCall("syncAfterSubtaskFinish") + failpoint.Call(_curpkg_("syncAfterSubtaskFinish")) } // GetTaskBase implements TaskExecutor.GetTaskBase. diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ b/pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..2551173673f38 --- /dev/null +++ b/pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ @@ -0,0 +1,683 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package taskexecutor + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" + "github.com/pingcap/tidb/pkg/lightning/common" + llog "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/backoff" + "github.com/pingcap/tidb/pkg/util/gctuner" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +var ( + // checkBalanceSubtaskInterval is the default check interval for checking + // subtasks balance to/away from this node. + checkBalanceSubtaskInterval = 2 * time.Second + + // updateSubtaskSummaryInterval is the interval for updating the subtask summary to + // subtask table. + updateSubtaskSummaryInterval = 3 * time.Second +) + +var ( + // ErrCancelSubtask is the cancel cause when cancelling subtasks. + ErrCancelSubtask = errors.New("cancel subtasks") + // ErrFinishSubtask is the cancel cause when TaskExecutor successfully processed subtasks. + ErrFinishSubtask = errors.New("finish subtasks") + // ErrNonIdempotentSubtask means the subtask is left in running state and is not idempotent, + // so cannot be run again. + ErrNonIdempotentSubtask = errors.New("subtask in running state and is not idempotent") + + // MockTiDBDown is used to mock TiDB node down, return true if it's chosen. + MockTiDBDown func(execID string, task *proto.TaskBase) bool +) + +// BaseTaskExecutor is the base implementation of TaskExecutor. +type BaseTaskExecutor struct { + // id, it's the same as server id now, i.e. host:port. + id string + // we only store task base here to reduce overhead of refreshing it. + // task meta is loaded when we do execute subtasks, see GetStepExecutor. + taskBase atomic.Pointer[proto.TaskBase] + taskTable TaskTable + logger *zap.Logger + ctx context.Context + cancel context.CancelFunc + Extension + + currSubtaskID atomic.Int64 + + mu struct { + sync.RWMutex + err error + // handled indicates whether the error has been updated to one of the subtask. + handled bool + // runtimeCancel is used to cancel the Run/Rollback when error occurs. + runtimeCancel context.CancelCauseFunc + } +} + +// NewBaseTaskExecutor creates a new BaseTaskExecutor. +// see TaskExecutor.Init for why we want to use task-base to create TaskExecutor. +// TODO: we can refactor this part to pass task base only, but currently ADD-INDEX +// depends on it to init, so we keep it for now. +func NewBaseTaskExecutor(ctx context.Context, id string, task *proto.Task, taskTable TaskTable) *BaseTaskExecutor { + logger := log.L().With(zap.Int64("task-id", task.ID), zap.String("task-type", string(task.Type))) + if intest.InTest { + logger = logger.With(zap.String("server-id", id)) + } + subCtx, cancelFunc := context.WithCancel(ctx) + taskExecutorImpl := &BaseTaskExecutor{ + id: id, + taskTable: taskTable, + ctx: subCtx, + cancel: cancelFunc, + logger: logger, + } + taskExecutorImpl.taskBase.Store(&task.TaskBase) + return taskExecutorImpl +} + +// checkBalanceSubtask check whether the subtasks are balanced to or away from this node. +// - If other subtask of `running` state is scheduled to this node, try changed to +// `pending` state, to make sure subtasks can be balanced later when node scale out. +// - If current running subtask are scheduled away from this node, i.e. this node +// is taken as down, cancel running. +func (e *BaseTaskExecutor) checkBalanceSubtask(ctx context.Context) { + ticker := time.NewTicker(checkBalanceSubtaskInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + + task := e.taskBase.Load() + subtasks, err := e.taskTable.GetSubtasksByExecIDAndStepAndStates(ctx, e.id, task.ID, task.Step, + proto.SubtaskStateRunning) + if err != nil { + e.logger.Error("get subtasks failed", zap.Error(err)) + continue + } + if len(subtasks) == 0 { + e.logger.Info("subtask is scheduled away, cancel running") + // cancels runStep, but leave the subtask state unchanged. + e.cancelRunStepWith(nil) + return + } + + extraRunningSubtasks := make([]*proto.SubtaskBase, 0, len(subtasks)) + for _, st := range subtasks { + if st.ID == e.currSubtaskID.Load() { + continue + } + if !e.IsIdempotent(st) { + e.updateSubtaskStateAndErrorImpl(ctx, st.ExecID, st.ID, proto.SubtaskStateFailed, ErrNonIdempotentSubtask) + return + } + extraRunningSubtasks = append(extraRunningSubtasks, &st.SubtaskBase) + } + if len(extraRunningSubtasks) > 0 { + if err = e.taskTable.RunningSubtasksBack2Pending(ctx, extraRunningSubtasks); err != nil { + e.logger.Error("update running subtasks back to pending failed", zap.Error(err)) + } else { + e.logger.Info("update extra running subtasks back to pending", + zap.Stringers("subtasks", extraRunningSubtasks)) + } + } + } +} + +func (e *BaseTaskExecutor) updateSubtaskSummaryLoop( + checkCtx, runStepCtx context.Context, stepExec execute.StepExecutor) { + taskMgr := e.taskTable.(*storage.TaskManager) + ticker := time.NewTicker(updateSubtaskSummaryInterval) + defer ticker.Stop() + curSubtaskID := e.currSubtaskID.Load() + update := func() { + summary := stepExec.RealtimeSummary() + err := taskMgr.UpdateSubtaskRowCount(runStepCtx, curSubtaskID, summary.RowCount) + if err != nil { + e.logger.Info("update subtask row count failed", zap.Error(err)) + } + } + for { + select { + case <-checkCtx.Done(): + update() + return + case <-ticker.C: + } + update() + } +} + +// Init implements the TaskExecutor interface. +func (*BaseTaskExecutor) Init(_ context.Context) error { + return nil +} + +// Ctx returns the context of the task executor. +// TODO: remove it when add-index.taskexecutor.Init don't depends on it. +func (e *BaseTaskExecutor) Ctx() context.Context { + return e.ctx +} + +// Run implements the TaskExecutor interface. +func (e *BaseTaskExecutor) Run(resource *proto.StepResource) { + var err error + // task executor occupies resources, if there's no subtask to run for 10s, + // we release the resources so that other tasks can use them. + // 300ms + 600ms + 1.2s + 2s * 4 = 10.1s + backoffer := backoff.NewExponential(SubtaskCheckInterval, 2, MaxSubtaskCheckInterval) + checkInterval, noSubtaskCheckCnt := SubtaskCheckInterval, 0 + for { + select { + case <-e.ctx.Done(): + return + case <-time.After(checkInterval): + } + if err = e.refreshTask(); err != nil { + if errors.Cause(err) == storage.ErrTaskNotFound { + return + } + e.logger.Error("refresh task failed", zap.Error(err)) + continue + } + task := e.taskBase.Load() + if task.State != proto.TaskStateRunning { + return + } + if exist, err := e.taskTable.HasSubtasksInStates(e.ctx, e.id, task.ID, task.Step, + unfinishedSubtaskStates...); err != nil { + e.logger.Error("check whether there are subtasks to run failed", zap.Error(err)) + continue + } else if !exist { + if noSubtaskCheckCnt >= maxChecksWhenNoSubtask { + e.logger.Info("no subtask to run for a while, exit") + break + } + checkInterval = backoffer.Backoff(noSubtaskCheckCnt) + noSubtaskCheckCnt++ + continue + } + // reset it when we get a subtask + checkInterval, noSubtaskCheckCnt = SubtaskCheckInterval, 0 + err = e.RunStep(resource) + if err != nil { + e.logger.Error("failed to handle task", zap.Error(err)) + } + } +} + +// RunStep start to fetch and run all subtasks for the step of task on the node. +// return if there's no subtask to run. +func (e *BaseTaskExecutor) RunStep(resource *proto.StepResource) (err error) { + defer func() { + if r := recover(); r != nil { + e.logger.Error("BaseTaskExecutor panicked", zap.Any("recover", r), zap.Stack("stack")) + err4Panic := errors.Errorf("%v", r) + err1 := e.updateSubtask(err4Panic) + if err == nil { + err = err1 + } + } + }() + err = e.runStep(resource) + if e.mu.handled { + return err + } + if err == nil { + // may have error in + // 1. defer function in run(ctx, task) + // 2. cancel ctx + // TODO: refine onError/getError + if e.getError() != nil { + err = e.getError() + } else if e.ctx.Err() != nil { + err = e.ctx.Err() + } else { + return nil + } + } + + return e.updateSubtask(err) +} + +func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error) { + runStepCtx, runStepCancel := context.WithCancelCause(e.ctx) + e.registerRunStepCancelFunc(runStepCancel) + defer func() { + runStepCancel(ErrFinishSubtask) + e.unregisterRunStepCancelFunc() + }() + e.resetError() + taskBase := e.taskBase.Load() + task, err := e.taskTable.GetTaskByID(e.ctx, taskBase.ID) + if err != nil { + e.onError(err) + return e.getError() + } + stepLogger := llog.BeginTask(e.logger.With( + zap.String("step", proto.Step2Str(task.Type, task.Step)), + zap.Float64("mem-limit-percent", gctuner.GlobalMemoryLimitTuner.GetPercentage()), + zap.String("server-mem-limit", memory.ServerMemoryLimitOriginText.Load()), + zap.Stringer("resource", resource), + ), "execute task step") + // log as info level, subtask might be cancelled, let caller check it. + defer func() { + stepLogger.End(zap.InfoLevel, resErr) + }() + + stepExecutor, err := e.GetStepExecutor(task) + if err != nil { + e.onError(err) + return e.getError() + } + execute.SetFrameworkInfo(stepExecutor, resource) + + failpoint.Inject("mockExecSubtaskInitEnvErr", func() { + failpoint.Return(errors.New("mockExecSubtaskInitEnvErr")) + }) + if err := stepExecutor.Init(runStepCtx); err != nil { + e.onError(err) + return e.getError() + } + + defer func() { + err := stepExecutor.Cleanup(runStepCtx) + if err != nil { + e.logger.Error("cleanup subtask exec env failed", zap.Error(err)) + e.onError(err) + } + }() + + for { + // check if any error occurs. + if err := e.getError(); err != nil { + break + } + if runStepCtx.Err() != nil { + break + } + + subtask, err := e.taskTable.GetFirstSubtaskInStates(runStepCtx, e.id, task.ID, task.Step, + proto.SubtaskStatePending, proto.SubtaskStateRunning) + if err != nil { + e.logger.Warn("GetFirstSubtaskInStates meets error", zap.Error(err)) + continue + } + if subtask == nil { + break + } + + if subtask.State == proto.SubtaskStateRunning { + if !e.IsIdempotent(subtask) { + e.logger.Info("subtask in running state and is not idempotent, fail it", + zap.Int64("subtask-id", subtask.ID)) + e.onError(ErrNonIdempotentSubtask) + e.updateSubtaskStateAndErrorImpl(runStepCtx, subtask.ExecID, subtask.ID, proto.SubtaskStateFailed, ErrNonIdempotentSubtask) + e.markErrorHandled() + break + } + e.logger.Info("subtask in running state and is idempotent", + zap.Int64("subtask-id", subtask.ID)) + } else { + // subtask.State == proto.SubtaskStatePending + err := e.startSubtask(runStepCtx, subtask.ID) + if err != nil { + e.logger.Warn("startSubtask meets error", zap.Error(err)) + // should ignore ErrSubtaskNotFound + // since it only means that the subtask not owned by current task executor. + if err == storage.ErrSubtaskNotFound { + continue + } + e.onError(err) + continue + } + } + + failpoint.Inject("cancelBeforeRunSubtask", func() { + runStepCancel(nil) + }) + + e.runSubtask(runStepCtx, stepExecutor, subtask) + } + return e.getError() +} + +func (e *BaseTaskExecutor) hasRealtimeSummary(stepExecutor execute.StepExecutor) bool { + _, ok := e.taskTable.(*storage.TaskManager) + return ok && stepExecutor.RealtimeSummary() != nil +} + +func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute.StepExecutor, subtask *proto.Subtask) { + err := func() error { + e.currSubtaskID.Store(subtask.ID) + + var wg util.WaitGroupWrapper + checkCtx, checkCancel := context.WithCancel(ctx) + wg.RunWithLog(func() { + e.checkBalanceSubtask(checkCtx) + }) + + if e.hasRealtimeSummary(stepExecutor) { + wg.RunWithLog(func() { + e.updateSubtaskSummaryLoop(checkCtx, ctx, stepExecutor) + }) + } + defer func() { + checkCancel() + wg.Wait() + }() + return stepExecutor.RunSubtask(ctx, subtask) + }() + failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) { + if val.(bool) { + err = ErrCancelSubtask + } + }) + + failpoint.Inject("MockRunSubtaskContextCanceled", func(val failpoint.Value) { + if val.(bool) { + err = context.Canceled + } + }) + + if err != nil { + e.onError(err) + } + + finished := e.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { + return + } + + failpoint.Inject("mockTiDBShutdown", func() { + if MockTiDBDown(e.id, e.GetTaskBase()) { + failpoint.Return() + } + }) + + failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) { + if val.(bool) { + e.onError(errors.New("MockExecutorRunErr")) + } + }) + failpoint.Inject("MockExecutorRunCancel", func(val failpoint.Value) { + if taskID, ok := val.(int); ok { + mgr, err := storage.GetTaskManager() + if err != nil { + e.logger.Error("get task manager failed", zap.Error(err)) + } else { + err = mgr.CancelTask(ctx, int64(taskID)) + if err != nil { + e.logger.Error("cancel task failed", zap.Error(err)) + } + } + } + }) + e.onSubtaskFinished(ctx, stepExecutor, subtask) +} + +func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execute.StepExecutor, subtask *proto.Subtask) { + if err := e.getError(); err == nil { + if err = executor.OnFinished(ctx, subtask); err != nil { + e.onError(err) + } + } + failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) { + if val.(bool) { + e.onError(ErrCancelSubtask) + } + }) + + finished := e.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { + return + } + + e.finishSubtask(ctx, subtask) + + finished = e.markSubTaskCanceledOrFailed(ctx, subtask) + if finished { + return + } + + failpoint.InjectCall("syncAfterSubtaskFinish") +} + +// GetTaskBase implements TaskExecutor.GetTaskBase. +func (e *BaseTaskExecutor) GetTaskBase() *proto.TaskBase { + return e.taskBase.Load() +} + +// CancelRunningSubtask implements TaskExecutor.CancelRunningSubtask. +func (e *BaseTaskExecutor) CancelRunningSubtask() { + e.cancelRunStepWith(ErrCancelSubtask) +} + +// Cancel implements TaskExecutor.Cancel. +func (e *BaseTaskExecutor) Cancel() { + e.cancel() +} + +// Close closes the TaskExecutor when all the subtasks are complete. +func (e *BaseTaskExecutor) Close() { + e.Cancel() +} + +// refreshTask fetch task state from tidb_global_task table. +func (e *BaseTaskExecutor) refreshTask() error { + task := e.GetTaskBase() + newTaskBase, err := e.taskTable.GetTaskBaseByID(e.ctx, task.ID) + if err != nil { + return err + } + e.taskBase.Store(newTaskBase) + return nil +} + +func (e *BaseTaskExecutor) registerRunStepCancelFunc(cancel context.CancelCauseFunc) { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.runtimeCancel = cancel +} + +func (e *BaseTaskExecutor) unregisterRunStepCancelFunc() { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.runtimeCancel = nil +} + +func (e *BaseTaskExecutor) cancelRunStepWith(cause error) { + e.mu.Lock() + defer e.mu.Unlock() + if e.mu.runtimeCancel != nil { + e.mu.runtimeCancel(cause) + } +} + +func (e *BaseTaskExecutor) onError(err error) { + if err == nil { + return + } + err = errors.Trace(err) + e.logger.Error("onError", zap.Error(err), zap.Stack("stack")) + e.mu.Lock() + defer e.mu.Unlock() + + if e.mu.err == nil { + e.mu.err = err + e.logger.Error("taskExecutor met first error", zap.Error(err)) + } + + if e.mu.runtimeCancel != nil { + e.mu.runtimeCancel(err) + } +} + +func (e *BaseTaskExecutor) markErrorHandled() { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.handled = true +} + +func (e *BaseTaskExecutor) getError() error { + e.mu.RLock() + defer e.mu.RUnlock() + return e.mu.err +} + +func (e *BaseTaskExecutor) resetError() { + e.mu.Lock() + defer e.mu.Unlock() + e.mu.err = nil + e.mu.handled = false +} + +func (e *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, execID string, subtaskID int64, state proto.SubtaskState, subTaskErr error) { + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, e.logger, + func(ctx context.Context) (bool, error) { + return true, e.taskTable.UpdateSubtaskStateAndError(ctx, execID, subtaskID, state, subTaskErr) + }, + ) + if err != nil { + e.onError(err) + } +} + +// startSubtask try to change the state of the subtask to running. +// If the subtask is not owned by the task executor, +// the update will fail and task executor should not run the subtask. +func (e *BaseTaskExecutor) startSubtask(ctx context.Context, subtaskID int64) error { + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, e.logger, + func(ctx context.Context) (bool, error) { + err := e.taskTable.StartSubtask(ctx, subtaskID, e.id) + if err == storage.ErrSubtaskNotFound { + // No need to retry. + return false, err + } + return true, err + }, + ) +} + +func (e *BaseTaskExecutor) finishSubtask(ctx context.Context, subtask *proto.Subtask) { + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, e.logger, + func(ctx context.Context) (bool, error) { + return true, e.taskTable.FinishSubtask(ctx, subtask.ExecID, subtask.ID, subtask.Meta) + }, + ) + if err != nil { + e.onError(err) + } +} + +// markSubTaskCanceledOrFailed check the error type and decide the subtasks' state. +// 1. Only cancel subtasks when meet ErrCancelSubtask. +// 2. Only fail subtasks when meet non retryable error. +// 3. When meet other errors, don't change subtasks' state. +func (e *BaseTaskExecutor) markSubTaskCanceledOrFailed(ctx context.Context, subtask *proto.Subtask) bool { + if err := e.getError(); err != nil { + err := errors.Cause(err) + if ctx.Err() != nil && context.Cause(ctx) == ErrCancelSubtask { + e.logger.Warn("subtask canceled", zap.Error(err)) + e.updateSubtaskStateAndErrorImpl(e.ctx, subtask.ExecID, subtask.ID, proto.SubtaskStateCanceled, nil) + } else if e.IsRetryableError(err) { + e.logger.Warn("meet retryable error", zap.Error(err)) + } else if common.IsContextCanceledError(err) { + e.logger.Info("meet context canceled for gracefully shutdown", zap.Error(err)) + } else { + e.logger.Warn("subtask failed", zap.Error(err)) + e.updateSubtaskStateAndErrorImpl(e.ctx, subtask.ExecID, subtask.ID, proto.SubtaskStateFailed, err) + } + e.markErrorHandled() + return true + } + return false +} + +func (e *BaseTaskExecutor) failSubtaskWithRetry(ctx context.Context, taskID int64, err error) error { + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err1 := handle.RunWithRetry(e.ctx, scheduler.RetrySQLTimes, backoffer, e.logger, + func(_ context.Context) (bool, error) { + return true, e.taskTable.FailSubtask(ctx, e.id, taskID, err) + }, + ) + if err1 == nil { + e.logger.Info("failed one subtask succeed", zap.NamedError("subtask-err", err)) + } + return err1 +} + +func (e *BaseTaskExecutor) cancelSubtaskWithRetry(ctx context.Context, taskID int64, err error) error { + e.logger.Warn("subtask canceled", zap.NamedError("subtask-cancel", err)) + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err1 := handle.RunWithRetry(e.ctx, scheduler.RetrySQLTimes, backoffer, e.logger, + func(_ context.Context) (bool, error) { + return true, e.taskTable.CancelSubtask(ctx, e.id, taskID) + }, + ) + if err1 == nil { + e.logger.Info("canceled one subtask succeed", zap.NamedError("subtask-cancel", err)) + } + return err1 +} + +// updateSubtask check the error type and decide the subtasks' state. +// 1. Only cancel subtasks when meet ErrCancelSubtask. +// 2. Only fail subtasks when meet non retryable error. +// 3. When meet other errors, don't change subtasks' state. +// Handled errors should not happen during subtasks execution. +// Only handle errors before subtasks execution and after subtasks execution. +func (e *BaseTaskExecutor) updateSubtask(err error) error { + task := e.taskBase.Load() + err = errors.Cause(err) + // TODO this branch is unreachable now, remove it when we refactor error handling. + if e.ctx.Err() != nil && context.Cause(e.ctx) == ErrCancelSubtask { + return e.cancelSubtaskWithRetry(e.ctx, task.ID, ErrCancelSubtask) + } else if e.IsRetryableError(err) { + e.logger.Warn("meet retryable error", zap.Error(err)) + } else if common.IsContextCanceledError(err) { + e.logger.Info("meet context canceled for gracefully shutdown", zap.Error(err)) + } else { + return e.failSubtaskWithRetry(e.ctx, task.ID, err) + } + return nil +} diff --git a/pkg/disttask/importinto/binding__failpoint_binding__.go b/pkg/disttask/importinto/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..f923a525842a5 --- /dev/null +++ b/pkg/disttask/importinto/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package importinto + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/disttask/importinto/planner.go b/pkg/disttask/importinto/planner.go index 8fda8686facbb..cc4bec874a880 100644 --- a/pkg/disttask/importinto/planner.go +++ b/pkg/disttask/importinto/planner.go @@ -292,12 +292,12 @@ func generateImportSpecs(pCtx planner.PlanCtx, p *LogicalPlan) ([]planner.Pipeli } func skipMergeSort(kvGroup string, stats []external.MultipleFilesStat) bool { - failpoint.Inject("forceMergeSort", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("forceMergeSort")); _err_ == nil { in := val.(string) if in == kvGroup || in == "*" { - failpoint.Return(false) + return false } - }) + } return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold } @@ -349,8 +349,8 @@ func generateWriteIngestSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planne if err != nil { return nil, err } - failpoint.Inject("mockWriteIngestSpecs", func() { - failpoint.Return([]planner.PipelineSpec{ + if _, _err_ := failpoint.Eval(_curpkg_("mockWriteIngestSpecs")); _err_ == nil { + return []planner.PipelineSpec{ &WriteIngestSpec{ WriteIngestStepMeta: &WriteIngestStepMeta{ KVGroup: dataKVGroup, @@ -361,8 +361,8 @@ func generateWriteIngestSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planne KVGroup: "1", }, }, - }, nil) - }) + }, nil + } pTS, lTS, err := planCtx.Store.GetPDClient().GetTS(ctx) if err != nil { diff --git a/pkg/disttask/importinto/planner.go__failpoint_stash__ b/pkg/disttask/importinto/planner.go__failpoint_stash__ new file mode 100644 index 0000000000000..8fda8686facbb --- /dev/null +++ b/pkg/disttask/importinto/planner.go__failpoint_stash__ @@ -0,0 +1,528 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "encoding/hex" + "encoding/json" + "math" + "strconv" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/planner" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/executor/importer" + tidbkv "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/external" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + verify "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/tikv/client-go/v2/oracle" + "go.uber.org/zap" +) + +var ( + _ planner.LogicalPlan = &LogicalPlan{} + _ planner.PipelineSpec = &ImportSpec{} + _ planner.PipelineSpec = &PostProcessSpec{} +) + +// LogicalPlan represents a logical plan for import into. +type LogicalPlan struct { + JobID int64 + Plan importer.Plan + Stmt string + EligibleInstances []*infosync.ServerInfo + ChunkMap map[int32][]Chunk +} + +// ToTaskMeta converts the logical plan to task meta. +func (p *LogicalPlan) ToTaskMeta() ([]byte, error) { + taskMeta := TaskMeta{ + JobID: p.JobID, + Plan: p.Plan, + Stmt: p.Stmt, + EligibleInstances: p.EligibleInstances, + ChunkMap: p.ChunkMap, + } + return json.Marshal(taskMeta) +} + +// FromTaskMeta converts the task meta to logical plan. +func (p *LogicalPlan) FromTaskMeta(bs []byte) error { + var taskMeta TaskMeta + if err := json.Unmarshal(bs, &taskMeta); err != nil { + return errors.Trace(err) + } + p.JobID = taskMeta.JobID + p.Plan = taskMeta.Plan + p.Stmt = taskMeta.Stmt + p.EligibleInstances = taskMeta.EligibleInstances + p.ChunkMap = taskMeta.ChunkMap + return nil +} + +// ToPhysicalPlan converts the logical plan to physical plan. +func (p *LogicalPlan) ToPhysicalPlan(planCtx planner.PlanCtx) (*planner.PhysicalPlan, error) { + physicalPlan := &planner.PhysicalPlan{} + inputLinks := make([]planner.LinkSpec, 0) + addSpecs := func(specs []planner.PipelineSpec) { + for i, spec := range specs { + physicalPlan.AddProcessor(planner.ProcessorSpec{ + ID: i, + Pipeline: spec, + Output: planner.OutputSpec{ + Links: []planner.LinkSpec{ + { + ProcessorID: len(specs), + }, + }, + }, + Step: planCtx.NextTaskStep, + }) + inputLinks = append(inputLinks, planner.LinkSpec{ + ProcessorID: i, + }) + } + } + // physical plan only needs to be generated once. + // However, our current implementation requires generating it for each step. + // we only generate needed plans for the next step. + switch planCtx.NextTaskStep { + case proto.ImportStepImport, proto.ImportStepEncodeAndSort: + specs, err := generateImportSpecs(planCtx, p) + if err != nil { + return nil, err + } + + addSpecs(specs) + case proto.ImportStepMergeSort: + specs, err := generateMergeSortSpecs(planCtx, p) + if err != nil { + return nil, err + } + + addSpecs(specs) + case proto.ImportStepWriteAndIngest: + specs, err := generateWriteIngestSpecs(planCtx, p) + if err != nil { + return nil, err + } + + addSpecs(specs) + case proto.ImportStepPostProcess: + physicalPlan.AddProcessor(planner.ProcessorSpec{ + ID: len(inputLinks), + Input: planner.InputSpec{ + ColumnTypes: []byte{ + // Checksum_crc64_xor, Total_kvs, Total_bytes, ReadRowCnt, LoadedRowCnt, ColSizeMap + mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeJSON, + }, + Links: inputLinks, + }, + Pipeline: &PostProcessSpec{ + Schema: p.Plan.DBName, + Table: p.Plan.TableInfo.Name.L, + }, + Step: planCtx.NextTaskStep, + }) + } + + return physicalPlan, nil +} + +// ImportSpec is the specification of an import pipeline. +type ImportSpec struct { + ID int32 + Plan importer.Plan + Chunks []Chunk +} + +// ToSubtaskMeta converts the import spec to subtask meta. +func (s *ImportSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) { + importStepMeta := ImportStepMeta{ + ID: s.ID, + Chunks: s.Chunks, + } + return json.Marshal(importStepMeta) +} + +// WriteIngestSpec is the specification of a write-ingest pipeline. +type WriteIngestSpec struct { + *WriteIngestStepMeta +} + +// ToSubtaskMeta converts the write-ingest spec to subtask meta. +func (s *WriteIngestSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) { + return json.Marshal(s.WriteIngestStepMeta) +} + +// MergeSortSpec is the specification of a merge-sort pipeline. +type MergeSortSpec struct { + *MergeSortStepMeta +} + +// ToSubtaskMeta converts the merge-sort spec to subtask meta. +func (s *MergeSortSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) { + return json.Marshal(s.MergeSortStepMeta) +} + +// PostProcessSpec is the specification of a post process pipeline. +type PostProcessSpec struct { + // for checksum request + Schema string + Table string +} + +// ToSubtaskMeta converts the post process spec to subtask meta. +func (*PostProcessSpec) ToSubtaskMeta(planCtx planner.PlanCtx) ([]byte, error) { + encodeStep := getStepOfEncode(planCtx.GlobalSort) + subtaskMetas := make([]*ImportStepMeta, 0, len(planCtx.PreviousSubtaskMetas)) + for _, bs := range planCtx.PreviousSubtaskMetas[encodeStep] { + var subtaskMeta ImportStepMeta + if err := json.Unmarshal(bs, &subtaskMeta); err != nil { + return nil, errors.Trace(err) + } + subtaskMetas = append(subtaskMetas, &subtaskMeta) + } + localChecksum := verify.NewKVGroupChecksumForAdd() + maxIDs := make(map[autoid.AllocatorType]int64, 3) + for _, subtaskMeta := range subtaskMetas { + for id, c := range subtaskMeta.Checksum { + localChecksum.AddRawGroup(id, c.Size, c.KVs, c.Sum) + } + + for key, val := range subtaskMeta.MaxIDs { + if maxIDs[key] < val { + maxIDs[key] = val + } + } + } + c := localChecksum.GetInnerChecksums() + postProcessStepMeta := &PostProcessStepMeta{ + Checksum: make(map[int64]Checksum, len(c)), + MaxIDs: maxIDs, + } + for id, cksum := range c { + postProcessStepMeta.Checksum[id] = Checksum{ + Size: cksum.SumSize(), + KVs: cksum.SumKVS(), + Sum: cksum.Sum(), + } + } + return json.Marshal(postProcessStepMeta) +} + +func buildControllerForPlan(p *LogicalPlan) (*importer.LoadDataController, error) { + return buildController(&p.Plan, p.Stmt) +} + +func buildController(plan *importer.Plan, stmt string) (*importer.LoadDataController, error) { + idAlloc := kv.NewPanickingAllocators(plan.TableInfo.SepAutoInc(), 0) + tbl, err := tables.TableFromMeta(idAlloc, plan.TableInfo) + if err != nil { + return nil, err + } + + astArgs, err := importer.ASTArgsFromStmt(stmt) + if err != nil { + return nil, err + } + controller, err := importer.NewLoadDataController(plan, tbl, astArgs) + if err != nil { + return nil, err + } + return controller, nil +} + +func generateImportSpecs(pCtx planner.PlanCtx, p *LogicalPlan) ([]planner.PipelineSpec, error) { + var chunkMap map[int32][]Chunk + if len(p.ChunkMap) > 0 { + chunkMap = p.ChunkMap + } else { + controller, err2 := buildControllerForPlan(p) + if err2 != nil { + return nil, err2 + } + if err2 = controller.InitDataFiles(pCtx.Ctx); err2 != nil { + return nil, err2 + } + + controller.SetExecuteNodeCnt(pCtx.ExecuteNodesCnt) + engineCheckpoints, err2 := controller.PopulateChunks(pCtx.Ctx) + if err2 != nil { + return nil, err2 + } + chunkMap = toChunkMap(engineCheckpoints) + } + importSpecs := make([]planner.PipelineSpec, 0, len(chunkMap)) + for id := range chunkMap { + if id == common.IndexEngineID { + continue + } + importSpec := &ImportSpec{ + ID: id, + Plan: p.Plan, + Chunks: chunkMap[id], + } + importSpecs = append(importSpecs, importSpec) + } + return importSpecs, nil +} + +func skipMergeSort(kvGroup string, stats []external.MultipleFilesStat) bool { + failpoint.Inject("forceMergeSort", func(val failpoint.Value) { + in := val.(string) + if in == kvGroup || in == "*" { + failpoint.Return(false) + } + }) + return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold +} + +func generateMergeSortSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planner.PipelineSpec, error) { + step := external.MergeSortFileCountStep + result := make([]planner.PipelineSpec, 0, 16) + kvMetas, err := getSortedKVMetasOfEncodeStep(planCtx.PreviousSubtaskMetas[proto.ImportStepEncodeAndSort]) + if err != nil { + return nil, err + } + for kvGroup, kvMeta := range kvMetas { + if !p.Plan.ForceMergeStep && skipMergeSort(kvGroup, kvMeta.MultipleFilesStats) { + logutil.Logger(planCtx.Ctx).Info("skip merge sort for kv group", + zap.Int64("task-id", planCtx.TaskID), + zap.String("kv-group", kvGroup)) + continue + } + dataFiles := kvMeta.GetDataFiles() + length := len(dataFiles) + for start := 0; start < length; start += step { + end := start + step + if end > length { + end = length + } + result = append(result, &MergeSortSpec{ + MergeSortStepMeta: &MergeSortStepMeta{ + KVGroup: kvGroup, + DataFiles: dataFiles[start:end], + }, + }) + } + } + return result, nil +} + +func generateWriteIngestSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planner.PipelineSpec, error) { + ctx := planCtx.Ctx + controller, err2 := buildControllerForPlan(p) + if err2 != nil { + return nil, err2 + } + if err2 = controller.InitDataStore(ctx); err2 != nil { + return nil, err2 + } + // kvMetas contains data kv meta and all index kv metas. + // each kvMeta will be split into multiple range group individually, + // i.e. data and index kv will NOT be in the same subtask. + kvMetas, err := getSortedKVMetasForIngest(planCtx, p) + if err != nil { + return nil, err + } + failpoint.Inject("mockWriteIngestSpecs", func() { + failpoint.Return([]planner.PipelineSpec{ + &WriteIngestSpec{ + WriteIngestStepMeta: &WriteIngestStepMeta{ + KVGroup: dataKVGroup, + }, + }, + &WriteIngestSpec{ + WriteIngestStepMeta: &WriteIngestStepMeta{ + KVGroup: "1", + }, + }, + }, nil) + }) + + pTS, lTS, err := planCtx.Store.GetPDClient().GetTS(ctx) + if err != nil { + return nil, err + } + ts := oracle.ComposeTS(pTS, lTS) + + specs := make([]planner.PipelineSpec, 0, 16) + for kvGroup, kvMeta := range kvMetas { + splitter, err1 := getRangeSplitter(ctx, controller.GlobalSortStore, kvMeta) + if err1 != nil { + return nil, err1 + } + + err1 = func() error { + defer func() { + err2 := splitter.Close() + if err2 != nil { + logutil.Logger(ctx).Warn("close range splitter failed", zap.Error(err2)) + } + }() + startKey := tidbkv.Key(kvMeta.StartKey) + var endKey tidbkv.Key + for { + endKeyOfGroup, dataFiles, statFiles, rangeSplitKeys, err2 := splitter.SplitOneRangesGroup() + if err2 != nil { + return err2 + } + if len(endKeyOfGroup) == 0 { + endKey = kvMeta.EndKey + } else { + endKey = tidbkv.Key(endKeyOfGroup).Clone() + } + logutil.Logger(ctx).Info("kv range as subtask", + zap.String("startKey", hex.EncodeToString(startKey)), + zap.String("endKey", hex.EncodeToString(endKey)), + zap.Int("dataFiles", len(dataFiles))) + if startKey.Cmp(endKey) >= 0 { + return errors.Errorf("invalid kv range, startKey: %s, endKey: %s", + hex.EncodeToString(startKey), hex.EncodeToString(endKey)) + } + // each subtask will write and ingest one range group + m := &WriteIngestStepMeta{ + KVGroup: kvGroup, + SortedKVMeta: external.SortedKVMeta{ + StartKey: startKey, + EndKey: endKey, + // this is actually an estimate, we don't know the exact size of the data + TotalKVSize: uint64(config.DefaultBatchSize), + }, + DataFiles: dataFiles, + StatFiles: statFiles, + RangeSplitKeys: rangeSplitKeys, + RangeSplitSize: splitter.GetRangeSplitSize(), + TS: ts, + } + specs = append(specs, &WriteIngestSpec{m}) + + startKey = endKey + if len(endKeyOfGroup) == 0 { + break + } + } + return nil + }() + if err1 != nil { + return nil, err1 + } + } + return specs, nil +} + +func getSortedKVMetasOfEncodeStep(subTaskMetas [][]byte) (map[string]*external.SortedKVMeta, error) { + dataKVMeta := &external.SortedKVMeta{} + indexKVMetas := make(map[int64]*external.SortedKVMeta) + for _, subTaskMeta := range subTaskMetas { + var stepMeta ImportStepMeta + err := json.Unmarshal(subTaskMeta, &stepMeta) + if err != nil { + return nil, errors.Trace(err) + } + dataKVMeta.Merge(stepMeta.SortedDataMeta) + for indexID, sortedIndexMeta := range stepMeta.SortedIndexMetas { + if item, ok := indexKVMetas[indexID]; !ok { + indexKVMetas[indexID] = sortedIndexMeta + } else { + item.Merge(sortedIndexMeta) + } + } + } + res := make(map[string]*external.SortedKVMeta, 1+len(indexKVMetas)) + res[dataKVGroup] = dataKVMeta + for indexID, item := range indexKVMetas { + res[strconv.Itoa(int(indexID))] = item + } + return res, nil +} + +func getSortedKVMetasOfMergeStep(subTaskMetas [][]byte) (map[string]*external.SortedKVMeta, error) { + result := make(map[string]*external.SortedKVMeta, len(subTaskMetas)) + for _, subTaskMeta := range subTaskMetas { + var stepMeta MergeSortStepMeta + err := json.Unmarshal(subTaskMeta, &stepMeta) + if err != nil { + return nil, errors.Trace(err) + } + meta, ok := result[stepMeta.KVGroup] + if !ok { + result[stepMeta.KVGroup] = &stepMeta.SortedKVMeta + continue + } + meta.Merge(&stepMeta.SortedKVMeta) + } + return result, nil +} + +func getSortedKVMetasForIngest(planCtx planner.PlanCtx, p *LogicalPlan) (map[string]*external.SortedKVMeta, error) { + kvMetasOfMergeSort, err := getSortedKVMetasOfMergeStep(planCtx.PreviousSubtaskMetas[proto.ImportStepMergeSort]) + if err != nil { + return nil, err + } + kvMetasOfEncodeStep, err := getSortedKVMetasOfEncodeStep(planCtx.PreviousSubtaskMetas[proto.ImportStepEncodeAndSort]) + if err != nil { + return nil, err + } + for kvGroup, kvMeta := range kvMetasOfEncodeStep { + // only part of kv files are merge sorted. we need to merge kv metas that + // are not merged into the kvMetasOfMergeSort. + if !p.Plan.ForceMergeStep && skipMergeSort(kvGroup, kvMeta.MultipleFilesStats) { + if _, ok := kvMetasOfMergeSort[kvGroup]; ok { + // this should not happen, because we only generate merge sort + // subtasks for those kv groups with MaxOverlappingTotal > MergeSortOverlapThreshold + logutil.Logger(planCtx.Ctx).Error("kv group of encode step conflict with merge sort step") + return nil, errors.New("kv group of encode step conflict with merge sort step") + } + kvMetasOfMergeSort[kvGroup] = kvMeta + } + } + return kvMetasOfMergeSort, nil +} + +func getRangeSplitter(ctx context.Context, store storage.ExternalStorage, kvMeta *external.SortedKVMeta) ( + *external.RangeSplitter, error) { + regionSplitSize, regionSplitKeys, err := importer.GetRegionSplitSizeKeys(ctx) + if err != nil { + logutil.Logger(ctx).Warn("fail to get region split size and keys", zap.Error(err)) + } + regionSplitSize = max(regionSplitSize, int64(config.SplitRegionSize)) + regionSplitKeys = max(regionSplitKeys, int64(config.SplitRegionKeys)) + logutil.Logger(ctx).Info("split kv range with split size and keys", + zap.Int64("region-split-size", regionSplitSize), + zap.Int64("region-split-keys", regionSplitKeys)) + + return external.NewRangeSplitter( + ctx, + kvMeta.MultipleFilesStats, + store, + int64(config.DefaultBatchSize), + int64(math.MaxInt64), + regionSplitSize, + regionSplitKeys, + ) +} diff --git a/pkg/disttask/importinto/scheduler.go b/pkg/disttask/importinto/scheduler.go index 1dd1c0021ae7f..52a98941d99df 100644 --- a/pkg/disttask/importinto/scheduler.go +++ b/pkg/disttask/importinto/scheduler.go @@ -249,9 +249,9 @@ func (sch *ImportSchedulerExt) OnNextSubtasksBatch( } previousSubtaskMetas[proto.ImportStepEncodeAndSort] = sortAndEncodeMeta case proto.ImportStepWriteAndIngest: - failpoint.Inject("failWhenDispatchWriteIngestSubtask", func() { - failpoint.Return(nil, errors.New("injected error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failWhenDispatchWriteIngestSubtask")); _err_ == nil { + return nil, errors.New("injected error") + } // merge sort might be skipped for some kv groups, so we need to get all // subtask metas of ImportStepEncodeAndSort step too. encodeAndSortMetas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepEncodeAndSort) @@ -269,15 +269,15 @@ func (sch *ImportSchedulerExt) OnNextSubtasksBatch( } case proto.ImportStepPostProcess: sch.switchTiKV2NormalMode(ctx, task, logger) - failpoint.Inject("clearLastSwitchTime", func() { + if _, _err_ := failpoint.Eval(_curpkg_("clearLastSwitchTime")); _err_ == nil { sch.lastSwitchTime.Store(time.Time{}) - }) + } if err = job2Step(ctx, logger, taskMeta, importer.JobStepValidating); err != nil { return nil, err } - failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { - failpoint.Return(nil, errors.New("injected error after ImportStepImport")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("failWhenDispatchPostProcessSubtask")); _err_ == nil { + return nil, errors.New("injected error after ImportStepImport") + } // we need get metas where checksum is stored. if err := updateResult(taskHandle, task, taskMeta, sch.GlobalSort); err != nil { return nil, err @@ -614,7 +614,7 @@ func getLoadedRowCountOnGlobalSort(handle storage.TaskHandle, task *proto.Task) } func startJob(ctx context.Context, logger *zap.Logger, taskHandle storage.TaskHandle, taskMeta *TaskMeta, jobStep string) error { - failpoint.InjectCall("syncBeforeJobStarted", taskMeta.JobID) + failpoint.Call(_curpkg_("syncBeforeJobStarted"), taskMeta.JobID) // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes // we consider all errors as retryable errors, except context done. // the errors include errors happened when communicate with PD and TiKV. @@ -628,7 +628,7 @@ func startJob(ctx context.Context, logger *zap.Logger, taskHandle storage.TaskHa }) }, ) - failpoint.InjectCall("syncAfterJobStarted") + failpoint.Call(_curpkg_("syncAfterJobStarted")) return err } diff --git a/pkg/disttask/importinto/scheduler.go__failpoint_stash__ b/pkg/disttask/importinto/scheduler.go__failpoint_stash__ new file mode 100644 index 0000000000000..1dd1c0021ae7f --- /dev/null +++ b/pkg/disttask/importinto/scheduler.go__failpoint_stash__ @@ -0,0 +1,719 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "encoding/json" + "strconv" + "strings" + "sync" + "time" + + dmysql "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/utils" + tidb "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/planner" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/executor/importer" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/backoff" + disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" + "github.com/pingcap/tidb/pkg/util/etcd" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + registerTaskTTL = 10 * time.Minute + refreshTaskTTLInterval = 3 * time.Minute + registerTimeout = 5 * time.Second +) + +// NewTaskRegisterWithTTL is the ctor for TaskRegister. +// It is exported for testing. +var NewTaskRegisterWithTTL = utils.NewTaskRegisterWithTTL + +type taskInfo struct { + taskID int64 + + // operation on taskInfo is run inside detect-task goroutine, so no need to synchronize. + lastRegisterTime time.Time + + // initialized lazily in register() + etcdClient *etcd.Client + taskRegister utils.TaskRegister +} + +func (t *taskInfo) register(ctx context.Context) { + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + + if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { + return + } + logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) + if t.taskRegister == nil { + client, err := importer.GetEtcdClient() + if err != nil { + logger.Warn("get etcd client failed", zap.Error(err)) + return + } + t.etcdClient = client + t.taskRegister = NewTaskRegisterWithTTL(client.GetClient(), registerTaskTTL, + utils.RegisterImportInto, strconv.FormatInt(t.taskID, 10)) + } + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.RegisterTaskOnce(timeoutCtx); err != nil { + logger.Warn("register task failed", zap.Error(err)) + } else { + logger.Info("register task to pd or refresh lease success") + } + // we set it even if register failed, TTL is 10min, refresh interval is 3min, + // we can try 2 times before the lease is expired. + t.lastRegisterTime = time.Now() +} + +func (t *taskInfo) close(ctx context.Context) { + logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) + if t.taskRegister != nil { + timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) + defer cancel() + if err := t.taskRegister.Close(timeoutCtx); err != nil { + logger.Warn("unregister task failed", zap.Error(err)) + } else { + logger.Info("unregister task success") + } + t.taskRegister = nil + } + if t.etcdClient != nil { + if err := t.etcdClient.Close(); err != nil { + logger.Warn("close etcd client failed", zap.Error(err)) + } + t.etcdClient = nil + } +} + +// ImportSchedulerExt is an extension of ImportScheduler, exported for test. +type ImportSchedulerExt struct { + GlobalSort bool + mu sync.RWMutex + // NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one + // task can be running at a time. but we might support task queuing in the future, leave it for now. + // the last time we switch TiKV into IMPORT mode, this is a global operation, do it for one task makes + // no difference to do it for all tasks. So we do not need to record the switch time for each task. + lastSwitchTime atomic.Time + // taskInfoMap is a map from taskID to taskInfo + taskInfoMap sync.Map + + // currTaskID is the taskID of the current running task. + // It may be changed when we switch to a new task or switch to a new owner. + currTaskID atomic.Int64 + disableTiKVImportMode atomic.Bool + + storeWithPD kv.StorageWithPD +} + +var _ scheduler.Extension = (*ImportSchedulerExt)(nil) + +// OnTick implements scheduler.Extension interface. +func (sch *ImportSchedulerExt) OnTick(ctx context.Context, task *proto.Task) { + // only switch TiKV mode or register task when task is running + if task.State != proto.TaskStateRunning { + return + } + sch.switchTiKVMode(ctx, task) + sch.registerTask(ctx, task) +} + +func (*ImportSchedulerExt) isImporting2TiKV(task *proto.Task) bool { + return task.Step == proto.ImportStepImport || task.Step == proto.ImportStepWriteAndIngest +} + +func (sch *ImportSchedulerExt) switchTiKVMode(ctx context.Context, task *proto.Task) { + sch.updateCurrentTask(task) + // only import step need to switch to IMPORT mode, + // If TiKV is in IMPORT mode during checksum, coprocessor will time out. + if sch.disableTiKVImportMode.Load() || !sch.isImporting2TiKV(task) { + return + } + + if time.Since(sch.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { + return + } + + sch.mu.Lock() + defer sch.mu.Unlock() + if time.Since(sch.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { + return + } + + logger := logutil.BgLogger().With(zap.Int64("task-id", task.ID)) + // TODO: use the TLS object from TiDB server + tidbCfg := tidb.GetGlobalConfig() + tls, err := util.NewTLSConfig( + util.WithCAPath(tidbCfg.Security.ClusterSSLCA), + util.WithCertAndKeyPath(tidbCfg.Security.ClusterSSLCert, tidbCfg.Security.ClusterSSLKey), + ) + if err != nil { + logger.Warn("get tikv mode switcher failed", zap.Error(err)) + return + } + pdHTTPCli := sch.storeWithPD.GetPDHTTPClient() + switcher := importer.NewTiKVModeSwitcher(tls, pdHTTPCli, logger) + + switcher.ToImportMode(ctx) + sch.lastSwitchTime.Store(time.Now()) +} + +func (sch *ImportSchedulerExt) registerTask(ctx context.Context, task *proto.Task) { + val, _ := sch.taskInfoMap.LoadOrStore(task.ID, &taskInfo{taskID: task.ID}) + info := val.(*taskInfo) + info.register(ctx) +} + +func (sch *ImportSchedulerExt) unregisterTask(ctx context.Context, task *proto.Task) { + if val, loaded := sch.taskInfoMap.LoadAndDelete(task.ID); loaded { + info := val.(*taskInfo) + info.close(ctx) + } +} + +// OnNextSubtasksBatch generate batch of next stage's plan. +func (sch *ImportSchedulerExt) OnNextSubtasksBatch( + ctx context.Context, + taskHandle storage.TaskHandle, + task *proto.Task, + execIDs []string, + nextStep proto.Step, +) ( + resSubtaskMeta [][]byte, err error) { + logger := logutil.BgLogger().With( + zap.Stringer("type", task.Type), + zap.Int64("task-id", task.ID), + zap.String("curr-step", proto.Step2Str(task.Type, task.Step)), + zap.String("next-step", proto.Step2Str(task.Type, nextStep)), + ) + taskMeta := &TaskMeta{} + err = json.Unmarshal(task.Meta, taskMeta) + if err != nil { + return nil, errors.Trace(err) + } + logger.Info("on next subtasks batch") + + previousSubtaskMetas := make(map[proto.Step][][]byte, 1) + switch nextStep { + case proto.ImportStepImport, proto.ImportStepEncodeAndSort: + if metrics, ok := metric.GetCommonMetric(ctx); ok { + metrics.BytesCounter.WithLabelValues(metric.StateTotalRestore).Add(float64(taskMeta.Plan.TotalFileSize)) + } + jobStep := importer.JobStepImporting + if sch.GlobalSort { + jobStep = importer.JobStepGlobalSorting + } + if err = startJob(ctx, logger, taskHandle, taskMeta, jobStep); err != nil { + return nil, err + } + case proto.ImportStepMergeSort: + sortAndEncodeMeta, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepEncodeAndSort) + if err != nil { + return nil, err + } + previousSubtaskMetas[proto.ImportStepEncodeAndSort] = sortAndEncodeMeta + case proto.ImportStepWriteAndIngest: + failpoint.Inject("failWhenDispatchWriteIngestSubtask", func() { + failpoint.Return(nil, errors.New("injected error")) + }) + // merge sort might be skipped for some kv groups, so we need to get all + // subtask metas of ImportStepEncodeAndSort step too. + encodeAndSortMetas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepEncodeAndSort) + if err != nil { + return nil, err + } + mergeSortMetas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepMergeSort) + if err != nil { + return nil, err + } + previousSubtaskMetas[proto.ImportStepEncodeAndSort] = encodeAndSortMetas + previousSubtaskMetas[proto.ImportStepMergeSort] = mergeSortMetas + if err = job2Step(ctx, logger, taskMeta, importer.JobStepImporting); err != nil { + return nil, err + } + case proto.ImportStepPostProcess: + sch.switchTiKV2NormalMode(ctx, task, logger) + failpoint.Inject("clearLastSwitchTime", func() { + sch.lastSwitchTime.Store(time.Time{}) + }) + if err = job2Step(ctx, logger, taskMeta, importer.JobStepValidating); err != nil { + return nil, err + } + failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { + failpoint.Return(nil, errors.New("injected error after ImportStepImport")) + }) + // we need get metas where checksum is stored. + if err := updateResult(taskHandle, task, taskMeta, sch.GlobalSort); err != nil { + return nil, err + } + step := getStepOfEncode(sch.GlobalSort) + metas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, step) + if err != nil { + return nil, err + } + previousSubtaskMetas[step] = metas + logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result)) + case proto.StepDone: + return nil, nil + default: + return nil, errors.Errorf("unknown step %d", task.Step) + } + + planCtx := planner.PlanCtx{ + Ctx: ctx, + TaskID: task.ID, + PreviousSubtaskMetas: previousSubtaskMetas, + GlobalSort: sch.GlobalSort, + NextTaskStep: nextStep, + ExecuteNodesCnt: len(execIDs), + Store: sch.storeWithPD, + } + logicalPlan := &LogicalPlan{} + if err := logicalPlan.FromTaskMeta(task.Meta); err != nil { + return nil, err + } + physicalPlan, err := logicalPlan.ToPhysicalPlan(planCtx) + if err != nil { + return nil, err + } + metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, nextStep) + if err != nil { + return nil, err + } + logger.Info("generate subtasks", zap.Int("subtask-count", len(metaBytes))) + return metaBytes, nil +} + +// OnDone implements scheduler.Extension interface. +func (sch *ImportSchedulerExt) OnDone(ctx context.Context, handle storage.TaskHandle, task *proto.Task) error { + logger := logutil.BgLogger().With( + zap.Stringer("type", task.Type), + zap.Int64("task-id", task.ID), + zap.String("step", proto.Step2Str(task.Type, task.Step)), + ) + logger.Info("task done", zap.Stringer("state", task.State), zap.Error(task.Error)) + taskMeta := &TaskMeta{} + err := json.Unmarshal(task.Meta, taskMeta) + if err != nil { + return errors.Trace(err) + } + if task.Error == nil { + return sch.finishJob(ctx, logger, handle, task, taskMeta) + } + if scheduler.IsCancelledErr(task.Error) { + return sch.cancelJob(ctx, handle, task, taskMeta, logger) + } + return sch.failJob(ctx, handle, task, taskMeta, logger, task.Error.Error()) +} + +// GetEligibleInstances implements scheduler.Extension interface. +func (*ImportSchedulerExt) GetEligibleInstances(_ context.Context, task *proto.Task) ([]string, error) { + taskMeta := &TaskMeta{} + err := json.Unmarshal(task.Meta, taskMeta) + if err != nil { + return nil, errors.Trace(err) + } + res := make([]string, 0, len(taskMeta.EligibleInstances)) + for _, instance := range taskMeta.EligibleInstances { + res = append(res, disttaskutil.GenerateExecID(instance)) + } + return res, nil +} + +// IsRetryableErr implements scheduler.Extension interface. +func (*ImportSchedulerExt) IsRetryableErr(error) bool { + // TODO: check whether the error is retryable. + return false +} + +// GetNextStep implements scheduler.Extension interface. +func (sch *ImportSchedulerExt) GetNextStep(task *proto.TaskBase) proto.Step { + switch task.Step { + case proto.StepInit: + if sch.GlobalSort { + return proto.ImportStepEncodeAndSort + } + return proto.ImportStepImport + case proto.ImportStepEncodeAndSort: + return proto.ImportStepMergeSort + case proto.ImportStepMergeSort: + return proto.ImportStepWriteAndIngest + case proto.ImportStepImport, proto.ImportStepWriteAndIngest: + return proto.ImportStepPostProcess + default: + // current step must be ImportStepPostProcess + return proto.StepDone + } +} + +func (sch *ImportSchedulerExt) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { + sch.updateCurrentTask(task) + if sch.disableTiKVImportMode.Load() { + return + } + + sch.mu.Lock() + defer sch.mu.Unlock() + + // TODO: use the TLS object from TiDB server + tidbCfg := tidb.GetGlobalConfig() + tls, err := util.NewTLSConfig( + util.WithCAPath(tidbCfg.Security.ClusterSSLCA), + util.WithCertAndKeyPath(tidbCfg.Security.ClusterSSLCert, tidbCfg.Security.ClusterSSLKey), + ) + if err != nil { + logger.Warn("get tikv mode switcher failed", zap.Error(err)) + return + } + pdHTTPCli := sch.storeWithPD.GetPDHTTPClient() + switcher := importer.NewTiKVModeSwitcher(tls, pdHTTPCli, logger) + + switcher.ToNormalMode(ctx) + + // clear it, so next task can switch TiKV mode again. + sch.lastSwitchTime.Store(time.Time{}) +} + +func (sch *ImportSchedulerExt) updateCurrentTask(task *proto.Task) { + if sch.currTaskID.Swap(task.ID) != task.ID { + taskMeta := &TaskMeta{} + if err := json.Unmarshal(task.Meta, taskMeta); err == nil { + // for raftkv2, switch mode in local backend + sch.disableTiKVImportMode.Store(taskMeta.Plan.DisableTiKVImportMode || taskMeta.Plan.IsRaftKV2) + } + } +} + +type importScheduler struct { + *scheduler.BaseScheduler + storeWithPD kv.StorageWithPD +} + +// NewImportScheduler creates a new import scheduler. +func NewImportScheduler( + ctx context.Context, + task *proto.Task, + param scheduler.Param, + storeWithPD kv.StorageWithPD, +) scheduler.Scheduler { + metrics := metricsManager.getOrCreateMetrics(task.ID) + subCtx := metric.WithCommonMetric(ctx, metrics) + sch := importScheduler{ + BaseScheduler: scheduler.NewBaseScheduler(subCtx, task, param), + storeWithPD: storeWithPD, + } + return &sch +} + +func (sch *importScheduler) Init() (err error) { + defer func() { + if err != nil { + // if init failed, close is not called, so we need to unregister here. + metricsManager.unregister(sch.GetTask().ID) + } + }() + taskMeta := &TaskMeta{} + if err = json.Unmarshal(sch.BaseScheduler.GetTask().Meta, taskMeta); err != nil { + return errors.Annotate(err, "unmarshal task meta failed") + } + + sch.BaseScheduler.Extension = &ImportSchedulerExt{ + GlobalSort: taskMeta.Plan.CloudStorageURI != "", + storeWithPD: sch.storeWithPD, + } + return sch.BaseScheduler.Init() +} + +func (sch *importScheduler) Close() { + metricsManager.unregister(sch.GetTask().ID) + sch.BaseScheduler.Close() +} + +// nolint:deadcode +func dropTableIndexes(ctx context.Context, handle storage.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { + tblInfo := taskMeta.Plan.TableInfo + + remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo) + for _, idxInfo := range dropIndexes { + sqlStr := common.BuildDropIndexSQL(taskMeta.Plan.DBName, tblInfo.Name.L, idxInfo) + if err := executeSQL(ctx, handle, logger, sqlStr); err != nil { + if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { + switch merr.Number { + case errno.ErrCantDropFieldOrKey, errno.ErrDropIndexNeededInForeignKey: + remainIndexes = append(remainIndexes, idxInfo) + logger.Warn("can't drop index, skip", zap.String("index", idxInfo.Name.O), zap.Error(err)) + continue + } + } + return err + } + } + if len(remainIndexes) < len(tblInfo.Indices) { + taskMeta.Plan.TableInfo = taskMeta.Plan.TableInfo.Clone() + taskMeta.Plan.TableInfo.Indices = remainIndexes + } + return nil +} + +// nolint:deadcode +func createTableIndexes(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) error { + tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) + singleSQL, multiSQLs := common.BuildAddIndexSQL(tableName, taskMeta.Plan.TableInfo, taskMeta.Plan.DesiredTableInfo) + logger.Info("build add index sql", zap.String("singleSQL", singleSQL), zap.Strings("multiSQLs", multiSQLs)) + if len(multiSQLs) == 0 { + return nil + } + + err := executeSQL(ctx, executor, logger, singleSQL) + if err == nil { + return nil + } + if !common.IsDupKeyError(err) { + // TODO: refine err msg and error code according to spec. + return errors.Errorf("Failed to create index: %v, please execute the SQL manually, sql: %s", err, singleSQL) + } + if len(multiSQLs) == 1 { + return nil + } + logger.Warn("cannot add all indexes in one statement, try to add them one by one", zap.Strings("sqls", multiSQLs), zap.Error(err)) + + for i, ddl := range multiSQLs { + err := executeSQL(ctx, executor, logger, ddl) + if err != nil && !common.IsDupKeyError(err) { + // TODO: refine err msg and error code according to spec. + return errors.Errorf("Failed to create index: %v, please execute the SQLs manually, sqls: %s", err, strings.Join(multiSQLs[i:], ";")) + } + } + return nil +} + +// TODO: return the result of sql. +func executeSQL(ctx context.Context, executor storage.SessionExecutor, logger *zap.Logger, sql string, args ...any) (err error) { + logger.Info("execute sql", zap.String("sql", sql), zap.Any("args", args)) + return executor.WithNewSession(func(se sessionctx.Context) error { + _, err := se.GetSQLExecutor().ExecuteInternal(ctx, sql, args...) + return err + }) +} + +func updateMeta(task *proto.Task, taskMeta *TaskMeta) error { + bs, err := json.Marshal(taskMeta) + if err != nil { + return errors.Trace(err) + } + task.Meta = bs + + return nil +} + +// todo: converting back and forth, we should unify struct and remove this function later. +func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[int32][]Chunk { + chunkMap := make(map[int32][]Chunk, len(engineCheckpoints)) + for id, ecp := range engineCheckpoints { + chunkMap[id] = make([]Chunk, 0, len(ecp.Chunks)) + for _, chunkCheckpoint := range ecp.Chunks { + chunkMap[id] = append(chunkMap[id], toChunk(*chunkCheckpoint)) + } + } + return chunkMap +} + +func getStepOfEncode(globalSort bool) proto.Step { + if globalSort { + return proto.ImportStepEncodeAndSort + } + return proto.ImportStepImport +} + +// we will update taskMeta in place and make task.Meta point to the new taskMeta. +func updateResult(handle storage.TaskHandle, task *proto.Task, taskMeta *TaskMeta, globalSort bool) error { + stepOfEncode := getStepOfEncode(globalSort) + metas, err := handle.GetPreviousSubtaskMetas(task.ID, stepOfEncode) + if err != nil { + return err + } + + subtaskMetas := make([]*ImportStepMeta, 0, len(metas)) + for _, bs := range metas { + var subtaskMeta ImportStepMeta + if err := json.Unmarshal(bs, &subtaskMeta); err != nil { + return errors.Trace(err) + } + subtaskMetas = append(subtaskMetas, &subtaskMeta) + } + columnSizeMap := make(map[int64]int64) + for _, subtaskMeta := range subtaskMetas { + taskMeta.Result.LoadedRowCnt += subtaskMeta.Result.LoadedRowCnt + for key, val := range subtaskMeta.Result.ColSizeMap { + columnSizeMap[key] += val + } + } + taskMeta.Result.ColSizeMap = columnSizeMap + + if globalSort { + taskMeta.Result.LoadedRowCnt, err = getLoadedRowCountOnGlobalSort(handle, task) + if err != nil { + return err + } + } + + return updateMeta(task, taskMeta) +} + +func getLoadedRowCountOnGlobalSort(handle storage.TaskHandle, task *proto.Task) (uint64, error) { + metas, err := handle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepWriteAndIngest) + if err != nil { + return 0, err + } + + var loadedRowCount uint64 + for _, bs := range metas { + var subtaskMeta WriteIngestStepMeta + if err = json.Unmarshal(bs, &subtaskMeta); err != nil { + return 0, errors.Trace(err) + } + loadedRowCount += subtaskMeta.Result.LoadedRowCnt + } + return loadedRowCount, nil +} + +func startJob(ctx context.Context, logger *zap.Logger, taskHandle storage.TaskHandle, taskMeta *TaskMeta, jobStep string) error { + failpoint.InjectCall("syncBeforeJobStarted", taskMeta.JobID) + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + // we consider all errors as retryable errors, except context done. + // the errors include errors happened when communicate with PD and TiKV. + // we didn't consider system corrupt cases like system table dropped/altered. + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { + exec := se.GetSQLExecutor() + return importer.StartJob(ctx, exec, taskMeta.JobID, jobStep) + }) + }, + ) + failpoint.InjectCall("syncAfterJobStarted") + return err +} + +func job2Step(ctx context.Context, logger *zap.Logger, taskMeta *TaskMeta, step string) error { + taskManager, err := storage.GetTaskManager() + if err != nil { + return err + } + // todo: use scheduler.TaskHandle + // we might call this in taskExecutor later, there's no scheduler.Extension, so we use taskManager here. + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, taskManager.WithNewSession(func(se sessionctx.Context) error { + exec := se.GetSQLExecutor() + return importer.Job2Step(ctx, exec, taskMeta.JobID, step) + }) + }, + ) +} + +func (sch *ImportSchedulerExt) finishJob(ctx context.Context, logger *zap.Logger, + taskHandle storage.TaskHandle, task *proto.Task, taskMeta *TaskMeta) error { + // we have already switch import-mode when switch to post-process step. + sch.unregisterTask(ctx, task) + summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { + if err := importer.FlushTableStats(ctx, se, taskMeta.Plan.TableInfo.ID, &importer.JobImportResult{ + Affected: taskMeta.Result.LoadedRowCnt, + ColSizeMap: taskMeta.Result.ColSizeMap, + }); err != nil { + logger.Warn("flush table stats failed", zap.Error(err)) + } + exec := se.GetSQLExecutor() + return importer.FinishJob(ctx, exec, taskMeta.JobID, summary) + }) + }, + ) +} + +func (sch *ImportSchedulerExt) failJob(ctx context.Context, taskHandle storage.TaskHandle, task *proto.Task, + taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error { + sch.switchTiKV2NormalMode(ctx, task, logger) + sch.unregisterTask(ctx, task) + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { + exec := se.GetSQLExecutor() + return importer.FailJob(ctx, exec, taskMeta.JobID, errorMsg) + }) + }, + ) +} + +func (sch *ImportSchedulerExt) cancelJob(ctx context.Context, taskHandle storage.TaskHandle, task *proto.Task, + meta *TaskMeta, logger *zap.Logger) error { + sch.switchTiKV2NormalMode(ctx, task, logger) + sch.unregisterTask(ctx, task) + // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes + backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) + return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, + func(ctx context.Context) (bool, error) { + return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { + exec := se.GetSQLExecutor() + return importer.CancelJob(ctx, exec, meta.JobID) + }) + }, + ) +} + +func redactSensitiveInfo(task *proto.Task, taskMeta *TaskMeta) { + taskMeta.Stmt = "" + taskMeta.Plan.Path = ast.RedactURL(taskMeta.Plan.Path) + if taskMeta.Plan.CloudStorageURI != "" { + taskMeta.Plan.CloudStorageURI = ast.RedactURL(taskMeta.Plan.CloudStorageURI) + } + if err := updateMeta(task, taskMeta); err != nil { + // marshal failed, should not happen + logutil.BgLogger().Warn("failed to update task meta", zap.Error(err)) + } +} diff --git a/pkg/disttask/importinto/subtask_executor.go b/pkg/disttask/importinto/subtask_executor.go index 2ca7fc241cc0b..92300944ef8ac 100644 --- a/pkg/disttask/importinto/subtask_executor.go +++ b/pkg/disttask/importinto/subtask_executor.go @@ -55,13 +55,13 @@ func newImportMinimalTaskExecutor0(t *importStepMinimalTask) MiniTaskExecutor { func (e *importMinimalTaskExecutor) Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error { logger := logutil.BgLogger().With(zap.Stringer("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID)) logger.Info("execute chunk") - failpoint.Inject("waitBeforeSortChunk", func() { + if _, _err_ := failpoint.Eval(_curpkg_("waitBeforeSortChunk")); _err_ == nil { time.Sleep(3 * time.Second) - }) - failpoint.Inject("errorWhenSortChunk", func() { - failpoint.Return(errors.New("occur an error when sort chunk")) - }) - failpoint.InjectCall("syncBeforeSortChunk") + } + if _, _err_ := failpoint.Eval(_curpkg_("errorWhenSortChunk")); _err_ == nil { + return errors.New("occur an error when sort chunk") + } + failpoint.Call(_curpkg_("syncBeforeSortChunk")) chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk) sharedVars := e.mTtask.SharedVars checksum := verify.NewKVGroupChecksumWithKeyspace(sharedVars.TableImporter.GetKeySpace()) @@ -101,7 +101,7 @@ func (e *importMinimalTaskExecutor) Run(ctx context.Context, dataWriter, indexWr // postProcess does the post-processing for the task. func postProcess(ctx context.Context, store kv.Storage, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) (err error) { - failpoint.InjectCall("syncBeforePostProcess", taskMeta.JobID) + failpoint.Call(_curpkg_("syncBeforePostProcess"), taskMeta.JobID) callLog := log.BeginTask(logger, "post process") defer func() { diff --git a/pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ b/pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..2ca7fc241cc0b --- /dev/null +++ b/pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ @@ -0,0 +1,142 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/executor/importer" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/log" + verify "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// MiniTaskExecutor is the interface for a minimal task executor. +// exported for testing. +type MiniTaskExecutor interface { + Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error +} + +// importMinimalTaskExecutor is a minimal task executor for IMPORT INTO. +type importMinimalTaskExecutor struct { + mTtask *importStepMinimalTask +} + +var newImportMinimalTaskExecutor = newImportMinimalTaskExecutor0 + +func newImportMinimalTaskExecutor0(t *importStepMinimalTask) MiniTaskExecutor { + return &importMinimalTaskExecutor{ + mTtask: t, + } +} + +func (e *importMinimalTaskExecutor) Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error { + logger := logutil.BgLogger().With(zap.Stringer("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID)) + logger.Info("execute chunk") + failpoint.Inject("waitBeforeSortChunk", func() { + time.Sleep(3 * time.Second) + }) + failpoint.Inject("errorWhenSortChunk", func() { + failpoint.Return(errors.New("occur an error when sort chunk")) + }) + failpoint.InjectCall("syncBeforeSortChunk") + chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk) + sharedVars := e.mTtask.SharedVars + checksum := verify.NewKVGroupChecksumWithKeyspace(sharedVars.TableImporter.GetKeySpace()) + if sharedVars.TableImporter.IsLocalSort() { + if err := importer.ProcessChunk( + ctx, + &chunkCheckpoint, + sharedVars.TableImporter, + sharedVars.DataEngine, + sharedVars.IndexEngine, + sharedVars.Progress, + logger, + checksum, + ); err != nil { + return err + } + } else { + if err := importer.ProcessChunkWithWriter( + ctx, + &chunkCheckpoint, + sharedVars.TableImporter, + dataWriter, + indexWriter, + sharedVars.Progress, + logger, + checksum, + ); err != nil { + return err + } + } + + sharedVars.mu.Lock() + defer sharedVars.mu.Unlock() + sharedVars.Checksum.Add(checksum) + return nil +} + +// postProcess does the post-processing for the task. +func postProcess(ctx context.Context, store kv.Storage, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) (err error) { + failpoint.InjectCall("syncBeforePostProcess", taskMeta.JobID) + + callLog := log.BeginTask(logger, "post process") + defer func() { + callLog.End(zap.ErrorLevel, err) + }() + + if err = importer.RebaseAllocatorBases(ctx, store, subtaskMeta.MaxIDs, &taskMeta.Plan, logger); err != nil { + return err + } + + // TODO: create table indexes depends on the option. + // create table indexes even if the post process is failed. + // defer func() { + // err2 := createTableIndexes(ctx, globalTaskManager, taskMeta, logger) + // err = multierr.Append(err, err2) + // }() + + localChecksum := verify.NewKVGroupChecksumForAdd() + for id, cksum := range subtaskMeta.Checksum { + callLog.Info( + "kv group checksum", + zap.Int64("groupId", id), + zap.Uint64("size", cksum.Size), + zap.Uint64("kvs", cksum.KVs), + zap.Uint64("checksum", cksum.Sum), + ) + localChecksum.AddRawGroup(id, cksum.Size, cksum.KVs, cksum.Sum) + } + + taskManager, err := storage.GetTaskManager() + ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) + if err != nil { + return err + } + return taskManager.WithNewSession(func(se sessionctx.Context) error { + return importer.VerifyChecksum(ctx, &taskMeta.Plan, localChecksum.MergedChecksum(), se, logger) + }) +} diff --git a/pkg/disttask/importinto/task_executor.go b/pkg/disttask/importinto/task_executor.go index 78f41ff03c0f7..bb0ea5da6b3c8 100644 --- a/pkg/disttask/importinto/task_executor.go +++ b/pkg/disttask/importinto/task_executor.go @@ -492,9 +492,9 @@ func (p *postProcessStepExecutor) RunSubtask(ctx context.Context, subtask *proto if err = json.Unmarshal(subtask.Meta, &stepMeta); err != nil { return errors.Trace(err) } - failpoint.Inject("waitBeforePostProcess", func() { + if _, _err_ := failpoint.Eval(_curpkg_("waitBeforePostProcess")); _err_ == nil { time.Sleep(5 * time.Second) - }) + } return postProcess(ctx, p.store, p.taskMeta, &stepMeta, logger) } diff --git a/pkg/disttask/importinto/task_executor.go__failpoint_stash__ b/pkg/disttask/importinto/task_executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..78f41ff03c0f7 --- /dev/null +++ b/pkg/disttask/importinto/task_executor.go__failpoint_stash__ @@ -0,0 +1,577 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importinto + +import ( + "context" + "encoding/json" + "strconv" + "sync" + "time" + + "github.com/docker/go-units" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + brlogutil "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" + "github.com/pingcap/tidb/pkg/disttask/operator" + "github.com/pingcap/tidb/pkg/executor/importer" + tidbkv "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/external" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// importStepExecutor is a executor for import step. +// StepExecutor is equivalent to a Lightning instance. +type importStepExecutor struct { + execute.StepExecFrameworkInfo + + taskID int64 + taskMeta *TaskMeta + tableImporter *importer.TableImporter + store tidbkv.Storage + sharedVars sync.Map + logger *zap.Logger + + dataKVMemSizePerCon uint64 + perIndexKVMemSizePerCon uint64 + + importCtx context.Context + importCancel context.CancelFunc + wg sync.WaitGroup +} + +func getTableImporter( + ctx context.Context, + taskID int64, + taskMeta *TaskMeta, + store tidbkv.Storage, +) (*importer.TableImporter, error) { + idAlloc := kv.NewPanickingAllocators(taskMeta.Plan.TableInfo.SepAutoInc(), 0) + tbl, err := tables.TableFromMeta(idAlloc, taskMeta.Plan.TableInfo) + if err != nil { + return nil, err + } + astArgs, err := importer.ASTArgsFromStmt(taskMeta.Stmt) + if err != nil { + return nil, err + } + controller, err := importer.NewLoadDataController(&taskMeta.Plan, tbl, astArgs) + if err != nil { + return nil, err + } + if err = controller.InitDataStore(ctx); err != nil { + return nil, err + } + + return importer.NewTableImporter(ctx, controller, strconv.FormatInt(taskID, 10), store) +} + +func (s *importStepExecutor) Init(ctx context.Context) error { + s.logger.Info("init subtask env") + tableImporter, err := getTableImporter(ctx, s.taskID, s.taskMeta, s.store) + if err != nil { + return err + } + s.tableImporter = tableImporter + + // we need this sub context since Cleanup which wait on this routine is called + // before parent context is canceled in normal flow. + s.importCtx, s.importCancel = context.WithCancel(ctx) + // only need to check disk quota when we are using local sort. + if s.tableImporter.IsLocalSort() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.tableImporter.CheckDiskQuota(s.importCtx) + }() + } + s.dataKVMemSizePerCon, s.perIndexKVMemSizePerCon = getWriterMemorySizeLimit(s.GetResource(), s.tableImporter.Plan) + s.logger.Info("KV writer memory size limit per concurrency", + zap.String("data", units.BytesSize(float64(s.dataKVMemSizePerCon))), + zap.String("per-index", units.BytesSize(float64(s.perIndexKVMemSizePerCon)))) + return nil +} + +func (s *importStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { + logger := s.logger.With(zap.Int64("subtask-id", subtask.ID)) + task := log.BeginTask(logger, "run subtask") + defer func() { + task.End(zapcore.ErrorLevel, err) + }() + bs := subtask.Meta + var subtaskMeta ImportStepMeta + err = json.Unmarshal(bs, &subtaskMeta) + if err != nil { + return errors.Trace(err) + } + + var dataEngine, indexEngine *backend.OpenedEngine + if s.tableImporter.IsLocalSort() { + dataEngine, err = s.tableImporter.OpenDataEngine(ctx, subtaskMeta.ID) + if err != nil { + return err + } + // Unlike in Lightning, we start an index engine for each subtask, + // whereas previously there was only a single index engine globally. + // This is because the executor currently does not have a post-processing mechanism. + // If we import the index in `cleanupSubtaskEnv`, the scheduler will not wait for the import to complete. + // Multiple index engines may suffer performance degradation due to range overlap. + // These issues will be alleviated after we integrate s3 sorter. + // engineID = -1, -2, -3, ... + indexEngine, err = s.tableImporter.OpenIndexEngine(ctx, common.IndexEngineID-subtaskMeta.ID) + if err != nil { + return err + } + } + sharedVars := &SharedVars{ + TableImporter: s.tableImporter, + DataEngine: dataEngine, + IndexEngine: indexEngine, + Progress: importer.NewProgress(), + Checksum: verification.NewKVGroupChecksumWithKeyspace(s.tableImporter.GetKeySpace()), + SortedDataMeta: &external.SortedKVMeta{}, + SortedIndexMetas: make(map[int64]*external.SortedKVMeta), + } + s.sharedVars.Store(subtaskMeta.ID, sharedVars) + + source := operator.NewSimpleDataChannel(make(chan *importStepMinimalTask)) + op := newEncodeAndSortOperator(ctx, s, sharedVars, subtask.ID, subtask.Concurrency) + op.SetSource(source) + pipeline := operator.NewAsyncPipeline(op) + if err = pipeline.Execute(); err != nil { + return err + } + +outer: + for _, chunk := range subtaskMeta.Chunks { + // TODO: current workpool impl doesn't drain the input channel, it will + // just return on context cancel(error happened), so we add this select. + select { + case source.Channel() <- &importStepMinimalTask{ + Plan: s.taskMeta.Plan, + Chunk: chunk, + SharedVars: sharedVars, + }: + case <-op.Done(): + break outer + } + } + source.Finish() + + return pipeline.Close() +} + +func (*importStepExecutor) RealtimeSummary() *execute.SubtaskSummary { + return nil +} + +func (s *importStepExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { + var subtaskMeta ImportStepMeta + if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil { + return errors.Trace(err) + } + s.logger.Info("on subtask finished", zap.Int32("engine-id", subtaskMeta.ID)) + + val, ok := s.sharedVars.Load(subtaskMeta.ID) + if !ok { + return errors.Errorf("sharedVars %d not found", subtaskMeta.ID) + } + sharedVars, ok := val.(*SharedVars) + if !ok { + return errors.Errorf("sharedVars %d not found", subtaskMeta.ID) + } + + var dataKVCount uint64 + if s.tableImporter.IsLocalSort() { + // TODO: we should close and cleanup engine in all case, since there's no checkpoint. + s.logger.Info("import data engine", zap.Int32("engine-id", subtaskMeta.ID)) + closedDataEngine, err := sharedVars.DataEngine.Close(ctx) + if err != nil { + return err + } + dataKVCount2, err := s.tableImporter.ImportAndCleanup(ctx, closedDataEngine) + if err != nil { + return err + } + dataKVCount = uint64(dataKVCount2) + + s.logger.Info("import index engine", zap.Int32("engine-id", subtaskMeta.ID)) + if closedEngine, err := sharedVars.IndexEngine.Close(ctx); err != nil { + return err + } else if _, err := s.tableImporter.ImportAndCleanup(ctx, closedEngine); err != nil { + return err + } + } + // there's no imported dataKVCount on this stage when using global sort. + + sharedVars.mu.Lock() + defer sharedVars.mu.Unlock() + subtaskMeta.Checksum = map[int64]Checksum{} + for id, c := range sharedVars.Checksum.GetInnerChecksums() { + subtaskMeta.Checksum[id] = Checksum{ + Sum: c.Sum(), + KVs: c.SumKVS(), + Size: c.SumSize(), + } + } + subtaskMeta.Result = Result{ + LoadedRowCnt: dataKVCount, + ColSizeMap: sharedVars.Progress.GetColSize(), + } + allocators := sharedVars.TableImporter.Allocators() + subtaskMeta.MaxIDs = map[autoid.AllocatorType]int64{ + autoid.RowIDAllocType: allocators.Get(autoid.RowIDAllocType).Base(), + autoid.AutoIncrementType: allocators.Get(autoid.AutoIncrementType).Base(), + autoid.AutoRandomType: allocators.Get(autoid.AutoRandomType).Base(), + } + subtaskMeta.SortedDataMeta = sharedVars.SortedDataMeta + subtaskMeta.SortedIndexMetas = sharedVars.SortedIndexMetas + s.sharedVars.Delete(subtaskMeta.ID) + newMeta, err := json.Marshal(subtaskMeta) + if err != nil { + return errors.Trace(err) + } + subtask.Meta = newMeta + return nil +} + +func (s *importStepExecutor) Cleanup(_ context.Context) (err error) { + s.logger.Info("cleanup subtask env") + s.importCancel() + s.wg.Wait() + return s.tableImporter.Close() +} + +type mergeSortStepExecutor struct { + taskexecutor.EmptyStepExecutor + taskID int64 + taskMeta *TaskMeta + logger *zap.Logger + controller *importer.LoadDataController + // subtask of a task is run in serial now, so we don't need lock here. + // change to SyncMap when we support parallel subtask in the future. + subtaskSortedKVMeta *external.SortedKVMeta + // part-size for uploading merged files, it's calculated by: + // max(max-merged-files * max-file-size / max-part-num(10000), min-part-size) + dataKVPartSize int64 + indexKVPartSize int64 +} + +var _ execute.StepExecutor = &mergeSortStepExecutor{} + +func (m *mergeSortStepExecutor) Init(ctx context.Context) error { + controller, err := buildController(&m.taskMeta.Plan, m.taskMeta.Stmt) + if err != nil { + return err + } + if err = controller.InitDataStore(ctx); err != nil { + return err + } + m.controller = controller + dataKVMemSizePerCon, perIndexKVMemSizePerCon := getWriterMemorySizeLimit(m.GetResource(), &m.taskMeta.Plan) + m.dataKVPartSize = max(external.MinUploadPartSize, int64(dataKVMemSizePerCon*uint64(external.MaxMergingFilesPerThread)/10000)) + m.indexKVPartSize = max(external.MinUploadPartSize, int64(perIndexKVMemSizePerCon*uint64(external.MaxMergingFilesPerThread)/10000)) + + m.logger.Info("merge sort partSize", + zap.String("data-kv", units.BytesSize(float64(m.dataKVPartSize))), + zap.String("index-kv", units.BytesSize(float64(m.indexKVPartSize))), + ) + return nil +} + +func (m *mergeSortStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { + sm := &MergeSortStepMeta{} + err = json.Unmarshal(subtask.Meta, sm) + if err != nil { + return errors.Trace(err) + } + logger := m.logger.With(zap.Int64("subtask-id", subtask.ID), zap.String("kv-group", sm.KVGroup)) + task := log.BeginTask(logger, "run subtask") + defer func() { + task.End(zapcore.ErrorLevel, err) + }() + + var mu sync.Mutex + m.subtaskSortedKVMeta = &external.SortedKVMeta{} + onClose := func(summary *external.WriterSummary) { + mu.Lock() + defer mu.Unlock() + m.subtaskSortedKVMeta.MergeSummary(summary) + } + + prefix := subtaskPrefix(m.taskID, subtask.ID) + + partSize := m.dataKVPartSize + if sm.KVGroup != dataKVGroup { + partSize = m.indexKVPartSize + } + err = external.MergeOverlappingFiles( + logutil.WithFields(ctx, zap.String("kv-group", sm.KVGroup), zap.Int64("subtask-id", subtask.ID)), + sm.DataFiles, + m.controller.GlobalSortStore, + partSize, + prefix, + getKVGroupBlockSize(sm.KVGroup), + onClose, + subtask.Concurrency, + false) + logger.Info( + "merge sort finished", + zap.Uint64("total-kv-size", m.subtaskSortedKVMeta.TotalKVSize), + zap.Uint64("total-kv-count", m.subtaskSortedKVMeta.TotalKVCnt), + brlogutil.Key("start-key", m.subtaskSortedKVMeta.StartKey), + brlogutil.Key("end-key", m.subtaskSortedKVMeta.EndKey), + ) + return err +} + +func (m *mergeSortStepExecutor) OnFinished(_ context.Context, subtask *proto.Subtask) error { + var subtaskMeta MergeSortStepMeta + if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil { + return errors.Trace(err) + } + subtaskMeta.SortedKVMeta = *m.subtaskSortedKVMeta + m.subtaskSortedKVMeta = nil + newMeta, err := json.Marshal(subtaskMeta) + if err != nil { + return errors.Trace(err) + } + subtask.Meta = newMeta + return nil +} + +type writeAndIngestStepExecutor struct { + execute.StepExecFrameworkInfo + + taskID int64 + taskMeta *TaskMeta + logger *zap.Logger + tableImporter *importer.TableImporter + store tidbkv.Storage +} + +var _ execute.StepExecutor = &writeAndIngestStepExecutor{} + +func (e *writeAndIngestStepExecutor) Init(ctx context.Context) error { + tableImporter, err := getTableImporter(ctx, e.taskID, e.taskMeta, e.store) + if err != nil { + return err + } + e.tableImporter = tableImporter + return nil +} + +func (e *writeAndIngestStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { + sm := &WriteIngestStepMeta{} + err = json.Unmarshal(subtask.Meta, sm) + if err != nil { + return errors.Trace(err) + } + + logger := e.logger.With(zap.Int64("subtask-id", subtask.ID), + zap.String("kv-group", sm.KVGroup)) + task := log.BeginTask(logger, "run subtask") + defer func() { + task.End(zapcore.ErrorLevel, err) + }() + + _, engineUUID := backend.MakeUUID("", subtask.ID) + localBackend := e.tableImporter.Backend() + localBackend.WorkerConcurrency = subtask.Concurrency * 2 + err = localBackend.CloseEngine(ctx, &backend.EngineConfig{ + External: &backend.ExternalEngineConfig{ + StorageURI: e.taskMeta.Plan.CloudStorageURI, + DataFiles: sm.DataFiles, + StatFiles: sm.StatFiles, + StartKey: sm.StartKey, + EndKey: sm.EndKey, + SplitKeys: sm.RangeSplitKeys, + RegionSplitSize: sm.RangeSplitSize, + TotalFileSize: int64(sm.TotalKVSize), + TotalKVCount: 0, + CheckHotspot: false, + }, + TS: sm.TS, + }, engineUUID) + if err != nil { + return err + } + return localBackend.ImportEngine(ctx, engineUUID, int64(config.SplitRegionSize), int64(config.SplitRegionKeys)) +} + +func (*writeAndIngestStepExecutor) RealtimeSummary() *execute.SubtaskSummary { + return nil +} + +func (e *writeAndIngestStepExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { + var subtaskMeta WriteIngestStepMeta + if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil { + return errors.Trace(err) + } + if subtaskMeta.KVGroup != dataKVGroup { + return nil + } + + // only data kv group has loaded row count + _, engineUUID := backend.MakeUUID("", subtask.ID) + localBackend := e.tableImporter.Backend() + _, kvCount := localBackend.GetExternalEngineKVStatistics(engineUUID) + subtaskMeta.Result.LoadedRowCnt = uint64(kvCount) + err := localBackend.CleanupEngine(ctx, engineUUID) + if err != nil { + e.logger.Warn("failed to cleanup engine", zap.Error(err)) + } + + newMeta, err := json.Marshal(subtaskMeta) + if err != nil { + return errors.Trace(err) + } + subtask.Meta = newMeta + return nil +} + +func (e *writeAndIngestStepExecutor) Cleanup(_ context.Context) (err error) { + e.logger.Info("cleanup subtask env") + return e.tableImporter.Close() +} + +type postProcessStepExecutor struct { + taskexecutor.EmptyStepExecutor + taskID int64 + store tidbkv.Storage + taskMeta *TaskMeta + logger *zap.Logger +} + +var _ execute.StepExecutor = &postProcessStepExecutor{} + +// NewPostProcessStepExecutor creates a new post process step executor. +// exported for testing. +func NewPostProcessStepExecutor(taskID int64, store tidbkv.Storage, taskMeta *TaskMeta, logger *zap.Logger) execute.StepExecutor { + return &postProcessStepExecutor{ + taskID: taskID, + store: store, + taskMeta: taskMeta, + logger: logger, + } +} + +func (p *postProcessStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { + logger := p.logger.With(zap.Int64("subtask-id", subtask.ID)) + task := log.BeginTask(logger, "run subtask") + defer func() { + task.End(zapcore.ErrorLevel, err) + }() + stepMeta := PostProcessStepMeta{} + if err = json.Unmarshal(subtask.Meta, &stepMeta); err != nil { + return errors.Trace(err) + } + failpoint.Inject("waitBeforePostProcess", func() { + time.Sleep(5 * time.Second) + }) + return postProcess(ctx, p.store, p.taskMeta, &stepMeta, logger) +} + +type importExecutor struct { + *taskexecutor.BaseTaskExecutor + store tidbkv.Storage +} + +// NewImportExecutor creates a new import task executor. +func NewImportExecutor( + ctx context.Context, + id string, + task *proto.Task, + taskTable taskexecutor.TaskTable, + store tidbkv.Storage, +) taskexecutor.TaskExecutor { + metrics := metricsManager.getOrCreateMetrics(task.ID) + subCtx := metric.WithCommonMetric(ctx, metrics) + s := &importExecutor{ + BaseTaskExecutor: taskexecutor.NewBaseTaskExecutor(subCtx, id, task, taskTable), + store: store, + } + s.BaseTaskExecutor.Extension = s + return s +} + +func (*importExecutor) IsIdempotent(*proto.Subtask) bool { + // import don't have conflict detection and resolution now, so it's ok + // to import data twice. + return true +} + +func (*importExecutor) IsRetryableError(err error) bool { + return common.IsRetryableError(err) +} + +func (e *importExecutor) GetStepExecutor(task *proto.Task) (execute.StepExecutor, error) { + taskMeta := TaskMeta{} + if err := json.Unmarshal(task.Meta, &taskMeta); err != nil { + return nil, errors.Trace(err) + } + logger := logutil.BgLogger().With( + zap.Stringer("type", proto.ImportInto), + zap.Int64("task-id", task.ID), + zap.String("step", proto.Step2Str(task.Type, task.Step)), + ) + + switch task.Step { + case proto.ImportStepImport, proto.ImportStepEncodeAndSort: + return &importStepExecutor{ + taskID: task.ID, + taskMeta: &taskMeta, + logger: logger, + store: e.store, + }, nil + case proto.ImportStepMergeSort: + return &mergeSortStepExecutor{ + taskID: task.ID, + taskMeta: &taskMeta, + logger: logger, + }, nil + case proto.ImportStepWriteAndIngest: + return &writeAndIngestStepExecutor{ + taskID: task.ID, + taskMeta: &taskMeta, + logger: logger, + store: e.store, + }, nil + case proto.ImportStepPostProcess: + return NewPostProcessStepExecutor(task.ID, e.store, &taskMeta, logger), nil + default: + return nil, errors.Errorf("unknown step %d for import task %d", task.Step, task.ID) + } +} + +func (e *importExecutor) Close() { + task := e.GetTaskBase() + metricsManager.unregister(task.ID) + e.BaseTaskExecutor.Close() +} diff --git a/pkg/domain/binding__failpoint_binding__.go b/pkg/domain/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..17ae9cb59498c --- /dev/null +++ b/pkg/domain/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package domain + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/domain/domain.go b/pkg/domain/domain.go index 8fe52df356fd4..e1889a312a056 100644 --- a/pkg/domain/domain.go +++ b/pkg/domain/domain.go @@ -566,22 +566,22 @@ func (do *Domain) tryLoadSchemaDiffs(builder *infoschema.Builder, m *meta.Meta, diffs = append(diffs, diff) } - failpoint.Inject("MockTryLoadDiffError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockTryLoadDiffError")); _err_ == nil { switch val.(string) { case "exchangepartition": if diffs[0].Type == model.ActionExchangeTablePartition { - failpoint.Return(nil, nil, nil, errors.New("mock error")) + return nil, nil, nil, errors.New("mock error") } case "renametable": if diffs[0].Type == model.ActionRenameTable { - failpoint.Return(nil, nil, nil, errors.New("mock error")) + return nil, nil, nil, errors.New("mock error") } case "dropdatabase": if diffs[0].Type == model.ActionDropSchema { - failpoint.Return(nil, nil, nil, errors.New("mock error")) + return nil, nil, nil, errors.New("mock error") } } - }) + } err := builder.InitWithOldInfoSchema(do.infoCache.GetLatest()) if err != nil { @@ -720,11 +720,11 @@ func getFlashbackStartTSFromErrorMsg(err error) uint64 { // Reload reloads InfoSchema. // It's public in order to do the test. func (do *Domain) Reload() error { - failpoint.Inject("ErrorMockReloadFailed", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ErrorMockReloadFailed")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.New("mock reload failed")) + return errors.New("mock reload failed") } - }) + } // Lock here for only once at the same time. do.m.Lock() @@ -1048,9 +1048,9 @@ func (do *Domain) loadSchemaInLoop(ctx context.Context, lease time.Duration) { for { select { case <-ticker.C: - failpoint.Inject("disableOnTickReload", func() { - failpoint.Continue() - }) + if _, _err_ := failpoint.Eval(_curpkg_("disableOnTickReload")); _err_ == nil { + continue + } err := do.Reload() if err != nil { logutil.BgLogger().Error("reload schema in loop failed", zap.Error(err)) @@ -1346,12 +1346,12 @@ func (do *Domain) Init( ddl.WithSchemaLoader(do), ) - failpoint.Inject("MockReplaceDDL", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockReplaceDDL")); _err_ == nil { if val.(bool) { do.ddl = d do.ddlExecutor = eBak } - }) + } if ddlInjector != nil { checker := ddlInjector(do.ddl, do.ddlExecutor, do.infoCache) checker.CreateTestDB(nil) @@ -1645,11 +1645,11 @@ func (do *Domain) checkReplicaRead(ctx context.Context, pdClient pd.Client) erro // InitDistTaskLoop initializes the distributed task framework. func (do *Domain) InitDistTaskLoop() error { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalDistTask) - failpoint.Inject("MockDisableDistTask", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockDisableDistTask")); _err_ == nil { if val.(bool) { - failpoint.Return(nil) + return nil } - }) + } taskManager := storage.NewTaskManager(do.sysSessionPool) var serverID string @@ -1857,7 +1857,7 @@ func (do *Domain) LoadSysVarCacheLoop(ctx sessionctx.Context) error { case <-time.After(duration): } - failpoint.Inject("skipLoadSysVarCacheLoop", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("skipLoadSysVarCacheLoop")); _err_ == nil { // In some pkg integration test, there are many testSuite, and each testSuite has separate storage and // `LoadSysVarCacheLoop` background goroutine. Then each testSuite `RebuildSysVarCache` from it's // own storage. @@ -1865,9 +1865,9 @@ func (do *Domain) LoadSysVarCacheLoop(ctx sessionctx.Context) error { // That's the problem, each testSuit use different storage to update some same local variables. // So just skip `RebuildSysVarCache` in some integration testing. if val.(bool) { - failpoint.Continue() + continue } - }) + } if !ok { logutil.BgLogger().Error("LoadSysVarCacheLoop loop watch channel closed") diff --git a/pkg/domain/domain.go__failpoint_stash__ b/pkg/domain/domain.go__failpoint_stash__ new file mode 100644 index 0000000000000..8fe52df356fd4 --- /dev/null +++ b/pkg/domain/domain.go__failpoint_stash__ @@ -0,0 +1,3239 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package domain + +import ( + "context" + "fmt" + "math" + "math/rand" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/pdpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/streamhelper" + "github.com/pingcap/tidb/br/pkg/streamhelper/daemon" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/ddl/schematracker" + "github.com/pingcap/tidb/pkg/ddl/systable" + ddlutil "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" + "github.com/pingcap/tidb/pkg/domain/globalconfigsync" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/infoschema" + infoschema_metrics "github.com/pingcap/tidb/pkg/infoschema/metrics" + "github.com/pingcap/tidb/pkg/infoschema/perfschema" + "github.com/pingcap/tidb/pkg/keyspace" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + metrics2 "github.com/pingcap/tidb/pkg/planner/core/metrics" + "github.com/pingcap/tidb/pkg/privilege/privileges" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/sysproctrack" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics/handle" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/ttl/ttlworker" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" + "github.com/pingcap/tidb/pkg/util/domainutil" + "github.com/pingcap/tidb/pkg/util/engine" + "github.com/pingcap/tidb/pkg/util/etcd" + "github.com/pingcap/tidb/pkg/util/expensivequery" + "github.com/pingcap/tidb/pkg/util/gctuner" + "github.com/pingcap/tidb/pkg/util/globalconn" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/memoryusagealarm" + "github.com/pingcap/tidb/pkg/util/replayer" + "github.com/pingcap/tidb/pkg/util/servermemorylimit" + "github.com/pingcap/tidb/pkg/util/sqlkiller" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/txnkv/transaction" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + rmclient "github.com/tikv/pd/client/resource_group/controller" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/keepalive" +) + +var ( + mdlCheckLookDuration = 50 * time.Millisecond + + // LoadSchemaDiffVersionGapThreshold is the threshold for version gap to reload domain by loading schema diffs + LoadSchemaDiffVersionGapThreshold int64 = 10000 + + // NewInstancePlanCache creates a new instance level plan cache, this function is designed to avoid cycle-import. + NewInstancePlanCache func(softMemLimit, hardMemLimit int64) sessionctx.InstancePlanCache +) + +const ( + indexUsageGCDuration = 30 * time.Minute +) + +func init() { + if intest.InTest { + // In test we can set duration lower to make test faster. + mdlCheckLookDuration = 2 * time.Millisecond + } +} + +// NewMockDomain is only used for test +func NewMockDomain() *Domain { + do := &Domain{} + do.infoCache = infoschema.NewCache(do, 1) + do.infoCache.Insert(infoschema.MockInfoSchema(nil), 0) + return do +} + +// Domain represents a storage space. Different domains can use the same database name. +// Multiple domains can be used in parallel without synchronization. +type Domain struct { + store kv.Storage + infoCache *infoschema.InfoCache + privHandle *privileges.Handle + bindHandle atomic.Value + statsHandle atomic.Pointer[handle.Handle] + statsLease time.Duration + ddl ddl.DDL + ddlExecutor ddl.Executor + info *infosync.InfoSyncer + globalCfgSyncer *globalconfigsync.GlobalConfigSyncer + m syncutil.Mutex + SchemaValidator SchemaValidator + sysSessionPool util.SessionPool + exit chan struct{} + // `etcdClient` must be used when keyspace is not set, or when the logic to each etcd path needs to be separated by keyspace. + etcdClient *clientv3.Client + // autoidClient is used when there are tables with AUTO_ID_CACHE=1, it is the client to the autoid service. + autoidClient *autoid.ClientDiscover + // `unprefixedEtcdCli` will never set the etcd namespace prefix by keyspace. + // It is only used in storeMinStartTS and RemoveMinStartTS now. + // It must be used when the etcd path isn't needed to separate by keyspace. + // See keyspace RFC: https://github.com/pingcap/tidb/pull/39685 + unprefixedEtcdCli *clientv3.Client + sysVarCache sysVarCache // replaces GlobalVariableCache + slowQuery *topNSlowQueries + expensiveQueryHandle *expensivequery.Handle + memoryUsageAlarmHandle *memoryusagealarm.Handle + serverMemoryLimitHandle *servermemorylimit.Handle + // TODO: use Run for each process in future pr + wg *util.WaitGroupEnhancedWrapper + statsUpdating atomicutil.Int32 + // this is the parent context of DDL, and also used by other loops such as closestReplicaReadCheckLoop. + // there are other top level contexts in the domain, such as the ones used in + // InitDistTaskLoop and loadStatsWorker, domain only stores the cancelFns of them. + // TODO unify top level context. + ctx context.Context + cancelFns struct { + mu sync.Mutex + fns []context.CancelFunc + } + dumpFileGcChecker *dumpFileGcChecker + planReplayerHandle *planReplayerHandle + extractTaskHandle *ExtractHandle + expiredTimeStamp4PC struct { + // let `expiredTimeStamp4PC` use its own lock to avoid any block across domain.Reload() + // and compiler.Compile(), see issue https://github.com/pingcap/tidb/issues/45400 + sync.RWMutex + expiredTimeStamp types.Time + } + + logBackupAdvancer *daemon.OwnerDaemon + historicalStatsWorker *HistoricalStatsWorker + ttlJobManager atomic.Pointer[ttlworker.JobManager] + runawayManager *resourcegroup.RunawayManager + runawaySyncer *runawaySyncer + resourceGroupsController *rmclient.ResourceGroupsController + + serverID uint64 + serverIDSession *concurrency.Session + isLostConnectionToPD atomicutil.Int32 // !0: true, 0: false. + connIDAllocator globalconn.Allocator + + onClose func() + sysExecutorFactory func(*Domain) (pools.Resource, error) + + sysProcesses SysProcesses + + mdlCheckTableInfo *mdlCheckTableInfo + + analyzeMu struct { + sync.Mutex + sctxs map[sessionctx.Context]bool + } + + mdlCheckCh chan struct{} + stopAutoAnalyze atomicutil.Bool + minJobIDRefresher *systable.MinJobIDRefresher + + instancePlanCache sessionctx.InstancePlanCache // the instance level plan cache + + // deferFn is used to release infoschema object lazily during v1 and v2 switch + deferFn +} + +type deferFn struct { + sync.Mutex + data []deferFnRecord +} + +type deferFnRecord struct { + fn func() + fire time.Time +} + +func (df *deferFn) add(fn func(), fire time.Time) { + df.Lock() + defer df.Unlock() + df.data = append(df.data, deferFnRecord{fn: fn, fire: fire}) +} + +func (df *deferFn) check() { + now := time.Now() + df.Lock() + defer df.Unlock() + + // iterate the slice, call the defer function and remove it. + rm := 0 + for i := 0; i < len(df.data); i++ { + record := &df.data[i] + if now.After(record.fire) { + record.fn() + rm++ + } else { + df.data[i-rm] = df.data[i] + } + } + df.data = df.data[:len(df.data)-rm] +} + +type mdlCheckTableInfo struct { + mu sync.Mutex + newestVer int64 + jobsVerMap map[int64]int64 + jobsIDsMap map[int64]string +} + +// InfoCache export for test. +func (do *Domain) InfoCache() *infoschema.InfoCache { + return do.infoCache +} + +// EtcdClient export for test. +func (do *Domain) EtcdClient() *clientv3.Client { + return do.etcdClient +} + +// loadInfoSchema loads infoschema at startTS. +// It returns: +// 1. the needed infoschema +// 2. cache hit indicator +// 3. currentSchemaVersion(before loading) +// 4. the changed table IDs if it is not full load +// 5. an error if any +func (do *Domain) loadInfoSchema(startTS uint64, isSnapshot bool) (infoschema.InfoSchema, bool, int64, *transaction.RelatedSchemaChange, error) { + beginTime := time.Now() + defer func() { + infoschema_metrics.LoadSchemaDurationTotal.Observe(time.Since(beginTime).Seconds()) + }() + snapshot := do.store.GetSnapshot(kv.NewVersion(startTS)) + // Using the KV timeout read feature to address the issue of potential DDL lease expiration when + // the meta region leader is slow. + snapshot.SetOption(kv.TiKVClientReadTimeout, uint64(3000)) // 3000ms. + m := meta.NewSnapshotMeta(snapshot) + neededSchemaVersion, err := m.GetSchemaVersionWithNonEmptyDiff() + if err != nil { + return nil, false, 0, nil, err + } + // fetch the commit timestamp of the schema diff + schemaTs, err := do.getTimestampForSchemaVersionWithNonEmptyDiff(m, neededSchemaVersion, startTS) + if err != nil { + logutil.BgLogger().Warn("failed to get schema version", zap.Error(err), zap.Int64("version", neededSchemaVersion)) + schemaTs = 0 + } + + if is := do.infoCache.GetByVersion(neededSchemaVersion); is != nil { + isV2, raw := infoschema.IsV2(is) + if isV2 { + // Copy the infoschema V2 instance and update its ts. + // For example, the DDL run 30 minutes ago, GC happened 10 minutes ago. If we use + // that infoschema it would get error "GC life time is shorter than transaction + // duration" when visiting TiKV. + // So we keep updating the ts of the infoschema v2. + is = raw.CloneAndUpdateTS(startTS) + } + + // try to insert here as well to correct the schemaTs if previous is wrong + // the insert method check if schemaTs is zero + do.infoCache.Insert(is, schemaTs) + + return is, true, 0, nil, nil + } + + var oldIsV2 bool + enableV2 := variable.SchemaCacheSize.Load() > 0 + currentSchemaVersion := int64(0) + if oldInfoSchema := do.infoCache.GetLatest(); oldInfoSchema != nil { + currentSchemaVersion = oldInfoSchema.SchemaMetaVersion() + oldIsV2, _ = infoschema.IsV2(oldInfoSchema) + } + useV2, isV1V2Switch := shouldUseV2(enableV2, oldIsV2, isSnapshot) + builder := infoschema.NewBuilder(do, do.sysFacHack, do.infoCache.Data, useV2) + + // TODO: tryLoadSchemaDiffs has potential risks of failure. And it becomes worse in history reading cases. + // It is only kept because there is no alternative diff/partial loading solution. + // And it is only used to diff upgrading the current latest infoschema, if: + // 1. Not first time bootstrap loading, which needs a full load. + // 2. It is newer than the current one, so it will be "the current one" after this function call. + // 3. There are less 100 diffs. + // 4. No regenerated schema diff. + startTime := time.Now() + if !isV1V2Switch && currentSchemaVersion != 0 && neededSchemaVersion > currentSchemaVersion && neededSchemaVersion-currentSchemaVersion < LoadSchemaDiffVersionGapThreshold { + is, relatedChanges, diffTypes, err := do.tryLoadSchemaDiffs(builder, m, currentSchemaVersion, neededSchemaVersion, startTS) + if err == nil { + infoschema_metrics.LoadSchemaDurationLoadDiff.Observe(time.Since(startTime).Seconds()) + isV2, _ := infoschema.IsV2(is) + do.infoCache.Insert(is, schemaTs) + logutil.BgLogger().Info("diff load InfoSchema success", + zap.Bool("isV2", isV2), + zap.Int64("currentSchemaVersion", currentSchemaVersion), + zap.Int64("neededSchemaVersion", neededSchemaVersion), + zap.Duration("elapsed time", time.Since(startTime)), + zap.Int64("gotSchemaVersion", is.SchemaMetaVersion()), + zap.Int64s("phyTblIDs", relatedChanges.PhyTblIDS), + zap.Uint64s("actionTypes", relatedChanges.ActionTypes), + zap.Strings("diffTypes", diffTypes)) + return is, false, currentSchemaVersion, relatedChanges, nil + } + // We can fall back to full load, don't need to return the error. + logutil.BgLogger().Error("failed to load schema diff", zap.Error(err)) + } + // full load. + schemas, err := do.fetchAllSchemasWithTables(m) + if err != nil { + return nil, false, currentSchemaVersion, nil, err + } + + policies, err := do.fetchPolicies(m) + if err != nil { + return nil, false, currentSchemaVersion, nil, err + } + + resourceGroups, err := do.fetchResourceGroups(m) + if err != nil { + return nil, false, currentSchemaVersion, nil, err + } + infoschema_metrics.LoadSchemaDurationLoadAll.Observe(time.Since(startTime).Seconds()) + + err = builder.InitWithDBInfos(schemas, policies, resourceGroups, neededSchemaVersion) + if err != nil { + return nil, false, currentSchemaVersion, nil, err + } + is := builder.Build(startTS) + isV2, _ := infoschema.IsV2(is) + logutil.BgLogger().Info("full load InfoSchema success", + zap.Bool("isV2", isV2), + zap.Int64("currentSchemaVersion", currentSchemaVersion), + zap.Int64("neededSchemaVersion", neededSchemaVersion), + zap.Duration("elapsed time", time.Since(startTime))) + + if isV1V2Switch && schemaTs > 0 { + // Reset the whole info cache to avoid co-existing of both v1 and v2, causing the memory usage doubled. + fn := do.infoCache.Upsert(is, schemaTs) + do.deferFn.add(fn, time.Now().Add(10*time.Minute)) + } else { + do.infoCache.Insert(is, schemaTs) + } + return is, false, currentSchemaVersion, nil, nil +} + +// Returns the timestamp of a schema version, which is the commit timestamp of the schema diff +func (do *Domain) getTimestampForSchemaVersionWithNonEmptyDiff(m *meta.Meta, version int64, startTS uint64) (uint64, error) { + tikvStore, ok := do.Store().(helper.Storage) + if ok { + newHelper := helper.NewHelper(tikvStore) + mvccResp, err := newHelper.GetMvccByEncodedKeyWithTS(m.EncodeSchemaDiffKey(version), startTS) + if err != nil { + return 0, err + } + if mvccResp == nil || mvccResp.Info == nil || len(mvccResp.Info.Writes) == 0 { + return 0, errors.Errorf("There is no Write MVCC info for the schema version") + } + return mvccResp.Info.Writes[0].CommitTs, nil + } + return 0, errors.Errorf("cannot get store from domain") +} + +func (do *Domain) sysFacHack() (pools.Resource, error) { + // TODO: Here we create new sessions with sysFac in DDL, + // which will use `do` as Domain instead of call `domap.Get`. + // That's because `domap.Get` requires a lock, but before + // we initialize Domain finish, we can't require that again. + // After we remove the lazy logic of creating Domain, we + // can simplify code here. + return do.sysExecutorFactory(do) +} + +func (*Domain) fetchPolicies(m *meta.Meta) ([]*model.PolicyInfo, error) { + allPolicies, err := m.ListPolicies() + if err != nil { + return nil, err + } + return allPolicies, nil +} + +func (*Domain) fetchResourceGroups(m *meta.Meta) ([]*model.ResourceGroupInfo, error) { + allResourceGroups, err := m.ListResourceGroups() + if err != nil { + return nil, err + } + return allResourceGroups, nil +} + +func (do *Domain) fetchAllSchemasWithTables(m *meta.Meta) ([]*model.DBInfo, error) { + allSchemas, err := m.ListDatabases() + if err != nil { + return nil, err + } + splittedSchemas := do.splitForConcurrentFetch(allSchemas) + doneCh := make(chan error, len(splittedSchemas)) + for _, schemas := range splittedSchemas { + go do.fetchSchemasWithTables(schemas, m, doneCh) + } + for range splittedSchemas { + err = <-doneCh + if err != nil { + return nil, err + } + } + return allSchemas, nil +} + +// fetchSchemaConcurrency controls the goroutines to load schemas, but more goroutines +// increase the memory usage when calling json.Unmarshal(), which would cause OOM, +// so we decrease the concurrency. +const fetchSchemaConcurrency = 1 + +func (*Domain) splitForConcurrentFetch(schemas []*model.DBInfo) [][]*model.DBInfo { + groupSize := (len(schemas) + fetchSchemaConcurrency - 1) / fetchSchemaConcurrency + if variable.SchemaCacheSize.Load() > 0 && len(schemas) > 1000 { + // TODO: Temporary solution to speed up when too many databases, will refactor it later. + groupSize = 8 + } + splitted := make([][]*model.DBInfo, 0, fetchSchemaConcurrency) + schemaCnt := len(schemas) + for i := 0; i < schemaCnt; i += groupSize { + end := i + groupSize + if end > schemaCnt { + end = schemaCnt + } + splitted = append(splitted, schemas[i:end]) + } + return splitted +} + +func (*Domain) fetchSchemasWithTables(schemas []*model.DBInfo, m *meta.Meta, done chan error) { + for _, di := range schemas { + if di.State != model.StatePublic { + // schema is not public, can't be used outside. + continue + } + var tables []*model.TableInfo + var err error + if variable.SchemaCacheSize.Load() > 0 && !infoschema.IsSpecialDB(di.Name.L) { + name2ID, specialTableInfos, err := meta.GetAllNameToIDAndTheMustLoadedTableInfo(m, di.ID) + if err != nil { + done <- err + return + } + di.TableName2ID = name2ID + tables = specialTableInfos + } else { + tables, err = m.ListTables(di.ID) + if err != nil { + done <- err + return + } + } + // If TreatOldVersionUTF8AsUTF8MB4 was enable, need to convert the old version schema UTF8 charset to UTF8MB4. + if config.GetGlobalConfig().TreatOldVersionUTF8AsUTF8MB4 { + for _, tbInfo := range tables { + infoschema.ConvertOldVersionUTF8ToUTF8MB4IfNeed(tbInfo) + } + } + diTables := make([]*model.TableInfo, 0, len(tables)) + for _, tbl := range tables { + if tbl.State != model.StatePublic { + // schema is not public, can't be used outside. + continue + } + infoschema.ConvertCharsetCollateToLowerCaseIfNeed(tbl) + // Check whether the table is in repair mode. + if domainutil.RepairInfo.InRepairMode() && domainutil.RepairInfo.CheckAndFetchRepairedTable(di, tbl) { + if tbl.State != model.StatePublic { + // Do not load it because we are reparing the table and the table info could be `bad` + // before repair is done. + continue + } + // If the state is public, it means that the DDL job is done, but the table + // haven't been deleted from the repair table list. + // Since the repairment is done and table is visible, we should load it. + } + diTables = append(diTables, tbl) + } + di.Deprecated.Tables = diTables + } + done <- nil +} + +// shouldUseV2 decides whether to use infoschema v2. +// When loading snapshot, infoschema should keep the same as before to avoid v1/v2 switch. +// Otherwise, it is decided by enabledV2. +func shouldUseV2(enableV2 bool, oldIsV2 bool, isSnapshot bool) (useV2 bool, isV1V2Switch bool) { + if isSnapshot { + return oldIsV2, false + } + return enableV2, enableV2 != oldIsV2 +} + +// tryLoadSchemaDiffs tries to only load latest schema changes. +// Return true if the schema is loaded successfully. +// Return false if the schema can not be loaded by schema diff, then we need to do full load. +// The second returned value is the delta updated table and partition IDs. +func (do *Domain) tryLoadSchemaDiffs(builder *infoschema.Builder, m *meta.Meta, usedVersion, newVersion int64, startTS uint64) (infoschema.InfoSchema, *transaction.RelatedSchemaChange, []string, error) { + var diffs []*model.SchemaDiff + for usedVersion < newVersion { + usedVersion++ + diff, err := m.GetSchemaDiff(usedVersion) + if err != nil { + return nil, nil, nil, err + } + if diff == nil { + // Empty diff means the txn of generating schema version is committed, but the txn of `runDDLJob` is not or fail. + // It is safe to skip the empty diff because the infoschema is new enough and consistent. + logutil.BgLogger().Info("diff load InfoSchema get empty schema diff", zap.Int64("version", usedVersion)) + do.infoCache.InsertEmptySchemaVersion(usedVersion) + continue + } + diffs = append(diffs, diff) + } + + failpoint.Inject("MockTryLoadDiffError", func(val failpoint.Value) { + switch val.(string) { + case "exchangepartition": + if diffs[0].Type == model.ActionExchangeTablePartition { + failpoint.Return(nil, nil, nil, errors.New("mock error")) + } + case "renametable": + if diffs[0].Type == model.ActionRenameTable { + failpoint.Return(nil, nil, nil, errors.New("mock error")) + } + case "dropdatabase": + if diffs[0].Type == model.ActionDropSchema { + failpoint.Return(nil, nil, nil, errors.New("mock error")) + } + } + }) + + err := builder.InitWithOldInfoSchema(do.infoCache.GetLatest()) + if err != nil { + return nil, nil, nil, errors.Trace(err) + } + + builder.WithStore(do.store).SetDeltaUpdateBundles() + phyTblIDs := make([]int64, 0, len(diffs)) + actions := make([]uint64, 0, len(diffs)) + diffTypes := make([]string, 0, len(diffs)) + for _, diff := range diffs { + if diff.RegenerateSchemaMap { + return nil, nil, nil, errors.Errorf("Meets a schema diff with RegenerateSchemaMap flag") + } + ids, err := builder.ApplyDiff(m, diff) + if err != nil { + return nil, nil, nil, err + } + if canSkipSchemaCheckerDDL(diff.Type) { + continue + } + diffTypes = append(diffTypes, diff.Type.String()) + phyTblIDs = append(phyTblIDs, ids...) + for i := 0; i < len(ids); i++ { + actions = append(actions, uint64(diff.Type)) + } + } + + is := builder.Build(startTS) + relatedChange := transaction.RelatedSchemaChange{} + relatedChange.PhyTblIDS = phyTblIDs + relatedChange.ActionTypes = actions + return is, &relatedChange, diffTypes, nil +} + +func canSkipSchemaCheckerDDL(tp model.ActionType) bool { + switch tp { + case model.ActionUpdateTiFlashReplicaStatus, model.ActionSetTiFlashReplica: + return true + } + return false +} + +// InfoSchema gets the latest information schema from domain. +func (do *Domain) InfoSchema() infoschema.InfoSchema { + return do.infoCache.GetLatest() +} + +// GetSnapshotInfoSchema gets a snapshot information schema. +func (do *Domain) GetSnapshotInfoSchema(snapshotTS uint64) (infoschema.InfoSchema, error) { + // if the snapshotTS is new enough, we can get infoschema directly through snapshotTS. + if is := do.infoCache.GetBySnapshotTS(snapshotTS); is != nil { + return is, nil + } + is, _, _, _, err := do.loadInfoSchema(snapshotTS, true) + infoschema_metrics.LoadSchemaCounterSnapshot.Inc() + return is, err +} + +// GetSnapshotMeta gets a new snapshot meta at startTS. +func (do *Domain) GetSnapshotMeta(startTS uint64) *meta.Meta { + snapshot := do.store.GetSnapshot(kv.NewVersion(startTS)) + return meta.NewSnapshotMeta(snapshot) +} + +// ExpiredTimeStamp4PC gets expiredTimeStamp4PC from domain. +func (do *Domain) ExpiredTimeStamp4PC() types.Time { + do.expiredTimeStamp4PC.RLock() + defer do.expiredTimeStamp4PC.RUnlock() + + return do.expiredTimeStamp4PC.expiredTimeStamp +} + +// SetExpiredTimeStamp4PC sets the expiredTimeStamp4PC from domain. +func (do *Domain) SetExpiredTimeStamp4PC(time types.Time) { + do.expiredTimeStamp4PC.Lock() + defer do.expiredTimeStamp4PC.Unlock() + + do.expiredTimeStamp4PC.expiredTimeStamp = time +} + +// DDL gets DDL from domain. +func (do *Domain) DDL() ddl.DDL { + return do.ddl +} + +// DDLExecutor gets the ddl executor from domain. +func (do *Domain) DDLExecutor() ddl.Executor { + return do.ddlExecutor +} + +// SetDDL sets DDL to domain, it's only used in tests. +func (do *Domain) SetDDL(d ddl.DDL, executor ddl.Executor) { + do.ddl = d + do.ddlExecutor = executor +} + +// InfoSyncer gets infoSyncer from domain. +func (do *Domain) InfoSyncer() *infosync.InfoSyncer { + return do.info +} + +// NotifyGlobalConfigChange notify global config syncer to store the global config into PD. +func (do *Domain) NotifyGlobalConfigChange(name, value string) { + do.globalCfgSyncer.Notify(pd.GlobalConfigItem{Name: name, Value: value, EventType: pdpb.EventType_PUT}) +} + +// GetGlobalConfigSyncer exports for testing. +func (do *Domain) GetGlobalConfigSyncer() *globalconfigsync.GlobalConfigSyncer { + return do.globalCfgSyncer +} + +// Store gets KV store from domain. +func (do *Domain) Store() kv.Storage { + return do.store +} + +// GetScope gets the status variables scope. +func (*Domain) GetScope(string) variable.ScopeFlag { + // Now domain status variables scope are all default scope. + return variable.DefaultStatusVarScopeFlag +} + +func getFlashbackStartTSFromErrorMsg(err error) uint64 { + slices := strings.Split(err.Error(), "is in flashback progress, FlashbackStartTS is ") + if len(slices) != 2 { + return 0 + } + version, err := strconv.ParseUint(slices[1], 10, 0) + if err != nil { + return 0 + } + return version +} + +// Reload reloads InfoSchema. +// It's public in order to do the test. +func (do *Domain) Reload() error { + failpoint.Inject("ErrorMockReloadFailed", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock reload failed")) + } + }) + + // Lock here for only once at the same time. + do.m.Lock() + defer do.m.Unlock() + + startTime := time.Now() + ver, err := do.store.CurrentVersion(kv.GlobalTxnScope) + if err != nil { + return err + } + + version := ver.Ver + is, hitCache, oldSchemaVersion, changes, err := do.loadInfoSchema(version, false) + if err != nil { + if version = getFlashbackStartTSFromErrorMsg(err); version != 0 { + // use the latest available version to create domain + version-- + is, hitCache, oldSchemaVersion, changes, err = do.loadInfoSchema(version, false) + } + } + if err != nil { + metrics.LoadSchemaCounter.WithLabelValues("failed").Inc() + return err + } + metrics.LoadSchemaCounter.WithLabelValues("succ").Inc() + + // only update if it is not from cache + if !hitCache { + // loaded newer schema + if oldSchemaVersion < is.SchemaMetaVersion() { + // Update self schema version to etcd. + err = do.ddl.SchemaSyncer().UpdateSelfVersion(context.Background(), 0, is.SchemaMetaVersion()) + if err != nil { + logutil.BgLogger().Info("update self version failed", + zap.Int64("oldSchemaVersion", oldSchemaVersion), + zap.Int64("neededSchemaVersion", is.SchemaMetaVersion()), zap.Error(err)) + } + } + + // it is full load + if changes == nil { + logutil.BgLogger().Info("full load and reset schema validator") + do.SchemaValidator.Reset() + } + } + + // lease renew, so it must be executed despite it is cache or not + do.SchemaValidator.Update(version, oldSchemaVersion, is.SchemaMetaVersion(), changes) + lease := do.DDL().GetLease() + sub := time.Since(startTime) + // Reload interval is lease / 2, if load schema time elapses more than this interval, + // some query maybe responded by ErrInfoSchemaExpired error. + if sub > (lease/2) && lease > 0 { + logutil.BgLogger().Warn("loading schema takes a long time", zap.Duration("take time", sub)) + } + + return nil +} + +// LogSlowQuery keeps topN recent slow queries in domain. +func (do *Domain) LogSlowQuery(query *SlowQueryInfo) { + do.slowQuery.mu.RLock() + defer do.slowQuery.mu.RUnlock() + if do.slowQuery.mu.closed { + return + } + + select { + case do.slowQuery.ch <- query: + default: + } +} + +// ShowSlowQuery returns the slow queries. +func (do *Domain) ShowSlowQuery(showSlow *ast.ShowSlow) []*SlowQueryInfo { + msg := &showSlowMessage{ + request: showSlow, + } + msg.Add(1) + do.slowQuery.msgCh <- msg + msg.Wait() + return msg.result +} + +func (do *Domain) topNSlowQueryLoop() { + defer util.Recover(metrics.LabelDomain, "topNSlowQueryLoop", nil, false) + ticker := time.NewTicker(time.Minute * 10) + defer func() { + ticker.Stop() + logutil.BgLogger().Info("topNSlowQueryLoop exited.") + }() + for { + select { + case now := <-ticker.C: + do.slowQuery.RemoveExpired(now) + case info, ok := <-do.slowQuery.ch: + if !ok { + return + } + do.slowQuery.Append(info) + case msg := <-do.slowQuery.msgCh: + req := msg.request + switch req.Tp { + case ast.ShowSlowTop: + msg.result = do.slowQuery.QueryTop(int(req.Count), req.Kind) + case ast.ShowSlowRecent: + msg.result = do.slowQuery.QueryRecent(int(req.Count)) + default: + msg.result = do.slowQuery.QueryAll() + } + msg.Done() + } + } +} + +func (do *Domain) infoSyncerKeeper() { + defer func() { + logutil.BgLogger().Info("infoSyncerKeeper exited.") + }() + + defer util.Recover(metrics.LabelDomain, "infoSyncerKeeper", nil, false) + + ticker := time.NewTicker(infosync.ReportInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + do.info.ReportMinStartTS(do.Store()) + case <-do.info.Done(): + logutil.BgLogger().Info("server info syncer need to restart") + if err := do.info.Restart(context.Background()); err != nil { + logutil.BgLogger().Error("server info syncer restart failed", zap.Error(err)) + } else { + logutil.BgLogger().Info("server info syncer restarted") + } + case <-do.exit: + return + } + } +} + +func (do *Domain) globalConfigSyncerKeeper() { + defer func() { + logutil.BgLogger().Info("globalConfigSyncerKeeper exited.") + }() + + defer util.Recover(metrics.LabelDomain, "globalConfigSyncerKeeper", nil, false) + + for { + select { + case entry := <-do.globalCfgSyncer.NotifyCh: + err := do.globalCfgSyncer.StoreGlobalConfig(context.Background(), entry) + if err != nil { + logutil.BgLogger().Error("global config syncer store failed", zap.Error(err)) + } + // TODO(crazycs520): Add owner to maintain global config is consistency with global variable. + case <-do.exit: + return + } + } +} + +func (do *Domain) topologySyncerKeeper() { + defer util.Recover(metrics.LabelDomain, "topologySyncerKeeper", nil, false) + ticker := time.NewTicker(infosync.TopologyTimeToRefresh) + defer func() { + ticker.Stop() + logutil.BgLogger().Info("topologySyncerKeeper exited.") + }() + + for { + select { + case <-ticker.C: + err := do.info.StoreTopologyInfo(context.Background()) + if err != nil { + logutil.BgLogger().Error("refresh topology in loop failed", zap.Error(err)) + } + case <-do.info.TopologyDone(): + logutil.BgLogger().Info("server topology syncer need to restart") + if err := do.info.RestartTopology(context.Background()); err != nil { + logutil.BgLogger().Error("server topology syncer restart failed", zap.Error(err)) + } else { + logutil.BgLogger().Info("server topology syncer restarted") + } + case <-do.exit: + return + } + } +} + +func (do *Domain) refreshMDLCheckTableInfo() { + se, err := do.sysSessionPool.Get() + + if err != nil { + logutil.BgLogger().Warn("get system session failed", zap.Error(err)) + return + } + // Make sure the session is new. + sctx := se.(sessionctx.Context) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) + if _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "rollback"); err != nil { + se.Close() + return + } + defer do.sysSessionPool.Put(se) + exec := sctx.GetRestrictedSQLExecutor() + domainSchemaVer := do.InfoSchema().SchemaMetaVersion() + // the job must stay inside tidb_ddl_job if we need to wait schema version for it. + sql := fmt.Sprintf(`select job_id, version, table_ids from mysql.tidb_mdl_info + where job_id >= %d and version <= %d`, do.minJobIDRefresher.GetCurrMinJobID(), domainSchemaVer) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) + if err != nil { + logutil.BgLogger().Warn("get mdl info from tidb_mdl_info failed", zap.Error(err)) + return + } + do.mdlCheckTableInfo.mu.Lock() + defer do.mdlCheckTableInfo.mu.Unlock() + + do.mdlCheckTableInfo.newestVer = domainSchemaVer + do.mdlCheckTableInfo.jobsVerMap = make(map[int64]int64, len(rows)) + do.mdlCheckTableInfo.jobsIDsMap = make(map[int64]string, len(rows)) + for i := 0; i < len(rows); i++ { + do.mdlCheckTableInfo.jobsVerMap[rows[i].GetInt64(0)] = rows[i].GetInt64(1) + do.mdlCheckTableInfo.jobsIDsMap[rows[i].GetInt64(0)] = rows[i].GetString(2) + } +} + +func (do *Domain) mdlCheckLoop() { + ticker := time.Tick(mdlCheckLookDuration) + var saveMaxSchemaVersion int64 + jobNeedToSync := false + jobCache := make(map[int64]int64, 1000) + + for { + // Wait for channels + select { + case <-do.mdlCheckCh: + case <-ticker: + case <-do.exit: + return + } + + if !variable.EnableMDL.Load() { + continue + } + + do.mdlCheckTableInfo.mu.Lock() + maxVer := do.mdlCheckTableInfo.newestVer + if maxVer > saveMaxSchemaVersion { + saveMaxSchemaVersion = maxVer + } else if !jobNeedToSync { + // Schema doesn't change, and no job to check in the last run. + do.mdlCheckTableInfo.mu.Unlock() + continue + } + + jobNeedToCheckCnt := len(do.mdlCheckTableInfo.jobsVerMap) + if jobNeedToCheckCnt == 0 { + jobNeedToSync = false + do.mdlCheckTableInfo.mu.Unlock() + continue + } + + jobsVerMap := make(map[int64]int64, len(do.mdlCheckTableInfo.jobsVerMap)) + jobsIDsMap := make(map[int64]string, len(do.mdlCheckTableInfo.jobsIDsMap)) + for k, v := range do.mdlCheckTableInfo.jobsVerMap { + jobsVerMap[k] = v + } + for k, v := range do.mdlCheckTableInfo.jobsIDsMap { + jobsIDsMap[k] = v + } + do.mdlCheckTableInfo.mu.Unlock() + + jobNeedToSync = true + + sm := do.InfoSyncer().GetSessionManager() + if sm == nil { + logutil.BgLogger().Info("session manager is nil") + } else { + sm.CheckOldRunningTxn(jobsVerMap, jobsIDsMap) + } + + if len(jobsVerMap) == jobNeedToCheckCnt { + jobNeedToSync = false + } + + // Try to gc jobCache. + if len(jobCache) > 1000 { + jobCache = make(map[int64]int64, 1000) + } + + for jobID, ver := range jobsVerMap { + if cver, ok := jobCache[jobID]; ok && cver >= ver { + // Already update, skip it. + continue + } + logutil.BgLogger().Info("mdl gets lock, update self version to owner", zap.Int64("jobID", jobID), zap.Int64("version", ver)) + err := do.ddl.SchemaSyncer().UpdateSelfVersion(context.Background(), jobID, ver) + if err != nil { + jobNeedToSync = true + logutil.BgLogger().Warn("mdl gets lock, update self version to owner failed", + zap.Int64("jobID", jobID), zap.Int64("version", ver), zap.Error(err)) + } else { + jobCache[jobID] = ver + } + } + } +} + +func (do *Domain) loadSchemaInLoop(ctx context.Context, lease time.Duration) { + defer util.Recover(metrics.LabelDomain, "loadSchemaInLoop", nil, true) + // Lease renewal can run at any frequency. + // Use lease/2 here as recommend by paper. + ticker := time.NewTicker(lease / 2) + defer func() { + ticker.Stop() + logutil.BgLogger().Info("loadSchemaInLoop exited.") + }() + syncer := do.ddl.SchemaSyncer() + + for { + select { + case <-ticker.C: + failpoint.Inject("disableOnTickReload", func() { + failpoint.Continue() + }) + err := do.Reload() + if err != nil { + logutil.BgLogger().Error("reload schema in loop failed", zap.Error(err)) + } + do.deferFn.check() + case _, ok := <-syncer.GlobalVersionCh(): + err := do.Reload() + if err != nil { + logutil.BgLogger().Error("reload schema in loop failed", zap.Error(err)) + } + if !ok { + logutil.BgLogger().Warn("reload schema in loop, schema syncer need rewatch") + // Make sure the rewatch doesn't affect load schema, so we watch the global schema version asynchronously. + syncer.WatchGlobalSchemaVer(context.Background()) + } + case <-syncer.Done(): + // The schema syncer stops, we need stop the schema validator to synchronize the schema version. + logutil.BgLogger().Info("reload schema in loop, schema syncer need restart") + // The etcd is responsible for schema synchronization, we should ensure there is at most two different schema version + // in the TiDB cluster, to make the data/schema be consistent. If we lost connection/session to etcd, the cluster + // will treats this TiDB as a down instance, and etcd will remove the key of `/tidb/ddl/all_schema_versions/tidb-id`. + // Say the schema version now is 1, the owner is changing the schema version to 2, it will not wait for this down TiDB syncing the schema, + // then continue to change the TiDB schema to version 3. Unfortunately, this down TiDB schema version will still be version 1. + // And version 1 is not consistent to version 3. So we need to stop the schema validator to prohibit the DML executing. + do.SchemaValidator.Stop() + err := do.mustRestartSyncer(ctx) + if err != nil { + logutil.BgLogger().Error("reload schema in loop, schema syncer restart failed", zap.Error(err)) + break + } + // The schema maybe changed, must reload schema then the schema validator can restart. + exitLoop := do.mustReload() + // domain is closed. + if exitLoop { + logutil.BgLogger().Error("domain is closed, exit loadSchemaInLoop") + return + } + do.SchemaValidator.Restart() + logutil.BgLogger().Info("schema syncer restarted") + case <-do.exit: + return + } + do.refreshMDLCheckTableInfo() + select { + case do.mdlCheckCh <- struct{}{}: + default: + } + } +} + +// mustRestartSyncer tries to restart the SchemaSyncer. +// It returns until it's successful or the domain is stopped. +func (do *Domain) mustRestartSyncer(ctx context.Context) error { + syncer := do.ddl.SchemaSyncer() + + for { + err := syncer.Restart(ctx) + if err == nil { + return nil + } + // If the domain has stopped, we return an error immediately. + if do.isClose() { + return err + } + logutil.BgLogger().Error("restart the schema syncer failed", zap.Error(err)) + time.Sleep(time.Second) + } +} + +// mustReload tries to Reload the schema, it returns until it's successful or the domain is closed. +// it returns false when it is successful, returns true when the domain is closed. +func (do *Domain) mustReload() (exitLoop bool) { + for { + err := do.Reload() + if err == nil { + logutil.BgLogger().Info("mustReload succeed") + return false + } + + // If the domain is closed, we returns immediately. + logutil.BgLogger().Info("reload the schema failed", zap.Error(err)) + if do.isClose() { + return true + } + time.Sleep(200 * time.Millisecond) + } +} + +func (do *Domain) isClose() bool { + select { + case <-do.exit: + logutil.BgLogger().Info("domain is closed") + return true + default: + } + return false +} + +// Close closes the Domain and release its resource. +func (do *Domain) Close() { + if do == nil { + return + } + startTime := time.Now() + if do.ddl != nil { + terror.Log(do.ddl.Stop()) + } + if do.info != nil { + do.info.RemoveServerInfo() + do.info.RemoveMinStartTS() + } + ttlJobManager := do.ttlJobManager.Load() + if ttlJobManager != nil { + logutil.BgLogger().Info("stopping ttlJobManager") + ttlJobManager.Stop() + err := ttlJobManager.WaitStopped(context.Background(), func() time.Duration { + if intest.InTest { + return 10 * time.Second + } + return 30 * time.Second + }()) + if err != nil { + logutil.BgLogger().Warn("fail to wait until the ttl job manager stop", zap.Error(err)) + } else { + logutil.BgLogger().Info("ttlJobManager exited.") + } + } + do.releaseServerID(context.Background()) + close(do.exit) + if do.etcdClient != nil { + terror.Log(errors.Trace(do.etcdClient.Close())) + } + + do.runawayManager.Stop() + + if do.unprefixedEtcdCli != nil { + terror.Log(errors.Trace(do.unprefixedEtcdCli.Close())) + } + + do.slowQuery.Close() + do.cancelFns.mu.Lock() + for _, f := range do.cancelFns.fns { + f() + } + do.cancelFns.mu.Unlock() + do.wg.Wait() + do.sysSessionPool.Close() + variable.UnregisterStatistics(do.BindHandle()) + if do.onClose != nil { + do.onClose() + } + gctuner.WaitMemoryLimitTunerExitInTest() + close(do.mdlCheckCh) + + // close MockGlobalServerInfoManagerEntry in order to refresh mock server info. + if intest.InTest { + infosync.MockGlobalServerInfoManagerEntry.Close() + } + if handle := do.statsHandle.Load(); handle != nil { + handle.Close() + } + + logutil.BgLogger().Info("domain closed", zap.Duration("take time", time.Since(startTime))) +} + +const resourceIdleTimeout = 3 * time.Minute // resources in the ResourcePool will be recycled after idleTimeout + +// NewDomain creates a new domain. Should not create multiple domains for the same store. +func NewDomain(store kv.Storage, ddlLease time.Duration, statsLease time.Duration, dumpFileGcLease time.Duration, factory pools.Factory) *Domain { + capacity := 200 // capacity of the sysSessionPool size + do := &Domain{ + store: store, + exit: make(chan struct{}), + sysSessionPool: util.NewSessionPool( + capacity, factory, + func(r pools.Resource) { + _, ok := r.(sessionctx.Context) + intest.Assert(ok) + infosync.StoreInternalSession(r) + }, + func(r pools.Resource) { + _, ok := r.(sessionctx.Context) + intest.Assert(ok) + infosync.DeleteInternalSession(r) + }, + ), + statsLease: statsLease, + slowQuery: newTopNSlowQueries(config.GetGlobalConfig().InMemSlowQueryTopNNum, time.Hour*24*7, config.GetGlobalConfig().InMemSlowQueryRecentNum), + dumpFileGcChecker: &dumpFileGcChecker{gcLease: dumpFileGcLease, paths: []string{replayer.GetPlanReplayerDirName(), GetOptimizerTraceDirName(), GetExtractTaskDirName()}}, + mdlCheckTableInfo: &mdlCheckTableInfo{ + mu: sync.Mutex{}, + jobsVerMap: make(map[int64]int64), + jobsIDsMap: make(map[int64]string), + }, + mdlCheckCh: make(chan struct{}), + } + + do.infoCache = infoschema.NewCache(do, int(variable.SchemaVersionCacheLimit.Load())) + do.stopAutoAnalyze.Store(false) + do.wg = util.NewWaitGroupEnhancedWrapper("domain", do.exit, config.GetGlobalConfig().TiDBEnableExitCheck) + do.SchemaValidator = NewSchemaValidator(ddlLease, do) + do.expensiveQueryHandle = expensivequery.NewExpensiveQueryHandle(do.exit) + do.memoryUsageAlarmHandle = memoryusagealarm.NewMemoryUsageAlarmHandle(do.exit) + do.serverMemoryLimitHandle = servermemorylimit.NewServerMemoryLimitHandle(do.exit) + do.sysProcesses = SysProcesses{mu: &sync.RWMutex{}, procMap: make(map[uint64]sysproctrack.TrackProc)} + do.initDomainSysVars() + do.expiredTimeStamp4PC.expiredTimeStamp = types.NewTime(types.ZeroCoreTime, mysql.TypeTimestamp, types.DefaultFsp) + return do +} + +const serverIDForStandalone = 1 // serverID for standalone deployment. + +func newEtcdCli(addrs []string, ebd kv.EtcdBackend) (*clientv3.Client, error) { + cfg := config.GetGlobalConfig() + etcdLogCfg := zap.NewProductionConfig() + etcdLogCfg.Level = zap.NewAtomicLevelAt(zap.ErrorLevel) + backoffCfg := backoff.DefaultConfig + backoffCfg.MaxDelay = 3 * time.Second + cli, err := clientv3.New(clientv3.Config{ + LogConfig: &etcdLogCfg, + Endpoints: addrs, + AutoSyncInterval: 30 * time.Second, + DialTimeout: 5 * time.Second, + DialOptions: []grpc.DialOption{ + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoffCfg, + }), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: time.Duration(cfg.TiKVClient.GrpcKeepAliveTime) * time.Second, + Timeout: time.Duration(cfg.TiKVClient.GrpcKeepAliveTimeout) * time.Second, + }), + }, + TLS: ebd.TLSConfig(), + }) + return cli, err +} + +// Init initializes a domain. after return, session can be used to do DMLs but not +// DDLs which can be used after domain Start. +func (do *Domain) Init( + ddlLease time.Duration, + sysExecutorFactory func(*Domain) (pools.Resource, error), + ddlInjector func(ddl.DDL, ddl.Executor, *infoschema.InfoCache) *schematracker.Checker, +) error { + // TODO there are many place set ddlLease to 0, remove them completely, we want + // UT and even local uni-store to run similar code path as normal. + if ddlLease == 0 { + ddlLease = time.Second + } + + do.sysExecutorFactory = sysExecutorFactory + perfschema.Init() + if ebd, ok := do.store.(kv.EtcdBackend); ok { + var addrs []string + var err error + if addrs, err = ebd.EtcdAddrs(); err != nil { + return err + } + if addrs != nil { + cli, err := newEtcdCli(addrs, ebd) + if err != nil { + return errors.Trace(err) + } + + etcd.SetEtcdCliByNamespace(cli, keyspace.MakeKeyspaceEtcdNamespace(do.store.GetCodec())) + + do.etcdClient = cli + + do.autoidClient = autoid.NewClientDiscover(cli) + + unprefixedEtcdCli, err := newEtcdCli(addrs, ebd) + if err != nil { + return errors.Trace(err) + } + do.unprefixedEtcdCli = unprefixedEtcdCli + } + } + + ctx, cancelFunc := context.WithCancel(context.Background()) + do.ctx = ctx + do.cancelFns.mu.Lock() + do.cancelFns.fns = append(do.cancelFns.fns, cancelFunc) + do.cancelFns.mu.Unlock() + d := do.ddl + eBak := do.ddlExecutor + do.ddl, do.ddlExecutor = ddl.NewDDL( + ctx, + ddl.WithEtcdClient(do.etcdClient), + ddl.WithStore(do.store), + ddl.WithAutoIDClient(do.autoidClient), + ddl.WithInfoCache(do.infoCache), + ddl.WithLease(ddlLease), + ddl.WithSchemaLoader(do), + ) + + failpoint.Inject("MockReplaceDDL", func(val failpoint.Value) { + if val.(bool) { + do.ddl = d + do.ddlExecutor = eBak + } + }) + if ddlInjector != nil { + checker := ddlInjector(do.ddl, do.ddlExecutor, do.infoCache) + checker.CreateTestDB(nil) + do.ddl = checker + do.ddlExecutor = checker + } + + // step 1: prepare the info/schema syncer which domain reload needed. + pdCli, pdHTTPCli := do.GetPDClient(), do.GetPDHTTPClient() + skipRegisterToDashboard := config.GetGlobalConfig().SkipRegisterToDashboard + var err error + do.info, err = infosync.GlobalInfoSyncerInit(ctx, do.ddl.GetID(), do.ServerID, + do.etcdClient, do.unprefixedEtcdCli, pdCli, pdHTTPCli, + do.Store().GetCodec(), skipRegisterToDashboard) + if err != nil { + return err + } + do.globalCfgSyncer = globalconfigsync.NewGlobalConfigSyncer(pdCli) + err = do.ddl.SchemaSyncer().Init(ctx) + if err != nil { + return err + } + + // step 2: initialize the global kill, which depends on `globalInfoSyncer`.` + if config.GetGlobalConfig().EnableGlobalKill { + do.connIDAllocator = globalconn.NewGlobalAllocator(do.ServerID, config.GetGlobalConfig().Enable32BitsConnectionID) + + if do.etcdClient != nil { + err := do.acquireServerID(ctx) + if err != nil { + logutil.BgLogger().Error("acquire serverID failed", zap.Error(err)) + do.isLostConnectionToPD.Store(1) // will retry in `do.serverIDKeeper` + } else { + if err := do.info.StoreServerInfo(context.Background()); err != nil { + return errors.Trace(err) + } + do.isLostConnectionToPD.Store(0) + } + } else { + // set serverID for standalone deployment to enable 'KILL'. + atomic.StoreUint64(&do.serverID, serverIDForStandalone) + } + } else { + do.connIDAllocator = globalconn.NewSimpleAllocator() + } + + // should put `initResourceGroupsController` after fetching server ID + err = do.initResourceGroupsController(ctx, pdCli, do.ServerID()) + if err != nil { + return err + } + + startReloadTime := time.Now() + // step 3: domain reload the infoSchema. + err = do.Reload() + if err != nil { + return err + } + + sub := time.Since(startReloadTime) + // The reload(in step 2) operation takes more than ddlLease and a new reload operation was not performed, + // the next query will respond by ErrInfoSchemaExpired error. So we do a new reload to update schemaValidator.latestSchemaExpire. + if sub > (ddlLease / 2) { + logutil.BgLogger().Warn("loading schema and starting ddl take a long time, we do a new reload", zap.Duration("take time", sub)) + err = do.Reload() + if err != nil { + return err + } + } + return nil +} + +// Start starts the domain. After start, DDLs can be executed using session, see +// Init also. +func (do *Domain) Start() error { + gCfg := config.GetGlobalConfig() + if gCfg.EnableGlobalKill && do.etcdClient != nil { + do.wg.Add(1) + go do.serverIDKeeper() + } + + // TODO: Here we create new sessions with sysFac in DDL, + // which will use `do` as Domain instead of call `domap.Get`. + // That's because `domap.Get` requires a lock, but before + // we initialize Domain finish, we can't require that again. + // After we remove the lazy logic of creating Domain, we + // can simplify code here. + sysFac := func() (pools.Resource, error) { + return do.sysExecutorFactory(do) + } + sysCtxPool := pools.NewResourcePool(sysFac, 512, 512, resourceIdleTimeout) + + // start the ddl after the domain reload, avoiding some internal sql running before infoSchema construction. + err := do.ddl.Start(sysCtxPool) + if err != nil { + return err + } + do.minJobIDRefresher = do.ddl.GetMinJobIDRefresher() + + // Local store needs to get the change information for every DDL state in each session. + do.wg.Run(func() { + do.loadSchemaInLoop(do.ctx, do.ddl.GetLease()) + }, "loadSchemaInLoop") + do.wg.Run(do.mdlCheckLoop, "mdlCheckLoop") + do.wg.Run(do.topNSlowQueryLoop, "topNSlowQueryLoop") + do.wg.Run(do.infoSyncerKeeper, "infoSyncerKeeper") + do.wg.Run(do.globalConfigSyncerKeeper, "globalConfigSyncerKeeper") + do.wg.Run(do.runawayStartLoop, "runawayStartLoop") + do.wg.Run(do.requestUnitsWriterLoop, "requestUnitsWriterLoop") + skipRegisterToDashboard := gCfg.SkipRegisterToDashboard + if !skipRegisterToDashboard { + do.wg.Run(do.topologySyncerKeeper, "topologySyncerKeeper") + } + pdCli := do.GetPDClient() + if pdCli != nil { + do.wg.Run(func() { + do.closestReplicaReadCheckLoop(do.ctx, pdCli) + }, "closestReplicaReadCheckLoop") + } + + err = do.initLogBackup(do.ctx, pdCli) + if err != nil { + return err + } + + return nil +} + +// InitInfo4Test init infosync for distributed execution test. +func (do *Domain) InitInfo4Test() { + infosync.MockGlobalServerInfoManagerEntry.Add(do.ddl.GetID(), do.ServerID) +} + +// SetOnClose used to set do.onClose func. +func (do *Domain) SetOnClose(onClose func()) { + do.onClose = onClose +} + +func (do *Domain) initLogBackup(ctx context.Context, pdClient pd.Client) error { + cfg := config.GetGlobalConfig() + if pdClient == nil || do.etcdClient == nil { + log.Warn("pd / etcd client not provided, won't begin Advancer.") + return nil + } + tikvStore, ok := do.Store().(tikv.Storage) + if !ok { + log.Warn("non tikv store, stop begin Advancer.") + return nil + } + env, err := streamhelper.TiDBEnv(tikvStore, pdClient, do.etcdClient, cfg) + if err != nil { + return err + } + adv := streamhelper.NewCheckpointAdvancer(env) + do.logBackupAdvancer = daemon.New(adv, streamhelper.OwnerManagerForLogBackup(ctx, do.etcdClient), adv.Config().TickDuration) + loop, err := do.logBackupAdvancer.Begin(ctx) + if err != nil { + return err + } + do.wg.Run(loop, "logBackupAdvancer") + return nil +} + +// when tidb_replica_read = 'closest-adaptive', check tidb and tikv's zone label matches. +// if not match, disable replica_read to avoid uneven read traffic distribution. +func (do *Domain) closestReplicaReadCheckLoop(ctx context.Context, pdClient pd.Client) { + defer util.Recover(metrics.LabelDomain, "closestReplicaReadCheckLoop", nil, false) + + // trigger check once instantly. + if err := do.checkReplicaRead(ctx, pdClient); err != nil { + logutil.BgLogger().Warn("refresh replicaRead flag failed", zap.Error(err)) + } + + ticker := time.NewTicker(time.Minute) + defer func() { + ticker.Stop() + logutil.BgLogger().Info("closestReplicaReadCheckLoop exited.") + }() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := do.checkReplicaRead(ctx, pdClient); err != nil { + logutil.BgLogger().Warn("refresh replicaRead flag failed", zap.Error(err)) + } + } + } +} + +// Periodically check and update the replica-read status when `tidb_replica_read` is set to "closest-adaptive" +// We disable "closest-adaptive" in following conditions to ensure the read traffic is evenly distributed across +// all AZs: +// - There are no TiKV servers in the AZ of this tidb instance +// - The AZ if this tidb contains more tidb than other AZ and this tidb's id is the bigger one. +func (do *Domain) checkReplicaRead(ctx context.Context, pdClient pd.Client) error { + do.sysVarCache.RLock() + replicaRead := do.sysVarCache.global[variable.TiDBReplicaRead] + do.sysVarCache.RUnlock() + + if !strings.EqualFold(replicaRead, "closest-adaptive") { + logutil.BgLogger().Debug("closest replica read is not enabled, skip check!", zap.String("mode", replicaRead)) + return nil + } + + serverInfo, err := infosync.GetServerInfo() + if err != nil { + return err + } + zone := "" + for k, v := range serverInfo.Labels { + if k == placement.DCLabelKey && v != "" { + zone = v + break + } + } + if zone == "" { + logutil.BgLogger().Debug("server contains no 'zone' label, disable closest replica read", zap.Any("labels", serverInfo.Labels)) + variable.SetEnableAdaptiveReplicaRead(false) + return nil + } + + stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) + if err != nil { + return err + } + + storeZones := make(map[string]int) + for _, s := range stores { + // skip tumbstone stores or tiflash + if s.NodeState == metapb.NodeState_Removing || s.NodeState == metapb.NodeState_Removed || engine.IsTiFlash(s) { + continue + } + for _, label := range s.Labels { + if label.Key == placement.DCLabelKey && label.Value != "" { + storeZones[label.Value] = 0 + break + } + } + } + + // no stores in this AZ + if _, ok := storeZones[zone]; !ok { + variable.SetEnableAdaptiveReplicaRead(false) + return nil + } + + servers, err := infosync.GetAllServerInfo(ctx) + if err != nil { + return err + } + svrIDsInThisZone := make([]string, 0) + for _, s := range servers { + if v, ok := s.Labels[placement.DCLabelKey]; ok && v != "" { + if _, ok := storeZones[v]; ok { + storeZones[v]++ + if v == zone { + svrIDsInThisZone = append(svrIDsInThisZone, s.ID) + } + } + } + } + enabledCount := math.MaxInt + for _, count := range storeZones { + if count < enabledCount { + enabledCount = count + } + } + // sort tidb in the same AZ by ID and disable the tidb with bigger ID + // because ID is unchangeable, so this is a simple and stable algorithm to select + // some instances across all tidb servers. + if enabledCount < len(svrIDsInThisZone) { + sort.Slice(svrIDsInThisZone, func(i, j int) bool { + return strings.Compare(svrIDsInThisZone[i], svrIDsInThisZone[j]) < 0 + }) + } + enabled := true + for _, s := range svrIDsInThisZone[enabledCount:] { + if s == serverInfo.ID { + enabled = false + break + } + } + + if variable.SetEnableAdaptiveReplicaRead(enabled) { + logutil.BgLogger().Info("tidb server adaptive closest replica read is changed", zap.Bool("enable", enabled)) + } + return nil +} + +// InitDistTaskLoop initializes the distributed task framework. +func (do *Domain) InitDistTaskLoop() error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalDistTask) + failpoint.Inject("MockDisableDistTask", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil) + } + }) + + taskManager := storage.NewTaskManager(do.sysSessionPool) + var serverID string + if intest.InTest { + do.InitInfo4Test() + serverID = disttaskutil.GenerateSubtaskExecID4Test(do.ddl.GetID()) + } else { + serverID = disttaskutil.GenerateSubtaskExecID(ctx, do.ddl.GetID()) + } + + if serverID == "" { + errMsg := fmt.Sprintf("TiDB node ID( = %s ) not found in available TiDB nodes list", do.ddl.GetID()) + return errors.New(errMsg) + } + managerCtx, cancel := context.WithCancel(ctx) + do.cancelFns.mu.Lock() + do.cancelFns.fns = append(do.cancelFns.fns, cancel) + do.cancelFns.mu.Unlock() + executorManager, err := taskexecutor.NewManager(managerCtx, serverID, taskManager) + if err != nil { + return err + } + + storage.SetTaskManager(taskManager) + if err = executorManager.InitMeta(); err != nil { + // executor manager loop will try to recover meta repeatedly, so we can + // just log the error here. + logutil.BgLogger().Warn("init task executor manager meta failed", zap.Error(err)) + } + do.wg.Run(func() { + defer func() { + storage.SetTaskManager(nil) + }() + do.distTaskFrameworkLoop(ctx, taskManager, executorManager, serverID) + }, "distTaskFrameworkLoop") + return nil +} + +func (do *Domain) distTaskFrameworkLoop(ctx context.Context, taskManager *storage.TaskManager, executorManager *taskexecutor.Manager, serverID string) { + err := executorManager.Start() + if err != nil { + logutil.BgLogger().Error("dist task executor manager start failed", zap.Error(err)) + return + } + logutil.BgLogger().Info("dist task executor manager started") + defer func() { + logutil.BgLogger().Info("stopping dist task executor manager") + executorManager.Stop() + logutil.BgLogger().Info("dist task executor manager stopped") + }() + + var schedulerManager *scheduler.Manager + startSchedulerMgrIfNeeded := func() { + if schedulerManager != nil && schedulerManager.Initialized() { + return + } + schedulerManager = scheduler.NewManager(ctx, taskManager, serverID) + schedulerManager.Start() + } + stopSchedulerMgrIfNeeded := func() { + if schedulerManager != nil && schedulerManager.Initialized() { + logutil.BgLogger().Info("stopping dist task scheduler manager because the current node is not DDL owner anymore", zap.String("id", do.ddl.GetID())) + schedulerManager.Stop() + logutil.BgLogger().Info("dist task scheduler manager stopped", zap.String("id", do.ddl.GetID())) + } + } + + ticker := time.NewTicker(time.Second) + for { + select { + case <-do.exit: + stopSchedulerMgrIfNeeded() + return + case <-ticker.C: + if do.ddl.OwnerManager().IsOwner() { + startSchedulerMgrIfNeeded() + } else { + stopSchedulerMgrIfNeeded() + } + } + } +} + +// SysSessionPool returns the system session pool. +func (do *Domain) SysSessionPool() util.SessionPool { + return do.sysSessionPool +} + +// SysProcTracker returns the system processes tracker. +func (do *Domain) SysProcTracker() sysproctrack.Tracker { + return &do.sysProcesses +} + +// GetEtcdClient returns the etcd client. +func (do *Domain) GetEtcdClient() *clientv3.Client { + return do.etcdClient +} + +// AutoIDClient returns the autoid client. +func (do *Domain) AutoIDClient() *autoid.ClientDiscover { + return do.autoidClient +} + +// GetPDClient returns the PD client. +func (do *Domain) GetPDClient() pd.Client { + if store, ok := do.store.(kv.StorageWithPD); ok { + return store.GetPDClient() + } + return nil +} + +// GetPDHTTPClient returns the PD HTTP client. +func (do *Domain) GetPDHTTPClient() pdhttp.Client { + if store, ok := do.store.(kv.StorageWithPD); ok { + return store.GetPDHTTPClient() + } + return nil +} + +// LoadPrivilegeLoop create a goroutine loads privilege tables in a loop, it +// should be called only once in BootstrapSession. +func (do *Domain) LoadPrivilegeLoop(sctx sessionctx.Context) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + sctx.GetSessionVars().InRestrictedSQL = true + _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "set @@autocommit = 1") + if err != nil { + return err + } + do.privHandle = privileges.NewHandle() + err = do.privHandle.Update(sctx) + if err != nil { + return err + } + + var watchCh clientv3.WatchChan + duration := 5 * time.Minute + if do.etcdClient != nil { + watchCh = do.etcdClient.Watch(context.Background(), privilegeKey) + duration = 10 * time.Minute + } + + do.wg.Run(func() { + defer func() { + logutil.BgLogger().Info("loadPrivilegeInLoop exited.") + }() + defer util.Recover(metrics.LabelDomain, "loadPrivilegeInLoop", nil, false) + + var count int + for { + ok := true + select { + case <-do.exit: + return + case _, ok = <-watchCh: + case <-time.After(duration): + } + if !ok { + logutil.BgLogger().Error("load privilege loop watch channel closed") + watchCh = do.etcdClient.Watch(context.Background(), privilegeKey) + count++ + if count > 10 { + time.Sleep(time.Duration(count) * time.Second) + } + continue + } + + count = 0 + err := do.privHandle.Update(sctx) + metrics.LoadPrivilegeCounter.WithLabelValues(metrics.RetLabel(err)).Inc() + if err != nil { + logutil.BgLogger().Error("load privilege failed", zap.Error(err)) + } + } + }, "loadPrivilegeInLoop") + return nil +} + +// LoadSysVarCacheLoop create a goroutine loads sysvar cache in a loop, +// it should be called only once in BootstrapSession. +func (do *Domain) LoadSysVarCacheLoop(ctx sessionctx.Context) error { + ctx.GetSessionVars().InRestrictedSQL = true + err := do.rebuildSysVarCache(ctx) + if err != nil { + return err + } + var watchCh clientv3.WatchChan + duration := 30 * time.Second + if do.etcdClient != nil { + watchCh = do.etcdClient.Watch(context.Background(), sysVarCacheKey) + } + + do.wg.Run(func() { + defer func() { + logutil.BgLogger().Info("LoadSysVarCacheLoop exited.") + }() + defer util.Recover(metrics.LabelDomain, "LoadSysVarCacheLoop", nil, false) + + var count int + for { + ok := true + select { + case <-do.exit: + return + case _, ok = <-watchCh: + case <-time.After(duration): + } + + failpoint.Inject("skipLoadSysVarCacheLoop", func(val failpoint.Value) { + // In some pkg integration test, there are many testSuite, and each testSuite has separate storage and + // `LoadSysVarCacheLoop` background goroutine. Then each testSuite `RebuildSysVarCache` from it's + // own storage. + // Each testSuit will also call `checkEnableServerGlobalVar` to update some local variables. + // That's the problem, each testSuit use different storage to update some same local variables. + // So just skip `RebuildSysVarCache` in some integration testing. + if val.(bool) { + failpoint.Continue() + } + }) + + if !ok { + logutil.BgLogger().Error("LoadSysVarCacheLoop loop watch channel closed") + watchCh = do.etcdClient.Watch(context.Background(), sysVarCacheKey) + count++ + if count > 10 { + time.Sleep(time.Duration(count) * time.Second) + } + continue + } + count = 0 + logutil.BgLogger().Debug("Rebuilding sysvar cache from etcd watch event.") + err := do.rebuildSysVarCache(ctx) + metrics.LoadSysVarCacheCounter.WithLabelValues(metrics.RetLabel(err)).Inc() + if err != nil { + logutil.BgLogger().Error("LoadSysVarCacheLoop failed", zap.Error(err)) + } + } + }, "LoadSysVarCacheLoop") + return nil +} + +// WatchTiFlashComputeNodeChange create a routine to watch if the topology of tiflash_compute node is changed. +// TODO: tiflashComputeNodeKey is not put to etcd yet(finish this when AutoScaler is done) +// +// store cache will only be invalidated every n seconds. +func (do *Domain) WatchTiFlashComputeNodeChange() error { + var watchCh clientv3.WatchChan + if do.etcdClient != nil { + watchCh = do.etcdClient.Watch(context.Background(), tiflashComputeNodeKey) + } + duration := 10 * time.Second + do.wg.Run(func() { + defer func() { + logutil.BgLogger().Info("WatchTiFlashComputeNodeChange exit") + }() + defer util.Recover(metrics.LabelDomain, "WatchTiFlashComputeNodeChange", nil, false) + + var count int + var logCount int + for { + ok := true + var watched bool + select { + case <-do.exit: + return + case _, ok = <-watchCh: + watched = true + case <-time.After(duration): + } + if !ok { + logutil.BgLogger().Error("WatchTiFlashComputeNodeChange watch channel closed") + watchCh = do.etcdClient.Watch(context.Background(), tiflashComputeNodeKey) + count++ + if count > 10 { + time.Sleep(time.Duration(count) * time.Second) + } + continue + } + count = 0 + switch s := do.store.(type) { + case tikv.Storage: + logCount++ + s.GetRegionCache().InvalidateTiFlashComputeStores() + if logCount == 6 { + // Print log every 6*duration seconds. + logutil.BgLogger().Debug("tiflash_compute store cache invalied, will update next query", zap.Bool("watched", watched)) + logCount = 0 + } + default: + logutil.BgLogger().Debug("No need to watch tiflash_compute store cache for non-tikv store") + return + } + } + }, "WatchTiFlashComputeNodeChange") + return nil +} + +// PrivilegeHandle returns the MySQLPrivilege. +func (do *Domain) PrivilegeHandle() *privileges.Handle { + return do.privHandle +} + +// BindHandle returns domain's bindHandle. +func (do *Domain) BindHandle() bindinfo.GlobalBindingHandle { + v := do.bindHandle.Load() + if v == nil { + return nil + } + return v.(bindinfo.GlobalBindingHandle) +} + +// LoadBindInfoLoop create a goroutine loads BindInfo in a loop, it should +// be called only once in BootstrapSession. +func (do *Domain) LoadBindInfoLoop(ctxForHandle sessionctx.Context, ctxForEvolve sessionctx.Context) error { + ctxForHandle.GetSessionVars().InRestrictedSQL = true + ctxForEvolve.GetSessionVars().InRestrictedSQL = true + if !do.bindHandle.CompareAndSwap(nil, bindinfo.NewGlobalBindingHandle(do.sysSessionPool)) { + do.BindHandle().Reset() + } + + err := do.BindHandle().LoadFromStorageToCache(true) + if err != nil || bindinfo.Lease == 0 { + return err + } + + owner := do.newOwnerManager(bindinfo.Prompt, bindinfo.OwnerKey) + do.globalBindHandleWorkerLoop(owner) + return nil +} + +func (do *Domain) globalBindHandleWorkerLoop(owner owner.Manager) { + do.wg.Run(func() { + defer func() { + logutil.BgLogger().Info("globalBindHandleWorkerLoop exited.") + }() + defer util.Recover(metrics.LabelDomain, "globalBindHandleWorkerLoop", nil, false) + + bindWorkerTicker := time.NewTicker(bindinfo.Lease) + gcBindTicker := time.NewTicker(100 * bindinfo.Lease) + defer func() { + bindWorkerTicker.Stop() + gcBindTicker.Stop() + }() + for { + select { + case <-do.exit: + owner.Cancel() + return + case <-bindWorkerTicker.C: + bindHandle := do.BindHandle() + err := bindHandle.LoadFromStorageToCache(false) + if err != nil { + logutil.BgLogger().Error("update bindinfo failed", zap.Error(err)) + } + bindHandle.DropInvalidGlobalBinding() + // Get Global + optVal, err := do.GetGlobalVar(variable.TiDBCapturePlanBaseline) + if err == nil && variable.TiDBOptOn(optVal) { + bindHandle.CaptureBaselines() + } + case <-gcBindTicker.C: + if !owner.IsOwner() { + continue + } + err := do.BindHandle().GCGlobalBinding() + if err != nil { + logutil.BgLogger().Error("GC bind record failed", zap.Error(err)) + } + } + } + }, "globalBindHandleWorkerLoop") +} + +// SetupPlanReplayerHandle setup plan replayer handle +func (do *Domain) SetupPlanReplayerHandle(collectorSctx sessionctx.Context, workersSctxs []sessionctx.Context) { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + do.planReplayerHandle = &planReplayerHandle{} + do.planReplayerHandle.planReplayerTaskCollectorHandle = &planReplayerTaskCollectorHandle{ + ctx: ctx, + sctx: collectorSctx, + } + taskCH := make(chan *PlanReplayerDumpTask, 16) + taskStatus := &planReplayerDumpTaskStatus{} + taskStatus.finishedTaskMu.finishedTask = map[replayer.PlanReplayerTaskKey]struct{}{} + taskStatus.runningTaskMu.runningTasks = map[replayer.PlanReplayerTaskKey]struct{}{} + + do.planReplayerHandle.planReplayerTaskDumpHandle = &planReplayerTaskDumpHandle{ + taskCH: taskCH, + status: taskStatus, + } + do.planReplayerHandle.planReplayerTaskDumpHandle.workers = make([]*planReplayerTaskDumpWorker, 0) + for i := 0; i < len(workersSctxs); i++ { + worker := &planReplayerTaskDumpWorker{ + ctx: ctx, + sctx: workersSctxs[i], + taskCH: taskCH, + status: taskStatus, + } + do.planReplayerHandle.planReplayerTaskDumpHandle.workers = append(do.planReplayerHandle.planReplayerTaskDumpHandle.workers, worker) + } +} + +// RunawayManager returns the runaway manager. +func (do *Domain) RunawayManager() *resourcegroup.RunawayManager { + return do.runawayManager +} + +// ResourceGroupsController returns the resource groups controller. +func (do *Domain) ResourceGroupsController() *rmclient.ResourceGroupsController { + return do.resourceGroupsController +} + +// SetResourceGroupsController is only used in test. +func (do *Domain) SetResourceGroupsController(controller *rmclient.ResourceGroupsController) { + do.resourceGroupsController = controller +} + +// SetupHistoricalStatsWorker setups worker +func (do *Domain) SetupHistoricalStatsWorker(ctx sessionctx.Context) { + do.historicalStatsWorker = &HistoricalStatsWorker{ + tblCH: make(chan int64, 16), + sctx: ctx, + } +} + +// SetupDumpFileGCChecker setup sctx +func (do *Domain) SetupDumpFileGCChecker(ctx sessionctx.Context) { + do.dumpFileGcChecker.setupSctx(ctx) + do.dumpFileGcChecker.planReplayerTaskStatus = do.planReplayerHandle.status +} + +// SetupExtractHandle setups extract handler +func (do *Domain) SetupExtractHandle(sctxs []sessionctx.Context) { + do.extractTaskHandle = newExtractHandler(do.ctx, sctxs) +} + +var planReplayerHandleLease atomic.Uint64 + +func init() { + planReplayerHandleLease.Store(uint64(10 * time.Second)) + enableDumpHistoricalStats.Store(true) +} + +// DisablePlanReplayerBackgroundJob4Test disable plan replayer handle for test +func DisablePlanReplayerBackgroundJob4Test() { + planReplayerHandleLease.Store(0) +} + +// DisableDumpHistoricalStats4Test disable historical dump worker for test +func DisableDumpHistoricalStats4Test() { + enableDumpHistoricalStats.Store(false) +} + +// StartPlanReplayerHandle start plan replayer handle job +func (do *Domain) StartPlanReplayerHandle() { + lease := planReplayerHandleLease.Load() + if lease < 1 { + return + } + do.wg.Run(func() { + logutil.BgLogger().Info("PlanReplayerTaskCollectHandle started") + tikcer := time.NewTicker(time.Duration(lease)) + defer func() { + tikcer.Stop() + logutil.BgLogger().Info("PlanReplayerTaskCollectHandle exited.") + }() + defer util.Recover(metrics.LabelDomain, "PlanReplayerTaskCollectHandle", nil, false) + + for { + select { + case <-do.exit: + return + case <-tikcer.C: + err := do.planReplayerHandle.CollectPlanReplayerTask() + if err != nil { + logutil.BgLogger().Warn("plan replayer handle collect tasks failed", zap.Error(err)) + } + } + } + }, "PlanReplayerTaskCollectHandle") + + do.wg.Run(func() { + logutil.BgLogger().Info("PlanReplayerTaskDumpHandle started") + defer func() { + logutil.BgLogger().Info("PlanReplayerTaskDumpHandle exited.") + }() + defer util.Recover(metrics.LabelDomain, "PlanReplayerTaskDumpHandle", nil, false) + + for _, worker := range do.planReplayerHandle.planReplayerTaskDumpHandle.workers { + go worker.run() + } + <-do.exit + do.planReplayerHandle.planReplayerTaskDumpHandle.Close() + }, "PlanReplayerTaskDumpHandle") +} + +// GetPlanReplayerHandle returns plan replayer handle +func (do *Domain) GetPlanReplayerHandle() *planReplayerHandle { + return do.planReplayerHandle +} + +// GetExtractHandle returns extract handle +func (do *Domain) GetExtractHandle() *ExtractHandle { + return do.extractTaskHandle +} + +// GetDumpFileGCChecker returns dump file GC checker for plan replayer and plan trace +func (do *Domain) GetDumpFileGCChecker() *dumpFileGcChecker { + return do.dumpFileGcChecker +} + +// DumpFileGcCheckerLoop creates a goroutine that handles `exit` and `gc`. +func (do *Domain) DumpFileGcCheckerLoop() { + do.wg.Run(func() { + logutil.BgLogger().Info("dumpFileGcChecker started") + gcTicker := time.NewTicker(do.dumpFileGcChecker.gcLease) + defer func() { + gcTicker.Stop() + logutil.BgLogger().Info("dumpFileGcChecker exited.") + }() + defer util.Recover(metrics.LabelDomain, "dumpFileGcCheckerLoop", nil, false) + + for { + select { + case <-do.exit: + return + case <-gcTicker.C: + do.dumpFileGcChecker.GCDumpFiles(time.Hour, time.Hour*24*7) + } + } + }, "dumpFileGcChecker") +} + +// GetHistoricalStatsWorker gets historical workers +func (do *Domain) GetHistoricalStatsWorker() *HistoricalStatsWorker { + return do.historicalStatsWorker +} + +// EnableDumpHistoricalStats used to control whether enable dump stats for unit test +var enableDumpHistoricalStats atomic.Bool + +// StartHistoricalStatsWorker start historical workers running +func (do *Domain) StartHistoricalStatsWorker() { + if !enableDumpHistoricalStats.Load() { + return + } + do.wg.Run(func() { + logutil.BgLogger().Info("HistoricalStatsWorker started") + defer func() { + logutil.BgLogger().Info("HistoricalStatsWorker exited.") + }() + defer util.Recover(metrics.LabelDomain, "HistoricalStatsWorkerLoop", nil, false) + + for { + select { + case <-do.exit: + close(do.historicalStatsWorker.tblCH) + return + case tblID, ok := <-do.historicalStatsWorker.tblCH: + if !ok { + return + } + err := do.historicalStatsWorker.DumpHistoricalStats(tblID, do.StatsHandle()) + if err != nil { + logutil.BgLogger().Warn("dump historical stats failed", zap.Error(err), zap.Int64("tableID", tblID)) + } + } + } + }, "HistoricalStatsWorker") +} + +// StatsHandle returns the statistic handle. +func (do *Domain) StatsHandle() *handle.Handle { + return do.statsHandle.Load() +} + +// CreateStatsHandle is used only for test. +func (do *Domain) CreateStatsHandle(ctx, initStatsCtx sessionctx.Context) error { + h, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.NextConnID, do.ReleaseConnID) + if err != nil { + return err + } + h.StartWorker() + do.statsHandle.Store(h) + return nil +} + +// StatsUpdating checks if the stats worker is updating. +func (do *Domain) StatsUpdating() bool { + return do.statsUpdating.Load() > 0 +} + +// SetStatsUpdating sets the value of stats updating. +func (do *Domain) SetStatsUpdating(val bool) { + if val { + do.statsUpdating.Store(1) + } else { + do.statsUpdating.Store(0) + } +} + +// ReleaseAnalyzeExec returned extra exec for Analyze +func (do *Domain) ReleaseAnalyzeExec(sctxs []sessionctx.Context) { + do.analyzeMu.Lock() + defer do.analyzeMu.Unlock() + for _, ctx := range sctxs { + do.analyzeMu.sctxs[ctx] = false + } +} + +// FetchAnalyzeExec get needed exec for analyze +func (do *Domain) FetchAnalyzeExec(need int) []sessionctx.Context { + if need < 1 { + return nil + } + count := 0 + r := make([]sessionctx.Context, 0) + do.analyzeMu.Lock() + defer do.analyzeMu.Unlock() + for sctx, used := range do.analyzeMu.sctxs { + if used { + continue + } + r = append(r, sctx) + do.analyzeMu.sctxs[sctx] = true + count++ + if count >= need { + break + } + } + return r +} + +// SetupAnalyzeExec setups exec for Analyze Executor +func (do *Domain) SetupAnalyzeExec(ctxs []sessionctx.Context) { + do.analyzeMu.sctxs = make(map[sessionctx.Context]bool) + for _, ctx := range ctxs { + do.analyzeMu.sctxs[ctx] = false + } +} + +// LoadAndUpdateStatsLoop loads and updates stats info. +func (do *Domain) LoadAndUpdateStatsLoop(ctxs []sessionctx.Context, initStatsCtx sessionctx.Context) error { + if err := do.UpdateTableStatsLoop(ctxs[0], initStatsCtx); err != nil { + return err + } + do.StartLoadStatsSubWorkers(ctxs[1:]) + return nil +} + +// UpdateTableStatsLoop creates a goroutine loads stats info and updates stats info in a loop. +// It will also start a goroutine to analyze tables automatically. +// It should be called only once in BootstrapSession. +func (do *Domain) UpdateTableStatsLoop(ctx, initStatsCtx sessionctx.Context) error { + ctx.GetSessionVars().InRestrictedSQL = true + statsHandle, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.NextConnID, do.ReleaseConnID) + if err != nil { + return err + } + statsHandle.StartWorker() + do.statsHandle.Store(statsHandle) + do.ddl.RegisterStatsHandle(statsHandle) + // Negative stats lease indicates that it is in test or in br binary mode, it does not need update. + if do.statsLease >= 0 { + do.wg.Run(do.loadStatsWorker, "loadStatsWorker") + } + owner := do.newOwnerManager(handle.StatsPrompt, handle.StatsOwnerKey) + do.wg.Run(func() { + do.indexUsageWorker() + }, "indexUsageWorker") + if do.statsLease <= 0 { + // For statsLease > 0, `updateStatsWorker` handles the quit of stats owner. + do.wg.Run(func() { quitStatsOwner(do, owner) }, "quitStatsOwner") + return nil + } + do.SetStatsUpdating(true) + // The stats updated worker doesn't require the stats initialization to be completed. + // This is because the updated worker's primary responsibilities are to update the change delta and handle DDL operations. + // These tasks do not interfere with or depend on the initialization process. + do.wg.Run(func() { do.updateStatsWorker(ctx, owner) }, "updateStatsWorker") + do.wg.Run(func() { + do.handleDDLEvent() + }, "handleDDLEvent") + // Wait for the stats worker to finish the initialization. + // Otherwise, we may start the auto analyze worker before the stats cache is initialized. + do.wg.Run( + func() { + select { + case <-do.StatsHandle().InitStatsDone: + case <-do.exit: // It may happen that before initStatsDone, tidb receive Ctrl+C + return + } + do.autoAnalyzeWorker(owner) + }, + "autoAnalyzeWorker", + ) + do.wg.Run( + func() { + select { + case <-do.StatsHandle().InitStatsDone: + case <-do.exit: // It may happen that before initStatsDone, tidb receive Ctrl+C + return + } + do.analyzeJobsCleanupWorker(owner) + }, + "analyzeJobsCleanupWorker", + ) + do.wg.Run( + func() { + // The initStatsCtx is used to store the internal session for initializing stats, + // so we need the gc min start ts calculation to track it as an internal session. + // Since the session manager may not be ready at this moment, `infosync.StoreInternalSession` can fail. + // we need to retry until the session manager is ready or the init stats completes. + for !infosync.StoreInternalSession(initStatsCtx) { + waitRetry := time.After(time.Second) + select { + case <-do.StatsHandle().InitStatsDone: + return + case <-waitRetry: + } + } + select { + case <-do.StatsHandle().InitStatsDone: + case <-do.exit: // It may happen that before initStatsDone, tidb receive Ctrl+C + return + } + infosync.DeleteInternalSession(initStatsCtx) + }, + "RemoveInitStatsFromInternalSessions", + ) + return nil +} + +func quitStatsOwner(do *Domain, mgr owner.Manager) { + <-do.exit + mgr.Cancel() +} + +// StartLoadStatsSubWorkers starts sub workers with new sessions to load stats concurrently. +func (do *Domain) StartLoadStatsSubWorkers(ctxList []sessionctx.Context) { + statsHandle := do.StatsHandle() + for _, ctx := range ctxList { + do.wg.Add(1) + go statsHandle.SubLoadWorker(ctx, do.exit, do.wg) + } + logutil.BgLogger().Info("start load stats sub workers", zap.Int("worker count", len(ctxList))) +} + +func (do *Domain) newOwnerManager(prompt, ownerKey string) owner.Manager { + id := do.ddl.OwnerManager().ID() + var statsOwner owner.Manager + if do.etcdClient == nil { + statsOwner = owner.NewMockManager(context.Background(), id, do.store, ownerKey) + } else { + statsOwner = owner.NewOwnerManager(context.Background(), do.etcdClient, prompt, id, ownerKey) + } + // TODO: Need to do something when err is not nil. + err := statsOwner.CampaignOwner() + if err != nil { + logutil.BgLogger().Warn("campaign owner failed", zap.Error(err)) + } + return statsOwner +} + +func (do *Domain) initStats(ctx context.Context) { + statsHandle := do.StatsHandle() + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("panic when initiating stats", zap.Any("r", r), + zap.Stack("stack")) + } + close(statsHandle.InitStatsDone) + }() + t := time.Now() + liteInitStats := config.GetGlobalConfig().Performance.LiteInitStats + var err error + if liteInitStats { + err = statsHandle.InitStatsLite(ctx, do.InfoSchema()) + } else { + err = statsHandle.InitStats(ctx, do.InfoSchema()) + } + if err != nil { + logutil.BgLogger().Error("init stats info failed", zap.Bool("lite", liteInitStats), zap.Duration("take time", time.Since(t)), zap.Error(err)) + } else { + logutil.BgLogger().Info("init stats info time", zap.Bool("lite", liteInitStats), zap.Duration("take time", time.Since(t))) + } +} + +func (do *Domain) loadStatsWorker() { + defer util.Recover(metrics.LabelDomain, "loadStatsWorker", nil, false) + lease := do.statsLease + if lease == 0 { + lease = 3 * time.Second + } + loadTicker := time.NewTicker(lease) + updStatsHealthyTicker := time.NewTicker(20 * lease) + defer func() { + loadTicker.Stop() + updStatsHealthyTicker.Stop() + logutil.BgLogger().Info("loadStatsWorker exited.") + }() + + ctx, cancelFunc := context.WithCancel(context.Background()) + do.cancelFns.mu.Lock() + do.cancelFns.fns = append(do.cancelFns.fns, cancelFunc) + do.cancelFns.mu.Unlock() + + do.initStats(ctx) + statsHandle := do.StatsHandle() + var err error + for { + select { + case <-loadTicker.C: + err = statsHandle.Update(ctx, do.InfoSchema()) + if err != nil { + logutil.BgLogger().Debug("update stats info failed", zap.Error(err)) + } + err = statsHandle.LoadNeededHistograms() + if err != nil { + logutil.BgLogger().Debug("load histograms failed", zap.Error(err)) + } + case <-updStatsHealthyTicker.C: + statsHandle.UpdateStatsHealthyMetrics() + case <-do.exit: + return + } + } +} + +func (do *Domain) indexUsageWorker() { + defer util.Recover(metrics.LabelDomain, "indexUsageWorker", nil, false) + gcStatsTicker := time.NewTicker(indexUsageGCDuration) + handle := do.StatsHandle() + defer func() { + logutil.BgLogger().Info("indexUsageWorker exited.") + }() + for { + select { + case <-do.exit: + return + case <-gcStatsTicker.C: + if err := handle.GCIndexUsage(); err != nil { + statslogutil.StatsLogger().Error("gc index usage failed", zap.Error(err)) + } + } + } +} + +func (*Domain) updateStatsWorkerExitPreprocessing(statsHandle *handle.Handle, owner owner.Manager) { + ch := make(chan struct{}, 1) + timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + go func() { + logutil.BgLogger().Info("updateStatsWorker is going to exit, start to flush stats") + statsHandle.FlushStats() + logutil.BgLogger().Info("updateStatsWorker ready to release owner") + owner.Cancel() + ch <- struct{}{} + }() + select { + case <-ch: + logutil.BgLogger().Info("updateStatsWorker exit preprocessing finished") + return + case <-timeout.Done(): + logutil.BgLogger().Warn("updateStatsWorker exit preprocessing timeout, force exiting") + return + } +} + +func (do *Domain) handleDDLEvent() { + logutil.BgLogger().Info("handleDDLEvent started.") + defer util.Recover(metrics.LabelDomain, "handleDDLEvent", nil, false) + statsHandle := do.StatsHandle() + for { + select { + case <-do.exit: + return + // This channel is sent only by ddl owner. + case t := <-statsHandle.DDLEventCh(): + err := statsHandle.HandleDDLEvent(t) + if err != nil { + logutil.BgLogger().Error("handle ddl event failed", zap.String("event", t.String()), zap.Error(err)) + } + } + } +} + +func (do *Domain) updateStatsWorker(_ sessionctx.Context, owner owner.Manager) { + defer util.Recover(metrics.LabelDomain, "updateStatsWorker", nil, false) + logutil.BgLogger().Info("updateStatsWorker started.") + lease := do.statsLease + // We need to have different nodes trigger tasks at different times to avoid the herd effect. + randDuration := time.Duration(rand.Int63n(int64(time.Minute))) + deltaUpdateTicker := time.NewTicker(20*lease + randDuration) + gcStatsTicker := time.NewTicker(100 * lease) + dumpColStatsUsageTicker := time.NewTicker(100 * lease) + readMemTicker := time.NewTicker(memory.ReadMemInterval) + statsHandle := do.StatsHandle() + defer func() { + dumpColStatsUsageTicker.Stop() + gcStatsTicker.Stop() + deltaUpdateTicker.Stop() + readMemTicker.Stop() + do.SetStatsUpdating(false) + logutil.BgLogger().Info("updateStatsWorker exited.") + }() + defer util.Recover(metrics.LabelDomain, "updateStatsWorker", nil, false) + + for { + select { + case <-do.exit: + do.updateStatsWorkerExitPreprocessing(statsHandle, owner) + return + case <-deltaUpdateTicker.C: + err := statsHandle.DumpStatsDeltaToKV(false) + if err != nil { + logutil.BgLogger().Debug("dump stats delta failed", zap.Error(err)) + } + case <-gcStatsTicker.C: + if !owner.IsOwner() { + continue + } + err := statsHandle.GCStats(do.InfoSchema(), do.DDL().GetLease()) + if err != nil { + logutil.BgLogger().Debug("GC stats failed", zap.Error(err)) + } + case <-dumpColStatsUsageTicker.C: + err := statsHandle.DumpColStatsUsageToKV() + if err != nil { + logutil.BgLogger().Debug("dump column stats usage failed", zap.Error(err)) + } + + case <-readMemTicker.C: + memory.ForceReadMemStats() + } + } +} + +func (do *Domain) autoAnalyzeWorker(owner owner.Manager) { + defer util.Recover(metrics.LabelDomain, "autoAnalyzeWorker", nil, false) + statsHandle := do.StatsHandle() + analyzeTicker := time.NewTicker(do.statsLease) + defer func() { + analyzeTicker.Stop() + logutil.BgLogger().Info("autoAnalyzeWorker exited.") + }() + for { + select { + case <-analyzeTicker.C: + if variable.RunAutoAnalyze.Load() && !do.stopAutoAnalyze.Load() && owner.IsOwner() { + statsHandle.HandleAutoAnalyze() + } + case <-do.exit: + return + } + } +} + +// analyzeJobsCleanupWorker is a background worker that periodically performs two main tasks: +// +// 1. Garbage Collection: It removes outdated analyze jobs from the statistics handle. +// This operation is performed every hour and only if the current instance is the owner. +// Analyze jobs older than 7 days are considered outdated and are removed. +// +// 2. Cleanup: It cleans up corrupted analyze jobs. +// A corrupted analyze job is one that is in a 'pending' or 'running' state, +// but is associated with a TiDB instance that is either not currently running or has been restarted. +// Also, if the analyze job is killed by the user, it is considered corrupted. +// This operation is performed every 100 stats leases. +// It first retrieves the list of current analyze processes, then removes any analyze job +// that is not associated with a current process. Additionally, if the current instance is the owner, +// it also cleans up corrupted analyze jobs on dead instances. +func (do *Domain) analyzeJobsCleanupWorker(owner owner.Manager) { + defer util.Recover(metrics.LabelDomain, "analyzeJobsCleanupWorker", nil, false) + // For GC. + const gcInterval = time.Hour + const daysToKeep = 7 + gcTicker := time.NewTicker(gcInterval) + // For clean up. + // Default stats lease is 3 * time.Second. + // So cleanupInterval is 100 * 3 * time.Second = 5 * time.Minute. + var cleanupInterval = do.statsLease * 100 + cleanupTicker := time.NewTicker(cleanupInterval) + defer func() { + gcTicker.Stop() + cleanupTicker.Stop() + logutil.BgLogger().Info("analyzeJobsCleanupWorker exited.") + }() + statsHandle := do.StatsHandle() + for { + select { + case <-gcTicker.C: + // Only the owner should perform this operation. + if owner.IsOwner() { + updateTime := time.Now().AddDate(0, 0, -daysToKeep) + err := statsHandle.DeleteAnalyzeJobs(updateTime) + if err != nil { + logutil.BgLogger().Warn("gc analyze history failed", zap.Error(err)) + } + } + case <-cleanupTicker.C: + sm := do.InfoSyncer().GetSessionManager() + if sm == nil { + continue + } + analyzeProcessIDs := make(map[uint64]struct{}, 8) + for _, process := range sm.ShowProcessList() { + if isAnalyzeTableSQL(process.Info) { + analyzeProcessIDs[process.ID] = struct{}{} + } + } + + err := statsHandle.CleanupCorruptedAnalyzeJobsOnCurrentInstance(analyzeProcessIDs) + if err != nil { + logutil.BgLogger().Warn("cleanup analyze jobs on current instance failed", zap.Error(err)) + } + + if owner.IsOwner() { + err = statsHandle.CleanupCorruptedAnalyzeJobsOnDeadInstances() + if err != nil { + logutil.BgLogger().Warn("cleanup analyze jobs on dead instances failed", zap.Error(err)) + } + } + case <-do.exit: + return + } + } +} + +func isAnalyzeTableSQL(sql string) bool { + // Get rid of the comments. + normalizedSQL := parser.Normalize(sql, "ON") + return strings.HasPrefix(normalizedSQL, "analyze table") +} + +// ExpensiveQueryHandle returns the expensive query handle. +func (do *Domain) ExpensiveQueryHandle() *expensivequery.Handle { + return do.expensiveQueryHandle +} + +// MemoryUsageAlarmHandle returns the memory usage alarm handle. +func (do *Domain) MemoryUsageAlarmHandle() *memoryusagealarm.Handle { + return do.memoryUsageAlarmHandle +} + +// ServerMemoryLimitHandle returns the expensive query handle. +func (do *Domain) ServerMemoryLimitHandle() *servermemorylimit.Handle { + return do.serverMemoryLimitHandle +} + +const ( + privilegeKey = "/tidb/privilege" + sysVarCacheKey = "/tidb/sysvars" + tiflashComputeNodeKey = "/tiflash/new_tiflash_compute_nodes" +) + +// NotifyUpdatePrivilege updates privilege key in etcd, TiDB client that watches +// the key will get notification. +func (do *Domain) NotifyUpdatePrivilege() error { + // No matter skip-grant-table is configured or not, sending an etcd message is required. + // Because we need to tell other TiDB instances to update privilege data, say, we're changing the + // password using a special TiDB instance and want the new password to take effect. + if do.etcdClient != nil { + row := do.etcdClient.KV + _, err := row.Put(context.Background(), privilegeKey, "") + if err != nil { + logutil.BgLogger().Warn("notify update privilege failed", zap.Error(err)) + } + } + + // If skip-grant-table is configured, do not flush privileges. + // Because LoadPrivilegeLoop does not run and the privilege Handle is nil, + // the call to do.PrivilegeHandle().Update would panic. + if config.GetGlobalConfig().Security.SkipGrantTable { + return nil + } + + // update locally + ctx, err := do.sysSessionPool.Get() + if err != nil { + return err + } + defer do.sysSessionPool.Put(ctx) + return do.PrivilegeHandle().Update(ctx.(sessionctx.Context)) +} + +// NotifyUpdateSysVarCache updates the sysvar cache key in etcd, which other TiDB +// clients are subscribed to for updates. For the caller, the cache is also built +// synchronously so that the effect is immediate. +func (do *Domain) NotifyUpdateSysVarCache(updateLocal bool) { + if do.etcdClient != nil { + row := do.etcdClient.KV + _, err := row.Put(context.Background(), sysVarCacheKey, "") + if err != nil { + logutil.BgLogger().Warn("notify update sysvar cache failed", zap.Error(err)) + } + } + // update locally + if updateLocal { + if err := do.rebuildSysVarCache(nil); err != nil { + logutil.BgLogger().Error("rebuilding sysvar cache failed", zap.Error(err)) + } + } +} + +// LoadSigningCertLoop loads the signing cert periodically to make sure it's fresh new. +func (do *Domain) LoadSigningCertLoop(signingCert, signingKey string) { + sessionstates.SetCertPath(signingCert) + sessionstates.SetKeyPath(signingKey) + + do.wg.Run(func() { + defer func() { + logutil.BgLogger().Debug("loadSigningCertLoop exited.") + }() + defer util.Recover(metrics.LabelDomain, "LoadSigningCertLoop", nil, false) + + for { + select { + case <-time.After(sessionstates.LoadCertInterval): + sessionstates.ReloadSigningCert() + case <-do.exit: + return + } + } + }, "loadSigningCertLoop") +} + +// ServerID gets serverID. +func (do *Domain) ServerID() uint64 { + return atomic.LoadUint64(&do.serverID) +} + +// IsLostConnectionToPD indicates lost connection to PD or not. +func (do *Domain) IsLostConnectionToPD() bool { + return do.isLostConnectionToPD.Load() != 0 +} + +// NextConnID return next connection ID. +func (do *Domain) NextConnID() uint64 { + return do.connIDAllocator.NextID() +} + +// ReleaseConnID releases connection ID. +func (do *Domain) ReleaseConnID(connID uint64) { + do.connIDAllocator.Release(connID) +} + +const ( + serverIDEtcdPath = "/tidb/server_id" + refreshServerIDRetryCnt = 3 + acquireServerIDRetryInterval = 300 * time.Millisecond + acquireServerIDTimeout = 10 * time.Second + retrieveServerIDSessionTimeout = 10 * time.Second + + acquire32BitsServerIDRetryCnt = 3 +) + +var ( + // serverIDTTL should be LONG ENOUGH to avoid barbarically killing an on-going long-run SQL. + serverIDTTL = 12 * time.Hour + // serverIDTimeToKeepAlive is the interval that we keep serverID TTL alive periodically. + serverIDTimeToKeepAlive = 5 * time.Minute + // serverIDTimeToCheckPDConnectionRestored is the interval that we check connection to PD restored (after broken) periodically. + serverIDTimeToCheckPDConnectionRestored = 10 * time.Second + // lostConnectionToPDTimeout is the duration that when TiDB cannot connect to PD excceeds this limit, + // we realize the connection to PD is lost utterly, and server ID acquired before should be released. + // Must be SHORTER than `serverIDTTL`. + lostConnectionToPDTimeout = 6 * time.Hour +) + +var ( + ldflagIsGlobalKillTest = "0" // 1:Yes, otherwise:No. + ldflagServerIDTTL = "10" // in seconds. + ldflagServerIDTimeToKeepAlive = "1" // in seconds. + ldflagServerIDTimeToCheckPDConnectionRestored = "1" // in seconds. + ldflagLostConnectionToPDTimeout = "5" // in seconds. +) + +func initByLDFlagsForGlobalKill() { + if ldflagIsGlobalKillTest == "1" { + var ( + i int + err error + ) + + if i, err = strconv.Atoi(ldflagServerIDTTL); err != nil { + panic("invalid ldflagServerIDTTL") + } + serverIDTTL = time.Duration(i) * time.Second + + if i, err = strconv.Atoi(ldflagServerIDTimeToKeepAlive); err != nil { + panic("invalid ldflagServerIDTimeToKeepAlive") + } + serverIDTimeToKeepAlive = time.Duration(i) * time.Second + + if i, err = strconv.Atoi(ldflagServerIDTimeToCheckPDConnectionRestored); err != nil { + panic("invalid ldflagServerIDTimeToCheckPDConnectionRestored") + } + serverIDTimeToCheckPDConnectionRestored = time.Duration(i) * time.Second + + if i, err = strconv.Atoi(ldflagLostConnectionToPDTimeout); err != nil { + panic("invalid ldflagLostConnectionToPDTimeout") + } + lostConnectionToPDTimeout = time.Duration(i) * time.Second + + logutil.BgLogger().Info("global_kill_test is enabled", zap.Duration("serverIDTTL", serverIDTTL), + zap.Duration("serverIDTimeToKeepAlive", serverIDTimeToKeepAlive), + zap.Duration("serverIDTimeToCheckPDConnectionRestored", serverIDTimeToCheckPDConnectionRestored), + zap.Duration("lostConnectionToPDTimeout", lostConnectionToPDTimeout)) + } +} + +func (do *Domain) retrieveServerIDSession(ctx context.Context) (*concurrency.Session, error) { + if do.serverIDSession != nil { + return do.serverIDSession, nil + } + + // `etcdClient.Grant` needs a shortterm timeout, to avoid blocking if connection to PD lost, + // while `etcdClient.KeepAlive` should be longterm. + // So we separately invoke `etcdClient.Grant` and `concurrency.NewSession` with leaseID. + childCtx, cancel := context.WithTimeout(ctx, retrieveServerIDSessionTimeout) + resp, err := do.etcdClient.Grant(childCtx, int64(serverIDTTL.Seconds())) + cancel() + if err != nil { + logutil.BgLogger().Error("retrieveServerIDSession.Grant fail", zap.Error(err)) + return nil, err + } + leaseID := resp.ID + + session, err := concurrency.NewSession(do.etcdClient, + concurrency.WithLease(leaseID), concurrency.WithContext(context.Background())) + if err != nil { + logutil.BgLogger().Error("retrieveServerIDSession.NewSession fail", zap.Error(err)) + return nil, err + } + do.serverIDSession = session + return session, nil +} + +func (do *Domain) acquireServerID(ctx context.Context) error { + atomic.StoreUint64(&do.serverID, 0) + + session, err := do.retrieveServerIDSession(ctx) + if err != nil { + return err + } + + conflictCnt := 0 + for { + var proposeServerID uint64 + if config.GetGlobalConfig().Enable32BitsConnectionID { + proposeServerID, err = do.proposeServerID(ctx, conflictCnt) + if err != nil { + return errors.Trace(err) + } + } else { + // get a random serverID: [1, MaxServerID64] + proposeServerID = uint64(rand.Int63n(int64(globalconn.MaxServerID64)) + 1) // #nosec G404 + } + + key := fmt.Sprintf("%s/%v", serverIDEtcdPath, proposeServerID) + cmp := clientv3.Compare(clientv3.CreateRevision(key), "=", 0) + value := "0" + + childCtx, cancel := context.WithTimeout(ctx, acquireServerIDTimeout) + txn := do.etcdClient.Txn(childCtx) + t := txn.If(cmp) + resp, err := t.Then(clientv3.OpPut(key, value, clientv3.WithLease(session.Lease()))).Commit() + cancel() + if err != nil { + return err + } + if !resp.Succeeded { + logutil.BgLogger().Info("propose serverID exists, try again", zap.Uint64("proposeServerID", proposeServerID)) + time.Sleep(acquireServerIDRetryInterval) + conflictCnt++ + continue + } + + atomic.StoreUint64(&do.serverID, proposeServerID) + logutil.BgLogger().Info("acquireServerID", zap.Uint64("serverID", do.ServerID()), + zap.String("lease id", strconv.FormatInt(int64(session.Lease()), 16))) + return nil + } +} + +func (do *Domain) releaseServerID(context.Context) { + serverID := do.ServerID() + if serverID == 0 { + return + } + atomic.StoreUint64(&do.serverID, 0) + + if do.etcdClient == nil { + return + } + key := fmt.Sprintf("%s/%v", serverIDEtcdPath, serverID) + err := ddlutil.DeleteKeyFromEtcd(key, do.etcdClient, refreshServerIDRetryCnt, acquireServerIDTimeout) + if err != nil { + logutil.BgLogger().Error("releaseServerID fail", zap.Uint64("serverID", serverID), zap.Error(err)) + } else { + logutil.BgLogger().Info("releaseServerID succeed", zap.Uint64("serverID", serverID)) + } +} + +// propose server ID by random. +func (*Domain) proposeServerID(ctx context.Context, conflictCnt int) (uint64, error) { + // get a random server ID in range [min, max] + randomServerID := func(min uint64, max uint64) uint64 { + return uint64(rand.Int63n(int64(max-min+1)) + int64(min)) // #nosec G404 + } + + if conflictCnt < acquire32BitsServerIDRetryCnt { + // get existing server IDs. + allServerInfo, err := infosync.GetAllServerInfo(ctx) + if err != nil { + return 0, errors.Trace(err) + } + // `allServerInfo` contains current TiDB. + if float32(len(allServerInfo)) <= 0.9*float32(globalconn.MaxServerID32) { + serverIDs := make(map[uint64]struct{}, len(allServerInfo)) + for _, info := range allServerInfo { + serverID := info.ServerIDGetter() + if serverID <= globalconn.MaxServerID32 { + serverIDs[serverID] = struct{}{} + } + } + + for retry := 0; retry < 15; retry++ { + randServerID := randomServerID(1, globalconn.MaxServerID32) + if _, ok := serverIDs[randServerID]; !ok { + return randServerID, nil + } + } + } + logutil.BgLogger().Info("upgrade to 64 bits server ID due to used up", zap.Int("len(allServerInfo)", len(allServerInfo))) + } else { + logutil.BgLogger().Info("upgrade to 64 bits server ID due to conflict", zap.Int("conflictCnt", conflictCnt)) + } + + // upgrade to 64 bits. + return randomServerID(globalconn.MaxServerID32+1, globalconn.MaxServerID64), nil +} + +func (do *Domain) refreshServerIDTTL(ctx context.Context) error { + session, err := do.retrieveServerIDSession(ctx) + if err != nil { + return err + } + + key := fmt.Sprintf("%s/%v", serverIDEtcdPath, do.ServerID()) + value := "0" + err = ddlutil.PutKVToEtcd(ctx, do.etcdClient, refreshServerIDRetryCnt, key, value, clientv3.WithLease(session.Lease())) + if err != nil { + logutil.BgLogger().Error("refreshServerIDTTL fail", zap.Uint64("serverID", do.ServerID()), zap.Error(err)) + } else { + logutil.BgLogger().Info("refreshServerIDTTL succeed", zap.Uint64("serverID", do.ServerID()), + zap.String("lease id", strconv.FormatInt(int64(session.Lease()), 16))) + } + return err +} + +func (do *Domain) serverIDKeeper() { + defer func() { + do.wg.Done() + logutil.BgLogger().Info("serverIDKeeper exited.") + }() + defer util.Recover(metrics.LabelDomain, "serverIDKeeper", func() { + logutil.BgLogger().Info("recover serverIDKeeper.") + // should be called before `do.wg.Done()`, to ensure that Domain.Close() waits for the new `serverIDKeeper()` routine. + do.wg.Add(1) + go do.serverIDKeeper() + }, false) + + tickerKeepAlive := time.NewTicker(serverIDTimeToKeepAlive) + tickerCheckRestored := time.NewTicker(serverIDTimeToCheckPDConnectionRestored) + defer func() { + tickerKeepAlive.Stop() + tickerCheckRestored.Stop() + }() + + blocker := make(chan struct{}) // just used for blocking the sessionDone() when session is nil. + sessionDone := func() <-chan struct{} { + if do.serverIDSession == nil { + return blocker + } + return do.serverIDSession.Done() + } + + var lastSucceedTimestamp time.Time + + onConnectionToPDRestored := func() { + logutil.BgLogger().Info("restored connection to PD") + do.isLostConnectionToPD.Store(0) + lastSucceedTimestamp = time.Now() + + if err := do.info.StoreServerInfo(context.Background()); err != nil { + logutil.BgLogger().Error("StoreServerInfo failed", zap.Error(err)) + } + } + + onConnectionToPDLost := func() { + logutil.BgLogger().Warn("lost connection to PD") + do.isLostConnectionToPD.Store(1) + + // Kill all connections when lost connection to PD, + // to avoid the possibility that another TiDB instance acquires the same serverID and generates a same connection ID, + // which will lead to a wrong connection killed. + do.InfoSyncer().GetSessionManager().KillAllConnections() + } + + for { + select { + case <-tickerKeepAlive.C: + if !do.IsLostConnectionToPD() { + if err := do.refreshServerIDTTL(context.Background()); err == nil { + lastSucceedTimestamp = time.Now() + } else { + if lostConnectionToPDTimeout > 0 && time.Since(lastSucceedTimestamp) > lostConnectionToPDTimeout { + onConnectionToPDLost() + } + } + } + case <-tickerCheckRestored.C: + if do.IsLostConnectionToPD() { + if err := do.acquireServerID(context.Background()); err == nil { + onConnectionToPDRestored() + } + } + case <-sessionDone(): + // inform that TTL of `serverID` is expired. See https://godoc.org/github.com/coreos/etcd/clientv3/concurrency#Session.Done + // Should be in `IsLostConnectionToPD` state, as `lostConnectionToPDTimeout` is shorter than `serverIDTTL`. + // So just set `do.serverIDSession = nil` to restart `serverID` session in `retrieveServerIDSession()`. + logutil.BgLogger().Info("serverIDSession need restart") + do.serverIDSession = nil + case <-do.exit: + return + } + } +} + +// StartTTLJobManager creates and starts the ttl job manager +func (do *Domain) StartTTLJobManager() { + ttlJobManager := ttlworker.NewJobManager(do.ddl.GetID(), do.sysSessionPool, do.store, do.etcdClient, do.ddl.OwnerManager().IsOwner) + do.ttlJobManager.Store(ttlJobManager) + ttlJobManager.Start() +} + +// TTLJobManager returns the ttl job manager on this domain +func (do *Domain) TTLJobManager() *ttlworker.JobManager { + return do.ttlJobManager.Load() +} + +// StopAutoAnalyze stops (*Domain).autoAnalyzeWorker to launch new auto analyze jobs. +func (do *Domain) StopAutoAnalyze() { + do.stopAutoAnalyze.Store(true) +} + +// InitInstancePlanCache initializes the instance level plan cache for this Domain. +func (do *Domain) InitInstancePlanCache() { + softLimit := variable.InstancePlanCacheTargetMemSize.Load() + hardLimit := variable.InstancePlanCacheMaxMemSize.Load() + do.instancePlanCache = NewInstancePlanCache(softLimit, hardLimit) + // use a separate goroutine to avoid the eviction blocking other operations. + do.wg.Run(do.planCacheEvictTrigger, "planCacheEvictTrigger") + do.wg.Run(do.planCacheMetricsAndVars, "planCacheMetricsAndVars") +} + +// GetInstancePlanCache returns the instance level plan cache in this Domain. +func (do *Domain) GetInstancePlanCache() sessionctx.InstancePlanCache { + return do.instancePlanCache +} + +// planCacheMetricsAndVars updates metrics and variables for Instance Plan Cache periodically. +func (do *Domain) planCacheMetricsAndVars() { + defer util.Recover(metrics.LabelDomain, "planCacheMetricsAndVars", nil, false) + ticker := time.NewTicker(time.Second * 15) // 15s by default + defer func() { + ticker.Stop() + logutil.BgLogger().Info("planCacheMetricsAndVars exited.") + }() + + for { + select { + case <-ticker.C: + // update limits + softLimit := variable.InstancePlanCacheTargetMemSize.Load() + hardLimit := variable.InstancePlanCacheMaxMemSize.Load() + curSoft, curHard := do.instancePlanCache.GetLimits() + if curSoft != softLimit || curHard != hardLimit { + do.instancePlanCache.SetLimits(softLimit, hardLimit) + } + + // update the metrics + size := do.instancePlanCache.Size() + memUsage := do.instancePlanCache.MemUsage() + metrics2.GetPlanCacheInstanceNumCounter(true).Set(float64(size)) + metrics2.GetPlanCacheInstanceMemoryUsage(true).Set(float64(memUsage)) + case <-do.exit: + return + } + } +} + +// planCacheEvictTrigger triggers the plan cache eviction periodically. +func (do *Domain) planCacheEvictTrigger() { + defer util.Recover(metrics.LabelDomain, "planCacheEvictTrigger", nil, false) + ticker := time.NewTicker(time.Second * 15) // 15s by default + defer func() { + ticker.Stop() + logutil.BgLogger().Info("planCacheEvictTrigger exited.") + }() + + for { + select { + case <-ticker.C: + // trigger the eviction + do.instancePlanCache.Evict() + case <-do.exit: + return + } + } +} + +func init() { + initByLDFlagsForGlobalKill() +} + +var ( + // ErrInfoSchemaExpired returns the error that information schema is out of date. + ErrInfoSchemaExpired = dbterror.ClassDomain.NewStd(errno.ErrInfoSchemaExpired) + // ErrInfoSchemaChanged returns the error that information schema is changed. + ErrInfoSchemaChanged = dbterror.ClassDomain.NewStdErr(errno.ErrInfoSchemaChanged, + mysql.Message(errno.MySQLErrName[errno.ErrInfoSchemaChanged].Raw+". "+kv.TxnRetryableMark, nil)) +) + +// SysProcesses holds the sys processes infos +type SysProcesses struct { + mu *sync.RWMutex + procMap map[uint64]sysproctrack.TrackProc +} + +// Track tracks the sys process into procMap +func (s *SysProcesses) Track(id uint64, proc sysproctrack.TrackProc) error { + s.mu.Lock() + defer s.mu.Unlock() + if oldProc, ok := s.procMap[id]; ok && oldProc != proc { + return errors.Errorf("The ID is in use: %v", id) + } + s.procMap[id] = proc + proc.GetSessionVars().ConnectionID = id + proc.GetSessionVars().SQLKiller.Reset() + return nil +} + +// UnTrack removes the sys process from procMap +func (s *SysProcesses) UnTrack(id uint64) { + s.mu.Lock() + defer s.mu.Unlock() + if proc, ok := s.procMap[id]; ok { + delete(s.procMap, id) + proc.GetSessionVars().ConnectionID = 0 + proc.GetSessionVars().SQLKiller.Reset() + } +} + +// GetSysProcessList gets list of system ProcessInfo +func (s *SysProcesses) GetSysProcessList() map[uint64]*util.ProcessInfo { + s.mu.RLock() + defer s.mu.RUnlock() + rs := make(map[uint64]*util.ProcessInfo) + for connID, proc := range s.procMap { + // if session is still tracked in this map, it's not returned to sysSessionPool yet + if pi := proc.ShowProcess(); pi != nil && pi.ID == connID { + rs[connID] = pi + } + } + return rs +} + +// KillSysProcess kills sys process with specified ID +func (s *SysProcesses) KillSysProcess(id uint64) { + s.mu.Lock() + defer s.mu.Unlock() + if proc, ok := s.procMap[id]; ok { + proc.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted) + } +} diff --git a/pkg/domain/historical_stats.go b/pkg/domain/historical_stats.go index 9b4dd016d2711..4f8cb16ce8ff9 100644 --- a/pkg/domain/historical_stats.go +++ b/pkg/domain/historical_stats.go @@ -35,11 +35,11 @@ type HistoricalStatsWorker struct { // SendTblToDumpHistoricalStats send tableID to worker to dump historical stats func (w *HistoricalStatsWorker) SendTblToDumpHistoricalStats(tableID int64) { send := enableDumpHistoricalStats.Load() - failpoint.Inject("sendHistoricalStats", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("sendHistoricalStats")); _err_ == nil { if val.(bool) { send = true } - }) + } if !send { return } diff --git a/pkg/domain/historical_stats.go__failpoint_stash__ b/pkg/domain/historical_stats.go__failpoint_stash__ new file mode 100644 index 0000000000000..9b4dd016d2711 --- /dev/null +++ b/pkg/domain/historical_stats.go__failpoint_stash__ @@ -0,0 +1,98 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package domain + +import ( + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + domain_metrics "github.com/pingcap/tidb/pkg/domain/metrics" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/statistics/handle" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// HistoricalStatsWorker indicates for dump historical stats +type HistoricalStatsWorker struct { + tblCH chan int64 + sctx sessionctx.Context +} + +// SendTblToDumpHistoricalStats send tableID to worker to dump historical stats +func (w *HistoricalStatsWorker) SendTblToDumpHistoricalStats(tableID int64) { + send := enableDumpHistoricalStats.Load() + failpoint.Inject("sendHistoricalStats", func(val failpoint.Value) { + if val.(bool) { + send = true + } + }) + if !send { + return + } + select { + case w.tblCH <- tableID: + return + default: + logutil.BgLogger().Warn("discard dump historical stats task", zap.Int64("table-id", tableID)) + } +} + +// DumpHistoricalStats dump stats by given tableID +func (w *HistoricalStatsWorker) DumpHistoricalStats(tableID int64, statsHandle *handle.Handle) error { + historicalStatsEnabled, err := statsHandle.CheckHistoricalStatsEnable() + if err != nil { + return errors.Errorf("check tidb_enable_historical_stats failed: %v", err) + } + if !historicalStatsEnabled { + return nil + } + sctx := w.sctx + is := GetDomain(sctx).InfoSchema() + isPartition := false + var tblInfo *model.TableInfo + tbl, existed := is.TableByID(tableID) + if !existed { + tbl, db, p := is.FindTableByPartitionID(tableID) + if !(tbl != nil && db != nil && p != nil) { + return errors.Errorf("cannot get table by id %d", tableID) + } + isPartition = true + tblInfo = tbl.Meta() + } else { + tblInfo = tbl.Meta() + } + dbInfo, existed := infoschema.SchemaByTable(is, tblInfo) + if !existed { + return errors.Errorf("cannot get DBInfo by TableID %d", tableID) + } + if _, err := statsHandle.RecordHistoricalStatsToStorage(dbInfo.Name.O, tblInfo, tableID, isPartition); err != nil { + domain_metrics.GenerateHistoricalStatsFailedCounter.Inc() + return errors.Errorf("record table %s.%s's historical stats failed, err:%v", dbInfo.Name.O, tblInfo.Name.O, err) + } + domain_metrics.GenerateHistoricalStatsSuccessCounter.Inc() + return nil +} + +// GetOneHistoricalStatsTable gets one tableID from channel, only used for test +func (w *HistoricalStatsWorker) GetOneHistoricalStatsTable() int64 { + select { + case tblID := <-w.tblCH: + return tblID + default: + return -1 + } +} diff --git a/pkg/domain/infosync/binding__failpoint_binding__.go b/pkg/domain/infosync/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..c4cc7873735bf --- /dev/null +++ b/pkg/domain/infosync/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package infosync + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/domain/infosync/info.go b/pkg/domain/infosync/info.go index 1e5e0cb51e6e6..9845ee180a7e1 100644 --- a/pkg/domain/infosync/info.go +++ b/pkg/domain/infosync/info.go @@ -276,7 +276,7 @@ func (is *InfoSyncer) initResourceManagerClient(pdCli pd.Client) { if pdCli == nil { cli = NewMockResourceManagerClient() } - failpoint.Inject("managerAlreadyCreateSomeGroups", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("managerAlreadyCreateSomeGroups")); _err_ == nil { if val.(bool) { _, err := cli.AddResourceGroup(context.TODO(), &rmpb.ResourceGroup{ @@ -305,7 +305,7 @@ func (is *InfoSyncer) initResourceManagerClient(pdCli pd.Client) { log.Warn("fail to create default group", zap.Error(err)) } } - }) + } is.resourceManagerClient = cli } @@ -355,11 +355,11 @@ func SetMockTiFlash(tiflash *MockTiFlash) { // GetServerInfo gets self server static information. func GetServerInfo() (*ServerInfo, error) { - failpoint.Inject("mockGetServerInfo", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("mockGetServerInfo")); _err_ == nil { var res ServerInfo err := json.Unmarshal([]byte(v.(string)), &res) - failpoint.Return(&res, err) - }) + return &res, err + } is, err := getGlobalInfoSyncer() if err != nil { return nil, err @@ -394,11 +394,11 @@ func (is *InfoSyncer) getServerInfoByID(ctx context.Context, id string) (*Server // GetAllServerInfo gets all servers static information from etcd. func GetAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { - failpoint.Inject("mockGetAllServerInfo", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockGetAllServerInfo")); _err_ == nil { res := make(map[string]*ServerInfo) err := json.Unmarshal([]byte(val.(string)), &res) - failpoint.Return(res, err) - }) + return res, err + } is, err := getGlobalInfoSyncer() if err != nil { return nil, err @@ -538,15 +538,15 @@ func GetRuleBundle(ctx context.Context, name string) (*placement.Bundle, error) // PutRuleBundles is used to post specific rule bundles to PD. func PutRuleBundles(ctx context.Context, bundles []*placement.Bundle) error { - failpoint.Inject("putRuleBundlesError", func(isServiceError failpoint.Value) { + if isServiceError, _err_ := failpoint.Eval(_curpkg_("putRuleBundlesError")); _err_ == nil { var err error if isServiceError.(bool) { err = ErrHTTPServiceError.FastGen("mock service error") } else { err = errors.New("mock other error") } - failpoint.Return(err) - }) + return err + } is, err := getGlobalInfoSyncer() if err != nil { @@ -1034,14 +1034,14 @@ func getServerInfo(id string, serverIDGetter func() uint64) *ServerInfo { metrics.ServerInfo.WithLabelValues(mysql.TiDBReleaseVersion, info.GitHash).Set(float64(info.StartTimestamp)) - failpoint.Inject("mockServerInfo", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockServerInfo")); _err_ == nil { if val.(bool) { info.StartTimestamp = 1282967700 info.Labels = map[string]string{ "foo": "bar", } } - }) + } return info } @@ -1337,11 +1337,11 @@ type TiProxyServerInfo struct { // GetTiProxyServerInfo gets all TiProxy servers information from etcd. func GetTiProxyServerInfo(ctx context.Context) (map[string]*TiProxyServerInfo, error) { - failpoint.Inject("mockGetTiProxyServerInfo", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockGetTiProxyServerInfo")); _err_ == nil { res := make(map[string]*TiProxyServerInfo) err := json.Unmarshal([]byte(val.(string)), &res) - failpoint.Return(res, err) - }) + return res, err + } is, err := getGlobalInfoSyncer() if err != nil { return nil, err diff --git a/pkg/domain/infosync/info.go__failpoint_stash__ b/pkg/domain/infosync/info.go__failpoint_stash__ new file mode 100644 index 0000000000000..1e5e0cb51e6e6 --- /dev/null +++ b/pkg/domain/infosync/info.go__failpoint_stash__ @@ -0,0 +1,1457 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infosync + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "path" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + rmpb "github.com/pingcap/kvproto/pkg/resource_manager" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/label" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/session/cursor" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + util2 "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/engine" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/versioninfo" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + "go.uber.org/zap" +) + +const ( + // ServerInformationPath store server information such as IP, port and so on. + ServerInformationPath = "/tidb/server/info" + // ServerMinStartTSPath store the server min start timestamp. + ServerMinStartTSPath = "/tidb/server/minstartts" + // TiFlashTableSyncProgressPath store the tiflash table replica sync progress. + TiFlashTableSyncProgressPath = "/tiflash/table/sync" + // keyOpDefaultRetryCnt is the default retry count for etcd store. + keyOpDefaultRetryCnt = 5 + // keyOpDefaultTimeout is the default time out for etcd store. + keyOpDefaultTimeout = 1 * time.Second + // ReportInterval is interval of infoSyncerKeeper reporting min startTS. + ReportInterval = 30 * time.Second + // TopologyInformationPath means etcd path for storing topology info. + TopologyInformationPath = "/topology/tidb" + // TopologySessionTTL is ttl for topology, ant it's the ETCD session's TTL in seconds. + TopologySessionTTL = 45 + // TopologyTimeToRefresh means time to refresh etcd. + TopologyTimeToRefresh = 30 * time.Second + // TopologyPrometheus means address of prometheus. + TopologyPrometheus = "/topology/prometheus" + // TopologyTiProxy means address of TiProxy. + TopologyTiProxy = "/topology/tiproxy" + // infoSuffix is the suffix of TiDB/TiProxy topology info. + infoSuffix = "/info" + // TopologyTiCDC means address of TiCDC. + TopologyTiCDC = "/topology/ticdc" + // TablePrometheusCacheExpiry is the expiry time for prometheus address cache. + TablePrometheusCacheExpiry = 10 * time.Second + // RequestRetryInterval is the sleep time before next retry for http request + RequestRetryInterval = 200 * time.Millisecond + // SyncBundlesMaxRetry is the max retry times for sync placement bundles + SyncBundlesMaxRetry = 3 +) + +// ErrPrometheusAddrIsNotSet is the error that Prometheus address is not set in PD and etcd +var ErrPrometheusAddrIsNotSet = dbterror.ClassDomain.NewStd(errno.ErrPrometheusAddrIsNotSet) + +// InfoSyncer stores server info to etcd when the tidb-server starts and delete when tidb-server shuts down. +type InfoSyncer struct { + // `etcdClient` must be used when keyspace is not set, or when the logic to each etcd path needs to be separated by keyspace. + etcdCli *clientv3.Client + // `unprefixedEtcdCli` will never set the etcd namespace prefix by keyspace. + // It is only used in storeMinStartTS and RemoveMinStartTS now. + // It must be used when the etcd path isn't needed to separate by keyspace. + // See keyspace RFC: https://github.com/pingcap/tidb/pull/39685 + unprefixedEtcdCli *clientv3.Client + pdHTTPCli pdhttp.Client + info *ServerInfo + serverInfoPath string + minStartTS uint64 + minStartTSPath string + managerMu struct { + mu sync.RWMutex + util2.SessionManager + } + session *concurrency.Session + topologySession *concurrency.Session + prometheusAddr string + modifyTime time.Time + labelRuleManager LabelRuleManager + placementManager PlacementManager + scheduleManager ScheduleManager + tiflashReplicaManager TiFlashReplicaManager + resourceManagerClient pd.ResourceManagerClient +} + +// ServerInfo is server static information. +// It will not be updated when tidb-server running. So please only put static information in ServerInfo struct. +type ServerInfo struct { + ServerVersionInfo + ID string `json:"ddl_id"` + IP string `json:"ip"` + Port uint `json:"listening_port"` + StatusPort uint `json:"status_port"` + Lease string `json:"lease"` + BinlogStatus string `json:"binlog_status"` + StartTimestamp int64 `json:"start_timestamp"` + Labels map[string]string `json:"labels"` + // ServerID is a function, to always retrieve latest serverID from `Domain`, + // which will be changed on occasions such as connection to PD is restored after broken. + ServerIDGetter func() uint64 `json:"-"` + + // JSONServerID is `serverID` for json marshal/unmarshal ONLY. + JSONServerID uint64 `json:"server_id"` +} + +// Marshal `ServerInfo` into bytes. +func (info *ServerInfo) Marshal() ([]byte, error) { + info.JSONServerID = info.ServerIDGetter() + infoBuf, err := json.Marshal(info) + if err != nil { + return nil, errors.Trace(err) + } + return infoBuf, nil +} + +// Unmarshal `ServerInfo` from bytes. +func (info *ServerInfo) Unmarshal(v []byte) error { + if err := json.Unmarshal(v, info); err != nil { + return err + } + info.ServerIDGetter = func() uint64 { + return info.JSONServerID + } + return nil +} + +// ServerVersionInfo is the server version and git_hash. +type ServerVersionInfo struct { + Version string `json:"version"` + GitHash string `json:"git_hash"` +} + +// globalInfoSyncer stores the global infoSyncer. +// Use a global variable for simply the code, use the domain.infoSyncer will have circle import problem in some pkg. +// Use atomic.Pointer to avoid data race in the test. +var globalInfoSyncer atomic.Pointer[InfoSyncer] + +func getGlobalInfoSyncer() (*InfoSyncer, error) { + v := globalInfoSyncer.Load() + if v == nil { + return nil, errors.New("infoSyncer is not initialized") + } + return v, nil +} + +func setGlobalInfoSyncer(is *InfoSyncer) { + globalInfoSyncer.Store(is) +} + +// GlobalInfoSyncerInit return a new InfoSyncer. It is exported for testing. +func GlobalInfoSyncerInit( + ctx context.Context, + id string, + serverIDGetter func() uint64, + etcdCli, unprefixedEtcdCli *clientv3.Client, + pdCli pd.Client, pdHTTPCli pdhttp.Client, + codec tikv.Codec, + skipRegisterToDashBoard bool, +) (*InfoSyncer, error) { + if pdHTTPCli != nil { + pdHTTPCli = pdHTTPCli. + WithCallerID("tidb-info-syncer"). + WithRespHandler(pdResponseHandler) + } + is := &InfoSyncer{ + etcdCli: etcdCli, + unprefixedEtcdCli: unprefixedEtcdCli, + pdHTTPCli: pdHTTPCli, + info: getServerInfo(id, serverIDGetter), + serverInfoPath: fmt.Sprintf("%s/%s", ServerInformationPath, id), + minStartTSPath: fmt.Sprintf("%s/%s", ServerMinStartTSPath, id), + } + err := is.init(ctx, skipRegisterToDashBoard) + if err != nil { + return nil, err + } + is.initLabelRuleManager() + is.initPlacementManager() + is.initScheduleManager() + is.initTiFlashReplicaManager(codec) + is.initResourceManagerClient(pdCli) + setGlobalInfoSyncer(is) + return is, nil +} + +// Init creates a new etcd session and stores server info to etcd. +func (is *InfoSyncer) init(ctx context.Context, skipRegisterToDashboard bool) error { + err := is.newSessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) + if err != nil { + return err + } + if skipRegisterToDashboard { + return nil + } + return is.newTopologySessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) +} + +// SetSessionManager set the session manager for InfoSyncer. +func (is *InfoSyncer) SetSessionManager(manager util2.SessionManager) { + is.managerMu.mu.Lock() + defer is.managerMu.mu.Unlock() + is.managerMu.SessionManager = manager +} + +// GetSessionManager get the session manager. +func (is *InfoSyncer) GetSessionManager() util2.SessionManager { + is.managerMu.mu.RLock() + defer is.managerMu.mu.RUnlock() + return is.managerMu.SessionManager +} + +func (is *InfoSyncer) initLabelRuleManager() { + if is.pdHTTPCli == nil { + is.labelRuleManager = &mockLabelManager{labelRules: map[string][]byte{}} + return + } + is.labelRuleManager = &PDLabelManager{is.pdHTTPCli} +} + +func (is *InfoSyncer) initPlacementManager() { + if is.pdHTTPCli == nil { + is.placementManager = &mockPlacementManager{} + return + } + is.placementManager = &PDPlacementManager{is.pdHTTPCli} +} + +func (is *InfoSyncer) initResourceManagerClient(pdCli pd.Client) { + var cli pd.ResourceManagerClient = pdCli + if pdCli == nil { + cli = NewMockResourceManagerClient() + } + failpoint.Inject("managerAlreadyCreateSomeGroups", func(val failpoint.Value) { + if val.(bool) { + _, err := cli.AddResourceGroup(context.TODO(), + &rmpb.ResourceGroup{ + Name: resourcegroup.DefaultResourceGroupName, + Mode: rmpb.GroupMode_RUMode, + RUSettings: &rmpb.GroupRequestUnitSettings{ + RU: &rmpb.TokenBucket{ + Settings: &rmpb.TokenLimitSettings{FillRate: 1000000, BurstLimit: -1}, + }, + }, + }) + if err != nil { + log.Warn("fail to create default group", zap.Error(err)) + } + _, err = cli.AddResourceGroup(context.TODO(), + &rmpb.ResourceGroup{ + Name: "oltp", + Mode: rmpb.GroupMode_RUMode, + RUSettings: &rmpb.GroupRequestUnitSettings{ + RU: &rmpb.TokenBucket{ + Settings: &rmpb.TokenLimitSettings{FillRate: 1000000, BurstLimit: -1}, + }, + }, + }) + if err != nil { + log.Warn("fail to create default group", zap.Error(err)) + } + } + }) + is.resourceManagerClient = cli +} + +func (is *InfoSyncer) initTiFlashReplicaManager(codec tikv.Codec) { + if is.pdHTTPCli == nil { + is.tiflashReplicaManager = &mockTiFlashReplicaManagerCtx{tiflashProgressCache: make(map[int64]float64)} + return + } + logutil.BgLogger().Warn("init TiFlashReplicaManager") + is.tiflashReplicaManager = &TiFlashReplicaManagerCtx{pdHTTPCli: is.pdHTTPCli, tiflashProgressCache: make(map[int64]float64), codec: codec} +} + +func (is *InfoSyncer) initScheduleManager() { + if is.pdHTTPCli == nil { + is.scheduleManager = &mockScheduleManager{} + return + } + is.scheduleManager = &PDScheduleManager{is.pdHTTPCli} +} + +// GetMockTiFlash can only be used in tests to get MockTiFlash +func GetMockTiFlash() *MockTiFlash { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil + } + + m, ok := is.tiflashReplicaManager.(*mockTiFlashReplicaManagerCtx) + if ok { + return m.tiflash + } + return nil +} + +// SetMockTiFlash can only be used in tests to set MockTiFlash +func SetMockTiFlash(tiflash *MockTiFlash) { + is, err := getGlobalInfoSyncer() + if err != nil { + return + } + + m, ok := is.tiflashReplicaManager.(*mockTiFlashReplicaManagerCtx) + if ok { + m.SetMockTiFlash(tiflash) + } +} + +// GetServerInfo gets self server static information. +func GetServerInfo() (*ServerInfo, error) { + failpoint.Inject("mockGetServerInfo", func(v failpoint.Value) { + var res ServerInfo + err := json.Unmarshal([]byte(v.(string)), &res) + failpoint.Return(&res, err) + }) + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + return is.info, nil +} + +// GetServerInfoByID gets specified server static information from etcd. +func GetServerInfoByID(ctx context.Context, id string) (*ServerInfo, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + return is.getServerInfoByID(ctx, id) +} + +func (is *InfoSyncer) getServerInfoByID(ctx context.Context, id string) (*ServerInfo, error) { + if is.etcdCli == nil || id == is.info.ID { + return is.info, nil + } + key := fmt.Sprintf("%s/%s", ServerInformationPath, id) + infoMap, err := getInfo(ctx, is.etcdCli, key, keyOpDefaultRetryCnt, keyOpDefaultTimeout) + if err != nil { + return nil, err + } + info, ok := infoMap[id] + if !ok { + return nil, errors.Errorf("[info-syncer] get %s failed", key) + } + return info, nil +} + +// GetAllServerInfo gets all servers static information from etcd. +func GetAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { + failpoint.Inject("mockGetAllServerInfo", func(val failpoint.Value) { + res := make(map[string]*ServerInfo) + err := json.Unmarshal([]byte(val.(string)), &res) + failpoint.Return(res, err) + }) + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + return is.getAllServerInfo(ctx) +} + +// UpdateServerLabel updates the server label for global info syncer. +func UpdateServerLabel(ctx context.Context, labels map[string]string) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + // when etcdCli is nil, the server infos are generated from the latest config, no need to update. + if is.etcdCli == nil { + return nil + } + selfInfo, err := is.getServerInfoByID(ctx, is.info.ID) + if err != nil { + return err + } + changed := false + for k, v := range labels { + if selfInfo.Labels[k] != v { + changed = true + selfInfo.Labels[k] = v + } + } + if !changed { + return nil + } + infoBuf, err := selfInfo.Marshal() + if err != nil { + return errors.Trace(err) + } + str := string(hack.String(infoBuf)) + err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) + return err +} + +// DeleteTiFlashTableSyncProgress is used to delete the tiflash table replica sync progress. +func DeleteTiFlashTableSyncProgress(tableInfo *model.TableInfo) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + if pi := tableInfo.GetPartitionInfo(); pi != nil { + for _, p := range pi.Definitions { + is.tiflashReplicaManager.DeleteTiFlashProgressFromCache(p.ID) + } + } else { + is.tiflashReplicaManager.DeleteTiFlashProgressFromCache(tableInfo.ID) + } + return nil +} + +// MustGetTiFlashProgress gets tiflash replica progress from tiflashProgressCache, if cache not exist, it calculates progress from PD and TiFlash and inserts progress into cache. +func MustGetTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores *map[int64]pdhttp.StoreInfo) (float64, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return 0, err + } + progressCache, isExist := is.tiflashReplicaManager.GetTiFlashProgressFromCache(tableID) + if isExist { + return progressCache, nil + } + if *tiFlashStores == nil { + // We need the up-to-date information about TiFlash stores. + // Since TiFlash Replica synchronize may happen immediately after new TiFlash stores are added. + tikvStats, err := is.tiflashReplicaManager.GetStoresStat(context.Background()) + // If MockTiFlash is not set, will issue a MockTiFlashError here. + if err != nil { + return 0, err + } + stores := make(map[int64]pdhttp.StoreInfo) + for _, store := range tikvStats.Stores { + if engine.IsTiFlashHTTPResp(&store.Store) { + stores[store.Store.ID] = store + } + } + *tiFlashStores = stores + logutil.BgLogger().Debug("updateTiFlashStores finished", zap.Int("TiFlash store count", len(*tiFlashStores))) + } + progress, err := is.tiflashReplicaManager.CalculateTiFlashProgress(tableID, replicaCount, *tiFlashStores) + if err != nil { + return 0, err + } + is.tiflashReplicaManager.UpdateTiFlashProgressCache(tableID, progress) + return progress, nil +} + +// pdResponseHandler will be injected into the PD HTTP client to handle the response, +// this is to maintain consistency with the original logic without the PD HTTP client. +func pdResponseHandler(resp *http.Response, res any) error { + defer func() { terror.Log(resp.Body.Close()) }() + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + if resp.StatusCode == http.StatusOK { + if res != nil && bodyBytes != nil { + return json.Unmarshal(bodyBytes, res) + } + return nil + } + logutil.BgLogger().Warn("response not 200", + zap.String("method", resp.Request.Method), + zap.String("host", resp.Request.URL.Host), + zap.String("url", resp.Request.URL.RequestURI()), + zap.Int("http status", resp.StatusCode), + ) + if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusPreconditionFailed { + return ErrHTTPServiceError.FastGen("%s", bodyBytes) + } + return nil +} + +// GetAllRuleBundles is used to get all rule bundles from PD It is used to load full rules from PD while fullload infoschema. +func GetAllRuleBundles(ctx context.Context) ([]*placement.Bundle, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + + return is.placementManager.GetAllRuleBundles(ctx) +} + +// GetRuleBundle is used to get one specific rule bundle from PD. +func GetRuleBundle(ctx context.Context, name string) (*placement.Bundle, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + + return is.placementManager.GetRuleBundle(ctx, name) +} + +// PutRuleBundles is used to post specific rule bundles to PD. +func PutRuleBundles(ctx context.Context, bundles []*placement.Bundle) error { + failpoint.Inject("putRuleBundlesError", func(isServiceError failpoint.Value) { + var err error + if isServiceError.(bool) { + err = ErrHTTPServiceError.FastGen("mock service error") + } else { + err = errors.New("mock other error") + } + failpoint.Return(err) + }) + + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + + return is.placementManager.PutRuleBundles(ctx, bundles) +} + +// PutRuleBundlesWithRetry will retry for specified times when PutRuleBundles failed +func PutRuleBundlesWithRetry(ctx context.Context, bundles []*placement.Bundle, maxRetry int, interval time.Duration) (err error) { + if maxRetry < 0 { + maxRetry = 0 + } + + for i := 0; i <= maxRetry; i++ { + if err = PutRuleBundles(ctx, bundles); err == nil || ErrHTTPServiceError.Equal(err) { + return err + } + + if i != maxRetry { + logutil.BgLogger().Warn("Error occurs when PutRuleBundles, retry", zap.Error(err)) + time.Sleep(interval) + } + } + + return +} + +// GetResourceGroup is used to get one specific resource group from resource manager. +func GetResourceGroup(ctx context.Context, name string) (*rmpb.ResourceGroup, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + + return is.resourceManagerClient.GetResourceGroup(ctx, name) +} + +// ListResourceGroups is used to get all resource groups from resource manager. +func ListResourceGroups(ctx context.Context) ([]*rmpb.ResourceGroup, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + + return is.resourceManagerClient.ListResourceGroups(ctx) +} + +// AddResourceGroup is used to create one specific resource group to resource manager. +func AddResourceGroup(ctx context.Context, group *rmpb.ResourceGroup) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + _, err = is.resourceManagerClient.AddResourceGroup(ctx, group) + return err +} + +// ModifyResourceGroup is used to modify one specific resource group to resource manager. +func ModifyResourceGroup(ctx context.Context, group *rmpb.ResourceGroup) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + _, err = is.resourceManagerClient.ModifyResourceGroup(ctx, group) + return err +} + +// DeleteResourceGroup is used to delete one specific resource group from resource manager. +func DeleteResourceGroup(ctx context.Context, name string) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + _, err = is.resourceManagerClient.DeleteResourceGroup(ctx, name) + return err +} + +// PutRuleBundlesWithDefaultRetry will retry for default times +func PutRuleBundlesWithDefaultRetry(ctx context.Context, bundles []*placement.Bundle) (err error) { + return PutRuleBundlesWithRetry(ctx, bundles, SyncBundlesMaxRetry, RequestRetryInterval) +} + +func (is *InfoSyncer) getAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { + allInfo := make(map[string]*ServerInfo) + if is.etcdCli == nil { + allInfo[is.info.ID] = getServerInfo(is.info.ID, is.info.ServerIDGetter) + return allInfo, nil + } + allInfo, err := getInfo(ctx, is.etcdCli, ServerInformationPath, keyOpDefaultRetryCnt, keyOpDefaultTimeout, clientv3.WithPrefix()) + if err != nil { + return nil, err + } + return allInfo, nil +} + +// StoreServerInfo stores self server static information to etcd. +func (is *InfoSyncer) StoreServerInfo(ctx context.Context) error { + if is.etcdCli == nil { + return nil + } + infoBuf, err := is.info.Marshal() + if err != nil { + return errors.Trace(err) + } + str := string(hack.String(infoBuf)) + err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) + return err +} + +// RemoveServerInfo remove self server static information from etcd. +func (is *InfoSyncer) RemoveServerInfo() { + if is.etcdCli == nil { + return + } + err := util.DeleteKeyFromEtcd(is.serverInfoPath, is.etcdCli, keyOpDefaultRetryCnt, keyOpDefaultTimeout) + if err != nil { + logutil.BgLogger().Error("remove server info failed", zap.Error(err)) + } +} + +// TopologyInfo is the topology info +type TopologyInfo struct { + ServerVersionInfo + IP string `json:"ip"` + StatusPort uint `json:"status_port"` + DeployPath string `json:"deploy_path"` + StartTimestamp int64 `json:"start_timestamp"` + Labels map[string]string `json:"labels"` +} + +func (is *InfoSyncer) getTopologyInfo() TopologyInfo { + s, err := os.Executable() + if err != nil { + s = "" + } + dir := path.Dir(s) + return TopologyInfo{ + ServerVersionInfo: ServerVersionInfo{ + Version: mysql.TiDBReleaseVersion, + GitHash: is.info.ServerVersionInfo.GitHash, + }, + IP: is.info.IP, + StatusPort: is.info.StatusPort, + DeployPath: dir, + StartTimestamp: is.info.StartTimestamp, + Labels: is.info.Labels, + } +} + +// StoreTopologyInfo stores the topology of tidb to etcd. +func (is *InfoSyncer) StoreTopologyInfo(ctx context.Context) error { + if is.etcdCli == nil { + return nil + } + topologyInfo := is.getTopologyInfo() + infoBuf, err := json.Marshal(topologyInfo) + if err != nil { + return errors.Trace(err) + } + str := string(hack.String(infoBuf)) + key := fmt.Sprintf("%s/%s/info", TopologyInformationPath, net.JoinHostPort(is.info.IP, strconv.Itoa(int(is.info.Port)))) + // Note: no lease is required here. + err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, key, str) + if err != nil { + return err + } + // Initialize ttl. + return is.updateTopologyAliveness(ctx) +} + +// GetMinStartTS get min start timestamp. +// Export for testing. +func (is *InfoSyncer) GetMinStartTS() uint64 { + return is.minStartTS +} + +// storeMinStartTS stores self server min start timestamp to etcd. +func (is *InfoSyncer) storeMinStartTS(ctx context.Context) error { + if is.unprefixedEtcdCli == nil { + return nil + } + return util.PutKVToEtcd(ctx, is.unprefixedEtcdCli, keyOpDefaultRetryCnt, is.minStartTSPath, + strconv.FormatUint(is.minStartTS, 10), + clientv3.WithLease(is.session.Lease())) +} + +// RemoveMinStartTS removes self server min start timestamp from etcd. +func (is *InfoSyncer) RemoveMinStartTS() { + if is.unprefixedEtcdCli == nil { + return + } + err := util.DeleteKeyFromEtcd(is.minStartTSPath, is.unprefixedEtcdCli, keyOpDefaultRetryCnt, keyOpDefaultTimeout) + if err != nil { + logutil.BgLogger().Error("remove minStartTS failed", zap.Error(err)) + } +} + +// ReportMinStartTS reports self server min start timestamp to ETCD. +func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) { + sm := is.GetSessionManager() + if sm == nil { + return + } + pl := sm.ShowProcessList() + innerSessionStartTSList := sm.GetInternalSessionStartTSList() + + // Calculate the lower limit of the start timestamp to avoid extremely old transaction delaying GC. + currentVer, err := store.CurrentVersion(kv.GlobalTxnScope) + if err != nil { + logutil.BgLogger().Error("update minStartTS failed", zap.Error(err)) + return + } + now := oracle.GetTimeFromTS(currentVer.Ver) + // GCMaxWaitTime is in seconds, GCMaxWaitTime * 1000 converts it to milliseconds. + startTSLowerLimit := oracle.GoTimeToLowerLimitStartTS(now, variable.GCMaxWaitTime.Load()*1000) + minStartTS := oracle.GoTimeToTS(now) + logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("initial minStartTS", minStartTS), + zap.Uint64("StartTSLowerLimit", startTSLowerLimit)) + for _, info := range pl { + if info.CurTxnStartTS > startTSLowerLimit && info.CurTxnStartTS < minStartTS { + minStartTS = info.CurTxnStartTS + } + + if info.CursorTracker != nil { + info.CursorTracker.RangeCursor(func(c cursor.Handle) bool { + startTS := c.GetState().StartTS + if startTS > startTSLowerLimit && startTS < minStartTS { + minStartTS = startTS + } + return true + }) + } + } + + for _, innerTS := range innerSessionStartTSList { + logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("Internal Session Transaction StartTS", innerTS)) + kv.PrintLongTimeInternalTxn(now, innerTS, false) + if innerTS > startTSLowerLimit && innerTS < minStartTS { + minStartTS = innerTS + } + } + + is.minStartTS = kv.GetMinInnerTxnStartTS(now, startTSLowerLimit, minStartTS) + + err = is.storeMinStartTS(context.Background()) + if err != nil { + logutil.BgLogger().Error("update minStartTS failed", zap.Error(err)) + } + logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("final minStartTS", is.minStartTS)) +} + +// Done returns a channel that closes when the info syncer is no longer being refreshed. +func (is *InfoSyncer) Done() <-chan struct{} { + if is.etcdCli == nil { + return make(chan struct{}, 1) + } + return is.session.Done() +} + +// TopologyDone returns a channel that closes when the topology syncer is no longer being refreshed. +func (is *InfoSyncer) TopologyDone() <-chan struct{} { + if is.etcdCli == nil { + return make(chan struct{}, 1) + } + return is.topologySession.Done() +} + +// Restart restart the info syncer with new session leaseID and store server info to etcd again. +func (is *InfoSyncer) Restart(ctx context.Context) error { + return is.newSessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) +} + +// RestartTopology restart the topology syncer with new session leaseID and store server info to etcd again. +func (is *InfoSyncer) RestartTopology(ctx context.Context) error { + return is.newTopologySessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) +} + +// GetAllTiDBTopology gets all tidb topology +func (is *InfoSyncer) GetAllTiDBTopology(ctx context.Context) ([]*TopologyInfo, error) { + topos := make([]*TopologyInfo, 0) + response, err := is.etcdCli.Get(ctx, TopologyInformationPath, clientv3.WithPrefix()) + if err != nil { + return nil, err + } + for _, kv := range response.Kvs { + if !strings.HasSuffix(string(kv.Key), "/info") { + continue + } + var topo *TopologyInfo + err = json.Unmarshal(kv.Value, &topo) + if err != nil { + return nil, err + } + topos = append(topos, topo) + } + return topos, nil +} + +// newSessionAndStoreServerInfo creates a new etcd session and stores server info to etcd. +func (is *InfoSyncer) newSessionAndStoreServerInfo(ctx context.Context, retryCnt int) error { + if is.etcdCli == nil { + return nil + } + logPrefix := fmt.Sprintf("[Info-syncer] %s", is.serverInfoPath) + session, err := util2.NewSession(ctx, logPrefix, is.etcdCli, retryCnt, util.SessionTTL) + if err != nil { + return err + } + is.session = session + binloginfo.RegisterStatusListener(func(status binloginfo.BinlogStatus) error { + is.info.BinlogStatus = status.String() + err := is.StoreServerInfo(ctx) + return errors.Trace(err) + }) + return is.StoreServerInfo(ctx) +} + +// newTopologySessionAndStoreServerInfo creates a new etcd session and stores server info to etcd. +func (is *InfoSyncer) newTopologySessionAndStoreServerInfo(ctx context.Context, retryCnt int) error { + if is.etcdCli == nil { + return nil + } + logPrefix := fmt.Sprintf("[topology-syncer] %s/%s", TopologyInformationPath, net.JoinHostPort(is.info.IP, strconv.Itoa(int(is.info.Port)))) + session, err := util2.NewSession(ctx, logPrefix, is.etcdCli, retryCnt, TopologySessionTTL) + if err != nil { + return err + } + + is.topologySession = session + return is.StoreTopologyInfo(ctx) +} + +// refreshTopology refreshes etcd topology with ttl stored in "/topology/tidb/ip:port/ttl". +func (is *InfoSyncer) updateTopologyAliveness(ctx context.Context) error { + if is.etcdCli == nil { + return nil + } + key := fmt.Sprintf("%s/%s/ttl", TopologyInformationPath, net.JoinHostPort(is.info.IP, strconv.Itoa(int(is.info.Port)))) + return util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, key, + fmt.Sprintf("%v", time.Now().UnixNano()), + clientv3.WithLease(is.topologySession.Lease())) +} + +// GetPrometheusAddr gets prometheus Address +func GetPrometheusAddr() (string, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return "", err + } + + // if the cache of prometheusAddr is over 10s, update the prometheusAddr + if time.Since(is.modifyTime) < TablePrometheusCacheExpiry { + return is.prometheusAddr, nil + } + return is.getPrometheusAddr() +} + +type prometheus struct { + IP string `json:"ip"` + BinaryPath string `json:"binary_path"` + Port int `json:"port"` +} + +type metricStorage struct { + PDServer struct { + MetricStorage string `json:"metric-storage"` + } `json:"pd-server"` +} + +func (is *InfoSyncer) getPrometheusAddr() (string, error) { + // Get PD servers info. + clientAvailable := is.etcdCli != nil + var pdAddrs []string + if clientAvailable { + pdAddrs = is.etcdCli.Endpoints() + } + if !clientAvailable || len(pdAddrs) == 0 { + return "", errors.Errorf("pd unavailable") + } + // Get prometheus address from pdhttp. + url := util2.ComposeURL(pdAddrs[0], pdhttp.Config) + resp, err := util2.InternalHTTPClient().Get(url) + if err != nil { + return "", err + } + defer resp.Body.Close() + var metricStorage metricStorage + dec := json.NewDecoder(resp.Body) + err = dec.Decode(&metricStorage) + if err != nil { + return "", err + } + res := metricStorage.PDServer.MetricStorage + + // Get prometheus address from etcdApi. + if res == "" { + values, err := is.getPrometheusAddrFromEtcd(TopologyPrometheus) + if err != nil { + return "", errors.Trace(err) + } + if values == "" { + return "", ErrPrometheusAddrIsNotSet + } + var prometheus prometheus + err = json.Unmarshal([]byte(values), &prometheus) + if err != nil { + return "", errors.Trace(err) + } + res = fmt.Sprintf("http://%s", net.JoinHostPort(prometheus.IP, strconv.Itoa(prometheus.Port))) + } + is.prometheusAddr = res + is.modifyTime = time.Now() + setGlobalInfoSyncer(is) + return res, nil +} + +func (is *InfoSyncer) getPrometheusAddrFromEtcd(k string) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), keyOpDefaultTimeout) + resp, err := is.etcdCli.Get(ctx, k) + cancel() + if err != nil { + return "", errors.Trace(err) + } + if len(resp.Kvs) > 0 { + return string(resp.Kvs[0].Value), nil + } + return "", nil +} + +// getInfo gets server information from etcd according to the key and opts. +func getInfo(ctx context.Context, etcdCli *clientv3.Client, key string, retryCnt int, timeout time.Duration, opts ...clientv3.OpOption) (map[string]*ServerInfo, error) { + var err error + var resp *clientv3.GetResponse + allInfo := make(map[string]*ServerInfo) + for i := 0; i < retryCnt; i++ { + select { + case <-ctx.Done(): + err = errors.Trace(ctx.Err()) + return nil, err + default: + } + childCtx, cancel := context.WithTimeout(ctx, timeout) + resp, err = etcdCli.Get(childCtx, key, opts...) + cancel() + if err != nil { + logutil.BgLogger().Info("get key failed", zap.String("key", key), zap.Error(err)) + time.Sleep(200 * time.Millisecond) + continue + } + for _, kv := range resp.Kvs { + info := &ServerInfo{ + BinlogStatus: binloginfo.BinlogStatusUnknown.String(), + } + err = info.Unmarshal(kv.Value) + if err != nil { + logutil.BgLogger().Info("get key failed", zap.String("key", string(kv.Key)), zap.ByteString("value", kv.Value), + zap.Error(err)) + return nil, errors.Trace(err) + } + allInfo[info.ID] = info + } + return allInfo, nil + } + return nil, errors.Trace(err) +} + +// getServerInfo gets self tidb server information. +func getServerInfo(id string, serverIDGetter func() uint64) *ServerInfo { + cfg := config.GetGlobalConfig() + info := &ServerInfo{ + ID: id, + IP: cfg.AdvertiseAddress, + Port: cfg.Port, + StatusPort: cfg.Status.StatusPort, + Lease: cfg.Lease, + BinlogStatus: binloginfo.GetStatus().String(), + StartTimestamp: time.Now().Unix(), + Labels: cfg.Labels, + ServerIDGetter: serverIDGetter, + } + info.Version = mysql.ServerVersion + info.GitHash = versioninfo.TiDBGitHash + + metrics.ServerInfo.WithLabelValues(mysql.TiDBReleaseVersion, info.GitHash).Set(float64(info.StartTimestamp)) + + failpoint.Inject("mockServerInfo", func(val failpoint.Value) { + if val.(bool) { + info.StartTimestamp = 1282967700 + info.Labels = map[string]string{ + "foo": "bar", + } + } + }) + + return info +} + +// PutLabelRule synchronizes the label rule to PD. +func PutLabelRule(ctx context.Context, rule *label.Rule) error { + if rule == nil { + return nil + } + + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + if is.labelRuleManager == nil { + return nil + } + return is.labelRuleManager.PutLabelRule(ctx, rule) +} + +// UpdateLabelRules synchronizes the label rule to PD. +func UpdateLabelRules(ctx context.Context, patch *pdhttp.LabelRulePatch) error { + if patch == nil || (len(patch.DeleteRules) == 0 && len(patch.SetRules) == 0) { + return nil + } + + is, err := getGlobalInfoSyncer() + if err != nil { + return err + } + if is.labelRuleManager == nil { + return nil + } + return is.labelRuleManager.UpdateLabelRules(ctx, patch) +} + +// GetAllLabelRules gets all label rules from PD. +func GetAllLabelRules(ctx context.Context) ([]*label.Rule, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + if is.labelRuleManager == nil { + return nil, nil + } + return is.labelRuleManager.GetAllLabelRules(ctx) +} + +// GetLabelRules gets the label rules according to the given IDs from PD. +func GetLabelRules(ctx context.Context, ruleIDs []string) (map[string]*label.Rule, error) { + if len(ruleIDs) == 0 { + return nil, nil + } + + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + if is.labelRuleManager == nil { + return nil, nil + } + return is.labelRuleManager.GetLabelRules(ctx, ruleIDs) +} + +// CalculateTiFlashProgress calculates TiFlash replica progress +func CalculateTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores map[int64]pdhttp.StoreInfo) (float64, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return 0, errors.Trace(err) + } + return is.tiflashReplicaManager.CalculateTiFlashProgress(tableID, replicaCount, tiFlashStores) +} + +// UpdateTiFlashProgressCache updates tiflashProgressCache +func UpdateTiFlashProgressCache(tableID int64, progress float64) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + is.tiflashReplicaManager.UpdateTiFlashProgressCache(tableID, progress) + return nil +} + +// GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache +func GetTiFlashProgressFromCache(tableID int64) (float64, bool) { + is, err := getGlobalInfoSyncer() + if err != nil { + logutil.BgLogger().Error("GetTiFlashProgressFromCache get info sync failed", zap.Int64("tableID", tableID), zap.Error(err)) + return 0, false + } + return is.tiflashReplicaManager.GetTiFlashProgressFromCache(tableID) +} + +// CleanTiFlashProgressCache clean progress cache +func CleanTiFlashProgressCache() { + is, err := getGlobalInfoSyncer() + if err != nil { + return + } + is.tiflashReplicaManager.CleanTiFlashProgressCache() +} + +// SetTiFlashGroupConfig is a helper function to set tiflash rule group config +func SetTiFlashGroupConfig(ctx context.Context) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + logutil.BgLogger().Info("SetTiFlashGroupConfig") + return is.tiflashReplicaManager.SetTiFlashGroupConfig(ctx) +} + +// SetTiFlashPlacementRule is a helper function to set placement rule. +// It is discouraged to use SetTiFlashPlacementRule directly, +// use `ConfigureTiFlashPDForTable`/`ConfigureTiFlashPDForPartitions` instead. +func SetTiFlashPlacementRule(ctx context.Context, rule pdhttp.Rule) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + logutil.BgLogger().Info("SetTiFlashPlacementRule", zap.String("ruleID", rule.ID)) + return is.tiflashReplicaManager.SetPlacementRule(ctx, &rule) +} + +// DeleteTiFlashPlacementRules is a helper function to delete TiFlash placement rules of given physical table IDs. +func DeleteTiFlashPlacementRules(ctx context.Context, physicalTableIDs []int64) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + logutil.BgLogger().Info("DeleteTiFlashPlacementRules", zap.Int64s("physicalTableIDs", physicalTableIDs)) + rules := make([]*pdhttp.Rule, 0, len(physicalTableIDs)) + for _, id := range physicalTableIDs { + // make a rule with count 0 to delete the rule + rule := MakeNewRule(id, 0, nil) + rules = append(rules, &rule) + } + return is.tiflashReplicaManager.SetPlacementRuleBatch(ctx, rules) +} + +// GetTiFlashGroupRules to get all placement rule in a certain group. +func GetTiFlashGroupRules(ctx context.Context, group string) ([]*pdhttp.Rule, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, errors.Trace(err) + } + return is.tiflashReplicaManager.GetGroupRules(ctx, group) +} + +// GetTiFlashRegionCountFromPD is a helper function calling `/stats/region`. +func GetTiFlashRegionCountFromPD(ctx context.Context, tableID int64, regionCount *int) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + return is.tiflashReplicaManager.GetRegionCountFromPD(ctx, tableID, regionCount) +} + +// GetTiFlashStoresStat gets the TiKV store information by accessing PD's api. +func GetTiFlashStoresStat(ctx context.Context) (*pdhttp.StoresInfo, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, errors.Trace(err) + } + return is.tiflashReplicaManager.GetStoresStat(ctx) +} + +// CloseTiFlashManager closes TiFlash manager. +func CloseTiFlashManager(ctx context.Context) { + is, err := getGlobalInfoSyncer() + if err != nil { + return + } + is.tiflashReplicaManager.Close(ctx) +} + +// ConfigureTiFlashPDForTable configures pd rule for unpartitioned tables. +func ConfigureTiFlashPDForTable(id int64, count uint64, locationLabels *[]string) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + ctx := context.Background() + logutil.BgLogger().Info("ConfigureTiFlashPDForTable", zap.Int64("tableID", id), zap.Uint64("count", count)) + ruleNew := MakeNewRule(id, count, *locationLabels) + if e := is.tiflashReplicaManager.SetPlacementRule(ctx, &ruleNew); e != nil { + return errors.Trace(e) + } + return nil +} + +// ConfigureTiFlashPDForPartitions configures pd rule for all partition in partitioned tables. +func ConfigureTiFlashPDForPartitions(accel bool, definitions *[]model.PartitionDefinition, count uint64, locationLabels *[]string, tableID int64) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + ctx := context.Background() + rules := make([]*pdhttp.Rule, 0, len(*definitions)) + pids := make([]int64, 0, len(*definitions)) + for _, p := range *definitions { + logutil.BgLogger().Info("ConfigureTiFlashPDForPartitions", zap.Int64("tableID", tableID), zap.Int64("partID", p.ID), zap.Bool("accel", accel), zap.Uint64("count", count)) + ruleNew := MakeNewRule(p.ID, count, *locationLabels) + rules = append(rules, &ruleNew) + pids = append(pids, p.ID) + } + if e := is.tiflashReplicaManager.SetPlacementRuleBatch(ctx, rules); e != nil { + return errors.Trace(e) + } + if accel { + if e := is.tiflashReplicaManager.PostAccelerateScheduleBatch(ctx, pids); e != nil { + return errors.Trace(e) + } + } + return nil +} + +// StoreInternalSession is the entry function for store an internal session to SessionManager. +// return whether the session is stored successfully. +func StoreInternalSession(se any) bool { + is, err := getGlobalInfoSyncer() + if err != nil { + return false + } + sm := is.GetSessionManager() + if sm == nil { + return false + } + sm.StoreInternalSession(se) + return true +} + +// DeleteInternalSession is the entry function for delete an internal session from SessionManager. +func DeleteInternalSession(se any) { + is, err := getGlobalInfoSyncer() + if err != nil { + return + } + sm := is.GetSessionManager() + if sm == nil { + return + } + sm.DeleteInternalSession(se) +} + +// SetEtcdClient is only used for test. +func SetEtcdClient(etcdCli *clientv3.Client) { + is, err := getGlobalInfoSyncer() + + if err != nil { + return + } + is.etcdCli = etcdCli +} + +// GetEtcdClient is only used for test. +func GetEtcdClient() *clientv3.Client { + is, err := getGlobalInfoSyncer() + + if err != nil { + return nil + } + return is.etcdCli +} + +// GetPDScheduleConfig gets the schedule information from pd +func GetPDScheduleConfig(ctx context.Context) (map[string]any, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, errors.Trace(err) + } + return is.scheduleManager.GetScheduleConfig(ctx) +} + +// SetPDScheduleConfig sets the schedule information for pd +func SetPDScheduleConfig(ctx context.Context, config map[string]any) error { + is, err := getGlobalInfoSyncer() + if err != nil { + return errors.Trace(err) + } + return is.scheduleManager.SetScheduleConfig(ctx, config) +} + +// TiProxyServerInfo is the server info for TiProxy. +type TiProxyServerInfo struct { + Version string `json:"version"` + GitHash string `json:"git_hash"` + IP string `json:"ip"` + Port string `json:"port"` + StatusPort string `json:"status_port"` + StartTimestamp int64 `json:"start_timestamp"` +} + +// GetTiProxyServerInfo gets all TiProxy servers information from etcd. +func GetTiProxyServerInfo(ctx context.Context) (map[string]*TiProxyServerInfo, error) { + failpoint.Inject("mockGetTiProxyServerInfo", func(val failpoint.Value) { + res := make(map[string]*TiProxyServerInfo) + err := json.Unmarshal([]byte(val.(string)), &res) + failpoint.Return(res, err) + }) + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + return is.getTiProxyServerInfo(ctx) +} + +func (is *InfoSyncer) getTiProxyServerInfo(ctx context.Context) (map[string]*TiProxyServerInfo, error) { + // In test. + if is.etcdCli == nil { + return nil, nil + } + + var err error + var resp *clientv3.GetResponse + allInfo := make(map[string]*TiProxyServerInfo) + for i := 0; i < keyOpDefaultRetryCnt; i++ { + if ctx.Err() != nil { + return nil, errors.Trace(ctx.Err()) + } + childCtx, cancel := context.WithTimeout(ctx, keyOpDefaultTimeout) + resp, err = is.etcdCli.Get(childCtx, TopologyTiProxy, clientv3.WithPrefix()) + cancel() + if err != nil { + logutil.BgLogger().Info("get key failed", zap.String("key", TopologyTiProxy), zap.Error(err)) + time.Sleep(200 * time.Millisecond) + continue + } + for _, kv := range resp.Kvs { + key := string(kv.Key) + if !strings.HasSuffix(key, infoSuffix) { + continue + } + addr := key[len(TopologyTiProxy)+1 : len(key)-len(infoSuffix)] + var info TiProxyServerInfo + err = json.Unmarshal(kv.Value, &info) + if err != nil { + logutil.BgLogger().Info("unmarshal key failed", zap.String("key", key), zap.ByteString("value", kv.Value), + zap.Error(err)) + return nil, errors.Trace(err) + } + allInfo[addr] = &info + } + return allInfo, nil + } + return nil, errors.Trace(err) +} + +// TiCDCInfo is the server info for TiCDC. +type TiCDCInfo struct { + ID string `json:"id"` + Address string `json:"address"` + Version string `json:"version"` + GitHash string `json:"git-hash"` + DeployPath string `json:"deploy-path"` + StartTimestamp int64 `json:"start-timestamp"` + ClusterID string `json:"cluster-id"` +} + +// GetTiCDCServerInfo gets all TiCDC servers information from etcd. +func GetTiCDCServerInfo(ctx context.Context) ([]*TiCDCInfo, error) { + is, err := getGlobalInfoSyncer() + if err != nil { + return nil, err + } + return is.getTiCDCServerInfo(ctx) +} + +func (is *InfoSyncer) getTiCDCServerInfo(ctx context.Context) ([]*TiCDCInfo, error) { + // In test. + if is.etcdCli == nil { + return nil, nil + } + + var err error + var resp *clientv3.GetResponse + allInfo := make([]*TiCDCInfo, 0) + for i := 0; i < keyOpDefaultRetryCnt; i++ { + if ctx.Err() != nil { + return nil, errors.Trace(ctx.Err()) + } + childCtx, cancel := context.WithTimeout(ctx, keyOpDefaultTimeout) + resp, err = is.etcdCli.Get(childCtx, TopologyTiCDC, clientv3.WithPrefix()) + cancel() + if err != nil { + logutil.BgLogger().Info("get key failed", zap.String("key", TopologyTiCDC), zap.Error(err)) + time.Sleep(200 * time.Millisecond) + continue + } + for _, kv := range resp.Kvs { + key := string(kv.Key) + keyParts := strings.Split(key, "/") + if len(keyParts) < 3 { + logutil.BgLogger().Info("invalid ticdc key", zap.String("key", key)) + continue + } + clusterID := keyParts[1] + + var info TiCDCInfo + err := json.Unmarshal(kv.Value, &info) + if err != nil { + logutil.BgLogger().Info("unmarshal key failed", zap.String("key", key), zap.ByteString("value", kv.Value), + zap.Error(err)) + return nil, errors.Trace(err) + } + info.Version = strings.TrimPrefix(info.Version, "v") + info.ClusterID = clusterID + allInfo = append(allInfo, &info) + } + return allInfo, nil + } + return nil, errors.Trace(err) +} diff --git a/pkg/domain/infosync/tiflash_manager.go b/pkg/domain/infosync/tiflash_manager.go index bd84d5fb4c043..040a7611e50ad 100644 --- a/pkg/domain/infosync/tiflash_manager.go +++ b/pkg/domain/infosync/tiflash_manager.go @@ -95,11 +95,11 @@ func getTiFlashPeerWithoutLagCount(tiFlashStores map[int64]pd.StoreInfo, keyspac for _, store := range tiFlashStores { regionReplica := make(map[int64]int) err := helper.CollectTiFlashStatus(store.Store.StatusAddress, keyspaceID, tableID, ®ionReplica) - failpoint.Inject("OneTiFlashStoreDown", func() { + if _, _err_ := failpoint.Eval(_curpkg_("OneTiFlashStoreDown")); _err_ == nil { if store.Store.StateName == "Down" { err = errors.New("mock TiFlasah down") } - }) + } if err != nil { logutil.BgLogger().Error("Fail to get peer status from TiFlash.", zap.Int64("tableID", tableID)) diff --git a/pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ b/pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ new file mode 100644 index 0000000000000..bd84d5fb4c043 --- /dev/null +++ b/pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ @@ -0,0 +1,893 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infosync + +import ( + "bytes" + "context" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/tikv/client-go/v2/tikv" + pd "github.com/tikv/pd/client/http" + "go.uber.org/zap" +) + +var ( + _ TiFlashReplicaManager = &TiFlashReplicaManagerCtx{} + _ TiFlashReplicaManager = &mockTiFlashReplicaManagerCtx{} +) + +// TiFlashReplicaManager manages placement settings and replica progress for TiFlash. +type TiFlashReplicaManager interface { + // SetTiFlashGroupConfig sets the group index of the tiflash placement rule + SetTiFlashGroupConfig(ctx context.Context) error + // SetPlacementRule is a helper function to set placement rule. + SetPlacementRule(ctx context.Context, rule *pd.Rule) error + // SetPlacementRuleBatch is a helper function to set a batch of placement rules. + SetPlacementRuleBatch(ctx context.Context, rules []*pd.Rule) error + // DeletePlacementRule is to delete placement rule for certain group. + DeletePlacementRule(ctx context.Context, group string, ruleID string) error + // GetGroupRules to get all placement rule in a certain group. + GetGroupRules(ctx context.Context, group string) ([]*pd.Rule, error) + // PostAccelerateScheduleBatch sends `regions/accelerate-schedule/batch` request. + PostAccelerateScheduleBatch(ctx context.Context, tableIDs []int64) error + // GetRegionCountFromPD is a helper function calling `/stats/region`. + GetRegionCountFromPD(ctx context.Context, tableID int64, regionCount *int) error + // GetStoresStat gets the TiKV store information by accessing PD's api. + GetStoresStat(ctx context.Context) (*pd.StoresInfo, error) + // CalculateTiFlashProgress calculates TiFlash replica progress + CalculateTiFlashProgress(tableID int64, replicaCount uint64, TiFlashStores map[int64]pd.StoreInfo) (float64, error) + // UpdateTiFlashProgressCache updates tiflashProgressCache + UpdateTiFlashProgressCache(tableID int64, progress float64) + // GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache + GetTiFlashProgressFromCache(tableID int64) (float64, bool) + // DeleteTiFlashProgressFromCache delete tiflash replica progress from tiflashProgressCache + DeleteTiFlashProgressFromCache(tableID int64) + // CleanTiFlashProgressCache clean progress cache + CleanTiFlashProgressCache() + // Close is to close TiFlashReplicaManager + Close(ctx context.Context) +} + +// TiFlashReplicaManagerCtx manages placement with pd and replica progress for TiFlash. +type TiFlashReplicaManagerCtx struct { + pdHTTPCli pd.Client + sync.RWMutex // protect tiflashProgressCache + tiflashProgressCache map[int64]float64 + codec tikv.Codec +} + +// Close is called to close TiFlashReplicaManagerCtx. +func (*TiFlashReplicaManagerCtx) Close(context.Context) {} + +func getTiFlashPeerWithoutLagCount(tiFlashStores map[int64]pd.StoreInfo, keyspaceID tikv.KeyspaceID, tableID int64) (int, error) { + // storeIDs -> regionID, PD will not create two peer on the same store + var flashPeerCount int + for _, store := range tiFlashStores { + regionReplica := make(map[int64]int) + err := helper.CollectTiFlashStatus(store.Store.StatusAddress, keyspaceID, tableID, ®ionReplica) + failpoint.Inject("OneTiFlashStoreDown", func() { + if store.Store.StateName == "Down" { + err = errors.New("mock TiFlasah down") + } + }) + if err != nil { + logutil.BgLogger().Error("Fail to get peer status from TiFlash.", + zap.Int64("tableID", tableID)) + // Just skip down or offline or tomestone stores, because PD will migrate regions from these stores. + if store.Store.StateName == "Up" || store.Store.StateName == "Disconnected" { + return 0, err + } + continue + } + flashPeerCount += len(regionReplica) + } + return flashPeerCount, nil +} + +// calculateTiFlashProgress calculates progress based on the region status from PD and TiFlash. +func calculateTiFlashProgress(keyspaceID tikv.KeyspaceID, tableID int64, replicaCount uint64, tiFlashStores map[int64]pd.StoreInfo) (float64, error) { + var regionCount int + if err := GetTiFlashRegionCountFromPD(context.Background(), tableID, ®ionCount); err != nil { + logutil.BgLogger().Error("Fail to get regionCount from PD.", + zap.Int64("tableID", tableID)) + return 0, errors.Trace(err) + } + + if regionCount == 0 { + logutil.BgLogger().Warn("region count getting from PD is 0.", + zap.Int64("tableID", tableID)) + return 0, fmt.Errorf("region count getting from PD is 0") + } + + tiflashPeerCount, err := getTiFlashPeerWithoutLagCount(tiFlashStores, keyspaceID, tableID) + if err != nil { + logutil.BgLogger().Error("Fail to get peer count from TiFlash.", + zap.Int64("tableID", tableID)) + return 0, errors.Trace(err) + } + progress := float64(tiflashPeerCount) / float64(regionCount*int(replicaCount)) + if progress > 1 { // when pd do balance + logutil.BgLogger().Debug("TiFlash peer count > pd peer count, maybe doing balance.", + zap.Int64("tableID", tableID), zap.Int("tiflashPeerCount", tiflashPeerCount), zap.Int("regionCount", regionCount), zap.Uint64("replicaCount", replicaCount)) + progress = 1 + } + if progress < 1 { + logutil.BgLogger().Debug("TiFlash replica progress < 1.", + zap.Int64("tableID", tableID), zap.Int("tiflashPeerCount", tiflashPeerCount), zap.Int("regionCount", regionCount), zap.Uint64("replicaCount", replicaCount)) + } + return progress, nil +} + +func encodeRule(c tikv.Codec, rule *pd.Rule) { + rule.StartKey, rule.EndKey = c.EncodeRange(rule.StartKey, rule.EndKey) + rule.ID = encodeRuleID(c, rule.ID) +} + +// encodeRule encodes the rule ID by the following way: +// 1. if the codec is in API V1 then the rule ID is not encoded, should be like "table--r". +// 2. if the codec is in API V2 then the rule ID is encoded, +// should be like "keyspace--table--r". +func encodeRuleID(c tikv.Codec, ruleID string) string { + if c.GetAPIVersion() == kvrpcpb.APIVersion_V2 { + return fmt.Sprintf("keyspace-%v-%s", c.GetKeyspaceID(), ruleID) + } + return ruleID +} + +// CalculateTiFlashProgress calculates TiFlash replica progress. +func (m *TiFlashReplicaManagerCtx) CalculateTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores map[int64]pd.StoreInfo) (float64, error) { + return calculateTiFlashProgress(m.codec.GetKeyspaceID(), tableID, replicaCount, tiFlashStores) +} + +// UpdateTiFlashProgressCache updates tiflashProgressCache +func (m *TiFlashReplicaManagerCtx) UpdateTiFlashProgressCache(tableID int64, progress float64) { + m.Lock() + defer m.Unlock() + m.tiflashProgressCache[tableID] = progress +} + +// GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache +func (m *TiFlashReplicaManagerCtx) GetTiFlashProgressFromCache(tableID int64) (float64, bool) { + m.RLock() + defer m.RUnlock() + progress, ok := m.tiflashProgressCache[tableID] + return progress, ok +} + +// DeleteTiFlashProgressFromCache delete tiflash replica progress from tiflashProgressCache +func (m *TiFlashReplicaManagerCtx) DeleteTiFlashProgressFromCache(tableID int64) { + m.Lock() + defer m.Unlock() + delete(m.tiflashProgressCache, tableID) +} + +// CleanTiFlashProgressCache clean progress cache +func (m *TiFlashReplicaManagerCtx) CleanTiFlashProgressCache() { + m.Lock() + defer m.Unlock() + m.tiflashProgressCache = make(map[int64]float64) +} + +// SetTiFlashGroupConfig sets the tiflash's rule group config +func (m *TiFlashReplicaManagerCtx) SetTiFlashGroupConfig(ctx context.Context) error { + groupConfig, err := m.pdHTTPCli.GetPlacementRuleGroupByID(ctx, placement.TiFlashRuleGroupID) + if err != nil { + return errors.Trace(err) + } + if groupConfig != nil && groupConfig.Index == placement.RuleIndexTiFlash && !groupConfig.Override { + return nil + } + groupConfig = &pd.RuleGroup{ + ID: placement.TiFlashRuleGroupID, + Index: placement.RuleIndexTiFlash, + Override: false, + } + return m.pdHTTPCli.SetPlacementRuleGroup(ctx, groupConfig) +} + +// SetPlacementRule is a helper function to set placement rule. +func (m *TiFlashReplicaManagerCtx) SetPlacementRule(ctx context.Context, rule *pd.Rule) error { + encodeRule(m.codec, rule) + return m.doSetPlacementRule(ctx, rule) +} + +func (m *TiFlashReplicaManagerCtx) doSetPlacementRule(ctx context.Context, rule *pd.Rule) error { + if err := m.SetTiFlashGroupConfig(ctx); err != nil { + return err + } + if rule.Count == 0 { + return m.pdHTTPCli.DeletePlacementRule(ctx, rule.GroupID, rule.ID) + } + return m.pdHTTPCli.SetPlacementRule(ctx, rule) +} + +// SetPlacementRuleBatch is a helper function to set a batch of placement rules. +func (m *TiFlashReplicaManagerCtx) SetPlacementRuleBatch(ctx context.Context, rules []*pd.Rule) error { + r := make([]*pd.Rule, 0, len(rules)) + for _, rule := range rules { + encodeRule(m.codec, rule) + r = append(r, rule) + } + return m.doSetPlacementRuleBatch(ctx, r) +} + +func (m *TiFlashReplicaManagerCtx) doSetPlacementRuleBatch(ctx context.Context, rules []*pd.Rule) error { + if err := m.SetTiFlashGroupConfig(ctx); err != nil { + return err + } + ruleOps := make([]*pd.RuleOp, 0, len(rules)) + for i, r := range rules { + if r.Count == 0 { + ruleOps = append(ruleOps, &pd.RuleOp{ + Rule: rules[i], + Action: pd.RuleOpDel, + }) + } else { + ruleOps = append(ruleOps, &pd.RuleOp{ + Rule: rules[i], + Action: pd.RuleOpAdd, + }) + } + } + return m.pdHTTPCli.SetPlacementRuleInBatch(ctx, ruleOps) +} + +// DeletePlacementRule is to delete placement rule for certain group. +func (m *TiFlashReplicaManagerCtx) DeletePlacementRule(ctx context.Context, group string, ruleID string) error { + ruleID = encodeRuleID(m.codec, ruleID) + return m.pdHTTPCli.DeletePlacementRule(ctx, group, ruleID) +} + +// GetGroupRules to get all placement rule in a certain group. +func (m *TiFlashReplicaManagerCtx) GetGroupRules(ctx context.Context, group string) ([]*pd.Rule, error) { + return m.pdHTTPCli.GetPlacementRulesByGroup(ctx, group) +} + +// PostAccelerateScheduleBatch sends `regions/batch-accelerate-schedule` request. +func (m *TiFlashReplicaManagerCtx) PostAccelerateScheduleBatch(ctx context.Context, tableIDs []int64) error { + if len(tableIDs) == 0 { + return nil + } + input := make([]*pd.KeyRange, 0, len(tableIDs)) + for _, tableID := range tableIDs { + startKey := tablecodec.GenTableRecordPrefix(tableID) + endKey := tablecodec.EncodeTablePrefix(tableID + 1) + startKey, endKey = m.codec.EncodeRegionRange(startKey, endKey) + input = append(input, pd.NewKeyRange(startKey, endKey)) + } + return m.pdHTTPCli.AccelerateScheduleInBatch(ctx, input) +} + +// GetRegionCountFromPD is a helper function calling `/stats/region`. +func (m *TiFlashReplicaManagerCtx) GetRegionCountFromPD(ctx context.Context, tableID int64, regionCount *int) error { + startKey := tablecodec.GenTableRecordPrefix(tableID) + endKey := tablecodec.EncodeTablePrefix(tableID + 1) + startKey, endKey = m.codec.EncodeRegionRange(startKey, endKey) + stats, err := m.pdHTTPCli.GetRegionStatusByKeyRange(ctx, pd.NewKeyRange(startKey, endKey), true) + if err != nil { + return err + } + *regionCount = stats.Count + return nil +} + +// GetStoresStat gets the TiKV store information by accessing PD's api. +func (m *TiFlashReplicaManagerCtx) GetStoresStat(ctx context.Context) (*pd.StoresInfo, error) { + return m.pdHTTPCli.GetStores(ctx) +} + +type mockTiFlashReplicaManagerCtx struct { + sync.RWMutex + // Set to nil if there is no need to set up a mock TiFlash server. + // Otherwise use NewMockTiFlash to create one. + tiflash *MockTiFlash + tiflashProgressCache map[int64]float64 +} + +func makeBaseRule() pd.Rule { + return pd.Rule{ + GroupID: placement.TiFlashRuleGroupID, + ID: "", + Index: placement.RuleIndexTiFlash, + Override: false, + Role: pd.Learner, + Count: 2, + LabelConstraints: []pd.LabelConstraint{ + { + Key: "engine", + Op: pd.In, + Values: []string{"tiflash"}, + }, + }, + } +} + +// MakeNewRule creates a pd rule for TiFlash. +func MakeNewRule(id int64, count uint64, locationLabels []string) pd.Rule { + ruleID := MakeRuleID(id) + startKey := tablecodec.GenTableRecordPrefix(id) + endKey := tablecodec.EncodeTablePrefix(id + 1) + + ruleNew := makeBaseRule() + ruleNew.ID = ruleID + ruleNew.StartKey = startKey + ruleNew.EndKey = endKey + ruleNew.Count = int(count) + ruleNew.LocationLabels = locationLabels + + return ruleNew +} + +// MakeRuleID creates a rule ID for TiFlash with given TableID. +// This interface is exported for the module who wants to manipulate the TiFlash rule. +// The rule ID is in the format of "table--r". +// NOTE: PLEASE DO NOT write the rule ID manually, use this interface instead. +func MakeRuleID(id int64) string { + return fmt.Sprintf("table-%v-r", id) +} + +type mockTiFlashTableInfo struct { + Regions []int + Accel bool +} + +func (m *mockTiFlashTableInfo) String() string { + regionStr := "" + for _, s := range m.Regions { + regionStr = regionStr + strconv.Itoa(s) + "\n" + } + if regionStr == "" { + regionStr = "\n" + } + return fmt.Sprintf("%v\n%v", len(m.Regions), regionStr) +} + +// MockTiFlash mocks a TiFlash, with necessary Pd support. +type MockTiFlash struct { + syncutil.Mutex + groupIndex int + StatusAddr string + StatusServer *httptest.Server + SyncStatus map[int]mockTiFlashTableInfo + StoreInfo map[uint64]pd.MetaStore + GlobalTiFlashPlacementRules map[string]*pd.Rule + PdEnabled bool + TiflashDelay time.Duration + StartTime time.Time + NotAvailable bool + NetworkError bool +} + +func (tiflash *MockTiFlash) setUpMockTiFlashHTTPServer() { + tiflash.Lock() + defer tiflash.Unlock() + // mock TiFlash http server + router := mux.NewRouter() + server := httptest.NewServer(router) + // mock store stats stat + statusAddr := strings.TrimPrefix(server.URL, "http://") + statusAddrVec := strings.Split(statusAddr, ":") + statusPort, _ := strconv.Atoi(statusAddrVec[1]) + router.HandleFunc("/tiflash/sync-status/keyspace/{keyspaceid:\\d+}/table/{tableid:\\d+}", func(w http.ResponseWriter, req *http.Request) { + tiflash.Lock() + defer tiflash.Unlock() + if tiflash.NetworkError { + w.WriteHeader(http.StatusNotFound) + return + } + params := mux.Vars(req) + tableID, err := strconv.Atoi(params["tableid"]) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + table, ok := tiflash.SyncStatus[tableID] + if tiflash.NotAvailable { + // No region is available, so the table is not available. + table.Regions = []int{} + } + if !ok { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("0\n\n")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(table.String())) + }) + router.HandleFunc("/config", func(w http.ResponseWriter, _ *http.Request) { + tiflash.Lock() + defer tiflash.Unlock() + s := fmt.Sprintf("{\n \"engine-store\": {\n \"http_port\": %v\n }\n}", statusPort) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(s)) + }) + tiflash.StatusServer = server + tiflash.StatusAddr = statusAddr +} + +// NewMockTiFlash creates a MockTiFlash with a mocked TiFlash server. +func NewMockTiFlash() *MockTiFlash { + tiflash := &MockTiFlash{ + StatusAddr: "", + StatusServer: nil, + SyncStatus: make(map[int]mockTiFlashTableInfo), + StoreInfo: make(map[uint64]pd.MetaStore), + GlobalTiFlashPlacementRules: make(map[string]*pd.Rule), + PdEnabled: true, + TiflashDelay: 0, + StartTime: time.Now(), + NotAvailable: false, + } + tiflash.setUpMockTiFlashHTTPServer() + return tiflash +} + +// HandleSetPlacementRule is mock function for SetTiFlashPlacementRule. +func (tiflash *MockTiFlash) HandleSetPlacementRule(rule *pd.Rule) error { + tiflash.Lock() + defer tiflash.Unlock() + tiflash.groupIndex = placement.RuleIndexTiFlash + if !tiflash.PdEnabled { + logutil.BgLogger().Info("pd server is manually disabled, just quit") + return nil + } + + if rule.Count == 0 { + delete(tiflash.GlobalTiFlashPlacementRules, rule.ID) + } else { + tiflash.GlobalTiFlashPlacementRules[rule.ID] = rule + } + // Pd shall schedule TiFlash, we can mock here + tid := 0 + _, err := fmt.Sscanf(rule.ID, "table-%d-r", &tid) + if err != nil { + return errors.New("Can't parse rule") + } + // Set up mock TiFlash replica + f := func() { + if z, ok := tiflash.SyncStatus[tid]; ok { + z.Regions = []int{1} + tiflash.SyncStatus[tid] = z + } else { + tiflash.SyncStatus[tid] = mockTiFlashTableInfo{ + Regions: []int{1}, + Accel: false, + } + } + } + if tiflash.TiflashDelay > 0 { + go func() { + time.Sleep(tiflash.TiflashDelay) + logutil.BgLogger().Warn("TiFlash replica is available after delay", zap.Duration("duration", tiflash.TiflashDelay)) + f() + }() + } else { + f() + } + return nil +} + +// HandleSetPlacementRuleBatch is mock function for batch SetTiFlashPlacementRule. +func (tiflash *MockTiFlash) HandleSetPlacementRuleBatch(rules []*pd.Rule) error { + for _, r := range rules { + if err := tiflash.HandleSetPlacementRule(r); err != nil { + return err + } + } + return nil +} + +// ResetSyncStatus is mock function for reset sync status. +func (tiflash *MockTiFlash) ResetSyncStatus(tableID int, canAvailable bool) { + tiflash.Lock() + defer tiflash.Unlock() + if canAvailable { + if z, ok := tiflash.SyncStatus[tableID]; ok { + z.Regions = []int{1} + tiflash.SyncStatus[tableID] = z + } else { + tiflash.SyncStatus[tableID] = mockTiFlashTableInfo{ + Regions: []int{1}, + Accel: false, + } + } + } else { + delete(tiflash.SyncStatus, tableID) + } +} + +// HandleDeletePlacementRule is mock function for DeleteTiFlashPlacementRule. +func (tiflash *MockTiFlash) HandleDeletePlacementRule(_ string, ruleID string) { + tiflash.Lock() + defer tiflash.Unlock() + delete(tiflash.GlobalTiFlashPlacementRules, ruleID) +} + +// HandleGetGroupRules is mock function for GetTiFlashGroupRules. +func (tiflash *MockTiFlash) HandleGetGroupRules(_ string) ([]*pd.Rule, error) { + tiflash.Lock() + defer tiflash.Unlock() + var result = make([]*pd.Rule, 0) + for _, item := range tiflash.GlobalTiFlashPlacementRules { + result = append(result, item) + } + return result, nil +} + +// HandlePostAccelerateSchedule is mock function for PostAccelerateSchedule +func (tiflash *MockTiFlash) HandlePostAccelerateSchedule(endKey string) error { + tiflash.Lock() + defer tiflash.Unlock() + tableID := helper.GetTiFlashTableIDFromEndKey(endKey) + + table, ok := tiflash.SyncStatus[int(tableID)] + if ok { + table.Accel = true + tiflash.SyncStatus[int(tableID)] = table + } else { + tiflash.SyncStatus[int(tableID)] = mockTiFlashTableInfo{ + Regions: []int{}, + Accel: true, + } + } + return nil +} + +// HandleGetPDRegionRecordStats is mock function for GetRegionCountFromPD. +// It currently always returns 1 Region for convenience. +func (*MockTiFlash) HandleGetPDRegionRecordStats(int64) pd.RegionStats { + return pd.RegionStats{ + Count: 1, + } +} + +// AddStore is mock function for adding store info into MockTiFlash. +func (tiflash *MockTiFlash) AddStore(storeID uint64, address string) { + tiflash.StoreInfo[storeID] = pd.MetaStore{ + ID: int64(storeID), + Address: address, + State: 0, + StateName: "Up", + Version: "4.0.0-alpha", + StatusAddress: tiflash.StatusAddr, + GitHash: "mock-tikv-githash", + StartTimestamp: tiflash.StartTime.Unix(), + Labels: []pd.StoreLabel{{ + Key: "engine", + Value: "tiflash", + }}, + } +} + +// HandleGetStoresStat is mock function for GetStoresStat. +// It returns address of our mocked TiFlash server. +func (tiflash *MockTiFlash) HandleGetStoresStat() *pd.StoresInfo { + tiflash.Lock() + defer tiflash.Unlock() + if len(tiflash.StoreInfo) == 0 { + // default Store + return &pd.StoresInfo{ + Count: 1, + Stores: []pd.StoreInfo{ + { + Store: pd.MetaStore{ + ID: 1, + Address: "127.0.0.1:3930", + State: 0, + StateName: "Up", + Version: "4.0.0-alpha", + StatusAddress: tiflash.StatusAddr, + GitHash: "mock-tikv-githash", + StartTimestamp: tiflash.StartTime.Unix(), + Labels: []pd.StoreLabel{{ + Key: "engine", + Value: "tiflash", + }}, + }, + }, + }, + } + } + stores := make([]pd.StoreInfo, 0, len(tiflash.StoreInfo)) + for _, storeInfo := range tiflash.StoreInfo { + stores = append(stores, pd.StoreInfo{Store: storeInfo, Status: pd.StoreStatus{}}) + } + return &pd.StoresInfo{ + Count: len(tiflash.StoreInfo), + Stores: stores, + } +} + +// SetRuleGroupIndex sets the group index of tiflash +func (tiflash *MockTiFlash) SetRuleGroupIndex(groupIndex int) { + tiflash.Lock() + defer tiflash.Unlock() + tiflash.groupIndex = groupIndex +} + +// GetRuleGroupIndex gets the group index of tiflash +func (tiflash *MockTiFlash) GetRuleGroupIndex() int { + tiflash.Lock() + defer tiflash.Unlock() + return tiflash.groupIndex +} + +// Compare supposed rule, and we actually get from TableInfo +func isRuleMatch(rule pd.Rule, startKey []byte, endKey []byte, count int, labels []string) bool { + // Compute startKey + if !(bytes.Equal(rule.StartKey, startKey) && bytes.Equal(rule.EndKey, endKey)) { + return false + } + ok := false + for _, c := range rule.LabelConstraints { + if c.Key == "engine" && len(c.Values) == 1 && c.Values[0] == "tiflash" && c.Op == pd.In { + ok = true + break + } + } + if !ok { + return false + } + + if len(rule.LocationLabels) != len(labels) { + return false + } + for i, lb := range labels { + if lb != rule.LocationLabels[i] { + return false + } + } + if rule.Count != count { + return false + } + if rule.Role != pd.Learner { + return false + } + return true +} + +// CheckPlacementRule find if a given rule precisely matches already set rules. +func (tiflash *MockTiFlash) CheckPlacementRule(rule pd.Rule) bool { + tiflash.Lock() + defer tiflash.Unlock() + for _, r := range tiflash.GlobalTiFlashPlacementRules { + if isRuleMatch(rule, r.StartKey, r.EndKey, r.Count, r.LocationLabels) { + return true + } + } + return false +} + +// GetPlacementRule find a rule by name. +func (tiflash *MockTiFlash) GetPlacementRule(ruleName string) (*pd.Rule, bool) { + tiflash.Lock() + defer tiflash.Unlock() + if r, ok := tiflash.GlobalTiFlashPlacementRules[ruleName]; ok { + p := r + return p, ok + } + return nil, false +} + +// CleanPlacementRules cleans all placement rules. +func (tiflash *MockTiFlash) CleanPlacementRules() { + tiflash.Lock() + defer tiflash.Unlock() + tiflash.GlobalTiFlashPlacementRules = make(map[string]*pd.Rule) +} + +// PlacementRulesLen gets length of all currently set placement rules. +func (tiflash *MockTiFlash) PlacementRulesLen() int { + tiflash.Lock() + defer tiflash.Unlock() + return len(tiflash.GlobalTiFlashPlacementRules) +} + +// GetTableSyncStatus returns table sync status by given tableID. +func (tiflash *MockTiFlash) GetTableSyncStatus(tableID int) (*mockTiFlashTableInfo, bool) { + tiflash.Lock() + defer tiflash.Unlock() + if r, ok := tiflash.SyncStatus[tableID]; ok { + p := r + return &p, ok + } + return nil, false +} + +// PdSwitch controls if pd is enabled. +func (tiflash *MockTiFlash) PdSwitch(enabled bool) { + tiflash.Lock() + defer tiflash.Unlock() + tiflash.PdEnabled = enabled +} + +// SetNetworkError sets network error state. +func (tiflash *MockTiFlash) SetNetworkError(e bool) { + tiflash.Lock() + defer tiflash.Unlock() + tiflash.NetworkError = e +} + +// CalculateTiFlashProgress return truncated string to avoid float64 comparison. +func (*mockTiFlashReplicaManagerCtx) CalculateTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores map[int64]pd.StoreInfo) (float64, error) { + return calculateTiFlashProgress(tikv.NullspaceID, tableID, replicaCount, tiFlashStores) +} + +// UpdateTiFlashProgressCache updates tiflashProgressCache +func (m *mockTiFlashReplicaManagerCtx) UpdateTiFlashProgressCache(tableID int64, progress float64) { + m.Lock() + defer m.Unlock() + m.tiflashProgressCache[tableID] = progress +} + +// GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache +func (m *mockTiFlashReplicaManagerCtx) GetTiFlashProgressFromCache(tableID int64) (float64, bool) { + m.RLock() + defer m.RUnlock() + progress, ok := m.tiflashProgressCache[tableID] + return progress, ok +} + +// DeleteTiFlashProgressFromCache delete tiflash replica progress from tiflashProgressCache +func (m *mockTiFlashReplicaManagerCtx) DeleteTiFlashProgressFromCache(tableID int64) { + m.Lock() + defer m.Unlock() + delete(m.tiflashProgressCache, tableID) +} + +// CleanTiFlashProgressCache clean progress cache +func (m *mockTiFlashReplicaManagerCtx) CleanTiFlashProgressCache() { + m.Lock() + defer m.Unlock() + m.tiflashProgressCache = make(map[int64]float64) +} + +// SetMockTiFlash is set a mock TiFlash server. +func (m *mockTiFlashReplicaManagerCtx) SetMockTiFlash(tiflash *MockTiFlash) { + m.Lock() + defer m.Unlock() + m.tiflash = tiflash +} + +// SetTiFlashGroupConfig sets the tiflash's rule group config +func (m *mockTiFlashReplicaManagerCtx) SetTiFlashGroupConfig(_ context.Context) error { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return nil + } + m.tiflash.SetRuleGroupIndex(placement.RuleIndexTiFlash) + return nil +} + +// SetPlacementRule is a helper function to set placement rule. +func (m *mockTiFlashReplicaManagerCtx) SetPlacementRule(_ context.Context, rule *pd.Rule) error { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return nil + } + return m.tiflash.HandleSetPlacementRule(rule) +} + +// SetPlacementRuleBatch is a helper function to set a batch of placement rules. +func (m *mockTiFlashReplicaManagerCtx) SetPlacementRuleBatch(_ context.Context, rules []*pd.Rule) error { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return nil + } + return m.tiflash.HandleSetPlacementRuleBatch(rules) +} + +// DeletePlacementRule is to delete placement rule for certain group. +func (m *mockTiFlashReplicaManagerCtx) DeletePlacementRule(_ context.Context, group string, ruleID string) error { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return nil + } + logutil.BgLogger().Info("Remove TiFlash rule", zap.String("ruleID", ruleID)) + m.tiflash.HandleDeletePlacementRule(group, ruleID) + return nil +} + +// GetGroupRules to get all placement rule in a certain group. +func (m *mockTiFlashReplicaManagerCtx) GetGroupRules(_ context.Context, group string) ([]*pd.Rule, error) { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return []*pd.Rule{}, nil + } + return m.tiflash.HandleGetGroupRules(group) +} + +// PostAccelerateScheduleBatch sends `regions/batch-accelerate-schedule` request. +func (m *mockTiFlashReplicaManagerCtx) PostAccelerateScheduleBatch(_ context.Context, tableIDs []int64) error { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return nil + } + for _, tableID := range tableIDs { + endKey := tablecodec.EncodeTablePrefix(tableID + 1) + endKey = codec.EncodeBytes([]byte{}, endKey) + if err := m.tiflash.HandlePostAccelerateSchedule(hex.EncodeToString(endKey)); err != nil { + return err + } + } + return nil +} + +// GetRegionCountFromPD is a helper function calling `/stats/region`. +func (m *mockTiFlashReplicaManagerCtx) GetRegionCountFromPD(_ context.Context, tableID int64, regionCount *int) error { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return nil + } + stats := m.tiflash.HandleGetPDRegionRecordStats(tableID) + *regionCount = stats.Count + return nil +} + +// GetStoresStat gets the TiKV store information by accessing PD's api. +func (m *mockTiFlashReplicaManagerCtx) GetStoresStat(_ context.Context) (*pd.StoresInfo, error) { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return nil, &MockTiFlashError{"MockTiFlash is not accessible"} + } + return m.tiflash.HandleGetStoresStat(), nil +} + +// Close is called to close mockTiFlashReplicaManager. +func (m *mockTiFlashReplicaManagerCtx) Close(_ context.Context) { + m.Lock() + defer m.Unlock() + if m.tiflash == nil { + return + } + if m.tiflash.StatusServer != nil { + m.tiflash.StatusServer.Close() + } +} + +// MockTiFlashError represents MockTiFlash error +type MockTiFlashError struct { + Message string +} + +func (me *MockTiFlashError) Error() string { + return me.Message +} diff --git a/pkg/domain/plan_replayer_dump.go b/pkg/domain/plan_replayer_dump.go index 1a5b95e4eb0c2..55a777ab7d556 100644 --- a/pkg/domain/plan_replayer_dump.go +++ b/pkg/domain/plan_replayer_dump.go @@ -308,11 +308,11 @@ func DumpPlanReplayerInfo(ctx context.Context, sctx sessionctx.Context, errMsgs = append(errMsgs, fallbackMsg) } } else { - failpoint.Inject("shouldDumpStats", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("shouldDumpStats")); _err_ == nil { if val.(bool) { panic("shouldDumpStats") } - }) + } } } else { // Dump stats diff --git a/pkg/domain/plan_replayer_dump.go__failpoint_stash__ b/pkg/domain/plan_replayer_dump.go__failpoint_stash__ new file mode 100644 index 0000000000000..1a5b95e4eb0c2 --- /dev/null +++ b/pkg/domain/plan_replayer_dump.go__failpoint_stash__ @@ -0,0 +1,923 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package domain + +import ( + "archive/zip" + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + "github.com/BurntSushi/toml" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/config" + domain_metrics "github.com/pingcap/tidb/pkg/domain/metrics" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/printer" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "go.uber.org/zap" +) + +const ( + // PlanReplayerSQLMetaFile indicates sql meta path for plan replayer + PlanReplayerSQLMetaFile = "sql_meta.toml" + // PlanReplayerConfigFile indicates config file path for plan replayer + PlanReplayerConfigFile = "config.toml" + // PlanReplayerMetaFile meta file path for plan replayer + PlanReplayerMetaFile = "meta.txt" + // PlanReplayerVariablesFile indicates for session variables file path for plan replayer + PlanReplayerVariablesFile = "variables.toml" + // PlanReplayerTiFlashReplicasFile indicates for table tiflash replica file path for plan replayer + PlanReplayerTiFlashReplicasFile = "table_tiflash_replica.txt" + // PlanReplayerSessionBindingFile indicates session binding file path for plan replayer + PlanReplayerSessionBindingFile = "session_bindings.sql" + // PlanReplayerGlobalBindingFile indicates global binding file path for plan replayer + PlanReplayerGlobalBindingFile = "global_bindings.sql" + // PlanReplayerSchemaMetaFile indicates the schema meta + PlanReplayerSchemaMetaFile = "schema_meta.txt" + // PlanReplayerErrorMessageFile is the file name for error messages + PlanReplayerErrorMessageFile = "errors.txt" +) + +const ( + // PlanReplayerSQLMetaStartTS indicates the startTS in plan replayer sql meta + PlanReplayerSQLMetaStartTS = "startTS" + // PlanReplayerTaskMetaIsCapture indicates whether this task is capture task + PlanReplayerTaskMetaIsCapture = "isCapture" + // PlanReplayerTaskMetaIsContinues indicates whether this task is continues task + PlanReplayerTaskMetaIsContinues = "isContinues" + // PlanReplayerTaskMetaSQLDigest indicates the sql digest of this task + PlanReplayerTaskMetaSQLDigest = "sqlDigest" + // PlanReplayerTaskMetaPlanDigest indicates the plan digest of this task + PlanReplayerTaskMetaPlanDigest = "planDigest" + // PlanReplayerTaskEnableHistoricalStats indicates whether the task is using historical stats + PlanReplayerTaskEnableHistoricalStats = "enableHistoricalStats" + // PlanReplayerHistoricalStatsTS indicates the expected TS of the historical stats if it's specified by the user. + PlanReplayerHistoricalStatsTS = "historicalStatsTS" +) + +type tableNamePair struct { + DBName string + TableName string + IsView bool +} + +type tableNameExtractor struct { + ctx context.Context + executor sqlexec.RestrictedSQLExecutor + is infoschema.InfoSchema + curDB model.CIStr + names map[tableNamePair]struct{} + cteNames map[string]struct{} + err error +} + +func (tne *tableNameExtractor) getTablesAndViews() map[tableNamePair]struct{} { + r := make(map[tableNamePair]struct{}) + for tablePair := range tne.names { + if tablePair.IsView { + r[tablePair] = struct{}{} + continue + } + // remove cte in table names + _, ok := tne.cteNames[tablePair.TableName] + if !ok { + r[tablePair] = struct{}{} + } + } + return r +} + +func (*tableNameExtractor) Enter(in ast.Node) (ast.Node, bool) { + if _, ok := in.(*ast.TableName); ok { + return in, true + } + return in, false +} + +func (tne *tableNameExtractor) Leave(in ast.Node) (ast.Node, bool) { + if tne.err != nil { + return in, true + } + if t, ok := in.(*ast.TableName); ok { + isView, err := tne.handleIsView(t) + if err != nil { + tne.err = err + return in, true + } + if tne.is.TableExists(t.Schema, t.Name) { + tp := tableNamePair{DBName: t.Schema.L, TableName: t.Name.L, IsView: isView} + if tp.DBName == "" { + tp.DBName = tne.curDB.L + } + tne.names[tp] = struct{}{} + } + } else if s, ok := in.(*ast.SelectStmt); ok { + if s.With != nil && len(s.With.CTEs) > 0 { + for _, cte := range s.With.CTEs { + tne.cteNames[cte.Name.L] = struct{}{} + } + } + } + return in, true +} + +func (tne *tableNameExtractor) handleIsView(t *ast.TableName) (bool, error) { + schema := t.Schema + if schema.L == "" { + schema = tne.curDB + } + table := t.Name + isView := infoschema.TableIsView(tne.is, schema, table) + if !isView { + return false, nil + } + viewTbl, err := tne.is.TableByName(context.Background(), schema, table) + if err != nil { + return false, err + } + sql := viewTbl.Meta().View.SelectStmt + node, err := tne.executor.ParseWithParams(tne.ctx, sql) + if err != nil { + return false, err + } + node.Accept(tne) + return true, nil +} + +// DumpPlanReplayerInfo will dump the information about sqls. +// The files will be organized into the following format: +/* + |-sql_meta.toml + |-meta.txt + |-schema + | |-schema_meta.txt + | |-db1.table1.schema.txt + | |-db2.table2.schema.txt + | |-.... + |-view + | |-db1.view1.view.txt + | |-db2.view2.view.txt + | |-.... + |-stats + | |-stats1.json + | |-stats2.json + | |-.... + |-statsMem + | |-stats1.txt + | |-stats2.txt + | |-.... + |-config.toml + |-table_tiflash_replica.txt + |-variables.toml + |-bindings.sql + |-sql + | |-sql1.sql + | |-sql2.sql + | |-.... + |-explain.txt +*/ +func DumpPlanReplayerInfo(ctx context.Context, sctx sessionctx.Context, + task *PlanReplayerDumpTask) (err error) { + zf := task.Zf + fileName := task.FileName + sessionVars := task.SessionVars + execStmts := task.ExecStmts + zw := zip.NewWriter(zf) + var records []PlanReplayerStatusRecord + var errMsgs []string + sqls := make([]string, 0) + for _, execStmt := range task.ExecStmts { + sqls = append(sqls, execStmt.Text()) + } + if task.IsCapture { + logutil.BgLogger().Info("start to dump plan replayer result", zap.String("category", "plan-replayer-dump"), + zap.String("sql-digest", task.SQLDigest), + zap.String("plan-digest", task.PlanDigest), + zap.Strings("sql", sqls), + zap.Bool("isContinues", task.IsContinuesCapture)) + } else { + logutil.BgLogger().Info("start to dump plan replayer result", zap.String("category", "plan-replayer-dump"), + zap.Strings("sqls", sqls)) + } + defer func() { + errMsg := "" + if err != nil { + if task.IsCapture { + logutil.BgLogger().Info("dump file failed", zap.String("category", "plan-replayer-dump"), + zap.String("sql-digest", task.SQLDigest), + zap.String("plan-digest", task.PlanDigest), + zap.Strings("sql", sqls), + zap.Bool("isContinues", task.IsContinuesCapture)) + } else { + logutil.BgLogger().Info("start to dump plan replayer result", zap.String("category", "plan-replayer-dump"), + zap.Strings("sqls", sqls)) + } + errMsg = err.Error() + domain_metrics.PlanReplayerDumpTaskFailed.Inc() + } else { + domain_metrics.PlanReplayerDumpTaskSuccess.Inc() + } + err1 := zw.Close() + if err1 != nil { + logutil.BgLogger().Error("Closing zip writer failed", zap.String("category", "plan-replayer-dump"), zap.Error(err), zap.String("filename", fileName)) + errMsg = errMsg + "," + err1.Error() + } + err2 := zf.Close() + if err2 != nil { + logutil.BgLogger().Error("Closing zip file failed", zap.String("category", "plan-replayer-dump"), zap.Error(err), zap.String("filename", fileName)) + errMsg = errMsg + "," + err2.Error() + } + if len(errMsg) > 0 { + for i, record := range records { + record.FailedReason = errMsg + records[i] = record + } + } + insertPlanReplayerStatus(ctx, sctx, records) + }() + // Dump SQLMeta + if err = dumpSQLMeta(zw, task); err != nil { + return err + } + + // Dump config + if err = dumpConfig(zw); err != nil { + return err + } + + // Dump meta + if err = dumpMeta(zw); err != nil { + return err + } + // Retrieve current DB + dbName := model.NewCIStr(sessionVars.CurrentDB) + do := GetDomain(sctx) + + // Retrieve all tables + pairs, err := extractTableNames(ctx, sctx, execStmts, dbName) + if err != nil { + return errors.AddStack(fmt.Errorf("plan replayer: invalid SQL text, err: %v", err)) + } + + // Dump Schema and View + if err = dumpSchemas(sctx, zw, pairs); err != nil { + return err + } + + // Dump tables tiflash replicas + if err = dumpTiFlashReplica(sctx, zw, pairs); err != nil { + return err + } + + // For continuous capture task, we dump stats in storage only if EnableHistoricalStatsForCapture is disabled. + // For manual plan replayer dump command or capture, we directly dump stats in storage + if task.IsCapture && task.IsContinuesCapture { + if !variable.EnableHistoricalStatsForCapture.Load() { + // Dump stats + fallbackMsg, err := dumpStats(zw, pairs, do, 0) + if err != nil { + return err + } + if len(fallbackMsg) > 0 { + errMsgs = append(errMsgs, fallbackMsg) + } + } else { + failpoint.Inject("shouldDumpStats", func(val failpoint.Value) { + if val.(bool) { + panic("shouldDumpStats") + } + }) + } + } else { + // Dump stats + fallbackMsg, err := dumpStats(zw, pairs, do, task.HistoricalStatsTS) + if err != nil { + return err + } + if len(fallbackMsg) > 0 { + errMsgs = append(errMsgs, fallbackMsg) + } + } + + if err = dumpStatsMemStatus(zw, pairs, do); err != nil { + return err + } + + // Dump variables + if err = dumpVariables(sctx, sessionVars, zw); err != nil { + return err + } + + // Dump sql + if err = dumpSQLs(execStmts, zw); err != nil { + return err + } + + // Dump session bindings + if len(task.SessionBindings) > 0 { + if err = dumpSessionBindRecords(task.SessionBindings, zw); err != nil { + return err + } + } else { + if err = dumpSessionBindings(sctx, zw); err != nil { + return err + } + } + + // Dump global bindings + if err = dumpGlobalBindings(sctx, zw); err != nil { + return err + } + + if len(task.EncodedPlan) > 0 { + records = generateRecords(task) + if err = dumpEncodedPlan(sctx, zw, task.EncodedPlan); err != nil { + return err + } + } else { + // Dump explain + if err = dumpPlanReplayerExplain(sctx, zw, task, &records); err != nil { + return err + } + } + + if task.DebugTrace != nil { + if err = dumpDebugTrace(zw, task.DebugTrace); err != nil { + return err + } + } + + if len(errMsgs) > 0 { + if err = dumpErrorMsgs(zw, errMsgs); err != nil { + return err + } + } + return nil +} + +func generateRecords(task *PlanReplayerDumpTask) []PlanReplayerStatusRecord { + records := make([]PlanReplayerStatusRecord, 0) + if len(task.ExecStmts) > 0 { + for _, execStmt := range task.ExecStmts { + records = append(records, PlanReplayerStatusRecord{ + SQLDigest: task.SQLDigest, + PlanDigest: task.PlanDigest, + OriginSQL: execStmt.Text(), + Token: task.FileName, + }) + } + } + return records +} + +func dumpSQLMeta(zw *zip.Writer, task *PlanReplayerDumpTask) error { + cf, err := zw.Create(PlanReplayerSQLMetaFile) + if err != nil { + return errors.AddStack(err) + } + varMap := make(map[string]string) + varMap[PlanReplayerSQLMetaStartTS] = strconv.FormatUint(task.StartTS, 10) + varMap[PlanReplayerTaskMetaIsCapture] = strconv.FormatBool(task.IsCapture) + varMap[PlanReplayerTaskMetaIsContinues] = strconv.FormatBool(task.IsContinuesCapture) + varMap[PlanReplayerTaskMetaSQLDigest] = task.SQLDigest + varMap[PlanReplayerTaskMetaPlanDigest] = task.PlanDigest + varMap[PlanReplayerTaskEnableHistoricalStats] = strconv.FormatBool(variable.EnableHistoricalStatsForCapture.Load()) + if task.HistoricalStatsTS > 0 { + varMap[PlanReplayerHistoricalStatsTS] = strconv.FormatUint(task.HistoricalStatsTS, 10) + } + if err := toml.NewEncoder(cf).Encode(varMap); err != nil { + return errors.AddStack(err) + } + return nil +} + +func dumpConfig(zw *zip.Writer) error { + cf, err := zw.Create(PlanReplayerConfigFile) + if err != nil { + return errors.AddStack(err) + } + if err := toml.NewEncoder(cf).Encode(config.GetGlobalConfig()); err != nil { + return errors.AddStack(err) + } + return nil +} + +func dumpMeta(zw *zip.Writer) error { + mt, err := zw.Create(PlanReplayerMetaFile) + if err != nil { + return errors.AddStack(err) + } + _, err = mt.Write([]byte(printer.GetTiDBInfo())) + if err != nil { + return errors.AddStack(err) + } + return nil +} + +func dumpTiFlashReplica(ctx sessionctx.Context, zw *zip.Writer, pairs map[tableNamePair]struct{}) error { + bf, err := zw.Create(PlanReplayerTiFlashReplicasFile) + if err != nil { + return errors.AddStack(err) + } + is := GetDomain(ctx).InfoSchema() + for pair := range pairs { + dbName := model.NewCIStr(pair.DBName) + tableName := model.NewCIStr(pair.TableName) + t, err := is.TableByName(context.Background(), dbName, tableName) + if err != nil { + logutil.BgLogger().Warn("failed to find table info", zap.Error(err), + zap.String("dbName", dbName.L), zap.String("tableName", tableName.L)) + continue + } + if t.Meta().TiFlashReplica != nil && t.Meta().TiFlashReplica.Count > 0 { + row := []string{ + pair.DBName, pair.TableName, strconv.FormatUint(t.Meta().TiFlashReplica.Count, 10), + } + fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) + } + } + return nil +} + +func dumpSchemas(ctx sessionctx.Context, zw *zip.Writer, pairs map[tableNamePair]struct{}) error { + tables := make(map[tableNamePair]struct{}) + for pair := range pairs { + err := getShowCreateTable(pair, zw, ctx) + if err != nil { + return err + } + if !pair.IsView { + tables[pair] = struct{}{} + } + } + return dumpSchemaMeta(zw, tables) +} + +func dumpSchemaMeta(zw *zip.Writer, tables map[tableNamePair]struct{}) error { + zf, err := zw.Create(fmt.Sprintf("schema/%v", PlanReplayerSchemaMetaFile)) + if err != nil { + return err + } + for table := range tables { + _, err := fmt.Fprintf(zf, "%s.%s;", table.DBName, table.TableName) + if err != nil { + return err + } + } + return nil +} + +func dumpStatsMemStatus(zw *zip.Writer, pairs map[tableNamePair]struct{}, do *Domain) error { + statsHandle := do.StatsHandle() + is := do.InfoSchema() + for pair := range pairs { + if pair.IsView { + continue + } + tbl, err := is.TableByName(context.Background(), model.NewCIStr(pair.DBName), model.NewCIStr(pair.TableName)) + if err != nil { + return err + } + tblStats := statsHandle.GetTableStats(tbl.Meta()) + if tblStats == nil { + continue + } + statsMemFw, err := zw.Create(fmt.Sprintf("statsMem/%v.%v.txt", pair.DBName, pair.TableName)) + if err != nil { + return errors.AddStack(err) + } + fmt.Fprintf(statsMemFw, "[INDEX]\n") + tblStats.ForEachIndexImmutable(func(_ int64, idx *statistics.Index) bool { + fmt.Fprintf(statsMemFw, "%s\n", fmt.Sprintf("%s=%s", idx.Info.Name.String(), idx.StatusToString())) + return false + }) + fmt.Fprintf(statsMemFw, "[COLUMN]\n") + tblStats.ForEachColumnImmutable(func(_ int64, c *statistics.Column) bool { + fmt.Fprintf(statsMemFw, "%s\n", fmt.Sprintf("%s=%s", c.Info.Name.String(), c.StatusToString())) + return false + }) + } + return nil +} + +func dumpStats(zw *zip.Writer, pairs map[tableNamePair]struct{}, do *Domain, historyStatsTS uint64) (string, error) { + allFallBackTbls := make([]string, 0) + for pair := range pairs { + if pair.IsView { + continue + } + jsonTbl, fallBackTbls, err := getStatsForTable(do, pair, historyStatsTS) + if err != nil { + return "", err + } + statsFw, err := zw.Create(fmt.Sprintf("stats/%v.%v.json", pair.DBName, pair.TableName)) + if err != nil { + return "", errors.AddStack(err) + } + data, err := json.Marshal(jsonTbl) + if err != nil { + return "", errors.AddStack(err) + } + _, err = statsFw.Write(data) + if err != nil { + return "", errors.AddStack(err) + } + allFallBackTbls = append(allFallBackTbls, fallBackTbls...) + } + var msg string + if len(allFallBackTbls) > 0 { + msg = "Historical stats for " + strings.Join(allFallBackTbls, ", ") + " are unavailable, fallback to latest stats" + } + return msg, nil +} + +func dumpSQLs(execStmts []ast.StmtNode, zw *zip.Writer) error { + for i, stmtExec := range execStmts { + zf, err := zw.Create(fmt.Sprintf("sql/sql%v.sql", i)) + if err != nil { + return err + } + _, err = zf.Write([]byte(stmtExec.Text())) + if err != nil { + return err + } + } + return nil +} + +func dumpVariables(sctx sessionctx.Context, sessionVars *variable.SessionVars, zw *zip.Writer) error { + varMap := make(map[string]string) + for _, v := range variable.GetSysVars() { + if v.IsNoop && !variable.EnableNoopVariables.Load() { + continue + } + if infoschema.SysVarHiddenForSem(sctx, v.Name) { + continue + } + value, err := sessionVars.GetSessionOrGlobalSystemVar(context.Background(), v.Name) + if err != nil { + return errors.Trace(err) + } + varMap[v.Name] = value + } + vf, err := zw.Create(PlanReplayerVariablesFile) + if err != nil { + return errors.AddStack(err) + } + if err := toml.NewEncoder(vf).Encode(varMap); err != nil { + return errors.AddStack(err) + } + return nil +} + +func dumpSessionBindRecords(records []bindinfo.Bindings, zw *zip.Writer) error { + sRows := make([][]string, 0) + for _, bindData := range records { + for _, hint := range bindData { + sRows = append(sRows, []string{ + hint.OriginalSQL, + hint.BindSQL, + hint.Db, + hint.Status, + hint.CreateTime.String(), + hint.UpdateTime.String(), + hint.Charset, + hint.Collation, + hint.Source, + }) + } + } + bf, err := zw.Create(PlanReplayerSessionBindingFile) + if err != nil { + return errors.AddStack(err) + } + for _, row := range sRows { + fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) + } + return nil +} + +func dumpSessionBindings(ctx sessionctx.Context, zw *zip.Writer) error { + recordSets, err := ctx.GetSQLExecutor().Execute(context.Background(), "show bindings") + if err != nil { + return err + } + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], true) + if err != nil { + return err + } + bf, err := zw.Create(PlanReplayerSessionBindingFile) + if err != nil { + return errors.AddStack(err) + } + for _, row := range sRows { + fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) + } + if len(recordSets) > 0 { + if err := recordSets[0].Close(); err != nil { + return err + } + } + return nil +} + +func dumpGlobalBindings(ctx sessionctx.Context, zw *zip.Writer) error { + recordSets, err := ctx.GetSQLExecutor().Execute(context.Background(), "show global bindings") + if err != nil { + return err + } + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) + if err != nil { + return err + } + bf, err := zw.Create(PlanReplayerGlobalBindingFile) + if err != nil { + return errors.AddStack(err) + } + for _, row := range sRows { + fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) + } + if len(recordSets) > 0 { + if err := recordSets[0].Close(); err != nil { + return err + } + } + return nil +} + +func dumpEncodedPlan(ctx sessionctx.Context, zw *zip.Writer, encodedPlan string) error { + var recordSets []sqlexec.RecordSet + var err error + recordSets, err = ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("select tidb_decode_plan('%s')", encodedPlan)) + if err != nil { + return err + } + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) + if err != nil { + return err + } + fw, err := zw.Create("explain/sql.txt") + if err != nil { + return errors.AddStack(err) + } + for _, row := range sRows { + fmt.Fprintf(fw, "%s\n", strings.Join(row, "\t")) + } + if len(recordSets) > 0 { + if err := recordSets[0].Close(); err != nil { + return err + } + } + return nil +} + +func dumpExplain(ctx sessionctx.Context, zw *zip.Writer, isAnalyze bool, sqls []string, emptyAsNil bool) (debugTraces []any, err error) { + fw, err := zw.Create("explain.txt") + if err != nil { + return nil, errors.AddStack(err) + } + ctx.GetSessionVars().InPlanReplayer = true + defer func() { + ctx.GetSessionVars().InPlanReplayer = false + }() + for i, sql := range sqls { + var recordSets []sqlexec.RecordSet + if isAnalyze { + // Explain analyze + recordSets, err = ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("explain analyze %s", sql)) + if err != nil { + return nil, err + } + } else { + // Explain + recordSets, err = ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("explain %s", sql)) + if err != nil { + return nil, err + } + } + debugTrace := ctx.GetSessionVars().StmtCtx.OptimizerDebugTrace + debugTraces = append(debugTraces, debugTrace) + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], emptyAsNil) + if err != nil { + return nil, err + } + for _, row := range sRows { + fmt.Fprintf(fw, "%s\n", strings.Join(row, "\t")) + } + if len(recordSets) > 0 { + if err := recordSets[0].Close(); err != nil { + return nil, err + } + } + if i < len(sqls)-1 { + fmt.Fprintf(fw, "<--------->\n") + } + } + return +} + +func dumpPlanReplayerExplain(ctx sessionctx.Context, zw *zip.Writer, task *PlanReplayerDumpTask, records *[]PlanReplayerStatusRecord) error { + sqls := make([]string, 0) + for _, execStmt := range task.ExecStmts { + sql := execStmt.Text() + sqls = append(sqls, sql) + *records = append(*records, PlanReplayerStatusRecord{ + OriginSQL: sql, + Token: task.FileName, + }) + } + debugTraces, err := dumpExplain(ctx, zw, task.Analyze, sqls, false) + task.DebugTrace = debugTraces + return err +} + +// extractTableNames extracts table names from the given stmts. +func extractTableNames(ctx context.Context, sctx sessionctx.Context, + execStmts []ast.StmtNode, curDB model.CIStr) (map[tableNamePair]struct{}, error) { + tableExtractor := &tableNameExtractor{ + ctx: ctx, + executor: sctx.GetRestrictedSQLExecutor(), + is: GetDomain(sctx).InfoSchema(), + curDB: curDB, + names: make(map[tableNamePair]struct{}), + cteNames: make(map[string]struct{}), + } + for _, execStmt := range execStmts { + execStmt.Accept(tableExtractor) + } + if tableExtractor.err != nil { + return nil, tableExtractor.err + } + return tableExtractor.getTablesAndViews(), nil +} + +func getStatsForTable(do *Domain, pair tableNamePair, historyStatsTS uint64) (*util.JSONTable, []string, error) { + is := do.InfoSchema() + h := do.StatsHandle() + tbl, err := is.TableByName(context.Background(), model.NewCIStr(pair.DBName), model.NewCIStr(pair.TableName)) + if err != nil { + return nil, nil, err + } + if historyStatsTS > 0 { + return h.DumpHistoricalStatsBySnapshot(pair.DBName, tbl.Meta(), historyStatsTS) + } + jt, err := h.DumpStatsToJSON(pair.DBName, tbl.Meta(), nil, true) + return jt, nil, err +} + +func getShowCreateTable(pair tableNamePair, zw *zip.Writer, ctx sessionctx.Context) error { + recordSets, err := ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("show create table `%v`.`%v`", pair.DBName, pair.TableName)) + if err != nil { + return err + } + sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) + if err != nil { + return err + } + var fw io.Writer + if pair.IsView { + fw, err = zw.Create(fmt.Sprintf("view/%v.%v.view.txt", pair.DBName, pair.TableName)) + if err != nil { + return errors.AddStack(err) + } + if len(sRows) == 0 || len(sRows[0]) != 4 { + return fmt.Errorf("plan replayer: get create view %v.%v failed", pair.DBName, pair.TableName) + } + } else { + fw, err = zw.Create(fmt.Sprintf("schema/%v.%v.schema.txt", pair.DBName, pair.TableName)) + if err != nil { + return errors.AddStack(err) + } + if len(sRows) == 0 || len(sRows[0]) != 2 { + return fmt.Errorf("plan replayer: get create table %v.%v failed", pair.DBName, pair.TableName) + } + } + fmt.Fprintf(fw, "create database if not exists `%v`; use `%v`;", pair.DBName, pair.DBName) + fmt.Fprintf(fw, "%s", sRows[0][1]) + if len(recordSets) > 0 { + if err := recordSets[0].Close(); err != nil { + return err + } + } + return nil +} + +func resultSetToStringSlice(ctx context.Context, rs sqlexec.RecordSet, emptyAsNil bool) ([][]string, error) { + rows, err := getRows(ctx, rs) + if err != nil { + return nil, err + } + err = rs.Close() + if err != nil { + return nil, err + } + sRows := make([][]string, len(rows)) + for i, row := range rows { + iRow := make([]string, row.Len()) + for j := 0; j < row.Len(); j++ { + if row.IsNull(j) { + iRow[j] = "" + } else { + d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType) + iRow[j], err = d.ToString() + if err != nil { + return nil, err + } + if len(iRow[j]) < 1 && emptyAsNil { + iRow[j] = "" + } + } + } + sRows[i] = iRow + } + return sRows, nil +} + +func getRows(ctx context.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { + if rs == nil { + return nil, nil + } + var rows []chunk.Row + req := rs.NewChunk(nil) + // Must reuse `req` for imitating server.(*clientConn).writeChunks + for { + err := rs.Next(ctx, req) + if err != nil { + return nil, err + } + if req.NumRows() == 0 { + break + } + + iter := chunk.NewIterator4Chunk(req.CopyConstruct()) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + rows = append(rows, row) + } + } + return rows, nil +} + +func dumpDebugTrace(zw *zip.Writer, debugTraces []any) error { + for i, trace := range debugTraces { + fw, err := zw.Create(fmt.Sprintf("debug_trace/debug_trace%d.json", i)) + if err != nil { + return errors.AddStack(err) + } + err = dumpOneDebugTrace(fw, trace) + if err != nil { + return errors.AddStack(err) + } + } + return nil +} + +func dumpOneDebugTrace(w io.Writer, debugTrace any) error { + jsonEncoder := json.NewEncoder(w) + // If we do not set this to false, ">", "<", "&"... will be escaped to "\u003c","\u003e", "\u0026"... + jsonEncoder.SetEscapeHTML(false) + return jsonEncoder.Encode(debugTrace) +} + +func dumpErrorMsgs(zw *zip.Writer, msgs []string) error { + mt, err := zw.Create(PlanReplayerErrorMessageFile) + if err != nil { + return errors.AddStack(err) + } + for _, msg := range msgs { + _, err = mt.Write([]byte(msg)) + if err != nil { + return errors.AddStack(err) + } + _, err = mt.Write([]byte{'\n'}) + if err != nil { + return errors.AddStack(err) + } + } + return nil +} diff --git a/pkg/domain/runaway.go b/pkg/domain/runaway.go index ae5d50724a047..f6dfa79d96049 100644 --- a/pkg/domain/runaway.go +++ b/pkg/domain/runaway.go @@ -64,9 +64,9 @@ func (do *Domain) deleteExpiredRows(tableName, colName string, expiredDuration t if !do.DDL().OwnerManager().IsOwner() { return } - failpoint.Inject("FastRunawayGC", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FastRunawayGC")); _err_ == nil { expiredDuration = time.Second * 1 - }) + } expiredTime := time.Now().Add(-expiredDuration) tbCIStr := model.NewCIStr(tableName) tbl, err := do.InfoSchema().TableByName(context.Background(), systemSchemaCIStr, tbCIStr) @@ -244,12 +244,12 @@ func (do *Domain) runawayRecordFlushLoop() { // we can guarantee a watch record can be seen by the user within 1s. runawayRecordFlushTimer := time.NewTimer(runawayRecordFlushInterval) runawayRecordGCTicker := time.NewTicker(runawayRecordGCInterval) - failpoint.Inject("FastRunawayGC", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FastRunawayGC")); _err_ == nil { runawayRecordFlushTimer.Stop() runawayRecordGCTicker.Stop() runawayRecordFlushTimer = time.NewTimer(time.Millisecond * 50) runawayRecordGCTicker = time.NewTicker(time.Millisecond * 200) - }) + } fired := false recordCh := do.runawayManager.RunawayRecordChan() @@ -278,9 +278,9 @@ func (do *Domain) runawayRecordFlushLoop() { fired = true case r := <-recordCh: records = append(records, r) - failpoint.Inject("FastRunawayGC", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FastRunawayGC")); _err_ == nil { flushRunawayRecords() - }) + } if len(records) >= flushThreshold { flushRunawayRecords() } else if fired { diff --git a/pkg/domain/runaway.go__failpoint_stash__ b/pkg/domain/runaway.go__failpoint_stash__ new file mode 100644 index 0000000000000..ae5d50724a047 --- /dev/null +++ b/pkg/domain/runaway.go__failpoint_stash__ @@ -0,0 +1,659 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package domain + +import ( + "context" + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + rmpb "github.com/pingcap/kvproto/pkg/resource_manager" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/ttl/cache" + "github.com/pingcap/tidb/pkg/ttl/sqlbuilder" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/tikv/client-go/v2/tikv" + pd "github.com/tikv/pd/client" + rmclient "github.com/tikv/pd/client/resource_group/controller" + "go.uber.org/zap" +) + +const ( + runawayRecordFlushInterval = time.Second + runawayRecordGCInterval = time.Hour * 24 + runawayRecordExpiredDuration = time.Hour * 24 * 7 + runawayWatchSyncInterval = time.Second + + runawayRecordGCBatchSize = 100 + runawayRecordGCSelectBatchSize = runawayRecordGCBatchSize * 5 + + maxIDRetries = 3 + runawayLoopLogErrorIntervalCount = 1800 +) + +var systemSchemaCIStr = model.NewCIStr("mysql") + +func (do *Domain) deleteExpiredRows(tableName, colName string, expiredDuration time.Duration) { + if !do.DDL().OwnerManager().IsOwner() { + return + } + failpoint.Inject("FastRunawayGC", func() { + expiredDuration = time.Second * 1 + }) + expiredTime := time.Now().Add(-expiredDuration) + tbCIStr := model.NewCIStr(tableName) + tbl, err := do.InfoSchema().TableByName(context.Background(), systemSchemaCIStr, tbCIStr) + if err != nil { + logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) + return + } + tbInfo := tbl.Meta() + col := tbInfo.FindPublicColumnByName(colName) + if col == nil { + logutil.BgLogger().Error("time column is not public in table", zap.String("table", tableName), zap.String("column", colName)) + return + } + tb, err := cache.NewBasePhysicalTable(systemSchemaCIStr, tbInfo, model.NewCIStr(""), col) + if err != nil { + logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) + return + } + generator, err := sqlbuilder.NewScanQueryGenerator(tb, expiredTime, nil, nil) + if err != nil { + logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) + return + } + var leftRows [][]types.Datum + for { + sql := "" + if sql, err = generator.NextSQL(leftRows, runawayRecordGCSelectBatchSize); err != nil { + logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) + return + } + // to remove + if len(sql) == 0 { + return + } + + rows, sqlErr := execRestrictedSQL(do.sysSessionPool, sql, nil) + if sqlErr != nil { + logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) + return + } + leftRows = make([][]types.Datum, len(rows)) + for i, row := range rows { + leftRows[i] = row.GetDatumRow(tb.KeyColumnTypes) + } + + for len(leftRows) > 0 { + var delBatch [][]types.Datum + if len(leftRows) < runawayRecordGCBatchSize { + delBatch = leftRows + leftRows = nil + } else { + delBatch = leftRows[0:runawayRecordGCBatchSize] + leftRows = leftRows[runawayRecordGCBatchSize:] + } + sql, err := sqlbuilder.BuildDeleteSQL(tb, delBatch, expiredTime) + if err != nil { + logutil.BgLogger().Error( + "build delete SQL failed when deleting system table", + zap.Error(err), + zap.String("table", tb.Schema.O+"."+tb.Name.O), + ) + return + } + + _, err = execRestrictedSQL(do.sysSessionPool, sql, nil) + if err != nil { + logutil.BgLogger().Error( + "delete SQL failed when deleting system table", zap.Error(err), zap.String("SQL", sql), + ) + } + } + } +} + +func (do *Domain) runawayStartLoop() { + defer util.Recover(metrics.LabelDomain, "runawayStartLoop", nil, false) + runawayWatchSyncTicker := time.NewTicker(runawayWatchSyncInterval) + count := 0 + var err error + logutil.BgLogger().Info("try to start runaway manager loop") + for { + select { + case <-do.exit: + return + case <-runawayWatchSyncTicker.C: + // Due to the watch and watch done tables is created later than runaway queries table + err = do.updateNewAndDoneWatch() + if err == nil { + logutil.BgLogger().Info("preparations for the runaway manager are finished and start runaway manager loop") + do.wg.Run(do.runawayRecordFlushLoop, "runawayRecordFlushLoop") + do.wg.Run(do.runawayWatchSyncLoop, "runawayWatchSyncLoop") + do.runawayManager.MarkSyncerInitialized() + return + } + } + if count %= runawayLoopLogErrorIntervalCount; count == 0 { + logutil.BgLogger().Warn( + "failed to start runaway manager loop, please check whether the bootstrap or update is finished", + zap.Error(err)) + } + count++ + } +} + +func (do *Domain) updateNewAndDoneWatch() error { + do.runawaySyncer.mu.Lock() + defer do.runawaySyncer.mu.Unlock() + records, err := do.runawaySyncer.getNewWatchRecords() + if err != nil { + return err + } + for _, r := range records { + do.runawayManager.AddWatch(r) + } + doneRecords, err := do.runawaySyncer.getNewWatchDoneRecords() + if err != nil { + return err + } + for _, r := range doneRecords { + do.runawayManager.RemoveWatch(r) + } + return nil +} + +func (do *Domain) runawayWatchSyncLoop() { + defer util.Recover(metrics.LabelDomain, "runawayWatchSyncLoop", nil, false) + runawayWatchSyncTicker := time.NewTicker(runawayWatchSyncInterval) + count := 0 + for { + select { + case <-do.exit: + return + case <-runawayWatchSyncTicker.C: + err := do.updateNewAndDoneWatch() + if err != nil { + if count %= runawayLoopLogErrorIntervalCount; count == 0 { + logutil.BgLogger().Warn("get runaway watch record failed", zap.Error(err)) + } + count++ + } + } + } +} + +// GetRunawayWatchList is used to get all items from runaway watch list. +func (do *Domain) GetRunawayWatchList() []*resourcegroup.QuarantineRecord { + return do.runawayManager.GetWatchList() +} + +// TryToUpdateRunawayWatch is used to update watch list including +// creation and deletion by manual trigger. +func (do *Domain) TryToUpdateRunawayWatch() error { + return do.updateNewAndDoneWatch() +} + +// RemoveRunawayWatch is used to remove runaway watch item manually. +func (do *Domain) RemoveRunawayWatch(recordID int64) error { + do.runawaySyncer.mu.Lock() + defer do.runawaySyncer.mu.Unlock() + records, err := do.runawaySyncer.getWatchRecordByID(recordID) + if err != nil { + return err + } + if len(records) != 1 { + return errors.Errorf("no runaway watch with the specific ID") + } + err = do.handleRunawayWatchDone(records[0]) + return err +} + +func (do *Domain) runawayRecordFlushLoop() { + defer util.Recover(metrics.LabelDomain, "runawayRecordFlushLoop", nil, false) + + // this times is used to batch flushing records, with 1s duration, + // we can guarantee a watch record can be seen by the user within 1s. + runawayRecordFlushTimer := time.NewTimer(runawayRecordFlushInterval) + runawayRecordGCTicker := time.NewTicker(runawayRecordGCInterval) + failpoint.Inject("FastRunawayGC", func() { + runawayRecordFlushTimer.Stop() + runawayRecordGCTicker.Stop() + runawayRecordFlushTimer = time.NewTimer(time.Millisecond * 50) + runawayRecordGCTicker = time.NewTicker(time.Millisecond * 200) + }) + + fired := false + recordCh := do.runawayManager.RunawayRecordChan() + quarantineRecordCh := do.runawayManager.QuarantineRecordChan() + staleQuarantineRecordCh := do.runawayManager.StaleQuarantineRecordChan() + flushThreshold := do.runawayManager.FlushThreshold() + records := make([]*resourcegroup.RunawayRecord, 0, flushThreshold) + + flushRunawayRecords := func() { + if len(records) == 0 { + return + } + sql, params := resourcegroup.GenRunawayQueriesStmt(records) + if _, err := execRestrictedSQL(do.sysSessionPool, sql, params); err != nil { + logutil.BgLogger().Error("flush runaway records failed", zap.Error(err), zap.Int("count", len(records))) + } + records = records[:0] + } + + for { + select { + case <-do.exit: + return + case <-runawayRecordFlushTimer.C: + flushRunawayRecords() + fired = true + case r := <-recordCh: + records = append(records, r) + failpoint.Inject("FastRunawayGC", func() { + flushRunawayRecords() + }) + if len(records) >= flushThreshold { + flushRunawayRecords() + } else if fired { + fired = false + // meet a new record, reset the timer. + runawayRecordFlushTimer.Reset(runawayRecordFlushInterval) + } + case <-runawayRecordGCTicker.C: + go do.deleteExpiredRows("tidb_runaway_queries", "time", runawayRecordExpiredDuration) + case r := <-quarantineRecordCh: + go func() { + _, err := do.AddRunawayWatch(r) + if err != nil { + logutil.BgLogger().Error("add runaway watch", zap.Error(err)) + } + }() + case r := <-staleQuarantineRecordCh: + go func() { + for i := 0; i < 3; i++ { + err := do.handleRemoveStaleRunawayWatch(r) + if err == nil { + break + } + logutil.BgLogger().Error("remove stale runaway watch", zap.Error(err)) + time.Sleep(time.Second) + } + }() + } + } +} + +// AddRunawayWatch is used to add runaway watch item manually. +func (do *Domain) AddRunawayWatch(record *resourcegroup.QuarantineRecord) (uint64, error) { + se, err := do.sysSessionPool.Get() + defer func() { + do.sysSessionPool.Put(se) + }() + if err != nil { + return 0, errors.Annotate(err, "get session failed") + } + exec := se.(sessionctx.Context).GetSQLExecutor() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + _, err = exec.ExecuteInternal(ctx, "BEGIN") + if err != nil { + return 0, errors.Trace(err) + } + defer func() { + if err != nil { + _, err1 := exec.ExecuteInternal(ctx, "ROLLBACK") + terror.Log(err1) + return + } + _, err = exec.ExecuteInternal(ctx, "COMMIT") + if err != nil { + return + } + }() + sql, params := record.GenInsertionStmt() + _, err = exec.ExecuteInternal(ctx, sql, params...) + if err != nil { + return 0, err + } + for retry := 0; retry < maxIDRetries; retry++ { + if retry > 0 { + select { + case <-do.exit: + return 0, err + case <-time.After(time.Millisecond * time.Duration(retry*100)): + logutil.BgLogger().Warn("failed to get last insert id when adding runaway watch", zap.Error(err)) + } + } + var rs sqlexec.RecordSet + rs, err = exec.ExecuteInternal(ctx, `SELECT LAST_INSERT_ID();`) + if err != nil { + continue + } + var rows []chunk.Row + rows, err = sqlexec.DrainRecordSet(ctx, rs, 1) + //nolint: errcheck + rs.Close() + if err != nil { + continue + } + if len(rows) != 1 { + err = errors.Errorf("unexpected result length: %d", len(rows)) + continue + } + return rows[0].GetUint64(0), nil + } + return 0, errors.Errorf("An error: %v occurred while getting the ID of the newly added watch record. Try querying information_schema.runaway_watches later", err) +} + +func (do *Domain) handleRunawayWatchDone(record *resourcegroup.QuarantineRecord) error { + se, err := do.sysSessionPool.Get() + defer func() { + do.sysSessionPool.Put(se) + }() + if err != nil { + return errors.Annotate(err, "get session failed") + } + sctx, _ := se.(sessionctx.Context) + exec := sctx.GetSQLExecutor() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + _, err = exec.ExecuteInternal(ctx, "BEGIN") + if err != nil { + return errors.Trace(err) + } + defer func() { + if err != nil { + _, err1 := exec.ExecuteInternal(ctx, "ROLLBACK") + terror.Log(err1) + return + } + _, err = exec.ExecuteInternal(ctx, "COMMIT") + if err != nil { + return + } + }() + sql, params := record.GenInsertionDoneStmt() + _, err = exec.ExecuteInternal(ctx, sql, params...) + if err != nil { + return err + } + sql, params = record.GenDeletionStmt() + _, err = exec.ExecuteInternal(ctx, sql, params...) + return err +} + +func (do *Domain) handleRemoveStaleRunawayWatch(record *resourcegroup.QuarantineRecord) error { + se, err := do.sysSessionPool.Get() + defer func() { + do.sysSessionPool.Put(se) + }() + if err != nil { + return errors.Annotate(err, "get session failed") + } + sctx, _ := se.(sessionctx.Context) + exec := sctx.GetSQLExecutor() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + _, err = exec.ExecuteInternal(ctx, "BEGIN") + if err != nil { + return errors.Trace(err) + } + defer func() { + if err != nil { + _, err1 := exec.ExecuteInternal(ctx, "ROLLBACK") + terror.Log(err1) + return + } + _, err = exec.ExecuteInternal(ctx, "COMMIT") + if err != nil { + return + } + }() + sql, params := record.GenDeletionStmt() + _, err = exec.ExecuteInternal(ctx, sql, params...) + return err +} + +func execRestrictedSQL(sessPool util.SessionPool, sql string, params []any) ([]chunk.Row, error) { + se, err := sessPool.Get() + defer func() { + sessPool.Put(se) + }() + if err != nil { + return nil, errors.Annotate(err, "get session failed") + } + sctx := se.(sessionctx.Context) + exec := sctx.GetRestrictedSQLExecutor() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + r, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, + sql, params..., + ) + return r, err +} + +func (do *Domain) initResourceGroupsController(ctx context.Context, pdClient pd.Client, uniqueID uint64) error { + if pdClient == nil { + logutil.BgLogger().Warn("cannot setup up resource controller, not using tikv storage") + // return nil as unistore doesn't support it + return nil + } + + control, err := rmclient.NewResourceGroupController(ctx, uniqueID, pdClient, nil, rmclient.WithMaxWaitDuration(resourcegroup.MaxWaitDuration)) + if err != nil { + return err + } + control.Start(ctx) + serverInfo, err := infosync.GetServerInfo() + if err != nil { + return err + } + serverAddr := net.JoinHostPort(serverInfo.IP, strconv.Itoa(int(serverInfo.Port))) + do.runawayManager = resourcegroup.NewRunawayManager(control, serverAddr) + do.runawaySyncer = newRunawaySyncer(do.sysSessionPool) + do.resourceGroupsController = control + tikv.SetResourceControlInterceptor(control) + return nil +} + +type runawaySyncer struct { + newWatchReader *SystemTableReader + deletionWatchReader *SystemTableReader + sysSessionPool util.SessionPool + mu sync.Mutex +} + +func newRunawaySyncer(sysSessionPool util.SessionPool) *runawaySyncer { + return &runawaySyncer{ + sysSessionPool: sysSessionPool, + newWatchReader: &SystemTableReader{ + resourcegroup.RunawayWatchTableName, + "start_time", + resourcegroup.NullTime}, + deletionWatchReader: &SystemTableReader{resourcegroup.RunawayWatchDoneTableName, + "done_time", + resourcegroup.NullTime}, + } +} + +func (s *runawaySyncer) getWatchRecordByID(id int64) ([]*resourcegroup.QuarantineRecord, error) { + return s.getWatchRecord(s.newWatchReader, s.newWatchReader.genSelectByIDStmt(id), false) +} + +func (s *runawaySyncer) getNewWatchRecords() ([]*resourcegroup.QuarantineRecord, error) { + return s.getWatchRecord(s.newWatchReader, s.newWatchReader.genSelectStmt, true) +} + +func (s *runawaySyncer) getNewWatchDoneRecords() ([]*resourcegroup.QuarantineRecord, error) { + return s.getWatchDoneRecord(s.deletionWatchReader, s.deletionWatchReader.genSelectStmt, true) +} + +func (s *runawaySyncer) getWatchRecord(reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { + se, err := s.sysSessionPool.Get() + defer func() { + s.sysSessionPool.Put(se) + }() + if err != nil { + return nil, errors.Annotate(err, "get session failed") + } + sctx := se.(sessionctx.Context) + exec := sctx.GetRestrictedSQLExecutor() + return getRunawayWatchRecord(exec, reader, sqlGenFn, push) +} + +func (s *runawaySyncer) getWatchDoneRecord(reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { + se, err := s.sysSessionPool.Get() + defer func() { + s.sysSessionPool.Put(se) + }() + if err != nil { + return nil, errors.Annotate(err, "get session failed") + } + sctx := se.(sessionctx.Context) + exec := sctx.GetRestrictedSQLExecutor() + return getRunawayWatchDoneRecord(exec, reader, sqlGenFn, push) +} + +func getRunawayWatchRecord(exec sqlexec.RestrictedSQLExecutor, reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { + rs, err := reader.Read(exec, sqlGenFn) + if err != nil { + return nil, err + } + ret := make([]*resourcegroup.QuarantineRecord, 0, len(rs)) + now := time.Now().UTC() + for _, r := range rs { + startTime, err := r.GetTime(2).GoTime(time.UTC) + if err != nil { + continue + } + var endTime time.Time + if !r.IsNull(3) { + endTime, err = r.GetTime(3).GoTime(time.UTC) + if err != nil { + continue + } + } + qr := &resourcegroup.QuarantineRecord{ + ID: r.GetInt64(0), + ResourceGroupName: r.GetString(1), + StartTime: startTime, + EndTime: endTime, + Watch: rmpb.RunawayWatchType(r.GetInt64(4)), + WatchText: r.GetString(5), + Source: r.GetString(6), + Action: rmpb.RunawayAction(r.GetInt64(7)), + } + // If a TiDB write record slow, it will occur that the record which has earlier start time is inserted later than others. + // So we start the scan a little earlier. + if push { + reader.CheckPoint = now.Add(-3 * runawayWatchSyncInterval) + } + ret = append(ret, qr) + } + return ret, nil +} + +func getRunawayWatchDoneRecord(exec sqlexec.RestrictedSQLExecutor, reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { + rs, err := reader.Read(exec, sqlGenFn) + if err != nil { + return nil, err + } + length := len(rs) + ret := make([]*resourcegroup.QuarantineRecord, 0, length) + now := time.Now().UTC() + for _, r := range rs { + startTime, err := r.GetTime(3).GoTime(time.UTC) + if err != nil { + continue + } + var endTime time.Time + if !r.IsNull(4) { + endTime, err = r.GetTime(4).GoTime(time.UTC) + if err != nil { + continue + } + } + qr := &resourcegroup.QuarantineRecord{ + ID: r.GetInt64(1), + ResourceGroupName: r.GetString(2), + StartTime: startTime, + EndTime: endTime, + Watch: rmpb.RunawayWatchType(r.GetInt64(5)), + WatchText: r.GetString(6), + Source: r.GetString(7), + Action: rmpb.RunawayAction(r.GetInt64(8)), + } + // Ditto as getRunawayWatchRecord. + if push { + reader.CheckPoint = now.Add(-3 * runawayWatchSyncInterval) + } + ret = append(ret, qr) + } + return ret, nil +} + +// SystemTableReader is used to read table `runaway_watch` and `runaway_watch_done`. +type SystemTableReader struct { + TableName string + KeyCol string + CheckPoint time.Time +} + +func (r *SystemTableReader) genSelectByIDStmt(id int64) func() (string, []any) { + return func() (string, []any) { + var builder strings.Builder + params := make([]any, 0, 1) + builder.WriteString("select * from ") + builder.WriteString(r.TableName) + builder.WriteString(" where id = %?") + params = append(params, id) + return builder.String(), params + } +} + +func (r *SystemTableReader) genSelectStmt() (string, []any) { + var builder strings.Builder + params := make([]any, 0, 1) + builder.WriteString("select * from ") + builder.WriteString(r.TableName) + builder.WriteString(" where ") + builder.WriteString(r.KeyCol) + builder.WriteString(" > %? order by ") + builder.WriteString(r.KeyCol) + params = append(params, r.CheckPoint) + return builder.String(), params +} + +func (*SystemTableReader) Read(exec sqlexec.RestrictedSQLExecutor, genFn func() (string, []any)) ([]chunk.Row, error) { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + sql, params := genFn() + rows, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, + sql, params..., + ) + return rows, err +} diff --git a/pkg/executor/adapter.go b/pkg/executor/adapter.go index 5ea9fdcafb178..2c1c5cd305918 100644 --- a/pkg/executor/adapter.go +++ b/pkg/executor/adapter.go @@ -299,12 +299,12 @@ func (a *ExecStmt) PointGet(ctx context.Context) (*recordSet, error) { r.Span.LogKV("sql", a.OriginText()) } - failpoint.Inject("assertTxnManagerInShortPointGetPlan", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInShortPointGetPlan")); _err_ == nil { sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInShortPointGetPlan", true) // stale read should not reach here staleread.AssertStmtStaleness(a.Ctx, false) sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, a.InfoSchema) - }) + } ctx = a.observeStmtBeginForTopSQL(ctx) startTs, err := sessiontxn.GetTxnManager(a.Ctx).GetStmtReadTS() @@ -399,7 +399,7 @@ func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { return 0, err } - failpoint.Inject("assertTxnManagerInRebuildPlan", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInRebuildPlan")); _err_ == nil { if is, ok := a.Ctx.Value(sessiontxn.AssertTxnInfoSchemaAfterRetryKey).(infoschema.InfoSchema); ok { a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaKey, is) a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaAfterRetryKey, nil) @@ -410,7 +410,7 @@ func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { if ret.IsStaleness { sessiontxn.AssertTxnManagerReadTS(a.Ctx, ret.LastSnapshotTS) } - }) + } a.InfoSchema = sessiontxn.GetTxnManager(a.Ctx).GetTxnInfoSchema() replicaReadScope := sessiontxn.GetTxnManager(a.Ctx).GetReadReplicaScope() @@ -488,7 +488,7 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.GetTextToLog(false)), zap.Stack("stack")) }() - failpoint.Inject("assertStaleTSO", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertStaleTSO")); _err_ == nil { if n, ok := val.(int); ok && staleread.IsStmtStaleness(a.Ctx) { txnManager := sessiontxn.GetTxnManager(a.Ctx) ts, err := txnManager.GetStmtReadTS() @@ -500,7 +500,7 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { panic(fmt.Sprintf("different tso %d != %d", n, startTS)) } } - }) + } sctx := a.Ctx ctx = util.SetSessionID(ctx, sctx.GetSessionVars().ConnectionID) if _, ok := a.Plan.(*plannercore.Analyze); ok && sctx.GetSessionVars().InRestrictedSQL { @@ -729,12 +729,12 @@ func (a *ExecStmt) handleForeignKeyCascade(ctx context.Context, fkc *FKCascadeEx return err } err = exec.Next(ctx, e, exec.NewFirstChunk(e)) - failpoint.Inject("handleForeignKeyCascadeError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("handleForeignKeyCascadeError")); _err_ == nil { // Next can recover panic and convert it to error. So we inject error directly here. if val.(bool) && err == nil { err = errors.New("handleForeignKeyCascadeError") } - }) + } closeErr := exec.Close(e) if err == nil { err = closeErr @@ -943,7 +943,7 @@ func (a *ExecStmt) handlePessimisticSelectForUpdate(ctx context.Context, e exec. return rs, nil } - failpoint.Inject("pessimisticSelectForUpdateRetry", nil) + failpoint.Eval(_curpkg_("pessimisticSelectForUpdateRetry")) } } @@ -1050,7 +1050,7 @@ func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e exec.Executor) (e for { if !isFirstAttempt { - failpoint.Inject("pessimisticDMLRetry", nil) + failpoint.Eval(_curpkg_("pessimisticDMLRetry")) } startTime := time.Now() @@ -1122,13 +1122,13 @@ func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error if lockErr == nil { return nil, nil } - failpoint.Inject("assertPessimisticLockErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertPessimisticLockErr")); _err_ == nil { if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errWriteConflict") } else if terror.ErrorEqual(kv.ErrKeyExists, lockErr) { sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errDuplicateKey") } - }) + } defer func() { if _, ok := errors.Cause(err).(*tikverr.ErrDeadlock); ok { @@ -1177,9 +1177,9 @@ func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error a.Ctx.GetSessionVars().StmtCtx.ResetForRetry() a.Ctx.GetSessionVars().RetryInfo.ResetOffset() - failpoint.Inject("assertTxnManagerAfterPessimisticLockErrorRetry", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerAfterPessimisticLockErrorRetry")); _err_ == nil { sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterPessimisticLockErrorRetry", true) - }) + } if err = a.openExecutor(ctx, e); err != nil { return nil, err @@ -1213,10 +1213,10 @@ func (a *ExecStmt) buildExecutor() (exec.Executor, error) { return nil, errors.Trace(b.err) } - failpoint.Inject("assertTxnManagerAfterBuildExecutor", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerAfterBuildExecutor")); _err_ == nil { sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterBuildExecutor", true) sessiontxn.AssertTxnManagerInfoSchema(b.ctx, b.is) - }) + } // ExecuteExec is not a real Executor, we only use it to build another Executor from a prepared statement. if executorExec, ok := e.(*ExecuteExec); ok { @@ -1474,9 +1474,9 @@ func (a *ExecStmt) recordLastQueryInfo(err error) { ruDetail := ruDetailRaw.(*util.RUDetails) lastRUConsumption = ruDetail.RRU() + ruDetail.WRU() } - failpoint.Inject("mockRUConsumption", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("mockRUConsumption")); _err_ == nil { lastRUConsumption = float64(len(sessVars.StmtCtx.OriginalSQL)) - }) + } // Keep the previous queryInfo for `show session_states` because the statement needs to encode it. sessVars.LastQueryInfo = sessionstates.QueryInfo{ TxnScope: sessVars.CheckAndGetTxnScope(), @@ -1668,13 +1668,13 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { WRU: ruDetails.WRU(), WaitRUDuration: ruDetails.RUWaitDuration(), } - failpoint.Inject("assertSyncStatsFailed", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertSyncStatsFailed")); _err_ == nil { if val.(bool) { if !slowItems.IsSyncStatsFailed { panic("isSyncStatsFailed should be true") } } - }) + } if a.retryCount > 0 { slowItems.ExecRetryTime = costTime - sessVars.DurationParse - sessVars.DurationCompile - time.Since(a.retryStartTime) } diff --git a/pkg/executor/adapter.go__failpoint_stash__ b/pkg/executor/adapter.go__failpoint_stash__ new file mode 100644 index 0000000000000..5ea9fdcafb178 --- /dev/null +++ b/pkg/executor/adapter.go__failpoint_stash__ @@ -0,0 +1,2240 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "context" + "fmt" + "math" + "runtime/trace" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + executor_metrics "github.com/pingcap/tidb/pkg/executor/metrics" + "github.com/pingcap/tidb/pkg/executor/staticrecordset" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/keyspace" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/plugin" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/sessiontxn/staleread" + "github.com/pingcap/tidb/pkg/types" + util2 "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/breakpoint" + "github.com/pingcap/tidb/pkg/util/chunk" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/plancodec" + "github.com/pingcap/tidb/pkg/util/redact" + "github.com/pingcap/tidb/pkg/util/replayer" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/stmtsummary" + stmtsummaryv2 "github.com/pingcap/tidb/pkg/util/stmtsummary/v2" + "github.com/pingcap/tidb/pkg/util/stringutil" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/prometheus/client_golang/prometheus" + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// processinfoSetter is the interface use to set current running process info. +type processinfoSetter interface { + SetProcessInfo(string, time.Time, byte, uint64) + UpdateProcessInfo() +} + +// recordSet wraps an executor, implements sqlexec.RecordSet interface +type recordSet struct { + fields []*ast.ResultField + executor exec.Executor + // The `Fields` method may be called after `Close`, and the executor is cleared in the `Close` function. + // Therefore, we need to store the schema in `recordSet` to avoid a null pointer exception when calling `executor.Schema()`. + schema *expression.Schema + stmt *ExecStmt + lastErrs []error + txnStartTS uint64 + once sync.Once +} + +func (a *recordSet) Fields() []*ast.ResultField { + if len(a.fields) == 0 { + a.fields = colNames2ResultFields(a.schema, a.stmt.OutputNames, a.stmt.Ctx.GetSessionVars().CurrentDB) + } + return a.fields +} + +func colNames2ResultFields(schema *expression.Schema, names []*types.FieldName, defaultDB string) []*ast.ResultField { + rfs := make([]*ast.ResultField, 0, schema.Len()) + defaultDBCIStr := model.NewCIStr(defaultDB) + for i := 0; i < schema.Len(); i++ { + dbName := names[i].DBName + if dbName.L == "" && names[i].TblName.L != "" { + dbName = defaultDBCIStr + } + origColName := names[i].OrigColName + emptyOrgName := false + if origColName.L == "" { + origColName = names[i].ColName + emptyOrgName = true + } + rf := &ast.ResultField{ + Column: &model.ColumnInfo{Name: origColName, FieldType: *schema.Columns[i].RetType}, + ColumnAsName: names[i].ColName, + EmptyOrgName: emptyOrgName, + Table: &model.TableInfo{Name: names[i].OrigTblName}, + TableAsName: names[i].TblName, + DBName: dbName, + } + // This is for compatibility. + // See issue https://github.com/pingcap/tidb/issues/10513 . + if len(rf.ColumnAsName.O) > mysql.MaxAliasIdentifierLen { + rf.ColumnAsName.O = rf.ColumnAsName.O[:mysql.MaxAliasIdentifierLen] + } + // Usually the length of O equals the length of L. + // Add this len judgement to avoid panic. + if len(rf.ColumnAsName.L) > mysql.MaxAliasIdentifierLen { + rf.ColumnAsName.L = rf.ColumnAsName.L[:mysql.MaxAliasIdentifierLen] + } + rfs = append(rfs, rf) + } + return rfs +} + +// Next use uses recordSet's executor to get next available chunk for later usage. +// If chunk does not contain any rows, then we update last query found rows in session variable as current found rows. +// The reason we need update is that chunk with 0 rows indicating we already finished current query, we need prepare for +// next query. +// If stmt is not nil and chunk with some rows inside, we simply update last query found rows by the number of row in chunk. +func (a *recordSet) Next(ctx context.Context, req *chunk.Chunk) (err error) { + defer func() { + r := recover() + if r == nil { + return + } + err = util2.GetRecoverError(r) + logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.stmt.GetTextToLog(false)), zap.Stack("stack")) + }() + if a.stmt != nil { + if err := a.stmt.Ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return err + } + } + + err = a.stmt.next(ctx, a.executor, req) + if err != nil { + a.lastErrs = append(a.lastErrs, err) + return err + } + numRows := req.NumRows() + if numRows == 0 { + if a.stmt != nil { + a.stmt.Ctx.GetSessionVars().LastFoundRows = a.stmt.Ctx.GetSessionVars().StmtCtx.FoundRows() + } + return nil + } + if a.stmt != nil { + a.stmt.Ctx.GetSessionVars().StmtCtx.AddFoundRows(uint64(numRows)) + } + return nil +} + +// NewChunk create a chunk base on top-level executor's exec.NewFirstChunk(). +func (a *recordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { + if alloc == nil { + return exec.NewFirstChunk(a.executor) + } + + return alloc.Alloc(a.executor.RetFieldTypes(), a.executor.InitCap(), a.executor.MaxChunkSize()) +} + +func (a *recordSet) Finish() error { + var err error + a.once.Do(func() { + err = exec.Close(a.executor) + cteErr := resetCTEStorageMap(a.stmt.Ctx) + if cteErr != nil { + logutil.BgLogger().Error("got error when reset cte storage, should check if the spill disk file deleted or not", zap.Error(cteErr)) + } + if err == nil { + err = cteErr + } + a.executor = nil + if a.stmt != nil { + status := a.stmt.Ctx.GetSessionVars().SQLKiller.GetKillSignal() + inWriteResultSet := a.stmt.Ctx.GetSessionVars().SQLKiller.InWriteResultSet.Load() + if status > 0 && inWriteResultSet { + logutil.BgLogger().Warn("kill, this SQL might be stuck in the network stack while writing packets to the client.", zap.Uint64("connection ID", a.stmt.Ctx.GetSessionVars().ConnectionID)) + } + } + }) + if err != nil { + a.lastErrs = append(a.lastErrs, err) + } + return err +} + +func (a *recordSet) Close() error { + err := a.Finish() + if err != nil { + logutil.BgLogger().Error("close recordSet error", zap.Error(err)) + } + a.stmt.CloseRecordSet(a.txnStartTS, errors.Join(a.lastErrs...)) + return err +} + +// OnFetchReturned implements commandLifeCycle#OnFetchReturned +func (a *recordSet) OnFetchReturned() { + a.stmt.LogSlowQuery(a.txnStartTS, len(a.lastErrs) == 0, true) +} + +// TryDetach creates a new `RecordSet` which doesn't depend on the current session context. +func (a *recordSet) TryDetach() (sqlexec.RecordSet, bool, error) { + e, ok := Detach(a.executor) + if !ok { + return nil, false, nil + } + return staticrecordset.New(a.Fields(), e, a.stmt.GetTextToLog(false)), true, nil +} + +// GetExecutor4Test exports the internal executor for test purpose. +func (a *recordSet) GetExecutor4Test() any { + return a.executor +} + +// ExecStmt implements the sqlexec.Statement interface, it builds a planner.Plan to an sqlexec.Statement. +type ExecStmt struct { + // GoCtx stores parent go context.Context for a stmt. + GoCtx context.Context + // InfoSchema stores a reference to the schema information. + InfoSchema infoschema.InfoSchema + // Plan stores a reference to the final physical plan. + Plan base.Plan + // Text represents the origin query text. + Text string + + StmtNode ast.StmtNode + + Ctx sessionctx.Context + + // LowerPriority represents whether to lower the execution priority of a query. + LowerPriority bool + isPreparedStmt bool + isSelectForUpdate bool + retryCount uint + retryStartTime time.Time + + // Phase durations are splited into two parts: 1. trying to lock keys (but + // failed); 2. the final iteration of the retry loop. Here we use + // [2]time.Duration to record such info for each phase. The first duration + // is increased only within the current iteration. When we meet a + // pessimistic lock error and decide to retry, we add the first duration to + // the second and reset the first to 0 by calling `resetPhaseDurations`. + phaseBuildDurations [2]time.Duration + phaseOpenDurations [2]time.Duration + phaseNextDurations [2]time.Duration + phaseLockDurations [2]time.Duration + + // OutputNames will be set if using cached plan + OutputNames []*types.FieldName + PsStmt *plannercore.PlanCacheStmt +} + +// GetStmtNode returns the stmtNode inside Statement +func (a *ExecStmt) GetStmtNode() ast.StmtNode { + return a.StmtNode +} + +// PointGet short path for point exec directly from plan, keep only necessary steps +func (a *ExecStmt) PointGet(ctx context.Context) (*recordSet, error) { + r, ctx := tracing.StartRegionEx(ctx, "ExecStmt.PointGet") + defer r.End() + if r.Span != nil { + r.Span.LogKV("sql", a.OriginText()) + } + + failpoint.Inject("assertTxnManagerInShortPointGetPlan", func() { + sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInShortPointGetPlan", true) + // stale read should not reach here + staleread.AssertStmtStaleness(a.Ctx, false) + sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, a.InfoSchema) + }) + + ctx = a.observeStmtBeginForTopSQL(ctx) + startTs, err := sessiontxn.GetTxnManager(a.Ctx).GetStmtReadTS() + if err != nil { + return nil, err + } + a.Ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityHigh + + var executor exec.Executor + useMaxTS := startTs == math.MaxUint64 + + // try to reuse point get executor + // We should only use the cached the executor when the startTS is MaxUint64 + if a.PsStmt.PointGet.Executor != nil && useMaxTS { + exec, ok := a.PsStmt.PointGet.Executor.(*PointGetExecutor) + if !ok { + logutil.Logger(ctx).Error("invalid executor type, not PointGetExecutor for point get path") + a.PsStmt.PointGet.Executor = nil + } else { + // CachedPlan type is already checked in last step + pointGetPlan := a.Plan.(*plannercore.PointGetPlan) + exec.Init(pointGetPlan) + a.PsStmt.PointGet.Executor = exec + executor = exec + } + } + + if executor == nil { + b := newExecutorBuilder(a.Ctx, a.InfoSchema) + executor = b.build(a.Plan) + if b.err != nil { + return nil, b.err + } + pointExecutor, ok := executor.(*PointGetExecutor) + + // Don't cache the executor for non point-get (table dual) or partitioned tables + if ok && useMaxTS && pointExecutor.partitionDefIdx == nil { + a.PsStmt.PointGet.Executor = pointExecutor + } + } + + if err = exec.Open(ctx, executor); err != nil { + terror.Log(exec.Close(executor)) + return nil, err + } + + sctx := a.Ctx + cmd32 := atomic.LoadUint32(&sctx.GetSessionVars().CommandValue) + cmd := byte(cmd32) + var pi processinfoSetter + if raw, ok := sctx.(processinfoSetter); ok { + pi = raw + sql := a.OriginText() + maxExecutionTime := sctx.GetSessionVars().GetMaxExecutionTime() + // Update processinfo, ShowProcess() will use it. + pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime) + if sctx.GetSessionVars().StmtCtx.StmtType == "" { + sctx.GetSessionVars().StmtCtx.StmtType = ast.GetStmtLabel(a.StmtNode) + } + } + + return &recordSet{ + executor: executor, + schema: executor.Schema(), + stmt: a, + txnStartTS: startTs, + }, nil +} + +// OriginText returns original statement as a string. +func (a *ExecStmt) OriginText() string { + return a.Text +} + +// IsPrepared returns true if stmt is a prepare statement. +func (a *ExecStmt) IsPrepared() bool { + return a.isPreparedStmt +} + +// IsReadOnly returns true if a statement is read only. +// If current StmtNode is an ExecuteStmt, we can get its prepared stmt, +// then using ast.IsReadOnly function to determine a statement is read only or not. +func (a *ExecStmt) IsReadOnly(vars *variable.SessionVars) bool { + return planner.IsReadOnly(a.StmtNode, vars) +} + +// RebuildPlan rebuilds current execute statement plan. +// It returns the current information schema version that 'a' is using. +func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { + ret := &plannercore.PreprocessorReturn{} + if err := plannercore.Preprocess(ctx, a.Ctx, a.StmtNode, plannercore.InTxnRetry, plannercore.InitTxnContextProvider, plannercore.WithPreprocessorReturn(ret)); err != nil { + return 0, err + } + + failpoint.Inject("assertTxnManagerInRebuildPlan", func() { + if is, ok := a.Ctx.Value(sessiontxn.AssertTxnInfoSchemaAfterRetryKey).(infoschema.InfoSchema); ok { + a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaKey, is) + a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaAfterRetryKey, nil) + } + sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInRebuildPlan", true) + sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, ret.InfoSchema) + staleread.AssertStmtStaleness(a.Ctx, ret.IsStaleness) + if ret.IsStaleness { + sessiontxn.AssertTxnManagerReadTS(a.Ctx, ret.LastSnapshotTS) + } + }) + + a.InfoSchema = sessiontxn.GetTxnManager(a.Ctx).GetTxnInfoSchema() + replicaReadScope := sessiontxn.GetTxnManager(a.Ctx).GetReadReplicaScope() + if a.Ctx.GetSessionVars().GetReplicaRead().IsClosestRead() && replicaReadScope == kv.GlobalReplicaScope { + logutil.BgLogger().Warn(fmt.Sprintf("tidb can't read closest replicas due to it haven't %s label", placement.DCLabelKey)) + } + p, names, err := planner.Optimize(ctx, a.Ctx, a.StmtNode, a.InfoSchema) + if err != nil { + return 0, err + } + a.OutputNames = names + a.Plan = p + a.Ctx.GetSessionVars().StmtCtx.SetPlan(p) + return a.InfoSchema.SchemaMetaVersion(), nil +} + +// IsFastPlan exports for testing. +func IsFastPlan(p base.Plan) bool { + if proj, ok := p.(*plannercore.PhysicalProjection); ok { + p = proj.Children()[0] + } + switch p.(type) { + case *plannercore.PointGetPlan: + return true + case *plannercore.PhysicalTableDual: + // Plan of following SQL is PhysicalTableDual: + // select 1; + // select @@autocommit; + return true + case *plannercore.Set: + // Plan of following SQL is Set: + // set @a=1; + // set @@autocommit=1; + return true + } + return false +} + +// Exec builds an Executor from a plan. If the Executor doesn't return result, +// like the INSERT, UPDATE statements, it executes in this function. If the Executor returns +// result, execution is done after this function returns, in the returned sqlexec.RecordSet Next method. +func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { + defer func() { + r := recover() + if r == nil { + if a.retryCount > 0 { + metrics.StatementPessimisticRetryCount.Observe(float64(a.retryCount)) + } + lockKeysCnt := a.Ctx.GetSessionVars().StmtCtx.LockKeysCount + if lockKeysCnt > 0 { + metrics.StatementLockKeysCount.Observe(float64(lockKeysCnt)) + } + + execDetails := a.Ctx.GetSessionVars().StmtCtx.GetExecDetails() + if err == nil && execDetails.LockKeysDetail != nil && + (execDetails.LockKeysDetail.AggressiveLockNewCount > 0 || execDetails.LockKeysDetail.AggressiveLockDerivedCount > 0) { + a.Ctx.GetSessionVars().TxnCtx.FairLockingUsed = true + // If this statement is finished when some of the keys are locked with conflict in the last retry, or + // some of the keys are derived from the previous retry, we consider the optimization of fair locking + // takes effect on this statement. + if execDetails.LockKeysDetail.LockedWithConflictCount > 0 || execDetails.LockKeysDetail.AggressiveLockDerivedCount > 0 { + a.Ctx.GetSessionVars().TxnCtx.FairLockingEffective = true + } + } + return + } + recoverdErr, ok := r.(error) + if !ok || !(exeerrors.ErrMemoryExceedForQuery.Equal(recoverdErr) || + exeerrors.ErrMemoryExceedForInstance.Equal(recoverdErr) || + exeerrors.ErrQueryInterrupted.Equal(recoverdErr) || + exeerrors.ErrMaxExecTimeExceeded.Equal(recoverdErr)) { + panic(r) + } + err = recoverdErr + logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.GetTextToLog(false)), zap.Stack("stack")) + }() + + failpoint.Inject("assertStaleTSO", func(val failpoint.Value) { + if n, ok := val.(int); ok && staleread.IsStmtStaleness(a.Ctx) { + txnManager := sessiontxn.GetTxnManager(a.Ctx) + ts, err := txnManager.GetStmtReadTS() + if err != nil { + panic(err) + } + startTS := oracle.ExtractPhysical(ts) / 1000 + if n != int(startTS) { + panic(fmt.Sprintf("different tso %d != %d", n, startTS)) + } + } + }) + sctx := a.Ctx + ctx = util.SetSessionID(ctx, sctx.GetSessionVars().ConnectionID) + if _, ok := a.Plan.(*plannercore.Analyze); ok && sctx.GetSessionVars().InRestrictedSQL { + oriStats, ok := sctx.GetSessionVars().GetSystemVar(variable.TiDBBuildStatsConcurrency) + if !ok { + oriStats = strconv.Itoa(variable.DefBuildStatsConcurrency) + } + oriScan := sctx.GetSessionVars().AnalyzeDistSQLScanConcurrency() + oriIso, ok := sctx.GetSessionVars().GetSystemVar(variable.TxnIsolation) + if !ok { + oriIso = "REPEATABLE-READ" + } + autoConcurrency, err1 := sctx.GetSessionVars().GetSessionOrGlobalSystemVar(ctx, variable.TiDBAutoBuildStatsConcurrency) + terror.Log(err1) + if err1 == nil { + terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TiDBBuildStatsConcurrency, autoConcurrency)) + } + sVal, err2 := sctx.GetSessionVars().GetSessionOrGlobalSystemVar(ctx, variable.TiDBSysProcScanConcurrency) + terror.Log(err2) + if err2 == nil { + concurrency, err3 := strconv.ParseInt(sVal, 10, 64) + terror.Log(err3) + if err3 == nil { + sctx.GetSessionVars().SetAnalyzeDistSQLScanConcurrency(int(concurrency)) + } + } + terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TxnIsolation, ast.ReadCommitted)) + defer func() { + terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TiDBBuildStatsConcurrency, oriStats)) + sctx.GetSessionVars().SetAnalyzeDistSQLScanConcurrency(oriScan) + terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TxnIsolation, oriIso)) + }() + } + + if sctx.GetSessionVars().StmtCtx.HasMemQuotaHint { + sctx.GetSessionVars().MemTracker.SetBytesLimit(sctx.GetSessionVars().StmtCtx.MemQuotaQuery) + } + + // must set plan according to the `Execute` plan before getting planDigest + a.inheritContextFromExecuteStmt() + if rm := domain.GetDomain(sctx).RunawayManager(); variable.EnableResourceControl.Load() && rm != nil { + sessionVars := sctx.GetSessionVars() + stmtCtx := sessionVars.StmtCtx + _, planDigest := GetPlanDigest(stmtCtx) + _, sqlDigest := stmtCtx.SQLDigest() + stmtCtx.RunawayChecker = rm.DeriveChecker(stmtCtx.ResourceGroupName, stmtCtx.OriginalSQL, sqlDigest.String(), planDigest.String(), sessionVars.StartTime) + if err := stmtCtx.RunawayChecker.BeforeExecutor(); err != nil { + return nil, err + } + } + ctx = a.observeStmtBeginForTopSQL(ctx) + + e, err := a.buildExecutor() + if err != nil { + return nil, err + } + + cmd32 := atomic.LoadUint32(&sctx.GetSessionVars().CommandValue) + cmd := byte(cmd32) + var pi processinfoSetter + if raw, ok := sctx.(processinfoSetter); ok { + pi = raw + sql := a.getSQLForProcessInfo() + maxExecutionTime := sctx.GetSessionVars().GetMaxExecutionTime() + // Update processinfo, ShowProcess() will use it. + if a.Ctx.GetSessionVars().StmtCtx.StmtType == "" { + a.Ctx.GetSessionVars().StmtCtx.StmtType = ast.GetStmtLabel(a.StmtNode) + } + // Since maxExecutionTime is used only for query statement, here we limit it affect scope. + if !a.IsReadOnly(a.Ctx.GetSessionVars()) { + maxExecutionTime = 0 + } + pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime) + } + + breakpoint.Inject(a.Ctx, sessiontxn.BreakPointBeforeExecutorFirstRun) + if err = a.openExecutor(ctx, e); err != nil { + terror.Log(exec.Close(e)) + return nil, err + } + + isPessimistic := sctx.GetSessionVars().TxnCtx.IsPessimistic + + if a.isSelectForUpdate { + if sctx.GetSessionVars().UseLowResolutionTSO() { + return nil, errors.New("can not execute select for update statement when 'tidb_low_resolution_tso' is set") + } + // Special handle for "select for update statement" in pessimistic transaction. + if isPessimistic { + return a.handlePessimisticSelectForUpdate(ctx, e) + } + } + + a.prepareFKCascadeContext(e) + if handled, result, err := a.handleNoDelay(ctx, e, isPessimistic); handled || err != nil { + return result, err + } + + var txnStartTS uint64 + txn, err := sctx.Txn(false) + if err != nil { + return nil, err + } + if txn.Valid() { + txnStartTS = txn.StartTS() + } + + return &recordSet{ + executor: e, + schema: e.Schema(), + stmt: a, + txnStartTS: txnStartTS, + }, nil +} + +func (a *ExecStmt) inheritContextFromExecuteStmt() { + if executePlan, ok := a.Plan.(*plannercore.Execute); ok { + a.Ctx.SetValue(sessionctx.QueryString, executePlan.Stmt.Text()) + a.OutputNames = executePlan.OutputNames() + a.isPreparedStmt = true + a.Plan = executePlan.Plan + a.Ctx.GetSessionVars().StmtCtx.SetPlan(executePlan.Plan) + } +} + +func (a *ExecStmt) getSQLForProcessInfo() string { + sql := a.OriginText() + if simple, ok := a.Plan.(*plannercore.Simple); ok && simple.Statement != nil { + if ss, ok := simple.Statement.(ast.SensitiveStmtNode); ok { + // Use SecureText to avoid leak password information. + sql = ss.SecureText() + } + } else if sn, ok2 := a.StmtNode.(ast.SensitiveStmtNode); ok2 { + // such as import into statement + sql = sn.SecureText() + } + return sql +} + +func (a *ExecStmt) handleStmtForeignKeyTrigger(ctx context.Context, e exec.Executor) error { + stmtCtx := a.Ctx.GetSessionVars().StmtCtx + if stmtCtx.ForeignKeyTriggerCtx.HasFKCascades { + // If the ExecStmt has foreign key cascade to be executed, we need call `StmtCommit` to commit the ExecStmt itself + // change first. + // Since `UnionScanExec` use `SnapshotIter` and `SnapshotGetter` to read txn mem-buffer, if we don't do `StmtCommit`, + // then the fk cascade executor can't read the mem-buffer changed by the ExecStmt. + a.Ctx.StmtCommit(ctx) + } + err := a.handleForeignKeyTrigger(ctx, e, 1) + if err != nil { + err1 := a.handleFKTriggerError(stmtCtx) + if err1 != nil { + return errors.Errorf("handle foreign key trigger error failed, err: %v, original_err: %v", err1, err) + } + return err + } + if stmtCtx.ForeignKeyTriggerCtx.SavepointName != "" { + a.Ctx.GetSessionVars().TxnCtx.ReleaseSavepoint(stmtCtx.ForeignKeyTriggerCtx.SavepointName) + } + return nil +} + +var maxForeignKeyCascadeDepth = 15 + +func (a *ExecStmt) handleForeignKeyTrigger(ctx context.Context, e exec.Executor, depth int) error { + exec, ok := e.(WithForeignKeyTrigger) + if !ok { + return nil + } + fkChecks := exec.GetFKChecks() + for _, fkCheck := range fkChecks { + err := fkCheck.doCheck(ctx) + if err != nil { + return err + } + } + fkCascades := exec.GetFKCascades() + for _, fkCascade := range fkCascades { + err := a.handleForeignKeyCascade(ctx, fkCascade, depth) + if err != nil { + return err + } + } + return nil +} + +// handleForeignKeyCascade uses to execute foreign key cascade behaviour, the progress is: +// 1. Build delete/update executor for foreign key on delete/update behaviour. +// a. Construct delete/update AST. We used to try generated SQL string first and then parse the SQL to get AST, +// but we need convert Datum to string, there may be some risks here, since assert_eq(datum_a, parse(datum_a.toString())) may be broken. +// so we chose to construct AST directly. +// b. Build plan by the delete/update AST. +// c. Build executor by the delete/update plan. +// 2. Execute the delete/update executor. +// 3. Close the executor. +// 4. `StmtCommit` to commit the kv change to transaction mem-buffer. +// 5. If the foreign key cascade behaviour has more fk value need to be cascaded, go to step 1. +func (a *ExecStmt) handleForeignKeyCascade(ctx context.Context, fkc *FKCascadeExec, depth int) error { + if a.Ctx.GetSessionVars().StmtCtx.RuntimeStatsColl != nil { + fkc.stats = &FKCascadeRuntimeStats{} + defer a.Ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(fkc.plan.ID(), fkc.stats) + } + if len(fkc.fkValues) == 0 && len(fkc.fkUpdatedValuesMap) == 0 { + return nil + } + if depth > maxForeignKeyCascadeDepth { + return exeerrors.ErrForeignKeyCascadeDepthExceeded.GenWithStackByArgs(maxForeignKeyCascadeDepth) + } + a.Ctx.GetSessionVars().StmtCtx.InHandleForeignKeyTrigger = true + defer func() { + a.Ctx.GetSessionVars().StmtCtx.InHandleForeignKeyTrigger = false + }() + if fkc.stats != nil { + start := time.Now() + defer func() { + fkc.stats.Total += time.Since(start) + }() + } + for { + e, err := fkc.buildExecutor(ctx) + if err != nil || e == nil { + return err + } + if err := exec.Open(ctx, e); err != nil { + terror.Log(exec.Close(e)) + return err + } + err = exec.Next(ctx, e, exec.NewFirstChunk(e)) + failpoint.Inject("handleForeignKeyCascadeError", func(val failpoint.Value) { + // Next can recover panic and convert it to error. So we inject error directly here. + if val.(bool) && err == nil { + err = errors.New("handleForeignKeyCascadeError") + } + }) + closeErr := exec.Close(e) + if err == nil { + err = closeErr + } + if err != nil { + return err + } + // Call `StmtCommit` uses to flush the fk cascade executor change into txn mem-buffer, + // then the later fk cascade executors can see the mem-buffer changes. + a.Ctx.StmtCommit(ctx) + err = a.handleForeignKeyTrigger(ctx, e, depth+1) + if err != nil { + return err + } + } +} + +// prepareFKCascadeContext records a transaction savepoint for foreign key cascade when this ExecStmt has foreign key +// cascade behaviour and this ExecStmt is in transaction. +func (a *ExecStmt) prepareFKCascadeContext(e exec.Executor) { + exec, ok := e.(WithForeignKeyTrigger) + if !ok || !exec.HasFKCascades() { + return + } + sessVar := a.Ctx.GetSessionVars() + sessVar.StmtCtx.ForeignKeyTriggerCtx.HasFKCascades = true + if !sessVar.InTxn() { + return + } + txn, err := a.Ctx.Txn(false) + if err != nil || !txn.Valid() { + return + } + // Record a txn savepoint if ExecStmt in transaction, the savepoint is use to do rollback when handle foreign key + // cascade failed. + savepointName := "fk_sp_" + strconv.FormatUint(txn.StartTS(), 10) + memDBCheckpoint := txn.GetMemDBCheckpoint() + sessVar.TxnCtx.AddSavepoint(savepointName, memDBCheckpoint) + sessVar.StmtCtx.ForeignKeyTriggerCtx.SavepointName = savepointName +} + +func (a *ExecStmt) handleFKTriggerError(sc *stmtctx.StatementContext) error { + if sc.ForeignKeyTriggerCtx.SavepointName == "" { + return nil + } + txn, err := a.Ctx.Txn(false) + if err != nil || !txn.Valid() { + return err + } + savepointRecord := a.Ctx.GetSessionVars().TxnCtx.RollbackToSavepoint(sc.ForeignKeyTriggerCtx.SavepointName) + if savepointRecord == nil { + // Normally should never run into here, but just in case, rollback the transaction. + err = txn.Rollback() + if err != nil { + return err + } + return errors.Errorf("foreign key cascade savepoint '%s' not found, transaction is rollback, should never happen", sc.ForeignKeyTriggerCtx.SavepointName) + } + txn.RollbackMemDBToCheckpoint(savepointRecord.MemDBCheckpoint) + a.Ctx.GetSessionVars().TxnCtx.ReleaseSavepoint(sc.ForeignKeyTriggerCtx.SavepointName) + return nil +} + +func (a *ExecStmt) handleNoDelay(ctx context.Context, e exec.Executor, isPessimistic bool) (handled bool, rs sqlexec.RecordSet, err error) { + sc := a.Ctx.GetSessionVars().StmtCtx + defer func() { + // If the stmt have no rs like `insert`, The session tracker detachment will be directly + // done in the `defer` function. If the rs is not nil, the detachment will be done in + // `rs.Close` in `handleStmt` + if handled && sc != nil && rs == nil { + sc.DetachMemDiskTracker() + cteErr := resetCTEStorageMap(a.Ctx) + if err == nil { + // Only overwrite err when it's nil. + err = cteErr + } + } + }() + + toCheck := e + isExplainAnalyze := false + if explain, ok := e.(*ExplainExec); ok { + if analyze := explain.getAnalyzeExecToExecutedNoDelay(); analyze != nil { + toCheck = analyze + isExplainAnalyze = true + a.Ctx.GetSessionVars().StmtCtx.IsExplainAnalyzeDML = isExplainAnalyze + } + } + + // If the executor doesn't return any result to the client, we execute it without delay. + if toCheck.Schema().Len() == 0 { + handled = !isExplainAnalyze + if isPessimistic { + err := a.handlePessimisticDML(ctx, toCheck) + return handled, nil, err + } + r, err := a.handleNoDelayExecutor(ctx, toCheck) + return handled, r, err + } else if proj, ok := toCheck.(*ProjectionExec); ok && proj.calculateNoDelay { + // Currently this is only for the "DO" statement. Take "DO 1, @a=2;" as an example: + // the Projection has two expressions and two columns in the schema, but we should + // not return the result of the two expressions. + r, err := a.handleNoDelayExecutor(ctx, e) + return true, r, err + } + + return false, nil, nil +} + +func isNoResultPlan(p base.Plan) bool { + if p.Schema().Len() == 0 { + return true + } + + // Currently this is only for the "DO" statement. Take "DO 1, @a=2;" as an example: + // the Projection has two expressions and two columns in the schema, but we should + // not return the result of the two expressions. + switch raw := p.(type) { + case *logicalop.LogicalProjection: + if raw.CalculateNoDelay { + return true + } + case *plannercore.PhysicalProjection: + if raw.CalculateNoDelay { + return true + } + } + return false +} + +type chunkRowRecordSet struct { + rows []chunk.Row + idx int + fields []*ast.ResultField + e exec.Executor + execStmt *ExecStmt +} + +func (c *chunkRowRecordSet) Fields() []*ast.ResultField { + if c.fields == nil { + c.fields = colNames2ResultFields(c.e.Schema(), c.execStmt.OutputNames, c.execStmt.Ctx.GetSessionVars().CurrentDB) + } + return c.fields +} + +func (c *chunkRowRecordSet) Next(_ context.Context, chk *chunk.Chunk) error { + chk.Reset() + if !chk.IsFull() && c.idx < len(c.rows) { + numToAppend := min(len(c.rows)-c.idx, chk.RequiredRows()-chk.NumRows()) + chk.AppendRows(c.rows[c.idx : c.idx+numToAppend]) + c.idx += numToAppend + } + return nil +} + +func (c *chunkRowRecordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { + if alloc == nil { + return exec.NewFirstChunk(c.e) + } + + return alloc.Alloc(c.e.RetFieldTypes(), c.e.InitCap(), c.e.MaxChunkSize()) +} + +func (c *chunkRowRecordSet) Close() error { + c.execStmt.CloseRecordSet(c.execStmt.Ctx.GetSessionVars().TxnCtx.StartTS, nil) + return nil +} + +func (a *ExecStmt) handlePessimisticSelectForUpdate(ctx context.Context, e exec.Executor) (_ sqlexec.RecordSet, retErr error) { + if snapshotTS := a.Ctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { + terror.Log(exec.Close(e)) + return nil, errors.New("can not execute write statement when 'tidb_snapshot' is set") + } + + txnManager := sessiontxn.GetTxnManager(a.Ctx) + err := txnManager.OnPessimisticStmtStart(ctx) + if err != nil { + return nil, err + } + defer func() { + isSuccessful := retErr == nil + err1 := txnManager.OnPessimisticStmtEnd(ctx, isSuccessful) + if retErr == nil && err1 != nil { + retErr = err1 + } + }() + + isFirstAttempt := true + + for { + startTime := time.Now() + rs, err := a.runPessimisticSelectForUpdate(ctx, e) + + if isFirstAttempt { + executor_metrics.SelectForUpdateFirstAttemptDuration.Observe(time.Since(startTime).Seconds()) + isFirstAttempt = false + } else { + executor_metrics.SelectForUpdateRetryDuration.Observe(time.Since(startTime).Seconds()) + } + + e, err = a.handlePessimisticLockError(ctx, err) + if err != nil { + return nil, err + } + if e == nil { + return rs, nil + } + + failpoint.Inject("pessimisticSelectForUpdateRetry", nil) + } +} + +func (a *ExecStmt) runPessimisticSelectForUpdate(ctx context.Context, e exec.Executor) (sqlexec.RecordSet, error) { + defer func() { + terror.Log(exec.Close(e)) + }() + var rows []chunk.Row + var err error + req := exec.TryNewCacheChunk(e) + for { + err = a.next(ctx, e, req) + if err != nil { + // Handle 'write conflict' error. + break + } + if req.NumRows() == 0 { + return &chunkRowRecordSet{rows: rows, e: e, execStmt: a}, nil + } + iter := chunk.NewIterator4Chunk(req) + for r := iter.Begin(); r != iter.End(); r = iter.Next() { + rows = append(rows, r) + } + req = chunk.Renew(req, a.Ctx.GetSessionVars().MaxChunkSize) + } + return nil, err +} + +func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e exec.Executor) (sqlexec.RecordSet, error) { + sctx := a.Ctx + r, ctx := tracing.StartRegionEx(ctx, "executor.handleNoDelayExecutor") + defer r.End() + + var err error + defer func() { + terror.Log(exec.Close(e)) + a.logAudit() + }() + + // Check if "tidb_snapshot" is set for the write executors. + // In history read mode, we can not do write operations. + switch e.(type) { + case *DeleteExec, *InsertExec, *UpdateExec, *ReplaceExec, *LoadDataExec, *DDLExec, *ImportIntoExec: + snapshotTS := sctx.GetSessionVars().SnapshotTS + if snapshotTS != 0 { + return nil, errors.New("can not execute write statement when 'tidb_snapshot' is set") + } + if sctx.GetSessionVars().UseLowResolutionTSO() { + return nil, errors.New("can not execute write statement when 'tidb_low_resolution_tso' is set") + } + } + + err = a.next(ctx, e, exec.TryNewCacheChunk(e)) + if err != nil { + return nil, err + } + err = a.handleStmtForeignKeyTrigger(ctx, e) + return nil, err +} + +func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e exec.Executor) (err error) { + sctx := a.Ctx + // Do not activate the transaction here. + // When autocommit = 0 and transaction in pessimistic mode, + // statements like set xxx = xxx; should not active the transaction. + txn, err := sctx.Txn(false) + if err != nil { + return err + } + txnCtx := sctx.GetSessionVars().TxnCtx + defer func() { + if err != nil && !sctx.GetSessionVars().ConstraintCheckInPlacePessimistic && sctx.GetSessionVars().InTxn() { + // If it's not a retryable error, rollback current transaction instead of rolling back current statement like + // in normal transactions, because we cannot locate and rollback the statement that leads to the lock error. + // This is too strict, but since the feature is not for everyone, it's the easiest way to guarantee safety. + stmtText := parser.Normalize(a.OriginText(), sctx.GetSessionVars().EnableRedactLog) + logutil.Logger(ctx).Info("Transaction abort for the safety of lazy uniqueness check. "+ + "Note this may not be a uniqueness violation.", + zap.Error(err), + zap.String("statement", stmtText), + zap.Uint64("conn", sctx.GetSessionVars().ConnectionID), + zap.Uint64("txnStartTS", txnCtx.StartTS), + zap.Uint64("forUpdateTS", txnCtx.GetForUpdateTS()), + ) + sctx.GetSessionVars().SetInTxn(false) + err = exeerrors.ErrLazyUniquenessCheckFailure.GenWithStackByArgs(err.Error()) + } + }() + + txnManager := sessiontxn.GetTxnManager(a.Ctx) + err = txnManager.OnPessimisticStmtStart(ctx) + if err != nil { + return err + } + defer func() { + isSuccessful := err == nil + err1 := txnManager.OnPessimisticStmtEnd(ctx, isSuccessful) + if err == nil && err1 != nil { + err = err1 + } + }() + + isFirstAttempt := true + + for { + if !isFirstAttempt { + failpoint.Inject("pessimisticDMLRetry", nil) + } + + startTime := time.Now() + _, err = a.handleNoDelayExecutor(ctx, e) + if !txn.Valid() { + return err + } + + if isFirstAttempt { + executor_metrics.DmlFirstAttemptDuration.Observe(time.Since(startTime).Seconds()) + isFirstAttempt = false + } else { + executor_metrics.DmlRetryDuration.Observe(time.Since(startTime).Seconds()) + } + + if err != nil { + // It is possible the DML has point get plan that locks the key. + e, err = a.handlePessimisticLockError(ctx, err) + if err != nil { + if exeerrors.ErrDeadlock.Equal(err) { + metrics.StatementDeadlockDetectDuration.Observe(time.Since(startTime).Seconds()) + } + return err + } + continue + } + keys, err1 := txn.(pessimisticTxn).KeysNeedToLock() + if err1 != nil { + return err1 + } + keys = txnCtx.CollectUnchangedKeysForLock(keys) + if len(keys) == 0 { + return nil + } + keys = filterTemporaryTableKeys(sctx.GetSessionVars(), keys) + seVars := sctx.GetSessionVars() + keys = filterLockTableKeys(seVars.StmtCtx, keys) + lockCtx, err := newLockCtx(sctx, seVars.LockWaitTimeout, len(keys)) + if err != nil { + return err + } + var lockKeyStats *util.LockKeysDetails + ctx = context.WithValue(ctx, util.LockKeysDetailCtxKey, &lockKeyStats) + startLocking := time.Now() + err = txn.LockKeys(ctx, lockCtx, keys...) + a.phaseLockDurations[0] += time.Since(startLocking) + if e.RuntimeStats() != nil { + e.RuntimeStats().Record(time.Since(startLocking), 0) + } + if lockKeyStats != nil { + seVars.StmtCtx.MergeLockKeysExecDetails(lockKeyStats) + } + if err == nil { + return nil + } + e, err = a.handlePessimisticLockError(ctx, err) + if err != nil { + // todo: Report deadlock + if exeerrors.ErrDeadlock.Equal(err) { + metrics.StatementDeadlockDetectDuration.Observe(time.Since(startLocking).Seconds()) + } + return err + } + } +} + +// handlePessimisticLockError updates TS and rebuild executor if the err is write conflict. +func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error) (_ exec.Executor, err error) { + if lockErr == nil { + return nil, nil + } + failpoint.Inject("assertPessimisticLockErr", func() { + if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { + sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errWriteConflict") + } else if terror.ErrorEqual(kv.ErrKeyExists, lockErr) { + sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errDuplicateKey") + } + }) + + defer func() { + if _, ok := errors.Cause(err).(*tikverr.ErrDeadlock); ok { + err = exeerrors.ErrDeadlock + } + }() + + txnManager := sessiontxn.GetTxnManager(a.Ctx) + action, err := txnManager.OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterPessimisticLock, lockErr) + if err != nil { + return nil, err + } + + if action != sessiontxn.StmtActionRetryReady { + return nil, lockErr + } + + if a.retryCount >= config.GetGlobalConfig().PessimisticTxn.MaxRetryCount { + return nil, errors.New("pessimistic lock retry limit reached") + } + a.retryCount++ + a.retryStartTime = time.Now() + + err = txnManager.OnStmtRetry(ctx) + if err != nil { + return nil, err + } + + // Without this line of code, the result will still be correct. But it can ensure that the update time of for update read + // is determined which is beneficial for testing. + if _, err = txnManager.GetStmtForUpdateTS(); err != nil { + return nil, err + } + + breakpoint.Inject(a.Ctx, sessiontxn.BreakPointOnStmtRetryAfterLockError) + + a.resetPhaseDurations() + + a.inheritContextFromExecuteStmt() + e, err := a.buildExecutor() + if err != nil { + return nil, err + } + // Rollback the statement change before retry it. + a.Ctx.StmtRollback(ctx, true) + a.Ctx.GetSessionVars().StmtCtx.ResetForRetry() + a.Ctx.GetSessionVars().RetryInfo.ResetOffset() + + failpoint.Inject("assertTxnManagerAfterPessimisticLockErrorRetry", func() { + sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterPessimisticLockErrorRetry", true) + }) + + if err = a.openExecutor(ctx, e); err != nil { + return nil, err + } + return e, nil +} + +type pessimisticTxn interface { + kv.Transaction + // KeysNeedToLock returns the keys need to be locked. + KeysNeedToLock() ([]kv.Key, error) +} + +// buildExecutor build an executor from plan, prepared statement may need additional procedure. +func (a *ExecStmt) buildExecutor() (exec.Executor, error) { + defer func(start time.Time) { a.phaseBuildDurations[0] += time.Since(start) }(time.Now()) + ctx := a.Ctx + stmtCtx := ctx.GetSessionVars().StmtCtx + if _, ok := a.Plan.(*plannercore.Execute); !ok { + if stmtCtx.Priority == mysql.NoPriority && a.LowerPriority { + stmtCtx.Priority = kv.PriorityLow + } + } + if _, ok := a.Plan.(*plannercore.Analyze); ok && ctx.GetSessionVars().InRestrictedSQL { + ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityLow + } + + b := newExecutorBuilder(ctx, a.InfoSchema) + e := b.build(a.Plan) + if b.err != nil { + return nil, errors.Trace(b.err) + } + + failpoint.Inject("assertTxnManagerAfterBuildExecutor", func() { + sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterBuildExecutor", true) + sessiontxn.AssertTxnManagerInfoSchema(b.ctx, b.is) + }) + + // ExecuteExec is not a real Executor, we only use it to build another Executor from a prepared statement. + if executorExec, ok := e.(*ExecuteExec); ok { + err := executorExec.Build(b) + if err != nil { + return nil, err + } + if executorExec.lowerPriority { + ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityLow + } + e = executorExec.stmtExec + } + a.isSelectForUpdate = b.hasLock && (!stmtCtx.InDeleteStmt && !stmtCtx.InUpdateStmt && !stmtCtx.InInsertStmt) + return e, nil +} + +func (a *ExecStmt) openExecutor(ctx context.Context, e exec.Executor) (err error) { + defer func() { + if r := recover(); r != nil { + err = util2.GetRecoverError(r) + } + }() + start := time.Now() + err = exec.Open(ctx, e) + a.phaseOpenDurations[0] += time.Since(start) + return err +} + +func (a *ExecStmt) next(ctx context.Context, e exec.Executor, req *chunk.Chunk) error { + start := time.Now() + err := exec.Next(ctx, e, req) + a.phaseNextDurations[0] += time.Since(start) + return err +} + +func (a *ExecStmt) resetPhaseDurations() { + a.phaseBuildDurations[1] += a.phaseBuildDurations[0] + a.phaseBuildDurations[0] = 0 + a.phaseOpenDurations[1] += a.phaseOpenDurations[0] + a.phaseOpenDurations[0] = 0 + a.phaseNextDurations[1] += a.phaseNextDurations[0] + a.phaseNextDurations[0] = 0 + a.phaseLockDurations[1] += a.phaseLockDurations[0] + a.phaseLockDurations[0] = 0 +} + +// QueryReplacer replaces new line and tab for grep result including query string. +var QueryReplacer = strings.NewReplacer("\r", " ", "\n", " ", "\t", " ") + +func (a *ExecStmt) logAudit() { + sessVars := a.Ctx.GetSessionVars() + if sessVars.InRestrictedSQL { + return + } + + err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + audit := plugin.DeclareAuditManifest(p.Manifest) + if audit.OnGeneralEvent != nil { + cmd := mysql.Command2Str[byte(atomic.LoadUint32(&a.Ctx.GetSessionVars().CommandValue))] + ctx := context.WithValue(context.Background(), plugin.ExecStartTimeCtxKey, a.Ctx.GetSessionVars().StartTime) + audit.OnGeneralEvent(ctx, sessVars, plugin.Completed, cmd) + } + return nil + }) + if err != nil { + log.Error("log audit log failure", zap.Error(err)) + } +} + +// FormatSQL is used to format the original SQL, e.g. truncating long SQL, appending prepared arguments. +func FormatSQL(sql string) stringutil.StringerFunc { + return func() string { return formatSQL(sql) } +} + +func formatSQL(sql string) string { + length := len(sql) + maxQueryLen := variable.QueryLogMaxLen.Load() + if maxQueryLen <= 0 { + return QueryReplacer.Replace(sql) // no limit + } + if int32(length) > maxQueryLen { + var result strings.Builder + result.Grow(int(maxQueryLen)) + result.WriteString(sql[:maxQueryLen]) + fmt.Fprintf(&result, "(len:%d)", length) + return QueryReplacer.Replace(result.String()) + } + return QueryReplacer.Replace(sql) +} + +func getPhaseDurationObserver(phase string, internal bool) prometheus.Observer { + if internal { + if ob, found := executor_metrics.PhaseDurationObserverMapInternal[phase]; found { + return ob + } + return executor_metrics.ExecUnknownInternal + } + if ob, found := executor_metrics.PhaseDurationObserverMap[phase]; found { + return ob + } + return executor_metrics.ExecUnknown +} + +func (a *ExecStmt) observePhaseDurations(internal bool, commitDetails *util.CommitDetails) { + for _, it := range []struct { + duration time.Duration + phase string + }{ + {a.phaseBuildDurations[0], executor_metrics.PhaseBuildFinal}, + {a.phaseBuildDurations[1], executor_metrics.PhaseBuildLocking}, + {a.phaseOpenDurations[0], executor_metrics.PhaseOpenFinal}, + {a.phaseOpenDurations[1], executor_metrics.PhaseOpenLocking}, + {a.phaseNextDurations[0], executor_metrics.PhaseNextFinal}, + {a.phaseNextDurations[1], executor_metrics.PhaseNextLocking}, + {a.phaseLockDurations[0], executor_metrics.PhaseLockFinal}, + {a.phaseLockDurations[1], executor_metrics.PhaseLockLocking}, + } { + if it.duration > 0 { + getPhaseDurationObserver(it.phase, internal).Observe(it.duration.Seconds()) + } + } + if commitDetails != nil { + for _, it := range []struct { + duration time.Duration + phase string + }{ + {commitDetails.PrewriteTime, executor_metrics.PhaseCommitPrewrite}, + {commitDetails.CommitTime, executor_metrics.PhaseCommitCommit}, + {commitDetails.GetCommitTsTime, executor_metrics.PhaseCommitWaitCommitTS}, + {commitDetails.GetLatestTsTime, executor_metrics.PhaseCommitWaitLatestTS}, + {commitDetails.LocalLatchTime, executor_metrics.PhaseCommitWaitLatch}, + {commitDetails.WaitPrewriteBinlogTime, executor_metrics.PhaseCommitWaitBinlog}, + } { + if it.duration > 0 { + getPhaseDurationObserver(it.phase, internal).Observe(it.duration.Seconds()) + } + } + } + if stmtDetailsRaw := a.GoCtx.Value(execdetails.StmtExecDetailKey); stmtDetailsRaw != nil { + d := stmtDetailsRaw.(*execdetails.StmtExecDetails).WriteSQLRespDuration + if d > 0 { + getPhaseDurationObserver(executor_metrics.PhaseWriteResponse, internal).Observe(d.Seconds()) + } + } +} + +// FinishExecuteStmt is used to record some information after `ExecStmt` execution finished: +// 1. record slow log if needed. +// 2. record summary statement. +// 3. record execute duration metric. +// 4. update the `PrevStmt` in session variable. +// 5. reset `DurationParse` in session variable. +func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, err error, hasMoreResults bool) { + a.checkPlanReplayerCapture(txnTS) + + sessVars := a.Ctx.GetSessionVars() + execDetail := sessVars.StmtCtx.GetExecDetails() + // Attach commit/lockKeys runtime stats to executor runtime stats. + if (execDetail.CommitDetail != nil || execDetail.LockKeysDetail != nil) && sessVars.StmtCtx.RuntimeStatsColl != nil { + statsWithCommit := &execdetails.RuntimeStatsWithCommit{ + Commit: execDetail.CommitDetail, + LockKeys: execDetail.LockKeysDetail, + } + sessVars.StmtCtx.RuntimeStatsColl.RegisterStats(a.Plan.ID(), statsWithCommit) + } + // Record related SLI metrics. + if execDetail.CommitDetail != nil && execDetail.CommitDetail.WriteSize > 0 { + a.Ctx.GetTxnWriteThroughputSLI().AddTxnWriteSize(execDetail.CommitDetail.WriteSize, execDetail.CommitDetail.WriteKeys) + } + if execDetail.ScanDetail != nil && sessVars.StmtCtx.AffectedRows() > 0 { + processedKeys := atomic.LoadInt64(&execDetail.ScanDetail.ProcessedKeys) + if processedKeys > 0 { + // Only record the read keys in write statement which affect row more than 0. + a.Ctx.GetTxnWriteThroughputSLI().AddReadKeys(processedKeys) + } + } + succ := err == nil + if a.Plan != nil { + // If this statement has a Plan, the StmtCtx.plan should have been set when it comes here, + // but we set it again in case we missed some code paths. + sessVars.StmtCtx.SetPlan(a.Plan) + } + // `LowSlowQuery` and `SummaryStmt` must be called before recording `PrevStmt`. + a.LogSlowQuery(txnTS, succ, hasMoreResults) + a.SummaryStmt(succ) + a.observeStmtFinishedForTopSQL() + if sessVars.StmtCtx.IsTiFlash.Load() { + if succ { + executor_metrics.TotalTiFlashQuerySuccCounter.Inc() + } else { + metrics.TiFlashQueryTotalCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err), metrics.LblError).Inc() + } + } + a.updatePrevStmt() + a.recordLastQueryInfo(err) + a.observePhaseDurations(sessVars.InRestrictedSQL, execDetail.CommitDetail) + executeDuration := sessVars.GetExecuteDuration() + if sessVars.InRestrictedSQL { + executor_metrics.SessionExecuteRunDurationInternal.Observe(executeDuration.Seconds()) + } else { + executor_metrics.SessionExecuteRunDurationGeneral.Observe(executeDuration.Seconds()) + } + // Reset DurationParse due to the next statement may not need to be parsed (not a text protocol query). + sessVars.DurationParse = 0 + // Clean the stale read flag when statement execution finish + sessVars.StmtCtx.IsStaleness = false + // Clean the MPP query info + sessVars.StmtCtx.MPPQueryInfo.QueryID.Store(0) + sessVars.StmtCtx.MPPQueryInfo.QueryTS.Store(0) + sessVars.StmtCtx.MPPQueryInfo.AllocatedMPPTaskID.Store(0) + sessVars.StmtCtx.MPPQueryInfo.AllocatedMPPGatherID.Store(0) + + if sessVars.StmtCtx.ReadFromTableCache { + metrics.ReadFromTableCacheCounter.Inc() + } + + // Update fair locking related counters by stmt + if execDetail.LockKeysDetail != nil { + if execDetail.LockKeysDetail.AggressiveLockNewCount > 0 || execDetail.LockKeysDetail.AggressiveLockDerivedCount > 0 { + executor_metrics.FairLockingStmtUsedCount.Inc() + // If this statement is finished when some of the keys are locked with conflict in the last retry, or + // some of the keys are derived from the previous retry, we consider the optimization of fair locking + // takes effect on this statement. + if execDetail.LockKeysDetail.LockedWithConflictCount > 0 || execDetail.LockKeysDetail.AggressiveLockDerivedCount > 0 { + executor_metrics.FairLockingStmtEffectiveCount.Inc() + } + } + } + // If the transaction is committed, update fair locking related counters by txn + if execDetail.CommitDetail != nil { + if sessVars.TxnCtx.FairLockingUsed { + executor_metrics.FairLockingTxnUsedCount.Inc() + } + if sessVars.TxnCtx.FairLockingEffective { + executor_metrics.FairLockingTxnEffectiveCount.Inc() + } + } + + a.Ctx.ReportUsageStats() +} + +func (a *ExecStmt) recordLastQueryInfo(err error) { + sessVars := a.Ctx.GetSessionVars() + // Record diagnostic information for DML statements + recordLastQuery := false + switch typ := a.StmtNode.(type) { + case *ast.ShowStmt: + recordLastQuery = typ.Tp != ast.ShowSessionStates + case *ast.ExecuteStmt, ast.DMLNode: + recordLastQuery = true + } + if recordLastQuery { + var lastRUConsumption float64 + if ruDetailRaw := a.GoCtx.Value(util.RUDetailsCtxKey); ruDetailRaw != nil { + ruDetail := ruDetailRaw.(*util.RUDetails) + lastRUConsumption = ruDetail.RRU() + ruDetail.WRU() + } + failpoint.Inject("mockRUConsumption", func(_ failpoint.Value) { + lastRUConsumption = float64(len(sessVars.StmtCtx.OriginalSQL)) + }) + // Keep the previous queryInfo for `show session_states` because the statement needs to encode it. + sessVars.LastQueryInfo = sessionstates.QueryInfo{ + TxnScope: sessVars.CheckAndGetTxnScope(), + StartTS: sessVars.TxnCtx.StartTS, + ForUpdateTS: sessVars.TxnCtx.GetForUpdateTS(), + RUConsumption: lastRUConsumption, + } + if err != nil { + sessVars.LastQueryInfo.ErrMsg = err.Error() + } + } +} + +func (a *ExecStmt) checkPlanReplayerCapture(txnTS uint64) { + if kv.GetInternalSourceType(a.GoCtx) == kv.InternalTxnStats { + return + } + se := a.Ctx + if !se.GetSessionVars().InRestrictedSQL && se.GetSessionVars().IsPlanReplayerCaptureEnabled() { + stmtNode := a.GetStmtNode() + if se.GetSessionVars().EnablePlanReplayedContinuesCapture { + if checkPlanReplayerContinuesCaptureValidStmt(stmtNode) { + checkPlanReplayerContinuesCapture(se, stmtNode, txnTS) + } + } else { + checkPlanReplayerCaptureTask(se, stmtNode, txnTS) + } + } +} + +// CloseRecordSet will finish the execution of current statement and do some record work +func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) { + a.FinishExecuteStmt(txnStartTS, lastErr, false) + a.logAudit() + a.Ctx.GetSessionVars().StmtCtx.DetachMemDiskTracker() +} + +// Clean CTE storage shared by different CTEFullScan executor within a SQL stmt. +// Will return err in two situations: +// 1. Got err when remove disk spill file. +// 2. Some logical error like ref count of CTEStorage is less than 0. +func resetCTEStorageMap(se sessionctx.Context) error { + tmp := se.GetSessionVars().StmtCtx.CTEStorageMap + if tmp == nil { + // Close() is already called, so no need to reset. Such as TraceExec. + return nil + } + storageMap, ok := tmp.(map[int]*CTEStorages) + if !ok { + return errors.New("type assertion for CTEStorageMap failed") + } + for _, v := range storageMap { + v.ResTbl.Lock() + err1 := v.ResTbl.DerefAndClose() + // Make sure we do not hold the lock for longer than necessary. + v.ResTbl.Unlock() + // No need to lock IterInTbl. + err2 := v.IterInTbl.DerefAndClose() + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + } + se.GetSessionVars().StmtCtx.CTEStorageMap = nil + return nil +} + +// LogSlowQuery is used to print the slow query in the log files. +func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { + sessVars := a.Ctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + level := log.GetLevel() + cfg := config.GetGlobalConfig() + costTime := sessVars.GetTotalCostDuration() + threshold := time.Duration(atomic.LoadUint64(&cfg.Instance.SlowThreshold)) * time.Millisecond + enable := cfg.Instance.EnableSlowLog.Load() + // if the level is Debug, or trace is enabled, print slow logs anyway + force := level <= zapcore.DebugLevel || trace.IsEnabled() + if (!enable || costTime < threshold) && !force { + return + } + sql := FormatSQL(a.GetTextToLog(true)) + _, digest := stmtCtx.SQLDigest() + + var indexNames string + if len(stmtCtx.IndexNames) > 0 { + // remove duplicate index. + idxMap := make(map[string]struct{}) + buf := bytes.NewBuffer(make([]byte, 0, 4)) + buf.WriteByte('[') + for _, idx := range stmtCtx.IndexNames { + _, ok := idxMap[idx] + if ok { + continue + } + idxMap[idx] = struct{}{} + if buf.Len() > 1 { + buf.WriteByte(',') + } + buf.WriteString(idx) + } + buf.WriteByte(']') + indexNames = buf.String() + } + var stmtDetail execdetails.StmtExecDetails + stmtDetailRaw := a.GoCtx.Value(execdetails.StmtExecDetailKey) + if stmtDetailRaw != nil { + stmtDetail = *(stmtDetailRaw.(*execdetails.StmtExecDetails)) + } + var tikvExecDetail util.ExecDetails + tikvExecDetailRaw := a.GoCtx.Value(util.ExecDetailsKey) + if tikvExecDetailRaw != nil { + tikvExecDetail = *(tikvExecDetailRaw.(*util.ExecDetails)) + } + ruDetails := util.NewRUDetails() + if ruDetailsVal := a.GoCtx.Value(util.RUDetailsCtxKey); ruDetailsVal != nil { + ruDetails = ruDetailsVal.(*util.RUDetails) + } + + execDetail := stmtCtx.GetExecDetails() + copTaskInfo := stmtCtx.CopTasksDetails() + memMax := sessVars.MemTracker.MaxConsumed() + diskMax := sessVars.DiskTracker.MaxConsumed() + _, planDigest := GetPlanDigest(stmtCtx) + + binaryPlan := "" + if variable.GenerateBinaryPlan.Load() { + binaryPlan = getBinaryPlan(a.Ctx) + if len(binaryPlan) > 0 { + binaryPlan = variable.SlowLogBinaryPlanPrefix + binaryPlan + variable.SlowLogPlanSuffix + } + } + + resultRows := GetResultRowsCount(stmtCtx, a.Plan) + + var ( + keyspaceName string + keyspaceID uint32 + ) + keyspaceName = keyspace.GetKeyspaceNameBySettings() + if !keyspace.IsKeyspaceNameEmpty(keyspaceName) { + keyspaceID = uint32(a.Ctx.GetStore().GetCodec().GetKeyspaceID()) + } + if txnTS == 0 { + // TODO: txnTS maybe ambiguous, consider logging stale-read-ts with a new field in the slow log. + txnTS = sessVars.TxnCtx.StaleReadTs + } + + slowItems := &variable.SlowQueryLogItems{ + TxnTS: txnTS, + KeyspaceName: keyspaceName, + KeyspaceID: keyspaceID, + SQL: sql.String(), + Digest: digest.String(), + TimeTotal: costTime, + TimeParse: sessVars.DurationParse, + TimeCompile: sessVars.DurationCompile, + TimeOptimize: sessVars.DurationOptimization, + TimeWaitTS: sessVars.DurationWaitTS, + IndexNames: indexNames, + CopTasks: &copTaskInfo, + ExecDetail: execDetail, + MemMax: memMax, + DiskMax: diskMax, + Succ: succ, + Plan: getPlanTree(stmtCtx), + PlanDigest: planDigest.String(), + BinaryPlan: binaryPlan, + Prepared: a.isPreparedStmt, + HasMoreResults: hasMoreResults, + PlanFromCache: sessVars.FoundInPlanCache, + PlanFromBinding: sessVars.FoundInBinding, + RewriteInfo: sessVars.RewritePhaseInfo, + KVTotal: time.Duration(atomic.LoadInt64(&tikvExecDetail.WaitKVRespDuration)), + PDTotal: time.Duration(atomic.LoadInt64(&tikvExecDetail.WaitPDRespDuration)), + BackoffTotal: time.Duration(atomic.LoadInt64(&tikvExecDetail.BackoffDuration)), + WriteSQLRespTotal: stmtDetail.WriteSQLRespDuration, + ResultRows: resultRows, + ExecRetryCount: a.retryCount, + IsExplicitTxn: sessVars.TxnCtx.IsExplicit, + IsWriteCacheTable: stmtCtx.WaitLockLeaseTime > 0, + UsedStats: stmtCtx.GetUsedStatsInfo(false), + IsSyncStatsFailed: stmtCtx.IsSyncStatsFailed, + Warnings: collectWarningsForSlowLog(stmtCtx), + ResourceGroupName: sessVars.StmtCtx.ResourceGroupName, + RRU: ruDetails.RRU(), + WRU: ruDetails.WRU(), + WaitRUDuration: ruDetails.RUWaitDuration(), + } + failpoint.Inject("assertSyncStatsFailed", func(val failpoint.Value) { + if val.(bool) { + if !slowItems.IsSyncStatsFailed { + panic("isSyncStatsFailed should be true") + } + } + }) + if a.retryCount > 0 { + slowItems.ExecRetryTime = costTime - sessVars.DurationParse - sessVars.DurationCompile - time.Since(a.retryStartTime) + } + if _, ok := a.StmtNode.(*ast.CommitStmt); ok && sessVars.PrevStmt != nil { + slowItems.PrevStmt = sessVars.PrevStmt.String() + } + slowLog := sessVars.SlowLogFormat(slowItems) + if trace.IsEnabled() { + trace.Log(a.GoCtx, "details", slowLog) + } + logutil.SlowQueryLogger.Warn(slowLog) + if costTime >= threshold { + if sessVars.InRestrictedSQL { + executor_metrics.TotalQueryProcHistogramInternal.Observe(costTime.Seconds()) + executor_metrics.TotalCopProcHistogramInternal.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) + executor_metrics.TotalCopWaitHistogramInternal.Observe(execDetail.TimeDetail.WaitTime.Seconds()) + } else { + executor_metrics.TotalQueryProcHistogramGeneral.Observe(costTime.Seconds()) + executor_metrics.TotalCopProcHistogramGeneral.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) + executor_metrics.TotalCopWaitHistogramGeneral.Observe(execDetail.TimeDetail.WaitTime.Seconds()) + if execDetail.ScanDetail != nil && execDetail.ScanDetail.ProcessedKeys != 0 { + executor_metrics.CopMVCCRatioHistogramGeneral.Observe(float64(execDetail.ScanDetail.TotalKeys) / float64(execDetail.ScanDetail.ProcessedKeys)) + } + } + var userString string + if sessVars.User != nil { + userString = sessVars.User.String() + } + var tableIDs string + if len(stmtCtx.TableIDs) > 0 { + tableIDs = strings.ReplaceAll(fmt.Sprintf("%v", stmtCtx.TableIDs), " ", ",") + } + domain.GetDomain(a.Ctx).LogSlowQuery(&domain.SlowQueryInfo{ + SQL: sql.String(), + Digest: digest.String(), + Start: sessVars.StartTime, + Duration: costTime, + Detail: stmtCtx.GetExecDetails(), + Succ: succ, + ConnID: sessVars.ConnectionID, + SessAlias: sessVars.SessionAlias, + TxnTS: txnTS, + User: userString, + DB: sessVars.CurrentDB, + TableIDs: tableIDs, + IndexNames: indexNames, + Internal: sessVars.InRestrictedSQL, + }) + } +} + +func extractMsgFromSQLWarn(sqlWarn *contextutil.SQLWarn) string { + // Currently, this function is only used in collectWarningsForSlowLog. + // collectWarningsForSlowLog can make sure SQLWarn is not nil so no need to add a nil check here. + warn := errors.Cause(sqlWarn.Err) + if x, ok := warn.(*terror.Error); ok && x != nil { + sqlErr := terror.ToSQLError(x) + return sqlErr.Message + } + return warn.Error() +} + +func collectWarningsForSlowLog(stmtCtx *stmtctx.StatementContext) []variable.JSONSQLWarnForSlowLog { + warnings := stmtCtx.GetWarnings() + extraWarnings := stmtCtx.GetExtraWarnings() + res := make([]variable.JSONSQLWarnForSlowLog, len(warnings)+len(extraWarnings)) + for i := range warnings { + res[i].Level = warnings[i].Level + res[i].Message = extractMsgFromSQLWarn(&warnings[i]) + } + for i := range extraWarnings { + res[len(warnings)+i].Level = extraWarnings[i].Level + res[len(warnings)+i].Message = extractMsgFromSQLWarn(&extraWarnings[i]) + res[len(warnings)+i].IsExtra = true + } + return res +} + +// GetResultRowsCount gets the count of the statement result rows. +func GetResultRowsCount(stmtCtx *stmtctx.StatementContext, p base.Plan) int64 { + runtimeStatsColl := stmtCtx.RuntimeStatsColl + if runtimeStatsColl == nil { + return 0 + } + rootPlanID := p.ID() + if !runtimeStatsColl.ExistsRootStats(rootPlanID) { + return 0 + } + rootStats := runtimeStatsColl.GetRootStats(rootPlanID) + return rootStats.GetActRows() +} + +// getFlatPlan generates a FlatPhysicalPlan from the plan stored in stmtCtx.plan, +// then stores it in stmtCtx.flatPlan. +func getFlatPlan(stmtCtx *stmtctx.StatementContext) *plannercore.FlatPhysicalPlan { + pp := stmtCtx.GetPlan() + if pp == nil { + return nil + } + if flat := stmtCtx.GetFlatPlan(); flat != nil { + f := flat.(*plannercore.FlatPhysicalPlan) + return f + } + p := pp.(base.Plan) + flat := plannercore.FlattenPhysicalPlan(p, false) + if flat != nil { + stmtCtx.SetFlatPlan(flat) + return flat + } + return nil +} + +func getBinaryPlan(sCtx sessionctx.Context) string { + stmtCtx := sCtx.GetSessionVars().StmtCtx + binaryPlan := stmtCtx.GetBinaryPlan() + if len(binaryPlan) > 0 { + return binaryPlan + } + flat := getFlatPlan(stmtCtx) + binaryPlan = plannercore.BinaryPlanStrFromFlatPlan(sCtx.GetPlanCtx(), flat) + stmtCtx.SetBinaryPlan(binaryPlan) + return binaryPlan +} + +// getPlanTree will try to get the select plan tree if the plan is select or the select plan of delete/update/insert statement. +func getPlanTree(stmtCtx *stmtctx.StatementContext) string { + cfg := config.GetGlobalConfig() + if atomic.LoadUint32(&cfg.Instance.RecordPlanInSlowLog) == 0 { + return "" + } + planTree, _ := getEncodedPlan(stmtCtx, false) + if len(planTree) == 0 { + return planTree + } + return variable.SlowLogPlanPrefix + planTree + variable.SlowLogPlanSuffix +} + +// GetPlanDigest will try to get the select plan tree if the plan is select or the select plan of delete/update/insert statement. +func GetPlanDigest(stmtCtx *stmtctx.StatementContext) (string, *parser.Digest) { + normalized, planDigest := stmtCtx.GetPlanDigest() + if len(normalized) > 0 && planDigest != nil { + return normalized, planDigest + } + flat := getFlatPlan(stmtCtx) + normalized, planDigest = plannercore.NormalizeFlatPlan(flat) + stmtCtx.SetPlanDigest(normalized, planDigest) + return normalized, planDigest +} + +// GetEncodedPlan returned same as getEncodedPlan +func GetEncodedPlan(stmtCtx *stmtctx.StatementContext, genHint bool) (encodedPlan, hintStr string) { + return getEncodedPlan(stmtCtx, genHint) +} + +// getEncodedPlan gets the encoded plan, and generates the hint string if indicated. +func getEncodedPlan(stmtCtx *stmtctx.StatementContext, genHint bool) (encodedPlan, hintStr string) { + var hintSet bool + encodedPlan = stmtCtx.GetEncodedPlan() + hintStr, hintSet = stmtCtx.GetPlanHint() + if len(encodedPlan) > 0 && (!genHint || hintSet) { + return + } + flat := getFlatPlan(stmtCtx) + if len(encodedPlan) == 0 { + encodedPlan = plannercore.EncodeFlatPlan(flat) + stmtCtx.SetEncodedPlan(encodedPlan) + } + if genHint { + hints := plannercore.GenHintsFromFlatPlan(flat) + for _, tableHint := range stmtCtx.OriginalTableHints { + // some hints like 'memory_quota' cannot be extracted from the PhysicalPlan directly, + // so we have to iterate all hints from the customer and keep some other necessary hints. + switch tableHint.HintName.L { + case hint.HintMemoryQuota, hint.HintUseToja, hint.HintNoIndexMerge, + hint.HintMaxExecutionTime, hint.HintIgnoreIndex, hint.HintReadFromStorage, + hint.HintMerge, hint.HintSemiJoinRewrite, hint.HintNoDecorrelate: + hints = append(hints, tableHint) + } + } + + hintStr = hint.RestoreOptimizerHints(hints) + stmtCtx.SetPlanHint(hintStr) + } + return +} + +// SummaryStmt collects statements for information_schema.statements_summary +func (a *ExecStmt) SummaryStmt(succ bool) { + sessVars := a.Ctx.GetSessionVars() + var userString string + if sessVars.User != nil { + userString = sessVars.User.Username + } + + // Internal SQLs must also be recorded to keep the consistency of `PrevStmt` and `PrevStmtDigest`. + if !stmtsummaryv2.Enabled() || ((sessVars.InRestrictedSQL || len(userString) == 0) && !stmtsummaryv2.EnabledInternal()) { + sessVars.SetPrevStmtDigest("") + return + } + // Ignore `PREPARE` statements, but record `EXECUTE` statements. + if _, ok := a.StmtNode.(*ast.PrepareStmt); ok { + return + } + stmtCtx := sessVars.StmtCtx + // Make sure StmtType is filled even if succ is false. + if stmtCtx.StmtType == "" { + stmtCtx.StmtType = ast.GetStmtLabel(a.StmtNode) + } + normalizedSQL, digest := stmtCtx.SQLDigest() + costTime := sessVars.GetTotalCostDuration() + charset, collation := sessVars.GetCharsetInfo() + + var prevSQL, prevSQLDigest string + if _, ok := a.StmtNode.(*ast.CommitStmt); ok { + // If prevSQLDigest is not recorded, it means this `commit` is the first SQL once stmt summary is enabled, + // so it's OK just to ignore it. + if prevSQLDigest = sessVars.GetPrevStmtDigest(); len(prevSQLDigest) == 0 { + return + } + prevSQL = sessVars.PrevStmt.String() + } + sessVars.SetPrevStmtDigest(digest.String()) + + // No need to encode every time, so encode lazily. + planGenerator := func() (p string, h string, e any) { + defer func() { + e = recover() + if e != nil { + logutil.BgLogger().Warn("fail to generate plan info", + zap.Stack("backtrace"), + zap.Any("error", e)) + } + }() + p, h = getEncodedPlan(stmtCtx, !sessVars.InRestrictedSQL) + return + } + var binPlanGen func() string + if variable.GenerateBinaryPlan.Load() { + binPlanGen = func() string { + binPlan := getBinaryPlan(a.Ctx) + return binPlan + } + } + // Generating plan digest is slow, only generate it once if it's 'Point_Get'. + // If it's a point get, different SQLs leads to different plans, so SQL digest + // is enough to distinguish different plans in this case. + var planDigest string + var planDigestGen func() string + if a.Plan.TP() == plancodec.TypePointGet { + planDigestGen = func() string { + _, planDigest := GetPlanDigest(stmtCtx) + return planDigest.String() + } + } else { + _, tmp := GetPlanDigest(stmtCtx) + planDigest = tmp.String() + } + + execDetail := stmtCtx.GetExecDetails() + copTaskInfo := stmtCtx.CopTasksDetails() + memMax := sessVars.MemTracker.MaxConsumed() + diskMax := sessVars.DiskTracker.MaxConsumed() + sql := a.getLazyStmtText() + var stmtDetail execdetails.StmtExecDetails + stmtDetailRaw := a.GoCtx.Value(execdetails.StmtExecDetailKey) + if stmtDetailRaw != nil { + stmtDetail = *(stmtDetailRaw.(*execdetails.StmtExecDetails)) + } + var tikvExecDetail util.ExecDetails + tikvExecDetailRaw := a.GoCtx.Value(util.ExecDetailsKey) + if tikvExecDetailRaw != nil { + tikvExecDetail = *(tikvExecDetailRaw.(*util.ExecDetails)) + } + var ruDetail *util.RUDetails + if ruDetailRaw := a.GoCtx.Value(util.RUDetailsCtxKey); ruDetailRaw != nil { + ruDetail = ruDetailRaw.(*util.RUDetails) + } + + if stmtCtx.WaitLockLeaseTime > 0 { + if execDetail.BackoffSleep == nil { + execDetail.BackoffSleep = make(map[string]time.Duration) + } + execDetail.BackoffSleep["waitLockLeaseForCacheTable"] = stmtCtx.WaitLockLeaseTime + execDetail.BackoffTime += stmtCtx.WaitLockLeaseTime + execDetail.TimeDetail.WaitTime += stmtCtx.WaitLockLeaseTime + } + + resultRows := GetResultRowsCount(stmtCtx, a.Plan) + + var ( + keyspaceName string + keyspaceID uint32 + ) + keyspaceName = keyspace.GetKeyspaceNameBySettings() + if !keyspace.IsKeyspaceNameEmpty(keyspaceName) { + keyspaceID = uint32(a.Ctx.GetStore().GetCodec().GetKeyspaceID()) + } + + stmtExecInfo := &stmtsummary.StmtExecInfo{ + SchemaName: strings.ToLower(sessVars.CurrentDB), + OriginalSQL: &sql, + Charset: charset, + Collation: collation, + NormalizedSQL: normalizedSQL, + Digest: digest.String(), + PrevSQL: prevSQL, + PrevSQLDigest: prevSQLDigest, + PlanGenerator: planGenerator, + BinaryPlanGenerator: binPlanGen, + PlanDigest: planDigest, + PlanDigestGen: planDigestGen, + User: userString, + TotalLatency: costTime, + ParseLatency: sessVars.DurationParse, + CompileLatency: sessVars.DurationCompile, + StmtCtx: stmtCtx, + CopTasks: &copTaskInfo, + ExecDetail: &execDetail, + MemMax: memMax, + DiskMax: diskMax, + StartTime: sessVars.StartTime, + IsInternal: sessVars.InRestrictedSQL, + Succeed: succ, + PlanInCache: sessVars.FoundInPlanCache, + PlanInBinding: sessVars.FoundInBinding, + ExecRetryCount: a.retryCount, + StmtExecDetails: stmtDetail, + ResultRows: resultRows, + TiKVExecDetails: tikvExecDetail, + Prepared: a.isPreparedStmt, + KeyspaceName: keyspaceName, + KeyspaceID: keyspaceID, + RUDetail: ruDetail, + ResourceGroupName: sessVars.StmtCtx.ResourceGroupName, + + PlanCacheUnqualified: sessVars.StmtCtx.PlanCacheUnqualified(), + } + if a.retryCount > 0 { + stmtExecInfo.ExecRetryTime = costTime - sessVars.DurationParse - sessVars.DurationCompile - time.Since(a.retryStartTime) + } + stmtsummaryv2.Add(stmtExecInfo) +} + +// GetTextToLog return the query text to log. +func (a *ExecStmt) GetTextToLog(keepHint bool) string { + var sql string + sessVars := a.Ctx.GetSessionVars() + rmode := sessVars.EnableRedactLog + if rmode == errors.RedactLogEnable { + if keepHint { + sql = parser.NormalizeKeepHint(sessVars.StmtCtx.OriginalSQL) + } else { + sql, _ = sessVars.StmtCtx.SQLDigest() + } + } else if sensitiveStmt, ok := a.StmtNode.(ast.SensitiveStmtNode); ok { + sql = sensitiveStmt.SecureText() + } else { + sql = redact.String(rmode, sessVars.StmtCtx.OriginalSQL+sessVars.PlanCacheParams.String()) + } + return sql +} + +// getLazyText is equivalent to `a.GetTextToLog(false)`. Note that the s.Params is a shallow copy of +// `sessVars.PlanCacheParams`, so you can only use the lazy text within the current stmt context. +func (a *ExecStmt) getLazyStmtText() (s variable.LazyStmtText) { + sessVars := a.Ctx.GetSessionVars() + rmode := sessVars.EnableRedactLog + if rmode == errors.RedactLogEnable { + sql, _ := sessVars.StmtCtx.SQLDigest() + s.SetText(sql) + } else if sensitiveStmt, ok := a.StmtNode.(ast.SensitiveStmtNode); ok { + sql := sensitiveStmt.SecureText() + s.SetText(sql) + } else { + s.Redact = rmode + s.SQL = sessVars.StmtCtx.OriginalSQL + s.Params = *sessVars.PlanCacheParams + } + return +} + +// updatePrevStmt is equivalent to `sessVars.PrevStmt = FormatSQL(a.GetTextToLog(false))` +func (a *ExecStmt) updatePrevStmt() { + sessVars := a.Ctx.GetSessionVars() + if sessVars.PrevStmt == nil { + sessVars.PrevStmt = &variable.LazyStmtText{Format: formatSQL} + } + rmode := sessVars.EnableRedactLog + if rmode == errors.RedactLogEnable { + sql, _ := sessVars.StmtCtx.SQLDigest() + sessVars.PrevStmt.SetText(sql) + } else if sensitiveStmt, ok := a.StmtNode.(ast.SensitiveStmtNode); ok { + sql := sensitiveStmt.SecureText() + sessVars.PrevStmt.SetText(sql) + } else { + sessVars.PrevStmt.Update(rmode, sessVars.StmtCtx.OriginalSQL, sessVars.PlanCacheParams) + } +} + +func (a *ExecStmt) observeStmtBeginForTopSQL(ctx context.Context) context.Context { + vars := a.Ctx.GetSessionVars() + sc := vars.StmtCtx + normalizedSQL, sqlDigest := sc.SQLDigest() + normalizedPlan, planDigest := GetPlanDigest(sc) + var sqlDigestByte, planDigestByte []byte + if sqlDigest != nil { + sqlDigestByte = sqlDigest.Bytes() + } + if planDigest != nil { + planDigestByte = planDigest.Bytes() + } + stats := a.Ctx.GetStmtStats() + if !topsqlstate.TopSQLEnabled() { + // To reduce the performance impact on fast plan. + // Drop them does not cause notable accuracy issue in TopSQL. + if IsFastPlan(a.Plan) { + return ctx + } + // Always attach the SQL and plan info uses to catch the running SQL when Top SQL is enabled in execution. + if stats != nil { + stats.OnExecutionBegin(sqlDigestByte, planDigestByte) + } + return topsql.AttachSQLAndPlanInfo(ctx, sqlDigest, planDigest) + } + + if stats != nil { + stats.OnExecutionBegin(sqlDigestByte, planDigestByte) + // This is a special logic prepared for TiKV's SQLExecCount. + sc.KvExecCounter = stats.CreateKvExecCounter(sqlDigestByte, planDigestByte) + } + + isSQLRegistered := sc.IsSQLRegistered.Load() + if !isSQLRegistered { + topsql.RegisterSQL(normalizedSQL, sqlDigest, vars.InRestrictedSQL) + } + sc.IsSQLAndPlanRegistered.Store(true) + if len(normalizedPlan) == 0 { + return ctx + } + topsql.RegisterPlan(normalizedPlan, planDigest) + return topsql.AttachSQLAndPlanInfo(ctx, sqlDigest, planDigest) +} + +func (a *ExecStmt) observeStmtFinishedForTopSQL() { + vars := a.Ctx.GetSessionVars() + if vars == nil { + return + } + if stats := a.Ctx.GetStmtStats(); stats != nil && topsqlstate.TopSQLEnabled() { + sqlDigest, planDigest := a.getSQLPlanDigest() + execDuration := vars.GetTotalCostDuration() + stats.OnExecutionFinished(sqlDigest, planDigest, execDuration) + } +} + +func (a *ExecStmt) getSQLPlanDigest() ([]byte, []byte) { + var sqlDigest, planDigest []byte + vars := a.Ctx.GetSessionVars() + if _, d := vars.StmtCtx.SQLDigest(); d != nil { + sqlDigest = d.Bytes() + } + if _, d := vars.StmtCtx.GetPlanDigest(); d != nil { + planDigest = d.Bytes() + } + return sqlDigest, planDigest +} + +// only allow select/delete/update/insert/execute stmt captured by continues capture +func checkPlanReplayerContinuesCaptureValidStmt(stmtNode ast.StmtNode) bool { + switch stmtNode.(type) { + case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt, *ast.ExecuteStmt: + return true + default: + return false + } +} + +func checkPlanReplayerCaptureTask(sctx sessionctx.Context, stmtNode ast.StmtNode, startTS uint64) { + dom := domain.GetDomain(sctx) + if dom == nil { + return + } + handle := dom.GetPlanReplayerHandle() + if handle == nil { + return + } + tasks := handle.GetTasks() + if len(tasks) == 0 { + return + } + _, sqlDigest := sctx.GetSessionVars().StmtCtx.SQLDigest() + _, planDigest := sctx.GetSessionVars().StmtCtx.GetPlanDigest() + if sqlDigest == nil || planDigest == nil { + return + } + key := replayer.PlanReplayerTaskKey{ + SQLDigest: sqlDigest.String(), + PlanDigest: planDigest.String(), + } + for _, task := range tasks { + if task.SQLDigest == sqlDigest.String() { + if task.PlanDigest == "*" || task.PlanDigest == planDigest.String() { + sendPlanReplayerDumpTask(key, sctx, stmtNode, startTS, false) + return + } + } + } +} + +func checkPlanReplayerContinuesCapture(sctx sessionctx.Context, stmtNode ast.StmtNode, startTS uint64) { + dom := domain.GetDomain(sctx) + if dom == nil { + return + } + handle := dom.GetPlanReplayerHandle() + if handle == nil { + return + } + _, sqlDigest := sctx.GetSessionVars().StmtCtx.SQLDigest() + _, planDigest := sctx.GetSessionVars().StmtCtx.GetPlanDigest() + key := replayer.PlanReplayerTaskKey{ + SQLDigest: sqlDigest.String(), + PlanDigest: planDigest.String(), + } + existed := sctx.GetSessionVars().CheckPlanReplayerFinishedTaskKey(key) + if existed { + return + } + sendPlanReplayerDumpTask(key, sctx, stmtNode, startTS, true) + sctx.GetSessionVars().AddPlanReplayerFinishedTaskKey(key) +} + +func sendPlanReplayerDumpTask(key replayer.PlanReplayerTaskKey, sctx sessionctx.Context, stmtNode ast.StmtNode, + startTS uint64, isContinuesCapture bool) { + stmtCtx := sctx.GetSessionVars().StmtCtx + handle := sctx.Value(bindinfo.SessionBindInfoKeyType).(bindinfo.SessionBindingHandle) + bindings := handle.GetAllSessionBindings() + dumpTask := &domain.PlanReplayerDumpTask{ + PlanReplayerTaskKey: key, + StartTS: startTS, + TblStats: stmtCtx.TableStats, + SessionBindings: []bindinfo.Bindings{bindings}, + SessionVars: sctx.GetSessionVars(), + ExecStmts: []ast.StmtNode{stmtNode}, + DebugTrace: []any{stmtCtx.OptimizerDebugTrace}, + Analyze: false, + IsCapture: true, + IsContinuesCapture: isContinuesCapture, + } + dumpTask.EncodedPlan, _ = GetEncodedPlan(stmtCtx, false) + if execStmtAst, ok := stmtNode.(*ast.ExecuteStmt); ok { + planCacheStmt, err := plannercore.GetPreparedStmt(execStmtAst, sctx.GetSessionVars()) + if err != nil { + logutil.BgLogger().Warn("fail to find prepared ast for dumping plan replayer", zap.String("category", "plan-replayer-capture"), + zap.String("sqlDigest", key.SQLDigest), + zap.String("planDigest", key.PlanDigest), + zap.Error(err)) + } else { + dumpTask.ExecStmts = []ast.StmtNode{planCacheStmt.PreparedAst.Stmt} + } + } + domain.GetDomain(sctx).GetPlanReplayerHandle().SendTask(dumpTask) +} diff --git a/pkg/executor/aggregate/agg_hash_executor.go b/pkg/executor/aggregate/agg_hash_executor.go index 9b05153015803..80c78de579c3d 100644 --- a/pkg/executor/aggregate/agg_hash_executor.go +++ b/pkg/executor/aggregate/agg_hash_executor.go @@ -215,23 +215,23 @@ func (e *HashAggExec) Close() error { } err := e.BaseExecutor.Close() - failpoint.Inject("injectHashAggClosePanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectHashAggClosePanic")); _err_ == nil { if enabled := val.(bool); enabled { if e.Ctx().GetSessionVars().ConnectionID != 0 { panic(errors.New("test")) } } - }) + } return err } // Open implements the Executor Open interface. func (e *HashAggExec) Open(ctx context.Context) error { - failpoint.Inject("mockHashAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockHashAggExecBaseExecutorOpenReturnedError")); _err_ == nil { if val, _ := val.(bool); val { - failpoint.Return(errors.New("mock HashAggExec.baseExecutor.Open returned error")) + return errors.New("mock HashAggExec.baseExecutor.Open returned error") } - }) + } if err := e.BaseExecutor.Open(ctx); err != nil { return err @@ -264,7 +264,7 @@ func (e *HashAggExec) initForUnparallelExec() { e.groupSet, setSize = set.NewStringSetWithMemoryUsage() e.partialResultMap = make(aggfuncs.AggPartialResultMapper) e.bInMap = 0 - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1< | | | | ...... | | partialInputChs + +--------------+ +-+ +-+ +-+ +*/ +type HashAggExec struct { + exec.BaseExecutor + + Sc *stmtctx.StatementContext + PartialAggFuncs []aggfuncs.AggFunc + FinalAggFuncs []aggfuncs.AggFunc + partialResultMap aggfuncs.AggPartialResultMapper + bInMap int64 // indicate there are 2^bInMap buckets in partialResultMap + groupSet set.StringSetWithMemoryUsage + groupKeys []string + cursor4GroupKey int + GroupByItems []expression.Expression + groupKeyBuffer [][]byte + + finishCh chan struct{} + finalOutputCh chan *AfFinalResult + partialOutputChs []chan *aggfuncs.AggPartialResultMapper + inputCh chan *HashAggInput + partialInputChs []chan *chunk.Chunk + partialWorkers []HashAggPartialWorker + finalWorkers []HashAggFinalWorker + DefaultVal *chunk.Chunk + childResult *chunk.Chunk + + // IsChildReturnEmpty indicates whether the child executor only returns an empty input. + IsChildReturnEmpty bool + // After we support parallel execution for aggregation functions with distinct, + // we can remove this attribute. + IsUnparallelExec bool + parallelExecValid bool + prepared atomic.Bool + executed atomic.Bool + + memTracker *memory.Tracker // track memory usage. + diskTracker *disk.Tracker + + stats *HashAggRuntimeStats + + // dataInDisk is the chunks to store row values for spilled data. + // The HashAggExec may be set to `spill mode` multiple times, and all spilled data will be appended to DataInDiskByRows. + dataInDisk *chunk.DataInDiskByChunks + // numOfSpilledChks indicates the number of all the spilled chunks. + numOfSpilledChks int + // offsetOfSpilledChks indicates the offset of the chunk be read from the disk. + // In each round of processing, we need to re-fetch all the chunks spilled in the last one. + offsetOfSpilledChks int + // inSpillMode indicates whether HashAgg is in `spill mode`. + // When HashAgg is in `spill mode`, the size of `partialResultMap` is no longer growing and all the data fetched + // from the child executor is spilled to the disk. + inSpillMode uint32 + // tmpChkForSpill is the temp chunk for spilling. + tmpChkForSpill *chunk.Chunk + // The `inflightChunkSync` calls `Add(1)` when the data fetcher goroutine inserts a chunk into the channel, + // and `Done()` when any partial worker retrieves a chunk from the channel and updates it in the `partialResultMap`. + // In scenarios where it is necessary to wait for all partial workers to finish processing the inflight chunk, + // `inflightChunkSync` can be used for synchronization. + inflightChunkSync *sync.WaitGroup + // spillAction save the Action for spilling. + spillAction *AggSpillDiskAction + // parallelAggSpillAction save the Action for spilling of parallel aggregation. + parallelAggSpillAction *ParallelAggSpillDiskAction + // spillHelper helps to carry out the spill action + spillHelper *parallelHashAggSpillHelper + // isChildDrained indicates whether the all data from child has been taken out. + isChildDrained bool +} + +// Close implements the Executor Close interface. +func (e *HashAggExec) Close() error { + if e.stats != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + + if e.IsUnparallelExec { + e.childResult = nil + e.groupSet, _ = set.NewStringSetWithMemoryUsage() + e.partialResultMap = nil + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) + } + if e.dataInDisk != nil { + e.dataInDisk.Close() + } + if e.spillAction != nil { + e.spillAction.SetFinished() + } + e.spillAction, e.tmpChkForSpill = nil, nil + err := e.BaseExecutor.Close() + if err != nil { + return err + } + return nil + } + if e.parallelExecValid { + // `Close` may be called after `Open` without calling `Next` in test. + if e.prepared.CompareAndSwap(false, true) { + close(e.inputCh) + for _, ch := range e.partialOutputChs { + close(ch) + } + for _, ch := range e.partialInputChs { + close(ch) + } + close(e.finalOutputCh) + } + close(e.finishCh) + for _, ch := range e.partialOutputChs { + channel.Clear(ch) + } + for _, ch := range e.partialInputChs { + channel.Clear(ch) + } + channel.Clear(e.finalOutputCh) + e.executed.Store(false) + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) + } + e.parallelExecValid = false + if e.parallelAggSpillAction != nil { + e.parallelAggSpillAction.SetFinished() + e.parallelAggSpillAction = nil + e.spillHelper.close() + } + } + + err := e.BaseExecutor.Close() + failpoint.Inject("injectHashAggClosePanic", func(val failpoint.Value) { + if enabled := val.(bool); enabled { + if e.Ctx().GetSessionVars().ConnectionID != 0 { + panic(errors.New("test")) + } + } + }) + return err +} + +// Open implements the Executor Open interface. +func (e *HashAggExec) Open(ctx context.Context) error { + failpoint.Inject("mockHashAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val, _ := val.(bool); val { + failpoint.Return(errors.New("mock HashAggExec.baseExecutor.Open returned error")) + } + }) + + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + return e.OpenSelf() +} + +// OpenSelf just opens the hash aggregation executor. +func (e *HashAggExec) OpenSelf() error { + e.prepared.Store(false) + + if e.memTracker != nil { + e.memTracker.Reset() + } else { + e.memTracker = memory.NewTracker(e.ID(), -1) + } + if e.Ctx().GetSessionVars().TrackAggregateMemoryUsage { + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + } + + if e.IsUnparallelExec { + e.initForUnparallelExec() + return nil + } + return e.initForParallelExec(e.Ctx()) +} + +func (e *HashAggExec) initForUnparallelExec() { + var setSize int64 + e.groupSet, setSize = set.NewStringSetWithMemoryUsage() + e.partialResultMap = make(aggfuncs.AggPartialResultMapper) + e.bInMap = 0 + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1< 0 { + e.IsChildReturnEmpty = false + return nil + } + } +} + +// unparallelExec executes hash aggregation algorithm in single thread. +func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) error { + chk.Reset() + for { + exprCtx := e.Ctx().GetExprCtx() + if e.prepared.Load() { + // Since we return e.MaxChunkSize() rows every time, so we should not traverse + // `groupSet` because of its randomness. + for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ { + partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey]) + if len(e.PartialAggFuncs) == 0 { + chk.SetNumVirtualRows(chk.NumRows() + 1) + } + for i, af := range e.PartialAggFuncs { + if err := af.AppendFinalResult2Chunk(exprCtx.GetEvalCtx(), partialResults[i], chk); err != nil { + return err + } + } + if chk.IsFull() { + e.cursor4GroupKey++ + return nil + } + } + e.resetSpillMode() + } + if e.executed.Load() { + return nil + } + if err := e.execute(ctx); err != nil { + return err + } + if (len(e.groupSet.StringSet) == 0) && len(e.GroupByItems) == 0 { + // If no groupby and no data, we should add an empty group. + // For example: + // "select count(c) from t;" should return one row [0] + // "select count(c) from t group by c1;" should return empty result set. + e.memTracker.Consume(e.groupSet.Insert("")) + e.groupKeys = append(e.groupKeys, "") + } + e.prepared.Store(true) + } +} + +func (e *HashAggExec) resetSpillMode() { + e.cursor4GroupKey, e.groupKeys = 0, e.groupKeys[:0] + var setSize int64 + e.groupSet, setSize = set.NewStringSetWithMemoryUsage() + e.partialResultMap = make(aggfuncs.AggPartialResultMapper) + e.bInMap = 0 + e.prepared.Store(false) + e.executed.Store(e.numOfSpilledChks == e.dataInDisk.NumChunks()) // No data is spilling again, all data have been processed. + e.numOfSpilledChks = e.dataInDisk.NumChunks() + e.memTracker.ReplaceBytesUsed(setSize) + atomic.StoreUint32(&e.inSpillMode, 0) +} + +// execute fetches Chunks from src and update each aggregate function for each row in Chunk. +func (e *HashAggExec) execute(ctx context.Context) (err error) { + defer func() { + if e.tmpChkForSpill.NumRows() > 0 && err == nil { + err = e.dataInDisk.Add(e.tmpChkForSpill) + e.tmpChkForSpill.Reset() + } + }() + exprCtx := e.Ctx().GetExprCtx() + for { + mSize := e.childResult.MemoryUsage() + if err := e.getNextChunk(ctx); err != nil { + return err + } + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) + if err != nil { + return err + } + + failpoint.Inject("unparallelHashAggError", func(val failpoint.Value) { + if val, _ := val.(bool); val { + failpoint.Return(errors.New("HashAggExec.unparallelExec error")) + } + }) + + // no more data. + if e.childResult.NumRows() == 0 { + return nil + } + e.groupKeyBuffer, err = GetGroupKey(e.Ctx(), e.childResult, e.groupKeyBuffer, e.GroupByItems) + if err != nil { + return err + } + + allMemDelta := int64(0) + sel := make([]int, 0, e.childResult.NumRows()) + var tmpBuf [1]chunk.Row + for j := 0; j < e.childResult.NumRows(); j++ { + groupKey := string(e.groupKeyBuffer[j]) // do memory copy here, because e.groupKeyBuffer may be reused. + if !e.groupSet.Exist(groupKey) { + if atomic.LoadUint32(&e.inSpillMode) == 1 && e.groupSet.Count() > 0 { + sel = append(sel, j) + continue + } + allMemDelta += e.groupSet.Insert(groupKey) + e.groupKeys = append(e.groupKeys, groupKey) + } + partialResults := e.getPartialResults(groupKey) + for i, af := range e.PartialAggFuncs { + tmpBuf[0] = e.childResult.GetRow(j) + memDelta, err := af.UpdatePartialResult(exprCtx.GetEvalCtx(), tmpBuf[:], partialResults[i]) + if err != nil { + return err + } + allMemDelta += memDelta + } + } + + // spill unprocessed data when exceeded. + if len(sel) > 0 { + e.childResult.SetSel(sel) + err = e.spillUnprocessedData(len(sel) == cap(sel)) + if err != nil { + return err + } + } + + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(allMemDelta) + } +} + +func (e *HashAggExec) spillUnprocessedData(isFullChk bool) (err error) { + if isFullChk { + return e.dataInDisk.Add(e.childResult) + } + for i := 0; i < e.childResult.NumRows(); i++ { + e.tmpChkForSpill.AppendRow(e.childResult.GetRow(i)) + if e.tmpChkForSpill.IsFull() { + err = e.dataInDisk.Add(e.tmpChkForSpill) + if err != nil { + return err + } + e.tmpChkForSpill.Reset() + } + } + return nil +} + +func (e *HashAggExec) getNextChunk(ctx context.Context) (err error) { + e.childResult.Reset() + if !e.isChildDrained { + if err := exec.Next(ctx, e.Children(0), e.childResult); err != nil { + return err + } + if e.childResult.NumRows() != 0 { + return nil + } + e.isChildDrained = true + } + if e.offsetOfSpilledChks < e.numOfSpilledChks { + e.childResult, err = e.dataInDisk.GetChunk(e.offsetOfSpilledChks) + if err != nil { + return err + } + e.offsetOfSpilledChks++ + } + return nil +} + +func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResult { + partialResults, ok := e.partialResultMap[groupKey] + allMemDelta := int64(0) + if !ok { + partialResults = make([]aggfuncs.PartialResult, 0, len(e.PartialAggFuncs)) + for _, af := range e.PartialAggFuncs { + partialResult, memDelta := af.AllocPartialResult() + partialResults = append(partialResults, partialResult) + allMemDelta += memDelta + } + // Map will expand when count > bucketNum * loadFactor. The memory usage will doubled. + if len(e.partialResultMap)+1 > (1< 0 { + return true + } + } + return false +} diff --git a/pkg/executor/aggregate/agg_hash_final_worker.go b/pkg/executor/aggregate/agg_hash_final_worker.go index d2cc2cb047f10..6cb21c8045069 100644 --- a/pkg/executor/aggregate/agg_hash_final_worker.go +++ b/pkg/executor/aggregate/agg_hash_final_worker.go @@ -128,7 +128,7 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) error { return nil } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if err := w.mergeInputIntoResultMap(sctx, input); err != nil { return err @@ -168,7 +168,7 @@ func (w *HashAggFinalWorker) sendFinalResult(sctx sessionctx.Context) { return } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) execStart := time.Now() updateExecTime(w.stats, execStart) @@ -253,7 +253,7 @@ func (w *HashAggFinalWorker) cleanup(start time.Time, waitGroup *sync.WaitGroup) } func intestBeforeFinalWorkerStart() { - failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("enableAggSpillIntest")); _err_ == nil { if val.(bool) { num := rand.Intn(50) if num < 3 { @@ -262,11 +262,11 @@ func intestBeforeFinalWorkerStart() { time.Sleep(1 * time.Millisecond) } } - }) + } } func (w *HashAggFinalWorker) intestDuringFinalWorkerRun(err *error) { - failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("enableAggSpillIntest")); _err_ == nil { if val.(bool) { num := rand.Intn(10000) if num < 5 { @@ -279,5 +279,5 @@ func (w *HashAggFinalWorker) intestDuringFinalWorkerRun(err *error) { *err = errors.New("Random fail is triggered in final worker") } } - }) + } } diff --git a/pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ b/pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ new file mode 100644 index 0000000000000..d2cc2cb047f10 --- /dev/null +++ b/pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ @@ -0,0 +1,283 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregate + +import ( + "math/rand" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/aggfuncs" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// AfFinalResult indicates aggregation functions final result. +type AfFinalResult struct { + chk *chunk.Chunk + err error + giveBackCh chan *chunk.Chunk +} + +// HashAggFinalWorker indicates the final workers of parallel hash agg execution, +// the number of the worker can be set by `tidb_hashagg_final_concurrency`. +type HashAggFinalWorker struct { + baseHashAggWorker + + mutableRow chunk.MutRow + partialResultMap aggfuncs.AggPartialResultMapper + BInMap int + inputCh chan *aggfuncs.AggPartialResultMapper + outputCh chan *AfFinalResult + finalResultHolderCh chan *chunk.Chunk + + spillHelper *parallelHashAggSpillHelper + + restoredAggResultMapperMem int64 +} + +func (w *HashAggFinalWorker) getInputFromDisk(sctx sessionctx.Context) (ret aggfuncs.AggPartialResultMapper, restoredMem int64, err error) { + ret, restoredMem, err = w.spillHelper.restoreOnePartition(sctx) + w.intestDuringFinalWorkerRun(&err) + return ret, restoredMem, err +} + +func (w *HashAggFinalWorker) getPartialInput() (input *aggfuncs.AggPartialResultMapper, ok bool) { + waitStart := time.Now() + defer updateWaitTime(w.stats, waitStart) + select { + case <-w.finishCh: + return nil, false + case input, ok = <-w.inputCh: + if !ok { + return nil, false + } + } + return +} + +func (w *HashAggFinalWorker) initBInMap() { + w.BInMap = 0 + mapLen := len(w.partialResultMap) + for mapLen > (1< (1<= 2 && num < 4 { + time.Sleep(1 * time.Millisecond) + } + } + }) +} + +func (w *HashAggPartialWorker) finalizeWorkerProcess(needShuffle bool, finalConcurrency int, hasError bool) { + // Consume all chunks to avoid hang of fetcher + for range w.inputCh { + w.inflightChunkSync.Done() + } + + if w.checkFinishChClosed() { + return + } + + if hasError { + return + } + + if needShuffle && w.spillHelper.isSpilledChunksIOEmpty() { + w.shuffleIntermData(finalConcurrency) + } +} + +func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitGroup, finalConcurrency int) { + start := time.Now() + hasError := false + needShuffle := false + + defer func() { + if r := recover(); r != nil { + recoveryHashAgg(w.globalOutputCh, r) + } + + w.finalizeWorkerProcess(needShuffle, finalConcurrency, hasError) + + w.memTracker.Consume(-w.chk.MemoryUsage()) + updateWorkerTime(w.stats, start) + + // We must ensure that there is no panic before `waitGroup.Done()` or there will be hang + waitGroup.Done() + + tryRecycleBuffer(&w.partialResultsBuffer, &w.groupKeyBuf) + }() + + intestBeforePartialWorkerRun() + + for w.fetchChunkAndProcess(ctx, &hasError, &needShuffle) { + } +} + +// If the group key has appeared before, reuse the partial result. +// If the group key has not appeared before, create empty partial results. +func (w *HashAggPartialWorker) getPartialResultsOfEachRow(groupKey [][]byte, finalConcurrency int) [][]aggfuncs.PartialResult { + mapper := w.partialResultsMap + numRows := len(groupKey) + allMemDelta := int64(0) + w.partialResultsBuffer = w.partialResultsBuffer[0:0] + + for i := 0; i < numRows; i++ { + finalWorkerIdx := int(murmur3.Sum32(groupKey[i])) % finalConcurrency + tmp, ok := mapper[finalWorkerIdx][string(hack.String(groupKey[i]))] + + // This group by key has appeared before, reuse the partial result. + if ok { + w.partialResultsBuffer = append(w.partialResultsBuffer, tmp) + continue + } + + // It's the first time that this group by key appeared, create it + w.partialResultsBuffer = append(w.partialResultsBuffer, make([]aggfuncs.PartialResult, w.partialResultNumInRow)) + lastIdx := len(w.partialResultsBuffer) - 1 + for j, af := range w.aggFuncs { + partialResult, memDelta := af.AllocPartialResult() + w.partialResultsBuffer[lastIdx][j] = partialResult + allMemDelta += memDelta // the memory usage of PartialResult + } + allMemDelta += int64(w.partialResultNumInRow * 8) + + // Map will expand when count > bucketNum * loadFactor. The memory usage will double. + if len(mapper[finalWorkerIdx])+1 > (1< 0 { + err := w.spilledChunksIO[i].Add(w.tmpChksForSpill[i]) + if err != nil { + return err + } + w.tmpChksForSpill[i].Reset() + } + } + return nil +} + +func (w *HashAggPartialWorker) processError(err error) { + w.globalOutputCh <- &AfFinalResult{err: err} + w.spillHelper.setError() +} diff --git a/pkg/executor/aggregate/agg_stream_executor.go b/pkg/executor/aggregate/agg_stream_executor.go index 6e08503325731..9416d8cd60b38 100644 --- a/pkg/executor/aggregate/agg_stream_executor.go +++ b/pkg/executor/aggregate/agg_stream_executor.go @@ -53,11 +53,11 @@ type StreamAggExec struct { // Open implements the Executor Open interface. func (e *StreamAggExec) Open(ctx context.Context) error { - failpoint.Inject("mockStreamAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockStreamAggExecBaseExecutorOpenReturnedError")); _err_ == nil { if val, _ := val.(bool); val { - failpoint.Return(errors.New("mock StreamAggExec.baseExecutor.Open returned error")) + return errors.New("mock StreamAggExec.baseExecutor.Open returned error") } - }) + } if err := e.BaseExecutor.Open(ctx); err != nil { return err @@ -91,7 +91,7 @@ func (e *StreamAggExec) OpenSelf() error { if e.Ctx().GetSessionVars().TrackAggregateMemoryUsage { e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) e.memTracker.Consume(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) return nil } @@ -179,7 +179,7 @@ func (e *StreamAggExec) consumeGroupRows() error { } allMemDelta += memDelta } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) e.memTracker.Consume(allMemDelta) e.groupRows = e.groupRows[:0] return nil @@ -194,7 +194,7 @@ func (e *StreamAggExec) consumeCurGroupRowsAndFetchChild(ctx context.Context, ch mSize := e.childResult.MemoryUsage() err = exec.Next(ctx, e.Children(0), e.childResult) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err @@ -227,7 +227,7 @@ func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error { } aggFunc.ResetPartialResult(e.partialResults[i]) } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) // All partial results have been reset, so reset the memory usage. e.memTracker.ReplaceBytesUsed(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) if len(e.AggFuncs) == 0 { diff --git a/pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ b/pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..6e08503325731 --- /dev/null +++ b/pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ @@ -0,0 +1,237 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregate + +import ( + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/aggfuncs" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +// StreamAggExec deals with all the aggregate functions. +// It assumes all the input data is sorted by group by key. +// When Next() is called, it will return a result for the same group. +type StreamAggExec struct { + exec.BaseExecutor + + executed bool + // IsChildReturnEmpty indicates whether the child executor only returns an empty input. + IsChildReturnEmpty bool + DefaultVal *chunk.Chunk + GroupChecker *vecgroupchecker.VecGroupChecker + inputIter *chunk.Iterator4Chunk + inputRow chunk.Row + AggFuncs []aggfuncs.AggFunc + partialResults []aggfuncs.PartialResult + groupRows []chunk.Row + childResult *chunk.Chunk + + memTracker *memory.Tracker // track memory usage. + // memUsageOfInitialPartialResult indicates the memory usage of all partial results after initialization. + // All partial results will be reset after processing one group data, and the memory usage should also be reset. + // We can't get memory delta from ResetPartialResult, so record the memory usage here. + memUsageOfInitialPartialResult int64 +} + +// Open implements the Executor Open interface. +func (e *StreamAggExec) Open(ctx context.Context) error { + failpoint.Inject("mockStreamAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val, _ := val.(bool); val { + failpoint.Return(errors.New("mock StreamAggExec.baseExecutor.Open returned error")) + } + }) + + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + // If panic in Open, the children executor should be closed because they are open. + defer closeBaseExecutor(&e.BaseExecutor) + return e.OpenSelf() +} + +// OpenSelf just opens the StreamAggExec. +func (e *StreamAggExec) OpenSelf() error { + e.childResult = exec.TryNewCacheChunk(e.Children(0)) + e.executed = false + e.IsChildReturnEmpty = true + e.inputIter = chunk.NewIterator4Chunk(e.childResult) + e.inputRow = e.inputIter.End() + + e.partialResults = make([]aggfuncs.PartialResult, 0, len(e.AggFuncs)) + for _, aggFunc := range e.AggFuncs { + partialResult, memDelta := aggFunc.AllocPartialResult() + e.partialResults = append(e.partialResults, partialResult) + e.memUsageOfInitialPartialResult += memDelta + } + + if e.memTracker != nil { + e.memTracker.Reset() + } else { + // bytesLimit <= 0 means no limit, for now we just track the memory footprint + e.memTracker = memory.NewTracker(e.ID(), -1) + } + if e.Ctx().GetSessionVars().TrackAggregateMemoryUsage { + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + } + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) + return nil +} + +// Close implements the Executor Close interface. +func (e *StreamAggExec) Close() error { + if e.childResult != nil { + e.memTracker.Consume(-e.childResult.MemoryUsage() - e.memUsageOfInitialPartialResult) + e.childResult = nil + } + e.GroupChecker.Reset() + return e.BaseExecutor.Close() +} + +// Next implements the Executor Next interface. +func (e *StreamAggExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.Reset() + for !e.executed && !req.IsFull() { + err = e.consumeOneGroup(ctx, req) + if err != nil { + e.executed = true + return err + } + } + return nil +} + +func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) (err error) { + if e.GroupChecker.IsExhausted() { + if err = e.consumeCurGroupRowsAndFetchChild(ctx, chk); err != nil { + return err + } + if e.executed { + return nil + } + _, err := e.GroupChecker.SplitIntoGroups(e.childResult) + if err != nil { + return err + } + } + begin, end := e.GroupChecker.GetNextGroup() + for i := begin; i < end; i++ { + e.groupRows = append(e.groupRows, e.childResult.GetRow(i)) + } + + for meetLastGroup := end == e.childResult.NumRows(); meetLastGroup; { + meetLastGroup = false + if err = e.consumeCurGroupRowsAndFetchChild(ctx, chk); err != nil || e.executed { + return err + } + + isFirstGroupSameAsPrev, err := e.GroupChecker.SplitIntoGroups(e.childResult) + if err != nil { + return err + } + + if isFirstGroupSameAsPrev { + begin, end = e.GroupChecker.GetNextGroup() + for i := begin; i < end; i++ { + e.groupRows = append(e.groupRows, e.childResult.GetRow(i)) + } + meetLastGroup = end == e.childResult.NumRows() + } + } + + err = e.consumeGroupRows() + if err != nil { + return err + } + + return e.appendResult2Chunk(chk) +} + +func (e *StreamAggExec) consumeGroupRows() error { + if len(e.groupRows) == 0 { + return nil + } + + allMemDelta := int64(0) + exprCtx := e.Ctx().GetExprCtx() + for i, aggFunc := range e.AggFuncs { + memDelta, err := aggFunc.UpdatePartialResult(exprCtx.GetEvalCtx(), e.groupRows, e.partialResults[i]) + if err != nil { + return err + } + allMemDelta += memDelta + } + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(allMemDelta) + e.groupRows = e.groupRows[:0] + return nil +} + +func (e *StreamAggExec) consumeCurGroupRowsAndFetchChild(ctx context.Context, chk *chunk.Chunk) (err error) { + // Before fetching a new batch of input, we should consume the last group. + err = e.consumeGroupRows() + if err != nil { + return err + } + + mSize := e.childResult.MemoryUsage() + err = exec.Next(ctx, e.Children(0), e.childResult) + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) + if err != nil { + return err + } + + // No more data. + if e.childResult.NumRows() == 0 { + if !e.IsChildReturnEmpty { + err = e.appendResult2Chunk(chk) + } else if e.DefaultVal != nil { + chk.Append(e.DefaultVal, 0, 1) + } + e.executed = true + return err + } + // Reach here, "e.childrenResults[0].NumRows() > 0" is guaranteed. + e.IsChildReturnEmpty = false + e.inputRow = e.inputIter.Begin() + return nil +} + +// appendResult2Chunk appends result of all the aggregation functions to the +// result chunk, and reset the evaluation context for each aggregation. +func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error { + exprCtx := e.Ctx().GetExprCtx() + for i, aggFunc := range e.AggFuncs { + err := aggFunc.AppendFinalResult2Chunk(exprCtx.GetEvalCtx(), e.partialResults[i], chk) + if err != nil { + return err + } + aggFunc.ResetPartialResult(e.partialResults[i]) + } + failpoint.Inject("ConsumeRandomPanic", nil) + // All partial results have been reset, so reset the memory usage. + e.memTracker.ReplaceBytesUsed(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) + if len(e.AggFuncs) == 0 { + chk.SetNumVirtualRows(chk.NumRows() + 1) + } + return nil +} diff --git a/pkg/executor/aggregate/agg_util.go b/pkg/executor/aggregate/agg_util.go index 81b06147ee1d0..9b1568b88bf2a 100644 --- a/pkg/executor/aggregate/agg_util.go +++ b/pkg/executor/aggregate/agg_util.go @@ -281,14 +281,14 @@ func (e *HashAggExec) ActionSpill() memory.ActionOnExceed { func failpointError() error { var err error - failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("enableAggSpillIntest")); _err_ == nil { if val.(bool) { num := rand.Intn(1000) if num < 3 { err = errors.Errorf("Random fail is triggered in ParallelAggSpillDiskAction") } } - }) + } return err } diff --git a/pkg/executor/aggregate/agg_util.go__failpoint_stash__ b/pkg/executor/aggregate/agg_util.go__failpoint_stash__ new file mode 100644 index 0000000000000..81b06147ee1d0 --- /dev/null +++ b/pkg/executor/aggregate/agg_util.go__failpoint_stash__ @@ -0,0 +1,312 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregate + +import ( + "bytes" + "cmp" + "fmt" + "math/rand" + "slices" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/aggfuncs" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +const defaultPartialResultsBufferCap = 2048 +const defaultGroupKeyCap = 8 + +var partialResultsBufferPool = sync.Pool{ + New: func() any { + s := make([][]aggfuncs.PartialResult, 0, defaultPartialResultsBufferCap) + return &s + }, +} + +var groupKeyPool = sync.Pool{ + New: func() any { + s := make([][]byte, 0, defaultGroupKeyCap) + return &s + }, +} + +func getBuffer() (*[][]aggfuncs.PartialResult, *[][]byte) { + partialResultsBuffer := partialResultsBufferPool.Get().(*[][]aggfuncs.PartialResult) + *partialResultsBuffer = (*partialResultsBuffer)[:0] + groupKey := groupKeyPool.Get().(*[][]byte) + *groupKey = (*groupKey)[:0] + return partialResultsBuffer, groupKey +} + +// tryRecycleBuffer recycles small buffers only. This approach reduces the CPU pressure +// from memory allocation during high concurrency aggregation computations (like DDL's scheduled tasks), +// and also prevents the pool from holding too much memory and causing memory pressure. +func tryRecycleBuffer(buf *[][]aggfuncs.PartialResult, groupKey *[][]byte) { + if cap(*buf) <= defaultPartialResultsBufferCap { + partialResultsBufferPool.Put(buf) + } + if cap(*groupKey) <= defaultGroupKeyCap { + groupKeyPool.Put(groupKey) + } +} + +func closeBaseExecutor(b *exec.BaseExecutor) { + if r := recover(); r != nil { + // Release the resource, but throw the panic again and let the top level handle it. + terror.Log(b.Close()) + logutil.BgLogger().Warn("panic in Open(), close base executor and throw exception again") + panic(r) + } +} + +func recoveryHashAgg(output chan *AfFinalResult, r any) { + err := util.GetRecoverError(r) + output <- &AfFinalResult{err: err} + logutil.BgLogger().Error("parallel hash aggregation panicked", zap.Error(err), zap.Stack("stack")) +} + +func getGroupKeyMemUsage(groupKey [][]byte) int64 { + mem := int64(0) + for _, key := range groupKey { + mem += int64(cap(key)) + } + mem += aggfuncs.DefSliceSize * int64(cap(groupKey)) + return mem +} + +// GetGroupKey evaluates the group items and args of aggregate functions. +func GetGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte, groupByItems []expression.Expression) ([][]byte, error) { + numRows := input.NumRows() + avlGroupKeyLen := min(len(groupKey), numRows) + for i := 0; i < avlGroupKeyLen; i++ { + groupKey[i] = groupKey[i][:0] + } + for i := avlGroupKeyLen; i < numRows; i++ { + groupKey = append(groupKey, make([]byte, 0, 10*len(groupByItems))) + } + + errCtx := ctx.GetSessionVars().StmtCtx.ErrCtx() + exprCtx := ctx.GetExprCtx() + for _, item := range groupByItems { + tp := item.GetType(ctx.GetExprCtx().GetEvalCtx()) + + buf, err := expression.GetColumn(tp.EvalType(), numRows) + if err != nil { + return nil, err + } + + // In strict sql mode like ‘STRICT_TRANS_TABLES’,can not insert an invalid enum value like 0. + // While in sql mode like '', can insert an invalid enum value like 0, + // then the enum value 0 will have the enum name '', which maybe conflict with user defined enum ''. + // Ref to issue #26885. + // This check is used to handle invalid enum name same with user defined enum name. + // Use enum value as groupKey instead of enum name. + if item.GetType(ctx.GetExprCtx().GetEvalCtx()).GetType() == mysql.TypeEnum { + newTp := *tp + newTp.AddFlag(mysql.EnumSetAsIntFlag) + tp = &newTp + } + + if err := expression.EvalExpr(exprCtx.GetEvalCtx(), ctx.GetSessionVars().EnableVectorizedExpression, item, tp.EvalType(), input, buf); err != nil { + expression.PutColumn(buf) + return nil, err + } + // This check is used to avoid error during the execution of `EncodeDecimal`. + if item.GetType(ctx.GetExprCtx().GetEvalCtx()).GetType() == mysql.TypeNewDecimal { + newTp := *tp + newTp.SetFlen(0) + tp = &newTp + } + + groupKey, err = codec.HashGroupKey(ctx.GetSessionVars().StmtCtx.TimeZone(), input.NumRows(), buf, groupKey, tp) + err = errCtx.HandleError(err) + if err != nil { + expression.PutColumn(buf) + return nil, err + } + expression.PutColumn(buf) + } + return groupKey[:numRows], nil +} + +// HashAggRuntimeStats record the HashAggExec runtime stat +type HashAggRuntimeStats struct { + PartialConcurrency int + PartialWallTime int64 + FinalConcurrency int + FinalWallTime int64 + PartialStats []*AggWorkerStat + FinalStats []*AggWorkerStat +} + +func (*HashAggRuntimeStats) workerString(buf *bytes.Buffer, prefix string, concurrency int, wallTime int64, workerStats []*AggWorkerStat) { + var totalTime, totalWait, totalExec, totalTaskNum int64 + for _, w := range workerStats { + totalTime += w.WorkerTime + totalWait += w.WaitTime + totalExec += w.ExecTime + totalTaskNum += w.TaskNum + } + buf.WriteString(prefix) + fmt.Fprintf(buf, "_worker:{wall_time:%s, concurrency:%d, task_num:%d, tot_wait:%s, tot_exec:%s, tot_time:%s", + time.Duration(wallTime), concurrency, totalTaskNum, time.Duration(totalWait), time.Duration(totalExec), time.Duration(totalTime)) + n := len(workerStats) + if n > 0 { + slices.SortFunc(workerStats, func(i, j *AggWorkerStat) int { return cmp.Compare(i.WorkerTime, j.WorkerTime) }) + fmt.Fprintf(buf, ", max:%v, p95:%v", + time.Duration(workerStats[n-1].WorkerTime), time.Duration(workerStats[n*19/20].WorkerTime)) + } + buf.WriteString("}") +} + +// String implements the RuntimeStats interface. +func (e *HashAggRuntimeStats) String() string { + buf := bytes.NewBuffer(make([]byte, 0, 64)) + e.workerString(buf, "partial", e.PartialConcurrency, atomic.LoadInt64(&e.PartialWallTime), e.PartialStats) + buf.WriteString(", ") + e.workerString(buf, "final", e.FinalConcurrency, atomic.LoadInt64(&e.FinalWallTime), e.FinalStats) + return buf.String() +} + +// Clone implements the RuntimeStats interface. +func (e *HashAggRuntimeStats) Clone() execdetails.RuntimeStats { + newRs := &HashAggRuntimeStats{ + PartialConcurrency: e.PartialConcurrency, + PartialWallTime: atomic.LoadInt64(&e.PartialWallTime), + FinalConcurrency: e.FinalConcurrency, + FinalWallTime: atomic.LoadInt64(&e.FinalWallTime), + PartialStats: make([]*AggWorkerStat, 0, e.PartialConcurrency), + FinalStats: make([]*AggWorkerStat, 0, e.FinalConcurrency), + } + for _, s := range e.PartialStats { + newRs.PartialStats = append(newRs.PartialStats, s.Clone()) + } + for _, s := range e.FinalStats { + newRs.FinalStats = append(newRs.FinalStats, s.Clone()) + } + return newRs +} + +// Merge implements the RuntimeStats interface. +func (e *HashAggRuntimeStats) Merge(other execdetails.RuntimeStats) { + tmp, ok := other.(*HashAggRuntimeStats) + if !ok { + return + } + atomic.AddInt64(&e.PartialWallTime, atomic.LoadInt64(&tmp.PartialWallTime)) + atomic.AddInt64(&e.FinalWallTime, atomic.LoadInt64(&tmp.FinalWallTime)) + e.PartialStats = append(e.PartialStats, tmp.PartialStats...) + e.FinalStats = append(e.FinalStats, tmp.FinalStats...) +} + +// Tp implements the RuntimeStats interface. +func (*HashAggRuntimeStats) Tp() int { + return execdetails.TpHashAggRuntimeStat +} + +// AggWorkerInfo contains the agg worker information. +type AggWorkerInfo struct { + Concurrency int + WallTime int64 +} + +// AggWorkerStat record the AggWorker runtime stat +type AggWorkerStat struct { + TaskNum int64 + WaitTime int64 + ExecTime int64 + WorkerTime int64 +} + +// Clone implements the RuntimeStats interface. +func (w *AggWorkerStat) Clone() *AggWorkerStat { + return &AggWorkerStat{ + TaskNum: w.TaskNum, + WaitTime: w.WaitTime, + ExecTime: w.ExecTime, + WorkerTime: w.WorkerTime, + } +} + +func (e *HashAggExec) actionSpillForUnparallel() memory.ActionOnExceed { + e.spillAction = &AggSpillDiskAction{ + e: e, + } + return e.spillAction +} + +func (e *HashAggExec) actionSpillForParallel() memory.ActionOnExceed { + e.parallelAggSpillAction = &ParallelAggSpillDiskAction{ + e: e, + spillHelper: e.spillHelper, + } + return e.parallelAggSpillAction +} + +// ActionSpill returns an action for spilling intermediate data for hashAgg. +func (e *HashAggExec) ActionSpill() memory.ActionOnExceed { + if e.IsUnparallelExec { + return e.actionSpillForUnparallel() + } + return e.actionSpillForParallel() +} + +func failpointError() error { + var err error + failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { + if val.(bool) { + num := rand.Intn(1000) + if num < 3 { + err = errors.Errorf("Random fail is triggered in ParallelAggSpillDiskAction") + } + } + }) + return err +} + +func updateWaitTime(stats *AggWorkerStat, startTime time.Time) { + if stats != nil { + stats.WaitTime += int64(time.Since(startTime)) + } +} + +func updateWorkerTime(stats *AggWorkerStat, startTime time.Time) { + if stats != nil { + stats.WorkerTime += int64(time.Since(startTime)) + } +} + +func updateExecTime(stats *AggWorkerStat, startTime time.Time) { + if stats != nil { + stats.ExecTime += int64(time.Since(startTime)) + stats.TaskNum++ + } +} diff --git a/pkg/executor/aggregate/binding__failpoint_binding__.go b/pkg/executor/aggregate/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..f2796e9412af6 --- /dev/null +++ b/pkg/executor/aggregate/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package aggregate + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/analyze.go b/pkg/executor/analyze.go index 9ab4963870dec..ed6c7c31ca7b1 100644 --- a/pkg/executor/analyze.go +++ b/pkg/executor/analyze.go @@ -124,12 +124,12 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) error { prepareV2AnalyzeJobInfo(task.colExec) AddNewAnalyzeJob(e.Ctx(), task.job) } - failpoint.Inject("mockKillPendingAnalyzeJob", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockKillPendingAnalyzeJob")); _err_ == nil { dom := domain.GetDomain(e.Ctx()) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - }) + } TASKLOOP: for _, task := range tasks { select { @@ -154,12 +154,12 @@ TASKLOOP: return err } - failpoint.Inject("mockKillFinishedAnalyzeJob", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockKillFinishedAnalyzeJob")); _err_ == nil { dom := domain.GetDomain(e.Ctx()) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - }) + } // If we enabled dynamic prune mode, then we need to generate global stats here for partition tables. if needGlobalStats { err = e.handleGlobalStats(statsHandle, globalStatsMap) @@ -415,7 +415,7 @@ func (e *AnalyzeExec) handleResultsError( } } logutil.BgLogger().Info("use single session to save analyze results") - failpoint.Inject("handleResultsErrorSingleThreadPanic", nil) + failpoint.Eval(_curpkg_("handleResultsErrorSingleThreadPanic")) subSctxs := []sessionctx.Context{e.Ctx()} return e.handleResultsErrorWithConcurrency(internalCtx, concurrency, needGlobalStats, subSctxs, globalStatsMap, resultsCh) } @@ -510,7 +510,7 @@ func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultsCh chan<- if !ok { break } - failpoint.Inject("handleAnalyzeWorkerPanic", nil) + failpoint.Eval(_curpkg_("handleAnalyzeWorkerPanic")) statsHandle.StartAnalyzeJob(task.job) switch task.taskType { case colTask: diff --git a/pkg/executor/analyze.go__failpoint_stash__ b/pkg/executor/analyze.go__failpoint_stash__ new file mode 100644 index 0000000000000..9ab4963870dec --- /dev/null +++ b/pkg/executor/analyze.go__failpoint_stash__ @@ -0,0 +1,619 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + stderrors "errors" + "fmt" + "math" + "net" + "strconv" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle" + statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" + handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "github.com/pingcap/tipb/go-tipb" + "github.com/tiancaiamao/gp" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var _ exec.Executor = &AnalyzeExec{} + +// AnalyzeExec represents Analyze executor. +type AnalyzeExec struct { + exec.BaseExecutor + tasks []*analyzeTask + wg *util.WaitGroupPool + opts map[ast.AnalyzeOptionType]uint64 + OptionsMap map[int64]core.V2AnalyzeOptions + gp *gp.Pool + // errExitCh is used to notice the worker that the whole analyze task is finished when to meet error. + errExitCh chan struct{} +} + +var ( + // RandSeed is the seed for randing package. + // It's public for test. + RandSeed = int64(1) + + // MaxRegionSampleSize is the max sample size for one region when analyze v1 collects samples from table. + // It's public for test. + MaxRegionSampleSize = int64(1000) +) + +type taskType int + +const ( + colTask taskType = iota + idxTask +) + +// Next implements the Executor Next interface. +// It will collect all the sample task and run them concurrently. +func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) error { + statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() + infoSchema := sessiontxn.GetTxnManager(e.Ctx()).GetTxnInfoSchema() + sessionVars := e.Ctx().GetSessionVars() + + // Filter the locked tables. + tasks, needAnalyzeTableCnt, skippedTables, err := filterAndCollectTasks(e.tasks, statsHandle, infoSchema) + if err != nil { + return err + } + warnLockedTableMsg(sessionVars, needAnalyzeTableCnt, skippedTables) + + if len(tasks) == 0 { + return nil + } + + // Get the min number of goroutines for parallel execution. + concurrency, err := getBuildStatsConcurrency(e.Ctx()) + if err != nil { + return err + } + concurrency = min(len(tasks), concurrency) + + // Start workers with channel to collect results. + taskCh := make(chan *analyzeTask, concurrency) + resultsCh := make(chan *statistics.AnalyzeResults, 1) + for i := 0; i < concurrency; i++ { + e.wg.Run(func() { e.analyzeWorker(taskCh, resultsCh) }) + } + pruneMode := variable.PartitionPruneMode(sessionVars.PartitionPruneMode.Load()) + // needGlobalStats used to indicate whether we should merge the partition-level stats to global-level stats. + needGlobalStats := pruneMode == variable.Dynamic + globalStatsMap := make(map[globalStatsKey]statstypes.GlobalStatsInfo) + g, gctx := errgroup.WithContext(ctx) + g.Go(func() error { + return e.handleResultsError(ctx, concurrency, needGlobalStats, globalStatsMap, resultsCh, len(tasks)) + }) + for _, task := range tasks { + prepareV2AnalyzeJobInfo(task.colExec) + AddNewAnalyzeJob(e.Ctx(), task.job) + } + failpoint.Inject("mockKillPendingAnalyzeJob", func() { + dom := domain.GetDomain(e.Ctx()) + for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { + dom.SysProcTracker().KillSysProcess(id) + } + }) +TASKLOOP: + for _, task := range tasks { + select { + case taskCh <- task: + case <-e.errExitCh: + break TASKLOOP + case <-gctx.Done(): + break TASKLOOP + } + } + close(taskCh) + defer func() { + for _, task := range tasks { + if task.colExec != nil && task.colExec.memTracker != nil { + task.colExec.memTracker.Detach() + } + } + }() + + err = e.waitFinish(ctx, g, resultsCh) + if err != nil { + return err + } + + failpoint.Inject("mockKillFinishedAnalyzeJob", func() { + dom := domain.GetDomain(e.Ctx()) + for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { + dom.SysProcTracker().KillSysProcess(id) + } + }) + // If we enabled dynamic prune mode, then we need to generate global stats here for partition tables. + if needGlobalStats { + err = e.handleGlobalStats(statsHandle, globalStatsMap) + if err != nil { + return err + } + } + + // Update analyze options to mysql.analyze_options for auto analyze. + err = e.saveV2AnalyzeOpts() + if err != nil { + sessionVars.StmtCtx.AppendWarning(err) + } + return statsHandle.Update(ctx, infoSchema) +} + +func (e *AnalyzeExec) waitFinish(ctx context.Context, g *errgroup.Group, resultsCh chan *statistics.AnalyzeResults) error { + checkwg, _ := errgroup.WithContext(ctx) + checkwg.Go(func() error { + // It is to wait for the completion of the result handler. if the result handler meets error, we should cancel + // the analyze process by closing the errExitCh. + err := g.Wait() + if err != nil { + close(e.errExitCh) + return err + } + return nil + }) + checkwg.Go(func() error { + // Wait all workers done and close the results channel. + e.wg.Wait() + close(resultsCh) + return nil + }) + return checkwg.Wait() +} + +// filterAndCollectTasks filters the tasks that are not locked and collects the table IDs. +func filterAndCollectTasks(tasks []*analyzeTask, statsHandle *handle.Handle, is infoschema.InfoSchema) ([]*analyzeTask, uint, []string, error) { + var ( + filteredTasks []*analyzeTask + skippedTables []string + needAnalyzeTableCnt uint + // tidMap is used to deduplicate table IDs. + // In stats v1, analyze for each index is a single task, and they have the same table id. + tidAndPidsMap = make(map[int64]struct{}, len(tasks)) + ) + + lockedTableAndPartitionIDs, err := getLockedTableAndPartitionIDs(statsHandle, tasks) + if err != nil { + return nil, 0, nil, err + } + + for _, task := range tasks { + // Check if the table or partition is locked. + tableID := getTableIDFromTask(task) + _, isLocked := lockedTableAndPartitionIDs[tableID.TableID] + // If the whole table is not locked, we should check whether the partition is locked. + if !isLocked && tableID.IsPartitionTable() { + _, isLocked = lockedTableAndPartitionIDs[tableID.PartitionID] + } + + // Only analyze the table that is not locked. + if !isLocked { + filteredTasks = append(filteredTasks, task) + } + + // Get the physical table ID. + physicalTableID := tableID.TableID + if tableID.IsPartitionTable() { + physicalTableID = tableID.PartitionID + } + if _, ok := tidAndPidsMap[physicalTableID]; !ok { + if isLocked { + if tableID.IsPartitionTable() { + tbl, _, def := is.FindTableByPartitionID(tableID.PartitionID) + if def == nil { + logutil.BgLogger().Warn("Unknown partition ID in analyze task", zap.Int64("pid", tableID.PartitionID)) + } else { + schema, _ := infoschema.SchemaByTable(is, tbl.Meta()) + skippedTables = append(skippedTables, fmt.Sprintf("%s.%s partition (%s)", schema.Name, tbl.Meta().Name.O, def.Name.O)) + } + } else { + tbl, ok := is.TableByID(physicalTableID) + if !ok { + logutil.BgLogger().Warn("Unknown table ID in analyze task", zap.Int64("tid", physicalTableID)) + } else { + schema, _ := infoschema.SchemaByTable(is, tbl.Meta()) + skippedTables = append(skippedTables, fmt.Sprintf("%s.%s", schema.Name, tbl.Meta().Name.O)) + } + } + } else { + needAnalyzeTableCnt++ + } + tidAndPidsMap[physicalTableID] = struct{}{} + } + } + + return filteredTasks, needAnalyzeTableCnt, skippedTables, nil +} + +// getLockedTableAndPartitionIDs queries the locked tables and partitions. +func getLockedTableAndPartitionIDs(statsHandle *handle.Handle, tasks []*analyzeTask) (map[int64]struct{}, error) { + tidAndPids := make([]int64, 0, len(tasks)) + // Check the locked tables in one transaction. + // We need to check all tables and its partitions. + // Because if the whole table is locked, we should skip all partitions. + for _, task := range tasks { + tableID := getTableIDFromTask(task) + tidAndPids = append(tidAndPids, tableID.TableID) + if tableID.IsPartitionTable() { + tidAndPids = append(tidAndPids, tableID.PartitionID) + } + } + return statsHandle.GetLockedTables(tidAndPids...) +} + +// warnLockedTableMsg warns the locked table IDs. +func warnLockedTableMsg(sessionVars *variable.SessionVars, needAnalyzeTableCnt uint, skippedTables []string) { + if len(skippedTables) > 0 { + tables := strings.Join(skippedTables, ", ") + var msg string + if len(skippedTables) > 1 { + msg = "skip analyze locked tables: %s" + if needAnalyzeTableCnt > 0 { + msg = "skip analyze locked tables: %s, other tables will be analyzed" + } + } else { + msg = "skip analyze locked table: %s" + } + sessionVars.StmtCtx.AppendWarning(errors.NewNoStackErrorf(msg, tables)) + } +} + +func getTableIDFromTask(task *analyzeTask) statistics.AnalyzeTableID { + switch task.taskType { + case colTask: + return task.colExec.tableID + case idxTask: + return task.idxExec.tableID + } + + panic("unreachable") +} + +func (e *AnalyzeExec) saveV2AnalyzeOpts() error { + if !variable.PersistAnalyzeOptions.Load() || len(e.OptionsMap) == 0 { + return nil + } + // only to save table options if dynamic prune mode + dynamicPrune := variable.PartitionPruneMode(e.Ctx().GetSessionVars().PartitionPruneMode.Load()) == variable.Dynamic + toSaveMap := make(map[int64]core.V2AnalyzeOptions) + for id, opts := range e.OptionsMap { + if !opts.IsPartition || !dynamicPrune { + toSaveMap[id] = opts + } + } + sql := new(strings.Builder) + sqlescape.MustFormatSQL(sql, "REPLACE INTO mysql.analyze_options (table_id,sample_num,sample_rate,buckets,topn,column_choice,column_ids) VALUES ") + idx := 0 + for _, opts := range toSaveMap { + sampleNum := opts.RawOpts[ast.AnalyzeOptNumSamples] + sampleRate := float64(0) + if val, ok := opts.RawOpts[ast.AnalyzeOptSampleRate]; ok { + sampleRate = math.Float64frombits(val) + } + buckets := opts.RawOpts[ast.AnalyzeOptNumBuckets] + topn := int64(-1) + if val, ok := opts.RawOpts[ast.AnalyzeOptNumTopN]; ok { + topn = int64(val) + } + colChoice := opts.ColChoice.String() + colIDs := make([]string, 0, len(opts.ColumnList)) + for _, colInfo := range opts.ColumnList { + colIDs = append(colIDs, strconv.FormatInt(colInfo.ID, 10)) + } + colIDStrs := strings.Join(colIDs, ",") + sqlescape.MustFormatSQL(sql, "(%?,%?,%?,%?,%?,%?,%?)", opts.PhyTableID, sampleNum, sampleRate, buckets, topn, colChoice, colIDStrs) + if idx < len(toSaveMap)-1 { + sqlescape.MustFormatSQL(sql, ",") + } + idx++ + } + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) + exec := e.Ctx().GetRestrictedSQLExecutor() + _, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + return err + } + return nil +} + +func recordHistoricalStats(sctx sessionctx.Context, tableID int64) error { + statsHandle := domain.GetDomain(sctx).StatsHandle() + historicalStatsEnabled, err := statsHandle.CheckHistoricalStatsEnable() + if err != nil { + return errors.Errorf("check tidb_enable_historical_stats failed: %v", err) + } + if !historicalStatsEnabled { + return nil + } + historicalStatsWorker := domain.GetDomain(sctx).GetHistoricalStatsWorker() + historicalStatsWorker.SendTblToDumpHistoricalStats(tableID) + return nil +} + +// handleResultsError will handle the error fetch from resultsCh and record it in log +func (e *AnalyzeExec) handleResultsError( + ctx context.Context, + concurrency int, + needGlobalStats bool, + globalStatsMap globalStatsMap, + resultsCh <-chan *statistics.AnalyzeResults, + taskNum int, +) (err error) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("analyze save stats panic", zap.Any("recover", r), zap.Stack("stack")) + if err != nil { + err = stderrors.Join(err, getAnalyzePanicErr(r)) + } else { + err = getAnalyzePanicErr(r) + } + } + }() + partitionStatsConcurrency := e.Ctx().GetSessionVars().AnalyzePartitionConcurrency + // the concurrency of handleResultsError cannot be more than partitionStatsConcurrency + partitionStatsConcurrency = min(taskNum, partitionStatsConcurrency) + // If partitionStatsConcurrency > 1, we will try to demand extra session from Domain to save Analyze results in concurrency. + // If there is no extra session we can use, we will save analyze results in single-thread. + dom := domain.GetDomain(e.Ctx()) + internalCtx := kv.WithInternalSourceType(ctx, kv.InternalTxnStats) + if partitionStatsConcurrency > 1 { + // FIXME: Since we don't use it either to save analysis results or to store job history, it has no effect. Please remove this :( + subSctxs := dom.FetchAnalyzeExec(partitionStatsConcurrency) + warningMessage := "Insufficient sessions to save analyze results. Consider increasing the 'analyze-partition-concurrency-quota' configuration to improve analyze performance. " + + "This value should typically be greater than or equal to the 'tidb_analyze_partition_concurrency' variable." + if len(subSctxs) < partitionStatsConcurrency { + e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError(warningMessage)) + logutil.BgLogger().Warn( + warningMessage, + zap.Int("sessionCount", len(subSctxs)), + zap.Int("needSessionCount", partitionStatsConcurrency), + ) + } + if len(subSctxs) > 0 { + sessionCount := len(subSctxs) + logutil.BgLogger().Info("use multiple sessions to save analyze results", zap.Int("sessionCount", sessionCount)) + defer func() { + dom.ReleaseAnalyzeExec(subSctxs) + }() + return e.handleResultsErrorWithConcurrency(internalCtx, concurrency, needGlobalStats, subSctxs, globalStatsMap, resultsCh) + } + } + logutil.BgLogger().Info("use single session to save analyze results") + failpoint.Inject("handleResultsErrorSingleThreadPanic", nil) + subSctxs := []sessionctx.Context{e.Ctx()} + return e.handleResultsErrorWithConcurrency(internalCtx, concurrency, needGlobalStats, subSctxs, globalStatsMap, resultsCh) +} + +func (e *AnalyzeExec) handleResultsErrorWithConcurrency( + ctx context.Context, + statsConcurrency int, + needGlobalStats bool, + subSctxs []sessionctx.Context, + globalStatsMap globalStatsMap, + resultsCh <-chan *statistics.AnalyzeResults, +) error { + partitionStatsConcurrency := len(subSctxs) + statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() + wg := util.NewWaitGroupPool(e.gp) + saveResultsCh := make(chan *statistics.AnalyzeResults, partitionStatsConcurrency) + errCh := make(chan error, partitionStatsConcurrency) + for i := 0; i < partitionStatsConcurrency; i++ { + worker := newAnalyzeSaveStatsWorker(saveResultsCh, subSctxs[i], errCh, &e.Ctx().GetSessionVars().SQLKiller) + ctx1 := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + wg.Run(func() { + worker.run(ctx1, statsHandle, e.Ctx().GetSessionVars().EnableAnalyzeSnapshot) + }) + } + tableIDs := map[int64]struct{}{} + panicCnt := 0 + var err error + for panicCnt < statsConcurrency { + if err := e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { + close(saveResultsCh) + return err + } + results, ok := <-resultsCh + if !ok { + break + } + if results.Err != nil { + err = results.Err + if isAnalyzeWorkerPanic(err) { + panicCnt++ + } else { + logutil.Logger(ctx).Error("analyze failed", zap.Error(err)) + } + finishJobWithLog(statsHandle, results.Job, err) + continue + } + handleGlobalStats(needGlobalStats, globalStatsMap, results) + tableIDs[results.TableID.GetStatisticsID()] = struct{}{} + saveResultsCh <- results + } + close(saveResultsCh) + wg.Wait() + close(errCh) + if len(errCh) > 0 { + errMsg := make([]string, 0) + for err1 := range errCh { + errMsg = append(errMsg, err1.Error()) + } + err = errors.New(strings.Join(errMsg, ",")) + } + for tableID := range tableIDs { + // Dump stats to historical storage. + if err := recordHistoricalStats(e.Ctx(), tableID); err != nil { + logutil.BgLogger().Error("record historical stats failed", zap.Error(err)) + } + } + return err +} + +func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultsCh chan<- *statistics.AnalyzeResults) { + var task *analyzeTask + statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) + metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() + // If errExitCh is closed, it means the whole analyze task is aborted. So we do not need to send the result to resultsCh. + err := getAnalyzePanicErr(r) + select { + case resultsCh <- &statistics.AnalyzeResults{ + Err: err, + Job: task.job, + }: + case <-e.errExitCh: + logutil.BgLogger().Error("analyze worker exits because the whole analyze task is aborted", zap.Error(err)) + } + } + }() + for { + var ok bool + task, ok = <-taskCh + if !ok { + break + } + failpoint.Inject("handleAnalyzeWorkerPanic", nil) + statsHandle.StartAnalyzeJob(task.job) + switch task.taskType { + case colTask: + select { + case <-e.errExitCh: + return + case resultsCh <- analyzeColumnsPushDownEntry(e.gp, task.colExec): + } + case idxTask: + select { + case <-e.errExitCh: + return + case resultsCh <- analyzeIndexPushdown(task.idxExec): + } + } + } +} + +type analyzeTask struct { + taskType taskType + idxExec *AnalyzeIndexExec + colExec *AnalyzeColumnsExec + job *statistics.AnalyzeJob +} + +type baseAnalyzeExec struct { + ctx sessionctx.Context + tableID statistics.AnalyzeTableID + concurrency int + analyzePB *tipb.AnalyzeReq + opts map[ast.AnalyzeOptionType]uint64 + job *statistics.AnalyzeJob + snapshot uint64 +} + +// AddNewAnalyzeJob records the new analyze job. +func AddNewAnalyzeJob(ctx sessionctx.Context, job *statistics.AnalyzeJob) { + if job == nil { + return + } + var instance string + serverInfo, err := infosync.GetServerInfo() + if err != nil { + logutil.BgLogger().Error("failed to get server info", zap.Error(err)) + instance = "unknown" + } else { + instance = net.JoinHostPort(serverInfo.IP, strconv.Itoa(int(serverInfo.Port))) + } + statsHandle := domain.GetDomain(ctx).StatsHandle() + err = statsHandle.InsertAnalyzeJob(job, instance, ctx.GetSessionVars().ConnectionID) + if err != nil { + logutil.BgLogger().Error("failed to insert analyze job", zap.Error(err)) + } +} + +func finishJobWithLog(statsHandle *handle.Handle, job *statistics.AnalyzeJob, analyzeErr error) { + statsHandle.FinishAnalyzeJob(job, analyzeErr, statistics.TableAnalysisJob) + if job != nil { + var state string + if analyzeErr != nil { + state = statistics.AnalyzeFailed + logutil.BgLogger().Warn(fmt.Sprintf("analyze table `%s`.`%s` has %s", job.DBName, job.TableName, state), + zap.String("partition", job.PartitionName), + zap.String("job info", job.JobInfo), + zap.Time("start time", job.StartTime), + zap.Time("end time", job.EndTime), + zap.String("cost", job.EndTime.Sub(job.StartTime).String()), + zap.String("sample rate reason", job.SampleRateReason), + zap.Error(analyzeErr)) + } else { + state = statistics.AnalyzeFinished + logutil.BgLogger().Info(fmt.Sprintf("analyze table `%s`.`%s` has %s", job.DBName, job.TableName, state), + zap.String("partition", job.PartitionName), + zap.String("job info", job.JobInfo), + zap.Time("start time", job.StartTime), + zap.Time("end time", job.EndTime), + zap.String("cost", job.EndTime.Sub(job.StartTime).String()), + zap.String("sample rate reason", job.SampleRateReason)) + } + } +} + +func handleGlobalStats(needGlobalStats bool, globalStatsMap globalStatsMap, results *statistics.AnalyzeResults) { + if results.TableID.IsPartitionTable() && needGlobalStats { + for _, result := range results.Ars { + if result.IsIndex == 0 { + // If it does not belong to the statistics of index, we need to set it to -1 to distinguish. + globalStatsID := globalStatsKey{tableID: results.TableID.TableID, indexID: int64(-1)} + histIDs := make([]int64, 0, len(result.Hist)) + for _, hg := range result.Hist { + // It's normal virtual column, skip. + if hg == nil { + continue + } + histIDs = append(histIDs, hg.ID) + } + globalStatsMap[globalStatsID] = statstypes.GlobalStatsInfo{IsIndex: result.IsIndex, HistIDs: histIDs, StatsVersion: results.StatsVer} + } else { + for _, hg := range result.Hist { + globalStatsID := globalStatsKey{tableID: results.TableID.TableID, indexID: hg.ID} + globalStatsMap[globalStatsID] = statstypes.GlobalStatsInfo{IsIndex: result.IsIndex, HistIDs: []int64{hg.ID}, StatsVersion: results.StatsVer} + } + } + } + } +} diff --git a/pkg/executor/analyze_col.go b/pkg/executor/analyze_col.go index 78a5190b8d21b..cab8a5101e3d9 100644 --- a/pkg/executor/analyze_col.go +++ b/pkg/executor/analyze_col.go @@ -176,18 +176,18 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats boo } statsHandle := domain.GetDomain(e.ctx).StatsHandle() for { - failpoint.Inject("mockKillRunningV1AnalyzeJob", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockKillRunningV1AnalyzeJob")); _err_ == nil { dom := domain.GetDomain(e.ctx) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - }) + } if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { return nil, nil, nil, nil, nil, err } - failpoint.Inject("mockSlowAnalyzeV1", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockSlowAnalyzeV1")); _err_ == nil { time.Sleep(1000 * time.Second) - }) + } data, err1 := e.resultHandler.nextRaw(context.TODO()) if err1 != nil { return nil, nil, nil, nil, nil, err1 diff --git a/pkg/executor/analyze_col.go__failpoint_stash__ b/pkg/executor/analyze_col.go__failpoint_stash__ new file mode 100644 index 0000000000000..78a5190b8d21b --- /dev/null +++ b/pkg/executor/analyze_col.go__failpoint_stash__ @@ -0,0 +1,494 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "fmt" + "math" + "strings" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/statistics" + handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tipb/go-tipb" + "github.com/tiancaiamao/gp" +) + +// AnalyzeColumnsExec represents Analyze columns push down executor. +type AnalyzeColumnsExec struct { + baseAnalyzeExec + + tableInfo *model.TableInfo + colsInfo []*model.ColumnInfo + handleCols plannerutil.HandleCols + commonHandle *model.IndexInfo + resultHandler *tableResultHandler + indexes []*model.IndexInfo + core.AnalyzeInfo + + samplingBuilderWg *notifyErrorWaitGroupWrapper + samplingMergeWg *util.WaitGroupWrapper + + schemaForVirtualColEval *expression.Schema + baseCount int64 + baseModifyCnt int64 + + memTracker *memory.Tracker +} + +func analyzeColumnsPushDownEntry(gp *gp.Pool, e *AnalyzeColumnsExec) *statistics.AnalyzeResults { + if e.AnalyzeInfo.StatsVersion >= statistics.Version2 { + return e.toV2().analyzeColumnsPushDownV2(gp) + } + return e.toV1().analyzeColumnsPushDownV1() +} + +func (e *AnalyzeColumnsExec) toV1() *AnalyzeColumnsExecV1 { + return &AnalyzeColumnsExecV1{ + AnalyzeColumnsExec: e, + } +} + +func (e *AnalyzeColumnsExec) toV2() *AnalyzeColumnsExecV2 { + return &AnalyzeColumnsExecV2{ + AnalyzeColumnsExec: e, + } +} + +func (e *AnalyzeColumnsExec) open(ranges []*ranger.Range) error { + e.memTracker = memory.NewTracker(int(e.ctx.GetSessionVars().PlanID.Load()), -1) + e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) + e.resultHandler = &tableResultHandler{} + firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(ranges, true, false, !hasPkHist(e.handleCols)) + firstResult, err := e.buildResp(firstPartRanges) + if err != nil { + return err + } + if len(secondPartRanges) == 0 { + e.resultHandler.open(nil, firstResult) + return nil + } + var secondResult distsql.SelectResult + secondResult, err = e.buildResp(secondPartRanges) + if err != nil { + return err + } + e.resultHandler.open(firstResult, secondResult) + + return nil +} + +func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectResult, error) { + var builder distsql.RequestBuilder + reqBuilder := builder.SetHandleRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.TableID.GetStatisticsID()}, e.handleCols != nil && !e.handleCols.IsInt(), ranges) + builder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) + startTS := uint64(math.MaxUint64) + isoLevel := kv.RC + if e.ctx.GetSessionVars().EnableAnalyzeSnapshot { + startTS = e.snapshot + isoLevel = kv.SI + } + // Always set KeepOrder of the request to be true, in order to compute + // correct `correlation` of columns. + kvReq, err := reqBuilder. + SetAnalyzeRequest(e.analyzePB, isoLevel). + SetStartTS(startTS). + SetKeepOrder(true). + SetConcurrency(e.concurrency). + SetMemTracker(e.memTracker). + SetResourceGroupName(e.ctx.GetSessionVars().StmtCtx.ResourceGroupName). + SetExplicitRequestSourceType(e.ctx.GetSessionVars().ExplicitRequestSourceType). + Build() + if err != nil { + return nil, err + } + ctx := context.TODO() + result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) + if err != nil { + return nil, err + } + return result, nil +} + +func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats bool) (hists []*statistics.Histogram, cms []*statistics.CMSketch, topNs []*statistics.TopN, fms []*statistics.FMSketch, extStats *statistics.ExtendedStatsColl, err error) { + if err = e.open(ranges); err != nil { + return nil, nil, nil, nil, nil, err + } + defer func() { + if err1 := e.resultHandler.Close(); err1 != nil { + hists = nil + cms = nil + extStats = nil + err = err1 + } + }() + var handleHist *statistics.Histogram + var handleCms *statistics.CMSketch + var handleFms *statistics.FMSketch + var handleTopn *statistics.TopN + statsVer := statistics.Version1 + if e.analyzePB.Tp == tipb.AnalyzeType_TypeMixed { + handleHist = &statistics.Histogram{} + handleCms = statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])) + handleTopn = statistics.NewTopN(int(e.opts[ast.AnalyzeOptNumTopN])) + handleFms = statistics.NewFMSketch(statistics.MaxSketchSize) + if e.analyzePB.IdxReq.Version != nil { + statsVer = int(*e.analyzePB.IdxReq.Version) + } + } + pkHist := &statistics.Histogram{} + collectors := make([]*statistics.SampleCollector, len(e.colsInfo)) + for i := range collectors { + collectors[i] = &statistics.SampleCollector{ + IsMerger: true, + FMSketch: statistics.NewFMSketch(statistics.MaxSketchSize), + MaxSampleSize: int64(e.opts[ast.AnalyzeOptNumSamples]), + CMSketch: statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])), + } + } + statsHandle := domain.GetDomain(e.ctx).StatsHandle() + for { + failpoint.Inject("mockKillRunningV1AnalyzeJob", func() { + dom := domain.GetDomain(e.ctx) + for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { + dom.SysProcTracker().KillSysProcess(id) + } + }) + if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return nil, nil, nil, nil, nil, err + } + failpoint.Inject("mockSlowAnalyzeV1", func() { + time.Sleep(1000 * time.Second) + }) + data, err1 := e.resultHandler.nextRaw(context.TODO()) + if err1 != nil { + return nil, nil, nil, nil, nil, err1 + } + if data == nil { + break + } + var colResp *tipb.AnalyzeColumnsResp + if e.analyzePB.Tp == tipb.AnalyzeType_TypeMixed { + resp := &tipb.AnalyzeMixedResp{} + err = resp.Unmarshal(data) + if err != nil { + return nil, nil, nil, nil, nil, err + } + colResp = resp.ColumnsResp + handleHist, handleCms, handleFms, handleTopn, err = updateIndexResult(e.ctx, resp.IndexResp, nil, handleHist, + handleCms, handleFms, handleTopn, e.commonHandle, int(e.opts[ast.AnalyzeOptNumBuckets]), + int(e.opts[ast.AnalyzeOptNumTopN]), statsVer) + + if err != nil { + return nil, nil, nil, nil, nil, err + } + } else { + colResp = &tipb.AnalyzeColumnsResp{} + err = colResp.Unmarshal(data) + } + sc := e.ctx.GetSessionVars().StmtCtx + rowCount := int64(0) + if hasPkHist(e.handleCols) { + respHist := statistics.HistogramFromProto(colResp.PkHist) + rowCount = int64(respHist.TotalRowCount()) + pkHist, err = statistics.MergeHistograms(sc, pkHist, respHist, int(e.opts[ast.AnalyzeOptNumBuckets]), statistics.Version1) + if err != nil { + return nil, nil, nil, nil, nil, err + } + } + for i, rc := range colResp.Collectors { + respSample := statistics.SampleCollectorFromProto(rc) + rowCount = respSample.Count + respSample.NullCount + collectors[i].MergeSampleCollector(sc, respSample) + } + statsHandle.UpdateAnalyzeJobProgress(e.job, rowCount) + } + timeZone := e.ctx.GetSessionVars().Location() + if hasPkHist(e.handleCols) { + pkInfo := e.handleCols.GetCol(0) + pkHist.ID = pkInfo.ID + err = pkHist.DecodeTo(pkInfo.RetType, timeZone) + if err != nil { + return nil, nil, nil, nil, nil, err + } + hists = append(hists, pkHist) + cms = append(cms, nil) + topNs = append(topNs, nil) + fms = append(fms, nil) + } + for i, col := range e.colsInfo { + if e.StatsVersion < 2 { + // In analyze version 2, we don't collect TopN this way. We will collect TopN from samples in `BuildColumnHistAndTopN()` below. + err := collectors[i].ExtractTopN(uint32(e.opts[ast.AnalyzeOptNumTopN]), e.ctx.GetSessionVars().StmtCtx, &col.FieldType, timeZone) + if err != nil { + return nil, nil, nil, nil, nil, err + } + topNs = append(topNs, collectors[i].TopN) + } + for j, s := range collectors[i].Samples { + s.Ordinal = j + s.Value, err = tablecodec.DecodeColumnValue(s.Value.GetBytes(), &col.FieldType, timeZone) + if err != nil { + return nil, nil, nil, nil, nil, err + } + // When collation is enabled, we store the Key representation of the sampling data. So we set it to kind `Bytes` here + // to avoid to convert it to its Key representation once more. + if s.Value.Kind() == types.KindString { + s.Value.SetBytes(s.Value.GetBytes()) + } + } + var hg *statistics.Histogram + var err error + var topn *statistics.TopN + if e.StatsVersion < 2 { + hg, err = statistics.BuildColumn(e.ctx, int64(e.opts[ast.AnalyzeOptNumBuckets]), col.ID, collectors[i], &col.FieldType) + } else { + hg, topn, err = statistics.BuildHistAndTopN(e.ctx, int(e.opts[ast.AnalyzeOptNumBuckets]), int(e.opts[ast.AnalyzeOptNumTopN]), col.ID, collectors[i], &col.FieldType, true, nil, true) + topNs = append(topNs, topn) + } + if err != nil { + return nil, nil, nil, nil, nil, err + } + hists = append(hists, hg) + collectors[i].CMSketch.CalcDefaultValForAnalyze(uint64(hg.NDV)) + cms = append(cms, collectors[i].CMSketch) + fms = append(fms, collectors[i].FMSketch) + } + if needExtStats { + extStats, err = statistics.BuildExtendedStats(e.ctx, e.TableID.GetStatisticsID(), e.colsInfo, collectors) + if err != nil { + return nil, nil, nil, nil, nil, err + } + } + if handleHist != nil { + handleHist.ID = e.commonHandle.ID + if handleTopn != nil && handleTopn.TotalCount() > 0 { + handleHist.RemoveVals(handleTopn.TopN) + } + if handleCms != nil { + handleCms.CalcDefaultValForAnalyze(uint64(handleHist.NDV)) + } + hists = append([]*statistics.Histogram{handleHist}, hists...) + cms = append([]*statistics.CMSketch{handleCms}, cms...) + fms = append([]*statistics.FMSketch{handleFms}, fms...) + topNs = append([]*statistics.TopN{handleTopn}, topNs...) + } + return hists, cms, topNs, fms, extStats, nil +} + +// AnalyzeColumnsExecV1 is used to maintain v1 analyze process +type AnalyzeColumnsExecV1 struct { + *AnalyzeColumnsExec +} + +func (e *AnalyzeColumnsExecV1) analyzeColumnsPushDownV1() *statistics.AnalyzeResults { + var ranges []*ranger.Range + if hc := e.handleCols; hc != nil { + if hc.IsInt() { + ranges = ranger.FullIntRange(mysql.HasUnsignedFlag(hc.GetCol(0).RetType.GetFlag())) + } else { + ranges = ranger.FullNotNullRange() + } + } else { + ranges = ranger.FullIntRange(false) + } + collExtStats := e.ctx.GetSessionVars().EnableExtendedStats + hists, cms, topNs, fms, extStats, err := e.buildStats(ranges, collExtStats) + if err != nil { + return &statistics.AnalyzeResults{Err: err, Job: e.job} + } + + if hasPkHist(e.handleCols) { + pkResult := &statistics.AnalyzeResult{ + Hist: hists[:1], + Cms: cms[:1], + TopNs: topNs[:1], + Fms: fms[:1], + } + restResult := &statistics.AnalyzeResult{ + Hist: hists[1:], + Cms: cms[1:], + TopNs: topNs[1:], + Fms: fms[1:], + } + return &statistics.AnalyzeResults{ + TableID: e.tableID, + Ars: []*statistics.AnalyzeResult{pkResult, restResult}, + ExtStats: extStats, + Job: e.job, + StatsVer: e.StatsVersion, + Count: int64(pkResult.Hist[0].TotalRowCount()), + Snapshot: e.snapshot, + } + } + var ars []*statistics.AnalyzeResult + if e.analyzePB.Tp == tipb.AnalyzeType_TypeMixed { + ars = append(ars, &statistics.AnalyzeResult{ + Hist: []*statistics.Histogram{hists[0]}, + Cms: []*statistics.CMSketch{cms[0]}, + TopNs: []*statistics.TopN{topNs[0]}, + Fms: []*statistics.FMSketch{nil}, + IsIndex: 1, + }) + hists = hists[1:] + cms = cms[1:] + topNs = topNs[1:] + } + colResult := &statistics.AnalyzeResult{ + Hist: hists, + Cms: cms, + TopNs: topNs, + Fms: fms, + } + ars = append(ars, colResult) + cnt := int64(hists[0].TotalRowCount()) + if e.StatsVersion >= statistics.Version2 { + cnt += int64(topNs[0].TotalCount()) + } + return &statistics.AnalyzeResults{ + TableID: e.tableID, + Ars: ars, + Job: e.job, + StatsVer: e.StatsVersion, + ExtStats: extStats, + Count: cnt, + Snapshot: e.snapshot, + } +} + +func hasPkHist(handleCols plannerutil.HandleCols) bool { + return handleCols != nil && handleCols.IsInt() +} + +// prepareColumns prepares the columns for the analyze job. +func prepareColumns(e *AnalyzeColumnsExec, b *strings.Builder) { + cols := e.colsInfo + // Ignore the _row_id column. + if len(cols) > 0 && cols[len(cols)-1].ID == model.ExtraHandleID { + cols = cols[:len(cols)-1] + } + // If there are no columns, skip the process. + if len(cols) == 0 { + return + } + if len(cols) < len(e.tableInfo.Columns) { + if len(cols) > 1 { + b.WriteString(" columns ") + } else { + b.WriteString(" column ") + } + for i, col := range cols { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(col.Name.O) + } + } else { + b.WriteString(" all columns") + } +} + +// prepareIndexes prepares the indexes for the analyze job. +func prepareIndexes(e *AnalyzeColumnsExec, b *strings.Builder) { + indexes := e.indexes + + // If there are no indexes, skip the process. + if len(indexes) == 0 { + return + } + if len(indexes) < len(e.tableInfo.Indices) { + if len(indexes) > 1 { + b.WriteString(" indexes ") + } else { + b.WriteString(" index ") + } + for i, index := range indexes { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(index.Name.O) + } + } else { + b.WriteString(" all indexes") + } +} + +// prepareV2AnalyzeJobInfo prepares the job info for the analyze job. +func prepareV2AnalyzeJobInfo(e *AnalyzeColumnsExec) { + // For v1, we analyze all columns in a single job, so we don't need to set the job info. + if e == nil || e.StatsVersion != statistics.Version2 { + return + } + + opts := e.opts + if e.V2Options != nil { + opts = e.V2Options.FilledOpts + } + sampleRate := *e.analyzePB.ColReq.SampleRate + var b strings.Builder + // If it is an internal SQL, it means it is triggered by the system itself(auto-analyze). + if e.ctx.GetSessionVars().InRestrictedSQL { + b.WriteString("auto ") + } + b.WriteString("analyze table") + + prepareIndexes(e, &b) + if len(e.indexes) > 0 && len(e.colsInfo) > 0 { + b.WriteString(",") + } + prepareColumns(e, &b) + + var needComma bool + b.WriteString(" with ") + printOption := func(optType ast.AnalyzeOptionType) { + if val, ok := opts[optType]; ok { + if needComma { + b.WriteString(", ") + } else { + needComma = true + } + b.WriteString(fmt.Sprintf("%v %s", val, strings.ToLower(ast.AnalyzeOptionString[optType]))) + } + } + printOption(ast.AnalyzeOptNumBuckets) + printOption(ast.AnalyzeOptNumTopN) + if opts[ast.AnalyzeOptNumSamples] != 0 { + printOption(ast.AnalyzeOptNumSamples) + } else { + if needComma { + b.WriteString(", ") + } else { + needComma = true + } + b.WriteString(fmt.Sprintf("%v samplerate", sampleRate)) + } + e.job.JobInfo = b.String() +} diff --git a/pkg/executor/analyze_col_v2.go b/pkg/executor/analyze_col_v2.go index 6eefe269fc865..4e41f0b1bbabb 100644 --- a/pkg/executor/analyze_col_v2.go +++ b/pkg/executor/analyze_col_v2.go @@ -616,16 +616,16 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu close(resultCh) } }() - failpoint.Inject("mockAnalyzeSamplingMergeWorkerPanic", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockAnalyzeSamplingMergeWorkerPanic")); _err_ == nil { panic("failpoint triggered") - }) - failpoint.Inject("mockAnalyzeMergeWorkerSlowConsume", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("mockAnalyzeMergeWorkerSlowConsume")); _err_ == nil { times := val.(int) for i := 0; i < times; i++ { e.memTracker.Consume(5 << 20) time.Sleep(100 * time.Millisecond) } - }) + } retCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) for i := 0; i < l; i++ { retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) @@ -682,9 +682,9 @@ func (e *AnalyzeColumnsExecV2) subBuildWorker(resultCh chan error, taskCh chan * resultCh <- getAnalyzePanicErr(r) } }() - failpoint.Inject("mockAnalyzeSamplingBuildWorkerPanic", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockAnalyzeSamplingBuildWorkerPanic")); _err_ == nil { panic("failpoint triggered") - }) + } colLen := len(e.colsInfo) bufferedMemSize := int64(0) @@ -856,18 +856,18 @@ func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, me // After all tasks are sent, close the mergeTaskCh to notify the mergeWorker that all tasks have been sent. defer close(mergeTaskCh) for { - failpoint.Inject("mockKillRunningV2AnalyzeJob", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockKillRunningV2AnalyzeJob")); _err_ == nil { dom := domain.GetDomain(ctx) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - }) + } if err := ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { return err } - failpoint.Inject("mockSlowAnalyzeV2", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockSlowAnalyzeV2")); _err_ == nil { time.Sleep(1000 * time.Second) - }) + } data, err := handler.nextRaw(context.TODO()) if err != nil { diff --git a/pkg/executor/analyze_col_v2.go__failpoint_stash__ b/pkg/executor/analyze_col_v2.go__failpoint_stash__ new file mode 100644 index 0000000000000..6eefe269fc865 --- /dev/null +++ b/pkg/executor/analyze_col_v2.go__failpoint_stash__ @@ -0,0 +1,885 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + stderrors "errors" + "slices" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/statistics" + handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tipb/go-tipb" + "github.com/tiancaiamao/gp" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +// AnalyzeColumnsExecV2 is used to maintain v2 analyze process +type AnalyzeColumnsExecV2 struct { + *AnalyzeColumnsExec +} + +func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2(gp *gp.Pool) *statistics.AnalyzeResults { + var ranges []*ranger.Range + if hc := e.handleCols; hc != nil { + if hc.IsInt() { + ranges = ranger.FullIntRange(mysql.HasUnsignedFlag(hc.GetCol(0).RetType.GetFlag())) + } else { + ranges = ranger.FullNotNullRange() + } + } else { + ranges = ranger.FullIntRange(false) + } + + collExtStats := e.ctx.GetSessionVars().EnableExtendedStats + // specialIndexes holds indexes that include virtual or prefix columns. For these indexes, + // only the number of distinct values (NDV) is computed using TiKV. Other statistic + // are derived from sample data processed within TiDB. + // The reason is that we want to keep the same row sampling for all columns. + specialIndexes := make([]*model.IndexInfo, 0, len(e.indexes)) + specialIndexesOffsets := make([]int, 0, len(e.indexes)) + for i, idx := range e.indexes { + isSpecial := false + for _, col := range idx.Columns { + colInfo := e.colsInfo[col.Offset] + isVirtualCol := colInfo.IsGenerated() && !colInfo.GeneratedStored + isPrefixCol := col.Length != types.UnspecifiedLength + if isVirtualCol || isPrefixCol { + isSpecial = true + break + } + } + if isSpecial { + specialIndexesOffsets = append(specialIndexesOffsets, i) + specialIndexes = append(specialIndexes, idx) + } + } + samplingStatsConcurrency, err := getBuildSamplingStatsConcurrency(e.ctx) + if err != nil { + e.memTracker.Release(e.memTracker.BytesConsumed()) + return &statistics.AnalyzeResults{Err: err, Job: e.job} + } + statsConcurrncy, err := getBuildStatsConcurrency(e.ctx) + if err != nil { + e.memTracker.Release(e.memTracker.BytesConsumed()) + return &statistics.AnalyzeResults{Err: err, Job: e.job} + } + idxNDVPushDownCh := make(chan analyzeIndexNDVTotalResult, 1) + // subIndexWorkerWg is better to be initialized in handleNDVForSpecialIndexes, however if we do so, golang would + // report unexpected/unreasonable data race error on subIndexWorkerWg when running TestAnalyzeVirtualCol test + // case with `-race` flag now. + wg := util.NewWaitGroupPool(gp) + wg.Run(func() { + e.handleNDVForSpecialIndexes(specialIndexes, idxNDVPushDownCh, statsConcurrncy) + }) + defer wg.Wait() + + count, hists, topNs, fmSketches, extStats, err := e.buildSamplingStats(gp, ranges, collExtStats, specialIndexesOffsets, idxNDVPushDownCh, samplingStatsConcurrency) + if err != nil { + e.memTracker.Release(e.memTracker.BytesConsumed()) + return &statistics.AnalyzeResults{Err: err, Job: e.job} + } + cLen := len(e.analyzePB.ColReq.ColumnsInfo) + colGroupResult := &statistics.AnalyzeResult{ + Hist: hists[cLen:], + TopNs: topNs[cLen:], + Fms: fmSketches[cLen:], + IsIndex: 1, + } + // Discard stats of _tidb_rowid. + // Because the process of analyzing will keep the order of results be the same as the colsInfo in the analyze task, + // and in `buildAnalyzeFullSamplingTask` we always place the _tidb_rowid at the last of colsInfo, so if there are + // stats for _tidb_rowid, it must be at the end of the column stats. + // Virtual column has no histogram yet. So we check nil here. + if hists[cLen-1] != nil && hists[cLen-1].ID == -1 { + cLen-- + } + colResult := &statistics.AnalyzeResult{ + Hist: hists[:cLen], + TopNs: topNs[:cLen], + Fms: fmSketches[:cLen], + } + + return &statistics.AnalyzeResults{ + TableID: e.tableID, + Ars: []*statistics.AnalyzeResult{colResult, colGroupResult}, + Job: e.job, + StatsVer: e.StatsVersion, + Count: count, + Snapshot: e.snapshot, + ExtStats: extStats, + BaseCount: e.baseCount, + BaseModifyCnt: e.baseModifyCnt, + } +} + +// decodeSampleDataWithVirtualColumn constructs the virtual column by evaluating from the decoded normal columns. +func (e *AnalyzeColumnsExecV2) decodeSampleDataWithVirtualColumn( + collector statistics.RowSampleCollector, + fieldTps []*types.FieldType, + virtualColIdx []int, + schema *expression.Schema, +) error { + totFts := make([]*types.FieldType, 0, e.schemaForVirtualColEval.Len()) + for _, col := range e.schemaForVirtualColEval.Columns { + totFts = append(totFts, col.RetType) + } + chk := chunk.NewChunkWithCapacity(totFts, len(collector.Base().Samples)) + decoder := codec.NewDecoder(chk, e.ctx.GetSessionVars().Location()) + for _, sample := range collector.Base().Samples { + for i, columns := range sample.Columns { + if schema.Columns[i].VirtualExpr != nil { + continue + } + _, err := decoder.DecodeOne(columns.GetBytes(), i, e.schemaForVirtualColEval.Columns[i].RetType) + if err != nil { + return err + } + } + } + err := table.FillVirtualColumnValue(fieldTps, virtualColIdx, schema.Columns, e.colsInfo, e.ctx.GetExprCtx(), chk) + if err != nil { + return err + } + iter := chunk.NewIterator4Chunk(chk) + for row, i := iter.Begin(), 0; row != iter.End(); row, i = iter.Next(), i+1 { + datums := row.GetDatumRow(totFts) + collector.Base().Samples[i].Columns = datums + } + return nil +} + +func printAnalyzeMergeCollectorLog(oldRootCount, newRootCount, subCount, tableID, partitionID int64, isPartition bool, info string, index int) { + if index < 0 { + logutil.BgLogger().Debug(info, + zap.Int64("tableID", tableID), + zap.Int64("partitionID", partitionID), + zap.Bool("isPartitionTable", isPartition), + zap.Int64("oldRootCount", oldRootCount), + zap.Int64("newRootCount", newRootCount), + zap.Int64("subCount", subCount)) + } else { + logutil.BgLogger().Debug(info, + zap.Int64("tableID", tableID), + zap.Int64("partitionID", partitionID), + zap.Bool("isPartitionTable", isPartition), + zap.Int64("oldRootCount", oldRootCount), + zap.Int64("newRootCount", newRootCount), + zap.Int64("subCount", subCount), + zap.Int("subCollectorIndex", index)) + } +} + +func (e *AnalyzeColumnsExecV2) buildSamplingStats( + gp *gp.Pool, + ranges []*ranger.Range, + needExtStats bool, + indexesWithVirtualColOffsets []int, + idxNDVPushDownCh chan analyzeIndexNDVTotalResult, + samplingStatsConcurrency int, +) ( + count int64, + hists []*statistics.Histogram, + topns []*statistics.TopN, + fmSketches []*statistics.FMSketch, + extStats *statistics.ExtendedStatsColl, + err error, +) { + // Open memory tracker and resultHandler. + if err = e.open(ranges); err != nil { + return 0, nil, nil, nil, nil, err + } + defer func() { + if err1 := e.resultHandler.Close(); err1 != nil { + err = err1 + } + }() + + l := len(e.analyzePB.ColReq.ColumnsInfo) + len(e.analyzePB.ColReq.ColumnGroups) + rootRowCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) + for i := 0; i < l; i++ { + rootRowCollector.Base().FMSketches = append(rootRowCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) + } + + sc := e.ctx.GetSessionVars().StmtCtx + + // Start workers to merge the result from collectors. + mergeResultCh := make(chan *samplingMergeResult, 1) + mergeTaskCh := make(chan []byte, 1) + var taskEg errgroup.Group + // Start read data from resultHandler and send them to mergeTaskCh. + taskEg.Go(func() (err error) { + defer func() { + if r := recover(); r != nil { + err = getAnalyzePanicErr(r) + } + }() + return readDataAndSendTask(e.ctx, e.resultHandler, mergeTaskCh, e.memTracker) + }) + e.samplingMergeWg = &util.WaitGroupWrapper{} + e.samplingMergeWg.Add(samplingStatsConcurrency) + for i := 0; i < samplingStatsConcurrency; i++ { + id := i + gp.Go(func() { + e.subMergeWorker(mergeResultCh, mergeTaskCh, l, id) + }) + } + // Merge the result from collectors. + mergeWorkerPanicCnt := 0 + mergeEg, mergeCtx := errgroup.WithContext(context.Background()) + mergeEg.Go(func() (err error) { + defer func() { + if r := recover(); r != nil { + err = getAnalyzePanicErr(r) + } + }() + for mergeWorkerPanicCnt < samplingStatsConcurrency { + mergeResult, ok := <-mergeResultCh + if !ok { + break + } + if mergeResult.err != nil { + err = mergeResult.err + if isAnalyzeWorkerPanic(mergeResult.err) { + mergeWorkerPanicCnt++ + } + continue + } + oldRootCollectorSize := rootRowCollector.Base().MemSize + oldRootCollectorCount := rootRowCollector.Base().Count + // Merge the result from sub-collectors. + rootRowCollector.MergeCollector(mergeResult.collector) + newRootCollectorCount := rootRowCollector.Base().Count + printAnalyzeMergeCollectorLog(oldRootCollectorCount, newRootCollectorCount, + mergeResult.collector.Base().Count, e.tableID.TableID, e.tableID.PartitionID, e.tableID.IsPartitionTable(), + "merge subMergeWorker in AnalyzeColumnsExecV2", -1) + e.memTracker.Consume(rootRowCollector.Base().MemSize - oldRootCollectorSize - mergeResult.collector.Base().MemSize) + mergeResult.collector.DestroyAndPutToPool() + } + return err + }) + err = taskEg.Wait() + if err != nil { + mergeCtx.Done() + if err1 := mergeEg.Wait(); err1 != nil { + err = stderrors.Join(err, err1) + } + return 0, nil, nil, nil, nil, getAnalyzePanicErr(err) + } + err = mergeEg.Wait() + defer e.memTracker.Release(rootRowCollector.Base().MemSize) + if err != nil { + return 0, nil, nil, nil, nil, err + } + + // Decode the data from sample collectors. + virtualColIdx := buildVirtualColumnIndex(e.schemaForVirtualColEval, e.colsInfo) + // Filling virtual columns is necessary here because these samples are used to build statistics for indexes that constructed by virtual columns. + if len(virtualColIdx) > 0 { + fieldTps := make([]*types.FieldType, 0, len(virtualColIdx)) + for _, colOffset := range virtualColIdx { + fieldTps = append(fieldTps, e.schemaForVirtualColEval.Columns[colOffset].RetType) + } + err = e.decodeSampleDataWithVirtualColumn(rootRowCollector, fieldTps, virtualColIdx, e.schemaForVirtualColEval) + if err != nil { + return 0, nil, nil, nil, nil, err + } + } else { + // If there's no virtual column, normal decode way is enough. + for _, sample := range rootRowCollector.Base().Samples { + for i := range sample.Columns { + sample.Columns[i], err = tablecodec.DecodeColumnValue(sample.Columns[i].GetBytes(), &e.colsInfo[i].FieldType, sc.TimeZone()) + if err != nil { + return 0, nil, nil, nil, nil, err + } + } + } + } + + // Calculate handle from the row data for each row. It will be used to sort the samples. + for _, sample := range rootRowCollector.Base().Samples { + sample.Handle, err = e.handleCols.BuildHandleByDatums(sample.Columns) + if err != nil { + return 0, nil, nil, nil, nil, err + } + } + colLen := len(e.colsInfo) + // The order of the samples are broken when merging samples from sub-collectors. + // So now we need to sort the samples according to the handle in order to calculate correlation. + slices.SortFunc(rootRowCollector.Base().Samples, func(i, j *statistics.ReservoirRowSampleItem) int { + return i.Handle.Compare(j.Handle) + }) + + totalLen := len(e.colsInfo) + len(e.indexes) + hists = make([]*statistics.Histogram, totalLen) + topns = make([]*statistics.TopN, totalLen) + fmSketches = make([]*statistics.FMSketch, 0, totalLen) + buildResultChan := make(chan error, totalLen) + buildTaskChan := make(chan *samplingBuildTask, totalLen) + if totalLen < samplingStatsConcurrency { + samplingStatsConcurrency = totalLen + } + e.samplingBuilderWg = newNotifyErrorWaitGroupWrapper(gp, buildResultChan) + sampleCollectors := make([]*statistics.SampleCollector, len(e.colsInfo)) + exitCh := make(chan struct{}) + e.samplingBuilderWg.Add(samplingStatsConcurrency) + + // Start workers to build stats. + for i := 0; i < samplingStatsConcurrency; i++ { + e.samplingBuilderWg.Run(func() { + e.subBuildWorker(buildResultChan, buildTaskChan, hists, topns, sampleCollectors, exitCh) + }) + } + // Generate tasks for building stats. + for i, col := range e.colsInfo { + buildTaskChan <- &samplingBuildTask{ + id: col.ID, + rootRowCollector: rootRowCollector, + tp: &col.FieldType, + isColumn: true, + slicePos: i, + } + fmSketches = append(fmSketches, rootRowCollector.Base().FMSketches[i]) + } + + indexPushedDownResult := <-idxNDVPushDownCh + if indexPushedDownResult.err != nil { + close(exitCh) + e.samplingBuilderWg.Wait() + return 0, nil, nil, nil, nil, indexPushedDownResult.err + } + for _, offset := range indexesWithVirtualColOffsets { + ret := indexPushedDownResult.results[e.indexes[offset].ID] + rootRowCollector.Base().NullCount[colLen+offset] = ret.Count + rootRowCollector.Base().FMSketches[colLen+offset] = ret.Ars[0].Fms[0] + } + + // Generate tasks for building stats for indexes. + for i, idx := range e.indexes { + buildTaskChan <- &samplingBuildTask{ + id: idx.ID, + rootRowCollector: rootRowCollector, + tp: types.NewFieldType(mysql.TypeBlob), + isColumn: false, + slicePos: colLen + i, + } + fmSketches = append(fmSketches, rootRowCollector.Base().FMSketches[colLen+i]) + } + close(buildTaskChan) + + panicCnt := 0 + for panicCnt < samplingStatsConcurrency { + err1, ok := <-buildResultChan + if !ok { + break + } + if err1 != nil { + err = err1 + if isAnalyzeWorkerPanic(err1) { + panicCnt++ + } + continue + } + } + defer func() { + totalSampleCollectorSize := int64(0) + for _, sampleCollector := range sampleCollectors { + if sampleCollector != nil { + totalSampleCollectorSize += sampleCollector.MemSize + } + } + e.memTracker.Release(totalSampleCollectorSize) + }() + if err != nil { + return 0, nil, nil, nil, nil, err + } + + count = rootRowCollector.Base().Count + if needExtStats { + extStats, err = statistics.BuildExtendedStats(e.ctx, e.TableID.GetStatisticsID(), e.colsInfo, sampleCollectors) + if err != nil { + return 0, nil, nil, nil, nil, err + } + } + + return +} + +// handleNDVForSpecialIndexes deals with the logic to analyze the index containing the virtual column when the mode is full sampling. +func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.IndexInfo, totalResultCh chan analyzeIndexNDVTotalResult, statsConcurrncy int) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("analyze ndv for special index panicked", zap.Any("recover", r), zap.Stack("stack")) + metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() + totalResultCh <- analyzeIndexNDVTotalResult{ + err: getAnalyzePanicErr(r), + } + } + }() + tasks := e.buildSubIndexJobForSpecialIndex(indexInfos) + taskCh := make(chan *analyzeTask, len(tasks)) + for _, task := range tasks { + AddNewAnalyzeJob(e.ctx, task.job) + } + resultsCh := make(chan *statistics.AnalyzeResults, len(tasks)) + if len(tasks) < statsConcurrncy { + statsConcurrncy = len(tasks) + } + var subIndexWorkerWg = NewAnalyzeResultsNotifyWaitGroupWrapper(resultsCh) + subIndexWorkerWg.Add(statsConcurrncy) + for i := 0; i < statsConcurrncy; i++ { + subIndexWorkerWg.Run(func() { e.subIndexWorkerForNDV(taskCh, resultsCh) }) + } + for _, task := range tasks { + taskCh <- task + } + close(taskCh) + panicCnt := 0 + totalResult := analyzeIndexNDVTotalResult{ + results: make(map[int64]*statistics.AnalyzeResults, len(indexInfos)), + } + var err error + statsHandle := domain.GetDomain(e.ctx).StatsHandle() + for panicCnt < statsConcurrncy { + results, ok := <-resultsCh + if !ok { + break + } + if results.Err != nil { + err = results.Err + statsHandle.FinishAnalyzeJob(results.Job, err, statistics.TableAnalysisJob) + if isAnalyzeWorkerPanic(err) { + panicCnt++ + } + continue + } + statsHandle.FinishAnalyzeJob(results.Job, nil, statistics.TableAnalysisJob) + totalResult.results[results.Ars[0].Hist[0].ID] = results + } + if err != nil { + totalResult.err = err + } + totalResultCh <- totalResult +} + +// subIndexWorker receive the task for each index and return the result for them. +func (e *AnalyzeColumnsExecV2) subIndexWorkerForNDV(taskCh chan *analyzeTask, resultsCh chan *statistics.AnalyzeResults) { + var task *analyzeTask + statsHandle := domain.GetDomain(e.ctx).StatsHandle() + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) + metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() + resultsCh <- &statistics.AnalyzeResults{ + Err: getAnalyzePanicErr(r), + Job: task.job, + } + } + }() + for { + var ok bool + task, ok = <-taskCh + if !ok { + break + } + statsHandle.StartAnalyzeJob(task.job) + if task.taskType != idxTask { + resultsCh <- &statistics.AnalyzeResults{ + Err: errors.Errorf("incorrect analyze type"), + Job: task.job, + } + continue + } + task.idxExec.job = task.job + resultsCh <- analyzeIndexNDVPushDown(task.idxExec) + } +} + +// buildSubIndexJobForSpecialIndex builds sub index pushed down task to calculate the NDV information for indexes containing virtual column. +// This is because we cannot push the calculation of the virtual column down to the tikv side. +func (e *AnalyzeColumnsExecV2) buildSubIndexJobForSpecialIndex(indexInfos []*model.IndexInfo) []*analyzeTask { + _, offset := timeutil.Zone(e.ctx.GetSessionVars().Location()) + tasks := make([]*analyzeTask, 0, len(indexInfos)) + sc := e.ctx.GetSessionVars().StmtCtx + concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), e.ctx) + for _, indexInfo := range indexInfos { + base := baseAnalyzeExec{ + ctx: e.ctx, + tableID: e.TableID, + concurrency: concurrency, + analyzePB: &tipb.AnalyzeReq{ + Tp: tipb.AnalyzeType_TypeIndex, + Flags: sc.PushDownFlags(), + TimeZoneOffset: offset, + }, + snapshot: e.snapshot, + } + idxExec := &AnalyzeIndexExec{ + baseAnalyzeExec: base, + isCommonHandle: e.tableInfo.IsCommonHandle, + idxInfo: indexInfo, + } + idxExec.opts = make(map[ast.AnalyzeOptionType]uint64, len(ast.AnalyzeOptionString)) + idxExec.opts[ast.AnalyzeOptNumTopN] = 0 + idxExec.opts[ast.AnalyzeOptCMSketchDepth] = 0 + idxExec.opts[ast.AnalyzeOptCMSketchWidth] = 0 + idxExec.opts[ast.AnalyzeOptNumSamples] = 0 + idxExec.opts[ast.AnalyzeOptNumBuckets] = 1 + statsVersion := new(int32) + *statsVersion = statistics.Version1 + // No Top-N + topnSize := int32(0) + idxExec.analyzePB.IdxReq = &tipb.AnalyzeIndexReq{ + // One bucket to store the null for null histogram. + BucketSize: 1, + NumColumns: int32(len(indexInfo.Columns)), + TopNSize: &topnSize, + Version: statsVersion, + SketchSize: statistics.MaxSketchSize, + } + if idxExec.isCommonHandle && indexInfo.Primary { + idxExec.analyzePB.Tp = tipb.AnalyzeType_TypeCommonHandle + } + // No CM-Sketch. + depth := int32(0) + width := int32(0) + idxExec.analyzePB.IdxReq.CmsketchDepth = &depth + idxExec.analyzePB.IdxReq.CmsketchWidth = &width + autoAnalyze := "" + if e.ctx.GetSessionVars().InRestrictedSQL { + autoAnalyze = "auto " + } + job := &statistics.AnalyzeJob{DBName: e.job.DBName, TableName: e.job.TableName, PartitionName: e.job.PartitionName, JobInfo: autoAnalyze + "analyze ndv for index " + indexInfo.Name.O} + idxExec.job = job + tasks = append(tasks, &analyzeTask{ + taskType: idxTask, + idxExec: idxExec, + job: job, + }) + } + return tasks +} + +func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResult, taskCh <-chan []byte, l int, index int) { + // Only close the resultCh in the first worker. + closeTheResultCh := index == 0 + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) + metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() + resultCh <- &samplingMergeResult{err: getAnalyzePanicErr(r)} + } + // Consume the remaining things. + for { + _, ok := <-taskCh + if !ok { + break + } + } + e.samplingMergeWg.Done() + if closeTheResultCh { + e.samplingMergeWg.Wait() + close(resultCh) + } + }() + failpoint.Inject("mockAnalyzeSamplingMergeWorkerPanic", func() { + panic("failpoint triggered") + }) + failpoint.Inject("mockAnalyzeMergeWorkerSlowConsume", func(val failpoint.Value) { + times := val.(int) + for i := 0; i < times; i++ { + e.memTracker.Consume(5 << 20) + time.Sleep(100 * time.Millisecond) + } + }) + retCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) + for i := 0; i < l; i++ { + retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) + } + statsHandle := domain.GetDomain(e.ctx).StatsHandle() + for { + data, ok := <-taskCh + if !ok { + break + } + + // Unmarshal the data. + dataSize := int64(cap(data)) + colResp := &tipb.AnalyzeColumnsResp{} + err := colResp.Unmarshal(data) + if err != nil { + resultCh <- &samplingMergeResult{err: err} + return + } + // Consume the memory of the data. + colRespSize := int64(colResp.Size()) + e.memTracker.Consume(colRespSize) + + // Update processed rows. + subCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) + subCollector.Base().FromProto(colResp.RowCollector, e.memTracker) + statsHandle.UpdateAnalyzeJobProgress(e.job, subCollector.Base().Count) + + // Print collect log. + oldRetCollectorSize := retCollector.Base().MemSize + oldRetCollectorCount := retCollector.Base().Count + retCollector.MergeCollector(subCollector) + newRetCollectorCount := retCollector.Base().Count + printAnalyzeMergeCollectorLog(oldRetCollectorCount, newRetCollectorCount, subCollector.Base().Count, + e.tableID.TableID, e.tableID.PartitionID, e.TableID.IsPartitionTable(), + "merge subCollector in concurrency in AnalyzeColumnsExecV2", index) + + // Consume the memory of the result. + newRetCollectorSize := retCollector.Base().MemSize + subCollectorSize := subCollector.Base().MemSize + e.memTracker.Consume(newRetCollectorSize - oldRetCollectorSize - subCollectorSize) + e.memTracker.Release(dataSize + colRespSize) + subCollector.DestroyAndPutToPool() + } + + resultCh <- &samplingMergeResult{collector: retCollector} +} + +func (e *AnalyzeColumnsExecV2) subBuildWorker(resultCh chan error, taskCh chan *samplingBuildTask, hists []*statistics.Histogram, topns []*statistics.TopN, collectors []*statistics.SampleCollector, exitCh chan struct{}) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) + metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() + resultCh <- getAnalyzePanicErr(r) + } + }() + failpoint.Inject("mockAnalyzeSamplingBuildWorkerPanic", func() { + panic("failpoint triggered") + }) + + colLen := len(e.colsInfo) + bufferedMemSize := int64(0) + bufferedReleaseSize := int64(0) + defer e.memTracker.Consume(bufferedMemSize) + defer e.memTracker.Release(bufferedReleaseSize) + +workLoop: + for { + select { + case task, ok := <-taskCh: + if !ok { + break workLoop + } + var collector *statistics.SampleCollector + if task.isColumn { + if e.colsInfo[task.slicePos].IsGenerated() && !e.colsInfo[task.slicePos].GeneratedStored { + hists[task.slicePos] = nil + topns[task.slicePos] = nil + continue + } + sampleNum := task.rootRowCollector.Base().Samples.Len() + sampleItems := make([]*statistics.SampleItem, 0, sampleNum) + // consume mandatory memory at the beginning, including empty SampleItems of all sample rows, if exceeds, fast fail + collectorMemSize := int64(sampleNum) * (8 + statistics.EmptySampleItemSize) + e.memTracker.Consume(collectorMemSize) + var collator collate.Collator + ft := e.colsInfo[task.slicePos].FieldType + // When it's new collation data, we need to use its collate key instead of original value because only + // the collate key can ensure the correct ordering. + // This is also corresponding to similar operation in (*statistics.Column).GetColumnRowCount(). + if ft.EvalType() == types.ETString && ft.GetType() != mysql.TypeEnum && ft.GetType() != mysql.TypeSet { + collator = collate.GetCollator(ft.GetCollate()) + } + for j, row := range task.rootRowCollector.Base().Samples { + if row.Columns[task.slicePos].IsNull() { + continue + } + val := row.Columns[task.slicePos] + // If this value is very big, we think that it is not a value that can occur many times. So we don't record it. + if len(val.GetBytes()) > statistics.MaxSampleValueLength { + continue + } + if collator != nil { + val.SetBytes(collator.Key(val.GetString())) + deltaSize := int64(cap(val.GetBytes())) + collectorMemSize += deltaSize + e.memTracker.BufferedConsume(&bufferedMemSize, deltaSize) + } + sampleItems = append(sampleItems, &statistics.SampleItem{ + Value: val, + Ordinal: j, + }) + // tmp memory usage + deltaSize := val.MemUsage() + 4 // content of SampleItem is copied + e.memTracker.BufferedConsume(&bufferedMemSize, deltaSize) + e.memTracker.BufferedRelease(&bufferedReleaseSize, deltaSize) + } + collector = &statistics.SampleCollector{ + Samples: sampleItems, + NullCount: task.rootRowCollector.Base().NullCount[task.slicePos], + Count: task.rootRowCollector.Base().Count - task.rootRowCollector.Base().NullCount[task.slicePos], + FMSketch: task.rootRowCollector.Base().FMSketches[task.slicePos], + TotalSize: task.rootRowCollector.Base().TotalSizes[task.slicePos], + MemSize: collectorMemSize, + } + } else { + var tmpDatum types.Datum + var err error + idx := e.indexes[task.slicePos-colLen] + sampleNum := task.rootRowCollector.Base().Samples.Len() + sampleItems := make([]*statistics.SampleItem, 0, sampleNum) + // consume mandatory memory at the beginning, including all SampleItems, if exceeds, fast fail + // 8 is size of reference, 8 is the size of "b := make([]byte, 0, 8)" + collectorMemSize := int64(sampleNum) * (8 + statistics.EmptySampleItemSize + 8) + e.memTracker.Consume(collectorMemSize) + errCtx := e.ctx.GetSessionVars().StmtCtx.ErrCtx() + indexSampleCollectLoop: + for _, row := range task.rootRowCollector.Base().Samples { + if len(idx.Columns) == 1 && row.Columns[idx.Columns[0].Offset].IsNull() { + continue + } + b := make([]byte, 0, 8) + for _, col := range idx.Columns { + // If the index value contains one value which is too long, we think that it's a value that doesn't occur many times. + if len(row.Columns[col.Offset].GetBytes()) > statistics.MaxSampleValueLength { + continue indexSampleCollectLoop + } + if col.Length != types.UnspecifiedLength { + row.Columns[col.Offset].Copy(&tmpDatum) + ranger.CutDatumByPrefixLen(&tmpDatum, col.Length, &e.colsInfo[col.Offset].FieldType) + b, err = codec.EncodeKey(e.ctx.GetSessionVars().StmtCtx.TimeZone(), b, tmpDatum) + err = errCtx.HandleError(err) + if err != nil { + resultCh <- err + continue workLoop + } + continue + } + b, err = codec.EncodeKey(e.ctx.GetSessionVars().StmtCtx.TimeZone(), b, row.Columns[col.Offset]) + err = errCtx.HandleError(err) + if err != nil { + resultCh <- err + continue workLoop + } + } + sampleItems = append(sampleItems, &statistics.SampleItem{ + Value: types.NewBytesDatum(b), + }) + // tmp memory usage + deltaSize := sampleItems[len(sampleItems)-1].Value.MemUsage() + e.memTracker.BufferedConsume(&bufferedMemSize, deltaSize) + e.memTracker.BufferedRelease(&bufferedReleaseSize, deltaSize) + } + collector = &statistics.SampleCollector{ + Samples: sampleItems, + NullCount: task.rootRowCollector.Base().NullCount[task.slicePos], + Count: task.rootRowCollector.Base().Count - task.rootRowCollector.Base().NullCount[task.slicePos], + FMSketch: task.rootRowCollector.Base().FMSketches[task.slicePos], + TotalSize: task.rootRowCollector.Base().TotalSizes[task.slicePos], + MemSize: collectorMemSize, + } + } + if task.isColumn { + collectors[task.slicePos] = collector + } + releaseCollectorMemory := func() { + if !task.isColumn { + e.memTracker.Release(collector.MemSize) + } + } + hist, topn, err := statistics.BuildHistAndTopN(e.ctx, int(e.opts[ast.AnalyzeOptNumBuckets]), int(e.opts[ast.AnalyzeOptNumTopN]), task.id, collector, task.tp, task.isColumn, e.memTracker, e.ctx.GetSessionVars().EnableExtendedStats) + if err != nil { + resultCh <- err + releaseCollectorMemory() + continue + } + finalMemSize := hist.MemoryUsage() + topn.MemoryUsage() + e.memTracker.Consume(finalMemSize) + hists[task.slicePos] = hist + topns[task.slicePos] = topn + resultCh <- nil + releaseCollectorMemory() + case <-exitCh: + return + } + } +} + +type analyzeIndexNDVTotalResult struct { + results map[int64]*statistics.AnalyzeResults + err error +} + +type samplingMergeResult struct { + collector statistics.RowSampleCollector + err error +} + +type samplingBuildTask struct { + id int64 + rootRowCollector statistics.RowSampleCollector + tp *types.FieldType + isColumn bool + slicePos int +} + +func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, mergeTaskCh chan []byte, memTracker *memory.Tracker) error { + // After all tasks are sent, close the mergeTaskCh to notify the mergeWorker that all tasks have been sent. + defer close(mergeTaskCh) + for { + failpoint.Inject("mockKillRunningV2AnalyzeJob", func() { + dom := domain.GetDomain(ctx) + for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { + dom.SysProcTracker().KillSysProcess(id) + } + }) + if err := ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return err + } + failpoint.Inject("mockSlowAnalyzeV2", func() { + time.Sleep(1000 * time.Second) + }) + + data, err := handler.nextRaw(context.TODO()) + if err != nil { + return errors.Trace(err) + } + if data == nil { + break + } + + memTracker.Consume(int64(cap(data))) + mergeTaskCh <- data + } + + return nil +} diff --git a/pkg/executor/analyze_idx.go b/pkg/executor/analyze_idx.go index 1c26be7782d4f..65b1850491e6d 100644 --- a/pkg/executor/analyze_idx.go +++ b/pkg/executor/analyze_idx.go @@ -180,11 +180,11 @@ func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRang } func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, needCMS bool) (*statistics.Histogram, *statistics.CMSketch, *statistics.FMSketch, *statistics.TopN, error) { - failpoint.Inject("buildStatsFromResult", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("buildStatsFromResult")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, nil, nil, nil, errors.New("mock buildStatsFromResult error")) + return nil, nil, nil, nil, errors.New("mock buildStatsFromResult error") } - }) + } hist := &statistics.Histogram{} var cms *statistics.CMSketch var topn *statistics.TopN @@ -198,18 +198,18 @@ func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, nee statsVer = int(*e.analyzePB.IdxReq.Version) } for { - failpoint.Inject("mockKillRunningAnalyzeIndexJob", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockKillRunningAnalyzeIndexJob")); _err_ == nil { dom := domain.GetDomain(e.ctx) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - }) + } if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { return nil, nil, nil, nil, err } - failpoint.Inject("mockSlowAnalyzeIndex", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockSlowAnalyzeIndex")); _err_ == nil { time.Sleep(1000 * time.Second) - }) + } data, err := result.NextRaw(context.TODO()) if err != nil { return nil, nil, nil, nil, err diff --git a/pkg/executor/analyze_idx.go__failpoint_stash__ b/pkg/executor/analyze_idx.go__failpoint_stash__ new file mode 100644 index 0000000000000..1c26be7782d4f --- /dev/null +++ b/pkg/executor/analyze_idx.go__failpoint_stash__ @@ -0,0 +1,344 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "math" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/statistics" + handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +// AnalyzeIndexExec represents analyze index push down executor. +type AnalyzeIndexExec struct { + baseAnalyzeExec + + idxInfo *model.IndexInfo + isCommonHandle bool + result distsql.SelectResult + countNullRes distsql.SelectResult +} + +func analyzeIndexPushdown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { + ranges := ranger.FullRange() + // For single-column index, we do not load null rows from TiKV, so the built histogram would not include + // null values, and its `NullCount` would be set by result of another distsql call to get null rows. + // For multi-column index, we cannot define null for the rows, so we still use full range, and the rows + // containing null fields would exist in built histograms. Note that, the `NullCount` of histograms for + // multi-column index is always 0 then. + if len(idxExec.idxInfo.Columns) == 1 { + ranges = ranger.FullNotNullRange() + } + hist, cms, fms, topN, err := idxExec.buildStats(ranges, true) + if err != nil { + return &statistics.AnalyzeResults{Err: err, Job: idxExec.job} + } + var statsVer = statistics.Version1 + if idxExec.analyzePB.IdxReq.Version != nil { + statsVer = int(*idxExec.analyzePB.IdxReq.Version) + } + idxResult := &statistics.AnalyzeResult{ + Hist: []*statistics.Histogram{hist}, + TopNs: []*statistics.TopN{topN}, + Fms: []*statistics.FMSketch{fms}, + IsIndex: 1, + } + if statsVer != statistics.Version2 { + idxResult.Cms = []*statistics.CMSketch{cms} + } + cnt := hist.NullCount + if hist.Len() > 0 { + cnt += hist.Buckets[hist.Len()-1].Count + } + if topN.TotalCount() > 0 { + cnt += int64(topN.TotalCount()) + } + result := &statistics.AnalyzeResults{ + TableID: idxExec.tableID, + Ars: []*statistics.AnalyzeResult{idxResult}, + Job: idxExec.job, + StatsVer: statsVer, + Count: cnt, + Snapshot: idxExec.snapshot, + } + if idxExec.idxInfo.MVIndex { + result.ForMVIndex = true + } + return result +} + +func (e *AnalyzeIndexExec) buildStats(ranges []*ranger.Range, considerNull bool) (hist *statistics.Histogram, cms *statistics.CMSketch, fms *statistics.FMSketch, topN *statistics.TopN, err error) { + if err = e.open(ranges, considerNull); err != nil { + return nil, nil, nil, nil, err + } + defer func() { + err1 := closeAll(e.result, e.countNullRes) + if err == nil { + err = err1 + } + }() + hist, cms, fms, topN, err = e.buildStatsFromResult(e.result, true) + if err != nil { + return nil, nil, nil, nil, err + } + if e.countNullRes != nil { + nullHist, _, _, _, err := e.buildStatsFromResult(e.countNullRes, false) + if err != nil { + return nil, nil, nil, nil, err + } + if l := nullHist.Len(); l > 0 { + hist.NullCount = nullHist.Buckets[l-1].Count + } + } + hist.ID = e.idxInfo.ID + return hist, cms, fms, topN, nil +} + +func (e *AnalyzeIndexExec) open(ranges []*ranger.Range, considerNull bool) error { + err := e.fetchAnalyzeResult(ranges, false) + if err != nil { + return err + } + if considerNull && len(e.idxInfo.Columns) == 1 { + ranges = ranger.NullRange() + err = e.fetchAnalyzeResult(ranges, true) + if err != nil { + return err + } + } + return nil +} + +// fetchAnalyzeResult builds and dispatches the `kv.Request` from given ranges, and stores the `SelectResult` +// in corresponding fields based on the input `isNullRange` argument, which indicates if the range is the +// special null range for single-column index to get the null count. +func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRange bool) error { + var builder distsql.RequestBuilder + var kvReqBuilder *distsql.RequestBuilder + if e.isCommonHandle && e.idxInfo.Primary { + kvReqBuilder = builder.SetHandleRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.tableID.GetStatisticsID()}, true, ranges) + } else { + kvReqBuilder = builder.SetIndexRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.tableID.GetStatisticsID()}, e.idxInfo.ID, ranges) + } + kvReqBuilder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) + startTS := uint64(math.MaxUint64) + isoLevel := kv.RC + if e.ctx.GetSessionVars().EnableAnalyzeSnapshot { + startTS = e.snapshot + isoLevel = kv.SI + } + kvReq, err := kvReqBuilder. + SetAnalyzeRequest(e.analyzePB, isoLevel). + SetStartTS(startTS). + SetKeepOrder(true). + SetConcurrency(e.concurrency). + SetResourceGroupName(e.ctx.GetSessionVars().StmtCtx.ResourceGroupName). + SetExplicitRequestSourceType(e.ctx.GetSessionVars().ExplicitRequestSourceType). + Build() + if err != nil { + return err + } + ctx := context.TODO() + result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) + if err != nil { + return err + } + if isNullRange { + e.countNullRes = result + } else { + e.result = result + } + return nil +} + +func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, needCMS bool) (*statistics.Histogram, *statistics.CMSketch, *statistics.FMSketch, *statistics.TopN, error) { + failpoint.Inject("buildStatsFromResult", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, nil, nil, nil, errors.New("mock buildStatsFromResult error")) + } + }) + hist := &statistics.Histogram{} + var cms *statistics.CMSketch + var topn *statistics.TopN + if needCMS { + cms = statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])) + topn = statistics.NewTopN(int(e.opts[ast.AnalyzeOptNumTopN])) + } + fms := statistics.NewFMSketch(statistics.MaxSketchSize) + statsVer := statistics.Version1 + if e.analyzePB.IdxReq.Version != nil { + statsVer = int(*e.analyzePB.IdxReq.Version) + } + for { + failpoint.Inject("mockKillRunningAnalyzeIndexJob", func() { + dom := domain.GetDomain(e.ctx) + for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { + dom.SysProcTracker().KillSysProcess(id) + } + }) + if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return nil, nil, nil, nil, err + } + failpoint.Inject("mockSlowAnalyzeIndex", func() { + time.Sleep(1000 * time.Second) + }) + data, err := result.NextRaw(context.TODO()) + if err != nil { + return nil, nil, nil, nil, err + } + if data == nil { + break + } + resp := &tipb.AnalyzeIndexResp{} + err = resp.Unmarshal(data) + if err != nil { + return nil, nil, nil, nil, err + } + hist, cms, fms, topn, err = updateIndexResult(e.ctx, resp, e.job, hist, cms, fms, topn, + e.idxInfo, int(e.opts[ast.AnalyzeOptNumBuckets]), int(e.opts[ast.AnalyzeOptNumTopN]), statsVer) + if err != nil { + return nil, nil, nil, nil, err + } + } + if needCMS && topn.TotalCount() > 0 { + hist.RemoveVals(topn.TopN) + } + if statsVer == statistics.Version2 { + hist.StandardizeForV2AnalyzeIndex() + } + if needCMS && cms != nil { + cms.CalcDefaultValForAnalyze(uint64(hist.NDV)) + } + return hist, cms, fms, topn, nil +} + +func (e *AnalyzeIndexExec) buildSimpleStats(ranges []*ranger.Range, considerNull bool) (fms *statistics.FMSketch, nullHist *statistics.Histogram, err error) { + if err = e.open(ranges, considerNull); err != nil { + return nil, nil, err + } + defer func() { + err1 := closeAll(e.result, e.countNullRes) + if err == nil { + err = err1 + } + }() + _, _, fms, _, err = e.buildStatsFromResult(e.result, false) + if e.countNullRes != nil { + nullHist, _, _, _, err := e.buildStatsFromResult(e.countNullRes, false) + if err != nil { + return nil, nil, err + } + if l := nullHist.Len(); l > 0 { + return fms, nullHist, nil + } + } + return fms, nil, nil +} + +func analyzeIndexNDVPushDown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { + ranges := ranger.FullRange() + // For single-column index, we do not load null rows from TiKV, so the built histogram would not include + // null values, and its `NullCount` would be set by result of another distsql call to get null rows. + // For multi-column index, we cannot define null for the rows, so we still use full range, and the rows + // containing null fields would exist in built histograms. Note that, the `NullCount` of histograms for + // multi-column index is always 0 then. + if len(idxExec.idxInfo.Columns) == 1 { + ranges = ranger.FullNotNullRange() + } + fms, nullHist, err := idxExec.buildSimpleStats(ranges, len(idxExec.idxInfo.Columns) == 1) + if err != nil { + return &statistics.AnalyzeResults{Err: err, Job: idxExec.job} + } + result := &statistics.AnalyzeResult{ + Fms: []*statistics.FMSketch{fms}, + // We use histogram to get the Index's ID. + Hist: []*statistics.Histogram{statistics.NewHistogram(idxExec.idxInfo.ID, 0, 0, statistics.Version1, types.NewFieldType(mysql.TypeBlob), 0, 0)}, + IsIndex: 1, + } + r := &statistics.AnalyzeResults{ + TableID: idxExec.tableID, + Ars: []*statistics.AnalyzeResult{result}, + Job: idxExec.job, + // TODO: avoid reusing Version1. + StatsVer: statistics.Version1, + } + if nullHist != nil && nullHist.Len() > 0 { + r.Count = nullHist.Buckets[nullHist.Len()-1].Count + } + return r +} + +func updateIndexResult( + ctx sessionctx.Context, + resp *tipb.AnalyzeIndexResp, + job *statistics.AnalyzeJob, + hist *statistics.Histogram, + cms *statistics.CMSketch, + fms *statistics.FMSketch, + topn *statistics.TopN, + idxInfo *model.IndexInfo, + numBuckets int, + numTopN int, + statsVer int, +) ( + *statistics.Histogram, + *statistics.CMSketch, + *statistics.FMSketch, + *statistics.TopN, + error, +) { + var err error + needCMS := cms != nil + respHist := statistics.HistogramFromProto(resp.Hist) + if job != nil { + statsHandle := domain.GetDomain(ctx).StatsHandle() + statsHandle.UpdateAnalyzeJobProgress(job, int64(respHist.TotalRowCount())) + } + hist, err = statistics.MergeHistograms(ctx.GetSessionVars().StmtCtx, hist, respHist, numBuckets, statsVer) + if err != nil { + return nil, nil, nil, nil, err + } + if needCMS { + if resp.Cms == nil { + logutil.Logger(context.TODO()).Warn("nil CMS in response", zap.String("table", idxInfo.Table.O), zap.String("index", idxInfo.Name.O)) + } else { + cm, tmpTopN := statistics.CMSketchAndTopNFromProto(resp.Cms) + if err := cms.MergeCMSketch(cm); err != nil { + return nil, nil, nil, nil, err + } + statistics.MergeTopNAndUpdateCMSketch(topn, tmpTopN, cms, uint32(numTopN)) + } + } + if fms != nil && resp.Collector != nil && resp.Collector.FmSketch != nil { + fms.MergeFMSketch(statistics.FMSketchFromProto(resp.Collector.FmSketch)) + } + return hist, cms, fms, topn, nil +} diff --git a/pkg/executor/batch_point_get.go b/pkg/executor/batch_point_get.go index 88a2a442f158d..f7d8a425590ac 100644 --- a/pkg/executor/batch_point_get.go +++ b/pkg/executor/batch_point_get.go @@ -327,14 +327,14 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 // 3. Then point get retrieve data from backend after step 2 finished // 4. Check the result - failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step1", func() { + if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("batchPointGetRepeatableReadTest-step1")); _err_ == nil { if ch, ok := ctx.Value("batchPointGetRepeatableReadTest").(chan struct{}); ok { // Make `UPDATE` continue close(ch) } // Wait `UPDATE` finished - failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step2", nil) - }) + failpoint.EvalContext(ctx, _curpkg_("batchPointGetRepeatableReadTest-step2")) + } } else if e.keepOrder { less := func(i, j kv.Handle) int { if e.desc { diff --git a/pkg/executor/batch_point_get.go__failpoint_stash__ b/pkg/executor/batch_point_get.go__failpoint_stash__ new file mode 100644 index 0000000000000..88a2a442f158d --- /dev/null +++ b/pkg/executor/batch_point_get.go__failpoint_stash__ @@ -0,0 +1,528 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "fmt" + "slices" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + driver "github.com/pingcap/tidb/pkg/store/driver/txn" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil/consistency" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/tikv/client-go/v2/tikvrpc" +) + +// BatchPointGetExec executes a bunch of point select queries. +type BatchPointGetExec struct { + exec.BaseExecutor + indexUsageReporter *exec.IndexUsageReporter + + tblInfo *model.TableInfo + idxInfo *model.IndexInfo + handles []kv.Handle + // table/partition IDs for handle or index read + // (can be secondary unique key, + // and need lookup through handle) + planPhysIDs []int64 + // If != 0 then it is a single partition under Static Prune mode. + singlePartID int64 + partitionNames []model.CIStr + idxVals [][]types.Datum + txn kv.Transaction + lock bool + waitTime int64 + inited uint32 + values [][]byte + index int + rowDecoder *rowcodec.ChunkDecoder + keepOrder bool + desc bool + batchGetter kv.BatchGetter + + columns []*model.ColumnInfo + // virtualColumnIndex records all the indices of virtual columns and sort them in definition + // to make sure we can compute the virtual column in right order. + virtualColumnIndex []int + + // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. + virtualColumnRetFieldTypes []*types.FieldType + + snapshot kv.Snapshot + stats *runtimeStatsWithSnapshot +} + +// buildVirtualColumnInfo saves virtual column indices and sort them in definition order +func (e *BatchPointGetExec) buildVirtualColumnInfo() { + e.virtualColumnIndex = buildVirtualColumnIndex(e.Schema(), e.columns) + if len(e.virtualColumnIndex) > 0 { + e.virtualColumnRetFieldTypes = make([]*types.FieldType, len(e.virtualColumnIndex)) + for i, idx := range e.virtualColumnIndex { + e.virtualColumnRetFieldTypes[i] = e.Schema().Columns[idx].RetType + } + } +} + +// Open implements the Executor interface. +func (e *BatchPointGetExec) Open(context.Context) error { + sessVars := e.Ctx().GetSessionVars() + txnCtx := sessVars.TxnCtx + txn, err := e.Ctx().Txn(false) + if err != nil { + return err + } + e.txn = txn + + setOptionForTopSQL(e.Ctx().GetSessionVars().StmtCtx, e.snapshot) + var batchGetter kv.BatchGetter = e.snapshot + if txn.Valid() { + lock := e.tblInfo.Lock + if e.lock { + batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), &PessimisticLockCacheGetter{txnCtx: txnCtx}, e.snapshot) + } else if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) && e.Ctx().GetSessionVars().EnablePointGetCache { + batchGetter = newCacheBatchGetter(e.Ctx(), e.tblInfo.ID, e.snapshot) + } else { + batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), nil, e.snapshot) + } + } + e.batchGetter = batchGetter + return nil +} + +// CacheTable always use memBuffer in session as snapshot. +// cacheTableSnapshot inherits kv.Snapshot and override the BatchGet methods and Get methods. +type cacheTableSnapshot struct { + kv.Snapshot + memBuffer kv.MemBuffer +} + +func (s cacheTableSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { + values := make(map[string][]byte) + if s.memBuffer == nil { + return values, nil + } + + for _, key := range keys { + val, err := s.memBuffer.Get(ctx, key) + if kv.ErrNotExist.Equal(err) { + continue + } + + if err != nil { + return nil, err + } + + if len(val) == 0 { + continue + } + + values[string(key)] = val + } + + return values, nil +} + +func (s cacheTableSnapshot) Get(ctx context.Context, key kv.Key) ([]byte, error) { + return s.memBuffer.Get(ctx, key) +} + +// MockNewCacheTableSnapShot only serves for test. +func MockNewCacheTableSnapShot(snapshot kv.Snapshot, memBuffer kv.MemBuffer) *cacheTableSnapshot { + return &cacheTableSnapshot{snapshot, memBuffer} +} + +// Close implements the Executor interface. +func (e *BatchPointGetExec) Close() error { + if e.RuntimeStats() != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + if e.RuntimeStats() != nil && e.snapshot != nil { + e.snapshot.SetOption(kv.CollectRuntimeStats, nil) + } + if e.indexUsageReporter != nil && e.idxInfo != nil { + kvReqTotal := e.stats.GetCmdRPCCount(tikvrpc.CmdBatchGet) + // We cannot distinguish how many rows are coming from each partition. Here, we calculate all index usages + // percentage according to the row counts for the whole table. + e.indexUsageReporter.ReportPointGetIndexUsage(e.tblInfo.ID, e.tblInfo.ID, e.idxInfo.ID, e.ID(), kvReqTotal) + } + e.inited = 0 + e.index = 0 + return nil +} + +// Next implements the Executor interface. +func (e *BatchPointGetExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if atomic.CompareAndSwapUint32(&e.inited, 0, 1) { + if err := e.initialize(ctx); err != nil { + return err + } + if e.lock { + e.UpdateDeltaForTableID(e.tblInfo.ID) + } + } + + if e.index >= len(e.values) { + return nil + } + + schema := e.Schema() + sctx := e.BaseExecutor.Ctx() + start := e.index + for !req.IsFull() && e.index < len(e.values) { + handle, val := e.handles[e.index], e.values[e.index] + err := DecodeRowValToChunk(sctx, schema, e.tblInfo, handle, val, req, e.rowDecoder) + if err != nil { + return err + } + e.index++ + } + + err := fillRowChecksum(sctx, start, e.index, schema, e.tblInfo, e.values, e.handles, req, nil) + if err != nil { + return err + } + err = table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, schema.Columns, e.columns, sctx.GetExprCtx(), req) + if err != nil { + return err + } + return nil +} + +func (e *BatchPointGetExec) initialize(ctx context.Context) error { + var handleVals map[string][]byte + var indexKeys []kv.Key + var err error + batchGetter := e.batchGetter + rc := e.Ctx().GetSessionVars().IsPessimisticReadConsistency() + if e.idxInfo != nil && !isCommonHandleRead(e.tblInfo, e.idxInfo) { + // `SELECT a, b FROM t WHERE (a, b) IN ((1, 2), (1, 2), (2, 1), (1, 2))` should not return duplicated rows + dedup := make(map[hack.MutableString]struct{}) + toFetchIndexKeys := make([]kv.Key, 0, len(e.idxVals)) + for i, idxVals := range e.idxVals { + physID := e.tblInfo.ID + if e.singlePartID != 0 { + physID = e.singlePartID + } else if len(e.planPhysIDs) > i { + physID = e.planPhysIDs[i] + } + idxKey, err1 := plannercore.EncodeUniqueIndexKey(e.Ctx(), e.tblInfo, e.idxInfo, idxVals, physID) + if err1 != nil && !kv.ErrNotExist.Equal(err1) { + return err1 + } + if idxKey == nil { + continue + } + s := hack.String(idxKey) + if _, found := dedup[s]; found { + continue + } + dedup[s] = struct{}{} + toFetchIndexKeys = append(toFetchIndexKeys, idxKey) + } + if e.keepOrder { + // TODO: if multiple partitions, then the IDs needs to be + // in the same order as the index keys + // and should skip table id part when comparing + intest.Assert(e.singlePartID != 0 || len(e.planPhysIDs) <= 1 || e.idxInfo.Global) + slices.SortFunc(toFetchIndexKeys, func(i, j kv.Key) int { + if e.desc { + return j.Cmp(i) + } + return i.Cmp(j) + }) + } + + // lock all keys in repeatable read isolation. + // for read consistency, only lock exist keys, + // indexKeys will be generated after getting handles. + if !rc { + indexKeys = toFetchIndexKeys + } else { + indexKeys = make([]kv.Key, 0, len(toFetchIndexKeys)) + } + + // SELECT * FROM t WHERE x IN (null), in this case there is no key. + if len(toFetchIndexKeys) == 0 { + return nil + } + + // Fetch all handles. + handleVals, err = batchGetter.BatchGet(ctx, toFetchIndexKeys) + if err != nil { + return err + } + + e.handles = make([]kv.Handle, 0, len(toFetchIndexKeys)) + if e.tblInfo.Partition != nil { + e.planPhysIDs = e.planPhysIDs[:0] + } + for _, key := range toFetchIndexKeys { + handleVal := handleVals[string(key)] + if len(handleVal) == 0 { + continue + } + handle, err1 := tablecodec.DecodeHandleInIndexValue(handleVal) + if err1 != nil { + return err1 + } + if e.tblInfo.Partition != nil { + var pid int64 + if e.idxInfo.Global { + _, pid, err = codec.DecodeInt(tablecodec.SplitIndexValue(handleVal).PartitionID) + if err != nil { + return err + } + if e.singlePartID != 0 && e.singlePartID != pid { + continue + } + if !matchPartitionNames(pid, e.partitionNames, e.tblInfo.GetPartitionInfo()) { + continue + } + e.planPhysIDs = append(e.planPhysIDs, pid) + } else { + pid = tablecodec.DecodeTableID(key) + e.planPhysIDs = append(e.planPhysIDs, pid) + } + if e.lock { + e.UpdateDeltaForTableID(pid) + } + } + e.handles = append(e.handles, handle) + if rc { + indexKeys = append(indexKeys, key) + } + } + + // The injection is used to simulate following scenario: + // 1. Session A create a point get query but pause before second time `GET` kv from backend + // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 + // 3. Then point get retrieve data from backend after step 2 finished + // 4. Check the result + failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step1", func() { + if ch, ok := ctx.Value("batchPointGetRepeatableReadTest").(chan struct{}); ok { + // Make `UPDATE` continue + close(ch) + } + // Wait `UPDATE` finished + failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step2", nil) + }) + } else if e.keepOrder { + less := func(i, j kv.Handle) int { + if e.desc { + return j.Compare(i) + } + return i.Compare(j) + } + if e.tblInfo.PKIsHandle && mysql.HasUnsignedFlag(e.tblInfo.GetPkColInfo().GetFlag()) { + uintComparator := func(i, h kv.Handle) int { + if !i.IsInt() || !h.IsInt() { + panic(fmt.Sprintf("both handles need be IntHandle, but got %T and %T ", i, h)) + } + ihVal := uint64(i.IntValue()) + hVal := uint64(h.IntValue()) + if ihVal > hVal { + return 1 + } + if ihVal < hVal { + return -1 + } + return 0 + } + less = func(i, j kv.Handle) int { + if e.desc { + return uintComparator(j, i) + } + return uintComparator(i, j) + } + } + slices.SortFunc(e.handles, less) + // TODO: if partitioned table, sorting the handles would also + // need to have the physIDs rearranged in the same order! + intest.Assert(e.singlePartID != 0 || len(e.planPhysIDs) <= 1) + } + + keys := make([]kv.Key, 0, len(e.handles)) + newHandles := make([]kv.Handle, 0, len(e.handles)) + for i, handle := range e.handles { + tID := e.tblInfo.ID + if e.singlePartID != 0 { + tID = e.singlePartID + } else if len(e.planPhysIDs) > 0 { + // Direct handle read + tID = e.planPhysIDs[i] + } + if tID <= 0 { + // not matching any partition + continue + } + key := tablecodec.EncodeRowKeyWithHandle(tID, handle) + keys = append(keys, key) + newHandles = append(newHandles, handle) + } + e.handles = newHandles + + var values map[string][]byte + // Lock keys (include exists and non-exists keys) before fetch all values for Repeatable Read Isolation. + if e.lock && !rc { + lockKeys := make([]kv.Key, len(keys)+len(indexKeys)) + copy(lockKeys, keys) + copy(lockKeys[len(keys):], indexKeys) + err = LockKeys(ctx, e.Ctx(), e.waitTime, lockKeys...) + if err != nil { + return err + } + } + // Fetch all values. + values, err = batchGetter.BatchGet(ctx, keys) + if err != nil { + return err + } + handles := make([]kv.Handle, 0, len(values)) + var existKeys []kv.Key + if e.lock && rc { + existKeys = make([]kv.Key, 0, 2*len(values)) + } + e.values = make([][]byte, 0, len(values)) + for i, key := range keys { + val := values[string(key)] + if len(val) == 0 { + if e.idxInfo != nil && (!e.tblInfo.IsCommonHandle || !e.idxInfo.Primary) && + !e.Ctx().GetSessionVars().StmtCtx.WeakConsistency { + return (&consistency.Reporter{ + HandleEncode: func(_ kv.Handle) kv.Key { + return key + }, + IndexEncode: func(_ *consistency.RecordData) kv.Key { + return indexKeys[i] + }, + Tbl: e.tblInfo, + Idx: e.idxInfo, + EnableRedactLog: e.Ctx().GetSessionVars().EnableRedactLog, + Storage: e.Ctx().GetStore(), + }).ReportLookupInconsistent(ctx, + 1, 0, + e.handles[i:i+1], + e.handles, + []consistency.RecordData{{}}, + ) + } + continue + } + e.values = append(e.values, val) + handles = append(handles, e.handles[i]) + if e.lock && rc { + existKeys = append(existKeys, key) + // when e.handles is set in builder directly, index should be primary key and the plan is CommonHandleRead + // with clustered index enabled, indexKeys is empty in this situation + // lock primary key for clustered index table is redundant + if len(indexKeys) != 0 { + existKeys = append(existKeys, indexKeys[i]) + } + } + } + // Lock exists keys only for Read Committed Isolation. + if e.lock && rc { + err = LockKeys(ctx, e.Ctx(), e.waitTime, existKeys...) + if err != nil { + return err + } + } + e.handles = handles + return nil +} + +// LockKeys locks the keys for pessimistic transaction. +func LockKeys(ctx context.Context, sctx sessionctx.Context, lockWaitTime int64, keys ...kv.Key) error { + txnCtx := sctx.GetSessionVars().TxnCtx + lctx, err := newLockCtx(sctx, lockWaitTime, len(keys)) + if err != nil { + return err + } + if txnCtx.IsPessimistic { + lctx.InitReturnValues(len(keys)) + } + err = doLockKeys(ctx, sctx, lctx, keys...) + if err != nil { + return err + } + if txnCtx.IsPessimistic { + // When doLockKeys returns without error, no other goroutines access the map, + // it's safe to read it without mutex. + for _, key := range keys { + if v, ok := lctx.GetValueNotLocked(key); ok { + txnCtx.SetPessimisticLockCache(key, v) + } + } + } + return nil +} + +// PessimisticLockCacheGetter implements the kv.Getter interface. +// It is used as a middle cache to construct the BufferedBatchGetter. +type PessimisticLockCacheGetter struct { + txnCtx *variable.TransactionContext +} + +// Get implements the kv.Getter interface. +func (getter *PessimisticLockCacheGetter) Get(_ context.Context, key kv.Key) ([]byte, error) { + val, ok := getter.txnCtx.GetKeyInPessimisticLockCache(key) + if ok { + return val, nil + } + return nil, kv.ErrNotExist +} + +type cacheBatchGetter struct { + ctx sessionctx.Context + tid int64 + snapshot kv.Snapshot +} + +func (b *cacheBatchGetter) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { + cacheDB := b.ctx.GetStore().GetMemCache() + vals := make(map[string][]byte) + for _, key := range keys { + val, err := cacheDB.UnionGet(ctx, b.tid, b.snapshot, key) + if err != nil { + if !kv.ErrNotExist.Equal(err) { + return nil, err + } + continue + } + vals[string(key)] = val + } + return vals, nil +} + +func newCacheBatchGetter(ctx sessionctx.Context, tid int64, snapshot kv.Snapshot) *cacheBatchGetter { + return &cacheBatchGetter{ctx, tid, snapshot} +} diff --git a/pkg/executor/binding__failpoint_binding__.go b/pkg/executor/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..4ed4af3ddbf4a --- /dev/null +++ b/pkg/executor/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package executor + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/brie.go b/pkg/executor/brie.go index 130f627bc6609..478ba583f4dd2 100644 --- a/pkg/executor/brie.go +++ b/pkg/executor/brie.go @@ -313,9 +313,9 @@ func (b *executorBuilder) buildBRIE(s *ast.BRIEStmt, schema *expression.Schema) } store := tidbCfg.Store - failpoint.Inject("modifyStore", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("modifyStore")); _err_ == nil { store = v.(string) - }) + } if store != "tikv" { b.err = errors.Errorf("%s requires tikv store, not %s", s.Kind, store) return nil @@ -581,13 +581,13 @@ func (e *BRIEExec) Next(ctx context.Context, req *chunk.Chunk) error { e.info.queueTime = types.CurrentTime(mysql.TypeDatetime) taskCtx, taskID := bq.registerTask(ctx, e.info) defer bq.cancelTask(taskID) - failpoint.Inject("block-on-brie", func() { + if _, _err_ := failpoint.Eval(_curpkg_("block-on-brie")); _err_ == nil { log.Warn("You shall not pass, nya. :3") <-taskCtx.Done() if taskCtx.Err() != nil { - failpoint.Return(taskCtx.Err()) + return taskCtx.Err() } - }) + } // manually monitor the Killed status... go func() { ticker := time.NewTicker(3 * time.Second) diff --git a/pkg/executor/brie.go__failpoint_stash__ b/pkg/executor/brie.go__failpoint_stash__ new file mode 100644 index 0000000000000..130f627bc6609 --- /dev/null +++ b/pkg/executor/brie.go__failpoint_stash__ @@ -0,0 +1,826 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "context" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + backuppb "github.com/pingcap/kvproto/pkg/brpb" + "github.com/pingcap/kvproto/pkg/encryptionpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/br/pkg/glue" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/task" + "github.com/pingcap/tidb/br/pkg/task/show" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/printer" + "github.com/pingcap/tidb/pkg/util/sem" + "github.com/pingcap/tidb/pkg/util/syncutil" + filter "github.com/pingcap/tidb/pkg/util/table-filter" + "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" +) + +const clearInterval = 10 * time.Minute + +var outdatedDuration = types.Duration{ + Duration: 30 * time.Minute, + Fsp: types.DefaultFsp, +} + +// brieTaskProgress tracks a task's current progress. +type brieTaskProgress struct { + // current progress of the task. + // this field is atomically updated outside of the lock below. + current int64 + + // lock is the mutex protected the two fields below. + lock syncutil.Mutex + // cmd is the name of the step the BRIE task is currently performing. + cmd string + // total is the total progress of the task. + // the percentage of completeness is `(100%) * current / total`. + total int64 +} + +// Inc implements glue.Progress +func (p *brieTaskProgress) Inc() { + atomic.AddInt64(&p.current, 1) +} + +// IncBy implements glue.Progress +func (p *brieTaskProgress) IncBy(cnt int64) { + atomic.AddInt64(&p.current, cnt) +} + +// GetCurrent implements glue.Progress +func (p *brieTaskProgress) GetCurrent() int64 { + return atomic.LoadInt64(&p.current) +} + +// Close implements glue.Progress +func (p *brieTaskProgress) Close() { + p.lock.Lock() + current := atomic.LoadInt64(&p.current) + if current < p.total { + p.cmd = fmt.Sprintf("%s Canceled", p.cmd) + } + atomic.StoreInt64(&p.current, p.total) + p.lock.Unlock() +} + +type brieTaskInfo struct { + id uint64 + query string + queueTime types.Time + execTime types.Time + finishTime types.Time + kind ast.BRIEKind + storage string + connID uint64 + backupTS uint64 + restoreTS uint64 + archiveSize uint64 + message string +} + +type brieQueueItem struct { + info *brieTaskInfo + progress *brieTaskProgress + cancel func() +} + +type brieQueue struct { + nextID uint64 + tasks sync.Map + + lastClearTime time.Time + + workerCh chan struct{} +} + +// globalBRIEQueue is the BRIE execution queue. Only one BRIE task can be executed each time. +// TODO: perhaps copy the DDL Job queue so only one task can be executed in the whole cluster. +var globalBRIEQueue = &brieQueue{ + workerCh: make(chan struct{}, 1), +} + +// ResetGlobalBRIEQueueForTest resets the ID allocation for the global BRIE queue. +// In some of our test cases, we rely on the ID is allocated from 1. +// When batch executing test cases, the assumation may be broken and make the cases fail. +func ResetGlobalBRIEQueueForTest() { + globalBRIEQueue = &brieQueue{ + workerCh: make(chan struct{}, 1), + } +} + +// registerTask registers a BRIE task in the queue. +func (bq *brieQueue) registerTask( + ctx context.Context, + info *brieTaskInfo, +) (context.Context, uint64) { + taskCtx, taskCancel := context.WithCancel(ctx) + item := &brieQueueItem{ + info: info, + cancel: taskCancel, + progress: &brieTaskProgress{ + cmd: "Wait", + total: 1, + }, + } + + taskID := atomic.AddUint64(&bq.nextID, 1) + bq.tasks.Store(taskID, item) + info.id = taskID + + return taskCtx, taskID +} + +// query task queries a task from the queue. +func (bq *brieQueue) queryTask(taskID uint64) (*brieTaskInfo, bool) { + if item, ok := bq.tasks.Load(taskID); ok { + return item.(*brieQueueItem).info, true + } + return nil, false +} + +// acquireTask prepares to execute a BRIE task. Only one BRIE task can be +// executed at a time, and this function blocks until the task is ready. +// +// Returns an object to track the task's progress. +func (bq *brieQueue) acquireTask(taskCtx context.Context, taskID uint64) (*brieTaskProgress, error) { + // wait until we are at the front of the queue. + select { + case bq.workerCh <- struct{}{}: + if item, ok := bq.tasks.Load(taskID); ok { + return item.(*brieQueueItem).progress, nil + } + // cannot find task, perhaps it has been canceled. allow the next task to run. + bq.releaseTask() + return nil, errors.Errorf("backup/restore task %d is canceled", taskID) + case <-taskCtx.Done(): + return nil, taskCtx.Err() + } +} + +func (bq *brieQueue) releaseTask() { + <-bq.workerCh +} + +func (bq *brieQueue) cancelTask(taskID uint64) bool { + item, ok := bq.tasks.Load(taskID) + if !ok { + return false + } + i := item.(*brieQueueItem) + i.cancel() + i.progress.Close() + log.Info("BRIE job canceled.", zap.Uint64("ID", i.info.id)) + return true +} + +func (bq *brieQueue) clearTask(sc *stmtctx.StatementContext) { + if time.Since(bq.lastClearTime) < clearInterval { + return + } + + bq.lastClearTime = time.Now() + currTime := types.CurrentTime(mysql.TypeDatetime) + + bq.tasks.Range(func(key, value any) bool { + item := value.(*brieQueueItem) + if d := currTime.Sub(sc.TypeCtx(), &item.info.finishTime); d.Compare(outdatedDuration) > 0 { + bq.tasks.Delete(key) + } + return true + }) +} + +func (b *executorBuilder) parseTSString(ts string) (uint64, error) { + sc := stmtctx.NewStmtCtxWithTimeZone(b.ctx.GetSessionVars().Location()) + t, err := types.ParseTime(sc.TypeCtx(), ts, mysql.TypeTimestamp, types.MaxFsp) + if err != nil { + return 0, err + } + t1, err := t.GoTime(sc.TimeZone()) + if err != nil { + return 0, err + } + return oracle.GoTimeToTS(t1), nil +} + +func (b *executorBuilder) buildBRIE(s *ast.BRIEStmt, schema *expression.Schema) exec.Executor { + if s.Kind == ast.BRIEKindShowBackupMeta { + return execOnce(&showMetaExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), + showConfig: buildShowMetadataConfigFrom(s), + }) + } + + if s.Kind == ast.BRIEKindShowQuery { + return execOnce(&showQueryExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), + targetID: uint64(s.JobID), + }) + } + + if s.Kind == ast.BRIEKindCancelJob { + return &cancelJobExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), + targetID: uint64(s.JobID), + } + } + + e := &BRIEExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), + info: &brieTaskInfo{ + kind: s.Kind, + }, + } + + tidbCfg := config.GetGlobalConfig() + tlsCfg := task.TLSConfig{ + CA: tidbCfg.Security.ClusterSSLCA, + Cert: tidbCfg.Security.ClusterSSLCert, + Key: tidbCfg.Security.ClusterSSLKey, + } + pds := strings.Split(tidbCfg.Path, ",") + cfg := task.DefaultConfig() + cfg.PD = pds + cfg.TLS = tlsCfg + + storageURL, err := storage.ParseRawURL(s.Storage) + if err != nil { + b.err = errors.Annotate(err, "invalid destination URL") + return nil + } + + switch storageURL.Scheme { + case "s3": + storage.ExtractQueryParameters(storageURL, &cfg.S3) + case "gs", "gcs": + storage.ExtractQueryParameters(storageURL, &cfg.GCS) + case "hdfs": + if sem.IsEnabled() { + // Storage is not permitted to be hdfs when SEM is enabled. + b.err = plannererrors.ErrNotSupportedWithSem.GenWithStackByArgs("hdfs storage") + return nil + } + case "local", "file", "": + if sem.IsEnabled() { + // Storage is not permitted to be local when SEM is enabled. + b.err = plannererrors.ErrNotSupportedWithSem.GenWithStackByArgs("local storage") + return nil + } + default: + } + + store := tidbCfg.Store + failpoint.Inject("modifyStore", func(v failpoint.Value) { + store = v.(string) + }) + if store != "tikv" { + b.err = errors.Errorf("%s requires tikv store, not %s", s.Kind, store) + return nil + } + + cfg.Storage = storageURL.String() + e.info.storage = cfg.Storage + + for _, opt := range s.Options { + switch opt.Tp { + case ast.BRIEOptionRateLimit: + cfg.RateLimit = opt.UintValue + case ast.BRIEOptionConcurrency: + cfg.Concurrency = uint32(opt.UintValue) + case ast.BRIEOptionChecksum: + cfg.Checksum = opt.UintValue != 0 + case ast.BRIEOptionSendCreds: + cfg.SendCreds = opt.UintValue != 0 + case ast.BRIEOptionChecksumConcurrency: + cfg.ChecksumConcurrency = uint(opt.UintValue) + case ast.BRIEOptionEncryptionKeyFile: + cfg.CipherInfo.CipherKey, err = task.GetCipherKeyContent("", opt.StrValue) + if err != nil { + b.err = err + return nil + } + case ast.BRIEOptionEncryptionMethod: + switch opt.StrValue { + case "aes128-ctr": + cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_AES128_CTR + case "aes192-ctr": + cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_AES192_CTR + case "aes256-ctr": + cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_AES256_CTR + case "plaintext": + cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_PLAINTEXT + default: + b.err = errors.Errorf("unsupported encryption method: %s", opt.StrValue) + return nil + } + } + } + + switch { + case len(s.Tables) != 0: + tables := make([]filter.Table, 0, len(s.Tables)) + for _, tbl := range s.Tables { + tables = append(tables, filter.Table{Name: tbl.Name.O, Schema: tbl.Schema.O}) + } + cfg.TableFilter = filter.NewTablesFilter(tables...) + case len(s.Schemas) != 0: + cfg.TableFilter = filter.NewSchemasFilter(s.Schemas...) + default: + cfg.TableFilter = filter.All() + } + + // table options are stored in original case, but comparison + // is expected to be performed insensitive. + cfg.TableFilter = filter.CaseInsensitive(cfg.TableFilter) + + // We cannot directly use the query string, or the secret may be print. + // NOTE: the ownership of `s.Storage` is taken here. + s.Storage = e.info.storage + e.info.query = restoreQuery(s) + + switch s.Kind { + case ast.BRIEKindBackup: + bcfg := task.DefaultBackupConfig() + bcfg.Config = cfg + e.backupCfg = &bcfg + + for _, opt := range s.Options { + switch opt.Tp { + case ast.BRIEOptionLastBackupTS: + tso, err := b.parseTSString(opt.StrValue) + if err != nil { + b.err = err + return nil + } + e.backupCfg.LastBackupTS = tso + case ast.BRIEOptionLastBackupTSO: + e.backupCfg.LastBackupTS = opt.UintValue + case ast.BRIEOptionBackupTimeAgo: + e.backupCfg.TimeAgo = time.Duration(opt.UintValue) + case ast.BRIEOptionBackupTSO: + e.backupCfg.BackupTS = opt.UintValue + case ast.BRIEOptionBackupTS: + tso, err := b.parseTSString(opt.StrValue) + if err != nil { + b.err = err + return nil + } + e.backupCfg.BackupTS = tso + case ast.BRIEOptionCompression: + switch opt.StrValue { + case "zstd": + e.backupCfg.CompressionConfig.CompressionType = backuppb.CompressionType_ZSTD + case "snappy": + e.backupCfg.CompressionConfig.CompressionType = backuppb.CompressionType_SNAPPY + case "lz4": + e.backupCfg.CompressionConfig.CompressionType = backuppb.CompressionType_LZ4 + default: + b.err = errors.Errorf("unsupported compression type: %s", opt.StrValue) + return nil + } + case ast.BRIEOptionCompressionLevel: + e.backupCfg.CompressionConfig.CompressionLevel = int32(opt.UintValue) + case ast.BRIEOptionIgnoreStats: + e.backupCfg.IgnoreStats = opt.UintValue != 0 + } + } + + case ast.BRIEKindRestore: + rcfg := task.DefaultRestoreConfig() + rcfg.Config = cfg + e.restoreCfg = &rcfg + for _, opt := range s.Options { + switch opt.Tp { + case ast.BRIEOptionOnline: + e.restoreCfg.Online = opt.UintValue != 0 + case ast.BRIEOptionWaitTiflashReady: + e.restoreCfg.WaitTiflashReady = opt.UintValue != 0 + case ast.BRIEOptionWithSysTable: + e.restoreCfg.WithSysTable = opt.UintValue != 0 + case ast.BRIEOptionLoadStats: + e.restoreCfg.LoadStats = opt.UintValue != 0 + } + } + + default: + b.err = errors.Errorf("unsupported BRIE statement kind: %s", s.Kind) + return nil + } + + return e +} + +// oneshotExecutor wraps a executor, making its `Next` would only be called once. +type oneshotExecutor struct { + exec.Executor + finished bool +} + +func (o *oneshotExecutor) Next(ctx context.Context, req *chunk.Chunk) error { + if o.finished { + req.Reset() + return nil + } + + if err := o.Executor.Next(ctx, req); err != nil { + return err + } + o.finished = true + return nil +} + +func execOnce(ex exec.Executor) exec.Executor { + return &oneshotExecutor{Executor: ex} +} + +type showQueryExec struct { + exec.BaseExecutor + + targetID uint64 +} + +func (s *showQueryExec) Next(_ context.Context, req *chunk.Chunk) error { + req.Reset() + + tsk, ok := globalBRIEQueue.queryTask(s.targetID) + if !ok { + return nil + } + + req.AppendString(0, tsk.query) + return nil +} + +type cancelJobExec struct { + exec.BaseExecutor + + targetID uint64 +} + +func (s cancelJobExec) Next(_ context.Context, req *chunk.Chunk) error { + req.Reset() + if !globalBRIEQueue.cancelTask(s.targetID) { + s.Ctx().GetSessionVars().StmtCtx.AppendWarning(exeerrors.ErrLoadDataJobNotFound.FastGenByArgs(s.targetID)) + } + return nil +} + +type showMetaExec struct { + exec.BaseExecutor + + showConfig show.Config +} + +// BRIEExec represents an executor for BRIE statements (BACKUP, RESTORE, etc) +type BRIEExec struct { + exec.BaseExecutor + + backupCfg *task.BackupConfig + restoreCfg *task.RestoreConfig + showConfig *show.Config + info *brieTaskInfo +} + +func buildShowMetadataConfigFrom(s *ast.BRIEStmt) show.Config { + if s.Kind != ast.BRIEKindShowBackupMeta { + panic(fmt.Sprintf("precondition failed: `fillByShowMetadata` should always called by a ast.BRIEKindShowBackupMeta, but it is %s.", s.Kind)) + } + + store := s.Storage + cfg := show.Config{ + Storage: store, + Cipher: backuppb.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_PLAINTEXT, + }, + } + return cfg +} + +func (e *showMetaExec) Next(ctx context.Context, req *chunk.Chunk) error { + exe, err := show.CreateExec(ctx, e.showConfig) + if err != nil { + return errors.Annotate(err, "failed to create show exec") + } + res, err := exe.Read(ctx) + if err != nil { + return errors.Annotate(err, "failed to read metadata from backupmeta") + } + + startTime := oracle.GetTimeFromTS(uint64(res.StartVersion)) + endTime := oracle.GetTimeFromTS(uint64(res.EndVersion)) + + for _, table := range res.Tables { + req.AppendString(0, table.DBName) + req.AppendString(1, table.TableName) + req.AppendInt64(2, int64(table.KVCount)) + req.AppendInt64(3, int64(table.KVSize)) + if res.StartVersion > 0 { + req.AppendTime(4, types.NewTime(types.FromGoTime(startTime.In(e.Ctx().GetSessionVars().Location())), mysql.TypeDatetime, 0)) + } else { + req.AppendNull(4) + } + req.AppendTime(5, types.NewTime(types.FromGoTime(endTime.In(e.Ctx().GetSessionVars().Location())), mysql.TypeDatetime, 0)) + } + return nil +} + +// Next implements the Executor Next interface. +func (e *BRIEExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.info == nil { + return nil + } + + bq := globalBRIEQueue + bq.clearTask(e.Ctx().GetSessionVars().StmtCtx) + + e.info.connID = e.Ctx().GetSessionVars().ConnectionID + e.info.queueTime = types.CurrentTime(mysql.TypeDatetime) + taskCtx, taskID := bq.registerTask(ctx, e.info) + defer bq.cancelTask(taskID) + failpoint.Inject("block-on-brie", func() { + log.Warn("You shall not pass, nya. :3") + <-taskCtx.Done() + if taskCtx.Err() != nil { + failpoint.Return(taskCtx.Err()) + } + }) + // manually monitor the Killed status... + go func() { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if e.Ctx().GetSessionVars().SQLKiller.HandleSignal() == exeerrors.ErrQueryInterrupted { + bq.cancelTask(taskID) + return + } + case <-taskCtx.Done(): + return + } + } + }() + + progress, err := bq.acquireTask(taskCtx, taskID) + if err != nil { + return err + } + defer bq.releaseTask() + + e.info.execTime = types.CurrentTime(mysql.TypeDatetime) + glue := &tidbGlue{se: e.Ctx(), progress: progress, info: e.info} + + switch e.info.kind { + case ast.BRIEKindBackup: + err = handleBRIEError(task.RunBackup(taskCtx, glue, "Backup", e.backupCfg), exeerrors.ErrBRIEBackupFailed) + case ast.BRIEKindRestore: + err = handleBRIEError(task.RunRestore(taskCtx, glue, "Restore", e.restoreCfg), exeerrors.ErrBRIERestoreFailed) + default: + err = errors.Errorf("unsupported BRIE statement kind: %s", e.info.kind) + } + e.info.finishTime = types.CurrentTime(mysql.TypeDatetime) + if err != nil { + e.info.message = err.Error() + return err + } + e.info.message = "" + + req.AppendString(0, e.info.storage) + req.AppendUint64(1, e.info.archiveSize) + switch e.info.kind { + case ast.BRIEKindBackup: + req.AppendUint64(2, e.info.backupTS) + req.AppendTime(3, e.info.queueTime) + req.AppendTime(4, e.info.execTime) + case ast.BRIEKindRestore: + req.AppendUint64(2, e.info.backupTS) + req.AppendUint64(3, e.info.restoreTS) + req.AppendTime(4, e.info.queueTime) + req.AppendTime(5, e.info.execTime) + } + e.info = nil + return nil +} + +func handleBRIEError(err error, terror *terror.Error) error { + if err == nil { + return nil + } + return terror.GenWithStackByArgs(err) +} + +func (e *ShowExec) fetchShowBRIE(kind ast.BRIEKind) error { + globalBRIEQueue.tasks.Range(func(_, value any) bool { + item := value.(*brieQueueItem) + if item.info.kind == kind { + item.progress.lock.Lock() + defer item.progress.lock.Unlock() + current := atomic.LoadInt64(&item.progress.current) + e.result.AppendUint64(0, item.info.id) + e.result.AppendString(1, item.info.storage) + e.result.AppendString(2, item.progress.cmd) + e.result.AppendFloat64(3, 100.0*float64(current)/float64(item.progress.total)) + e.result.AppendTime(4, item.info.queueTime) + e.result.AppendTime(5, item.info.execTime) + e.result.AppendTime(6, item.info.finishTime) + e.result.AppendUint64(7, item.info.connID) + if len(item.info.message) > 0 { + e.result.AppendString(8, item.info.message) + } else { + e.result.AppendNull(8) + } + } + return true + }) + globalBRIEQueue.clearTask(e.Ctx().GetSessionVars().StmtCtx) + return nil +} + +type tidbGlue struct { + // the session context of the brie task + se sessionctx.Context + progress *brieTaskProgress + info *brieTaskInfo +} + +// GetDomain implements glue.Glue +func (gs *tidbGlue) GetDomain(_ kv.Storage) (*domain.Domain, error) { + return domain.GetDomain(gs.se), nil +} + +// CreateSession implements glue.Glue +func (gs *tidbGlue) CreateSession(_ kv.Storage) (glue.Session, error) { + newSCtx, err := CreateSession(gs.se) + if err != nil { + return nil, err + } + return &tidbGlueSession{se: newSCtx}, nil +} + +// Open implements glue.Glue +func (gs *tidbGlue) Open(string, pd.SecurityOption) (kv.Storage, error) { + return gs.se.GetStore(), nil +} + +// OwnsStorage implements glue.Glue +func (*tidbGlue) OwnsStorage() bool { + return false +} + +// StartProgress implements glue.Glue +func (gs *tidbGlue) StartProgress(_ context.Context, cmdName string, total int64, _ bool) glue.Progress { + gs.progress.lock.Lock() + gs.progress.cmd = cmdName + gs.progress.total = total + atomic.StoreInt64(&gs.progress.current, 0) + gs.progress.lock.Unlock() + return gs.progress +} + +// Record implements glue.Glue +func (gs *tidbGlue) Record(name string, value uint64) { + switch name { + case "BackupTS": + gs.info.backupTS = value + case "RestoreTS": + gs.info.restoreTS = value + case "Size": + gs.info.archiveSize = value + } +} + +func (*tidbGlue) GetVersion() string { + return "TiDB\n" + printer.GetTiDBInfo() +} + +// UseOneShotSession implements glue.Glue +func (gs *tidbGlue) UseOneShotSession(_ kv.Storage, _ bool, fn func(se glue.Session) error) error { + // In SQL backup, we don't need to close domain, + // but need to create an new session. + newSCtx, err := CreateSession(gs.se) + if err != nil { + return err + } + glueSession := &tidbGlueSession{se: newSCtx} + defer func() { + CloseSession(newSCtx) + log.Info("one shot session from brie closed") + }() + return fn(glueSession) +} + +type tidbGlueSession struct { + // the session context of the brie task's subtask, such as `CREATE TABLE`. + se sessionctx.Context +} + +// Execute implements glue.Session +// These queries execute without privilege checking, since the calling statements +// such as BACKUP and RESTORE have already been privilege checked. +// NOTE: Maybe drain the restult too? See `gluetidb.tidbSession.ExecuteInternal` for more details. +func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) + _, _, err := gs.se.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, sql) + return err +} + +func (gs *tidbGlueSession) ExecuteInternal(ctx context.Context, sql string, args ...any) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) + exec := gs.se.GetSQLExecutor() + _, err := exec.ExecuteInternal(ctx, sql, args...) + return err +} + +// CreateDatabase implements glue.Session +func (gs *tidbGlueSession) CreateDatabase(_ context.Context, schema *model.DBInfo) error { + return BRIECreateDatabase(gs.se, schema, "") +} + +// CreateTable implements glue.Session +func (gs *tidbGlueSession) CreateTable(_ context.Context, dbName model.CIStr, table *model.TableInfo, cs ...ddl.CreateTableOption) error { + return BRIECreateTable(gs.se, dbName, table, "", cs...) +} + +// CreateTables implements glue.BatchCreateTableSession. +func (gs *tidbGlueSession) CreateTables(_ context.Context, + tables map[string][]*model.TableInfo, cs ...ddl.CreateTableOption) error { + return BRIECreateTables(gs.se, tables, "", cs...) +} + +// CreatePlacementPolicy implements glue.Session +func (gs *tidbGlueSession) CreatePlacementPolicy(_ context.Context, policy *model.PolicyInfo) error { + originQueryString := gs.se.Value(sessionctx.QueryString) + defer gs.se.SetValue(sessionctx.QueryString, originQueryString) + gs.se.SetValue(sessionctx.QueryString, ConstructResultOfShowCreatePlacementPolicy(policy)) + d := domain.GetDomain(gs.se).DDLExecutor() + // the default behaviour is ignoring duplicated policy during restore. + return d.CreatePlacementPolicyWithInfo(gs.se, policy, ddl.OnExistIgnore) +} + +// Close implements glue.Session +func (gs *tidbGlueSession) Close() { + CloseSession(gs.se) +} + +// GetGlobalVariables implements glue.Session. +func (gs *tidbGlueSession) GetGlobalVariable(name string) (string, error) { + return gs.se.GetSessionVars().GlobalVarsAccessor.GetTiDBTableValue(name) +} + +// GetSessionCtx implements glue.Glue +func (gs *tidbGlueSession) GetSessionCtx() sessionctx.Context { + return gs.se +} + +func restoreQuery(stmt *ast.BRIEStmt) string { + out := bytes.NewBuffer(nil) + rc := format.NewRestoreCtx(format.RestoreNameBackQuotes|format.RestoreStringSingleQuotes, out) + if err := stmt.Restore(rc); err != nil { + return "N/A" + } + return out.String() +} diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index 6bcee9f5f8be3..95fb0b69094a4 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -846,7 +846,7 @@ func (b *executorBuilder) buildExecute(v *plannercore.Execute) exec.Executor { outputNames: v.OutputNames(), } - failpoint.Inject("assertExecutePrepareStatementStalenessOption", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertExecutePrepareStatementStalenessOption")); _err_ == nil { vs := strings.Split(val.(string), "_") assertTS, assertReadReplicaScope := vs[0], vs[1] staleread.AssertStmtStaleness(b.ctx, true) @@ -859,7 +859,7 @@ func (b *executorBuilder) buildExecute(v *plannercore.Execute) exec.Executor { assertReadReplicaScope != b.readReplicaScope { panic("execute prepare statement have wrong staleness option") } - }) + } return e } @@ -2729,9 +2729,9 @@ func (b *executorBuilder) buildAnalyzeIndexPushdown(task plannercore.AnalyzeInde b.err = err return nil } - failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectAnalyzeSnapshot")); _err_ == nil { startTS = uint64(val.(int)) - }) + } concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) base := baseAnalyzeExec{ ctx: b.ctx, @@ -2802,21 +2802,21 @@ func (b *executorBuilder) buildAnalyzeSamplingPushdown( b.err = err return nil } - failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectAnalyzeSnapshot")); _err_ == nil { startTS = uint64(val.(int)) - }) + } statsHandle := domain.GetDomain(b.ctx).StatsHandle() count, modifyCount, err := statsHandle.StatsMetaCountAndModifyCount(task.TableID.GetStatisticsID()) if err != nil { b.err = err return nil } - failpoint.Inject("injectBaseCount", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectBaseCount")); _err_ == nil { count = int64(val.(int)) - }) - failpoint.Inject("injectBaseModifyCount", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("injectBaseModifyCount")); _err_ == nil { modifyCount = int64(val.(int)) - }) + } sampleRate := new(float64) var sampleRateReason string if opts[ast.AnalyzeOptNumSamples] == 0 { @@ -2980,9 +2980,9 @@ func (b *executorBuilder) buildAnalyzeColumnsPushdown( b.err = err return nil } - failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectAnalyzeSnapshot")); _err_ == nil { startTS = uint64(val.(int)) - }) + } concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) base := baseAnalyzeExec{ ctx: b.ctx, @@ -3588,16 +3588,16 @@ func (b *executorBuilder) buildMPPGather(v *plannercore.PhysicalTableReader) exe // buildTableReader builds a table reader executor. It first build a no range table reader, // and then update it ranges from table scan plan. func (b *executorBuilder) buildTableReader(v *plannercore.PhysicalTableReader) exec.Executor { - failpoint.Inject("checkUseMPP", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkUseMPP")); _err_ == nil { if !b.ctx.GetSessionVars().InRestrictedSQL && val.(bool) != useMPPExecution(b.ctx, v) { if val.(bool) { b.err = errors.New("expect mpp but not used") } else { b.err = errors.New("don't expect mpp but we used it") } - failpoint.Return(nil) + return nil } - }) + } // https://github.com/pingcap/tidb/issues/50358 if len(v.Schema().Columns) == 0 && len(v.GetTablePlan().Schema().Columns) > 0 { v.SetSchema(v.GetTablePlan().Schema()) @@ -4790,11 +4790,11 @@ func (builder *dataReaderBuilder) buildProjectionForIndexJoin( if int64(v.StatsCount()) < int64(builder.ctx.GetSessionVars().MaxChunkSize) { e.numWorkers = 0 } - failpoint.Inject("buildProjectionForIndexJoinPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("buildProjectionForIndexJoinPanic")); _err_ == nil { if v, ok := val.(bool); ok && v { panic("buildProjectionForIndexJoinPanic") } - }) + } err = e.open(ctx) if err != nil { return nil, err @@ -4892,9 +4892,9 @@ func buildKvRangesForIndexJoin(dctx *distsqlctx.DistSQLContext, pctx *rangerctx. } } if len(kvRanges) != 0 && memTracker != nil { - failpoint.Inject("testIssue49033", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testIssue49033")); _err_ == nil { panic("testIssue49033") - }) + } memTracker.Consume(int64(2 * cap(kvRanges[0].StartKey) * len(kvRanges))) } if len(tmpDatumRanges) != 0 && memTracker != nil { @@ -5251,12 +5251,12 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan sctx.IndexNames = append(sctx.IndexNames, plan.TblInfo.Name.O+":"+plan.IndexInfo.Name.O) } - failpoint.Inject("assertBatchPointReplicaOption", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertBatchPointReplicaOption")); _err_ == nil { assertScope := val.(string) if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != b.readReplicaScope { panic("batch point get replica option fail") } - }) + } snapshotTS, err := b.getSnapshotTS() if err != nil { diff --git a/pkg/executor/builder.go__failpoint_stash__ b/pkg/executor/builder.go__failpoint_stash__ new file mode 100644 index 0000000000000..6bcee9f5f8be3 --- /dev/null +++ b/pkg/executor/builder.go__failpoint_stash__ @@ -0,0 +1,5659 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "cmp" + "context" + "fmt" + "math" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/diagnosticspb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/distsql" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/executor/aggfuncs" + "github.com/pingcap/tidb/pkg/executor/aggregate" + "github.com/pingcap/tidb/pkg/executor/internal/builder" + "github.com/pingcap/tidb/pkg/executor/internal/calibrateresource" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/internal/pdhelper" + "github.com/pingcap/tidb/pkg/executor/internal/querywatch" + "github.com/pingcap/tidb/pkg/executor/internal/testutil" + "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" + "github.com/pingcap/tidb/pkg/executor/join" + "github.com/pingcap/tidb/pkg/executor/lockstats" + executor_metrics "github.com/pingcap/tidb/pkg/executor/metrics" + "github.com/pingcap/tidb/pkg/executor/sortexec" + "github.com/pingcap/tidb/pkg/executor/unionexec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/coreusage" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/sessiontxn/staleread" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/table/temptable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/cteutil" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" + rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/pingcap/tidb/pkg/util/tiflash" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tipb/go-tipb" + clientkv "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/txnkv" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" +) + +// executorBuilder builds an Executor from a Plan. +// The InfoSchema must not change during execution. +type executorBuilder struct { + ctx sessionctx.Context + is infoschema.InfoSchema + err error // err is set when there is error happened during Executor building process. + hasLock bool + // isStaleness means whether this statement use stale read. + isStaleness bool + txnScope string + readReplicaScope string + inUpdateStmt bool + inDeleteStmt bool + inInsertStmt bool + inSelectLockStmt bool + + // forDataReaderBuilder indicates whether the builder is used by a dataReaderBuilder. + // When forDataReader is true, the builder should use the dataReaderTS as the executor read ts. This is because + // dataReaderBuilder can be used in concurrent goroutines, so we must ensure that getting the ts should be thread safe and + // can return a correct value even if the session context has already been destroyed + forDataReaderBuilder bool + dataReaderTS uint64 + + // Used when building MPPGather. + encounterUnionScan bool +} + +// CTEStorages stores resTbl and iterInTbl for CTEExec. +// There will be a map[CTEStorageID]*CTEStorages in StmtCtx, +// which will store all CTEStorages to make all shared CTEs use same the CTEStorages. +type CTEStorages struct { + ResTbl cteutil.Storage + IterInTbl cteutil.Storage + Producer *cteProducer +} + +func newExecutorBuilder(ctx sessionctx.Context, is infoschema.InfoSchema) *executorBuilder { + txnManager := sessiontxn.GetTxnManager(ctx) + return &executorBuilder{ + ctx: ctx, + is: is, + isStaleness: staleread.IsStmtStaleness(ctx), + txnScope: txnManager.GetTxnScope(), + readReplicaScope: txnManager.GetReadReplicaScope(), + } +} + +// MockExecutorBuilder is a wrapper for executorBuilder. +// ONLY used in test. +type MockExecutorBuilder struct { + *executorBuilder +} + +// NewMockExecutorBuilderForTest is ONLY used in test. +func NewMockExecutorBuilderForTest(ctx sessionctx.Context, is infoschema.InfoSchema) *MockExecutorBuilder { + return &MockExecutorBuilder{ + executorBuilder: newExecutorBuilder(ctx, is)} +} + +// Build builds an executor tree according to `p`. +func (b *MockExecutorBuilder) Build(p base.Plan) exec.Executor { + return b.build(p) +} + +func (b *executorBuilder) build(p base.Plan) exec.Executor { + switch v := p.(type) { + case nil: + return nil + case *plannercore.Change: + return b.buildChange(v) + case *plannercore.CheckTable: + return b.buildCheckTable(v) + case *plannercore.RecoverIndex: + return b.buildRecoverIndex(v) + case *plannercore.CleanupIndex: + return b.buildCleanupIndex(v) + case *plannercore.CheckIndexRange: + return b.buildCheckIndexRange(v) + case *plannercore.ChecksumTable: + return b.buildChecksumTable(v) + case *plannercore.ReloadExprPushdownBlacklist: + return b.buildReloadExprPushdownBlacklist(v) + case *plannercore.ReloadOptRuleBlacklist: + return b.buildReloadOptRuleBlacklist(v) + case *plannercore.AdminPlugins: + return b.buildAdminPlugins(v) + case *plannercore.DDL: + return b.buildDDL(v) + case *plannercore.Deallocate: + return b.buildDeallocate(v) + case *plannercore.Delete: + return b.buildDelete(v) + case *plannercore.Execute: + return b.buildExecute(v) + case *plannercore.Trace: + return b.buildTrace(v) + case *plannercore.Explain: + return b.buildExplain(v) + case *plannercore.PointGetPlan: + return b.buildPointGet(v) + case *plannercore.BatchPointGetPlan: + return b.buildBatchPointGet(v) + case *plannercore.Insert: + return b.buildInsert(v) + case *plannercore.ImportInto: + return b.buildImportInto(v) + case *plannercore.LoadData: + return b.buildLoadData(v) + case *plannercore.LoadStats: + return b.buildLoadStats(v) + case *plannercore.LockStats: + return b.buildLockStats(v) + case *plannercore.UnlockStats: + return b.buildUnlockStats(v) + case *plannercore.IndexAdvise: + return b.buildIndexAdvise(v) + case *plannercore.PlanReplayer: + return b.buildPlanReplayer(v) + case *plannercore.PhysicalLimit: + return b.buildLimit(v) + case *plannercore.Prepare: + return b.buildPrepare(v) + case *plannercore.PhysicalLock: + return b.buildSelectLock(v) + case *plannercore.CancelDDLJobs: + return b.buildCancelDDLJobs(v) + case *plannercore.PauseDDLJobs: + return b.buildPauseDDLJobs(v) + case *plannercore.ResumeDDLJobs: + return b.buildResumeDDLJobs(v) + case *plannercore.ShowNextRowID: + return b.buildShowNextRowID(v) + case *plannercore.ShowDDL: + return b.buildShowDDL(v) + case *plannercore.PhysicalShowDDLJobs: + return b.buildShowDDLJobs(v) + case *plannercore.ShowDDLJobQueries: + return b.buildShowDDLJobQueries(v) + case *plannercore.ShowDDLJobQueriesWithRange: + return b.buildShowDDLJobQueriesWithRange(v) + case *plannercore.ShowSlow: + return b.buildShowSlow(v) + case *plannercore.PhysicalShow: + return b.buildShow(v) + case *plannercore.Simple: + return b.buildSimple(v) + case *plannercore.PhysicalSimpleWrapper: + return b.buildSimple(&v.Inner) + case *plannercore.Set: + return b.buildSet(v) + case *plannercore.SetConfig: + return b.buildSetConfig(v) + case *plannercore.PhysicalSort: + return b.buildSort(v) + case *plannercore.PhysicalTopN: + return b.buildTopN(v) + case *plannercore.PhysicalUnionAll: + return b.buildUnionAll(v) + case *plannercore.Update: + return b.buildUpdate(v) + case *plannercore.PhysicalUnionScan: + return b.buildUnionScanExec(v) + case *plannercore.PhysicalHashJoin: + return b.buildHashJoin(v) + case *plannercore.PhysicalMergeJoin: + return b.buildMergeJoin(v) + case *plannercore.PhysicalIndexJoin: + return b.buildIndexLookUpJoin(v) + case *plannercore.PhysicalIndexMergeJoin: + return b.buildIndexLookUpMergeJoin(v) + case *plannercore.PhysicalIndexHashJoin: + return b.buildIndexNestedLoopHashJoin(v) + case *plannercore.PhysicalSelection: + return b.buildSelection(v) + case *plannercore.PhysicalHashAgg: + return b.buildHashAgg(v) + case *plannercore.PhysicalStreamAgg: + return b.buildStreamAgg(v) + case *plannercore.PhysicalProjection: + return b.buildProjection(v) + case *plannercore.PhysicalMemTable: + return b.buildMemTable(v) + case *plannercore.PhysicalTableDual: + return b.buildTableDual(v) + case *plannercore.PhysicalApply: + return b.buildApply(v) + case *plannercore.PhysicalMaxOneRow: + return b.buildMaxOneRow(v) + case *plannercore.Analyze: + return b.buildAnalyze(v) + case *plannercore.PhysicalTableReader: + return b.buildTableReader(v) + case *plannercore.PhysicalTableSample: + return b.buildTableSample(v) + case *plannercore.PhysicalIndexReader: + return b.buildIndexReader(v) + case *plannercore.PhysicalIndexLookUpReader: + return b.buildIndexLookUpReader(v) + case *plannercore.PhysicalWindow: + return b.buildWindow(v) + case *plannercore.PhysicalShuffle: + return b.buildShuffle(v) + case *plannercore.PhysicalShuffleReceiverStub: + return b.buildShuffleReceiverStub(v) + case *plannercore.SQLBindPlan: + return b.buildSQLBindExec(v) + case *plannercore.SplitRegion: + return b.buildSplitRegion(v) + case *plannercore.PhysicalIndexMergeReader: + return b.buildIndexMergeReader(v) + case *plannercore.SelectInto: + return b.buildSelectInto(v) + case *plannercore.PhysicalCTE: + return b.buildCTE(v) + case *plannercore.PhysicalCTETable: + return b.buildCTETableReader(v) + case *plannercore.CompactTable: + return b.buildCompactTable(v) + case *plannercore.AdminShowBDRRole: + return b.buildAdminShowBDRRole(v) + case *plannercore.PhysicalExpand: + return b.buildExpand(v) + default: + if mp, ok := p.(testutil.MockPhysicalPlan); ok { + return mp.GetExecutor() + } + + b.err = exeerrors.ErrUnknownPlan.GenWithStack("Unknown Plan %T", p) + return nil + } +} + +func (b *executorBuilder) buildCancelDDLJobs(v *plannercore.CancelDDLJobs) exec.Executor { + e := &CancelDDLJobsExec{ + CommandDDLJobsExec: &CommandDDLJobsExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + jobIDs: v.JobIDs, + execute: ddl.CancelJobs, + }, + } + return e +} + +func (b *executorBuilder) buildPauseDDLJobs(v *plannercore.PauseDDLJobs) exec.Executor { + e := &PauseDDLJobsExec{ + CommandDDLJobsExec: &CommandDDLJobsExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + jobIDs: v.JobIDs, + execute: ddl.PauseJobs, + }, + } + return e +} + +func (b *executorBuilder) buildResumeDDLJobs(v *plannercore.ResumeDDLJobs) exec.Executor { + e := &ResumeDDLJobsExec{ + CommandDDLJobsExec: &CommandDDLJobsExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + jobIDs: v.JobIDs, + execute: ddl.ResumeJobs, + }, + } + return e +} + +func (b *executorBuilder) buildChange(v *plannercore.Change) exec.Executor { + return &ChangeExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + ChangeStmt: v.ChangeStmt, + } +} + +func (b *executorBuilder) buildShowNextRowID(v *plannercore.ShowNextRowID) exec.Executor { + e := &ShowNextRowIDExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + tblName: v.TableName, + } + return e +} + +func (b *executorBuilder) buildShowDDL(v *plannercore.ShowDDL) exec.Executor { + // We get Info here because for Executors that returns result set, + // next will be called after transaction has been committed. + // We need the transaction to get Info. + e := &ShowDDLExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + } + + var err error + ownerManager := domain.GetDomain(e.Ctx()).DDL().OwnerManager() + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + e.ddlOwnerID, err = ownerManager.GetOwnerID(ctx) + cancel() + if err != nil { + b.err = err + return nil + } + + session, err := e.GetSysSession() + if err != nil { + b.err = err + return nil + } + ddlInfo, err := ddl.GetDDLInfoWithNewTxn(session) + e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), session) + if err != nil { + b.err = err + return nil + } + e.ddlInfo = ddlInfo + e.selfID = ownerManager.ID() + return e +} + +func (b *executorBuilder) buildShowDDLJobs(v *plannercore.PhysicalShowDDLJobs) exec.Executor { + loc := b.ctx.GetSessionVars().Location() + ddlJobRetriever := DDLJobRetriever{TZLoc: loc} + e := &ShowDDLJobsExec{ + jobNumber: int(v.JobNumber), + is: b.is, + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + DDLJobRetriever: ddlJobRetriever, + } + return e +} + +func (b *executorBuilder) buildShowDDLJobQueries(v *plannercore.ShowDDLJobQueries) exec.Executor { + e := &ShowDDLJobQueriesExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + jobIDs: v.JobIDs, + } + return e +} + +func (b *executorBuilder) buildShowDDLJobQueriesWithRange(v *plannercore.ShowDDLJobQueriesWithRange) exec.Executor { + e := &ShowDDLJobQueriesWithRangeExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + offset: v.Offset, + limit: v.Limit, + } + return e +} + +func (b *executorBuilder) buildShowSlow(v *plannercore.ShowSlow) exec.Executor { + e := &ShowSlowExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + ShowSlow: v.ShowSlow, + } + return e +} + +// buildIndexLookUpChecker builds check information to IndexLookUpReader. +func buildIndexLookUpChecker(b *executorBuilder, p *plannercore.PhysicalIndexLookUpReader, + e *IndexLookUpExecutor) { + is := p.IndexPlans[0].(*plannercore.PhysicalIndexScan) + fullColLen := len(is.Index.Columns) + len(p.CommonHandleCols) + if !e.isCommonHandle() { + fullColLen++ + } + if e.index.Global { + fullColLen++ + } + e.dagPB.OutputOffsets = make([]uint32, fullColLen) + for i := 0; i < fullColLen; i++ { + e.dagPB.OutputOffsets[i] = uint32(i) + } + + ts := p.TablePlans[0].(*plannercore.PhysicalTableScan) + e.handleIdx = ts.HandleIdx + + e.ranges = ranger.FullRange() + + tps := make([]*types.FieldType, 0, fullColLen) + for _, col := range is.Columns { + // tps is used to decode the index, we should use the element type of the array if any. + tps = append(tps, col.FieldType.ArrayType()) + } + + if !e.isCommonHandle() { + tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) + } + if e.index.Global { + tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) + } + + e.checkIndexValue = &checkIndexValue{idxColTps: tps} + + colNames := make([]string, 0, len(is.IdxCols)) + for i := range is.IdxCols { + colNames = append(colNames, is.Columns[i].Name.L) + } + if cols, missingColOffset := table.FindColumns(e.table.Cols(), colNames, true); missingColOffset >= 0 { + b.err = plannererrors.ErrUnknownColumn.GenWithStack("Unknown column %s", is.Columns[missingColOffset].Name.O) + } else { + e.idxTblCols = cols + } +} + +func (b *executorBuilder) buildCheckTable(v *plannercore.CheckTable) exec.Executor { + noMVIndexOrPrefixIndex := true + for _, idx := range v.IndexInfos { + if idx.MVIndex { + noMVIndexOrPrefixIndex = false + break + } + for _, col := range idx.Columns { + if col.Length != types.UnspecifiedLength { + noMVIndexOrPrefixIndex = false + break + } + } + if !noMVIndexOrPrefixIndex { + break + } + } + if b.ctx.GetSessionVars().FastCheckTable && noMVIndexOrPrefixIndex { + e := &FastCheckTableExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + dbName: v.DBName, + table: v.Table, + indexInfos: v.IndexInfos, + is: b.is, + err: &atomic.Pointer[error]{}, + } + return e + } + + readerExecs := make([]*IndexLookUpExecutor, 0, len(v.IndexLookUpReaders)) + for _, readerPlan := range v.IndexLookUpReaders { + readerExec, err := buildNoRangeIndexLookUpReader(b, readerPlan) + if err != nil { + b.err = errors.Trace(err) + return nil + } + buildIndexLookUpChecker(b, readerPlan, readerExec) + + readerExecs = append(readerExecs, readerExec) + } + + e := &CheckTableExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + dbName: v.DBName, + table: v.Table, + indexInfos: v.IndexInfos, + is: b.is, + srcs: readerExecs, + exitCh: make(chan struct{}), + retCh: make(chan error, len(readerExecs)), + checkIndex: v.CheckIndex, + } + return e +} + +func buildIdxColsConcatHandleCols(tblInfo *model.TableInfo, indexInfo *model.IndexInfo, hasGenedCol bool) []*model.ColumnInfo { + var pkCols []*model.IndexColumn + if tblInfo.IsCommonHandle { + pkIdx := tables.FindPrimaryIndex(tblInfo) + pkCols = pkIdx.Columns + } + + columns := make([]*model.ColumnInfo, 0, len(indexInfo.Columns)+len(pkCols)) + if hasGenedCol { + columns = tblInfo.Columns + } else { + for _, idxCol := range indexInfo.Columns { + if tblInfo.PKIsHandle && tblInfo.GetPkColInfo().Offset == idxCol.Offset { + continue + } + columns = append(columns, tblInfo.Columns[idxCol.Offset]) + } + } + + if tblInfo.IsCommonHandle { + for _, c := range pkCols { + if model.FindColumnInfo(columns, c.Name.L) == nil { + columns = append(columns, tblInfo.Columns[c.Offset]) + } + } + return columns + } + if tblInfo.PKIsHandle { + columns = append(columns, tblInfo.Columns[tblInfo.GetPkColInfo().Offset]) + return columns + } + handleOffset := len(columns) + handleColsInfo := &model.ColumnInfo{ + ID: model.ExtraHandleID, + Name: model.ExtraHandleName, + Offset: handleOffset, + } + handleColsInfo.FieldType = *types.NewFieldType(mysql.TypeLonglong) + columns = append(columns, handleColsInfo) + return columns +} + +func (b *executorBuilder) buildRecoverIndex(v *plannercore.RecoverIndex) exec.Executor { + tblInfo := v.Table.TableInfo + t, err := b.is.TableByName(context.Background(), v.Table.Schema, tblInfo.Name) + if err != nil { + b.err = err + return nil + } + idxName := strings.ToLower(v.IndexName) + index := tables.GetWritableIndexByName(idxName, t) + if index == nil { + b.err = errors.Errorf("secondary index `%v` is not found in table `%v`", v.IndexName, v.Table.Name.O) + return nil + } + var hasGenedCol bool + for _, iCol := range index.Meta().Columns { + if tblInfo.Columns[iCol.Offset].IsGenerated() { + hasGenedCol = true + } + } + cols := buildIdxColsConcatHandleCols(tblInfo, index.Meta(), hasGenedCol) + e := &RecoverIndexExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + columns: cols, + containsGenedCol: hasGenedCol, + index: index, + table: t, + physicalID: t.Meta().ID, + } + sessCtx := e.Ctx().GetSessionVars().StmtCtx + e.handleCols = buildHandleColsForExec(sessCtx, tblInfo, e.columns) + return e +} + +func buildHandleColsForExec(sctx *stmtctx.StatementContext, tblInfo *model.TableInfo, + allColInfo []*model.ColumnInfo) plannerutil.HandleCols { + if !tblInfo.IsCommonHandle { + extraColPos := len(allColInfo) - 1 + intCol := &expression.Column{ + Index: extraColPos, + RetType: types.NewFieldType(mysql.TypeLonglong), + } + return plannerutil.NewIntHandleCols(intCol) + } + tblCols := make([]*expression.Column, len(tblInfo.Columns)) + for i := 0; i < len(tblInfo.Columns); i++ { + c := tblInfo.Columns[i] + tblCols[i] = &expression.Column{ + RetType: &c.FieldType, + ID: c.ID, + } + } + pkIdx := tables.FindPrimaryIndex(tblInfo) + for _, c := range pkIdx.Columns { + for j, colInfo := range allColInfo { + if colInfo.Name.L == c.Name.L { + tblCols[c.Offset].Index = j + } + } + } + return plannerutil.NewCommonHandleCols(sctx, tblInfo, pkIdx, tblCols) +} + +func (b *executorBuilder) buildCleanupIndex(v *plannercore.CleanupIndex) exec.Executor { + tblInfo := v.Table.TableInfo + t, err := b.is.TableByName(context.Background(), v.Table.Schema, tblInfo.Name) + if err != nil { + b.err = err + return nil + } + idxName := strings.ToLower(v.IndexName) + var index table.Index + for _, idx := range t.Indices() { + if idx.Meta().State != model.StatePublic { + continue + } + if idxName == idx.Meta().Name.L { + index = idx + break + } + } + + if index == nil { + b.err = errors.Errorf("secondary index `%v` is not found in table `%v`", v.IndexName, v.Table.Name.O) + return nil + } + e := &CleanupIndexExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + columns: buildIdxColsConcatHandleCols(tblInfo, index.Meta(), false), + index: index, + table: t, + physicalID: t.Meta().ID, + batchSize: 20000, + } + sessCtx := e.Ctx().GetSessionVars().StmtCtx + e.handleCols = buildHandleColsForExec(sessCtx, tblInfo, e.columns) + if e.index.Meta().Global { + e.columns = append(e.columns, model.NewExtraPhysTblIDColInfo()) + } + return e +} + +func (b *executorBuilder) buildCheckIndexRange(v *plannercore.CheckIndexRange) exec.Executor { + tb, err := b.is.TableByName(context.Background(), v.Table.Schema, v.Table.Name) + if err != nil { + b.err = err + return nil + } + e := &CheckIndexRangeExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + handleRanges: v.HandleRanges, + table: tb.Meta(), + is: b.is, + } + idxName := strings.ToLower(v.IndexName) + for _, idx := range tb.Indices() { + if idx.Meta().Name.L == idxName { + e.index = idx.Meta() + e.startKey = make([]types.Datum, len(e.index.Columns)) + break + } + } + return e +} + +func (b *executorBuilder) buildChecksumTable(v *plannercore.ChecksumTable) exec.Executor { + e := &ChecksumTableExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + tables: make(map[int64]*checksumContext), + done: false, + } + startTs, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + for _, t := range v.Tables { + e.tables[t.TableInfo.ID] = newChecksumContext(t.DBInfo, t.TableInfo, startTs) + } + return e +} + +func (b *executorBuilder) buildReloadExprPushdownBlacklist(_ *plannercore.ReloadExprPushdownBlacklist) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, nil, 0) + return &ReloadExprPushdownBlacklistExec{base} +} + +func (b *executorBuilder) buildReloadOptRuleBlacklist(_ *plannercore.ReloadOptRuleBlacklist) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, nil, 0) + return &ReloadOptRuleBlacklistExec{BaseExecutor: base} +} + +func (b *executorBuilder) buildAdminPlugins(v *plannercore.AdminPlugins) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, nil, 0) + return &AdminPluginsExec{BaseExecutor: base, Action: v.Action, Plugins: v.Plugins} +} + +func (b *executorBuilder) buildDeallocate(v *plannercore.Deallocate) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, nil, v.ID()) + base.SetInitCap(chunk.ZeroCapacity) + e := &DeallocateExec{ + BaseExecutor: base, + Name: v.Name, + } + return e +} + +func (b *executorBuilder) buildSelectLock(v *plannercore.PhysicalLock) exec.Executor { + if !b.inSelectLockStmt { + b.inSelectLockStmt = true + defer func() { b.inSelectLockStmt = false }() + } + if b.err = b.updateForUpdateTS(); b.err != nil { + return nil + } + + src := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + if !b.ctx.GetSessionVars().PessimisticLockEligible() { + // Locking of rows for update using SELECT FOR UPDATE only applies when autocommit + // is disabled (either by beginning transaction with START TRANSACTION or by setting + // autocommit to 0. If autocommit is enabled, the rows matching the specification are not locked. + // See https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html + return src + } + // If the `PhysicalLock` is not ignored by the above logic, set the `hasLock` flag. + b.hasLock = true + e := &SelectLockExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), src), + Lock: v.Lock, + tblID2Handle: v.TblID2Handle, + tblID2PhysTblIDCol: v.TblID2PhysTblIDCol, + } + + // filter out temporary tables because they do not store any record in tikv and should not write any lock + is := e.Ctx().GetInfoSchema().(infoschema.InfoSchema) + for tblID := range e.tblID2Handle { + tblInfo, ok := is.TableByID(tblID) + if !ok { + b.err = errors.Errorf("Can not get table %d", tblID) + } + + if tblInfo.Meta().TempTableType != model.TempTableNone { + delete(e.tblID2Handle, tblID) + } + } + + return e +} + +func (b *executorBuilder) buildLimit(v *plannercore.PhysicalLimit) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + n := int(min(v.Count, uint64(b.ctx.GetSessionVars().MaxChunkSize))) + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec) + base.SetInitCap(n) + e := &LimitExec{ + BaseExecutor: base, + begin: v.Offset, + end: v.Offset + v.Count, + } + + childUsedSchemaLen := v.Children()[0].Schema().Len() + childUsedSchema := markChildrenUsedCols(v.Schema().Columns, v.Children()[0].Schema())[0] + e.columnIdxsUsedByChild = make([]int, 0, len(childUsedSchema)) + e.columnIdxsUsedByChild = append(e.columnIdxsUsedByChild, childUsedSchema...) + if len(e.columnIdxsUsedByChild) == childUsedSchemaLen { + e.columnIdxsUsedByChild = nil // indicates that all columns are used. LimitExec will improve performance for this condition. + } + return e +} + +func (b *executorBuilder) buildPrepare(v *plannercore.Prepare) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + base.SetInitCap(chunk.ZeroCapacity) + return &PrepareExec{ + BaseExecutor: base, + name: v.Name, + sqlText: v.SQLText, + } +} + +func (b *executorBuilder) buildExecute(v *plannercore.Execute) exec.Executor { + e := &ExecuteExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + is: b.is, + name: v.Name, + usingVars: v.Params, + stmt: v.Stmt, + plan: v.Plan, + outputNames: v.OutputNames(), + } + + failpoint.Inject("assertExecutePrepareStatementStalenessOption", func(val failpoint.Value) { + vs := strings.Split(val.(string), "_") + assertTS, assertReadReplicaScope := vs[0], vs[1] + staleread.AssertStmtStaleness(b.ctx, true) + ts, err := sessiontxn.GetTxnManager(b.ctx).GetStmtReadTS() + if err != nil { + panic(e) + } + + if strconv.FormatUint(ts, 10) != assertTS || + assertReadReplicaScope != b.readReplicaScope { + panic("execute prepare statement have wrong staleness option") + } + }) + + return e +} + +func (b *executorBuilder) buildShow(v *plannercore.PhysicalShow) exec.Executor { + e := &ShowExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + Tp: v.Tp, + CountWarningsOrErrors: v.CountWarningsOrErrors, + DBName: model.NewCIStr(v.DBName), + Table: v.Table, + Partition: v.Partition, + Column: v.Column, + IndexName: v.IndexName, + ResourceGroupName: model.NewCIStr(v.ResourceGroupName), + Flag: v.Flag, + Roles: v.Roles, + User: v.User, + is: b.is, + Full: v.Full, + IfNotExists: v.IfNotExists, + GlobalScope: v.GlobalScope, + Extended: v.Extended, + Extractor: v.Extractor, + ImportJobID: v.ImportJobID, + } + if e.Tp == ast.ShowMasterStatus || e.Tp == ast.ShowBinlogStatus { + // show master status need start ts. + if _, err := e.Ctx().Txn(true); err != nil { + b.err = err + } + } + return e +} + +func (b *executorBuilder) buildSimple(v *plannercore.Simple) exec.Executor { + switch s := v.Statement.(type) { + case *ast.GrantStmt: + return b.buildGrant(s) + case *ast.RevokeStmt: + return b.buildRevoke(s) + case *ast.BRIEStmt: + return b.buildBRIE(s, v.Schema()) + case *ast.CalibrateResourceStmt: + return &calibrateresource.Executor{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), 0), + WorkloadType: s.Tp, + OptionList: s.DynamicCalibrateResourceOptionList, + } + case *ast.AddQueryWatchStmt: + return &querywatch.AddExecutor{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), 0), + QueryWatchOptionList: s.QueryWatchOptionList, + } + case *ast.ImportIntoActionStmt: + return &ImportIntoActionExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, 0), + tp: s.Tp, + jobID: s.JobID, + } + } + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + base.SetInitCap(chunk.ZeroCapacity) + e := &SimpleExec{ + BaseExecutor: base, + Statement: v.Statement, + IsFromRemote: v.IsFromRemote, + is: b.is, + staleTxnStartTS: v.StaleTxnStartTS, + } + return e +} + +func (b *executorBuilder) buildSet(v *plannercore.Set) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + base.SetInitCap(chunk.ZeroCapacity) + e := &SetExecutor{ + BaseExecutor: base, + vars: v.VarAssigns, + } + return e +} + +func (b *executorBuilder) buildSetConfig(v *plannercore.SetConfig) exec.Executor { + return &SetConfigExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + p: v, + } +} + +func (b *executorBuilder) buildInsert(v *plannercore.Insert) exec.Executor { + b.inInsertStmt = true + if b.err = b.updateForUpdateTS(); b.err != nil { + return nil + } + + selectExec := b.build(v.SelectPlan) + if b.err != nil { + return nil + } + var baseExec exec.BaseExecutor + if selectExec != nil { + baseExec = exec.NewBaseExecutor(b.ctx, nil, v.ID(), selectExec) + } else { + baseExec = exec.NewBaseExecutor(b.ctx, nil, v.ID()) + } + baseExec.SetInitCap(chunk.ZeroCapacity) + + ivs := &InsertValues{ + BaseExecutor: baseExec, + Table: v.Table, + Columns: v.Columns, + Lists: v.Lists, + GenExprs: v.GenCols.Exprs, + allAssignmentsAreConstant: v.AllAssignmentsAreConstant, + hasRefCols: v.NeedFillDefaultValue, + SelectExec: selectExec, + rowLen: v.RowLen, + } + err := ivs.initInsertColumns() + if err != nil { + b.err = err + return nil + } + ivs.fkChecks, b.err = buildFKCheckExecs(b.ctx, ivs.Table, v.FKChecks) + if b.err != nil { + return nil + } + ivs.fkCascades, b.err = b.buildFKCascadeExecs(ivs.Table, v.FKCascades) + if b.err != nil { + return nil + } + + if v.IsReplace { + return b.buildReplace(ivs) + } + insert := &InsertExec{ + InsertValues: ivs, + OnDuplicate: append(v.OnDuplicate, v.GenCols.OnDuplicates...), + } + return insert +} + +func (b *executorBuilder) buildImportInto(v *plannercore.ImportInto) exec.Executor { + // see planBuilder.buildImportInto for detail why we use the latest schema here. + latestIS := b.ctx.GetDomainInfoSchema().(infoschema.InfoSchema) + tbl, ok := latestIS.TableByID(v.Table.TableInfo.ID) + if !ok { + b.err = errors.Errorf("Can not get table %d", v.Table.TableInfo.ID) + return nil + } + if !tbl.Meta().IsBaseTable() { + b.err = plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(tbl.Meta().Name.O, "IMPORT") + return nil + } + + var ( + selectExec exec.Executor + base exec.BaseExecutor + ) + if v.SelectPlan != nil { + selectExec = b.build(v.SelectPlan) + if b.err != nil { + return nil + } + base = exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), selectExec) + } else { + base = exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + } + executor, err := newImportIntoExec(base, selectExec, b.ctx, v, tbl) + if err != nil { + b.err = err + return nil + } + + return executor +} + +func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) exec.Executor { + tbl, ok := b.is.TableByID(v.Table.TableInfo.ID) + if !ok { + b.err = errors.Errorf("Can not get table %d", v.Table.TableInfo.ID) + return nil + } + if !tbl.Meta().IsBaseTable() { + b.err = plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(tbl.Meta().Name.O, "LOAD") + return nil + } + + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + worker, err := NewLoadDataWorker(b.ctx, v, tbl) + if err != nil { + b.err = err + return nil + } + + return &LoadDataExec{ + BaseExecutor: base, + loadDataWorker: worker, + FileLocRef: v.FileLocRef, + } +} + +func (b *executorBuilder) buildLoadStats(v *plannercore.LoadStats) exec.Executor { + e := &LoadStatsExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), + info: &LoadStatsInfo{v.Path, b.ctx}, + } + return e +} + +func (b *executorBuilder) buildLockStats(v *plannercore.LockStats) exec.Executor { + e := &lockstats.LockExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), + Tables: v.Tables, + } + return e +} + +func (b *executorBuilder) buildUnlockStats(v *plannercore.UnlockStats) exec.Executor { + e := &lockstats.UnlockExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), + Tables: v.Tables, + } + return e +} + +func (b *executorBuilder) buildIndexAdvise(v *plannercore.IndexAdvise) exec.Executor { + e := &IndexAdviseExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), + IsLocal: v.IsLocal, + indexAdviseInfo: &IndexAdviseInfo{ + Path: v.Path, + MaxMinutes: v.MaxMinutes, + MaxIndexNum: v.MaxIndexNum, + LineFieldsInfo: v.LineFieldsInfo, + Ctx: b.ctx, + }, + } + return e +} + +func (b *executorBuilder) buildPlanReplayer(v *plannercore.PlanReplayer) exec.Executor { + if v.Load { + e := &PlanReplayerLoadExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), + info: &PlanReplayerLoadInfo{Path: v.File, Ctx: b.ctx}, + } + return e + } + if v.Capture { + e := &PlanReplayerExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), + CaptureInfo: &PlanReplayerCaptureInfo{ + SQLDigest: v.SQLDigest, + PlanDigest: v.PlanDigest, + }, + } + return e + } + if v.Remove { + e := &PlanReplayerExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), + CaptureInfo: &PlanReplayerCaptureInfo{ + SQLDigest: v.SQLDigest, + PlanDigest: v.PlanDigest, + Remove: true, + }, + } + return e + } + + e := &PlanReplayerExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + DumpInfo: &PlanReplayerDumpInfo{ + Analyze: v.Analyze, + Path: v.File, + ctx: b.ctx, + HistoricalStatsTS: v.HistoricalStatsTS, + }, + } + if v.ExecStmt != nil { + e.DumpInfo.ExecStmts = []ast.StmtNode{v.ExecStmt} + } else { + e.BaseExecutor = exec.NewBaseExecutor(b.ctx, nil, v.ID()) + } + return e +} + +func (*executorBuilder) buildReplace(vals *InsertValues) exec.Executor { + replaceExec := &ReplaceExec{ + InsertValues: vals, + } + return replaceExec +} + +func (b *executorBuilder) buildGrant(grant *ast.GrantStmt) exec.Executor { + e := &GrantExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, 0), + Privs: grant.Privs, + ObjectType: grant.ObjectType, + Level: grant.Level, + Users: grant.Users, + WithGrant: grant.WithGrant, + AuthTokenOrTLSOptions: grant.AuthTokenOrTLSOptions, + is: b.is, + } + return e +} + +func (b *executorBuilder) buildRevoke(revoke *ast.RevokeStmt) exec.Executor { + e := &RevokeExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, 0), + ctx: b.ctx, + Privs: revoke.Privs, + ObjectType: revoke.ObjectType, + Level: revoke.Level, + Users: revoke.Users, + is: b.is, + } + return e +} + +func (b *executorBuilder) buildDDL(v *plannercore.DDL) exec.Executor { + e := &DDLExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + ddlExecutor: domain.GetDomain(b.ctx).DDLExecutor(), + stmt: v.Statement, + is: b.is, + tempTableDDL: temptable.GetTemporaryTableDDL(b.ctx), + } + return e +} + +// buildTrace builds a TraceExec for future executing. This method will be called +// at build(). +func (b *executorBuilder) buildTrace(v *plannercore.Trace) exec.Executor { + t := &TraceExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + stmtNode: v.StmtNode, + builder: b, + format: v.Format, + + optimizerTrace: v.OptimizerTrace, + optimizerTraceTarget: v.OptimizerTraceTarget, + } + if t.format == plannercore.TraceFormatLog && !t.optimizerTrace { + return &sortexec.SortExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), t), + ByItems: []*plannerutil.ByItems{ + {Expr: &expression.Column{ + Index: 0, + RetType: types.NewFieldType(mysql.TypeTimestamp), + }}, + }, + ExecSchema: v.Schema(), + } + } + return t +} + +// buildExplain builds a explain executor. `e.rows` collects final result to `ExplainExec`. +func (b *executorBuilder) buildExplain(v *plannercore.Explain) exec.Executor { + explainExec := &ExplainExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + explain: v, + } + if v.Analyze { + if b.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl == nil { + b.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(nil) + } + } + // Needs to build the target plan, even if not executing it + // to get partition pruning. + explainExec.analyzeExec = b.build(v.TargetPlan) + return explainExec +} + +func (b *executorBuilder) buildSelectInto(v *plannercore.SelectInto) exec.Executor { + child := b.build(v.TargetPlan) + if b.err != nil { + return nil + } + return &SelectIntoExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), child), + intoOpt: v.IntoOpt, + LineFieldsInfo: v.LineFieldsInfo, + } +} + +func (b *executorBuilder) buildUnionScanExec(v *plannercore.PhysicalUnionScan) exec.Executor { + oriEncounterUnionScan := b.encounterUnionScan + b.encounterUnionScan = true + defer func() { + b.encounterUnionScan = oriEncounterUnionScan + }() + reader := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + + return b.buildUnionScanFromReader(reader, v) +} + +// buildUnionScanFromReader builds union scan executor from child executor. +// Note that this function may be called by inner workers of index lookup join concurrently. +// Be careful to avoid data race. +func (b *executorBuilder) buildUnionScanFromReader(reader exec.Executor, v *plannercore.PhysicalUnionScan) exec.Executor { + // If reader is union, it means a partition table and we should transfer as above. + if x, ok := reader.(*unionexec.UnionExec); ok { + for i, child := range x.AllChildren() { + x.SetChildren(i, b.buildUnionScanFromReader(child, v)) + if b.err != nil { + return nil + } + } + return x + } + us := &UnionScanExec{BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), reader)} + // Get the handle column index of the below Plan. + us.handleCols = v.HandleCols + us.mutableRow = chunk.MutRowFromTypes(exec.RetTypes(us)) + + // If the push-downed condition contains virtual column, we may build a selection upon reader + originReader := reader + if sel, ok := reader.(*SelectionExec); ok { + reader = sel.Children(0) + } + + us.collators = make([]collate.Collator, 0, len(us.columns)) + for _, tp := range exec.RetTypes(us) { + us.collators = append(us.collators, collate.GetCollator(tp.GetCollate())) + } + + startTS, err := b.getSnapshotTS() + sessionVars := b.ctx.GetSessionVars() + if err != nil { + b.err = err + return nil + } + + switch x := reader.(type) { + case *MPPGather: + us.desc = false + us.keepOrder = false + us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) + us.columns = x.columns + us.table = x.table + us.virtualColumnIndex = x.virtualColumnIndex + us.handleCachedTable(b, x, sessionVars, startTS) + case *TableReaderExecutor: + us.desc = x.desc + us.keepOrder = x.keepOrder + us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) + us.columns = x.columns + us.table = x.table + us.virtualColumnIndex = x.virtualColumnIndex + us.handleCachedTable(b, x, sessionVars, startTS) + case *IndexReaderExecutor: + us.desc = x.desc + us.keepOrder = x.keepOrder + for _, ic := range x.index.Columns { + for i, col := range x.columns { + if col.Name.L == ic.Name.L { + us.usedIndex = append(us.usedIndex, i) + break + } + } + } + us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) + us.columns = x.columns + us.partitionIDMap = x.partitionIDMap + us.table = x.table + us.handleCachedTable(b, x, sessionVars, startTS) + case *IndexLookUpExecutor: + us.desc = x.desc + us.keepOrder = x.keepOrder + for _, ic := range x.index.Columns { + for i, col := range x.columns { + if col.Name.L == ic.Name.L { + us.usedIndex = append(us.usedIndex, i) + break + } + } + } + us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) + us.columns = x.columns + us.table = x.table + us.partitionIDMap = x.partitionIDMap + us.virtualColumnIndex = buildVirtualColumnIndex(us.Schema(), us.columns) + us.handleCachedTable(b, x, sessionVars, startTS) + case *IndexMergeReaderExecutor: + if len(x.byItems) != 0 { + us.keepOrder = x.keepOrder + us.desc = x.byItems[0].Desc + for _, item := range x.byItems { + c, ok := item.Expr.(*expression.Column) + if !ok { + b.err = errors.Errorf("Not support non-column in orderBy pushed down") + return nil + } + for i, col := range x.columns { + if col.ID == c.ID { + us.usedIndex = append(us.usedIndex, i) + break + } + } + } + } + us.partitionIDMap = x.partitionIDMap + us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) + us.columns = x.columns + us.table = x.table + us.virtualColumnIndex = buildVirtualColumnIndex(us.Schema(), us.columns) + case *PointGetExecutor, *BatchPointGetExec, // PointGet and BatchPoint can handle virtual columns and dirty txn data themselves. + *TableDualExec, // If TableDual, the result must be empty, so we can skip UnionScan and use TableDual directly here. + *TableSampleExecutor: // TableSample only supports sampling from disk, don't need to consider in-memory txn data for simplicity. + return originReader + default: + // TODO: consider more operators like Projection. + b.err = errors.NewNoStackErrorf("unexpected operator %T under UnionScan", reader) + return nil + } + return us +} + +type bypassDataSourceExecutor interface { + dataSourceExecutor + setDummy() +} + +func (us *UnionScanExec) handleCachedTable(b *executorBuilder, x bypassDataSourceExecutor, vars *variable.SessionVars, startTS uint64) { + tbl := x.Table() + if tbl.Meta().TableCacheStatusType == model.TableCacheStatusEnable { + cachedTable := tbl.(table.CachedTable) + // Determine whether the cache can be used. + leaseDuration := time.Duration(variable.TableCacheLease.Load()) * time.Second + cacheData, loading := cachedTable.TryReadFromCache(startTS, leaseDuration) + if cacheData != nil { + vars.StmtCtx.ReadFromTableCache = true + x.setDummy() + us.cacheTable = cacheData + } else if loading { + return + } else { + if !b.inUpdateStmt && !b.inDeleteStmt && !b.inInsertStmt && !vars.StmtCtx.InExplainStmt { + store := b.ctx.GetStore() + cachedTable.UpdateLockForRead(context.Background(), store, startTS, leaseDuration) + } + } + } +} + +// buildMergeJoin builds MergeJoinExec executor. +func (b *executorBuilder) buildMergeJoin(v *plannercore.PhysicalMergeJoin) exec.Executor { + leftExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + + rightExec := b.build(v.Children()[1]) + if b.err != nil { + return nil + } + + defaultValues := v.DefaultValues + if defaultValues == nil { + if v.JoinType == plannercore.RightOuterJoin { + defaultValues = make([]types.Datum, leftExec.Schema().Len()) + } else { + defaultValues = make([]types.Datum, rightExec.Schema().Len()) + } + } + + colsFromChildren := v.Schema().Columns + if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { + colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] + } + + e := &join.MergeJoinExec{ + StmtCtx: b.ctx.GetSessionVars().StmtCtx, + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), + CompareFuncs: v.CompareFuncs, + Joiner: join.NewJoiner( + b.ctx, + v.JoinType, + v.JoinType == plannercore.RightOuterJoin, + defaultValues, + v.OtherConditions, + exec.RetTypes(leftExec), + exec.RetTypes(rightExec), + markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()), + false, + ), + IsOuterJoin: v.JoinType.IsOuterJoin(), + Desc: v.Desc, + } + + leftTable := &join.MergeJoinTable{ + ChildIndex: 0, + JoinKeys: v.LeftJoinKeys, + Filters: v.LeftConditions, + } + rightTable := &join.MergeJoinTable{ + ChildIndex: 1, + JoinKeys: v.RightJoinKeys, + Filters: v.RightConditions, + } + + if v.JoinType == plannercore.RightOuterJoin { + e.InnerTable = leftTable + e.OuterTable = rightTable + } else { + e.InnerTable = rightTable + e.OuterTable = leftTable + } + e.InnerTable.IsInner = true + + // optimizer should guarantee that filters on inner table are pushed down + // to tikv or extracted to a Selection. + if len(e.InnerTable.Filters) != 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "merge join's inner filter should be empty.") + return nil + } + + executor_metrics.ExecutorCounterMergeJoinExec.Inc() + return e +} + +func collectColumnIndexFromExpr(expr expression.Expression, leftColumnSize int, leftColumnIndex []int, rightColumnIndex []int) ([]int, []int) { + switch x := expr.(type) { + case *expression.Column: + colIndex := x.Index + if colIndex >= leftColumnSize { + rightColumnIndex = append(rightColumnIndex, colIndex-leftColumnSize) + } else { + leftColumnIndex = append(leftColumnIndex, colIndex) + } + return leftColumnIndex, rightColumnIndex + case *expression.Constant: + return leftColumnIndex, rightColumnIndex + case *expression.ScalarFunction: + for _, arg := range x.GetArgs() { + leftColumnIndex, rightColumnIndex = collectColumnIndexFromExpr(arg, leftColumnSize, leftColumnIndex, rightColumnIndex) + } + return leftColumnIndex, rightColumnIndex + default: + panic("unsupported expression") + } +} + +func extractUsedColumnsInJoinOtherCondition(expr expression.CNFExprs, leftColumnSize int) ([]int, []int) { + leftColumnIndex := make([]int, 0, 1) + rightColumnIndex := make([]int, 0, 1) + for _, subExpr := range expr { + leftColumnIndex, rightColumnIndex = collectColumnIndexFromExpr(subExpr, leftColumnSize, leftColumnIndex, rightColumnIndex) + } + return leftColumnIndex, rightColumnIndex +} + +func (b *executorBuilder) buildHashJoinV2(v *plannercore.PhysicalHashJoin) exec.Executor { + leftExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + + rightExec := b.build(v.Children()[1]) + if b.err != nil { + return nil + } + + e := &join.HashJoinV2Exec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), + ProbeSideTupleFetcher: &join.ProbeSideTupleFetcherV2{}, + ProbeWorkers: make([]*join.ProbeWorkerV2, v.Concurrency), + BuildWorkers: make([]*join.BuildWorkerV2, v.Concurrency), + HashJoinCtxV2: &join.HashJoinCtxV2{ + OtherCondition: v.OtherConditions, + }, + } + e.HashJoinCtxV2.SessCtx = b.ctx + e.HashJoinCtxV2.JoinType = v.JoinType + e.HashJoinCtxV2.Concurrency = v.Concurrency + e.HashJoinCtxV2.SetupPartitionInfo() + e.ChunkAllocPool = e.AllocPool + e.HashJoinCtxV2.RightAsBuildSide = true + if v.InnerChildIdx == 1 && v.UseOuterToBuild { + e.HashJoinCtxV2.RightAsBuildSide = false + } else if v.InnerChildIdx == 0 && !v.UseOuterToBuild { + e.HashJoinCtxV2.RightAsBuildSide = false + } + + lhsTypes, rhsTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) + joinedTypes := make([]*types.FieldType, 0, len(lhsTypes)+len(rhsTypes)) + joinedTypes = append(joinedTypes, lhsTypes...) + joinedTypes = append(joinedTypes, rhsTypes...) + + if v.InnerChildIdx == 1 { + if len(v.RightConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } else { + if len(v.LeftConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } + + var probeKeys, buildKeys []*expression.Column + var buildSideExec exec.Executor + if v.UseOuterToBuild { + if v.InnerChildIdx == 1 { + buildSideExec, buildKeys = leftExec, v.LeftJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = rightExec, v.RightJoinKeys + e.HashJoinCtxV2.BuildFilter = v.LeftConditions + } else { + buildSideExec, buildKeys = rightExec, v.RightJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = leftExec, v.LeftJoinKeys + e.HashJoinCtxV2.BuildFilter = v.RightConditions + } + } else { + if v.InnerChildIdx == 0 { + buildSideExec, buildKeys = leftExec, v.LeftJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = rightExec, v.RightJoinKeys + e.HashJoinCtxV2.ProbeFilter = v.RightConditions + } else { + buildSideExec, buildKeys = rightExec, v.RightJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = leftExec, v.LeftJoinKeys + e.HashJoinCtxV2.ProbeFilter = v.LeftConditions + } + } + probeKeyColIdx := make([]int, len(probeKeys)) + buildKeyColIdx := make([]int, len(buildKeys)) + for i := range buildKeys { + buildKeyColIdx[i] = buildKeys[i].Index + } + for i := range probeKeys { + probeKeyColIdx[i] = probeKeys[i].Index + } + + colsFromChildren := v.Schema().Columns + if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { + // the matched column is added inside join + colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] + } + childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) + if childrenUsedSchema == nil { + b.err = errors.New("children used should never be nil") + return nil + } + e.LUsed = make([]int, 0, len(childrenUsedSchema[0])) + e.LUsed = append(e.LUsed, childrenUsedSchema[0]...) + e.RUsed = make([]int, 0, len(childrenUsedSchema[1])) + e.RUsed = append(e.RUsed, childrenUsedSchema[1]...) + if v.OtherConditions != nil { + leftColumnSize := v.Children()[0].Schema().Len() + e.LUsedInOtherCondition, e.RUsedInOtherCondition = extractUsedColumnsInJoinOtherCondition(v.OtherConditions, leftColumnSize) + } + // todo add partition hash join exec + executor_metrics.ExecutorCountHashJoinExec.Inc() + + leftExecTypes, rightExecTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) + leftTypes, rightTypes := make([]*types.FieldType, 0, len(v.LeftJoinKeys)+len(v.LeftNAJoinKeys)), make([]*types.FieldType, 0, len(v.RightJoinKeys)+len(v.RightNAJoinKeys)) + for i, col := range v.LeftJoinKeys { + leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) + leftTypes[i].SetFlag(col.RetType.GetFlag()) + } + offset := len(v.LeftJoinKeys) + for i, col := range v.LeftNAJoinKeys { + leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) + leftTypes[i+offset].SetFlag(col.RetType.GetFlag()) + } + for i, col := range v.RightJoinKeys { + rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) + rightTypes[i].SetFlag(col.RetType.GetFlag()) + } + offset = len(v.RightJoinKeys) + for i, col := range v.RightNAJoinKeys { + rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) + rightTypes[i+offset].SetFlag(col.RetType.GetFlag()) + } + + // consider collations + for i := range v.EqualConditions { + chs, coll := v.EqualConditions[i].CharsetAndCollation() + leftTypes[i].SetCharset(chs) + leftTypes[i].SetCollate(coll) + rightTypes[i].SetCharset(chs) + rightTypes[i].SetCollate(coll) + } + offset = len(v.EqualConditions) + for i := range v.NAEqualConditions { + chs, coll := v.NAEqualConditions[i].CharsetAndCollation() + leftTypes[i+offset].SetCharset(chs) + leftTypes[i+offset].SetCollate(coll) + rightTypes[i+offset].SetCharset(chs) + rightTypes[i+offset].SetCollate(coll) + } + if e.RightAsBuildSide { + e.BuildKeyTypes, e.ProbeKeyTypes = rightTypes, leftTypes + } else { + e.BuildKeyTypes, e.ProbeKeyTypes = leftTypes, rightTypes + } + for i := uint(0); i < e.Concurrency; i++ { + e.ProbeWorkers[i] = &join.ProbeWorkerV2{ + HashJoinCtx: e.HashJoinCtxV2, + JoinProbe: join.NewJoinProbe(e.HashJoinCtxV2, i, v.JoinType, probeKeyColIdx, joinedTypes, e.ProbeKeyTypes, e.RightAsBuildSide), + } + e.ProbeWorkers[i].WorkerID = i + + e.BuildWorkers[i] = join.NewJoinBuildWorkerV2(e.HashJoinCtxV2, i, buildSideExec, buildKeyColIdx, exec.RetTypes(buildSideExec)) + } + return e +} + +func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) exec.Executor { + if join.IsHashJoinV2Enabled() && v.CanUseHashJoinV2() { + return b.buildHashJoinV2(v) + } + leftExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + + rightExec := b.build(v.Children()[1]) + if b.err != nil { + return nil + } + + e := &join.HashJoinV1Exec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), + ProbeSideTupleFetcher: &join.ProbeSideTupleFetcherV1{}, + ProbeWorkers: make([]*join.ProbeWorkerV1, v.Concurrency), + BuildWorker: &join.BuildWorkerV1{}, + HashJoinCtxV1: &join.HashJoinCtxV1{ + IsOuterJoin: v.JoinType.IsOuterJoin(), + UseOuterToBuild: v.UseOuterToBuild, + }, + } + e.HashJoinCtxV1.SessCtx = b.ctx + e.HashJoinCtxV1.JoinType = v.JoinType + e.HashJoinCtxV1.Concurrency = v.Concurrency + e.HashJoinCtxV1.ChunkAllocPool = e.AllocPool + defaultValues := v.DefaultValues + lhsTypes, rhsTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) + if v.InnerChildIdx == 1 { + if len(v.RightConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } else { + if len(v.LeftConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } + + leftIsBuildSide := true + + e.IsNullEQ = v.IsNullEQ + var probeKeys, probeNAKeys, buildKeys, buildNAKeys []*expression.Column + var buildSideExec exec.Executor + if v.UseOuterToBuild { + // update the buildSideEstCount due to changing the build side + if v.InnerChildIdx == 1 { + buildSideExec, buildKeys, buildNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys + e.OuterFilter = v.LeftConditions + } else { + buildSideExec, buildKeys, buildNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys + e.OuterFilter = v.RightConditions + leftIsBuildSide = false + } + if defaultValues == nil { + defaultValues = make([]types.Datum, e.ProbeSideTupleFetcher.ProbeSideExec.Schema().Len()) + } + } else { + if v.InnerChildIdx == 0 { + buildSideExec, buildKeys, buildNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys + e.OuterFilter = v.RightConditions + } else { + buildSideExec, buildKeys, buildNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys + e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys + e.OuterFilter = v.LeftConditions + leftIsBuildSide = false + } + if defaultValues == nil { + defaultValues = make([]types.Datum, buildSideExec.Schema().Len()) + } + } + probeKeyColIdx := make([]int, len(probeKeys)) + probeNAKeColIdx := make([]int, len(probeNAKeys)) + buildKeyColIdx := make([]int, len(buildKeys)) + buildNAKeyColIdx := make([]int, len(buildNAKeys)) + for i := range buildKeys { + buildKeyColIdx[i] = buildKeys[i].Index + } + for i := range buildNAKeys { + buildNAKeyColIdx[i] = buildNAKeys[i].Index + } + for i := range probeKeys { + probeKeyColIdx[i] = probeKeys[i].Index + } + for i := range probeNAKeys { + probeNAKeColIdx[i] = probeNAKeys[i].Index + } + isNAJoin := len(v.LeftNAJoinKeys) > 0 + colsFromChildren := v.Schema().Columns + if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { + colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] + } + childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) + for i := uint(0); i < e.Concurrency; i++ { + e.ProbeWorkers[i] = &join.ProbeWorkerV1{ + HashJoinCtx: e.HashJoinCtxV1, + Joiner: join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, lhsTypes, rhsTypes, childrenUsedSchema, isNAJoin), + ProbeKeyColIdx: probeKeyColIdx, + ProbeNAKeyColIdx: probeNAKeColIdx, + } + e.ProbeWorkers[i].WorkerID = i + } + e.BuildWorker.BuildKeyColIdx, e.BuildWorker.BuildNAKeyColIdx, e.BuildWorker.BuildSideExec, e.BuildWorker.HashJoinCtx = buildKeyColIdx, buildNAKeyColIdx, buildSideExec, e.HashJoinCtxV1 + e.HashJoinCtxV1.IsNullAware = isNAJoin + executor_metrics.ExecutorCountHashJoinExec.Inc() + + // We should use JoinKey to construct the type information using by hashing, instead of using the child's schema directly. + // When a hybrid type column is hashed multiple times, we need to distinguish what field types are used. + // For example, the condition `enum = int and enum = string`, we should use ETInt to hash the first column, + // and use ETString to hash the second column, although they may be the same column. + leftExecTypes, rightExecTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) + leftTypes, rightTypes := make([]*types.FieldType, 0, len(v.LeftJoinKeys)+len(v.LeftNAJoinKeys)), make([]*types.FieldType, 0, len(v.RightJoinKeys)+len(v.RightNAJoinKeys)) + // set left types and right types for joiner. + for i, col := range v.LeftJoinKeys { + leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) + leftTypes[i].SetFlag(col.RetType.GetFlag()) + } + offset := len(v.LeftJoinKeys) + for i, col := range v.LeftNAJoinKeys { + leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) + leftTypes[i+offset].SetFlag(col.RetType.GetFlag()) + } + for i, col := range v.RightJoinKeys { + rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) + rightTypes[i].SetFlag(col.RetType.GetFlag()) + } + offset = len(v.RightJoinKeys) + for i, col := range v.RightNAJoinKeys { + rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) + rightTypes[i+offset].SetFlag(col.RetType.GetFlag()) + } + + // consider collations + for i := range v.EqualConditions { + chs, coll := v.EqualConditions[i].CharsetAndCollation() + leftTypes[i].SetCharset(chs) + leftTypes[i].SetCollate(coll) + rightTypes[i].SetCharset(chs) + rightTypes[i].SetCollate(coll) + } + offset = len(v.EqualConditions) + for i := range v.NAEqualConditions { + chs, coll := v.NAEqualConditions[i].CharsetAndCollation() + leftTypes[i+offset].SetCharset(chs) + leftTypes[i+offset].SetCollate(coll) + rightTypes[i+offset].SetCharset(chs) + rightTypes[i+offset].SetCollate(coll) + } + if leftIsBuildSide { + e.BuildTypes, e.ProbeTypes = leftTypes, rightTypes + } else { + e.BuildTypes, e.ProbeTypes = rightTypes, leftTypes + } + return e +} + +func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) exec.Executor { + src := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + return b.buildHashAggFromChildExec(src, v) +} + +func (b *executorBuilder) buildHashAggFromChildExec(childExec exec.Executor, v *plannercore.PhysicalHashAgg) *aggregate.HashAggExec { + sessionVars := b.ctx.GetSessionVars() + e := &aggregate.HashAggExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), + Sc: sessionVars.StmtCtx, + PartialAggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), + GroupByItems: v.GroupByItems, + } + // We take `create table t(a int, b int);` as example. + // + // 1. If all the aggregation functions are FIRST_ROW, we do not need to set the defaultVal for them: + // e.g. + // mysql> select distinct a, b from t; + // 0 rows in set (0.00 sec) + // + // 2. If there exists group by items, we do not need to set the defaultVal for them either: + // e.g. + // mysql> select avg(a) from t group by b; + // Empty set (0.00 sec) + // + // mysql> select avg(a) from t group by a; + // +--------+ + // | avg(a) | + // +--------+ + // | NULL | + // +--------+ + // 1 row in set (0.00 sec) + if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { + e.DefaultVal = nil + } else { + if v.IsFinalAgg() { + e.DefaultVal = e.AllocPool.Alloc(exec.RetTypes(e), 1, 1) + } + } + for _, aggDesc := range v.AggFuncs { + if aggDesc.HasDistinct || len(aggDesc.OrderByItems) > 0 { + e.IsUnparallelExec = true + } + } + // When we set both tidb_hashagg_final_concurrency and tidb_hashagg_partial_concurrency to 1, + // we do not need to parallelly execute hash agg, + // and this action can be a workaround when meeting some unexpected situation using parallelExec. + if finalCon, partialCon := sessionVars.HashAggFinalConcurrency(), sessionVars.HashAggPartialConcurrency(); finalCon <= 0 || partialCon <= 0 || finalCon == 1 && partialCon == 1 { + e.IsUnparallelExec = true + } + partialOrdinal := 0 + exprCtx := b.ctx.GetExprCtx() + for i, aggDesc := range v.AggFuncs { + if e.IsUnparallelExec { + e.PartialAggFuncs = append(e.PartialAggFuncs, aggfuncs.Build(exprCtx, aggDesc, i)) + } else { + ordinal := []int{partialOrdinal} + partialOrdinal++ + if aggDesc.Name == ast.AggFuncAvg { + ordinal = append(ordinal, partialOrdinal+1) + partialOrdinal++ + } + partialAggDesc, finalDesc := aggDesc.Split(ordinal) + partialAggFunc := aggfuncs.Build(exprCtx, partialAggDesc, i) + finalAggFunc := aggfuncs.Build(exprCtx, finalDesc, i) + e.PartialAggFuncs = append(e.PartialAggFuncs, partialAggFunc) + e.FinalAggFuncs = append(e.FinalAggFuncs, finalAggFunc) + if partialAggDesc.Name == ast.AggFuncGroupConcat { + // For group_concat, finalAggFunc and partialAggFunc need shared `truncate` flag to do duplicate. + finalAggFunc.(interface{ SetTruncated(t *int32) }).SetTruncated( + partialAggFunc.(interface{ GetTruncated() *int32 }).GetTruncated(), + ) + } + } + if e.DefaultVal != nil { + value := aggDesc.GetDefaultValue() + e.DefaultVal.AppendDatum(i, &value) + } + } + + executor_metrics.ExecutorCounterHashAggExec.Inc() + return e +} + +func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) exec.Executor { + src := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + return b.buildStreamAggFromChildExec(src, v) +} + +func (b *executorBuilder) buildStreamAggFromChildExec(childExec exec.Executor, v *plannercore.PhysicalStreamAgg) *aggregate.StreamAggExec { + exprCtx := b.ctx.GetExprCtx() + e := &aggregate.StreamAggExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), + GroupChecker: vecgroupchecker.NewVecGroupChecker(exprCtx.GetEvalCtx(), b.ctx.GetSessionVars().EnableVectorizedExpression, v.GroupByItems), + AggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), + } + + if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { + e.DefaultVal = nil + } else { + // Only do this for final agg, see issue #35295, #30923 + if v.IsFinalAgg() { + e.DefaultVal = e.AllocPool.Alloc(exec.RetTypes(e), 1, 1) + } + } + for i, aggDesc := range v.AggFuncs { + aggFunc := aggfuncs.Build(exprCtx, aggDesc, i) + e.AggFuncs = append(e.AggFuncs, aggFunc) + if e.DefaultVal != nil { + value := aggDesc.GetDefaultValue() + e.DefaultVal.AppendDatum(i, &value) + } + } + + executor_metrics.ExecutorStreamAggExec.Inc() + return e +} + +func (b *executorBuilder) buildSelection(v *plannercore.PhysicalSelection) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + e := &SelectionExec{ + selectionExecutorContext: newSelectionExecutorContext(b.ctx), + BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), + filters: v.Conditions, + } + return e +} + +func (b *executorBuilder) buildExpand(v *plannercore.PhysicalExpand) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + levelES := make([]*expression.EvaluatorSuite, 0, len(v.LevelExprs)) + for _, exprs := range v.LevelExprs { + // column evaluator can always refer others inside expand. + // grouping column's nullability change should be seen as a new column projecting. + // since input inside expand logic should be targeted and reused for N times. + // column evaluator's swapping columns logic will pollute the input data. + levelE := expression.NewEvaluatorSuite(exprs, true) + levelES = append(levelES, levelE) + } + e := &ExpandExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), + numWorkers: int64(b.ctx.GetSessionVars().ProjectionConcurrency()), + levelEvaluatorSuits: levelES, + } + + // If the calculation row count for this Projection operator is smaller + // than a Chunk size, we turn back to the un-parallel Projection + // implementation to reduce the goroutine overhead. + if int64(v.StatsCount()) < int64(b.ctx.GetSessionVars().MaxChunkSize) { + e.numWorkers = 0 + } + + // Use un-parallel projection for query that write on memdb to avoid data race. + // See also https://github.com/pingcap/tidb/issues/26832 + if b.inUpdateStmt || b.inDeleteStmt || b.inInsertStmt || b.hasLock { + e.numWorkers = 0 + } + return e +} + +func (b *executorBuilder) buildProjection(v *plannercore.PhysicalProjection) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + e := &ProjectionExec{ + projectionExecutorContext: newProjectionExecutorContext(b.ctx), + BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), + numWorkers: int64(b.ctx.GetSessionVars().ProjectionConcurrency()), + evaluatorSuit: expression.NewEvaluatorSuite(v.Exprs, v.AvoidColumnEvaluator), + calculateNoDelay: v.CalculateNoDelay, + } + + // If the calculation row count for this Projection operator is smaller + // than a Chunk size, we turn back to the un-parallel Projection + // implementation to reduce the goroutine overhead. + if int64(v.StatsCount()) < int64(b.ctx.GetSessionVars().MaxChunkSize) { + e.numWorkers = 0 + } + + // Use un-parallel projection for query that write on memdb to avoid data race. + // See also https://github.com/pingcap/tidb/issues/26832 + if b.inUpdateStmt || b.inDeleteStmt || b.inInsertStmt || b.hasLock { + e.numWorkers = 0 + } + return e +} + +func (b *executorBuilder) buildTableDual(v *plannercore.PhysicalTableDual) exec.Executor { + if v.RowCount != 0 && v.RowCount != 1 { + b.err = errors.Errorf("buildTableDual failed, invalid row count for dual table: %v", v.RowCount) + return nil + } + base := exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()) + base.SetInitCap(v.RowCount) + e := &TableDualExec{ + BaseExecutorV2: base, + numDualRows: v.RowCount, + } + return e +} + +// `getSnapshotTS` returns for-update-ts if in insert/update/delete/lock statement otherwise the isolation read ts +// Please notice that in RC isolation, the above two ts are the same +func (b *executorBuilder) getSnapshotTS() (ts uint64, err error) { + if b.forDataReaderBuilder { + return b.dataReaderTS, nil + } + + txnManager := sessiontxn.GetTxnManager(b.ctx) + if b.inInsertStmt || b.inUpdateStmt || b.inDeleteStmt || b.inSelectLockStmt { + return txnManager.GetStmtForUpdateTS() + } + return txnManager.GetStmtReadTS() +} + +// getSnapshot get the appropriate snapshot from txnManager and set +// the relevant snapshot options before return. +func (b *executorBuilder) getSnapshot() (kv.Snapshot, error) { + var snapshot kv.Snapshot + var err error + + txnManager := sessiontxn.GetTxnManager(b.ctx) + if b.inInsertStmt || b.inUpdateStmt || b.inDeleteStmt || b.inSelectLockStmt { + snapshot, err = txnManager.GetSnapshotWithStmtForUpdateTS() + } else { + snapshot, err = txnManager.GetSnapshotWithStmtReadTS() + } + if err != nil { + return nil, err + } + + sessVars := b.ctx.GetSessionVars() + replicaReadType := sessVars.GetReplicaRead() + snapshot.SetOption(kv.ReadReplicaScope, b.readReplicaScope) + snapshot.SetOption(kv.TaskID, sessVars.StmtCtx.TaskID) + snapshot.SetOption(kv.TiKVClientReadTimeout, sessVars.GetTiKVClientReadTimeout()) + snapshot.SetOption(kv.ResourceGroupName, sessVars.StmtCtx.ResourceGroupName) + snapshot.SetOption(kv.ExplicitRequestSourceType, sessVars.ExplicitRequestSourceType) + + if replicaReadType.IsClosestRead() && b.readReplicaScope != kv.GlobalTxnScope { + snapshot.SetOption(kv.MatchStoreLabels, []*metapb.StoreLabel{ + { + Key: placement.DCLabelKey, + Value: b.readReplicaScope, + }, + }) + } + + return snapshot, nil +} + +func (b *executorBuilder) buildMemTable(v *plannercore.PhysicalMemTable) exec.Executor { + switch v.DBName.L { + case util.MetricSchemaName.L: + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &MetricRetriever{ + table: v.Table, + extractor: v.Extractor.(*plannercore.MetricTableExtractor), + }, + } + case util.InformationSchemaName.L: + switch v.Table.Name.L { + case strings.ToLower(infoschema.TableClusterConfig): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &clusterConfigRetriever{ + extractor: v.Extractor.(*plannercore.ClusterTableExtractor), + }, + } + case strings.ToLower(infoschema.TableClusterLoad): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &clusterServerInfoRetriever{ + extractor: v.Extractor.(*plannercore.ClusterTableExtractor), + serverInfoType: diagnosticspb.ServerInfoType_LoadInfo, + }, + } + case strings.ToLower(infoschema.TableClusterHardware): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &clusterServerInfoRetriever{ + extractor: v.Extractor.(*plannercore.ClusterTableExtractor), + serverInfoType: diagnosticspb.ServerInfoType_HardwareInfo, + }, + } + case strings.ToLower(infoschema.TableClusterSystemInfo): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &clusterServerInfoRetriever{ + extractor: v.Extractor.(*plannercore.ClusterTableExtractor), + serverInfoType: diagnosticspb.ServerInfoType_SystemInfo, + }, + } + case strings.ToLower(infoschema.TableClusterLog): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &clusterLogRetriever{ + extractor: v.Extractor.(*plannercore.ClusterLogTableExtractor), + }, + } + case strings.ToLower(infoschema.TableTiDBHotRegionsHistory): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &hotRegionsHistoryRetriver{ + extractor: v.Extractor.(*plannercore.HotRegionsHistoryTableExtractor), + }, + } + case strings.ToLower(infoschema.TableInspectionResult): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &inspectionResultRetriever{ + extractor: v.Extractor.(*plannercore.InspectionResultTableExtractor), + timeRange: v.QueryTimeRange, + }, + } + case strings.ToLower(infoschema.TableInspectionSummary): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &inspectionSummaryRetriever{ + table: v.Table, + extractor: v.Extractor.(*plannercore.InspectionSummaryTableExtractor), + timeRange: v.QueryTimeRange, + }, + } + case strings.ToLower(infoschema.TableInspectionRules): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &inspectionRuleRetriever{ + extractor: v.Extractor.(*plannercore.InspectionRuleTableExtractor), + }, + } + case strings.ToLower(infoschema.TableMetricSummary): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &MetricsSummaryRetriever{ + table: v.Table, + extractor: v.Extractor.(*plannercore.MetricSummaryTableExtractor), + timeRange: v.QueryTimeRange, + }, + } + case strings.ToLower(infoschema.TableMetricSummaryByLabel): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &MetricsSummaryByLabelRetriever{ + table: v.Table, + extractor: v.Extractor.(*plannercore.MetricSummaryTableExtractor), + timeRange: v.QueryTimeRange, + }, + } + case strings.ToLower(infoschema.TableTiKVRegionPeers): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &tikvRegionPeersRetriever{ + extractor: v.Extractor.(*plannercore.TikvRegionPeersExtractor), + }, + } + case strings.ToLower(infoschema.TableSchemata), + strings.ToLower(infoschema.TableStatistics), + strings.ToLower(infoschema.TableTiDBIndexes), + strings.ToLower(infoschema.TableViews), + strings.ToLower(infoschema.TableTables), + strings.ToLower(infoschema.TableReferConst), + strings.ToLower(infoschema.TableSequences), + strings.ToLower(infoschema.TablePartitions), + strings.ToLower(infoschema.TableEngines), + strings.ToLower(infoschema.TableCollations), + strings.ToLower(infoschema.TableAnalyzeStatus), + strings.ToLower(infoschema.TableClusterInfo), + strings.ToLower(infoschema.TableProfiling), + strings.ToLower(infoschema.TableCharacterSets), + strings.ToLower(infoschema.TableKeyColumn), + strings.ToLower(infoschema.TableUserPrivileges), + strings.ToLower(infoschema.TableMetricTables), + strings.ToLower(infoschema.TableCollationCharacterSetApplicability), + strings.ToLower(infoschema.TableProcesslist), + strings.ToLower(infoschema.ClusterTableProcesslist), + strings.ToLower(infoschema.TableTiKVRegionStatus), + strings.ToLower(infoschema.TableTiDBHotRegions), + strings.ToLower(infoschema.TableSessionVar), + strings.ToLower(infoschema.TableConstraints), + strings.ToLower(infoschema.TableTiFlashReplica), + strings.ToLower(infoschema.TableTiDBServersInfo), + strings.ToLower(infoschema.TableTiKVStoreStatus), + strings.ToLower(infoschema.TableClientErrorsSummaryGlobal), + strings.ToLower(infoschema.TableClientErrorsSummaryByUser), + strings.ToLower(infoschema.TableClientErrorsSummaryByHost), + strings.ToLower(infoschema.TableAttributes), + strings.ToLower(infoschema.TablePlacementPolicies), + strings.ToLower(infoschema.TableTrxSummary), + strings.ToLower(infoschema.TableVariablesInfo), + strings.ToLower(infoschema.TableUserAttributes), + strings.ToLower(infoschema.ClusterTableTrxSummary), + strings.ToLower(infoschema.TableMemoryUsage), + strings.ToLower(infoschema.TableMemoryUsageOpsHistory), + strings.ToLower(infoschema.ClusterTableMemoryUsage), + strings.ToLower(infoschema.ClusterTableMemoryUsageOpsHistory), + strings.ToLower(infoschema.TableResourceGroups), + strings.ToLower(infoschema.TableRunawayWatches), + strings.ToLower(infoschema.TableCheckConstraints), + strings.ToLower(infoschema.TableTiDBCheckConstraints), + strings.ToLower(infoschema.TableKeywords), + strings.ToLower(infoschema.TableTiDBIndexUsage), + strings.ToLower(infoschema.ClusterTableTiDBIndexUsage): + memTracker := memory.NewTracker(v.ID(), -1) + memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker) + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &memtableRetriever{ + table: v.Table, + columns: v.Columns, + extractor: v.Extractor, + memTracker: memTracker, + }, + } + case strings.ToLower(infoschema.TableTiDBTrx), + strings.ToLower(infoschema.ClusterTableTiDBTrx): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &tidbTrxTableRetriever{ + table: v.Table, + columns: v.Columns, + }, + } + case strings.ToLower(infoschema.TableDataLockWaits): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &dataLockWaitsTableRetriever{ + table: v.Table, + columns: v.Columns, + }, + } + case strings.ToLower(infoschema.TableDeadlocks), + strings.ToLower(infoschema.ClusterTableDeadlocks): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &deadlocksTableRetriever{ + table: v.Table, + columns: v.Columns, + }, + } + case strings.ToLower(infoschema.TableStatementsSummary), + strings.ToLower(infoschema.TableStatementsSummaryHistory), + strings.ToLower(infoschema.TableStatementsSummaryEvicted), + strings.ToLower(infoschema.ClusterTableStatementsSummary), + strings.ToLower(infoschema.ClusterTableStatementsSummaryHistory), + strings.ToLower(infoschema.ClusterTableStatementsSummaryEvicted): + var extractor *plannercore.StatementsSummaryExtractor + if v.Extractor != nil { + extractor = v.Extractor.(*plannercore.StatementsSummaryExtractor) + } + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: buildStmtSummaryRetriever(v.Table, v.Columns, extractor), + } + case strings.ToLower(infoschema.TableColumns): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &hugeMemTableRetriever{ + table: v.Table, + columns: v.Columns, + extractor: v.Extractor.(*plannercore.ColumnsTableExtractor), + viewSchemaMap: make(map[int64]*expression.Schema), + viewOutputNamesMap: make(map[int64]types.NameSlice), + }, + } + case strings.ToLower(infoschema.TableSlowQuery), strings.ToLower(infoschema.ClusterTableSlowLog): + memTracker := memory.NewTracker(v.ID(), -1) + memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker) + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &slowQueryRetriever{ + table: v.Table, + outputCols: v.Columns, + extractor: v.Extractor.(*plannercore.SlowQueryExtractor), + memTracker: memTracker, + }, + } + case strings.ToLower(infoschema.TableStorageStats): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &tableStorageStatsRetriever{ + table: v.Table, + outputCols: v.Columns, + extractor: v.Extractor.(*plannercore.TableStorageStatsExtractor), + }, + } + case strings.ToLower(infoschema.TableDDLJobs): + loc := b.ctx.GetSessionVars().Location() + ddlJobRetriever := DDLJobRetriever{TZLoc: loc} + return &DDLJobsReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + is: b.is, + DDLJobRetriever: ddlJobRetriever, + } + case strings.ToLower(infoschema.TableTiFlashTables), + strings.ToLower(infoschema.TableTiFlashSegments): + return &MemTableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.Table, + retriever: &TiFlashSystemTableRetriever{ + table: v.Table, + outputCols: v.Columns, + extractor: v.Extractor.(*plannercore.TiFlashSystemTableExtractor), + }, + } + } + } + tb, _ := b.is.TableByID(v.Table.ID) + return &TableScanExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + t: tb, + columns: v.Columns, + } +} + +func (b *executorBuilder) buildSort(v *plannercore.PhysicalSort) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + sortExec := sortexec.SortExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), + ByItems: v.ByItems, + ExecSchema: v.Schema(), + } + executor_metrics.ExecutorCounterSortExec.Inc() + return &sortExec +} + +func (b *executorBuilder) buildTopN(v *plannercore.PhysicalTopN) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + sortExec := sortexec.SortExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), + ByItems: v.ByItems, + ExecSchema: v.Schema(), + } + executor_metrics.ExecutorCounterTopNExec.Inc() + return &sortexec.TopNExec{ + SortExec: sortExec, + Limit: &plannercore.PhysicalLimit{Count: v.Count, Offset: v.Offset}, + Concurrency: b.ctx.GetSessionVars().Concurrency.ExecutorConcurrency, + } +} + +func (b *executorBuilder) buildApply(v *plannercore.PhysicalApply) exec.Executor { + var ( + innerPlan base.PhysicalPlan + outerPlan base.PhysicalPlan + ) + if v.InnerChildIdx == 0 { + innerPlan = v.Children()[0] + outerPlan = v.Children()[1] + } else { + innerPlan = v.Children()[1] + outerPlan = v.Children()[0] + } + v.OuterSchema = coreusage.ExtractCorColumnsBySchema4PhysicalPlan(innerPlan, outerPlan.Schema()) + leftChild := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + rightChild := b.build(v.Children()[1]) + if b.err != nil { + return nil + } + // test is in the explain/naaj.test#part5. + // although we prepared the NAEqualConditions, but for Apply mode, we still need move it to other conditions like eq condition did here. + otherConditions := append(expression.ScalarFuncs2Exprs(v.EqualConditions), expression.ScalarFuncs2Exprs(v.NAEqualConditions)...) + otherConditions = append(otherConditions, v.OtherConditions...) + defaultValues := v.DefaultValues + if defaultValues == nil { + defaultValues = make([]types.Datum, v.Children()[v.InnerChildIdx].Schema().Len()) + } + outerExec, innerExec := leftChild, rightChild + outerFilter, innerFilter := v.LeftConditions, v.RightConditions + if v.InnerChildIdx == 0 { + outerExec, innerExec = rightChild, leftChild + outerFilter, innerFilter = v.RightConditions, v.LeftConditions + } + tupleJoiner := join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, + defaultValues, otherConditions, exec.RetTypes(leftChild), exec.RetTypes(rightChild), nil, false) + serialExec := &join.NestedLoopApplyExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec, innerExec), + InnerExec: innerExec, + OuterExec: outerExec, + OuterFilter: outerFilter, + InnerFilter: innerFilter, + Outer: v.JoinType != plannercore.InnerJoin, + Joiner: tupleJoiner, + OuterSchema: v.OuterSchema, + Sctx: b.ctx, + CanUseCache: v.CanUseCache, + } + executor_metrics.ExecutorCounterNestedLoopApplyExec.Inc() + + // try parallel mode + if v.Concurrency > 1 { + innerExecs := make([]exec.Executor, 0, v.Concurrency) + innerFilters := make([]expression.CNFExprs, 0, v.Concurrency) + corCols := make([][]*expression.CorrelatedColumn, 0, v.Concurrency) + joiners := make([]join.Joiner, 0, v.Concurrency) + for i := 0; i < v.Concurrency; i++ { + clonedInnerPlan, err := plannercore.SafeClone(v.SCtx(), innerPlan) + if err != nil { + b.err = nil + return serialExec + } + corCol := coreusage.ExtractCorColumnsBySchema4PhysicalPlan(clonedInnerPlan, outerPlan.Schema()) + clonedInnerExec := b.build(clonedInnerPlan) + if b.err != nil { + b.err = nil + return serialExec + } + innerExecs = append(innerExecs, clonedInnerExec) + corCols = append(corCols, corCol) + innerFilters = append(innerFilters, innerFilter.Clone()) + joiners = append(joiners, join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, + defaultValues, otherConditions, exec.RetTypes(leftChild), exec.RetTypes(rightChild), nil, false)) + } + + allExecs := append([]exec.Executor{outerExec}, innerExecs...) + + return &ParallelNestedLoopApplyExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), allExecs...), + innerExecs: innerExecs, + outerExec: outerExec, + outerFilter: outerFilter, + innerFilter: innerFilters, + outer: v.JoinType != plannercore.InnerJoin, + joiners: joiners, + corCols: corCols, + concurrency: v.Concurrency, + useCache: v.CanUseCache, + } + } + return serialExec +} + +func (b *executorBuilder) buildMaxOneRow(v *plannercore.PhysicalMaxOneRow) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec) + base.SetInitCap(2) + base.SetMaxChunkSize(2) + e := &MaxOneRowExec{BaseExecutor: base} + return e +} + +func (b *executorBuilder) buildUnionAll(v *plannercore.PhysicalUnionAll) exec.Executor { + childExecs := make([]exec.Executor, len(v.Children())) + for i, child := range v.Children() { + childExecs[i] = b.build(child) + if b.err != nil { + return nil + } + } + e := &unionexec.UnionExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExecs...), + Concurrency: b.ctx.GetSessionVars().UnionConcurrency(), + } + return e +} + +func buildHandleColsForSplit(sc *stmtctx.StatementContext, tbInfo *model.TableInfo) plannerutil.HandleCols { + if tbInfo.IsCommonHandle { + primaryIdx := tables.FindPrimaryIndex(tbInfo) + tableCols := make([]*expression.Column, len(tbInfo.Columns)) + for i, col := range tbInfo.Columns { + tableCols[i] = &expression.Column{ + ID: col.ID, + RetType: &col.FieldType, + } + } + for i, pkCol := range primaryIdx.Columns { + tableCols[pkCol.Offset].Index = i + } + return plannerutil.NewCommonHandleCols(sc, tbInfo, primaryIdx, tableCols) + } + intCol := &expression.Column{ + RetType: types.NewFieldType(mysql.TypeLonglong), + } + return plannerutil.NewIntHandleCols(intCol) +} + +func (b *executorBuilder) buildSplitRegion(v *plannercore.SplitRegion) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + base.SetInitCap(1) + base.SetMaxChunkSize(1) + if v.IndexInfo != nil { + return &SplitIndexRegionExec{ + BaseExecutor: base, + tableInfo: v.TableInfo, + partitionNames: v.PartitionNames, + indexInfo: v.IndexInfo, + lower: v.Lower, + upper: v.Upper, + num: v.Num, + valueLists: v.ValueLists, + } + } + handleCols := buildHandleColsForSplit(b.ctx.GetSessionVars().StmtCtx, v.TableInfo) + if len(v.ValueLists) > 0 { + return &SplitTableRegionExec{ + BaseExecutor: base, + tableInfo: v.TableInfo, + partitionNames: v.PartitionNames, + handleCols: handleCols, + valueLists: v.ValueLists, + } + } + return &SplitTableRegionExec{ + BaseExecutor: base, + tableInfo: v.TableInfo, + partitionNames: v.PartitionNames, + handleCols: handleCols, + lower: v.Lower, + upper: v.Upper, + num: v.Num, + } +} + +func (b *executorBuilder) buildUpdate(v *plannercore.Update) exec.Executor { + b.inUpdateStmt = true + tblID2table := make(map[int64]table.Table, len(v.TblColPosInfos)) + multiUpdateOnSameTable := make(map[int64]bool) + for _, info := range v.TblColPosInfos { + tbl, _ := b.is.TableByID(info.TblID) + if _, ok := tblID2table[info.TblID]; ok { + multiUpdateOnSameTable[info.TblID] = true + } + tblID2table[info.TblID] = tbl + if len(v.PartitionedTable) > 0 { + // The v.PartitionedTable collects the partitioned table. + // Replace the original table with the partitioned table to support partition selection. + // e.g. update t partition (p0, p1), the new values are not belong to the given set p0, p1 + // Using the table in v.PartitionedTable returns a proper error, while using the original table can't. + for _, p := range v.PartitionedTable { + if info.TblID == p.Meta().ID { + tblID2table[info.TblID] = p + } + } + } + } + if b.err = b.updateForUpdateTS(); b.err != nil { + return nil + } + + selExec := b.build(v.SelectPlan) + if b.err != nil { + return nil + } + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), selExec) + base.SetInitCap(chunk.ZeroCapacity) + var assignFlag []int + assignFlag, b.err = getAssignFlag(b.ctx, v, selExec.Schema().Len()) + if b.err != nil { + return nil + } + // should use the new tblID2table, since the update's schema may have been changed in Execstmt. + b.err = plannercore.CheckUpdateList(assignFlag, v, tblID2table) + if b.err != nil { + return nil + } + updateExec := &UpdateExec{ + BaseExecutor: base, + OrderedList: v.OrderedList, + allAssignmentsAreConstant: v.AllAssignmentsAreConstant, + virtualAssignmentsOffset: v.VirtualAssignmentsOffset, + multiUpdateOnSameTable: multiUpdateOnSameTable, + tblID2table: tblID2table, + tblColPosInfos: v.TblColPosInfos, + assignFlag: assignFlag, + } + updateExec.fkChecks, b.err = buildTblID2FKCheckExecs(b.ctx, tblID2table, v.FKChecks) + if b.err != nil { + return nil + } + updateExec.fkCascades, b.err = b.buildTblID2FKCascadeExecs(tblID2table, v.FKCascades) + if b.err != nil { + return nil + } + return updateExec +} + +func getAssignFlag(ctx sessionctx.Context, v *plannercore.Update, schemaLen int) ([]int, error) { + assignFlag := make([]int, schemaLen) + for i := range assignFlag { + assignFlag[i] = -1 + } + for _, assign := range v.OrderedList { + if !ctx.GetSessionVars().AllowWriteRowID && assign.Col.ID == model.ExtraHandleID { + return nil, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported") + } + tblIdx, found := v.TblColPosInfos.FindTblIdx(assign.Col.Index) + if found { + colIdx := assign.Col.Index + assignFlag[colIdx] = tblIdx + } + } + return assignFlag, nil +} + +func (b *executorBuilder) buildDelete(v *plannercore.Delete) exec.Executor { + b.inDeleteStmt = true + tblID2table := make(map[int64]table.Table, len(v.TblColPosInfos)) + for _, info := range v.TblColPosInfos { + tblID2table[info.TblID], _ = b.is.TableByID(info.TblID) + } + + if b.err = b.updateForUpdateTS(); b.err != nil { + return nil + } + + selExec := b.build(v.SelectPlan) + if b.err != nil { + return nil + } + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), selExec) + base.SetInitCap(chunk.ZeroCapacity) + deleteExec := &DeleteExec{ + BaseExecutor: base, + tblID2Table: tblID2table, + IsMultiTable: v.IsMultiTable, + tblColPosInfos: v.TblColPosInfos, + } + deleteExec.fkChecks, b.err = buildTblID2FKCheckExecs(b.ctx, tblID2table, v.FKChecks) + if b.err != nil { + return nil + } + deleteExec.fkCascades, b.err = b.buildTblID2FKCascadeExecs(tblID2table, v.FKCascades) + if b.err != nil { + return nil + } + return deleteExec +} + +func (b *executorBuilder) updateForUpdateTS() error { + // GetStmtForUpdateTS will auto update the for update ts if it is necessary + _, err := sessiontxn.GetTxnManager(b.ctx).GetStmtForUpdateTS() + return err +} + +func (b *executorBuilder) buildAnalyzeIndexPushdown(task plannercore.AnalyzeIndexTask, opts map[ast.AnalyzeOptionType]uint64, autoAnalyze string) *analyzeTask { + job := &statistics.AnalyzeJob{DBName: task.DBName, TableName: task.TableName, PartitionName: task.PartitionName, JobInfo: autoAnalyze + "analyze index " + task.IndexInfo.Name.O} + _, offset := timeutil.Zone(b.ctx.GetSessionVars().Location()) + sc := b.ctx.GetSessionVars().StmtCtx + startTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { + startTS = uint64(val.(int)) + }) + concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) + base := baseAnalyzeExec{ + ctx: b.ctx, + tableID: task.TableID, + concurrency: concurrency, + analyzePB: &tipb.AnalyzeReq{ + Tp: tipb.AnalyzeType_TypeIndex, + Flags: sc.PushDownFlags(), + TimeZoneOffset: offset, + }, + opts: opts, + job: job, + snapshot: startTS, + } + e := &AnalyzeIndexExec{ + baseAnalyzeExec: base, + isCommonHandle: task.TblInfo.IsCommonHandle, + idxInfo: task.IndexInfo, + } + topNSize := new(int32) + *topNSize = int32(opts[ast.AnalyzeOptNumTopN]) + statsVersion := new(int32) + *statsVersion = int32(task.StatsVersion) + e.analyzePB.IdxReq = &tipb.AnalyzeIndexReq{ + BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), + NumColumns: int32(len(task.IndexInfo.Columns)), + TopNSize: topNSize, + Version: statsVersion, + SketchSize: statistics.MaxSketchSize, + } + if e.isCommonHandle && e.idxInfo.Primary { + e.analyzePB.Tp = tipb.AnalyzeType_TypeCommonHandle + } + depth := int32(opts[ast.AnalyzeOptCMSketchDepth]) + width := int32(opts[ast.AnalyzeOptCMSketchWidth]) + e.analyzePB.IdxReq.CmsketchDepth = &depth + e.analyzePB.IdxReq.CmsketchWidth = &width + return &analyzeTask{taskType: idxTask, idxExec: e, job: job} +} + +func (b *executorBuilder) buildAnalyzeSamplingPushdown( + task plannercore.AnalyzeColumnsTask, + opts map[ast.AnalyzeOptionType]uint64, + schemaForVirtualColEval *expression.Schema, +) *analyzeTask { + if task.V2Options != nil { + opts = task.V2Options.FilledOpts + } + availableIdx := make([]*model.IndexInfo, 0, len(task.Indexes)) + colGroups := make([]*tipb.AnalyzeColumnGroup, 0, len(task.Indexes)) + if len(task.Indexes) > 0 { + for _, idx := range task.Indexes { + availableIdx = append(availableIdx, idx) + colGroup := &tipb.AnalyzeColumnGroup{ + ColumnOffsets: make([]int64, 0, len(idx.Columns)), + } + for _, col := range idx.Columns { + colGroup.ColumnOffsets = append(colGroup.ColumnOffsets, int64(col.Offset)) + } + colGroups = append(colGroups, colGroup) + } + } + + _, offset := timeutil.Zone(b.ctx.GetSessionVars().Location()) + sc := b.ctx.GetSessionVars().StmtCtx + startTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { + startTS = uint64(val.(int)) + }) + statsHandle := domain.GetDomain(b.ctx).StatsHandle() + count, modifyCount, err := statsHandle.StatsMetaCountAndModifyCount(task.TableID.GetStatisticsID()) + if err != nil { + b.err = err + return nil + } + failpoint.Inject("injectBaseCount", func(val failpoint.Value) { + count = int64(val.(int)) + }) + failpoint.Inject("injectBaseModifyCount", func(val failpoint.Value) { + modifyCount = int64(val.(int)) + }) + sampleRate := new(float64) + var sampleRateReason string + if opts[ast.AnalyzeOptNumSamples] == 0 { + *sampleRate = math.Float64frombits(opts[ast.AnalyzeOptSampleRate]) + if *sampleRate < 0 { + *sampleRate, sampleRateReason = b.getAdjustedSampleRate(task) + if task.PartitionName != "" { + sc.AppendNote(errors.NewNoStackErrorf( + `Analyze use auto adjusted sample rate %f for table %s.%s's partition %s, reason to use this rate is "%s"`, + *sampleRate, + task.DBName, + task.TableName, + task.PartitionName, + sampleRateReason, + )) + } else { + sc.AppendNote(errors.NewNoStackErrorf( + `Analyze use auto adjusted sample rate %f for table %s.%s, reason to use this rate is "%s"`, + *sampleRate, + task.DBName, + task.TableName, + sampleRateReason, + )) + } + } + } + job := &statistics.AnalyzeJob{ + DBName: task.DBName, + TableName: task.TableName, + PartitionName: task.PartitionName, + SampleRateReason: sampleRateReason, + } + concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) + base := baseAnalyzeExec{ + ctx: b.ctx, + tableID: task.TableID, + concurrency: concurrency, + analyzePB: &tipb.AnalyzeReq{ + Tp: tipb.AnalyzeType_TypeFullSampling, + Flags: sc.PushDownFlags(), + TimeZoneOffset: offset, + }, + opts: opts, + job: job, + snapshot: startTS, + } + e := &AnalyzeColumnsExec{ + baseAnalyzeExec: base, + tableInfo: task.TblInfo, + colsInfo: task.ColsInfo, + handleCols: task.HandleCols, + indexes: availableIdx, + AnalyzeInfo: task.AnalyzeInfo, + schemaForVirtualColEval: schemaForVirtualColEval, + baseCount: count, + baseModifyCnt: modifyCount, + } + e.analyzePB.ColReq = &tipb.AnalyzeColumnsReq{ + BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), + SampleSize: int64(opts[ast.AnalyzeOptNumSamples]), + SampleRate: sampleRate, + SketchSize: statistics.MaxSketchSize, + ColumnsInfo: util.ColumnsToProto(task.ColsInfo, task.TblInfo.PKIsHandle, false), + ColumnGroups: colGroups, + } + if task.TblInfo != nil { + e.analyzePB.ColReq.PrimaryColumnIds = tables.TryGetCommonPkColumnIds(task.TblInfo) + if task.TblInfo.IsCommonHandle { + e.analyzePB.ColReq.PrimaryPrefixColumnIds = tables.PrimaryPrefixColumnIDs(task.TblInfo) + } + } + b.err = tables.SetPBColumnsDefaultValue(b.ctx.GetExprCtx(), e.analyzePB.ColReq.ColumnsInfo, task.ColsInfo) + return &analyzeTask{taskType: colTask, colExec: e, job: job} +} + +// getAdjustedSampleRate calculate the sample rate by the table size. If we cannot get the table size. We use the 0.001 as the default sample rate. +// From the paper "Random sampling for histogram construction: how much is enough?"'s Corollary 1 to Theorem 5, +// for a table size n, histogram size k, maximum relative error in bin size f, and error probability gamma, +// the minimum random sample size is +// +// r = 4 * k * ln(2*n/gamma) / f^2 +// +// If we take f = 0.5, gamma = 0.01, n =1e6, we would got r = 305.82* k. +// Since the there's log function over the table size n, the r grows slowly when the n increases. +// If we take n = 1e12, a 300*k sample still gives <= 0.66 bin size error with probability 0.99. +// So if we don't consider the top-n values, we can keep the sample size at 300*256. +// But we may take some top-n before building the histogram, so we increase the sample a little. +func (b *executorBuilder) getAdjustedSampleRate(task plannercore.AnalyzeColumnsTask) (sampleRate float64, reason string) { + statsHandle := domain.GetDomain(b.ctx).StatsHandle() + defaultRate := 0.001 + if statsHandle == nil { + return defaultRate, fmt.Sprintf("statsHandler is nil, use the default-rate=%v", defaultRate) + } + var statsTbl *statistics.Table + tid := task.TableID.GetStatisticsID() + if tid == task.TblInfo.ID { + statsTbl = statsHandle.GetTableStats(task.TblInfo) + } else { + statsTbl = statsHandle.GetPartitionStats(task.TblInfo, tid) + } + approxiCount, hasPD := b.getApproximateTableCountFromStorage(tid, task) + // If there's no stats meta and no pd, return the default rate. + if statsTbl == nil && !hasPD { + return defaultRate, fmt.Sprintf("TiDB cannot get the row count of the table, use the default-rate=%v", defaultRate) + } + // If the count in stats_meta is still 0 and there's no information from pd side, we scan all rows. + if statsTbl.RealtimeCount == 0 && !hasPD { + return 1, "TiDB assumes that the table is empty and cannot get row count from PD, use sample-rate=1" + } + // we have issue https://github.com/pingcap/tidb/issues/29216. + // To do a workaround for this issue, we check the approxiCount from the pd side to do a comparison. + // If the count from the stats_meta is extremely smaller than the approximate count from the pd, + // we think that we meet this issue and use the approximate count to calculate the sample rate. + if float64(statsTbl.RealtimeCount*5) < approxiCount { + // Confirmed by TiKV side, the experience error rate of the approximate count is about 20%. + // So we increase the number to 150000 to reduce this error rate. + sampleRate = math.Min(1, 150000/approxiCount) + return sampleRate, fmt.Sprintf("Row count in stats_meta is much smaller compared with the row count got by PD, use min(1, 15000/%v) as the sample-rate=%v", approxiCount, sampleRate) + } + // If we don't go into the above if branch and we still detect the count is zero. Return 1 to prevent the dividing zero. + if statsTbl.RealtimeCount == 0 { + return 1, "TiDB assumes that the table is empty, use sample-rate=1" + } + // We are expected to scan about 100000 rows or so. + // Since there's tiny error rate around the count from the stats meta, we use 110000 to get a little big result + sampleRate = math.Min(1, config.DefRowsForSampleRate/float64(statsTbl.RealtimeCount)) + return sampleRate, fmt.Sprintf("use min(1, %v/%v) as the sample-rate=%v", config.DefRowsForSampleRate, statsTbl.RealtimeCount, sampleRate) +} + +func (b *executorBuilder) getApproximateTableCountFromStorage(tid int64, task plannercore.AnalyzeColumnsTask) (float64, bool) { + return pdhelper.GlobalPDHelper.GetApproximateTableCountFromStorage(context.Background(), b.ctx, tid, task.DBName, task.TableName, task.PartitionName) +} + +func (b *executorBuilder) buildAnalyzeColumnsPushdown( + task plannercore.AnalyzeColumnsTask, + opts map[ast.AnalyzeOptionType]uint64, + autoAnalyze string, + schemaForVirtualColEval *expression.Schema, +) *analyzeTask { + if task.StatsVersion == statistics.Version2 { + return b.buildAnalyzeSamplingPushdown(task, opts, schemaForVirtualColEval) + } + job := &statistics.AnalyzeJob{DBName: task.DBName, TableName: task.TableName, PartitionName: task.PartitionName, JobInfo: autoAnalyze + "analyze columns"} + cols := task.ColsInfo + if hasPkHist(task.HandleCols) { + colInfo := task.TblInfo.Columns[task.HandleCols.GetCol(0).Index] + cols = append([]*model.ColumnInfo{colInfo}, cols...) + } else if task.HandleCols != nil && !task.HandleCols.IsInt() { + cols = make([]*model.ColumnInfo, 0, len(task.ColsInfo)+task.HandleCols.NumCols()) + for i := 0; i < task.HandleCols.NumCols(); i++ { + cols = append(cols, task.TblInfo.Columns[task.HandleCols.GetCol(i).Index]) + } + cols = append(cols, task.ColsInfo...) + task.ColsInfo = cols + } + + _, offset := timeutil.Zone(b.ctx.GetSessionVars().Location()) + sc := b.ctx.GetSessionVars().StmtCtx + startTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { + startTS = uint64(val.(int)) + }) + concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) + base := baseAnalyzeExec{ + ctx: b.ctx, + tableID: task.TableID, + concurrency: concurrency, + analyzePB: &tipb.AnalyzeReq{ + Tp: tipb.AnalyzeType_TypeColumn, + Flags: sc.PushDownFlags(), + TimeZoneOffset: offset, + }, + opts: opts, + job: job, + snapshot: startTS, + } + e := &AnalyzeColumnsExec{ + baseAnalyzeExec: base, + colsInfo: task.ColsInfo, + handleCols: task.HandleCols, + AnalyzeInfo: task.AnalyzeInfo, + } + depth := int32(opts[ast.AnalyzeOptCMSketchDepth]) + width := int32(opts[ast.AnalyzeOptCMSketchWidth]) + e.analyzePB.ColReq = &tipb.AnalyzeColumnsReq{ + BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), + SampleSize: MaxRegionSampleSize, + SketchSize: statistics.MaxSketchSize, + ColumnsInfo: util.ColumnsToProto(cols, task.HandleCols != nil && task.HandleCols.IsInt(), false), + CmsketchDepth: &depth, + CmsketchWidth: &width, + } + if task.TblInfo != nil { + e.analyzePB.ColReq.PrimaryColumnIds = tables.TryGetCommonPkColumnIds(task.TblInfo) + if task.TblInfo.IsCommonHandle { + e.analyzePB.ColReq.PrimaryPrefixColumnIds = tables.PrimaryPrefixColumnIDs(task.TblInfo) + } + } + if task.CommonHandleInfo != nil { + topNSize := new(int32) + *topNSize = int32(opts[ast.AnalyzeOptNumTopN]) + statsVersion := new(int32) + *statsVersion = int32(task.StatsVersion) + e.analyzePB.IdxReq = &tipb.AnalyzeIndexReq{ + BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), + NumColumns: int32(len(task.CommonHandleInfo.Columns)), + TopNSize: topNSize, + Version: statsVersion, + } + depth := int32(opts[ast.AnalyzeOptCMSketchDepth]) + width := int32(opts[ast.AnalyzeOptCMSketchWidth]) + e.analyzePB.IdxReq.CmsketchDepth = &depth + e.analyzePB.IdxReq.CmsketchWidth = &width + e.analyzePB.IdxReq.SketchSize = statistics.MaxSketchSize + e.analyzePB.ColReq.PrimaryColumnIds = tables.TryGetCommonPkColumnIds(task.TblInfo) + e.analyzePB.Tp = tipb.AnalyzeType_TypeMixed + e.commonHandle = task.CommonHandleInfo + } + b.err = tables.SetPBColumnsDefaultValue(b.ctx.GetExprCtx(), e.analyzePB.ColReq.ColumnsInfo, cols) + return &analyzeTask{taskType: colTask, colExec: e, job: job} +} + +func (b *executorBuilder) buildAnalyze(v *plannercore.Analyze) exec.Executor { + gp := domain.GetDomain(b.ctx).StatsHandle().GPool() + e := &AnalyzeExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + tasks: make([]*analyzeTask, 0, len(v.ColTasks)+len(v.IdxTasks)), + opts: v.Opts, + OptionsMap: v.OptionsMap, + wg: util.NewWaitGroupPool(gp), + gp: gp, + errExitCh: make(chan struct{}), + } + autoAnalyze := "" + if b.ctx.GetSessionVars().InRestrictedSQL { + autoAnalyze = "auto " + } + exprCtx := b.ctx.GetExprCtx() + for _, task := range v.ColTasks { + columns, _, err := expression.ColumnInfos2ColumnsAndNames( + exprCtx, + model.NewCIStr(task.AnalyzeInfo.DBName), + task.TblInfo.Name, + task.ColsInfo, + task.TblInfo, + ) + if err != nil { + b.err = err + return nil + } + schema := expression.NewSchema(columns...) + e.tasks = append(e.tasks, b.buildAnalyzeColumnsPushdown(task, v.Opts, autoAnalyze, schema)) + // Other functions may set b.err, so we need to check it here. + if b.err != nil { + return nil + } + } + for _, task := range v.IdxTasks { + e.tasks = append(e.tasks, b.buildAnalyzeIndexPushdown(task, v.Opts, autoAnalyze)) + if b.err != nil { + return nil + } + } + return e +} + +// markChildrenUsedCols compares each child with the output schema, and mark +// each column of the child is used by output or not. +func markChildrenUsedCols(outputCols []*expression.Column, childSchemas ...*expression.Schema) (childrenUsed [][]int) { + childrenUsed = make([][]int, 0, len(childSchemas)) + markedOffsets := make(map[int]int) + // keep the original maybe reversed order. + for originalIdx, col := range outputCols { + markedOffsets[col.Index] = originalIdx + } + prefixLen := 0 + type intPair struct { + first int + second int + } + // for example here. + // left child schema: [col11] + // right child schema: [col21, col22] + // output schema is [col11, col22, col21], if not records the original derived order after physical resolve index. + // the lused will be [0], the rused will be [0,1], while the actual order is dismissed, [1,0] is correct for rused. + for _, childSchema := range childSchemas { + usedIdxPair := make([]intPair, 0, len(childSchema.Columns)) + for i := range childSchema.Columns { + if originalIdx, ok := markedOffsets[prefixLen+i]; ok { + usedIdxPair = append(usedIdxPair, intPair{first: originalIdx, second: i}) + } + } + // sort the used idxes according their original indexes derived after resolveIndex. + slices.SortFunc(usedIdxPair, func(a, b intPair) int { + return cmp.Compare(a.first, b.first) + }) + usedIdx := make([]int, 0, len(childSchema.Columns)) + for _, one := range usedIdxPair { + usedIdx = append(usedIdx, one.second) + } + childrenUsed = append(childrenUsed, usedIdx) + prefixLen += childSchema.Len() + } + return +} + +func (*executorBuilder) corColInDistPlan(plans []base.PhysicalPlan) bool { + for _, p := range plans { + switch x := p.(type) { + case *plannercore.PhysicalSelection: + for _, cond := range x.Conditions { + if len(expression.ExtractCorColumns(cond)) > 0 { + return true + } + } + case *plannercore.PhysicalProjection: + for _, expr := range x.Exprs { + if len(expression.ExtractCorColumns(expr)) > 0 { + return true + } + } + case *plannercore.PhysicalTopN: + for _, byItem := range x.ByItems { + if len(expression.ExtractCorColumns(byItem.Expr)) > 0 { + return true + } + } + case *plannercore.PhysicalTableScan: + for _, cond := range x.LateMaterializationFilterCondition { + if len(expression.ExtractCorColumns(cond)) > 0 { + return true + } + } + } + } + return false +} + +// corColInAccess checks whether there's correlated column in access conditions. +func (*executorBuilder) corColInAccess(p base.PhysicalPlan) bool { + var access []expression.Expression + switch x := p.(type) { + case *plannercore.PhysicalTableScan: + access = x.AccessCondition + case *plannercore.PhysicalIndexScan: + access = x.AccessCondition + } + for _, cond := range access { + if len(expression.ExtractCorColumns(cond)) > 0 { + return true + } + } + return false +} + +func (b *executorBuilder) newDataReaderBuilder(p base.PhysicalPlan) (*dataReaderBuilder, error) { + ts, err := b.getSnapshotTS() + if err != nil { + return nil, err + } + + builderForDataReader := *b + builderForDataReader.forDataReaderBuilder = true + builderForDataReader.dataReaderTS = ts + + return &dataReaderBuilder{ + plan: p, + executorBuilder: &builderForDataReader, + }, nil +} + +func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) exec.Executor { + outerExec := b.build(v.Children()[1-v.InnerChildIdx]) + if b.err != nil { + return nil + } + outerTypes := exec.RetTypes(outerExec) + innerPlan := v.Children()[v.InnerChildIdx] + innerTypes := make([]*types.FieldType, innerPlan.Schema().Len()) + for i, col := range innerPlan.Schema().Columns { + innerTypes[i] = col.RetType.Clone() + // The `innerTypes` would be called for `Datum.ConvertTo` when converting the columns from outer table + // to build hash map or construct lookup keys. So we need to modify its flen otherwise there would be + // truncate error. See issue https://github.com/pingcap/tidb/issues/21232 for example. + if innerTypes[i].EvalType() == types.ETString { + innerTypes[i].SetFlen(types.UnspecifiedLength) + } + } + + // Use the probe table's collation. + for i, col := range v.OuterHashKeys { + outerTypes[col.Index] = outerTypes[col.Index].Clone() + outerTypes[col.Index].SetCollate(innerTypes[v.InnerHashKeys[i].Index].GetCollate()) + outerTypes[col.Index].SetFlag(col.RetType.GetFlag()) + } + + // We should use JoinKey to construct the type information using by hashing, instead of using the child's schema directly. + // When a hybrid type column is hashed multiple times, we need to distinguish what field types are used. + // For example, the condition `enum = int and enum = string`, we should use ETInt to hash the first column, + // and use ETString to hash the second column, although they may be the same column. + innerHashTypes := make([]*types.FieldType, len(v.InnerHashKeys)) + outerHashTypes := make([]*types.FieldType, len(v.OuterHashKeys)) + for i, col := range v.InnerHashKeys { + innerHashTypes[i] = innerTypes[col.Index].Clone() + innerHashTypes[i].SetFlag(col.RetType.GetFlag()) + } + for i, col := range v.OuterHashKeys { + outerHashTypes[i] = outerTypes[col.Index].Clone() + outerHashTypes[i].SetFlag(col.RetType.GetFlag()) + } + + var ( + outerFilter []expression.Expression + leftTypes, rightTypes []*types.FieldType + ) + + if v.InnerChildIdx == 0 { + leftTypes, rightTypes = innerTypes, outerTypes + outerFilter = v.RightConditions + if len(v.LeftConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } else { + leftTypes, rightTypes = outerTypes, innerTypes + outerFilter = v.LeftConditions + if len(v.RightConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } + defaultValues := v.DefaultValues + if defaultValues == nil { + defaultValues = make([]types.Datum, len(innerTypes)) + } + hasPrefixCol := false + for _, l := range v.IdxColLens { + if l != types.UnspecifiedLength { + hasPrefixCol = true + break + } + } + + readerBuilder, err := b.newDataReaderBuilder(innerPlan) + if err != nil { + b.err = err + return nil + } + + e := &join.IndexLookUpJoin{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec), + OuterCtx: join.OuterCtx{ + RowTypes: outerTypes, + HashTypes: outerHashTypes, + Filter: outerFilter, + }, + InnerCtx: join.InnerCtx{ + ReaderBuilder: readerBuilder, + RowTypes: innerTypes, + HashTypes: innerHashTypes, + ColLens: v.IdxColLens, + HasPrefixCol: hasPrefixCol, + }, + WorkerWg: new(sync.WaitGroup), + IsOuterJoin: v.JoinType.IsOuterJoin(), + IndexRanges: v.Ranges, + KeyOff2IdxOff: v.KeyOff2IdxOff, + LastColHelper: v.CompareFilters, + Finished: &atomic.Value{}, + } + colsFromChildren := v.Schema().Columns + if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { + colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] + } + childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) + e.Joiner = join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema, false) + outerKeyCols := make([]int, len(v.OuterJoinKeys)) + for i := 0; i < len(v.OuterJoinKeys); i++ { + outerKeyCols[i] = v.OuterJoinKeys[i].Index + } + innerKeyCols := make([]int, len(v.InnerJoinKeys)) + innerKeyColIDs := make([]int64, len(v.InnerJoinKeys)) + keyCollators := make([]collate.Collator, 0, len(v.InnerJoinKeys)) + for i := 0; i < len(v.InnerJoinKeys); i++ { + innerKeyCols[i] = v.InnerJoinKeys[i].Index + innerKeyColIDs[i] = v.InnerJoinKeys[i].ID + keyCollators = append(keyCollators, collate.GetCollator(v.InnerJoinKeys[i].RetType.GetCollate())) + } + e.OuterCtx.KeyCols = outerKeyCols + e.InnerCtx.KeyCols = innerKeyCols + e.InnerCtx.KeyColIDs = innerKeyColIDs + e.InnerCtx.KeyCollators = keyCollators + + outerHashCols, innerHashCols := make([]int, len(v.OuterHashKeys)), make([]int, len(v.InnerHashKeys)) + hashCollators := make([]collate.Collator, 0, len(v.InnerHashKeys)) + for i := 0; i < len(v.OuterHashKeys); i++ { + outerHashCols[i] = v.OuterHashKeys[i].Index + } + for i := 0; i < len(v.InnerHashKeys); i++ { + innerHashCols[i] = v.InnerHashKeys[i].Index + hashCollators = append(hashCollators, collate.GetCollator(v.InnerHashKeys[i].RetType.GetCollate())) + } + e.OuterCtx.HashCols = outerHashCols + e.InnerCtx.HashCols = innerHashCols + e.InnerCtx.HashCollators = hashCollators + + e.JoinResult = exec.TryNewCacheChunk(e) + executor_metrics.ExecutorCounterIndexLookUpJoin.Inc() + return e +} + +func (b *executorBuilder) buildIndexLookUpMergeJoin(v *plannercore.PhysicalIndexMergeJoin) exec.Executor { + outerExec := b.build(v.Children()[1-v.InnerChildIdx]) + if b.err != nil { + return nil + } + outerTypes := exec.RetTypes(outerExec) + innerPlan := v.Children()[v.InnerChildIdx] + innerTypes := make([]*types.FieldType, innerPlan.Schema().Len()) + for i, col := range innerPlan.Schema().Columns { + innerTypes[i] = col.RetType.Clone() + // The `innerTypes` would be called for `Datum.ConvertTo` when converting the columns from outer table + // to build hash map or construct lookup keys. So we need to modify its flen otherwise there would be + // truncate error. See issue https://github.com/pingcap/tidb/issues/21232 for example. + if innerTypes[i].EvalType() == types.ETString { + innerTypes[i].SetFlen(types.UnspecifiedLength) + } + } + var ( + outerFilter []expression.Expression + leftTypes, rightTypes []*types.FieldType + ) + if v.InnerChildIdx == 0 { + leftTypes, rightTypes = innerTypes, outerTypes + outerFilter = v.RightConditions + if len(v.LeftConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } else { + leftTypes, rightTypes = outerTypes, innerTypes + outerFilter = v.LeftConditions + if len(v.RightConditions) > 0 { + b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") + return nil + } + } + defaultValues := v.DefaultValues + if defaultValues == nil { + defaultValues = make([]types.Datum, len(innerTypes)) + } + outerKeyCols := make([]int, len(v.OuterJoinKeys)) + for i := 0; i < len(v.OuterJoinKeys); i++ { + outerKeyCols[i] = v.OuterJoinKeys[i].Index + } + innerKeyCols := make([]int, len(v.InnerJoinKeys)) + keyCollators := make([]collate.Collator, 0, len(v.InnerJoinKeys)) + for i := 0; i < len(v.InnerJoinKeys); i++ { + innerKeyCols[i] = v.InnerJoinKeys[i].Index + keyCollators = append(keyCollators, collate.GetCollator(v.InnerJoinKeys[i].RetType.GetCollate())) + } + executor_metrics.ExecutorCounterIndexLookUpJoin.Inc() + + readerBuilder, err := b.newDataReaderBuilder(innerPlan) + if err != nil { + b.err = err + return nil + } + + e := &join.IndexLookUpMergeJoin{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec), + OuterMergeCtx: join.OuterMergeCtx{ + RowTypes: outerTypes, + Filter: outerFilter, + JoinKeys: v.OuterJoinKeys, + KeyCols: outerKeyCols, + NeedOuterSort: v.NeedOuterSort, + CompareFuncs: v.OuterCompareFuncs, + }, + InnerMergeCtx: join.InnerMergeCtx{ + ReaderBuilder: readerBuilder, + RowTypes: innerTypes, + JoinKeys: v.InnerJoinKeys, + KeyCols: innerKeyCols, + KeyCollators: keyCollators, + CompareFuncs: v.CompareFuncs, + ColLens: v.IdxColLens, + Desc: v.Desc, + KeyOff2KeyOffOrderByIdx: v.KeyOff2KeyOffOrderByIdx, + }, + WorkerWg: new(sync.WaitGroup), + IsOuterJoin: v.JoinType.IsOuterJoin(), + IndexRanges: v.Ranges, + KeyOff2IdxOff: v.KeyOff2IdxOff, + LastColHelper: v.CompareFilters, + } + colsFromChildren := v.Schema().Columns + if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { + colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] + } + childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) + joiners := make([]join.Joiner, e.Ctx().GetSessionVars().IndexLookupJoinConcurrency()) + for i := 0; i < len(joiners); i++ { + joiners[i] = join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema, false) + } + e.Joiners = joiners + return e +} + +func (b *executorBuilder) buildIndexNestedLoopHashJoin(v *plannercore.PhysicalIndexHashJoin) exec.Executor { + joinExec := b.buildIndexLookUpJoin(&(v.PhysicalIndexJoin)) + if b.err != nil { + return nil + } + e := joinExec.(*join.IndexLookUpJoin) + idxHash := &join.IndexNestedLoopHashJoin{ + IndexLookUpJoin: *e, + KeepOuterOrder: v.KeepOuterOrder, + } + concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() + idxHash.Joiners = make([]join.Joiner, concurrency) + for i := 0; i < concurrency; i++ { + idxHash.Joiners[i] = e.Joiner.Clone() + } + return idxHash +} + +func buildNoRangeTableReader(b *executorBuilder, v *plannercore.PhysicalTableReader) (*TableReaderExecutor, error) { + tablePlans := v.TablePlans + if v.StoreType == kv.TiFlash { + tablePlans = []base.PhysicalPlan{v.GetTablePlan()} + } + dagReq, err := builder.ConstructDAGReq(b.ctx, tablePlans, v.StoreType) + if err != nil { + return nil, err + } + ts, err := v.GetTableScan() + if err != nil { + return nil, err + } + if err = b.validCanReadTemporaryOrCacheTable(ts.Table); err != nil { + return nil, err + } + + tbl, _ := b.is.TableByID(ts.Table.ID) + isPartition, physicalTableID := ts.IsPartition() + if isPartition { + pt := tbl.(table.PartitionedTable) + tbl = pt.GetPartition(physicalTableID) + } + startTS, err := b.getSnapshotTS() + if err != nil { + return nil, err + } + paging := b.ctx.GetSessionVars().EnablePaging + + e := &TableReaderExecutor{ + BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()), + tableReaderExecutorContext: newTableReaderExecutorContext(b.ctx), + dagPB: dagReq, + startTS: startTS, + txnScope: b.txnScope, + readReplicaScope: b.readReplicaScope, + isStaleness: b.isStaleness, + netDataSize: v.GetNetDataSize(), + table: tbl, + keepOrder: ts.KeepOrder, + desc: ts.Desc, + byItems: ts.ByItems, + columns: ts.Columns, + paging: paging, + corColInFilter: b.corColInDistPlan(v.TablePlans), + corColInAccess: b.corColInAccess(v.TablePlans[0]), + plans: v.TablePlans, + tablePlan: v.GetTablePlan(), + storeType: v.StoreType, + batchCop: v.ReadReqType == plannercore.BatchCop, + } + e.buildVirtualColumnInfo() + + if v.StoreType == kv.TiDB && b.ctx.GetSessionVars().User != nil { + // User info is used to do privilege check. It is only used in TiDB cluster memory table. + e.dagPB.User = &tipb.UserIdentity{ + UserName: b.ctx.GetSessionVars().User.Username, + UserHost: b.ctx.GetSessionVars().User.Hostname, + } + } + + for i := range v.Schema().Columns { + dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) + } + + return e, nil +} + +func (b *executorBuilder) buildMPPGather(v *plannercore.PhysicalTableReader) exec.Executor { + startTs, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + + gather := &MPPGather{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + is: b.is, + originalPlan: v.GetTablePlan(), + startTS: startTs, + mppQueryID: kv.MPPQueryID{QueryTs: getMPPQueryTS(b.ctx), LocalQueryID: getMPPQueryID(b.ctx), ServerID: domain.GetDomain(b.ctx).ServerID()}, + memTracker: memory.NewTracker(v.ID(), -1), + + columns: []*model.ColumnInfo{}, + virtualColumnIndex: []int{}, + virtualColumnRetFieldTypes: []*types.FieldType{}, + } + + gather.memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker) + + var hasVirtualCol bool + for _, col := range v.Schema().Columns { + if col.VirtualExpr != nil { + hasVirtualCol = true + break + } + } + + var isSingleDataSource bool + tableScans := v.GetTableScans() + if len(tableScans) == 1 { + isSingleDataSource = true + } + + // 1. hasVirtualCol: when got virtual column in TableScan, will generate plan like the following, + // and there will be no other operators in the MPP fragment. + // MPPGather + // ExchangeSender + // PhysicalTableScan + // 2. UnionScan: there won't be any operators like Join between UnionScan and TableScan. + // and UnionScan cannot push down to tiflash. + if !isSingleDataSource { + if hasVirtualCol || b.encounterUnionScan { + b.err = errors.Errorf("should only have one TableScan in MPP fragment(hasVirtualCol: %v, encounterUnionScan: %v)", hasVirtualCol, b.encounterUnionScan) + return nil + } + return gather + } + + // Setup MPPGather.table if isSingleDataSource. + // Virtual Column or UnionScan need to use it. + ts := tableScans[0] + gather.columns = ts.Columns + if hasVirtualCol { + gather.virtualColumnIndex, gather.virtualColumnRetFieldTypes = buildVirtualColumnInfo(gather.Schema(), gather.columns) + } + tbl, _ := b.is.TableByID(ts.Table.ID) + isPartition, physicalTableID := ts.IsPartition() + if isPartition { + // Only for static pruning partition table. + pt := tbl.(table.PartitionedTable) + tbl = pt.GetPartition(physicalTableID) + } + gather.table = tbl + return gather +} + +// buildTableReader builds a table reader executor. It first build a no range table reader, +// and then update it ranges from table scan plan. +func (b *executorBuilder) buildTableReader(v *plannercore.PhysicalTableReader) exec.Executor { + failpoint.Inject("checkUseMPP", func(val failpoint.Value) { + if !b.ctx.GetSessionVars().InRestrictedSQL && val.(bool) != useMPPExecution(b.ctx, v) { + if val.(bool) { + b.err = errors.New("expect mpp but not used") + } else { + b.err = errors.New("don't expect mpp but we used it") + } + failpoint.Return(nil) + } + }) + // https://github.com/pingcap/tidb/issues/50358 + if len(v.Schema().Columns) == 0 && len(v.GetTablePlan().Schema().Columns) > 0 { + v.SetSchema(v.GetTablePlan().Schema()) + } + useMPP := useMPPExecution(b.ctx, v) + useTiFlashBatchCop := v.ReadReqType == plannercore.BatchCop + useTiFlash := useMPP || useTiFlashBatchCop + if useTiFlash { + if _, isTiDBZoneLabelSet := config.GetGlobalConfig().Labels[placement.DCLabelKey]; b.ctx.GetSessionVars().TiFlashReplicaRead != tiflash.AllReplicas && !isTiDBZoneLabelSet { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("the variable tiflash_replica_read is ignored, because the entry TiDB[%s] does not set the zone attribute and tiflash_replica_read is '%s'", config.GetGlobalConfig().AdvertiseAddress, tiflash.GetTiFlashReplicaRead(b.ctx.GetSessionVars().TiFlashReplicaRead))) + } + } + if useMPP { + return b.buildMPPGather(v) + } + ts, err := v.GetTableScan() + if err != nil { + b.err = err + return nil + } + ret, err := buildNoRangeTableReader(b, v) + if err != nil { + b.err = err + return nil + } + if err = b.validCanReadTemporaryOrCacheTable(ts.Table); err != nil { + b.err = err + return nil + } + + if ret.table.Meta().TempTableType != model.TempTableNone { + ret.dummy = true + } + + ret.ranges = ts.Ranges + sctx := b.ctx.GetSessionVars().StmtCtx + sctx.TableIDs = append(sctx.TableIDs, ts.Table.ID) + + if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + return ret + } + // When isPartition is set, it means the union rewriting is done, so a partition reader is preferred. + if ok, _ := ts.IsPartition(); ok { + return ret + } + + pi := ts.Table.GetPartitionInfo() + if pi == nil { + return ret + } + + tmp, _ := b.is.TableByID(ts.Table.ID) + tbl := tmp.(table.PartitionedTable) + partitions, err := partitionPruning(b.ctx, tbl, v.PlanPartInfo) + if err != nil { + b.err = err + return nil + } + if v.StoreType == kv.TiFlash { + sctx.IsTiFlash.Store(true) + } + + if len(partitions) == 0 { + return &TableDualExec{BaseExecutorV2: ret.BaseExecutorV2} + } + + // Sort the partition is necessary to make the final multiple partition key ranges ordered. + slices.SortFunc(partitions, func(i, j table.PhysicalTable) int { + return cmp.Compare(i.GetPhysicalID(), j.GetPhysicalID()) + }) + ret.kvRangeBuilder = kvRangeBuilderFromRangeAndPartition{ + partitions: partitions, + } + + return ret +} + +func buildIndexRangeForEachPartition(rctx *rangerctx.RangerContext, usedPartitions []table.PhysicalTable, contentPos []int64, + lookUpContent []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager) (map[int64][]*ranger.Range, error) { + contentBucket := make(map[int64][]*join.IndexJoinLookUpContent) + for _, p := range usedPartitions { + contentBucket[p.GetPhysicalID()] = make([]*join.IndexJoinLookUpContent, 0, 8) + } + for i, pos := range contentPos { + if _, ok := contentBucket[pos]; ok { + contentBucket[pos] = append(contentBucket[pos], lookUpContent[i]) + } + } + nextRange := make(map[int64][]*ranger.Range) + for _, p := range usedPartitions { + ranges, err := buildRangesForIndexJoin(rctx, contentBucket[p.GetPhysicalID()], indexRanges, keyOff2IdxOff, cwc) + if err != nil { + return nil, err + } + nextRange[p.GetPhysicalID()] = ranges + } + return nextRange, nil +} + +func getPartitionKeyColOffsets(keyColIDs []int64, pt table.PartitionedTable) []int { + keyColOffsets := make([]int, len(keyColIDs)) + for i, colID := range keyColIDs { + offset := -1 + for j, col := range pt.Cols() { + if colID == col.ID { + offset = j + break + } + } + if offset == -1 { + return nil + } + keyColOffsets[i] = offset + } + + t, ok := pt.(interface { + PartitionExpr() *tables.PartitionExpr + }) + if !ok { + return nil + } + pe := t.PartitionExpr() + if pe == nil { + return nil + } + + offsetMap := make(map[int]struct{}) + for _, offset := range keyColOffsets { + offsetMap[offset] = struct{}{} + } + for _, offset := range pe.ColumnOffset { + if _, ok := offsetMap[offset]; !ok { + return nil + } + } + return keyColOffsets +} + +func (builder *dataReaderBuilder) prunePartitionForInnerExecutor(tbl table.Table, physPlanPartInfo *plannercore.PhysPlanPartInfo, + lookUpContent []*join.IndexJoinLookUpContent) (usedPartition []table.PhysicalTable, canPrune bool, contentPos []int64, err error) { + partitionTbl := tbl.(table.PartitionedTable) + + // In index join, this is called by multiple goroutines simultaneously, but partitionPruning is not thread-safe. + // Use once.Do to avoid DATA RACE here. + // TODO: condition based pruning can be do in advance. + condPruneResult, err := builder.partitionPruning(partitionTbl, physPlanPartInfo) + if err != nil { + return nil, false, nil, err + } + + // recalculate key column offsets + if len(lookUpContent) == 0 { + return nil, false, nil, nil + } + if lookUpContent[0].KeyColIDs == nil { + return nil, false, nil, plannererrors.ErrInternal.GenWithStack("cannot get column IDs when dynamic pruning") + } + keyColOffsets := getPartitionKeyColOffsets(lookUpContent[0].KeyColIDs, partitionTbl) + if len(keyColOffsets) == 0 { + return condPruneResult, false, nil, nil + } + + locateKey := make([]types.Datum, len(partitionTbl.Cols())) + partitions := make(map[int64]table.PhysicalTable) + contentPos = make([]int64, len(lookUpContent)) + exprCtx := builder.ctx.GetExprCtx() + for idx, content := range lookUpContent { + for i, data := range content.Keys { + locateKey[keyColOffsets[i]] = data + } + p, err := partitionTbl.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey) + if table.ErrNoPartitionForGivenValue.Equal(err) { + continue + } + if err != nil { + return nil, false, nil, err + } + if _, ok := partitions[p.GetPhysicalID()]; !ok { + partitions[p.GetPhysicalID()] = p + } + contentPos[idx] = p.GetPhysicalID() + } + + usedPartition = make([]table.PhysicalTable, 0, len(partitions)) + for _, p := range condPruneResult { + if _, ok := partitions[p.GetPhysicalID()]; ok { + usedPartition = append(usedPartition, p) + } + } + + // To make the final key ranges involving multiple partitions ordered. + slices.SortFunc(usedPartition, func(i, j table.PhysicalTable) int { + return cmp.Compare(i.GetPhysicalID(), j.GetPhysicalID()) + }) + return usedPartition, true, contentPos, nil +} + +func buildNoRangeIndexReader(b *executorBuilder, v *plannercore.PhysicalIndexReader) (*IndexReaderExecutor, error) { + dagReq, err := builder.ConstructDAGReq(b.ctx, v.IndexPlans, kv.TiKV) + if err != nil { + return nil, err + } + is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) + tbl, _ := b.is.TableByID(is.Table.ID) + isPartition, physicalTableID := is.IsPartition() + if isPartition { + pt := tbl.(table.PartitionedTable) + tbl = pt.GetPartition(physicalTableID) + } else { + physicalTableID = is.Table.ID + } + startTS, err := b.getSnapshotTS() + if err != nil { + return nil, err + } + paging := b.ctx.GetSessionVars().EnablePaging + + e := &IndexReaderExecutor{ + indexReaderExecutorContext: newIndexReaderExecutorContext(b.ctx), + BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()), + indexUsageReporter: b.buildIndexUsageReporter(v), + dagPB: dagReq, + startTS: startTS, + txnScope: b.txnScope, + readReplicaScope: b.readReplicaScope, + isStaleness: b.isStaleness, + netDataSize: v.GetNetDataSize(), + physicalTableID: physicalTableID, + table: tbl, + index: is.Index, + keepOrder: is.KeepOrder, + desc: is.Desc, + columns: is.Columns, + byItems: is.ByItems, + paging: paging, + corColInFilter: b.corColInDistPlan(v.IndexPlans), + corColInAccess: b.corColInAccess(v.IndexPlans[0]), + idxCols: is.IdxCols, + colLens: is.IdxColLens, + plans: v.IndexPlans, + outputColumns: v.OutputColumns, + } + + for _, col := range v.OutputColumns { + dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(col.Index)) + } + + return e, nil +} + +func (b *executorBuilder) buildIndexReader(v *plannercore.PhysicalIndexReader) exec.Executor { + is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) + if err := b.validCanReadTemporaryOrCacheTable(is.Table); err != nil { + b.err = err + return nil + } + + ret, err := buildNoRangeIndexReader(b, v) + if err != nil { + b.err = err + return nil + } + + if ret.table.Meta().TempTableType != model.TempTableNone { + ret.dummy = true + } + + ret.ranges = is.Ranges + sctx := b.ctx.GetSessionVars().StmtCtx + sctx.IndexNames = append(sctx.IndexNames, is.Table.Name.O+":"+is.Index.Name.O) + + if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + return ret + } + // When isPartition is set, it means the union rewriting is done, so a partition reader is preferred. + if ok, _ := is.IsPartition(); ok { + return ret + } + + pi := is.Table.GetPartitionInfo() + if pi == nil { + return ret + } + + if is.Index.Global { + ret.partitionIDMap, err = getPartitionIDsAfterPruning(b.ctx, ret.table.(table.PartitionedTable), v.PlanPartInfo) + if err != nil { + b.err = err + return nil + } + return ret + } + + tmp, _ := b.is.TableByID(is.Table.ID) + tbl := tmp.(table.PartitionedTable) + partitions, err := partitionPruning(b.ctx, tbl, v.PlanPartInfo) + if err != nil { + b.err = err + return nil + } + ret.partitions = partitions + return ret +} + +func buildTableReq(b *executorBuilder, schemaLen int, plans []base.PhysicalPlan) (dagReq *tipb.DAGRequest, val table.Table, err error) { + tableReq, err := builder.ConstructDAGReq(b.ctx, plans, kv.TiKV) + if err != nil { + return nil, nil, err + } + for i := 0; i < schemaLen; i++ { + tableReq.OutputOffsets = append(tableReq.OutputOffsets, uint32(i)) + } + ts := plans[0].(*plannercore.PhysicalTableScan) + tbl, _ := b.is.TableByID(ts.Table.ID) + isPartition, physicalTableID := ts.IsPartition() + if isPartition { + pt := tbl.(table.PartitionedTable) + tbl = pt.GetPartition(physicalTableID) + } + return tableReq, tbl, err +} + +// buildIndexReq is designed to create a DAG for index request. +// If len(ByItems) != 0 means index request should return related columns +// to sort result rows in TiDB side for partition tables. +func buildIndexReq(ctx sessionctx.Context, columns []*model.IndexColumn, handleLen int, plans []base.PhysicalPlan) (dagReq *tipb.DAGRequest, err error) { + indexReq, err := builder.ConstructDAGReq(ctx, plans, kv.TiKV) + if err != nil { + return nil, err + } + + indexReq.OutputOffsets = []uint32{} + idxScan := plans[0].(*plannercore.PhysicalIndexScan) + if len(idxScan.ByItems) != 0 { + schema := idxScan.Schema() + for _, item := range idxScan.ByItems { + c, ok := item.Expr.(*expression.Column) + if !ok { + return nil, errors.Errorf("Not support non-column in orderBy pushed down") + } + find := false + for i, schemaColumn := range schema.Columns { + if schemaColumn.ID == c.ID { + indexReq.OutputOffsets = append(indexReq.OutputOffsets, uint32(i)) + find = true + break + } + } + if !find { + return nil, errors.Errorf("Not found order by related columns in indexScan.schema") + } + } + } + + for i := 0; i < handleLen; i++ { + indexReq.OutputOffsets = append(indexReq.OutputOffsets, uint32(len(columns)+i)) + } + + if idxScan.NeedExtraOutputCol() { + // need add one more column for pid or physical table id + indexReq.OutputOffsets = append(indexReq.OutputOffsets, uint32(len(columns)+handleLen)) + } + return indexReq, err +} + +func buildNoRangeIndexLookUpReader(b *executorBuilder, v *plannercore.PhysicalIndexLookUpReader) (*IndexLookUpExecutor, error) { + is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) + var handleLen int + if len(v.CommonHandleCols) != 0 { + handleLen = len(v.CommonHandleCols) + } else { + handleLen = 1 + } + indexReq, err := buildIndexReq(b.ctx, is.Index.Columns, handleLen, v.IndexPlans) + if err != nil { + return nil, err + } + indexPaging := false + if v.Paging { + indexPaging = true + } + tableReq, tbl, err := buildTableReq(b, v.Schema().Len(), v.TablePlans) + if err != nil { + return nil, err + } + ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) + startTS, err := b.getSnapshotTS() + if err != nil { + return nil, err + } + + readerBuilder, err := b.newDataReaderBuilder(nil) + if err != nil { + return nil, err + } + + e := &IndexLookUpExecutor{ + indexLookUpExecutorContext: newIndexLookUpExecutorContext(b.ctx), + BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()), + indexUsageReporter: b.buildIndexUsageReporter(v), + dagPB: indexReq, + startTS: startTS, + table: tbl, + index: is.Index, + keepOrder: is.KeepOrder, + byItems: is.ByItems, + desc: is.Desc, + tableRequest: tableReq, + columns: ts.Columns, + indexPaging: indexPaging, + dataReaderBuilder: readerBuilder, + corColInIdxSide: b.corColInDistPlan(v.IndexPlans), + corColInTblSide: b.corColInDistPlan(v.TablePlans), + corColInAccess: b.corColInAccess(v.IndexPlans[0]), + idxCols: is.IdxCols, + colLens: is.IdxColLens, + idxPlans: v.IndexPlans, + tblPlans: v.TablePlans, + PushedLimit: v.PushedLimit, + idxNetDataSize: v.GetAvgTableRowSize(), + avgRowSize: v.GetAvgTableRowSize(), + } + + if v.ExtraHandleCol != nil { + e.handleIdx = append(e.handleIdx, v.ExtraHandleCol.Index) + e.handleCols = []*expression.Column{v.ExtraHandleCol} + } else { + for _, handleCol := range v.CommonHandleCols { + e.handleIdx = append(e.handleIdx, handleCol.Index) + } + e.handleCols = v.CommonHandleCols + e.primaryKeyIndex = tables.FindPrimaryIndex(tbl.Meta()) + } + return e, nil +} + +func (b *executorBuilder) buildIndexLookUpReader(v *plannercore.PhysicalIndexLookUpReader) exec.Executor { + is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) + if err := b.validCanReadTemporaryOrCacheTable(is.Table); err != nil { + b.err = err + return nil + } + + ret, err := buildNoRangeIndexLookUpReader(b, v) + if err != nil { + b.err = err + return nil + } + + if ret.table.Meta().TempTableType != model.TempTableNone { + ret.dummy = true + } + + ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) + + ret.ranges = is.Ranges + executor_metrics.ExecutorCounterIndexLookUpExecutor.Inc() + + sctx := b.ctx.GetSessionVars().StmtCtx + sctx.IndexNames = append(sctx.IndexNames, is.Table.Name.O+":"+is.Index.Name.O) + sctx.TableIDs = append(sctx.TableIDs, ts.Table.ID) + + if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + return ret + } + + if pi := is.Table.GetPartitionInfo(); pi == nil { + return ret + } + + if is.Index.Global { + ret.partitionIDMap, err = getPartitionIDsAfterPruning(b.ctx, ret.table.(table.PartitionedTable), v.PlanPartInfo) + if err != nil { + b.err = err + return nil + } + + return ret + } + if ok, _ := is.IsPartition(); ok { + // Already pruned when translated to logical union. + return ret + } + + tmp, _ := b.is.TableByID(is.Table.ID) + tbl := tmp.(table.PartitionedTable) + partitions, err := partitionPruning(b.ctx, tbl, v.PlanPartInfo) + if err != nil { + b.err = err + return nil + } + ret.partitionTableMode = true + ret.prunedPartitions = partitions + return ret +} + +func buildNoRangeIndexMergeReader(b *executorBuilder, v *plannercore.PhysicalIndexMergeReader) (*IndexMergeReaderExecutor, error) { + partialPlanCount := len(v.PartialPlans) + partialReqs := make([]*tipb.DAGRequest, 0, partialPlanCount) + partialDataSizes := make([]float64, 0, partialPlanCount) + indexes := make([]*model.IndexInfo, 0, partialPlanCount) + descs := make([]bool, 0, partialPlanCount) + ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) + isCorColInPartialFilters := make([]bool, 0, partialPlanCount) + isCorColInPartialAccess := make([]bool, 0, partialPlanCount) + hasGlobalIndex := false + for i := 0; i < partialPlanCount; i++ { + var tempReq *tipb.DAGRequest + var err error + + if is, ok := v.PartialPlans[i][0].(*plannercore.PhysicalIndexScan); ok { + tempReq, err = buildIndexReq(b.ctx, is.Index.Columns, ts.HandleCols.NumCols(), v.PartialPlans[i]) + descs = append(descs, is.Desc) + indexes = append(indexes, is.Index) + if is.Index.Global { + hasGlobalIndex = true + } + } else { + ts := v.PartialPlans[i][0].(*plannercore.PhysicalTableScan) + tempReq, _, err = buildTableReq(b, len(ts.Columns), v.PartialPlans[i]) + descs = append(descs, ts.Desc) + indexes = append(indexes, nil) + } + if err != nil { + return nil, err + } + collect := false + tempReq.CollectRangeCounts = &collect + partialReqs = append(partialReqs, tempReq) + isCorColInPartialFilters = append(isCorColInPartialFilters, b.corColInDistPlan(v.PartialPlans[i])) + isCorColInPartialAccess = append(isCorColInPartialAccess, b.corColInAccess(v.PartialPlans[i][0])) + partialDataSizes = append(partialDataSizes, v.GetPartialReaderNetDataSize(v.PartialPlans[i][0])) + } + tableReq, tblInfo, err := buildTableReq(b, v.Schema().Len(), v.TablePlans) + isCorColInTableFilter := b.corColInDistPlan(v.TablePlans) + if err != nil { + return nil, err + } + startTS, err := b.getSnapshotTS() + if err != nil { + return nil, err + } + + readerBuilder, err := b.newDataReaderBuilder(nil) + if err != nil { + return nil, err + } + + e := &IndexMergeReaderExecutor{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + indexUsageReporter: b.buildIndexUsageReporter(v), + dagPBs: partialReqs, + startTS: startTS, + table: tblInfo, + indexes: indexes, + descs: descs, + tableRequest: tableReq, + columns: ts.Columns, + partialPlans: v.PartialPlans, + tblPlans: v.TablePlans, + partialNetDataSizes: partialDataSizes, + dataAvgRowSize: v.GetAvgTableRowSize(), + dataReaderBuilder: readerBuilder, + handleCols: v.HandleCols, + isCorColInPartialFilters: isCorColInPartialFilters, + isCorColInTableFilter: isCorColInTableFilter, + isCorColInPartialAccess: isCorColInPartialAccess, + isIntersection: v.IsIntersectionType, + byItems: v.ByItems, + pushedLimit: v.PushedLimit, + keepOrder: v.KeepOrder, + hasGlobalIndex: hasGlobalIndex, + } + collectTable := false + e.tableRequest.CollectRangeCounts = &collectTable + return e, nil +} + +type tableStatsPreloader interface { + LoadTableStats(sessionctx.Context) +} + +func (b *executorBuilder) buildIndexUsageReporter(plan tableStatsPreloader) (indexUsageReporter *exec.IndexUsageReporter) { + sc := b.ctx.GetSessionVars().StmtCtx + if b.ctx.GetSessionVars().StmtCtx.IndexUsageCollector != nil && + sc.RuntimeStatsColl != nil { + // Preload the table stats. If the statement is a point-get or execute, the planner may not have loaded the + // stats. + plan.LoadTableStats(b.ctx) + + statsMap := sc.GetUsedStatsInfo(false) + indexUsageReporter = exec.NewIndexUsageReporter( + sc.IndexUsageCollector, + sc.RuntimeStatsColl, statsMap) + } + + return indexUsageReporter +} + +func (b *executorBuilder) buildIndexMergeReader(v *plannercore.PhysicalIndexMergeReader) exec.Executor { + ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) + if err := b.validCanReadTemporaryOrCacheTable(ts.Table); err != nil { + b.err = err + return nil + } + + ret, err := buildNoRangeIndexMergeReader(b, v) + if err != nil { + b.err = err + return nil + } + ret.ranges = make([][]*ranger.Range, 0, len(v.PartialPlans)) + sctx := b.ctx.GetSessionVars().StmtCtx + hasGlobalIndex := false + for i := 0; i < len(v.PartialPlans); i++ { + if is, ok := v.PartialPlans[i][0].(*plannercore.PhysicalIndexScan); ok { + ret.ranges = append(ret.ranges, is.Ranges) + sctx.IndexNames = append(sctx.IndexNames, is.Table.Name.O+":"+is.Index.Name.O) + if is.Index.Global { + hasGlobalIndex = true + } + } else { + ret.ranges = append(ret.ranges, v.PartialPlans[i][0].(*plannercore.PhysicalTableScan).Ranges) + if ret.table.Meta().IsCommonHandle { + tblInfo := ret.table.Meta() + sctx.IndexNames = append(sctx.IndexNames, tblInfo.Name.O+":"+tables.FindPrimaryIndex(tblInfo).Name.O) + } + } + } + sctx.TableIDs = append(sctx.TableIDs, ts.Table.ID) + executor_metrics.ExecutorCounterIndexMergeReaderExecutor.Inc() + + if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + return ret + } + + if pi := ts.Table.GetPartitionInfo(); pi == nil { + return ret + } + + tmp, _ := b.is.TableByID(ts.Table.ID) + partitions, err := partitionPruning(b.ctx, tmp.(table.PartitionedTable), v.PlanPartInfo) + if err != nil { + b.err = err + return nil + } + ret.partitionTableMode, ret.prunedPartitions = true, partitions + if hasGlobalIndex { + ret.partitionIDMap = make(map[int64]struct{}) + for _, p := range partitions { + ret.partitionIDMap[p.GetPhysicalID()] = struct{}{} + } + } + return ret +} + +// dataReaderBuilder build an executor. +// The executor can be used to read data in the ranges which are constructed by datums. +// Differences from executorBuilder: +// 1. dataReaderBuilder calculate data range from argument, rather than plan. +// 2. the result executor is already opened. +type dataReaderBuilder struct { + plan base.Plan + *executorBuilder + + selectResultHook // for testing + once struct { + sync.Once + condPruneResult []table.PhysicalTable + err error + } +} + +type mockPhysicalIndexReader struct { + base.PhysicalPlan + + e exec.Executor +} + +// MemoryUsage of mockPhysicalIndexReader is only for testing +func (*mockPhysicalIndexReader) MemoryUsage() (sum int64) { + return +} + +func (builder *dataReaderBuilder) BuildExecutorForIndexJoin(ctx context.Context, lookUpContents []*join.IndexJoinLookUpContent, + indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { + return builder.buildExecutorForIndexJoinInternal(ctx, builder.plan, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) +} + +func (builder *dataReaderBuilder) buildExecutorForIndexJoinInternal(ctx context.Context, plan base.Plan, lookUpContents []*join.IndexJoinLookUpContent, + indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { + switch v := plan.(type) { + case *plannercore.PhysicalTableReader: + return builder.buildTableReaderForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + case *plannercore.PhysicalIndexReader: + return builder.buildIndexReaderForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + case *plannercore.PhysicalIndexLookUpReader: + return builder.buildIndexLookUpReaderForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + case *plannercore.PhysicalUnionScan: + return builder.buildUnionScanForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + case *plannercore.PhysicalProjection: + return builder.buildProjectionForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + // Need to support physical selection because after PR 16389, TiDB will push down all the expr supported by TiKV or TiFlash + // in predicate push down stage, so if there is an expr which only supported by TiFlash, a physical selection will be added after index read + case *plannercore.PhysicalSelection: + childExec, err := builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + if err != nil { + return nil, err + } + exec := &SelectionExec{ + selectionExecutorContext: newSelectionExecutorContext(builder.ctx), + BaseExecutorV2: exec.NewBaseExecutorV2(builder.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), + filters: v.Conditions, + } + err = exec.open(ctx) + return exec, err + case *plannercore.PhysicalHashAgg: + childExec, err := builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + if err != nil { + return nil, err + } + exec := builder.buildHashAggFromChildExec(childExec, v) + err = exec.OpenSelf() + return exec, err + case *plannercore.PhysicalStreamAgg: + childExec, err := builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + if err != nil { + return nil, err + } + exec := builder.buildStreamAggFromChildExec(childExec, v) + err = exec.OpenSelf() + return exec, err + case *mockPhysicalIndexReader: + return v.e, nil + } + return nil, errors.New("Wrong plan type for dataReaderBuilder") +} + +func (builder *dataReaderBuilder) buildUnionScanForIndexJoin(ctx context.Context, v *plannercore.PhysicalUnionScan, + values []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, + cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { + childBuilder, err := builder.newDataReaderBuilder(v.Children()[0]) + if err != nil { + return nil, err + } + + reader, err := childBuilder.BuildExecutorForIndexJoin(ctx, values, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + if err != nil { + return nil, err + } + + ret := builder.buildUnionScanFromReader(reader, v) + if builder.err != nil { + return nil, builder.err + } + if us, ok := ret.(*UnionScanExec); ok { + err = us.open(ctx) + } + return ret, err +} + +func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Context, v *plannercore.PhysicalTableReader, + lookUpContents []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, + cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { + e, err := buildNoRangeTableReader(builder.executorBuilder, v) + if !canReorderHandles { + // `canReorderHandles` is set to false only in IndexMergeJoin. IndexMergeJoin will trigger a dead loop problem + // when enabling paging(tidb/issues/35831). But IndexMergeJoin is not visible to the user and is deprecated + // for now. Thus, we disable paging here. + e.paging = false + } + if err != nil { + return nil, err + } + tbInfo := e.table.Meta() + if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + if v.IsCommonHandle { + kvRanges, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, getPhysicalTableID(e.table), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + if err != nil { + return nil, err + } + return builder.buildTableReaderFromKvRanges(ctx, e, kvRanges) + } + handles, _ := dedupHandles(lookUpContents) + return builder.buildTableReaderFromHandles(ctx, e, handles, canReorderHandles) + } + tbl, _ := builder.is.TableByID(tbInfo.ID) + pt := tbl.(table.PartitionedTable) + usedPartitionList, err := builder.partitionPruning(pt, v.PlanPartInfo) + if err != nil { + return nil, err + } + usedPartitions := make(map[int64]table.PhysicalTable, len(usedPartitionList)) + for _, p := range usedPartitionList { + usedPartitions[p.GetPhysicalID()] = p + } + var kvRanges []kv.KeyRange + var keyColOffsets []int + if len(lookUpContents) > 0 { + keyColOffsets = getPartitionKeyColOffsets(lookUpContents[0].KeyColIDs, pt) + } + if v.IsCommonHandle { + if len(keyColOffsets) > 0 { + locateKey := make([]types.Datum, len(pt.Cols())) + kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) + // lookUpContentsByPID groups lookUpContents by pid(partition) so that kv ranges for same partition can be merged. + lookUpContentsByPID := make(map[int64][]*join.IndexJoinLookUpContent) + exprCtx := e.ectx + for _, content := range lookUpContents { + for i, data := range content.Keys { + locateKey[keyColOffsets[i]] = data + } + p, err := pt.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey) + if table.ErrNoPartitionForGivenValue.Equal(err) { + continue + } + if err != nil { + return nil, err + } + pid := p.GetPhysicalID() + if _, ok := usedPartitions[pid]; !ok { + continue + } + lookUpContentsByPID[pid] = append(lookUpContentsByPID[pid], content) + } + for pid, contents := range lookUpContentsByPID { + // buildKvRanges for each partition. + tmp, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, pid, -1, contents, indexRanges, keyOff2IdxOff, cwc, nil, interruptSignal) + if err != nil { + return nil, err + } + kvRanges = append(kvRanges, tmp...) + } + } else { + kvRanges = make([]kv.KeyRange, 0, len(usedPartitions)*len(lookUpContents)) + for _, p := range usedPartitionList { + tmp, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, p.GetPhysicalID(), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + if err != nil { + return nil, err + } + kvRanges = append(tmp, kvRanges...) + } + } + // The key ranges should be ordered. + slices.SortFunc(kvRanges, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + return builder.buildTableReaderFromKvRanges(ctx, e, kvRanges) + } + + handles, lookUpContents := dedupHandles(lookUpContents) + + if len(keyColOffsets) > 0 { + locateKey := make([]types.Datum, len(pt.Cols())) + kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) + exprCtx := e.ectx + for _, content := range lookUpContents { + for i, data := range content.Keys { + locateKey[keyColOffsets[i]] = data + } + p, err := pt.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey) + if table.ErrNoPartitionForGivenValue.Equal(err) { + continue + } + if err != nil { + return nil, err + } + pid := p.GetPhysicalID() + if _, ok := usedPartitions[pid]; !ok { + continue + } + handle := kv.IntHandle(content.Keys[0].GetInt64()) + ranges, _ := distsql.TableHandlesToKVRanges(pid, []kv.Handle{handle}) + kvRanges = append(kvRanges, ranges...) + } + } else { + for _, p := range usedPartitionList { + ranges, _ := distsql.TableHandlesToKVRanges(p.GetPhysicalID(), handles) + kvRanges = append(kvRanges, ranges...) + } + } + + // The key ranges should be ordered. + slices.SortFunc(kvRanges, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + return builder.buildTableReaderFromKvRanges(ctx, e, kvRanges) +} + +func dedupHandles(lookUpContents []*join.IndexJoinLookUpContent) ([]kv.Handle, []*join.IndexJoinLookUpContent) { + handles := make([]kv.Handle, 0, len(lookUpContents)) + validLookUpContents := make([]*join.IndexJoinLookUpContent, 0, len(lookUpContents)) + for _, content := range lookUpContents { + isValidHandle := true + handle := kv.IntHandle(content.Keys[0].GetInt64()) + for _, key := range content.Keys { + if handle.IntValue() != key.GetInt64() { + isValidHandle = false + break + } + } + if isValidHandle { + handles = append(handles, handle) + validLookUpContents = append(validLookUpContents, content) + } + } + return handles, validLookUpContents +} + +type kvRangeBuilderFromRangeAndPartition struct { + partitions []table.PhysicalTable +} + +func (h kvRangeBuilderFromRangeAndPartition) buildKeyRangeSeparately(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([]int64, [][]kv.KeyRange, error) { + ret := make([][]kv.KeyRange, len(h.partitions)) + pids := make([]int64, 0, len(h.partitions)) + for i, p := range h.partitions { + pid := p.GetPhysicalID() + pids = append(pids, pid) + meta := p.Meta() + if len(ranges) == 0 { + continue + } + kvRange, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) + if err != nil { + return nil, nil, err + } + ret[i] = kvRange.AppendSelfTo(ret[i]) + } + return pids, ret, nil +} + +func (h kvRangeBuilderFromRangeAndPartition) buildKeyRange(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([][]kv.KeyRange, error) { + ret := make([][]kv.KeyRange, len(h.partitions)) + if len(ranges) == 0 { + return ret, nil + } + for i, p := range h.partitions { + pid := p.GetPhysicalID() + meta := p.Meta() + kvRange, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) + if err != nil { + return nil, err + } + ret[i] = kvRange.AppendSelfTo(ret[i]) + } + return ret, nil +} + +// newClosestReadAdjuster let the request be sent to closest replica(within the same zone) +// if response size exceeds certain threshold. +func newClosestReadAdjuster(dctx *distsqlctx.DistSQLContext, req *kv.Request, netDataSize float64) kv.CoprRequestAdjuster { + if req.ReplicaRead != kv.ReplicaReadClosestAdaptive { + return nil + } + return func(req *kv.Request, copTaskCount int) bool { + // copTaskCount is the number of coprocessor requests + if int64(netDataSize/float64(copTaskCount)) >= dctx.ReplicaClosestReadThreshold { + req.MatchStoreLabels = append(req.MatchStoreLabels, &metapb.StoreLabel{ + Key: placement.DCLabelKey, + Value: config.GetTxnScopeFromConfig(), + }) + return true + } + // reset to read from leader when the data size is small. + req.ReplicaRead = kv.ReplicaReadLeader + return false + } +} + +func (builder *dataReaderBuilder) buildTableReaderBase(ctx context.Context, e *TableReaderExecutor, reqBuilderWithRange distsql.RequestBuilder) (*TableReaderExecutor, error) { + startTS, err := builder.getSnapshotTS() + if err != nil { + return nil, err + } + kvReq, err := reqBuilderWithRange. + SetDAGRequest(e.dagPB). + SetStartTS(startTS). + SetDesc(e.desc). + SetKeepOrder(e.keepOrder). + SetTxnScope(e.txnScope). + SetReadReplicaScope(e.readReplicaScope). + SetIsStaleness(e.isStaleness). + SetFromSessionVars(e.dctx). + SetFromInfoSchema(e.GetInfoSchema()). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilderWithRange.Request, e.netDataSize)). + SetPaging(e.paging). + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). + Build() + if err != nil { + return nil, err + } + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) + e.resultHandler = &tableResultHandler{} + result, err := builder.SelectResult(ctx, builder.ctx.GetDistSQLCtx(), kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + if err != nil { + return nil, err + } + e.resultHandler.open(nil, result) + return e, nil +} + +func (builder *dataReaderBuilder) buildTableReaderFromHandles(ctx context.Context, e *TableReaderExecutor, handles []kv.Handle, canReorderHandles bool) (*TableReaderExecutor, error) { + if canReorderHandles { + slices.SortFunc(handles, func(i, j kv.Handle) int { + return i.Compare(j) + }) + } + var b distsql.RequestBuilder + if len(handles) > 0 { + if _, ok := handles[0].(kv.PartitionHandle); ok { + b.SetPartitionsAndHandles(handles) + } else { + b.SetTableHandles(getPhysicalTableID(e.table), handles) + } + } else { + b.SetKeyRanges(nil) + } + return builder.buildTableReaderBase(ctx, e, b) +} + +func (builder *dataReaderBuilder) buildTableReaderFromKvRanges(ctx context.Context, e *TableReaderExecutor, ranges []kv.KeyRange) (exec.Executor, error) { + var b distsql.RequestBuilder + b.SetKeyRanges(ranges) + return builder.buildTableReaderBase(ctx, e, b) +} + +func (builder *dataReaderBuilder) buildIndexReaderForIndexJoin(ctx context.Context, v *plannercore.PhysicalIndexReader, + lookUpContents []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memoryTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { + e, err := buildNoRangeIndexReader(builder.executorBuilder, v) + if err != nil { + return nil, err + } + tbInfo := e.table.Meta() + if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + kvRanges, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, e.physicalTableID, e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memoryTracker, interruptSignal) + if err != nil { + return nil, err + } + err = e.open(ctx, kvRanges) + return e, err + } + + is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) + if is.Index.Global { + e.partitionIDMap, err = getPartitionIDsAfterPruning(builder.ctx, e.table.(table.PartitionedTable), v.PlanPartInfo) + if err != nil { + return nil, err + } + if e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc); err != nil { + return nil, err + } + if err := exec.Open(ctx, e); err != nil { + return nil, err + } + return e, nil + } + + tbl, _ := builder.executorBuilder.is.TableByID(tbInfo.ID) + usedPartition, canPrune, contentPos, err := builder.prunePartitionForInnerExecutor(tbl, v.PlanPartInfo, lookUpContents) + if err != nil { + return nil, err + } + if len(usedPartition) != 0 { + if canPrune { + rangeMap, err := buildIndexRangeForEachPartition(e.rctx, usedPartition, contentPos, lookUpContents, indexRanges, keyOff2IdxOff, cwc) + if err != nil { + return nil, err + } + e.partitions = usedPartition + e.ranges = indexRanges + e.partRangeMap = rangeMap + } else { + e.partitions = usedPartition + if e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc); err != nil { + return nil, err + } + } + if err := exec.Open(ctx, e); err != nil { + return nil, err + } + return e, nil + } + ret := &TableDualExec{BaseExecutorV2: e.BaseExecutorV2} + err = exec.Open(ctx, ret) + return ret, err +} + +func (builder *dataReaderBuilder) buildIndexLookUpReaderForIndexJoin(ctx context.Context, v *plannercore.PhysicalIndexLookUpReader, + lookUpContents []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { + e, err := buildNoRangeIndexLookUpReader(builder.executorBuilder, v) + if err != nil { + return nil, err + } + + tbInfo := e.table.Meta() + if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + e.kvRanges, err = buildKvRangesForIndexJoin(e.dctx, e.rctx, getPhysicalTableID(e.table), e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) + if err != nil { + return nil, err + } + err = e.open(ctx) + return e, err + } + + is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) + if is.Index.Global { + e.partitionIDMap, err = getPartitionIDsAfterPruning(builder.ctx, e.table.(table.PartitionedTable), v.PlanPartInfo) + if err != nil { + return nil, err + } + e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc) + if err != nil { + return nil, err + } + if err := exec.Open(ctx, e); err != nil { + return nil, err + } + return e, err + } + + tbl, _ := builder.executorBuilder.is.TableByID(tbInfo.ID) + usedPartition, canPrune, contentPos, err := builder.prunePartitionForInnerExecutor(tbl, v.PlanPartInfo, lookUpContents) + if err != nil { + return nil, err + } + if len(usedPartition) != 0 { + if canPrune { + rangeMap, err := buildIndexRangeForEachPartition(e.rctx, usedPartition, contentPos, lookUpContents, indexRanges, keyOff2IdxOff, cwc) + if err != nil { + return nil, err + } + e.prunedPartitions = usedPartition + e.ranges = indexRanges + e.partitionRangeMap = rangeMap + } else { + e.prunedPartitions = usedPartition + e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc) + if err != nil { + return nil, err + } + } + e.partitionTableMode = true + if err := exec.Open(ctx, e); err != nil { + return nil, err + } + return e, err + } + ret := &TableDualExec{BaseExecutorV2: e.BaseExecutorV2} + err = exec.Open(ctx, ret) + return ret, err +} + +func (builder *dataReaderBuilder) buildProjectionForIndexJoin( + ctx context.Context, + v *plannercore.PhysicalProjection, + lookUpContents []*join.IndexJoinLookUpContent, + indexRanges []*ranger.Range, + keyOff2IdxOff []int, + cwc *plannercore.ColWithCmpFuncManager, + canReorderHandles bool, + memTracker *memory.Tracker, + interruptSignal *atomic.Value, +) (executor exec.Executor, err error) { + var childExec exec.Executor + childExec, err = builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) + if err != nil { + return nil, err + } + defer func() { + if r := recover(); r != nil { + err = util.GetRecoverError(r) + } + if err != nil { + terror.Log(exec.Close(childExec)) + } + }() + + e := &ProjectionExec{ + projectionExecutorContext: newProjectionExecutorContext(builder.ctx), + BaseExecutorV2: exec.NewBaseExecutorV2(builder.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), + numWorkers: int64(builder.ctx.GetSessionVars().ProjectionConcurrency()), + evaluatorSuit: expression.NewEvaluatorSuite(v.Exprs, v.AvoidColumnEvaluator), + calculateNoDelay: v.CalculateNoDelay, + } + + // If the calculation row count for this Projection operator is smaller + // than a Chunk size, we turn back to the un-parallel Projection + // implementation to reduce the goroutine overhead. + if int64(v.StatsCount()) < int64(builder.ctx.GetSessionVars().MaxChunkSize) { + e.numWorkers = 0 + } + failpoint.Inject("buildProjectionForIndexJoinPanic", func(val failpoint.Value) { + if v, ok := val.(bool); ok && v { + panic("buildProjectionForIndexJoinPanic") + } + }) + err = e.open(ctx) + if err != nil { + return nil, err + } + return e, nil +} + +// buildRangesForIndexJoin builds kv ranges for index join when the inner plan is index scan plan. +func buildRangesForIndexJoin(rctx *rangerctx.RangerContext, lookUpContents []*join.IndexJoinLookUpContent, + ranges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager) ([]*ranger.Range, error) { + retRanges := make([]*ranger.Range, 0, len(ranges)*len(lookUpContents)) + lastPos := len(ranges[0].LowVal) - 1 + tmpDatumRanges := make([]*ranger.Range, 0, len(lookUpContents)) + for _, content := range lookUpContents { + for _, ran := range ranges { + for keyOff, idxOff := range keyOff2IdxOff { + ran.LowVal[idxOff] = content.Keys[keyOff] + ran.HighVal[idxOff] = content.Keys[keyOff] + } + } + if cwc == nil { + // A deep copy is need here because the old []*range.Range is overwritten + for _, ran := range ranges { + retRanges = append(retRanges, ran.Clone()) + } + continue + } + nextColRanges, err := cwc.BuildRangesByRow(rctx, content.Row) + if err != nil { + return nil, err + } + for _, nextColRan := range nextColRanges { + for _, ran := range ranges { + ran.LowVal[lastPos] = nextColRan.LowVal[0] + ran.HighVal[lastPos] = nextColRan.HighVal[0] + ran.LowExclude = nextColRan.LowExclude + ran.HighExclude = nextColRan.HighExclude + ran.Collators = nextColRan.Collators + tmpDatumRanges = append(tmpDatumRanges, ran.Clone()) + } + } + } + + if cwc == nil { + return retRanges, nil + } + + return ranger.UnionRanges(rctx, tmpDatumRanges, true) +} + +// buildKvRangesForIndexJoin builds kv ranges for index join when the inner plan is index scan plan. +func buildKvRangesForIndexJoin(dctx *distsqlctx.DistSQLContext, pctx *rangerctx.RangerContext, tableID, indexID int64, lookUpContents []*join.IndexJoinLookUpContent, + ranges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memTracker *memory.Tracker, interruptSignal *atomic.Value) (_ []kv.KeyRange, err error) { + kvRanges := make([]kv.KeyRange, 0, len(ranges)*len(lookUpContents)) + if len(ranges) == 0 { + return []kv.KeyRange{}, nil + } + lastPos := len(ranges[0].LowVal) - 1 + tmpDatumRanges := make([]*ranger.Range, 0, len(lookUpContents)) + for _, content := range lookUpContents { + for _, ran := range ranges { + for keyOff, idxOff := range keyOff2IdxOff { + ran.LowVal[idxOff] = content.Keys[keyOff] + ran.HighVal[idxOff] = content.Keys[keyOff] + } + } + if cwc == nil { + // Index id is -1 means it's a common handle. + var tmpKvRanges *kv.KeyRanges + var err error + if indexID == -1 { + tmpKvRanges, err = distsql.CommonHandleRangesToKVRanges(dctx, []int64{tableID}, ranges) + } else { + tmpKvRanges, err = distsql.IndexRangesToKVRangesWithInterruptSignal(dctx, tableID, indexID, ranges, memTracker, interruptSignal) + } + if err != nil { + return nil, err + } + kvRanges = tmpKvRanges.AppendSelfTo(kvRanges) + continue + } + nextColRanges, err := cwc.BuildRangesByRow(pctx, content.Row) + if err != nil { + return nil, err + } + for _, nextColRan := range nextColRanges { + for _, ran := range ranges { + ran.LowVal[lastPos] = nextColRan.LowVal[0] + ran.HighVal[lastPos] = nextColRan.HighVal[0] + ran.LowExclude = nextColRan.LowExclude + ran.HighExclude = nextColRan.HighExclude + ran.Collators = nextColRan.Collators + tmpDatumRanges = append(tmpDatumRanges, ran.Clone()) + } + } + } + if len(kvRanges) != 0 && memTracker != nil { + failpoint.Inject("testIssue49033", func() { + panic("testIssue49033") + }) + memTracker.Consume(int64(2 * cap(kvRanges[0].StartKey) * len(kvRanges))) + } + if len(tmpDatumRanges) != 0 && memTracker != nil { + memTracker.Consume(2 * types.EstimatedMemUsage(tmpDatumRanges[0].LowVal, len(tmpDatumRanges))) + } + if cwc == nil { + slices.SortFunc(kvRanges, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + return kvRanges, nil + } + + tmpDatumRanges, err = ranger.UnionRanges(pctx, tmpDatumRanges, true) + if err != nil { + return nil, err + } + // Index id is -1 means it's a common handle. + if indexID == -1 { + tmpKeyRanges, err := distsql.CommonHandleRangesToKVRanges(dctx, []int64{tableID}, tmpDatumRanges) + return tmpKeyRanges.FirstPartitionRange(), err + } + tmpKeyRanges, err := distsql.IndexRangesToKVRangesWithInterruptSignal(dctx, tableID, indexID, tmpDatumRanges, memTracker, interruptSignal) + return tmpKeyRanges.FirstPartitionRange(), err +} + +func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Executor { + childExec := b.build(v.Children()[0]) + if b.err != nil { + return nil + } + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec) + groupByItems := make([]expression.Expression, 0, len(v.PartitionBy)) + for _, item := range v.PartitionBy { + groupByItems = append(groupByItems, item.Col) + } + orderByCols := make([]*expression.Column, 0, len(v.OrderBy)) + for _, item := range v.OrderBy { + orderByCols = append(orderByCols, item.Col) + } + windowFuncs := make([]aggfuncs.AggFunc, 0, len(v.WindowFuncDescs)) + partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) + resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) + exprCtx := b.ctx.GetExprCtx() + for _, desc := range v.WindowFuncDescs { + aggDesc, err := aggregation.NewAggFuncDescForWindowFunc(exprCtx, desc, false) + if err != nil { + b.err = err + return nil + } + agg := aggfuncs.BuildWindowFunctions(exprCtx, aggDesc, resultColIdx, orderByCols) + windowFuncs = append(windowFuncs, agg) + partialResult, _ := agg.AllocPartialResult() + partialResults = append(partialResults, partialResult) + resultColIdx++ + } + + var err error + if b.ctx.GetSessionVars().EnablePipelinedWindowExec { + exec := &PipelinedWindowExec{ + BaseExecutor: base, + groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx.GetExprCtx().GetEvalCtx(), b.ctx.GetSessionVars().EnableVectorizedExpression, groupByItems), + numWindowFuncs: len(v.WindowFuncDescs), + } + + exec.windowFuncs = windowFuncs + exec.partialResults = partialResults + if v.Frame == nil { + exec.start = &logicalop.FrameBound{ + Type: ast.Preceding, + UnBounded: true, + } + exec.end = &logicalop.FrameBound{ + Type: ast.Following, + UnBounded: true, + } + } else { + exec.start = v.Frame.Start + exec.end = v.Frame.End + if v.Frame.Type == ast.Ranges { + cmpResult := int64(-1) + if len(v.OrderBy) > 0 && v.OrderBy[0].Desc { + cmpResult = 1 + } + exec.orderByCols = orderByCols + exec.expectedCmpResult = cmpResult + exec.isRangeFrame = true + err = exec.start.UpdateCompareCols(b.ctx, exec.orderByCols) + if err != nil { + return nil + } + err = exec.end.UpdateCompareCols(b.ctx, exec.orderByCols) + if err != nil { + return nil + } + } + } + return exec + } + var processor windowProcessor + if v.Frame == nil { + processor = &aggWindowProcessor{ + windowFuncs: windowFuncs, + partialResults: partialResults, + } + } else if v.Frame.Type == ast.Rows { + processor = &rowFrameWindowProcessor{ + windowFuncs: windowFuncs, + partialResults: partialResults, + start: v.Frame.Start, + end: v.Frame.End, + } + } else { + cmpResult := int64(-1) + if len(v.OrderBy) > 0 && v.OrderBy[0].Desc { + cmpResult = 1 + } + tmpProcessor := &rangeFrameWindowProcessor{ + windowFuncs: windowFuncs, + partialResults: partialResults, + start: v.Frame.Start, + end: v.Frame.End, + orderByCols: orderByCols, + expectedCmpResult: cmpResult, + } + + err = tmpProcessor.start.UpdateCompareCols(b.ctx, orderByCols) + if err != nil { + return nil + } + err = tmpProcessor.end.UpdateCompareCols(b.ctx, orderByCols) + if err != nil { + return nil + } + + processor = tmpProcessor + } + return &WindowExec{BaseExecutor: base, + processor: processor, + groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx.GetExprCtx().GetEvalCtx(), b.ctx.GetSessionVars().EnableVectorizedExpression, groupByItems), + numWindowFuncs: len(v.WindowFuncDescs), + } +} + +func (b *executorBuilder) buildShuffle(v *plannercore.PhysicalShuffle) *ShuffleExec { + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + shuffle := &ShuffleExec{ + BaseExecutor: base, + concurrency: v.Concurrency, + } + + // 1. initialize the splitters + splitters := make([]partitionSplitter, len(v.ByItemArrays)) + switch v.SplitterType { + case plannercore.PartitionHashSplitterType: + for i, byItems := range v.ByItemArrays { + splitters[i] = buildPartitionHashSplitter(shuffle.concurrency, byItems) + } + case plannercore.PartitionRangeSplitterType: + for i, byItems := range v.ByItemArrays { + splitters[i] = buildPartitionRangeSplitter(b.ctx, shuffle.concurrency, byItems) + } + default: + panic("Not implemented. Should not reach here.") + } + shuffle.splitters = splitters + + // 2. initialize the data sources (build the data sources from physical plan to executors) + shuffle.dataSources = make([]exec.Executor, len(v.DataSources)) + for i, dataSource := range v.DataSources { + shuffle.dataSources[i] = b.build(dataSource) + if b.err != nil { + return nil + } + } + + // 3. initialize the workers + head := v.Children()[0] + // A `PhysicalShuffleReceiverStub` for every worker have the same `DataSource` but different `Receiver`. + // We preallocate `PhysicalShuffleReceiverStub`s here and reuse them below. + stubs := make([]*plannercore.PhysicalShuffleReceiverStub, 0, len(v.DataSources)) + for _, dataSource := range v.DataSources { + stub := plannercore.PhysicalShuffleReceiverStub{ + DataSource: dataSource, + }.Init(b.ctx.GetPlanCtx(), dataSource.StatsInfo(), dataSource.QueryBlockOffset(), nil) + stub.SetSchema(dataSource.Schema()) + stubs = append(stubs, stub) + } + shuffle.workers = make([]*shuffleWorker, shuffle.concurrency) + for i := range shuffle.workers { + receivers := make([]*shuffleReceiver, len(v.DataSources)) + for j, dataSource := range v.DataSources { + receivers[j] = &shuffleReceiver{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, dataSource.Schema(), stubs[j].ID()), + } + } + + w := &shuffleWorker{ + receivers: receivers, + } + + for j := range v.DataSources { + stub := stubs[j] + stub.Receiver = (unsafe.Pointer)(receivers[j]) + v.Tails[j].SetChildren(stub) + } + + w.childExec = b.build(head) + if b.err != nil { + return nil + } + + shuffle.workers[i] = w + } + + return shuffle +} + +func (*executorBuilder) buildShuffleReceiverStub(v *plannercore.PhysicalShuffleReceiverStub) *shuffleReceiver { + return (*shuffleReceiver)(v.Receiver) +} + +func (b *executorBuilder) buildSQLBindExec(v *plannercore.SQLBindPlan) exec.Executor { + base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) + base.SetInitCap(chunk.ZeroCapacity) + + e := &SQLBindExec{ + BaseExecutor: base, + sqlBindOp: v.SQLBindOp, + normdOrigSQL: v.NormdOrigSQL, + bindSQL: v.BindSQL, + charset: v.Charset, + collation: v.Collation, + db: v.Db, + isGlobal: v.IsGlobal, + bindAst: v.BindStmt, + newStatus: v.NewStatus, + source: v.Source, + sqlDigest: v.SQLDigest, + planDigest: v.PlanDigest, + } + return e +} + +// NewRowDecoder creates a chunk decoder for new row format row value decode. +func NewRowDecoder(ctx sessionctx.Context, schema *expression.Schema, tbl *model.TableInfo) *rowcodec.ChunkDecoder { + getColInfoByID := func(tbl *model.TableInfo, colID int64) *model.ColumnInfo { + for _, col := range tbl.Columns { + if col.ID == colID { + return col + } + } + return nil + } + var pkCols []int64 + reqCols := make([]rowcodec.ColInfo, len(schema.Columns)) + for i := range schema.Columns { + idx, col := i, schema.Columns[i] + isPK := (tbl.PKIsHandle && mysql.HasPriKeyFlag(col.RetType.GetFlag())) || col.ID == model.ExtraHandleID + if isPK { + pkCols = append(pkCols, col.ID) + } + isGeneratedCol := false + if col.VirtualExpr != nil { + isGeneratedCol = true + } + reqCols[idx] = rowcodec.ColInfo{ + ID: col.ID, + VirtualGenCol: isGeneratedCol, + Ft: col.RetType, + } + } + if len(pkCols) == 0 { + pkCols = tables.TryGetCommonPkColumnIds(tbl) + if len(pkCols) == 0 { + pkCols = []int64{-1} + } + } + defVal := func(i int, chk *chunk.Chunk) error { + if reqCols[i].ID < 0 { + // model.ExtraHandleID, ExtraPhysTblID... etc + // Don't set the default value for that column. + chk.AppendNull(i) + return nil + } + + ci := getColInfoByID(tbl, reqCols[i].ID) + d, err := table.GetColOriginDefaultValue(ctx.GetExprCtx(), ci) + if err != nil { + return err + } + chk.AppendDatum(i, &d) + return nil + } + return rowcodec.NewChunkDecoder(reqCols, pkCols, defVal, ctx.GetSessionVars().Location()) +} + +func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan) exec.Executor { + var err error + if err = b.validCanReadTemporaryOrCacheTable(plan.TblInfo); err != nil { + b.err = err + return nil + } + + if plan.Lock && !b.inSelectLockStmt { + b.inSelectLockStmt = true + defer func() { + b.inSelectLockStmt = false + }() + } + handles, isTableDual := plan.PrunePartitionsAndValues(b.ctx) + if isTableDual { + // No matching partitions + return &TableDualExec{ + BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), plan.Schema(), plan.ID()), + numDualRows: 0, + } + } + + decoder := NewRowDecoder(b.ctx, plan.Schema(), plan.TblInfo) + e := &BatchPointGetExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, plan.Schema(), plan.ID()), + indexUsageReporter: b.buildIndexUsageReporter(plan), + tblInfo: plan.TblInfo, + idxInfo: plan.IndexInfo, + rowDecoder: decoder, + keepOrder: plan.KeepOrder, + desc: plan.Desc, + lock: plan.Lock, + waitTime: plan.LockWaitTime, + columns: plan.Columns, + handles: handles, + idxVals: plan.IndexValues, + partitionNames: plan.PartitionNames, + } + + e.snapshot, err = b.getSnapshot() + if err != nil { + b.err = err + return nil + } + if e.Ctx().GetSessionVars().IsReplicaReadClosestAdaptive() { + e.snapshot.SetOption(kv.ReplicaReadAdjuster, newReplicaReadAdjuster(e.Ctx(), plan.GetAvgRowSize())) + } + if e.RuntimeStats() != nil { + snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} + e.stats = &runtimeStatsWithSnapshot{ + SnapshotRuntimeStats: snapshotStats, + } + e.snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) + } + + if plan.IndexInfo != nil { + sctx := b.ctx.GetSessionVars().StmtCtx + sctx.IndexNames = append(sctx.IndexNames, plan.TblInfo.Name.O+":"+plan.IndexInfo.Name.O) + } + + failpoint.Inject("assertBatchPointReplicaOption", func(val failpoint.Value) { + assertScope := val.(string) + if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != b.readReplicaScope { + panic("batch point get replica option fail") + } + }) + + snapshotTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + if plan.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { + if cacheTable := b.getCacheTable(plan.TblInfo, snapshotTS); cacheTable != nil { + e.snapshot = cacheTableSnapshot{e.snapshot, cacheTable} + } + } + + if plan.TblInfo.TempTableType != model.TempTableNone { + // Temporary table should not do any lock operations + e.lock = false + e.waitTime = 0 + } + + if e.lock { + b.hasLock = true + } + if pi := plan.TblInfo.GetPartitionInfo(); pi != nil && len(plan.PartitionIdxs) > 0 { + defs := plan.TblInfo.GetPartitionInfo().Definitions + if plan.SinglePartition { + e.singlePartID = defs[plan.PartitionIdxs[0]].ID + } else { + e.planPhysIDs = make([]int64, len(plan.PartitionIdxs)) + for i, idx := range plan.PartitionIdxs { + e.planPhysIDs[i] = defs[idx].ID + } + } + } + + capacity := len(e.handles) + if capacity == 0 { + capacity = len(e.idxVals) + } + e.SetInitCap(capacity) + e.SetMaxChunkSize(capacity) + e.buildVirtualColumnInfo() + return e +} + +func newReplicaReadAdjuster(ctx sessionctx.Context, avgRowSize float64) txnkv.ReplicaReadAdjuster { + return func(count int) (tikv.StoreSelectorOption, clientkv.ReplicaReadType) { + if int64(avgRowSize*float64(count)) >= ctx.GetSessionVars().ReplicaClosestReadThreshold { + return tikv.WithMatchLabels([]*metapb.StoreLabel{ + { + Key: placement.DCLabelKey, + Value: config.GetTxnScopeFromConfig(), + }, + }), clientkv.ReplicaReadMixed + } + // fallback to read from leader if the request is small + return nil, clientkv.ReplicaReadLeader + } +} + +func isCommonHandleRead(tbl *model.TableInfo, idx *model.IndexInfo) bool { + return tbl.IsCommonHandle && idx.Primary +} + +func getPhysicalTableID(t table.Table) int64 { + if p, ok := t.(table.PhysicalTable); ok { + return p.GetPhysicalID() + } + return t.Meta().ID +} + +func (builder *dataReaderBuilder) partitionPruning(tbl table.PartitionedTable, planPartInfo *plannercore.PhysPlanPartInfo) ([]table.PhysicalTable, error) { + builder.once.Do(func() { + condPruneResult, err := partitionPruning(builder.executorBuilder.ctx, tbl, planPartInfo) + builder.once.condPruneResult = condPruneResult + builder.once.err = err + }) + return builder.once.condPruneResult, builder.once.err +} + +func partitionPruning(ctx sessionctx.Context, tbl table.PartitionedTable, planPartInfo *plannercore.PhysPlanPartInfo) ([]table.PhysicalTable, error) { + var pruningConds []expression.Expression + var partitionNames []model.CIStr + var columns []*expression.Column + var columnNames types.NameSlice + if planPartInfo != nil { + pruningConds = planPartInfo.PruningConds + partitionNames = planPartInfo.PartitionNames + columns = planPartInfo.Columns + columnNames = planPartInfo.ColumnNames + } + idxArr, err := plannercore.PartitionPruning(ctx.GetPlanCtx(), tbl, pruningConds, partitionNames, columns, columnNames) + if err != nil { + return nil, err + } + + pi := tbl.Meta().GetPartitionInfo() + var ret []table.PhysicalTable + if fullRangePartition(idxArr) { + ret = make([]table.PhysicalTable, 0, len(pi.Definitions)) + for _, def := range pi.Definitions { + p := tbl.GetPartition(def.ID) + ret = append(ret, p) + } + } else { + ret = make([]table.PhysicalTable, 0, len(idxArr)) + for _, idx := range idxArr { + pid := pi.Definitions[idx].ID + p := tbl.GetPartition(pid) + ret = append(ret, p) + } + } + return ret, nil +} + +func getPartitionIDsAfterPruning(ctx sessionctx.Context, tbl table.PartitionedTable, physPlanPartInfo *plannercore.PhysPlanPartInfo) (map[int64]struct{}, error) { + if physPlanPartInfo == nil { + return nil, errors.New("physPlanPartInfo in getPartitionIDsAfterPruning must not be nil") + } + idxArr, err := plannercore.PartitionPruning(ctx.GetPlanCtx(), tbl, physPlanPartInfo.PruningConds, physPlanPartInfo.PartitionNames, physPlanPartInfo.Columns, physPlanPartInfo.ColumnNames) + if err != nil { + return nil, err + } + + var ret map[int64]struct{} + + pi := tbl.Meta().GetPartitionInfo() + if fullRangePartition(idxArr) { + ret = make(map[int64]struct{}, len(pi.Definitions)) + for _, def := range pi.Definitions { + ret[def.ID] = struct{}{} + } + } else { + ret = make(map[int64]struct{}, len(idxArr)) + for _, idx := range idxArr { + pid := pi.Definitions[idx].ID + ret[pid] = struct{}{} + } + } + return ret, nil +} + +func fullRangePartition(idxArr []int) bool { + return len(idxArr) == 1 && idxArr[0] == plannercore.FullRange +} + +type emptySampler struct{} + +func (*emptySampler) writeChunk(_ *chunk.Chunk) error { + return nil +} + +func (*emptySampler) finished() bool { + return true +} + +func (b *executorBuilder) buildTableSample(v *plannercore.PhysicalTableSample) *TableSampleExecutor { + startTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + e := &TableSampleExecutor{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + table: v.TableInfo, + startTS: startTS, + } + + tblInfo := v.TableInfo.Meta() + if tblInfo.TempTableType != model.TempTableNone { + if tblInfo.TempTableType != model.TempTableGlobal { + b.err = errors.New("TABLESAMPLE clause can not be applied to local temporary tables") + return nil + } + e.sampler = &emptySampler{} + } else if v.TableSampleInfo.AstNode.SampleMethod == ast.SampleMethodTypeTiDBRegion { + e.sampler = newTableRegionSampler( + b.ctx, v.TableInfo, startTS, v.PhysicalTableID, v.TableSampleInfo.Partitions, v.Schema(), + v.TableSampleInfo.FullSchema, e.RetFieldTypes(), v.Desc) + } + + return e +} + +func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) exec.Executor { + storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages) + if !ok { + b.err = errors.New("type assertion for CTEStorageMap failed") + return nil + } + + chkSize := b.ctx.GetSessionVars().MaxChunkSize + // iterOutTbl will be constructed in CTEExec.Open(). + var resTbl cteutil.Storage + var iterInTbl cteutil.Storage + var producer *cteProducer + storages, ok := storageMap[v.CTE.IDForStorage] + if ok { + // Storage already setup. + resTbl = storages.ResTbl + iterInTbl = storages.IterInTbl + producer = storages.Producer + } else { + if v.SeedPlan == nil { + b.err = errors.New("cte.seedPlan cannot be nil") + return nil + } + // Build seed part. + corCols := plannercore.ExtractOuterApplyCorrelatedCols(v.SeedPlan) + seedExec := b.build(v.SeedPlan) + if b.err != nil { + return nil + } + + // Setup storages. + tps := seedExec.RetFieldTypes() + resTbl = cteutil.NewStorageRowContainer(tps, chkSize) + if err := resTbl.OpenAndRef(); err != nil { + b.err = err + return nil + } + iterInTbl = cteutil.NewStorageRowContainer(tps, chkSize) + if err := iterInTbl.OpenAndRef(); err != nil { + b.err = err + return nil + } + storageMap[v.CTE.IDForStorage] = &CTEStorages{ResTbl: resTbl, IterInTbl: iterInTbl} + + // Build recursive part. + var recursiveExec exec.Executor + if v.RecurPlan != nil { + recursiveExec = b.build(v.RecurPlan) + if b.err != nil { + return nil + } + corCols = append(corCols, plannercore.ExtractOuterApplyCorrelatedCols(v.RecurPlan)...) + } + + var sel []int + if v.CTE.IsDistinct { + sel = make([]int, chkSize) + for i := 0; i < chkSize; i++ { + sel[i] = i + } + } + + var corColHashCodes [][]byte + for _, corCol := range corCols { + corColHashCodes = append(corColHashCodes, getCorColHashCode(corCol)) + } + + producer = &cteProducer{ + ctx: b.ctx, + seedExec: seedExec, + recursiveExec: recursiveExec, + resTbl: resTbl, + iterInTbl: iterInTbl, + isDistinct: v.CTE.IsDistinct, + sel: sel, + hasLimit: v.CTE.HasLimit, + limitBeg: v.CTE.LimitBeg, + limitEnd: v.CTE.LimitEnd, + corCols: corCols, + corColHashCodes: corColHashCodes, + } + storageMap[v.CTE.IDForStorage].Producer = producer + } + + return &CTEExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + producer: producer, + } +} + +func (b *executorBuilder) buildCTETableReader(v *plannercore.PhysicalCTETable) exec.Executor { + storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages) + if !ok { + b.err = errors.New("type assertion for CTEStorageMap failed") + return nil + } + storages, ok := storageMap[v.IDForStorage] + if !ok { + b.err = errors.Errorf("iterInTbl should already be set up by CTEExec(id: %d)", v.IDForStorage) + return nil + } + return &CTETableReaderExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + iterInTbl: storages.IterInTbl, + chkIdx: 0, + } +} +func (b *executorBuilder) validCanReadTemporaryOrCacheTable(tbl *model.TableInfo) error { + err := b.validCanReadTemporaryTable(tbl) + if err != nil { + return err + } + return b.validCanReadCacheTable(tbl) +} + +func (b *executorBuilder) validCanReadCacheTable(tbl *model.TableInfo) error { + if tbl.TableCacheStatusType == model.TableCacheStatusDisable { + return nil + } + + sessionVars := b.ctx.GetSessionVars() + + // Temporary table can't switch into cache table. so the following code will not cause confusion + if sessionVars.TxnCtx.IsStaleness || b.isStaleness { + return errors.Trace(errors.New("can not stale read cache table")) + } + + return nil +} + +func (b *executorBuilder) validCanReadTemporaryTable(tbl *model.TableInfo) error { + if tbl.TempTableType == model.TempTableNone { + return nil + } + + // Some tools like dumpling use history read to dump all table's records and will be fail if we return an error. + // So we do not check SnapshotTS here + + sessionVars := b.ctx.GetSessionVars() + + if tbl.TempTableType == model.TempTableLocal && sessionVars.SnapshotTS != 0 { + return errors.New("can not read local temporary table when 'tidb_snapshot' is set") + } + + if sessionVars.TxnCtx.IsStaleness || b.isStaleness { + return errors.New("can not stale read temporary table") + } + + return nil +} + +func (b *executorBuilder) getCacheTable(tblInfo *model.TableInfo, startTS uint64) kv.MemBuffer { + tbl, ok := b.is.TableByID(tblInfo.ID) + if !ok { + b.err = errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(b.ctx.GetSessionVars().CurrentDB, tblInfo.Name)) + return nil + } + sessVars := b.ctx.GetSessionVars() + leaseDuration := time.Duration(variable.TableCacheLease.Load()) * time.Second + cacheData, loading := tbl.(table.CachedTable).TryReadFromCache(startTS, leaseDuration) + if cacheData != nil { + sessVars.StmtCtx.ReadFromTableCache = true + return cacheData + } else if loading { + return nil + } + if !b.ctx.GetSessionVars().StmtCtx.InExplainStmt && !b.inDeleteStmt && !b.inUpdateStmt { + tbl.(table.CachedTable).UpdateLockForRead(context.Background(), b.ctx.GetStore(), startTS, leaseDuration) + } + return nil +} + +func (b *executorBuilder) buildCompactTable(v *plannercore.CompactTable) exec.Executor { + if v.ReplicaKind != ast.CompactReplicaKindTiFlash && v.ReplicaKind != ast.CompactReplicaKindAll { + b.err = errors.Errorf("compact %v replica is not supported", strings.ToLower(string(v.ReplicaKind))) + return nil + } + + store := b.ctx.GetStore() + tikvStore, ok := store.(tikv.Storage) + if !ok { + b.err = errors.New("compact tiflash replica can only run with tikv compatible storage") + return nil + } + + var partitionIDs []int64 + if v.PartitionNames != nil { + if v.TableInfo.Partition == nil { + b.err = errors.Errorf("table:%s is not a partition table, but user specify partition name list:%+v", v.TableInfo.Name.O, v.PartitionNames) + return nil + } + // use map to avoid FindPartitionDefinitionByName + partitionMap := map[string]int64{} + for _, partition := range v.TableInfo.Partition.Definitions { + partitionMap[partition.Name.L] = partition.ID + } + + for _, partitionName := range v.PartitionNames { + partitionID, ok := partitionMap[partitionName.L] + if !ok { + b.err = table.ErrUnknownPartition.GenWithStackByArgs(partitionName.O, v.TableInfo.Name.O) + return nil + } + partitionIDs = append(partitionIDs, partitionID) + } + } + + return &CompactTableTiFlashExec{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), + tableInfo: v.TableInfo, + partitionIDs: partitionIDs, + tikvStore: tikvStore, + } +} + +func (b *executorBuilder) buildAdminShowBDRRole(v *plannercore.AdminShowBDRRole) exec.Executor { + return &AdminShowBDRRoleExec{BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID())} +} diff --git a/pkg/executor/checksum.go b/pkg/executor/checksum.go index c4d349359a74d..2fde5a48869fb 100644 --- a/pkg/executor/checksum.go +++ b/pkg/executor/checksum.go @@ -159,7 +159,7 @@ func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.C if err1 := res.Close(); err1 != nil { err = err1 } - failpoint.Inject("afterHandleChecksumRequest", nil) + failpoint.Eval(_curpkg_("afterHandleChecksumRequest")) }() resp = &tipb.ChecksumResponse{} diff --git a/pkg/executor/checksum.go__failpoint_stash__ b/pkg/executor/checksum.go__failpoint_stash__ new file mode 100644 index 0000000000000..c4d349359a74d --- /dev/null +++ b/pkg/executor/checksum.go__failpoint_stash__ @@ -0,0 +1,336 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "strconv" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +var _ exec.Executor = &ChecksumTableExec{} + +// ChecksumTableExec represents ChecksumTable executor. +type ChecksumTableExec struct { + exec.BaseExecutor + + tables map[int64]*checksumContext // tableID -> checksumContext + done bool +} + +// Open implements the Executor Open interface. +func (e *ChecksumTableExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + + concurrency, err := getChecksumTableConcurrency(e.Ctx()) + if err != nil { + return err + } + + tasks, err := e.buildTasks() + if err != nil { + return err + } + + taskCh := make(chan *checksumTask, len(tasks)) + resultCh := make(chan *checksumResult, len(tasks)) + for i := 0; i < concurrency; i++ { + go e.checksumWorker(taskCh, resultCh) + } + + for _, task := range tasks { + taskCh <- task + } + close(taskCh) + + for i := 0; i < len(tasks); i++ { + result := <-resultCh + if result.err != nil { + err = result.err + logutil.Logger(ctx).Error("checksum failed", zap.Error(err)) + continue + } + logutil.Logger(ctx).Info( + "got one checksum result", + zap.Int64("tableID", result.tableID), + zap.Int64("physicalTableID", result.physicalTableID), + zap.Int64("indexID", result.indexID), + zap.Uint64("checksum", result.response.Checksum), + zap.Uint64("totalKvs", result.response.TotalKvs), + zap.Uint64("totalBytes", result.response.TotalBytes), + ) + e.handleResult(result) + } + if err != nil { + return err + } + + return nil +} + +// Next implements the Executor Next interface. +func (e *ChecksumTableExec) Next(_ context.Context, req *chunk.Chunk) error { + req.Reset() + if e.done { + return nil + } + for _, t := range e.tables { + req.AppendString(0, t.dbInfo.Name.O) + req.AppendString(1, t.tableInfo.Name.O) + req.AppendUint64(2, t.response.Checksum) + req.AppendUint64(3, t.response.TotalKvs) + req.AppendUint64(4, t.response.TotalBytes) + } + e.done = true + return nil +} + +func (e *ChecksumTableExec) buildTasks() ([]*checksumTask, error) { + allTasks := make([][]*checksumTask, 0, len(e.tables)) + taskCnt := 0 + for _, t := range e.tables { + tasks, err := t.buildTasks(e.Ctx()) + if err != nil { + return nil, err + } + allTasks = append(allTasks, tasks) + taskCnt += len(tasks) + } + ret := make([]*checksumTask, 0, taskCnt) + for _, task := range allTasks { + ret = append(ret, task...) + } + return ret, nil +} + +func (e *ChecksumTableExec) handleResult(result *checksumResult) { + table := e.tables[result.tableID] + table.handleResponse(result.response) +} + +func (e *ChecksumTableExec) checksumWorker(taskCh <-chan *checksumTask, resultCh chan<- *checksumResult) { + for task := range taskCh { + result := &checksumResult{ + tableID: task.tableID, + physicalTableID: task.physicalTableID, + indexID: task.indexID, + } + result.response, result.err = e.handleChecksumRequest(task.request) + resultCh <- result + } +} + +func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.ChecksumResponse, err error) { + if err = e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return nil, err + } + ctx := distsql.WithSQLKvExecCounterInterceptor(context.TODO(), e.Ctx().GetSessionVars().StmtCtx.KvExecCounter) + res, err := distsql.Checksum(ctx, e.Ctx().GetClient(), req, e.Ctx().GetSessionVars().KVVars) + if err != nil { + return nil, err + } + defer func() { + if err1 := res.Close(); err1 != nil { + err = err1 + } + failpoint.Inject("afterHandleChecksumRequest", nil) + }() + + resp = &tipb.ChecksumResponse{} + + for { + data, err := res.NextRaw(ctx) + if err != nil { + return nil, err + } + if data == nil { + break + } + checksum := &tipb.ChecksumResponse{} + if err = checksum.Unmarshal(data); err != nil { + return nil, err + } + updateChecksumResponse(resp, checksum) + if err = e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return nil, err + } + } + + return resp, nil +} + +type checksumTask struct { + tableID int64 + physicalTableID int64 + indexID int64 + request *kv.Request +} + +type checksumResult struct { + err error + tableID int64 + physicalTableID int64 + indexID int64 + response *tipb.ChecksumResponse +} + +type checksumContext struct { + dbInfo *model.DBInfo + tableInfo *model.TableInfo + startTs uint64 + response *tipb.ChecksumResponse +} + +func newChecksumContext(db *model.DBInfo, table *model.TableInfo, startTs uint64) *checksumContext { + return &checksumContext{ + dbInfo: db, + tableInfo: table, + startTs: startTs, + response: &tipb.ChecksumResponse{}, + } +} + +func (c *checksumContext) buildTasks(ctx sessionctx.Context) ([]*checksumTask, error) { + var partDefs []model.PartitionDefinition + if part := c.tableInfo.Partition; part != nil { + partDefs = part.Definitions + } + + reqs := make([]*checksumTask, 0, (len(c.tableInfo.Indices)+1)*(len(partDefs)+1)) + if err := c.appendRequest4PhysicalTable(ctx, c.tableInfo.ID, c.tableInfo.ID, &reqs); err != nil { + return nil, err + } + + for _, partDef := range partDefs { + if err := c.appendRequest4PhysicalTable(ctx, c.tableInfo.ID, partDef.ID, &reqs); err != nil { + return nil, err + } + } + + return reqs, nil +} + +func (c *checksumContext) appendRequest4PhysicalTable( + ctx sessionctx.Context, + tableID int64, + physicalTableID int64, + reqs *[]*checksumTask, +) error { + req, err := c.buildTableRequest(ctx, physicalTableID) + if err != nil { + return err + } + + *reqs = append(*reqs, &checksumTask{ + tableID: tableID, + physicalTableID: physicalTableID, + indexID: -1, + request: req, + }) + for _, indexInfo := range c.tableInfo.Indices { + if indexInfo.State != model.StatePublic { + continue + } + req, err = c.buildIndexRequest(ctx, physicalTableID, indexInfo) + if err != nil { + return err + } + *reqs = append(*reqs, &checksumTask{ + tableID: tableID, + physicalTableID: physicalTableID, + indexID: indexInfo.ID, + request: req, + }) + } + + return nil +} + +func (c *checksumContext) buildTableRequest(ctx sessionctx.Context, physicalTableID int64) (*kv.Request, error) { + checksum := &tipb.ChecksumRequest{ + ScanOn: tipb.ChecksumScanOn_Table, + Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, + } + + var ranges []*ranger.Range + if c.tableInfo.IsCommonHandle { + ranges = ranger.FullNotNullRange() + } else { + ranges = ranger.FullIntRange(false) + } + + var builder distsql.RequestBuilder + builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) + return builder.SetHandleRanges(ctx.GetDistSQLCtx(), physicalTableID, c.tableInfo.IsCommonHandle, ranges). + SetChecksumRequest(checksum). + SetStartTS(c.startTs). + SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency()). + SetResourceGroupName(ctx.GetSessionVars().StmtCtx.ResourceGroupName). + SetExplicitRequestSourceType(ctx.GetSessionVars().ExplicitRequestSourceType). + Build() +} + +func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, physicalTableID int64, indexInfo *model.IndexInfo) (*kv.Request, error) { + checksum := &tipb.ChecksumRequest{ + ScanOn: tipb.ChecksumScanOn_Index, + Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, + } + + ranges := ranger.FullRange() + + var builder distsql.RequestBuilder + builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) + return builder.SetIndexRanges(ctx.GetDistSQLCtx(), physicalTableID, indexInfo.ID, ranges). + SetChecksumRequest(checksum). + SetStartTS(c.startTs). + SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency()). + SetResourceGroupName(ctx.GetSessionVars().StmtCtx.ResourceGroupName). + SetExplicitRequestSourceType(ctx.GetSessionVars().ExplicitRequestSourceType). + Build() +} + +func (c *checksumContext) handleResponse(update *tipb.ChecksumResponse) { + updateChecksumResponse(c.response, update) +} + +func getChecksumTableConcurrency(ctx sessionctx.Context) (int, error) { + sessionVars := ctx.GetSessionVars() + concurrency, err := sessionVars.GetSessionOrGlobalSystemVar(context.Background(), variable.TiDBChecksumTableConcurrency) + if err != nil { + return 0, err + } + c, err := strconv.ParseInt(concurrency, 10, 64) + return int(c), err +} + +func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { + resp.Checksum ^= update.Checksum + resp.TotalKvs += update.TotalKvs + resp.TotalBytes += update.TotalBytes +} diff --git a/pkg/executor/compiler.go b/pkg/executor/compiler.go index 8771753e78626..ea0a6cf1d4b7d 100644 --- a/pkg/executor/compiler.go +++ b/pkg/executor/compiler.go @@ -75,14 +75,14 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (_ *ExecS return nil, err } - failpoint.Inject("assertTxnManagerInCompile", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInCompile")); _err_ == nil { sessiontxn.RecordAssert(c.Ctx, "assertTxnManagerInCompile", true) sessiontxn.AssertTxnManagerInfoSchema(c.Ctx, ret.InfoSchema) if ret.LastSnapshotTS != 0 { staleread.AssertStmtStaleness(c.Ctx, true) sessiontxn.AssertTxnManagerReadTS(c.Ctx, ret.LastSnapshotTS) } - }) + } is := sessiontxn.GetTxnManager(c.Ctx).GetTxnInfoSchema() sessVars := c.Ctx.GetSessionVars() @@ -101,9 +101,9 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (_ *ExecS return nil, err } - failpoint.Inject("assertStmtCtxIsStaleness", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertStmtCtxIsStaleness")); _err_ == nil { staleread.AssertStmtStaleness(c.Ctx, val.(bool)) - }) + } if preparedObj != nil { CountStmtNode(preparedObj.PreparedAst.Stmt, sessVars.InRestrictedSQL, stmtCtx.ResourceGroupName) diff --git a/pkg/executor/compiler.go__failpoint_stash__ b/pkg/executor/compiler.go__failpoint_stash__ new file mode 100644 index 0000000000000..8771753e78626 --- /dev/null +++ b/pkg/executor/compiler.go__failpoint_stash__ @@ -0,0 +1,568 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/sessiontxn/staleread" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/tracing" + "go.uber.org/zap" +) + +// Compiler compiles an ast.StmtNode to a physical plan. +type Compiler struct { + Ctx sessionctx.Context +} + +// Compile compiles an ast.StmtNode to a physical plan. +func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (_ *ExecStmt, err error) { + r, ctx := tracing.StartRegionEx(ctx, "executor.Compile") + defer r.End() + + defer func() { + r := recover() + if r == nil { + return + } + recoveredErr, ok := r.(error) + if !ok || !(exeerrors.ErrMemoryExceedForQuery.Equal(recoveredErr) || + exeerrors.ErrMemoryExceedForInstance.Equal(recoveredErr) || + exeerrors.ErrQueryInterrupted.Equal(recoveredErr) || + exeerrors.ErrMaxExecTimeExceeded.Equal(recoveredErr)) { + panic(r) + } + err = recoveredErr + logutil.Logger(ctx).Error("compile SQL panic", zap.String("SQL", stmtNode.Text()), zap.Stack("stack"), zap.Any("recover", r)) + }() + + c.Ctx.GetSessionVars().StmtCtx.IsReadOnly = plannercore.IsReadOnly(stmtNode, c.Ctx.GetSessionVars()) + + // Do preprocess and validate. + ret := &plannercore.PreprocessorReturn{} + err = plannercore.Preprocess( + ctx, + c.Ctx, + stmtNode, + plannercore.WithPreprocessorReturn(ret), + plannercore.InitTxnContextProvider, + ) + if err != nil { + return nil, err + } + + failpoint.Inject("assertTxnManagerInCompile", func() { + sessiontxn.RecordAssert(c.Ctx, "assertTxnManagerInCompile", true) + sessiontxn.AssertTxnManagerInfoSchema(c.Ctx, ret.InfoSchema) + if ret.LastSnapshotTS != 0 { + staleread.AssertStmtStaleness(c.Ctx, true) + sessiontxn.AssertTxnManagerReadTS(c.Ctx, ret.LastSnapshotTS) + } + }) + + is := sessiontxn.GetTxnManager(c.Ctx).GetTxnInfoSchema() + sessVars := c.Ctx.GetSessionVars() + stmtCtx := sessVars.StmtCtx + // handle the execute statement + var preparedObj *plannercore.PlanCacheStmt + + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + if preparedObj, err = plannercore.GetPreparedStmt(execStmt, sessVars); err != nil { + return nil, err + } + } + // Build the final physical plan. + finalPlan, names, err := planner.Optimize(ctx, c.Ctx, stmtNode, is) + if err != nil { + return nil, err + } + + failpoint.Inject("assertStmtCtxIsStaleness", func(val failpoint.Value) { + staleread.AssertStmtStaleness(c.Ctx, val.(bool)) + }) + + if preparedObj != nil { + CountStmtNode(preparedObj.PreparedAst.Stmt, sessVars.InRestrictedSQL, stmtCtx.ResourceGroupName) + } else { + CountStmtNode(stmtNode, sessVars.InRestrictedSQL, stmtCtx.ResourceGroupName) + } + var lowerPriority bool + if c.Ctx.GetSessionVars().StmtCtx.Priority == mysql.NoPriority { + lowerPriority = needLowerPriority(finalPlan) + } + stmtCtx.SetPlan(finalPlan) + stmt := &ExecStmt{ + GoCtx: ctx, + InfoSchema: is, + Plan: finalPlan, + LowerPriority: lowerPriority, + Text: stmtNode.Text(), + StmtNode: stmtNode, + Ctx: c.Ctx, + OutputNames: names, + } + // Use cached plan if possible. + if preparedObj != nil && plannercore.IsSafeToReusePointGetExecutor(c.Ctx, is, preparedObj) { + if exec, isExec := finalPlan.(*plannercore.Execute); isExec { + if pointPlan, isPointPlan := exec.Plan.(*plannercore.PointGetPlan); isPointPlan { + stmt.PsStmt, stmt.Plan = preparedObj, pointPlan // notify to re-use the cached plan + } + } + } + + // Perform optimization and initialization related to the transaction level. + if err = sessiontxn.AdviseOptimizeWithPlanAndThenWarmUp(c.Ctx, stmt.Plan); err != nil { + return nil, err + } + + return stmt, nil +} + +// needLowerPriority checks whether it's needed to lower the execution priority +// of a query. +// If the estimated output row count of any operator in the physical plan tree +// is greater than the specific threshold, we'll set it to lowPriority when +// sending it to the coprocessor. +func needLowerPriority(p base.Plan) bool { + switch x := p.(type) { + case base.PhysicalPlan: + return isPhysicalPlanNeedLowerPriority(x) + case *plannercore.Execute: + return needLowerPriority(x.Plan) + case *plannercore.Insert: + if x.SelectPlan != nil { + return isPhysicalPlanNeedLowerPriority(x.SelectPlan) + } + case *plannercore.Delete: + if x.SelectPlan != nil { + return isPhysicalPlanNeedLowerPriority(x.SelectPlan) + } + case *plannercore.Update: + if x.SelectPlan != nil { + return isPhysicalPlanNeedLowerPriority(x.SelectPlan) + } + } + return false +} + +func isPhysicalPlanNeedLowerPriority(p base.PhysicalPlan) bool { + expensiveThreshold := int64(config.GetGlobalConfig().Log.ExpensiveThreshold) + if int64(p.StatsCount()) > expensiveThreshold { + return true + } + + for _, child := range p.Children() { + if isPhysicalPlanNeedLowerPriority(child) { + return true + } + } + + return false +} + +// CountStmtNode records the number of statements with the same type. +func CountStmtNode(stmtNode ast.StmtNode, inRestrictedSQL bool, resourceGroup string) { + if inRestrictedSQL { + return + } + + typeLabel := ast.GetStmtLabel(stmtNode) + + if config.GetGlobalConfig().Status.RecordQPSbyDB || config.GetGlobalConfig().Status.RecordDBLabel { + dbLabels := getStmtDbLabel(stmtNode) + switch { + case config.GetGlobalConfig().Status.RecordQPSbyDB: + for dbLabel := range dbLabels { + metrics.DbStmtNodeCounter.WithLabelValues(dbLabel, typeLabel).Inc() + } + case config.GetGlobalConfig().Status.RecordDBLabel: + for dbLabel := range dbLabels { + metrics.StmtNodeCounter.WithLabelValues(typeLabel, dbLabel, resourceGroup).Inc() + } + } + } else { + metrics.StmtNodeCounter.WithLabelValues(typeLabel, "", resourceGroup).Inc() + } +} + +func getStmtDbLabel(stmtNode ast.StmtNode) map[string]struct{} { + dbLabelSet := make(map[string]struct{}) + + switch x := stmtNode.(type) { + case *ast.AlterTableStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.CreateIndexStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.CreateTableStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.InsertStmt: + var dbLabels []string + if x.Table != nil { + dbLabels = getDbFromResultNode(x.Table.TableRefs) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + dbLabels = getDbFromResultNode(x.Select) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + case *ast.DropIndexStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.TruncateTableStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.RepairTableStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.FlashBackTableStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.RecoverTableStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.CreateViewStmt: + if x.ViewName != nil { + dbLabel := x.ViewName.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.RenameTableStmt: + tables := x.TableToTables + for _, table := range tables { + if table.OldTable != nil { + dbLabel := table.OldTable.Schema.O + if _, ok := dbLabelSet[dbLabel]; !ok { + dbLabelSet[dbLabel] = struct{}{} + } + } + } + case *ast.DropTableStmt: + tables := x.Tables + for _, table := range tables { + dbLabel := table.Schema.O + if _, ok := dbLabelSet[dbLabel]; !ok { + dbLabelSet[dbLabel] = struct{}{} + } + } + case *ast.SelectStmt: + dbLabels := getDbFromResultNode(x) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + case *ast.SetOprStmt: + dbLabels := getDbFromResultNode(x) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + case *ast.UpdateStmt: + if x.TableRefs != nil { + dbLabels := getDbFromResultNode(x.TableRefs.TableRefs) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + case *ast.DeleteStmt: + if x.TableRefs != nil { + dbLabels := getDbFromResultNode(x.TableRefs.TableRefs) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + case *ast.CallStmt: + if x.Procedure != nil { + dbLabel := x.Procedure.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.ShowStmt: + dbLabelSet[x.DBName] = struct{}{} + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.LoadDataStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.ImportIntoStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.SplitRegionStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.NonTransactionalDMLStmt: + if x.ShardColumn != nil { + dbLabel := x.ShardColumn.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.AnalyzeTableStmt: + tables := x.TableNames + for _, table := range tables { + dbLabel := table.Schema.O + if _, ok := dbLabelSet[dbLabel]; !ok { + dbLabelSet[dbLabel] = struct{}{} + } + } + case *ast.DropStatsStmt: + tables := x.Tables + for _, table := range tables { + dbLabel := table.Schema.O + if _, ok := dbLabelSet[dbLabel]; !ok { + dbLabelSet[dbLabel] = struct{}{} + } + } + case *ast.AdminStmt: + tables := x.Tables + for _, table := range tables { + dbLabel := table.Schema.O + if _, ok := dbLabelSet[dbLabel]; !ok { + dbLabelSet[dbLabel] = struct{}{} + } + } + case *ast.UseStmt: + if _, ok := dbLabelSet[x.DBName]; !ok { + dbLabelSet[x.DBName] = struct{}{} + } + case *ast.FlushStmt: + tables := x.Tables + for _, table := range tables { + dbLabel := table.Schema.O + if _, ok := dbLabelSet[dbLabel]; !ok { + dbLabelSet[dbLabel] = struct{}{} + } + } + case *ast.CompactTableStmt: + if x.Table != nil { + dbLabel := x.Table.Schema.O + dbLabelSet[dbLabel] = struct{}{} + } + case *ast.CreateBindingStmt: + var resNode ast.ResultSetNode + var tableRef *ast.TableRefsClause + if x.OriginNode != nil { + switch n := x.OriginNode.(type) { + case *ast.SelectStmt: + tableRef = n.From + case *ast.DeleteStmt: + tableRef = n.TableRefs + case *ast.UpdateStmt: + tableRef = n.TableRefs + case *ast.InsertStmt: + tableRef = n.Table + } + if tableRef != nil { + resNode = tableRef.TableRefs + } else { + resNode = nil + } + dbLabels := getDbFromResultNode(resNode) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + if len(dbLabelSet) == 0 && x.HintedNode != nil { + switch n := x.HintedNode.(type) { + case *ast.SelectStmt: + tableRef = n.From + case *ast.DeleteStmt: + tableRef = n.TableRefs + case *ast.UpdateStmt: + tableRef = n.TableRefs + case *ast.InsertStmt: + tableRef = n.Table + } + if tableRef != nil { + resNode = tableRef.TableRefs + } else { + resNode = nil + } + dbLabels := getDbFromResultNode(resNode) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + case *ast.DropBindingStmt: + var resNode ast.ResultSetNode + var tableRef *ast.TableRefsClause + if x.OriginNode != nil { + switch n := x.OriginNode.(type) { + case *ast.SelectStmt: + tableRef = n.From + case *ast.DeleteStmt: + tableRef = n.TableRefs + case *ast.UpdateStmt: + tableRef = n.TableRefs + case *ast.InsertStmt: + tableRef = n.Table + } + if tableRef != nil { + resNode = tableRef.TableRefs + } else { + resNode = nil + } + dbLabels := getDbFromResultNode(resNode) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + if len(dbLabelSet) == 0 && x.HintedNode != nil { + switch n := x.HintedNode.(type) { + case *ast.SelectStmt: + tableRef = n.From + case *ast.DeleteStmt: + tableRef = n.TableRefs + case *ast.UpdateStmt: + tableRef = n.TableRefs + case *ast.InsertStmt: + tableRef = n.Table + } + if tableRef != nil { + resNode = tableRef.TableRefs + } else { + resNode = nil + } + dbLabels := getDbFromResultNode(resNode) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + case *ast.SetBindingStmt: + var resNode ast.ResultSetNode + var tableRef *ast.TableRefsClause + if x.OriginNode != nil { + switch n := x.OriginNode.(type) { + case *ast.SelectStmt: + tableRef = n.From + case *ast.DeleteStmt: + tableRef = n.TableRefs + case *ast.UpdateStmt: + tableRef = n.TableRefs + case *ast.InsertStmt: + tableRef = n.Table + } + if tableRef != nil { + resNode = tableRef.TableRefs + } else { + resNode = nil + } + dbLabels := getDbFromResultNode(resNode) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + + if len(dbLabelSet) == 0 && x.HintedNode != nil { + switch n := x.HintedNode.(type) { + case *ast.SelectStmt: + tableRef = n.From + case *ast.DeleteStmt: + tableRef = n.TableRefs + case *ast.UpdateStmt: + tableRef = n.TableRefs + case *ast.InsertStmt: + tableRef = n.Table + } + if tableRef != nil { + resNode = tableRef.TableRefs + } else { + resNode = nil + } + dbLabels := getDbFromResultNode(resNode) + for _, db := range dbLabels { + dbLabelSet[db] = struct{}{} + } + } + } + + // add "" db label + if len(dbLabelSet) == 0 { + dbLabelSet[""] = struct{}{} + } + + return dbLabelSet +} + +func getDbFromResultNode(resultNode ast.ResultSetNode) []string { // may have duplicate db name + var dbLabels []string + + if resultNode == nil { + return dbLabels + } + + switch x := resultNode.(type) { + case *ast.TableSource: + return getDbFromResultNode(x.Source) + case *ast.SelectStmt: + if x.From != nil { + return getDbFromResultNode(x.From.TableRefs) + } + case *ast.TableName: + if x.DBInfo != nil { + dbLabels = append(dbLabels, x.DBInfo.Name.O) + } + case *ast.Join: + if x.Left != nil { + dbs := getDbFromResultNode(x.Left) + if dbs != nil { + dbLabels = append(dbLabels, dbs...) + } + } + + if x.Right != nil { + dbs := getDbFromResultNode(x.Right) + if dbs != nil { + dbLabels = append(dbLabels, dbs...) + } + } + } + + return dbLabels +} diff --git a/pkg/executor/cte.go b/pkg/executor/cte.go index 4f57b89d3bd92..bd92fa3796176 100644 --- a/pkg/executor/cte.go +++ b/pkg/executor/cte.go @@ -134,13 +134,13 @@ func (e *CTEExec) Close() (firstErr error) { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() if !e.producer.closed { - failpoint.Inject("mock_cte_exec_panic_avoid_deadlock", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("mock_cte_exec_panic_avoid_deadlock")); _err_ == nil { ok := v.(bool) if ok { // mock an oom panic, returning ErrMemoryExceedForQuery for error identification in recovery work. panic(exeerrors.ErrMemoryExceedForQuery) } - }) + } // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. // It means you can still read resTbl after call closeProducer(). // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). @@ -350,7 +350,7 @@ func (p *cteProducer) produce(ctx context.Context) (err error) { iterOutAction = setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) } - failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testCTEStorageSpill")); _err_ == nil { if val.(bool) && variable.EnableTmpStorageOnOOM.Load() { defer resAction.WaitForTest() defer iterInAction.WaitForTest() @@ -358,7 +358,7 @@ func (p *cteProducer) produce(ctx context.Context) (err error) { defer iterOutAction.WaitForTest() } } - }) + } if err = p.computeSeedPart(ctx); err != nil { p.resTbl.SetError(err) @@ -378,7 +378,7 @@ func (p *cteProducer) computeSeedPart(ctx context.Context) (err error) { err = util.GetRecoverError(r) } }() - failpoint.Inject("testCTESeedPanic", nil) + failpoint.Eval(_curpkg_("testCTESeedPanic")) p.curIter = 0 p.iterInTbl.SetIter(p.curIter) chks := make([]*chunk.Chunk, 0, 10) @@ -417,7 +417,7 @@ func (p *cteProducer) computeRecursivePart(ctx context.Context) (err error) { err = util.GetRecoverError(r) } }() - failpoint.Inject("testCTERecursivePanic", nil) + failpoint.Eval(_curpkg_("testCTERecursivePanic")) if p.recursiveExec == nil || p.iterInTbl.NumChunks() == 0 { return } @@ -442,14 +442,14 @@ func (p *cteProducer) computeRecursivePart(ctx context.Context) (err error) { p.logTbls(ctx, err, iterNum, zapcore.DebugLevel) } iterNum++ - failpoint.Inject("assertIterTableSpillToDisk", func(maxIter failpoint.Value) { + if maxIter, _err_ := failpoint.Eval(_curpkg_("assertIterTableSpillToDisk")); _err_ == nil { if iterNum > 0 && iterNum < uint64(maxIter.(int)) && err == nil { if p.iterInTbl.GetDiskBytes() == 0 && p.iterOutTbl.GetDiskBytes() == 0 && p.resTbl.GetDiskBytes() == 0 { p.logTbls(ctx, err, iterNum, zapcore.InfoLevel) panic("assert row container spill disk failed") } } - }) + } if err = p.setupTblsForNewIteration(); err != nil { return @@ -582,11 +582,11 @@ func setupCTEStorageTracker(tbl cteutil.Storage, ctx sessionctx.Context, parentM if variable.EnableTmpStorageOnOOM.Load() { actionSpill = tbl.ActionSpill() - failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testCTEStorageSpill")); _err_ == nil { if val.(bool) { actionSpill = tbl.(*cteutil.StorageRC).ActionSpillForTest() } - }) + } ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } return actionSpill diff --git a/pkg/executor/cte.go__failpoint_stash__ b/pkg/executor/cte.go__failpoint_stash__ new file mode 100644 index 0000000000000..4f57b89d3bd92 --- /dev/null +++ b/pkg/executor/cte.go__failpoint_stash__ @@ -0,0 +1,770 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "context" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/join" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/cteutil" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +var _ exec.Executor = &CTEExec{} + +// CTEExec implements CTE. +// Following diagram describes how CTEExec works. +// +// `iterInTbl` is shared by `CTEExec` and `CTETableReaderExec`. +// `CTETableReaderExec` reads data from `iterInTbl`, +// and its output will be stored `iterOutTbl` by `CTEExec`. +// +// When an iteration ends, `CTEExec` will move all data from `iterOutTbl` into `iterInTbl`, +// which will be the input for new iteration. +// At the end of each iteration, data in `iterOutTbl` will also be added into `resTbl`. +// `resTbl` stores data of all iteration. +/* + +----------+ + write |iterOutTbl| + CTEExec ------------------->| | + | +----+-----+ + ------------- | write + | | v + other op other op +----------+ + (seed) (recursive) | resTbl | + ^ | | + | +----------+ + CTETableReaderExec + ^ + | read +----------+ + +---------------+iterInTbl | + | | + +----------+ +*/ +type CTEExec struct { + exec.BaseExecutor + + chkIdx int + producer *cteProducer + + // limit in recursive CTE. + cursor uint64 + meetFirstBatch bool +} + +// Open implements the Executor interface. +func (e *CTEExec) Open(ctx context.Context) (err error) { + e.reset() + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + + e.producer.resTbl.Lock() + defer e.producer.resTbl.Unlock() + + if e.producer.checkAndUpdateCorColHashCode() { + e.producer.reset() + if err = e.producer.reopenTbls(); err != nil { + return err + } + } + if e.producer.openErr != nil { + return e.producer.openErr + } + if !e.producer.opened { + if err = e.producer.openProducer(ctx, e); err != nil { + return err + } + } + return nil +} + +// Next implements the Executor interface. +func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + e.producer.resTbl.Lock() + defer e.producer.resTbl.Unlock() + if !e.producer.resTbl.Done() { + if err = e.producer.produce(ctx); err != nil { + return err + } + } + return e.producer.getChunk(e, req) +} + +func setFirstErr(firstErr error, newErr error, msg string) error { + if newErr != nil { + logutil.BgLogger().Error("cte got error", zap.Any("err", newErr), zap.Any("extra msg", msg)) + if firstErr == nil { + firstErr = newErr + } + } + return firstErr +} + +// Close implements the Executor interface. +func (e *CTEExec) Close() (firstErr error) { + func() { + e.producer.resTbl.Lock() + defer e.producer.resTbl.Unlock() + if !e.producer.closed { + failpoint.Inject("mock_cte_exec_panic_avoid_deadlock", func(v failpoint.Value) { + ok := v.(bool) + if ok { + // mock an oom panic, returning ErrMemoryExceedForQuery for error identification in recovery work. + panic(exeerrors.ErrMemoryExceedForQuery) + } + }) + // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. + // It means you can still read resTbl after call closeProducer(). + // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). + // Separating these three function calls is only to follow the abstraction of the volcano model. + err := e.producer.closeProducer() + firstErr = setFirstErr(firstErr, err, "close cte producer error") + } + }() + err := e.BaseExecutor.Close() + firstErr = setFirstErr(firstErr, err, "close cte children error") + return +} + +func (e *CTEExec) reset() { + e.chkIdx = 0 + e.cursor = 0 + e.meetFirstBatch = false +} + +type cteProducer struct { + // opened should be false when not open or open fail(a.k.a. openErr != nil) + opened bool + produced bool + closed bool + + // cteProducer is shared by multiple operators, so if the first operator tries to open + // and got error, the second should return open error directly instead of open again. + // Otherwise there may be resource leak because Close() only clean resource for the last Open(). + openErr error + + ctx sessionctx.Context + + seedExec exec.Executor + recursiveExec exec.Executor + + // `resTbl` and `iterInTbl` are shared by all CTEExec which reference to same the CTE. + // `iterInTbl` is also shared by CTETableReaderExec. + resTbl cteutil.Storage + iterInTbl cteutil.Storage + iterOutTbl cteutil.Storage + + hashTbl join.BaseHashTable + + // UNION ALL or UNION DISTINCT. + isDistinct bool + curIter int + hCtx *join.HashContext + sel []int + + // Limit related info. + hasLimit bool + limitBeg uint64 + limitEnd uint64 + + memTracker *memory.Tracker + diskTracker *disk.Tracker + + // Correlated Column. + corCols []*expression.CorrelatedColumn + corColHashCodes [][]byte +} + +func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err error) { + defer func() { + p.openErr = err + if err == nil { + p.opened = true + } else { + p.opened = false + } + }() + if p.seedExec == nil { + return errors.New("seedExec for CTEExec is nil") + } + if err = exec.Open(ctx, p.seedExec); err != nil { + return err + } + + p.resetTracker() + p.memTracker = memory.NewTracker(cteExec.ID(), -1) + p.diskTracker = disk.NewTracker(cteExec.ID(), -1) + p.memTracker.AttachTo(p.ctx.GetSessionVars().StmtCtx.MemTracker) + p.diskTracker.AttachTo(p.ctx.GetSessionVars().StmtCtx.DiskTracker) + + if p.recursiveExec != nil { + if err = exec.Open(ctx, p.recursiveExec); err != nil { + return err + } + // For non-recursive CTE, the result will be put into resTbl directly. + // So no need to build iterOutTbl. + // Construct iterOutTbl in Open() instead of buildCTE(), because its destruct is in Close(). + recursiveTypes := p.recursiveExec.RetFieldTypes() + p.iterOutTbl = cteutil.NewStorageRowContainer(recursiveTypes, cteExec.MaxChunkSize()) + if err = p.iterOutTbl.OpenAndRef(); err != nil { + return err + } + } + + if p.isDistinct { + p.hashTbl = join.NewConcurrentMapHashTable() + p.hCtx = &join.HashContext{ + AllTypes: cteExec.RetFieldTypes(), + } + // We use all columns to compute hash. + p.hCtx.KeyColIdx = make([]int, len(p.hCtx.AllTypes)) + for i := range p.hCtx.KeyColIdx { + p.hCtx.KeyColIdx[i] = i + } + } + return nil +} + +func (p *cteProducer) closeProducer() (firstErr error) { + err := exec.Close(p.seedExec) + firstErr = setFirstErr(firstErr, err, "close seedExec err") + + if p.recursiveExec != nil { + err = exec.Close(p.recursiveExec) + firstErr = setFirstErr(firstErr, err, "close recursiveExec err") + + // `iterInTbl` and `resTbl` are shared by multiple operators, + // so will be closed when the SQL finishes. + if p.iterOutTbl != nil { + err = p.iterOutTbl.DerefAndClose() + firstErr = setFirstErr(firstErr, err, "deref iterOutTbl err") + } + } + // Reset to nil instead of calling Detach(), + // because ExplainExec still needs tracker to get mem usage info. + p.memTracker = nil + p.diskTracker = nil + p.closed = true + return +} + +func (p *cteProducer) getChunk(cteExec *CTEExec, req *chunk.Chunk) (err error) { + req.Reset() + if p.hasLimit { + return p.nextChunkLimit(cteExec, req) + } + if cteExec.chkIdx < p.resTbl.NumChunks() { + res, err := p.resTbl.GetChunk(cteExec.chkIdx) + if err != nil { + return err + } + // Need to copy chunk to make sure upper operator will not change chunk in resTbl. + // Also we ignore copying rows not selected, because some operators like Projection + // doesn't support swap column if chunk.sel is no nil. + req.SwapColumns(res.CopyConstructSel()) + cteExec.chkIdx++ + } + return nil +} + +func (p *cteProducer) nextChunkLimit(cteExec *CTEExec, req *chunk.Chunk) error { + if !cteExec.meetFirstBatch { + for cteExec.chkIdx < p.resTbl.NumChunks() { + res, err := p.resTbl.GetChunk(cteExec.chkIdx) + if err != nil { + return err + } + cteExec.chkIdx++ + numRows := uint64(res.NumRows()) + if newCursor := cteExec.cursor + numRows; newCursor >= p.limitBeg { + cteExec.meetFirstBatch = true + begInChk, endInChk := p.limitBeg-cteExec.cursor, numRows + if newCursor > p.limitEnd { + endInChk = p.limitEnd - cteExec.cursor + } + cteExec.cursor += endInChk + if begInChk == endInChk { + break + } + tmpChk := res.CopyConstructSel() + req.Append(tmpChk, int(begInChk), int(endInChk)) + return nil + } + cteExec.cursor += numRows + } + } + if cteExec.chkIdx < p.resTbl.NumChunks() && cteExec.cursor < p.limitEnd { + res, err := p.resTbl.GetChunk(cteExec.chkIdx) + if err != nil { + return err + } + cteExec.chkIdx++ + numRows := uint64(res.NumRows()) + if cteExec.cursor+numRows > p.limitEnd { + numRows = p.limitEnd - cteExec.cursor + req.Append(res.CopyConstructSel(), 0, int(numRows)) + } else { + req.SwapColumns(res.CopyConstructSel()) + } + cteExec.cursor += numRows + } + return nil +} + +func (p *cteProducer) produce(ctx context.Context) (err error) { + if p.resTbl.Error() != nil { + return p.resTbl.Error() + } + resAction := setupCTEStorageTracker(p.resTbl, p.ctx, p.memTracker, p.diskTracker) + iterInAction := setupCTEStorageTracker(p.iterInTbl, p.ctx, p.memTracker, p.diskTracker) + var iterOutAction *chunk.SpillDiskAction + if p.iterOutTbl != nil { + iterOutAction = setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) + } + + failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { + if val.(bool) && variable.EnableTmpStorageOnOOM.Load() { + defer resAction.WaitForTest() + defer iterInAction.WaitForTest() + if iterOutAction != nil { + defer iterOutAction.WaitForTest() + } + } + }) + + if err = p.computeSeedPart(ctx); err != nil { + p.resTbl.SetError(err) + return err + } + if err = p.computeRecursivePart(ctx); err != nil { + p.resTbl.SetError(err) + return err + } + p.resTbl.SetDone() + return nil +} + +func (p *cteProducer) computeSeedPart(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil && err == nil { + err = util.GetRecoverError(r) + } + }() + failpoint.Inject("testCTESeedPanic", nil) + p.curIter = 0 + p.iterInTbl.SetIter(p.curIter) + chks := make([]*chunk.Chunk, 0, 10) + for { + if p.limitDone(p.iterInTbl) { + break + } + chk := exec.TryNewCacheChunk(p.seedExec) + if err = exec.Next(ctx, p.seedExec, chk); err != nil { + return + } + if chk.NumRows() == 0 { + break + } + if chk, err = p.tryDedupAndAdd(chk, p.iterInTbl, p.hashTbl); err != nil { + return + } + chks = append(chks, chk) + } + // Initial resTbl is empty, so no need to deduplicate chk using resTbl. + // Just adding is ok. + for _, chk := range chks { + if err = p.resTbl.Add(chk); err != nil { + return + } + } + p.curIter++ + p.iterInTbl.SetIter(p.curIter) + + return +} + +func (p *cteProducer) computeRecursivePart(ctx context.Context) (err error) { + defer func() { + if r := recover(); r != nil && err == nil { + err = util.GetRecoverError(r) + } + }() + failpoint.Inject("testCTERecursivePanic", nil) + if p.recursiveExec == nil || p.iterInTbl.NumChunks() == 0 { + return + } + + if p.curIter > p.ctx.GetSessionVars().CTEMaxRecursionDepth { + return exeerrors.ErrCTEMaxRecursionDepth.GenWithStackByArgs(p.curIter) + } + + if p.limitDone(p.resTbl) { + return + } + + var iterNum uint64 + for { + chk := exec.TryNewCacheChunk(p.recursiveExec) + if err = exec.Next(ctx, p.recursiveExec, chk); err != nil { + return + } + if chk.NumRows() == 0 { + if iterNum%1000 == 0 { + // To avoid too many logs. + p.logTbls(ctx, err, iterNum, zapcore.DebugLevel) + } + iterNum++ + failpoint.Inject("assertIterTableSpillToDisk", func(maxIter failpoint.Value) { + if iterNum > 0 && iterNum < uint64(maxIter.(int)) && err == nil { + if p.iterInTbl.GetDiskBytes() == 0 && p.iterOutTbl.GetDiskBytes() == 0 && p.resTbl.GetDiskBytes() == 0 { + p.logTbls(ctx, err, iterNum, zapcore.InfoLevel) + panic("assert row container spill disk failed") + } + } + }) + + if err = p.setupTblsForNewIteration(); err != nil { + return + } + if p.limitDone(p.resTbl) { + break + } + if p.iterInTbl.NumChunks() == 0 { + break + } + // Next iteration begins. Need use iterOutTbl as input of next iteration. + p.curIter++ + p.iterInTbl.SetIter(p.curIter) + if p.curIter > p.ctx.GetSessionVars().CTEMaxRecursionDepth { + return exeerrors.ErrCTEMaxRecursionDepth.GenWithStackByArgs(p.curIter) + } + // Make sure iterInTbl is setup before Close/Open, + // because some executors will read iterInTbl in Open() (like IndexLookupJoin). + if err = exec.Close(p.recursiveExec); err != nil { + return + } + if err = exec.Open(ctx, p.recursiveExec); err != nil { + return + } + } else { + if err = p.iterOutTbl.Add(chk); err != nil { + return + } + } + } + return +} + +func (p *cteProducer) setupTblsForNewIteration() (err error) { + num := p.iterOutTbl.NumChunks() + chks := make([]*chunk.Chunk, 0, num) + // Setup resTbl's data. + for i := 0; i < num; i++ { + chk, err := p.iterOutTbl.GetChunk(i) + if err != nil { + return err + } + // Data should be copied in UNION DISTINCT. + // Because deduplicate() will change data in iterOutTbl, + // which will cause panic when spilling data into disk concurrently. + if p.isDistinct { + chk = chk.CopyConstruct() + } + chk, err = p.tryDedupAndAdd(chk, p.resTbl, p.hashTbl) + if err != nil { + return err + } + chks = append(chks, chk) + } + + // Setup new iteration data in iterInTbl. + if err = p.iterInTbl.Reopen(); err != nil { + return err + } + setupCTEStorageTracker(p.iterInTbl, p.ctx, p.memTracker, p.diskTracker) + + if p.isDistinct { + // Already deduplicated by resTbl, adding directly is ok. + for _, chk := range chks { + if err = p.iterInTbl.Add(chk); err != nil { + return err + } + } + } else { + if err = p.iterInTbl.SwapData(p.iterOutTbl); err != nil { + return err + } + } + + // Clear data in iterOutTbl. + if err = p.iterOutTbl.Reopen(); err != nil { + return err + } + setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) + return nil +} + +func (p *cteProducer) reset() { + p.curIter = 0 + p.hashTbl = nil + + p.opened = false + p.openErr = nil + p.produced = false + p.closed = false +} + +func (p *cteProducer) resetTracker() { + if p.memTracker != nil { + p.memTracker.Reset() + p.memTracker = nil + } + if p.diskTracker != nil { + p.diskTracker.Reset() + p.diskTracker = nil + } +} + +func (p *cteProducer) reopenTbls() (err error) { + if p.isDistinct { + p.hashTbl = join.NewConcurrentMapHashTable() + } + // Normally we need to setup tracker after calling Reopen(), + // But reopen resTbl means we need to call produce() again, it will setup tracker. + if err := p.resTbl.Reopen(); err != nil { + return err + } + return p.iterInTbl.Reopen() +} + +// Check if tbl meets the requirement of limit. +func (p *cteProducer) limitDone(tbl cteutil.Storage) bool { + return p.hasLimit && uint64(tbl.NumRows()) >= p.limitEnd +} + +func setupCTEStorageTracker(tbl cteutil.Storage, ctx sessionctx.Context, parentMemTracker *memory.Tracker, + parentDiskTracker *disk.Tracker) (actionSpill *chunk.SpillDiskAction) { + memTracker := tbl.GetMemTracker() + memTracker.SetLabel(memory.LabelForCTEStorage) + memTracker.AttachTo(parentMemTracker) + + diskTracker := tbl.GetDiskTracker() + diskTracker.SetLabel(memory.LabelForCTEStorage) + diskTracker.AttachTo(parentDiskTracker) + + if variable.EnableTmpStorageOnOOM.Load() { + actionSpill = tbl.ActionSpill() + failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { + if val.(bool) { + actionSpill = tbl.(*cteutil.StorageRC).ActionSpillForTest() + } + }) + ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) + } + return actionSpill +} + +func (p *cteProducer) tryDedupAndAdd(chk *chunk.Chunk, + storage cteutil.Storage, + hashTbl join.BaseHashTable) (res *chunk.Chunk, err error) { + if p.isDistinct { + if chk, err = p.deduplicate(chk, storage, hashTbl); err != nil { + return nil, err + } + } + return chk, storage.Add(chk) +} + +// Compute hash values in chk and put it in hCtx.hashVals. +// Use the returned sel to choose the computed hash values. +func (p *cteProducer) computeChunkHash(chk *chunk.Chunk) (sel []int, err error) { + numRows := chk.NumRows() + p.hCtx.InitHash(numRows) + // Continue to reset to make sure all hasher is new. + for i := numRows; i < len(p.hCtx.HashVals); i++ { + p.hCtx.HashVals[i].Reset() + } + sel = chk.Sel() + var hashBitMap []bool + if sel != nil { + hashBitMap = make([]bool, chk.Capacity()) + for _, val := range sel { + hashBitMap[val] = true + } + } else { + // Length of p.sel is init as MaxChunkSize, but the row num of chunk may still exceeds MaxChunkSize. + // So needs to handle here to make sure len(p.sel) == chk.NumRows(). + if len(p.sel) < numRows { + tmpSel := make([]int, numRows-len(p.sel)) + for i := 0; i < len(tmpSel); i++ { + tmpSel[i] = i + len(p.sel) + } + p.sel = append(p.sel, tmpSel...) + } + + // All rows is selected, sel will be [0....numRows). + // e.sel is setup when building executor. + sel = p.sel + } + + for i := 0; i < chk.NumCols(); i++ { + if err = codec.HashChunkSelected(p.ctx.GetSessionVars().StmtCtx.TypeCtx(), p.hCtx.HashVals, + chk, p.hCtx.AllTypes[i], i, p.hCtx.Buf, p.hCtx.HasNull, + hashBitMap, false); err != nil { + return nil, err + } + } + return sel, nil +} + +// Use hashTbl to deduplicate rows, and unique rows will be added to hashTbl. +// Duplicated rows are only marked to be removed by sel in Chunk, instead of really deleted. +func (p *cteProducer) deduplicate(chk *chunk.Chunk, + storage cteutil.Storage, + hashTbl join.BaseHashTable) (chkNoDup *chunk.Chunk, err error) { + numRows := chk.NumRows() + if numRows == 0 { + return chk, nil + } + + // 1. Compute hash values for chunk. + chkHashTbl := join.NewConcurrentMapHashTable() + selOri, err := p.computeChunkHash(chk) + if err != nil { + return nil, err + } + + // 2. Filter rows duplicated in input chunk. + // This sel is for filtering rows duplicated in cur chk. + selChk := make([]int, 0, numRows) + for i := 0; i < numRows; i++ { + key := p.hCtx.HashVals[selOri[i]].Sum64() + row := chk.GetRow(i) + + hasDup, err := p.checkHasDup(key, row, chk, storage, chkHashTbl) + if err != nil { + return nil, err + } + if hasDup { + continue + } + + selChk = append(selChk, selOri[i]) + + rowPtr := chunk.RowPtr{ChkIdx: uint32(0), RowIdx: uint32(i)} + chkHashTbl.Put(key, rowPtr) + } + chk.SetSel(selChk) + chkIdx := storage.NumChunks() + + // 3. Filter rows duplicated in RowContainer. + // This sel is for filtering rows duplicated in cteutil.Storage. + selStorage := make([]int, 0, len(selChk)) + for i := 0; i < len(selChk); i++ { + key := p.hCtx.HashVals[selChk[i]].Sum64() + row := chk.GetRow(i) + + hasDup, err := p.checkHasDup(key, row, nil, storage, hashTbl) + if err != nil { + return nil, err + } + if hasDup { + continue + } + + rowIdx := len(selStorage) + selStorage = append(selStorage, selChk[i]) + + rowPtr := chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)} + hashTbl.Put(key, rowPtr) + } + + chk.SetSel(selStorage) + return chk, nil +} + +// Use the row's probe key to check if it already exists in chk or storage. +// We also need to compare the row's real encoding value to avoid hash collision. +func (p *cteProducer) checkHasDup(probeKey uint64, + row chunk.Row, + curChk *chunk.Chunk, + storage cteutil.Storage, + hashTbl join.BaseHashTable) (hasDup bool, err error) { + entry := hashTbl.Get(probeKey) + + for ; entry != nil; entry = entry.Next { + ptr := entry.Ptr + var matchedRow chunk.Row + if curChk != nil { + matchedRow = curChk.GetRow(int(ptr.RowIdx)) + } else { + matchedRow, err = storage.GetRow(ptr) + } + if err != nil { + return false, err + } + isEqual, err := codec.EqualChunkRow(p.ctx.GetSessionVars().StmtCtx.TypeCtx(), + row, p.hCtx.AllTypes, p.hCtx.KeyColIdx, + matchedRow, p.hCtx.AllTypes, p.hCtx.KeyColIdx) + if err != nil { + return false, err + } + if isEqual { + return true, nil + } + } + return false, nil +} + +func getCorColHashCode(corCol *expression.CorrelatedColumn) (res []byte) { + return codec.HashCode(res, *corCol.Data) +} + +// Return true if cor col has changed. +func (p *cteProducer) checkAndUpdateCorColHashCode() bool { + var changed bool + for i, corCol := range p.corCols { + newHashCode := getCorColHashCode(corCol) + if !bytes.Equal(newHashCode, p.corColHashCodes[i]) { + changed = true + p.corColHashCodes[i] = newHashCode + } + } + return changed +} + +func (p *cteProducer) logTbls(ctx context.Context, err error, iterNum uint64, lvl zapcore.Level) { + logutil.Logger(ctx).Log(lvl, "cte iteration info", + zap.Any("iterInTbl mem usage", p.iterInTbl.GetMemBytes()), zap.Any("iterInTbl disk usage", p.iterInTbl.GetDiskBytes()), + zap.Any("iterOutTbl mem usage", p.iterOutTbl.GetMemBytes()), zap.Any("iterOutTbl disk usage", p.iterOutTbl.GetDiskBytes()), + zap.Any("resTbl mem usage", p.resTbl.GetMemBytes()), zap.Any("resTbl disk usage", p.resTbl.GetDiskBytes()), + zap.Any("resTbl rows", p.resTbl.NumRows()), zap.Any("iteration num", iterNum), zap.Error(err)) +} diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index d6e6252d51834..27731493fb0e5 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -1561,11 +1561,11 @@ func (e *SelectionExec) Open(ctx context.Context) error { if err := e.BaseExecutorV2.Open(ctx); err != nil { return err } - failpoint.Inject("mockSelectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockSelectionExecBaseExecutorOpenReturnedError")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.New("mock SelectionExec.baseExecutor.Open returned error")) + return errors.New("mock SelectionExec.baseExecutor.Open returned error") } - }) + } return e.open(ctx) } diff --git a/pkg/executor/executor.go__failpoint_stash__ b/pkg/executor/executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..d6e6252d51834 --- /dev/null +++ b/pkg/executor/executor.go__failpoint_stash__ @@ -0,0 +1,2673 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "cmp" + "context" + stderrors "errors" + "fmt" + "math" + "runtime/pprof" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/opentracing/opentracing-go" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/schematracker" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/executor/aggregate" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/internal/pdhelper" + "github.com/pingcap/tidb/pkg/executor/sortexec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + planctx "github.com/pingcap/tidb/pkg/planner/context" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/resourcemanager/pool/workerpool" + poolutil "github.com/pingcap/tidb/pkg/resourcemanager/util" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/admin" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/deadlockhistory" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/logutil/consistency" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/resourcegrouptag" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/tracing" + tikverr "github.com/tikv/client-go/v2/error" + tikvstore "github.com/tikv/client-go/v2/kv" + tikvutil "github.com/tikv/client-go/v2/util" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +var ( + _ exec.Executor = &CheckTableExec{} + _ exec.Executor = &aggregate.HashAggExec{} + _ exec.Executor = &IndexLookUpExecutor{} + _ exec.Executor = &IndexReaderExecutor{} + _ exec.Executor = &LimitExec{} + _ exec.Executor = &MaxOneRowExec{} + _ exec.Executor = &ProjectionExec{} + _ exec.Executor = &SelectionExec{} + _ exec.Executor = &SelectLockExec{} + _ exec.Executor = &ShowNextRowIDExec{} + _ exec.Executor = &ShowDDLExec{} + _ exec.Executor = &ShowDDLJobsExec{} + _ exec.Executor = &ShowDDLJobQueriesExec{} + _ exec.Executor = &sortexec.SortExec{} + _ exec.Executor = &aggregate.StreamAggExec{} + _ exec.Executor = &TableDualExec{} + _ exec.Executor = &TableReaderExecutor{} + _ exec.Executor = &TableScanExec{} + _ exec.Executor = &sortexec.TopNExec{} + _ exec.Executor = &FastCheckTableExec{} + _ exec.Executor = &AdminShowBDRRoleExec{} + + // GlobalMemoryUsageTracker is the ancestor of all the Executors' memory tracker and GlobalMemory Tracker + GlobalMemoryUsageTracker *memory.Tracker + // GlobalDiskUsageTracker is the ancestor of all the Executors' disk tracker + GlobalDiskUsageTracker *disk.Tracker + // GlobalAnalyzeMemoryTracker is the ancestor of all the Analyze jobs' memory tracker and child of global Tracker + GlobalAnalyzeMemoryTracker *memory.Tracker +) + +var ( + _ dataSourceExecutor = &TableReaderExecutor{} + _ dataSourceExecutor = &IndexReaderExecutor{} + _ dataSourceExecutor = &IndexLookUpExecutor{} + _ dataSourceExecutor = &IndexMergeReaderExecutor{} + + // CheckTableFastBucketSize is the bucket size of fast check table. + CheckTableFastBucketSize = atomic.Int64{} +) + +// dataSourceExecutor is a table DataSource converted Executor. +// Currently, there are TableReader/IndexReader/IndexLookUp/IndexMergeReader. +// Note, partition reader is special and the caller should handle it carefully. +type dataSourceExecutor interface { + exec.Executor + Table() table.Table +} + +const ( + // globalPanicStorageExceed represents the panic message when out of storage quota. + globalPanicStorageExceed string = "Out Of Quota For Local Temporary Space!" + // globalPanicMemoryExceed represents the panic message when out of memory limit. + globalPanicMemoryExceed string = "Out Of Global Memory Limit!" + // globalPanicAnalyzeMemoryExceed represents the panic message when out of analyze memory limit. + globalPanicAnalyzeMemoryExceed string = "Out Of Global Analyze Memory Limit!" +) + +// globalPanicOnExceed panics when GlobalDisTracker storage usage exceeds storage quota. +type globalPanicOnExceed struct { + memory.BaseOOMAction + mutex syncutil.Mutex // For synchronization. +} + +func init() { + action := &globalPanicOnExceed{} + GlobalMemoryUsageTracker = memory.NewGlobalTracker(memory.LabelForGlobalMemory, -1) + GlobalMemoryUsageTracker.SetActionOnExceed(action) + GlobalDiskUsageTracker = disk.NewGlobalTrcaker(memory.LabelForGlobalStorage, -1) + GlobalDiskUsageTracker.SetActionOnExceed(action) + GlobalAnalyzeMemoryTracker = memory.NewTracker(memory.LabelForGlobalAnalyzeMemory, -1) + GlobalAnalyzeMemoryTracker.SetActionOnExceed(action) + // register quota funcs + variable.SetMemQuotaAnalyze = GlobalAnalyzeMemoryTracker.SetBytesLimit + variable.GetMemQuotaAnalyze = GlobalAnalyzeMemoryTracker.GetBytesLimit + // TODO: do not attach now to avoid impact to global, will attach later when analyze memory track is stable + //GlobalAnalyzeMemoryTracker.AttachToGlobalTracker(GlobalMemoryUsageTracker) + + schematracker.ConstructResultOfShowCreateDatabase = ConstructResultOfShowCreateDatabase + schematracker.ConstructResultOfShowCreateTable = ConstructResultOfShowCreateTable + + // CheckTableFastBucketSize is used to set the fast analyze bucket size for check table. + CheckTableFastBucketSize.Store(1024) +} + +// Start the backend components +func Start() { + pdhelper.GlobalPDHelper.Start() +} + +// Stop the backend components +func Stop() { + pdhelper.GlobalPDHelper.Stop() +} + +// Action panics when storage usage exceeds storage quota. +func (a *globalPanicOnExceed) Action(t *memory.Tracker) { + a.mutex.Lock() + defer a.mutex.Unlock() + msg := "" + switch t.Label() { + case memory.LabelForGlobalStorage: + msg = globalPanicStorageExceed + case memory.LabelForGlobalMemory: + msg = globalPanicMemoryExceed + case memory.LabelForGlobalAnalyzeMemory: + msg = globalPanicAnalyzeMemoryExceed + default: + msg = "Out of Unknown Resource Quota!" + } + // TODO(hawkingrei): should return error instead. + panic(msg) +} + +// GetPriority get the priority of the Action +func (*globalPanicOnExceed) GetPriority() int64 { + return memory.DefPanicPriority +} + +// CommandDDLJobsExec is the general struct for Cancel/Pause/Resume commands on +// DDL jobs. These command currently by admin have the very similar struct and +// operations, it should be a better idea to have them in the same struct. +type CommandDDLJobsExec struct { + exec.BaseExecutor + + cursor int + jobIDs []int64 + errs []error + + execute func(se sessionctx.Context, ids []int64) (errs []error, err error) +} + +// Open implements the Executor for all Cancel/Pause/Resume command on DDL jobs +// just with different processes. And, it should not be called directly by the +// Executor. +func (e *CommandDDLJobsExec) Open(context.Context) error { + // We want to use a global transaction to execute the admin command, so we don't use e.Ctx() here. + newSess, err := e.GetSysSession() + if err != nil { + return err + } + e.errs, err = e.execute(newSess, e.jobIDs) + e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), newSess) + return err +} + +// Next implements the Executor Next interface for Cancel/Pause/Resume +func (e *CommandDDLJobsExec) Next(_ context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + if e.cursor >= len(e.jobIDs) { + return nil + } + numCurBatch := min(req.Capacity(), len(e.jobIDs)-e.cursor) + for i := e.cursor; i < e.cursor+numCurBatch; i++ { + req.AppendString(0, strconv.FormatInt(e.jobIDs[i], 10)) + if e.errs != nil && e.errs[i] != nil { + req.AppendString(1, fmt.Sprintf("error: %v", e.errs[i])) + } else { + req.AppendString(1, "successful") + } + } + e.cursor += numCurBatch + return nil +} + +// CancelDDLJobsExec represents a cancel DDL jobs executor. +type CancelDDLJobsExec struct { + *CommandDDLJobsExec +} + +// PauseDDLJobsExec indicates an Executor for Pause a DDL Job. +type PauseDDLJobsExec struct { + *CommandDDLJobsExec +} + +// ResumeDDLJobsExec indicates an Executor for Resume a DDL Job. +type ResumeDDLJobsExec struct { + *CommandDDLJobsExec +} + +// ShowNextRowIDExec represents a show the next row ID executor. +type ShowNextRowIDExec struct { + exec.BaseExecutor + tblName *ast.TableName + done bool +} + +// Next implements the Executor Next interface. +func (e *ShowNextRowIDExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.done { + return nil + } + is := domain.GetDomain(e.Ctx()).InfoSchema() + tbl, err := is.TableByName(ctx, e.tblName.Schema, e.tblName.Name) + if err != nil { + return err + } + tblMeta := tbl.Meta() + + allocators := tbl.Allocators(e.Ctx().GetTableCtx()) + for _, alloc := range allocators.Allocs { + nextGlobalID, err := alloc.NextGlobalAutoID() + if err != nil { + return err + } + + var colName, idType string + switch alloc.GetType() { + case autoid.RowIDAllocType: + idType = "_TIDB_ROWID" + if tblMeta.PKIsHandle { + if col := tblMeta.GetAutoIncrementColInfo(); col != nil { + colName = col.Name.O + } + } else { + colName = model.ExtraHandleName.O + } + case autoid.AutoIncrementType: + idType = "AUTO_INCREMENT" + if tblMeta.PKIsHandle { + if col := tblMeta.GetAutoIncrementColInfo(); col != nil { + colName = col.Name.O + } + } else { + colName = model.ExtraHandleName.O + } + case autoid.AutoRandomType: + idType = "AUTO_RANDOM" + colName = tblMeta.GetPkName().O + case autoid.SequenceType: + idType = "SEQUENCE" + colName = "" + default: + return autoid.ErrInvalidAllocatorType.GenWithStackByArgs() + } + + req.AppendString(0, e.tblName.Schema.O) + req.AppendString(1, e.tblName.Name.O) + req.AppendString(2, colName) + req.AppendInt64(3, nextGlobalID) + req.AppendString(4, idType) + } + + e.done = true + return nil +} + +// ShowDDLExec represents a show DDL executor. +type ShowDDLExec struct { + exec.BaseExecutor + + ddlOwnerID string + selfID string + ddlInfo *ddl.Info + done bool +} + +// Next implements the Executor Next interface. +func (e *ShowDDLExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.done { + return nil + } + + ddlJobs := "" + query := "" + l := len(e.ddlInfo.Jobs) + for i, job := range e.ddlInfo.Jobs { + ddlJobs += job.String() + query += job.Query + if i != l-1 { + ddlJobs += "\n" + query += "\n" + } + } + + serverInfo, err := infosync.GetServerInfoByID(ctx, e.ddlOwnerID) + if err != nil { + return err + } + + serverAddress := serverInfo.IP + ":" + + strconv.FormatUint(uint64(serverInfo.Port), 10) + + req.AppendInt64(0, e.ddlInfo.SchemaVer) + req.AppendString(1, e.ddlOwnerID) + req.AppendString(2, serverAddress) + req.AppendString(3, ddlJobs) + req.AppendString(4, e.selfID) + req.AppendString(5, query) + + e.done = true + return nil +} + +// ShowDDLJobsExec represent a show DDL jobs executor. +type ShowDDLJobsExec struct { + exec.BaseExecutor + DDLJobRetriever + + jobNumber int + is infoschema.InfoSchema + sess sessionctx.Context +} + +// DDLJobRetriever retrieve the DDLJobs. +// nolint:structcheck +type DDLJobRetriever struct { + runningJobs []*model.Job + historyJobIter meta.LastJobIterator + cursor int + is infoschema.InfoSchema + activeRoles []*auth.RoleIdentity + cacheJobs []*model.Job + TZLoc *time.Location +} + +func (e *DDLJobRetriever) initial(txn kv.Transaction, sess sessionctx.Context) error { + m := meta.NewMeta(txn) + jobs, err := ddl.GetAllDDLJobs(sess) + if err != nil { + return err + } + e.historyJobIter, err = ddl.GetLastHistoryDDLJobsIterator(m) + if err != nil { + return err + } + e.runningJobs = jobs + e.cursor = 0 + return nil +} + +func (e *DDLJobRetriever) appendJobToChunk(req *chunk.Chunk, job *model.Job, checker privilege.Manager) { + schemaName := job.SchemaName + tableName := "" + finishTS := uint64(0) + if job.BinlogInfo != nil { + finishTS = job.BinlogInfo.FinishedTS + if job.BinlogInfo.TableInfo != nil { + tableName = job.BinlogInfo.TableInfo.Name.L + } + if job.BinlogInfo.MultipleTableInfos != nil { + tablenames := new(strings.Builder) + for i, affect := range job.BinlogInfo.MultipleTableInfos { + if i > 0 { + fmt.Fprintf(tablenames, ",") + } + fmt.Fprintf(tablenames, "%s", affect.Name.L) + } + tableName = tablenames.String() + } + if len(schemaName) == 0 && job.BinlogInfo.DBInfo != nil { + schemaName = job.BinlogInfo.DBInfo.Name.L + } + } + if len(tableName) == 0 { + tableName = job.TableName + } + // For compatibility, the old version of DDL Job wasn't store the schema name and table name. + if len(schemaName) == 0 { + schemaName = getSchemaName(e.is, job.SchemaID) + } + if len(tableName) == 0 { + tableName = getTableName(e.is, job.TableID) + } + + createTime := ts2Time(job.StartTS, e.TZLoc) + startTime := ts2Time(job.RealStartTS, e.TZLoc) + finishTime := ts2Time(finishTS, e.TZLoc) + + // Check the privilege. + if checker != nil && !checker.RequestVerification(e.activeRoles, strings.ToLower(schemaName), strings.ToLower(tableName), "", mysql.AllPrivMask) { + return + } + + req.AppendInt64(0, job.ID) + req.AppendString(1, schemaName) + req.AppendString(2, tableName) + req.AppendString(3, job.Type.String()+showAddIdxReorgTp(job)) + req.AppendString(4, job.SchemaState.String()) + req.AppendInt64(5, job.SchemaID) + req.AppendInt64(6, job.TableID) + req.AppendInt64(7, job.RowCount) + req.AppendTime(8, createTime) + if job.RealStartTS > 0 { + req.AppendTime(9, startTime) + } else { + req.AppendNull(9) + } + if finishTS > 0 { + req.AppendTime(10, finishTime) + } else { + req.AppendNull(10) + } + req.AppendString(11, job.State.String()) + if job.Type == model.ActionMultiSchemaChange { + isDistTask := job.ReorgMeta != nil && job.ReorgMeta.IsDistReorg + for _, subJob := range job.MultiSchemaInfo.SubJobs { + req.AppendInt64(0, job.ID) + req.AppendString(1, schemaName) + req.AppendString(2, tableName) + req.AppendString(3, subJob.Type.String()+" /* subjob */"+showAddIdxReorgTpInSubJob(subJob, isDistTask)) + req.AppendString(4, subJob.SchemaState.String()) + req.AppendInt64(5, job.SchemaID) + req.AppendInt64(6, job.TableID) + req.AppendInt64(7, subJob.RowCount) + req.AppendTime(8, createTime) + if subJob.RealStartTS > 0 { + realStartTS := ts2Time(subJob.RealStartTS, e.TZLoc) + req.AppendTime(9, realStartTS) + } else { + req.AppendNull(9) + } + if finishTS > 0 { + req.AppendTime(10, finishTime) + } else { + req.AppendNull(10) + } + req.AppendString(11, subJob.State.String()) + } + } +} + +func showAddIdxReorgTp(job *model.Job) string { + if job.Type == model.ActionAddIndex || job.Type == model.ActionAddPrimaryKey { + if job.ReorgMeta != nil { + sb := strings.Builder{} + tp := job.ReorgMeta.ReorgTp.String() + if len(tp) > 0 { + sb.WriteString(" /* ") + sb.WriteString(tp) + if job.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge && + job.ReorgMeta.IsDistReorg && + job.ReorgMeta.UseCloudStorage { + sb.WriteString(" cloud") + } + sb.WriteString(" */") + } + return sb.String() + } + } + return "" +} + +func showAddIdxReorgTpInSubJob(subJob *model.SubJob, useDistTask bool) string { + if subJob.Type == model.ActionAddIndex || subJob.Type == model.ActionAddPrimaryKey { + sb := strings.Builder{} + tp := subJob.ReorgTp.String() + if len(tp) > 0 { + sb.WriteString(" /* ") + sb.WriteString(tp) + if subJob.ReorgTp == model.ReorgTypeLitMerge && useDistTask && subJob.UseCloud { + sb.WriteString(" cloud") + } + sb.WriteString(" */") + } + return sb.String() + } + return "" +} + +func ts2Time(timestamp uint64, loc *time.Location) types.Time { + duration := time.Duration(math.Pow10(9-types.DefaultFsp)) * time.Nanosecond + t := model.TSConvert2Time(timestamp) + t.Truncate(duration) + return types.NewTime(types.FromGoTime(t.In(loc)), mysql.TypeDatetime, types.MaxFsp) +} + +// ShowDDLJobQueriesExec represents a show DDL job queries executor. +// The jobs id that is given by 'admin show ddl job queries' statement, +// only be searched in the latest 10 history jobs. +type ShowDDLJobQueriesExec struct { + exec.BaseExecutor + + cursor int + jobs []*model.Job + jobIDs []int64 +} + +// Open implements the Executor Open interface. +func (e *ShowDDLJobQueriesExec) Open(ctx context.Context) error { + var err error + var jobs []*model.Job + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + session, err := e.GetSysSession() + if err != nil { + return err + } + err = sessiontxn.NewTxn(context.Background(), session) + if err != nil { + return err + } + defer func() { + // ReleaseSysSession will rollbacks txn automatically. + e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), session) + }() + txn, err := session.Txn(true) + if err != nil { + return err + } + session.GetSessionVars().SetInTxn(true) + + m := meta.NewMeta(txn) + jobs, err = ddl.GetAllDDLJobs(session) + if err != nil { + return err + } + + historyJobs, err := ddl.GetLastNHistoryDDLJobs(m, ddl.DefNumHistoryJobs) + if err != nil { + return err + } + + appendedJobID := make(map[int64]struct{}) + // deduplicate job results + // for situations when this operation happens at the same time with new DDLs being executed + for _, job := range jobs { + if _, ok := appendedJobID[job.ID]; !ok { + appendedJobID[job.ID] = struct{}{} + e.jobs = append(e.jobs, job) + } + } + for _, historyJob := range historyJobs { + if _, ok := appendedJobID[historyJob.ID]; !ok { + appendedJobID[historyJob.ID] = struct{}{} + e.jobs = append(e.jobs, historyJob) + } + } + + return nil +} + +// Next implements the Executor Next interface. +func (e *ShowDDLJobQueriesExec) Next(_ context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + if e.cursor >= len(e.jobs) { + return nil + } + if len(e.jobIDs) >= len(e.jobs) { + return nil + } + numCurBatch := min(req.Capacity(), len(e.jobs)-e.cursor) + for _, id := range e.jobIDs { + for i := e.cursor; i < e.cursor+numCurBatch; i++ { + if id == e.jobs[i].ID { + req.AppendString(0, e.jobs[i].Query) + } + } + } + e.cursor += numCurBatch + return nil +} + +// ShowDDLJobQueriesWithRangeExec represents a show DDL job queries with range executor. +// The jobs id that is given by 'admin show ddl job queries' statement, +// can be searched within a specified range in history jobs using offset and limit. +type ShowDDLJobQueriesWithRangeExec struct { + exec.BaseExecutor + + cursor int + jobs []*model.Job + offset uint64 + limit uint64 +} + +// Open implements the Executor Open interface. +func (e *ShowDDLJobQueriesWithRangeExec) Open(ctx context.Context) error { + var err error + var jobs []*model.Job + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + session, err := e.GetSysSession() + if err != nil { + return err + } + err = sessiontxn.NewTxn(context.Background(), session) + if err != nil { + return err + } + defer func() { + // ReleaseSysSession will rollbacks txn automatically. + e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), session) + }() + txn, err := session.Txn(true) + if err != nil { + return err + } + session.GetSessionVars().SetInTxn(true) + + m := meta.NewMeta(txn) + jobs, err = ddl.GetAllDDLJobs(session) + if err != nil { + return err + } + + historyJobs, err := ddl.GetLastNHistoryDDLJobs(m, int(e.offset+e.limit)) + if err != nil { + return err + } + + appendedJobID := make(map[int64]struct{}) + // deduplicate job results + // for situations when this operation happens at the same time with new DDLs being executed + for _, job := range jobs { + if _, ok := appendedJobID[job.ID]; !ok { + appendedJobID[job.ID] = struct{}{} + e.jobs = append(e.jobs, job) + } + } + for _, historyJob := range historyJobs { + if _, ok := appendedJobID[historyJob.ID]; !ok { + appendedJobID[historyJob.ID] = struct{}{} + e.jobs = append(e.jobs, historyJob) + } + } + + if e.cursor < int(e.offset) { + e.cursor = int(e.offset) + } + + return nil +} + +// Next implements the Executor Next interface. +func (e *ShowDDLJobQueriesWithRangeExec) Next(_ context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + if e.cursor >= len(e.jobs) { + return nil + } + if int(e.offset) > len(e.jobs) { + return nil + } + numCurBatch := min(req.Capacity(), len(e.jobs)-e.cursor) + for i := e.cursor; i < e.cursor+numCurBatch; i++ { + // i is make true to be >= int(e.offset) + if i >= int(e.offset+e.limit) { + break + } + req.AppendString(0, strconv.FormatInt(e.jobs[i].ID, 10)) + req.AppendString(1, e.jobs[i].Query) + } + e.cursor += numCurBatch + return nil +} + +// Open implements the Executor Open interface. +func (e *ShowDDLJobsExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + e.DDLJobRetriever.is = e.is + if e.jobNumber == 0 { + e.jobNumber = ddl.DefNumHistoryJobs + } + sess, err := e.GetSysSession() + if err != nil { + return err + } + e.sess = sess + err = sessiontxn.NewTxn(context.Background(), sess) + if err != nil { + return err + } + txn, err := sess.Txn(true) + if err != nil { + return err + } + sess.GetSessionVars().SetInTxn(true) + err = e.DDLJobRetriever.initial(txn, sess) + return err +} + +// Next implements the Executor Next interface. +func (e *ShowDDLJobsExec) Next(_ context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + if (e.cursor - len(e.runningJobs)) >= e.jobNumber { + return nil + } + count := 0 + + // Append running ddl jobs. + if e.cursor < len(e.runningJobs) { + numCurBatch := min(req.Capacity(), len(e.runningJobs)-e.cursor) + for i := e.cursor; i < e.cursor+numCurBatch; i++ { + e.appendJobToChunk(req, e.runningJobs[i], nil) + } + e.cursor += numCurBatch + count += numCurBatch + } + + // Append history ddl jobs. + var err error + if count < req.Capacity() { + num := req.Capacity() - count + remainNum := e.jobNumber - (e.cursor - len(e.runningJobs)) + num = min(num, remainNum) + e.cacheJobs, err = e.historyJobIter.GetLastJobs(num, e.cacheJobs) + if err != nil { + return err + } + for _, job := range e.cacheJobs { + e.appendJobToChunk(req, job, nil) + } + e.cursor += len(e.cacheJobs) + } + return nil +} + +// Close implements the Executor Close interface. +func (e *ShowDDLJobsExec) Close() error { + e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), e.sess) + return e.BaseExecutor.Close() +} + +func getSchemaName(is infoschema.InfoSchema, id int64) string { + var schemaName string + dbInfo, ok := is.SchemaByID(id) + if ok { + schemaName = dbInfo.Name.O + return schemaName + } + + return schemaName +} + +func getTableName(is infoschema.InfoSchema, id int64) string { + var tableName string + table, ok := is.TableByID(id) + if ok { + tableName = table.Meta().Name.O + return tableName + } + + return tableName +} + +// CheckTableExec represents a check table executor. +// It is built from the "admin check table" statement, and it checks if the +// index matches the records in the table. +type CheckTableExec struct { + exec.BaseExecutor + + dbName string + table table.Table + indexInfos []*model.IndexInfo + srcs []*IndexLookUpExecutor + done bool + is infoschema.InfoSchema + exitCh chan struct{} + retCh chan error + checkIndex bool +} + +// Open implements the Executor Open interface. +func (e *CheckTableExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + for _, src := range e.srcs { + if err := exec.Open(ctx, src); err != nil { + return errors.Trace(err) + } + } + e.done = false + return nil +} + +// Close implements the Executor Close interface. +func (e *CheckTableExec) Close() error { + var firstErr error + close(e.exitCh) + for _, src := range e.srcs { + if err := exec.Close(src); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +func (e *CheckTableExec) checkTableIndexHandle(ctx context.Context, idxInfo *model.IndexInfo) error { + // For partition table, there will be multi same index indexLookUpReaders on different partitions. + for _, src := range e.srcs { + if src.index.Name.L == idxInfo.Name.L { + err := e.checkIndexHandle(ctx, src) + if err != nil { + return err + } + } + } + return nil +} + +func (e *CheckTableExec) checkIndexHandle(ctx context.Context, src *IndexLookUpExecutor) error { + cols := src.Schema().Columns + retFieldTypes := make([]*types.FieldType, len(cols)) + for i := range cols { + retFieldTypes[i] = cols[i].RetType + } + chk := chunk.New(retFieldTypes, e.InitCap(), e.MaxChunkSize()) + + var err error + for { + err = exec.Next(ctx, src, chk) + if err != nil { + e.retCh <- errors.Trace(err) + break + } + if chk.NumRows() == 0 { + break + } + } + return errors.Trace(err) +} + +func (e *CheckTableExec) handlePanic(r any) { + if r != nil { + e.retCh <- errors.Errorf("%v", r) + } +} + +// Next implements the Executor Next interface. +func (e *CheckTableExec) Next(ctx context.Context, _ *chunk.Chunk) error { + if e.done || len(e.srcs) == 0 { + return nil + } + defer func() { e.done = true }() + + idxNames := make([]string, 0, len(e.indexInfos)) + for _, idx := range e.indexInfos { + if idx.MVIndex { + continue + } + idxNames = append(idxNames, idx.Name.O) + } + greater, idxOffset, err := admin.CheckIndicesCount(e.Ctx(), e.dbName, e.table.Meta().Name.O, idxNames) + if err != nil { + // For admin check index statement, for speed up and compatibility, doesn't do below checks. + if e.checkIndex { + return errors.Trace(err) + } + if greater == admin.IdxCntGreater { + err = e.checkTableIndexHandle(ctx, e.indexInfos[idxOffset]) + } else if greater == admin.TblCntGreater { + err = e.checkTableRecord(ctx, idxOffset) + } + return errors.Trace(err) + } + + // The number of table rows is equal to the number of index rows. + // TODO: Make the value of concurrency adjustable. And we can consider the number of records. + if len(e.srcs) == 1 { + err = e.checkIndexHandle(ctx, e.srcs[0]) + if err == nil && e.srcs[0].index.MVIndex { + err = e.checkTableRecord(ctx, 0) + } + if err != nil { + return err + } + } + taskCh := make(chan *IndexLookUpExecutor, len(e.srcs)) + failure := atomicutil.NewBool(false) + concurrency := min(3, len(e.srcs)) + var wg util.WaitGroupWrapper + for _, src := range e.srcs { + taskCh <- src + } + for i := 0; i < concurrency; i++ { + wg.Run(func() { + util.WithRecovery(func() { + for { + if fail := failure.Load(); fail { + return + } + select { + case src := <-taskCh: + err1 := e.checkIndexHandle(ctx, src) + if err1 == nil && src.index.MVIndex { + for offset, idx := range e.indexInfos { + if idx.ID == src.index.ID { + err1 = e.checkTableRecord(ctx, offset) + break + } + } + } + if err1 != nil { + failure.Store(true) + logutil.Logger(ctx).Info("check index handle failed", zap.Error(err1)) + return + } + case <-e.exitCh: + return + default: + return + } + } + }, e.handlePanic) + }) + } + wg.Wait() + select { + case err := <-e.retCh: + return errors.Trace(err) + default: + return nil + } +} + +func (e *CheckTableExec) checkTableRecord(ctx context.Context, idxOffset int) error { + idxInfo := e.indexInfos[idxOffset] + txn, err := e.Ctx().Txn(true) + if err != nil { + return err + } + if e.table.Meta().GetPartitionInfo() == nil { + idx := tables.NewIndex(e.table.Meta().ID, e.table.Meta(), idxInfo) + return admin.CheckRecordAndIndex(ctx, e.Ctx(), txn, e.table, idx) + } + + info := e.table.Meta().GetPartitionInfo() + for _, def := range info.Definitions { + pid := def.ID + partition := e.table.(table.PartitionedTable).GetPartition(pid) + idx := tables.NewIndex(def.ID, e.table.Meta(), idxInfo) + if err := admin.CheckRecordAndIndex(ctx, e.Ctx(), txn, partition, idx); err != nil { + return errors.Trace(err) + } + } + return nil +} + +// ShowSlowExec represents the executor of showing the slow queries. +// It is build from the "admin show slow" statement: +// +// admin show slow top [internal | all] N +// admin show slow recent N +type ShowSlowExec struct { + exec.BaseExecutor + + ShowSlow *ast.ShowSlow + result []*domain.SlowQueryInfo + cursor int +} + +// Open implements the Executor Open interface. +func (e *ShowSlowExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + + dom := domain.GetDomain(e.Ctx()) + e.result = dom.ShowSlowQuery(e.ShowSlow) + return nil +} + +// Next implements the Executor Next interface. +func (e *ShowSlowExec) Next(_ context.Context, req *chunk.Chunk) error { + req.Reset() + if e.cursor >= len(e.result) { + return nil + } + + for e.cursor < len(e.result) && req.NumRows() < e.MaxChunkSize() { + slow := e.result[e.cursor] + req.AppendString(0, slow.SQL) + req.AppendTime(1, types.NewTime(types.FromGoTime(slow.Start), mysql.TypeTimestamp, types.MaxFsp)) + req.AppendDuration(2, types.Duration{Duration: slow.Duration, Fsp: types.MaxFsp}) + req.AppendString(3, slow.Detail.String()) + if slow.Succ { + req.AppendInt64(4, 1) + } else { + req.AppendInt64(4, 0) + } + req.AppendUint64(5, slow.ConnID) + req.AppendUint64(6, slow.TxnTS) + req.AppendString(7, slow.User) + req.AppendString(8, slow.DB) + req.AppendString(9, slow.TableIDs) + req.AppendString(10, slow.IndexNames) + if slow.Internal { + req.AppendInt64(11, 1) + } else { + req.AppendInt64(11, 0) + } + req.AppendString(12, slow.Digest) + req.AppendString(13, slow.SessAlias) + e.cursor++ + } + return nil +} + +// SelectLockExec represents a select lock executor. +// It is built from the "SELECT .. FOR UPDATE" or the "SELECT .. LOCK IN SHARE MODE" statement. +// For "SELECT .. FOR UPDATE" statement, it locks every row key from source Executor. +// After the execution, the keys are buffered in transaction, and will be sent to KV +// when doing commit. If there is any key already locked by another transaction, +// the transaction will rollback and retry. +type SelectLockExec struct { + exec.BaseExecutor + + Lock *ast.SelectLockInfo + keys []kv.Key + + // The children may be a join of multiple tables, so we need a map. + tblID2Handle map[int64][]plannerutil.HandleCols + + // When SelectLock work on a partition table, we need the partition ID + // (Physical Table ID) instead of the 'logical' table ID to calculate + // the lock KV. In that case, the Physical Table ID is extracted + // from the row key in the store and as an extra column in the chunk row. + + // tblID2PhyTblIDCol is used for partitioned tables. + // The child executor need to return an extra column containing + // the Physical Table ID (i.e. from which partition the row came from) + // Used during building + tblID2PhysTblIDCol map[int64]*expression.Column + + // Used during execution + // Map from logic tableID to column index where the physical table id is stored + // For dynamic prune mode, model.ExtraPhysTblID columns are requested from + // storage and used for physical table id + // For static prune mode, model.ExtraPhysTblID is still sent to storage/Protobuf + // but could be filled in by the partitions TableReaderExecutor + // due to issues with chunk handling between the TableReaderExecutor and the + // SelectReader result. + tblID2PhysTblIDColIdx map[int64]int +} + +// Open implements the Executor Open interface. +func (e *SelectLockExec) Open(ctx context.Context) error { + if len(e.tblID2PhysTblIDCol) > 0 { + e.tblID2PhysTblIDColIdx = make(map[int64]int) + cols := e.Schema().Columns + for i := len(cols) - 1; i >= 0; i-- { + if cols[i].ID == model.ExtraPhysTblID { + for tblID, col := range e.tblID2PhysTblIDCol { + if cols[i].UniqueID == col.UniqueID { + e.tblID2PhysTblIDColIdx[tblID] = i + break + } + } + } + } + } + return e.BaseExecutor.Open(ctx) +} + +// Next implements the Executor Next interface. +func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + err := exec.Next(ctx, e.Children(0), req) + if err != nil { + return err + } + // If there's no handle or it's not a `SELECT FOR UPDATE` statement. + if len(e.tblID2Handle) == 0 || (!logicalop.IsSelectForUpdateLockType(e.Lock.LockType)) { + return nil + } + + if req.NumRows() > 0 { + iter := chunk.NewIterator4Chunk(req) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + for tblID, cols := range e.tblID2Handle { + for _, col := range cols { + handle, err := col.BuildHandle(row) + if err != nil { + return err + } + physTblID := tblID + if physTblColIdx, ok := e.tblID2PhysTblIDColIdx[tblID]; ok { + physTblID = row.GetInt64(physTblColIdx) + if physTblID == 0 { + // select * from t1 left join t2 on t1.c = t2.c for update + // The join right side might be added NULL in left join + // In that case, physTblID is 0, so skip adding the lock. + // + // Note, we can't distinguish whether it's the left join case, + // or a bug that TiKV return without correct physical ID column. + continue + } + } + e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(physTblID, handle)) + } + } + } + return nil + } + lockWaitTime := e.Ctx().GetSessionVars().LockWaitTimeout + if e.Lock.LockType == ast.SelectLockForUpdateNoWait { + lockWaitTime = tikvstore.LockNoWait + } else if e.Lock.LockType == ast.SelectLockForUpdateWaitN { + lockWaitTime = int64(e.Lock.WaitSec) * 1000 + } + + for id := range e.tblID2Handle { + e.UpdateDeltaForTableID(id) + } + lockCtx, err := newLockCtx(e.Ctx(), lockWaitTime, len(e.keys)) + if err != nil { + return err + } + return doLockKeys(ctx, e.Ctx(), lockCtx, e.keys...) +} + +func newLockCtx(sctx sessionctx.Context, lockWaitTime int64, numKeys int) (*tikvstore.LockCtx, error) { + seVars := sctx.GetSessionVars() + forUpdateTS, err := sessiontxn.GetTxnManager(sctx).GetStmtForUpdateTS() + if err != nil { + return nil, err + } + lockCtx := tikvstore.NewLockCtx(forUpdateTS, lockWaitTime, seVars.StmtCtx.GetLockWaitStartTime()) + lockCtx.Killed = &seVars.SQLKiller.Signal + lockCtx.PessimisticLockWaited = &seVars.StmtCtx.PessimisticLockWaited + lockCtx.LockKeysDuration = &seVars.StmtCtx.LockKeysDuration + lockCtx.LockKeysCount = &seVars.StmtCtx.LockKeysCount + lockCtx.LockExpired = &seVars.TxnCtx.LockExpire + lockCtx.ResourceGroupTagger = func(req *kvrpcpb.PessimisticLockRequest) []byte { + if req == nil { + return nil + } + if len(req.Mutations) == 0 { + return nil + } + if mutation := req.Mutations[0]; mutation != nil { + label := resourcegrouptag.GetResourceGroupLabelByKey(mutation.Key) + normalized, digest := seVars.StmtCtx.SQLDigest() + if len(normalized) == 0 { + return nil + } + _, planDigest := seVars.StmtCtx.GetPlanDigest() + return resourcegrouptag.EncodeResourceGroupTag(digest, planDigest, label) + } + return nil + } + lockCtx.OnDeadlock = func(deadlock *tikverr.ErrDeadlock) { + cfg := config.GetGlobalConfig() + if deadlock.IsRetryable && !cfg.PessimisticTxn.DeadlockHistoryCollectRetryable { + return + } + rec := deadlockhistory.ErrDeadlockToDeadlockRecord(deadlock) + deadlockhistory.GlobalDeadlockHistory.Push(rec) + } + if lockCtx.ForUpdateTS > 0 && seVars.AssertionLevel != variable.AssertionLevelOff { + lockCtx.InitCheckExistence(numKeys) + } + return lockCtx, nil +} + +// doLockKeys is the main entry for pessimistic lock keys +// waitTime means the lock operation will wait in milliseconds if target key is already +// locked by others. used for (select for update nowait) situation +func doLockKeys(ctx context.Context, se sessionctx.Context, lockCtx *tikvstore.LockCtx, keys ...kv.Key) error { + sessVars := se.GetSessionVars() + sctx := sessVars.StmtCtx + if !sctx.InUpdateStmt && !sctx.InDeleteStmt { + atomic.StoreUint32(&se.GetSessionVars().TxnCtx.ForUpdate, 1) + } + // Lock keys only once when finished fetching all results. + txn, err := se.Txn(true) + if err != nil { + return err + } + + // Skip the temporary table keys. + keys = filterTemporaryTableKeys(sessVars, keys) + + keys = filterLockTableKeys(sessVars.StmtCtx, keys) + var lockKeyStats *tikvutil.LockKeysDetails + ctx = context.WithValue(ctx, tikvutil.LockKeysDetailCtxKey, &lockKeyStats) + err = txn.LockKeys(tikvutil.SetSessionID(ctx, se.GetSessionVars().ConnectionID), lockCtx, keys...) + if lockKeyStats != nil { + sctx.MergeLockKeysExecDetails(lockKeyStats) + } + return err +} + +func filterTemporaryTableKeys(vars *variable.SessionVars, keys []kv.Key) []kv.Key { + txnCtx := vars.TxnCtx + if txnCtx == nil || txnCtx.TemporaryTables == nil { + return keys + } + + newKeys := keys[:0:len(keys)] + for _, key := range keys { + tblID := tablecodec.DecodeTableID(key) + if _, ok := txnCtx.TemporaryTables[tblID]; !ok { + newKeys = append(newKeys, key) + } + } + return newKeys +} + +func filterLockTableKeys(stmtCtx *stmtctx.StatementContext, keys []kv.Key) []kv.Key { + if len(stmtCtx.LockTableIDs) == 0 { + return keys + } + newKeys := keys[:0:len(keys)] + for _, key := range keys { + tblID := tablecodec.DecodeTableID(key) + if _, ok := stmtCtx.LockTableIDs[tblID]; ok { + newKeys = append(newKeys, key) + } + } + return newKeys +} + +// LimitExec represents limit executor +// It ignores 'Offset' rows from src, then returns 'Count' rows at maximum. +type LimitExec struct { + exec.BaseExecutor + + begin uint64 + end uint64 + cursor uint64 + + // meetFirstBatch represents whether we have met the first valid Chunk from child. + meetFirstBatch bool + + childResult *chunk.Chunk + + // columnIdxsUsedByChild keep column indexes of child executor used for inline projection + columnIdxsUsedByChild []int + + // Log the close time when opentracing is enabled. + span opentracing.Span +} + +// Next implements the Executor Next interface. +func (e *LimitExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.cursor >= e.end { + return nil + } + for !e.meetFirstBatch { + // transfer req's requiredRows to childResult and then adjust it in childResult + e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.MaxChunkSize()) + err := exec.Next(ctx, e.Children(0), e.adjustRequiredRows(e.childResult)) + if err != nil { + return err + } + batchSize := uint64(e.childResult.NumRows()) + // no more data. + if batchSize == 0 { + return nil + } + if newCursor := e.cursor + batchSize; newCursor >= e.begin { + e.meetFirstBatch = true + begin, end := e.begin-e.cursor, batchSize + if newCursor > e.end { + end = e.end - e.cursor + } + e.cursor += end + if begin == end { + break + } + if e.columnIdxsUsedByChild != nil { + req.Append(e.childResult.Prune(e.columnIdxsUsedByChild), int(begin), int(end)) + } else { + req.Append(e.childResult, int(begin), int(end)) + } + return nil + } + e.cursor += batchSize + } + e.childResult.Reset() + e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.MaxChunkSize()) + e.adjustRequiredRows(e.childResult) + err := exec.Next(ctx, e.Children(0), e.childResult) + if err != nil { + return err + } + batchSize := uint64(e.childResult.NumRows()) + // no more data. + if batchSize == 0 { + return nil + } + if e.cursor+batchSize > e.end { + e.childResult.TruncateTo(int(e.end - e.cursor)) + batchSize = e.end - e.cursor + } + e.cursor += batchSize + + if e.columnIdxsUsedByChild != nil { + for i, childIdx := range e.columnIdxsUsedByChild { + if err = req.SwapColumn(i, e.childResult, childIdx); err != nil { + return err + } + } + } else { + req.SwapColumns(e.childResult) + } + return nil +} + +// Open implements the Executor Open interface. +func (e *LimitExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + e.childResult = exec.TryNewCacheChunk(e.Children(0)) + e.cursor = 0 + e.meetFirstBatch = e.begin == 0 + if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { + e.span = span + } + return nil +} + +// Close implements the Executor Close interface. +func (e *LimitExec) Close() error { + start := time.Now() + + e.childResult = nil + err := e.BaseExecutor.Close() + + elapsed := time.Since(start) + if elapsed > time.Millisecond { + logutil.BgLogger().Info("limit executor close takes a long time", + zap.Duration("elapsed", elapsed)) + if e.span != nil { + span1 := e.span.Tracer().StartSpan("limitExec.Close", opentracing.ChildOf(e.span.Context()), opentracing.StartTime(start)) + defer span1.Finish() + } + } + return err +} + +func (e *LimitExec) adjustRequiredRows(chk *chunk.Chunk) *chunk.Chunk { + // the limit of maximum number of rows the LimitExec should read + limitTotal := int(e.end - e.cursor) + + var limitRequired int + if e.cursor < e.begin { + // if cursor is less than begin, it have to read (begin-cursor) rows to ignore + // and then read chk.RequiredRows() rows to return, + // so the limit is (begin-cursor)+chk.RequiredRows(). + limitRequired = int(e.begin) - int(e.cursor) + chk.RequiredRows() + } else { + // if cursor is equal or larger than begin, just read chk.RequiredRows() rows to return. + limitRequired = chk.RequiredRows() + } + + return chk.SetRequiredRows(min(limitTotal, limitRequired), e.MaxChunkSize()) +} + +func init() { + // While doing optimization in the plan package, we need to execute uncorrelated subquery, + // but the plan package cannot import the executor package because of the dependency cycle. + // So we assign a function implemented in the executor package to the plan package to avoid the dependency cycle. + plannercore.EvalSubqueryFirstRow = func(ctx context.Context, p base.PhysicalPlan, is infoschema.InfoSchema, pctx planctx.PlanContext) ([]types.Datum, error) { + if fixcontrol.GetBoolWithDefault(pctx.GetSessionVars().OptimizerFixControl, fixcontrol.Fix43817, false) { + return nil, errors.NewNoStackError("evaluate non-correlated sub-queries during optimization phase is not allowed by fix-control 43817") + } + + defer func(begin time.Time) { + s := pctx.GetSessionVars() + s.StmtCtx.SetSkipPlanCache("query has uncorrelated sub-queries is un-cacheable") + s.RewritePhaseInfo.PreprocessSubQueries++ + s.RewritePhaseInfo.DurationPreprocessSubQuery += time.Since(begin) + }(time.Now()) + + r, ctx := tracing.StartRegionEx(ctx, "executor.EvalSubQuery") + defer r.End() + + sctx, err := plannercore.AsSctx(pctx) + intest.AssertNoError(err) + if err != nil { + return nil, err + } + + e := newExecutorBuilder(sctx, is) + executor := e.build(p) + if e.err != nil { + return nil, e.err + } + err = exec.Open(ctx, executor) + defer func() { terror.Log(exec.Close(executor)) }() + if err != nil { + return nil, err + } + if pi, ok := sctx.(processinfoSetter); ok { + // Before executing the sub-query, we need update the processinfo to make the progress bar more accurate. + // because the sub-query may take a long time. + pi.UpdateProcessInfo() + } + chk := exec.TryNewCacheChunk(executor) + err = exec.Next(ctx, executor, chk) + if err != nil { + return nil, err + } + if chk.NumRows() == 0 { + return nil, nil + } + row := chk.GetRow(0).GetDatumRow(exec.RetTypes(executor)) + return row, err + } +} + +// TableDualExec represents a dual table executor. +type TableDualExec struct { + exec.BaseExecutorV2 + + // numDualRows can only be 0 or 1. + numDualRows int + numReturned int +} + +// Open implements the Executor Open interface. +func (e *TableDualExec) Open(context.Context) error { + e.numReturned = 0 + return nil +} + +// Next implements the Executor Next interface. +func (e *TableDualExec) Next(_ context.Context, req *chunk.Chunk) error { + req.Reset() + if e.numReturned >= e.numDualRows { + return nil + } + if e.Schema().Len() == 0 { + req.SetNumVirtualRows(1) + } else { + for i := range e.Schema().Columns { + req.AppendNull(i) + } + } + e.numReturned = e.numDualRows + return nil +} + +type selectionExecutorContext struct { + stmtMemTracker *memory.Tracker + evalCtx expression.EvalContext + enableVectorizedExpression bool +} + +func newSelectionExecutorContext(sctx sessionctx.Context) selectionExecutorContext { + return selectionExecutorContext{ + stmtMemTracker: sctx.GetSessionVars().StmtCtx.MemTracker, + evalCtx: sctx.GetExprCtx().GetEvalCtx(), + enableVectorizedExpression: sctx.GetSessionVars().EnableVectorizedExpression, + } +} + +// SelectionExec represents a filter executor. +type SelectionExec struct { + selectionExecutorContext + exec.BaseExecutorV2 + + batched bool + filters []expression.Expression + selected []bool + inputIter *chunk.Iterator4Chunk + inputRow chunk.Row + childResult *chunk.Chunk + + memTracker *memory.Tracker +} + +// Open implements the Executor Open interface. +func (e *SelectionExec) Open(ctx context.Context) error { + if err := e.BaseExecutorV2.Open(ctx); err != nil { + return err + } + failpoint.Inject("mockSelectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock SelectionExec.baseExecutor.Open returned error")) + } + }) + return e.open(ctx) +} + +func (e *SelectionExec) open(context.Context) error { + if e.memTracker != nil { + e.memTracker.Reset() + } else { + e.memTracker = memory.NewTracker(e.ID(), -1) + } + e.memTracker.AttachTo(e.stmtMemTracker) + e.childResult = exec.TryNewCacheChunk(e.Children(0)) + e.memTracker.Consume(e.childResult.MemoryUsage()) + e.batched = expression.Vectorizable(e.filters) + if e.batched { + e.selected = make([]bool, 0, chunk.InitialCapacity) + } + e.inputIter = chunk.NewIterator4Chunk(e.childResult) + e.inputRow = e.inputIter.End() + return nil +} + +// Close implements plannercore.Plan Close interface. +func (e *SelectionExec) Close() error { + if e.childResult != nil { + e.memTracker.Consume(-e.childResult.MemoryUsage()) + e.childResult = nil + } + e.selected = nil + return e.BaseExecutorV2.Close() +} + +// Next implements the Executor Next interface. +func (e *SelectionExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + + if !e.batched { + return e.unBatchedNext(ctx, req) + } + + for { + for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { + if req.IsFull() { + return nil + } + + if !e.selected[e.inputRow.Idx()] { + continue + } + + req.AppendRow(e.inputRow) + } + mSize := e.childResult.MemoryUsage() + err := exec.Next(ctx, e.Children(0), e.childResult) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) + if err != nil { + return err + } + // no more data. + if e.childResult.NumRows() == 0 { + return nil + } + e.selected, err = expression.VectorizedFilter(e.evalCtx, e.enableVectorizedExpression, e.filters, e.inputIter, e.selected) + if err != nil { + return err + } + e.inputRow = e.inputIter.Begin() + } +} + +// unBatchedNext filters input rows one by one and returns once an input row is selected. +// For sql with "SETVAR" in filter and "GETVAR" in projection, for example: "SELECT @a FROM t WHERE (@a := 2) > 0", +// we have to set batch size to 1 to do the evaluation of filter and projection. +func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) error { + evalCtx := e.evalCtx + for { + for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { + selected, _, err := expression.EvalBool(evalCtx, e.filters, e.inputRow) + if err != nil { + return err + } + if selected { + chk.AppendRow(e.inputRow) + e.inputRow = e.inputIter.Next() + return nil + } + } + mSize := e.childResult.MemoryUsage() + err := exec.Next(ctx, e.Children(0), e.childResult) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) + if err != nil { + return err + } + e.inputRow = e.inputIter.Begin() + // no more data. + if e.childResult.NumRows() == 0 { + return nil + } + } +} + +// TableScanExec is a table scan executor without result fields. +type TableScanExec struct { + exec.BaseExecutor + + t table.Table + columns []*model.ColumnInfo + virtualTableChunkList *chunk.List + virtualTableChunkIdx int +} + +// Next implements the Executor Next interface. +func (e *TableScanExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + return e.nextChunk4InfoSchema(ctx, req) +} + +func (e *TableScanExec) nextChunk4InfoSchema(ctx context.Context, chk *chunk.Chunk) error { + chk.GrowAndReset(e.MaxChunkSize()) + if e.virtualTableChunkList == nil { + e.virtualTableChunkList = chunk.NewList(exec.RetTypes(e), e.InitCap(), e.MaxChunkSize()) + columns := make([]*table.Column, e.Schema().Len()) + for i, colInfo := range e.columns { + columns[i] = table.ToColumn(colInfo) + } + mutableRow := chunk.MutRowFromTypes(exec.RetTypes(e)) + type tableIter interface { + IterRecords(ctx context.Context, sctx sessionctx.Context, cols []*table.Column, fn table.RecordIterFunc) error + } + err := (e.t.(tableIter)).IterRecords(ctx, e.Ctx(), columns, func(_ kv.Handle, rec []types.Datum, _ []*table.Column) (bool, error) { + mutableRow.SetDatums(rec...) + e.virtualTableChunkList.AppendRow(mutableRow.ToRow()) + return true, nil + }) + if err != nil { + return err + } + } + // no more data. + if e.virtualTableChunkIdx >= e.virtualTableChunkList.NumChunks() { + return nil + } + virtualTableChunk := e.virtualTableChunkList.GetChunk(e.virtualTableChunkIdx) + e.virtualTableChunkIdx++ + chk.SwapColumns(virtualTableChunk) + return nil +} + +// Open implements the Executor Open interface. +func (e *TableScanExec) Open(context.Context) error { + e.virtualTableChunkList = nil + return nil +} + +// MaxOneRowExec checks if the number of rows that a query returns is at maximum one. +// It's built from subquery expression. +type MaxOneRowExec struct { + exec.BaseExecutor + + evaluated bool +} + +// Open implements the Executor Open interface. +func (e *MaxOneRowExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + e.evaluated = false + return nil +} + +// Next implements the Executor Next interface. +func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.evaluated { + return nil + } + e.evaluated = true + err := exec.Next(ctx, e.Children(0), req) + if err != nil { + return err + } + + if num := req.NumRows(); num == 0 { + for i := range e.Schema().Columns { + req.AppendNull(i) + } + return nil + } else if num != 1 { + return exeerrors.ErrSubqueryMoreThan1Row + } + + childChunk := exec.TryNewCacheChunk(e.Children(0)) + err = exec.Next(ctx, e.Children(0), childChunk) + if err != nil { + return err + } + if childChunk.NumRows() != 0 { + return exeerrors.ErrSubqueryMoreThan1Row + } + + return nil +} + +// ResetContextOfStmt resets the StmtContext and session variables. +// Before every execution, we must clear statement context. +func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Warn("ResetContextOfStmt panicked", zap.Stack("stack"), zap.Any("recover", r), zap.Error(err)) + if err != nil { + err = stderrors.Join(err, util.GetRecoverError(r)) + } else { + err = util.GetRecoverError(r) + } + } + }() + vars := ctx.GetSessionVars() + for name, val := range vars.StmtCtx.SetVarHintRestore { + err := vars.SetSystemVar(name, val) + if err != nil { + logutil.BgLogger().Warn("Failed to restore the variable after SET_VAR hint", zap.String("variable name", name), zap.String("expected value", val)) + } + } + vars.StmtCtx.SetVarHintRestore = nil + var sc *stmtctx.StatementContext + if vars.TxnCtx.CouldRetry || vars.HasStatusFlag(mysql.ServerStatusCursorExists) { + // Must construct new statement context object, the retry history need context for every statement. + // TODO: Maybe one day we can get rid of transaction retry, then this logic can be deleted. + sc = stmtctx.NewStmtCtx() + } else { + sc = vars.InitStatementContext() + } + sc.SetTimeZone(vars.Location()) + sc.TaskID = stmtctx.AllocateTaskID() + if sc.CTEStorageMap == nil { + sc.CTEStorageMap = map[int]*CTEStorages{} + } else { + clear(sc.CTEStorageMap.(map[int]*CTEStorages)) + } + if sc.LockTableIDs == nil { + sc.LockTableIDs = make(map[int64]struct{}) + } else { + clear(sc.LockTableIDs) + } + if sc.TableStats == nil { + sc.TableStats = make(map[int64]any) + } else { + clear(sc.TableStats) + } + if sc.MDLRelatedTableIDs == nil { + sc.MDLRelatedTableIDs = make(map[int64]struct{}) + } else { + clear(sc.MDLRelatedTableIDs) + } + if sc.TblInfo2UnionScan == nil { + sc.TblInfo2UnionScan = make(map[*model.TableInfo]bool) + } else { + clear(sc.TblInfo2UnionScan) + } + sc.IsStaleness = false + sc.EnableOptimizeTrace = false + sc.OptimizeTracer = nil + sc.OptimizerCETrace = nil + sc.IsSyncStatsFailed = false + sc.IsExplainAnalyzeDML = false + sc.ResourceGroupName = vars.ResourceGroupName + // Firstly we assume that UseDynamicPruneMode can be enabled according session variable, then we will check other conditions + // in PlanBuilder.buildDataSource + if ctx.GetSessionVars().IsDynamicPartitionPruneEnabled() { + sc.UseDynamicPruneMode = true + } else { + sc.UseDynamicPruneMode = false + } + + sc.StatsLoad.Timeout = 0 + sc.StatsLoad.NeededItems = nil + sc.StatsLoad.ResultCh = nil + + sc.SysdateIsNow = ctx.GetSessionVars().SysdateIsNow + + vars.MemTracker.Detach() + vars.MemTracker.UnbindActions() + vars.MemTracker.SetBytesLimit(vars.MemQuotaQuery) + vars.MemTracker.ResetMaxConsumed() + vars.DiskTracker.Detach() + vars.DiskTracker.ResetMaxConsumed() + vars.MemTracker.SessionID.Store(vars.ConnectionID) + vars.MemTracker.Killer = &vars.SQLKiller + vars.DiskTracker.Killer = &vars.SQLKiller + vars.SQLKiller.Reset() + vars.SQLKiller.ConnID = vars.ConnectionID + + isAnalyze := false + if execStmt, ok := s.(*ast.ExecuteStmt); ok { + prepareStmt, err := plannercore.GetPreparedStmt(execStmt, vars) + if err != nil { + return err + } + _, isAnalyze = prepareStmt.PreparedAst.Stmt.(*ast.AnalyzeTableStmt) + } else if _, ok := s.(*ast.AnalyzeTableStmt); ok { + isAnalyze = true + } + if isAnalyze { + sc.InitMemTracker(memory.LabelForAnalyzeMemory, -1) + vars.MemTracker.SetBytesLimit(-1) + vars.MemTracker.AttachTo(GlobalAnalyzeMemoryTracker) + } else { + sc.InitMemTracker(memory.LabelForSQLText, -1) + } + logOnQueryExceedMemQuota := domain.GetDomain(ctx).ExpensiveQueryHandle().LogOnQueryExceedMemQuota + switch variable.OOMAction.Load() { + case variable.OOMActionCancel: + action := &memory.PanicOnExceed{ConnID: vars.ConnectionID, Killer: vars.MemTracker.Killer} + action.SetLogHook(logOnQueryExceedMemQuota) + vars.MemTracker.SetActionOnExceed(action) + case variable.OOMActionLog: + fallthrough + default: + action := &memory.LogOnExceed{ConnID: vars.ConnectionID} + action.SetLogHook(logOnQueryExceedMemQuota) + vars.MemTracker.SetActionOnExceed(action) + } + sc.MemTracker.SessionID.Store(vars.ConnectionID) + sc.MemTracker.AttachTo(vars.MemTracker) + sc.InitDiskTracker(memory.LabelForSQLText, -1) + globalConfig := config.GetGlobalConfig() + if variable.EnableTmpStorageOnOOM.Load() && sc.DiskTracker != nil { + sc.DiskTracker.AttachTo(vars.DiskTracker) + if GlobalDiskUsageTracker != nil { + vars.DiskTracker.AttachTo(GlobalDiskUsageTracker) + } + } + if execStmt, ok := s.(*ast.ExecuteStmt); ok { + prepareStmt, err := plannercore.GetPreparedStmt(execStmt, vars) + if err != nil { + return err + } + s = prepareStmt.PreparedAst.Stmt + sc.InitSQLDigest(prepareStmt.NormalizedSQL, prepareStmt.SQLDigest) + // For `execute stmt` SQL, should reset the SQL digest with the prepare SQL digest. + goCtx := context.Background() + if variable.EnablePProfSQLCPU.Load() && len(prepareStmt.NormalizedSQL) > 0 { + goCtx = pprof.WithLabels(goCtx, pprof.Labels("sql", FormatSQL(prepareStmt.NormalizedSQL).String())) + pprof.SetGoroutineLabels(goCtx) + } + if topsqlstate.TopSQLEnabled() && prepareStmt.SQLDigest != nil { + sc.IsSQLRegistered.Store(true) + topsql.AttachAndRegisterSQLInfo(goCtx, prepareStmt.NormalizedSQL, prepareStmt.SQLDigest, vars.InRestrictedSQL) + } + if s, ok := prepareStmt.PreparedAst.Stmt.(*ast.SelectStmt); ok { + if s.LockInfo == nil { + sc.WeakConsistency = isWeakConsistencyRead(ctx, execStmt) + } + } + } + // execute missed stmtID uses empty sql + sc.OriginalSQL = s.Text() + if explainStmt, ok := s.(*ast.ExplainStmt); ok { + sc.InExplainStmt = true + sc.ExplainFormat = explainStmt.Format + sc.InExplainAnalyzeStmt = explainStmt.Analyze + sc.IgnoreExplainIDSuffix = strings.ToLower(explainStmt.Format) == types.ExplainFormatBrief + sc.InVerboseExplain = strings.ToLower(explainStmt.Format) == types.ExplainFormatVerbose + s = explainStmt.Stmt + } else { + sc.ExplainFormat = "" + } + if explainForStmt, ok := s.(*ast.ExplainForStmt); ok { + sc.InExplainStmt = true + sc.InExplainAnalyzeStmt = true + sc.InVerboseExplain = strings.ToLower(explainForStmt.Format) == types.ExplainFormatVerbose + } + + // TODO: Many same bool variables here. + // We should set only two variables ( + // IgnoreErr and StrictSQLMode) to avoid setting the same bool variables and + // pushing them down to TiKV as flags. + + sc.InRestrictedSQL = vars.InRestrictedSQL + strictSQLMode := vars.SQLMode.HasStrictMode() + + errLevels := sc.ErrLevels() + errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn + switch stmt := s.(type) { + // `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them. + case *ast.UpdateStmt: + ResetUpdateStmtCtx(sc, stmt, vars) + errLevels = sc.ErrLevels() + case *ast.DeleteStmt: + ResetDeleteStmtCtx(sc, stmt, vars) + errLevels = sc.ErrLevels() + case *ast.InsertStmt: + sc.InInsertStmt = true + // For insert statement (not for update statement), disabling the StrictSQLMode + // should make TruncateAsWarning and DividedByZeroAsWarning, + // but should not make DupKeyAsWarning. + if stmt.IgnoreErr { + errLevels[errctx.ErrGroupDupKey] = errctx.LevelWarn + errLevels[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn + errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn + } + errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) + errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( + !vars.SQLMode.HasErrorForDivisionByZeroMode(), + !strictSQLMode || stmt.IgnoreErr, + ) + sc.Priority = stmt.Priority + sc.SetTypeFlags(sc.TypeFlags(). + WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). + WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || + !vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode || stmt.IgnoreErr || + vars.SQLMode.HasAllowInvalidDatesMode())) + case *ast.CreateTableStmt, *ast.AlterTableStmt: + sc.InCreateOrAlterStmt = true + sc.SetTypeFlags(sc.TypeFlags(). + WithTruncateAsWarning(!strictSQLMode). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). + WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !strictSQLMode || + vars.SQLMode.HasAllowInvalidDatesMode()). + WithIgnoreZeroDateErr(!vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode)) + + case *ast.LoadDataStmt: + sc.InLoadDataStmt = true + // return warning instead of error when load data meet no partition for value + errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn + case *ast.SelectStmt: + sc.InSelectStmt = true + + // Return warning for truncate error in selection. + sc.SetTypeFlags(sc.TypeFlags(). + WithTruncateAsWarning(true). + WithIgnoreZeroInDate(true). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) + if opts := stmt.SelectStmtOpts; opts != nil { + sc.Priority = opts.Priority + sc.NotFillCache = !opts.SQLCache + } + sc.WeakConsistency = isWeakConsistencyRead(ctx, stmt) + case *ast.SetOprStmt: + sc.InSelectStmt = true + sc.SetTypeFlags(sc.TypeFlags(). + WithTruncateAsWarning(true). + WithIgnoreZeroInDate(true). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) + case *ast.ShowStmt: + sc.SetTypeFlags(sc.TypeFlags(). + WithIgnoreTruncateErr(true). + WithIgnoreZeroInDate(true). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) + if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates { + sc.InShowWarning = true + sc.SetWarnings(vars.StmtCtx.GetWarnings()) + } + case *ast.SplitRegionStmt: + sc.SetTypeFlags(sc.TypeFlags(). + WithIgnoreTruncateErr(false). + WithIgnoreZeroInDate(true). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) + case *ast.SetSessionStatesStmt: + sc.InSetSessionStatesStmt = true + sc.SetTypeFlags(sc.TypeFlags(). + WithIgnoreTruncateErr(true). + WithIgnoreZeroInDate(true). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) + default: + sc.SetTypeFlags(sc.TypeFlags(). + WithIgnoreTruncateErr(true). + WithIgnoreZeroInDate(true). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) + } + + if errLevels != sc.ErrLevels() { + sc.SetErrLevels(errLevels) + } + + sc.SetTypeFlags(sc.TypeFlags(). + WithSkipUTF8Check(vars.SkipUTF8Check). + WithSkipSACIICheck(vars.SkipASCIICheck). + WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()). + // WithAllowNegativeToUnsigned with false value indicates values less than 0 should be clipped to 0 for unsigned integer types. + // This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode. + // see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html + WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt), + ) + + vars.PlanCacheParams.Reset() + if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority { + sc.Priority = priority + } + if vars.StmtCtx.LastInsertID > 0 { + sc.PrevLastInsertID = vars.StmtCtx.LastInsertID + } else { + sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID + } + sc.PrevAffectedRows = 0 + if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt || vars.StmtCtx.InSetSessionStatesStmt { + sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) + } else if vars.StmtCtx.InSelectStmt { + sc.PrevAffectedRows = -1 + } + if globalConfig.Instance.EnableCollectExecutionInfo.Load() { + // In ExplainFor case, RuntimeStatsColl should not be reset for reuse, + // because ExplainFor need to display the last statement information. + reuseObj := vars.StmtCtx.RuntimeStatsColl + if _, ok := s.(*ast.ExplainForStmt); ok { + reuseObj = nil + } + sc.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(reuseObj) + + // also enable index usage collector + if sc.IndexUsageCollector == nil { + sc.IndexUsageCollector = ctx.NewStmtIndexUsageCollector() + } else { + sc.IndexUsageCollector.Reset() + } + } else { + // turn off the index usage collector + sc.IndexUsageCollector = nil + } + + sc.SetForcePlanCache(fixcontrol.GetBoolWithDefault(vars.OptimizerFixControl, fixcontrol.Fix49736, false)) + sc.SetAlwaysWarnSkipCache(sc.InExplainStmt && sc.ExplainFormat == "plan_cache") + errCount, warnCount := vars.StmtCtx.NumErrorWarnings() + vars.SysErrorCount = errCount + vars.SysWarningCount = warnCount + vars.ExchangeChunkStatus() + vars.StmtCtx = sc + vars.PrevFoundInPlanCache = vars.FoundInPlanCache + vars.FoundInPlanCache = false + vars.PrevFoundInBinding = vars.FoundInBinding + vars.FoundInBinding = false + vars.DurationWaitTS = 0 + vars.CurrInsertBatchExtraCols = nil + vars.CurrInsertValues = chunk.Row{} + + return +} + +// ResetUpdateStmtCtx resets statement context for UpdateStmt. +func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars *variable.SessionVars) { + strictSQLMode := vars.SQLMode.HasStrictMode() + sc.InUpdateStmt = true + errLevels := sc.ErrLevels() + errLevels[errctx.ErrGroupDupKey] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) + errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) + errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( + !vars.SQLMode.HasErrorForDivisionByZeroMode(), + !strictSQLMode || stmt.IgnoreErr, + ) + errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) + sc.SetErrLevels(errLevels) + sc.Priority = stmt.Priority + sc.SetTypeFlags(sc.TypeFlags(). + WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). + WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || + !strictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) +} + +// ResetDeleteStmtCtx resets statement context for DeleteStmt. +func ResetDeleteStmtCtx(sc *stmtctx.StatementContext, stmt *ast.DeleteStmt, vars *variable.SessionVars) { + strictSQLMode := vars.SQLMode.HasStrictMode() + sc.InDeleteStmt = true + errLevels := sc.ErrLevels() + errLevels[errctx.ErrGroupDupKey] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) + errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) + errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( + !vars.SQLMode.HasErrorForDivisionByZeroMode(), + !strictSQLMode || stmt.IgnoreErr, + ) + sc.SetErrLevels(errLevels) + sc.Priority = stmt.Priority + sc.SetTypeFlags(sc.TypeFlags(). + WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). + WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). + WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || + !strictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) +} + +func setOptionForTopSQL(sc *stmtctx.StatementContext, snapshot kv.Snapshot) { + if snapshot == nil { + return + } + // pipelined dml may already flush in background, don't touch it to avoid race. + if txn, ok := snapshot.(kv.Transaction); ok && txn.IsPipelined() { + return + } + snapshot.SetOption(kv.ResourceGroupTagger, sc.GetResourceGroupTagger()) + if sc.KvExecCounter != nil { + snapshot.SetOption(kv.RPCInterceptor, sc.KvExecCounter.RPCInterceptor()) + } +} + +func isWeakConsistencyRead(ctx sessionctx.Context, node ast.Node) bool { + sessionVars := ctx.GetSessionVars() + return sessionVars.ConnectionID > 0 && sessionVars.ReadConsistency.IsWeak() && + plannercore.IsAutoCommitTxn(sessionVars) && plannercore.IsReadOnly(node, sessionVars) +} + +// FastCheckTableExec represents a check table executor. +// It is built from the "admin check table" statement, and it checks if the +// index matches the records in the table. +// It uses a new algorithms to check table data, which is faster than the old one(CheckTableExec). +type FastCheckTableExec struct { + exec.BaseExecutor + + dbName string + table table.Table + indexInfos []*model.IndexInfo + done bool + is infoschema.InfoSchema + err *atomic.Pointer[error] + wg sync.WaitGroup + contextCtx context.Context +} + +// Open implements the Executor Open interface. +func (e *FastCheckTableExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + + e.done = false + e.contextCtx = ctx + return nil +} + +type checkIndexTask struct { + indexOffset int +} + +type checkIndexWorker struct { + sctx sessionctx.Context + dbName string + table table.Table + indexInfos []*model.IndexInfo + e *FastCheckTableExec +} + +type groupByChecksum struct { + bucket uint64 + checksum uint64 + count int64 +} + +func getCheckSum(ctx context.Context, se sessionctx.Context, sql string) ([]groupByChecksum, error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnAdmin) + rs, err := se.GetSQLExecutor().ExecuteInternal(ctx, sql) + if err != nil { + return nil, err + } + defer func(rs sqlexec.RecordSet) { + err := rs.Close() + if err != nil { + logutil.BgLogger().Error("close record set failed", zap.Error(err)) + } + }(rs) + rows, err := sqlexec.DrainRecordSet(ctx, rs, 256) + if err != nil { + return nil, err + } + checksums := make([]groupByChecksum, 0, len(rows)) + for _, row := range rows { + checksums = append(checksums, groupByChecksum{bucket: row.GetUint64(1), checksum: row.GetUint64(0), count: row.GetInt64(2)}) + } + return checksums, nil +} + +func (w *checkIndexWorker) initSessCtx(se sessionctx.Context) (restore func()) { + sessVars := se.GetSessionVars() + originOptUseInvisibleIdx := sessVars.OptimizerUseInvisibleIndexes + originMemQuotaQuery := sessVars.MemQuotaQuery + + sessVars.OptimizerUseInvisibleIndexes = true + sessVars.MemQuotaQuery = w.sctx.GetSessionVars().MemQuotaQuery + return func() { + sessVars.OptimizerUseInvisibleIndexes = originOptUseInvisibleIdx + sessVars.MemQuotaQuery = originMemQuotaQuery + } +} + +// HandleTask implements the Worker interface. +func (w *checkIndexWorker) HandleTask(task checkIndexTask, _ func(workerpool.None)) { + defer w.e.wg.Done() + idxInfo := w.indexInfos[task.indexOffset] + bucketSize := int(CheckTableFastBucketSize.Load()) + + ctx := kv.WithInternalSourceType(w.e.contextCtx, kv.InternalTxnAdmin) + + trySaveErr := func(err error) { + w.e.err.CompareAndSwap(nil, &err) + } + + se, err := w.e.BaseExecutor.GetSysSession() + if err != nil { + trySaveErr(err) + return + } + restoreCtx := w.initSessCtx(se) + defer func() { + restoreCtx() + w.e.BaseExecutor.ReleaseSysSession(ctx, se) + }() + + var pkCols []string + var pkTypes []*types.FieldType + switch { + case w.e.table.Meta().IsCommonHandle: + pkColsInfo := w.e.table.Meta().GetPrimaryKey().Columns + for _, colInfo := range pkColsInfo { + colStr := colInfo.Name.O + pkCols = append(pkCols, colStr) + pkTypes = append(pkTypes, &w.e.table.Meta().Columns[colInfo.Offset].FieldType) + } + case w.e.table.Meta().PKIsHandle: + pkCols = append(pkCols, w.e.table.Meta().GetPkName().O) + default: // support decoding _tidb_rowid. + pkCols = append(pkCols, model.ExtraHandleName.O) + } + + // CheckSum of (handle + index columns). + var md5HandleAndIndexCol strings.Builder + md5HandleAndIndexCol.WriteString("crc32(md5(concat_ws(0x2, ") + for _, col := range pkCols { + md5HandleAndIndexCol.WriteString(ColumnName(col)) + md5HandleAndIndexCol.WriteString(", ") + } + for offset, col := range idxInfo.Columns { + tblCol := w.table.Meta().Columns[col.Offset] + if tblCol.IsGenerated() && !tblCol.GeneratedStored { + md5HandleAndIndexCol.WriteString(tblCol.GeneratedExprString) + } else { + md5HandleAndIndexCol.WriteString(ColumnName(col.Name.O)) + } + if offset != len(idxInfo.Columns)-1 { + md5HandleAndIndexCol.WriteString(", ") + } + } + md5HandleAndIndexCol.WriteString(")))") + + // Used to group by and order. + var md5Handle strings.Builder + md5Handle.WriteString("crc32(md5(concat_ws(0x2, ") + for i, col := range pkCols { + md5Handle.WriteString(ColumnName(col)) + if i != len(pkCols)-1 { + md5Handle.WriteString(", ") + } + } + md5Handle.WriteString(")))") + + handleColumnField := strings.Join(pkCols, ", ") + var indexColumnField strings.Builder + for offset, col := range idxInfo.Columns { + indexColumnField.WriteString(ColumnName(col.Name.O)) + if offset != len(idxInfo.Columns)-1 { + indexColumnField.WriteString(", ") + } + } + + tableRowCntToCheck := int64(0) + + offset := 0 + mod := 1 + meetError := false + + lookupCheckThreshold := int64(100) + checkOnce := false + + if w.e.Ctx().GetSessionVars().SnapshotTS != 0 { + se.GetSessionVars().SnapshotTS = w.e.Ctx().GetSessionVars().SnapshotTS + defer func() { + se.GetSessionVars().SnapshotTS = 0 + }() + } + _, err = se.GetSQLExecutor().ExecuteInternal(ctx, "begin") + if err != nil { + trySaveErr(err) + return + } + + times := 0 + const maxTimes = 10 + for tableRowCntToCheck > lookupCheckThreshold || !checkOnce { + times++ + if times == maxTimes { + logutil.BgLogger().Warn("compare checksum by group reaches time limit", zap.Int("times", times)) + break + } + whereKey := fmt.Sprintf("((cast(%s as signed) - %d) %% %d)", md5Handle.String(), offset, mod) + groupByKey := fmt.Sprintf("((cast(%s as signed) - %d) div %d %% %d)", md5Handle.String(), offset, mod, bucketSize) + if !checkOnce { + whereKey = "0" + } + checkOnce = true + + tblQuery := fmt.Sprintf("select /*+ read_from_storage(tikv[%s]) */ bit_xor(%s), %s, count(*) from %s use index() where %s = 0 group by %s", TableName(w.e.dbName, w.e.table.Meta().Name.String()), md5HandleAndIndexCol.String(), groupByKey, TableName(w.e.dbName, w.e.table.Meta().Name.String()), whereKey, groupByKey) + idxQuery := fmt.Sprintf("select bit_xor(%s), %s, count(*) from %s use index(`%s`) where %s = 0 group by %s", md5HandleAndIndexCol.String(), groupByKey, TableName(w.e.dbName, w.e.table.Meta().Name.String()), idxInfo.Name, whereKey, groupByKey) + + logutil.BgLogger().Info("fast check table by group", zap.String("table name", w.table.Meta().Name.String()), zap.String("index name", idxInfo.Name.String()), zap.Int("times", times), zap.Int("current offset", offset), zap.Int("current mod", mod), zap.String("table sql", tblQuery), zap.String("index sql", idxQuery)) + + // compute table side checksum. + tableChecksum, err := getCheckSum(w.e.contextCtx, se, tblQuery) + if err != nil { + trySaveErr(err) + return + } + slices.SortFunc(tableChecksum, func(i, j groupByChecksum) int { + return cmp.Compare(i.bucket, j.bucket) + }) + + // compute index side checksum. + indexChecksum, err := getCheckSum(w.e.contextCtx, se, idxQuery) + if err != nil { + trySaveErr(err) + return + } + slices.SortFunc(indexChecksum, func(i, j groupByChecksum) int { + return cmp.Compare(i.bucket, j.bucket) + }) + + currentOffset := 0 + + // Every checksum in table side should be the same as the index side. + i := 0 + for i < len(tableChecksum) && i < len(indexChecksum) { + if tableChecksum[i].bucket != indexChecksum[i].bucket || tableChecksum[i].checksum != indexChecksum[i].checksum { + if tableChecksum[i].bucket <= indexChecksum[i].bucket { + currentOffset = int(tableChecksum[i].bucket) + tableRowCntToCheck = tableChecksum[i].count + } else { + currentOffset = int(indexChecksum[i].bucket) + tableRowCntToCheck = indexChecksum[i].count + } + meetError = true + break + } + i++ + } + + if !meetError && i < len(indexChecksum) && i == len(tableChecksum) { + // Table side has fewer buckets. + currentOffset = int(indexChecksum[i].bucket) + tableRowCntToCheck = indexChecksum[i].count + meetError = true + } else if !meetError && i < len(tableChecksum) && i == len(indexChecksum) { + // Index side has fewer buckets. + currentOffset = int(tableChecksum[i].bucket) + tableRowCntToCheck = tableChecksum[i].count + meetError = true + } + + if !meetError { + if times != 1 { + logutil.BgLogger().Error("unexpected result, no error detected in this round, but an error is detected in the previous round", zap.Int("times", times), zap.Int("offset", offset), zap.Int("mod", mod)) + } + break + } + + offset += currentOffset * mod + mod *= bucketSize + } + + queryToRow := func(se sessionctx.Context, sql string) ([]chunk.Row, error) { + rs, err := se.GetSQLExecutor().ExecuteInternal(ctx, sql) + if err != nil { + return nil, err + } + row, err := sqlexec.DrainRecordSet(ctx, rs, 4096) + if err != nil { + return nil, err + } + err = rs.Close() + if err != nil { + logutil.BgLogger().Warn("close result set failed", zap.Error(err)) + } + return row, nil + } + + if meetError { + groupByKey := fmt.Sprintf("((cast(%s as signed) - %d) %% %d)", md5Handle.String(), offset, mod) + indexSQL := fmt.Sprintf("select %s, %s, %s from %s use index(`%s`) where %s = 0 order by %s", handleColumnField, indexColumnField.String(), md5HandleAndIndexCol.String(), TableName(w.e.dbName, w.e.table.Meta().Name.String()), idxInfo.Name, groupByKey, handleColumnField) + tableSQL := fmt.Sprintf("select /*+ read_from_storage(tikv[%s]) */ %s, %s, %s from %s use index() where %s = 0 order by %s", TableName(w.e.dbName, w.e.table.Meta().Name.String()), handleColumnField, indexColumnField.String(), md5HandleAndIndexCol.String(), TableName(w.e.dbName, w.e.table.Meta().Name.String()), groupByKey, handleColumnField) + + idxRow, err := queryToRow(se, indexSQL) + if err != nil { + trySaveErr(err) + return + } + tblRow, err := queryToRow(se, tableSQL) + if err != nil { + trySaveErr(err) + return + } + + errCtx := w.sctx.GetSessionVars().StmtCtx.ErrCtx() + getHandleFromRow := func(row chunk.Row) (kv.Handle, error) { + handleDatum := make([]types.Datum, 0) + for i, t := range pkTypes { + handleDatum = append(handleDatum, row.GetDatum(i, t)) + } + if w.table.Meta().IsCommonHandle { + handleBytes, err := codec.EncodeKey(w.sctx.GetSessionVars().StmtCtx.TimeZone(), nil, handleDatum...) + err = errCtx.HandleError(err) + if err != nil { + return nil, err + } + return kv.NewCommonHandle(handleBytes) + } + return kv.IntHandle(row.GetInt64(0)), nil + } + getValueFromRow := func(row chunk.Row) ([]types.Datum, error) { + valueDatum := make([]types.Datum, 0) + for i, t := range idxInfo.Columns { + valueDatum = append(valueDatum, row.GetDatum(i+len(pkCols), &w.table.Meta().Columns[t.Offset].FieldType)) + } + return valueDatum, nil + } + + ir := func() *consistency.Reporter { + return &consistency.Reporter{ + HandleEncode: func(handle kv.Handle) kv.Key { + return tablecodec.EncodeRecordKey(w.table.RecordPrefix(), handle) + }, + IndexEncode: func(idxRow *consistency.RecordData) kv.Key { + var idx table.Index + for _, v := range w.table.Indices() { + if strings.EqualFold(v.Meta().Name.String(), idxInfo.Name.O) { + idx = v + break + } + } + if idx == nil { + return nil + } + sc := w.sctx.GetSessionVars().StmtCtx + k, _, err := idx.GenIndexKey(sc.ErrCtx(), sc.TimeZone(), idxRow.Values[:len(idx.Meta().Columns)], idxRow.Handle, nil) + if err != nil { + return nil + } + return k + }, + Tbl: w.table.Meta(), + Idx: idxInfo, + EnableRedactLog: w.sctx.GetSessionVars().EnableRedactLog, + Storage: w.sctx.GetStore(), + } + } + + getCheckSum := func(row chunk.Row) uint64 { + return row.GetUint64(len(pkCols) + len(idxInfo.Columns)) + } + + var handle kv.Handle + var tableRecord *consistency.RecordData + var lastTableRecord *consistency.RecordData + var indexRecord *consistency.RecordData + i := 0 + for i < len(tblRow) || i < len(idxRow) { + if i == len(tblRow) { + // No more rows in table side. + tableRecord = nil + } else { + handle, err = getHandleFromRow(tblRow[i]) + if err != nil { + trySaveErr(err) + return + } + value, err := getValueFromRow(tblRow[i]) + if err != nil { + trySaveErr(err) + return + } + tableRecord = &consistency.RecordData{Handle: handle, Values: value} + } + if i == len(idxRow) { + // No more rows in index side. + indexRecord = nil + } else { + indexHandle, err := getHandleFromRow(idxRow[i]) + if err != nil { + trySaveErr(err) + return + } + indexValue, err := getValueFromRow(idxRow[i]) + if err != nil { + trySaveErr(err) + return + } + indexRecord = &consistency.RecordData{Handle: indexHandle, Values: indexValue} + } + + if tableRecord == nil { + if lastTableRecord != nil && lastTableRecord.Handle.Equal(indexRecord.Handle) { + tableRecord = lastTableRecord + } + err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, indexRecord.Handle, indexRecord, tableRecord) + } else if indexRecord == nil { + err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, tableRecord.Handle, indexRecord, tableRecord) + } else if tableRecord.Handle.Equal(indexRecord.Handle) && getCheckSum(tblRow[i]) != getCheckSum(idxRow[i]) { + err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, tableRecord.Handle, indexRecord, tableRecord) + } else if !tableRecord.Handle.Equal(indexRecord.Handle) { + if tableRecord.Handle.Compare(indexRecord.Handle) < 0 { + err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, tableRecord.Handle, nil, tableRecord) + } else { + if lastTableRecord != nil && lastTableRecord.Handle.Equal(indexRecord.Handle) { + err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, indexRecord.Handle, indexRecord, lastTableRecord) + } else { + err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, indexRecord.Handle, indexRecord, nil) + } + } + } + if err != nil { + trySaveErr(err) + return + } + i++ + if tableRecord != nil { + lastTableRecord = &consistency.RecordData{Handle: tableRecord.Handle, Values: tableRecord.Values} + } else { + lastTableRecord = nil + } + } + } +} + +// Close implements the Worker interface. +func (*checkIndexWorker) Close() {} + +func (e *FastCheckTableExec) createWorker() workerpool.Worker[checkIndexTask, workerpool.None] { + return &checkIndexWorker{sctx: e.Ctx(), dbName: e.dbName, table: e.table, indexInfos: e.indexInfos, e: e} +} + +// Next implements the Executor Next interface. +func (e *FastCheckTableExec) Next(ctx context.Context, _ *chunk.Chunk) error { + if e.done || len(e.indexInfos) == 0 { + return nil + } + defer func() { e.done = true }() + + // Here we need check all indexes, includes invisible index + e.Ctx().GetSessionVars().OptimizerUseInvisibleIndexes = true + defer func() { + e.Ctx().GetSessionVars().OptimizerUseInvisibleIndexes = false + }() + + workerPool := workerpool.NewWorkerPool[checkIndexTask]("checkIndex", + poolutil.CheckTable, 3, e.createWorker) + workerPool.Start(ctx) + + e.wg.Add(len(e.indexInfos)) + for i := range e.indexInfos { + workerPool.AddTask(checkIndexTask{indexOffset: i}) + } + + e.wg.Wait() + workerPool.ReleaseAndWait() + + p := e.err.Load() + if p == nil { + return nil + } + return *p +} + +// TableName returns `schema`.`table` +func TableName(schema, table string) string { + return fmt.Sprintf("`%s`.`%s`", escapeName(schema), escapeName(table)) +} + +// ColumnName returns `column` +func ColumnName(column string) string { + return fmt.Sprintf("`%s`", escapeName(column)) +} + +func escapeName(name string) string { + return strings.ReplaceAll(name, "`", "``") +} + +// AdminShowBDRRoleExec represents a show BDR role executor. +type AdminShowBDRRoleExec struct { + exec.BaseExecutor + + done bool +} + +// Next implements the Executor Next interface. +func (e *AdminShowBDRRoleExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.done { + return nil + } + + return kv.RunInNewTxn(kv.WithInternalSourceType(ctx, kv.InternalTxnAdmin), e.Ctx().GetStore(), true, func(_ context.Context, txn kv.Transaction) error { + role, err := meta.NewMeta(txn).GetBDRRole() + if err != nil { + return err + } + + req.AppendString(0, role) + e.done = true + return nil + }) +} diff --git a/pkg/executor/import_into.go b/pkg/executor/import_into.go index 7d9a4f92efd95..6e2a979c6f2fe 100644 --- a/pkg/executor/import_into.go +++ b/pkg/executor/import_into.go @@ -119,7 +119,7 @@ func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error) return err } - failpoint.InjectCall("cancellableCtx", &ctx) + failpoint.Call(_curpkg_("cancellableCtx"), &ctx) jobID, task, err := e.submitTask(ctx) if err != nil { diff --git a/pkg/executor/import_into.go__failpoint_stash__ b/pkg/executor/import_into.go__failpoint_stash__ new file mode 100644 index 0000000000000..7d9a4f92efd95 --- /dev/null +++ b/pkg/executor/import_into.go__failpoint_stash__ @@ -0,0 +1,344 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/handle" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + fstorage "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/importinto" + "github.com/pingcap/tidb/pkg/executor/importer" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +const unknownImportedRowCount = -1 + +// ImportIntoExec represents a IMPORT INTO executor. +type ImportIntoExec struct { + exec.BaseExecutor + selectExec exec.Executor + userSctx sessionctx.Context + controller *importer.LoadDataController + stmt string + + plan *plannercore.ImportInto + tbl table.Table + dataFilled bool +} + +var ( + _ exec.Executor = (*ImportIntoExec)(nil) +) + +func newImportIntoExec(b exec.BaseExecutor, selectExec exec.Executor, userSctx sessionctx.Context, + plan *plannercore.ImportInto, tbl table.Table) (*ImportIntoExec, error) { + return &ImportIntoExec{ + BaseExecutor: b, + selectExec: selectExec, + userSctx: userSctx, + stmt: plan.Stmt, + plan: plan, + tbl: tbl, + }, nil +} + +// Next implements the Executor Next interface. +func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.GrowAndReset(e.MaxChunkSize()) + ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) + if e.dataFilled { + // need to return an empty req to indicate all results have been written + return nil + } + importPlan, err := importer.NewImportPlan(ctx, e.userSctx, e.plan, e.tbl) + if err != nil { + return err + } + astArgs := importer.ASTArgsFromImportPlan(e.plan) + controller, err := importer.NewLoadDataController(importPlan, e.tbl, astArgs) + if err != nil { + return err + } + e.controller = controller + + if e.selectExec != nil { + // `import from select` doesn't return rows, so no need to set dataFilled. + return e.importFromSelect(ctx) + } + + if err2 := e.controller.InitDataFiles(ctx); err2 != nil { + return err2 + } + + // must use a new session to pre-check, else the stmt in show processlist will be changed. + newSCtx, err2 := CreateSession(e.userSctx) + if err2 != nil { + return err2 + } + defer CloseSession(newSCtx) + sqlExec := newSCtx.GetSQLExecutor() + if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil { + return err2 + } + + if err := e.controller.InitTiKVConfigs(ctx, newSCtx); err != nil { + return err + } + + failpoint.InjectCall("cancellableCtx", &ctx) + + jobID, task, err := e.submitTask(ctx) + if err != nil { + return err + } + + if !e.controller.Detached { + if err = e.waitTask(ctx, jobID, task); err != nil { + return err + } + } + return e.fillJobInfo(ctx, jobID, req) +} + +func (e *ImportIntoExec) fillJobInfo(ctx context.Context, jobID int64, req *chunk.Chunk) error { + e.dataFilled = true + // we use taskManager to get job, user might not have the privilege to system tables. + taskManager, err := fstorage.GetTaskManager() + ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) + if err != nil { + return err + } + var info *importer.JobInfo + if err = taskManager.WithNewSession(func(se sessionctx.Context) error { + sqlExec := se.GetSQLExecutor() + var err2 error + info, err2 = importer.GetJob(ctx, sqlExec, jobID, e.Ctx().GetSessionVars().User.String(), false) + return err2 + }); err != nil { + return err + } + fillOneImportJobInfo(info, req, unknownImportedRowCount) + return nil +} + +func (e *ImportIntoExec) submitTask(ctx context.Context) (int64, *proto.TaskBase, error) { + importFromServer, err := storage.IsLocalPath(e.controller.Path) + if err != nil { + // since we have checked this during creating controller, this should not happen. + return 0, nil, exeerrors.ErrLoadDataInvalidURI.FastGenByArgs(plannercore.ImportIntoDataSource, err.Error()) + } + logutil.Logger(ctx).Info("get job importer", zap.Stringer("param", e.controller.Parameters), + zap.Bool("dist-task-enabled", variable.EnableDistTask.Load())) + if importFromServer { + ecp, err2 := e.controller.PopulateChunks(ctx) + if err2 != nil { + return 0, nil, err2 + } + return importinto.SubmitStandaloneTask(ctx, e.controller.Plan, e.stmt, ecp) + } + // if tidb_enable_dist_task=true, we import distributively, otherwise we import on current node. + if variable.EnableDistTask.Load() { + return importinto.SubmitTask(ctx, e.controller.Plan, e.stmt) + } + return importinto.SubmitStandaloneTask(ctx, e.controller.Plan, e.stmt, nil) +} + +// waitTask waits for the task to finish. +// NOTE: WaitTaskDoneOrPaused also return error when task fails. +func (*ImportIntoExec) waitTask(ctx context.Context, jobID int64, task *proto.TaskBase) error { + err := handle.WaitTaskDoneOrPaused(ctx, task.ID) + // when user KILL the connection, the ctx will be canceled, we need to cancel the import job. + if errors.Cause(err) == context.Canceled { + taskManager, err2 := fstorage.GetTaskManager() + if err2 != nil { + return err2 + } + // use background, since ctx is canceled already. + return cancelAndWaitImportJob(context.Background(), taskManager, jobID) + } + return err +} + +func (e *ImportIntoExec) importFromSelect(ctx context.Context) error { + e.dataFilled = true + // must use a new session as: + // - pre-check will execute other sql, the stmt in show processlist will be changed. + // - userSctx might be in stale read, we cannot do write. + newSCtx, err2 := CreateSession(e.userSctx) + if err2 != nil { + return err2 + } + defer CloseSession(newSCtx) + + sqlExec := newSCtx.GetSQLExecutor() + if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil { + return err2 + } + if err := e.controller.InitTiKVConfigs(ctx, newSCtx); err != nil { + return err + } + + importID := uuid.New().String() + logutil.Logger(ctx).Info("importing data from select statement", + zap.String("import-id", importID), zap.Int("concurrency", e.controller.ThreadCnt), + zap.String("target-table", e.controller.FullTableName()), + zap.Int64("target-table-id", e.controller.TableInfo.ID)) + ti, err2 := importer.NewTableImporter(ctx, e.controller, importID, e.Ctx().GetStore()) + if err2 != nil { + return err2 + } + defer func() { + if err := ti.Close(); err != nil { + logutil.Logger(ctx).Error("close importer failed", zap.Error(err)) + } + }() + selectedRowCh := make(chan importer.QueryRow) + ti.SetSelectedRowCh(selectedRowCh) + + var importResult *importer.JobImportResult + eg, egCtx := errgroup.WithContext(ctx) + eg.Go(func() error { + var err error + importResult, err = ti.ImportSelectedRows(egCtx, newSCtx) + return err + }) + eg.Go(func() error { + defer close(selectedRowCh) + fields := exec.RetTypes(e.selectExec) + var idAllocator int64 + for { + // rows will be consumed concurrently, we cannot use chunk pool in session ctx. + chk := exec.NewFirstChunk(e.selectExec) + iter := chunk.NewIterator4Chunk(chk) + err := exec.Next(egCtx, e.selectExec, chk) + if err != nil { + return err + } + if chk.NumRows() == 0 { + break + } + for innerChunkRow := iter.Begin(); innerChunkRow != iter.End(); innerChunkRow = iter.Next() { + idAllocator++ + select { + case selectedRowCh <- importer.QueryRow{ + ID: idAllocator, + Data: innerChunkRow.GetDatumRow(fields), + }: + case <-egCtx.Done(): + return egCtx.Err() + } + } + } + return nil + }) + if err := eg.Wait(); err != nil { + return err + } + + if err2 = importer.FlushTableStats(ctx, newSCtx, e.controller.TableInfo.ID, importResult); err2 != nil { + logutil.Logger(ctx).Error("flush stats failed", zap.Error(err2)) + } + + stmtCtx := e.userSctx.GetSessionVars().StmtCtx + stmtCtx.SetAffectedRows(importResult.Affected) + // TODO: change it after spec is ready. + stmtCtx.SetMessage(fmt.Sprintf("Records: %d, ID: %s", importResult.Affected, importID)) + return nil +} + +// ImportIntoActionExec represents a import into action executor. +type ImportIntoActionExec struct { + exec.BaseExecutor + tp ast.ImportIntoActionTp + jobID int64 +} + +var ( + _ exec.Executor = (*ImportIntoActionExec)(nil) +) + +// Next implements the Executor Next interface. +func (e *ImportIntoActionExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) + + var hasSuperPriv bool + if pm := privilege.GetPrivilegeManager(e.Ctx()); pm != nil { + hasSuperPriv = pm.RequestVerification(e.Ctx().GetSessionVars().ActiveRoles, "", "", "", mysql.SuperPriv) + } + // we use sessionCtx from GetTaskManager, user ctx might not have enough privileges. + taskManager, err := fstorage.GetTaskManager() + ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) + if err != nil { + return err + } + if err = e.checkPrivilegeAndStatus(ctx, taskManager, hasSuperPriv); err != nil { + return err + } + + task := log.BeginTask(logutil.Logger(ctx).With(zap.Int64("jobID", e.jobID), + zap.Any("action", e.tp)), "import into action") + defer func() { + task.End(zap.ErrorLevel, err) + }() + return cancelAndWaitImportJob(ctx, taskManager, e.jobID) +} + +func (e *ImportIntoActionExec) checkPrivilegeAndStatus(ctx context.Context, manager *fstorage.TaskManager, hasSuperPriv bool) error { + var info *importer.JobInfo + if err := manager.WithNewSession(func(se sessionctx.Context) error { + exec := se.GetSQLExecutor() + var err2 error + info, err2 = importer.GetJob(ctx, exec, e.jobID, e.Ctx().GetSessionVars().User.String(), hasSuperPriv) + return err2 + }); err != nil { + return err + } + if !info.CanCancel() { + return exeerrors.ErrLoadDataInvalidOperation.FastGenByArgs("CANCEL") + } + return nil +} + +func cancelAndWaitImportJob(ctx context.Context, manager *fstorage.TaskManager, jobID int64) error { + if err := manager.WithNewTxn(ctx, func(se sessionctx.Context) error { + ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) + return manager.CancelTaskByKeySession(ctx, se, importinto.TaskKey(jobID)) + }); err != nil { + return err + } + return handle.WaitTaskDoneByKey(ctx, importinto.TaskKey(jobID)) +} diff --git a/pkg/executor/importer/binding__failpoint_binding__.go b/pkg/executor/importer/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..62be625525776 --- /dev/null +++ b/pkg/executor/importer/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package importer + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/importer/job.go b/pkg/executor/importer/job.go index dc0dec7b20d90..b79ca87fb337f 100644 --- a/pkg/executor/importer/job.go +++ b/pkg/executor/importer/job.go @@ -221,9 +221,9 @@ func CreateJob( return 0, errors.Errorf("unexpected result length: %d", len(rows)) } - failpoint.Inject("setLastImportJobID", func() { + if _, _err_ := failpoint.Eval(_curpkg_("setLastImportJobID")); _err_ == nil { TestLastImportJobID.Store(rows[0].GetInt64(0)) - }) + } return rows[0].GetInt64(0), nil } diff --git a/pkg/executor/importer/job.go__failpoint_stash__ b/pkg/executor/importer/job.go__failpoint_stash__ new file mode 100644 index 0000000000000..dc0dec7b20d90 --- /dev/null +++ b/pkg/executor/importer/job.go__failpoint_stash__ @@ -0,0 +1,370 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "context" + "encoding/json" + "fmt" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/tikv/client-go/v2/util" +) + +// vars used for test. +var ( + // TestLastImportJobID last created job id, used in unit test. + TestLastImportJobID atomic.Int64 +) + +// constants for job status and step. +const ( + // JobStatus + // ┌───────┐ ┌───────┐ ┌────────┐ + // │pending├────►│running├───►│finished│ + // └────┬──┘ └────┬──┘ └────────┘ + // │ │ ┌──────┐ + // │ ├──────►│failed│ + // │ │ └──────┘ + // │ │ ┌─────────┐ + // └─────────────┴──────►│cancelled│ + // └─────────┘ + jobStatusPending = "pending" + // JobStatusRunning exported since it's used in show import jobs + JobStatusRunning = "running" + jogStatusCancelled = "cancelled" + jobStatusFailed = "failed" + jobStatusFinished = "finished" + + // when the job is finished, step will be set to none. + jobStepNone = "" + // JobStepGlobalSorting is the first step when using global sort, + // step goes from none -> global-sorting -> importing -> validating -> none. + JobStepGlobalSorting = "global-sorting" + // JobStepImporting is the first step when using local sort, + // step goes from none -> importing -> validating -> none. + // when used in global sort, it means importing the sorted data. + // when used in local sort, it means encode&sort data and then importing the data. + JobStepImporting = "importing" + JobStepValidating = "validating" + + baseQuerySQL = `SELECT + id, create_time, start_time, end_time, + table_schema, table_name, table_id, created_by, parameters, source_file_size, + status, step, summary, error_message + FROM mysql.tidb_import_jobs` +) + +// ImportParameters is the parameters for import into statement. +// it's a minimal meta info to store in tidb_import_jobs for diagnose. +// for detailed info, see tidb_global_tasks. +type ImportParameters struct { + ColumnsAndVars string `json:"columns-and-vars,omitempty"` + SetClause string `json:"set-clause,omitempty"` + // for s3 URL, AK/SK is redacted for security + FileLocation string `json:"file-location"` + Format string `json:"format"` + // only include what user specified, not include default value. + Options map[string]any `json:"options,omitempty"` +} + +var _ fmt.Stringer = &ImportParameters{} + +// String implements fmt.Stringer interface. +func (ip *ImportParameters) String() string { + b, _ := json.Marshal(ip) + return string(b) +} + +// JobSummary is the summary info of import into job. +type JobSummary struct { + // ImportedRows is the number of rows imported into TiKV. + ImportedRows uint64 `json:"imported-rows,omitempty"` +} + +// JobInfo is the information of import into job. +type JobInfo struct { + ID int64 + CreateTime types.Time + StartTime types.Time + EndTime types.Time + TableSchema string + TableName string + TableID int64 + CreatedBy string + Parameters ImportParameters + SourceFileSize int64 + Status string + // in SHOW IMPORT JOB, we name it as phase. + // here, we use the same name as in distributed framework. + Step string + // the summary info of the job, it's updated only when the job is finished. + // for running job, we should query the progress from the distributed framework. + Summary *JobSummary + ErrorMessage string +} + +// CanCancel returns whether the job can be cancelled. +func (j *JobInfo) CanCancel() bool { + return j.Status == jobStatusPending || j.Status == JobStatusRunning +} + +// GetJob returns the job with the given id if the user has privilege. +// hasSuperPriv: whether the user has super privilege. +// If the user has super privilege, the user can show or operate all jobs, +// else the user can only show or operate his own jobs. +func GetJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, user string, hasSuperPriv bool) (*JobInfo, error) { + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + + sql := baseQuerySQL + ` WHERE id = %?` + rs, err := conn.ExecuteInternal(ctx, sql, jobID) + if err != nil { + return nil, err + } + defer terror.Call(rs.Close) + rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) + if err != nil { + return nil, err + } + if len(rows) != 1 { + return nil, exeerrors.ErrLoadDataJobNotFound.GenWithStackByArgs(jobID) + } + + info, err := convert2JobInfo(rows[0]) + if err != nil { + return nil, err + } + if !hasSuperPriv && info.CreatedBy != user { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("SUPER") + } + return info, nil +} + +// GetActiveJobCnt returns the count of active import jobs. +// Active import jobs include pending and running jobs. +func GetActiveJobCnt(ctx context.Context, conn sqlexec.SQLExecutor, tableSchema, tableName string) (int64, error) { + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + + sql := `select count(1) from mysql.tidb_import_jobs + where status in (%?, %?) + and table_schema = %? and table_name = %?; + ` + rs, err := conn.ExecuteInternal(ctx, sql, jobStatusPending, JobStatusRunning, + tableSchema, tableName) + if err != nil { + return 0, err + } + defer terror.Call(rs.Close) + rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) + if err != nil { + return 0, err + } + cnt := rows[0].GetInt64(0) + return cnt, nil +} + +// CreateJob creates import into job by insert a record to system table. +// The AUTO_INCREMENT value will be returned as jobID. +func CreateJob( + ctx context.Context, + conn sqlexec.SQLExecutor, + db, table string, + tableID int64, + user string, + parameters *ImportParameters, + sourceFileSize int64, +) (int64, error) { + bytes, err := json.Marshal(parameters) + if err != nil { + return 0, err + } + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + _, err = conn.ExecuteInternal(ctx, `INSERT INTO mysql.tidb_import_jobs + (table_schema, table_name, table_id, created_by, parameters, source_file_size, status, step) + VALUES (%?, %?, %?, %?, %?, %?, %?, %?);`, + db, table, tableID, user, bytes, sourceFileSize, jobStatusPending, jobStepNone) + if err != nil { + return 0, err + } + rs, err := conn.ExecuteInternal(ctx, `SELECT LAST_INSERT_ID();`) + if err != nil { + return 0, err + } + defer terror.Call(rs.Close) + + rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) + if err != nil { + return 0, err + } + if len(rows) != 1 { + return 0, errors.Errorf("unexpected result length: %d", len(rows)) + } + + failpoint.Inject("setLastImportJobID", func() { + TestLastImportJobID.Store(rows[0].GetInt64(0)) + }) + return rows[0].GetInt64(0), nil +} + +// StartJob tries to start a pending job with jobID, change its status/step to running/input step. +// It will not return error when there's no matched job or the job has already started. +func StartJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, step string) error { + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + _, err := conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs + SET update_time = CURRENT_TIMESTAMP(6), start_time = CURRENT_TIMESTAMP(6), status = %?, step = %? + WHERE id = %? AND status = %?;`, + JobStatusRunning, step, jobID, jobStatusPending) + + return err +} + +// Job2Step tries to change the step of a running job with jobID. +// It will not return error when there's no matched job. +func Job2Step(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, step string) error { + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + _, err := conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs + SET update_time = CURRENT_TIMESTAMP(6), step = %? + WHERE id = %? AND status = %?;`, + step, jobID, JobStatusRunning) + + return err +} + +// FinishJob tries to finish a running job with jobID, change its status to finished, clear its step. +// It will not return error when there's no matched job. +func FinishJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, summary *JobSummary) error { + bytes, err := json.Marshal(summary) + if err != nil { + return err + } + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + _, err = conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs + SET update_time = CURRENT_TIMESTAMP(6), end_time = CURRENT_TIMESTAMP(6), status = %?, step = %?, summary = %? + WHERE id = %? AND status = %?;`, + jobStatusFinished, jobStepNone, bytes, jobID, JobStatusRunning) + return err +} + +// FailJob fails import into job. A job can only be failed once. +// It will not return error when there's no matched job. +func FailJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, errorMsg string) error { + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + _, err := conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs + SET update_time = CURRENT_TIMESTAMP(6), end_time = CURRENT_TIMESTAMP(6), status = %?, error_message = %? + WHERE id = %? AND status = %?;`, + jobStatusFailed, errorMsg, jobID, JobStatusRunning) + return err +} + +func convert2JobInfo(row chunk.Row) (*JobInfo, error) { + // start_time, end_time, summary, error_message can be NULL, need to use row.IsNull() to check. + startTime, endTime := types.ZeroTime, types.ZeroTime + if !row.IsNull(2) { + startTime = row.GetTime(2) + } + if !row.IsNull(3) { + endTime = row.GetTime(3) + } + + parameters := ImportParameters{} + parametersStr := row.GetString(8) + if err := json.Unmarshal([]byte(parametersStr), ¶meters); err != nil { + return nil, errors.Trace(err) + } + + var summary *JobSummary + var summaryStr string + if !row.IsNull(12) { + summaryStr = row.GetString(12) + } + if len(summaryStr) > 0 { + summary = &JobSummary{} + if err := json.Unmarshal([]byte(summaryStr), summary); err != nil { + return nil, errors.Trace(err) + } + } + + var errMsg string + if !row.IsNull(13) { + errMsg = row.GetString(13) + } + return &JobInfo{ + ID: row.GetInt64(0), + CreateTime: row.GetTime(1), + StartTime: startTime, + EndTime: endTime, + TableSchema: row.GetString(4), + TableName: row.GetString(5), + TableID: row.GetInt64(6), + CreatedBy: row.GetString(7), + Parameters: parameters, + SourceFileSize: row.GetInt64(9), + Status: row.GetString(10), + Step: row.GetString(11), + Summary: summary, + ErrorMessage: errMsg, + }, nil +} + +// GetAllViewableJobs gets all viewable jobs. +func GetAllViewableJobs(ctx context.Context, conn sqlexec.SQLExecutor, user string, hasSuperPriv bool) ([]*JobInfo, error) { + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + sql := baseQuerySQL + args := []any{} + if !hasSuperPriv { + sql += " WHERE created_by = %?" + args = append(args, user) + } + rs, err := conn.ExecuteInternal(ctx, sql, args...) + if err != nil { + return nil, err + } + defer terror.Call(rs.Close) + rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) + if err != nil { + return nil, err + } + ret := make([]*JobInfo, 0, len(rows)) + for _, row := range rows { + jobInfo, err2 := convert2JobInfo(row) + if err2 != nil { + return nil, err2 + } + ret = append(ret, jobInfo) + } + + return ret, nil +} + +// CancelJob cancels import into job. Only a running/paused job can be canceled. +// check privileges using get before calling this method. +func CancelJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64) (err error) { + ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) + sql := `UPDATE mysql.tidb_import_jobs + SET update_time = CURRENT_TIMESTAMP(6), status = %?, error_message = 'cancelled by user' + WHERE id = %? AND status IN (%?, %?);` + args := []any{jogStatusCancelled, jobID, jobStatusPending, JobStatusRunning} + _, err = conn.ExecuteInternal(ctx, sql, args...) + return err +} diff --git a/pkg/executor/importer/table_import.go b/pkg/executor/importer/table_import.go index 1e5d1ad2b14e6..649d0eff2c512 100644 --- a/pkg/executor/importer/table_import.go +++ b/pkg/executor/importer/table_import.go @@ -684,9 +684,9 @@ func (ti *TableImporter) ImportSelectedRows(ctx context.Context, se sessionctx.C if err != nil { return nil, err } - failpoint.Inject("mockImportFromSelectErr", func() { - failpoint.Return(nil, errors.New("mock import from select error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockImportFromSelectErr")); _err_ == nil { + return nil, errors.New("mock import from select error") + } if err = closedDataEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys); err != nil { if common.ErrFoundDuplicateKeys.Equal(err) { err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) @@ -834,9 +834,9 @@ func VerifyChecksum(ctx context.Context, plan *Plan, localChecksum verify.KVChec } logger.Info("local checksum", zap.Object("checksum", &localChecksum)) - failpoint.Inject("waitCtxDone", func() { + if _, _err_ := failpoint.Eval(_curpkg_("waitCtxDone")); _err_ == nil { <-ctx.Done() - }) + } remoteChecksum, err := checksumTable(ctx, se, plan, logger) if err != nil { @@ -911,9 +911,9 @@ func checksumTable(ctx context.Context, se sessionctx.Context, plan *Plan, logge return errors.New("empty checksum result") } - failpoint.Inject("errWhenChecksum", func() { - failpoint.Return(errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("errWhenChecksum")); _err_ == nil { + return errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline") + } // ADMIN CHECKSUM TABLE . example. // mysql> admin checksum table test.t; diff --git a/pkg/executor/importer/table_import.go__failpoint_stash__ b/pkg/executor/importer/table_import.go__failpoint_stash__ new file mode 100644 index 0000000000000..1e5d1ad2b14e6 --- /dev/null +++ b/pkg/executor/importer/table_import.go__failpoint_stash__ @@ -0,0 +1,983 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importer + +import ( + "context" + "io" + "math" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "github.com/docker/go-units" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + tidb "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/keyspace" + tidbkv "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/backend/local" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/mydump" + verify "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/meta/autoid" + tidbmetrics "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/etcd" + "github.com/pingcap/tidb/pkg/util/mathutil" + "github.com/pingcap/tidb/pkg/util/promutil" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/sqlkiller" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/prometheus/client_golang/prometheus" + "github.com/tikv/client-go/v2/util" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +// NewTiKVModeSwitcher make it a var, so we can mock it in tests. +var NewTiKVModeSwitcher = local.NewTiKVModeSwitcher + +var ( + // CheckDiskQuotaInterval is the default time interval to check disk quota. + // TODO: make it dynamically adjusting according to the speed of import and the disk size. + CheckDiskQuotaInterval = 10 * time.Second + + // defaultMaxEngineSize is the default max engine size in bytes. + // we make it 5 times larger than lightning default engine size to reduce range overlap, especially for index, + // since we have an index engine per distributed subtask. + // for 1TiB data, we can divide it into 2 engines that runs on 2 TiDB. it can have a good balance between + // range overlap and sort speed in one of our test of: + // - 10 columns, PK + 6 secondary index 2 of which is mv index + // - 1.05 KiB per row, 527 MiB per file, 1024000000 rows, 1 TiB total + // + // it might not be the optimal value for other cases. + defaultMaxEngineSize = int64(5 * config.DefaultBatchSize) +) + +// prepareSortDir creates a new directory for import, remove previous sort directory if exists. +func prepareSortDir(e *LoadDataController, id string, tidbCfg *tidb.Config) (string, error) { + importDir := GetImportRootDir(tidbCfg) + sortDir := filepath.Join(importDir, id) + + if info, err := os.Stat(importDir); err != nil || !info.IsDir() { + if err != nil && !os.IsNotExist(err) { + e.logger.Error("stat import dir failed", zap.String("import_dir", importDir), zap.Error(err)) + return "", errors.Trace(err) + } + if info != nil && !info.IsDir() { + e.logger.Warn("import dir is not a dir, remove it", zap.String("import_dir", importDir)) + if err := os.RemoveAll(importDir); err != nil { + return "", errors.Trace(err) + } + } + e.logger.Info("import dir not exists, create it", zap.String("import_dir", importDir)) + if err := os.MkdirAll(importDir, 0o700); err != nil { + e.logger.Error("failed to make dir", zap.String("import_dir", importDir), zap.Error(err)) + return "", errors.Trace(err) + } + } + + // todo: remove this after we support checkpoint + if _, err := os.Stat(sortDir); err != nil { + if !os.IsNotExist(err) { + e.logger.Error("stat sort dir failed", zap.String("sort_dir", sortDir), zap.Error(err)) + return "", errors.Trace(err) + } + } else { + e.logger.Warn("sort dir already exists, remove it", zap.String("sort_dir", sortDir)) + if err := os.RemoveAll(sortDir); err != nil { + return "", errors.Trace(err) + } + } + return sortDir, nil +} + +// GetRegionSplitSizeKeys gets the region split size and keys from PD. +func GetRegionSplitSizeKeys(ctx context.Context) (regionSplitSize int64, regionSplitKeys int64, err error) { + tidbCfg := tidb.GetGlobalConfig() + tls, err := common.NewTLS( + tidbCfg.Security.ClusterSSLCA, + tidbCfg.Security.ClusterSSLCert, + tidbCfg.Security.ClusterSSLKey, + "", + nil, nil, nil, + ) + if err != nil { + return 0, 0, err + } + tlsOpt := tls.ToPDSecurityOption() + addrs := strings.Split(tidbCfg.Path, ",") + pdCli, err := NewClientWithContext(ctx, addrs, tlsOpt) + if err != nil { + return 0, 0, errors.Trace(err) + } + defer pdCli.Close() + return local.GetRegionSplitSizeKeys(ctx, pdCli, tls) +} + +// NewTableImporter creates a new table importer. +func NewTableImporter( + ctx context.Context, + e *LoadDataController, + id string, + kvStore tidbkv.Storage, +) (ti *TableImporter, err error) { + idAlloc := kv.NewPanickingAllocators(e.Table.Meta().SepAutoInc(), 0) + tbl, err := tables.TableFromMeta(idAlloc, e.Table.Meta()) + if err != nil { + return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", e.Table.Meta().Name) + } + + tidbCfg := tidb.GetGlobalConfig() + // todo: we only need to prepare this once on each node(we might call it 3 times in distribution framework) + dir, err := prepareSortDir(e, id, tidbCfg) + if err != nil { + return nil, err + } + + hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) + tls, err := common.NewTLS( + tidbCfg.Security.ClusterSSLCA, + tidbCfg.Security.ClusterSSLCert, + tidbCfg.Security.ClusterSSLKey, + hostPort, + nil, nil, nil, + ) + if err != nil { + return nil, err + } + + backendConfig := e.getLocalBackendCfg(tidbCfg.Path, dir) + d := kvStore.(tidbkv.StorageWithPD).GetPDClient().GetServiceDiscovery() + localBackend, err := local.NewBackend(ctx, tls, backendConfig, d) + if err != nil { + return nil, err + } + + return &TableImporter{ + LoadDataController: e, + id: id, + backend: localBackend, + tableInfo: &checkpoints.TidbTableInfo{ + ID: e.Table.Meta().ID, + Name: e.Table.Meta().Name.O, + Core: e.Table.Meta(), + }, + encTable: tbl, + dbID: e.DBID, + keyspace: kvStore.GetCodec().GetKeyspace(), + logger: e.logger.With(zap.String("import-id", id)), + // this is the value we use for 50TiB data parallel import. + // this might not be the optimal value. + // todo: use different default for single-node import and distributed import. + regionSplitSize: 2 * int64(config.SplitRegionSize), + regionSplitKeys: 2 * int64(config.SplitRegionKeys), + diskQuota: adjustDiskQuota(int64(e.DiskQuota), dir, e.logger), + diskQuotaLock: new(syncutil.RWMutex), + }, nil +} + +// TableImporter is a table importer. +type TableImporter struct { + *LoadDataController + // id is the unique id for this importer. + // it's the task id if we are running in distributed framework, else it's an + // uuid. we use this id to create a unique directory for this importer. + id string + backend *local.Backend + tableInfo *checkpoints.TidbTableInfo + // this table has a separate id allocator used to record the max row id allocated. + encTable table.Table + dbID int64 + + keyspace []byte + logger *zap.Logger + regionSplitSize int64 + regionSplitKeys int64 + diskQuota int64 + diskQuotaLock *syncutil.RWMutex + + rowCh chan QueryRow +} + +// NewTableImporterForTest creates a new table importer for test. +func NewTableImporterForTest(ctx context.Context, e *LoadDataController, id string, helper local.StoreHelper) (*TableImporter, error) { + idAlloc := kv.NewPanickingAllocators(e.Table.Meta().SepAutoInc(), 0) + tbl, err := tables.TableFromMeta(idAlloc, e.Table.Meta()) + if err != nil { + return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", e.Table.Meta().Name) + } + + tidbCfg := tidb.GetGlobalConfig() + dir, err := prepareSortDir(e, id, tidbCfg) + if err != nil { + return nil, err + } + + backendConfig := e.getLocalBackendCfg(tidbCfg.Path, dir) + localBackend, err := local.NewBackendForTest(ctx, backendConfig, helper) + if err != nil { + return nil, err + } + + return &TableImporter{ + LoadDataController: e, + id: id, + backend: localBackend, + tableInfo: &checkpoints.TidbTableInfo{ + ID: e.Table.Meta().ID, + Name: e.Table.Meta().Name.O, + Core: e.Table.Meta(), + }, + encTable: tbl, + dbID: e.DBID, + logger: e.logger.With(zap.String("import-id", id)), + diskQuotaLock: new(syncutil.RWMutex), + }, nil +} + +// GetKeySpace gets the keyspace of the kv store. +func (ti *TableImporter) GetKeySpace() []byte { + return ti.keyspace +} + +func (ti *TableImporter) getParser(ctx context.Context, chunk *checkpoints.ChunkCheckpoint) (mydump.Parser, error) { + info := LoadDataReaderInfo{ + Opener: func(ctx context.Context) (io.ReadSeekCloser, error) { + reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, ti.dataStore, storage.DecompressConfig{ + ZStdDecodeConcurrency: 1, + }) + if err != nil { + return nil, errors.Trace(err) + } + return reader, nil + }, + Remote: &chunk.FileMeta, + } + parser, err := ti.LoadDataController.GetParser(ctx, info) + if err != nil { + return nil, err + } + if chunk.Chunk.Offset == 0 { + // if data file is split, only the first chunk need to do skip. + // see check in initOptions. + if err = ti.LoadDataController.HandleSkipNRows(parser); err != nil { + return nil, err + } + parser.SetRowID(chunk.Chunk.PrevRowIDMax) + } else { + // if we reached here, the file must be an uncompressed CSV file. + if err = parser.SetPos(chunk.Chunk.Offset, chunk.Chunk.PrevRowIDMax); err != nil { + return nil, err + } + } + return parser, nil +} + +func (ti *TableImporter) getKVEncoder(chunk *checkpoints.ChunkCheckpoint) (KVEncoder, error) { + cfg := &encode.EncodingConfig{ + SessionOptions: encode.SessionOptions{ + SQLMode: ti.SQLMode, + Timestamp: chunk.Timestamp, + SysVars: ti.ImportantSysVars, + AutoRandomSeed: chunk.Chunk.PrevRowIDMax, + }, + Path: chunk.FileMeta.Path, + Table: ti.encTable, + Logger: log.Logger{Logger: ti.logger.With(zap.String("path", chunk.FileMeta.Path))}, + } + return NewTableKVEncoder(cfg, ti) +} + +func (e *LoadDataController) calculateSubtaskCnt() int { + // we want to split data files into subtask of size close to MaxEngineSize to reduce range overlap, + // and evenly distribute them to subtasks. + // we calculate subtask count first by round(TotalFileSize / maxEngineSize) + + // AllocateEngineIDs is using ceil() to calculate subtask count, engine size might be too small in some case, + // such as 501G data, maxEngineSize will be about 250G, so we don't relay on it. + // see https://github.com/pingcap/tidb/blob/b4183e1dc9bb01fb81d3aa79ca4b5b74387c6c2a/br/pkg/lightning/mydump/region.go#L109 + // + // for default e.MaxEngineSize = 500GiB, we have: + // data size range(G) cnt adjusted-engine-size range(G) + // [0, 750) 1 [0, 750) + // [750, 1250) 2 [375, 625) + // [1250, 1750) 3 [416, 583) + // [1750, 2250) 4 [437, 562) + var ( + subtaskCount float64 + maxEngineSize = int64(e.MaxEngineSize) + ) + if e.TotalFileSize <= maxEngineSize { + subtaskCount = 1 + } else { + subtaskCount = math.Round(float64(e.TotalFileSize) / float64(e.MaxEngineSize)) + } + + // for global sort task, since there is no overlap, + // we make sure subtask count is a multiple of execute nodes count + if e.IsGlobalSort() && e.ExecuteNodesCnt > 0 { + subtaskCount = math.Ceil(subtaskCount/float64(e.ExecuteNodesCnt)) * float64(e.ExecuteNodesCnt) + } + return int(subtaskCount) +} + +func (e *LoadDataController) getAdjustedMaxEngineSize() int64 { + subtaskCount := e.calculateSubtaskCnt() + // we adjust MaxEngineSize to make sure each subtask has a similar amount of data to import. + return int64(math.Ceil(float64(e.TotalFileSize) / float64(subtaskCount))) +} + +// SetExecuteNodeCnt sets the execute node count. +func (e *LoadDataController) SetExecuteNodeCnt(cnt int) { + e.ExecuteNodesCnt = cnt +} + +// PopulateChunks populates chunks from table regions. +// in dist framework, this should be done in the tidb node which is responsible for splitting job into subtasks +// then table-importer handles data belongs to the subtask. +func (e *LoadDataController) PopulateChunks(ctx context.Context) (ecp map[int32]*checkpoints.EngineCheckpoint, err error) { + task := log.BeginTask(e.logger, "populate chunks") + defer func() { + task.End(zap.ErrorLevel, err) + }() + + tableMeta := &mydump.MDTableMeta{ + DB: e.DBName, + Name: e.Table.Meta().Name.O, + DataFiles: e.toMyDumpFiles(), + } + adjustedMaxEngineSize := e.getAdjustedMaxEngineSize() + e.logger.Info("adjust max engine size", zap.Int64("before", int64(e.MaxEngineSize)), + zap.Int64("after", adjustedMaxEngineSize)) + dataDivideCfg := &mydump.DataDivideConfig{ + ColumnCnt: len(e.Table.Meta().Columns), + EngineDataSize: adjustedMaxEngineSize, + MaxChunkSize: int64(config.MaxRegionSize), + Concurrency: e.ThreadCnt, + IOWorkers: nil, + Store: e.dataStore, + TableMeta: tableMeta, + + StrictFormat: e.SplitFile, + DataCharacterSet: *e.Charset, + DataInvalidCharReplace: string(utf8.RuneError), + ReadBlockSize: LoadDataReadBlockSize, + CSV: *e.GenerateCSVConfig(), + } + makeEngineCtx := log.NewContext(ctx, log.Logger{Logger: e.logger}) + tableRegions, err2 := mydump.MakeTableRegions(makeEngineCtx, dataDivideCfg) + + if err2 != nil { + e.logger.Error("populate chunks failed", zap.Error(err2)) + return nil, err2 + } + + var maxRowID int64 + timestamp := time.Now().Unix() + tableCp := &checkpoints.TableCheckpoint{ + Engines: map[int32]*checkpoints.EngineCheckpoint{}, + } + for _, region := range tableRegions { + engine, found := tableCp.Engines[region.EngineID] + if !found { + engine = &checkpoints.EngineCheckpoint{ + Status: checkpoints.CheckpointStatusLoaded, + } + tableCp.Engines[region.EngineID] = engine + } + ccp := &checkpoints.ChunkCheckpoint{ + Key: checkpoints.ChunkCheckpointKey{ + Path: region.FileMeta.Path, + Offset: region.Chunk.Offset, + }, + FileMeta: region.FileMeta, + ColumnPermutation: nil, + Chunk: region.Chunk, + Timestamp: timestamp, + } + engine.Chunks = append(engine.Chunks, ccp) + if region.Chunk.RowIDMax > maxRowID { + maxRowID = region.Chunk.RowIDMax + } + } + + // Add index engine checkpoint + tableCp.Engines[common.IndexEngineID] = &checkpoints.EngineCheckpoint{Status: checkpoints.CheckpointStatusLoaded} + return tableCp.Engines, nil +} + +// a simplified version of EstimateCompactionThreshold +func (ti *TableImporter) getTotalRawFileSize(indexCnt int64) int64 { + var totalSize int64 + for _, file := range ti.dataFiles { + size := file.RealSize + if file.Type == mydump.SourceTypeParquet { + // parquet file is compressed, thus estimates with a factor of 2 + size *= 2 + } + totalSize += size + } + return totalSize * indexCnt +} + +// OpenIndexEngine opens an index engine. +func (ti *TableImporter) OpenIndexEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { + idxEngineCfg := &backend.EngineConfig{ + TableInfo: ti.tableInfo, + } + idxCnt := len(ti.tableInfo.Core.Indices) + if !common.TableHasAutoRowID(ti.tableInfo.Core) { + idxCnt-- + } + // todo: getTotalRawFileSize returns size of all data files, but in distributed framework, + // we create one index engine for each engine, should reflect this in the future. + threshold := local.EstimateCompactionThreshold2(ti.getTotalRawFileSize(int64(idxCnt))) + idxEngineCfg.Local = backend.LocalEngineConfig{ + Compact: threshold > 0, + CompactConcurrency: 4, + CompactThreshold: threshold, + BlockSize: 16 * 1024, + } + fullTableName := ti.FullTableName() + // todo: cleanup all engine data on any error since we don't support checkpoint for now + // some return path, didn't make sure all data engine and index engine are cleaned up. + // maybe we can add this in upper level to clean the whole local-sort directory + mgr := backend.MakeEngineManager(ti.backend) + return mgr.OpenEngine(ctx, idxEngineCfg, fullTableName, engineID) +} + +// OpenDataEngine opens a data engine. +func (ti *TableImporter) OpenDataEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { + dataEngineCfg := &backend.EngineConfig{ + TableInfo: ti.tableInfo, + } + // todo: support checking IsRowOrdered later. + // also see test result here: https://github.com/pingcap/tidb/pull/47147 + //if ti.tableMeta.IsRowOrdered { + // dataEngineCfg.Local.Compact = true + // dataEngineCfg.Local.CompactConcurrency = 4 + // dataEngineCfg.Local.CompactThreshold = local.CompactionUpperThreshold + //} + mgr := backend.MakeEngineManager(ti.backend) + return mgr.OpenEngine(ctx, dataEngineCfg, ti.FullTableName(), engineID) +} + +// ImportAndCleanup imports the engine and cleanup the engine data. +func (ti *TableImporter) ImportAndCleanup(ctx context.Context, closedEngine *backend.ClosedEngine) (int64, error) { + var kvCount int64 + importErr := closedEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys) + if common.ErrFoundDuplicateKeys.Equal(importErr) { + importErr = local.ConvertToErrFoundConflictRecords(importErr, ti.encTable) + } + if closedEngine.GetID() != common.IndexEngineID { + // todo: change to a finer-grain progress later. + // each row is encoded into 1 data key + kvCount = ti.backend.GetImportedKVCount(closedEngine.GetUUID()) + } + cleanupErr := closedEngine.Cleanup(ctx) + return kvCount, multierr.Combine(importErr, cleanupErr) +} + +// Backend returns the backend of the importer. +func (ti *TableImporter) Backend() *local.Backend { + return ti.backend +} + +// Close implements the io.Closer interface. +func (ti *TableImporter) Close() error { + ti.backend.Close() + return nil +} + +// Allocators returns allocators used to record max used ID, i.e. PanickingAllocators. +func (ti *TableImporter) Allocators() autoid.Allocators { + return ti.encTable.Allocators(nil) +} + +// CheckDiskQuota checks disk quota. +func (ti *TableImporter) CheckDiskQuota(ctx context.Context) { + var locker sync.Locker + lockDiskQuota := func() { + if locker == nil { + ti.diskQuotaLock.Lock() + locker = ti.diskQuotaLock + } + } + unlockDiskQuota := func() { + if locker != nil { + locker.Unlock() + locker = nil + } + } + + defer unlockDiskQuota() + ti.logger.Info("start checking disk quota", zap.String("disk-quota", units.BytesSize(float64(ti.diskQuota)))) + for { + select { + case <-ctx.Done(): + return + case <-time.After(CheckDiskQuotaInterval): + } + + largeEngines, inProgressLargeEngines, totalDiskSize, totalMemSize := local.CheckDiskQuota(ti.backend, ti.diskQuota) + if len(largeEngines) == 0 && inProgressLargeEngines == 0 { + unlockDiskQuota() + continue + } + + ti.logger.Warn("disk quota exceeded", + zap.Int64("diskSize", totalDiskSize), + zap.Int64("memSize", totalMemSize), + zap.Int64("quota", ti.diskQuota), + zap.Int("largeEnginesCount", len(largeEngines)), + zap.Int("inProgressLargeEnginesCount", inProgressLargeEngines)) + + lockDiskQuota() + + if len(largeEngines) == 0 { + ti.logger.Warn("all large engines are already importing, keep blocking all writes") + continue + } + + if err := ti.backend.FlushAllEngines(ctx); err != nil { + ti.logger.Error("flush engine for disk quota failed, check again later", log.ShortError(err)) + unlockDiskQuota() + continue + } + + // at this point, all engines are synchronized on disk. + // we then import every large engines one by one and complete. + // if any engine failed to import, we just try again next time, since the data are still intact. + var importErr error + for _, engine := range largeEngines { + // Use a larger split region size to avoid split the same region by many times. + if err := ti.backend.UnsafeImportAndReset( + ctx, + engine, + int64(config.SplitRegionSize)*int64(config.MaxSplitRegionSizeRatio), + int64(config.SplitRegionKeys)*int64(config.MaxSplitRegionSizeRatio), + ); err != nil { + if common.ErrFoundDuplicateKeys.Equal(err) { + err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) + } + importErr = multierr.Append(importErr, err) + } + } + if importErr != nil { + // discuss: should we return the error and cancel the import? + ti.logger.Error("import large engines failed, check again later", log.ShortError(importErr)) + } + unlockDiskQuota() + } +} + +// SetSelectedRowCh sets the channel to receive selected rows. +func (ti *TableImporter) SetSelectedRowCh(ch chan QueryRow) { + ti.rowCh = ch +} + +func (ti *TableImporter) closeAndCleanupEngine(engine *backend.OpenedEngine) { + // outer context might be done, so we create a new context here. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + closedEngine, err := engine.Close(ctx) + if err != nil { + ti.logger.Error("close engine failed", zap.Error(err)) + return + } + if err = closedEngine.Cleanup(ctx); err != nil { + ti.logger.Error("cleanup engine failed", zap.Error(err)) + } +} + +// ImportSelectedRows imports selected rows. +func (ti *TableImporter) ImportSelectedRows(ctx context.Context, se sessionctx.Context) (*JobImportResult, error) { + var ( + err error + dataEngine, indexEngine *backend.OpenedEngine + ) + metrics := tidbmetrics.GetRegisteredImportMetrics(promutil.NewDefaultFactory(), + prometheus.Labels{ + proto.TaskIDLabelName: ti.id, + }) + ctx = metric.WithCommonMetric(ctx, metrics) + defer func() { + tidbmetrics.UnregisterImportMetrics(metrics) + if dataEngine != nil { + ti.closeAndCleanupEngine(dataEngine) + } + if indexEngine != nil { + ti.closeAndCleanupEngine(indexEngine) + } + }() + + dataEngine, err = ti.OpenDataEngine(ctx, 1) + if err != nil { + return nil, err + } + indexEngine, err = ti.OpenIndexEngine(ctx, common.IndexEngineID) + if err != nil { + return nil, err + } + + var ( + mu sync.Mutex + checksum = verify.NewKVGroupChecksumWithKeyspace(ti.keyspace) + colSizeMap = make(map[int64]int64) + ) + eg, egCtx := tidbutil.NewErrorGroupWithRecoverWithCtx(ctx) + for i := 0; i < ti.ThreadCnt; i++ { + eg.Go(func() error { + chunkCheckpoint := checkpoints.ChunkCheckpoint{} + chunkChecksum := verify.NewKVGroupChecksumWithKeyspace(ti.keyspace) + progress := NewProgress() + defer func() { + mu.Lock() + defer mu.Unlock() + checksum.Add(chunkChecksum) + for k, v := range progress.GetColSize() { + colSizeMap[k] += v + } + }() + return ProcessChunk(egCtx, &chunkCheckpoint, ti, dataEngine, indexEngine, progress, ti.logger, chunkChecksum) + }) + } + if err = eg.Wait(); err != nil { + return nil, err + } + + closedDataEngine, err := dataEngine.Close(ctx) + if err != nil { + return nil, err + } + failpoint.Inject("mockImportFromSelectErr", func() { + failpoint.Return(nil, errors.New("mock import from select error")) + }) + if err = closedDataEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys); err != nil { + if common.ErrFoundDuplicateKeys.Equal(err) { + err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) + } + return nil, err + } + dataKVCount := ti.backend.GetImportedKVCount(closedDataEngine.GetUUID()) + + closedIndexEngine, err := indexEngine.Close(ctx) + if err != nil { + return nil, err + } + if err = closedIndexEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys); err != nil { + if common.ErrFoundDuplicateKeys.Equal(err) { + err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) + } + return nil, err + } + + allocators := ti.Allocators() + maxIDs := map[autoid.AllocatorType]int64{ + autoid.RowIDAllocType: allocators.Get(autoid.RowIDAllocType).Base(), + autoid.AutoIncrementType: allocators.Get(autoid.AutoIncrementType).Base(), + autoid.AutoRandomType: allocators.Get(autoid.AutoRandomType).Base(), + } + if err = PostProcess(ctx, se, maxIDs, ti.Plan, checksum, ti.logger); err != nil { + return nil, err + } + + return &JobImportResult{ + Affected: uint64(dataKVCount), + ColSizeMap: colSizeMap, + }, nil +} + +func adjustDiskQuota(diskQuota int64, sortDir string, logger *zap.Logger) int64 { + sz, err := common.GetStorageSize(sortDir) + if err != nil { + logger.Warn("failed to get storage size", zap.Error(err)) + if diskQuota != 0 { + return diskQuota + } + logger.Info("use default quota instead", zap.Int64("quota", int64(DefaultDiskQuota))) + return int64(DefaultDiskQuota) + } + + maxDiskQuota := int64(float64(sz.Capacity) * 0.8) + switch { + case diskQuota == 0: + logger.Info("use 0.8 of the storage size as default disk quota", + zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) + return maxDiskQuota + case diskQuota > maxDiskQuota: + logger.Warn("disk quota is larger than 0.8 of the storage size, use 0.8 of the storage size instead", + zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) + return maxDiskQuota + default: + return diskQuota + } +} + +// PostProcess does the post-processing for the task. +// exported for testing. +func PostProcess( + ctx context.Context, + se sessionctx.Context, + maxIDs map[autoid.AllocatorType]int64, + plan *Plan, + localChecksum *verify.KVGroupChecksum, + logger *zap.Logger, +) (err error) { + callLog := log.BeginTask(logger.With(zap.Object("checksum", localChecksum)), "post process") + defer func() { + callLog.End(zap.ErrorLevel, err) + }() + + if err = RebaseAllocatorBases(ctx, se.GetStore(), maxIDs, plan, logger); err != nil { + return err + } + + return VerifyChecksum(ctx, plan, localChecksum.MergedChecksum(), se, logger) +} + +type autoIDRequirement struct { + store tidbkv.Storage + autoidCli *autoid.ClientDiscover +} + +func (r *autoIDRequirement) Store() tidbkv.Storage { + return r.store +} + +func (r *autoIDRequirement) AutoIDClient() *autoid.ClientDiscover { + return r.autoidCli +} + +// RebaseAllocatorBases rebase the allocator bases. +func RebaseAllocatorBases(ctx context.Context, kvStore tidbkv.Storage, maxIDs map[autoid.AllocatorType]int64, plan *Plan, logger *zap.Logger) (err error) { + callLog := log.BeginTask(logger, "rebase allocators") + defer func() { + callLog.End(zap.ErrorLevel, err) + }() + + if !common.TableHasAutoID(plan.DesiredTableInfo) { + return nil + } + + tidbCfg := tidb.GetGlobalConfig() + hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) + tls, err2 := common.NewTLS( + tidbCfg.Security.ClusterSSLCA, + tidbCfg.Security.ClusterSSLCert, + tidbCfg.Security.ClusterSSLKey, + hostPort, + nil, nil, nil, + ) + if err2 != nil { + return err2 + } + + addrs := strings.Split(tidbCfg.Path, ",") + etcdCli, err := clientv3.New(clientv3.Config{ + Endpoints: addrs, + AutoSyncInterval: 30 * time.Second, + TLS: tls.TLSConfig(), + }) + if err != nil { + return errors.Trace(err) + } + etcd.SetEtcdCliByNamespace(etcdCli, keyspace.MakeKeyspaceEtcdNamespace(kvStore.GetCodec())) + autoidCli := autoid.NewClientDiscover(etcdCli) + r := autoIDRequirement{store: kvStore, autoidCli: autoidCli} + err = common.RebaseTableAllocators(ctx, maxIDs, &r, plan.DBID, plan.DesiredTableInfo) + if err1 := etcdCli.Close(); err1 != nil { + logger.Info("close etcd client error", zap.Error(err1)) + } + autoidCli.ResetConn(nil) + return errors.Trace(err) +} + +// VerifyChecksum verify the checksum of the table. +func VerifyChecksum(ctx context.Context, plan *Plan, localChecksum verify.KVChecksum, se sessionctx.Context, logger *zap.Logger) error { + if plan.Checksum == config.OpLevelOff { + return nil + } + logger.Info("local checksum", zap.Object("checksum", &localChecksum)) + + failpoint.Inject("waitCtxDone", func() { + <-ctx.Done() + }) + + remoteChecksum, err := checksumTable(ctx, se, plan, logger) + if err != nil { + if plan.Checksum != config.OpLevelOptional { + return err + } + logger.Warn("checksumTable failed, will skip this error and go on", zap.Error(err)) + } + if remoteChecksum != nil { + if !remoteChecksum.IsEqual(&localChecksum) { + err2 := common.ErrChecksumMismatch.GenWithStackByArgs( + remoteChecksum.Checksum, localChecksum.Sum(), + remoteChecksum.TotalKVs, localChecksum.SumKVS(), + remoteChecksum.TotalBytes, localChecksum.SumSize(), + ) + if plan.Checksum == config.OpLevelOptional { + logger.Warn("verify checksum failed, but checksum is optional, will skip it", zap.Error(err2)) + err2 = nil + } + return err2 + } + logger.Info("checksum pass", zap.Object("local", &localChecksum)) + } + return nil +} + +func checksumTable(ctx context.Context, se sessionctx.Context, plan *Plan, logger *zap.Logger) (*local.RemoteChecksum, error) { + var ( + tableName = common.UniqueTable(plan.DBName, plan.TableInfo.Name.L) + sql = "ADMIN CHECKSUM TABLE " + tableName + maxErrorRetryCount = 3 + distSQLScanConcurrencyFactor = 1 + remoteChecksum *local.RemoteChecksum + txnErr error + doneCh = make(chan struct{}) + ) + checkCtx, cancel := context.WithCancel(ctx) + defer func() { + cancel() + <-doneCh + }() + + go func() { + <-checkCtx.Done() + se.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted) + close(doneCh) + }() + + distSQLScanConcurrencyBak := se.GetSessionVars().DistSQLScanConcurrency() + defer func() { + se.GetSessionVars().SetDistSQLScanConcurrency(distSQLScanConcurrencyBak) + }() + ctx = util.WithInternalSourceType(checkCtx, tidbkv.InternalImportInto) + for i := 0; i < maxErrorRetryCount; i++ { + txnErr = func() error { + // increase backoff weight + if err := setBackoffWeight(se, plan, logger); err != nil { + logger.Warn("set tidb_backoff_weight failed", zap.Error(err)) + } + + newConcurrency := mathutil.Max(plan.DistSQLScanConcurrency/distSQLScanConcurrencyFactor, local.MinDistSQLScanConcurrency) + logger.Info("checksum with adjusted distsql scan concurrency", zap.Int("concurrency", newConcurrency)) + se.GetSessionVars().SetDistSQLScanConcurrency(newConcurrency) + + // TODO: add resource group name + + rs, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql) + if err != nil { + return err + } + if len(rs) < 1 { + return errors.New("empty checksum result") + } + + failpoint.Inject("errWhenChecksum", func() { + failpoint.Return(errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline")) + }) + + // ADMIN CHECKSUM TABLE .
example. + // mysql> admin checksum table test.t; + // +---------+------------+---------------------+-----------+-------------+ + // | Db_name | Table_name | Checksum_crc64_xor | Total_kvs | Total_bytes | + // +---------+------------+---------------------+-----------+-------------+ + // | test | t | 8520875019404689597 | 7296873 | 357601387 | + // +---------+------------+------------- + remoteChecksum = &local.RemoteChecksum{ + Schema: rs[0].GetString(0), + Table: rs[0].GetString(1), + Checksum: rs[0].GetUint64(2), + TotalKVs: rs[0].GetUint64(3), + TotalBytes: rs[0].GetUint64(4), + } + return nil + }() + if !common.IsRetryableError(txnErr) { + break + } + distSQLScanConcurrencyFactor *= 2 + logger.Warn("retry checksum table", zap.Int("retry count", i+1), zap.Error(txnErr)) + } + return remoteChecksum, txnErr +} + +func setBackoffWeight(se sessionctx.Context, plan *Plan, logger *zap.Logger) error { + backoffWeight := local.DefaultBackoffWeight + if val, ok := plan.ImportantSysVars[variable.TiDBBackOffWeight]; ok { + if weight, err := strconv.Atoi(val); err == nil && weight > backoffWeight { + backoffWeight = weight + } + } + logger.Info("set backoff weight", zap.Int("weight", backoffWeight)) + return se.GetSessionVars().SetSystemVar(variable.TiDBBackOffWeight, strconv.Itoa(backoffWeight)) +} + +// GetImportRootDir returns the root directory for import. +// The directory structure is like: +// +// -> /path/to/tidb-tmpdir +// -> import-4000 +// -> 1 +// -> some-uuid +// +// exported for testing. +func GetImportRootDir(tidbCfg *tidb.Config) string { + sortPathSuffix := "import-" + strconv.Itoa(int(tidbCfg.Port)) + return filepath.Join(tidbCfg.TempDir, sortPathSuffix) +} + +// FlushTableStats flushes the stats of the table. +// stats will be flushed in domain.updateStatsWorker, default interval is [1, 2) minutes, +// see DumpStatsDeltaToKV for more details. then the background analyzer will analyze +// the table. +// the stats stay in memory until the next flush, so it might be lost if the tidb-server restarts. +func FlushTableStats(ctx context.Context, se sessionctx.Context, tableID int64, result *JobImportResult) error { + if err := sessiontxn.NewTxn(ctx, se); err != nil { + return err + } + sessionVars := se.GetSessionVars() + sessionVars.TxnCtxMu.Lock() + defer sessionVars.TxnCtxMu.Unlock() + sessionVars.TxnCtx.UpdateDeltaForTable(tableID, int64(result.Affected), int64(result.Affected), result.ColSizeMap) + se.StmtCommit(ctx) + return se.CommitTxn(ctx) +} diff --git a/pkg/executor/index_merge_reader.go b/pkg/executor/index_merge_reader.go index 84261164a5925..2034dac4b2eea 100644 --- a/pkg/executor/index_merge_reader.go +++ b/pkg/executor/index_merge_reader.go @@ -326,12 +326,12 @@ func (e *IndexMergeReaderExecutor) startIndexMergeProcessWorker(ctx context.Cont } func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { - failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeResultChCloseEarly")); _err_ == nil { // Wait for processWorker to close resultCh. time.Sleep(time.Second * 2) // Should use fetchCh instead of resultCh to send error. syncErr(ctx, e.finished, fetchCh, errors.New("testIndexMergeResultChCloseEarly")) - }) + } if e.RuntimeStats() != nil { collExec := true e.dagPBs[workID].CollectExecutionSummaries = &collExec @@ -343,9 +343,9 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, } else { keyRanges = [][]kv.KeyRange{e.keyRanges[workID]} } - failpoint.Inject("startPartialIndexWorkerErr", func() error { + if _, _err_ := failpoint.Eval(_curpkg_("startPartialIndexWorkerErr")); _err_ == nil { return errors.New("inject an error before start partialIndexWorker") - }) + } // for union case, the push-downLimit can be utilized to limit index fetched handles. // for intersection case, the push-downLimit can only be conducted after all index path/table finished. @@ -359,7 +359,7 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, defer e.idxWorkerWg.Done() util.WithRecovery( func() { - failpoint.Inject("testIndexMergePanicPartialIndexWorker", nil) + failpoint.Eval(_curpkg_("testIndexMergePanicPartialIndexWorker")) is := e.partialPlans[workID][0].(*plannercore.PhysicalIndexScan) worker := &partialIndexWorker{ stats: e.stats, @@ -443,7 +443,7 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, return } results = append(results, result) - failpoint.Inject("testIndexMergePartialIndexWorkerCoprLeak", nil) + failpoint.Eval(_curpkg_("testIndexMergePartialIndexWorkerCoprLeak")) } if len(results) > 1 && len(e.byItems) != 0 { // e.Schema() not the output schema for partialIndexReader, and we put byItems related column at first in `buildIndexReq`, so use nil here. @@ -486,7 +486,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, defer e.idxWorkerWg.Done() util.WithRecovery( func() { - failpoint.Inject("testIndexMergePanicPartialTableWorker", nil) + failpoint.Eval(_curpkg_("testIndexMergePanicPartialTableWorker")) var err error partialTableReader := &TableReaderExecutor{ BaseExecutorV2: exec.NewBaseExecutorV2(e.Ctx().GetSessionVars(), ts.Schema(), e.getPartitalPlanID(workID)), @@ -558,7 +558,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, syncErr(ctx, e.finished, fetchCh, err) break } - failpoint.Inject("testIndexMergePartialTableWorkerCoprLeak", nil) + failpoint.Eval(_curpkg_("testIndexMergePartialTableWorkerCoprLeak")) tableReaderClosed = false worker.batchSize = e.MaxChunkSize() if worker.batchSize > worker.maxBatchSize { @@ -705,9 +705,9 @@ func (w *partialTableWorker) extractTaskHandles(ctx context.Context, chk *chunk. w.tableReader.RuntimeStats().Record(time.Since(start), chk.NumRows()) } if chk.NumRows() == 0 { - failpoint.Inject("testIndexMergeErrorPartialTableWorker", func(v failpoint.Value) { - failpoint.Return(handles, nil, errors.New(v.(string))) - }) + if v, _err_ := failpoint.Eval(_curpkg_("testIndexMergeErrorPartialTableWorker")); _err_ == nil { + return handles, nil, errors.New(v.(string)) + } return handles, retChk, nil } memDelta := chk.MemoryUsage() @@ -851,12 +851,12 @@ func (e *IndexMergeReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) e } func (e *IndexMergeReaderExecutor) getResultTask(ctx context.Context) (*indexMergeTableTask, error) { - failpoint.Inject("testIndexMergeMainReturnEarly", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeMainReturnEarly")); _err_ == nil { // To make sure processWorker make resultCh to be full. // When main goroutine close finished, processWorker may be stuck when writing resultCh. time.Sleep(time.Second * 20) - failpoint.Return(nil, errors.New("failpoint testIndexMergeMainReturnEarly")) - }) + return nil, errors.New("failpoint testIndexMergeMainReturnEarly") + } if e.resultCurr != nil && e.resultCurr.cursor < len(e.resultCurr.rows) { return e.resultCurr, nil } @@ -1154,9 +1154,9 @@ func (w *indexMergeProcessWorker) fetchLoopUnionWithOrderBy(ctx context.Context, case <-finished: return case resultCh <- task: - failpoint.Inject("testCancelContext", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testCancelContext")); _err_ == nil { IndexMergeCancelFuncForTest() - }) + } select { case <-ctx.Done(): return @@ -1190,14 +1190,14 @@ func pushedLimitCountingDown(pushedLimit *plannercore.PushedDownLimit, handles [ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <-chan *indexMergeTableTask, workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { - failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { - failpoint.Return() - }) + if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeResultChCloseEarly")); _err_ == nil { + return + } memTracker := memory.NewTracker(w.indexMerge.ID(), -1) memTracker.AttachTo(w.indexMerge.memTracker) defer memTracker.Detach() defer close(workCh) - failpoint.Inject("testIndexMergePanicProcessWorkerUnion", nil) + failpoint.Eval(_curpkg_("testIndexMergePanicProcessWorkerUnion")) var pushedLimit *plannercore.PushedDownLimit if w.indexMerge.pushedLimit != nil { @@ -1267,23 +1267,23 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- if w.stats != nil { w.stats.IndexMergeProcess += time.Since(start) } - failpoint.Inject("testIndexMergeProcessWorkerUnionHang", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeProcessWorkerUnionHang")); _err_ == nil { for i := 0; i < cap(resultCh); i++ { select { case resultCh <- &indexMergeTableTask{}: default: } } - }) + } select { case <-ctx.Done(): return case <-finished: return case resultCh <- task: - failpoint.Inject("testCancelContext", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testCancelContext")); _err_ == nil { IndexMergeCancelFuncForTest() - }) + } select { case <-ctx.Done(): return @@ -1377,7 +1377,7 @@ func (w *intersectionProcessWorker) consumeMemDelta() { // doIntersectionPerPartition fetch all the task from workerChannel, and after that, then do the intersection pruning, which // will cause wasting a lot of time waiting for all the fetch task done. func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Context, workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished, limitDone <-chan struct{}) { - failpoint.Inject("testIndexMergePanicPartitionTableIntersectionWorker", nil) + failpoint.Eval(_curpkg_("testIndexMergePanicPartitionTableIntersectionWorker")) defer w.memTracker.Detach() for task := range w.workerCh { @@ -1419,7 +1419,7 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte if w.rowDelta >= int64(w.batchSize) { w.consumeMemDelta() } - failpoint.Inject("testIndexMergeIntersectionWorkerPanic", nil) + failpoint.Eval(_curpkg_("testIndexMergeIntersectionWorkerPanic")) } if w.rowDelta > 0 { w.consumeMemDelta() @@ -1460,7 +1460,7 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte zap.Int("parTblIdx", parTblIdx), zap.Int("task.handles", len(task.handles))) } } - failpoint.Inject("testIndexMergeProcessWorkerIntersectionHang", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeProcessWorkerIntersectionHang")); _err_ == nil { if resultCh != nil { for i := 0; i < cap(resultCh); i++ { select { @@ -1469,7 +1469,7 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte } } } - }) + } for _, task := range tasks { select { case <-ctx.Done(): @@ -1534,7 +1534,7 @@ func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fet }() } - failpoint.Inject("testIndexMergePanicProcessWorkerIntersection", nil) + failpoint.Eval(_curpkg_("testIndexMergePanicProcessWorkerIntersection")) // One goroutine may handle one or multiple partitions. // Max number of partition number is 8192, we use ExecutorConcurrency to avoid too many goroutines. @@ -1548,12 +1548,12 @@ func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fet partCnt = len(w.indexMerge.prunedPartitions) } workerCnt := min(partCnt, maxWorkerCnt) - failpoint.Inject("testIndexMergeIntersectionConcurrency", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testIndexMergeIntersectionConcurrency")); _err_ == nil { con := val.(int) if con != workerCnt { panic(fmt.Sprintf("unexpected workerCnt, expect %d, got %d", con, workerCnt)) } - }) + } partitionIDMap := make(map[int64]int) if w.indexMerge.hasGlobalIndex { @@ -1803,9 +1803,9 @@ func (w *partialIndexWorker) extractTaskHandles(ctx context.Context, chk *chunk. w.sc.GetSessionVars().StmtCtx.RuntimeStatsColl.GetBasicRuntimeStats(w.idxID).Record(time.Since(start), chk.NumRows()) } if chk.NumRows() == 0 { - failpoint.Inject("testIndexMergeErrorPartialIndexWorker", func(v failpoint.Value) { - failpoint.Return(handles, nil, errors.New(v.(string))) - }) + if v, _err_ := failpoint.Eval(_curpkg_("testIndexMergeErrorPartialIndexWorker")); _err_ == nil { + return handles, nil, errors.New(v.(string)) + } return handles, retChk, nil } memDelta := chk.MemoryUsage() @@ -1891,7 +1891,7 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** } // Make sure panic failpoint is after fetch task from workCh. // Otherwise, cannot send error to task.doneCh. - failpoint.Inject("testIndexMergePanicTableScanWorker", nil) + failpoint.Eval(_curpkg_("testIndexMergePanicTableScanWorker")) execStart := time.Now() err := w.executeTask(ctx, *task) if w.stats != nil { @@ -1899,7 +1899,7 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** atomic.AddInt64(&w.stats.FetchRow, int64(time.Since(execStart))) atomic.AddInt64(&w.stats.TableTaskNum, 1) } - failpoint.Inject("testIndexMergePickAndExecTaskPanic", nil) + failpoint.Eval(_curpkg_("testIndexMergePickAndExecTaskPanic")) select { case <-ctx.Done(): return diff --git a/pkg/executor/index_merge_reader.go__failpoint_stash__ b/pkg/executor/index_merge_reader.go__failpoint_stash__ new file mode 100644 index 0000000000000..84261164a5925 --- /dev/null +++ b/pkg/executor/index_merge_reader.go__failpoint_stash__ @@ -0,0 +1,2056 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "cmp" + "container/heap" + "context" + "fmt" + "runtime/trace" + "slices" + "sort" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/executor/internal/builder" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +var ( + _ exec.Executor = &IndexMergeReaderExecutor{} + + // IndexMergeCancelFuncForTest is used just for test + IndexMergeCancelFuncForTest func() +) + +const ( + partialIndexWorkerType = "IndexMergePartialIndexWorker" + partialTableWorkerType = "IndexMergePartialTableWorker" + processWorkerType = "IndexMergeProcessWorker" + partTblIntersectionWorkerType = "IndexMergePartTblIntersectionWorker" + tableScanWorkerType = "IndexMergeTableScanWorker" +) + +// IndexMergeReaderExecutor accesses a table with multiple index/table scan. +// There are three types of workers: +// 1. partialTableWorker/partialIndexWorker, which are used to fetch the handles +// 2. indexMergeProcessWorker, which is used to do the `Union` operation. +// 3. indexMergeTableScanWorker, which is used to get the table tuples with the given handles. +// +// The execution flow is really like IndexLookUpReader. However, it uses multiple index scans +// or table scans to get the handles: +// 1. use the partialTableWorkers and partialIndexWorkers to fetch the handles (a batch per time) +// and send them to the indexMergeProcessWorker. +// 2. indexMergeProcessWorker do the `Union` operation for a batch of handles it have got. +// For every handle in the batch: +// 1. check whether it has been accessed. +// 2. if not, record it and send it to the indexMergeTableScanWorker. +// 3. if accessed, just ignore it. +type IndexMergeReaderExecutor struct { + exec.BaseExecutor + indexUsageReporter *exec.IndexUsageReporter + + table table.Table + indexes []*model.IndexInfo + descs []bool + ranges [][]*ranger.Range + dagPBs []*tipb.DAGRequest + startTS uint64 + tableRequest *tipb.DAGRequest + + keepOrder bool + pushedLimit *plannercore.PushedDownLimit + byItems []*plannerutil.ByItems + + // columns are only required by union scan. + columns []*model.ColumnInfo + // partitionIDMap are only required by union scan with global index. + partitionIDMap map[int64]struct{} + *dataReaderBuilder + + // fields about accessing partition tables + partitionTableMode bool // if this IndexMerge is accessing a partition table + prunedPartitions []table.PhysicalTable // pruned partition tables need to access + partitionKeyRanges [][][]kv.KeyRange // [partialIndex][partitionIdx][ranges] + + // All fields above are immutable. + + tblWorkerWg sync.WaitGroup + idxWorkerWg sync.WaitGroup + processWorkerWg sync.WaitGroup + finished chan struct{} + + workerStarted bool + keyRanges [][]kv.KeyRange + + resultCh chan *indexMergeTableTask + resultCurr *indexMergeTableTask + + // memTracker is used to track the memory usage of this executor. + memTracker *memory.Tracker + + partialPlans [][]base.PhysicalPlan + tblPlans []base.PhysicalPlan + partialNetDataSizes []float64 + dataAvgRowSize float64 + + handleCols plannerutil.HandleCols + stats *IndexMergeRuntimeStat + + // Indicates whether there is correlated column in filter or table/index range. + // We need to refresh dagPBs before send DAGReq to storage. + isCorColInPartialFilters []bool + isCorColInTableFilter bool + isCorColInPartialAccess []bool + + // Whether it's intersection or union. + isIntersection bool + + hasGlobalIndex bool +} + +type indexMergeTableTask struct { + lookupTableTask + + // parTblIdx are only used in indexMergeProcessWorker.fetchLoopIntersection. + parTblIdx int + + // partialPlanID are only used for indexMergeProcessWorker.fetchLoopUnionWithOrderBy. + partialPlanID int +} + +// Table implements the dataSourceExecutor interface. +func (e *IndexMergeReaderExecutor) Table() table.Table { + return e.table +} + +// Open implements the Executor Open interface +func (e *IndexMergeReaderExecutor) Open(_ context.Context) (err error) { + e.keyRanges = make([][]kv.KeyRange, 0, len(e.partialPlans)) + e.initRuntimeStats() + if e.isCorColInTableFilter { + e.tableRequest.Executors, err = builder.ConstructListBasedDistExec(e.Ctx().GetBuildPBCtx(), e.tblPlans) + if err != nil { + return err + } + } + if err = e.rebuildRangeForCorCol(); err != nil { + return err + } + + if !e.partitionTableMode { + if e.keyRanges, err = e.buildKeyRangesForTable(e.table); err != nil { + return err + } + } else { + e.partitionKeyRanges = make([][][]kv.KeyRange, len(e.indexes)) + tmpPartitionKeyRanges := make([][][]kv.KeyRange, len(e.prunedPartitions)) + for i, p := range e.prunedPartitions { + if tmpPartitionKeyRanges[i], err = e.buildKeyRangesForTable(p); err != nil { + return err + } + } + for i, idx := range e.indexes { + if idx != nil && idx.Global { + keyRange, _ := distsql.IndexRangesToKVRanges(e.ctx.GetDistSQLCtx(), e.table.Meta().ID, idx.ID, e.ranges[i]) + e.partitionKeyRanges[i] = [][]kv.KeyRange{keyRange.FirstPartitionRange()} + } else { + for _, pKeyRanges := range tmpPartitionKeyRanges { + e.partitionKeyRanges[i] = append(e.partitionKeyRanges[i], pKeyRanges[i]) + } + } + } + } + e.finished = make(chan struct{}) + e.resultCh = make(chan *indexMergeTableTask, atomic.LoadInt32(&LookupTableTaskChannelSize)) + if e.memTracker != nil { + e.memTracker.Reset() + } else { + e.memTracker = memory.NewTracker(e.ID(), -1) + } + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + return nil +} + +func (e *IndexMergeReaderExecutor) rebuildRangeForCorCol() (err error) { + len1 := len(e.partialPlans) + len2 := len(e.isCorColInPartialAccess) + if len1 != len2 { + return errors.Errorf("unexpect length for partialPlans(%d) and isCorColInPartialAccess(%d)", len1, len2) + } + for i, plan := range e.partialPlans { + if e.isCorColInPartialAccess[i] { + switch x := plan[0].(type) { + case *plannercore.PhysicalIndexScan: + e.ranges[i], err = rebuildIndexRanges(e.Ctx().GetExprCtx(), e.Ctx().GetRangerCtx(), x, x.IdxCols, x.IdxColLens) + case *plannercore.PhysicalTableScan: + e.ranges[i], err = x.ResolveCorrelatedColumns() + default: + err = errors.Errorf("unsupported plan type %T", plan[0]) + } + if err != nil { + return err + } + } + } + return nil +} + +func (e *IndexMergeReaderExecutor) buildKeyRangesForTable(tbl table.Table) (ranges [][]kv.KeyRange, err error) { + dctx := e.Ctx().GetDistSQLCtx() + for i, plan := range e.partialPlans { + _, ok := plan[0].(*plannercore.PhysicalIndexScan) + if !ok { + firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(e.ranges[i], false, e.descs[i], tbl.Meta().IsCommonHandle) + firstKeyRanges, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, firstPartRanges) + if err != nil { + return nil, err + } + secondKeyRanges, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, secondPartRanges) + if err != nil { + return nil, err + } + keyRanges := append(firstKeyRanges.FirstPartitionRange(), secondKeyRanges.FirstPartitionRange()...) + ranges = append(ranges, keyRanges) + continue + } + keyRange, err := distsql.IndexRangesToKVRanges(dctx, getPhysicalTableID(tbl), e.indexes[i].ID, e.ranges[i]) + if err != nil { + return nil, err + } + ranges = append(ranges, keyRange.FirstPartitionRange()) + } + return ranges, nil +} + +func (e *IndexMergeReaderExecutor) startWorkers(ctx context.Context) error { + exitCh := make(chan struct{}) + workCh := make(chan *indexMergeTableTask, 1) + fetchCh := make(chan *indexMergeTableTask, len(e.keyRanges)) + + e.startIndexMergeProcessWorker(ctx, workCh, fetchCh) + + var err error + for i := 0; i < len(e.partialPlans); i++ { + e.idxWorkerWg.Add(1) + if e.indexes[i] != nil { + err = e.startPartialIndexWorker(ctx, exitCh, fetchCh, i) + } else { + err = e.startPartialTableWorker(ctx, exitCh, fetchCh, i) + } + if err != nil { + e.idxWorkerWg.Done() + break + } + } + go e.waitPartialWorkersAndCloseFetchChan(fetchCh) + if err != nil { + close(exitCh) + return err + } + e.startIndexMergeTableScanWorker(ctx, workCh) + e.workerStarted = true + return nil +} + +func (e *IndexMergeReaderExecutor) waitPartialWorkersAndCloseFetchChan(fetchCh chan *indexMergeTableTask) { + e.idxWorkerWg.Wait() + close(fetchCh) +} + +func (e *IndexMergeReaderExecutor) startIndexMergeProcessWorker(ctx context.Context, workCh chan<- *indexMergeTableTask, fetch <-chan *indexMergeTableTask) { + idxMergeProcessWorker := &indexMergeProcessWorker{ + indexMerge: e, + stats: e.stats, + } + e.processWorkerWg.Add(1) + go func() { + defer trace.StartRegion(ctx, "IndexMergeProcessWorker").End() + util.WithRecovery( + func() { + if e.isIntersection { + if e.keepOrder { + // todo: implementing fetchLoopIntersectionWithOrderBy if necessary. + panic("Not support intersection with keepOrder = true") + } + idxMergeProcessWorker.fetchLoopIntersection(ctx, fetch, workCh, e.resultCh, e.finished) + } else if len(e.byItems) != 0 { + idxMergeProcessWorker.fetchLoopUnionWithOrderBy(ctx, fetch, workCh, e.resultCh, e.finished) + } else { + idxMergeProcessWorker.fetchLoopUnion(ctx, fetch, workCh, e.resultCh, e.finished) + } + }, + handleWorkerPanic(ctx, e.finished, nil, e.resultCh, nil, processWorkerType), + ) + e.processWorkerWg.Done() + }() +} + +func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { + failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { + // Wait for processWorker to close resultCh. + time.Sleep(time.Second * 2) + // Should use fetchCh instead of resultCh to send error. + syncErr(ctx, e.finished, fetchCh, errors.New("testIndexMergeResultChCloseEarly")) + }) + if e.RuntimeStats() != nil { + collExec := true + e.dagPBs[workID].CollectExecutionSummaries = &collExec + } + + var keyRanges [][]kv.KeyRange + if e.partitionTableMode { + keyRanges = e.partitionKeyRanges[workID] + } else { + keyRanges = [][]kv.KeyRange{e.keyRanges[workID]} + } + failpoint.Inject("startPartialIndexWorkerErr", func() error { + return errors.New("inject an error before start partialIndexWorker") + }) + + // for union case, the push-downLimit can be utilized to limit index fetched handles. + // for intersection case, the push-downLimit can only be conducted after all index path/table finished. + pushedIndexLimit := e.pushedLimit + if e.isIntersection { + pushedIndexLimit = nil + } + + go func() { + defer trace.StartRegion(ctx, "IndexMergePartialIndexWorker").End() + defer e.idxWorkerWg.Done() + util.WithRecovery( + func() { + failpoint.Inject("testIndexMergePanicPartialIndexWorker", nil) + is := e.partialPlans[workID][0].(*plannercore.PhysicalIndexScan) + worker := &partialIndexWorker{ + stats: e.stats, + idxID: e.getPartitalPlanID(workID), + sc: e.Ctx(), + dagPB: e.dagPBs[workID], + plan: e.partialPlans[workID], + batchSize: e.MaxChunkSize(), + maxBatchSize: e.Ctx().GetSessionVars().IndexLookupSize, + maxChunkSize: e.MaxChunkSize(), + memTracker: e.memTracker, + partitionTableMode: e.partitionTableMode, + prunedPartitions: e.prunedPartitions, + byItems: is.ByItems, + pushedLimit: pushedIndexLimit, + } + if e.isCorColInPartialFilters[workID] { + // We got correlated column, so need to refresh Selection operator. + var err error + if e.dagPBs[workID].Executors, err = builder.ConstructListBasedDistExec(e.Ctx().GetBuildPBCtx(), e.partialPlans[workID]); err != nil { + syncErr(ctx, e.finished, fetchCh, err) + return + } + } + var builder distsql.RequestBuilder + builder.SetDAGRequest(e.dagPBs[workID]). + SetStartTS(e.startTS). + SetDesc(e.descs[workID]). + SetKeepOrder(e.keepOrder). + SetTxnScope(e.txnScope). + SetReadReplicaScope(e.readReplicaScope). + SetIsStaleness(e.isStaleness). + SetFromSessionVars(e.Ctx().GetDistSQLCtx()). + SetMemTracker(e.memTracker). + SetFromInfoSchema(e.Ctx().GetInfoSchema()). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetDistSQLCtx(), &builder.Request, e.partialNetDataSizes[workID])). + SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias) + + worker.batchSize = CalculateBatchSize(int(is.StatsCount()), e.MaxChunkSize(), worker.maxBatchSize) + if builder.Request.Paging.Enable && builder.Request.Paging.MinPagingSize < uint64(worker.batchSize) { + // when paging enabled and Paging.MinPagingSize less than initBatchSize, change Paging.MinPagingSize to + // initial batchSize to avoid redundant paging RPC, see more detail in https://github.com/pingcap/tidb/issues/54066 + builder.Request.Paging.MinPagingSize = uint64(worker.batchSize) + if builder.Request.Paging.MaxPagingSize < uint64(worker.batchSize) { + builder.Request.Paging.MaxPagingSize = uint64(worker.batchSize) + } + } + tps := worker.getRetTpsForIndexScan(e.handleCols) + results := make([]distsql.SelectResult, 0, len(keyRanges)) + defer func() { + // To make sure SelectResult.Close() is called even got panic in fetchHandles(). + for _, result := range results { + if err := result.Close(); err != nil { + logutil.Logger(ctx).Error("close Select result failed", zap.Error(err)) + } + } + }() + for _, keyRange := range keyRanges { + // check if this executor is closed + select { + case <-ctx.Done(): + return + case <-e.finished: + return + default: + } + + // init kvReq and worker for this partition + // The key ranges should be ordered. + slices.SortFunc(keyRange, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + kvReq, err := builder.SetKeyRanges(keyRange).Build() + if err != nil { + syncErr(ctx, e.finished, fetchCh, err) + return + } + result, err := distsql.SelectWithRuntimeStats(ctx, e.Ctx().GetDistSQLCtx(), kvReq, tps, getPhysicalPlanIDs(e.partialPlans[workID]), e.getPartitalPlanID(workID)) + if err != nil { + syncErr(ctx, e.finished, fetchCh, err) + return + } + results = append(results, result) + failpoint.Inject("testIndexMergePartialIndexWorkerCoprLeak", nil) + } + if len(results) > 1 && len(e.byItems) != 0 { + // e.Schema() not the output schema for partialIndexReader, and we put byItems related column at first in `buildIndexReq`, so use nil here. + ssr := distsql.NewSortedSelectResults(e.Ctx().GetExprCtx().GetEvalCtx(), results, nil, e.byItems, e.memTracker) + results = []distsql.SelectResult{ssr} + } + ctx1, cancel := context.WithCancel(ctx) + // this error is reported in fetchHandles(), so ignore it here. + _, _ = worker.fetchHandles(ctx1, results, exitCh, fetchCh, e.finished, e.handleCols, workID) + cancel() + }, + handleWorkerPanic(ctx, e.finished, nil, fetchCh, nil, partialIndexWorkerType), + ) + }() + + return nil +} + +func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { + ts := e.partialPlans[workID][0].(*plannercore.PhysicalTableScan) + + tbls := make([]table.Table, 0, 1) + if e.partitionTableMode && len(e.byItems) == 0 { + for _, p := range e.prunedPartitions { + tbls = append(tbls, p) + } + } else { + tbls = append(tbls, e.table) + } + + // for union case, the push-downLimit can be utilized to limit index fetched handles. + // for intersection case, the push-downLimit can only be conducted after all index/table path finished. + pushedTableLimit := e.pushedLimit + if e.isIntersection { + pushedTableLimit = nil + } + + go func() { + defer trace.StartRegion(ctx, "IndexMergePartialTableWorker").End() + defer e.idxWorkerWg.Done() + util.WithRecovery( + func() { + failpoint.Inject("testIndexMergePanicPartialTableWorker", nil) + var err error + partialTableReader := &TableReaderExecutor{ + BaseExecutorV2: exec.NewBaseExecutorV2(e.Ctx().GetSessionVars(), ts.Schema(), e.getPartitalPlanID(workID)), + tableReaderExecutorContext: newTableReaderExecutorContext(e.Ctx()), + dagPB: e.dagPBs[workID], + startTS: e.startTS, + txnScope: e.txnScope, + readReplicaScope: e.readReplicaScope, + isStaleness: e.isStaleness, + plans: e.partialPlans[workID], + ranges: e.ranges[workID], + netDataSize: e.partialNetDataSizes[workID], + keepOrder: ts.KeepOrder, + byItems: ts.ByItems, + } + + worker := &partialTableWorker{ + stats: e.stats, + sc: e.Ctx(), + batchSize: e.MaxChunkSize(), + maxBatchSize: e.Ctx().GetSessionVars().IndexLookupSize, + maxChunkSize: e.MaxChunkSize(), + tableReader: partialTableReader, + memTracker: e.memTracker, + partitionTableMode: e.partitionTableMode, + prunedPartitions: e.prunedPartitions, + byItems: ts.ByItems, + pushedLimit: pushedTableLimit, + } + + if len(e.prunedPartitions) != 0 && len(e.byItems) != 0 { + slices.SortFunc(worker.prunedPartitions, func(i, j table.PhysicalTable) int { + return cmp.Compare(i.GetPhysicalID(), j.GetPhysicalID()) + }) + partialTableReader.kvRangeBuilder = kvRangeBuilderFromRangeAndPartition{ + partitions: worker.prunedPartitions, + } + } + + if e.isCorColInPartialFilters[workID] { + if e.dagPBs[workID].Executors, err = builder.ConstructListBasedDistExec(e.Ctx().GetBuildPBCtx(), e.partialPlans[workID]); err != nil { + syncErr(ctx, e.finished, fetchCh, err) + return + } + partialTableReader.dagPB = e.dagPBs[workID] + } + + var tableReaderClosed bool + defer func() { + // To make sure SelectResult.Close() is called even got panic in fetchHandles(). + if !tableReaderClosed { + terror.Log(exec.Close(worker.tableReader)) + } + }() + for parTblIdx, tbl := range tbls { + // check if this executor is closed + select { + case <-ctx.Done(): + return + case <-e.finished: + return + default: + } + + // init partialTableReader and partialTableWorker again for the next table + partialTableReader.table = tbl + if err = exec.Open(ctx, partialTableReader); err != nil { + logutil.Logger(ctx).Error("open Select result failed:", zap.Error(err)) + syncErr(ctx, e.finished, fetchCh, err) + break + } + failpoint.Inject("testIndexMergePartialTableWorkerCoprLeak", nil) + tableReaderClosed = false + worker.batchSize = e.MaxChunkSize() + if worker.batchSize > worker.maxBatchSize { + worker.batchSize = worker.maxBatchSize + } + + // fetch all handles from this table + ctx1, cancel := context.WithCancel(ctx) + _, fetchErr := worker.fetchHandles(ctx1, exitCh, fetchCh, e.finished, e.handleCols, parTblIdx, workID) + // release related resources + cancel() + tableReaderClosed = true + if err = exec.Close(worker.tableReader); err != nil { + logutil.Logger(ctx).Error("close Select result failed:", zap.Error(err)) + } + // this error is reported in fetchHandles(), so ignore it here. + if fetchErr != nil { + break + } + } + }, + handleWorkerPanic(ctx, e.finished, nil, fetchCh, nil, partialTableWorkerType), + ) + }() + return nil +} + +func (e *IndexMergeReaderExecutor) initRuntimeStats() { + if e.RuntimeStats() != nil { + e.stats = &IndexMergeRuntimeStat{ + Concurrency: e.Ctx().GetSessionVars().IndexLookupConcurrency(), + } + } +} + +func (e *IndexMergeReaderExecutor) getPartitalPlanID(workID int) int { + if len(e.partialPlans[workID]) > 0 { + return e.partialPlans[workID][len(e.partialPlans[workID])-1].ID() + } + return 0 +} + +func (e *IndexMergeReaderExecutor) getTablePlanRootID() int { + if len(e.tblPlans) > 0 { + return e.tblPlans[len(e.tblPlans)-1].ID() + } + return e.ID() +} + +type partialTableWorker struct { + stats *IndexMergeRuntimeStat + sc sessionctx.Context + batchSize int + maxBatchSize int + maxChunkSize int + tableReader exec.Executor + memTracker *memory.Tracker + partitionTableMode bool + prunedPartitions []table.PhysicalTable + byItems []*plannerutil.ByItems + scannedKeys uint64 + pushedLimit *plannercore.PushedDownLimit +} + +// needPartitionHandle indicates whether we need create a partitionHandle or not. +// If the schema from planner part contains ExtraPhysTblID, +// we need create a partitionHandle, otherwise create a normal handle. +// In TableRowIDScan, the partitionHandle will be used to create key ranges. +func (w *partialTableWorker) needPartitionHandle() (bool, error) { + cols := w.tableReader.(*TableReaderExecutor).plans[0].Schema().Columns + outputOffsets := w.tableReader.(*TableReaderExecutor).dagPB.OutputOffsets + col := cols[outputOffsets[len(outputOffsets)-1]] + + needPartitionHandle := w.partitionTableMode && len(w.byItems) > 0 + hasExtraCol := col.ID == model.ExtraPhysTblID + + // There will be two needPartitionHandle != hasExtraCol situations. + // Only `needPartitionHandle` == true and `hasExtraCol` == false are not allowed. + // `ExtraPhysTblID` will be used in `SelectLock` when `needPartitionHandle` == false and `hasExtraCol` == true. + if needPartitionHandle && !hasExtraCol { + return needPartitionHandle, errors.Errorf("Internal error, needPartitionHandle != ret") + } + return needPartitionHandle, nil +} + +func (w *partialTableWorker) fetchHandles(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, + finished <-chan struct{}, handleCols plannerutil.HandleCols, parTblIdx int, partialPlanIndex int) (count int64, err error) { + chk := w.tableReader.NewChunkWithCapacity(w.getRetTpsForTableScan(), w.maxChunkSize, w.maxBatchSize) + for { + start := time.Now() + handles, retChunk, err := w.extractTaskHandles(ctx, chk, handleCols) + if err != nil { + syncErr(ctx, finished, fetchCh, err) + return count, err + } + if len(handles) == 0 { + return count, nil + } + count += int64(len(handles)) + task := w.buildTableTask(handles, retChunk, parTblIdx, partialPlanIndex) + if w.stats != nil { + atomic.AddInt64(&w.stats.FetchIdxTime, int64(time.Since(start))) + } + select { + case <-ctx.Done(): + return count, ctx.Err() + case <-exitCh: + return count, nil + case <-finished: + return count, nil + case fetchCh <- task: + } + } +} + +func (w *partialTableWorker) getRetTpsForTableScan() []*types.FieldType { + return exec.RetTypes(w.tableReader) +} + +func (w *partialTableWorker) extractTaskHandles(ctx context.Context, chk *chunk.Chunk, handleCols plannerutil.HandleCols) ( + handles []kv.Handle, retChk *chunk.Chunk, err error) { + handles = make([]kv.Handle, 0, w.batchSize) + if len(w.byItems) != 0 { + retChk = chunk.NewChunkWithCapacity(w.getRetTpsForTableScan(), w.batchSize) + } + var memUsage int64 + var chunkRowOffset int + defer w.memTracker.Consume(-memUsage) + for len(handles) < w.batchSize { + requiredRows := w.batchSize - len(handles) + if w.pushedLimit != nil { + if w.pushedLimit.Offset+w.pushedLimit.Count <= w.scannedKeys { + return handles, retChk, nil + } + requiredRows = min(int(w.pushedLimit.Offset+w.pushedLimit.Count-w.scannedKeys), requiredRows) + } + chk.SetRequiredRows(requiredRows, w.maxChunkSize) + start := time.Now() + err = errors.Trace(w.tableReader.Next(ctx, chk)) + if err != nil { + return nil, nil, err + } + if w.tableReader != nil && w.tableReader.RuntimeStats() != nil { + w.tableReader.RuntimeStats().Record(time.Since(start), chk.NumRows()) + } + if chk.NumRows() == 0 { + failpoint.Inject("testIndexMergeErrorPartialTableWorker", func(v failpoint.Value) { + failpoint.Return(handles, nil, errors.New(v.(string))) + }) + return handles, retChk, nil + } + memDelta := chk.MemoryUsage() + memUsage += memDelta + w.memTracker.Consume(memDelta) + for chunkRowOffset = 0; chunkRowOffset < chk.NumRows(); chunkRowOffset++ { + if w.pushedLimit != nil { + w.scannedKeys++ + if w.scannedKeys > (w.pushedLimit.Offset + w.pushedLimit.Count) { + // Skip the handles after Offset+Count. + break + } + } + var handle kv.Handle + ok, err1 := w.needPartitionHandle() + if err1 != nil { + return nil, nil, err1 + } + if ok { + handle, err = handleCols.BuildPartitionHandleFromIndexRow(chk.GetRow(chunkRowOffset)) + } else { + handle, err = handleCols.BuildHandleFromIndexRow(chk.GetRow(chunkRowOffset)) + } + if err != nil { + return nil, nil, err + } + handles = append(handles, handle) + } + // used for order by + if len(w.byItems) != 0 { + retChk.Append(chk, 0, chunkRowOffset) + } + } + w.batchSize *= 2 + if w.batchSize > w.maxBatchSize { + w.batchSize = w.maxBatchSize + } + return handles, retChk, nil +} + +func (w *partialTableWorker) buildTableTask(handles []kv.Handle, retChk *chunk.Chunk, parTblIdx int, partialPlanID int) *indexMergeTableTask { + task := &indexMergeTableTask{ + lookupTableTask: lookupTableTask{ + handles: handles, + idxRows: retChk, + }, + parTblIdx: parTblIdx, + partialPlanID: partialPlanID, + } + + if w.prunedPartitions != nil { + task.partitionTable = w.prunedPartitions[parTblIdx] + } + + task.doneCh = make(chan error, 1) + return task +} + +func (e *IndexMergeReaderExecutor) startIndexMergeTableScanWorker(ctx context.Context, workCh <-chan *indexMergeTableTask) { + lookupConcurrencyLimit := e.Ctx().GetSessionVars().IndexLookupConcurrency() + e.tblWorkerWg.Add(lookupConcurrencyLimit) + for i := 0; i < lookupConcurrencyLimit; i++ { + worker := &indexMergeTableScanWorker{ + stats: e.stats, + workCh: workCh, + finished: e.finished, + indexMergeExec: e, + tblPlans: e.tblPlans, + memTracker: e.memTracker, + } + ctx1, cancel := context.WithCancel(ctx) + go func() { + defer trace.StartRegion(ctx, "IndexMergeTableScanWorker").End() + var task *indexMergeTableTask + util.WithRecovery( + // Note we use the address of `task` as the argument of both `pickAndExecTask` and `handleTableScanWorkerPanic` + // because `task` is expected to be assigned in `pickAndExecTask`, and this assignment should also be visible + // in `handleTableScanWorkerPanic` since it will get `doneCh` from `task`. Golang always pass argument by value, + // so if we don't use the address of `task` as the argument, the assignment to `task` in `pickAndExecTask` is + // not visible in `handleTableScanWorkerPanic` + func() { worker.pickAndExecTask(ctx1, &task) }, + worker.handleTableScanWorkerPanic(ctx1, e.finished, &task, tableScanWorkerType), + ) + cancel() + e.tblWorkerWg.Done() + }() + } +} + +func (e *IndexMergeReaderExecutor) buildFinalTableReader(ctx context.Context, tbl table.Table, handles []kv.Handle) (_ exec.Executor, err error) { + tableReaderExec := &TableReaderExecutor{ + BaseExecutorV2: exec.NewBaseExecutorV2(e.Ctx().GetSessionVars(), e.Schema(), e.getTablePlanRootID()), + tableReaderExecutorContext: newTableReaderExecutorContext(e.Ctx()), + table: tbl, + dagPB: e.tableRequest, + startTS: e.startTS, + txnScope: e.txnScope, + readReplicaScope: e.readReplicaScope, + isStaleness: e.isStaleness, + columns: e.columns, + plans: e.tblPlans, + netDataSize: e.dataAvgRowSize * float64(len(handles)), + } + tableReaderExec.buildVirtualColumnInfo() + // Reorder handles because SplitKeyRangesByLocationsWith/WithoutBuckets() requires startKey of kvRanges is ordered. + // Also it's good for performance. + tableReader, err := e.dataReaderBuilder.buildTableReaderFromHandles(ctx, tableReaderExec, handles, true) + if err != nil { + logutil.Logger(ctx).Error("build table reader from handles failed", zap.Error(err)) + return nil, err + } + return tableReader, nil +} + +// Next implements Executor Next interface. +func (e *IndexMergeReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error { + if !e.workerStarted { + if err := e.startWorkers(ctx); err != nil { + return err + } + } + + req.Reset() + for { + resultTask, err := e.getResultTask(ctx) + if err != nil { + return errors.Trace(err) + } + if resultTask == nil { + return nil + } + if resultTask.cursor < len(resultTask.rows) { + numToAppend := min(len(resultTask.rows)-resultTask.cursor, e.MaxChunkSize()-req.NumRows()) + req.AppendRows(resultTask.rows[resultTask.cursor : resultTask.cursor+numToAppend]) + resultTask.cursor += numToAppend + if req.NumRows() >= e.MaxChunkSize() { + return nil + } + } + } +} + +func (e *IndexMergeReaderExecutor) getResultTask(ctx context.Context) (*indexMergeTableTask, error) { + failpoint.Inject("testIndexMergeMainReturnEarly", func(_ failpoint.Value) { + // To make sure processWorker make resultCh to be full. + // When main goroutine close finished, processWorker may be stuck when writing resultCh. + time.Sleep(time.Second * 20) + failpoint.Return(nil, errors.New("failpoint testIndexMergeMainReturnEarly")) + }) + if e.resultCurr != nil && e.resultCurr.cursor < len(e.resultCurr.rows) { + return e.resultCurr, nil + } + task, ok := <-e.resultCh + if !ok { + return nil, nil + } + + select { + case <-ctx.Done(): + return nil, errors.Trace(ctx.Err()) + case err := <-task.doneCh: + if err != nil { + return nil, errors.Trace(err) + } + } + + // Release the memory usage of last task before we handle a new task. + if e.resultCurr != nil { + e.resultCurr.memTracker.Consume(-e.resultCurr.memUsage) + } + e.resultCurr = task + return e.resultCurr, nil +} + +func handleWorkerPanic(ctx context.Context, finished, limitDone <-chan struct{}, ch chan<- *indexMergeTableTask, extraNotifyCh chan bool, worker string) func(r any) { + return func(r any) { + if worker == processWorkerType { + // There is only one processWorker, so it's safe to close here. + // No need to worry about "close on closed channel" error. + defer close(ch) + } + if r == nil { + logutil.BgLogger().Debug("worker finish without panic", zap.Any("worker", worker)) + return + } + + if extraNotifyCh != nil { + extraNotifyCh <- true + } + + err4Panic := util.GetRecoverError(r) + logutil.Logger(ctx).Error(err4Panic.Error()) + doneCh := make(chan error, 1) + doneCh <- err4Panic + task := &indexMergeTableTask{ + lookupTableTask: lookupTableTask{ + doneCh: doneCh, + }, + } + select { + case <-ctx.Done(): + return + case <-finished: + return + case <-limitDone: + // once the process worker recovered from panic, once finding the limitDone signal, actually we can return. + return + case ch <- task: + return + } + } +} + +// Close implements Exec Close interface. +func (e *IndexMergeReaderExecutor) Close() error { + if e.stats != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + if e.indexUsageReporter != nil { + for _, p := range e.partialPlans { + is, ok := p[0].(*plannercore.PhysicalIndexScan) + if !ok { + continue + } + + e.indexUsageReporter.ReportCopIndexUsageForTable(e.table, is.Index.ID, is.ID()) + } + } + if e.finished == nil { + return nil + } + close(e.finished) + e.tblWorkerWg.Wait() + e.idxWorkerWg.Wait() + e.processWorkerWg.Wait() + e.finished = nil + e.workerStarted = false + return nil +} + +type indexMergeProcessWorker struct { + indexMerge *IndexMergeReaderExecutor + stats *IndexMergeRuntimeStat +} + +type rowIdx struct { + partialID int + taskID int + rowID int +} + +type handleHeap struct { + // requiredCnt == 0 means need all handles + requiredCnt uint64 + tracker *memory.Tracker + taskMap map[int][]*indexMergeTableTask + + idx []rowIdx + compareFunc []chunk.CompareFunc + byItems []*plannerutil.ByItems +} + +func (h handleHeap) Len() int { + return len(h.idx) +} + +func (h handleHeap) Less(i, j int) bool { + rowI := h.taskMap[h.idx[i].partialID][h.idx[i].taskID].idxRows.GetRow(h.idx[i].rowID) + rowJ := h.taskMap[h.idx[j].partialID][h.idx[j].taskID].idxRows.GetRow(h.idx[j].rowID) + + for i, compFunc := range h.compareFunc { + cmp := compFunc(rowI, i, rowJ, i) + if !h.byItems[i].Desc { + cmp = -cmp + } + if cmp < 0 { + return true + } else if cmp > 0 { + return false + } + } + return false +} + +func (h handleHeap) Swap(i, j int) { + h.idx[i], h.idx[j] = h.idx[j], h.idx[i] +} + +func (h *handleHeap) Push(x any) { + idx := x.(rowIdx) + h.idx = append(h.idx, idx) + if h.tracker != nil { + h.tracker.Consume(int64(unsafe.Sizeof(h.idx))) + } +} + +func (h *handleHeap) Pop() any { + idxRet := h.idx[len(h.idx)-1] + h.idx = h.idx[:len(h.idx)-1] + if h.tracker != nil { + h.tracker.Consume(-int64(unsafe.Sizeof(h.idx))) + } + return idxRet +} + +func (w *indexMergeProcessWorker) NewHandleHeap(taskMap map[int][]*indexMergeTableTask, memTracker *memory.Tracker) *handleHeap { + compareFuncs := make([]chunk.CompareFunc, 0, len(w.indexMerge.byItems)) + for _, item := range w.indexMerge.byItems { + keyType := item.Expr.GetType(w.indexMerge.Ctx().GetExprCtx().GetEvalCtx()) + compareFuncs = append(compareFuncs, chunk.GetCompareFunc(keyType)) + } + + requiredCnt := uint64(0) + if w.indexMerge.pushedLimit != nil { + // Pre-allocate up to 1024 to avoid oom + requiredCnt = min(1024, w.indexMerge.pushedLimit.Count+w.indexMerge.pushedLimit.Offset) + } + return &handleHeap{ + requiredCnt: requiredCnt, + tracker: memTracker, + taskMap: taskMap, + idx: make([]rowIdx, 0, requiredCnt), + compareFunc: compareFuncs, + byItems: w.indexMerge.byItems, + } +} + +// pruneTableWorkerTaskIdxRows prune idxRows and only keep columns that will be used in byItems. +// e.g. the common handle is (`b`, `c`) and order by with column `c`, we should make column `c` at the first. +func (w *indexMergeProcessWorker) pruneTableWorkerTaskIdxRows(task *indexMergeTableTask) { + if task.idxRows == nil { + return + } + // IndexScan no need to prune retChk, Columns required by byItems are always first. + if plan, ok := w.indexMerge.partialPlans[task.partialPlanID][0].(*plannercore.PhysicalTableScan); ok { + prune := make([]int, 0, len(w.indexMerge.byItems)) + for _, item := range plan.ByItems { + c, _ := item.Expr.(*expression.Column) + idx := plan.Schema().ColumnIndex(c) + // couldn't equals to -1 here, if idx == -1, just let it panic + prune = append(prune, idx) + } + task.idxRows = task.idxRows.Prune(prune) + } +} + +func (w *indexMergeProcessWorker) fetchLoopUnionWithOrderBy(ctx context.Context, fetchCh <-chan *indexMergeTableTask, + workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { + memTracker := memory.NewTracker(w.indexMerge.ID(), -1) + memTracker.AttachTo(w.indexMerge.memTracker) + defer memTracker.Detach() + defer close(workCh) + + if w.stats != nil { + start := time.Now() + defer func() { + w.stats.IndexMergeProcess += time.Since(start) + }() + } + + distinctHandles := kv.NewHandleMap() + taskMap := make(map[int][]*indexMergeTableTask) + uselessMap := make(map[int]struct{}) + taskHeap := w.NewHandleHeap(taskMap, memTracker) + + for task := range fetchCh { + select { + case err := <-task.doneCh: + // If got error from partialIndexWorker/partialTableWorker, stop processing. + if err != nil { + syncErr(ctx, finished, resultCh, err) + return + } + default: + } + if _, ok := uselessMap[task.partialPlanID]; ok { + continue + } + if _, ok := taskMap[task.partialPlanID]; !ok { + taskMap[task.partialPlanID] = make([]*indexMergeTableTask, 0, 1) + } + w.pruneTableWorkerTaskIdxRows(task) + taskMap[task.partialPlanID] = append(taskMap[task.partialPlanID], task) + for i, h := range task.handles { + if _, ok := distinctHandles.Get(h); !ok { + distinctHandles.Set(h, true) + heap.Push(taskHeap, rowIdx{task.partialPlanID, len(taskMap[task.partialPlanID]) - 1, i}) + if int(taskHeap.requiredCnt) != 0 && taskHeap.Len() > int(taskHeap.requiredCnt) { + top := heap.Pop(taskHeap).(rowIdx) + if top.partialID == task.partialPlanID && top.taskID == len(taskMap[task.partialPlanID])-1 && top.rowID == i { + uselessMap[task.partialPlanID] = struct{}{} + task.handles = task.handles[:i] + break + } + } + } + memTracker.Consume(int64(h.MemUsage())) + } + memTracker.Consume(task.idxRows.MemoryUsage()) + if len(uselessMap) == len(w.indexMerge.partialPlans) { + // consume reset tasks + go func() { + channel.Clear(fetchCh) + }() + break + } + } + + needCount := taskHeap.Len() + if w.indexMerge.pushedLimit != nil { + needCount = max(0, taskHeap.Len()-int(w.indexMerge.pushedLimit.Offset)) + } + if needCount == 0 { + return + } + fhs := make([]kv.Handle, needCount) + for i := needCount - 1; i >= 0; i-- { + idx := heap.Pop(taskHeap).(rowIdx) + fhs[i] = taskMap[idx.partialID][idx.taskID].handles[idx.rowID] + } + + batchSize := w.indexMerge.Ctx().GetSessionVars().IndexLookupSize + tasks := make([]*indexMergeTableTask, 0, len(fhs)/batchSize+1) + for len(fhs) > 0 { + l := min(len(fhs), batchSize) + // Save the index order. + indexOrder := kv.NewHandleMap() + for i, h := range fhs[:l] { + indexOrder.Set(h, i) + } + tasks = append(tasks, &indexMergeTableTask{ + lookupTableTask: lookupTableTask{ + handles: fhs[:l], + indexOrder: indexOrder, + doneCh: make(chan error, 1), + }, + }) + fhs = fhs[l:] + } + for _, task := range tasks { + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + failpoint.Inject("testCancelContext", func() { + IndexMergeCancelFuncForTest() + }) + select { + case <-ctx.Done(): + return + case <-finished: + return + case workCh <- task: + continue + } + } + } +} + +func pushedLimitCountingDown(pushedLimit *plannercore.PushedDownLimit, handles []kv.Handle) (next bool, res []kv.Handle) { + fhsLen := uint64(len(handles)) + // The number of handles is less than the offset, discard all handles. + if fhsLen <= pushedLimit.Offset { + pushedLimit.Offset -= fhsLen + return true, nil + } + handles = handles[pushedLimit.Offset:] + pushedLimit.Offset = 0 + + fhsLen = uint64(len(handles)) + // The number of handles is greater than the limit, only keep limit count. + if fhsLen > pushedLimit.Count { + handles = handles[:pushedLimit.Count] + } + pushedLimit.Count -= min(pushedLimit.Count, fhsLen) + return false, handles +} + +func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <-chan *indexMergeTableTask, + workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { + failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { + failpoint.Return() + }) + memTracker := memory.NewTracker(w.indexMerge.ID(), -1) + memTracker.AttachTo(w.indexMerge.memTracker) + defer memTracker.Detach() + defer close(workCh) + failpoint.Inject("testIndexMergePanicProcessWorkerUnion", nil) + + var pushedLimit *plannercore.PushedDownLimit + if w.indexMerge.pushedLimit != nil { + pushedLimit = w.indexMerge.pushedLimit.Clone() + } + hMap := kv.NewHandleMap() + for { + var ok bool + var task *indexMergeTableTask + if pushedLimit != nil && pushedLimit.Count == 0 { + return + } + select { + case <-ctx.Done(): + return + case <-finished: + return + case task, ok = <-fetchCh: + if !ok { + return + } + } + + select { + case err := <-task.doneCh: + // If got error from partialIndexWorker/partialTableWorker, stop processing. + if err != nil { + syncErr(ctx, finished, resultCh, err) + return + } + default: + } + start := time.Now() + handles := task.handles + fhs := make([]kv.Handle, 0, 8) + + memTracker.Consume(int64(cap(task.handles) * 8)) + for _, h := range handles { + if w.indexMerge.partitionTableMode { + if _, ok := h.(kv.PartitionHandle); !ok { + h = kv.NewPartitionHandle(task.partitionTable.GetPhysicalID(), h) + } + } + if _, ok := hMap.Get(h); !ok { + fhs = append(fhs, h) + hMap.Set(h, true) + } + } + if len(fhs) == 0 { + continue + } + if pushedLimit != nil { + next, res := pushedLimitCountingDown(pushedLimit, fhs) + if next { + continue + } + fhs = res + } + task = &indexMergeTableTask{ + lookupTableTask: lookupTableTask{ + handles: fhs, + doneCh: make(chan error, 1), + + partitionTable: task.partitionTable, + }, + } + if w.stats != nil { + w.stats.IndexMergeProcess += time.Since(start) + } + failpoint.Inject("testIndexMergeProcessWorkerUnionHang", func(_ failpoint.Value) { + for i := 0; i < cap(resultCh); i++ { + select { + case resultCh <- &indexMergeTableTask{}: + default: + } + } + }) + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + failpoint.Inject("testCancelContext", func() { + IndexMergeCancelFuncForTest() + }) + select { + case <-ctx.Done(): + return + case <-finished: + return + case workCh <- task: + } + } + } +} + +// intersectionCollectWorker is used to dispatch index-merge-table-task to original workCh and resultCh. +// a kind of interceptor to control the pushed down limit restriction. (should be no performance impact) +type intersectionCollectWorker struct { + pushedLimit *plannercore.PushedDownLimit + collectCh chan *indexMergeTableTask + limitDone chan struct{} +} + +func (w *intersectionCollectWorker) doIntersectionLimitAndDispatch(ctx context.Context, workCh chan<- *indexMergeTableTask, + resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { + var ( + ok bool + task *indexMergeTableTask + ) + for { + select { + case <-ctx.Done(): + return + case <-finished: + return + case task, ok = <-w.collectCh: + if !ok { + return + } + // receive a new intersection task here, adding limit restriction logic + if w.pushedLimit != nil { + if w.pushedLimit.Count == 0 { + // close limitDone channel to notify intersectionProcessWorkers * N to exit. + close(w.limitDone) + return + } + next, handles := pushedLimitCountingDown(w.pushedLimit, task.handles) + if next { + continue + } + task.handles = handles + } + // dispatch the new task to workCh and resultCh. + select { + case <-ctx.Done(): + return + case <-finished: + return + case workCh <- task: + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + } + } + } + } +} + +type intersectionProcessWorker struct { + // key: parTblIdx, val: HandleMap + // Value of MemAwareHandleMap is *int to avoid extra Get(). + handleMapsPerWorker map[int]*kv.MemAwareHandleMap[*int] + workerID int + workerCh chan *indexMergeTableTask + indexMerge *IndexMergeReaderExecutor + memTracker *memory.Tracker + batchSize int + + // When rowDelta == memConsumeBatchSize, Consume(memUsage) + rowDelta int64 + mapUsageDelta int64 + + partitionIDMap map[int64]int +} + +func (w *intersectionProcessWorker) consumeMemDelta() { + w.memTracker.Consume(w.mapUsageDelta + w.rowDelta*int64(unsafe.Sizeof(int(0)))) + w.mapUsageDelta = 0 + w.rowDelta = 0 +} + +// doIntersectionPerPartition fetch all the task from workerChannel, and after that, then do the intersection pruning, which +// will cause wasting a lot of time waiting for all the fetch task done. +func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Context, workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished, limitDone <-chan struct{}) { + failpoint.Inject("testIndexMergePanicPartitionTableIntersectionWorker", nil) + defer w.memTracker.Detach() + + for task := range w.workerCh { + var ok bool + var hMap *kv.MemAwareHandleMap[*int] + if hMap, ok = w.handleMapsPerWorker[task.parTblIdx]; !ok { + hMap = kv.NewMemAwareHandleMap[*int]() + w.handleMapsPerWorker[task.parTblIdx] = hMap + } + var mapDelta, rowDelta int64 + for _, h := range task.handles { + if w.indexMerge.hasGlobalIndex { + if ph, ok := h.(kv.PartitionHandle); ok { + if v, exists := w.partitionIDMap[ph.PartitionID]; exists { + if hMap, ok = w.handleMapsPerWorker[v]; !ok { + hMap = kv.NewMemAwareHandleMap[*int]() + w.handleMapsPerWorker[v] = hMap + } + } + } else { + h = kv.NewPartitionHandle(task.partitionTable.GetPhysicalID(), h) + } + } + // Use *int to avoid Get() again. + if cntPtr, ok := hMap.Get(h); ok { + (*cntPtr)++ + } else { + cnt := 1 + mapDelta += hMap.Set(h, &cnt) + int64(h.ExtraMemSize()) + rowDelta++ + } + } + + logutil.BgLogger().Debug("intersectionProcessWorker handle tasks", zap.Int("workerID", w.workerID), + zap.Int("task.handles", len(task.handles)), zap.Int64("rowDelta", rowDelta)) + + w.mapUsageDelta += mapDelta + w.rowDelta += rowDelta + if w.rowDelta >= int64(w.batchSize) { + w.consumeMemDelta() + } + failpoint.Inject("testIndexMergeIntersectionWorkerPanic", nil) + } + if w.rowDelta > 0 { + w.consumeMemDelta() + } + + // We assume the result of intersection is small, so no need to track memory. + intersectedMap := make(map[int][]kv.Handle, len(w.handleMapsPerWorker)) + for parTblIdx, hMap := range w.handleMapsPerWorker { + hMap.Range(func(h kv.Handle, val *int) bool { + if *(val) == len(w.indexMerge.partialPlans) { + // Means all partial paths have this handle. + intersectedMap[parTblIdx] = append(intersectedMap[parTblIdx], h) + } + return true + }) + } + + tasks := make([]*indexMergeTableTask, 0, len(w.handleMapsPerWorker)) + for parTblIdx, intersected := range intersectedMap { + // Split intersected[parTblIdx] to avoid task is too large. + for len(intersected) > 0 { + length := w.batchSize + if length > len(intersected) { + length = len(intersected) + } + task := &indexMergeTableTask{ + lookupTableTask: lookupTableTask{ + handles: intersected[:length], + doneCh: make(chan error, 1), + }, + } + intersected = intersected[length:] + if w.indexMerge.partitionTableMode { + task.partitionTable = w.indexMerge.prunedPartitions[parTblIdx] + } + tasks = append(tasks, task) + logutil.BgLogger().Debug("intersectionProcessWorker build tasks", + zap.Int("parTblIdx", parTblIdx), zap.Int("task.handles", len(task.handles))) + } + } + failpoint.Inject("testIndexMergeProcessWorkerIntersectionHang", func(_ failpoint.Value) { + if resultCh != nil { + for i := 0; i < cap(resultCh); i++ { + select { + case resultCh <- &indexMergeTableTask{}: + default: + } + } + } + }) + for _, task := range tasks { + select { + case <-ctx.Done(): + return + case <-finished: + return + case <-limitDone: + // limitDone has signal means the collectWorker has collected enough results, shutdown process workers quickly here. + return + case workCh <- task: + // resultCh != nil means there is no collectWorker, and we should send task to resultCh too by ourselves here. + if resultCh != nil { + select { + case <-ctx.Done(): + return + case <-finished: + return + case resultCh <- task: + } + } + } + } +} + +// for every index merge process worker, it should be feed on a sortedSelectResult for every partial index plan (constructed on all +// table partition ranges results on that index plan path). Since every partial index path is a sorted select result, we can utilize +// K-way merge to accelerate the intersection process. +// +// partialIndexPlan-1 ---> SSR ---> + +// partialIndexPlan-2 ---> SSR ---> + ---> SSR K-way Merge ---> output IndexMergeTableTask +// partialIndexPlan-3 ---> SSR ---> + +// ... + +// partialIndexPlan-N ---> SSR ---> + +// +// K-way merge detail: for every partial index plan, output one row as current its representative row. Then, comparing the N representative +// rows together: +// +// Loop start: +// +// case 1: they are all the same, intersection succeed. --- Record current handle down (already in index order). +// case 2: distinguish among them, for the minimum value/values corresponded index plan/plans. --- Discard current representative row, fetch next. +// +// goto Loop start: +// +// encapsulate all the recorded handles (already in index order) as index merge table tasks, sending them out. +func (*indexMergeProcessWorker) fetchLoopIntersectionWithOrderBy(_ context.Context, _ <-chan *indexMergeTableTask, + _ chan<- *indexMergeTableTask, _ chan<- *indexMergeTableTask, _ <-chan struct{}) { + // todo: pushed sort property with partial index plan and limit. +} + +// For each partition(dynamic mode), a map is used to do intersection. Key of the map is handle, and value is the number of times it occurs. +// If the value of handle equals the number of partial paths, it should be sent to final_table_scan_worker. +// To avoid too many goroutines, each intersectionProcessWorker can handle multiple partitions. +func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fetchCh <-chan *indexMergeTableTask, + workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { + defer close(workCh) + + if w.stats != nil { + start := time.Now() + defer func() { + w.stats.IndexMergeProcess += time.Since(start) + }() + } + + failpoint.Inject("testIndexMergePanicProcessWorkerIntersection", nil) + + // One goroutine may handle one or multiple partitions. + // Max number of partition number is 8192, we use ExecutorConcurrency to avoid too many goroutines. + maxWorkerCnt := w.indexMerge.Ctx().GetSessionVars().IndexMergeIntersectionConcurrency() + maxChannelSize := atomic.LoadInt32(&LookupTableTaskChannelSize) + batchSize := w.indexMerge.Ctx().GetSessionVars().IndexLookupSize + + partCnt := 1 + // To avoid multi-threaded access the handle map, we only use one worker for indexMerge with global index. + if w.indexMerge.partitionTableMode && !w.indexMerge.hasGlobalIndex { + partCnt = len(w.indexMerge.prunedPartitions) + } + workerCnt := min(partCnt, maxWorkerCnt) + failpoint.Inject("testIndexMergeIntersectionConcurrency", func(val failpoint.Value) { + con := val.(int) + if con != workerCnt { + panic(fmt.Sprintf("unexpected workerCnt, expect %d, got %d", con, workerCnt)) + } + }) + + partitionIDMap := make(map[int64]int) + if w.indexMerge.hasGlobalIndex { + for i, p := range w.indexMerge.prunedPartitions { + partitionIDMap[p.GetPhysicalID()] = i + } + } + + workers := make([]*intersectionProcessWorker, 0, workerCnt) + var collectWorker *intersectionCollectWorker + wg := util.WaitGroupWrapper{} + wg2 := util.WaitGroupWrapper{} + errCh := make(chan bool, workerCnt) + var limitDone chan struct{} + if w.indexMerge.pushedLimit != nil { + // no memory cost for this code logic. + collectWorker = &intersectionCollectWorker{ + // same size of workCh/resultCh + collectCh: make(chan *indexMergeTableTask, atomic.LoadInt32(&LookupTableTaskChannelSize)), + pushedLimit: w.indexMerge.pushedLimit.Clone(), + limitDone: make(chan struct{}), + } + limitDone = collectWorker.limitDone + wg2.RunWithRecover(func() { + defer trace.StartRegion(ctx, "IndexMergeIntersectionProcessWorker").End() + collectWorker.doIntersectionLimitAndDispatch(ctx, workCh, resultCh, finished) + }, handleWorkerPanic(ctx, finished, nil, resultCh, errCh, partTblIntersectionWorkerType)) + } + for i := 0; i < workerCnt; i++ { + tracker := memory.NewTracker(w.indexMerge.ID(), -1) + tracker.AttachTo(w.indexMerge.memTracker) + worker := &intersectionProcessWorker{ + workerID: i, + handleMapsPerWorker: make(map[int]*kv.MemAwareHandleMap[*int]), + workerCh: make(chan *indexMergeTableTask, maxChannelSize), + indexMerge: w.indexMerge, + memTracker: tracker, + batchSize: batchSize, + partitionIDMap: partitionIDMap, + } + wg.RunWithRecover(func() { + defer trace.StartRegion(ctx, "IndexMergeIntersectionProcessWorker").End() + if collectWorker != nil { + // workflow: + // intersectionProcessWorker-1 --+ (limit restriction logic) + // intersectionProcessWorker-2 --+--------- collectCh--> intersectionCollectWorker +--> workCh --> table worker + // ... --+ <--- limitDone to shut inputs ------+ +-> resultCh --> upper parent + // intersectionProcessWorker-N --+ + worker.doIntersectionPerPartition(ctx, collectWorker.collectCh, nil, finished, collectWorker.limitDone) + } else { + // workflow: + // intersectionProcessWorker-1 --------------------------+--> workCh --> table worker + // intersectionProcessWorker-2 ---(same as above) +--> resultCh --> upper parent + // ... ---(same as above) + // intersectionProcessWorker-N ---(same as above) + worker.doIntersectionPerPartition(ctx, workCh, resultCh, finished, nil) + } + }, handleWorkerPanic(ctx, finished, limitDone, resultCh, errCh, partTblIntersectionWorkerType)) + workers = append(workers, worker) + } + defer func() { + for _, processWorker := range workers { + close(processWorker.workerCh) + } + wg.Wait() + // only after all the possible writer closed, can we shut down the collectCh. + if collectWorker != nil { + // you don't need to clear the channel before closing it, so discard all the remain tasks. + close(collectWorker.collectCh) + } + wg2.Wait() + }() + for { + var ok bool + var task *indexMergeTableTask + select { + case <-ctx.Done(): + return + case <-finished: + return + case task, ok = <-fetchCh: + if !ok { + return + } + } + + select { + case err := <-task.doneCh: + // If got error from partialIndexWorker/partialTableWorker, stop processing. + if err != nil { + syncErr(ctx, finished, resultCh, err) + return + } + default: + } + + select { + case <-ctx.Done(): + return + case <-finished: + return + case workers[task.parTblIdx%workerCnt].workerCh <- task: + case <-errCh: + // If got error from intersectionProcessWorker, stop processing. + return + } + } +} + +type partialIndexWorker struct { + stats *IndexMergeRuntimeStat + sc sessionctx.Context + idxID int + batchSize int + maxBatchSize int + maxChunkSize int + memTracker *memory.Tracker + partitionTableMode bool + prunedPartitions []table.PhysicalTable + byItems []*plannerutil.ByItems + scannedKeys uint64 + pushedLimit *plannercore.PushedDownLimit + dagPB *tipb.DAGRequest + plan []base.PhysicalPlan +} + +func syncErr(ctx context.Context, finished <-chan struct{}, errCh chan<- *indexMergeTableTask, err error) { + logutil.BgLogger().Error("IndexMergeReaderExecutor.syncErr", zap.Error(err)) + doneCh := make(chan error, 1) + doneCh <- err + task := &indexMergeTableTask{ + lookupTableTask: lookupTableTask{ + doneCh: doneCh, + }, + } + + // ctx.Done and finished is to avoid write channel is stuck. + select { + case <-ctx.Done(): + return + case <-finished: + return + case errCh <- task: + return + } +} + +// needPartitionHandle indicates whether we need create a partitionHandle or not. +// If the schema from planner part contains ExtraPhysTblID, +// we need create a partitionHandle, otherwise create a normal handle. +// In TableRowIDScan, the partitionHandle will be used to create key ranges. +func (w *partialIndexWorker) needPartitionHandle() (bool, error) { + cols := w.plan[0].Schema().Columns + outputOffsets := w.dagPB.OutputOffsets + col := cols[outputOffsets[len(outputOffsets)-1]] + + is := w.plan[0].(*plannercore.PhysicalIndexScan) + needPartitionHandle := w.partitionTableMode && len(w.byItems) > 0 || is.Index.Global + hasExtraCol := col.ID == model.ExtraPhysTblID + + // There will be two needPartitionHandle != hasExtraCol situations. + // Only `needPartitionHandle` == true and `hasExtraCol` == false are not allowed. + // `ExtraPhysTblID` will be used in `SelectLock` when `needPartitionHandle` == false and `hasExtraCol` == true. + if needPartitionHandle && !hasExtraCol { + return needPartitionHandle, errors.Errorf("Internal error, needPartitionHandle != ret") + } + return needPartitionHandle, nil +} + +func (w *partialIndexWorker) fetchHandles( + ctx context.Context, + results []distsql.SelectResult, + exitCh <-chan struct{}, + fetchCh chan<- *indexMergeTableTask, + finished <-chan struct{}, + handleCols plannerutil.HandleCols, + partialPlanIndex int) (count int64, err error) { + tps := w.getRetTpsForIndexScan(handleCols) + chk := chunk.NewChunkWithCapacity(tps, w.maxChunkSize) + for i := 0; i < len(results); { + start := time.Now() + handles, retChunk, err := w.extractTaskHandles(ctx, chk, results[i], handleCols) + if err != nil { + syncErr(ctx, finished, fetchCh, err) + return count, err + } + if len(handles) == 0 { + i++ + continue + } + count += int64(len(handles)) + task := w.buildTableTask(handles, retChunk, i, partialPlanIndex) + if w.stats != nil { + atomic.AddInt64(&w.stats.FetchIdxTime, int64(time.Since(start))) + } + select { + case <-ctx.Done(): + return count, ctx.Err() + case <-exitCh: + return count, nil + case <-finished: + return count, nil + case fetchCh <- task: + } + } + return count, nil +} + +func (w *partialIndexWorker) getRetTpsForIndexScan(handleCols plannerutil.HandleCols) []*types.FieldType { + var tps []*types.FieldType + if len(w.byItems) != 0 { + for _, item := range w.byItems { + tps = append(tps, item.Expr.GetType(w.sc.GetExprCtx().GetEvalCtx())) + } + } + tps = append(tps, handleCols.GetFieldsTypes()...) + if ok, _ := w.needPartitionHandle(); ok { + tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) + } + return tps +} + +func (w *partialIndexWorker) extractTaskHandles(ctx context.Context, chk *chunk.Chunk, idxResult distsql.SelectResult, handleCols plannerutil.HandleCols) ( + handles []kv.Handle, retChk *chunk.Chunk, err error) { + handles = make([]kv.Handle, 0, w.batchSize) + if len(w.byItems) != 0 { + retChk = chunk.NewChunkWithCapacity(w.getRetTpsForIndexScan(handleCols), w.batchSize) + } + var memUsage int64 + var chunkRowOffset int + defer w.memTracker.Consume(-memUsage) + for len(handles) < w.batchSize { + requiredRows := w.batchSize - len(handles) + if w.pushedLimit != nil { + if w.pushedLimit.Offset+w.pushedLimit.Count <= w.scannedKeys { + return handles, retChk, nil + } + requiredRows = min(int(w.pushedLimit.Offset+w.pushedLimit.Count-w.scannedKeys), requiredRows) + } + chk.SetRequiredRows(requiredRows, w.maxChunkSize) + start := time.Now() + err = errors.Trace(idxResult.Next(ctx, chk)) + if err != nil { + return nil, nil, err + } + if w.stats != nil && w.idxID != 0 { + w.sc.GetSessionVars().StmtCtx.RuntimeStatsColl.GetBasicRuntimeStats(w.idxID).Record(time.Since(start), chk.NumRows()) + } + if chk.NumRows() == 0 { + failpoint.Inject("testIndexMergeErrorPartialIndexWorker", func(v failpoint.Value) { + failpoint.Return(handles, nil, errors.New(v.(string))) + }) + return handles, retChk, nil + } + memDelta := chk.MemoryUsage() + memUsage += memDelta + w.memTracker.Consume(memDelta) + for chunkRowOffset = 0; chunkRowOffset < chk.NumRows(); chunkRowOffset++ { + if w.pushedLimit != nil { + w.scannedKeys++ + if w.scannedKeys > (w.pushedLimit.Offset + w.pushedLimit.Count) { + // Skip the handles after Offset+Count. + break + } + } + var handle kv.Handle + ok, err1 := w.needPartitionHandle() + if err1 != nil { + return nil, nil, err1 + } + if ok { + handle, err = handleCols.BuildPartitionHandleFromIndexRow(chk.GetRow(chunkRowOffset)) + } else { + handle, err = handleCols.BuildHandleFromIndexRow(chk.GetRow(chunkRowOffset)) + } + if err != nil { + return nil, nil, err + } + handles = append(handles, handle) + } + // used for order by + if len(w.byItems) != 0 { + retChk.Append(chk, 0, chunkRowOffset) + } + } + w.batchSize *= 2 + if w.batchSize > w.maxBatchSize { + w.batchSize = w.maxBatchSize + } + return handles, retChk, nil +} + +func (w *partialIndexWorker) buildTableTask(handles []kv.Handle, retChk *chunk.Chunk, parTblIdx int, partialPlanID int) *indexMergeTableTask { + task := &indexMergeTableTask{ + lookupTableTask: lookupTableTask{ + handles: handles, + idxRows: retChk, + }, + parTblIdx: parTblIdx, + partialPlanID: partialPlanID, + } + + if w.prunedPartitions != nil { + task.partitionTable = w.prunedPartitions[parTblIdx] + } + + task.doneCh = make(chan error, 1) + return task +} + +type indexMergeTableScanWorker struct { + stats *IndexMergeRuntimeStat + workCh <-chan *indexMergeTableTask + finished <-chan struct{} + indexMergeExec *IndexMergeReaderExecutor + tblPlans []base.PhysicalPlan + + // memTracker is used to track the memory usage of this executor. + memTracker *memory.Tracker +} + +func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task **indexMergeTableTask) { + var ok bool + for { + waitStart := time.Now() + select { + case <-ctx.Done(): + return + case <-w.finished: + return + case *task, ok = <-w.workCh: + if !ok { + return + } + } + // Make sure panic failpoint is after fetch task from workCh. + // Otherwise, cannot send error to task.doneCh. + failpoint.Inject("testIndexMergePanicTableScanWorker", nil) + execStart := time.Now() + err := w.executeTask(ctx, *task) + if w.stats != nil { + atomic.AddInt64(&w.stats.WaitTime, int64(execStart.Sub(waitStart))) + atomic.AddInt64(&w.stats.FetchRow, int64(time.Since(execStart))) + atomic.AddInt64(&w.stats.TableTaskNum, 1) + } + failpoint.Inject("testIndexMergePickAndExecTaskPanic", nil) + select { + case <-ctx.Done(): + return + case <-w.finished: + return + case (*task).doneCh <- err: + } + } +} + +func (*indexMergeTableScanWorker) handleTableScanWorkerPanic(ctx context.Context, finished <-chan struct{}, task **indexMergeTableTask, worker string) func(r any) { + return func(r any) { + if r == nil { + logutil.BgLogger().Debug("worker finish without panic", zap.Any("worker", worker)) + return + } + + err4Panic := errors.Errorf("%s: %v", worker, r) + logutil.Logger(ctx).Error(err4Panic.Error()) + if *task != nil { + select { + case <-ctx.Done(): + return + case <-finished: + return + case (*task).doneCh <- err4Panic: + return + } + } + } +} + +func (w *indexMergeTableScanWorker) executeTask(ctx context.Context, task *indexMergeTableTask) error { + tbl := w.indexMergeExec.table + if w.indexMergeExec.partitionTableMode && task.partitionTable != nil { + tbl = task.partitionTable + } + tableReader, err := w.indexMergeExec.buildFinalTableReader(ctx, tbl, task.handles) + if err != nil { + logutil.Logger(ctx).Error("build table reader failed", zap.Error(err)) + return err + } + defer func() { terror.Log(exec.Close(tableReader)) }() + task.memTracker = w.memTracker + memUsage := int64(cap(task.handles) * 8) + task.memUsage = memUsage + task.memTracker.Consume(memUsage) + handleCnt := len(task.handles) + task.rows = make([]chunk.Row, 0, handleCnt) + for { + chk := exec.TryNewCacheChunk(tableReader) + err = exec.Next(ctx, tableReader, chk) + if err != nil { + logutil.Logger(ctx).Error("table reader fetch next chunk failed", zap.Error(err)) + return err + } + if chk.NumRows() == 0 { + break + } + memUsage = chk.MemoryUsage() + task.memUsage += memUsage + task.memTracker.Consume(memUsage) + iter := chunk.NewIterator4Chunk(chk) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + task.rows = append(task.rows, row) + } + } + + if w.indexMergeExec.keepOrder { + // Because len(outputOffsets) == tableScan.Schema().Len(), + // so we could use row.GetInt64(idx) to get partition ID here. + // TODO: We could add plannercore.PartitionHandleCols to unify them. + physicalTableIDIdx := -1 + for i, c := range w.indexMergeExec.Schema().Columns { + if c.ID == model.ExtraPhysTblID { + physicalTableIDIdx = i + break + } + } + task.rowIdx = make([]int, 0, len(task.rows)) + for _, row := range task.rows { + handle, err := w.indexMergeExec.handleCols.BuildHandle(row) + if err != nil { + return err + } + if w.indexMergeExec.partitionTableMode && physicalTableIDIdx != -1 { + handle = kv.NewPartitionHandle(row.GetInt64(physicalTableIDIdx), handle) + } + rowIdx, _ := task.indexOrder.Get(handle) + task.rowIdx = append(task.rowIdx, rowIdx.(int)) + } + sort.Sort(task) + } + + memUsage = int64(cap(task.rows)) * int64(unsafe.Sizeof(chunk.Row{})) + task.memUsage += memUsage + task.memTracker.Consume(memUsage) + if handleCnt != len(task.rows) && len(w.tblPlans) == 1 { + return errors.Errorf("handle count %d isn't equal to value count %d", handleCnt, len(task.rows)) + } + return nil +} + +// IndexMergeRuntimeStat record the indexMerge runtime stat +type IndexMergeRuntimeStat struct { + IndexMergeProcess time.Duration + FetchIdxTime int64 + WaitTime int64 + FetchRow int64 + TableTaskNum int64 + Concurrency int +} + +func (e *IndexMergeRuntimeStat) String() string { + var buf bytes.Buffer + if e.FetchIdxTime != 0 { + buf.WriteString(fmt.Sprintf("index_task:{fetch_handle:%s", time.Duration(e.FetchIdxTime))) + if e.IndexMergeProcess != 0 { + buf.WriteString(fmt.Sprintf(", merge:%s", e.IndexMergeProcess)) + } + buf.WriteByte('}') + } + if e.FetchRow != 0 { + if buf.Len() > 0 { + buf.WriteByte(',') + } + fmt.Fprintf(&buf, " table_task:{num:%d, concurrency:%d, fetch_row:%s, wait_time:%s}", e.TableTaskNum, e.Concurrency, time.Duration(e.FetchRow), time.Duration(e.WaitTime)) + } + return buf.String() +} + +// Clone implements the RuntimeStats interface. +func (e *IndexMergeRuntimeStat) Clone() execdetails.RuntimeStats { + newRs := *e + return &newRs +} + +// Merge implements the RuntimeStats interface. +func (e *IndexMergeRuntimeStat) Merge(other execdetails.RuntimeStats) { + tmp, ok := other.(*IndexMergeRuntimeStat) + if !ok { + return + } + e.IndexMergeProcess += tmp.IndexMergeProcess + e.FetchIdxTime += tmp.FetchIdxTime + e.FetchRow += tmp.FetchRow + e.WaitTime += e.WaitTime + e.TableTaskNum += tmp.TableTaskNum +} + +// Tp implements the RuntimeStats interface. +func (*IndexMergeRuntimeStat) Tp() int { + return execdetails.TpIndexMergeRunTimeStats +} diff --git a/pkg/executor/infoschema_reader.go b/pkg/executor/infoschema_reader.go index 6b8455e281f39..6cdcbf9c2dd44 100644 --- a/pkg/executor/infoschema_reader.go +++ b/pkg/executor/infoschema_reader.go @@ -3478,7 +3478,7 @@ func (e *memtableRetriever) setDataForAttributes(ctx context.Context, sctx sessi checker := privilege.GetPrivilegeManager(sctx) rules, err := infosync.GetAllLabelRules(context.TODO()) skipValidateTable := false - failpoint.Inject("mockOutputOfAttributes", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockOutputOfAttributes")); _err_ == nil { convert := func(i any) []any { return []any{i} } @@ -3513,7 +3513,7 @@ func (e *memtableRetriever) setDataForAttributes(ctx context.Context, sctx sessi } err = nil skipValidateTable = true - }) + } if err != nil { return errors.Wrap(err, "get the label rules failed") diff --git a/pkg/executor/infoschema_reader.go__failpoint_stash__ b/pkg/executor/infoschema_reader.go__failpoint_stash__ new file mode 100644 index 0000000000000..6b8455e281f39 --- /dev/null +++ b/pkg/executor/infoschema_reader.go__failpoint_stash__ @@ -0,0 +1,3878 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "fmt" + "math" + "slices" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/deadlock" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + rmpb "github.com/pingcap/kvproto/pkg/resource_manager" + "github.com/pingcap/tidb/pkg/ddl/label" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/internal/pdhelper" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/privilege/privileges" + "github.com/pingcap/tidb/pkg/session/txninfo" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/cache" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/deadlockhistory" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/keydecoder" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/resourcegrouptag" + "github.com/pingcap/tidb/pkg/util/sem" + "github.com/pingcap/tidb/pkg/util/servermemorylimit" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tidb/pkg/util/stringutil" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/txnkv/txnlock" + pd "github.com/tikv/pd/client/http" + "go.uber.org/zap" +) + +type memtableRetriever struct { + dummyCloser + table *model.TableInfo + columns []*model.ColumnInfo + rows [][]types.Datum + rowIdx int + retrieved bool + initialized bool + extractor base.MemTablePredicateExtractor + is infoschema.InfoSchema + memTracker *memory.Tracker +} + +// retrieve implements the infoschemaRetriever interface +func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.table.Name.O == infoschema.TableClusterInfo && !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + if e.retrieved { + return nil, nil + } + + // Cache the ret full rows in schemataRetriever + if !e.initialized { + is := sctx.GetInfoSchema().(infoschema.InfoSchema) + e.is = is + + var getAllSchemas = func() []model.CIStr { + dbs := is.AllSchemaNames() + slices.SortFunc(dbs, func(a, b model.CIStr) int { + return strings.Compare(a.L, b.L) + }) + return dbs + } + + var err error + switch e.table.Name.O { + case infoschema.TableSchemata: + err = e.setDataFromSchemata(sctx) + case infoschema.TableStatistics: + err = e.setDataForStatistics(ctx, sctx) + case infoschema.TableTables: + err = e.setDataFromTables(ctx, sctx) + case infoschema.TableReferConst: + dbs := getAllSchemas() + err = e.setDataFromReferConst(ctx, sctx, dbs) + case infoschema.TableSequences: + dbs := getAllSchemas() + err = e.setDataFromSequences(ctx, sctx, dbs) + case infoschema.TablePartitions: + err = e.setDataFromPartitions(ctx, sctx) + case infoschema.TableClusterInfo: + err = e.dataForTiDBClusterInfo(sctx) + case infoschema.TableAnalyzeStatus: + err = e.setDataForAnalyzeStatus(ctx, sctx) + case infoschema.TableTiDBIndexes: + dbs := getAllSchemas() + err = e.setDataFromIndexes(ctx, sctx, dbs) + case infoschema.TableViews: + dbs := getAllSchemas() + err = e.setDataFromViews(ctx, sctx, dbs) + case infoschema.TableEngines: + e.setDataFromEngines() + case infoschema.TableCharacterSets: + e.setDataFromCharacterSets() + case infoschema.TableCollations: + e.setDataFromCollations() + case infoschema.TableKeyColumn: + dbs := getAllSchemas() + err = e.setDataFromKeyColumnUsage(ctx, sctx, dbs) + case infoschema.TableMetricTables: + e.setDataForMetricTables() + case infoschema.TableProfiling: + e.setDataForPseudoProfiling(sctx) + case infoschema.TableCollationCharacterSetApplicability: + e.dataForCollationCharacterSetApplicability() + case infoschema.TableProcesslist: + e.setDataForProcessList(sctx) + case infoschema.ClusterTableProcesslist: + err = e.setDataForClusterProcessList(sctx) + case infoschema.TableUserPrivileges: + e.setDataFromUserPrivileges(sctx) + case infoschema.TableTiKVRegionStatus: + err = e.setDataForTiKVRegionStatus(ctx, sctx) + case infoschema.TableTiDBHotRegions: + err = e.setDataForTiDBHotRegions(ctx, sctx) + case infoschema.TableConstraints: + dbs := getAllSchemas() + err = e.setDataFromTableConstraints(ctx, sctx, dbs) + case infoschema.TableSessionVar: + e.rows, err = infoschema.GetDataFromSessionVariables(ctx, sctx) + case infoschema.TableTiDBServersInfo: + err = e.setDataForServersInfo(sctx) + case infoschema.TableTiFlashReplica: + dbs := getAllSchemas() + err = e.dataForTableTiFlashReplica(ctx, sctx, dbs) + case infoschema.TableTiKVStoreStatus: + err = e.dataForTiKVStoreStatus(ctx, sctx) + case infoschema.TableClientErrorsSummaryGlobal, + infoschema.TableClientErrorsSummaryByUser, + infoschema.TableClientErrorsSummaryByHost: + err = e.setDataForClientErrorsSummary(sctx, e.table.Name.O) + case infoschema.TableAttributes: + err = e.setDataForAttributes(ctx, sctx, is) + case infoschema.TablePlacementPolicies: + err = e.setDataFromPlacementPolicies(sctx) + case infoschema.TableTrxSummary: + err = e.setDataForTrxSummary(sctx) + case infoschema.ClusterTableTrxSummary: + err = e.setDataForClusterTrxSummary(sctx) + case infoschema.TableVariablesInfo: + err = e.setDataForVariablesInfo(sctx) + case infoschema.TableUserAttributes: + err = e.setDataForUserAttributes(ctx, sctx) + case infoschema.TableMemoryUsage: + err = e.setDataForMemoryUsage() + case infoschema.ClusterTableMemoryUsage: + err = e.setDataForClusterMemoryUsage(sctx) + case infoschema.TableMemoryUsageOpsHistory: + err = e.setDataForMemoryUsageOpsHistory() + case infoschema.ClusterTableMemoryUsageOpsHistory: + err = e.setDataForClusterMemoryUsageOpsHistory(sctx) + case infoschema.TableResourceGroups: + err = e.setDataFromResourceGroups() + case infoschema.TableRunawayWatches: + err = e.setDataFromRunawayWatches(sctx) + case infoschema.TableCheckConstraints: + dbs := getAllSchemas() + err = e.setDataFromCheckConstraints(ctx, sctx, dbs) + case infoschema.TableTiDBCheckConstraints: + dbs := getAllSchemas() + err = e.setDataFromTiDBCheckConstraints(ctx, sctx, dbs) + case infoschema.TableKeywords: + err = e.setDataFromKeywords() + case infoschema.TableTiDBIndexUsage: + dbs := getAllSchemas() + err = e.setDataFromIndexUsage(ctx, sctx, dbs) + case infoschema.ClusterTableTiDBIndexUsage: + dbs := getAllSchemas() + err = e.setDataForClusterIndexUsage(ctx, sctx, dbs) + } + if err != nil { + return nil, err + } + e.initialized = true + if e.memTracker != nil { + e.memTracker.Consume(calculateDatumsSize(e.rows)) + } + } + + // Adjust the amount of each return + maxCount := 1024 + retCount := maxCount + if e.rowIdx+maxCount > len(e.rows) { + retCount = len(e.rows) - e.rowIdx + e.retrieved = true + } + ret := make([][]types.Datum, retCount) + for i := e.rowIdx; i < e.rowIdx+retCount; i++ { + ret[i-e.rowIdx] = e.rows[i] + } + e.rowIdx += retCount + return adjustColumns(ret, e.columns, e.table), nil +} + +func getAutoIncrementID( + is infoschema.InfoSchema, + sctx sessionctx.Context, + tblInfo *model.TableInfo, +) int64 { + tbl, ok := is.TableByID(tblInfo.ID) + if !ok { + return 0 + } + return tbl.Allocators(sctx.GetTableCtx()).Get(autoid.AutoIncrementType).Base() + 1 +} + +func hasPriv(ctx sessionctx.Context, priv mysql.PrivilegeType) bool { + pm := privilege.GetPrivilegeManager(ctx) + if pm == nil { + // internal session created with createSession doesn't has the PrivilegeManager. For most experienced cases before, + // we use it like this: + // ``` + // checker := privilege.GetPrivilegeManager(ctx) + // if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", mysql.AllPrivMask) { + // continue + // } + // do something. + // ``` + // So once the privilege manager is nil, it's a signature of internal sql, so just passing the checker through. + return true + } + return pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, "", "", "", priv) +} + +func (e *memtableRetriever) setDataForVariablesInfo(ctx sessionctx.Context) error { + sysVars := variable.GetSysVars() + rows := make([][]types.Datum, 0, len(sysVars)) + for _, sv := range sysVars { + if infoschema.SysVarHiddenForSem(ctx, sv.Name) { + continue + } + currentVal, err := ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), sv.Name) + if err != nil { + currentVal = "" + } + isNoop := "NO" + if sv.IsNoop { + isNoop = "YES" + } + defVal := sv.Value + if sv.HasGlobalScope() { + defVal = variable.GlobalSystemVariableInitialValue(sv.Name, defVal) + } + row := types.MakeDatums( + sv.Name, // VARIABLE_NAME + sv.Scope.String(), // VARIABLE_SCOPE + defVal, // DEFAULT_VALUE + currentVal, // CURRENT_VALUE + sv.MinValue, // MIN_VALUE + sv.MaxValue, // MAX_VALUE + nil, // POSSIBLE_VALUES + isNoop, // IS_NOOP + ) + // min and max value is only supported for numeric types + if !(sv.Type == variable.TypeUnsigned || sv.Type == variable.TypeInt || sv.Type == variable.TypeFloat) { + row[4].SetNull() + row[5].SetNull() + } + if sv.Type == variable.TypeEnum { + possibleValues := strings.Join(sv.PossibleValues, ",") + row[6].SetString(possibleValues, mysql.DefaultCollationName) + } + rows = append(rows, row) + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForUserAttributes(ctx context.Context, sctx sessionctx.Context) error { + exec := sctx.GetRestrictedSQLExecutor() + chunkRows, _, err := exec.ExecRestrictedSQL(ctx, nil, `SELECT user, host, JSON_UNQUOTE(JSON_EXTRACT(user_attributes, '$.metadata')) FROM mysql.user`) + if err != nil { + return err + } + if len(chunkRows) == 0 { + return nil + } + rows := make([][]types.Datum, 0, len(chunkRows)) + for _, chunkRow := range chunkRows { + if chunkRow.Len() != 3 { + continue + } + user := chunkRow.GetString(0) + host := chunkRow.GetString(1) + // Compatible with results in MySQL + var attribute any + if attribute = chunkRow.GetString(2); attribute == "" { + attribute = nil + } + row := types.MakeDatums(user, host, attribute) + rows = append(rows, row) + } + + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromSchemata(ctx sessionctx.Context) error { + checker := privilege.GetPrivilegeManager(ctx) + ex, ok := e.extractor.(*plannercore.InfoSchemaSchemataExtractor) + if !ok { + return errors.Errorf("wrong extractor type: %T, expected InfoSchemaSchemataExtractor", e.extractor) + } + if ex.SkipRequest { + return nil + } + schemas := ex.ListSchemas(e.is) + rows := make([][]types.Datum, 0, len(schemas)) + + for _, schemaName := range schemas { + schema, _ := e.is.SchemaByName(schemaName) + charset := mysql.DefaultCharset + collation := mysql.DefaultCollationName + + if len(schema.Charset) > 0 { + charset = schema.Charset // Overwrite default + } + + if len(schema.Collate) > 0 { + collation = schema.Collate // Overwrite default + } + var policyName any + if schema.PlacementPolicyRef != nil { + policyName = schema.PlacementPolicyRef.Name.O + } + + if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, schema.Name.L, "", "", mysql.AllPrivMask) { + continue + } + record := types.MakeDatums( + infoschema.CatalogVal, // CATALOG_NAME + schema.Name.O, // SCHEMA_NAME + charset, // DEFAULT_CHARACTER_SET_NAME + collation, // DEFAULT_COLLATION_NAME + nil, // SQL_PATH + policyName, // TIDB_PLACEMENT_POLICY_NAME + ) + rows = append(rows, record) + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForStatistics(ctx context.Context, sctx sessionctx.Context) error { + checker := privilege.GetPrivilegeManager(sctx) + ex, ok := e.extractor.(*plannercore.InfoSchemaStatisticsExtractor) + if !ok { + return errors.Errorf("wrong extractor type: %T, expected InfoSchemaStatisticsExtractor", e.extractor) + } + if ex.SkipRequest { + return nil + } + schemas, tables, err := ex.ListSchemasAndTables(ctx, e.is) + if err != nil { + return errors.Trace(err) + } + for i, table := range tables { + schema := schemas[i] + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { + continue + } + e.setDataForStatisticsInTable(schema, table, ex) + } + return nil +} + +func (e *memtableRetriever) setDataForStatisticsInTable(schema model.CIStr, table *model.TableInfo, extractor *plannercore.InfoSchemaStatisticsExtractor) { + var rows [][]types.Datum + if table.PKIsHandle { + if !extractor.Filter("index_name", "primary") { + for _, col := range table.Columns { + if mysql.HasPriKeyFlag(col.GetFlag()) { + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + "0", // NON_UNIQUE + schema.O, // INDEX_SCHEMA + "PRIMARY", // INDEX_NAME + 1, // SEQ_IN_INDEX + col.Name.O, // COLUMN_NAME + "A", // COLLATION + 0, // CARDINALITY + nil, // SUB_PART + nil, // PACKED + "", // NULLABLE + "BTREE", // INDEX_TYPE + "", // COMMENT + "", // INDEX_COMMENT + "YES", // IS_VISIBLE + nil, // Expression + ) + rows = append(rows, record) + } + } + } + } + nameToCol := make(map[string]*model.ColumnInfo, len(table.Columns)) + for _, c := range table.Columns { + nameToCol[c.Name.L] = c + } + for _, index := range table.Indices { + if extractor.Filter("index_name", index.Name.L) { + continue + } + nonUnique := "1" + if index.Unique { + nonUnique = "0" + } + for i, key := range index.Columns { + col := nameToCol[key.Name.L] + nullable := "YES" + if mysql.HasNotNullFlag(col.GetFlag()) { + nullable = "" + } + + visible := "YES" + if index.Invisible { + visible = "NO" + } + + colName := col.Name.O + var expression any + expression = nil + tblCol := table.Columns[col.Offset] + if tblCol.Hidden { + colName = "NULL" + expression = tblCol.GeneratedExprString + } + + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + nonUnique, // NON_UNIQUE + schema.O, // INDEX_SCHEMA + index.Name.O, // INDEX_NAME + i+1, // SEQ_IN_INDEX + colName, // COLUMN_NAME + "A", // COLLATION + 0, // CARDINALITY + nil, // SUB_PART + nil, // PACKED + nullable, // NULLABLE + "BTREE", // INDEX_TYPE + "", // COMMENT + index.Comment, // INDEX_COMMENT + visible, // IS_VISIBLE + expression, // Expression + ) + rows = append(rows, record) + } + } + e.rows = append(e.rows, rows...) +} + +func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + checker := privilege.GetPrivilegeManager(sctx) + var rows [][]types.Datum + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + for _, schema := range schemas { + if ok && extractor.Filter("constraint_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, table := range tables { + if ok && extractor.Filter("table_name", table.Name.L) { + continue + } + if !table.IsBaseTable() { + continue + } + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { + continue + } + for _, fk := range table.ForeignKeys { + if ok && extractor.Filter("constraint_name", fk.Name.L) { + continue + } + updateRule, deleteRule := "NO ACTION", "NO ACTION" + if model.ReferOptionType(fk.OnUpdate) != 0 { + updateRule = model.ReferOptionType(fk.OnUpdate).String() + } + if model.ReferOptionType(fk.OnDelete) != 0 { + deleteRule = model.ReferOptionType(fk.OnDelete).String() + } + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + fk.Name.O, // CONSTRAINT_NAME + infoschema.CatalogVal, // UNIQUE_CONSTRAINT_CATALOG + schema.O, // UNIQUE_CONSTRAINT_SCHEMA + "PRIMARY", // UNIQUE_CONSTRAINT_NAME + "NONE", // MATCH_OPTION + updateRule, // UPDATE_RULE + deleteRule, // DELETE_RULE + table.Name.O, // TABLE_NAME + fk.RefTable.O, // REFERENCED_TABLE_NAME + ) + rows = append(rows, record) + } + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) updateStatsCacheIfNeed() bool { + for _, col := range e.columns { + // only the following columns need stats cache. + if col.Name.O == "AVG_ROW_LENGTH" || col.Name.O == "DATA_LENGTH" || col.Name.O == "INDEX_LENGTH" || col.Name.O == "TABLE_ROWS" { + return true + } + } + return false +} + +func (e *memtableRetriever) setDataFromOneTable( + sctx sessionctx.Context, + loc *time.Location, + checker privilege.Manager, + schema model.CIStr, + table *model.TableInfo, + rows [][]types.Datum, + useStatsCache bool, +) ([][]types.Datum, error) { + collation := table.Collate + if collation == "" { + collation = mysql.DefaultCollationName + } + createTime := types.NewTime(types.FromGoTime(table.GetUpdateTime().In(loc)), mysql.TypeDatetime, types.DefaultFsp) + + createOptions := "" + + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { + return rows, nil + } + pkType := "NONCLUSTERED" + if !table.IsView() { + if table.GetPartitionInfo() != nil { + createOptions = "partitioned" + } else if table.TableCacheStatusType == model.TableCacheStatusEnable { + createOptions = "cached=on" + } + var autoIncID any + hasAutoIncID, _ := infoschema.HasAutoIncrementColumn(table) + if hasAutoIncID { + autoIncID = getAutoIncrementID(e.is, sctx, table) + } + tableType := "BASE TABLE" + if util.IsSystemView(schema.L) { + tableType = "SYSTEM VIEW" + } + if table.IsSequence() { + tableType = "SEQUENCE" + } + if table.HasClusteredIndex() { + pkType = "CLUSTERED" + } + shardingInfo := infoschema.GetShardingInfo(schema, table) + var policyName any + if table.PlacementPolicyRef != nil { + policyName = table.PlacementPolicyRef.Name.O + } + + var rowCount, avgRowLength, dataLength, indexLength uint64 + if useStatsCache { + if table.GetPartitionInfo() == nil { + err := cache.TableRowStatsCache.UpdateByID(sctx, table.ID) + if err != nil { + return rows, err + } + } else { + // needs to update all partitions for partition table. + for _, pi := range table.GetPartitionInfo().Definitions { + err := cache.TableRowStatsCache.UpdateByID(sctx, pi.ID) + if err != nil { + return rows, err + } + } + } + rowCount, avgRowLength, dataLength, indexLength = cache.TableRowStatsCache.EstimateDataLength(table) + } + + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + tableType, // TABLE_TYPE + "InnoDB", // ENGINE + uint64(10), // VERSION + "Compact", // ROW_FORMAT + rowCount, // TABLE_ROWS + avgRowLength, // AVG_ROW_LENGTH + dataLength, // DATA_LENGTH + uint64(0), // MAX_DATA_LENGTH + indexLength, // INDEX_LENGTH + uint64(0), // DATA_FREE + autoIncID, // AUTO_INCREMENT + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + collation, // TABLE_COLLATION + nil, // CHECKSUM + createOptions, // CREATE_OPTIONS + table.Comment, // TABLE_COMMENT + table.ID, // TIDB_TABLE_ID + shardingInfo, // TIDB_ROW_ID_SHARDING_INFO + pkType, // TIDB_PK_TYPE + policyName, // TIDB_PLACEMENT_POLICY_NAME + ) + rows = append(rows, record) + } else { + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + "VIEW", // TABLE_TYPE + nil, // ENGINE + nil, // VERSION + nil, // ROW_FORMAT + nil, // TABLE_ROWS + nil, // AVG_ROW_LENGTH + nil, // DATA_LENGTH + nil, // MAX_DATA_LENGTH + nil, // INDEX_LENGTH + nil, // DATA_FREE + nil, // AUTO_INCREMENT + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + nil, // TABLE_COLLATION + nil, // CHECKSUM + nil, // CREATE_OPTIONS + "VIEW", // TABLE_COMMENT + table.ID, // TIDB_TABLE_ID + nil, // TIDB_ROW_ID_SHARDING_INFO + pkType, // TIDB_PK_TYPE + nil, // TIDB_PLACEMENT_POLICY_NAME + ) + rows = append(rows, record) + } + return rows, nil +} + +func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context) error { + useStatsCache := e.updateStatsCacheIfNeed() + checker := privilege.GetPrivilegeManager(sctx) + + var rows [][]types.Datum + loc := sctx.GetSessionVars().TimeZone + if loc == nil { + loc = time.Local + } + ex, ok := e.extractor.(*plannercore.InfoSchemaTablesExtractor) + if !ok { + return errors.Errorf("wrong extractor type: %T, expected InfoSchemaTablesExtractor", e.extractor) + } + if ex.SkipRequest { + return nil + } + + schemas, tables, err := ex.ListSchemasAndTables(ctx, e.is) + if err != nil { + return errors.Trace(err) + } + for i, table := range tables { + rows, err = e.setDataFromOneTable(sctx, loc, checker, schemas[i], table, rows, useStatsCache) + if err != nil { + return errors.Trace(err) + } + } + e.rows = rows + return nil +} + +// Data for inforation_schema.CHECK_CONSTRAINTS +// This is standards (ISO/IEC 9075-11) compliant and is compatible with the implementation in MySQL as well. +func (e *memtableRetriever) setDataFromCheckConstraints(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + var rows [][]types.Datum + checker := privilege.GetPrivilegeManager(sctx) + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + for _, schema := range schemas { + if ok && extractor.Filter("constraint_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, table := range tables { + if len(table.Constraints) > 0 { + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.SelectPriv) { + continue + } + for _, constraint := range table.Constraints { + if constraint.State != model.StatePublic { + continue + } + if ok && extractor.Filter("constraint_name", constraint.Name.L) { + continue + } + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + constraint.Name.O, // CONSTRAINT_NAME + fmt.Sprintf("(%s)", constraint.ExprString), // CHECK_CLAUSE + ) + rows = append(rows, record) + } + } + } + } + e.rows = rows + return nil +} + +// Data for inforation_schema.TIDB_CHECK_CONSTRAINTS +// This has non-standard TiDB specific extensions. +func (e *memtableRetriever) setDataFromTiDBCheckConstraints(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + var rows [][]types.Datum + checker := privilege.GetPrivilegeManager(sctx) + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + for _, schema := range schemas { + if ok && extractor.Filter("constraint_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, table := range tables { + if len(table.Constraints) > 0 { + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.SelectPriv) { + continue + } + for _, constraint := range table.Constraints { + if constraint.State != model.StatePublic { + continue + } + if ok && extractor.Filter("constraint_name", constraint.Name.L) { + continue + } + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + constraint.Name.O, // CONSTRAINT_NAME + fmt.Sprintf("(%s)", constraint.ExprString), // CHECK_CLAUSE + table.Name.O, // TABLE_NAME + table.ID, // TABLE_ID + ) + rows = append(rows, record) + } + } + } + } + e.rows = rows + return nil +} + +type hugeMemTableRetriever struct { + dummyCloser + extractor *plannercore.ColumnsTableExtractor + table *model.TableInfo + columns []*model.ColumnInfo + retrieved bool + initialized bool + rows [][]types.Datum + dbs []*model.DBInfo + curTables []*model.TableInfo + dbsIdx int + tblIdx int + viewMu syncutil.RWMutex + viewSchemaMap map[int64]*expression.Schema // table id to view schema + viewOutputNamesMap map[int64]types.NameSlice // table id to view output names + batch int + is infoschema.InfoSchema +} + +// retrieve implements the infoschemaRetriever interface +func (e *hugeMemTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.retrieved { + return nil, nil + } + + if !e.initialized { + e.is = sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() + dbs := e.is.AllSchemas() + slices.SortFunc(dbs, model.LessDBInfo) + e.dbs = dbs + e.initialized = true + e.rows = make([][]types.Datum, 0, 1024) + e.batch = 1024 + } + + var err error + if e.table.Name.O == infoschema.TableColumns { + err = e.setDataForColumns(ctx, sctx, e.extractor) + } + if err != nil { + return nil, err + } + e.retrieved = len(e.rows) == 0 + + return adjustColumns(e.rows, e.columns, e.table), nil +} + +func (e *hugeMemTableRetriever) setDataForColumns(ctx context.Context, sctx sessionctx.Context, extractor *plannercore.ColumnsTableExtractor) error { + checker := privilege.GetPrivilegeManager(sctx) + e.rows = e.rows[:0] + for ; e.dbsIdx < len(e.dbs); e.dbsIdx++ { + schema := e.dbs[e.dbsIdx] + var table *model.TableInfo + if len(e.curTables) == 0 { + tables, err := e.is.SchemaTableInfos(ctx, schema.Name) + if err != nil { + return errors.Trace(err) + } + e.curTables = tables + } + for e.tblIdx < len(e.curTables) { + table = e.curTables[e.tblIdx] + e.tblIdx++ + if e.setDataForColumnsWithOneTable(ctx, sctx, extractor, schema, table, checker) { + return nil + } + } + e.tblIdx = 0 + e.curTables = e.curTables[:0] + } + return nil +} + +func (e *hugeMemTableRetriever) setDataForColumnsWithOneTable( + ctx context.Context, + sctx sessionctx.Context, + extractor *plannercore.ColumnsTableExtractor, + schema *model.DBInfo, + table *model.TableInfo, + checker privilege.Manager) bool { + hasPrivs := false + var priv mysql.PrivilegeType + if checker != nil { + for _, p := range mysql.AllColumnPrivs { + if checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", p) { + hasPrivs = true + priv |= p + } + } + if !hasPrivs { + return false + } + } + + e.dataForColumnsInTable(ctx, sctx, schema, table, priv, extractor) + return len(e.rows) >= e.batch +} + +func (e *hugeMemTableRetriever) dataForColumnsInTable(ctx context.Context, sctx sessionctx.Context, schema *model.DBInfo, tbl *model.TableInfo, priv mysql.PrivilegeType, extractor *plannercore.ColumnsTableExtractor) { + if tbl.IsView() { + e.viewMu.Lock() + _, ok := e.viewSchemaMap[tbl.ID] + if !ok { + var viewLogicalPlan base.Plan + internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) + // Build plan is not thread safe, there will be concurrency on sessionctx. + if err := runWithSystemSession(internalCtx, sctx, func(s sessionctx.Context) error { + is := sessiontxn.GetTxnManager(s).GetTxnInfoSchema() + planBuilder, _ := plannercore.NewPlanBuilder().Init(s.GetPlanCtx(), is, hint.NewQBHintHandler(nil)) + var err error + viewLogicalPlan, err = planBuilder.BuildDataSourceFromView(ctx, schema.Name, tbl, nil, nil) + return errors.Trace(err) + }); err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(err) + e.viewMu.Unlock() + return + } + e.viewSchemaMap[tbl.ID] = viewLogicalPlan.Schema() + e.viewOutputNamesMap[tbl.ID] = viewLogicalPlan.OutputNames() + } + e.viewMu.Unlock() + } + + var tableSchemaRegexp, tableNameRegexp, columnsRegexp []collate.WildcardPattern + var tableSchemaFilterEnable, + tableNameFilterEnable, columnsFilterEnable bool + if !extractor.SkipRequest { + tableSchemaFilterEnable = extractor.TableSchema.Count() > 0 + tableNameFilterEnable = extractor.TableName.Count() > 0 + columnsFilterEnable = extractor.ColumnName.Count() > 0 + if len(extractor.TableSchemaPatterns) > 0 { + tableSchemaRegexp = make([]collate.WildcardPattern, len(extractor.TableSchemaPatterns)) + for i, pattern := range extractor.TableSchemaPatterns { + tableSchemaRegexp[i] = collate.GetCollatorByID(collate.CollationName2ID(mysql.UTF8MB4DefaultCollation)).Pattern() + tableSchemaRegexp[i].Compile(pattern, byte('\\')) + } + } + if len(extractor.TableNamePatterns) > 0 { + tableNameRegexp = make([]collate.WildcardPattern, len(extractor.TableNamePatterns)) + for i, pattern := range extractor.TableNamePatterns { + tableNameRegexp[i] = collate.GetCollatorByID(collate.CollationName2ID(mysql.UTF8MB4DefaultCollation)).Pattern() + tableNameRegexp[i].Compile(pattern, byte('\\')) + } + } + if len(extractor.ColumnNamePatterns) > 0 { + columnsRegexp = make([]collate.WildcardPattern, len(extractor.ColumnNamePatterns)) + for i, pattern := range extractor.ColumnNamePatterns { + columnsRegexp[i] = collate.GetCollatorByID(collate.CollationName2ID(mysql.UTF8MB4DefaultCollation)).Pattern() + columnsRegexp[i].Compile(pattern, byte('\\')) + } + } + } + i := 0 +ForColumnsTag: + for _, col := range tbl.Columns { + if col.Hidden { + continue + } + i++ + ft := &(col.FieldType) + if tbl.IsView() { + e.viewMu.RLock() + if e.viewSchemaMap[tbl.ID] != nil { + // If this is a view, replace the column with the view column. + idx := expression.FindFieldNameIdxByColName(e.viewOutputNamesMap[tbl.ID], col.Name.L) + if idx >= 0 { + col1 := e.viewSchemaMap[tbl.ID].Columns[idx] + ft = col1.GetType(sctx.GetExprCtx().GetEvalCtx()) + } + } + e.viewMu.RUnlock() + } + if !extractor.SkipRequest { + if tableSchemaFilterEnable && !extractor.TableSchema.Exist(schema.Name.L) { + continue + } + if tableNameFilterEnable && !extractor.TableName.Exist(tbl.Name.L) { + continue + } + if columnsFilterEnable && !extractor.ColumnName.Exist(col.Name.L) { + continue + } + for _, re := range tableSchemaRegexp { + if !re.DoMatch(schema.Name.L) { + continue ForColumnsTag + } + } + for _, re := range tableNameRegexp { + if !re.DoMatch(tbl.Name.L) { + continue ForColumnsTag + } + } + for _, re := range columnsRegexp { + if !re.DoMatch(col.Name.L) { + continue ForColumnsTag + } + } + } + + var charMaxLen, charOctLen, numericPrecision, numericScale, datetimePrecision any + colLen, decimal := ft.GetFlen(), ft.GetDecimal() + defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType()) + if decimal == types.UnspecifiedLength { + decimal = defaultDecimal + } + if colLen == types.UnspecifiedLength { + colLen = defaultFlen + } + if ft.GetType() == mysql.TypeSet { + // Example: In MySQL set('a','bc','def','ghij') has length 13, because + // len('a')+len('bc')+len('def')+len('ghij')+len(ThreeComma)=13 + // Reference link: https://bugs.mysql.com/bug.php?id=22613 + colLen = 0 + for _, ele := range ft.GetElems() { + colLen += len(ele) + } + if len(ft.GetElems()) != 0 { + colLen += (len(ft.GetElems()) - 1) + } + charMaxLen = colLen + charOctLen = calcCharOctLength(colLen, ft.GetCharset()) + } else if ft.GetType() == mysql.TypeEnum { + // Example: In MySQL enum('a', 'ab', 'cdef') has length 4, because + // the longest string in the enum is 'cdef' + // Reference link: https://bugs.mysql.com/bug.php?id=22613 + colLen = 0 + for _, ele := range ft.GetElems() { + if len(ele) > colLen { + colLen = len(ele) + } + } + charMaxLen = colLen + charOctLen = calcCharOctLength(colLen, ft.GetCharset()) + } else if types.IsString(ft.GetType()) { + charMaxLen = colLen + charOctLen = calcCharOctLength(colLen, ft.GetCharset()) + } else if types.IsTypeFractionable(ft.GetType()) { + datetimePrecision = decimal + } else if types.IsTypeNumeric(ft.GetType()) { + numericPrecision = colLen + if ft.GetType() != mysql.TypeFloat && ft.GetType() != mysql.TypeDouble { + numericScale = decimal + } else if decimal != -1 { + numericScale = decimal + } + } else if ft.GetType() == mysql.TypeNull { + charMaxLen, charOctLen = 0, 0 + } + columnType := ft.InfoSchemaStr() + columnDesc := table.NewColDesc(table.ToColumn(col)) + var columnDefault any + if columnDesc.DefaultValue != nil { + columnDefault = fmt.Sprintf("%v", columnDesc.DefaultValue) + switch col.GetDefaultValue() { + case "CURRENT_TIMESTAMP": + default: + if ft.GetType() == mysql.TypeTimestamp && columnDefault != types.ZeroDatetimeStr { + timeValue, err := table.GetColDefaultValue(sctx.GetExprCtx(), col) + if err == nil { + columnDefault = timeValue.GetMysqlTime().String() + } + } + if ft.GetType() == mysql.TypeBit && !col.DefaultIsExpr { + defaultValBinaryLiteral := types.BinaryLiteral(columnDefault.(string)) + columnDefault = defaultValBinaryLiteral.ToBitLiteralString(true) + } + } + } + colType := ft.GetType() + if colType == mysql.TypeVarString { + colType = mysql.TypeVarchar + } + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.Name.O, // TABLE_SCHEMA + tbl.Name.O, // TABLE_NAME + col.Name.O, // COLUMN_NAME + i, // ORDINAL_POSITION + columnDefault, // COLUMN_DEFAULT + columnDesc.Null, // IS_NULLABLE + types.TypeToStr(colType, ft.GetCharset()), // DATA_TYPE + charMaxLen, // CHARACTER_MAXIMUM_LENGTH + charOctLen, // CHARACTER_OCTET_LENGTH + numericPrecision, // NUMERIC_PRECISION + numericScale, // NUMERIC_SCALE + datetimePrecision, // DATETIME_PRECISION + columnDesc.Charset, // CHARACTER_SET_NAME + columnDesc.Collation, // COLLATION_NAME + columnType, // COLUMN_TYPE + columnDesc.Key, // COLUMN_KEY + columnDesc.Extra, // EXTRA + strings.ToLower(privileges.PrivToString(priv, mysql.AllColumnPrivs, mysql.Priv2Str)), // PRIVILEGES + columnDesc.Comment, // COLUMN_COMMENT + col.GeneratedExprString, // GENERATION_EXPRESSION + ) + e.rows = append(e.rows, record) + } +} + +func calcCharOctLength(lenInChar int, cs string) int { + lenInBytes := lenInChar + if desc, err := charset.GetCharsetInfo(cs); err == nil { + lenInBytes = desc.Maxlen * lenInChar + } + return lenInBytes +} + +func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context) error { + useStatsCache := e.updateStatsCacheIfNeed() + checker := privilege.GetPrivilegeManager(sctx) + var rows [][]types.Datum + createTimeTp := mysql.TypeDatetime + + ex, ok := e.extractor.(*plannercore.InfoSchemaPartitionsExtractor) + if !ok { + return errors.Errorf("wrong extractor type: %T, expected InfoSchemaPartitionsExtractor", e.extractor) + } + if ex.SkipRequest { + return nil + } + schemas, tables, err := ex.ListSchemasAndTables(ctx, e.is) + if err != nil { + return errors.Trace(err) + } + for i, table := range tables { + schema := schemas[i] + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.SelectPriv) { + continue + } + createTime := types.NewTime(types.FromGoTime(table.GetUpdateTime()), createTimeTp, types.DefaultFsp) + + var rowCount, dataLength, indexLength uint64 + if useStatsCache { + if table.GetPartitionInfo() == nil { + err := cache.TableRowStatsCache.UpdateByID(sctx, table.ID) + if err != nil { + return err + } + } else { + // needs to update needed partitions for partition table. + for _, pi := range table.GetPartitionInfo().Definitions { + if ex.Filter("partition_name", pi.Name.L) { + continue + } + err := cache.TableRowStatsCache.UpdateByID(sctx, pi.ID) + if err != nil { + return err + } + } + } + } + if table.GetPartitionInfo() == nil { + rowCount = cache.TableRowStatsCache.GetTableRows(table.ID) + dataLength, indexLength = cache.TableRowStatsCache.GetDataAndIndexLength(table, table.ID, rowCount) + avgRowLength := uint64(0) + if rowCount != 0 { + avgRowLength = dataLength / rowCount + } + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + nil, // PARTITION_NAME + nil, // SUBPARTITION_NAME + nil, // PARTITION_ORDINAL_POSITION + nil, // SUBPARTITION_ORDINAL_POSITION + nil, // PARTITION_METHOD + nil, // SUBPARTITION_METHOD + nil, // PARTITION_EXPRESSION + nil, // SUBPARTITION_EXPRESSION + nil, // PARTITION_DESCRIPTION + rowCount, // TABLE_ROWS + avgRowLength, // AVG_ROW_LENGTH + dataLength, // DATA_LENGTH + nil, // MAX_DATA_LENGTH + indexLength, // INDEX_LENGTH + nil, // DATA_FREE + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + nil, // CHECKSUM + nil, // PARTITION_COMMENT + nil, // NODEGROUP + nil, // TABLESPACE_NAME + nil, // TIDB_PARTITION_ID + nil, // TIDB_PLACEMENT_POLICY_NAME + ) + rows = append(rows, record) + } else { + for i, pi := range table.GetPartitionInfo().Definitions { + if ex.Filter("partition_name", pi.Name.L) { + continue + } + rowCount = cache.TableRowStatsCache.GetTableRows(pi.ID) + dataLength, indexLength = cache.TableRowStatsCache.GetDataAndIndexLength(table, pi.ID, rowCount) + avgRowLength := uint64(0) + if rowCount != 0 { + avgRowLength = dataLength / rowCount + } + + var partitionDesc string + if table.Partition.Type == model.PartitionTypeRange { + partitionDesc = strings.Join(pi.LessThan, ",") + } else if table.Partition.Type == model.PartitionTypeList { + if len(pi.InValues) > 0 { + buf := bytes.NewBuffer(nil) + for i, vs := range pi.InValues { + if i > 0 { + buf.WriteString(",") + } + if len(vs) != 1 { + buf.WriteString("(") + } + buf.WriteString(strings.Join(vs, ",")) + if len(vs) != 1 { + buf.WriteString(")") + } + } + partitionDesc = buf.String() + } + } + + partitionMethod := table.Partition.Type.String() + partitionExpr := table.Partition.Expr + if len(table.Partition.Columns) > 0 { + switch table.Partition.Type { + case model.PartitionTypeRange: + partitionMethod = "RANGE COLUMNS" + case model.PartitionTypeList: + partitionMethod = "LIST COLUMNS" + case model.PartitionTypeKey: + partitionMethod = "KEY" + default: + return errors.Errorf("Inconsistent partition type, have type %v, but with COLUMNS > 0 (%d)", table.Partition.Type, len(table.Partition.Columns)) + } + buf := bytes.NewBuffer(nil) + for i, col := range table.Partition.Columns { + if i > 0 { + buf.WriteString(",") + } + buf.WriteString("`") + buf.WriteString(col.String()) + buf.WriteString("`") + } + partitionExpr = buf.String() + } + + var policyName any + if pi.PlacementPolicyRef != nil { + policyName = pi.PlacementPolicyRef.Name.O + } + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + pi.Name.O, // PARTITION_NAME + nil, // SUBPARTITION_NAME + i+1, // PARTITION_ORDINAL_POSITION + nil, // SUBPARTITION_ORDINAL_POSITION + partitionMethod, // PARTITION_METHOD + nil, // SUBPARTITION_METHOD + partitionExpr, // PARTITION_EXPRESSION + nil, // SUBPARTITION_EXPRESSION + partitionDesc, // PARTITION_DESCRIPTION + rowCount, // TABLE_ROWS + avgRowLength, // AVG_ROW_LENGTH + dataLength, // DATA_LENGTH + uint64(0), // MAX_DATA_LENGTH + indexLength, // INDEX_LENGTH + uint64(0), // DATA_FREE + createTime, // CREATE_TIME + nil, // UPDATE_TIME + nil, // CHECK_TIME + nil, // CHECKSUM + pi.Comment, // PARTITION_COMMENT + nil, // NODEGROUP + nil, // TABLESPACE_NAME + pi.ID, // TIDB_PARTITION_ID + policyName, // TIDB_PLACEMENT_POLICY_NAME + ) + rows = append(rows, record) + } + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromIndexes(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + checker := privilege.GetPrivilegeManager(sctx) + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + var rows [][]types.Datum + for _, schema := range schemas { + if ok && extractor.Filter("table_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, tb := range tables { + if ok && extractor.Filter("table_name", tb.Name.L) { + continue + } + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, tb.Name.L, "", mysql.AllPrivMask) { + continue + } + + if tb.PKIsHandle { + var pkCol *model.ColumnInfo + for _, col := range tb.Cols() { + if mysql.HasPriKeyFlag(col.GetFlag()) { + pkCol = col + break + } + } + record := types.MakeDatums( + schema.O, // TABLE_SCHEMA + tb.Name.O, // TABLE_NAME + 0, // NON_UNIQUE + "PRIMARY", // KEY_NAME + 1, // SEQ_IN_INDEX + pkCol.Name.O, // COLUMN_NAME + nil, // SUB_PART + "", // INDEX_COMMENT + nil, // Expression + 0, // INDEX_ID + "YES", // IS_VISIBLE + "YES", // CLUSTERED + 0, // IS_GLOBAL + ) + rows = append(rows, record) + } + for _, idxInfo := range tb.Indices { + if idxInfo.State != model.StatePublic { + continue + } + isClustered := "NO" + if tb.IsCommonHandle && idxInfo.Primary { + isClustered = "YES" + } + for i, col := range idxInfo.Columns { + nonUniq := 1 + if idxInfo.Unique { + nonUniq = 0 + } + var subPart any + if col.Length != types.UnspecifiedLength { + subPart = col.Length + } + colName := col.Name.O + var expression any + expression = nil + tblCol := tb.Columns[col.Offset] + if tblCol.Hidden { + colName = "NULL" + expression = tblCol.GeneratedExprString + } + visible := "YES" + if idxInfo.Invisible { + visible = "NO" + } + record := types.MakeDatums( + schema.O, // TABLE_SCHEMA + tb.Name.O, // TABLE_NAME + nonUniq, // NON_UNIQUE + idxInfo.Name.O, // KEY_NAME + i+1, // SEQ_IN_INDEX + colName, // COLUMN_NAME + subPart, // SUB_PART + idxInfo.Comment, // INDEX_COMMENT + expression, // Expression + idxInfo.ID, // INDEX_ID + visible, // IS_VISIBLE + isClustered, // CLUSTERED + idxInfo.Global, // IS_GLOBAL + ) + rows = append(rows, record) + } + } + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromViews(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + checker := privilege.GetPrivilegeManager(sctx) + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + var rows [][]types.Datum + for _, schema := range schemas { + if ok && extractor.Filter("table_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, table := range tables { + if ok && extractor.Filter("table_name", table.Name.L) { + continue + } + if !table.IsView() { + continue + } + collation := table.Collate + charset := table.Charset + if collation == "" { + collation = mysql.DefaultCollationName + } + if charset == "" { + charset = mysql.DefaultCharset + } + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { + continue + } + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + table.View.SelectStmt, // VIEW_DEFINITION + table.View.CheckOption.String(), // CHECK_OPTION + "NO", // IS_UPDATABLE + table.View.Definer.String(), // DEFINER + table.View.Security.String(), // SECURITY_TYPE + charset, // CHARACTER_SET_CLIENT + collation, // COLLATION_CONNECTION + ) + rows = append(rows, record) + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) dataForTiKVStoreStatus(ctx context.Context, sctx sessionctx.Context) (err error) { + tikvStore, ok := sctx.GetStore().(helper.Storage) + if !ok { + return errors.New("Information about TiKV store status can be gotten only when the storage is TiKV") + } + tikvHelper := &helper.Helper{ + Store: tikvStore, + RegionCache: tikvStore.GetRegionCache(), + } + pdCli, err := tikvHelper.TryGetPDHTTPClient() + if err != nil { + return err + } + storesStat, err := pdCli.GetStores(ctx) + if err != nil { + return err + } + for _, storeStat := range storesStat.Stores { + row := make([]types.Datum, len(infoschema.TableTiKVStoreStatusCols)) + row[0].SetInt64(storeStat.Store.ID) + row[1].SetString(storeStat.Store.Address, mysql.DefaultCollationName) + row[2].SetInt64(storeStat.Store.State) + row[3].SetString(storeStat.Store.StateName, mysql.DefaultCollationName) + data, err := json.Marshal(storeStat.Store.Labels) + if err != nil { + return err + } + bj := types.BinaryJSON{} + if err = bj.UnmarshalJSON(data); err != nil { + return err + } + row[4].SetMysqlJSON(bj) + row[5].SetString(storeStat.Store.Version, mysql.DefaultCollationName) + row[6].SetString(storeStat.Status.Capacity, mysql.DefaultCollationName) + row[7].SetString(storeStat.Status.Available, mysql.DefaultCollationName) + row[8].SetInt64(storeStat.Status.LeaderCount) + row[9].SetFloat64(storeStat.Status.LeaderWeight) + row[10].SetFloat64(storeStat.Status.LeaderScore) + row[11].SetInt64(storeStat.Status.LeaderSize) + row[12].SetInt64(storeStat.Status.RegionCount) + row[13].SetFloat64(storeStat.Status.RegionWeight) + row[14].SetFloat64(storeStat.Status.RegionScore) + row[15].SetInt64(storeStat.Status.RegionSize) + startTs := types.NewTime(types.FromGoTime(storeStat.Status.StartTS), mysql.TypeDatetime, types.DefaultFsp) + row[16].SetMysqlTime(startTs) + lastHeartbeatTs := types.NewTime(types.FromGoTime(storeStat.Status.LastHeartbeatTS), mysql.TypeDatetime, types.DefaultFsp) + row[17].SetMysqlTime(lastHeartbeatTs) + row[18].SetString(storeStat.Status.Uptime, mysql.DefaultCollationName) + if sem.IsEnabled() { + // Patch out IP addresses etc if the user does not have the RESTRICTED_TABLES_ADMIN privilege + checker := privilege.GetPrivilegeManager(sctx) + if checker == nil || !checker.RequestDynamicVerification(sctx.GetSessionVars().ActiveRoles, "RESTRICTED_TABLES_ADMIN", false) { + row[1].SetString(strconv.FormatInt(storeStat.Store.ID, 10), mysql.DefaultCollationName) + row[1].SetNull() + row[6].SetNull() + row[7].SetNull() + row[16].SetNull() + row[18].SetNull() + } + } + e.rows = append(e.rows, row) + } + return nil +} + +// DDLJobsReaderExec executes DDLJobs information retrieving. +type DDLJobsReaderExec struct { + exec.BaseExecutor + DDLJobRetriever + + cacheJobs []*model.Job + is infoschema.InfoSchema + sess sessionctx.Context +} + +// Open implements the Executor Next interface. +func (e *DDLJobsReaderExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + e.DDLJobRetriever.is = e.is + e.activeRoles = e.Ctx().GetSessionVars().ActiveRoles + sess, err := e.GetSysSession() + if err != nil { + return err + } + e.sess = sess + err = sessiontxn.NewTxn(context.Background(), sess) + if err != nil { + return err + } + txn, err := sess.Txn(true) + if err != nil { + return err + } + sess.GetSessionVars().SetInTxn(true) + err = e.DDLJobRetriever.initial(txn, sess) + if err != nil { + return err + } + return nil +} + +// Next implements the Executor Next interface. +func (e *DDLJobsReaderExec) Next(_ context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + checker := privilege.GetPrivilegeManager(e.Ctx()) + count := 0 + + // Append running DDL jobs. + if e.cursor < len(e.runningJobs) { + num := min(req.Capacity(), len(e.runningJobs)-e.cursor) + for i := e.cursor; i < e.cursor+num; i++ { + e.appendJobToChunk(req, e.runningJobs[i], checker) + req.AppendString(12, e.runningJobs[i].Query) + if e.runningJobs[i].MultiSchemaInfo != nil { + for range e.runningJobs[i].MultiSchemaInfo.SubJobs { + req.AppendString(12, e.runningJobs[i].Query) + } + } + } + e.cursor += num + count += num + } + var err error + + // Append history DDL jobs. + if count < req.Capacity() { + e.cacheJobs, err = e.historyJobIter.GetLastJobs(req.Capacity()-count, e.cacheJobs) + if err != nil { + return err + } + for _, job := range e.cacheJobs { + e.appendJobToChunk(req, job, checker) + req.AppendString(12, job.Query) + if job.MultiSchemaInfo != nil { + for range job.MultiSchemaInfo.SubJobs { + req.AppendString(12, job.Query) + } + } + } + e.cursor += len(e.cacheJobs) + } + return nil +} + +// Close implements the Executor Close interface. +func (e *DDLJobsReaderExec) Close() error { + e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), e.sess) + return e.BaseExecutor.Close() +} + +func (e *memtableRetriever) setDataFromEngines() { + var rows [][]types.Datum + rows = append(rows, + types.MakeDatums( + "InnoDB", // Engine + "DEFAULT", // Support + "Supports transactions, row-level locking, and foreign keys", // Comment + "YES", // Transactions + "YES", // XA + "YES", // Savepoints + ), + ) + e.rows = rows +} + +func (e *memtableRetriever) setDataFromCharacterSets() { + charsets := charset.GetSupportedCharsets() + var rows = make([][]types.Datum, 0, len(charsets)) + for _, charset := range charsets { + rows = append(rows, + types.MakeDatums(charset.Name, charset.DefaultCollation, charset.Desc, charset.Maxlen), + ) + } + e.rows = rows +} + +func (e *memtableRetriever) setDataFromCollations() { + collations := collate.GetSupportedCollations() + var rows = make([][]types.Datum, 0, len(collations)) + for _, collation := range collations { + isDefault := "" + if collation.IsDefault { + isDefault = "Yes" + } + rows = append(rows, + types.MakeDatums(collation.Name, collation.CharsetName, collation.ID, + isDefault, "Yes", collation.Sortlen, collation.PadAttribute), + ) + } + e.rows = rows +} + +func (e *memtableRetriever) dataForCollationCharacterSetApplicability() { + collations := collate.GetSupportedCollations() + var rows = make([][]types.Datum, 0, len(collations)) + for _, collation := range collations { + rows = append(rows, + types.MakeDatums(collation.Name, collation.CharsetName), + ) + } + e.rows = rows +} + +func (e *memtableRetriever) dataForTiDBClusterInfo(ctx sessionctx.Context) error { + servers, err := infoschema.GetClusterServerInfo(ctx) + if err != nil { + e.rows = nil + return err + } + rows := make([][]types.Datum, 0, len(servers)) + for _, server := range servers { + upTimeStr := "" + startTimeNative := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 0) + if server.StartTimestamp > 0 { + startTime := time.Unix(server.StartTimestamp, 0) + startTimeNative = types.NewTime(types.FromGoTime(startTime), mysql.TypeDatetime, 0) + upTimeStr = time.Since(startTime).String() + } + serverType := server.ServerType + if server.ServerType == kv.TiFlash.Name() && server.EngineRole == placement.EngineRoleLabelWrite { + serverType = infoschema.TiFlashWrite + } + row := types.MakeDatums( + serverType, + server.Address, + server.StatusAddr, + server.Version, + server.GitHash, + startTimeNative, + upTimeStr, + server.ServerID, + ) + if sem.IsEnabled() { + checker := privilege.GetPrivilegeManager(ctx) + if checker == nil || !checker.RequestDynamicVerification(ctx.GetSessionVars().ActiveRoles, "RESTRICTED_TABLES_ADMIN", false) { + row[1].SetString(strconv.FormatUint(server.ServerID, 10), mysql.DefaultCollationName) + row[2].SetNull() + row[5].SetNull() + row[6].SetNull() + } + } + rows = append(rows, row) + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromKeyColumnUsage(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + checker := privilege.GetPrivilegeManager(sctx) + rows := make([][]types.Datum, 0, len(schemas)) // The capacity is not accurate, but it is not a big problem. + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + for _, schema := range schemas { + // `constraint_schema` and `table_schema` are always the same in MySQL. + if ok && extractor.Filter("constraint_schema", schema.L) { + continue + } + if ok && extractor.Filter("table_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, table := range tables { + if ok && extractor.Filter("table_name", table.Name.L) { + continue + } + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { + continue + } + rs := keyColumnUsageInTable(schema, table, extractor) + rows = append(rows, rs...) + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForClusterProcessList(ctx sessionctx.Context) error { + e.setDataForProcessList(ctx) + rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) + if err != nil { + return err + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForProcessList(ctx sessionctx.Context) { + sm := ctx.GetSessionManager() + if sm == nil { + return + } + + loginUser := ctx.GetSessionVars().User + hasProcessPriv := hasPriv(ctx, mysql.ProcessPriv) + pl := sm.ShowProcessList() + + records := make([][]types.Datum, 0, len(pl)) + for _, pi := range pl { + // If you have the PROCESS privilege, you can see all threads. + // Otherwise, you can see only your own threads. + if !hasProcessPriv && loginUser != nil && pi.User != loginUser.Username { + continue + } + + rows := pi.ToRow(ctx.GetSessionVars().StmtCtx.TimeZone()) + record := types.MakeDatums(rows...) + records = append(records, record) + } + e.rows = records +} + +func (e *memtableRetriever) setDataFromUserPrivileges(ctx sessionctx.Context) { + pm := privilege.GetPrivilegeManager(ctx) + // The results depend on the user querying the information. + e.rows = pm.UserPrivilegesTable(ctx.GetSessionVars().ActiveRoles, ctx.GetSessionVars().User.Username, ctx.GetSessionVars().User.Hostname) +} + +func (e *memtableRetriever) setDataForMetricTables() { + tables := make([]string, 0, len(infoschema.MetricTableMap)) + for name := range infoschema.MetricTableMap { + tables = append(tables, name) + } + slices.Sort(tables) + rows := make([][]types.Datum, 0, len(tables)) + for _, name := range tables { + schema := infoschema.MetricTableMap[name] + record := types.MakeDatums( + name, // METRICS_NAME + schema.PromQL, // PROMQL + strings.Join(schema.Labels, ","), // LABELS + schema.Quantile, // QUANTILE + schema.Comment, // COMMENT + ) + rows = append(rows, record) + } + e.rows = rows +} + +func keyColumnUsageInTable(schema model.CIStr, table *model.TableInfo, extractor *plannercore.InfoSchemaBaseExtractor) [][]types.Datum { + var rows [][]types.Datum + if table.PKIsHandle { + for _, col := range table.Columns { + if mysql.HasPriKeyFlag(col.GetFlag()) { + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + infoschema.PrimaryConstraint, // CONSTRAINT_NAME + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + col.Name.O, // COLUMN_NAME + 1, // ORDINAL_POSITION + 1, // POSITION_IN_UNIQUE_CONSTRAINT + nil, // REFERENCED_TABLE_SCHEMA + nil, // REFERENCED_TABLE_NAME + nil, // REFERENCED_COLUMN_NAME + ) + rows = append(rows, record) + break + } + } + } + nameToCol := make(map[string]*model.ColumnInfo, len(table.Columns)) + for _, c := range table.Columns { + nameToCol[c.Name.L] = c + } + for _, index := range table.Indices { + var idxName string + if index.Primary { + idxName = infoschema.PrimaryConstraint + } else if index.Unique { + idxName = index.Name.O + } else { + // Only handle unique/primary key + continue + } + + if extractor != nil && extractor.Filter("constraint_name", idxName) { + continue + } + + for i, key := range index.Columns { + col := nameToCol[key.Name.L] + if col.Hidden { + continue + } + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + idxName, // CONSTRAINT_NAME + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + col.Name.O, // COLUMN_NAME + i+1, // ORDINAL_POSITION, + nil, // POSITION_IN_UNIQUE_CONSTRAINT + nil, // REFERENCED_TABLE_SCHEMA + nil, // REFERENCED_TABLE_NAME + nil, // REFERENCED_COLUMN_NAME + ) + rows = append(rows, record) + } + } + for _, fk := range table.ForeignKeys { + for i, key := range fk.Cols { + fkRefCol := "" + if len(fk.RefCols) > i { + fkRefCol = fk.RefCols[i].O + } + col := nameToCol[key.L] + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + fk.Name.O, // CONSTRAINT_NAME + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // TABLE_SCHEMA + table.Name.O, // TABLE_NAME + col.Name.O, // COLUMN_NAME + i+1, // ORDINAL_POSITION, + 1, // POSITION_IN_UNIQUE_CONSTRAINT + fk.RefSchema.O, // REFERENCED_TABLE_SCHEMA + fk.RefTable.O, // REFERENCED_TABLE_NAME + fkRefCol, // REFERENCED_COLUMN_NAME + ) + rows = append(rows, record) + } + } + return rows +} + +func (e *memtableRetriever) setDataForTiKVRegionStatus(ctx context.Context, sctx sessionctx.Context) (err error) { + checker := privilege.GetPrivilegeManager(sctx) + var extractorTableIDs []int64 + tikvStore, ok := sctx.GetStore().(helper.Storage) + if !ok { + return errors.New("Information about TiKV region status can be gotten only when the storage is TiKV") + } + tikvHelper := &helper.Helper{ + Store: tikvStore, + RegionCache: tikvStore.GetRegionCache(), + } + requestByTableRange := false + var allRegionsInfo *pd.RegionsInfo + is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) + if e.extractor != nil { + extractor, ok := e.extractor.(*plannercore.TiKVRegionStatusExtractor) + if ok && len(extractor.GetTablesID()) > 0 { + extractorTableIDs = extractor.GetTablesID() + for _, tableID := range extractorTableIDs { + regionsInfo, err := e.getRegionsInfoForTable(ctx, tikvHelper, is, tableID) + if err != nil { + if errors.ErrorEqual(err, infoschema.ErrTableExists) { + continue + } + return err + } + allRegionsInfo = allRegionsInfo.Merge(regionsInfo) + } + requestByTableRange = true + } + } + if !requestByTableRange { + pdCli, err := tikvHelper.TryGetPDHTTPClient() + if err != nil { + return err + } + allRegionsInfo, err = pdCli.GetRegions(ctx) + if err != nil { + return err + } + } + tableInfos := tikvHelper.GetRegionsTableInfo(allRegionsInfo, is, nil) + for i := range allRegionsInfo.Regions { + regionTableList := tableInfos[allRegionsInfo.Regions[i].ID] + if len(regionTableList) == 0 { + e.setNewTiKVRegionStatusCol(&allRegionsInfo.Regions[i], nil) + } + for j, regionTable := range regionTableList { + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, regionTable.DB.Name.L, regionTable.Table.Name.L, "", mysql.AllPrivMask) { + continue + } + if len(extractorTableIDs) == 0 { + e.setNewTiKVRegionStatusCol(&allRegionsInfo.Regions[i], ®ionTable) + } + if slices.Contains(extractorTableIDs, regionTableList[j].Table.ID) { + e.setNewTiKVRegionStatusCol(&allRegionsInfo.Regions[i], ®ionTable) + } + } + } + return nil +} + +func (e *memtableRetriever) getRegionsInfoForTable(ctx context.Context, h *helper.Helper, is infoschema.InfoSchema, tableID int64) (*pd.RegionsInfo, error) { + tbl, _ := is.TableByID(tableID) + if tbl == nil { + return nil, infoschema.ErrTableExists.GenWithStackByArgs(tableID) + } + + pt := tbl.Meta().GetPartitionInfo() + if pt == nil { + regionsInfo, err := e.getRegionsInfoForSingleTable(ctx, h, tableID) + if err != nil { + return nil, err + } + return regionsInfo, nil + } + + var allRegionsInfo *pd.RegionsInfo + for _, def := range pt.Definitions { + regionsInfo, err := e.getRegionsInfoForSingleTable(ctx, h, def.ID) + if err != nil { + return nil, err + } + allRegionsInfo = allRegionsInfo.Merge(regionsInfo) + } + return allRegionsInfo, nil +} + +func (*memtableRetriever) getRegionsInfoForSingleTable(ctx context.Context, helper *helper.Helper, tableID int64) (*pd.RegionsInfo, error) { + pdCli, err := helper.TryGetPDHTTPClient() + if err != nil { + return nil, err + } + sk, ek := tablecodec.GetTableHandleKeyRange(tableID) + sRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, sk)) + if err != nil { + return nil, err + } + eRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, ek)) + if err != nil { + return nil, err + } + sk, err = hex.DecodeString(sRegion.StartKey) + if err != nil { + return nil, err + } + ek, err = hex.DecodeString(eRegion.EndKey) + if err != nil { + return nil, err + } + return pdCli.GetRegionsByKeyRange(ctx, pd.NewKeyRange(sk, ek), -1) +} + +func (e *memtableRetriever) setNewTiKVRegionStatusCol(region *pd.RegionInfo, table *helper.TableInfo) { + row := make([]types.Datum, len(infoschema.TableTiKVRegionStatusCols)) + row[0].SetInt64(region.ID) + row[1].SetString(region.StartKey, mysql.DefaultCollationName) + row[2].SetString(region.EndKey, mysql.DefaultCollationName) + if table != nil { + row[3].SetInt64(table.Table.ID) + row[4].SetString(table.DB.Name.O, mysql.DefaultCollationName) + row[5].SetString(table.Table.Name.O, mysql.DefaultCollationName) + if table.IsIndex { + row[6].SetInt64(1) + row[7].SetInt64(table.Index.ID) + row[8].SetString(table.Index.Name.O, mysql.DefaultCollationName) + } else { + row[6].SetInt64(0) + } + if table.IsPartition { + row[9].SetInt64(1) + row[10].SetInt64(table.Partition.ID) + row[11].SetString(table.Partition.Name.O, mysql.DefaultCollationName) + } else { + row[9].SetInt64(0) + } + } else { + row[6].SetInt64(0) + row[9].SetInt64(0) + } + row[12].SetInt64(region.Epoch.ConfVer) + row[13].SetInt64(region.Epoch.Version) + row[14].SetUint64(region.WrittenBytes) + row[15].SetUint64(region.ReadBytes) + row[16].SetInt64(region.ApproximateSize) + row[17].SetInt64(region.ApproximateKeys) + if region.ReplicationStatus != nil { + row[18].SetString(region.ReplicationStatus.State, mysql.DefaultCollationName) + row[19].SetInt64(region.ReplicationStatus.StateID) + } + e.rows = append(e.rows, row) +} + +const ( + normalPeer = "NORMAL" + pendingPeer = "PENDING" + downPeer = "DOWN" +) + +func (e *memtableRetriever) setDataForTiDBHotRegions(ctx context.Context, sctx sessionctx.Context) error { + tikvStore, ok := sctx.GetStore().(helper.Storage) + if !ok { + return errors.New("Information about hot region can be gotten only when the storage is TiKV") + } + tikvHelper := &helper.Helper{ + Store: tikvStore, + RegionCache: tikvStore.GetRegionCache(), + } + is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() + metrics, err := tikvHelper.ScrapeHotInfo(ctx, helper.HotRead, is, tikvHelper.FilterMemDBs) + if err != nil { + return err + } + e.setDataForHotRegionByMetrics(metrics, "read") + metrics, err = tikvHelper.ScrapeHotInfo(ctx, helper.HotWrite, is, nil) + if err != nil { + return err + } + e.setDataForHotRegionByMetrics(metrics, "write") + return nil +} + +func (e *memtableRetriever) setDataForHotRegionByMetrics(metrics []helper.HotTableIndex, tp string) { + rows := make([][]types.Datum, 0, len(metrics)) + for _, tblIndex := range metrics { + row := make([]types.Datum, len(infoschema.TableTiDBHotRegionsCols)) + if tblIndex.IndexName != "" { + row[1].SetInt64(tblIndex.IndexID) + row[4].SetString(tblIndex.IndexName, mysql.DefaultCollationName) + } else { + row[1].SetNull() + row[4].SetNull() + } + row[0].SetInt64(tblIndex.TableID) + row[2].SetString(tblIndex.DbName, mysql.DefaultCollationName) + row[3].SetString(tblIndex.TableName, mysql.DefaultCollationName) + row[5].SetUint64(tblIndex.RegionID) + row[6].SetString(tp, mysql.DefaultCollationName) + if tblIndex.RegionMetric == nil { + row[7].SetNull() + row[8].SetNull() + } else { + row[7].SetInt64(int64(tblIndex.RegionMetric.MaxHotDegree)) + row[8].SetInt64(int64(tblIndex.RegionMetric.Count)) + } + row[9].SetUint64(tblIndex.RegionMetric.FlowBytes) + rows = append(rows, row) + } + e.rows = append(e.rows, rows...) +} + +// setDataFromTableConstraints constructs data for table information_schema.constraints.See https://dev.mysql.com/doc/refman/5.7/en/table-constraints-table.html +func (e *memtableRetriever) setDataFromTableConstraints(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + checker := privilege.GetPrivilegeManager(sctx) + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + var rows [][]types.Datum + for _, schema := range schemas { + if ok && extractor.Filter("constraint_schema", schema.L) { + continue + } + if ok && extractor.Filter("table_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, tbl := range tables { + if ok && extractor.Filter("table_name", tbl.Name.L) { + continue + } + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, tbl.Name.L, "", mysql.AllPrivMask) { + continue + } + + if tbl.PKIsHandle { + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + mysql.PrimaryKeyName, // CONSTRAINT_NAME + schema.O, // TABLE_SCHEMA + tbl.Name.O, // TABLE_NAME + infoschema.PrimaryKeyType, // CONSTRAINT_TYPE + ) + rows = append(rows, record) + } + + for _, idx := range tbl.Indices { + var cname, ctype string + if idx.Primary { + cname = mysql.PrimaryKeyName + ctype = infoschema.PrimaryKeyType + } else if idx.Unique { + cname = idx.Name.O + ctype = infoschema.UniqueKeyType + } else { + // The index has no constriant. + continue + } + if ok && extractor.Filter("constraint_name", cname) { + continue + } + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + cname, // CONSTRAINT_NAME + schema.O, // TABLE_SCHEMA + tbl.Name.O, // TABLE_NAME + ctype, // CONSTRAINT_TYPE + ) + rows = append(rows, record) + } + // TiDB includes foreign key information for compatibility but foreign keys are not yet enforced. + for _, fk := range tbl.ForeignKeys { + record := types.MakeDatums( + infoschema.CatalogVal, // CONSTRAINT_CATALOG + schema.O, // CONSTRAINT_SCHEMA + fk.Name.O, // CONSTRAINT_NAME + schema.O, // TABLE_SCHEMA + tbl.Name.O, // TABLE_NAME + infoschema.ForeignKeyType, // CONSTRAINT_TYPE + ) + rows = append(rows, record) + } + } + } + e.rows = rows + return nil +} + +// tableStorageStatsRetriever is used to read slow log data. +type tableStorageStatsRetriever struct { + dummyCloser + table *model.TableInfo + outputCols []*model.ColumnInfo + retrieved bool + initialized bool + extractor *plannercore.TableStorageStatsExtractor + initialTables []*initialTable + curTable int + helper *helper.Helper + stats *pd.RegionStats +} + +func (e *tableStorageStatsRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.retrieved { + return nil, nil + } + if !e.initialized { + err := e.initialize(ctx, sctx) + if err != nil { + return nil, err + } + } + if len(e.initialTables) == 0 || e.curTable >= len(e.initialTables) { + e.retrieved = true + return nil, nil + } + + rows, err := e.setDataForTableStorageStats(ctx) + if err != nil { + return nil, err + } + if len(e.outputCols) == len(e.table.Columns) { + return rows, nil + } + retRows := make([][]types.Datum, len(rows)) + for i, fullRow := range rows { + row := make([]types.Datum, len(e.outputCols)) + for j, col := range e.outputCols { + row[j] = fullRow[col.Offset] + } + retRows[i] = row + } + return retRows, nil +} + +type initialTable struct { + db string + *model.TableInfo +} + +func (e *tableStorageStatsRetriever) initialize(ctx context.Context, sctx sessionctx.Context) error { + is := sctx.GetInfoSchema().(infoschema.InfoSchema) + var databases []string + schemas := e.extractor.TableSchema + tables := e.extractor.TableName + + // If not specify the table_schema, return an error to avoid traverse all schemas and their tables. + if len(schemas) == 0 { + return errors.Errorf("Please add where clause to filter the column TABLE_SCHEMA. " + + "For example, where TABLE_SCHEMA = 'xxx' or where TABLE_SCHEMA in ('xxx', 'yyy')") + } + + // Filter the sys or memory schema. + for schema := range schemas { + if !util.IsMemDB(schema) { + databases = append(databases, schema) + } + } + + // Privilege checker. + checker := func(db, table string) bool { + if pm := privilege.GetPrivilegeManager(sctx); pm != nil { + return pm.RequestVerification(sctx.GetSessionVars().ActiveRoles, db, table, "", mysql.AllPrivMask) + } + return true + } + + // Extract the tables to the initialTable. + for _, DB := range databases { + // The user didn't specified the table, extract all tables of this db to initialTable. + if len(tables) == 0 { + tbs, err := is.SchemaTableInfos(ctx, model.NewCIStr(DB)) + if err != nil { + return errors.Trace(err) + } + for _, tb := range tbs { + // For every db.table, check it's privileges. + if checker(DB, tb.Name.L) { + e.initialTables = append(e.initialTables, &initialTable{DB, tb}) + } + } + } else { + // The user specified the table, extract the specified tables of this db to initialTable. + for tb := range tables { + if tb, err := is.TableByName(context.Background(), model.NewCIStr(DB), model.NewCIStr(tb)); err == nil { + // For every db.table, check it's privileges. + if checker(DB, tb.Meta().Name.L) { + e.initialTables = append(e.initialTables, &initialTable{DB, tb.Meta()}) + } + } + } + } + } + + // Cache the helper and return an error if PD unavailable. + tikvStore, ok := sctx.GetStore().(helper.Storage) + if !ok { + return errors.Errorf("Information about TiKV region status can be gotten only when the storage is TiKV") + } + e.helper = helper.NewHelper(tikvStore) + _, err := e.helper.GetPDAddr() + if err != nil { + return err + } + e.initialized = true + return nil +} + +func (e *tableStorageStatsRetriever) setDataForTableStorageStats(ctx context.Context) ([][]types.Datum, error) { + rows := make([][]types.Datum, 0, 1024) + count := 0 + for e.curTable < len(e.initialTables) && count < 1024 { + tbl := e.initialTables[e.curTable] + tblIDs := make([]int64, 0, 1) + tblIDs = append(tblIDs, tbl.ID) + if partInfo := tbl.GetPartitionInfo(); partInfo != nil { + for _, partDef := range partInfo.Definitions { + tblIDs = append(tblIDs, partDef.ID) + } + } + var err error + for _, tableID := range tblIDs { + e.stats, err = e.helper.GetPDRegionStats(ctx, tableID, false) + if err != nil { + return nil, err + } + peerCount := 0 + for _, cnt := range e.stats.StorePeerCount { + peerCount += cnt + } + + record := types.MakeDatums( + tbl.db, // TABLE_SCHEMA + tbl.Name.O, // TABLE_NAME + tableID, // TABLE_ID + peerCount, // TABLE_PEER_COUNT + e.stats.Count, // TABLE_REGION_COUNT + e.stats.EmptyCount, // TABLE_EMPTY_REGION_COUNT + e.stats.StorageSize, // TABLE_SIZE + e.stats.StorageKeys, // TABLE_KEYS + ) + rows = append(rows, record) + } + count++ + e.curTable++ + } + return rows, nil +} + +// dataForAnalyzeStatusHelper is a helper function which can be used in show_stats.go +func dataForAnalyzeStatusHelper(ctx context.Context, sctx sessionctx.Context) (rows [][]types.Datum, err error) { + const maxAnalyzeJobs = 30 + const sql = "SELECT table_schema, table_name, partition_name, job_info, processed_rows, CONVERT_TZ(start_time, @@TIME_ZONE, '+00:00'), CONVERT_TZ(end_time, @@TIME_ZONE, '+00:00'), state, fail_reason, instance, process_id FROM mysql.analyze_jobs ORDER BY update_time DESC LIMIT %?" + exec := sctx.GetRestrictedSQLExecutor() + kctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + chunkRows, _, err := exec.ExecRestrictedSQL(kctx, nil, sql, maxAnalyzeJobs) + if err != nil { + return nil, err + } + checker := privilege.GetPrivilegeManager(sctx) + + for _, chunkRow := range chunkRows { + dbName := chunkRow.GetString(0) + tableName := chunkRow.GetString(1) + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, dbName, tableName, "", mysql.AllPrivMask) { + continue + } + partitionName := chunkRow.GetString(2) + jobInfo := chunkRow.GetString(3) + processedRows := chunkRow.GetInt64(4) + var startTime, endTime any + if !chunkRow.IsNull(5) { + t, err := chunkRow.GetTime(5).GoTime(time.UTC) + if err != nil { + return nil, err + } + startTime = types.NewTime(types.FromGoTime(t.In(sctx.GetSessionVars().TimeZone)), mysql.TypeDatetime, 0) + } + if !chunkRow.IsNull(6) { + t, err := chunkRow.GetTime(6).GoTime(time.UTC) + if err != nil { + return nil, err + } + endTime = types.NewTime(types.FromGoTime(t.In(sctx.GetSessionVars().TimeZone)), mysql.TypeDatetime, 0) + } + + state := chunkRow.GetEnum(7).String() + var failReason any + if !chunkRow.IsNull(8) { + failReason = chunkRow.GetString(8) + } + instance := chunkRow.GetString(9) + var procID any + if !chunkRow.IsNull(10) { + procID = chunkRow.GetUint64(10) + } + + var remainDurationStr, progressDouble, estimatedRowCntStr any + if state == statistics.AnalyzeRunning && !strings.HasPrefix(jobInfo, "merge global stats") { + startTime, ok := startTime.(types.Time) + if !ok { + return nil, errors.New("invalid start time") + } + remainingDuration, progress, estimatedRowCnt, remainDurationErr := + getRemainDurationForAnalyzeStatusHelper(ctx, sctx, &startTime, + dbName, tableName, partitionName, processedRows) + if remainDurationErr != nil { + logutil.BgLogger().Warn("get remaining duration failed", zap.Error(remainDurationErr)) + } + if remainingDuration != nil { + remainDurationStr = execdetails.FormatDuration(*remainingDuration) + } + progressDouble = progress + estimatedRowCntStr = int64(estimatedRowCnt) + } + row := types.MakeDatums( + dbName, // TABLE_SCHEMA + tableName, // TABLE_NAME + partitionName, // PARTITION_NAME + jobInfo, // JOB_INFO + processedRows, // ROW_COUNT + startTime, // START_TIME + endTime, // END_TIME + state, // STATE + failReason, // FAIL_REASON + instance, // INSTANCE + procID, // PROCESS_ID + remainDurationStr, // REMAINING_SECONDS + progressDouble, // PROGRESS + estimatedRowCntStr, // ESTIMATED_TOTAL_ROWS + ) + rows = append(rows, row) + } + return +} + +func getRemainDurationForAnalyzeStatusHelper( + ctx context.Context, + sctx sessionctx.Context, startTime *types.Time, + dbName, tableName, partitionName string, processedRows int64) (*time.Duration, float64, float64, error) { + var remainingDuration = time.Duration(0) + var percentage = 0.0 + var totalCnt = float64(0) + if startTime != nil { + start, err := startTime.GoTime(time.UTC) + if err != nil { + return nil, percentage, totalCnt, err + } + duration := time.Now().UTC().Sub(start) + if intest.InTest { + if val := ctx.Value(AnalyzeProgressTest); val != nil { + remainingDuration, percentage = calRemainInfoForAnalyzeStatus(ctx, int64(totalCnt), processedRows, duration) + return &remainingDuration, percentage, totalCnt, nil + } + } + var tid int64 + is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() + tb, err := is.TableByName(ctx, model.NewCIStr(dbName), model.NewCIStr(tableName)) + if err != nil { + return nil, percentage, totalCnt, err + } + statsHandle := domain.GetDomain(sctx).StatsHandle() + if statsHandle != nil { + var statsTbl *statistics.Table + meta := tb.Meta() + if partitionName != "" { + pt := meta.GetPartitionInfo() + tid = pt.GetPartitionIDByName(partitionName) + statsTbl = statsHandle.GetPartitionStats(meta, tid) + } else { + statsTbl = statsHandle.GetTableStats(meta) + tid = meta.ID + } + if statsTbl != nil && statsTbl.RealtimeCount != 0 { + totalCnt = float64(statsTbl.RealtimeCount) + } + } + if (tid > 0 && totalCnt == 0) || float64(processedRows) > totalCnt { + totalCnt, _ = pdhelper.GlobalPDHelper.GetApproximateTableCountFromStorage(ctx, sctx, tid, dbName, tableName, partitionName) + } + remainingDuration, percentage = calRemainInfoForAnalyzeStatus(ctx, int64(totalCnt), processedRows, duration) + } + return &remainingDuration, percentage, totalCnt, nil +} + +func calRemainInfoForAnalyzeStatus(ctx context.Context, totalCnt int64, processedRows int64, duration time.Duration) (time.Duration, float64) { + if intest.InTest { + if val := ctx.Value(AnalyzeProgressTest); val != nil { + totalCnt = 100 // But in final result, it is still 0. + processedRows = 10 + duration = 1 * time.Minute + } + } + if totalCnt == 0 { + return 0, 100.0 + } + remainLine := totalCnt - processedRows + if processedRows == 0 { + processedRows = 1 + } + if duration == 0 { + duration = 1 * time.Second + } + i := float64(remainLine) * duration.Seconds() / float64(processedRows) + persentage := float64(processedRows) / float64(totalCnt) + return time.Duration(i) * time.Second, persentage +} + +// setDataForAnalyzeStatus gets all the analyze jobs. +func (e *memtableRetriever) setDataForAnalyzeStatus(ctx context.Context, sctx sessionctx.Context) (err error) { + e.rows, err = dataForAnalyzeStatusHelper(ctx, sctx) + return +} + +// setDataForPseudoProfiling returns pseudo data for table profiling when system variable `profiling` is set to `ON`. +func (e *memtableRetriever) setDataForPseudoProfiling(sctx sessionctx.Context) { + if v, ok := sctx.GetSessionVars().GetSystemVar("profiling"); ok && variable.TiDBOptOn(v) { + row := types.MakeDatums( + 0, // QUERY_ID + 0, // SEQ + "", // STATE + types.NewDecFromInt(0), // DURATION + types.NewDecFromInt(0), // CPU_USER + types.NewDecFromInt(0), // CPU_SYSTEM + 0, // CONTEXT_VOLUNTARY + 0, // CONTEXT_INVOLUNTARY + 0, // BLOCK_OPS_IN + 0, // BLOCK_OPS_OUT + 0, // MESSAGES_SENT + 0, // MESSAGES_RECEIVED + 0, // PAGE_FAULTS_MAJOR + 0, // PAGE_FAULTS_MINOR + 0, // SWAPS + "", // SOURCE_FUNCTION + "", // SOURCE_FILE + 0, // SOURCE_LINE + ) + e.rows = append(e.rows, row) + } +} + +func (e *memtableRetriever) setDataForServersInfo(ctx sessionctx.Context) error { + serversInfo, err := infosync.GetAllServerInfo(context.Background()) + if err != nil { + return err + } + rows := make([][]types.Datum, 0, len(serversInfo)) + for _, info := range serversInfo { + row := types.MakeDatums( + info.ID, // DDL_ID + info.IP, // IP + int(info.Port), // PORT + int(info.StatusPort), // STATUS_PORT + info.Lease, // LEASE + info.Version, // VERSION + info.GitHash, // GIT_HASH + info.BinlogStatus, // BINLOG_STATUS + stringutil.BuildStringFromLabels(info.Labels), // LABELS + ) + if sem.IsEnabled() { + checker := privilege.GetPrivilegeManager(ctx) + if checker == nil || !checker.RequestDynamicVerification(ctx.GetSessionVars().ActiveRoles, "RESTRICTED_TABLES_ADMIN", false) { + row[1].SetNull() // clear IP + } + } + rows = append(rows, row) + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromSequences(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + checker := privilege.GetPrivilegeManager(sctx) + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + var rows [][]types.Datum + for _, schema := range schemas { + if ok && extractor.Filter("sequence_schema", schema.L) { + continue + } + tables, err := e.is.SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, table := range tables { + if ok && extractor.Filter("sequence_name", table.Name.L) { + continue + } + if !table.IsSequence() { + continue + } + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { + continue + } + record := types.MakeDatums( + infoschema.CatalogVal, // TABLE_CATALOG + schema.O, // SEQUENCE_SCHEMA + table.Name.O, // SEQUENCE_NAME + table.Sequence.Cache, // Cache + table.Sequence.CacheValue, // CACHE_VALUE + table.Sequence.Cycle, // CYCLE + table.Sequence.Increment, // INCREMENT + table.Sequence.MaxValue, // MAXVALUE + table.Sequence.MinValue, // MINVALUE + table.Sequence.Start, // START + table.Sequence.Comment, // COMMENT + ) + rows = append(rows, record) + } + } + e.rows = rows + return nil +} + +// dataForTableTiFlashReplica constructs data for table tiflash replica info. +func (e *memtableRetriever) dataForTableTiFlashReplica(_ context.Context, sctx sessionctx.Context, _ []model.CIStr) error { + var ( + checker = privilege.GetPrivilegeManager(sctx) + rows [][]types.Datum + tiFlashStores map[int64]pd.StoreInfo + ) + rs := e.is.ListTablesWithSpecialAttribute(infoschema.TiFlashAttribute) + for _, schema := range rs { + for _, tbl := range schema.TableInfos { + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.DBName, tbl.Name.L, "", mysql.AllPrivMask) { + continue + } + var progress float64 + if pi := tbl.GetPartitionInfo(); pi != nil && len(pi.Definitions) > 0 { + for _, p := range pi.Definitions { + progressOfPartition, err := infosync.MustGetTiFlashProgress(p.ID, tbl.TiFlashReplica.Count, &tiFlashStores) + if err != nil { + logutil.BgLogger().Error("dataForTableTiFlashReplica error", zap.Int64("tableID", tbl.ID), zap.Int64("partitionID", p.ID), zap.Error(err)) + } + progress += progressOfPartition + } + progress = progress / float64(len(pi.Definitions)) + } else { + var err error + progress, err = infosync.MustGetTiFlashProgress(tbl.ID, tbl.TiFlashReplica.Count, &tiFlashStores) + if err != nil { + logutil.BgLogger().Error("dataForTableTiFlashReplica error", zap.Int64("tableID", tbl.ID), zap.Error(err)) + } + } + progressString := types.TruncateFloatToString(progress, 2) + progress, _ = strconv.ParseFloat(progressString, 64) + record := types.MakeDatums( + schema.DBName, // TABLE_SCHEMA + tbl.Name.O, // TABLE_NAME + tbl.ID, // TABLE_ID + int64(tbl.TiFlashReplica.Count), // REPLICA_COUNT + strings.Join(tbl.TiFlashReplica.LocationLabels, ","), // LOCATION_LABELS + tbl.TiFlashReplica.Available, // AVAILABLE + progress, // PROGRESS + ) + rows = append(rows, record) + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForClientErrorsSummary(ctx sessionctx.Context, tableName string) error { + // Seeing client errors should require the PROCESS privilege, with the exception of errors for your own user. + // This is similar to information_schema.processlist, which is the closest comparison. + hasProcessPriv := hasPriv(ctx, mysql.ProcessPriv) + loginUser := ctx.GetSessionVars().User + + var rows [][]types.Datum + switch tableName { + case infoschema.TableClientErrorsSummaryGlobal: + if !hasProcessPriv { + return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + for code, summary := range errno.GlobalStats() { + firstSeen := types.NewTime(types.FromGoTime(summary.FirstSeen), mysql.TypeTimestamp, types.DefaultFsp) + lastSeen := types.NewTime(types.FromGoTime(summary.LastSeen), mysql.TypeTimestamp, types.DefaultFsp) + row := types.MakeDatums( + int(code), // ERROR_NUMBER + errno.MySQLErrName[code].Raw, // ERROR_MESSAGE + summary.ErrorCount, // ERROR_COUNT + summary.WarningCount, // WARNING_COUNT + firstSeen, // FIRST_SEEN + lastSeen, // LAST_SEEN + ) + rows = append(rows, row) + } + case infoschema.TableClientErrorsSummaryByUser: + for user, agg := range errno.UserStats() { + for code, summary := range agg { + // Allow anyone to see their own errors. + if !hasProcessPriv && loginUser != nil && loginUser.Username != user { + continue + } + firstSeen := types.NewTime(types.FromGoTime(summary.FirstSeen), mysql.TypeTimestamp, types.DefaultFsp) + lastSeen := types.NewTime(types.FromGoTime(summary.LastSeen), mysql.TypeTimestamp, types.DefaultFsp) + row := types.MakeDatums( + user, // USER + int(code), // ERROR_NUMBER + errno.MySQLErrName[code].Raw, // ERROR_MESSAGE + summary.ErrorCount, // ERROR_COUNT + summary.WarningCount, // WARNING_COUNT + firstSeen, // FIRST_SEEN + lastSeen, // LAST_SEEN + ) + rows = append(rows, row) + } + } + case infoschema.TableClientErrorsSummaryByHost: + if !hasProcessPriv { + return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + for host, agg := range errno.HostStats() { + for code, summary := range agg { + firstSeen := types.NewTime(types.FromGoTime(summary.FirstSeen), mysql.TypeTimestamp, types.DefaultFsp) + lastSeen := types.NewTime(types.FromGoTime(summary.LastSeen), mysql.TypeTimestamp, types.DefaultFsp) + row := types.MakeDatums( + host, // HOST + int(code), // ERROR_NUMBER + errno.MySQLErrName[code].Raw, // ERROR_MESSAGE + summary.ErrorCount, // ERROR_COUNT + summary.WarningCount, // WARNING_COUNT + firstSeen, // FIRST_SEEN + lastSeen, // LAST_SEEN + ) + rows = append(rows, row) + } + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForTrxSummary(ctx sessionctx.Context) error { + hasProcessPriv := hasPriv(ctx, mysql.ProcessPriv) + if !hasProcessPriv { + return nil + } + rows := txninfo.Recorder.DumpTrxSummary() + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForClusterTrxSummary(ctx sessionctx.Context) error { + err := e.setDataForTrxSummary(ctx) + if err != nil { + return err + } + rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) + if err != nil { + return err + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForMemoryUsage() error { + r := memory.ReadMemStats() + currentOps, sessionKillLastDatum := types.NewDatum(nil), types.NewDatum(nil) + if memory.TriggerMemoryLimitGC.Load() || servermemorylimit.IsKilling.Load() { + currentOps.SetString("shrink", mysql.DefaultCollationName) + } + sessionKillLast := servermemorylimit.SessionKillLast.Load() + if !sessionKillLast.IsZero() { + sessionKillLastDatum.SetMysqlTime(types.NewTime(types.FromGoTime(sessionKillLast), mysql.TypeDatetime, 0)) + } + gcLast := types.NewTime(types.FromGoTime(memory.MemoryLimitGCLast.Load()), mysql.TypeDatetime, 0) + + row := []types.Datum{ + types.NewIntDatum(int64(memory.GetMemTotalIgnoreErr())), // MEMORY_TOTAL + types.NewIntDatum(int64(memory.ServerMemoryLimit.Load())), // MEMORY_LIMIT + types.NewIntDatum(int64(r.HeapInuse)), // MEMORY_CURRENT + types.NewIntDatum(int64(servermemorylimit.MemoryMaxUsed.Load())), // MEMORY_MAX_USED + currentOps, // CURRENT_OPS + sessionKillLastDatum, // SESSION_KILL_LAST + types.NewIntDatum(servermemorylimit.SessionKillTotal.Load()), // SESSION_KILL_TOTAL + types.NewTimeDatum(gcLast), // GC_LAST + types.NewIntDatum(memory.MemoryLimitGCTotal.Load()), // GC_TOTAL + types.NewDatum(GlobalDiskUsageTracker.BytesConsumed()), // DISK_USAGE + types.NewDatum(memory.QueryForceDisk.Load()), // QUERY_FORCE_DISK + } + e.rows = append(e.rows, row) + return nil +} + +func (e *memtableRetriever) setDataForClusterMemoryUsage(ctx sessionctx.Context) error { + err := e.setDataForMemoryUsage() + if err != nil { + return err + } + rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) + if err != nil { + return err + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForMemoryUsageOpsHistory() error { + e.rows = servermemorylimit.GlobalMemoryOpsHistoryManager.GetRows() + return nil +} + +func (e *memtableRetriever) setDataForClusterMemoryUsageOpsHistory(ctx sessionctx.Context) error { + err := e.setDataForMemoryUsageOpsHistory() + if err != nil { + return err + } + rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) + if err != nil { + return err + } + e.rows = rows + return nil +} + +// tidbTrxTableRetriever is the memtable retriever for the TIDB_TRX and CLUSTER_TIDB_TRX table. +type tidbTrxTableRetriever struct { + dummyCloser + batchRetrieverHelper + table *model.TableInfo + columns []*model.ColumnInfo + txnInfo []*txninfo.TxnInfo + initialized bool +} + +func (e *tidbTrxTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.retrieved { + return nil, nil + } + + if !e.initialized { + e.initialized = true + + sm := sctx.GetSessionManager() + if sm == nil { + e.retrieved = true + return nil, nil + } + + loginUser := sctx.GetSessionVars().User + hasProcessPriv := hasPriv(sctx, mysql.ProcessPriv) + infoList := sm.ShowTxnList() + e.txnInfo = make([]*txninfo.TxnInfo, 0, len(infoList)) + for _, info := range infoList { + // If you have the PROCESS privilege, you can see all running transactions. + // Otherwise, you can see only your own transactions. + if !hasProcessPriv && loginUser != nil && info.Username != loginUser.Username { + continue + } + e.txnInfo = append(e.txnInfo, info) + } + + e.batchRetrieverHelper.totalRows = len(e.txnInfo) + e.batchRetrieverHelper.batchSize = 1024 + } + + sqlExec := sctx.GetRestrictedSQLExecutor() + + var err error + // The current TiDB node's address is needed by the CLUSTER_TIDB_TRX table. + var instanceAddr string + if e.table.Name.O == infoschema.ClusterTableTiDBTrx { + instanceAddr, err = infoschema.GetInstanceAddr(sctx) + if err != nil { + return nil, err + } + } + + var res [][]types.Datum + err = e.nextBatch(func(start, end int) error { + // Before getting rows, collect the SQL digests that needs to be retrieved first. + var sqlRetriever *expression.SQLDigestTextRetriever + for _, c := range e.columns { + if c.Name.O == txninfo.CurrentSQLDigestTextStr { + if sqlRetriever == nil { + sqlRetriever = expression.NewSQLDigestTextRetriever() + } + + for i := start; i < end; i++ { + sqlRetriever.SQLDigestsMap[e.txnInfo[i].CurrentSQLDigest] = "" + } + } + } + // Retrieve the SQL texts if necessary. + if sqlRetriever != nil { + err1 := sqlRetriever.RetrieveLocal(ctx, sqlExec) + if err1 != nil { + return errors.Trace(err1) + } + } + + res = make([][]types.Datum, 0, end-start) + + // Calculate rows. + for i := start; i < end; i++ { + row := make([]types.Datum, 0, len(e.columns)) + for _, c := range e.columns { + if c.Name.O == util.ClusterTableInstanceColumnName { + row = append(row, types.NewDatum(instanceAddr)) + } else if c.Name.O == txninfo.CurrentSQLDigestTextStr { + if text, ok := sqlRetriever.SQLDigestsMap[e.txnInfo[i].CurrentSQLDigest]; ok && len(text) != 0 { + row = append(row, types.NewDatum(text)) + } else { + row = append(row, types.NewDatum(nil)) + } + } else { + switch c.Name.O { + case txninfo.MemBufferBytesStr: + memDBFootprint := sctx.GetSessionVars().MemDBFootprint + var bytesConsumed int64 + if memDBFootprint != nil { + bytesConsumed = memDBFootprint.BytesConsumed() + } + row = append(row, types.NewDatum(bytesConsumed)) + default: + row = append(row, e.txnInfo[i].ToDatum(c.Name.O)) + } + } + } + res = append(res, row) + } + + return nil + }) + + if err != nil { + return nil, err + } + + return res, nil +} + +// dataLockWaitsTableRetriever is the memtable retriever for the DATA_LOCK_WAITS table. +type dataLockWaitsTableRetriever struct { + dummyCloser + batchRetrieverHelper + table *model.TableInfo + columns []*model.ColumnInfo + lockWaits []*deadlock.WaitForEntry + resolvingLocks []txnlock.ResolvingLock + initialized bool +} + +func (r *dataLockWaitsTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if r.retrieved { + return nil, nil + } + + if !r.initialized { + if !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + + r.initialized = true + var err error + r.lockWaits, err = sctx.GetStore().GetLockWaits() + tikvStore, _ := sctx.GetStore().(helper.Storage) + r.resolvingLocks = tikvStore.GetLockResolver().Resolving() + if err != nil { + r.retrieved = true + return nil, err + } + + r.batchRetrieverHelper.totalRows = len(r.lockWaits) + len(r.resolvingLocks) + r.batchRetrieverHelper.batchSize = 1024 + } + + var res [][]types.Datum + + err := r.nextBatch(func(start, end int) error { + // Before getting rows, collect the SQL digests that needs to be retrieved first. + var needDigest bool + var needSQLText bool + for _, c := range r.columns { + if c.Name.O == infoschema.DataLockWaitsColumnSQLDigestText { + needSQLText = true + } else if c.Name.O == infoschema.DataLockWaitsColumnSQLDigest { + needDigest = true + } + } + + var digests []string + if needDigest || needSQLText { + digests = make([]string, end-start) + for i, lockWait := range r.lockWaits { + digest, err := resourcegrouptag.DecodeResourceGroupTag(lockWait.ResourceGroupTag) + if err != nil { + // Ignore the error if failed to decode the digest from resource_group_tag. We still want to show + // as much information as possible even we can't retrieve some of them. + logutil.Logger(ctx).Warn("failed to decode resource group tag", zap.Error(err)) + } else { + digests[i] = hex.EncodeToString(digest) + } + } + // todo: support resourcegrouptag for resolvingLocks + } + + // Fetch the SQL Texts of the digests above if necessary. + var sqlRetriever *expression.SQLDigestTextRetriever + if needSQLText { + sqlRetriever = expression.NewSQLDigestTextRetriever() + for _, digest := range digests { + if len(digest) > 0 { + sqlRetriever.SQLDigestsMap[digest] = "" + } + } + + err := sqlRetriever.RetrieveGlobal(ctx, sctx.GetRestrictedSQLExecutor()) + if err != nil { + return errors.Trace(err) + } + } + + // Calculate rows. + res = make([][]types.Datum, 0, end-start) + // data_lock_waits contains both lockWaits (pessimistic lock waiting) + // and resolving (optimistic lock "waiting") info + // first we'll return the lockWaits, and then resolving, so we need to + // do some index calculation here + lockWaitsStart := min(start, len(r.lockWaits)) + resolvingStart := start - lockWaitsStart + lockWaitsEnd := min(end, len(r.lockWaits)) + resolvingEnd := end - lockWaitsEnd + for rowIdx, lockWait := range r.lockWaits[lockWaitsStart:lockWaitsEnd] { + row := make([]types.Datum, 0, len(r.columns)) + + for _, col := range r.columns { + switch col.Name.O { + case infoschema.DataLockWaitsColumnKey: + row = append(row, types.NewDatum(strings.ToUpper(hex.EncodeToString(lockWait.Key)))) + case infoschema.DataLockWaitsColumnKeyInfo: + infoSchema := sctx.GetInfoSchema().(infoschema.InfoSchema) + var decodedKeyStr any + decodedKey, err := keydecoder.DecodeKey(lockWait.Key, infoSchema) + if err == nil { + decodedKeyBytes, err := json.Marshal(decodedKey) + if err != nil { + logutil.BgLogger().Warn("marshal decoded key info to JSON failed", zap.Error(err)) + } else { + decodedKeyStr = string(decodedKeyBytes) + } + } else { + logutil.Logger(ctx).Warn("decode key failed", zap.Error(err)) + } + row = append(row, types.NewDatum(decodedKeyStr)) + case infoschema.DataLockWaitsColumnTrxID: + row = append(row, types.NewDatum(lockWait.Txn)) + case infoschema.DataLockWaitsColumnCurrentHoldingTrxID: + row = append(row, types.NewDatum(lockWait.WaitForTxn)) + case infoschema.DataLockWaitsColumnSQLDigest: + digest := digests[rowIdx] + if len(digest) == 0 { + row = append(row, types.NewDatum(nil)) + } else { + row = append(row, types.NewDatum(digest)) + } + case infoschema.DataLockWaitsColumnSQLDigestText: + text := sqlRetriever.SQLDigestsMap[digests[rowIdx]] + if len(text) > 0 { + row = append(row, types.NewDatum(text)) + } else { + row = append(row, types.NewDatum(nil)) + } + default: + row = append(row, types.NewDatum(nil)) + } + } + + res = append(res, row) + } + for _, resolving := range r.resolvingLocks[resolvingStart:resolvingEnd] { + row := make([]types.Datum, 0, len(r.columns)) + + for _, col := range r.columns { + switch col.Name.O { + case infoschema.DataLockWaitsColumnKey: + row = append(row, types.NewDatum(strings.ToUpper(hex.EncodeToString(resolving.Key)))) + case infoschema.DataLockWaitsColumnKeyInfo: + infoSchema := domain.GetDomain(sctx).InfoSchema() + var decodedKeyStr any + decodedKey, err := keydecoder.DecodeKey(resolving.Key, infoSchema) + if err == nil { + decodedKeyBytes, err := json.Marshal(decodedKey) + if err != nil { + logutil.Logger(ctx).Warn("marshal decoded key info to JSON failed", zap.Error(err)) + } else { + decodedKeyStr = string(decodedKeyBytes) + } + } else { + logutil.Logger(ctx).Warn("decode key failed", zap.Error(err)) + } + row = append(row, types.NewDatum(decodedKeyStr)) + case infoschema.DataLockWaitsColumnTrxID: + row = append(row, types.NewDatum(resolving.TxnID)) + case infoschema.DataLockWaitsColumnCurrentHoldingTrxID: + row = append(row, types.NewDatum(resolving.LockTxnID)) + case infoschema.DataLockWaitsColumnSQLDigest: + // todo: support resourcegrouptag for resolvingLocks + row = append(row, types.NewDatum(nil)) + case infoschema.DataLockWaitsColumnSQLDigestText: + // todo: support resourcegrouptag for resolvingLocks + row = append(row, types.NewDatum(nil)) + default: + row = append(row, types.NewDatum(nil)) + } + } + + res = append(res, row) + } + return nil + }) + + if err != nil { + return nil, err + } + + return res, nil +} + +// deadlocksTableRetriever is the memtable retriever for the DEADLOCKS and CLUSTER_DEADLOCKS table. +type deadlocksTableRetriever struct { + dummyCloser + batchRetrieverHelper + + currentIdx int + currentWaitChainIdx int + + table *model.TableInfo + columns []*model.ColumnInfo + deadlocks []*deadlockhistory.DeadlockRecord + initialized bool +} + +// nextIndexPair advances a index pair (where `idx` is the index of the DeadlockRecord, and `waitChainIdx` is the index +// of the wait chain item in the `idx`-th DeadlockRecord. This function helps iterate over each wait chain item +// in all DeadlockRecords. +func (r *deadlocksTableRetriever) nextIndexPair(idx, waitChainIdx int) (a, b int) { + waitChainIdx++ + if waitChainIdx >= len(r.deadlocks[idx].WaitChain) { + waitChainIdx = 0 + idx++ + for idx < len(r.deadlocks) && len(r.deadlocks[idx].WaitChain) == 0 { + idx++ + } + } + return idx, waitChainIdx +} + +func (r *deadlocksTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if r.retrieved { + return nil, nil + } + + if !r.initialized { + if !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + + r.initialized = true + r.deadlocks = deadlockhistory.GlobalDeadlockHistory.GetAll() + + r.batchRetrieverHelper.totalRows = 0 + for _, d := range r.deadlocks { + r.batchRetrieverHelper.totalRows += len(d.WaitChain) + } + r.batchRetrieverHelper.batchSize = 1024 + } + + // The current TiDB node's address is needed by the CLUSTER_DEADLOCKS table. + var err error + var instanceAddr string + if r.table.Name.O == infoschema.ClusterTableDeadlocks { + instanceAddr, err = infoschema.GetInstanceAddr(sctx) + if err != nil { + return nil, err + } + } + + infoSchema := sctx.GetInfoSchema().(infoschema.InfoSchema) + + var res [][]types.Datum + + err = r.nextBatch(func(start, end int) error { + // Before getting rows, collect the SQL digests that needs to be retrieved first. + var sqlRetriever *expression.SQLDigestTextRetriever + for _, c := range r.columns { + if c.Name.O == deadlockhistory.ColCurrentSQLDigestTextStr { + if sqlRetriever == nil { + sqlRetriever = expression.NewSQLDigestTextRetriever() + } + + idx, waitChainIdx := r.currentIdx, r.currentWaitChainIdx + for i := start; i < end; i++ { + if idx >= len(r.deadlocks) { + return errors.New("reading information_schema.(cluster_)deadlocks table meets corrupted index") + } + + sqlRetriever.SQLDigestsMap[r.deadlocks[idx].WaitChain[waitChainIdx].SQLDigest] = "" + // Step to the next entry + idx, waitChainIdx = r.nextIndexPair(idx, waitChainIdx) + } + } + } + // Retrieve the SQL texts if necessary. + if sqlRetriever != nil { + err1 := sqlRetriever.RetrieveGlobal(ctx, sctx.GetRestrictedSQLExecutor()) + if err1 != nil { + return errors.Trace(err1) + } + } + + res = make([][]types.Datum, 0, end-start) + + for i := start; i < end; i++ { + if r.currentIdx >= len(r.deadlocks) { + return errors.New("reading information_schema.(cluster_)deadlocks table meets corrupted index") + } + + row := make([]types.Datum, 0, len(r.columns)) + deadlock := r.deadlocks[r.currentIdx] + waitChainItem := deadlock.WaitChain[r.currentWaitChainIdx] + + for _, c := range r.columns { + if c.Name.O == util.ClusterTableInstanceColumnName { + row = append(row, types.NewDatum(instanceAddr)) + } else if c.Name.O == deadlockhistory.ColCurrentSQLDigestTextStr { + if text, ok := sqlRetriever.SQLDigestsMap[waitChainItem.SQLDigest]; ok && len(text) > 0 { + row = append(row, types.NewDatum(text)) + } else { + row = append(row, types.NewDatum(nil)) + } + } else if c.Name.O == deadlockhistory.ColKeyInfoStr { + value := types.NewDatum(nil) + if len(waitChainItem.Key) > 0 { + decodedKey, err := keydecoder.DecodeKey(waitChainItem.Key, infoSchema) + if err == nil { + decodedKeyJSON, err := json.Marshal(decodedKey) + if err != nil { + logutil.BgLogger().Warn("marshal decoded key info to JSON failed", zap.Error(err)) + } else { + value = types.NewDatum(string(decodedKeyJSON)) + } + } else { + logutil.Logger(ctx).Warn("decode key failed", zap.Error(err)) + } + } + row = append(row, value) + } else { + row = append(row, deadlock.ToDatum(r.currentWaitChainIdx, c.Name.O)) + } + } + + res = append(res, row) + // Step to the next entry + r.currentIdx, r.currentWaitChainIdx = r.nextIndexPair(r.currentIdx, r.currentWaitChainIdx) + } + + return nil + }) + + if err != nil { + return nil, err + } + + return res, nil +} + +func adjustColumns(input [][]types.Datum, outColumns []*model.ColumnInfo, table *model.TableInfo) [][]types.Datum { + if len(outColumns) == len(table.Columns) { + return input + } + rows := make([][]types.Datum, len(input)) + for i, fullRow := range input { + row := make([]types.Datum, len(outColumns)) + for j, col := range outColumns { + row[j] = fullRow[col.Offset] + } + rows[i] = row + } + return rows +} + +// TiFlashSystemTableRetriever is used to read system table from tiflash. +type TiFlashSystemTableRetriever struct { + dummyCloser + table *model.TableInfo + outputCols []*model.ColumnInfo + instanceCount int + instanceIdx int + instanceIDs []string + rowIdx int + retrieved bool + initialized bool + extractor *plannercore.TiFlashSystemTableExtractor +} + +func (e *TiFlashSystemTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.extractor.SkipRequest || e.retrieved { + return nil, nil + } + if !e.initialized { + err := e.initialize(sctx, e.extractor.TiFlashInstances) + if err != nil { + return nil, err + } + } + if e.instanceCount == 0 || e.instanceIdx >= e.instanceCount { + e.retrieved = true + return nil, nil + } + + for { + rows, err := e.dataForTiFlashSystemTables(ctx, sctx, e.extractor.TiDBDatabases, e.extractor.TiDBTables) + if err != nil { + return nil, err + } + if len(rows) > 0 || e.instanceIdx >= e.instanceCount { + return rows, nil + } + } +} + +func (e *TiFlashSystemTableRetriever) initialize(sctx sessionctx.Context, tiflashInstances set.StringSet) error { + storeInfo, err := infoschema.GetStoreServerInfo(sctx.GetStore()) + if err != nil { + return err + } + + for _, info := range storeInfo { + if info.ServerType != kv.TiFlash.Name() { + continue + } + info.ResolveLoopBackAddr() + if len(tiflashInstances) > 0 && !tiflashInstances.Exist(info.Address) { + continue + } + hostAndStatusPort := strings.Split(info.StatusAddr, ":") + if len(hostAndStatusPort) != 2 { + return errors.Errorf("node status addr: %s format illegal", info.StatusAddr) + } + e.instanceIDs = append(e.instanceIDs, info.Address) + e.instanceCount++ + } + e.initialized = true + return nil +} + +type tiFlashSQLExecuteResponseMetaColumn struct { + Name string `json:"name"` + Type string `json:"type"` +} + +type tiFlashSQLExecuteResponse struct { + Meta []tiFlashSQLExecuteResponseMetaColumn `json:"meta"` + Data [][]any `json:"data"` +} + +func (e *TiFlashSystemTableRetriever) dataForTiFlashSystemTables(ctx context.Context, sctx sessionctx.Context, tidbDatabases string, tidbTables string) ([][]types.Datum, error) { + maxCount := 1024 + targetTable := strings.ToLower(strings.Replace(e.table.Name.O, "TIFLASH", "DT", 1)) + var filters []string + if len(tidbDatabases) > 0 { + filters = append(filters, fmt.Sprintf("tidb_database IN (%s)", strings.ReplaceAll(tidbDatabases, "\"", "'"))) + } + if len(tidbTables) > 0 { + filters = append(filters, fmt.Sprintf("tidb_table IN (%s)", strings.ReplaceAll(tidbTables, "\"", "'"))) + } + sql := fmt.Sprintf("SELECT * FROM system.%s", targetTable) + if len(filters) > 0 { + sql = fmt.Sprintf("%s WHERE %s", sql, strings.Join(filters, " AND ")) + } + sql = fmt.Sprintf("%s LIMIT %d, %d", sql, e.rowIdx, maxCount) + request := tikvrpc.Request{ + Type: tikvrpc.CmdGetTiFlashSystemTable, + StoreTp: tikvrpc.TiFlash, + Req: &kvrpcpb.TiFlashSystemTableRequest{ + Sql: sql, + }, + } + + store := sctx.GetStore() + tikvStore, ok := store.(tikv.Storage) + if !ok { + return nil, errors.New("Get tiflash system tables can only run with tikv compatible storage") + } + // send request to tiflash, timeout is 1s + instanceID := e.instanceIDs[e.instanceIdx] + resp, err := tikvStore.GetTiKVClient().SendRequest(ctx, instanceID, &request, time.Second) + if err != nil { + return nil, errors.Trace(err) + } + var result tiFlashSQLExecuteResponse + tiflashResp, ok := resp.Resp.(*kvrpcpb.TiFlashSystemTableResponse) + if !ok { + return nil, errors.Errorf("Unexpected response type: %T", resp.Resp) + } + err = json.Unmarshal(tiflashResp.Data, &result) + if err != nil { + return nil, errors.Wrapf(err, "Failed to decode JSON from TiFlash") + } + + // Map result columns back to our columns. It is possible that some columns cannot be + // recognized and some other columns are missing. This may happen during upgrading. + outputColIndexMap := map[string]int{} // Map from TiDB Column name to Output Column Index + for idx, c := range e.outputCols { + outputColIndexMap[c.Name.L] = idx + } + tiflashColIndexMap := map[int]int{} // Map from TiFlash Column index to Output Column Index + for tiFlashColIdx, col := range result.Meta { + if outputIdx, ok := outputColIndexMap[strings.ToLower(col.Name)]; ok { + tiflashColIndexMap[tiFlashColIdx] = outputIdx + } + } + outputRows := make([][]types.Datum, 0, len(result.Data)) + for _, rowFields := range result.Data { + if len(rowFields) == 0 { + continue + } + outputRow := make([]types.Datum, len(e.outputCols)) + for tiFlashColIdx, fieldValue := range rowFields { + outputIdx, ok := tiflashColIndexMap[tiFlashColIdx] + if !ok { + // Discard this field, we don't know which output column is the destination + continue + } + if fieldValue == nil { + continue + } + valStr := fmt.Sprint(fieldValue) + column := e.outputCols[outputIdx] + if column.GetType() == mysql.TypeVarchar { + outputRow[outputIdx].SetString(valStr, mysql.DefaultCollationName) + } else if column.GetType() == mysql.TypeLonglong { + value, err := strconv.ParseInt(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + outputRow[outputIdx].SetInt64(value) + } else if column.GetType() == mysql.TypeDouble { + value, err := strconv.ParseFloat(valStr, 64) + if err != nil { + return nil, errors.Trace(err) + } + outputRow[outputIdx].SetFloat64(value) + } else { + return nil, errors.Errorf("Meet column of unknown type %v", column) + } + } + outputRow[len(e.outputCols)-1].SetString(instanceID, mysql.DefaultCollationName) + outputRows = append(outputRows, outputRow) + } + e.rowIdx += len(outputRows) + if len(outputRows) < maxCount { + e.instanceIdx++ + e.rowIdx = 0 + } + return outputRows, nil +} + +func (e *memtableRetriever) setDataForAttributes(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema) error { + checker := privilege.GetPrivilegeManager(sctx) + rules, err := infosync.GetAllLabelRules(context.TODO()) + skipValidateTable := false + failpoint.Inject("mockOutputOfAttributes", func() { + convert := func(i any) []any { + return []any{i} + } + rules = []*label.Rule{ + { + ID: "schema/test/test_label", + Labels: []pd.RegionLabel{{Key: "merge_option", Value: "allow"}, {Key: "db", Value: "test"}, {Key: "table", Value: "test_label"}}, + RuleType: "key-range", + Data: convert(map[string]any{ + "start_key": "7480000000000000ff395f720000000000fa", + "end_key": "7480000000000000ff3a5f720000000000fa", + }), + }, + { + ID: "invalidIDtest", + Labels: []pd.RegionLabel{{Key: "merge_option", Value: "allow"}, {Key: "db", Value: "test"}, {Key: "table", Value: "test_label"}}, + RuleType: "key-range", + Data: convert(map[string]any{ + "start_key": "7480000000000000ff395f720000000000fa", + "end_key": "7480000000000000ff3a5f720000000000fa", + }), + }, + { + ID: "schema/test/test_label", + Labels: []pd.RegionLabel{{Key: "merge_option", Value: "allow"}, {Key: "db", Value: "test"}, {Key: "table", Value: "test_label"}}, + RuleType: "key-range", + Data: convert(map[string]any{ + "start_key": "aaaaa", + "end_key": "bbbbb", + }), + }, + } + err = nil + skipValidateTable = true + }) + + if err != nil { + return errors.Wrap(err, "get the label rules failed") + } + + rows := make([][]types.Datum, 0, len(rules)) + for _, rule := range rules { + skip := true + dbName, tableName, partitionName, err := checkRule(rule) + if err != nil { + logutil.BgLogger().Warn("check table-rule failed", zap.String("ID", rule.ID), zap.Error(err)) + continue + } + tableID, err := decodeTableIDFromRule(rule) + if err != nil { + logutil.BgLogger().Warn("decode table ID from rule failed", zap.String("ID", rule.ID), zap.Error(err)) + continue + } + + if !skipValidateTable && tableOrPartitionNotExist(ctx, dbName, tableName, partitionName, is, tableID) { + continue + } + + if tableName != "" && dbName != "" && (checker == nil || checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, dbName, tableName, "", mysql.SelectPriv)) { + skip = false + } + if skip { + continue + } + + labels := label.RestoreRegionLabels(&rule.Labels) + var ranges []string + for _, data := range rule.Data.([]any) { + if kv, ok := data.(map[string]any); ok { + startKey := kv["start_key"] + endKey := kv["end_key"] + ranges = append(ranges, fmt.Sprintf("[%s, %s]", startKey, endKey)) + } + } + kr := strings.Join(ranges, ", ") + + row := types.MakeDatums( + rule.ID, + rule.RuleType, + labels, + kr, + ) + rows = append(rows, row) + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromPlacementPolicies(sctx sessionctx.Context) error { + is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() + placementPolicies := is.AllPlacementPolicies() + rows := make([][]types.Datum, 0, len(placementPolicies)) + // Get global PLACEMENT POLICIES + // Currently no privileges needed for seeing global PLACEMENT POLICIES! + for _, policy := range placementPolicies { + // Currently we skip converting syntactic sugar. We might revisit this decision still in the future + // I.e.: if PrimaryRegion or Regions are set, + // also convert them to LeaderConstraints and FollowerConstraints + // for better user experience searching for particular constraints + + // Followers == 0 means not set, so the default value 2 will be used + followerCnt := policy.PlacementSettings.Followers + if followerCnt == 0 { + followerCnt = 2 + } + + row := types.MakeDatums( + policy.ID, + infoschema.CatalogVal, // CATALOG + policy.Name.O, // Policy Name + policy.PlacementSettings.PrimaryRegion, + policy.PlacementSettings.Regions, + policy.PlacementSettings.Constraints, + policy.PlacementSettings.LeaderConstraints, + policy.PlacementSettings.FollowerConstraints, + policy.PlacementSettings.LearnerConstraints, + policy.PlacementSettings.Schedule, + followerCnt, + policy.PlacementSettings.Learners, + ) + rows = append(rows, row) + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromRunawayWatches(sctx sessionctx.Context) error { + do := domain.GetDomain(sctx) + err := do.TryToUpdateRunawayWatch() + if err != nil { + logutil.BgLogger().Warn("read runaway watch list", zap.Error(err)) + } + watches := do.GetRunawayWatchList() + rows := make([][]types.Datum, 0, len(watches)) + for _, watch := range watches { + action := watch.Action + row := types.MakeDatums( + watch.ID, + watch.ResourceGroupName, + watch.StartTime.UTC().Format(time.DateTime), + watch.EndTime.UTC().Format(time.DateTime), + rmpb.RunawayWatchType_name[int32(watch.Watch)], + watch.WatchText, + watch.Source, + rmpb.RunawayAction_name[int32(action)], + ) + if watch.EndTime.Equal(resourcegroup.NullTime) { + row[3].SetString("UNLIMITED", mysql.DefaultCollationName) + } + rows = append(rows, row) + } + e.rows = rows + return nil +} + +// used in resource_groups +const ( + burstableStr = "YES" + burstdisableStr = "NO" + unlimitedFillRate = "UNLIMITED" +) + +func (e *memtableRetriever) setDataFromResourceGroups() error { + resourceGroups, err := infosync.ListResourceGroups(context.TODO()) + if err != nil { + return errors.Errorf("failed to access resource group manager, error message is %s", err.Error()) + } + rows := make([][]types.Datum, 0, len(resourceGroups)) + for _, group := range resourceGroups { + //mode := "" + burstable := burstdisableStr + priority := model.PriorityValueToName(uint64(group.Priority)) + fillrate := unlimitedFillRate + // RU_PER_SEC = unlimited like the default group settings. + isDefaultInReservedSetting := group.RUSettings.RU.Settings.FillRate == math.MaxInt32 + if !isDefaultInReservedSetting { + fillrate = strconv.FormatUint(group.RUSettings.RU.Settings.FillRate, 10) + } + // convert runaway settings + limitBuilder := new(strings.Builder) + if setting := group.RunawaySettings; setting != nil { + if setting.Rule == nil { + return errors.Errorf("unexpected runaway config in resource group") + } + dur := time.Duration(setting.Rule.ExecElapsedTimeMs) * time.Millisecond + fmt.Fprintf(limitBuilder, "EXEC_ELAPSED='%s'", dur.String()) + fmt.Fprintf(limitBuilder, ", ACTION=%s", model.RunawayActionType(setting.Action).String()) + if setting.Watch != nil { + if setting.Watch.LastingDurationMs > 0 { + dur := time.Duration(setting.Watch.LastingDurationMs) * time.Millisecond + fmt.Fprintf(limitBuilder, ", WATCH=%s DURATION='%s'", model.RunawayWatchType(setting.Watch.Type).String(), dur.String()) + } else { + fmt.Fprintf(limitBuilder, ", WATCH=%s DURATION=UNLIMITED", model.RunawayWatchType(setting.Watch.Type).String()) + } + } + } + queryLimit := limitBuilder.String() + + // convert background settings + bgBuilder := new(strings.Builder) + if setting := group.BackgroundSettings; setting != nil { + fmt.Fprintf(bgBuilder, "TASK_TYPES='%s'", strings.Join(setting.JobTypes, ",")) + } + background := bgBuilder.String() + + switch group.Mode { + case rmpb.GroupMode_RUMode: + if group.RUSettings.RU.Settings.BurstLimit < 0 { + burstable = burstableStr + } + row := types.MakeDatums( + group.Name, + fillrate, + priority, + burstable, + queryLimit, + background, + ) + if len(queryLimit) == 0 { + row[4].SetNull() + } + if len(background) == 0 { + row[5].SetNull() + } + rows = append(rows, row) + default: + //mode = "UNKNOWN_MODE" + row := types.MakeDatums( + group.Name, + nil, + nil, + nil, + nil, + nil, + ) + rows = append(rows, row) + } + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromKeywords() error { + rows := make([][]types.Datum, 0, len(parser.Keywords)) + for _, kw := range parser.Keywords { + row := types.MakeDatums(kw.Word, kw.Reserved) + rows = append(rows, row) + } + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataFromIndexUsage(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + dom := domain.GetDomain(sctx) + rows := make([][]types.Datum, 0, 100) + checker := privilege.GetPrivilegeManager(sctx) + extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) + if ok && extractor.SkipRequest { + return nil + } + + for _, schema := range schemas { + if ok && extractor.Filter("table_schema", schema.L) { + continue + } + tables, err := dom.InfoSchema().SchemaTableInfos(ctx, schema) + if err != nil { + return errors.Trace(err) + } + for _, tbl := range tables { + if ok && extractor.Filter("table_name", tbl.Name.L) { + continue + } + allowed := checker == nil || checker.RequestVerification( + sctx.GetSessionVars().ActiveRoles, + schema.L, tbl.Name.L, "", mysql.AllPrivMask) + if !allowed { + continue + } + + for _, idx := range tbl.Indices { + if ok && extractor.Filter("index_name", idx.Name.L) { + continue + } + row := make([]types.Datum, 0, 14) + usage := dom.StatsHandle().GetIndexUsage(tbl.ID, idx.ID) + row = append(row, types.NewStringDatum(schema.O)) + row = append(row, types.NewStringDatum(tbl.Name.O)) + row = append(row, types.NewStringDatum(idx.Name.O)) + row = append(row, types.NewIntDatum(int64(usage.QueryTotal))) + row = append(row, types.NewIntDatum(int64(usage.KvReqTotal))) + row = append(row, types.NewIntDatum(int64(usage.RowAccessTotal))) + for _, percentage := range usage.PercentageAccess { + row = append(row, types.NewIntDatum(int64(percentage))) + } + lastUsedAt := types.Datum{} + lastUsedAt.SetNull() + if !usage.LastUsedAt.IsZero() { + t := types.NewTime(types.FromGoTime(usage.LastUsedAt), mysql.TypeTimestamp, 0) + lastUsedAt = types.NewTimeDatum(t) + } + row = append(row, lastUsedAt) + rows = append(rows, row) + } + } + } + + e.rows = rows + return nil +} + +func (e *memtableRetriever) setDataForClusterIndexUsage(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { + err := e.setDataFromIndexUsage(ctx, sctx, schemas) + if err != nil { + return errors.Trace(err) + } + rows, err := infoschema.AppendHostInfoToRows(sctx, e.rows) + if err != nil { + return err + } + e.rows = rows + return nil +} + +func checkRule(rule *label.Rule) (dbName, tableName string, partitionName string, err error) { + s := strings.Split(rule.ID, "/") + if len(s) < 3 { + err = errors.Errorf("invalid label rule ID: %v", rule.ID) + return + } + if rule.RuleType == "" { + err = errors.New("empty label rule type") + return + } + if rule.Labels == nil || len(rule.Labels) == 0 { + err = errors.New("the label rule has no label") + return + } + if rule.Data == nil { + err = errors.New("the label rule has no data") + return + } + dbName = s[1] + tableName = s[2] + if len(s) > 3 { + partitionName = s[3] + } + return +} + +func decodeTableIDFromRule(rule *label.Rule) (tableID int64, err error) { + datas := rule.Data.([]any) + if len(datas) == 0 { + err = fmt.Errorf("there is no data in rule %s", rule.ID) + return + } + data := datas[0] + dataMap, ok := data.(map[string]any) + if !ok { + err = fmt.Errorf("get the label rules %s failed", rule.ID) + return + } + key, err := hex.DecodeString(fmt.Sprintf("%s", dataMap["start_key"])) + if err != nil { + err = fmt.Errorf("decode key from start_key %s in rule %s failed", dataMap["start_key"], rule.ID) + return + } + _, bs, err := codec.DecodeBytes(key, nil) + if err == nil { + key = bs + } + tableID = tablecodec.DecodeTableID(key) + if tableID == 0 { + err = fmt.Errorf("decode tableID from key %s in rule %s failed", key, rule.ID) + return + } + return +} + +func tableOrPartitionNotExist(ctx context.Context, dbName string, tableName string, partitionName string, is infoschema.InfoSchema, tableID int64) (tableNotExist bool) { + if len(partitionName) == 0 { + curTable, _ := is.TableByName(ctx, model.NewCIStr(dbName), model.NewCIStr(tableName)) + if curTable == nil { + return true + } + curTableID := curTable.Meta().ID + if curTableID != tableID { + return true + } + } else { + _, _, partInfo := is.FindTableByPartitionID(tableID) + if partInfo == nil { + return true + } + } + return false +} diff --git a/pkg/executor/inspection_result.go b/pkg/executor/inspection_result.go index 8596f562ca720..8d8251a77b1e2 100644 --- a/pkg/executor/inspection_result.go +++ b/pkg/executor/inspection_result.go @@ -128,7 +128,7 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct sctx.GetSessionVars().InspectionTableCache = map[string]variable.TableSnapshot{} defer func() { sctx.GetSessionVars().InspectionTableCache = nil }() - failpoint.InjectContext(ctx, "mockMergeMockInspectionTables", func() { + if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockMergeMockInspectionTables")); _err_ == nil { // Merge mock snapshots injected from failpoint for test purpose mockTables, ok := ctx.Value("__mockInspectionTables").(map[string]variable.TableSnapshot) if ok { @@ -136,7 +136,7 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct sctx.GetSessionVars().InspectionTableCache[strings.ToLower(name)] = snap } } - }) + } if e.instanceToStatusAddress == nil { // Get cluster info. diff --git a/pkg/executor/inspection_result.go__failpoint_stash__ b/pkg/executor/inspection_result.go__failpoint_stash__ new file mode 100644 index 0000000000000..8596f562ca720 --- /dev/null +++ b/pkg/executor/inspection_result.go__failpoint_stash__ @@ -0,0 +1,1248 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "cmp" + "context" + "fmt" + "math" + "slices" + "strconv" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tidb/pkg/util/size" +) + +type ( + // inspectionResult represents a abnormal diagnosis result + inspectionResult struct { + tp string + instance string + statusAddress string + // represents the diagnostics item, e.g: `ddl.lease` `raftstore.cpuusage` + item string + // diagnosis result value base on current cluster status + actual string + expected string + severity string + detail string + // degree only used for sort. + degree float64 + } + + inspectionName string + + inspectionFilter struct { + set set.StringSet + timeRange plannerutil.QueryTimeRange + } + + inspectionRule interface { + name() string + inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult + } +) + +func (n inspectionName) name() string { + return string(n) +} + +func (f inspectionFilter) enable(name string) bool { + return len(f.set) == 0 || f.set.Exist(name) +} + +type ( + // configInspection is used to check whether a same configuration item has a + // different value between different instance in the cluster + configInspection struct{ inspectionName } + + // versionInspection is used to check whether the same component has different + // version in the cluster + versionInspection struct{ inspectionName } + + // nodeLoadInspection is used to check the node load of memory/disk/cpu + // have reached a high-level threshold + nodeLoadInspection struct{ inspectionName } + + // criticalErrorInspection is used to check are there some critical errors + // occurred in the past + criticalErrorInspection struct{ inspectionName } + + // thresholdCheckInspection is used to check some threshold value, like CPU usage, leader count change. + thresholdCheckInspection struct{ inspectionName } +) + +var inspectionRules = []inspectionRule{ + &configInspection{inspectionName: "config"}, + &versionInspection{inspectionName: "version"}, + &nodeLoadInspection{inspectionName: "node-load"}, + &criticalErrorInspection{inspectionName: "critical-error"}, + &thresholdCheckInspection{inspectionName: "threshold-check"}, +} + +type inspectionResultRetriever struct { + dummyCloser + retrieved bool + extractor *plannercore.InspectionResultTableExtractor + timeRange plannerutil.QueryTimeRange + instanceToStatusAddress map[string]string + statusToInstanceAddress map[string]string +} + +func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.retrieved || e.extractor.SkipInspection { + return nil, nil + } + e.retrieved = true + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + // Some data of cluster-level memory tables will be retrieved many times in different inspection rules, + // and the cost of retrieving some data is expensive. We use the `TableSnapshot` to cache those data + // and obtain them lazily, and provide a consistent view of inspection tables for each inspection rules. + // All cached snapshots should be released at the end of retrieving. + sctx.GetSessionVars().InspectionTableCache = map[string]variable.TableSnapshot{} + defer func() { sctx.GetSessionVars().InspectionTableCache = nil }() + + failpoint.InjectContext(ctx, "mockMergeMockInspectionTables", func() { + // Merge mock snapshots injected from failpoint for test purpose + mockTables, ok := ctx.Value("__mockInspectionTables").(map[string]variable.TableSnapshot) + if ok { + for name, snap := range mockTables { + sctx.GetSessionVars().InspectionTableCache[strings.ToLower(name)] = snap + } + } + }) + + if e.instanceToStatusAddress == nil { + // Get cluster info. + e.instanceToStatusAddress = make(map[string]string) + e.statusToInstanceAddress = make(map[string]string) + var rows []chunk.Row + exec := sctx.GetRestrictedSQLExecutor() + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select instance,status_address from information_schema.cluster_info;") + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("get cluster info failed: %v", err)) + } + for _, row := range rows { + if row.Len() < 2 { + continue + } + e.instanceToStatusAddress[row.GetString(0)] = row.GetString(1) + e.statusToInstanceAddress[row.GetString(1)] = row.GetString(0) + } + } + + rules := inspectionFilter{set: e.extractor.Rules} + items := inspectionFilter{set: e.extractor.Items, timeRange: e.timeRange} + var finalRows [][]types.Datum + for _, r := range inspectionRules { + name := r.name() + if !rules.enable(name) { + continue + } + results := r.inspect(ctx, sctx, items) + if len(results) == 0 { + continue + } + // make result stable + slices.SortFunc(results, func(i, j inspectionResult) int { + if c := cmp.Compare(i.degree, j.degree); c != 0 { + return -c + } + // lhs and rhs + if c := cmp.Compare(i.item, j.item); c != 0 { + return c + } + if c := cmp.Compare(i.actual, j.actual); c != 0 { + return c + } + // lhs and rhs + if c := cmp.Compare(i.tp, j.tp); c != 0 { + return c + } + return cmp.Compare(i.instance, j.instance) + }) + for _, result := range results { + if len(result.instance) == 0 { + result.instance = e.statusToInstanceAddress[result.statusAddress] + } + if len(result.statusAddress) == 0 { + result.statusAddress = e.instanceToStatusAddress[result.instance] + } + finalRows = append(finalRows, types.MakeDatums( + name, + result.item, + result.tp, + result.instance, + result.statusAddress, + result.actual, + result.expected, + result.severity, + result.detail, + )) + } + } + return finalRows, nil +} + +func (c configInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + var results []inspectionResult + results = append(results, c.inspectDiffConfig(ctx, sctx, filter)...) + results = append(results, c.inspectCheckConfig(ctx, sctx, filter)...) + return results +} + +func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + // check the configuration consistent + ignoreConfigKey := []string{ + // TiDB + "port", + "status.status-port", + "host", + "path", + "advertise-address", + "status.status-port", + "log.file.filename", + "log.slow-query-file", + "tmp-storage-path", + + // PD + "advertise-client-urls", + "advertise-peer-urls", + "client-urls", + "data-dir", + "log-file", + "log.file.filename", + "metric.job", + "name", + "peer-urls", + + // TiKV + "server.addr", + "server.advertise-addr", + "server.advertise-status-addr", + "server.status-addr", + "log-file", + "raftstore.raftdb-path", + "storage.data-dir", + "storage.block-cache.capacity", + } + exec := sctx.GetRestrictedSQLExecutor() + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) + } + + generateDetail := func(tp, item string) string { + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) + return fmt.Sprintf("the cluster has different config value of %[2]s, execute the sql to see more detail: select * from information_schema.cluster_config where type='%[1]s' and `key`='%[2]s'", + tp, item) + } + m := make(map[string][]string) + for _, row := range rows { + value := row.GetString(0) + instance := row.GetString(1) + m[value] = append(m[value], instance) + } + groups := make([]string, 0, len(m)) + for k, v := range m { + slices.Sort(v) + groups = append(groups, fmt.Sprintf("%s config value is %s", strings.Join(v, ","), k)) + } + slices.Sort(groups) + return strings.Join(groups, "\n") + } + + var results []inspectionResult + for _, row := range rows { + if filter.enable(row.GetString(1)) { + detail := generateDetail(row.GetString(0), row.GetString(1)) + results = append(results, inspectionResult{ + tp: row.GetString(0), + instance: "", + item: row.GetString(1), // key + actual: "inconsistent", + expected: "consistent", + severity: "warning", + detail: detail, + }) + } + } + return results +} + +func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + // check the configuration in reason. + cases := []struct { + table string + tp string + key string + expect string + cond string + detail string + }{ + { + table: "cluster_config", + key: "log.slow-threshold", + expect: "> 0", + cond: "type = 'tidb' and `key` = 'log.slow-threshold' and value = '0'", + detail: "slow-threshold = 0 will record every query to slow log, it may affect performance", + }, + { + + table: "cluster_config", + key: "raftstore.sync-log", + expect: "true", + cond: "type = 'tikv' and `key` = 'raftstore.sync-log' and value = 'false'", + detail: "sync-log should be true to avoid recover region when the machine breaks down", + }, + { + table: "cluster_systeminfo", + key: "transparent_hugepage_enabled", + expect: "always madvise [never]", + cond: "system_name = 'kernel' and name = 'transparent_hugepage_enabled' and value not like '%[never]%'", + detail: "Transparent HugePages can cause memory allocation delays during runtime, TiDB recommends that you disable Transparent HugePages on all TiDB servers", + }, + } + + var results []inspectionResult + var rows []chunk.Row + sql := new(strings.Builder) + exec := sctx.GetRestrictedSQLExecutor() + for _, cas := range cases { + if !filter.enable(cas.key) { + continue + } + sql.Reset() + fmt.Fprintf(sql, "select type,instance,value from information_schema.%s where %s", cas.table, cas.cond) + stmt, err := exec.ParseWithParams(ctx, sql.String()) + if err == nil { + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) + } + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) + } + + for _, row := range rows { + results = append(results, inspectionResult{ + tp: row.GetString(0), + instance: row.GetString(1), + item: cas.key, + actual: row.GetString(2), + expected: cas.expect, + severity: "warning", + detail: cas.detail, + }) + } + } + results = append(results, c.checkTiKVBlockCacheSizeConfig(ctx, sctx, filter)...) + return results +} + +func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + item := "storage.block-cache.capacity" + if !filter.enable(item) { + return nil + } + exec := sctx.GetRestrictedSQLExecutor() + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) + } + extractIP := func(addr string) string { + if idx := strings.Index(addr, ":"); idx > -1 { + return addr[0:idx] + } + return addr + } + + ipToBlockSize := make(map[string]uint64) + ipToCount := make(map[string]int) + for _, row := range rows { + ip := extractIP(row.GetString(0)) + size, err := c.convertReadableSizeToByteSize(row.GetString(1)) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check TiKV block-cache configuration in reason failed: %v", err)) + return nil + } + ipToBlockSize[ip] += size + ipToCount[ip]++ + } + + rows, _, err = exec.ExecRestrictedSQL(ctx, nil, "select instance, value from metrics_schema.node_total_memory where time=now()") + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) + } + ipToMemorySize := make(map[string]float64) + for _, row := range rows { + ip := extractIP(row.GetString(0)) + size := row.GetFloat64(1) + ipToMemorySize[ip] += size + } + + var results []inspectionResult + for ip, blockSize := range ipToBlockSize { + if memorySize, ok := ipToMemorySize[ip]; ok { + if float64(blockSize) > memorySize*0.45 { + detail := fmt.Sprintf("There are %v TiKV server in %v node, the total 'storage.block-cache.capacity' of TiKV is more than (0.45 * total node memory)", + ipToCount[ip], ip) + results = append(results, inspectionResult{ + tp: "tikv", + instance: ip, + item: item, + actual: fmt.Sprintf("%v", blockSize), + expected: fmt.Sprintf("< %.0f", memorySize*0.45), + severity: "warning", + detail: detail, + }) + } + } + } + return results +} + +func (configInspection) convertReadableSizeToByteSize(sizeStr string) (uint64, error) { + rate := uint64(1) + if strings.HasSuffix(sizeStr, "KiB") { + rate = size.KB + } else if strings.HasSuffix(sizeStr, "MiB") { + rate = size.MB + } else if strings.HasSuffix(sizeStr, "GiB") { + rate = size.GB + } else if strings.HasSuffix(sizeStr, "TiB") { + rate = size.TB + } else if strings.HasSuffix(sizeStr, "PiB") { + rate = size.PB + } + if rate != 1 && len(sizeStr) > 3 { + sizeStr = sizeStr[:len(sizeStr)-3] + } + size, err := strconv.Atoi(sizeStr) + if err != nil { + return 0, errors.Trace(err) + } + return uint64(size) * rate, nil +} + +func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + exec := sctx.GetRestrictedSQLExecutor() + // check the configuration consistent + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check version consistency failed: %v", err)) + } + + const name = "git_hash" + var results []inspectionResult + for _, row := range rows { + if filter.enable(name) { + results = append(results, inspectionResult{ + tp: row.GetString(0), + instance: "", + item: name, + actual: "inconsistent", + expected: "consistent", + severity: "critical", + detail: fmt.Sprintf("the cluster has %[1]v different %[2]s versions, execute the sql to see more detail: select * from information_schema.cluster_info where type='%[2]s'", row.GetUint64(1), row.GetString(0)), + }) + } + } + return results +} + +func (nodeLoadInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + var rules = []ruleChecker{ + inspectCPULoad{item: "load1", tbl: "node_load1"}, + inspectCPULoad{item: "load5", tbl: "node_load5"}, + inspectCPULoad{item: "load15", tbl: "node_load15"}, + inspectVirtualMemUsage{}, + inspectSwapMemoryUsed{}, + inspectDiskUsage{}, + } + return checkRules(ctx, sctx, filter, rules) +} + +type inspectVirtualMemUsage struct{} + +func (inspectVirtualMemUsage) genSQL(timeRange plannerutil.QueryTimeRange) string { + sql := fmt.Sprintf("select instance, max(value) as max_usage from metrics_schema.node_memory_usage %s group by instance having max_usage >= 70", timeRange.Condition()) + return sql +} + +func (i inspectVirtualMemUsage) genResult(_ string, row chunk.Row) inspectionResult { + return inspectionResult{ + tp: "node", + instance: row.GetString(0), + item: i.getItem(), + actual: fmt.Sprintf("%.1f%%", row.GetFloat64(1)), + expected: "< 70%", + severity: "warning", + detail: "the memory-usage is too high", + } +} + +func (inspectVirtualMemUsage) getItem() string { + return "virtual-memory-usage" +} + +type inspectSwapMemoryUsed struct{} + +func (inspectSwapMemoryUsed) genSQL(timeRange plannerutil.QueryTimeRange) string { + sql := fmt.Sprintf("select instance, max(value) as max_used from metrics_schema.node_memory_swap_used %s group by instance having max_used > 0", timeRange.Condition()) + return sql +} + +func (i inspectSwapMemoryUsed) genResult(_ string, row chunk.Row) inspectionResult { + return inspectionResult{ + tp: "node", + instance: row.GetString(0), + item: i.getItem(), + actual: fmt.Sprintf("%.1f", row.GetFloat64(1)), + expected: "0", + severity: "warning", + } +} + +func (inspectSwapMemoryUsed) getItem() string { + return "swap-memory-used" +} + +type inspectDiskUsage struct{} + +func (inspectDiskUsage) genSQL(timeRange plannerutil.QueryTimeRange) string { + sql := fmt.Sprintf("select instance, device, max(value) as max_usage from metrics_schema.node_disk_usage %v and device like '/%%' group by instance, device having max_usage >= 70", timeRange.Condition()) + return sql +} + +func (i inspectDiskUsage) genResult(_ string, row chunk.Row) inspectionResult { + return inspectionResult{ + tp: "node", + instance: row.GetString(0), + item: i.getItem(), + actual: fmt.Sprintf("%.1f%%", row.GetFloat64(2)), + expected: "< 70%", + severity: "warning", + detail: "the disk-usage of " + row.GetString(1) + " is too high", + } +} + +func (inspectDiskUsage) getItem() string { + return "disk-usage" +} + +type inspectCPULoad struct { + item string + tbl string +} + +func (i inspectCPULoad) genSQL(timeRange plannerutil.QueryTimeRange) string { + sql := fmt.Sprintf(`select t1.instance, t1.max_load , 0.7*t2.cpu_count from + (select instance,max(value) as max_load from metrics_schema.%[1]s %[2]s group by instance) as t1 join + (select instance,max(value) as cpu_count from metrics_schema.node_virtual_cpus %[2]s group by instance) as t2 + on t1.instance=t2.instance where t1.max_load>(0.7*t2.cpu_count);`, i.tbl, timeRange.Condition()) + return sql +} + +func (i inspectCPULoad) genResult(_ string, row chunk.Row) inspectionResult { + return inspectionResult{ + tp: "node", + instance: row.GetString(0), + item: "cpu-" + i.item, + actual: fmt.Sprintf("%.1f", row.GetFloat64(1)), + expected: fmt.Sprintf("< %.1f", row.GetFloat64(2)), + severity: "warning", + detail: i.getItem() + " should less than (cpu_logical_cores * 0.7)", + } +} + +func (i inspectCPULoad) getItem() string { + return "cpu-" + i.item +} + +func (c criticalErrorInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + results := c.inspectError(ctx, sctx, filter) + results = append(results, c.inspectForServerDown(ctx, sctx, filter)...) + return results +} +func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + var rules = []struct { + tp string + item string + tbl string + }{ + {tp: "tikv", item: "critical-error", tbl: "tikv_critical_error_total_count"}, + {tp: "tidb", item: "panic-count", tbl: "tidb_panic_count_total_count"}, + {tp: "tidb", item: "binlog-error", tbl: "tidb_binlog_error_total_count"}, + {tp: "tikv", item: "scheduler-is-busy", tbl: "tikv_scheduler_is_busy_total_count"}, + {tp: "tikv", item: "coprocessor-is-busy", tbl: "tikv_coprocessor_is_busy_total_count"}, + {tp: "tikv", item: "channel-is-full", tbl: "tikv_channel_full_total_count"}, + {tp: "tikv", item: "tikv_engine_write_stall", tbl: "tikv_engine_write_stall"}, + } + + condition := filter.timeRange.Condition() + var results []inspectionResult + exec := sctx.GetRestrictedSQLExecutor() + sql := new(strings.Builder) + for _, rule := range rules { + if filter.enable(rule.item) { + def, found := infoschema.MetricTableMap[rule.tbl] + if !found { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", rule.tbl)) + continue + } + sql.Reset() + fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", + strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + continue + } + for _, row := range rows { + var actual, detail string + var degree float64 + if rest := def.Labels[1:]; len(rest) > 0 { + values := make([]string, 0, len(rest)) + // `i+1` and `1+len(rest)` means skip the first field `instance` + for i := range rest { + values = append(values, row.GetString(i+1)) + } + // TODO: find a better way to construct the `actual` field + actual = fmt.Sprintf("%.2f(%s)", row.GetFloat64(1+len(rest)), strings.Join(values, ", ")) + degree = row.GetFloat64(1 + len(rest)) + } else { + actual = fmt.Sprintf("%.2f", row.GetFloat64(1)) + degree = row.GetFloat64(1) + } + detail = fmt.Sprintf("the total number of errors about '%s' is too many", rule.item) + result := inspectionResult{ + tp: rule.tp, + // NOTE: all tables which can be inspected here whose first label must be `instance` + statusAddress: row.GetString(0), + item: rule.item, + actual: actual, + expected: "0", + severity: "critical", + detail: detail, + degree: degree, + } + results = append(results, result) + } + } + } + return results +} + +func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + item := "server-down" + if !filter.enable(item) { + return nil + } + condition := filter.timeRange.Condition() + exec := sctx.GetRestrictedSQLExecutor() + sql := new(strings.Builder) + fmt.Fprintf(sql, `select t1.job,t1.instance, t2.min_time from + (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join + (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + } + results := make([]inspectionResult, 0, len(rows)) + for _, row := range rows { + if row.Len() < 3 { + continue + } + detail := fmt.Sprintf("%s %s disconnect with prometheus around time '%s'", row.GetString(0), row.GetString(1), row.GetTime(2)) + result := inspectionResult{ + tp: row.GetString(0), + statusAddress: row.GetString(1), + item: item, + actual: "", + expected: "", + severity: "critical", + detail: detail, + degree: 10000 + float64(len(results)), + } + results = append(results, result) + } + // Check from log. + sql.Reset() + fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) + rows, _, err = exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + } + for _, row := range rows { + if row.Len() < 3 { + continue + } + detail := fmt.Sprintf("%s %s restarted at time '%s'", row.GetString(0), row.GetString(1), row.GetString(2)) + result := inspectionResult{ + tp: row.GetString(0), + instance: row.GetString(1), + item: item, + actual: "", + expected: "", + severity: "critical", + detail: detail, + degree: 10000 + float64(len(results)), + } + results = append(results, result) + } + return results +} + +func (c thresholdCheckInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + inspects := []func(context.Context, sessionctx.Context, inspectionFilter) []inspectionResult{ + c.inspectThreshold1, + c.inspectThreshold2, + c.inspectThreshold3, + c.inspectForLeaderDrop, + } + //nolint: prealloc + var results []inspectionResult + for _, inspect := range inspects { + re := inspect(ctx, sctx, filter) + results = append(results, re...) + } + return results +} + +func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + var rules = []struct { + item string + component string + configKey string + threshold float64 + }{ + { + item: "coprocessor-normal-cpu", + component: "cop_normal%", + configKey: "readpool.coprocessor.normal-concurrency", + threshold: 0.9}, + { + item: "coprocessor-high-cpu", + component: "cop_high%", + configKey: "readpool.coprocessor.high-concurrency", + threshold: 0.9, + }, + { + item: "coprocessor-low-cpu", + component: "cop_low%", + configKey: "readpool.coprocessor.low-concurrency", + threshold: 0.9, + }, + { + item: "grpc-cpu", + component: "grpc%", + configKey: "server.grpc-concurrency", + threshold: 0.9, + }, + { + item: "raftstore-cpu", + component: "raftstore_%", + configKey: "raftstore.store-pool-size", + threshold: 0.8, + }, + { + item: "apply-cpu", + component: "apply_%", + configKey: "raftstore.apply-pool-size", + threshold: 0.8, + }, + { + item: "storage-readpool-normal-cpu", + component: "store_read_norm%", + configKey: "readpool.storage.normal-concurrency", + threshold: 0.9, + }, + { + item: "storage-readpool-high-cpu", + component: "store_read_high%", + configKey: "readpool.storage.high-concurrency", + threshold: 0.9, + }, + { + item: "storage-readpool-low-cpu", + component: "store_read_low%", + configKey: "readpool.storage.low-concurrency", + threshold: 0.9, + }, + { + item: "scheduler-worker-cpu", + component: "sched_%", + configKey: "storage.scheduler-worker-pool-size", + threshold: 0.85, + }, + { + item: "split-check-cpu", + component: "split_check", + threshold: 0.9, + }, + } + + condition := filter.timeRange.Condition() + var results []inspectionResult + exec := sctx.GetRestrictedSQLExecutor() + sql := new(strings.Builder) + for _, rule := range rules { + if !filter.enable(rule.item) { + continue + } + + sql.Reset() + if len(rule.configKey) > 0 { + fmt.Fprintf(sql, `select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from + (select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join + (select instance, value from information_schema.cluster_config where type='tikv' and %[5]s = '%[3]s') as t2 join + (select instance,status_address from information_schema.cluster_info where type='tikv') as t3 + on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)`, rule.component, rule.threshold, rule.configKey, condition, "`key`") + } else { + fmt.Fprintf(sql, `select t1.instance, t1.cpu, %[2]f from + (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 + where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) + } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + continue + } + for _, row := range rows { + actual := fmt.Sprintf("%.2f", row.GetFloat64(1)) + degree := math.Abs(row.GetFloat64(1)-row.GetFloat64(2)) / math.Max(row.GetFloat64(1), row.GetFloat64(2)) + expected := "" + if len(rule.configKey) > 0 { + expected = fmt.Sprintf("< %.2f, config: %v=%v", row.GetFloat64(2), rule.configKey, row.GetString(3)) + } else { + expected = fmt.Sprintf("< %.2f", row.GetFloat64(2)) + } + detail := fmt.Sprintf("the '%s' max cpu-usage of %s tikv is too high", rule.item, row.GetString(0)) + result := inspectionResult{ + tp: "tikv", + statusAddress: row.GetString(0), + item: rule.item, + actual: actual, + expected: expected, + severity: "warning", + detail: detail, + degree: degree, + } + results = append(results, result) + } + } + return results +} + +func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + var rules = []struct { + tp string + item string + tbl string + condition string + threshold float64 + factor float64 + isMin bool + detail string + }{ + { + tp: "tidb", + item: "tso-duration", + tbl: "pd_tso_wait_duration", + condition: "quantile=0.999", + threshold: 0.05, + }, + { + tp: "tidb", + item: "get-token-duration", + tbl: "tidb_get_token_duration", + condition: "quantile=0.999", + threshold: 0.001, + factor: 10e5, // the unit is microsecond + }, + { + tp: "tidb", + item: "load-schema-duration", + tbl: "tidb_load_schema_duration", + condition: "quantile=0.99", + threshold: 1, + }, + { + tp: "tikv", + item: "scheduler-cmd-duration", + tbl: "tikv_scheduler_command_duration", + condition: "quantile=0.99", + threshold: 0.1, + }, + { + tp: "tikv", + item: "handle-snapshot-duration", + tbl: "tikv_handle_snapshot_duration", + threshold: 30, + }, + { + tp: "tikv", + item: "storage-write-duration", + tbl: "tikv_storage_async_request_duration", + condition: "type='write'", + threshold: 0.1, + }, + { + tp: "tikv", + item: "storage-snapshot-duration", + tbl: "tikv_storage_async_request_duration", + condition: "type='snapshot'", + threshold: 0.05, + }, + { + tp: "tikv", + item: "rocksdb-write-duration", + tbl: "tikv_engine_write_duration", + condition: "type='write_max'", + threshold: 0.1, + factor: 10e5, // the unit is microsecond + }, + { + tp: "tikv", + item: "rocksdb-get-duration", + tbl: "tikv_engine_max_get_duration", + condition: "type='get_max'", + threshold: 0.05, + factor: 10e5, + }, + { + tp: "tikv", + item: "rocksdb-seek-duration", + tbl: "tikv_engine_max_seek_duration", + condition: "type='seek_max'", + threshold: 0.05, + factor: 10e5, // the unit is microsecond + }, + { + tp: "tikv", + item: "scheduler-pending-cmd-count", + tbl: "tikv_scheduler_pending_commands", + threshold: 1000, + detail: " %s tikv scheduler has too many pending commands", + }, + { + tp: "tikv", + item: "index-block-cache-hit", + tbl: "tikv_block_index_cache_hit", + condition: "value > 0", + threshold: 0.95, + isMin: true, + }, + { + tp: "tikv", + item: "filter-block-cache-hit", + tbl: "tikv_block_filter_cache_hit", + condition: "value > 0", + threshold: 0.95, + isMin: true, + }, + { + tp: "tikv", + item: "data-block-cache-hit", + tbl: "tikv_block_data_cache_hit", + condition: "value > 0", + threshold: 0.80, + isMin: true, + }, + } + + condition := filter.timeRange.Condition() + var results []inspectionResult + sql := new(strings.Builder) + exec := sctx.GetRestrictedSQLExecutor() + for _, rule := range rules { + if !filter.enable(rule.item) { + continue + } + cond := condition + if len(rule.condition) > 0 { + cond = fmt.Sprintf("%s and %s", cond, rule.condition) + } + if rule.factor == 0 { + rule.factor = 1 + } + sql.Reset() + if rule.isMin { + fmt.Fprintf(sql, "select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) + } else { + fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) + } + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + continue + } + for _, row := range rows { + actual := fmt.Sprintf("%.3f", row.GetFloat64(1)) + degree := math.Abs(row.GetFloat64(1)-rule.threshold) / math.Max(row.GetFloat64(1), rule.threshold) + expected := "" + if rule.isMin { + expected = fmt.Sprintf("> %.3f", rule.threshold) + } else { + expected = fmt.Sprintf("< %.3f", rule.threshold) + } + detail := rule.detail + if len(detail) == 0 { + if strings.HasSuffix(rule.item, "duration") { + detail = fmt.Sprintf("max duration of %s %s %s is too slow", row.GetString(0), rule.tp, rule.item) + } else if strings.HasSuffix(rule.item, "hit") { + detail = fmt.Sprintf("min %s rate of %s %s is too low", rule.item, row.GetString(0), rule.tp) + } + } else { + detail = fmt.Sprintf(detail, row.GetString(0)) + } + result := inspectionResult{ + tp: rule.tp, + statusAddress: row.GetString(0), + item: rule.item, + actual: actual, + expected: expected, + severity: "warning", + detail: detail, + degree: degree, + } + results = append(results, result) + } + } + return results +} + +type ruleChecker interface { + genSQL(timeRange plannerutil.QueryTimeRange) string + genResult(sql string, row chunk.Row) inspectionResult + getItem() string +} + +type compareStoreStatus struct { + item string + tp string + threshold float64 +} + +func (c compareStoreStatus) genSQL(timeRange plannerutil.QueryTimeRange) string { + condition := fmt.Sprintf(`where t1.time>='%[1]s' and t1.time<='%[2]s' and + t2.time>='%[1]s' and t2.time<='%[2]s'`, timeRange.From.Format(plannerutil.MetricTableTimeFormat), + timeRange.To.Format(plannerutil.MetricTableTimeFormat)) + return fmt.Sprintf(` + SELECT t1.address, + max(t1.value), + t2.address, + min(t2.value), + max((t1.value-t2.value)/t1.value) AS ratio + FROM metrics_schema.pd_scheduler_store_status t1 + JOIN metrics_schema.pd_scheduler_store_status t2 %s + AND t1.type='%s' + AND t1.time = t2.time + AND t1.type=t2.type + AND t1.address != t2.address + AND (t1.value-t2.value)/t1.value>%v + AND t1.value > 0 + GROUP BY t1.address,t2.address + ORDER BY ratio desc`, condition, c.tp, c.threshold) +} + +func (c compareStoreStatus) genResult(_ string, row chunk.Row) inspectionResult { + addr1 := row.GetString(0) + value1 := row.GetFloat64(1) + addr2 := row.GetString(2) + value2 := row.GetFloat64(3) + ratio := row.GetFloat64(4) + detail := fmt.Sprintf("%v max %s is %.2f, much more than %v min %s %.2f", addr1, c.tp, value1, addr2, c.tp, value2) + return inspectionResult{ + tp: "tikv", + instance: addr2, + item: c.item, + actual: fmt.Sprintf("%.2f%%", ratio*100), + expected: fmt.Sprintf("< %.2f%%", c.threshold*100), + severity: "warning", + detail: detail, + degree: ratio, + } +} + +func (c compareStoreStatus) getItem() string { + return c.item +} + +type checkRegionHealth struct{} + +func (checkRegionHealth) genSQL(timeRange plannerutil.QueryTimeRange) string { + condition := timeRange.Condition() + return fmt.Sprintf(`select instance, sum(value) as sum_value from metrics_schema.pd_region_health %s and + type in ('extra-peer-region-count','learner-peer-region-count','pending-peer-region-count') having sum_value>100`, condition) +} + +func (c checkRegionHealth) genResult(_ string, row chunk.Row) inspectionResult { + detail := fmt.Sprintf("the count of extra-perr and learner-peer and pending-peer are %v, it means the scheduling is too frequent or too slow", row.GetFloat64(1)) + actual := fmt.Sprintf("%.2f", row.GetFloat64(1)) + degree := math.Abs(row.GetFloat64(1)-100) / math.Max(row.GetFloat64(1), 100) + return inspectionResult{ + tp: "pd", + instance: row.GetString(0), + item: c.getItem(), + actual: actual, + expected: "< 100", + severity: "warning", + detail: detail, + degree: degree, + } +} + +func (checkRegionHealth) getItem() string { + return "region-health" +} + +type checkStoreRegionTooMuch struct{} + +func (checkStoreRegionTooMuch) genSQL(timeRange plannerutil.QueryTimeRange) string { + condition := timeRange.Condition() + return fmt.Sprintf(`select address, max(value) from metrics_schema.pd_scheduler_store_status %s and type='region_count' and value > 20000 group by address`, condition) +} + +func (c checkStoreRegionTooMuch) genResult(_ string, row chunk.Row) inspectionResult { + actual := fmt.Sprintf("%.2f", row.GetFloat64(1)) + degree := math.Abs(row.GetFloat64(1)-20000) / math.Max(row.GetFloat64(1), 20000) + return inspectionResult{ + tp: "tikv", + instance: row.GetString(0), + item: c.getItem(), + actual: actual, + expected: "<= 20000", + severity: "warning", + detail: fmt.Sprintf("%s tikv has too many regions", row.GetString(0)), + degree: degree, + } +} + +func (checkStoreRegionTooMuch) getItem() string { + return "region-count" +} + +func (thresholdCheckInspection) inspectThreshold3(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + var rules = []ruleChecker{ + compareStoreStatus{ + item: "leader-score-balance", + tp: "leader_score", + threshold: 0.05, + }, + compareStoreStatus{ + item: "region-score-balance", + tp: "region_score", + threshold: 0.05, + }, + compareStoreStatus{ + item: "store-available-balance", + tp: "store_available", + threshold: 0.2, + }, + checkRegionHealth{}, + checkStoreRegionTooMuch{}, + } + return checkRules(ctx, sctx, filter, rules) +} + +func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter, rules []ruleChecker) []inspectionResult { + var results []inspectionResult + exec := sctx.GetRestrictedSQLExecutor() + for _, rule := range rules { + if !filter.enable(rule.getItem()) { + continue + } + sql := rule.genSQL(filter.timeRange) + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + continue + } + for _, row := range rows { + results = append(results, rule.genResult(sql, row)) + } + } + return results +} + +func (thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { + condition := filter.timeRange.Condition() + threshold := 50.0 + sql := new(strings.Builder) + fmt.Fprintf(sql, `select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) + exec := sctx.GetRestrictedSQLExecutor() + + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + return nil + } + var results []inspectionResult + for _, row := range rows { + address := row.GetString(0) + sql.Reset() + fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) + var subRows []chunk.Row + subRows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) + continue + } + + lastValue := float64(0) + for i, subRows := range subRows { + v := subRows.GetFloat64(1) + if i == 0 { + lastValue = v + continue + } + if lastValue-v > threshold { + level := "warning" + if v == 0 { + level = "critical" + } + results = append(results, inspectionResult{ + tp: "tikv", + instance: address, + item: "leader-drop", + actual: fmt.Sprintf("%.0f", lastValue-v), + expected: fmt.Sprintf("<= %.0f", threshold), + severity: level, + detail: fmt.Sprintf("%s tikv has too many leader-drop around time %s, leader count from %.0f drop to %.0f", address, subRows.GetTime(0), lastValue, v), + degree: lastValue - v, + }) + break + } + lastValue = v + } + } + return results +} diff --git a/pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go b/pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..ee9bcd36346bf --- /dev/null +++ b/pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package calibrateresource + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/internal/calibrateresource/calibrate_resource.go b/pkg/executor/internal/calibrateresource/calibrate_resource.go index 74cb61fec5846..75378d7acf087 100644 --- a/pkg/executor/internal/calibrateresource/calibrate_resource.go +++ b/pkg/executor/internal/calibrateresource/calibrate_resource.go @@ -317,7 +317,7 @@ func (e *Executor) getTiDBQuota( return 0, err } - failpoint.Inject("mockMetricsDataFilter", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockMetricsDataFilter")); _err_ == nil { ret := make([]*timePointValue, 0) for _, point := range tikvCPUs.vals { if point.tp.After(endTs) || point.tp.Before(startTs) { @@ -342,7 +342,7 @@ func (e *Executor) getTiDBQuota( ret = append(ret, point) } rus.vals = ret - }) + } quotas := make([]float64, 0) lowCount := 0 for { @@ -505,11 +505,11 @@ func staticCalibrateTpch10(req *chunk.Chunk, clusterInfo []infoschema.ServerInfo func getTiDBTotalCPUQuota(clusterInfo []infoschema.ServerInfo) (float64, error) { cpuQuota := float64(runtime.GOMAXPROCS(0)) - failpoint.Inject("mockGOMAXPROCS", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockGOMAXPROCS")); _err_ == nil { if val != nil { cpuQuota = float64(val.(int)) } - }) + } instanceNum := count(clusterInfo, serverTypeTiDB) return cpuQuota * float64(instanceNum), nil } @@ -662,7 +662,7 @@ func fetchStoreMetrics(serversInfo []infoschema.ServerInfo, serverType string, o return err } var resp *http.Response - failpoint.Inject("mockMetricsResponse", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockMetricsResponse")); _err_ == nil { if val != nil { data, _ := base64.StdEncoding.DecodeString(val.(string)) resp = &http.Response{ @@ -672,7 +672,7 @@ func fetchStoreMetrics(serversInfo []infoschema.ServerInfo, serverType string, o }, } } - }) + } if resp == nil { var err1 error // ignore false positive go line, can't use defer here because it's in a loop. diff --git a/pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ b/pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ new file mode 100644 index 0000000000000..74cb61fec5846 --- /dev/null +++ b/pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ @@ -0,0 +1,704 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package calibrateresource + +import ( + "bufio" + "context" + "encoding/base64" + "fmt" + "io" + "math" + "net/http" + "runtime" + "sort" + "strconv" + "strings" + "time" + + "github.com/docker/go-units" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/duration" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn/staleread" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/tikv/client-go/v2/oracle" + resourceControlClient "github.com/tikv/pd/client/resource_group/controller" +) + +var ( + // workloadBaseRUCostMap contains the base resource cost rate per 1 kv cpu within 1 second, + // the data is calculated from benchmark result, these data might not be very accurate, + // but is enough here because the maximum RU capacity is depended on both the cluster and + // the workload. + workloadBaseRUCostMap = map[ast.CalibrateResourceType]*baseResourceCost{ + ast.TPCC: { + tidbToKVCPURatio: 0.6, + kvCPU: 0.15, + readBytes: units.MiB / 2, + writeBytes: units.MiB, + readReqCount: 300, + writeReqCount: 1750, + }, + ast.OLTPREADWRITE: { + tidbToKVCPURatio: 1.25, + kvCPU: 0.35, + readBytes: units.MiB * 4.25, + writeBytes: units.MiB / 3, + readReqCount: 1600, + writeReqCount: 1400, + }, + ast.OLTPREADONLY: { + tidbToKVCPURatio: 2, + kvCPU: 0.52, + readBytes: units.MiB * 28, + writeBytes: 0, + readReqCount: 4500, + writeReqCount: 0, + }, + ast.OLTPWRITEONLY: { + tidbToKVCPURatio: 1, + kvCPU: 0, + readBytes: 0, + writeBytes: units.MiB, + readReqCount: 0, + writeReqCount: 3550, + }, + } +) + +const ( + // serverTypeTiDB is tidb's instance type name + serverTypeTiDB = "tidb" + // serverTypeTiKV is tikv's instance type name + serverTypeTiKV = "tikv" + // serverTypeTiFlash is tiflash's instance type name + serverTypeTiFlash = "tiflash" +) + +// the resource cost rate of a specified workload per 1 tikv cpu. +type baseResourceCost struct { + // represents the average ratio of TiDB CPU time to TiKV CPU time, this is used to calculate whether tikv cpu + // or tidb cpu is the performance bottle neck. + tidbToKVCPURatio float64 + // the kv CPU time for calculate RU, it's smaller than the actual cpu usage. The unit is seconds. + kvCPU float64 + // the read bytes rate per 1 tikv cpu. + readBytes uint64 + // the write bytes rate per 1 tikv cpu. + writeBytes uint64 + // the average tikv read request count per 1 tikv cpu. + readReqCount uint64 + // the average tikv write request count per 1 tikv cpu. + writeReqCount uint64 +} + +const ( + // valuableUsageThreshold is the threshold used to determine whether the CPU is high enough. + // The sampling point is available when the CPU utilization of tikv or tidb is higher than the valuableUsageThreshold. + valuableUsageThreshold = 0.2 + // lowUsageThreshold is the threshold used to determine whether the CPU is too low. + // When the CPU utilization of tikv or tidb is lower than lowUsageThreshold, but neither is higher than valuableUsageThreshold, the sampling point is unavailable + lowUsageThreshold = 0.1 + // For quotas computed at each point in time, the maximum and minimum portions are discarded, and discardRate is the percentage discarded + discardRate = 0.1 + + // duration Indicates the supported calibration duration + maxDuration = time.Hour * 24 + minDuration = time.Minute +) + +// Executor is used as executor of calibrate resource. +type Executor struct { + OptionList []*ast.DynamicCalibrateResourceOption + exec.BaseExecutor + WorkloadType ast.CalibrateResourceType + done bool +} + +func (e *Executor) parseTsExpr(ctx context.Context, tsExpr ast.ExprNode) (time.Time, error) { + ts, err := staleread.CalculateAsOfTsExpr(ctx, e.Ctx().GetPlanCtx(), tsExpr) + if err != nil { + return time.Time{}, err + } + return oracle.GetTimeFromTS(ts), nil +} + +func (e *Executor) parseCalibrateDuration(ctx context.Context) (startTime time.Time, endTime time.Time, err error) { + var dur time.Duration + // startTimeExpr and endTimeExpr are used to calc endTime by FuncCallExpr when duration begin with `interval`. + var startTimeExpr ast.ExprNode + var endTimeExpr ast.ExprNode + for _, op := range e.OptionList { + switch op.Tp { + case ast.CalibrateStartTime: + startTimeExpr = op.Ts + startTime, err = e.parseTsExpr(ctx, startTimeExpr) + if err != nil { + return + } + case ast.CalibrateEndTime: + endTimeExpr = op.Ts + endTime, err = e.parseTsExpr(ctx, op.Ts) + if err != nil { + return + } + } + } + for _, op := range e.OptionList { + if op.Tp != ast.CalibrateDuration { + continue + } + // string duration + if len(op.StrValue) > 0 { + dur, err = duration.ParseDuration(op.StrValue) + if err != nil { + return + } + // If startTime is not set, startTime will be now() - duration. + if startTime.IsZero() { + toTime := endTime + if toTime.IsZero() { + toTime = time.Now() + } + startTime = toTime.Add(-dur) + } + // If endTime is set, duration will be ignored. + if endTime.IsZero() { + endTime = startTime.Add(dur) + } + continue + } + // interval duration + // If startTime is not set, startTime will be now() - duration. + if startTimeExpr == nil { + toTimeExpr := endTimeExpr + if endTime.IsZero() { + toTimeExpr = &ast.FuncCallExpr{FnName: model.NewCIStr("CURRENT_TIMESTAMP")} + } + startTimeExpr = &ast.FuncCallExpr{ + FnName: model.NewCIStr("DATE_SUB"), + Args: []ast.ExprNode{ + toTimeExpr, + op.Ts, + &ast.TimeUnitExpr{Unit: op.Unit}}, + } + startTime, err = e.parseTsExpr(ctx, startTimeExpr) + if err != nil { + return + } + } + // If endTime is set, duration will be ignored. + if endTime.IsZero() { + endTime, err = e.parseTsExpr(ctx, &ast.FuncCallExpr{ + FnName: model.NewCIStr("DATE_ADD"), + Args: []ast.ExprNode{startTimeExpr, + op.Ts, + &ast.TimeUnitExpr{Unit: op.Unit}}, + }) + if err != nil { + return + } + } + } + + if startTime.IsZero() { + err = errors.Errorf("start time should not be 0") + return + } + if endTime.IsZero() { + endTime = time.Now() + } + // check the duration + dur = endTime.Sub(startTime) + // add the buffer duration + if dur > maxDuration+time.Minute { + err = errors.Errorf("the duration of calibration is too long, which could lead to inaccurate output. Please make the duration between %s and %s", minDuration.String(), maxDuration.String()) + return + } + // We only need to consider the case where the duration is slightly enlarged. + if dur < minDuration { + err = errors.Errorf("the duration of calibration is too short, which could lead to inaccurate output. Please make the duration between %s and %s", minDuration.String(), maxDuration.String()) + } + return +} + +// Next implements the interface of Executor. +func (e *Executor) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.done { + return nil + } + e.done = true + if !variable.EnableResourceControl.Load() { + return infoschema.ErrResourceGroupSupportDisabled + } + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + if len(e.OptionList) > 0 { + return e.dynamicCalibrate(ctx, req) + } + return e.staticCalibrate(req) +} + +var ( + errLowUsage = errors.Errorf("The workload in selected time window is too low, with which TiDB is unable to reach a capacity estimation; please select another time window with higher workload, or calibrate resource by hardware instead") + errNoCPUQuotaMetrics = errors.Normalize("There is no CPU quota metrics, %v") +) + +func (e *Executor) dynamicCalibrate(ctx context.Context, req *chunk.Chunk) error { + exec := e.Ctx().GetRestrictedSQLExecutor() + startTs, endTs, err := e.parseCalibrateDuration(ctx) + if err != nil { + return err + } + clusterInfo, err := infoschema.GetClusterServerInfo(e.Ctx()) + if err != nil { + return err + } + tidbQuota, err1 := e.getTiDBQuota(ctx, exec, clusterInfo, startTs, endTs) + tiflashQuota, err2 := e.getTiFlashQuota(ctx, exec, clusterInfo, startTs, endTs) + if err1 != nil && err2 != nil { + return err1 + } + + req.AppendUint64(0, uint64(tidbQuota+tiflashQuota)) + return nil +} + +func (e *Executor) getTiDBQuota( + ctx context.Context, + exec sqlexec.RestrictedSQLExecutor, + serverInfos []infoschema.ServerInfo, + startTs, endTs time.Time, +) (float64, error) { + startTime := startTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) + endTime := endTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) + + totalKVCPUQuota, err := getTiKVTotalCPUQuota(serverInfos) + if err != nil { + return 0, errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) + } + totalTiDBCPU, err := getTiDBTotalCPUQuota(serverInfos) + if err != nil { + return 0, errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) + } + rus, err := getRUPerSec(ctx, e.Ctx(), exec, startTime, endTime) + if err != nil { + return 0, err + } + tikvCPUs, err := getComponentCPUUsagePerSec(ctx, e.Ctx(), exec, "tikv", startTime, endTime) + if err != nil { + return 0, err + } + tidbCPUs, err := getComponentCPUUsagePerSec(ctx, e.Ctx(), exec, "tidb", startTime, endTime) + if err != nil { + return 0, err + } + + failpoint.Inject("mockMetricsDataFilter", func() { + ret := make([]*timePointValue, 0) + for _, point := range tikvCPUs.vals { + if point.tp.After(endTs) || point.tp.Before(startTs) { + continue + } + ret = append(ret, point) + } + tikvCPUs.vals = ret + ret = make([]*timePointValue, 0) + for _, point := range tidbCPUs.vals { + if point.tp.After(endTs) || point.tp.Before(startTs) { + continue + } + ret = append(ret, point) + } + tidbCPUs.vals = ret + ret = make([]*timePointValue, 0) + for _, point := range rus.vals { + if point.tp.After(endTs) || point.tp.Before(startTs) { + continue + } + ret = append(ret, point) + } + rus.vals = ret + }) + quotas := make([]float64, 0) + lowCount := 0 + for { + if rus.isEnd() || tikvCPUs.isEnd() || tidbCPUs.isEnd() { + break + } + // make time point match + maxTime := rus.getTime() + if tikvCPUs.getTime().After(maxTime) { + maxTime = tikvCPUs.getTime() + } + if tidbCPUs.getTime().After(maxTime) { + maxTime = tidbCPUs.getTime() + } + if !rus.advance(maxTime) || !tikvCPUs.advance(maxTime) || !tidbCPUs.advance(maxTime) { + continue + } + tikvQuota, tidbQuota := tikvCPUs.getValue()/totalKVCPUQuota, tidbCPUs.getValue()/totalTiDBCPU + // If one of the two cpu usage is greater than the `valuableUsageThreshold`, we can accept it. + // And if both are greater than the `lowUsageThreshold`, we can also accept it. + if tikvQuota > valuableUsageThreshold || tidbQuota > valuableUsageThreshold { + quotas = append(quotas, rus.getValue()/max(tikvQuota, tidbQuota)) + } else if tikvQuota < lowUsageThreshold || tidbQuota < lowUsageThreshold { + lowCount++ + } else { + quotas = append(quotas, rus.getValue()/max(tikvQuota, tidbQuota)) + } + rus.next() + tidbCPUs.next() + tikvCPUs.next() + } + quota, err := setupQuotas(quotas) + if err != nil { + return 0, err + } + return quota, nil +} + +func setupQuotas(quotas []float64) (float64, error) { + if len(quotas) < 2 { + return 0, errLowUsage + } + sort.Slice(quotas, func(i, j int) bool { + return quotas[i] > quotas[j] + }) + lowerBound := int(math.Round(float64(len(quotas)) * discardRate)) + upperBound := len(quotas) - lowerBound + sum := 0. + for i := lowerBound; i < upperBound; i++ { + sum += quotas[i] + } + return sum / float64(upperBound-lowerBound), nil +} + +func (e *Executor) getTiFlashQuota( + ctx context.Context, + exec sqlexec.RestrictedSQLExecutor, + serverInfos []infoschema.ServerInfo, + startTs, endTs time.Time, +) (float64, error) { + startTime := startTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) + endTime := endTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) + + quotas := make([]float64, 0) + totalTiFlashLogicalCores, err := getTiFlashLogicalCores(serverInfos) + if err != nil { + return 0, errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) + } + tiflashCPUs, err := getTiFlashCPUUsagePerSec(ctx, e.Ctx(), exec, startTime, endTime) + if err != nil { + return 0, err + } + tiflashRUs, err := getTiFlashRUPerSec(ctx, e.Ctx(), exec, startTime, endTime) + if err != nil { + return 0, err + } + for { + if tiflashRUs.isEnd() || tiflashCPUs.isEnd() { + break + } + // make time point match + maxTime := tiflashRUs.getTime() + if tiflashCPUs.getTime().After(maxTime) { + maxTime = tiflashCPUs.getTime() + } + if !tiflashRUs.advance(maxTime) || !tiflashCPUs.advance(maxTime) { + continue + } + tiflashQuota := tiflashCPUs.getValue() / totalTiFlashLogicalCores + if tiflashQuota > lowUsageThreshold { + quotas = append(quotas, tiflashRUs.getValue()/tiflashQuota) + } + tiflashRUs.next() + tiflashCPUs.next() + } + return setupQuotas(quotas) +} + +func (e *Executor) staticCalibrate(req *chunk.Chunk) error { + resourceGroupCtl := domain.GetDomain(e.Ctx()).ResourceGroupsController() + // first fetch the ru settings config. + if resourceGroupCtl == nil { + return errors.New("resource group controller is not initialized") + } + clusterInfo, err := infoschema.GetClusterServerInfo(e.Ctx()) + if err != nil { + return err + } + ruCfg := resourceGroupCtl.GetConfig() + if e.WorkloadType == ast.TPCH10 { + return staticCalibrateTpch10(req, clusterInfo, ruCfg) + } + + totalKVCPUQuota, err := getTiKVTotalCPUQuota(clusterInfo) + if err != nil { + return errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) + } + totalTiDBCPUQuota, err := getTiDBTotalCPUQuota(clusterInfo) + if err != nil { + return errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) + } + + // The default workload to calculate the RU capacity. + if e.WorkloadType == ast.WorkloadNone { + e.WorkloadType = ast.TPCC + } + baseCost, ok := workloadBaseRUCostMap[e.WorkloadType] + if !ok { + return errors.Errorf("unknown workload '%T'", e.WorkloadType) + } + + if totalTiDBCPUQuota/baseCost.tidbToKVCPURatio < totalKVCPUQuota { + totalKVCPUQuota = totalTiDBCPUQuota / baseCost.tidbToKVCPURatio + } + ruPerKVCPU := float64(ruCfg.ReadBaseCost)*float64(baseCost.readReqCount) + + float64(ruCfg.CPUMsCost)*baseCost.kvCPU*1000 + // convert to ms + float64(ruCfg.ReadBytesCost)*float64(baseCost.readBytes) + + float64(ruCfg.WriteBaseCost)*float64(baseCost.writeReqCount) + + float64(ruCfg.WriteBytesCost)*float64(baseCost.writeBytes) + quota := totalKVCPUQuota * ruPerKVCPU + req.AppendUint64(0, uint64(quota)) + return nil +} + +func staticCalibrateTpch10(req *chunk.Chunk, clusterInfo []infoschema.ServerInfo, ruCfg *resourceControlClient.RUConfig) error { + // TPCH10 only considers the resource usage of the TiFlash including cpu and read bytes. Others are ignored. + // cpu usage: 105494.666484 / 20 / 20 = 263.74 + // read bytes: 401799161689.0 / 20 / 20 = 1004497904.22 + const cpuTimePerCPUPerSec float64 = 263.74 + const readBytesPerCPUPerSec float64 = 1004497904.22 + ruPerCPU := float64(ruCfg.CPUMsCost)*cpuTimePerCPUPerSec + float64(ruCfg.ReadBytesCost)*readBytesPerCPUPerSec + totalTiFlashLogicalCores, err := getTiFlashLogicalCores(clusterInfo) + if err != nil { + return err + } + quota := totalTiFlashLogicalCores * ruPerCPU + req.AppendUint64(0, uint64(quota)) + return nil +} + +func getTiDBTotalCPUQuota(clusterInfo []infoschema.ServerInfo) (float64, error) { + cpuQuota := float64(runtime.GOMAXPROCS(0)) + failpoint.Inject("mockGOMAXPROCS", func(val failpoint.Value) { + if val != nil { + cpuQuota = float64(val.(int)) + } + }) + instanceNum := count(clusterInfo, serverTypeTiDB) + return cpuQuota * float64(instanceNum), nil +} + +func getTiKVTotalCPUQuota(clusterInfo []infoschema.ServerInfo) (float64, error) { + instanceNum := count(clusterInfo, serverTypeTiKV) + if instanceNum == 0 { + return 0.0, errors.New("no server with type 'tikv' is found") + } + cpuQuota, err := fetchServerCPUQuota(clusterInfo, serverTypeTiKV, "tikv_server_cpu_cores_quota") + if err != nil { + return 0.0, err + } + return cpuQuota * float64(instanceNum), nil +} + +func getTiFlashLogicalCores(clusterInfo []infoschema.ServerInfo) (float64, error) { + instanceNum := count(clusterInfo, serverTypeTiFlash) + if instanceNum == 0 { + return 0.0, nil + } + cpuQuota, err := fetchServerCPUQuota(clusterInfo, serverTypeTiFlash, "tiflash_proxy_tikv_server_cpu_cores_quota") + if err != nil { + return 0.0, err + } + return cpuQuota * float64(instanceNum), nil +} + +func getTiFlashRUPerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, startTime, endTime string) (*timeSeriesValues, error) { + query := fmt.Sprintf("SELECT time, value FROM METRICS_SCHEMA.tiflash_resource_manager_resource_unit where time >= '%s' and time <= '%s' ORDER BY time asc", startTime, endTime) + return getValuesFromMetrics(ctx, sctx, exec, query) +} + +func getTiFlashCPUUsagePerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, startTime, endTime string) (*timeSeriesValues, error) { + query := fmt.Sprintf("SELECT time, sum(value) FROM METRICS_SCHEMA.tiflash_process_cpu_usage where time >= '%s' and time <= '%s' and job = 'tiflash' GROUP BY time ORDER BY time asc", startTime, endTime) + return getValuesFromMetrics(ctx, sctx, exec, query) +} + +type timePointValue struct { + tp time.Time + val float64 +} + +type timeSeriesValues struct { + vals []*timePointValue + idx int +} + +func (t *timeSeriesValues) isEnd() bool { + return t.idx >= len(t.vals) +} + +func (t *timeSeriesValues) next() { + t.idx++ +} + +func (t *timeSeriesValues) getTime() time.Time { + return t.vals[t.idx].tp +} + +func (t *timeSeriesValues) getValue() float64 { + return t.vals[t.idx].val +} + +func (t *timeSeriesValues) advance(target time.Time) bool { + for ; t.idx < len(t.vals); t.idx++ { + // `target` is maximal time in other timeSeriesValues, + // so we should find the time which offset is less than 10s. + if t.vals[t.idx].tp.Add(time.Second * 10).After(target) { + return t.vals[t.idx].tp.Add(-time.Second * 10).Before(target) + } + } + return false +} + +func getRUPerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, startTime, endTime string) (*timeSeriesValues, error) { + query := fmt.Sprintf("SELECT time, value FROM METRICS_SCHEMA.resource_manager_resource_unit where time >= '%s' and time <= '%s' ORDER BY time asc", startTime, endTime) + return getValuesFromMetrics(ctx, sctx, exec, query) +} + +func getComponentCPUUsagePerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, component, startTime, endTime string) (*timeSeriesValues, error) { + query := fmt.Sprintf("SELECT time, sum(value) FROM METRICS_SCHEMA.process_cpu_usage where time >= '%s' and time <= '%s' and job like '%%%s' GROUP BY time ORDER BY time asc", startTime, endTime, component) + return getValuesFromMetrics(ctx, sctx, exec, query) +} + +func getValuesFromMetrics(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, query string) (*timeSeriesValues, error) { + rows, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, query) + if err != nil { + return nil, errors.Trace(err) + } + ret := make([]*timePointValue, 0, len(rows)) + for _, row := range rows { + if tp, err := row.GetTime(0).AdjustedGoTime(sctx.GetSessionVars().Location()); err == nil { + ret = append(ret, &timePointValue{ + tp: tp, + val: row.GetFloat64(1), + }) + } + } + return &timeSeriesValues{idx: 0, vals: ret}, nil +} + +func count(clusterInfo []infoschema.ServerInfo, ty string) int { + num := 0 + for _, e := range clusterInfo { + if e.ServerType == ty { + num++ + } + } + return num +} + +func fetchServerCPUQuota(serverInfos []infoschema.ServerInfo, serverType string, metricName string) (float64, error) { + var cpuQuota float64 + err := fetchStoreMetrics(serverInfos, serverType, func(addr string, resp *http.Response) error { + if resp.StatusCode != http.StatusOK { + return errors.Errorf("request %s failed: %s", addr, resp.Status) + } + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, metricName) { + continue + } + // the metrics format is like following: + // tikv_server_cpu_cores_quota 8 + quota, err := strconv.ParseFloat(line[len(metricName)+1:], 64) + if err == nil { + cpuQuota = quota + } + return errors.Trace(err) + } + return errors.Errorf("metrics '%s' not found from server '%s'", metricName, addr) + }) + return cpuQuota, err +} + +func fetchStoreMetrics(serversInfo []infoschema.ServerInfo, serverType string, onResp func(string, *http.Response) error) error { + var firstErr error + for _, srv := range serversInfo { + if srv.ServerType != serverType { + continue + } + if len(srv.StatusAddr) == 0 { + continue + } + url := fmt.Sprintf("%s://%s/metrics", util.InternalHTTPSchema(), srv.StatusAddr) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return err + } + var resp *http.Response + failpoint.Inject("mockMetricsResponse", func(val failpoint.Value) { + if val != nil { + data, _ := base64.StdEncoding.DecodeString(val.(string)) + resp = &http.Response{ + StatusCode: http.StatusOK, + Body: noopCloserWrapper{ + Reader: strings.NewReader(string(data)), + }, + } + } + }) + if resp == nil { + var err1 error + // ignore false positive go line, can't use defer here because it's in a loop. + //nolint:bodyclose + resp, err1 = util.InternalHTTPClient().Do(req) + if err1 != nil { + if firstErr == nil { + firstErr = err1 + } + continue + } + } + err = onResp(srv.Address, resp) + resp.Body.Close() + return err + } + if firstErr == nil { + firstErr = errors.Errorf("no server with type '%s' is found", serverType) + } + return firstErr +} + +type noopCloserWrapper struct { + io.Reader +} + +func (noopCloserWrapper) Close() error { + return nil +} diff --git a/pkg/executor/internal/exec/binding__failpoint_binding__.go b/pkg/executor/internal/exec/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..7d0ade518b1e1 --- /dev/null +++ b/pkg/executor/internal/exec/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package exec + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/internal/exec/executor.go b/pkg/executor/internal/exec/executor.go index 6feb131b67306..74f6260ac652c 100644 --- a/pkg/executor/internal/exec/executor.go +++ b/pkg/executor/internal/exec/executor.go @@ -97,9 +97,9 @@ func newExecutorChunkAllocator(vars *variable.SessionVars, retFieldTypes []*type // InitCap returns the initial capacity for chunk func (e *executorChunkAllocator) InitCap() int { - failpoint.Inject("initCap", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) + if val, _err_ := failpoint.Eval(_curpkg_("initCap")); _err_ == nil { + return val.(int) + } return e.initCap } @@ -110,9 +110,9 @@ func (e *executorChunkAllocator) SetInitCap(c int) { // MaxChunkSize returns the max chunk size. func (e *executorChunkAllocator) MaxChunkSize() int { - failpoint.Inject("maxChunkSize", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) + if val, _err_ := failpoint.Eval(_curpkg_("maxChunkSize")); _err_ == nil { + return val.(int) + } return e.maxChunkSize } diff --git a/pkg/executor/internal/exec/executor.go__failpoint_stash__ b/pkg/executor/internal/exec/executor.go__failpoint_stash__ new file mode 100644 index 0000000000000..6feb131b67306 --- /dev/null +++ b/pkg/executor/internal/exec/executor.go__failpoint_stash__ @@ -0,0 +1,468 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package exec + +import ( + "context" + "reflect" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/tracing" + "go.uber.org/atomic" +) + +// Executor is the physical implementation of an algebra operator. +// +// In TiDB, all algebra operators are implemented as iterators, i.e., they +// support a simple Open-Next-Close protocol. See this paper for more details: +// +// "Volcano-An Extensible and Parallel Query Evaluation System" +// +// Different from Volcano's execution model, a "Next" function call in TiDB will +// return a batch of rows, other than a single row in Volcano. +// NOTE: Executors must call "chk.Reset()" before appending their results to it. +type Executor interface { + NewChunk() *chunk.Chunk + NewChunkWithCapacity(fields []*types.FieldType, capacity int, maxCachesize int) *chunk.Chunk + + RuntimeStats() *execdetails.BasicRuntimeStats + + HandleSQLKillerSignal() error + RegisterSQLAndPlanInExecForTopSQL() + + AllChildren() []Executor + SetAllChildren([]Executor) + Open(context.Context) error + Next(ctx context.Context, req *chunk.Chunk) error + + // `Close()` may be called at any time after `Open()` and it may be called with `Next()` at the same time + Close() error + Schema() *expression.Schema + RetFieldTypes() []*types.FieldType + InitCap() int + MaxChunkSize() int + + // Detach detaches the current executor from the session context without considering its children. + // + // It has to make sure, no matter whether it returns true or false, both the original executor and the returning executor + // should be able to be used correctly. + Detach() (Executor, bool) +} + +var _ Executor = &BaseExecutor{} + +// executorChunkAllocator is a helper to implement `Chunk` related methods in `Executor` interface +type executorChunkAllocator struct { + AllocPool chunk.Allocator + retFieldTypes []*types.FieldType + initCap int + maxChunkSize int +} + +// newExecutorChunkAllocator creates a new `executorChunkAllocator` +func newExecutorChunkAllocator(vars *variable.SessionVars, retFieldTypes []*types.FieldType) executorChunkAllocator { + return executorChunkAllocator{ + AllocPool: vars.GetChunkAllocator(), + initCap: vars.InitChunkSize, + maxChunkSize: vars.MaxChunkSize, + retFieldTypes: retFieldTypes, + } +} + +// InitCap returns the initial capacity for chunk +func (e *executorChunkAllocator) InitCap() int { + failpoint.Inject("initCap", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) + return e.initCap +} + +// SetInitCap sets the initial capacity for chunk +func (e *executorChunkAllocator) SetInitCap(c int) { + e.initCap = c +} + +// MaxChunkSize returns the max chunk size. +func (e *executorChunkAllocator) MaxChunkSize() int { + failpoint.Inject("maxChunkSize", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) + return e.maxChunkSize +} + +// SetMaxChunkSize sets the max chunk size. +func (e *executorChunkAllocator) SetMaxChunkSize(size int) { + e.maxChunkSize = size +} + +// NewChunk creates a new chunk according to the executor configuration +func (e *executorChunkAllocator) NewChunk() *chunk.Chunk { + return e.NewChunkWithCapacity(e.retFieldTypes, e.InitCap(), e.MaxChunkSize()) +} + +// NewChunkWithCapacity allows the caller to allocate the chunk with any types, capacity and max size in the pool +func (e *executorChunkAllocator) NewChunkWithCapacity(fields []*types.FieldType, capacity int, maxCachesize int) *chunk.Chunk { + return e.AllocPool.Alloc(fields, capacity, maxCachesize) +} + +// executorMeta is a helper to store metadata for an execturo and implement the getter +type executorMeta struct { + schema *expression.Schema + children []Executor + retFieldTypes []*types.FieldType + id int +} + +// newExecutorMeta creates a new `executorMeta` +func newExecutorMeta(schema *expression.Schema, id int, children ...Executor) executorMeta { + e := executorMeta{ + id: id, + schema: schema, + children: children, + } + if schema != nil { + cols := schema.Columns + e.retFieldTypes = make([]*types.FieldType, len(cols)) + for i := range cols { + e.retFieldTypes[i] = cols[i].RetType + } + } + return e +} + +// NewChunkWithCapacity allows the caller to allocate the chunk with any types, capacity and max size in the pool +func (e *executorMeta) RetFieldTypes() []*types.FieldType { + return e.retFieldTypes +} + +// ID returns the id of an executor. +func (e *executorMeta) ID() int { + return e.id +} + +// AllChildren returns all children. +func (e *executorMeta) AllChildren() []Executor { + return e.children +} + +// SetAllChildren sets the children for an executor. +func (e *executorMeta) SetAllChildren(children []Executor) { + e.children = children +} + +// ChildrenLen returns the length of children. +func (e *executorMeta) ChildrenLen() int { + return len(e.children) +} + +// EmptyChildren judges whether the children is empty. +func (e *executorMeta) EmptyChildren() bool { + return len(e.children) == 0 +} + +// SetChildren sets a child for an executor. +func (e *executorMeta) SetChildren(idx int, ex Executor) { + e.children[idx] = ex +} + +// Children returns the children for an executor. +func (e *executorMeta) Children(idx int) Executor { + return e.children[idx] +} + +// Schema returns the current BaseExecutor's schema. If it is nil, then create and return a new one. +func (e *executorMeta) Schema() *expression.Schema { + if e.schema == nil { + return expression.NewSchema() + } + return e.schema +} + +// GetSchema gets the schema. +func (e *executorMeta) GetSchema() *expression.Schema { + return e.schema +} + +// executorStats is a helper to implement the stats related methods for `Executor` +type executorStats struct { + runtimeStats *execdetails.BasicRuntimeStats + isSQLAndPlanRegistered *atomic.Bool + sqlDigest *parser.Digest + planDigest *parser.Digest + normalizedSQL string + normalizedPlan string + inRestrictedSQL bool +} + +// newExecutorStats creates a new `executorStats` +func newExecutorStats(stmtCtx *stmtctx.StatementContext, id int) executorStats { + normalizedSQL, sqlDigest := stmtCtx.SQLDigest() + normalizedPlan, planDigest := stmtCtx.GetPlanDigest() + e := executorStats{ + isSQLAndPlanRegistered: &stmtCtx.IsSQLAndPlanRegistered, + normalizedSQL: normalizedSQL, + sqlDigest: sqlDigest, + normalizedPlan: normalizedPlan, + planDigest: planDigest, + inRestrictedSQL: stmtCtx.InRestrictedSQL, + } + + if stmtCtx.RuntimeStatsColl != nil { + if id > 0 { + e.runtimeStats = stmtCtx.RuntimeStatsColl.GetBasicRuntimeStats(id) + } + } + + return e +} + +// RuntimeStats returns the runtime stats of an executor. +func (e *executorStats) RuntimeStats() *execdetails.BasicRuntimeStats { + return e.runtimeStats +} + +// RegisterSQLAndPlanInExecForTopSQL registers the current SQL and Plan on top sql +func (e *executorStats) RegisterSQLAndPlanInExecForTopSQL() { + if topsqlstate.TopSQLEnabled() && e.isSQLAndPlanRegistered.CompareAndSwap(false, true) { + topsql.RegisterSQL(e.normalizedSQL, e.sqlDigest, e.inRestrictedSQL) + if len(e.normalizedPlan) > 0 { + topsql.RegisterPlan(e.normalizedPlan, e.planDigest) + } + } +} + +type signalHandler interface { + HandleSignal() error +} + +// executorKillerHandler is a helper to implement the killer related methods for `Executor`. +type executorKillerHandler struct { + handler signalHandler +} + +func (e *executorKillerHandler) HandleSQLKillerSignal() error { + return e.handler.HandleSignal() +} + +func newExecutorKillerHandler(handler signalHandler) executorKillerHandler { + return executorKillerHandler{handler} +} + +// BaseExecutorV2 is a simplified version of `BaseExecutor`, which doesn't contain a full session context +type BaseExecutorV2 struct { + executorMeta + executorKillerHandler + executorStats + executorChunkAllocator +} + +// NewBaseExecutorV2 creates a new BaseExecutorV2 instance. +func NewBaseExecutorV2(vars *variable.SessionVars, schema *expression.Schema, id int, children ...Executor) BaseExecutorV2 { + executorMeta := newExecutorMeta(schema, id, children...) + e := BaseExecutorV2{ + executorMeta: executorMeta, + executorStats: newExecutorStats(vars.StmtCtx, id), + executorChunkAllocator: newExecutorChunkAllocator(vars, executorMeta.RetFieldTypes()), + executorKillerHandler: newExecutorKillerHandler(&vars.SQLKiller), + } + return e +} + +// Open initializes children recursively and "childrenResults" according to children's schemas. +func (e *BaseExecutorV2) Open(ctx context.Context) error { + for _, child := range e.children { + err := Open(ctx, child) + if err != nil { + return err + } + } + return nil +} + +// Close closes all executors and release all resources. +func (e *BaseExecutorV2) Close() error { + var firstErr error + for _, src := range e.children { + if err := Close(src); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} + +// Next fills multiple rows into a chunk. +func (*BaseExecutorV2) Next(_ context.Context, _ *chunk.Chunk) error { + return nil +} + +// Detach detaches the current executor from the session context. +func (*BaseExecutorV2) Detach() (Executor, bool) { + return nil, false +} + +// BuildNewBaseExecutorV2 builds a new `BaseExecutorV2` based on the configuration of the current base executor. +// It's used to build a new sub-executor from an existing executor. For example, the `IndexLookUpExecutor` will use +// this function to build `TableReaderExecutor` +func (e *BaseExecutorV2) BuildNewBaseExecutorV2(stmtRuntimeStatsColl *execdetails.RuntimeStatsColl, schema *expression.Schema, id int, children ...Executor) BaseExecutorV2 { + newExecutorMeta := newExecutorMeta(schema, id, children...) + + newExecutorStats := e.executorStats + if stmtRuntimeStatsColl != nil { + if id > 0 { + newExecutorStats.runtimeStats = stmtRuntimeStatsColl.GetBasicRuntimeStats(id) + } + } + + newChunkAllocator := e.executorChunkAllocator + newChunkAllocator.retFieldTypes = newExecutorMeta.RetFieldTypes() + newE := BaseExecutorV2{ + executorMeta: newExecutorMeta, + executorStats: newExecutorStats, + executorChunkAllocator: newChunkAllocator, + executorKillerHandler: e.executorKillerHandler, + } + return newE +} + +// BaseExecutor holds common information for executors. +type BaseExecutor struct { + ctx sessionctx.Context + + BaseExecutorV2 +} + +// NewBaseExecutor creates a new BaseExecutor instance. +func NewBaseExecutor(ctx sessionctx.Context, schema *expression.Schema, id int, children ...Executor) BaseExecutor { + return BaseExecutor{ + ctx: ctx, + BaseExecutorV2: NewBaseExecutorV2(ctx.GetSessionVars(), schema, id, children...), + } +} + +// Ctx return ```sessionctx.Context``` of Executor +func (e *BaseExecutor) Ctx() sessionctx.Context { + return e.ctx +} + +// UpdateDeltaForTableID updates the delta info for the table with tableID. +func (e *BaseExecutor) UpdateDeltaForTableID(id int64) { + txnCtx := e.ctx.GetSessionVars().TxnCtx + txnCtx.UpdateDeltaForTable(id, 0, 0, nil) +} + +// GetSysSession gets a system session context from executor. +func (e *BaseExecutor) GetSysSession() (sessionctx.Context, error) { + dom := domain.GetDomain(e.Ctx()) + sysSessionPool := dom.SysSessionPool() + ctx, err := sysSessionPool.Get() + if err != nil { + return nil, err + } + restrictedCtx := ctx.(sessionctx.Context) + restrictedCtx.GetSessionVars().InRestrictedSQL = true + return restrictedCtx, nil +} + +// ReleaseSysSession releases a system session context to executor. +func (e *BaseExecutor) ReleaseSysSession(ctx context.Context, sctx sessionctx.Context) { + if sctx == nil { + return + } + dom := domain.GetDomain(e.Ctx()) + sysSessionPool := dom.SysSessionPool() + if _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "rollback"); err != nil { + sctx.(pools.Resource).Close() + return + } + sysSessionPool.Put(sctx.(pools.Resource)) +} + +// TryNewCacheChunk tries to get a cached chunk +func TryNewCacheChunk(e Executor) *chunk.Chunk { + return e.NewChunk() +} + +// RetTypes returns all output column types. +func RetTypes(e Executor) []*types.FieldType { + return e.RetFieldTypes() +} + +// NewFirstChunk creates a new chunk to buffer current executor's result. +func NewFirstChunk(e Executor) *chunk.Chunk { + return chunk.New(e.RetFieldTypes(), e.InitCap(), e.MaxChunkSize()) +} + +// Open is a wrapper function on e.Open(), it handles some common codes. +func Open(ctx context.Context, e Executor) (err error) { + defer func() { + if r := recover(); r != nil { + err = util.GetRecoverError(r) + } + }() + return e.Open(ctx) +} + +// Next is a wrapper function on e.Next(), it handles some common codes. +func Next(ctx context.Context, e Executor, req *chunk.Chunk) (err error) { + defer func() { + if r := recover(); r != nil { + err = util.GetRecoverError(r) + } + }() + if e.RuntimeStats() != nil { + start := time.Now() + defer func() { e.RuntimeStats().Record(time.Since(start), req.NumRows()) }() + } + + if err := e.HandleSQLKillerSignal(); err != nil { + return err + } + + r, ctx := tracing.StartRegionEx(ctx, reflect.TypeOf(e).String()+".Next") + defer r.End() + + e.RegisterSQLAndPlanInExecForTopSQL() + err = e.Next(ctx, req) + + if err != nil { + return err + } + // recheck whether the session/query is killed during the Next() + return e.HandleSQLKillerSignal() +} + +// Close is a wrapper function on e.Close(), it handles some common codes. +func Close(e Executor) (err error) { + defer func() { + if r := recover(); r != nil { + err = util.GetRecoverError(r) + } + }() + return e.Close() +} diff --git a/pkg/executor/internal/mpp/binding__failpoint_binding__.go b/pkg/executor/internal/mpp/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..bb1223ce4ae04 --- /dev/null +++ b/pkg/executor/internal/mpp/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package mpp + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/internal/mpp/executor_with_retry.go b/pkg/executor/internal/mpp/executor_with_retry.go index fd2bd527dc48c..4d3d6d6b6578a 100644 --- a/pkg/executor/internal/mpp/executor_with_retry.go +++ b/pkg/executor/internal/mpp/executor_with_retry.go @@ -84,11 +84,11 @@ func NewExecutorWithRetry(ctx context.Context, sctx sessionctx.Context, parentTr // 3. For cached table, will not dispatch tasks to TiFlash, so no need to recovery. enableMPPRecovery := disaggTiFlashWithAutoScaler && !allowTiFlashFallback - failpoint.Inject("mpp_recovery_test_mock_enable", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_mock_enable")); _err_ == nil { if !allowTiFlashFallback { enableMPPRecovery = true } - }) + } recoveryHandler := NewRecoveryHandler(disaggTiFlashWithAutoScaler, uint64(holdCap), enableMPPRecovery, parentTracker) @@ -176,12 +176,12 @@ func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { resp, mppErr := r.coord.Next(ctx) // Mock recovery n times. - failpoint.Inject("mpp_recovery_test_max_err_times", func(forceErrCnt failpoint.Value) { + if forceErrCnt, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_max_err_times")); _err_ == nil { forceErrCntInt := forceErrCnt.(int) if r.mppErrRecovery.RecoveryCnt() < uint32(forceErrCntInt) { mppErr = errors.New("mock mpp error") } - }) + } if mppErr != nil { recoveryErr := r.mppErrRecovery.Recovery(&RecoveryInfo{ @@ -190,14 +190,14 @@ func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { }) // Mock recovery succeed, ignore no recovery handler err. - failpoint.Inject("mpp_recovery_test_ignore_recovery_err", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_ignore_recovery_err")); _err_ == nil { if recoveryErr == nil { panic("mocked mpp err should got recovery err") } if strings.Contains(mppErr.Error(), "mock mpp error") && strings.Contains(recoveryErr.Error(), "no handler to recovery") { recoveryErr = nil } - }) + } if recoveryErr != nil { logutil.BgLogger().Error("recovery mpp error failed", zap.Any("mppErr", mppErr), @@ -224,14 +224,14 @@ func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { r.mppErrRecovery.HoldResult(resp.(*mppResponse)) } - failpoint.Inject("mpp_recovery_test_hold_size", func(num failpoint.Value) { + if num, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_hold_size")); _err_ == nil { // Note: this failpoint only execute once. curRows := r.mppErrRecovery.NumHoldResp() numInt := num.(int) if curRows != numInt { panic(fmt.Sprintf("unexpected holding rows, cur: %d", curRows)) } - }) + } return nil } diff --git a/pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ b/pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ new file mode 100644 index 0000000000000..fd2bd527dc48c --- /dev/null +++ b/pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ @@ -0,0 +1,254 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mpp + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/executor/mppcoordmanager" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + plannercore "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +// ExecutorWithRetry receive mppResponse from localMppCoordinator, +// and tries to recovery mpp err if necessary. +// The abstraction layer of reading mpp resp: +// 1. MPPGather: As part of the TiDB Volcano model executor, it is equivalent to a TableReader. +// 2. selectResult: Decode select result(mppResponse) into chunk. Also record runtime info. +// 3. ExecutorWithRetry: Recovery mpp err if possible and retry MPP Task. +// 4. localMppCoordinator: Generate MPP fragment and dispatch MPPTask. +// And receive MPP status for better err msg and correct stats for Limit. +// 5. mppIterator: Send or receive MPP RPC. +type ExecutorWithRetry struct { + coord kv.MppCoordinator + sctx sessionctx.Context + is infoschema.InfoSchema + plan plannercore.PhysicalPlan + ctx context.Context + memTracker *memory.Tracker + // mppErrRecovery is designed for the recovery of MPP errors. + // Basic idea: + // 1. It attempts to hold the results of MPP. During the holding process, if an error occurs, it starts error recovery. + // If the recovery is successful, it discards held results and reconstructs the respIter, then re-executes the MPP task. + // If the recovery fails, an error is reported directly. + // 2. If the held MPP results exceed the capacity, will starts returning results to caller. + // Once the results start being returned, error recovery cannot be performed anymore. + mppErrRecovery *RecoveryHandler + planIDs []int + // Expose to let MPPGather access. + KVRanges []kv.KeyRange + queryID kv.MPPQueryID + startTS uint64 + gatherID uint64 + nodeCnt int +} + +var _ kv.Response = &ExecutorWithRetry{} + +// NewExecutorWithRetry create ExecutorWithRetry. +func NewExecutorWithRetry(ctx context.Context, sctx sessionctx.Context, parentTracker *memory.Tracker, planIDs []int, + plan plannercore.PhysicalPlan, startTS uint64, queryID kv.MPPQueryID, + is infoschema.InfoSchema) (*ExecutorWithRetry, error) { + // TODO: After add row info in tipb.DataPacket, we can use row count as capacity. + // For now, use the number of tipb.DataPacket as capacity. + const holdCap = 2 + + disaggTiFlashWithAutoScaler := config.GetGlobalConfig().DisaggregatedTiFlash && config.GetGlobalConfig().UseAutoScaler + _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] + + // 1. For now, mpp err recovery only support MemLimit, which is only useful when AutoScaler is used. + // 2. When enable fallback to tikv, the returned mpp err will be ErrTiFlashServerTimeout, + // which we cannot handle for now. Also there is no need to recovery because tikv will retry the query. + // 3. For cached table, will not dispatch tasks to TiFlash, so no need to recovery. + enableMPPRecovery := disaggTiFlashWithAutoScaler && !allowTiFlashFallback + + failpoint.Inject("mpp_recovery_test_mock_enable", func() { + if !allowTiFlashFallback { + enableMPPRecovery = true + } + }) + + recoveryHandler := NewRecoveryHandler(disaggTiFlashWithAutoScaler, + uint64(holdCap), enableMPPRecovery, parentTracker) + memTracker := memory.NewTracker(parentTracker.Label(), 0) + memTracker.AttachTo(parentTracker) + retryer := &ExecutorWithRetry{ + ctx: ctx, + sctx: sctx, + memTracker: memTracker, + planIDs: planIDs, + is: is, + plan: plan, + startTS: startTS, + queryID: queryID, + mppErrRecovery: recoveryHandler, + } + + var err error + retryer.KVRanges, err = retryer.setupMPPCoordinator(ctx, false) + return retryer, err +} + +// Next implements kv.Response interface. +func (r *ExecutorWithRetry) Next(ctx context.Context) (resp kv.ResultSubset, err error) { + if err = r.nextWithRecovery(ctx); err != nil { + return nil, err + } + + if r.mppErrRecovery.NumHoldResp() != 0 { + if resp, err = r.mppErrRecovery.PopFrontResp(); err != nil { + return nil, err + } + } else if resp, err = r.coord.Next(ctx); err != nil { + return nil, err + } + return resp, nil +} + +// Close implements kv.Response interface. +func (r *ExecutorWithRetry) Close() error { + r.mppErrRecovery.ResetHolder() + r.memTracker.Detach() + // Need to close coordinator before unregister to avoid coord.Close() takes too long. + err := r.coord.Close() + mppcoordmanager.InstanceMPPCoordinatorManager.Unregister(r.getCoordUniqueID()) + return err +} + +func (r *ExecutorWithRetry) setupMPPCoordinator(ctx context.Context, recoverying bool) ([]kv.KeyRange, error) { + if recoverying { + // Sanity check. + if r.coord == nil { + return nil, errors.New("mpp coordinator should not be nil when recoverying") + } + // Only report runtime stats when there is no error. + r.coord.(*localMppCoordinator).closeWithoutReport() + mppcoordmanager.InstanceMPPCoordinatorManager.Unregister(r.getCoordUniqueID()) + } + + // Make sure gatherID is updated before build coord. + r.gatherID = allocMPPGatherID(r.sctx) + + r.coord = r.buildCoordinator() + if err := mppcoordmanager.InstanceMPPCoordinatorManager.Register(r.getCoordUniqueID(), r.coord); err != nil { + return nil, err + } + + _, kvRanges, err := r.coord.Execute(ctx) + if err != nil { + return nil, err + } + + if r.nodeCnt = r.coord.GetNodeCnt(); r.nodeCnt <= 0 { + return nil, errors.Errorf("tiflash node count should be greater than zero: %v", r.nodeCnt) + } + return kvRanges, err +} + +func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { + if !r.mppErrRecovery.Enabled() { + return nil + } + + for r.mppErrRecovery.CanHoldResult() { + resp, mppErr := r.coord.Next(ctx) + + // Mock recovery n times. + failpoint.Inject("mpp_recovery_test_max_err_times", func(forceErrCnt failpoint.Value) { + forceErrCntInt := forceErrCnt.(int) + if r.mppErrRecovery.RecoveryCnt() < uint32(forceErrCntInt) { + mppErr = errors.New("mock mpp error") + } + }) + + if mppErr != nil { + recoveryErr := r.mppErrRecovery.Recovery(&RecoveryInfo{ + MPPErr: mppErr, + NodeCnt: r.nodeCnt, + }) + + // Mock recovery succeed, ignore no recovery handler err. + failpoint.Inject("mpp_recovery_test_ignore_recovery_err", func() { + if recoveryErr == nil { + panic("mocked mpp err should got recovery err") + } + if strings.Contains(mppErr.Error(), "mock mpp error") && strings.Contains(recoveryErr.Error(), "no handler to recovery") { + recoveryErr = nil + } + }) + + if recoveryErr != nil { + logutil.BgLogger().Error("recovery mpp error failed", zap.Any("mppErr", mppErr), + zap.Any("recoveryErr", recoveryErr)) + return mppErr + } + + logutil.BgLogger().Info("recovery mpp error succeed, begin next retry", + zap.Any("mppErr", mppErr), zap.Any("recoveryCnt", r.mppErrRecovery.RecoveryCnt())) + + if _, err := r.setupMPPCoordinator(r.ctx, true); err != nil { + logutil.BgLogger().Error("setup resp iter when recovery mpp err failed", zap.Any("err", err)) + return mppErr + } + r.mppErrRecovery.ResetHolder() + + continue + } + + if resp == nil { + break + } + + r.mppErrRecovery.HoldResult(resp.(*mppResponse)) + } + + failpoint.Inject("mpp_recovery_test_hold_size", func(num failpoint.Value) { + // Note: this failpoint only execute once. + curRows := r.mppErrRecovery.NumHoldResp() + numInt := num.(int) + if curRows != numInt { + panic(fmt.Sprintf("unexpected holding rows, cur: %d", curRows)) + } + }) + return nil +} + +func allocMPPGatherID(ctx sessionctx.Context) uint64 { + mppQueryInfo := &ctx.GetSessionVars().StmtCtx.MPPQueryInfo + return mppQueryInfo.AllocatedMPPGatherID.Add(1) +} + +func (r *ExecutorWithRetry) buildCoordinator() kv.MppCoordinator { + _, serverAddr := mppcoordmanager.InstanceMPPCoordinatorManager.GetServerAddr() + return NewLocalMPPCoordinator(r.ctx, r.sctx, r.is, r.plan, r.planIDs, r.startTS, r.queryID, + r.gatherID, serverAddr, r.memTracker) +} + +func (r *ExecutorWithRetry) getCoordUniqueID() mppcoordmanager.CoordinatorUniqueID { + return mppcoordmanager.CoordinatorUniqueID{ + MPPQueryID: r.queryID, + GatherID: r.gatherID, + } +} diff --git a/pkg/executor/internal/mpp/local_mpp_coordinator.go b/pkg/executor/internal/mpp/local_mpp_coordinator.go index 15ef85eaa2da0..5299f27ef9a3f 100644 --- a/pkg/executor/internal/mpp/local_mpp_coordinator.go +++ b/pkg/executor/internal/mpp/local_mpp_coordinator.go @@ -363,11 +363,11 @@ func (c *localMppCoordinator) dispatchAll(ctx context.Context) { c.mu.Unlock() c.wg.Add(1) boMaxSleep := copr.CopNextMaxBackoff - failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("ReduceCopNextMaxBackoff")); _err_ == nil { if value.(bool) { boMaxSleep = 2 } - }) + } bo := backoff.NewBackoffer(ctx, boMaxSleep) go func(mppTask *kv.MPPDispatchRequest) { defer func() { @@ -395,11 +395,11 @@ func (c *localMppCoordinator) sendToRespCh(resp *mppResponse) (exit bool) { }() if c.memTracker != nil { respSize := resp.MemSize() - failpoint.Inject("testMPPOOMPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testMPPOOMPanic")); _err_ == nil { if val.(bool) && respSize != 0 { respSize = 1 << 30 } - }) + } c.memTracker.Consume(respSize) defer c.memTracker.Consume(-respSize) } @@ -451,14 +451,14 @@ func (c *localMppCoordinator) handleDispatchReq(ctx context.Context, bo *backoff c.sendError(errors.New(rpcResp.Error.Msg)) return } - failpoint.Inject("mppNonRootTaskError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mppNonRootTaskError")); _err_ == nil { if val.(bool) && !req.IsRoot { time.Sleep(1 * time.Second) atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) c.sendError(derr.ErrTiFlashServerTimeout) return } - }) + } if !req.IsRoot { return } @@ -747,11 +747,11 @@ func (c *localMppCoordinator) Execute(ctx context.Context) (kv.Response, []kv.Ke return nil, nil, errors.Trace(err) } } - failpoint.Inject("checkTotalMPPTasks", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkTotalMPPTasks")); _err_ == nil { if val.(int) != len(c.mppReqs) { - failpoint.Return(nil, nil, errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(c.mppReqs))) + return nil, nil, errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(c.mppReqs)) } - }) + } ctx = distsql.WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx.KvExecCounter) _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] diff --git a/pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ b/pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ new file mode 100644 index 0000000000000..15ef85eaa2da0 --- /dev/null +++ b/pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ @@ -0,0 +1,772 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mpp + +import ( + "context" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/mpp" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/executor/internal/builder" + "github.com/pingcap/tidb/pkg/executor/internal/util" + "github.com/pingcap/tidb/pkg/executor/metrics" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/store/copr" + "github.com/pingcap/tidb/pkg/store/driver/backoff" + derr "github.com/pingcap/tidb/pkg/store/driver/error" + util2 "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tipb/go-tipb" + clientutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +const ( + receiveReportTimeout = 3 * time.Second +) + +// mppResponse wraps mpp data packet. +type mppResponse struct { + err error + pbResp *mpp.MPPDataPacket + detail *copr.CopRuntimeStats + respTime time.Duration + respSize int64 +} + +// GetData implements the kv.ResultSubset GetData interface. +func (m *mppResponse) GetData() []byte { + return m.pbResp.Data +} + +// GetStartKey implements the kv.ResultSubset GetStartKey interface. +func (*mppResponse) GetStartKey() kv.Key { + return nil +} + +// GetCopRuntimeStats is unavailable currently. +func (m *mppResponse) GetCopRuntimeStats() *copr.CopRuntimeStats { + return m.detail +} + +// MemSize returns how many bytes of memory this response use +func (m *mppResponse) MemSize() int64 { + if m.respSize != 0 { + return m.respSize + } + if m.detail != nil { + m.respSize += int64(int(unsafe.Sizeof(execdetails.ExecDetails{}))) + } + if m.pbResp != nil { + m.respSize += int64(m.pbResp.Size()) + } + return m.respSize +} + +func (m *mppResponse) RespTime() time.Duration { + return m.respTime +} + +type mppRequestReport struct { + mppReq *kv.MPPDispatchRequest + errMsg string + executionSummaries []*tipb.ExecutorExecutionSummary + receivedReport bool // if received ReportStatus from mpp task +} + +// localMppCoordinator stands for constructing and dispatching mpp tasks in local tidb server, since these work might be done remotely too +type localMppCoordinator struct { + ctx context.Context + sessionCtx sessionctx.Context + is infoschema.InfoSchema + originalPlan base.PhysicalPlan + reqMap map[int64]*mppRequestReport + + cancelFunc context.CancelFunc + + wgDoneChan chan struct{} + + memTracker *memory.Tracker + + reportStatusCh chan struct{} // used to notify inside coordinator that all reports has been received + + vars *kv.Variables + + respChan chan *mppResponse + + finishCh chan struct{} + + coordinatorAddr string // empty if coordinator service not available + firstErrMsg string + + mppReqs []*kv.MPPDispatchRequest + + planIDs []int + mppQueryID kv.MPPQueryID + + wg sync.WaitGroup + gatherID uint64 + reportedReqCount int + startTS uint64 + mu sync.Mutex + + closed uint32 + + dispatchFailed uint32 + allReportsHandled uint32 + + needTriggerFallback bool + enableCollectExecutionInfo bool + reportExecutionInfo bool // if each mpp task needs to report execution info directly to coordinator through ReportMPPTaskStatus + + // Record node cnt that involved in the mpp computation. + nodeCnt int +} + +// NewLocalMPPCoordinator creates a new localMppCoordinator instance +func NewLocalMPPCoordinator(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema, plan base.PhysicalPlan, planIDs []int, startTS uint64, mppQueryID kv.MPPQueryID, gatherID uint64, coordinatorAddr string, memTracker *memory.Tracker) *localMppCoordinator { + if sctx.GetSessionVars().ChooseMppVersion() < kv.MppVersionV2 { + coordinatorAddr = "" + } + coord := &localMppCoordinator{ + ctx: ctx, + sessionCtx: sctx, + is: is, + originalPlan: plan, + planIDs: planIDs, + startTS: startTS, + mppQueryID: mppQueryID, + gatherID: gatherID, + coordinatorAddr: coordinatorAddr, + memTracker: memTracker, + finishCh: make(chan struct{}), + wgDoneChan: make(chan struct{}), + respChan: make(chan *mppResponse), + reportStatusCh: make(chan struct{}), + vars: sctx.GetSessionVars().KVVars, + reqMap: make(map[int64]*mppRequestReport), + } + + if len(coordinatorAddr) > 0 && needReportExecutionSummary(coord.originalPlan) { + coord.reportExecutionInfo = true + } + return coord +} + +func (c *localMppCoordinator) appendMPPDispatchReq(pf *plannercore.Fragment) error { + dagReq, err := builder.ConstructDAGReq(c.sessionCtx, []base.PhysicalPlan{pf.ExchangeSender}, kv.TiFlash) + if err != nil { + return errors.Trace(err) + } + for i := range pf.ExchangeSender.Schema().Columns { + dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) + } + if !pf.IsRoot { + dagReq.EncodeType = tipb.EncodeType_TypeCHBlock + } else { + dagReq.EncodeType = tipb.EncodeType_TypeChunk + } + for _, mppTask := range pf.ExchangeSender.Tasks { + if mppTask.PartitionTableIDs != nil { + err = util.UpdateExecutorTableID(context.Background(), dagReq.RootExecutor, true, mppTask.PartitionTableIDs) + } else if !mppTask.TiFlashStaticPrune { + // If isDisaggregatedTiFlashStaticPrune is true, it means this TableScan is under PartitionUnoin, + // tableID in TableScan is already the physical table id of this partition, no need to update again. + err = util.UpdateExecutorTableID(context.Background(), dagReq.RootExecutor, true, []int64{mppTask.TableID}) + } + if err != nil { + return errors.Trace(err) + } + err = c.fixTaskForCTEStorageAndReader(dagReq.RootExecutor, mppTask.Meta) + if err != nil { + return err + } + pbData, err := dagReq.Marshal() + if err != nil { + return errors.Trace(err) + } + + rgName := c.sessionCtx.GetSessionVars().StmtCtx.ResourceGroupName + if !variable.EnableResourceControl.Load() { + rgName = "" + } + logutil.BgLogger().Info("Dispatch mpp task", zap.Uint64("timestamp", mppTask.StartTs), + zap.Int64("ID", mppTask.ID), zap.Uint64("QueryTs", mppTask.MppQueryID.QueryTs), zap.Uint64("LocalQueryId", mppTask.MppQueryID.LocalQueryID), + zap.Uint64("ServerID", mppTask.MppQueryID.ServerID), zap.String("address", mppTask.Meta.GetAddress()), + zap.String("plan", plannercore.ToString(pf.ExchangeSender)), + zap.Int64("mpp-version", mppTask.MppVersion.ToInt64()), + zap.String("exchange-compression-mode", pf.ExchangeSender.CompressionMode.Name()), + zap.Uint64("GatherID", c.gatherID), + zap.String("resource_group", rgName), + ) + req := &kv.MPPDispatchRequest{ + Data: pbData, + Meta: mppTask.Meta, + ID: mppTask.ID, + IsRoot: pf.IsRoot, + Timeout: 10, + SchemaVar: c.is.SchemaMetaVersion(), + StartTs: c.startTS, + MppQueryID: mppTask.MppQueryID, + GatherID: c.gatherID, + MppVersion: mppTask.MppVersion, + CoordinatorAddress: c.coordinatorAddr, + ReportExecutionSummary: c.reportExecutionInfo, + State: kv.MppTaskReady, + ResourceGroupName: rgName, + ConnectionID: c.sessionCtx.GetSessionVars().ConnectionID, + ConnectionAlias: c.sessionCtx.ShowProcess().SessionAlias, + } + c.reqMap[req.ID] = &mppRequestReport{mppReq: req, receivedReport: false, errMsg: "", executionSummaries: nil} + c.mppReqs = append(c.mppReqs, req) + } + return nil +} + +// fixTaskForCTEStorageAndReader fixes the upstream/downstream tasks for the producers and consumers. +// After we split the fragments. A CTE producer in the fragment will holds all the task address of the consumers. +// For example, the producer has two task on node_1 and node_2. As we know that each consumer also has two task on the same nodes(node_1 and node_2) +// We need to prune address of node_2 for producer's task on node_1 since we just want the producer task on the node_1 only send to the consumer tasks on the node_1. +// And the same for the task on the node_2. +// And the same for the consumer task. We need to prune the unnecessary task address of its producer tasks(i.e. the downstream tasks). +func (c *localMppCoordinator) fixTaskForCTEStorageAndReader(exec *tipb.Executor, meta kv.MPPTaskMeta) error { + children := make([]*tipb.Executor, 0, 2) + switch exec.Tp { + case tipb.ExecType_TypeTableScan, tipb.ExecType_TypePartitionTableScan, tipb.ExecType_TypeIndexScan: + case tipb.ExecType_TypeSelection: + children = append(children, exec.Selection.Child) + case tipb.ExecType_TypeAggregation, tipb.ExecType_TypeStreamAgg: + children = append(children, exec.Aggregation.Child) + case tipb.ExecType_TypeTopN: + children = append(children, exec.TopN.Child) + case tipb.ExecType_TypeLimit: + children = append(children, exec.Limit.Child) + case tipb.ExecType_TypeExchangeSender: + children = append(children, exec.ExchangeSender.Child) + if len(exec.ExchangeSender.UpstreamCteTaskMeta) == 0 { + break + } + actualUpStreamTasks := make([][]byte, 0, len(exec.ExchangeSender.UpstreamCteTaskMeta)) + actualTIDs := make([]int64, 0, len(exec.ExchangeSender.UpstreamCteTaskMeta)) + for _, tasksFromOneConsumer := range exec.ExchangeSender.UpstreamCteTaskMeta { + for _, taskBytes := range tasksFromOneConsumer.EncodedTasks { + taskMeta := &mpp.TaskMeta{} + err := taskMeta.Unmarshal(taskBytes) + if err != nil { + return err + } + if taskMeta.Address != meta.GetAddress() { + continue + } + actualUpStreamTasks = append(actualUpStreamTasks, taskBytes) + actualTIDs = append(actualTIDs, taskMeta.TaskId) + } + } + logutil.BgLogger().Warn("refine tunnel for cte producer task", zap.String("the final tunnel", fmt.Sprintf("up stream consumer tasks: %v", actualTIDs))) + exec.ExchangeSender.EncodedTaskMeta = actualUpStreamTasks + case tipb.ExecType_TypeExchangeReceiver: + if len(exec.ExchangeReceiver.OriginalCtePrdocuerTaskMeta) == 0 { + break + } + exec.ExchangeReceiver.EncodedTaskMeta = [][]byte{} + actualTIDs := make([]int64, 0, 4) + for _, taskBytes := range exec.ExchangeReceiver.OriginalCtePrdocuerTaskMeta { + taskMeta := &mpp.TaskMeta{} + err := taskMeta.Unmarshal(taskBytes) + if err != nil { + return err + } + if taskMeta.Address != meta.GetAddress() { + continue + } + exec.ExchangeReceiver.EncodedTaskMeta = append(exec.ExchangeReceiver.EncodedTaskMeta, taskBytes) + actualTIDs = append(actualTIDs, taskMeta.TaskId) + } + logutil.BgLogger().Warn("refine tunnel for cte consumer task", zap.String("the final tunnel", fmt.Sprintf("down stream producer task: %v", actualTIDs))) + case tipb.ExecType_TypeJoin: + children = append(children, exec.Join.Children...) + case tipb.ExecType_TypeProjection: + children = append(children, exec.Projection.Child) + case tipb.ExecType_TypeWindow: + children = append(children, exec.Window.Child) + case tipb.ExecType_TypeSort: + children = append(children, exec.Sort.Child) + case tipb.ExecType_TypeExpand: + children = append(children, exec.Expand.Child) + case tipb.ExecType_TypeExpand2: + children = append(children, exec.Expand2.Child) + default: + return errors.Errorf("unknown new tipb protocol %d", exec.Tp) + } + for _, child := range children { + err := c.fixTaskForCTEStorageAndReader(child, meta) + if err != nil { + return err + } + } + return nil +} + +// DFS to check if plan needs report execution summary through ReportMPPTaskStatus mpp service +// Currently, return true if plan contains limit operator +func needReportExecutionSummary(plan base.PhysicalPlan) bool { + switch x := plan.(type) { + case *plannercore.PhysicalLimit: + return true + default: + for _, child := range x.Children() { + if needReportExecutionSummary(child) { + return true + } + } + } + return false +} + +func (c *localMppCoordinator) dispatchAll(ctx context.Context) { + for _, task := range c.mppReqs { + if atomic.LoadUint32(&c.closed) == 1 { + break + } + c.mu.Lock() + if task.State == kv.MppTaskReady { + task.State = kv.MppTaskRunning + } + c.mu.Unlock() + c.wg.Add(1) + boMaxSleep := copr.CopNextMaxBackoff + failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { + if value.(bool) { + boMaxSleep = 2 + } + }) + bo := backoff.NewBackoffer(ctx, boMaxSleep) + go func(mppTask *kv.MPPDispatchRequest) { + defer func() { + c.wg.Done() + }() + c.handleDispatchReq(ctx, bo, mppTask) + }(task) + } + c.wg.Wait() + close(c.wgDoneChan) + close(c.respChan) +} + +func (c *localMppCoordinator) sendError(err error) { + c.sendToRespCh(&mppResponse{err: err}) + c.cancelMppTasks() +} + +func (c *localMppCoordinator) sendToRespCh(resp *mppResponse) (exit bool) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("localMppCoordinator panic", zap.Stack("stack"), zap.Any("recover", r)) + c.sendError(util2.GetRecoverError(r)) + } + }() + if c.memTracker != nil { + respSize := resp.MemSize() + failpoint.Inject("testMPPOOMPanic", func(val failpoint.Value) { + if val.(bool) && respSize != 0 { + respSize = 1 << 30 + } + }) + c.memTracker.Consume(respSize) + defer c.memTracker.Consume(-respSize) + } + select { + case c.respChan <- resp: + case <-c.finishCh: + exit = true + } + return +} + +// TODO:: Consider that which way is better: +// - dispatch all Tasks at once, and connect Tasks at second. +// - dispatch Tasks and establish connection at the same time. +func (c *localMppCoordinator) handleDispatchReq(ctx context.Context, bo *backoff.Backoffer, req *kv.MPPDispatchRequest) { + var rpcResp *mpp.DispatchTaskResponse + var err error + var retry bool + for { + rpcResp, retry, err = c.sessionCtx.GetMPPClient().DispatchMPPTask( + kv.DispatchMPPTaskParam{ + Ctx: ctx, + Req: req, + EnableCollectExecutionInfo: c.enableCollectExecutionInfo, + Bo: bo.TiKVBackoffer(), + }) + if retry { + // TODO: If we want to retry, we might need to redo the plan fragment cutting and task scheduling. https://github.com/pingcap/tidb/issues/31015 + logutil.BgLogger().Warn("mpp dispatch meet error and retrying", zap.Error(err), zap.Uint64("timestamp", c.startTS), zap.Int64("task", req.ID), zap.Int64("mpp-version", req.MppVersion.ToInt64())) + continue + } + break + } + + if err != nil { + logutil.BgLogger().Error("mpp dispatch meet error", zap.String("error", err.Error()), zap.Uint64("timestamp", req.StartTs), zap.Int64("task", req.ID), zap.Int64("mpp-version", req.MppVersion.ToInt64())) + atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) + // if NeedTriggerFallback is true, we return timeout to trigger tikv's fallback + if c.needTriggerFallback { + err = derr.ErrTiFlashServerTimeout + } + c.sendError(err) + return + } + + if rpcResp.Error != nil { + logutil.BgLogger().Error("mpp dispatch response meet error", zap.String("error", rpcResp.Error.Msg), zap.Uint64("timestamp", req.StartTs), zap.Int64("task", req.ID), zap.Int64("task-mpp-version", req.MppVersion.ToInt64()), zap.Int64("error-mpp-version", rpcResp.Error.GetMppVersion())) + atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) + c.sendError(errors.New(rpcResp.Error.Msg)) + return + } + failpoint.Inject("mppNonRootTaskError", func(val failpoint.Value) { + if val.(bool) && !req.IsRoot { + time.Sleep(1 * time.Second) + atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) + c.sendError(derr.ErrTiFlashServerTimeout) + return + } + }) + if !req.IsRoot { + return + } + // only root task should establish a stream conn with tiFlash to receive result. + taskMeta := &mpp.TaskMeta{StartTs: req.StartTs, GatherId: c.gatherID, QueryTs: req.MppQueryID.QueryTs, LocalQueryId: req.MppQueryID.LocalQueryID, TaskId: req.ID, ServerId: req.MppQueryID.ServerID, + Address: req.Meta.GetAddress(), + MppVersion: req.MppVersion.ToInt64(), + ResourceGroupName: req.ResourceGroupName, + } + c.receiveResults(req, taskMeta, bo) +} + +// NOTE: We do not retry here, because retry is helpless when errors result from TiFlash or Network. If errors occur, the execution on TiFlash will finally stop after some minutes. +// This function is exclusively called, and only the first call succeeds sending Tasks and setting all Tasks as cancelled, while others will not work. +func (c *localMppCoordinator) cancelMppTasks() { + if len(c.mppReqs) == 0 { + return + } + usedStoreAddrs := make(map[string]bool) + c.mu.Lock() + // 1. One address will receive one cancel request, since cancel request cancels all mpp tasks within the same mpp gather + // 2. Cancel process will set all mpp task requests' states, thus if one request's state is Cancelled already, just return + if c.mppReqs[0].State == kv.MppTaskCancelled { + c.mu.Unlock() + return + } + for _, task := range c.mppReqs { + // get the store address of running tasks, + if task.State == kv.MppTaskRunning && !usedStoreAddrs[task.Meta.GetAddress()] { + usedStoreAddrs[task.Meta.GetAddress()] = true + } + task.State = kv.MppTaskCancelled + } + c.mu.Unlock() + c.sessionCtx.GetMPPClient().CancelMPPTasks(kv.CancelMPPTasksParam{StoreAddr: usedStoreAddrs, Reqs: c.mppReqs}) +} + +func (c *localMppCoordinator) receiveResults(req *kv.MPPDispatchRequest, taskMeta *mpp.TaskMeta, bo *backoff.Backoffer) { + stream, err := c.sessionCtx.GetMPPClient().EstablishMPPConns(kv.EstablishMPPConnsParam{Ctx: bo.GetCtx(), Req: req, TaskMeta: taskMeta}) + if err != nil { + // if NeedTriggerFallback is true, we return timeout to trigger tikv's fallback + if c.needTriggerFallback { + c.sendError(derr.ErrTiFlashServerTimeout) + } else { + c.sendError(err) + } + return + } + + defer stream.Close() + resp := stream.MPPDataPacket + if resp == nil { + return + } + + for { + err := c.handleMPPStreamResponse(bo, resp, req) + if err != nil { + c.sendError(err) + return + } + + resp, err = stream.Recv() + if err != nil { + if errors.Cause(err) == io.EOF { + return + } + + logutil.BgLogger().Info("mpp stream recv got error", zap.Error(err), zap.Uint64("timestamp", taskMeta.StartTs), + zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) + + // if NeedTriggerFallback is true, we return timeout to trigger tikv's fallback + if c.needTriggerFallback { + c.sendError(derr.ErrTiFlashServerTimeout) + } else { + c.sendError(err) + } + return + } + } +} + +// ReportStatus implements MppCoordinator interface +func (c *localMppCoordinator) ReportStatus(info kv.ReportStatusRequest) error { + taskID := info.Request.Meta.TaskId + var errMsg string + if info.Request.Error != nil { + errMsg = info.Request.Error.Msg + } + executionInfo := new(tipb.TiFlashExecutionInfo) + err := executionInfo.Unmarshal(info.Request.GetData()) + if err != nil { + // since it is very corner case to reach here, and it won't cause forever hang due to not close reportStatusCh + return err + } + + c.mu.Lock() + defer c.mu.Unlock() + req, exists := c.reqMap[taskID] + if !exists { + return errors.Errorf("ReportMPPTaskStatus task not exists taskID: %d", taskID) + } + if req.receivedReport { + return errors.Errorf("ReportMPPTaskStatus task already received taskID: %d", taskID) + } + + req.receivedReport = true + if len(errMsg) > 0 { + req.errMsg = errMsg + if len(c.firstErrMsg) == 0 { + c.firstErrMsg = errMsg + } + } + + c.reportedReqCount++ + req.executionSummaries = executionInfo.GetExecutionSummaries() + if c.reportedReqCount == len(c.mppReqs) { + close(c.reportStatusCh) + } + return nil +} + +func (c *localMppCoordinator) handleAllReports() error { + if c.reportExecutionInfo && atomic.LoadUint32(&c.dispatchFailed) == 0 && atomic.CompareAndSwapUint32(&c.allReportsHandled, 0, 1) { + startTime := time.Now() + select { + case <-c.reportStatusCh: + metrics.MppCoordinatorLatencyRcvReport.Observe(float64(time.Since(startTime).Milliseconds())) + var recordedPlanIDs = make(map[int]int) + for _, report := range c.reqMap { + for _, detail := range report.executionSummaries { + if detail != nil && detail.TimeProcessedNs != nil && + detail.NumProducedRows != nil && detail.NumIterations != nil { + recordedPlanIDs[c.sessionCtx.GetSessionVars().StmtCtx.RuntimeStatsColl. + RecordOneCopTask(-1, kv.TiFlash.Name(), report.mppReq.Meta.GetAddress(), detail)] = 0 + } + } + if ruDetailsRaw := c.ctx.Value(clientutil.RUDetailsCtxKey); ruDetailsRaw != nil { + if err := execdetails.MergeTiFlashRUConsumption(report.executionSummaries, ruDetailsRaw.(*clientutil.RUDetails)); err != nil { + return err + } + } + } + distsql.FillDummySummariesForTiFlashTasks(c.sessionCtx.GetSessionVars().StmtCtx.RuntimeStatsColl, "", kv.TiFlash.Name(), c.planIDs, recordedPlanIDs) + case <-time.After(receiveReportTimeout): + metrics.MppCoordinatorStatsReportNotReceived.Inc() + logutil.BgLogger().Warn(fmt.Sprintf("Mpp coordinator not received all reports within %d seconds", int(receiveReportTimeout.Seconds())), + zap.Uint64("txnStartTS", c.startTS), + zap.Uint64("gatherID", c.gatherID), + zap.Int("expectCount", len(c.mppReqs)), + zap.Int("actualCount", c.reportedReqCount)) + } + } + return nil +} + +// IsClosed implements MppCoordinator interface +func (c *localMppCoordinator) IsClosed() bool { + return atomic.LoadUint32(&c.closed) == 1 +} + +// Close implements MppCoordinator interface +// TODO: Test the case that user cancels the query. +func (c *localMppCoordinator) Close() error { + c.closeWithoutReport() + return c.handleAllReports() +} + +func (c *localMppCoordinator) closeWithoutReport() { + if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + close(c.finishCh) + } + c.cancelFunc() + <-c.wgDoneChan +} + +func (c *localMppCoordinator) handleMPPStreamResponse(bo *backoff.Backoffer, response *mpp.MPPDataPacket, req *kv.MPPDispatchRequest) (err error) { + if response.Error != nil { + c.mu.Lock() + firstErrMsg := c.firstErrMsg + c.mu.Unlock() + // firstErrMsg is only used when already received error response from root tasks, avoid confusing error messages + if len(firstErrMsg) > 0 { + err = errors.Errorf("other error for mpp stream: %s", firstErrMsg) + } else { + err = errors.Errorf("other error for mpp stream: %s", response.Error.Msg) + } + logutil.BgLogger().Warn("other error", + zap.Uint64("txnStartTS", req.StartTs), + zap.String("storeAddr", req.Meta.GetAddress()), + zap.Int64("mpp-version", req.MppVersion.ToInt64()), + zap.Int64("task-id", req.ID), + zap.Error(err)) + return err + } + + resp := &mppResponse{ + pbResp: response, + detail: new(copr.CopRuntimeStats), + } + + backoffTimes := bo.GetBackoffTimes() + resp.detail.BackoffTime = time.Duration(bo.GetTotalSleep()) * time.Millisecond + resp.detail.BackoffSleep = make(map[string]time.Duration, len(backoffTimes)) + resp.detail.BackoffTimes = make(map[string]int, len(backoffTimes)) + for backoff := range backoffTimes { + resp.detail.BackoffTimes[backoff] = backoffTimes[backoff] + resp.detail.BackoffSleep[backoff] = time.Duration(bo.GetBackoffSleepMS()[backoff]) * time.Millisecond + } + resp.detail.CalleeAddress = req.Meta.GetAddress() + c.sendToRespCh(resp) + return +} + +func (c *localMppCoordinator) nextImpl(ctx context.Context) (resp *mppResponse, ok bool, exit bool, err error) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case resp, ok = <-c.respChan: + return + case <-ticker.C: + if c.vars != nil && c.vars.Killed != nil { + killed := atomic.LoadUint32(c.vars.Killed) + if killed != 0 { + logutil.Logger(ctx).Info( + "a killed signal is received", + zap.Uint32("signal", killed), + ) + err = derr.ErrQueryInterrupted + exit = true + return + } + } + case <-c.finishCh: + exit = true + return + case <-ctx.Done(): + if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { + close(c.finishCh) + } + exit = true + return + } + } +} + +// Next implements MppCoordinator interface +func (c *localMppCoordinator) Next(ctx context.Context) (kv.ResultSubset, error) { + resp, ok, closed, err := c.nextImpl(ctx) + if err != nil { + return nil, errors.Trace(err) + } + if !ok || closed { + return nil, nil + } + + if resp.err != nil { + return nil, errors.Trace(resp.err) + } + + err = c.sessionCtx.GetMPPClient().CheckVisibility(c.startTS) + if err != nil { + return nil, errors.Trace(derr.ErrQueryInterrupted) + } + return resp, nil +} + +// Execute implements MppCoordinator interface +func (c *localMppCoordinator) Execute(ctx context.Context) (kv.Response, []kv.KeyRange, error) { + // TODO: Move the construct tasks logic to planner, so we can see the explain results. + sender := c.originalPlan.(*plannercore.PhysicalExchangeSender) + sctx := c.sessionCtx + frags, kvRanges, nodeInfo, err := plannercore.GenerateRootMPPTasks(sctx, c.startTS, c.gatherID, c.mppQueryID, sender, c.is) + if err != nil { + return nil, nil, errors.Trace(err) + } + if nodeInfo == nil { + return nil, nil, errors.New("node info should not be nil") + } + c.nodeCnt = len(nodeInfo) + + for _, frag := range frags { + err = c.appendMPPDispatchReq(frag) + if err != nil { + return nil, nil, errors.Trace(err) + } + } + failpoint.Inject("checkTotalMPPTasks", func(val failpoint.Value) { + if val.(int) != len(c.mppReqs) { + failpoint.Return(nil, nil, errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(c.mppReqs))) + } + }) + + ctx = distsql.WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx.KvExecCounter) + _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] + ctx = distsql.SetTiFlashConfVarsInContext(ctx, sctx.GetDistSQLCtx()) + c.needTriggerFallback = allowTiFlashFallback + c.enableCollectExecutionInfo = config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load() + + var ctxChild context.Context + ctxChild, c.cancelFunc = context.WithCancel(ctx) + go c.dispatchAll(ctxChild) + + return c, kvRanges, nil +} + +// GetNodeCnt returns the node count that involved in the mpp computation. +func (c *localMppCoordinator) GetNodeCnt() int { + return c.nodeCnt +} diff --git a/pkg/executor/internal/pdhelper/binding__failpoint_binding__.go b/pkg/executor/internal/pdhelper/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..d67d83e08b252 --- /dev/null +++ b/pkg/executor/internal/pdhelper/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package pdhelper + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/internal/pdhelper/pd.go b/pkg/executor/internal/pdhelper/pd.go index 4893791d14e3f..7a4f8c099ccbf 100644 --- a/pkg/executor/internal/pdhelper/pd.go +++ b/pkg/executor/internal/pdhelper/pd.go @@ -93,13 +93,13 @@ func getApproximateTableCountFromStorage( return 0, false } regionStats, err := helper.NewHelper(tikvStore).GetPDRegionStats(ctx, tid, true) - failpoint.Inject("calcSampleRateByStorageCount", func() { + if _, _err_ := failpoint.Eval(_curpkg_("calcSampleRateByStorageCount")); _err_ == nil { // Force the TiDB thinking that there's PD and the count of region is small. err = nil regionStats.Count = 1 // Set a very large approximate count. regionStats.StorageKeys = 1000000 - }) + } if err != nil { return 0, false } diff --git a/pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ b/pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ new file mode 100644 index 0000000000000..4893791d14e3f --- /dev/null +++ b/pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ @@ -0,0 +1,128 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pdhelper + +import ( + "context" + "strconv" + "strings" + "sync" + "time" + + "github.com/jellydator/ttlcache/v3" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/sqlescape" +) + +// GlobalPDHelper is the global variable for PDHelper. +var GlobalPDHelper = defaultPDHelper() +var globalPDHelperOnce sync.Once + +// PDHelper is used to get some information from PD. +type PDHelper struct { + cacheForApproximateTableCountFromStorage *ttlcache.Cache[string, float64] + + getApproximateTableCountFromStorageFunc func(ctx context.Context, sctx sessionctx.Context, tid int64, dbName, tableName, partitionName string) (float64, bool) + wg util.WaitGroupWrapper +} + +func defaultPDHelper() *PDHelper { + cache := ttlcache.New[string, float64]( + ttlcache.WithTTL[string, float64](30*time.Second), + ttlcache.WithCapacity[string, float64](1024*1024), + ) + return &PDHelper{ + cacheForApproximateTableCountFromStorage: cache, + getApproximateTableCountFromStorageFunc: getApproximateTableCountFromStorage, + } +} + +// Start is used to start the background task of PDHelper. Currently, the background task is used to clean up TTL cache. +func (p *PDHelper) Start() { + globalPDHelperOnce.Do(func() { + p.wg.Run(p.cacheForApproximateTableCountFromStorage.Start) + }) +} + +// Stop stops the background task of PDHelper. +func (p *PDHelper) Stop() { + p.cacheForApproximateTableCountFromStorage.Stop() + p.wg.Wait() +} + +func approximateTableCountKey(tid int64, dbName, tableName, partitionName string) string { + return strings.Join([]string{strconv.FormatInt(tid, 10), dbName, tableName, partitionName}, "_") +} + +// GetApproximateTableCountFromStorage gets the approximate count of the table. +func (p *PDHelper) GetApproximateTableCountFromStorage( + ctx context.Context, sctx sessionctx.Context, + tid int64, dbName, tableName, partitionName string, +) (float64, bool) { + key := approximateTableCountKey(tid, dbName, tableName, partitionName) + if item := p.cacheForApproximateTableCountFromStorage.Get(key); item != nil { + return item.Value(), true + } + result, hasPD := p.getApproximateTableCountFromStorageFunc(ctx, sctx, tid, dbName, tableName, partitionName) + p.cacheForApproximateTableCountFromStorage.Set(key, result, ttlcache.DefaultTTL) + return result, hasPD +} + +func getApproximateTableCountFromStorage( + ctx context.Context, sctx sessionctx.Context, + tid int64, dbName, tableName, partitionName string, +) (float64, bool) { + tikvStore, ok := sctx.GetStore().(helper.Storage) + if !ok { + return 0, false + } + regionStats, err := helper.NewHelper(tikvStore).GetPDRegionStats(ctx, tid, true) + failpoint.Inject("calcSampleRateByStorageCount", func() { + // Force the TiDB thinking that there's PD and the count of region is small. + err = nil + regionStats.Count = 1 + // Set a very large approximate count. + regionStats.StorageKeys = 1000000 + }) + if err != nil { + return 0, false + } + // If this table is not small, we directly use the count from PD, + // since for a small table, it's possible that it's data is in the same region with part of another large table. + // Thus, we use the number of the regions of the table's table KV to decide whether the table is small. + if regionStats.Count > 2 { + return float64(regionStats.StorageKeys), true + } + // Otherwise, we use count(*) to calc it's size, since it's very small, the table data can be filled in no more than 2 regions. + sql := new(strings.Builder) + sqlescape.MustFormatSQL(sql, "select count(*) from %n.%n", dbName, tableName) + if partitionName != "" { + sqlescape.MustFormatSQL(sql, " partition(%n)", partitionName) + } + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) + rows, _, err := sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, sql.String()) + if err != nil { + return 0, false + } + // If the record set is nil, there's something wrong with the execution. The COUNT(*) would always return one row. + if len(rows) == 0 || rows[0].Len() == 0 { + return 0, false + } + return float64(rows[0].GetInt64(0)), true +} diff --git a/pkg/executor/join/base_join_probe.go b/pkg/executor/join/base_join_probe.go index 64a815bafdc51..58face22ee483 100644 --- a/pkg/executor/join/base_join_probe.go +++ b/pkg/executor/join/base_join_probe.go @@ -262,11 +262,11 @@ func (j *baseJoinProbe) finishLookupCurrentProbeRow() { func checkSQLKiller(killer *sqlkiller.SQLKiller, fpName string) error { err := killer.HandleSignal() - failpoint.Inject(fpName, func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_(fpName)); _err_ == nil { if val.(bool) { err = exeerrors.ErrQueryInterrupted } - }) + } return err } diff --git a/pkg/executor/join/base_join_probe.go__failpoint_stash__ b/pkg/executor/join/base_join_probe.go__failpoint_stash__ new file mode 100644 index 0000000000000..64a815bafdc51 --- /dev/null +++ b/pkg/executor/join/base_join_probe.go__failpoint_stash__ @@ -0,0 +1,593 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "bytes" + "hash/fnv" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/sqlkiller" +) + +type keyMode int + +const ( + // OneInt64 mean the key contains only one Int64 + OneInt64 keyMode = iota + // FixedSerializedKey mean the key has fixed length + FixedSerializedKey + // VariableSerializedKey mean the key has variable length + VariableSerializedKey +) + +const batchBuildRowSize = 32 + +func (hCtx *HashJoinCtxV2) hasOtherCondition() bool { + return hCtx.OtherCondition != nil +} + +// ProbeV2 is the interface used to do probe in hash join v2 +type ProbeV2 interface { + // SetChunkForProbe will do some pre-work when start probing a chunk + SetChunkForProbe(chunk *chunk.Chunk) error + // Probe is to probe current chunk, the result chunk is set in result.chk, and Probe need to make sure result.chk.NumRows() <= result.chk.RequiredRows() + Probe(joinResult *hashjoinWorkerResult, sqlKiller *sqlkiller.SQLKiller) (ok bool, result *hashjoinWorkerResult) + // IsCurrentChunkProbeDone returns true if current probe chunk is all probed + IsCurrentChunkProbeDone() bool + // ScanRowTable is called after all the probe chunks are probed. It is used in some special joins, like left outer join with left side to build, after all + // the probe side chunks are handled, it needs to scan the row table to return the un-matched rows + ScanRowTable(joinResult *hashjoinWorkerResult, sqlKiller *sqlkiller.SQLKiller) (result *hashjoinWorkerResult) + // IsScanRowTableDone returns true after scan row table is done + IsScanRowTableDone() bool + // NeedScanRowTable returns true if current join need to scan row table after all the probe side chunks are handled + NeedScanRowTable() bool + // InitForScanRowTable do some pre-work before ScanRowTable, it must be called before ScanRowTable + InitForScanRowTable() + // Return probe collsion + GetProbeCollision() uint64 + // Reset probe collsion + ResetProbeCollision() +} + +type offsetAndLength struct { + offset int + length int +} + +type matchedRowInfo struct { + // probeRowIndex mean the probe side index of current matched row + probeRowIndex int + // buildRowStart mean the build row start of the current matched row + buildRowStart uintptr + // buildRowOffset mean the current offset of current BuildRow, used to construct column data from BuildRow + buildRowOffset int +} + +func createMatchRowInfo(probeRowIndex int, buildRowStart unsafe.Pointer) *matchedRowInfo { + ret := &matchedRowInfo{probeRowIndex: probeRowIndex} + *(*unsafe.Pointer)(unsafe.Pointer(&ret.buildRowStart)) = buildRowStart + return ret +} + +type posAndHashValue struct { + hashValue uint64 + pos int +} + +type baseJoinProbe struct { + ctx *HashJoinCtxV2 + workID uint + + currentChunk *chunk.Chunk + // if currentChunk.Sel() == nil, then construct a fake selRows + selRows []int + usedRows []int + // matchedRowsHeaders, serializedKeys is indexed by logical row index + matchedRowsHeaders []uintptr // the start address of each matched rows + serializedKeys [][]byte // used for save serialized keys + // filterVector and nullKeyVector is indexed by physical row index because the return vector of VectorizedFilter is based on physical row index + filterVector []bool // if there is filter before probe, filterVector saves the filter result + nullKeyVector []bool // nullKeyVector[i] = true if any of the key is null + hashValues [][]posAndHashValue // the start address of each matched rows + currentProbeRow int + matchedRowsForCurrentProbeRow int + chunkRows int + cachedBuildRows []*matchedRowInfo + + keyIndex []int + keyTypes []*types.FieldType + hasNullableKey bool + maxChunkSize int + rightAsBuildSide bool + // lUsed/rUsed show which columns are used by father for left child and right child. + // NOTE: + // 1. lUsed/rUsed should never be nil. + // 2. no columns are used if lUsed/rUsed is not nil but the size of lUsed/rUsed is 0. + lUsed, rUsed []int + lUsedInOtherCondition, rUsedInOtherCondition []int + // used when construct column from probe side + offsetAndLengthArray []offsetAndLength + // these 3 variables are used for join that has other condition, should be inited when the join has other condition + tmpChk *chunk.Chunk + rowIndexInfos []*matchedRowInfo + selected []bool + + probeCollision uint64 +} + +func (j *baseJoinProbe) GetProbeCollision() uint64 { + return j.probeCollision +} + +func (j *baseJoinProbe) ResetProbeCollision() { + j.probeCollision = 0 +} + +func (j *baseJoinProbe) IsCurrentChunkProbeDone() bool { + return j.currentChunk == nil || j.currentProbeRow >= j.chunkRows +} + +func (j *baseJoinProbe) finishCurrentLookupLoop(joinedChk *chunk.Chunk) { + if len(j.cachedBuildRows) > 0 { + j.batchConstructBuildRows(joinedChk, 0, j.ctx.hasOtherCondition()) + } + j.finishLookupCurrentProbeRow() + j.appendProbeRowToChunk(joinedChk, j.currentChunk) +} + +func (j *baseJoinProbe) SetChunkForProbe(chk *chunk.Chunk) (err error) { + if j.currentChunk != nil { + if j.currentProbeRow < j.chunkRows { + return errors.New("Previous chunk is not probed yet") + } + } + j.currentChunk = chk + logicalRows := chk.NumRows() + // if chk.sel != nil, then physicalRows is different from logicalRows + physicalRows := chk.Column(0).Rows() + j.usedRows = chk.Sel() + if j.usedRows == nil { + if cap(j.selRows) >= logicalRows { + j.selRows = j.selRows[:logicalRows] + } else { + j.selRows = make([]int, 0, logicalRows) + for i := 0; i < logicalRows; i++ { + j.selRows = append(j.selRows, i) + } + } + j.usedRows = j.selRows + } + j.chunkRows = logicalRows + if cap(j.matchedRowsHeaders) >= logicalRows { + j.matchedRowsHeaders = j.matchedRowsHeaders[:logicalRows] + } else { + j.matchedRowsHeaders = make([]uintptr, logicalRows) + } + for i := 0; i < int(j.ctx.partitionNumber); i++ { + j.hashValues[i] = j.hashValues[i][:0] + } + if j.ctx.ProbeFilter != nil { + if cap(j.filterVector) >= physicalRows { + j.filterVector = j.filterVector[:physicalRows] + } else { + j.filterVector = make([]bool, physicalRows) + } + } + if j.hasNullableKey { + if cap(j.nullKeyVector) >= physicalRows { + j.nullKeyVector = j.nullKeyVector[:physicalRows] + } else { + j.nullKeyVector = make([]bool, physicalRows) + } + for i := 0; i < physicalRows; i++ { + j.nullKeyVector[i] = false + } + } + if cap(j.serializedKeys) >= logicalRows { + j.serializedKeys = j.serializedKeys[:logicalRows] + } else { + j.serializedKeys = make([][]byte, logicalRows) + } + for i := 0; i < logicalRows; i++ { + j.serializedKeys[i] = j.serializedKeys[i][:0] + } + if j.ctx.ProbeFilter != nil { + j.filterVector, err = expression.VectorizedFilter(j.ctx.SessCtx.GetExprCtx().GetEvalCtx(), j.ctx.SessCtx.GetSessionVars().EnableVectorizedExpression, j.ctx.ProbeFilter, chunk.NewIterator4Chunk(j.currentChunk), j.filterVector) + if err != nil { + return err + } + } + + // generate serialized key + for i, index := range j.keyIndex { + err = codec.SerializeKeys(j.ctx.SessCtx.GetSessionVars().StmtCtx.TypeCtx(), j.currentChunk, j.keyTypes[i], index, j.usedRows, j.filterVector, j.nullKeyVector, j.ctx.hashTableMeta.serializeModes[i], j.serializedKeys) + if err != nil { + return err + } + } + // generate hash value + hash := fnv.New64() + for logicalRowIndex, physicalRowIndex := range j.usedRows { + if (j.filterVector != nil && !j.filterVector[physicalRowIndex]) || (j.nullKeyVector != nil && j.nullKeyVector[physicalRowIndex]) { + // explicit set the matchedRowsHeaders[logicalRowIndex] to nil to indicate there is no matched rows + j.matchedRowsHeaders[logicalRowIndex] = 0 + continue + } + hash.Reset() + // As the golang doc described, `Hash.Write` never returns an error. + // See https://golang.org/pkg/hash/#Hash + _, _ = hash.Write(j.serializedKeys[logicalRowIndex]) + hashValue := hash.Sum64() + partIndex := hashValue >> j.ctx.partitionMaskOffset + j.hashValues[partIndex] = append(j.hashValues[partIndex], posAndHashValue{hashValue: hashValue, pos: logicalRowIndex}) + } + j.currentProbeRow = 0 + for i := 0; i < int(j.ctx.partitionNumber); i++ { + for index := range j.hashValues[i] { + j.matchedRowsHeaders[j.hashValues[i][index].pos] = j.ctx.hashTableContext.hashTable.tables[i].lookup(j.hashValues[i][index].hashValue) + } + } + return +} + +func (j *baseJoinProbe) finishLookupCurrentProbeRow() { + if j.matchedRowsForCurrentProbeRow > 0 { + j.offsetAndLengthArray = append(j.offsetAndLengthArray, offsetAndLength{offset: j.usedRows[j.currentProbeRow], length: j.matchedRowsForCurrentProbeRow}) + } + j.matchedRowsForCurrentProbeRow = 0 +} + +func checkSQLKiller(killer *sqlkiller.SQLKiller, fpName string) error { + err := killer.HandleSignal() + failpoint.Inject(fpName, func(val failpoint.Value) { + if val.(bool) { + err = exeerrors.ErrQueryInterrupted + } + }) + return err +} + +func (j *baseJoinProbe) appendBuildRowToCachedBuildRowsAndConstructBuildRowsIfNeeded(buildRow *matchedRowInfo, chk *chunk.Chunk, currentColumnIndexInRow int, forOtherCondition bool) { + j.cachedBuildRows = append(j.cachedBuildRows, buildRow) + if len(j.cachedBuildRows) >= batchBuildRowSize { + j.batchConstructBuildRows(chk, currentColumnIndexInRow, forOtherCondition) + } +} + +func (j *baseJoinProbe) batchConstructBuildRows(chk *chunk.Chunk, currentColumnIndexInRow int, forOtherCondition bool) { + j.appendBuildRowToChunk(chk, currentColumnIndexInRow, forOtherCondition) + if forOtherCondition { + j.rowIndexInfos = append(j.rowIndexInfos, j.cachedBuildRows...) + } + j.cachedBuildRows = j.cachedBuildRows[:0] +} + +func (j *baseJoinProbe) prepareForProbe(chk *chunk.Chunk) (joinedChk *chunk.Chunk, remainCap int, err error) { + j.offsetAndLengthArray = j.offsetAndLengthArray[:0] + j.cachedBuildRows = j.cachedBuildRows[:0] + j.matchedRowsForCurrentProbeRow = 0 + joinedChk = chk + if j.ctx.OtherCondition != nil { + j.tmpChk.Reset() + j.rowIndexInfos = j.rowIndexInfos[:0] + j.selected = j.selected[:0] + joinedChk = j.tmpChk + } + return joinedChk, chk.RequiredRows() - chk.NumRows(), nil +} + +func (j *baseJoinProbe) appendBuildRowToChunk(chk *chunk.Chunk, currentColumnIndexInRow int, forOtherCondition bool) { + if j.rightAsBuildSide { + if forOtherCondition { + j.appendBuildRowToChunkInternal(chk, j.rUsedInOtherCondition, true, j.currentChunk.NumCols(), currentColumnIndexInRow) + } else { + j.appendBuildRowToChunkInternal(chk, j.rUsed, false, len(j.lUsed), currentColumnIndexInRow) + } + } else { + if forOtherCondition { + j.appendBuildRowToChunkInternal(chk, j.lUsedInOtherCondition, true, 0, currentColumnIndexInRow) + } else { + j.appendBuildRowToChunkInternal(chk, j.lUsed, false, 0, currentColumnIndexInRow) + } + } +} + +func (j *baseJoinProbe) appendBuildRowToChunkInternal(chk *chunk.Chunk, usedCols []int, forOtherCondition bool, colOffset int, currentColumnInRow int) { + chkRows := chk.NumRows() + needUpdateVirtualRow := currentColumnInRow == 0 + if len(usedCols) == 0 || len(j.cachedBuildRows) == 0 { + if needUpdateVirtualRow { + chk.SetNumVirtualRows(chkRows + len(j.cachedBuildRows)) + } + return + } + for i := 0; i < len(j.cachedBuildRows); i++ { + if j.cachedBuildRows[i].buildRowOffset == 0 { + j.ctx.hashTableMeta.advanceToRowData(j.cachedBuildRows[i]) + } + } + colIndexMap := make(map[int]int) + for index, value := range usedCols { + if forOtherCondition { + colIndexMap[value] = value + colOffset + } else { + colIndexMap[value] = index + colOffset + } + } + meta := j.ctx.hashTableMeta + columnsToAppend := len(meta.rowColumnsOrder) + if forOtherCondition { + columnsToAppend = meta.columnCountNeededForOtherCondition + if j.ctx.RightAsBuildSide { + for _, value := range j.rUsed { + colIndexMap[value] = value + colOffset + } + } else { + for _, value := range j.lUsed { + colIndexMap[value] = value + colOffset + } + } + } + for columnIndex := currentColumnInRow; columnIndex < len(meta.rowColumnsOrder) && columnIndex < columnsToAppend; columnIndex++ { + indexInDstChk, ok := colIndexMap[meta.rowColumnsOrder[columnIndex]] + var currentColumn *chunk.Column + if ok { + currentColumn = chk.Column(indexInDstChk) + for index := range j.cachedBuildRows { + currentColumn.AppendNullBitmap(!meta.isColumnNull(*(*unsafe.Pointer)(unsafe.Pointer(&j.cachedBuildRows[index].buildRowStart)), columnIndex)) + j.cachedBuildRows[index].buildRowOffset = chunk.AppendCellFromRawData(currentColumn, *(*unsafe.Pointer)(unsafe.Pointer(&j.cachedBuildRows[index].buildRowStart)), j.cachedBuildRows[index].buildRowOffset) + } + } else { + // not used so don't need to insert into chk, but still need to advance rowData + if meta.columnsSize[columnIndex] < 0 { + for index := range j.cachedBuildRows { + size := *(*uint64)(unsafe.Add(*(*unsafe.Pointer)(unsafe.Pointer(&j.cachedBuildRows[index].buildRowStart)), j.cachedBuildRows[index].buildRowOffset)) + j.cachedBuildRows[index].buildRowOffset += sizeOfLengthField + int(size) + } + } else { + for index := range j.cachedBuildRows { + j.cachedBuildRows[index].buildRowOffset += meta.columnsSize[columnIndex] + } + } + } + } + if needUpdateVirtualRow { + chk.SetNumVirtualRows(chkRows + len(j.cachedBuildRows)) + } +} + +func (j *baseJoinProbe) appendProbeRowToChunk(chk *chunk.Chunk, probeChk *chunk.Chunk) { + if j.rightAsBuildSide { + if j.ctx.hasOtherCondition() { + j.appendProbeRowToChunkInternal(chk, probeChk, j.lUsedInOtherCondition, 0, true) + } else { + j.appendProbeRowToChunkInternal(chk, probeChk, j.lUsed, 0, false) + } + } else { + if j.ctx.hasOtherCondition() { + j.appendProbeRowToChunkInternal(chk, probeChk, j.rUsedInOtherCondition, j.ctx.hashTableMeta.totalColumnNumber, true) + } else { + j.appendProbeRowToChunkInternal(chk, probeChk, j.rUsed, len(j.lUsed), false) + } + } +} + +func (j *baseJoinProbe) appendProbeRowToChunkInternal(chk *chunk.Chunk, probeChk *chunk.Chunk, used []int, collOffset int, forOtherCondition bool) { + if len(used) == 0 || len(j.offsetAndLengthArray) == 0 { + return + } + if forOtherCondition { + usedColumnMap := make(map[int]struct{}) + for _, colIndex := range used { + if _, ok := usedColumnMap[colIndex]; !ok { + srcCol := probeChk.Column(colIndex) + dstCol := chk.Column(colIndex + collOffset) + for _, offsetAndLength := range j.offsetAndLengthArray { + dstCol.AppendCellNTimes(srcCol, offsetAndLength.offset, offsetAndLength.length) + } + usedColumnMap[colIndex] = struct{}{} + } + } + } else { + for index, colIndex := range used { + srcCol := probeChk.Column(colIndex) + dstCol := chk.Column(index + collOffset) + for _, offsetAndLength := range j.offsetAndLengthArray { + dstCol.AppendCellNTimes(srcCol, offsetAndLength.offset, offsetAndLength.length) + } + } + } +} + +func (j *baseJoinProbe) buildResultAfterOtherCondition(chk *chunk.Chunk, joinedChk *chunk.Chunk) (err error) { + // construct the return chunk based on joinedChk and selected, there are 3 kinds of columns + // 1. columns already in joinedChk + // 2. columns from build side, but not in joinedChk + // 3. columns from probe side, but not in joinedChk + rowCount := chk.NumRows() + probeUsedColumns, probeColOffset, probeColOffsetInJoinedChk := j.lUsed, 0, 0 + if !j.rightAsBuildSide { + probeUsedColumns, probeColOffset, probeColOffsetInJoinedChk = j.rUsed, len(j.lUsed), j.ctx.hashTableMeta.totalColumnNumber + } + + for index, colIndex := range probeUsedColumns { + dstCol := chk.Column(index + probeColOffset) + if joinedChk.Column(colIndex+probeColOffsetInJoinedChk).Rows() > 0 { + // probe column that is already in joinedChk + srcCol := joinedChk.Column(colIndex + probeColOffsetInJoinedChk) + chunk.CopySelectedRows(dstCol, srcCol, j.selected) + } else { + // probe column that is not in joinedChk + srcCol := j.currentChunk.Column(colIndex) + chunk.CopySelectedRowsWithRowIDFunc(dstCol, srcCol, j.selected, 0, len(j.selected), func(i int) int { + return j.usedRows[j.rowIndexInfos[i].probeRowIndex] + }) + } + } + buildUsedColumns, buildColOffset, buildColOffsetInJoinedChk := j.lUsed, 0, 0 + if j.rightAsBuildSide { + buildUsedColumns, buildColOffset, buildColOffsetInJoinedChk = j.rUsed, len(j.lUsed), j.currentChunk.NumCols() + } + hasRemainCols := false + for index, colIndex := range buildUsedColumns { + dstCol := chk.Column(index + buildColOffset) + srcCol := joinedChk.Column(colIndex + buildColOffsetInJoinedChk) + if srcCol.Rows() > 0 { + // build column that is already in joinedChk + chunk.CopySelectedRows(dstCol, srcCol, j.selected) + } else { + hasRemainCols = true + } + } + if hasRemainCols { + j.cachedBuildRows = j.cachedBuildRows[:0] + // build column that is not in joinedChk + for index, result := range j.selected { + if result { + j.appendBuildRowToCachedBuildRowsAndConstructBuildRowsIfNeeded(j.rowIndexInfos[index], chk, j.ctx.hashTableMeta.columnCountNeededForOtherCondition, false) + } + } + if len(j.cachedBuildRows) > 0 { + j.batchConstructBuildRows(chk, j.ctx.hashTableMeta.columnCountNeededForOtherCondition, false) + } + } + rowsAdded := 0 + for _, result := range j.selected { + if result { + rowsAdded++ + } + } + chk.SetNumVirtualRows(rowCount + rowsAdded) + return +} + +func isKeyMatched(keyMode keyMode, serializedKey []byte, rowStart unsafe.Pointer, meta *TableMeta) bool { + switch keyMode { + case OneInt64: + return *(*int64)(unsafe.Pointer(&serializedKey[0])) == *(*int64)(unsafe.Add(rowStart, meta.nullMapLength+sizeOfNextPtr)) + case FixedSerializedKey: + return bytes.Equal(serializedKey, hack.GetBytesFromPtr(unsafe.Add(rowStart, meta.nullMapLength+sizeOfNextPtr), meta.joinKeysLength)) + case VariableSerializedKey: + return bytes.Equal(serializedKey, hack.GetBytesFromPtr(unsafe.Add(rowStart, meta.nullMapLength+sizeOfNextPtr+sizeOfLengthField), int(meta.getSerializedKeyLength(rowStart)))) + default: + panic("unknown key match type") + } +} + +// NewJoinProbe create a join probe used for hash join v2 +func NewJoinProbe(ctx *HashJoinCtxV2, workID uint, joinType core.JoinType, keyIndex []int, joinedColumnTypes, probeKeyTypes []*types.FieldType, rightAsBuildSide bool) ProbeV2 { + base := baseJoinProbe{ + ctx: ctx, + workID: workID, + keyIndex: keyIndex, + keyTypes: probeKeyTypes, + maxChunkSize: ctx.SessCtx.GetSessionVars().MaxChunkSize, + lUsed: ctx.LUsed, + rUsed: ctx.RUsed, + lUsedInOtherCondition: ctx.LUsedInOtherCondition, + rUsedInOtherCondition: ctx.RUsedInOtherCondition, + rightAsBuildSide: rightAsBuildSide, + } + for i := range keyIndex { + if !mysql.HasNotNullFlag(base.keyTypes[i].GetFlag()) { + base.hasNullableKey = true + } + } + base.cachedBuildRows = make([]*matchedRowInfo, 0, batchBuildRowSize) + base.matchedRowsHeaders = make([]uintptr, 0, chunk.InitialCapacity) + base.selRows = make([]int, 0, chunk.InitialCapacity) + for i := 0; i < chunk.InitialCapacity; i++ { + base.selRows = append(base.selRows, i) + } + base.hashValues = make([][]posAndHashValue, ctx.partitionNumber) + for i := 0; i < int(ctx.partitionNumber); i++ { + base.hashValues[i] = make([]posAndHashValue, 0, chunk.InitialCapacity) + } + base.serializedKeys = make([][]byte, 0, chunk.InitialCapacity) + if base.ctx.ProbeFilter != nil { + base.filterVector = make([]bool, 0, chunk.InitialCapacity) + } + if base.hasNullableKey { + base.nullKeyVector = make([]bool, 0, chunk.InitialCapacity) + } + if base.ctx.OtherCondition != nil { + base.tmpChk = chunk.NewChunkWithCapacity(joinedColumnTypes, chunk.InitialCapacity) + base.tmpChk.SetInCompleteChunk(true) + base.selected = make([]bool, 0, chunk.InitialCapacity) + base.rowIndexInfos = make([]*matchedRowInfo, 0, chunk.InitialCapacity) + } + switch joinType { + case core.InnerJoin: + return &innerJoinProbe{base} + case core.LeftOuterJoin: + return newOuterJoinProbe(base, !rightAsBuildSide, rightAsBuildSide) + case core.RightOuterJoin: + return newOuterJoinProbe(base, rightAsBuildSide, rightAsBuildSide) + default: + panic("unsupported join type") + } +} + +type mockJoinProbe struct { + baseJoinProbe +} + +func (*mockJoinProbe) SetChunkForProbe(*chunk.Chunk) error { + return errors.New("not supported") +} + +func (*mockJoinProbe) Probe(*hashjoinWorkerResult, *sqlkiller.SQLKiller) (ok bool, result *hashjoinWorkerResult) { + panic("not supported") +} + +func (*mockJoinProbe) ScanRowTable(*hashjoinWorkerResult, *sqlkiller.SQLKiller) (result *hashjoinWorkerResult) { + panic("not supported") +} + +func (*mockJoinProbe) IsScanRowTableDone() bool { + panic("not supported") +} + +func (*mockJoinProbe) NeedScanRowTable() bool { + panic("not supported") +} + +func (*mockJoinProbe) InitForScanRowTable() { + panic("not supported") +} + +// used for test +func newMockJoinProbe(ctx *HashJoinCtxV2) *mockJoinProbe { + base := baseJoinProbe{ + ctx: ctx, + lUsed: ctx.LUsed, + rUsed: ctx.RUsed, + lUsedInOtherCondition: ctx.LUsedInOtherCondition, + rUsedInOtherCondition: ctx.RUsedInOtherCondition, + rightAsBuildSide: false, + } + return &mockJoinProbe{base} +} diff --git a/pkg/executor/join/binding__failpoint_binding__.go b/pkg/executor/join/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..6560aaf35d7c5 --- /dev/null +++ b/pkg/executor/join/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package join + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/join/hash_join_base.go b/pkg/executor/join/hash_join_base.go index 7947b96653e6c..ecf5586e34114 100644 --- a/pkg/executor/join/hash_join_base.go +++ b/pkg/executor/join/hash_join_base.go @@ -168,7 +168,7 @@ func (fetcher *probeSideTupleFetcherBase) fetchProbeSideChunks(ctx context.Conte } probeSideResult := probeSideResource.chk err := exec.Next(ctx, fetcher.ProbeSideExec, probeSideResult) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if err != nil { hashJoinCtx.joinResultCh <- &hashjoinWorkerResult{ err: err, @@ -176,11 +176,11 @@ func (fetcher *probeSideTupleFetcherBase) fetchProbeSideChunks(ctx context.Conte return } if !hasWaitedForBuild { - failpoint.Inject("issue30289", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("issue30289")); _err_ == nil { if val.(bool) { probeSideResult.Reset() } - }) + } skipProbe := wait4BuildSide(isBuildEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone, hashJoinCtx) if skipProbe { // there is no need to probe, so just return @@ -223,14 +223,14 @@ type buildWorkerBase struct { func (w *buildWorkerBase) fetchBuildSideRows(ctx context.Context, hashJoinCtx *hashJoinCtxBase, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { defer close(chkCh) var err error - failpoint.Inject("issue30289", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("issue30289")); _err_ == nil { if val.(bool) { err = errors.Errorf("issue30289 build return error") errCh <- errors.Trace(err) return } - }) - failpoint.Inject("issue42662_1", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("issue42662_1")); _err_ == nil { if val.(bool) { if hashJoinCtx.SessCtx.GetSessionVars().ConnectionID != 0 { // consume 170MB memory, this sql should be tracked into MemoryTop1Tracker @@ -238,30 +238,30 @@ func (w *buildWorkerBase) fetchBuildSideRows(ctx context.Context, hashJoinCtx *h } return } - }) + } sessVars := hashJoinCtx.SessCtx.GetSessionVars() - failpoint.Inject("issue51998", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("issue51998")); _err_ == nil { if val.(bool) { time.Sleep(2 * time.Second) } - }) + } for { if hashJoinCtx.finished.Load() { return } chk := hashJoinCtx.ChunkAllocPool.Alloc(w.BuildSideExec.RetFieldTypes(), sessVars.MaxChunkSize, sessVars.MaxChunkSize) err = exec.Next(ctx, w.BuildSideExec, chk) - failpoint.Inject("issue51998", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("issue51998")); _err_ == nil { if val.(bool) { err = errors.Errorf("issue51998 build return error") } - }) + } if err != nil { errCh <- errors.Trace(err) return } - failpoint.Inject("errorFetchBuildSideRowsMockOOMPanic", nil) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("errorFetchBuildSideRowsMockOOMPanic")) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if chk.NumRows() == 0 { return } diff --git a/pkg/executor/join/hash_join_base.go__failpoint_stash__ b/pkg/executor/join/hash_join_base.go__failpoint_stash__ new file mode 100644 index 0000000000000..7947b96653e6c --- /dev/null +++ b/pkg/executor/join/hash_join_base.go__failpoint_stash__ @@ -0,0 +1,379 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "bytes" + "context" + "fmt" + "strconv" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/memory" +) + +// hashjoinWorkerResult stores the result of join workers, +// `src` is for Chunk reuse: the main goroutine will get the join result chunk `chk`, +// and push `chk` into `src` after processing, join worker goroutines get the empty chunk from `src` +// and push new data into this chunk. +type hashjoinWorkerResult struct { + chk *chunk.Chunk + err error + src chan<- *chunk.Chunk +} + +type hashJoinCtxBase struct { + SessCtx sessionctx.Context + ChunkAllocPool chunk.Allocator + // Concurrency is the number of partition, build and join workers. + Concurrency uint + joinResultCh chan *hashjoinWorkerResult + // closeCh add a lock for closing executor. + closeCh chan struct{} + finished atomic.Bool + IsNullEQ []bool + buildFinished chan error + JoinType plannercore.JoinType + IsNullAware bool + memTracker *memory.Tracker // track memory usage. + diskTracker *disk.Tracker // track disk usage. +} + +type probeSideTupleFetcherBase struct { + ProbeSideExec exec.Executor + probeChkResourceCh chan *probeChkResource + probeResultChs []chan *chunk.Chunk + requiredRows int64 + joinResultChannel chan *hashjoinWorkerResult +} + +func (fetcher *probeSideTupleFetcherBase) initializeForProbeBase(concurrency uint, joinResultChannel chan *hashjoinWorkerResult) { + // fetcher.probeResultChs is for transmitting the chunks which store the data of + // ProbeSideExec, it'll be written by probe side worker goroutine, and read by join + // workers. + fetcher.probeResultChs = make([]chan *chunk.Chunk, concurrency) + for i := uint(0); i < concurrency; i++ { + fetcher.probeResultChs[i] = make(chan *chunk.Chunk, 1) + } + // fetcher.probeChkResourceCh is for transmitting the used ProbeSideExec chunks from + // join workers to ProbeSideExec worker. + fetcher.probeChkResourceCh = make(chan *probeChkResource, concurrency) + for i := uint(0); i < concurrency; i++ { + fetcher.probeChkResourceCh <- &probeChkResource{ + chk: exec.NewFirstChunk(fetcher.ProbeSideExec), + dest: fetcher.probeResultChs[i], + } + } + fetcher.joinResultChannel = joinResultChannel +} + +func (fetcher *probeSideTupleFetcherBase) handleProbeSideFetcherPanic(r any) { + for i := range fetcher.probeResultChs { + close(fetcher.probeResultChs[i]) + } + if r != nil { + fetcher.joinResultChannel <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} + } +} + +type isBuildSideEmpty func() bool + +func wait4BuildSide(isBuildEmpty isBuildSideEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone bool, hashJoinCtx *hashJoinCtxBase) (skipProbe bool) { + var err error + skipProbe = false + buildFinishes := false + select { + case <-hashJoinCtx.closeCh: + // current executor is closed, no need to probe + skipProbe = true + case err = <-hashJoinCtx.buildFinished: + if err != nil { + // build meet error, no need to probe + skipProbe = true + } else { + buildFinishes = true + } + } + // only check build empty if build finishes + if buildFinishes && isBuildEmpty() && canSkipIfBuildEmpty { + // if build side is empty, can skip probe if canSkipIfBuildEmpty is true(e.g. inner join) + skipProbe = true + } + if err != nil { + // if err is not nil, send out the error + hashJoinCtx.joinResultCh <- &hashjoinWorkerResult{ + err: err, + } + } else if skipProbe { + // if skipProbe is true and there is no need to scan hash table after probe, just the whole hash join is finished + if !needScanAfterProbeDone { + hashJoinCtx.finished.Store(true) + } + } + return skipProbe +} + +func (fetcher *probeSideTupleFetcherBase) getProbeSideResource(shouldLimitProbeFetchSize bool, maxChunkSize int, hashJoinCtx *hashJoinCtxBase) *probeChkResource { + if hashJoinCtx.finished.Load() { + return nil + } + + var probeSideResource *probeChkResource + var ok bool + select { + case <-hashJoinCtx.closeCh: + return nil + case probeSideResource, ok = <-fetcher.probeChkResourceCh: + if !ok { + return nil + } + } + if shouldLimitProbeFetchSize { + required := int(atomic.LoadInt64(&fetcher.requiredRows)) + probeSideResource.chk.SetRequiredRows(required, maxChunkSize) + } + return probeSideResource +} + +// fetchProbeSideChunks get chunks from fetches chunks from the big table in a background goroutine +// and sends the chunks to multiple channels which will be read by multiple join workers. +func (fetcher *probeSideTupleFetcherBase) fetchProbeSideChunks(ctx context.Context, maxChunkSize int, isBuildEmpty isBuildSideEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone, shouldLimitProbeFetchSize bool, hashJoinCtx *hashJoinCtxBase) { + hasWaitedForBuild := false + for { + probeSideResource := fetcher.getProbeSideResource(shouldLimitProbeFetchSize, maxChunkSize, hashJoinCtx) + if probeSideResource == nil { + return + } + probeSideResult := probeSideResource.chk + err := exec.Next(ctx, fetcher.ProbeSideExec, probeSideResult) + failpoint.Inject("ConsumeRandomPanic", nil) + if err != nil { + hashJoinCtx.joinResultCh <- &hashjoinWorkerResult{ + err: err, + } + return + } + if !hasWaitedForBuild { + failpoint.Inject("issue30289", func(val failpoint.Value) { + if val.(bool) { + probeSideResult.Reset() + } + }) + skipProbe := wait4BuildSide(isBuildEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone, hashJoinCtx) + if skipProbe { + // there is no need to probe, so just return + return + } + hasWaitedForBuild = true + } + + if probeSideResult.NumRows() == 0 { + return + } + + probeSideResource.dest <- probeSideResult + } +} + +type probeWorkerBase struct { + WorkerID uint + probeChkResourceCh chan *probeChkResource + joinChkResourceCh chan *chunk.Chunk + probeResultCh chan *chunk.Chunk +} + +func (worker *probeWorkerBase) initializeForProbe(probeChkResourceCh chan *probeChkResource, probeResultCh chan *chunk.Chunk, joinExec exec.Executor) { + // worker.joinChkResourceCh is for transmitting the reused join result chunks + // from the main thread to probe worker goroutines. + worker.joinChkResourceCh = make(chan *chunk.Chunk, 1) + worker.joinChkResourceCh <- exec.NewFirstChunk(joinExec) + worker.probeChkResourceCh = probeChkResourceCh + worker.probeResultCh = probeResultCh +} + +type buildWorkerBase struct { + BuildSideExec exec.Executor + BuildKeyColIdx []int +} + +// fetchBuildSideRows fetches all rows from build side executor, and append them +// to e.buildSideResult. +func (w *buildWorkerBase) fetchBuildSideRows(ctx context.Context, hashJoinCtx *hashJoinCtxBase, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { + defer close(chkCh) + var err error + failpoint.Inject("issue30289", func(val failpoint.Value) { + if val.(bool) { + err = errors.Errorf("issue30289 build return error") + errCh <- errors.Trace(err) + return + } + }) + failpoint.Inject("issue42662_1", func(val failpoint.Value) { + if val.(bool) { + if hashJoinCtx.SessCtx.GetSessionVars().ConnectionID != 0 { + // consume 170MB memory, this sql should be tracked into MemoryTop1Tracker + hashJoinCtx.memTracker.Consume(170 * 1024 * 1024) + } + return + } + }) + sessVars := hashJoinCtx.SessCtx.GetSessionVars() + failpoint.Inject("issue51998", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(2 * time.Second) + } + }) + for { + if hashJoinCtx.finished.Load() { + return + } + chk := hashJoinCtx.ChunkAllocPool.Alloc(w.BuildSideExec.RetFieldTypes(), sessVars.MaxChunkSize, sessVars.MaxChunkSize) + err = exec.Next(ctx, w.BuildSideExec, chk) + failpoint.Inject("issue51998", func(val failpoint.Value) { + if val.(bool) { + err = errors.Errorf("issue51998 build return error") + } + }) + if err != nil { + errCh <- errors.Trace(err) + return + } + failpoint.Inject("errorFetchBuildSideRowsMockOOMPanic", nil) + failpoint.Inject("ConsumeRandomPanic", nil) + if chk.NumRows() == 0 { + return + } + select { + case <-doneCh: + return + case <-hashJoinCtx.closeCh: + return + case chkCh <- chk: + } + } +} + +// probeChkResource stores the result of the join probe side fetch worker, +// `dest` is for Chunk reuse: after join workers process the probe side chunk which is read from `dest`, +// they'll store the used chunk as `chk`, and then the probe side fetch worker will put new data into `chk` and write `chk` into dest. +type probeChkResource struct { + chk *chunk.Chunk + dest chan<- *chunk.Chunk +} + +type hashJoinRuntimeStats struct { + fetchAndBuildHashTable time.Duration + hashStat hashStatistic + fetchAndProbe int64 + probe int64 + concurrent int + maxFetchAndProbe int64 +} + +func (e *hashJoinRuntimeStats) setMaxFetchAndProbeTime(t int64) { + for { + value := atomic.LoadInt64(&e.maxFetchAndProbe) + if t <= value { + return + } + if atomic.CompareAndSwapInt64(&e.maxFetchAndProbe, value, t) { + return + } + } +} + +// Tp implements the RuntimeStats interface. +func (*hashJoinRuntimeStats) Tp() int { + return execdetails.TpHashJoinRuntimeStats +} + +func (e *hashJoinRuntimeStats) String() string { + buf := bytes.NewBuffer(make([]byte, 0, 128)) + if e.fetchAndBuildHashTable > 0 { + buf.WriteString("build_hash_table:{total:") + buf.WriteString(execdetails.FormatDuration(e.fetchAndBuildHashTable)) + buf.WriteString(", fetch:") + buf.WriteString(execdetails.FormatDuration(e.fetchAndBuildHashTable - e.hashStat.buildTableElapse)) + buf.WriteString(", build:") + buf.WriteString(execdetails.FormatDuration(e.hashStat.buildTableElapse)) + buf.WriteString("}") + } + if e.probe > 0 { + buf.WriteString(", probe:{concurrency:") + buf.WriteString(strconv.Itoa(e.concurrent)) + buf.WriteString(", total:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe))) + buf.WriteString(", max:") + buf.WriteString(execdetails.FormatDuration(time.Duration(atomic.LoadInt64(&e.maxFetchAndProbe)))) + buf.WriteString(", probe:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.probe))) + // fetch time is the time wait fetch result from its child executor, + // wait time is the time wait its parent executor to fetch the joined result + buf.WriteString(", fetch and wait:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe - e.probe))) + if e.hashStat.probeCollision > 0 { + buf.WriteString(", probe_collision:") + buf.WriteString(strconv.FormatInt(e.hashStat.probeCollision, 10)) + } + buf.WriteString("}") + } + return buf.String() +} + +func (e *hashJoinRuntimeStats) Clone() execdetails.RuntimeStats { + return &hashJoinRuntimeStats{ + fetchAndBuildHashTable: e.fetchAndBuildHashTable, + hashStat: e.hashStat, + fetchAndProbe: e.fetchAndProbe, + probe: e.probe, + concurrent: e.concurrent, + maxFetchAndProbe: e.maxFetchAndProbe, + } +} + +func (e *hashJoinRuntimeStats) Merge(rs execdetails.RuntimeStats) { + tmp, ok := rs.(*hashJoinRuntimeStats) + if !ok { + return + } + e.fetchAndBuildHashTable += tmp.fetchAndBuildHashTable + e.hashStat.buildTableElapse += tmp.hashStat.buildTableElapse + e.hashStat.probeCollision += tmp.hashStat.probeCollision + e.fetchAndProbe += tmp.fetchAndProbe + e.probe += tmp.probe + if e.maxFetchAndProbe < tmp.maxFetchAndProbe { + e.maxFetchAndProbe = tmp.maxFetchAndProbe + } +} + +type hashStatistic struct { + // NOTE: probeCollision may be accessed from multiple goroutines concurrently. + probeCollision int64 + buildTableElapse time.Duration +} + +func (s *hashStatistic) String() string { + return fmt.Sprintf("probe_collision:%v, build:%v", s.probeCollision, execdetails.FormatDuration(s.buildTableElapse)) +} diff --git a/pkg/executor/join/hash_join_v1.go b/pkg/executor/join/hash_join_v1.go index 649eac1eb466b..d414a594f8116 100644 --- a/pkg/executor/join/hash_join_v1.go +++ b/pkg/executor/join/hash_join_v1.go @@ -330,7 +330,7 @@ func (w *ProbeWorkerV1) runJoinWorker() { return case probeSideResult, ok = <-w.probeResultCh: } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if !ok { break } @@ -859,11 +859,11 @@ func (w *ProbeWorkerV1) join2Chunk(probeSideChk *chunk.Chunk, hCtx *HashContext, for i := range selected { err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() - failpoint.Inject("killedInJoin2Chunk", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("killedInJoin2Chunk")); _err_ == nil { if val.(bool) { err = exeerrors.ErrQueryInterrupted } - }) + } if err != nil { joinResult.err = err return false, waitTime, joinResult @@ -938,11 +938,11 @@ func (w *ProbeWorkerV1) join2ChunkForOuterHashJoin(probeSideChk *chunk.Chunk, hC } for i := 0; i < probeSideChk.NumRows(); i++ { err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() - failpoint.Inject("killedInJoin2ChunkForOuterHashJoin", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("killedInJoin2ChunkForOuterHashJoin")); _err_ == nil { if val.(bool) { err = exeerrors.ErrQueryInterrupted } - }) + } if err != nil { joinResult.err = err return false, waitTime, joinResult @@ -1073,12 +1073,12 @@ func (w *BuildWorkerV1) BuildHashTableForList(buildSideResultCh <-chan *chunk.Ch rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult) if variable.EnableTmpStorageOnOOM.Load() { actionSpill := rowContainer.ActionSpill() - failpoint.Inject("testRowContainerSpill", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testRowContainerSpill")); _err_ == nil { if val.(bool) { actionSpill = rowContainer.rowContainer.ActionSpillForTest() defer actionSpill.(*chunk.SpillDiskAction).WaitForTest() } - }) + } w.HashJoinCtx.SessCtx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } for chk := range buildSideResultCh { @@ -1101,7 +1101,7 @@ func (w *BuildWorkerV1) BuildHashTableForList(buildSideResultCh <-chan *chunk.Ch err = rowContainer.PutChunkSelected(chk, selected, w.HashJoinCtx.IsNullEQ) } } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if err != nil { return err } diff --git a/pkg/executor/join/hash_join_v1.go__failpoint_stash__ b/pkg/executor/join/hash_join_v1.go__failpoint_stash__ new file mode 100644 index 0000000000000..649eac1eb466b --- /dev/null +++ b/pkg/executor/join/hash_join_v1.go__failpoint_stash__ @@ -0,0 +1,1434 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "bytes" + "context" + "fmt" + "runtime/trace" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/aggregate" + "github.com/pingcap/tidb/pkg/executor/internal/applycache" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/unionexec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/bitmap" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/memory" +) + +var ( + _ exec.Executor = &HashJoinV1Exec{} + _ exec.Executor = &NestedLoopApplyExec{} +) + +// HashJoinCtxV1 is the context used in hash join +type HashJoinCtxV1 struct { + hashJoinCtxBase + UseOuterToBuild bool + IsOuterJoin bool + RowContainer *hashRowContainer + outerMatchedStatus []*bitmap.ConcurrentBitmap + ProbeTypes []*types.FieldType + BuildTypes []*types.FieldType + OuterFilter expression.CNFExprs + stats *hashJoinRuntimeStats +} + +// ProbeSideTupleFetcherV1 reads tuples from ProbeSideExec and send them to ProbeWorkers. +type ProbeSideTupleFetcherV1 struct { + probeSideTupleFetcherBase + *HashJoinCtxV1 +} + +// ProbeWorkerV1 is the probe side worker in hash join +type ProbeWorkerV1 struct { + probeWorkerBase + HashJoinCtx *HashJoinCtxV1 + ProbeKeyColIdx []int + ProbeNAKeyColIdx []int + // We pre-alloc and reuse the Rows and RowPtrs for each probe goroutine, to avoid allocation frequently + buildSideRows []chunk.Row + buildSideRowPtrs []chunk.RowPtr + + // We build individual joiner for each join worker when use chunk-based + // execution, to avoid the concurrency of joiner.chk and joiner.selected. + Joiner Joiner + rowIters *chunk.Iterator4Slice + rowContainerForProbe *hashRowContainer + // for every naaj probe worker, pre-allocate the int slice for store the join column index to check. + needCheckBuildColPos []int + needCheckProbeColPos []int + needCheckBuildTypes []*types.FieldType + needCheckProbeTypes []*types.FieldType +} + +// BuildWorkerV1 is the build side worker in hash join +type BuildWorkerV1 struct { + buildWorkerBase + HashJoinCtx *HashJoinCtxV1 + BuildNAKeyColIdx []int +} + +// HashJoinV1Exec implements the hash join algorithm. +type HashJoinV1Exec struct { + exec.BaseExecutor + *HashJoinCtxV1 + + ProbeSideTupleFetcher *ProbeSideTupleFetcherV1 + ProbeWorkers []*ProbeWorkerV1 + BuildWorker *BuildWorkerV1 + + workerWg util.WaitGroupWrapper + waiterWg util.WaitGroupWrapper + + Prepared bool +} + +// Close implements the Executor Close interface. +func (e *HashJoinV1Exec) Close() error { + if e.closeCh != nil { + close(e.closeCh) + } + e.finished.Store(true) + if e.Prepared { + if e.buildFinished != nil { + channel.Clear(e.buildFinished) + } + if e.joinResultCh != nil { + channel.Clear(e.joinResultCh) + } + if e.ProbeSideTupleFetcher.probeChkResourceCh != nil { + close(e.ProbeSideTupleFetcher.probeChkResourceCh) + channel.Clear(e.ProbeSideTupleFetcher.probeChkResourceCh) + } + for i := range e.ProbeSideTupleFetcher.probeResultChs { + channel.Clear(e.ProbeSideTupleFetcher.probeResultChs[i]) + } + for i := range e.ProbeWorkers { + close(e.ProbeWorkers[i].joinChkResourceCh) + channel.Clear(e.ProbeWorkers[i].joinChkResourceCh) + } + e.ProbeSideTupleFetcher.probeChkResourceCh = nil + terror.Call(e.RowContainer.Close) + e.HashJoinCtxV1.SessCtx.GetSessionVars().MemTracker.UnbindActionFromHardLimit(e.RowContainer.ActionSpill()) + e.waiterWg.Wait() + } + e.outerMatchedStatus = e.outerMatchedStatus[:0] + for _, w := range e.ProbeWorkers { + w.buildSideRows = nil + w.buildSideRowPtrs = nil + w.needCheckBuildColPos = nil + w.needCheckProbeColPos = nil + w.needCheckBuildTypes = nil + w.needCheckProbeTypes = nil + w.joinChkResourceCh = nil + } + + if e.stats != nil && e.RowContainer != nil { + e.stats.hashStat = *e.RowContainer.stat + } + if e.stats != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + err := e.BaseExecutor.Close() + return err +} + +// Open implements the Executor Open interface. +func (e *HashJoinV1Exec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + e.closeCh = nil + e.Prepared = false + return err + } + e.Prepared = false + if e.HashJoinCtxV1.memTracker != nil { + e.HashJoinCtxV1.memTracker.Reset() + } else { + e.HashJoinCtxV1.memTracker = memory.NewTracker(e.ID(), -1) + } + e.HashJoinCtxV1.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + + if e.HashJoinCtxV1.diskTracker != nil { + e.HashJoinCtxV1.diskTracker.Reset() + } else { + e.HashJoinCtxV1.diskTracker = disk.NewTracker(e.ID(), -1) + } + e.HashJoinCtxV1.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) + + e.workerWg = util.WaitGroupWrapper{} + e.waiterWg = util.WaitGroupWrapper{} + e.closeCh = make(chan struct{}) + e.finished.Store(false) + + if e.RuntimeStats() != nil { + e.stats = &hashJoinRuntimeStats{ + concurrent: int(e.Concurrency), + } + } + return nil +} + +func (e *HashJoinV1Exec) initializeForProbe() { + e.ProbeSideTupleFetcher.HashJoinCtxV1 = e.HashJoinCtxV1 + // e.joinResultCh is for transmitting the join result chunks to the main + // thread. + e.joinResultCh = make(chan *hashjoinWorkerResult, e.Concurrency+1) + e.ProbeSideTupleFetcher.initializeForProbeBase(e.Concurrency, e.joinResultCh) + + for i := uint(0); i < e.Concurrency; i++ { + e.ProbeWorkers[i].initializeForProbe(e.ProbeSideTupleFetcher.probeChkResourceCh, e.ProbeSideTupleFetcher.probeResultChs[i], e) + } +} + +func (e *HashJoinV1Exec) fetchAndProbeHashTable(ctx context.Context) { + e.initializeForProbe() + e.workerWg.RunWithRecover(func() { + defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() + e.ProbeSideTupleFetcher.fetchProbeSideChunks(ctx, e.MaxChunkSize(), func() bool { + return e.ProbeSideTupleFetcher.RowContainer.Len() == uint64(0) + }, e.ProbeSideTupleFetcher.JoinType == plannercore.InnerJoin || e.ProbeSideTupleFetcher.JoinType == plannercore.SemiJoin, + false, e.ProbeSideTupleFetcher.IsOuterJoin, &e.ProbeSideTupleFetcher.hashJoinCtxBase) + }, e.ProbeSideTupleFetcher.handleProbeSideFetcherPanic) + + for i := uint(0); i < e.Concurrency; i++ { + workerID := i + e.workerWg.RunWithRecover(func() { + defer trace.StartRegion(ctx, "HashJoinWorker").End() + e.ProbeWorkers[workerID].runJoinWorker() + }, e.ProbeWorkers[workerID].handleProbeWorkerPanic) + } + e.waiterWg.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) +} + +func (w *ProbeWorkerV1) handleProbeWorkerPanic(r any) { + if r != nil { + w.HashJoinCtx.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} + } +} + +func (e *HashJoinV1Exec) handleJoinWorkerPanic(r any) { + if r != nil { + e.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} + } +} + +// Concurrently handling unmatched rows from the hash table +func (w *ProbeWorkerV1) handleUnmatchedRowsFromHashTable() { + ok, joinResult := w.getNewJoinResult() + if !ok { + return + } + numChks := w.rowContainerForProbe.NumChunks() + for i := int(w.WorkerID); i < numChks; i += int(w.HashJoinCtx.Concurrency) { + chk, err := w.rowContainerForProbe.GetChunk(i) + if err != nil { + // Catching the error and send it + joinResult.err = err + w.HashJoinCtx.joinResultCh <- joinResult + return + } + for j := 0; j < chk.NumRows(); j++ { + if !w.HashJoinCtx.outerMatchedStatus[i].UnsafeIsSet(j) { // process unmatched Outer rows + w.Joiner.OnMissMatch(false, chk.GetRow(j), joinResult.chk) + } + if joinResult.chk.IsFull() { + w.HashJoinCtx.joinResultCh <- joinResult + ok, joinResult = w.getNewJoinResult() + if !ok { + return + } + } + } + } + + if joinResult == nil { + return + } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { + w.HashJoinCtx.joinResultCh <- joinResult + } +} + +func (e *HashJoinV1Exec) waitJoinWorkersAndCloseResultChan() { + e.workerWg.Wait() + if e.UseOuterToBuild { + // Concurrently handling unmatched rows from the hash table at the tail + for i := uint(0); i < e.Concurrency; i++ { + var workerID = i + e.workerWg.RunWithRecover(func() { e.ProbeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic) + } + e.workerWg.Wait() + } + close(e.joinResultCh) +} + +func (w *ProbeWorkerV1) runJoinWorker() { + probeTime := int64(0) + if w.HashJoinCtx.stats != nil { + start := time.Now() + defer func() { + t := time.Since(start) + atomic.AddInt64(&w.HashJoinCtx.stats.probe, probeTime) + atomic.AddInt64(&w.HashJoinCtx.stats.fetchAndProbe, int64(t)) + w.HashJoinCtx.stats.setMaxFetchAndProbeTime(int64(t)) + }() + } + + var ( + probeSideResult *chunk.Chunk + selected = make([]bool, 0, chunk.InitialCapacity) + ) + ok, joinResult := w.getNewJoinResult() + if !ok { + return + } + + // Read and filter probeSideResult, and join the probeSideResult with the build side rows. + emptyProbeSideResult := &probeChkResource{ + dest: w.probeResultCh, + } + hCtx := &HashContext{ + AllTypes: w.HashJoinCtx.ProbeTypes, + KeyColIdx: w.ProbeKeyColIdx, + NaKeyColIdx: w.ProbeNAKeyColIdx, + } + for ok := true; ok; { + if w.HashJoinCtx.finished.Load() { + break + } + select { + case <-w.HashJoinCtx.closeCh: + return + case probeSideResult, ok = <-w.probeResultCh: + } + failpoint.Inject("ConsumeRandomPanic", nil) + if !ok { + break + } + start := time.Now() + // waitTime is the time cost on w.sendingResult(), it should not be added to probe time, because if + // parent executor does not call `e.Next()`, `sendingResult()` will hang, and this hang has nothing to do + // with the probe + waitTime := int64(0) + if w.HashJoinCtx.UseOuterToBuild { + ok, waitTime, joinResult = w.join2ChunkForOuterHashJoin(probeSideResult, hCtx, joinResult) + } else { + ok, waitTime, joinResult = w.join2Chunk(probeSideResult, hCtx, joinResult, selected) + } + probeTime += int64(time.Since(start)) - waitTime + if !ok { + break + } + probeSideResult.Reset() + emptyProbeSideResult.chk = probeSideResult + w.probeChkResourceCh <- emptyProbeSideResult + } + // note joinResult.chk may be nil when getNewJoinResult fails in loops + if joinResult == nil { + return + } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { + w.HashJoinCtx.joinResultCh <- joinResult + } else if joinResult.chk != nil && joinResult.chk.NumRows() == 0 { + w.joinChkResourceCh <- joinResult.chk + } +} + +func (w *ProbeWorkerV1) joinMatchedProbeSideRow2ChunkForOuterHashJoin(probeKey uint64, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { + var err error + waitTime := int64(0) + oneWaitTime := int64(0) + w.buildSideRows, w.buildSideRowPtrs, err = w.rowContainerForProbe.GetMatchedRowsAndPtrs(probeKey, probeSideRow, hCtx, w.buildSideRows, w.buildSideRowPtrs, true) + buildSideRows, rowsPtrs := w.buildSideRows, w.buildSideRowPtrs + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) == 0 { + return true, waitTime, joinResult + } + + iter := w.rowIters + iter.Reset(buildSideRows) + var outerMatchStatus []outerRowStatusFlag + rowIdx, ok := 0, false + for iter.Begin(); iter.Current() != iter.End(); { + outerMatchStatus, err = w.Joiner.TryToMatchOuters(iter, probeSideRow, joinResult.chk, outerMatchStatus) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + for i := range outerMatchStatus { + if outerMatchStatus[i] == outerRowMatched { + w.HashJoinCtx.outerMatchedStatus[rowsPtrs[rowIdx+i].ChkIdx].Set(int(rowsPtrs[rowIdx+i].RowIdx)) + } + } + rowIdx += len(outerMatchStatus) + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + return true, waitTime, joinResult +} + +// joinNAALOSJMatchProbeSideRow2Chunk implement the matching logic for NA-AntiLeftOuterSemiJoin +func (w *ProbeWorkerV1) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { + var ( + err error + ok bool + ) + waitTime := int64(0) + oneWaitTime := int64(0) + if probeKeyNullBits == nil { + // step1: match the same key bucket first. + // because AntiLeftOuterSemiJoin cares about the scalar value. If we both have a match from null + // bucket and same key bucket, we should return the result as from same-key bucket + // rather than from null bucket. + w.buildSideRows, err = w.rowContainerForProbe.GetMatchedRows(probeKey, probeSideRow, hCtx, w.buildSideRows) + buildSideRows := w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) != 0 { + iter1 := w.rowIters + iter1.Reset(buildSideRows) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk, LeftNotNullRightNotNull) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid same-key bucket row from right side. + // as said in the comment, once we meet a same key (NOT IN semantic) in CNF, we can determine the result as . + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + } + // step2: match the null bucket secondly. + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) + buildSideRows = w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't find a valid same key match from same-key bucket yet + // and the null bucket is empty. so the result should be . + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult + } + iter2 := w.rowIters + iter2.Reset(buildSideRows) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk, LeftNotNullRightHasNull) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid null bucket row from right side. + // as said in the comment, once we meet a null in CNF, we can determine the result as . + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and same key bucket, here means two cases: + // case1: x NOT IN (empty set): if other key bucket don't have the valid rows yet. + // case2: x NOT IN (l,m,n...): if other key bucket do have the valid rows. + // both cases mean the result should be + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult + } + // when left side has null values, all we want is to find a valid build side rows (past other condition) + // so we can return it as soon as possible. here means two cases: + // case1: NOT IN (empty set): ----------------------> result is . + // case2: NOT IN (at least a valid inner row) ------------------> result is . + // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) + buildSideRows := w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) != 0 { + iter1 := w.rowIters + iter1.Reset(buildSideRows) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk, LeftHasNullRightHasNull) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid null bucket row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the result as . + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + } + // Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any). + w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) + buildSideRows = w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, + // which means x NOT IN (empty set) or x NOT IN (l,m,n), the result should be + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult + } + iter2 := w.rowIters + iter2.Reset(buildSideRows) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk, LeftHasNullRightNotNull) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid same key bucket row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the result as . + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and all hash bucket, here means only one cases: + // case1: NOT IN (empty set): + // empty set comes from no rows from all bucket can pass other condition. the result should be + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult +} + +// joinNAASJMatchProbeSideRow2Chunk implement the matching logic for NA-AntiSemiJoin +func (w *ProbeWorkerV1) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { + var ( + err error + ok bool + ) + waitTime := int64(0) + oneWaitTime := int64(0) + if probeKeyNullBits == nil { + // step1: match null bucket first. + // need fetch the "valid" rows every time. (nullBits map check is necessary) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) + buildSideRows := w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) != 0 { + iter1 := w.rowIters + iter1.Reset(buildSideRows) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid null bucket row from right side. + // as said in the comment, once we meet a rhs null in CNF, we can determine the reject of lhs row. + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + } + // step2: then same key bucket. + w.buildSideRows, err = w.rowContainerForProbe.GetMatchedRows(probeKey, probeSideRow, hCtx, w.buildSideRows) + buildSideRows = w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, + // which means x NOT IN (empty set), accept the rhs Row. + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult + } + iter2 := w.rowIters + iter2.Reset(buildSideRows) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid same key bucket row from right side. + // as said in the comment, once we meet a false in CNF, we can determine the reject of lhs Row. + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and same key bucket, here means two cases: + // case1: x NOT IN (empty set): if other key bucket don't have the valid rows yet. + // case2: x NOT IN (l,m,n...): if other key bucket do have the valid rows. + // both cases should accept the rhs row. + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult + } + // when left side has null values, all we want is to find a valid build side rows (passed from other condition) + // so we can return it as soon as possible. here means two cases: + // case1: NOT IN (empty set): ----------------------> accept rhs row. + // case2: NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row. + // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) + w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) + buildSideRows := w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) != 0 { + iter1 := w.rowIters + iter1.Reset(buildSideRows) + for iter1.Begin(); iter1.Current() != iter1.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid null bucket row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the reject of lhs row. + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + } + // Step2: match all hash table bucket build rows. + w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) + buildSideRows = w.buildSideRows + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) == 0 { + // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, + // which means NOT IN (empty set) or NOT IN (no valid rows) accept the rhs row. + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult + } + iter2 := w.rowIters + iter2.Reset(buildSideRows) + for iter2.Begin(); iter2.Current() != iter2.End(); { + matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // here matched means: there is a valid key row from right side. (not empty) + // as said in the comment, once we found at least a valid row, we can determine the reject of lhs row. + if matched { + return true, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + // step3: if we couldn't return it quickly in null bucket and all hash bucket, here means only one cases: + // case1: NOT IN (empty set): + // empty set comes from no rows from all bucket can pass other condition. we should accept the rhs row. + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult +} + +// joinNAAJMatchProbeSideRow2Chunk implement the matching priority logic for NA-AntiSemiJoin and NA-AntiLeftOuterSemiJoin +// there are some bucket-matching priority difference between them. +// +// Since NA-AntiSemiJoin don't need to append the scalar value with the left side row, there is a quick matching path. +// 1: lhs row has null: +// lhs row has null can't determine its result in advance, we should judge whether the right valid set is empty +// or not. For semantic like x NOT IN(y set), If y set is empty, the scalar result is 1; Otherwise, the result +// is 0. Since NA-AntiSemiJoin don't care about the scalar value, we just try to find a valid row from right side, +// once we found it then just return the left side row instantly. (same as NA-AntiLeftOuterSemiJoin) +// +// 2: lhs row without null: +// same-key bucket and null-bucket which should be the first to match? For semantic like x NOT IN(y set), once y +// set has a same key x, the scalar value is 0; else if y set has a null key, then the scalar value is null. Both +// of them lead the refuse of the lhs row without any difference. Since NA-AntiSemiJoin don't care about the scalar +// value, we can just match the null bucket first and refuse the lhs row as quickly as possible, because a null of +// yi in the CNF (x NA-EQ yi) can always determine a negative value (refuse lhs row) in advance here. +// +// For NA-AntiLeftOuterSemiJoin, we couldn't match null-bucket first, because once y set has a same key x and null +// key, we should return the result as left side row appended with a scalar value 0 which is from same key matching failure. +func (w *ProbeWorkerV1) joinNAAJMatchProbeSideRow2Chunk(probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { + naAntiSemiJoin := w.HashJoinCtx.JoinType == plannercore.AntiSemiJoin && w.HashJoinCtx.IsNullAware + naAntiLeftOuterSemiJoin := w.HashJoinCtx.JoinType == plannercore.AntiLeftOuterSemiJoin && w.HashJoinCtx.IsNullAware + if naAntiSemiJoin { + return w.joinNAASJMatchProbeSideRow2Chunk(probeKey, probeKeyNullBits, probeSideRow, hCtx, joinResult) + } + if naAntiLeftOuterSemiJoin { + return w.joinNAALOSJMatchProbeSideRow2Chunk(probeKey, probeKeyNullBits, probeSideRow, hCtx, joinResult) + } + // shouldn't be here, not a valid NAAJ. + return false, 0, joinResult +} + +func (w *ProbeWorkerV1) joinMatchedProbeSideRow2Chunk(probeKey uint64, probeSideRow chunk.Row, hCtx *HashContext, + joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { + var err error + waitTime := int64(0) + oneWaitTime := int64(0) + var buildSideRows []chunk.Row + if w.Joiner.isSemiJoinWithoutCondition() { + var rowPtr *chunk.Row + rowPtr, err = w.rowContainerForProbe.GetOneMatchedRow(probeKey, probeSideRow, hCtx) + if rowPtr != nil { + buildSideRows = append(buildSideRows, *rowPtr) + } + } else { + w.buildSideRows, err = w.rowContainerForProbe.GetMatchedRows(probeKey, probeSideRow, hCtx, w.buildSideRows) + buildSideRows = w.buildSideRows + } + + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if len(buildSideRows) == 0 { + w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) + return true, waitTime, joinResult + } + iter := w.rowIters + iter.Reset(buildSideRows) + hasMatch, hasNull, ok := false, false, false + for iter.Begin(); iter.Current() != iter.End(); { + matched, isNull, err := w.Joiner.TryToMatchInners(probeSideRow, iter, joinResult.chk) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + hasMatch = hasMatch || matched + hasNull = hasNull || isNull + + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + if !hasMatch { + w.Joiner.OnMissMatch(hasNull, probeSideRow, joinResult.chk) + } + return true, waitTime, joinResult +} + +func (w *ProbeWorkerV1) getNewJoinResult() (bool, *hashjoinWorkerResult) { + joinResult := &hashjoinWorkerResult{ + src: w.joinChkResourceCh, + } + ok := true + select { + case <-w.HashJoinCtx.closeCh: + ok = false + case joinResult.chk, ok = <-w.joinChkResourceCh: + } + return ok, joinResult +} + +func (w *ProbeWorkerV1) join2Chunk(probeSideChk *chunk.Chunk, hCtx *HashContext, joinResult *hashjoinWorkerResult, + selected []bool) (ok bool, waitTime int64, _ *hashjoinWorkerResult) { + var err error + waitTime = 0 + oneWaitTime := int64(0) + selected, err = expression.VectorizedFilter(w.HashJoinCtx.SessCtx.GetExprCtx().GetEvalCtx(), w.HashJoinCtx.SessCtx.GetSessionVars().EnableVectorizedExpression, w.HashJoinCtx.OuterFilter, chunk.NewIterator4Chunk(probeSideChk), selected) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + + numRows := probeSideChk.NumRows() + hCtx.InitHash(numRows) + // By now, path 1 and 2 won't be conducted at the same time. + // 1: write the row data of join key to hashVals. (normal EQ key should ignore the null values.) null-EQ for Except statement is an exception. + for keyIdx, i := range hCtx.KeyColIdx { + ignoreNull := len(w.HashJoinCtx.IsNullEQ) > keyIdx && w.HashJoinCtx.IsNullEQ[keyIdx] + err = codec.HashChunkSelected(w.rowContainerForProbe.sc.TypeCtx(), hCtx.HashVals, probeSideChk, hCtx.AllTypes[keyIdx], i, hCtx.Buf, hCtx.HasNull, selected, ignoreNull) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + } + // 2: write the how data of NA join key to hashVals. (NA EQ key should collect all how including null value, store null value in a special position) + isNAAJ := len(hCtx.NaKeyColIdx) > 0 + for keyIdx, i := range hCtx.NaKeyColIdx { + // NAAJ won't ignore any null values, but collect them up to probe. + err = codec.HashChunkSelected(w.rowContainerForProbe.sc.TypeCtx(), hCtx.HashVals, probeSideChk, hCtx.AllTypes[keyIdx], i, hCtx.Buf, hCtx.HasNull, selected, false) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + // after fetch one NA column, collect the null value to null bitmap for every how. (use hasNull flag to accelerate) + // eg: if a NA Join cols is (a, b, c), for every build row here we maintained a 3-bit map to mark which column is null for them. + for rowIdx := 0; rowIdx < numRows; rowIdx++ { + if hCtx.HasNull[rowIdx] { + hCtx.naColNullBitMap[rowIdx].UnsafeSet(keyIdx) + // clean and try fetch Next NA join col. + hCtx.HasNull[rowIdx] = false + hCtx.naHasNull[rowIdx] = true + } + } + } + + for i := range selected { + err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() + failpoint.Inject("killedInJoin2Chunk", func(val failpoint.Value) { + if val.(bool) { + err = exeerrors.ErrQueryInterrupted + } + }) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + if isNAAJ { + if !selected[i] { + // since this is the case of using inner to build, so for an outer row unselected, we should fill the result when it's outer join. + w.Joiner.OnMissMatch(false, probeSideChk.GetRow(i), joinResult.chk) + } + if hCtx.naHasNull[i] { + // here means the probe join connecting column has null value in it and this is special for matching all the hash buckets + // for it. (probeKey is not necessary here) + probeRow := probeSideChk.GetRow(i) + ok, oneWaitTime, joinResult = w.joinNAAJMatchProbeSideRow2Chunk(0, hCtx.naColNullBitMap[i].Clone(), probeRow, hCtx, joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } else { + // here means the probe join connecting column without null values, where we should match same key bucket and null bucket for it at its order. + // step1: process same key matched probe side rows + probeKey, probeRow := hCtx.HashVals[i].Sum64(), probeSideChk.GetRow(i) + ok, oneWaitTime, joinResult = w.joinNAAJMatchProbeSideRow2Chunk(probeKey, nil, probeRow, hCtx, joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } else { + // since this is the case of using inner to build, so for an outer row unselected, we should fill the result when it's outer join. + if !selected[i] || hCtx.HasNull[i] { // process unmatched probe side rows + w.Joiner.OnMissMatch(false, probeSideChk.GetRow(i), joinResult.chk) + } else { // process matched probe side rows + probeKey, probeRow := hCtx.HashVals[i].Sum64(), probeSideChk.GetRow(i) + ok, oneWaitTime, joinResult = w.joinMatchedProbeSideRow2Chunk(probeKey, probeRow, hCtx, joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + return true, waitTime, joinResult +} + +func (w *ProbeWorkerV1) sendingResult(joinResult *hashjoinWorkerResult) (ok bool, cost int64, newJoinResult *hashjoinWorkerResult) { + start := time.Now() + w.HashJoinCtx.joinResultCh <- joinResult + ok, newJoinResult = w.getNewJoinResult() + cost = int64(time.Since(start)) + return ok, cost, newJoinResult +} + +// join2ChunkForOuterHashJoin joins chunks when using the outer to build a hash table (refer to outer hash join) +func (w *ProbeWorkerV1) join2ChunkForOuterHashJoin(probeSideChk *chunk.Chunk, hCtx *HashContext, joinResult *hashjoinWorkerResult) (ok bool, waitTime int64, _ *hashjoinWorkerResult) { + waitTime = 0 + oneWaitTime := int64(0) + hCtx.InitHash(probeSideChk.NumRows()) + for keyIdx, i := range hCtx.KeyColIdx { + err := codec.HashChunkColumns(w.rowContainerForProbe.sc.TypeCtx(), hCtx.HashVals, probeSideChk, hCtx.AllTypes[keyIdx], i, hCtx.Buf, hCtx.HasNull) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + } + for i := 0; i < probeSideChk.NumRows(); i++ { + err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() + failpoint.Inject("killedInJoin2ChunkForOuterHashJoin", func(val failpoint.Value) { + if val.(bool) { + err = exeerrors.ErrQueryInterrupted + } + }) + if err != nil { + joinResult.err = err + return false, waitTime, joinResult + } + probeKey, probeRow := hCtx.HashVals[i].Sum64(), probeSideChk.GetRow(i) + ok, oneWaitTime, joinResult = w.joinMatchedProbeSideRow2ChunkForOuterHashJoin(probeKey, probeRow, hCtx, joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + if joinResult.chk.IsFull() { + ok, oneWaitTime, joinResult = w.sendingResult(joinResult) + waitTime += oneWaitTime + if !ok { + return false, waitTime, joinResult + } + } + } + return true, waitTime, joinResult +} + +// Next implements the Executor Next interface. +// hash join constructs the result following these steps: +// step 1. fetch data from build side child and build a hash table; +// step 2. fetch data from probe child in a background goroutine and probe the hash table in multiple join workers. +func (e *HashJoinV1Exec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + if !e.Prepared { + e.buildFinished = make(chan error, 1) + hCtx := &HashContext{ + AllTypes: e.BuildTypes, + KeyColIdx: e.BuildWorker.BuildKeyColIdx, + NaKeyColIdx: e.BuildWorker.BuildNAKeyColIdx, + } + e.RowContainer = newHashRowContainer(e.Ctx(), hCtx, exec.RetTypes(e.BuildWorker.BuildSideExec)) + // we shallow copies RowContainer for each probe worker to avoid lock contention + for i := uint(0); i < e.Concurrency; i++ { + if i == 0 { + e.ProbeWorkers[i].rowContainerForProbe = e.RowContainer + } else { + e.ProbeWorkers[i].rowContainerForProbe = e.RowContainer.ShallowCopy() + } + } + for i := uint(0); i < e.Concurrency; i++ { + e.ProbeWorkers[i].rowIters = chunk.NewIterator4Slice([]chunk.Row{}) + } + e.workerWg.RunWithRecover(func() { + defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End() + e.fetchAndBuildHashTable(ctx) + }, e.handleFetchAndBuildHashTablePanic) + e.fetchAndProbeHashTable(ctx) + e.Prepared = true + } + if e.IsOuterJoin { + atomic.StoreInt64(&e.ProbeSideTupleFetcher.requiredRows, int64(req.RequiredRows())) + } + req.Reset() + + result, ok := <-e.joinResultCh + if !ok { + return nil + } + if result.err != nil { + e.finished.Store(true) + return result.err + } + req.SwapColumns(result.chk) + result.src <- result.chk + return nil +} + +func (e *HashJoinV1Exec) handleFetchAndBuildHashTablePanic(r any) { + if r != nil { + e.buildFinished <- util.GetRecoverError(r) + } + close(e.buildFinished) +} + +func (e *HashJoinV1Exec) fetchAndBuildHashTable(ctx context.Context) { + if e.stats != nil { + start := time.Now() + defer func() { + e.stats.fetchAndBuildHashTable = time.Since(start) + }() + } + // buildSideResultCh transfers build side chunk from build side fetch to build hash table. + buildSideResultCh := make(chan *chunk.Chunk, 1) + doneCh := make(chan struct{}) + fetchBuildSideRowsOk := make(chan error, 1) + e.workerWg.RunWithRecover( + func() { + defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End() + e.BuildWorker.fetchBuildSideRows(ctx, &e.BuildWorker.HashJoinCtx.hashJoinCtxBase, buildSideResultCh, fetchBuildSideRowsOk, doneCh) + }, + func(r any) { + if r != nil { + fetchBuildSideRowsOk <- util.GetRecoverError(r) + } + close(fetchBuildSideRowsOk) + }, + ) + + // TODO: Parallel build hash table. Currently not support because `unsafeHashTable` is not thread-safe. + err := e.BuildWorker.BuildHashTableForList(buildSideResultCh) + if err != nil { + e.buildFinished <- errors.Trace(err) + close(doneCh) + } + // Wait fetchBuildSideRows be Finished. + // 1. if BuildHashTableForList fails + // 2. if probeSideResult.NumRows() == 0, fetchProbeSideChunks will not wait for the build side. + channel.Clear(buildSideResultCh) + // Check whether err is nil to avoid sending redundant error into buildFinished. + if err == nil { + if err = <-fetchBuildSideRowsOk; err != nil { + e.buildFinished <- err + } + } +} + +// BuildHashTableForList builds hash table from `list`. +func (w *BuildWorkerV1) BuildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error { + var err error + var selected []bool + rowContainer := w.HashJoinCtx.RowContainer + rowContainer.GetMemTracker().AttachTo(w.HashJoinCtx.memTracker) + rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult) + rowContainer.GetDiskTracker().AttachTo(w.HashJoinCtx.diskTracker) + rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult) + if variable.EnableTmpStorageOnOOM.Load() { + actionSpill := rowContainer.ActionSpill() + failpoint.Inject("testRowContainerSpill", func(val failpoint.Value) { + if val.(bool) { + actionSpill = rowContainer.rowContainer.ActionSpillForTest() + defer actionSpill.(*chunk.SpillDiskAction).WaitForTest() + } + }) + w.HashJoinCtx.SessCtx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) + } + for chk := range buildSideResultCh { + if w.HashJoinCtx.finished.Load() { + return nil + } + if !w.HashJoinCtx.UseOuterToBuild { + err = rowContainer.PutChunk(chk, w.HashJoinCtx.IsNullEQ) + } else { + var bitMap = bitmap.NewConcurrentBitmap(chk.NumRows()) + w.HashJoinCtx.outerMatchedStatus = append(w.HashJoinCtx.outerMatchedStatus, bitMap) + w.HashJoinCtx.memTracker.Consume(bitMap.BytesConsumed()) + if len(w.HashJoinCtx.OuterFilter) == 0 { + err = w.HashJoinCtx.RowContainer.PutChunk(chk, w.HashJoinCtx.IsNullEQ) + } else { + selected, err = expression.VectorizedFilter(w.HashJoinCtx.SessCtx.GetExprCtx().GetEvalCtx(), w.HashJoinCtx.SessCtx.GetSessionVars().EnableVectorizedExpression, w.HashJoinCtx.OuterFilter, chunk.NewIterator4Chunk(chk), selected) + if err != nil { + return err + } + err = rowContainer.PutChunkSelected(chk, selected, w.HashJoinCtx.IsNullEQ) + } + } + failpoint.Inject("ConsumeRandomPanic", nil) + if err != nil { + return err + } + } + return nil +} + +// NestedLoopApplyExec is the executor for apply. +type NestedLoopApplyExec struct { + exec.BaseExecutor + + Sctx sessionctx.Context + innerRows []chunk.Row + cursor int + InnerExec exec.Executor + OuterExec exec.Executor + InnerFilter expression.CNFExprs + OuterFilter expression.CNFExprs + + Joiner Joiner + + cache *applycache.ApplyCache + CanUseCache bool + cacheHitCounter int + cacheAccessCounter int + + OuterSchema []*expression.CorrelatedColumn + + OuterChunk *chunk.Chunk + outerChunkCursor int + outerSelected []bool + InnerList *chunk.List + InnerChunk *chunk.Chunk + innerSelected []bool + innerIter chunk.Iterator + outerRow *chunk.Row + hasMatch bool + hasNull bool + + Outer bool + + memTracker *memory.Tracker // track memory usage. +} + +// Close implements the Executor interface. +func (e *NestedLoopApplyExec) Close() error { + e.innerRows = nil + e.memTracker = nil + if e.RuntimeStats() != nil { + runtimeStats := NewJoinRuntimeStats() + if e.CanUseCache { + var hitRatio float64 + if e.cacheAccessCounter > 0 { + hitRatio = float64(e.cacheHitCounter) / float64(e.cacheAccessCounter) + } + runtimeStats.SetCacheInfo(true, hitRatio) + } else { + runtimeStats.SetCacheInfo(false, 0) + } + runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("concurrency", 0)) + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) + } + return exec.Close(e.OuterExec) +} + +// Open implements the Executor interface. +func (e *NestedLoopApplyExec) Open(ctx context.Context) error { + err := exec.Open(ctx, e.OuterExec) + if err != nil { + return err + } + e.cursor = 0 + e.innerRows = e.innerRows[:0] + e.OuterChunk = exec.TryNewCacheChunk(e.OuterExec) + e.InnerChunk = exec.TryNewCacheChunk(e.InnerExec) + e.InnerList = chunk.NewList(exec.RetTypes(e.InnerExec), e.InitCap(), e.MaxChunkSize()) + + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + + e.InnerList.GetMemTracker().SetLabel(memory.LabelForInnerList) + e.InnerList.GetMemTracker().AttachTo(e.memTracker) + + if e.CanUseCache { + e.cache, err = applycache.NewApplyCache(e.Sctx) + if err != nil { + return err + } + e.cacheHitCounter = 0 + e.cacheAccessCounter = 0 + e.cache.GetMemTracker().AttachTo(e.memTracker) + } + return nil +} + +// aggExecutorTreeInputEmpty checks whether the executor tree returns empty if without aggregate operators. +// Note that, the prerequisite is that this executor tree has been executed already and it returns one Row. +func aggExecutorTreeInputEmpty(e exec.Executor) bool { + children := e.AllChildren() + if len(children) == 0 { + return false + } + if len(children) > 1 { + _, ok := e.(*unionexec.UnionExec) + if !ok { + // It is a Join executor. + return false + } + for _, child := range children { + if !aggExecutorTreeInputEmpty(child) { + return false + } + } + return true + } + // Single child executors. + if aggExecutorTreeInputEmpty(children[0]) { + return true + } + if hashAgg, ok := e.(*aggregate.HashAggExec); ok { + return hashAgg.IsChildReturnEmpty + } + if streamAgg, ok := e.(*aggregate.StreamAggExec); ok { + return streamAgg.IsChildReturnEmpty + } + return false +} + +func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *chunk.Chunk) (*chunk.Row, error) { + outerIter := chunk.NewIterator4Chunk(e.OuterChunk) + for { + if e.outerChunkCursor >= e.OuterChunk.NumRows() { + err := exec.Next(ctx, e.OuterExec, e.OuterChunk) + if err != nil { + return nil, err + } + if e.OuterChunk.NumRows() == 0 { + return nil, nil + } + e.outerSelected, err = expression.VectorizedFilter(e.Sctx.GetExprCtx().GetEvalCtx(), e.Sctx.GetSessionVars().EnableVectorizedExpression, e.OuterFilter, outerIter, e.outerSelected) + if err != nil { + return nil, err + } + // For cases like `select count(1), (select count(1) from s where s.a > t.a) as sub from t where t.a = 1`, + // if outer child has no row satisfying `t.a = 1`, `sub` should be `null` instead of `0` theoretically; however, the + // outer `count(1)` produces one row <0, null> over the empty input, we should specially mark this outer row + // as not selected, to trigger the mismatch join procedure. + if e.outerChunkCursor == 0 && e.OuterChunk.NumRows() == 1 && e.outerSelected[0] && aggExecutorTreeInputEmpty(e.OuterExec) { + e.outerSelected[0] = false + } + e.outerChunkCursor = 0 + } + outerRow := e.OuterChunk.GetRow(e.outerChunkCursor) + selected := e.outerSelected[e.outerChunkCursor] + e.outerChunkCursor++ + if selected { + return &outerRow, nil + } else if e.Outer { + e.Joiner.OnMissMatch(false, outerRow, chk) + if chk.IsFull() { + return nil, nil + } + } + } +} + +// fetchAllInners reads all data from the inner table and stores them in a List. +func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { + err := exec.Open(ctx, e.InnerExec) + defer func() { terror.Log(exec.Close(e.InnerExec)) }() + if err != nil { + return err + } + + if e.CanUseCache { + // create a new one since it may be in the cache + e.InnerList = chunk.NewListWithMemTracker(exec.RetTypes(e.InnerExec), e.InitCap(), e.MaxChunkSize(), e.InnerList.GetMemTracker()) + } else { + e.InnerList.Reset() + } + innerIter := chunk.NewIterator4Chunk(e.InnerChunk) + for { + err := exec.Next(ctx, e.InnerExec, e.InnerChunk) + if err != nil { + return err + } + if e.InnerChunk.NumRows() == 0 { + return nil + } + + e.innerSelected, err = expression.VectorizedFilter(e.Sctx.GetExprCtx().GetEvalCtx(), e.Sctx.GetSessionVars().EnableVectorizedExpression, e.InnerFilter, innerIter, e.innerSelected) + if err != nil { + return err + } + for row := innerIter.Begin(); row != innerIter.End(); row = innerIter.Next() { + if e.innerSelected[row.Idx()] { + e.InnerList.AppendRow(row) + } + } + } +} + +// Next implements the Executor interface. +func (e *NestedLoopApplyExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.Reset() + for { + if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { + if e.outerRow != nil && !e.hasMatch { + e.Joiner.OnMissMatch(e.hasNull, *e.outerRow, req) + } + e.outerRow, err = e.fetchSelectedOuterRow(ctx, req) + if e.outerRow == nil || err != nil { + return err + } + e.hasMatch = false + e.hasNull = false + + if e.CanUseCache { + var key []byte + for _, col := range e.OuterSchema { + *col.Data = e.outerRow.GetDatum(col.Index, col.RetType) + key, err = codec.EncodeKey(e.Ctx().GetSessionVars().StmtCtx.TimeZone(), key, *col.Data) + err = e.Ctx().GetSessionVars().StmtCtx.HandleError(err) + if err != nil { + return err + } + } + e.cacheAccessCounter++ + value, err := e.cache.Get(key) + if err != nil { + return err + } + if value != nil { + e.InnerList = value + e.cacheHitCounter++ + } else { + err = e.fetchAllInners(ctx) + if err != nil { + return err + } + if _, err := e.cache.Set(key, e.InnerList); err != nil { + return err + } + } + } else { + for _, col := range e.OuterSchema { + *col.Data = e.outerRow.GetDatum(col.Index, col.RetType) + } + err = e.fetchAllInners(ctx) + if err != nil { + return err + } + } + e.innerIter = chunk.NewIterator4List(e.InnerList) + e.innerIter.Begin() + } + + matched, isNull, err := e.Joiner.TryToMatchInners(*e.outerRow, e.innerIter, req) + e.hasMatch = e.hasMatch || matched + e.hasNull = e.hasNull || isNull + + if err != nil || req.IsFull() { + return err + } + } +} + +// cacheInfo is used to save the concurrency information of the executor operator +type cacheInfo struct { + hitRatio float64 + useCache bool +} + +type joinRuntimeStats struct { + *execdetails.RuntimeStatsWithConcurrencyInfo + + applyCache bool + cache cacheInfo + hasHashStat bool + hashStat hashStatistic +} + +// NewJoinRuntimeStats returns a new joinRuntimeStats +func NewJoinRuntimeStats() *joinRuntimeStats { + stats := &joinRuntimeStats{ + RuntimeStatsWithConcurrencyInfo: &execdetails.RuntimeStatsWithConcurrencyInfo{}, + } + return stats +} + +// SetCacheInfo sets the cache information. Only used for apply executor. +func (e *joinRuntimeStats) SetCacheInfo(useCache bool, hitRatio float64) { + e.Lock() + e.applyCache = true + e.cache.useCache = useCache + e.cache.hitRatio = hitRatio + e.Unlock() +} + +func (e *joinRuntimeStats) String() string { + buf := bytes.NewBuffer(make([]byte, 0, 16)) + buf.WriteString(e.RuntimeStatsWithConcurrencyInfo.String()) + if e.applyCache { + if e.cache.useCache { + fmt.Fprintf(buf, ", cache:ON, cacheHitRatio:%.3f%%", e.cache.hitRatio*100) + } else { + buf.WriteString(", cache:OFF") + } + } + if e.hasHashStat { + buf.WriteString(", " + e.hashStat.String()) + } + return buf.String() +} + +// Tp implements the RuntimeStats interface. +func (*joinRuntimeStats) Tp() int { + return execdetails.TpJoinRuntimeStats +} + +func (e *joinRuntimeStats) Clone() execdetails.RuntimeStats { + newJRS := &joinRuntimeStats{ + RuntimeStatsWithConcurrencyInfo: e.RuntimeStatsWithConcurrencyInfo, + applyCache: e.applyCache, + cache: e.cache, + hasHashStat: e.hasHashStat, + hashStat: e.hashStat, + } + return newJRS +} diff --git a/pkg/executor/join/hash_join_v2.go b/pkg/executor/join/hash_join_v2.go index 9fbb587af4a3e..d5bcd9ce69ba1 100644 --- a/pkg/executor/join/hash_join_v2.go +++ b/pkg/executor/join/hash_join_v2.go @@ -95,7 +95,7 @@ func (htc *hashTableContext) getCurrentRowSegment(workerID, partitionID int, tab func (htc *hashTableContext) finalizeCurrentSeg(workerID, partitionID int, builder *rowTableBuilder) { seg := htc.getCurrentRowSegment(workerID, partitionID, nil, false, 0) builder.rowNumberInCurrentRowTableSeg[partitionID] = 0 - failpoint.Inject("finalizeCurrentSegPanic", nil) + failpoint.Eval(_curpkg_("finalizeCurrentSegPanic")) seg.finalized = true htc.memoryTracker.Consume(seg.totalUsedBytes()) } @@ -360,7 +360,7 @@ func (w *BuildWorkerV2) splitPartitionAndAppendToRowTable(typeCtx types.Context, for chk := range srcChkCh { start := time.Now() err = builder.processOneChunk(chk, typeCtx, w.HashJoinCtx, int(w.WorkerID)) - failpoint.Inject("splitPartitionPanic", nil) + failpoint.Eval(_curpkg_("splitPartitionPanic")) cost += int64(time.Since(start)) if err != nil { return err @@ -495,7 +495,7 @@ func (w *ProbeWorkerV2) processOneProbeChunk(probeChunk *chunk.Chunk, joinResult if !ok || joinResult.err != nil { return ok, waitTime, joinResult } - failpoint.Inject("processOneProbeChunkPanic", nil) + failpoint.Eval(_curpkg_("processOneProbeChunkPanic")) if joinResult.chk.IsFull() { waitStart := time.Now() w.HashJoinCtx.joinResultCh <- joinResult @@ -542,7 +542,7 @@ func (w *ProbeWorkerV2) runJoinWorker() { return case probeSideResult, ok = <-w.probeResultCh: } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if !ok { break } @@ -648,7 +648,7 @@ func (e *HashJoinV2Exec) createTasks(buildTaskCh chan<- *buildTask, totalSegment createBuildTask := func(partIdx int, segStartIdx int, segEndIdx int) *buildTask { return &buildTask{partitionIdx: partIdx, segStartIdx: segStartIdx, segEndIdx: segEndIdx} } - failpoint.Inject("createTasksPanic", nil) + failpoint.Eval(_curpkg_("createTasksPanic")) if isBalanced { for partIdx, subTable := range subTables { @@ -831,7 +831,7 @@ func (w *BuildWorkerV2) buildHashTable(taskCh chan *buildTask) error { start := time.Now() partIdx, segStartIdx, segEndIdx := task.partitionIdx, task.segStartIdx, task.segEndIdx w.HashJoinCtx.hashTableContext.hashTable.tables[partIdx].build(segStartIdx, segEndIdx) - failpoint.Inject("buildHashTablePanic", nil) + failpoint.Eval(_curpkg_("buildHashTablePanic")) cost += int64(time.Since(start)) } return nil diff --git a/pkg/executor/join/hash_join_v2.go__failpoint_stash__ b/pkg/executor/join/hash_join_v2.go__failpoint_stash__ new file mode 100644 index 0000000000000..9fbb587af4a3e --- /dev/null +++ b/pkg/executor/join/hash_join_v2.go__failpoint_stash__ @@ -0,0 +1,943 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "bytes" + "context" + "math" + "runtime/trace" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/memory" +) + +var ( + _ exec.Executor = &HashJoinV2Exec{} + // enableHashJoinV2 is a variable used only in test + enableHashJoinV2 = atomic.Bool{} +) + +// IsHashJoinV2Enabled return true if hash join v2 is enabled +func IsHashJoinV2Enabled() bool { + // sizeOfUintptr should always equal to sizeOfUnsafePointer, because according to golang's doc, + // a Pointer can be converted to an uintptr. Add this check here in case in the future go runtime + // change this + return !heapObjectsCanMove() && enableHashJoinV2.Load() && sizeOfUintptr >= sizeOfUnsafePointer +} + +// SetEnableHashJoinV2 enable/disable hash join v2 +func SetEnableHashJoinV2(enable bool) { + enableHashJoinV2.Store(enable) +} + +type hashTableContext struct { + // rowTables is used during split partition stage, each buildWorker has + // its own rowTable + rowTables [][]*rowTable + hashTable *hashTableV2 + memoryTracker *memory.Tracker +} + +func (htc *hashTableContext) reset() { + htc.rowTables = nil + htc.hashTable = nil + htc.memoryTracker.Detach() +} + +func (htc *hashTableContext) getCurrentRowSegment(workerID, partitionID int, tableMeta *TableMeta, allowCreate bool, firstSegSizeHint uint) *rowTableSegment { + if htc.rowTables[workerID][partitionID] == nil { + htc.rowTables[workerID][partitionID] = newRowTable(tableMeta) + } + segNum := len(htc.rowTables[workerID][partitionID].segments) + if segNum == 0 || htc.rowTables[workerID][partitionID].segments[segNum-1].finalized { + if !allowCreate { + panic("logical error, should not reach here") + } + // do not pre-allocate too many memory for the first seg because for query that only has a few rows, it may waste memory and may hurt the performance in high concurrency scenarios + rowSizeHint := maxRowTableSegmentSize + if segNum == 0 { + rowSizeHint = int(firstSegSizeHint) + } + seg := newRowTableSegment(uint(rowSizeHint)) + htc.rowTables[workerID][partitionID].segments = append(htc.rowTables[workerID][partitionID].segments, seg) + segNum++ + } + return htc.rowTables[workerID][partitionID].segments[segNum-1] +} + +func (htc *hashTableContext) finalizeCurrentSeg(workerID, partitionID int, builder *rowTableBuilder) { + seg := htc.getCurrentRowSegment(workerID, partitionID, nil, false, 0) + builder.rowNumberInCurrentRowTableSeg[partitionID] = 0 + failpoint.Inject("finalizeCurrentSegPanic", nil) + seg.finalized = true + htc.memoryTracker.Consume(seg.totalUsedBytes()) +} + +func (htc *hashTableContext) mergeRowTablesToHashTable(tableMeta *TableMeta, partitionNumber uint) int { + rowTables := make([]*rowTable, partitionNumber) + for i := 0; i < int(partitionNumber); i++ { + rowTables[i] = newRowTable(tableMeta) + } + totalSegmentCnt := 0 + for _, rowTablesPerWorker := range htc.rowTables { + for partIdx, rt := range rowTablesPerWorker { + if rt == nil { + continue + } + rowTables[partIdx].merge(rt) + totalSegmentCnt += len(rt.segments) + } + } + for i := 0; i < int(partitionNumber); i++ { + htc.hashTable.tables[i] = newSubTable(rowTables[i]) + } + htc.rowTables = nil + return totalSegmentCnt +} + +// HashJoinCtxV2 is the hash join ctx used in hash join v2 +type HashJoinCtxV2 struct { + hashJoinCtxBase + partitionNumber uint + partitionMaskOffset int + ProbeKeyTypes []*types.FieldType + BuildKeyTypes []*types.FieldType + stats *hashJoinRuntimeStatsV2 + + RightAsBuildSide bool + BuildFilter expression.CNFExprs + ProbeFilter expression.CNFExprs + OtherCondition expression.CNFExprs + hashTableContext *hashTableContext + hashTableMeta *TableMeta + needScanRowTableAfterProbeDone bool + + LUsed, RUsed []int + LUsedInOtherCondition, RUsedInOtherCondition []int +} + +// partitionNumber is always power of 2 +func genHashJoinPartitionNumber(partitionHint uint) uint { + prevRet := uint(16) + currentRet := uint(8) + for currentRet != 0 { + if currentRet < partitionHint { + return prevRet + } + prevRet = currentRet + currentRet = currentRet >> 1 + } + return 1 +} + +func getPartitionMaskOffset(partitionNumber uint) int { + getMSBPos := func(num uint64) int { + ret := 0 + for num&1 != 1 { + num = num >> 1 + ret++ + } + if num != 1 { + // partitionNumber is always pow of 2 + panic("should not reach here") + } + return ret + } + msbPos := getMSBPos(uint64(partitionNumber)) + // top MSB bits in hash value will be used to partition data + return 64 - msbPos +} + +// SetupPartitionInfo set up partitionNumber and partitionMaskOffset based on concurrency +func (hCtx *HashJoinCtxV2) SetupPartitionInfo() { + hCtx.partitionNumber = genHashJoinPartitionNumber(hCtx.Concurrency) + hCtx.partitionMaskOffset = getPartitionMaskOffset(hCtx.partitionNumber) +} + +// initHashTableContext create hashTableContext for current HashJoinCtxV2 +func (hCtx *HashJoinCtxV2) initHashTableContext() { + hCtx.hashTableContext = &hashTableContext{} + hCtx.hashTableContext.rowTables = make([][]*rowTable, hCtx.Concurrency) + for index := range hCtx.hashTableContext.rowTables { + hCtx.hashTableContext.rowTables[index] = make([]*rowTable, hCtx.partitionNumber) + } + hCtx.hashTableContext.hashTable = &hashTableV2{ + tables: make([]*subTable, hCtx.partitionNumber), + partitionNumber: uint64(hCtx.partitionNumber), + } + hCtx.hashTableContext.memoryTracker = memory.NewTracker(memory.LabelForHashTableInHashJoinV2, -1) +} + +// ProbeSideTupleFetcherV2 reads tuples from ProbeSideExec and send them to ProbeWorkers. +type ProbeSideTupleFetcherV2 struct { + probeSideTupleFetcherBase + *HashJoinCtxV2 + canSkipProbeIfHashTableIsEmpty bool +} + +// ProbeWorkerV2 is the probe worker used in hash join v2 +type ProbeWorkerV2 struct { + probeWorkerBase + HashJoinCtx *HashJoinCtxV2 + // We build individual joinProbe for each join worker when use chunk-based + // execution, to avoid the concurrency of joiner.chk and joiner.selected. + JoinProbe ProbeV2 +} + +// BuildWorkerV2 is the build worker used in hash join v2 +type BuildWorkerV2 struct { + buildWorkerBase + HashJoinCtx *HashJoinCtxV2 + BuildTypes []*types.FieldType + HasNullableKey bool + WorkerID uint +} + +// NewJoinBuildWorkerV2 create a BuildWorkerV2 +func NewJoinBuildWorkerV2(ctx *HashJoinCtxV2, workID uint, buildSideExec exec.Executor, buildKeyColIdx []int, buildTypes []*types.FieldType) *BuildWorkerV2 { + hasNullableKey := false + for _, idx := range buildKeyColIdx { + if !mysql.HasNotNullFlag(buildTypes[idx].GetFlag()) { + hasNullableKey = true + break + } + } + worker := &BuildWorkerV2{ + HashJoinCtx: ctx, + BuildTypes: buildTypes, + WorkerID: workID, + HasNullableKey: hasNullableKey, + } + worker.BuildSideExec = buildSideExec + worker.BuildKeyColIdx = buildKeyColIdx + return worker +} + +// HashJoinV2Exec implements the hash join algorithm. +type HashJoinV2Exec struct { + exec.BaseExecutor + *HashJoinCtxV2 + + ProbeSideTupleFetcher *ProbeSideTupleFetcherV2 + ProbeWorkers []*ProbeWorkerV2 + BuildWorkers []*BuildWorkerV2 + + workerWg util.WaitGroupWrapper + waiterWg util.WaitGroupWrapper + + prepared bool +} + +// Close implements the Executor Close interface. +func (e *HashJoinV2Exec) Close() error { + if e.closeCh != nil { + close(e.closeCh) + } + e.finished.Store(true) + if e.prepared { + if e.buildFinished != nil { + channel.Clear(e.buildFinished) + } + if e.joinResultCh != nil { + channel.Clear(e.joinResultCh) + } + if e.ProbeSideTupleFetcher.probeChkResourceCh != nil { + close(e.ProbeSideTupleFetcher.probeChkResourceCh) + channel.Clear(e.ProbeSideTupleFetcher.probeChkResourceCh) + } + for i := range e.ProbeSideTupleFetcher.probeResultChs { + channel.Clear(e.ProbeSideTupleFetcher.probeResultChs[i]) + } + for i := range e.ProbeWorkers { + close(e.ProbeWorkers[i].joinChkResourceCh) + channel.Clear(e.ProbeWorkers[i].joinChkResourceCh) + } + e.ProbeSideTupleFetcher.probeChkResourceCh = nil + e.waiterWg.Wait() + e.hashTableContext.reset() + } + for _, w := range e.ProbeWorkers { + w.joinChkResourceCh = nil + } + + if e.stats != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + err := e.BaseExecutor.Close() + return err +} + +// Open implements the Executor Open interface. +func (e *HashJoinV2Exec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + e.closeCh = nil + e.prepared = false + return err + } + e.prepared = false + needScanRowTableAfterProbeDone := e.ProbeWorkers[0].JoinProbe.NeedScanRowTable() + e.HashJoinCtxV2.needScanRowTableAfterProbeDone = needScanRowTableAfterProbeDone + if e.RightAsBuildSide { + e.hashTableMeta = newTableMeta(e.BuildWorkers[0].BuildKeyColIdx, e.BuildWorkers[0].BuildTypes, + e.BuildKeyTypes, e.ProbeKeyTypes, e.RUsedInOtherCondition, e.RUsed, needScanRowTableAfterProbeDone) + } else { + e.hashTableMeta = newTableMeta(e.BuildWorkers[0].BuildKeyColIdx, e.BuildWorkers[0].BuildTypes, + e.BuildKeyTypes, e.ProbeKeyTypes, e.LUsedInOtherCondition, e.LUsed, needScanRowTableAfterProbeDone) + } + e.HashJoinCtxV2.ChunkAllocPool = e.AllocPool + if e.memTracker != nil { + e.memTracker.Reset() + } else { + e.memTracker = memory.NewTracker(e.ID(), -1) + } + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + + e.diskTracker = disk.NewTracker(e.ID(), -1) + e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) + + e.workerWg = util.WaitGroupWrapper{} + e.waiterWg = util.WaitGroupWrapper{} + e.closeCh = make(chan struct{}) + e.finished.Store(false) + + if e.RuntimeStats() != nil { + e.stats = &hashJoinRuntimeStatsV2{} + e.stats.concurrent = int(e.Concurrency) + } + return nil +} + +func (fetcher *ProbeSideTupleFetcherV2) shouldLimitProbeFetchSize() bool { + if fetcher.JoinType == plannercore.LeftOuterJoin && fetcher.RightAsBuildSide { + return true + } + if fetcher.JoinType == plannercore.RightOuterJoin && !fetcher.RightAsBuildSide { + return true + } + return false +} + +func (w *BuildWorkerV2) splitPartitionAndAppendToRowTable(typeCtx types.Context, srcChkCh chan *chunk.Chunk) (err error) { + cost := int64(0) + defer func() { + if w.HashJoinCtx.stats != nil { + atomic.AddInt64(&w.HashJoinCtx.stats.partitionData, cost) + setMaxValue(&w.HashJoinCtx.stats.maxPartitionData, cost) + } + }() + partitionNumber := w.HashJoinCtx.partitionNumber + hashJoinCtx := w.HashJoinCtx + + builder := createRowTableBuilder(w.BuildKeyColIdx, hashJoinCtx.BuildKeyTypes, partitionNumber, w.HasNullableKey, hashJoinCtx.BuildFilter != nil, hashJoinCtx.needScanRowTableAfterProbeDone) + + for chk := range srcChkCh { + start := time.Now() + err = builder.processOneChunk(chk, typeCtx, w.HashJoinCtx, int(w.WorkerID)) + failpoint.Inject("splitPartitionPanic", nil) + cost += int64(time.Since(start)) + if err != nil { + return err + } + } + start := time.Now() + builder.appendRemainingRowLocations(int(w.WorkerID), w.HashJoinCtx.hashTableContext) + cost += int64(time.Since(start)) + return nil +} + +func (e *HashJoinV2Exec) canSkipProbeIfHashTableIsEmpty() bool { + switch e.JoinType { + case plannercore.InnerJoin: + return true + case plannercore.LeftOuterJoin: + return !e.RightAsBuildSide + case plannercore.RightOuterJoin: + return e.RightAsBuildSide + case plannercore.SemiJoin: + return e.RightAsBuildSide + default: + return false + } +} + +func (e *HashJoinV2Exec) initializeForProbe() { + e.ProbeSideTupleFetcher.HashJoinCtxV2 = e.HashJoinCtxV2 + // e.joinResultCh is for transmitting the join result chunks to the main + // thread. + e.joinResultCh = make(chan *hashjoinWorkerResult, e.Concurrency+1) + e.ProbeSideTupleFetcher.initializeForProbeBase(e.Concurrency, e.joinResultCh) + e.ProbeSideTupleFetcher.canSkipProbeIfHashTableIsEmpty = e.canSkipProbeIfHashTableIsEmpty() + + for i := uint(0); i < e.Concurrency; i++ { + e.ProbeWorkers[i].initializeForProbe(e.ProbeSideTupleFetcher.probeChkResourceCh, e.ProbeSideTupleFetcher.probeResultChs[i], e) + e.ProbeWorkers[i].JoinProbe.ResetProbeCollision() + } +} + +func (e *HashJoinV2Exec) fetchAndProbeHashTable(ctx context.Context) { + e.initializeForProbe() + fetchProbeSideChunksFunc := func() { + defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() + e.ProbeSideTupleFetcher.fetchProbeSideChunks( + ctx, + e.MaxChunkSize(), + func() bool { return e.ProbeSideTupleFetcher.hashTableContext.hashTable.isHashTableEmpty() }, + e.ProbeSideTupleFetcher.canSkipProbeIfHashTableIsEmpty, + e.ProbeSideTupleFetcher.needScanRowTableAfterProbeDone, + e.ProbeSideTupleFetcher.shouldLimitProbeFetchSize(), + &e.ProbeSideTupleFetcher.hashJoinCtxBase) + } + e.workerWg.RunWithRecover(fetchProbeSideChunksFunc, e.ProbeSideTupleFetcher.handleProbeSideFetcherPanic) + + for i := uint(0); i < e.Concurrency; i++ { + workerID := i + e.workerWg.RunWithRecover(func() { + defer trace.StartRegion(ctx, "HashJoinWorker").End() + e.ProbeWorkers[workerID].runJoinWorker() + }, e.ProbeWorkers[workerID].handleProbeWorkerPanic) + } + e.waiterWg.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) +} + +func (w *ProbeWorkerV2) handleProbeWorkerPanic(r any) { + if r != nil { + w.HashJoinCtx.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} + } +} + +func (e *HashJoinV2Exec) handleJoinWorkerPanic(r any) { + if r != nil { + e.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} + } +} + +func (e *HashJoinV2Exec) waitJoinWorkersAndCloseResultChan() { + e.workerWg.Wait() + if e.stats != nil { + for _, prober := range e.ProbeWorkers { + e.stats.hashStat.probeCollision += int64(prober.JoinProbe.GetProbeCollision()) + } + } + if e.ProbeWorkers[0] != nil && e.ProbeWorkers[0].JoinProbe.NeedScanRowTable() { + for i := uint(0); i < e.Concurrency; i++ { + var workerID = i + e.workerWg.RunWithRecover(func() { + e.ProbeWorkers[workerID].scanRowTableAfterProbeDone() + }, e.handleJoinWorkerPanic) + } + e.workerWg.Wait() + } + close(e.joinResultCh) +} + +func (w *ProbeWorkerV2) scanRowTableAfterProbeDone() { + w.JoinProbe.InitForScanRowTable() + ok, joinResult := w.getNewJoinResult() + if !ok { + return + } + for !w.JoinProbe.IsScanRowTableDone() { + joinResult = w.JoinProbe.ScanRowTable(joinResult, &w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller) + if joinResult.err != nil { + w.HashJoinCtx.joinResultCh <- joinResult + return + } + if joinResult.chk.IsFull() { + w.HashJoinCtx.joinResultCh <- joinResult + ok, joinResult = w.getNewJoinResult() + if !ok { + return + } + } + } + if joinResult == nil { + return + } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { + w.HashJoinCtx.joinResultCh <- joinResult + } +} + +func (w *ProbeWorkerV2) processOneProbeChunk(probeChunk *chunk.Chunk, joinResult *hashjoinWorkerResult) (ok bool, waitTime int64, _ *hashjoinWorkerResult) { + waitTime = 0 + joinResult.err = w.JoinProbe.SetChunkForProbe(probeChunk) + if joinResult.err != nil { + return false, waitTime, joinResult + } + for !w.JoinProbe.IsCurrentChunkProbeDone() { + ok, joinResult = w.JoinProbe.Probe(joinResult, &w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller) + if !ok || joinResult.err != nil { + return ok, waitTime, joinResult + } + failpoint.Inject("processOneProbeChunkPanic", nil) + if joinResult.chk.IsFull() { + waitStart := time.Now() + w.HashJoinCtx.joinResultCh <- joinResult + ok, joinResult = w.getNewJoinResult() + waitTime += int64(time.Since(waitStart)) + if !ok { + return false, waitTime, joinResult + } + } + } + return true, waitTime, joinResult +} + +func (w *ProbeWorkerV2) runJoinWorker() { + probeTime := int64(0) + if w.HashJoinCtx.stats != nil { + start := time.Now() + defer func() { + t := time.Since(start) + atomic.AddInt64(&w.HashJoinCtx.stats.probe, probeTime) + atomic.AddInt64(&w.HashJoinCtx.stats.fetchAndProbe, int64(t)) + setMaxValue(&w.HashJoinCtx.stats.maxFetchAndProbe, int64(t)) + }() + } + + var ( + probeSideResult *chunk.Chunk + ) + ok, joinResult := w.getNewJoinResult() + if !ok { + return + } + + // Read and filter probeSideResult, and join the probeSideResult with the build side rows. + emptyProbeSideResult := &probeChkResource{ + dest: w.probeResultCh, + } + for ok := true; ok; { + if w.HashJoinCtx.finished.Load() { + break + } + select { + case <-w.HashJoinCtx.closeCh: + return + case probeSideResult, ok = <-w.probeResultCh: + } + failpoint.Inject("ConsumeRandomPanic", nil) + if !ok { + break + } + + start := time.Now() + waitTime := int64(0) + ok, waitTime, joinResult = w.processOneProbeChunk(probeSideResult, joinResult) + probeTime += int64(time.Since(start)) - waitTime + if !ok { + break + } + probeSideResult.Reset() + emptyProbeSideResult.chk = probeSideResult + w.probeChkResourceCh <- emptyProbeSideResult + } + // note joinResult.chk may be nil when getNewJoinResult fails in loops + if joinResult == nil { + return + } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { + w.HashJoinCtx.joinResultCh <- joinResult + } else if joinResult.chk != nil && joinResult.chk.NumRows() == 0 { + w.joinChkResourceCh <- joinResult.chk + } +} + +func (w *ProbeWorkerV2) getNewJoinResult() (bool, *hashjoinWorkerResult) { + joinResult := &hashjoinWorkerResult{ + src: w.joinChkResourceCh, + } + ok := true + select { + case <-w.HashJoinCtx.closeCh: + ok = false + case joinResult.chk, ok = <-w.joinChkResourceCh: + } + return ok, joinResult +} + +// Next implements the Executor Next interface. +// hash join constructs the result following these steps: +// step 1. fetch data from build side child and build a hash table; +// step 2. fetch data from probe child in a background goroutine and probe the hash table in multiple join workers. +func (e *HashJoinV2Exec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + if !e.prepared { + e.initHashTableContext() + e.hashTableContext.memoryTracker.AttachTo(e.memTracker) + e.buildFinished = make(chan error, 1) + e.workerWg.RunWithRecover(func() { + defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End() + e.fetchAndBuildHashTable(ctx) + }, e.handleFetchAndBuildHashTablePanic) + e.fetchAndProbeHashTable(ctx) + e.prepared = true + } + if e.ProbeSideTupleFetcher.shouldLimitProbeFetchSize() { + atomic.StoreInt64(&e.ProbeSideTupleFetcher.requiredRows, int64(req.RequiredRows())) + } + req.Reset() + + result, ok := <-e.joinResultCh + if !ok { + return nil + } + if result.err != nil { + e.finished.Store(true) + return result.err + } + req.SwapColumns(result.chk) + result.src <- result.chk + return nil +} + +func (e *HashJoinV2Exec) handleFetchAndBuildHashTablePanic(r any) { + if r != nil { + e.buildFinished <- util.GetRecoverError(r) + } + close(e.buildFinished) +} + +// checkBalance checks whether the segment count of each partition is balanced. +func (e *HashJoinV2Exec) checkBalance(totalSegmentCnt int) bool { + isBalanced := e.Concurrency == e.partitionNumber + if !isBalanced { + return false + } + avgSegCnt := totalSegmentCnt / int(e.partitionNumber) + balanceThreshold := int(float64(avgSegCnt) * 0.8) + subTables := e.HashJoinCtxV2.hashTableContext.hashTable.tables + + for _, subTable := range subTables { + if math.Abs(float64(len(subTable.rowData.segments)-avgSegCnt)) > float64(balanceThreshold) { + isBalanced = false + break + } + } + return isBalanced +} + +func (e *HashJoinV2Exec) createTasks(buildTaskCh chan<- *buildTask, totalSegmentCnt int, doneCh chan struct{}) { + isBalanced := e.checkBalance(totalSegmentCnt) + segStep := max(1, totalSegmentCnt/int(e.Concurrency)) + subTables := e.HashJoinCtxV2.hashTableContext.hashTable.tables + createBuildTask := func(partIdx int, segStartIdx int, segEndIdx int) *buildTask { + return &buildTask{partitionIdx: partIdx, segStartIdx: segStartIdx, segEndIdx: segEndIdx} + } + failpoint.Inject("createTasksPanic", nil) + + if isBalanced { + for partIdx, subTable := range subTables { + segmentsLen := len(subTable.rowData.segments) + select { + case <-doneCh: + return + case buildTaskCh <- createBuildTask(partIdx, 0, segmentsLen): + } + } + return + } + + partitionStartIndex := make([]int, len(subTables)) + partitionSegmentLength := make([]int, len(subTables)) + for i := 0; i < len(subTables); i++ { + partitionStartIndex[i] = 0 + partitionSegmentLength[i] = len(subTables[i].rowData.segments) + } + + for { + hasNewTask := false + for partIdx := range subTables { + // create table by round-robin all the partitions so the build thread is likely to build different partition at the same time + if partitionStartIndex[partIdx] < partitionSegmentLength[partIdx] { + startIndex := partitionStartIndex[partIdx] + endIndex := min(startIndex+segStep, partitionSegmentLength[partIdx]) + select { + case <-doneCh: + return + case buildTaskCh <- createBuildTask(partIdx, startIndex, endIndex): + } + partitionStartIndex[partIdx] = endIndex + hasNewTask = true + } + } + if !hasNewTask { + break + } + } +} + +func (e *HashJoinV2Exec) fetchAndBuildHashTable(ctx context.Context) { + if e.stats != nil { + start := time.Now() + defer func() { + e.stats.fetchAndBuildHashTable = time.Since(start) + }() + } + + waitJobDone := func(wg *sync.WaitGroup, errCh chan error) bool { + wg.Wait() + close(errCh) + if err := <-errCh; err != nil { + e.buildFinished <- err + return false + } + return true + } + + wg := new(sync.WaitGroup) + errCh := make(chan error, 1+e.Concurrency) + // doneCh is used by the consumer(splitAndAppendToRowTable) to info the producer(fetchBuildSideRows) that the consumer meet error and stop consume data + doneCh := make(chan struct{}, e.Concurrency) + srcChkCh := e.fetchBuildSideRows(ctx, wg, errCh, doneCh) + e.splitAndAppendToRowTable(srcChkCh, wg, errCh, doneCh) + success := waitJobDone(wg, errCh) + if !success { + return + } + + totalSegmentCnt := e.hashTableContext.mergeRowTablesToHashTable(e.hashTableMeta, e.partitionNumber) + + wg = new(sync.WaitGroup) + errCh = make(chan error, 1+e.Concurrency) + // doneCh is used by the consumer(buildHashTable) to info the producer(createBuildTasks) that the consumer meet error and stop consume data + doneCh = make(chan struct{}, e.Concurrency) + buildTaskCh := e.createBuildTasks(totalSegmentCnt, wg, errCh, doneCh) + e.buildHashTable(buildTaskCh, wg, errCh, doneCh) + waitJobDone(wg, errCh) +} + +func (e *HashJoinV2Exec) fetchBuildSideRows(ctx context.Context, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) chan *chunk.Chunk { + srcChkCh := make(chan *chunk.Chunk, 1) + wg.Add(1) + e.workerWg.RunWithRecover( + func() { + defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End() + fetcher := e.BuildWorkers[0] + fetcher.fetchBuildSideRows(ctx, &fetcher.HashJoinCtx.hashJoinCtxBase, srcChkCh, errCh, doneCh) + }, + func(r any) { + if r != nil { + errCh <- util.GetRecoverError(r) + } + wg.Done() + }, + ) + return srcChkCh +} + +func (e *HashJoinV2Exec) splitAndAppendToRowTable(srcChkCh chan *chunk.Chunk, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) { + for i := uint(0); i < e.Concurrency; i++ { + wg.Add(1) + workIndex := i + e.workerWg.RunWithRecover( + func() { + err := e.BuildWorkers[workIndex].splitPartitionAndAppendToRowTable(e.SessCtx.GetSessionVars().StmtCtx.TypeCtx(), srcChkCh) + if err != nil { + errCh <- err + doneCh <- struct{}{} + } + }, + func(r any) { + if r != nil { + errCh <- util.GetRecoverError(r) + doneCh <- struct{}{} + } + wg.Done() + }, + ) + } +} + +func (e *HashJoinV2Exec) createBuildTasks(totalSegmentCnt int, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) chan *buildTask { + buildTaskCh := make(chan *buildTask, e.Concurrency) + wg.Add(1) + e.workerWg.RunWithRecover( + func() { e.createTasks(buildTaskCh, totalSegmentCnt, doneCh) }, + func(r any) { + if r != nil { + errCh <- util.GetRecoverError(r) + } + close(buildTaskCh) + wg.Done() + }, + ) + return buildTaskCh +} + +func (e *HashJoinV2Exec) buildHashTable(buildTaskCh chan *buildTask, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) { + for i := uint(0); i < e.Concurrency; i++ { + wg.Add(1) + workID := i + e.workerWg.RunWithRecover( + func() { + err := e.BuildWorkers[workID].buildHashTable(buildTaskCh) + if err != nil { + errCh <- err + doneCh <- struct{}{} + } + }, + func(r any) { + if r != nil { + errCh <- util.GetRecoverError(r) + doneCh <- struct{}{} + } + wg.Done() + }, + ) + } +} + +type buildTask struct { + partitionIdx int + segStartIdx int + segEndIdx int +} + +// buildHashTableForList builds hash table from `list`. +func (w *BuildWorkerV2) buildHashTable(taskCh chan *buildTask) error { + cost := int64(0) + defer func() { + if w.HashJoinCtx.stats != nil { + atomic.AddInt64(&w.HashJoinCtx.stats.buildHashTable, cost) + setMaxValue(&w.HashJoinCtx.stats.maxBuildHashTable, cost) + } + }() + for task := range taskCh { + start := time.Now() + partIdx, segStartIdx, segEndIdx := task.partitionIdx, task.segStartIdx, task.segEndIdx + w.HashJoinCtx.hashTableContext.hashTable.tables[partIdx].build(segStartIdx, segEndIdx) + failpoint.Inject("buildHashTablePanic", nil) + cost += int64(time.Since(start)) + } + return nil +} + +type hashJoinRuntimeStatsV2 struct { + hashJoinRuntimeStats + partitionData int64 + maxPartitionData int64 + buildHashTable int64 + maxBuildHashTable int64 +} + +func setMaxValue(addr *int64, currentValue int64) { + for { + value := atomic.LoadInt64(addr) + if currentValue <= value { + return + } + if atomic.CompareAndSwapInt64(addr, value, currentValue) { + return + } + } +} + +// Tp implements the RuntimeStats interface. +func (*hashJoinRuntimeStatsV2) Tp() int { + return execdetails.TpHashJoinRuntimeStats +} + +func (e *hashJoinRuntimeStatsV2) String() string { + buf := bytes.NewBuffer(make([]byte, 0, 128)) + if e.fetchAndBuildHashTable > 0 { + buf.WriteString("build_hash_table:{concurrency:") + buf.WriteString(strconv.Itoa(e.concurrent)) + buf.WriteString(", total:") + buf.WriteString(execdetails.FormatDuration(e.fetchAndBuildHashTable)) + buf.WriteString(", fetch:") + buf.WriteString(execdetails.FormatDuration(time.Duration(int64(e.fetchAndBuildHashTable) - e.maxBuildHashTable - e.maxPartitionData))) + buf.WriteString(", partition:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.partitionData))) + buf.WriteString(", max partition:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.maxPartitionData))) + buf.WriteString(", build:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.buildHashTable))) + buf.WriteString(", max build:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.maxBuildHashTable))) + buf.WriteString("}") + } + if e.probe > 0 { + buf.WriteString(", probe:{concurrency:") + buf.WriteString(strconv.Itoa(e.concurrent)) + buf.WriteString(", total:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe))) + buf.WriteString(", max:") + buf.WriteString(execdetails.FormatDuration(time.Duration(atomic.LoadInt64(&e.maxFetchAndProbe)))) + buf.WriteString(", probe:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.probe))) + buf.WriteString(", fetch and wait:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe - e.probe))) + if e.hashStat.probeCollision > 0 { + buf.WriteString(", probe_collision:") + buf.WriteString(strconv.FormatInt(e.hashStat.probeCollision, 10)) + } + buf.WriteString("}") + } + return buf.String() +} + +func (e *hashJoinRuntimeStatsV2) Clone() execdetails.RuntimeStats { + stats := hashJoinRuntimeStats{ + fetchAndBuildHashTable: e.fetchAndBuildHashTable, + hashStat: e.hashStat, + fetchAndProbe: e.fetchAndProbe, + probe: e.probe, + concurrent: e.concurrent, + maxFetchAndProbe: e.maxFetchAndProbe, + } + return &hashJoinRuntimeStatsV2{ + hashJoinRuntimeStats: stats, + partitionData: e.partitionData, + maxPartitionData: e.maxPartitionData, + buildHashTable: e.buildHashTable, + maxBuildHashTable: e.maxBuildHashTable, + } +} + +func (e *hashJoinRuntimeStatsV2) Merge(rs execdetails.RuntimeStats) { + tmp, ok := rs.(*hashJoinRuntimeStatsV2) + if !ok { + return + } + e.fetchAndBuildHashTable += tmp.fetchAndBuildHashTable + e.buildHashTable += tmp.buildHashTable + if e.maxBuildHashTable < tmp.maxBuildHashTable { + e.maxBuildHashTable = tmp.maxBuildHashTable + } + e.partitionData += tmp.partitionData + if e.maxPartitionData < tmp.maxPartitionData { + e.maxPartitionData = tmp.maxPartitionData + } + e.hashStat.buildTableElapse += tmp.hashStat.buildTableElapse + e.hashStat.probeCollision += tmp.hashStat.probeCollision + e.fetchAndProbe += tmp.fetchAndProbe + e.probe += tmp.probe + if e.maxFetchAndProbe < tmp.maxFetchAndProbe { + e.maxFetchAndProbe = tmp.maxFetchAndProbe + } +} diff --git a/pkg/executor/join/index_lookup_hash_join.go b/pkg/executor/join/index_lookup_hash_join.go index fd368ef77911d..7d0c7f591a237 100644 --- a/pkg/executor/join/index_lookup_hash_join.go +++ b/pkg/executor/join/index_lookup_hash_join.go @@ -338,11 +338,11 @@ func (ow *indexHashJoinOuterWorker) run(ctx context.Context) { defer trace.StartRegion(ctx, "IndexHashJoinOuterWorker").End() defer close(ow.innerCh) for { - failpoint.Inject("TestIssue30211", nil) + failpoint.Eval(_curpkg_("TestIssue30211")) task, err := ow.buildTask(ctx) - failpoint.Inject("testIndexHashJoinOuterWorkerErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testIndexHashJoinOuterWorkerErr")); _err_ == nil { err = errors.New("mockIndexHashJoinOuterWorkerErr") - }) + } if err != nil { task = &indexHashJoinTask{err: err} if ow.keepOuterOrder { @@ -362,9 +362,9 @@ func (ow *indexHashJoinOuterWorker) run(ctx context.Context) { return } if ow.keepOuterOrder { - failpoint.Inject("testIssue20779", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testIssue20779")); _err_ == nil { panic("testIssue20779") - }) + } if finished := ow.pushToChan(ctx, task, ow.taskCh); finished { return } @@ -531,9 +531,9 @@ func (iw *indexHashJoinInnerWorker) run(ctx context.Context, cancelFunc context. } } } - failpoint.Inject("testIndexHashJoinInnerWorkerErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testIndexHashJoinInnerWorkerErr")); _err_ == nil { joinResult.err = errors.New("mockIndexHashJoinInnerWorkerErr") - }) + } // When task.KeepOuterOrder is TRUE (resultCh != iw.resultCh): // - the last joinResult will be handled when the task has been processed, // thus we DO NOT need to check it here again. @@ -572,8 +572,8 @@ func (iw *indexHashJoinInnerWorker) getNewJoinResult(ctx context.Context) (*inde } func (iw *indexHashJoinInnerWorker) buildHashTableForOuterResult(task *indexHashJoinTask, h hash.Hash64) { - failpoint.Inject("IndexHashJoinBuildHashTablePanic", nil) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("IndexHashJoinBuildHashTablePanic")) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if iw.stats != nil { start := time.Now() defer func() { @@ -602,9 +602,9 @@ func (iw *indexHashJoinInnerWorker) buildHashTableForOuterResult(task *indexHash } h.Reset() err := codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), h, row, iw.outerCtx.HashTypes, hashColIdx, buf) - failpoint.Inject("testIndexHashJoinBuildErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testIndexHashJoinBuildErr")); _err_ == nil { err = errors.New("mockIndexHashJoinBuildErr") - }) + } if err != nil { // This panic will be recovered by the invoker. panic(err.Error()) @@ -680,10 +680,10 @@ func (iw *indexHashJoinInnerWorker) handleTask(ctx context.Context, task *indexH iw.wg.Wait() // check error after wg.Wait to make sure error message can be sent to // resultCh even if panic happen in buildHashTableForOuterResult. - failpoint.Inject("IndexHashJoinFetchInnerResultsErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("IndexHashJoinFetchInnerResultsErr")); _err_ == nil { err = errors.New("IndexHashJoinFetchInnerResultsErr") - }) - failpoint.Inject("ConsumeRandomPanic", nil) + } + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if err != nil { return err } @@ -789,7 +789,7 @@ func (iw *indexHashJoinInnerWorker) joinMatchedInnerRow2Chunk(ctx context.Contex joinResult.err = ctx.Err() return false, joinResult } - failpoint.InjectCall("joinMatchedInnerRow2Chunk") + failpoint.Call(_curpkg_("joinMatchedInnerRow2Chunk")) joinResult, ok = iw.getNewJoinResult(ctx) if !ok { return false, joinResult @@ -837,9 +837,9 @@ func (iw *indexHashJoinInnerWorker) doJoinInOrder(ctx context.Context, task *ind row := chk.GetRow(j) ptr := chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)} err = iw.collectMatchedInnerPtrs4OuterRows(row, ptr, task, h, iw.joinKeyBuf) - failpoint.Inject("TestIssue31129", func() { + if _, _err_ := failpoint.Eval(_curpkg_("TestIssue31129")); _err_ == nil { err = errors.New("TestIssue31129") - }) + } if err != nil { return err } diff --git a/pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ b/pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ new file mode 100644 index 0000000000000..fd368ef77911d --- /dev/null +++ b/pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ @@ -0,0 +1,884 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "context" + "fmt" + "hash" + "hash/fnv" + "runtime/trace" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" +) + +// numResChkHold indicates the number of resource chunks that an inner worker +// holds at the same time. +// It's used in 2 cases individually: +// 1. IndexMergeJoin +// 2. IndexNestedLoopHashJoin: +// It's used when IndexNestedLoopHashJoin.KeepOuterOrder is true. +// Otherwise, there will be at most `concurrency` resource chunks throughout +// the execution of IndexNestedLoopHashJoin. +const numResChkHold = 4 + +// IndexNestedLoopHashJoin employs one outer worker and N inner workers to +// execute concurrently. The output order is not promised. +// +// The execution flow is very similar to IndexLookUpReader: +// 1. The outer worker reads N outer rows, builds a task and sends it to the +// inner worker channel. +// 2. The inner worker receives the tasks and does 3 things for every task: +// 1. builds hash table from the outer rows +// 2. builds key ranges from outer rows and fetches inner rows +// 3. probes the hash table and sends the join result to the main thread channel. +// Note: step 1 and step 2 runs concurrently. +// +// 3. The main thread receives the join results. +type IndexNestedLoopHashJoin struct { + IndexLookUpJoin + resultCh chan *indexHashJoinResult + joinChkResourceCh []chan *chunk.Chunk + // We build individual joiner for each inner worker when using chunk-based + // execution, to avoid the concurrency of joiner.chk and joiner.selected. + Joiners []Joiner + KeepOuterOrder bool + curTask *indexHashJoinTask + // taskCh is only used when `KeepOuterOrder` is true. + taskCh chan *indexHashJoinTask + + stats *indexLookUpJoinRuntimeStats + prepared bool + // panicErr records the error generated by panic recover. This is introduced to + // return the actual error message instead of `context cancelled` to the client. + panicErr error + ctxWithCancel context.Context +} + +type indexHashJoinOuterWorker struct { + outerWorker + innerCh chan *indexHashJoinTask + keepOuterOrder bool + // taskCh is only used when the outer order needs to be promised. + taskCh chan *indexHashJoinTask +} + +type indexHashJoinInnerWorker struct { + innerWorker + joiner Joiner + joinChkResourceCh chan *chunk.Chunk + // resultCh is valid only when indexNestedLoopHashJoin do not need to keep + // order. Otherwise, it will be nil. + resultCh chan *indexHashJoinResult + taskCh <-chan *indexHashJoinTask + wg *sync.WaitGroup + joinKeyBuf []byte + outerRowStatus []outerRowStatusFlag + rowIter *chunk.Iterator4Slice +} + +type indexHashJoinResult struct { + chk *chunk.Chunk + err error + src chan<- *chunk.Chunk +} + +type indexHashJoinTask struct { + *lookUpJoinTask + outerRowStatus [][]outerRowStatusFlag + lookupMap BaseHashTable + err error + keepOuterOrder bool + // resultCh is only used when the outer order needs to be promised. + resultCh chan *indexHashJoinResult + // matchedInnerRowPtrs is only valid when the outer order needs to be + // promised. Otherwise, it will be nil. + // len(matchedInnerRowPtrs) equals to + // lookUpJoinTask.outerResult.NumChunks(), and the elements of every + // matchedInnerRowPtrs[chkIdx][rowIdx] indicates the matched inner row ptrs + // of the corresponding outer row. + matchedInnerRowPtrs [][][]chunk.RowPtr +} + +// Open implements the IndexNestedLoopHashJoin Executor interface. +func (e *IndexNestedLoopHashJoin) Open(ctx context.Context) error { + err := exec.Open(ctx, e.Children(0)) + if err != nil { + return err + } + if e.memTracker != nil { + e.memTracker.Reset() + } else { + e.memTracker = memory.NewTracker(e.ID(), -1) + } + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + e.cancelFunc = nil + e.innerPtrBytes = make([][]byte, 0, 8) + if e.RuntimeStats() != nil { + e.stats = &indexLookUpJoinRuntimeStats{} + } + e.Finished.Store(false) + return nil +} + +func (e *IndexNestedLoopHashJoin) startWorkers(ctx context.Context, initBatchSize int) { + concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() + if e.stats != nil { + e.stats.concurrency = concurrency + } + workerCtx, cancelFunc := context.WithCancel(ctx) + e.ctxWithCancel, e.cancelFunc = workerCtx, cancelFunc + innerCh := make(chan *indexHashJoinTask, concurrency) + if e.KeepOuterOrder { + e.taskCh = make(chan *indexHashJoinTask, concurrency) + // When `KeepOuterOrder` is true, each task holds their own `resultCh` + // individually, thus we do not need a global resultCh. + e.resultCh = nil + } else { + e.resultCh = make(chan *indexHashJoinResult, concurrency) + } + e.joinChkResourceCh = make([]chan *chunk.Chunk, concurrency) + e.WorkerWg.Add(1) + ow := e.newOuterWorker(innerCh, initBatchSize) + go util.WithRecovery(func() { ow.run(e.ctxWithCancel) }, e.finishJoinWorkers) + + for i := 0; i < concurrency; i++ { + if !e.KeepOuterOrder { + e.joinChkResourceCh[i] = make(chan *chunk.Chunk, 1) + e.joinChkResourceCh[i] <- exec.NewFirstChunk(e) + } else { + e.joinChkResourceCh[i] = make(chan *chunk.Chunk, numResChkHold) + for j := 0; j < numResChkHold; j++ { + e.joinChkResourceCh[i] <- exec.NewFirstChunk(e) + } + } + } + + e.WorkerWg.Add(concurrency) + for i := 0; i < concurrency; i++ { + workerID := i + go util.WithRecovery(func() { e.newInnerWorker(innerCh, workerID).run(e.ctxWithCancel, cancelFunc) }, e.finishJoinWorkers) + } + go e.wait4JoinWorkers() +} + +func (e *IndexNestedLoopHashJoin) finishJoinWorkers(r any) { + if r != nil { + e.IndexLookUpJoin.Finished.Store(true) + err := fmt.Errorf("%v", r) + if recoverdErr, ok := r.(error); ok { + err = recoverdErr + } + if !e.KeepOuterOrder { + e.resultCh <- &indexHashJoinResult{err: err} + } else { + task := &indexHashJoinTask{err: err} + e.taskCh <- task + } + e.panicErr = err + if e.cancelFunc != nil { + e.cancelFunc() + } + } + e.WorkerWg.Done() +} + +func (e *IndexNestedLoopHashJoin) wait4JoinWorkers() { + e.WorkerWg.Wait() + if e.resultCh != nil { + close(e.resultCh) + } + if e.taskCh != nil { + close(e.taskCh) + } +} + +// Next implements the IndexNestedLoopHashJoin Executor interface. +func (e *IndexNestedLoopHashJoin) Next(ctx context.Context, req *chunk.Chunk) error { + if !e.prepared { + e.startWorkers(ctx, req.RequiredRows()) + e.prepared = true + } + req.Reset() + if e.KeepOuterOrder { + return e.runInOrder(e.ctxWithCancel, req) + } + return e.runUnordered(e.ctxWithCancel, req) +} + +func (e *IndexNestedLoopHashJoin) runInOrder(ctx context.Context, req *chunk.Chunk) error { + for { + if e.isDryUpTasks(ctx) { + return e.panicErr + } + if e.curTask.err != nil { + return e.curTask.err + } + result, err := e.getResultFromChannel(ctx, e.curTask.resultCh) + if err != nil { + return err + } + if result == nil { + e.curTask = nil + continue + } + return e.handleResult(req, result) + } +} + +func (e *IndexNestedLoopHashJoin) runUnordered(ctx context.Context, req *chunk.Chunk) error { + result, err := e.getResultFromChannel(ctx, e.resultCh) + if err != nil { + return err + } + return e.handleResult(req, result) +} + +// isDryUpTasks indicates whether all the tasks have been processed. +func (e *IndexNestedLoopHashJoin) isDryUpTasks(ctx context.Context) bool { + if e.curTask != nil { + return false + } + var ok bool + select { + case e.curTask, ok = <-e.taskCh: + if !ok { + return true + } + case <-ctx.Done(): + return true + } + return false +} + +func (e *IndexNestedLoopHashJoin) getResultFromChannel(ctx context.Context, resultCh <-chan *indexHashJoinResult) (*indexHashJoinResult, error) { + var ( + result *indexHashJoinResult + ok bool + ) + select { + case result, ok = <-resultCh: + if !ok { + return nil, nil + } + if result.err != nil { + return nil, result.err + } + case <-ctx.Done(): + err := e.panicErr + if err == nil { + err = ctx.Err() + } + return nil, err + } + return result, nil +} + +func (*IndexNestedLoopHashJoin) handleResult(req *chunk.Chunk, result *indexHashJoinResult) error { + if result == nil { + return nil + } + req.SwapColumns(result.chk) + result.src <- result.chk + return nil +} + +// Close implements the IndexNestedLoopHashJoin Executor interface. +func (e *IndexNestedLoopHashJoin) Close() error { + if e.stats != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + if e.cancelFunc != nil { + e.cancelFunc() + } + if e.resultCh != nil { + channel.Clear(e.resultCh) + e.resultCh = nil + } + if e.taskCh != nil { + channel.Clear(e.taskCh) + e.taskCh = nil + } + for i := range e.joinChkResourceCh { + close(e.joinChkResourceCh[i]) + } + e.joinChkResourceCh = nil + e.Finished.Store(false) + e.prepared = false + e.ctxWithCancel = nil + return e.BaseExecutor.Close() +} + +func (ow *indexHashJoinOuterWorker) run(ctx context.Context) { + defer trace.StartRegion(ctx, "IndexHashJoinOuterWorker").End() + defer close(ow.innerCh) + for { + failpoint.Inject("TestIssue30211", nil) + task, err := ow.buildTask(ctx) + failpoint.Inject("testIndexHashJoinOuterWorkerErr", func() { + err = errors.New("mockIndexHashJoinOuterWorkerErr") + }) + if err != nil { + task = &indexHashJoinTask{err: err} + if ow.keepOuterOrder { + // The outerBuilder and innerFetcher run concurrently, we may + // get 2 errors at simultaneously. Thus the capacity of task.resultCh + // needs to be initialized to 2 to avoid waiting. + task.keepOuterOrder, task.resultCh = true, make(chan *indexHashJoinResult, 2) + ow.pushToChan(ctx, task, ow.taskCh) + } + ow.pushToChan(ctx, task, ow.innerCh) + return + } + if task == nil { + return + } + if finished := ow.pushToChan(ctx, task, ow.innerCh); finished { + return + } + if ow.keepOuterOrder { + failpoint.Inject("testIssue20779", func() { + panic("testIssue20779") + }) + if finished := ow.pushToChan(ctx, task, ow.taskCh); finished { + return + } + } + } +} + +func (ow *indexHashJoinOuterWorker) buildTask(ctx context.Context) (*indexHashJoinTask, error) { + task, err := ow.outerWorker.buildTask(ctx) + if task == nil || err != nil { + return nil, err + } + var ( + resultCh chan *indexHashJoinResult + matchedInnerRowPtrs [][][]chunk.RowPtr + ) + if ow.keepOuterOrder { + resultCh = make(chan *indexHashJoinResult, numResChkHold) + matchedInnerRowPtrs = make([][][]chunk.RowPtr, task.outerResult.NumChunks()) + for i := range matchedInnerRowPtrs { + matchedInnerRowPtrs[i] = make([][]chunk.RowPtr, task.outerResult.GetChunk(i).NumRows()) + } + } + numChks := task.outerResult.NumChunks() + outerRowStatus := make([][]outerRowStatusFlag, numChks) + for i := 0; i < numChks; i++ { + outerRowStatus[i] = make([]outerRowStatusFlag, task.outerResult.GetChunk(i).NumRows()) + } + return &indexHashJoinTask{ + lookUpJoinTask: task, + outerRowStatus: outerRowStatus, + keepOuterOrder: ow.keepOuterOrder, + resultCh: resultCh, + matchedInnerRowPtrs: matchedInnerRowPtrs, + }, nil +} + +func (*indexHashJoinOuterWorker) pushToChan(ctx context.Context, task *indexHashJoinTask, dst chan<- *indexHashJoinTask) bool { + select { + case <-ctx.Done(): + return true + case dst <- task: + } + return false +} + +func (e *IndexNestedLoopHashJoin) newOuterWorker(innerCh chan *indexHashJoinTask, initBatchSize int) *indexHashJoinOuterWorker { + maxBatchSize := e.Ctx().GetSessionVars().IndexJoinBatchSize + batchSize := min(initBatchSize, maxBatchSize) + ow := &indexHashJoinOuterWorker{ + outerWorker: outerWorker{ + OuterCtx: e.OuterCtx, + ctx: e.Ctx(), + executor: e.Children(0), + batchSize: batchSize, + maxBatchSize: maxBatchSize, + parentMemTracker: e.memTracker, + lookup: &e.IndexLookUpJoin, + }, + innerCh: innerCh, + keepOuterOrder: e.KeepOuterOrder, + taskCh: e.taskCh, + } + return ow +} + +func (e *IndexNestedLoopHashJoin) newInnerWorker(taskCh chan *indexHashJoinTask, workerID int) *indexHashJoinInnerWorker { + // Since multiple inner workers run concurrently, we should copy join's IndexRanges for every worker to avoid data race. + copiedRanges := make([]*ranger.Range, 0, len(e.IndexRanges.Range())) + for _, ran := range e.IndexRanges.Range() { + copiedRanges = append(copiedRanges, ran.Clone()) + } + var innerStats *innerWorkerRuntimeStats + if e.stats != nil { + innerStats = &e.stats.innerWorker + } + iw := &indexHashJoinInnerWorker{ + innerWorker: innerWorker{ + InnerCtx: e.InnerCtx, + outerCtx: e.OuterCtx, + ctx: e.Ctx(), + executorChk: e.AllocPool.Alloc(e.InnerCtx.RowTypes, e.MaxChunkSize(), e.MaxChunkSize()), + indexRanges: copiedRanges, + keyOff2IdxOff: e.KeyOff2IdxOff, + stats: innerStats, + lookup: &e.IndexLookUpJoin, + memTracker: memory.NewTracker(memory.LabelForIndexJoinInnerWorker, -1), + }, + taskCh: taskCh, + joiner: e.Joiners[workerID], + joinChkResourceCh: e.joinChkResourceCh[workerID], + resultCh: e.resultCh, + joinKeyBuf: make([]byte, 1), + outerRowStatus: make([]outerRowStatusFlag, 0, e.MaxChunkSize()), + rowIter: chunk.NewIterator4Slice([]chunk.Row{}), + } + iw.memTracker.AttachTo(e.memTracker) + if len(copiedRanges) != 0 { + // We should not consume this memory usage in `iw.memTracker`. The + // memory usage of inner worker will be reset the end of iw.handleTask. + // While the life cycle of this memory consumption exists throughout the + // whole active period of inner worker. + e.Ctx().GetSessionVars().StmtCtx.MemTracker.Consume(2 * types.EstimatedMemUsage(copiedRanges[0].LowVal, len(copiedRanges))) + } + if e.LastColHelper != nil { + // nextCwf.TmpConstant needs to be reset for every individual + // inner worker to avoid data race when the inner workers is running + // concurrently. + nextCwf := *e.LastColHelper + nextCwf.TmpConstant = make([]*expression.Constant, len(e.LastColHelper.TmpConstant)) + for i := range e.LastColHelper.TmpConstant { + nextCwf.TmpConstant[i] = &expression.Constant{RetType: nextCwf.TargetCol.RetType} + } + iw.nextColCompareFilters = &nextCwf + } + return iw +} + +func (iw *indexHashJoinInnerWorker) run(ctx context.Context, cancelFunc context.CancelFunc) { + defer trace.StartRegion(ctx, "IndexHashJoinInnerWorker").End() + var task *indexHashJoinTask + joinResult, ok := iw.getNewJoinResult(ctx) + if !ok { + cancelFunc() + return + } + h, resultCh := fnv.New64(), iw.resultCh + for { + // The previous task has been processed, so release the occupied memory + if task != nil { + task.memTracker.Detach() + } + select { + case <-ctx.Done(): + return + case task, ok = <-iw.taskCh: + } + if !ok { + break + } + // We need to init resultCh before the err is returned. + if task.keepOuterOrder { + resultCh = task.resultCh + } + if task.err != nil { + joinResult.err = task.err + break + } + err := iw.handleTask(ctx, task, joinResult, h, resultCh) + if err != nil && !task.keepOuterOrder { + // Only need check non-keep-outer-order case because the + // `joinResult` had been sent to the `resultCh` when err != nil. + joinResult.err = err + break + } + if task.keepOuterOrder { + // We need to get a new result holder here because the old + // `joinResult` hash been sent to the `resultCh` or to the + // `joinChkResourceCh`. + joinResult, ok = iw.getNewJoinResult(ctx) + if !ok { + cancelFunc() + return + } + } + } + failpoint.Inject("testIndexHashJoinInnerWorkerErr", func() { + joinResult.err = errors.New("mockIndexHashJoinInnerWorkerErr") + }) + // When task.KeepOuterOrder is TRUE (resultCh != iw.resultCh): + // - the last joinResult will be handled when the task has been processed, + // thus we DO NOT need to check it here again. + // - we DO NOT check the error here neither, because: + // - if the error is from task.err, the main thread will check the error of each task + // - if the error is from handleTask, the error will be handled in handleTask + // We should not check `task != nil && !task.KeepOuterOrder` here since it's + // possible that `join.chk.NumRows > 0` is true even if task == nil. + if resultCh == iw.resultCh { + if joinResult.err != nil { + resultCh <- joinResult + return + } + if joinResult.chk != nil && joinResult.chk.NumRows() > 0 { + select { + case resultCh <- joinResult: + case <-ctx.Done(): + return + } + } + } +} + +func (iw *indexHashJoinInnerWorker) getNewJoinResult(ctx context.Context) (*indexHashJoinResult, bool) { + joinResult := &indexHashJoinResult{ + src: iw.joinChkResourceCh, + } + ok := true + select { + case joinResult.chk, ok = <-iw.joinChkResourceCh: + case <-ctx.Done(): + joinResult.err = ctx.Err() + return joinResult, false + } + return joinResult, ok +} + +func (iw *indexHashJoinInnerWorker) buildHashTableForOuterResult(task *indexHashJoinTask, h hash.Hash64) { + failpoint.Inject("IndexHashJoinBuildHashTablePanic", nil) + failpoint.Inject("ConsumeRandomPanic", nil) + if iw.stats != nil { + start := time.Now() + defer func() { + atomic.AddInt64(&iw.stats.build, int64(time.Since(start))) + }() + } + buf, numChks := make([]byte, 1), task.outerResult.NumChunks() + task.lookupMap = newUnsafeHashTable(task.outerResult.Len()) + for chkIdx := 0; chkIdx < numChks; chkIdx++ { + chk := task.outerResult.GetChunk(chkIdx) + numRows := chk.NumRows() + if iw.lookup.Finished.Load().(bool) { + return + } + OUTER: + for rowIdx := 0; rowIdx < numRows; rowIdx++ { + if task.outerMatch != nil && !task.outerMatch[chkIdx][rowIdx] { + continue + } + row := chk.GetRow(rowIdx) + hashColIdx := iw.outerCtx.HashCols + for _, i := range hashColIdx { + if row.IsNull(i) { + continue OUTER + } + } + h.Reset() + err := codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), h, row, iw.outerCtx.HashTypes, hashColIdx, buf) + failpoint.Inject("testIndexHashJoinBuildErr", func() { + err = errors.New("mockIndexHashJoinBuildErr") + }) + if err != nil { + // This panic will be recovered by the invoker. + panic(err.Error()) + } + rowPtr := chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)} + task.lookupMap.Put(h.Sum64(), rowPtr) + } + } +} + +func (iw *indexHashJoinInnerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTask) error { + lookUpContents, err := iw.constructLookupContent(task) + if err != nil { + return err + } + return iw.innerWorker.fetchInnerResults(ctx, task, lookUpContents) +} + +func (iw *indexHashJoinInnerWorker) handleHashJoinInnerWorkerPanic(resultCh chan *indexHashJoinResult, err error) { + defer func() { + iw.wg.Done() + iw.lookup.WorkerWg.Done() + }() + if err != nil { + resultCh <- &indexHashJoinResult{err: err} + } +} + +func (iw *indexHashJoinInnerWorker) handleTask(ctx context.Context, task *indexHashJoinTask, joinResult *indexHashJoinResult, h hash.Hash64, resultCh chan *indexHashJoinResult) (err error) { + defer func() { + iw.memTracker.Consume(-iw.memTracker.BytesConsumed()) + if task.keepOuterOrder { + if err != nil { + joinResult.err = err + select { + case <-ctx.Done(): + case resultCh <- joinResult: + } + } + close(resultCh) + } + }() + var joinStartTime time.Time + if iw.stats != nil { + start := time.Now() + defer func() { + endTime := time.Now() + atomic.AddInt64(&iw.stats.totalTime, int64(endTime.Sub(start))) + if !joinStartTime.IsZero() { + // FetchInnerResults maybe return err and return, so joinStartTime is not initialized. + atomic.AddInt64(&iw.stats.join, int64(endTime.Sub(joinStartTime))) + } + }() + } + + iw.wg = &sync.WaitGroup{} + iw.wg.Add(1) + iw.lookup.WorkerWg.Add(1) + // TODO(XuHuaiyu): we may always use the smaller side to build the hashtable. + go util.WithRecovery( + func() { + iw.buildHashTableForOuterResult(task, h) + }, + func(r any) { + var err error + if r != nil { + err = errors.Errorf("%v", r) + } + iw.handleHashJoinInnerWorkerPanic(resultCh, err) + }, + ) + err = iw.fetchInnerResults(ctx, task.lookUpJoinTask) + iw.wg.Wait() + // check error after wg.Wait to make sure error message can be sent to + // resultCh even if panic happen in buildHashTableForOuterResult. + failpoint.Inject("IndexHashJoinFetchInnerResultsErr", func() { + err = errors.New("IndexHashJoinFetchInnerResultsErr") + }) + failpoint.Inject("ConsumeRandomPanic", nil) + if err != nil { + return err + } + + joinStartTime = time.Now() + if !task.keepOuterOrder { + return iw.doJoinUnordered(ctx, task, joinResult, h, resultCh) + } + return iw.doJoinInOrder(ctx, task, joinResult, h, resultCh) +} + +func (iw *indexHashJoinInnerWorker) doJoinUnordered(ctx context.Context, task *indexHashJoinTask, joinResult *indexHashJoinResult, h hash.Hash64, resultCh chan *indexHashJoinResult) error { + var ok bool + iter := chunk.NewIterator4List(task.innerResult) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + ok, joinResult = iw.joinMatchedInnerRow2Chunk(ctx, row, task, joinResult, h, iw.joinKeyBuf) + if !ok { + return joinResult.err + } + } + for chkIdx, outerRowStatus := range task.outerRowStatus { + chk := task.outerResult.GetChunk(chkIdx) + for rowIdx, val := range outerRowStatus { + if val == outerRowMatched { + continue + } + iw.joiner.OnMissMatch(val == outerRowHasNull, chk.GetRow(rowIdx), joinResult.chk) + if joinResult.chk.IsFull() { + select { + case resultCh <- joinResult: + case <-ctx.Done(): + return ctx.Err() + } + joinResult, ok = iw.getNewJoinResult(ctx) + if !ok { + return errors.New("indexHashJoinInnerWorker.doJoinUnordered failed") + } + } + } + } + return nil +} + +func (iw *indexHashJoinInnerWorker) getMatchedOuterRows(innerRow chunk.Row, task *indexHashJoinTask, h hash.Hash64, buf []byte) (matchedRows []chunk.Row, matchedRowPtr []chunk.RowPtr, err error) { + h.Reset() + err = codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), h, innerRow, iw.HashTypes, iw.HashCols, buf) + if err != nil { + return nil, nil, err + } + matchedOuterEntry := task.lookupMap.Get(h.Sum64()) + if matchedOuterEntry == nil { + return nil, nil, nil + } + joinType := JoinerType(iw.joiner) + isSemiJoin := joinType.IsSemiJoin() + for ; matchedOuterEntry != nil; matchedOuterEntry = matchedOuterEntry.Next { + ptr := matchedOuterEntry.Ptr + outerRow := task.outerResult.GetRow(ptr) + ok, err := codec.EqualChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), innerRow, iw.HashTypes, iw.HashCols, outerRow, iw.outerCtx.HashTypes, iw.outerCtx.HashCols) + if err != nil { + return nil, nil, err + } + if !ok || (task.outerRowStatus[ptr.ChkIdx][ptr.RowIdx] == outerRowMatched && isSemiJoin) { + continue + } + matchedRows = append(matchedRows, outerRow) + matchedRowPtr = append(matchedRowPtr, chunk.RowPtr{ChkIdx: ptr.ChkIdx, RowIdx: ptr.RowIdx}) + } + return matchedRows, matchedRowPtr, nil +} + +func (iw *indexHashJoinInnerWorker) joinMatchedInnerRow2Chunk(ctx context.Context, innerRow chunk.Row, task *indexHashJoinTask, + joinResult *indexHashJoinResult, h hash.Hash64, buf []byte) (bool, *indexHashJoinResult) { + matchedOuterRows, matchedOuterRowPtr, err := iw.getMatchedOuterRows(innerRow, task, h, buf) + if err != nil { + joinResult.err = err + return false, joinResult + } + if len(matchedOuterRows) == 0 { + return true, joinResult + } + var ok bool + cursor := 0 + iw.rowIter.Reset(matchedOuterRows) + iter := iw.rowIter + for iw.rowIter.Begin(); iter.Current() != iter.End(); { + iw.outerRowStatus, err = iw.joiner.TryToMatchOuters(iter, innerRow, joinResult.chk, iw.outerRowStatus) + if err != nil { + joinResult.err = err + return false, joinResult + } + for _, status := range iw.outerRowStatus { + chkIdx, rowIdx := matchedOuterRowPtr[cursor].ChkIdx, matchedOuterRowPtr[cursor].RowIdx + if status == outerRowMatched || task.outerRowStatus[chkIdx][rowIdx] == outerRowUnmatched { + task.outerRowStatus[chkIdx][rowIdx] = status + } + cursor++ + } + if joinResult.chk.IsFull() { + select { + case iw.resultCh <- joinResult: + case <-ctx.Done(): + joinResult.err = ctx.Err() + return false, joinResult + } + failpoint.InjectCall("joinMatchedInnerRow2Chunk") + joinResult, ok = iw.getNewJoinResult(ctx) + if !ok { + return false, joinResult + } + } + } + return true, joinResult +} + +func (iw *indexHashJoinInnerWorker) collectMatchedInnerPtrs4OuterRows(innerRow chunk.Row, innerRowPtr chunk.RowPtr, + task *indexHashJoinTask, h hash.Hash64, buf []byte) error { + _, matchedOuterRowIdx, err := iw.getMatchedOuterRows(innerRow, task, h, buf) + if err != nil { + return err + } + for _, outerRowPtr := range matchedOuterRowIdx { + chkIdx, rowIdx := outerRowPtr.ChkIdx, outerRowPtr.RowIdx + task.matchedInnerRowPtrs[chkIdx][rowIdx] = append(task.matchedInnerRowPtrs[chkIdx][rowIdx], innerRowPtr) + } + return nil +} + +// doJoinInOrder follows the following steps: +// 1. collect all the matched inner row ptrs for every outer row +// 2. do the join work +// 2.1 collect all the matched inner rows using the collected ptrs for every outer row +// 2.2 call TryToMatchInners for every outer row +// 2.3 call OnMissMatch when no inner rows are matched +func (iw *indexHashJoinInnerWorker) doJoinInOrder(ctx context.Context, task *indexHashJoinTask, joinResult *indexHashJoinResult, h hash.Hash64, resultCh chan *indexHashJoinResult) (err error) { + defer func() { + if err == nil && joinResult.chk != nil { + if joinResult.chk.NumRows() > 0 { + select { + case resultCh <- joinResult: + case <-ctx.Done(): + return + } + } else { + joinResult.src <- joinResult.chk + } + } + }() + for i, numChunks := 0, task.innerResult.NumChunks(); i < numChunks; i++ { + for j, chk := 0, task.innerResult.GetChunk(i); j < chk.NumRows(); j++ { + row := chk.GetRow(j) + ptr := chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)} + err = iw.collectMatchedInnerPtrs4OuterRows(row, ptr, task, h, iw.joinKeyBuf) + failpoint.Inject("TestIssue31129", func() { + err = errors.New("TestIssue31129") + }) + if err != nil { + return err + } + } + } + // TODO: matchedInnerRowPtrs and matchedInnerRows can be moved to inner worker. + matchedInnerRows := make([]chunk.Row, 0, len(task.matchedInnerRowPtrs)) + var hasMatched, hasNull, ok bool + for chkIdx, innerRowPtrs4Chk := range task.matchedInnerRowPtrs { + for outerRowIdx, innerRowPtrs := range innerRowPtrs4Chk { + matchedInnerRows, hasMatched, hasNull = matchedInnerRows[:0], false, false + outerRow := task.outerResult.GetChunk(chkIdx).GetRow(outerRowIdx) + for _, ptr := range innerRowPtrs { + matchedInnerRows = append(matchedInnerRows, task.innerResult.GetRow(ptr)) + } + iw.rowIter.Reset(matchedInnerRows) + iter := iw.rowIter + for iter.Begin(); iter.Current() != iter.End(); { + matched, isNull, err := iw.joiner.TryToMatchInners(outerRow, iter, joinResult.chk) + if err != nil { + return err + } + hasMatched, hasNull = matched || hasMatched, isNull || hasNull + if joinResult.chk.IsFull() { + select { + case resultCh <- joinResult: + case <-ctx.Done(): + return ctx.Err() + } + joinResult, ok = iw.getNewJoinResult(ctx) + if !ok { + return errors.New("indexHashJoinInnerWorker.doJoinInOrder failed") + } + } + } + if !hasMatched { + iw.joiner.OnMissMatch(hasNull, outerRow, joinResult.chk) + } + } + } + return nil +} diff --git a/pkg/executor/join/index_lookup_join.go b/pkg/executor/join/index_lookup_join.go index e289ae7ab44b3..06b8fadfabfeb 100644 --- a/pkg/executor/join/index_lookup_join.go +++ b/pkg/executor/join/index_lookup_join.go @@ -243,11 +243,11 @@ func (e *IndexLookUpJoin) newInnerWorker(taskCh chan *lookUpJoinTask) *innerWork lookup: e, memTracker: memory.NewTracker(memory.LabelForIndexJoinInnerWorker, -1), } - failpoint.Inject("inlNewInnerPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("inlNewInnerPanic")); _err_ == nil { if val.(bool) { panic("test inlNewInnerPanic") } - }) + } iw.memTracker.AttachTo(e.memTracker) if len(copiedRanges) != 0 { // We should not consume this memory usage in `iw.memTracker`. The @@ -389,8 +389,8 @@ func (ow *outerWorker) run(ctx context.Context, wg *sync.WaitGroup) { wg.Done() }() for { - failpoint.Inject("TestIssue30211", nil) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("TestIssue30211")) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) task, err := ow.buildTask(ctx) if err != nil { task.doneCh <- err @@ -436,7 +436,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { task.memTracker = memory.NewTracker(-1, -1) task.outerResult.GetMemTracker().AttachTo(task.memTracker) task.memTracker.AttachTo(ow.parentMemTracker) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) ow.increaseBatchSize() requiredRows := ow.batchSize @@ -586,7 +586,7 @@ func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*IndexJoi } return nil, err } - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if rowIdx == 0 { iw.memTracker.Consume(types.EstimatedMemUsage(dLookUpKey, numRows)) } @@ -733,7 +733,7 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa default: } err := exec.Next(ctx, innerExec, iw.executorChk) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if err != nil { return err } diff --git a/pkg/executor/join/index_lookup_join.go__failpoint_stash__ b/pkg/executor/join/index_lookup_join.go__failpoint_stash__ new file mode 100644 index 0000000000000..e289ae7ab44b3 --- /dev/null +++ b/pkg/executor/join/index_lookup_join.go__failpoint_stash__ @@ -0,0 +1,882 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "bytes" + "context" + "runtime/trace" + "slices" + "strconv" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/mvmap" + "github.com/pingcap/tidb/pkg/util/ranger" + "go.uber.org/zap" +) + +var _ exec.Executor = &IndexLookUpJoin{} + +// IndexLookUpJoin employs one outer worker and N innerWorkers to execute concurrently. +// It preserves the order of the outer table and support batch lookup. +// +// The execution flow is very similar to IndexLookUpReader: +// 1. outerWorker read N outer rows, build a task and send it to result channel and inner worker channel. +// 2. The innerWorker receives the task, builds key ranges from outer rows and fetch inner rows, builds inner row hash map. +// 3. main thread receives the task, waits for inner worker finish handling the task. +// 4. main thread join each outer row by look up the inner rows hash map in the task. +type IndexLookUpJoin struct { + exec.BaseExecutor + + resultCh <-chan *lookUpJoinTask + cancelFunc context.CancelFunc + WorkerWg *sync.WaitGroup + + OuterCtx OuterCtx + InnerCtx InnerCtx + + task *lookUpJoinTask + JoinResult *chunk.Chunk + innerIter *chunk.Iterator4Slice + + Joiner Joiner + IsOuterJoin bool + + requiredRows int64 + + IndexRanges ranger.MutableRanges + KeyOff2IdxOff []int + innerPtrBytes [][]byte + + // LastColHelper store the information for last col if there's complicated filter like col > x_col and col < x_col + 100. + LastColHelper *plannercore.ColWithCmpFuncManager + + memTracker *memory.Tracker // track memory usage. + + stats *indexLookUpJoinRuntimeStats + Finished *atomic.Value + prepared bool +} + +// OuterCtx is the outer ctx used in index lookup join +type OuterCtx struct { + RowTypes []*types.FieldType + KeyCols []int + HashTypes []*types.FieldType + HashCols []int + Filter expression.CNFExprs +} + +// IndexJoinExecutorBuilder is the interface used by index lookup join to build the executor, this interface +// is added to avoid cycle import +type IndexJoinExecutorBuilder interface { + BuildExecutorForIndexJoin(ctx context.Context, lookUpContents []*IndexJoinLookUpContent, + indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) +} + +// InnerCtx is the inner side ctx used in index lookup join +type InnerCtx struct { + ReaderBuilder IndexJoinExecutorBuilder + RowTypes []*types.FieldType + KeyCols []int + KeyColIDs []int64 // the original ID in its table, used by dynamic partition pruning + KeyCollators []collate.Collator + HashTypes []*types.FieldType + HashCols []int + HashCollators []collate.Collator + ColLens []int + HasPrefixCol bool +} + +type lookUpJoinTask struct { + outerResult *chunk.List + outerMatch [][]bool + + innerResult *chunk.List + encodedLookUpKeys []*chunk.Chunk + lookupMap *mvmap.MVMap + matchedInners []chunk.Row + + doneCh chan error + cursor chunk.RowPtr + hasMatch bool + hasNull bool + + memTracker *memory.Tracker // track memory usage. +} + +type outerWorker struct { + OuterCtx + + lookup *IndexLookUpJoin + + ctx sessionctx.Context + executor exec.Executor + + maxBatchSize int + batchSize int + + resultCh chan<- *lookUpJoinTask + innerCh chan<- *lookUpJoinTask + + parentMemTracker *memory.Tracker +} + +type innerWorker struct { + InnerCtx + + taskCh <-chan *lookUpJoinTask + outerCtx OuterCtx + ctx sessionctx.Context + executorChk *chunk.Chunk + lookup *IndexLookUpJoin + + indexRanges []*ranger.Range + nextColCompareFilters *plannercore.ColWithCmpFuncManager + keyOff2IdxOff []int + stats *innerWorkerRuntimeStats + memTracker *memory.Tracker +} + +// Open implements the Executor interface. +func (e *IndexLookUpJoin) Open(ctx context.Context) error { + err := exec.Open(ctx, e.Children(0)) + if err != nil { + return err + } + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + e.innerPtrBytes = make([][]byte, 0, 8) + e.Finished.Store(false) + if e.RuntimeStats() != nil { + e.stats = &indexLookUpJoinRuntimeStats{} + } + e.cancelFunc = nil + return nil +} + +func (e *IndexLookUpJoin) startWorkers(ctx context.Context) { + concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() + if e.stats != nil { + e.stats.concurrency = concurrency + } + resultCh := make(chan *lookUpJoinTask, concurrency) + e.resultCh = resultCh + workerCtx, cancelFunc := context.WithCancel(ctx) + e.cancelFunc = cancelFunc + innerCh := make(chan *lookUpJoinTask, concurrency) + e.WorkerWg.Add(1) + go e.newOuterWorker(resultCh, innerCh).run(workerCtx, e.WorkerWg) + for i := 0; i < concurrency; i++ { + innerWorker := e.newInnerWorker(innerCh) + e.WorkerWg.Add(1) + go innerWorker.run(workerCtx, e.WorkerWg) + } +} + +func (e *IndexLookUpJoin) newOuterWorker(resultCh, innerCh chan *lookUpJoinTask) *outerWorker { + ow := &outerWorker{ + OuterCtx: e.OuterCtx, + ctx: e.Ctx(), + executor: e.Children(0), + resultCh: resultCh, + innerCh: innerCh, + batchSize: 32, + maxBatchSize: e.Ctx().GetSessionVars().IndexJoinBatchSize, + parentMemTracker: e.memTracker, + lookup: e, + } + return ow +} + +func (e *IndexLookUpJoin) newInnerWorker(taskCh chan *lookUpJoinTask) *innerWorker { + // Since multiple inner workers run concurrently, we should copy join's IndexRanges for every worker to avoid data race. + copiedRanges := make([]*ranger.Range, 0, len(e.IndexRanges.Range())) + for _, ran := range e.IndexRanges.Range() { + copiedRanges = append(copiedRanges, ran.Clone()) + } + + var innerStats *innerWorkerRuntimeStats + if e.stats != nil { + innerStats = &e.stats.innerWorker + } + iw := &innerWorker{ + InnerCtx: e.InnerCtx, + outerCtx: e.OuterCtx, + taskCh: taskCh, + ctx: e.Ctx(), + executorChk: e.AllocPool.Alloc(e.InnerCtx.RowTypes, e.MaxChunkSize(), e.MaxChunkSize()), + indexRanges: copiedRanges, + keyOff2IdxOff: e.KeyOff2IdxOff, + stats: innerStats, + lookup: e, + memTracker: memory.NewTracker(memory.LabelForIndexJoinInnerWorker, -1), + } + failpoint.Inject("inlNewInnerPanic", func(val failpoint.Value) { + if val.(bool) { + panic("test inlNewInnerPanic") + } + }) + iw.memTracker.AttachTo(e.memTracker) + if len(copiedRanges) != 0 { + // We should not consume this memory usage in `iw.memTracker`. The + // memory usage of inner worker will be reset the end of iw.handleTask. + // While the life cycle of this memory consumption exists throughout the + // whole active period of inner worker. + e.Ctx().GetSessionVars().StmtCtx.MemTracker.Consume(2 * types.EstimatedMemUsage(copiedRanges[0].LowVal, len(copiedRanges))) + } + if e.LastColHelper != nil { + // nextCwf.TmpConstant needs to be reset for every individual + // inner worker to avoid data race when the inner workers is running + // concurrently. + nextCwf := *e.LastColHelper + nextCwf.TmpConstant = make([]*expression.Constant, len(e.LastColHelper.TmpConstant)) + for i := range e.LastColHelper.TmpConstant { + nextCwf.TmpConstant[i] = &expression.Constant{RetType: nextCwf.TargetCol.RetType} + } + iw.nextColCompareFilters = &nextCwf + } + return iw +} + +// Next implements the Executor interface. +func (e *IndexLookUpJoin) Next(ctx context.Context, req *chunk.Chunk) error { + if !e.prepared { + e.startWorkers(ctx) + e.prepared = true + } + if e.IsOuterJoin { + atomic.StoreInt64(&e.requiredRows, int64(req.RequiredRows())) + } + req.Reset() + e.JoinResult.Reset() + for { + task, err := e.getFinishedTask(ctx) + if err != nil { + return err + } + if task == nil { + return nil + } + startTime := time.Now() + if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { + e.lookUpMatchedInners(task, task.cursor) + if e.innerIter == nil { + e.innerIter = chunk.NewIterator4Slice(task.matchedInners) + } + e.innerIter.Reset(task.matchedInners) + e.innerIter.Begin() + } + + outerRow := task.outerResult.GetRow(task.cursor) + if e.innerIter.Current() != e.innerIter.End() { + matched, isNull, err := e.Joiner.TryToMatchInners(outerRow, e.innerIter, req) + if err != nil { + return err + } + task.hasMatch = task.hasMatch || matched + task.hasNull = task.hasNull || isNull + } + if e.innerIter.Current() == e.innerIter.End() { + if !task.hasMatch { + e.Joiner.OnMissMatch(task.hasNull, outerRow, req) + } + task.cursor.RowIdx++ + if int(task.cursor.RowIdx) == task.outerResult.GetChunk(int(task.cursor.ChkIdx)).NumRows() { + task.cursor.ChkIdx++ + task.cursor.RowIdx = 0 + } + task.hasMatch = false + task.hasNull = false + } + if e.stats != nil { + atomic.AddInt64(&e.stats.probe, int64(time.Since(startTime))) + } + if req.IsFull() { + return nil + } + } +} + +func (e *IndexLookUpJoin) getFinishedTask(ctx context.Context) (*lookUpJoinTask, error) { + task := e.task + if task != nil && int(task.cursor.ChkIdx) < task.outerResult.NumChunks() { + return task, nil + } + + // The previous task has been processed, so release the occupied memory + if task != nil { + task.memTracker.Detach() + } + select { + case task = <-e.resultCh: + case <-ctx.Done(): + return nil, ctx.Err() + } + if task == nil { + return nil, nil + } + + select { + case err := <-task.doneCh: + if err != nil { + return nil, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + + e.task = task + return task, nil +} + +func (e *IndexLookUpJoin) lookUpMatchedInners(task *lookUpJoinTask, rowPtr chunk.RowPtr) { + outerKey := task.encodedLookUpKeys[rowPtr.ChkIdx].GetRow(int(rowPtr.RowIdx)).GetBytes(0) + e.innerPtrBytes = task.lookupMap.Get(outerKey, e.innerPtrBytes[:0]) + task.matchedInners = task.matchedInners[:0] + + for _, b := range e.innerPtrBytes { + ptr := *(*chunk.RowPtr)(unsafe.Pointer(&b[0])) + matchedInner := task.innerResult.GetRow(ptr) + task.matchedInners = append(task.matchedInners, matchedInner) + } +} + +func (ow *outerWorker) run(ctx context.Context, wg *sync.WaitGroup) { + defer trace.StartRegion(ctx, "IndexLookupJoinOuterWorker").End() + defer func() { + if r := recover(); r != nil { + ow.lookup.Finished.Store(true) + logutil.Logger(ctx).Error("outerWorker panicked", zap.Any("recover", r), zap.Stack("stack")) + task := &lookUpJoinTask{doneCh: make(chan error, 1)} + err := util.GetRecoverError(r) + task.doneCh <- err + ow.pushToChan(ctx, task, ow.resultCh) + } + close(ow.resultCh) + close(ow.innerCh) + wg.Done() + }() + for { + failpoint.Inject("TestIssue30211", nil) + failpoint.Inject("ConsumeRandomPanic", nil) + task, err := ow.buildTask(ctx) + if err != nil { + task.doneCh <- err + ow.pushToChan(ctx, task, ow.resultCh) + return + } + if task == nil { + return + } + + if finished := ow.pushToChan(ctx, task, ow.innerCh); finished { + return + } + + if finished := ow.pushToChan(ctx, task, ow.resultCh); finished { + return + } + } +} + +func (*outerWorker) pushToChan(ctx context.Context, task *lookUpJoinTask, dst chan<- *lookUpJoinTask) bool { + select { + case <-ctx.Done(): + return true + case dst <- task: + } + return false +} + +// newList creates a new List to buffer current executor's result. +func newList(e exec.Executor) *chunk.List { + return chunk.NewList(e.RetFieldTypes(), e.InitCap(), e.MaxChunkSize()) +} + +// buildTask builds a lookUpJoinTask and read Outer rows. +// When err is not nil, task must not be nil to send the error to the main thread via task. +func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { + task := &lookUpJoinTask{ + doneCh: make(chan error, 1), + outerResult: newList(ow.executor), + lookupMap: mvmap.NewMVMap(), + } + task.memTracker = memory.NewTracker(-1, -1) + task.outerResult.GetMemTracker().AttachTo(task.memTracker) + task.memTracker.AttachTo(ow.parentMemTracker) + failpoint.Inject("ConsumeRandomPanic", nil) + + ow.increaseBatchSize() + requiredRows := ow.batchSize + if ow.lookup.IsOuterJoin { + // If it is outerJoin, push the requiredRows down. + // Note: buildTask is triggered when `Open` is called, but + // ow.lookup.requiredRows is set when `Next` is called. Thus we check + // whether it's 0 here. + if parentRequired := int(atomic.LoadInt64(&ow.lookup.requiredRows)); parentRequired != 0 { + requiredRows = parentRequired + } + } + maxChunkSize := ow.ctx.GetSessionVars().MaxChunkSize + for requiredRows > task.outerResult.Len() { + chk := ow.executor.NewChunkWithCapacity(ow.OuterCtx.RowTypes, maxChunkSize, maxChunkSize) + chk = chk.SetRequiredRows(requiredRows, maxChunkSize) + err := exec.Next(ctx, ow.executor, chk) + if err != nil { + return task, err + } + if chk.NumRows() == 0 { + break + } + + task.outerResult.Add(chk) + } + if task.outerResult.Len() == 0 { + return nil, nil + } + numChks := task.outerResult.NumChunks() + if ow.Filter != nil { + task.outerMatch = make([][]bool, task.outerResult.NumChunks()) + var err error + exprCtx := ow.ctx.GetExprCtx() + for i := 0; i < numChks; i++ { + chk := task.outerResult.GetChunk(i) + outerMatch := make([]bool, 0, chk.NumRows()) + task.memTracker.Consume(int64(cap(outerMatch))) + task.outerMatch[i], err = expression.VectorizedFilter(exprCtx.GetEvalCtx(), ow.ctx.GetSessionVars().EnableVectorizedExpression, ow.Filter, chunk.NewIterator4Chunk(chk), outerMatch) + if err != nil { + return task, err + } + } + } + task.encodedLookUpKeys = make([]*chunk.Chunk, task.outerResult.NumChunks()) + for i := range task.encodedLookUpKeys { + task.encodedLookUpKeys[i] = ow.executor.NewChunkWithCapacity( + []*types.FieldType{types.NewFieldType(mysql.TypeBlob)}, + task.outerResult.GetChunk(i).NumRows(), + task.outerResult.GetChunk(i).NumRows(), + ) + } + return task, nil +} + +func (ow *outerWorker) increaseBatchSize() { + if ow.batchSize < ow.maxBatchSize { + ow.batchSize *= 2 + } + if ow.batchSize > ow.maxBatchSize { + ow.batchSize = ow.maxBatchSize + } +} + +func (iw *innerWorker) run(ctx context.Context, wg *sync.WaitGroup) { + defer trace.StartRegion(ctx, "IndexLookupJoinInnerWorker").End() + var task *lookUpJoinTask + defer func() { + if r := recover(); r != nil { + iw.lookup.Finished.Store(true) + logutil.Logger(ctx).Error("innerWorker panicked", zap.Any("recover", r), zap.Stack("stack")) + err := util.GetRecoverError(r) + // "task != nil" is guaranteed when panic happened. + task.doneCh <- err + } + wg.Done() + }() + + for ok := true; ok; { + select { + case task, ok = <-iw.taskCh: + if !ok { + return + } + case <-ctx.Done(): + return + } + + err := iw.handleTask(ctx, task) + task.doneCh <- err + } +} + +// IndexJoinLookUpContent is the content used in index lookup join +type IndexJoinLookUpContent struct { + Keys []types.Datum + Row chunk.Row + keyCols []int + KeyColIDs []int64 // the original ID in its table, used by dynamic partition pruning +} + +func (iw *innerWorker) handleTask(ctx context.Context, task *lookUpJoinTask) error { + if iw.stats != nil { + start := time.Now() + defer func() { + atomic.AddInt64(&iw.stats.totalTime, int64(time.Since(start))) + }() + } + defer func() { + iw.memTracker.Consume(-iw.memTracker.BytesConsumed()) + }() + lookUpContents, err := iw.constructLookupContent(task) + if err != nil { + return err + } + err = iw.fetchInnerResults(ctx, task, lookUpContents) + if err != nil { + return err + } + err = iw.buildLookUpMap(task) + if err != nil { + return err + } + return nil +} + +func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*IndexJoinLookUpContent, error) { + if iw.stats != nil { + start := time.Now() + defer func() { + atomic.AddInt64(&iw.stats.task, 1) + atomic.AddInt64(&iw.stats.construct, int64(time.Since(start))) + }() + } + lookUpContents := make([]*IndexJoinLookUpContent, 0, task.outerResult.Len()) + keyBuf := make([]byte, 0, 64) + for chkIdx := 0; chkIdx < task.outerResult.NumChunks(); chkIdx++ { + chk := task.outerResult.GetChunk(chkIdx) + numRows := chk.NumRows() + for rowIdx := 0; rowIdx < numRows; rowIdx++ { + dLookUpKey, dHashKey, err := iw.constructDatumLookupKey(task, chkIdx, rowIdx) + if err != nil { + if terror.ErrorEqual(err, types.ErrWrongValue) { + // We ignore rows with invalid datetime. + task.encodedLookUpKeys[chkIdx].AppendNull(0) + continue + } + return nil, err + } + failpoint.Inject("ConsumeRandomPanic", nil) + if rowIdx == 0 { + iw.memTracker.Consume(types.EstimatedMemUsage(dLookUpKey, numRows)) + } + if dHashKey == nil { + // Append null to make lookUpKeys the same length as Outer Result. + task.encodedLookUpKeys[chkIdx].AppendNull(0) + continue + } + keyBuf = keyBuf[:0] + keyBuf, err = codec.EncodeKey(iw.ctx.GetSessionVars().StmtCtx.TimeZone(), keyBuf, dHashKey...) + err = iw.ctx.GetSessionVars().StmtCtx.HandleError(err) + if err != nil { + if terror.ErrorEqual(err, types.ErrWrongValue) { + // we ignore rows with invalid datetime + task.encodedLookUpKeys[chkIdx].AppendNull(0) + continue + } + return nil, err + } + // Store the encoded lookup key in chunk, so we can use it to lookup the matched inners directly. + task.encodedLookUpKeys[chkIdx].AppendBytes(0, keyBuf) + if iw.HasPrefixCol { + for i, outerOffset := range iw.keyOff2IdxOff { + // If it's a prefix column. Try to fix it. + joinKeyColPrefixLen := iw.ColLens[outerOffset] + if joinKeyColPrefixLen != types.UnspecifiedLength { + ranger.CutDatumByPrefixLen(&dLookUpKey[i], joinKeyColPrefixLen, iw.RowTypes[iw.KeyCols[i]]) + } + } + // dLookUpKey is sorted and deduplicated at sortAndDedupLookUpContents. + // So we don't need to do it here. + } + lookUpContents = append(lookUpContents, &IndexJoinLookUpContent{Keys: dLookUpKey, Row: chk.GetRow(rowIdx), keyCols: iw.KeyCols, KeyColIDs: iw.KeyColIDs}) + } + } + + for i := range task.encodedLookUpKeys { + task.memTracker.Consume(task.encodedLookUpKeys[i].MemoryUsage()) + } + lookUpContents = iw.sortAndDedupLookUpContents(lookUpContents) + return lookUpContents, nil +} + +func (iw *innerWorker) constructDatumLookupKey(task *lookUpJoinTask, chkIdx, rowIdx int) ([]types.Datum, []types.Datum, error) { + if task.outerMatch != nil && !task.outerMatch[chkIdx][rowIdx] { + return nil, nil, nil + } + outerRow := task.outerResult.GetChunk(chkIdx).GetRow(rowIdx) + sc := iw.ctx.GetSessionVars().StmtCtx + keyLen := len(iw.KeyCols) + dLookupKey := make([]types.Datum, 0, keyLen) + dHashKey := make([]types.Datum, 0, len(iw.HashCols)) + for i, hashCol := range iw.outerCtx.HashCols { + outerValue := outerRow.GetDatum(hashCol, iw.outerCtx.RowTypes[hashCol]) + // Join-on-condition can be promised to be equal-condition in + // IndexNestedLoopJoin, thus the Filter will always be false if + // outerValue is null, and we don't need to lookup it. + if outerValue.IsNull() { + return nil, nil, nil + } + innerColType := iw.RowTypes[iw.HashCols[i]] + innerValue, err := outerValue.ConvertTo(sc.TypeCtx(), innerColType) + if err != nil && !(terror.ErrorEqual(err, types.ErrTruncated) && (innerColType.GetType() == mysql.TypeSet || innerColType.GetType() == mysql.TypeEnum)) { + // If the converted outerValue overflows or invalid to innerValue, we don't need to lookup it. + if terror.ErrorEqual(err, types.ErrOverflow) || terror.ErrorEqual(err, types.ErrWarnDataOutOfRange) { + return nil, nil, nil + } + return nil, nil, err + } + cmp, err := outerValue.Compare(sc.TypeCtx(), &innerValue, iw.HashCollators[i]) + if err != nil { + return nil, nil, err + } + if cmp != 0 { + // If the converted outerValue is not equal to the origin outerValue, we don't need to lookup it. + return nil, nil, nil + } + if i < keyLen { + dLookupKey = append(dLookupKey, innerValue) + } + dHashKey = append(dHashKey, innerValue) + } + return dLookupKey, dHashKey, nil +} + +func (iw *innerWorker) sortAndDedupLookUpContents(lookUpContents []*IndexJoinLookUpContent) []*IndexJoinLookUpContent { + if len(lookUpContents) < 2 { + return lookUpContents + } + sc := iw.ctx.GetSessionVars().StmtCtx + slices.SortFunc(lookUpContents, func(i, j *IndexJoinLookUpContent) int { + cmp := compareRow(sc, i.Keys, j.Keys, iw.KeyCollators) + if cmp != 0 || iw.nextColCompareFilters == nil { + return cmp + } + return iw.nextColCompareFilters.CompareRow(i.Row, j.Row) + }) + deDupedLookupKeys := lookUpContents[:1] + for i := 1; i < len(lookUpContents); i++ { + cmp := compareRow(sc, lookUpContents[i].Keys, lookUpContents[i-1].Keys, iw.KeyCollators) + if cmp != 0 || (iw.nextColCompareFilters != nil && iw.nextColCompareFilters.CompareRow(lookUpContents[i].Row, lookUpContents[i-1].Row) != 0) { + deDupedLookupKeys = append(deDupedLookupKeys, lookUpContents[i]) + } + } + return deDupedLookupKeys +} + +func compareRow(sc *stmtctx.StatementContext, left, right []types.Datum, ctors []collate.Collator) int { + for idx := 0; idx < len(left); idx++ { + cmp, err := left[idx].Compare(sc.TypeCtx(), &right[idx], ctors[idx]) + // We only compare rows with the same type, no error to return. + terror.Log(err) + if cmp > 0 { + return 1 + } else if cmp < 0 { + return -1 + } + } + return 0 +} + +func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTask, lookUpContent []*IndexJoinLookUpContent) error { + if iw.stats != nil { + start := time.Now() + defer func() { + atomic.AddInt64(&iw.stats.fetch, int64(time.Since(start))) + }() + } + innerExec, err := iw.ReaderBuilder.BuildExecutorForIndexJoin(ctx, lookUpContent, iw.indexRanges, iw.keyOff2IdxOff, iw.nextColCompareFilters, true, iw.memTracker, iw.lookup.Finished) + if innerExec != nil { + defer func() { terror.Log(exec.Close(innerExec)) }() + } + if err != nil { + return err + } + + innerResult := chunk.NewList(exec.RetTypes(innerExec), iw.ctx.GetSessionVars().MaxChunkSize, iw.ctx.GetSessionVars().MaxChunkSize) + innerResult.GetMemTracker().SetLabel(memory.LabelForBuildSideResult) + innerResult.GetMemTracker().AttachTo(task.memTracker) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + err := exec.Next(ctx, innerExec, iw.executorChk) + failpoint.Inject("ConsumeRandomPanic", nil) + if err != nil { + return err + } + if iw.executorChk.NumRows() == 0 { + break + } + innerResult.Add(iw.executorChk) + iw.executorChk = exec.TryNewCacheChunk(innerExec) + } + task.innerResult = innerResult + return nil +} + +func (iw *innerWorker) buildLookUpMap(task *lookUpJoinTask) error { + if iw.stats != nil { + start := time.Now() + defer func() { + atomic.AddInt64(&iw.stats.build, int64(time.Since(start))) + }() + } + keyBuf := make([]byte, 0, 64) + valBuf := make([]byte, 8) + for i := 0; i < task.innerResult.NumChunks(); i++ { + chk := task.innerResult.GetChunk(i) + for j := 0; j < chk.NumRows(); j++ { + innerRow := chk.GetRow(j) + if iw.hasNullInJoinKey(innerRow) { + continue + } + + keyBuf = keyBuf[:0] + for _, keyCol := range iw.HashCols { + d := innerRow.GetDatum(keyCol, iw.RowTypes[keyCol]) + var err error + keyBuf, err = codec.EncodeKey(iw.ctx.GetSessionVars().StmtCtx.TimeZone(), keyBuf, d) + err = iw.ctx.GetSessionVars().StmtCtx.HandleError(err) + if err != nil { + return err + } + } + rowPtr := chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)} + *(*chunk.RowPtr)(unsafe.Pointer(&valBuf[0])) = rowPtr + task.lookupMap.Put(keyBuf, valBuf) + } + } + return nil +} + +func (iw *innerWorker) hasNullInJoinKey(row chunk.Row) bool { + for _, ordinal := range iw.HashCols { + if row.IsNull(ordinal) { + return true + } + } + return false +} + +// Close implements the Executor interface. +func (e *IndexLookUpJoin) Close() error { + if e.stats != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + if e.cancelFunc != nil { + e.cancelFunc() + } + e.WorkerWg.Wait() + e.memTracker = nil + e.task = nil + e.Finished.Store(false) + e.prepared = false + return e.BaseExecutor.Close() +} + +type indexLookUpJoinRuntimeStats struct { + concurrency int + probe int64 + innerWorker innerWorkerRuntimeStats +} + +type innerWorkerRuntimeStats struct { + totalTime int64 + task int64 + construct int64 + fetch int64 + build int64 + join int64 +} + +func (e *indexLookUpJoinRuntimeStats) String() string { + buf := bytes.NewBuffer(make([]byte, 0, 16)) + if e.innerWorker.totalTime > 0 { + buf.WriteString("inner:{total:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.totalTime))) + buf.WriteString(", concurrency:") + if e.concurrency > 0 { + buf.WriteString(strconv.Itoa(e.concurrency)) + } else { + buf.WriteString("OFF") + } + buf.WriteString(", task:") + buf.WriteString(strconv.FormatInt(e.innerWorker.task, 10)) + buf.WriteString(", construct:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.construct))) + buf.WriteString(", fetch:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.fetch))) + buf.WriteString(", build:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.build))) + if e.innerWorker.join > 0 { + buf.WriteString(", join:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.join))) + } + buf.WriteString("}") + } + if e.probe > 0 { + buf.WriteString(", probe:") + buf.WriteString(execdetails.FormatDuration(time.Duration(e.probe))) + } + return buf.String() +} + +func (e *indexLookUpJoinRuntimeStats) Clone() execdetails.RuntimeStats { + return &indexLookUpJoinRuntimeStats{ + concurrency: e.concurrency, + probe: e.probe, + innerWorker: e.innerWorker, + } +} + +func (e *indexLookUpJoinRuntimeStats) Merge(rs execdetails.RuntimeStats) { + tmp, ok := rs.(*indexLookUpJoinRuntimeStats) + if !ok { + return + } + e.probe += tmp.probe + e.innerWorker.totalTime += tmp.innerWorker.totalTime + e.innerWorker.task += tmp.innerWorker.task + e.innerWorker.construct += tmp.innerWorker.construct + e.innerWorker.fetch += tmp.innerWorker.fetch + e.innerWorker.build += tmp.innerWorker.build + e.innerWorker.join += tmp.innerWorker.join +} + +// Tp implements the RuntimeStats interface. +func (*indexLookUpJoinRuntimeStats) Tp() int { + return execdetails.TpIndexLookUpJoinRuntimeStats +} diff --git a/pkg/executor/join/index_lookup_merge_join.go b/pkg/executor/join/index_lookup_merge_join.go index c181eb04548d7..0a8afa17f8234 100644 --- a/pkg/executor/join/index_lookup_merge_join.go +++ b/pkg/executor/join/index_lookup_merge_join.go @@ -211,9 +211,9 @@ func (e *IndexLookUpMergeJoin) newOuterWorker(resultCh, innerCh chan *lookUpMerg parentMemTracker: e.memTracker, nextColCompareFilters: e.LastColHelper, } - failpoint.Inject("testIssue18068", func() { + if _, _err_ := failpoint.Eval(_curpkg_("testIssue18068")); _err_ == nil { omw.batchSize = 1 - }) + } return omw } @@ -316,7 +316,7 @@ func (omw *outerMergeWorker) run(ctx context.Context, wg *sync.WaitGroup, cancel omw.pushToChan(ctx, task, omw.resultCh) return } - failpoint.Inject("mockIndexMergeJoinOOMPanic", nil) + failpoint.Eval(_curpkg_("mockIndexMergeJoinOOMPanic")) if task == nil { return } diff --git a/pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ b/pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ new file mode 100644 index 0000000000000..c181eb04548d7 --- /dev/null +++ b/pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ @@ -0,0 +1,743 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "context" + "runtime/trace" + "slices" + "sync" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" + "go.uber.org/zap" +) + +// IndexLookUpMergeJoin realizes IndexLookUpJoin by merge join +// It preserves the order of the outer table and support batch lookup. +// +// The execution flow is very similar to IndexLookUpReader: +// 1. outerWorker read N outer rows, build a task and send it to result channel and inner worker channel. +// 2. The innerWorker receives the task, builds key ranges from outer rows and fetch inner rows, then do merge join. +// 3. main thread receives the task and fetch results from the channel in task one by one. +// 4. If channel has been closed, main thread receives the next task. +type IndexLookUpMergeJoin struct { + exec.BaseExecutor + + resultCh <-chan *lookUpMergeJoinTask + cancelFunc context.CancelFunc + WorkerWg *sync.WaitGroup + + OuterMergeCtx OuterMergeCtx + InnerMergeCtx InnerMergeCtx + + Joiners []Joiner + joinChkResourceCh []chan *chunk.Chunk + IsOuterJoin bool + + requiredRows int64 + + task *lookUpMergeJoinTask + + IndexRanges ranger.MutableRanges + KeyOff2IdxOff []int + + // LastColHelper store the information for last col if there's complicated filter like col > x_col and col < x_col + 100. + LastColHelper *plannercore.ColWithCmpFuncManager + + memTracker *memory.Tracker // track memory usage + prepared bool +} + +// OuterMergeCtx is the outer side ctx of merge join +type OuterMergeCtx struct { + RowTypes []*types.FieldType + JoinKeys []*expression.Column + KeyCols []int + Filter expression.CNFExprs + NeedOuterSort bool + CompareFuncs []expression.CompareFunc +} + +// InnerMergeCtx is the inner side ctx of merge join +type InnerMergeCtx struct { + ReaderBuilder IndexJoinExecutorBuilder + RowTypes []*types.FieldType + JoinKeys []*expression.Column + KeyCols []int + KeyCollators []collate.Collator + CompareFuncs []expression.CompareFunc + ColLens []int + Desc bool + KeyOff2KeyOffOrderByIdx []int +} + +type lookUpMergeJoinTask struct { + outerResult *chunk.List + outerMatch [][]bool + outerOrderIdx []chunk.RowPtr + + innerResult *chunk.Chunk + innerIter chunk.Iterator + + sameKeyInnerRows []chunk.Row + sameKeyIter chunk.Iterator + + doneErr error + results chan *indexMergeJoinResult + + memTracker *memory.Tracker +} + +type outerMergeWorker struct { + OuterMergeCtx + + lookup *IndexLookUpMergeJoin + + ctx sessionctx.Context + executor exec.Executor + + maxBatchSize int + batchSize int + + nextColCompareFilters *plannercore.ColWithCmpFuncManager + + resultCh chan<- *lookUpMergeJoinTask + innerCh chan<- *lookUpMergeJoinTask + + parentMemTracker *memory.Tracker +} + +type innerMergeWorker struct { + InnerMergeCtx + + taskCh <-chan *lookUpMergeJoinTask + joinChkResourceCh chan *chunk.Chunk + outerMergeCtx OuterMergeCtx + ctx sessionctx.Context + innerExec exec.Executor + joiner Joiner + retFieldTypes []*types.FieldType + + maxChunkSize int + indexRanges []*ranger.Range + nextColCompareFilters *plannercore.ColWithCmpFuncManager + keyOff2IdxOff []int +} + +type indexMergeJoinResult struct { + chk *chunk.Chunk + src chan<- *chunk.Chunk +} + +// Open implements the Executor interface +func (e *IndexLookUpMergeJoin) Open(ctx context.Context) error { + err := exec.Open(ctx, e.Children(0)) + if err != nil { + return err + } + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + return nil +} + +func (e *IndexLookUpMergeJoin) startWorkers(ctx context.Context) { + // TODO: consider another session currency variable for index merge join. + // Because its parallelization is not complete. + concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() + if e.RuntimeStats() != nil { + runtimeStats := &execdetails.RuntimeStatsWithConcurrencyInfo{} + runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", concurrency)) + e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) + } + + resultCh := make(chan *lookUpMergeJoinTask, concurrency) + e.resultCh = resultCh + e.joinChkResourceCh = make([]chan *chunk.Chunk, concurrency) + for i := 0; i < concurrency; i++ { + e.joinChkResourceCh[i] = make(chan *chunk.Chunk, numResChkHold) + for j := 0; j < numResChkHold; j++ { + e.joinChkResourceCh[i] <- chunk.NewChunkWithCapacity(e.RetFieldTypes(), e.MaxChunkSize()) + } + } + workerCtx, cancelFunc := context.WithCancel(ctx) + e.cancelFunc = cancelFunc + innerCh := make(chan *lookUpMergeJoinTask, concurrency) + e.WorkerWg.Add(1) + go e.newOuterWorker(resultCh, innerCh).run(workerCtx, e.WorkerWg, e.cancelFunc) + e.WorkerWg.Add(concurrency) + for i := 0; i < concurrency; i++ { + go e.newInnerMergeWorker(innerCh, i).run(workerCtx, e.WorkerWg, e.cancelFunc) + } +} + +func (e *IndexLookUpMergeJoin) newOuterWorker(resultCh, innerCh chan *lookUpMergeJoinTask) *outerMergeWorker { + omw := &outerMergeWorker{ + OuterMergeCtx: e.OuterMergeCtx, + ctx: e.Ctx(), + lookup: e, + executor: e.Children(0), + resultCh: resultCh, + innerCh: innerCh, + batchSize: 32, + maxBatchSize: e.Ctx().GetSessionVars().IndexJoinBatchSize, + parentMemTracker: e.memTracker, + nextColCompareFilters: e.LastColHelper, + } + failpoint.Inject("testIssue18068", func() { + omw.batchSize = 1 + }) + return omw +} + +func (e *IndexLookUpMergeJoin) newInnerMergeWorker(taskCh chan *lookUpMergeJoinTask, workID int) *innerMergeWorker { + // Since multiple inner workers run concurrently, we should copy join's IndexRanges for every worker to avoid data race. + copiedRanges := make([]*ranger.Range, 0, len(e.IndexRanges.Range())) + for _, ran := range e.IndexRanges.Range() { + copiedRanges = append(copiedRanges, ran.Clone()) + } + imw := &innerMergeWorker{ + InnerMergeCtx: e.InnerMergeCtx, + outerMergeCtx: e.OuterMergeCtx, + taskCh: taskCh, + ctx: e.Ctx(), + indexRanges: copiedRanges, + keyOff2IdxOff: e.KeyOff2IdxOff, + joiner: e.Joiners[workID], + joinChkResourceCh: e.joinChkResourceCh[workID], + retFieldTypes: e.RetFieldTypes(), + maxChunkSize: e.MaxChunkSize(), + } + if e.LastColHelper != nil { + // nextCwf.TmpConstant needs to be reset for every individual + // inner worker to avoid data race when the inner workers is running + // concurrently. + nextCwf := *e.LastColHelper + nextCwf.TmpConstant = make([]*expression.Constant, len(e.LastColHelper.TmpConstant)) + for i := range e.LastColHelper.TmpConstant { + nextCwf.TmpConstant[i] = &expression.Constant{RetType: nextCwf.TargetCol.RetType} + } + imw.nextColCompareFilters = &nextCwf + } + return imw +} + +// Next implements the Executor interface +func (e *IndexLookUpMergeJoin) Next(ctx context.Context, req *chunk.Chunk) error { + if !e.prepared { + e.startWorkers(ctx) + e.prepared = true + } + if e.IsOuterJoin { + atomic.StoreInt64(&e.requiredRows, int64(req.RequiredRows())) + } + req.Reset() + if e.task == nil { + e.loadFinishedTask(ctx) + } + for e.task != nil { + select { + case result, ok := <-e.task.results: + if !ok { + if e.task.doneErr != nil { + return e.task.doneErr + } + e.loadFinishedTask(ctx) + continue + } + req.SwapColumns(result.chk) + result.src <- result.chk + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil +} + +// TODO: reuse the Finished task memory to build tasks. +func (e *IndexLookUpMergeJoin) loadFinishedTask(ctx context.Context) { + select { + case e.task = <-e.resultCh: + case <-ctx.Done(): + e.task = nil + } +} + +func (omw *outerMergeWorker) run(ctx context.Context, wg *sync.WaitGroup, cancelFunc context.CancelFunc) { + defer trace.StartRegion(ctx, "IndexLookupMergeJoinOuterWorker").End() + defer func() { + if r := recover(); r != nil { + task := &lookUpMergeJoinTask{ + doneErr: util.GetRecoverError(r), + results: make(chan *indexMergeJoinResult, numResChkHold), + } + close(task.results) + omw.resultCh <- task + cancelFunc() + } + close(omw.resultCh) + close(omw.innerCh) + wg.Done() + }() + for { + task, err := omw.buildTask(ctx) + if err != nil { + task.doneErr = err + close(task.results) + omw.pushToChan(ctx, task, omw.resultCh) + return + } + failpoint.Inject("mockIndexMergeJoinOOMPanic", nil) + if task == nil { + return + } + + if finished := omw.pushToChan(ctx, task, omw.innerCh); finished { + return + } + + if finished := omw.pushToChan(ctx, task, omw.resultCh); finished { + return + } + } +} + +func (*outerMergeWorker) pushToChan(ctx context.Context, task *lookUpMergeJoinTask, dst chan<- *lookUpMergeJoinTask) (finished bool) { + select { + case <-ctx.Done(): + return true + case dst <- task: + } + return false +} + +// buildTask builds a lookUpMergeJoinTask and read Outer rows. +// When err is not nil, task must not be nil to send the error to the main thread via task +func (omw *outerMergeWorker) buildTask(ctx context.Context) (*lookUpMergeJoinTask, error) { + task := &lookUpMergeJoinTask{ + results: make(chan *indexMergeJoinResult, numResChkHold), + outerResult: chunk.NewList(omw.RowTypes, omw.executor.InitCap(), omw.executor.MaxChunkSize()), + } + task.memTracker = memory.NewTracker(memory.LabelForSimpleTask, -1) + task.memTracker.AttachTo(omw.parentMemTracker) + + omw.increaseBatchSize() + requiredRows := omw.batchSize + if omw.lookup.IsOuterJoin { + requiredRows = int(atomic.LoadInt64(&omw.lookup.requiredRows)) + } + if requiredRows <= 0 || requiredRows > omw.maxBatchSize { + requiredRows = omw.maxBatchSize + } + for requiredRows > 0 { + execChk := exec.TryNewCacheChunk(omw.executor) + err := exec.Next(ctx, omw.executor, execChk) + if err != nil { + return task, err + } + if execChk.NumRows() == 0 { + break + } + + task.outerResult.Add(execChk) + requiredRows -= execChk.NumRows() + task.memTracker.Consume(execChk.MemoryUsage()) + } + + if task.outerResult.Len() == 0 { + return nil, nil + } + + return task, nil +} + +func (omw *outerMergeWorker) increaseBatchSize() { + if omw.batchSize < omw.maxBatchSize { + omw.batchSize *= 2 + } + if omw.batchSize > omw.maxBatchSize { + omw.batchSize = omw.maxBatchSize + } +} + +func (imw *innerMergeWorker) run(ctx context.Context, wg *sync.WaitGroup, cancelFunc context.CancelFunc) { + defer trace.StartRegion(ctx, "IndexLookupMergeJoinInnerWorker").End() + var task *lookUpMergeJoinTask + defer func() { + wg.Done() + if r := recover(); r != nil { + if task != nil { + task.doneErr = util.GetRecoverError(r) + close(task.results) + } + logutil.Logger(ctx).Error("innerMergeWorker panicked", zap.Any("recover", r), zap.Stack("stack")) + cancelFunc() + } + }() + + for ok := true; ok; { + select { + case task, ok = <-imw.taskCh: + if !ok { + return + } + case <-ctx.Done(): + return + } + + err := imw.handleTask(ctx, task) + task.doneErr = err + close(task.results) + } +} + +func (imw *innerMergeWorker) handleTask(ctx context.Context, task *lookUpMergeJoinTask) (err error) { + numOuterChks := task.outerResult.NumChunks() + if imw.outerMergeCtx.Filter != nil { + task.outerMatch = make([][]bool, numOuterChks) + exprCtx := imw.ctx.GetExprCtx() + for i := 0; i < numOuterChks; i++ { + chk := task.outerResult.GetChunk(i) + task.outerMatch[i] = make([]bool, chk.NumRows()) + task.outerMatch[i], err = expression.VectorizedFilter(exprCtx.GetEvalCtx(), imw.ctx.GetSessionVars().EnableVectorizedExpression, imw.outerMergeCtx.Filter, chunk.NewIterator4Chunk(chk), task.outerMatch[i]) + if err != nil { + return err + } + } + } + task.memTracker.Consume(int64(cap(task.outerMatch))) + task.outerOrderIdx = make([]chunk.RowPtr, 0, task.outerResult.Len()) + for i := 0; i < numOuterChks; i++ { + numRow := task.outerResult.GetChunk(i).NumRows() + for j := 0; j < numRow; j++ { + task.outerOrderIdx = append(task.outerOrderIdx, chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)}) + } + } + task.memTracker.Consume(int64(cap(task.outerOrderIdx))) + // NeedOuterSort means the outer side property items can't guarantee the order of join keys. + // Because the necessary condition of merge join is both outer and inner keep order of join keys. + // In this case, we need sort the outer side. + if imw.outerMergeCtx.NeedOuterSort { + exprCtx := imw.ctx.GetExprCtx() + slices.SortFunc(task.outerOrderIdx, func(idxI, idxJ chunk.RowPtr) int { + rowI, rowJ := task.outerResult.GetRow(idxI), task.outerResult.GetRow(idxJ) + var c int64 + var err error + for _, keyOff := range imw.KeyOff2KeyOffOrderByIdx { + joinKey := imw.outerMergeCtx.JoinKeys[keyOff] + c, _, err = imw.outerMergeCtx.CompareFuncs[keyOff](exprCtx.GetEvalCtx(), joinKey, joinKey, rowI, rowJ) + terror.Log(err) + if c != 0 { + break + } + } + if c != 0 || imw.nextColCompareFilters == nil { + if imw.Desc { + return int(-c) + } + return int(c) + } + c = int64(imw.nextColCompareFilters.CompareRow(rowI, rowJ)) + if imw.Desc { + return int(-c) + } + return int(c) + }) + } + dLookUpKeys, err := imw.constructDatumLookupKeys(task) + if err != nil { + return err + } + dLookUpKeys = imw.dedupDatumLookUpKeys(dLookUpKeys) + // If the order requires descending, the deDupedLookUpContents is keep descending order before. + // So at the end, we should generate the ascending deDupedLookUpContents to build the correct range for inner read. + if imw.Desc { + lenKeys := len(dLookUpKeys) + for i := 0; i < lenKeys/2; i++ { + dLookUpKeys[i], dLookUpKeys[lenKeys-i-1] = dLookUpKeys[lenKeys-i-1], dLookUpKeys[i] + } + } + imw.innerExec, err = imw.ReaderBuilder.BuildExecutorForIndexJoin(ctx, dLookUpKeys, imw.indexRanges, imw.keyOff2IdxOff, imw.nextColCompareFilters, false, nil, nil) + if imw.innerExec != nil { + defer func() { terror.Log(exec.Close(imw.innerExec)) }() + } + if err != nil { + return err + } + _, err = imw.fetchNextInnerResult(ctx, task) + if err != nil { + return err + } + err = imw.doMergeJoin(ctx, task) + return err +} + +func (imw *innerMergeWorker) fetchNewChunkWhenFull(ctx context.Context, task *lookUpMergeJoinTask, chk **chunk.Chunk) (continueJoin bool) { + if !(*chk).IsFull() { + return true + } + select { + case task.results <- &indexMergeJoinResult{*chk, imw.joinChkResourceCh}: + case <-ctx.Done(): + return false + } + var ok bool + select { + case *chk, ok = <-imw.joinChkResourceCh: + if !ok { + return false + } + case <-ctx.Done(): + return false + } + (*chk).Reset() + return true +} + +func (imw *innerMergeWorker) doMergeJoin(ctx context.Context, task *lookUpMergeJoinTask) (err error) { + var chk *chunk.Chunk + select { + case chk = <-imw.joinChkResourceCh: + case <-ctx.Done(): + return + } + defer func() { + if chk == nil { + return + } + if chk.NumRows() > 0 { + select { + case task.results <- &indexMergeJoinResult{chk, imw.joinChkResourceCh}: + case <-ctx.Done(): + return + } + } else { + imw.joinChkResourceCh <- chk + } + }() + + initCmpResult := 1 + if imw.InnerMergeCtx.Desc { + initCmpResult = -1 + } + noneInnerRowsRemain := task.innerResult.NumRows() == 0 + + for _, outerIdx := range task.outerOrderIdx { + outerRow := task.outerResult.GetRow(outerIdx) + hasMatch, hasNull, cmpResult := false, false, initCmpResult + if task.outerMatch != nil && !task.outerMatch[outerIdx.ChkIdx][outerIdx.RowIdx] { + goto missMatch + } + // If it has iterated out all inner rows and the inner rows with same key is empty, + // that means the Outer Row needn't match any inner rows. + if noneInnerRowsRemain && len(task.sameKeyInnerRows) == 0 { + goto missMatch + } + if len(task.sameKeyInnerRows) > 0 { + cmpResult, err = imw.compare(outerRow, task.sameKeyIter.Begin()) + if err != nil { + return err + } + } + if (cmpResult > 0 && !imw.InnerMergeCtx.Desc) || (cmpResult < 0 && imw.InnerMergeCtx.Desc) { + if noneInnerRowsRemain { + task.sameKeyInnerRows = task.sameKeyInnerRows[:0] + goto missMatch + } + noneInnerRowsRemain, err = imw.fetchInnerRowsWithSameKey(ctx, task, outerRow) + if err != nil { + return err + } + } + + for task.sameKeyIter.Current() != task.sameKeyIter.End() { + matched, isNull, err := imw.joiner.TryToMatchInners(outerRow, task.sameKeyIter, chk) + if err != nil { + return err + } + hasMatch = hasMatch || matched + hasNull = hasNull || isNull + if !imw.fetchNewChunkWhenFull(ctx, task, &chk) { + return nil + } + } + + missMatch: + if !hasMatch { + imw.joiner.OnMissMatch(hasNull, outerRow, chk) + if !imw.fetchNewChunkWhenFull(ctx, task, &chk) { + return nil + } + } + } + + return nil +} + +// fetchInnerRowsWithSameKey collects the inner rows having the same key with one outer row. +func (imw *innerMergeWorker) fetchInnerRowsWithSameKey(ctx context.Context, task *lookUpMergeJoinTask, key chunk.Row) (noneInnerRows bool, err error) { + task.sameKeyInnerRows = task.sameKeyInnerRows[:0] + curRow := task.innerIter.Current() + var cmpRes int + for cmpRes, err = imw.compare(key, curRow); ((cmpRes >= 0 && !imw.Desc) || (cmpRes <= 0 && imw.Desc)) && err == nil; cmpRes, err = imw.compare(key, curRow) { + if cmpRes == 0 { + task.sameKeyInnerRows = append(task.sameKeyInnerRows, curRow) + } + curRow = task.innerIter.Next() + if curRow == task.innerIter.End() { + curRow, err = imw.fetchNextInnerResult(ctx, task) + if err != nil || task.innerResult.NumRows() == 0 { + break + } + } + } + task.sameKeyIter = chunk.NewIterator4Slice(task.sameKeyInnerRows) + task.sameKeyIter.Begin() + noneInnerRows = task.innerResult.NumRows() == 0 + return +} + +func (imw *innerMergeWorker) compare(outerRow, innerRow chunk.Row) (int, error) { + exprCtx := imw.ctx.GetExprCtx() + for _, keyOff := range imw.InnerMergeCtx.KeyOff2KeyOffOrderByIdx { + cmp, _, err := imw.InnerMergeCtx.CompareFuncs[keyOff](exprCtx.GetEvalCtx(), imw.outerMergeCtx.JoinKeys[keyOff], imw.InnerMergeCtx.JoinKeys[keyOff], outerRow, innerRow) + if err != nil || cmp != 0 { + return int(cmp), err + } + } + return 0, nil +} + +func (imw *innerMergeWorker) constructDatumLookupKeys(task *lookUpMergeJoinTask) ([]*IndexJoinLookUpContent, error) { + numRows := len(task.outerOrderIdx) + dLookUpKeys := make([]*IndexJoinLookUpContent, 0, numRows) + for i := 0; i < numRows; i++ { + dLookUpKey, err := imw.constructDatumLookupKey(task, task.outerOrderIdx[i]) + if err != nil { + return nil, err + } + if dLookUpKey == nil { + continue + } + dLookUpKeys = append(dLookUpKeys, dLookUpKey) + } + + return dLookUpKeys, nil +} + +func (imw *innerMergeWorker) constructDatumLookupKey(task *lookUpMergeJoinTask, idx chunk.RowPtr) (*IndexJoinLookUpContent, error) { + if task.outerMatch != nil && !task.outerMatch[idx.ChkIdx][idx.RowIdx] { + return nil, nil + } + outerRow := task.outerResult.GetRow(idx) + sc := imw.ctx.GetSessionVars().StmtCtx + keyLen := len(imw.KeyCols) + dLookupKey := make([]types.Datum, 0, keyLen) + for i, keyCol := range imw.outerMergeCtx.KeyCols { + outerValue := outerRow.GetDatum(keyCol, imw.outerMergeCtx.RowTypes[keyCol]) + // Join-on-condition can be promised to be equal-condition in + // IndexNestedLoopJoin, thus the Filter will always be false if + // outerValue is null, and we don't need to lookup it. + if outerValue.IsNull() { + return nil, nil + } + innerColType := imw.RowTypes[imw.KeyCols[i]] + innerValue, err := outerValue.ConvertTo(sc.TypeCtx(), innerColType) + if err != nil { + // If the converted outerValue overflows, we don't need to lookup it. + if terror.ErrorEqual(err, types.ErrOverflow) || terror.ErrorEqual(err, types.ErrWarnDataOutOfRange) { + return nil, nil + } + if terror.ErrorEqual(err, types.ErrTruncated) && (innerColType.GetType() == mysql.TypeSet || innerColType.GetType() == mysql.TypeEnum) { + return nil, nil + } + return nil, err + } + cmp, err := outerValue.Compare(sc.TypeCtx(), &innerValue, imw.KeyCollators[i]) + if err != nil { + return nil, err + } + if cmp != 0 { + // If the converted outerValue is not equal to the origin outerValue, we don't need to lookup it. + return nil, nil + } + dLookupKey = append(dLookupKey, innerValue) + } + return &IndexJoinLookUpContent{Keys: dLookupKey, Row: task.outerResult.GetRow(idx)}, nil +} + +func (imw *innerMergeWorker) dedupDatumLookUpKeys(lookUpContents []*IndexJoinLookUpContent) []*IndexJoinLookUpContent { + if len(lookUpContents) < 2 { + return lookUpContents + } + sc := imw.ctx.GetSessionVars().StmtCtx + deDupedLookUpContents := lookUpContents[:1] + for i := 1; i < len(lookUpContents); i++ { + cmp := compareRow(sc, lookUpContents[i].Keys, lookUpContents[i-1].Keys, imw.KeyCollators) + if cmp != 0 || (imw.nextColCompareFilters != nil && imw.nextColCompareFilters.CompareRow(lookUpContents[i].Row, lookUpContents[i-1].Row) != 0) { + deDupedLookUpContents = append(deDupedLookUpContents, lookUpContents[i]) + } + } + return deDupedLookUpContents +} + +// fetchNextInnerResult collects a chunk of inner results from inner child executor. +func (imw *innerMergeWorker) fetchNextInnerResult(ctx context.Context, task *lookUpMergeJoinTask) (beginRow chunk.Row, err error) { + task.innerResult = imw.innerExec.NewChunkWithCapacity(imw.innerExec.RetFieldTypes(), imw.innerExec.MaxChunkSize(), imw.innerExec.MaxChunkSize()) + err = exec.Next(ctx, imw.innerExec, task.innerResult) + task.innerIter = chunk.NewIterator4Chunk(task.innerResult) + beginRow = task.innerIter.Begin() + return +} + +// Close implements the Executor interface. +func (e *IndexLookUpMergeJoin) Close() error { + if e.RuntimeStats() != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.RuntimeStats()) + } + if e.cancelFunc != nil { + e.cancelFunc() + e.cancelFunc = nil + } + if e.resultCh != nil { + channel.Clear(e.resultCh) + e.resultCh = nil + } + e.joinChkResourceCh = nil + // joinChkResourceCh is to recycle result chunks, used by inner worker. + // resultCh is the main thread get the results, used by main thread and inner worker. + // cancelFunc control the outer worker and outer worker close the task channel. + e.WorkerWg.Wait() + e.memTracker = nil + e.prepared = false + return e.BaseExecutor.Close() +} diff --git a/pkg/executor/join/merge_join.go b/pkg/executor/join/merge_join.go index 7dc647ce6c592..0be8ff35fae1e 100644 --- a/pkg/executor/join/merge_join.go +++ b/pkg/executor/join/merge_join.go @@ -103,11 +103,11 @@ func (t *MergeJoinTable) init(executor *MergeJoinExec) { t.rowContainer.GetDiskTracker().SetLabel(memory.LabelForInnerTable) if variable.EnableTmpStorageOnOOM.Load() { actionSpill := t.rowContainer.ActionSpill() - failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testMergeJoinRowContainerSpill")); _err_ == nil { if val.(bool) { actionSpill = t.rowContainer.ActionSpillForTest() } - }) + } executor.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } t.memTracker = memory.NewTracker(memory.LabelForInnerTable, -1) @@ -128,12 +128,12 @@ func (t *MergeJoinTable) finish() error { t.memTracker.Consume(-t.childChunk.MemoryUsage()) if t.IsInner { - failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testMergeJoinRowContainerSpill")); _err_ == nil { if val.(bool) { actionSpill := t.rowContainer.ActionSpill() actionSpill.WaitForTest() } - }) + } if err := t.rowContainer.Close(); err != nil { return err } @@ -330,7 +330,7 @@ func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) innerIter := e.InnerTable.groupRowsIter outerIter := e.OuterTable.groupRowsIter for !req.IsFull() { - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) if innerIter.Current() == innerIter.End() { if err := e.InnerTable.fetchNextInnerGroup(ctx, e); err != nil { return err diff --git a/pkg/executor/join/merge_join.go__failpoint_stash__ b/pkg/executor/join/merge_join.go__failpoint_stash__ new file mode 100644 index 0000000000000..7dc647ce6c592 --- /dev/null +++ b/pkg/executor/join/merge_join.go__failpoint_stash__ @@ -0,0 +1,420 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package join + +import ( + "context" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +var ( + _ exec.Executor = &MergeJoinExec{} +) + +// MergeJoinExec implements the merge join algorithm. +// This operator assumes that two iterators of both sides +// will provide required order on join condition: +// 1. For equal-join, one of the join key from each side +// matches the order given. +// 2. For other cases its preferred not to use SMJ and operator +// will throw error. +type MergeJoinExec struct { + exec.BaseExecutor + + StmtCtx *stmtctx.StatementContext + CompareFuncs []expression.CompareFunc + Joiner Joiner + IsOuterJoin bool + Desc bool + + InnerTable *MergeJoinTable + OuterTable *MergeJoinTable + + hasMatch bool + hasNull bool + + memTracker *memory.Tracker + diskTracker *disk.Tracker +} + +// MergeJoinTable is used for merge join +type MergeJoinTable struct { + inited bool + IsInner bool + ChildIndex int + JoinKeys []*expression.Column + Filters []expression.Expression + + executed bool + childChunk *chunk.Chunk + childChunkIter *chunk.Iterator4Chunk + groupChecker *vecgroupchecker.VecGroupChecker + groupRowsSelected []int + groupRowsIter chunk.Iterator + + // for inner table, an unbroken group may refer many chunks + rowContainer *chunk.RowContainer + + // for outer table, save result of filters + filtersSelected []bool + + memTracker *memory.Tracker +} + +func (t *MergeJoinTable) init(executor *MergeJoinExec) { + child := executor.Children(t.ChildIndex) + t.childChunk = exec.TryNewCacheChunk(child) + t.childChunkIter = chunk.NewIterator4Chunk(t.childChunk) + + items := make([]expression.Expression, 0, len(t.JoinKeys)) + for _, col := range t.JoinKeys { + items = append(items, col) + } + vecEnabled := executor.Ctx().GetSessionVars().EnableVectorizedExpression + t.groupChecker = vecgroupchecker.NewVecGroupChecker(executor.Ctx().GetExprCtx().GetEvalCtx(), vecEnabled, items) + t.groupRowsIter = chunk.NewIterator4Chunk(t.childChunk) + + if t.IsInner { + t.rowContainer = chunk.NewRowContainer(child.RetFieldTypes(), t.childChunk.Capacity()) + t.rowContainer.GetMemTracker().AttachTo(executor.memTracker) + t.rowContainer.GetMemTracker().SetLabel(memory.LabelForInnerTable) + t.rowContainer.GetDiskTracker().AttachTo(executor.diskTracker) + t.rowContainer.GetDiskTracker().SetLabel(memory.LabelForInnerTable) + if variable.EnableTmpStorageOnOOM.Load() { + actionSpill := t.rowContainer.ActionSpill() + failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { + if val.(bool) { + actionSpill = t.rowContainer.ActionSpillForTest() + } + }) + executor.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) + } + t.memTracker = memory.NewTracker(memory.LabelForInnerTable, -1) + } else { + t.filtersSelected = make([]bool, 0, executor.MaxChunkSize()) + t.memTracker = memory.NewTracker(memory.LabelForOuterTable, -1) + } + + t.memTracker.AttachTo(executor.memTracker) + t.inited = true + t.memTracker.Consume(t.childChunk.MemoryUsage()) +} + +func (t *MergeJoinTable) finish() error { + if !t.inited { + return nil + } + t.memTracker.Consume(-t.childChunk.MemoryUsage()) + + if t.IsInner { + failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { + if val.(bool) { + actionSpill := t.rowContainer.ActionSpill() + actionSpill.WaitForTest() + } + }) + if err := t.rowContainer.Close(); err != nil { + return err + } + } + + t.executed = false + t.childChunk = nil + t.childChunkIter = nil + t.groupChecker = nil + t.groupRowsSelected = nil + t.groupRowsIter = nil + t.rowContainer = nil + t.filtersSelected = nil + t.memTracker = nil + return nil +} + +func (t *MergeJoinTable) selectNextGroup() { + t.groupRowsSelected = t.groupRowsSelected[:0] + begin, end := t.groupChecker.GetNextGroup() + if t.IsInner && t.hasNullInJoinKey(t.childChunk.GetRow(begin)) { + return + } + + for i := begin; i < end; i++ { + t.groupRowsSelected = append(t.groupRowsSelected, i) + } + t.childChunk.SetSel(t.groupRowsSelected) +} + +func (t *MergeJoinTable) fetchNextChunk(ctx context.Context, executor *MergeJoinExec) error { + oldMemUsage := t.childChunk.MemoryUsage() + err := exec.Next(ctx, executor.Children(t.ChildIndex), t.childChunk) + t.memTracker.Consume(t.childChunk.MemoryUsage() - oldMemUsage) + if err != nil { + return err + } + t.executed = t.childChunk.NumRows() == 0 + return nil +} + +func (t *MergeJoinTable) fetchNextInnerGroup(ctx context.Context, exec *MergeJoinExec) error { + t.childChunk.SetSel(nil) + if err := t.rowContainer.Reset(); err != nil { + return err + } + +fetchNext: + if t.executed && t.groupChecker.IsExhausted() { + // Ensure iter at the end, since sel of childChunk has been cleared. + t.groupRowsIter.ReachEnd() + return nil + } + + isEmpty := true + // For inner table, rows have null in join keys should be skip by selectNextGroup. + for isEmpty && !t.groupChecker.IsExhausted() { + t.selectNextGroup() + isEmpty = len(t.groupRowsSelected) == 0 + } + + // For inner table, all the rows have the same join keys should be put into one group. + for !t.executed && t.groupChecker.IsExhausted() { + if !isEmpty { + // Group is not empty, hand over the management of childChunk to t.RowContainer. + if err := t.rowContainer.Add(t.childChunk); err != nil { + return err + } + t.memTracker.Consume(-t.childChunk.MemoryUsage()) + t.groupRowsSelected = nil + + t.childChunk = t.rowContainer.AllocChunk() + t.childChunkIter = chunk.NewIterator4Chunk(t.childChunk) + t.memTracker.Consume(t.childChunk.MemoryUsage()) + } + + if err := t.fetchNextChunk(ctx, exec); err != nil { + return err + } + if t.executed { + break + } + + isFirstGroupSameAsPrev, err := t.groupChecker.SplitIntoGroups(t.childChunk) + if err != nil { + return err + } + if isFirstGroupSameAsPrev && !isEmpty { + t.selectNextGroup() + } + } + if isEmpty { + goto fetchNext + } + + // iterate all data in t.RowContainer and t.childChunk + var iter chunk.Iterator + if t.rowContainer.NumChunks() != 0 { + iter = chunk.NewIterator4RowContainer(t.rowContainer) + } + if len(t.groupRowsSelected) != 0 { + if iter != nil { + iter = chunk.NewMultiIterator(iter, t.childChunkIter) + } else { + iter = t.childChunkIter + } + } + t.groupRowsIter = iter + t.groupRowsIter.Begin() + return nil +} + +func (t *MergeJoinTable) fetchNextOuterGroup(ctx context.Context, exec *MergeJoinExec, requiredRows int) error { + if t.executed && t.groupChecker.IsExhausted() { + return nil + } + + if !t.executed && t.groupChecker.IsExhausted() { + // It's hard to calculate selectivity if there is any filter or it's inner join, + // so we just push the requiredRows down when it's outer join and has no filter. + if exec.IsOuterJoin && len(t.Filters) == 0 { + t.childChunk.SetRequiredRows(requiredRows, exec.MaxChunkSize()) + } + err := t.fetchNextChunk(ctx, exec) + if err != nil || t.executed { + return err + } + + t.childChunkIter.Begin() + t.filtersSelected, err = expression.VectorizedFilter(exec.Ctx().GetExprCtx().GetEvalCtx(), exec.Ctx().GetSessionVars().EnableVectorizedExpression, t.Filters, t.childChunkIter, t.filtersSelected) + if err != nil { + return err + } + + _, err = t.groupChecker.SplitIntoGroups(t.childChunk) + if err != nil { + return err + } + } + + t.selectNextGroup() + t.groupRowsIter.Begin() + return nil +} + +func (t *MergeJoinTable) hasNullInJoinKey(row chunk.Row) bool { + for _, col := range t.JoinKeys { + ordinal := col.Index + if row.IsNull(ordinal) { + return true + } + } + return false +} + +// Close implements the Executor Close interface. +func (e *MergeJoinExec) Close() error { + if err := e.InnerTable.finish(); err != nil { + return err + } + if err := e.OuterTable.finish(); err != nil { + return err + } + + e.hasMatch = false + e.hasNull = false + e.memTracker = nil + e.diskTracker = nil + return e.BaseExecutor.Close() +} + +// Open implements the Executor Open interface. +func (e *MergeJoinExec) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + e.diskTracker = disk.NewTracker(e.ID(), -1) + e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) + + e.InnerTable.init(e) + e.OuterTable.init(e) + return nil +} + +// Next implements the Executor Next interface. +// Note the inner group collects all identical keys in a group across multiple chunks, but the outer group just covers +// the identical keys within a chunk, so identical keys may cover more than one chunk. +func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + req.Reset() + + innerIter := e.InnerTable.groupRowsIter + outerIter := e.OuterTable.groupRowsIter + for !req.IsFull() { + failpoint.Inject("ConsumeRandomPanic", nil) + if innerIter.Current() == innerIter.End() { + if err := e.InnerTable.fetchNextInnerGroup(ctx, e); err != nil { + return err + } + innerIter = e.InnerTable.groupRowsIter + } + if outerIter.Current() == outerIter.End() { + if err := e.OuterTable.fetchNextOuterGroup(ctx, e, req.RequiredRows()-req.NumRows()); err != nil { + return err + } + outerIter = e.OuterTable.groupRowsIter + if e.OuterTable.executed { + return nil + } + } + + cmpResult := -1 + if e.Desc { + cmpResult = 1 + } + if innerIter.Current() != innerIter.End() { + cmpResult, err = e.compare(outerIter.Current(), innerIter.Current()) + if err != nil { + return err + } + } + // the inner group falls behind + if (cmpResult > 0 && !e.Desc) || (cmpResult < 0 && e.Desc) { + innerIter.ReachEnd() + continue + } + // the Outer group falls behind + if (cmpResult < 0 && !e.Desc) || (cmpResult > 0 && e.Desc) { + for row := outerIter.Current(); row != outerIter.End() && !req.IsFull(); row = outerIter.Next() { + e.Joiner.OnMissMatch(false, row, req) + } + continue + } + + for row := outerIter.Current(); row != outerIter.End() && !req.IsFull(); row = outerIter.Next() { + if !e.OuterTable.filtersSelected[row.Idx()] { + e.Joiner.OnMissMatch(false, row, req) + continue + } + // compare each Outer item with each inner item + // the inner maybe not exhausted at one time + for innerIter.Current() != innerIter.End() { + matched, isNull, err := e.Joiner.TryToMatchInners(row, innerIter, req) + if err != nil { + return err + } + e.hasMatch = e.hasMatch || matched + e.hasNull = e.hasNull || isNull + if req.IsFull() { + if innerIter.Current() == innerIter.End() { + break + } + return nil + } + } + + if !e.hasMatch { + e.Joiner.OnMissMatch(e.hasNull, row, req) + } + e.hasMatch = false + e.hasNull = false + innerIter.Begin() + } + } + return nil +} + +func (e *MergeJoinExec) compare(outerRow, innerRow chunk.Row) (int, error) { + outerJoinKeys := e.OuterTable.JoinKeys + innerJoinKeys := e.InnerTable.JoinKeys + for i := range outerJoinKeys { + cmp, _, err := e.CompareFuncs[i](e.Ctx().GetExprCtx().GetEvalCtx(), outerJoinKeys[i], innerJoinKeys[i], outerRow, innerRow) + if err != nil { + return 0, err + } + + if cmp != 0 { + return int(cmp), nil + } + } + return 0, nil +} diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index e6471546d36a5..91c4fc40ee0a5 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -242,7 +242,7 @@ func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDa }) // commitWork goroutines. group.Go(func() error { - failpoint.Inject("BeforeCommitWork", nil) + failpoint.Eval(_curpkg_("BeforeCommitWork")) return committer.commitWork(groupCtx, commitTaskCh) }) @@ -620,9 +620,9 @@ func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error logutil.Logger(ctx).Error("commit error CheckAndInsert", zap.Error(err)) return err } - failpoint.Inject("commitOneTaskErr", func() { - failpoint.Return(errors.New("mock commit one task error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("commitOneTaskErr")); _err_ == nil { + return errors.New("mock commit one task error") + } return nil } diff --git a/pkg/executor/load_data.go__failpoint_stash__ b/pkg/executor/load_data.go__failpoint_stash__ new file mode 100644 index 0000000000000..e6471546d36a5 --- /dev/null +++ b/pkg/executor/load_data.go__failpoint_stash__ @@ -0,0 +1,780 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "fmt" + "io" + "math" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/executor/importer" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/lightning/mydump" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sqlkiller" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +// LoadDataVarKey is a variable key for load data. +const LoadDataVarKey loadDataVarKeyType = 0 + +// LoadDataReaderBuilderKey stores the reader channel that reads from the connection. +const LoadDataReaderBuilderKey loadDataVarKeyType = 1 + +var ( + taskQueueSize = 16 // the maximum number of pending tasks to commit in queue +) + +// LoadDataReaderBuilder is a function type that builds a reader from a file path. +type LoadDataReaderBuilder func(filepath string) ( + r io.ReadCloser, err error, +) + +// LoadDataExec represents a load data executor. +type LoadDataExec struct { + exec.BaseExecutor + + FileLocRef ast.FileLocRefTp + loadDataWorker *LoadDataWorker + + // fields for loading local file + infileReader io.ReadCloser +} + +// Open implements the Executor interface. +func (e *LoadDataExec) Open(_ context.Context) error { + if rb, ok := e.Ctx().Value(LoadDataReaderBuilderKey).(LoadDataReaderBuilder); ok { + var err error + e.infileReader, err = rb(e.loadDataWorker.GetInfilePath()) + if err != nil { + return err + } + } + return nil +} + +// Close implements the Executor interface. +func (e *LoadDataExec) Close() error { + return e.closeLocalReader(nil) +} + +func (e *LoadDataExec) closeLocalReader(originalErr error) error { + err := originalErr + if e.infileReader != nil { + if err2 := e.infileReader.Close(); err2 != nil { + logutil.BgLogger().Error( + "close local reader failed", zap.Error(err2), + zap.NamedError("original error", originalErr), + ) + if err == nil { + err = err2 + } + } + e.infileReader = nil + } + return err +} + +// Next implements the Executor Next interface. +func (e *LoadDataExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { + switch e.FileLocRef { + case ast.FileLocServerOrRemote: + return e.loadDataWorker.loadRemote(ctx) + case ast.FileLocClient: + // This is for legacy test only + // TODO: adjust tests to remove LoadDataVarKey + sctx := e.loadDataWorker.UserSctx + sctx.SetValue(LoadDataVarKey, e.loadDataWorker) + + err = e.loadDataWorker.LoadLocal(ctx, e.infileReader) + if err != nil { + logutil.Logger(ctx).Error("load local data failed", zap.Error(err)) + err = e.closeLocalReader(err) + return err + } + } + return nil +} + +type planInfo struct { + ID int + Columns []*ast.ColumnName + GenColExprs []expression.Expression +} + +// LoadDataWorker does a LOAD DATA job. +type LoadDataWorker struct { + UserSctx sessionctx.Context + + controller *importer.LoadDataController + planInfo planInfo + + table table.Table +} + +func setNonRestrictiveFlags(stmtCtx *stmtctx.StatementContext) { + // TODO: DupKeyAsWarning represents too many "ignore error" paths, the + // meaning of this flag is not clear. I can only reuse it here. + levels := stmtCtx.ErrLevels() + levels[errctx.ErrGroupDupKey] = errctx.LevelWarn + levels[errctx.ErrGroupBadNull] = errctx.LevelWarn + stmtCtx.SetErrLevels(levels) + stmtCtx.SetTypeFlags(stmtCtx.TypeFlags().WithTruncateAsWarning(true)) +} + +// NewLoadDataWorker creates a new LoadDataWorker that is ready to work. +func NewLoadDataWorker( + userSctx sessionctx.Context, + plan *plannercore.LoadData, + tbl table.Table, +) (w *LoadDataWorker, err error) { + importPlan, err := importer.NewPlanFromLoadDataPlan(userSctx, plan) + if err != nil { + return nil, err + } + astArgs := importer.ASTArgsFromPlan(plan) + controller, err := importer.NewLoadDataController(importPlan, tbl, astArgs) + if err != nil { + return nil, err + } + + if !controller.Restrictive { + setNonRestrictiveFlags(userSctx.GetSessionVars().StmtCtx) + } + + loadDataWorker := &LoadDataWorker{ + UserSctx: userSctx, + table: tbl, + controller: controller, + planInfo: planInfo{ + ID: plan.ID(), + Columns: plan.Columns, + GenColExprs: plan.GenCols.Exprs, + }, + } + return loadDataWorker, nil +} + +func (e *LoadDataWorker) loadRemote(ctx context.Context) error { + if err2 := e.controller.InitDataFiles(ctx); err2 != nil { + return err2 + } + return e.load(ctx, e.controller.GetLoadDataReaderInfos()) +} + +// LoadLocal reads from client connection and do load data job. +func (e *LoadDataWorker) LoadLocal(ctx context.Context, r io.ReadCloser) error { + if r == nil { + return errors.New("load local data, reader is nil") + } + + compressTp := mydump.ParseCompressionOnFileExtension(e.GetInfilePath()) + compressTp2, err := mydump.ToStorageCompressType(compressTp) + if err != nil { + return err + } + readers := []importer.LoadDataReaderInfo{{ + Opener: func(_ context.Context) (io.ReadSeekCloser, error) { + addedSeekReader := NewSimpleSeekerOnReadCloser(r) + return storage.InterceptDecompressReader(addedSeekReader, compressTp2, storage.DecompressConfig{ + ZStdDecodeConcurrency: 1, + }) + }}} + return e.load(ctx, readers) +} + +func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDataReaderInfo) error { + group, groupCtx := errgroup.WithContext(ctx) + + encoder, committer, err := initEncodeCommitWorkers(e) + if err != nil { + return err + } + + // main goroutine -> readerInfoCh -> processOneStream goroutines + readerInfoCh := make(chan importer.LoadDataReaderInfo, 1) + // processOneStream goroutines -> commitTaskCh -> commitWork goroutines + commitTaskCh := make(chan commitTask, taskQueueSize) + // commitWork goroutines -> done -> UpdateJobProgress goroutine + + // processOneStream goroutines. + group.Go(func() error { + err2 := encoder.processStream(groupCtx, readerInfoCh, commitTaskCh) + if err2 == nil { + close(commitTaskCh) + } + return err2 + }) + // commitWork goroutines. + group.Go(func() error { + failpoint.Inject("BeforeCommitWork", nil) + return committer.commitWork(groupCtx, commitTaskCh) + }) + +sendReaderInfoLoop: + for _, info := range readerInfos { + select { + case <-groupCtx.Done(): + break sendReaderInfoLoop + case readerInfoCh <- info: + } + } + close(readerInfoCh) + err = group.Wait() + e.setResult(encoder.exprWarnings) + return err +} + +func (e *LoadDataWorker) setResult(colAssignExprWarnings []contextutil.SQLWarn) { + stmtCtx := e.UserSctx.GetSessionVars().StmtCtx + numWarnings := uint64(stmtCtx.WarningCount()) + numRecords := stmtCtx.RecordRows() + numDeletes := stmtCtx.DeletedRows() + numSkipped := stmtCtx.RecordRows() - stmtCtx.CopiedRows() + + // col assign expr warnings is generated during init, it's static + // we need to generate it for each row processed. + numWarnings += numRecords * uint64(len(colAssignExprWarnings)) + + if numWarnings > math.MaxUint16 { + numWarnings = math.MaxUint16 + } + + msg := fmt.Sprintf(mysql.MySQLErrName[mysql.ErrLoadInfo].Raw, numRecords, numDeletes, numSkipped, numWarnings) + warns := make([]contextutil.SQLWarn, numWarnings) + n := copy(warns, stmtCtx.GetWarnings()) + for i := 0; i < int(numRecords) && n < len(warns); i++ { + n += copy(warns[n:], colAssignExprWarnings) + } + + stmtCtx.SetMessage(msg) + stmtCtx.SetWarnings(warns) +} + +func initEncodeCommitWorkers(e *LoadDataWorker) (*encodeWorker, *commitWorker, error) { + insertValues, err2 := createInsertValues(e) + if err2 != nil { + return nil, nil, err2 + } + colAssignExprs, exprWarnings, err2 := e.controller.CreateColAssignExprs(insertValues.Ctx()) + if err2 != nil { + return nil, nil, err2 + } + enc := &encodeWorker{ + InsertValues: insertValues, + controller: e.controller, + colAssignExprs: colAssignExprs, + exprWarnings: exprWarnings, + killer: &e.UserSctx.GetSessionVars().SQLKiller, + } + enc.resetBatch() + com := &commitWorker{ + InsertValues: insertValues, + controller: e.controller, + } + return enc, com, nil +} + +// createInsertValues creates InsertValues from userSctx. +func createInsertValues(e *LoadDataWorker) (insertVal *InsertValues, err error) { + insertColumns := e.controller.InsertColumns + hasExtraHandle := false + for _, col := range insertColumns { + if col.Name.L == model.ExtraHandleName.L { + if !e.UserSctx.GetSessionVars().AllowWriteRowID { + return nil, errors.Errorf("load data statement for _tidb_rowid are not supported") + } + hasExtraHandle = true + break + } + } + ret := &InsertValues{ + BaseExecutor: exec.NewBaseExecutor(e.UserSctx, nil, e.planInfo.ID), + Table: e.table, + Columns: e.planInfo.Columns, + GenExprs: e.planInfo.GenColExprs, + maxRowsInBatch: 1000, + insertColumns: insertColumns, + rowLen: len(insertColumns), + hasExtraHandle: hasExtraHandle, + } + if len(insertColumns) > 0 { + ret.initEvalBuffer() + } + ret.collectRuntimeStatsEnabled() + return ret, nil +} + +// encodeWorker is a sub-worker of LoadDataWorker that dedicated to encode data. +type encodeWorker struct { + *InsertValues + controller *importer.LoadDataController + colAssignExprs []expression.Expression + // sessionCtx generate warnings when rewrite AST node into expression. + // we should generate such warnings for each row encoded. + exprWarnings []contextutil.SQLWarn + killer *sqlkiller.SQLKiller + rows [][]types.Datum +} + +// commitTask is used for passing data from processStream goroutine to commitWork goroutine. +type commitTask struct { + cnt uint64 + rows [][]types.Datum +} + +// processStream always tries to build a parser from channel and process it. When +// it returns nil, it means all data is read. +func (w *encodeWorker) processStream( + ctx context.Context, + inCh <-chan importer.LoadDataReaderInfo, + outCh chan<- commitTask, +) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case readerInfo, ok := <-inCh: + if !ok { + return nil + } + dataParser, err := w.controller.GetParser(ctx, readerInfo) + if err != nil { + return err + } + if err = w.controller.HandleSkipNRows(dataParser); err != nil { + return err + } + err = w.processOneStream(ctx, dataParser, outCh) + terror.Log(dataParser.Close()) + if err != nil { + return err + } + } + } +} + +// processOneStream process input stream from parser. When returns nil, it means +// all data is read. +func (w *encodeWorker) processOneStream( + ctx context.Context, + parser mydump.Parser, + outCh chan<- commitTask, +) (err error) { + defer func() { + r := recover() + if r != nil { + logutil.Logger(ctx).Error("process routine panicked", + zap.Any("r", r), + zap.Stack("stack")) + err = util.GetRecoverError(r) + } + }() + + checkKilled := time.NewTicker(30 * time.Second) + defer checkKilled.Stop() + + for { + // prepare batch and enqueue task + if err = w.readOneBatchRows(ctx, parser); err != nil { + return + } + if w.curBatchCnt == 0 { + return + } + + TrySendTask: + select { + case <-ctx.Done(): + return ctx.Err() + case <-checkKilled.C: + if err := w.killer.HandleSignal(); err != nil { + logutil.Logger(ctx).Info("load data query interrupted quit data processing") + return err + } + goto TrySendTask + case outCh <- commitTask{ + cnt: w.curBatchCnt, + rows: w.rows, + }: + } + // reset rows buffer, will reallocate buffer but NOT reuse + w.resetBatch() + } +} + +func (w *encodeWorker) resetBatch() { + w.rows = make([][]types.Datum, 0, w.maxRowsInBatch) + w.curBatchCnt = 0 +} + +// readOneBatchRows reads rows from parser. When parser's reader meet EOF, it +// will return nil. For other errors it will return directly. When the rows +// batch is full it will also return nil. +// The result rows are saved in w.rows and update some members, caller can check +// if curBatchCnt == 0 to know if reached EOF. +func (w *encodeWorker) readOneBatchRows(ctx context.Context, parser mydump.Parser) error { + for { + if err := parser.ReadRow(); err != nil { + if errors.Cause(err) == io.EOF { + return nil + } + return exeerrors.ErrLoadDataCantRead.GenWithStackByArgs( + err.Error(), + "Only the following formats delimited text file (csv, tsv), parquet, sql are supported. Please provide the valid source file(s)", + ) + } + // rowCount will be used in fillRow(), last insert ID will be assigned according to the rowCount = 1. + // So should add first here. + w.rowCount++ + r, err := w.parserData2TableData(ctx, parser.LastRow().Row) + if err != nil { + return err + } + parser.RecycleRow(parser.LastRow()) + w.rows = append(w.rows, r) + w.curBatchCnt++ + if w.maxRowsInBatch != 0 && w.rowCount%w.maxRowsInBatch == 0 { + logutil.Logger(ctx).Info("batch limit hit when inserting rows", zap.Int("maxBatchRows", w.MaxChunkSize()), + zap.Uint64("totalRows", w.rowCount)) + return nil + } + } +} + +// parserData2TableData encodes the data of parser output. +func (w *encodeWorker) parserData2TableData( + ctx context.Context, + parserData []types.Datum, +) ([]types.Datum, error) { + var errColNumMismatch error + switch { + case len(parserData) < w.controller.GetFieldCount(): + errColNumMismatch = exeerrors.ErrWarnTooFewRecords.GenWithStackByArgs(w.rowCount) + case len(parserData) > w.controller.GetFieldCount(): + errColNumMismatch = exeerrors.ErrWarnTooManyRecords.GenWithStackByArgs(w.rowCount) + } + + if errColNumMismatch != nil { + if w.controller.Restrictive { + return nil, errColNumMismatch + } + w.handleWarning(errColNumMismatch) + } + + row := make([]types.Datum, 0, len(w.insertColumns)) + sessionVars := w.Ctx().GetSessionVars() + setVar := func(name string, col *types.Datum) { + // User variable names are not case-sensitive + // https://dev.mysql.com/doc/refman/8.0/en/user-variables.html + name = strings.ToLower(name) + if col == nil || col.IsNull() { + sessionVars.UnsetUserVar(name) + } else { + sessionVars.SetUserVarVal(name, *col) + } + } + + fieldMappings := w.controller.FieldMappings + for i := 0; i < len(fieldMappings); i++ { + if i >= len(parserData) { + if fieldMappings[i].Column == nil { + setVar(fieldMappings[i].UserVar.Name, nil) + continue + } + + // If some columns is missing and their type is time and has not null flag, they should be set as current time. + if types.IsTypeTime(fieldMappings[i].Column.GetType()) && mysql.HasNotNullFlag(fieldMappings[i].Column.GetFlag()) { + row = append(row, types.NewTimeDatum(types.CurrentTime(fieldMappings[i].Column.GetType()))) + continue + } + + row = append(row, types.NewDatum(nil)) + continue + } + + if fieldMappings[i].Column == nil { + setVar(fieldMappings[i].UserVar.Name, &parserData[i]) + continue + } + + // Don't set the value for generated columns. + if fieldMappings[i].Column.IsGenerated() { + row = append(row, types.NewDatum(nil)) + continue + } + + row = append(row, parserData[i]) + } + for i := 0; i < len(w.colAssignExprs); i++ { + // eval expression of `SET` clause + d, err := w.colAssignExprs[i].Eval(w.Ctx().GetExprCtx().GetEvalCtx(), chunk.Row{}) + if err != nil { + if w.controller.Restrictive { + return nil, err + } + w.handleWarning(err) + } + row = append(row, d) + } + + // a new row buffer will be allocated in getRow + newRow, err := w.getRow(ctx, row) + if err != nil { + if w.controller.Restrictive { + return nil, err + } + w.handleWarning(err) + logutil.Logger(ctx).Error("failed to get row", zap.Error(err)) + // TODO: should not return nil! caller will panic when lookup index + return nil, nil + } + + return newRow, nil +} + +// commitWorker is a sub-worker of LoadDataWorker that dedicated to commit data. +type commitWorker struct { + *InsertValues + controller *importer.LoadDataController +} + +// commitWork commit batch sequentially. When returns nil, it means the job is +// finished. +func (w *commitWorker) commitWork(ctx context.Context, inCh <-chan commitTask) (err error) { + defer func() { + r := recover() + if r != nil { + logutil.Logger(ctx).Error("commitWork panicked", + zap.Any("r", r), + zap.Stack("stack")) + err = util.GetRecoverError(r) + } + }() + + var ( + taskCnt uint64 + ) + for { + select { + case <-ctx.Done(): + return ctx.Err() + case task, ok := <-inCh: + if !ok { + return nil + } + start := time.Now() + if err = w.commitOneTask(ctx, task); err != nil { + return err + } + taskCnt++ + logutil.Logger(ctx).Info("commit one task success", + zap.Duration("commit time usage", time.Since(start)), + zap.Uint64("keys processed", task.cnt), + zap.Uint64("taskCnt processed", taskCnt), + ) + } + } +} + +// commitOneTask insert Data from LoadDataWorker.rows, then commit the modification +// like a statement. +func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error { + err := w.checkAndInsertOneBatch(ctx, task.rows, task.cnt) + if err != nil { + logutil.Logger(ctx).Error("commit error CheckAndInsert", zap.Error(err)) + return err + } + failpoint.Inject("commitOneTaskErr", func() { + failpoint.Return(errors.New("mock commit one task error")) + }) + return nil +} + +func (w *commitWorker) checkAndInsertOneBatch(ctx context.Context, rows [][]types.Datum, cnt uint64) error { + if w.stats != nil && w.stats.BasicRuntimeStats != nil { + // Since this method will not call by executor Next, + // so we need record the basic executor runtime stats by ourselves. + start := time.Now() + defer func() { + w.stats.BasicRuntimeStats.Record(time.Since(start), 0) + }() + } + var err error + if cnt == 0 { + return err + } + w.Ctx().GetSessionVars().StmtCtx.AddRecordRows(cnt) + + switch w.controller.OnDuplicate { + case ast.OnDuplicateKeyHandlingReplace: + return w.batchCheckAndInsert(ctx, rows[0:cnt], w.addRecordLD, true) + case ast.OnDuplicateKeyHandlingIgnore: + return w.batchCheckAndInsert(ctx, rows[0:cnt], w.addRecordLD, false) + case ast.OnDuplicateKeyHandlingError: + for i, row := range rows[0:cnt] { + sizeHintStep := int(w.Ctx().GetSessionVars().ShardAllocateStep) + if sizeHintStep > 0 && i%sizeHintStep == 0 { + sizeHint := sizeHintStep + remain := len(rows[0:cnt]) - i + if sizeHint > remain { + sizeHint = remain + } + err = w.addRecordWithAutoIDHint(ctx, row, sizeHint, table.DupKeyCheckDefault) + } else { + err = w.addRecord(ctx, row, table.DupKeyCheckDefault) + } + if err != nil { + return err + } + w.Ctx().GetSessionVars().StmtCtx.AddCopiedRows(1) + } + return nil + default: + return errors.Errorf("unknown on duplicate key handling: %v", w.controller.OnDuplicate) + } +} + +func (w *commitWorker) addRecordLD(ctx context.Context, row []types.Datum, dupKeyCheck table.DupKeyCheckMode) error { + if row == nil { + return nil + } + return w.addRecord(ctx, row, dupKeyCheck) +} + +// GetInfilePath get infile path. +func (e *LoadDataWorker) GetInfilePath() string { + return e.controller.Path +} + +// GetController get load data controller. +// used in unit test. +func (e *LoadDataWorker) GetController() *importer.LoadDataController { + return e.controller +} + +// TestLoadLocal is a helper function for unit test. +func (e *LoadDataWorker) TestLoadLocal(parser mydump.Parser) error { + if err := ResetContextOfStmt(e.UserSctx, &ast.LoadDataStmt{}); err != nil { + return err + } + setNonRestrictiveFlags(e.UserSctx.GetSessionVars().StmtCtx) + encoder, committer, err := initEncodeCommitWorkers(e) + if err != nil { + return err + } + + ctx := context.Background() + err = sessiontxn.NewTxn(ctx, e.UserSctx) + if err != nil { + return err + } + + for i := uint64(0); i < e.controller.IgnoreLines; i++ { + //nolint: errcheck + _ = parser.ReadRow() + } + + err = encoder.readOneBatchRows(ctx, parser) + if err != nil { + return err + } + + err = committer.checkAndInsertOneBatch( + ctx, + encoder.rows, + encoder.curBatchCnt) + if err != nil { + return err + } + encoder.resetBatch() + committer.Ctx().StmtCommit(ctx) + err = committer.Ctx().CommitTxn(ctx) + if err != nil { + return err + } + e.setResult(encoder.exprWarnings) + return nil +} + +var _ io.ReadSeekCloser = (*SimpleSeekerOnReadCloser)(nil) + +// SimpleSeekerOnReadCloser provides Seek(0, SeekCurrent) on ReadCloser. +type SimpleSeekerOnReadCloser struct { + r io.ReadCloser + pos int +} + +// NewSimpleSeekerOnReadCloser creates a SimpleSeekerOnReadCloser. +func NewSimpleSeekerOnReadCloser(r io.ReadCloser) *SimpleSeekerOnReadCloser { + return &SimpleSeekerOnReadCloser{r: r} +} + +// Read implements io.Reader. +func (s *SimpleSeekerOnReadCloser) Read(p []byte) (n int, err error) { + n, err = s.r.Read(p) + s.pos += n + return +} + +// Seek implements io.Seeker. +func (s *SimpleSeekerOnReadCloser) Seek(offset int64, whence int) (int64, error) { + // only support get reader's current offset + if offset == 0 && whence == io.SeekCurrent { + return int64(s.pos), nil + } + return 0, errors.Errorf("unsupported seek on SimpleSeekerOnReadCloser, offset: %d whence: %d", offset, whence) +} + +// Close implements io.Closer. +func (s *SimpleSeekerOnReadCloser) Close() error { + return s.r.Close() +} + +// GetFileSize implements storage.ExternalFileReader. +func (*SimpleSeekerOnReadCloser) GetFileSize() (int64, error) { + return 0, errors.Errorf("unsupported GetFileSize on SimpleSeekerOnReadCloser") +} + +// loadDataVarKeyType is a dummy type to avoid naming collision in context. +type loadDataVarKeyType int + +// String defines a Stringer function for debugging and pretty printing. +func (loadDataVarKeyType) String() string { + return "load_data_var" +} diff --git a/pkg/executor/memtable_reader.go b/pkg/executor/memtable_reader.go index 316439a638b99..286ac5c9bb6f9 100644 --- a/pkg/executor/memtable_reader.go +++ b/pkg/executor/memtable_reader.go @@ -170,12 +170,12 @@ func fetchClusterConfig(sctx sessionctx.Context, nodeTypes, nodeAddrs set.String return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("CONFIG") } serversInfo, err := infoschema.GetClusterServerInfo(sctx) - failpoint.Inject("mockClusterConfigServerInfo", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockClusterConfigServerInfo")); _err_ == nil { if s := val.(string); len(s) > 0 { // erase the error serversInfo, err = parseFailpointServerInfo(s), nil } - }) + } if err != nil { return nil, err } @@ -394,13 +394,13 @@ func (e *clusterLogRetriever) initialize(ctx context.Context, sctx sessionctx.Co return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") } serversInfo, err := infoschema.GetClusterServerInfo(sctx) - failpoint.Inject("mockClusterLogServerInfo", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockClusterLogServerInfo")); _err_ == nil { // erase the error err = nil if s := val.(string); len(s) > 0 { serversInfo = parseFailpointServerInfo(s) } - }) + } if err != nil { return nil, err } diff --git a/pkg/executor/memtable_reader.go__failpoint_stash__ b/pkg/executor/memtable_reader.go__failpoint_stash__ new file mode 100644 index 0000000000000..316439a638b99 --- /dev/null +++ b/pkg/executor/memtable_reader.go__failpoint_stash__ @@ -0,0 +1,1009 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "cmp" + "container/heap" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "slices" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/diagnosticspb" + "github.com/pingcap/sysutil" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/set" + pd "github.com/tikv/pd/client/http" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +const clusterLogBatchSize = 256 +const hotRegionsHistoryBatchSize = 256 + +type dummyCloser struct{} + +func (dummyCloser) close() error { return nil } + +func (dummyCloser) getRuntimeStats() execdetails.RuntimeStats { return nil } + +type memTableRetriever interface { + retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) + close() error + getRuntimeStats() execdetails.RuntimeStats +} + +// MemTableReaderExec executes memTable information retrieving from the MemTable components +type MemTableReaderExec struct { + exec.BaseExecutor + table *model.TableInfo + retriever memTableRetriever + // cacheRetrieved is used to indicate whether has the parent executor retrieved + // from inspection cache in inspection mode. + cacheRetrieved bool +} + +func (*MemTableReaderExec) isInspectionCacheableTable(tblName string) bool { + switch tblName { + case strings.ToLower(infoschema.TableClusterConfig), + strings.ToLower(infoschema.TableClusterInfo), + strings.ToLower(infoschema.TableClusterSystemInfo), + strings.ToLower(infoschema.TableClusterLoad), + strings.ToLower(infoschema.TableClusterHardware): + return true + default: + return false + } +} + +// Next implements the Executor Next interface. +func (e *MemTableReaderExec) Next(ctx context.Context, req *chunk.Chunk) error { + var ( + rows [][]types.Datum + err error + ) + + // The `InspectionTableCache` will be assigned in the begin of retrieving` and be + // cleaned at the end of retrieving, so nil represents currently in non-inspection mode. + if cache, tbl := e.Ctx().GetSessionVars().InspectionTableCache, e.table.Name.L; cache != nil && + e.isInspectionCacheableTable(tbl) { + // TODO: cached rows will be returned fully, we should refactor this part. + if !e.cacheRetrieved { + // Obtain data from cache first. + cached, found := cache[tbl] + if !found { + rows, err := e.retriever.retrieve(ctx, e.Ctx()) + cached = variable.TableSnapshot{Rows: rows, Err: err} + cache[tbl] = cached + } + e.cacheRetrieved = true + rows, err = cached.Rows, cached.Err + } + } else { + rows, err = e.retriever.retrieve(ctx, e.Ctx()) + } + if err != nil { + return err + } + + if len(rows) == 0 { + req.Reset() + return nil + } + + req.GrowAndReset(len(rows)) + mutableRow := chunk.MutRowFromTypes(exec.RetTypes(e)) + for _, row := range rows { + mutableRow.SetDatums(row...) + req.AppendRow(mutableRow.ToRow()) + } + return nil +} + +// Close implements the Executor Close interface. +func (e *MemTableReaderExec) Close() error { + if stats := e.retriever.getRuntimeStats(); stats != nil && e.RuntimeStats() != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), stats) + } + return e.retriever.close() +} + +type clusterConfigRetriever struct { + dummyCloser + retrieved bool + extractor *plannercore.ClusterTableExtractor +} + +// retrieve implements the memTableRetriever interface +func (e *clusterConfigRetriever) retrieve(_ context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.extractor.SkipRequest || e.retrieved { + return nil, nil + } + e.retrieved = true + return fetchClusterConfig(sctx, e.extractor.NodeTypes, e.extractor.Instances) +} + +func fetchClusterConfig(sctx sessionctx.Context, nodeTypes, nodeAddrs set.StringSet) ([][]types.Datum, error) { + type result struct { + idx int + rows [][]types.Datum + err error + } + if !hasPriv(sctx, mysql.ConfigPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("CONFIG") + } + serversInfo, err := infoschema.GetClusterServerInfo(sctx) + failpoint.Inject("mockClusterConfigServerInfo", func(val failpoint.Value) { + if s := val.(string); len(s) > 0 { + // erase the error + serversInfo, err = parseFailpointServerInfo(s), nil + } + }) + if err != nil { + return nil, err + } + serversInfo = infoschema.FilterClusterServerInfo(serversInfo, nodeTypes, nodeAddrs) + //nolint: prealloc + var finalRows [][]types.Datum + wg := sync.WaitGroup{} + ch := make(chan result, len(serversInfo)) + for i, srv := range serversInfo { + typ := srv.ServerType + address := srv.Address + statusAddr := srv.StatusAddr + if len(statusAddr) == 0 { + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("%s node %s does not contain status address", typ, address)) + continue + } + wg.Add(1) + go func(index int) { + util.WithRecovery(func() { + defer wg.Done() + var url string + switch typ { + case "pd": + url = fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), statusAddr, pd.Config) + case "tikv", "tidb", "tiflash": + url = fmt.Sprintf("%s://%s/config", util.InternalHTTPSchema(), statusAddr) + case "tiproxy": + url = fmt.Sprintf("%s://%s/api/admin/config?format=json", util.InternalHTTPSchema(), statusAddr) + case "ticdc": + url = fmt.Sprintf("%s://%s/config", util.InternalHTTPSchema(), statusAddr) + case "tso": + url = fmt.Sprintf("%s://%s/tso/api/v1/config", util.InternalHTTPSchema(), statusAddr) + case "scheduling": + url = fmt.Sprintf("%s://%s/scheduling/api/v1/config", util.InternalHTTPSchema(), statusAddr) + default: + ch <- result{err: errors.Errorf("currently we do not support get config from node type: %s(%s)", typ, address)} + return + } + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + ch <- result{err: errors.Trace(err)} + return + } + req.Header.Add("PD-Allow-follower-handle", "true") + resp, err := util.InternalHTTPClient().Do(req) + if err != nil { + ch <- result{err: errors.Trace(err)} + return + } + defer func() { + terror.Log(resp.Body.Close()) + }() + if resp.StatusCode != http.StatusOK { + ch <- result{err: errors.Errorf("request %s failed: %s", url, resp.Status)} + return + } + var nested map[string]any + if err = json.NewDecoder(resp.Body).Decode(&nested); err != nil { + ch <- result{err: errors.Trace(err)} + return + } + data := config.FlattenConfigItems(nested) + type item struct { + key string + val string + } + var items []item + for key, val := range data { + if config.ContainHiddenConfig(key) { + continue + } + var str string + switch val := val.(type) { + case string: // remove quotes + str = val + default: + tmp, err := json.Marshal(val) + if err != nil { + ch <- result{err: errors.Trace(err)} + return + } + str = string(tmp) + } + items = append(items, item{key: key, val: str}) + } + slices.SortFunc(items, func(i, j item) int { return cmp.Compare(i.key, j.key) }) + var rows [][]types.Datum + for _, item := range items { + rows = append(rows, types.MakeDatums( + typ, + address, + item.key, + item.val, + )) + } + ch <- result{idx: index, rows: rows} + }, nil) + }(i) + } + + wg.Wait() + close(ch) + + // Keep the original order to make the result more stable + var results []result //nolint: prealloc + for result := range ch { + if result.err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) + continue + } + results = append(results, result) + } + slices.SortFunc(results, func(i, j result) int { return cmp.Compare(i.idx, j.idx) }) + for _, result := range results { + finalRows = append(finalRows, result.rows...) + } + return finalRows, nil +} + +type clusterServerInfoRetriever struct { + dummyCloser + extractor *plannercore.ClusterTableExtractor + serverInfoType diagnosticspb.ServerInfoType + retrieved bool +} + +// retrieve implements the memTableRetriever interface +func (e *clusterServerInfoRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + switch e.serverInfoType { + case diagnosticspb.ServerInfoType_LoadInfo, + diagnosticspb.ServerInfoType_SystemInfo: + if !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + case diagnosticspb.ServerInfoType_HardwareInfo: + if !hasPriv(sctx, mysql.ConfigPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("CONFIG") + } + } + if e.extractor.SkipRequest || e.retrieved { + return nil, nil + } + e.retrieved = true + serversInfo, err := infoschema.GetClusterServerInfo(sctx) + if err != nil { + return nil, err + } + serversInfo = infoschema.FilterClusterServerInfo(serversInfo, e.extractor.NodeTypes, e.extractor.Instances) + return infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, e.serverInfoType, true) +} + +func parseFailpointServerInfo(s string) []infoschema.ServerInfo { + servers := strings.Split(s, ";") + serversInfo := make([]infoschema.ServerInfo, 0, len(servers)) + for _, server := range servers { + parts := strings.Split(server, ",") + serversInfo = append(serversInfo, infoschema.ServerInfo{ + StatusAddr: parts[2], + Address: parts[1], + ServerType: parts[0], + }) + } + return serversInfo +} + +type clusterLogRetriever struct { + isDrained bool + retrieving bool + heap *logResponseHeap + extractor *plannercore.ClusterLogTableExtractor + cancel context.CancelFunc +} + +type logStreamResult struct { + // Read the next stream result while current messages is drained + next chan logStreamResult + + addr string + typ string + messages []*diagnosticspb.LogMessage + err error +} + +type logResponseHeap []logStreamResult + +func (h logResponseHeap) Len() int { + return len(h) +} + +func (h logResponseHeap) Less(i, j int) bool { + if lhs, rhs := h[i].messages[0].Time, h[j].messages[0].Time; lhs != rhs { + return lhs < rhs + } + return h[i].typ < h[j].typ +} + +func (h logResponseHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *logResponseHeap) Push(x any) { + *h = append(*h, x.(logStreamResult)) +} + +func (h *logResponseHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +func (e *clusterLogRetriever) initialize(ctx context.Context, sctx sessionctx.Context) ([]chan logStreamResult, error) { + if !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + serversInfo, err := infoschema.GetClusterServerInfo(sctx) + failpoint.Inject("mockClusterLogServerInfo", func(val failpoint.Value) { + // erase the error + err = nil + if s := val.(string); len(s) > 0 { + serversInfo = parseFailpointServerInfo(s) + } + }) + if err != nil { + return nil, err + } + + instances := e.extractor.Instances + nodeTypes := e.extractor.NodeTypes + serversInfo = infoschema.FilterClusterServerInfo(serversInfo, nodeTypes, instances) + + var levels = make([]diagnosticspb.LogLevel, 0, len(e.extractor.LogLevels)) + for l := range e.extractor.LogLevels { + levels = append(levels, sysutil.ParseLogLevel(l)) + } + + // To avoid search log interface overload, the user should specify the time range, and at least one pattern + // in normally SQL. + if e.extractor.StartTime == 0 { + return nil, errors.New("denied to scan logs, please specified the start time, such as `time > '2020-01-01 00:00:00'`") + } + if e.extractor.EndTime == 0 { + return nil, errors.New("denied to scan logs, please specified the end time, such as `time < '2020-01-01 00:00:00'`") + } + patterns := e.extractor.Patterns + if len(patterns) == 0 && len(levels) == 0 && len(instances) == 0 && len(nodeTypes) == 0 { + return nil, errors.New("denied to scan full logs (use `SELECT * FROM cluster_log WHERE message LIKE '%'` explicitly if intentionally)") + } + + req := &diagnosticspb.SearchLogRequest{ + StartTime: e.extractor.StartTime, + EndTime: e.extractor.EndTime, + Levels: levels, + Patterns: patterns, + } + + return e.startRetrieving(ctx, sctx, serversInfo, req) +} + +func (e *clusterLogRetriever) startRetrieving( + ctx context.Context, + sctx sessionctx.Context, + serversInfo []infoschema.ServerInfo, + req *diagnosticspb.SearchLogRequest) ([]chan logStreamResult, error) { + // gRPC options + opt := grpc.WithTransportCredentials(insecure.NewCredentials()) + security := config.GetGlobalConfig().Security + if len(security.ClusterSSLCA) != 0 { + clusterSecurity := security.ClusterSecurity() + tlsConfig, err := clusterSecurity.ToTLSConfig() + if err != nil { + return nil, errors.Trace(err) + } + opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) + } + + // The retrieve progress may be abort + ctx, e.cancel = context.WithCancel(ctx) + + var results []chan logStreamResult //nolint: prealloc + for _, srv := range serversInfo { + typ := srv.ServerType + address := srv.Address + statusAddr := srv.StatusAddr + if len(statusAddr) == 0 { + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("%s node %s does not contain status address", typ, address)) + continue + } + ch := make(chan logStreamResult) + results = append(results, ch) + + go func(ch chan logStreamResult, serverType, address, statusAddr string) { + util.WithRecovery(func() { + defer close(ch) + + // TiDB and TiProxy provide diagnostics service via status address + remote := address + if serverType == "tidb" || serverType == "tiproxy" { + remote = statusAddr + } + conn, err := grpc.Dial(remote, opt) + if err != nil { + ch <- logStreamResult{addr: address, typ: serverType, err: err} + return + } + defer terror.Call(conn.Close) + + cli := diagnosticspb.NewDiagnosticsClient(conn) + stream, err := cli.SearchLog(ctx, req) + if err != nil { + ch <- logStreamResult{addr: address, typ: serverType, err: err} + return + } + + for { + res, err := stream.Recv() + if err != nil && err == io.EOF { + return + } + if err != nil { + select { + case ch <- logStreamResult{addr: address, typ: serverType, err: err}: + case <-ctx.Done(): + } + return + } + + result := logStreamResult{next: ch, addr: address, typ: serverType, messages: res.Messages} + select { + case ch <- result: + case <-ctx.Done(): + return + } + } + }, nil) + }(ch, typ, address, statusAddr) + } + + return results, nil +} + +func (e *clusterLogRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.extractor.SkipRequest || e.isDrained { + return nil, nil + } + + if !e.retrieving { + e.retrieving = true + results, err := e.initialize(ctx, sctx) + if err != nil { + e.isDrained = true + return nil, err + } + + // initialize the heap + e.heap = &logResponseHeap{} + for _, ch := range results { + result := <-ch + if result.err != nil || len(result.messages) == 0 { + if result.err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) + } + continue + } + *e.heap = append(*e.heap, result) + } + heap.Init(e.heap) + } + + // Merge the results + var finalRows [][]types.Datum + for e.heap.Len() > 0 && len(finalRows) < clusterLogBatchSize { + minTimeItem := heap.Pop(e.heap).(logStreamResult) + headMessage := minTimeItem.messages[0] + loggingTime := time.UnixMilli(headMessage.Time) + finalRows = append(finalRows, types.MakeDatums( + loggingTime.Format("2006/01/02 15:04:05.000"), + minTimeItem.typ, + minTimeItem.addr, + strings.ToUpper(headMessage.Level.String()), + headMessage.Message, + )) + minTimeItem.messages = minTimeItem.messages[1:] + // Current streaming result is drained, read the next to supply. + if len(minTimeItem.messages) == 0 { + result := <-minTimeItem.next + if result.err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) + continue + } + if len(result.messages) > 0 { + heap.Push(e.heap, result) + } + } else { + heap.Push(e.heap, minTimeItem) + } + } + + // All streams are drained + e.isDrained = e.heap.Len() == 0 + + return finalRows, nil +} + +func (e *clusterLogRetriever) close() error { + if e.cancel != nil { + e.cancel() + } + return nil +} + +func (*clusterLogRetriever) getRuntimeStats() execdetails.RuntimeStats { + return nil +} + +type hotRegionsResult struct { + addr string + messages *HistoryHotRegions + err error +} + +type hotRegionsResponseHeap []hotRegionsResult + +func (h hotRegionsResponseHeap) Len() int { + return len(h) +} + +func (h hotRegionsResponseHeap) Less(i, j int) bool { + lhs, rhs := h[i].messages.HistoryHotRegion[0], h[j].messages.HistoryHotRegion[0] + if lhs.UpdateTime != rhs.UpdateTime { + return lhs.UpdateTime < rhs.UpdateTime + } + return lhs.HotDegree < rhs.HotDegree +} + +func (h hotRegionsResponseHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *hotRegionsResponseHeap) Push(x any) { + *h = append(*h, x.(hotRegionsResult)) +} + +func (h *hotRegionsResponseHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +type hotRegionsHistoryRetriver struct { + dummyCloser + isDrained bool + retrieving bool + heap *hotRegionsResponseHeap + extractor *plannercore.HotRegionsHistoryTableExtractor +} + +// HistoryHotRegionsRequest wrap conditions push down to PD. +type HistoryHotRegionsRequest struct { + StartTime int64 `json:"start_time,omitempty"` + EndTime int64 `json:"end_time,omitempty"` + RegionIDs []uint64 `json:"region_ids,omitempty"` + StoreIDs []uint64 `json:"store_ids,omitempty"` + PeerIDs []uint64 `json:"peer_ids,omitempty"` + IsLearners []bool `json:"is_learners,omitempty"` + IsLeaders []bool `json:"is_leaders,omitempty"` + HotRegionTypes []string `json:"hot_region_type,omitempty"` +} + +// HistoryHotRegions records filtered hot regions stored in each PD. +// it's the response of PD. +type HistoryHotRegions struct { + HistoryHotRegion []*HistoryHotRegion `json:"history_hot_region"` +} + +// HistoryHotRegion records each hot region's statistics. +// it's the response of PD. +type HistoryHotRegion struct { + UpdateTime int64 `json:"update_time"` + RegionID uint64 `json:"region_id"` + StoreID uint64 `json:"store_id"` + PeerID uint64 `json:"peer_id"` + IsLearner bool `json:"is_learner"` + IsLeader bool `json:"is_leader"` + HotRegionType string `json:"hot_region_type"` + HotDegree int64 `json:"hot_degree"` + FlowBytes float64 `json:"flow_bytes"` + KeyRate float64 `json:"key_rate"` + QueryRate float64 `json:"query_rate"` + StartKey string `json:"start_key"` + EndKey string `json:"end_key"` +} + +func (e *hotRegionsHistoryRetriver) initialize(_ context.Context, sctx sessionctx.Context) ([]chan hotRegionsResult, error) { + if !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + pdServers, err := infoschema.GetPDServerInfo(sctx) + if err != nil { + return nil, err + } + + // To avoid search hot regions interface overload, the user should specify the time range in normally SQL. + if e.extractor.StartTime == 0 { + return nil, errors.New("denied to scan hot regions, please specified the start time, such as `update_time > '2020-01-01 00:00:00'`") + } + if e.extractor.EndTime == 0 { + return nil, errors.New("denied to scan hot regions, please specified the end time, such as `update_time < '2020-01-01 00:00:00'`") + } + + historyHotRegionsRequest := &HistoryHotRegionsRequest{ + StartTime: e.extractor.StartTime, + EndTime: e.extractor.EndTime, + RegionIDs: e.extractor.RegionIDs, + StoreIDs: e.extractor.StoreIDs, + PeerIDs: e.extractor.PeerIDs, + IsLearners: e.extractor.IsLearners, + IsLeaders: e.extractor.IsLeaders, + } + + return e.startRetrieving(pdServers, historyHotRegionsRequest) +} + +func (e *hotRegionsHistoryRetriver) startRetrieving( + pdServers []infoschema.ServerInfo, + req *HistoryHotRegionsRequest, +) ([]chan hotRegionsResult, error) { + var results []chan hotRegionsResult + for _, srv := range pdServers { + for typ := range e.extractor.HotRegionTypes { + req.HotRegionTypes = []string{typ} + jsonBody, err := json.Marshal(req) + if err != nil { + return nil, err + } + body := bytes.NewBuffer(jsonBody) + ch := make(chan hotRegionsResult) + results = append(results, ch) + go func(ch chan hotRegionsResult, address string, body *bytes.Buffer) { + util.WithRecovery(func() { + defer close(ch) + url := fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), address, pd.HotHistory) + req, err := http.NewRequest(http.MethodGet, url, body) + if err != nil { + ch <- hotRegionsResult{err: errors.Trace(err)} + return + } + req.Header.Add("PD-Allow-follower-handle", "true") + resp, err := util.InternalHTTPClient().Do(req) + if err != nil { + ch <- hotRegionsResult{err: errors.Trace(err)} + return + } + defer func() { + terror.Log(resp.Body.Close()) + }() + if resp.StatusCode != http.StatusOK { + ch <- hotRegionsResult{err: errors.Errorf("request %s failed: %s", url, resp.Status)} + return + } + var historyHotRegions HistoryHotRegions + if err = json.NewDecoder(resp.Body).Decode(&historyHotRegions); err != nil { + ch <- hotRegionsResult{err: errors.Trace(err)} + return + } + ch <- hotRegionsResult{addr: address, messages: &historyHotRegions} + }, nil) + }(ch, srv.StatusAddr, body) + } + } + return results, nil +} + +func (e *hotRegionsHistoryRetriver) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.extractor.SkipRequest || e.isDrained { + return nil, nil + } + + if !e.retrieving { + e.retrieving = true + results, err := e.initialize(ctx, sctx) + if err != nil { + e.isDrained = true + return nil, err + } + // Initialize the heap + e.heap = &hotRegionsResponseHeap{} + for _, ch := range results { + result := <-ch + if result.err != nil || len(result.messages.HistoryHotRegion) == 0 { + if result.err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) + } + continue + } + *e.heap = append(*e.heap, result) + } + heap.Init(e.heap) + } + // Merge the results + var finalRows [][]types.Datum + tikvStore, ok := sctx.GetStore().(helper.Storage) + if !ok { + return nil, errors.New("Information about hot region can be gotten only when the storage is TiKV") + } + tikvHelper := &helper.Helper{ + Store: tikvStore, + RegionCache: tikvStore.GetRegionCache(), + } + tz := sctx.GetSessionVars().Location() + is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() + tables := tikvHelper.GetTablesInfoWithKeyRange(is, tikvHelper.FilterMemDBs) + for e.heap.Len() > 0 && len(finalRows) < hotRegionsHistoryBatchSize { + minTimeItem := heap.Pop(e.heap).(hotRegionsResult) + rows, err := e.getHotRegionRowWithSchemaInfo(minTimeItem.messages.HistoryHotRegion[0], tikvHelper, tables, tz) + if err != nil { + return nil, err + } + if rows != nil { + finalRows = append(finalRows, rows...) + } + minTimeItem.messages.HistoryHotRegion = minTimeItem.messages.HistoryHotRegion[1:] + // Fetch next message item + if len(minTimeItem.messages.HistoryHotRegion) != 0 { + heap.Push(e.heap, minTimeItem) + } + } + // All streams are drained + e.isDrained = e.heap.Len() == 0 + return finalRows, nil +} + +func (*hotRegionsHistoryRetriver) getHotRegionRowWithSchemaInfo( + hisHotRegion *HistoryHotRegion, + tikvHelper *helper.Helper, + tables []helper.TableInfoWithKeyRange, + tz *time.Location, +) ([][]types.Datum, error) { + regionsInfo := []*pd.RegionInfo{ + { + ID: int64(hisHotRegion.RegionID), + StartKey: hisHotRegion.StartKey, + EndKey: hisHotRegion.EndKey, + }} + regionsTableInfos := tikvHelper.ParseRegionsTableInfos(regionsInfo, tables) + + var rows [][]types.Datum + // Ignore row without corresponding schema. + if tableInfos, ok := regionsTableInfos[int64(hisHotRegion.RegionID)]; ok { + for _, tableInfo := range tableInfos { + updateTimestamp := time.UnixMilli(hisHotRegion.UpdateTime) + if updateTimestamp.Location() != tz { + updateTimestamp.In(tz) + } + updateTime := types.NewTime(types.FromGoTime(updateTimestamp), mysql.TypeTimestamp, types.MinFsp) + row := make([]types.Datum, len(infoschema.GetTableTiDBHotRegionsHistoryCols())) + row[0].SetMysqlTime(updateTime) + row[1].SetString(strings.ToUpper(tableInfo.DB.Name.O), mysql.DefaultCollationName) + row[2].SetString(strings.ToUpper(tableInfo.Table.Name.O), mysql.DefaultCollationName) + row[3].SetInt64(tableInfo.Table.ID) + if tableInfo.IsIndex { + row[4].SetString(strings.ToUpper(tableInfo.Index.Name.O), mysql.DefaultCollationName) + row[5].SetInt64(tableInfo.Index.ID) + } else { + row[4].SetNull() + row[5].SetNull() + } + row[6].SetInt64(int64(hisHotRegion.RegionID)) + row[7].SetInt64(int64(hisHotRegion.StoreID)) + row[8].SetInt64(int64(hisHotRegion.PeerID)) + if hisHotRegion.IsLearner { + row[9].SetInt64(1) + } else { + row[9].SetInt64(0) + } + if hisHotRegion.IsLeader { + row[10].SetInt64(1) + } else { + row[10].SetInt64(0) + } + row[11].SetString(strings.ToUpper(hisHotRegion.HotRegionType), mysql.DefaultCollationName) + row[12].SetInt64(hisHotRegion.HotDegree) + row[13].SetFloat64(hisHotRegion.FlowBytes) + row[14].SetFloat64(hisHotRegion.KeyRate) + row[15].SetFloat64(hisHotRegion.QueryRate) + rows = append(rows, row) + } + } + + return rows, nil +} + +type tikvRegionPeersRetriever struct { + dummyCloser + extractor *plannercore.TikvRegionPeersExtractor + retrieved bool +} + +func (e *tikvRegionPeersRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.extractor.SkipRequest || e.retrieved { + return nil, nil + } + e.retrieved = true + tikvStore, ok := sctx.GetStore().(helper.Storage) + if !ok { + return nil, errors.New("Information about hot region can be gotten only when the storage is TiKV") + } + tikvHelper := &helper.Helper{ + Store: tikvStore, + RegionCache: tikvStore.GetRegionCache(), + } + pdCli, err := tikvHelper.TryGetPDHTTPClient() + if err != nil { + return nil, err + } + + var regionsInfo, regionsInfoByStoreID []pd.RegionInfo + regionMap := make(map[int64]*pd.RegionInfo) + storeMap := make(map[int64]struct{}) + + if len(e.extractor.StoreIDs) == 0 && len(e.extractor.RegionIDs) == 0 { + regionsInfo, err := pdCli.GetRegions(ctx) + if err != nil { + return nil, err + } + return e.packTiKVRegionPeersRows(regionsInfo.Regions, storeMap) + } + + for _, storeID := range e.extractor.StoreIDs { + // if a region_id located in 1, 4, 7 store we will get all of them when request any store_id, + // storeMap is used to filter peers on unexpected stores. + storeMap[int64(storeID)] = struct{}{} + storeRegionsInfo, err := pdCli.GetRegionsByStoreID(ctx, storeID) + if err != nil { + return nil, err + } + for i, regionInfo := range storeRegionsInfo.Regions { + // regionMap is used to remove dup regions and record the region in regionsInfoByStoreID. + if _, ok := regionMap[regionInfo.ID]; !ok { + regionsInfoByStoreID = append(regionsInfoByStoreID, regionInfo) + regionMap[regionInfo.ID] = &storeRegionsInfo.Regions[i] + } + } + } + + if len(e.extractor.RegionIDs) == 0 { + return e.packTiKVRegionPeersRows(regionsInfoByStoreID, storeMap) + } + + for _, regionID := range e.extractor.RegionIDs { + regionInfoByStoreID, ok := regionMap[int64(regionID)] + if !ok { + // if there is storeIDs, target region_id is fetched by storeIDs, + // otherwise we need to fetch it from PD. + if len(e.extractor.StoreIDs) == 0 { + regionInfo, err := pdCli.GetRegionByID(ctx, regionID) + if err != nil { + return nil, err + } + regionsInfo = append(regionsInfo, *regionInfo) + } + } else { + regionsInfo = append(regionsInfo, *regionInfoByStoreID) + } + } + + return e.packTiKVRegionPeersRows(regionsInfo, storeMap) +} + +func (e *tikvRegionPeersRetriever) isUnexpectedStoreID(storeID int64, storeMap map[int64]struct{}) bool { + if len(e.extractor.StoreIDs) == 0 { + return false + } + if _, ok := storeMap[storeID]; ok { + return false + } + return true +} + +func (e *tikvRegionPeersRetriever) packTiKVRegionPeersRows( + regionsInfo []pd.RegionInfo, storeMap map[int64]struct{}) ([][]types.Datum, error) { + //nolint: prealloc + var rows [][]types.Datum + for _, region := range regionsInfo { + records := make([][]types.Datum, 0, len(region.Peers)) + pendingPeerIDSet := set.NewInt64Set() + for _, peer := range region.PendingPeers { + pendingPeerIDSet.Insert(peer.ID) + } + downPeerMap := make(map[int64]int64, len(region.DownPeers)) + for _, peerStat := range region.DownPeers { + downPeerMap[peerStat.Peer.ID] = peerStat.DownSec + } + for _, peer := range region.Peers { + // isUnexpectedStoreID return true if we should filter this peer. + if e.isUnexpectedStoreID(peer.StoreID, storeMap) { + continue + } + + row := make([]types.Datum, len(infoschema.GetTableTiKVRegionPeersCols())) + row[0].SetInt64(region.ID) + row[1].SetInt64(peer.ID) + row[2].SetInt64(peer.StoreID) + if peer.IsLearner { + row[3].SetInt64(1) + } else { + row[3].SetInt64(0) + } + if peer.ID == region.Leader.ID { + row[4].SetInt64(1) + } else { + row[4].SetInt64(0) + } + if downSec, ok := downPeerMap[peer.ID]; ok { + row[5].SetString(downPeer, mysql.DefaultCollationName) + row[6].SetInt64(downSec) + } else if pendingPeerIDSet.Exist(peer.ID) { + row[5].SetString(pendingPeer, mysql.DefaultCollationName) + } else { + row[5].SetString(normalPeer, mysql.DefaultCollationName) + } + records = append(records, row) + } + rows = append(rows, records...) + } + return rows, nil +} diff --git a/pkg/executor/metrics_reader.go b/pkg/executor/metrics_reader.go index 31a0073148584..3bfaa6907d568 100644 --- a/pkg/executor/metrics_reader.go +++ b/pkg/executor/metrics_reader.go @@ -57,12 +57,12 @@ func (e *MetricRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) } e.retrieved = true - failpoint.InjectContext(ctx, "mockMetricsTableData", func() { + if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockMetricsTableData")); _err_ == nil { m, ok := ctx.Value("__mockMetricsTableData").(map[string][][]types.Datum) if ok && m[e.table.Name.L] != nil { - failpoint.Return(m[e.table.Name.L], nil) + return m[e.table.Name.L], nil } - }) + } tblDef, err := infoschema.GetMetricTableDef(e.table.Name.L) if err != nil { @@ -94,9 +94,9 @@ func (e *MetricRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) type MockMetricsPromDataKey struct{} func (e *MetricRetriever) queryMetric(ctx context.Context, sctx sessionctx.Context, queryRange promv1.Range, quantile float64) (result pmodel.Value, err error) { - failpoint.InjectContext(ctx, "mockMetricsPromData", func() { - failpoint.Return(ctx.Value(MockMetricsPromDataKey{}).(pmodel.Matrix), nil) - }) + if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockMetricsPromData")); _err_ == nil { + return ctx.Value(MockMetricsPromDataKey{}).(pmodel.Matrix), nil + } // Add retry to avoid network error. var prometheusAddr string diff --git a/pkg/executor/metrics_reader.go__failpoint_stash__ b/pkg/executor/metrics_reader.go__failpoint_stash__ new file mode 100644 index 0000000000000..31a0073148584 --- /dev/null +++ b/pkg/executor/metrics_reader.go__failpoint_stash__ @@ -0,0 +1,365 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "fmt" + "math" + "slices" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/prometheus/client_golang/api" + promv1 "github.com/prometheus/client_golang/api/prometheus/v1" + pmodel "github.com/prometheus/common/model" +) + +const promReadTimeout = time.Second * 10 + +// MetricRetriever uses to read metric data. +type MetricRetriever struct { + dummyCloser + table *model.TableInfo + tblDef *infoschema.MetricTableDef + extractor *plannercore.MetricTableExtractor + retrieved bool +} + +func (e *MetricRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if e.retrieved || e.extractor.SkipRequest { + return nil, nil + } + e.retrieved = true + + failpoint.InjectContext(ctx, "mockMetricsTableData", func() { + m, ok := ctx.Value("__mockMetricsTableData").(map[string][][]types.Datum) + if ok && m[e.table.Name.L] != nil { + failpoint.Return(m[e.table.Name.L], nil) + } + }) + + tblDef, err := infoschema.GetMetricTableDef(e.table.Name.L) + if err != nil { + return nil, err + } + e.tblDef = tblDef + queryRange := e.getQueryRange(sctx) + totalRows := make([][]types.Datum, 0) + quantiles := e.extractor.Quantiles + if len(quantiles) == 0 { + quantiles = []float64{tblDef.Quantile} + } + for _, quantile := range quantiles { + var queryValue pmodel.Value + queryValue, err = e.queryMetric(ctx, sctx, queryRange, quantile) + if err != nil { + if err1, ok := err.(*promv1.Error); ok { + return nil, errors.Errorf("query metric error, msg: %v, detail: %v", err1.Msg, err1.Detail) + } + return nil, errors.Errorf("query metric error: %v", err.Error()) + } + partRows := e.genRows(queryValue, quantile) + totalRows = append(totalRows, partRows...) + } + return totalRows, nil +} + +// MockMetricsPromDataKey is for test +type MockMetricsPromDataKey struct{} + +func (e *MetricRetriever) queryMetric(ctx context.Context, sctx sessionctx.Context, queryRange promv1.Range, quantile float64) (result pmodel.Value, err error) { + failpoint.InjectContext(ctx, "mockMetricsPromData", func() { + failpoint.Return(ctx.Value(MockMetricsPromDataKey{}).(pmodel.Matrix), nil) + }) + + // Add retry to avoid network error. + var prometheusAddr string + for i := 0; i < 5; i++ { + //TODO: the prometheus will be Integrated into the PD, then we need to query the prometheus in PD directly, which need change the quire API + prometheusAddr, err = infosync.GetPrometheusAddr() + if err == nil || err == infosync.ErrPrometheusAddrIsNotSet { + break + } + time.Sleep(100 * time.Millisecond) + } + if err != nil { + return nil, err + } + promClient, err := api.NewClient(api.Config{ + Address: prometheusAddr, + }) + if err != nil { + return nil, err + } + promQLAPI := promv1.NewAPI(promClient) + ctx, cancel := context.WithTimeout(ctx, promReadTimeout) + defer cancel() + promQL := e.tblDef.GenPromQL(sctx.GetSessionVars().MetricSchemaRangeDuration, e.extractor.LabelConditions, quantile) + + // Add retry to avoid network error. + for i := 0; i < 5; i++ { + result, _, err = promQLAPI.QueryRange(ctx, promQL, queryRange) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + return result, err +} + +type promQLQueryRange = promv1.Range + +func (e *MetricRetriever) getQueryRange(sctx sessionctx.Context) promQLQueryRange { + startTime, endTime := e.extractor.StartTime, e.extractor.EndTime + step := time.Second * time.Duration(sctx.GetSessionVars().MetricSchemaStep) + return promQLQueryRange{Start: startTime, End: endTime, Step: step} +} + +func (e *MetricRetriever) genRows(value pmodel.Value, quantile float64) [][]types.Datum { + var rows [][]types.Datum + if value.Type() == pmodel.ValMatrix { + matrix := value.(pmodel.Matrix) + for _, m := range matrix { + for _, v := range m.Values { + record := e.genRecord(m.Metric, v, quantile) + rows = append(rows, record) + } + } + } + return rows +} + +func (e *MetricRetriever) genRecord(metric pmodel.Metric, pair pmodel.SamplePair, quantile float64) []types.Datum { + record := make([]types.Datum, 0, 2+len(e.tblDef.Labels)+1) + // Record order should keep same with genColumnInfos. + record = append(record, types.NewTimeDatum(types.NewTime( + types.FromGoTime(time.UnixMilli(int64(pair.Timestamp))), + mysql.TypeDatetime, + types.MaxFsp, + ))) + for _, label := range e.tblDef.Labels { + v := "" + if metric != nil { + v = string(metric[pmodel.LabelName(label)]) + } + if len(v) == 0 { + v = infoschema.GenLabelConditionValues(e.extractor.LabelConditions[strings.ToLower(label)]) + } + record = append(record, types.NewStringDatum(v)) + } + if e.tblDef.Quantile > 0 { + record = append(record, types.NewFloat64Datum(quantile)) + } + if math.IsNaN(float64(pair.Value)) { + record = append(record, types.NewDatum(nil)) + } else { + record = append(record, types.NewFloat64Datum(float64(pair.Value))) + } + return record +} + +// MetricsSummaryRetriever uses to read metric data. +type MetricsSummaryRetriever struct { + dummyCloser + table *model.TableInfo + extractor *plannercore.MetricSummaryTableExtractor + timeRange plannerutil.QueryTimeRange + retrieved bool +} + +func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + if e.retrieved || e.extractor.SkipRequest { + return nil, nil + } + e.retrieved = true + totalRows := make([][]types.Datum, 0, len(infoschema.MetricTableMap)) + tables := make([]string, 0, len(infoschema.MetricTableMap)) + for name := range infoschema.MetricTableMap { + tables = append(tables, name) + } + slices.Sort(tables) + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + filter := inspectionFilter{set: e.extractor.MetricsNames} + condition := e.timeRange.Condition() + for _, name := range tables { + if !filter.enable(name) { + continue + } + def, found := infoschema.MetricTableMap[name] + if !found { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", name)) + continue + } + var sql string + if def.Quantile > 0 { + var qs []string + if len(e.extractor.Quantiles) > 0 { + for _, q := range e.extractor.Quantiles { + qs = append(qs, fmt.Sprintf("%f", q)) + } + } else { + qs = []string{"0.99"} + } + sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value),quantile from `%[2]s`.`%[1]s` %[3]s and quantile in (%[4]s) group by quantile order by quantile", + name, util.MetricSchemaName.L, condition, strings.Join(qs, ",")) + } else { + sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%[2]s`.`%[1]s` %[3]s", + name, util.MetricSchemaName.L, condition) + } + + exec := sctx.GetRestrictedSQLExecutor() + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + for _, row := range rows { + var quantile any + if def.Quantile > 0 { + quantile = row.GetFloat64(row.Len() - 1) + } + totalRows = append(totalRows, types.MakeDatums( + name, + quantile, + row.GetFloat64(0), + row.GetFloat64(1), + row.GetFloat64(2), + row.GetFloat64(3), + def.Comment, + )) + } + } + return totalRows, nil +} + +// MetricsSummaryByLabelRetriever uses to read metric detail data. +type MetricsSummaryByLabelRetriever struct { + dummyCloser + table *model.TableInfo + extractor *plannercore.MetricSummaryTableExtractor + timeRange plannerutil.QueryTimeRange + retrieved bool +} + +func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if !hasPriv(sctx, mysql.ProcessPriv) { + return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") + } + if e.retrieved || e.extractor.SkipRequest { + return nil, nil + } + e.retrieved = true + totalRows := make([][]types.Datum, 0, len(infoschema.MetricTableMap)) + tables := make([]string, 0, len(infoschema.MetricTableMap)) + for name := range infoschema.MetricTableMap { + tables = append(tables, name) + } + slices.Sort(tables) + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + filter := inspectionFilter{set: e.extractor.MetricsNames} + condition := e.timeRange.Condition() + for _, name := range tables { + if !filter.enable(name) { + continue + } + def, found := infoschema.MetricTableMap[name] + if !found { + sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", name)) + continue + } + cols := def.Labels + cond := condition + if def.Quantile > 0 { + cols = append(cols, "quantile") + if len(e.extractor.Quantiles) > 0 { + qs := make([]string, len(e.extractor.Quantiles)) + for i, q := range e.extractor.Quantiles { + qs[i] = fmt.Sprintf("%f", q) + } + cond += " and quantile in (" + strings.Join(qs, ",") + ")" + } else { + cond += " and quantile=0.99" + } + } + var sql string + if len(cols) > 0 { + sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value),`%s` from `%s`.`%s` %s group by `%[1]s` order by `%[1]s`", + strings.Join(cols, "`,`"), util.MetricSchemaName.L, name, cond) + } else { + sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%s`.`%s` %s", + util.MetricSchemaName.L, name, cond) + } + exec := sctx.GetRestrictedSQLExecutor() + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) + if err != nil { + return nil, errors.Errorf("execute '%s' failed: %v", sql, err) + } + nonInstanceLabelIndex := 0 + if len(def.Labels) > 0 && def.Labels[0] == "instance" { + nonInstanceLabelIndex = 1 + } + // skip sum/avg/min/max + const skipCols = 4 + for _, row := range rows { + instance := "" + if nonInstanceLabelIndex > 0 { + instance = row.GetString(skipCols) // sum/avg/min/max + } + var labels []string + for i, label := range def.Labels[nonInstanceLabelIndex:] { + // skip min/max/avg/instance + val := row.GetString(skipCols + nonInstanceLabelIndex + i) + if label == "store" || label == "store_id" { + val = fmt.Sprintf("store_id:%s", val) + } + labels = append(labels, val) + } + var quantile any + if def.Quantile > 0 { + quantile = row.GetFloat64(row.Len() - 1) // quantile will be the last column + } + totalRows = append(totalRows, types.MakeDatums( + instance, + name, + strings.Join(labels, ", "), + quantile, + row.GetFloat64(0), // sum + row.GetFloat64(1), // avg + row.GetFloat64(2), // min + row.GetFloat64(3), // max + def.Comment, + )) + } + } + return totalRows, nil +} diff --git a/pkg/executor/parallel_apply.go b/pkg/executor/parallel_apply.go index 03820652df933..f855b991e72a5 100644 --- a/pkg/executor/parallel_apply.go +++ b/pkg/executor/parallel_apply.go @@ -208,7 +208,7 @@ func (e *ParallelNestedLoopApplyExec) outerWorker(ctx context.Context) { var selected []bool var err error for { - failpoint.Inject("parallelApplyOuterWorkerPanic", nil) + failpoint.Eval(_curpkg_("parallelApplyOuterWorkerPanic")) chk := exec.TryNewCacheChunk(e.outerExec) if err := exec.Next(ctx, e.outerExec, chk); err != nil { e.putResult(nil, err) @@ -246,7 +246,7 @@ func (e *ParallelNestedLoopApplyExec) innerWorker(ctx context.Context, id int) { case <-e.exit: return } - failpoint.Inject("parallelApplyInnerWorkerPanic", nil) + failpoint.Eval(_curpkg_("parallelApplyInnerWorkerPanic")) err := e.fillInnerChunk(ctx, id, chk) if err == nil && chk.NumRows() == 0 { // no more data, this goroutine can exit return @@ -292,7 +292,7 @@ func (e *ParallelNestedLoopApplyExec) fetchAllInners(ctx context.Context, id int } if e.useCache { // look up the cache atomic.AddInt64(&e.cacheAccessCounter, 1) - failpoint.Inject("parallelApplyGetCachePanic", nil) + failpoint.Eval(_curpkg_("parallelApplyGetCachePanic")) value, err := e.cache.Get(key) if err != nil { return err @@ -339,7 +339,7 @@ func (e *ParallelNestedLoopApplyExec) fetchAllInners(ctx context.Context, id int } if e.useCache { // update the cache - failpoint.Inject("parallelApplySetCachePanic", nil) + failpoint.Eval(_curpkg_("parallelApplySetCachePanic")) if _, err := e.cache.Set(key, e.innerList[id]); err != nil { return err } diff --git a/pkg/executor/parallel_apply.go__failpoint_stash__ b/pkg/executor/parallel_apply.go__failpoint_stash__ new file mode 100644 index 0000000000000..03820652df933 --- /dev/null +++ b/pkg/executor/parallel_apply.go__failpoint_stash__ @@ -0,0 +1,405 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "runtime/trace" + "sync" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/applycache" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/join" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +type result struct { + chk *chunk.Chunk + err error +} + +type outerRow struct { + row *chunk.Row + selected bool // if this row is selected by the outer side +} + +// ParallelNestedLoopApplyExec is the executor for apply. +type ParallelNestedLoopApplyExec struct { + exec.BaseExecutor + + // outer-side fields + outerExec exec.Executor + outerFilter expression.CNFExprs + outerList *chunk.List + outer bool + + // inner-side fields + // use slices since the inner side is paralleled + corCols [][]*expression.CorrelatedColumn + innerFilter []expression.CNFExprs + innerExecs []exec.Executor + innerList []*chunk.List + innerChunk []*chunk.Chunk + innerSelected [][]bool + innerIter []chunk.Iterator + outerRow []*chunk.Row + hasMatch []bool + hasNull []bool + joiners []join.Joiner + + // fields about concurrency control + concurrency int + started uint32 + drained uint32 // drained == true indicates there is no more data + freeChkCh chan *chunk.Chunk + resultChkCh chan result + outerRowCh chan outerRow + exit chan struct{} + workerWg sync.WaitGroup + notifyWg sync.WaitGroup + + // fields about cache + cache *applycache.ApplyCache + useCache bool + cacheHitCounter int64 + cacheAccessCounter int64 + + memTracker *memory.Tracker // track memory usage. +} + +// Open implements the Executor interface. +func (e *ParallelNestedLoopApplyExec) Open(ctx context.Context) error { + err := exec.Open(ctx, e.outerExec) + if err != nil { + return err + } + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + + e.outerList = chunk.NewList(exec.RetTypes(e.outerExec), e.InitCap(), e.MaxChunkSize()) + e.outerList.GetMemTracker().SetLabel(memory.LabelForOuterList) + e.outerList.GetMemTracker().AttachTo(e.memTracker) + + e.innerList = make([]*chunk.List, e.concurrency) + e.innerChunk = make([]*chunk.Chunk, e.concurrency) + e.innerSelected = make([][]bool, e.concurrency) + e.innerIter = make([]chunk.Iterator, e.concurrency) + e.outerRow = make([]*chunk.Row, e.concurrency) + e.hasMatch = make([]bool, e.concurrency) + e.hasNull = make([]bool, e.concurrency) + for i := 0; i < e.concurrency; i++ { + e.innerChunk[i] = exec.TryNewCacheChunk(e.innerExecs[i]) + e.innerList[i] = chunk.NewList(exec.RetTypes(e.innerExecs[i]), e.InitCap(), e.MaxChunkSize()) + e.innerList[i].GetMemTracker().SetLabel(memory.LabelForInnerList) + e.innerList[i].GetMemTracker().AttachTo(e.memTracker) + } + + e.freeChkCh = make(chan *chunk.Chunk, e.concurrency) + e.resultChkCh = make(chan result, e.concurrency+1) // innerWorkers + outerWorker + e.outerRowCh = make(chan outerRow) + e.exit = make(chan struct{}) + for i := 0; i < e.concurrency; i++ { + e.freeChkCh <- exec.NewFirstChunk(e) + } + + if e.useCache { + if e.cache, err = applycache.NewApplyCache(e.Ctx()); err != nil { + return err + } + e.cache.GetMemTracker().AttachTo(e.memTracker) + } + return nil +} + +// Next implements the Executor interface. +func (e *ParallelNestedLoopApplyExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { + if atomic.LoadUint32(&e.drained) == 1 { + req.Reset() + return nil + } + + if atomic.CompareAndSwapUint32(&e.started, 0, 1) { + e.workerWg.Add(1) + go e.outerWorker(ctx) + for i := 0; i < e.concurrency; i++ { + e.workerWg.Add(1) + workID := i + go e.innerWorker(ctx, workID) + } + e.notifyWg.Add(1) + go e.notifyWorker(ctx) + } + result := <-e.resultChkCh + if result.err != nil { + return result.err + } + if result.chk == nil { // no more data + req.Reset() + atomic.StoreUint32(&e.drained, 1) + return nil + } + req.SwapColumns(result.chk) + e.freeChkCh <- result.chk + return nil +} + +// Close implements the Executor interface. +func (e *ParallelNestedLoopApplyExec) Close() error { + e.memTracker = nil + if atomic.LoadUint32(&e.started) == 1 { + close(e.exit) + e.notifyWg.Wait() + e.started = 0 + } + // Wait all workers to finish before Close() is called. + // Otherwise we may got data race. + err := exec.Close(e.outerExec) + + if e.RuntimeStats() != nil { + runtimeStats := join.NewJoinRuntimeStats() + if e.useCache { + var hitRatio float64 + if e.cacheAccessCounter > 0 { + hitRatio = float64(e.cacheHitCounter) / float64(e.cacheAccessCounter) + } + runtimeStats.SetCacheInfo(true, hitRatio) + } else { + runtimeStats.SetCacheInfo(false, 0) + } + runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", e.concurrency)) + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) + } + return err +} + +// notifyWorker waits for all inner/outer-workers finishing and then put an empty +// chunk into the resultCh to notify the upper executor there is no more data. +func (e *ParallelNestedLoopApplyExec) notifyWorker(ctx context.Context) { + defer e.handleWorkerPanic(ctx, &e.notifyWg) + e.workerWg.Wait() + e.putResult(nil, nil) +} + +func (e *ParallelNestedLoopApplyExec) outerWorker(ctx context.Context) { + defer trace.StartRegion(ctx, "ParallelApplyOuterWorker").End() + defer e.handleWorkerPanic(ctx, &e.workerWg) + var selected []bool + var err error + for { + failpoint.Inject("parallelApplyOuterWorkerPanic", nil) + chk := exec.TryNewCacheChunk(e.outerExec) + if err := exec.Next(ctx, e.outerExec, chk); err != nil { + e.putResult(nil, err) + return + } + if chk.NumRows() == 0 { + close(e.outerRowCh) + return + } + e.outerList.Add(chk) + outerIter := chunk.NewIterator4Chunk(chk) + selected, err = expression.VectorizedFilter(e.Ctx().GetExprCtx().GetEvalCtx(), e.Ctx().GetSessionVars().EnableVectorizedExpression, e.outerFilter, outerIter, selected) + if err != nil { + e.putResult(nil, err) + return + } + for i := 0; i < chk.NumRows(); i++ { + row := chk.GetRow(i) + select { + case e.outerRowCh <- outerRow{&row, selected[i]}: + case <-e.exit: + return + } + } + } +} + +func (e *ParallelNestedLoopApplyExec) innerWorker(ctx context.Context, id int) { + defer trace.StartRegion(ctx, "ParallelApplyInnerWorker").End() + defer e.handleWorkerPanic(ctx, &e.workerWg) + for { + var chk *chunk.Chunk + select { + case chk = <-e.freeChkCh: + case <-e.exit: + return + } + failpoint.Inject("parallelApplyInnerWorkerPanic", nil) + err := e.fillInnerChunk(ctx, id, chk) + if err == nil && chk.NumRows() == 0 { // no more data, this goroutine can exit + return + } + if e.putResult(chk, err) { + return + } + } +} + +func (e *ParallelNestedLoopApplyExec) putResult(chk *chunk.Chunk, err error) (exit bool) { + select { + case e.resultChkCh <- result{chk, err}: + return false + case <-e.exit: + return true + } +} + +func (e *ParallelNestedLoopApplyExec) handleWorkerPanic(ctx context.Context, wg *sync.WaitGroup) { + if r := recover(); r != nil { + err := util.GetRecoverError(r) + logutil.Logger(ctx).Error("parallel nested loop join worker panicked", zap.Error(err), zap.Stack("stack")) + e.resultChkCh <- result{nil, err} + } + if wg != nil { + wg.Done() + } +} + +// fetchAllInners reads all data from the inner table and stores them in a List. +func (e *ParallelNestedLoopApplyExec) fetchAllInners(ctx context.Context, id int) (err error) { + var key []byte + for _, col := range e.corCols[id] { + *col.Data = e.outerRow[id].GetDatum(col.Index, col.RetType) + if e.useCache { + key, err = codec.EncodeKey(e.Ctx().GetSessionVars().StmtCtx.TimeZone(), key, *col.Data) + err = e.Ctx().GetSessionVars().StmtCtx.HandleError(err) + if err != nil { + return err + } + } + } + if e.useCache { // look up the cache + atomic.AddInt64(&e.cacheAccessCounter, 1) + failpoint.Inject("parallelApplyGetCachePanic", nil) + value, err := e.cache.Get(key) + if err != nil { + return err + } + if value != nil { + e.innerList[id] = value + atomic.AddInt64(&e.cacheHitCounter, 1) + return nil + } + } + + err = exec.Open(ctx, e.innerExecs[id]) + defer func() { terror.Log(exec.Close(e.innerExecs[id])) }() + if err != nil { + return err + } + + if e.useCache { + // create a new one in this case since it may be in the cache + e.innerList[id] = chunk.NewList(exec.RetTypes(e.innerExecs[id]), e.InitCap(), e.MaxChunkSize()) + } else { + e.innerList[id].Reset() + } + + innerIter := chunk.NewIterator4Chunk(e.innerChunk[id]) + for { + err := exec.Next(ctx, e.innerExecs[id], e.innerChunk[id]) + if err != nil { + return err + } + if e.innerChunk[id].NumRows() == 0 { + break + } + + e.innerSelected[id], err = expression.VectorizedFilter(e.Ctx().GetExprCtx().GetEvalCtx(), e.Ctx().GetSessionVars().EnableVectorizedExpression, e.innerFilter[id], innerIter, e.innerSelected[id]) + if err != nil { + return err + } + for row := innerIter.Begin(); row != innerIter.End(); row = innerIter.Next() { + if e.innerSelected[id][row.Idx()] { + e.innerList[id].AppendRow(row) + } + } + } + + if e.useCache { // update the cache + failpoint.Inject("parallelApplySetCachePanic", nil) + if _, err := e.cache.Set(key, e.innerList[id]); err != nil { + return err + } + } + return nil +} + +func (e *ParallelNestedLoopApplyExec) fetchNextOuterRow(id int, req *chunk.Chunk) (row *chunk.Row, exit bool) { + for { + select { + case outerRow, ok := <-e.outerRowCh: + if !ok { // no more data + return nil, false + } + if !outerRow.selected { + if e.outer { + e.joiners[id].OnMissMatch(false, *outerRow.row, req) + if req.IsFull() { + return nil, false + } + } + continue // try the next outer row + } + return outerRow.row, false + case <-e.exit: + return nil, true + } + } +} + +func (e *ParallelNestedLoopApplyExec) fillInnerChunk(ctx context.Context, id int, req *chunk.Chunk) (err error) { + req.Reset() + for { + if e.innerIter[id] == nil || e.innerIter[id].Current() == e.innerIter[id].End() { + if e.outerRow[id] != nil && !e.hasMatch[id] { + e.joiners[id].OnMissMatch(e.hasNull[id], *e.outerRow[id], req) + } + var exit bool + e.outerRow[id], exit = e.fetchNextOuterRow(id, req) + if exit || req.IsFull() || e.outerRow[id] == nil { + return nil + } + + e.hasMatch[id] = false + e.hasNull[id] = false + + err = e.fetchAllInners(ctx, id) + if err != nil { + return err + } + e.innerIter[id] = chunk.NewIterator4List(e.innerList[id]) + e.innerIter[id].Begin() + } + + matched, isNull, err := e.joiners[id].TryToMatchInners(*e.outerRow[id], e.innerIter[id], req) + e.hasMatch[id] = e.hasMatch[id] || matched + e.hasNull[id] = e.hasNull[id] || isNull + + if err != nil || req.IsFull() { + return err + } + } +} diff --git a/pkg/executor/point_get.go b/pkg/executor/point_get.go index ee3b7047aa89e..c2af1a0332290 100644 --- a/pkg/executor/point_get.go +++ b/pkg/executor/point_get.go @@ -103,12 +103,12 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) exec.Execut sctx.IndexNames = append(sctx.IndexNames, p.TblInfo.Name.O+":"+p.IndexInfo.Name.O) } - failpoint.Inject("assertPointReplicaOption", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertPointReplicaOption")); _err_ == nil { assertScope := val.(string) if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != e.readReplicaScope { panic("point get replica option fail") } - }) + } snapshotTS, err := b.getSnapshotTS() if err != nil { @@ -340,14 +340,14 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 // 3. Then point get retrieve data from backend after step 2 finished // 4. Check the result - failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step1", func() { + if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("pointGetRepeatableReadTest-step1")); _err_ == nil { if ch, ok := ctx.Value("pointGetRepeatableReadTest").(chan struct{}); ok { // Make `UPDATE` continue close(ch) } // Wait `UPDATE` finished - failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step2", nil) - }) + failpoint.EvalContext(ctx, _curpkg_("pointGetRepeatableReadTest-step2")) + } if e.idxInfo.Global { _, pid, err := codec.DecodeInt(tablecodec.SplitIndexValue(e.handleVal).PartitionID) if err != nil { diff --git a/pkg/executor/point_get.go__failpoint_stash__ b/pkg/executor/point_get.go__failpoint_stash__ new file mode 100644 index 0000000000000..ee3b7047aa89e --- /dev/null +++ b/pkg/executor/point_get.go__failpoint_stash__ @@ -0,0 +1,824 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "fmt" + "sort" + "strconv" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/distsql" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil/consistency" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" +) + +func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) exec.Executor { + var err error + if err = b.validCanReadTemporaryOrCacheTable(p.TblInfo); err != nil { + b.err = err + return nil + } + + if p.PrunePartitions(b.ctx) { + // no matching partitions + return &TableDualExec{ + BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), p.Schema(), p.ID()), + numDualRows: 0, + numReturned: 0, + } + } + + if p.Lock && !b.inSelectLockStmt { + b.inSelectLockStmt = true + defer func() { + b.inSelectLockStmt = false + }() + } + + e := &PointGetExecutor{ + BaseExecutor: exec.NewBaseExecutor(b.ctx, p.Schema(), p.ID()), + indexUsageReporter: b.buildIndexUsageReporter(p), + txnScope: b.txnScope, + readReplicaScope: b.readReplicaScope, + isStaleness: b.isStaleness, + partitionNames: p.PartitionNames, + } + + e.SetInitCap(1) + e.SetMaxChunkSize(1) + e.Init(p) + + e.snapshot, err = b.getSnapshot() + if err != nil { + b.err = err + return nil + } + if b.ctx.GetSessionVars().IsReplicaReadClosestAdaptive() { + e.snapshot.SetOption(kv.ReplicaReadAdjuster, newReplicaReadAdjuster(e.Ctx(), p.GetAvgRowSize())) + } + if e.RuntimeStats() != nil { + snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} + e.stats = &runtimeStatsWithSnapshot{ + SnapshotRuntimeStats: snapshotStats, + } + e.snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) + } + + if p.IndexInfo != nil { + sctx := b.ctx.GetSessionVars().StmtCtx + sctx.IndexNames = append(sctx.IndexNames, p.TblInfo.Name.O+":"+p.IndexInfo.Name.O) + } + + failpoint.Inject("assertPointReplicaOption", func(val failpoint.Value) { + assertScope := val.(string) + if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != e.readReplicaScope { + panic("point get replica option fail") + } + }) + + snapshotTS, err := b.getSnapshotTS() + if err != nil { + b.err = err + return nil + } + if p.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { + if cacheTable := b.getCacheTable(p.TblInfo, snapshotTS); cacheTable != nil { + e.snapshot = cacheTableSnapshot{e.snapshot, cacheTable} + } + } + + if e.lock { + b.hasLock = true + } + + return e +} + +// PointGetExecutor executes point select query. +type PointGetExecutor struct { + exec.BaseExecutor + indexUsageReporter *exec.IndexUsageReporter + + tblInfo *model.TableInfo + handle kv.Handle + idxInfo *model.IndexInfo + partitionDefIdx *int + partitionNames []model.CIStr + idxKey kv.Key + handleVal []byte + idxVals []types.Datum + txnScope string + readReplicaScope string + isStaleness bool + txn kv.Transaction + snapshot kv.Snapshot + done bool + lock bool + lockWaitTime int64 + rowDecoder *rowcodec.ChunkDecoder + + columns []*model.ColumnInfo + // virtualColumnIndex records all the indices of virtual columns and sort them in definition + // to make sure we can compute the virtual column in right order. + virtualColumnIndex []int + + // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. + virtualColumnRetFieldTypes []*types.FieldType + + stats *runtimeStatsWithSnapshot +} + +// GetPhysID returns the physical id used, either the table's id or a partition's ID +func GetPhysID(tblInfo *model.TableInfo, idx *int) int64 { + if idx != nil { + if *idx < 0 { + intest.Assert(false) + } else { + if pi := tblInfo.GetPartitionInfo(); pi != nil { + return pi.Definitions[*idx].ID + } + } + } + return tblInfo.ID +} + +func matchPartitionNames(pid int64, partitionNames []model.CIStr, pi *model.PartitionInfo) bool { + if len(partitionNames) == 0 { + return true + } + defs := pi.Definitions + for i := range defs { + // TODO: create a map from id to partition definition index + if defs[i].ID == pid { + for _, name := range partitionNames { + if defs[i].Name.L == name.L { + return true + } + } + // Only one partition can match pid + return false + } + } + return false +} + +// Init set fields needed for PointGetExecutor reuse, this does NOT change baseExecutor field +func (e *PointGetExecutor) Init(p *plannercore.PointGetPlan) { + decoder := NewRowDecoder(e.Ctx(), p.Schema(), p.TblInfo) + e.tblInfo = p.TblInfo + e.handle = p.Handle + e.idxInfo = p.IndexInfo + e.idxVals = p.IndexValues + e.done = false + if e.tblInfo.TempTableType == model.TempTableNone { + e.lock = p.Lock + e.lockWaitTime = p.LockWaitTime + } else { + // Temporary table should not do any lock operations + e.lock = false + e.lockWaitTime = 0 + } + e.rowDecoder = decoder + e.partitionDefIdx = p.PartitionIdx + e.columns = p.Columns + e.buildVirtualColumnInfo() +} + +// buildVirtualColumnInfo saves virtual column indices and sort them in definition order +func (e *PointGetExecutor) buildVirtualColumnInfo() { + e.virtualColumnIndex = buildVirtualColumnIndex(e.Schema(), e.columns) + if len(e.virtualColumnIndex) > 0 { + e.virtualColumnRetFieldTypes = make([]*types.FieldType, len(e.virtualColumnIndex)) + for i, idx := range e.virtualColumnIndex { + e.virtualColumnRetFieldTypes[i] = e.Schema().Columns[idx].RetType + } + } +} + +// Open implements the Executor interface. +func (e *PointGetExecutor) Open(context.Context) error { + var err error + e.txn, err = e.Ctx().Txn(false) + if err != nil { + return err + } + if err := e.verifyTxnScope(); err != nil { + return err + } + setOptionForTopSQL(e.Ctx().GetSessionVars().StmtCtx, e.snapshot) + return nil +} + +// Close implements the Executor interface. +func (e *PointGetExecutor) Close() error { + if e.stats != nil { + defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) + } + if e.RuntimeStats() != nil && e.snapshot != nil { + e.snapshot.SetOption(kv.CollectRuntimeStats, nil) + } + if e.indexUsageReporter != nil && e.idxInfo != nil { + tableID := e.tblInfo.ID + physicalTableID := GetPhysID(e.tblInfo, e.partitionDefIdx) + kvReqTotal := e.stats.SnapshotRuntimeStats.GetCmdRPCCount(tikvrpc.CmdGet) + e.indexUsageReporter.ReportPointGetIndexUsage(tableID, physicalTableID, e.idxInfo.ID, e.ID(), kvReqTotal) + } + e.done = false + return nil +} + +// Next implements the Executor interface. +func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.done { + return nil + } + e.done = true + + var err error + tblID := GetPhysID(e.tblInfo, e.partitionDefIdx) + if e.lock { + e.UpdateDeltaForTableID(tblID) + } + if e.idxInfo != nil { + if isCommonHandleRead(e.tblInfo, e.idxInfo) { + handleBytes, err := plannercore.EncodeUniqueIndexValuesForKey(e.Ctx(), e.tblInfo, e.idxInfo, e.idxVals) + if err != nil { + if kv.ErrNotExist.Equal(err) { + return nil + } + return err + } + e.handle, err = kv.NewCommonHandle(handleBytes) + if err != nil { + return err + } + } else { + e.idxKey, err = plannercore.EncodeUniqueIndexKey(e.Ctx(), e.tblInfo, e.idxInfo, e.idxVals, tblID) + if err != nil && !kv.ErrNotExist.Equal(err) { + return err + } + + // lockNonExistIdxKey indicates the key will be locked regardless of its existence. + lockNonExistIdxKey := !e.Ctx().GetSessionVars().IsPessimisticReadConsistency() + // Non-exist keys are also locked if the isolation level is not read consistency, + // lock it before read here, then it's able to read from pessimistic lock cache. + if lockNonExistIdxKey { + err = e.lockKeyIfNeeded(ctx, e.idxKey) + if err != nil { + return err + } + e.handleVal, err = e.get(ctx, e.idxKey) + if err != nil { + if !kv.ErrNotExist.Equal(err) { + return err + } + } + } else { + if e.lock { + e.handleVal, err = e.lockKeyIfExists(ctx, e.idxKey) + if err != nil { + return err + } + } else { + e.handleVal, err = e.get(ctx, e.idxKey) + if err != nil { + if !kv.ErrNotExist.Equal(err) { + return err + } + } + } + } + + if len(e.handleVal) == 0 { + return nil + } + + var iv kv.Handle + iv, err = tablecodec.DecodeHandleInIndexValue(e.handleVal) + if err != nil { + return err + } + e.handle = iv + + // The injection is used to simulate following scenario: + // 1. Session A create a point get query but pause before second time `GET` kv from backend + // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 + // 3. Then point get retrieve data from backend after step 2 finished + // 4. Check the result + failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step1", func() { + if ch, ok := ctx.Value("pointGetRepeatableReadTest").(chan struct{}); ok { + // Make `UPDATE` continue + close(ch) + } + // Wait `UPDATE` finished + failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step2", nil) + }) + if e.idxInfo.Global { + _, pid, err := codec.DecodeInt(tablecodec.SplitIndexValue(e.handleVal).PartitionID) + if err != nil { + return err + } + tblID = pid + if !matchPartitionNames(tblID, e.partitionNames, e.tblInfo.GetPartitionInfo()) { + return nil + } + } + } + } + + key := tablecodec.EncodeRowKeyWithHandle(tblID, e.handle) + val, err := e.getAndLock(ctx, key) + if err != nil { + return err + } + if len(val) == 0 { + if e.idxInfo != nil && !isCommonHandleRead(e.tblInfo, e.idxInfo) && + !e.Ctx().GetSessionVars().StmtCtx.WeakConsistency { + return (&consistency.Reporter{ + HandleEncode: func(kv.Handle) kv.Key { + return key + }, + IndexEncode: func(*consistency.RecordData) kv.Key { + return e.idxKey + }, + Tbl: e.tblInfo, + Idx: e.idxInfo, + EnableRedactLog: e.Ctx().GetSessionVars().EnableRedactLog, + Storage: e.Ctx().GetStore(), + }).ReportLookupInconsistent(ctx, + 1, 0, + []kv.Handle{e.handle}, + []kv.Handle{e.handle}, + []consistency.RecordData{{}}, + ) + } + return nil + } + + sctx := e.BaseExecutor.Ctx() + schema := e.Schema() + err = DecodeRowValToChunk(sctx, schema, e.tblInfo, e.handle, val, req, e.rowDecoder) + if err != nil { + return err + } + + err = fillRowChecksum(sctx, 0, 1, schema, e.tblInfo, [][]byte{val}, []kv.Handle{e.handle}, req, nil) + if err != nil { + return err + } + + err = table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, + schema.Columns, e.columns, sctx.GetExprCtx(), req) + if err != nil { + return err + } + return nil +} + +func shouldFillRowChecksum(schema *expression.Schema) (int, bool) { + for idx, col := range schema.Columns { + if col.ID == model.ExtraRowChecksumID { + return idx, true + } + } + return 0, false +} + +func fillRowChecksum( + sctx sessionctx.Context, + start, end int, + schema *expression.Schema, tblInfo *model.TableInfo, + values [][]byte, handles []kv.Handle, + req *chunk.Chunk, buf []byte, +) error { + checksumColumnIndex, ok := shouldFillRowChecksum(schema) + if !ok { + return nil + } + + var handleColIDs []int64 + if tblInfo.PKIsHandle { + colInfo := tblInfo.GetPkColInfo() + handleColIDs = []int64{colInfo.ID} + } else if tblInfo.IsCommonHandle { + pkIdx := tables.FindPrimaryIndex(tblInfo) + for _, col := range pkIdx.Columns { + colInfo := tblInfo.Columns[col.Offset] + handleColIDs = append(handleColIDs, colInfo.ID) + } + } + + columnFt := make(map[int64]*types.FieldType) + for idx := range tblInfo.Columns { + col := tblInfo.Columns[idx] + columnFt[col.ID] = &col.FieldType + } + tz := sctx.GetSessionVars().TimeZone + ft := []*types.FieldType{schema.Columns[checksumColumnIndex].GetType(sctx.GetExprCtx().GetEvalCtx())} + checksumCols := chunk.NewChunkWithCapacity(ft, req.Capacity()) + for i := start; i < end; i++ { + handle, val := handles[i], values[i] + if !rowcodec.IsNewFormat(val) { + checksumCols.AppendNull(0) + continue + } + datums, err := tablecodec.DecodeRowWithMapNew(val, columnFt, tz, nil) + if err != nil { + return err + } + datums, err = tablecodec.DecodeHandleToDatumMap(handle, handleColIDs, columnFt, tz, datums) + if err != nil { + return err + } + for _, col := range tblInfo.Columns { + // cannot found from the datums, which means the data is not stored, this + // may happen after `add column` executed, filling with the default value. + _, ok := datums[col.ID] + if !ok { + colInfo := getColInfoByID(tblInfo, col.ID) + d, err := table.GetColOriginDefaultValue(sctx.GetExprCtx(), colInfo) + if err != nil { + return err + } + datums[col.ID] = d + } + } + + colData := make([]rowcodec.ColData, len(tblInfo.Columns)) + for idx, col := range tblInfo.Columns { + d := datums[col.ID] + data := rowcodec.ColData{ + ColumnInfo: col, + Datum: &d, + } + colData[idx] = data + } + row := rowcodec.RowData{ + Cols: colData, + Data: buf, + } + if !sort.IsSorted(row) { + sort.Sort(row) + } + checksum, err := row.Checksum(tz) + if err != nil { + return err + } + checksumCols.AppendString(0, strconv.FormatUint(uint64(checksum), 10)) + } + req.SetCol(checksumColumnIndex, checksumCols.Column(0)) + return nil +} + +func (e *PointGetExecutor) getAndLock(ctx context.Context, key kv.Key) (val []byte, err error) { + if e.Ctx().GetSessionVars().IsPessimisticReadConsistency() { + // Only Lock the existing keys in RC isolation. + if e.lock { + val, err = e.lockKeyIfExists(ctx, key) + if err != nil { + return nil, err + } + } else { + val, err = e.get(ctx, key) + if err != nil { + if !kv.ErrNotExist.Equal(err) { + return nil, err + } + return nil, nil + } + } + return val, nil + } + // Lock the key before get in RR isolation, then get will get the value from the cache. + err = e.lockKeyIfNeeded(ctx, key) + if err != nil { + return nil, err + } + val, err = e.get(ctx, key) + if err != nil { + if !kv.ErrNotExist.Equal(err) { + return nil, err + } + return nil, nil + } + return val, nil +} + +func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) error { + _, err := e.lockKeyBase(ctx, key, false) + return err +} + +// lockKeyIfExists locks the key if needed, but won't lock the key if it doesn't exis. +// Returns the value of the key if the key exist. +func (e *PointGetExecutor) lockKeyIfExists(ctx context.Context, key []byte) ([]byte, error) { + return e.lockKeyBase(ctx, key, true) +} + +func (e *PointGetExecutor) lockKeyBase(ctx context.Context, + key []byte, + lockOnlyIfExists bool) ([]byte, error) { + if len(key) == 0 { + return nil, nil + } + + if e.lock { + seVars := e.Ctx().GetSessionVars() + lockCtx, err := newLockCtx(e.Ctx(), e.lockWaitTime, 1) + if err != nil { + return nil, err + } + lockCtx.LockOnlyIfExists = lockOnlyIfExists + lockCtx.InitReturnValues(1) + err = doLockKeys(ctx, e.Ctx(), lockCtx, key) + if err != nil { + return nil, err + } + lockCtx.IterateValuesNotLocked(func(k, v []byte) { + seVars.TxnCtx.SetPessimisticLockCache(k, v) + }) + if len(e.handleVal) > 0 { + seVars.TxnCtx.SetPessimisticLockCache(e.idxKey, e.handleVal) + } + if lockOnlyIfExists { + return e.getValueFromLockCtx(ctx, lockCtx, key) + } + } + + return nil, nil +} + +func (e *PointGetExecutor) getValueFromLockCtx(ctx context.Context, + lockCtx *kv.LockCtx, + key []byte) ([]byte, error) { + if val, ok := lockCtx.Values[string(key)]; ok { + if val.Exists { + return val.Value, nil + } else if val.AlreadyLocked { + val, err := e.get(ctx, key) + if err != nil { + if !kv.ErrNotExist.Equal(err) { + return nil, err + } + return nil, nil + } + return val, nil + } + } + + return nil, nil +} + +// get will first try to get from txn buffer, then check the pessimistic lock cache, +// then the store. Kv.ErrNotExist will be returned if key is not found +func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error) { + if len(key) == 0 { + return nil, kv.ErrNotExist + } + + var ( + val []byte + err error + ) + + if e.txn.Valid() && !e.txn.IsReadOnly() { + // We cannot use txn.Get directly here because the snapshot in txn and the snapshot of e.snapshot may be + // different for pessimistic transaction. + val, err = e.txn.GetMemBuffer().Get(ctx, key) + if err == nil { + return val, err + } + if !kv.IsErrNotFound(err) { + return nil, err + } + // key does not exist in mem buffer, check the lock cache + if e.lock { + var ok bool + val, ok = e.Ctx().GetSessionVars().TxnCtx.GetKeyInPessimisticLockCache(key) + if ok { + return val, nil + } + } + // fallthrough to snapshot get. + } + + lock := e.tblInfo.Lock + if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) { + if e.Ctx().GetSessionVars().EnablePointGetCache { + cacheDB := e.Ctx().GetStore().GetMemCache() + val, err = cacheDB.UnionGet(ctx, e.tblInfo.ID, e.snapshot, key) + if err != nil { + return nil, err + } + return val, nil + } + } + // if not read lock or table was unlock then snapshot get + return e.snapshot.Get(ctx, key) +} + +func (e *PointGetExecutor) verifyTxnScope() error { + if e.txnScope == "" || e.txnScope == kv.GlobalTxnScope { + return nil + } + + var partName string + is := e.Ctx().GetInfoSchema().(infoschema.InfoSchema) + tblInfo, _ := is.TableByID((e.tblInfo.ID)) + tblName := tblInfo.Meta().Name.String() + tblID := GetPhysID(tblInfo.Meta(), e.partitionDefIdx) + if tblID != tblInfo.Meta().ID { + partName = tblInfo.Meta().GetPartitionInfo().Definitions[*e.partitionDefIdx].Name.String() + } + valid := distsql.VerifyTxnScope(e.txnScope, tblID, is) + if valid { + return nil + } + if len(partName) > 0 { + return dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs( + fmt.Sprintf("table %v's partition %v can not be read by %v txn_scope", tblName, partName, e.txnScope)) + } + return dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs( + fmt.Sprintf("table %v can not be read by %v txn_scope", tblName, e.txnScope)) +} + +// DecodeRowValToChunk decodes row value into chunk checking row format used. +func DecodeRowValToChunk(sctx sessionctx.Context, schema *expression.Schema, tblInfo *model.TableInfo, + handle kv.Handle, rowVal []byte, chk *chunk.Chunk, rd *rowcodec.ChunkDecoder) error { + if rowcodec.IsNewFormat(rowVal) { + return rd.DecodeToChunk(rowVal, handle, chk) + } + return decodeOldRowValToChunk(sctx, schema, tblInfo, handle, rowVal, chk) +} + +func decodeOldRowValToChunk(sctx sessionctx.Context, schema *expression.Schema, tblInfo *model.TableInfo, handle kv.Handle, + rowVal []byte, chk *chunk.Chunk) error { + pkCols := tables.TryGetCommonPkColumnIds(tblInfo) + prefixColIDs := tables.PrimaryPrefixColumnIDs(tblInfo) + colID2CutPos := make(map[int64]int, schema.Len()) + for _, col := range schema.Columns { + if _, ok := colID2CutPos[col.ID]; !ok { + colID2CutPos[col.ID] = len(colID2CutPos) + } + } + cutVals, err := tablecodec.CutRowNew(rowVal, colID2CutPos) + if err != nil { + return err + } + if cutVals == nil { + cutVals = make([][]byte, len(colID2CutPos)) + } + decoder := codec.NewDecoder(chk, sctx.GetSessionVars().Location()) + for i, col := range schema.Columns { + // fill the virtual column value after row calculation + if col.VirtualExpr != nil { + chk.AppendNull(i) + continue + } + ok, err := tryDecodeFromHandle(tblInfo, i, col, handle, chk, decoder, pkCols, prefixColIDs) + if err != nil { + return err + } + if ok { + continue + } + cutPos := colID2CutPos[col.ID] + if len(cutVals[cutPos]) == 0 { + colInfo := getColInfoByID(tblInfo, col.ID) + d, err1 := table.GetColOriginDefaultValue(sctx.GetExprCtx(), colInfo) + if err1 != nil { + return err1 + } + chk.AppendDatum(i, &d) + continue + } + _, err = decoder.DecodeOne(cutVals[cutPos], i, col.RetType) + if err != nil { + return err + } + } + return nil +} + +func tryDecodeFromHandle(tblInfo *model.TableInfo, schemaColIdx int, col *expression.Column, handle kv.Handle, chk *chunk.Chunk, + decoder *codec.Decoder, pkCols []int64, prefixColIDs []int64) (bool, error) { + if tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.RetType.GetFlag()) { + chk.AppendInt64(schemaColIdx, handle.IntValue()) + return true, nil + } + if col.ID == model.ExtraHandleID { + chk.AppendInt64(schemaColIdx, handle.IntValue()) + return true, nil + } + if types.NeedRestoredData(col.RetType) { + return false, nil + } + // Try to decode common handle. + if mysql.HasPriKeyFlag(col.RetType.GetFlag()) { + for i, hid := range pkCols { + if col.ID == hid && notPKPrefixCol(hid, prefixColIDs) { + _, err := decoder.DecodeOne(handle.EncodedCol(i), schemaColIdx, col.RetType) + if err != nil { + return false, errors.Trace(err) + } + return true, nil + } + } + } + return false, nil +} + +func notPKPrefixCol(colID int64, prefixColIDs []int64) bool { + for _, pCol := range prefixColIDs { + if pCol == colID { + return false + } + } + return true +} + +func getColInfoByID(tbl *model.TableInfo, colID int64) *model.ColumnInfo { + for _, col := range tbl.Columns { + if col.ID == colID { + return col + } + } + return nil +} + +type runtimeStatsWithSnapshot struct { + *txnsnapshot.SnapshotRuntimeStats +} + +func (e *runtimeStatsWithSnapshot) String() string { + if e.SnapshotRuntimeStats != nil { + return e.SnapshotRuntimeStats.String() + } + return "" +} + +// Clone implements the RuntimeStats interface. +func (e *runtimeStatsWithSnapshot) Clone() execdetails.RuntimeStats { + newRs := &runtimeStatsWithSnapshot{} + if e.SnapshotRuntimeStats != nil { + snapshotStats := e.SnapshotRuntimeStats.Clone() + newRs.SnapshotRuntimeStats = snapshotStats + } + return newRs +} + +// Merge implements the RuntimeStats interface. +func (e *runtimeStatsWithSnapshot) Merge(other execdetails.RuntimeStats) { + tmp, ok := other.(*runtimeStatsWithSnapshot) + if !ok { + return + } + if tmp.SnapshotRuntimeStats != nil { + if e.SnapshotRuntimeStats == nil { + snapshotStats := tmp.SnapshotRuntimeStats.Clone() + e.SnapshotRuntimeStats = snapshotStats + return + } + e.SnapshotRuntimeStats.Merge(tmp.SnapshotRuntimeStats) + } +} + +// Tp implements the RuntimeStats interface. +func (*runtimeStatsWithSnapshot) Tp() int { + return execdetails.TpRuntimeStatsWithSnapshot +} diff --git a/pkg/executor/projection.go b/pkg/executor/projection.go index ce1113dde4776..632e8bf586a5e 100644 --- a/pkg/executor/projection.go +++ b/pkg/executor/projection.go @@ -106,11 +106,11 @@ func (e *ProjectionExec) Open(ctx context.Context) error { if err := e.BaseExecutorV2.Open(ctx); err != nil { return err } - failpoint.Inject("mockProjectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockProjectionExecBaseExecutorOpenReturnedError")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.New("mock ProjectionExec.baseExecutor.Open returned error")) + return errors.New("mock ProjectionExec.baseExecutor.Open returned error") } - }) + } return e.open(ctx) } @@ -216,7 +216,7 @@ func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk e.childResult.SetRequiredRows(chk.RequiredRows(), e.MaxChunkSize()) mSize := e.childResult.MemoryUsage() err := exec.Next(ctx, e.Children(0), e.childResult) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err @@ -246,7 +246,7 @@ func (e *ProjectionExec) parallelExecute(ctx context.Context, chk *chunk.Chunk) } mSize := output.chk.MemoryUsage() chk.SwapColumns(output.chk) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) e.memTracker.Consume(output.chk.MemoryUsage() - mSize) e.fetcher.outputCh <- output return nil @@ -280,7 +280,7 @@ func (e *ProjectionExec) prepare(ctx context.Context) { }) inputChk := exec.NewFirstChunk(e.Children(0)) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) e.memTracker.Consume(inputChk.MemoryUsage()) e.fetcher.inputCh <- &projectionInput{ chk: inputChk, @@ -408,7 +408,7 @@ func (f *projectionInputFetcher) run(ctx context.Context) { input.chk.SetRequiredRows(int(requiredRows), f.proj.MaxChunkSize()) mSize := input.chk.MemoryUsage() err := exec.Next(ctx, f.child, input.chk) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) f.proj.memTracker.Consume(input.chk.MemoryUsage() - mSize) if err != nil || input.chk.NumRows() == 0 { output.done <- err @@ -469,7 +469,7 @@ func (w *projectionWorker) run(ctx context.Context) { mSize := output.chk.MemoryUsage() + input.chk.MemoryUsage() err := w.evaluatorSuit.Run(w.ctx.evalCtx, w.ctx.enableVectorizedExpression, input.chk, output.chk) - failpoint.Inject("ConsumeRandomPanic", nil) + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) w.proj.memTracker.Consume(output.chk.MemoryUsage() + input.chk.MemoryUsage() - mSize) output.done <- err diff --git a/pkg/executor/projection.go__failpoint_stash__ b/pkg/executor/projection.go__failpoint_stash__ new file mode 100644 index 0000000000000..ce1113dde4776 --- /dev/null +++ b/pkg/executor/projection.go__failpoint_stash__ @@ -0,0 +1,501 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "fmt" + "runtime/trace" + "sync" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +// This file contains the implementation of the physical Projection Operator: +// https://en.wikipedia.org/wiki/Projection_(relational_algebra) +// +// NOTE: +// 1. The number of "projectionWorker" is controlled by the global session +// variable "tidb_projection_concurrency". +// 2. Unparallel version is used when one of the following situations occurs: +// a. "tidb_projection_concurrency" is set to 0. +// b. The estimated input size is smaller than "tidb_max_chunk_size". +// c. This projection can not be executed vectorially. + +type projectionInput struct { + chk *chunk.Chunk + targetWorker *projectionWorker +} + +type projectionOutput struct { + chk *chunk.Chunk + done chan error +} + +// projectionExecutorContext is the execution context for the `ProjectionExec` +type projectionExecutorContext struct { + stmtMemTracker *memory.Tracker + stmtRuntimeStatsColl *execdetails.RuntimeStatsColl + evalCtx expression.EvalContext + enableVectorizedExpression bool +} + +func newProjectionExecutorContext(sctx sessionctx.Context) projectionExecutorContext { + return projectionExecutorContext{ + stmtMemTracker: sctx.GetSessionVars().StmtCtx.MemTracker, + stmtRuntimeStatsColl: sctx.GetSessionVars().StmtCtx.RuntimeStatsColl, + evalCtx: sctx.GetExprCtx().GetEvalCtx(), + enableVectorizedExpression: sctx.GetSessionVars().EnableVectorizedExpression, + } +} + +// ProjectionExec implements the physical Projection Operator: +// https://en.wikipedia.org/wiki/Projection_(relational_algebra) +type ProjectionExec struct { + projectionExecutorContext + exec.BaseExecutorV2 + + evaluatorSuit *expression.EvaluatorSuite + + finishCh chan struct{} + outputCh chan *projectionOutput + fetcher projectionInputFetcher + numWorkers int64 + workers []*projectionWorker + childResult *chunk.Chunk + + // parentReqRows indicates how many rows the parent executor is + // requiring. It is set when parallelExecute() is called and used by the + // concurrent projectionInputFetcher. + // + // NOTE: It should be protected by atomic operations. + parentReqRows int64 + + memTracker *memory.Tracker + wg *sync.WaitGroup + + calculateNoDelay bool + prepared bool +} + +// Open implements the Executor Open interface. +func (e *ProjectionExec) Open(ctx context.Context) error { + if err := e.BaseExecutorV2.Open(ctx); err != nil { + return err + } + failpoint.Inject("mockProjectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("mock ProjectionExec.baseExecutor.Open returned error")) + } + }) + return e.open(ctx) +} + +func (e *ProjectionExec) open(_ context.Context) error { + e.prepared = false + e.parentReqRows = int64(e.MaxChunkSize()) + + if e.memTracker != nil { + e.memTracker.Reset() + } else { + e.memTracker = memory.NewTracker(e.ID(), -1) + } + e.memTracker.AttachTo(e.stmtMemTracker) + + // For now a Projection can not be executed vectorially only because it + // contains "SetVar" or "GetVar" functions, in this scenario this + // Projection can not be executed parallelly. + if e.numWorkers > 0 && !e.evaluatorSuit.Vectorizable() { + e.numWorkers = 0 + } + + if e.isUnparallelExec() { + e.childResult = exec.TryNewCacheChunk(e.Children(0)) + e.memTracker.Consume(e.childResult.MemoryUsage()) + } + + e.wg = &sync.WaitGroup{} + + return nil +} + +// Next implements the Executor Next interface. +// +// Here we explain the execution flow of the parallel projection implementation. +// There are 3 main components: +// 1. "projectionInputFetcher": Fetch input "Chunk" from child. +// 2. "projectionWorker": Do the projection work. +// 3. "ProjectionExec.Next": Return result to parent. +// +// 1. "projectionInputFetcher" gets its input and output resources from its +// "inputCh" and "outputCh" channel, once the input and output resources are +// obtained, it fetches child's result into "input.chk" and: +// a. Dispatches this input to the worker specified in "input.targetWorker" +// b. Dispatches this output to the main thread: "ProjectionExec.Next" +// c. Dispatches this output to the worker specified in "input.targetWorker" +// It is finished and exited once: +// a. There is no more input from child. +// b. "ProjectionExec" close the "globalFinishCh" +// +// 2. "projectionWorker" gets its input and output resources from its +// "inputCh" and "outputCh" channel, once the input and output resources are +// abtained, it calculates the projection result use "input.chk" as the input +// and "output.chk" as the output, once the calculation is done, it: +// a. Sends "nil" or error to "output.done" to mark this input is finished. +// b. Returns the "input" resource to "projectionInputFetcher.inputCh" +// They are finished and exited once: +// a. "ProjectionExec" closes the "globalFinishCh" +// +// 3. "ProjectionExec.Next" gets its output resources from its "outputCh" channel. +// After receiving an output from "outputCh", it should wait to receive a "nil" +// or error from "output.done" channel. Once a "nil" or error is received: +// a. Returns this output to its parent +// b. Returns the "output" resource to "projectionInputFetcher.outputCh" +/* + +-----------+----------------------+--------------------------+ + | | | | + | +--------+---------+ +--------+---------+ +--------+---------+ + | | projectionWorker | + projectionWorker | ... + projectionWorker | + | +------------------+ +------------------+ +------------------+ + | ^ ^ ^ ^ ^ ^ + | | | | | | | + | inputCh outputCh inputCh outputCh inputCh outputCh + | ^ ^ ^ ^ ^ ^ + | | | | | | | + | | | + | | +----------------->outputCh + | | | | + | | | v + | +-------+-------+--------+ +---------------------+ + | | projectionInputFetcher | | ProjectionExec.Next | + | +------------------------+ +---------+-----------+ + | ^ ^ | + | | | | + | inputCh outputCh | + | ^ ^ | + | | | | + +------------------------------+ +----------------------+ +*/ +func (e *ProjectionExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + if e.isUnparallelExec() { + return e.unParallelExecute(ctx, req) + } + return e.parallelExecute(ctx, req) +} + +func (e *ProjectionExec) isUnparallelExec() bool { + return e.numWorkers <= 0 +} + +func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk) error { + // transmit the requiredRows + e.childResult.SetRequiredRows(chk.RequiredRows(), e.MaxChunkSize()) + mSize := e.childResult.MemoryUsage() + err := exec.Next(ctx, e.Children(0), e.childResult) + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) + if err != nil { + return err + } + if e.childResult.NumRows() == 0 { + return nil + } + err = e.evaluatorSuit.Run(e.evalCtx, e.enableVectorizedExpression, e.childResult, chk) + return err +} + +func (e *ProjectionExec) parallelExecute(ctx context.Context, chk *chunk.Chunk) error { + atomic.StoreInt64(&e.parentReqRows, int64(chk.RequiredRows())) + if !e.prepared { + e.prepare(ctx) + e.prepared = true + } + + output, ok := <-e.outputCh + if !ok { + return nil + } + + err := <-output.done + if err != nil { + return err + } + mSize := output.chk.MemoryUsage() + chk.SwapColumns(output.chk) + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(output.chk.MemoryUsage() - mSize) + e.fetcher.outputCh <- output + return nil +} + +func (e *ProjectionExec) prepare(ctx context.Context) { + e.finishCh = make(chan struct{}) + e.outputCh = make(chan *projectionOutput, e.numWorkers) + + // Initialize projectionInputFetcher. + e.fetcher = projectionInputFetcher{ + proj: e, + child: e.Children(0), + globalFinishCh: e.finishCh, + globalOutputCh: e.outputCh, + inputCh: make(chan *projectionInput, e.numWorkers), + outputCh: make(chan *projectionOutput, e.numWorkers), + } + + // Initialize projectionWorker. + e.workers = make([]*projectionWorker, 0, e.numWorkers) + for i := int64(0); i < e.numWorkers; i++ { + e.workers = append(e.workers, &projectionWorker{ + proj: e, + ctx: e.projectionExecutorContext, + evaluatorSuit: e.evaluatorSuit, + globalFinishCh: e.finishCh, + inputGiveBackCh: e.fetcher.inputCh, + inputCh: make(chan *projectionInput, 1), + outputCh: make(chan *projectionOutput, 1), + }) + + inputChk := exec.NewFirstChunk(e.Children(0)) + failpoint.Inject("ConsumeRandomPanic", nil) + e.memTracker.Consume(inputChk.MemoryUsage()) + e.fetcher.inputCh <- &projectionInput{ + chk: inputChk, + targetWorker: e.workers[i], + } + + outputChk := exec.NewFirstChunk(e) + e.memTracker.Consume(outputChk.MemoryUsage()) + e.fetcher.outputCh <- &projectionOutput{ + chk: outputChk, + done: make(chan error, 1), + } + } + + e.wg.Add(1) + go e.fetcher.run(ctx) + + for i := range e.workers { + e.wg.Add(1) + go e.workers[i].run(ctx) + } +} + +func (e *ProjectionExec) drainInputCh(ch chan *projectionInput) { + close(ch) + for item := range ch { + if item.chk != nil { + e.memTracker.Consume(-item.chk.MemoryUsage()) + } + } +} + +func (e *ProjectionExec) drainOutputCh(ch chan *projectionOutput) { + close(ch) + for item := range ch { + if item.chk != nil { + e.memTracker.Consume(-item.chk.MemoryUsage()) + } + } +} + +// Close implements the Executor Close interface. +func (e *ProjectionExec) Close() error { + // if e.BaseExecutor.Open returns error, e.childResult will be nil, see https://github.com/pingcap/tidb/issues/24210 + // for more information + if e.isUnparallelExec() && e.childResult != nil { + e.memTracker.Consume(-e.childResult.MemoryUsage()) + e.childResult = nil + } + if e.prepared { + close(e.finishCh) + e.wg.Wait() // Wait for fetcher and workers to finish and exit. + + // clear fetcher + e.drainInputCh(e.fetcher.inputCh) + e.drainOutputCh(e.fetcher.outputCh) + + // clear workers + for _, w := range e.workers { + e.drainInputCh(w.inputCh) + e.drainOutputCh(w.outputCh) + } + } + if e.BaseExecutorV2.RuntimeStats() != nil { + runtimeStats := &execdetails.RuntimeStatsWithConcurrencyInfo{} + if e.isUnparallelExec() { + runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", 0)) + } else { + runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", int(e.numWorkers))) + } + e.stmtRuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) + } + return e.BaseExecutorV2.Close() +} + +type projectionInputFetcher struct { + proj *ProjectionExec + child exec.Executor + globalFinishCh <-chan struct{} + globalOutputCh chan<- *projectionOutput + + inputCh chan *projectionInput + outputCh chan *projectionOutput +} + +// run gets projectionInputFetcher's input and output resources from its +// "inputCh" and "outputCh" channel, once the input and output resources are +// abtained, it fetches child's result into "input.chk" and: +// +// a. Dispatches this input to the worker specified in "input.targetWorker" +// b. Dispatches this output to the main thread: "ProjectionExec.Next" +// c. Dispatches this output to the worker specified in "input.targetWorker" +// +// It is finished and exited once: +// +// a. There is no more input from child. +// b. "ProjectionExec" close the "globalFinishCh" +func (f *projectionInputFetcher) run(ctx context.Context) { + defer trace.StartRegion(ctx, "ProjectionFetcher").End() + var output *projectionOutput + defer func() { + if r := recover(); r != nil { + recoveryProjection(output, r) + } + close(f.globalOutputCh) + f.proj.wg.Done() + }() + + for { + input, isNil := readProjection[*projectionInput](f.inputCh, f.globalFinishCh) + if isNil { + return + } + targetWorker := input.targetWorker + + output, isNil = readProjection[*projectionOutput](f.outputCh, f.globalFinishCh) + if isNil { + f.proj.memTracker.Consume(-input.chk.MemoryUsage()) + return + } + + f.globalOutputCh <- output + + requiredRows := atomic.LoadInt64(&f.proj.parentReqRows) + input.chk.SetRequiredRows(int(requiredRows), f.proj.MaxChunkSize()) + mSize := input.chk.MemoryUsage() + err := exec.Next(ctx, f.child, input.chk) + failpoint.Inject("ConsumeRandomPanic", nil) + f.proj.memTracker.Consume(input.chk.MemoryUsage() - mSize) + if err != nil || input.chk.NumRows() == 0 { + output.done <- err + f.proj.memTracker.Consume(-input.chk.MemoryUsage()) + return + } + + targetWorker.inputCh <- input + targetWorker.outputCh <- output + } +} + +type projectionWorker struct { + proj *ProjectionExec + ctx projectionExecutorContext + evaluatorSuit *expression.EvaluatorSuite + globalFinishCh <-chan struct{} + inputGiveBackCh chan<- *projectionInput + + // channel "input" and "output" is : + // a. initialized by "ProjectionExec.prepare" + // b. written by "projectionInputFetcher.run" + // c. read by "projectionWorker.run" + inputCh chan *projectionInput + outputCh chan *projectionOutput +} + +// run gets projectionWorker's input and output resources from its +// "inputCh" and "outputCh" channel, once the input and output resources are +// abtained, it calculate the projection result use "input.chk" as the input +// and "output.chk" as the output, once the calculation is done, it: +// +// a. Sends "nil" or error to "output.done" to mark this input is finished. +// b. Returns the "input" resource to "projectionInputFetcher.inputCh". +// +// It is finished and exited once: +// +// a. "ProjectionExec" closes the "globalFinishCh". +func (w *projectionWorker) run(ctx context.Context) { + defer trace.StartRegion(ctx, "ProjectionWorker").End() + var output *projectionOutput + defer func() { + if r := recover(); r != nil { + recoveryProjection(output, r) + } + w.proj.wg.Done() + }() + for { + input, isNil := readProjection[*projectionInput](w.inputCh, w.globalFinishCh) + if isNil { + return + } + + output, isNil = readProjection[*projectionOutput](w.outputCh, w.globalFinishCh) + if isNil { + return + } + + mSize := output.chk.MemoryUsage() + input.chk.MemoryUsage() + err := w.evaluatorSuit.Run(w.ctx.evalCtx, w.ctx.enableVectorizedExpression, input.chk, output.chk) + failpoint.Inject("ConsumeRandomPanic", nil) + w.proj.memTracker.Consume(output.chk.MemoryUsage() + input.chk.MemoryUsage() - mSize) + output.done <- err + + if err != nil { + return + } + + w.inputGiveBackCh <- input + } +} + +func recoveryProjection(output *projectionOutput, r any) { + if output != nil { + output.done <- util.GetRecoverError(r) + } + logutil.BgLogger().Error("projection executor panicked", zap.String("error", fmt.Sprintf("%v", r)), zap.Stack("stack")) +} + +func readProjection[T any](ch <-chan T, finishCh <-chan struct{}) (t T, isNil bool) { + select { + case <-finishCh: + return t, true + case t, ok := <-ch: + if !ok { + return t, true + } + return t, false + } +} diff --git a/pkg/executor/shuffle.go b/pkg/executor/shuffle.go index 9c0ee0f8050b2..fde7eb34e5d28 100644 --- a/pkg/executor/shuffle.go +++ b/pkg/executor/shuffle.go @@ -231,11 +231,11 @@ func (e *ShuffleExec) Next(ctx context.Context, req *chunk.Chunk) error { e.prepared = true } - failpoint.Inject("shuffleError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("shuffleError")); _err_ == nil { if val.(bool) { - failpoint.Return(errors.New("ShuffleExec.Next error")) + return errors.New("ShuffleExec.Next error") } - }) + } if e.executed { return nil @@ -279,12 +279,12 @@ func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int waitGroup.Done() }() - failpoint.Inject("shuffleExecFetchDataAndSplit", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("shuffleExecFetchDataAndSplit")); _err_ == nil { if val.(bool) { time.Sleep(100 * time.Millisecond) panic("shuffleExecFetchDataAndSplitPanic") } - }) + } for { err = exec.Next(ctx, e.dataSources[dataSourceIndex], chk) @@ -400,7 +400,7 @@ func (e *shuffleWorker) run(ctx context.Context, waitGroup *sync.WaitGroup) { waitGroup.Done() }() - failpoint.Inject("shuffleWorkerRun", nil) + failpoint.Eval(_curpkg_("shuffleWorkerRun")) for { select { case <-e.finishCh: diff --git a/pkg/executor/shuffle.go__failpoint_stash__ b/pkg/executor/shuffle.go__failpoint_stash__ new file mode 100644 index 0000000000000..9c0ee0f8050b2 --- /dev/null +++ b/pkg/executor/shuffle.go__failpoint_stash__ @@ -0,0 +1,492 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "context" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/aggregate" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/twmb/murmur3" + "go.uber.org/zap" +) + +// ShuffleExec is the executor to run other executors in a parallel manner. +// +// 1. It fetches chunks from M `DataSources` (value of M depends on the actual executor, e.g. M = 1 for WindowExec, M = 2 for MergeJoinExec). +// +// 2. It splits tuples from each `DataSource` into N partitions (Only "split by hash" is implemented so far). +// +// 3. It invokes N workers in parallel, each one has M `receiver` to receive partitions from `DataSources` +// +// 4. It assigns partitions received as input to each worker and executes child executors. +// +// 5. It collects outputs from each worker, then sends outputs to its parent. +// +// +-------------+ +// +-------| Main Thread | +// | +------+------+ +// | ^ +// | | +// | + +// v +++ +// outputHolderCh | | outputCh (1 x Concurrency) +// v +++ +// | ^ +// | | +// | +-------+-------+ +// v | | +// +--------------+ +--------------+ +// +----- | worker | ....... | worker | worker (N Concurrency): child executor, eg. WindowExec (+SortExec) +// | +------------+-+ +-+------------+ +// | ^ ^ +// | | | +// | +-+ +-+ ...... +-+ +// | | | | | | | +// | ... ... ... inputCh (Concurrency x 1) +// v | | | | | | +// inputHolderCh +++ +++ +++ +// v ^ ^ ^ +// | | | | +// | +------o----+ | +// | | +-----------------+-----+ +// | | | +// | +---+------------+------------+----+-----------+ +// | | Partition Splitter | +// | +--------------+-+------------+-+--------------+ +// | ^ +// | | +// | +---------------v-----------------+ +// +----------> | fetch data from DataSource | +// +---------------------------------+ +type ShuffleExec struct { + exec.BaseExecutor + concurrency int + workers []*shuffleWorker + + prepared bool + executed bool + + // each dataSource has a corresponding spliter + splitters []partitionSplitter + dataSources []exec.Executor + + finishCh chan struct{} + outputCh chan *shuffleOutput +} + +type shuffleOutput struct { + chk *chunk.Chunk + err error + giveBackCh chan *chunk.Chunk +} + +// Open implements the Executor Open interface. +func (e *ShuffleExec) Open(ctx context.Context) error { + for _, s := range e.dataSources { + if err := exec.Open(ctx, s); err != nil { + return err + } + } + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + + e.prepared = false + e.finishCh = make(chan struct{}, 1) + e.outputCh = make(chan *shuffleOutput, e.concurrency+len(e.dataSources)) + + for _, w := range e.workers { + w.finishCh = e.finishCh + + for _, r := range w.receivers { + r.inputCh = make(chan *chunk.Chunk, 1) + r.inputHolderCh = make(chan *chunk.Chunk, 1) + } + + w.outputCh = e.outputCh + w.outputHolderCh = make(chan *chunk.Chunk, 1) + + if err := exec.Open(ctx, w.childExec); err != nil { + return err + } + + for i, r := range w.receivers { + r.inputHolderCh <- exec.NewFirstChunk(e.dataSources[i]) + } + w.outputHolderCh <- exec.NewFirstChunk(e) + } + + return nil +} + +// Close implements the Executor Close interface. +func (e *ShuffleExec) Close() error { + var firstErr error + if !e.prepared { + for _, w := range e.workers { + for _, r := range w.receivers { + if r.inputHolderCh != nil { + close(r.inputHolderCh) + } + if r.inputCh != nil { + close(r.inputCh) + } + } + if w.outputHolderCh != nil { + close(w.outputHolderCh) + } + } + if e.outputCh != nil { + close(e.outputCh) + } + } + if e.finishCh != nil { + close(e.finishCh) + } + for _, w := range e.workers { + for _, r := range w.receivers { + if r.inputCh != nil { + channel.Clear(r.inputCh) + } + } + // close child executor of each worker + if err := exec.Close(w.childExec); err != nil && firstErr == nil { + firstErr = err + } + } + if e.outputCh != nil { + channel.Clear(e.outputCh) + } + e.executed = false + + if e.RuntimeStats() != nil { + runtimeStats := &execdetails.RuntimeStatsWithConcurrencyInfo{} + runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("ShuffleConcurrency", e.concurrency)) + e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) + } + + // close dataSources + for _, dataSource := range e.dataSources { + if err := exec.Close(dataSource); err != nil && firstErr == nil { + firstErr = err + } + } + // close baseExecutor + if err := e.BaseExecutor.Close(); err != nil && firstErr == nil { + firstErr = err + } + return errors.Trace(firstErr) +} + +func (e *ShuffleExec) prepare4ParallelExec(ctx context.Context) { + waitGroup := &sync.WaitGroup{} + waitGroup.Add(len(e.workers) + len(e.dataSources)) + // create a goroutine for each dataSource to fetch and split data + for i := range e.dataSources { + go e.fetchDataAndSplit(ctx, i, waitGroup) + } + + for _, w := range e.workers { + go w.run(ctx, waitGroup) + } + + go e.waitWorkerAndCloseOutput(waitGroup) +} + +func (e *ShuffleExec) waitWorkerAndCloseOutput(waitGroup *sync.WaitGroup) { + waitGroup.Wait() + close(e.outputCh) +} + +// Next implements the Executor Next interface. +func (e *ShuffleExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if !e.prepared { + e.prepare4ParallelExec(ctx) + e.prepared = true + } + + failpoint.Inject("shuffleError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(errors.New("ShuffleExec.Next error")) + } + }) + + if e.executed { + return nil + } + + result, ok := <-e.outputCh + if !ok { + e.executed = true + return nil + } + if result.err != nil { + return result.err + } + req.SwapColumns(result.chk) // `shuffleWorker` will not send an empty `result.chk` to `e.outputCh`. + result.giveBackCh <- result.chk + + return nil +} + +func recoveryShuffleExec(output chan *shuffleOutput, r any) { + err := util.GetRecoverError(r) + output <- &shuffleOutput{err: util.GetRecoverError(r)} + logutil.BgLogger().Error("shuffle panicked", zap.Error(err), zap.Stack("stack")) +} + +func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int, waitGroup *sync.WaitGroup) { + var ( + err error + workerIndices []int + ) + results := make([]*chunk.Chunk, len(e.workers)) + chk := exec.TryNewCacheChunk(e.dataSources[dataSourceIndex]) + + defer func() { + if r := recover(); r != nil { + recoveryShuffleExec(e.outputCh, r) + } + for _, w := range e.workers { + close(w.receivers[dataSourceIndex].inputCh) + } + waitGroup.Done() + }() + + failpoint.Inject("shuffleExecFetchDataAndSplit", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(100 * time.Millisecond) + panic("shuffleExecFetchDataAndSplitPanic") + } + }) + + for { + err = exec.Next(ctx, e.dataSources[dataSourceIndex], chk) + if err != nil { + e.outputCh <- &shuffleOutput{err: err} + return + } + if chk.NumRows() == 0 { + break + } + + workerIndices, err = e.splitters[dataSourceIndex].split(e.Ctx(), chk, workerIndices) + if err != nil { + e.outputCh <- &shuffleOutput{err: err} + return + } + numRows := chk.NumRows() + for i := 0; i < numRows; i++ { + workerIdx := workerIndices[i] + w := e.workers[workerIdx] + + if results[workerIdx] == nil { + select { + case <-e.finishCh: + return + case results[workerIdx] = <-w.receivers[dataSourceIndex].inputHolderCh: + //nolint: revive + break + } + } + results[workerIdx].AppendRow(chk.GetRow(i)) + if results[workerIdx].IsFull() { + w.receivers[dataSourceIndex].inputCh <- results[workerIdx] + results[workerIdx] = nil + } + } + } + for i, w := range e.workers { + if results[i] != nil { + w.receivers[dataSourceIndex].inputCh <- results[i] + results[i] = nil + } + } +} + +var _ exec.Executor = &shuffleReceiver{} + +// shuffleReceiver receives chunk from dataSource through inputCh +type shuffleReceiver struct { + exec.BaseExecutor + + finishCh <-chan struct{} + executed bool + + inputCh chan *chunk.Chunk + inputHolderCh chan *chunk.Chunk +} + +// Open implements the Executor Open interface. +func (e *shuffleReceiver) Open(ctx context.Context) error { + if err := e.BaseExecutor.Open(ctx); err != nil { + return err + } + e.executed = false + return nil +} + +// Close implements the Executor Close interface. +func (e *shuffleReceiver) Close() error { + return errors.Trace(e.BaseExecutor.Close()) +} + +// Next implements the Executor Next interface. +// It is called by `Tail` executor within "shuffle", to fetch data from `DataSource` by `inputCh`. +func (e *shuffleReceiver) Next(_ context.Context, req *chunk.Chunk) error { + req.Reset() + if e.executed { + return nil + } + select { + case <-e.finishCh: + e.executed = true + return nil + case result, ok := <-e.inputCh: + if !ok || result.NumRows() == 0 { + e.executed = true + return nil + } + req.SwapColumns(result) + e.inputHolderCh <- result + return nil + } +} + +// shuffleWorker is the multi-thread worker executing child executors within "partition". +type shuffleWorker struct { + childExec exec.Executor + + finishCh <-chan struct{} + + // each receiver corresponse to a dataSource + receivers []*shuffleReceiver + + outputCh chan *shuffleOutput + outputHolderCh chan *chunk.Chunk +} + +func (e *shuffleWorker) run(ctx context.Context, waitGroup *sync.WaitGroup) { + defer func() { + if r := recover(); r != nil { + recoveryShuffleExec(e.outputCh, r) + } + waitGroup.Done() + }() + + failpoint.Inject("shuffleWorkerRun", nil) + for { + select { + case <-e.finishCh: + return + case chk := <-e.outputHolderCh: + if err := exec.Next(ctx, e.childExec, chk); err != nil { + e.outputCh <- &shuffleOutput{err: err} + return + } + + // Should not send an empty `chk` to `e.outputCh`. + if chk.NumRows() == 0 { + return + } + e.outputCh <- &shuffleOutput{chk: chk, giveBackCh: e.outputHolderCh} + } + } +} + +var _ partitionSplitter = &partitionHashSplitter{} +var _ partitionSplitter = &partitionRangeSplitter{} + +type partitionSplitter interface { + split(ctx sessionctx.Context, input *chunk.Chunk, workerIndices []int) ([]int, error) +} + +type partitionHashSplitter struct { + byItems []expression.Expression + numWorkers int + hashKeys [][]byte +} + +func (s *partitionHashSplitter) split(ctx sessionctx.Context, input *chunk.Chunk, workerIndices []int) ([]int, error) { + var err error + s.hashKeys, err = aggregate.GetGroupKey(ctx, input, s.hashKeys, s.byItems) + if err != nil { + return workerIndices, err + } + workerIndices = workerIndices[:0] + numRows := input.NumRows() + for i := 0; i < numRows; i++ { + workerIndices = append(workerIndices, int(murmur3.Sum32(s.hashKeys[i]))%s.numWorkers) + } + return workerIndices, nil +} + +func buildPartitionHashSplitter(concurrency int, byItems []expression.Expression) *partitionHashSplitter { + return &partitionHashSplitter{ + byItems: byItems, + numWorkers: concurrency, + } +} + +type partitionRangeSplitter struct { + byItems []expression.Expression + numWorkers int + groupChecker *vecgroupchecker.VecGroupChecker + idx int +} + +func buildPartitionRangeSplitter(ctx sessionctx.Context, concurrency int, byItems []expression.Expression) *partitionRangeSplitter { + return &partitionRangeSplitter{ + byItems: byItems, + numWorkers: concurrency, + groupChecker: vecgroupchecker.NewVecGroupChecker(ctx.GetExprCtx().GetEvalCtx(), ctx.GetSessionVars().EnableVectorizedExpression, byItems), + idx: 0, + } +} + +// This method is supposed to be used for shuffle with sorted `dataSource` +// the caller of this method should guarantee that `input` is grouped, +// which means that rows with the same byItems should be continuous, the order does not matter. +func (s *partitionRangeSplitter) split(_ sessionctx.Context, input *chunk.Chunk, workerIndices []int) ([]int, error) { + _, err := s.groupChecker.SplitIntoGroups(input) + if err != nil { + return workerIndices, err + } + + workerIndices = workerIndices[:0] + for !s.groupChecker.IsExhausted() { + begin, end := s.groupChecker.GetNextGroup() + for i := begin; i < end; i++ { + workerIndices = append(workerIndices, s.idx) + } + s.idx = (s.idx + 1) % s.numWorkers + } + + return workerIndices, nil +} diff --git a/pkg/executor/slow_query.go b/pkg/executor/slow_query.go index 203d72d8f0e61..87da1791d0b19 100644 --- a/pkg/executor/slow_query.go +++ b/pkg/executor/slow_query.go @@ -468,13 +468,13 @@ func (e *slowQueryRetriever) parseSlowLog(ctx context.Context, sctx sessionctx.C if e.stats != nil { e.stats.readFile += time.Since(startTime) } - failpoint.Inject("mockReadSlowLogSlow", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockReadSlowLogSlow")); _err_ == nil { if val.(bool) { signals := ctx.Value(signalsKey{}).([]chan int) signals[0] <- 1 <-signals[1] } - }) + } for i := range logs { log := logs[i] t := slowLogTask{} @@ -636,11 +636,11 @@ func (e *slowQueryRetriever) parseLog(ctx context.Context, sctx sessionctx.Conte } }() e.memConsume(logSize) - failpoint.Inject("errorMockParseSlowLogPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorMockParseSlowLogPanic")); _err_ == nil { if val.(bool) { panic("panic test") } - }) + } var row []types.Datum user := "" tz := sctx.GetSessionVars().Location() diff --git a/pkg/executor/slow_query.go__failpoint_stash__ b/pkg/executor/slow_query.go__failpoint_stash__ new file mode 100644 index 0000000000000..203d72d8f0e61 --- /dev/null +++ b/pkg/executor/slow_query.go__failpoint_stash__ @@ -0,0 +1,1259 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bufio" + "compress/gzip" + "context" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/plancodec" + "go.uber.org/zap" +) + +type signalsKey struct{} + +// ParseSlowLogBatchSize is the batch size of slow-log lines for a worker to parse, exported for testing. +var ParseSlowLogBatchSize = 64 + +// slowQueryRetriever is used to read slow log data. +type slowQueryRetriever struct { + table *model.TableInfo + outputCols []*model.ColumnInfo + initialized bool + extractor *plannercore.SlowQueryExtractor + files []logFile + fileIdx int + fileLine int + checker *slowLogChecker + columnValueFactoryMap map[string]slowQueryColumnValueFactory + instanceFactory func([]types.Datum) + + taskList chan slowLogTask + stats *slowQueryRuntimeStats + memTracker *memory.Tracker + lastFetchSize int64 + cancel context.CancelFunc + wg sync.WaitGroup +} + +func (e *slowQueryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + if !e.initialized { + err := e.initialize(ctx, sctx) + if err != nil { + return nil, err + } + ctx, e.cancel = context.WithCancel(ctx) + e.initializeAsyncParsing(ctx, sctx) + } + return e.dataForSlowLog(ctx) +} + +func (e *slowQueryRetriever) initialize(ctx context.Context, sctx sessionctx.Context) error { + var err error + var hasProcessPriv bool + if pm := privilege.GetPrivilegeManager(sctx); pm != nil { + hasProcessPriv = pm.RequestVerification(sctx.GetSessionVars().ActiveRoles, "", "", "", mysql.ProcessPriv) + } + // initialize column value factories. + e.columnValueFactoryMap = make(map[string]slowQueryColumnValueFactory, len(e.outputCols)) + for idx, col := range e.outputCols { + if col.Name.O == util.ClusterTableInstanceColumnName { + e.instanceFactory, err = getInstanceColumnValueFactory(sctx, idx) + if err != nil { + return err + } + continue + } + factory, err := getColumnValueFactoryByName(col.Name.O, idx) + if err != nil { + return err + } + if factory == nil { + panic(fmt.Sprintf("should never happen, should register new column %v into getColumnValueFactoryByName function", col.Name.O)) + } + e.columnValueFactoryMap[col.Name.O] = factory + } + // initialize checker. + e.checker = &slowLogChecker{ + hasProcessPriv: hasProcessPriv, + user: sctx.GetSessionVars().User, + } + e.stats = &slowQueryRuntimeStats{} + if e.extractor != nil { + e.checker.enableTimeCheck = e.extractor.Enable + for _, tr := range e.extractor.TimeRanges { + startTime := types.NewTime(types.FromGoTime(tr.StartTime.In(sctx.GetSessionVars().Location())), mysql.TypeDatetime, types.MaxFsp) + endTime := types.NewTime(types.FromGoTime(tr.EndTime.In(sctx.GetSessionVars().Location())), mysql.TypeDatetime, types.MaxFsp) + timeRange := &timeRange{ + startTime: startTime, + endTime: endTime, + } + e.checker.timeRanges = append(e.checker.timeRanges, timeRange) + } + } else { + e.extractor = &plannercore.SlowQueryExtractor{} + } + e.initialized = true + e.files, err = e.getAllFiles(ctx, sctx, sctx.GetSessionVars().SlowQueryFile) + if e.extractor.Desc { + slices.Reverse(e.files) + } + return err +} + +func (e *slowQueryRetriever) close() error { + for _, f := range e.files { + err := f.file.Close() + if err != nil { + logutil.BgLogger().Error("close slow log file failed.", zap.Error(err)) + } + } + if e.cancel != nil { + e.cancel() + } + e.wg.Wait() + return nil +} + +type parsedSlowLog struct { + rows [][]types.Datum + err error +} + +func (e *slowQueryRetriever) getNextFile() *logFile { + if e.fileIdx >= len(e.files) { + return nil + } + ret := &e.files[e.fileIdx] + file := e.files[e.fileIdx].file + e.fileIdx++ + if e.stats != nil { + stat, err := file.Stat() + if err == nil { + // ignore the err will be ok. + e.stats.readFileSize += stat.Size() + e.stats.readFileNum++ + } + } + return ret +} + +func (e *slowQueryRetriever) getPreviousReader() (*bufio.Reader, error) { + fileIdx := e.fileIdx + // fileIdx refer to the next file which should be read + // so we need to set fileIdx to fileIdx - 2 to get the previous file. + fileIdx = fileIdx - 2 + if fileIdx < 0 { + return nil, nil + } + file := e.files[fileIdx] + _, err := file.file.Seek(0, io.SeekStart) + if err != nil { + return nil, err + } + var reader *bufio.Reader + if !file.compressed { + reader = bufio.NewReader(file.file) + } else { + gr, err := gzip.NewReader(file.file) + if err != nil { + return nil, err + } + reader = bufio.NewReader(gr) + } + return reader, nil +} + +func (e *slowQueryRetriever) getNextReader() (*bufio.Reader, error) { + file := e.getNextFile() + if file == nil { + return nil, nil + } + var reader *bufio.Reader + if !file.compressed { + reader = bufio.NewReader(file.file) + } else { + gr, err := gzip.NewReader(file.file) + if err != nil { + return nil, err + } + reader = bufio.NewReader(gr) + } + return reader, nil +} + +func (e *slowQueryRetriever) parseDataForSlowLog(ctx context.Context, sctx sessionctx.Context) { + defer e.wg.Done() + reader, _ := e.getNextReader() + if reader == nil { + close(e.taskList) + return + } + e.parseSlowLog(ctx, sctx, reader, ParseSlowLogBatchSize) +} + +func (e *slowQueryRetriever) dataForSlowLog(ctx context.Context) ([][]types.Datum, error) { + var ( + task slowLogTask + ok bool + ) + e.memConsume(-e.lastFetchSize) + e.lastFetchSize = 0 + for { + select { + case task, ok = <-e.taskList: + case <-ctx.Done(): + return nil, ctx.Err() + } + if !ok { + return nil, nil + } + result := <-task.resultCh + rows, err := result.rows, result.err + if err != nil { + return nil, err + } + if len(rows) == 0 { + continue + } + if e.instanceFactory != nil { + for i := range rows { + e.instanceFactory(rows[i]) + } + } + e.lastFetchSize = calculateDatumsSize(rows) + return rows, nil + } +} + +type slowLogChecker struct { + // Below fields is used to check privilege. + hasProcessPriv bool + user *auth.UserIdentity + // Below fields is used to check slow log time valid. + enableTimeCheck bool + timeRanges []*timeRange +} + +type timeRange struct { + startTime types.Time + endTime types.Time +} + +func (sc *slowLogChecker) hasPrivilege(userName string) bool { + return sc.hasProcessPriv || sc.user == nil || userName == sc.user.Username +} + +func (sc *slowLogChecker) isTimeValid(t types.Time) bool { + for _, tr := range sc.timeRanges { + if sc.enableTimeCheck && (t.Compare(tr.startTime) >= 0 && t.Compare(tr.endTime) <= 0) { + return true + } + } + return !sc.enableTimeCheck +} + +func getOneLine(reader *bufio.Reader) ([]byte, error) { + return util.ReadLine(reader, int(variable.MaxOfMaxAllowedPacket)) +} + +type offset struct { + offset int + length int +} + +type slowLogTask struct { + resultCh chan parsedSlowLog +} + +type slowLogBlock []string + +func (e *slowQueryRetriever) getBatchLog(ctx context.Context, reader *bufio.Reader, offset *offset, num int) ([][]string, error) { + var line string + log := make([]string, 0, num) + var err error + for i := 0; i < num; i++ { + for { + if isCtxDone(ctx) { + return nil, ctx.Err() + } + e.fileLine++ + lineByte, err := getOneLine(reader) + if err != nil { + if err == io.EOF { + e.fileLine = 0 + newReader, err := e.getNextReader() + if newReader == nil || err != nil { + return [][]string{log}, err + } + offset.length = len(log) + reader.Reset(newReader) + continue + } + return [][]string{log}, err + } + line = string(hack.String(lineByte)) + log = append(log, line) + if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) { + if strings.HasPrefix(line, "use") || strings.HasPrefix(line, variable.SlowLogRowPrefixStr) { + continue + } + break + } + } + } + return [][]string{log}, err +} + +func (e *slowQueryRetriever) getBatchLogForReversedScan(ctx context.Context, reader *bufio.Reader, offset *offset, num int) ([][]string, error) { + // reader maybe change when read previous file. + inputReader := reader + defer func() { + newReader, _ := e.getNextReader() + if newReader != nil { + inputReader.Reset(newReader) + } + }() + var line string + var logs []slowLogBlock + var log []string + var err error + hasStartFlag := false + scanPreviousFile := false + for { + if isCtxDone(ctx) { + return nil, ctx.Err() + } + e.fileLine++ + lineByte, err := getOneLine(reader) + if err != nil { + if err == io.EOF { + if len(log) == 0 { + decomposedSlowLogTasks := decomposeToSlowLogTasks(logs, num) + offset.length = len(decomposedSlowLogTasks) + return decomposedSlowLogTasks, nil + } + e.fileLine = 0 + reader, err = e.getPreviousReader() + if reader == nil || err != nil { + return decomposeToSlowLogTasks(logs, num), nil + } + scanPreviousFile = true + continue + } + return nil, err + } + line = string(hack.String(lineByte)) + if !hasStartFlag && strings.HasPrefix(line, variable.SlowLogStartPrefixStr) { + hasStartFlag = true + } + if hasStartFlag { + log = append(log, line) + if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) { + if strings.HasPrefix(line, "use") || strings.HasPrefix(line, variable.SlowLogRowPrefixStr) { + continue + } + logs = append(logs, log) + if scanPreviousFile { + break + } + log = make([]string, 0, 8) + hasStartFlag = false + } + } + } + return decomposeToSlowLogTasks(logs, num), err +} + +func decomposeToSlowLogTasks(logs []slowLogBlock, num int) [][]string { + if len(logs) == 0 { + return nil + } + + //In reversed scan, We should reverse the blocks. + last := len(logs) - 1 + for i := 0; i < len(logs)/2; i++ { + logs[i], logs[last-i] = logs[last-i], logs[i] + } + + decomposedSlowLogTasks := make([][]string, 0) + log := make([]string, 0, num*len(logs[0])) + for i := range logs { + log = append(log, logs[i]...) + if i > 0 && i%num == 0 { + decomposedSlowLogTasks = append(decomposedSlowLogTasks, log) + log = make([]string, 0, len(log)) + } + } + if len(log) > 0 { + decomposedSlowLogTasks = append(decomposedSlowLogTasks, log) + } + return decomposedSlowLogTasks +} + +func (e *slowQueryRetriever) parseSlowLog(ctx context.Context, sctx sessionctx.Context, reader *bufio.Reader, logNum int) { + defer close(e.taskList) + offset := offset{offset: 0, length: 0} + // To limit the num of go routine + concurrent := sctx.GetSessionVars().Concurrency.DistSQLScanConcurrency() + ch := make(chan int, concurrent) + if e.stats != nil { + e.stats.concurrent = concurrent + } + defer close(ch) + for { + startTime := time.Now() + var logs [][]string + var err error + if !e.extractor.Desc { + logs, err = e.getBatchLog(ctx, reader, &offset, logNum) + } else { + logs, err = e.getBatchLogForReversedScan(ctx, reader, &offset, logNum) + } + if err != nil { + t := slowLogTask{} + t.resultCh = make(chan parsedSlowLog, 1) + select { + case <-ctx.Done(): + return + case e.taskList <- t: + } + e.sendParsedSlowLogCh(t, parsedSlowLog{nil, err}) + } + if len(logs) == 0 || len(logs[0]) == 0 { + break + } + if e.stats != nil { + e.stats.readFile += time.Since(startTime) + } + failpoint.Inject("mockReadSlowLogSlow", func(val failpoint.Value) { + if val.(bool) { + signals := ctx.Value(signalsKey{}).([]chan int) + signals[0] <- 1 + <-signals[1] + } + }) + for i := range logs { + log := logs[i] + t := slowLogTask{} + t.resultCh = make(chan parsedSlowLog, 1) + start := offset + ch <- 1 + select { + case <-ctx.Done(): + return + case e.taskList <- t: + } + e.wg.Add(1) + go func() { + defer e.wg.Done() + result, err := e.parseLog(ctx, sctx, log, start) + e.sendParsedSlowLogCh(t, parsedSlowLog{result, err}) + <-ch + }() + offset.offset = e.fileLine + offset.length = 0 + select { + case <-ctx.Done(): + return + default: + } + } + } +} + +func (*slowQueryRetriever) sendParsedSlowLogCh(t slowLogTask, re parsedSlowLog) { + select { + case t.resultCh <- re: + default: + return + } +} + +func getLineIndex(offset offset, index int) int { + var fileLine int + if offset.length <= index { + fileLine = index - offset.length + 1 + } else { + fileLine = offset.offset + index + 1 + } + return fileLine +} + +// findMatchedRightBracket returns the rightBracket index which matchs line[leftBracketIdx] +// leftBracketIdx should be valid string index for line +// Returns -1 if invalid inputs are given +func findMatchedRightBracket(line string, leftBracketIdx int) int { + leftBracket := line[leftBracketIdx] + rightBracket := byte('}') + if leftBracket == '[' { + rightBracket = ']' + } else if leftBracket != '{' { + return -1 + } + lineLength := len(line) + current := leftBracketIdx + leftBracketCnt := 0 + for current < lineLength { + b := line[current] + if b == leftBracket { + leftBracketCnt++ + current++ + } else if b == rightBracket { + leftBracketCnt-- + if leftBracketCnt > 0 { + current++ + } else if leftBracketCnt == 0 { + if current+1 < lineLength && line[current+1] != ' ' { + return -1 + } + return current + } else { + return -1 + } + } else { + current++ + } + } + return -1 +} + +func isLetterOrNumeric(b byte) bool { + return ('A' <= b && b <= 'Z') || ('a' <= b && b <= 'z') || ('0' <= b && b <= '9') +} + +// splitByColon split a line like "field: value field: value..." +// Note: +// 1. field string's first character can only be ASCII letters or digits, and can't contain ':' +// 2. value string may be surrounded by brackets, allowed brackets includes "[]" and "{}", like {key: value,{key: value}} +// "[]" can only be nested inside "[]"; "{}" can only be nested inside "{}" +// 3. value string can't contain ' ' character unless it is inside brackets +func splitByColon(line string) (fields []string, values []string) { + fields = make([]string, 0, 1) + values = make([]string, 0, 1) + + lineLength := len(line) + parseKey := true + start := 0 + errMsg := "" + for current := 0; current < lineLength; { + if parseKey { + // Find key start + for current < lineLength && !isLetterOrNumeric(line[current]) { + current++ + } + start = current + if current >= lineLength { + break + } + for current < lineLength && line[current] != ':' { + current++ + } + fields = append(fields, line[start:current]) + parseKey = false + current += 2 // bypass ": " + } else { + start = current + if current < lineLength && (line[current] == '{' || line[current] == '[') { + rBraceIdx := findMatchedRightBracket(line, current) + if rBraceIdx == -1 { + errMsg = "Braces matched error" + break + } + current = rBraceIdx + 1 + } else { + for current < lineLength && line[current] != ' ' { + current++ + } + } + values = append(values, line[start:min(current, len(line))]) + parseKey = true + } + } + if len(errMsg) > 0 { + logutil.BgLogger().Warn("slow query parse slow log error", zap.String("Error", errMsg), zap.String("Log", line)) + return nil, nil + } + return fields, values +} + +func (e *slowQueryRetriever) parseLog(ctx context.Context, sctx sessionctx.Context, log []string, offset offset) (data [][]types.Datum, err error) { + start := time.Now() + logSize := calculateLogSize(log) + defer e.memConsume(-logSize) + defer func() { + if r := recover(); r != nil { + err = util.GetRecoverError(r) + buf := make([]byte, 4096) + stackSize := runtime.Stack(buf, false) + buf = buf[:stackSize] + logutil.BgLogger().Warn("slow query parse slow log panic", zap.Error(err), zap.String("stack", string(buf))) + } + if e.stats != nil { + atomic.AddInt64(&e.stats.parseLog, int64(time.Since(start))) + } + }() + e.memConsume(logSize) + failpoint.Inject("errorMockParseSlowLogPanic", func(val failpoint.Value) { + if val.(bool) { + panic("panic test") + } + }) + var row []types.Datum + user := "" + tz := sctx.GetSessionVars().Location() + startFlag := false + for index, line := range log { + if isCtxDone(ctx) { + return nil, ctx.Err() + } + fileLine := getLineIndex(offset, index) + if !startFlag && strings.HasPrefix(line, variable.SlowLogStartPrefixStr) { + row = make([]types.Datum, len(e.outputCols)) + user = "" + valid := e.setColumnValue(sctx, row, tz, variable.SlowLogTimeStr, line[len(variable.SlowLogStartPrefixStr):], e.checker, fileLine) + if valid { + startFlag = true + } + continue + } + if startFlag { + if strings.HasPrefix(line, variable.SlowLogRowPrefixStr) { + line = line[len(variable.SlowLogRowPrefixStr):] + valid := true + if strings.HasPrefix(line, variable.SlowLogPrevStmtPrefix) { + valid = e.setColumnValue(sctx, row, tz, variable.SlowLogPrevStmt, line[len(variable.SlowLogPrevStmtPrefix):], e.checker, fileLine) + } else if strings.HasPrefix(line, variable.SlowLogUserAndHostStr+variable.SlowLogSpaceMarkStr) { + value := line[len(variable.SlowLogUserAndHostStr+variable.SlowLogSpaceMarkStr):] + fields := strings.SplitN(value, "@", 2) + if len(fields) < 2 { + continue + } + user = parseUserOrHostValue(fields[0]) + if e.checker != nil && !e.checker.hasPrivilege(user) { + startFlag = false + continue + } + valid = e.setColumnValue(sctx, row, tz, variable.SlowLogUserStr, user, e.checker, fileLine) + if !valid { + startFlag = false + continue + } + host := parseUserOrHostValue(fields[1]) + valid = e.setColumnValue(sctx, row, tz, variable.SlowLogHostStr, host, e.checker, fileLine) + } else if strings.HasPrefix(line, variable.SlowLogCopBackoffPrefix) { + valid = e.setColumnValue(sctx, row, tz, variable.SlowLogBackoffDetail, line, e.checker, fileLine) + } else if strings.HasPrefix(line, variable.SlowLogWarnings) { + line = line[len(variable.SlowLogWarnings+variable.SlowLogSpaceMarkStr):] + valid = e.setColumnValue(sctx, row, tz, variable.SlowLogWarnings, line, e.checker, fileLine) + } else { + fields, values := splitByColon(line) + for i := 0; i < len(fields); i++ { + valid := e.setColumnValue(sctx, row, tz, fields[i], values[i], e.checker, fileLine) + if !valid { + startFlag = false + break + } + } + } + if !valid { + startFlag = false + } + } else if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) { + if strings.HasPrefix(line, "use") { + // `use DB` statements in the slow log is used to keep it be compatible with MySQL, + // since we already get the current DB from the `# DB` field, we can ignore it here, + // please see https://github.com/pingcap/tidb/issues/17846 for more details. + continue + } + if e.checker != nil && !e.checker.hasPrivilege(user) { + startFlag = false + continue + } + // Get the sql string, and mark the start flag to false. + _ = e.setColumnValue(sctx, row, tz, variable.SlowLogQuerySQLStr, string(hack.Slice(line)), e.checker, fileLine) + e.setDefaultValue(row) + e.memConsume(types.EstimatedMemUsage(row, 1)) + data = append(data, row) + startFlag = false + } else { + startFlag = false + } + } + } + return data, nil +} + +func (e *slowQueryRetriever) setColumnValue(sctx sessionctx.Context, row []types.Datum, tz *time.Location, field, value string, checker *slowLogChecker, lineNum int) bool { + factory := e.columnValueFactoryMap[field] + if factory == nil { + // Fix issue 34320, when slow log time is not in the output columns, the time filter condition is mistakenly discard. + if field == variable.SlowLogTimeStr && checker != nil { + t, err := ParseTime(value) + if err != nil { + err = fmt.Errorf("Parse slow log at line %v, failed field is %v, failed value is %v, error is %v", lineNum, field, value, err) + sctx.GetSessionVars().StmtCtx.AppendWarning(err) + return false + } + timeValue := types.NewTime(types.FromGoTime(t), mysql.TypeTimestamp, types.MaxFsp) + return checker.isTimeValid(timeValue) + } + return true + } + valid, err := factory(row, value, tz, checker) + if err != nil { + err = fmt.Errorf("Parse slow log at line %v, failed field is %v, failed value is %v, error is %v", lineNum, field, value, err) + sctx.GetSessionVars().StmtCtx.AppendWarning(err) + return true + } + return valid +} + +func (e *slowQueryRetriever) setDefaultValue(row []types.Datum) { + for i := range row { + if !row[i].IsNull() { + continue + } + row[i] = table.GetZeroValue(e.outputCols[i]) + } +} + +type slowQueryColumnValueFactory func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) + +func parseUserOrHostValue(value string) string { + // the new User&Host format: root[root] @ localhost [127.0.0.1] + tmp := strings.Split(value, "[") + return strings.TrimSpace(tmp[0]) +} + +func getColumnValueFactoryByName(colName string, columnIdx int) (slowQueryColumnValueFactory, error) { + switch colName { + case variable.SlowLogTimeStr: + return func(row []types.Datum, value string, tz *time.Location, checker *slowLogChecker) (bool, error) { + t, err := ParseTime(value) + if err != nil { + return false, err + } + timeValue := types.NewTime(types.FromGoTime(t.In(tz)), mysql.TypeTimestamp, types.MaxFsp) + if checker != nil { + valid := checker.isTimeValid(timeValue) + if !valid { + return valid, nil + } + } + row[columnIdx] = types.NewTimeDatum(timeValue) + return true, nil + }, nil + case variable.SlowLogBackoffDetail: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (bool, error) { + backoffDetail := row[columnIdx].GetString() + if len(backoffDetail) > 0 { + backoffDetail += " " + } + backoffDetail += value + row[columnIdx] = types.NewStringDatum(backoffDetail) + return true, nil + }, nil + case variable.SlowLogPlan: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (bool, error) { + plan := parsePlan(value) + row[columnIdx] = types.NewStringDatum(plan) + return true, nil + }, nil + case variable.SlowLogBinaryPlan: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (bool, error) { + if strings.HasPrefix(value, variable.SlowLogBinaryPlanPrefix) { + value = value[len(variable.SlowLogBinaryPlanPrefix) : len(value)-len(variable.SlowLogPlanSuffix)] + } + row[columnIdx] = types.NewStringDatum(value) + return true, nil + }, nil + case variable.SlowLogConnIDStr, variable.SlowLogExecRetryCount, variable.SlowLogPreprocSubQueriesStr, + execdetails.WriteKeysStr, execdetails.WriteSizeStr, execdetails.PrewriteRegionStr, execdetails.TxnRetryStr, + execdetails.RequestCountStr, execdetails.TotalKeysStr, execdetails.ProcessKeysStr, + execdetails.RocksdbDeleteSkippedCountStr, execdetails.RocksdbKeySkippedCountStr, + execdetails.RocksdbBlockCacheHitCountStr, execdetails.RocksdbBlockReadCountStr, + variable.SlowLogTxnStartTSStr, execdetails.RocksdbBlockReadByteStr: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { + v, err := strconv.ParseUint(value, 10, 64) + if err != nil { + return false, err + } + row[columnIdx] = types.NewUintDatum(v) + return true, nil + }, nil + case variable.SlowLogExecRetryTime, variable.SlowLogQueryTimeStr, variable.SlowLogParseTimeStr, + variable.SlowLogCompileTimeStr, variable.SlowLogRewriteTimeStr, variable.SlowLogPreProcSubQueryTimeStr, + variable.SlowLogOptimizeTimeStr, variable.SlowLogWaitTSTimeStr, execdetails.PreWriteTimeStr, + execdetails.WaitPrewriteBinlogTimeStr, execdetails.CommitTimeStr, execdetails.GetCommitTSTimeStr, + execdetails.CommitBackoffTimeStr, execdetails.ResolveLockTimeStr, execdetails.LocalLatchWaitTimeStr, + execdetails.CopTimeStr, execdetails.ProcessTimeStr, execdetails.WaitTimeStr, execdetails.BackoffTimeStr, + execdetails.LockKeysTimeStr, variable.SlowLogCopProcAvg, variable.SlowLogCopProcP90, variable.SlowLogCopProcMax, + variable.SlowLogCopWaitAvg, variable.SlowLogCopWaitP90, variable.SlowLogCopWaitMax, variable.SlowLogKVTotal, + variable.SlowLogPDTotal, variable.SlowLogBackoffTotal, variable.SlowLogWriteSQLRespTotal, variable.SlowLogRRU, + variable.SlowLogWRU, variable.SlowLogWaitRUDuration: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return false, err + } + row[columnIdx] = types.NewFloat64Datum(v) + return true, nil + }, nil + case variable.SlowLogUserStr, variable.SlowLogHostStr, execdetails.BackoffTypesStr, variable.SlowLogDBStr, variable.SlowLogIndexNamesStr, variable.SlowLogDigestStr, + variable.SlowLogStatsInfoStr, variable.SlowLogCopProcAddr, variable.SlowLogCopWaitAddr, variable.SlowLogPlanDigest, + variable.SlowLogPrevStmt, variable.SlowLogQuerySQLStr, variable.SlowLogWarnings, variable.SlowLogSessAliasStr, + variable.SlowLogResourceGroup: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { + row[columnIdx] = types.NewStringDatum(value) + return true, nil + }, nil + case variable.SlowLogMemMax, variable.SlowLogDiskMax, variable.SlowLogResultRows: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { + v, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return false, err + } + row[columnIdx] = types.NewIntDatum(v) + return true, nil + }, nil + case variable.SlowLogPrepared, variable.SlowLogSucc, variable.SlowLogPlanFromCache, variable.SlowLogPlanFromBinding, + variable.SlowLogIsInternalStr, variable.SlowLogIsExplicitTxn, variable.SlowLogIsWriteCacheTable, variable.SlowLogHasMoreResults: + return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { + v, err := strconv.ParseBool(value) + if err != nil { + return false, err + } + row[columnIdx] = types.NewDatum(v) + return true, nil + }, nil + } + return nil, nil +} + +func getInstanceColumnValueFactory(sctx sessionctx.Context, columnIdx int) (func(row []types.Datum), error) { + instanceAddr, err := infoschema.GetInstanceAddr(sctx) + if err != nil { + return nil, err + } + return func(row []types.Datum) { + row[columnIdx] = types.NewStringDatum(instanceAddr) + }, nil +} + +func parsePlan(planString string) string { + if len(planString) <= len(variable.SlowLogPlanPrefix)+len(variable.SlowLogPlanSuffix) { + return planString + } + planString = planString[len(variable.SlowLogPlanPrefix) : len(planString)-len(variable.SlowLogPlanSuffix)] + decodePlanString, err := plancodec.DecodePlan(planString) + if err == nil { + planString = decodePlanString + } else { + logutil.BgLogger().Error("decode plan in slow log failed", zap.String("plan", planString), zap.Error(err)) + } + return planString +} + +// ParseTime exports for testing. +func ParseTime(s string) (time.Time, error) { + t, err := time.Parse(logutil.SlowLogTimeFormat, s) + if err != nil { + // This is for compatibility. + t, err = time.Parse(logutil.OldSlowLogTimeFormat, s) + if err != nil { + err = errors.Errorf("string \"%v\" doesn't has a prefix that matches format \"%v\", err: %v", s, logutil.SlowLogTimeFormat, err) + } + } + return t, err +} + +type logFile struct { + file *os.File // The opened file handle + start time.Time // The start time of the log file + compressed bool // The file is compressed or not +} + +// getAllFiles is used to get all slow-log needed to parse, it is exported for test. +func (e *slowQueryRetriever) getAllFiles(ctx context.Context, sctx sessionctx.Context, logFilePath string) ([]logFile, error) { + totalFileNum := 0 + if e.stats != nil { + startTime := time.Now() + defer func() { + e.stats.initialize = time.Since(startTime) + e.stats.totalFileNum = totalFileNum + }() + } + if e.extractor == nil || !e.extractor.Enable { + totalFileNum = 1 + //nolint: gosec + file, err := os.Open(logFilePath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + return []logFile{{file: file}}, nil + } + var logFiles []logFile + logDir := filepath.Dir(logFilePath) + ext := filepath.Ext(logFilePath) + prefix := logFilePath[:len(logFilePath)-len(ext)] + handleErr := func(err error) error { + // Ignore the error and append warning for usability. + if err != io.EOF { + sctx.GetSessionVars().StmtCtx.AppendWarning(err) + } + return nil + } + files, err := os.ReadDir(logDir) + if err != nil { + return nil, err + } + walkFn := func(path string, info os.DirEntry) error { + if info.IsDir() { + return nil + } + // All rotated log files have the same prefix with the original file. + if !strings.HasPrefix(path, prefix) { + return nil + } + compressed := strings.HasSuffix(path, ".gz") + if isCtxDone(ctx) { + return ctx.Err() + } + totalFileNum++ + file, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm) + if err != nil { + return handleErr(err) + } + skip := false + defer func() { + if !skip { + terror.Log(file.Close()) + } + }() + // Get the file start time. + fileStartTime, err := e.getFileStartTime(ctx, file, compressed) + if err != nil { + return handleErr(err) + } + start := types.NewTime(types.FromGoTime(fileStartTime), mysql.TypeDatetime, types.MaxFsp) + notInAllTimeRanges := true + for _, tr := range e.checker.timeRanges { + if start.Compare(tr.endTime) <= 0 { + notInAllTimeRanges = false + break + } + } + if notInAllTimeRanges { + return nil + } + + // If we want to get the end time from a compressed file, + // we need uncompress the whole file which is very slow and consume a lot of memory. + if !compressed { + // Get the file end time. + fileEndTime, err := e.getFileEndTime(ctx, file) + if err != nil { + return handleErr(err) + } + end := types.NewTime(types.FromGoTime(fileEndTime), mysql.TypeDatetime, types.MaxFsp) + inTimeRanges := false + for _, tr := range e.checker.timeRanges { + if !(start.Compare(tr.endTime) > 0 || end.Compare(tr.startTime) < 0) { + inTimeRanges = true + break + } + } + if !inTimeRanges { + return nil + } + } + _, err = file.Seek(0, io.SeekStart) + if err != nil { + return handleErr(err) + } + logFiles = append(logFiles, logFile{ + file: file, + start: fileStartTime, + compressed: compressed, + }) + skip = true + return nil + } + for _, file := range files { + err := walkFn(filepath.Join(logDir, file.Name()), file) + if err != nil { + return nil, err + } + } + // Sort by start time + slices.SortFunc(logFiles, func(i, j logFile) int { + return i.start.Compare(j.start) + }) + // Assume no time range overlap in log files and remove unnecessary log files for compressed files. + var ret []logFile + for i, file := range logFiles { + if i == len(logFiles)-1 || !file.compressed { + ret = append(ret, file) + continue + } + start := types.NewTime(types.FromGoTime(logFiles[i].start), mysql.TypeDatetime, types.MaxFsp) + // use next file.start as endTime + end := types.NewTime(types.FromGoTime(logFiles[i+1].start), mysql.TypeDatetime, types.MaxFsp) + inTimeRanges := false + for _, tr := range e.checker.timeRanges { + if !(start.Compare(tr.endTime) > 0 || end.Compare(tr.startTime) < 0) { + inTimeRanges = true + break + } + } + if inTimeRanges { + ret = append(ret, file) + } + } + return ret, err +} + +func (*slowQueryRetriever) getFileStartTime(ctx context.Context, file *os.File, compressed bool) (time.Time, error) { + var t time.Time + _, err := file.Seek(0, io.SeekStart) + if err != nil { + return t, err + } + var reader *bufio.Reader + if !compressed { + reader = bufio.NewReader(file) + } else { + gr, err := gzip.NewReader(file) + if err != nil { + return t, err + } + reader = bufio.NewReader(gr) + } + maxNum := 128 + for { + lineByte, err := getOneLine(reader) + if err != nil { + return t, err + } + line := string(lineByte) + if strings.HasPrefix(line, variable.SlowLogStartPrefixStr) { + return ParseTime(line[len(variable.SlowLogStartPrefixStr):]) + } + maxNum-- + if maxNum <= 0 { + break + } + if isCtxDone(ctx) { + return t, ctx.Err() + } + } + return t, errors.Errorf("malform slow query file %v", file.Name()) +} + +func (e *slowQueryRetriever) getRuntimeStats() execdetails.RuntimeStats { + return e.stats +} + +type slowQueryRuntimeStats struct { + totalFileNum int + readFileNum int + readFile time.Duration + initialize time.Duration + readFileSize int64 + parseLog int64 + concurrent int +} + +// String implements the RuntimeStats interface. +func (s *slowQueryRuntimeStats) String() string { + return fmt.Sprintf("initialize: %s, read_file: %s, parse_log: {time:%s, concurrency:%v}, total_file: %v, read_file: %v, read_size: %s", + execdetails.FormatDuration(s.initialize), execdetails.FormatDuration(s.readFile), + execdetails.FormatDuration(time.Duration(s.parseLog)), s.concurrent, + s.totalFileNum, s.readFileNum, memory.FormatBytes(s.readFileSize)) +} + +// Merge implements the RuntimeStats interface. +func (s *slowQueryRuntimeStats) Merge(rs execdetails.RuntimeStats) { + tmp, ok := rs.(*slowQueryRuntimeStats) + if !ok { + return + } + s.totalFileNum += tmp.totalFileNum + s.readFileNum += tmp.readFileNum + s.readFile += tmp.readFile + s.initialize += tmp.initialize + s.readFileSize += tmp.readFileSize + s.parseLog += tmp.parseLog +} + +// Clone implements the RuntimeStats interface. +func (s *slowQueryRuntimeStats) Clone() execdetails.RuntimeStats { + newRs := *s + return &newRs +} + +// Tp implements the RuntimeStats interface. +func (*slowQueryRuntimeStats) Tp() int { + return execdetails.TpSlowQueryRuntimeStat +} + +func (*slowQueryRetriever) getFileEndTime(ctx context.Context, file *os.File) (time.Time, error) { + var t time.Time + var tried int + stat, err := file.Stat() + if err != nil { + return t, err + } + endCursor := stat.Size() + maxLineNum := 128 + for { + lines, readBytes, err := readLastLines(ctx, file, endCursor) + if err != nil { + return t, err + } + // read out the file + if readBytes == 0 { + break + } + endCursor -= int64(readBytes) + for i := len(lines) - 1; i >= 0; i-- { + if strings.HasPrefix(lines[i], variable.SlowLogStartPrefixStr) { + return ParseTime(lines[i][len(variable.SlowLogStartPrefixStr):]) + } + } + tried += len(lines) + if tried >= maxLineNum { + break + } + if isCtxDone(ctx) { + return t, ctx.Err() + } + } + return t, errors.Errorf("invalid slow query file %v", file.Name()) +} + +const maxReadCacheSize = 1024 * 1024 * 64 + +// Read lines from the end of a file +// endCursor initial value should be the filesize +func readLastLines(ctx context.Context, file *os.File, endCursor int64) ([]string, int, error) { + var lines []byte + var firstNonNewlinePos int + var cursor = endCursor + var size int64 = 2048 + for { + // stop if we are at the beginning + // check it in the start to avoid read beyond the size + if cursor <= 0 { + break + } + if size < maxReadCacheSize { + size = size * 2 + } + if cursor < size { + size = cursor + } + cursor -= size + + _, err := file.Seek(cursor, io.SeekStart) + if err != nil { + return nil, 0, err + } + chars := make([]byte, size) + _, err = file.Read(chars) + if err != nil { + return nil, 0, err + } + lines = append(chars, lines...) // nozero + + // find first '\n' or '\r' + for i := 0; i < len(chars)-1; i++ { + if (chars[i] == '\n' || chars[i] == '\r') && chars[i+1] != '\n' && chars[i+1] != '\r' { + firstNonNewlinePos = i + 1 + break + } + } + if firstNonNewlinePos > 0 { + break + } + if isCtxDone(ctx) { + return nil, 0, ctx.Err() + } + } + finalStr := string(lines[firstNonNewlinePos:]) + return strings.Split(strings.ReplaceAll(finalStr, "\r\n", "\n"), "\n"), len(finalStr), nil +} + +func (e *slowQueryRetriever) initializeAsyncParsing(ctx context.Context, sctx sessionctx.Context) { + e.taskList = make(chan slowLogTask, 1) + e.wg.Add(1) + go e.parseDataForSlowLog(ctx, sctx) +} + +func calculateLogSize(log []string) int64 { + size := 0 + for _, line := range log { + size += len(line) + } + return int64(size) +} + +func calculateDatumsSize(rows [][]types.Datum) int64 { + size := int64(0) + for _, row := range rows { + size += types.EstimatedMemUsage(row, 1) + } + return size +} + +func (e *slowQueryRetriever) memConsume(bytes int64) { + if e.memTracker != nil { + e.memTracker.Consume(bytes) + } +} diff --git a/pkg/executor/sortexec/binding__failpoint_binding__.go b/pkg/executor/sortexec/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..1fe00bd1a09a7 --- /dev/null +++ b/pkg/executor/sortexec/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package sortexec + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/sortexec/parallel_sort_worker.go b/pkg/executor/sortexec/parallel_sort_worker.go index 1d7a4d7aaa7b6..967f4754f7575 100644 --- a/pkg/executor/sortexec/parallel_sort_worker.go +++ b/pkg/executor/sortexec/parallel_sort_worker.go @@ -87,7 +87,7 @@ func (p *parallelSortWorker) reset() { func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor int32) { injectParallelSortRandomFail(triggerFactor) - failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SlowSomeWorkers")); _err_ == nil { if val.(bool) { if p.workerIDForTest%2 == 0 { randNum := rand.Int31n(10000) @@ -96,7 +96,7 @@ func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor } } } - }) + } } func (p *parallelSortWorker) multiWayMergeLocalSortedRows() ([]chunk.Row, error) { @@ -208,11 +208,11 @@ func (p *parallelSortWorker) keyColumnsLess(i, j chunk.Row) int { p.timesOfRowCompare = 0 } - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { if val.(bool) { p.timesOfRowCompare += 1024 } - }) + } p.timesOfRowCompare++ return p.lessRowFunc(i, j) diff --git a/pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ b/pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ new file mode 100644 index 0000000000000..1d7a4d7aaa7b6 --- /dev/null +++ b/pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ @@ -0,0 +1,229 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "math/rand" + "slices" + "sync" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +// SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered. +const SignalCheckpointForSort uint = 20000 + +type parallelSortWorker struct { + workerIDForTest int + + chunkChannel chan *chunkWithMemoryUsage + fetcherAndWorkerSyncer *sync.WaitGroup + errOutputChan chan rowWithError + finishCh chan struct{} + + lessRowFunc func(chunk.Row, chunk.Row) int + timesOfRowCompare uint + + memTracker *memory.Tracker + totalMemoryUsage int64 + + spillHelper *parallelSortSpillHelper + + localSortedRows []*chunk.Iterator4Slice + sortedRowsIter *chunk.Iterator4Slice + maxSortedRowsLimit int + chunkIters []*chunk.Iterator4Chunk + rowNumInChunkIters int + merger *multiWayMerger +} + +func newParallelSortWorker( + workerIDForTest int, + lessRowFunc func(chunk.Row, chunk.Row) int, + chunkChannel chan *chunkWithMemoryUsage, + fetcherAndWorkerSyncer *sync.WaitGroup, + errOutputChan chan rowWithError, + finishCh chan struct{}, + memTracker *memory.Tracker, + sortedRowsIter *chunk.Iterator4Slice, + maxChunkSize int, + spillHelper *parallelSortSpillHelper) *parallelSortWorker { + return ¶llelSortWorker{ + workerIDForTest: workerIDForTest, + lessRowFunc: lessRowFunc, + chunkChannel: chunkChannel, + fetcherAndWorkerSyncer: fetcherAndWorkerSyncer, + errOutputChan: errOutputChan, + finishCh: finishCh, + timesOfRowCompare: 0, + memTracker: memTracker, + sortedRowsIter: sortedRowsIter, + maxSortedRowsLimit: maxChunkSize * 30, + spillHelper: spillHelper, + } +} + +func (p *parallelSortWorker) reset() { + p.localSortedRows = nil + p.sortedRowsIter = nil + p.merger = nil + p.memTracker.ReplaceBytesUsed(0) +} + +func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor int32) { + injectParallelSortRandomFail(triggerFactor) + failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { + if val.(bool) { + if p.workerIDForTest%2 == 0 { + randNum := rand.Int31n(10000) + if randNum < 10 { + time.Sleep(1 * time.Millisecond) + } + } + } + }) +} + +func (p *parallelSortWorker) multiWayMergeLocalSortedRows() ([]chunk.Row, error) { + totalRowNum := 0 + for _, rows := range p.localSortedRows { + totalRowNum += rows.Len() + } + resultSortedRows := make([]chunk.Row, 0, totalRowNum) + source := &memorySource{sortedRowsIters: p.localSortedRows} + p.merger = newMultiWayMerger(source, p.lessRowFunc) + err := p.merger.init() + if err != nil { + return nil, err + } + + for { + // It's impossible to return error here as rows are in memory + row, _ := p.merger.next() + if row.IsEmpty() { + break + } + resultSortedRows = append(resultSortedRows, row) + } + p.localSortedRows = nil + return resultSortedRows, nil +} + +func (p *parallelSortWorker) convertChunksToRows() []chunk.Row { + rows := make([]chunk.Row, 0, p.rowNumInChunkIters) + for _, iter := range p.chunkIters { + row := iter.Begin() + for !row.IsEmpty() { + rows = append(rows, row) + row = iter.Next() + } + } + p.chunkIters = p.chunkIters[:0] + p.rowNumInChunkIters = 0 + return rows +} + +func (p *parallelSortWorker) sortBatchRows() { + rows := p.convertChunksToRows() + slices.SortFunc(rows, p.keyColumnsLess) + p.localSortedRows = append(p.localSortedRows, chunk.NewIterator4Slice(rows)) +} + +func (p *parallelSortWorker) sortLocalRows() ([]chunk.Row, error) { + // Handle Remaining batchRows whose row number is not over the `maxSortedRowsLimit` + if p.rowNumInChunkIters > 0 { + p.sortBatchRows() + } + + return p.multiWayMergeLocalSortedRows() +} + +func (p *parallelSortWorker) saveChunk(chk *chunk.Chunk) { + chkIter := chunk.NewIterator4Chunk(chk) + p.chunkIters = append(p.chunkIters, chkIter) + p.rowNumInChunkIters += chkIter.Len() +} + +// Fetching a bunch of chunks from chunkChannel and sort them. +// After receiving all chunks, we will get several sorted rows slices and we use k-way merge to sort them. +func (p *parallelSortWorker) fetchChunksAndSort() { + for p.fetchChunksAndSortImpl() { + } +} + +func (p *parallelSortWorker) fetchChunksAndSortImpl() bool { + var ( + chk *chunkWithMemoryUsage + ok bool + ) + select { + case <-p.finishCh: + return false + case chk, ok = <-p.chunkChannel: + // Memory usage of the chunk has been consumed at the chunk fetcher + if !ok { + p.injectFailPointForParallelSortWorker(100) + // Put local sorted rows into this iter who will be read by sort executor + sortedRows, err := p.sortLocalRows() + if err != nil { + p.errOutputChan <- rowWithError{err: err} + return false + } + p.sortedRowsIter.Reset(sortedRows) + return false + } + defer p.fetcherAndWorkerSyncer.Done() + p.totalMemoryUsage += chk.MemoryUsage + } + + p.saveChunk(chk.Chk) + + if p.rowNumInChunkIters >= p.maxSortedRowsLimit { + p.sortBatchRows() + } + + p.injectFailPointForParallelSortWorker(3) + return true +} + +func (p *parallelSortWorker) keyColumnsLess(i, j chunk.Row) int { + if p.timesOfRowCompare >= SignalCheckpointForSort { + // Trigger Consume for checking the NeedKill signal + p.memTracker.Consume(1) + p.timesOfRowCompare = 0 + } + + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val.(bool) { + p.timesOfRowCompare += 1024 + } + }) + p.timesOfRowCompare++ + + return p.lessRowFunc(i, j) +} + +func (p *parallelSortWorker) run() { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(p.errOutputChan, r) + } + }() + + p.fetchChunksAndSort() +} diff --git a/pkg/executor/sortexec/sort.go b/pkg/executor/sortexec/sort.go index 9f581cb699a9a..5d1f79a554040 100644 --- a/pkg/executor/sortexec/sort.go +++ b/pkg/executor/sortexec/sort.go @@ -615,28 +615,28 @@ func (e *SortExec) fetchChunksUnparallel(ctx context.Context) error { return err } - failpoint.Inject("unholdSyncLock", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("unholdSyncLock")); _err_ == nil { if val.(bool) { // Ensure that spill can get `syncLock`. time.Sleep(1 * time.Millisecond) } - }) + } } - failpoint.Inject("waitForSpill", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("waitForSpill")); _err_ == nil { if val.(bool) { // Ensure that spill is triggered before returning data. time.Sleep(50 * time.Millisecond) } - }) + } - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { if val.(bool) { if e.Ctx().GetSessionVars().ConnectionID == 123456 { e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) } } - }) + } err = e.handleCurrentPartitionBeforeExit() if err != nil { @@ -700,13 +700,13 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) { e.Parallel.resultChannel <- rowWithError{err: err} } - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { if val.(bool) { if e.Ctx().GetSessionVars().ConnectionID == 123456 { e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) } } - }) + } // We must place it after the spill as workers will process its received // chunks after channel is closed and this will cause data race. diff --git a/pkg/executor/sortexec/sort.go__failpoint_stash__ b/pkg/executor/sortexec/sort.go__failpoint_stash__ new file mode 100644 index 0000000000000..9f581cb699a9a --- /dev/null +++ b/pkg/executor/sortexec/sort.go__failpoint_stash__ @@ -0,0 +1,845 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/expression" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sqlkiller" +) + +// SortExec represents sorting executor. +type SortExec struct { + exec.BaseExecutor + + ByItems []*plannerutil.ByItems + fetched *atomic.Bool + ExecSchema *expression.Schema + + // keyColumns is the column index of the by items. + keyColumns []int + // keyCmpFuncs is used to compare each ByItem. + keyCmpFuncs []chunk.CompareFunc + + curPartition *sortPartition + + // We can't spill if size of data is lower than the limit + spillLimit int64 + + memTracker *memory.Tracker + diskTracker *disk.Tracker + + // TODO delete this variable in the future and remove the unparallel sort + IsUnparallel bool + + finishCh chan struct{} + + // multiWayMerge uses multi-way merge for spill disk. + // The multi-way merge algorithm can refer to https://en.wikipedia.org/wiki/K-way_merge_algorithm + multiWayMerge *multiWayMerger + + Unparallel struct { + Idx int + + // sortPartitions is the chunks to store row values for partitions. Every partition is a sorted list. + sortPartitions []*sortPartition + + spillAction *sortPartitionSpillDiskAction + } + + Parallel struct { + chunkChannel chan *chunkWithMemoryUsage + // It's useful when spill is triggered and the fetcher could know when workers finish their works. + fetcherAndWorkerSyncer *sync.WaitGroup + workers []*parallelSortWorker + + // Each worker will put their results into the given iter + sortedRowsIters []*chunk.Iterator4Slice + merger *multiWayMerger + + resultChannel chan rowWithError + + // Ensure that workers and fetcher have exited + closeSync chan struct{} + + spillHelper *parallelSortSpillHelper + spillAction *parallelSortSpillAction + } + + enableTmpStorageOnOOM bool +} + +// Close implements the Executor Close interface. +func (e *SortExec) Close() error { + // TopN not initializes `e.finishCh` but it will call the Close function + if e.finishCh != nil { + close(e.finishCh) + } + if e.Unparallel.spillAction != nil { + e.Unparallel.spillAction.SetFinished() + } + + if e.IsUnparallel { + for _, partition := range e.Unparallel.sortPartitions { + partition.close() + } + } else if e.finishCh != nil { + if e.fetched.CompareAndSwap(false, true) { + close(e.Parallel.resultChannel) + close(e.Parallel.chunkChannel) + } else { + for range e.Parallel.chunkChannel { + e.Parallel.fetcherAndWorkerSyncer.Done() + } + <-e.Parallel.closeSync + } + + // Ensure that `generateResult()` has exited, + // or data race may happen as `generateResult()` + // will use `e.Parallel.workers` and `e.Parallel.merger`. + channel.Clear(e.Parallel.resultChannel) + for i := range e.Parallel.workers { + if e.Parallel.workers[i] != nil { + e.Parallel.workers[i].reset() + } + } + e.Parallel.merger = nil + if e.Parallel.spillAction != nil { + e.Parallel.spillAction.SetFinished() + } + e.Parallel.spillHelper.close() + } + + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) + } + + return exec.Close(e.Children(0)) +} + +// Open implements the Executor Open interface. +func (e *SortExec) Open(ctx context.Context) error { + e.fetched = &atomic.Bool{} + e.fetched.Store(false) + e.enableTmpStorageOnOOM = variable.EnableTmpStorageOnOOM.Load() + e.finishCh = make(chan struct{}, 1) + + // To avoid duplicated initialization for TopNExec. + if e.memTracker == nil { + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + e.spillLimit = e.Ctx().GetSessionVars().MemTracker.GetBytesLimit() / 10 + e.diskTracker = disk.NewTracker(e.ID(), -1) + e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) + } + + e.IsUnparallel = false + if e.IsUnparallel { + e.Unparallel.Idx = 0 + e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0] + } else { + e.Parallel.workers = make([]*parallelSortWorker, e.Ctx().GetSessionVars().ExecutorConcurrency) + e.Parallel.chunkChannel = make(chan *chunkWithMemoryUsage, e.Ctx().GetSessionVars().ExecutorConcurrency) + e.Parallel.fetcherAndWorkerSyncer = &sync.WaitGroup{} + e.Parallel.sortedRowsIters = make([]*chunk.Iterator4Slice, len(e.Parallel.workers)) + e.Parallel.resultChannel = make(chan rowWithError, 10) + e.Parallel.closeSync = make(chan struct{}) + e.Parallel.merger = newMultiWayMerger(&memorySource{sortedRowsIters: e.Parallel.sortedRowsIters}, e.lessRow) + e.Parallel.spillHelper = newParallelSortSpillHelper(e, exec.RetTypes(e), e.finishCh, e.lessRow, e.Parallel.resultChannel) + e.Parallel.spillAction = newParallelSortSpillDiskAction(e.Parallel.spillHelper) + for i := range e.Parallel.sortedRowsIters { + e.Parallel.sortedRowsIters[i] = chunk.NewIterator4Slice(nil) + } + if e.enableTmpStorageOnOOM { + e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.Parallel.spillAction) + } + } + + return exec.Open(ctx, e.Children(0)) +} + +// InitUnparallelModeForTest is for unit test +func (e *SortExec) InitUnparallelModeForTest() { + e.Unparallel.Idx = 0 + e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0] +} + +// Next implements the Executor Next interface. +// Sort constructs the result following these step in unparallel mode: +// 1. Read as mush as rows into memory. +// 2. If memory quota is triggered, sort these rows in memory and put them into disk as partition 1, then reset +// the memory quota trigger and return to step 1 +// 3. If memory quota is not triggered and child is consumed, sort these rows in memory as partition N. +// 4. Merge sort if the count of partitions is larger than 1. If there is only one partition in step 4, it works +// just like in-memory sort before. +// +// Here we explain the execution flow of the parallel sort implementation. +// There are 3 main components: +// 1. Chunks Fetcher: Fetcher is responsible for fetching chunks from child and send them to channel. +// 2. Parallel Sort Worker: Worker receives chunks from channel it will sort these chunks after the +// number of rows in these chunks exceeds limit, we call them as sorted rows after chunks are sorted. +// Then each worker will have several sorted rows, we use multi-way merge to sort them and each worker +// will have only one sorted rows in the end. +// 3. Result Generator: Generator gets n sorted rows from n workers, it will use multi-way merge to sort +// these rows, once it gets the next row, it will send it into `resultChannel` and the goroutine who +// calls `Next()` will fetch result from `resultChannel`. +/* + ┌─────────┐ + │ Child │ + └────▲────┘ + │ + Fetch + │ + ┌───────┴───────┐ + │ Chunk Fetcher │ + └───────┬───────┘ + │ + Push + │ + ▼ + ┌────────────────►Channel◄───────────────────┐ + │ ▲ │ + │ │ │ + Fetch Fetch Fetch + │ │ │ + ┌────┴───┐ ┌───┴────┐ ┌───┴────┐ + │ Worker │ │ Worker │ ...... │ Worker │ + └────┬───┘ └───┬────┘ └───┬────┘ + │ │ │ + │ │ │ + Sort Sort Sort + │ │ │ + │ │ │ + ┌──────┴──────┐ ┌──────┴──────┐ ┌──────┴──────┐ + │ Sorted Rows │ │ Sorted Rows │ ...... │ Sorted Rows │ + └──────▲──────┘ └──────▲──────┘ └──────▲──────┘ + │ │ │ + Pull Pull Pull + │ │ │ + └────────────────────┼───────────────────────┘ + │ + Multi-way Merge + │ + ┌──────┴──────┐ + │ Generator │ + └──────┬──────┘ + │ + Push + │ + ▼ + resultChannel +*/ +func (e *SortExec) Next(ctx context.Context, req *chunk.Chunk) error { + if e.fetched.CompareAndSwap(false, true) { + err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) + if err != nil { + return err + } + + e.buildKeyColumns() + err = e.fetchChunks(ctx) + if err != nil { + return err + } + } + + req.Reset() + if e.IsUnparallel { + return e.appendResultToChunkInUnparallelMode(req) + } + return e.appendResultToChunkInParallelMode(req) +} + +func (e *SortExec) appendResultToChunkInParallelMode(req *chunk.Chunk) error { + for !req.IsFull() { + row, ok := <-e.Parallel.resultChannel + if row.err != nil { + return row.err + } + if !ok { + return nil + } + req.AppendRow(row.row) + } + return nil +} + +func (e *SortExec) appendResultToChunkInUnparallelMode(req *chunk.Chunk) error { + sortPartitionListLen := len(e.Unparallel.sortPartitions) + if sortPartitionListLen == 0 { + return nil + } + + if sortPartitionListLen == 1 { + if err := e.onePartitionSorting(req); err != nil { + return err + } + } else { + if err := e.externalSorting(req); err != nil { + return err + } + } + return nil +} + +func (e *SortExec) generateResultWithMultiWayMerge() error { + multiWayMerge := newMultiWayMerger(&diskSource{sortedRowsInDisk: e.Parallel.spillHelper.sortedRowsInDisk}, e.lessRow) + + err := multiWayMerge.init() + if err != nil { + return err + } + + for { + row, err := multiWayMerge.next() + if err != nil { + return err + } + + if row.IsEmpty() { + return nil + } + + select { + case <-e.finishCh: + return nil + case e.Parallel.resultChannel <- rowWithError{row: row}: + } + injectParallelSortRandomFail(1) + } +} + +// We call this function when sorted rows are in disk +func (e *SortExec) generateResultFromDisk() error { + inDiskNum := len(e.Parallel.spillHelper.sortedRowsInDisk) + if inDiskNum == 0 { + return nil + } + + // Spill is triggered only once + if inDiskNum == 1 { + inDisk := e.Parallel.spillHelper.sortedRowsInDisk[0] + chunkNum := inDisk.NumChunks() + for i := 0; i < chunkNum; i++ { + chk, err := inDisk.GetChunk(i) + if err != nil { + return err + } + + injectParallelSortRandomFail(1) + + rowNum := chk.NumRows() + for j := 0; j < rowNum; j++ { + select { + case <-e.finishCh: + return nil + case e.Parallel.resultChannel <- rowWithError{row: chk.GetRow(j)}: + } + } + } + return nil + } + return e.generateResultWithMultiWayMerge() +} + +// We call this function to generate result when sorted rows are in memory +// Return true when spill is triggered +func (e *SortExec) generateResultFromMemory() (bool, error) { + if e.Parallel.merger == nil { + // Sort has been closed + return false, nil + } + err := e.Parallel.merger.init() + if err != nil { + return false, err + } + + maxChunkSize := e.MaxChunkSize() + resBuf := make([]rowWithError, 0, 3) + idx := int64(0) + var row chunk.Row + for { + resBuf = resBuf[:0] + for i := 0; i < maxChunkSize; i++ { + // It's impossible to return error here as rows are in memory + row, _ = e.Parallel.merger.next() + if row.IsEmpty() { + break + } + resBuf = append(resBuf, rowWithError{row: row, err: nil}) + } + + if len(resBuf) == 0 { + return false, nil + } + + for _, row := range resBuf { + select { + case <-e.finishCh: + return false, nil + case e.Parallel.resultChannel <- row: + } + } + + injectParallelSortRandomFail(3) + + if idx%1000 == 0 && e.Parallel.spillHelper.isSpillNeeded() { + return true, nil + } + } +} + +func (e *SortExec) generateResult(waitGroups ...*util.WaitGroupWrapper) { + for _, waitGroup := range waitGroups { + waitGroup.Wait() + } + close(e.Parallel.closeSync) + + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.Parallel.resultChannel, r) + } + + for i := range e.Parallel.sortedRowsIters { + e.Parallel.sortedRowsIters[i].Reset(nil) + } + e.Parallel.merger = nil + close(e.Parallel.resultChannel) + }() + + if !e.Parallel.spillHelper.isSpillTriggered() { + spillTriggered, err := e.generateResultFromMemory() + if err != nil { + e.Parallel.resultChannel <- rowWithError{err: err} + return + } + + if !spillTriggered { + return + } + + err = e.spillSortedRowsInMemory() + if err != nil { + e.Parallel.resultChannel <- rowWithError{err: err} + return + } + } + + err := e.generateResultFromDisk() + if err != nil { + e.Parallel.resultChannel <- rowWithError{err: err} + } +} + +// Spill rows that are in memory +func (e *SortExec) spillSortedRowsInMemory() error { + return e.Parallel.spillHelper.spillImpl(e.Parallel.merger) +} + +func (e *SortExec) onePartitionSorting(req *chunk.Chunk) (err error) { + err = e.Unparallel.sortPartitions[0].checkError() + if err != nil { + return err + } + + for !req.IsFull() { + row, err := e.Unparallel.sortPartitions[0].getNextSortedRow() + if err != nil { + return err + } + + if row.IsEmpty() { + return nil + } + + req.AppendRow(row) + } + return nil +} + +func (e *SortExec) externalSorting(req *chunk.Chunk) (err error) { + // We only need to check error for the last partition as previous partitions + // have been checked when we call `switchToNewSortPartition` function. + err = e.Unparallel.sortPartitions[len(e.Unparallel.sortPartitions)-1].checkError() + if err != nil { + return err + } + + if e.multiWayMerge == nil { + e.multiWayMerge = newMultiWayMerger(&sortPartitionSource{sortPartitions: e.Unparallel.sortPartitions}, e.lessRow) + err := e.multiWayMerge.init() + if err != nil { + return err + } + } + + for !req.IsFull() { + row, err := e.multiWayMerge.next() + if err != nil { + return err + } + if row.IsEmpty() { + return nil + } + req.AppendRow(row) + } + return nil +} + +func (e *SortExec) fetchChunks(ctx context.Context) error { + if e.IsUnparallel { + return e.fetchChunksUnparallel(ctx) + } + return e.fetchChunksParallel(ctx) +} + +func (e *SortExec) switchToNewSortPartition(fields []*types.FieldType, byItemsDesc []bool, appendPartition bool) error { + if appendPartition { + // Put the full partition into list + e.Unparallel.sortPartitions = append(e.Unparallel.sortPartitions, e.curPartition) + } + + if e.curPartition != nil { + err := e.curPartition.checkError() + if err != nil { + return err + } + } + + e.curPartition = newSortPartition(fields, byItemsDesc, e.keyColumns, e.keyCmpFuncs, e.spillLimit) + e.curPartition.getMemTracker().AttachTo(e.memTracker) + e.curPartition.getMemTracker().SetLabel(memory.LabelForRowChunks) + e.Unparallel.spillAction = e.curPartition.actionSpill() + if e.enableTmpStorageOnOOM { + e.curPartition.getDiskTracker().AttachTo(e.diskTracker) + e.curPartition.getDiskTracker().SetLabel(memory.LabelForRowChunks) + e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.Unparallel.spillAction) + } + return nil +} + +func (e *SortExec) checkError() error { + for _, partition := range e.Unparallel.sortPartitions { + err := partition.checkError() + if err != nil { + return err + } + } + return nil +} + +func (e *SortExec) storeChunk(chk *chunk.Chunk, fields []*types.FieldType, byItemsDesc []bool) error { + err := e.curPartition.checkError() + if err != nil { + return err + } + + if !e.curPartition.add(chk) { + err := e.switchToNewSortPartition(fields, byItemsDesc, true) + if err != nil { + return err + } + + if !e.curPartition.add(chk) { + return errFailToAddChunk + } + } + return nil +} + +func (e *SortExec) handleCurrentPartitionBeforeExit() error { + err := e.checkError() + if err != nil { + return err + } + + err = e.curPartition.sort() + if err != nil { + return err + } + + return nil +} + +func (e *SortExec) fetchChunksUnparallel(ctx context.Context) error { + fields := exec.RetTypes(e) + byItemsDesc := make([]bool, len(e.ByItems)) + for i, byItem := range e.ByItems { + byItemsDesc[i] = byItem.Desc + } + + err := e.switchToNewSortPartition(fields, byItemsDesc, false) + if err != nil { + return err + } + + for { + chk := exec.TryNewCacheChunk(e.Children(0)) + err := exec.Next(ctx, e.Children(0), chk) + if err != nil { + return err + } + if chk.NumRows() == 0 { + break + } + + err = e.storeChunk(chk, fields, byItemsDesc) + if err != nil { + return err + } + + failpoint.Inject("unholdSyncLock", func(val failpoint.Value) { + if val.(bool) { + // Ensure that spill can get `syncLock`. + time.Sleep(1 * time.Millisecond) + } + }) + } + + failpoint.Inject("waitForSpill", func(val failpoint.Value) { + if val.(bool) { + // Ensure that spill is triggered before returning data. + time.Sleep(50 * time.Millisecond) + } + }) + + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val.(bool) { + if e.Ctx().GetSessionVars().ConnectionID == 123456 { + e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) + } + } + }) + + err = e.handleCurrentPartitionBeforeExit() + if err != nil { + return err + } + + e.Unparallel.sortPartitions = append(e.Unparallel.sortPartitions, e.curPartition) + e.curPartition = nil + return nil +} + +func (e *SortExec) fetchChunksParallel(ctx context.Context) error { + // Wait for the finish of all workers + workersWaiter := util.WaitGroupWrapper{} + // Wait for the finish of chunk fetcher + fetcherWaiter := util.WaitGroupWrapper{} + + for i := range e.Parallel.workers { + e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper) + worker := e.Parallel.workers[i] + workersWaiter.Run(func() { + worker.run() + }) + } + + // Fetch chunks from child and put chunks into chunkChannel + fetcherWaiter.Run(func() { + e.fetchChunksFromChild(ctx) + }) + + go e.generateResult(&workersWaiter, &fetcherWaiter) + return nil +} + +func (e *SortExec) spillRemainingRowsWhenNeeded() error { + if e.Parallel.spillHelper.isSpillTriggered() { + return e.Parallel.spillHelper.spill() + } + return nil +} + +func (e *SortExec) checkSpillAndExecute() error { + if e.Parallel.spillHelper.isSpillNeeded() { + // Wait for the stop of all workers + e.Parallel.fetcherAndWorkerSyncer.Wait() + return e.Parallel.spillHelper.spill() + } + return nil +} + +// Fetch chunks from child and put chunks into chunkChannel +func (e *SortExec) fetchChunksFromChild(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.Parallel.resultChannel, r) + } + + e.Parallel.fetcherAndWorkerSyncer.Wait() + err := e.spillRemainingRowsWhenNeeded() + if err != nil { + e.Parallel.resultChannel <- rowWithError{err: err} + } + + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val.(bool) { + if e.Ctx().GetSessionVars().ConnectionID == 123456 { + e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) + } + } + }) + + // We must place it after the spill as workers will process its received + // chunks after channel is closed and this will cause data race. + close(e.Parallel.chunkChannel) + }() + + for { + chk := exec.TryNewCacheChunk(e.Children(0)) + err := exec.Next(ctx, e.Children(0), chk) + if err != nil { + e.Parallel.resultChannel <- rowWithError{err: err} + return + } + + rowCount := chk.NumRows() + if rowCount == 0 { + break + } + + chkWithMemoryUsage := &chunkWithMemoryUsage{ + Chk: chk, + MemoryUsage: chk.MemoryUsage() + chunk.RowSize*int64(rowCount), + } + + e.memTracker.Consume(chkWithMemoryUsage.MemoryUsage) + + e.Parallel.fetcherAndWorkerSyncer.Add(1) + + select { + case <-e.finishCh: + e.Parallel.fetcherAndWorkerSyncer.Done() + return + case e.Parallel.chunkChannel <- chkWithMemoryUsage: + } + + err = e.checkSpillAndExecute() + if err != nil { + e.Parallel.resultChannel <- rowWithError{err: err} + return + } + injectParallelSortRandomFail(3) + } +} + +func (e *SortExec) initCompareFuncs(ctx expression.EvalContext) error { + e.keyCmpFuncs = make([]chunk.CompareFunc, len(e.ByItems)) + for i := range e.ByItems { + keyType := e.ByItems[i].Expr.GetType(ctx) + e.keyCmpFuncs[i] = chunk.GetCompareFunc(keyType) + if e.keyCmpFuncs[i] == nil { + return errors.Errorf("Sort executor not supports type %s", types.TypeStr(keyType.GetType())) + } + } + return nil +} + +func (e *SortExec) buildKeyColumns() { + e.keyColumns = make([]int, 0, len(e.ByItems)) + for _, by := range e.ByItems { + col := by.Expr.(*expression.Column) + e.keyColumns = append(e.keyColumns, col.Index) + } +} + +func (e *SortExec) lessRow(rowI, rowJ chunk.Row) int { + for i, colIdx := range e.keyColumns { + cmpFunc := e.keyCmpFuncs[i] + cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) + if e.ByItems[i].Desc { + cmp = -cmp + } + if cmp != 0 { + return cmp + } + } + return 0 +} + +func (e *SortExec) compareRow(rowI, rowJ chunk.Row) int { + for i, colIdx := range e.keyColumns { + cmpFunc := e.keyCmpFuncs[i] + cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) + if e.ByItems[i].Desc { + cmp = -cmp + } + if cmp != 0 { + return cmp + } + } + return 0 +} + +// IsSpillTriggeredInParallelSortForTest tells if spill is triggered in parallel sort. +func (e *SortExec) IsSpillTriggeredInParallelSortForTest() bool { + return e.Parallel.spillHelper.isSpillTriggered() +} + +// GetSpilledRowNumInParallelSortForTest tells if spill is triggered in parallel sort. +func (e *SortExec) GetSpilledRowNumInParallelSortForTest() int64 { + totalSpilledRows := int64(0) + for _, disk := range e.Parallel.spillHelper.sortedRowsInDisk { + totalSpilledRows += disk.NumRows() + } + return totalSpilledRows +} + +// IsSpillTriggeredInOnePartitionForTest tells if spill is triggered in a specific partition, it's only used in test. +func (e *SortExec) IsSpillTriggeredInOnePartitionForTest(idx int) bool { + return e.Unparallel.sortPartitions[idx].isSpillTriggered() +} + +// GetRowNumInOnePartitionDiskForTest returns number of rows a partition holds in disk, it's only used in test. +func (e *SortExec) GetRowNumInOnePartitionDiskForTest(idx int) int64 { + return e.Unparallel.sortPartitions[idx].numRowInDiskForTest() +} + +// GetRowNumInOnePartitionMemoryForTest returns number of rows a partition holds in memory, it's only used in test. +func (e *SortExec) GetRowNumInOnePartitionMemoryForTest(idx int) int64 { + return e.Unparallel.sortPartitions[idx].numRowInMemoryForTest() +} + +// GetSortPartitionListLenForTest returns the number of partitions, it's only used in test. +func (e *SortExec) GetSortPartitionListLenForTest() int { + return len(e.Unparallel.sortPartitions) +} + +// GetSortMetaForTest returns some sort meta, it's only used in test. +func (e *SortExec) GetSortMetaForTest() (keyColumns []int, keyCmpFuncs []chunk.CompareFunc, byItemsDesc []bool) { + keyColumns = e.keyColumns + keyCmpFuncs = e.keyCmpFuncs + byItemsDesc = make([]bool, len(e.ByItems)) + for i, byItem := range e.ByItems { + byItemsDesc[i] = byItem.Desc + } + return +} diff --git a/pkg/executor/sortexec/sort_partition.go b/pkg/executor/sortexec/sort_partition.go index 7c798f6385e95..82e1e3ee07822 100644 --- a/pkg/executor/sortexec/sort_partition.go +++ b/pkg/executor/sortexec/sort_partition.go @@ -141,11 +141,11 @@ func (s *sortPartition) sortNoLock() (ret error) { return } - failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorDuringSortRowContainer")); _err_ == nil { if val.(bool) { panic("sort meet error") } - }) + } sort.Slice(s.savedRows, s.keyColumnsLess) s.isSorted = true @@ -297,11 +297,11 @@ func (s *sortPartition) keyColumnsLess(i, j int) bool { s.timesOfRowCompare = 0 } - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { if val.(bool) { s.timesOfRowCompare += 1024 } - }) + } s.timesOfRowCompare++ return s.lessRow(s.savedRows[i], s.savedRows[j]) diff --git a/pkg/executor/sortexec/sort_partition.go__failpoint_stash__ b/pkg/executor/sortexec/sort_partition.go__failpoint_stash__ new file mode 100644 index 0000000000000..7c798f6385e95 --- /dev/null +++ b/pkg/executor/sortexec/sort_partition.go__failpoint_stash__ @@ -0,0 +1,367 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "sort" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +type sortPartition struct { + // cond is only used for protecting spillStatus + cond *sync.Cond + spillStatus int + + // syncLock is used to protect variables except `spillStatus` + syncLock sync.Mutex + + // Data are stored in savedRows + savedRows []chunk.Row + sliceIter *chunk.Iterator4Slice + isSorted bool + + // cursor iterates the spilled chunks. + cursor *dataCursor + inDisk *chunk.DataInDiskByChunks + + spillError error + closed bool + + fieldTypes []*types.FieldType + + memTracker *memory.Tracker + diskTracker *disk.Tracker + spillAction *sortPartitionSpillDiskAction + + // We can't spill if size of data is lower than the limit + spillLimit int64 + + byItemsDesc []bool + // keyColumns is the column index of the by items. + keyColumns []int + // keyCmpFuncs is used to compare each ByItem. + keyCmpFuncs []chunk.CompareFunc + + // Sort is a time-consuming operation, we need to set a checkpoint to detect + // the outside signal periodically. + timesOfRowCompare uint +} + +// Creates a new SortPartition in memory. +func newSortPartition(fieldTypes []*types.FieldType, byItemsDesc []bool, + keyColumns []int, keyCmpFuncs []chunk.CompareFunc, spillLimit int64) *sortPartition { + lock := new(sync.Mutex) + retVal := &sortPartition{ + cond: sync.NewCond(lock), + spillError: nil, + spillStatus: notSpilled, + fieldTypes: fieldTypes, + savedRows: make([]chunk.Row, 0), + isSorted: false, + inDisk: nil, // It's initialized only when spill is triggered + memTracker: memory.NewTracker(memory.LabelForSortPartition, -1), + diskTracker: disk.NewTracker(memory.LabelForSortPartition, -1), + spillAction: nil, // It's set in `actionSpill` function + spillLimit: spillLimit, + byItemsDesc: byItemsDesc, + keyColumns: keyColumns, + keyCmpFuncs: keyCmpFuncs, + cursor: NewDataCursor(), + closed: false, + } + + return retVal +} + +func (s *sortPartition) close() { + s.syncLock.Lock() + defer s.syncLock.Unlock() + s.closed = true + if s.inDisk != nil { + s.inDisk.Close() + } + s.getMemTracker().ReplaceBytesUsed(0) +} + +// Return false if the spill is triggered in this partition. +func (s *sortPartition) add(chk *chunk.Chunk) bool { + rowNum := chk.NumRows() + consumedBytesNum := chunk.RowSize*int64(rowNum) + chk.MemoryUsage() + + s.syncLock.Lock() + defer s.syncLock.Unlock() + if s.isSpillTriggered() { + return false + } + + // Convert chunk to rows + for i := 0; i < rowNum; i++ { + s.savedRows = append(s.savedRows, chk.GetRow(i)) + } + + s.getMemTracker().Consume(consumedBytesNum) + return true +} + +func (s *sortPartition) sort() error { + s.syncLock.Lock() + defer s.syncLock.Unlock() + return s.sortNoLock() +} + +func (s *sortPartition) sortNoLock() (ret error) { + ret = nil + defer func() { + if r := recover(); r != nil { + ret = util.GetRecoverError(r) + } + }() + + if s.isSorted { + return + } + + failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { + if val.(bool) { + panic("sort meet error") + } + }) + + sort.Slice(s.savedRows, s.keyColumnsLess) + s.isSorted = true + s.sliceIter = chunk.NewIterator4Slice(s.savedRows) + return +} + +func (s *sortPartition) spillToDiskImpl() (err error) { + defer func() { + if r := recover(); r != nil { + err = util.GetRecoverError(r) + } + }() + + if s.closed { + return nil + } + + s.inDisk = chunk.NewDataInDiskByChunks(s.fieldTypes) + s.inDisk.GetDiskTracker().AttachTo(s.diskTracker) + tmpChk := chunk.NewChunkWithCapacity(s.fieldTypes, spillChunkSize) + + rowNum := len(s.savedRows) + if rowNum == 0 { + return errSpillEmptyChunk + } + + for row := s.sliceIter.Next(); !row.IsEmpty(); row = s.sliceIter.Next() { + tmpChk.AppendRow(row) + if tmpChk.IsFull() { + err := s.inDisk.Add(tmpChk) + if err != nil { + return err + } + tmpChk.Reset() + s.getMemTracker().HandleKillSignal() + } + } + + // Spill the remaining data in tmpChk. + // Do not spill when tmpChk is empty as `Add` function requires a non-empty chunk + if tmpChk.NumRows() > 0 { + err := s.inDisk.Add(tmpChk) + if err != nil { + return err + } + } + + // Release memory as all data have been spilled to disk + s.savedRows = nil + s.sliceIter = nil + s.getMemTracker().ReplaceBytesUsed(0) + return nil +} + +// We can only call this function under the protection of `syncLock`. +func (s *sortPartition) spillToDisk() error { + s.syncLock.Lock() + defer s.syncLock.Unlock() + if s.isSpillTriggered() { + return nil + } + + err := s.sortNoLock() + if err != nil { + return err + } + + s.setIsSpilling() + defer s.cond.Broadcast() + defer s.setSpillTriggered() + + err = s.spillToDiskImpl() + return err +} + +func (s *sortPartition) getNextSortedRow() (chunk.Row, error) { + s.syncLock.Lock() + defer s.syncLock.Unlock() + if s.isSpillTriggered() { + row := s.cursor.next() + if row.IsEmpty() { + success, err := reloadCursor(s.cursor, s.inDisk) + if err != nil { + return chunk.Row{}, err + } + if !success { + // All data has been consumed + return chunk.Row{}, nil + } + + row = s.cursor.begin() + if row.IsEmpty() { + return chunk.Row{}, errors.New("Get an empty row") + } + } + return row, nil + } + + row := s.sliceIter.Next() + return row, nil +} + +func (s *sortPartition) actionSpill() *sortPartitionSpillDiskAction { + if s.spillAction == nil { + s.spillAction = &sortPartitionSpillDiskAction{ + partition: s, + } + } + return s.spillAction +} + +func (s *sortPartition) getMemTracker() *memory.Tracker { + return s.memTracker +} + +func (s *sortPartition) getDiskTracker() *disk.Tracker { + return s.diskTracker +} + +func (s *sortPartition) hasEnoughDataToSpill() bool { + // Guarantee that each partition size is not too small, to avoid opening too many files. + return s.getMemTracker().BytesConsumed() > s.spillLimit +} + +func (s *sortPartition) lessRow(rowI, rowJ chunk.Row) bool { + for i, colIdx := range s.keyColumns { + cmpFunc := s.keyCmpFuncs[i] + if cmpFunc != nil { + cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) + if s.byItemsDesc[i] { + cmp = -cmp + } + if cmp < 0 { + return true + } else if cmp > 0 { + return false + } + } + } + return false +} + +// keyColumnsLess is the less function for key columns. +func (s *sortPartition) keyColumnsLess(i, j int) bool { + if s.timesOfRowCompare >= signalCheckpointForSort { + // Trigger Consume for checking the NeedKill signal + s.memTracker.HandleKillSignal() + s.timesOfRowCompare = 0 + } + + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val.(bool) { + s.timesOfRowCompare += 1024 + } + }) + + s.timesOfRowCompare++ + return s.lessRow(s.savedRows[i], s.savedRows[j]) +} + +func (s *sortPartition) isSpillTriggered() bool { + s.cond.L.Lock() + defer s.cond.L.Unlock() + return s.spillStatus == spillTriggered +} + +func (s *sortPartition) isSpillTriggeredNoLock() bool { + return s.spillStatus == spillTriggered +} + +func (s *sortPartition) setSpillTriggered() { + s.cond.L.Lock() + defer s.cond.L.Unlock() + s.spillStatus = spillTriggered +} + +func (s *sortPartition) setIsSpilling() { + s.cond.L.Lock() + defer s.cond.L.Unlock() + s.spillStatus = inSpilling +} + +func (s *sortPartition) getIsSpillingNoLock() bool { + return s.spillStatus == inSpilling +} + +func (s *sortPartition) setError(err error) { + s.syncLock.Lock() + defer s.syncLock.Unlock() + s.spillError = err +} + +func (s *sortPartition) checkError() error { + s.syncLock.Lock() + defer s.syncLock.Unlock() + return s.spillError +} + +func (s *sortPartition) numRowInDiskForTest() int64 { + if s.inDisk != nil { + return s.inDisk.NumRows() + } + return 0 +} + +func (s *sortPartition) numRowInMemoryForTest() int64 { + if s.sliceIter != nil { + if s.sliceIter.Len() != len(s.savedRows) { + panic("length of sliceIter should be equal to savedRows") + } + } + return int64(len(s.savedRows)) +} + +// SetSmallSpillChunkSizeForTest set spill chunk size for test. +func SetSmallSpillChunkSizeForTest() { + spillChunkSize = 16 +} diff --git a/pkg/executor/sortexec/sort_util.go b/pkg/executor/sortexec/sort_util.go index 59ef17f90da2c..8a4aaee944499 100644 --- a/pkg/executor/sortexec/sort_util.go +++ b/pkg/executor/sortexec/sort_util.go @@ -66,14 +66,14 @@ type rowWithError struct { } func injectParallelSortRandomFail(triggerFactor int32) { - failpoint.Inject("ParallelSortRandomFail", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ParallelSortRandomFail")); _err_ == nil { if val.(bool) { randNum := rand.Int31n(10000) if randNum < triggerFactor { panic("panic is triggered by random fail") } } - }) + } } // It's used only when spill is triggered diff --git a/pkg/executor/sortexec/sort_util.go__failpoint_stash__ b/pkg/executor/sortexec/sort_util.go__failpoint_stash__ new file mode 100644 index 0000000000000..59ef17f90da2c --- /dev/null +++ b/pkg/executor/sortexec/sort_util.go__failpoint_stash__ @@ -0,0 +1,124 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "math/rand" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +var errSpillEmptyChunk = errors.New("can not spill empty chunk to disk") +var errFailToAddChunk = errors.New("fail to add chunk") + +// It should be const, but we need to modify it for test. +var spillChunkSize = 1024 + +// signalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered. +const signalCheckpointForSort uint = 10240 + +const ( + notSpilled = iota + needSpill + inSpilling + spillTriggered +) + +type rowWithPartition struct { + row chunk.Row + partitionID int +} + +func processPanicAndLog(errOutputChan chan<- rowWithError, r any) { + err := util.GetRecoverError(r) + errOutputChan <- rowWithError{err: err} + logutil.BgLogger().Error("executor panicked", zap.Error(err), zap.Stack("stack")) +} + +// chunkWithMemoryUsage contains chunk and memory usage. +// However, some of memory usage may also come from other place, +// not only the chunk's memory usage. +type chunkWithMemoryUsage struct { + Chk *chunk.Chunk + MemoryUsage int64 +} + +type rowWithError struct { + row chunk.Row + err error +} + +func injectParallelSortRandomFail(triggerFactor int32) { + failpoint.Inject("ParallelSortRandomFail", func(val failpoint.Value) { + if val.(bool) { + randNum := rand.Int31n(10000) + if randNum < triggerFactor { + panic("panic is triggered by random fail") + } + } + }) +} + +// It's used only when spill is triggered +type dataCursor struct { + chkID int + chunkIter *chunk.Iterator4Chunk +} + +// NewDataCursor creates a new dataCursor +func NewDataCursor() *dataCursor { + return &dataCursor{ + chkID: -1, + chunkIter: chunk.NewIterator4Chunk(nil), + } +} + +func (d *dataCursor) getChkID() int { + return d.chkID +} + +func (d *dataCursor) begin() chunk.Row { + return d.chunkIter.Begin() +} + +func (d *dataCursor) next() chunk.Row { + return d.chunkIter.Next() +} + +func (d *dataCursor) setChunk(chk *chunk.Chunk, chkID int) { + d.chkID = chkID + d.chunkIter.ResetChunk(chk) +} + +func reloadCursor(cursor *dataCursor, inDisk *chunk.DataInDiskByChunks) (bool, error) { + spilledChkNum := inDisk.NumChunks() + restoredChkID := cursor.getChkID() + 1 + if restoredChkID >= spilledChkNum { + // All data has been consumed + return false, nil + } + + chk, err := inDisk.GetChunk(restoredChkID) + if err != nil { + return false, err + } + cursor.setChunk(chk, restoredChkID) + return true, nil +} diff --git a/pkg/executor/sortexec/topn.go b/pkg/executor/sortexec/topn.go index 5d49cb80703d4..a3e48c340e847 100644 --- a/pkg/executor/sortexec/topn.go +++ b/pkg/executor/sortexec/topn.go @@ -613,14 +613,14 @@ func (e *TopNExec) GetInMemoryThenSpillFlagForTest() bool { } func injectTopNRandomFail(triggerFactor int32) { - failpoint.Inject("TopNRandomFail", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("TopNRandomFail")); _err_ == nil { if val.(bool) { randNum := rand.Int31n(10000) if randNum < triggerFactor { panic("panic is triggered by random fail") } } - }) + } } // InitTopNExecForTest initializes TopN executors, only for test. diff --git a/pkg/executor/sortexec/topn.go__failpoint_stash__ b/pkg/executor/sortexec/topn.go__failpoint_stash__ new file mode 100644 index 0000000000000..5d49cb80703d4 --- /dev/null +++ b/pkg/executor/sortexec/topn.go__failpoint_stash__ @@ -0,0 +1,647 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "container/heap" + "context" + "math/rand" + "slices" + "sync" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +// TopNExec implements a Top-N algorithm and it is built from a SELECT statement with ORDER BY and LIMIT. +// Instead of sorting all the rows fetched from the table, it keeps the Top-N elements only in a heap to reduce memory usage. +type TopNExec struct { + SortExec + Limit *plannercore.PhysicalLimit + + // It's useful when spill is triggered and the fetcher could know when workers finish their works. + fetcherAndWorkerSyncer *sync.WaitGroup + resultChannel chan rowWithError + chunkChannel chan *chunk.Chunk + + finishCh chan struct{} + + chkHeap *topNChunkHeap + + spillHelper *topNSpillHelper + spillAction *topNSpillAction + + // Normally, heap will be stored in memory after it has been built. + // However, other executors may trigger topn spill after the heap is built + // and inMemoryThenSpillFlag will be set to true at this time. + inMemoryThenSpillFlag bool + + // Topn executor has two stage: + // 1. Building heap, in this stage all received rows will be inserted into heap. + // 2. Updating heap, in this stage only rows that is smaller than the heap top could be inserted and we will drop the heap top. + // + // This variable is only used for test. + isSpillTriggeredInStage1ForTest bool + isSpillTriggeredInStage2ForTest bool + + Concurrency int +} + +// Open implements the Executor Open interface. +func (e *TopNExec) Open(ctx context.Context) error { + e.memTracker = memory.NewTracker(e.ID(), -1) + e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) + + e.fetched = &atomic.Bool{} + e.fetched.Store(false) + e.chkHeap = &topNChunkHeap{memTracker: e.memTracker} + e.chkHeap.idx = 0 + + e.finishCh = make(chan struct{}, 1) + e.resultChannel = make(chan rowWithError, e.MaxChunkSize()) + e.chunkChannel = make(chan *chunk.Chunk, e.Concurrency) + e.inMemoryThenSpillFlag = false + e.isSpillTriggeredInStage1ForTest = false + e.isSpillTriggeredInStage2ForTest = false + + if variable.EnableTmpStorageOnOOM.Load() { + e.diskTracker = disk.NewTracker(e.ID(), -1) + diskTracker := e.Ctx().GetSessionVars().StmtCtx.DiskTracker + if diskTracker != nil { + e.diskTracker.AttachTo(diskTracker) + } + e.fetcherAndWorkerSyncer = &sync.WaitGroup{} + + workers := make([]*topNWorker, e.Concurrency) + for i := range workers { + chkHeap := &topNChunkHeap{} + // Offset of heap in worker should be 0, as we need to spill all data + chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, 0, e.greaterRow, e.RetFieldTypes()) + workers[i] = newTopNWorker(i, e.chunkChannel, e.fetcherAndWorkerSyncer, e.resultChannel, e.finishCh, e, chkHeap, e.memTracker) + } + + e.spillHelper = newTopNSpillerHelper( + e, + e.finishCh, + e.resultChannel, + e.memTracker, + e.diskTracker, + exec.RetTypes(e), + workers, + e.Concurrency, + ) + e.spillAction = &topNSpillAction{spillHelper: e.spillHelper} + e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.spillAction) + } else { + e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0) + } + + return exec.Open(ctx, e.Children(0)) +} + +// Close implements the Executor Close interface. +func (e *TopNExec) Close() error { + // `e.finishCh == nil` means that `Open` is not called. + if e.finishCh == nil { + return exec.Close(e.Children(0)) + } + + close(e.finishCh) + if e.fetched.CompareAndSwap(false, true) { + close(e.resultChannel) + return exec.Close(e.Children(0)) + } + + // Wait for the finish of all tasks + channel.Clear(e.resultChannel) + + e.chkHeap = nil + e.spillAction = nil + + if e.spillHelper != nil { + e.spillHelper.close() + e.spillHelper = nil + } + + if e.memTracker != nil { + e.memTracker.ReplaceBytesUsed(0) + } + + return exec.Close(e.Children(0)) +} + +func (e *TopNExec) greaterRow(rowI, rowJ chunk.Row) bool { + for i, colIdx := range e.keyColumns { + cmpFunc := e.keyCmpFuncs[i] + cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) + if e.ByItems[i].Desc { + cmp = -cmp + } + if cmp > 0 { + return true + } else if cmp < 0 { + return false + } + } + return false +} + +// Next implements the Executor Next interface. +// +// The following picture shows the procedure of topn when spill is triggered. +/* +Spill Stage: + ┌─────────┐ + │ Child │ + └────▲────┘ + │ + Fetch + │ + ┌───────┴───────┐ + │ Chunk Fetcher │ + └───────┬───────┘ + │ + │ + ▼ + Check Spill──────►Spill Triggered─────────►Spill + │ │ + ▼ │ + Spill Not Triggered │ + │ │ + ▼ │ + Push Chunk◄─────────────────────────────────┘ + │ + ▼ + ┌────────────────►Channel◄───────────────────┐ + │ ▲ │ + │ │ │ + Fetch Fetch Fetch + │ │ │ + ┌────┴───┐ ┌───┴────┐ ┌───┴────┐ + │ Worker │ │ Worker │ ...... │ Worker │ + └────┬───┘ └───┬────┘ └───┬────┘ + │ │ │ + │ │ │ + │ ▼ │ + └───────────► Multi-way Merge◄───────────────┘ + │ + │ + ▼ + Output + +Restore Stage: + ┌────────┐ ┌────────┐ ┌────────┐ + │ Heap │ │ Heap │ ...... │ Heap │ + └────┬───┘ └───┬────┘ └───┬────┘ + │ │ │ + │ │ │ + │ ▼ │ + └───────────► Multi-way Merge◄───────────────┘ + │ + │ + ▼ + Output + +*/ +func (e *TopNExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.Reset() + if e.fetched.CompareAndSwap(false, true) { + err := e.fetchChunks(ctx) + if err != nil { + return err + } + } + + if !req.IsFull() { + numToAppend := req.RequiredRows() - req.NumRows() + for i := 0; i < numToAppend; i++ { + row, ok := <-e.resultChannel + if !ok || row.err != nil { + return row.err + } + req.AppendRow(row.row) + } + } + return nil +} + +func (e *TopNExec) fetchChunks(ctx context.Context) error { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.resultChannel, r) + close(e.resultChannel) + } + }() + + err := e.loadChunksUntilTotalLimit(ctx) + if err != nil { + close(e.resultChannel) + return err + } + go e.executeTopN(ctx) + return nil +} + +func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { + err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) + if err != nil { + return err + } + + e.buildKeyColumns() + e.chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, int(e.Limit.Offset), e.greaterRow, e.RetFieldTypes()) + for uint64(e.chkHeap.rowChunks.Len()) < e.chkHeap.totalLimit { + srcChk := exec.TryNewCacheChunk(e.Children(0)) + // adjust required rows by total limit + srcChk.SetRequiredRows(int(e.chkHeap.totalLimit-uint64(e.chkHeap.rowChunks.Len())), e.MaxChunkSize()) + err := exec.Next(ctx, e.Children(0), srcChk) + if err != nil { + return err + } + if srcChk.NumRows() == 0 { + break + } + e.chkHeap.rowChunks.Add(srcChk) + if e.spillHelper.isSpillNeeded() { + e.isSpillTriggeredInStage1ForTest = true + break + } + + injectTopNRandomFail(1) + } + + e.chkHeap.initPtrs() + return nil +} + +const topNCompactionFactor = 4 + +func (e *TopNExec) executeTopNWhenNoSpillTriggered(ctx context.Context) error { + if e.spillHelper.isSpillNeeded() { + e.isSpillTriggeredInStage2ForTest = true + return nil + } + + childRowChk := exec.TryNewCacheChunk(e.Children(0)) + for { + if e.spillHelper.isSpillNeeded() { + e.isSpillTriggeredInStage2ForTest = true + return nil + } + + err := exec.Next(ctx, e.Children(0), childRowChk) + if err != nil { + return err + } + + if childRowChk.NumRows() == 0 { + break + } + + e.chkHeap.processChk(childRowChk) + + if e.chkHeap.rowChunks.Len() > len(e.chkHeap.rowPtrs)*topNCompactionFactor { + err = e.chkHeap.doCompaction(e) + if err != nil { + return err + } + } + injectTopNRandomFail(10) + } + + slices.SortFunc(e.chkHeap.rowPtrs, e.chkHeap.keyColumnsCompare) + return nil +} + +func (e *TopNExec) spillRemainingRowsWhenNeeded() error { + if e.spillHelper.isSpillTriggered() { + return e.spillHelper.spill() + } + return nil +} + +func (e *TopNExec) checkSpillAndExecute() error { + if e.spillHelper.isSpillNeeded() { + // Wait for the stop of all workers + e.fetcherAndWorkerSyncer.Wait() + return e.spillHelper.spill() + } + return nil +} + +func (e *TopNExec) fetchChunksFromChild(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.resultChannel, r) + } + + e.fetcherAndWorkerSyncer.Wait() + err := e.spillRemainingRowsWhenNeeded() + if err != nil { + e.resultChannel <- rowWithError{err: err} + } + + close(e.chunkChannel) + }() + + for { + chk := exec.TryNewCacheChunk(e.Children(0)) + err := exec.Next(ctx, e.Children(0), chk) + if err != nil { + e.resultChannel <- rowWithError{err: err} + return + } + + rowCount := chk.NumRows() + if rowCount == 0 { + break + } + + e.fetcherAndWorkerSyncer.Add(1) + select { + case <-e.finishCh: + e.fetcherAndWorkerSyncer.Done() + return + case e.chunkChannel <- chk: + } + + injectTopNRandomFail(10) + + err = e.checkSpillAndExecute() + if err != nil { + e.resultChannel <- rowWithError{err: err} + return + } + } +} + +// Spill the heap which is in TopN executor +func (e *TopNExec) spillTopNExecHeap() error { + e.spillHelper.setInSpilling() + defer e.spillHelper.cond.Broadcast() + defer e.spillHelper.setNotSpilled() + + err := e.spillHelper.spillHeap(e.chkHeap) + if err != nil { + return err + } + return nil +} + +func (e *TopNExec) executeTopNWhenSpillTriggered(ctx context.Context) error { + // idx need to be set to 0 as we need to spill all data + e.chkHeap.idx = 0 + err := e.spillTopNExecHeap() + if err != nil { + return err + } + + // Wait for the finish of chunk fetcher + fetcherWaiter := util.WaitGroupWrapper{} + // Wait for the finish of all workers + workersWaiter := util.WaitGroupWrapper{} + + for i := range e.spillHelper.workers { + worker := e.spillHelper.workers[i] + worker.initWorker() + workersWaiter.Run(func() { + worker.run() + }) + } + + // Fetch chunks from child and put chunks into chunkChannel + fetcherWaiter.Run(func() { + e.fetchChunksFromChild(ctx) + }) + + fetcherWaiter.Wait() + workersWaiter.Wait() + return nil +} + +func (e *TopNExec) executeTopN(ctx context.Context) { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(e.resultChannel, r) + } + + close(e.resultChannel) + }() + + heap.Init(e.chkHeap) + for uint64(len(e.chkHeap.rowPtrs)) > e.chkHeap.totalLimit { + // The number of rows we loaded may exceeds total limit, remove greatest rows by Pop. + heap.Pop(e.chkHeap) + } + + if err := e.executeTopNWhenNoSpillTriggered(ctx); err != nil { + e.resultChannel <- rowWithError{err: err} + return + } + + if e.spillHelper.isSpillNeeded() { + if err := e.executeTopNWhenSpillTriggered(ctx); err != nil { + e.resultChannel <- rowWithError{err: err} + return + } + } + + e.generateTopNResults() +} + +// Return true when spill is triggered +func (e *TopNExec) generateTopNResultsWhenNoSpillTriggered() bool { + rowPtrNum := len(e.chkHeap.rowPtrs) + for ; e.chkHeap.idx < rowPtrNum; e.chkHeap.idx++ { + if e.chkHeap.idx%10 == 0 && e.spillHelper.isSpillNeeded() { + return true + } + e.resultChannel <- rowWithError{row: e.chkHeap.rowChunks.GetRow(e.chkHeap.rowPtrs[e.chkHeap.idx])} + } + return false +} + +func (e *TopNExec) generateResultWithMultiWayMerge(offset int64, limit int64) error { + multiWayMerge := newMultiWayMerger(&diskSource{sortedRowsInDisk: e.spillHelper.sortedRowsInDisk}, e.lessRow) + + err := multiWayMerge.init() + if err != nil { + return err + } + + outputRowNum := int64(0) + for { + if outputRowNum >= limit { + return nil + } + + row, err := multiWayMerge.next() + if err != nil { + return err + } + + if row.IsEmpty() { + return nil + } + + if outputRowNum >= offset { + select { + case <-e.finishCh: + return nil + case e.resultChannel <- rowWithError{row: row}: + } + } + outputRowNum++ + injectParallelSortRandomFail(1) + } +} + +// GenerateTopNResultsWhenSpillOnlyOnce generates results with this function when we trigger spill only once. +// It's a public function as we need to test it in ut. +func (e *TopNExec) GenerateTopNResultsWhenSpillOnlyOnce() error { + inDisk := e.spillHelper.sortedRowsInDisk[0] + chunkNum := inDisk.NumChunks() + skippedRowNum := uint64(0) + offset := e.Limit.Offset + for i := 0; i < chunkNum; i++ { + chk, err := inDisk.GetChunk(i) + if err != nil { + return err + } + + injectTopNRandomFail(10) + + rowNum := chk.NumRows() + j := 0 + if !e.inMemoryThenSpillFlag { + // When e.inMemoryThenSpillFlag == false, we need to manually set j + // because rows that should be ignored before offset have also been + // spilled to disk. + if skippedRowNum < offset { + rowNumNeedSkip := offset - skippedRowNum + if rowNum <= int(rowNumNeedSkip) { + // All rows in this chunk should be skipped + skippedRowNum += uint64(rowNum) + continue + } + j += int(rowNumNeedSkip) + skippedRowNum += rowNumNeedSkip + } + } + + for ; j < rowNum; j++ { + select { + case <-e.finishCh: + return nil + case e.resultChannel <- rowWithError{row: chk.GetRow(j)}: + } + } + } + return nil +} + +func (e *TopNExec) generateTopNResultsWhenSpillTriggered() error { + inDiskNum := len(e.spillHelper.sortedRowsInDisk) + if inDiskNum == 0 { + panic("inDiskNum can't be 0 when we generate result with spill triggered") + } + + if inDiskNum == 1 { + return e.GenerateTopNResultsWhenSpillOnlyOnce() + } + return e.generateResultWithMultiWayMerge(int64(e.Limit.Offset), int64(e.Limit.Offset+e.Limit.Count)) +} + +func (e *TopNExec) generateTopNResults() { + if !e.spillHelper.isSpillTriggered() { + if !e.generateTopNResultsWhenNoSpillTriggered() { + return + } + + err := e.spillTopNExecHeap() + if err != nil { + e.resultChannel <- rowWithError{err: err} + } + + e.inMemoryThenSpillFlag = true + } + + err := e.generateTopNResultsWhenSpillTriggered() + if err != nil { + e.resultChannel <- rowWithError{err: err} + } +} + +// IsSpillTriggeredForTest shows if spill is triggered, used for test. +func (e *TopNExec) IsSpillTriggeredForTest() bool { + return e.spillHelper.isSpillTriggered() +} + +// GetIsSpillTriggeredInStage1ForTest shows if spill is triggered in stage 1, only used for test. +func (e *TopNExec) GetIsSpillTriggeredInStage1ForTest() bool { + return e.isSpillTriggeredInStage1ForTest +} + +// GetIsSpillTriggeredInStage2ForTest shows if spill is triggered in stage 2, only used for test. +func (e *TopNExec) GetIsSpillTriggeredInStage2ForTest() bool { + return e.isSpillTriggeredInStage2ForTest +} + +// GetInMemoryThenSpillFlagForTest shows if results are in memory before they are spilled, only used for test +func (e *TopNExec) GetInMemoryThenSpillFlagForTest() bool { + return e.inMemoryThenSpillFlag +} + +func injectTopNRandomFail(triggerFactor int32) { + failpoint.Inject("TopNRandomFail", func(val failpoint.Value) { + if val.(bool) { + randNum := rand.Int31n(10000) + if randNum < triggerFactor { + panic("panic is triggered by random fail") + } + } + }) +} + +// InitTopNExecForTest initializes TopN executors, only for test. +func InitTopNExecForTest(topnExec *TopNExec, offset uint64, sortedRowsInDisk *chunk.DataInDiskByChunks) { + topnExec.inMemoryThenSpillFlag = false + topnExec.finishCh = make(chan struct{}, 1) + topnExec.resultChannel = make(chan rowWithError, 10000) + topnExec.Limit.Offset = offset + topnExec.spillHelper = &topNSpillHelper{} + topnExec.spillHelper.sortedRowsInDisk = []*chunk.DataInDiskByChunks{sortedRowsInDisk} +} + +// GetResultForTest gets result, only for test. +func GetResultForTest(topnExec *TopNExec) []int64 { + close(topnExec.resultChannel) + result := make([]int64, 0, 100) + for { + row, ok := <-topnExec.resultChannel + if !ok { + return result + } + result = append(result, row.row.GetInt64(0)) + } +} diff --git a/pkg/executor/sortexec/topn_worker.go b/pkg/executor/sortexec/topn_worker.go index 527dcd42977e9..45b2b75843ff8 100644 --- a/pkg/executor/sortexec/topn_worker.go +++ b/pkg/executor/sortexec/topn_worker.go @@ -117,7 +117,7 @@ func (t *topNWorker) run() { func (t *topNWorker) injectFailPointForTopNWorker(triggerFactor int32) { injectTopNRandomFail(triggerFactor) - failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SlowSomeWorkers")); _err_ == nil { if val.(bool) { if t.workerIDForTest%2 == 0 { randNum := rand.Int31n(10000) @@ -126,5 +126,5 @@ func (t *topNWorker) injectFailPointForTopNWorker(triggerFactor int32) { } } } - }) + } } diff --git a/pkg/executor/sortexec/topn_worker.go__failpoint_stash__ b/pkg/executor/sortexec/topn_worker.go__failpoint_stash__ new file mode 100644 index 0000000000000..527dcd42977e9 --- /dev/null +++ b/pkg/executor/sortexec/topn_worker.go__failpoint_stash__ @@ -0,0 +1,130 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sortexec + +import ( + "container/heap" + "math/rand" + "sync" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +// topNWorker is used only when topn spill is triggered +type topNWorker struct { + workerIDForTest int + + chunkChannel <-chan *chunk.Chunk + fetcherAndWorkerSyncer *sync.WaitGroup + errOutputChan chan<- rowWithError + finishChan <-chan struct{} + + topn *TopNExec + chkHeap *topNChunkHeap + memTracker *memory.Tracker +} + +func newTopNWorker( + idForTest int, + chunkChannel <-chan *chunk.Chunk, + fetcherAndWorkerSyncer *sync.WaitGroup, + errOutputChan chan<- rowWithError, + finishChan <-chan struct{}, + topn *TopNExec, + chkHeap *topNChunkHeap, + memTracker *memory.Tracker) *topNWorker { + return &topNWorker{ + workerIDForTest: idForTest, + chunkChannel: chunkChannel, + fetcherAndWorkerSyncer: fetcherAndWorkerSyncer, + errOutputChan: errOutputChan, + finishChan: finishChan, + chkHeap: chkHeap, + topn: topn, + memTracker: memTracker, + } +} + +func (t *topNWorker) initWorker() { + // Offset of heap in worker should be 0, as we need to spill all data + t.chkHeap.init(t.topn, t.memTracker, t.topn.Limit.Offset+t.topn.Limit.Count, 0, t.topn.greaterRow, t.topn.RetFieldTypes()) +} + +func (t *topNWorker) fetchChunksAndProcess() { + for t.fetchChunksAndProcessImpl() { + } +} + +func (t *topNWorker) fetchChunksAndProcessImpl() bool { + select { + case <-t.finishChan: + return false + case chk, ok := <-t.chunkChannel: + if !ok { + return false + } + defer func() { + t.fetcherAndWorkerSyncer.Done() + }() + + t.injectFailPointForTopNWorker(3) + + if uint64(t.chkHeap.rowChunks.Len()) < t.chkHeap.totalLimit { + if !t.chkHeap.isInitialized { + t.chkHeap.init(t.topn, t.memTracker, t.topn.Limit.Offset+t.topn.Limit.Count, 0, t.topn.greaterRow, t.topn.RetFieldTypes()) + } + t.chkHeap.rowChunks.Add(chk) + } else { + if !t.chkHeap.isRowPtrsInit { + t.chkHeap.initPtrs() + heap.Init(t.chkHeap) + } + t.chkHeap.processChk(chk) + } + } + return true +} + +func (t *topNWorker) run() { + defer func() { + if r := recover(); r != nil { + processPanicAndLog(t.errOutputChan, r) + } + + // Consume all chunks to avoid hang of fetcher + for range t.chunkChannel { + t.fetcherAndWorkerSyncer.Done() + } + }() + + t.fetchChunksAndProcess() +} + +func (t *topNWorker) injectFailPointForTopNWorker(triggerFactor int32) { + injectTopNRandomFail(triggerFactor) + failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { + if val.(bool) { + if t.workerIDForTest%2 == 0 { + randNum := rand.Int31n(10000) + if randNum < 10 { + time.Sleep(1 * time.Millisecond) + } + } + } + }) +} diff --git a/pkg/executor/table_reader.go b/pkg/executor/table_reader.go index 60c83af717526..f59384d029d81 100644 --- a/pkg/executor/table_reader.go +++ b/pkg/executor/table_reader.go @@ -220,10 +220,10 @@ func (e *TableReaderExecutor) memUsage() int64 { func (e *TableReaderExecutor) Open(ctx context.Context) error { r, ctx := tracing.StartRegionEx(ctx, "TableReaderExecutor.Open") defer r.End() - failpoint.Inject("mockSleepInTableReaderNext", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("mockSleepInTableReaderNext")); _err_ == nil { ms := v.(int) time.Sleep(time.Millisecond * time.Duration(ms)) - }) + } if e.memTracker != nil { e.memTracker.Reset() diff --git a/pkg/executor/table_reader.go__failpoint_stash__ b/pkg/executor/table_reader.go__failpoint_stash__ new file mode 100644 index 0000000000000..60c83af717526 --- /dev/null +++ b/pkg/executor/table_reader.go__failpoint_stash__ @@ -0,0 +1,632 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "bytes" + "cmp" + "context" + "slices" + "time" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/distsql" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/executor/internal/builder" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + internalutil "github.com/pingcap/tidb/pkg/executor/internal/util" + "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/infoschema" + isctx "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + planctx "github.com/pingcap/tidb/pkg/planner/context" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/ranger" + rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" + "github.com/pingcap/tidb/pkg/util/size" + "github.com/pingcap/tidb/pkg/util/stringutil" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/pingcap/tipb/go-tipb" +) + +// make sure `TableReaderExecutor` implements `Executor`. +var _ exec.Executor = &TableReaderExecutor{} + +// selectResultHook is used to hack distsql.SelectWithRuntimeStats safely for testing. +type selectResultHook struct { + selectResultFunc func(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, + fieldTypes []*types.FieldType, copPlanIDs []int) (distsql.SelectResult, error) +} + +func (sr selectResultHook) SelectResult(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, + fieldTypes []*types.FieldType, copPlanIDs []int, rootPlanID int) (distsql.SelectResult, error) { + if sr.selectResultFunc == nil { + return distsql.SelectWithRuntimeStats(ctx, dctx, kvReq, fieldTypes, copPlanIDs, rootPlanID) + } + return sr.selectResultFunc(ctx, dctx, kvReq, fieldTypes, copPlanIDs) +} + +type kvRangeBuilder interface { + buildKeyRange(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([][]kv.KeyRange, error) + buildKeyRangeSeparately(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([]int64, [][]kv.KeyRange, error) +} + +// tableReaderExecutorContext is the execution context for the `TableReaderExecutor` +type tableReaderExecutorContext struct { + dctx *distsqlctx.DistSQLContext + rctx *rangerctx.RangerContext + buildPBCtx *planctx.BuildPBContext + ectx exprctx.BuildContext + + stmtMemTracker *memory.Tracker + + infoSchema isctx.MetaOnlyInfoSchema + getDDLOwner func(context.Context) (*infosync.ServerInfo, error) +} + +func (treCtx *tableReaderExecutorContext) GetInfoSchema() isctx.MetaOnlyInfoSchema { + return treCtx.infoSchema +} + +func (treCtx *tableReaderExecutorContext) GetDDLOwner(ctx context.Context) (*infosync.ServerInfo, error) { + if treCtx.getDDLOwner != nil { + return treCtx.getDDLOwner(ctx) + } + + return nil, errors.New("GetDDLOwner in a context without DDL") +} + +func newTableReaderExecutorContext(sctx sessionctx.Context) tableReaderExecutorContext { + // Explicitly get `ownerManager` out of the closure to show that the `tableReaderExecutorContext` itself doesn't + // depend on `sctx` directly. + // The context of some tests don't have `DDL`, so make it optional + var getDDLOwner func(ctx context.Context) (*infosync.ServerInfo, error) + ddl := domain.GetDomain(sctx).DDL() + if ddl != nil { + ownerManager := ddl.OwnerManager() + getDDLOwner = func(ctx context.Context) (*infosync.ServerInfo, error) { + ddlOwnerID, err := ownerManager.GetOwnerID(ctx) + if err != nil { + return nil, err + } + return infosync.GetServerInfoByID(ctx, ddlOwnerID) + } + } + + pctx := sctx.GetPlanCtx() + return tableReaderExecutorContext{ + dctx: sctx.GetDistSQLCtx(), + rctx: pctx.GetRangerCtx(), + buildPBCtx: pctx.GetBuildPBCtx(), + ectx: sctx.GetExprCtx(), + stmtMemTracker: sctx.GetSessionVars().StmtCtx.MemTracker, + infoSchema: pctx.GetInfoSchema(), + getDDLOwner: getDDLOwner, + } +} + +// TableReaderExecutor sends DAG request and reads table data from kv layer. +type TableReaderExecutor struct { + tableReaderExecutorContext + exec.BaseExecutorV2 + + table table.Table + + // The source of key ranges varies from case to case. + // It may be calculated from PhysicalPlan by executorBuilder, or calculated from argument by dataBuilder; + // It may be calculated from ranger.Ranger, or calculated from handles. + // The table ID may also change because of the partition table, and causes the key range to change. + // So instead of keeping a `range` struct field, it's better to define a interface. + kvRangeBuilder + // TODO: remove this field, use the kvRangeBuilder interface. + ranges []*ranger.Range + + // kvRanges are only use for union scan. + kvRanges []kv.KeyRange + dagPB *tipb.DAGRequest + startTS uint64 + txnScope string + readReplicaScope string + isStaleness bool + // FIXME: in some cases the data size can be more accurate after get the handles count, + // but we keep things simple as it needn't to be that accurate for now. + netDataSize float64 + // columns are only required by union scan and virtual column. + columns []*model.ColumnInfo + + // resultHandler handles the order of the result. Since (MAXInt64, MAXUint64] stores before [0, MaxInt64] physically + // for unsigned int. + resultHandler *tableResultHandler + plans []base.PhysicalPlan + tablePlan base.PhysicalPlan + + memTracker *memory.Tracker + selectResultHook // for testing + + keepOrder bool + desc bool + // byItems only for partition table with orderBy + pushedLimit + byItems []*util.ByItems + paging bool + storeType kv.StoreType + // corColInFilter tells whether there's correlated column in filter (both conditions in PhysicalSelection and LateMaterializationFilterCondition in PhysicalTableScan) + // If true, we will need to revise the dagPB (fill correlated column value in filter) each time call Open(). + corColInFilter bool + // corColInAccess tells whether there's correlated column in access conditions. + corColInAccess bool + // virtualColumnIndex records all the indices of virtual columns and sort them in definition + // to make sure we can compute the virtual column in right order. + virtualColumnIndex []int + // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. + virtualColumnRetFieldTypes []*types.FieldType + // batchCop indicates whether use super batch coprocessor request, only works for TiFlash engine. + batchCop bool + + // If dummy flag is set, this is not a real TableReader, it just provides the KV ranges for UnionScan. + // Used by the temporary table, cached table. + dummy bool +} + +// Table implements the dataSourceExecutor interface. +func (e *TableReaderExecutor) Table() table.Table { + return e.table +} + +func (e *TableReaderExecutor) setDummy() { + e.dummy = true +} + +func (e *TableReaderExecutor) memUsage() int64 { + const sizeofTableReaderExecutor = int64(unsafe.Sizeof(*(*TableReaderExecutor)(nil))) + + res := sizeofTableReaderExecutor + res += size.SizeOfPointer * int64(cap(e.ranges)) + for _, v := range e.ranges { + res += v.MemUsage() + } + res += kv.KeyRangeSliceMemUsage(e.kvRanges) + res += int64(e.dagPB.Size()) + // TODO: add more statistics + return res +} + +// Open initializes necessary variables for using this executor. +func (e *TableReaderExecutor) Open(ctx context.Context) error { + r, ctx := tracing.StartRegionEx(ctx, "TableReaderExecutor.Open") + defer r.End() + failpoint.Inject("mockSleepInTableReaderNext", func(v failpoint.Value) { + ms := v.(int) + time.Sleep(time.Millisecond * time.Duration(ms)) + }) + + if e.memTracker != nil { + e.memTracker.Reset() + } else { + e.memTracker = memory.NewTracker(e.ID(), -1) + } + e.memTracker.AttachTo(e.stmtMemTracker) + + var err error + if e.corColInFilter { + // If there's correlated column in filter, need to rewrite dagPB + if e.storeType == kv.TiFlash { + execs, err := builder.ConstructTreeBasedDistExec(e.buildPBCtx, e.tablePlan) + if err != nil { + return err + } + e.dagPB.RootExecutor = execs[0] + } else { + e.dagPB.Executors, err = builder.ConstructListBasedDistExec(e.buildPBCtx, e.plans) + if err != nil { + return err + } + } + } + if e.dctx.RuntimeStatsColl != nil { + collExec := true + e.dagPB.CollectExecutionSummaries = &collExec + } + if e.corColInAccess { + ts := e.plans[0].(*plannercore.PhysicalTableScan) + e.ranges, err = ts.ResolveCorrelatedColumns() + if err != nil { + return err + } + } + + e.resultHandler = &tableResultHandler{} + + firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(e.ranges, e.keepOrder, e.desc, e.table.Meta() != nil && e.table.Meta().IsCommonHandle) + + // Treat temporary table as dummy table, avoid sending distsql request to TiKV. + // Calculate the kv ranges here, UnionScan rely on this kv ranges. + // cached table and temporary table are similar + if e.dummy { + if e.desc && len(secondPartRanges) != 0 { + // TiKV support reverse scan and the `resultHandler` process the range order. + // While in UnionScan, it doesn't use reverse scan and reverse the final result rows manually. + // So things are differ, we need to reverse the kv range here. + // TODO: If we refactor UnionScan to use reverse scan, update the code here. + // [9734095886065816708 9734095886065816709] | [1 3] [65535 9734095886065816707] => before the following change + // [1 3] [65535 9734095886065816707] | [9734095886065816708 9734095886065816709] => ranges part reverse here + // [1 3 65535 9734095886065816707 9734095886065816708 9734095886065816709] => scan (normal order) in UnionScan + // [9734095886065816709 9734095886065816708 9734095886065816707 65535 3 1] => rows reverse in UnionScan + firstPartRanges, secondPartRanges = secondPartRanges, firstPartRanges + } + kvReq, err := e.buildKVReq(ctx, firstPartRanges) + if err != nil { + return err + } + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) + if len(secondPartRanges) != 0 { + kvReq, err = e.buildKVReq(ctx, secondPartRanges) + if err != nil { + return err + } + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) + } + return nil + } + + firstResult, err := e.buildResp(ctx, firstPartRanges) + if err != nil { + return err + } + if len(secondPartRanges) == 0 { + e.resultHandler.open(nil, firstResult) + return nil + } + var secondResult distsql.SelectResult + secondResult, err = e.buildResp(ctx, secondPartRanges) + if err != nil { + return err + } + e.resultHandler.open(firstResult, secondResult) + return nil +} + +// Next fills data into the chunk passed by its caller. +// The task was actually done by tableReaderHandler. +func (e *TableReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error { + if e.dummy { + // Treat temporary table as dummy table, avoid sending distsql request to TiKV. + req.Reset() + return nil + } + + logutil.Eventf(ctx, "table scan table: %s, range: %v", stringutil.MemoizeStr(func() string { + var tableName string + if meta := e.table.Meta(); meta != nil { + tableName = meta.Name.L + } + return tableName + }), e.ranges) + if err := e.resultHandler.nextChunk(ctx, req); err != nil { + return err + } + + err := table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, e.Schema().Columns, e.columns, e.ectx, req) + if err != nil { + return err + } + + return nil +} + +// Close implements the Executor Close interface. +func (e *TableReaderExecutor) Close() error { + var err error + if e.resultHandler != nil { + err = e.resultHandler.Close() + } + e.kvRanges = e.kvRanges[:0] + if e.dummy { + return nil + } + return err +} + +// buildResp first builds request and sends it to tikv using distsql.Select. It uses SelectResult returned by the callee +// to fetch all results. +func (e *TableReaderExecutor) buildResp(ctx context.Context, ranges []*ranger.Range) (distsql.SelectResult, error) { + if e.storeType == kv.TiFlash && e.kvRangeBuilder != nil { + if !e.batchCop { + // TiFlash cannot support to access multiple tables/partitions within one KVReq, so we have to build KVReq for each partition separately. + kvReqs, err := e.buildKVReqSeparately(ctx, ranges) + if err != nil { + return nil, err + } + var results []distsql.SelectResult + for _, kvReq := range kvReqs { + result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + if err != nil { + return nil, err + } + results = append(results, result) + } + return distsql.NewSerialSelectResults(results), nil + } + // Use PartitionTable Scan + kvReq, err := e.buildKVReqForPartitionTableScan(ctx, ranges) + if err != nil { + return nil, err + } + result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + if err != nil { + return nil, err + } + return result, nil + } + + // use sortedSelectResults here when pushDown limit for partition table. + if e.kvRangeBuilder != nil && e.byItems != nil { + kvReqs, err := e.buildKVReqSeparately(ctx, ranges) + if err != nil { + return nil, err + } + var results []distsql.SelectResult + for _, kvReq := range kvReqs { + result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + if err != nil { + return nil, err + } + results = append(results, result) + } + if len(results) == 1 { + return results[0], nil + } + return distsql.NewSortedSelectResults(e.ectx.GetEvalCtx(), results, e.Schema(), e.byItems, e.memTracker), nil + } + + kvReq, err := e.buildKVReq(ctx, ranges) + if err != nil { + return nil, err + } + kvReq.KeyRanges.SortByFunc(func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) + + result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) + if err != nil { + return nil, err + } + return result, nil +} + +func (e *TableReaderExecutor) buildKVReqSeparately(ctx context.Context, ranges []*ranger.Range) ([]*kv.Request, error) { + pids, kvRanges, err := e.kvRangeBuilder.buildKeyRangeSeparately(e.dctx, ranges) + if err != nil { + return nil, err + } + kvReqs := make([]*kv.Request, 0, len(kvRanges)) + for i, kvRange := range kvRanges { + e.kvRanges = append(e.kvRanges, kvRange...) + if err := internalutil.UpdateExecutorTableID(ctx, e.dagPB.RootExecutor, true, []int64{pids[i]}); err != nil { + return nil, err + } + var builder distsql.RequestBuilder + reqBuilder := builder.SetKeyRanges(kvRange) + kvReq, err := reqBuilder. + SetDAGRequest(e.dagPB). + SetStartTS(e.startTS). + SetDesc(e.desc). + SetKeepOrder(e.keepOrder). + SetTxnScope(e.txnScope). + SetReadReplicaScope(e.readReplicaScope). + SetFromSessionVars(e.dctx). + SetFromInfoSchema(e.GetInfoSchema()). + SetMemTracker(e.memTracker). + SetStoreType(e.storeType). + SetPaging(e.paging). + SetAllowBatchCop(e.batchCop). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). + Build() + if err != nil { + return nil, err + } + kvReqs = append(kvReqs, kvReq) + } + return kvReqs, nil +} + +func (e *TableReaderExecutor) buildKVReqForPartitionTableScan(ctx context.Context, ranges []*ranger.Range) (*kv.Request, error) { + pids, kvRanges, err := e.kvRangeBuilder.buildKeyRangeSeparately(e.dctx, ranges) + if err != nil { + return nil, err + } + partitionIDAndRanges := make([]kv.PartitionIDAndRanges, 0, len(pids)) + for i, kvRange := range kvRanges { + e.kvRanges = append(e.kvRanges, kvRange...) + partitionIDAndRanges = append(partitionIDAndRanges, kv.PartitionIDAndRanges{ + ID: pids[i], + KeyRanges: kvRange, + }) + } + if err := internalutil.UpdateExecutorTableID(ctx, e.dagPB.RootExecutor, true, pids); err != nil { + return nil, err + } + var builder distsql.RequestBuilder + reqBuilder := builder.SetPartitionIDAndRanges(partitionIDAndRanges) + kvReq, err := reqBuilder. + SetDAGRequest(e.dagPB). + SetStartTS(e.startTS). + SetDesc(e.desc). + SetKeepOrder(e.keepOrder). + SetTxnScope(e.txnScope). + SetReadReplicaScope(e.readReplicaScope). + SetFromSessionVars(e.dctx). + SetFromInfoSchema(e.GetInfoSchema()). + SetMemTracker(e.memTracker). + SetStoreType(e.storeType). + SetPaging(e.paging). + SetAllowBatchCop(e.batchCop). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). + Build() + if err != nil { + return nil, err + } + return kvReq, nil +} + +func (e *TableReaderExecutor) buildKVReq(ctx context.Context, ranges []*ranger.Range) (*kv.Request, error) { + var builder distsql.RequestBuilder + var reqBuilder *distsql.RequestBuilder + if e.kvRangeBuilder != nil { + kvRange, err := e.kvRangeBuilder.buildKeyRange(e.dctx, ranges) + if err != nil { + return nil, err + } + reqBuilder = builder.SetPartitionKeyRanges(kvRange) + } else { + reqBuilder = builder.SetHandleRanges(e.dctx, getPhysicalTableID(e.table), e.table.Meta() != nil && e.table.Meta().IsCommonHandle, ranges) + } + if e.table != nil && e.table.Type().IsClusterTable() { + copDestination := infoschema.GetClusterTableCopDestination(e.table.Meta().Name.L) + if copDestination == infoschema.DDLOwner { + serverInfo, err := e.GetDDLOwner(ctx) + if err != nil { + return nil, err + } + reqBuilder.SetTiDBServerID(serverInfo.ServerIDGetter()) + } + } + reqBuilder. + SetDAGRequest(e.dagPB). + SetStartTS(e.startTS). + SetDesc(e.desc). + SetKeepOrder(e.keepOrder). + SetTxnScope(e.txnScope). + SetReadReplicaScope(e.readReplicaScope). + SetIsStaleness(e.isStaleness). + SetFromSessionVars(e.dctx). + SetFromInfoSchema(e.GetInfoSchema()). + SetMemTracker(e.memTracker). + SetStoreType(e.storeType). + SetAllowBatchCop(e.batchCop). + SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). + SetPaging(e.paging). + SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias) + return reqBuilder.Build() +} + +func buildVirtualColumnIndex(schema *expression.Schema, columns []*model.ColumnInfo) []int { + virtualColumnIndex := make([]int, 0, len(columns)) + for i, col := range schema.Columns { + if col.VirtualExpr != nil { + virtualColumnIndex = append(virtualColumnIndex, i) + } + } + slices.SortFunc(virtualColumnIndex, func(i, j int) int { + return cmp.Compare(plannercore.FindColumnInfoByID(columns, schema.Columns[i].ID).Offset, + plannercore.FindColumnInfoByID(columns, schema.Columns[j].ID).Offset) + }) + return virtualColumnIndex +} + +// buildVirtualColumnInfo saves virtual column indices and sort them in definition order +func (e *TableReaderExecutor) buildVirtualColumnInfo() { + e.virtualColumnIndex, e.virtualColumnRetFieldTypes = buildVirtualColumnInfo(e.Schema(), e.columns) +} + +// buildVirtualColumnInfo saves virtual column indices and sort them in definition order +func buildVirtualColumnInfo(schema *expression.Schema, columns []*model.ColumnInfo) (colIndexs []int, retTypes []*types.FieldType) { + colIndexs = buildVirtualColumnIndex(schema, columns) + if len(colIndexs) > 0 { + retTypes = make([]*types.FieldType, len(colIndexs)) + for i, idx := range colIndexs { + retTypes[i] = schema.Columns[idx].RetType + } + } + return colIndexs, retTypes +} + +type tableResultHandler struct { + // If the pk is unsigned and we have KeepOrder=true and want ascending order, + // `optionalResult` will handles the request whose range is in signed int range, and + // `result` will handle the request whose range is exceed signed int range. + // If we want descending order, `optionalResult` will handles the request whose range is exceed signed, and + // the `result` will handle the request whose range is in signed. + // Otherwise, we just set `optionalFinished` true and the `result` handles the whole ranges. + optionalResult distsql.SelectResult + result distsql.SelectResult + + optionalFinished bool +} + +func (tr *tableResultHandler) open(optionalResult, result distsql.SelectResult) { + if optionalResult == nil { + tr.optionalFinished = true + tr.result = result + return + } + tr.optionalResult = optionalResult + tr.result = result + tr.optionalFinished = false +} + +func (tr *tableResultHandler) nextChunk(ctx context.Context, chk *chunk.Chunk) error { + if !tr.optionalFinished { + err := tr.optionalResult.Next(ctx, chk) + if err != nil { + return err + } + if chk.NumRows() > 0 { + return nil + } + tr.optionalFinished = true + } + return tr.result.Next(ctx, chk) +} + +func (tr *tableResultHandler) nextRaw(ctx context.Context) (data []byte, err error) { + if !tr.optionalFinished { + data, err = tr.optionalResult.NextRaw(ctx) + if err != nil { + return nil, err + } + if data != nil { + return data, nil + } + tr.optionalFinished = true + } + data, err = tr.result.NextRaw(ctx) + if err != nil { + return nil, err + } + return data, nil +} + +func (tr *tableResultHandler) Close() error { + err := closeAll(tr.optionalResult, tr.result) + tr.optionalResult, tr.result = nil, nil + return err +} diff --git a/pkg/executor/unionexec/binding__failpoint_binding__.go b/pkg/executor/unionexec/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..80646b919b715 --- /dev/null +++ b/pkg/executor/unionexec/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package unionexec + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/executor/unionexec/union.go b/pkg/executor/unionexec/union.go index 0b6326dc0261a..d952c4a65418a 100644 --- a/pkg/executor/unionexec/union.go +++ b/pkg/executor/unionexec/union.go @@ -149,9 +149,9 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.stopFetchData.Store(true) e.resultPool <- result } - failpoint.Inject("issue21441", func() { + if _, _err_ := failpoint.Eval(_curpkg_("issue21441")); _err_ == nil { atomic.AddInt32(&e.childInFlightForTest, 1) - }) + } for { if e.stopFetchData.Load().(bool) { return @@ -166,20 +166,20 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.resourcePools[workerID] <- result.chk break } - failpoint.Inject("issue21441", func() { + if _, _err_ := failpoint.Eval(_curpkg_("issue21441")); _err_ == nil { if int(atomic.LoadInt32(&e.childInFlightForTest)) > e.Concurrency { panic("the count of child in flight is larger than e.concurrency unexpectedly") } - }) + } e.resultPool <- result if result.err != nil { e.stopFetchData.Store(true) return } } - failpoint.Inject("issue21441", func() { + if _, _err_ := failpoint.Eval(_curpkg_("issue21441")); _err_ == nil { atomic.AddInt32(&e.childInFlightForTest, -1) - }) + } } } diff --git a/pkg/executor/unionexec/union.go__failpoint_stash__ b/pkg/executor/unionexec/union.go__failpoint_stash__ new file mode 100644 index 0000000000000..0b6326dc0261a --- /dev/null +++ b/pkg/executor/unionexec/union.go__failpoint_stash__ @@ -0,0 +1,232 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unionexec + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/executor/internal/exec" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/channel" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/syncutil" + "go.uber.org/zap" +) + +var ( + _ exec.Executor = &UnionExec{} +) + +// UnionExec pulls all it's children's result and returns to its parent directly. +// A "resultPuller" is started for every child to pull result from that child and push it to the "resultPool", the used +// "Chunk" is obtained from the corresponding "resourcePool". All resultPullers are running concurrently. +// +// +----------------+ +// +---> resourcePool 1 ---> | resultPuller 1 |-----+ +// | +----------------+ | +// | | +// | +----------------+ v +// +---> resourcePool 2 ---> | resultPuller 2 |-----> resultPool ---+ +// | +----------------+ ^ | +// | ...... | | +// | +----------------+ | | +// +---> resourcePool n ---> | resultPuller n |-----+ | +// | +----------------+ | +// | | +// | +-------------+ | +// |--------------------------| main thread | <---------------------+ +// +-------------+ +type UnionExec struct { + exec.BaseExecutor + Concurrency int + childIDChan chan int + + stopFetchData atomic.Value + + finished chan struct{} + resourcePools []chan *chunk.Chunk + resultPool chan *unionWorkerResult + + results []*chunk.Chunk + wg sync.WaitGroup + initialized bool + mu struct { + *syncutil.Mutex + maxOpenedChildID int + } + + childInFlightForTest int32 +} + +// unionWorkerResult stores the result for a union worker. +// A "resultPuller" is started for every child to pull result from that child, unionWorkerResult is used to store that pulled result. +// "src" is used for Chunk reuse: after pulling result from "resultPool", main-thread must push a valid unused Chunk to "src" to +// enable the corresponding "resultPuller" continue to work. +type unionWorkerResult struct { + chk *chunk.Chunk + err error + src chan<- *chunk.Chunk +} + +func (e *UnionExec) waitAllFinished() { + e.wg.Wait() + close(e.resultPool) +} + +// Open implements the Executor Open interface. +func (e *UnionExec) Open(context.Context) error { + e.stopFetchData.Store(false) + e.initialized = false + e.finished = make(chan struct{}) + e.mu.Mutex = &syncutil.Mutex{} + e.mu.maxOpenedChildID = -1 + return nil +} + +func (e *UnionExec) initialize(ctx context.Context) { + if e.Concurrency > e.ChildrenLen() { + e.Concurrency = e.ChildrenLen() + } + for i := 0; i < e.Concurrency; i++ { + e.results = append(e.results, exec.NewFirstChunk(e.Children(0))) + } + e.resultPool = make(chan *unionWorkerResult, e.Concurrency) + e.resourcePools = make([]chan *chunk.Chunk, e.Concurrency) + e.childIDChan = make(chan int, e.ChildrenLen()) + for i := 0; i < e.Concurrency; i++ { + e.resourcePools[i] = make(chan *chunk.Chunk, 1) + e.resourcePools[i] <- e.results[i] + e.wg.Add(1) + go e.resultPuller(ctx, i) + } + for i := 0; i < e.ChildrenLen(); i++ { + e.childIDChan <- i + } + close(e.childIDChan) + go e.waitAllFinished() +} + +func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { + result := &unionWorkerResult{ + err: nil, + chk: nil, + src: e.resourcePools[workerID], + } + defer func() { + if r := recover(); r != nil { + logutil.Logger(ctx).Error("resultPuller panicked", zap.Any("recover", r), zap.Stack("stack")) + result.err = util.GetRecoverError(r) + e.resultPool <- result + e.stopFetchData.Store(true) + } + e.wg.Done() + }() + for childID := range e.childIDChan { + e.mu.Lock() + if childID > e.mu.maxOpenedChildID { + e.mu.maxOpenedChildID = childID + } + e.mu.Unlock() + if err := exec.Open(ctx, e.Children(childID)); err != nil { + result.err = err + e.stopFetchData.Store(true) + e.resultPool <- result + } + failpoint.Inject("issue21441", func() { + atomic.AddInt32(&e.childInFlightForTest, 1) + }) + for { + if e.stopFetchData.Load().(bool) { + return + } + select { + case <-e.finished: + return + case result.chk = <-e.resourcePools[workerID]: + } + result.err = exec.Next(ctx, e.Children(childID), result.chk) + if result.err == nil && result.chk.NumRows() == 0 { + e.resourcePools[workerID] <- result.chk + break + } + failpoint.Inject("issue21441", func() { + if int(atomic.LoadInt32(&e.childInFlightForTest)) > e.Concurrency { + panic("the count of child in flight is larger than e.concurrency unexpectedly") + } + }) + e.resultPool <- result + if result.err != nil { + e.stopFetchData.Store(true) + return + } + } + failpoint.Inject("issue21441", func() { + atomic.AddInt32(&e.childInFlightForTest, -1) + }) + } +} + +// Next implements the Executor Next interface. +func (e *UnionExec) Next(ctx context.Context, req *chunk.Chunk) error { + req.GrowAndReset(e.MaxChunkSize()) + if !e.initialized { + e.initialize(ctx) + e.initialized = true + } + result, ok := <-e.resultPool + if !ok { + return nil + } + if result.err != nil { + return errors.Trace(result.err) + } + + if result.chk.NumCols() != req.NumCols() { + return errors.Errorf("Internal error: UnionExec chunk column count mismatch, req: %d, result: %d", + req.NumCols(), result.chk.NumCols()) + } + req.SwapColumns(result.chk) + result.src <- result.chk + return nil +} + +// Close implements the Executor Close interface. +func (e *UnionExec) Close() error { + if e.finished != nil { + close(e.finished) + } + e.results = nil + if e.resultPool != nil { + channel.Clear(e.resultPool) + } + e.resourcePools = nil + if e.childIDChan != nil { + channel.Clear(e.childIDChan) + } + // We do not need to acquire the e.mu.Lock since all the resultPuller can be + // promised to exit when reaching here (e.childIDChan been closed). + var firstErr error + for i := 0; i <= e.mu.maxOpenedChildID; i++ { + if err := exec.Close(e.Children(i)); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr +} diff --git a/pkg/expression/aggregation/binding__failpoint_binding__.go b/pkg/expression/aggregation/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..df733f383834f --- /dev/null +++ b/pkg/expression/aggregation/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package aggregation + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/expression/aggregation/explain.go b/pkg/expression/aggregation/explain.go index d89fc08d88dc6..d63fd9d4e8d7d 100644 --- a/pkg/expression/aggregation/explain.go +++ b/pkg/expression/aggregation/explain.go @@ -27,11 +27,11 @@ import ( func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string { var buffer bytes.Buffer showMode := false - failpoint.Inject("show-agg-mode", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("show-agg-mode")); _err_ == nil { if v.(bool) { showMode = true } - }) + } if showMode { fmt.Fprintf(&buffer, "%s(%s,", agg.Name, agg.Mode.ToString()) } else { diff --git a/pkg/expression/aggregation/explain.go__failpoint_stash__ b/pkg/expression/aggregation/explain.go__failpoint_stash__ new file mode 100644 index 0000000000000..d89fc08d88dc6 --- /dev/null +++ b/pkg/expression/aggregation/explain.go__failpoint_stash__ @@ -0,0 +1,80 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggregation + +import ( + "bytes" + "fmt" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/ast" +) + +// ExplainAggFunc generates explain information for a aggregation function. +func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string { + var buffer bytes.Buffer + showMode := false + failpoint.Inject("show-agg-mode", func(v failpoint.Value) { + if v.(bool) { + showMode = true + } + }) + if showMode { + fmt.Fprintf(&buffer, "%s(%s,", agg.Name, agg.Mode.ToString()) + } else { + fmt.Fprintf(&buffer, "%s(", agg.Name) + } + + if agg.HasDistinct { + buffer.WriteString("distinct ") + } + for i, arg := range agg.Args { + if agg.Name == ast.AggFuncGroupConcat && i == len(agg.Args)-1 { + if len(agg.OrderByItems) > 0 { + buffer.WriteString(" order by ") + for i, item := range agg.OrderByItems { + if item.Desc { + if normalized { + fmt.Fprintf(&buffer, "%s desc", item.Expr.ExplainNormalizedInfo()) + } else { + fmt.Fprintf(&buffer, "%s desc", item.Expr.ExplainInfo(ctx)) + } + } else { + if normalized { + fmt.Fprintf(&buffer, "%s", item.Expr.ExplainNormalizedInfo()) + } else { + fmt.Fprintf(&buffer, "%s", item.Expr.ExplainInfo(ctx)) + } + } + + if i+1 < len(agg.OrderByItems) { + buffer.WriteString(", ") + } + } + } + buffer.WriteString(" separator ") + } else if i != 0 { + buffer.WriteString(", ") + } + if normalized { + buffer.WriteString(arg.ExplainNormalizedInfo()) + } else { + buffer.WriteString(arg.ExplainInfo(ctx)) + } + } + buffer.WriteString(")") + return buffer.String() +} diff --git a/pkg/expression/binding__failpoint_binding__.go b/pkg/expression/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..7464c1d08001f --- /dev/null +++ b/pkg/expression/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package expression + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/expression/builtin_json.go b/pkg/expression/builtin_json.go index a6acf9844cfc5..6f62e0a8981ea 100644 --- a/pkg/expression/builtin_json.go +++ b/pkg/expression/builtin_json.go @@ -1881,9 +1881,9 @@ func (b *builtinJSONSchemaValidSig) evalInt(ctx EvalContext, row chunk.Row) (res if b.args[0].ConstLevel() >= ConstOnlyInContext { schema, err = b.schemaCache.getOrInitCache(ctx, func() (jsonschema.Schema, error) { - failpoint.Inject("jsonSchemaValidDisableCacheRefresh", func() { - failpoint.Return(jsonschema.Schema{}, errors.New("Cache refresh disabled by failpoint")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("jsonSchemaValidDisableCacheRefresh")); _err_ == nil { + return jsonschema.Schema{}, errors.New("Cache refresh disabled by failpoint") + } dataBin, err := schemaData.MarshalJSON() if err != nil { return jsonschema.Schema{}, err diff --git a/pkg/expression/builtin_json.go__failpoint_stash__ b/pkg/expression/builtin_json.go__failpoint_stash__ new file mode 100644 index 0000000000000..a6acf9844cfc5 --- /dev/null +++ b/pkg/expression/builtin_json.go__failpoint_stash__ @@ -0,0 +1,1940 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "bytes" + "context" + goJSON "encoding/json" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tipb/go-tipb" + "github.com/qri-io/jsonschema" +) + +var ( + _ functionClass = &jsonTypeFunctionClass{} + _ functionClass = &jsonExtractFunctionClass{} + _ functionClass = &jsonUnquoteFunctionClass{} + _ functionClass = &jsonQuoteFunctionClass{} + _ functionClass = &jsonSetFunctionClass{} + _ functionClass = &jsonInsertFunctionClass{} + _ functionClass = &jsonReplaceFunctionClass{} + _ functionClass = &jsonRemoveFunctionClass{} + _ functionClass = &jsonMergeFunctionClass{} + _ functionClass = &jsonObjectFunctionClass{} + _ functionClass = &jsonArrayFunctionClass{} + _ functionClass = &jsonMemberOfFunctionClass{} + _ functionClass = &jsonContainsFunctionClass{} + _ functionClass = &jsonOverlapsFunctionClass{} + _ functionClass = &jsonContainsPathFunctionClass{} + _ functionClass = &jsonValidFunctionClass{} + _ functionClass = &jsonArrayAppendFunctionClass{} + _ functionClass = &jsonArrayInsertFunctionClass{} + _ functionClass = &jsonMergePatchFunctionClass{} + _ functionClass = &jsonMergePreserveFunctionClass{} + _ functionClass = &jsonPrettyFunctionClass{} + _ functionClass = &jsonQuoteFunctionClass{} + _ functionClass = &jsonSchemaValidFunctionClass{} + _ functionClass = &jsonSearchFunctionClass{} + _ functionClass = &jsonStorageSizeFunctionClass{} + _ functionClass = &jsonDepthFunctionClass{} + _ functionClass = &jsonKeysFunctionClass{} + _ functionClass = &jsonLengthFunctionClass{} + + _ builtinFunc = &builtinJSONTypeSig{} + _ builtinFunc = &builtinJSONQuoteSig{} + _ builtinFunc = &builtinJSONUnquoteSig{} + _ builtinFunc = &builtinJSONArraySig{} + _ builtinFunc = &builtinJSONArrayAppendSig{} + _ builtinFunc = &builtinJSONArrayInsertSig{} + _ builtinFunc = &builtinJSONObjectSig{} + _ builtinFunc = &builtinJSONExtractSig{} + _ builtinFunc = &builtinJSONSetSig{} + _ builtinFunc = &builtinJSONInsertSig{} + _ builtinFunc = &builtinJSONReplaceSig{} + _ builtinFunc = &builtinJSONRemoveSig{} + _ builtinFunc = &builtinJSONMergeSig{} + _ builtinFunc = &builtinJSONMemberOfSig{} + _ builtinFunc = &builtinJSONContainsSig{} + _ builtinFunc = &builtinJSONOverlapsSig{} + _ builtinFunc = &builtinJSONStorageSizeSig{} + _ builtinFunc = &builtinJSONDepthSig{} + _ builtinFunc = &builtinJSONSchemaValidSig{} + _ builtinFunc = &builtinJSONSearchSig{} + _ builtinFunc = &builtinJSONKeysSig{} + _ builtinFunc = &builtinJSONKeys2ArgsSig{} + _ builtinFunc = &builtinJSONLengthSig{} + _ builtinFunc = &builtinJSONValidJSONSig{} + _ builtinFunc = &builtinJSONValidStringSig{} + _ builtinFunc = &builtinJSONValidOthersSig{} +) + +type jsonTypeFunctionClass struct { + baseFunctionClass +} + +type builtinJSONTypeSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONTypeSig) Clone() builtinFunc { + newSig := &builtinJSONTypeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETJson) + if err != nil { + return nil, err + } + charset, collate := ctx.GetCharsetInfo() + bf.tp.SetCharset(charset) + bf.tp.SetCollate(collate) + bf.tp.SetFlen(51) // flen of JSON_TYPE is length of UNSIGNED INTEGER. + bf.tp.AddFlag(mysql.BinaryFlag) + sig := &builtinJSONTypeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonTypeSig) + return sig, nil +} + +func (b *builtinJSONTypeSig) evalString(ctx EvalContext, row chunk.Row) (val string, isNull bool, err error) { + var j types.BinaryJSON + j, isNull, err = b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + return j.Type(), false, nil +} + +type jsonExtractFunctionClass struct { + baseFunctionClass +} + +type builtinJSONExtractSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONExtractSig) Clone() builtinFunc { + newSig := &builtinJSONExtractSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonExtractFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_extract") + } + return nil +} + +func (c *jsonExtractFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for range args[1:] { + argTps = append(argTps, types.ETString) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONExtractSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonExtractSig) + return sig, nil +} + +func (b *builtinJSONExtractSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return + } + pathExprs := make([]types.JSONPathExpression, 0, len(b.args)-1) + for _, arg := range b.args[1:] { + var s string + s, isNull, err = arg.EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + pathExpr, err := types.ParseJSONPathExpr(s) + if err != nil { + return res, true, err + } + pathExprs = append(pathExprs, pathExpr) + } + var found bool + if res, found = res.Extract(pathExprs); !found { + return res, true, nil + } + return res, false, nil +} + +type jsonUnquoteFunctionClass struct { + baseFunctionClass +} + +type builtinJSONUnquoteSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONUnquoteSig) Clone() builtinFunc { + newSig := &builtinJSONUnquoteSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonUnquoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrIncorrectType.GenWithStackByArgs("1", "json_unquote") + } + return nil +} + +func (c *jsonUnquoteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString) + if err != nil { + return nil, err + } + bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen()) + bf.tp.AddFlag(mysql.BinaryFlag) + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[0]) + sig := &builtinJSONUnquoteSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonUnquoteSig) + return sig, nil +} + +func (b *builtinJSONUnquoteSig) evalString(ctx EvalContext, row chunk.Row) (str string, isNull bool, err error) { + str, isNull, err = b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + if len(str) >= 2 && str[0] == '"' && str[len(str)-1] == '"' && !goJSON.Valid([]byte(str)) { + return "", false, types.ErrInvalidJSONText.GenWithStackByArgs("The document root must not be followed by other values.") + } + str, err = types.UnquoteString(str) + if err != nil { + return "", false, err + } + return str, false, nil +} + +type jsonSetFunctionClass struct { + baseFunctionClass +} + +type builtinJSONSetSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONSetSig) Clone() builtinFunc { + newSig := &builtinJSONSetSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonSetFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + if len(args)&1 != 1 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for i := 1; i < len(args)-1; i += 2 { + argTps = append(argTps, types.ETString, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + for i := 2; i < len(args); i += 2 { + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) + } + sig := &builtinJSONSetSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonSetSig) + return sig, nil +} + +func (b *builtinJSONSetSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = jsonModify(ctx, b.args, row, types.JSONModifySet) + return res, isNull, err +} + +type jsonInsertFunctionClass struct { + baseFunctionClass +} + +type builtinJSONInsertSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONInsertSig) Clone() builtinFunc { + newSig := &builtinJSONInsertSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonInsertFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + if len(args)&1 != 1 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for i := 1; i < len(args)-1; i += 2 { + argTps = append(argTps, types.ETString, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + for i := 2; i < len(args); i += 2 { + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) + } + sig := &builtinJSONInsertSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonInsertSig) + return sig, nil +} + +func (b *builtinJSONInsertSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = jsonModify(ctx, b.args, row, types.JSONModifyInsert) + return res, isNull, err +} + +type jsonReplaceFunctionClass struct { + baseFunctionClass +} + +type builtinJSONReplaceSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONReplaceSig) Clone() builtinFunc { + newSig := &builtinJSONReplaceSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonReplaceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + if len(args)&1 != 1 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for i := 1; i < len(args)-1; i += 2 { + argTps = append(argTps, types.ETString, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + for i := 2; i < len(args); i += 2 { + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) + } + sig := &builtinJSONReplaceSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonReplaceSig) + return sig, nil +} + +func (b *builtinJSONReplaceSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = jsonModify(ctx, b.args, row, types.JSONModifyReplace) + return res, isNull, err +} + +type jsonRemoveFunctionClass struct { + baseFunctionClass +} + +type builtinJSONRemoveSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONRemoveSig) Clone() builtinFunc { + newSig := &builtinJSONRemoveSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonRemoveFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for range args[1:] { + argTps = append(argTps, types.ETString) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONRemoveSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonRemoveSig) + return sig, nil +} + +func (b *builtinJSONRemoveSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + pathExprs := make([]types.JSONPathExpression, 0, len(b.args)-1) + for _, arg := range b.args[1:] { + var s string + s, isNull, err = arg.EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + var pathExpr types.JSONPathExpression + pathExpr, err = types.ParseJSONPathExpr(s) + if err != nil { + return res, true, err + } + pathExprs = append(pathExprs, pathExpr) + } + res, err = res.Remove(pathExprs) + if err != nil { + return res, true, err + } + return res, false, nil +} + +type jsonMergeFunctionClass struct { + baseFunctionClass +} + +func (c *jsonMergeFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + for i, arg := range args { + if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge") + } + } + return nil +} + +type builtinJSONMergeSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONMergeSig) Clone() builtinFunc { + newSig := &builtinJSONMergeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonMergeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONMergeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonMergeSig) + return sig, nil +} + +func (b *builtinJSONMergeSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + values := make([]types.BinaryJSON, 0, len(b.args)) + for _, arg := range b.args { + var value types.BinaryJSON + value, isNull, err = arg.EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + values = append(values, value) + } + res = types.MergeBinaryJSON(values) + // function "JSON_MERGE" is deprecated since MySQL 5.7.22. Synonym for function "JSON_MERGE_PRESERVE". + // See https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#function_json-merge + if b.pbCode == tipb.ScalarFuncSig_JsonMergeSig { + tc := typeCtx(ctx) + tc.AppendWarning(errDeprecatedSyntaxNoReplacement.FastGenByArgs("JSON_MERGE", "")) + } + return res, false, nil +} + +type jsonObjectFunctionClass struct { + baseFunctionClass +} + +type builtinJSONObjectSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONObjectSig) Clone() builtinFunc { + newSig := &builtinJSONObjectSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonObjectFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + if len(args)&1 != 0 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) + } + argTps := make([]types.EvalType, 0, len(args)) + for i := 0; i < len(args)-1; i += 2 { + if args[i].GetType(ctx.GetEvalCtx()).EvalType() == types.ETString && args[i].GetType(ctx.GetEvalCtx()).GetCharset() == charset.CharsetBin { + return nil, types.ErrInvalidJSONCharset.GenWithStackByArgs(args[i].GetType(ctx.GetEvalCtx()).GetCharset()) + } + argTps = append(argTps, types.ETString, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + for i := 1; i < len(args); i += 2 { + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) + } + sig := &builtinJSONObjectSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonObjectSig) + return sig, nil +} + +func (b *builtinJSONObjectSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + if len(b.args)&1 == 1 { + err = ErrIncorrectParameterCount.GenWithStackByArgs(ast.JSONObject) + return res, true, err + } + jsons := make(map[string]any, len(b.args)>>1) + var key string + var value types.BinaryJSON + for i, arg := range b.args { + if i&1 == 0 { + key, isNull, err = arg.EvalString(ctx, row) + if err != nil { + return res, true, err + } + if isNull { + return res, true, types.ErrJSONDocumentNULLKey + } + } else { + value, isNull, err = arg.EvalJSON(ctx, row) + if err != nil { + return res, true, err + } + if isNull { + value = types.CreateBinaryJSON(nil) + } + jsons[key] = value + } + } + bj, err := types.CreateBinaryJSONWithCheck(jsons) + if err != nil { + return res, true, err + } + return bj, false, nil +} + +type jsonArrayFunctionClass struct { + baseFunctionClass +} + +type builtinJSONArraySig struct { + baseBuiltinFunc +} + +func (b *builtinJSONArraySig) Clone() builtinFunc { + newSig := &builtinJSONArraySig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonArrayFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + for i := range args { + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) + } + sig := &builtinJSONArraySig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonArraySig) + return sig, nil +} + +func (b *builtinJSONArraySig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + jsons := make([]any, 0, len(b.args)) + for _, arg := range b.args { + j, isNull, err := arg.EvalJSON(ctx, row) + if err != nil { + return res, true, err + } + if isNull { + j = types.CreateBinaryJSON(nil) + } + jsons = append(jsons, j) + } + bj, err := types.CreateBinaryJSONWithCheck(jsons) + if err != nil { + return res, true, err + } + return bj, false, nil +} + +type jsonContainsPathFunctionClass struct { + baseFunctionClass +} + +type builtinJSONContainsPathSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONContainsPathSig) Clone() builtinFunc { + newSig := &builtinJSONContainsPathSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonContainsPathFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains_path") + } + return nil +} + +func (c *jsonContainsPathFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := []types.EvalType{types.ETJson, types.ETString} + for i := 3; i <= len(args); i++ { + argTps = append(argTps, types.ETString) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONContainsPathSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonContainsPathSig) + return sig, nil +} + +func (b *builtinJSONContainsPathSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + containType, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + containType = strings.ToLower(containType) + if containType != types.JSONContainsPathAll && containType != types.JSONContainsPathOne { + return res, true, types.ErrJSONBadOneOrAllArg.GenWithStackByArgs("json_contains_path") + } + var pathExpr types.JSONPathExpression + contains := int64(1) + for i := 2; i < len(b.args); i++ { + path, isNull, err := b.args[i].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + if pathExpr, err = types.ParseJSONPathExpr(path); err != nil { + return res, true, err + } + _, exists := obj.Extract([]types.JSONPathExpression{pathExpr}) + switch { + case exists && containType == types.JSONContainsPathOne: + return 1, false, nil + case !exists && containType == types.JSONContainsPathOne: + contains = 0 + case !exists && containType == types.JSONContainsPathAll: + return 0, false, nil + } + } + return contains, false, nil +} + +func jsonModify(ctx EvalContext, args []Expression, row chunk.Row, mt types.JSONModifyType) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + pathExprs := make([]types.JSONPathExpression, 0, (len(args)-1)/2+1) + for i := 1; i < len(args); i += 2 { + // TODO: We can cache pathExprs if args are constants. + var s string + s, isNull, err = args[i].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + var pathExpr types.JSONPathExpression + pathExpr, err = types.ParseJSONPathExpr(s) + if err != nil { + return res, true, err + } + pathExprs = append(pathExprs, pathExpr) + } + values := make([]types.BinaryJSON, 0, (len(args)-1)/2+1) + for i := 2; i < len(args); i += 2 { + var value types.BinaryJSON + value, isNull, err = args[i].EvalJSON(ctx, row) + if err != nil { + return res, true, err + } + if isNull { + value = types.CreateBinaryJSON(nil) + } + values = append(values, value) + } + res, err = res.Modify(pathExprs, values, mt) + if err != nil { + return res, true, err + } + return res, false, nil +} + +type jsonMemberOfFunctionClass struct { + baseFunctionClass +} + +type builtinJSONMemberOfSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONMemberOfSig) Clone() builtinFunc { + newSig := &builtinJSONMemberOfSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonMemberOfFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { + return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "member of") + } + return nil +} + +func (c *jsonMemberOfFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := []types.EvalType{types.ETJson, types.ETJson} + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[0]) + sig := &builtinJSONMemberOfSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonMemberOfSig) + return sig, nil +} + +func (b *builtinJSONMemberOfSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + target, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + obj, isNull, err := b.args[1].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + if obj.TypeCode != types.JSONTypeCodeArray { + return boolToInt64(types.CompareBinaryJSON(obj, target) == 0), false, nil + } + + elemCount := obj.GetElemCount() + for i := 0; i < elemCount; i++ { + if types.CompareBinaryJSON(obj.ArrayGetElem(i), target) == 0 { + return 1, false, nil + } + } + + return 0, false, nil +} + +type jsonContainsFunctionClass struct { + baseFunctionClass +} + +type builtinJSONContainsSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONContainsSig) Clone() builtinFunc { + newSig := &builtinJSONContainsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonContainsFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { + return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains") + } + if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { + return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_contains") + } + return nil +} + +func (c *jsonContainsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + + argTps := []types.EvalType{types.ETJson, types.ETJson} + if len(args) == 3 { + argTps = append(argTps, types.ETString) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONContainsSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonContainsSig) + return sig, nil +} + +func (b *builtinJSONContainsSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + target, isNull, err := b.args[1].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + var pathExpr types.JSONPathExpression + if len(b.args) == 3 { + path, isNull, err := b.args[2].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + pathExpr, err = types.ParseJSONPathExpr(path) + if err != nil { + return res, true, err + } + if pathExpr.CouldMatchMultipleValues() { + return res, true, types.ErrInvalidJSONPathMultipleSelection + } + var exists bool + obj, exists = obj.Extract([]types.JSONPathExpression{pathExpr}) + if !exists { + return res, true, nil + } + } + + if types.ContainsBinaryJSON(obj, target) { + return 1, false, nil + } + return 0, false, nil +} + +type jsonOverlapsFunctionClass struct { + baseFunctionClass +} + +type builtinJSONOverlapsSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONOverlapsSig) Clone() builtinFunc { + newSig := &builtinJSONOverlapsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonOverlapsFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { + return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_overlaps") + } + if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { + return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_overlaps") + } + return nil +} + +func (c *jsonOverlapsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + + argTps := []types.EvalType{types.ETJson, types.ETJson} + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONOverlapsSig{bf} + return sig, nil +} + +func (b *builtinJSONOverlapsSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + target, isNull, err := b.args[1].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + if types.OverlapsBinaryJSON(obj, target) { + return 1, false, nil + } + return 0, false, nil +} + +type jsonValidFunctionClass struct { + baseFunctionClass +} + +func (c *jsonValidFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + var sig builtinFunc + argType := args[0].GetType(ctx.GetEvalCtx()).EvalType() + switch argType { + case types.ETJson: + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) + if err != nil { + return nil, err + } + sig = &builtinJSONValidJSONSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonValidJsonSig) + case types.ETString: + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString) + if err != nil { + return nil, err + } + sig = &builtinJSONValidStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonValidStringSig) + default: + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argType) + if err != nil { + return nil, err + } + sig = &builtinJSONValidOthersSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonValidOthersSig) + } + return sig, nil +} + +type builtinJSONValidJSONSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONValidJSONSig) Clone() builtinFunc { + newSig := &builtinJSONValidJSONSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinJSONValidJSONSig. +// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid +func (b *builtinJSONValidJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + _, isNull, err = b.args[0].EvalJSON(ctx, row) + return 1, isNull, err +} + +type builtinJSONValidStringSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONValidStringSig) Clone() builtinFunc { + newSig := &builtinJSONValidStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinJSONValidStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid +func (b *builtinJSONValidStringSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + val, isNull, err := b.args[0].EvalString(ctx, row) + if err != nil || isNull { + return 0, isNull, err + } + + data := hack.Slice(val) + if goJSON.Valid(data) { + res = 1 + } + return res, false, nil +} + +type builtinJSONValidOthersSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONValidOthersSig) Clone() builtinFunc { + newSig := &builtinJSONValidOthersSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinJSONValidOthersSig. +// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid +func (b *builtinJSONValidOthersSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + return 0, false, nil +} + +type jsonArrayAppendFunctionClass struct { + baseFunctionClass +} + +type builtinJSONArrayAppendSig struct { + baseBuiltinFunc +} + +func (c *jsonArrayAppendFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if len(args) < 3 || (len(args)&1 != 1) { + return ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) + } + return nil +} + +func (c *jsonArrayAppendFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for i := 1; i < len(args)-1; i += 2 { + argTps = append(argTps, types.ETString, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + for i := 2; i < len(args); i += 2 { + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) + } + sig := &builtinJSONArrayAppendSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonArrayAppendSig) + return sig, nil +} + +func (b *builtinJSONArrayAppendSig) Clone() builtinFunc { + newSig := &builtinJSONArrayAppendSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinJSONArrayAppendSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = b.args[0].EvalJSON(ctx, row) + if err != nil || isNull { + return res, true, err + } + + for i := 1; i < len(b.args)-1; i += 2 { + // If JSON path is NULL, MySQL breaks and returns NULL. + s, sNull, err := b.args[i].EvalString(ctx, row) + if sNull || err != nil { + return res, true, err + } + value, vNull, err := b.args[i+1].EvalJSON(ctx, row) + if err != nil { + return res, true, err + } + if vNull { + value = types.CreateBinaryJSON(nil) + } + res, isNull, err = b.appendJSONArray(res, s, value) + if isNull || err != nil { + return res, isNull, err + } + } + return res, false, nil +} + +func (b *builtinJSONArrayAppendSig) appendJSONArray(res types.BinaryJSON, p string, v types.BinaryJSON) (types.BinaryJSON, bool, error) { + // We should do the following checks to get correct values in res.Extract + pathExpr, err := types.ParseJSONPathExpr(p) + if err != nil { + return res, true, err + } + if pathExpr.CouldMatchMultipleValues() { + return res, true, types.ErrInvalidJSONPathMultipleSelection + } + + obj, exists := res.Extract([]types.JSONPathExpression{pathExpr}) + if !exists { + // If path not exists, just do nothing and no errors. + return res, false, nil + } + + if obj.TypeCode != types.JSONTypeCodeArray { + // res.Extract will return a json object instead of an array if there is an object at path pathExpr. + // JSON_ARRAY_APPEND({"a": "b"}, "$", {"b": "c"}) => [{"a": "b"}, {"b", "c"}] + // We should wrap them to a single array first. + obj, err = types.CreateBinaryJSONWithCheck([]any{obj}) + if err != nil { + return res, true, err + } + } + + obj = types.MergeBinaryJSON([]types.BinaryJSON{obj, v}) + res, err = res.Modify([]types.JSONPathExpression{pathExpr}, []types.BinaryJSON{obj}, types.JSONModifySet) + return res, false, err +} + +type jsonArrayInsertFunctionClass struct { + baseFunctionClass +} + +type builtinJSONArrayInsertSig struct { + baseBuiltinFunc +} + +func (c *jsonArrayInsertFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + if len(args)&1 != 1 { + return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) + } + + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for i := 1; i < len(args)-1; i += 2 { + argTps = append(argTps, types.ETString, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + for i := 2; i < len(args); i += 2 { + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) + } + sig := &builtinJSONArrayInsertSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonArrayInsertSig) + return sig, nil +} + +func (b *builtinJSONArrayInsertSig) Clone() builtinFunc { + newSig := &builtinJSONArrayInsertSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinJSONArrayInsertSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = b.args[0].EvalJSON(ctx, row) + if err != nil || isNull { + return res, true, err + } + + for i := 1; i < len(b.args)-1; i += 2 { + // If JSON path is NULL, MySQL breaks and returns NULL. + s, isNull, err := b.args[i].EvalString(ctx, row) + if err != nil || isNull { + return res, true, err + } + + pathExpr, err := types.ParseJSONPathExpr(s) + if err != nil { + return res, true, err + } + if pathExpr.CouldMatchMultipleValues() { + return res, true, types.ErrInvalidJSONPathMultipleSelection + } + + value, isnull, err := b.args[i+1].EvalJSON(ctx, row) + if err != nil { + return res, true, err + } + + if isnull { + value = types.CreateBinaryJSON(nil) + } + + res, err = res.ArrayInsert(pathExpr, value) + if err != nil { + return res, true, err + } + } + return res, false, nil +} + +type jsonMergePatchFunctionClass struct { + baseFunctionClass +} + +func (c *jsonMergePatchFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + for i, arg := range args { + if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_patch") + } + } + return nil +} + +func (c *jsonMergePatchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONMergePatchSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonMergePatchSig) + return sig, nil +} + +type builtinJSONMergePatchSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONMergePatchSig) Clone() builtinFunc { + newSig := &builtinJSONMergePatchSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinJSONMergePatchSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + values := make([]*types.BinaryJSON, 0, len(b.args)) + for _, arg := range b.args { + var value types.BinaryJSON + value, isNull, err = arg.EvalJSON(ctx, row) + if err != nil { + return + } + if isNull { + values = append(values, nil) + } else { + values = append(values, &value) + } + } + tmpRes, err := types.MergePatchBinaryJSON(values) + if err != nil { + return + } + if tmpRes != nil { + res = *tmpRes + } else { + isNull = true + } + return res, isNull, nil +} + +type jsonMergePreserveFunctionClass struct { + baseFunctionClass +} + +func (c *jsonMergePreserveFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + for i, arg := range args { + if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_preserve") + } + } + return nil +} + +func (c *jsonMergePreserveFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, len(args)) + for range args { + argTps = append(argTps, types.ETJson) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONMergeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonMergePreserveSig) + return sig, nil +} + +type jsonPrettyFunctionClass struct { + baseFunctionClass +} + +type builtinJSONSPrettySig struct { + baseBuiltinFunc +} + +func (b *builtinJSONSPrettySig) Clone() builtinFunc { + newSig := &builtinJSONSPrettySig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonPrettyFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETJson) + if err != nil { + return nil, err + } + bf.tp.AddFlag(mysql.BinaryFlag) + bf.tp.SetFlen(mysql.MaxBlobWidth * 4) + sig := &builtinJSONSPrettySig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonPrettySig) + return sig, nil +} + +func (b *builtinJSONSPrettySig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + buf, err := obj.MarshalJSON() + if err != nil { + return res, isNull, err + } + var resBuf bytes.Buffer + if err = goJSON.Indent(&resBuf, buf, "", " "); err != nil { + return res, isNull, err + } + return resBuf.String(), false, nil +} + +type jsonQuoteFunctionClass struct { + baseFunctionClass +} + +type builtinJSONQuoteSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONQuoteSig) Clone() builtinFunc { + newSig := &builtinJSONQuoteSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonQuoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString { + return ErrIncorrectType.GenWithStackByArgs("1", "json_quote") + } + return nil +} + +func (c *jsonQuoteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString) + if err != nil { + return nil, err + } + DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[0]) + bf.tp.AddFlag(mysql.BinaryFlag) + bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen()*6 + 2) + sig := &builtinJSONQuoteSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonQuoteSig) + return sig, nil +} + +func (b *builtinJSONQuoteSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + str, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + buffer := &bytes.Buffer{} + encoder := goJSON.NewEncoder(buffer) + encoder.SetEscapeHTML(false) + err = encoder.Encode(str) + if err != nil { + return "", isNull, err + } + return string(bytes.TrimSuffix(buffer.Bytes(), []byte("\n"))), false, nil +} + +type jsonSearchFunctionClass struct { + baseFunctionClass +} + +type builtinJSONSearchSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONSearchSig) Clone() builtinFunc { + newSig := &builtinJSONSearchSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonSearchFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_search") + } + return nil +} + +func (c *jsonSearchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + // json_doc, one_or_all, search_str[, escape_char[, path] ...]) + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + for range args[1:] { + argTps = append(argTps, types.ETString) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONSearchSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonSearchSig) + return sig, nil +} + +func (b *builtinJSONSearchSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + // json_doc + var obj types.BinaryJSON + obj, isNull, err = b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + // one_or_all + var containType string + containType, isNull, err = b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + containType = strings.ToLower(containType) + if containType != types.JSONContainsPathAll && containType != types.JSONContainsPathOne { + return res, true, errors.AddStack(types.ErrInvalidJSONContainsPathType) + } + + // search_str & escape_char + var searchStr string + searchStr, isNull, err = b.args[2].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + escape := byte('\\') + if len(b.args) >= 4 { + var escapeStr string + escapeStr, isNull, err = b.args[3].EvalString(ctx, row) + if err != nil { + return res, isNull, err + } + if isNull || len(escapeStr) == 0 { + escape = byte('\\') + } else if len(escapeStr) == 1 { + escape = escapeStr[0] + } else { + return res, true, errIncorrectArgs.GenWithStackByArgs("ESCAPE") + } + } + if len(b.args) >= 5 { // path... + pathExprs := make([]types.JSONPathExpression, 0, len(b.args)-4) + for i := 4; i < len(b.args); i++ { + var s string + s, isNull, err = b.args[i].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + var pathExpr types.JSONPathExpression + pathExpr, err = types.ParseJSONPathExpr(s) + if err != nil { + return res, true, err + } + pathExprs = append(pathExprs, pathExpr) + } + return obj.Search(containType, searchStr, escape, pathExprs) + } + return obj.Search(containType, searchStr, escape, nil) +} + +type jsonStorageFreeFunctionClass struct { + baseFunctionClass +} + +type builtinJSONStorageFreeSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONStorageFreeSig) Clone() builtinFunc { + newSig := &builtinJSONStorageFreeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonStorageFreeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) + if err != nil { + return nil, err + } + sig := &builtinJSONStorageFreeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonStorageFreeSig) + return sig, nil +} + +func (b *builtinJSONStorageFreeSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + _, isNull, err = b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + return 0, false, nil +} + +type jsonStorageSizeFunctionClass struct { + baseFunctionClass +} + +type builtinJSONStorageSizeSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONStorageSizeSig) Clone() builtinFunc { + newSig := &builtinJSONStorageSizeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonStorageSizeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) + if err != nil { + return nil, err + } + sig := &builtinJSONStorageSizeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonStorageSizeSig) + return sig, nil +} + +func (b *builtinJSONStorageSizeSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + // returns the length of obj value plus 1 (the TypeCode) + return int64(len(obj.Value)) + 1, false, nil +} + +type jsonDepthFunctionClass struct { + baseFunctionClass +} + +type builtinJSONDepthSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONDepthSig) Clone() builtinFunc { + newSig := &builtinJSONDepthSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonDepthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) + if err != nil { + return nil, err + } + sig := &builtinJSONDepthSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonDepthSig) + return sig, nil +} + +func (b *builtinJSONDepthSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + // as TiDB doesn't support partial update json value, so only check the + // json format and whether it's NULL. For NULL return NULL, for invalid json, return + // an error, otherwise return 0 + + obj, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + return int64(obj.GetElemDepth()), false, nil +} + +type jsonKeysFunctionClass struct { + baseFunctionClass +} + +func (c *jsonKeysFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_keys") + } + return nil +} + +func (c *jsonKeysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + argTps := []types.EvalType{types.ETJson} + if len(args) == 2 { + argTps = append(argTps, types.ETString) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) + if err != nil { + return nil, err + } + var sig builtinFunc + switch len(args) { + case 1: + sig = &builtinJSONKeysSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonKeysSig) + case 2: + sig = &builtinJSONKeys2ArgsSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonKeys2ArgsSig) + } + return sig, nil +} + +type builtinJSONKeysSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONKeysSig) Clone() builtinFunc { + newSig := &builtinJSONKeysSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinJSONKeysSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + if res.TypeCode != types.JSONTypeCodeObject { + return res, true, nil + } + return res.GetKeys(), false, nil +} + +type builtinJSONKeys2ArgsSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONKeys2ArgsSig) Clone() builtinFunc { + newSig := &builtinJSONKeys2ArgsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinJSONKeys2ArgsSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { + res, isNull, err = b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + path, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + pathExpr, err := types.ParseJSONPathExpr(path) + if err != nil { + return res, true, err + } + if pathExpr.CouldMatchMultipleValues() { + return res, true, types.ErrInvalidJSONPathMultipleSelection + } + + res, exists := res.Extract([]types.JSONPathExpression{pathExpr}) + if !exists { + return res, true, nil + } + if res.TypeCode != types.JSONTypeCodeObject { + return res, true, nil + } + + return res.GetKeys(), false, nil +} + +type jsonLengthFunctionClass struct { + baseFunctionClass +} + +type builtinJSONLengthSig struct { + baseBuiltinFunc +} + +func (b *builtinJSONLengthSig) Clone() builtinFunc { + newSig := &builtinJSONLengthSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (c *jsonLengthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + argTps := make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETJson) + if len(args) == 2 { + argTps = append(argTps, types.ETString) + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + sig := &builtinJSONLengthSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_JsonLengthSig) + return sig, nil +} + +func (b *builtinJSONLengthSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + obj, isNull, err := b.args[0].EvalJSON(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + if len(b.args) == 2 { + path, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + pathExpr, err := types.ParseJSONPathExpr(path) + if err != nil { + return res, true, err + } + if pathExpr.CouldMatchMultipleValues() { + return res, true, types.ErrInvalidJSONPathMultipleSelection + } + + var exists bool + obj, exists = obj.Extract([]types.JSONPathExpression{pathExpr}) + if !exists { + return res, true, nil + } + } + + if obj.TypeCode != types.JSONTypeCodeObject && obj.TypeCode != types.JSONTypeCodeArray { + return 1, false, nil + } + return int64(obj.GetElemCount()), false, nil +} + +type jsonSchemaValidFunctionClass struct { + baseFunctionClass +} + +func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { + if err := c.baseFunctionClass.verifyArgs(args); err != nil { + return err + } + if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_schema_valid") + } + if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { + return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_schema_valid") + } + if c, ok := args[0].(*Constant); ok { + // If args[0] is NULL, then don't check the length of *both* arguments. + // JSON_SCHEMA_VALID(NULL,NULL) -> NULL + // JSON_SCHEMA_VALID(NULL,'') -> NULL + // JSON_SCHEMA_VALID('',NULL) -> ErrInvalidJSONTextInParam + if !c.Value.IsNull() { + if len(c.Value.GetBytes()) == 0 { + return types.ErrInvalidJSONTextInParam.GenWithStackByArgs( + 1, "json_schema_valid", "The document is empty.", 0) + } + if c1, ok := args[1].(*Constant); ok { + if !c1.Value.IsNull() && len(c1.Value.GetBytes()) == 0 { + return types.ErrInvalidJSONTextInParam.GenWithStackByArgs( + 2, "json_schema_valid", "The document is empty.", 0) + } + } + } + } + return nil +} + +func (c *jsonSchemaValidFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson, types.ETJson) + if err != nil { + return nil, err + } + + sig := &builtinJSONSchemaValidSig{baseBuiltinFunc: bf} + return sig, nil +} + +type builtinJSONSchemaValidSig struct { + baseBuiltinFunc + + schemaCache builtinFuncCache[jsonschema.Schema] +} + +func (b *builtinJSONSchemaValidSig) Clone() builtinFunc { + newSig := &builtinJSONSchemaValidSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinJSONSchemaValidSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { + var schema jsonschema.Schema + + // First argument is the schema + schemaData, schemaIsNull, err := b.args[0].EvalJSON(ctx, row) + if err != nil { + return res, false, err + } + if schemaIsNull { + return res, true, err + } + + if b.args[0].ConstLevel() >= ConstOnlyInContext { + schema, err = b.schemaCache.getOrInitCache(ctx, func() (jsonschema.Schema, error) { + failpoint.Inject("jsonSchemaValidDisableCacheRefresh", func() { + failpoint.Return(jsonschema.Schema{}, errors.New("Cache refresh disabled by failpoint")) + }) + dataBin, err := schemaData.MarshalJSON() + if err != nil { + return jsonschema.Schema{}, err + } + if err := goJSON.Unmarshal(dataBin, &schema); err != nil { + if _, ok := err.(*goJSON.UnmarshalTypeError); ok { + return jsonschema.Schema{}, + types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", "object") + } + return jsonschema.Schema{}, + types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", err) + } + return schema, nil + }) + if err != nil { + return res, false, err + } + } else { + dataBin, err := schemaData.MarshalJSON() + if err != nil { + return res, false, err + } + if err := goJSON.Unmarshal(dataBin, &schema); err != nil { + if _, ok := err.(*goJSON.UnmarshalTypeError); ok { + return res, false, + types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", "object") + } + return res, false, + types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", err) + } + } + + // Second argument is the JSON document + docData, docIsNull, err := b.args[1].EvalJSON(ctx, row) + if err != nil { + return res, false, err + } + if docIsNull { + return res, true, err + } + docDataBin, err := docData.MarshalJSON() + if err != nil { + return res, false, err + } + errs, err := schema.ValidateBytes(context.Background(), docDataBin) + if err != nil { + return res, false, err + } + if len(errs) > 0 { + return res, false, nil + } + res = 1 + return res, false, nil +} diff --git a/pkg/expression/builtin_time.go b/pkg/expression/builtin_time.go index e2c58fd32699c..8af6b1796dfce 100644 --- a/pkg/expression/builtin_time.go +++ b/pkg/expression/builtin_time.go @@ -2507,9 +2507,9 @@ func evalNowWithFsp(ctx EvalContext, fsp int) (types.Time, bool, error) { return types.ZeroTime, true, err } - failpoint.Inject("injectNow", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectNow")); _err_ == nil { nowTs = time.Unix(int64(val.(int)), 0) - }) + } // In MySQL's implementation, now() will truncate the result instead of rounding it. // Results below are from MySQL 5.7, which can prove it. @@ -6723,10 +6723,10 @@ func GetStmtMinSafeTime(sc *stmtctx.StatementContext, store kv.Storage, timeZone minSafeTS = store.GetMinSafeTS(txnScope) } // Inject mocked SafeTS for test. - failpoint.Inject("injectSafeTS", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectSafeTS")); _err_ == nil { injectTS := val.(int) minSafeTS = uint64(injectTS) - }) + } // Try to get from the stmt cache to make sure this function is deterministic. minSafeTS = sc.GetOrStoreStmtCache(stmtctx.StmtSafeTSCacheKey, minSafeTS).(uint64) return oracle.GetTimeFromTS(minSafeTS).In(timeZone) diff --git a/pkg/expression/builtin_time.go__failpoint_stash__ b/pkg/expression/builtin_time.go__failpoint_stash__ new file mode 100644 index 0000000000000..e2c58fd32699c --- /dev/null +++ b/pkg/expression/builtin_time.go__failpoint_stash__ @@ -0,0 +1,6832 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package expression + +import ( + "context" + "fmt" + "math" + "regexp" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression/contextopt" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/mathutil" + "github.com/pingcap/tidb/pkg/util/parser" + "github.com/pingcap/tipb/go-tipb" + "github.com/tikv/client-go/v2/oracle" + "go.uber.org/zap" +) + +const ( // GET_FORMAT first argument. + dateFormat = "DATE" + datetimeFormat = "DATETIME" + timestampFormat = "TIMESTAMP" + timeFormat = "TIME" +) + +const ( // GET_FORMAT location. + usaLocation = "USA" + jisLocation = "JIS" + isoLocation = "ISO" + eurLocation = "EUR" + internalLocation = "INTERNAL" +) + +var ( + // durationPattern checks whether a string matches the format of duration. + durationPattern = regexp.MustCompile(`^\s*[-]?(((\d{1,2}\s+)?0*\d{0,3}(:0*\d{1,2}){0,2})|(\d{1,7}))?(\.\d*)?\s*$`) + + // timestampPattern checks whether a string matches the format of timestamp. + timestampPattern = regexp.MustCompile(`^\s*0*\d{1,4}([^\d]0*\d{1,2}){2}\s+(0*\d{0,2}([^\d]0*\d{1,2}){2})?(\.\d*)?\s*$`) + + // datePattern determine whether to match the format of date. + datePattern = regexp.MustCompile(`^\s*((0*\d{1,4}([^\d]0*\d{1,2}){2})|(\d{2,4}(\d{2}){2}))\s*$`) +) + +var ( + _ functionClass = &dateFunctionClass{} + _ functionClass = &dateLiteralFunctionClass{} + _ functionClass = &dateDiffFunctionClass{} + _ functionClass = &timeDiffFunctionClass{} + _ functionClass = &dateFormatFunctionClass{} + _ functionClass = &hourFunctionClass{} + _ functionClass = &minuteFunctionClass{} + _ functionClass = &secondFunctionClass{} + _ functionClass = µSecondFunctionClass{} + _ functionClass = &monthFunctionClass{} + _ functionClass = &monthNameFunctionClass{} + _ functionClass = &nowFunctionClass{} + _ functionClass = &dayNameFunctionClass{} + _ functionClass = &dayOfMonthFunctionClass{} + _ functionClass = &dayOfWeekFunctionClass{} + _ functionClass = &dayOfYearFunctionClass{} + _ functionClass = &weekFunctionClass{} + _ functionClass = &weekDayFunctionClass{} + _ functionClass = &weekOfYearFunctionClass{} + _ functionClass = &yearFunctionClass{} + _ functionClass = &yearWeekFunctionClass{} + _ functionClass = &fromUnixTimeFunctionClass{} + _ functionClass = &getFormatFunctionClass{} + _ functionClass = &strToDateFunctionClass{} + _ functionClass = &sysDateFunctionClass{} + _ functionClass = ¤tDateFunctionClass{} + _ functionClass = ¤tTimeFunctionClass{} + _ functionClass = &timeFunctionClass{} + _ functionClass = &timeLiteralFunctionClass{} + _ functionClass = &utcDateFunctionClass{} + _ functionClass = &utcTimestampFunctionClass{} + _ functionClass = &extractFunctionClass{} + _ functionClass = &unixTimestampFunctionClass{} + _ functionClass = &addTimeFunctionClass{} + _ functionClass = &convertTzFunctionClass{} + _ functionClass = &makeDateFunctionClass{} + _ functionClass = &makeTimeFunctionClass{} + _ functionClass = &periodAddFunctionClass{} + _ functionClass = &periodDiffFunctionClass{} + _ functionClass = &quarterFunctionClass{} + _ functionClass = &secToTimeFunctionClass{} + _ functionClass = &subTimeFunctionClass{} + _ functionClass = &timeFormatFunctionClass{} + _ functionClass = &timeToSecFunctionClass{} + _ functionClass = ×tampAddFunctionClass{} + _ functionClass = &toDaysFunctionClass{} + _ functionClass = &toSecondsFunctionClass{} + _ functionClass = &utcTimeFunctionClass{} + _ functionClass = ×tampFunctionClass{} + _ functionClass = ×tampLiteralFunctionClass{} + _ functionClass = &lastDayFunctionClass{} + _ functionClass = &addSubDateFunctionClass{} +) + +var ( + _ builtinFunc = &builtinDateSig{} + _ builtinFunc = &builtinDateLiteralSig{} + _ builtinFunc = &builtinDateDiffSig{} + _ builtinFunc = &builtinNullTimeDiffSig{} + _ builtinFunc = &builtinTimeStringTimeDiffSig{} + _ builtinFunc = &builtinDurationStringTimeDiffSig{} + _ builtinFunc = &builtinDurationDurationTimeDiffSig{} + _ builtinFunc = &builtinStringTimeTimeDiffSig{} + _ builtinFunc = &builtinStringDurationTimeDiffSig{} + _ builtinFunc = &builtinStringStringTimeDiffSig{} + _ builtinFunc = &builtinTimeTimeTimeDiffSig{} + _ builtinFunc = &builtinDateFormatSig{} + _ builtinFunc = &builtinHourSig{} + _ builtinFunc = &builtinMinuteSig{} + _ builtinFunc = &builtinSecondSig{} + _ builtinFunc = &builtinMicroSecondSig{} + _ builtinFunc = &builtinMonthSig{} + _ builtinFunc = &builtinMonthNameSig{} + _ builtinFunc = &builtinNowWithArgSig{} + _ builtinFunc = &builtinNowWithoutArgSig{} + _ builtinFunc = &builtinDayNameSig{} + _ builtinFunc = &builtinDayOfMonthSig{} + _ builtinFunc = &builtinDayOfWeekSig{} + _ builtinFunc = &builtinDayOfYearSig{} + _ builtinFunc = &builtinWeekWithModeSig{} + _ builtinFunc = &builtinWeekWithoutModeSig{} + _ builtinFunc = &builtinWeekDaySig{} + _ builtinFunc = &builtinWeekOfYearSig{} + _ builtinFunc = &builtinYearSig{} + _ builtinFunc = &builtinYearWeekWithModeSig{} + _ builtinFunc = &builtinYearWeekWithoutModeSig{} + _ builtinFunc = &builtinGetFormatSig{} + _ builtinFunc = &builtinSysDateWithFspSig{} + _ builtinFunc = &builtinSysDateWithoutFspSig{} + _ builtinFunc = &builtinCurrentDateSig{} + _ builtinFunc = &builtinCurrentTime0ArgSig{} + _ builtinFunc = &builtinCurrentTime1ArgSig{} + _ builtinFunc = &builtinTimeSig{} + _ builtinFunc = &builtinTimeLiteralSig{} + _ builtinFunc = &builtinUTCDateSig{} + _ builtinFunc = &builtinUTCTimestampWithArgSig{} + _ builtinFunc = &builtinUTCTimestampWithoutArgSig{} + _ builtinFunc = &builtinAddDatetimeAndDurationSig{} + _ builtinFunc = &builtinAddDatetimeAndStringSig{} + _ builtinFunc = &builtinAddTimeDateTimeNullSig{} + _ builtinFunc = &builtinAddStringAndDurationSig{} + _ builtinFunc = &builtinAddStringAndStringSig{} + _ builtinFunc = &builtinAddTimeStringNullSig{} + _ builtinFunc = &builtinAddDurationAndDurationSig{} + _ builtinFunc = &builtinAddDurationAndStringSig{} + _ builtinFunc = &builtinAddTimeDurationNullSig{} + _ builtinFunc = &builtinAddDateAndDurationSig{} + _ builtinFunc = &builtinAddDateAndStringSig{} + _ builtinFunc = &builtinSubDatetimeAndDurationSig{} + _ builtinFunc = &builtinSubDatetimeAndStringSig{} + _ builtinFunc = &builtinSubTimeDateTimeNullSig{} + _ builtinFunc = &builtinSubStringAndDurationSig{} + _ builtinFunc = &builtinSubStringAndStringSig{} + _ builtinFunc = &builtinSubTimeStringNullSig{} + _ builtinFunc = &builtinSubDurationAndDurationSig{} + _ builtinFunc = &builtinSubDurationAndStringSig{} + _ builtinFunc = &builtinSubTimeDurationNullSig{} + _ builtinFunc = &builtinSubDateAndDurationSig{} + _ builtinFunc = &builtinSubDateAndStringSig{} + _ builtinFunc = &builtinUnixTimestampCurrentSig{} + _ builtinFunc = &builtinUnixTimestampIntSig{} + _ builtinFunc = &builtinUnixTimestampDecSig{} + _ builtinFunc = &builtinConvertTzSig{} + _ builtinFunc = &builtinMakeDateSig{} + _ builtinFunc = &builtinMakeTimeSig{} + _ builtinFunc = &builtinPeriodAddSig{} + _ builtinFunc = &builtinPeriodDiffSig{} + _ builtinFunc = &builtinQuarterSig{} + _ builtinFunc = &builtinSecToTimeSig{} + _ builtinFunc = &builtinTimeToSecSig{} + _ builtinFunc = &builtinTimestampAddSig{} + _ builtinFunc = &builtinToDaysSig{} + _ builtinFunc = &builtinToSecondsSig{} + _ builtinFunc = &builtinUTCTimeWithArgSig{} + _ builtinFunc = &builtinUTCTimeWithoutArgSig{} + _ builtinFunc = &builtinTimestamp1ArgSig{} + _ builtinFunc = &builtinTimestamp2ArgsSig{} + _ builtinFunc = &builtinTimestampLiteralSig{} + _ builtinFunc = &builtinLastDaySig{} + _ builtinFunc = &builtinStrToDateDateSig{} + _ builtinFunc = &builtinStrToDateDatetimeSig{} + _ builtinFunc = &builtinStrToDateDurationSig{} + _ builtinFunc = &builtinFromUnixTime1ArgSig{} + _ builtinFunc = &builtinFromUnixTime2ArgSig{} + _ builtinFunc = &builtinExtractDatetimeFromStringSig{} + _ builtinFunc = &builtinExtractDatetimeSig{} + _ builtinFunc = &builtinExtractDurationSig{} + _ builtinFunc = &builtinAddSubDateAsStringSig{} + _ builtinFunc = &builtinAddSubDateDatetimeAnySig{} + _ builtinFunc = &builtinAddSubDateDurationAnySig{} +) + +func convertTimeToMysqlTime(t time.Time, fsp int, roundMode types.RoundMode) (types.Time, error) { + var tr time.Time + var err error + if roundMode == types.ModeTruncate { + tr, err = types.TruncateFrac(t, fsp) + } else { + tr, err = types.RoundFrac(t, fsp) + } + if err != nil { + return types.ZeroTime, err + } + + return types.NewTime(types.FromGoTime(tr), mysql.TypeDatetime, fsp), nil +} + +type dateFunctionClass struct { + baseFunctionClass +} + +func (c *dateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig := &builtinDateSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Date) + return sig, nil +} + +type builtinDateSig struct { + baseBuiltinFunc +} + +func (b *builtinDateSig) Clone() builtinFunc { + newSig := &builtinDateSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals DATE(expr). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date +func (b *builtinDateSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + expr, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + + if expr.IsZero() && sqlMode(ctx).HasNoZeroDateMode() { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, expr.String())) + } + + expr.SetCoreTime(types.FromDate(expr.Year(), expr.Month(), expr.Day(), 0, 0, 0, 0)) + expr.SetType(mysql.TypeDate) + return expr, false, nil +} + +type dateLiteralFunctionClass struct { + baseFunctionClass +} + +func (c *dateLiteralFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + con, ok := args[0].(*Constant) + if !ok { + panic("Unexpected parameter for date literal") + } + dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, err + } + str := dt.GetString() + if !datePattern.MatchString(str) { + return nil, types.ErrWrongValue.GenWithStackByArgs(types.DateStr, str) + } + tm, err := types.ParseDate(ctx.GetEvalCtx().TypeCtx(), str) + if err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, []Expression{}, types.ETDatetime) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig := &builtinDateLiteralSig{bf, tm} + return sig, nil +} + +type builtinDateLiteralSig struct { + baseBuiltinFunc + literal types.Time +} + +func (b *builtinDateLiteralSig) Clone() builtinFunc { + newSig := &builtinDateLiteralSig{literal: b.literal} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals DATE 'stringLit'. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-literals.html +func (b *builtinDateLiteralSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + mode := sqlMode(ctx) + if mode.HasNoZeroDateMode() && b.literal.IsZero() { + return b.literal, true, types.ErrWrongValue.GenWithStackByArgs(types.DateStr, b.literal.String()) + } + if mode.HasNoZeroInDateMode() && (b.literal.InvalidZero() && !b.literal.IsZero()) { + return b.literal, true, types.ErrWrongValue.GenWithStackByArgs(types.DateStr, b.literal.String()) + } + return b.literal, false, nil +} + +type dateDiffFunctionClass struct { + baseFunctionClass +} + +func (c *dateDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime, types.ETDatetime) + if err != nil { + return nil, err + } + sig := &builtinDateDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DateDiff) + return sig, nil +} + +type builtinDateDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinDateDiffSig) Clone() builtinFunc { + newSig := &builtinDateDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinDateDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_datediff +func (b *builtinDateDiffSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + lhs, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + rhs, isNull, err := b.args[1].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + if invalidLHS, invalidRHS := lhs.InvalidZero(), rhs.InvalidZero(); invalidLHS || invalidRHS { + if invalidLHS { + err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, lhs.String())) + } + if invalidRHS { + err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rhs.String())) + } + return 0, true, err + } + return int64(types.DateDiff(lhs.CoreTime(), rhs.CoreTime())), false, nil +} + +type timeDiffFunctionClass struct { + baseFunctionClass +} + +func (c *timeDiffFunctionClass) getArgEvalTp(fieldTp *types.FieldType) types.EvalType { + argTp := types.ETString + switch tp := fieldTp.EvalType(); tp { + case types.ETDuration, types.ETDatetime, types.ETTimestamp: + argTp = tp + } + return argTp +} + +func (c *timeDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + arg0FieldTp, arg1FieldTp := args[0].GetType(ctx.GetEvalCtx()), args[1].GetType(ctx.GetEvalCtx()) + arg0Tp, arg1Tp := c.getArgEvalTp(arg0FieldTp), c.getArgEvalTp(arg1FieldTp) + arg0Dec, err := getExpressionFsp(ctx, args[0]) + if err != nil { + return nil, err + } + arg1Dec, err := getExpressionFsp(ctx, args[1]) + if err != nil { + return nil, err + } + fsp := mathutil.Max(arg0Dec, arg1Dec) + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, arg0Tp, arg1Tp) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(fsp) + + var sig builtinFunc + // arg0 and arg1 must be the same time type(compatible), or timediff will return NULL. + switch arg0Tp { + case types.ETDuration: + switch arg1Tp { + case types.ETDuration: + sig = &builtinDurationDurationTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DurationDurationTimeDiff) + case types.ETDatetime, types.ETTimestamp: + sig = &builtinNullTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_NullTimeDiff) + default: + sig = &builtinDurationStringTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DurationStringTimeDiff) + } + case types.ETDatetime, types.ETTimestamp: + switch arg1Tp { + case types.ETDuration: + sig = &builtinNullTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_NullTimeDiff) + case types.ETDatetime, types.ETTimestamp: + sig = &builtinTimeTimeTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_TimeTimeTimeDiff) + default: + sig = &builtinTimeStringTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_TimeStringTimeDiff) + } + default: + switch arg1Tp { + case types.ETDuration: + sig = &builtinStringDurationTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_StringDurationTimeDiff) + case types.ETDatetime, types.ETTimestamp: + sig = &builtinStringTimeTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_StringTimeTimeDiff) + default: + sig = &builtinStringStringTimeDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_StringStringTimeDiff) + } + } + return sig, nil +} + +type builtinDurationDurationTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinDurationDurationTimeDiffSig) Clone() builtinFunc { + newSig := &builtinDurationDurationTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinDurationDurationTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinDurationDurationTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + lhs, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + rhs, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + d, isNull, err = calculateDurationTimeDiff(ctx, lhs, rhs) + return d, isNull, err +} + +type builtinTimeTimeTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinTimeTimeTimeDiffSig) Clone() builtinFunc { + newSig := &builtinTimeTimeTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinTimeTimeTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinTimeTimeTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + lhs, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + rhs, isNull, err := b.args[1].EvalTime(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + tc := typeCtx(ctx) + d, isNull, err = calculateTimeDiff(tc, lhs, rhs) + return d, isNull, err +} + +type builtinDurationStringTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinDurationStringTimeDiffSig) Clone() builtinFunc { + newSig := &builtinDurationStringTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinDurationStringTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinDurationStringTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + lhs, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + rhsStr, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + tc := typeCtx(ctx) + rhs, _, isDuration, err := convertStringToDuration(tc, rhsStr, b.tp.GetDecimal()) + if err != nil || !isDuration { + return d, true, err + } + + d, isNull, err = calculateDurationTimeDiff(ctx, lhs, rhs) + return d, isNull, err +} + +type builtinStringDurationTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinStringDurationTimeDiffSig) Clone() builtinFunc { + newSig := &builtinStringDurationTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinStringDurationTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinStringDurationTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + lhsStr, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + rhs, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + tc := typeCtx(ctx) + lhs, _, isDuration, err := convertStringToDuration(tc, lhsStr, b.tp.GetDecimal()) + if err != nil || !isDuration { + return d, true, err + } + + d, isNull, err = calculateDurationTimeDiff(ctx, lhs, rhs) + return d, isNull, err +} + +// calculateTimeDiff calculates interval difference of two types.Time. +func calculateTimeDiff(tc types.Context, lhs, rhs types.Time) (d types.Duration, isNull bool, err error) { + d = lhs.Sub(tc, &rhs) + d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) + if types.ErrTruncatedWrongVal.Equal(err) { + err = tc.HandleTruncate(err) + } + return d, err != nil, err +} + +// calculateDurationTimeDiff calculates interval difference of two types.Duration. +func calculateDurationTimeDiff(ctx EvalContext, lhs, rhs types.Duration) (d types.Duration, isNull bool, err error) { + d, err = lhs.Sub(rhs) + if err != nil { + return d, true, err + } + + d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) + if types.ErrTruncatedWrongVal.Equal(err) { + tc := typeCtx(ctx) + err = tc.HandleTruncate(err) + } + return d, err != nil, err +} + +type builtinTimeStringTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinTimeStringTimeDiffSig) Clone() builtinFunc { + newSig := &builtinTimeStringTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinTimeStringTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinTimeStringTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + lhs, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + rhsStr, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + tc := typeCtx(ctx) + _, rhs, isDuration, err := convertStringToDuration(tc, rhsStr, b.tp.GetDecimal()) + if err != nil || isDuration { + return d, true, err + } + + d, isNull, err = calculateTimeDiff(tc, lhs, rhs) + return d, isNull, err +} + +type builtinStringTimeTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinStringTimeTimeDiffSig) Clone() builtinFunc { + newSig := &builtinStringTimeTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinStringTimeTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinStringTimeTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + lhsStr, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + rhs, isNull, err := b.args[1].EvalTime(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + tc := typeCtx(ctx) + _, lhs, isDuration, err := convertStringToDuration(tc, lhsStr, b.tp.GetDecimal()) + if err != nil || isDuration { + return d, true, err + } + + d, isNull, err = calculateTimeDiff(tc, lhs, rhs) + return d, isNull, err +} + +type builtinStringStringTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinStringStringTimeDiffSig) Clone() builtinFunc { + newSig := &builtinStringStringTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinStringStringTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinStringStringTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + lhs, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + rhs, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return d, isNull, err + } + + tc := typeCtx(ctx) + fsp := b.tp.GetDecimal() + lhsDur, lhsTime, lhsIsDuration, err := convertStringToDuration(tc, lhs, fsp) + if err != nil { + return d, true, err + } + + rhsDur, rhsTime, rhsIsDuration, err := convertStringToDuration(tc, rhs, fsp) + if err != nil { + return d, true, err + } + + if lhsIsDuration != rhsIsDuration { + return d, true, nil + } + + if lhsIsDuration { + d, isNull, err = calculateDurationTimeDiff(ctx, lhsDur, rhsDur) + } else { + d, isNull, err = calculateTimeDiff(tc, lhsTime, rhsTime) + } + + return d, isNull, err +} + +type builtinNullTimeDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinNullTimeDiffSig) Clone() builtinFunc { + newSig := &builtinNullTimeDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinNullTimeDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff +func (b *builtinNullTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { + return d, true, nil +} + +// convertStringToDuration converts string to duration, it return types.Time because in some case +// it will converts string to datetime. +func convertStringToDuration(tc types.Context, str string, fsp int) (d types.Duration, t types.Time, + isDuration bool, err error) { + if n := strings.IndexByte(str, '.'); n >= 0 { + lenStrFsp := len(str[n+1:]) + if lenStrFsp <= types.MaxFsp { + fsp = mathutil.Max(lenStrFsp, fsp) + } + } + return types.StrToDuration(tc, str, fsp) +} + +type dateFormatFunctionClass struct { + baseFunctionClass +} + +func (c *dateFormatFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDatetime, types.ETString) + if err != nil { + return nil, err + } + // worst case: formatMask=%r%r%r...%r, each %r takes 11 characters + bf.tp.SetFlen((args[1].GetType(ctx.GetEvalCtx()).GetFlen() + 1) / 2 * 11) + sig := &builtinDateFormatSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DateFormatSig) + return sig, nil +} + +type builtinDateFormatSig struct { + baseBuiltinFunc +} + +func (b *builtinDateFormatSig) Clone() builtinFunc { + newSig := &builtinDateFormatSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinDateFormatSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-format +func (b *builtinDateFormatSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + t, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return "", isNull, handleInvalidTimeError(ctx, err) + } + formatMask, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + // MySQL compatibility, #11203 + // If format mask is 0 then return 0 without warnings + if formatMask == "0" { + return "0", false, nil + } + + if t.InvalidZero() { + // MySQL compatibility, #11203 + // 0 | 0.0 should be converted to null without warnings + n, err := t.ToNumber().ToInt() + isOriginalIntOrDecimalZero := err == nil && n == 0 + // Args like "0000-00-00", "0000-00-00 00:00:00" set Fsp to 6 + isOriginalStringZero := t.Fsp() > 0 + if isOriginalIntOrDecimalZero && !isOriginalStringZero { + return "", true, nil + } + return "", true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) + } + + res, err := t.DateFormat(formatMask) + return res, isNull, err +} + +type fromDaysFunctionClass struct { + baseFunctionClass +} + +func (c *fromDaysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETInt) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig := &builtinFromDaysSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_FromDays) + return sig, nil +} + +type builtinFromDaysSig struct { + baseBuiltinFunc +} + +func (b *builtinFromDaysSig) Clone() builtinFunc { + newSig := &builtinFromDaysSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals FROM_DAYS(N). +// See https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_from-days +func (b *builtinFromDaysSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + n, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + ret := types.TimeFromDays(n) + // the maximum date value is 9999-12-31 in mysql 5.8. + if ret.Year() > 9999 { + return types.ZeroTime, true, nil + } + return ret, false, nil +} + +type hourFunctionClass struct { + baseFunctionClass +} + +func (c *hourFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) + if err != nil { + return nil, err + } + bf.tp.SetFlen(3) + bf.tp.SetDecimal(0) + sig := &builtinHourSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Hour) + return sig, nil +} + +type builtinHourSig struct { + baseBuiltinFunc +} + +func (b *builtinHourSig) Clone() builtinFunc { + newSig := &builtinHourSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals HOUR(time). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_hour +func (b *builtinHourSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + dur, isNull, err := b.args[0].EvalDuration(ctx, row) + // ignore error and return NULL + if isNull || err != nil { + return 0, true, nil + } + return int64(dur.Hour()), false, nil +} + +type minuteFunctionClass struct { + baseFunctionClass +} + +func (c *minuteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) + if err != nil { + return nil, err + } + bf.tp.SetFlen(2) + bf.tp.SetDecimal(0) + sig := &builtinMinuteSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Minute) + return sig, nil +} + +type builtinMinuteSig struct { + baseBuiltinFunc +} + +func (b *builtinMinuteSig) Clone() builtinFunc { + newSig := &builtinMinuteSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals MINUTE(time). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_minute +func (b *builtinMinuteSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + dur, isNull, err := b.args[0].EvalDuration(ctx, row) + // ignore error and return NULL + if isNull || err != nil { + return 0, true, nil + } + return int64(dur.Minute()), false, nil +} + +type secondFunctionClass struct { + baseFunctionClass +} + +func (c *secondFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) + if err != nil { + return nil, err + } + bf.tp.SetFlen(2) + bf.tp.SetDecimal(0) + sig := &builtinSecondSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Second) + return sig, nil +} + +type builtinSecondSig struct { + baseBuiltinFunc +} + +func (b *builtinSecondSig) Clone() builtinFunc { + newSig := &builtinSecondSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals SECOND(time). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_second +func (b *builtinSecondSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + dur, isNull, err := b.args[0].EvalDuration(ctx, row) + // ignore error and return NULL + if isNull || err != nil { + return 0, true, nil + } + return int64(dur.Second()), false, nil +} + +type microSecondFunctionClass struct { + baseFunctionClass +} + +func (c *microSecondFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) + if err != nil { + return nil, err + } + bf.tp.SetFlen(6) + bf.tp.SetDecimal(0) + sig := &builtinMicroSecondSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_MicroSecond) + return sig, nil +} + +type builtinMicroSecondSig struct { + baseBuiltinFunc +} + +func (b *builtinMicroSecondSig) Clone() builtinFunc { + newSig := &builtinMicroSecondSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals MICROSECOND(expr). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_microsecond +func (b *builtinMicroSecondSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + dur, isNull, err := b.args[0].EvalDuration(ctx, row) + // ignore error and return NULL + if isNull || err != nil { + return 0, true, nil + } + return int64(dur.MicroSecond()), false, nil +} + +type monthFunctionClass struct { + baseFunctionClass +} + +func (c *monthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(2) + bf.tp.SetDecimal(0) + sig := &builtinMonthSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Month) + return sig, nil +} + +type builtinMonthSig struct { + baseBuiltinFunc +} + +func (b *builtinMonthSig) Clone() builtinFunc { + newSig := &builtinMonthSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals MONTH(date). +// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_month +func (b *builtinMonthSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + + return int64(date.Month()), false, nil +} + +// monthNameFunctionClass see https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_monthname +type monthNameFunctionClass struct { + baseFunctionClass +} + +func (c *monthNameFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDatetime) + if err != nil { + return nil, err + } + charset, collate := ctx.GetCharsetInfo() + bf.tp.SetCharset(charset) + bf.tp.SetCollate(collate) + bf.tp.SetFlen(10) + sig := &builtinMonthNameSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_MonthName) + return sig, nil +} + +type builtinMonthNameSig struct { + baseBuiltinFunc +} + +func (b *builtinMonthNameSig) Clone() builtinFunc { + newSig := &builtinMonthNameSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinMonthNameSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return "", true, handleInvalidTimeError(ctx, err) + } + mon := arg.Month() + if (arg.IsZero() && sqlMode(ctx).HasNoZeroDateMode()) || mon < 0 || mon > len(types.MonthNames) { + return "", true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } else if mon == 0 || arg.IsZero() { + return "", true, nil + } + return types.MonthNames[mon-1], false, nil +} + +type dayNameFunctionClass struct { + baseFunctionClass +} + +func (c *dayNameFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDatetime) + if err != nil { + return nil, err + } + charset, collate := ctx.GetCharsetInfo() + bf.tp.SetCharset(charset) + bf.tp.SetCollate(collate) + bf.tp.SetFlen(10) + sig := &builtinDayNameSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DayName) + return sig, nil +} + +type builtinDayNameSig struct { + baseBuiltinFunc +} + +func (b *builtinDayNameSig) Clone() builtinFunc { + newSig := &builtinDayNameSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinDayNameSig) evalIndex(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + if arg.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + // Monday is 0, ... Sunday = 6 in MySQL + // but in go, Sunday is 0, ... Saturday is 6 + // w will do a conversion. + res := (int64(arg.Weekday()) + 6) % 7 + return res, false, nil +} + +// evalString evals a builtinDayNameSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayname +func (b *builtinDayNameSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + idx, isNull, err := b.evalIndex(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + return types.WeekdayNames[idx], false, nil +} + +func (b *builtinDayNameSig) evalReal(ctx EvalContext, row chunk.Row) (float64, bool, error) { + idx, isNull, err := b.evalIndex(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + return float64(idx), false, nil +} + +func (b *builtinDayNameSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + idx, isNull, err := b.evalIndex(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + return idx, false, nil +} + +type dayOfMonthFunctionClass struct { + baseFunctionClass +} + +func (c *dayOfMonthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(2) + sig := &builtinDayOfMonthSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DayOfMonth) + return sig, nil +} + +type builtinDayOfMonthSig struct { + baseBuiltinFunc +} + +func (b *builtinDayOfMonthSig) Clone() builtinFunc { + newSig := &builtinDayOfMonthSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinDayOfMonthSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofmonth +func (b *builtinDayOfMonthSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + return int64(arg.Day()), false, nil +} + +type dayOfWeekFunctionClass struct { + baseFunctionClass +} + +func (c *dayOfWeekFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(1) + sig := &builtinDayOfWeekSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DayOfWeek) + return sig, nil +} + +type builtinDayOfWeekSig struct { + baseBuiltinFunc +} + +func (b *builtinDayOfWeekSig) Clone() builtinFunc { + newSig := &builtinDayOfWeekSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinDayOfWeekSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofweek +func (b *builtinDayOfWeekSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + if arg.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + // 1 is Sunday, 2 is Monday, .... 7 is Saturday + return int64(arg.Weekday() + 1), false, nil +} + +type dayOfYearFunctionClass struct { + baseFunctionClass +} + +func (c *dayOfYearFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(3) + sig := &builtinDayOfYearSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_DayOfYear) + return sig, nil +} + +type builtinDayOfYearSig struct { + baseBuiltinFunc +} + +func (b *builtinDayOfYearSig) Clone() builtinFunc { + newSig := &builtinDayOfYearSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinDayOfYearSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofyear +func (b *builtinDayOfYearSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, isNull, handleInvalidTimeError(ctx, err) + } + if arg.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + + return int64(arg.YearDay()), false, nil +} + +type weekFunctionClass struct { + baseFunctionClass +} + +func (c *weekFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + argTps := []types.EvalType{types.ETDatetime} + if len(args) == 2 { + argTps = append(argTps, types.ETInt) + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + bf.tp.SetFlen(2) + bf.tp.SetDecimal(0) + + var sig builtinFunc + if len(args) == 2 { + sig = &builtinWeekWithModeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_WeekWithMode) + } else { + sig = &builtinWeekWithoutModeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_WeekWithoutMode) + } + return sig, nil +} + +type builtinWeekWithModeSig struct { + baseBuiltinFunc +} + +func (b *builtinWeekWithModeSig) Clone() builtinFunc { + newSig := &builtinWeekWithModeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals WEEK(date, mode). +// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_week +func (b *builtinWeekWithModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + + if date.IsZero() || date.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) + } + + mode, isNull, err := b.args[1].EvalInt(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + + week := date.Week(int(mode)) + return int64(week), false, nil +} + +type builtinWeekWithoutModeSig struct { + baseBuiltinFunc +} + +func (b *builtinWeekWithoutModeSig) Clone() builtinFunc { + newSig := &builtinWeekWithoutModeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals WEEK(date). +// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_week +func (b *builtinWeekWithoutModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + + if date.IsZero() || date.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) + } + + mode := 0 + if modeStr := ctx.GetDefaultWeekFormatMode(); modeStr != "" { + mode, err = strconv.Atoi(modeStr) + if err != nil { + return 0, true, handleInvalidTimeError(ctx, types.ErrInvalidWeekModeFormat.GenWithStackByArgs(modeStr)) + } + } + + week := date.Week(mode) + return int64(week), false, nil +} + +type weekDayFunctionClass struct { + baseFunctionClass +} + +func (c *weekDayFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(1) + + sig := &builtinWeekDaySig{bf} + sig.setPbCode(tipb.ScalarFuncSig_WeekDay) + return sig, nil +} + +type builtinWeekDaySig struct { + baseBuiltinFunc +} + +func (b *builtinWeekDaySig) Clone() builtinFunc { + newSig := &builtinWeekDaySig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals WEEKDAY(date). +func (b *builtinWeekDaySig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + + if date.IsZero() || date.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) + } + + return int64(date.Weekday()+6) % 7, false, nil +} + +type weekOfYearFunctionClass struct { + baseFunctionClass +} + +func (c *weekOfYearFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(2) + bf.tp.SetDecimal(0) + sig := &builtinWeekOfYearSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_WeekOfYear) + return sig, nil +} + +type builtinWeekOfYearSig struct { + baseBuiltinFunc +} + +func (b *builtinWeekOfYearSig) Clone() builtinFunc { + newSig := &builtinWeekOfYearSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals WEEKOFYEAR(date). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_weekofyear +func (b *builtinWeekOfYearSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + + if date.IsZero() || date.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) + } + + week := date.Week(3) + return int64(week), false, nil +} + +type yearFunctionClass struct { + baseFunctionClass +} + +func (c *yearFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(4) + bf.tp.SetDecimal(0) + sig := &builtinYearSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Year) + return sig, nil +} + +type builtinYearSig struct { + baseBuiltinFunc +} + +func (b *builtinYearSig) Clone() builtinFunc { + newSig := &builtinYearSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals YEAR(date). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_year +func (b *builtinYearSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + return int64(date.Year()), false, nil +} + +type yearWeekFunctionClass struct { + baseFunctionClass +} + +func (c *yearWeekFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := []types.EvalType{types.ETDatetime} + if len(args) == 2 { + argTps = append(argTps, types.ETInt) + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) + if err != nil { + return nil, err + } + + bf.tp.SetFlen(6) + bf.tp.SetDecimal(0) + + var sig builtinFunc + if len(args) == 2 { + sig = &builtinYearWeekWithModeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_YearWeekWithMode) + } else { + sig = &builtinYearWeekWithoutModeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_YearWeekWithoutMode) + } + return sig, nil +} + +type builtinYearWeekWithModeSig struct { + baseBuiltinFunc +} + +func (b *builtinYearWeekWithModeSig) Clone() builtinFunc { + newSig := &builtinYearWeekWithModeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals YEARWEEK(date,mode). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_yearweek +func (b *builtinYearWeekWithModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, isNull, handleInvalidTimeError(ctx, err) + } + if date.IsZero() || date.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) + } + + mode, isNull, err := b.args[1].EvalInt(ctx, row) + if err != nil { + return 0, true, err + } + if isNull { + mode = 0 + } + + year, week := date.YearWeek(int(mode)) + result := int64(week + year*100) + if result < 0 { + return int64(math.MaxUint32), false, nil + } + return result, false, nil +} + +type builtinYearWeekWithoutModeSig struct { + baseBuiltinFunc +} + +func (b *builtinYearWeekWithoutModeSig) Clone() builtinFunc { + newSig := &builtinYearWeekWithoutModeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals YEARWEEK(date). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_yearweek +func (b *builtinYearWeekWithoutModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + + if date.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) + } + + year, week := date.YearWeek(0) + result := int64(week + year*100) + if result < 0 { + return int64(math.MaxUint32), false, nil + } + return result, false, nil +} + +type fromUnixTimeFunctionClass struct { + baseFunctionClass +} + +func (c *fromUnixTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err = c.verifyArgs(args); err != nil { + return nil, err + } + + retTp, argTps := types.ETDatetime, make([]types.EvalType, 0, len(args)) + argTps = append(argTps, types.ETDecimal) + if len(args) == 2 { + retTp = types.ETString + argTps = append(argTps, types.ETString) + } + + arg0Tp := args[0].GetType(ctx.GetEvalCtx()) + isArg0Str := arg0Tp.EvalType() == types.ETString + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, retTp, argTps...) + if err != nil { + return nil, err + } + + if fieldString(arg0Tp.GetType()) { + //Improve string cast Unix Time precision + x, ok := (bf.getArgs()[0]).(*ScalarFunction) + if ok { + //used to adjust FromUnixTime precision #Fixbug35184 + if x.FuncName.L == ast.Cast { + if x.RetType.GetDecimal() == 0 && (x.RetType.GetType() == mysql.TypeNewDecimal) { + x.RetType.SetDecimal(6) + fieldLen := mathutil.Min(x.RetType.GetFlen()+6, mysql.MaxDecimalWidth) + x.RetType.SetFlen(fieldLen) + } + } + } + } + + if len(args) > 1 { + sig = &builtinFromUnixTime2ArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_FromUnixTime2Arg) + return sig, nil + } + + // Calculate the time fsp. + fsp := types.MaxFsp + if !isArg0Str { + if arg0Tp.GetDecimal() != types.UnspecifiedLength { + fsp = mathutil.Min(bf.tp.GetDecimal(), arg0Tp.GetDecimal()) + } + } + bf.setDecimalAndFlenForDatetime(fsp) + + sig = &builtinFromUnixTime1ArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_FromUnixTime1Arg) + return sig, nil +} + +func evalFromUnixTime(ctx EvalContext, fsp int, unixTimeStamp *types.MyDecimal) (res types.Time, isNull bool, err error) { + // 0 <= unixTimeStamp <= 32536771199.999999 + if unixTimeStamp.IsNegative() { + return res, true, nil + } + integralPart, err := unixTimeStamp.ToInt() + if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) && !terror.ErrorEqual(err, types.ErrOverflow) { + return res, true, err + } + // The max integralPart should not be larger than 32536771199. + // Refer to https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-28.html + if integralPart > 32536771199 { + return res, true, nil + } + // Split the integral part and fractional part of a decimal timestamp. + // e.g. for timestamp 12345.678, + // first get the integral part 12345, + // then (12345.678 - 12345) * (10^9) to get the decimal part and convert it to nanosecond precision. + integerDecimalTp := new(types.MyDecimal).FromInt(integralPart) + fracDecimalTp := new(types.MyDecimal) + err = types.DecimalSub(unixTimeStamp, integerDecimalTp, fracDecimalTp) + if err != nil { + return res, true, err + } + nano := new(types.MyDecimal).FromInt(int64(time.Second)) + x := new(types.MyDecimal) + err = types.DecimalMul(fracDecimalTp, nano, x) + if err != nil { + return res, true, err + } + fractionalPart, err := x.ToInt() // here fractionalPart is result multiplying the original fractional part by 10^9. + if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) { + return res, true, err + } + if fsp < 0 { + fsp = types.MaxFsp + } + + tc := typeCtx(ctx) + tmp := time.Unix(integralPart, fractionalPart).In(tc.Location()) + t, err := convertTimeToMysqlTime(tmp, fsp, types.ModeHalfUp) + if err != nil { + return res, true, err + } + return t, false, nil +} + +// fieldString returns true if precision cannot be determined +func fieldString(fieldType byte) bool { + switch fieldType { + case mysql.TypeString, mysql.TypeVarchar, mysql.TypeTinyBlob, + mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: + return true + default: + return false + } +} + +type builtinFromUnixTime1ArgSig struct { + baseBuiltinFunc +} + +func (b *builtinFromUnixTime1ArgSig) Clone() builtinFunc { + newSig := &builtinFromUnixTime1ArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinFromUnixTime1ArgSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_from-unixtime +func (b *builtinFromUnixTime1ArgSig) evalTime(ctx EvalContext, row chunk.Row) (res types.Time, isNull bool, err error) { + unixTimeStamp, isNull, err := b.args[0].EvalDecimal(ctx, row) + if err != nil || isNull { + return res, isNull, err + } + return evalFromUnixTime(ctx, b.tp.GetDecimal(), unixTimeStamp) +} + +type builtinFromUnixTime2ArgSig struct { + baseBuiltinFunc +} + +func (b *builtinFromUnixTime2ArgSig) Clone() builtinFunc { + newSig := &builtinFromUnixTime2ArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinFromUnixTime2ArgSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_from-unixtime +func (b *builtinFromUnixTime2ArgSig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) { + format, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", true, err + } + unixTimeStamp, isNull, err := b.args[0].EvalDecimal(ctx, row) + if err != nil || isNull { + return "", isNull, err + } + t, isNull, err := evalFromUnixTime(ctx, b.tp.GetDecimal(), unixTimeStamp) + if isNull || err != nil { + return "", isNull, err + } + res, err = t.DateFormat(format) + return res, err != nil, err +} + +type getFormatFunctionClass struct { + baseFunctionClass +} + +func (c *getFormatFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString, types.ETString) + if err != nil { + return nil, err + } + bf.tp.SetFlen(17) + sig := &builtinGetFormatSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_GetFormat) + return sig, nil +} + +type builtinGetFormatSig struct { + baseBuiltinFunc +} + +func (b *builtinGetFormatSig) Clone() builtinFunc { + newSig := &builtinGetFormatSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinGetFormatSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_get-format +func (b *builtinGetFormatSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + t, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + l, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + + res := b.getFormat(t, l) + return res, false, nil +} + +type strToDateFunctionClass struct { + baseFunctionClass +} + +func (c *strToDateFunctionClass) getRetTp(ctx BuildContext, arg Expression) (tp byte, fsp int) { + tp = mysql.TypeDatetime + if _, ok := arg.(*Constant); !ok { + return tp, types.MaxFsp + } + strArg := WrapWithCastAsString(ctx, arg) + format, isNull, err := strArg.EvalString(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil || isNull { + return + } + + isDuration, isDate := types.GetFormatType(format) + if isDuration && !isDate { + tp = mysql.TypeDuration + } else if !isDuration && isDate { + tp = mysql.TypeDate + } + if strings.Contains(format, "%f") { + fsp = types.MaxFsp + } + return +} + +// getFunction see https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_str-to-date +func (c *strToDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + retTp, fsp := c.getRetTp(ctx, args[1]) + switch retTp { + case mysql.TypeDate: + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETString, types.ETString) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig = &builtinStrToDateDateSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_StrToDateDate) + case mysql.TypeDatetime: + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETString, types.ETString) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(fsp) + sig = &builtinStrToDateDatetimeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_StrToDateDatetime) + case mysql.TypeDuration: + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETString, types.ETString) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(fsp) + sig = &builtinStrToDateDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_StrToDateDuration) + } + return sig, nil +} + +type builtinStrToDateDateSig struct { + baseBuiltinFunc +} + +func (b *builtinStrToDateDateSig) Clone() builtinFunc { + newSig := &builtinStrToDateDateSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinStrToDateDateSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + date, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + format, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + var t types.Time + tc := typeCtx(ctx) + succ := t.StrToDate(tc, date, format) + if !succ { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) + } + if sqlMode(ctx).HasNoZeroDateMode() && (t.Year() == 0 || t.Month() == 0 || t.Day() == 0) { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValueForType.GenWithStackByArgs(types.DateTimeStr, date, ast.StrToDate)) + } + t.SetType(mysql.TypeDate) + t.SetFsp(types.MinFsp) + return t, false, nil +} + +type builtinStrToDateDatetimeSig struct { + baseBuiltinFunc +} + +func (b *builtinStrToDateDatetimeSig) Clone() builtinFunc { + newSig := &builtinStrToDateDatetimeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinStrToDateDatetimeSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + date, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + format, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + var t types.Time + tc := typeCtx(ctx) + succ := t.StrToDate(tc, date, format) + if !succ { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) + } + if sqlMode(ctx).HasNoZeroDateMode() && (t.Year() == 0 || t.Month() == 0 || t.Day() == 0) { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) + } + t.SetType(mysql.TypeDatetime) + t.SetFsp(b.tp.GetDecimal()) + return t, false, nil +} + +type builtinStrToDateDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinStrToDateDurationSig) Clone() builtinFunc { + newSig := &builtinStrToDateDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration +// TODO: If the NO_ZERO_DATE or NO_ZERO_IN_DATE SQL mode is enabled, zero dates or part of dates are disallowed. +// In that case, STR_TO_DATE() returns NULL and generates a warning. +func (b *builtinStrToDateDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + date, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return types.Duration{}, isNull, err + } + format, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.Duration{}, isNull, err + } + var t types.Time + tc := typeCtx(ctx) + succ := t.StrToDate(tc, date, format) + if !succ { + return types.Duration{}, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) + } + t.SetFsp(b.tp.GetDecimal()) + dur, err := t.ConvertToDuration() + return dur, err != nil, err +} + +type sysDateFunctionClass struct { + baseFunctionClass +} + +func (c *sysDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + fsp, err := getFspByIntArg(ctx, args) + if err != nil { + return nil, err + } + var argTps = make([]types.EvalType, 0) + if len(args) == 1 { + argTps = append(argTps, types.ETInt) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, argTps...) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(fsp) + // Illegal parameters have been filtered out in the parser, so the result is always not null. + bf.tp.SetFlag(bf.tp.GetFlag() | mysql.NotNullFlag) + + var sig builtinFunc + if len(args) == 1 { + sig = &builtinSysDateWithFspSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SysDateWithFsp) + } else { + sig = &builtinSysDateWithoutFspSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SysDateWithoutFsp) + } + return sig, nil +} + +type builtinSysDateWithFspSig struct { + baseBuiltinFunc +} + +func (b *builtinSysDateWithFspSig) Clone() builtinFunc { + newSig := &builtinSysDateWithFspSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals SYSDATE(fsp). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_sysdate +func (b *builtinSysDateWithFspSig) evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { + fsp, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + + loc := location(ctx) + now := time.Now().In(loc) + result, err := convertTimeToMysqlTime(now, int(fsp), types.ModeHalfUp) + if err != nil { + return types.ZeroTime, true, err + } + return result, false, nil +} + +type builtinSysDateWithoutFspSig struct { + baseBuiltinFunc +} + +func (b *builtinSysDateWithoutFspSig) Clone() builtinFunc { + newSig := &builtinSysDateWithoutFspSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals SYSDATE(). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_sysdate +func (b *builtinSysDateWithoutFspSig) evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { + tz := location(ctx) + now := time.Now().In(tz) + result, err := convertTimeToMysqlTime(now, 0, types.ModeHalfUp) + if err != nil { + return types.ZeroTime, true, err + } + return result, false, nil +} + +type currentDateFunctionClass struct { + baseFunctionClass +} + +func (c *currentDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig := &builtinCurrentDateSig{bf} + return sig, nil +} + +type builtinCurrentDateSig struct { + baseBuiltinFunc +} + +func (b *builtinCurrentDateSig) Clone() builtinFunc { + newSig := &builtinCurrentDateSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals CURDATE(). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_curdate +func (b *builtinCurrentDateSig) evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { + tz := location(ctx) + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.ZeroTime, true, err + } + year, month, day := nowTs.In(tz).Date() + result := types.NewTime(types.FromDate(year, int(month), day, 0, 0, 0, 0), mysql.TypeDate, 0) + return result, false, nil +} + +type currentTimeFunctionClass struct { + baseFunctionClass +} + +func (c *currentTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err = c.verifyArgs(args); err != nil { + return nil, err + } + + fsp, err := getFspByIntArg(ctx, args) + if err != nil { + return nil, err + } + var argTps = make([]types.EvalType, 0) + if len(args) == 1 { + argTps = append(argTps, types.ETInt) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, argTps...) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(fsp) + // 1. no sign. + // 2. hour is in the 2-digit range. + bf.tp.SetFlen(bf.tp.GetFlen() - 2) + if len(args) == 0 { + sig = &builtinCurrentTime0ArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_CurrentTime0Arg) + return sig, nil + } + sig = &builtinCurrentTime1ArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_CurrentTime1Arg) + return sig, nil +} + +type builtinCurrentTime0ArgSig struct { + baseBuiltinFunc +} + +func (b *builtinCurrentTime0ArgSig) Clone() builtinFunc { + newSig := &builtinCurrentTime0ArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCurrentTime0ArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + tz := location(ctx) + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.Duration{}, true, err + } + dur := nowTs.In(tz).Format(types.TimeFormat) + res, _, err := types.ParseDuration(typeCtx(ctx), dur, types.MinFsp) + if err != nil { + return types.Duration{}, true, err + } + return res, false, nil +} + +type builtinCurrentTime1ArgSig struct { + baseBuiltinFunc +} + +func (b *builtinCurrentTime1ArgSig) Clone() builtinFunc { + newSig := &builtinCurrentTime1ArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinCurrentTime1ArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + fsp, _, err := b.args[0].EvalInt(ctx, row) + if err != nil { + return types.Duration{}, true, err + } + tz := location(ctx) + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.Duration{}, true, err + } + dur := nowTs.In(tz).Format(types.TimeFSPFormat) + tc := typeCtx(ctx) + res, _, err := types.ParseDuration(tc, dur, int(fsp)) + if err != nil { + return types.Duration{}, true, err + } + return res, false, nil +} + +type timeFunctionClass struct { + baseFunctionClass +} + +func (c *timeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + err := c.verifyArgs(args) + if err != nil { + return nil, err + } + fsp, err := getExpressionFsp(ctx, args[0]) + if err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETString) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(fsp) + sig := &builtinTimeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Time) + return sig, nil +} + +type builtinTimeSig struct { + baseBuiltinFunc +} + +func (b *builtinTimeSig) Clone() builtinFunc { + newSig := &builtinTimeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinTimeSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time. +func (b *builtinTimeSig) evalDuration(ctx EvalContext, row chunk.Row) (res types.Duration, isNull bool, err error) { + expr, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return res, isNull, err + } + + fsp := 0 + if idx := strings.Index(expr, "."); idx != -1 { + fsp = len(expr) - idx - 1 + } + + var tmpFsp int + if tmpFsp, err = types.CheckFsp(fsp); err != nil { + return res, isNull, err + } + fsp = tmpFsp + + tc := typeCtx(ctx) + res, _, err = types.ParseDuration(tc, expr, fsp) + if types.ErrTruncatedWrongVal.Equal(err) { + err = tc.HandleTruncate(err) + } + return res, isNull, err +} + +type timeLiteralFunctionClass struct { + baseFunctionClass +} + +func (c *timeLiteralFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + con, ok := args[0].(*Constant) + if !ok { + panic("Unexpected parameter for time literal") + } + dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, err + } + str := dt.GetString() + if !isDuration(str) { + return nil, types.ErrWrongValue.GenWithStackByArgs(types.TimeStr, str) + } + duration, _, err := types.ParseDuration(ctx.GetEvalCtx().TypeCtx(), str, types.GetFsp(str)) + if err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, []Expression{}, types.ETDuration) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(duration.Fsp) + sig := &builtinTimeLiteralSig{bf, duration} + return sig, nil +} + +type builtinTimeLiteralSig struct { + baseBuiltinFunc + duration types.Duration +} + +func (b *builtinTimeLiteralSig) Clone() builtinFunc { + newSig := &builtinTimeLiteralSig{duration: b.duration} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals TIME 'stringLit'. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-literals.html +func (b *builtinTimeLiteralSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + return b.duration, false, nil +} + +type utcDateFunctionClass struct { + baseFunctionClass +} + +func (c *utcDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig := &builtinUTCDateSig{bf} + return sig, nil +} + +type builtinUTCDateSig struct { + baseBuiltinFunc +} + +func (b *builtinUTCDateSig) Clone() builtinFunc { + newSig := &builtinUTCDateSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals UTC_DATE, UTC_DATE(). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-date +func (b *builtinUTCDateSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.ZeroTime, true, err + } + year, month, day := nowTs.UTC().Date() + result := types.NewTime(types.FromGoTime(time.Date(year, month, day, 0, 0, 0, 0, time.UTC)), mysql.TypeDate, types.UnspecifiedFsp) + return result, false, nil +} + +type utcTimestampFunctionClass struct { + baseFunctionClass +} + +func (c *utcTimestampFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, 1) + if len(args) == 1 { + argTps = append(argTps, types.ETInt) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, argTps...) + if err != nil { + return nil, err + } + + fsp, err := getFspByIntArg(ctx, args) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(fsp) + var sig builtinFunc + if len(args) == 1 { + sig = &builtinUTCTimestampWithArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UTCTimestampWithArg) + } else { + sig = &builtinUTCTimestampWithoutArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UTCTimestampWithoutArg) + } + return sig, nil +} + +func evalUTCTimestampWithFsp(ctx EvalContext, fsp int) (types.Time, bool, error) { + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.ZeroTime, true, err + } + result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp, types.ModeHalfUp) + if err != nil { + return types.ZeroTime, true, err + } + return result, false, nil +} + +type builtinUTCTimestampWithArgSig struct { + baseBuiltinFunc +} + +func (b *builtinUTCTimestampWithArgSig) Clone() builtinFunc { + newSig := &builtinUTCTimestampWithArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals UTC_TIMESTAMP(fsp). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-timestamp +func (b *builtinUTCTimestampWithArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + num, isNull, err := b.args[0].EvalInt(ctx, row) + if err != nil { + return types.ZeroTime, true, err + } + + if !isNull && num > int64(types.MaxFsp) { + return types.ZeroTime, true, errors.Errorf("Too-big precision %v specified for 'utc_timestamp'. Maximum is %v", num, types.MaxFsp) + } + if !isNull && num < int64(types.MinFsp) { + return types.ZeroTime, true, errors.Errorf("Invalid negative %d specified, must in [0, 6]", num) + } + + result, isNull, err := evalUTCTimestampWithFsp(ctx, int(num)) + return result, isNull, err +} + +type builtinUTCTimestampWithoutArgSig struct { + baseBuiltinFunc +} + +func (b *builtinUTCTimestampWithoutArgSig) Clone() builtinFunc { + newSig := &builtinUTCTimestampWithoutArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals UTC_TIMESTAMP(). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-timestamp +func (b *builtinUTCTimestampWithoutArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + result, isNull, err := evalUTCTimestampWithFsp(ctx, 0) + return result, isNull, err +} + +type nowFunctionClass struct { + baseFunctionClass +} + +func (c *nowFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, 1) + if len(args) == 1 { + argTps = append(argTps, types.ETInt) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, argTps...) + if err != nil { + return nil, err + } + + fsp, err := getFspByIntArg(ctx, args) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(fsp) + + var sig builtinFunc + if len(args) == 1 { + sig = &builtinNowWithArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_NowWithArg) + } else { + sig = &builtinNowWithoutArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_NowWithoutArg) + } + return sig, nil +} + +// GetStmtTimestamp directly calls getTimeZone with timezone +func GetStmtTimestamp(ctx EvalContext) (time.Time, error) { + tz := getTimeZone(ctx) + tVal, err := getStmtTimestamp(ctx) + if err != nil { + return tVal, err + } + return tVal.In(tz), nil +} + +func evalNowWithFsp(ctx EvalContext, fsp int) (types.Time, bool, error) { + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.ZeroTime, true, err + } + + failpoint.Inject("injectNow", func(val failpoint.Value) { + nowTs = time.Unix(int64(val.(int)), 0) + }) + + // In MySQL's implementation, now() will truncate the result instead of rounding it. + // Results below are from MySQL 5.7, which can prove it. + // mysql> select now(6), now(3), now(); + // +----------------------------+-------------------------+---------------------+ + // | now(6) | now(3) | now() | + // +----------------------------+-------------------------+---------------------+ + // | 2019-03-25 15:57:56.612966 | 2019-03-25 15:57:56.612 | 2019-03-25 15:57:56 | + // +----------------------------+-------------------------+---------------------+ + result, err := convertTimeToMysqlTime(nowTs, fsp, types.ModeTruncate) + if err != nil { + return types.ZeroTime, true, err + } + + err = result.ConvertTimeZone(time.Local, location(ctx)) + if err != nil { + return types.ZeroTime, true, err + } + + return result, false, nil +} + +type builtinNowWithArgSig struct { + baseBuiltinFunc +} + +func (b *builtinNowWithArgSig) Clone() builtinFunc { + newSig := &builtinNowWithArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals NOW(fsp) +// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_now +func (b *builtinNowWithArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + fsp, isNull, err := b.args[0].EvalInt(ctx, row) + + if err != nil { + return types.ZeroTime, true, err + } + + if isNull { + fsp = 0 + } else if fsp > int64(types.MaxFsp) { + return types.ZeroTime, true, errors.Errorf("Too-big precision %v specified for 'now'. Maximum is %v", fsp, types.MaxFsp) + } else if fsp < int64(types.MinFsp) { + return types.ZeroTime, true, errors.Errorf("Invalid negative %d specified, must in [0, 6]", fsp) + } + + result, isNull, err := evalNowWithFsp(ctx, int(fsp)) + return result, isNull, err +} + +type builtinNowWithoutArgSig struct { + baseBuiltinFunc +} + +func (b *builtinNowWithoutArgSig) Clone() builtinFunc { + newSig := &builtinNowWithoutArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals NOW() +// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_now +func (b *builtinNowWithoutArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + result, isNull, err := evalNowWithFsp(ctx, 0) + return result, isNull, err +} + +type extractFunctionClass struct { + baseFunctionClass +} + +func (c *extractFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err = c.verifyArgs(args); err != nil { + return nil, err + } + + args[0] = WrapWithCastAsString(ctx, args[0]) + unit, _, err := args[0].EvalString(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, err + } + isClockUnit := types.IsClockUnit(unit) + isDateUnit := types.IsDateUnit(unit) + var bf baseBuiltinFunc + if isClockUnit && isDateUnit { + // For unit DAY_MICROSECOND/DAY_SECOND/DAY_MINUTE/DAY_HOUR, the interpretation of the second argument depends on its evaluation type: + // 1. Datetime/timestamp are interpreted as datetime. For example: + // extract(day_second from datetime('2001-01-01 02:03:04')) = 120304 + // Note that MySQL 5.5+ has a bug of no day portion in the result (20304) for this case, see https://bugs.mysql.com/bug.php?id=73240. + // 2. Time is interpreted as is. For example: + // extract(day_second from time('02:03:04')) = 20304 + // Note that time shouldn't be implicitly cast to datetime, or else the date portion will be padded with the current date and this will adjust time portion accordingly. + // 3. Otherwise, string/int/float are interpreted as arbitrarily either datetime or time, depending on which fits. For example: + // extract(day_second from '2001-01-01 02:03:04') = 1020304 // datetime + // extract(day_second from 20010101020304) = 1020304 // datetime + // extract(day_second from '01 02:03:04') = 260304 // time + if args[1].GetType(ctx.GetEvalCtx()).EvalType() == types.ETDatetime || args[1].GetType(ctx.GetEvalCtx()).EvalType() == types.ETTimestamp { + bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDatetime) + if err != nil { + return nil, err + } + sig = &builtinExtractDatetimeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ExtractDatetime) + } else if args[1].GetType(ctx.GetEvalCtx()).EvalType() == types.ETDuration { + bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDuration) + if err != nil { + return nil, err + } + sig = &builtinExtractDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ExtractDuration) + } else { + bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETString) + if err != nil { + return nil, err + } + bf.args[1].GetType(ctx.GetEvalCtx()).SetDecimal(int(types.MaxFsp)) + sig = &builtinExtractDatetimeFromStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ExtractDatetimeFromString) + } + } else if isClockUnit { + // Clock units interpret the second argument as time. + bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDuration) + if err != nil { + return nil, err + } + sig = &builtinExtractDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ExtractDuration) + } else { + // Date units interpret the second argument as datetime. + bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDatetime) + if err != nil { + return nil, err + } + sig = &builtinExtractDatetimeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ExtractDatetime) + } + return sig, nil +} + +type builtinExtractDatetimeFromStringSig struct { + baseBuiltinFunc +} + +func (b *builtinExtractDatetimeFromStringSig) Clone() builtinFunc { + newSig := &builtinExtractDatetimeFromStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinExtractDatetimeFromStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract +func (b *builtinExtractDatetimeFromStringSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + unit, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + dtStr, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + tc := typeCtx(ctx) + if types.IsClockUnit(unit) && types.IsDateUnit(unit) { + dur, _, err := types.ParseDuration(tc, dtStr, types.GetFsp(dtStr)) + if err != nil { + return 0, true, err + } + res, err := types.ExtractDurationNum(&dur, unit) + if err != nil { + return 0, true, err + } + dt, err := types.ParseDatetime(tc, dtStr) + if err != nil { + return res, false, nil + } + if dt.Hour() == dur.Hour() && dt.Minute() == dur.Minute() && dt.Second() == dur.Second() && dt.Year() > 0 { + res, err = types.ExtractDatetimeNum(&dt, unit) + } + return res, err != nil, err + } + + panic("Unexpected unit for extract") +} + +type builtinExtractDatetimeSig struct { + baseBuiltinFunc +} + +func (b *builtinExtractDatetimeSig) Clone() builtinFunc { + newSig := &builtinExtractDatetimeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinExtractDatetimeSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract +func (b *builtinExtractDatetimeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + unit, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + dt, isNull, err := b.args[1].EvalTime(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + res, err := types.ExtractDatetimeNum(&dt, unit) + return res, err != nil, err +} + +type builtinExtractDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinExtractDurationSig) Clone() builtinFunc { + newSig := &builtinExtractDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinExtractDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract +func (b *builtinExtractDurationSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + unit, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + dur, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + res, err := types.ExtractDurationNum(&dur, unit) + return res, err != nil, err +} + +// baseDateArithmetical is the base class for all "builtinAddDateXXXSig" and "builtinSubDateXXXSig", +// which provides parameter getter and date arithmetical calculate functions. +type baseDateArithmetical struct { + // intervalRegexp is "*Regexp" used to extract string interval for "DAY" unit. + intervalRegexp *regexp.Regexp +} + +func newDateArithmeticalUtil() baseDateArithmetical { + return baseDateArithmetical{ + intervalRegexp: regexp.MustCompile(`^[+-]?[\d]+`), + } +} + +func (du *baseDateArithmetical) getDateFromString(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + dateStr, isNull, err := args[0].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + dateTp := mysql.TypeDate + if !types.IsDateFormat(dateStr) || types.IsClockUnit(unit) { + dateTp = mysql.TypeDatetime + } + + tc := typeCtx(ctx) + date, err := types.ParseTime(tc, dateStr, dateTp, types.MaxFsp) + if err != nil { + err = handleInvalidTimeError(ctx, err) + if err != nil { + return types.ZeroTime, true, err + } + return date, true, handleInvalidTimeError(ctx, err) + } else if sqlMode(ctx).HasNoZeroDateMode() && (date.Year() == 0 || date.Month() == 0 || date.Day() == 0) { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, dateStr)) + } + return date, false, handleInvalidTimeError(ctx, err) +} + +func (du *baseDateArithmetical) getDateFromInt(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + dateInt, isNull, err := args[0].EvalInt(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + tc := typeCtx(ctx) + date, err := types.ParseTimeFromInt64(tc, dateInt) + if err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + + // The actual date.Type() might be date or datetime. + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if types.IsClockUnit(unit) { + date.SetType(mysql.TypeDatetime) + } + return date, false, nil +} + +func (du *baseDateArithmetical) getDateFromReal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + dateReal, isNull, err := args[0].EvalReal(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + tc := typeCtx(ctx) + date, err := types.ParseTimeFromFloat64(tc, dateReal) + if err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + + // The actual date.Type() might be date or datetime. + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if types.IsClockUnit(unit) { + date.SetType(mysql.TypeDatetime) + } + return date, false, nil +} + +func (du *baseDateArithmetical) getDateFromDecimal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + dateDec, isNull, err := args[0].EvalDecimal(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + tc := typeCtx(ctx) + date, err := types.ParseTimeFromDecimal(tc, dateDec) + if err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + + // The actual date.Type() might be date or datetime. + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if types.IsClockUnit(unit) { + date.SetType(mysql.TypeDatetime) + } + return date, false, nil +} + +func (du *baseDateArithmetical) getDateFromDatetime(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + date, isNull, err := args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + // The actual date.Type() might be date, datetime or timestamp. + // Datetime is treated as is. + // Timestamp is treated as datetime, as MySQL manual says: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-add + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if types.IsClockUnit(unit) || date.Type() == mysql.TypeTimestamp { + date.SetType(mysql.TypeDatetime) + } + return date, false, nil +} + +func (du *baseDateArithmetical) getIntervalFromString(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + interval, isNull, err := args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", true, err + } + + ec := errCtx(ctx) + interval, err = du.intervalReformatString(ec, interval, unit) + return interval, false, err +} + +func (du *baseDateArithmetical) intervalReformatString(ec errctx.Context, str string, unit string) (interval string, err error) { + switch strings.ToUpper(unit) { + case "MICROSECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": + str = strings.TrimSpace(str) + // a single unit value has to be specially handled. + interval = du.intervalRegexp.FindString(str) + if interval == "" { + interval = "0" + } + + if interval != str { + err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str)) + } + case "SECOND": + // The unit SECOND is specially handled, for example: + // date + INTERVAL "1e2" SECOND = date + INTERVAL 100 second + // date + INTERVAL "1.6" SECOND = date + INTERVAL 1.6 second + // But: + // date + INTERVAL "1e2" MINUTE = date + INTERVAL 1 MINUTE + // date + INTERVAL "1.6" MINUTE = date + INTERVAL 1 MINUTE + var dec types.MyDecimal + if err = dec.FromString([]byte(str)); err != nil { + truncatedErr := types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str) + err = ec.HandleErrorWithAlias(err, truncatedErr, truncatedErr) + } + interval = string(dec.ToString()) + default: + interval = str + } + return interval, err +} + +func (du *baseDateArithmetical) intervalDecimalToString(ec errctx.Context, dec *types.MyDecimal) (string, error) { + var rounded types.MyDecimal + err := dec.Round(&rounded, 0, types.ModeHalfUp) + if err != nil { + return "", err + } + + intVal, err := rounded.ToInt() + if err != nil { + if err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", dec.String())); err != nil { + return "", err + } + } + + return strconv.FormatInt(intVal, 10), nil +} + +func (du *baseDateArithmetical) getIntervalFromDecimal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + decimal, isNull, err := args[1].EvalDecimal(ctx, row) + if isNull || err != nil { + return "", true, err + } + interval := decimal.String() + + switch strings.ToUpper(unit) { + case "HOUR_MINUTE", "MINUTE_SECOND", "YEAR_MONTH", "DAY_HOUR", "DAY_MINUTE", + "DAY_SECOND", "DAY_MICROSECOND", "HOUR_MICROSECOND", "HOUR_SECOND", "MINUTE_MICROSECOND", "SECOND_MICROSECOND": + neg := false + if interval != "" && interval[0] == '-' { + neg = true + interval = interval[1:] + } + switch strings.ToUpper(unit) { + case "HOUR_MINUTE", "MINUTE_SECOND": + interval = strings.ReplaceAll(interval, ".", ":") + case "YEAR_MONTH": + interval = strings.ReplaceAll(interval, ".", "-") + case "DAY_HOUR": + interval = strings.ReplaceAll(interval, ".", " ") + case "DAY_MINUTE": + interval = "0 " + strings.ReplaceAll(interval, ".", ":") + case "DAY_SECOND": + interval = "0 00:" + strings.ReplaceAll(interval, ".", ":") + case "DAY_MICROSECOND": + interval = "0 00:00:" + interval + case "HOUR_MICROSECOND": + interval = "00:00:" + interval + case "HOUR_SECOND": + interval = "00:" + strings.ReplaceAll(interval, ".", ":") + case "MINUTE_MICROSECOND": + interval = "00:" + interval + case "SECOND_MICROSECOND": + /* keep interval as original decimal */ + } + if neg { + interval = "-" + interval + } + case "SECOND": + // interval is already like the %f format. + default: + // YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND + ec := errCtx(ctx) + interval, err = du.intervalDecimalToString(ec, decimal) + if err != nil { + return "", true, err + } + } + + return interval, false, nil +} + +func (du *baseDateArithmetical) getIntervalFromInt(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + interval, isNull, err := args[1].EvalInt(ctx, row) + if isNull || err != nil { + return "", true, err + } + + if mysql.HasUnsignedFlag(args[1].GetType(ctx).GetFlag()) { + return strconv.FormatUint(uint64(interval), 10), false, nil + } + + return strconv.FormatInt(interval, 10), false, nil +} + +func (du *baseDateArithmetical) getIntervalFromReal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + interval, isNull, err := args[1].EvalReal(ctx, row) + if isNull || err != nil { + return "", true, err + } + return strconv.FormatFloat(interval, 'f', args[1].GetType(ctx).GetDecimal(), 64), false, nil +} + +func (du *baseDateArithmetical) add(ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) { + year, month, day, nano, _, err := types.ParseDurationValue(unit, interval) + if err := handleInvalidTimeError(ctx, err); err != nil { + return types.ZeroTime, true, err + } + return du.addDate(ctx, date, year, month, day, nano, resultFsp) +} + +func (du *baseDateArithmetical) addDate(ctx EvalContext, date types.Time, year, month, day, nano int64, resultFsp int) (types.Time, bool, error) { + goTime, err := date.GoTime(time.UTC) + if err := handleInvalidTimeError(ctx, err); err != nil { + return types.ZeroTime, true, err + } + + goTime = goTime.Add(time.Duration(nano)) + goTime, err = types.AddDate(year, month, day, goTime) + if err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) + } + + // Adjust fsp as required by outer - always respect type inference. + date.SetFsp(resultFsp) + + // fix https://github.com/pingcap/tidb/issues/11329 + if goTime.Year() == 0 { + hour, minute, second := goTime.Clock() + date.SetCoreTime(types.FromDate(0, 0, 0, hour, minute, second, goTime.Nanosecond()/1000)) + return date, false, nil + } + + date.SetCoreTime(types.FromGoTime(goTime)) + tc := typeCtx(ctx) + overflow, err := types.DateTimeIsOverflow(tc, date) + if err := handleInvalidTimeError(ctx, err); err != nil { + return types.ZeroTime, true, err + } + if overflow { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) + } + return date, false, nil +} + +type funcDurationOp func(d, interval types.Duration) (types.Duration, error) + +func (du *baseDateArithmetical) opDuration(ctx EvalContext, op funcDurationOp, d types.Duration, interval string, unit string, resultFsp int) (types.Duration, bool, error) { + dur, err := types.ExtractDurationValue(unit, interval) + if err != nil { + return types.ZeroDuration, true, handleInvalidTimeError(ctx, err) + } + retDur, err := op(d, dur) + if err != nil { + return types.ZeroDuration, true, err + } + // Adjust fsp as required by outer - always respect type inference. + retDur.Fsp = resultFsp + return retDur, false, nil +} + +func (du *baseDateArithmetical) addDuration(ctx EvalContext, d types.Duration, interval string, unit string, resultFsp int) (types.Duration, bool, error) { + add := func(d, interval types.Duration) (types.Duration, error) { + return d.Add(interval) + } + return du.opDuration(ctx, add, d, interval, unit, resultFsp) +} + +func (du *baseDateArithmetical) subDuration(ctx EvalContext, d types.Duration, interval string, unit string, resultFsp int) (types.Duration, bool, error) { + sub := func(d, interval types.Duration) (types.Duration, error) { + return d.Sub(interval) + } + return du.opDuration(ctx, sub, d, interval, unit, resultFsp) +} + +func (du *baseDateArithmetical) sub(ctx EvalContext, date types.Time, interval string, unit string, resultFsp int) (types.Time, bool, error) { + year, month, day, nano, _, err := types.ParseDurationValue(unit, interval) + if err := handleInvalidTimeError(ctx, err); err != nil { + return types.ZeroTime, true, err + } + return du.addDate(ctx, date, -year, -month, -day, -nano, resultFsp) +} + +func (du *baseDateArithmetical) vecGetDateFromInt(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalInt(ctx, input, buf); err != nil { + return err + } + + result.ResizeTime(n, false) + result.MergeNulls(buf) + dates := result.Times() + i64s := buf.Int64s() + tc := typeCtx(ctx) + isClockUnit := types.IsClockUnit(unit) + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + + date, err := types.ParseTimeFromInt64(tc, i64s[i]) + if err != nil { + err = handleInvalidTimeError(ctx, err) + if err != nil { + return err + } + result.SetNull(i, true) + continue + } + + // The actual date.Type() might be date or datetime. + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if isClockUnit { + date.SetType(mysql.TypeDatetime) + } + dates[i] = date + } + return nil +} + +func (du *baseDateArithmetical) vecGetDateFromReal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalReal(ctx, input, buf); err != nil { + return err + } + + result.ResizeTime(n, false) + result.MergeNulls(buf) + dates := result.Times() + f64s := buf.Float64s() + tc := typeCtx(ctx) + isClockUnit := types.IsClockUnit(unit) + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + + date, err := types.ParseTimeFromFloat64(tc, f64s[i]) + if err != nil { + err = handleInvalidTimeError(ctx, err) + if err != nil { + return err + } + result.SetNull(i, true) + continue + } + + // The actual date.Type() might be date or datetime. + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if isClockUnit { + date.SetType(mysql.TypeDatetime) + } + dates[i] = date + } + return nil +} + +func (du *baseDateArithmetical) vecGetDateFromDecimal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalDecimal(ctx, input, buf); err != nil { + return err + } + + result.ResizeTime(n, false) + result.MergeNulls(buf) + dates := result.Times() + tc := typeCtx(ctx) + isClockUnit := types.IsClockUnit(unit) + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + + dec := buf.GetDecimal(i) + date, err := types.ParseTimeFromDecimal(tc, dec) + if err != nil { + err = handleInvalidTimeError(ctx, err) + if err != nil { + return err + } + result.SetNull(i, true) + continue + } + + // The actual date.Type() might be date or datetime. + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if isClockUnit { + date.SetType(mysql.TypeDatetime) + } + dates[i] = date + } + return nil +} + +func (du *baseDateArithmetical) vecGetDateFromString(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[0].VecEvalString(ctx, input, buf); err != nil { + return err + } + + result.ResizeTime(n, false) + result.MergeNulls(buf) + dates := result.Times() + tc := typeCtx(ctx) + isClockUnit := types.IsClockUnit(unit) + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + + dateStr := buf.GetString(i) + dateTp := mysql.TypeDate + if !types.IsDateFormat(dateStr) || isClockUnit { + dateTp = mysql.TypeDatetime + } + + date, err := types.ParseTime(tc, dateStr, dateTp, types.MaxFsp) + if err != nil { + err = handleInvalidTimeError(ctx, err) + if err != nil { + return err + } + result.SetNull(i, true) + } else if sqlMode(ctx).HasNoZeroDateMode() && (date.Year() == 0 || date.Month() == 0 || date.Day() == 0) { + err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, dateStr)) + if err != nil { + return err + } + result.SetNull(i, true) + } else { + dates[i] = date + } + } + return nil +} + +func (du *baseDateArithmetical) vecGetDateFromDatetime(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + result.ResizeTime(n, false) + if err := b.args[0].VecEvalTime(ctx, input, result); err != nil { + return err + } + + dates := result.Times() + isClockUnit := types.IsClockUnit(unit) + for i := 0; i < n; i++ { + if result.IsNull(i) { + continue + } + + // The actual date[i].Type() might be date, datetime or timestamp. + // Datetime is treated as is. + // Timestamp is treated as datetime, as MySQL manual says: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-add + // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. + if isClockUnit || dates[i].Type() == mysql.TypeTimestamp { + dates[i].SetType(mysql.TypeDatetime) + } + } + return nil +} + +func (du *baseDateArithmetical) vecGetIntervalFromString(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[1].VecEvalString(ctx, input, buf); err != nil { + return err + } + + ec := errCtx(ctx) + result.ReserveString(n) + for i := 0; i < n; i++ { + if buf.IsNull(i) { + result.AppendNull() + continue + } + + interval, err := du.intervalReformatString(ec, buf.GetString(i), unit) + if err != nil { + return err + } + result.AppendString(interval) + } + return nil +} + +func (du *baseDateArithmetical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[1].VecEvalDecimal(ctx, input, buf); err != nil { + return err + } + + isCompoundUnit := false + amendInterval := func(val string, row *chunk.Row) (string, bool, error) { + return val, false, nil + } + switch unitUpper := strings.ToUpper(unit); unitUpper { + case "HOUR_MINUTE", "MINUTE_SECOND", "YEAR_MONTH", "DAY_HOUR", "DAY_MINUTE", + "DAY_SECOND", "DAY_MICROSECOND", "HOUR_MICROSECOND", "HOUR_SECOND", "MINUTE_MICROSECOND", "SECOND_MICROSECOND": + isCompoundUnit = true + switch strings.ToUpper(unit) { + case "HOUR_MINUTE", "MINUTE_SECOND": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return strings.ReplaceAll(val, ".", ":"), false, nil + } + case "YEAR_MONTH": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return strings.ReplaceAll(val, ".", "-"), false, nil + } + case "DAY_HOUR": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return strings.ReplaceAll(val, ".", " "), false, nil + } + case "DAY_MINUTE": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return "0 " + strings.ReplaceAll(val, ".", ":"), false, nil + } + case "DAY_SECOND": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return "0 00:" + strings.ReplaceAll(val, ".", ":"), false, nil + } + case "DAY_MICROSECOND": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return "0 00:00:" + val, false, nil + } + case "HOUR_MICROSECOND": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return "00:00:" + val, false, nil + } + case "HOUR_SECOND": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return "00:" + strings.ReplaceAll(val, ".", ":"), false, nil + } + case "MINUTE_MICROSECOND": + amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { + return "00:" + val, false, nil + } + case "SECOND_MICROSECOND": + /* keep interval as original decimal */ + } + case "SECOND": + /* keep interval as original decimal */ + default: + // YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND + amendInterval = func(_ string, row *chunk.Row) (string, bool, error) { + dec, isNull, err := b.args[1].EvalDecimal(ctx, *row) + if isNull || err != nil { + return "", true, err + } + + str, err := du.intervalDecimalToString(errCtx(ctx), dec) + if err != nil { + return "", true, err + } + + return str, false, nil + } + } + + result.ReserveString(n) + decs := buf.Decimals() + for i := 0; i < n; i++ { + if buf.IsNull(i) { + result.AppendNull() + continue + } + + interval := decs[i].String() + row := input.GetRow(i) + isNeg := false + if isCompoundUnit && interval != "" && interval[0] == '-' { + isNeg = true + interval = interval[1:] + } + interval, isNull, err := amendInterval(interval, &row) + if err != nil { + return err + } + if isNull { + result.AppendNull() + continue + } + if isCompoundUnit && isNeg { + interval = "-" + interval + } + result.AppendString(interval) + } + return nil +} + +func (du *baseDateArithmetical) vecGetIntervalFromInt(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[1].VecEvalInt(ctx, input, buf); err != nil { + return err + } + + result.ReserveString(n) + i64s := buf.Int64s() + unsigned := mysql.HasUnsignedFlag(b.args[1].GetType(ctx).GetFlag()) + for i := 0; i < n; i++ { + if buf.IsNull(i) { + result.AppendNull() + } else if unsigned { + result.AppendString(strconv.FormatUint(uint64(i64s[i]), 10)) + } else { + result.AppendString(strconv.FormatInt(i64s[i], 10)) + } + } + return nil +} + +func (du *baseDateArithmetical) vecGetIntervalFromReal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { + n := input.NumRows() + buf, err := b.bufAllocator.get() + if err != nil { + return err + } + defer b.bufAllocator.put(buf) + if err := b.args[1].VecEvalReal(ctx, input, buf); err != nil { + return err + } + + result.ReserveString(n) + f64s := buf.Float64s() + prec := b.args[1].GetType(ctx).GetDecimal() + for i := 0; i < n; i++ { + if buf.IsNull(i) { + result.AppendNull() + } else { + result.AppendString(strconv.FormatFloat(f64s[i], 'f', prec, 64)) + } + } + return nil +} + +type funcTimeOpForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) + +func addTime(da *baseDateArithmetical, ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) { + return da.add(ctx, date, interval, unit, resultFsp) +} + +func subTime(da *baseDateArithmetical, ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) { + return da.sub(ctx, date, interval, unit, resultFsp) +} + +type funcDurationOpForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, d types.Duration, interval, unit string, resultFsp int) (types.Duration, bool, error) + +func addDuration(da *baseDateArithmetical, ctx EvalContext, d types.Duration, interval, unit string, resultFsp int) (types.Duration, bool, error) { + return da.addDuration(ctx, d, interval, unit, resultFsp) +} + +func subDuration(da *baseDateArithmetical, ctx EvalContext, d types.Duration, interval, unit string, resultFsp int) (types.Duration, bool, error) { + return da.subDuration(ctx, d, interval, unit, resultFsp) +} + +type funcSetPbCodeOp func(b builtinFunc, add, sub tipb.ScalarFuncSig) + +func setAdd(b builtinFunc, add, sub tipb.ScalarFuncSig) { + b.setPbCode(add) +} + +func setSub(b builtinFunc, add, sub tipb.ScalarFuncSig) { + b.setPbCode(sub) +} + +type funcGetDateForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) + +func getDateFromString(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + return da.getDateFromString(ctx, args, row, unit) +} + +func getDateFromInt(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + return da.getDateFromInt(ctx, args, row, unit) +} + +func getDateFromReal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + return da.getDateFromReal(ctx, args, row, unit) +} + +func getDateFromDecimal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { + return da.getDateFromDecimal(ctx, args, row, unit) +} + +type funcVecGetDateForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error + +func vecGetDateFromString(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetDateFromString(b, ctx, input, unit, result) +} + +func vecGetDateFromInt(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetDateFromInt(b, ctx, input, unit, result) +} + +func vecGetDateFromReal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetDateFromReal(b, ctx, input, unit, result) +} + +func vecGetDateFromDecimal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetDateFromDecimal(b, ctx, input, unit, result) +} + +type funcGetIntervalForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) + +func getIntervalFromString(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + return da.getIntervalFromString(ctx, args, row, unit) +} + +func getIntervalFromInt(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + return da.getIntervalFromInt(ctx, args, row, unit) +} + +func getIntervalFromReal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + return da.getIntervalFromReal(ctx, args, row, unit) +} + +func getIntervalFromDecimal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { + return da.getIntervalFromDecimal(ctx, args, row, unit) +} + +type funcVecGetIntervalForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error + +func vecGetIntervalFromString(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetIntervalFromString(b, ctx, input, unit, result) +} + +func vecGetIntervalFromInt(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetIntervalFromInt(b, ctx, input, unit, result) +} + +func vecGetIntervalFromReal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetIntervalFromReal(b, ctx, input, unit, result) +} + +func vecGetIntervalFromDecimal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { + return da.vecGetIntervalFromDecimal(b, ctx, input, unit, result) +} + +type addSubDateFunctionClass struct { + baseFunctionClass + timeOp funcTimeOpForDateAddSub + durationOp funcDurationOpForDateAddSub + setPbCodeOp funcSetPbCodeOp +} + +func (c *addSubDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err = c.verifyArgs(args); err != nil { + return nil, err + } + + dateEvalTp := args[0].GetType(ctx.GetEvalCtx()).EvalType() + // Some special evaluation type treatment. + // Note that it could be more elegant if we always evaluate datetime for int, real, decimal and string, by leveraging existing implicit casts. + // However, MySQL has a weird behavior for date_add(string, ...), whose result depends on the content of the first argument. + // E.g., date_add('2000-01-02 00:00:00', interval 1 day) evaluates to '2021-01-03 00:00:00' (which is normal), + // whereas date_add('2000-01-02', interval 1 day) evaluates to '2000-01-03' instead of '2021-01-03 00:00:00'. + // This requires a customized parsing of the content of the first argument, by recognizing if it is a pure date format or contains HMS part. + // So implicit casts are not viable here. + if dateEvalTp == types.ETTimestamp { + dateEvalTp = types.ETDatetime + } else if dateEvalTp == types.ETJson { + dateEvalTp = types.ETString + } + + intervalEvalTp := args[1].GetType(ctx.GetEvalCtx()).EvalType() + if intervalEvalTp == types.ETJson { + intervalEvalTp = types.ETString + } else if intervalEvalTp != types.ETString && intervalEvalTp != types.ETDecimal && intervalEvalTp != types.ETReal { + intervalEvalTp = types.ETInt + } + + unit, _, err := args[2].EvalString(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, err + } + + resultTp := mysql.TypeVarString + resultEvalTp := types.ETString + if args[0].GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDate { + if !types.IsClockUnit(unit) { + // First arg is date and unit contains no HMS, return date. + resultTp = mysql.TypeDate + resultEvalTp = types.ETDatetime + } else { + // First arg is date and unit contains HMS, return datetime. + resultTp = mysql.TypeDatetime + resultEvalTp = types.ETDatetime + } + } else if dateEvalTp == types.ETDuration { + if types.IsDateUnit(unit) && unit != "DAY_MICROSECOND" { + // First arg is time and unit contains YMD (except DAY_MICROSECOND), return datetime. + resultTp = mysql.TypeDatetime + resultEvalTp = types.ETDatetime + } else { + // First arg is time and unit contains no YMD or is DAY_MICROSECOND, return time. + resultTp = mysql.TypeDuration + resultEvalTp = types.ETDuration + } + } else if dateEvalTp == types.ETDatetime { + // First arg is datetime or timestamp, return datetime. + resultTp = mysql.TypeDatetime + resultEvalTp = types.ETDatetime + } + + argTps := []types.EvalType{dateEvalTp, intervalEvalTp, types.ETString} + var bf baseBuiltinFunc + bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, resultEvalTp, argTps...) + if err != nil { + return nil, err + } + bf.tp.SetType(resultTp) + + var resultFsp int + if types.IsMicrosecondUnit(unit) { + resultFsp = types.MaxFsp + } else { + intervalFsp := types.MinFsp + if unit == "SECOND" { + if intervalEvalTp == types.ETString || intervalEvalTp == types.ETReal { + intervalFsp = types.MaxFsp + } else { + intervalFsp = mathutil.Min(types.MaxFsp, args[1].GetType(ctx.GetEvalCtx()).GetDecimal()) + } + } + resultFsp = mathutil.Min(types.MaxFsp, mathutil.Max(args[0].GetType(ctx.GetEvalCtx()).GetDecimal(), intervalFsp)) + } + switch resultTp { + case mysql.TypeDate: + bf.setDecimalAndFlenForDate() + case mysql.TypeDuration: + bf.setDecimalAndFlenForTime(resultFsp) + case mysql.TypeDatetime: + bf.setDecimalAndFlenForDatetime(resultFsp) + case mysql.TypeVarString: + bf.tp.SetFlen(mysql.MaxDatetimeFullWidth) + bf.tp.SetDecimal(types.MinFsp) + } + + switch { + case dateEvalTp == types.ETString && intervalEvalTp == types.ETString: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromString, + vecGetDate: vecGetDateFromString, + getInterval: getIntervalFromString, + vecGetInterval: vecGetIntervalFromString, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringString, tipb.ScalarFuncSig_SubDateStringString) + case dateEvalTp == types.ETString && intervalEvalTp == types.ETInt: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromString, + vecGetDate: vecGetDateFromString, + getInterval: getIntervalFromInt, + vecGetInterval: vecGetIntervalFromInt, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringInt, tipb.ScalarFuncSig_SubDateStringInt) + case dateEvalTp == types.ETString && intervalEvalTp == types.ETReal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromString, + vecGetDate: vecGetDateFromString, + getInterval: getIntervalFromReal, + vecGetInterval: vecGetIntervalFromReal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringReal, tipb.ScalarFuncSig_SubDateStringReal) + case dateEvalTp == types.ETString && intervalEvalTp == types.ETDecimal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromString, + vecGetDate: vecGetDateFromString, + getInterval: getIntervalFromDecimal, + vecGetInterval: vecGetIntervalFromDecimal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringDecimal, tipb.ScalarFuncSig_SubDateStringDecimal) + case dateEvalTp == types.ETInt && intervalEvalTp == types.ETString: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromInt, + vecGetDate: vecGetDateFromInt, + getInterval: getIntervalFromString, + vecGetInterval: vecGetIntervalFromString, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntString, tipb.ScalarFuncSig_SubDateIntString) + case dateEvalTp == types.ETInt && intervalEvalTp == types.ETInt: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromInt, + vecGetDate: vecGetDateFromInt, + getInterval: getIntervalFromInt, + vecGetInterval: vecGetIntervalFromInt, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntInt, tipb.ScalarFuncSig_SubDateIntInt) + case dateEvalTp == types.ETInt && intervalEvalTp == types.ETReal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromInt, + vecGetDate: vecGetDateFromInt, + getInterval: getIntervalFromReal, + vecGetInterval: vecGetIntervalFromReal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntReal, tipb.ScalarFuncSig_SubDateIntReal) + case dateEvalTp == types.ETInt && intervalEvalTp == types.ETDecimal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromInt, + vecGetDate: vecGetDateFromInt, + getInterval: getIntervalFromDecimal, + vecGetInterval: vecGetIntervalFromDecimal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntDecimal, tipb.ScalarFuncSig_SubDateIntDecimal) + case dateEvalTp == types.ETReal && intervalEvalTp == types.ETString: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromReal, + vecGetDate: vecGetDateFromReal, + getInterval: getIntervalFromString, + vecGetInterval: vecGetIntervalFromString, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealString, tipb.ScalarFuncSig_SubDateRealString) + case dateEvalTp == types.ETReal && intervalEvalTp == types.ETInt: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromReal, + vecGetDate: vecGetDateFromReal, + getInterval: getIntervalFromInt, + vecGetInterval: vecGetIntervalFromInt, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealInt, tipb.ScalarFuncSig_SubDateRealInt) + case dateEvalTp == types.ETReal && intervalEvalTp == types.ETReal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromReal, + vecGetDate: vecGetDateFromReal, + getInterval: getIntervalFromReal, + vecGetInterval: vecGetIntervalFromReal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealReal, tipb.ScalarFuncSig_SubDateRealReal) + case dateEvalTp == types.ETReal && intervalEvalTp == types.ETDecimal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromReal, + vecGetDate: vecGetDateFromReal, + getInterval: getIntervalFromDecimal, + vecGetInterval: vecGetIntervalFromDecimal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealDecimal, tipb.ScalarFuncSig_SubDateRealDecimal) + case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETString: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromDecimal, + vecGetDate: vecGetDateFromDecimal, + getInterval: getIntervalFromString, + vecGetInterval: vecGetIntervalFromString, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalString, tipb.ScalarFuncSig_SubDateDecimalString) + case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETInt: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromDecimal, + vecGetDate: vecGetDateFromDecimal, + getInterval: getIntervalFromInt, + vecGetInterval: vecGetIntervalFromInt, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalInt, tipb.ScalarFuncSig_SubDateDecimalInt) + case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETReal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromDecimal, + vecGetDate: vecGetDateFromDecimal, + getInterval: getIntervalFromReal, + vecGetInterval: vecGetIntervalFromReal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalReal, tipb.ScalarFuncSig_SubDateDecimalReal) + case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETDecimal: + sig = &builtinAddSubDateAsStringSig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getDate: getDateFromDecimal, + vecGetDate: vecGetDateFromDecimal, + getInterval: getIntervalFromDecimal, + vecGetInterval: vecGetIntervalFromDecimal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalDecimal, tipb.ScalarFuncSig_SubDateDecimalDecimal) + case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETString: + sig = &builtinAddSubDateDatetimeAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromString, + vecGetInterval: vecGetIntervalFromString, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeString, tipb.ScalarFuncSig_SubDateDatetimeString) + case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETInt: + sig = &builtinAddSubDateDatetimeAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromInt, + vecGetInterval: vecGetIntervalFromInt, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeInt, tipb.ScalarFuncSig_SubDateDatetimeInt) + case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETReal: + sig = &builtinAddSubDateDatetimeAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromReal, + vecGetInterval: vecGetIntervalFromReal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeReal, tipb.ScalarFuncSig_SubDateDatetimeReal) + case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETDecimal: + sig = &builtinAddSubDateDatetimeAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromDecimal, + vecGetInterval: vecGetIntervalFromDecimal, + timeOp: c.timeOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeDecimal, tipb.ScalarFuncSig_SubDateDatetimeDecimal) + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETString: + sig = &builtinAddSubDateDurationAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromString, + vecGetInterval: vecGetIntervalFromString, + timeOp: c.timeOp, + durationOp: c.durationOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationString, tipb.ScalarFuncSig_SubDateDurationString) + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETInt: + sig = &builtinAddSubDateDurationAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromInt, + vecGetInterval: vecGetIntervalFromInt, + timeOp: c.timeOp, + durationOp: c.durationOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationInt, tipb.ScalarFuncSig_SubDateDurationInt) + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETReal: + sig = &builtinAddSubDateDurationAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromReal, + vecGetInterval: vecGetIntervalFromReal, + timeOp: c.timeOp, + durationOp: c.durationOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationReal, tipb.ScalarFuncSig_SubDateDurationReal) + case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: + sig = &builtinAddSubDateDurationAnySig{ + baseBuiltinFunc: bf, + baseDateArithmetical: newDateArithmeticalUtil(), + getInterval: getIntervalFromDecimal, + vecGetInterval: vecGetIntervalFromDecimal, + timeOp: c.timeOp, + durationOp: c.durationOp, + } + c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationDecimal, tipb.ScalarFuncSig_SubDateDurationDecimal) + } + return sig, nil +} + +type builtinAddSubDateAsStringSig struct { + baseBuiltinFunc + baseDateArithmetical + getDate funcGetDateForDateAddSub + vecGetDate funcVecGetDateForDateAddSub + getInterval funcGetIntervalForDateAddSub + vecGetInterval funcVecGetIntervalForDateAddSub + timeOp funcTimeOpForDateAddSub +} + +func (b *builtinAddSubDateAsStringSig) Clone() builtinFunc { + newSig := &builtinAddSubDateAsStringSig{ + baseDateArithmetical: b.baseDateArithmetical, + getDate: b.getDate, + vecGetDate: b.vecGetDate, + getInterval: b.getInterval, + vecGetInterval: b.vecGetInterval, + timeOp: b.timeOp, + } + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinAddSubDateAsStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + unit, isNull, err := b.args[2].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime.String(), true, err + } + + date, isNull, err := b.getDate(&b.baseDateArithmetical, ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroTime.String(), true, err + } + if date.InvalidZero() { + return types.ZeroTime.String(), true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) + } + + interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroTime.String(), true, err + } + + result, isNull, err := b.timeOp(&b.baseDateArithmetical, ctx, date, interval, unit, b.tp.GetDecimal()) + if result.Microsecond() == 0 { + result.SetFsp(types.MinFsp) + } else { + result.SetFsp(types.MaxFsp) + } + + return result.String(), isNull, err +} + +type builtinAddSubDateDatetimeAnySig struct { + baseBuiltinFunc + baseDateArithmetical + getInterval funcGetIntervalForDateAddSub + vecGetInterval funcVecGetIntervalForDateAddSub + timeOp funcTimeOpForDateAddSub +} + +func (b *builtinAddSubDateDatetimeAnySig) Clone() builtinFunc { + newSig := &builtinAddSubDateDatetimeAnySig{ + baseDateArithmetical: b.baseDateArithmetical, + getInterval: b.getInterval, + vecGetInterval: b.vecGetInterval, + timeOp: b.timeOp, + } + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinAddSubDateDatetimeAnySig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + unit, isNull, err := b.args[2].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + date, isNull, err := b.getDateFromDatetime(ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + result, isNull, err := b.timeOp(&b.baseDateArithmetical, ctx, date, interval, unit, b.tp.GetDecimal()) + return result, isNull || err != nil, err +} + +type builtinAddSubDateDurationAnySig struct { + baseBuiltinFunc + baseDateArithmetical + getInterval funcGetIntervalForDateAddSub + vecGetInterval funcVecGetIntervalForDateAddSub + timeOp funcTimeOpForDateAddSub + durationOp funcDurationOpForDateAddSub +} + +func (b *builtinAddSubDateDurationAnySig) Clone() builtinFunc { + newSig := &builtinAddSubDateDurationAnySig{ + baseDateArithmetical: b.baseDateArithmetical, + getInterval: b.getInterval, + vecGetInterval: b.vecGetInterval, + timeOp: b.timeOp, + durationOp: b.durationOp, + } + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinAddSubDateDurationAnySig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + unit, isNull, err := b.args[2].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + d, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroTime, true, err + } + + tc := typeCtx(ctx) + t, err := d.ConvertToTime(tc, mysql.TypeDatetime) + if err != nil { + return types.ZeroTime, true, err + } + result, isNull, err := b.timeOp(&b.baseDateArithmetical, ctx, t, interval, unit, b.tp.GetDecimal()) + return result, isNull || err != nil, err +} + +func (b *builtinAddSubDateDurationAnySig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + unit, isNull, err := b.args[2].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + dur, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) + if isNull || err != nil { + return types.ZeroDuration, true, err + } + + result, isNull, err := b.durationOp(&b.baseDateArithmetical, ctx, dur, interval, unit, b.tp.GetDecimal()) + return result, isNull || err != nil, err +} + +type timestampDiffFunctionClass struct { + baseFunctionClass +} + +func (c *timestampDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDatetime, types.ETDatetime) + if err != nil { + return nil, err + } + sig := &builtinTimestampDiffSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_TimestampDiff) + return sig, nil +} + +type builtinTimestampDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinTimestampDiffSig) Clone() builtinFunc { + newSig := &builtinTimestampDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinTimestampDiffSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timestampdiff +func (b *builtinTimestampDiffSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + unit, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + lhs, isNull, err := b.args[1].EvalTime(ctx, row) + if isNull || err != nil { + return 0, isNull, handleInvalidTimeError(ctx, err) + } + rhs, isNull, err := b.args[2].EvalTime(ctx, row) + if isNull || err != nil { + return 0, isNull, handleInvalidTimeError(ctx, err) + } + if invalidLHS, invalidRHS := lhs.InvalidZero(), rhs.InvalidZero(); invalidLHS || invalidRHS { + if invalidLHS { + err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, lhs.String())) + } + if invalidRHS { + err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rhs.String())) + } + return 0, true, err + } + return types.TimestampDiff(unit, lhs, rhs), false, nil +} + +type unixTimestampFunctionClass struct { + baseFunctionClass +} + +func (c *unixTimestampFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + var ( + argTps []types.EvalType + retTp types.EvalType + retFLen, retDecimal int + ) + + if len(args) == 0 { + retTp, retDecimal = types.ETInt, 0 + } else { + argTps = []types.EvalType{types.ETDatetime} + argType := args[0].GetType(ctx.GetEvalCtx()) + argEvaltp := argType.EvalType() + if argEvaltp == types.ETString { + // Treat types.ETString as unspecified decimal. + retDecimal = types.UnspecifiedLength + if cnst, ok := args[0].(*Constant); ok { + tmpStr, _, err := cnst.EvalString(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, err + } + retDecimal = 0 + if dotIdx := strings.LastIndex(tmpStr, "."); dotIdx >= 0 { + retDecimal = len(tmpStr) - dotIdx - 1 + } + } + } else { + retDecimal = argType.GetDecimal() + } + if retDecimal > 6 || retDecimal == types.UnspecifiedLength { + retDecimal = 6 + } + if retDecimal == 0 { + retTp = types.ETInt + } else { + retTp = types.ETDecimal + } + } + if retTp == types.ETInt { + retFLen = 11 + } else if retTp == types.ETDecimal { + retFLen = 12 + retDecimal + } else { + panic("Unexpected retTp") + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, retTp, argTps...) + if err != nil { + return nil, err + } + bf.tp.SetFlenUnderLimit(retFLen) + bf.tp.SetDecimalUnderLimit(retDecimal) + + var sig builtinFunc + if len(args) == 0 { + sig = &builtinUnixTimestampCurrentSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UnixTimestampCurrent) + } else if retTp == types.ETInt { + sig = &builtinUnixTimestampIntSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UnixTimestampInt) + } else if retTp == types.ETDecimal { + sig = &builtinUnixTimestampDecSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UnixTimestampDec) + } + return sig, nil +} + +// goTimeToMysqlUnixTimestamp converts go time into MySQL's Unix timestamp. +// MySQL's Unix timestamp ranges from '1970-01-01 00:00:01.000000' UTC to '3001-01-18 23:59:59.999999' UTC. Values out of range should be rewritten to 0. +// https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_unix-timestamp +func goTimeToMysqlUnixTimestamp(t time.Time, decimal int) (*types.MyDecimal, error) { + microSeconds := t.UnixMicro() + // Prior to MySQL 8.0.28 (or any 32-bit platform), the valid range of argument values is the same as for the TIMESTAMP data type: + // '1970-01-01 00:00:01.000000' UTC to '2038-01-19 03:14:07.999999' UTC. + // After 8.0.28, the range has been extended to '1970-01-01 00:00:01.000000' UTC to '3001-01-18 23:59:59.999999' UTC + // The magic value of '3001-01-18 23:59:59.999999' comes from the maximum supported timestamp on windows. Though TiDB + // doesn't support windows, this value is used here to keep the compatibility with MySQL + if microSeconds < 1e6 || microSeconds > 32536771199999999 { + return new(types.MyDecimal), nil + } + dec := new(types.MyDecimal) + // Here we don't use float to prevent precision lose. + dec.FromUint(uint64(microSeconds)) + err := dec.Shift(-6) + if err != nil { + return nil, err + } + + // In MySQL's implementation, unix_timestamp() will truncate the result instead of rounding it. + // Results below are from MySQL 5.7, which can prove it. + // mysql> select unix_timestamp(), unix_timestamp(now(0)), now(0), unix_timestamp(now(3)), now(3), now(6); + // +------------------+------------------------+---------------------+------------------------+-------------------------+----------------------------+ + // | unix_timestamp() | unix_timestamp(now(0)) | now(0) | unix_timestamp(now(3)) | now(3) | now(6) | + // +------------------+------------------------+---------------------+------------------------+-------------------------+----------------------------+ + // | 1553503194 | 1553503194 | 2019-03-25 16:39:54 | 1553503194.992 | 2019-03-25 16:39:54.992 | 2019-03-25 16:39:54.992969 | + // +------------------+------------------------+---------------------+------------------------+-------------------------+----------------------------+ + err = dec.Round(dec, decimal, types.ModeTruncate) + return dec, err +} + +type builtinUnixTimestampCurrentSig struct { + baseBuiltinFunc +} + +func (b *builtinUnixTimestampCurrentSig) Clone() builtinFunc { + newSig := &builtinUnixTimestampCurrentSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a UNIX_TIMESTAMP(). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp +func (b *builtinUnixTimestampCurrentSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return 0, true, err + } + dec, err := goTimeToMysqlUnixTimestamp(nowTs, 1) + if err != nil { + return 0, true, err + } + intVal, err := dec.ToInt() + if !terror.ErrorEqual(err, types.ErrTruncated) { + terror.Log(err) + } + return intVal, false, nil +} + +type builtinUnixTimestampIntSig struct { + baseBuiltinFunc +} + +func (b *builtinUnixTimestampIntSig) Clone() builtinFunc { + newSig := &builtinUnixTimestampIntSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a UNIX_TIMESTAMP(time). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp +func (b *builtinUnixTimestampIntSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + val, isNull, err := b.args[0].EvalTime(ctx, row) + if err != nil && terror.ErrorEqual(types.ErrWrongValue.GenWithStackByArgs(types.TimeStr, val), err) { + // Return 0 for invalid date time. + return 0, false, nil + } + if isNull { + return 0, true, nil + } + + tz := location(ctx) + t, err := val.AdjustedGoTime(tz) + if err != nil { + return 0, false, nil + } + dec, err := goTimeToMysqlUnixTimestamp(t, 1) + if err != nil { + return 0, true, err + } + intVal, err := dec.ToInt() + if !terror.ErrorEqual(err, types.ErrTruncated) { + terror.Log(err) + } + return intVal, false, nil +} + +type builtinUnixTimestampDecSig struct { + baseBuiltinFunc +} + +func (b *builtinUnixTimestampDecSig) Clone() builtinFunc { + newSig := &builtinUnixTimestampDecSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDecimal evals a UNIX_TIMESTAMP(time). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp +func (b *builtinUnixTimestampDecSig) evalDecimal(ctx EvalContext, row chunk.Row) (*types.MyDecimal, bool, error) { + val, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + // Return 0 for invalid date time. + return new(types.MyDecimal), isNull, nil + } + t, err := val.GoTime(getTimeZone(ctx)) + if err != nil { + return new(types.MyDecimal), false, nil + } + result, err := goTimeToMysqlUnixTimestamp(t, b.tp.GetDecimal()) + return result, err != nil, err +} + +type timestampFunctionClass struct { + baseFunctionClass +} + +func (c *timestampFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + evalTps, argLen := []types.EvalType{types.ETString}, len(args) + if argLen == 2 { + evalTps = append(evalTps, types.ETString) + } + fsp, err := getExpressionFsp(ctx, args[0]) + if err != nil { + return nil, err + } + if argLen == 2 { + fsp2, err := getExpressionFsp(ctx, args[1]) + if err != nil { + return nil, err + } + if fsp2 > fsp { + fsp = fsp2 + } + } + isFloat := false + switch args[0].GetType(ctx.GetEvalCtx()).GetType() { + case mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal, mysql.TypeLonglong: + isFloat = true + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, evalTps...) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(fsp) + var sig builtinFunc + if argLen == 2 { + sig = &builtinTimestamp2ArgsSig{bf, isFloat} + sig.setPbCode(tipb.ScalarFuncSig_Timestamp2Args) + } else { + sig = &builtinTimestamp1ArgSig{bf, isFloat} + sig.setPbCode(tipb.ScalarFuncSig_Timestamp1Arg) + } + return sig, nil +} + +type builtinTimestamp1ArgSig struct { + baseBuiltinFunc + + isFloat bool +} + +func (b *builtinTimestamp1ArgSig) Clone() builtinFunc { + newSig := &builtinTimestamp1ArgSig{isFloat: b.isFloat} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinTimestamp1ArgSig. +// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_timestamp +func (b *builtinTimestamp1ArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + s, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + var tm types.Time + tc := typeCtx(ctx) + if b.isFloat { + tm, err = types.ParseTimeFromFloatString(tc, s, mysql.TypeDatetime, types.GetFsp(s)) + } else { + tm, err = types.ParseTime(tc, s, mysql.TypeDatetime, types.GetFsp(s)) + } + if err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + return tm, false, nil +} + +type builtinTimestamp2ArgsSig struct { + baseBuiltinFunc + + isFloat bool +} + +func (b *builtinTimestamp2ArgsSig) Clone() builtinFunc { + newSig := &builtinTimestamp2ArgsSig{isFloat: b.isFloat} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinTimestamp2ArgsSig. +// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_timestamp +func (b *builtinTimestamp2ArgsSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + arg0, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + var tm types.Time + tc := typeCtx(ctx) + if b.isFloat { + tm, err = types.ParseTimeFromFloatString(tc, arg0, mysql.TypeDatetime, types.GetFsp(arg0)) + } else { + tm, err = types.ParseTime(tc, arg0, mysql.TypeDatetime, types.GetFsp(arg0)) + } + if err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + if tm.Year() == 0 { + // MySQL won't evaluate add for date with zero year. + // See https://github.com/mysql/mysql-server/blob/5.7/sql/item_timefunc.cc#L2805 + return types.ZeroTime, true, nil + } + arg1, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, isNull, err + } + if !isDuration(arg1) { + return types.ZeroTime, true, nil + } + duration, _, err := types.ParseDuration(tc, arg1, types.GetFsp(arg1)) + if err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + tmp, err := tm.Add(tc, duration) + if err != nil { + return types.ZeroTime, true, err + } + return tmp, false, nil +} + +type timestampLiteralFunctionClass struct { + baseFunctionClass +} + +func (c *timestampLiteralFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + con, ok := args[0].(*Constant) + if !ok { + panic("Unexpected parameter for timestamp literal") + } + dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, err + } + str, err := dt.ToString() + if err != nil { + return nil, err + } + if !timestampPattern.MatchString(str) { + return nil, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, str) + } + tm, err := types.ParseTime(ctx.GetEvalCtx().TypeCtx(), str, mysql.TypeDatetime, types.GetFsp(str)) + if err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, []Expression{}, types.ETDatetime) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(tm.Fsp()) + sig := &builtinTimestampLiteralSig{bf, tm} + return sig, nil +} + +type builtinTimestampLiteralSig struct { + baseBuiltinFunc + tm types.Time +} + +func (b *builtinTimestampLiteralSig) Clone() builtinFunc { + newSig := &builtinTimestampLiteralSig{tm: b.tm} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals TIMESTAMP 'stringLit'. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-literals.html +func (b *builtinTimestampLiteralSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + return b.tm, false, nil +} + +// getFsp4TimeAddSub is used to in function 'ADDTIME' and 'SUBTIME' to evaluate `fsp` for the +// second parameter. It's used only if the second parameter is of string type. It's different +// from getFsp in that the result of getFsp4TimeAddSub is either 6 or 0. +func getFsp4TimeAddSub(s string) int { + if len(s)-strings.Index(s, ".")-1 == len(s) { + return types.MinFsp + } + for _, c := range s[strings.Index(s, ".")+1:] { + if c != '0' { + return types.MaxFsp + } + } + return types.MinFsp +} + +// getBf4TimeAddSub parses input types, generates baseBuiltinFunc and set related attributes for +// builtin function 'ADDTIME' and 'SUBTIME' +func getBf4TimeAddSub(ctx BuildContext, funcName string, args []Expression) (tp1, tp2 *types.FieldType, bf baseBuiltinFunc, err error) { + tp1, tp2 = args[0].GetType(ctx.GetEvalCtx()), args[1].GetType(ctx.GetEvalCtx()) + var argTp1, argTp2, retTp types.EvalType + switch tp1.GetType() { + case mysql.TypeDatetime, mysql.TypeTimestamp: + argTp1, retTp = types.ETDatetime, types.ETDatetime + case mysql.TypeDuration: + argTp1, retTp = types.ETDuration, types.ETDuration + case mysql.TypeDate: + argTp1, retTp = types.ETDuration, types.ETString + default: + argTp1, retTp = types.ETString, types.ETString + } + switch tp2.GetType() { + case mysql.TypeDatetime, mysql.TypeDuration: + argTp2 = types.ETDuration + default: + argTp2 = types.ETString + } + arg0Dec, err := getExpressionFsp(ctx, args[0]) + if err != nil { + return + } + arg1Dec, err := getExpressionFsp(ctx, args[1]) + if err != nil { + return + } + + bf, err = newBaseBuiltinFuncWithTp(ctx, funcName, args, retTp, argTp1, argTp2) + if err != nil { + return + } + switch retTp { + case types.ETDatetime: + bf.setDecimalAndFlenForDatetime(mathutil.Min(mathutil.Max(arg0Dec, arg1Dec), types.MaxFsp)) + case types.ETDuration: + bf.setDecimalAndFlenForTime(mathutil.Min(mathutil.Max(arg0Dec, arg1Dec), types.MaxFsp)) + case types.ETString: + bf.tp.SetType(mysql.TypeString) + bf.tp.SetFlen(mysql.MaxDatetimeWidthWithFsp) + bf.tp.SetDecimal(types.UnspecifiedLength) + } + return +} + +func getTimeZone(ctx EvalContext) *time.Location { + ret := location(ctx) + if ret == nil { + ret = time.Local + } + return ret +} + +// isDuration returns a boolean indicating whether the str matches the format of duration. +// See https://dev.mysql.com/doc/refman/5.7/en/time.html +func isDuration(str string) bool { + return durationPattern.MatchString(str) +} + +// strDatetimeAddDuration adds duration to datetime string, returns a string value. +func strDatetimeAddDuration(tc types.Context, d string, arg1 types.Duration) (result string, isNull bool, err error) { + arg0, err := types.ParseTime(tc, d, mysql.TypeDatetime, types.MaxFsp) + if err != nil { + // Return a warning regardless of the sql_mode, this is compatible with MySQL. + tc.AppendWarning(err) + return "", true, nil + } + ret, err := arg0.Add(tc, arg1) + if err != nil { + return "", false, err + } + fsp := types.MaxFsp + if ret.Microsecond() == 0 { + fsp = types.MinFsp + } + ret.SetFsp(fsp) + return ret.String(), false, nil +} + +// strDurationAddDuration adds duration to duration string, returns a string value. +func strDurationAddDuration(tc types.Context, d string, arg1 types.Duration) (string, error) { + arg0, _, err := types.ParseDuration(tc, d, types.MaxFsp) + if err != nil { + return "", err + } + tmpDuration, err := arg0.Add(arg1) + if err != nil { + return "", err + } + tmpDuration.Fsp = types.MaxFsp + if tmpDuration.MicroSecond() == 0 { + tmpDuration.Fsp = types.MinFsp + } + return tmpDuration.String(), nil +} + +// strDatetimeSubDuration subtracts duration from datetime string, returns a string value. +func strDatetimeSubDuration(tc types.Context, d string, arg1 types.Duration) (result string, isNull bool, err error) { + arg0, err := types.ParseTime(tc, d, mysql.TypeDatetime, types.MaxFsp) + if err != nil { + // Return a warning regardless of the sql_mode, this is compatible with MySQL. + tc.AppendWarning(err) + return "", true, nil + } + resultTime, err := arg0.Add(tc, arg1.Neg()) + if err != nil { + return "", false, err + } + fsp := types.MaxFsp + if resultTime.Microsecond() == 0 { + fsp = types.MinFsp + } + resultTime.SetFsp(fsp) + return resultTime.String(), false, nil +} + +// strDurationSubDuration subtracts duration from duration string, returns a string value. +func strDurationSubDuration(tc types.Context, d string, arg1 types.Duration) (string, error) { + arg0, _, err := types.ParseDuration(tc, d, types.MaxFsp) + if err != nil { + return "", err + } + tmpDuration, err := arg0.Sub(arg1) + if err != nil { + return "", err + } + tmpDuration.Fsp = types.MaxFsp + if tmpDuration.MicroSecond() == 0 { + tmpDuration.Fsp = types.MinFsp + } + return tmpDuration.String(), nil +} + +type addTimeFunctionClass struct { + baseFunctionClass +} + +func (c *addTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err = c.verifyArgs(args); err != nil { + return nil, err + } + tp1, tp2, bf, err := getBf4TimeAddSub(ctx, c.funcName, args) + if err != nil { + return nil, err + } + switch tp1.GetType() { + case mysql.TypeDatetime, mysql.TypeTimestamp: + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinAddDatetimeAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddDatetimeAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinAddTimeDateTimeNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddTimeDateTimeNull) + default: + sig = &builtinAddDatetimeAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddDatetimeAndString) + } + case mysql.TypeDate: + charset, collate := ctx.GetCharsetInfo() + bf.tp.SetCharset(charset) + bf.tp.SetCollate(collate) + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinAddDateAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddDateAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinAddTimeStringNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddTimeStringNull) + default: + sig = &builtinAddDateAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddDateAndString) + } + case mysql.TypeDuration: + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinAddDurationAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddDurationAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinAddTimeDurationNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddTimeDurationNull) + default: + sig = &builtinAddDurationAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddDurationAndString) + } + default: + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinAddStringAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddStringAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinAddTimeStringNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddTimeStringNull) + default: + sig = &builtinAddStringAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_AddStringAndString) + } + } + return sig, nil +} + +type builtinAddTimeDateTimeNullSig struct { + baseBuiltinFunc +} + +func (b *builtinAddTimeDateTimeNullSig) Clone() builtinFunc { + newSig := &builtinAddTimeDateTimeNullSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinAddTimeDateTimeNullSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddTimeDateTimeNullSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + return types.ZeroDatetime, true, nil +} + +type builtinAddDatetimeAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinAddDatetimeAndDurationSig) Clone() builtinFunc { + newSig := &builtinAddDatetimeAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinAddDatetimeAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddDatetimeAndDurationSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + arg0, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + arg1, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + result, err := arg0.Add(typeCtx(ctx), arg1) + return result, err != nil, err +} + +type builtinAddDatetimeAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinAddDatetimeAndStringSig) Clone() builtinFunc { + newSig := &builtinAddDatetimeAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinAddDatetimeAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddDatetimeAndStringSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + arg0, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + s, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + if !isDuration(s) { + return types.ZeroDatetime, true, nil + } + tc := typeCtx(ctx) + arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return types.ZeroDatetime, true, nil + } + return types.ZeroDatetime, true, err + } + result, err := arg0.Add(tc, arg1) + return result, err != nil, err +} + +type builtinAddTimeDurationNullSig struct { + baseBuiltinFunc +} + +func (b *builtinAddTimeDurationNullSig) Clone() builtinFunc { + newSig := &builtinAddTimeDurationNullSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinAddTimeDurationNullSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddTimeDurationNullSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + return types.ZeroDuration, true, nil +} + +type builtinAddDurationAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinAddDurationAndDurationSig) Clone() builtinFunc { + newSig := &builtinAddDurationAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinAddDurationAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddDurationAndDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + arg1, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + result, err := arg0.Add(arg1) + if err != nil { + return types.ZeroDuration, true, err + } + return result, false, nil +} + +type builtinAddDurationAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinAddDurationAndStringSig) Clone() builtinFunc { + newSig := &builtinAddDurationAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinAddDurationAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddDurationAndStringSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + s, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + if !isDuration(s) { + return types.ZeroDuration, true, nil + } + tc := typeCtx(ctx) + arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return types.ZeroDuration, true, nil + } + return types.ZeroDuration, true, err + } + result, err := arg0.Add(arg1) + if err != nil { + return types.ZeroDuration, true, err + } + return result, false, nil +} + +type builtinAddTimeStringNullSig struct { + baseBuiltinFunc +} + +func (b *builtinAddTimeStringNullSig) Clone() builtinFunc { + newSig := &builtinAddTimeStringNullSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinAddDurationAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddTimeStringNullSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + return "", true, nil +} + +type builtinAddStringAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinAddStringAndDurationSig) Clone() builtinFunc { + newSig := &builtinAddStringAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinAddStringAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddStringAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { + var ( + arg0 string + arg1 types.Duration + ) + arg0, isNull, err = b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + arg1, isNull, err = b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + tc := typeCtx(ctx) + if isDuration(arg0) { + result, err = strDurationAddDuration(tc, arg0, arg1) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + return result, false, nil + } + result, isNull, err = strDatetimeAddDuration(tc, arg0, arg1) + return result, isNull, err +} + +type builtinAddStringAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinAddStringAndStringSig) Clone() builtinFunc { + newSig := &builtinAddStringAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinAddStringAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddStringAndStringSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { + var ( + arg0, arg1Str string + arg1 types.Duration + ) + arg0, isNull, err = b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + arg1Type := b.args[1].GetType(ctx) + if mysql.HasBinaryFlag(arg1Type.GetFlag()) { + return "", true, nil + } + arg1Str, isNull, err = b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + tc := typeCtx(ctx) + arg1, _, err = types.ParseDuration(tc, arg1Str, getFsp4TimeAddSub(arg1Str)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + + check := arg1Str + _, check, err = parser.Number(parser.Space0(check)) + if err == nil { + check, err = parser.Char(check, '-') + if strings.Compare(check, "") != 0 && err == nil { + return "", true, nil + } + } + + if isDuration(arg0) { + result, err = strDurationAddDuration(tc, arg0, arg1) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + return result, false, nil + } + result, isNull, err = strDatetimeAddDuration(tc, arg0, arg1) + return result, isNull, err +} + +type builtinAddDateAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinAddDateAndDurationSig) Clone() builtinFunc { + newSig := &builtinAddDateAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinAddDurationAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddDateAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + arg1, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + result, err := arg0.Add(arg1) + return result.String(), err != nil, err +} + +type builtinAddDateAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinAddDateAndStringSig) Clone() builtinFunc { + newSig := &builtinAddDateAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinAddDateAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime +func (b *builtinAddDateAndStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + s, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + if !isDuration(s) { + return "", true, nil + } + tc := typeCtx(ctx) + arg1, _, err := types.ParseDuration(tc, s, getFsp4TimeAddSub(s)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + result, err := arg0.Add(arg1) + return result.String(), err != nil, err +} + +type convertTzFunctionClass struct { + baseFunctionClass +} + +func (c *convertTzFunctionClass) getDecimal(ctx BuildContext, arg Expression) int { + decimal := types.MaxFsp + if dt, isConstant := arg.(*Constant); isConstant { + switch arg.GetType(ctx.GetEvalCtx()).EvalType() { + case types.ETInt: + decimal = 0 + case types.ETReal, types.ETDecimal: + decimal = arg.GetType(ctx.GetEvalCtx()).GetDecimal() + case types.ETString: + str, isNull, err := dt.EvalString(ctx.GetEvalCtx(), chunk.Row{}) + if err == nil && !isNull { + decimal = types.DateFSP(str) + } + } + } + if decimal > types.MaxFsp { + return types.MaxFsp + } + if decimal < types.MinFsp { + return types.MinFsp + } + return decimal +} + +func (c *convertTzFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + // tzRegex holds the regex to check whether a string is a time zone. + tzRegex, err := regexp.Compile(`(^[-+](0?[0-9]|1[0-3]):[0-5]?\d$)|(^\+14:00?$)`) + if err != nil { + return nil, err + } + + decimal := c.getDecimal(ctx, args[0]) + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime, types.ETString, types.ETString) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(decimal) + sig := &builtinConvertTzSig{ + baseBuiltinFunc: bf, + timezoneRegex: tzRegex, + } + sig.setPbCode(tipb.ScalarFuncSig_ConvertTz) + return sig, nil +} + +type builtinConvertTzSig struct { + baseBuiltinFunc + timezoneRegex *regexp.Regexp +} + +func (b *builtinConvertTzSig) Clone() builtinFunc { + newSig := &builtinConvertTzSig{timezoneRegex: b.timezoneRegex} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals CONVERT_TZ(dt,from_tz,to_tz). +// `CONVERT_TZ` function returns NULL if the arguments are invalid. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_convert-tz +func (b *builtinConvertTzSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + dt, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, nil + } + if dt.InvalidZero() { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, dt.String())) + } + fromTzStr, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, nil + } + + toTzStr, isNull, err := b.args[2].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, nil + } + + return b.convertTz(dt, fromTzStr, toTzStr) +} + +func (b *builtinConvertTzSig) convertTz(dt types.Time, fromTzStr, toTzStr string) (types.Time, bool, error) { + if fromTzStr == "" || toTzStr == "" { + return types.ZeroTime, true, nil + } + fromTzMatched := b.timezoneRegex.MatchString(fromTzStr) + toTzMatched := b.timezoneRegex.MatchString(toTzStr) + + var fromTz, toTz *time.Location + var err error + + if fromTzMatched { + fromTz = time.FixedZone(fromTzStr, timeZone2int(fromTzStr)) + } else { + if strings.EqualFold(fromTzStr, "SYSTEM") { + fromTzStr = "Local" + } + fromTz, err = time.LoadLocation(fromTzStr) + if err != nil { + return types.ZeroTime, true, nil + } + } + + t, err := dt.AdjustedGoTime(fromTz) + if err != nil { + return types.ZeroTime, true, nil + } + t = t.In(time.UTC) + + if toTzMatched { + toTz = time.FixedZone(toTzStr, timeZone2int(toTzStr)) + } else { + if strings.EqualFold(toTzStr, "SYSTEM") { + toTzStr = "Local" + } + toTz, err = time.LoadLocation(toTzStr) + if err != nil { + return types.ZeroTime, true, nil + } + } + + return types.NewTime(types.FromGoTime(t.In(toTz)), mysql.TypeDatetime, b.tp.GetDecimal()), false, nil +} + +type makeDateFunctionClass struct { + baseFunctionClass +} + +func (c *makeDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETInt, types.ETInt) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig := &builtinMakeDateSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_MakeDate) + return sig, nil +} + +type builtinMakeDateSig struct { + baseBuiltinFunc +} + +func (b *builtinMakeDateSig) Clone() builtinFunc { + newSig := &builtinMakeDateSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evaluates a builtinMakeDateSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_makedate +func (b *builtinMakeDateSig) evalTime(ctx EvalContext, row chunk.Row) (d types.Time, isNull bool, err error) { + args := b.getArgs() + var year, dayOfYear int64 + year, isNull, err = args[0].EvalInt(ctx, row) + if isNull || err != nil { + return d, true, err + } + dayOfYear, isNull, err = args[1].EvalInt(ctx, row) + if isNull || err != nil { + return d, true, err + } + if dayOfYear <= 0 || year < 0 || year > 9999 { + return d, true, nil + } + if year < 70 { + year += 2000 + } else if year < 100 { + year += 1900 + } + startTime := types.NewTime(types.FromDate(int(year), 1, 1, 0, 0, 0, 0), mysql.TypeDate, 0) + retTimestamp := types.TimestampDiff("DAY", types.ZeroDate, startTime) + if retTimestamp == 0 { + return d, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, startTime.String())) + } + ret := types.TimeFromDays(retTimestamp + dayOfYear - 1) + if ret.IsZero() || ret.Year() > 9999 { + return d, true, nil + } + return ret, false, nil +} + +type makeTimeFunctionClass struct { + baseFunctionClass +} + +func (c *makeTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + tp, decimal := args[2].GetType(ctx.GetEvalCtx()).EvalType(), 0 + switch tp { + case types.ETInt: + case types.ETReal, types.ETDecimal: + decimal = args[2].GetType(ctx.GetEvalCtx()).GetDecimal() + if decimal > 6 || decimal == types.UnspecifiedLength { + decimal = 6 + } + default: + decimal = 6 + } + // MySQL will cast the first and second arguments to INT, and the third argument to DECIMAL. + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETInt, types.ETInt, types.ETReal) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(decimal) + sig := &builtinMakeTimeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_MakeTime) + return sig, nil +} + +type builtinMakeTimeSig struct { + baseBuiltinFunc +} + +func (b *builtinMakeTimeSig) Clone() builtinFunc { + newSig := &builtinMakeTimeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinMakeTimeSig) makeTime(ctx types.Context, hour int64, minute int64, second float64, hourUnsignedFlag bool) (types.Duration, error) { + var overflow bool + // MySQL TIME datatype: https://dev.mysql.com/doc/refman/5.7/en/time.html + // ranges from '-838:59:59.000000' to '838:59:59.000000' + if hour < 0 && hourUnsignedFlag { + hour = 838 + overflow = true + } + if hour < -838 { + hour = -838 + overflow = true + } else if hour > 838 { + hour = 838 + overflow = true + } + if (hour == -838 || hour == 838) && minute == 59 && second > 59 { + overflow = true + } + if overflow { + minute = 59 + second = 59 + } + fsp := b.tp.GetDecimal() + d, _, err := types.ParseDuration(ctx, fmt.Sprintf("%02d:%02d:%v", hour, minute, second), fsp) + return d, err +} + +// evalDuration evals a builtinMakeTimeIntSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_maketime +func (b *builtinMakeTimeSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + dur := types.ZeroDuration + dur.Fsp = types.MaxFsp + hour, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil { + return dur, isNull, err + } + minute, isNull, err := b.args[1].EvalInt(ctx, row) + if isNull || err != nil { + return dur, isNull, err + } + if minute < 0 || minute >= 60 { + return dur, true, nil + } + second, isNull, err := b.args[2].EvalReal(ctx, row) + if isNull || err != nil { + return dur, isNull, err + } + if second < 0 || second >= 60 { + return dur, true, nil + } + dur, err = b.makeTime(typeCtx(ctx), hour, minute, second, mysql.HasUnsignedFlag(b.args[0].GetType(ctx).GetFlag())) + if err != nil { + return dur, true, err + } + return dur, false, nil +} + +type periodAddFunctionClass struct { + baseFunctionClass +} + +func (c *periodAddFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) + if err != nil { + return nil, err + } + bf.tp.SetFlen(6) + sig := &builtinPeriodAddSig{bf} + return sig, nil +} + +// validPeriod checks if this period is valid, it comes from MySQL 8.0+. +func validPeriod(p int64) bool { + return !(p < 0 || p%100 == 0 || p%100 > 12) +} + +// period2Month converts a period to months, in which period is represented in the format of YYMM or YYYYMM. +// Note that the period argument is not a date value. +func period2Month(period uint64) uint64 { + if period == 0 { + return 0 + } + + year, month := period/100, period%100 + if year < 70 { + year += 2000 + } else if year < 100 { + year += 1900 + } + + return year*12 + month - 1 +} + +// month2Period converts a month to a period. +func month2Period(month uint64) uint64 { + if month == 0 { + return 0 + } + + year := month / 12 + if year < 70 { + year += 2000 + } else if year < 100 { + year += 1900 + } + + return year*100 + month%12 + 1 +} + +type builtinPeriodAddSig struct { + baseBuiltinFunc +} + +func (b *builtinPeriodAddSig) Clone() builtinFunc { + newSig := &builtinPeriodAddSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals PERIOD_ADD(P,N). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_period-add +func (b *builtinPeriodAddSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + p, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil { + return 0, true, err + } + + n, isNull, err := b.args[1].EvalInt(ctx, row) + if isNull || err != nil { + return 0, true, err + } + + // in MySQL, if p is invalid but n is NULL, the result is NULL, so we have to check if n is NULL first. + if !validPeriod(p) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_add") + } + + sumMonth := int64(period2Month(uint64(p))) + n + return int64(month2Period(uint64(sumMonth))), false, nil +} + +type periodDiffFunctionClass struct { + baseFunctionClass +} + +func (c *periodDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) + if err != nil { + return nil, err + } + bf.tp.SetFlen(6) + sig := &builtinPeriodDiffSig{bf} + return sig, nil +} + +type builtinPeriodDiffSig struct { + baseBuiltinFunc +} + +func (b *builtinPeriodDiffSig) Clone() builtinFunc { + newSig := &builtinPeriodDiffSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals PERIOD_DIFF(P1,P2). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_period-diff +func (b *builtinPeriodDiffSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + p1, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + + p2, isNull, err := b.args[1].EvalInt(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + + if !validPeriod(p1) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") + } + + if !validPeriod(p2) { + return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") + } + + return int64(period2Month(uint64(p1)) - period2Month(uint64(p2))), false, nil +} + +type quarterFunctionClass struct { + baseFunctionClass +} + +func (c *quarterFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + bf.tp.SetFlen(1) + + sig := &builtinQuarterSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_Quarter) + return sig, nil +} + +type builtinQuarterSig struct { + baseBuiltinFunc +} + +func (b *builtinQuarterSig) Clone() builtinFunc { + newSig := &builtinQuarterSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals QUARTER(date). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_quarter +func (b *builtinQuarterSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + date, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + + return int64((date.Month() + 2) / 3), false, nil +} + +type secToTimeFunctionClass struct { + baseFunctionClass +} + +func (c *secToTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + var retFsp int + argType := args[0].GetType(ctx.GetEvalCtx()) + argEvalTp := argType.EvalType() + if argEvalTp == types.ETString { + retFsp = types.UnspecifiedLength + } else { + retFsp = argType.GetDecimal() + } + if retFsp > types.MaxFsp || retFsp == types.UnspecifiedFsp { + retFsp = types.MaxFsp + } else if retFsp < types.MinFsp { + retFsp = types.MinFsp + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETReal) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(retFsp) + sig := &builtinSecToTimeSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SecToTime) + return sig, nil +} + +type builtinSecToTimeSig struct { + baseBuiltinFunc +} + +func (b *builtinSecToTimeSig) Clone() builtinFunc { + newSig := &builtinSecToTimeSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals SEC_TO_TIME(seconds). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_sec-to-time +func (b *builtinSecToTimeSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + secondsFloat, isNull, err := b.args[0].EvalReal(ctx, row) + if isNull || err != nil { + return types.Duration{}, isNull, err + } + var ( + hour uint64 + minute uint64 + second uint64 + demical float64 + secondDemical float64 + negative string + ) + + if secondsFloat < 0 { + negative = "-" + secondsFloat = math.Abs(secondsFloat) + } + seconds := uint64(secondsFloat) + demical = secondsFloat - float64(seconds) + + hour = seconds / 3600 + if hour > 838 { + hour = 838 + minute = 59 + second = 59 + demical = 0 + tc := typeCtx(ctx) + err = tc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) + if err != nil { + return types.Duration{}, err != nil, err + } + } else { + minute = seconds % 3600 / 60 + second = seconds % 60 + } + secondDemical = float64(second) + demical + + var dur types.Duration + dur, _, err = types.ParseDuration(typeCtx(ctx), fmt.Sprintf("%s%02d:%02d:%s", negative, hour, minute, strconv.FormatFloat(secondDemical, 'f', -1, 64)), b.tp.GetDecimal()) + if err != nil { + return types.Duration{}, err != nil, err + } + return dur, false, nil +} + +type subTimeFunctionClass struct { + baseFunctionClass +} + +func (c *subTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { + if err = c.verifyArgs(args); err != nil { + return nil, err + } + tp1, tp2, bf, err := getBf4TimeAddSub(ctx, c.funcName, args) + if err != nil { + return nil, err + } + switch tp1.GetType() { + case mysql.TypeDatetime, mysql.TypeTimestamp: + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinSubDatetimeAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubDatetimeAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinSubTimeDateTimeNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubTimeDateTimeNull) + default: + sig = &builtinSubDatetimeAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubDatetimeAndString) + } + case mysql.TypeDate: + charset, collate := ctx.GetCharsetInfo() + bf.tp.SetCharset(charset) + bf.tp.SetCollate(collate) + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinSubDateAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubDateAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinSubTimeStringNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubTimeStringNull) + default: + sig = &builtinSubDateAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubDateAndString) + } + case mysql.TypeDuration: + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinSubDurationAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubDurationAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinSubTimeDurationNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubTimeDurationNull) + default: + sig = &builtinSubDurationAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubDurationAndString) + } + default: + switch tp2.GetType() { + case mysql.TypeDuration: + sig = &builtinSubStringAndDurationSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubStringAndDuration) + case mysql.TypeDatetime, mysql.TypeTimestamp: + sig = &builtinSubTimeStringNullSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubTimeStringNull) + default: + sig = &builtinSubStringAndStringSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_SubStringAndString) + } + } + return sig, nil +} + +type builtinSubDatetimeAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinSubDatetimeAndDurationSig) Clone() builtinFunc { + newSig := &builtinSubDatetimeAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinSubDatetimeAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubDatetimeAndDurationSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + arg0, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + arg1, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + tc := typeCtx(ctx) + result, err := arg0.Add(tc, arg1.Neg()) + return result, err != nil, err +} + +type builtinSubDatetimeAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinSubDatetimeAndStringSig) Clone() builtinFunc { + newSig := &builtinSubDatetimeAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinSubDatetimeAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubDatetimeAndStringSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + arg0, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + s, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroDatetime, isNull, err + } + if !isDuration(s) { + return types.ZeroDatetime, true, nil + } + tc := typeCtx(ctx) + arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return types.ZeroDatetime, true, nil + } + return types.ZeroDatetime, true, err + } + result, err := arg0.Add(tc, arg1.Neg()) + return result, err != nil, err +} + +type builtinSubTimeDateTimeNullSig struct { + baseBuiltinFunc +} + +func (b *builtinSubTimeDateTimeNullSig) Clone() builtinFunc { + newSig := &builtinSubTimeDateTimeNullSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinSubTimeDateTimeNullSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubTimeDateTimeNullSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + return types.ZeroDatetime, true, nil +} + +type builtinSubStringAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinSubStringAndDurationSig) Clone() builtinFunc { + newSig := &builtinSubStringAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinSubStringAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubStringAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { + var ( + arg0 string + arg1 types.Duration + ) + arg0, isNull, err = b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + arg1, isNull, err = b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + tc := typeCtx(ctx) + if isDuration(arg0) { + result, err = strDurationSubDuration(tc, arg0, arg1) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + return result, false, nil + } + result, isNull, err = strDatetimeSubDuration(tc, arg0, arg1) + return result, isNull, err +} + +type builtinSubStringAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinSubStringAndStringSig) Clone() builtinFunc { + newSig := &builtinSubStringAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinSubStringAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubStringAndStringSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { + var ( + s, arg0 string + arg1 types.Duration + ) + arg0, isNull, err = b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + arg1Type := b.args[1].GetType(ctx) + if mysql.HasBinaryFlag(arg1Type.GetFlag()) { + return "", true, nil + } + s, isNull, err = b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + tc := typeCtx(ctx) + arg1, _, err = types.ParseDuration(tc, s, getFsp4TimeAddSub(s)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + if isDuration(arg0) { + result, err = strDurationSubDuration(tc, arg0, arg1) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + return result, false, nil + } + result, isNull, err = strDatetimeSubDuration(tc, arg0, arg1) + return result, isNull, err +} + +type builtinSubTimeStringNullSig struct { + baseBuiltinFunc +} + +func (b *builtinSubTimeStringNullSig) Clone() builtinFunc { + newSig := &builtinSubTimeStringNullSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinSubTimeStringNullSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubTimeStringNullSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + return "", true, nil +} + +type builtinSubDurationAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinSubDurationAndDurationSig) Clone() builtinFunc { + newSig := &builtinSubDurationAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinSubDurationAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubDurationAndDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + arg1, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + result, err := arg0.Sub(arg1) + if err != nil { + return types.ZeroDuration, true, err + } + return result, false, nil +} + +type builtinSubDurationAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinSubDurationAndStringSig) Clone() builtinFunc { + newSig := &builtinSubDurationAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinSubDurationAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubDurationAndStringSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + s, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return types.ZeroDuration, isNull, err + } + if !isDuration(s) { + return types.ZeroDuration, true, nil + } + tc := typeCtx(ctx) + arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return types.ZeroDuration, true, nil + } + return types.ZeroDuration, true, err + } + result, err := arg0.Sub(arg1) + return result, err != nil, err +} + +type builtinSubTimeDurationNullSig struct { + baseBuiltinFunc +} + +func (b *builtinSubTimeDurationNullSig) Clone() builtinFunc { + newSig := &builtinSubTimeDurationNullSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinSubTimeDurationNullSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubTimeDurationNullSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + return types.ZeroDuration, true, nil +} + +type builtinSubDateAndDurationSig struct { + baseBuiltinFunc +} + +func (b *builtinSubDateAndDurationSig) Clone() builtinFunc { + newSig := &builtinSubDateAndDurationSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinSubDateAndDurationSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubDateAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + arg1, isNull, err := b.args[1].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + result, err := arg0.Sub(arg1) + return result.String(), err != nil, err +} + +type builtinSubDateAndStringSig struct { + baseBuiltinFunc +} + +func (b *builtinSubDateAndStringSig) Clone() builtinFunc { + newSig := &builtinSubDateAndStringSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinSubDateAndStringSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime +func (b *builtinSubDateAndStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + arg0, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + s, isNull, err := b.args[1].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + if !isDuration(s) { + return "", true, nil + } + tc := typeCtx(ctx) + arg1, _, err := types.ParseDuration(tc, s, getFsp4TimeAddSub(s)) + if err != nil { + if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { + tc.AppendWarning(err) + return "", true, nil + } + return "", true, err + } + result, err := arg0.Sub(arg1) + if err != nil { + return "", true, err + } + return result.String(), false, nil +} + +type timeFormatFunctionClass struct { + baseFunctionClass +} + +func (c *timeFormatFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDuration, types.ETString) + if err != nil { + return nil, err + } + // worst case: formatMask=%r%r%r...%r, each %r takes 11 characters + bf.tp.SetFlen((args[1].GetType(ctx.GetEvalCtx()).GetFlen() + 1) / 2 * 11) + sig := &builtinTimeFormatSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_TimeFormat) + return sig, nil +} + +type builtinTimeFormatSig struct { + baseBuiltinFunc +} + +func (b *builtinTimeFormatSig) Clone() builtinFunc { + newSig := &builtinTimeFormatSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalString evals a builtinTimeFormatSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time-format +func (b *builtinTimeFormatSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + dur, isNull, err := b.args[0].EvalDuration(ctx, row) + // if err != nil, then dur is ZeroDuration, outputs 00:00:00 in this case which follows the behavior of mysql. + if err != nil { + logutil.BgLogger().Warn("time_format.args[0].EvalDuration failed", zap.Error(err)) + } + if isNull { + return "", isNull, err + } + formatMask, isNull, err := b.args[1].EvalString(ctx, row) + if err != nil || isNull { + return "", isNull, err + } + res, err := b.formatTime(dur, formatMask) + return res, isNull, err +} + +// formatTime see https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time-format +func (b *builtinTimeFormatSig) formatTime(t types.Duration, formatMask string) (res string, err error) { + return t.DurationFormat(formatMask) +} + +type timeToSecFunctionClass struct { + baseFunctionClass +} + +func (c *timeToSecFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) + if err != nil { + return nil, err + } + bf.tp.SetFlen(10) + sig := &builtinTimeToSecSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_TimeToSec) + return sig, nil +} + +type builtinTimeToSecSig struct { + baseBuiltinFunc +} + +func (b *builtinTimeToSecSig) Clone() builtinFunc { + newSig := &builtinTimeToSecSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals TIME_TO_SEC(time). +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time-to-sec +func (b *builtinTimeToSecSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + duration, isNull, err := b.args[0].EvalDuration(ctx, row) + if isNull || err != nil { + return 0, isNull, err + } + var sign int + if duration.Duration >= 0 { + sign = 1 + } else { + sign = -1 + } + return int64(sign * (duration.Hour()*3600 + duration.Minute()*60 + duration.Second())), false, nil +} + +type timestampAddFunctionClass struct { + baseFunctionClass +} + +func (c *timestampAddFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString, types.ETReal, types.ETDatetime) + if err != nil { + return nil, err + } + flen := mysql.MaxDatetimeWidthNoFsp + con, ok := args[0].(*Constant) + if !ok { + return nil, errors.New("should not happened") + } + unit, null, err := con.EvalString(ctx.GetEvalCtx(), chunk.Row{}) + if null || err != nil { + return nil, errors.New("should not happened") + } + if unit == ast.TimeUnitMicrosecond.String() { + flen = mysql.MaxDatetimeWidthWithFsp + } + + bf.tp.SetFlen(flen) + sig := &builtinTimestampAddSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_TimestampAdd) + return sig, nil +} + +type builtinTimestampAddSig struct { + baseBuiltinFunc +} + +func (b *builtinTimestampAddSig) Clone() builtinFunc { + newSig := &builtinTimestampAddSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +var ( + minDatetimeInGoTime, _ = types.MinDatetime.GoTime(time.Local) + minDatetimeNanos = float64(minDatetimeInGoTime.Unix())*1e9 + float64(minDatetimeInGoTime.Nanosecond()) + maxDatetimeInGoTime, _ = types.MaxDatetime.GoTime(time.Local) + maxDatetimeNanos = float64(maxDatetimeInGoTime.Unix())*1e9 + float64(maxDatetimeInGoTime.Nanosecond()) + minDatetimeMonths = float64(types.MinDatetime.Year()*12 + types.MinDatetime.Month() - 1) // 0001-01-01 00:00:00 + maxDatetimeMonths = float64(types.MaxDatetime.Year()*12 + types.MaxDatetime.Month() - 1) // 9999-12-31 00:00:00 +) + +func validAddTime(nano1 float64, nano2 float64) bool { + return nano1+nano2 >= minDatetimeNanos && nano1+nano2 <= maxDatetimeNanos +} + +func validAddMonth(month1 float64, year, month int) bool { + tmp := month1 + float64(year)*12 + float64(month-1) + return tmp >= minDatetimeMonths && tmp <= maxDatetimeMonths +} + +func addUnitToTime(unit string, t time.Time, v float64) (time.Time, bool, error) { + s := math.Trunc(v * 1000000) + // round to the nearest int + v = math.Round(v) + var tb time.Time + nano := float64(t.Unix())*1e9 + float64(t.Nanosecond()) + switch unit { + case "MICROSECOND": + if !validAddTime(v*float64(time.Microsecond), nano) { + return tb, true, nil + } + tb = t.Add(time.Duration(v) * time.Microsecond) + case "SECOND": + if !validAddTime(s*float64(time.Microsecond), nano) { + return tb, true, nil + } + tb = t.Add(time.Duration(s) * time.Microsecond) + case "MINUTE": + if !validAddTime(v*float64(time.Minute), nano) { + return tb, true, nil + } + tb = t.Add(time.Duration(v) * time.Minute) + case "HOUR": + if !validAddTime(v*float64(time.Hour), nano) { + return tb, true, nil + } + tb = t.Add(time.Duration(v) * time.Hour) + case "DAY": + if !validAddTime(v*24*float64(time.Hour), nano) { + return tb, true, nil + } + tb = t.AddDate(0, 0, int(v)) + case "WEEK": + if !validAddTime(v*24*7*float64(time.Hour), nano) { + return tb, true, nil + } + tb = t.AddDate(0, 0, 7*int(v)) + case "MONTH": + if !validAddMonth(v, t.Year(), int(t.Month())) { + return tb, true, nil + } + + var err error + tb, err = types.AddDate(0, int64(v), 0, t) + if err != nil { + return tb, false, err + } + case "QUARTER": + if !validAddMonth(v*3, t.Year(), int(t.Month())) { + return tb, true, nil + } + tb = t.AddDate(0, 3*int(v), 0) + case "YEAR": + if !validAddMonth(v*12, t.Year(), int(t.Month())) { + return tb, true, nil + } + tb = t.AddDate(int(v), 0, 0) + default: + return tb, false, types.ErrWrongValue.GenWithStackByArgs(types.TimeStr, unit) + } + return tb, false, nil +} + +// evalString evals a builtinTimestampAddSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timestampadd +func (b *builtinTimestampAddSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { + unit, isNull, err := b.args[0].EvalString(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + v, isNull, err := b.args[1].EvalReal(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + arg, isNull, err := b.args[2].EvalTime(ctx, row) + if isNull || err != nil { + return "", isNull, err + } + tm1, err := arg.GoTime(time.Local) + if err != nil { + tc := typeCtx(ctx) + tc.AppendWarning(err) + return "", true, nil + } + tb, overflow, err := addUnitToTime(unit, tm1, v) + if err != nil { + return "", true, err + } + if overflow { + return "", true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) + } + fsp := types.DefaultFsp + // use MaxFsp when microsecond is not zero + if tb.Nanosecond()/1000 != 0 { + fsp = types.MaxFsp + } + r := types.NewTime(types.FromGoTime(tb), b.resolveType(arg.Type(), unit), fsp) + if err = r.Check(typeCtx(ctx)); err != nil { + return "", true, handleInvalidTimeError(ctx, err) + } + return r.String(), false, nil +} + +func (b *builtinTimestampAddSig) resolveType(typ uint8, unit string) uint8 { + // The approach below is from MySQL. + // The field type for the result of an Item_date function is defined as + // follows: + // + //- If first arg is a MYSQL_TYPE_DATETIME result is MYSQL_TYPE_DATETIME + //- If first arg is a MYSQL_TYPE_DATE and the interval type uses hours, + // minutes, seconds or microsecond then type is MYSQL_TYPE_DATETIME. + //- Otherwise the result is MYSQL_TYPE_STRING + // (This is because you can't know if the string contains a DATE, MYSQL_TIME + // or DATETIME argument) + if typ == mysql.TypeDate && (unit == "HOUR" || unit == "MINUTE" || unit == "SECOND" || unit == "MICROSECOND") { + return mysql.TypeDatetime + } + return typ +} + +type toDaysFunctionClass struct { + baseFunctionClass +} + +func (c *toDaysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + sig := &builtinToDaysSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ToDays) + return sig, nil +} + +type builtinToDaysSig struct { + baseBuiltinFunc +} + +func (b *builtinToDaysSig) Clone() builtinFunc { + newSig := &builtinToDaysSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinToDaysSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_to-days +func (b *builtinToDaysSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + if arg.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + ret := types.TimestampDiff("DAY", types.ZeroDate, arg) + if ret == 0 { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + return ret, false, nil +} + +type toSecondsFunctionClass struct { + baseFunctionClass +} + +func (c *toSecondsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) + if err != nil { + return nil, err + } + sig := &builtinToSecondsSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_ToSeconds) + return sig, nil +} + +type builtinToSecondsSig struct { + baseBuiltinFunc +} + +func (b *builtinToSecondsSig) Clone() builtinFunc { + newSig := &builtinToSecondsSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalInt evals a builtinToSecondsSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_to-seconds +func (b *builtinToSecondsSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return 0, true, handleInvalidTimeError(ctx, err) + } + if arg.InvalidZero() { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + ret := types.TimestampDiff("SECOND", types.ZeroDate, arg) + if ret == 0 { + return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + return ret, false, nil +} + +type utcTimeFunctionClass struct { + baseFunctionClass +} + +func (c *utcTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTps := make([]types.EvalType, 0, 1) + if len(args) == 1 { + argTps = append(argTps, types.ETInt) + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, argTps...) + if err != nil { + return nil, err + } + fsp, err := getFspByIntArg(ctx, args) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForTime(fsp) + // 1. no sign. + // 2. hour is in the 2-digit range. + bf.tp.SetFlen(bf.tp.GetFlen() - 2) + + var sig builtinFunc + if len(args) == 1 { + sig = &builtinUTCTimeWithArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UTCTimeWithArg) + } else { + sig = &builtinUTCTimeWithoutArgSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_UTCTimeWithoutArg) + } + return sig, nil +} + +type builtinUTCTimeWithoutArgSig struct { + baseBuiltinFunc +} + +func (b *builtinUTCTimeWithoutArgSig) Clone() builtinFunc { + newSig := &builtinUTCTimeWithoutArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinUTCTimeWithoutArgSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-time +func (b *builtinUTCTimeWithoutArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.Duration{}, true, err + } + v, _, err := types.ParseDuration(typeCtx(ctx), nowTs.UTC().Format(types.TimeFormat), 0) + return v, false, err +} + +type builtinUTCTimeWithArgSig struct { + baseBuiltinFunc +} + +func (b *builtinUTCTimeWithArgSig) Clone() builtinFunc { + newSig := &builtinUTCTimeWithArgSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalDuration evals a builtinUTCTimeWithArgSig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-time +func (b *builtinUTCTimeWithArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { + fsp, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil { + return types.Duration{}, isNull, err + } + if fsp > int64(types.MaxFsp) { + return types.Duration{}, true, errors.Errorf("Too-big precision %v specified for 'utc_time'. Maximum is %v", fsp, types.MaxFsp) + } + if fsp < int64(types.MinFsp) { + return types.Duration{}, true, errors.Errorf("Invalid negative %d specified, must in [0, 6]", fsp) + } + nowTs, err := getStmtTimestamp(ctx) + if err != nil { + return types.Duration{}, true, err + } + v, _, err := types.ParseDuration(typeCtx(ctx), nowTs.UTC().Format(types.TimeFSPFormat), int(fsp)) + return v, false, err +} + +type lastDayFunctionClass struct { + baseFunctionClass +} + +func (c *lastDayFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDate() + sig := &builtinLastDaySig{bf} + sig.setPbCode(tipb.ScalarFuncSig_LastDay) + return sig, nil +} + +type builtinLastDaySig struct { + baseBuiltinFunc +} + +func (b *builtinLastDaySig) Clone() builtinFunc { + newSig := &builtinLastDaySig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinLastDaySig. +// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_last-day +func (b *builtinLastDaySig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + arg, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + tm := arg + year, month := tm.Year(), tm.Month() + if tm.Month() == 0 || (tm.Day() == 0 && sqlMode(ctx).HasNoZeroDateMode()) { + return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) + } + lastDay := types.GetLastDay(year, month) + ret := types.NewTime(types.FromDate(year, month, lastDay, 0, 0, 0, 0), mysql.TypeDate, types.DefaultFsp) + return ret, false, nil +} + +// getExpressionFsp calculates the fsp from given expression. +// This function must by called before calling newBaseBuiltinFuncWithTp. +func getExpressionFsp(ctx BuildContext, expression Expression) (int, error) { + constExp, isConstant := expression.(*Constant) + if isConstant { + str, isNil, err := constExp.EvalString(ctx.GetEvalCtx(), chunk.Row{}) + if isNil || err != nil { + return 0, err + } + return types.GetFsp(str), nil + } + warpExpr := WrapWithCastAsTime(ctx, expression, types.NewFieldType(mysql.TypeDatetime)) + return mathutil.Min(warpExpr.GetType(ctx.GetEvalCtx()).GetDecimal(), types.MaxFsp), nil +} + +// tidbParseTsoFunctionClass extracts physical time from a tso +type tidbParseTsoFunctionClass struct { + baseFunctionClass +} + +func (c *tidbParseTsoFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + argTp := args[0].GetType(ctx.GetEvalCtx()).EvalType() + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, argTp, types.ETInt) + if err != nil { + return nil, err + } + + bf.tp.SetType(mysql.TypeDatetime) + bf.tp.SetFlen(mysql.MaxDateWidth) + bf.tp.SetDecimal(types.DefaultFsp) + sig := &builtinTidbParseTsoSig{bf} + return sig, nil +} + +type builtinTidbParseTsoSig struct { + baseBuiltinFunc +} + +func (b *builtinTidbParseTsoSig) Clone() builtinFunc { + newSig := &builtinTidbParseTsoSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinTidbParseTsoSig. +func (b *builtinTidbParseTsoSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + arg, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil || arg <= 0 { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + + t := oracle.GetTimeFromTS(uint64(arg)) + result := types.NewTime(types.FromGoTime(t), mysql.TypeDatetime, types.MaxFsp) + err = result.ConvertTimeZone(time.Local, location(ctx)) + if err != nil { + return types.ZeroTime, true, err + } + return result, false, nil +} + +// tidbParseTsoFunctionClass extracts logical time from a tso +type tidbParseTsoLogicalFunctionClass struct { + baseFunctionClass +} + +func (c *tidbParseTsoLogicalFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt) + if err != nil { + return nil, err + } + + sig := &builtinTidbParseTsoLogicalSig{bf} + return sig, nil +} + +type builtinTidbParseTsoLogicalSig struct { + baseBuiltinFunc +} + +func (b *builtinTidbParseTsoLogicalSig) Clone() builtinFunc { + newSig := &builtinTidbParseTsoLogicalSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalTime evals a builtinTidbParseTsoLogicalSig. +func (b *builtinTidbParseTsoLogicalSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { + arg, isNull, err := b.args[0].EvalInt(ctx, row) + if isNull || err != nil || arg <= 0 { + return 0, true, err + } + + t := oracle.ExtractLogical(uint64(arg)) + return t, false, nil +} + +// tidbBoundedStalenessFunctionClass reads a time window [a, b] and compares it with the latest SafeTS +// to determine which TS to use in a read only transaction. +type tidbBoundedStalenessFunctionClass struct { + baseFunctionClass +} + +func (c *tidbBoundedStalenessFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime, types.ETDatetime) + if err != nil { + return nil, err + } + bf.setDecimalAndFlenForDatetime(3) + sig := &builtinTiDBBoundedStalenessSig{baseBuiltinFunc: bf} + return sig, nil +} + +type builtinTiDBBoundedStalenessSig struct { + baseBuiltinFunc + contextopt.SessionVarsPropReader + contextopt.KVStorePropReader +} + +// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. +func (b *builtinTiDBBoundedStalenessSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { + return b.SessionVarsPropReader.RequiredOptionalEvalProps() | + b.KVStorePropReader.RequiredOptionalEvalProps() +} + +func (b *builtinTiDBBoundedStalenessSig) Clone() builtinFunc { + newSig := &builtinTidbParseTsoSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinTiDBBoundedStalenessSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { + store, err := b.GetKVStore(ctx) + if err != nil { + return types.ZeroTime, true, err + } + + vars, err := b.GetSessionVars(ctx) + if err != nil { + return types.ZeroTime, true, err + } + + leftTime, isNull, err := b.args[0].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + rightTime, isNull, err := b.args[1].EvalTime(ctx, row) + if isNull || err != nil { + return types.ZeroTime, true, handleInvalidTimeError(ctx, err) + } + if invalidLeftTime, invalidRightTime := leftTime.InvalidZero(), rightTime.InvalidZero(); invalidLeftTime || invalidRightTime { + if invalidLeftTime { + err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, leftTime.String())) + } + if invalidRightTime { + err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rightTime.String())) + } + return types.ZeroTime, true, err + } + timeZone := getTimeZone(ctx) + minTime, err := leftTime.GoTime(timeZone) + if err != nil { + return types.ZeroTime, true, err + } + maxTime, err := rightTime.GoTime(timeZone) + if err != nil { + return types.ZeroTime, true, err + } + if minTime.After(maxTime) { + return types.ZeroTime, true, nil + } + // Because the minimum unit of a TSO is millisecond, so we only need fsp to be 3. + return types.NewTime(types.FromGoTime(calAppropriateTime(minTime, maxTime, GetStmtMinSafeTime(vars.StmtCtx, store, timeZone))), mysql.TypeDatetime, 3), false, nil +} + +// GetStmtMinSafeTime get minSafeTime +func GetStmtMinSafeTime(sc *stmtctx.StatementContext, store kv.Storage, timeZone *time.Location) time.Time { + var minSafeTS uint64 + txnScope := config.GetTxnScopeFromConfig() + if store != nil { + minSafeTS = store.GetMinSafeTS(txnScope) + } + // Inject mocked SafeTS for test. + failpoint.Inject("injectSafeTS", func(val failpoint.Value) { + injectTS := val.(int) + minSafeTS = uint64(injectTS) + }) + // Try to get from the stmt cache to make sure this function is deterministic. + minSafeTS = sc.GetOrStoreStmtCache(stmtctx.StmtSafeTSCacheKey, minSafeTS).(uint64) + return oracle.GetTimeFromTS(minSafeTS).In(timeZone) +} + +// CalAppropriateTime directly calls calAppropriateTime +func CalAppropriateTime(minTime, maxTime, minSafeTime time.Time) time.Time { + return calAppropriateTime(minTime, maxTime, minSafeTime) +} + +// For a SafeTS t and a time range [t1, t2]: +// 1. If t < t1, we will use t1 as the result, +// and with it, a read request may fail because it's an unreached SafeTS. +// 2. If t1 <= t <= t2, we will use t as the result, and with it, +// a read request won't fail. +// 2. If t2 < t, we will use t2 as the result, +// and with it, a read request won't fail because it's bigger than the latest SafeTS. +func calAppropriateTime(minTime, maxTime, minSafeTime time.Time) time.Time { + if minSafeTime.Before(minTime) || minSafeTime.After(maxTime) { + logutil.BgLogger().Debug("calAppropriateTime", + zap.Time("minTime", minTime), + zap.Time("maxTime", maxTime), + zap.Time("minSafeTime", minSafeTime)) + if minSafeTime.Before(minTime) { + return minTime + } else if minSafeTime.After(maxTime) { + return maxTime + } + } + logutil.BgLogger().Debug("calAppropriateTime", + zap.Time("minTime", minTime), + zap.Time("maxTime", maxTime), + zap.Time("minSafeTime", minSafeTime)) + return minSafeTime +} + +// getFspByIntArg is used by some time functions to get the result fsp. If len(expr) == 0, then the fsp is not explicit set, use 0 as default. +func getFspByIntArg(ctx BuildContext, exps []Expression) (int, error) { + if len(exps) == 0 { + return 0, nil + } + if len(exps) != 1 { + return 0, errors.Errorf("Should not happen, the num of argument should be 1, but got %d", len(exps)) + } + _, ok := exps[0].(*Constant) + if ok { + fsp, isNuLL, err := exps[0].EvalInt(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil || isNuLL { + // If isNULL, it may be a bug of parser. Return 0 to be compatible with old version. + return 0, err + } + if fsp > int64(types.MaxFsp) { + return 0, errors.Errorf("Too-big precision %v specified for 'curtime'. Maximum is %v", fsp, types.MaxFsp) + } else if fsp < int64(types.MinFsp) { + return 0, errors.Errorf("Invalid negative %d specified, must in [0, 6]", fsp) + } + return int(fsp), nil + } + // Should no happen. But our tests may generate non-constant input. + return 0, nil +} + +type tidbCurrentTsoFunctionClass struct { + baseFunctionClass +} + +func (c *tidbCurrentTsoFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, err + } + bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt) + if err != nil { + return nil, err + } + sig := &builtinTiDBCurrentTsoSig{baseBuiltinFunc: bf} + return sig, nil +} + +type builtinTiDBCurrentTsoSig struct { + baseBuiltinFunc + contextopt.SessionVarsPropReader +} + +func (b *builtinTiDBCurrentTsoSig) Clone() builtinFunc { + newSig := &builtinTiDBCurrentTsoSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +func (b *builtinTiDBCurrentTsoSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { + return b.SessionVarsPropReader.RequiredOptionalEvalProps() +} + +// evalInt evals currentTSO(). +func (b *builtinTiDBCurrentTsoSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { + sessionVars, err := b.GetSessionVars(ctx) + if err != nil { + return 0, true, err + } + tso, _ := sessionVars.GetSessionOrGlobalSystemVar(context.Background(), "tidb_current_ts") + itso, _ := strconv.ParseInt(tso, 10, 64) + return itso, false, nil +} diff --git a/pkg/expression/expr_to_pb.go b/pkg/expression/expr_to_pb.go index 4c74eedb5e441..c4b853f85919e 100644 --- a/pkg/expression/expr_to_pb.go +++ b/pkg/expression/expr_to_pb.go @@ -250,9 +250,9 @@ func (pc PbConverter) scalarFuncToPBExpr(expr *ScalarFunction) *tipb.Expr { // Check whether this function has ProtoBuf signature. pbCode := expr.Function.PbCode() if pbCode <= tipb.ScalarFuncSig_Unspecified { - failpoint.Inject("PanicIfPbCodeUnspecified", func() { + if _, _err_ := failpoint.Eval(_curpkg_("PanicIfPbCodeUnspecified")); _err_ == nil { panic(errors.Errorf("unspecified PbCode: %T", expr.Function)) - }) + } return nil } diff --git a/pkg/expression/expr_to_pb.go__failpoint_stash__ b/pkg/expression/expr_to_pb.go__failpoint_stash__ new file mode 100644 index 0000000000000..4c74eedb5e441 --- /dev/null +++ b/pkg/expression/expr_to_pb.go__failpoint_stash__ @@ -0,0 +1,319 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "strconv" + + "github.com/gogo/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" + ast "github.com/pingcap/tidb/pkg/parser/types" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +// ExpressionsToPBList converts expressions to tipb.Expr list for new plan. +func ExpressionsToPBList(ctx EvalContext, exprs []Expression, client kv.Client) (pbExpr []*tipb.Expr, err error) { + pc := PbConverter{client: client, ctx: ctx} + for _, expr := range exprs { + v := pc.ExprToPB(expr) + if v == nil { + return nil, plannererrors.ErrInternal.GenWithStack("expression %v cannot be pushed down", expr.StringWithCtx(ctx, errors.RedactLogDisable)) + } + pbExpr = append(pbExpr, v) + } + return +} + +// ProjectionExpressionsToPBList converts PhysicalProjection's expressions to tipb.Expr list for new plan. +// It doesn't check type for top level column expression, since top level column expression doesn't imply any calculations +func ProjectionExpressionsToPBList(ctx EvalContext, exprs []Expression, client kv.Client) (pbExpr []*tipb.Expr, err error) { + pc := PbConverter{client: client, ctx: ctx} + for _, expr := range exprs { + var v *tipb.Expr + if column, ok := expr.(*Column); ok { + v = pc.columnToPBExpr(column, false) + } else { + v = pc.ExprToPB(expr) + } + if v == nil { + return nil, plannererrors.ErrInternal.GenWithStack("expression %v cannot be pushed down", expr.StringWithCtx(ctx, errors.RedactLogDisable)) + } + pbExpr = append(pbExpr, v) + } + return +} + +// PbConverter supplies methods to convert TiDB expressions to TiPB. +type PbConverter struct { + client kv.Client + ctx EvalContext +} + +// NewPBConverter creates a PbConverter. +func NewPBConverter(client kv.Client, ctx EvalContext) PbConverter { + return PbConverter{client: client, ctx: ctx} +} + +// ExprToPB converts Expression to TiPB. +func (pc PbConverter) ExprToPB(expr Expression) *tipb.Expr { + switch x := expr.(type) { + case *Constant: + pbExpr := pc.conOrCorColToPBExpr(expr) + if pbExpr == nil { + return nil + } + return pbExpr + case *CorrelatedColumn: + return pc.conOrCorColToPBExpr(expr) + case *Column: + return pc.columnToPBExpr(x, true) + case *ScalarFunction: + return pc.scalarFuncToPBExpr(x) + } + return nil +} + +func (pc PbConverter) conOrCorColToPBExpr(expr Expression) *tipb.Expr { + ft := expr.GetType(pc.ctx) + d, err := expr.Eval(pc.ctx, chunk.Row{}) + if err != nil { + logutil.BgLogger().Error("eval constant or correlated column", zap.String("expression", expr.ExplainInfo(pc.ctx)), zap.Error(err)) + return nil + } + tp, val, ok := pc.encodeDatum(ft, d) + if !ok { + return nil + } + + if !pc.client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) { + return nil + } + return &tipb.Expr{Tp: tp, Val: val, FieldType: ToPBFieldType(ft)} +} + +func (pc *PbConverter) encodeDatum(ft *types.FieldType, d types.Datum) (tipb.ExprType, []byte, bool) { + var ( + tp tipb.ExprType + val []byte + ) + switch d.Kind() { + case types.KindNull: + tp = tipb.ExprType_Null + case types.KindInt64: + tp = tipb.ExprType_Int64 + val = codec.EncodeInt(nil, d.GetInt64()) + case types.KindUint64: + tp = tipb.ExprType_Uint64 + val = codec.EncodeUint(nil, d.GetUint64()) + case types.KindString, types.KindBinaryLiteral: + tp = tipb.ExprType_String + val = d.GetBytes() + case types.KindMysqlBit: + tp = tipb.ExprType_MysqlBit + val = d.GetBytes() + case types.KindBytes: + tp = tipb.ExprType_Bytes + val = d.GetBytes() + case types.KindFloat32: + tp = tipb.ExprType_Float32 + val = codec.EncodeFloat(nil, d.GetFloat64()) + case types.KindFloat64: + tp = tipb.ExprType_Float64 + val = codec.EncodeFloat(nil, d.GetFloat64()) + case types.KindMysqlDuration: + tp = tipb.ExprType_MysqlDuration + val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration)) + case types.KindMysqlDecimal: + tp = tipb.ExprType_MysqlDecimal + var err error + // Use precision and fraction from MyDecimal instead of the ones in datum itself. + // These two set of parameters are not the same. MyDecimal is compatible with MySQL + // so the precision and fraction from MyDecimal are consistent with MySQL. The other + // ones come from the column type which belongs to the output schema. Here the datum + // are encoded into protobuf and will be used to do calculation so it should use the + // MyDecimal precision and fraction otherwise there may be a loss of accuracy. + val, err = codec.EncodeDecimal(nil, d.GetMysqlDecimal(), 0, 0) + if err != nil { + logutil.BgLogger().Error("encode decimal", zap.Error(err)) + return tp, nil, false + } + case types.KindMysqlTime: + if pc.client.IsRequestTypeSupported(kv.ReqTypeDAG, int64(tipb.ExprType_MysqlTime)) { + tp = tipb.ExprType_MysqlTime + tc, ec := typeCtx(pc.ctx), errCtx(pc.ctx) + val, err := codec.EncodeMySQLTime(tc.Location(), d.GetMysqlTime(), ft.GetType(), nil) + err = ec.HandleError(err) + if err != nil { + logutil.BgLogger().Error("encode mysql time", zap.Error(err)) + return tp, nil, false + } + return tp, val, true + } + return tp, nil, false + case types.KindMysqlEnum: + tp = tipb.ExprType_MysqlEnum + val = codec.EncodeUint(nil, d.GetUint64()) + default: + return tp, nil, false + } + return tp, val, true +} + +// ToPBFieldType converts *types.FieldType to *tipb.FieldType. +func ToPBFieldType(ft *types.FieldType) *tipb.FieldType { + return &tipb.FieldType{ + Tp: int32(ft.GetType()), + Flag: uint32(ft.GetFlag()), + Flen: int32(ft.GetFlen()), + Decimal: int32(ft.GetDecimal()), + Charset: ft.GetCharset(), + Collate: collate.CollationToProto(ft.GetCollate()), + Elems: ft.GetElems(), + } +} + +// ToPBFieldTypeWithCheck converts *types.FieldType to *tipb.FieldType with checking the valid decimal for TiFlash +func ToPBFieldTypeWithCheck(ft *types.FieldType, storeType kv.StoreType) (*tipb.FieldType, error) { + if storeType == kv.TiFlash && !ft.IsDecimalValid() { + return nil, errors.New(ft.String() + " can not be pushed to TiFlash because it contains invalid decimal('" + strconv.Itoa(ft.GetFlen()) + "','" + strconv.Itoa(ft.GetDecimal()) + "').") + } + return ToPBFieldType(ft), nil +} + +// FieldTypeFromPB converts *tipb.FieldType to *types.FieldType. +func FieldTypeFromPB(ft *tipb.FieldType) *types.FieldType { + ft1 := types.NewFieldTypeBuilder().SetType(byte(ft.Tp)).SetFlag(uint(ft.Flag)).SetFlen(int(ft.Flen)).SetDecimal(int(ft.Decimal)).SetCharset(ft.Charset).SetCollate(collate.ProtoToCollation(ft.Collate)).BuildP() + ft1.SetElems(ft.Elems) + return ft1 +} + +func (pc PbConverter) columnToPBExpr(column *Column, checkType bool) *tipb.Expr { + if !pc.client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tipb.ExprType_ColumnRef)) { + return nil + } + if checkType { + switch column.GetType(pc.ctx).GetType() { + case mysql.TypeBit: + if !IsPushDownEnabled(ast.TypeStr(mysql.TypeBit), kv.TiKV) { + return nil + } + case mysql.TypeSet, mysql.TypeGeometry, mysql.TypeUnspecified: + return nil + case mysql.TypeEnum: + if !IsPushDownEnabled("enum", kv.UnSpecified) { + return nil + } + } + } + + if pc.client.IsRequestTypeSupported(kv.ReqTypeDAG, kv.ReqSubTypeBasic) { + return &tipb.Expr{ + Tp: tipb.ExprType_ColumnRef, + Val: codec.EncodeInt(nil, int64(column.Index)), + FieldType: ToPBFieldType(column.RetType), + } + } + id := column.ID + // Zero Column ID is not a column from table, can not support for now. + if id == 0 || id == -1 { + return nil + } + + return &tipb.Expr{ + Tp: tipb.ExprType_ColumnRef, + Val: codec.EncodeInt(nil, id)} +} + +func (pc PbConverter) scalarFuncToPBExpr(expr *ScalarFunction) *tipb.Expr { + // Check whether this function has ProtoBuf signature. + pbCode := expr.Function.PbCode() + if pbCode <= tipb.ScalarFuncSig_Unspecified { + failpoint.Inject("PanicIfPbCodeUnspecified", func() { + panic(errors.Errorf("unspecified PbCode: %T", expr.Function)) + }) + return nil + } + + // Check whether this function can be pushed. + if !canFuncBePushed(pc.ctx, expr, kv.UnSpecified) { + return nil + } + + // Check whether all of its parameters can be pushed. + children := make([]*tipb.Expr, 0, len(expr.GetArgs())) + for _, arg := range expr.GetArgs() { + pbArg := pc.ExprToPB(arg) + if pbArg == nil { + return nil + } + children = append(children, pbArg) + } + + var encoded []byte + if metadata := expr.Function.metadata(); metadata != nil { + var err error + encoded, err = proto.Marshal(metadata) + if err != nil { + logutil.BgLogger().Error("encode metadata", zap.Any("metadata", metadata), zap.Error(err)) + return nil + } + } + + // put collation information into the RetType enforcedly and push it down to TiKV/MockTiKV + tp := *expr.RetType + if collate.NewCollationEnabled() { + _, str1 := expr.CharsetAndCollation() + tp.SetCollate(str1) + } + + // Construct expression ProtoBuf. + return &tipb.Expr{ + Tp: tipb.ExprType_ScalarFunc, + Val: encoded, + Sig: pbCode, + Children: children, + FieldType: ToPBFieldType(&tp), + } +} + +// GroupByItemToPB converts group by items to pb. +func GroupByItemToPB(ctx EvalContext, client kv.Client, expr Expression) *tipb.ByItem { + pc := PbConverter{client: client, ctx: ctx} + e := pc.ExprToPB(expr) + if e == nil { + return nil + } + return &tipb.ByItem{Expr: e} +} + +// SortByItemToPB converts order by items to pb. +func SortByItemToPB(ctx EvalContext, client kv.Client, expr Expression, desc bool) *tipb.ByItem { + pc := PbConverter{client: client, ctx: ctx} + e := pc.ExprToPB(expr) + if e == nil { + return nil + } + return &tipb.ByItem{Expr: e, Desc: desc} +} diff --git a/pkg/expression/helper.go b/pkg/expression/helper.go index 154d599142d95..26e4c11443098 100644 --- a/pkg/expression/helper.go +++ b/pkg/expression/helper.go @@ -162,9 +162,9 @@ func GetTimeValue(ctx BuildContext, v any, tp byte, fsp int, explicitTz *time.Lo // if timestamp session variable set, use session variable as current time, otherwise use cached time // during one sql statement, the "current_time" should be the same func getStmtTimestamp(ctx EvalContext) (time.Time, error) { - failpoint.Inject("injectNow", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("injectNow")); _err_ == nil { v := time.Unix(int64(val.(int)), 0) - failpoint.Return(v, nil) - }) + return v, nil + } return ctx.CurrentTime() } diff --git a/pkg/expression/helper.go__failpoint_stash__ b/pkg/expression/helper.go__failpoint_stash__ new file mode 100644 index 0000000000000..154d599142d95 --- /dev/null +++ b/pkg/expression/helper.go__failpoint_stash__ @@ -0,0 +1,170 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "math" + "strings" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" +) + +func boolToInt64(v bool) int64 { + if v { + return 1 + } + return 0 +} + +// IsValidCurrentTimestampExpr returns true if exprNode is a valid CurrentTimestamp expression. +// Here `valid` means it is consistent with the given fieldType's decimal. +func IsValidCurrentTimestampExpr(exprNode ast.ExprNode, fieldType *types.FieldType) bool { + fn, isFuncCall := exprNode.(*ast.FuncCallExpr) + if !isFuncCall || fn.FnName.L != ast.CurrentTimestamp { + return false + } + + containsArg := len(fn.Args) > 0 + // Fsp represents fractional seconds precision. + containsFsp := fieldType != nil && fieldType.GetDecimal() > 0 + var isConsistent bool + if containsArg { + v, ok := fn.Args[0].(*driver.ValueExpr) + isConsistent = ok && fieldType != nil && v.Datum.GetInt64() == int64(fieldType.GetDecimal()) + } + + return (containsArg && isConsistent) || (!containsArg && !containsFsp) +} + +// GetTimeCurrentTimestamp is used for generating a timestamp for some special cases: cast null value to timestamp type with not null flag. +func GetTimeCurrentTimestamp(ctx EvalContext, tp byte, fsp int) (d types.Datum, err error) { + var t types.Time + t, err = getTimeCurrentTimeStamp(ctx, tp, fsp) + if err != nil { + return d, err + } + d.SetMysqlTime(t) + return d, nil +} + +func getTimeCurrentTimeStamp(ctx EvalContext, tp byte, fsp int) (t types.Time, err error) { + value := types.NewTime(types.ZeroCoreTime, tp, fsp) + defaultTime, err := getStmtTimestamp(ctx) + if err != nil { + return value, err + } + value.SetCoreTime(types.FromGoTime(defaultTime.Truncate(time.Duration(math.Pow10(9-fsp)) * time.Nanosecond))) + if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate { + err = value.ConvertTimeZone(time.Local, ctx.Location()) + if err != nil { + return value, err + } + } + return value, nil +} + +// GetTimeValue gets the time value with type tp. +func GetTimeValue(ctx BuildContext, v any, tp byte, fsp int, explicitTz *time.Location) (d types.Datum, err error) { + var value types.Time + tc := ctx.GetEvalCtx().TypeCtx() + if explicitTz != nil { + tc = tc.WithLocation(explicitTz) + } + + switch x := v.(type) { + case string: + lowerX := strings.ToLower(x) + switch lowerX { + case ast.CurrentTimestamp: + if value, err = getTimeCurrentTimeStamp(ctx.GetEvalCtx(), tp, fsp); err != nil { + return d, err + } + case ast.CurrentDate: + if value, err = getTimeCurrentTimeStamp(ctx.GetEvalCtx(), tp, fsp); err != nil { + return d, err + } + yy, mm, dd := value.Year(), value.Month(), value.Day() + truncated := types.FromDate(yy, mm, dd, 0, 0, 0, 0) + value.SetCoreTime(truncated) + case types.ZeroDatetimeStr: + value, err = types.ParseTimeFromNum(tc, 0, tp, fsp) + terror.Log(err) + default: + value, err = types.ParseTime(tc, x, tp, fsp) + if err != nil { + return d, err + } + } + case *driver.ValueExpr: + switch x.Kind() { + case types.KindString: + value, err = types.ParseTime(tc, x.GetString(), tp, fsp) + if err != nil { + return d, err + } + case types.KindInt64: + value, err = types.ParseTimeFromNum(tc, x.GetInt64(), tp, fsp) + if err != nil { + return d, err + } + case types.KindNull: + return d, nil + default: + return d, errDefaultValue + } + case *ast.FuncCallExpr: + if x.FnName.L == ast.CurrentTimestamp || x.FnName.L == ast.CurrentDate { + d.SetString(strings.ToUpper(x.FnName.L), mysql.DefaultCollationName) + return d, nil + } + return d, errDefaultValue + case *ast.UnaryOperationExpr: + // support some expression, like `-1` + v, err := EvalSimpleAst(ctx, x) + if err != nil { + return d, err + } + ft := types.NewFieldType(mysql.TypeLonglong) + xval, err := v.ConvertTo(tc, ft) + if err != nil { + return d, err + } + + value, err = types.ParseTimeFromNum(tc, xval.GetInt64(), tp, fsp) + if err != nil { + return d, err + } + default: + return d, nil + } + d.SetMysqlTime(value) + return d, nil +} + +// if timestamp session variable set, use session variable as current time, otherwise use cached time +// during one sql statement, the "current_time" should be the same +func getStmtTimestamp(ctx EvalContext) (time.Time, error) { + failpoint.Inject("injectNow", func(val failpoint.Value) { + v := time.Unix(int64(val.(int)), 0) + failpoint.Return(v, nil) + }) + return ctx.CurrentTime() +} diff --git a/pkg/expression/infer_pushdown.go b/pkg/expression/infer_pushdown.go index 4af5c0e912a8d..9ab3672594ad1 100644 --- a/pkg/expression/infer_pushdown.go +++ b/pkg/expression/infer_pushdown.go @@ -47,19 +47,19 @@ func canFuncBePushed(ctx EvalContext, sf *ScalarFunction, storeType kv.StoreType // Push down all expression if the `failpoint expression` is `all`, otherwise, check // whether scalar function's name is contained in the enabled expression list (e.g.`ne,eq,lt`). // If neither of the above is true, switch to original logic. - failpoint.Inject("PushDownTestSwitcher", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("PushDownTestSwitcher")); _err_ == nil { enabled := val.(string) if enabled == "all" { - failpoint.Return(true) + return true } exprs := strings.Split(enabled, ",") for _, expr := range exprs { if strings.ToLower(strings.TrimSpace(expr)) == sf.FuncName.L { - failpoint.Return(true) + return true } } - failpoint.Return(false) - }) + return false + } ret := false @@ -87,9 +87,9 @@ func canScalarFuncPushDown(ctx PushDownContext, scalarFunc *ScalarFunction, stor // Check whether this function can be pushed. if unspecified := pbCode <= tipb.ScalarFuncSig_Unspecified; unspecified || !canFuncBePushed(ctx.EvalCtx(), scalarFunc, storeType) { if unspecified { - failpoint.Inject("PanicIfPbCodeUnspecified", func() { + if _, _err_ := failpoint.Eval(_curpkg_("PanicIfPbCodeUnspecified")); _err_ == nil { panic(errors.Errorf("unspecified PbCode: %T", scalarFunc.Function)) - }) + } } storageName := storeType.Name() if storeType == kv.UnSpecified { diff --git a/pkg/expression/infer_pushdown.go__failpoint_stash__ b/pkg/expression/infer_pushdown.go__failpoint_stash__ new file mode 100644 index 0000000000000..4af5c0e912a8d --- /dev/null +++ b/pkg/expression/infer_pushdown.go__failpoint_stash__ @@ -0,0 +1,536 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + "strconv" + "strings" + "sync/atomic" + + "github.com/gogo/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +// DefaultExprPushDownBlacklist indicates the expressions which can not be pushed down to TiKV. +var DefaultExprPushDownBlacklist *atomic.Value + +// ExprPushDownBlackListReloadTimeStamp is used to record the last time when the push-down black list is reloaded. +// This is for plan cache, when the push-down black list is updated, we invalid all cached plans to avoid error. +var ExprPushDownBlackListReloadTimeStamp *atomic.Int64 + +func canFuncBePushed(ctx EvalContext, sf *ScalarFunction, storeType kv.StoreType) bool { + // Use the failpoint to control whether to push down an expression in the integration test. + // Push down all expression if the `failpoint expression` is `all`, otherwise, check + // whether scalar function's name is contained in the enabled expression list (e.g.`ne,eq,lt`). + // If neither of the above is true, switch to original logic. + failpoint.Inject("PushDownTestSwitcher", func(val failpoint.Value) { + enabled := val.(string) + if enabled == "all" { + failpoint.Return(true) + } + exprs := strings.Split(enabled, ",") + for _, expr := range exprs { + if strings.ToLower(strings.TrimSpace(expr)) == sf.FuncName.L { + failpoint.Return(true) + } + } + failpoint.Return(false) + }) + + ret := false + + switch storeType { + case kv.TiFlash: + ret = scalarExprSupportedByFlash(ctx, sf) + case kv.TiKV: + ret = scalarExprSupportedByTiKV(ctx, sf) + case kv.TiDB: + ret = scalarExprSupportedByTiDB(ctx, sf) + case kv.UnSpecified: + ret = scalarExprSupportedByTiDB(ctx, sf) || scalarExprSupportedByTiKV(ctx, sf) || scalarExprSupportedByFlash(ctx, sf) + } + + if ret { + funcFullName := fmt.Sprintf("%s.%s", sf.FuncName.L, strings.ToLower(sf.Function.PbCode().String())) + // Aside from checking function name, also check the pb name in case only the specific push down is disabled. + ret = IsPushDownEnabled(sf.FuncName.L, storeType) && IsPushDownEnabled(funcFullName, storeType) + } + return ret +} + +func canScalarFuncPushDown(ctx PushDownContext, scalarFunc *ScalarFunction, storeType kv.StoreType) bool { + pbCode := scalarFunc.Function.PbCode() + // Check whether this function can be pushed. + if unspecified := pbCode <= tipb.ScalarFuncSig_Unspecified; unspecified || !canFuncBePushed(ctx.EvalCtx(), scalarFunc, storeType) { + if unspecified { + failpoint.Inject("PanicIfPbCodeUnspecified", func() { + panic(errors.Errorf("unspecified PbCode: %T", scalarFunc.Function)) + }) + } + storageName := storeType.Name() + if storeType == kv.UnSpecified { + storageName = "storage layer" + } + warnErr := errors.NewNoStackError("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ", return type: " + scalarFunc.RetType.CompactStr() + ") is not supported to push down to " + storageName + " now.") + + ctx.AppendWarning(warnErr) + return false + } + canEnumPush := canEnumPushdownPreliminarily(scalarFunc) + // Check whether all of its parameters can be pushed. + for _, arg := range scalarFunc.GetArgs() { + if !canExprPushDown(ctx, arg, storeType, canEnumPush) { + return false + } + } + + if metadata := scalarFunc.Function.metadata(); metadata != nil { + var err error + _, err = proto.Marshal(metadata) + if err != nil { + logutil.BgLogger().Error("encode metadata", zap.Any("metadata", metadata), zap.Error(err)) + return false + } + } + return true +} + +func canExprPushDown(ctx PushDownContext, expr Expression, storeType kv.StoreType, canEnumPush bool) bool { + pc := ctx.PbConverter() + if storeType == kv.TiFlash { + switch expr.GetType(ctx.EvalCtx()).GetType() { + case mysql.TypeEnum, mysql.TypeBit, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeUnspecified: + if expr.GetType(ctx.EvalCtx()).GetType() == mysql.TypeEnum && canEnumPush { + break + } + warnErr := errors.NewNoStackError("Expression about '" + expr.StringWithCtx(ctx.EvalCtx(), errors.RedactLogDisable) + "' can not be pushed to TiFlash because it contains unsupported calculation of type '" + types.TypeStr(expr.GetType(ctx.EvalCtx()).GetType()) + "'.") + ctx.AppendWarning(warnErr) + return false + case mysql.TypeNewDecimal: + if !expr.GetType(ctx.EvalCtx()).IsDecimalValid() { + warnErr := errors.NewNoStackError("Expression about '" + expr.StringWithCtx(ctx.EvalCtx(), errors.RedactLogDisable) + "' can not be pushed to TiFlash because it contains invalid decimal('" + strconv.Itoa(expr.GetType(ctx.EvalCtx()).GetFlen()) + "','" + strconv.Itoa(expr.GetType(ctx.EvalCtx()).GetDecimal()) + "').") + ctx.AppendWarning(warnErr) + return false + } + } + } + switch x := expr.(type) { + case *CorrelatedColumn: + return pc.conOrCorColToPBExpr(expr) != nil && pc.columnToPBExpr(&x.Column, true) != nil + case *Constant: + return pc.conOrCorColToPBExpr(expr) != nil + case *Column: + return pc.columnToPBExpr(x, true) != nil + case *ScalarFunction: + return canScalarFuncPushDown(ctx, x, storeType) + } + return false +} + +func scalarExprSupportedByTiDB(ctx EvalContext, function *ScalarFunction) bool { + // TiDB can support all functions, but TiPB may not include some functions. + return scalarExprSupportedByTiKV(ctx, function) || scalarExprSupportedByFlash(ctx, function) +} + +// supported functions tracked by https://github.com/tikv/tikv/issues/5751 +func scalarExprSupportedByTiKV(ctx EvalContext, sf *ScalarFunction) bool { + switch sf.FuncName.L { + case + // op functions. + ast.LogicAnd, ast.LogicOr, ast.LogicXor, ast.UnaryNot, ast.And, ast.Or, ast.Xor, ast.BitNeg, ast.LeftShift, ast.RightShift, ast.UnaryMinus, + + // compare functions. + ast.LT, ast.LE, ast.EQ, ast.NE, ast.GE, ast.GT, ast.NullEQ, ast.In, ast.IsNull, ast.Like, ast.IsTruthWithoutNull, ast.IsTruthWithNull, ast.IsFalsity, + // ast.Greatest, ast.Least, ast.Interval + + // arithmetical functions. + ast.PI, /* ast.Truncate */ + ast.Plus, ast.Minus, ast.Mul, ast.Div, ast.Abs, ast.Mod, ast.IntDiv, + + // math functions. + ast.Ceil, ast.Ceiling, ast.Floor, ast.Sqrt, ast.Sign, ast.Ln, ast.Log, ast.Log2, ast.Log10, ast.Exp, ast.Pow, ast.Power, + + // Rust use the llvm math functions, which have different precision with Golang/MySQL(cmath) + // open the following switchers if we implement them in coprocessor via `cmath` + ast.Sin, ast.Asin, ast.Cos, ast.Acos /* ast.Tan */, ast.Atan, ast.Atan2, ast.Cot, + ast.Radians, ast.Degrees, ast.CRC32, + + // control flow functions. + ast.Case, ast.If, ast.Ifnull, ast.Coalesce, + + // string functions. + // ast.Bin, ast.Unhex, ast.Locate, ast.Ord, ast.Lpad, ast.Rpad, + // ast.Trim, ast.FromBase64, ast.ToBase64, ast.InsertFunc, + // ast.MakeSet, ast.SubstringIndex, ast.Instr, ast.Quote, ast.Oct, + // ast.FindInSet, ast.Repeat, + ast.Upper, ast.Lower, + ast.Length, ast.BitLength, ast.Concat, ast.ConcatWS, ast.Replace, ast.ASCII, ast.Hex, + ast.Reverse, ast.LTrim, ast.RTrim, ast.Strcmp, ast.Space, ast.Elt, ast.Field, + InternalFuncFromBinary, InternalFuncToBinary, ast.Mid, ast.Substring, ast.Substr, ast.CharLength, + ast.Right, /* ast.Left */ + + // json functions. + ast.JSONType, ast.JSONExtract, ast.JSONObject, ast.JSONArray, ast.JSONMerge, ast.JSONSet, + ast.JSONInsert, ast.JSONReplace, ast.JSONRemove, ast.JSONLength, ast.JSONMergePatch, + ast.JSONUnquote, ast.JSONContains, ast.JSONValid, ast.JSONMemberOf, ast.JSONArrayAppend, + + // date functions. + ast.Date, ast.Week /* ast.YearWeek, ast.ToSeconds */, ast.DateDiff, + /* ast.TimeDiff, ast.AddTime, ast.SubTime, */ + ast.MonthName, ast.MakeDate, ast.TimeToSec, ast.MakeTime, + ast.DateFormat, + ast.Hour, ast.Minute, ast.Second, ast.MicroSecond, ast.Month, + /* ast.DayName */ ast.DayOfMonth, ast.DayOfWeek, ast.DayOfYear, + /* ast.Weekday */ ast.WeekOfYear, ast.Year, + ast.FromDays, /* ast.ToDays */ + ast.PeriodAdd, ast.PeriodDiff, /*ast.TimestampDiff, ast.DateAdd, ast.FromUnixTime,*/ + /* ast.LastDay */ + ast.Sysdate, + + // encryption functions. + ast.MD5, ast.SHA1, ast.UncompressedLength, + + ast.Cast, + + // misc functions. + // TODO(#26942): enable functions below after them are fully tested in TiKV. + /*ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, ast.IsIPv4, ast.IsIPv4Compat, ast.IsIPv4Mapped, ast.IsIPv6,*/ + ast.UUID: + + return true + // Rust use the llvm math functions, which have different precision with Golang/MySQL(cmath) + // open the following switchers if we implement them in coprocessor via `cmath` + case ast.Conv: + arg0 := sf.GetArgs()[0] + // To be aligned with MySQL, tidb handles hybrid type argument and binary literal specially, tikv can't be consistent with tidb now. + if f, ok := arg0.(*ScalarFunction); ok { + if f.FuncName.L == ast.Cast && (f.GetArgs()[0].GetType(ctx).Hybrid() || IsBinaryLiteral(f.GetArgs()[0])) { + return false + } + } + return true + case ast.Round: + switch sf.Function.PbCode() { + case tipb.ScalarFuncSig_RoundReal, tipb.ScalarFuncSig_RoundInt, tipb.ScalarFuncSig_RoundDec: + // We don't push round with frac due to mysql's round with frac has its special behavior: + // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_round + return true + } + case ast.Rand: + switch sf.Function.PbCode() { + case tipb.ScalarFuncSig_RandWithSeedFirstGen: + return true + } + case ast.Regexp, ast.RegexpLike, ast.RegexpSubstr, ast.RegexpInStr, ast.RegexpReplace: + funcCharset, funcCollation := sf.Function.CharsetAndCollation() + if funcCharset == charset.CharsetBin && funcCollation == charset.CollationBin { + return false + } + return true + } + return false +} + +func scalarExprSupportedByFlash(ctx EvalContext, function *ScalarFunction) bool { + switch function.FuncName.L { + case ast.Floor, ast.Ceil, ast.Ceiling: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_FloorIntToDec, tipb.ScalarFuncSig_CeilIntToDec: + return false + default: + return true + } + case + ast.LogicOr, ast.LogicAnd, ast.UnaryNot, ast.BitNeg, ast.Xor, ast.And, ast.Or, ast.RightShift, ast.LeftShift, + ast.GE, ast.LE, ast.EQ, ast.NE, ast.LT, ast.GT, ast.In, ast.IsNull, ast.Like, ast.Ilike, ast.Strcmp, + ast.Plus, ast.Minus, ast.Div, ast.Mul, ast.Abs, ast.Mod, + ast.If, ast.Ifnull, ast.Case, + ast.Concat, ast.ConcatWS, + ast.Date, ast.Year, ast.Month, ast.Day, ast.Quarter, ast.DayName, ast.MonthName, + ast.DateDiff, ast.TimestampDiff, ast.DateFormat, ast.FromUnixTime, + ast.DayOfWeek, ast.DayOfMonth, ast.DayOfYear, ast.LastDay, ast.WeekOfYear, ast.ToSeconds, + ast.FromDays, ast.ToDays, + + ast.Sqrt, ast.Log, ast.Log2, ast.Log10, ast.Ln, ast.Exp, ast.Pow, ast.Power, ast.Sign, + ast.Radians, ast.Degrees, ast.Conv, ast.CRC32, + ast.JSONLength, ast.JSONDepth, ast.JSONExtract, ast.JSONUnquote, ast.JSONArray, ast.JSONContainsPath, ast.JSONValid, ast.JSONKeys, + ast.Repeat, ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, + ast.Coalesce, ast.ASCII, ast.Length, ast.Trim, ast.Position, ast.Format, ast.Elt, + ast.LTrim, ast.RTrim, ast.Lpad, ast.Rpad, + ast.Hour, ast.Minute, ast.Second, ast.MicroSecond, + ast.TimeToSec: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_InDuration, + tipb.ScalarFuncSig_CoalesceDuration, + tipb.ScalarFuncSig_IfNullDuration, + tipb.ScalarFuncSig_IfDuration, + tipb.ScalarFuncSig_CaseWhenDuration: + return false + } + return true + case ast.Regexp, ast.RegexpLike, ast.RegexpInStr, ast.RegexpSubstr, ast.RegexpReplace: + funcCharset, funcCollation := function.Function.CharsetAndCollation() + if funcCharset == charset.CharsetBin && funcCollation == charset.CollationBin { + return false + } + return true + case ast.Substr, ast.Substring, ast.Left, ast.Right, ast.CharLength, ast.SubstringIndex, ast.Reverse: + switch function.Function.PbCode() { + case + tipb.ScalarFuncSig_LeftUTF8, + tipb.ScalarFuncSig_RightUTF8, + tipb.ScalarFuncSig_CharLengthUTF8, + tipb.ScalarFuncSig_Substring2ArgsUTF8, + tipb.ScalarFuncSig_Substring3ArgsUTF8, + tipb.ScalarFuncSig_SubstringIndex, + tipb.ScalarFuncSig_ReverseUTF8, + tipb.ScalarFuncSig_Reverse: + return true + } + case ast.Cast: + sourceType := function.GetArgs()[0].GetType(ctx) + retType := function.RetType + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_CastDecimalAsInt, tipb.ScalarFuncSig_CastIntAsInt, tipb.ScalarFuncSig_CastRealAsInt, tipb.ScalarFuncSig_CastTimeAsInt, + tipb.ScalarFuncSig_CastStringAsInt /*, tipb.ScalarFuncSig_CastDurationAsInt, tipb.ScalarFuncSig_CastJsonAsInt*/ : + // TiFlash cast only support cast to Int64 or the source type is the same as the target type + return (sourceType.GetType() == retType.GetType() && mysql.HasUnsignedFlag(sourceType.GetFlag()) == mysql.HasUnsignedFlag(retType.GetFlag())) || retType.GetType() == mysql.TypeLonglong + case tipb.ScalarFuncSig_CastIntAsReal, tipb.ScalarFuncSig_CastRealAsReal, tipb.ScalarFuncSig_CastStringAsReal, tipb.ScalarFuncSig_CastTimeAsReal, tipb.ScalarFuncSig_CastDecimalAsReal: /* + tipb.ScalarFuncSig_CastDurationAsReal, tipb.ScalarFuncSig_CastJsonAsReal*/ + // TiFlash cast only support cast to Float64 or the source type is the same as the target type + return sourceType.GetType() == retType.GetType() || retType.GetType() == mysql.TypeDouble + case tipb.ScalarFuncSig_CastDecimalAsDecimal, tipb.ScalarFuncSig_CastIntAsDecimal, tipb.ScalarFuncSig_CastRealAsDecimal, tipb.ScalarFuncSig_CastTimeAsDecimal, + tipb.ScalarFuncSig_CastStringAsDecimal /*, tipb.ScalarFuncSig_CastDurationAsDecimal, tipb.ScalarFuncSig_CastJsonAsDecimal*/ : + return function.RetType.IsDecimalValid() + case tipb.ScalarFuncSig_CastDecimalAsString, tipb.ScalarFuncSig_CastIntAsString, tipb.ScalarFuncSig_CastRealAsString, tipb.ScalarFuncSig_CastTimeAsString, + tipb.ScalarFuncSig_CastStringAsString, tipb.ScalarFuncSig_CastJsonAsString /*, tipb.ScalarFuncSig_CastDurationAsString*/ : + return true + case tipb.ScalarFuncSig_CastDecimalAsTime, tipb.ScalarFuncSig_CastIntAsTime, tipb.ScalarFuncSig_CastRealAsTime, tipb.ScalarFuncSig_CastTimeAsTime, + tipb.ScalarFuncSig_CastStringAsTime /*, tipb.ScalarFuncSig_CastDurationAsTime, tipb.ScalarFuncSig_CastJsonAsTime*/ : + // ban the function of casting year type as time type pushing down to tiflash because of https://github.com/pingcap/tidb/issues/26215 + return function.GetArgs()[0].GetType(ctx).GetType() != mysql.TypeYear + case tipb.ScalarFuncSig_CastTimeAsDuration: + return retType.GetType() == mysql.TypeDuration + case tipb.ScalarFuncSig_CastIntAsJson, tipb.ScalarFuncSig_CastRealAsJson, tipb.ScalarFuncSig_CastDecimalAsJson, tipb.ScalarFuncSig_CastStringAsJson, + tipb.ScalarFuncSig_CastTimeAsJson, tipb.ScalarFuncSig_CastDurationAsJson, tipb.ScalarFuncSig_CastJsonAsJson: + return true + } + case ast.DateAdd, ast.AddDate: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_AddDateDatetimeInt, tipb.ScalarFuncSig_AddDateStringInt, tipb.ScalarFuncSig_AddDateStringReal: + return true + } + case ast.DateSub, ast.SubDate: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_SubDateDatetimeInt, tipb.ScalarFuncSig_SubDateStringInt, tipb.ScalarFuncSig_SubDateStringReal: + return true + } + case ast.UnixTimestamp: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_UnixTimestampInt, tipb.ScalarFuncSig_UnixTimestampDec: + return true + } + case ast.Round: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_RoundInt, tipb.ScalarFuncSig_RoundReal, tipb.ScalarFuncSig_RoundDec, + tipb.ScalarFuncSig_RoundWithFracInt, tipb.ScalarFuncSig_RoundWithFracReal, tipb.ScalarFuncSig_RoundWithFracDec: + return true + } + case ast.Extract: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_ExtractDatetime, tipb.ScalarFuncSig_ExtractDuration: + return true + } + case ast.Replace: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_Replace: + return true + } + case ast.StrToDate: + switch function.Function.PbCode() { + case + tipb.ScalarFuncSig_StrToDateDate, + tipb.ScalarFuncSig_StrToDateDatetime: + return true + default: + return false + } + case ast.Upper, ast.Ucase, ast.Lower, ast.Lcase, ast.Space: + return true + case ast.Sysdate: + return true + case ast.Least, ast.Greatest: + switch function.Function.PbCode() { + case tipb.ScalarFuncSig_GreatestInt, tipb.ScalarFuncSig_GreatestReal, + tipb.ScalarFuncSig_LeastInt, tipb.ScalarFuncSig_LeastReal, tipb.ScalarFuncSig_LeastString, tipb.ScalarFuncSig_GreatestString: + return true + } + case ast.IsTruthWithNull, ast.IsTruthWithoutNull, ast.IsFalsity: + return true + case ast.Hex, ast.Unhex, ast.Bin: + return true + case ast.GetFormat: + return true + case ast.IsIPv4, ast.IsIPv6: + return true + case ast.Grouping: // grouping function for grouping sets identification. + return true + } + return false +} + +func canEnumPushdownPreliminarily(scalarFunc *ScalarFunction) bool { + switch scalarFunc.FuncName.L { + case ast.Cast: + return scalarFunc.RetType.EvalType() == types.ETInt || scalarFunc.RetType.EvalType() == types.ETReal || scalarFunc.RetType.EvalType() == types.ETDecimal + default: + return false + } +} + +// IsPushDownEnabled returns true if the input expr is not in the expr_pushdown_blacklist +func IsPushDownEnabled(name string, storeType kv.StoreType) bool { + value, exists := DefaultExprPushDownBlacklist.Load().(map[string]uint32)[name] + if exists { + mask := storeTypeMask(storeType) + return !(value&mask == mask) + } + + if storeType != kv.TiFlash && name == ast.AggFuncApproxCountDistinct { + // Can not push down approx_count_distinct to other store except tiflash by now. + return false + } + + return true +} + +// PushDownContext is the context used for push down expressions +type PushDownContext struct { + evalCtx EvalContext + client kv.Client + warnHandler contextutil.WarnAppender + groupConcatMaxLen uint64 +} + +// NewPushDownContext returns a new PushDownContext +func NewPushDownContext(evalCtx EvalContext, client kv.Client, inExplainStmt bool, + warnHandler contextutil.WarnAppender, extraWarnHandler contextutil.WarnAppender, groupConcatMaxLen uint64) PushDownContext { + var newWarnHandler contextutil.WarnAppender + if warnHandler != nil && extraWarnHandler != nil { + if inExplainStmt { + newWarnHandler = warnHandler + } else { + newWarnHandler = extraWarnHandler + } + } + + return PushDownContext{ + evalCtx: evalCtx, + client: client, + warnHandler: newWarnHandler, + groupConcatMaxLen: groupConcatMaxLen, + } +} + +// NewPushDownContextFromSessionVars builds a new PushDownContext from session vars. +func NewPushDownContextFromSessionVars(evalCtx EvalContext, sessVars *variable.SessionVars, client kv.Client) PushDownContext { + return NewPushDownContext( + evalCtx, + client, + sessVars.StmtCtx.InExplainStmt, + sessVars.StmtCtx.WarnHandler, + sessVars.StmtCtx.ExtraWarnHandler, + sessVars.GroupConcatMaxLen) +} + +// EvalCtx returns the eval context +func (ctx PushDownContext) EvalCtx() EvalContext { + return ctx.evalCtx +} + +// PbConverter returns a new PbConverter +func (ctx PushDownContext) PbConverter() PbConverter { + return NewPBConverter(ctx.client, ctx.evalCtx) +} + +// Client returns the kv client +func (ctx PushDownContext) Client() kv.Client { + return ctx.client +} + +// GetGroupConcatMaxLen returns the max length of group_concat +func (ctx PushDownContext) GetGroupConcatMaxLen() uint64 { + return ctx.groupConcatMaxLen +} + +// AppendWarning appends a warning to be handled by the internal handler +func (ctx PushDownContext) AppendWarning(err error) { + if ctx.warnHandler != nil { + ctx.warnHandler.AppendWarning(err) + } +} + +// PushDownExprsWithExtraInfo split the input exprs into pushed and remained, pushed include all the exprs that can be pushed down +func PushDownExprsWithExtraInfo(ctx PushDownContext, exprs []Expression, storeType kv.StoreType, canEnumPush bool) (pushed []Expression, remained []Expression) { + for _, expr := range exprs { + if canExprPushDown(ctx, expr, storeType, canEnumPush) { + pushed = append(pushed, expr) + } else { + remained = append(remained, expr) + } + } + return +} + +// PushDownExprs split the input exprs into pushed and remained, pushed include all the exprs that can be pushed down +func PushDownExprs(ctx PushDownContext, exprs []Expression, storeType kv.StoreType) (pushed []Expression, remained []Expression) { + return PushDownExprsWithExtraInfo(ctx, exprs, storeType, false) +} + +// CanExprsPushDownWithExtraInfo return true if all the expr in exprs can be pushed down +func CanExprsPushDownWithExtraInfo(ctx PushDownContext, exprs []Expression, storeType kv.StoreType, canEnumPush bool) bool { + _, remained := PushDownExprsWithExtraInfo(ctx, exprs, storeType, canEnumPush) + return len(remained) == 0 +} + +// CanExprsPushDown return true if all the expr in exprs can be pushed down +func CanExprsPushDown(ctx PushDownContext, exprs []Expression, storeType kv.StoreType) bool { + return CanExprsPushDownWithExtraInfo(ctx, exprs, storeType, false) +} + +func storeTypeMask(storeType kv.StoreType) uint32 { + if storeType == kv.UnSpecified { + return 1<= 0; i-- { + if filter(input[i]) { + filteredOut = append(filteredOut, input[i]) + input = append(input[:i], input[i+1:]...) + } + } + return input, filteredOut +} + +// ExtractDependentColumns extracts all dependent columns from a virtual column. +func ExtractDependentColumns(expr Expression) []*Column { + // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. + result := make([]*Column, 0, 8) + return extractDependentColumns(result, expr) +} + +func extractDependentColumns(result []*Column, expr Expression) []*Column { + switch v := expr.(type) { + case *Column: + result = append(result, v) + if v.VirtualExpr != nil { + result = extractDependentColumns(result, v.VirtualExpr) + } + case *ScalarFunction: + for _, arg := range v.GetArgs() { + result = extractDependentColumns(result, arg) + } + } + return result +} + +// ExtractColumns extracts all columns from an expression. +func ExtractColumns(expr Expression) []*Column { + // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. + result := make([]*Column, 0, 8) + return extractColumns(result, expr, nil) +} + +// ExtractCorColumns extracts correlated column from given expression. +func ExtractCorColumns(expr Expression) (cols []*CorrelatedColumn) { + switch v := expr.(type) { + case *CorrelatedColumn: + return []*CorrelatedColumn{v} + case *ScalarFunction: + for _, arg := range v.GetArgs() { + cols = append(cols, ExtractCorColumns(arg)...) + } + } + return +} + +// ExtractColumnsFromExpressions is a more efficient version of ExtractColumns for batch operation. +// filter can be nil, or a function to filter the result column. +// It's often observed that the pattern of the caller like this: +// +// cols := ExtractColumns(...) +// +// for _, col := range cols { +// if xxx(col) {...} +// } +// +// Provide an additional filter argument, this can be done in one step. +// To avoid allocation for cols that not need. +func ExtractColumnsFromExpressions(result []*Column, exprs []Expression, filter func(*Column) bool) []*Column { + for _, expr := range exprs { + result = extractColumns(result, expr, filter) + } + return result +} + +func extractColumns(result []*Column, expr Expression, filter func(*Column) bool) []*Column { + switch v := expr.(type) { + case *Column: + if filter == nil || filter(v) { + result = append(result, v) + } + case *ScalarFunction: + for _, arg := range v.GetArgs() { + result = extractColumns(result, arg, filter) + } + } + return result +} + +// ExtractEquivalenceColumns detects the equivalence from CNF exprs. +func ExtractEquivalenceColumns(result [][]Expression, exprs []Expression) [][]Expression { + // exprs are CNF expressions, EQ condition only make sense in the top level of every expr. + for _, expr := range exprs { + result = extractEquivalenceColumns(result, expr) + } + return result +} + +// FindUpperBound looks for column < constant or column <= constant and returns both the column +// and constant. It return nil, 0 if the expression is not of this form. +// It is used by derived Top N pattern and it is put here since it looks like +// a general purpose routine. Similar routines can be added to find lower bound as well. +func FindUpperBound(expr Expression) (*Column, int64) { + scalarFunction, scalarFunctionOk := expr.(*ScalarFunction) + if scalarFunctionOk { + args := scalarFunction.GetArgs() + if len(args) == 2 { + col, colOk := args[0].(*Column) + constant, constantOk := args[1].(*Constant) + if colOk && constantOk && (scalarFunction.FuncName.L == ast.LT || scalarFunction.FuncName.L == ast.LE) { + value, valueOk := constant.Value.GetValue().(int64) + if valueOk { + if scalarFunction.FuncName.L == ast.LT { + return col, value - 1 + } + return col, value + } + } + } + } + return nil, 0 +} + +func extractEquivalenceColumns(result [][]Expression, expr Expression) [][]Expression { + switch v := expr.(type) { + case *ScalarFunction: + // a==b, a<=>b, the latter one is evaluated to true when a,b are both null. + if v.FuncName.L == ast.EQ || v.FuncName.L == ast.NullEQ { + args := v.GetArgs() + if len(args) == 2 { + col1, ok1 := args[0].(*Column) + col2, ok2 := args[1].(*Column) + if ok1 && ok2 { + result = append(result, []Expression{col1, col2}) + } + col, ok1 := args[0].(*Column) + scl, ok2 := args[1].(*ScalarFunction) + if ok1 && ok2 { + result = append(result, []Expression{col, scl}) + } + col, ok1 = args[1].(*Column) + scl, ok2 = args[0].(*ScalarFunction) + if ok1 && ok2 { + result = append(result, []Expression{col, scl}) + } + } + return result + } + if v.FuncName.L == ast.In { + args := v.GetArgs() + // only `col in (only 1 element)`, can we build an equivalence here. + if len(args[1:]) == 1 { + col1, ok1 := args[0].(*Column) + col2, ok2 := args[1].(*Column) + if ok1 && ok2 { + result = append(result, []Expression{col1, col2}) + } + col, ok1 := args[0].(*Column) + scl, ok2 := args[1].(*ScalarFunction) + if ok1 && ok2 { + result = append(result, []Expression{col, scl}) + } + col, ok1 = args[1].(*Column) + scl, ok2 = args[0].(*ScalarFunction) + if ok1 && ok2 { + result = append(result, []Expression{col, scl}) + } + } + return result + } + // For Non-EQ function, we don't have to traverse down. + // eg: (a=b or c=d) doesn't make any definitely equivalence assertion. + } + return result +} + +// extractColumnsAndCorColumns extracts columns and correlated columns from `expr` and append them to `result`. +func extractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { + switch v := expr.(type) { + case *Column: + result = append(result, v) + case *CorrelatedColumn: + result = append(result, &v.Column) + case *ScalarFunction: + for _, arg := range v.GetArgs() { + result = extractColumnsAndCorColumns(result, arg) + } + } + return result +} + +// ExtractConstantEqColumnsOrScalar detects the constant equal relationship from CNF exprs. +func ExtractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, exprs []Expression) []Expression { + // exprs are CNF expressions, EQ condition only make sense in the top level of every expr. + for _, expr := range exprs { + result = extractConstantEqColumnsOrScalar(ctx, result, expr) + } + return result +} + +func extractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, expr Expression) []Expression { + switch v := expr.(type) { + case *ScalarFunction: + if v.FuncName.L == ast.EQ || v.FuncName.L == ast.NullEQ { + args := v.GetArgs() + if len(args) == 2 { + col, ok1 := args[0].(*Column) + _, ok2 := args[1].(*Constant) + if ok1 && ok2 { + result = append(result, col) + } + col, ok1 = args[1].(*Column) + _, ok2 = args[0].(*Constant) + if ok1 && ok2 { + result = append(result, col) + } + // take the correlated column as constant here. + col, ok1 = args[0].(*Column) + _, ok2 = args[1].(*CorrelatedColumn) + if ok1 && ok2 { + result = append(result, col) + } + col, ok1 = args[1].(*Column) + _, ok2 = args[0].(*CorrelatedColumn) + if ok1 && ok2 { + result = append(result, col) + } + scl, ok1 := args[0].(*ScalarFunction) + _, ok2 = args[1].(*Constant) + if ok1 && ok2 { + result = append(result, scl) + } + scl, ok1 = args[1].(*ScalarFunction) + _, ok2 = args[0].(*Constant) + if ok1 && ok2 { + result = append(result, scl) + } + // take the correlated column as constant here. + scl, ok1 = args[0].(*ScalarFunction) + _, ok2 = args[1].(*CorrelatedColumn) + if ok1 && ok2 { + result = append(result, scl) + } + scl, ok1 = args[1].(*ScalarFunction) + _, ok2 = args[0].(*CorrelatedColumn) + if ok1 && ok2 { + result = append(result, scl) + } + } + return result + } + if v.FuncName.L == ast.In { + args := v.GetArgs() + allArgsIsConst := true + // only `col in (all same const)`, can col be the constant column. + // eg: a in (1, "1") does, while a in (1, '2') doesn't. + guard := args[1] + for i, v := range args[1:] { + if _, ok := v.(*Constant); !ok { + allArgsIsConst = false + break + } + if i == 0 { + continue + } + if !guard.Equal(ctx.GetEvalCtx(), v) { + allArgsIsConst = false + break + } + } + if allArgsIsConst { + if col, ok := args[0].(*Column); ok { + result = append(result, col) + } else if scl, ok := args[0].(*ScalarFunction); ok { + result = append(result, scl) + } + } + return result + } + // For Non-EQ function, we don't have to traverse down. + } + return result +} + +// ExtractColumnsAndCorColumnsFromExpressions extracts columns and correlated columns from expressions and append them to `result`. +func ExtractColumnsAndCorColumnsFromExpressions(result []*Column, list []Expression) []*Column { + for _, expr := range list { + result = extractColumnsAndCorColumns(result, expr) + } + return result +} + +// ExtractColumnSet extracts the different values of `UniqueId` for columns in expressions. +func ExtractColumnSet(exprs ...Expression) intset.FastIntSet { + set := intset.NewFastIntSet() + for _, expr := range exprs { + extractColumnSet(expr, &set) + } + return set +} + +func extractColumnSet(expr Expression, set *intset.FastIntSet) { + switch v := expr.(type) { + case *Column: + set.Insert(int(v.UniqueID)) + case *ScalarFunction: + for _, arg := range v.GetArgs() { + extractColumnSet(arg, set) + } + } +} + +// SetExprColumnInOperand is used to set columns in expr as InOperand. +func SetExprColumnInOperand(expr Expression) Expression { + switch v := expr.(type) { + case *Column: + col := v.Clone().(*Column) + col.InOperand = true + return col + case *ScalarFunction: + args := v.GetArgs() + for i, arg := range args { + args[i] = SetExprColumnInOperand(arg) + } + } + return expr +} + +// ColumnSubstitute substitutes the columns in filter to expressions in select fields. +// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. +// TODO: remove this function and only use ColumnSubstituteImpl since this function swallows the error, which seems unsafe. +func ColumnSubstitute(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) Expression { + _, _, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, false) + return resExpr +} + +// ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution. +// Only accept: +// +// 1: substitute them all once find col in schema. +// 2: nothing in expr can be substituted. +func ColumnSubstituteAll(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { + _, hasFail, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, true) + return hasFail, resExpr +} + +// ColumnSubstituteImpl tries to substitute column expr using newExprs, +// the newFunctionInternal is only called if its child is substituted +// @return bool means whether the expr has changed. +// @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback. +// @return Expression, the original expr or the changed expr, it depends on the first @return bool. +func ColumnSubstituteImpl(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { + switch v := expr.(type) { + case *Column: + id := schema.ColumnIndex(v) + if id == -1 { + return false, false, v + } + newExpr := newExprs[id] + if v.InOperand { + newExpr = SetExprColumnInOperand(newExpr) + } + return true, false, newExpr + case *ScalarFunction: + substituted := false + hasFail := false + if v.FuncName.L == ast.Cast || v.FuncName.L == ast.Grouping { + var newArg Expression + substituted, hasFail, newArg = ColumnSubstituteImpl(ctx, v.GetArgs()[0], schema, newExprs, fail1Return) + if fail1Return && hasFail { + return substituted, hasFail, v + } + if substituted { + flag := v.RetType.GetFlag() + var e Expression + if v.FuncName.L == ast.Cast { + e = BuildCastFunction(ctx, newArg, v.RetType) + } else { + // for grouping function recreation, use clone (meta included) instead of newFunction + e = v.Clone() + e.(*ScalarFunction).Function.getArgs()[0] = newArg + } + e.SetCoercibility(v.Coercibility()) + e.GetType(ctx.GetEvalCtx()).SetFlag(flag) + return true, false, e + } + return false, false, v + } + // If the collation of the column is PAD SPACE, + // we can't propagate the constant to the length function. + // For example, schema = ['name'], newExprs = ['a'], v = length(name). + // We can't substitute name with 'a' in length(name) because the collation of name is PAD SPACE. + // TODO: We will fix it here temporarily, and redesign the logic if we encounter more similar functions or situations later. + // Fixed issue #53730 + if ctx.IsConstantPropagateCheck() && v.FuncName.L == ast.Length { + arg0, isColumn := v.GetArgs()[0].(*Column) + if isColumn { + id := schema.ColumnIndex(arg0) + if id != -1 { + _, isConstant := newExprs[id].(*Constant) + if isConstant { + mappedNewColumnCollate := schema.Columns[id].GetStaticType().GetCollate() + if mappedNewColumnCollate == charset.CollationUTF8MB4 || + mappedNewColumnCollate == charset.CollationUTF8 { + return false, false, v + } + } + } + } + } + // cowExprRef is a copy-on-write util, args array allocation happens only + // when expr in args is changed + refExprArr := cowExprRef{v.GetArgs(), nil} + oldCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), v.GetArgs()...) + if err != nil { + logutil.BgLogger().Error("Unexpected error happened during ColumnSubstitution", zap.Stack("stack")) + return false, false, v + } + var tmpArgForCollCheck []Expression + if collate.NewCollationEnabled() { + tmpArgForCollCheck = make([]Expression, len(v.GetArgs())) + } + for idx, arg := range v.GetArgs() { + changed, failed, newFuncExpr := ColumnSubstituteImpl(ctx, arg, schema, newExprs, fail1Return) + if fail1Return && failed { + return changed, failed, v + } + oldChanged := changed + if collate.NewCollationEnabled() && changed { + // Make sure the collation used by the ScalarFunction isn't changed and its result collation is not weaker than the collation used by the ScalarFunction. + changed = false + copy(tmpArgForCollCheck, refExprArr.Result()) + tmpArgForCollCheck[idx] = newFuncExpr + newCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), tmpArgForCollCheck...) + if err != nil { + logutil.BgLogger().Error("Unexpected error happened during ColumnSubstitution", zap.Stack("stack")) + return false, failed, v + } + if oldCollEt.Collation == newCollEt.Collation { + if newFuncExpr.GetType(ctx.GetEvalCtx()).GetCollate() == arg.GetType(ctx.GetEvalCtx()).GetCollate() && newFuncExpr.Coercibility() == arg.Coercibility() { + // It's safe to use the new expression, otherwise some cases in projection push-down will be wrong. + changed = true + } else { + changed = checkCollationStrictness(oldCollEt.Collation, newFuncExpr.GetType(ctx.GetEvalCtx()).GetCollate()) + } + } + } + hasFail = hasFail || failed || oldChanged != changed + if fail1Return && oldChanged != changed { + // Only when the oldChanged is true and changed is false, we will get here. + // And this means there some dependency in this arg can be substituted with + // given expressions, while it has some collation compatibility, finally we + // fall back to use the origin args. (commonly used in projection elimination + // in which fallback usage is unacceptable) + return changed, true, v + } + refExprArr.Set(idx, changed, newFuncExpr) + if changed { + substituted = true + } + } + if substituted { + newFunc, err := NewFunction(ctx, v.FuncName.L, v.RetType, refExprArr.Result()...) + if err != nil { + return true, true, v + } + return true, hasFail, newFunc + } + } + return false, false, expr +} + +// checkCollationStrictness check collation strictness-ship between `coll` and `newFuncColl` +// return true iff `newFuncColl` is not weaker than `coll` +func checkCollationStrictness(coll, newFuncColl string) bool { + collGroupID, ok1 := CollationStrictnessGroup[coll] + newFuncCollGroupID, ok2 := CollationStrictnessGroup[newFuncColl] + + if ok1 && ok2 { + if collGroupID == newFuncCollGroupID { + return true + } + + for _, id := range CollationStrictness[collGroupID] { + if newFuncCollGroupID == id { + return true + } + } + } + + return false +} + +// getValidPrefix gets a prefix of string which can parsed to a number with base. the minimum base is 2 and the maximum is 36. +func getValidPrefix(s string, base int64) string { + var ( + validLen int + upper rune + ) + switch { + case base >= 2 && base <= 9: + upper = rune('0' + base) + case base <= 36: + upper = rune('A' + base - 10) + default: + return "" + } +Loop: + for i := 0; i < len(s); i++ { + c := rune(s[i]) + switch { + case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c): + c = unicode.ToUpper(c) + if c >= upper { + break Loop + } + validLen = i + 1 + case c == '+' || c == '-': + if i != 0 { + break Loop + } + default: + break Loop + } + } + if validLen > 1 && s[0] == '+' { + return s[1:validLen] + } + return s[:validLen] +} + +// SubstituteCorCol2Constant will substitute correlated column to constant value which it contains. +// If the args of one scalar function are all constant, we will substitute it to constant. +func SubstituteCorCol2Constant(ctx BuildContext, expr Expression) (Expression, error) { + switch x := expr.(type) { + case *ScalarFunction: + allConstant := true + newArgs := make([]Expression, 0, len(x.GetArgs())) + for _, arg := range x.GetArgs() { + newArg, err := SubstituteCorCol2Constant(ctx, arg) + if err != nil { + return nil, err + } + _, ok := newArg.(*Constant) + newArgs = append(newArgs, newArg) + allConstant = allConstant && ok + } + if allConstant { + val, err := x.Eval(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return nil, err + } + return &Constant{Value: val, RetType: x.GetType(ctx.GetEvalCtx())}, nil + } + var ( + err error + newSf Expression + ) + if x.FuncName.L == ast.Cast { + newSf = BuildCastFunction(ctx, newArgs[0], x.RetType) + } else if x.FuncName.L == ast.Grouping { + newSf = x.Clone() + newSf.(*ScalarFunction).GetArgs()[0] = newArgs[0] + } else { + newSf, err = NewFunction(ctx, x.FuncName.L, x.GetType(ctx.GetEvalCtx()), newArgs...) + } + return newSf, err + case *CorrelatedColumn: + return &Constant{Value: *x.Data, RetType: x.GetType(ctx.GetEvalCtx())}, nil + case *Constant: + if x.DeferredExpr != nil { + newExpr := FoldConstant(ctx, x) + return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType(ctx.GetEvalCtx())}, nil + } + } + return expr, nil +} + +func locateStringWithCollation(str, substr, coll string) int64 { + collator := collate.GetCollator(coll) + strKey := collator.KeyWithoutTrimRightSpace(str) + subStrKey := collator.KeyWithoutTrimRightSpace(substr) + + index := bytes.Index(strKey, subStrKey) + if index == -1 || index == 0 { + return int64(index + 1) + } + + // todo: we can use binary search to make it faster. + count := int64(0) + for { + r, size := utf8.DecodeRuneInString(str) + count++ + index -= len(collator.KeyWithoutTrimRightSpace(string(r))) + if index <= 0 { + return count + 1 + } + str = str[size:] + } +} + +// timeZone2Duration converts timezone whose format should satisfy the regular condition +// `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to int for use by time.FixedZone(). +func timeZone2int(tz string) int { + sign := 1 + if strings.HasPrefix(tz, "-") { + sign = -1 + } + + i := strings.Index(tz, ":") + h, err := strconv.Atoi(tz[1:i]) + terror.Log(err) + m, err := strconv.Atoi(tz[i+1:]) + terror.Log(err) + return sign * ((h * 3600) + (m * 60)) +} + +var logicalOps = map[string]struct{}{ + ast.LT: {}, + ast.GE: {}, + ast.GT: {}, + ast.LE: {}, + ast.EQ: {}, + ast.NE: {}, + ast.UnaryNot: {}, + ast.LogicAnd: {}, + ast.LogicOr: {}, + ast.LogicXor: {}, + ast.In: {}, + ast.IsNull: {}, + ast.IsTruthWithoutNull: {}, + ast.IsFalsity: {}, + ast.Like: {}, +} + +var oppositeOp = map[string]string{ + ast.LT: ast.GE, + ast.GE: ast.LT, + ast.GT: ast.LE, + ast.LE: ast.GT, + ast.EQ: ast.NE, + ast.NE: ast.EQ, + ast.LogicOr: ast.LogicAnd, + ast.LogicAnd: ast.LogicOr, +} + +// a op b is equal to b symmetricOp a +var symmetricOp = map[opcode.Op]opcode.Op{ + opcode.LT: opcode.GT, + opcode.GE: opcode.LE, + opcode.GT: opcode.LT, + opcode.LE: opcode.GE, + opcode.EQ: opcode.EQ, + opcode.NE: opcode.NE, + opcode.NullEQ: opcode.NullEQ, +} + +func pushNotAcrossArgs(ctx BuildContext, exprs []Expression, not bool) ([]Expression, bool) { + newExprs := make([]Expression, 0, len(exprs)) + flag := false + for _, expr := range exprs { + newExpr, changed := pushNotAcrossExpr(ctx, expr, not) + flag = changed || flag + newExprs = append(newExprs, newExpr) + } + return newExprs, flag +} + +// todo: consider more no precision-loss downcast cases. +func noPrecisionLossCastCompatible(cast, argCol *types.FieldType) bool { + // now only consider varchar type and integer. + if !(types.IsTypeVarchar(cast.GetType()) && types.IsTypeVarchar(argCol.GetType())) && + !(mysql.IsIntegerType(cast.GetType()) && mysql.IsIntegerType(argCol.GetType())) { + // varchar type and integer on the storage layer is quite same, while the char type has its padding suffix. + return false + } + if types.IsTypeVarchar(cast.GetType()) { + // cast varchar function only bear the flen extension. + if cast.GetFlen() < argCol.GetFlen() { + return false + } + if !collate.CompatibleCollate(cast.GetCollate(), argCol.GetCollate()) { + return false + } + } else { + // For integers, we should ignore the potential display length represented by flen, using the default flen of the type. + castFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(cast.GetType()) + originFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(argCol.GetType()) + // cast integer function only bear the flen extension and signed symbol unchanged. + if castFlen < originFlen { + return false + } + if mysql.HasUnsignedFlag(cast.GetFlag()) != mysql.HasUnsignedFlag(argCol.GetFlag()) { + return false + } + } + return true +} + +func unwrapCast(sctx BuildContext, parentF *ScalarFunction, castOffset int) (Expression, bool) { + _, collation := parentF.CharsetAndCollation() + cast, ok := parentF.GetArgs()[castOffset].(*ScalarFunction) + if !ok || cast.FuncName.L != ast.Cast { + return parentF, false + } + // eg: if (cast(A) EQ const) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. + if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { + return parentF, false + } + // 1-castOffset should be constant + if _, ok := parentF.GetArgs()[1-castOffset].(*Constant); !ok { + return parentF, false + } + + // the direct args of cast function should be column. + c, ok := cast.GetArgs()[0].(*Column) + if !ok { + return parentF, false + } + + // current only consider varchar and integer + if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { + return parentF, false + } + + // the column is covered by indexes, deconstructing it out. + if castOffset == 0 { + return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, c, parentF.GetArgs()[1]), true + } + return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, parentF.GetArgs()[0], c), true +} + +// eliminateCastFunction will detect the original arg before and the cast type after, once upon +// there is no precision loss between them, current cast wrapper can be eliminated. For string +// type, collation is also taken into consideration. (mainly used to build range or point) +func eliminateCastFunction(sctx BuildContext, expr Expression) (_ Expression, changed bool) { + f, ok := expr.(*ScalarFunction) + if !ok { + return expr, false + } + _, collation := expr.CharsetAndCollation() + switch f.FuncName.L { + case ast.LogicOr: + dnfItems := FlattenDNFConditions(f) + rmCast := false + rmCastItems := make([]Expression, len(dnfItems)) + for i, dnfItem := range dnfItems { + newExpr, curDowncast := eliminateCastFunction(sctx, dnfItem) + rmCastItems[i] = newExpr + if curDowncast { + rmCast = true + } + } + if rmCast { + // compose the new DNF expression. + return ComposeDNFCondition(sctx, rmCastItems...), true + } + return expr, false + case ast.LogicAnd: + cnfItems := FlattenCNFConditions(f) + rmCast := false + rmCastItems := make([]Expression, len(cnfItems)) + for i, cnfItem := range cnfItems { + newExpr, curDowncast := eliminateCastFunction(sctx, cnfItem) + rmCastItems[i] = newExpr + if curDowncast { + rmCast = true + } + } + if rmCast { + // compose the new CNF expression. + return ComposeCNFCondition(sctx, rmCastItems...), true + } + return expr, false + case ast.EQ, ast.NullEQ, ast.LE, ast.GE, ast.LT, ast.GT: + // for case: eq(cast(test.t2.a, varchar(100), "aaaaa"), once t2.a is covered by index or pk, try deconstructing it out. + if newF, ok := unwrapCast(sctx, f, 0); ok { + return newF, true + } + // for case: eq("aaaaa", cast(test.t2.a, varchar(100)), once t2.a is covered by index or pk, try deconstructing it out. + if newF, ok := unwrapCast(sctx, f, 1); ok { + return newF, true + } + case ast.In: + // case for: cast(a as bigint) in (1,2,3), we could deconstruct column 'a out directly. + cast, ok := f.GetArgs()[0].(*ScalarFunction) + if !ok || cast.FuncName.L != ast.Cast { + return expr, false + } + // eg: if (cast(A) IN {const}) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. + if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { + return expr, false + } + for _, arg := range f.GetArgs()[1:] { + if _, ok := arg.(*Constant); !ok { + return expr, false + } + } + // the direct args of cast function should be column. + c, ok := cast.GetArgs()[0].(*Column) + if !ok { + return expr, false + } + // current only consider varchar and integer + if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { + return expr, false + } + newArgs := []Expression{c} + newArgs = append(newArgs, f.GetArgs()[1:]...) + return NewFunctionInternal(sctx, f.FuncName.L, f.RetType, newArgs...), true + } + return expr, false +} + +// pushNotAcrossExpr try to eliminate the NOT expr in expression tree. +// Input `not` indicates whether there's a `NOT` be pushed down. +// Output `changed` indicates whether the output expression differs from the +// input `expr` because of the pushed-down-not. +func pushNotAcrossExpr(ctx BuildContext, expr Expression, not bool) (_ Expression, changed bool) { + if f, ok := expr.(*ScalarFunction); ok { + switch f.FuncName.L { + case ast.UnaryNot: + child, err := wrapWithIsTrue(ctx, true, f.GetArgs()[0], true) + if err != nil { + return expr, false + } + var childExpr Expression + childExpr, changed = pushNotAcrossExpr(ctx, child, !not) + if !changed && !not { + return expr, false + } + return childExpr, true + case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE: + if not { + return NewFunctionInternal(ctx, oppositeOp[f.FuncName.L], f.GetType(ctx.GetEvalCtx()), f.GetArgs()...), true + } + newArgs, changed := pushNotAcrossArgs(ctx, f.GetArgs(), false) + if !changed { + return f, false + } + return NewFunctionInternal(ctx, f.FuncName.L, f.GetType(ctx.GetEvalCtx()), newArgs...), true + case ast.LogicAnd, ast.LogicOr: + var ( + newArgs []Expression + changed bool + ) + funcName := f.FuncName.L + if not { + newArgs, _ = pushNotAcrossArgs(ctx, f.GetArgs(), true) + funcName = oppositeOp[f.FuncName.L] + changed = true + } else { + newArgs, changed = pushNotAcrossArgs(ctx, f.GetArgs(), false) + } + if !changed { + return f, false + } + return NewFunctionInternal(ctx, funcName, f.GetType(ctx.GetEvalCtx()), newArgs...), true + } + } + if not { + expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr) + } + return expr, not +} + +// GetExprInsideIsTruth get the expression inside the `istrue_with_null` and `istrue`. +// This is useful when handling expressions from "not" or "!", because we might wrap `istrue_with_null` or `istrue` +// when handling them. See pushNotAcrossExpr() and wrapWithIsTrue() for details. +func GetExprInsideIsTruth(expr Expression) Expression { + if f, ok := expr.(*ScalarFunction); ok { + switch f.FuncName.L { + case ast.IsTruthWithNull, ast.IsTruthWithoutNull: + return GetExprInsideIsTruth(f.GetArgs()[0]) + default: + return expr + } + } + return expr +} + +// PushDownNot pushes the `not` function down to the expression's arguments. +func PushDownNot(ctx BuildContext, expr Expression) Expression { + newExpr, _ := pushNotAcrossExpr(ctx, expr, false) + return newExpr +} + +// EliminateNoPrecisionLossCast remove the redundant cast function for range build convenience. +// 1: deeper cast embedded in other complicated function will not be considered. +// 2: cast args should be one for original base column and one for constant. +// 3: some collation compatibility and precision loss will be considered when remove this cast func. +func EliminateNoPrecisionLossCast(sctx BuildContext, expr Expression) Expression { + newExpr, _ := eliminateCastFunction(sctx, expr) + return newExpr +} + +// ContainOuterNot checks if there is an outer `not`. +func ContainOuterNot(expr Expression) bool { + return containOuterNot(expr, false) +} + +// containOuterNot checks if there is an outer `not`. +// Input `not` means whether there is `not` outside `expr` +// +// eg. +// +// not(0+(t.a == 1 and t.b == 2)) returns true +// not(t.a) and not(t.b) returns false +func containOuterNot(expr Expression, not bool) bool { + if f, ok := expr.(*ScalarFunction); ok { + switch f.FuncName.L { + case ast.UnaryNot: + return containOuterNot(f.GetArgs()[0], true) + case ast.IsTruthWithNull, ast.IsNull: + return containOuterNot(f.GetArgs()[0], not) + default: + if not { + return true + } + hasNot := false + for _, expr := range f.GetArgs() { + hasNot = hasNot || containOuterNot(expr, not) + if hasNot { + return hasNot + } + } + return hasNot + } + } + return false +} + +// Contains tests if `exprs` contains `e`. +func Contains(ectx EvalContext, exprs []Expression, e Expression) bool { + for _, expr := range exprs { + // Check string equivalence if one of the expressions is a clone. + sameString := false + if e != nil && expr != nil { + sameString = (e.StringWithCtx(ectx, errors.RedactLogDisable) == expr.StringWithCtx(ectx, errors.RedactLogDisable)) + } + if e == expr || sameString { + return true + } + } + return false +} + +// ExtractFiltersFromDNFs checks whether the cond is DNF. If so, it will get the extracted part and the remained part. +// The original DNF will be replaced by the remained part or just be deleted if remained part is nil. +// And the extracted part will be appended to the end of the original slice. +func ExtractFiltersFromDNFs(ctx BuildContext, conditions []Expression) []Expression { + var allExtracted []Expression + for i := len(conditions) - 1; i >= 0; i-- { + if sf, ok := conditions[i].(*ScalarFunction); ok && sf.FuncName.L == ast.LogicOr { + extracted, remained := extractFiltersFromDNF(ctx, sf) + allExtracted = append(allExtracted, extracted...) + if remained == nil { + conditions = append(conditions[:i], conditions[i+1:]...) + } else { + conditions[i] = remained + } + } + } + return append(conditions, allExtracted...) +} + +// extractFiltersFromDNF extracts the same condition that occurs in every DNF item and remove them from dnf leaves. +func extractFiltersFromDNF(ctx BuildContext, dnfFunc *ScalarFunction) ([]Expression, Expression) { + dnfItems := FlattenDNFConditions(dnfFunc) + codeMap := make(map[string]int) + hashcode2Expr := make(map[string]Expression) + for i, dnfItem := range dnfItems { + innerMap := make(map[string]struct{}) + cnfItems := SplitCNFItems(dnfItem) + for _, cnfItem := range cnfItems { + code := cnfItem.HashCode() + if i == 0 { + codeMap[string(code)] = 1 + hashcode2Expr[string(code)] = cnfItem + } else if _, ok := codeMap[string(code)]; ok { + // We need this check because there may be the case like `select * from t, t1 where (t.a=t1.a and t.a=t1.a) or (something). + // We should make sure that the two `t.a=t1.a` contributes only once. + // TODO: do this out of this function. + if _, ok = innerMap[string(code)]; !ok { + codeMap[string(code)]++ + innerMap[string(code)] = struct{}{} + } + } + } + } + // We should make sure that this item occurs in every DNF item. + for hashcode, cnt := range codeMap { + if cnt < len(dnfItems) { + delete(hashcode2Expr, hashcode) + } + } + if len(hashcode2Expr) == 0 { + return nil, dnfFunc + } + newDNFItems := make([]Expression, 0, len(dnfItems)) + onlyNeedExtracted := false + for _, dnfItem := range dnfItems { + cnfItems := SplitCNFItems(dnfItem) + newCNFItems := make([]Expression, 0, len(cnfItems)) + for _, cnfItem := range cnfItems { + code := cnfItem.HashCode() + _, ok := hashcode2Expr[string(code)] + if !ok { + newCNFItems = append(newCNFItems, cnfItem) + } + } + // If the extracted part is just one leaf of the DNF expression. Then the value of the total DNF expression is + // always the same with the value of the extracted part. + if len(newCNFItems) == 0 { + onlyNeedExtracted = true + break + } + newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) + } + extractedExpr := make([]Expression, 0, len(hashcode2Expr)) + for _, expr := range hashcode2Expr { + extractedExpr = append(extractedExpr, expr) + } + if onlyNeedExtracted { + return extractedExpr, nil + } + return extractedExpr, ComposeDNFCondition(ctx, newDNFItems...) +} + +// DeriveRelaxedFiltersFromDNF given a DNF expression, derive a relaxed DNF expression which only contains columns +// in specified schema; the derived expression is a superset of original expression, i.e, any tuple satisfying +// the original expression must satisfy the derived expression. Return nil when the derived expression is universal set. +// A running example is: for schema of t1, `(t1.a=1 and t2.a=1) or (t1.a=2 and t2.a=2)` would be derived as +// `t1.a=1 or t1.a=2`, while `t1.a=1 or t2.a=1` would get nil. +func DeriveRelaxedFiltersFromDNF(ctx BuildContext, expr Expression, schema *Schema) Expression { + sf, ok := expr.(*ScalarFunction) + if !ok || sf.FuncName.L != ast.LogicOr { + return nil + } + dnfItems := FlattenDNFConditions(sf) + newDNFItems := make([]Expression, 0, len(dnfItems)) + for _, dnfItem := range dnfItems { + cnfItems := SplitCNFItems(dnfItem) + newCNFItems := make([]Expression, 0, len(cnfItems)) + for _, cnfItem := range cnfItems { + if itemSF, ok := cnfItem.(*ScalarFunction); ok && itemSF.FuncName.L == ast.LogicOr { + relaxedCNFItem := DeriveRelaxedFiltersFromDNF(ctx, cnfItem, schema) + if relaxedCNFItem != nil { + newCNFItems = append(newCNFItems, relaxedCNFItem) + } + // If relaxed expression for embedded DNF is universal set, just drop this CNF item + continue + } + // This cnfItem must be simple expression now + // If it cannot be fully covered by schema, just drop this CNF item + if ExprFromSchema(cnfItem, schema) { + newCNFItems = append(newCNFItems, cnfItem) + } + } + // If this DNF item involves no column of specified schema, the relaxed expression must be universal set + if len(newCNFItems) == 0 { + return nil + } + newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) + } + return ComposeDNFCondition(ctx, newDNFItems...) +} + +// GetRowLen gets the length if the func is row, returns 1 if not row. +func GetRowLen(e Expression) int { + if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { + return len(f.GetArgs()) + } + return 1 +} + +// CheckArgsNotMultiColumnRow checks the args are not multi-column row. +func CheckArgsNotMultiColumnRow(args ...Expression) error { + for _, arg := range args { + if GetRowLen(arg) != 1 { + return ErrOperandColumns.GenWithStackByArgs(1) + } + } + return nil +} + +// GetFuncArg gets the argument of the function at idx. +func GetFuncArg(e Expression, idx int) Expression { + if f, ok := e.(*ScalarFunction); ok { + return f.GetArgs()[idx] + } + return nil +} + +// PopRowFirstArg pops the first element and returns the rest of row. +// e.g. After this function (1, 2, 3) becomes (2, 3). +func PopRowFirstArg(ctx BuildContext, e Expression) (ret Expression, err error) { + if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { + args := f.GetArgs() + if len(args) == 2 { + return args[1], nil + } + ret, err = NewFunction(ctx, ast.RowFunc, f.GetType(ctx.GetEvalCtx()), args[1:]...) + return ret, err + } + return +} + +// DatumToConstant generates a Constant expression from a Datum. +func DatumToConstant(d types.Datum, tp byte, flag uint) *Constant { + t := types.NewFieldType(tp) + t.AddFlag(flag) + return &Constant{Value: d, RetType: t} +} + +// ParamMarkerExpression generate a getparam function expression. +func ParamMarkerExpression(ctx variable.SessionVarsProvider, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) { + useCache := ctx.GetSessionVars().StmtCtx.UseCache() + tp := types.NewFieldType(mysql.TypeUnspecified) + types.InferParamTypeFromDatum(&v.Datum, tp) + value := &Constant{Value: v.Datum, RetType: tp} + if useCache || needParam { + value.ParamMarker = &ParamMarker{ + order: v.Order, + } + } + return value, nil +} + +// ParamMarkerInPrepareChecker checks whether the given ast tree has paramMarker and is in prepare statement. +type ParamMarkerInPrepareChecker struct { + InPrepareStmt bool +} + +// Enter implements Visitor Interface. +func (pc *ParamMarkerInPrepareChecker) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch v := in.(type) { + case *driver.ParamMarkerExpr: + pc.InPrepareStmt = !v.InExecute + return v, true + } + return in, false +} + +// Leave implements Visitor Interface. +func (pc *ParamMarkerInPrepareChecker) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} + +// DisableParseJSONFlag4Expr disables ParseToJSONFlag for `expr` except Column. +// We should not *PARSE* a string as JSON under some scenarios. ParseToJSONFlag +// is 0 for JSON column yet(as well as JSON correlated column), so we can skip +// it. Moreover, Column.RetType refers to the infoschema, if we modify it, data +// race may happen if another goroutine read from the infoschema at the same +// time. +func DisableParseJSONFlag4Expr(ctx EvalContext, expr Expression) { + if _, isColumn := expr.(*Column); isColumn { + return + } + if _, isCorCol := expr.(*CorrelatedColumn); isCorCol { + return + } + expr.GetType(ctx).SetFlag(expr.GetType(ctx).GetFlag() & ^mysql.ParseToJSONFlag) +} + +// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr. +func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr { + return &ast.PositionExpr{P: p} +} + +// PosFromPositionExpr generates a position value from PositionExpr. +func PosFromPositionExpr(ctx BuildContext, vars variable.SessionVarsProvider, v *ast.PositionExpr) (int, bool, error) { + if v.P == nil { + return v.N, false, nil + } + value, err := ParamMarkerExpression(vars, v.P.(*driver.ParamMarkerExpr), false) + if err != nil { + return 0, true, err + } + pos, isNull, err := GetIntFromConstant(ctx.GetEvalCtx(), value) + if err != nil || isNull { + return 0, true, err + } + return pos, false, nil +} + +// GetStringFromConstant gets a string value from the Constant expression. +func GetStringFromConstant(ctx EvalContext, value Expression) (string, bool, error) { + con, ok := value.(*Constant) + if !ok { + err := errors.Errorf("Not a Constant expression %+v", value) + return "", true, err + } + str, isNull, err := con.EvalString(ctx, chunk.Row{}) + if err != nil || isNull { + return "", true, err + } + return str, false, nil +} + +// GetIntFromConstant gets an integer value from the Constant expression. +func GetIntFromConstant(ctx EvalContext, value Expression) (int, bool, error) { + str, isNull, err := GetStringFromConstant(ctx, value) + if err != nil || isNull { + return 0, true, err + } + intNum, err := strconv.Atoi(str) + if err != nil { + return 0, true, nil + } + return intNum, false, nil +} + +// BuildNotNullExpr wraps up `not(isnull())` for given expression. +func BuildNotNullExpr(ctx BuildContext, expr Expression) Expression { + isNull := NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr) + notNull := NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNull) + return notNull +} + +// IsRuntimeConstExpr checks if a expr can be treated as a constant in **executor**. +func IsRuntimeConstExpr(expr Expression) bool { + switch x := expr.(type) { + case *ScalarFunction: + if _, ok := unFoldableFunctions[x.FuncName.L]; ok { + return false + } + for _, arg := range x.GetArgs() { + if !IsRuntimeConstExpr(arg) { + return false + } + } + return true + case *Column: + return false + case *Constant, *CorrelatedColumn: + return true + } + return false +} + +// CheckNonDeterministic checks whether the current expression contains a non-deterministic func. +func CheckNonDeterministic(e Expression) bool { + switch x := e.(type) { + case *Constant, *Column, *CorrelatedColumn: + return false + case *ScalarFunction: + if _, ok := unFoldableFunctions[x.FuncName.L]; ok { + return true + } + for _, arg := range x.GetArgs() { + if CheckNonDeterministic(arg) { + return true + } + } + } + return false +} + +// CheckFuncInExpr checks whether there's a given function in the expression. +func CheckFuncInExpr(e Expression, funcName string) bool { + switch x := e.(type) { + case *Constant, *Column, *CorrelatedColumn: + return false + case *ScalarFunction: + if x.FuncName.L == funcName { + return true + } + for _, arg := range x.GetArgs() { + if CheckFuncInExpr(arg, funcName) { + return true + } + } + } + return false +} + +// IsMutableEffectsExpr checks if expr contains function which is mutable or has side effects. +func IsMutableEffectsExpr(expr Expression) bool { + switch x := expr.(type) { + case *ScalarFunction: + if _, ok := mutableEffectsFunctions[x.FuncName.L]; ok { + return true + } + for _, arg := range x.GetArgs() { + if IsMutableEffectsExpr(arg) { + return true + } + } + case *Column: + case *Constant: + if x.DeferredExpr != nil { + return IsMutableEffectsExpr(x.DeferredExpr) + } + } + return false +} + +// IsImmutableFunc checks whether this expression only consists of foldable functions. +// This expression can be evaluated by using `expr.Eval(chunk.Row{})` directly and the result won't change if it's immutable. +func IsImmutableFunc(expr Expression) bool { + switch x := expr.(type) { + case *ScalarFunction: + if _, ok := unFoldableFunctions[x.FuncName.L]; ok { + return false + } + if _, ok := mutableEffectsFunctions[x.FuncName.L]; ok { + return false + } + for _, arg := range x.GetArgs() { + if !IsImmutableFunc(arg) { + return false + } + } + return true + default: + return true + } +} + +// RemoveDupExprs removes identical exprs. Not that if expr contains functions which +// are mutable or have side effects, we cannot remove it even if it has duplicates; +// if the plan is going to be cached, we cannot remove expressions containing `?` neither. +func RemoveDupExprs(exprs []Expression) []Expression { + res := make([]Expression, 0, len(exprs)) + exists := make(map[string]struct{}, len(exprs)) + for _, expr := range exprs { + key := string(expr.HashCode()) + if _, ok := exists[key]; !ok || IsMutableEffectsExpr(expr) { + res = append(res, expr) + exists[key] = struct{}{} + } + } + return res +} + +// GetUint64FromConstant gets a uint64 from constant expression. +func GetUint64FromConstant(ctx EvalContext, expr Expression) (uint64, bool, bool) { + con, ok := expr.(*Constant) + if !ok { + logutil.BgLogger().Warn("not a constant expression", zap.String("expression", expr.ExplainInfo(ctx))) + return 0, false, false + } + dt := con.Value + if con.ParamMarker != nil { + var err error + dt, err = con.ParamMarker.GetUserVar(ctx) + if err != nil { + logutil.BgLogger().Warn("get param failed", zap.Error(err)) + return 0, false, false + } + } else if con.DeferredExpr != nil { + var err error + dt, err = con.DeferredExpr.Eval(ctx, chunk.Row{}) + if err != nil { + logutil.BgLogger().Warn("eval deferred expr failed", zap.Error(err)) + return 0, false, false + } + } + switch dt.Kind() { + case types.KindNull: + return 0, true, true + case types.KindInt64: + val := dt.GetInt64() + if val < 0 { + return 0, false, false + } + return uint64(val), false, true + case types.KindUint64: + return dt.GetUint64(), false, true + } + return 0, false, false +} + +// ContainVirtualColumn checks if the expressions contain a virtual column +func ContainVirtualColumn(exprs []Expression) bool { + for _, expr := range exprs { + switch v := expr.(type) { + case *Column: + if v.VirtualExpr != nil { + return true + } + case *ScalarFunction: + if ContainVirtualColumn(v.GetArgs()) { + return true + } + } + } + return false +} + +// ContainCorrelatedColumn checks if the expressions contain a correlated column +func ContainCorrelatedColumn(exprs []Expression) bool { + for _, expr := range exprs { + switch v := expr.(type) { + case *CorrelatedColumn: + return true + case *ScalarFunction: + if ContainCorrelatedColumn(v.GetArgs()) { + return true + } + } + } + return false +} + +func jsonUnquoteFunctionBenefitsFromPushedDown(sf *ScalarFunction) bool { + arg0 := sf.GetArgs()[0] + // Only `->>` which parsed to JSONUnquote(CAST(JSONExtract() AS string)) can be pushed down to tikv + if fChild, ok := arg0.(*ScalarFunction); ok { + if fChild.FuncName.L == ast.Cast { + if fGrand, ok := fChild.GetArgs()[0].(*ScalarFunction); ok { + if fGrand.FuncName.L == ast.JSONExtract { + return true + } + } + } + } + return false +} + +// ProjectionBenefitsFromPushedDown evaluates if the expressions can improve performance when pushed down to TiKV +// Projections are not pushed down to tikv by default, thus we need to check strictly here to avoid potential performance degradation. +// Note: virtual column is not considered here, since this function cares performance instead of functionality +func ProjectionBenefitsFromPushedDown(exprs []Expression, inputSchemaLen int) bool { + allColRef := true + colRefCount := 0 + for _, expr := range exprs { + switch v := expr.(type) { + case *Column: + colRefCount = colRefCount + 1 + continue + case *ScalarFunction: + allColRef = false + switch v.FuncName.L { + case ast.JSONDepth, ast.JSONLength, ast.JSONType, ast.JSONValid, ast.JSONContains, ast.JSONContainsPath, + ast.JSONExtract, ast.JSONKeys, ast.JSONSearch, ast.JSONMemberOf, ast.JSONOverlaps: + continue + case ast.JSONUnquote: + if jsonUnquoteFunctionBenefitsFromPushedDown(v) { + continue + } + return false + default: + return false + } + default: + return false + } + } + // For all col refs, only push down column pruning projections + if allColRef { + return colRefCount < inputSchemaLen + } + return true +} + +// MaybeOverOptimized4PlanCache used to check whether an optimization can work +// for the statement when we enable the plan cache. +// In some situations, some optimizations maybe over-optimize and cache an +// overOptimized plan. The cached plan may not get the correct result when we +// reuse the plan for other statements. +// For example, `pk>=$a and pk<=$b` can be optimized to a PointGet when +// `$a==$b`, but it will cause wrong results when `$a!=$b`. +// So we need to do the check here. The check includes the following aspects: +// 1. Whether the plan cache switch is enable. +// 2. Whether the statement can be cached. +// 3. Whether the expressions contain a lazy constant. +// TODO: Do more careful check here. +func MaybeOverOptimized4PlanCache(ctx BuildContext, exprs []Expression) bool { + // If we do not enable plan cache, all the optimization can work correctly. + if !ctx.IsUseCache() { + return false + } + return containMutableConst(ctx.GetEvalCtx(), exprs) +} + +// containMutableConst checks if the expressions contain a lazy constant. +func containMutableConst(ctx EvalContext, exprs []Expression) bool { + for _, expr := range exprs { + switch v := expr.(type) { + case *Constant: + if v.ParamMarker != nil || v.DeferredExpr != nil { + return true + } + case *ScalarFunction: + if containMutableConst(ctx, v.GetArgs()) { + return true + } + } + } + return false +} + +// RemoveMutableConst used to remove the `ParamMarker` and `DeferredExpr` in the `Constant` expr. +func RemoveMutableConst(ctx BuildContext, exprs []Expression) (err error) { + for _, expr := range exprs { + switch v := expr.(type) { + case *Constant: + v.ParamMarker = nil + if v.DeferredExpr != nil { // evaluate and update v.Value to convert v to a complete immutable constant. + // TODO: remove or hide DeferredExpr since it's too dangerous (hard to be consistent with v.Value all the time). + v.Value, err = v.DeferredExpr.Eval(ctx.GetEvalCtx(), chunk.Row{}) + if err != nil { + return err + } + v.DeferredExpr = nil + } + v.DeferredExpr = nil // do nothing since v.Value has already been evaluated in this case. + case *ScalarFunction: + return RemoveMutableConst(ctx, v.GetArgs()) + } + } + return nil +} + +const ( + _ = iota + kib = 1 << (10 * iota) + mib = 1 << (10 * iota) + gib = 1 << (10 * iota) + tib = 1 << (10 * iota) + pib = 1 << (10 * iota) + eib = 1 << (10 * iota) +) + +const ( + nano = 1 + micro = 1000 * nano + milli = 1000 * micro + sec = 1000 * milli + min = 60 * sec + hour = 60 * min + dayTime = 24 * hour +) + +// GetFormatBytes convert byte count to value with units. +func GetFormatBytes(bytes float64) string { + var divisor float64 + var unit string + + bytesAbs := math.Abs(bytes) + if bytesAbs >= eib { + divisor = eib + unit = "EiB" + } else if bytesAbs >= pib { + divisor = pib + unit = "PiB" + } else if bytesAbs >= tib { + divisor = tib + unit = "TiB" + } else if bytesAbs >= gib { + divisor = gib + unit = "GiB" + } else if bytesAbs >= mib { + divisor = mib + unit = "MiB" + } else if bytesAbs >= kib { + divisor = kib + unit = "KiB" + } else { + divisor = 1 + unit = "bytes" + } + + if divisor == 1 { + return strconv.FormatFloat(bytes, 'f', 0, 64) + " " + unit + } + value := bytes / divisor + if math.Abs(value) >= 100000.0 { + return strconv.FormatFloat(value, 'e', 2, 64) + " " + unit + } + return strconv.FormatFloat(value, 'f', 2, 64) + " " + unit +} + +// GetFormatNanoTime convert time in nanoseconds to value with units. +func GetFormatNanoTime(time float64) string { + var divisor float64 + var unit string + + timeAbs := math.Abs(time) + if timeAbs >= dayTime { + divisor = dayTime + unit = "d" + } else if timeAbs >= hour { + divisor = hour + unit = "h" + } else if timeAbs >= min { + divisor = min + unit = "min" + } else if timeAbs >= sec { + divisor = sec + unit = "s" + } else if timeAbs >= milli { + divisor = milli + unit = "ms" + } else if timeAbs >= micro { + divisor = micro + unit = "us" + } else { + divisor = 1 + unit = "ns" + } + + if divisor == 1 { + return strconv.FormatFloat(time, 'f', 0, 64) + " " + unit + } + value := time / divisor + if math.Abs(value) >= 100000.0 { + return strconv.FormatFloat(value, 'e', 2, 64) + " " + unit + } + return strconv.FormatFloat(value, 'f', 2, 64) + " " + unit +} + +// SQLDigestTextRetriever is used to find the normalized SQL statement text by SQL digests in statements_summary table. +// It's exported for test purposes. It's used by the `tidb_decode_sql_digests` builtin function, but also exposed to +// be used in other modules. +type SQLDigestTextRetriever struct { + // SQLDigestsMap is the place to put the digests that's requested for getting SQL text and also the place to put + // the query result. + SQLDigestsMap map[string]string + + // Replace querying for test purposes. + mockLocalData map[string]string + mockGlobalData map[string]string + // There are two ways for querying information: 1) query specified digests by WHERE IN query, or 2) query all + // information to avoid the too long WHERE IN clause. If there are more than `fetchAllLimit` digests needs to be + // queried, the second way will be chosen; otherwise, the first way will be chosen. + fetchAllLimit int +} + +// NewSQLDigestTextRetriever creates a new SQLDigestTextRetriever. +func NewSQLDigestTextRetriever() *SQLDigestTextRetriever { + return &SQLDigestTextRetriever{ + SQLDigestsMap: make(map[string]string), + fetchAllLimit: 512, + } +} + +func (r *SQLDigestTextRetriever) runMockQuery(data map[string]string, inValues []any) (map[string]string, error) { + if len(inValues) == 0 { + return data, nil + } + res := make(map[string]string, len(inValues)) + for _, digest := range inValues { + if text, ok := data[digest.(string)]; ok { + res[digest.(string)] = text + } + } + return res, nil +} + +// runFetchDigestQuery runs query to the system tables to fetch the kv mapping of SQL digests and normalized SQL texts +// of the given SQL digests, if `inValues` is given, or all these mappings otherwise. If `queryGlobal` is false, it +// queries information_schema.statements_summary and information_schema.statements_summary_history; otherwise, it +// queries the cluster version of these two tables. +func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, exec contextopt.SQLExecutor, queryGlobal bool, inValues []any) (map[string]string, error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + // If mock data is set, query the mock data instead of the real statements_summary tables. + if !queryGlobal && r.mockLocalData != nil { + return r.runMockQuery(r.mockLocalData, inValues) + } else if queryGlobal && r.mockGlobalData != nil { + return r.runMockQuery(r.mockGlobalData, inValues) + } + + // Information in statements_summary will be periodically moved to statements_summary_history. Union them together + // to avoid missing information when statements_summary is just cleared. + stmt := "select digest, digest_text from information_schema.statements_summary union distinct " + + "select digest, digest_text from information_schema.statements_summary_history" + if queryGlobal { + stmt = "select digest, digest_text from information_schema.cluster_statements_summary union distinct " + + "select digest, digest_text from information_schema.cluster_statements_summary_history" + } + // Add the where clause if `inValues` is specified. + if len(inValues) > 0 { + stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" + } + + rows, _, err := exec.ExecRestrictedSQL(ctx, nil, stmt, inValues...) + if err != nil { + return nil, err + } + + res := make(map[string]string, len(rows)) + for _, row := range rows { + res[row.GetString(0)] = row.GetString(1) + } + return res, nil +} + +func (r *SQLDigestTextRetriever) updateDigestInfo(queryResult map[string]string) { + for digest, text := range r.SQLDigestsMap { + if len(text) > 0 { + // The text of this digest is already known + continue + } + sqlText, ok := queryResult[digest] + if ok { + r.SQLDigestsMap[digest] = sqlText + } + } +} + +// RetrieveLocal tries to retrieve the SQL text of the SQL digests from local information. +func (r *SQLDigestTextRetriever) RetrieveLocal(ctx context.Context, exec contextopt.SQLExecutor) error { + if len(r.SQLDigestsMap) == 0 { + return nil + } + + var queryResult map[string]string + if len(r.SQLDigestsMap) <= r.fetchAllLimit { + inValues := make([]any, 0, len(r.SQLDigestsMap)) + for key := range r.SQLDigestsMap { + inValues = append(inValues, key) + } + var err error + queryResult, err = r.runFetchDigestQuery(ctx, exec, false, inValues) + if err != nil { + return errors.Trace(err) + } + + if len(queryResult) == len(r.SQLDigestsMap) { + r.SQLDigestsMap = queryResult + return nil + } + } else { + var err error + queryResult, err = r.runFetchDigestQuery(ctx, exec, false, nil) + if err != nil { + return errors.Trace(err) + } + } + + r.updateDigestInfo(queryResult) + return nil +} + +// RetrieveGlobal tries to retrieve the SQL text of the SQL digests from the information of the whole cluster. +func (r *SQLDigestTextRetriever) RetrieveGlobal(ctx context.Context, exec contextopt.SQLExecutor) error { + err := r.RetrieveLocal(ctx, exec) + if err != nil { + return errors.Trace(err) + } + + // In some unit test environments it's unable to retrieve global info, and this function blocks it for tens of + // seconds, which wastes much time during unit test. In this case, enable this failpoint to bypass retrieving + // globally. + failpoint.Inject("sqlDigestRetrieverSkipRetrieveGlobal", func() { + failpoint.Return(nil) + }) + + var unknownDigests []any + for k, v := range r.SQLDigestsMap { + if len(v) == 0 { + unknownDigests = append(unknownDigests, k) + } + } + + if len(unknownDigests) == 0 { + return nil + } + + var queryResult map[string]string + if len(r.SQLDigestsMap) <= r.fetchAllLimit { + queryResult, err = r.runFetchDigestQuery(ctx, exec, true, unknownDigests) + if err != nil { + return errors.Trace(err) + } + } else { + queryResult, err = r.runFetchDigestQuery(ctx, exec, true, nil) + if err != nil { + return errors.Trace(err) + } + } + + r.updateDigestInfo(queryResult) + return nil +} + +// ExprsToStringsForDisplay convert a slice of Expression to a slice of string using Expression.String(), and +// to make it better for display and debug, it also escapes the string to corresponding golang string literal, +// which means using \t, \n, \x??, \u????, ... to represent newline, control character, non-printable character, +// invalid utf-8 bytes and so on. +func ExprsToStringsForDisplay(ctx EvalContext, exprs []Expression) []string { + strs := make([]string, len(exprs)) + for i, cond := range exprs { + quote := `"` + // We only need the escape functionality of strconv.Quote, the quoting is not needed, + // so we trim the \" prefix and suffix here. + strs[i] = strings.TrimSuffix( + strings.TrimPrefix( + strconv.Quote(cond.StringWithCtx(ctx, errors.RedactLogDisable)), + quote), + quote) + } + return strs +} + +// ConstExprConsiderPlanCache indicates whether the expression can be considered as a constant expression considering planCache. +// If the expression is in plan cache, it should have a const level `ConstStrict` because it can be shared across statements. +// If the expression is not in plan cache, `ConstOnlyInContext` is enough because it is only used in one statement. +// Please notice that if the expression may be cached in other ways except plan cache, we should not use this function. +func ConstExprConsiderPlanCache(expr Expression, inPlanCache bool) bool { + switch expr.ConstLevel() { + case ConstStrict: + return true + case ConstOnlyInContext: + return !inPlanCache + default: + return false + } +} + +// ExprsHasSideEffects checks if any of the expressions has side effects. +func ExprsHasSideEffects(exprs []Expression) bool { + for _, expr := range exprs { + if ExprHasSetVarOrSleep(expr) { + return true + } + } + return false +} + +// ExprHasSetVarOrSleep checks if the expression has SetVar function or Sleep function. +func ExprHasSetVarOrSleep(expr Expression) bool { + scalaFunc, isScalaFunc := expr.(*ScalarFunction) + if !isScalaFunc { + return false + } + if scalaFunc.FuncName.L == ast.SetVar || scalaFunc.FuncName.L == ast.Sleep { + return true + } + for _, arg := range scalaFunc.GetArgs() { + if ExprHasSetVarOrSleep(arg) { + return true + } + } + return false +} diff --git a/pkg/infoschema/binding__failpoint_binding__.go b/pkg/infoschema/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..5fad7f8a11a71 --- /dev/null +++ b/pkg/infoschema/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package infoschema + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/infoschema/builder.go b/pkg/infoschema/builder.go index 6e489d29b4721..81e30ea733338 100644 --- a/pkg/infoschema/builder.go +++ b/pkg/infoschema/builder.go @@ -670,13 +670,13 @@ func applyCreateTable(b *Builder, m *meta.Meta, dbInfo *model.DBInfo, tableID in // Failpoint check whether tableInfo should be added to repairInfo. // Typically used in repair table test to load mock `bad` tableInfo into repairInfo. - failpoint.Inject("repairFetchCreateTable", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("repairFetchCreateTable")); _err_ == nil { if val.(bool) { if domainutil.RepairInfo.InRepairMode() && tp != model.ActionRepairTable && domainutil.RepairInfo.CheckAndFetchRepairedTable(dbInfo, tblInfo) { - failpoint.Return(nil, nil) + return nil, nil } } - }) + } ConvertCharsetCollateToLowerCaseIfNeed(tblInfo) ConvertOldVersionUTF8ToUTF8MB4IfNeed(tblInfo) diff --git a/pkg/infoschema/builder.go__failpoint_stash__ b/pkg/infoschema/builder.go__failpoint_stash__ new file mode 100644 index 0000000000000..6e489d29b4721 --- /dev/null +++ b/pkg/infoschema/builder.go__failpoint_stash__ @@ -0,0 +1,1040 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infoschema + +import ( + "cmp" + "context" + "fmt" + "maps" + "slices" + "strings" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/util/domainutil" + "github.com/pingcap/tidb/pkg/util/intest" +) + +// Builder builds a new InfoSchema. +type Builder struct { + enableV2 bool + infoschemaV2 + // dbInfos do not need to be copied everytime applying a diff, instead, + // they can be copied only once over the whole lifespan of Builder. + // This map will indicate which DB has been copied, so that they + // don't need to be copied again. + dirtyDB map[string]bool + + // Used by autoid allocators + autoid.Requirement + + factory func() (pools.Resource, error) + bundleInfoBuilder + infoData *Data + store kv.Storage +} + +// ApplyDiff applies SchemaDiff to the new InfoSchema. +// Return the detail updated table IDs that are produced from SchemaDiff and an error. +func (b *Builder) ApplyDiff(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + b.schemaMetaVersion = diff.Version + switch diff.Type { + case model.ActionCreateSchema: + return nil, applyCreateSchema(b, m, diff) + case model.ActionDropSchema: + return applyDropSchema(b, diff), nil + case model.ActionRecoverSchema: + return applyRecoverSchema(b, m, diff) + case model.ActionModifySchemaCharsetAndCollate: + return nil, applyModifySchemaCharsetAndCollate(b, m, diff) + case model.ActionModifySchemaDefaultPlacement: + return nil, applyModifySchemaDefaultPlacement(b, m, diff) + case model.ActionCreatePlacementPolicy: + return nil, applyCreatePolicy(b, m, diff) + case model.ActionDropPlacementPolicy: + return applyDropPolicy(b, diff.SchemaID), nil + case model.ActionAlterPlacementPolicy: + return applyAlterPolicy(b, m, diff) + case model.ActionCreateResourceGroup: + return nil, applyCreateOrAlterResourceGroup(b, m, diff) + case model.ActionAlterResourceGroup: + return nil, applyCreateOrAlterResourceGroup(b, m, diff) + case model.ActionDropResourceGroup: + return applyDropResourceGroup(b, m, diff), nil + case model.ActionTruncateTablePartition, model.ActionTruncateTable: + return applyTruncateTableOrPartition(b, m, diff) + case model.ActionDropTable, model.ActionDropTablePartition: + return applyDropTableOrPartition(b, m, diff) + case model.ActionRecoverTable: + return applyRecoverTable(b, m, diff) + case model.ActionCreateTables: + return applyCreateTables(b, m, diff) + case model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + return applyReorganizePartition(b, m, diff) + case model.ActionExchangeTablePartition: + return applyExchangeTablePartition(b, m, diff) + case model.ActionFlashbackCluster: + return []int64{-1}, nil + default: + return applyDefaultAction(b, m, diff) + } +} + +func (b *Builder) applyCreateTables(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + return b.applyAffectedOpts(m, make([]int64, 0, len(diff.AffectedOpts)), diff, model.ActionCreateTable) +} + +func applyTruncateTableOrPartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + tblIDs, err := applyTableUpdate(b, m, diff) + if err != nil { + return nil, errors.Trace(err) + } + + // bundle ops + if diff.Type == model.ActionTruncateTable { + b.deleteBundle(b.infoSchema, diff.OldTableID) + b.markTableBundleShouldUpdate(diff.TableID) + } + + for _, opt := range diff.AffectedOpts { + if diff.Type == model.ActionTruncateTablePartition { + // Reduce the impact on DML when executing partition DDL. eg. + // While session 1 performs the DML operation associated with partition 1, + // the TRUNCATE operation of session 2 on partition 2 does not cause the operation of session 1 to fail. + tblIDs = append(tblIDs, opt.OldTableID) + b.markPartitionBundleShouldUpdate(opt.TableID) + } + b.deleteBundle(b.infoSchema, opt.OldTableID) + } + return tblIDs, nil +} + +func applyDropTableOrPartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + tblIDs, err := applyTableUpdate(b, m, diff) + if err != nil { + return nil, errors.Trace(err) + } + + // bundle ops + b.markTableBundleShouldUpdate(diff.TableID) + for _, opt := range diff.AffectedOpts { + b.deleteBundle(b.infoSchema, opt.OldTableID) + } + return tblIDs, nil +} + +func applyReorganizePartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + tblIDs, err := applyTableUpdate(b, m, diff) + if err != nil { + return nil, errors.Trace(err) + } + + // bundle ops + for _, opt := range diff.AffectedOpts { + if opt.OldTableID != 0 { + b.deleteBundle(b.infoSchema, opt.OldTableID) + } + if opt.TableID != 0 { + b.markTableBundleShouldUpdate(opt.TableID) + } + // TODO: Should we also check markPartitionBundleShouldUpdate?!? + } + return tblIDs, nil +} + +func applyExchangeTablePartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + // It is not in StatePublic. + if diff.OldTableID == diff.TableID && diff.OldSchemaID == diff.SchemaID { + ntIDs, err := applyTableUpdate(b, m, diff) + if err != nil { + return nil, errors.Trace(err) + } + if diff.AffectedOpts == nil || diff.AffectedOpts[0].OldSchemaID == 0 { + return ntIDs, err + } + // Reload parition tabe. + ptSchemaID := diff.AffectedOpts[0].OldSchemaID + ptID := diff.AffectedOpts[0].TableID + ptDiff := &model.SchemaDiff{ + Type: diff.Type, + Version: diff.Version, + TableID: ptID, + SchemaID: ptSchemaID, + OldTableID: ptID, + OldSchemaID: ptSchemaID, + } + ptIDs, err := applyTableUpdate(b, m, ptDiff) + if err != nil { + return nil, errors.Trace(err) + } + return append(ptIDs, ntIDs...), nil + } + ntSchemaID := diff.OldSchemaID + ntID := diff.OldTableID + ptSchemaID := diff.SchemaID + ptID := diff.TableID + partID := diff.TableID + if len(diff.AffectedOpts) > 0 { + ptID = diff.AffectedOpts[0].TableID + if diff.AffectedOpts[0].SchemaID != 0 { + ptSchemaID = diff.AffectedOpts[0].SchemaID + } + } + // The normal table needs to be updated first: + // Just update the tables separately + currDiff := &model.SchemaDiff{ + // This is only for the case since https://github.com/pingcap/tidb/pull/45877 + // Fixed now, by adding back the AffectedOpts + // to carry the partitioned Table ID. + Type: diff.Type, + Version: diff.Version, + TableID: ntID, + SchemaID: ntSchemaID, + } + if ptID != partID { + currDiff.TableID = partID + currDiff.OldTableID = ntID + currDiff.OldSchemaID = ntSchemaID + } + ntIDs, err := applyTableUpdate(b, m, currDiff) + if err != nil { + return nil, errors.Trace(err) + } + // partID is the new id for the non-partitioned table! + b.markTableBundleShouldUpdate(partID) + // Then the partitioned table, will re-read the whole table, including all partitions! + currDiff.TableID = ptID + currDiff.SchemaID = ptSchemaID + currDiff.OldTableID = ptID + currDiff.OldSchemaID = ptSchemaID + ptIDs, err := applyTableUpdate(b, m, currDiff) + if err != nil { + return nil, errors.Trace(err) + } + // ntID is the new id for the partition! + b.markPartitionBundleShouldUpdate(ntID) + err = updateAutoIDForExchangePartition(b.Requirement.Store(), ptSchemaID, ptID, ntSchemaID, ntID) + if err != nil { + return nil, errors.Trace(err) + } + return append(ptIDs, ntIDs...), nil +} + +func applyRecoverTable(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + tblIDs, err := applyTableUpdate(b, m, diff) + if err != nil { + return nil, errors.Trace(err) + } + + // bundle ops + for _, opt := range diff.AffectedOpts { + b.markTableBundleShouldUpdate(opt.TableID) + } + return tblIDs, nil +} + +func updateAutoIDForExchangePartition(store kv.Storage, ptSchemaID, ptID, ntSchemaID, ntID int64) error { + err := kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(ctx context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + ptAutoIDs, err := t.GetAutoIDAccessors(ptSchemaID, ptID).Get() + if err != nil { + return err + } + + // non-partition table auto IDs. + ntAutoIDs, err := t.GetAutoIDAccessors(ntSchemaID, ntID).Get() + if err != nil { + return err + } + + // Set both tables to the maximum auto IDs between normal table and partitioned table. + newAutoIDs := meta.AutoIDGroup{ + RowID: max(ptAutoIDs.RowID, ntAutoIDs.RowID), + IncrementID: max(ptAutoIDs.IncrementID, ntAutoIDs.IncrementID), + RandomID: max(ptAutoIDs.RandomID, ntAutoIDs.RandomID), + } + err = t.GetAutoIDAccessors(ptSchemaID, ptID).Put(newAutoIDs) + if err != nil { + return err + } + err = t.GetAutoIDAccessors(ntSchemaID, ntID).Put(newAutoIDs) + if err != nil { + return err + } + return nil + }) + + return err +} + +func (b *Builder) applyAffectedOpts(m *meta.Meta, tblIDs []int64, diff *model.SchemaDiff, tp model.ActionType) ([]int64, error) { + if diff.AffectedOpts != nil { + for _, opt := range diff.AffectedOpts { + affectedDiff := &model.SchemaDiff{ + Version: diff.Version, + Type: tp, + SchemaID: opt.SchemaID, + TableID: opt.TableID, + OldSchemaID: opt.OldSchemaID, + OldTableID: opt.OldTableID, + } + affectedIDs, err := b.ApplyDiff(m, affectedDiff) + if err != nil { + return nil, errors.Trace(err) + } + tblIDs = append(tblIDs, affectedIDs...) + } + } + return tblIDs, nil +} + +func applyDefaultAction(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + tblIDs, err := applyTableUpdate(b, m, diff) + if err != nil { + return nil, errors.Trace(err) + } + + return b.applyAffectedOpts(m, tblIDs, diff, diff.Type) +} + +func (b *Builder) getTableIDs(diff *model.SchemaDiff) (oldTableID, newTableID int64) { + switch diff.Type { + case model.ActionCreateSequence, model.ActionRecoverTable: + newTableID = diff.TableID + case model.ActionCreateTable: + // WARN: when support create table with foreign key in https://github.com/pingcap/tidb/pull/37148, + // create table with foreign key requires a multi-step state change(none -> write-only -> public), + // when the table's state changes from write-only to public, infoSchema need to drop the old table + // which state is write-only, otherwise, infoSchema.sortedTablesBuckets will contain 2 table both + // have the same ID, but one state is write-only, another table's state is public, it's unexpected. + // + // WARN: this change will break the compatibility if execute create table with foreign key DDL when upgrading TiDB, + // since old-version TiDB doesn't know to delete the old table. + // Since the cluster-index feature also has similar problem, we chose to prevent DDL execution during the upgrade process to avoid this issue. + oldTableID = diff.OldTableID + newTableID = diff.TableID + case model.ActionDropTable, model.ActionDropView, model.ActionDropSequence: + oldTableID = diff.TableID + case model.ActionTruncateTable, model.ActionCreateView, + model.ActionExchangeTablePartition, model.ActionAlterTablePartitioning, + model.ActionRemovePartitioning: + oldTableID = diff.OldTableID + newTableID = diff.TableID + default: + oldTableID = diff.TableID + newTableID = diff.TableID + } + return +} + +func (b *Builder) updateBundleForTableUpdate(diff *model.SchemaDiff, newTableID, oldTableID int64) { + // handle placement rule cache + switch diff.Type { + case model.ActionCreateTable: + b.markTableBundleShouldUpdate(newTableID) + case model.ActionDropTable: + b.deleteBundle(b.infoSchema, oldTableID) + case model.ActionTruncateTable: + b.deleteBundle(b.infoSchema, oldTableID) + b.markTableBundleShouldUpdate(newTableID) + case model.ActionRecoverTable: + b.markTableBundleShouldUpdate(newTableID) + case model.ActionAlterTablePlacement: + b.markTableBundleShouldUpdate(newTableID) + } +} + +func dropTableForUpdate(b *Builder, newTableID, oldTableID int64, dbInfo *model.DBInfo, diff *model.SchemaDiff) ([]int64, autoid.Allocators, error) { + tblIDs := make([]int64, 0, 2) + var keptAllocs autoid.Allocators + // We try to reuse the old allocator, so the cached auto ID can be reused. + if tableIDIsValid(oldTableID) { + if oldTableID == newTableID && + // For rename table, keep the old alloc. + + // For repairing table in TiDB cluster, given 2 normal node and 1 repair node. + // For normal node's information schema, repaired table is existed. + // For repair node's information schema, repaired table is filtered (couldn't find it in `is`). + // So here skip to reserve the allocators when repairing table. + diff.Type != model.ActionRepairTable && + // Alter sequence will change the sequence info in the allocator, so the old allocator is not valid any more. + diff.Type != model.ActionAlterSequence { + // TODO: Check how this would work with ADD/REMOVE Partitioning, + // which may have AutoID not connected to tableID + // TODO: can there be _tidb_rowid AutoID per partition? + oldAllocs, _ := allocByID(b, oldTableID) + keptAllocs = getKeptAllocators(diff, oldAllocs) + } + + tmpIDs := tblIDs + if (diff.Type == model.ActionRenameTable || diff.Type == model.ActionRenameTables) && diff.OldSchemaID != diff.SchemaID { + oldDBInfo, ok := oldSchemaInfo(b, diff) + if !ok { + return nil, keptAllocs, ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.OldSchemaID), + ) + } + tmpIDs = applyDropTable(b, diff, oldDBInfo, oldTableID, tmpIDs) + } else { + tmpIDs = applyDropTable(b, diff, dbInfo, oldTableID, tmpIDs) + } + + if oldTableID != newTableID { + // Update tblIDs only when oldTableID != newTableID because applyCreateTable() also updates tblIDs. + tblIDs = tmpIDs + } + } + return tblIDs, keptAllocs, nil +} + +func (b *Builder) applyTableUpdate(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + roDBInfo, ok := b.infoSchema.SchemaByID(diff.SchemaID) + if !ok { + return nil, ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.SchemaID), + ) + } + dbInfo := b.getSchemaAndCopyIfNecessary(roDBInfo.Name.L) + oldTableID, newTableID := b.getTableIDs(diff) + b.updateBundleForTableUpdate(diff, newTableID, oldTableID) + b.copySortedTables(oldTableID, newTableID) + + tblIDs, allocs, err := dropTableForUpdate(b, newTableID, oldTableID, dbInfo, diff) + if err != nil { + return nil, err + } + + if tableIDIsValid(newTableID) { + // All types except DropTableOrView. + var err error + tblIDs, err = applyCreateTable(b, m, dbInfo, newTableID, allocs, diff.Type, tblIDs, diff.Version) + if err != nil { + return nil, errors.Trace(err) + } + } + return tblIDs, nil +} + +// getKeptAllocators get allocators that is not changed by the DDL. +func getKeptAllocators(diff *model.SchemaDiff, oldAllocs autoid.Allocators) autoid.Allocators { + var autoIDChanged, autoRandomChanged bool + switch diff.Type { + case model.ActionRebaseAutoID, model.ActionModifyTableAutoIdCache: + autoIDChanged = true + case model.ActionRebaseAutoRandomBase: + autoRandomChanged = true + case model.ActionMultiSchemaChange: + for _, t := range diff.SubActionTypes { + switch t { + case model.ActionRebaseAutoID, model.ActionModifyTableAutoIdCache: + autoIDChanged = true + case model.ActionRebaseAutoRandomBase: + autoRandomChanged = true + } + } + } + var newAllocs autoid.Allocators + switch { + case autoIDChanged: + // Only drop auto-increment allocator. + newAllocs = oldAllocs.Filter(func(a autoid.Allocator) bool { + tp := a.GetType() + return tp != autoid.RowIDAllocType && tp != autoid.AutoIncrementType + }) + case autoRandomChanged: + // Only drop auto-random allocator. + newAllocs = oldAllocs.Filter(func(a autoid.Allocator) bool { + tp := a.GetType() + return tp != autoid.AutoRandomType + }) + default: + // Keep all allocators. + newAllocs = oldAllocs + } + return newAllocs +} + +func appendAffectedIDs(affected []int64, tblInfo *model.TableInfo) []int64 { + affected = append(affected, tblInfo.ID) + if pi := tblInfo.GetPartitionInfo(); pi != nil { + for _, def := range pi.Definitions { + affected = append(affected, def.ID) + } + } + return affected +} + +func (b *Builder) applyCreateSchema(m *meta.Meta, diff *model.SchemaDiff) error { + di, err := m.GetDatabase(diff.SchemaID) + if err != nil { + return errors.Trace(err) + } + if di == nil { + // When we apply an old schema diff, the database may has been dropped already, so we need to fall back to + // full load. + return ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.SchemaID), + ) + } + b.addDB(diff.Version, di, &schemaTables{dbInfo: di, tables: make(map[string]table.Table)}) + return nil +} + +func (b *Builder) applyModifySchemaCharsetAndCollate(m *meta.Meta, diff *model.SchemaDiff) error { + di, err := m.GetDatabase(diff.SchemaID) + if err != nil { + return errors.Trace(err) + } + if di == nil { + // This should never happen. + return ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.SchemaID), + ) + } + newDbInfo := b.getSchemaAndCopyIfNecessary(di.Name.L) + newDbInfo.Charset = di.Charset + newDbInfo.Collate = di.Collate + return nil +} + +func (b *Builder) applyModifySchemaDefaultPlacement(m *meta.Meta, diff *model.SchemaDiff) error { + di, err := m.GetDatabase(diff.SchemaID) + if err != nil { + return errors.Trace(err) + } + if di == nil { + // This should never happen. + return ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.SchemaID), + ) + } + newDbInfo := b.getSchemaAndCopyIfNecessary(di.Name.L) + newDbInfo.PlacementPolicyRef = di.PlacementPolicyRef + return nil +} + +func (b *Builder) applyDropSchema(diff *model.SchemaDiff) []int64 { + di, ok := b.infoSchema.SchemaByID(diff.SchemaID) + if !ok { + return nil + } + b.infoSchema.delSchema(di) + + // Copy the sortedTables that contain the table we are going to drop. + tableIDs := make([]int64, 0, len(di.Deprecated.Tables)) + bucketIdxMap := make(map[int]struct{}, len(di.Deprecated.Tables)) + for _, tbl := range di.Deprecated.Tables { + bucketIdxMap[tableBucketIdx(tbl.ID)] = struct{}{} + // TODO: If the table ID doesn't exist. + tableIDs = appendAffectedIDs(tableIDs, tbl) + } + for bucketIdx := range bucketIdxMap { + b.copySortedTablesBucket(bucketIdx) + } + + di = di.Clone() + for _, id := range tableIDs { + b.deleteBundle(b.infoSchema, id) + b.applyDropTable(diff, di, id, nil) + } + return tableIDs +} + +func (b *Builder) applyRecoverSchema(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + if di, ok := b.infoSchema.SchemaByID(diff.SchemaID); ok { + return nil, ErrDatabaseExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", di.ID), + ) + } + di, err := m.GetDatabase(diff.SchemaID) + if err != nil { + return nil, errors.Trace(err) + } + b.infoSchema.addSchema(&schemaTables{ + dbInfo: di, + tables: make(map[string]table.Table, len(diff.AffectedOpts)), + }) + return applyCreateTables(b, m, diff) +} + +// copySortedTables copies sortedTables for old table and new table for later modification. +func (b *Builder) copySortedTables(oldTableID, newTableID int64) { + if tableIDIsValid(oldTableID) { + b.copySortedTablesBucket(tableBucketIdx(oldTableID)) + } + if tableIDIsValid(newTableID) && newTableID != oldTableID { + b.copySortedTablesBucket(tableBucketIdx(newTableID)) + } +} + +func (b *Builder) copySortedTablesBucket(bucketIdx int) { + oldSortedTables := b.infoSchema.sortedTablesBuckets[bucketIdx] + newSortedTables := make(sortedTables, len(oldSortedTables)) + copy(newSortedTables, oldSortedTables) + b.infoSchema.sortedTablesBuckets[bucketIdx] = newSortedTables +} + +func (b *Builder) updateBundleForCreateTable(tblInfo *model.TableInfo, tp model.ActionType) { + switch tp { + case model.ActionDropTablePartition: + case model.ActionTruncateTablePartition: + // ReorganizePartition handle the bundles in applyReorganizePartition + case model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + default: + pi := tblInfo.GetPartitionInfo() + if pi != nil { + for _, partition := range pi.Definitions { + b.markPartitionBundleShouldUpdate(partition.ID) + } + } + } +} + +func (b *Builder) buildAllocsForCreateTable(tp model.ActionType, dbInfo *model.DBInfo, tblInfo *model.TableInfo, allocs autoid.Allocators) autoid.Allocators { + if len(allocs.Allocs) != 0 { + tblVer := autoid.AllocOptionTableInfoVersion(tblInfo.Version) + switch tp { + case model.ActionRebaseAutoID, model.ActionModifyTableAutoIdCache: + idCacheOpt := autoid.CustomAutoIncCacheOption(tblInfo.AutoIdCache) + // If the allocator type might be AutoIncrementType, create both AutoIncrementType + // and RowIDAllocType allocator for it. Because auto id and row id could share the same allocator. + // Allocate auto id may route to allocate row id, if row id allocator is nil, the program panic! + for _, tp := range [2]autoid.AllocatorType{autoid.AutoIncrementType, autoid.RowIDAllocType} { + newAlloc := autoid.NewAllocator(b.Requirement, dbInfo.ID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), tp, tblVer, idCacheOpt) + allocs = allocs.Append(newAlloc) + } + case model.ActionRebaseAutoRandomBase: + newAlloc := autoid.NewAllocator(b.Requirement, dbInfo.ID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType, tblVer) + allocs = allocs.Append(newAlloc) + case model.ActionModifyColumn: + // Change column attribute from auto_increment to auto_random. + if tblInfo.ContainsAutoRandomBits() && allocs.Get(autoid.AutoRandomType) == nil { + // Remove auto_increment allocator. + allocs = allocs.Filter(func(a autoid.Allocator) bool { + return a.GetType() != autoid.AutoIncrementType && a.GetType() != autoid.RowIDAllocType + }) + newAlloc := autoid.NewAllocator(b.Requirement, dbInfo.ID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType, tblVer) + allocs = allocs.Append(newAlloc) + } + } + return allocs + } + return autoid.NewAllocatorsFromTblInfo(b.Requirement, dbInfo.ID, tblInfo) +} + +func applyCreateTable(b *Builder, m *meta.Meta, dbInfo *model.DBInfo, tableID int64, allocs autoid.Allocators, tp model.ActionType, affected []int64, schemaVersion int64) ([]int64, error) { + tblInfo, err := m.GetTable(dbInfo.ID, tableID) + if err != nil { + return nil, errors.Trace(err) + } + if tblInfo == nil { + // When we apply an old schema diff, the table may has been dropped already, so we need to fall back to + // full load. + return nil, ErrTableNotExists.FastGenByArgs( + fmt.Sprintf("(Schema ID %d)", dbInfo.ID), + fmt.Sprintf("(Table ID %d)", tableID), + ) + } + + b.updateBundleForCreateTable(tblInfo, tp) + + if tp != model.ActionTruncateTablePartition { + affected = appendAffectedIDs(affected, tblInfo) + } + + // Failpoint check whether tableInfo should be added to repairInfo. + // Typically used in repair table test to load mock `bad` tableInfo into repairInfo. + failpoint.Inject("repairFetchCreateTable", func(val failpoint.Value) { + if val.(bool) { + if domainutil.RepairInfo.InRepairMode() && tp != model.ActionRepairTable && domainutil.RepairInfo.CheckAndFetchRepairedTable(dbInfo, tblInfo) { + failpoint.Return(nil, nil) + } + } + }) + + ConvertCharsetCollateToLowerCaseIfNeed(tblInfo) + ConvertOldVersionUTF8ToUTF8MB4IfNeed(tblInfo) + + allocs = b.buildAllocsForCreateTable(tp, dbInfo, tblInfo, allocs) + + tbl, err := tableFromMeta(allocs, b.factory, tblInfo) + if err != nil { + return nil, errors.Trace(err) + } + + b.infoSchema.addReferredForeignKeys(dbInfo.Name, tblInfo) + + if !b.enableV2 { + tableNames := b.infoSchema.schemaMap[dbInfo.Name.L] + tableNames.tables[tblInfo.Name.L] = tbl + } + b.addTable(schemaVersion, dbInfo, tblInfo, tbl) + + bucketIdx := tableBucketIdx(tableID) + slices.SortFunc(b.infoSchema.sortedTablesBuckets[bucketIdx], func(i, j table.Table) int { + return cmp.Compare(i.Meta().ID, j.Meta().ID) + }) + + if tblInfo.TempTableType != model.TempTableNone { + b.addTemporaryTable(tableID) + } + + newTbl, ok := b.infoSchema.TableByID(tableID) + if ok { + dbInfo.Deprecated.Tables = append(dbInfo.Deprecated.Tables, newTbl.Meta()) + } + return affected, nil +} + +// ConvertCharsetCollateToLowerCaseIfNeed convert the charset / collation of table and its columns to lower case, +// if the table's version is prior to TableInfoVersion3. +func ConvertCharsetCollateToLowerCaseIfNeed(tbInfo *model.TableInfo) { + if tbInfo.Version >= model.TableInfoVersion3 { + return + } + tbInfo.Charset = strings.ToLower(tbInfo.Charset) + tbInfo.Collate = strings.ToLower(tbInfo.Collate) + for _, col := range tbInfo.Columns { + col.SetCharset(strings.ToLower(col.GetCharset())) + col.SetCollate(strings.ToLower(col.GetCollate())) + } +} + +// ConvertOldVersionUTF8ToUTF8MB4IfNeed convert old version UTF8 to UTF8MB4 if config.TreatOldVersionUTF8AsUTF8MB4 is enable. +func ConvertOldVersionUTF8ToUTF8MB4IfNeed(tbInfo *model.TableInfo) { + if tbInfo.Version >= model.TableInfoVersion2 || !config.GetGlobalConfig().TreatOldVersionUTF8AsUTF8MB4 { + return + } + if tbInfo.Charset == charset.CharsetUTF8 { + tbInfo.Charset = charset.CharsetUTF8MB4 + tbInfo.Collate = charset.CollationUTF8MB4 + } + for _, col := range tbInfo.Columns { + if col.Version < model.ColumnInfoVersion2 && col.GetCharset() == charset.CharsetUTF8 { + col.SetCharset(charset.CharsetUTF8MB4) + col.SetCollate(charset.CollationUTF8MB4) + } + } +} + +func (b *Builder) applyDropTable(diff *model.SchemaDiff, dbInfo *model.DBInfo, tableID int64, affected []int64) []int64 { + bucketIdx := tableBucketIdx(tableID) + sortedTbls := b.infoSchema.sortedTablesBuckets[bucketIdx] + idx := sortedTbls.searchTable(tableID) + if idx == -1 { + return affected + } + if tableNames, ok := b.infoSchema.schemaMap[dbInfo.Name.L]; ok { + tblInfo := sortedTbls[idx].Meta() + delete(tableNames.tables, tblInfo.Name.L) + affected = appendAffectedIDs(affected, tblInfo) + } + // Remove the table in sorted table slice. + b.infoSchema.sortedTablesBuckets[bucketIdx] = append(sortedTbls[0:idx], sortedTbls[idx+1:]...) + + // Remove the table in temporaryTables + if b.infoSchema.temporaryTableIDs != nil { + delete(b.infoSchema.temporaryTableIDs, tableID) + } + // The old DBInfo still holds a reference to old table info, we need to remove it. + b.deleteReferredForeignKeys(dbInfo, tableID) + return affected +} + +func (b *Builder) deleteReferredForeignKeys(dbInfo *model.DBInfo, tableID int64) { + tables := dbInfo.Deprecated.Tables + for i, tblInfo := range tables { + if tblInfo.ID == tableID { + if i == len(tables)-1 { + tables = tables[:i] + } else { + tables = append(tables[:i], tables[i+1:]...) + } + b.infoSchema.deleteReferredForeignKeys(dbInfo.Name, tblInfo) + break + } + } + dbInfo.Deprecated.Tables = tables +} + +// Build builds and returns the built infoschema. +func (b *Builder) Build(schemaTS uint64) InfoSchema { + if b.enableV2 { + b.infoschemaV2.ts = schemaTS + updateInfoSchemaBundles(b) + return &b.infoschemaV2 + } + updateInfoSchemaBundles(b) + return b.infoSchema +} + +// InitWithOldInfoSchema initializes an empty new InfoSchema by copies all the data from old InfoSchema. +func (b *Builder) InitWithOldInfoSchema(oldSchema InfoSchema) error { + // Do not mix infoschema v1 and infoschema v2 building, this can simplify the logic. + // If we want to build infoschema v2, but the old infoschema is v1, just return error to trigger a full load. + isV2, _ := IsV2(oldSchema) + if b.enableV2 != isV2 { + return errors.Errorf("builder's (v2=%v) infoschema mismatch, return error to trigger full reload", b.enableV2) + } + + if schemaV2, ok := oldSchema.(*infoschemaV2); ok { + b.infoschemaV2.ts = schemaV2.ts + } + oldIS := oldSchema.base() + b.initBundleInfoBuilder() + b.infoSchema.schemaMetaVersion = oldIS.schemaMetaVersion + b.infoSchema.schemaMap = maps.Clone(oldIS.schemaMap) + b.infoSchema.schemaID2Name = maps.Clone(oldIS.schemaID2Name) + b.infoSchema.ruleBundleMap = maps.Clone(oldIS.ruleBundleMap) + b.infoSchema.policyMap = oldIS.ClonePlacementPolicies() + b.infoSchema.resourceGroupMap = oldIS.CloneResourceGroups() + b.infoSchema.temporaryTableIDs = maps.Clone(oldIS.temporaryTableIDs) + b.infoSchema.referredForeignKeyMap = maps.Clone(oldIS.referredForeignKeyMap) + + copy(b.infoSchema.sortedTablesBuckets, oldIS.sortedTablesBuckets) + return nil +} + +// getSchemaAndCopyIfNecessary creates a new schemaTables instance when a table in the database has changed. +// It also does modifications on the new one because old schemaTables must be read-only. +// And it will only copy the changed database once in the lifespan of the Builder. +// NOTE: please make sure the dbName is in lowercase. +func (b *Builder) getSchemaAndCopyIfNecessary(dbName string) *model.DBInfo { + if !b.dirtyDB[dbName] { + b.dirtyDB[dbName] = true + oldSchemaTables := b.infoSchema.schemaMap[dbName] + newSchemaTables := &schemaTables{ + dbInfo: oldSchemaTables.dbInfo.Copy(), + tables: maps.Clone(oldSchemaTables.tables), + } + b.infoSchema.addSchema(newSchemaTables) + return newSchemaTables.dbInfo + } + return b.infoSchema.schemaMap[dbName].dbInfo +} + +func (b *Builder) initVirtualTables(schemaVersion int64) error { + // Initialize virtual tables. + for _, driver := range drivers { + err := b.createSchemaTablesForDB(driver.DBInfo, driver.TableFromMeta, schemaVersion) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (b *Builder) sortAllTablesByID() { + // Sort all tables by `ID` + for _, v := range b.infoSchema.sortedTablesBuckets { + slices.SortFunc(v, func(a, b table.Table) int { + return cmp.Compare(a.Meta().ID, b.Meta().ID) + }) + } +} + +// InitWithDBInfos initializes an empty new InfoSchema with a slice of DBInfo, all placement rules, and schema version. +func (b *Builder) InitWithDBInfos(dbInfos []*model.DBInfo, policies []*model.PolicyInfo, resourceGroups []*model.ResourceGroupInfo, schemaVersion int64) error { + info := b.infoSchema + info.schemaMetaVersion = schemaVersion + + b.initBundleInfoBuilder() + + b.initMisc(dbInfos, policies, resourceGroups) + + if b.enableV2 { + // We must not clear the historial versions like b.infoData = NewData(), because losing + // the historial versions would cause applyDiff get db not exist error and fail, then + // infoschema reloading retries with full load every time. + // See https://github.com/pingcap/tidb/issues/53442 + // + // We must reset it, otherwise the stale tables remain and cause bugs later. + // For example, schema version 59: + // 107: t1 + // 112: t2 (partitions p0=113, p1=114, p2=115) + // operation: alter table t2 exchange partition p0 with table t1 + // schema version 60 if we do not reset: + // 107: t1 <- stale + // 112: t2 (partition p0=107, p1=114, p2=115) + // 113: t1 + // See https://github.com/pingcap/tidb/issues/54796 + b.infoData.resetBeforeFullLoad(schemaVersion) + } + + for _, di := range dbInfos { + err := b.createSchemaTablesForDB(di, tableFromMeta, schemaVersion) + if err != nil { + return errors.Trace(err) + } + } + + err := b.initVirtualTables(schemaVersion) + if err != nil { + return err + } + + b.sortAllTablesByID() + + return nil +} + +func tableFromMeta(alloc autoid.Allocators, factory func() (pools.Resource, error), tblInfo *model.TableInfo) (table.Table, error) { + ret, err := tables.TableFromMeta(alloc, tblInfo) + if err != nil { + return nil, errors.Trace(err) + } + if t, ok := ret.(table.CachedTable); ok { + var tmp pools.Resource + tmp, err = factory() + if err != nil { + return nil, errors.Trace(err) + } + + err = t.Init(tmp.(sessionctx.Context).GetSQLExecutor()) + if err != nil { + return nil, errors.Trace(err) + } + } + return ret, nil +} + +type tableFromMetaFunc func(alloc autoid.Allocators, factory func() (pools.Resource, error), tblInfo *model.TableInfo) (table.Table, error) + +func (b *Builder) createSchemaTablesForDB(di *model.DBInfo, tableFromMeta tableFromMetaFunc, schemaVersion int64) error { + schTbls := &schemaTables{ + dbInfo: di, + tables: make(map[string]table.Table, len(di.Deprecated.Tables)), + } + for _, t := range di.Deprecated.Tables { + allocs := autoid.NewAllocatorsFromTblInfo(b.Requirement, di.ID, t) + var tbl table.Table + tbl, err := tableFromMeta(allocs, b.factory, t) + if err != nil { + return errors.Wrap(err, fmt.Sprintf("Build table `%s`.`%s` schema failed", di.Name.O, t.Name.O)) + } + + schTbls.tables[t.Name.L] = tbl + b.addTable(schemaVersion, di, t, tbl) + if len(di.TableName2ID) > 0 { + delete(di.TableName2ID, t.Name.L) + } + + if tblInfo := tbl.Meta(); tblInfo.TempTableType != model.TempTableNone { + b.addTemporaryTable(tblInfo.ID) + } + } + // Add the rest name to ID mappings. + if b.enableV2 { + for name, id := range di.TableName2ID { + item := tableItem{ + dbName: di.Name.L, + dbID: di.ID, + tableName: name, + tableID: id, + schemaVersion: schemaVersion, + } + b.infoData.byID.Set(item) + b.infoData.byName.Set(item) + } + } + b.addDB(schemaVersion, di, schTbls) + + return nil +} + +func (b *Builder) addDB(schemaVersion int64, di *model.DBInfo, schTbls *schemaTables) { + if b.enableV2 { + if IsSpecialDB(di.Name.L) { + b.infoData.addSpecialDB(di, schTbls) + } else { + b.infoData.addDB(schemaVersion, di) + } + } else { + b.infoSchema.addSchema(schTbls) + } +} + +func (b *Builder) addTable(schemaVersion int64, di *model.DBInfo, tblInfo *model.TableInfo, tbl table.Table) { + if b.enableV2 { + b.infoData.add(tableItem{ + dbName: di.Name.L, + dbID: di.ID, + tableName: tblInfo.Name.L, + tableID: tblInfo.ID, + schemaVersion: schemaVersion, + }, tbl) + } else { + sortedTbls := b.infoSchema.sortedTablesBuckets[tableBucketIdx(tblInfo.ID)] + b.infoSchema.sortedTablesBuckets[tableBucketIdx(tblInfo.ID)] = append(sortedTbls, tbl) + } +} + +type virtualTableDriver struct { + *model.DBInfo + TableFromMeta tableFromMetaFunc +} + +var drivers []*virtualTableDriver + +// RegisterVirtualTable register virtual tables to the builder. +func RegisterVirtualTable(dbInfo *model.DBInfo, tableFromMeta tableFromMetaFunc) { + drivers = append(drivers, &virtualTableDriver{dbInfo, tableFromMeta}) +} + +// NewBuilder creates a new Builder with a Handle. +func NewBuilder(r autoid.Requirement, factory func() (pools.Resource, error), infoData *Data, useV2 bool) *Builder { + builder := &Builder{ + Requirement: r, + infoschemaV2: NewInfoSchemaV2(r, factory, infoData), + dirtyDB: make(map[string]bool), + factory: factory, + infoData: infoData, + enableV2: useV2, + } + schemaCacheSize := variable.SchemaCacheSize.Load() + if schemaCacheSize > 0 { + infoData.tableCache.SetCapacity(schemaCacheSize) + } + return builder +} + +// WithStore attaches the given store to builder. +func (b *Builder) WithStore(s kv.Storage) *Builder { + b.store = s + return b +} + +func tableBucketIdx(tableID int64) int { + intest.Assert(tableID > 0) + return int(tableID % bucketCount) +} + +func tableIDIsValid(tableID int64) bool { + return tableID > 0 +} diff --git a/pkg/infoschema/infoschema_v2.go b/pkg/infoschema/infoschema_v2.go index c1ffe26d2197d..87f9725634901 100644 --- a/pkg/infoschema/infoschema_v2.go +++ b/pkg/infoschema/infoschema_v2.go @@ -1012,9 +1012,9 @@ func (is *infoschemaV2) SchemaByID(id int64) (*model.DBInfo, bool) { func (is *infoschemaV2) loadTableInfo(ctx context.Context, tblID, dbID int64, ts uint64, schemaVersion int64) (table.Table, error) { defer tracing.StartRegion(ctx, "infoschema.loadTableInfo").End() - failpoint.Inject("mockLoadTableInfoError", func(_ failpoint.Value) { - failpoint.Return(nil, errors.New("mockLoadTableInfoError")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockLoadTableInfoError")); _err_ == nil { + return nil, errors.New("mockLoadTableInfoError") + } // Try to avoid repeated concurrency loading. res, err, _ := loadTableSF.Do(fmt.Sprintf("%d-%d-%d", dbID, tblID, schemaVersion), func() (any, error) { retry: diff --git a/pkg/infoschema/infoschema_v2.go__failpoint_stash__ b/pkg/infoschema/infoschema_v2.go__failpoint_stash__ new file mode 100644 index 0000000000000..c1ffe26d2197d --- /dev/null +++ b/pkg/infoschema/infoschema_v2.go__failpoint_stash__ @@ -0,0 +1,1456 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infoschema + +import ( + "context" + "fmt" + "math" + "strings" + "sync" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/size" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/tidwall/btree" + "golang.org/x/sync/singleflight" +) + +// tableItem is the btree item sorted by name or by id. +type tableItem struct { + dbName string + dbID int64 + tableName string + tableID int64 + schemaVersion int64 + tomb bool +} + +type schemaItem struct { + schemaVersion int64 + dbInfo *model.DBInfo + tomb bool +} + +type schemaIDName struct { + schemaVersion int64 + id int64 + name string + tomb bool +} + +func (si *schemaItem) Name() string { + return si.dbInfo.Name.L +} + +// versionAndTimestamp is the tuple of schema version and timestamp. +type versionAndTimestamp struct { + schemaVersion int64 + timestamp uint64 +} + +// Data is the core data struct of infoschema V2. +type Data struct { + // For the TableByName API, sorted by {dbName, tableName, schemaVersion} => tableID + // + // If the schema version +1 but a specific table does not change, the old record is + // kept and no new {dbName, tableName, schemaVersion+1} => tableID record been added. + // + // It means as long as we can find an item in it, the item is available, even through the + // schema version maybe smaller than required. + byName *btree.BTreeG[tableItem] + + // For the TableByID API, sorted by {tableID, schemaVersion} => dbID + // To reload model.TableInfo, we need both table ID and database ID for meta kv API. + // It provides the tableID => databaseID mapping. + // This mapping MUST be synced with byName. + byID *btree.BTreeG[tableItem] + + // For the SchemaByName API, sorted by {dbName, schemaVersion} => model.DBInfo + // Stores the full data in memory. + schemaMap *btree.BTreeG[schemaItem] + + // For the SchemaByID API, sorted by {id, schemaVersion} + // Stores only id, name and schemaVersion in memory. + schemaID2Name *btree.BTreeG[schemaIDName] + + tableCache *Sieve[tableCacheKey, table.Table] + + // sorted by both SchemaVersion and timestamp in descending order, assume they have same order + mu struct { + sync.RWMutex + versionTimestamps []versionAndTimestamp + } + + // For information_schema/metrics_schema/performance_schema etc + specials sync.Map + + // pid2tid is used by FindTableInfoByPartitionID, it stores {partitionID, schemaVersion} => table ID + // Need full data in memory! + pid2tid *btree.BTreeG[partitionItem] + + // tableInfoResident stores {dbName, tableID, schemaVersion} => model.TableInfo + // It is part of the model.TableInfo data kept in memory to accelerate the list tables API. + // We observe the pattern that list table API always come with filter. + // All model.TableInfo with special attributes are here, currently the special attributes including: + // TTLInfo, TiFlashReplica + // PlacementPolicyRef, Partition might be added later, and also ForeignKeys, TableLock etc + tableInfoResident *btree.BTreeG[tableInfoItem] +} + +type tableInfoItem struct { + dbName string + tableID int64 + schemaVersion int64 + tableInfo *model.TableInfo + tomb bool +} + +type partitionItem struct { + partitionID int64 + schemaVersion int64 + tableID int64 + tomb bool +} + +func (isd *Data) getVersionByTS(ts uint64) (int64, bool) { + isd.mu.RLock() + defer isd.mu.RUnlock() + return isd.getVersionByTSNoLock(ts) +} + +func (isd *Data) getVersionByTSNoLock(ts uint64) (int64, bool) { + // search one by one instead of binary search, because the timestamp of a schema could be 0 + // this is ok because the size of h.tableCache is small (currently set to 16) + // moreover, the most likely hit element in the array is the first one in steady mode + // thus it may have better performance than binary search + for i, vt := range isd.mu.versionTimestamps { + if vt.timestamp == 0 || ts < vt.timestamp { + // is.timestamp == 0 means the schema ts is unknown, so we can't use it, then just skip it. + // ts < is.timestamp means the schema is newer than ts, so we can't use it too, just skip it to find the older one. + continue + } + // ts >= is.timestamp must be true after the above condition. + if i == 0 { + // the first element is the latest schema, so we can return it directly. + return vt.schemaVersion, true + } + if isd.mu.versionTimestamps[i-1].schemaVersion == vt.schemaVersion+1 && isd.mu.versionTimestamps[i-1].timestamp > ts { + // This first condition is to make sure the schema version is continuous. If last(cache[i-1]) schema-version is 10, + // but current(cache[i]) schema-version is not 9, then current schema is not suitable for ts. + // The second condition is to make sure the cache[i-1].timestamp > ts >= cache[i].timestamp, then the current schema is suitable for ts. + return vt.schemaVersion, true + } + // current schema is not suitable for ts, then break the loop to avoid the unnecessary search. + break + } + + return 0, false +} + +type tableCacheKey struct { + tableID int64 + schemaVersion int64 +} + +// NewData creates an infoschema V2 data struct. +func NewData() *Data { + ret := &Data{ + byID: btree.NewBTreeG[tableItem](compareByID), + byName: btree.NewBTreeG[tableItem](compareByName), + schemaMap: btree.NewBTreeG[schemaItem](compareSchemaItem), + schemaID2Name: btree.NewBTreeG[schemaIDName](compareSchemaByID), + tableCache: newSieve[tableCacheKey, table.Table](1024 * 1024 * size.MB), + pid2tid: btree.NewBTreeG[partitionItem](comparePartitionItem), + tableInfoResident: btree.NewBTreeG[tableInfoItem](compareTableInfoItem), + } + ret.tableCache.SetStatusHook(newSieveStatusHookImpl()) + return ret +} + +// CacheCapacity is exported for testing. +func (isd *Data) CacheCapacity() uint64 { + return isd.tableCache.Capacity() +} + +// SetCacheCapacity sets the cache capacity size in bytes. +func (isd *Data) SetCacheCapacity(capacity uint64) { + isd.tableCache.SetCapacityAndWaitEvict(capacity) +} + +func (isd *Data) add(item tableItem, tbl table.Table) { + isd.byID.Set(item) + isd.byName.Set(item) + isd.tableCache.Set(tableCacheKey{item.tableID, item.schemaVersion}, tbl) + ti := tbl.Meta() + if pi := ti.GetPartitionInfo(); pi != nil { + for _, def := range pi.Definitions { + isd.pid2tid.Set(partitionItem{def.ID, item.schemaVersion, tbl.Meta().ID, false}) + } + } + if hasSpecialAttributes(ti) { + isd.tableInfoResident.Set(tableInfoItem{ + dbName: item.dbName, + tableID: item.tableID, + schemaVersion: item.schemaVersion, + tableInfo: ti, + tomb: false}) + } +} + +func (isd *Data) addSpecialDB(di *model.DBInfo, tables *schemaTables) { + isd.specials.LoadOrStore(di.Name.L, tables) +} + +func (isd *Data) addDB(schemaVersion int64, dbInfo *model.DBInfo) { + dbInfo.Deprecated.Tables = nil + isd.schemaID2Name.Set(schemaIDName{schemaVersion: schemaVersion, id: dbInfo.ID, name: dbInfo.Name.O}) + isd.schemaMap.Set(schemaItem{schemaVersion: schemaVersion, dbInfo: dbInfo}) +} + +func (isd *Data) remove(item tableItem) { + item.tomb = true + isd.byID.Set(item) + isd.byName.Set(item) + isd.tableInfoResident.Set(tableInfoItem{ + dbName: item.dbName, + tableID: item.tableID, + schemaVersion: item.schemaVersion, + tableInfo: nil, + tomb: true}) + isd.tableCache.Remove(tableCacheKey{item.tableID, item.schemaVersion}) +} + +func (isd *Data) deleteDB(dbInfo *model.DBInfo, schemaVersion int64) { + item := schemaItem{schemaVersion: schemaVersion, dbInfo: dbInfo, tomb: true} + isd.schemaMap.Set(item) + isd.schemaID2Name.Set(schemaIDName{schemaVersion: schemaVersion, id: dbInfo.ID, name: dbInfo.Name.O, tomb: true}) +} + +// resetBeforeFullLoad is called before a full recreate operation within builder.InitWithDBInfos(). +// TODO: write a generics version to avoid repeated code. +func (isd *Data) resetBeforeFullLoad(schemaVersion int64) { + resetTableInfoResidentBeforeFullLoad(isd.tableInfoResident, schemaVersion) + + resetByIDBeforeFullLoad(isd.byID, schemaVersion) + resetByNameBeforeFullLoad(isd.byName, schemaVersion) + + resetSchemaMapBeforeFullLoad(isd.schemaMap, schemaVersion) + resetSchemaID2NameBeforeFullLoad(isd.schemaID2Name, schemaVersion) + + resetPID2TIDBeforeFullLoad(isd.pid2tid, schemaVersion) +} + +func resetByIDBeforeFullLoad(bt *btree.BTreeG[tableItem], schemaVersion int64) { + pivot, ok := bt.Max() + if !ok { + return + } + + batchSize := 1000 + if bt.Len() < batchSize { + batchSize = bt.Len() + } + items := make([]tableItem, 0, batchSize) + items = append(items, pivot) + for { + bt.Descend(pivot, func(item tableItem) bool { + if pivot.tableID == item.tableID { + return true // skip MVCC version + } + pivot = item + items = append(items, pivot) + return len(items) < cap(items) + }) + if len(items) == 0 { + break + } + for _, item := range items { + bt.Set(tableItem{ + dbName: item.dbName, + dbID: item.dbID, + tableName: item.tableName, + tableID: item.tableID, + schemaVersion: schemaVersion, + tomb: true, + }) + } + items = items[:0] + } +} + +func resetByNameBeforeFullLoad(bt *btree.BTreeG[tableItem], schemaVersion int64) { + pivot, ok := bt.Max() + if !ok { + return + } + + batchSize := 1000 + if bt.Len() < batchSize { + batchSize = bt.Len() + } + items := make([]tableItem, 0, batchSize) + items = append(items, pivot) + for { + bt.Descend(pivot, func(item tableItem) bool { + if pivot.dbName == item.dbName && pivot.tableName == item.tableName { + return true // skip MVCC version + } + pivot = item + items = append(items, pivot) + return len(items) < cap(items) + }) + if len(items) == 0 { + break + } + for _, item := range items { + bt.Set(tableItem{ + dbName: item.dbName, + dbID: item.dbID, + tableName: item.tableName, + tableID: item.tableID, + schemaVersion: schemaVersion, + tomb: true, + }) + } + items = items[:0] + } +} + +func resetTableInfoResidentBeforeFullLoad(bt *btree.BTreeG[tableInfoItem], schemaVersion int64) { + pivot, ok := bt.Max() + if !ok { + return + } + items := make([]tableInfoItem, 0, bt.Len()) + items = append(items, pivot) + bt.Descend(pivot, func(item tableInfoItem) bool { + if pivot.dbName == item.dbName && pivot.tableID == item.tableID { + return true // skip MVCC version + } + pivot = item + items = append(items, pivot) + return true + }) + for _, item := range items { + bt.Set(tableInfoItem{ + dbName: item.dbName, + tableID: item.tableID, + schemaVersion: schemaVersion, + tomb: true, + }) + } +} + +func resetSchemaMapBeforeFullLoad(bt *btree.BTreeG[schemaItem], schemaVersion int64) { + pivot, ok := bt.Max() + if !ok { + return + } + items := make([]schemaItem, 0, bt.Len()) + items = append(items, pivot) + bt.Descend(pivot, func(item schemaItem) bool { + if pivot.Name() == item.Name() { + return true // skip MVCC version + } + pivot = item + items = append(items, pivot) + return true + }) + for _, item := range items { + bt.Set(schemaItem{ + dbInfo: item.dbInfo, + schemaVersion: schemaVersion, + tomb: true, + }) + } +} + +func resetSchemaID2NameBeforeFullLoad(bt *btree.BTreeG[schemaIDName], schemaVersion int64) { + pivot, ok := bt.Max() + if !ok { + return + } + items := make([]schemaIDName, 0, bt.Len()) + items = append(items, pivot) + bt.Descend(pivot, func(item schemaIDName) bool { + if pivot.id == item.id { + return true // skip MVCC version + } + pivot = item + items = append(items, pivot) + return true + }) + for _, item := range items { + bt.Set(schemaIDName{ + id: item.id, + name: item.name, + schemaVersion: schemaVersion, + tomb: true, + }) + } +} + +func resetPID2TIDBeforeFullLoad(bt *btree.BTreeG[partitionItem], schemaVersion int64) { + pivot, ok := bt.Max() + if !ok { + return + } + + batchSize := 1000 + if bt.Len() < batchSize { + batchSize = bt.Len() + } + items := make([]partitionItem, 0, batchSize) + items = append(items, pivot) + for { + bt.Descend(pivot, func(item partitionItem) bool { + if pivot.partitionID == item.partitionID { + return true // skip MVCC version + } + pivot = item + items = append(items, pivot) + return len(items) < cap(items) + }) + if len(items) == 0 { + break + } + for _, item := range items { + bt.Set(partitionItem{ + partitionID: item.partitionID, + tableID: item.tableID, + schemaVersion: schemaVersion, + tomb: true, + }) + } + items = items[:0] + } +} + +func compareByID(a, b tableItem) bool { + if a.tableID < b.tableID { + return true + } + if a.tableID > b.tableID { + return false + } + + return a.schemaVersion < b.schemaVersion +} + +func compareByName(a, b tableItem) bool { + if a.dbName < b.dbName { + return true + } + if a.dbName > b.dbName { + return false + } + + if a.tableName < b.tableName { + return true + } + if a.tableName > b.tableName { + return false + } + + return a.schemaVersion < b.schemaVersion +} + +func compareTableInfoItem(a, b tableInfoItem) bool { + if a.dbName < b.dbName { + return true + } + if a.dbName > b.dbName { + return false + } + + if a.tableID < b.tableID { + return true + } + if a.tableID > b.tableID { + return false + } + return a.schemaVersion < b.schemaVersion +} + +func comparePartitionItem(a, b partitionItem) bool { + if a.partitionID < b.partitionID { + return true + } + if a.partitionID > b.partitionID { + return false + } + return a.schemaVersion < b.schemaVersion +} + +func compareSchemaItem(a, b schemaItem) bool { + if a.Name() < b.Name() { + return true + } + if a.Name() > b.Name() { + return false + } + return a.schemaVersion < b.schemaVersion +} + +func compareSchemaByID(a, b schemaIDName) bool { + if a.id < b.id { + return true + } + if a.id > b.id { + return false + } + return a.schemaVersion < b.schemaVersion +} + +var _ InfoSchema = &infoschemaV2{} + +type infoschemaV2 struct { + *infoSchema // in fact, we only need the infoSchemaMisc inside it, but the builder rely on it. + r autoid.Requirement + factory func() (pools.Resource, error) + ts uint64 + *Data +} + +// NewInfoSchemaV2 create infoschemaV2. +func NewInfoSchemaV2(r autoid.Requirement, factory func() (pools.Resource, error), infoData *Data) infoschemaV2 { + return infoschemaV2{ + infoSchema: newInfoSchema(), + Data: infoData, + r: r, + factory: factory, + } +} + +func search(bt *btree.BTreeG[tableItem], schemaVersion int64, end tableItem, matchFn func(a, b *tableItem) bool) (tableItem, bool) { + var ok bool + var target tableItem + // Iterate through the btree, find the query item whose schema version is the largest one (latest). + bt.Descend(end, func(item tableItem) bool { + if !matchFn(&end, &item) { + return false + } + if item.schemaVersion > schemaVersion { + // We're seaching historical snapshot, and this record is newer than us, we can't use it. + // Skip the record. + return true + } + // schema version of the items should <= query's schema version. + if !ok { // The first one found. + ok = true + target = item + } else { // The latest one + if item.schemaVersion > target.schemaVersion { + target = item + } + } + return true + }) + if ok && target.tomb { + // If the item is a tomb record, the table is dropped. + ok = false + } + return target, ok +} + +func (is *infoschemaV2) base() *infoSchema { + return is.infoSchema +} + +func (is *infoschemaV2) CloneAndUpdateTS(startTS uint64) *infoschemaV2 { + tmp := *is + tmp.ts = startTS + return &tmp +} + +func (is *infoschemaV2) TableByID(id int64) (val table.Table, ok bool) { + return is.tableByID(id, true) +} + +func (is *infoschemaV2) tableByID(id int64, noRefill bool) (val table.Table, ok bool) { + if !tableIDIsValid(id) { + return + } + + // Get from the cache. + key := tableCacheKey{id, is.infoSchema.schemaMetaVersion} + tbl, found := is.tableCache.Get(key) + if found && tbl != nil { + return tbl, true + } + + eq := func(a, b *tableItem) bool { return a.tableID == b.tableID } + itm, ok := search(is.byID, is.infoSchema.schemaMetaVersion, tableItem{tableID: id, schemaVersion: math.MaxInt64}, eq) + if !ok { + return nil, false + } + + if isTableVirtual(id) { + if raw, exist := is.Data.specials.Load(itm.dbName); exist { + schTbls := raw.(*schemaTables) + val, ok = schTbls.tables[itm.tableName] + return + } + return nil, false + } + // get cache with old key + oldKey := tableCacheKey{itm.tableID, itm.schemaVersion} + tbl, found = is.tableCache.Get(oldKey) + if found && tbl != nil { + if !noRefill { + is.tableCache.Set(key, tbl) + } + return tbl, true + } + + // Maybe the table is evicted? need to reload. + ret, err := is.loadTableInfo(context.Background(), id, itm.dbID, is.ts, is.infoSchema.schemaMetaVersion) + if err != nil || ret == nil { + return nil, false + } + + if !noRefill { + is.tableCache.Set(oldKey, ret) + } + return ret, true +} + +// IsSpecialDB tells whether the database is a special database. +func IsSpecialDB(dbName string) bool { + return dbName == util.InformationSchemaName.L || + dbName == util.PerformanceSchemaName.L || + dbName == util.MetricSchemaName.L +} + +// EvictTable is exported for testing only. +func (is *infoschemaV2) EvictTable(schema, tbl string) { + eq := func(a, b *tableItem) bool { return a.dbName == b.dbName && a.tableName == b.tableName } + itm, ok := search(is.byName, is.infoSchema.schemaMetaVersion, tableItem{dbName: schema, tableName: tbl, schemaVersion: math.MaxInt64}, eq) + if !ok { + return + } + is.tableCache.Remove(tableCacheKey{itm.tableID, is.infoSchema.schemaMetaVersion}) + is.tableCache.Remove(tableCacheKey{itm.tableID, itm.schemaVersion}) +} + +type tableByNameHelper struct { + end tableItem + schemaVersion int64 + found bool + res tableItem +} + +func (h *tableByNameHelper) onItem(item tableItem) bool { + if item.dbName != h.end.dbName || item.tableName != h.end.tableName { + h.found = false + return false + } + if item.schemaVersion <= h.schemaVersion { + if !item.tomb { // If the item is a tomb record, the database is dropped. + h.found = true + h.res = item + } + return false + } + return true +} + +func (is *infoschemaV2) TableByName(ctx context.Context, schema, tbl model.CIStr) (t table.Table, err error) { + if IsSpecialDB(schema.L) { + if raw, ok := is.specials.Load(schema.L); ok { + tbNames := raw.(*schemaTables) + if t, ok = tbNames.tables[tbl.L]; ok { + return + } + } + return nil, ErrTableNotExists.FastGenByArgs(schema, tbl) + } + + start := time.Now() + + var h tableByNameHelper + h.end = tableItem{dbName: schema.L, tableName: tbl.L, schemaVersion: math.MaxInt64} + h.schemaVersion = is.infoSchema.schemaMetaVersion + is.byName.Descend(h.end, h.onItem) + + if !h.found { + return nil, ErrTableNotExists.FastGenByArgs(schema, tbl) + } + itm := h.res + + // Get from the cache with old key + oldKey := tableCacheKey{itm.tableID, itm.schemaVersion} + res, found := is.tableCache.Get(oldKey) + if found && res != nil { + metrics.TableByNameHitDuration.Observe(float64(time.Since(start))) + return res, nil + } + + // Maybe the table is evicted? need to reload. + ret, err := is.loadTableInfo(ctx, itm.tableID, itm.dbID, is.ts, is.infoSchema.schemaMetaVersion) + if err != nil { + return nil, errors.Trace(err) + } + is.tableCache.Set(oldKey, ret) + metrics.TableByNameMissDuration.Observe(float64(time.Since(start))) + return ret, nil +} + +// TableInfoByName implements InfoSchema.TableInfoByName +func (is *infoschemaV2) TableInfoByName(schema, table model.CIStr) (*model.TableInfo, error) { + tbl, err := is.TableByName(context.Background(), schema, table) + return getTableInfo(tbl), err +} + +// TableInfoByID implements InfoSchema.TableInfoByID +func (is *infoschemaV2) TableInfoByID(id int64) (*model.TableInfo, bool) { + tbl, ok := is.TableByID(id) + return getTableInfo(tbl), ok +} + +// SchemaTableInfos implements MetaOnlyInfoSchema. +func (is *infoschemaV2) SchemaTableInfos(ctx context.Context, schema model.CIStr) ([]*model.TableInfo, error) { + if IsSpecialDB(schema.L) { + raw, ok := is.Data.specials.Load(schema.L) + if ok { + schTbls := raw.(*schemaTables) + tables := make([]table.Table, 0, len(schTbls.tables)) + for _, tbl := range schTbls.tables { + tables = append(tables, tbl) + } + return getTableInfoList(tables), nil + } + return nil, nil // something wrong? + } + +retry: + dbInfo, ok := is.SchemaByName(schema) + if !ok { + return nil, nil + } + snapshot := is.r.Store().GetSnapshot(kv.NewVersion(is.ts)) + // Using the KV timeout read feature to address the issue of potential DDL lease expiration when + // the meta region leader is slow. + snapshot.SetOption(kv.TiKVClientReadTimeout, uint64(3000)) // 3000ms. + m := meta.NewSnapshotMeta(snapshot) + tblInfos, err := m.ListTables(dbInfo.ID) + if err != nil { + if meta.ErrDBNotExists.Equal(err) { + return nil, nil + } + // Flashback statement could cause such kind of error. + // In theory that error should be handled in the lower layer, like client-go. + // But it's not done, so we retry here. + if strings.Contains(err.Error(), "in flashback progress") { + select { + case <-time.After(200 * time.Millisecond): + case <-ctx.Done(): + return nil, ctx.Err() + } + goto retry + } + return nil, errors.Trace(err) + } + return tblInfos, nil +} + +// SchemaSimpleTableInfos implements MetaOnlyInfoSchema. +func (is *infoschemaV2) SchemaSimpleTableInfos(ctx context.Context, schema model.CIStr) ([]*model.TableNameInfo, error) { + if IsSpecialDB(schema.L) { + raw, ok := is.Data.specials.Load(schema.L) + if ok { + schTbls := raw.(*schemaTables) + ret := make([]*model.TableNameInfo, 0, len(schTbls.tables)) + for _, tbl := range schTbls.tables { + ret = append(ret, &model.TableNameInfo{ + ID: tbl.Meta().ID, + Name: tbl.Meta().Name, + }) + } + return ret, nil + } + return nil, nil // something wrong? + } + + // Ascend is much more difficult than Descend. + // So the data is taken out first and then dedup in Descend order. + var tableItems []tableItem + is.byName.Ascend(tableItem{dbName: schema.L}, func(item tableItem) bool { + if item.dbName != schema.L { + return false + } + if is.infoSchema.schemaMetaVersion >= item.schemaVersion { + tableItems = append(tableItems, item) + } + return true + }) + if len(tableItems) == 0 { + return nil, nil + } + tblInfos := make([]*model.TableNameInfo, 0, len(tableItems)) + var curr *tableItem + for i := len(tableItems) - 1; i >= 0; i-- { + item := &tableItems[i] + if curr == nil || curr.tableName != tableItems[i].tableName { + curr = item + if !item.tomb { + tblInfos = append(tblInfos, &model.TableNameInfo{ + ID: item.tableID, + Name: model.NewCIStr(item.tableName), + }) + } + } + } + return tblInfos, nil +} + +// FindTableInfoByPartitionID implements InfoSchema.FindTableInfoByPartitionID +func (is *infoschemaV2) FindTableInfoByPartitionID( + partitionID int64, +) (*model.TableInfo, *model.DBInfo, *model.PartitionDefinition) { + tbl, db, partDef := is.FindTableByPartitionID(partitionID) + return getTableInfo(tbl), db, partDef +} + +func (is *infoschemaV2) SchemaByName(schema model.CIStr) (val *model.DBInfo, ok bool) { + if IsSpecialDB(schema.L) { + raw, ok := is.Data.specials.Load(schema.L) + if !ok { + return nil, false + } + schTbls, ok := raw.(*schemaTables) + return schTbls.dbInfo, ok + } + + var dbInfo model.DBInfo + dbInfo.Name = schema + is.Data.schemaMap.Descend(schemaItem{ + dbInfo: &dbInfo, + schemaVersion: math.MaxInt64, + }, func(item schemaItem) bool { + if item.Name() != schema.L { + ok = false + return false + } + if item.schemaVersion <= is.infoSchema.schemaMetaVersion { + if !item.tomb { // If the item is a tomb record, the database is dropped. + ok = true + val = item.dbInfo + } + return false + } + return true + }) + return +} + +func (is *infoschemaV2) allSchemas(visit func(*model.DBInfo)) { + var last *model.DBInfo + is.Data.schemaMap.Reverse(func(item schemaItem) bool { + if item.schemaVersion > is.infoSchema.schemaMetaVersion { + // Skip the versions that we are not looking for. + return true + } + + // Dedup the same db record of different versions. + if last != nil && last.Name == item.dbInfo.Name { + return true + } + last = item.dbInfo + + if !item.tomb { + visit(item.dbInfo) + } + return true + }) + is.Data.specials.Range(func(key, value any) bool { + sc := value.(*schemaTables) + visit(sc.dbInfo) + return true + }) +} + +func (is *infoschemaV2) AllSchemas() (schemas []*model.DBInfo) { + is.allSchemas(func(di *model.DBInfo) { + schemas = append(schemas, di) + }) + return +} + +func (is *infoschemaV2) AllSchemaNames() []model.CIStr { + rs := make([]model.CIStr, 0, is.Data.schemaMap.Len()) + is.allSchemas(func(di *model.DBInfo) { + rs = append(rs, di.Name) + }) + return rs +} + +func (is *infoschemaV2) SchemaExists(schema model.CIStr) bool { + _, ok := is.SchemaByName(schema) + return ok +} + +func (is *infoschemaV2) FindTableByPartitionID(partitionID int64) (table.Table, *model.DBInfo, *model.PartitionDefinition) { + var ok bool + var pi partitionItem + is.pid2tid.Descend(partitionItem{partitionID: partitionID, schemaVersion: math.MaxInt64}, + func(item partitionItem) bool { + if item.partitionID != partitionID { + return false + } + if item.schemaVersion > is.infoSchema.schemaMetaVersion { + // Skip the record. + return true + } + if item.schemaVersion <= is.infoSchema.schemaMetaVersion { + ok = !item.tomb + pi = item + return false + } + return true + }) + if !ok { + return nil, nil, nil + } + + tbl, ok := is.TableByID(pi.tableID) + if !ok { + // something wrong? + return nil, nil, nil + } + + dbID := tbl.Meta().DBID + dbInfo, ok := is.SchemaByID(dbID) + if !ok { + // something wrong? + return nil, nil, nil + } + + partInfo := tbl.Meta().GetPartitionInfo() + var def *model.PartitionDefinition + for i := 0; i < len(partInfo.Definitions); i++ { + pdef := &partInfo.Definitions[i] + if pdef.ID == partitionID { + def = pdef + break + } + } + + return tbl, dbInfo, def +} + +func (is *infoschemaV2) TableExists(schema, table model.CIStr) bool { + _, err := is.TableByName(context.Background(), schema, table) + return err == nil +} + +func (is *infoschemaV2) SchemaByID(id int64) (*model.DBInfo, bool) { + if isTableVirtual(id) { + var st *schemaTables + is.Data.specials.Range(func(key, value any) bool { + tmp := value.(*schemaTables) + if tmp.dbInfo.ID == id { + st = tmp + return false + } + return true + }) + if st == nil { + return nil, false + } + return st.dbInfo, true + } + var ok bool + var name string + is.Data.schemaID2Name.Descend(schemaIDName{ + id: id, + schemaVersion: math.MaxInt64, + }, func(item schemaIDName) bool { + if item.id != id { + ok = false + return false + } + if item.schemaVersion <= is.infoSchema.schemaMetaVersion { + if !item.tomb { // If the item is a tomb record, the database is dropped. + ok = true + name = item.name + } + return false + } + return true + }) + if !ok { + return nil, false + } + return is.SchemaByName(model.NewCIStr(name)) +} + +func (is *infoschemaV2) loadTableInfo(ctx context.Context, tblID, dbID int64, ts uint64, schemaVersion int64) (table.Table, error) { + defer tracing.StartRegion(ctx, "infoschema.loadTableInfo").End() + failpoint.Inject("mockLoadTableInfoError", func(_ failpoint.Value) { + failpoint.Return(nil, errors.New("mockLoadTableInfoError")) + }) + // Try to avoid repeated concurrency loading. + res, err, _ := loadTableSF.Do(fmt.Sprintf("%d-%d-%d", dbID, tblID, schemaVersion), func() (any, error) { + retry: + snapshot := is.r.Store().GetSnapshot(kv.NewVersion(ts)) + // Using the KV timeout read feature to address the issue of potential DDL lease expiration when + // the meta region leader is slow. + snapshot.SetOption(kv.TiKVClientReadTimeout, uint64(3000)) // 3000ms. + m := meta.NewSnapshotMeta(snapshot) + + tblInfo, err := m.GetTable(dbID, tblID) + if err != nil { + // Flashback statement could cause such kind of error. + // In theory that error should be handled in the lower layer, like client-go. + // But it's not done, so we retry here. + if strings.Contains(err.Error(), "in flashback progress") { + time.Sleep(200 * time.Millisecond) + goto retry + } + + // TODO load table panic!!! + panic(err) + } + + // table removed. + if tblInfo == nil { + return nil, errors.Trace(ErrTableNotExists.FastGenByArgs( + fmt.Sprintf("(Schema ID %d)", dbID), + fmt.Sprintf("(Table ID %d)", tblID), + )) + } + + ConvertCharsetCollateToLowerCaseIfNeed(tblInfo) + ConvertOldVersionUTF8ToUTF8MB4IfNeed(tblInfo) + allocs := autoid.NewAllocatorsFromTblInfo(is.r, dbID, tblInfo) + ret, err := tableFromMeta(allocs, is.factory, tblInfo) + if err != nil { + return nil, errors.Trace(err) + } + return ret, err + }) + + if err != nil { + return nil, errors.Trace(err) + } + if res == nil { + return nil, errors.Trace(ErrTableNotExists.FastGenByArgs( + fmt.Sprintf("(Schema ID %d)", dbID), + fmt.Sprintf("(Table ID %d)", tblID), + )) + } + return res.(table.Table), nil +} + +var loadTableSF = &singleflight.Group{} + +func isTableVirtual(id int64) bool { + // some kind of magic number... + // we use special ids for tables in INFORMATION_SCHEMA/PERFORMANCE_SCHEMA/METRICS_SCHEMA + // See meta/autoid/autoid.go for those definitions. + return (id & autoid.SystemSchemaIDFlag) > 0 +} + +// IsV2 tells whether an InfoSchema is v2 or not. +func IsV2(is InfoSchema) (bool, *infoschemaV2) { + ret, ok := is.(*infoschemaV2) + return ok, ret +} + +func applyTableUpdate(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + if b.enableV2 { + return b.applyTableUpdateV2(m, diff) + } + return b.applyTableUpdate(m, diff) +} + +func applyCreateSchema(b *Builder, m *meta.Meta, diff *model.SchemaDiff) error { + return b.applyCreateSchema(m, diff) +} + +func applyDropSchema(b *Builder, diff *model.SchemaDiff) []int64 { + if b.enableV2 { + return b.applyDropSchemaV2(diff) + } + return b.applyDropSchema(diff) +} + +func applyRecoverSchema(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + if diff.ReadTableFromMeta { + // recover tables under the database and set them to diff.AffectedOpts + s := b.store.GetSnapshot(kv.MaxVersion) + recoverMeta := meta.NewSnapshotMeta(s) + tables, err := recoverMeta.ListSimpleTables(diff.SchemaID) + if err != nil { + return nil, err + } + diff.AffectedOpts = make([]*model.AffectedOption, 0, len(tables)) + for _, t := range tables { + diff.AffectedOpts = append(diff.AffectedOpts, &model.AffectedOption{ + SchemaID: diff.SchemaID, + OldSchemaID: diff.SchemaID, + TableID: t.ID, + OldTableID: t.ID, + }) + } + } + + if b.enableV2 { + return b.applyRecoverSchemaV2(m, diff) + } + return b.applyRecoverSchema(m, diff) +} + +func (b *Builder) applyRecoverSchemaV2(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + if di, ok := b.infoschemaV2.SchemaByID(diff.SchemaID); ok { + return nil, ErrDatabaseExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", di.ID), + ) + } + di, err := m.GetDatabase(diff.SchemaID) + if err != nil { + return nil, errors.Trace(err) + } + b.infoschemaV2.addDB(diff.Version, di) + return applyCreateTables(b, m, diff) +} + +func applyModifySchemaCharsetAndCollate(b *Builder, m *meta.Meta, diff *model.SchemaDiff) error { + if b.enableV2 { + return b.applyModifySchemaCharsetAndCollateV2(m, diff) + } + return b.applyModifySchemaCharsetAndCollate(m, diff) +} + +func applyModifySchemaDefaultPlacement(b *Builder, m *meta.Meta, diff *model.SchemaDiff) error { + if b.enableV2 { + return b.applyModifySchemaDefaultPlacementV2(m, diff) + } + return b.applyModifySchemaDefaultPlacement(m, diff) +} + +func applyDropTable(b *Builder, diff *model.SchemaDiff, dbInfo *model.DBInfo, tableID int64, affected []int64) []int64 { + if b.enableV2 { + return b.applyDropTableV2(diff, dbInfo, tableID, affected) + } + return b.applyDropTable(diff, dbInfo, tableID, affected) +} + +func applyCreateTables(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + return b.applyCreateTables(m, diff) +} + +func updateInfoSchemaBundles(b *Builder) { + if b.enableV2 { + b.updateInfoSchemaBundlesV2(&b.infoschemaV2) + } else { + b.updateInfoSchemaBundles(b.infoSchema) + } +} + +func oldSchemaInfo(b *Builder, diff *model.SchemaDiff) (*model.DBInfo, bool) { + if b.enableV2 { + return b.infoschemaV2.SchemaByID(diff.OldSchemaID) + } + + oldRoDBInfo, ok := b.infoSchema.SchemaByID(diff.OldSchemaID) + if ok { + oldRoDBInfo = b.getSchemaAndCopyIfNecessary(oldRoDBInfo.Name.L) + } + return oldRoDBInfo, ok +} + +// allocByID returns the Allocators of a table. +func allocByID(b *Builder, id int64) (autoid.Allocators, bool) { + var is InfoSchema + if b.enableV2 { + is = &b.infoschemaV2 + } else { + is = b.infoSchema + } + tbl, ok := is.TableByID(id) + if !ok { + return autoid.Allocators{}, false + } + return tbl.Allocators(nil), true +} + +// TODO: more UT to check the correctness. +func (b *Builder) applyTableUpdateV2(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { + oldDBInfo, ok := b.infoschemaV2.SchemaByID(diff.SchemaID) + if !ok { + return nil, ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.SchemaID), + ) + } + + oldTableID, newTableID := b.getTableIDs(diff) + b.updateBundleForTableUpdate(diff, newTableID, oldTableID) + + tblIDs, allocs, err := dropTableForUpdate(b, newTableID, oldTableID, oldDBInfo, diff) + if err != nil { + return nil, err + } + + if tableIDIsValid(newTableID) { + // All types except DropTableOrView. + var err error + tblIDs, err = applyCreateTable(b, m, oldDBInfo, newTableID, allocs, diff.Type, tblIDs, diff.Version) + if err != nil { + return nil, errors.Trace(err) + } + } + return tblIDs, nil +} + +func (b *Builder) applyDropSchemaV2(diff *model.SchemaDiff) []int64 { + di, ok := b.infoschemaV2.SchemaByID(diff.SchemaID) + if !ok { + return nil + } + + tableIDs := make([]int64, 0, len(di.Deprecated.Tables)) + tables, err := b.infoschemaV2.SchemaTableInfos(context.Background(), di.Name) + terror.Log(err) + for _, tbl := range tables { + tableIDs = appendAffectedIDs(tableIDs, tbl) + } + + for _, id := range tableIDs { + b.deleteBundle(b.infoSchema, id) + b.applyDropTableV2(diff, di, id, nil) + } + b.infoData.deleteDB(di, diff.Version) + return tableIDs +} + +func (b *Builder) applyDropTableV2(diff *model.SchemaDiff, dbInfo *model.DBInfo, tableID int64, affected []int64) []int64 { + // Remove the table in temporaryTables + if b.infoSchemaMisc.temporaryTableIDs != nil { + delete(b.infoSchemaMisc.temporaryTableIDs, tableID) + } + + table, ok := b.infoschemaV2.TableByID(tableID) + if !ok { + return nil + } + + // The old DBInfo still holds a reference to old table info, we need to remove it. + b.infoSchema.deleteReferredForeignKeys(dbInfo.Name, table.Meta()) + + if pi := table.Meta().GetPartitionInfo(); pi != nil { + for _, def := range pi.Definitions { + b.infoData.pid2tid.Set(partitionItem{def.ID, diff.Version, table.Meta().ID, true}) + } + } + + b.infoData.remove(tableItem{ + dbName: dbInfo.Name.L, + dbID: dbInfo.ID, + tableName: table.Meta().Name.L, + tableID: table.Meta().ID, + schemaVersion: diff.Version, + }) + + return affected +} + +func (b *Builder) applyModifySchemaCharsetAndCollateV2(m *meta.Meta, diff *model.SchemaDiff) error { + di, err := m.GetDatabase(diff.SchemaID) + if err != nil { + return errors.Trace(err) + } + if di == nil { + // This should never happen. + return ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.SchemaID), + ) + } + newDBInfo, _ := b.infoschemaV2.SchemaByID(diff.SchemaID) + newDBInfo.Charset = di.Charset + newDBInfo.Collate = di.Collate + b.infoschemaV2.deleteDB(di, diff.Version) + b.infoschemaV2.addDB(diff.Version, newDBInfo) + return nil +} + +func (b *Builder) applyModifySchemaDefaultPlacementV2(m *meta.Meta, diff *model.SchemaDiff) error { + di, err := m.GetDatabase(diff.SchemaID) + if err != nil { + return errors.Trace(err) + } + if di == nil { + // This should never happen. + return ErrDatabaseNotExists.GenWithStackByArgs( + fmt.Sprintf("(Schema ID %d)", diff.SchemaID), + ) + } + newDBInfo, _ := b.infoschemaV2.SchemaByID(diff.SchemaID) + newDBInfo.PlacementPolicyRef = di.PlacementPolicyRef + b.infoschemaV2.deleteDB(di, diff.Version) + b.infoschemaV2.addDB(diff.Version, newDBInfo) + return nil +} + +func (b *bundleInfoBuilder) updateInfoSchemaBundlesV2(is *infoschemaV2) { + if b.deltaUpdate { + b.completeUpdateTablesV2(is) + for tblID := range b.updateTables { + b.updateTableBundles(is, tblID) + } + return + } + + // do full update bundles + is.ruleBundleMap = make(map[int64]*placement.Bundle) + tmp := is.ListTablesWithSpecialAttribute(PlacementPolicyAttribute) + for _, v := range tmp { + for _, tbl := range v.TableInfos { + b.updateTableBundles(is, tbl.ID) + } + } +} + +func (b *bundleInfoBuilder) completeUpdateTablesV2(is *infoschemaV2) { + if len(b.updatePolicies) == 0 && len(b.updatePartitions) == 0 { + return + } + + dbs := is.ListTablesWithSpecialAttribute(AllSpecialAttribute) + for _, db := range dbs { + for _, tbl := range db.TableInfos { + tblInfo := tbl + if tblInfo.PlacementPolicyRef != nil { + if _, ok := b.updatePolicies[tblInfo.PlacementPolicyRef.ID]; ok { + b.markTableBundleShouldUpdate(tblInfo.ID) + } + } + + if tblInfo.Partition != nil { + for _, par := range tblInfo.Partition.Definitions { + if _, ok := b.updatePartitions[par.ID]; ok { + b.markTableBundleShouldUpdate(tblInfo.ID) + } + } + } + } + } +} + +type specialAttributeFilter func(*model.TableInfo) bool + +// TTLAttribute is the TTL attribute filter used by ListTablesWithSpecialAttribute. +var TTLAttribute specialAttributeFilter = func(t *model.TableInfo) bool { + return t.State == model.StatePublic && t.TTLInfo != nil +} + +// TiFlashAttribute is the TiFlashReplica attribute filter used by ListTablesWithSpecialAttribute. +var TiFlashAttribute specialAttributeFilter = func(t *model.TableInfo) bool { + return t.TiFlashReplica != nil +} + +// PlacementPolicyAttribute is the Placement Policy attribute filter used by ListTablesWithSpecialAttribute. +var PlacementPolicyAttribute specialAttributeFilter = func(t *model.TableInfo) bool { + if t.PlacementPolicyRef != nil { + return true + } + if parInfo := t.GetPartitionInfo(); parInfo != nil { + for _, def := range parInfo.Definitions { + if def.PlacementPolicyRef != nil { + return true + } + } + } + return false +} + +// TableLockAttribute is the Table Lock attribute filter used by ListTablesWithSpecialAttribute. +var TableLockAttribute specialAttributeFilter = func(t *model.TableInfo) bool { + return t.Lock != nil +} + +// ForeignKeysAttribute is the ForeignKeys attribute filter used by ListTablesWithSpecialAttribute. +var ForeignKeysAttribute specialAttributeFilter = func(t *model.TableInfo) bool { + return len(t.ForeignKeys) > 0 +} + +// PartitionAttribute is the Partition attribute filter used by ListTablesWithSpecialAttribute. +var PartitionAttribute specialAttributeFilter = func(t *model.TableInfo) bool { + return t.GetPartitionInfo() != nil +} + +func hasSpecialAttributes(t *model.TableInfo) bool { + return TTLAttribute(t) || TiFlashAttribute(t) || PlacementPolicyAttribute(t) || PartitionAttribute(t) || TableLockAttribute(t) || ForeignKeysAttribute(t) +} + +// AllSpecialAttribute marks a model.TableInfo with any special attributes. +var AllSpecialAttribute specialAttributeFilter = hasSpecialAttributes + +func (is *infoschemaV2) ListTablesWithSpecialAttribute(filter specialAttributeFilter) []tableInfoResult { + ret := make([]tableInfoResult, 0, 10) + var currDB string + var lastTableID int64 + var res tableInfoResult + is.Data.tableInfoResident.Reverse(func(item tableInfoItem) bool { + if item.schemaVersion > is.infoSchema.schemaMetaVersion { + // Skip the versions that we are not looking for. + return true + } + // Dedup the same record of different versions. + if lastTableID != 0 && lastTableID == item.tableID { + return true + } + lastTableID = item.tableID + + if item.tomb { + return true + } + + if !filter(item.tableInfo) { + return true + } + + if currDB == "" { + currDB = item.dbName + res = tableInfoResult{DBName: item.dbName} + res.TableInfos = append(res.TableInfos, item.tableInfo) + } else if currDB == item.dbName { + res.TableInfos = append(res.TableInfos, item.tableInfo) + } else { + ret = append(ret, res) + res = tableInfoResult{DBName: item.dbName} + res.TableInfos = append(res.TableInfos, item.tableInfo) + } + return true + }) + if len(res.TableInfos) > 0 { + ret = append(ret, res) + } + return ret +} diff --git a/pkg/infoschema/perfschema/binding__failpoint_binding__.go b/pkg/infoschema/perfschema/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..fe39572d084c5 --- /dev/null +++ b/pkg/infoschema/perfschema/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package perfschema + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/infoschema/perfschema/tables.go b/pkg/infoschema/perfschema/tables.go index d242bd7560a71..d4d85b0a26f81 100644 --- a/pkg/infoschema/perfschema/tables.go +++ b/pkg/infoschema/perfschema/tables.go @@ -309,7 +309,7 @@ func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGorout default: return nil, errors.Errorf("%s does not support profile remote component", nodeType) } - failpoint.Inject("mockRemoteNodeStatusAddress", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockRemoteNodeStatusAddress")); _err_ == nil { // The cluster topology is injected by `failpoint` expression and // there is no extra checks for it. (let the test fail if the expression invalid) if s := val.(string); len(s) > 0 { @@ -328,7 +328,7 @@ func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGorout // erase error err = nil } - }) + } if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/infoschema/perfschema/tables.go__failpoint_stash__ b/pkg/infoschema/perfschema/tables.go__failpoint_stash__ new file mode 100644 index 0000000000000..d242bd7560a71 --- /dev/null +++ b/pkg/infoschema/perfschema/tables.go__failpoint_stash__ @@ -0,0 +1,415 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package perfschema + +import ( + "cmp" + "context" + "fmt" + "net/http" + "slices" + "strings" + "sync" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/profile" + pd "github.com/tikv/pd/client/http" +) + +const ( + tableNameGlobalStatus = "global_status" + tableNameSessionStatus = "session_status" + tableNameSetupActors = "setup_actors" + tableNameSetupObjects = "setup_objects" + tableNameSetupInstruments = "setup_instruments" + tableNameSetupConsumers = "setup_consumers" + tableNameEventsStatementsCurrent = "events_statements_current" + tableNameEventsStatementsHistory = "events_statements_history" + tableNameEventsStatementsHistoryLong = "events_statements_history_long" + tableNamePreparedStatementsInstances = "prepared_statements_instances" + tableNameEventsTransactionsCurrent = "events_transactions_current" + tableNameEventsTransactionsHistory = "events_transactions_history" + tableNameEventsTransactionsHistoryLong = "events_transactions_history_long" + tableNameEventsStagesCurrent = "events_stages_current" + tableNameEventsStagesHistory = "events_stages_history" + tableNameEventsStagesHistoryLong = "events_stages_history_long" + tableNameEventsStatementsSummaryByDigest = "events_statements_summary_by_digest" + tableNameTiDBProfileCPU = "tidb_profile_cpu" + tableNameTiDBProfileMemory = "tidb_profile_memory" + tableNameTiDBProfileMutex = "tidb_profile_mutex" + tableNameTiDBProfileAllocs = "tidb_profile_allocs" + tableNameTiDBProfileBlock = "tidb_profile_block" + tableNameTiDBProfileGoroutines = "tidb_profile_goroutines" + tableNameTiKVProfileCPU = "tikv_profile_cpu" + tableNamePDProfileCPU = "pd_profile_cpu" + tableNamePDProfileMemory = "pd_profile_memory" + tableNamePDProfileMutex = "pd_profile_mutex" + tableNamePDProfileAllocs = "pd_profile_allocs" + tableNamePDProfileBlock = "pd_profile_block" + tableNamePDProfileGoroutines = "pd_profile_goroutines" + tableNameSessionAccountConnectAttrs = "session_account_connect_attrs" + tableNameSessionConnectAttrs = "session_connect_attrs" + tableNameSessionVariables = "session_variables" +) + +var tableIDMap = map[string]int64{ + tableNameGlobalStatus: autoid.PerformanceSchemaDBID + 1, + tableNameSessionStatus: autoid.PerformanceSchemaDBID + 2, + tableNameSetupActors: autoid.PerformanceSchemaDBID + 3, + tableNameSetupObjects: autoid.PerformanceSchemaDBID + 4, + tableNameSetupInstruments: autoid.PerformanceSchemaDBID + 5, + tableNameSetupConsumers: autoid.PerformanceSchemaDBID + 6, + tableNameEventsStatementsCurrent: autoid.PerformanceSchemaDBID + 7, + tableNameEventsStatementsHistory: autoid.PerformanceSchemaDBID + 8, + tableNameEventsStatementsHistoryLong: autoid.PerformanceSchemaDBID + 9, + tableNamePreparedStatementsInstances: autoid.PerformanceSchemaDBID + 10, + tableNameEventsTransactionsCurrent: autoid.PerformanceSchemaDBID + 11, + tableNameEventsTransactionsHistory: autoid.PerformanceSchemaDBID + 12, + tableNameEventsTransactionsHistoryLong: autoid.PerformanceSchemaDBID + 13, + tableNameEventsStagesCurrent: autoid.PerformanceSchemaDBID + 14, + tableNameEventsStagesHistory: autoid.PerformanceSchemaDBID + 15, + tableNameEventsStagesHistoryLong: autoid.PerformanceSchemaDBID + 16, + tableNameEventsStatementsSummaryByDigest: autoid.PerformanceSchemaDBID + 17, + tableNameTiDBProfileCPU: autoid.PerformanceSchemaDBID + 18, + tableNameTiDBProfileMemory: autoid.PerformanceSchemaDBID + 19, + tableNameTiDBProfileMutex: autoid.PerformanceSchemaDBID + 20, + tableNameTiDBProfileAllocs: autoid.PerformanceSchemaDBID + 21, + tableNameTiDBProfileBlock: autoid.PerformanceSchemaDBID + 22, + tableNameTiDBProfileGoroutines: autoid.PerformanceSchemaDBID + 23, + tableNameTiKVProfileCPU: autoid.PerformanceSchemaDBID + 24, + tableNamePDProfileCPU: autoid.PerformanceSchemaDBID + 25, + tableNamePDProfileMemory: autoid.PerformanceSchemaDBID + 26, + tableNamePDProfileMutex: autoid.PerformanceSchemaDBID + 27, + tableNamePDProfileAllocs: autoid.PerformanceSchemaDBID + 28, + tableNamePDProfileBlock: autoid.PerformanceSchemaDBID + 29, + tableNamePDProfileGoroutines: autoid.PerformanceSchemaDBID + 30, + tableNameSessionVariables: autoid.PerformanceSchemaDBID + 31, + tableNameSessionConnectAttrs: autoid.PerformanceSchemaDBID + 32, + tableNameSessionAccountConnectAttrs: autoid.PerformanceSchemaDBID + 33, +} + +// perfSchemaTable stands for the fake table all its data is in the memory. +type perfSchemaTable struct { + infoschema.VirtualTable + meta *model.TableInfo + cols []*table.Column + tp table.Type + indices []table.Index +} + +var pluginTable = make(map[string]func(autoid.Allocators, *model.TableInfo) (table.Table, error)) + +// IsPredefinedTable judges whether this table is predefined. +func IsPredefinedTable(tableName string) bool { + _, ok := tableIDMap[strings.ToLower(tableName)] + return ok +} + +func tableFromMeta(allocs autoid.Allocators, _ func() (pools.Resource, error), meta *model.TableInfo) (table.Table, error) { + if f, ok := pluginTable[meta.Name.L]; ok { + ret, err := f(allocs, meta) + return ret, err + } + return createPerfSchemaTable(meta) +} + +// createPerfSchemaTable creates all perfSchemaTables +func createPerfSchemaTable(meta *model.TableInfo) (*perfSchemaTable, error) { + columns := make([]*table.Column, 0, len(meta.Columns)) + for _, colInfo := range meta.Columns { + col := table.ToColumn(colInfo) + columns = append(columns, col) + } + tp := table.VirtualTable + t := &perfSchemaTable{ + meta: meta, + cols: columns, + tp: tp, + } + if err := initTableIndices(t); err != nil { + return nil, err + } + return t, nil +} + +// Cols implements table.Table Type interface. +func (vt *perfSchemaTable) Cols() []*table.Column { + return vt.cols +} + +// VisibleCols implements table.Table VisibleCols interface. +func (vt *perfSchemaTable) VisibleCols() []*table.Column { + return vt.cols +} + +// HiddenCols implements table.Table HiddenCols interface. +func (vt *perfSchemaTable) HiddenCols() []*table.Column { + return nil +} + +// WritableCols implements table.Table Type interface. +func (vt *perfSchemaTable) WritableCols() []*table.Column { + return vt.cols +} + +// DeletableCols implements table.Table Type interface. +func (vt *perfSchemaTable) DeletableCols() []*table.Column { + return vt.cols +} + +// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. +func (vt *perfSchemaTable) FullHiddenColsAndVisibleCols() []*table.Column { + return vt.cols +} + +// GetPhysicalID implements table.Table GetID interface. +func (vt *perfSchemaTable) GetPhysicalID() int64 { + return vt.meta.ID +} + +// Meta implements table.Table Type interface. +func (vt *perfSchemaTable) Meta() *model.TableInfo { + return vt.meta +} + +// Type implements table.Table Type interface. +func (vt *perfSchemaTable) Type() table.Type { + return vt.tp +} + +// Indices implements table.Table Indices interface. +func (vt *perfSchemaTable) Indices() []table.Index { + return vt.indices +} + +// GetPartitionedTable implements table.Table GetPartitionedTable interface. +func (vt *perfSchemaTable) GetPartitionedTable() table.PartitionedTable { + return nil +} + +// initTableIndices initializes the indices of the perfSchemaTable. +func initTableIndices(t *perfSchemaTable) error { + tblInfo := t.meta + for _, idxInfo := range tblInfo.Indices { + if idxInfo.State == model.StateNone { + return table.ErrIndexStateCantNone.GenWithStackByArgs(idxInfo.Name) + } + idx := tables.NewIndex(t.meta.ID, tblInfo, idxInfo) + t.indices = append(t.indices, idx) + } + return nil +} + +func (vt *perfSchemaTable) getRows(ctx context.Context, sctx sessionctx.Context, cols []*table.Column) (fullRows [][]types.Datum, err error) { + switch vt.meta.Name.O { + case tableNameTiDBProfileCPU: + fullRows, err = (&profile.Collector{}).ProfileGraph("cpu") + case tableNameTiDBProfileMemory: + fullRows, err = (&profile.Collector{}).ProfileGraph("heap") + case tableNameTiDBProfileMutex: + fullRows, err = (&profile.Collector{}).ProfileGraph("mutex") + case tableNameTiDBProfileAllocs: + fullRows, err = (&profile.Collector{}).ProfileGraph("allocs") + case tableNameTiDBProfileBlock: + fullRows, err = (&profile.Collector{}).ProfileGraph("block") + case tableNameTiDBProfileGoroutines: + fullRows, err = (&profile.Collector{}).ProfileGraph("goroutine") + case tableNameTiKVProfileCPU: + interval := fmt.Sprintf("%d", profile.CPUProfileInterval/time.Second) + fullRows, err = dataForRemoteProfile(sctx, "tikv", "/debug/pprof/profile?seconds="+interval, false) + case tableNamePDProfileCPU: + fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfProfileAPIWithInterval(profile.CPUProfileInterval), false) + case tableNamePDProfileMemory: + fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfHeap, false) + case tableNamePDProfileMutex: + fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfMutex, false) + case tableNamePDProfileAllocs: + fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfAllocs, false) + case tableNamePDProfileBlock: + fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfBlock, false) + case tableNamePDProfileGoroutines: + fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfGoroutineWithDebugLevel(2), true) + case tableNameSessionVariables: + fullRows, err = infoschema.GetDataFromSessionVariables(ctx, sctx) + case tableNameSessionConnectAttrs: + fullRows, err = infoschema.GetDataFromSessionConnectAttrs(sctx, false) + case tableNameSessionAccountConnectAttrs: + fullRows, err = infoschema.GetDataFromSessionConnectAttrs(sctx, true) + } + if err != nil { + return + } + if len(cols) == len(vt.cols) { + return + } + rows := make([][]types.Datum, len(fullRows)) + for i, fullRow := range fullRows { + row := make([]types.Datum, len(cols)) + for j, col := range cols { + row[j] = fullRow[col.Offset] + } + rows[i] = row + } + return rows, nil +} + +// IterRecords implements table.Table IterRecords interface. +func (vt *perfSchemaTable) IterRecords(ctx context.Context, sctx sessionctx.Context, cols []*table.Column, fn table.RecordIterFunc) error { + rows, err := vt.getRows(ctx, sctx, cols) + if err != nil { + return err + } + for i, row := range rows { + more, err := fn(kv.IntHandle(i), row, cols) + if err != nil { + return err + } + if !more { + break + } + } + return nil +} + +func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGoroutine bool) ([][]types.Datum, error) { + var ( + servers []infoschema.ServerInfo + err error + ) + switch nodeType { + case "tikv": + servers, err = infoschema.GetStoreServerInfo(ctx.GetStore()) + case "pd": + servers, err = infoschema.GetPDServerInfo(ctx) + default: + return nil, errors.Errorf("%s does not support profile remote component", nodeType) + } + failpoint.Inject("mockRemoteNodeStatusAddress", func(val failpoint.Value) { + // The cluster topology is injected by `failpoint` expression and + // there is no extra checks for it. (let the test fail if the expression invalid) + if s := val.(string); len(s) > 0 { + servers = servers[:0] + for _, server := range strings.Split(s, ";") { + parts := strings.Split(server, ",") + if parts[0] != nodeType { + continue + } + servers = append(servers, infoschema.ServerInfo{ + ServerType: parts[0], + Address: parts[1], + StatusAddr: parts[2], + }) + } + // erase error + err = nil + } + }) + if err != nil { + return nil, errors.Trace(err) + } + + type result struct { + addr string + rows [][]types.Datum + err error + } + + wg := sync.WaitGroup{} + ch := make(chan result, len(servers)) + for _, server := range servers { + statusAddr := server.StatusAddr + if len(statusAddr) == 0 { + ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("TiKV node %s does not contain status address", server.Address)) + continue + } + + wg.Add(1) + go func(address string) { + util.WithRecovery(func() { + defer wg.Done() + url := fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), statusAddr, uri) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + ch <- result{err: errors.Trace(err)} + return + } + // Forbidden PD follower proxy + req.Header.Add("PD-Allow-follower-handle", "true") + // TiKV output svg format in default + req.Header.Add("Content-Type", "application/protobuf") + resp, err := util.InternalHTTPClient().Do(req) + if err != nil { + ch <- result{err: errors.Trace(err)} + return + } + defer func() { + terror.Log(resp.Body.Close()) + }() + if resp.StatusCode != http.StatusOK { + ch <- result{err: errors.Errorf("request %s failed: %s", url, resp.Status)} + return + } + collector := profile.Collector{} + var rows [][]types.Datum + if isGoroutine { + rows, err = collector.ParseGoroutines(resp.Body) + } else { + rows, err = collector.ProfileReaderToDatums(resp.Body) + } + if err != nil { + ch <- result{err: errors.Trace(err)} + return + } + ch <- result{addr: address, rows: rows} + }, nil) + }(statusAddr) + } + + wg.Wait() + close(ch) + + // Keep the original order to make the result more stable + var results []result //nolint: prealloc + for result := range ch { + if result.err != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(result.err) + continue + } + results = append(results, result) + } + slices.SortFunc(results, func(i, j result) int { return cmp.Compare(i.addr, j.addr) }) + var finalRows [][]types.Datum + for _, result := range results { + addr := types.NewStringDatum(result.addr) + for _, row := range result.rows { + // Insert the node address in front of rows + finalRows = append(finalRows, append([]types.Datum{addr}, row...)) + } + } + return finalRows, nil +} diff --git a/pkg/infoschema/sieve.go b/pkg/infoschema/sieve.go index 01100ebf359d6..2041c908432e7 100644 --- a/pkg/infoschema/sieve.go +++ b/pkg/infoschema/sieve.go @@ -151,10 +151,10 @@ func (s *Sieve[K, V]) Set(key K, value V) { } func (s *Sieve[K, V]) Get(key K) (value V, ok bool) { - failpoint.Inject("skipGet", func() { + if _, _err_ := failpoint.Eval(_curpkg_("skipGet")); _err_ == nil { var v V - failpoint.Return(v, false) - }) + return v, false + } s.mu.Lock() defer s.mu.Unlock() if e, ok := s.items[key]; ok { diff --git a/pkg/infoschema/sieve.go__failpoint_stash__ b/pkg/infoschema/sieve.go__failpoint_stash__ new file mode 100644 index 0000000000000..01100ebf359d6 --- /dev/null +++ b/pkg/infoschema/sieve.go__failpoint_stash__ @@ -0,0 +1,272 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infoschema + +import ( + "container/list" + "context" + "sync" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/infoschema/internal" +) + +// entry holds the key and value of a cache entry. +type entry[K comparable, V any] struct { + key K + value V + visited bool + element *list.Element + size uint64 +} + +func (t *entry[K, V]) Size() uint64 { + if t.size == 0 { + size := internal.Sizeof(t) + if size > 0 { + t.size = uint64(size) + } + } + return t.size +} + +// Sieve is an efficient turn-Key eviction algorithm for web caches. +// See blog post https://cachemon.github.io/SIEVE-website/blog/2023/12/17/sieve-is-simpler-than-lru/ +// and also the academic paper "SIEVE is simpler than LRU" +type Sieve[K comparable, V any] struct { + ctx context.Context + cancel context.CancelFunc + mu sync.Mutex + size uint64 + capacity uint64 + items map[K]*entry[K, V] + ll *list.List + hand *list.Element + + hook sieveStatusHook +} + +type sieveStatusHook interface { + onHit() + onMiss() + onEvict() + onUpdateSize(size uint64) + onUpdateLimit(limit uint64) +} + +type emptySieveStatusHook struct{} + +func (e *emptySieveStatusHook) onHit() {} + +func (e *emptySieveStatusHook) onMiss() {} + +func (e *emptySieveStatusHook) onEvict() {} + +func (e *emptySieveStatusHook) onUpdateSize(_ uint64) {} + +func (e *emptySieveStatusHook) onUpdateLimit(_ uint64) {} + +func newSieve[K comparable, V any](capacity uint64) *Sieve[K, V] { + ctx, cancel := context.WithCancel(context.Background()) + + cache := &Sieve[K, V]{ + ctx: ctx, + cancel: cancel, + capacity: capacity, + items: make(map[K]*entry[K, V]), + ll: list.New(), + hook: &emptySieveStatusHook{}, + } + + return cache +} + +func (s *Sieve[K, V]) SetStatusHook(hook sieveStatusHook) { + s.hook = hook +} + +func (s *Sieve[K, V]) SetCapacity(capacity uint64) { + s.mu.Lock() + defer s.mu.Unlock() + s.capacity = capacity + s.hook.onUpdateLimit(capacity) +} + +func (s *Sieve[K, V]) SetCapacityAndWaitEvict(capacity uint64) { + s.SetCapacity(capacity) + for { + s.mu.Lock() + if s.size <= s.capacity { + s.mu.Unlock() + break + } + for i := 0; s.size > s.capacity && i < 10; i++ { + s.evict() + } + s.mu.Unlock() + } +} + +func (s *Sieve[K, V]) Capacity() uint64 { + s.mu.Lock() + defer s.mu.Unlock() + return s.capacity +} + +func (s *Sieve[K, V]) Set(key K, value V) { + s.mu.Lock() + defer s.mu.Unlock() + + if e, ok := s.items[key]; ok { + e.value = value + e.visited = true + return + } + + for i := 0; s.size > s.capacity && i < 10; i++ { + s.evict() + } + + e := &entry[K, V]{ + key: key, + value: value, + } + s.size += e.Size() // calculate the size first without putting to the list. + s.hook.onUpdateSize(s.size) + e.element = s.ll.PushFront(key) + + s.items[key] = e +} + +func (s *Sieve[K, V]) Get(key K) (value V, ok bool) { + failpoint.Inject("skipGet", func() { + var v V + failpoint.Return(v, false) + }) + s.mu.Lock() + defer s.mu.Unlock() + if e, ok := s.items[key]; ok { + e.visited = true + s.hook.onHit() + return e.value, true + } + s.hook.onMiss() + return +} + +func (s *Sieve[K, V]) Remove(key K) (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + + if e, ok := s.items[key]; ok { + // if the element to be removed is the hand, + // then move the hand to the previous one. + if e.element == s.hand { + s.hand = s.hand.Prev() + } + + s.removeEntry(e) + return true + } + + return false +} + +func (s *Sieve[K, V]) Contains(key K) (ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + _, ok = s.items[key] + return +} + +func (s *Sieve[K, V]) Peek(key K) (value V, ok bool) { + s.mu.Lock() + defer s.mu.Unlock() + + if e, ok := s.items[key]; ok { + return e.value, true + } + + return +} + +func (s *Sieve[K, V]) Size() uint64 { + s.mu.Lock() + defer s.mu.Unlock() + + return s.size +} + +func (s *Sieve[K, V]) Len() int { + s.mu.Lock() + defer s.mu.Unlock() + + return s.ll.Len() +} + +func (s *Sieve[K, V]) Purge() { + s.mu.Lock() + defer s.mu.Unlock() + + for _, e := range s.items { + s.removeEntry(e) + } + + s.ll.Init() +} + +func (s *Sieve[K, V]) Close() { + s.Purge() + s.mu.Lock() + s.cancel() + s.mu.Unlock() +} + +func (s *Sieve[K, V]) removeEntry(e *entry[K, V]) { + s.ll.Remove(e.element) + delete(s.items, e.key) + s.size -= e.Size() + s.hook.onUpdateSize(s.size) +} + +func (s *Sieve[K, V]) evict() { + o := s.hand + // if o is nil, then assign it to the tail element in the list + if o == nil { + o = s.ll.Back() + } + + el, ok := s.items[o.Value.(K)] + if !ok { + panic("sieve: evicting non-existent element") + } + + for el.visited { + el.visited = false + o = o.Prev() + if o == nil { + o = s.ll.Back() + } + + el, ok = s.items[o.Value.(K)] + if !ok { + panic("sieve: evicting non-existent element") + } + } + + s.hand = o.Prev() + s.removeEntry(el) + s.hook.onEvict() +} diff --git a/pkg/infoschema/tables.go b/pkg/infoschema/tables.go index 6168001a813e7..b032c1c4c201b 100644 --- a/pkg/infoschema/tables.go +++ b/pkg/infoschema/tables.go @@ -1791,7 +1791,7 @@ func (s *ServerInfo) ResolveLoopBackAddr() { // GetClusterServerInfo returns all components information of cluster func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - failpoint.Inject("mockClusterInfo", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockClusterInfo")); _err_ == nil { // The cluster topology is injected by `failpoint` expression and // there is no extra checks for it. (let the test fail if the expression invalid) if s := val.(string); len(s) > 0 { @@ -1811,9 +1811,9 @@ func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { ServerID: serverID, }) } - failpoint.Return(servers, nil) + return servers, nil } - }) + } type retriever func(ctx sessionctx.Context) ([]ServerInfo, error) retrievers := []retriever{GetTiDBServerInfo, GetPDServerInfo, func(ctx sessionctx.Context) ([]ServerInfo, error) { @@ -2069,7 +2069,7 @@ func isTiFlashWriteNode(store *metapb.Store) bool { // GetStoreServerInfo returns all store nodes(TiKV or TiFlash) cluster information func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { - failpoint.Inject("mockStoreServerInfo", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockStoreServerInfo")); _err_ == nil { if s := val.(string); len(s) > 0 { var servers []ServerInfo for _, server := range strings.Split(s, ";") { @@ -2083,9 +2083,9 @@ func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { StartTimestamp: 0, }) } - failpoint.Return(servers, nil) + return servers, nil } - }) + } // Get TiKV servers info. tikvStore, ok := store.(tikv.Storage) @@ -2102,11 +2102,11 @@ func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { } servers := make([]ServerInfo, 0, len(stores)) for _, store := range stores { - failpoint.Inject("mockStoreTombstone", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockStoreTombstone")); _err_ == nil { if val.(bool) { store.State = metapb.StoreState_Tombstone } - }) + } if store.GetState() == metapb.StoreState_Tombstone { continue @@ -2144,11 +2144,11 @@ func FormatStoreServerVersion(version string) string { // GetTiFlashStoreCount returns the count of tiflash server. func GetTiFlashStoreCount(ctx sessionctx.Context) (cnt uint64, err error) { - failpoint.Inject("mockTiFlashStoreCount", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockTiFlashStoreCount")); _err_ == nil { if val.(bool) { - failpoint.Return(uint64(10), nil) + return uint64(10), nil } - }) + } stores, err := GetStoreServerInfo(ctx.GetStore()) if err != nil { diff --git a/pkg/infoschema/tables.go__failpoint_stash__ b/pkg/infoschema/tables.go__failpoint_stash__ new file mode 100644 index 0000000000000..6168001a813e7 --- /dev/null +++ b/pkg/infoschema/tables.go__failpoint_stash__ @@ -0,0 +1,2694 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package infoschema + +import ( + "cmp" + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "slices" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/diagnosticspb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/ddl/resourcegroup" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/session/txninfo" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/deadlockhistory" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sem" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tidb/pkg/util/stmtsummary" + "github.com/tikv/client-go/v2/tikv" + pd "github.com/tikv/pd/client/http" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" +) + +const ( + // TableSchemata is the string constant of infoschema table. + TableSchemata = "SCHEMATA" + // TableTables is the string constant of infoschema table. + TableTables = "TABLES" + // TableColumns is the string constant of infoschema table + TableColumns = "COLUMNS" + tableColumnStatistics = "COLUMN_STATISTICS" + // TableStatistics is the string constant of infoschema table + TableStatistics = "STATISTICS" + // TableCharacterSets is the string constant of infoschema charactersets memory table + TableCharacterSets = "CHARACTER_SETS" + // TableCollations is the string constant of infoschema collations memory table. + TableCollations = "COLLATIONS" + tableFiles = "FILES" + // CatalogVal is the string constant of TABLE_CATALOG. + CatalogVal = "def" + // TableProfiling is the string constant of infoschema table. + TableProfiling = "PROFILING" + // TablePartitions is the string constant of infoschema table. + TablePartitions = "PARTITIONS" + // TableKeyColumn is the string constant of KEY_COLUMN_USAGE. + TableKeyColumn = "KEY_COLUMN_USAGE" + // TableReferConst is the string constant of REFERENTIAL_CONSTRAINTS. + TableReferConst = "REFERENTIAL_CONSTRAINTS" + // TableSessionVar is the string constant of SESSION_VARIABLES. + TableSessionVar = "SESSION_VARIABLES" + tablePlugins = "PLUGINS" + // TableConstraints is the string constant of TABLE_CONSTRAINTS. + TableConstraints = "TABLE_CONSTRAINTS" + tableTriggers = "TRIGGERS" + // TableUserPrivileges is the string constant of infoschema user privilege table. + TableUserPrivileges = "USER_PRIVILEGES" + tableSchemaPrivileges = "SCHEMA_PRIVILEGES" + tableTablePrivileges = "TABLE_PRIVILEGES" + tableColumnPrivileges = "COLUMN_PRIVILEGES" + // TableEngines is the string constant of infoschema table. + TableEngines = "ENGINES" + // TableViews is the string constant of infoschema table. + TableViews = "VIEWS" + tableRoutines = "ROUTINES" + tableParameters = "PARAMETERS" + tableEvents = "EVENTS" + tableGlobalStatus = "GLOBAL_STATUS" + tableGlobalVariables = "GLOBAL_VARIABLES" + tableSessionStatus = "SESSION_STATUS" + tableOptimizerTrace = "OPTIMIZER_TRACE" + tableTableSpaces = "TABLESPACES" + // TableCollationCharacterSetApplicability is the string constant of infoschema memory table. + TableCollationCharacterSetApplicability = "COLLATION_CHARACTER_SET_APPLICABILITY" + // TableProcesslist is the string constant of infoschema table. + TableProcesslist = "PROCESSLIST" + // TableTiDBIndexes is the string constant of infoschema table + TableTiDBIndexes = "TIDB_INDEXES" + // TableTiDBHotRegions is the string constant of infoschema table + TableTiDBHotRegions = "TIDB_HOT_REGIONS" + // TableTiDBHotRegionsHistory is the string constant of infoschema table + TableTiDBHotRegionsHistory = "TIDB_HOT_REGIONS_HISTORY" + // TableTiKVStoreStatus is the string constant of infoschema table + TableTiKVStoreStatus = "TIKV_STORE_STATUS" + // TableAnalyzeStatus is the string constant of Analyze Status + TableAnalyzeStatus = "ANALYZE_STATUS" + // TableTiKVRegionStatus is the string constant of infoschema table + TableTiKVRegionStatus = "TIKV_REGION_STATUS" + // TableTiKVRegionPeers is the string constant of infoschema table + TableTiKVRegionPeers = "TIKV_REGION_PEERS" + // TableTiDBServersInfo is the string constant of TiDB server information table. + TableTiDBServersInfo = "TIDB_SERVERS_INFO" + // TableSlowQuery is the string constant of slow query memory table. + TableSlowQuery = "SLOW_QUERY" + // TableClusterInfo is the string constant of cluster info memory table. + TableClusterInfo = "CLUSTER_INFO" + // TableClusterConfig is the string constant of cluster configuration memory table. + TableClusterConfig = "CLUSTER_CONFIG" + // TableClusterLog is the string constant of cluster log memory table. + TableClusterLog = "CLUSTER_LOG" + // TableClusterLoad is the string constant of cluster load memory table. + TableClusterLoad = "CLUSTER_LOAD" + // TableClusterHardware is the string constant of cluster hardware table. + TableClusterHardware = "CLUSTER_HARDWARE" + // TableClusterSystemInfo is the string constant of cluster system info table. + TableClusterSystemInfo = "CLUSTER_SYSTEMINFO" + // TableTiFlashReplica is the string constant of tiflash replica table. + TableTiFlashReplica = "TIFLASH_REPLICA" + // TableInspectionResult is the string constant of inspection result table. + TableInspectionResult = "INSPECTION_RESULT" + // TableMetricTables is a table that contains all metrics table definition. + TableMetricTables = "METRICS_TABLES" + // TableMetricSummary is a summary table that contains all metrics. + TableMetricSummary = "METRICS_SUMMARY" + // TableMetricSummaryByLabel is a metric table that contains all metrics that group by label info. + TableMetricSummaryByLabel = "METRICS_SUMMARY_BY_LABEL" + // TableInspectionSummary is the string constant of inspection summary table. + TableInspectionSummary = "INSPECTION_SUMMARY" + // TableInspectionRules is the string constant of currently implemented inspection and summary rules. + TableInspectionRules = "INSPECTION_RULES" + // TableDDLJobs is the string constant of DDL job table. + TableDDLJobs = "DDL_JOBS" + // TableSequences is the string constant of all sequences created by user. + TableSequences = "SEQUENCES" + // TableStatementsSummary is the string constant of statement summary table. + TableStatementsSummary = "STATEMENTS_SUMMARY" + // TableStatementsSummaryHistory is the string constant of statements summary history table. + TableStatementsSummaryHistory = "STATEMENTS_SUMMARY_HISTORY" + // TableStatementsSummaryEvicted is the string constant of statements summary evicted table. + TableStatementsSummaryEvicted = "STATEMENTS_SUMMARY_EVICTED" + // TableStorageStats is a table that contains all tables disk usage + TableStorageStats = "TABLE_STORAGE_STATS" + // TableTiFlashTables is the string constant of tiflash tables table. + TableTiFlashTables = "TIFLASH_TABLES" + // TableTiFlashSegments is the string constant of tiflash segments table. + TableTiFlashSegments = "TIFLASH_SEGMENTS" + // TableClientErrorsSummaryGlobal is the string constant of client errors table. + TableClientErrorsSummaryGlobal = "CLIENT_ERRORS_SUMMARY_GLOBAL" + // TableClientErrorsSummaryByUser is the string constant of client errors table. + TableClientErrorsSummaryByUser = "CLIENT_ERRORS_SUMMARY_BY_USER" + // TableClientErrorsSummaryByHost is the string constant of client errors table. + TableClientErrorsSummaryByHost = "CLIENT_ERRORS_SUMMARY_BY_HOST" + // TableTiDBTrx is current running transaction status table. + TableTiDBTrx = "TIDB_TRX" + // TableDeadlocks is the string constant of deadlock table. + TableDeadlocks = "DEADLOCKS" + // TableDataLockWaits is current lock waiting status table. + TableDataLockWaits = "DATA_LOCK_WAITS" + // TableAttributes is the string constant of attributes table. + TableAttributes = "ATTRIBUTES" + // TablePlacementPolicies is the string constant of placement policies table. + TablePlacementPolicies = "PLACEMENT_POLICIES" + // TableTrxSummary is the string constant of transaction summary table. + TableTrxSummary = "TRX_SUMMARY" + // TableVariablesInfo is the string constant of variables_info table. + TableVariablesInfo = "VARIABLES_INFO" + // TableUserAttributes is the string constant of user_attributes view. + TableUserAttributes = "USER_ATTRIBUTES" + // TableMemoryUsage is the memory usage status of tidb instance. + TableMemoryUsage = "MEMORY_USAGE" + // TableMemoryUsageOpsHistory is the memory control operators history. + TableMemoryUsageOpsHistory = "MEMORY_USAGE_OPS_HISTORY" + // TableResourceGroups is the metadata of resource groups. + TableResourceGroups = "RESOURCE_GROUPS" + // TableRunawayWatches is the query list of runaway watch. + TableRunawayWatches = "RUNAWAY_WATCHES" + // TableCheckConstraints is the list of CHECK constraints. + TableCheckConstraints = "CHECK_CONSTRAINTS" + // TableTiDBCheckConstraints is the list of CHECK constraints, with non-standard TiDB extensions. + TableTiDBCheckConstraints = "TIDB_CHECK_CONSTRAINTS" + // TableKeywords is the list of keywords. + TableKeywords = "KEYWORDS" + // TableTiDBIndexUsage is a table to show the usage stats of indexes in the current instance. + TableTiDBIndexUsage = "TIDB_INDEX_USAGE" +) + +const ( + // DataLockWaitsColumnKey is the name of the KEY column of the DATA_LOCK_WAITS table. + DataLockWaitsColumnKey = "KEY" + // DataLockWaitsColumnKeyInfo is the name of the KEY_INFO column of the DATA_LOCK_WAITS table. + DataLockWaitsColumnKeyInfo = "KEY_INFO" + // DataLockWaitsColumnTrxID is the name of the TRX_ID column of the DATA_LOCK_WAITS table. + DataLockWaitsColumnTrxID = "TRX_ID" + // DataLockWaitsColumnCurrentHoldingTrxID is the name of the CURRENT_HOLDING_TRX_ID column of the DATA_LOCK_WAITS table. + DataLockWaitsColumnCurrentHoldingTrxID = "CURRENT_HOLDING_TRX_ID" + // DataLockWaitsColumnSQLDigest is the name of the SQL_DIGEST column of the DATA_LOCK_WAITS table. + DataLockWaitsColumnSQLDigest = "SQL_DIGEST" + // DataLockWaitsColumnSQLDigestText is the name of the SQL_DIGEST_TEXT column of the DATA_LOCK_WAITS table. + DataLockWaitsColumnSQLDigestText = "SQL_DIGEST_TEXT" +) + +// The following variables will only be used when PD in the microservice mode. +const ( + // tsoServiceName is the name of TSO service. + tsoServiceName = "tso" + // schedulingServiceName is the name of scheduling service. + schedulingServiceName = "scheduling" +) + +var tableIDMap = map[string]int64{ + TableSchemata: autoid.InformationSchemaDBID + 1, + TableTables: autoid.InformationSchemaDBID + 2, + TableColumns: autoid.InformationSchemaDBID + 3, + tableColumnStatistics: autoid.InformationSchemaDBID + 4, + TableStatistics: autoid.InformationSchemaDBID + 5, + TableCharacterSets: autoid.InformationSchemaDBID + 6, + TableCollations: autoid.InformationSchemaDBID + 7, + tableFiles: autoid.InformationSchemaDBID + 8, + CatalogVal: autoid.InformationSchemaDBID + 9, + TableProfiling: autoid.InformationSchemaDBID + 10, + TablePartitions: autoid.InformationSchemaDBID + 11, + TableKeyColumn: autoid.InformationSchemaDBID + 12, + TableReferConst: autoid.InformationSchemaDBID + 13, + TableSessionVar: autoid.InformationSchemaDBID + 14, + tablePlugins: autoid.InformationSchemaDBID + 15, + TableConstraints: autoid.InformationSchemaDBID + 16, + tableTriggers: autoid.InformationSchemaDBID + 17, + TableUserPrivileges: autoid.InformationSchemaDBID + 18, + tableSchemaPrivileges: autoid.InformationSchemaDBID + 19, + tableTablePrivileges: autoid.InformationSchemaDBID + 20, + tableColumnPrivileges: autoid.InformationSchemaDBID + 21, + TableEngines: autoid.InformationSchemaDBID + 22, + TableViews: autoid.InformationSchemaDBID + 23, + tableRoutines: autoid.InformationSchemaDBID + 24, + tableParameters: autoid.InformationSchemaDBID + 25, + tableEvents: autoid.InformationSchemaDBID + 26, + tableGlobalStatus: autoid.InformationSchemaDBID + 27, + tableGlobalVariables: autoid.InformationSchemaDBID + 28, + tableSessionStatus: autoid.InformationSchemaDBID + 29, + tableOptimizerTrace: autoid.InformationSchemaDBID + 30, + tableTableSpaces: autoid.InformationSchemaDBID + 31, + TableCollationCharacterSetApplicability: autoid.InformationSchemaDBID + 32, + TableProcesslist: autoid.InformationSchemaDBID + 33, + TableTiDBIndexes: autoid.InformationSchemaDBID + 34, + TableSlowQuery: autoid.InformationSchemaDBID + 35, + TableTiDBHotRegions: autoid.InformationSchemaDBID + 36, + TableTiKVStoreStatus: autoid.InformationSchemaDBID + 37, + TableAnalyzeStatus: autoid.InformationSchemaDBID + 38, + TableTiKVRegionStatus: autoid.InformationSchemaDBID + 39, + TableTiKVRegionPeers: autoid.InformationSchemaDBID + 40, + TableTiDBServersInfo: autoid.InformationSchemaDBID + 41, + TableClusterInfo: autoid.InformationSchemaDBID + 42, + TableClusterConfig: autoid.InformationSchemaDBID + 43, + TableClusterLoad: autoid.InformationSchemaDBID + 44, + TableTiFlashReplica: autoid.InformationSchemaDBID + 45, + ClusterTableSlowLog: autoid.InformationSchemaDBID + 46, + ClusterTableProcesslist: autoid.InformationSchemaDBID + 47, + TableClusterLog: autoid.InformationSchemaDBID + 48, + TableClusterHardware: autoid.InformationSchemaDBID + 49, + TableClusterSystemInfo: autoid.InformationSchemaDBID + 50, + TableInspectionResult: autoid.InformationSchemaDBID + 51, + TableMetricSummary: autoid.InformationSchemaDBID + 52, + TableMetricSummaryByLabel: autoid.InformationSchemaDBID + 53, + TableMetricTables: autoid.InformationSchemaDBID + 54, + TableInspectionSummary: autoid.InformationSchemaDBID + 55, + TableInspectionRules: autoid.InformationSchemaDBID + 56, + TableDDLJobs: autoid.InformationSchemaDBID + 57, + TableSequences: autoid.InformationSchemaDBID + 58, + TableStatementsSummary: autoid.InformationSchemaDBID + 59, + TableStatementsSummaryHistory: autoid.InformationSchemaDBID + 60, + ClusterTableStatementsSummary: autoid.InformationSchemaDBID + 61, + ClusterTableStatementsSummaryHistory: autoid.InformationSchemaDBID + 62, + TableStorageStats: autoid.InformationSchemaDBID + 63, + TableTiFlashTables: autoid.InformationSchemaDBID + 64, + TableTiFlashSegments: autoid.InformationSchemaDBID + 65, + // Removed, see https://github.com/pingcap/tidb/issues/28890 + //TablePlacementPolicy: autoid.InformationSchemaDBID + 66, + TableClientErrorsSummaryGlobal: autoid.InformationSchemaDBID + 67, + TableClientErrorsSummaryByUser: autoid.InformationSchemaDBID + 68, + TableClientErrorsSummaryByHost: autoid.InformationSchemaDBID + 69, + TableTiDBTrx: autoid.InformationSchemaDBID + 70, + ClusterTableTiDBTrx: autoid.InformationSchemaDBID + 71, + TableDeadlocks: autoid.InformationSchemaDBID + 72, + ClusterTableDeadlocks: autoid.InformationSchemaDBID + 73, + TableDataLockWaits: autoid.InformationSchemaDBID + 74, + TableStatementsSummaryEvicted: autoid.InformationSchemaDBID + 75, + ClusterTableStatementsSummaryEvicted: autoid.InformationSchemaDBID + 76, + TableAttributes: autoid.InformationSchemaDBID + 77, + TableTiDBHotRegionsHistory: autoid.InformationSchemaDBID + 78, + TablePlacementPolicies: autoid.InformationSchemaDBID + 79, + TableTrxSummary: autoid.InformationSchemaDBID + 80, + ClusterTableTrxSummary: autoid.InformationSchemaDBID + 81, + TableVariablesInfo: autoid.InformationSchemaDBID + 82, + TableUserAttributes: autoid.InformationSchemaDBID + 83, + TableMemoryUsage: autoid.InformationSchemaDBID + 84, + TableMemoryUsageOpsHistory: autoid.InformationSchemaDBID + 85, + ClusterTableMemoryUsage: autoid.InformationSchemaDBID + 86, + ClusterTableMemoryUsageOpsHistory: autoid.InformationSchemaDBID + 87, + TableResourceGroups: autoid.InformationSchemaDBID + 88, + TableRunawayWatches: autoid.InformationSchemaDBID + 89, + TableCheckConstraints: autoid.InformationSchemaDBID + 90, + TableTiDBCheckConstraints: autoid.InformationSchemaDBID + 91, + TableKeywords: autoid.InformationSchemaDBID + 92, + TableTiDBIndexUsage: autoid.InformationSchemaDBID + 93, + ClusterTableTiDBIndexUsage: autoid.InformationSchemaDBID + 94, +} + +// columnInfo represents the basic column information of all kinds of INFORMATION_SCHEMA tables +type columnInfo struct { + // name of column + name string + // tp is column type + tp byte + // represent size of bytes of the column + size int + // represent decimal length of the column + decimal int + // flag represent NotNull, Unsigned, PriKey flags etc. + flag uint + // deflt is default value + deflt any + // comment for the column + comment string + // enumElems represent all possible literal string values of an enum column + enumElems []string +} + +func buildColumnInfo(col columnInfo) *model.ColumnInfo { + mCharset := charset.CharsetBin + mCollation := charset.CharsetBin + if col.tp == mysql.TypeVarchar || col.tp == mysql.TypeBlob || col.tp == mysql.TypeLongBlob || col.tp == mysql.TypeEnum { + mCharset = charset.CharsetUTF8MB4 + mCollation = charset.CollationUTF8MB4 + } + fieldType := types.FieldType{} + fieldType.SetType(col.tp) + fieldType.SetCharset(mCharset) + fieldType.SetCollate(mCollation) + fieldType.SetFlen(col.size) + fieldType.SetDecimal(col.decimal) + fieldType.SetFlag(col.flag) + fieldType.SetElems(col.enumElems) + return &model.ColumnInfo{ + Name: model.NewCIStr(col.name), + FieldType: fieldType, + State: model.StatePublic, + DefaultValue: col.deflt, + Comment: col.comment, + } +} + +func buildTableMeta(tableName string, cs []columnInfo) *model.TableInfo { + cols := make([]*model.ColumnInfo, 0, len(cs)) + primaryIndices := make([]*model.IndexInfo, 0, 1) + tblInfo := &model.TableInfo{ + Name: model.NewCIStr(tableName), + State: model.StatePublic, + Charset: mysql.DefaultCharset, + Collate: mysql.DefaultCollationName, + } + for offset, c := range cs { + if tblInfo.Name.O == ClusterTableSlowLog && mysql.HasPriKeyFlag(c.flag) { + switch c.tp { + case mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24: + tblInfo.PKIsHandle = true + default: + tblInfo.IsCommonHandle = true + tblInfo.CommonHandleVersion = 1 + index := &model.IndexInfo{ + Name: model.NewCIStr("primary"), + State: model.StatePublic, + Primary: true, + Unique: true, + Columns: []*model.IndexColumn{ + {Name: model.NewCIStr(c.name), Offset: offset, Length: types.UnspecifiedLength}}, + } + primaryIndices = append(primaryIndices, index) + tblInfo.Indices = primaryIndices + } + } + cols = append(cols, buildColumnInfo(c)) + } + for i, col := range cols { + col.Offset = i + } + tblInfo.Columns = cols + return tblInfo +} + +var schemataCols = []columnInfo{ + {name: "CATALOG_NAME", tp: mysql.TypeVarchar, size: 512}, + {name: "SCHEMA_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "DEFAULT_CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "DEFAULT_COLLATION_NAME", tp: mysql.TypeVarchar, size: 32}, + {name: "SQL_PATH", tp: mysql.TypeVarchar, size: 512}, + {name: "TIDB_PLACEMENT_POLICY_NAME", tp: mysql.TypeVarchar, size: 64}, +} + +var tablesCols = []columnInfo{ + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "ENGINE", tp: mysql.TypeVarchar, size: 64}, + {name: "VERSION", tp: mysql.TypeLonglong, size: 21}, + {name: "ROW_FORMAT", tp: mysql.TypeVarchar, size: 10}, + {name: "TABLE_ROWS", tp: mysql.TypeLonglong, size: 21}, + {name: "AVG_ROW_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "MAX_DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "INDEX_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "DATA_FREE", tp: mysql.TypeLonglong, size: 21}, + {name: "AUTO_INCREMENT", tp: mysql.TypeLonglong, size: 21}, + {name: "CREATE_TIME", tp: mysql.TypeDatetime, size: 19}, + {name: "UPDATE_TIME", tp: mysql.TypeDatetime, size: 19}, + {name: "CHECK_TIME", tp: mysql.TypeDatetime, size: 19}, + {name: "TABLE_COLLATION", tp: mysql.TypeVarchar, size: 32, deflt: mysql.DefaultCollationName}, + {name: "CHECKSUM", tp: mysql.TypeLonglong, size: 21}, + {name: "CREATE_OPTIONS", tp: mysql.TypeVarchar, size: 255}, + {name: "TABLE_COMMENT", tp: mysql.TypeVarchar, size: 2048}, + {name: "TIDB_TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "TIDB_ROW_ID_SHARDING_INFO", tp: mysql.TypeVarchar, size: 255}, + {name: "TIDB_PK_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_PLACEMENT_POLICY_NAME", tp: mysql.TypeVarchar, size: 64}, +} + +// See: http://dev.mysql.com/doc/refman/5.7/en/information-schema-columns-table.html +var columnsCols = []columnInfo{ + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 64}, + {name: "COLUMN_DEFAULT", tp: mysql.TypeBlob, size: 196606}, + {name: "IS_NULLABLE", tp: mysql.TypeVarchar, size: 3}, + {name: "DATA_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "CHARACTER_MAXIMUM_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "CHARACTER_OCTET_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "NUMERIC_PRECISION", tp: mysql.TypeLonglong, size: 21}, + {name: "NUMERIC_SCALE", tp: mysql.TypeLonglong, size: 21}, + {name: "DATETIME_PRECISION", tp: mysql.TypeLonglong, size: 21}, + {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32}, + {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 32}, + {name: "COLUMN_TYPE", tp: mysql.TypeBlob, size: 196606}, + {name: "COLUMN_KEY", tp: mysql.TypeVarchar, size: 3}, + {name: "EXTRA", tp: mysql.TypeVarchar, size: 45}, + {name: "PRIVILEGES", tp: mysql.TypeVarchar, size: 80}, + {name: "COLUMN_COMMENT", tp: mysql.TypeVarchar, size: 1024}, + {name: "GENERATION_EXPRESSION", tp: mysql.TypeBlob, size: 589779, flag: mysql.NotNullFlag}, +} + +var columnStatisticsCols = []columnInfo{ + {name: "SCHEMA_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "HISTOGRAM", tp: mysql.TypeJSON, size: 51}, +} + +var statisticsCols = []columnInfo{ + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "NON_UNIQUE", tp: mysql.TypeVarchar, size: 1}, + {name: "INDEX_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "SEQ_IN_INDEX", tp: mysql.TypeLonglong, size: 2}, + {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 21}, + {name: "COLLATION", tp: mysql.TypeVarchar, size: 1}, + {name: "CARDINALITY", tp: mysql.TypeLonglong, size: 21}, + {name: "SUB_PART", tp: mysql.TypeLonglong, size: 3}, + {name: "PACKED", tp: mysql.TypeVarchar, size: 10}, + {name: "NULLABLE", tp: mysql.TypeVarchar, size: 3}, + {name: "INDEX_TYPE", tp: mysql.TypeVarchar, size: 16}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 16}, + {name: "INDEX_COMMENT", tp: mysql.TypeVarchar, size: 1024}, + {name: "IS_VISIBLE", tp: mysql.TypeVarchar, size: 3}, + {name: "Expression", tp: mysql.TypeVarchar, size: 64}, +} + +var profilingCols = []columnInfo{ + {name: "QUERY_ID", tp: mysql.TypeLong, size: 20}, + {name: "SEQ", tp: mysql.TypeLong, size: 20}, + {name: "STATE", tp: mysql.TypeVarchar, size: 30}, + {name: "DURATION", tp: mysql.TypeNewDecimal, size: 9}, + {name: "CPU_USER", tp: mysql.TypeNewDecimal, size: 9}, + {name: "CPU_SYSTEM", tp: mysql.TypeNewDecimal, size: 9}, + {name: "CONTEXT_VOLUNTARY", tp: mysql.TypeLong, size: 20}, + {name: "CONTEXT_INVOLUNTARY", tp: mysql.TypeLong, size: 20}, + {name: "BLOCK_OPS_IN", tp: mysql.TypeLong, size: 20}, + {name: "BLOCK_OPS_OUT", tp: mysql.TypeLong, size: 20}, + {name: "MESSAGES_SENT", tp: mysql.TypeLong, size: 20}, + {name: "MESSAGES_RECEIVED", tp: mysql.TypeLong, size: 20}, + {name: "PAGE_FAULTS_MAJOR", tp: mysql.TypeLong, size: 20}, + {name: "PAGE_FAULTS_MINOR", tp: mysql.TypeLong, size: 20}, + {name: "SWAPS", tp: mysql.TypeLong, size: 20}, + {name: "SOURCE_FUNCTION", tp: mysql.TypeVarchar, size: 30}, + {name: "SOURCE_FILE", tp: mysql.TypeVarchar, size: 20}, + {name: "SOURCE_LINE", tp: mysql.TypeLong, size: 20}, +} + +var charsetCols = []columnInfo{ + {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32}, + {name: "DEFAULT_COLLATE_NAME", tp: mysql.TypeVarchar, size: 32}, + {name: "DESCRIPTION", tp: mysql.TypeVarchar, size: 60}, + {name: "MAXLEN", tp: mysql.TypeLonglong, size: 3}, +} + +var collationsCols = []columnInfo{ + {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 32}, + {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32}, + {name: "ID", tp: mysql.TypeLonglong, size: 11}, + {name: "IS_DEFAULT", tp: mysql.TypeVarchar, size: 3}, + {name: "IS_COMPILED", tp: mysql.TypeVarchar, size: 3}, + {name: "SORTLEN", tp: mysql.TypeLonglong, size: 3}, + {name: "PAD_ATTRIBUTE", tp: mysql.TypeVarchar, size: 9}, +} + +var keyColumnUsageCols = []columnInfo{ + {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 10, flag: mysql.NotNullFlag}, + {name: "POSITION_IN_UNIQUE_CONSTRAINT", tp: mysql.TypeLonglong, size: 10}, + {name: "REFERENCED_TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "REFERENCED_TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "REFERENCED_COLUMN_NAME", tp: mysql.TypeVarchar, size: 64}, +} + +// See http://dev.mysql.com/doc/refman/5.7/en/information-schema-referential-constraints-table.html +var referConstCols = []columnInfo{ + {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "UNIQUE_CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "UNIQUE_CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "UNIQUE_CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "MATCH_OPTION", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "UPDATE_RULE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "DELETE_RULE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "REFERENCED_TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, +} + +// See http://dev.mysql.com/doc/refman/5.7/en/information-schema-variables-table.html +var sessionVarCols = []columnInfo{ + {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, +} + +// See https://dev.mysql.com/doc/refman/5.7/en/information-schema-plugins-table.html +var pluginsCols = []columnInfo{ + {name: "PLUGIN_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "PLUGIN_VERSION", tp: mysql.TypeVarchar, size: 20}, + {name: "PLUGIN_STATUS", tp: mysql.TypeVarchar, size: 10}, + {name: "PLUGIN_TYPE", tp: mysql.TypeVarchar, size: 80}, + {name: "PLUGIN_TYPE_VERSION", tp: mysql.TypeVarchar, size: 20}, + {name: "PLUGIN_LIBRARY", tp: mysql.TypeVarchar, size: 64}, + {name: "PLUGIN_LIBRARY_VERSION", tp: mysql.TypeVarchar, size: 20}, + {name: "PLUGIN_AUTHOR", tp: mysql.TypeVarchar, size: 64}, + {name: "PLUGIN_DESCRIPTION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: "PLUGIN_LICENSE", tp: mysql.TypeVarchar, size: 80}, + {name: "LOAD_OPTION", tp: mysql.TypeVarchar, size: 64}, +} + +// See https://dev.mysql.com/doc/refman/5.7/en/information-schema-partitions-table.html +var partitionsCols = []columnInfo{ + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "PARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "SUBPARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "PARTITION_ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 21}, + {name: "SUBPARTITION_ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 21}, + {name: "PARTITION_METHOD", tp: mysql.TypeVarchar, size: 18}, + {name: "SUBPARTITION_METHOD", tp: mysql.TypeVarchar, size: 12}, + {name: "PARTITION_EXPRESSION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: "SUBPARTITION_EXPRESSION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: "PARTITION_DESCRIPTION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: "TABLE_ROWS", tp: mysql.TypeLonglong, size: 21}, + {name: "AVG_ROW_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "MAX_DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "INDEX_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "DATA_FREE", tp: mysql.TypeLonglong, size: 21}, + {name: "CREATE_TIME", tp: mysql.TypeDatetime}, + {name: "UPDATE_TIME", tp: mysql.TypeDatetime}, + {name: "CHECK_TIME", tp: mysql.TypeDatetime}, + {name: "CHECKSUM", tp: mysql.TypeLonglong, size: 21}, + {name: "PARTITION_COMMENT", tp: mysql.TypeVarchar, size: 80}, + {name: "NODEGROUP", tp: mysql.TypeVarchar, size: 12}, + {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_PARTITION_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "TIDB_PLACEMENT_POLICY_NAME", tp: mysql.TypeVarchar, size: 64}, +} + +var tableConstraintsCols = []columnInfo{ + {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "CONSTRAINT_TYPE", tp: mysql.TypeVarchar, size: 64}, +} + +var tableTriggersCols = []columnInfo{ + {name: "TRIGGER_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "TRIGGER_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TRIGGER_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "EVENT_MANIPULATION", tp: mysql.TypeVarchar, size: 6}, + {name: "EVENT_OBJECT_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "EVENT_OBJECT_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "EVENT_OBJECT_TABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "ACTION_ORDER", tp: mysql.TypeLonglong, size: 4}, + {name: "ACTION_CONDITION", tp: mysql.TypeBlob, size: -1}, + {name: "ACTION_STATEMENT", tp: mysql.TypeBlob, size: -1}, + {name: "ACTION_ORIENTATION", tp: mysql.TypeVarchar, size: 9}, + {name: "ACTION_TIMING", tp: mysql.TypeVarchar, size: 6}, + {name: "ACTION_REFERENCE_OLD_TABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "ACTION_REFERENCE_NEW_TABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "ACTION_REFERENCE_OLD_ROW", tp: mysql.TypeVarchar, size: 3}, + {name: "ACTION_REFERENCE_NEW_ROW", tp: mysql.TypeVarchar, size: 3}, + {name: "CREATED", tp: mysql.TypeDatetime, size: 2}, + {name: "SQL_MODE", tp: mysql.TypeVarchar, size: 8192}, + {name: "DEFINER", tp: mysql.TypeVarchar, size: 77}, + {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32}, + {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32}, + {name: "DATABASE_COLLATION", tp: mysql.TypeVarchar, size: 32}, +} + +var tableUserPrivilegesCols = []columnInfo{ + {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81}, + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, + {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3}, +} + +var tableSchemaPrivilegesCols = []columnInfo{ + {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81, flag: mysql.NotNullFlag}, + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, +} + +var tableTablePrivilegesCols = []columnInfo{ + {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81, flag: mysql.NotNullFlag}, + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, +} + +var tableColumnPrivilegesCols = []columnInfo{ + {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81, flag: mysql.NotNullFlag}, + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, +} + +var tableEnginesCols = []columnInfo{ + {name: "ENGINE", tp: mysql.TypeVarchar, size: 64}, + {name: "SUPPORT", tp: mysql.TypeVarchar, size: 8}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 80}, + {name: "TRANSACTIONS", tp: mysql.TypeVarchar, size: 3}, + {name: "XA", tp: mysql.TypeVarchar, size: 3}, + {name: "SAVEPOINTS", tp: mysql.TypeVarchar, size: 3}, +} + +var tableViewsCols = []columnInfo{ + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "VIEW_DEFINITION", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag}, + {name: "CHECK_OPTION", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, + {name: "IS_UPDATABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, + {name: "DEFINER", tp: mysql.TypeVarchar, size: 77, flag: mysql.NotNullFlag}, + {name: "SECURITY_TYPE", tp: mysql.TypeVarchar, size: 7, flag: mysql.NotNullFlag}, + {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, +} + +var tableRoutinesCols = []columnInfo{ + {name: "SPECIFIC_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "ROUTINE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "ROUTINE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "ROUTINE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "ROUTINE_TYPE", tp: mysql.TypeVarchar, size: 9, flag: mysql.NotNullFlag}, + {name: "DATA_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CHARACTER_MAXIMUM_LENGTH", tp: mysql.TypeLong, size: 21}, + {name: "CHARACTER_OCTET_LENGTH", tp: mysql.TypeLong, size: 21}, + {name: "NUMERIC_PRECISION", tp: mysql.TypeLonglong, size: 21}, + {name: "NUMERIC_SCALE", tp: mysql.TypeLong, size: 21}, + {name: "DATETIME_PRECISION", tp: mysql.TypeLonglong, size: 21}, + {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "DTD_IDENTIFIER", tp: mysql.TypeLongBlob}, + {name: "ROUTINE_BODY", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, + {name: "ROUTINE_DEFINITION", tp: mysql.TypeLongBlob}, + {name: "EXTERNAL_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "EXTERNAL_LANGUAGE", tp: mysql.TypeVarchar, size: 64}, + {name: "PARAMETER_STYLE", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, + {name: "IS_DETERMINISTIC", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, + {name: "SQL_DATA_ACCESS", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "SQL_PATH", tp: mysql.TypeVarchar, size: 64}, + {name: "SECURITY_TYPE", tp: mysql.TypeVarchar, size: 7, flag: mysql.NotNullFlag}, + {name: "CREATED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, + {name: "LAST_ALTERED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, + {name: "SQL_MODE", tp: mysql.TypeVarchar, size: 8192, flag: mysql.NotNullFlag}, + {name: "ROUTINE_COMMENT", tp: mysql.TypeLongBlob}, + {name: "DEFINER", tp: mysql.TypeVarchar, size: 77, flag: mysql.NotNullFlag}, + {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "DATABASE_COLLATION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, +} + +var tableParametersCols = []columnInfo{ + {name: "SPECIFIC_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "SPECIFIC_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "SPECIFIC_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "ORDINAL_POSITION", tp: mysql.TypeVarchar, size: 21, flag: mysql.NotNullFlag}, + {name: "PARAMETER_MODE", tp: mysql.TypeVarchar, size: 5}, + {name: "PARAMETER_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "DATA_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CHARACTER_MAXIMUM_LENGTH", tp: mysql.TypeVarchar, size: 21}, + {name: "CHARACTER_OCTET_LENGTH", tp: mysql.TypeVarchar, size: 21}, + {name: "NUMERIC_PRECISION", tp: mysql.TypeVarchar, size: 21}, + {name: "NUMERIC_SCALE", tp: mysql.TypeVarchar, size: 21}, + {name: "DATETIME_PRECISION", tp: mysql.TypeVarchar, size: 21}, + {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "DTD_IDENTIFIER", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag}, + {name: "ROUTINE_TYPE", tp: mysql.TypeVarchar, size: 9, flag: mysql.NotNullFlag}, +} + +var tableEventsCols = []columnInfo{ + {name: "EVENT_CATALOG", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "EVENT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "EVENT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "DEFINER", tp: mysql.TypeVarchar, size: 77, flag: mysql.NotNullFlag}, + {name: "TIME_ZONE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "EVENT_BODY", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, + {name: "EVENT_DEFINITION", tp: mysql.TypeLongBlob}, + {name: "EVENT_TYPE", tp: mysql.TypeVarchar, size: 9, flag: mysql.NotNullFlag}, + {name: "EXECUTE_AT", tp: mysql.TypeDatetime}, + {name: "INTERVAL_VALUE", tp: mysql.TypeVarchar, size: 256}, + {name: "INTERVAL_FIELD", tp: mysql.TypeVarchar, size: 18}, + {name: "SQL_MODE", tp: mysql.TypeVarchar, size: 8192, flag: mysql.NotNullFlag}, + {name: "STARTS", tp: mysql.TypeDatetime}, + {name: "ENDS", tp: mysql.TypeDatetime}, + {name: "STATUS", tp: mysql.TypeVarchar, size: 18, flag: mysql.NotNullFlag}, + {name: "ON_COMPLETION", tp: mysql.TypeVarchar, size: 12, flag: mysql.NotNullFlag}, + {name: "CREATED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, + {name: "LAST_ALTERED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, + {name: "LAST_EXECUTED", tp: mysql.TypeDatetime}, + {name: "EVENT_COMMENT", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "ORIGINATOR", tp: mysql.TypeLong, size: 10, flag: mysql.NotNullFlag, deflt: 0}, + {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "DATABASE_COLLATION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, +} + +var tableGlobalStatusCols = []columnInfo{ + {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, +} + +var tableGlobalVariablesCols = []columnInfo{ + {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, +} + +var tableSessionStatusCols = []columnInfo{ + {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, +} + +var tableOptimizerTraceCols = []columnInfo{ + {name: "QUERY", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag, deflt: ""}, + {name: "TRACE", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag, deflt: ""}, + {name: "MISSING_BYTES_BEYOND_MAX_MEM_SIZE", tp: mysql.TypeShort, size: 20, flag: mysql.NotNullFlag, deflt: 0}, + {name: "INSUFFICIENT_PRIVILEGES", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, +} + +var tableTableSpacesCols = []columnInfo{ + {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, + {name: "ENGINE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, + {name: "TABLESPACE_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "LOGFILE_GROUP_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "EXTENT_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "AUTOEXTEND_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "MAXIMUM_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "NODEGROUP_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "TABLESPACE_COMMENT", tp: mysql.TypeVarchar, size: 2048}, +} + +var tableCollationCharacterSetApplicabilityCols = []columnInfo{ + {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, +} + +var tableProcesslistCols = []columnInfo{ + {name: "ID", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, deflt: 0}, + {name: "USER", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag, deflt: ""}, + {name: "HOST", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, + {name: "DB", tp: mysql.TypeVarchar, size: 64}, + {name: "COMMAND", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag, deflt: ""}, + {name: "TIME", tp: mysql.TypeLong, size: 7, flag: mysql.NotNullFlag, deflt: 0}, + {name: "STATE", tp: mysql.TypeVarchar, size: 7}, + {name: "INFO", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: "DIGEST", tp: mysql.TypeVarchar, size: 64, deflt: ""}, + {name: "MEM", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, + {name: "DISK", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, + {name: "TxnStart", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, + {name: "RESOURCE_GROUP", tp: mysql.TypeVarchar, size: resourcegroup.MaxGroupNameLength, flag: mysql.NotNullFlag, deflt: ""}, + {name: "SESSION_ALIAS", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, + {name: "CURRENT_AFFECTED_ROWS", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, +} + +var tableTiDBIndexesCols = []columnInfo{ + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "NON_UNIQUE", tp: mysql.TypeLonglong, size: 21}, + {name: "KEY_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "SEQ_IN_INDEX", tp: mysql.TypeLonglong, size: 21}, + {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "SUB_PART", tp: mysql.TypeLonglong, size: 21}, + {name: "INDEX_COMMENT", tp: mysql.TypeVarchar, size: 1024}, + {name: "Expression", tp: mysql.TypeVarchar, size: 64}, + {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "IS_VISIBLE", tp: mysql.TypeVarchar, size: 64}, + {name: "CLUSTERED", tp: mysql.TypeVarchar, size: 64}, + {name: "IS_GLOBAL", tp: mysql.TypeLonglong, size: 21}, +} + +var slowQueryCols = []columnInfo{ + {name: variable.SlowLogTimeStr, tp: mysql.TypeTimestamp, size: 26, decimal: 6, flag: mysql.PriKeyFlag | mysql.NotNullFlag | mysql.BinaryFlag}, + {name: variable.SlowLogTxnStartTSStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: variable.SlowLogUserStr, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogHostStr, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogConnIDStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: variable.SlowLogSessAliasStr, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogExecRetryCount, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: variable.SlowLogExecRetryTime, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogQueryTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogParseTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogCompileTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogRewriteTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogPreprocSubQueriesStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: variable.SlowLogPreProcSubQueryTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogOptimizeTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogWaitTSTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.PreWriteTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.WaitPrewriteBinlogTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.CommitTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.GetCommitTSTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.CommitBackoffTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.BackoffTypesStr, tp: mysql.TypeVarchar, size: 64}, + {name: execdetails.ResolveLockTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.LocalLatchWaitTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.WriteKeysStr, tp: mysql.TypeLonglong, size: 22}, + {name: execdetails.WriteSizeStr, tp: mysql.TypeLonglong, size: 22}, + {name: execdetails.PrewriteRegionStr, tp: mysql.TypeLonglong, size: 22}, + {name: execdetails.TxnRetryStr, tp: mysql.TypeLonglong, size: 22}, + {name: execdetails.CopTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.ProcessTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.WaitTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.BackoffTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.LockKeysTimeStr, tp: mysql.TypeDouble, size: 22}, + {name: execdetails.RequestCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: execdetails.TotalKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: execdetails.ProcessKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: execdetails.RocksdbDeleteSkippedCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: execdetails.RocksdbKeySkippedCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: execdetails.RocksdbBlockCacheHitCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: execdetails.RocksdbBlockReadCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: execdetails.RocksdbBlockReadByteStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, + {name: variable.SlowLogDBStr, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogIndexNamesStr, tp: mysql.TypeVarchar, size: 100}, + {name: variable.SlowLogIsInternalStr, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogDigestStr, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogStatsInfoStr, tp: mysql.TypeVarchar, size: 512}, + {name: variable.SlowLogCopProcAvg, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogCopProcP90, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogCopProcMax, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogCopProcAddr, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogCopWaitAvg, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogCopWaitP90, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogCopWaitMax, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogCopWaitAddr, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogMemMax, tp: mysql.TypeLonglong, size: 20}, + {name: variable.SlowLogDiskMax, tp: mysql.TypeLonglong, size: 20}, + {name: variable.SlowLogKVTotal, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogPDTotal, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogBackoffTotal, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogWriteSQLRespTotal, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogResultRows, tp: mysql.TypeLonglong, size: 22}, + {name: variable.SlowLogWarnings, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: variable.SlowLogBackoffDetail, tp: mysql.TypeVarchar, size: 4096}, + {name: variable.SlowLogPrepared, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogSucc, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogIsExplicitTxn, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogIsWriteCacheTable, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogPlanFromCache, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogPlanFromBinding, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogHasMoreResults, tp: mysql.TypeTiny, size: 1}, + {name: variable.SlowLogResourceGroup, tp: mysql.TypeVarchar, size: 64}, + {name: variable.SlowLogRRU, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogWRU, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogWaitRUDuration, tp: mysql.TypeDouble, size: 22}, + {name: variable.SlowLogPlan, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: variable.SlowLogPlanDigest, tp: mysql.TypeVarchar, size: 128}, + {name: variable.SlowLogBinaryPlan, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: variable.SlowLogPrevStmt, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: variable.SlowLogQuerySQLStr, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, +} + +// TableTiDBHotRegionsCols is TiDB hot region mem table columns. +var TableTiDBHotRegionsCols = []columnInfo{ + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "MAX_HOT_DEGREE", tp: mysql.TypeLonglong, size: 21}, + {name: "REGION_COUNT", tp: mysql.TypeLonglong, size: 21}, + {name: "FLOW_BYTES", tp: mysql.TypeLonglong, size: 21}, +} + +// TableTiDBHotRegionsHistoryCols is TiDB hot region history mem table columns. +var TableTiDBHotRegionsHistoryCols = []columnInfo{ + {name: "UPDATE_TIME", tp: mysql.TypeTimestamp, size: 26, decimal: 6}, + {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "STORE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "PEER_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "IS_LEARNER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, + {name: "IS_LEADER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "HOT_DEGREE", tp: mysql.TypeLonglong, size: 21}, + {name: "FLOW_BYTES", tp: mysql.TypeDouble, size: 22}, + {name: "KEY_RATE", tp: mysql.TypeDouble, size: 22}, + {name: "QUERY_RATE", tp: mysql.TypeDouble, size: 22}, +} + +// GetTableTiDBHotRegionsHistoryCols is to get TableTiDBHotRegionsHistoryCols. +// It is an optimization because Go does’t support const arrays. The solution is to use initialization functions. +// It is useful in the BCE optimization. +// https://go101.org/article/bounds-check-elimination.html +func GetTableTiDBHotRegionsHistoryCols() []columnInfo { + return TableTiDBHotRegionsHistoryCols +} + +// TableTiKVStoreStatusCols is TiDB kv store status columns. +var TableTiKVStoreStatusCols = []columnInfo{ + {name: "STORE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "ADDRESS", tp: mysql.TypeVarchar, size: 64}, + {name: "STORE_STATE", tp: mysql.TypeLonglong, size: 21}, + {name: "STORE_STATE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "LABEL", tp: mysql.TypeJSON, size: 51}, + {name: "VERSION", tp: mysql.TypeVarchar, size: 64}, + {name: "CAPACITY", tp: mysql.TypeVarchar, size: 64}, + {name: "AVAILABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "LEADER_COUNT", tp: mysql.TypeLonglong, size: 21}, + {name: "LEADER_WEIGHT", tp: mysql.TypeDouble, size: 22}, + {name: "LEADER_SCORE", tp: mysql.TypeDouble, size: 22}, + {name: "LEADER_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "REGION_COUNT", tp: mysql.TypeLonglong, size: 21}, + {name: "REGION_WEIGHT", tp: mysql.TypeDouble, size: 22}, + {name: "REGION_SCORE", tp: mysql.TypeDouble, size: 22}, + {name: "REGION_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "START_TS", tp: mysql.TypeDatetime}, + {name: "LAST_HEARTBEAT_TS", tp: mysql.TypeDatetime}, + {name: "UPTIME", tp: mysql.TypeVarchar, size: 64}, +} + +var tableAnalyzeStatusCols = []columnInfo{ + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "PARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "JOB_INFO", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: "PROCESSED_ROWS", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, + {name: "START_TIME", tp: mysql.TypeDatetime}, + {name: "END_TIME", tp: mysql.TypeDatetime}, + {name: "STATE", tp: mysql.TypeVarchar, size: 64}, + {name: "FAIL_REASON", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 512}, + {name: "PROCESS_ID", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, + {name: "REMAINING_SECONDS", tp: mysql.TypeVarchar, size: 512}, + {name: "PROGRESS", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "ESTIMATED_TOTAL_ROWS", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, +} + +// TableTiKVRegionStatusCols is TiKV region status mem table columns. +var TableTiKVRegionStatusCols = []columnInfo{ + {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "START_KEY", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, + {name: "END_KEY", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "IS_INDEX", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, + {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "IS_PARTITION", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, + {name: "PARTITION_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "PARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "EPOCH_CONF_VER", tp: mysql.TypeLonglong, size: 21}, + {name: "EPOCH_VERSION", tp: mysql.TypeLonglong, size: 21}, + {name: "WRITTEN_BYTES", tp: mysql.TypeLonglong, size: 21}, + {name: "READ_BYTES", tp: mysql.TypeLonglong, size: 21}, + {name: "APPROXIMATE_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "APPROXIMATE_KEYS", tp: mysql.TypeLonglong, size: 21}, + {name: "REPLICATIONSTATUS_STATE", tp: mysql.TypeVarchar, size: 64}, + {name: "REPLICATIONSTATUS_STATEID", tp: mysql.TypeLonglong, size: 21}, +} + +// TableTiKVRegionPeersCols is TiKV region peers mem table columns. +var TableTiKVRegionPeersCols = []columnInfo{ + {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "PEER_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "STORE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "IS_LEARNER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, + {name: "IS_LEADER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, + {name: "STATUS", tp: mysql.TypeVarchar, size: 10, deflt: 0}, + {name: "DOWN_SECONDS", tp: mysql.TypeLonglong, size: 21, deflt: 0}, +} + +// GetTableTiKVRegionPeersCols is to get TableTiKVRegionPeersCols. +// It is an optimization because Go does’t support const arrays. The solution is to use initialization functions. +// It is useful in the BCE optimization. +// https://go101.org/article/bounds-check-elimination.html +func GetTableTiKVRegionPeersCols() []columnInfo { + return TableTiKVRegionPeersCols +} + +var tableTiDBServersInfoCols = []columnInfo{ + {name: "DDL_ID", tp: mysql.TypeVarchar, size: 64}, + {name: "IP", tp: mysql.TypeVarchar, size: 64}, + {name: "PORT", tp: mysql.TypeLonglong, size: 21}, + {name: "STATUS_PORT", tp: mysql.TypeLonglong, size: 21}, + {name: "LEASE", tp: mysql.TypeVarchar, size: 64}, + {name: "VERSION", tp: mysql.TypeVarchar, size: 64}, + {name: "GIT_HASH", tp: mysql.TypeVarchar, size: 64}, + {name: "BINLOG_STATUS", tp: mysql.TypeVarchar, size: 64}, + {name: "LABELS", tp: mysql.TypeVarchar, size: 128}, +} + +var tableClusterConfigCols = []columnInfo{ + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "KEY", tp: mysql.TypeVarchar, size: 256}, + {name: "VALUE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, +} + +var tableClusterLogCols = []columnInfo{ + {name: "TIME", tp: mysql.TypeVarchar, size: 32}, + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "LEVEL", tp: mysql.TypeVarchar, size: 8}, + {name: "MESSAGE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, +} + +var tableClusterLoadCols = []columnInfo{ + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "DEVICE_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "DEVICE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "NAME", tp: mysql.TypeVarchar, size: 256}, + {name: "VALUE", tp: mysql.TypeVarchar, size: 128}, +} + +var tableClusterHardwareCols = []columnInfo{ + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "DEVICE_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "DEVICE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "NAME", tp: mysql.TypeVarchar, size: 256}, + {name: "VALUE", tp: mysql.TypeVarchar, size: 128}, +} + +var tableClusterSystemInfoCols = []columnInfo{ + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "SYSTEM_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "SYSTEM_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "NAME", tp: mysql.TypeVarchar, size: 256}, + {name: "VALUE", tp: mysql.TypeVarchar, size: 128}, +} + +var filesCols = []columnInfo{ + {name: "FILE_ID", tp: mysql.TypeLonglong, size: 4}, + {name: "FILE_NAME", tp: mysql.TypeVarchar, size: 4000}, + {name: "FILE_TYPE", tp: mysql.TypeVarchar, size: 20}, + {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "LOGFILE_GROUP_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "LOGFILE_GROUP_NUMBER", tp: mysql.TypeLonglong, size: 32}, + {name: "ENGINE", tp: mysql.TypeVarchar, size: 64}, + {name: "FULLTEXT_KEYS", tp: mysql.TypeVarchar, size: 64}, + {name: "DELETED_ROWS", tp: mysql.TypeLonglong, size: 4}, + {name: "UPDATE_COUNT", tp: mysql.TypeLonglong, size: 4}, + {name: "FREE_EXTENTS", tp: mysql.TypeLonglong, size: 4}, + {name: "TOTAL_EXTENTS", tp: mysql.TypeLonglong, size: 4}, + {name: "EXTENT_SIZE", tp: mysql.TypeLonglong, size: 4}, + {name: "INITIAL_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "MAXIMUM_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "AUTOEXTEND_SIZE", tp: mysql.TypeLonglong, size: 21}, + {name: "CREATION_TIME", tp: mysql.TypeDatetime, size: -1}, + {name: "LAST_UPDATE_TIME", tp: mysql.TypeDatetime, size: -1}, + {name: "LAST_ACCESS_TIME", tp: mysql.TypeDatetime, size: -1}, + {name: "RECOVER_TIME", tp: mysql.TypeLonglong, size: 4}, + {name: "TRANSACTION_COUNTER", tp: mysql.TypeLonglong, size: 4}, + {name: "VERSION", tp: mysql.TypeLonglong, size: 21}, + {name: "ROW_FORMAT", tp: mysql.TypeVarchar, size: 10}, + {name: "TABLE_ROWS", tp: mysql.TypeLonglong, size: 21}, + {name: "AVG_ROW_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "MAX_DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "INDEX_LENGTH", tp: mysql.TypeLonglong, size: 21}, + {name: "DATA_FREE", tp: mysql.TypeLonglong, size: 21}, + {name: "CREATE_TIME", tp: mysql.TypeDatetime, size: -1}, + {name: "UPDATE_TIME", tp: mysql.TypeDatetime, size: -1}, + {name: "CHECK_TIME", tp: mysql.TypeDatetime, size: -1}, + {name: "CHECKSUM", tp: mysql.TypeLonglong, size: 21}, + {name: "STATUS", tp: mysql.TypeVarchar, size: 20}, + {name: "EXTRA", tp: mysql.TypeVarchar, size: 255}, +} + +var tableClusterInfoCols = []columnInfo{ + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "STATUS_ADDRESS", tp: mysql.TypeVarchar, size: 64}, + {name: "VERSION", tp: mysql.TypeVarchar, size: 64}, + {name: "GIT_HASH", tp: mysql.TypeVarchar, size: 64}, + {name: "START_TIME", tp: mysql.TypeDatetime, size: 19}, + {name: "UPTIME", tp: mysql.TypeVarchar, size: 32}, + {name: "SERVER_ID", tp: mysql.TypeLonglong, size: 21, comment: "invalid if the configuration item `enable-global-kill` is set to FALSE"}, +} + +var tableTableTiFlashReplicaCols = []columnInfo{ + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "REPLICA_COUNT", tp: mysql.TypeLonglong, size: 64}, + {name: "LOCATION_LABELS", tp: mysql.TypeVarchar, size: 64}, + {name: "AVAILABLE", tp: mysql.TypeTiny, size: 1}, + {name: "PROGRESS", tp: mysql.TypeDouble, size: 22}, +} + +var tableInspectionResultCols = []columnInfo{ + {name: "RULE", tp: mysql.TypeVarchar, size: 64}, + {name: "ITEM", tp: mysql.TypeVarchar, size: 64}, + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "STATUS_ADDRESS", tp: mysql.TypeVarchar, size: 64}, + {name: "VALUE", tp: mysql.TypeVarchar, size: 64}, + {name: "REFERENCE", tp: mysql.TypeVarchar, size: 64}, + {name: "SEVERITY", tp: mysql.TypeVarchar, size: 64}, + {name: "DETAILS", tp: mysql.TypeVarchar, size: 256}, +} + +var tableInspectionSummaryCols = []columnInfo{ + {name: "RULE", tp: mysql.TypeVarchar, size: 64}, + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "METRICS_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "LABEL", tp: mysql.TypeVarchar, size: 64}, + {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, + {name: "AVG_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "MIN_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "MAX_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, +} + +var tableInspectionRulesCols = []columnInfo{ + {name: "NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, +} + +var tableMetricTablesCols = []columnInfo{ + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "PROMQL", tp: mysql.TypeVarchar, size: 64}, + {name: "LABELS", tp: mysql.TypeVarchar, size: 64}, + {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, +} + +var tableMetricSummaryCols = []columnInfo{ + {name: "METRICS_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, + {name: "SUM_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "AVG_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "MIN_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "MAX_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, +} + +var tableMetricSummaryByLabelCols = []columnInfo{ + {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, + {name: "METRICS_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "LABEL", tp: mysql.TypeVarchar, size: 64}, + {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, + {name: "SUM_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "AVG_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "MIN_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "MAX_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, +} + +var tableDDLJobsCols = []columnInfo{ + {name: "JOB_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "JOB_TYPE", tp: mysql.TypeVarchar, size: 64}, + {name: "SCHEMA_STATE", tp: mysql.TypeVarchar, size: 64}, + {name: "SCHEMA_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "ROW_COUNT", tp: mysql.TypeLonglong, size: 21}, + {name: "CREATE_TIME", tp: mysql.TypeDatetime, size: 26, decimal: 6}, + {name: "START_TIME", tp: mysql.TypeDatetime, size: 26, decimal: 6}, + {name: "END_TIME", tp: mysql.TypeDatetime, size: 26, decimal: 6}, + {name: "STATE", tp: mysql.TypeVarchar, size: 64}, + {name: "QUERY", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, +} + +var tableSequencesCols = []columnInfo{ + {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "SEQUENCE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "SEQUENCE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CACHE", tp: mysql.TypeTiny, flag: mysql.NotNullFlag}, + {name: "CACHE_VALUE", tp: mysql.TypeLonglong, size: 21}, + {name: "CYCLE", tp: mysql.TypeTiny, flag: mysql.NotNullFlag}, + {name: "INCREMENT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "MAX_VALUE", tp: mysql.TypeLonglong, size: 21}, + {name: "MIN_VALUE", tp: mysql.TypeLonglong, size: 21}, + {name: "START", tp: mysql.TypeLonglong, size: 21}, + {name: "COMMENT", tp: mysql.TypeVarchar, size: 64}, +} + +var tableStatementsSummaryCols = []columnInfo{ + {name: stmtsummary.SummaryBeginTimeStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "Begin time of this summary"}, + {name: stmtsummary.SummaryEndTimeStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "End time of this summary"}, + {name: stmtsummary.StmtTypeStr, tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, comment: "Statement type"}, + {name: stmtsummary.SchemaNameStr, tp: mysql.TypeVarchar, size: 64, comment: "Current schema"}, + {name: stmtsummary.DigestStr, tp: mysql.TypeVarchar, size: 64}, + {name: stmtsummary.DigestTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag, comment: "Normalized statement"}, + {name: stmtsummary.TableNamesStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Involved tables"}, + {name: stmtsummary.IndexNamesStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Used indices"}, + {name: stmtsummary.SampleUserStr, tp: mysql.TypeVarchar, size: 64, comment: "Sampled user who executed these statements"}, + {name: stmtsummary.ExecCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Count of executions"}, + {name: stmtsummary.SumErrorsStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum of errors"}, + {name: stmtsummary.SumWarningsStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum of warnings"}, + {name: stmtsummary.SumLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum latency of these statements"}, + {name: stmtsummary.MaxLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max latency of these statements"}, + {name: stmtsummary.MinLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Min latency of these statements"}, + {name: stmtsummary.AvgLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average latency of these statements"}, + {name: stmtsummary.AvgParseLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average latency of parsing"}, + {name: stmtsummary.MaxParseLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max latency of parsing"}, + {name: stmtsummary.AvgCompileLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average latency of compiling"}, + {name: stmtsummary.MaxCompileLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max latency of compiling"}, + {name: stmtsummary.SumCopTaskNumStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Total number of CopTasks"}, + {name: stmtsummary.MaxCopProcessTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max processing time of CopTasks"}, + {name: stmtsummary.MaxCopProcessAddressStr, tp: mysql.TypeVarchar, size: 256, comment: "Address of the CopTask with max processing time"}, + {name: stmtsummary.MaxCopWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time of CopTasks"}, + {name: stmtsummary.MaxCopWaitAddressStr, tp: mysql.TypeVarchar, size: 256, comment: "Address of the CopTask with max waiting time"}, + {name: stmtsummary.AvgProcessTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average processing time in TiKV"}, + {name: stmtsummary.MaxProcessTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max processing time in TiKV"}, + {name: stmtsummary.AvgWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average waiting time in TiKV"}, + {name: stmtsummary.MaxWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time in TiKV"}, + {name: stmtsummary.AvgBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average waiting time before retry"}, + {name: stmtsummary.MaxBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time before retry"}, + {name: stmtsummary.AvgTotalKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of scanned keys"}, + {name: stmtsummary.MaxTotalKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of scanned keys"}, + {name: stmtsummary.AvgProcessedKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of processed keys"}, + {name: stmtsummary.MaxProcessedKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of processed keys"}, + {name: stmtsummary.AvgRocksdbDeleteSkippedCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb delete skipped count"}, + {name: stmtsummary.MaxRocksdbDeleteSkippedCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb delete skipped count"}, + {name: stmtsummary.AvgRocksdbKeySkippedCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb key skipped count"}, + {name: stmtsummary.MaxRocksdbKeySkippedCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb key skipped count"}, + {name: stmtsummary.AvgRocksdbBlockCacheHitCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb block cache hit count"}, + {name: stmtsummary.MaxRocksdbBlockCacheHitCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb block cache hit count"}, + {name: stmtsummary.AvgRocksdbBlockReadCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb block read count"}, + {name: stmtsummary.MaxRocksdbBlockReadCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb block read count"}, + {name: stmtsummary.AvgRocksdbBlockReadByteStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb block read byte"}, + {name: stmtsummary.MaxRocksdbBlockReadByteStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb block read byte"}, + {name: stmtsummary.AvgPrewriteTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of prewrite phase"}, + {name: stmtsummary.MaxPrewriteTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of prewrite phase"}, + {name: stmtsummary.AvgCommitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of commit phase"}, + {name: stmtsummary.MaxCommitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of commit phase"}, + {name: stmtsummary.AvgGetCommitTsTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of getting commit_ts"}, + {name: stmtsummary.MaxGetCommitTsTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of getting commit_ts"}, + {name: stmtsummary.AvgCommitBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time before retry during commit phase"}, + {name: stmtsummary.MaxCommitBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time before retry during commit phase"}, + {name: stmtsummary.AvgResolveLockTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time for resolving locks"}, + {name: stmtsummary.MaxResolveLockTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time for resolving locks"}, + {name: stmtsummary.AvgLocalLatchWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average waiting time of local transaction"}, + {name: stmtsummary.MaxLocalLatchWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time of local transaction"}, + {name: stmtsummary.AvgWriteKeysStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average count of written keys"}, + {name: stmtsummary.MaxWriteKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max count of written keys"}, + {name: stmtsummary.AvgWriteSizeStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average amount of written bytes"}, + {name: stmtsummary.MaxWriteSizeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max amount of written bytes"}, + {name: stmtsummary.AvgPrewriteRegionsStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of involved regions in prewrite phase"}, + {name: stmtsummary.MaxPrewriteRegionsStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of involved regions in prewrite phase"}, + {name: stmtsummary.AvgTxnRetryStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of transaction retries"}, + {name: stmtsummary.MaxTxnRetryStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of transaction retries"}, + {name: stmtsummary.SumExecRetryStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum number of execution retries in pessimistic transactions"}, + {name: stmtsummary.SumExecRetryTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum time of execution retries in pessimistic transactions"}, + {name: stmtsummary.SumBackoffTimesStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum of retries"}, + {name: stmtsummary.BackoffTypesStr, tp: mysql.TypeVarchar, size: 1024, comment: "Types of errors and the number of retries for each type"}, + {name: stmtsummary.AvgMemStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average memory(byte) used"}, + {name: stmtsummary.MaxMemStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max memory(byte) used"}, + {name: stmtsummary.AvgDiskStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average disk space(byte) used"}, + {name: stmtsummary.MaxDiskStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max disk space(byte) used"}, + {name: stmtsummary.AvgKvTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of TiKV used"}, + {name: stmtsummary.AvgPdTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of PD used"}, + {name: stmtsummary.AvgBackoffTotalTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of Backoff used"}, + {name: stmtsummary.AvgWriteSQLRespTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of write sql resp used"}, + {name: stmtsummary.MaxResultRowsStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag, comment: "Max count of sql result rows"}, + {name: stmtsummary.MinResultRowsStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag, comment: "Min count of sql result rows"}, + {name: stmtsummary.AvgResultRowsStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag, comment: "Average count of sql result rows"}, + {name: stmtsummary.PreparedStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether prepared"}, + {name: stmtsummary.AvgAffectedRowsStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rows affected"}, + {name: stmtsummary.FirstSeenStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "The time these statements are seen for the first time"}, + {name: stmtsummary.LastSeenStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "The time these statements are seen for the last time"}, + {name: stmtsummary.PlanInCacheStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement hit plan cache"}, + {name: stmtsummary.PlanCacheHitsStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag, comment: "The number of times these statements hit plan cache"}, + {name: stmtsummary.PlanInBindingStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement is matched with the hints in the binding"}, + {name: stmtsummary.QuerySampleTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled original statement"}, + {name: stmtsummary.PrevSampleTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The previous statement before commit"}, + {name: stmtsummary.PlanDigestStr, tp: mysql.TypeVarchar, size: 64, comment: "Digest of its execution plan"}, + {name: stmtsummary.PlanStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled execution plan"}, + {name: stmtsummary.BinaryPlan, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled binary plan"}, + {name: stmtsummary.Charset, tp: mysql.TypeVarchar, size: 64, comment: "Sampled charset"}, + {name: stmtsummary.Collation, tp: mysql.TypeVarchar, size: 64, comment: "Sampled collation"}, + {name: stmtsummary.PlanHint, tp: mysql.TypeVarchar, size: 64, comment: "Sampled plan hint"}, + {name: stmtsummary.MaxRequestUnitReadStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Max read request-unit cost of these statements"}, + {name: stmtsummary.AvgRequestUnitReadStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Average read request-unit cost of these statements"}, + {name: stmtsummary.MaxRequestUnitWriteStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Max write request-unit cost of these statements"}, + {name: stmtsummary.AvgRequestUnitWriteStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Average write request-unit cost of these statements"}, + {name: stmtsummary.MaxQueuedRcTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of waiting for available request-units"}, + {name: stmtsummary.AvgQueuedRcTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of waiting for available request-units"}, + {name: stmtsummary.ResourceGroupName, tp: mysql.TypeVarchar, size: 64, comment: "Bind resource group name"}, + {name: stmtsummary.PlanCacheUnqualifiedStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag, comment: "The number of times that these statements are not supported by the plan cache"}, + {name: stmtsummary.PlanCacheUnqualifiedLastReasonStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The last reason why the statement is not supported by the plan cache"}, +} + +var tableStorageStatsCols = []columnInfo{ + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "PEER_COUNT", tp: mysql.TypeLonglong, size: 21}, + {name: "REGION_COUNT", tp: mysql.TypeLonglong, size: 21, comment: "The region count of single replica of the table"}, + {name: "EMPTY_REGION_COUNT", tp: mysql.TypeLonglong, size: 21, comment: "The region count of single replica of the table"}, + {name: "TABLE_SIZE", tp: mysql.TypeLonglong, size: 64, comment: "The disk usage(MB) of single replica of the table, if the table size is empty or less than 1MB, it would show 1MB "}, + {name: "TABLE_KEYS", tp: mysql.TypeLonglong, size: 64, comment: "The count of keys of single replica of the table"}, +} + +var tableTableTiFlashTablesCols = []columnInfo{ + {name: "DATABASE", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_DATABASE", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_TABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "IS_TOMBSTONE", tp: mysql.TypeLonglong, size: 64}, + {name: "SEGMENT_COUNT", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_DELETE_RANGES", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_RATE_ROWS", tp: mysql.TypeDouble, size: 64}, + {name: "DELTA_RATE_SEGMENTS", tp: mysql.TypeDouble, size: 64}, + {name: "DELTA_PLACED_RATE", tp: mysql.TypeDouble, size: 64}, + {name: "DELTA_CACHE_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_CACHE_RATE", tp: mysql.TypeDouble, size: 64}, + {name: "DELTA_CACHE_WASTED_RATE", tp: mysql.TypeDouble, size: 64}, + {name: "DELTA_INDEX_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "AVG_SEGMENT_ROWS", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_SEGMENT_SIZE", tp: mysql.TypeDouble, size: 64}, + {name: "DELTA_COUNT", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_DELTA_ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_DELTA_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "AVG_DELTA_ROWS", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_DELTA_SIZE", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_DELTA_DELETE_RANGES", tp: mysql.TypeDouble, size: 64}, + {name: "STABLE_COUNT", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_STABLE_ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_STABLE_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "TOTAL_STABLE_SIZE_ON_DISK", tp: mysql.TypeLonglong, size: 64}, + {name: "AVG_STABLE_ROWS", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_STABLE_SIZE", tp: mysql.TypeDouble, size: 64}, + {name: "TOTAL_PACK_COUNT_IN_DELTA", tp: mysql.TypeLonglong, size: 64}, + {name: "MAX_PACK_COUNT_IN_DELTA", tp: mysql.TypeLonglong, size: 64}, + {name: "AVG_PACK_COUNT_IN_DELTA", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_PACK_ROWS_IN_DELTA", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_PACK_SIZE_IN_DELTA", tp: mysql.TypeDouble, size: 64}, + {name: "TOTAL_PACK_COUNT_IN_STABLE", tp: mysql.TypeLonglong, size: 64}, + {name: "AVG_PACK_COUNT_IN_STABLE", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_PACK_ROWS_IN_STABLE", tp: mysql.TypeDouble, size: 64}, + {name: "AVG_PACK_SIZE_IN_STABLE", tp: mysql.TypeDouble, size: 64}, + {name: "STORAGE_STABLE_NUM_SNAPSHOTS", tp: mysql.TypeLonglong, size: 64}, + {name: "STORAGE_STABLE_OLDEST_SNAPSHOT_LIFETIME", tp: mysql.TypeDouble, size: 64}, + {name: "STORAGE_STABLE_OLDEST_SNAPSHOT_THREAD_ID", tp: mysql.TypeLonglong, size: 64}, + {name: "STORAGE_STABLE_OLDEST_SNAPSHOT_TRACING_ID", tp: mysql.TypeVarchar, size: 128}, + {name: "STORAGE_DELTA_NUM_SNAPSHOTS", tp: mysql.TypeLonglong, size: 64}, + {name: "STORAGE_DELTA_OLDEST_SNAPSHOT_LIFETIME", tp: mysql.TypeDouble, size: 64}, + {name: "STORAGE_DELTA_OLDEST_SNAPSHOT_THREAD_ID", tp: mysql.TypeLonglong, size: 64}, + {name: "STORAGE_DELTA_OLDEST_SNAPSHOT_TRACING_ID", tp: mysql.TypeVarchar, size: 128}, + {name: "STORAGE_META_NUM_SNAPSHOTS", tp: mysql.TypeLonglong, size: 64}, + {name: "STORAGE_META_OLDEST_SNAPSHOT_LIFETIME", tp: mysql.TypeDouble, size: 64}, + {name: "STORAGE_META_OLDEST_SNAPSHOT_THREAD_ID", tp: mysql.TypeLonglong, size: 64}, + {name: "STORAGE_META_OLDEST_SNAPSHOT_TRACING_ID", tp: mysql.TypeVarchar, size: 128}, + {name: "BACKGROUND_TASKS_LENGTH", tp: mysql.TypeLonglong, size: 64}, + {name: "TIFLASH_INSTANCE", tp: mysql.TypeVarchar, size: 64}, +} + +var tableTableTiFlashSegmentsCols = []columnInfo{ + {name: "DATABASE", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_DATABASE", tp: mysql.TypeVarchar, size: 64}, + {name: "TIDB_TABLE", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, + {name: "IS_TOMBSTONE", tp: mysql.TypeLonglong, size: 64}, + {name: "SEGMENT_ID", tp: mysql.TypeLonglong, size: 64}, + {name: "RANGE", tp: mysql.TypeVarchar, size: 64}, + {name: "EPOCH", tp: mysql.TypeLonglong, size: 64}, + {name: "ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_RATE", tp: mysql.TypeDouble, size: 64}, + {name: "DELTA_MEMTABLE_ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_MEMTABLE_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_MEMTABLE_COLUMN_FILES", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_MEMTABLE_DELETE_RANGES", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_PERSISTED_PAGE_ID", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_PERSISTED_ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_PERSISTED_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_PERSISTED_COLUMN_FILES", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_PERSISTED_DELETE_RANGES", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_CACHE_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "DELTA_INDEX_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_PAGE_ID", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_DMFILES", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_DMFILES_ID_0", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_DMFILES_ROWS", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_DMFILES_SIZE", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_DMFILES_SIZE_ON_DISK", tp: mysql.TypeLonglong, size: 64}, + {name: "STABLE_DMFILES_PACKS", tp: mysql.TypeLonglong, size: 64}, + {name: "TIFLASH_INSTANCE", tp: mysql.TypeVarchar, size: 64}, +} + +var tableClientErrorsSummaryGlobalCols = []columnInfo{ + {name: "ERROR_NUMBER", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "ERROR_MESSAGE", tp: mysql.TypeVarchar, size: 1024, flag: mysql.NotNullFlag}, + {name: "ERROR_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "WARNING_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "FIRST_SEEN", tp: mysql.TypeTimestamp, size: 26}, + {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26}, +} + +var tableClientErrorsSummaryByUserCols = []columnInfo{ + {name: "USER", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "ERROR_NUMBER", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "ERROR_MESSAGE", tp: mysql.TypeVarchar, size: 1024, flag: mysql.NotNullFlag}, + {name: "ERROR_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "WARNING_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "FIRST_SEEN", tp: mysql.TypeTimestamp, size: 26}, + {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26}, +} + +var tableClientErrorsSummaryByHostCols = []columnInfo{ + {name: "HOST", tp: mysql.TypeVarchar, size: 255, flag: mysql.NotNullFlag}, + {name: "ERROR_NUMBER", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "ERROR_MESSAGE", tp: mysql.TypeVarchar, size: 1024, flag: mysql.NotNullFlag}, + {name: "ERROR_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "WARNING_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "FIRST_SEEN", tp: mysql.TypeTimestamp, size: 26}, + {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26}, +} + +var tableTiDBTrxCols = []columnInfo{ + {name: txninfo.IDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.PriKeyFlag | mysql.NotNullFlag | mysql.UnsignedFlag}, + {name: txninfo.StartTimeStr, tp: mysql.TypeTimestamp, decimal: 6, size: 26, comment: "Start time of the transaction"}, + {name: txninfo.CurrentSQLDigestStr, tp: mysql.TypeVarchar, size: 64, comment: "Digest of the sql the transaction are currently running"}, + {name: txninfo.CurrentSQLDigestTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The normalized sql the transaction are currently running"}, + {name: txninfo.StateStr, tp: mysql.TypeEnum, size: 16, enumElems: txninfo.TxnRunningStateStrs, comment: "Current running state of the transaction"}, + {name: txninfo.WaitingStartTimeStr, tp: mysql.TypeTimestamp, decimal: 6, size: 26, comment: "Current lock waiting's start time"}, + {name: txninfo.MemBufferKeysStr, tp: mysql.TypeLonglong, size: 64, comment: "How many entries are in MemDB"}, + {name: txninfo.MemBufferBytesStr, tp: mysql.TypeLonglong, size: 64, comment: "MemDB used memory"}, + {name: txninfo.SessionIDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag, comment: "Which session this transaction belongs to"}, + {name: txninfo.UserStr, tp: mysql.TypeVarchar, size: 16, comment: "The user who open this session"}, + {name: txninfo.DBStr, tp: mysql.TypeVarchar, size: 64, comment: "The schema this transaction works on"}, + {name: txninfo.AllSQLDigestsStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "A list of the digests of SQL statements that the transaction has executed"}, + {name: txninfo.RelatedTableIDsStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "A list of the table IDs that the transaction has accessed"}, + {name: txninfo.WaitingTimeStr, tp: mysql.TypeDouble, size: 22, comment: "Current lock waiting time"}, +} + +var tableDeadlocksCols = []columnInfo{ + {name: deadlockhistory.ColDeadlockIDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag, comment: "The ID to distinguish different deadlock events"}, + {name: deadlockhistory.ColOccurTimeStr, tp: mysql.TypeTimestamp, decimal: 6, size: 26, comment: "The physical time when the deadlock occurs"}, + {name: deadlockhistory.ColRetryableStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the deadlock is retryable. Retryable deadlocks are usually not reported to the client"}, + {name: deadlockhistory.ColTryLockTrxIDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "The transaction ID (start ts) of the transaction that's trying to acquire the lock"}, + {name: deadlockhistory.ColCurrentSQLDigestStr, tp: mysql.TypeVarchar, size: 64, comment: "The digest of the SQL that's being blocked"}, + {name: deadlockhistory.ColCurrentSQLDigestTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The normalized SQL that's being blocked"}, + {name: deadlockhistory.ColKeyStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The key on which a transaction is waiting for another"}, + {name: deadlockhistory.ColKeyInfoStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Information of the key"}, + {name: deadlockhistory.ColTrxHoldingLockStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "The transaction ID (start ts) of the transaction that's currently holding the lock"}, +} + +var tableDataLockWaitsCols = []columnInfo{ + {name: DataLockWaitsColumnKey, tp: mysql.TypeBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag, comment: "The key that's being waiting on"}, + {name: DataLockWaitsColumnKeyInfo, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Information of the key"}, + {name: DataLockWaitsColumnTrxID, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Current transaction that's waiting for the lock"}, + {name: DataLockWaitsColumnCurrentHoldingTrxID, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "The transaction that's holding the lock and blocks the current transaction"}, + {name: DataLockWaitsColumnSQLDigest, tp: mysql.TypeVarchar, size: 64, comment: "Digest of the SQL that's trying to acquire the lock"}, + {name: DataLockWaitsColumnSQLDigestText, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Digest of the SQL that's trying to acquire the lock"}, +} + +var tableStatementsSummaryEvictedCols = []columnInfo{ + {name: "BEGIN_TIME", tp: mysql.TypeTimestamp, size: 26}, + {name: "END_TIME", tp: mysql.TypeTimestamp, size: 26}, + {name: "EVICTED_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, +} + +var tableAttributesCols = []columnInfo{ + {name: "ID", tp: mysql.TypeVarchar, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, + {name: "TYPE", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag}, + {name: "ATTRIBUTES", tp: mysql.TypeVarchar, size: types.UnspecifiedLength}, + {name: "RANGES", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, +} + +var tableTrxSummaryCols = []columnInfo{ + {name: "DIGEST", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag, comment: "Digest of a transaction"}, + {name: txninfo.AllSQLDigestsStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "A list of the digests of SQL statements that the transaction has executed"}, +} + +var tablePlacementPoliciesCols = []columnInfo{ + {name: "POLICY_ID", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "CATALOG_NAME", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, + {name: "POLICY_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, // Catalog wide policy + {name: "PRIMARY_REGION", tp: mysql.TypeVarchar, size: 1024}, + {name: "REGIONS", tp: mysql.TypeVarchar, size: 1024}, + {name: "CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, + {name: "LEADER_CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, + {name: "FOLLOWER_CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, + {name: "LEARNER_CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, + {name: "SCHEDULE", tp: mysql.TypeVarchar, size: 20}, // EVEN or MAJORITY_IN_PRIMARY + {name: "FOLLOWERS", tp: mysql.TypeLonglong, size: 64}, + {name: "LEARNERS", tp: mysql.TypeLonglong, size: 64}, +} + +var tableVariablesInfoCols = []columnInfo{ + {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "VARIABLE_SCOPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "DEFAULT_VALUE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CURRENT_VALUE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "MIN_VALUE", tp: mysql.TypeLonglong, size: 64}, + {name: "MAX_VALUE", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, + {name: "POSSIBLE_VALUES", tp: mysql.TypeVarchar, size: 256}, + {name: "IS_NOOP", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, +} + +var tableUserAttributesCols = []columnInfo{ + {name: "USER", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "HOST", tp: mysql.TypeVarchar, size: 255, flag: mysql.NotNullFlag}, + {name: "ATTRIBUTE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, +} + +var tableMemoryUsageCols = []columnInfo{ + {name: "MEMORY_TOTAL", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "MEMORY_LIMIT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "MEMORY_CURRENT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "MEMORY_MAX_USED", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "CURRENT_OPS", tp: mysql.TypeVarchar, size: 50}, + {name: "SESSION_KILL_LAST", tp: mysql.TypeDatetime}, + {name: "SESSION_KILL_TOTAL", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "GC_LAST", tp: mysql.TypeDatetime}, + {name: "GC_TOTAL", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "DISK_USAGE", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "QUERY_FORCE_DISK", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, +} + +var tableMemoryUsageOpsHistoryCols = []columnInfo{ + {name: "TIME", tp: mysql.TypeDatetime, size: 64, flag: mysql.NotNullFlag}, + {name: "OPS", tp: mysql.TypeVarchar, size: 20, flag: mysql.NotNullFlag}, + {name: "MEMORY_LIMIT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "MEMORY_CURRENT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, + {name: "PROCESSID", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, + {name: "MEM", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, + {name: "DISK", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, + {name: "CLIENT", tp: mysql.TypeVarchar, size: 64}, + {name: "DB", tp: mysql.TypeVarchar, size: 64}, + {name: "USER", tp: mysql.TypeVarchar, size: 16}, + {name: "SQL_DIGEST", tp: mysql.TypeVarchar, size: 64}, + {name: "SQL_TEXT", tp: mysql.TypeVarchar, size: 256}, +} + +var tableResourceGroupsCols = []columnInfo{ + {name: "NAME", tp: mysql.TypeVarchar, size: resourcegroup.MaxGroupNameLength, flag: mysql.NotNullFlag}, + {name: "RU_PER_SEC", tp: mysql.TypeVarchar, size: 21}, + {name: "PRIORITY", tp: mysql.TypeVarchar, size: 6}, + {name: "BURSTABLE", tp: mysql.TypeVarchar, size: 3}, + {name: "QUERY_LIMIT", tp: mysql.TypeVarchar, size: 256}, + {name: "BACKGROUND", tp: mysql.TypeVarchar, size: 256}, +} + +var tableRunawayWatchListCols = []columnInfo{ + {name: "ID", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, + {name: "RESOURCE_GROUP_NAME", tp: mysql.TypeVarchar, size: resourcegroup.MaxGroupNameLength, flag: mysql.NotNullFlag}, + {name: "START_TIME", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, + {name: "END_TIME", tp: mysql.TypeVarchar, size: 32}, + {name: "WATCH", tp: mysql.TypeVarchar, size: 12, flag: mysql.NotNullFlag}, + {name: "WATCH_TEXT", tp: mysql.TypeBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, + {name: "SOURCE", tp: mysql.TypeVarchar, size: 128, flag: mysql.NotNullFlag}, + {name: "ACTION", tp: mysql.TypeVarchar, size: 12, flag: mysql.NotNullFlag}, +} + +// information_schema.CHECK_CONSTRAINTS +var tableCheckConstraintsCols = []columnInfo{ + {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CHECK_CLAUSE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, +} + +// information_schema.TIDB_CHECK_CONSTRAINTS +var tableTiDBCheckConstraintsCols = []columnInfo{ + {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, + {name: "CHECK_CLAUSE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, +} + +var tableKeywords = []columnInfo{ + {name: "WORD", tp: mysql.TypeVarchar, size: 128}, + {name: "RESERVED", tp: mysql.TypeLong, size: 11}, +} + +var tableTiDBIndexUsage = []columnInfo{ + {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, + {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, + {name: "QUERY_TOTAL", tp: mysql.TypeLonglong, size: 21}, + {name: "KV_REQ_TOTAL", tp: mysql.TypeLonglong, size: 21}, + {name: "ROWS_ACCESS_TOTAL", tp: mysql.TypeLonglong, size: 21}, + {name: "PERCENTAGE_ACCESS_0", tp: mysql.TypeLonglong, size: 21}, + {name: "PERCENTAGE_ACCESS_0_1", tp: mysql.TypeLonglong, size: 21}, + {name: "PERCENTAGE_ACCESS_1_10", tp: mysql.TypeLonglong, size: 21}, + {name: "PERCENTAGE_ACCESS_10_20", tp: mysql.TypeLonglong, size: 21}, + {name: "PERCENTAGE_ACCESS_20_50", tp: mysql.TypeLonglong, size: 21}, + {name: "PERCENTAGE_ACCESS_50_100", tp: mysql.TypeLonglong, size: 21}, + {name: "PERCENTAGE_ACCESS_100", tp: mysql.TypeLonglong, size: 21}, + {name: "LAST_ACCESS_TIME", tp: mysql.TypeDatetime, size: 21}, +} + +// GetShardingInfo returns a nil or description string for the sharding information of given TableInfo. +// The returned description string may be: +// - "NOT_SHARDED": for tables that SHARD_ROW_ID_BITS is not specified. +// - "NOT_SHARDED(PK_IS_HANDLE)": for tables of which primary key is row id. +// - "PK_AUTO_RANDOM_BITS={bit_number}, RANGE BITS={bit_number}": for tables of which primary key is sharded row id. +// - "SHARD_BITS={bit_number}": for tables that with SHARD_ROW_ID_BITS. +// +// The returned nil indicates that sharding information is not suitable for the table(for example, when the table is a View). +// This function is exported for unit test. +func GetShardingInfo(dbInfo model.CIStr, tableInfo *model.TableInfo) any { + if tableInfo == nil || tableInfo.IsView() || util.IsMemOrSysDB(dbInfo.L) { + return nil + } + shardingInfo := "NOT_SHARDED" + if tableInfo.ContainsAutoRandomBits() { + shardingInfo = "PK_AUTO_RANDOM_BITS=" + strconv.Itoa(int(tableInfo.AutoRandomBits)) + rangeBits := tableInfo.AutoRandomRangeBits + if rangeBits != 0 && rangeBits != autoid.AutoRandomRangeBitsDefault { + shardingInfo = fmt.Sprintf("%s, RANGE BITS=%d", shardingInfo, rangeBits) + } + } else if tableInfo.ShardRowIDBits > 0 { + shardingInfo = "SHARD_BITS=" + strconv.Itoa(int(tableInfo.ShardRowIDBits)) + } else if tableInfo.PKIsHandle { + shardingInfo = "NOT_SHARDED(PK_IS_HANDLE)" + } + return shardingInfo +} + +const ( + // PrimaryKeyType is the string constant of PRIMARY KEY. + PrimaryKeyType = "PRIMARY KEY" + // PrimaryConstraint is the string constant of PRIMARY. + PrimaryConstraint = "PRIMARY" + // UniqueKeyType is the string constant of UNIQUE. + UniqueKeyType = "UNIQUE" + // ForeignKeyType is the string constant of Foreign Key. + ForeignKeyType = "FOREIGN KEY" +) + +const ( + // TiFlashWrite is the TiFlash write node in disaggregated mode. + TiFlashWrite = "tiflash_write" +) + +// ServerInfo represents the basic server information of single cluster component +type ServerInfo struct { + ServerType string + Address string + StatusAddr string + Version string + GitHash string + StartTimestamp int64 + ServerID uint64 + EngineRole string +} + +func (s *ServerInfo) isLoopBackOrUnspecifiedAddr(addr string) bool { + tcpAddr, err := net.ResolveTCPAddr("", addr) + if err != nil { + return false + } + ip := net.ParseIP(tcpAddr.IP.String()) + return ip != nil && (ip.IsUnspecified() || ip.IsLoopback()) +} + +// ResolveLoopBackAddr exports for testing. +func (s *ServerInfo) ResolveLoopBackAddr() { + if s.isLoopBackOrUnspecifiedAddr(s.Address) && !s.isLoopBackOrUnspecifiedAddr(s.StatusAddr) { + addr, err1 := net.ResolveTCPAddr("", s.Address) + statusAddr, err2 := net.ResolveTCPAddr("", s.StatusAddr) + if err1 == nil && err2 == nil { + addr.IP = statusAddr.IP + s.Address = addr.String() + } + } else if !s.isLoopBackOrUnspecifiedAddr(s.Address) && s.isLoopBackOrUnspecifiedAddr(s.StatusAddr) { + addr, err1 := net.ResolveTCPAddr("", s.Address) + statusAddr, err2 := net.ResolveTCPAddr("", s.StatusAddr) + if err1 == nil && err2 == nil { + statusAddr.IP = addr.IP + s.StatusAddr = statusAddr.String() + } + } +} + +// GetClusterServerInfo returns all components information of cluster +func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { + failpoint.Inject("mockClusterInfo", func(val failpoint.Value) { + // The cluster topology is injected by `failpoint` expression and + // there is no extra checks for it. (let the test fail if the expression invalid) + if s := val.(string); len(s) > 0 { + var servers []ServerInfo + for _, server := range strings.Split(s, ";") { + parts := strings.Split(server, ",") + serverID, err := strconv.ParseUint(parts[5], 10, 64) + if err != nil { + panic("convert parts[5] to uint64 failed") + } + servers = append(servers, ServerInfo{ + ServerType: parts[0], + Address: parts[1], + StatusAddr: parts[2], + Version: parts[3], + GitHash: parts[4], + ServerID: serverID, + }) + } + failpoint.Return(servers, nil) + } + }) + + type retriever func(ctx sessionctx.Context) ([]ServerInfo, error) + retrievers := []retriever{GetTiDBServerInfo, GetPDServerInfo, func(ctx sessionctx.Context) ([]ServerInfo, error) { + return GetStoreServerInfo(ctx.GetStore()) + }, GetTiProxyServerInfo, GetTiCDCServerInfo, GetTSOServerInfo, GetSchedulingServerInfo} + //nolint: prealloc + var servers []ServerInfo + for _, r := range retrievers { + nodes, err := r(ctx) + if err != nil { + return nil, err + } + for i := range nodes { + nodes[i].ResolveLoopBackAddr() + } + servers = append(servers, nodes...) + } + return servers, nil +} + +// GetTiDBServerInfo returns all TiDB nodes information of cluster +func GetTiDBServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { + // Get TiDB servers info. + tidbNodes, err := infosync.GetAllServerInfo(context.Background()) + if err != nil { + return nil, errors.Trace(err) + } + var isDefaultVersion bool + if len(config.GetGlobalConfig().ServerVersion) == 0 { + isDefaultVersion = true + } + var servers = make([]ServerInfo, 0, len(tidbNodes)) + for _, node := range tidbNodes { + servers = append(servers, ServerInfo{ + ServerType: "tidb", + Address: net.JoinHostPort(node.IP, strconv.Itoa(int(node.Port))), + StatusAddr: net.JoinHostPort(node.IP, strconv.Itoa(int(node.StatusPort))), + Version: FormatTiDBVersion(node.Version, isDefaultVersion), + GitHash: node.GitHash, + StartTimestamp: node.StartTimestamp, + ServerID: node.ServerIDGetter(), + }) + } + return servers, nil +} + +// FormatTiDBVersion make TiDBVersion consistent to TiKV and PD. +// The default TiDBVersion is 5.7.25-TiDB-${TiDBReleaseVersion}. +func FormatTiDBVersion(TiDBVersion string, isDefaultVersion bool) string { + var version, nodeVersion string + + // The user hasn't set the config 'ServerVersion'. + if isDefaultVersion { + nodeVersion = TiDBVersion[strings.Index(TiDBVersion, "TiDB-")+len("TiDB-"):] + if len(nodeVersion) > 0 && nodeVersion[0] == 'v' { + nodeVersion = nodeVersion[1:] + } + nodeVersions := strings.SplitN(nodeVersion, "-", 2) + if len(nodeVersions) == 1 { + version = nodeVersions[0] + } else if len(nodeVersions) >= 2 { + version = fmt.Sprintf("%s-%s", nodeVersions[0], nodeVersions[1]) + } + } else { // The user has already set the config 'ServerVersion',it would be a complex scene, so just use the 'ServerVersion' as version. + version = TiDBVersion + } + + return version +} + +// GetPDServerInfo returns all PD nodes information of cluster +func GetPDServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { + // Get PD servers info. + members, err := getEtcdMembers(ctx) + if err != nil { + return nil, err + } + // TODO: maybe we should unify the PD API request interface. + var ( + memberNum = len(members) + servers = make([]ServerInfo, 0, memberNum) + errs = make([]error, 0, memberNum) + ) + if memberNum == 0 { + return servers, nil + } + // Try on each member until one succeeds or all fail. + for _, addr := range members { + // Get PD version, git_hash + url := fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), addr, pd.Status) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(err) + logutil.BgLogger().Warn("create pd server info request error", zap.String("url", url), zap.Error(err)) + errs = append(errs, err) + continue + } + req.Header.Add("PD-Allow-follower-handle", "true") + resp, err := util.InternalHTTPClient().Do(req) + if err != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(err) + logutil.BgLogger().Warn("request pd server info error", zap.String("url", url), zap.Error(err)) + errs = append(errs, err) + continue + } + var content = struct { + Version string `json:"version"` + GitHash string `json:"git_hash"` + StartTimestamp int64 `json:"start_timestamp"` + }{} + err = json.NewDecoder(resp.Body).Decode(&content) + terror.Log(resp.Body.Close()) + if err != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(err) + logutil.BgLogger().Warn("close pd server info request error", zap.String("url", url), zap.Error(err)) + errs = append(errs, err) + continue + } + if len(content.Version) > 0 && content.Version[0] == 'v' { + content.Version = content.Version[1:] + } + + servers = append(servers, ServerInfo{ + ServerType: "pd", + Address: addr, + StatusAddr: addr, + Version: content.Version, + GitHash: content.GitHash, + StartTimestamp: content.StartTimestamp, + }) + } + // Return the errors if all members' requests fail. + if len(errs) == memberNum { + errorMsg := "" + for idx, err := range errs { + errorMsg += err.Error() + if idx < memberNum-1 { + errorMsg += "; " + } + } + return nil, errors.Trace(fmt.Errorf("%s", errorMsg)) + } + return servers, nil +} + +// GetTSOServerInfo returns all TSO nodes information of cluster +func GetTSOServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { + return getMicroServiceServerInfo(ctx, tsoServiceName) +} + +// GetSchedulingServerInfo returns all scheduling nodes information of cluster +func GetSchedulingServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { + return getMicroServiceServerInfo(ctx, schedulingServiceName) +} + +func getMicroServiceServerInfo(ctx sessionctx.Context, serviceName string) ([]ServerInfo, error) { + members, err := getEtcdMembers(ctx) + if err != nil { + return nil, err + } + // TODO: maybe we should unify the PD API request interface. + var servers []ServerInfo + + if len(members) == 0 { + return servers, nil + } + // Try on each member until one succeeds or all fail. + for _, addr := range members { + // Get members + url := fmt.Sprintf("%s://%s%s/%s", util.InternalHTTPSchema(), addr, "/pd/api/v2/ms/members", serviceName) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(err) + logutil.BgLogger().Warn("create microservice server info request error", zap.String("service", serviceName), zap.String("url", url), zap.Error(err)) + continue + } + req.Header.Add("PD-Allow-follower-handle", "true") + resp, err := util.InternalHTTPClient().Do(req) + if err != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(err) + logutil.BgLogger().Warn("request microservice server info error", zap.String("service", serviceName), zap.String("url", url), zap.Error(err)) + continue + } + if resp.StatusCode != http.StatusOK { + terror.Log(resp.Body.Close()) + continue + } + var content = []struct { + ServiceAddr string `json:"service-addr"` + Version string `json:"version"` + GitHash string `json:"git-hash"` + DeployPath string `json:"deploy-path"` + StartTimestamp int64 `json:"start-timestamp"` + }{} + err = json.NewDecoder(resp.Body).Decode(&content) + terror.Log(resp.Body.Close()) + if err != nil { + ctx.GetSessionVars().StmtCtx.AppendWarning(err) + logutil.BgLogger().Warn("close microservice server info request error", zap.String("service", serviceName), zap.String("url", url), zap.Error(err)) + continue + } + + for _, c := range content { + addr := strings.TrimPrefix(c.ServiceAddr, "http://") + addr = strings.TrimPrefix(addr, "https://") + if len(c.Version) > 0 && c.Version[0] == 'v' { + c.Version = c.Version[1:] + } + servers = append(servers, ServerInfo{ + ServerType: serviceName, + Address: addr, + StatusAddr: addr, + Version: c.Version, + GitHash: c.GitHash, + StartTimestamp: c.StartTimestamp, + }) + } + return servers, nil + } + return servers, nil +} + +func getEtcdMembers(ctx sessionctx.Context) ([]string, error) { + store := ctx.GetStore() + etcd, ok := store.(kv.EtcdBackend) + if !ok { + return nil, errors.Errorf("%T not an etcd backend", store) + } + members, err := etcd.EtcdAddrs() + if err != nil { + return nil, errors.Trace(err) + } + return members, nil +} + +func isTiFlashStore(store *metapb.Store) bool { + for _, label := range store.Labels { + if label.GetKey() == placement.EngineLabelKey && label.GetValue() == placement.EngineLabelTiFlash { + return true + } + } + return false +} + +func isTiFlashWriteNode(store *metapb.Store) bool { + for _, label := range store.Labels { + if label.GetKey() == placement.EngineRoleLabelKey && label.GetValue() == placement.EngineRoleLabelWrite { + return true + } + } + return false +} + +// GetStoreServerInfo returns all store nodes(TiKV or TiFlash) cluster information +func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { + failpoint.Inject("mockStoreServerInfo", func(val failpoint.Value) { + if s := val.(string); len(s) > 0 { + var servers []ServerInfo + for _, server := range strings.Split(s, ";") { + parts := strings.Split(server, ",") + servers = append(servers, ServerInfo{ + ServerType: parts[0], + Address: parts[1], + StatusAddr: parts[2], + Version: parts[3], + GitHash: parts[4], + StartTimestamp: 0, + }) + } + failpoint.Return(servers, nil) + } + }) + + // Get TiKV servers info. + tikvStore, ok := store.(tikv.Storage) + if !ok { + return nil, errors.Errorf("%T is not an TiKV or TiFlash store instance", store) + } + pdClient := tikvStore.GetRegionCache().PDClient() + if pdClient == nil { + return nil, errors.New("pd unavailable") + } + stores, err := pdClient.GetAllStores(context.Background()) + if err != nil { + return nil, errors.Trace(err) + } + servers := make([]ServerInfo, 0, len(stores)) + for _, store := range stores { + failpoint.Inject("mockStoreTombstone", func(val failpoint.Value) { + if val.(bool) { + store.State = metapb.StoreState_Tombstone + } + }) + + if store.GetState() == metapb.StoreState_Tombstone { + continue + } + var tp string + if isTiFlashStore(store) { + tp = kv.TiFlash.Name() + } else { + tp = tikv.GetStoreTypeByMeta(store).Name() + } + var engineRole string + if isTiFlashWriteNode(store) { + engineRole = placement.EngineRoleLabelWrite + } + servers = append(servers, ServerInfo{ + ServerType: tp, + Address: store.Address, + StatusAddr: store.StatusAddress, + Version: FormatStoreServerVersion(store.Version), + GitHash: store.GitHash, + StartTimestamp: store.StartTimestamp, + EngineRole: engineRole, + }) + } + return servers, nil +} + +// FormatStoreServerVersion format version of store servers(Tikv or TiFlash) +func FormatStoreServerVersion(version string) string { + if len(version) >= 1 && version[0] == 'v' { + version = version[1:] + } + return version +} + +// GetTiFlashStoreCount returns the count of tiflash server. +func GetTiFlashStoreCount(ctx sessionctx.Context) (cnt uint64, err error) { + failpoint.Inject("mockTiFlashStoreCount", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(uint64(10), nil) + } + }) + + stores, err := GetStoreServerInfo(ctx.GetStore()) + if err != nil { + return cnt, err + } + for _, store := range stores { + if store.ServerType == kv.TiFlash.Name() { + cnt++ + } + } + return cnt, nil +} + +// GetTiProxyServerInfo gets server info of TiProxy from PD. +func GetTiProxyServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { + tiproxyNodes, err := infosync.GetTiProxyServerInfo(context.Background()) + if err != nil { + return nil, errors.Trace(err) + } + var servers = make([]ServerInfo, 0, len(tiproxyNodes)) + for _, node := range tiproxyNodes { + servers = append(servers, ServerInfo{ + ServerType: "tiproxy", + Address: net.JoinHostPort(node.IP, node.Port), + StatusAddr: net.JoinHostPort(node.IP, node.StatusPort), + Version: node.Version, + GitHash: node.GitHash, + StartTimestamp: node.StartTimestamp, + }) + } + return servers, nil +} + +// GetTiCDCServerInfo gets server info of TiCDC from PD. +func GetTiCDCServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { + ticdcNodes, err := infosync.GetTiCDCServerInfo(context.Background()) + if err != nil { + return nil, errors.Trace(err) + } + var servers = make([]ServerInfo, 0, len(ticdcNodes)) + for _, node := range ticdcNodes { + servers = append(servers, ServerInfo{ + ServerType: "ticdc", + Address: node.Address, + StatusAddr: node.Address, + Version: node.Version, + GitHash: node.GitHash, + StartTimestamp: node.StartTimestamp, + }) + } + return servers, nil +} + +// SysVarHiddenForSem checks if a given sysvar is hidden according to SEM and privileges. +func SysVarHiddenForSem(ctx sessionctx.Context, sysVarNameInLower string) bool { + if !sem.IsEnabled() || !sem.IsInvisibleSysVar(sysVarNameInLower) { + return false + } + checker := privilege.GetPrivilegeManager(ctx) + if checker == nil || checker.RequestDynamicVerification(ctx.GetSessionVars().ActiveRoles, "RESTRICTED_VARIABLES_ADMIN", false) { + return false + } + return true +} + +// GetDataFromSessionVariables return the [name, value] of all session variables +func GetDataFromSessionVariables(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { + sessionVars := sctx.GetSessionVars() + sysVars := variable.GetSysVars() + rows := make([][]types.Datum, 0, len(sysVars)) + for _, v := range sysVars { + if SysVarHiddenForSem(sctx, v.Name) { + continue + } + var value string + value, err := sessionVars.GetSessionOrGlobalSystemVar(ctx, v.Name) + if err != nil { + return nil, err + } + row := types.MakeDatums(v.Name, value) + rows = append(rows, row) + } + return rows, nil +} + +// GetDataFromSessionConnectAttrs produces the rows for the session_connect_attrs table. +func GetDataFromSessionConnectAttrs(sctx sessionctx.Context, sameAccount bool) ([][]types.Datum, error) { + sm := sctx.GetSessionManager() + if sm == nil { + return nil, nil + } + var user *auth.UserIdentity + if sameAccount { + user = sctx.GetSessionVars().User + } + allAttrs := sm.GetConAttrs(user) + rows := make([][]types.Datum, 0, len(allAttrs)*10) // 10 Attributes per connection + for pid, attrs := range allAttrs { // Note: PID is not ordered. + // Sorts the attributes by key and gives ORDINAL_POSITION based on this. This is needed as we didn't store the + // ORDINAL_POSITION and a map doesn't have a guaranteed sort order. This is needed to keep the ORDINAL_POSITION + // stable over multiple queries. + attrnames := make([]string, 0, len(attrs)) + for attrname := range attrs { + attrnames = append(attrnames, attrname) + } + sort.Strings(attrnames) + + for ord, attrkey := range attrnames { + row := types.MakeDatums( + pid, + attrkey, + attrs[attrkey], + ord, + ) + rows = append(rows, row) + } + } + return rows, nil +} + +var tableNameToColumns = map[string][]columnInfo{ + TableSchemata: schemataCols, + TableTables: tablesCols, + TableColumns: columnsCols, + tableColumnStatistics: columnStatisticsCols, + TableStatistics: statisticsCols, + TableCharacterSets: charsetCols, + TableCollations: collationsCols, + tableFiles: filesCols, + TableProfiling: profilingCols, + TablePartitions: partitionsCols, + TableKeyColumn: keyColumnUsageCols, + TableReferConst: referConstCols, + TableSessionVar: sessionVarCols, + tablePlugins: pluginsCols, + TableConstraints: tableConstraintsCols, + tableTriggers: tableTriggersCols, + TableUserPrivileges: tableUserPrivilegesCols, + tableSchemaPrivileges: tableSchemaPrivilegesCols, + tableTablePrivileges: tableTablePrivilegesCols, + tableColumnPrivileges: tableColumnPrivilegesCols, + TableEngines: tableEnginesCols, + TableViews: tableViewsCols, + tableRoutines: tableRoutinesCols, + tableParameters: tableParametersCols, + tableEvents: tableEventsCols, + tableGlobalStatus: tableGlobalStatusCols, + tableGlobalVariables: tableGlobalVariablesCols, + tableSessionStatus: tableSessionStatusCols, + tableOptimizerTrace: tableOptimizerTraceCols, + tableTableSpaces: tableTableSpacesCols, + TableCollationCharacterSetApplicability: tableCollationCharacterSetApplicabilityCols, + TableProcesslist: tableProcesslistCols, + TableTiDBIndexes: tableTiDBIndexesCols, + TableSlowQuery: slowQueryCols, + TableTiDBHotRegions: TableTiDBHotRegionsCols, + TableTiDBHotRegionsHistory: TableTiDBHotRegionsHistoryCols, + TableTiKVStoreStatus: TableTiKVStoreStatusCols, + TableAnalyzeStatus: tableAnalyzeStatusCols, + TableTiKVRegionStatus: TableTiKVRegionStatusCols, + TableTiKVRegionPeers: TableTiKVRegionPeersCols, + TableTiDBServersInfo: tableTiDBServersInfoCols, + TableClusterInfo: tableClusterInfoCols, + TableClusterConfig: tableClusterConfigCols, + TableClusterLog: tableClusterLogCols, + TableClusterLoad: tableClusterLoadCols, + TableTiFlashReplica: tableTableTiFlashReplicaCols, + TableClusterHardware: tableClusterHardwareCols, + TableClusterSystemInfo: tableClusterSystemInfoCols, + TableInspectionResult: tableInspectionResultCols, + TableMetricSummary: tableMetricSummaryCols, + TableMetricSummaryByLabel: tableMetricSummaryByLabelCols, + TableMetricTables: tableMetricTablesCols, + TableInspectionSummary: tableInspectionSummaryCols, + TableInspectionRules: tableInspectionRulesCols, + TableDDLJobs: tableDDLJobsCols, + TableSequences: tableSequencesCols, + TableStatementsSummary: tableStatementsSummaryCols, + TableStatementsSummaryHistory: tableStatementsSummaryCols, + TableStatementsSummaryEvicted: tableStatementsSummaryEvictedCols, + TableStorageStats: tableStorageStatsCols, + TableTiFlashTables: tableTableTiFlashTablesCols, + TableTiFlashSegments: tableTableTiFlashSegmentsCols, + TableClientErrorsSummaryGlobal: tableClientErrorsSummaryGlobalCols, + TableClientErrorsSummaryByUser: tableClientErrorsSummaryByUserCols, + TableClientErrorsSummaryByHost: tableClientErrorsSummaryByHostCols, + TableTiDBTrx: tableTiDBTrxCols, + TableDeadlocks: tableDeadlocksCols, + TableDataLockWaits: tableDataLockWaitsCols, + TableAttributes: tableAttributesCols, + TablePlacementPolicies: tablePlacementPoliciesCols, + TableTrxSummary: tableTrxSummaryCols, + TableVariablesInfo: tableVariablesInfoCols, + TableUserAttributes: tableUserAttributesCols, + TableMemoryUsage: tableMemoryUsageCols, + TableMemoryUsageOpsHistory: tableMemoryUsageOpsHistoryCols, + TableResourceGroups: tableResourceGroupsCols, + TableRunawayWatches: tableRunawayWatchListCols, + TableCheckConstraints: tableCheckConstraintsCols, + TableTiDBCheckConstraints: tableTiDBCheckConstraintsCols, + TableKeywords: tableKeywords, + TableTiDBIndexUsage: tableTiDBIndexUsage, +} + +func createInfoSchemaTable(_ autoid.Allocators, _ func() (pools.Resource, error), meta *model.TableInfo) (table.Table, error) { + columns := make([]*table.Column, len(meta.Columns)) + for i, col := range meta.Columns { + columns[i] = table.ToColumn(col) + } + tp := table.VirtualTable + if isClusterTableByName(util.InformationSchemaName.O, meta.Name.O) { + tp = table.ClusterTable + } + return &infoschemaTable{meta: meta, cols: columns, tp: tp}, nil +} + +type infoschemaTable struct { + meta *model.TableInfo + cols []*table.Column + tp table.Type +} + +// IterRecords implements table.Table IterRecords interface. +func (*infoschemaTable) IterRecords(ctx context.Context, sctx sessionctx.Context, cols []*table.Column, fn table.RecordIterFunc) error { + return nil +} + +// Cols implements table.Table Cols interface. +func (it *infoschemaTable) Cols() []*table.Column { + return it.cols +} + +// VisibleCols implements table.Table VisibleCols interface. +func (it *infoschemaTable) VisibleCols() []*table.Column { + return it.cols +} + +// HiddenCols implements table.Table HiddenCols interface. +func (it *infoschemaTable) HiddenCols() []*table.Column { + return nil +} + +// WritableCols implements table.Table WritableCols interface. +func (it *infoschemaTable) WritableCols() []*table.Column { + return it.cols +} + +// DeletableCols implements table.Table WritableCols interface. +func (it *infoschemaTable) DeletableCols() []*table.Column { + return it.cols +} + +// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. +func (it *infoschemaTable) FullHiddenColsAndVisibleCols() []*table.Column { + return it.cols +} + +// Indices implements table.Table Indices interface. +func (it *infoschemaTable) Indices() []table.Index { + return nil +} + +// WritableConstraint implements table.Table WritableConstraint interface. +func (it *infoschemaTable) WritableConstraint() []*table.Constraint { + return nil +} + +// RecordPrefix implements table.Table RecordPrefix interface. +func (it *infoschemaTable) RecordPrefix() kv.Key { + return nil +} + +// IndexPrefix implements table.Table IndexPrefix interface. +func (it *infoschemaTable) IndexPrefix() kv.Key { + return nil +} + +// AddRecord implements table.Table AddRecord interface. +func (it *infoschemaTable) AddRecord(ctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { + return nil, table.ErrUnsupportedOp +} + +// RemoveRecord implements table.Table RemoveRecord interface. +func (it *infoschemaTable) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []types.Datum) error { + return table.ErrUnsupportedOp +} + +// UpdateRecord implements table.Table UpdateRecord interface. +func (it *infoschemaTable) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { + return table.ErrUnsupportedOp +} + +// Allocators implements table.Table Allocators interface. +func (it *infoschemaTable) Allocators(_ table.AllocatorContext) autoid.Allocators { + return autoid.Allocators{} +} + +// Meta implements table.Table Meta interface. +func (it *infoschemaTable) Meta() *model.TableInfo { + return it.meta +} + +// GetPhysicalID implements table.Table GetPhysicalID interface. +func (it *infoschemaTable) GetPhysicalID() int64 { + return it.meta.ID +} + +// Type implements table.Table Type interface. +func (it *infoschemaTable) Type() table.Type { + return it.tp +} + +// GetPartitionedTable implements table.Table GetPartitionedTable interface. +func (it *infoschemaTable) GetPartitionedTable() table.PartitionedTable { + return nil +} + +// VirtualTable is a dummy table.Table implementation. +type VirtualTable struct{} + +// Cols implements table.Table Cols interface. +func (vt *VirtualTable) Cols() []*table.Column { + return nil +} + +// VisibleCols implements table.Table VisibleCols interface. +func (vt *VirtualTable) VisibleCols() []*table.Column { + return nil +} + +// HiddenCols implements table.Table HiddenCols interface. +func (vt *VirtualTable) HiddenCols() []*table.Column { + return nil +} + +// WritableCols implements table.Table WritableCols interface. +func (vt *VirtualTable) WritableCols() []*table.Column { + return nil +} + +// DeletableCols implements table.Table WritableCols interface. +func (vt *VirtualTable) DeletableCols() []*table.Column { + return nil +} + +// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. +func (vt *VirtualTable) FullHiddenColsAndVisibleCols() []*table.Column { + return nil +} + +// Indices implements table.Table Indices interface. +func (vt *VirtualTable) Indices() []table.Index { + return nil +} + +// WritableConstraint implements table.Table WritableConstraint interface. +func (vt *VirtualTable) WritableConstraint() []*table.Constraint { + return nil +} + +// RecordPrefix implements table.Table RecordPrefix interface. +func (vt *VirtualTable) RecordPrefix() kv.Key { + return nil +} + +// IndexPrefix implements table.Table IndexPrefix interface. +func (vt *VirtualTable) IndexPrefix() kv.Key { + return nil +} + +// AddRecord implements table.Table AddRecord interface. +func (vt *VirtualTable) AddRecord(ctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { + return nil, table.ErrUnsupportedOp +} + +// RemoveRecord implements table.Table RemoveRecord interface. +func (vt *VirtualTable) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []types.Datum) error { + return table.ErrUnsupportedOp +} + +// UpdateRecord implements table.Table UpdateRecord interface. +func (vt *VirtualTable) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { + return table.ErrUnsupportedOp +} + +// Allocators implements table.Table Allocators interface. +func (vt *VirtualTable) Allocators(_ table.AllocatorContext) autoid.Allocators { + return autoid.Allocators{} +} + +// Meta implements table.Table Meta interface. +func (vt *VirtualTable) Meta() *model.TableInfo { + return nil +} + +// GetPhysicalID implements table.Table GetPhysicalID interface. +func (vt *VirtualTable) GetPhysicalID() int64 { + return 0 +} + +// Type implements table.Table Type interface. +func (vt *VirtualTable) Type() table.Type { + return table.VirtualTable +} + +// GetTiFlashServerInfo returns all TiFlash server infos +func GetTiFlashServerInfo(store kv.Storage) ([]ServerInfo, error) { + if config.GetGlobalConfig().DisaggregatedTiFlash { + return nil, table.ErrUnsupportedOp + } + serversInfo, err := GetStoreServerInfo(store) + if err != nil { + return nil, err + } + serversInfo = FilterClusterServerInfo(serversInfo, set.NewStringSet(kv.TiFlash.Name()), set.NewStringSet()) + return serversInfo, nil +} + +// FetchClusterServerInfoWithoutPrivilegeCheck fetches cluster server information +func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, vars *variable.SessionVars, serversInfo []ServerInfo, serverInfoType diagnosticspb.ServerInfoType, recordWarningInStmtCtx bool) ([][]types.Datum, error) { + type result struct { + idx int + rows [][]types.Datum + err error + } + wg := sync.WaitGroup{} + ch := make(chan result, len(serversInfo)) + infoTp := serverInfoType + finalRows := make([][]types.Datum, 0, len(serversInfo)*10) + for i, srv := range serversInfo { + address := srv.Address + remote := address + if srv.ServerType == "tidb" || srv.ServerType == "tiproxy" { + remote = srv.StatusAddr + } + wg.Add(1) + go func(index int, remote, address, serverTP string) { + util.WithRecovery(func() { + defer wg.Done() + items, err := getServerInfoByGRPC(ctx, remote, infoTp) + if err != nil { + ch <- result{idx: index, err: err} + return + } + partRows := serverInfoItemToRows(items, serverTP, address) + ch <- result{idx: index, rows: partRows} + }, nil) + }(i, remote, address, srv.ServerType) + } + wg.Wait() + close(ch) + // Keep the original order to make the result more stable + var results []result //nolint: prealloc + for result := range ch { + if result.err != nil { + if recordWarningInStmtCtx { + vars.StmtCtx.AppendWarning(result.err) + } else { + log.Warn(result.err.Error()) + } + continue + } + results = append(results, result) + } + slices.SortFunc(results, func(i, j result) int { return cmp.Compare(i.idx, j.idx) }) + for _, result := range results { + finalRows = append(finalRows, result.rows...) + } + return finalRows, nil +} + +func serverInfoItemToRows(items []*diagnosticspb.ServerInfoItem, tp, addr string) [][]types.Datum { + rows := make([][]types.Datum, 0, len(items)) + for _, v := range items { + for _, item := range v.Pairs { + row := types.MakeDatums( + tp, + addr, + v.Tp, + v.Name, + item.Key, + item.Value, + ) + rows = append(rows, row) + } + } + return rows +} + +func getServerInfoByGRPC(ctx context.Context, address string, tp diagnosticspb.ServerInfoType) ([]*diagnosticspb.ServerInfoItem, error) { + opt := grpc.WithTransportCredentials(insecure.NewCredentials()) + security := config.GetGlobalConfig().Security + if len(security.ClusterSSLCA) != 0 { + clusterSecurity := security.ClusterSecurity() + tlsConfig, err := clusterSecurity.ToTLSConfig() + if err != nil { + return nil, errors.Trace(err) + } + opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) + } + conn, err := grpc.Dial(address, opt) + if err != nil { + return nil, err + } + defer func() { + err := conn.Close() + if err != nil { + log.Error("close grpc connection error", zap.Error(err)) + } + }() + + cli := diagnosticspb.NewDiagnosticsClient(conn) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + defer cancel() + r, err := cli.ServerInfo(ctx, &diagnosticspb.ServerInfoRequest{Tp: tp}) + if err != nil { + return nil, err + } + return r.Items, nil +} + +// FilterClusterServerInfo filters serversInfo by nodeTypes and addresses +func FilterClusterServerInfo(serversInfo []ServerInfo, nodeTypes, addresses set.StringSet) []ServerInfo { + if len(nodeTypes) == 0 && len(addresses) == 0 { + return serversInfo + } + + filterServers := make([]ServerInfo, 0, len(serversInfo)) + for _, srv := range serversInfo { + // Skip some node type which has been filtered in WHERE clause + // e.g: SELECT * FROM cluster_config WHERE type='tikv' + if len(nodeTypes) > 0 && !nodeTypes.Exist(srv.ServerType) { + continue + } + // Skip some node address which has been filtered in WHERE clause + // e.g: SELECT * FROM cluster_config WHERE address='192.16.8.12:2379' + if len(addresses) > 0 && !addresses.Exist(srv.Address) { + continue + } + filterServers = append(filterServers, srv) + } + return filterServers +} diff --git a/pkg/kv/binding__failpoint_binding__.go b/pkg/kv/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..91ba6650c6d47 --- /dev/null +++ b/pkg/kv/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package kv + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/kv/txn.go b/pkg/kv/txn.go index 5e0d3399f9d1e..64acfa199de55 100644 --- a/pkg/kv/txn.go +++ b/pkg/kv/txn.go @@ -143,7 +143,7 @@ func RunInNewTxn(ctx context.Context, store Storage, retryable bool, f func(ctx return err } - failpoint.Inject("mockCommitErrorInNewTxn", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockCommitErrorInNewTxn")); _err_ == nil { if v := val.(string); len(v) > 0 { switch v { case "retry_once": @@ -152,10 +152,10 @@ func RunInNewTxn(ctx context.Context, store Storage, retryable bool, f func(ctx err = ErrTxnRetryable } case "no_retry": - failpoint.Return(errors.New("mock commit error")) + return errors.New("mock commit error") } } - }) + } if err == nil { err = txn.Commit(ctx) @@ -223,7 +223,7 @@ func setRequestSourceForInnerTxn(ctx context.Context, txn Transaction) { // SetTxnResourceGroup update the resource group name of target txn. func SetTxnResourceGroup(txn Transaction, name string) { txn.SetOption(ResourceGroupName, name) - failpoint.Inject("TxnResourceGroupChecker", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("TxnResourceGroupChecker")); _err_ == nil { expectedRgName := val.(string) validateRNameInterceptor := func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { @@ -243,5 +243,5 @@ func SetTxnResourceGroup(txn Transaction, name string) { } } txn.SetOption(RPCInterceptor, interceptor.NewRPCInterceptor("test-validate-rg-name", validateRNameInterceptor)) - }) + } } diff --git a/pkg/kv/txn.go__failpoint_stash__ b/pkg/kv/txn.go__failpoint_stash__ new file mode 100644 index 0000000000000..5e0d3399f9d1e --- /dev/null +++ b/pkg/kv/txn.go__failpoint_stash__ @@ -0,0 +1,247 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kv + +import ( + "context" + "errors" + "fmt" + "math" + "math/rand" + "sync" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/tikvrpc/interceptor" + "go.uber.org/zap" +) + +const ( + // TimeToPrintLongTimeInternalTxn is the duration if the internal transaction lasts more than it, + // TiDB prints a log message. + TimeToPrintLongTimeInternalTxn = time.Minute * 5 +) + +var globalInnerTxnTsBox = innerTxnStartTsBox{ + innerTSLock: sync.Mutex{}, + innerTxnStartTsMap: make(map[uint64]struct{}, 256), +} + +type innerTxnStartTsBox struct { + innerTSLock sync.Mutex + innerTxnStartTsMap map[uint64]struct{} +} + +func (ib *innerTxnStartTsBox) storeInnerTxnTS(startTS uint64) { + ib.innerTSLock.Lock() + ib.innerTxnStartTsMap[startTS] = struct{}{} + ib.innerTSLock.Unlock() +} + +func (ib *innerTxnStartTsBox) deleteInnerTxnTS(startTS uint64) { + ib.innerTSLock.Lock() + delete(ib.innerTxnStartTsMap, startTS) + ib.innerTSLock.Unlock() +} + +// GetMinInnerTxnStartTS get the min StartTS between startTSLowerLimit and curMinStartTS in globalInnerTxnTsBox. +func GetMinInnerTxnStartTS(now time.Time, startTSLowerLimit uint64, + curMinStartTS uint64) uint64 { + return globalInnerTxnTsBox.getMinStartTS(now, startTSLowerLimit, curMinStartTS) +} + +func (ib *innerTxnStartTsBox) getMinStartTS(now time.Time, startTSLowerLimit uint64, + curMinStartTS uint64) uint64 { + minStartTS := curMinStartTS + ib.innerTSLock.Lock() + for innerTS := range ib.innerTxnStartTsMap { + PrintLongTimeInternalTxn(now, innerTS, true) + if innerTS > startTSLowerLimit && innerTS < minStartTS { + minStartTS = innerTS + } + } + ib.innerTSLock.Unlock() + return minStartTS +} + +// PrintLongTimeInternalTxn print the internal transaction information. +// runByFunction true means the transaction is run by `RunInNewTxn`, +// +// false means the transaction is run by internal session. +func PrintLongTimeInternalTxn(now time.Time, startTS uint64, runByFunction bool) { + if startTS > 0 { + innerTxnStartTime := oracle.GetTimeFromTS(startTS) + if now.Sub(innerTxnStartTime) > TimeToPrintLongTimeInternalTxn { + callerName := "internal session" + if runByFunction { + callerName = "RunInNewTxn" + } + infoHeader := fmt.Sprintf("An internal transaction running by %s lasts long time", callerName) + + logutil.BgLogger().Info(infoHeader, + zap.Duration("time", now.Sub(innerTxnStartTime)), zap.Uint64("startTS", startTS), + zap.Time("start time", innerTxnStartTime)) + } + } +} + +// RunInNewTxn will run the f in a new transaction environment, should be used by inner txn only. +func RunInNewTxn(ctx context.Context, store Storage, retryable bool, f func(ctx context.Context, txn Transaction) error) error { + var ( + err error + originalTxnTS uint64 + txn Transaction + ) + + defer func() { + globalInnerTxnTsBox.deleteInnerTxnTS(originalTxnTS) + }() + + for i := uint(0); i < MaxRetryCnt; i++ { + txn, err = store.Begin() + if err != nil { + logutil.BgLogger().Error("RunInNewTxn", zap.Error(err)) + return err + } + setRequestSourceForInnerTxn(ctx, txn) + + // originalTxnTS is used to trace the original transaction when the function is retryable. + if i == 0 { + originalTxnTS = txn.StartTS() + globalInnerTxnTsBox.storeInnerTxnTS(originalTxnTS) + } + + err = f(ctx, txn) + if err != nil { + err1 := txn.Rollback() + terror.Log(err1) + if retryable && IsTxnRetryableError(err) { + logutil.BgLogger().Warn("RunInNewTxn", + zap.Uint64("retry txn", txn.StartTS()), + zap.Uint64("original txn", originalTxnTS), + zap.Error(err)) + continue + } + return err + } + + failpoint.Inject("mockCommitErrorInNewTxn", func(val failpoint.Value) { + if v := val.(string); len(v) > 0 { + switch v { + case "retry_once": + //nolint:noloopclosure + if i == 0 { + err = ErrTxnRetryable + } + case "no_retry": + failpoint.Return(errors.New("mock commit error")) + } + } + }) + + if err == nil { + err = txn.Commit(ctx) + if err == nil { + break + } + } + if retryable && IsTxnRetryableError(err) { + logutil.BgLogger().Warn("RunInNewTxn", + zap.Uint64("retry txn", txn.StartTS()), + zap.Uint64("original txn", originalTxnTS), + zap.Error(err)) + BackOff(i) + continue + } + return err + } + return err +} + +var ( + // MaxRetryCnt represents maximum retry times. + MaxRetryCnt uint = 100 + // retryBackOffBase is the initial duration, in microsecond, a failed transaction stays dormancy before it retries + retryBackOffBase = 1 + // retryBackOffCap is the max amount of duration, in microsecond, a failed transaction stays dormancy before it retries + retryBackOffCap = 100 +) + +// BackOff Implements exponential backoff with full jitter. +// Returns real back off time in microsecond. +// See http://www.awsarchitectureblog.com/2015/03/backoff.html. +func BackOff(attempts uint) int { + upper := int(math.Min(float64(retryBackOffCap), float64(retryBackOffBase)*math.Pow(2.0, float64(attempts)))) + sleep := time.Duration(rand.Intn(upper)) * time.Millisecond // #nosec G404 + time.Sleep(sleep) + return int(sleep) +} + +func setRequestSourceForInnerTxn(ctx context.Context, txn Transaction) { + if source := ctx.Value(RequestSourceKey); source != nil { + requestSource := source.(RequestSource) + if requestSource.RequestSourceType != "" { + if !requestSource.RequestSourceInternal { + logutil.Logger(ctx).Warn("`RunInNewTxn` should be used by inner txn only") + } + txn.SetOption(RequestSourceInternal, requestSource.RequestSourceInternal) + txn.SetOption(RequestSourceType, requestSource.RequestSourceType) + if requestSource.ExplicitRequestSourceType != "" { + txn.SetOption(ExplicitRequestSourceType, requestSource.ExplicitRequestSourceType) + } + return + } + } + // panic in test mode in case there are requests without source in the future. + // log warnings in production mode. + if intest.InTest { + panic("unexpected no source type context, if you see this error, " + + "the `RequestSourceTypeKey` is missing in your context") + } + logutil.Logger(ctx).Warn("unexpected no source type context, if you see this warning, " + + "the `RequestSourceTypeKey` is missing in the context") +} + +// SetTxnResourceGroup update the resource group name of target txn. +func SetTxnResourceGroup(txn Transaction, name string) { + txn.SetOption(ResourceGroupName, name) + failpoint.Inject("TxnResourceGroupChecker", func(val failpoint.Value) { + expectedRgName := val.(string) + validateRNameInterceptor := func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { + return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { + var rgName *string + switch r := req.Req.(type) { + case *kvrpcpb.PrewriteRequest: + rgName = &r.Context.ResourceControlContext.ResourceGroupName + case *kvrpcpb.CommitRequest: + rgName = &r.Context.ResourceControlContext.ResourceGroupName + case *kvrpcpb.PessimisticLockRequest: + rgName = &r.Context.ResourceControlContext.ResourceGroupName + } + if rgName != nil && *rgName != expectedRgName { + panic(fmt.Sprintf("resource group name not match, expected: %s, actual: %s", expectedRgName, *rgName)) + } + return next(target, req) + } + } + txn.SetOption(RPCInterceptor, interceptor.NewRPCInterceptor("test-validate-rg-name", validateRNameInterceptor)) + }) +} diff --git a/pkg/lightning/backend/backend.go b/pkg/lightning/backend/backend.go index 878b556c7e460..ddc56c94665c1 100644 --- a/pkg/lightning/backend/backend.go +++ b/pkg/lightning/backend/backend.go @@ -267,7 +267,7 @@ func (be EngineManager) OpenEngine( logger.Info("open engine") - failpoint.Inject("FailIfEngineCountExceeds", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("FailIfEngineCountExceeds")); _err_ == nil { if m, ok := metric.FromContext(ctx); ok { closedCounter := m.ImporterEngineCounter.WithLabelValues("closed") openCounter := m.ImporterEngineCounter.WithLabelValues("open") @@ -280,7 +280,7 @@ func (be EngineManager) OpenEngine( openCount, closedCount, injectValue)) } } - }) + } return &OpenedEngine{ engine: engine{ diff --git a/pkg/lightning/backend/backend.go__failpoint_stash__ b/pkg/lightning/backend/backend.go__failpoint_stash__ new file mode 100644 index 0000000000000..878b556c7e460 --- /dev/null +++ b/pkg/lightning/backend/backend.go__failpoint_stash__ @@ -0,0 +1,439 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backend + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/mydump" + "github.com/pingcap/tidb/pkg/parser/model" + "go.uber.org/zap" +) + +const ( + importMaxRetryTimes = 3 // tikv-importer has done retry internally. so we don't retry many times. +) + +func makeTag(tableName string, engineID int64) string { + return fmt.Sprintf("%s:%d", tableName, engineID) +} + +func makeLogger(logger log.Logger, tag string, engineUUID uuid.UUID) log.Logger { + return logger.With( + zap.String("engineTag", tag), + zap.Stringer("engineUUID", engineUUID), + ) +} + +// MakeUUID generates a UUID for the engine and a tag for the engine. +func MakeUUID(tableName string, engineID int64) (string, uuid.UUID) { + tag := makeTag(tableName, engineID) + engineUUID := uuid.NewSHA1(engineNamespace, []byte(tag)) + return tag, engineUUID +} + +var engineNamespace = uuid.MustParse("d68d6abe-c59e-45d6-ade8-e2b0ceb7bedf") + +// EngineFileSize represents the size of an engine on disk and in memory. +type EngineFileSize struct { + // UUID is the engine's UUID. + UUID uuid.UUID + // DiskSize is the estimated total file size on disk right now. + DiskSize int64 + // MemSize is the total memory size used by the engine. This is the + // estimated additional size saved onto disk after calling Flush(). + MemSize int64 + // IsImporting indicates whether the engine performing Import(). + IsImporting bool +} + +// LocalWriterConfig defines the configuration to open a LocalWriter +type LocalWriterConfig struct { + // Local backend specified configuration + Local struct { + // is the chunk KV written to this LocalWriter sent in order + IsKVSorted bool + // MemCacheSize specifies the estimated memory cache limit used by this local + // writer. It has higher priority than BackendConfig.LocalWriterMemCacheSize if + // set. + MemCacheSize int64 + } + // TiDB backend specified configuration + TiDB struct { + TableName string + } +} + +// EngineConfig defines configuration used for open engine +type EngineConfig struct { + // TableInfo is the corresponding tidb table info + TableInfo *checkpoints.TidbTableInfo + // local backend specified configuration + Local LocalEngineConfig + // local backend external engine specified configuration + External *ExternalEngineConfig + // KeepSortDir indicates whether to keep the temporary sort directory + // when opening the engine, instead of removing it. + KeepSortDir bool + // TS is the preset timestamp of data in the engine. When it's 0, the used TS + // will be set lazily. + TS uint64 +} + +// LocalEngineConfig is the configuration used for local backend in OpenEngine. +type LocalEngineConfig struct { + // compact small SSTs before ingest into pebble + Compact bool + // raw kvs size threshold to trigger compact + CompactThreshold int64 + // compact routine concurrency + CompactConcurrency int + + // blocksize + BlockSize int +} + +// ExternalEngineConfig is the configuration used for local backend external engine. +type ExternalEngineConfig struct { + StorageURI string + DataFiles []string + StatFiles []string + StartKey []byte + EndKey []byte + SplitKeys [][]byte + RegionSplitSize int64 + // TotalFileSize can be an estimated value. + TotalFileSize int64 + // TotalKVCount can be an estimated value. + TotalKVCount int64 + CheckHotspot bool +} + +// CheckCtx contains all parameters used in CheckRequirements +type CheckCtx struct { + DBMetas []*mydump.MDDatabaseMeta +} + +// TargetInfoGetter defines the interfaces to get target information. +type TargetInfoGetter interface { + // FetchRemoteDBModels obtains the models of all databases. Currently, only + // the database name is filled. + FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) + + // FetchRemoteTableModels obtains the models of all tables given the schema + // name. The returned table info does not need to be precise if the encoder, + // is not requiring them, but must at least fill in the following fields for + // TablesFromMeta to succeed: + // - Name + // - State (must be model.StatePublic) + // - ID + // - Columns + // * Name + // * State (must be model.StatePublic) + // * Offset (must be 0, 1, 2, ...) + // - PKIsHandle (true = do not generate _tidb_rowid) + FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) + + // CheckRequirements performs the check whether the backend satisfies the version requirements + CheckRequirements(ctx context.Context, checkCtx *CheckCtx) error +} + +// Backend defines the interface for a backend. +// Implementations of this interface must be goroutine safe: you can share an +// instance and execute any method anywhere. +// Usual workflow: +// 1. Create a `Backend` for the whole process. +// 2. For each table, +// i. Split into multiple "batches" consisting of data files with roughly equal total size. +// ii. For each batch, +// a. Create an `OpenedEngine` via `backend.OpenEngine()` +// b. For each chunk, deliver data into the engine via `engine.WriteRows()` +// c. When all chunks are written, obtain a `ClosedEngine` via `engine.Close()` +// d. Import data via `engine.Import()` +// e. Cleanup via `engine.Cleanup()` +// 3. Close the connection via `backend.Close()` +type Backend interface { + // Close the connection to the backend. + Close() + + // RetryImportDelay returns the duration to sleep when retrying an import + RetryImportDelay() time.Duration + + // ShouldPostProcess returns whether KV-specific post-processing should be + // performed for this backend. Post-processing includes checksum and analyze. + ShouldPostProcess() bool + + OpenEngine(ctx context.Context, config *EngineConfig, engineUUID uuid.UUID) error + + CloseEngine(ctx context.Context, config *EngineConfig, engineUUID uuid.UUID) error + + // ImportEngine imports engine data to the backend. If it returns ErrDuplicateDetected, + // it means there is duplicate detected. For this situation, all data in the engine must be imported. + // It's safe to reset or cleanup this engine. + ImportEngine(ctx context.Context, engineUUID uuid.UUID, regionSplitSize, regionSplitKeys int64) error + + CleanupEngine(ctx context.Context, engineUUID uuid.UUID) error + + // FlushEngine ensures all KV pairs written to an open engine has been + // synchronized, such that kill-9'ing Lightning afterwards and resuming from + // checkpoint can recover the exact same content. + // + // This method is only relevant for local backend, and is no-op for all + // other backends. + FlushEngine(ctx context.Context, engineUUID uuid.UUID) error + + // FlushAllEngines performs FlushEngine on all opened engines. This is a + // very expensive operation and should only be used in some rare situation + // (e.g. preparing to resolve a disk quota violation). + FlushAllEngines(ctx context.Context) error + + // ResetEngine clears all written KV pairs in this opened engine. + ResetEngine(ctx context.Context, engineUUID uuid.UUID) error + + // LocalWriter obtains a thread-local EngineWriter for writing rows into the given engine. + LocalWriter(ctx context.Context, cfg *LocalWriterConfig, engineUUID uuid.UUID) (EngineWriter, error) +} + +// EngineManager is the manager of engines. +// this is a wrapper of Backend, which provides some common methods for managing engines. +// and it has no states, can be created on demand +type EngineManager struct { + backend Backend +} + +type engine struct { + backend Backend + logger log.Logger + uuid uuid.UUID + // id of the engine, used to generate uuid and stored in checkpoint + // for index engine it's -1 + id int32 +} + +// OpenedEngine is an opened engine, allowing data to be written via WriteRows. +// This type is goroutine safe: you can share an instance and execute any method +// anywhere. +type OpenedEngine struct { + engine + tableName string + config *EngineConfig +} + +// MakeEngineManager creates a new Backend from an Backend. +func MakeEngineManager(ab Backend) EngineManager { + return EngineManager{backend: ab} +} + +// OpenEngine opens an engine with the given table name and engine ID. +func (be EngineManager) OpenEngine( + ctx context.Context, + config *EngineConfig, + tableName string, + engineID int32, +) (*OpenedEngine, error) { + tag, engineUUID := MakeUUID(tableName, int64(engineID)) + logger := makeLogger(log.FromContext(ctx), tag, engineUUID) + + if err := be.backend.OpenEngine(ctx, config, engineUUID); err != nil { + return nil, err + } + + if m, ok := metric.FromContext(ctx); ok { + openCounter := m.ImporterEngineCounter.WithLabelValues("open") + openCounter.Inc() + } + + logger.Info("open engine") + + failpoint.Inject("FailIfEngineCountExceeds", func(val failpoint.Value) { + if m, ok := metric.FromContext(ctx); ok { + closedCounter := m.ImporterEngineCounter.WithLabelValues("closed") + openCounter := m.ImporterEngineCounter.WithLabelValues("open") + openCount := metric.ReadCounter(openCounter) + + closedCount := metric.ReadCounter(closedCounter) + if injectValue := val.(int); openCount-closedCount > float64(injectValue) { + panic(fmt.Sprintf( + "forcing failure due to FailIfEngineCountExceeds: %v - %v >= %d", + openCount, closedCount, injectValue)) + } + } + }) + + return &OpenedEngine{ + engine: engine{ + backend: be.backend, + logger: logger, + uuid: engineUUID, + id: engineID, + }, + tableName: tableName, + config: config, + }, nil +} + +// Close the opened engine to prepare it for importing. +func (engine *OpenedEngine) Close(ctx context.Context) (*ClosedEngine, error) { + closedEngine, err := engine.unsafeClose(ctx, engine.config) + if err == nil { + if m, ok := metric.FromContext(ctx); ok { + m.ImporterEngineCounter.WithLabelValues("closed").Inc() + } + } + return closedEngine, err +} + +// Flush current written data for local backend +func (engine *OpenedEngine) Flush(ctx context.Context) error { + return engine.backend.FlushEngine(ctx, engine.uuid) +} + +// LocalWriter returns a writer that writes to the local backend. +func (engine *OpenedEngine) LocalWriter(ctx context.Context, cfg *LocalWriterConfig) (EngineWriter, error) { + return engine.backend.LocalWriter(ctx, cfg, engine.uuid) +} + +// SetTS sets the TS of the engine. In most cases if the caller wants to specify +// TS it should use the TS field in EngineConfig. This method is only used after +// a ResetEngine. +func (engine *OpenedEngine) SetTS(ts uint64) { + engine.config.TS = ts +} + +// UnsafeCloseEngine closes the engine without first opening it. +// This method is "unsafe" as it does not follow the normal operation sequence +// (Open -> Write -> Close -> Import). This method should only be used when one +// knows via other ways that the engine has already been opened, e.g. when +// resuming from a checkpoint. +func (be EngineManager) UnsafeCloseEngine(ctx context.Context, cfg *EngineConfig, + tableName string, engineID int32) (*ClosedEngine, error) { + tag, engineUUID := MakeUUID(tableName, int64(engineID)) + return be.UnsafeCloseEngineWithUUID(ctx, cfg, tag, engineUUID, engineID) +} + +// UnsafeCloseEngineWithUUID closes the engine without first opening it. +// This method is "unsafe" as it does not follow the normal operation sequence +// (Open -> Write -> Close -> Import). This method should only be used when one +// knows via other ways that the engine has already been opened, e.g. when +// resuming from a checkpoint. +func (be EngineManager) UnsafeCloseEngineWithUUID(ctx context.Context, cfg *EngineConfig, tag string, + engineUUID uuid.UUID, id int32) (*ClosedEngine, error) { + return engine{ + backend: be.backend, + logger: makeLogger(log.FromContext(ctx), tag, engineUUID), + uuid: engineUUID, + id: id, + }.unsafeClose(ctx, cfg) +} + +func (en engine) unsafeClose(ctx context.Context, cfg *EngineConfig) (*ClosedEngine, error) { + task := en.logger.Begin(zap.InfoLevel, "engine close") + err := en.backend.CloseEngine(ctx, cfg, en.uuid) + task.End(zap.ErrorLevel, err) + if err != nil { + return nil, err + } + return &ClosedEngine{engine: en}, nil +} + +// GetID get engine id. +func (en engine) GetID() int32 { + return en.id +} + +func (en engine) GetUUID() uuid.UUID { + return en.uuid +} + +// ClosedEngine represents a closed engine, allowing ingestion into the target. +// This type is goroutine safe: you can share an instance and execute any method +// anywhere. +type ClosedEngine struct { + engine +} + +// NewClosedEngine creates a new ClosedEngine. +func NewClosedEngine(backend Backend, logger log.Logger, uuid uuid.UUID, id int32) *ClosedEngine { + return &ClosedEngine{ + engine: engine{ + backend: backend, + logger: logger, + uuid: uuid, + id: id, + }, + } +} + +// Import the data written to the engine into the target. +func (engine *ClosedEngine) Import(ctx context.Context, regionSplitSize, regionSplitKeys int64) error { + var err error + + for i := 0; i < importMaxRetryTimes; i++ { + task := engine.logger.With(zap.Int("retryCnt", i)).Begin(zap.InfoLevel, "import") + err = engine.backend.ImportEngine(ctx, engine.uuid, regionSplitSize, regionSplitKeys) + if !common.IsRetryableError(err) { + if common.ErrFoundDuplicateKeys.Equal(err) { + task.End(zap.WarnLevel, err) + } else { + task.End(zap.ErrorLevel, err) + } + return err + } + task.Warn("import spuriously failed, going to retry again", log.ShortError(err)) + time.Sleep(engine.backend.RetryImportDelay()) + } + + return errors.Annotatef(err, "[%s] import reach max retry %d and still failed", engine.uuid, importMaxRetryTimes) +} + +// Cleanup deletes the intermediate data from target. +func (engine *ClosedEngine) Cleanup(ctx context.Context) error { + task := engine.logger.Begin(zap.InfoLevel, "cleanup") + err := engine.backend.CleanupEngine(ctx, engine.uuid) + task.End(zap.WarnLevel, err) + return err +} + +// Logger returns the logger for the engine. +func (engine *ClosedEngine) Logger() log.Logger { + return engine.logger +} + +// ChunkFlushStatus is the status of a chunk flush. +type ChunkFlushStatus interface { + Flushed() bool +} + +// EngineWriter is the interface for writing data to an engine. +type EngineWriter interface { + AppendRows(ctx context.Context, columnNames []string, rows encode.Rows) error + IsSynced() bool + Close(ctx context.Context) (ChunkFlushStatus, error) +} + +// GetEngineUUID returns the engine UUID. +func (engine *OpenedEngine) GetEngineUUID() uuid.UUID { + return engine.uuid +} diff --git a/pkg/lightning/backend/binding__failpoint_binding__.go b/pkg/lightning/backend/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..6a726164e38ec --- /dev/null +++ b/pkg/lightning/backend/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package backend + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/lightning/backend/external/binding__failpoint_binding__.go b/pkg/lightning/backend/external/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..f0c178da70e4b --- /dev/null +++ b/pkg/lightning/backend/external/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package external + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/lightning/backend/external/byte_reader.go b/pkg/lightning/backend/external/byte_reader.go index d2d46939f6aab..473b51a7ccfe0 100644 --- a/pkg/lightning/backend/external/byte_reader.go +++ b/pkg/lightning/backend/external/byte_reader.go @@ -327,11 +327,11 @@ func (r *byteReader) closeConcurrentReader() (reloadCnt, offsetInOldBuffer int) zap.Int("dropBytes", r.concurrentReader.bufSizePerConc*(len(r.curBuf)-r.curBufIdx)-r.curBufOffset), zap.Int("curBufIdx", r.curBufIdx), ) - failpoint.Inject("assertReloadAtMostOnce", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertReloadAtMostOnce")); _err_ == nil { if r.concurrentReader.reloadCnt > 1 { panic(fmt.Sprintf("reloadCnt is %d", r.concurrentReader.reloadCnt)) } - }) + } r.concurrentReader.largeBufferPool.Destroy() r.concurrentReader.largeBuf = nil r.concurrentReader.now = false diff --git a/pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ b/pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ new file mode 100644 index 0000000000000..d2d46939f6aab --- /dev/null +++ b/pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ @@ -0,0 +1,351 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package external + +import ( + "context" + "fmt" + "io" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/membuf" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/size" + "go.uber.org/zap" +) + +var ( + // ConcurrentReaderBufferSizePerConc is the buffer size for concurrent reader per + // concurrency. + ConcurrentReaderBufferSizePerConc = int(8 * size.MB) + // in readAllData, expected concurrency less than this value will not use + // concurrent reader. + readAllDataConcThreshold = uint64(4) +) + +// byteReader provides structured reading on a byte stream of external storage. +// It can also switch to concurrent reading mode and fetch a larger amount of +// data to improve throughput. +type byteReader struct { + ctx context.Context + storageReader storage.ExternalFileReader + + // curBuf is either smallBuf or concurrentReader.largeBuf. + curBuf [][]byte + curBufIdx int // invariant: 0 <= curBufIdx < len(curBuf) when curBuf contains unread data + curBufOffset int // invariant: 0 <= curBufOffset < len(curBuf[curBufIdx]) if curBufIdx < len(curBuf) + smallBuf []byte + + concurrentReader struct { + largeBufferPool *membuf.Buffer + store storage.ExternalStorage + filename string + concurrency int + bufSizePerConc int + + now bool + expected bool + largeBuf [][]byte + reader *concurrentFileReader + reloadCnt int + } + + logger *zap.Logger +} + +func openStoreReaderAndSeek( + ctx context.Context, + store storage.ExternalStorage, + name string, + initFileOffset uint64, + prefetchSize int, +) (storage.ExternalFileReader, error) { + storageReader, err := store.Open(ctx, name, &storage.ReaderOption{PrefetchSize: prefetchSize}) + if err != nil { + return nil, err + } + _, err = storageReader.Seek(int64(initFileOffset), io.SeekStart) + if err != nil { + return nil, err + } + return storageReader, nil +} + +// newByteReader wraps readNBytes functionality to storageReader. If store and +// filename are also given, this reader can use switchConcurrentMode to switch to +// concurrent reading mode. +func newByteReader( + ctx context.Context, + storageReader storage.ExternalFileReader, + bufSize int, +) (r *byteReader, err error) { + defer func() { + if err != nil && r != nil { + _ = r.Close() + } + }() + r = &byteReader{ + ctx: ctx, + storageReader: storageReader, + smallBuf: make([]byte, bufSize), + curBufOffset: 0, + } + r.curBuf = [][]byte{r.smallBuf} + r.logger = logutil.Logger(r.ctx) + return r, r.reload() +} + +func (r *byteReader) enableConcurrentRead( + store storage.ExternalStorage, + filename string, + concurrency int, + bufSizePerConc int, + bufferPool *membuf.Buffer, +) { + r.concurrentReader.store = store + r.concurrentReader.filename = filename + r.concurrentReader.concurrency = concurrency + r.concurrentReader.bufSizePerConc = bufSizePerConc + r.concurrentReader.largeBufferPool = bufferPool +} + +// switchConcurrentMode is used to help implement sortedReader.switchConcurrentMode. +// See the comment of the interface. +func (r *byteReader) switchConcurrentMode(useConcurrent bool) error { + readerFields := &r.concurrentReader + if readerFields.store == nil { + r.logger.Warn("concurrent reader is not enabled, skip switching") + // caller don't need to care about it. + return nil + } + // need to set it before reload() + readerFields.expected = useConcurrent + // concurrent reader will be lazily initialized when reload() + if useConcurrent { + return nil + } + + // no change + if !readerFields.now { + return nil + } + + // rest cases is caller want to turn off concurrent reader. We should turn off + // immediately to release memory. + reloadCnt, offsetInOldBuf := r.closeConcurrentReader() + // here we can assume largeBuf is always fully loaded, because the only exception + // is it's the end of file. When it's the end of the file, caller will see EOF + // and no further switchConcurrentMode should be called. + largeBufSize := readerFields.bufSizePerConc * readerFields.concurrency + delta := int64(offsetInOldBuf + (reloadCnt-1)*largeBufSize) + + if _, err := r.storageReader.Seek(delta, io.SeekCurrent); err != nil { + return err + } + err := r.reload() + if err != nil && err == io.EOF { + // ignore EOF error, let readNBytes handle it + return nil + } + return err +} + +func (r *byteReader) switchToConcurrentReader() error { + // because it will be called only when buffered data of storageReader is used + // up, we can use seek(0, io.SeekCurrent) to get the offset for concurrent + // reader + currOffset, err := r.storageReader.Seek(0, io.SeekCurrent) + if err != nil { + return err + } + fileSize, err := r.storageReader.GetFileSize() + if err != nil { + return err + } + readerFields := &r.concurrentReader + readerFields.reader, err = newConcurrentFileReader( + r.ctx, + readerFields.store, + readerFields.filename, + currOffset, + fileSize, + readerFields.concurrency, + readerFields.bufSizePerConc, + ) + if err != nil { + return err + } + + readerFields.largeBuf = make([][]byte, readerFields.concurrency) + for i := range readerFields.largeBuf { + readerFields.largeBuf[i] = readerFields.largeBufferPool.AllocBytes(readerFields.bufSizePerConc) + if readerFields.largeBuf[i] == nil { + return errors.Errorf("alloc large buffer failed, size %d", readerFields.bufSizePerConc) + } + } + + r.curBuf = readerFields.largeBuf + r.curBufOffset = 0 + readerFields.now = true + return nil +} + +// readNBytes reads the next n bytes from the reader and returns a buffer slice +// containing those bytes. The content of returned slice may be changed after +// next call. +func (r *byteReader) readNBytes(n int) ([]byte, error) { + if n <= 0 { + return nil, errors.Errorf("illegal n (%d) when reading from external storage", n) + } + if n > int(size.GB) { + return nil, errors.Errorf("read %d bytes from external storage, exceed max limit %d", n, size.GB) + } + + readLen, bs := r.next(n) + if readLen == n && len(bs) == 1 { + return bs[0], nil + } + // need to flatten bs + auxBuf := make([]byte, n) + for _, b := range bs { + copy(auxBuf[len(auxBuf)-n:], b) + n -= len(b) + } + hasRead := readLen > 0 + for n > 0 { + err := r.reload() + switch err { + case nil: + case io.EOF: + // EOF is only allowed when we have not read any data + if hasRead { + return nil, io.ErrUnexpectedEOF + } + return nil, err + default: + return nil, err + } + readLen, bs = r.next(n) + hasRead = hasRead || readLen > 0 + for _, b := range bs { + copy(auxBuf[len(auxBuf)-n:], b) + n -= len(b) + } + } + return auxBuf, nil +} + +func (r *byteReader) next(n int) (int, [][]byte) { + retCnt := 0 + // TODO(lance6716): heap escape performance? + ret := make([][]byte, 0, len(r.curBuf)-r.curBufIdx+1) + for r.curBufIdx < len(r.curBuf) && n > 0 { + cur := r.curBuf[r.curBufIdx] + if r.curBufOffset+n <= len(cur) { + ret = append(ret, cur[r.curBufOffset:r.curBufOffset+n]) + retCnt += n + r.curBufOffset += n + if r.curBufOffset == len(cur) { + r.curBufIdx++ + r.curBufOffset = 0 + } + break + } + ret = append(ret, cur[r.curBufOffset:]) + retCnt += len(cur) - r.curBufOffset + n -= len(cur) - r.curBufOffset + r.curBufIdx++ + r.curBufOffset = 0 + } + + return retCnt, ret +} + +func (r *byteReader) reload() error { + to := r.concurrentReader.expected + now := r.concurrentReader.now + // in read only false -> true is possible + if !now && to { + r.logger.Info("switch reader mode", zap.Bool("use concurrent mode", true)) + err := r.switchToConcurrentReader() + if err != nil { + return err + } + } + + if r.concurrentReader.now { + r.concurrentReader.reloadCnt++ + buffers, err := r.concurrentReader.reader.read(r.concurrentReader.largeBuf) + if err != nil { + return err + } + r.curBuf = buffers + r.curBufIdx = 0 + r.curBufOffset = 0 + return nil + } + // when not using concurrentReader, len(curBuf) == 1 + n, err := io.ReadFull(r.storageReader, r.curBuf[0][0:]) + if err != nil { + switch err { + case io.EOF: + // move curBufIdx so following read will also find EOF + r.curBufIdx = len(r.curBuf) + return err + case io.ErrUnexpectedEOF: + // The last batch. + r.curBuf[0] = r.curBuf[0][:n] + case context.Canceled: + return err + default: + r.logger.Warn("other error during read", zap.Error(err)) + return err + } + } + r.curBufIdx = 0 + r.curBufOffset = 0 + return nil +} + +func (r *byteReader) closeConcurrentReader() (reloadCnt, offsetInOldBuffer int) { + r.logger.Info("drop data in closeConcurrentReader", + zap.Int("reloadCnt", r.concurrentReader.reloadCnt), + zap.Int("dropBytes", r.concurrentReader.bufSizePerConc*(len(r.curBuf)-r.curBufIdx)-r.curBufOffset), + zap.Int("curBufIdx", r.curBufIdx), + ) + failpoint.Inject("assertReloadAtMostOnce", func() { + if r.concurrentReader.reloadCnt > 1 { + panic(fmt.Sprintf("reloadCnt is %d", r.concurrentReader.reloadCnt)) + } + }) + r.concurrentReader.largeBufferPool.Destroy() + r.concurrentReader.largeBuf = nil + r.concurrentReader.now = false + reloadCnt = r.concurrentReader.reloadCnt + r.concurrentReader.reloadCnt = 0 + r.curBuf = [][]byte{r.smallBuf} + offsetInOldBuffer = r.curBufOffset + r.curBufIdx*r.concurrentReader.bufSizePerConc + r.curBufOffset = 0 + return +} + +func (r *byteReader) Close() error { + if r.concurrentReader.now { + r.closeConcurrentReader() + } + return r.storageReader.Close() +} diff --git a/pkg/lightning/backend/external/engine.go b/pkg/lightning/backend/external/engine.go index 7a42354e20acc..149068f52b76e 100644 --- a/pkg/lightning/backend/external/engine.go +++ b/pkg/lightning/backend/external/engine.go @@ -357,9 +357,9 @@ func (e *Engine) LoadIngestData( ) error { // try to make every worker busy for each batch regionBatchSize := e.workerConcurrency - failpoint.Inject("LoadIngestDataBatchSize", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("LoadIngestDataBatchSize")); _err_ == nil { regionBatchSize = val.(int) - }) + } for i := 0; i < len(regionRanges); i += regionBatchSize { err := e.loadBatchRegionData(ctx, regionRanges[i].Start, regionRanges[min(i+regionBatchSize, len(regionRanges))-1].End, outCh) if err != nil { diff --git a/pkg/lightning/backend/external/engine.go__failpoint_stash__ b/pkg/lightning/backend/external/engine.go__failpoint_stash__ new file mode 100644 index 0000000000000..7a42354e20acc --- /dev/null +++ b/pkg/lightning/backend/external/engine.go__failpoint_stash__ @@ -0,0 +1,732 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package external + +import ( + "bytes" + "context" + "sort" + "sync" + "time" + + "github.com/cockroachdb/pebble" + "github.com/docker/go-units" + "github.com/jfcg/sorty/v2" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/membuf" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +// during test on ks3, we found that we can open about 8000 connections to ks3, +// bigger than that, we might receive "connection reset by peer" error, and +// the read speed will be very slow, still investigating the reason. +// Also open too many connections will take many memory in kernel, and the +// test is based on k8s pod, not sure how it will behave on EC2. +// but, ks3 supporter says there's no such limit on connections. +// And our target for global sort is AWS s3, this default value might not fit well. +// TODO: adjust it according to cloud storage. +const maxCloudStorageConnections = 1000 + +type memKVsAndBuffers struct { + mu sync.Mutex + keys [][]byte + values [][]byte + // memKVBuffers contains two types of buffer, first half are used for small block + // buffer, second half are used for large one. + memKVBuffers []*membuf.Buffer + size int + droppedSize int + + // temporary fields to store KVs to reduce slice allocations. + keysPerFile [][][]byte + valuesPerFile [][][]byte + droppedSizePerFile []int +} + +func (b *memKVsAndBuffers) build(ctx context.Context) { + sumKVCnt := 0 + for _, keys := range b.keysPerFile { + sumKVCnt += len(keys) + } + b.droppedSize = 0 + for _, size := range b.droppedSizePerFile { + b.droppedSize += size + } + b.droppedSizePerFile = nil + + logutil.Logger(ctx).Info("building memKVsAndBuffers", + zap.Int("sumKVCnt", sumKVCnt), + zap.Int("droppedSize", b.droppedSize)) + + b.keys = make([][]byte, 0, sumKVCnt) + b.values = make([][]byte, 0, sumKVCnt) + for i := range b.keysPerFile { + b.keys = append(b.keys, b.keysPerFile[i]...) + b.keysPerFile[i] = nil + b.values = append(b.values, b.valuesPerFile[i]...) + b.valuesPerFile[i] = nil + } + b.keysPerFile = nil + b.valuesPerFile = nil +} + +// Engine stored sorted key/value pairs in an external storage. +type Engine struct { + storage storage.ExternalStorage + dataFiles []string + statsFiles []string + startKey []byte + endKey []byte + splitKeys [][]byte + regionSplitSize int64 + smallBlockBufPool *membuf.Pool + largeBlockBufPool *membuf.Pool + + memKVsAndBuffers memKVsAndBuffers + + // checkHotspot is true means we will check hotspot file when using MergeKVIter. + // if hotspot file is detected, we will use multiple readers to read data. + // if it's false, MergeKVIter will read each file using 1 reader. + // this flag also affects the strategy of loading data, either: + // less load routine + check and read hotspot file concurrently (add-index uses this one) + // more load routine + read each file using 1 reader (import-into uses this one) + checkHotspot bool + mergerIterConcurrency int + + keyAdapter common.KeyAdapter + duplicateDetection bool + duplicateDB *pebble.DB + dupDetectOpt common.DupDetectOpt + workerConcurrency int + ts uint64 + + totalKVSize int64 + totalKVCount int64 + + importedKVSize *atomic.Int64 + importedKVCount *atomic.Int64 +} + +const ( + memLimit = 12 * units.GiB + smallBlockSize = units.MiB +) + +// NewExternalEngine creates an (external) engine. +func NewExternalEngine( + storage storage.ExternalStorage, + dataFiles []string, + statsFiles []string, + startKey []byte, + endKey []byte, + splitKeys [][]byte, + regionSplitSize int64, + keyAdapter common.KeyAdapter, + duplicateDetection bool, + duplicateDB *pebble.DB, + dupDetectOpt common.DupDetectOpt, + workerConcurrency int, + ts uint64, + totalKVSize int64, + totalKVCount int64, + checkHotspot bool, +) common.Engine { + memLimiter := membuf.NewLimiter(memLimit) + return &Engine{ + storage: storage, + dataFiles: dataFiles, + statsFiles: statsFiles, + startKey: startKey, + endKey: endKey, + splitKeys: splitKeys, + regionSplitSize: regionSplitSize, + smallBlockBufPool: membuf.NewPool( + membuf.WithBlockNum(0), + membuf.WithPoolMemoryLimiter(memLimiter), + membuf.WithBlockSize(smallBlockSize), + ), + largeBlockBufPool: membuf.NewPool( + membuf.WithBlockNum(0), + membuf.WithPoolMemoryLimiter(memLimiter), + membuf.WithBlockSize(ConcurrentReaderBufferSizePerConc), + ), + checkHotspot: checkHotspot, + keyAdapter: keyAdapter, + duplicateDetection: duplicateDetection, + duplicateDB: duplicateDB, + dupDetectOpt: dupDetectOpt, + workerConcurrency: workerConcurrency, + ts: ts, + totalKVSize: totalKVSize, + totalKVCount: totalKVCount, + importedKVSize: atomic.NewInt64(0), + importedKVCount: atomic.NewInt64(0), + } +} + +func split[T any](in []T, groupNum int) [][]T { + if len(in) == 0 { + return nil + } + if groupNum <= 0 { + groupNum = 1 + } + ceil := (len(in) + groupNum - 1) / groupNum + ret := make([][]T, 0, groupNum) + l := len(in) + for i := 0; i < l; i += ceil { + if i+ceil > l { + ret = append(ret, in[i:]) + } else { + ret = append(ret, in[i:i+ceil]) + } + } + return ret +} + +func (e *Engine) getAdjustedConcurrency() int { + if e.checkHotspot { + // estimate we will open at most 8000 files, so if e.dataFiles is small we can + // try to concurrently process ranges. + adjusted := maxCloudStorageConnections / len(e.dataFiles) + if adjusted == 0 { + return 1 + } + return min(adjusted, 8) + } + adjusted := min(e.workerConcurrency, maxCloudStorageConnections/len(e.dataFiles)) + return max(adjusted, 1) +} + +func getFilesReadConcurrency( + ctx context.Context, + storage storage.ExternalStorage, + statsFiles []string, + startKey, endKey []byte, +) ([]uint64, []uint64, error) { + result := make([]uint64, len(statsFiles)) + offsets, err := seekPropsOffsets(ctx, []kv.Key{startKey, endKey}, statsFiles, storage) + if err != nil { + return nil, nil, err + } + startOffs, endOffs := offsets[0], offsets[1] + for i := range statsFiles { + expectedConc := (endOffs[i] - startOffs[i]) / uint64(ConcurrentReaderBufferSizePerConc) + // let the stat internals cover the [startKey, endKey) since seekPropsOffsets + // always return an offset that is less than or equal to the key. + expectedConc += 1 + // readAllData will enable concurrent read and use large buffer if result[i] > 1 + // when expectedConc < readAllDataConcThreshold, we don't use concurrent read to + // reduce overhead + if expectedConc >= readAllDataConcThreshold { + result[i] = expectedConc + } else { + result[i] = 1 + } + // only log for files with expected concurrency > 1, to avoid too many logs + if expectedConc > 1 { + logutil.Logger(ctx).Info("found hotspot file in getFilesReadConcurrency", + zap.String("filename", statsFiles[i]), + zap.Uint64("startOffset", startOffs[i]), + zap.Uint64("endOffset", endOffs[i]), + zap.Uint64("expectedConc", expectedConc), + zap.Uint64("concurrency", result[i]), + ) + } + } + return result, startOffs, nil +} + +func (e *Engine) loadBatchRegionData(ctx context.Context, startKey, endKey []byte, outCh chan<- common.DataAndRange) error { + readAndSortRateHist := metrics.GlobalSortReadFromCloudStorageRate.WithLabelValues("read_and_sort") + readAndSortDurHist := metrics.GlobalSortReadFromCloudStorageDuration.WithLabelValues("read_and_sort") + readRateHist := metrics.GlobalSortReadFromCloudStorageRate.WithLabelValues("read") + readDurHist := metrics.GlobalSortReadFromCloudStorageDuration.WithLabelValues("read") + sortRateHist := metrics.GlobalSortReadFromCloudStorageRate.WithLabelValues("sort") + sortDurHist := metrics.GlobalSortReadFromCloudStorageDuration.WithLabelValues("sort") + + readStart := time.Now() + readDtStartKey := e.keyAdapter.Encode(nil, startKey, common.MinRowID) + readDtEndKey := e.keyAdapter.Encode(nil, endKey, common.MinRowID) + err := readAllData( + ctx, + e.storage, + e.dataFiles, + e.statsFiles, + readDtStartKey, + readDtEndKey, + e.smallBlockBufPool, + e.largeBlockBufPool, + &e.memKVsAndBuffers, + ) + if err != nil { + return err + } + e.memKVsAndBuffers.build(ctx) + + readSecond := time.Since(readStart).Seconds() + readDurHist.Observe(readSecond) + logutil.Logger(ctx).Info("reading external storage in loadBatchRegionData", + zap.Duration("cost time", time.Since(readStart)), + zap.Int("droppedSize", e.memKVsAndBuffers.droppedSize)) + + sortStart := time.Now() + oldSortyGor := sorty.MaxGor + sorty.MaxGor = uint64(e.workerConcurrency * 2) + sorty.Sort(len(e.memKVsAndBuffers.keys), func(i, k, r, s int) bool { + if bytes.Compare(e.memKVsAndBuffers.keys[i], e.memKVsAndBuffers.keys[k]) < 0 { // strict comparator like < or > + if r != s { + e.memKVsAndBuffers.keys[r], e.memKVsAndBuffers.keys[s] = e.memKVsAndBuffers.keys[s], e.memKVsAndBuffers.keys[r] + e.memKVsAndBuffers.values[r], e.memKVsAndBuffers.values[s] = e.memKVsAndBuffers.values[s], e.memKVsAndBuffers.values[r] + } + return true + } + return false + }) + sorty.MaxGor = oldSortyGor + sortSecond := time.Since(sortStart).Seconds() + sortDurHist.Observe(sortSecond) + logutil.Logger(ctx).Info("sorting in loadBatchRegionData", + zap.Duration("cost time", time.Since(sortStart))) + + readAndSortSecond := time.Since(readStart).Seconds() + readAndSortDurHist.Observe(readAndSortSecond) + + size := e.memKVsAndBuffers.size + readAndSortRateHist.Observe(float64(size) / 1024.0 / 1024.0 / readAndSortSecond) + readRateHist.Observe(float64(size) / 1024.0 / 1024.0 / readSecond) + sortRateHist.Observe(float64(size) / 1024.0 / 1024.0 / sortSecond) + + data := e.buildIngestData( + e.memKVsAndBuffers.keys, + e.memKVsAndBuffers.values, + e.memKVsAndBuffers.memKVBuffers, + ) + + // release the reference of e.memKVsAndBuffers + e.memKVsAndBuffers.keys = nil + e.memKVsAndBuffers.values = nil + e.memKVsAndBuffers.memKVBuffers = nil + e.memKVsAndBuffers.size = 0 + + sendFn := func(dr common.DataAndRange) error { + select { + case <-ctx.Done(): + return ctx.Err() + case outCh <- dr: + } + return nil + } + return sendFn(common.DataAndRange{ + Data: data, + Range: common.Range{ + Start: startKey, + End: endKey, + }, + }) +} + +// LoadIngestData loads the data from the external storage to memory in [start, +// end) range, so local backend can ingest it. The used byte slice of ingest data +// are allocated from Engine.bufPool and must be released by +// MemoryIngestData.DecRef(). +func (e *Engine) LoadIngestData( + ctx context.Context, + regionRanges []common.Range, + outCh chan<- common.DataAndRange, +) error { + // try to make every worker busy for each batch + regionBatchSize := e.workerConcurrency + failpoint.Inject("LoadIngestDataBatchSize", func(val failpoint.Value) { + regionBatchSize = val.(int) + }) + for i := 0; i < len(regionRanges); i += regionBatchSize { + err := e.loadBatchRegionData(ctx, regionRanges[i].Start, regionRanges[min(i+regionBatchSize, len(regionRanges))-1].End, outCh) + if err != nil { + return err + } + } + return nil +} + +func (e *Engine) buildIngestData(keys, values [][]byte, buf []*membuf.Buffer) *MemoryIngestData { + return &MemoryIngestData{ + keyAdapter: e.keyAdapter, + duplicateDetection: e.duplicateDetection, + duplicateDB: e.duplicateDB, + dupDetectOpt: e.dupDetectOpt, + keys: keys, + values: values, + ts: e.ts, + memBuf: buf, + refCnt: atomic.NewInt64(0), + importedKVSize: e.importedKVSize, + importedKVCount: e.importedKVCount, + } +} + +// LargeRegionSplitDataThreshold is exposed for test. +var LargeRegionSplitDataThreshold = int(config.SplitRegionSize) + +// KVStatistics returns the total kv size and total kv count. +func (e *Engine) KVStatistics() (totalKVSize int64, totalKVCount int64) { + return e.totalKVSize, e.totalKVCount +} + +// ImportedStatistics returns the imported kv size and imported kv count. +func (e *Engine) ImportedStatistics() (importedSize int64, importedKVCount int64) { + return e.importedKVSize.Load(), e.importedKVCount.Load() +} + +// ID is the identifier of an engine. +func (e *Engine) ID() string { + return "external" +} + +// GetKeyRange implements common.Engine. +func (e *Engine) GetKeyRange() (startKey []byte, endKey []byte, err error) { + if _, ok := e.keyAdapter.(common.NoopKeyAdapter); ok { + return e.startKey, e.endKey, nil + } + + // when duplicate detection feature is enabled, the end key comes from + // DupDetectKeyAdapter.Encode or Key.Next(). We try to decode it and check the + // error. + + start, err := e.keyAdapter.Decode(nil, e.startKey) + if err != nil { + return nil, nil, err + } + end, err := e.keyAdapter.Decode(nil, e.endKey) + if err == nil { + return start, end, nil + } + // handle the case that end key is from Key.Next() + if e.endKey[len(e.endKey)-1] != 0 { + return nil, nil, err + } + endEncoded := e.endKey[:len(e.endKey)-1] + end, err = e.keyAdapter.Decode(nil, endEncoded) + if err != nil { + return nil, nil, err + } + return start, kv.Key(end).Next(), nil +} + +// SplitRanges split the ranges by split keys provided by external engine. +func (e *Engine) SplitRanges( + startKey, endKey []byte, + _, _ int64, + _ log.Logger, +) ([]common.Range, error) { + splitKeys := e.splitKeys + for i, k := range e.splitKeys { + var err error + splitKeys[i], err = e.keyAdapter.Decode(nil, k) + if err != nil { + return nil, err + } + } + ranges := make([]common.Range, 0, len(splitKeys)+1) + ranges = append(ranges, common.Range{Start: startKey}) + for i := 0; i < len(splitKeys); i++ { + ranges[len(ranges)-1].End = splitKeys[i] + var endK []byte + if i < len(splitKeys)-1 { + endK = splitKeys[i+1] + } + ranges = append(ranges, common.Range{Start: splitKeys[i], End: endK}) + } + ranges[len(ranges)-1].End = endKey + return ranges, nil +} + +// Close implements common.Engine. +func (e *Engine) Close() error { + if e.smallBlockBufPool != nil { + e.smallBlockBufPool.Destroy() + e.smallBlockBufPool = nil + } + if e.largeBlockBufPool != nil { + e.largeBlockBufPool.Destroy() + e.largeBlockBufPool = nil + } + e.storage.Close() + return nil +} + +// Reset resets the memory buffer pool. +func (e *Engine) Reset() error { + memLimiter := membuf.NewLimiter(memLimit) + if e.smallBlockBufPool != nil { + e.smallBlockBufPool.Destroy() + e.smallBlockBufPool = membuf.NewPool( + membuf.WithBlockNum(0), + membuf.WithPoolMemoryLimiter(memLimiter), + membuf.WithBlockSize(smallBlockSize), + ) + } + if e.largeBlockBufPool != nil { + e.largeBlockBufPool.Destroy() + e.largeBlockBufPool = membuf.NewPool( + membuf.WithBlockNum(0), + membuf.WithPoolMemoryLimiter(memLimiter), + membuf.WithBlockSize(ConcurrentReaderBufferSizePerConc), + ) + } + return nil +} + +// MemoryIngestData is the in-memory implementation of IngestData. +type MemoryIngestData struct { + keyAdapter common.KeyAdapter + duplicateDetection bool + duplicateDB *pebble.DB + dupDetectOpt common.DupDetectOpt + + keys [][]byte + values [][]byte + ts uint64 + + memBuf []*membuf.Buffer + refCnt *atomic.Int64 + importedKVSize *atomic.Int64 + importedKVCount *atomic.Int64 +} + +var _ common.IngestData = (*MemoryIngestData)(nil) + +func (m *MemoryIngestData) firstAndLastKeyIndex(lowerBound, upperBound []byte) (int, int) { + firstKeyIdx := 0 + if len(lowerBound) > 0 { + lowerBound = m.keyAdapter.Encode(nil, lowerBound, common.MinRowID) + firstKeyIdx = sort.Search(len(m.keys), func(i int) bool { + return bytes.Compare(lowerBound, m.keys[i]) <= 0 + }) + if firstKeyIdx == len(m.keys) { + return -1, -1 + } + } + + lastKeyIdx := len(m.keys) - 1 + if len(upperBound) > 0 { + upperBound = m.keyAdapter.Encode(nil, upperBound, common.MinRowID) + i := sort.Search(len(m.keys), func(i int) bool { + reverseIdx := len(m.keys) - 1 - i + return bytes.Compare(upperBound, m.keys[reverseIdx]) > 0 + }) + if i == len(m.keys) { + // should not happen + return -1, -1 + } + lastKeyIdx = len(m.keys) - 1 - i + } + return firstKeyIdx, lastKeyIdx +} + +// GetFirstAndLastKey implements IngestData.GetFirstAndLastKey. +func (m *MemoryIngestData) GetFirstAndLastKey(lowerBound, upperBound []byte) ([]byte, []byte, error) { + firstKeyIdx, lastKeyIdx := m.firstAndLastKeyIndex(lowerBound, upperBound) + if firstKeyIdx < 0 || firstKeyIdx > lastKeyIdx { + return nil, nil, nil + } + firstKey, err := m.keyAdapter.Decode(nil, m.keys[firstKeyIdx]) + if err != nil { + return nil, nil, err + } + lastKey, err := m.keyAdapter.Decode(nil, m.keys[lastKeyIdx]) + if err != nil { + return nil, nil, err + } + return firstKey, lastKey, nil +} + +type memoryDataIter struct { + keys [][]byte + values [][]byte + + firstKeyIdx int + lastKeyIdx int + curIdx int +} + +// First implements ForwardIter. +func (m *memoryDataIter) First() bool { + if m.firstKeyIdx < 0 { + return false + } + m.curIdx = m.firstKeyIdx + return true +} + +// Valid implements ForwardIter. +func (m *memoryDataIter) Valid() bool { + return m.firstKeyIdx <= m.curIdx && m.curIdx <= m.lastKeyIdx +} + +// Next implements ForwardIter. +func (m *memoryDataIter) Next() bool { + m.curIdx++ + return m.Valid() +} + +// Key implements ForwardIter. +func (m *memoryDataIter) Key() []byte { + return m.keys[m.curIdx] +} + +// Value implements ForwardIter. +func (m *memoryDataIter) Value() []byte { + return m.values[m.curIdx] +} + +// Close implements ForwardIter. +func (m *memoryDataIter) Close() error { + return nil +} + +// Error implements ForwardIter. +func (m *memoryDataIter) Error() error { + return nil +} + +// ReleaseBuf implements ForwardIter. +func (m *memoryDataIter) ReleaseBuf() {} + +type memoryDataDupDetectIter struct { + iter *memoryDataIter + dupDetector *common.DupDetector + err error + curKey, curVal []byte + buf *membuf.Buffer +} + +// First implements ForwardIter. +func (m *memoryDataDupDetectIter) First() bool { + if m.err != nil || !m.iter.First() { + return false + } + m.curKey, m.curVal, m.err = m.dupDetector.Init(m.iter) + return m.Valid() +} + +// Valid implements ForwardIter. +func (m *memoryDataDupDetectIter) Valid() bool { + return m.err == nil && m.iter.Valid() +} + +// Next implements ForwardIter. +func (m *memoryDataDupDetectIter) Next() bool { + if m.err != nil { + return false + } + key, val, ok, err := m.dupDetector.Next(m.iter) + if err != nil { + m.err = err + return false + } + if !ok { + return false + } + m.curKey, m.curVal = key, val + return true +} + +// Key implements ForwardIter. +func (m *memoryDataDupDetectIter) Key() []byte { + return m.buf.AddBytes(m.curKey) +} + +// Value implements ForwardIter. +func (m *memoryDataDupDetectIter) Value() []byte { + return m.buf.AddBytes(m.curVal) +} + +// Close implements ForwardIter. +func (m *memoryDataDupDetectIter) Close() error { + m.buf.Destroy() + return m.dupDetector.Close() +} + +// Error implements ForwardIter. +func (m *memoryDataDupDetectIter) Error() error { + return m.err +} + +// ReleaseBuf implements ForwardIter. +func (m *memoryDataDupDetectIter) ReleaseBuf() { + m.buf.Reset() +} + +// NewIter implements IngestData.NewIter. +func (m *MemoryIngestData) NewIter( + ctx context.Context, + lowerBound, upperBound []byte, + bufPool *membuf.Pool, +) common.ForwardIter { + firstKeyIdx, lastKeyIdx := m.firstAndLastKeyIndex(lowerBound, upperBound) + iter := &memoryDataIter{ + keys: m.keys, + values: m.values, + firstKeyIdx: firstKeyIdx, + lastKeyIdx: lastKeyIdx, + } + if !m.duplicateDetection { + return iter + } + logger := log.FromContext(ctx) + detector := common.NewDupDetector(m.keyAdapter, m.duplicateDB.NewBatch(), logger, m.dupDetectOpt) + return &memoryDataDupDetectIter{ + iter: iter, + dupDetector: detector, + buf: bufPool.NewBuffer(), + } +} + +// GetTS implements IngestData.GetTS. +func (m *MemoryIngestData) GetTS() uint64 { + return m.ts +} + +// IncRef implements IngestData.IncRef. +func (m *MemoryIngestData) IncRef() { + m.refCnt.Inc() +} + +// DecRef implements IngestData.DecRef. +func (m *MemoryIngestData) DecRef() { + if m.refCnt.Dec() == 0 { + m.keys = nil + m.values = nil + for _, b := range m.memBuf { + b.Destroy() + } + } +} + +// Finish implements IngestData.Finish. +func (m *MemoryIngestData) Finish(totalBytes, totalCount int64) { + m.importedKVSize.Add(totalBytes) + m.importedKVCount.Add(totalCount) + +} diff --git a/pkg/lightning/backend/external/merge_v2.go b/pkg/lightning/backend/external/merge_v2.go index af32569d6745d..49de736b26ade 100644 --- a/pkg/lightning/backend/external/merge_v2.go +++ b/pkg/lightning/backend/external/merge_v2.go @@ -67,9 +67,9 @@ func MergeOverlappingFilesV2( }() rangesGroupSize := 4 * size.GB - failpoint.Inject("mockRangesGroupSize", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockRangesGroupSize")); _err_ == nil { rangesGroupSize = uint64(val.(int)) - }) + } splitter, err := NewRangeSplitter( ctx, diff --git a/pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ b/pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ new file mode 100644 index 0000000000000..af32569d6745d --- /dev/null +++ b/pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ @@ -0,0 +1,183 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package external + +import ( + "bytes" + "context" + "math" + "time" + + "github.com/jfcg/sorty/v2" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/membuf" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/size" + "go.uber.org/zap" +) + +// MergeOverlappingFilesV2 reads from given files whose key range may overlap +// and writes to new sorted, nonoverlapping files. +// Using 1 readAllData and 1 writer. +func MergeOverlappingFilesV2( + ctx context.Context, + multiFileStat []MultipleFilesStat, + store storage.ExternalStorage, + startKey []byte, + endKey []byte, + partSize int64, + newFilePrefix string, + writerID string, + blockSize int, + writeBatchCount uint64, + propSizeDist uint64, + propKeysDist uint64, + onClose OnCloseFunc, + concurrency int, + checkHotspot bool, +) (err error) { + fileCnt := 0 + for _, m := range multiFileStat { + fileCnt += len(m.Filenames) + } + task := log.BeginTask(logutil.Logger(ctx).With( + zap.Int("file-count", fileCnt), + zap.Binary("start-key", startKey), + zap.Binary("end-key", endKey), + zap.String("new-file-prefix", newFilePrefix), + zap.Int("concurrency", concurrency), + ), "merge overlapping files") + defer func() { + task.End(zap.ErrorLevel, err) + }() + + rangesGroupSize := 4 * size.GB + failpoint.Inject("mockRangesGroupSize", func(val failpoint.Value) { + rangesGroupSize = uint64(val.(int)) + }) + + splitter, err := NewRangeSplitter( + ctx, + multiFileStat, + store, + int64(rangesGroupSize), + math.MaxInt64, + int64(4*size.GB), + math.MaxInt64, + ) + if err != nil { + return err + } + + writer := NewWriterBuilder(). + SetMemorySizeLimit(DefaultMemSizeLimit). + SetBlockSize(blockSize). + SetPropKeysDistance(propKeysDist). + SetPropSizeDistance(propSizeDist). + SetOnCloseFunc(onClose). + BuildOneFile(store, newFilePrefix, writerID) + defer func() { + err = splitter.Close() + if err != nil { + logutil.Logger(ctx).Warn("close range splitter failed", zap.Error(err)) + } + err = writer.Close(ctx) + if err != nil { + logutil.Logger(ctx).Warn("close writer failed", zap.Error(err)) + } + }() + + err = writer.Init(ctx, partSize) + if err != nil { + logutil.Logger(ctx).Warn("init writer failed", zap.Error(err)) + return + } + + bufPool := membuf.NewPool() + loaded := &memKVsAndBuffers{} + curStart := kv.Key(startKey).Clone() + var curEnd kv.Key + + for { + endKeyOfGroup, dataFilesOfGroup, statFilesOfGroup, _, err1 := splitter.SplitOneRangesGroup() + if err1 != nil { + logutil.Logger(ctx).Warn("split one ranges group failed", zap.Error(err1)) + return + } + curEnd = kv.Key(endKeyOfGroup).Clone() + if len(endKeyOfGroup) == 0 { + curEnd = kv.Key(endKey).Clone() + } + now := time.Now() + err1 = readAllData( + ctx, + store, + dataFilesOfGroup, + statFilesOfGroup, + curStart, + curEnd, + bufPool, + bufPool, + loaded, + ) + if err1 != nil { + logutil.Logger(ctx).Warn("read all data failed", zap.Error(err1)) + return + } + loaded.build(ctx) + readTime := time.Since(now) + now = time.Now() + sorty.MaxGor = uint64(concurrency) + sorty.Sort(len(loaded.keys), func(i, k, r, s int) bool { + if bytes.Compare(loaded.keys[i], loaded.keys[k]) < 0 { // strict comparator like < or > + if r != s { + loaded.keys[r], loaded.keys[s] = loaded.keys[s], loaded.keys[r] + loaded.values[r], loaded.values[s] = loaded.values[s], loaded.values[r] + } + return true + } + return false + }) + sortTime := time.Since(now) + now = time.Now() + for i, key := range loaded.keys { + err1 = writer.WriteRow(ctx, key, loaded.values[i]) + if err1 != nil { + logutil.Logger(ctx).Warn("write one row to writer failed", zap.Error(err1)) + return + } + } + writeTime := time.Since(now) + logutil.Logger(ctx).Info("sort one group in MergeOverlappingFiles", + zap.Duration("read time", readTime), + zap.Duration("sort time", sortTime), + zap.Duration("write time", writeTime), + zap.Int("key len", len(loaded.keys))) + + curStart = curEnd.Clone() + loaded.keys = nil + loaded.values = nil + loaded.memKVBuffers = nil + loaded.size = 0 + + if len(endKeyOfGroup) == 0 { + break + } + } + return +} diff --git a/pkg/lightning/backend/local/binding__failpoint_binding__.go b/pkg/lightning/backend/local/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..1ae5c5095df43 --- /dev/null +++ b/pkg/lightning/backend/local/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package local + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/lightning/backend/local/checksum.go b/pkg/lightning/backend/local/checksum.go index 967e32d37ea46..206717d957393 100644 --- a/pkg/lightning/backend/local/checksum.go +++ b/pkg/lightning/backend/local/checksum.go @@ -241,7 +241,7 @@ func increaseGCLifeTime(ctx context.Context, manager *gcLifeTimeManager, db *sql } } - failpoint.Inject("IncreaseGCUpdateDuration", nil) + failpoint.Eval(_curpkg_("IncreaseGCUpdateDuration")) return nil } diff --git a/pkg/lightning/backend/local/checksum.go__failpoint_stash__ b/pkg/lightning/backend/local/checksum.go__failpoint_stash__ new file mode 100644 index 0000000000000..967e32d37ea46 --- /dev/null +++ b/pkg/lightning/backend/local/checksum.go__failpoint_stash__ @@ -0,0 +1,517 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "container/heap" + "context" + "database/sql" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/checksum" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tipb/go-tipb" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" + pd "github.com/tikv/pd/client" + pderrs "github.com/tikv/pd/client/errs" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +const ( + preUpdateServiceSafePointFactor = 3 + maxErrorRetryCount = 3 + defaultGCLifeTime = 100 * time.Hour +) + +var ( + serviceSafePointTTL int64 = 10 * 60 // 10 min in seconds + + // MinDistSQLScanConcurrency is the minimum value of tidb_distsql_scan_concurrency. + MinDistSQLScanConcurrency = 4 + + // DefaultBackoffWeight is the default value of tidb_backoff_weight for checksum. + // RegionRequestSender will retry within a maxSleep time, default is 2 * 20 = 40 seconds. + // When TiKV client encounters an error of "region not leader", it will keep + // retrying every 500 ms, if it still fails after maxSleep, it will return "region unavailable". + // When there are many pending compaction bytes, TiKV might not respond within 1m, + // and report "rpcError:wait recvLoop timeout,timeout:1m0s", and retry might + // time out again. + // so we enlarge it to 30 * 20 = 10 minutes. + DefaultBackoffWeight = 15 * tikvstore.DefBackOffWeight +) + +// RemoteChecksum represents a checksum result got from tidb. +type RemoteChecksum struct { + Schema string + Table string + Checksum uint64 + TotalKVs uint64 + TotalBytes uint64 +} + +// IsEqual checks whether the checksum is equal to the other. +func (rc *RemoteChecksum) IsEqual(other *verification.KVChecksum) bool { + return rc.Checksum == other.Sum() && + rc.TotalKVs == other.SumKVS() && + rc.TotalBytes == other.SumSize() +} + +// ChecksumManager is a manager that manages checksums. +type ChecksumManager interface { + Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) +} + +// fetch checksum for tidb sql client +type tidbChecksumExecutor struct { + db *sql.DB + manager *gcLifeTimeManager +} + +var _ ChecksumManager = (*tidbChecksumExecutor)(nil) + +// NewTiDBChecksumExecutor creates a new tidb checksum executor. +func NewTiDBChecksumExecutor(db *sql.DB) ChecksumManager { + return &tidbChecksumExecutor{ + db: db, + manager: newGCLifeTimeManager(), + } +} + +func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) { + var err error + if err = e.manager.addOneJob(ctx, e.db); err != nil { + return nil, err + } + + // set it back finally + defer e.manager.removeOneJob(ctx, e.db) + + tableName := common.UniqueTable(tableInfo.DB, tableInfo.Name) + + task := log.FromContext(ctx).With(zap.String("table", tableName)).Begin(zap.InfoLevel, "remote checksum") + + conn, err := e.db.Conn(ctx) + if err != nil { + return nil, errors.Trace(err) + } + defer func() { + if err := conn.Close(); err != nil { + task.Warn("close connection failed", zap.Error(err)) + } + }() + // ADMIN CHECKSUM TABLE
,
example. + // mysql> admin checksum table test.t; + // +---------+------------+---------------------+-----------+-------------+ + // | Db_name | Table_name | Checksum_crc64_xor | Total_kvs | Total_bytes | + // +---------+------------+---------------------+-----------+-------------+ + // | test | t | 8520875019404689597 | 7296873 | 357601387 | + // +---------+------------+---------------------+-----------+-------------+ + backoffWeight, err := common.GetBackoffWeightFromDB(ctx, e.db) + if err == nil && backoffWeight < DefaultBackoffWeight { + task.Info("increase tidb_backoff_weight", zap.Int("original", backoffWeight), zap.Int("new", DefaultBackoffWeight)) + // increase backoff weight + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, DefaultBackoffWeight)); err != nil { + task.Warn("set tidb_backoff_weight failed", zap.Error(err)) + } else { + defer func() { + if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, backoffWeight)); err != nil { + task.Warn("recover tidb_backoff_weight failed", zap.Error(err)) + } + }() + } + } + + cs := RemoteChecksum{} + err = common.SQLWithRetry{DB: conn, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", + "ADMIN CHECKSUM TABLE "+tableName, &cs.Schema, &cs.Table, &cs.Checksum, &cs.TotalKVs, &cs.TotalBytes, + ) + dur := task.End(zap.ErrorLevel, err) + if m, ok := metric.FromContext(ctx); ok { + m.ChecksumSecondsHistogram.Observe(dur.Seconds()) + } + if err != nil { + return nil, errors.Trace(err) + } + return &cs, nil +} + +type gcLifeTimeManager struct { + runningJobsLock sync.Mutex + runningJobs int + oriGCLifeTime string +} + +func newGCLifeTimeManager() *gcLifeTimeManager { + // Default values of three member are enough to initialize this struct + return &gcLifeTimeManager{} +} + +// Pre- and post-condition: +// if m.runningJobs == 0, GC life time has not been increased. +// if m.runningJobs > 0, GC life time has been increased. +// m.runningJobs won't be negative(overflow) since index concurrency is relatively small +func (m *gcLifeTimeManager) addOneJob(ctx context.Context, db *sql.DB) error { + m.runningJobsLock.Lock() + defer m.runningJobsLock.Unlock() + + if m.runningJobs == 0 { + oriGCLifeTime, err := obtainGCLifeTime(ctx, db) + if err != nil { + return err + } + m.oriGCLifeTime = oriGCLifeTime + err = increaseGCLifeTime(ctx, m, db) + if err != nil { + return err + } + } + m.runningJobs++ + return nil +} + +// Pre- and post-condition: +// if m.runningJobs == 0, GC life time has been tried to recovered. If this try fails, a warning will be printed. +// if m.runningJobs > 0, GC life time has not been recovered. +// m.runningJobs won't minus to negative since removeOneJob follows a successful addOneJob. +func (m *gcLifeTimeManager) removeOneJob(ctx context.Context, db *sql.DB) { + m.runningJobsLock.Lock() + defer m.runningJobsLock.Unlock() + + m.runningJobs-- + if m.runningJobs == 0 { + err := updateGCLifeTime(ctx, db, m.oriGCLifeTime) + if err != nil { + query := fmt.Sprintf( + "UPDATE mysql.tidb SET VARIABLE_VALUE = '%s' WHERE VARIABLE_NAME = 'tikv_gc_life_time'", + m.oriGCLifeTime, + ) + log.FromContext(ctx).Warn("revert GC lifetime failed, please reset the GC lifetime manually after Lightning completed", + zap.String("query", query), + log.ShortError(err), + ) + } + } +} + +func increaseGCLifeTime(ctx context.Context, manager *gcLifeTimeManager, db *sql.DB) (err error) { + // checksum command usually takes a long time to execute, + // so here need to increase the gcLifeTime for single transaction. + var increaseGCLifeTime bool + if manager.oriGCLifeTime != "" { + ori, err := time.ParseDuration(manager.oriGCLifeTime) + if err != nil { + return errors.Trace(err) + } + if ori < defaultGCLifeTime { + increaseGCLifeTime = true + } + } else { + increaseGCLifeTime = true + } + + if increaseGCLifeTime { + err = updateGCLifeTime(ctx, db, defaultGCLifeTime.String()) + if err != nil { + return err + } + } + + failpoint.Inject("IncreaseGCUpdateDuration", nil) + + return nil +} + +// obtainGCLifeTime obtains the current GC lifetime. +func obtainGCLifeTime(ctx context.Context, db *sql.DB) (string, error) { + var gcLifeTime string + err := common.SQLWithRetry{DB: db, Logger: log.FromContext(ctx)}.QueryRow( + ctx, + "obtain GC lifetime", + "SELECT VARIABLE_VALUE FROM mysql.tidb WHERE VARIABLE_NAME = 'tikv_gc_life_time'", + &gcLifeTime, + ) + return gcLifeTime, err +} + +// updateGCLifeTime updates the current GC lifetime. +func updateGCLifeTime(ctx context.Context, db *sql.DB, gcLifeTime string) error { + sql := common.SQLWithRetry{ + DB: db, + Logger: log.FromContext(ctx).With(zap.String("gcLifeTime", gcLifeTime)), + } + return sql.Exec(ctx, "update GC lifetime", + "UPDATE mysql.tidb SET VARIABLE_VALUE = ? WHERE VARIABLE_NAME = 'tikv_gc_life_time'", + gcLifeTime, + ) +} + +// TiKVChecksumManager is a manager that can compute checksum of a table using TiKV. +type TiKVChecksumManager struct { + client kv.Client + manager gcTTLManager + distSQLScanConcurrency uint + backoffWeight int + resourceGroupName string + explicitRequestSourceType string +} + +var _ ChecksumManager = &TiKVChecksumManager{} + +// NewTiKVChecksumManager return a new tikv checksum manager +func NewTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint, backoffWeight int, resourceGroupName, explicitRequestSourceType string) *TiKVChecksumManager { + return &TiKVChecksumManager{ + client: client, + manager: newGCTTLManager(pdClient), + distSQLScanConcurrency: distSQLScanConcurrency, + backoffWeight: backoffWeight, + resourceGroupName: resourceGroupName, + explicitRequestSourceType: explicitRequestSourceType, + } +} + +func (e *TiKVChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpoints.TidbTableInfo, ts uint64) (*RemoteChecksum, error) { + executor, err := checksum.NewExecutorBuilder(tableInfo.Core, ts). + SetConcurrency(e.distSQLScanConcurrency). + SetBackoffWeight(e.backoffWeight). + SetResourceGroupName(e.resourceGroupName). + SetExplicitRequestSourceType(e.explicitRequestSourceType). + Build() + if err != nil { + return nil, errors.Trace(err) + } + + distSQLScanConcurrency := int(e.distSQLScanConcurrency) + for i := 0; i < maxErrorRetryCount; i++ { + _ = executor.Each(func(request *kv.Request) error { + request.Concurrency = distSQLScanConcurrency + return nil + }) + var execRes *tipb.ChecksumResponse + execRes, err = executor.Execute(ctx, e.client, func() {}) + if err == nil { + return &RemoteChecksum{ + Schema: tableInfo.DB, + Table: tableInfo.Name, + Checksum: execRes.Checksum, + TotalBytes: execRes.TotalBytes, + TotalKVs: execRes.TotalKvs, + }, nil + } + + log.FromContext(ctx).Warn("remote checksum failed", zap.String("db", tableInfo.DB), + zap.String("table", tableInfo.Name), zap.Error(err), + zap.Int("concurrency", distSQLScanConcurrency), zap.Int("retry", i)) + + // do not retry context.Canceled error + if !common.IsRetryableError(err) { + break + } + if distSQLScanConcurrency > MinDistSQLScanConcurrency { + distSQLScanConcurrency = max(distSQLScanConcurrency/2, MinDistSQLScanConcurrency) + } + } + + return nil, err +} + +var retryGetTSInterval = time.Second + +// Checksum implements the ChecksumManager interface. +func (e *TiKVChecksumManager) Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) { + tbl := common.UniqueTable(tableInfo.DB, tableInfo.Name) + var ( + physicalTS, logicalTS int64 + err error + retryTime int + ) + physicalTS, logicalTS, err = e.manager.pdClient.GetTS(ctx) + for err != nil { + if !pderrs.IsLeaderChange(errors.Cause(err)) { + return nil, errors.Annotate(err, "fetch tso from pd failed") + } + retryTime++ + if retryTime%60 == 0 { + log.FromContext(ctx).Warn("fetch tso from pd failed and retrying", + zap.Int("retryTime", retryTime), + zap.Error(err)) + } + select { + case <-ctx.Done(): + err = ctx.Err() + case <-time.After(retryGetTSInterval): + physicalTS, logicalTS, err = e.manager.pdClient.GetTS(ctx) + } + } + ts := oracle.ComposeTS(physicalTS, logicalTS) + if err := e.manager.addOneJob(ctx, tbl, ts); err != nil { + return nil, errors.Trace(err) + } + defer e.manager.removeOneJob(tbl) + + return e.checksumDB(ctx, tableInfo, ts) +} + +type tableChecksumTS struct { + table string + gcSafeTS uint64 +} + +// following function are for implement `heap.Interface` + +func (m *gcTTLManager) Len() int { + return len(m.tableGCSafeTS) +} + +func (m *gcTTLManager) Less(i, j int) bool { + return m.tableGCSafeTS[i].gcSafeTS < m.tableGCSafeTS[j].gcSafeTS +} + +func (m *gcTTLManager) Swap(i, j int) { + m.tableGCSafeTS[i], m.tableGCSafeTS[j] = m.tableGCSafeTS[j], m.tableGCSafeTS[i] +} + +func (m *gcTTLManager) Push(x any) { + m.tableGCSafeTS = append(m.tableGCSafeTS, x.(*tableChecksumTS)) +} + +func (m *gcTTLManager) Pop() any { + i := m.tableGCSafeTS[len(m.tableGCSafeTS)-1] + m.tableGCSafeTS = m.tableGCSafeTS[:len(m.tableGCSafeTS)-1] + return i +} + +type gcTTLManager struct { + lock sync.Mutex + pdClient pd.Client + // tableGCSafeTS is a binary heap that stored active checksum jobs GC safe point ts + tableGCSafeTS []*tableChecksumTS + currentTS uint64 + serviceID string + // 0 for not start, otherwise started + started atomic.Bool +} + +func newGCTTLManager(pdClient pd.Client) gcTTLManager { + return gcTTLManager{ + pdClient: pdClient, + serviceID: fmt.Sprintf("lightning-%s", uuid.New()), + } +} + +func (m *gcTTLManager) addOneJob(ctx context.Context, table string, ts uint64) error { + // start gc ttl loop if not started yet. + if m.started.CompareAndSwap(false, true) { + m.start(ctx) + } + m.lock.Lock() + defer m.lock.Unlock() + var curTS uint64 + if len(m.tableGCSafeTS) > 0 { + curTS = m.tableGCSafeTS[0].gcSafeTS + } + m.Push(&tableChecksumTS{table: table, gcSafeTS: ts}) + heap.Fix(m, len(m.tableGCSafeTS)-1) + m.currentTS = m.tableGCSafeTS[0].gcSafeTS + if curTS == 0 || m.currentTS < curTS { + return m.doUpdateGCTTL(ctx, m.currentTS) + } + return nil +} + +func (m *gcTTLManager) removeOneJob(table string) { + m.lock.Lock() + defer m.lock.Unlock() + idx := -1 + for i := 0; i < len(m.tableGCSafeTS); i++ { + if m.tableGCSafeTS[i].table == table { + idx = i + break + } + } + + if idx >= 0 { + l := len(m.tableGCSafeTS) + m.tableGCSafeTS[idx] = m.tableGCSafeTS[l-1] + m.tableGCSafeTS = m.tableGCSafeTS[:l-1] + if l > 1 && idx < l-1 { + heap.Fix(m, idx) + } + } + + var newTS uint64 + if len(m.tableGCSafeTS) > 0 { + newTS = m.tableGCSafeTS[0].gcSafeTS + } + m.currentTS = newTS +} + +func (m *gcTTLManager) updateGCTTL(ctx context.Context) error { + m.lock.Lock() + currentTS := m.currentTS + m.lock.Unlock() + return m.doUpdateGCTTL(ctx, currentTS) +} + +func (m *gcTTLManager) doUpdateGCTTL(ctx context.Context, ts uint64) error { + log.FromContext(ctx).Debug("update PD safePoint limit with TTL", + zap.Uint64("currnet_ts", ts)) + var err error + if ts > 0 { + _, err = m.pdClient.UpdateServiceGCSafePoint(ctx, + m.serviceID, serviceSafePointTTL, ts) + } + return err +} + +func (m *gcTTLManager) start(ctx context.Context) { + // It would be OK since TTL won't be zero, so gapTime should > `0. + updateGapTime := time.Duration(serviceSafePointTTL) * time.Second / preUpdateServiceSafePointFactor + + updateTick := time.NewTicker(updateGapTime) + + updateGCTTL := func() { + if err := m.updateGCTTL(ctx); err != nil { + log.FromContext(ctx).Warn("failed to update service safe point, checksum may fail if gc triggered", zap.Error(err)) + } + } + + // trigger a service gc ttl at start + updateGCTTL() + go func() { + defer updateTick.Stop() + for { + select { + case <-ctx.Done(): + log.FromContext(ctx).Info("service safe point keeper exited") + return + case <-updateTick.C: + updateGCTTL() + } + } + }() +} diff --git a/pkg/lightning/backend/local/engine.go b/pkg/lightning/backend/local/engine.go index 9374b2ade74ff..a6bb4c8c7c887 100644 --- a/pkg/lightning/backend/local/engine.go +++ b/pkg/lightning/backend/local/engine.go @@ -1026,9 +1026,9 @@ func (e *Engine) GetFirstAndLastKey(lowerBound, upperBound []byte) ([]byte, []by LowerBound: lowerBound, UpperBound: upperBound, } - failpoint.Inject("mockGetFirstAndLastKey", func() { - failpoint.Return(lowerBound, upperBound, nil) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockGetFirstAndLastKey")); _err_ == nil { + return lowerBound, upperBound, nil + } iter := e.newKVIter(context.Background(), opt, nil) //nolint: errcheck @@ -1332,13 +1332,13 @@ func (w *Writer) flushKVs(ctx context.Context) error { return errors.Trace(err) } - failpoint.Inject("orphanWriterGoRoutine", func() { + if _, _err_ := failpoint.Eval(_curpkg_("orphanWriterGoRoutine")); _err_ == nil { _ = common.KillMySelf() // mimic we meet context cancel error when `addSST` <-ctx.Done() time.Sleep(5 * time.Second) - failpoint.Return(errors.Trace(ctx.Err())) - }) + return errors.Trace(ctx.Err()) + } err = w.addSST(ctx, meta) if err != nil { diff --git a/pkg/lightning/backend/local/engine.go__failpoint_stash__ b/pkg/lightning/backend/local/engine.go__failpoint_stash__ new file mode 100644 index 0000000000000..9374b2ade74ff --- /dev/null +++ b/pkg/lightning/backend/local/engine.go__failpoint_stash__ @@ -0,0 +1,1682 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "bytes" + "container/heap" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "slices" + "sync" + "time" + "unsafe" + + "github.com/cockroachdb/pebble" + "github.com/cockroachdb/pebble/objstorage/objstorageprovider" + "github.com/cockroachdb/pebble/sstable" + "github.com/cockroachdb/pebble/vfs" + "github.com/google/btree" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/membuf" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/checkpoints" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/tikv/client-go/v2/tikv" + "go.uber.org/atomic" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var ( + engineMetaKey = []byte{0, 'm', 'e', 't', 'a'} + normalIterStartKey = []byte{1} +) + +type importMutexState uint32 + +const ( + importMutexStateImport importMutexState = 1 << iota + importMutexStateClose + // importMutexStateReadLock is a special state because in this state we lock engine with read lock + // and add isImportingAtomic with this value. In other state, we directly store with the state value. + // so this must always the last value of this enum. + importMutexStateReadLock + // we need to lock the engine when it's open as we do when it's close, otherwise GetEngienSize may race with OpenEngine + importMutexStateOpen +) + +const ( + // DupDetectDirSuffix is used by pre-deduplication to store the encoded index KV. + DupDetectDirSuffix = ".dupdetect" + // DupResultDirSuffix is used by pre-deduplication to store the duplicated row ID. + DupResultDirSuffix = ".dupresult" +) + +// engineMeta contains some field that is necessary to continue the engine restore/import process. +// These field should be written to disk when we update chunk checkpoint +type engineMeta struct { + TS uint64 `json:"ts"` + // Length is the number of KV pairs stored by the engine. + Length atomic.Int64 `json:"length"` + // TotalSize is the total pre-compressed KV byte size stored by engine. + TotalSize atomic.Int64 `json:"total_size"` +} + +type syncedRanges struct { + sync.Mutex + ranges []common.Range +} + +func (r *syncedRanges) add(g common.Range) { + r.Lock() + r.ranges = append(r.ranges, g) + r.Unlock() +} + +func (r *syncedRanges) reset() { + r.Lock() + r.ranges = r.ranges[:0] + r.Unlock() +} + +// Engine is a local engine. +type Engine struct { + engineMeta + closed atomic.Bool + db atomic.Pointer[pebble.DB] + UUID uuid.UUID + localWriters sync.Map + + // isImportingAtomic is an atomic variable indicating whether this engine is importing. + // This should not be used as a "spin lock" indicator. + isImportingAtomic atomic.Uint32 + // flush and ingest sst hold the rlock, other operation hold the wlock. + mutex sync.RWMutex + + ctx context.Context + cancel context.CancelFunc + sstDir string + sstMetasChan chan metaOrFlush + ingestErr common.OnceError + wg sync.WaitGroup + sstIngester sstIngester + + // sst seq lock + seqLock sync.Mutex + // seq number for incoming sst meta + nextSeq int32 + // max seq of sst metas ingested into pebble + finishedMetaSeq atomic.Int32 + + config backend.LocalEngineConfig + tableInfo *checkpoints.TidbTableInfo + + dupDetectOpt common.DupDetectOpt + + // total size of SST files waiting to be ingested + pendingFileSize atomic.Int64 + + // statistics for pebble kv iter. + importedKVSize atomic.Int64 + importedKVCount atomic.Int64 + + keyAdapter common.KeyAdapter + duplicateDetection bool + duplicateDB *pebble.DB + + logger log.Logger +} + +func (e *Engine) setError(err error) { + if err != nil { + e.ingestErr.Set(err) + e.cancel() + } +} + +func (e *Engine) getDB() *pebble.DB { + return e.db.Load() +} + +// Close closes the engine and release all resources. +func (e *Engine) Close() error { + e.logger.Debug("closing local engine", zap.Stringer("engine", e.UUID), zap.Stack("stack")) + db := e.getDB() + if db == nil { + return nil + } + err := errors.Trace(db.Close()) + e.db.Store(nil) + return err +} + +// Cleanup remove meta, db and duplicate detection files +func (e *Engine) Cleanup(dataDir string) error { + if err := os.RemoveAll(e.sstDir); err != nil { + return errors.Trace(err) + } + uuid := e.UUID.String() + if err := os.RemoveAll(filepath.Join(dataDir, uuid+DupDetectDirSuffix)); err != nil { + return errors.Trace(err) + } + if err := os.RemoveAll(filepath.Join(dataDir, uuid+DupResultDirSuffix)); err != nil { + return errors.Trace(err) + } + + dbPath := filepath.Join(dataDir, uuid) + return errors.Trace(os.RemoveAll(dbPath)) +} + +// Exist checks if db folder existing (meta sometimes won't flush before lightning exit) +func (e *Engine) Exist(dataDir string) error { + dbPath := filepath.Join(dataDir, e.UUID.String()) + if _, err := os.Stat(dbPath); err != nil { + return err + } + return nil +} + +func isStateLocked(state importMutexState) bool { + return state&(importMutexStateClose|importMutexStateImport) != 0 +} + +func (e *Engine) isLocked() bool { + // the engine is locked only in import or close state. + return isStateLocked(importMutexState(e.isImportingAtomic.Load())) +} + +// rLock locks the local file with shard read state. Only used for flush and ingest SST files. +func (e *Engine) rLock() { + e.mutex.RLock() + e.isImportingAtomic.Add(uint32(importMutexStateReadLock)) +} + +func (e *Engine) rUnlock() { + if e == nil { + return + } + + e.isImportingAtomic.Sub(uint32(importMutexStateReadLock)) + e.mutex.RUnlock() +} + +// lock locks the local file for importing. +func (e *Engine) lock(state importMutexState) { + e.mutex.Lock() + e.isImportingAtomic.Store(uint32(state)) +} + +// lockUnless tries to lock the local file unless it is already locked into the state given by +// ignoreStateMask. Returns whether the lock is successful. +func (e *Engine) lockUnless(newState, ignoreStateMask importMutexState) bool { + curState := e.isImportingAtomic.Load() + if curState&uint32(ignoreStateMask) != 0 { + return false + } + e.lock(newState) + return true +} + +// tryRLock tries to read-lock the local file unless it is already write locked. +// Returns whether the lock is successful. +func (e *Engine) tryRLock() bool { + curState := e.isImportingAtomic.Load() + // engine is in import/close state. + if isStateLocked(importMutexState(curState)) { + return false + } + e.rLock() + return true +} + +func (e *Engine) unlock() { + if e == nil { + return + } + e.isImportingAtomic.Store(0) + e.mutex.Unlock() +} + +var sizeOfKVPair = int64(unsafe.Sizeof(common.KvPair{})) + +// TotalMemorySize returns the total memory size of the engine. +func (e *Engine) TotalMemorySize() int64 { + var memSize int64 + e.localWriters.Range(func(k, _ any) bool { + w := k.(*Writer) + if w.kvBuffer != nil { + w.Lock() + memSize += w.kvBuffer.TotalSize() + w.Unlock() + } + w.Lock() + memSize += sizeOfKVPair * int64(cap(w.writeBatch)) + w.Unlock() + return true + }) + return memSize +} + +// KVStatistics returns the total kv size and total kv count. +func (e *Engine) KVStatistics() (totalSize int64, totalKVCount int64) { + return e.TotalSize.Load(), e.Length.Load() +} + +// ImportedStatistics returns the imported kv size and imported kv count. +func (e *Engine) ImportedStatistics() (importedSize int64, importedKVCount int64) { + return e.importedKVSize.Load(), e.importedKVCount.Load() +} + +// ID is the identifier of an engine. +func (e *Engine) ID() string { + return e.UUID.String() +} + +// GetKeyRange implements common.Engine. +func (e *Engine) GetKeyRange() (startKey []byte, endKey []byte, err error) { + firstLey, lastKey, err := e.GetFirstAndLastKey(nil, nil) + if err != nil { + return nil, nil, errors.Trace(err) + } + return firstLey, nextKey(lastKey), nil +} + +// SplitRanges gets size properties from pebble and split ranges according to size/keys limit. +func (e *Engine) SplitRanges( + startKey, endKey []byte, + sizeLimit, keysLimit int64, + logger log.Logger, +) ([]common.Range, error) { + sizeProps, err := getSizePropertiesFn(logger, e.getDB(), e.keyAdapter) + if err != nil { + return nil, errors.Trace(err) + } + + ranges := splitRangeBySizeProps( + common.Range{Start: startKey, End: endKey}, + sizeProps, + sizeLimit, + keysLimit, + ) + return ranges, nil +} + +type rangeOffsets struct { + Size uint64 + Keys uint64 +} + +type rangeProperty struct { + Key []byte + rangeOffsets +} + +// Less implements btree.Item interface. +func (r *rangeProperty) Less(than btree.Item) bool { + ta := than.(*rangeProperty) + return bytes.Compare(r.Key, ta.Key) < 0 +} + +var _ btree.Item = &rangeProperty{} + +type rangeProperties []rangeProperty + +// Encode encodes the range properties into a byte slice. +func (r rangeProperties) Encode() []byte { + b := make([]byte, 0, 1024) + idx := 0 + for _, p := range r { + b = append(b, 0, 0, 0, 0) + binary.BigEndian.PutUint32(b[idx:], uint32(len(p.Key))) + idx += 4 + b = append(b, p.Key...) + idx += len(p.Key) + + b = append(b, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint64(b[idx:], p.Size) + idx += 8 + + b = append(b, 0, 0, 0, 0, 0, 0, 0, 0) + binary.BigEndian.PutUint64(b[idx:], p.Keys) + idx += 8 + } + return b +} + +// RangePropertiesCollector collects range properties for each range. +type RangePropertiesCollector struct { + props rangeProperties + lastOffsets rangeOffsets + lastKey []byte + currentOffsets rangeOffsets + propSizeIdxDistance uint64 + propKeysIdxDistance uint64 +} + +func newRangePropertiesCollector() pebble.TablePropertyCollector { + return &RangePropertiesCollector{ + props: make([]rangeProperty, 0, 1024), + propSizeIdxDistance: defaultPropSizeIndexDistance, + propKeysIdxDistance: defaultPropKeysIndexDistance, + } +} + +func (c *RangePropertiesCollector) sizeInLastRange() uint64 { + return c.currentOffsets.Size - c.lastOffsets.Size +} + +func (c *RangePropertiesCollector) keysInLastRange() uint64 { + return c.currentOffsets.Keys - c.lastOffsets.Keys +} + +func (c *RangePropertiesCollector) insertNewPoint(key []byte) { + c.lastOffsets = c.currentOffsets + c.props = append(c.props, rangeProperty{Key: append([]byte{}, key...), rangeOffsets: c.currentOffsets}) +} + +// Add implements `pebble.TablePropertyCollector`. +// Add implements `TablePropertyCollector.Add`. +func (c *RangePropertiesCollector) Add(key pebble.InternalKey, value []byte) error { + if key.Kind() != pebble.InternalKeyKindSet || bytes.Equal(key.UserKey, engineMetaKey) { + return nil + } + c.currentOffsets.Size += uint64(len(value)) + uint64(len(key.UserKey)) + c.currentOffsets.Keys++ + if len(c.lastKey) == 0 || c.sizeInLastRange() >= c.propSizeIdxDistance || + c.keysInLastRange() >= c.propKeysIdxDistance { + c.insertNewPoint(key.UserKey) + } + c.lastKey = append(c.lastKey[:0], key.UserKey...) + return nil +} + +// Finish implements `pebble.TablePropertyCollector`. +func (c *RangePropertiesCollector) Finish(userProps map[string]string) error { + if c.sizeInLastRange() > 0 || c.keysInLastRange() > 0 { + c.insertNewPoint(c.lastKey) + } + + userProps[propRangeIndex] = string(c.props.Encode()) + return nil +} + +// Name implements `pebble.TablePropertyCollector`. +func (*RangePropertiesCollector) Name() string { + return propRangeIndex +} + +type sizeProperties struct { + totalSize uint64 + indexHandles *btree.BTree +} + +func newSizeProperties() *sizeProperties { + return &sizeProperties{indexHandles: btree.New(32)} +} + +func (s *sizeProperties) add(item *rangeProperty) { + if old := s.indexHandles.ReplaceOrInsert(item); old != nil { + o := old.(*rangeProperty) + item.Keys += o.Keys + item.Size += o.Size + } +} + +func (s *sizeProperties) addAll(props rangeProperties) { + prevRange := rangeOffsets{} + for _, r := range props { + s.add(&rangeProperty{ + Key: r.Key, + rangeOffsets: rangeOffsets{Keys: r.Keys - prevRange.Keys, Size: r.Size - prevRange.Size}, + }) + prevRange = r.rangeOffsets + } + if len(props) > 0 { + s.totalSize += props[len(props)-1].Size + } +} + +// iter the tree until f return false +func (s *sizeProperties) iter(f func(p *rangeProperty) bool) { + s.indexHandles.Ascend(func(i btree.Item) bool { + prop := i.(*rangeProperty) + return f(prop) + }) +} + +func decodeRangeProperties(data []byte, keyAdapter common.KeyAdapter) (rangeProperties, error) { + r := make(rangeProperties, 0, 16) + for len(data) > 0 { + if len(data) < 4 { + return nil, io.ErrUnexpectedEOF + } + keyLen := int(binary.BigEndian.Uint32(data[:4])) + data = data[4:] + if len(data) < keyLen+8*2 { + return nil, io.ErrUnexpectedEOF + } + key := data[:keyLen] + data = data[keyLen:] + size := binary.BigEndian.Uint64(data[:8]) + keys := binary.BigEndian.Uint64(data[8:]) + data = data[16:] + if !bytes.Equal(key, engineMetaKey) { + userKey, err := keyAdapter.Decode(nil, key) + if err != nil { + return nil, errors.Annotate(err, "failed to decode key with keyAdapter") + } + r = append(r, rangeProperty{Key: userKey, rangeOffsets: rangeOffsets{Size: size, Keys: keys}}) + } + } + + return r, nil +} + +// getSizePropertiesFn is used to let unit test replace the real function. +var getSizePropertiesFn = getSizeProperties + +func getSizeProperties(logger log.Logger, db *pebble.DB, keyAdapter common.KeyAdapter) (*sizeProperties, error) { + sstables, err := db.SSTables(pebble.WithProperties()) + if err != nil { + logger.Warn("get sst table properties failed", log.ShortError(err)) + return nil, errors.Trace(err) + } + + sizeProps := newSizeProperties() + for _, level := range sstables { + for _, info := range level { + if prop, ok := info.Properties.UserProperties[propRangeIndex]; ok { + data := hack.Slice(prop) + rangeProps, err := decodeRangeProperties(data, keyAdapter) + if err != nil { + logger.Warn("decodeRangeProperties failed", + zap.Stringer("fileNum", info.FileNum), log.ShortError(err)) + return nil, errors.Trace(err) + } + sizeProps.addAll(rangeProps) + } + } + } + + return sizeProps, nil +} + +func (e *Engine) getEngineFileSize() backend.EngineFileSize { + db := e.getDB() + + var total pebble.LevelMetrics + if db != nil { + metrics := db.Metrics() + total = metrics.Total() + } + var memSize int64 + e.localWriters.Range(func(k, _ any) bool { + w := k.(*Writer) + memSize += int64(w.EstimatedSize()) + return true + }) + + pendingSize := e.pendingFileSize.Load() + // TODO: should also add the in-processing compaction sst writer size into MemSize + return backend.EngineFileSize{ + UUID: e.UUID, + DiskSize: total.Size + pendingSize, + MemSize: memSize, + IsImporting: e.isLocked(), + } +} + +// either a sstMeta or a flush message +type metaOrFlush struct { + meta *sstMeta + flushCh chan struct{} +} + +type metaSeq struct { + // the sequence for this flush message, a flush call can return only if + // all the other flush will lower `flushSeq` are done + flushSeq int32 + // the max sstMeta sequence number in this flush, after the flush is done (all SSTs are ingested), + // we can save chunks will a lower meta sequence number safely. + metaSeq int32 +} + +type metaSeqHeap struct { + arr []metaSeq +} + +// Len returns the number of items in the priority queue. +func (h *metaSeqHeap) Len() int { + return len(h.arr) +} + +// Less reports whether the item in the priority queue with +func (h *metaSeqHeap) Less(i, j int) bool { + return h.arr[i].flushSeq < h.arr[j].flushSeq +} + +// Swap swaps the items at the passed indices. +func (h *metaSeqHeap) Swap(i, j int) { + h.arr[i], h.arr[j] = h.arr[j], h.arr[i] +} + +// Push pushes the item onto the priority queue. +func (h *metaSeqHeap) Push(x any) { + h.arr = append(h.arr, x.(metaSeq)) +} + +// Pop removes the minimum item (according to Less) from the priority queue +func (h *metaSeqHeap) Pop() any { + item := h.arr[len(h.arr)-1] + h.arr = h.arr[:len(h.arr)-1] + return item +} + +func (e *Engine) ingestSSTLoop() { + defer e.wg.Done() + + type flushSeq struct { + seq int32 + ch chan struct{} + } + + seq := atomic.NewInt32(0) + finishedSeq := atomic.NewInt32(0) + var seqLock sync.Mutex + // a flush is finished iff all the compaction&ingest tasks with a lower seq number are finished. + flushQueue := make([]flushSeq, 0) + // inSyncSeqs is a heap that stores all the finished compaction tasks whose seq is bigger than `finishedSeq + 1` + // this mean there are still at lease one compaction task with a lower seq unfinished. + inSyncSeqs := &metaSeqHeap{arr: make([]metaSeq, 0)} + + type metaAndSeq struct { + metas []*sstMeta + seq int32 + } + + concurrency := e.config.CompactConcurrency + // when compaction is disabled, ingest is an serial action, so 1 routine is enough + if !e.config.Compact { + concurrency = 1 + } + metaChan := make(chan metaAndSeq, concurrency) + for i := 0; i < concurrency; i++ { + e.wg.Add(1) + go func() { + defer func() { + if e.ingestErr.Get() != nil { + seqLock.Lock() + for _, f := range flushQueue { + f.ch <- struct{}{} + } + flushQueue = flushQueue[:0] + seqLock.Unlock() + } + e.wg.Done() + }() + for { + select { + case <-e.ctx.Done(): + return + case metas, ok := <-metaChan: + if !ok { + return + } + ingestMetas := metas.metas + if e.config.Compact { + newMeta, err := e.sstIngester.mergeSSTs(metas.metas, e.sstDir, e.config.BlockSize) + if err != nil { + e.setError(err) + return + } + ingestMetas = []*sstMeta{newMeta} + } + // batchIngestSSTs will change ingestMetas' order, so we record the max seq here + metasMaxSeq := ingestMetas[len(ingestMetas)-1].seq + + if err := e.batchIngestSSTs(ingestMetas); err != nil { + e.setError(err) + return + } + seqLock.Lock() + finSeq := finishedSeq.Load() + if metas.seq == finSeq+1 { + finSeq = metas.seq + finMetaSeq := metasMaxSeq + for len(inSyncSeqs.arr) > 0 { + if inSyncSeqs.arr[0].flushSeq != finSeq+1 { + break + } + finSeq++ + finMetaSeq = inSyncSeqs.arr[0].metaSeq + heap.Remove(inSyncSeqs, 0) + } + + var flushChans []chan struct{} + for _, seq := range flushQueue { + if seq.seq > finSeq { + break + } + flushChans = append(flushChans, seq.ch) + } + flushQueue = flushQueue[len(flushChans):] + finishedSeq.Store(finSeq) + e.finishedMetaSeq.Store(finMetaSeq) + seqLock.Unlock() + for _, c := range flushChans { + c <- struct{}{} + } + } else { + heap.Push(inSyncSeqs, metaSeq{flushSeq: metas.seq, metaSeq: metasMaxSeq}) + seqLock.Unlock() + } + } + } + }() + } + + compactAndIngestSSTs := func(metas []*sstMeta) { + if len(metas) > 0 { + seqLock.Lock() + metaSeq := seq.Add(1) + seqLock.Unlock() + select { + case <-e.ctx.Done(): + case metaChan <- metaAndSeq{metas: metas, seq: metaSeq}: + } + } + } + + pendingMetas := make([]*sstMeta, 0, 16) + totalSize := int64(0) + metasTmp := make([]*sstMeta, 0) + addMetas := func() { + if len(metasTmp) == 0 { + return + } + metas := metasTmp + metasTmp = make([]*sstMeta, 0, len(metas)) + if !e.config.Compact { + compactAndIngestSSTs(metas) + return + } + for _, m := range metas { + if m.totalCount > 0 { + pendingMetas = append(pendingMetas, m) + totalSize += m.totalSize + if totalSize >= e.config.CompactThreshold { + compactMetas := pendingMetas + pendingMetas = make([]*sstMeta, 0, len(pendingMetas)) + totalSize = 0 + compactAndIngestSSTs(compactMetas) + } + } + } + } +readMetaLoop: + for { + closed := false + select { + case <-e.ctx.Done(): + close(metaChan) + return + case m, ok := <-e.sstMetasChan: + if !ok { + closed = true + break + } + if m.flushCh != nil { + // meet a flush event, we should trigger a ingest task if there are pending metas, + // and then waiting for all the running flush tasks to be done. + if len(metasTmp) > 0 { + addMetas() + } + if len(pendingMetas) > 0 { + seqLock.Lock() + metaSeq := seq.Add(1) + flushQueue = append(flushQueue, flushSeq{ch: m.flushCh, seq: metaSeq}) + seqLock.Unlock() + select { + case metaChan <- metaAndSeq{metas: pendingMetas, seq: metaSeq}: + case <-e.ctx.Done(): + close(metaChan) + return + } + + pendingMetas = make([]*sstMeta, 0, len(pendingMetas)) + totalSize = 0 + } else { + // none remaining metas needed to be ingested + seqLock.Lock() + curSeq := seq.Load() + finSeq := finishedSeq.Load() + // if all pending SST files are written, directly do a db.Flush + if curSeq == finSeq { + seqLock.Unlock() + m.flushCh <- struct{}{} + } else { + // waiting for pending compaction tasks + flushQueue = append(flushQueue, flushSeq{ch: m.flushCh, seq: curSeq}) + seqLock.Unlock() + } + } + continue readMetaLoop + } + metasTmp = append(metasTmp, m.meta) + // try to drain all the sst meta from the chan to make sure all the SSTs are processed before handle a flush msg. + if len(e.sstMetasChan) > 0 { + continue readMetaLoop + } + + addMetas() + } + if closed { + compactAndIngestSSTs(pendingMetas) + close(metaChan) + return + } + } +} + +func (e *Engine) addSST(ctx context.Context, m *sstMeta) (int32, error) { + // set pending size after SST file is generated + e.pendingFileSize.Add(m.fileSize) + // make sure sstMeta is sent into the chan in order + e.seqLock.Lock() + defer e.seqLock.Unlock() + e.nextSeq++ + seq := e.nextSeq + m.seq = seq + select { + case e.sstMetasChan <- metaOrFlush{meta: m}: + case <-ctx.Done(): + return 0, ctx.Err() + case <-e.ctx.Done(): + } + return seq, e.ingestErr.Get() +} + +func (e *Engine) batchIngestSSTs(metas []*sstMeta) error { + if len(metas) == 0 { + return nil + } + slices.SortFunc(metas, func(i, j *sstMeta) int { + return bytes.Compare(i.minKey, j.minKey) + }) + + // non overlapping sst is grouped, and ingested in that order + metaLevels := make([][]*sstMeta, 0) + for _, meta := range metas { + inserted := false + for i, l := range metaLevels { + if bytes.Compare(l[len(l)-1].maxKey, meta.minKey) >= 0 { + continue + } + metaLevels[i] = append(l, meta) + inserted = true + break + } + if !inserted { + metaLevels = append(metaLevels, []*sstMeta{meta}) + } + } + + for _, l := range metaLevels { + if err := e.ingestSSTs(l); err != nil { + return err + } + } + return nil +} + +func (e *Engine) ingestSSTs(metas []*sstMeta) error { + // use raw RLock to avoid change the lock state during flushing. + e.mutex.RLock() + defer e.mutex.RUnlock() + if e.closed.Load() { + return errorEngineClosed + } + totalSize := int64(0) + totalCount := int64(0) + fileSize := int64(0) + for _, m := range metas { + totalSize += m.totalSize + totalCount += m.totalCount + fileSize += m.fileSize + } + e.logger.Info("write data to local DB", + zap.Int64("size", totalSize), + zap.Int64("kvs", totalCount), + zap.Int("files", len(metas)), + zap.Int64("sstFileSize", fileSize), + zap.String("file", metas[0].path), + logutil.Key("firstKey", metas[0].minKey), + logutil.Key("lastKey", metas[len(metas)-1].maxKey)) + if err := e.sstIngester.ingest(metas); err != nil { + return errors.Trace(err) + } + count := int64(0) + size := int64(0) + for _, m := range metas { + count += m.totalCount + size += m.totalSize + } + e.Length.Add(count) + e.TotalSize.Add(size) + return nil +} + +func (e *Engine) flushLocalWriters(parentCtx context.Context) error { + eg, ctx := errgroup.WithContext(parentCtx) + e.localWriters.Range(func(k, _ any) bool { + eg.Go(func() error { + w := k.(*Writer) + return w.flush(ctx) + }) + return true + }) + return eg.Wait() +} + +func (e *Engine) flushEngineWithoutLock(ctx context.Context) error { + if err := e.flushLocalWriters(ctx); err != nil { + return err + } + flushChan := make(chan struct{}, 1) + select { + case e.sstMetasChan <- metaOrFlush{flushCh: flushChan}: + case <-ctx.Done(): + return ctx.Err() + case <-e.ctx.Done(): + return e.ctx.Err() + } + + select { + case <-flushChan: + case <-ctx.Done(): + return ctx.Err() + case <-e.ctx.Done(): + return e.ctx.Err() + } + if err := e.ingestErr.Get(); err != nil { + return errors.Trace(err) + } + if err := e.saveEngineMeta(); err != nil { + return err + } + + flushFinishedCh, err := e.getDB().AsyncFlush() + if err != nil { + return errors.Trace(err) + } + select { + case <-flushFinishedCh: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-e.ctx.Done(): + return e.ctx.Err() + } +} + +func saveEngineMetaToDB(meta *engineMeta, db *pebble.DB) error { + jsonBytes, err := json.Marshal(meta) + if err != nil { + return errors.Trace(err) + } + // note: we can't set Sync to true since we disabled WAL. + return db.Set(engineMetaKey, jsonBytes, &pebble.WriteOptions{Sync: false}) +} + +// saveEngineMeta saves the metadata about the DB into the DB itself. +// This method should be followed by a Flush to ensure the data is actually synchronized +func (e *Engine) saveEngineMeta() error { + e.logger.Debug("save engine meta", zap.Stringer("uuid", e.UUID), zap.Int64("count", e.Length.Load()), + zap.Int64("size", e.TotalSize.Load())) + return errors.Trace(saveEngineMetaToDB(&e.engineMeta, e.getDB())) +} + +func (e *Engine) loadEngineMeta() error { + jsonBytes, closer, err := e.getDB().Get(engineMetaKey) + if err != nil { + if err == pebble.ErrNotFound { + e.logger.Debug("local db missing engine meta", zap.Stringer("uuid", e.UUID), log.ShortError(err)) + return nil + } + return err + } + //nolint: errcheck + defer closer.Close() + + if err = json.Unmarshal(jsonBytes, &e.engineMeta); err != nil { + e.logger.Warn("local db failed to deserialize meta", zap.Stringer("uuid", e.UUID), zap.ByteString("content", jsonBytes), zap.Error(err)) + return err + } + e.logger.Debug("load engine meta", zap.Stringer("uuid", e.UUID), zap.Int64("count", e.Length.Load()), + zap.Int64("size", e.TotalSize.Load())) + return nil +} + +func (e *Engine) newKVIter(ctx context.Context, opts *pebble.IterOptions, buf *membuf.Buffer) IngestLocalEngineIter { + if bytes.Compare(opts.LowerBound, normalIterStartKey) < 0 { + newOpts := *opts + newOpts.LowerBound = normalIterStartKey + opts = &newOpts + } + if !e.duplicateDetection { + iter, err := e.getDB().NewIter(opts) + if err != nil { + e.logger.Panic("fail to create iterator") + return nil + } + return &pebbleIter{Iterator: iter, buf: buf} + } + logger := log.FromContext(ctx).With( + zap.String("table", common.UniqueTable(e.tableInfo.DB, e.tableInfo.Name)), + zap.Int64("tableID", e.tableInfo.ID), + zap.Stringer("engineUUID", e.UUID)) + return newDupDetectIter( + e.getDB(), + e.keyAdapter, + opts, + e.duplicateDB, + logger, + e.dupDetectOpt, + buf, + ) +} + +var _ common.IngestData = (*Engine)(nil) + +// GetFirstAndLastKey reads the first and last key in range [lowerBound, upperBound) +// in the engine. Empty upperBound means unbounded. +func (e *Engine) GetFirstAndLastKey(lowerBound, upperBound []byte) ([]byte, []byte, error) { + if len(upperBound) == 0 { + // we use empty slice for unbounded upper bound, but it means max value in pebble + // so reset to nil + upperBound = nil + } + opt := &pebble.IterOptions{ + LowerBound: lowerBound, + UpperBound: upperBound, + } + failpoint.Inject("mockGetFirstAndLastKey", func() { + failpoint.Return(lowerBound, upperBound, nil) + }) + + iter := e.newKVIter(context.Background(), opt, nil) + //nolint: errcheck + defer iter.Close() + // Needs seek to first because NewIter returns an iterator that is unpositioned + hasKey := iter.First() + if iter.Error() != nil { + return nil, nil, errors.Annotate(iter.Error(), "failed to read the first key") + } + if !hasKey { + return nil, nil, nil + } + firstKey := append([]byte{}, iter.Key()...) + iter.Last() + if iter.Error() != nil { + return nil, nil, errors.Annotate(iter.Error(), "failed to seek to the last key") + } + lastKey := append([]byte{}, iter.Key()...) + return firstKey, lastKey, nil +} + +// NewIter implements IngestData interface. +func (e *Engine) NewIter( + ctx context.Context, + lowerBound, upperBound []byte, + bufPool *membuf.Pool, +) common.ForwardIter { + return e.newKVIter( + ctx, + &pebble.IterOptions{LowerBound: lowerBound, UpperBound: upperBound}, + bufPool.NewBuffer(), + ) +} + +// GetTS implements IngestData interface. +func (e *Engine) GetTS() uint64 { + return e.TS +} + +// IncRef implements IngestData interface. +func (*Engine) IncRef() {} + +// DecRef implements IngestData interface. +func (*Engine) DecRef() {} + +// Finish implements IngestData interface. +func (e *Engine) Finish(totalBytes, totalCount int64) { + e.importedKVSize.Add(totalBytes) + e.importedKVCount.Add(totalCount) +} + +// LoadIngestData return (local) Engine itself because Engine has implemented +// IngestData interface. +func (e *Engine) LoadIngestData( + ctx context.Context, + regionRanges []common.Range, + outCh chan<- common.DataAndRange, +) error { + for _, r := range regionRanges { + select { + case <-ctx.Done(): + return ctx.Err() + case outCh <- common.DataAndRange{Data: e, Range: r}: + } + } + return nil +} + +type sstMeta struct { + path string + minKey []byte + maxKey []byte + totalSize int64 + totalCount int64 + // used for calculate disk-quota + fileSize int64 + seq int32 +} + +// Writer is used to write data into a SST file. +type Writer struct { + sync.Mutex + engine *Engine + memtableSizeLimit int64 + + // if the KVs are append in order, we can directly write the into SST file, + // else we must first store them in writeBatch and then batch flush into SST file. + isKVSorted bool + writer atomic.Pointer[sstWriter] + writerSize atomic.Uint64 + + // bytes buffer for writeBatch + kvBuffer *membuf.Buffer + writeBatch []common.KvPair + // if the kvs in writeBatch are in order, we can avoid doing a `sort.Slice` which + // is quite slow. in our bench, the sort operation eats about 5% of total CPU + isWriteBatchSorted bool + sortedKeyBuf []byte + + batchCount int + batchSize atomic.Int64 + + lastMetaSeq int32 + + tikvCodec tikv.Codec +} + +func (w *Writer) appendRowsSorted(kvs []common.KvPair) (err error) { + writer := w.writer.Load() + if writer == nil { + writer, err = w.createSSTWriter() + if err != nil { + return errors.Trace(err) + } + w.writer.Store(writer) + } + + keyAdapter := w.engine.keyAdapter + totalKeySize := 0 + for i := 0; i < len(kvs); i++ { + keySize := keyAdapter.EncodedLen(kvs[i].Key, kvs[i].RowID) + w.batchSize.Add(int64(keySize + len(kvs[i].Val))) + totalKeySize += keySize + } + w.batchCount += len(kvs) + // NoopKeyAdapter doesn't really change the key, + // skipping the encoding to avoid unnecessary alloc and copy. + if _, ok := keyAdapter.(common.NoopKeyAdapter); !ok { + if cap(w.sortedKeyBuf) < totalKeySize { + w.sortedKeyBuf = make([]byte, totalKeySize) + } + buf := w.sortedKeyBuf[:0] + newKvs := make([]common.KvPair, len(kvs)) + for i := 0; i < len(kvs); i++ { + buf = keyAdapter.Encode(buf, kvs[i].Key, kvs[i].RowID) + newKvs[i] = common.KvPair{Key: buf, Val: kvs[i].Val} + buf = buf[len(buf):] + } + kvs = newKvs + } + if err := writer.writeKVs(kvs); err != nil { + return err + } + w.writerSize.Store(writer.writer.EstimatedSize()) + return nil +} + +func (w *Writer) appendRowsUnsorted(ctx context.Context, kvs []common.KvPair) error { + l := len(w.writeBatch) + cnt := w.batchCount + var lastKey []byte + if cnt > 0 { + lastKey = w.writeBatch[cnt-1].Key + } + keyAdapter := w.engine.keyAdapter + for _, pair := range kvs { + if w.isWriteBatchSorted && bytes.Compare(lastKey, pair.Key) > 0 { + w.isWriteBatchSorted = false + } + lastKey = pair.Key + w.batchSize.Add(int64(len(pair.Key) + len(pair.Val))) + buf := w.kvBuffer.AllocBytes(keyAdapter.EncodedLen(pair.Key, pair.RowID)) + key := keyAdapter.Encode(buf[:0], pair.Key, pair.RowID) + val := w.kvBuffer.AddBytes(pair.Val) + if cnt < l { + w.writeBatch[cnt].Key = key + w.writeBatch[cnt].Val = val + } else { + w.writeBatch = append(w.writeBatch, common.KvPair{Key: key, Val: val}) + } + cnt++ + } + w.batchCount = cnt + + if w.batchSize.Load() > w.memtableSizeLimit { + if err := w.flushKVs(ctx); err != nil { + return err + } + } + return nil +} + +// AppendRows appends rows to the SST file. +func (w *Writer) AppendRows(ctx context.Context, columnNames []string, rows encode.Rows) error { + kvs := kv.Rows2KvPairs(rows) + if len(kvs) == 0 { + return nil + } + + if w.engine.closed.Load() { + return errorEngineClosed + } + + for i := range kvs { + kvs[i].Key = w.tikvCodec.EncodeKey(kvs[i].Key) + } + + w.Lock() + defer w.Unlock() + + // if chunk has _tidb_rowid field, we can't ensure that the rows are sorted. + if w.isKVSorted && w.writer.Load() == nil { + for _, c := range columnNames { + if c == model.ExtraHandleName.L { + w.isKVSorted = false + } + } + } + + if w.isKVSorted { + return w.appendRowsSorted(kvs) + } + return w.appendRowsUnsorted(ctx, kvs) +} + +func (w *Writer) flush(ctx context.Context) error { + w.Lock() + defer w.Unlock() + if w.batchCount == 0 { + return nil + } + + if len(w.writeBatch) > 0 { + if err := w.flushKVs(ctx); err != nil { + return errors.Trace(err) + } + } + + writer := w.writer.Load() + if writer != nil { + meta, err := writer.close() + if err != nil { + return errors.Trace(err) + } + w.writer.Store(nil) + w.writerSize.Store(0) + w.batchCount = 0 + if meta != nil && meta.totalSize > 0 { + return w.addSST(ctx, meta) + } + } + + return nil +} + +// EstimatedSize returns the estimated size of the SST file. +func (w *Writer) EstimatedSize() uint64 { + if size := w.writerSize.Load(); size > 0 { + return size + } + // if kvs are still in memory, only calculate half of the total size + // in our tests, SST file size is about 50% of the raw kv size + return uint64(w.batchSize.Load()) / 2 +} + +type flushStatus struct { + local *Engine + seq int32 +} + +// Flushed implements backend.ChunkFlushStatus. +func (f flushStatus) Flushed() bool { + return f.seq <= f.local.finishedMetaSeq.Load() +} + +// Close implements backend.ChunkFlushStatus. +func (w *Writer) Close(ctx context.Context) (backend.ChunkFlushStatus, error) { + defer w.kvBuffer.Destroy() + defer w.engine.localWriters.Delete(w) + err := w.flush(ctx) + // FIXME: in theory this line is useless, but In our benchmark with go1.15 + // this can resolve the memory consistently increasing issue. + // maybe this is a bug related to go GC mechanism. + w.writeBatch = nil + return flushStatus{local: w.engine, seq: w.lastMetaSeq}, err +} + +// IsSynced implements backend.ChunkFlushStatus. +func (w *Writer) IsSynced() bool { + return w.batchCount == 0 && w.lastMetaSeq <= w.engine.finishedMetaSeq.Load() +} + +func (w *Writer) flushKVs(ctx context.Context) error { + writer, err := w.createSSTWriter() + if err != nil { + return errors.Trace(err) + } + if !w.isWriteBatchSorted { + slices.SortFunc(w.writeBatch[:w.batchCount], func(i, j common.KvPair) int { + return bytes.Compare(i.Key, j.Key) + }) + w.isWriteBatchSorted = true + } + + err = writer.writeKVs(w.writeBatch[:w.batchCount]) + if err != nil { + return errors.Trace(err) + } + meta, err := writer.close() + if err != nil { + return errors.Trace(err) + } + + failpoint.Inject("orphanWriterGoRoutine", func() { + _ = common.KillMySelf() + // mimic we meet context cancel error when `addSST` + <-ctx.Done() + time.Sleep(5 * time.Second) + failpoint.Return(errors.Trace(ctx.Err())) + }) + + err = w.addSST(ctx, meta) + if err != nil { + return errors.Trace(err) + } + + w.batchSize.Store(0) + w.batchCount = 0 + w.kvBuffer.Reset() + return nil +} + +func (w *Writer) addSST(ctx context.Context, meta *sstMeta) error { + seq, err := w.engine.addSST(ctx, meta) + if err != nil { + return err + } + w.lastMetaSeq = seq + return nil +} + +func (w *Writer) createSSTWriter() (*sstWriter, error) { + path := filepath.Join(w.engine.sstDir, uuid.New().String()+".sst") + writer, err := newSSTWriter(path, w.engine.config.BlockSize) + if err != nil { + return nil, err + } + sw := &sstWriter{sstMeta: &sstMeta{path: path}, writer: writer, logger: w.engine.logger} + return sw, nil +} + +var errorUnorderedSSTInsertion = errors.New("inserting KVs into SST without order") + +type sstWriter struct { + *sstMeta + writer *sstable.Writer + + // To dedup keys before write them into the SST file. + // NOTE: keys should be sorted and deduped when construct one SST file. + lastKey []byte + + logger log.Logger +} + +func newSSTWriter(path string, blockSize int) (*sstable.Writer, error) { + f, err := vfs.Default.Create(path) + if err != nil { + return nil, errors.Trace(err) + } + writable := objstorageprovider.NewFileWritable(f) + writer := sstable.NewWriter(writable, sstable.WriterOptions{ + TablePropertyCollectors: []func() pebble.TablePropertyCollector{ + newRangePropertiesCollector, + }, + BlockSize: blockSize, + }) + return writer, nil +} + +func (sw *sstWriter) writeKVs(kvs []common.KvPair) error { + if len(kvs) == 0 { + return nil + } + if len(sw.minKey) == 0 { + sw.minKey = append([]byte{}, kvs[0].Key...) + } + if bytes.Compare(kvs[0].Key, sw.maxKey) <= 0 { + return errorUnorderedSSTInsertion + } + + internalKey := sstable.InternalKey{ + Trailer: uint64(sstable.InternalKeyKindSet), + } + for _, p := range kvs { + if sw.lastKey != nil && bytes.Equal(p.Key, sw.lastKey) { + sw.logger.Warn("duplicated key found, skip write", logutil.Key("key", p.Key)) + continue + } + internalKey.UserKey = p.Key + if err := sw.writer.Add(internalKey, p.Val); err != nil { + return errors.Trace(err) + } + sw.totalSize += int64(len(p.Key)) + int64(len(p.Val)) + sw.lastKey = p.Key + } + sw.totalCount += int64(len(kvs)) + sw.maxKey = append(sw.maxKey[:0], sw.lastKey...) + return nil +} + +func (sw *sstWriter) close() (*sstMeta, error) { + if err := sw.writer.Close(); err != nil { + return nil, errors.Trace(err) + } + meta, err := sw.writer.Metadata() + if err != nil { + return nil, errors.Trace(err) + } + sw.fileSize = int64(meta.Size) + return sw.sstMeta, nil +} + +type sstIter struct { + name string + key []byte + val []byte + iter sstable.Iterator + reader *sstable.Reader + valid bool +} + +// Close implements common.Iterator. +func (i *sstIter) Close() error { + if err := i.iter.Close(); err != nil { + return errors.Trace(err) + } + err := i.reader.Close() + return errors.Trace(err) +} + +type sstIterHeap struct { + iters []*sstIter +} + +// Len implements heap.Interface. +func (h *sstIterHeap) Len() int { + return len(h.iters) +} + +// Less implements heap.Interface. +func (h *sstIterHeap) Less(i, j int) bool { + return bytes.Compare(h.iters[i].key, h.iters[j].key) < 0 +} + +// Swap implements heap.Interface. +func (h *sstIterHeap) Swap(i, j int) { + h.iters[i], h.iters[j] = h.iters[j], h.iters[i] +} + +// Push implements heap.Interface. +func (h *sstIterHeap) Push(x any) { + h.iters = append(h.iters, x.(*sstIter)) +} + +// Pop implements heap.Interface. +func (h *sstIterHeap) Pop() any { + item := h.iters[len(h.iters)-1] + h.iters = h.iters[:len(h.iters)-1] + return item +} + +// Next implements common.Iterator. +func (h *sstIterHeap) Next() ([]byte, []byte, error) { + for { + if len(h.iters) == 0 { + return nil, nil, nil + } + + iter := h.iters[0] + if iter.valid { + iter.valid = false + return iter.key, iter.val, iter.iter.Error() + } + + var k *pebble.InternalKey + var v pebble.LazyValue + k, v = iter.iter.Next() + + if k != nil { + vBytes, _, err := v.Value(nil) + if err != nil { + return nil, nil, errors.Trace(err) + } + iter.key = k.UserKey + iter.val = vBytes + iter.valid = true + heap.Fix(h, 0) + } else { + err := iter.Close() + heap.Remove(h, 0) + if err != nil { + return nil, nil, errors.Trace(err) + } + } + } +} + +// sstIngester is a interface used to merge and ingest SST files. +// it's a interface mainly used for test convenience +type sstIngester interface { + mergeSSTs(metas []*sstMeta, dir string, blockSize int) (*sstMeta, error) + ingest([]*sstMeta) error +} + +type dbSSTIngester struct { + e *Engine +} + +func (i dbSSTIngester) mergeSSTs(metas []*sstMeta, dir string, blockSize int) (*sstMeta, error) { + if len(metas) == 0 { + return nil, errors.New("sst metas is empty") + } else if len(metas) == 1 { + return metas[0], nil + } + + start := time.Now() + newMeta := &sstMeta{ + seq: metas[len(metas)-1].seq, + } + mergeIter := &sstIterHeap{ + iters: make([]*sstIter, 0, len(metas)), + } + + for _, p := range metas { + f, err := vfs.Default.Open(p.path) + if err != nil { + return nil, errors.Trace(err) + } + readable, err := sstable.NewSimpleReadable(f) + if err != nil { + return nil, errors.Trace(err) + } + reader, err := sstable.NewReader(readable, sstable.ReaderOptions{}) + if err != nil { + return nil, errors.Trace(err) + } + iter, err := reader.NewIter(nil, nil) + if err != nil { + return nil, errors.Trace(err) + } + key, val := iter.Next() + if key == nil { + continue + } + valBytes, _, err := val.Value(nil) + if err != nil { + return nil, errors.Trace(err) + } + if iter.Error() != nil { + return nil, errors.Trace(iter.Error()) + } + mergeIter.iters = append(mergeIter.iters, &sstIter{ + name: p.path, + iter: iter, + key: key.UserKey, + val: valBytes, + reader: reader, + valid: true, + }) + newMeta.totalSize += p.totalSize + newMeta.totalCount += p.totalCount + } + heap.Init(mergeIter) + + name := filepath.Join(dir, fmt.Sprintf("%s.sst", uuid.New())) + writer, err := newSSTWriter(name, blockSize) + if err != nil { + return nil, errors.Trace(err) + } + newMeta.path = name + + internalKey := sstable.InternalKey{ + Trailer: uint64(sstable.InternalKeyKindSet), + } + key, val, err := mergeIter.Next() + if err != nil { + return nil, err + } + if key == nil { + return nil, errors.New("all ssts are empty") + } + newMeta.minKey = append(newMeta.minKey[:0], key...) + lastKey := make([]byte, 0) + for { + if bytes.Equal(lastKey, key) { + i.e.logger.Warn("duplicated key found, skipped", zap.Binary("key", lastKey)) + newMeta.totalCount-- + newMeta.totalSize -= int64(len(key) + len(val)) + + goto nextKey + } + internalKey.UserKey = key + err = writer.Add(internalKey, val) + if err != nil { + return nil, err + } + lastKey = append(lastKey[:0], key...) + nextKey: + key, val, err = mergeIter.Next() + if err != nil { + return nil, err + } + if key == nil { + break + } + } + err = writer.Close() + if err != nil { + return nil, errors.Trace(err) + } + meta, err := writer.Metadata() + if err != nil { + return nil, errors.Trace(err) + } + newMeta.maxKey = lastKey + newMeta.fileSize = int64(meta.Size) + + dur := time.Since(start) + i.e.logger.Info("compact sst", zap.Int("fileCount", len(metas)), zap.Int64("size", newMeta.totalSize), + zap.Int64("count", newMeta.totalCount), zap.Duration("cost", dur), zap.String("file", name)) + + // async clean raw SSTs. + go func() { + totalSize := int64(0) + for _, m := range metas { + totalSize += m.fileSize + if err := os.Remove(m.path); err != nil { + i.e.logger.Warn("async cleanup sst file failed", zap.Error(err)) + } + } + // decrease the pending size after clean up + i.e.pendingFileSize.Sub(totalSize) + }() + + return newMeta, err +} + +func (i dbSSTIngester) ingest(metas []*sstMeta) error { + if len(metas) == 0 { + return nil + } + paths := make([]string, 0, len(metas)) + for _, m := range metas { + paths = append(paths, m.path) + } + db := i.e.getDB() + if db == nil { + return errorEngineClosed + } + return db.Ingest(paths) +} diff --git a/pkg/lightning/backend/local/engine_mgr.go b/pkg/lightning/backend/local/engine_mgr.go index 28b6107a3d3e5..b796fa736ee65 100644 --- a/pkg/lightning/backend/local/engine_mgr.go +++ b/pkg/lightning/backend/local/engine_mgr.go @@ -630,9 +630,9 @@ func openDuplicateDB(storeDir string) (*pebble.DB, error) { newRangePropertiesCollector, }, } - failpoint.Inject("slowCreateFS", func() { + if _, _err_ := failpoint.Eval(_curpkg_("slowCreateFS")); _err_ == nil { opts.FS = slowCreateFS{vfs.Default} - }) + } return pebble.Open(dbPath, opts) } diff --git a/pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ b/pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ new file mode 100644 index 0000000000000..28b6107a3d3e5 --- /dev/null +++ b/pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ @@ -0,0 +1,658 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "context" + "math" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/cockroachdb/pebble" + "github.com/cockroachdb/pebble/vfs" + "github.com/docker/go-units" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/membuf" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/external" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/manual" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/tikv/client-go/v2/oracle" + tikvclient "github.com/tikv/client-go/v2/tikv" + "go.uber.org/atomic" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +var ( + // RunInTest indicates whether the current process is running in test. + RunInTest bool + // LastAlloc is the last ID allocator. + LastAlloc manual.Allocator +) + +// StoreHelper have some api to help encode or store KV data +type StoreHelper interface { + GetTS(ctx context.Context) (physical, logical int64, err error) + GetTiKVCodec() tikvclient.Codec +} + +// engineManager manages all engines, either local or external. +type engineManager struct { + BackendConfig + StoreHelper + engines sync.Map // sync version of map[uuid.UUID]*Engine + externalEngine map[uuid.UUID]common.Engine + bufferPool *membuf.Pool + duplicateDB *pebble.DB + keyAdapter common.KeyAdapter + logger log.Logger +} + +var inMemTest = false + +func newEngineManager(config BackendConfig, storeHelper StoreHelper, logger log.Logger) (_ *engineManager, err error) { + var duplicateDB *pebble.DB + defer func() { + if err != nil && duplicateDB != nil { + _ = duplicateDB.Close() + } + }() + + if err = prepareSortDir(config); err != nil { + return nil, err + } + + keyAdapter := common.KeyAdapter(common.NoopKeyAdapter{}) + if config.DupeDetectEnabled { + duplicateDB, err = openDuplicateDB(config.LocalStoreDir) + if err != nil { + return nil, common.ErrOpenDuplicateDB.Wrap(err).GenWithStackByArgs() + } + keyAdapter = common.DupDetectKeyAdapter{} + } + alloc := manual.Allocator{} + if RunInTest { + alloc.RefCnt = new(atomic.Int64) + LastAlloc = alloc + } + var opts = make([]membuf.Option, 0, 1) + if !inMemTest { + // otherwise, we use the default allocator that can be tracked by golang runtime. + opts = append(opts, membuf.WithAllocator(alloc)) + } + return &engineManager{ + BackendConfig: config, + StoreHelper: storeHelper, + engines: sync.Map{}, + externalEngine: map[uuid.UUID]common.Engine{}, + bufferPool: membuf.NewPool(opts...), + duplicateDB: duplicateDB, + keyAdapter: keyAdapter, + logger: logger, + }, nil +} + +// rlock read locks a local file and returns the Engine instance if it exists. +func (em *engineManager) rLockEngine(engineID uuid.UUID) *Engine { + if e, ok := em.engines.Load(engineID); ok { + engine := e.(*Engine) + engine.rLock() + return engine + } + return nil +} + +// lock locks a local file and returns the Engine instance if it exists. +func (em *engineManager) lockEngine(engineID uuid.UUID, state importMutexState) *Engine { + if e, ok := em.engines.Load(engineID); ok { + engine := e.(*Engine) + engine.lock(state) + return engine + } + return nil +} + +// tryRLockAllEngines tries to read lock all engines, return all `Engine`s that are successfully locked. +func (em *engineManager) tryRLockAllEngines() []*Engine { + var allEngines []*Engine + em.engines.Range(func(_, v any) bool { + engine := v.(*Engine) + // skip closed engine + if engine.tryRLock() { + if !engine.closed.Load() { + allEngines = append(allEngines, engine) + } else { + engine.rUnlock() + } + } + return true + }) + return allEngines +} + +// lockAllEnginesUnless tries to lock all engines, unless those which are already locked in the +// state given by ignoreStateMask. Returns the list of locked engines. +func (em *engineManager) lockAllEnginesUnless(newState, ignoreStateMask importMutexState) []*Engine { + var allEngines []*Engine + em.engines.Range(func(_, v any) bool { + engine := v.(*Engine) + if engine.lockUnless(newState, ignoreStateMask) { + allEngines = append(allEngines, engine) + } + return true + }) + return allEngines +} + +// flushEngine ensure the written data is saved successfully, to make sure no data lose after restart +func (em *engineManager) flushEngine(ctx context.Context, engineID uuid.UUID) error { + engine := em.rLockEngine(engineID) + + // the engine cannot be deleted after while we've acquired the lock identified by UUID. + if engine == nil { + return errors.Errorf("engine '%s' not found", engineID) + } + defer engine.rUnlock() + if engine.closed.Load() { + return nil + } + return engine.flushEngineWithoutLock(ctx) +} + +// flushAllEngines flush all engines. +func (em *engineManager) flushAllEngines(parentCtx context.Context) (err error) { + allEngines := em.tryRLockAllEngines() + defer func() { + for _, engine := range allEngines { + engine.rUnlock() + } + }() + + eg, ctx := errgroup.WithContext(parentCtx) + for _, engine := range allEngines { + e := engine + eg.Go(func() error { + return e.flushEngineWithoutLock(ctx) + }) + } + return eg.Wait() +} + +func (em *engineManager) openEngineDB(engineUUID uuid.UUID, readOnly bool) (*pebble.DB, error) { + opt := &pebble.Options{ + MemTableSize: uint64(em.MemTableSize), + // the default threshold value may cause write stall. + MemTableStopWritesThreshold: 8, + MaxConcurrentCompactions: func() int { return 16 }, + // set threshold to half of the max open files to avoid trigger compaction + L0CompactionThreshold: math.MaxInt32, + L0StopWritesThreshold: math.MaxInt32, + LBaseMaxBytes: 16 * units.TiB, + MaxOpenFiles: em.MaxOpenFiles, + DisableWAL: true, + ReadOnly: readOnly, + TablePropertyCollectors: []func() pebble.TablePropertyCollector{ + newRangePropertiesCollector, + }, + DisableAutomaticCompactions: em.DisableAutomaticCompactions, + } + // set level target file size to avoid pebble auto triggering compaction that split ingest SST files into small SST. + opt.Levels = []pebble.LevelOptions{ + { + TargetFileSize: 16 * units.GiB, + BlockSize: em.BlockSize, + }, + } + + dbPath := filepath.Join(em.LocalStoreDir, engineUUID.String()) + db, err := pebble.Open(dbPath, opt) + return db, errors.Trace(err) +} + +// openEngine must be called with holding mutex of Engine. +func (em *engineManager) openEngine(ctx context.Context, cfg *backend.EngineConfig, engineUUID uuid.UUID) error { + db, err := em.openEngineDB(engineUUID, false) + if err != nil { + return err + } + + sstDir := engineSSTDir(em.LocalStoreDir, engineUUID) + if !cfg.KeepSortDir { + if err := os.RemoveAll(sstDir); err != nil { + return errors.Trace(err) + } + } + if !common.IsDirExists(sstDir) { + if err := os.Mkdir(sstDir, 0o750); err != nil { + return errors.Trace(err) + } + } + engineCtx, cancel := context.WithCancel(ctx) + + e, _ := em.engines.LoadOrStore(engineUUID, &Engine{ + UUID: engineUUID, + sstDir: sstDir, + sstMetasChan: make(chan metaOrFlush, 64), + ctx: engineCtx, + cancel: cancel, + config: cfg.Local, + tableInfo: cfg.TableInfo, + duplicateDetection: em.DupeDetectEnabled, + dupDetectOpt: em.DuplicateDetectOpt, + duplicateDB: em.duplicateDB, + keyAdapter: em.keyAdapter, + logger: log.FromContext(ctx), + }) + engine := e.(*Engine) + engine.lock(importMutexStateOpen) + defer engine.unlock() + engine.db.Store(db) + engine.sstIngester = dbSSTIngester{e: engine} + if err = engine.loadEngineMeta(); err != nil { + return errors.Trace(err) + } + if engine.TS == 0 && cfg.TS > 0 { + engine.TS = cfg.TS + // we don't saveEngineMeta here, we can rely on the caller use the same TS to + // open the engine again. + } + if err = em.allocateTSIfNotExists(ctx, engine); err != nil { + return errors.Trace(err) + } + engine.wg.Add(1) + go engine.ingestSSTLoop() + return nil +} + +// closeEngine closes backend engine by uuid. +func (em *engineManager) closeEngine( + ctx context.Context, + cfg *backend.EngineConfig, + engineUUID uuid.UUID, +) (errRet error) { + if externalCfg := cfg.External; externalCfg != nil { + storeBackend, err := storage.ParseBackend(externalCfg.StorageURI, nil) + if err != nil { + return err + } + store, err := storage.NewWithDefaultOpt(ctx, storeBackend) + if err != nil { + return err + } + defer func() { + if errRet != nil { + store.Close() + } + }() + ts := cfg.TS + if ts == 0 { + physical, logical, err := em.GetTS(ctx) + if err != nil { + return err + } + ts = oracle.ComposeTS(physical, logical) + } + externalEngine := external.NewExternalEngine( + store, + externalCfg.DataFiles, + externalCfg.StatFiles, + externalCfg.StartKey, + externalCfg.EndKey, + externalCfg.SplitKeys, + externalCfg.RegionSplitSize, + em.keyAdapter, + em.DupeDetectEnabled, + em.duplicateDB, + em.DuplicateDetectOpt, + em.WorkerConcurrency, + ts, + externalCfg.TotalFileSize, + externalCfg.TotalKVCount, + externalCfg.CheckHotspot, + ) + em.externalEngine[engineUUID] = externalEngine + return nil + } + + // flush mem table to storage, to free memory, + // ask others' advise, looks like unnecessary, but with this we can control memory precisely. + engineI, ok := em.engines.Load(engineUUID) + if !ok { + // recovery mode, we should reopen this engine file + db, err := em.openEngineDB(engineUUID, true) + if err != nil { + return err + } + engine := &Engine{ + UUID: engineUUID, + sstMetasChan: make(chan metaOrFlush), + tableInfo: cfg.TableInfo, + keyAdapter: em.keyAdapter, + duplicateDetection: em.DupeDetectEnabled, + dupDetectOpt: em.DuplicateDetectOpt, + duplicateDB: em.duplicateDB, + logger: log.FromContext(ctx), + } + engine.db.Store(db) + engine.sstIngester = dbSSTIngester{e: engine} + if err = engine.loadEngineMeta(); err != nil { + return errors.Trace(err) + } + em.engines.Store(engineUUID, engine) + return nil + } + + engine := engineI.(*Engine) + engine.rLock() + if engine.closed.Load() { + engine.rUnlock() + return nil + } + + err := engine.flushEngineWithoutLock(ctx) + engine.rUnlock() + + // use mutex to make sure we won't close sstMetasChan while other routines + // trying to do flush. + engine.lock(importMutexStateClose) + engine.closed.Store(true) + close(engine.sstMetasChan) + engine.unlock() + if err != nil { + return errors.Trace(err) + } + engine.wg.Wait() + return engine.ingestErr.Get() +} + +// getImportedKVCount returns the number of imported KV pairs of some engine. +func (em *engineManager) getImportedKVCount(engineUUID uuid.UUID) int64 { + v, ok := em.engines.Load(engineUUID) + if !ok { + // we get it after import, but before clean up, so this should not happen + // todo: return error + return 0 + } + e := v.(*Engine) + return e.importedKVCount.Load() +} + +// getExternalEngineKVStatistics returns kv statistics of some engine. +func (em *engineManager) getExternalEngineKVStatistics(engineUUID uuid.UUID) ( + totalKVSize int64, totalKVCount int64) { + v, ok := em.externalEngine[engineUUID] + if !ok { + return 0, 0 + } + return v.ImportedStatistics() +} + +// resetEngine reset the engine and reclaim the space. +func (em *engineManager) resetEngine( + ctx context.Context, + engineUUID uuid.UUID, + skipAllocTS bool, +) error { + // the only way to reset the engine + reclaim the space is to delete and reopen it 🤷 + localEngine := em.lockEngine(engineUUID, importMutexStateClose) + if localEngine == nil { + if engineI, ok := em.externalEngine[engineUUID]; ok { + extEngine := engineI.(*external.Engine) + return extEngine.Reset() + } + + log.FromContext(ctx).Warn("could not find engine in cleanupEngine", zap.Stringer("uuid", engineUUID)) + return nil + } + defer localEngine.unlock() + if err := localEngine.Close(); err != nil { + return err + } + if err := localEngine.Cleanup(em.LocalStoreDir); err != nil { + return err + } + db, err := em.openEngineDB(engineUUID, false) + if err == nil { + localEngine.db.Store(db) + localEngine.engineMeta = engineMeta{} + if !common.IsDirExists(localEngine.sstDir) { + if err := os.Mkdir(localEngine.sstDir, 0o750); err != nil { + return errors.Trace(err) + } + } + if !skipAllocTS { + if err = em.allocateTSIfNotExists(ctx, localEngine); err != nil { + return errors.Trace(err) + } + } + } + localEngine.pendingFileSize.Store(0) + + return err +} + +func (em *engineManager) allocateTSIfNotExists(ctx context.Context, engine *Engine) error { + if engine.TS > 0 { + return nil + } + physical, logical, err := em.GetTS(ctx) + if err != nil { + return err + } + ts := oracle.ComposeTS(physical, logical) + engine.TS = ts + return engine.saveEngineMeta() +} + +// cleanupEngine cleanup the engine and reclaim the space. +func (em *engineManager) cleanupEngine(ctx context.Context, engineUUID uuid.UUID) error { + localEngine := em.lockEngine(engineUUID, importMutexStateClose) + // release this engine after import success + if localEngine == nil { + if extEngine, ok := em.externalEngine[engineUUID]; ok { + retErr := extEngine.Close() + delete(em.externalEngine, engineUUID) + return retErr + } + log.FromContext(ctx).Warn("could not find engine in cleanupEngine", zap.Stringer("uuid", engineUUID)) + return nil + } + defer localEngine.unlock() + + // since closing the engine causes all subsequent operations on it panic, + // we make sure to delete it from the engine map before calling Close(). + // (note that Close() returning error does _not_ mean the pebble DB + // remains open/usable.) + em.engines.Delete(engineUUID) + err := localEngine.Close() + if err != nil { + return err + } + err = localEngine.Cleanup(em.LocalStoreDir) + if err != nil { + return err + } + localEngine.TotalSize.Store(0) + localEngine.Length.Store(0) + return nil +} + +// LocalWriter returns a new local writer. +func (em *engineManager) localWriter(_ context.Context, cfg *backend.LocalWriterConfig, engineUUID uuid.UUID) (backend.EngineWriter, error) { + e, ok := em.engines.Load(engineUUID) + if !ok { + return nil, errors.Errorf("could not find engine for %s", engineUUID.String()) + } + engine := e.(*Engine) + memCacheSize := em.LocalWriterMemCacheSize + if cfg.Local.MemCacheSize > 0 { + memCacheSize = cfg.Local.MemCacheSize + } + return openLocalWriter(cfg, engine, em.GetTiKVCodec(), memCacheSize, em.bufferPool.NewBuffer()) +} + +func (em *engineManager) engineFileSizes() (res []backend.EngineFileSize) { + em.engines.Range(func(_, v any) bool { + engine := v.(*Engine) + res = append(res, engine.getEngineFileSize()) + return true + }) + return +} + +func (em *engineManager) close() { + for _, e := range em.externalEngine { + _ = e.Close() + } + em.externalEngine = map[uuid.UUID]common.Engine{} + allLocalEngines := em.lockAllEnginesUnless(importMutexStateClose, 0) + for _, e := range allLocalEngines { + _ = e.Close() + e.unlock() + } + em.engines = sync.Map{} + em.bufferPool.Destroy() + + if em.duplicateDB != nil { + // Check if there are duplicates that are not collected. + iter, err := em.duplicateDB.NewIter(&pebble.IterOptions{}) + if err != nil { + em.logger.Panic("fail to create iterator") + } + hasDuplicates := iter.First() + allIsWell := true + if err := iter.Error(); err != nil { + em.logger.Warn("iterate duplicate db failed", zap.Error(err)) + allIsWell = false + } + if err := iter.Close(); err != nil { + em.logger.Warn("close duplicate db iter failed", zap.Error(err)) + allIsWell = false + } + if err := em.duplicateDB.Close(); err != nil { + em.logger.Warn("close duplicate db failed", zap.Error(err)) + allIsWell = false + } + // If checkpoint is disabled, or we don't detect any duplicate, then this duplicate + // db dir will be useless, so we clean up this dir. + if allIsWell && (!em.CheckpointEnabled || !hasDuplicates) { + if err := os.RemoveAll(filepath.Join(em.LocalStoreDir, duplicateDBName)); err != nil { + em.logger.Warn("remove duplicate db file failed", zap.Error(err)) + } + } + em.duplicateDB = nil + } + + // if checkpoint is disabled, or we finish load all data successfully, then files in this + // dir will be useless, so we clean up this dir and all files in it. + if !em.CheckpointEnabled || common.IsEmptyDir(em.LocalStoreDir) { + err := os.RemoveAll(em.LocalStoreDir) + if err != nil { + em.logger.Warn("remove local db file failed", zap.Error(err)) + } + } +} + +func (em *engineManager) getExternalEngine(uuid uuid.UUID) (common.Engine, bool) { + e, ok := em.externalEngine[uuid] + return e, ok +} + +func (em *engineManager) totalMemoryConsume() int64 { + var memConsume int64 + em.engines.Range(func(_, v any) bool { + e := v.(*Engine) + if e != nil { + memConsume += e.TotalMemorySize() + } + return true + }) + return memConsume + em.bufferPool.TotalSize() +} + +func (em *engineManager) getDuplicateDB() *pebble.DB { + return em.duplicateDB +} + +func (em *engineManager) getKeyAdapter() common.KeyAdapter { + return em.keyAdapter +} + +func (em *engineManager) getBufferPool() *membuf.Pool { + return em.bufferPool +} + +// only used in tests +type slowCreateFS struct { + vfs.FS +} + +// WaitRMFolderChForTest is a channel for testing. +var WaitRMFolderChForTest = make(chan struct{}) + +func (s slowCreateFS) Create(name string) (vfs.File, error) { + if strings.Contains(name, "temporary") { + select { + case <-WaitRMFolderChForTest: + case <-time.After(1 * time.Second): + logutil.BgLogger().Info("no one removes folder") + } + } + return s.FS.Create(name) +} + +func openDuplicateDB(storeDir string) (*pebble.DB, error) { + dbPath := filepath.Join(storeDir, duplicateDBName) + // TODO: Optimize the opts for better write. + opts := &pebble.Options{ + TablePropertyCollectors: []func() pebble.TablePropertyCollector{ + newRangePropertiesCollector, + }, + } + failpoint.Inject("slowCreateFS", func() { + opts.FS = slowCreateFS{vfs.Default} + }) + return pebble.Open(dbPath, opts) +} + +func prepareSortDir(config BackendConfig) error { + shouldCreate := true + if config.CheckpointEnabled { + if info, err := os.Stat(config.LocalStoreDir); err != nil { + if !os.IsNotExist(err) { + return err + } + } else if info.IsDir() { + shouldCreate = false + } + } + + if shouldCreate { + err := os.Mkdir(config.LocalStoreDir, 0o700) + if err != nil { + return common.ErrInvalidSortedKVDir.Wrap(err).GenWithStackByArgs(config.LocalStoreDir) + } + } + return nil +} diff --git a/pkg/lightning/backend/local/local.go b/pkg/lightning/backend/local/local.go index c38a01936e504..6b2678b71ab1c 100644 --- a/pkg/lightning/backend/local/local.go +++ b/pkg/lightning/backend/local/local.go @@ -183,7 +183,7 @@ func (f *importClientFactoryImpl) makeConn(ctx context.Context, storeID uint64) return nil, common.ErrInvalidConfig.GenWithStack("unsupported compression type %s", f.compressionType) } - failpoint.Inject("LoggingImportBytes", func() { + if _, _err_ := failpoint.Eval(_curpkg_("LoggingImportBytes")); _err_ == nil { opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", target) if err != nil { @@ -191,7 +191,7 @@ func (f *importClientFactoryImpl) makeConn(ctx context.Context, storeID uint64) } return &loggingConn{Conn: conn}, nil })) - }) + } conn, err := grpc.DialContext(ctx, addr, opts...) if err != nil { @@ -891,18 +891,18 @@ func (local *Backend) prepareAndSendJob( // the table when table is created. needSplit := len(initialSplitRanges) > 1 || lfTotalSize > regionSplitSize || lfLength > regionSplitKeys // split region by given ranges - failpoint.Inject("failToSplit", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("failToSplit")); _err_ == nil { needSplit = true - }) + } if needSplit { var err error logger := log.FromContext(ctx).With(zap.String("uuid", engine.ID())).Begin(zap.InfoLevel, "split and scatter ranges") backOffTime := 10 * time.Second maxbackoffTime := 120 * time.Second for i := 0; i < maxRetryTimes; i++ { - failpoint.Inject("skipSplitAndScatter", func() { - failpoint.Break() - }) + if _, _err_ := failpoint.Eval(_curpkg_("skipSplitAndScatter")); _err_ == nil { + break + } err = local.SplitAndScatterRegionInBatches(ctx, initialSplitRanges, maxBatchSplitRanges) if err == nil || common.IsContextCanceledError(err) { @@ -987,13 +987,13 @@ func (local *Backend) generateAndSendJob( return nil } - failpoint.Inject("beforeGenerateJob", nil) - failpoint.Inject("sendDummyJob", func(_ failpoint.Value) { + failpoint.Eval(_curpkg_("beforeGenerateJob")) + if _, _err_ := failpoint.Eval(_curpkg_("sendDummyJob")); _err_ == nil { // this is used to trigger worker failure, used together // with WriteToTiKVNotEnoughDiskSpace jobToWorkerCh <- ®ionJob{} time.Sleep(5 * time.Second) - }) + } jobs, err := local.generateJobForRange(egCtx, p.Data, p.Range, regionSplitSize, regionSplitKeys) if err != nil { if common.IsContextCanceledError(err) { @@ -1045,9 +1045,9 @@ func (local *Backend) generateJobForRange( keyRange common.Range, regionSplitSize, regionSplitKeys int64, ) ([]*regionJob, error) { - failpoint.Inject("fakeRegionJobs", func() { + if _, _err_ := failpoint.Eval(_curpkg_("fakeRegionJobs")); _err_ == nil { if ctx.Err() != nil { - failpoint.Return(nil, ctx.Err()) + return nil, ctx.Err() } key := [2]string{string(keyRange.Start), string(keyRange.End)} injected := fakeRegionJobs[key] @@ -1056,8 +1056,8 @@ func (local *Backend) generateJobForRange( for _, job := range injected.jobs { job.stage = regionScanned } - failpoint.Return(injected.jobs, injected.err) - }) + return injected.jobs, injected.err + } start, end := keyRange.Start, keyRange.End pairStart, pairEnd, err := data.GetFirstAndLastKey(start, end) @@ -1226,10 +1226,9 @@ func (local *Backend) executeJob( ctx context.Context, job *regionJob, ) error { - failpoint.Inject("WriteToTiKVNotEnoughDiskSpace", func(_ failpoint.Value) { - failpoint.Return( - errors.New("the remaining storage capacity of TiKV is less than 10%%; please increase the storage capacity of TiKV and try again")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("WriteToTiKVNotEnoughDiskSpace")); _err_ == nil { + return errors.New("the remaining storage capacity of TiKV is less than 10%%; please increase the storage capacity of TiKV and try again") + } if local.ShouldCheckTiKV { for _, peer := range job.region.Region.GetPeers() { store, err := local.pdHTTPCli.GetStore(ctx, peer.StoreId) @@ -1371,7 +1370,7 @@ func (local *Backend) ImportEngine( zap.Int64("count", lfLength), zap.Int64("size", lfTotalSize)) - failpoint.Inject("ReadyForImportEngine", func() {}) + failpoint.Eval(_curpkg_("ReadyForImportEngine")) err = local.doImport(ctx, e, regionRanges, regionSplitSize, regionSplitKeys) if err == nil { @@ -1421,10 +1420,10 @@ func (local *Backend) doImport(ctx context.Context, engine common.Engine, region ) defer workerCancel() - failpoint.Inject("injectVariables", func() { + if _, _err_ := failpoint.Eval(_curpkg_("injectVariables")); _err_ == nil { jobToWorkerCh = testJobToWorkerCh testJobWg = &jobWg - }) + } retryer := startRegionJobRetryer(workerCtx, jobToWorkerCh, &jobWg) @@ -1476,17 +1475,16 @@ func (local *Backend) doImport(ctx context.Context, engine common.Engine, region } }() - failpoint.Inject("skipStartWorker", func() { - failpoint.Goto("afterStartWorker") - }) + if _, _err_ := failpoint.Eval(_curpkg_("skipStartWorker")); _err_ == nil { + goto afterStartWorker + } for i := 0; i < local.WorkerConcurrency; i++ { workGroup.Go(func() error { return local.startWorker(workerCtx, jobToWorkerCh, jobFromWorkerCh, &jobWg) }) } - - failpoint.Label("afterStartWorker") +afterStartWorker: workGroup.Go(func() error { err := local.prepareAndSendJob( diff --git a/pkg/lightning/backend/local/local.go__failpoint_stash__ b/pkg/lightning/backend/local/local.go__failpoint_stash__ new file mode 100644 index 0000000000000..c38a01936e504 --- /dev/null +++ b/pkg/lightning/backend/local/local.go__failpoint_stash__ @@ -0,0 +1,1754 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "bytes" + "context" + "database/sql" + "encoding/hex" + "io" + "math" + "net" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/coreos/go-semver/semver" + "github.com/docker/go-units" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + sst "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/membuf" + "github.com/pingcap/tidb/br/pkg/pdutil" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/external" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/errormanager" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/lightning/tikv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/engine" + tikvclient "github.com/tikv/client-go/v2/tikv" + pd "github.com/tikv/pd/client" + pdhttp "github.com/tikv/pd/client/http" + "github.com/tikv/pd/client/retry" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/backoff" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/status" +) + +const ( + dialTimeout = 5 * time.Minute + maxRetryTimes = 20 + defaultRetryBackoffTime = 3 * time.Second + // maxWriteAndIngestRetryTimes is the max retry times for write and ingest. + // A large retry times is for tolerating tikv cluster failures. + maxWriteAndIngestRetryTimes = 30 + + gRPCKeepAliveTime = 10 * time.Minute + gRPCKeepAliveTimeout = 5 * time.Minute + gRPCBackOffMaxDelay = 10 * time.Minute + + // The max ranges count in a batch to split and scatter. + maxBatchSplitRanges = 4096 + + propRangeIndex = "tikv.range_index" + + defaultPropSizeIndexDistance = 4 * units.MiB + defaultPropKeysIndexDistance = 40 * 1024 + + // the lower threshold of max open files for pebble db. + openFilesLowerThreshold = 128 + + duplicateDBName = "duplicates" + scanRegionLimit = 128 +) + +var ( + // Local backend is compatible with TiDB [4.0.0, NextMajorVersion). + localMinTiDBVersion = *semver.New("4.0.0") + localMinTiKVVersion = *semver.New("4.0.0") + localMinPDVersion = *semver.New("4.0.0") + localMaxTiDBVersion = version.NextMajorVersion() + localMaxTiKVVersion = version.NextMajorVersion() + localMaxPDVersion = version.NextMajorVersion() + tiFlashMinVersion = *semver.New("4.0.5") + tikvSideFreeSpaceCheck = *semver.New("8.0.0") + + errorEngineClosed = errors.New("engine is closed") + maxRetryBackoffSecond = 30 +) + +// ImportClientFactory is factory to create new import client for specific store. +type ImportClientFactory interface { + Create(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) + Close() +} + +type importClientFactoryImpl struct { + conns *common.GRPCConns + splitCli split.SplitClient + tls *common.TLS + tcpConcurrency int + compressionType config.CompressionType +} + +func newImportClientFactoryImpl( + splitCli split.SplitClient, + tls *common.TLS, + tcpConcurrency int, + compressionType config.CompressionType, +) *importClientFactoryImpl { + return &importClientFactoryImpl{ + conns: common.NewGRPCConns(), + splitCli: splitCli, + tls: tls, + tcpConcurrency: tcpConcurrency, + compressionType: compressionType, + } +} + +func (f *importClientFactoryImpl) makeConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { + store, err := f.splitCli.GetStore(ctx, storeID) + if err != nil { + return nil, errors.Trace(err) + } + var opts []grpc.DialOption + if f.tls.TLSConfig() != nil { + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(f.tls.TLSConfig()))) + } else { + opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + } + ctx, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() + + bfConf := backoff.DefaultConfig + bfConf.MaxDelay = gRPCBackOffMaxDelay + // we should use peer address for tiflash. for tikv, peer address is empty + addr := store.GetPeerAddress() + if addr == "" { + addr = store.GetAddress() + } + opts = append(opts, + grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: gRPCKeepAliveTime, + Timeout: gRPCKeepAliveTimeout, + PermitWithoutStream: true, + }), + ) + switch f.compressionType { + case config.CompressionNone: + // do nothing + case config.CompressionGzip: + // Use custom compressor/decompressor to speed up compression/decompression. + // Note that here we don't use grpc.UseCompressor option although it's the recommended way. + // Because gprc-go uses a global registry to store compressor/decompressor, we can't make sure + // the compressor/decompressor is not registered by other components. + opts = append(opts, grpc.WithCompressor(&gzipCompressor{}), grpc.WithDecompressor(&gzipDecompressor{})) + default: + return nil, common.ErrInvalidConfig.GenWithStack("unsupported compression type %s", f.compressionType) + } + + failpoint.Inject("LoggingImportBytes", func() { + opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", target) + if err != nil { + return nil, err + } + return &loggingConn{Conn: conn}, nil + })) + }) + + conn, err := grpc.DialContext(ctx, addr, opts...) + if err != nil { + return nil, errors.Trace(err) + } + return conn, nil +} + +func (f *importClientFactoryImpl) getGrpcConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { + return f.conns.GetGrpcConn(ctx, storeID, f.tcpConcurrency, + func(ctx context.Context) (*grpc.ClientConn, error) { + return f.makeConn(ctx, storeID) + }) +} + +// Create creates a new import client for specific store. +func (f *importClientFactoryImpl) Create(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { + conn, err := f.getGrpcConn(ctx, storeID) + if err != nil { + return nil, err + } + return sst.NewImportSSTClient(conn), nil +} + +// Close closes the factory. +func (f *importClientFactoryImpl) Close() { + f.conns.Close() +} + +type loggingConn struct { + net.Conn +} + +// Write implements net.Conn.Write +func (c loggingConn) Write(b []byte) (int, error) { + log.L().Debug("import write", zap.Int("bytes", len(b))) + return c.Conn.Write(b) +} + +type encodingBuilder struct { + metrics *metric.Metrics +} + +// NewEncodingBuilder creates an KVEncodingBuilder with local backend implementation. +func NewEncodingBuilder(ctx context.Context) encode.EncodingBuilder { + result := new(encodingBuilder) + if m, ok := metric.FromContext(ctx); ok { + result.metrics = m + } + return result +} + +// NewEncoder creates a KV encoder. +// It implements the `backend.EncodingBuilder` interface. +func (b *encodingBuilder) NewEncoder(_ context.Context, config *encode.EncodingConfig) (encode.Encoder, error) { + return kv.NewTableKVEncoder(config, b.metrics) +} + +// MakeEmptyRows creates an empty KV rows. +// It implements the `backend.EncodingBuilder` interface. +func (*encodingBuilder) MakeEmptyRows() encode.Rows { + return kv.MakeRowsFromKvPairs(nil) +} + +type targetInfoGetter struct { + tls *common.TLS + targetDB *sql.DB + pdHTTPCli pdhttp.Client +} + +// NewTargetInfoGetter creates an TargetInfoGetter with local backend +// implementation. `pdHTTPCli` should not be nil when need to check component +// versions in CheckRequirements. +func NewTargetInfoGetter( + tls *common.TLS, + db *sql.DB, + pdHTTPCli pdhttp.Client, +) backend.TargetInfoGetter { + return &targetInfoGetter{ + tls: tls, + targetDB: db, + pdHTTPCli: pdHTTPCli, + } +} + +// FetchRemoteDBModels implements the `backend.TargetInfoGetter` interface. +func (g *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + return tikv.FetchRemoteDBModelsFromTLS(ctx, g.tls) +} + +// FetchRemoteTableModels obtains the models of all tables given the schema name. +// It implements the `TargetInfoGetter` interface. +func (g *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + return tikv.FetchRemoteTableModelsFromTLS(ctx, g.tls, schemaName) +} + +// CheckRequirements performs the check whether the backend satisfies the version requirements. +// It implements the `TargetInfoGetter` interface. +func (g *targetInfoGetter) CheckRequirements(ctx context.Context, checkCtx *backend.CheckCtx) error { + // TODO: support lightning via SQL + versionStr, err := version.FetchVersion(ctx, g.targetDB) + if err != nil { + return errors.Trace(err) + } + if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil { + return err + } + if g.pdHTTPCli == nil { + return common.ErrUnknown.GenWithStack("pd HTTP client is required for component version check in local backend") + } + if err := tikv.CheckPDVersion(ctx, g.pdHTTPCli, localMinPDVersion, localMaxPDVersion); err != nil { + return err + } + if err := tikv.CheckTiKVVersion(ctx, g.pdHTTPCli, localMinTiKVVersion, localMaxTiKVVersion); err != nil { + return err + } + + serverInfo := version.ParseServerInfo(versionStr) + return checkTiFlashVersion(ctx, g.targetDB, checkCtx, *serverInfo.ServerVersion) +} + +func checkTiDBVersion(_ context.Context, versionStr string, requiredMinVersion, requiredMaxVersion semver.Version) error { + return version.CheckTiDBVersion(versionStr, requiredMinVersion, requiredMaxVersion) +} + +var tiFlashReplicaQuery = "SELECT TABLE_SCHEMA, TABLE_NAME FROM information_schema.TIFLASH_REPLICA WHERE REPLICA_COUNT > 0;" + +// TiFlashReplicaQueryForTest is only used for tests. +var TiFlashReplicaQueryForTest = tiFlashReplicaQuery + +type tblName struct { + schema string + name string +} + +type tblNames []tblName + +// String implements fmt.Stringer +func (t tblNames) String() string { + var b strings.Builder + b.WriteByte('[') + for i, n := range t { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(common.UniqueTable(n.schema, n.name)) + } + b.WriteByte(']') + return b.String() +} + +// CheckTiFlashVersionForTest is only used for tests. +var CheckTiFlashVersionForTest = checkTiFlashVersion + +// check TiFlash replicas. +// local backend doesn't support TiFlash before tidb v4.0.5 +func checkTiFlashVersion(ctx context.Context, db *sql.DB, checkCtx *backend.CheckCtx, tidbVersion semver.Version) error { + if tidbVersion.Compare(tiFlashMinVersion) >= 0 { + return nil + } + + exec := common.SQLWithRetry{ + DB: db, + Logger: log.FromContext(ctx), + } + + res, err := exec.QueryStringRows(ctx, "fetch tiflash replica info", tiFlashReplicaQuery) + if err != nil { + return errors.Annotate(err, "fetch tiflash replica info failed") + } + + tiFlashTablesMap := make(map[tblName]struct{}, len(res)) + for _, tblInfo := range res { + name := tblName{schema: tblInfo[0], name: tblInfo[1]} + tiFlashTablesMap[name] = struct{}{} + } + + tiFlashTables := make(tblNames, 0) + for _, dbMeta := range checkCtx.DBMetas { + for _, tblMeta := range dbMeta.Tables { + if len(tblMeta.DataFiles) == 0 { + continue + } + name := tblName{schema: tblMeta.DB, name: tblMeta.Name} + if _, ok := tiFlashTablesMap[name]; ok { + tiFlashTables = append(tiFlashTables, name) + } + } + } + + if len(tiFlashTables) > 0 { + helpInfo := "Please either upgrade TiDB to version >= 4.0.5 or add TiFlash replica after load data." + return errors.Errorf("lightning local backend doesn't support TiFlash in this TiDB version. conflict tables: %s. "+helpInfo, tiFlashTables) + } + return nil +} + +// BackendConfig is the config for local backend. +type BackendConfig struct { + // comma separated list of PD endpoints. + PDAddr string + LocalStoreDir string + // max number of cached grpc.ClientConn to a store. + // note: this is not the limit of actual connections, each grpc.ClientConn can have one or more of it. + MaxConnPerStore int + // compress type when write or ingest into tikv + ConnCompressType config.CompressionType + // concurrency of generateJobForRange and import(write & ingest) workers + WorkerConcurrency int + // batch kv size when writing to TiKV + KVWriteBatchSize int64 + RegionSplitBatchSize int + RegionSplitConcurrency int + CheckpointEnabled bool + // memory table size of pebble. since pebble can have multiple mem tables, the max memory used is + // MemTableSize * MemTableStopWritesThreshold, see pebble.Options for more details. + MemTableSize int + // LocalWriterMemCacheSize is the memory threshold for one local writer of + // engines. If the KV payload size exceeds LocalWriterMemCacheSize, local writer + // will flush them into the engine. + // + // It has lower priority than LocalWriterConfig.Local.MemCacheSize. + LocalWriterMemCacheSize int64 + // whether check TiKV capacity before write & ingest. + ShouldCheckTiKV bool + DupeDetectEnabled bool + DuplicateDetectOpt common.DupDetectOpt + // max write speed in bytes per second to each store(burst is allowed), 0 means no limit + StoreWriteBWLimit int + // When TiKV is in normal mode, ingesting too many SSTs will cause TiKV write stall. + // To avoid this, we should check write stall before ingesting SSTs. Note that, we + // must check both leader node and followers in client side, because followers will + // not check write stall as long as ingest command is accepted by leader. + ShouldCheckWriteStall bool + // soft limit on the number of open files that can be used by pebble DB. + // the minimum value is 128. + MaxOpenFiles int + KeyspaceName string + // the scope when pause PD schedulers. + PausePDSchedulerScope config.PausePDSchedulerScope + ResourceGroupName string + TaskType string + RaftKV2SwitchModeDuration time.Duration + // whether disable automatic compactions of pebble db of engine. + // deduplicate pebble db is not affected by this option. + // see DisableAutomaticCompactions of pebble.Options for more details. + // default true. + DisableAutomaticCompactions bool + BlockSize int +} + +// NewBackendConfig creates a new BackendConfig. +func NewBackendConfig(cfg *config.Config, maxOpenFiles int, keyspaceName, resourceGroupName, taskType string, raftKV2SwitchModeDuration time.Duration) BackendConfig { + return BackendConfig{ + PDAddr: cfg.TiDB.PdAddr, + LocalStoreDir: cfg.TikvImporter.SortedKVDir, + MaxConnPerStore: cfg.TikvImporter.RangeConcurrency, + ConnCompressType: cfg.TikvImporter.CompressKVPairs, + WorkerConcurrency: cfg.TikvImporter.RangeConcurrency * 2, + BlockSize: int(cfg.TikvImporter.BlockSize), + KVWriteBatchSize: int64(cfg.TikvImporter.SendKVSize), + RegionSplitBatchSize: cfg.TikvImporter.RegionSplitBatchSize, + RegionSplitConcurrency: cfg.TikvImporter.RegionSplitConcurrency, + CheckpointEnabled: cfg.Checkpoint.Enable, + MemTableSize: int(cfg.TikvImporter.EngineMemCacheSize), + LocalWriterMemCacheSize: int64(cfg.TikvImporter.LocalWriterMemCacheSize), + ShouldCheckTiKV: cfg.App.CheckRequirements, + DupeDetectEnabled: cfg.Conflict.Strategy != config.NoneOnDup, + DuplicateDetectOpt: common.DupDetectOpt{ReportErrOnDup: cfg.Conflict.Strategy == config.ErrorOnDup}, + StoreWriteBWLimit: int(cfg.TikvImporter.StoreWriteBWLimit), + ShouldCheckWriteStall: cfg.Cron.SwitchMode.Duration == 0, + MaxOpenFiles: maxOpenFiles, + KeyspaceName: keyspaceName, + PausePDSchedulerScope: cfg.TikvImporter.PausePDSchedulerScope, + ResourceGroupName: resourceGroupName, + TaskType: taskType, + RaftKV2SwitchModeDuration: raftKV2SwitchModeDuration, + DisableAutomaticCompactions: true, + } +} + +func (c *BackendConfig) adjust() { + c.MaxOpenFiles = max(c.MaxOpenFiles, openFilesLowerThreshold) +} + +// Backend is a local backend. +type Backend struct { + pdCli pd.Client + pdHTTPCli pdhttp.Client + splitCli split.SplitClient + tikvCli *tikvclient.KVStore + tls *common.TLS + tikvCodec tikvclient.Codec + + BackendConfig + engineMgr *engineManager + + supportMultiIngest bool + importClientFactory ImportClientFactory + + metrics *metric.Common + writeLimiter StoreWriteLimiter + logger log.Logger + // This mutex is used to do some mutual exclusion work in the backend, flushKVs() in writer for now. + mu sync.Mutex +} + +var _ DiskUsage = (*Backend)(nil) +var _ StoreHelper = (*Backend)(nil) +var _ backend.Backend = (*Backend)(nil) + +const ( + pdCliMaxMsgSize = int(128 * units.MiB) // pd.ScanRegion may return a large response +) + +var ( + maxCallMsgSize = []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(pdCliMaxMsgSize)), + grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(pdCliMaxMsgSize)), + } +) + +// NewBackend creates new connections to tikv. +func NewBackend( + ctx context.Context, + tls *common.TLS, + config BackendConfig, + pdSvcDiscovery pd.ServiceDiscovery, +) (b *Backend, err error) { + var ( + pdCli pd.Client + spkv *tikvclient.EtcdSafePointKV + pdCliForTiKV *tikvclient.CodecPDClient + rpcCli tikvclient.Client + tikvCli *tikvclient.KVStore + pdHTTPCli pdhttp.Client + importClientFactory *importClientFactoryImpl + multiIngestSupported bool + ) + defer func() { + if err == nil { + return + } + if importClientFactory != nil { + importClientFactory.Close() + } + if pdHTTPCli != nil { + pdHTTPCli.Close() + } + if tikvCli != nil { + // tikvCli uses pdCliForTiKV(which wraps pdCli) , spkv and rpcCli, so + // close tikvCli will close all of them. + _ = tikvCli.Close() + } else { + if rpcCli != nil { + _ = rpcCli.Close() + } + if spkv != nil { + _ = spkv.Close() + } + // pdCliForTiKV wraps pdCli, so we only need close pdCli + if pdCli != nil { + pdCli.Close() + } + } + }() + config.adjust() + var pdAddrs []string + if pdSvcDiscovery != nil { + pdAddrs = pdSvcDiscovery.GetServiceURLs() + // TODO(lance6716): if PD client can support creating a client with external + // service discovery, we can directly pass pdSvcDiscovery. + } else { + pdAddrs = strings.Split(config.PDAddr, ",") + } + pdCli, err = pd.NewClientWithContext( + ctx, pdAddrs, tls.ToPDSecurityOption(), + pd.WithGRPCDialOptions(maxCallMsgSize...), + // If the time too short, we may scatter a region many times, because + // the interface `ScatterRegions` may time out. + pd.WithCustomTimeoutOption(60*time.Second), + ) + if err != nil { + return nil, common.NormalizeOrWrapErr(common.ErrCreatePDClient, err) + } + + // The following copies tikv.NewTxnClient without creating yet another pdClient. + spkv, err = tikvclient.NewEtcdSafePointKV(strings.Split(config.PDAddr, ","), tls.TLSConfig()) + if err != nil { + return nil, common.ErrCreateKVClient.Wrap(err).GenWithStackByArgs() + } + + if config.KeyspaceName == "" { + pdCliForTiKV = tikvclient.NewCodecPDClient(tikvclient.ModeTxn, pdCli) + } else { + pdCliForTiKV, err = tikvclient.NewCodecPDClientWithKeyspace(tikvclient.ModeTxn, pdCli, config.KeyspaceName) + if err != nil { + return nil, common.ErrCreatePDClient.Wrap(err).GenWithStackByArgs() + } + } + + tikvCodec := pdCliForTiKV.GetCodec() + rpcCli = tikvclient.NewRPCClient(tikvclient.WithSecurity(tls.ToTiKVSecurityConfig()), tikvclient.WithCodec(tikvCodec)) + tikvCli, err = tikvclient.NewKVStore("lightning-local-backend", pdCliForTiKV, spkv, rpcCli) + if err != nil { + return nil, common.ErrCreateKVClient.Wrap(err).GenWithStackByArgs() + } + pdHTTPCli = pdhttp.NewClientWithServiceDiscovery( + "lightning", + pdCli.GetServiceDiscovery(), + pdhttp.WithTLSConfig(tls.TLSConfig()), + ).WithBackoffer(retry.InitialBackoffer(time.Second, time.Second, pdutil.PDRequestRetryTime*time.Second)) + splitCli := split.NewClient(pdCli, pdHTTPCli, tls.TLSConfig(), config.RegionSplitBatchSize, config.RegionSplitConcurrency) + importClientFactory = newImportClientFactoryImpl(splitCli, tls, config.MaxConnPerStore, config.ConnCompressType) + + multiIngestSupported, err = checkMultiIngestSupport(ctx, pdCli, importClientFactory) + if err != nil { + return nil, common.ErrCheckMultiIngest.Wrap(err).GenWithStackByArgs() + } + + var writeLimiter StoreWriteLimiter + if config.StoreWriteBWLimit > 0 { + writeLimiter = newStoreWriteLimiter(config.StoreWriteBWLimit) + } else { + writeLimiter = noopStoreWriteLimiter{} + } + local := &Backend{ + pdCli: pdCli, + pdHTTPCli: pdHTTPCli, + splitCli: splitCli, + tikvCli: tikvCli, + tls: tls, + tikvCodec: tikvCodec, + + BackendConfig: config, + + supportMultiIngest: multiIngestSupported, + importClientFactory: importClientFactory, + writeLimiter: writeLimiter, + logger: log.FromContext(ctx), + } + local.engineMgr, err = newEngineManager(config, local, local.logger) + if err != nil { + return nil, err + } + if m, ok := metric.GetCommonMetric(ctx); ok { + local.metrics = m + } + local.tikvSideCheckFreeSpace(ctx) + + return local, nil +} + +// NewBackendForTest creates a new Backend for test. +func NewBackendForTest(ctx context.Context, config BackendConfig, storeHelper StoreHelper) (*Backend, error) { + config.adjust() + + logger := log.FromContext(ctx) + engineMgr, err := newEngineManager(config, storeHelper, logger) + if err != nil { + return nil, err + } + local := &Backend{ + BackendConfig: config, + logger: logger, + engineMgr: engineMgr, + } + if m, ok := metric.GetCommonMetric(ctx); ok { + local.metrics = m + } + + return local, nil +} + +// TotalMemoryConsume returns the total memory usage of the local backend. +func (local *Backend) TotalMemoryConsume() int64 { + return local.engineMgr.totalMemoryConsume() +} + +func checkMultiIngestSupport(ctx context.Context, pdCli pd.Client, importClientFactory ImportClientFactory) (bool, error) { + stores, err := pdCli.GetAllStores(ctx, pd.WithExcludeTombstone()) + if err != nil { + return false, errors.Trace(err) + } + + hasTiFlash := false + for _, s := range stores { + if s.State == metapb.StoreState_Up && engine.IsTiFlash(s) { + hasTiFlash = true + break + } + } + + for _, s := range stores { + // skip stores that are not online + if s.State != metapb.StoreState_Up || engine.IsTiFlash(s) { + continue + } + var err error + for i := 0; i < maxRetryTimes; i++ { + if i > 0 { + select { + case <-time.After(100 * time.Millisecond): + case <-ctx.Done(): + return false, ctx.Err() + } + } + client, err1 := importClientFactory.Create(ctx, s.Id) + if err1 != nil { + err = err1 + log.FromContext(ctx).Warn("get import client failed", zap.Error(err), zap.String("store", s.Address)) + continue + } + _, err = client.MultiIngest(ctx, &sst.MultiIngestRequest{}) + if err == nil { + break + } + if st, ok := status.FromError(err); ok { + if st.Code() == codes.Unimplemented { + log.FromContext(ctx).Info("multi ingest not support", zap.Any("unsupported store", s)) + return false, nil + } + } + log.FromContext(ctx).Warn("check multi ingest support failed", zap.Error(err), zap.String("store", s.Address), + zap.Int("retry", i)) + } + if err != nil { + // if the cluster contains no TiFlash store, we don't need the multi-ingest feature, + // so in this condition, downgrade the logic instead of return an error. + if hasTiFlash { + return false, errors.Trace(err) + } + log.FromContext(ctx).Warn("check multi failed all retry, fallback to false", log.ShortError(err)) + return false, nil + } + } + + log.FromContext(ctx).Info("multi ingest support") + return true, nil +} + +func (local *Backend) tikvSideCheckFreeSpace(ctx context.Context) { + if !local.ShouldCheckTiKV { + return + } + err := tikv.ForTiKVVersions( + ctx, + local.pdHTTPCli, + func(version *semver.Version, addrMsg string) error { + if version.Compare(tikvSideFreeSpaceCheck) < 0 { + return errors.Errorf( + "%s has version %s, it does not support server side free space check", + addrMsg, version, + ) + } + return nil + }, + ) + if err == nil { + local.logger.Info("TiKV server side free space check is enabled, so lightning will turn it off") + local.ShouldCheckTiKV = false + } else { + local.logger.Info("", zap.Error(err)) + } +} + +// Close the local backend. +func (local *Backend) Close() { + local.engineMgr.close() + local.importClientFactory.Close() + + _ = local.tikvCli.Close() + local.pdHTTPCli.Close() + local.pdCli.Close() +} + +// FlushEngine ensure the written data is saved successfully, to make sure no data lose after restart +func (local *Backend) FlushEngine(ctx context.Context, engineID uuid.UUID) error { + return local.engineMgr.flushEngine(ctx, engineID) +} + +// FlushAllEngines flush all engines. +func (local *Backend) FlushAllEngines(parentCtx context.Context) (err error) { + return local.engineMgr.flushAllEngines(parentCtx) +} + +// RetryImportDelay returns the delay time before retrying to import a file. +func (*Backend) RetryImportDelay() time.Duration { + return defaultRetryBackoffTime +} + +// ShouldPostProcess returns true if the backend should post process the data. +func (*Backend) ShouldPostProcess() bool { + return true +} + +// OpenEngine must be called with holding mutex of Engine. +func (local *Backend) OpenEngine(ctx context.Context, cfg *backend.EngineConfig, engineUUID uuid.UUID) error { + return local.engineMgr.openEngine(ctx, cfg, engineUUID) +} + +// CloseEngine closes backend engine by uuid. +func (local *Backend) CloseEngine(ctx context.Context, cfg *backend.EngineConfig, engineUUID uuid.UUID) error { + return local.engineMgr.closeEngine(ctx, cfg, engineUUID) +} + +func (local *Backend) getImportClient(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { + return local.importClientFactory.Create(ctx, storeID) +} + +func splitRangeBySizeProps(fullRange common.Range, sizeProps *sizeProperties, sizeLimit int64, keysLimit int64) []common.Range { + ranges := make([]common.Range, 0, sizeProps.totalSize/uint64(sizeLimit)) + curSize := uint64(0) + curKeys := uint64(0) + curKey := fullRange.Start + + sizeProps.iter(func(p *rangeProperty) bool { + if bytes.Compare(p.Key, curKey) <= 0 { + return true + } + if bytes.Compare(p.Key, fullRange.End) > 0 { + return false + } + curSize += p.Size + curKeys += p.Keys + if int64(curSize) >= sizeLimit || int64(curKeys) >= keysLimit { + ranges = append(ranges, common.Range{Start: curKey, End: p.Key}) + curKey = p.Key + curSize = 0 + curKeys = 0 + } + return true + }) + + if bytes.Compare(curKey, fullRange.End) < 0 { + // If the remaining range is too small, append it to last range. + if len(ranges) > 0 && curKeys == 0 { + ranges[len(ranges)-1].End = fullRange.End + } else { + ranges = append(ranges, common.Range{Start: curKey, End: fullRange.End}) + } + } + return ranges +} + +func readAndSplitIntoRange( + ctx context.Context, + engine common.Engine, + sizeLimit int64, + keysLimit int64, +) ([]common.Range, error) { + startKey, endKey, err := engine.GetKeyRange() + if err != nil { + return nil, err + } + if startKey == nil { + return nil, errors.New("could not find first pair") + } + + engineFileTotalSize, engineFileLength := engine.KVStatistics() + + if engineFileTotalSize <= sizeLimit && engineFileLength <= keysLimit { + ranges := []common.Range{{Start: startKey, End: endKey}} + return ranges, nil + } + + logger := log.FromContext(ctx).With(zap.String("engine", engine.ID())) + ranges, err := engine.SplitRanges(startKey, endKey, sizeLimit, keysLimit, logger) + logger.Info("split engine key ranges", + zap.Int64("totalSize", engineFileTotalSize), zap.Int64("totalCount", engineFileLength), + logutil.Key("startKey", startKey), logutil.Key("endKey", endKey), + zap.Int("ranges", len(ranges)), zap.Error(err)) + return ranges, err +} + +// prepareAndSendJob will read the engine to get estimated key range, +// then split and scatter regions for these range and send region jobs to jobToWorkerCh. +// NOTE when ctx is Done, this function will NOT return error even if it hasn't sent +// all the jobs to jobToWorkerCh. This is because the "first error" can only be +// found by checking the work group LATER, we don't want to return an error to +// seize the "first" error. +func (local *Backend) prepareAndSendJob( + ctx context.Context, + engine common.Engine, + initialSplitRanges []common.Range, + regionSplitSize, regionSplitKeys int64, + jobToWorkerCh chan<- *regionJob, + jobWg *sync.WaitGroup, +) error { + lfTotalSize, lfLength := engine.KVStatistics() + log.FromContext(ctx).Info("import engine ranges", zap.Int("count", len(initialSplitRanges))) + if len(initialSplitRanges) == 0 { + return nil + } + + // if all the kv can fit in one region, skip split regions. TiDB will split one region for + // the table when table is created. + needSplit := len(initialSplitRanges) > 1 || lfTotalSize > regionSplitSize || lfLength > regionSplitKeys + // split region by given ranges + failpoint.Inject("failToSplit", func(_ failpoint.Value) { + needSplit = true + }) + if needSplit { + var err error + logger := log.FromContext(ctx).With(zap.String("uuid", engine.ID())).Begin(zap.InfoLevel, "split and scatter ranges") + backOffTime := 10 * time.Second + maxbackoffTime := 120 * time.Second + for i := 0; i < maxRetryTimes; i++ { + failpoint.Inject("skipSplitAndScatter", func() { + failpoint.Break() + }) + + err = local.SplitAndScatterRegionInBatches(ctx, initialSplitRanges, maxBatchSplitRanges) + if err == nil || common.IsContextCanceledError(err) { + break + } + + log.FromContext(ctx).Warn("split and scatter failed in retry", zap.String("engine ID", engine.ID()), + log.ShortError(err), zap.Int("retry", i)) + select { + case <-time.After(backOffTime): + case <-ctx.Done(): + return ctx.Err() + } + backOffTime *= 2 + if backOffTime > maxbackoffTime { + backOffTime = maxbackoffTime + } + } + logger.End(zap.ErrorLevel, err) + if err != nil { + return err + } + } + + return local.generateAndSendJob( + ctx, + engine, + initialSplitRanges, + regionSplitSize, + regionSplitKeys, + jobToWorkerCh, + jobWg, + ) +} + +// generateAndSendJob scans the region in ranges and send region jobs to jobToWorkerCh. +func (local *Backend) generateAndSendJob( + ctx context.Context, + engine common.Engine, + jobRanges []common.Range, + regionSplitSize, regionSplitKeys int64, + jobToWorkerCh chan<- *regionJob, + jobWg *sync.WaitGroup, +) error { + logger := log.FromContext(ctx) + // for external engine, it will split into smaller data inside LoadIngestData + if localEngine, ok := engine.(*Engine); ok { + // when use dynamic region feature, the region may be very big, we need + // to split to smaller ranges to increase the concurrency. + if regionSplitSize > 2*int64(config.SplitRegionSize) { + start := jobRanges[0].Start + end := jobRanges[len(jobRanges)-1].End + sizeLimit := int64(config.SplitRegionSize) + keysLimit := int64(config.SplitRegionKeys) + jrs, err := localEngine.SplitRanges(start, end, sizeLimit, keysLimit, logger) + if err != nil { + return errors.Trace(err) + } + jobRanges = jrs + } + } + + logger.Debug("the ranges length write to tikv", zap.Int("length", len(jobRanges))) + + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + + dataAndRangeCh := make(chan common.DataAndRange) + conn := local.WorkerConcurrency + if _, ok := engine.(*external.Engine); ok { + // currently external engine will generate a large IngestData, se we lower the + // concurrency to pass backpressure to the LoadIngestData goroutine to avoid OOM + conn = 1 + } + for i := 0; i < conn; i++ { + eg.Go(func() error { + for { + select { + case <-egCtx.Done(): + return nil + case p, ok := <-dataAndRangeCh: + if !ok { + return nil + } + + failpoint.Inject("beforeGenerateJob", nil) + failpoint.Inject("sendDummyJob", func(_ failpoint.Value) { + // this is used to trigger worker failure, used together + // with WriteToTiKVNotEnoughDiskSpace + jobToWorkerCh <- ®ionJob{} + time.Sleep(5 * time.Second) + }) + jobs, err := local.generateJobForRange(egCtx, p.Data, p.Range, regionSplitSize, regionSplitKeys) + if err != nil { + if common.IsContextCanceledError(err) { + return nil + } + return err + } + for _, job := range jobs { + job.ref(jobWg) + select { + case <-egCtx.Done(): + // this job is not put into jobToWorkerCh + job.done(jobWg) + // if the context is canceled, it means worker has error, the first error can be + // found by worker's error group LATER. if this function returns an error it will + // seize the "first error". + return nil + case jobToWorkerCh <- job: + } + } + } + } + }) + } + + eg.Go(func() error { + err := engine.LoadIngestData(egCtx, jobRanges, dataAndRangeCh) + if err != nil { + return errors.Trace(err) + } + close(dataAndRangeCh) + return nil + }) + + return eg.Wait() +} + +// fakeRegionJobs is used in test, the injected job can be found by (startKey, endKey). +var fakeRegionJobs map[[2]string]struct { + jobs []*regionJob + err error +} + +// generateJobForRange will scan the region in `keyRange` and generate region jobs. +// It will retry internally when scan region meet error. +func (local *Backend) generateJobForRange( + ctx context.Context, + data common.IngestData, + keyRange common.Range, + regionSplitSize, regionSplitKeys int64, +) ([]*regionJob, error) { + failpoint.Inject("fakeRegionJobs", func() { + if ctx.Err() != nil { + failpoint.Return(nil, ctx.Err()) + } + key := [2]string{string(keyRange.Start), string(keyRange.End)} + injected := fakeRegionJobs[key] + // overwrite the stage to regionScanned, because some time same keyRange + // will be generated more than once. + for _, job := range injected.jobs { + job.stage = regionScanned + } + failpoint.Return(injected.jobs, injected.err) + }) + + start, end := keyRange.Start, keyRange.End + pairStart, pairEnd, err := data.GetFirstAndLastKey(start, end) + if err != nil { + return nil, err + } + if pairStart == nil { + logFn := log.FromContext(ctx).Info + if _, ok := data.(*external.MemoryIngestData); ok { + logFn = log.FromContext(ctx).Warn + } + logFn("There is no pairs in range", + logutil.Key("start", start), + logutil.Key("end", end)) + // trigger cleanup + data.IncRef() + data.DecRef() + return nil, nil + } + + startKey := codec.EncodeBytes([]byte{}, pairStart) + endKey := codec.EncodeBytes([]byte{}, nextKey(pairEnd)) + regions, err := split.PaginateScanRegion(ctx, local.splitCli, startKey, endKey, scanRegionLimit) + if err != nil { + log.FromContext(ctx).Error("scan region failed", + log.ShortError(err), zap.Int("region_len", len(regions)), + logutil.Key("startKey", startKey), + logutil.Key("endKey", endKey)) + return nil, err + } + + jobs := make([]*regionJob, 0, len(regions)) + for _, region := range regions { + log.FromContext(ctx).Debug("get region", + zap.Binary("startKey", startKey), + zap.Binary("endKey", endKey), + zap.Uint64("id", region.Region.GetId()), + zap.Stringer("epoch", region.Region.GetRegionEpoch()), + zap.Binary("start", region.Region.GetStartKey()), + zap.Binary("end", region.Region.GetEndKey()), + zap.Reflect("peers", region.Region.GetPeers())) + + jobs = append(jobs, ®ionJob{ + keyRange: intersectRange(region.Region, common.Range{Start: start, End: end}), + region: region, + stage: regionScanned, + ingestData: data, + regionSplitSize: regionSplitSize, + regionSplitKeys: regionSplitKeys, + metrics: local.metrics, + }) + } + return jobs, nil +} + +// startWorker creates a worker that reads from the job channel and processes. +// startWorker will return nil if it's expected to stop, where the only case is +// the context canceled. It will return not nil error when it actively stops. +// startWorker must Done the jobWg if it does not put the job into jobOutCh. +func (local *Backend) startWorker( + ctx context.Context, + jobInCh, jobOutCh chan *regionJob, + jobWg *sync.WaitGroup, +) error { + metrics.GlobalSortIngestWorkerCnt.WithLabelValues("execute job").Set(0) + for { + select { + case <-ctx.Done(): + return nil + case job, ok := <-jobInCh: + if !ok { + // In fact we don't use close input channel to notify worker to + // exit, because there's a cycle in workflow. + return nil + } + + metrics.GlobalSortIngestWorkerCnt.WithLabelValues("execute job").Inc() + err := local.executeJob(ctx, job) + metrics.GlobalSortIngestWorkerCnt.WithLabelValues("execute job").Dec() + switch job.stage { + case regionScanned, wrote, ingested: + jobOutCh <- job + case needRescan: + jobs, err2 := local.generateJobForRange( + ctx, + job.ingestData, + job.keyRange, + job.regionSplitSize, + job.regionSplitKeys, + ) + if err2 != nil { + // Don't need to put the job back to retry, because generateJobForRange + // has done the retry internally. Here just done for the "needRescan" + // job and exit directly. + job.done(jobWg) + return err2 + } + // 1 "needRescan" job becomes len(jobs) "regionScanned" jobs. + newJobCnt := len(jobs) - 1 + for newJobCnt > 0 { + job.ref(jobWg) + newJobCnt-- + } + for _, j := range jobs { + j.lastRetryableErr = job.lastRetryableErr + jobOutCh <- j + } + } + + if err != nil { + return err + } + } + } +} + +func (*Backend) isRetryableImportTiKVError(err error) bool { + err = errors.Cause(err) + // io.EOF is not retryable in normal case + // but on TiKV restart, if we're writing to TiKV(through GRPC) + // it might return io.EOF(it's GRPC Unavailable in most case), + // we need to retry on this error. + // see SendMsg in https://pkg.go.dev/google.golang.org/grpc#ClientStream + if err == io.EOF { + return true + } + return common.IsRetryableError(err) +} + +func checkDiskAvail(ctx context.Context, store *pdhttp.StoreInfo) error { + logger := log.FromContext(ctx) + capacity, err := units.RAMInBytes(store.Status.Capacity) + if err != nil { + logger.Warn("failed to parse capacity", + zap.String("capacity", store.Status.Capacity), zap.Error(err)) + return nil + } + if capacity <= 0 { + // PD will return a zero value StoreInfo if heartbeat is not received after + // startup, skip temporarily. + return nil + } + available, err := units.RAMInBytes(store.Status.Available) + if err != nil { + logger.Warn("failed to parse available", + zap.String("available", store.Status.Available), zap.Error(err)) + return nil + } + ratio := available * 100 / capacity + if ratio < 10 { + storeType := "TiKV" + if engine.IsTiFlashHTTPResp(&store.Store) { + storeType = "TiFlash" + } + return errors.Errorf("the remaining storage capacity of %s(%s) is less than 10%%; please increase the storage capacity of %s and try again", + storeType, store.Store.Address, storeType) + } + return nil +} + +// executeJob handles a regionJob and tries to convert it to ingested stage. +// If non-retryable error occurs, it will return the error. +// If retryable error occurs, it will return nil and caller should check the stage +// of the regionJob to determine what to do with it. +func (local *Backend) executeJob( + ctx context.Context, + job *regionJob, +) error { + failpoint.Inject("WriteToTiKVNotEnoughDiskSpace", func(_ failpoint.Value) { + failpoint.Return( + errors.New("the remaining storage capacity of TiKV is less than 10%%; please increase the storage capacity of TiKV and try again")) + }) + if local.ShouldCheckTiKV { + for _, peer := range job.region.Region.GetPeers() { + store, err := local.pdHTTPCli.GetStore(ctx, peer.StoreId) + if err != nil { + log.FromContext(ctx).Warn("failed to get StoreInfo from pd http api", zap.Error(err)) + continue + } + err = checkDiskAvail(ctx, store) + if err != nil { + return err + } + } + } + + for { + err := local.writeToTiKV(ctx, job) + if err != nil { + if !local.isRetryableImportTiKVError(err) { + return err + } + // if it's retryable error, we retry from scanning region + log.FromContext(ctx).Warn("meet retryable error when writing to TiKV", + log.ShortError(err), zap.Stringer("job stage", job.stage)) + job.lastRetryableErr = err + return nil + } + + err = local.ingest(ctx, job) + if err != nil { + if !local.isRetryableImportTiKVError(err) { + return err + } + log.FromContext(ctx).Warn("meet retryable error when ingesting", + log.ShortError(err), zap.Stringer("job stage", job.stage)) + job.lastRetryableErr = err + return nil + } + // if the job.stage successfully converted into "ingested", it means + // these data are ingested into TiKV so we handle remaining data. + // For other job.stage, the job should be sent back to caller to retry + // later. + if job.stage != ingested { + return nil + } + + if job.writeResult == nil || job.writeResult.remainingStartKey == nil { + return nil + } + job.keyRange.Start = job.writeResult.remainingStartKey + job.convertStageTo(regionScanned) + } +} + +// ImportEngine imports an engine to TiKV. +func (local *Backend) ImportEngine( + ctx context.Context, + engineUUID uuid.UUID, + regionSplitSize, regionSplitKeys int64, +) error { + var e common.Engine + if externalEngine, ok := local.engineMgr.getExternalEngine(engineUUID); ok { + e = externalEngine + } else { + localEngine := local.engineMgr.lockEngine(engineUUID, importMutexStateImport) + if localEngine == nil { + // skip if engine not exist. See the comment of `CloseEngine` for more detail. + return nil + } + defer localEngine.unlock() + e = localEngine + } + + lfTotalSize, lfLength := e.KVStatistics() + if lfTotalSize == 0 { + // engine is empty, this is likes because it's a index engine but the table contains no index + log.FromContext(ctx).Info("engine contains no kv, skip import", zap.Stringer("engine", engineUUID)) + return nil + } + kvRegionSplitSize, kvRegionSplitKeys, err := GetRegionSplitSizeKeys(ctx, local.pdCli, local.tls) + if err == nil { + if kvRegionSplitSize > regionSplitSize { + regionSplitSize = kvRegionSplitSize + } + if kvRegionSplitKeys > regionSplitKeys { + regionSplitKeys = kvRegionSplitKeys + } + } else { + log.FromContext(ctx).Warn("fail to get region split keys and size", zap.Error(err)) + } + + // split sorted file into range about regionSplitSize per file + regionRanges, err := readAndSplitIntoRange(ctx, e, regionSplitSize, regionSplitKeys) + if err != nil { + return err + } + + if len(regionRanges) > 0 && local.PausePDSchedulerScope == config.PausePDSchedulerScopeTable { + log.FromContext(ctx).Info("pause pd scheduler of table scope") + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var startKey, endKey []byte + if len(regionRanges[0].Start) > 0 { + startKey = codec.EncodeBytes(nil, regionRanges[0].Start) + } + if len(regionRanges[len(regionRanges)-1].End) > 0 { + endKey = codec.EncodeBytes(nil, regionRanges[len(regionRanges)-1].End) + } + done, err := pdutil.PauseSchedulersByKeyRange(subCtx, local.pdHTTPCli, startKey, endKey) + if err != nil { + return errors.Trace(err) + } + defer func() { + cancel() + <-done + }() + } + + if len(regionRanges) > 0 && local.BackendConfig.RaftKV2SwitchModeDuration > 0 { + log.FromContext(ctx).Info("switch import mode of ranges", + zap.String("startKey", hex.EncodeToString(regionRanges[0].Start)), + zap.String("endKey", hex.EncodeToString(regionRanges[len(regionRanges)-1].End))) + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + done, err := local.SwitchModeByKeyRanges(subCtx, regionRanges) + if err != nil { + return errors.Trace(err) + } + defer func() { + cancel() + <-done + }() + } + + log.FromContext(ctx).Info("start import engine", + zap.Stringer("uuid", engineUUID), + zap.Int("region ranges", len(regionRanges)), + zap.Int64("count", lfLength), + zap.Int64("size", lfTotalSize)) + + failpoint.Inject("ReadyForImportEngine", func() {}) + + err = local.doImport(ctx, e, regionRanges, regionSplitSize, regionSplitKeys) + if err == nil { + importedSize, importedLength := e.ImportedStatistics() + log.FromContext(ctx).Info("import engine success", + zap.Stringer("uuid", engineUUID), + zap.Int64("size", lfTotalSize), + zap.Int64("kvs", lfLength), + zap.Int64("importedSize", importedSize), + zap.Int64("importedCount", importedLength)) + } + return err +} + +// expose these variables to unit test. +var ( + testJobToWorkerCh = make(chan *regionJob) + testJobWg *sync.WaitGroup +) + +func (local *Backend) doImport(ctx context.Context, engine common.Engine, regionRanges []common.Range, regionSplitSize, regionSplitKeys int64) error { + /* + [prepareAndSendJob]-----jobToWorkerCh--->[workers] + ^ | + | jobFromWorkerCh + | | + | v + [regionJobRetryer]<--[dispatchJobGoroutine]-->done + */ + + var ( + ctx2, workerCancel = context.WithCancel(ctx) + // workerCtx.Done() means workflow is canceled by error. It may be caused + // by calling workerCancel() or workers in workGroup meets error. + workGroup, workerCtx = util.NewErrorGroupWithRecoverWithCtx(ctx2) + firstErr common.OnceError + // jobToWorkerCh and jobFromWorkerCh are unbuffered so jobs will not be + // owned by them. + jobToWorkerCh = make(chan *regionJob) + jobFromWorkerCh = make(chan *regionJob) + // jobWg tracks the number of jobs in this workflow. + // prepareAndSendJob, workers and regionJobRetryer can own jobs. + // When cancel on error, the goroutine of above three components have + // responsibility to Done jobWg of their owning jobs. + jobWg sync.WaitGroup + dispatchJobGoroutine = make(chan struct{}) + ) + defer workerCancel() + + failpoint.Inject("injectVariables", func() { + jobToWorkerCh = testJobToWorkerCh + testJobWg = &jobWg + }) + + retryer := startRegionJobRetryer(workerCtx, jobToWorkerCh, &jobWg) + + // dispatchJobGoroutine handles processed job from worker, it will only exit + // when jobFromWorkerCh is closed to avoid worker is blocked on sending to + // jobFromWorkerCh. + defer func() { + // use defer to close jobFromWorkerCh after all workers are exited + close(jobFromWorkerCh) + <-dispatchJobGoroutine + }() + go func() { + defer close(dispatchJobGoroutine) + for { + job, ok := <-jobFromWorkerCh + if !ok { + return + } + switch job.stage { + case regionScanned, wrote: + job.retryCount++ + if job.retryCount > maxWriteAndIngestRetryTimes { + firstErr.Set(job.lastRetryableErr) + workerCancel() + job.done(&jobWg) + continue + } + // max retry backoff time: 2+4+8+16+30*26=810s + sleepSecond := math.Pow(2, float64(job.retryCount)) + if sleepSecond > float64(maxRetryBackoffSecond) { + sleepSecond = float64(maxRetryBackoffSecond) + } + job.waitUntil = time.Now().Add(time.Second * time.Duration(sleepSecond)) + log.FromContext(ctx).Info("put job back to jobCh to retry later", + logutil.Key("startKey", job.keyRange.Start), + logutil.Key("endKey", job.keyRange.End), + zap.Stringer("stage", job.stage), + zap.Int("retryCount", job.retryCount), + zap.Time("waitUntil", job.waitUntil)) + if !retryer.push(job) { + // retryer is closed by worker error + job.done(&jobWg) + } + case ingested: + job.done(&jobWg) + case needRescan: + panic("should not reach here") + } + } + }() + + failpoint.Inject("skipStartWorker", func() { + failpoint.Goto("afterStartWorker") + }) + + for i := 0; i < local.WorkerConcurrency; i++ { + workGroup.Go(func() error { + return local.startWorker(workerCtx, jobToWorkerCh, jobFromWorkerCh, &jobWg) + }) + } + + failpoint.Label("afterStartWorker") + + workGroup.Go(func() error { + err := local.prepareAndSendJob( + workerCtx, + engine, + regionRanges, + regionSplitSize, + regionSplitKeys, + jobToWorkerCh, + &jobWg, + ) + if err != nil { + return err + } + + jobWg.Wait() + workerCancel() + return nil + }) + if err := workGroup.Wait(); err != nil { + if !common.IsContextCanceledError(err) { + log.FromContext(ctx).Error("do import meets error", zap.Error(err)) + } + firstErr.Set(err) + } + return firstErr.Get() +} + +// GetImportedKVCount returns the number of imported KV pairs of some engine. +func (local *Backend) GetImportedKVCount(engineUUID uuid.UUID) int64 { + return local.engineMgr.getImportedKVCount(engineUUID) +} + +// GetExternalEngineKVStatistics returns kv statistics of some engine. +func (local *Backend) GetExternalEngineKVStatistics(engineUUID uuid.UUID) ( + totalKVSize int64, totalKVCount int64) { + return local.engineMgr.getExternalEngineKVStatistics(engineUUID) +} + +// ResetEngine reset the engine and reclaim the space. +func (local *Backend) ResetEngine(ctx context.Context, engineUUID uuid.UUID) error { + return local.engineMgr.resetEngine(ctx, engineUUID, false) +} + +// ResetEngineSkipAllocTS is like ResetEngine but the inner TS of the engine is +// invalid. Caller must use OpenedEngine.SetTS to set a valid TS before import +// the engine. +func (local *Backend) ResetEngineSkipAllocTS(ctx context.Context, engineUUID uuid.UUID) error { + return local.engineMgr.resetEngine(ctx, engineUUID, true) +} + +// CleanupEngine cleanup the engine and reclaim the space. +func (local *Backend) CleanupEngine(ctx context.Context, engineUUID uuid.UUID) error { + return local.engineMgr.cleanupEngine(ctx, engineUUID) +} + +// GetDupeController returns a new dupe controller. +func (local *Backend) GetDupeController(dupeConcurrency int, errorMgr *errormanager.ErrorManager) *DupeController { + return &DupeController{ + splitCli: local.splitCli, + tikvCli: local.tikvCli, + tikvCodec: local.tikvCodec, + errorMgr: errorMgr, + dupeConcurrency: dupeConcurrency, + duplicateDB: local.engineMgr.getDuplicateDB(), + keyAdapter: local.engineMgr.getKeyAdapter(), + importClientFactory: local.importClientFactory, + resourceGroupName: local.ResourceGroupName, + taskType: local.TaskType, + } +} + +// UnsafeImportAndReset forces the backend to import the content of an engine +// into the target and then reset the engine to empty. This method will not +// close the engine. Make sure the engine is flushed manually before calling +// this method. +func (local *Backend) UnsafeImportAndReset(ctx context.Context, engineUUID uuid.UUID, regionSplitSize, regionSplitKeys int64) error { + // DO NOT call be.abstract.CloseEngine()! The engine should still be writable after + // calling UnsafeImportAndReset(). + logger := log.FromContext(ctx).With( + zap.String("engineTag", ""), + zap.Stringer("engineUUID", engineUUID), + ) + closedEngine := backend.NewClosedEngine(local, logger, engineUUID, 0) + if err := closedEngine.Import(ctx, regionSplitSize, regionSplitKeys); err != nil { + return err + } + return local.engineMgr.resetEngine(ctx, engineUUID, false) +} + +func engineSSTDir(storeDir string, engineUUID uuid.UUID) string { + return filepath.Join(storeDir, engineUUID.String()+".sst") +} + +// LocalWriter returns a new local writer. +func (local *Backend) LocalWriter(ctx context.Context, cfg *backend.LocalWriterConfig, engineUUID uuid.UUID) (backend.EngineWriter, error) { + return local.engineMgr.localWriter(ctx, cfg, engineUUID) +} + +// SwitchModeByKeyRanges will switch tikv mode for regions in the specific key range for multirocksdb. +// This function will spawn a goroutine to keep switch mode periodically until the context is done. +// The return done channel is used to notify the caller that the background goroutine is exited. +func (local *Backend) SwitchModeByKeyRanges(ctx context.Context, ranges []common.Range) (<-chan struct{}, error) { + switcher := NewTiKVModeSwitcher(local.tls.TLSConfig(), local.pdHTTPCli, log.FromContext(ctx).Logger) + done := make(chan struct{}) + + keyRanges := make([]*sst.Range, 0, len(ranges)) + for _, r := range ranges { + startKey := r.Start + if len(r.Start) > 0 { + startKey = codec.EncodeBytes(nil, r.Start) + } + endKey := r.End + if len(r.End) > 0 { + endKey = codec.EncodeBytes(nil, r.End) + } + keyRanges = append(keyRanges, &sst.Range{ + Start: startKey, + End: endKey, + }) + } + + go func() { + defer close(done) + ticker := time.NewTicker(local.BackendConfig.RaftKV2SwitchModeDuration) + defer ticker.Stop() + switcher.ToImportMode(ctx, keyRanges...) + loop: + for { + select { + case <-ctx.Done(): + break loop + case <-ticker.C: + switcher.ToImportMode(ctx, keyRanges...) + } + } + // Use a new context to avoid the context is canceled by the caller. + recoverCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + switcher.ToNormalMode(recoverCtx, keyRanges...) + }() + return done, nil +} + +func openLocalWriter(cfg *backend.LocalWriterConfig, engine *Engine, tikvCodec tikvclient.Codec, cacheSize int64, kvBuffer *membuf.Buffer) (*Writer, error) { + // pre-allocate a long enough buffer to avoid a lot of runtime.growslice + // this can help save about 3% of CPU. + var preAllocWriteBatch []common.KvPair + if !cfg.Local.IsKVSorted { + preAllocWriteBatch = make([]common.KvPair, units.MiB) + // we want to keep the cacheSize as the whole limit of this local writer, but the + // main memory usage comes from two member: kvBuffer and writeBatch, so we split + // ~10% to writeBatch for !IsKVSorted, which means we estimate the average length + // of KV pairs are 9 times than the size of common.KvPair (9*72B = 648B). + cacheSize = cacheSize * 9 / 10 + } + w := &Writer{ + engine: engine, + memtableSizeLimit: cacheSize, + kvBuffer: kvBuffer, + isKVSorted: cfg.Local.IsKVSorted, + isWriteBatchSorted: true, + tikvCodec: tikvCodec, + writeBatch: preAllocWriteBatch, + } + engine.localWriters.Store(w, nil) + return w, nil +} + +// return the smallest []byte that is bigger than current bytes. +// special case when key is empty, empty bytes means infinity in our context, so directly return itself. +func nextKey(key []byte) []byte { + if len(key) == 0 { + return []byte{} + } + + // in tikv <= 4.x, tikv will truncate the row key, so we should fetch the next valid row key + // See: https://github.com/tikv/tikv/blob/f7f22f70e1585d7ca38a59ea30e774949160c3e8/components/raftstore/src/coprocessor/split_observer.rs#L36-L41 + // we only do this for IntHandle, which is checked by length + if tablecodec.IsRecordKey(key) && len(key) == tablecodec.RecordRowKeyLen { + tableID, handle, _ := tablecodec.DecodeRecordKey(key) + nextHandle := handle.Next() + // int handle overflow, use the next table prefix as nextKey + if nextHandle.Compare(handle) <= 0 { + return tablecodec.EncodeTablePrefix(tableID + 1) + } + return tablecodec.EncodeRowKeyWithHandle(tableID, nextHandle) + } + + // for index key and CommonHandle, directly append a 0x00 to the key. + res := make([]byte, 0, len(key)+1) + res = append(res, key...) + res = append(res, 0) + return res +} + +// EngineFileSizes implements DiskUsage interface. +func (local *Backend) EngineFileSizes() (res []backend.EngineFileSize) { + return local.engineMgr.engineFileSizes() +} + +// GetTS implements StoreHelper interface. +func (local *Backend) GetTS(ctx context.Context) (physical, logical int64, err error) { + return local.pdCli.GetTS(ctx) +} + +// GetTiKVCodec implements StoreHelper interface. +func (local *Backend) GetTiKVCodec() tikvclient.Codec { + return local.tikvCodec +} + +// CloseEngineMgr close the engine manager. +// This function is used for test. +func (local *Backend) CloseEngineMgr() { + local.engineMgr.close() +} + +var getSplitConfFromStoreFunc = getSplitConfFromStore + +// return region split size, region split keys, error +func getSplitConfFromStore(ctx context.Context, host string, tls *common.TLS) ( + splitSize int64, regionSplitKeys int64, err error) { + var ( + nested struct { + Coprocessor struct { + RegionSplitSize string `json:"region-split-size"` + RegionSplitKeys int64 `json:"region-split-keys"` + } `json:"coprocessor"` + } + ) + if err := tls.WithHost(host).GetJSON(ctx, "/config", &nested); err != nil { + return 0, 0, errors.Trace(err) + } + splitSize, err = units.FromHumanSize(nested.Coprocessor.RegionSplitSize) + if err != nil { + return 0, 0, errors.Trace(err) + } + + return splitSize, nested.Coprocessor.RegionSplitKeys, nil +} + +// GetRegionSplitSizeKeys return region split size, region split keys, error +func GetRegionSplitSizeKeys(ctx context.Context, cli pd.Client, tls *common.TLS) ( + regionSplitSize int64, regionSplitKeys int64, err error) { + stores, err := cli.GetAllStores(ctx, pd.WithExcludeTombstone()) + if err != nil { + return 0, 0, err + } + for _, store := range stores { + if store.StatusAddress == "" || engine.IsTiFlash(store) { + continue + } + serverInfo := infoschema.ServerInfo{ + Address: store.Address, + StatusAddr: store.StatusAddress, + } + serverInfo.ResolveLoopBackAddr() + regionSplitSize, regionSplitKeys, err := getSplitConfFromStoreFunc(ctx, serverInfo.StatusAddr, tls) + if err == nil { + return regionSplitSize, regionSplitKeys, nil + } + log.FromContext(ctx).Warn("get region split size and keys failed", zap.Error(err), zap.String("store", serverInfo.StatusAddr)) + } + return 0, 0, errors.New("get region split size and keys failed") +} diff --git a/pkg/lightning/backend/local/local_unix.go b/pkg/lightning/backend/local/local_unix.go index 20695c6ebd6ab..ec213e3664581 100644 --- a/pkg/lightning/backend/local/local_unix.go +++ b/pkg/lightning/backend/local/local_unix.go @@ -45,12 +45,12 @@ func VerifyRLimit(estimateMaxFiles RlimT) error { } var rLimit syscall.Rlimit err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) - failpoint.Inject("GetRlimitValue", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("GetRlimitValue")); _err_ == nil { limit := RlimT(v.(int)) rLimit.Cur = limit rLimit.Max = limit err = nil - }) + } if err != nil { return errors.Trace(err) } @@ -63,11 +63,11 @@ func VerifyRLimit(estimateMaxFiles RlimT) error { } prevLimit := rLimit.Cur rLimit.Cur = estimateMaxFiles - failpoint.Inject("SetRlimitError", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("SetRlimitError")); _err_ == nil { if v.(bool) { err = errors.New("Setrlimit Injected Error") } - }) + } if err == nil { err = syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) } diff --git a/pkg/lightning/backend/local/local_unix.go__failpoint_stash__ b/pkg/lightning/backend/local/local_unix.go__failpoint_stash__ new file mode 100644 index 0000000000000..20695c6ebd6ab --- /dev/null +++ b/pkg/lightning/backend/local/local_unix.go__failpoint_stash__ @@ -0,0 +1,92 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows + +package local + +import ( + "syscall" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/lightning/log" +) + +const ( + // maximum max open files value + maxRLimit = 1000000 +) + +// GetSystemRLimit returns the current open-file limit. +func GetSystemRLimit() (RlimT, error) { + var rLimit syscall.Rlimit + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) + return rLimit.Cur, err +} + +// VerifyRLimit checks whether the open-file limit is large enough. +// In Local-backend, we need to read and write a lot of L0 SST files, so we need +// to check system max open files limit. +func VerifyRLimit(estimateMaxFiles RlimT) error { + if estimateMaxFiles > maxRLimit { + estimateMaxFiles = maxRLimit + } + var rLimit syscall.Rlimit + err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) + failpoint.Inject("GetRlimitValue", func(v failpoint.Value) { + limit := RlimT(v.(int)) + rLimit.Cur = limit + rLimit.Max = limit + err = nil + }) + if err != nil { + return errors.Trace(err) + } + if rLimit.Cur >= estimateMaxFiles { + return nil + } + if rLimit.Max < estimateMaxFiles { + // If the process is not started by privileged user, this will fail. + rLimit.Max = estimateMaxFiles + } + prevLimit := rLimit.Cur + rLimit.Cur = estimateMaxFiles + failpoint.Inject("SetRlimitError", func(v failpoint.Value) { + if v.(bool) { + err = errors.New("Setrlimit Injected Error") + } + }) + if err == nil { + err = syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) + } + if err != nil { + return errors.Annotatef(err, "the maximum number of open file descriptors is too small, got %d, expect greater or equal to %d", prevLimit, estimateMaxFiles) + } + + // fetch the rlimit again to make sure our setting has taken effect + err = syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) + if err != nil { + return errors.Trace(err) + } + if rLimit.Cur < estimateMaxFiles { + helper := "Please manually execute `ulimit -n %d` to increase the open files limit." + return errors.Errorf("cannot update the maximum number of open file descriptors, expected: %d, got: %d. %s", + estimateMaxFiles, rLimit.Cur, helper) + } + + log.L().Info("Set the maximum number of open file descriptors(rlimit)", + zapRlimT("old", prevLimit), zapRlimT("new", estimateMaxFiles)) + return nil +} diff --git a/pkg/lightning/backend/local/region_job.go b/pkg/lightning/backend/local/region_job.go index 7bc812e4b9bb6..d18c33970a4ec 100644 --- a/pkg/lightning/backend/local/region_job.go +++ b/pkg/lightning/backend/local/region_job.go @@ -208,7 +208,7 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { return nil } - failpoint.Inject("fakeRegionJobs", func() { + if _, _err_ := failpoint.Eval(_curpkg_("fakeRegionJobs")); _err_ == nil { front := j.injected[0] j.injected = j.injected[1:] j.writeResult = front.write.result @@ -216,8 +216,8 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { if err == nil { j.convertStageTo(wrote) } - failpoint.Return(err) - }) + return err + } var cancel context.CancelFunc ctx, cancel = context.WithTimeoutCause(ctx, 15*time.Minute, common.ErrWriteTooSlow) @@ -261,7 +261,7 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { ApiVersion: apiVersion, } - failpoint.Inject("changeEpochVersion", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("changeEpochVersion")); _err_ == nil { cloned := *meta.RegionEpoch meta.RegionEpoch = &cloned i := val.(int) @@ -270,7 +270,7 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { } else { meta.RegionEpoch.ConfVer -= uint64(-i) } - }) + } annotateErr := func(in error, peer *metapb.Peer, msg string) error { // annotate the error with peer/store/region info to help debug. @@ -307,10 +307,10 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { return annotateErr(err, peer, "when open write stream") } - failpoint.Inject("mockWritePeerErr", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockWritePeerErr")); _err_ == nil { err = errors.Errorf("mock write peer error") - failpoint.Return(annotateErr(err, peer, "when open write stream")) - }) + return annotateErr(err, peer, "when open write stream") + } // Bind uuid for this write request if err = wstream.Send(req); err != nil { @@ -360,9 +360,9 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { return annotateErr(err, allPeers[i], "when send data") } } - failpoint.Inject("afterFlushKVs", func() { + if _, _err_ := failpoint.Eval(_curpkg_("afterFlushKVs")); _err_ == nil { log.FromContext(ctx).Info(fmt.Sprintf("afterFlushKVs count=%d,size=%d", count, size)) - }) + } return nil } @@ -444,10 +444,10 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { } } - failpoint.Inject("NoLeader", func() { + if _, _err_ := failpoint.Eval(_curpkg_("NoLeader")); _err_ == nil { log.FromContext(ctx).Warn("enter failpoint NoLeader") leaderPeerMetas = nil - }) + } // if there is not leader currently, we don't forward the stage to wrote and let caller // handle the retry. @@ -488,12 +488,12 @@ func (local *Backend) ingest(ctx context.Context, j *regionJob) (err error) { return nil } - failpoint.Inject("fakeRegionJobs", func() { + if _, _err_ := failpoint.Eval(_curpkg_("fakeRegionJobs")); _err_ == nil { front := j.injected[0] j.injected = j.injected[1:] j.convertStageTo(front.ingest.nextStage) - failpoint.Return(front.ingest.err) - }) + return front.ingest.err + } if len(j.writeResult.sstMeta) == 0 { j.convertStageTo(ingested) @@ -597,7 +597,7 @@ func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestRe log.FromContext(ctx).Debug("ingest meta", zap.Reflect("meta", ingestMetas)) - failpoint.Inject("FailIngestMeta", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("FailIngestMeta")); _err_ == nil { // only inject the error once var resp *sst.IngestResponse @@ -620,8 +620,8 @@ func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestRe }, } } - failpoint.Return(resp, nil) - }) + return resp, nil + } leader := j.region.Leader if leader == nil { diff --git a/pkg/lightning/backend/local/region_job.go__failpoint_stash__ b/pkg/lightning/backend/local/region_job.go__failpoint_stash__ new file mode 100644 index 0000000000000..7bc812e4b9bb6 --- /dev/null +++ b/pkg/lightning/backend/local/region_job.go__failpoint_stash__ @@ -0,0 +1,907 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package local + +import ( + "container/heap" + "context" + "fmt" + "io" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" + sst "github.com/pingcap/kvproto/pkg/import_sstpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + berrors "github.com/pingcap/tidb/br/pkg/errors" + "github.com/pingcap/tidb/br/pkg/logutil" + "github.com/pingcap/tidb/br/pkg/restore/split" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/metric" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +type jobStageTp string + +/* + + + v + +------+------+ + +->+regionScanned+<------+ + | +------+------+ | + | | | + | | | + | v | + | +--+--+ +-----+----+ + | |wrote+---->+needRescan| + | +--+--+ +-----+----+ + | | ^ + | | | + | v | + | +---+----+ | + +-----+ingested+---------+ + +---+----+ + | + v + +above diagram shows the state transition of a region job, here are some special +cases: + - regionScanned can directly jump to ingested if the keyRange has no data + - regionScanned can only transit to wrote. TODO: check if it should be transited + to needRescan + - if a job only partially writes the data, after it becomes ingested, it will + update its keyRange and transits to regionScanned to continue the remaining + data + - needRescan may output multiple regionScanned jobs when the old region is split +*/ +const ( + regionScanned jobStageTp = "regionScanned" + wrote jobStageTp = "wrote" + ingested jobStageTp = "ingested" + needRescan jobStageTp = "needRescan" + + // suppose each KV is about 32 bytes, 16 * units.KiB / 32 = 512 + defaultKVBatchCount = 512 +) + +func (j jobStageTp) String() string { + return string(j) +} + +// regionJob is dedicated to import the data in [keyRange.start, keyRange.end) +// to a region. The keyRange may be changed when processing because of writing +// partial data to TiKV or region split. +type regionJob struct { + keyRange common.Range + // TODO: check the keyRange so that it's always included in region + region *split.RegionInfo + // stage should be updated only by convertStageTo + stage jobStageTp + // writeResult is available only in wrote and ingested stage + writeResult *tikvWriteResult + + ingestData common.IngestData + regionSplitSize int64 + regionSplitKeys int64 + metrics *metric.Common + + retryCount int + waitUntil time.Time + lastRetryableErr error + + // injected is used in test to set the behaviour + injected []injectedBehaviour +} + +type tikvWriteResult struct { + sstMeta []*sst.SSTMeta + count int64 + totalBytes int64 + remainingStartKey []byte +} + +type injectedBehaviour struct { + write injectedWriteBehaviour + ingest injectedIngestBehaviour +} + +type injectedWriteBehaviour struct { + result *tikvWriteResult + err error +} + +type injectedIngestBehaviour struct { + nextStage jobStageTp + err error +} + +func (j *regionJob) convertStageTo(stage jobStageTp) { + j.stage = stage + switch stage { + case regionScanned: + j.writeResult = nil + case ingested: + // when writing is skipped because key range is empty + if j.writeResult == nil { + return + } + + j.ingestData.Finish(j.writeResult.totalBytes, j.writeResult.count) + if j.metrics != nil { + j.metrics.BytesCounter.WithLabelValues(metric.StateImported). + Add(float64(j.writeResult.totalBytes)) + } + case needRescan: + j.region = nil + } +} + +// ref means that the ingestData of job will be accessed soon. +func (j *regionJob) ref(wg *sync.WaitGroup) { + if wg != nil { + wg.Add(1) + } + if j.ingestData != nil { + j.ingestData.IncRef() + } +} + +// done promises that the ingestData of job will not be accessed. Same amount of +// done should be called to release the ingestData. +func (j *regionJob) done(wg *sync.WaitGroup) { + if j.ingestData != nil { + j.ingestData.DecRef() + } + if wg != nil { + wg.Done() + } +} + +// writeToTiKV writes the data to TiKV and mark this job as wrote stage. +// if any write logic has error, writeToTiKV will set job to a proper stage and return nil. +// if any underlying logic has error, writeToTiKV will return an error. +// we don't need to do cleanup for the pairs written to tikv if encounters an error, +// tikv will take the responsibility to do so. +// TODO: let client-go provide a high-level write interface. +func (local *Backend) writeToTiKV(ctx context.Context, j *regionJob) error { + err := local.doWrite(ctx, j) + if err == nil { + return nil + } + if !common.IsRetryableError(err) { + return err + } + // currently only one case will restart write + if strings.Contains(err.Error(), "RequestTooNew") { + j.convertStageTo(regionScanned) + return err + } + j.convertStageTo(needRescan) + return err +} + +func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { + if j.stage != regionScanned { + return nil + } + + failpoint.Inject("fakeRegionJobs", func() { + front := j.injected[0] + j.injected = j.injected[1:] + j.writeResult = front.write.result + err := front.write.err + if err == nil { + j.convertStageTo(wrote) + } + failpoint.Return(err) + }) + + var cancel context.CancelFunc + ctx, cancel = context.WithTimeoutCause(ctx, 15*time.Minute, common.ErrWriteTooSlow) + defer cancel() + + apiVersion := local.tikvCodec.GetAPIVersion() + clientFactory := local.importClientFactory + kvBatchSize := local.KVWriteBatchSize + bufferPool := local.engineMgr.getBufferPool() + writeLimiter := local.writeLimiter + + begin := time.Now() + region := j.region.Region + + firstKey, lastKey, err := j.ingestData.GetFirstAndLastKey(j.keyRange.Start, j.keyRange.End) + if err != nil { + return errors.Trace(err) + } + if firstKey == nil { + j.convertStageTo(ingested) + log.FromContext(ctx).Debug("keys within region is empty, skip doIngest", + logutil.Key("start", j.keyRange.Start), + logutil.Key("regionStart", region.StartKey), + logutil.Key("end", j.keyRange.End), + logutil.Key("regionEnd", region.EndKey)) + return nil + } + + firstKey = codec.EncodeBytes([]byte{}, firstKey) + lastKey = codec.EncodeBytes([]byte{}, lastKey) + + u := uuid.New() + meta := &sst.SSTMeta{ + Uuid: u[:], + RegionId: region.GetId(), + RegionEpoch: region.GetRegionEpoch(), + Range: &sst.Range{ + Start: firstKey, + End: lastKey, + }, + ApiVersion: apiVersion, + } + + failpoint.Inject("changeEpochVersion", func(val failpoint.Value) { + cloned := *meta.RegionEpoch + meta.RegionEpoch = &cloned + i := val.(int) + if i >= 0 { + meta.RegionEpoch.Version += uint64(i) + } else { + meta.RegionEpoch.ConfVer -= uint64(-i) + } + }) + + annotateErr := func(in error, peer *metapb.Peer, msg string) error { + // annotate the error with peer/store/region info to help debug. + return errors.Annotatef( + in, + "peer %d, store %d, region %d, epoch %s, %s", + peer.Id, peer.StoreId, region.Id, region.RegionEpoch.String(), + msg, + ) + } + + leaderID := j.region.Leader.GetId() + clients := make([]sst.ImportSST_WriteClient, 0, len(region.GetPeers())) + allPeers := make([]*metapb.Peer, 0, len(region.GetPeers())) + req := &sst.WriteRequest{ + Chunk: &sst.WriteRequest_Meta{ + Meta: meta, + }, + Context: &kvrpcpb.Context{ + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: local.ResourceGroupName, + }, + RequestSource: util.BuildRequestSource(true, kv.InternalTxnLightning, local.TaskType), + }, + } + for _, peer := range region.GetPeers() { + cli, err := clientFactory.Create(ctx, peer.StoreId) + if err != nil { + return annotateErr(err, peer, "when create client") + } + + wstream, err := cli.Write(ctx) + if err != nil { + return annotateErr(err, peer, "when open write stream") + } + + failpoint.Inject("mockWritePeerErr", func() { + err = errors.Errorf("mock write peer error") + failpoint.Return(annotateErr(err, peer, "when open write stream")) + }) + + // Bind uuid for this write request + if err = wstream.Send(req); err != nil { + return annotateErr(err, peer, "when send meta") + } + clients = append(clients, wstream) + allPeers = append(allPeers, peer) + } + dataCommitTS := j.ingestData.GetTS() + req.Chunk = &sst.WriteRequest_Batch{ + Batch: &sst.WriteBatch{ + CommitTs: dataCommitTS, + }, + } + + pairs := make([]*sst.Pair, 0, defaultKVBatchCount) + count := 0 + size := int64(0) + totalSize := int64(0) + totalCount := int64(0) + // if region-split-size <= 96MiB, we bump the threshold a bit to avoid too many retry split + // because the range-properties is not 100% accurate + regionMaxSize := j.regionSplitSize + if j.regionSplitSize <= int64(config.SplitRegionSize) { + regionMaxSize = j.regionSplitSize * 4 / 3 + } + + flushKVs := func() error { + req.Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] + preparedMsg := &grpc.PreparedMsg{} + // by reading the source code, Encode need to find codec and compression from the stream + // because all stream has the same codec and compression, we can use any one of them + if err := preparedMsg.Encode(clients[0], req); err != nil { + return err + } + + for i := range clients { + if err := writeLimiter.WaitN(ctx, allPeers[i].StoreId, int(size)); err != nil { + return errors.Trace(err) + } + if err := clients[i].SendMsg(preparedMsg); err != nil { + if err == io.EOF { + // if it's EOF, need RecvMsg to get the error + dummy := &sst.WriteResponse{} + err = clients[i].RecvMsg(dummy) + } + return annotateErr(err, allPeers[i], "when send data") + } + } + failpoint.Inject("afterFlushKVs", func() { + log.FromContext(ctx).Info(fmt.Sprintf("afterFlushKVs count=%d,size=%d", count, size)) + }) + return nil + } + + iter := j.ingestData.NewIter(ctx, j.keyRange.Start, j.keyRange.End, bufferPool) + //nolint: errcheck + defer iter.Close() + + var remainingStartKey []byte + for iter.First(); iter.Valid(); iter.Next() { + k, v := iter.Key(), iter.Value() + kvSize := int64(len(k) + len(v)) + // here we reuse the `*sst.Pair`s to optimize object allocation + if count < len(pairs) { + pairs[count].Key = k + pairs[count].Value = v + } else { + pair := &sst.Pair{ + Key: k, + Value: v, + } + pairs = append(pairs, pair) + } + count++ + totalCount++ + size += kvSize + totalSize += kvSize + + if size >= kvBatchSize { + if err := flushKVs(); err != nil { + return errors.Trace(err) + } + count = 0 + size = 0 + iter.ReleaseBuf() + } + if totalSize >= regionMaxSize || totalCount >= j.regionSplitKeys { + // we will shrink the key range of this job to real written range + if iter.Next() { + remainingStartKey = append([]byte{}, iter.Key()...) + log.FromContext(ctx).Info("write to tikv partial finish", + zap.Int64("count", totalCount), + zap.Int64("size", totalSize), + logutil.Key("startKey", j.keyRange.Start), + logutil.Key("endKey", j.keyRange.End), + logutil.Key("remainStart", remainingStartKey), + logutil.Region(region), + logutil.Leader(j.region.Leader), + zap.Uint64("commitTS", dataCommitTS)) + } + break + } + } + + if iter.Error() != nil { + return errors.Trace(iter.Error()) + } + + if count > 0 { + if err := flushKVs(); err != nil { + return errors.Trace(err) + } + count = 0 + size = 0 + iter.ReleaseBuf() + } + + var leaderPeerMetas []*sst.SSTMeta + for i, wStream := range clients { + resp, closeErr := wStream.CloseAndRecv() + if closeErr != nil { + return annotateErr(closeErr, allPeers[i], "when close write stream") + } + if resp.Error != nil { + return annotateErr(errors.New("resp error: "+resp.Error.Message), allPeers[i], "when close write stream") + } + if leaderID == region.Peers[i].GetId() { + leaderPeerMetas = resp.Metas + log.FromContext(ctx).Debug("get metas after write kv stream to tikv", zap.Reflect("metas", leaderPeerMetas)) + } + } + + failpoint.Inject("NoLeader", func() { + log.FromContext(ctx).Warn("enter failpoint NoLeader") + leaderPeerMetas = nil + }) + + // if there is not leader currently, we don't forward the stage to wrote and let caller + // handle the retry. + if len(leaderPeerMetas) == 0 { + log.FromContext(ctx).Warn("write to tikv no leader", + logutil.Region(region), logutil.Leader(j.region.Leader), + zap.Uint64("leader_id", leaderID), logutil.SSTMeta(meta), + zap.Int64("kv_pairs", totalCount), zap.Int64("total_bytes", totalSize)) + return common.ErrNoLeader.GenWithStackByArgs(region.Id, leaderID) + } + + takeTime := time.Since(begin) + log.FromContext(ctx).Debug("write to kv", zap.Reflect("region", j.region), zap.Uint64("leader", leaderID), + zap.Reflect("meta", meta), zap.Reflect("return metas", leaderPeerMetas), + zap.Int64("kv_pairs", totalCount), zap.Int64("total_bytes", totalSize), + zap.Stringer("takeTime", takeTime)) + if m, ok := metric.FromContext(ctx); ok { + m.SSTSecondsHistogram.WithLabelValues(metric.SSTProcessWrite).Observe(takeTime.Seconds()) + } + + j.writeResult = &tikvWriteResult{ + sstMeta: leaderPeerMetas, + count: totalCount, + totalBytes: totalSize, + remainingStartKey: remainingStartKey, + } + j.convertStageTo(wrote) + return nil +} + +// ingest tries to finish the regionJob. +// if any ingest logic has error, ingest may retry sometimes to resolve it and finally +// set job to a proper stage with nil error returned. +// if any underlying logic has error, ingest will return an error to let caller +// handle it. +func (local *Backend) ingest(ctx context.Context, j *regionJob) (err error) { + if j.stage != wrote { + return nil + } + + failpoint.Inject("fakeRegionJobs", func() { + front := j.injected[0] + j.injected = j.injected[1:] + j.convertStageTo(front.ingest.nextStage) + failpoint.Return(front.ingest.err) + }) + + if len(j.writeResult.sstMeta) == 0 { + j.convertStageTo(ingested) + return nil + } + + if m, ok := metric.FromContext(ctx); ok { + begin := time.Now() + defer func() { + if err == nil { + m.SSTSecondsHistogram.WithLabelValues(metric.SSTProcessIngest).Observe(time.Since(begin).Seconds()) + } + }() + } + + for retry := 0; retry < maxRetryTimes; retry++ { + resp, err := local.doIngest(ctx, j) + if err == nil && resp.GetError() == nil { + j.convertStageTo(ingested) + return nil + } + if err != nil { + if common.IsContextCanceledError(err) { + return err + } + log.FromContext(ctx).Warn("meet underlying error, will retry ingest", + log.ShortError(err), logutil.SSTMetas(j.writeResult.sstMeta), + logutil.Region(j.region.Region), logutil.Leader(j.region.Leader)) + continue + } + canContinue, err := j.convertStageOnIngestError(resp) + if common.IsContextCanceledError(err) { + return err + } + if !canContinue { + log.FromContext(ctx).Warn("meet error and handle the job later", + zap.Stringer("job stage", j.stage), + logutil.ShortError(j.lastRetryableErr), + j.region.ToZapFields(), + logutil.Key("start", j.keyRange.Start), + logutil.Key("end", j.keyRange.End)) + return nil + } + log.FromContext(ctx).Warn("meet error and will doIngest region again", + logutil.ShortError(j.lastRetryableErr), + j.region.ToZapFields(), + logutil.Key("start", j.keyRange.Start), + logutil.Key("end", j.keyRange.End)) + } + return nil +} + +func (local *Backend) checkWriteStall( + ctx context.Context, + region *split.RegionInfo, +) (bool, *sst.IngestResponse, error) { + clientFactory := local.importClientFactory + for _, peer := range region.Region.GetPeers() { + cli, err := clientFactory.Create(ctx, peer.StoreId) + if err != nil { + return false, nil, errors.Trace(err) + } + // currently we use empty MultiIngestRequest to check if TiKV is busy. + // If in future the rate limit feature contains more metrics we can switch to use it. + resp, err := cli.MultiIngest(ctx, &sst.MultiIngestRequest{}) + if err != nil { + return false, nil, errors.Trace(err) + } + if resp.Error != nil && resp.Error.ServerIsBusy != nil { + return true, resp, nil + } + } + return false, nil, nil +} + +// doIngest send ingest commands to TiKV based on regionJob.writeResult.sstMeta. +// When meet error, it will remove finished sstMetas before return. +func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestResponse, error) { + clientFactory := local.importClientFactory + supportMultiIngest := local.supportMultiIngest + shouldCheckWriteStall := local.ShouldCheckWriteStall + if shouldCheckWriteStall { + writeStall, resp, err := local.checkWriteStall(ctx, j.region) + if err != nil { + return nil, errors.Trace(err) + } + if writeStall { + return resp, nil + } + } + + batch := 1 + if supportMultiIngest { + batch = len(j.writeResult.sstMeta) + } + + var resp *sst.IngestResponse + for start := 0; start < len(j.writeResult.sstMeta); start += batch { + end := min(start+batch, len(j.writeResult.sstMeta)) + ingestMetas := j.writeResult.sstMeta[start:end] + + log.FromContext(ctx).Debug("ingest meta", zap.Reflect("meta", ingestMetas)) + + failpoint.Inject("FailIngestMeta", func(val failpoint.Value) { + // only inject the error once + var resp *sst.IngestResponse + + switch val.(string) { + case "notleader": + resp = &sst.IngestResponse{ + Error: &errorpb.Error{ + NotLeader: &errorpb.NotLeader{ + RegionId: j.region.Region.Id, + Leader: j.region.Leader, + }, + }, + } + case "epochnotmatch": + resp = &sst.IngestResponse{ + Error: &errorpb.Error{ + EpochNotMatch: &errorpb.EpochNotMatch{ + CurrentRegions: []*metapb.Region{j.region.Region}, + }, + }, + } + } + failpoint.Return(resp, nil) + }) + + leader := j.region.Leader + if leader == nil { + return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, + "region id %d has no leader", j.region.Region.Id) + } + + cli, err := clientFactory.Create(ctx, leader.StoreId) + if err != nil { + return nil, errors.Trace(err) + } + reqCtx := &kvrpcpb.Context{ + RegionId: j.region.Region.GetId(), + RegionEpoch: j.region.Region.GetRegionEpoch(), + Peer: leader, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: local.ResourceGroupName, + }, + RequestSource: util.BuildRequestSource(true, kv.InternalTxnLightning, local.TaskType), + } + + if supportMultiIngest { + req := &sst.MultiIngestRequest{ + Context: reqCtx, + Ssts: ingestMetas, + } + resp, err = cli.MultiIngest(ctx, req) + } else { + req := &sst.IngestRequest{ + Context: reqCtx, + Sst: ingestMetas[0], + } + resp, err = cli.Ingest(ctx, req) + } + if resp.GetError() != nil || err != nil { + // remove finished sstMetas + j.writeResult.sstMeta = j.writeResult.sstMeta[start:] + return resp, errors.Trace(err) + } + } + return resp, nil +} + +// convertStageOnIngestError will try to fix the error contained in ingest response. +// Return (_, error) when another error occurred. +// Return (true, nil) when the job can retry ingesting immediately. +// Return (false, nil) when the job should be put back to queue. +func (j *regionJob) convertStageOnIngestError( + resp *sst.IngestResponse, +) (bool, error) { + if resp.GetError() == nil { + return true, nil + } + + var newRegion *split.RegionInfo + switch errPb := resp.GetError(); { + case errPb.NotLeader != nil: + j.lastRetryableErr = common.ErrKVNotLeader.GenWithStack(errPb.GetMessage()) + + // meet a problem that the region leader+peer are all updated but the return + // error is only "NotLeader", we should update the whole region info. + j.convertStageTo(needRescan) + return false, nil + case errPb.EpochNotMatch != nil: + j.lastRetryableErr = common.ErrKVEpochNotMatch.GenWithStack(errPb.GetMessage()) + + if currentRegions := errPb.GetEpochNotMatch().GetCurrentRegions(); currentRegions != nil { + var currentRegion *metapb.Region + for _, r := range currentRegions { + if insideRegion(r, j.writeResult.sstMeta) { + currentRegion = r + break + } + } + if currentRegion != nil { + var newLeader *metapb.Peer + for _, p := range currentRegion.Peers { + if p.GetStoreId() == j.region.Leader.GetStoreId() { + newLeader = p + break + } + } + if newLeader != nil { + newRegion = &split.RegionInfo{ + Leader: newLeader, + Region: currentRegion, + } + } + } + } + if newRegion != nil { + j.region = newRegion + j.convertStageTo(regionScanned) + return false, nil + } + j.convertStageTo(needRescan) + return false, nil + case strings.Contains(errPb.Message, "raft: proposal dropped"): + j.lastRetryableErr = common.ErrKVRaftProposalDropped.GenWithStack(errPb.GetMessage()) + + j.convertStageTo(needRescan) + return false, nil + case errPb.ServerIsBusy != nil: + j.lastRetryableErr = common.ErrKVServerIsBusy.GenWithStack(errPb.GetMessage()) + + return false, nil + case errPb.RegionNotFound != nil: + j.lastRetryableErr = common.ErrKVRegionNotFound.GenWithStack(errPb.GetMessage()) + + j.convertStageTo(needRescan) + return false, nil + case errPb.ReadIndexNotReady != nil: + j.lastRetryableErr = common.ErrKVReadIndexNotReady.GenWithStack(errPb.GetMessage()) + + // this error happens when this region is splitting, the error might be: + // read index not ready, reason can not read index due to split, region 64037 + // we have paused schedule, but it's temporary, + // if next request takes a long time, there's chance schedule is enabled again + // or on key range border, another engine sharing this region tries to split this + // region may cause this error too. + j.convertStageTo(needRescan) + return false, nil + case errPb.DiskFull != nil: + j.lastRetryableErr = common.ErrKVIngestFailed.GenWithStack(errPb.GetMessage()) + + return false, errors.Errorf("non-retryable error: %s", resp.GetError().GetMessage()) + } + // all others doIngest error, such as stale command, etc. we'll retry it again from writeAndIngestByRange + j.lastRetryableErr = common.ErrKVIngestFailed.GenWithStack(resp.GetError().GetMessage()) + j.convertStageTo(regionScanned) + return false, nil +} + +type regionJobRetryHeap []*regionJob + +var _ heap.Interface = (*regionJobRetryHeap)(nil) + +func (h *regionJobRetryHeap) Len() int { + return len(*h) +} + +func (h *regionJobRetryHeap) Less(i, j int) bool { + v := *h + return v[i].waitUntil.Before(v[j].waitUntil) +} + +func (h *regionJobRetryHeap) Swap(i, j int) { + v := *h + v[i], v[j] = v[j], v[i] +} + +func (h *regionJobRetryHeap) Push(x any) { + *h = append(*h, x.(*regionJob)) +} + +func (h *regionJobRetryHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// regionJobRetryer is a concurrent-safe queue holding jobs that need to put +// back later, and put back when the regionJob.waitUntil is reached. It maintains +// a heap of jobs internally based on the regionJob.waitUntil field. +type regionJobRetryer struct { + // lock acquiring order: protectedClosed > protectedQueue > protectedToPutBack + protectedClosed struct { + mu sync.Mutex + closed bool + } + protectedQueue struct { + mu sync.Mutex + q regionJobRetryHeap + } + protectedToPutBack struct { + mu sync.Mutex + toPutBack *regionJob + } + putBackCh chan<- *regionJob + reload chan struct{} + jobWg *sync.WaitGroup +} + +// startRegionJobRetryer starts a new regionJobRetryer and it will run in +// background to put the job back to `putBackCh` when job's waitUntil is reached. +// Cancel the `ctx` will stop retryer and `jobWg.Done` will be trigger for jobs +// that are not put back yet. +func startRegionJobRetryer( + ctx context.Context, + putBackCh chan<- *regionJob, + jobWg *sync.WaitGroup, +) *regionJobRetryer { + ret := ®ionJobRetryer{ + putBackCh: putBackCh, + reload: make(chan struct{}, 1), + jobWg: jobWg, + } + ret.protectedQueue.q = make(regionJobRetryHeap, 0, 16) + go ret.run(ctx) + return ret +} + +// run is only internally used, caller should not use it. +func (q *regionJobRetryer) run(ctx context.Context) { + defer q.close() + + for { + var front *regionJob + q.protectedQueue.mu.Lock() + if len(q.protectedQueue.q) > 0 { + front = q.protectedQueue.q[0] + } + q.protectedQueue.mu.Unlock() + + switch { + case front != nil: + select { + case <-ctx.Done(): + return + case <-q.reload: + case <-time.After(time.Until(front.waitUntil)): + q.protectedQueue.mu.Lock() + q.protectedToPutBack.mu.Lock() + q.protectedToPutBack.toPutBack = heap.Pop(&q.protectedQueue.q).(*regionJob) + // release the lock of queue to avoid blocking regionJobRetryer.push + q.protectedQueue.mu.Unlock() + + // hold the lock of toPutBack to make sending to putBackCh and + // resetting toPutBack atomic w.r.t. regionJobRetryer.close + select { + case <-ctx.Done(): + q.protectedToPutBack.mu.Unlock() + return + case q.putBackCh <- q.protectedToPutBack.toPutBack: + q.protectedToPutBack.toPutBack = nil + q.protectedToPutBack.mu.Unlock() + } + } + default: + // len(q.q) == 0 + select { + case <-ctx.Done(): + return + case <-q.reload: + } + } + } +} + +// close is only internally used, caller should not use it. +func (q *regionJobRetryer) close() { + q.protectedClosed.mu.Lock() + defer q.protectedClosed.mu.Unlock() + q.protectedClosed.closed = true + + if q.protectedToPutBack.toPutBack != nil { + q.protectedToPutBack.toPutBack.done(q.jobWg) + } + for _, job := range q.protectedQueue.q { + job.done(q.jobWg) + } +} + +// push should not be blocked for long time in any cases. +func (q *regionJobRetryer) push(job *regionJob) bool { + q.protectedClosed.mu.Lock() + defer q.protectedClosed.mu.Unlock() + if q.protectedClosed.closed { + return false + } + + q.protectedQueue.mu.Lock() + heap.Push(&q.protectedQueue.q, job) + q.protectedQueue.mu.Unlock() + + select { + case q.reload <- struct{}{}: + default: + } + return true +} diff --git a/pkg/lightning/backend/tidb/binding__failpoint_binding__.go b/pkg/lightning/backend/tidb/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..7cf3faffb040f --- /dev/null +++ b/pkg/lightning/backend/tidb/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package tidb + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/lightning/backend/tidb/tidb.go b/pkg/lightning/backend/tidb/tidb.go index 186e17d9d6d06..2d0ae231b4fbc 100644 --- a/pkg/lightning/backend/tidb/tidb.go +++ b/pkg/lightning/backend/tidb/tidb.go @@ -252,12 +252,11 @@ func (b *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaNam return nil } - failpoint.Inject( - "FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", - func() { - fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") - }, - ) + if _, _err_ := failpoint.Eval(_curpkg_("FetchRemoteTableModels_BeforeFetchTableAutoIDInfos")); _err_ == nil { + + fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") + + } // init auto id column for each table for _, tbl := range tables { @@ -838,9 +837,9 @@ stmtLoop: } // max-error not yet reached (error consumed by errorMgr), proceed to next stmtTask. } - failpoint.Inject("FailIfImportedSomeRows", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FailIfImportedSomeRows")); _err_ == nil { panic("forcing failure due to FailIfImportedSomeRows, before saving checkpoint") - }) + } return nil } diff --git a/pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ b/pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ new file mode 100644 index 0000000000000..186e17d9d6d06 --- /dev/null +++ b/pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ @@ -0,0 +1,956 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidb + +import ( + "context" + "database/sql" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + + gmysql "github.com/go-sql-driver/mysql" + "github.com/google/uuid" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/version" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/lightning/backend" + "github.com/pingcap/tidb/pkg/lightning/backend/encode" + "github.com/pingcap/tidb/pkg/lightning/backend/kv" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/errormanager" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/lightning/verification" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbutil" + "github.com/pingcap/tidb/pkg/util/redact" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +var extraHandleTableColumn = &table.Column{ + ColumnInfo: kv.ExtraHandleColumnInfo, + GeneratedExpr: nil, + DefaultExpr: nil, +} + +const ( + writeRowsMaxRetryTimes = 3 +) + +type tidbRow struct { + insertStmt string + path string + offset int64 +} + +var emptyTiDBRow = tidbRow{ + insertStmt: "", + path: "", + offset: 0, +} + +type tidbRows []tidbRow + +// MarshalLogArray implements the zapcore.ArrayMarshaler interface +func (rows tidbRows) MarshalLogArray(encoder zapcore.ArrayEncoder) error { + for _, r := range rows { + encoder.AppendString(redact.Value(r.insertStmt)) + } + return nil +} + +type tidbEncoder struct { + mode mysql.SQLMode + tbl table.Table + se sessionctx.Context + // the index of table columns for each data field. + // index == len(table.columns) means this field is `_tidb_rowid` + columnIdx []int + // the max index used in this chunk, due to the ignore-columns config, we can't + // directly check the total column count, so we fall back to only check that + // the there are enough columns. + columnCnt int + // data file path + path string + logger log.Logger +} + +type encodingBuilder struct{} + +// NewEncodingBuilder creates an EncodingBuilder with TiDB backend implementation. +func NewEncodingBuilder() encode.EncodingBuilder { + return new(encodingBuilder) +} + +// NewEncoder creates a KV encoder. +// It implements the `backend.EncodingBuilder` interface. +func (*encodingBuilder) NewEncoder(ctx context.Context, config *encode.EncodingConfig) (encode.Encoder, error) { + se := kv.NewSessionCtx(&config.SessionOptions, log.FromContext(ctx)) + if config.SQLMode.HasStrictMode() { + se.GetSessionVars().SkipUTF8Check = false + se.GetSessionVars().SkipASCIICheck = false + } + + return &tidbEncoder{ + mode: config.SQLMode, + tbl: config.Table, + se: se, + path: config.Path, + logger: config.Logger, + }, nil +} + +// MakeEmptyRows creates an empty KV rows. +// It implements the `backend.EncodingBuilder` interface. +func (*encodingBuilder) MakeEmptyRows() encode.Rows { + return tidbRows(nil) +} + +type targetInfoGetter struct { + db *sql.DB +} + +// NewTargetInfoGetter creates an TargetInfoGetter with TiDB backend implementation. +func NewTargetInfoGetter(db *sql.DB) backend.TargetInfoGetter { + return &targetInfoGetter{ + db: db, + } +} + +// FetchRemoteDBModels implements the `backend.TargetInfoGetter` interface. +func (b *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { + results := []*model.DBInfo{} + logger := log.FromContext(ctx) + s := common.SQLWithRetry{ + DB: b.db, + Logger: logger, + } + err := s.Transact(ctx, "fetch db models", func(_ context.Context, tx *sql.Tx) error { + results = results[:0] + + rows, e := tx.Query("SHOW DATABASES") + if e != nil { + return e + } + defer rows.Close() + + for rows.Next() { + var dbName string + if e := rows.Scan(&dbName); e != nil { + return e + } + dbInfo := &model.DBInfo{ + Name: model.NewCIStr(dbName), + } + results = append(results, dbInfo) + } + return rows.Err() + }) + return results, err +} + +// FetchRemoteTableModels obtains the models of all tables given the schema name. +// It implements the `backend.TargetInfoGetter` interface. +// TODO: refactor +func (b *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { + var err error + results := []*model.TableInfo{} + logger := log.FromContext(ctx) + s := common.SQLWithRetry{ + DB: b.db, + Logger: logger, + } + + err = s.Transact(ctx, "fetch table columns", func(_ context.Context, tx *sql.Tx) error { + var versionStr string + if versionStr, err = version.FetchVersion(ctx, tx); err != nil { + return err + } + serverInfo := version.ParseServerInfo(versionStr) + + rows, e := tx.Query(` + SELECT table_name, column_name, column_type, generation_expression, extra + FROM information_schema.columns + WHERE table_schema = ? + ORDER BY table_name, ordinal_position; + `, schemaName) + if e != nil { + return e + } + defer rows.Close() + + var ( + curTableName string + curColOffset int + curTable *model.TableInfo + ) + tables := []*model.TableInfo{} + for rows.Next() { + var tableName, columnName, columnType, generationExpr, columnExtra string + if e := rows.Scan(&tableName, &columnName, &columnType, &generationExpr, &columnExtra); e != nil { + return e + } + if tableName != curTableName { + curTable = &model.TableInfo{ + Name: model.NewCIStr(tableName), + State: model.StatePublic, + PKIsHandle: true, + } + tables = append(tables, curTable) + curTableName = tableName + curColOffset = 0 + } + + // see: https://github.com/pingcap/parser/blob/3b2fb4b41d73710bc6c4e1f4e8679d8be6a4863e/types/field_type.go#L185-L191 + var flag uint + if strings.HasSuffix(columnType, "unsigned") { + flag |= mysql.UnsignedFlag + } + if strings.Contains(columnExtra, "auto_increment") { + flag |= mysql.AutoIncrementFlag + } + + ft := types.FieldType{} + ft.SetFlag(flag) + curTable.Columns = append(curTable.Columns, &model.ColumnInfo{ + Name: model.NewCIStr(columnName), + Offset: curColOffset, + State: model.StatePublic, + FieldType: ft, + GeneratedExprString: generationExpr, + }) + curColOffset++ + } + if err := rows.Err(); err != nil { + return err + } + // shard_row_id/auto random is only available after tidb v4.0.0 + // `show table next_row_id` is also not available before tidb v4.0.0 + if serverInfo.ServerType != version.ServerTypeTiDB || serverInfo.ServerVersion.Major < 4 { + results = tables + return nil + } + + failpoint.Inject( + "FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", + func() { + fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") + }, + ) + + // init auto id column for each table + for _, tbl := range tables { + tblName := common.UniqueTable(schemaName, tbl.Name.O) + autoIDInfos, err := FetchTableAutoIDInfos(ctx, tx, tblName) + if err != nil { + logger.Warn("fetch table auto ID infos error. Ignore this table and continue.", zap.String("table_name", tblName), zap.Error(err)) + continue + } + for _, info := range autoIDInfos { + for _, col := range tbl.Columns { + if col.Name.O == info.Column { + switch info.Type { + case "AUTO_INCREMENT": + col.AddFlag(mysql.AutoIncrementFlag) + case "AUTO_RANDOM": + col.AddFlag(mysql.PriKeyFlag) + tbl.PKIsHandle = true + // set a stub here, since we don't really need the real value + tbl.AutoRandomBits = 1 + } + } + } + } + results = append(results, tbl) + } + return nil + }) + return results, err +} + +// CheckRequirements performs the check whether the backend satisfies the version requirements. +// It implements the `backend.TargetInfoGetter` interface. +func (*targetInfoGetter) CheckRequirements(ctx context.Context, _ *backend.CheckCtx) error { + log.FromContext(ctx).Info("skipping check requirements for tidb backend") + return nil +} + +type tidbBackend struct { + db *sql.DB + conflictCfg config.Conflict + // onDuplicate is the type of INSERT SQL. It may be different with + // conflictCfg.Strategy to implement other feature, but the behaviour in caller's + // view should be the same. + onDuplicate config.DuplicateResolutionAlgorithm + errorMgr *errormanager.ErrorManager + // maxChunkSize and maxChunkRows are the target size and number of rows of each INSERT SQL + // statement to be sent to downstream. Sometimes we want to reduce the txn size to avoid + // affecting the cluster too much. + maxChunkSize uint64 + maxChunkRows int +} + +var _ backend.Backend = (*tidbBackend)(nil) + +// NewTiDBBackend creates a new TiDB backend using the given database. +// +// The backend does not take ownership of `db`. Caller should close `db` +// manually after the backend expired. +func NewTiDBBackend( + ctx context.Context, + db *sql.DB, + cfg *config.Config, + errorMgr *errormanager.ErrorManager, +) backend.Backend { + conflict := cfg.Conflict + var onDuplicate config.DuplicateResolutionAlgorithm + switch conflict.Strategy { + case config.ErrorOnDup: + onDuplicate = config.ErrorOnDup + case config.ReplaceOnDup: + onDuplicate = config.ReplaceOnDup + case config.IgnoreOnDup: + if conflict.MaxRecordRows == 0 { + onDuplicate = config.IgnoreOnDup + } else { + // need to stop batch insert on error and fall back to row by row insert + // to record the row + onDuplicate = config.ErrorOnDup + } + default: + log.FromContext(ctx).Warn("unsupported conflict strategy for TiDB backend, overwrite with `error`") + onDuplicate = config.ErrorOnDup + } + return &tidbBackend{ + db: db, + conflictCfg: conflict, + onDuplicate: onDuplicate, + errorMgr: errorMgr, + maxChunkSize: uint64(cfg.TikvImporter.LogicalImportBatchSize), + maxChunkRows: cfg.TikvImporter.LogicalImportBatchRows, + } +} + +func (row tidbRow) Size() uint64 { + return uint64(len(row.insertStmt)) +} + +func (row tidbRow) String() string { + return row.insertStmt +} + +func (row tidbRow) ClassifyAndAppend(data *encode.Rows, checksum *verification.KVChecksum, _ *encode.Rows, _ *verification.KVChecksum) { + rows := (*data).(tidbRows) + // Cannot do `rows := data.(*tidbRows); *rows = append(*rows, row)`. + //nolint:gocritic + *data = append(rows, row) + cs := verification.MakeKVChecksum(row.Size(), 1, 0) + checksum.Add(&cs) +} + +func (rows tidbRows) splitIntoChunks(splitSize uint64, splitRows int) []tidbRows { + if len(rows) == 0 { + return nil + } + + res := make([]tidbRows, 0, 1) + i := 0 + cumSize := uint64(0) + + for j, row := range rows { + if i < j && (cumSize+row.Size() > splitSize || j-i >= splitRows) { + res = append(res, rows[i:j]) + i = j + cumSize = 0 + } + cumSize += row.Size() + } + + return append(res, rows[i:]) +} + +func (rows tidbRows) Clear() encode.Rows { + return rows[:0] +} + +func (enc *tidbEncoder) appendSQLBytes(sb *strings.Builder, value []byte) { + sb.Grow(2 + len(value)) + sb.WriteByte('\'') + if enc.mode.HasNoBackslashEscapesMode() { + for _, b := range value { + if b == '\'' { + sb.WriteString(`''`) + } else { + sb.WriteByte(b) + } + } + } else { + for _, b := range value { + switch b { + case 0: + sb.WriteString(`\0`) + case '\b': + sb.WriteString(`\b`) + case '\n': + sb.WriteString(`\n`) + case '\r': + sb.WriteString(`\r`) + case '\t': + sb.WriteString(`\t`) + case 26: + sb.WriteString(`\Z`) + case '\'': + sb.WriteString(`''`) + case '\\': + sb.WriteString(`\\`) + default: + sb.WriteByte(b) + } + } + } + sb.WriteByte('\'') +} + +// appendSQL appends the SQL representation of the Datum into the string builder. +// Note that we cannot use Datum.ToString since it doesn't perform SQL escaping. +func (enc *tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum, _ *table.Column) error { + switch datum.Kind() { + case types.KindNull: + sb.WriteString("NULL") + + case types.KindMinNotNull: + sb.WriteString("MINVALUE") + + case types.KindMaxValue: + sb.WriteString("MAXVALUE") + + case types.KindInt64: + // longest int64 = -9223372036854775808 which has 20 characters + var buffer [20]byte + value := strconv.AppendInt(buffer[:0], datum.GetInt64(), 10) + sb.Write(value) + + case types.KindUint64, types.KindMysqlEnum, types.KindMysqlSet: + // longest uint64 = 18446744073709551615 which has 20 characters + var buffer [20]byte + value := strconv.AppendUint(buffer[:0], datum.GetUint64(), 10) + sb.Write(value) + + case types.KindFloat32, types.KindFloat64: + // float64 has 16 digits of precision, so a buffer size of 32 is more than enough... + var buffer [32]byte + value := strconv.AppendFloat(buffer[:0], datum.GetFloat64(), 'g', -1, 64) + sb.Write(value) + case types.KindString: + // See: https://github.com/pingcap/tidb-lightning/issues/550 + // if enc.mode.HasStrictMode() { + // d, err := table.CastValue(enc.se, *datum, col.ToInfo(), false, false) + // if err != nil { + // return errors.Trace(err) + // } + // datum = &d + // } + + enc.appendSQLBytes(sb, datum.GetBytes()) + case types.KindBytes: + enc.appendSQLBytes(sb, datum.GetBytes()) + + case types.KindMysqlJSON: + value, err := datum.GetMysqlJSON().MarshalJSON() + if err != nil { + return err + } + enc.appendSQLBytes(sb, value) + + case types.KindBinaryLiteral: + value := datum.GetBinaryLiteral() + sb.Grow(3 + 2*len(value)) + sb.WriteString("x'") + if _, err := hex.NewEncoder(sb).Write(value); err != nil { + return errors.Trace(err) + } + sb.WriteByte('\'') + + case types.KindMysqlBit: + var buffer [20]byte + intValue, err := datum.GetBinaryLiteral().ToInt(types.DefaultStmtNoWarningContext) + if err != nil { + return err + } + value := strconv.AppendUint(buffer[:0], intValue, 10) + sb.Write(value) + + // time, duration, decimal + default: + value, err := datum.ToString() + if err != nil { + return err + } + sb.WriteByte('\'') + sb.WriteString(value) + sb.WriteByte('\'') + } + + return nil +} + +func (*tidbEncoder) Close() {} + +func getColumnByIndex(cols []*table.Column, index int) *table.Column { + if index == len(cols) { + return extraHandleTableColumn + } + return cols[index] +} + +func (enc *tidbEncoder) Encode(row []types.Datum, _ int64, columnPermutation []int, offset int64) (encode.Row, error) { + cols := enc.tbl.Cols() + + if len(enc.columnIdx) == 0 { + columnMaxIdx := -1 + columnIdx := make([]int, len(columnPermutation)) + for i := 0; i < len(columnPermutation); i++ { + columnIdx[i] = -1 + } + for i, idx := range columnPermutation { + if idx >= 0 { + columnIdx[idx] = i + if idx > columnMaxIdx { + columnMaxIdx = idx + } + } + } + enc.columnIdx = columnIdx + enc.columnCnt = columnMaxIdx + 1 + } + + // TODO: since the column count doesn't exactly reflect the real column names, we only check the upper bound currently. + // See: tests/generated_columns/data/gencol.various_types.0.sql this sql has no columns, so encodeLoop will fill the + // column permutation with default, thus enc.columnCnt > len(row). + if len(row) < enc.columnCnt { + // 1. if len(row) < enc.columnCnt: data in row cannot populate the insert statement, because + // there are enc.columnCnt elements to insert but fewer columns in row + enc.logger.Error("column count mismatch", zap.Ints("column_permutation", columnPermutation), + zap.Array("data", kv.RowArrayMarshaller(row))) + return emptyTiDBRow, errors.Errorf("column count mismatch, expected %d, got %d", enc.columnCnt, len(row)) + } + + if len(row) > len(enc.columnIdx) { + // 2. if len(row) > len(columnIdx): raw row data has more columns than those + // in the table + enc.logger.Error("column count mismatch", zap.Ints("column_count", enc.columnIdx), + zap.Array("data", kv.RowArrayMarshaller(row))) + return emptyTiDBRow, errors.Errorf("column count mismatch, at most %d but got %d", len(enc.columnIdx), len(row)) + } + + var encoded strings.Builder + encoded.Grow(8 * len(row)) + encoded.WriteByte('(') + cnt := 0 + for i, field := range row { + if enc.columnIdx[i] < 0 { + continue + } + if cnt > 0 { + encoded.WriteByte(',') + } + datum := field + if err := enc.appendSQL(&encoded, &datum, getColumnByIndex(cols, enc.columnIdx[i])); err != nil { + enc.logger.Error("tidb encode failed", + zap.Array("original", kv.RowArrayMarshaller(row)), + zap.Int("originalCol", i), + log.ShortError(err), + ) + return nil, err + } + cnt++ + } + encoded.WriteByte(')') + return tidbRow{ + insertStmt: encoded.String(), + path: enc.path, + offset: offset, + }, nil +} + +// EncodeRowForRecord encodes a row to a string compatible with INSERT statements. +func EncodeRowForRecord(ctx context.Context, encTable table.Table, sqlMode mysql.SQLMode, row []types.Datum, columnPermutation []int) string { + enc := tidbEncoder{ + tbl: encTable, + mode: sqlMode, + logger: log.FromContext(ctx), + } + resRow, err := enc.Encode(row, 0, columnPermutation, 0) + if err != nil { + // if encode can't succeed, fallback to record the raw input strings + // ignore the error since it can only happen if the datum type is unknown, this can't happen here. + datumStr, _ := types.DatumsToString(row, true) + return datumStr + } + return resRow.(tidbRow).insertStmt +} + +func (*tidbBackend) Close() { + // *Not* going to close `be.db`. The db object is normally borrowed from a + // TidbManager, so we let the manager to close it. +} + +func (*tidbBackend) RetryImportDelay() time.Duration { + return 0 +} + +func (*tidbBackend) ShouldPostProcess() bool { + return true +} + +func (*tidbBackend) OpenEngine(context.Context, *backend.EngineConfig, uuid.UUID) error { + return nil +} + +func (*tidbBackend) CloseEngine(context.Context, *backend.EngineConfig, uuid.UUID) error { + return nil +} + +func (*tidbBackend) CleanupEngine(context.Context, uuid.UUID) error { + return nil +} + +func (*tidbBackend) ImportEngine(context.Context, uuid.UUID, int64, int64) error { + return nil +} + +func (be *tidbBackend) WriteRows(ctx context.Context, tableName string, columnNames []string, rows encode.Rows) error { + var err error +rowLoop: + for _, r := range rows.(tidbRows).splitIntoChunks(be.maxChunkSize, be.maxChunkRows) { + for i := 0; i < writeRowsMaxRetryTimes; i++ { + // Write in the batch mode first. + err = be.WriteBatchRowsToDB(ctx, tableName, columnNames, r) + switch { + case err == nil: + continue rowLoop + case common.IsRetryableError(err): + // retry next loop + case be.errorMgr.TypeErrorsRemain() > 0 || + be.errorMgr.ConflictErrorsRemain() > 0 || + (be.conflictCfg.Strategy == config.ErrorOnDup && !be.errorMgr.RecordErrorOnce()): + // WriteBatchRowsToDB failed in the batch mode and can not be retried, + // we need to redo the writing row-by-row to find where the error locates (and skip it correctly in future). + if err = be.WriteRowsToDB(ctx, tableName, columnNames, r); err != nil { + // If the error is not nil, it means we reach the max error count in the + // non-batch mode or this is "error" conflict strategy. + return errors.Annotatef(err, "[%s] write rows exceed conflict threshold", tableName) + } + continue rowLoop + default: + return err + } + } + return errors.Annotatef(err, "[%s] batch write rows reach max retry %d and still failed", tableName, writeRowsMaxRetryTimes) + } + return nil +} + +type stmtTask struct { + rows tidbRows + stmt string +} + +// WriteBatchRowsToDB write rows in batch mode, which will insert multiple rows like this: +// +// insert into t1 values (111), (222), (333), (444); +func (be *tidbBackend) WriteBatchRowsToDB(ctx context.Context, tableName string, columnNames []string, rows tidbRows) error { + insertStmt := be.checkAndBuildStmt(rows, tableName, columnNames) + if insertStmt == nil { + return nil + } + // Note: we are not going to do interpolation (prepared statements) to avoid + // complication arise from data length overflow of BIT and BINARY columns + stmtTasks := make([]stmtTask, 1) + for i, row := range rows { + if i != 0 { + insertStmt.WriteByte(',') + } + insertStmt.WriteString(row.insertStmt) + } + stmtTasks[0] = stmtTask{rows, insertStmt.String()} + return be.execStmts(ctx, stmtTasks, tableName, true) +} + +func (be *tidbBackend) checkAndBuildStmt(rows tidbRows, tableName string, columnNames []string) *strings.Builder { + if len(rows) == 0 { + return nil + } + return be.buildStmt(tableName, columnNames) +} + +// WriteRowsToDB write rows in row-by-row mode, which will insert multiple rows like this: +// +// insert into t1 values (111); +// insert into t1 values (222); +// insert into t1 values (333); +// insert into t1 values (444); +// +// See more details in br#1366: https://github.com/pingcap/br/issues/1366 +func (be *tidbBackend) WriteRowsToDB(ctx context.Context, tableName string, columnNames []string, rows tidbRows) error { + insertStmt := be.checkAndBuildStmt(rows, tableName, columnNames) + if insertStmt == nil { + return nil + } + is := insertStmt.String() + stmtTasks := make([]stmtTask, 0, len(rows)) + for _, row := range rows { + var finalInsertStmt strings.Builder + finalInsertStmt.WriteString(is) + finalInsertStmt.WriteString(row.insertStmt) + stmtTasks = append(stmtTasks, stmtTask{[]tidbRow{row}, finalInsertStmt.String()}) + } + return be.execStmts(ctx, stmtTasks, tableName, false) +} + +func (be *tidbBackend) buildStmt(tableName string, columnNames []string) *strings.Builder { + var insertStmt strings.Builder + switch be.onDuplicate { + case config.ReplaceOnDup: + insertStmt.WriteString("REPLACE INTO ") + case config.IgnoreOnDup: + insertStmt.WriteString("INSERT IGNORE INTO ") + case config.ErrorOnDup: + insertStmt.WriteString("INSERT INTO ") + } + insertStmt.WriteString(tableName) + if len(columnNames) > 0 { + insertStmt.WriteByte('(') + for i, colName := range columnNames { + if i != 0 { + insertStmt.WriteByte(',') + } + common.WriteMySQLIdentifier(&insertStmt, colName) + } + insertStmt.WriteByte(')') + } + insertStmt.WriteString(" VALUES") + return &insertStmt +} + +func (be *tidbBackend) execStmts(ctx context.Context, stmtTasks []stmtTask, tableName string, batch bool) error { +stmtLoop: + for _, stmtTask := range stmtTasks { + var ( + result sql.Result + err error + ) + for i := 0; i < writeRowsMaxRetryTimes; i++ { + stmt := stmtTask.stmt + result, err = be.db.ExecContext(ctx, stmt) + if err == nil { + affected, err2 := result.RowsAffected() + if err2 != nil { + // should not happen + return errors.Trace(err2) + } + diff := int64(len(stmtTask.rows)) - affected + if diff < 0 { + diff = -diff + } + if diff > 0 { + if err2 = be.errorMgr.RecordDuplicateCount(diff); err2 != nil { + return err2 + } + } + continue stmtLoop + } + + if !common.IsContextCanceledError(err) { + log.FromContext(ctx).Error("execute statement failed", + zap.Array("rows", stmtTask.rows), zap.String("stmt", redact.Value(stmt)), zap.Error(err)) + } + // It's batch mode, just return the error. Caller will fall back to row-by-row mode. + if batch { + return errors.Trace(err) + } + if !common.IsRetryableError(err) { + break + } + } + + firstRow := stmtTask.rows[0] + + if isDupEntryError(err) { + // rowID is ignored in tidb backend + if be.conflictCfg.Strategy == config.ErrorOnDup { + be.errorMgr.RecordDuplicateOnce( + ctx, + log.FromContext(ctx), + tableName, + firstRow.path, + firstRow.offset, + err.Error(), + 0, + firstRow.insertStmt, + ) + return err + } + err = be.errorMgr.RecordDuplicate( + ctx, + log.FromContext(ctx), + tableName, + firstRow.path, + firstRow.offset, + err.Error(), + 0, + firstRow.insertStmt, + ) + } else { + err = be.errorMgr.RecordTypeError( + ctx, + log.FromContext(ctx), + tableName, + firstRow.path, + firstRow.offset, + firstRow.insertStmt, + err, + ) + } + if err != nil { + return errors.Trace(err) + } + // max-error not yet reached (error consumed by errorMgr), proceed to next stmtTask. + } + failpoint.Inject("FailIfImportedSomeRows", func() { + panic("forcing failure due to FailIfImportedSomeRows, before saving checkpoint") + }) + return nil +} + +func isDupEntryError(err error) bool { + merr, ok := errors.Cause(err).(*gmysql.MySQLError) + if !ok { + return false + } + return merr.Number == errno.ErrDupEntry +} + +// FlushEngine flushes the data in the engine to the underlying storage. +func (*tidbBackend) FlushEngine(context.Context, uuid.UUID) error { + return nil +} + +// FlushAllEngines flushes all the data in the engines to the underlying storage. +func (*tidbBackend) FlushAllEngines(context.Context) error { + return nil +} + +// ResetEngine resets the engine. +func (*tidbBackend) ResetEngine(context.Context, uuid.UUID) error { + return errors.New("cannot reset an engine in TiDB backend") +} + +// LocalWriter returns a writer that writes data to local storage. +func (be *tidbBackend) LocalWriter( + _ context.Context, + cfg *backend.LocalWriterConfig, + _ uuid.UUID, +) (backend.EngineWriter, error) { + return &Writer{be: be, tableName: cfg.TiDB.TableName}, nil +} + +// Writer is a writer that writes data to local storage. +type Writer struct { + be *tidbBackend + tableName string +} + +// Close implements the EngineWriter interface. +func (*Writer) Close(_ context.Context) (backend.ChunkFlushStatus, error) { + return nil, nil +} + +// AppendRows implements the EngineWriter interface. +func (w *Writer) AppendRows(ctx context.Context, columnNames []string, rows encode.Rows) error { + return w.be.WriteRows(ctx, w.tableName, columnNames, rows) +} + +// IsSynced implements the EngineWriter interface. +func (*Writer) IsSynced() bool { + return true +} + +// TableAutoIDInfo is the auto id information of a table. +type TableAutoIDInfo struct { + Column string + NextID uint64 + Type string +} + +// FetchTableAutoIDInfos fetches the auto id information of a table. +func FetchTableAutoIDInfos(ctx context.Context, exec dbutil.QueryExecutor, tableName string) ([]*TableAutoIDInfo, error) { + rows, e := exec.QueryContext(ctx, fmt.Sprintf("SHOW TABLE %s NEXT_ROW_ID", tableName)) + if e != nil { + return nil, errors.Trace(e) + } + var autoIDInfos []*TableAutoIDInfo + for rows.Next() { + var ( + dbName, tblName, columnName, idType string + nextID uint64 + ) + columns, err := rows.Columns() + if err != nil { + return nil, errors.Trace(err) + } + + //+--------------+------------+-------------+--------------------+----------------+ + //| DB_NAME | TABLE_NAME | COLUMN_NAME | NEXT_GLOBAL_ROW_ID | ID_TYPE | + //+--------------+------------+-------------+--------------------+----------------+ + //| testsysbench | t | _tidb_rowid | 1 | AUTO_INCREMENT | + //+--------------+------------+-------------+--------------------+----------------+ + + // if columns length is 4, it doesn't contain the last column `ID_TYPE`, and it will always be 'AUTO_INCREMENT' + // for v4.0.0~v4.0.2 show table t next_row_id only returns 4 columns. + if len(columns) == 4 { + err = rows.Scan(&dbName, &tblName, &columnName, &nextID) + idType = "AUTO_INCREMENT" + } else { + err = rows.Scan(&dbName, &tblName, &columnName, &nextID, &idType) + } + if err != nil { + return nil, errors.Trace(err) + } + autoIDInfos = append(autoIDInfos, &TableAutoIDInfo{ + Column: columnName, + NextID: nextID, + Type: idType, + }) + } + // Defer in for-loop would be costly, anyway, we don't need those rows after this turn of iteration. + //nolint:sqlclosecheck + if err := rows.Close(); err != nil { + return nil, errors.Trace(err) + } + if err := rows.Err(); err != nil { + return nil, errors.Trace(err) + } + return autoIDInfos, nil +} diff --git a/pkg/lightning/common/binding__failpoint_binding__.go b/pkg/lightning/common/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..a9dff357d9a9e --- /dev/null +++ b/pkg/lightning/common/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package common + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/lightning/common/storage_unix.go b/pkg/lightning/common/storage_unix.go index a19dfb7c0ee98..cbcbc8f0f3a1a 100644 --- a/pkg/lightning/common/storage_unix.go +++ b/pkg/lightning/common/storage_unix.go @@ -29,10 +29,10 @@ import ( // GetStorageSize gets storage's capacity and available size func GetStorageSize(dir string) (size StorageSize, err error) { - failpoint.Inject("GetStorageSize", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("GetStorageSize")); _err_ == nil { injectedSize := val.(int) - failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) - }) + return StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil + } var stat unix.Statfs_t err = unix.Statfs(dir, &stat) diff --git a/pkg/lightning/common/storage_unix.go__failpoint_stash__ b/pkg/lightning/common/storage_unix.go__failpoint_stash__ new file mode 100644 index 0000000000000..a19dfb7c0ee98 --- /dev/null +++ b/pkg/lightning/common/storage_unix.go__failpoint_stash__ @@ -0,0 +1,79 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows + +// TODO: Deduplicate this implementation with DM! + +package common + +import ( + "reflect" + "syscall" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "golang.org/x/sys/unix" +) + +// GetStorageSize gets storage's capacity and available size +func GetStorageSize(dir string) (size StorageSize, err error) { + failpoint.Inject("GetStorageSize", func(val failpoint.Value) { + injectedSize := val.(int) + failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) + }) + + var stat unix.Statfs_t + err = unix.Statfs(dir, &stat) + if err != nil { + return size, errors.Annotatef(err, "cannot get disk capacity at %s", dir) + } + + // When container is run in MacOS, `bsize` obtained by `statfs` syscall is not the fundamental block size, + // but the `iosize` (optimal transfer block size) instead, it's usually 1024 times larger than the `bsize`. + // for example `4096 * 1024`. To get the correct block size, we should use `frsize`. But `frsize` isn't + // guaranteed to be supported everywhere, so we need to check whether it's supported before use it. + // For more details, please refer to: https://github.com/docker/for-mac/issues/2136 + bSize := uint64(stat.Bsize) + field := reflect.ValueOf(&stat).Elem().FieldByName("Frsize") + if field.IsValid() { + if field.Kind() == reflect.Uint64 { + bSize = field.Uint() + } else { + bSize = uint64(field.Int()) + } + } + + // Available blocks * size per block = available space in bytes + size.Available = uint64(stat.Bavail) * bSize + size.Capacity = stat.Blocks * bSize + + return +} + +// SameDisk is used to check dir1 and dir2 in the same disk. +func SameDisk(dir1 string, dir2 string) (bool, error) { + st1 := syscall.Stat_t{} + st2 := syscall.Stat_t{} + + if err := syscall.Stat(dir1, &st1); err != nil { + return false, err + } + + if err := syscall.Stat(dir2, &st2); err != nil { + return false, err + } + + return st1.Dev == st2.Dev, nil +} diff --git a/pkg/lightning/common/storage_windows.go b/pkg/lightning/common/storage_windows.go index 89b9483592f94..352636893c78f 100644 --- a/pkg/lightning/common/storage_windows.go +++ b/pkg/lightning/common/storage_windows.go @@ -33,10 +33,10 @@ var ( // GetStorageSize gets storage's capacity and available size func GetStorageSize(dir string) (size StorageSize, err error) { - failpoint.Inject("GetStorageSize", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("GetStorageSize")); _err_ == nil { injectedSize := val.(int) - failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) - }) + return StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil + } r, _, e := getDiskFreeSpaceExW.Call( uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(dir))), uintptr(unsafe.Pointer(&size.Available)), diff --git a/pkg/lightning/common/storage_windows.go__failpoint_stash__ b/pkg/lightning/common/storage_windows.go__failpoint_stash__ new file mode 100644 index 0000000000000..89b9483592f94 --- /dev/null +++ b/pkg/lightning/common/storage_windows.go__failpoint_stash__ @@ -0,0 +1,56 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows + +// TODO: Deduplicate this implementation with DM! + +package common + +import ( + "syscall" + "unsafe" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" +) + +var ( + kernel32 = syscall.MustLoadDLL("kernel32.dll") + getDiskFreeSpaceExW = kernel32.MustFindProc("GetDiskFreeSpaceExW") +) + +// GetStorageSize gets storage's capacity and available size +func GetStorageSize(dir string) (size StorageSize, err error) { + failpoint.Inject("GetStorageSize", func(val failpoint.Value) { + injectedSize := val.(int) + failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) + }) + r, _, e := getDiskFreeSpaceExW.Call( + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(dir))), + uintptr(unsafe.Pointer(&size.Available)), + uintptr(unsafe.Pointer(&size.Capacity)), + 0, + ) + if r == 0 { + err = errors.Annotatef(e, "cannot get disk capacity at %s", dir) + } + return +} + +// SameDisk is used to check dir1 and dir2 in the same disk. +func SameDisk(dir1 string, dir2 string) (bool, error) { + // FIXME + return false, nil +} diff --git a/pkg/lightning/common/util.go b/pkg/lightning/common/util.go index ac3ba7d1efaa5..1cca4d45d5f28 100644 --- a/pkg/lightning/common/util.go +++ b/pkg/lightning/common/util.go @@ -93,13 +93,13 @@ func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config { } func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { - failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MustMySQLPassword")); _err_ == nil { pwd := val.(string) if cfg.Passwd != pwd { - failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}) + return nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"} } - failpoint.Return(nil, nil) - }) + return nil, nil + } c, err := mysql.NewConnector(cfg) if err != nil { return nil, errors.Trace(err) diff --git a/pkg/lightning/common/util.go__failpoint_stash__ b/pkg/lightning/common/util.go__failpoint_stash__ new file mode 100644 index 0000000000000..ac3ba7d1efaa5 --- /dev/null +++ b/pkg/lightning/common/util.go__failpoint_stash__ @@ -0,0 +1,704 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "bytes" + "context" + "crypto/tls" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "os" + "strconv" + "strings" + "syscall" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/lightning/log" + "github.com/pingcap/tidb/pkg/parser/model" + tmysql "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbutil" + "github.com/pingcap/tidb/pkg/util/format" + "go.uber.org/zap" +) + +const ( + retryTimeout = 3 * time.Second + + defaultMaxRetry = 3 +) + +// MySQLConnectParam records the parameters needed to connect to a MySQL database. +type MySQLConnectParam struct { + Host string + Port int + User string + Password string + SQLMode string + MaxAllowedPacket uint64 + TLSConfig *tls.Config + AllowFallbackToPlaintext bool + Net string + Vars map[string]string +} + +// ToDriverConfig converts the MySQLConnectParam to a mysql.Config. +func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config { + cfg := mysql.NewConfig() + cfg.Params = make(map[string]string) + + cfg.User = param.User + cfg.Passwd = param.Password + cfg.Net = "tcp" + if param.Net != "" { + cfg.Net = param.Net + } + cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) + cfg.Params["charset"] = "utf8mb4" + cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode) + cfg.MaxAllowedPacket = int(param.MaxAllowedPacket) + + cfg.TLS = param.TLSConfig + cfg.AllowFallbackToPlaintext = param.AllowFallbackToPlaintext + + for k, v := range param.Vars { + cfg.Params[k] = fmt.Sprintf("'%s'", v) + } + return cfg +} + +func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { + failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) { + pwd := val.(string) + if cfg.Passwd != pwd { + failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}) + } + failpoint.Return(nil, nil) + }) + c, err := mysql.NewConnector(cfg) + if err != nil { + return nil, errors.Trace(err) + } + db := sql.OpenDB(c) + if err = db.Ping(); err != nil { + _ = db.Close() + return nil, errors.Trace(err) + } + return db, nil +} + +// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding, +// we will try to connect MySQL with the base64 decoding of the password. +func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { + // Try plain password first. + db, firstErr := tryConnectMySQL(cfg) + if firstErr == nil { + return db, nil + } + // If access is denied and password is encoded by base64, try the decoded string as well. + if mysqlErr, ok := errors.Cause(firstErr).(*mysql.MySQLError); ok && mysqlErr.Number == tmysql.ErrAccessDenied { + // If password is encoded by base64, try the decoded string as well. + password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd) + if decodeErr == nil && string(password) != cfg.Passwd { + cfg.Passwd = string(password) + db2, err := tryConnectMySQL(cfg) + if err == nil { + return db2, nil + } + } + } + // If we can't connect successfully, return the first error. + return nil, errors.Trace(firstErr) +} + +// Connect creates a new connection to the database. +func (param *MySQLConnectParam) Connect() (*sql.DB, error) { + db, err := ConnectMySQL(param.ToDriverConfig()) + if err != nil { + return nil, errors.Trace(err) + } + return db, nil +} + +// IsDirExists checks if dir exists. +func IsDirExists(name string) bool { + f, err := os.Stat(name) + if err != nil { + return false + } + return f != nil && f.IsDir() +} + +// IsEmptyDir checks if dir is empty. +func IsEmptyDir(name string) bool { + entries, err := os.ReadDir(name) + if err != nil { + return false + } + return len(entries) == 0 +} + +// SQLWithRetry constructs a retryable transaction. +type SQLWithRetry struct { + // either *sql.DB or *sql.Conn + DB dbutil.DBExecutor + Logger log.Logger + HideQueryLog bool +} + +func (SQLWithRetry) perform(_ context.Context, parentLogger log.Logger, purpose string, action func() error) error { + return Retry(purpose, parentLogger, action) +} + +// Retry is shared by SQLWithRetry.perform, implementation of GlueCheckpointsDB and TiDB's glue implementation +func Retry(purpose string, parentLogger log.Logger, action func() error) error { + var err error +outside: + for i := 0; i < defaultMaxRetry; i++ { + logger := parentLogger.With(zap.Int("retryCnt", i)) + + if i > 0 { + logger.Warn(purpose + " retry start") + time.Sleep(retryTimeout) + } + + err = action() + switch { + case err == nil: + return nil + // do not retry NotFound error + case errors.IsNotFound(err): + break outside + case IsRetryableError(err): + logger.Warn(purpose+" failed but going to try again", log.ShortError(err)) + continue + default: + logger.Warn(purpose+" failed with no retry", log.ShortError(err)) + break outside + } + } + + return errors.Annotatef(err, "%s failed", purpose) +} + +// QueryRow executes a query that is expected to return at most one row. +func (t SQLWithRetry) QueryRow(ctx context.Context, purpose string, query string, dest ...any) error { + logger := t.Logger + if !t.HideQueryLog { + logger = logger.With(zap.String("query", query)) + } + return t.perform(ctx, logger, purpose, func() error { + return t.DB.QueryRowContext(ctx, query).Scan(dest...) + }) +} + +// QueryStringRows executes a query that is expected to return multiple rows +// whose every column is string. +func (t SQLWithRetry) QueryStringRows(ctx context.Context, purpose string, query string) ([][]string, error) { + var res [][]string + logger := t.Logger + if !t.HideQueryLog { + logger = logger.With(zap.String("query", query)) + } + + err := t.perform(ctx, logger, purpose, func() error { + rows, err := t.DB.QueryContext(ctx, query) + if err != nil { + return err + } + defer rows.Close() + + colNames, err := rows.Columns() + if err != nil { + return err + } + for rows.Next() { + row := make([]string, len(colNames)) + refs := make([]any, 0, len(row)) + for i := range row { + refs = append(refs, &row[i]) + } + if err := rows.Scan(refs...); err != nil { + return err + } + res = append(res, row) + } + return rows.Err() + }) + + return res, err +} + +// Transact executes an action in a transaction, and retry if the +// action failed with a retryable error. +func (t SQLWithRetry) Transact(ctx context.Context, purpose string, action func(context.Context, *sql.Tx) error) error { + return t.perform(ctx, t.Logger, purpose, func() error { + txn, err := t.DB.BeginTx(ctx, nil) + if err != nil { + return errors.Annotate(err, "begin transaction failed") + } + + err = action(ctx, txn) + if err != nil { + rerr := txn.Rollback() + if rerr != nil { + t.Logger.Error(purpose+" rollback transaction failed", log.ShortError(rerr)) + } + // we should return the exec err, instead of the rollback rerr. + // no need to errors.Trace() it, as the error comes from user code anyway. + return err + } + + err = txn.Commit() + if err != nil { + return errors.Annotate(err, "commit transaction failed") + } + + return nil + }) +} + +// Exec executes a single SQL with optional retry. +func (t SQLWithRetry) Exec(ctx context.Context, purpose string, query string, args ...any) error { + logger := t.Logger + if !t.HideQueryLog { + logger = logger.With(zap.String("query", query), zap.Reflect("args", args)) + } + return t.perform(ctx, logger, purpose, func() error { + _, err := t.DB.ExecContext(ctx, query, args...) + return errors.Trace(err) + }) +} + +// IsContextCanceledError returns whether the error is caused by context +// cancellation. This function should only be used when the code logic is +// affected by whether the error is canceling or not. +// +// This function returns `false` (not a context-canceled error) if `err == nil`. +func IsContextCanceledError(err error) bool { + return log.IsContextCanceledError(err) +} + +// UniqueTable returns an unique table name. +func UniqueTable(schema string, table string) string { + var builder strings.Builder + WriteMySQLIdentifier(&builder, schema) + builder.WriteByte('.') + WriteMySQLIdentifier(&builder, table) + return builder.String() +} + +func escapeIdentifiers(identifier []string) []any { + escaped := make([]any, len(identifier)) + for i, id := range identifier { + escaped[i] = EscapeIdentifier(id) + } + return escaped +} + +// SprintfWithIdentifiers escapes the identifiers and sprintf them. The input +// identifiers must not be escaped. +func SprintfWithIdentifiers(format string, identifiers ...string) string { + return fmt.Sprintf(format, escapeIdentifiers(identifiers)...) +} + +// FprintfWithIdentifiers escapes the identifiers and fprintf them. The input +// identifiers must not be escaped. +func FprintfWithIdentifiers(w io.Writer, format string, identifiers ...string) (int, error) { + return fmt.Fprintf(w, format, escapeIdentifiers(identifiers)...) +} + +// EscapeIdentifier quote and escape an sql identifier +func EscapeIdentifier(identifier string) string { + var builder strings.Builder + WriteMySQLIdentifier(&builder, identifier) + return builder.String() +} + +// WriteMySQLIdentifier writes a MySQL identifier into the string builder. +// Writes a MySQL identifier into the string builder. +// The identifier is always escaped into the form "`foo`". +func WriteMySQLIdentifier(builder *strings.Builder, identifier string) { + builder.Grow(len(identifier) + 2) + builder.WriteByte('`') + + // use a C-style loop instead of range loop to avoid UTF-8 decoding + for i := 0; i < len(identifier); i++ { + b := identifier[i] + if b == '`' { + builder.WriteString("``") + } else { + builder.WriteByte(b) + } + } + + builder.WriteByte('`') +} + +// InterpolateMySQLString interpolates a string into a MySQL string literal. +func InterpolateMySQLString(s string) string { + var builder strings.Builder + builder.Grow(len(s) + 2) + builder.WriteByte('\'') + for i := 0; i < len(s); i++ { + b := s[i] + if b == '\'' { + builder.WriteString("''") + } else { + builder.WriteByte(b) + } + } + builder.WriteByte('\'') + return builder.String() +} + +// TableExists return whether table with specified name exists in target db +func TableExists(ctx context.Context, db dbutil.QueryExecutor, schema, table string) (bool, error) { + query := "SELECT 1 from INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?" + var exist string + err := db.QueryRowContext(ctx, query, schema, table).Scan(&exist) + switch err { + case nil: + return true, nil + case sql.ErrNoRows: + return false, nil + default: + return false, errors.Annotatef(err, "check table exists failed") + } +} + +// SchemaExists return whether schema with specified name exists. +func SchemaExists(ctx context.Context, db dbutil.QueryExecutor, schema string) (bool, error) { + query := "SELECT 1 from INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?" + var exist string + err := db.QueryRowContext(ctx, query, schema).Scan(&exist) + switch err { + case nil: + return true, nil + case sql.ErrNoRows: + return false, nil + default: + return false, errors.Annotatef(err, "check schema exists failed") + } +} + +// GetJSON fetches a page and parses it as JSON. The parsed result will be +// stored into the `v`. The variable `v` must be a pointer to a type that can be +// unmarshalled from JSON. +// +// Example: +// +// client := &http.Client{} +// var resp struct { IP string } +// if err := util.GetJSON(client, "http://api.ipify.org/?format=json", &resp); err != nil { +// return errors.Trace(err) +// } +// fmt.Println(resp.IP) +func GetJSON(ctx context.Context, client *http.Client, url string, v any) error { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return errors.Trace(err) + } + + resp, err := client.Do(req) + if err != nil { + return errors.Trace(err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return errors.Trace(err) + } + return errors.Errorf("get %s http status code != 200, message %s", url, string(body)) + } + + return errors.Trace(json.NewDecoder(resp.Body).Decode(v)) +} + +// KillMySelf sends sigint to current process, used in integration test only +// +// Only works on Unix. Signaling on Windows is not supported. +func KillMySelf() error { + proc, err := os.FindProcess(os.Getpid()) + if err == nil { + err = proc.Signal(syscall.SIGINT) + } + return errors.Trace(err) +} + +// KvPair contains a key-value pair and other fields that can be used to ingest +// KV pairs into TiKV. +type KvPair struct { + // Key is the key of the KV pair + Key []byte + // Val is the value of the KV pair + Val []byte + // RowID identifies a KvPair in case two KvPairs are equal in Key and Val. It has + // two sources: + // + // When the KvPair is generated from ADD INDEX, the RowID is the encoded handle. + // + // Otherwise, the RowID is related to the row number in the source files, and + // encode the number with `codec.EncodeComparableVarint`. + RowID []byte +} + +// EncodeIntRowID encodes an int64 row id. +func EncodeIntRowID(rowID int64) []byte { + return codec.EncodeComparableVarint(nil, rowID) +} + +// TableHasAutoRowID return whether table has auto generated row id +func TableHasAutoRowID(info *model.TableInfo) bool { + return !info.PKIsHandle && !info.IsCommonHandle +} + +// TableHasAutoID return whether table has auto generated id. +func TableHasAutoID(info *model.TableInfo) bool { + return TableHasAutoRowID(info) || info.GetAutoIncrementColInfo() != nil || info.ContainsAutoRandomBits() +} + +// GetAutoRandomColumn return the column with auto_random, return nil if the table doesn't have it. +// todo: better put in ddl package, but this will cause import cycle since ddl package import lightning +func GetAutoRandomColumn(tblInfo *model.TableInfo) *model.ColumnInfo { + if !tblInfo.ContainsAutoRandomBits() { + return nil + } + if tblInfo.PKIsHandle { + return tblInfo.GetPkColInfo() + } else if tblInfo.IsCommonHandle { + pk := tables.FindPrimaryIndex(tblInfo) + if pk == nil { + return nil + } + offset := pk.Columns[0].Offset + return tblInfo.Columns[offset] + } + return nil +} + +// GetDropIndexInfos returns the index infos that need to be dropped and the remain indexes. +func GetDropIndexInfos( + tblInfo *model.TableInfo, +) (remainIndexes []*model.IndexInfo, dropIndexes []*model.IndexInfo) { + cols := tblInfo.Columns +loop: + for _, idxInfo := range tblInfo.Indices { + if idxInfo.State != model.StatePublic { + remainIndexes = append(remainIndexes, idxInfo) + continue + } + // Primary key is a cluster index. + if idxInfo.Primary && tblInfo.HasClusteredIndex() { + remainIndexes = append(remainIndexes, idxInfo) + continue + } + // Skip index that contains auto-increment column. + // Because auto column must be defined as a key. + for _, idxCol := range idxInfo.Columns { + flag := cols[idxCol.Offset].GetFlag() + if tmysql.HasAutoIncrementFlag(flag) { + remainIndexes = append(remainIndexes, idxInfo) + continue loop + } + } + dropIndexes = append(dropIndexes, idxInfo) + } + return remainIndexes, dropIndexes +} + +// BuildDropIndexSQL builds the SQL statement to drop index. +func BuildDropIndexSQL(dbName, tableName string, idxInfo *model.IndexInfo) string { + if idxInfo.Primary { + return SprintfWithIdentifiers("ALTER TABLE %s.%s DROP PRIMARY KEY", dbName, tableName) + } + return SprintfWithIdentifiers("ALTER TABLE %s.%s DROP INDEX %s", dbName, tableName, idxInfo.Name.O) +} + +// BuildAddIndexSQL builds the SQL statement to create missing indexes. +// It returns both a single SQL statement that creates all indexes at once, +// and a list of SQL statements that creates each index individually. +func BuildAddIndexSQL( + tableName string, + curTblInfo, + desiredTblInfo *model.TableInfo, +) (singleSQL string, multiSQLs []string) { + addIndexSpecs := make([]string, 0, len(desiredTblInfo.Indices)) +loop: + for _, desiredIdxInfo := range desiredTblInfo.Indices { + for _, curIdxInfo := range curTblInfo.Indices { + if curIdxInfo.Name.L == desiredIdxInfo.Name.L { + continue loop + } + } + + var buf bytes.Buffer + if desiredIdxInfo.Primary { + buf.WriteString("ADD PRIMARY KEY ") + } else if desiredIdxInfo.Unique { + buf.WriteString("ADD UNIQUE KEY ") + } else { + buf.WriteString("ADD KEY ") + } + // "primary" is a special name for primary key, we should not use it as index name. + if desiredIdxInfo.Name.L != "primary" { + buf.WriteString(EscapeIdentifier(desiredIdxInfo.Name.O)) + } + + colStrs := make([]string, 0, len(desiredIdxInfo.Columns)) + for _, col := range desiredIdxInfo.Columns { + var colStr string + if desiredTblInfo.Columns[col.Offset].Hidden { + colStr = fmt.Sprintf("(%s)", desiredTblInfo.Columns[col.Offset].GeneratedExprString) + } else { + colStr = EscapeIdentifier(col.Name.O) + if col.Length != types.UnspecifiedLength { + colStr = fmt.Sprintf("%s(%s)", colStr, strconv.Itoa(col.Length)) + } + } + colStrs = append(colStrs, colStr) + } + fmt.Fprintf(&buf, "(%s)", strings.Join(colStrs, ",")) + + if desiredIdxInfo.Invisible { + fmt.Fprint(&buf, " INVISIBLE") + } + if desiredIdxInfo.Comment != "" { + fmt.Fprintf(&buf, ` COMMENT '%s'`, format.OutputFormat(desiredIdxInfo.Comment)) + } + addIndexSpecs = append(addIndexSpecs, buf.String()) + } + if len(addIndexSpecs) == 0 { + return "", nil + } + + singleSQL = fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(addIndexSpecs, ", ")) + for _, spec := range addIndexSpecs { + multiSQLs = append(multiSQLs, fmt.Sprintf("ALTER TABLE %s %s", tableName, spec)) + } + return singleSQL, multiSQLs +} + +// IsDupKeyError checks if err is a duplicate index error. +func IsDupKeyError(err error) bool { + if merr, ok := errors.Cause(err).(*mysql.MySQLError); ok { + switch merr.Number { + case errno.ErrDupKeyName, errno.ErrMultiplePriKey, errno.ErrDupUnique: + return true + } + } + return false +} + +// GetBackoffWeightFromDB gets the backoff weight from database. +func GetBackoffWeightFromDB(ctx context.Context, db *sql.DB) (int, error) { + val, err := getSessionVariable(ctx, db, variable.TiDBBackOffWeight) + if err != nil { + return 0, err + } + return strconv.Atoi(val) +} + +// GetExplicitRequestSourceTypeFromDB gets the explicit request source type from database. +func GetExplicitRequestSourceTypeFromDB(ctx context.Context, db *sql.DB) (string, error) { + return getSessionVariable(ctx, db, variable.TiDBExplicitRequestSourceType) +} + +// copy from dbutil to avoid import cycle +func getSessionVariable(ctx context.Context, db *sql.DB, variable string) (value string, err error) { + query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", variable) + rows, err := db.QueryContext(ctx, query) + + if err != nil { + return "", errors.Trace(err) + } + defer rows.Close() + + // Show an example. + /* + mysql> SHOW VARIABLES LIKE "binlog_format"; + +---------------+-------+ + | Variable_name | Value | + +---------------+-------+ + | binlog_format | ROW | + +---------------+-------+ + */ + + for rows.Next() { + if err = rows.Scan(&variable, &value); err != nil { + return "", errors.Trace(err) + } + } + + if err := rows.Err(); err != nil { + return "", errors.Trace(err) + } + + return value, nil +} + +// IsFunctionNotExistErr checks if err is a function not exist error. +func IsFunctionNotExistErr(err error, functionName string) bool { + return err != nil && + (strings.Contains(err.Error(), "No database selected") || + strings.Contains(err.Error(), fmt.Sprintf("%s does not exist", functionName))) +} + +// IsRaftKV2 checks whether the raft-kv2 is enabled +func IsRaftKV2(ctx context.Context, db *sql.DB) (bool, error) { + var ( + getRaftKvVersionSQL = "show config where type = 'tikv' and name = 'storage.engine'" + raftKv2 = "raft-kv2" + tp, instance, name, value string + ) + + rows, err := db.QueryContext(ctx, getRaftKvVersionSQL) + if err != nil { + return false, errors.Trace(err) + } + defer rows.Close() + + for rows.Next() { + if err = rows.Scan(&tp, &instance, &name, &value); err != nil { + return false, errors.Trace(err) + } + if value == raftKv2 { + return true, nil + } + } + return false, rows.Err() +} + +// IsAccessDeniedNeedConfigPrivilegeError checks if err is generated from a query to TiDB which failed due to missing CONFIG privilege. +func IsAccessDeniedNeedConfigPrivilegeError(err error) bool { + e, ok := err.(*mysql.MySQLError) + return ok && e.Number == errno.ErrSpecificAccessDenied && strings.Contains(e.Message, "CONFIG") +} diff --git a/pkg/lightning/mydump/binding__failpoint_binding__.go b/pkg/lightning/mydump/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..ce8a21b37f034 --- /dev/null +++ b/pkg/lightning/mydump/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package mydump + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/lightning/mydump/loader.go b/pkg/lightning/mydump/loader.go index d2fc407ddcdeb..b0d22084400f6 100644 --- a/pkg/lightning/mydump/loader.go +++ b/pkg/lightning/mydump/loader.go @@ -789,14 +789,14 @@ func calculateFileBytes(ctx context.Context, // SampleFileCompressRatio samples the compress ratio of the compressed file. func SampleFileCompressRatio(ctx context.Context, fileMeta SourceFileMeta, store storage.ExternalStorage) (float64, error) { - failpoint.Inject("SampleFileCompressPercentage", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SampleFileCompressPercentage")); _err_ == nil { switch v := val.(type) { case string: - failpoint.Return(1.0, errors.New(v)) + return 1.0, errors.New(v) case int: - failpoint.Return(float64(v)/100, nil) + return float64(v) / 100, nil } - }) + } if fileMeta.Compression == CompressionNone { return 1, nil } diff --git a/pkg/lightning/mydump/loader.go__failpoint_stash__ b/pkg/lightning/mydump/loader.go__failpoint_stash__ new file mode 100644 index 0000000000000..d2fc407ddcdeb --- /dev/null +++ b/pkg/lightning/mydump/loader.go__failpoint_stash__ @@ -0,0 +1,868 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mydump + +import ( + "context" + "io" + "path/filepath" + "sort" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/pkg/lightning/common" + "github.com/pingcap/tidb/pkg/lightning/config" + "github.com/pingcap/tidb/pkg/lightning/log" + regexprrouter "github.com/pingcap/tidb/pkg/util/regexpr-router" + filter "github.com/pingcap/tidb/pkg/util/table-filter" + "go.uber.org/zap" +) + +// sampleCompressedFileSize represents how many bytes need to be sampled for compressed files +const ( + sampleCompressedFileSize = 4 * 1024 + maxSampleParquetDataSize = 8 * 1024 + maxSampleParquetRowCount = 500 +) + +// MDDatabaseMeta contains some parsed metadata for a database in the source by MyDumper Loader. +type MDDatabaseMeta struct { + Name string + SchemaFile FileInfo + Tables []*MDTableMeta + Views []*MDTableMeta + charSet string +} + +// NewMDDatabaseMeta creates an Mydumper database meta with specified character set. +func NewMDDatabaseMeta(charSet string) *MDDatabaseMeta { + return &MDDatabaseMeta{ + charSet: charSet, + } +} + +// GetSchema gets the schema SQL for a source database. +func (m *MDDatabaseMeta) GetSchema(ctx context.Context, store storage.ExternalStorage) string { + if m.SchemaFile.FileMeta.Path != "" { + schema, err := ExportStatement(ctx, store, m.SchemaFile, m.charSet) + if err != nil { + log.FromContext(ctx).Warn("failed to extract table schema", + zap.String("Path", m.SchemaFile.FileMeta.Path), + log.ShortError(err), + ) + } else if schemaStr := strings.TrimSpace(string(schema)); schemaStr != "" { + return schemaStr + } + } + // set default if schema sql is empty or failed to extract. + return common.SprintfWithIdentifiers("CREATE DATABASE IF NOT EXISTS %s", m.Name) +} + +// MDTableMeta contains some parsed metadata for a table in the source by MyDumper Loader. +type MDTableMeta struct { + DB string + Name string + SchemaFile FileInfo + DataFiles []FileInfo + charSet string + TotalSize int64 + IndexRatio float64 + // default to true, and if we do precheck, this var is updated using data sampling result, so it's not accurate. + IsRowOrdered bool +} + +// SourceFileMeta contains some analyzed metadata for a source file by MyDumper Loader. +type SourceFileMeta struct { + Path string + Type SourceType + Compression Compression + SortKey string + // FileSize is the size of the file in the storage. + FileSize int64 + // WARNING: variables below are not persistent + ExtendData ExtendColumnData + // RealSize is same as FileSize if the file is not compressed and not parquet. + // If the file is compressed, RealSize is the estimated uncompressed size. + // If the file is parquet, RealSize is the estimated data size after convert + // to row oriented storage. + RealSize int64 + Rows int64 // only for parquet +} + +// NewMDTableMeta creates an Mydumper table meta with specified character set. +func NewMDTableMeta(charSet string) *MDTableMeta { + return &MDTableMeta{ + charSet: charSet, + } +} + +// GetSchema gets the table-creating SQL for a source table. +func (m *MDTableMeta) GetSchema(ctx context.Context, store storage.ExternalStorage) (string, error) { + schemaFilePath := m.SchemaFile.FileMeta.Path + if len(schemaFilePath) <= 0 { + return "", errors.Errorf("schema file is missing for the table '%s.%s'", m.DB, m.Name) + } + fileExists, err := store.FileExists(ctx, schemaFilePath) + if err != nil { + return "", errors.Annotate(err, "check table schema file exists error") + } + if !fileExists { + return "", errors.Errorf("the provided schema file (%s) for the table '%s.%s' doesn't exist", + schemaFilePath, m.DB, m.Name) + } + schema, err := ExportStatement(ctx, store, m.SchemaFile, m.charSet) + if err != nil { + log.FromContext(ctx).Error("failed to extract table schema", + zap.String("Path", m.SchemaFile.FileMeta.Path), + log.ShortError(err), + ) + return "", errors.Trace(err) + } + return string(schema), nil +} + +// MDLoaderSetupConfig stores the configs when setting up a MDLoader. +// This can control the behavior when constructing an MDLoader. +type MDLoaderSetupConfig struct { + // MaxScanFiles specifies the maximum number of files to scan. + // If the value is <= 0, it means the number of data source files will be scanned as many as possible. + MaxScanFiles int + // ReturnPartialResultOnError specifies whether the currently scanned files are analyzed, + // and return the partial result. + ReturnPartialResultOnError bool + // FileIter controls the file iteration policy when constructing a MDLoader. + FileIter FileIterator +} + +// DefaultMDLoaderSetupConfig generates a default MDLoaderSetupConfig. +func DefaultMDLoaderSetupConfig() *MDLoaderSetupConfig { + return &MDLoaderSetupConfig{ + MaxScanFiles: 0, // By default, the loader will scan all the files. + ReturnPartialResultOnError: false, + FileIter: nil, + } +} + +// MDLoaderSetupOption is the option type for setting up a MDLoaderSetupConfig. +type MDLoaderSetupOption func(cfg *MDLoaderSetupConfig) + +// WithMaxScanFiles generates an option that limits the max scan files when setting up a MDLoader. +func WithMaxScanFiles(maxScanFiles int) MDLoaderSetupOption { + return func(cfg *MDLoaderSetupConfig) { + if maxScanFiles > 0 { + cfg.MaxScanFiles = maxScanFiles + cfg.ReturnPartialResultOnError = true + } + } +} + +// ReturnPartialResultOnError generates an option that controls +// whether return the partial scanned result on error when setting up a MDLoader. +func ReturnPartialResultOnError(supportPartialResult bool) MDLoaderSetupOption { + return func(cfg *MDLoaderSetupConfig) { + cfg.ReturnPartialResultOnError = supportPartialResult + } +} + +// WithFileIterator generates an option that specifies the file iteration policy. +func WithFileIterator(fileIter FileIterator) MDLoaderSetupOption { + return func(cfg *MDLoaderSetupConfig) { + cfg.FileIter = fileIter + } +} + +// LoaderConfig is the configuration for constructing a MDLoader. +type LoaderConfig struct { + // SourceID is the unique identifier for the data source, it's used in DM only. + // must be used together with Routes. + SourceID string + // SourceURL is the URL of the data source. + SourceURL string + // Routes is the routing rules for the tables, exclusive with FileRouters. + // it's deprecated in lightning, but still used in DM. + // when used this, DefaultFileRules must be true. + Routes config.Routes + // CharacterSet is the character set of the schema sql files. + CharacterSet string + // Filter is the filter for the tables, files related to filtered-out tables are not loaded. + // must be specified, else all tables are filtered out, see config.GetDefaultFilter. + Filter []string + FileRouters []*config.FileRouteRule + // CaseSensitive indicates whether Routes and Filter are case-sensitive. + CaseSensitive bool + // DefaultFileRules indicates whether to use the default file routing rules. + // If it's true, the default file routing rules will be appended to the FileRouters. + // a little confusing, but it's true only when FileRouters is empty. + DefaultFileRules bool +} + +// NewLoaderCfg creates loader config from lightning config. +func NewLoaderCfg(cfg *config.Config) LoaderConfig { + return LoaderConfig{ + SourceID: cfg.Mydumper.SourceID, + SourceURL: cfg.Mydumper.SourceDir, + Routes: cfg.Routes, + CharacterSet: cfg.Mydumper.CharacterSet, + Filter: cfg.Mydumper.Filter, + FileRouters: cfg.Mydumper.FileRouters, + CaseSensitive: cfg.Mydumper.CaseSensitive, + DefaultFileRules: cfg.Mydumper.DefaultFileRules, + } +} + +// MDLoader is for 'Mydumper File Loader', which loads the files in the data source and generates a set of metadata. +type MDLoader struct { + store storage.ExternalStorage + dbs []*MDDatabaseMeta + filter filter.Filter + router *regexprrouter.RouteTable + fileRouter FileRouter + charSet string +} + +type mdLoaderSetup struct { + sourceID string + loader *MDLoader + dbSchemas []FileInfo + tableSchemas []FileInfo + viewSchemas []FileInfo + tableDatas []FileInfo + dbIndexMap map[string]int + tableIndexMap map[filter.Table]int + setupCfg *MDLoaderSetupConfig +} + +// NewLoader constructs a MyDumper loader that scanns the data source and constructs a set of metadatas. +func NewLoader(ctx context.Context, cfg LoaderConfig, opts ...MDLoaderSetupOption) (*MDLoader, error) { + u, err := storage.ParseBackend(cfg.SourceURL, nil) + if err != nil { + return nil, common.NormalizeError(err) + } + s, err := storage.New(ctx, u, &storage.ExternalStorageOptions{}) + if err != nil { + return nil, common.NormalizeError(err) + } + + return NewLoaderWithStore(ctx, cfg, s, opts...) +} + +// NewLoaderWithStore constructs a MyDumper loader with the provided external storage that scanns the data source and constructs a set of metadatas. +func NewLoaderWithStore(ctx context.Context, cfg LoaderConfig, + store storage.ExternalStorage, opts ...MDLoaderSetupOption) (*MDLoader, error) { + var r *regexprrouter.RouteTable + var err error + + mdLoaderSetupCfg := DefaultMDLoaderSetupConfig() + for _, o := range opts { + o(mdLoaderSetupCfg) + } + if mdLoaderSetupCfg.FileIter == nil { + mdLoaderSetupCfg.FileIter = &allFileIterator{ + store: store, + maxScanFiles: mdLoaderSetupCfg.MaxScanFiles, + } + } + + if len(cfg.Routes) > 0 && len(cfg.FileRouters) > 0 { + return nil, common.ErrInvalidConfig.GenWithStack("table route is deprecated, can't config both [routes] and [mydumper.files]") + } + + if len(cfg.Routes) > 0 { + r, err = regexprrouter.NewRegExprRouter(cfg.CaseSensitive, cfg.Routes) + if err != nil { + return nil, common.ErrInvalidConfig.Wrap(err).GenWithStack("invalid table route rule") + } + } + + f, err := filter.Parse(cfg.Filter) + if err != nil { + return nil, common.ErrInvalidConfig.Wrap(err).GenWithStack("parse filter failed") + } + if !cfg.CaseSensitive { + f = filter.CaseInsensitive(f) + } + + fileRouteRules := cfg.FileRouters + if cfg.DefaultFileRules { + fileRouteRules = append(fileRouteRules, defaultFileRouteRules...) + } + + fileRouter, err := NewFileRouter(fileRouteRules, log.FromContext(ctx)) + if err != nil { + return nil, common.ErrInvalidConfig.Wrap(err).GenWithStack("parse file routing rule failed") + } + + mdl := &MDLoader{ + store: store, + filter: f, + router: r, + charSet: cfg.CharacterSet, + fileRouter: fileRouter, + } + + setup := mdLoaderSetup{ + sourceID: cfg.SourceID, + loader: mdl, + dbIndexMap: make(map[string]int), + tableIndexMap: make(map[filter.Table]int), + setupCfg: mdLoaderSetupCfg, + } + + if err := setup.setup(ctx); err != nil { + if mdLoaderSetupCfg.ReturnPartialResultOnError { + return mdl, errors.Trace(err) + } + return nil, errors.Trace(err) + } + + return mdl, nil +} + +type fileType int + +const ( + fileTypeDatabaseSchema fileType = iota + fileTypeTableSchema + fileTypeTableData +) + +// String implements the Stringer interface. +func (ftype fileType) String() string { + switch ftype { + case fileTypeDatabaseSchema: + return "database schema" + case fileTypeTableSchema: + return "table schema" + case fileTypeTableData: + return "table data" + default: + return "(unknown)" + } +} + +// FileInfo contains the information for a data file in a table. +type FileInfo struct { + TableName filter.Table + FileMeta SourceFileMeta +} + +// ExtendColumnData contains the extended column names and values information for a table. +type ExtendColumnData struct { + Columns []string + Values []string +} + +// setup the `s.loader.dbs` slice by scanning all *.sql files inside `dir`. +// +// The database and tables are inserted in a consistent order, so creating an +// MDLoader twice with the same data source is going to produce the same array, +// even after killing Lightning. +// +// This is achieved by using `filepath.Walk` internally which guarantees the +// files are visited in lexicographical order (note that this does not mean the +// databases and tables in the end are ordered lexicographically since they may +// be stored in different subdirectories). +// +// Will sort tables by table size, this means that the big table is imported +// at the latest, which to avoid large table take a long time to import and block +// small table to release index worker. +func (s *mdLoaderSetup) setup(ctx context.Context) error { + /* + Mydumper file names format + db —— {db}-schema-create.sql + table —— {db}.{table}-schema.sql + sql —— {db}.{table}.{part}.sql / {db}.{table}.sql + */ + var gerr error + fileIter := s.setupCfg.FileIter + if fileIter == nil { + return errors.New("file iterator is not defined") + } + if err := fileIter.IterateFiles(ctx, s.constructFileInfo); err != nil { + if !s.setupCfg.ReturnPartialResultOnError { + return common.ErrStorageUnknown.Wrap(err).GenWithStack("list file failed") + } + gerr = err + } + if err := s.route(); err != nil { + return common.ErrTableRoute.Wrap(err).GenWithStackByArgs() + } + + // setup database schema + if len(s.dbSchemas) != 0 { + for _, fileInfo := range s.dbSchemas { + if _, dbExists := s.insertDB(fileInfo); dbExists && s.loader.router == nil { + return common.ErrInvalidSchemaFile.GenWithStack("invalid database schema file, duplicated item - %s", fileInfo.FileMeta.Path) + } + } + } + + if len(s.tableSchemas) != 0 { + // setup table schema + for _, fileInfo := range s.tableSchemas { + if _, _, tableExists := s.insertTable(fileInfo); tableExists && s.loader.router == nil { + return common.ErrInvalidSchemaFile.GenWithStack("invalid table schema file, duplicated item - %s", fileInfo.FileMeta.Path) + } + } + } + + if len(s.viewSchemas) != 0 { + // setup view schema + for _, fileInfo := range s.viewSchemas { + _, tableExists := s.insertView(fileInfo) + if !tableExists { + // we are not expect the user only has view schema without table schema when user use dumpling to get view. + // remove the last `-view.sql` from path as the relate table schema file path + return common.ErrInvalidSchemaFile.GenWithStack("invalid view schema file, miss host table schema for view '%s'", fileInfo.TableName.Name) + } + } + } + + // Sql file for restore data + for _, fileInfo := range s.tableDatas { + // set a dummy `FileInfo` here without file meta because we needn't restore the table schema + tableMeta, _, _ := s.insertTable(FileInfo{TableName: fileInfo.TableName}) + tableMeta.DataFiles = append(tableMeta.DataFiles, fileInfo) + tableMeta.TotalSize += fileInfo.FileMeta.RealSize + } + + for _, dbMeta := range s.loader.dbs { + // Put the small table in the front of the slice which can avoid large table + // take a long time to import and block small table to release index worker. + meta := dbMeta + sort.SliceStable(meta.Tables, func(i, j int) bool { + return meta.Tables[i].TotalSize < meta.Tables[j].TotalSize + }) + + // sort each table source files by sort-key + for _, tbMeta := range meta.Tables { + dataFiles := tbMeta.DataFiles + sort.SliceStable(dataFiles, func(i, j int) bool { + return dataFiles[i].FileMeta.SortKey < dataFiles[j].FileMeta.SortKey + }) + } + } + + return gerr +} + +// FileHandler is the interface to handle the file give the path and size. +// It is mainly used in the `FileIterator` as parameters. +type FileHandler func(ctx context.Context, path string, size int64) error + +// FileIterator is the interface to iterate files in a data source. +// Use this interface to customize the file iteration policy. +type FileIterator interface { + IterateFiles(ctx context.Context, hdl FileHandler) error +} + +type allFileIterator struct { + store storage.ExternalStorage + maxScanFiles int +} + +func (iter *allFileIterator) IterateFiles(ctx context.Context, hdl FileHandler) error { + // `filepath.Walk` yields the paths in a deterministic (lexicographical) order, + // meaning the file and chunk orders will be the same everytime it is called + // (as long as the source is immutable). + totalScannedFileCount := 0 + err := iter.store.WalkDir(ctx, &storage.WalkOption{}, func(path string, size int64) error { + totalScannedFileCount++ + if iter.maxScanFiles > 0 && totalScannedFileCount > iter.maxScanFiles { + return common.ErrTooManySourceFiles + } + return hdl(ctx, path, size) + }) + + return errors.Trace(err) +} + +func (s *mdLoaderSetup) constructFileInfo(ctx context.Context, path string, size int64) error { + logger := log.FromContext(ctx).With(zap.String("path", path)) + res, err := s.loader.fileRouter.Route(filepath.ToSlash(path)) + if err != nil { + return errors.Annotatef(err, "apply file routing on file '%s' failed", path) + } + if res == nil { + logger.Info("file is filtered by file router", zap.String("category", "loader")) + return nil + } + + info := FileInfo{ + TableName: filter.Table{Schema: res.Schema, Name: res.Name}, + FileMeta: SourceFileMeta{Path: path, Type: res.Type, Compression: res.Compression, SortKey: res.Key, FileSize: size, RealSize: size}, + } + + if s.loader.shouldSkip(&info.TableName) { + logger.Debug("ignoring table file", zap.String("category", "filter")) + + return nil + } + + switch res.Type { + case SourceTypeSchemaSchema: + s.dbSchemas = append(s.dbSchemas, info) + case SourceTypeTableSchema: + s.tableSchemas = append(s.tableSchemas, info) + case SourceTypeViewSchema: + s.viewSchemas = append(s.viewSchemas, info) + case SourceTypeSQL, SourceTypeCSV: + if info.FileMeta.Compression != CompressionNone { + compressRatio, err2 := SampleFileCompressRatio(ctx, info.FileMeta, s.loader.GetStore()) + if err2 != nil { + logger.Error("fail to calculate data file compress ratio", zap.String("category", "loader"), + zap.String("schema", res.Schema), zap.String("table", res.Name), zap.Stringer("type", res.Type)) + } else { + info.FileMeta.RealSize = int64(compressRatio * float64(info.FileMeta.FileSize)) + } + } + s.tableDatas = append(s.tableDatas, info) + case SourceTypeParquet: + parquestDataSize, err2 := SampleParquetDataSize(ctx, info.FileMeta, s.loader.GetStore()) + if err2 != nil { + logger.Error("fail to sample parquet data size", zap.String("category", "loader"), + zap.String("schema", res.Schema), zap.String("table", res.Name), zap.Stringer("type", res.Type), zap.Error(err2)) + } else { + info.FileMeta.RealSize = parquestDataSize + } + s.tableDatas = append(s.tableDatas, info) + } + + logger.Debug("file route result", zap.String("schema", res.Schema), + zap.String("table", res.Name), zap.Stringer("type", res.Type)) + + return nil +} + +func (l *MDLoader) shouldSkip(table *filter.Table) bool { + if len(table.Name) == 0 { + return !l.filter.MatchSchema(table.Schema) + } + return !l.filter.MatchTable(table.Schema, table.Name) +} + +func (s *mdLoaderSetup) route() error { + r := s.loader.router + if r == nil { + return nil + } + + type dbInfo struct { + fileMeta SourceFileMeta + count int // means file count(db/table/view schema and table data) + } + + knownDBNames := make(map[string]*dbInfo) + for _, info := range s.dbSchemas { + knownDBNames[info.TableName.Schema] = &dbInfo{ + fileMeta: info.FileMeta, + count: 1, + } + } + for _, info := range s.tableSchemas { + if _, ok := knownDBNames[info.TableName.Schema]; !ok { + knownDBNames[info.TableName.Schema] = &dbInfo{ + fileMeta: info.FileMeta, + count: 1, + } + } + knownDBNames[info.TableName.Schema].count++ + } + for _, info := range s.viewSchemas { + if _, ok := knownDBNames[info.TableName.Schema]; !ok { + knownDBNames[info.TableName.Schema] = &dbInfo{ + fileMeta: info.FileMeta, + count: 1, + } + } + knownDBNames[info.TableName.Schema].count++ + } + for _, info := range s.tableDatas { + if _, ok := knownDBNames[info.TableName.Schema]; !ok { + knownDBNames[info.TableName.Schema] = &dbInfo{ + fileMeta: info.FileMeta, + count: 1, + } + } + knownDBNames[info.TableName.Schema].count++ + } + + runRoute := func(arr []FileInfo) error { + for i, info := range arr { + rawDB, rawTable := info.TableName.Schema, info.TableName.Name + targetDB, targetTable, err := r.Route(rawDB, rawTable) + if err != nil { + return errors.Trace(err) + } + if targetDB != rawDB { + oldInfo := knownDBNames[rawDB] + oldInfo.count-- + newInfo, ok := knownDBNames[targetDB] + if !ok { + newInfo = &dbInfo{fileMeta: oldInfo.fileMeta, count: 1} + s.dbSchemas = append(s.dbSchemas, FileInfo{ + TableName: filter.Table{Schema: targetDB}, + FileMeta: oldInfo.fileMeta, + }) + } + newInfo.count++ + knownDBNames[targetDB] = newInfo + } + arr[i].TableName = filter.Table{Schema: targetDB, Name: targetTable} + extendCols, extendVals := r.FetchExtendColumn(rawDB, rawTable, s.sourceID) + if len(extendCols) > 0 { + arr[i].FileMeta.ExtendData = ExtendColumnData{ + Columns: extendCols, + Values: extendVals, + } + } + } + return nil + } + + // route for schema table and view + if err := runRoute(s.dbSchemas); err != nil { + return errors.Trace(err) + } + if err := runRoute(s.tableSchemas); err != nil { + return errors.Trace(err) + } + if err := runRoute(s.viewSchemas); err != nil { + return errors.Trace(err) + } + if err := runRoute(s.tableDatas); err != nil { + return errors.Trace(err) + } + // remove all schemas which has been entirely routed away(file count > 0) + // https://github.com/golang/go/wiki/SliceTricks#filtering-without-allocating + remainingSchemas := s.dbSchemas[:0] + for _, info := range s.dbSchemas { + if dbInfo := knownDBNames[info.TableName.Schema]; dbInfo.count > 0 { + remainingSchemas = append(remainingSchemas, info) + } else if dbInfo.count < 0 { + // this should not happen if there are no bugs in the code + return common.ErrTableRoute.GenWithStack("something wrong happened when route %s", info.TableName.String()) + } + } + s.dbSchemas = remainingSchemas + return nil +} + +func (s *mdLoaderSetup) insertDB(f FileInfo) (*MDDatabaseMeta, bool) { + dbIndex, ok := s.dbIndexMap[f.TableName.Schema] + if ok { + return s.loader.dbs[dbIndex], true + } + s.dbIndexMap[f.TableName.Schema] = len(s.loader.dbs) + ptr := &MDDatabaseMeta{ + Name: f.TableName.Schema, + SchemaFile: f, + charSet: s.loader.charSet, + } + s.loader.dbs = append(s.loader.dbs, ptr) + return ptr, false +} + +func (s *mdLoaderSetup) insertTable(fileInfo FileInfo) (tblMeta *MDTableMeta, dbExists bool, tableExists bool) { + dbFileInfo := FileInfo{ + TableName: filter.Table{ + Schema: fileInfo.TableName.Schema, + }, + FileMeta: SourceFileMeta{Type: SourceTypeSchemaSchema}, + } + dbMeta, dbExists := s.insertDB(dbFileInfo) + tableIndex, ok := s.tableIndexMap[fileInfo.TableName] + if ok { + return dbMeta.Tables[tableIndex], dbExists, true + } + s.tableIndexMap[fileInfo.TableName] = len(dbMeta.Tables) + ptr := &MDTableMeta{ + DB: fileInfo.TableName.Schema, + Name: fileInfo.TableName.Name, + SchemaFile: fileInfo, + DataFiles: make([]FileInfo, 0, 16), + charSet: s.loader.charSet, + IndexRatio: 0.0, + IsRowOrdered: true, + } + dbMeta.Tables = append(dbMeta.Tables, ptr) + return ptr, dbExists, false +} + +func (s *mdLoaderSetup) insertView(fileInfo FileInfo) (dbExists bool, tableExists bool) { + dbFileInfo := FileInfo{ + TableName: filter.Table{ + Schema: fileInfo.TableName.Schema, + }, + FileMeta: SourceFileMeta{Type: SourceTypeSchemaSchema}, + } + dbMeta, dbExists := s.insertDB(dbFileInfo) + _, ok := s.tableIndexMap[fileInfo.TableName] + if ok { + meta := &MDTableMeta{ + DB: fileInfo.TableName.Schema, + Name: fileInfo.TableName.Name, + SchemaFile: fileInfo, + charSet: s.loader.charSet, + IndexRatio: 0.0, + IsRowOrdered: true, + } + dbMeta.Views = append(dbMeta.Views, meta) + } + return dbExists, ok +} + +// GetDatabases gets the list of scanned MDDatabaseMeta for the loader. +func (l *MDLoader) GetDatabases() []*MDDatabaseMeta { + return l.dbs +} + +// GetStore gets the external storage used by the loader. +func (l *MDLoader) GetStore() storage.ExternalStorage { + return l.store +} + +func calculateFileBytes(ctx context.Context, + dataFile string, + compressType storage.CompressType, + store storage.ExternalStorage, + offset int64) (tot int, pos int64, err error) { + bytes := make([]byte, sampleCompressedFileSize) + reader, err := store.Open(ctx, dataFile, nil) + if err != nil { + return 0, 0, errors.Trace(err) + } + defer reader.Close() + + decompressConfig := storage.DecompressConfig{ZStdDecodeConcurrency: 1} + compressReader, err := storage.NewLimitedInterceptReader(reader, compressType, decompressConfig, offset) + if err != nil { + return 0, 0, errors.Trace(err) + } + + readBytes := func() error { + n, err2 := compressReader.Read(bytes) + if err2 != nil && errors.Cause(err2) != io.EOF && errors.Cause(err) != io.ErrUnexpectedEOF { + return err2 + } + tot += n + return err2 + } + + if offset == 0 { + err = readBytes() + if err != nil && errors.Cause(err) != io.EOF && errors.Cause(err) != io.ErrUnexpectedEOF { + return 0, 0, err + } + pos, err = compressReader.Seek(0, io.SeekCurrent) + if err != nil { + return 0, 0, errors.Trace(err) + } + return tot, pos, nil + } + + for { + err = readBytes() + if err != nil { + break + } + } + if err != nil && errors.Cause(err) != io.EOF && errors.Cause(err) != io.ErrUnexpectedEOF { + return 0, 0, errors.Trace(err) + } + return tot, offset, nil +} + +// SampleFileCompressRatio samples the compress ratio of the compressed file. +func SampleFileCompressRatio(ctx context.Context, fileMeta SourceFileMeta, store storage.ExternalStorage) (float64, error) { + failpoint.Inject("SampleFileCompressPercentage", func(val failpoint.Value) { + switch v := val.(type) { + case string: + failpoint.Return(1.0, errors.New(v)) + case int: + failpoint.Return(float64(v)/100, nil) + } + }) + if fileMeta.Compression == CompressionNone { + return 1, nil + } + compressType, err := ToStorageCompressType(fileMeta.Compression) + if err != nil { + return 0, err + } + // We use the following method to sample the compress ratio of the first few bytes of the file. + // 1. read first time aiming to find a valid compressed file offset. If we continue read now, the compress reader will + // request more data from file reader buffer them in its memory. We can't compute an accurate compress ratio. + // 2. we use a second reading and limit the file reader only read n bytes(n is the valid position we find in the first reading). + // Then we read all the data out from the compress reader. The data length m we read out is the uncompressed data length. + // Use m/n to compute the compress ratio. + // read first time, aims to find a valid end pos in compressed file + _, pos, err := calculateFileBytes(ctx, fileMeta.Path, compressType, store, 0) + if err != nil { + return 0, err + } + // read second time, original reader ends at first time's valid pos, compute sample data compress ratio + tot, pos, err := calculateFileBytes(ctx, fileMeta.Path, compressType, store, pos) + if err != nil { + return 0, err + } + return float64(tot) / float64(pos), nil +} + +// SampleParquetDataSize samples the data size of the parquet file. +func SampleParquetDataSize(ctx context.Context, fileMeta SourceFileMeta, store storage.ExternalStorage) (int64, error) { + totalRowCount, err := ReadParquetFileRowCountByFile(ctx, store, fileMeta) + if totalRowCount == 0 || err != nil { + return 0, err + } + + reader, err := store.Open(ctx, fileMeta.Path, nil) + if err != nil { + return 0, err + } + parser, err := NewParquetParser(ctx, store, reader, fileMeta.Path) + if err != nil { + //nolint: errcheck + reader.Close() + return 0, err + } + //nolint: errcheck + defer parser.Close() + + var ( + rowSize int64 + rowCount int64 + ) + for { + err = parser.ReadRow() + if err != nil { + if errors.Cause(err) == io.EOF { + break + } + return 0, err + } + lastRow := parser.LastRow() + rowCount++ + rowSize += int64(lastRow.Length) + parser.RecycleRow(lastRow) + if rowSize > maxSampleParquetDataSize || rowCount > maxSampleParquetRowCount { + break + } + } + size := int64(float64(totalRowCount) / float64(rowCount) * float64(rowSize)) + return size, nil +} diff --git a/pkg/meta/autoid/autoid.go b/pkg/meta/autoid/autoid.go index 237981497fcf9..d13b3f7fc2b03 100644 --- a/pkg/meta/autoid/autoid.go +++ b/pkg/meta/autoid/autoid.go @@ -539,16 +539,16 @@ func (alloc *allocator) GetType() AllocatorType { // NextStep return new auto id step according to previous step and consuming time. func NextStep(curStep int64, consumeDur time.Duration) int64 { - failpoint.Inject("mockAutoIDCustomize", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockAutoIDCustomize")); _err_ == nil { if val.(bool) { - failpoint.Return(3) + return 3 } - }) - failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("mockAutoIDChange")); _err_ == nil { if val.(bool) { - failpoint.Return(step) + return step } - }) + } consumeRate := defaultConsumeTime.Seconds() / consumeDur.Seconds() res := int64(float64(curStep) * consumeRate) @@ -582,11 +582,11 @@ func newSinglePointAlloc(r Requirement, dbID, tblID int64, isUnsigned bool) *sin } // mockAutoIDChange failpoint is not implemented in this allocator, so fallback to use the default one. - failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockAutoIDChange")); _err_ == nil { if val.(bool) { spa = nil } - }) + } return spa } diff --git a/pkg/meta/autoid/autoid.go__failpoint_stash__ b/pkg/meta/autoid/autoid.go__failpoint_stash__ new file mode 100644 index 0000000000000..237981497fcf9 --- /dev/null +++ b/pkg/meta/autoid/autoid.go__failpoint_stash__ @@ -0,0 +1,1351 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package autoid + +import ( + "bytes" + "context" + "fmt" + "math" + "strconv" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/autoid" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" + tikvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// Attention: +// For reading cluster TiDB memory tables, the system schema/table should be same. +// Once the system schema/table id been allocated, it can't be changed any more. +// Change the system schema/table id may have the compatibility problem. +const ( + // SystemSchemaIDFlag is the system schema/table id flag, uses the highest bit position as system schema ID flag, it's exports for test. + SystemSchemaIDFlag = 1 << 62 + // InformationSchemaDBID is the information_schema schema id, it's exports for test. + InformationSchemaDBID int64 = SystemSchemaIDFlag | 1 + // PerformanceSchemaDBID is the performance_schema schema id, it's exports for test. + PerformanceSchemaDBID int64 = SystemSchemaIDFlag | 10000 + // MetricSchemaDBID is the metrics_schema schema id, it's exported for test. + MetricSchemaDBID int64 = SystemSchemaIDFlag | 20000 +) + +const ( + minStep = 30000 + maxStep = 2000000 + defaultConsumeTime = 10 * time.Second + minIncrement = 1 + maxIncrement = 65535 +) + +// RowIDBitLength is the bit number of a row id in TiDB. +const RowIDBitLength = 64 + +const ( + // AutoRandomShardBitsDefault is the default number of shard bits. + AutoRandomShardBitsDefault = 5 + // AutoRandomRangeBitsDefault is the default number of range bits. + AutoRandomRangeBitsDefault = 64 + // AutoRandomShardBitsMax is the max number of shard bits. + AutoRandomShardBitsMax = 15 + // AutoRandomRangeBitsMax is the max number of range bits. + AutoRandomRangeBitsMax = 64 + // AutoRandomRangeBitsMin is the min number of range bits. + AutoRandomRangeBitsMin = 32 + // AutoRandomIncBitsMin is the min number of auto random incremental bits. + AutoRandomIncBitsMin = 27 +) + +// AutoRandomShardBitsNormalize normalizes the auto random shard bits. +func AutoRandomShardBitsNormalize(shard int, colName string) (ret uint64, err error) { + if shard == types.UnspecifiedLength { + return AutoRandomShardBitsDefault, nil + } + if shard <= 0 { + return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(AutoRandomNonPositive) + } + if shard > AutoRandomShardBitsMax { + errMsg := fmt.Sprintf(AutoRandomOverflowErrMsg, AutoRandomShardBitsMax, shard, colName) + return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(errMsg) + } + return uint64(shard), nil +} + +// AutoRandomRangeBitsNormalize normalizes the auto random range bits. +func AutoRandomRangeBitsNormalize(rangeBits int) (ret uint64, err error) { + if rangeBits == types.UnspecifiedLength { + return AutoRandomRangeBitsDefault, nil + } + if rangeBits < AutoRandomRangeBitsMin || rangeBits > AutoRandomRangeBitsMax { + errMsg := fmt.Sprintf(AutoRandomInvalidRangeBits, AutoRandomRangeBitsMin, AutoRandomRangeBitsMax, rangeBits) + return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(errMsg) + } + return uint64(rangeBits), nil +} + +// Test needs to change it, so it's a variable. +var step = int64(30000) + +// AllocatorType is the type of allocator for generating auto-id. Different type of allocators use different key-value pairs. +type AllocatorType uint8 + +const ( + // RowIDAllocType indicates the allocator is used to allocate row id. + RowIDAllocType AllocatorType = iota + // AutoIncrementType indicates the allocator is used to allocate auto increment value. + AutoIncrementType + // AutoRandomType indicates the allocator is used to allocate auto-shard id. + AutoRandomType + // SequenceType indicates the allocator is used to allocate sequence value. + SequenceType +) + +func (a AllocatorType) String() string { + switch a { + case RowIDAllocType: + return "_tidb_rowid" + case AutoIncrementType: + return "auto_increment" + case AutoRandomType: + return "auto_random" + case SequenceType: + return "sequence" + } + return "unknown" +} + +// CustomAutoIncCacheOption is one kind of AllocOption to customize the allocator step length. +type CustomAutoIncCacheOption int64 + +// ApplyOn implements the AllocOption interface. +func (step CustomAutoIncCacheOption) ApplyOn(alloc *allocator) { + if step == 0 { + return + } + alloc.step = int64(step) + alloc.customStep = true +} + +// AllocOptionTableInfoVersion is used to pass the TableInfo.Version to the allocator. +type AllocOptionTableInfoVersion uint16 + +// ApplyOn implements the AllocOption interface. +func (v AllocOptionTableInfoVersion) ApplyOn(alloc *allocator) { + alloc.tbVersion = uint16(v) +} + +// AllocOption is a interface to define allocator custom options coming in future. +type AllocOption interface { + ApplyOn(*allocator) +} + +// Allocator is an auto increment id generator. +// Just keep id unique actually. +type Allocator interface { + // Alloc allocs N consecutive autoID for table with tableID, returning (min, max] of the allocated autoID batch. + // It gets a batch of autoIDs at a time. So it does not need to access storage for each call. + // The consecutive feature is used to insert multiple rows in a statement. + // increment & offset is used to validate the start position (the allocator's base is not always the last allocated id). + // The returned range is (min, max]: + // case increment=1 & offset=1: you can derive the ids like min+1, min+2... max. + // case increment=x & offset=y: you firstly need to seek to firstID by `SeekToFirstAutoIDXXX`, then derive the IDs like firstID, firstID + increment * 2... in the caller. + Alloc(ctx context.Context, n uint64, increment, offset int64) (int64, int64, error) + + // AllocSeqCache allocs sequence batch value cached in table level(rather than in alloc), the returned range covering + // the size of sequence cache with it's increment. The returned round indicates the sequence cycle times if it is with + // cycle option. + AllocSeqCache() (min int64, max int64, round int64, err error) + + // Rebase rebases the autoID base for table with tableID and the new base value. + // If allocIDs is true, it will allocate some IDs and save to the cache. + // If allocIDs is false, it will not allocate IDs. + Rebase(ctx context.Context, newBase int64, allocIDs bool) error + + // ForceRebase set the next global auto ID to newBase. + ForceRebase(newBase int64) error + + // RebaseSeq rebases the sequence value in number axis with tableID and the new base value. + RebaseSeq(newBase int64) (int64, bool, error) + + // Base return the current base of Allocator. + Base() int64 + // End is only used for test. + End() int64 + // NextGlobalAutoID returns the next global autoID. + NextGlobalAutoID() (int64, error) + GetType() AllocatorType +} + +// Allocators represents a set of `Allocator`s. +type Allocators struct { + SepAutoInc bool + Allocs []Allocator +} + +// NewAllocators packs multiple `Allocator`s into Allocators. +func NewAllocators(sepAutoInc bool, allocators ...Allocator) Allocators { + return Allocators{ + SepAutoInc: sepAutoInc, + Allocs: allocators, + } +} + +// Append add an allocator to the allocators. +func (all Allocators) Append(a Allocator) Allocators { + return Allocators{ + SepAutoInc: all.SepAutoInc, + Allocs: append(all.Allocs, a), + } +} + +// Get returns the Allocator according to the AllocatorType. +func (all Allocators) Get(allocType AllocatorType) Allocator { + if !all.SepAutoInc { + if allocType == AutoIncrementType { + allocType = RowIDAllocType + } + } + + for _, a := range all.Allocs { + if a.GetType() == allocType { + return a + } + } + return nil +} + +// Filter filters all the allocators that match pred. +func (all Allocators) Filter(pred func(Allocator) bool) Allocators { + var ret []Allocator + for _, a := range all.Allocs { + if pred(a) { + ret = append(ret, a) + } + } + return Allocators{ + SepAutoInc: all.SepAutoInc, + Allocs: ret, + } +} + +type allocator struct { + mu sync.Mutex + base int64 + end int64 + store kv.Storage + // dbID is database ID where it was created. + dbID int64 + tbID int64 + tbVersion uint16 + isUnsigned bool + lastAllocTime time.Time + step int64 + customStep bool + allocType AllocatorType + sequence *model.SequenceInfo +} + +// GetStep is only used by tests +func GetStep() int64 { + return step +} + +// SetStep is only used by tests +func SetStep(s int64) { + step = s +} + +// Base implements autoid.Allocator Base interface. +func (alloc *allocator) Base() int64 { + alloc.mu.Lock() + defer alloc.mu.Unlock() + return alloc.base +} + +// End implements autoid.Allocator End interface. +func (alloc *allocator) End() int64 { + alloc.mu.Lock() + defer alloc.mu.Unlock() + return alloc.end +} + +// NextGlobalAutoID implements autoid.Allocator NextGlobalAutoID interface. +func (alloc *allocator) NextGlobalAutoID() (int64, error) { + var autoID int64 + startTime := time.Now() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { + var err1 error + autoID, err1 = alloc.getIDAccessor(txn).Get() + if err1 != nil { + return errors.Trace(err1) + } + return nil + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.GlobalAutoID, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if alloc.isUnsigned { + return int64(uint64(autoID) + 1), err + } + return autoID + 1, err +} + +func (alloc *allocator) rebase4Unsigned(ctx context.Context, requiredBase uint64, allocIDs bool) error { + // Satisfied by alloc.base, nothing to do. + if requiredBase <= uint64(alloc.base) { + return nil + } + // Satisfied by alloc.end, need to update alloc.base. + if requiredBase <= uint64(alloc.end) { + alloc.base = int64(requiredBase) + return nil + } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.rebaseCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } + var newBase, newEnd uint64 + startTime := time.Now() + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } + idAcc := alloc.getIDAccessor(txn) + currentEnd, err1 := idAcc.Get() + if err1 != nil { + return err1 + } + uCurrentEnd := uint64(currentEnd) + if allocIDs { + newBase = max(uCurrentEnd, requiredBase) + newEnd = min(math.MaxUint64-uint64(alloc.step), newBase) + uint64(alloc.step) + } else { + if uCurrentEnd >= requiredBase { + newBase = uCurrentEnd + newEnd = uCurrentEnd + // Required base satisfied, we don't need to update KV. + return nil + } + // If we don't want to allocate IDs, for example when creating a table with a given base value, + // We need to make sure when other TiDB server allocates ID for the first time, requiredBase + 1 + // will be allocated, so we need to increase the end to exactly the requiredBase. + newBase = requiredBase + newEnd = requiredBase + } + _, err1 = idAcc.Inc(int64(newEnd - uCurrentEnd)) + return err1 + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return err + } + alloc.base, alloc.end = int64(newBase), int64(newEnd) + return nil +} + +func (alloc *allocator) rebase4Signed(ctx context.Context, requiredBase int64, allocIDs bool) error { + // Satisfied by alloc.base, nothing to do. + if requiredBase <= alloc.base { + return nil + } + // Satisfied by alloc.end, need to update alloc.base. + if requiredBase <= alloc.end { + alloc.base = requiredBase + return nil + } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.rebaseCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } + var newBase, newEnd int64 + startTime := time.Now() + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } + idAcc := alloc.getIDAccessor(txn) + currentEnd, err1 := idAcc.Get() + if err1 != nil { + return err1 + } + if allocIDs { + newBase = max(currentEnd, requiredBase) + newEnd = min(math.MaxInt64-alloc.step, newBase) + alloc.step + } else { + if currentEnd >= requiredBase { + newBase = currentEnd + newEnd = currentEnd + // Required base satisfied, we don't need to update KV. + return nil + } + // If we don't want to allocate IDs, for example when creating a table with a given base value, + // We need to make sure when other TiDB server allocates ID for the first time, requiredBase + 1 + // will be allocated, so we need to increase the end to exactly the requiredBase. + newBase = requiredBase + newEnd = requiredBase + } + _, err1 = idAcc.Inc(newEnd - currentEnd) + return err1 + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return err + } + alloc.base, alloc.end = newBase, newEnd + return nil +} + +// rebase4Sequence won't alloc batch immediately, cause it won't cache value in allocator. +func (alloc *allocator) rebase4Sequence(requiredBase int64) (int64, bool, error) { + startTime := time.Now() + alreadySatisfied := false + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { + acc := meta.NewMeta(txn).GetAutoIDAccessors(alloc.dbID, alloc.tbID) + currentEnd, err := acc.SequenceValue().Get() + if err != nil { + return err + } + if alloc.sequence.Increment > 0 { + if currentEnd >= requiredBase { + // Required base satisfied, we don't need to update KV. + alreadySatisfied = true + return nil + } + } else { + if currentEnd <= requiredBase { + // Required base satisfied, we don't need to update KV. + alreadySatisfied = true + return nil + } + } + + // If we don't want to allocate IDs, for example when creating a table with a given base value, + // We need to make sure when other TiDB server allocates ID for the first time, requiredBase + 1 + // will be allocated, so we need to increase the end to exactly the requiredBase. + _, err = acc.SequenceValue().Inc(requiredBase - currentEnd) + return err + }) + // TODO: sequence metrics + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return 0, false, err + } + if alreadySatisfied { + return 0, true, nil + } + return requiredBase, false, err +} + +// Rebase implements autoid.Allocator Rebase interface. +// The requiredBase is the minimum base value after Rebase. +// The real base may be greater than the required base. +func (alloc *allocator) Rebase(ctx context.Context, requiredBase int64, allocIDs bool) error { + alloc.mu.Lock() + defer alloc.mu.Unlock() + if alloc.isUnsigned { + return alloc.rebase4Unsigned(ctx, uint64(requiredBase), allocIDs) + } + return alloc.rebase4Signed(ctx, requiredBase, allocIDs) +} + +// ForceRebase implements autoid.Allocator ForceRebase interface. +func (alloc *allocator) ForceRebase(requiredBase int64) error { + if requiredBase == -1 { + return ErrAutoincReadFailed.GenWithStack("Cannot force rebase the next global ID to '0'") + } + alloc.mu.Lock() + defer alloc.mu.Unlock() + startTime := time.Now() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { + idAcc := alloc.getIDAccessor(txn) + currentEnd, err1 := idAcc.Get() + if err1 != nil { + return err1 + } + var step int64 + if !alloc.isUnsigned { + step = requiredBase - currentEnd + } else { + uRequiredBase, uCurrentEnd := uint64(requiredBase), uint64(currentEnd) + step = int64(uRequiredBase - uCurrentEnd) + } + _, err1 = idAcc.Inc(step) + return err1 + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return err + } + alloc.base, alloc.end = requiredBase, requiredBase + return nil +} + +// Rebase implements autoid.Allocator RebaseSeq interface. +// The return value is quite same as expression function, bool means whether it should be NULL, +// here it will be used in setval expression function (true meaning the set value has been satisfied, return NULL). +// case1:When requiredBase is satisfied with current value, it will return (0, true, nil), +// case2:When requiredBase is successfully set in, it will return (requiredBase, false, nil). +// If some error occurs in the process, return it immediately. +func (alloc *allocator) RebaseSeq(requiredBase int64) (int64, bool, error) { + alloc.mu.Lock() + defer alloc.mu.Unlock() + return alloc.rebase4Sequence(requiredBase) +} + +func (alloc *allocator) GetType() AllocatorType { + return alloc.allocType +} + +// NextStep return new auto id step according to previous step and consuming time. +func NextStep(curStep int64, consumeDur time.Duration) int64 { + failpoint.Inject("mockAutoIDCustomize", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(3) + } + }) + failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(step) + } + }) + + consumeRate := defaultConsumeTime.Seconds() / consumeDur.Seconds() + res := int64(float64(curStep) * consumeRate) + if res < minStep { + return minStep + } else if res > maxStep { + return maxStep + } + return res +} + +// MockForTest is exported for testing. +// The actual implementation is in github.com/pingcap/tidb/pkg/autoid_service because of the +// package circle depending issue. +var MockForTest func(kv.Storage) autoid.AutoIDAllocClient + +func newSinglePointAlloc(r Requirement, dbID, tblID int64, isUnsigned bool) *singlePointAlloc { + keyspaceID := uint32(r.Store().GetCodec().GetKeyspaceID()) + spa := &singlePointAlloc{ + dbID: dbID, + tblID: tblID, + isUnsigned: isUnsigned, + keyspaceID: keyspaceID, + } + if r.AutoIDClient() == nil { + // Only for test in mockstore + spa.ClientDiscover = &ClientDiscover{} + spa.mu.AutoIDAllocClient = MockForTest(r.Store()) + } else { + spa.ClientDiscover = r.AutoIDClient() + } + + // mockAutoIDChange failpoint is not implemented in this allocator, so fallback to use the default one. + failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { + if val.(bool) { + spa = nil + } + }) + return spa +} + +// Requirement is the parameter required by NewAllocator +type Requirement interface { + Store() kv.Storage + AutoIDClient() *ClientDiscover +} + +// NewAllocator returns a new auto increment id generator on the store. +func NewAllocator(r Requirement, dbID, tbID int64, isUnsigned bool, + allocType AllocatorType, opts ...AllocOption) Allocator { + var store kv.Storage + if r != nil { + store = r.Store() + } + alloc := &allocator{ + store: store, + dbID: dbID, + tbID: tbID, + isUnsigned: isUnsigned, + step: step, + lastAllocTime: time.Now(), + allocType: allocType, + } + for _, fn := range opts { + fn.ApplyOn(alloc) + } + + // Use the MySQL compatible AUTO_INCREMENT mode. + if alloc.customStep && alloc.step == 1 && alloc.tbVersion >= model.TableInfoVersion5 { + if allocType == AutoIncrementType { + alloc1 := newSinglePointAlloc(r, dbID, tbID, isUnsigned) + if alloc1 != nil { + return alloc1 + } + } else if allocType == RowIDAllocType { + // Now that the autoid and rowid allocator are separated, the AUTO_ID_CACHE 1 setting should not make + // the rowid allocator do not use cache. + alloc.customStep = false + alloc.step = step + } + } + + return alloc +} + +// NewSequenceAllocator returns a new sequence value generator on the store. +func NewSequenceAllocator(store kv.Storage, dbID, tbID int64, info *model.SequenceInfo) Allocator { + return &allocator{ + store: store, + dbID: dbID, + tbID: tbID, + // Sequence allocator is always signed. + isUnsigned: false, + lastAllocTime: time.Now(), + allocType: SequenceType, + sequence: info, + } +} + +// TODO: Handle allocators when changing Table ID during ALTER TABLE t PARTITION BY ... + +// NewAllocatorsFromTblInfo creates an array of allocators of different types with the information of model.TableInfo. +func NewAllocatorsFromTblInfo(r Requirement, schemaID int64, tblInfo *model.TableInfo) Allocators { + var allocs []Allocator + dbID := tblInfo.GetAutoIDSchemaID(schemaID) + idCacheOpt := CustomAutoIncCacheOption(tblInfo.AutoIdCache) + tblVer := AllocOptionTableInfoVersion(tblInfo.Version) + + hasRowID := !tblInfo.PKIsHandle && !tblInfo.IsCommonHandle + hasAutoIncID := tblInfo.GetAutoIncrementColInfo() != nil + if hasRowID || hasAutoIncID { + alloc := NewAllocator(r, dbID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), RowIDAllocType, idCacheOpt, tblVer) + allocs = append(allocs, alloc) + } + if hasAutoIncID { + alloc := NewAllocator(r, dbID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), AutoIncrementType, idCacheOpt, tblVer) + allocs = append(allocs, alloc) + } + hasAutoRandID := tblInfo.ContainsAutoRandomBits() + if hasAutoRandID { + alloc := NewAllocator(r, dbID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), AutoRandomType, idCacheOpt, tblVer) + allocs = append(allocs, alloc) + } + if tblInfo.IsSequence() { + allocs = append(allocs, NewSequenceAllocator(r.Store(), dbID, tblInfo.ID, tblInfo.Sequence)) + } + return NewAllocators(tblInfo.SepAutoInc(), allocs...) +} + +// Alloc implements autoid.Allocator Alloc interface. +// For autoIncrement allocator, the increment and offset should always be positive in [1, 65535]. +// Attention: +// When increment and offset is not the default value(1), the return range (min, max] need to +// calculate the correct start position rather than simply the add 1 to min. Then you can derive +// the successive autoID by adding increment * cnt to firstID for (n-1) times. +// +// Example: +// (6, 13] is returned, increment = 4, offset = 1, n = 2. +// 6 is the last allocated value for other autoID or handle, maybe with different increment and step, +// but actually we don't care about it, all we need is to calculate the new autoID corresponding to the +// increment and offset at this time now. To simplify the rule is like (ID - offset) % increment = 0, +// so the first autoID should be 9, then add increment to it to get 13. +func (alloc *allocator) Alloc(ctx context.Context, n uint64, increment, offset int64) (min int64, max int64, err error) { + if alloc.tbID == 0 { + return 0, 0, errInvalidTableID.GenWithStackByArgs("Invalid tableID") + } + if n == 0 { + return 0, 0, nil + } + if alloc.allocType == AutoIncrementType || alloc.allocType == RowIDAllocType { + if !validIncrementAndOffset(increment, offset) { + return 0, 0, errInvalidIncrementAndOffset.GenWithStackByArgs(increment, offset) + } + } + alloc.mu.Lock() + defer alloc.mu.Unlock() + if alloc.isUnsigned { + return alloc.alloc4Unsigned(ctx, n, increment, offset) + } + return alloc.alloc4Signed(ctx, n, increment, offset) +} + +func (alloc *allocator) AllocSeqCache() (min int64, max int64, round int64, err error) { + alloc.mu.Lock() + defer alloc.mu.Unlock() + return alloc.alloc4Sequence() +} + +func validIncrementAndOffset(increment, offset int64) bool { + return (increment >= minIncrement && increment <= maxIncrement) && (offset >= minIncrement && offset <= maxIncrement) +} + +// CalcNeededBatchSize is used to calculate batch size for autoID allocation. +// It firstly seeks to the first valid position based on increment and offset, +// then plus the length remained, which could be (n-1) * increment. +func CalcNeededBatchSize(base, n, increment, offset int64, isUnsigned bool) int64 { + if increment == 1 { + return n + } + if isUnsigned { + // SeekToFirstAutoIDUnSigned seeks to the next unsigned valid position. + nr := SeekToFirstAutoIDUnSigned(uint64(base), uint64(increment), uint64(offset)) + // Calculate the total batch size needed. + nr += (uint64(n) - 1) * uint64(increment) + return int64(nr - uint64(base)) + } + nr := SeekToFirstAutoIDSigned(base, increment, offset) + // Calculate the total batch size needed. + nr += (n - 1) * increment + return nr - base +} + +// CalcSequenceBatchSize calculate the next sequence batch size. +func CalcSequenceBatchSize(base, size, increment, offset, min, max int64) (int64, error) { + // The sequence is positive growth. + if increment > 0 { + if increment == 1 { + // Sequence is already allocated to the end. + if base >= max { + return 0, ErrAutoincReadFailed + } + // The rest of sequence < cache size, return the rest. + if max-base < size { + return max - base, nil + } + // The rest of sequence is adequate. + return size, nil + } + nr, ok := SeekToFirstSequenceValue(base, increment, offset, min, max) + if !ok { + return 0, ErrAutoincReadFailed + } + // The rest of sequence < cache size, return the rest. + if max-nr < (size-1)*increment { + return max - base, nil + } + return (nr - base) + (size-1)*increment, nil + } + // The sequence is negative growth. + if increment == -1 { + if base <= min { + return 0, ErrAutoincReadFailed + } + if base-min < size { + return base - min, nil + } + return size, nil + } + nr, ok := SeekToFirstSequenceValue(base, increment, offset, min, max) + if !ok { + return 0, ErrAutoincReadFailed + } + // The rest of sequence < cache size, return the rest. + if nr-min < (size-1)*(-increment) { + return base - min, nil + } + return (base - nr) + (size-1)*(-increment), nil +} + +// SeekToFirstSequenceValue seeks to the next valid value (must be in range of [MIN, max]), +// the bool indicates whether the first value is got. +// The seeking formula is describe as below: +// +// nr := (base + increment - offset) / increment +// +// first := nr*increment + offset +// Because formula computation will overflow Int64, so we transfer it to uint64 for distance computation. +func SeekToFirstSequenceValue(base, increment, offset, min, max int64) (int64, bool) { + if increment > 0 { + // Sequence is already allocated to the end. + if base >= max { + return 0, false + } + uMax := EncodeIntToCmpUint(max) + uBase := EncodeIntToCmpUint(base) + uOffset := EncodeIntToCmpUint(offset) + uIncrement := uint64(increment) + if uMax-uBase < uIncrement { + // Enum the possible first value. + for i := uBase + 1; i <= uMax; i++ { + if (i-uOffset)%uIncrement == 0 { + return DecodeCmpUintToInt(i), true + } + } + return 0, false + } + nr := (uBase + uIncrement - uOffset) / uIncrement + nr = nr*uIncrement + uOffset + first := DecodeCmpUintToInt(nr) + return first, true + } + // Sequence is already allocated to the end. + if base <= min { + return 0, false + } + uMin := EncodeIntToCmpUint(min) + uBase := EncodeIntToCmpUint(base) + uOffset := EncodeIntToCmpUint(offset) + uIncrement := uint64(-increment) + if uBase-uMin < uIncrement { + // Enum the possible first value. + for i := uBase - 1; i >= uMin; i-- { + if (uOffset-i)%uIncrement == 0 { + return DecodeCmpUintToInt(i), true + } + } + return 0, false + } + nr := (uOffset - uBase + uIncrement) / uIncrement + nr = uOffset - nr*uIncrement + first := DecodeCmpUintToInt(nr) + return first, true +} + +// SeekToFirstAutoIDSigned seeks to the next valid signed position. +func SeekToFirstAutoIDSigned(base, increment, offset int64) int64 { + nr := (base + increment - offset) / increment + nr = nr*increment + offset + return nr +} + +// SeekToFirstAutoIDUnSigned seeks to the next valid unsigned position. +func SeekToFirstAutoIDUnSigned(base, increment, offset uint64) uint64 { + nr := (base + increment - offset) / increment + nr = nr*increment + offset + return nr +} + +func (alloc *allocator) alloc4Signed(ctx context.Context, n uint64, increment, offset int64) (mini int64, max int64, err error) { + // Check offset rebase if necessary. + if offset-1 > alloc.base { + if err := alloc.rebase4Signed(ctx, offset-1, true); err != nil { + return 0, 0, err + } + } + // CalcNeededBatchSize calculates the total batch size needed. + n1 := CalcNeededBatchSize(alloc.base, int64(n), increment, offset, alloc.isUnsigned) + + // Condition alloc.base+N1 > alloc.end will overflow when alloc.base + N1 > MaxInt64. So need this. + if math.MaxInt64-alloc.base <= n1 { + return 0, 0, ErrAutoincReadFailed + } + // The local rest is not enough for allocN, skip it. + if alloc.base+n1 > alloc.end { + var newBase, newEnd int64 + startTime := time.Now() + nextStep := alloc.step + if !alloc.customStep && alloc.end > 0 { + // Although it may skip a segment here, we still think it is consumed. + consumeDur := startTime.Sub(alloc.lastAllocTime) + nextStep = NextStep(alloc.step, consumeDur) + } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.allocCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { + defer tracing.StartRegion(ctx, "alloc.alloc4Signed").End() + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } + + idAcc := alloc.getIDAccessor(txn) + var err1 error + newBase, err1 = idAcc.Get() + if err1 != nil { + return err1 + } + // CalcNeededBatchSize calculates the total batch size needed on global base. + n1 = CalcNeededBatchSize(newBase, int64(n), increment, offset, alloc.isUnsigned) + // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. + if nextStep < n1 { + nextStep = n1 + } + tmpStep := min(math.MaxInt64-newBase, nextStep) + // The global rest is not enough for alloc. + if tmpStep < n1 { + return ErrAutoincReadFailed + } + newEnd, err1 = idAcc.Inc(tmpStep) + return err1 + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return 0, 0, err + } + // Store the step for non-customized-step allocator to calculate next dynamic step. + if !alloc.customStep { + alloc.step = nextStep + } + alloc.lastAllocTime = time.Now() + if newBase == math.MaxInt64 { + return 0, 0, ErrAutoincReadFailed + } + alloc.base, alloc.end = newBase, newEnd + } + if logutil.BgLogger().Core().Enabled(zap.DebugLevel) { + logutil.BgLogger().Debug("alloc N signed ID", + zap.Uint64("from ID", uint64(alloc.base)), + zap.Uint64("to ID", uint64(alloc.base+n1)), + zap.Int64("table ID", alloc.tbID), + zap.Int64("database ID", alloc.dbID)) + } + mini = alloc.base + alloc.base += n1 + return mini, alloc.base, nil +} + +func (alloc *allocator) alloc4Unsigned(ctx context.Context, n uint64, increment, offset int64) (mini int64, max int64, err error) { + // Check offset rebase if necessary. + if uint64(offset-1) > uint64(alloc.base) { + if err := alloc.rebase4Unsigned(ctx, uint64(offset-1), true); err != nil { + return 0, 0, err + } + } + // CalcNeededBatchSize calculates the total batch size needed. + n1 := CalcNeededBatchSize(alloc.base, int64(n), increment, offset, alloc.isUnsigned) + + // Condition alloc.base+n1 > alloc.end will overflow when alloc.base + n1 > MaxInt64. So need this. + if math.MaxUint64-uint64(alloc.base) <= uint64(n1) { + return 0, 0, ErrAutoincReadFailed + } + // The local rest is not enough for alloc, skip it. + if uint64(alloc.base)+uint64(n1) > uint64(alloc.end) { + var newBase, newEnd int64 + startTime := time.Now() + nextStep := alloc.step + if !alloc.customStep { + // Although it may skip a segment here, we still treat it as consumed. + consumeDur := startTime.Sub(alloc.lastAllocTime) + nextStep = NextStep(alloc.step, consumeDur) + } + + ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) + if allocatorStats != nil { + allocatorStats.allocCount++ + defer func() { + if commitDetail != nil { + allocatorStats.mergeCommitDetail(*commitDetail) + } + }() + } + + if codeRun := ctx.Value("testIssue39528"); codeRun != nil { + *(codeRun.(*bool)) = true + return 0, 0, errors.New("mock error for test") + } + + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { + defer tracing.StartRegion(ctx, "alloc.alloc4Unsigned").End() + if allocatorStats != nil { + txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) + } + + idAcc := alloc.getIDAccessor(txn) + var err1 error + newBase, err1 = idAcc.Get() + if err1 != nil { + return err1 + } + // CalcNeededBatchSize calculates the total batch size needed on new base. + n1 = CalcNeededBatchSize(newBase, int64(n), increment, offset, alloc.isUnsigned) + // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. + if nextStep < n1 { + nextStep = n1 + } + tmpStep := int64(min(math.MaxUint64-uint64(newBase), uint64(nextStep))) + // The global rest is not enough for alloc. + if tmpStep < n1 { + return ErrAutoincReadFailed + } + newEnd, err1 = idAcc.Inc(tmpStep) + return err1 + }) + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return 0, 0, err + } + // Store the step for non-customized-step allocator to calculate next dynamic step. + if !alloc.customStep { + alloc.step = nextStep + } + alloc.lastAllocTime = time.Now() + if uint64(newBase) == math.MaxUint64 { + return 0, 0, ErrAutoincReadFailed + } + alloc.base, alloc.end = newBase, newEnd + } + logutil.Logger(context.TODO()).Debug("alloc unsigned ID", + zap.Uint64(" from ID", uint64(alloc.base)), + zap.Uint64("to ID", uint64(alloc.base+n1)), + zap.Int64("table ID", alloc.tbID), + zap.Int64("database ID", alloc.dbID)) + mini = alloc.base + // Use uint64 n directly. + alloc.base = int64(uint64(alloc.base) + uint64(n1)) + return mini, alloc.base, nil +} + +func getAllocatorStatsFromCtx(ctx context.Context) (context.Context, *AllocatorRuntimeStats, **tikvutil.CommitDetails) { + var allocatorStats *AllocatorRuntimeStats + var commitDetail *tikvutil.CommitDetails + ctxValue := ctx.Value(AllocatorRuntimeStatsCtxKey) + if ctxValue != nil { + allocatorStats = ctxValue.(*AllocatorRuntimeStats) + ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) + } + return ctx, allocatorStats, &commitDetail +} + +// alloc4Sequence is used to alloc value for sequence, there are several aspects different from autoid logic. +// 1: sequence allocation don't need check rebase. +// 2: sequence allocation don't need auto step. +// 3: sequence allocation may have negative growth. +// 4: sequence allocation batch length can be dissatisfied. +// 5: sequence batch allocation will be consumed immediately. +func (alloc *allocator) alloc4Sequence() (min int64, max int64, round int64, err error) { + increment := alloc.sequence.Increment + offset := alloc.sequence.Start + minValue := alloc.sequence.MinValue + maxValue := alloc.sequence.MaxValue + cacheSize := alloc.sequence.CacheValue + if !alloc.sequence.Cache { + cacheSize = 1 + } + + var newBase, newEnd int64 + startTime := time.Now() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) + err = kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { + acc := meta.NewMeta(txn).GetAutoIDAccessors(alloc.dbID, alloc.tbID) + var ( + err1 error + seqStep int64 + ) + // Get the real offset if the sequence is in cycle. + // round is used to count cycle times in sequence with cycle option. + if alloc.sequence.Cycle { + // GetSequenceCycle is used to get the flag `round`, which indicates whether the sequence is already in cycle. + round, err1 = acc.SequenceCycle().Get() + if err1 != nil { + return err1 + } + if round > 0 { + if increment > 0 { + offset = alloc.sequence.MinValue + } else { + offset = alloc.sequence.MaxValue + } + } + } + + // Get the global new base. + newBase, err1 = acc.SequenceValue().Get() + if err1 != nil { + return err1 + } + + // CalcNeededBatchSize calculates the total batch size needed. + seqStep, err1 = CalcSequenceBatchSize(newBase, cacheSize, increment, offset, minValue, maxValue) + + if err1 != nil && err1 == ErrAutoincReadFailed { + if !alloc.sequence.Cycle { + return err1 + } + // Reset the sequence base and offset. + if alloc.sequence.Increment > 0 { + newBase = alloc.sequence.MinValue - 1 + offset = alloc.sequence.MinValue + } else { + newBase = alloc.sequence.MaxValue + 1 + offset = alloc.sequence.MaxValue + } + err1 = acc.SequenceValue().Put(newBase) + if err1 != nil { + return err1 + } + + // Reset sequence round state value. + round++ + // SetSequenceCycle is used to store the flag `round` which indicates whether the sequence is already in cycle. + // round > 0 means the sequence is already in cycle, so the offset should be minvalue / maxvalue rather than sequence.start. + // TiDB is a stateless node, it should know whether the sequence is already in cycle when restart. + err1 = acc.SequenceCycle().Put(round) + if err1 != nil { + return err1 + } + + // Recompute the sequence next batch size. + seqStep, err1 = CalcSequenceBatchSize(newBase, cacheSize, increment, offset, minValue, maxValue) + if err1 != nil { + return err1 + } + } + var delta int64 + if alloc.sequence.Increment > 0 { + delta = seqStep + } else { + delta = -seqStep + } + newEnd, err1 = acc.SequenceValue().Inc(delta) + return err1 + }) + + // TODO: sequence metrics + metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err != nil { + return 0, 0, 0, err + } + logutil.Logger(context.TODO()).Debug("alloc sequence value", + zap.Uint64(" from value", uint64(newBase)), + zap.Uint64("to value", uint64(newEnd)), + zap.Int64("table ID", alloc.tbID), + zap.Int64("database ID", alloc.dbID)) + return newBase, newEnd, round, nil +} + +func (alloc *allocator) getIDAccessor(txn kv.Transaction) meta.AutoIDAccessor { + acc := meta.NewMeta(txn).GetAutoIDAccessors(alloc.dbID, alloc.tbID) + switch alloc.allocType { + case RowIDAllocType: + return acc.RowID() + case AutoIncrementType: + return acc.IncrementID(alloc.tbVersion) + case AutoRandomType: + return acc.RandomID() + case SequenceType: + return acc.SequenceValue() + } + return nil +} + +const signMask uint64 = 0x8000000000000000 + +// EncodeIntToCmpUint make int v to comparable uint type +func EncodeIntToCmpUint(v int64) uint64 { + return uint64(v) ^ signMask +} + +// DecodeCmpUintToInt decodes the u that encoded by EncodeIntToCmpUint +func DecodeCmpUintToInt(u uint64) int64 { + return int64(u ^ signMask) +} + +// TestModifyBaseAndEndInjection exported for testing modifying the base and end. +func TestModifyBaseAndEndInjection(alloc Allocator, base, end int64) { + alloc.(*allocator).mu.Lock() + alloc.(*allocator).base = base + alloc.(*allocator).end = end + alloc.(*allocator).mu.Unlock() +} + +// ShardIDFormat is used to calculate the bit length of different segments in auto id. +// Generally, an auto id is consist of 4 segments: sign bit, reserved bits, shard bits and incremental bits. +// Take "a BIGINT AUTO_INCREMENT PRIMARY KEY" as an example, assume that the `shard_row_id_bits` = 5, +// the layout is like +// +// | [sign_bit] (1 bit) | [reserved bits] (0 bits) | [shard_bits] (5 bits) | [incremental_bits] (64-1-5=58 bits) | +// +// Please always use NewShardIDFormat() to instantiate. +type ShardIDFormat struct { + FieldType *types.FieldType + ShardBits uint64 + // Derived fields. + IncrementalBits uint64 +} + +// NewShardIDFormat create an instance of ShardIDFormat. +// RangeBits means the bit length of the sign bit + shard bits + incremental bits. +// If RangeBits is 0, it will be calculated according to field type automatically. +func NewShardIDFormat(fieldType *types.FieldType, shardBits, rangeBits uint64) ShardIDFormat { + var incrementalBits uint64 + if rangeBits == 0 { + // Zero means that the range bits is not specified. We interpret it as the length of BIGINT. + incrementalBits = RowIDBitLength - shardBits + } else { + incrementalBits = rangeBits - shardBits + } + hasSignBit := !mysql.HasUnsignedFlag(fieldType.GetFlag()) + if hasSignBit { + incrementalBits-- + } + return ShardIDFormat{ + FieldType: fieldType, + ShardBits: shardBits, + IncrementalBits: incrementalBits, + } +} + +// IncrementalBitsCapacity returns the max capacity of incremental section of the current format. +func (s *ShardIDFormat) IncrementalBitsCapacity() uint64 { + return uint64(s.IncrementalMask()) +} + +// IncrementalMask returns 00..0[11..1], where [11..1] is the incremental part of the current format. +func (s *ShardIDFormat) IncrementalMask() int64 { + return (1 << s.IncrementalBits) - 1 +} + +// Compose generates an auto ID based on the given shard and an incremental ID. +func (s *ShardIDFormat) Compose(shard int64, id int64) int64 { + return ((shard & ((1 << s.ShardBits) - 1)) << s.IncrementalBits) | id +} + +type allocatorRuntimeStatsCtxKeyType struct{} + +// AllocatorRuntimeStatsCtxKey is the context key of allocator runtime stats. +var AllocatorRuntimeStatsCtxKey = allocatorRuntimeStatsCtxKeyType{} + +// AllocatorRuntimeStats is the execution stats of auto id allocator. +type AllocatorRuntimeStats struct { + *txnsnapshot.SnapshotRuntimeStats + *execdetails.RuntimeStatsWithCommit + allocCount int + rebaseCount int +} + +// NewAllocatorRuntimeStats return a new AllocatorRuntimeStats. +func NewAllocatorRuntimeStats() *AllocatorRuntimeStats { + return &AllocatorRuntimeStats{ + SnapshotRuntimeStats: &txnsnapshot.SnapshotRuntimeStats{}, + } +} + +func (e *AllocatorRuntimeStats) mergeCommitDetail(detail *tikvutil.CommitDetails) { + if detail == nil { + return + } + if e.RuntimeStatsWithCommit == nil { + e.RuntimeStatsWithCommit = &execdetails.RuntimeStatsWithCommit{} + } + e.RuntimeStatsWithCommit.MergeCommitDetails(detail) +} + +// String implements the RuntimeStats interface. +func (e *AllocatorRuntimeStats) String() string { + if e.allocCount == 0 && e.rebaseCount == 0 { + return "" + } + var buf bytes.Buffer + buf.WriteString("auto_id_allocator: {") + initialSize := buf.Len() + if e.allocCount > 0 { + buf.WriteString("alloc_cnt: ") + buf.WriteString(strconv.FormatInt(int64(e.allocCount), 10)) + } + if e.rebaseCount > 0 { + if buf.Len() > initialSize { + buf.WriteString(", ") + } + buf.WriteString("rebase_cnt: ") + buf.WriteString(strconv.FormatInt(int64(e.rebaseCount), 10)) + } + if e.SnapshotRuntimeStats != nil { + stats := e.SnapshotRuntimeStats.String() + if stats != "" { + if buf.Len() > initialSize { + buf.WriteString(", ") + } + buf.WriteString(e.SnapshotRuntimeStats.String()) + } + } + if e.RuntimeStatsWithCommit != nil { + stats := e.RuntimeStatsWithCommit.String() + if stats != "" { + if buf.Len() > initialSize { + buf.WriteString(", ") + } + buf.WriteString(stats) + } + } + buf.WriteString("}") + return buf.String() +} + +// Clone implements the RuntimeStats interface. +func (e *AllocatorRuntimeStats) Clone() *AllocatorRuntimeStats { + newRs := &AllocatorRuntimeStats{ + allocCount: e.allocCount, + rebaseCount: e.rebaseCount, + } + if e.SnapshotRuntimeStats != nil { + snapshotStats := e.SnapshotRuntimeStats.Clone() + newRs.SnapshotRuntimeStats = snapshotStats + } + if e.RuntimeStatsWithCommit != nil { + newRs.RuntimeStatsWithCommit = e.RuntimeStatsWithCommit.Clone().(*execdetails.RuntimeStatsWithCommit) + } + return newRs +} + +// Merge implements the RuntimeStats interface. +func (e *AllocatorRuntimeStats) Merge(other *AllocatorRuntimeStats) { + if other == nil { + return + } + if other.SnapshotRuntimeStats != nil { + if e.SnapshotRuntimeStats == nil { + e.SnapshotRuntimeStats = other.SnapshotRuntimeStats.Clone() + } else { + e.SnapshotRuntimeStats.Merge(other.SnapshotRuntimeStats) + } + } + if other.RuntimeStatsWithCommit != nil { + if e.RuntimeStatsWithCommit == nil { + e.RuntimeStatsWithCommit = other.RuntimeStatsWithCommit.Clone().(*execdetails.RuntimeStatsWithCommit) + } else { + e.RuntimeStatsWithCommit.Merge(other.RuntimeStatsWithCommit) + } + } +} diff --git a/pkg/meta/autoid/binding__failpoint_binding__.go b/pkg/meta/autoid/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..2c1025c7f434f --- /dev/null +++ b/pkg/meta/autoid/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package autoid + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/owner/binding__failpoint_binding__.go b/pkg/owner/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..6f8eac02d8e5b --- /dev/null +++ b/pkg/owner/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package owner + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/owner/manager.go b/pkg/owner/manager.go index cfe227e580699..4453f88bca07b 100644 --- a/pkg/owner/manager.go +++ b/pkg/owner/manager.go @@ -273,12 +273,12 @@ func (m *ownerManager) campaignLoop(etcdSession *concurrency.Session) { } m.sessionLease.Store(int64(etcdSession.Lease())) case <-campaignContext.Done(): - failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("MockDelOwnerKey")); _err_ == nil { if v.(string) == "delOwnerKeyAndNotOwner" { logutil.Logger(logCtx).Info("mock break campaign and don't clear related info") return } - }) + } logutil.Logger(logCtx).Info("break campaign loop, context is done") m.revokeSession(logPrefix, etcdSession.Lease()) return @@ -408,13 +408,13 @@ func (m *ownerManager) SetOwnerOpValue(ctx context.Context, op OpType) error { } newOwnerVal := joinOwnerValues(ownerID, []byte{byte(op)}) - failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("MockDelOwnerKey")); _err_ == nil { if valStr, ok := v.(string); ok { if err := mockDelOwnerKey(valStr, ownerKey, m); err != nil { - failpoint.Return(err) + return err } } - }) + } leaseOp := clientv3.WithLease(clientv3.LeaseID(m.sessionLease.Load())) resp, err := m.etcdCli.Txn(ctx). diff --git a/pkg/owner/manager.go__failpoint_stash__ b/pkg/owner/manager.go__failpoint_stash__ new file mode 100644 index 0000000000000..cfe227e580699 --- /dev/null +++ b/pkg/owner/manager.go__failpoint_stash__ @@ -0,0 +1,486 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package owner + +import ( + "bytes" + "context" + "fmt" + "os" + "strconv" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/terror" + util2 "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.etcd.io/etcd/api/v3/mvccpb" + "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// Listener is used to listen the ownerManager's owner state. +type Listener interface { + OnBecomeOwner() + OnRetireOwner() +} + +// Manager is used to campaign the owner and manage the owner information. +type Manager interface { + // ID returns the ID of the manager. + ID() string + // IsOwner returns whether the ownerManager is the owner. + IsOwner() bool + // RetireOwner make the manager to be a not owner. It's exported for testing. + RetireOwner() + // GetOwnerID gets the owner ID. + GetOwnerID(ctx context.Context) (string, error) + // SetOwnerOpValue updates the owner op value. + SetOwnerOpValue(ctx context.Context, op OpType) error + // CampaignOwner campaigns the owner. + CampaignOwner(...int) error + // ResignOwner lets the owner start a new election. + ResignOwner(ctx context.Context) error + // Cancel cancels this etcd ownerManager. + Cancel() + // RequireOwner requires the ownerManager is owner. + RequireOwner(ctx context.Context) error + // CampaignCancel cancels one etcd campaign + CampaignCancel() + // SetListener sets the listener, set before CampaignOwner. + SetListener(listener Listener) +} + +const ( + keyOpDefaultTimeout = 5 * time.Second +) + +// OpType is the owner key value operation type. +type OpType byte + +// List operation of types. +const ( + OpNone OpType = 0 + OpSyncUpgradingState OpType = 1 +) + +// String implements fmt.Stringer interface. +func (ot OpType) String() string { + switch ot { + case OpSyncUpgradingState: + return "sync upgrading state" + default: + return "none" + } +} + +// IsSyncedUpgradingState represents whether the upgrading state is synchronized. +func (ot OpType) IsSyncedUpgradingState() bool { + return ot == OpSyncUpgradingState +} + +// DDLOwnerChecker is used to check whether tidb is owner. +type DDLOwnerChecker interface { + // IsOwner returns whether the ownerManager is the owner. + IsOwner() bool +} + +// ownerManager represents the structure which is used for electing owner. +type ownerManager struct { + id string // id is the ID of the manager. + key string + ctx context.Context + prompt string + logPrefix string + logCtx context.Context + etcdCli *clientv3.Client + cancel context.CancelFunc + elec atomic.Pointer[concurrency.Election] + sessionLease *atomicutil.Int64 + wg sync.WaitGroup + campaignCancel context.CancelFunc + + listener Listener +} + +// NewOwnerManager creates a new Manager. +func NewOwnerManager(ctx context.Context, etcdCli *clientv3.Client, prompt, id, key string) Manager { + logPrefix := fmt.Sprintf("[%s] %s ownerManager %s", prompt, key, id) + ctx, cancelFunc := context.WithCancel(ctx) + return &ownerManager{ + etcdCli: etcdCli, + id: id, + key: key, + ctx: ctx, + prompt: prompt, + cancel: cancelFunc, + logPrefix: logPrefix, + logCtx: logutil.WithKeyValue(context.Background(), "owner info", logPrefix), + sessionLease: atomicutil.NewInt64(0), + } +} + +// ID implements Manager.ID interface. +func (m *ownerManager) ID() string { + return m.id +} + +// IsOwner implements Manager.IsOwner interface. +func (m *ownerManager) IsOwner() bool { + return m.elec.Load() != nil +} + +// Cancel implements Manager.Cancel interface. +func (m *ownerManager) Cancel() { + m.cancel() + m.wg.Wait() +} + +// RequireOwner implements Manager.RequireOwner interface. +func (*ownerManager) RequireOwner(_ context.Context) error { + return nil +} + +func (m *ownerManager) SetListener(listener Listener) { + m.listener = listener +} + +// ManagerSessionTTL is the etcd session's TTL in seconds. It's exported for testing. +var ManagerSessionTTL = 60 + +// setManagerSessionTTL sets the ManagerSessionTTL value, it's used for testing. +func setManagerSessionTTL() error { + ttlStr := os.Getenv("tidb_manager_ttl") + if ttlStr == "" { + return nil + } + ttl, err := strconv.Atoi(ttlStr) + if err != nil { + return errors.Trace(err) + } + ManagerSessionTTL = ttl + return nil +} + +// CampaignOwner implements Manager.CampaignOwner interface. +func (m *ownerManager) CampaignOwner(withTTL ...int) error { + ttl := ManagerSessionTTL + if len(withTTL) == 1 { + ttl = withTTL[0] + } + logPrefix := fmt.Sprintf("[%s] %s", m.prompt, m.key) + logutil.BgLogger().Info("start campaign owner", zap.String("ownerInfo", logPrefix)) + session, err := util2.NewSession(m.ctx, logPrefix, m.etcdCli, util2.NewSessionDefaultRetryCnt, ttl) + if err != nil { + return errors.Trace(err) + } + m.sessionLease.Store(int64(session.Lease())) + m.wg.Add(1) + go m.campaignLoop(session) + return nil +} + +// ResignOwner lets the owner start a new election. +func (m *ownerManager) ResignOwner(ctx context.Context) error { + elec := m.elec.Load() + if elec == nil { + return errors.Errorf("This node is not a owner, can't be resigned") + } + + childCtx, cancel := context.WithTimeout(ctx, keyOpDefaultTimeout) + err := elec.Resign(childCtx) + cancel() + if err != nil { + return errors.Trace(err) + } + + logutil.Logger(m.logCtx).Warn("resign owner success") + return nil +} + +func (m *ownerManager) toBeOwner(elec *concurrency.Election) { + m.elec.Store(elec) + logutil.Logger(m.logCtx).Info("become owner") + if m.listener != nil { + m.listener.OnBecomeOwner() + } +} + +// RetireOwner make the manager to be a not owner. +func (m *ownerManager) RetireOwner() { + m.elec.Store(nil) + logutil.Logger(m.logCtx).Info("retire owner") + if m.listener != nil { + m.listener.OnRetireOwner() + } +} + +// CampaignCancel implements Manager.CampaignCancel interface. +func (m *ownerManager) CampaignCancel() { + m.campaignCancel() + m.wg.Wait() +} + +func (m *ownerManager) campaignLoop(etcdSession *concurrency.Session) { + var campaignContext context.Context + campaignContext, m.campaignCancel = context.WithCancel(m.ctx) + defer func() { + m.campaignCancel() + if r := recover(); r != nil { + logutil.BgLogger().Error("recover panic", zap.String("prompt", m.prompt), zap.Any("error", r), zap.Stack("buffer")) + metrics.PanicCounter.WithLabelValues(metrics.LabelDDLOwner).Inc() + } + m.wg.Done() + }() + + logPrefix := m.logPrefix + logCtx := m.logCtx + var err error + for { + if err != nil { + metrics.CampaignOwnerCounter.WithLabelValues(m.prompt, err.Error()).Inc() + } + + select { + case <-etcdSession.Done(): + logutil.Logger(logCtx).Info("etcd session is done, creates a new one") + leaseID := etcdSession.Lease() + etcdSession, err = util2.NewSession(campaignContext, logPrefix, m.etcdCli, util2.NewSessionRetryUnlimited, ManagerSessionTTL) + if err != nil { + logutil.Logger(logCtx).Info("break campaign loop, NewSession failed", zap.Error(err)) + m.revokeSession(logPrefix, leaseID) + return + } + m.sessionLease.Store(int64(etcdSession.Lease())) + case <-campaignContext.Done(): + failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { + if v.(string) == "delOwnerKeyAndNotOwner" { + logutil.Logger(logCtx).Info("mock break campaign and don't clear related info") + return + } + }) + logutil.Logger(logCtx).Info("break campaign loop, context is done") + m.revokeSession(logPrefix, etcdSession.Lease()) + return + default: + } + // If the etcd server turns clocks forward,the following case may occur. + // The etcd server deletes this session's lease ID, but etcd session doesn't find it. + // In this time if we do the campaign operation, the etcd server will return ErrLeaseNotFound. + if terror.ErrorEqual(err, rpctypes.ErrLeaseNotFound) { + if etcdSession != nil { + err = etcdSession.Close() + logutil.Logger(logCtx).Info("etcd session encounters the error of lease not found, closes it", zap.Error(err)) + } + continue + } + + elec := concurrency.NewElection(etcdSession, m.key) + err = elec.Campaign(campaignContext, m.id) + if err != nil { + logutil.Logger(logCtx).Info("failed to campaign", zap.Error(err)) + continue + } + + ownerKey, err := GetOwnerKey(campaignContext, logCtx, m.etcdCli, m.key, m.id) + if err != nil { + continue + } + + m.toBeOwner(elec) + m.watchOwner(campaignContext, etcdSession, ownerKey) + m.RetireOwner() + + metrics.CampaignOwnerCounter.WithLabelValues(m.prompt, metrics.NoLongerOwner).Inc() + logutil.Logger(logCtx).Warn("is not the owner") + } +} + +func (m *ownerManager) revokeSession(_ string, leaseID clientv3.LeaseID) { + // Revoke the session lease. + // If revoke takes longer than the ttl, lease is expired anyway. + cancelCtx, cancel := context.WithTimeout(context.Background(), + time.Duration(ManagerSessionTTL)*time.Second) + _, err := m.etcdCli.Revoke(cancelCtx, leaseID) + cancel() + logutil.Logger(m.logCtx).Info("revoke session", zap.Error(err)) +} + +// GetOwnerID implements Manager.GetOwnerID interface. +func (m *ownerManager) GetOwnerID(ctx context.Context) (string, error) { + _, ownerID, _, _, err := getOwnerInfo(ctx, m.logCtx, m.etcdCli, m.key) + return string(ownerID), errors.Trace(err) +} + +func getOwnerInfo(ctx, logCtx context.Context, etcdCli *clientv3.Client, ownerPath string) (string, []byte, OpType, int64, error) { + var op OpType + var resp *clientv3.GetResponse + var err error + for i := 0; i < 3; i++ { + if err = ctx.Err(); err != nil { + return "", nil, op, 0, errors.Trace(err) + } + + childCtx, cancel := context.WithTimeout(ctx, util.KeyOpDefaultTimeout) + resp, err = etcdCli.Get(childCtx, ownerPath, clientv3.WithFirstCreate()...) + cancel() + if err == nil { + break + } + logutil.Logger(logCtx).Info("etcd-cli get owner info failed", zap.String("key", ownerPath), zap.Int("retryCnt", i), zap.Error(err)) + time.Sleep(util.KeyOpRetryInterval) + } + if err != nil { + logutil.Logger(logCtx).Warn("etcd-cli get owner info failed", zap.Error(err)) + return "", nil, op, 0, errors.Trace(err) + } + if len(resp.Kvs) == 0 { + return "", nil, op, 0, concurrency.ErrElectionNoLeader + } + + var ownerID []byte + ownerID, op = splitOwnerValues(resp.Kvs[0].Value) + logutil.Logger(logCtx).Info("get owner", zap.ByteString("owner key", resp.Kvs[0].Key), + zap.ByteString("ownerID", ownerID), zap.Stringer("op", op)) + return string(resp.Kvs[0].Key), ownerID, op, resp.Kvs[0].ModRevision, nil +} + +// GetOwnerKey gets the owner key information. +func GetOwnerKey(ctx, logCtx context.Context, etcdCli *clientv3.Client, etcdKey, id string) (string, error) { + ownerKey, ownerID, _, _, err := getOwnerInfo(ctx, logCtx, etcdCli, etcdKey) + if err != nil { + return "", errors.Trace(err) + } + if string(ownerID) != id { + logutil.Logger(logCtx).Warn("is not the owner") + return "", errors.New("ownerInfoNotMatch") + } + + return ownerKey, nil +} + +func splitOwnerValues(val []byte) ([]byte, OpType) { + vals := bytes.Split(val, []byte("_")) + var op OpType + if len(vals) == 2 { + op = OpType(vals[1][0]) + } + return vals[0], op +} + +func joinOwnerValues(vals ...[]byte) []byte { + return bytes.Join(vals, []byte("_")) +} + +// SetOwnerOpValue implements Manager.SetOwnerOpValue interface. +func (m *ownerManager) SetOwnerOpValue(ctx context.Context, op OpType) error { + // owner don't change. + ownerKey, ownerID, currOp, modRevision, err := getOwnerInfo(ctx, m.logCtx, m.etcdCli, m.key) + if err != nil { + return errors.Trace(err) + } + if currOp == op { + logutil.Logger(m.logCtx).Info("set owner op is the same as the original, so do nothing.", zap.Stringer("op", op)) + return nil + } + if string(ownerID) != m.id { + return errors.New("ownerInfoNotMatch") + } + newOwnerVal := joinOwnerValues(ownerID, []byte{byte(op)}) + + failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { + if valStr, ok := v.(string); ok { + if err := mockDelOwnerKey(valStr, ownerKey, m); err != nil { + failpoint.Return(err) + } + } + }) + + leaseOp := clientv3.WithLease(clientv3.LeaseID(m.sessionLease.Load())) + resp, err := m.etcdCli.Txn(ctx). + If(clientv3.Compare(clientv3.ModRevision(ownerKey), "=", modRevision)). + Then(clientv3.OpPut(ownerKey, string(newOwnerVal), leaseOp)). + Commit() + if err == nil && !resp.Succeeded { + err = errors.New("put owner key failed, cmp is false") + } + logutil.BgLogger().Info("set owner op value", zap.String("owner key", ownerKey), zap.ByteString("ownerID", ownerID), + zap.Stringer("old Op", currOp), zap.Stringer("op", op), zap.Error(err)) + metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.PutValue+"_"+metrics.RetLabel(err)).Inc() + return errors.Trace(err) +} + +// GetOwnerOpValue gets the owner op value. +func GetOwnerOpValue(ctx context.Context, etcdCli *clientv3.Client, ownerPath, logPrefix string) (OpType, error) { + // It's using for testing. + if etcdCli == nil { + return *mockOwnerOpValue.Load(), nil + } + + logCtx := logutil.WithKeyValue(context.Background(), "owner info", logPrefix) + _, _, op, _, err := getOwnerInfo(ctx, logCtx, etcdCli, ownerPath) + return op, errors.Trace(err) +} + +func (m *ownerManager) watchOwner(ctx context.Context, etcdSession *concurrency.Session, key string) { + logPrefix := fmt.Sprintf("[%s] ownerManager %s watch owner key %v", m.prompt, m.id, key) + logCtx := logutil.WithKeyValue(context.Background(), "owner info", logPrefix) + logutil.BgLogger().Debug(logPrefix) + watchCh := m.etcdCli.Watch(ctx, key) + for { + select { + case resp, ok := <-watchCh: + if !ok { + metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.WatcherClosed).Inc() + logutil.Logger(logCtx).Info("watcher is closed, no owner") + return + } + if resp.Canceled { + metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.Cancelled).Inc() + logutil.Logger(logCtx).Info("watch canceled, no owner") + return + } + + for _, ev := range resp.Events { + if ev.Type == mvccpb.DELETE { + metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.Deleted).Inc() + logutil.Logger(logCtx).Info("watch failed, owner is deleted") + return + } + } + case <-etcdSession.Done(): + metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.SessionDone).Inc() + return + case <-ctx.Done(): + metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.CtxDone).Inc() + return + } + } +} + +func init() { + err := setManagerSessionTTL() + if err != nil { + logutil.BgLogger().Warn("set manager session TTL failed", zap.Error(err)) + } +} diff --git a/pkg/owner/mock.go b/pkg/owner/mock.go index 75a934307b41e..372f587e385f0 100644 --- a/pkg/owner/mock.go +++ b/pkg/owner/mock.go @@ -123,11 +123,11 @@ func (m *mockManager) GetOwnerID(_ context.Context) (string, error) { } func (*mockManager) SetOwnerOpValue(_ context.Context, op OpType) error { - failpoint.Inject("MockNotSetOwnerOp", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockNotSetOwnerOp")); _err_ == nil { if val.(bool) { - failpoint.Return(nil) + return nil } - }) + } mockOwnerOpValue.Store(&op) return nil } diff --git a/pkg/owner/mock.go__failpoint_stash__ b/pkg/owner/mock.go__failpoint_stash__ new file mode 100644 index 0000000000000..75a934307b41e --- /dev/null +++ b/pkg/owner/mock.go__failpoint_stash__ @@ -0,0 +1,230 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package owner + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/timeutil" + "go.uber.org/zap" +) + +var _ Manager = &mockManager{} + +// mockManager represents the structure which is used for electing owner. +// It's used for local store and testing. +// So this worker will always be the owner. +type mockManager struct { + id string // id is the ID of manager. + storeID string + key string + ctx context.Context + wg sync.WaitGroup + cancel context.CancelFunc + listener Listener + retireHook func() + campaignDone chan struct{} + resignDone chan struct{} +} + +var mockOwnerOpValue atomic.Pointer[OpType] + +// NewMockManager creates a new mock Manager. +func NewMockManager(ctx context.Context, id string, store kv.Storage, ownerKey string) Manager { + cancelCtx, cancelFunc := context.WithCancel(ctx) + storeID := "mock_store_id" + if store != nil { + storeID = store.UUID() + } + + // Make sure the mockOwnerOpValue is initialized before GetOwnerOpValue in bootstrap. + op := OpNone + mockOwnerOpValue.Store(&op) + return &mockManager{ + id: id, + storeID: storeID, + key: ownerKey, + ctx: cancelCtx, + cancel: cancelFunc, + campaignDone: make(chan struct{}), + resignDone: make(chan struct{}), + } +} + +// ID implements Manager.ID interface. +func (m *mockManager) ID() string { + return m.id +} + +// IsOwner implements Manager.IsOwner interface. +func (m *mockManager) IsOwner() bool { + logutil.BgLogger().Debug("owner manager checks owner", + zap.String("ownerKey", m.key), zap.String("ID", m.id)) + return util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).IsOwner(m.id) +} + +func (m *mockManager) toBeOwner() { + ok := util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).SetOwner(m.id) + if ok { + logutil.BgLogger().Info("owner manager gets owner", + zap.String("ownerKey", m.key), zap.String("ID", m.id)) + if m.listener != nil { + m.listener.OnBecomeOwner() + } + } +} + +// RetireOwner implements Manager.RetireOwner interface. +func (m *mockManager) RetireOwner() { + ok := util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).UnsetOwner(m.id) + if ok { + logutil.BgLogger().Info("owner manager retire owner", + zap.String("ownerKey", m.key), zap.String("ID", m.id)) + if m.listener != nil { + m.listener.OnRetireOwner() + } + } +} + +// Cancel implements Manager.Cancel interface. +func (m *mockManager) Cancel() { + m.cancel() + m.wg.Wait() + logutil.BgLogger().Info("owner manager is canceled", + zap.String("ownerKey", m.key), zap.String("ID", m.id)) +} + +// GetOwnerID implements Manager.GetOwnerID interface. +func (m *mockManager) GetOwnerID(_ context.Context) (string, error) { + if m.IsOwner() { + return m.ID(), nil + } + return "", errors.New("no owner") +} + +func (*mockManager) SetOwnerOpValue(_ context.Context, op OpType) error { + failpoint.Inject("MockNotSetOwnerOp", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil) + } + }) + mockOwnerOpValue.Store(&op) + return nil +} + +// CampaignOwner implements Manager.CampaignOwner interface. +func (m *mockManager) CampaignOwner(_ ...int) error { + m.wg.Add(1) + go func() { + logutil.BgLogger().Debug("owner manager campaign owner", + zap.String("ownerKey", m.key), zap.String("ID", m.id)) + defer m.wg.Done() + for { + select { + case <-m.campaignDone: + m.RetireOwner() + logutil.BgLogger().Debug("owner manager campaign done", zap.String("ID", m.id)) + return + case <-m.ctx.Done(): + m.RetireOwner() + logutil.BgLogger().Debug("owner manager is cancelled", zap.String("ID", m.id)) + return + case <-m.resignDone: + m.RetireOwner() + //nolint: errcheck + timeutil.Sleep(m.ctx, 1*time.Second) // Give a chance to the other owner managers to get owner. + default: + m.toBeOwner() + //nolint: errcheck + timeutil.Sleep(m.ctx, 1*time.Second) // Speed up domain.Close() + logutil.BgLogger().Debug("owner manager tick", zap.String("ID", m.id), + zap.String("ownerKey", m.key), zap.String("currentOwner", util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).GetOwner())) + } + } + }() + return nil +} + +// ResignOwner lets the owner start a new election. +func (m *mockManager) ResignOwner(_ context.Context) error { + m.resignDone <- struct{}{} + return nil +} + +// RequireOwner implements Manager.RequireOwner interface. +func (*mockManager) RequireOwner(context.Context) error { + return nil +} + +// SetListener implements Manager.SetListener interface. +func (m *mockManager) SetListener(listener Listener) { + m.listener = listener +} + +// CampaignCancel implements Manager.CampaignCancel interface +func (m *mockManager) CampaignCancel() { + m.campaignDone <- struct{}{} +} + +func mockDelOwnerKey(mockCal, ownerKey string, m *ownerManager) error { + checkIsOwner := func(m *ownerManager, checkTrue bool) error { + // 5s + for i := 0; i < 100; i++ { + if m.IsOwner() == checkTrue { + break + } + time.Sleep(50 * time.Millisecond) + } + if m.IsOwner() != checkTrue { + return errors.Errorf("expect manager state:%v", checkTrue) + } + return nil + } + + needCheckOwner := false + switch mockCal { + case "delOwnerKeyAndNotOwner": + m.CampaignCancel() + // Make sure the manager is not owner. And it will exit campaignLoop. + err := checkIsOwner(m, false) + if err != nil { + return err + } + case "onlyDelOwnerKey": + needCheckOwner = true + } + + err := util.DeleteKeyFromEtcd(ownerKey, m.etcdCli, 1, keyOpDefaultTimeout) + if err != nil { + return errors.Trace(err) + } + if needCheckOwner { + // Mock the manager become not owner because the owner is deleted(like TTL is timeout). + // And then the manager campaigns the owner again, and become the owner. + err = checkIsOwner(m, true) + if err != nil { + return err + } + } + return nil +} diff --git a/pkg/parser/ast/binding__failpoint_binding__.go b/pkg/parser/ast/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..88f2a0560be99 --- /dev/null +++ b/pkg/parser/ast/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package ast + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/parser/ast/misc.go b/pkg/parser/ast/misc.go index fa07d309953d0..13d83ea4d8da3 100644 --- a/pkg/parser/ast/misc.go +++ b/pkg/parser/ast/misc.go @@ -3582,9 +3582,9 @@ func RedactURL(str string) string { return str } scheme := u.Scheme - failpoint.Inject("forceRedactURL", func() { + if _, _err_ := failpoint.Eval(_curpkg_("forceRedactURL")); _err_ == nil { scheme = "s3" - }) + } switch strings.ToLower(scheme) { case "s3", "ks3": values := u.Query() diff --git a/pkg/parser/ast/misc.go__failpoint_stash__ b/pkg/parser/ast/misc.go__failpoint_stash__ new file mode 100644 index 0000000000000..fa07d309953d0 --- /dev/null +++ b/pkg/parser/ast/misc.go__failpoint_stash__ @@ -0,0 +1,4209 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "bytes" + "fmt" + "net/url" + "strconv" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" +) + +var ( + _ StmtNode = &AdminStmt{} + _ StmtNode = &AlterUserStmt{} + _ StmtNode = &AlterRangeStmt{} + _ StmtNode = &BeginStmt{} + _ StmtNode = &BinlogStmt{} + _ StmtNode = &CommitStmt{} + _ StmtNode = &CreateUserStmt{} + _ StmtNode = &DeallocateStmt{} + _ StmtNode = &DoStmt{} + _ StmtNode = &ExecuteStmt{} + _ StmtNode = &ExplainStmt{} + _ StmtNode = &GrantStmt{} + _ StmtNode = &PrepareStmt{} + _ StmtNode = &RollbackStmt{} + _ StmtNode = &SetPwdStmt{} + _ StmtNode = &SetRoleStmt{} + _ StmtNode = &SetDefaultRoleStmt{} + _ StmtNode = &SetStmt{} + _ StmtNode = &SetSessionStatesStmt{} + _ StmtNode = &UseStmt{} + _ StmtNode = &FlushStmt{} + _ StmtNode = &KillStmt{} + _ StmtNode = &CreateBindingStmt{} + _ StmtNode = &DropBindingStmt{} + _ StmtNode = &SetBindingStmt{} + _ StmtNode = &ShutdownStmt{} + _ StmtNode = &RestartStmt{} + _ StmtNode = &RenameUserStmt{} + _ StmtNode = &HelpStmt{} + _ StmtNode = &PlanReplayerStmt{} + _ StmtNode = &CompactTableStmt{} + _ StmtNode = &SetResourceGroupStmt{} + + _ Node = &PrivElem{} + _ Node = &VariableAssignment{} +) + +// Isolation level constants. +const ( + ReadCommitted = "READ-COMMITTED" + ReadUncommitted = "READ-UNCOMMITTED" + Serializable = "SERIALIZABLE" + RepeatableRead = "REPEATABLE-READ" + + PumpType = "PUMP" + DrainerType = "DRAINER" +) + +// Transaction mode constants. +const ( + Optimistic = "OPTIMISTIC" + Pessimistic = "PESSIMISTIC" +) + +// TypeOpt is used for parsing data type option from SQL. +type TypeOpt struct { + IsUnsigned bool + IsZerofill bool +} + +// FloatOpt is used for parsing floating-point type option from SQL. +// See http://dev.mysql.com/doc/refman/5.7/en/floating-point-types.html +type FloatOpt struct { + Flen int + Decimal int +} + +// AuthOption is used for parsing create use statement. +type AuthOption struct { + // ByAuthString set as true, if AuthString is used for authorization. Otherwise, authorization is done by HashString. + ByAuthString bool + AuthString string + ByHashString bool + HashString string + AuthPlugin string +} + +// Restore implements Node interface. +func (n *AuthOption) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("IDENTIFIED") + if n.AuthPlugin != "" { + ctx.WriteKeyWord(" WITH ") + ctx.WriteString(n.AuthPlugin) + } + if n.ByAuthString { + ctx.WriteKeyWord(" BY ") + ctx.WriteString(n.AuthString) + } else if n.ByHashString { + ctx.WriteKeyWord(" AS ") + ctx.WriteString(n.HashString) + } + return nil +} + +// TraceStmt is a statement to trace what sql actually does at background. +type TraceStmt struct { + stmtNode + + Stmt StmtNode + Format string + + TracePlan bool + TracePlanTarget string +} + +// Restore implements Node interface. +func (n *TraceStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("TRACE ") + if n.TracePlan { + ctx.WriteKeyWord("PLAN ") + if n.TracePlanTarget != "" { + ctx.WriteKeyWord("TARGET") + ctx.WritePlain(" = ") + ctx.WriteString(n.TracePlanTarget) + ctx.WritePlain(" ") + } + } else if n.Format != "row" { + ctx.WriteKeyWord("FORMAT") + ctx.WritePlain(" = ") + ctx.WriteString(n.Format) + ctx.WritePlain(" ") + } + if err := n.Stmt.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore TraceStmt.Stmt") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *TraceStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*TraceStmt) + node, ok := n.Stmt.Accept(v) + if !ok { + return n, false + } + n.Stmt = node.(StmtNode) + return v.Leave(n) +} + +// ExplainForStmt is a statement to provite information about how is SQL statement executeing +// in connection #ConnectionID +// See https://dev.mysql.com/doc/refman/5.7/en/explain.html +type ExplainForStmt struct { + stmtNode + + Format string + ConnectionID uint64 +} + +// Restore implements Node interface. +func (n *ExplainForStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("EXPLAIN ") + ctx.WriteKeyWord("FORMAT ") + ctx.WritePlain("= ") + ctx.WriteString(n.Format) + ctx.WritePlain(" ") + ctx.WriteKeyWord("FOR ") + ctx.WriteKeyWord("CONNECTION ") + ctx.WritePlain(strconv.FormatUint(n.ConnectionID, 10)) + return nil +} + +// Accept implements Node Accept interface. +func (n *ExplainForStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ExplainForStmt) + return v.Leave(n) +} + +// ExplainStmt is a statement to provide information about how is SQL statement executed +// or get columns information in a table. +// See https://dev.mysql.com/doc/refman/5.7/en/explain.html +type ExplainStmt struct { + stmtNode + + Stmt StmtNode + Format string + Analyze bool +} + +// Restore implements Node interface. +func (n *ExplainStmt) Restore(ctx *format.RestoreCtx) error { + if showStmt, ok := n.Stmt.(*ShowStmt); ok { + ctx.WriteKeyWord("DESC ") + if err := showStmt.Table.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore ExplainStmt.ShowStmt.Table") + } + if showStmt.Column != nil { + ctx.WritePlain(" ") + if err := showStmt.Column.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore ExplainStmt.ShowStmt.Column") + } + } + return nil + } + ctx.WriteKeyWord("EXPLAIN ") + if n.Analyze { + ctx.WriteKeyWord("ANALYZE ") + } + if !n.Analyze || strings.ToLower(n.Format) != "row" { + ctx.WriteKeyWord("FORMAT ") + ctx.WritePlain("= ") + ctx.WriteString(n.Format) + ctx.WritePlain(" ") + } + if err := n.Stmt.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore ExplainStmt.Stmt") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *ExplainStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ExplainStmt) + node, ok := n.Stmt.Accept(v) + if !ok { + return n, false + } + n.Stmt = node.(StmtNode) + return v.Leave(n) +} + +// PlanReplayerStmt is a statement to dump or load information for recreating plans +type PlanReplayerStmt struct { + stmtNode + + Stmt StmtNode + Analyze bool + Load bool + HistoricalStatsInfo *AsOfClause + + // Capture indicates 'plan replayer capture ' + Capture bool + // Remove indicates `plan replayer capture remove + Remove bool + + SQLDigest string + PlanDigest string + + // File is used to store 2 cases: + // 1. plan replayer load 'file'; + // 2. plan replayer dump explain 'file' + File string + + // Fields below are currently useless. + + // Where is the where clause in select statement. + Where ExprNode + // OrderBy is the ordering expression list. + OrderBy *OrderByClause + // Limit is the limit clause. + Limit *Limit +} + +// Restore implements Node interface. +func (n *PlanReplayerStmt) Restore(ctx *format.RestoreCtx) error { + if n.Load { + ctx.WriteKeyWord("PLAN REPLAYER LOAD ") + ctx.WriteString(n.File) + return nil + } + if n.Capture { + ctx.WriteKeyWord("PLAN REPLAYER CAPTURE ") + ctx.WriteString(n.SQLDigest) + ctx.WriteKeyWord(" ") + ctx.WriteString(n.PlanDigest) + return nil + } + if n.Remove { + ctx.WriteKeyWord("PLAN REPLAYER CAPTURE REMOVE ") + ctx.WriteString(n.SQLDigest) + ctx.WriteKeyWord(" ") + ctx.WriteString(n.PlanDigest) + return nil + } + + ctx.WriteKeyWord("PLAN REPLAYER DUMP ") + + if n.HistoricalStatsInfo != nil { + ctx.WriteKeyWord("WITH STATS ") + if err := n.HistoricalStatsInfo.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.HistoricalStatsInfo") + } + ctx.WriteKeyWord(" ") + } + if n.Analyze { + ctx.WriteKeyWord("EXPLAIN ANALYZE ") + } else { + ctx.WriteKeyWord("EXPLAIN ") + } + if n.Stmt == nil { + if len(n.File) > 0 { + ctx.WriteString(n.File) + return nil + } + ctx.WriteKeyWord("SLOW QUERY") + if n.Where != nil { + ctx.WriteKeyWord(" WHERE ") + if err := n.Where.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.Where") + } + } + if n.OrderBy != nil { + ctx.WriteKeyWord(" ") + if err := n.OrderBy.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.OrderBy") + } + } + if n.Limit != nil { + ctx.WriteKeyWord(" ") + if err := n.Limit.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.Limit") + } + } + return nil + } + if err := n.Stmt.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.Stmt") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *PlanReplayerStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + + n = newNode.(*PlanReplayerStmt) + + if n.Load { + return v.Leave(n) + } + + if n.HistoricalStatsInfo != nil { + info, ok := n.HistoricalStatsInfo.Accept(v) + if !ok { + return n, false + } + n.HistoricalStatsInfo = info.(*AsOfClause) + } + + if n.Stmt == nil { + if n.Where != nil { + node, ok := n.Where.Accept(v) + if !ok { + return n, false + } + n.Where = node.(ExprNode) + } + + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) + if !ok { + return n, false + } + n.OrderBy = node.(*OrderByClause) + } + + if n.Limit != nil { + node, ok := n.Limit.Accept(v) + if !ok { + return n, false + } + n.Limit = node.(*Limit) + } + return v.Leave(n) + } + + node, ok := n.Stmt.Accept(v) + if !ok { + return n, false + } + n.Stmt = node.(StmtNode) + return v.Leave(n) +} + +type CompactReplicaKind string + +const ( + // CompactReplicaKindAll means compacting both TiKV and TiFlash replicas. + CompactReplicaKindAll = "ALL" + + // CompactReplicaKindTiFlash means compacting TiFlash replicas. + CompactReplicaKindTiFlash = "TIFLASH" + + // CompactReplicaKindTiKV means compacting TiKV replicas. + CompactReplicaKindTiKV = "TIKV" +) + +// CompactTableStmt is a statement to manually compact a table. +type CompactTableStmt struct { + stmtNode + + Table *TableName + PartitionNames []model.CIStr + ReplicaKind CompactReplicaKind +} + +// Restore implements Node interface. +func (n *CompactTableStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("ALTER TABLE ") + n.Table.restoreName(ctx) + + ctx.WriteKeyWord(" COMPACT") + if len(n.PartitionNames) != 0 { + ctx.WriteKeyWord(" PARTITION ") + for i, partition := range n.PartitionNames { + if i != 0 { + ctx.WritePlain(",") + } + ctx.WriteName(partition.O) + } + } + if n.ReplicaKind != CompactReplicaKindAll { + ctx.WriteKeyWord(" ") + // Note: There is only TiFlash replica available now. TiKV will be added later. + ctx.WriteKeyWord(string(n.ReplicaKind)) + ctx.WriteKeyWord(" REPLICA") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *CompactTableStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*CompactTableStmt) + node, ok := n.Table.Accept(v) + if !ok { + return n, false + } + n.Table = node.(*TableName) + return v.Leave(n) +} + +// PrepareStmt is a statement to prepares a SQL statement which contains placeholders, +// and it is executed with ExecuteStmt and released with DeallocateStmt. +// See https://dev.mysql.com/doc/refman/5.7/en/prepare.html +type PrepareStmt struct { + stmtNode + + Name string + SQLText string + SQLVar *VariableExpr +} + +// Restore implements Node interface. +func (n *PrepareStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("PREPARE ") + ctx.WriteName(n.Name) + ctx.WriteKeyWord(" FROM ") + if n.SQLText != "" { + ctx.WriteString(n.SQLText) + return nil + } + if n.SQLVar != nil { + if err := n.SQLVar.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore PrepareStmt.SQLVar") + } + return nil + } + return errors.New("An error occurred while restore PrepareStmt") +} + +// Accept implements Node Accept interface. +func (n *PrepareStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*PrepareStmt) + if n.SQLVar != nil { + node, ok := n.SQLVar.Accept(v) + if !ok { + return n, false + } + n.SQLVar = node.(*VariableExpr) + } + return v.Leave(n) +} + +// DeallocateStmt is a statement to release PreparedStmt. +// See https://dev.mysql.com/doc/refman/5.7/en/deallocate-prepare.html +type DeallocateStmt struct { + stmtNode + + Name string +} + +// Restore implements Node interface. +func (n *DeallocateStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("DEALLOCATE PREPARE ") + ctx.WriteName(n.Name) + return nil +} + +// Accept implements Node Accept interface. +func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*DeallocateStmt) + return v.Leave(n) +} + +// Prepared represents a prepared statement. +type Prepared struct { + Stmt StmtNode + StmtType string +} + +// ExecuteStmt is a statement to execute PreparedStmt. +// See https://dev.mysql.com/doc/refman/5.7/en/execute.html +type ExecuteStmt struct { + stmtNode + + Name string + UsingVars []ExprNode + BinaryArgs interface{} + PrepStmt interface{} // the corresponding prepared statement + IdxInMulti int + + // FromGeneralStmt indicates whether this execute-stmt is converted from a general query. + // e.g. select * from t where a>2 --> execute 'select * from t where a>?' using 2 + FromGeneralStmt bool +} + +// Restore implements Node interface. +func (n *ExecuteStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("EXECUTE ") + ctx.WriteName(n.Name) + if len(n.UsingVars) > 0 { + ctx.WriteKeyWord(" USING ") + for i, val := range n.UsingVars { + if i != 0 { + ctx.WritePlain(",") + } + if err := val.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore ExecuteStmt.UsingVars index %d", i) + } + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ExecuteStmt) + for i, val := range n.UsingVars { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.UsingVars[i] = node.(ExprNode) + } + return v.Leave(n) +} + +// BeginStmt is a statement to start a new transaction. +// See https://dev.mysql.com/doc/refman/5.7/en/commit.html +type BeginStmt struct { + stmtNode + Mode string + CausalConsistencyOnly bool + ReadOnly bool + // AS OF is used to read the data at a specific point of time. + // Should only be used when ReadOnly is true. + AsOf *AsOfClause +} + +// Restore implements Node interface. +func (n *BeginStmt) Restore(ctx *format.RestoreCtx) error { + if n.Mode == "" { + if n.ReadOnly { + ctx.WriteKeyWord("START TRANSACTION READ ONLY") + if n.AsOf != nil { + ctx.WriteKeyWord(" ") + return n.AsOf.Restore(ctx) + } + } else if n.CausalConsistencyOnly { + ctx.WriteKeyWord("START TRANSACTION WITH CAUSAL CONSISTENCY ONLY") + } else { + ctx.WriteKeyWord("START TRANSACTION") + } + } else { + ctx.WriteKeyWord("BEGIN ") + ctx.WriteKeyWord(n.Mode) + } + return nil +} + +// Accept implements Node Accept interface. +func (n *BeginStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + + if n.AsOf != nil { + node, ok := n.AsOf.Accept(v) + if !ok { + return n, false + } + n.AsOf = node.(*AsOfClause) + } + + n = newNode.(*BeginStmt) + return v.Leave(n) +} + +// BinlogStmt is an internal-use statement. +// We just parse and ignore it. +// See http://dev.mysql.com/doc/refman/5.7/en/binlog.html +type BinlogStmt struct { + stmtNode + Str string +} + +// Restore implements Node interface. +func (n *BinlogStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("BINLOG ") + ctx.WriteString(n.Str) + return nil +} + +// Accept implements Node Accept interface. +func (n *BinlogStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*BinlogStmt) + return v.Leave(n) +} + +// CompletionType defines completion_type used in COMMIT and ROLLBACK statements +type CompletionType int8 + +const ( + // CompletionTypeDefault refers to NO_CHAIN + CompletionTypeDefault CompletionType = iota + CompletionTypeChain + CompletionTypeRelease +) + +func (n CompletionType) Restore(ctx *format.RestoreCtx) error { + switch n { + case CompletionTypeDefault: + case CompletionTypeChain: + ctx.WriteKeyWord(" AND CHAIN") + case CompletionTypeRelease: + ctx.WriteKeyWord(" RELEASE") + } + return nil +} + +// CommitStmt is a statement to commit the current transaction. +// See https://dev.mysql.com/doc/refman/5.7/en/commit.html +type CommitStmt struct { + stmtNode + // CompletionType overwrites system variable `completion_type` within transaction + CompletionType CompletionType +} + +// Restore implements Node interface. +func (n *CommitStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("COMMIT") + if err := n.CompletionType.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore CommitStmt.CompletionType") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *CommitStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*CommitStmt) + return v.Leave(n) +} + +// RollbackStmt is a statement to roll back the current transaction. +// See https://dev.mysql.com/doc/refman/5.7/en/commit.html +type RollbackStmt struct { + stmtNode + // CompletionType overwrites system variable `completion_type` within transaction + CompletionType CompletionType + // SavepointName is the savepoint name. + SavepointName string +} + +// Restore implements Node interface. +func (n *RollbackStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("ROLLBACK") + if n.SavepointName != "" { + ctx.WritePlain(" TO ") + ctx.WritePlain(n.SavepointName) + } + if err := n.CompletionType.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore RollbackStmt.CompletionType") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *RollbackStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*RollbackStmt) + return v.Leave(n) +} + +// UseStmt is a statement to use the DBName database as the current database. +// See https://dev.mysql.com/doc/refman/5.7/en/use.html +type UseStmt struct { + stmtNode + + DBName string +} + +// Restore implements Node interface. +func (n *UseStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("USE ") + ctx.WriteName(n.DBName) + return nil +} + +// Accept implements Node Accept interface. +func (n *UseStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*UseStmt) + return v.Leave(n) +} + +const ( + // SetNames is the const for set names stmt. + // If VariableAssignment.Name == Names, it should be set names stmt. + SetNames = "SetNAMES" + // SetCharset is the const for set charset stmt. + SetCharset = "SetCharset" + // TiDBCloudStorageURI is the const for set tidb_cloud_storage_uri stmt. + TiDBCloudStorageURI = "tidb_cloud_storage_uri" +) + +// VariableAssignment is a variable assignment struct. +type VariableAssignment struct { + node + Name string + Value ExprNode + IsGlobal bool + IsSystem bool + + // ExtendValue is a way to store extended info. + // VariableAssignment should be able to store information for SetCharset/SetPWD Stmt. + // For SetCharsetStmt, Value is charset, ExtendValue is collation. + // TODO: Use SetStmt to implement set password statement. + ExtendValue ValueExpr +} + +// Restore implements Node interface. +func (n *VariableAssignment) Restore(ctx *format.RestoreCtx) error { + if n.IsSystem { + ctx.WritePlain("@@") + if n.IsGlobal { + ctx.WriteKeyWord("GLOBAL") + } else { + ctx.WriteKeyWord("SESSION") + } + ctx.WritePlain(".") + } else if n.Name != SetNames && n.Name != SetCharset { + ctx.WriteKeyWord("@") + } + if n.Name == SetNames { + ctx.WriteKeyWord("NAMES ") + } else if n.Name == SetCharset { + ctx.WriteKeyWord("CHARSET ") + } else { + ctx.WriteName(n.Name) + ctx.WritePlain("=") + } + if n.Name == TiDBCloudStorageURI { + // need to redact the url for safety when `show processlist;` + ctx.WritePlain(RedactURL(n.Value.(ValueExpr).GetString())) + } else if err := n.Value.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore VariableAssignment.Value") + } + if n.ExtendValue != nil { + ctx.WriteKeyWord(" COLLATE ") + if err := n.ExtendValue.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore VariableAssignment.ExtendValue") + } + } + return nil +} + +// Accept implements Node interface. +func (n *VariableAssignment) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*VariableAssignment) + node, ok := n.Value.Accept(v) + if !ok { + return n, false + } + n.Value = node.(ExprNode) + return v.Leave(n) +} + +// FlushStmtType is the type for FLUSH statement. +type FlushStmtType int + +// Flush statement types. +const ( + FlushNone FlushStmtType = iota + FlushTables + FlushPrivileges + FlushStatus + FlushTiDBPlugin + FlushHosts + FlushLogs + FlushClientErrorsSummary +) + +// LogType is the log type used in FLUSH statement. +type LogType int8 + +const ( + LogTypeDefault LogType = iota + LogTypeBinary + LogTypeEngine + LogTypeError + LogTypeGeneral + LogTypeSlow +) + +// FlushStmt is a statement to flush tables/privileges/optimizer costs and so on. +type FlushStmt struct { + stmtNode + + Tp FlushStmtType // Privileges/Tables/... + NoWriteToBinLog bool + LogType LogType + Tables []*TableName // For FlushTableStmt, if Tables is empty, it means flush all tables. + ReadLock bool + Plugins []string +} + +// Restore implements Node interface. +func (n *FlushStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("FLUSH ") + if n.NoWriteToBinLog { + ctx.WriteKeyWord("NO_WRITE_TO_BINLOG ") + } + switch n.Tp { + case FlushTables: + ctx.WriteKeyWord("TABLES") + for i, v := range n.Tables { + if i == 0 { + ctx.WritePlain(" ") + } else { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore FlushStmt.Tables[%d]", i) + } + } + if n.ReadLock { + ctx.WriteKeyWord(" WITH READ LOCK") + } + case FlushPrivileges: + ctx.WriteKeyWord("PRIVILEGES") + case FlushStatus: + ctx.WriteKeyWord("STATUS") + case FlushTiDBPlugin: + ctx.WriteKeyWord("TIDB PLUGINS") + for i, v := range n.Plugins { + if i == 0 { + ctx.WritePlain(" ") + } else { + ctx.WritePlain(", ") + } + ctx.WritePlain(v) + } + case FlushHosts: + ctx.WriteKeyWord("HOSTS") + case FlushLogs: + var logType string + switch n.LogType { + case LogTypeDefault: + logType = "LOGS" + case LogTypeBinary: + logType = "BINARY LOGS" + case LogTypeEngine: + logType = "ENGINE LOGS" + case LogTypeError: + logType = "ERROR LOGS" + case LogTypeGeneral: + logType = "GENERAL LOGS" + case LogTypeSlow: + logType = "SLOW LOGS" + } + ctx.WriteKeyWord(logType) + case FlushClientErrorsSummary: + ctx.WriteKeyWord("CLIENT_ERRORS_SUMMARY") + default: + return errors.New("Unsupported type of FlushStmt") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *FlushStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*FlushStmt) + for i, t := range n.Tables { + node, ok := t.Accept(v) + if !ok { + return n, false + } + n.Tables[i] = node.(*TableName) + } + return v.Leave(n) +} + +// KillStmt is a statement to kill a query or connection. +type KillStmt struct { + stmtNode + + // Query indicates whether terminate a single query on this connection or the whole connection. + // If Query is true, terminates the statement the connection is currently executing, but leaves the connection itself intact. + // If Query is false, terminates the connection associated with the given ConnectionID, after terminating any statement the connection is executing. + Query bool + ConnectionID uint64 + // TiDBExtension is used to indicate whether the user knows he is sending kill statement to the right tidb-server. + // When the SQL grammar is "KILL TIDB [CONNECTION | QUERY] connectionID", TiDBExtension will be set. + // It's a special grammar extension in TiDB. This extension exists because, when the connection is: + // client -> LVS proxy -> TiDB, and type Ctrl+C in client, the following action will be executed: + // new a connection; kill xxx; + // kill command may send to the wrong TiDB, because the exists of LVS proxy, and kill the wrong session. + // So, "KILL TIDB" grammar is introduced, and it REQUIRES DIRECT client -> TiDB TOPOLOGY. + // TODO: The standard KILL grammar will be supported once we have global connectionID. + TiDBExtension bool + + Expr ExprNode +} + +// Restore implements Node interface. +func (n *KillStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("KILL") + if n.TiDBExtension { + ctx.WriteKeyWord(" TIDB") + } + if n.Query { + ctx.WriteKeyWord(" QUERY") + } + if n.Expr != nil { + ctx.WriteKeyWord(" ") + if err := n.Expr.Restore(ctx); err != nil { + return errors.Trace(err) + } + } else { + ctx.WritePlainf(" %d", n.ConnectionID) + } + return nil +} + +// Accept implements Node Accept interface. +func (n *KillStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*KillStmt) + return v.Leave(n) +} + +// SavepointStmt is the statement of SAVEPOINT. +type SavepointStmt struct { + stmtNode + // Name is the savepoint name. + Name string +} + +// Restore implements Node interface. +func (n *SavepointStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SAVEPOINT ") + ctx.WritePlain(n.Name) + return nil +} + +// Accept implements Node Accept interface. +func (n *SavepointStmt) Accept(v Visitor) (Node, bool) { + newNode, _ := v.Enter(n) + n = newNode.(*SavepointStmt) + return v.Leave(n) +} + +// ReleaseSavepointStmt is the statement of RELEASE SAVEPOINT. +type ReleaseSavepointStmt struct { + stmtNode + // Name is the savepoint name. + Name string +} + +// Restore implements Node interface. +func (n *ReleaseSavepointStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("RELEASE SAVEPOINT ") + ctx.WritePlain(n.Name) + return nil +} + +// Accept implements Node Accept interface. +func (n *ReleaseSavepointStmt) Accept(v Visitor) (Node, bool) { + newNode, _ := v.Enter(n) + n = newNode.(*ReleaseSavepointStmt) + return v.Leave(n) +} + +// SetStmt is the statement to set variables. +type SetStmt struct { + stmtNode + // Variables is the list of variable assignment. + Variables []*VariableAssignment +} + +// Restore implements Node interface. +func (n *SetStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET ") + for i, v := range n.Variables { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore SetStmt.Variables[%d]", i) + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *SetStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetStmt) + for i, val := range n.Variables { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Variables[i] = node.(*VariableAssignment) + } + return v.Leave(n) +} + +// SecureText implements SensitiveStatement interface. +// need to redact the tidb_cloud_storage_url for safety when `show processlist;` +func (n *SetStmt) SecureText() string { + redactedStmt := *n + var sb strings.Builder + _ = redactedStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) + return sb.String() +} + +// SetConfigStmt is the statement to set cluster configs. +type SetConfigStmt struct { + stmtNode + + Type string // TiDB, TiKV, PD + Instance string // '127.0.0.1:3306' + Name string // the variable name + Value ExprNode +} + +func (n *SetConfigStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET CONFIG ") + if n.Type != "" { + ctx.WriteKeyWord(n.Type) + } else { + ctx.WriteString(n.Instance) + } + ctx.WritePlain(" ") + ctx.WriteKeyWord(n.Name) + ctx.WritePlain(" = ") + return n.Value.Restore(ctx) +} + +func (n *SetConfigStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetConfigStmt) + node, ok := n.Value.Accept(v) + if !ok { + return n, false + } + n.Value = node.(ExprNode) + return v.Leave(n) +} + +// SetSessionStatesStmt is a statement to restore session states. +type SetSessionStatesStmt struct { + stmtNode + + SessionStates string +} + +func (n *SetSessionStatesStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET SESSION_STATES ") + ctx.WriteString(n.SessionStates) + return nil +} + +func (n *SetSessionStatesStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetSessionStatesStmt) + return v.Leave(n) +} + +/* +// SetCharsetStmt is a statement to assign values to character and collation variables. +// See https://dev.mysql.com/doc/refman/5.7/en/set-statement.html +type SetCharsetStmt struct { + stmtNode + + Charset string + Collate string +} + +// Accept implements Node Accept interface. +func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetCharsetStmt) + return v.Leave(n) +} +*/ + +// SetPwdStmt is a statement to assign a password to user account. +// See https://dev.mysql.com/doc/refman/5.7/en/set-password.html +type SetPwdStmt struct { + stmtNode + + User *auth.UserIdentity + Password string +} + +// Restore implements Node interface. +func (n *SetPwdStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET PASSWORD") + if n.User != nil { + ctx.WriteKeyWord(" FOR ") + if err := n.User.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore SetPwdStmt.User") + } + } + ctx.WritePlain("=") + ctx.WriteString(n.Password) + return nil +} + +// SecureText implements SensitiveStatement interface. +func (n *SetPwdStmt) SecureText() string { + return fmt.Sprintf("set password for user %s", n.User) +} + +// Accept implements Node Accept interface. +func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetPwdStmt) + return v.Leave(n) +} + +type ChangeStmt struct { + stmtNode + + NodeType string + State string + NodeID string +} + +// Restore implements Node interface. +func (n *ChangeStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("CHANGE ") + ctx.WriteKeyWord(n.NodeType) + ctx.WriteKeyWord(" TO NODE_STATE ") + ctx.WritePlain("=") + ctx.WriteString(n.State) + ctx.WriteKeyWord(" FOR NODE_ID ") + ctx.WriteString(n.NodeID) + return nil +} + +// SecureText implements SensitiveStatement interface. +func (n *ChangeStmt) SecureText() string { + return fmt.Sprintf("change %s to node_state='%s' for node_id '%s'", strings.ToLower(n.NodeType), n.State, n.NodeID) +} + +// Accept implements Node Accept interface. +func (n *ChangeStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ChangeStmt) + return v.Leave(n) +} + +// SetRoleStmtType is the type for FLUSH statement. +type SetRoleStmtType int + +// SetRole statement types. +const ( + SetRoleDefault SetRoleStmtType = iota + SetRoleNone + SetRoleAll + SetRoleAllExcept + SetRoleRegular +) + +type SetRoleStmt struct { + stmtNode + + SetRoleOpt SetRoleStmtType + RoleList []*auth.RoleIdentity +} + +func (n *SetRoleStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET ROLE") + switch n.SetRoleOpt { + case SetRoleDefault: + ctx.WriteKeyWord(" DEFAULT") + case SetRoleNone: + ctx.WriteKeyWord(" NONE") + case SetRoleAll: + ctx.WriteKeyWord(" ALL") + case SetRoleAllExcept: + ctx.WriteKeyWord(" ALL EXCEPT") + } + for i, role := range n.RoleList { + ctx.WritePlain(" ") + err := role.Restore(ctx) + if err != nil { + return errors.Annotate(err, "An error occurred while restore SetRoleStmt.RoleList") + } + if i != len(n.RoleList)-1 { + ctx.WritePlain(",") + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *SetRoleStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetRoleStmt) + return v.Leave(n) +} + +type SetDefaultRoleStmt struct { + stmtNode + + SetRoleOpt SetRoleStmtType + RoleList []*auth.RoleIdentity + UserList []*auth.UserIdentity +} + +func (n *SetDefaultRoleStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET DEFAULT ROLE") + switch n.SetRoleOpt { + case SetRoleNone: + ctx.WriteKeyWord(" NONE") + case SetRoleAll: + ctx.WriteKeyWord(" ALL") + default: + } + for i, role := range n.RoleList { + ctx.WritePlain(" ") + err := role.Restore(ctx) + if err != nil { + return errors.Annotate(err, "An error occurred while restore SetDefaultRoleStmt.RoleList") + } + if i != len(n.RoleList)-1 { + ctx.WritePlain(",") + } + } + ctx.WritePlain(" TO") + for i, user := range n.UserList { + ctx.WritePlain(" ") + err := user.Restore(ctx) + if err != nil { + return errors.Annotate(err, "An error occurred while restore SetDefaultRoleStmt.UserList") + } + if i != len(n.UserList)-1 { + ctx.WritePlain(",") + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *SetDefaultRoleStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetDefaultRoleStmt) + return v.Leave(n) +} + +// UserSpec is used for parsing create user statement. +type UserSpec struct { + User *auth.UserIdentity + AuthOpt *AuthOption + IsRole bool +} + +// Restore implements Node interface. +func (n *UserSpec) Restore(ctx *format.RestoreCtx) error { + if err := n.User.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore UserSpec.User") + } + if n.AuthOpt != nil { + ctx.WritePlain(" ") + if err := n.AuthOpt.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore UserSpec.AuthOpt") + } + } + return nil +} + +// SecurityString formats the UserSpec without password information. +func (n *UserSpec) SecurityString() string { + withPassword := false + if opt := n.AuthOpt; opt != nil { + if len(opt.AuthString) > 0 || len(opt.HashString) > 0 { + withPassword = true + } + } + if withPassword { + return fmt.Sprintf("{%s password = ***}", n.User) + } + return n.User.String() +} + +// EncodedPassword returns the encoded password (which is the real data mysql.user). +// The boolean value indicates input's password format is legal or not. +func (n *UserSpec) EncodedPassword() (string, bool) { + if n.AuthOpt == nil { + return "", true + } + + opt := n.AuthOpt + if opt.ByAuthString { + switch opt.AuthPlugin { + case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: + return auth.NewHashPassword(opt.AuthString, opt.AuthPlugin), true + case mysql.AuthSocket: + return "", true + default: + return auth.EncodePassword(opt.AuthString), true + } + } + + // store the LDAP dn directly in the password field + switch opt.AuthPlugin { + case mysql.AuthLDAPSimple, mysql.AuthLDAPSASL: + // TODO: validate the HashString to be a `dn` for LDAP + // It seems fine to not validate here, and LDAP server will give an error when the client'll try to login this user. + // The percona server implementation doesn't have a validation for this HashString. + // However, returning an error for obvious wrong format is more friendly. + return opt.HashString, true + } + + // In case we have 'IDENTIFIED WITH ' but no 'BY ' to set an empty password. + if opt.HashString == "" { + return opt.HashString, true + } + + // Not a legal password string. + switch opt.AuthPlugin { + case mysql.AuthCachingSha2Password: + if len(opt.HashString) != mysql.SHAPWDHashLen { + return "", false + } + case mysql.AuthTiDBSM3Password: + if len(opt.HashString) != mysql.SM3PWDHashLen { + return "", false + } + case "", mysql.AuthNativePassword: + if len(opt.HashString) != (mysql.PWDHashLen+1) || !strings.HasPrefix(opt.HashString, "*") { + return "", false + } + case mysql.AuthSocket: + default: + return "", false + } + return opt.HashString, true +} + +type AuthTokenOrTLSOption struct { + Type AuthTokenOrTLSOptionType + Value string +} + +func (t *AuthTokenOrTLSOption) Restore(ctx *format.RestoreCtx) error { + switch t.Type { + case TlsNone: + ctx.WriteKeyWord("NONE") + case Ssl: + ctx.WriteKeyWord("SSL") + case X509: + ctx.WriteKeyWord("X509") + case Cipher: + ctx.WriteKeyWord("CIPHER ") + ctx.WriteString(t.Value) + case Issuer: + ctx.WriteKeyWord("ISSUER ") + ctx.WriteString(t.Value) + case Subject: + ctx.WriteKeyWord("SUBJECT ") + ctx.WriteString(t.Value) + case SAN: + ctx.WriteKeyWord("SAN ") + ctx.WriteString(t.Value) + case TokenIssuer: + ctx.WriteKeyWord("TOKEN_ISSUER ") + ctx.WriteString(t.Value) + default: + return errors.Errorf("Unsupported AuthTokenOrTLSOption.Type %d", t.Type) + } + return nil +} + +type AuthTokenOrTLSOptionType int + +const ( + TlsNone AuthTokenOrTLSOptionType = iota + Ssl + X509 + Cipher + Issuer + Subject + SAN + TokenIssuer +) + +func (t AuthTokenOrTLSOptionType) String() string { + switch t { + case TlsNone: + return "NONE" + case Ssl: + return "SSL" + case X509: + return "X509" + case Cipher: + return "CIPHER" + case Issuer: + return "ISSUER" + case Subject: + return "SUBJECT" + case SAN: + return "SAN" + case TokenIssuer: + return "TOKEN_ISSUER" + default: + return "UNKNOWN" + } +} + +const ( + MaxQueriesPerHour = iota + 1 + MaxUpdatesPerHour + MaxConnectionsPerHour + MaxUserConnections +) + +type ResourceOption struct { + Type int + Count int64 +} + +func (r *ResourceOption) Restore(ctx *format.RestoreCtx) error { + switch r.Type { + case MaxQueriesPerHour: + ctx.WriteKeyWord("MAX_QUERIES_PER_HOUR ") + case MaxUpdatesPerHour: + ctx.WriteKeyWord("MAX_UPDATES_PER_HOUR ") + case MaxConnectionsPerHour: + ctx.WriteKeyWord("MAX_CONNECTIONS_PER_HOUR ") + case MaxUserConnections: + ctx.WriteKeyWord("MAX_USER_CONNECTIONS ") + default: + return errors.Errorf("Unsupported ResourceOption.Type %d", r.Type) + } + ctx.WritePlainf("%d", r.Count) + return nil +} + +const ( + PasswordExpire = iota + 1 + PasswordExpireDefault + PasswordExpireNever + PasswordExpireInterval + PasswordHistory + PasswordHistoryDefault + PasswordReuseInterval + PasswordReuseDefault + Lock + Unlock + FailedLoginAttempts + PasswordLockTime + PasswordLockTimeUnbounded + UserCommentType + UserAttributeType + PasswordRequireCurrentDefault + + UserResourceGroupName +) + +type PasswordOrLockOption struct { + Type int + Count int64 +} + +func (p *PasswordOrLockOption) Restore(ctx *format.RestoreCtx) error { + switch p.Type { + case PasswordExpire: + ctx.WriteKeyWord("PASSWORD EXPIRE") + case PasswordExpireDefault: + ctx.WriteKeyWord("PASSWORD EXPIRE DEFAULT") + case PasswordExpireNever: + ctx.WriteKeyWord("PASSWORD EXPIRE NEVER") + case PasswordExpireInterval: + ctx.WriteKeyWord("PASSWORD EXPIRE INTERVAL") + ctx.WritePlainf(" %d", p.Count) + ctx.WriteKeyWord(" DAY") + case Lock: + ctx.WriteKeyWord("ACCOUNT LOCK") + case Unlock: + ctx.WriteKeyWord("ACCOUNT UNLOCK") + case FailedLoginAttempts: + ctx.WriteKeyWord("FAILED_LOGIN_ATTEMPTS") + ctx.WritePlainf(" %d", p.Count) + case PasswordLockTime: + ctx.WriteKeyWord("PASSWORD_LOCK_TIME") + ctx.WritePlainf(" %d", p.Count) + case PasswordLockTimeUnbounded: + ctx.WriteKeyWord("PASSWORD_LOCK_TIME UNBOUNDED") + case PasswordHistory: + ctx.WriteKeyWord("PASSWORD HISTORY") + ctx.WritePlainf(" %d", p.Count) + case PasswordHistoryDefault: + ctx.WriteKeyWord("PASSWORD HISTORY DEFAULT") + case PasswordReuseInterval: + ctx.WriteKeyWord("PASSWORD REUSE INTERVAL") + ctx.WritePlainf(" %d", p.Count) + ctx.WriteKeyWord(" DAY") + case PasswordReuseDefault: + ctx.WriteKeyWord("PASSWORD REUSE INTERVAL DEFAULT") + default: + return errors.Errorf("Unsupported PasswordOrLockOption.Type %d", p.Type) + } + return nil +} + +type CommentOrAttributeOption struct { + Type int + Value string +} + +func (c *CommentOrAttributeOption) Restore(ctx *format.RestoreCtx) error { + if c.Type == UserCommentType { + ctx.WriteKeyWord(" COMMENT ") + ctx.WriteString(c.Value) + } else if c.Type == UserAttributeType { + ctx.WriteKeyWord(" ATTRIBUTE ") + ctx.WriteString(c.Value) + } + return nil +} + +type ResourceGroupNameOption struct { + Value string +} + +func (c *ResourceGroupNameOption) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord(" RESOURCE GROUP ") + ctx.WriteName(c.Value) + return nil +} + +// CreateUserStmt creates user account. +// See https://dev.mysql.com/doc/refman/8.0/en/create-user.html +type CreateUserStmt struct { + stmtNode + + IsCreateRole bool + IfNotExists bool + Specs []*UserSpec + AuthTokenOrTLSOptions []*AuthTokenOrTLSOption + ResourceOptions []*ResourceOption + PasswordOrLockOptions []*PasswordOrLockOption + CommentOrAttributeOption *CommentOrAttributeOption + ResourceGroupNameOption *ResourceGroupNameOption +} + +// Restore implements Node interface. +func (n *CreateUserStmt) Restore(ctx *format.RestoreCtx) error { + if n.IsCreateRole { + ctx.WriteKeyWord("CREATE ROLE ") + } else { + ctx.WriteKeyWord("CREATE USER ") + } + if n.IfNotExists { + ctx.WriteKeyWord("IF NOT EXISTS ") + } + for i, v := range n.Specs { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.Specs[%d]", i) + } + } + + if len(n.AuthTokenOrTLSOptions) != 0 { + ctx.WriteKeyWord(" REQUIRE ") + } + + for i, option := range n.AuthTokenOrTLSOptions { + if i != 0 { + ctx.WriteKeyWord(" AND ") + } + if err := option.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.AuthTokenOrTLSOptions[%d]", i) + } + } + + if len(n.ResourceOptions) != 0 { + ctx.WriteKeyWord(" WITH") + } + + for i, v := range n.ResourceOptions { + ctx.WritePlain(" ") + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.ResourceOptions[%d]", i) + } + } + + for i, v := range n.PasswordOrLockOptions { + ctx.WritePlain(" ") + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.PasswordOrLockOptions[%d]", i) + } + } + + if n.CommentOrAttributeOption != nil { + if err := n.CommentOrAttributeOption.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.CommentOrAttributeOption") + } + } + + if n.ResourceGroupNameOption != nil { + if err := n.ResourceGroupNameOption.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.ResourceGroupNameOption") + } + } + + return nil +} + +// Accept implements Node Accept interface. +func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*CreateUserStmt) + return v.Leave(n) +} + +// SecureText implements SensitiveStatement interface. +func (n *CreateUserStmt) SecureText() string { + var buf bytes.Buffer + buf.WriteString("create user") + for _, user := range n.Specs { + buf.WriteString(" ") + buf.WriteString(user.SecurityString()) + } + return buf.String() +} + +// AlterUserStmt modifies user account. +// See https://dev.mysql.com/doc/refman/8.0/en/alter-user.html +type AlterUserStmt struct { + stmtNode + + IfExists bool + CurrentAuth *AuthOption + Specs []*UserSpec + AuthTokenOrTLSOptions []*AuthTokenOrTLSOption + ResourceOptions []*ResourceOption + PasswordOrLockOptions []*PasswordOrLockOption + CommentOrAttributeOption *CommentOrAttributeOption + ResourceGroupNameOption *ResourceGroupNameOption +} + +// Restore implements Node interface. +func (n *AlterUserStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("ALTER USER ") + if n.IfExists { + ctx.WriteKeyWord("IF EXISTS ") + } + if n.CurrentAuth != nil { + ctx.WriteKeyWord("USER") + ctx.WritePlain("() ") + if err := n.CurrentAuth.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore AlterUserStmt.CurrentAuth") + } + } + for i, v := range n.Specs { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.Specs[%d]", i) + } + } + + if len(n.AuthTokenOrTLSOptions) != 0 { + ctx.WriteKeyWord(" REQUIRE ") + } + + for i, option := range n.AuthTokenOrTLSOptions { + if i != 0 { + ctx.WriteKeyWord(" AND ") + } + if err := option.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.AuthTokenOrTLSOptions[%d]", i) + } + } + + if len(n.ResourceOptions) != 0 { + ctx.WriteKeyWord(" WITH") + } + + for i, v := range n.ResourceOptions { + ctx.WritePlain(" ") + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.ResourceOptions[%d]", i) + } + } + + for i, v := range n.PasswordOrLockOptions { + ctx.WritePlain(" ") + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.PasswordOrLockOptions[%d]", i) + } + } + + if n.CommentOrAttributeOption != nil { + if err := n.CommentOrAttributeOption.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.CommentOrAttributeOption") + } + } + + if n.ResourceGroupNameOption != nil { + if err := n.ResourceGroupNameOption.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.ResourceGroupNameOption") + } + } + + return nil +} + +// SecureText implements SensitiveStatement interface. +func (n *AlterUserStmt) SecureText() string { + var buf bytes.Buffer + buf.WriteString("alter user") + for _, user := range n.Specs { + buf.WriteString(" ") + buf.WriteString(user.SecurityString()) + } + return buf.String() +} + +// Accept implements Node Accept interface. +func (n *AlterUserStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*AlterUserStmt) + return v.Leave(n) +} + +// AlterInstanceStmt modifies instance. +// See https://dev.mysql.com/doc/refman/8.0/en/alter-instance.html +type AlterInstanceStmt struct { + stmtNode + + ReloadTLS bool + NoRollbackOnError bool +} + +// Restore implements Node interface. +func (n *AlterInstanceStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("ALTER INSTANCE") + if n.ReloadTLS { + ctx.WriteKeyWord(" RELOAD TLS") + } + if n.NoRollbackOnError { + ctx.WriteKeyWord(" NO ROLLBACK ON ERROR") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *AlterInstanceStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*AlterInstanceStmt) + return v.Leave(n) +} + +// AlterRangeStmt modifies range configuration. +type AlterRangeStmt struct { + stmtNode + RangeName model.CIStr + PlacementOption *PlacementOption +} + +// Restore implements Node interface. +func (n *AlterRangeStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("ALTER RANGE ") + ctx.WriteName(n.RangeName.O) + ctx.WritePlain(" ") + if err := n.PlacementOption.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore AlterRangeStmt.PlacementOption") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *AlterRangeStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*AlterRangeStmt) + return v.Leave(n) +} + +// DropUserStmt creates user account. +// See http://dev.mysql.com/doc/refman/5.7/en/drop-user.html +type DropUserStmt struct { + stmtNode + + IfExists bool + IsDropRole bool + UserList []*auth.UserIdentity +} + +// Restore implements Node interface. +func (n *DropUserStmt) Restore(ctx *format.RestoreCtx) error { + if n.IsDropRole { + ctx.WriteKeyWord("DROP ROLE ") + } else { + ctx.WriteKeyWord("DROP USER ") + } + if n.IfExists { + ctx.WriteKeyWord("IF EXISTS ") + } + for i, v := range n.UserList { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore DropUserStmt.UserList[%d]", i) + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *DropUserStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*DropUserStmt) + return v.Leave(n) +} + +// CreateBindingStmt creates sql binding hint. +type CreateBindingStmt struct { + stmtNode + + GlobalScope bool + OriginNode StmtNode + HintedNode StmtNode + PlanDigest string +} + +func (n *CreateBindingStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("CREATE ") + if n.GlobalScope { + ctx.WriteKeyWord("GLOBAL ") + } else { + ctx.WriteKeyWord("SESSION ") + } + if n.OriginNode == nil { + ctx.WriteKeyWord("BINDING FROM HISTORY USING PLAN DIGEST ") + ctx.WriteString(n.PlanDigest) + } else { + ctx.WriteKeyWord("BINDING FOR ") + if err := n.OriginNode.Restore(ctx); err != nil { + return errors.Trace(err) + } + ctx.WriteKeyWord(" USING ") + if err := n.HintedNode.Restore(ctx); err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (n *CreateBindingStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*CreateBindingStmt) + if n.OriginNode != nil { + origNode, ok := n.OriginNode.Accept(v) + if !ok { + return n, false + } + n.OriginNode = origNode.(StmtNode) + hintedNode, ok := n.HintedNode.Accept(v) + if !ok { + return n, false + } + n.HintedNode = hintedNode.(StmtNode) + } + return v.Leave(n) +} + +// DropBindingStmt deletes sql binding hint. +type DropBindingStmt struct { + stmtNode + + GlobalScope bool + OriginNode StmtNode + HintedNode StmtNode + SQLDigest string +} + +func (n *DropBindingStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("DROP ") + if n.GlobalScope { + ctx.WriteKeyWord("GLOBAL ") + } else { + ctx.WriteKeyWord("SESSION ") + } + ctx.WriteKeyWord("BINDING FOR ") + if n.OriginNode == nil { + ctx.WriteKeyWord("SQL DIGEST ") + ctx.WriteString(n.SQLDigest) + } else { + if err := n.OriginNode.Restore(ctx); err != nil { + return errors.Trace(err) + } + if n.HintedNode != nil { + ctx.WriteKeyWord(" USING ") + if err := n.HintedNode.Restore(ctx); err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +func (n *DropBindingStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*DropBindingStmt) + if n.OriginNode != nil { + // OriginNode is nil means we build drop binding by sql digest + origNode, ok := n.OriginNode.Accept(v) + if !ok { + return n, false + } + n.OriginNode = origNode.(StmtNode) + if n.HintedNode != nil { + hintedNode, ok := n.HintedNode.Accept(v) + if !ok { + return n, false + } + n.HintedNode = hintedNode.(StmtNode) + } + } + return v.Leave(n) +} + +// BindingStatusType defines the status type for the binding +type BindingStatusType int8 + +// Binding status types. +const ( + BindingStatusTypeEnabled BindingStatusType = iota + BindingStatusTypeDisabled +) + +// SetBindingStmt sets sql binding status. +type SetBindingStmt struct { + stmtNode + + BindingStatusType BindingStatusType + OriginNode StmtNode + HintedNode StmtNode + SQLDigest string +} + +func (n *SetBindingStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET ") + ctx.WriteKeyWord("BINDING ") + switch n.BindingStatusType { + case BindingStatusTypeEnabled: + ctx.WriteKeyWord("ENABLED ") + case BindingStatusTypeDisabled: + ctx.WriteKeyWord("DISABLED ") + } + ctx.WriteKeyWord("FOR ") + if n.OriginNode == nil { + ctx.WriteKeyWord("SQL DIGEST ") + ctx.WriteString(n.SQLDigest) + } else { + if err := n.OriginNode.Restore(ctx); err != nil { + return errors.Trace(err) + } + if n.HintedNode != nil { + ctx.WriteKeyWord(" USING ") + if err := n.HintedNode.Restore(ctx); err != nil { + return errors.Trace(err) + } + } + } + return nil +} + +func (n *SetBindingStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetBindingStmt) + if n.OriginNode != nil { + // OriginNode is nil means we set binding stmt by sql digest + origNode, ok := n.OriginNode.Accept(v) + if !ok { + return n, false + } + n.OriginNode = origNode.(StmtNode) + if n.HintedNode != nil { + hintedNode, ok := n.HintedNode.Accept(v) + if !ok { + return n, false + } + n.HintedNode = hintedNode.(StmtNode) + } + } + return v.Leave(n) +} + +// Extended statistics types. +const ( + StatsTypeCardinality uint8 = iota + StatsTypeDependency + StatsTypeCorrelation +) + +// StatisticsSpec is the specification for ADD /DROP STATISTICS. +type StatisticsSpec struct { + StatsName string + StatsType uint8 + Columns []*ColumnName +} + +// CreateStatisticsStmt is a statement to create extended statistics. +// Examples: +// +// CREATE STATISTICS stats1 (cardinality) ON t(a, b, c); +// CREATE STATISTICS stats2 (dependency) ON t(a, b); +// CREATE STATISTICS stats3 (correlation) ON t(a, b); +type CreateStatisticsStmt struct { + stmtNode + + IfNotExists bool + StatsName string + StatsType uint8 + Table *TableName + Columns []*ColumnName +} + +// Restore implements Node interface. +func (n *CreateStatisticsStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("CREATE STATISTICS ") + if n.IfNotExists { + ctx.WriteKeyWord("IF NOT EXISTS ") + } + ctx.WriteName(n.StatsName) + switch n.StatsType { + case StatsTypeCardinality: + ctx.WriteKeyWord(" (cardinality) ") + case StatsTypeDependency: + ctx.WriteKeyWord(" (dependency) ") + case StatsTypeCorrelation: + ctx.WriteKeyWord(" (correlation) ") + } + ctx.WriteKeyWord("ON ") + if err := n.Table.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore CreateStatisticsStmt.Table") + } + + ctx.WritePlain("(") + for i, col := range n.Columns { + if i != 0 { + ctx.WritePlain(", ") + } + if err := col.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore CreateStatisticsStmt.Columns: [%v]", i) + } + } + ctx.WritePlain(")") + return nil +} + +// Accept implements Node Accept interface. +func (n *CreateStatisticsStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*CreateStatisticsStmt) + node, ok := n.Table.Accept(v) + if !ok { + return n, false + } + n.Table = node.(*TableName) + for i, col := range n.Columns { + node, ok = col.Accept(v) + if !ok { + return n, false + } + n.Columns[i] = node.(*ColumnName) + } + return v.Leave(n) +} + +// DropStatisticsStmt is a statement to drop extended statistics. +// Examples: +// +// DROP STATISTICS stats1; +type DropStatisticsStmt struct { + stmtNode + + StatsName string +} + +// Restore implements Node interface. +func (n *DropStatisticsStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("DROP STATISTICS ") + ctx.WriteName(n.StatsName) + return nil +} + +// Accept implements Node Accept interface. +func (n *DropStatisticsStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*DropStatisticsStmt) + return v.Leave(n) +} + +// DoStmt is the struct for DO statement. +type DoStmt struct { + stmtNode + + Exprs []ExprNode +} + +// Restore implements Node interface. +func (n *DoStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("DO ") + for i, v := range n.Exprs { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore DoStmt.Exprs[%d]", i) + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *DoStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*DoStmt) + for i, val := range n.Exprs { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Exprs[i] = node.(ExprNode) + } + return v.Leave(n) +} + +// AdminStmtType is the type for admin statement. +type AdminStmtType int + +// Admin statement types. +const ( + AdminShowDDL AdminStmtType = iota + 1 + AdminCheckTable + AdminShowDDLJobs + AdminCancelDDLJobs + AdminPauseDDLJobs + AdminResumeDDLJobs + AdminCheckIndex + AdminRecoverIndex + AdminCleanupIndex + AdminCheckIndexRange + AdminShowDDLJobQueries + AdminShowDDLJobQueriesWithRange + AdminChecksumTable + AdminShowSlow + AdminShowNextRowID + AdminReloadExprPushdownBlacklist + AdminReloadOptRuleBlacklist + AdminPluginDisable + AdminPluginEnable + AdminFlushBindings + AdminCaptureBindings + AdminEvolveBindings + AdminReloadBindings + AdminReloadStatistics + AdminFlushPlanCache + AdminSetBDRRole + AdminShowBDRRole + AdminUnsetBDRRole +) + +// HandleRange represents a range where handle value >= Begin and < End. +type HandleRange struct { + Begin int64 + End int64 +} + +// BDRRole represents the role of the cluster in BDR mode. +type BDRRole string + +const ( + BDRRolePrimary BDRRole = "primary" + BDRRoleSecondary BDRRole = "secondary" + BDRRoleNone BDRRole = "" +) + +// DeniedByBDR checks whether the DDL is denied by BDR. +func DeniedByBDR(role BDRRole, action model.ActionType, job *model.Job) (denied bool) { + ddlType, ok := model.ActionBDRMap[action] + switch role { + case BDRRolePrimary: + if !ok { + return true + } + + // Can't add unique index on primary role. + if job != nil && (action == model.ActionAddIndex || action == model.ActionAddPrimaryKey) && + len(job.Args) >= 1 && job.Args[0].(bool) { + // job.Args[0] is unique when job.Type is ActionAddIndex or ActionAddPrimaryKey. + return true + } + + if ddlType == model.SafeDDL || ddlType == model.UnmanagementDDL { + return false + } + case BDRRoleSecondary: + if !ok { + return true + } + if ddlType == model.UnmanagementDDL { + return false + } + default: + // if user do not set bdr role, we will not deny any ddl as `none` + return false + } + + return true +} + +type StatementScope int + +const ( + StatementScopeNone StatementScope = iota + StatementScopeSession + StatementScopeInstance + StatementScopeGlobal +) + +// ShowSlowType defines the type for SlowSlow statement. +type ShowSlowType int + +const ( + // ShowSlowTop is a ShowSlowType constant. + ShowSlowTop ShowSlowType = iota + // ShowSlowRecent is a ShowSlowType constant. + ShowSlowRecent +) + +// ShowSlowKind defines the kind for SlowSlow statement when the type is ShowSlowTop. +type ShowSlowKind int + +const ( + // ShowSlowKindDefault is a ShowSlowKind constant. + ShowSlowKindDefault ShowSlowKind = iota + // ShowSlowKindInternal is a ShowSlowKind constant. + ShowSlowKindInternal + // ShowSlowKindAll is a ShowSlowKind constant. + ShowSlowKindAll +) + +// ShowSlow is used for the following command: +// +// admin show slow top [ internal | all] N +// admin show slow recent N +type ShowSlow struct { + Tp ShowSlowType + Count uint64 + Kind ShowSlowKind +} + +// Restore implements Node interface. +func (n *ShowSlow) Restore(ctx *format.RestoreCtx) error { + switch n.Tp { + case ShowSlowRecent: + ctx.WriteKeyWord("RECENT ") + case ShowSlowTop: + ctx.WriteKeyWord("TOP ") + switch n.Kind { + case ShowSlowKindDefault: + // do nothing + case ShowSlowKindInternal: + ctx.WriteKeyWord("INTERNAL ") + case ShowSlowKindAll: + ctx.WriteKeyWord("ALL ") + default: + return errors.New("Unsupported kind of ShowSlowTop") + } + default: + return errors.New("Unsupported type of ShowSlow") + } + ctx.WritePlainf("%d", n.Count) + return nil +} + +// LimitSimple is the struct for Admin statement limit option. +type LimitSimple struct { + Count uint64 + Offset uint64 +} + +// AdminStmt is the struct for Admin statement. +type AdminStmt struct { + stmtNode + + Tp AdminStmtType + Index string + Tables []*TableName + JobIDs []int64 + JobNumber int64 + + HandleRanges []HandleRange + ShowSlow *ShowSlow + Plugins []string + Where ExprNode + StatementScope StatementScope + LimitSimple LimitSimple + BDRRole BDRRole +} + +// Restore implements Node interface. +func (n *AdminStmt) Restore(ctx *format.RestoreCtx) error { + restoreTables := func() error { + for i, v := range n.Tables { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore AdminStmt.Tables[%d]", i) + } + } + return nil + } + restoreJobIDs := func() { + for i, v := range n.JobIDs { + if i != 0 { + ctx.WritePlain(", ") + } + ctx.WritePlainf("%d", v) + } + } + + ctx.WriteKeyWord("ADMIN ") + switch n.Tp { + case AdminShowDDL: + ctx.WriteKeyWord("SHOW DDL") + case AdminShowDDLJobs: + ctx.WriteKeyWord("SHOW DDL JOBS") + if n.JobNumber != 0 { + ctx.WritePlainf(" %d", n.JobNumber) + } + if n.Where != nil { + ctx.WriteKeyWord(" WHERE ") + if err := n.Where.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore ShowStmt.Where") + } + } + case AdminShowNextRowID: + ctx.WriteKeyWord("SHOW ") + if err := restoreTables(); err != nil { + return err + } + ctx.WriteKeyWord(" NEXT_ROW_ID") + case AdminCheckTable: + ctx.WriteKeyWord("CHECK TABLE ") + if err := restoreTables(); err != nil { + return err + } + case AdminCheckIndex: + ctx.WriteKeyWord("CHECK INDEX ") + if err := restoreTables(); err != nil { + return err + } + ctx.WritePlainf(" %s", n.Index) + case AdminRecoverIndex: + ctx.WriteKeyWord("RECOVER INDEX ") + if err := restoreTables(); err != nil { + return err + } + ctx.WritePlainf(" %s", n.Index) + case AdminCleanupIndex: + ctx.WriteKeyWord("CLEANUP INDEX ") + if err := restoreTables(); err != nil { + return err + } + ctx.WritePlainf(" %s", n.Index) + case AdminCheckIndexRange: + ctx.WriteKeyWord("CHECK INDEX ") + if err := restoreTables(); err != nil { + return err + } + ctx.WritePlainf(" %s", n.Index) + if n.HandleRanges != nil { + ctx.WritePlain(" ") + for i, v := range n.HandleRanges { + if i != 0 { + ctx.WritePlain(", ") + } + ctx.WritePlainf("(%d,%d)", v.Begin, v.End) + } + } + case AdminChecksumTable: + ctx.WriteKeyWord("CHECKSUM TABLE ") + if err := restoreTables(); err != nil { + return err + } + case AdminCancelDDLJobs: + ctx.WriteKeyWord("CANCEL DDL JOBS ") + restoreJobIDs() + case AdminPauseDDLJobs: + ctx.WriteKeyWord("PAUSE DDL JOBS ") + restoreJobIDs() + case AdminResumeDDLJobs: + ctx.WriteKeyWord("RESUME DDL JOBS ") + restoreJobIDs() + case AdminShowDDLJobQueries: + ctx.WriteKeyWord("SHOW DDL JOB QUERIES ") + restoreJobIDs() + case AdminShowDDLJobQueriesWithRange: + ctx.WriteKeyWord("SHOW DDL JOB QUERIES LIMIT ") + ctx.WritePlainf("%d, %d", n.LimitSimple.Offset, n.LimitSimple.Count) + case AdminShowSlow: + ctx.WriteKeyWord("SHOW SLOW ") + if err := n.ShowSlow.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore AdminStmt.ShowSlow") + } + case AdminReloadExprPushdownBlacklist: + ctx.WriteKeyWord("RELOAD EXPR_PUSHDOWN_BLACKLIST") + case AdminReloadOptRuleBlacklist: + ctx.WriteKeyWord("RELOAD OPT_RULE_BLACKLIST") + case AdminPluginEnable: + ctx.WriteKeyWord("PLUGINS ENABLE") + for i, v := range n.Plugins { + if i == 0 { + ctx.WritePlain(" ") + } else { + ctx.WritePlain(", ") + } + ctx.WritePlain(v) + } + case AdminPluginDisable: + ctx.WriteKeyWord("PLUGINS DISABLE") + for i, v := range n.Plugins { + if i == 0 { + ctx.WritePlain(" ") + } else { + ctx.WritePlain(", ") + } + ctx.WritePlain(v) + } + case AdminFlushBindings: + ctx.WriteKeyWord("FLUSH BINDINGS") + case AdminCaptureBindings: + ctx.WriteKeyWord("CAPTURE BINDINGS") + case AdminEvolveBindings: + ctx.WriteKeyWord("EVOLVE BINDINGS") + case AdminReloadBindings: + ctx.WriteKeyWord("RELOAD BINDINGS") + case AdminReloadStatistics: + ctx.WriteKeyWord("RELOAD STATS_EXTENDED") + case AdminFlushPlanCache: + if n.StatementScope == StatementScopeSession { + ctx.WriteKeyWord("FLUSH SESSION PLAN_CACHE") + } else if n.StatementScope == StatementScopeInstance { + ctx.WriteKeyWord("FLUSH INSTANCE PLAN_CACHE") + } else if n.StatementScope == StatementScopeGlobal { + ctx.WriteKeyWord("FLUSH GLOBAL PLAN_CACHE") + } + case AdminSetBDRRole: + switch n.BDRRole { + case BDRRolePrimary: + ctx.WriteKeyWord("SET BDR ROLE PRIMARY") + case BDRRoleSecondary: + ctx.WriteKeyWord("SET BDR ROLE SECONDARY") + default: + return errors.New("Unsupported BDR role") + } + case AdminShowBDRRole: + ctx.WriteKeyWord("SHOW BDR ROLE") + case AdminUnsetBDRRole: + ctx.WriteKeyWord("UNSET BDR ROLE") + default: + return errors.New("Unsupported AdminStmt type") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *AdminStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + + n = newNode.(*AdminStmt) + for i, val := range n.Tables { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Tables[i] = node.(*TableName) + } + + if n.Where != nil { + node, ok := n.Where.Accept(v) + if !ok { + return n, false + } + n.Where = node.(ExprNode) + } + + return v.Leave(n) +} + +// RoleOrPriv is a temporary structure to be further processed into auth.RoleIdentity or PrivElem +type RoleOrPriv struct { + Symbols string // hold undecided symbols + Node interface{} // hold auth.RoleIdentity or PrivElem that can be sure when parsing +} + +func (n *RoleOrPriv) ToRole() (*auth.RoleIdentity, error) { + if n.Node != nil { + if r, ok := n.Node.(*auth.RoleIdentity); ok { + return r, nil + } + return nil, errors.Errorf("can't convert to RoleIdentity, type %T", n.Node) + } + return &auth.RoleIdentity{Username: n.Symbols, Hostname: "%"}, nil +} + +func (n *RoleOrPriv) ToPriv() (*PrivElem, error) { + if n.Node != nil { + if p, ok := n.Node.(*PrivElem); ok { + return p, nil + } + return nil, errors.Errorf("can't convert to PrivElem, type %T", n.Node) + } + if len(n.Symbols) == 0 { + return nil, errors.New("symbols should not be length 0") + } + return &PrivElem{Priv: mysql.ExtendedPriv, Name: n.Symbols}, nil +} + +// PrivElem is the privilege type and optional column list. +type PrivElem struct { + node + + Priv mysql.PrivilegeType + Cols []*ColumnName + Name string +} + +// Restore implements Node interface. +func (n *PrivElem) Restore(ctx *format.RestoreCtx) error { + if n.Priv == mysql.AllPriv { + ctx.WriteKeyWord("ALL") + } else if n.Priv == mysql.ExtendedPriv { + ctx.WriteKeyWord(n.Name) + } else { + str, ok := mysql.Priv2Str[n.Priv] + if !ok { + return errors.New("Undefined privilege type") + } + ctx.WriteKeyWord(str) + } + if n.Cols != nil { + ctx.WritePlain(" (") + for i, v := range n.Cols { + if i != 0 { + ctx.WritePlain(",") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore PrivElem.Cols[%d]", i) + } + } + ctx.WritePlain(")") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *PrivElem) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*PrivElem) + for i, val := range n.Cols { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Cols[i] = node.(*ColumnName) + } + return v.Leave(n) +} + +// ObjectTypeType is the type for object type. +type ObjectTypeType int + +const ( + // ObjectTypeNone is for empty object type. + ObjectTypeNone ObjectTypeType = iota + 1 + // ObjectTypeTable means the following object is a table. + ObjectTypeTable + // ObjectTypeFunction means the following object is a stored function. + ObjectTypeFunction + // ObjectTypeProcedure means the following object is a stored procedure. + ObjectTypeProcedure +) + +// Restore implements Node interface. +func (n ObjectTypeType) Restore(ctx *format.RestoreCtx) error { + switch n { + case ObjectTypeNone: + // do nothing + case ObjectTypeTable: + ctx.WriteKeyWord("TABLE") + case ObjectTypeFunction: + ctx.WriteKeyWord("FUNCTION") + case ObjectTypeProcedure: + ctx.WriteKeyWord("PROCEDURE") + default: + return errors.New("Unsupported object type") + } + return nil +} + +// GrantLevelType is the type for grant level. +type GrantLevelType int + +const ( + // GrantLevelNone is the dummy const for default value. + GrantLevelNone GrantLevelType = iota + 1 + // GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server. + GrantLevelGlobal + // GrantLevelDB means the privileges apply to all objects in a given database. + GrantLevelDB + // GrantLevelTable means the privileges apply to all columns in a given table. + GrantLevelTable +) + +// GrantLevel is used for store the privilege scope. +type GrantLevel struct { + Level GrantLevelType + DBName string + TableName string +} + +// Restore implements Node interface. +func (n *GrantLevel) Restore(ctx *format.RestoreCtx) error { + switch n.Level { + case GrantLevelDB: + if n.DBName == "" { + ctx.WritePlain("*") + } else { + ctx.WriteName(n.DBName) + ctx.WritePlain(".*") + } + case GrantLevelGlobal: + ctx.WritePlain("*.*") + case GrantLevelTable: + if n.DBName != "" { + ctx.WriteName(n.DBName) + ctx.WritePlain(".") + } + ctx.WriteName(n.TableName) + } + return nil +} + +// RevokeStmt is the struct for REVOKE statement. +type RevokeStmt struct { + stmtNode + + Privs []*PrivElem + ObjectType ObjectTypeType + Level *GrantLevel + Users []*UserSpec +} + +// Restore implements Node interface. +func (n *RevokeStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("REVOKE ") + for i, v := range n.Privs { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore RevokeStmt.Privs[%d]", i) + } + } + ctx.WriteKeyWord(" ON ") + if n.ObjectType != ObjectTypeNone { + if err := n.ObjectType.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore RevokeStmt.ObjectType") + } + ctx.WritePlain(" ") + } + if err := n.Level.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore RevokeStmt.Level") + } + ctx.WriteKeyWord(" FROM ") + for i, v := range n.Users { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore RevokeStmt.Users[%d]", i) + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *RevokeStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*RevokeStmt) + for i, val := range n.Privs { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Privs[i] = node.(*PrivElem) + } + return v.Leave(n) +} + +// RevokeStmt is the struct for REVOKE statement. +type RevokeRoleStmt struct { + stmtNode + + Roles []*auth.RoleIdentity + Users []*auth.UserIdentity +} + +// Restore implements Node interface. +func (n *RevokeRoleStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("REVOKE ") + for i, role := range n.Roles { + if i != 0 { + ctx.WritePlain(", ") + } + if err := role.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore RevokeRoleStmt.Roles[%d]", i) + } + } + ctx.WriteKeyWord(" FROM ") + for i, v := range n.Users { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore RevokeRoleStmt.Users[%d]", i) + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *RevokeRoleStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*RevokeRoleStmt) + return v.Leave(n) +} + +// GrantStmt is the struct for GRANT statement. +type GrantStmt struct { + stmtNode + + Privs []*PrivElem + ObjectType ObjectTypeType + Level *GrantLevel + Users []*UserSpec + AuthTokenOrTLSOptions []*AuthTokenOrTLSOption + WithGrant bool +} + +// Restore implements Node interface. +func (n *GrantStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("GRANT ") + for i, v := range n.Privs { + if i != 0 && v.Priv != 0 { + ctx.WritePlain(", ") + } else if v.Priv == 0 { + ctx.WritePlain(" ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore GrantStmt.Privs[%d]", i) + } + } + ctx.WriteKeyWord(" ON ") + if n.ObjectType != ObjectTypeNone { + if err := n.ObjectType.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore GrantStmt.ObjectType") + } + ctx.WritePlain(" ") + } + if err := n.Level.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore GrantStmt.Level") + } + ctx.WriteKeyWord(" TO ") + for i, v := range n.Users { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore GrantStmt.Users[%d]", i) + } + } + if n.AuthTokenOrTLSOptions != nil { + if len(n.AuthTokenOrTLSOptions) != 0 { + ctx.WriteKeyWord(" REQUIRE ") + } + for i, option := range n.AuthTokenOrTLSOptions { + if i != 0 { + ctx.WriteKeyWord(" AND ") + } + if err := option.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore GrantStmt.AuthTokenOrTLSOptions[%d]", i) + } + } + } + if n.WithGrant { + ctx.WriteKeyWord(" WITH GRANT OPTION") + } + return nil +} + +// SecureText implements SensitiveStatement interface. +func (n *GrantStmt) SecureText() string { + text := n.text + // Filter "identified by xxx" because it would expose password information. + idx := strings.Index(strings.ToLower(text), "identified") + if idx > 0 { + text = text[:idx] + } + return text +} + +// Accept implements Node Accept interface. +func (n *GrantStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*GrantStmt) + for i, val := range n.Privs { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Privs[i] = node.(*PrivElem) + } + return v.Leave(n) +} + +// GrantProxyStmt is the struct for GRANT PROXY statement. +type GrantProxyStmt struct { + stmtNode + + LocalUser *auth.UserIdentity + ExternalUsers []*auth.UserIdentity + WithGrant bool +} + +// Accept implements Node Accept interface. +func (n *GrantProxyStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*GrantProxyStmt) + return v.Leave(n) +} + +// Restore implements Node interface. +func (n *GrantProxyStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("GRANT PROXY ON ") + if err := n.LocalUser.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore GrantProxyStmt.LocalUser") + } + ctx.WriteKeyWord(" TO ") + for i, v := range n.ExternalUsers { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore GrantProxyStmt.ExternalUsers[%d]", i) + } + } + if n.WithGrant { + ctx.WriteKeyWord(" WITH GRANT OPTION") + } + return nil +} + +// GrantRoleStmt is the struct for GRANT TO statement. +type GrantRoleStmt struct { + stmtNode + + Roles []*auth.RoleIdentity + Users []*auth.UserIdentity +} + +// Accept implements Node Accept interface. +func (n *GrantRoleStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*GrantRoleStmt) + return v.Leave(n) +} + +// Restore implements Node interface. +func (n *GrantRoleStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("GRANT ") + if len(n.Roles) > 0 { + for i, role := range n.Roles { + if i != 0 { + ctx.WritePlain(", ") + } + if err := role.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore GrantRoleStmt.Roles[%d]", i) + } + } + } + ctx.WriteKeyWord(" TO ") + for i, v := range n.Users { + if i != 0 { + ctx.WritePlain(", ") + } + if err := v.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore GrantStmt.Users[%d]", i) + } + } + return nil +} + +// SecureText implements SensitiveStatement interface. +func (n *GrantRoleStmt) SecureText() string { + text := n.text + // Filter "identified by xxx" because it would expose password information. + idx := strings.Index(strings.ToLower(text), "identified") + if idx > 0 { + text = text[:idx] + } + return text +} + +// ShutdownStmt is a statement to stop the TiDB server. +// See https://dev.mysql.com/doc/refman/5.7/en/shutdown.html +type ShutdownStmt struct { + stmtNode +} + +// Restore implements Node interface. +func (n *ShutdownStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SHUTDOWN") + return nil +} + +// Accept implements Node Accept interface. +func (n *ShutdownStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*ShutdownStmt) + return v.Leave(n) +} + +// RestartStmt is a statement to restart the TiDB server. +// See https://dev.mysql.com/doc/refman/8.0/en/restart.html +type RestartStmt struct { + stmtNode +} + +// Restore implements Node interface. +func (n *RestartStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("RESTART") + return nil +} + +// Accept implements Node Accept interface. +func (n *RestartStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*RestartStmt) + return v.Leave(n) +} + +// HelpStmt is a statement for server side help +// See https://dev.mysql.com/doc/refman/8.0/en/help.html +type HelpStmt struct { + stmtNode + + Topic string +} + +// Restore implements Node interface. +func (n *HelpStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("HELP ") + ctx.WriteString(n.Topic) + return nil +} + +// Accept implements Node Accept interface. +func (n *HelpStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*HelpStmt) + return v.Leave(n) +} + +// RenameUserStmt is a statement to rename a user. +// See http://dev.mysql.com/doc/refman/5.7/en/rename-user.html +type RenameUserStmt struct { + stmtNode + + UserToUsers []*UserToUser +} + +// Restore implements Node interface. +func (n *RenameUserStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("RENAME USER ") + for index, user2user := range n.UserToUsers { + if index != 0 { + ctx.WritePlain(", ") + } + if err := user2user.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore RenameUserStmt.UserToUsers") + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *RenameUserStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*RenameUserStmt) + + for i, t := range n.UserToUsers { + node, ok := t.Accept(v) + if !ok { + return n, false + } + n.UserToUsers[i] = node.(*UserToUser) + } + return v.Leave(n) +} + +// UserToUser represents renaming old user to new user used in RenameUserStmt. +type UserToUser struct { + node + OldUser *auth.UserIdentity + NewUser *auth.UserIdentity +} + +// Restore implements Node interface. +func (n *UserToUser) Restore(ctx *format.RestoreCtx) error { + if err := n.OldUser.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore UserToUser.OldUser") + } + ctx.WriteKeyWord(" TO ") + if err := n.NewUser.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore UserToUser.NewUser") + } + return nil +} + +// Accept implements Node Accept interface. +func (n *UserToUser) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*UserToUser) + return v.Leave(n) +} + +type BRIEKind uint8 +type BRIEOptionType uint16 + +const ( + BRIEKindBackup BRIEKind = iota + BRIEKindCancelJob + BRIEKindStreamStart + BRIEKindStreamMetaData + BRIEKindStreamStatus + BRIEKindStreamPause + BRIEKindStreamResume + BRIEKindStreamStop + BRIEKindStreamPurge + BRIEKindRestore + BRIEKindRestorePIT + BRIEKindShowJob + BRIEKindShowQuery + BRIEKindShowBackupMeta + // common BRIE options + BRIEOptionRateLimit BRIEOptionType = iota + 1 + BRIEOptionConcurrency + BRIEOptionChecksum + BRIEOptionSendCreds + BRIEOptionCheckpoint + BRIEOptionStartTS + BRIEOptionUntilTS + BRIEOptionChecksumConcurrency + BRIEOptionEncryptionMethod + BRIEOptionEncryptionKeyFile + // backup options + BRIEOptionBackupTimeAgo + BRIEOptionBackupTS + BRIEOptionBackupTSO + BRIEOptionLastBackupTS + BRIEOptionLastBackupTSO + BRIEOptionGCTTL + BRIEOptionCompressionLevel + BRIEOptionCompression + BRIEOptionIgnoreStats + BRIEOptionLoadStats + // restore options + BRIEOptionOnline + BRIEOptionFullBackupStorage + BRIEOptionRestoredTS + BRIEOptionWaitTiflashReady + BRIEOptionWithSysTable + // import options + BRIEOptionAnalyze + BRIEOptionBackend + BRIEOptionOnDuplicate + BRIEOptionSkipSchemaFiles + BRIEOptionStrictFormat + BRIEOptionTiKVImporter + BRIEOptionResume + // CSV options + BRIEOptionCSVBackslashEscape + BRIEOptionCSVDelimiter + BRIEOptionCSVHeader + BRIEOptionCSVNotNull + BRIEOptionCSVNull + BRIEOptionCSVSeparator + BRIEOptionCSVTrimLastSeparators + + BRIECSVHeaderIsColumns = ^uint64(0) +) + +type BRIEOptionLevel uint64 + +const ( + BRIEOptionLevelOff BRIEOptionLevel = iota // equals FALSE + BRIEOptionLevelRequired // equals TRUE + BRIEOptionLevelOptional +) + +func (kind BRIEKind) String() string { + switch kind { + case BRIEKindBackup: + return "BACKUP" + case BRIEKindRestore: + return "RESTORE" + case BRIEKindStreamStart: + return "BACKUP LOGS" + case BRIEKindStreamStop: + return "STOP BACKUP LOGS" + case BRIEKindStreamPause: + return "PAUSE BACKUP LOGS" + case BRIEKindStreamResume: + return "RESUME BACKUP LOGS" + case BRIEKindStreamStatus: + return "SHOW BACKUP LOGS STATUS" + case BRIEKindStreamMetaData: + return "SHOW BACKUP LOGS METADATA" + case BRIEKindStreamPurge: + return "PURGE BACKUP LOGS" + case BRIEKindRestorePIT: + return "RESTORE POINT" + case BRIEKindShowJob: + return "SHOW BR JOB" + case BRIEKindShowQuery: + return "SHOW BR JOB QUERY" + case BRIEKindCancelJob: + return "CANCEL BR JOB" + case BRIEKindShowBackupMeta: + return "SHOW BACKUP METADATA" + default: + return "" + } +} + +func (kind BRIEOptionType) String() string { + switch kind { + case BRIEOptionRateLimit: + return "RATE_LIMIT" + case BRIEOptionConcurrency: + return "CONCURRENCY" + case BRIEOptionChecksum: + return "CHECKSUM" + case BRIEOptionSendCreds: + return "SEND_CREDENTIALS_TO_TIKV" + case BRIEOptionBackupTimeAgo, BRIEOptionBackupTS, BRIEOptionBackupTSO: + return "SNAPSHOT" + case BRIEOptionLastBackupTS, BRIEOptionLastBackupTSO: + return "LAST_BACKUP" + case BRIEOptionOnline: + return "ONLINE" + case BRIEOptionCheckpoint: + return "CHECKPOINT" + case BRIEOptionAnalyze: + return "ANALYZE" + case BRIEOptionBackend: + return "BACKEND" + case BRIEOptionOnDuplicate: + return "ON_DUPLICATE" + case BRIEOptionSkipSchemaFiles: + return "SKIP_SCHEMA_FILES" + case BRIEOptionStrictFormat: + return "STRICT_FORMAT" + case BRIEOptionTiKVImporter: + return "TIKV_IMPORTER" + case BRIEOptionResume: + return "RESUME" + case BRIEOptionCSVBackslashEscape: + return "CSV_BACKSLASH_ESCAPE" + case BRIEOptionCSVDelimiter: + return "CSV_DELIMITER" + case BRIEOptionCSVHeader: + return "CSV_HEADER" + case BRIEOptionCSVNotNull: + return "CSV_NOT_NULL" + case BRIEOptionCSVNull: + return "CSV_NULL" + case BRIEOptionCSVSeparator: + return "CSV_SEPARATOR" + case BRIEOptionCSVTrimLastSeparators: + return "CSV_TRIM_LAST_SEPARATORS" + case BRIEOptionFullBackupStorage: + return "FULL_BACKUP_STORAGE" + case BRIEOptionRestoredTS: + return "RESTORED_TS" + case BRIEOptionStartTS: + return "START_TS" + case BRIEOptionUntilTS: + return "UNTIL_TS" + case BRIEOptionGCTTL: + return "GC_TTL" + case BRIEOptionWaitTiflashReady: + return "WAIT_TIFLASH_READY" + case BRIEOptionWithSysTable: + return "WITH_SYS_TABLE" + case BRIEOptionIgnoreStats: + return "IGNORE_STATS" + case BRIEOptionLoadStats: + return "LOAD_STATS" + case BRIEOptionChecksumConcurrency: + return "CHECKSUM_CONCURRENCY" + case BRIEOptionCompressionLevel: + return "COMPRESSION_LEVEL" + case BRIEOptionCompression: + return "COMPRESSION_TYPE" + case BRIEOptionEncryptionMethod: + return "ENCRYPTION_METHOD" + case BRIEOptionEncryptionKeyFile: + return "ENCRYPTION_KEY_FILE" + default: + return "" + } +} + +func (level BRIEOptionLevel) String() string { + switch level { + case BRIEOptionLevelOff: + return "OFF" + case BRIEOptionLevelOptional: + return "OPTIONAL" + case BRIEOptionLevelRequired: + return "REQUIRED" + default: + return "" + } +} + +type BRIEOption struct { + Tp BRIEOptionType + StrValue string + UintValue uint64 +} + +func (opt *BRIEOption) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord(opt.Tp.String()) + ctx.WritePlain(" = ") + switch opt.Tp { + case BRIEOptionBackupTS, BRIEOptionLastBackupTS, BRIEOptionBackend, BRIEOptionOnDuplicate, BRIEOptionTiKVImporter, BRIEOptionCSVDelimiter, BRIEOptionCSVNull, BRIEOptionCSVSeparator, BRIEOptionFullBackupStorage, BRIEOptionRestoredTS, BRIEOptionStartTS, BRIEOptionUntilTS, BRIEOptionGCTTL, BRIEOptionCompression, BRIEOptionEncryptionMethod, BRIEOptionEncryptionKeyFile: + ctx.WriteString(opt.StrValue) + case BRIEOptionBackupTimeAgo: + ctx.WritePlainf("%d ", opt.UintValue/1000) + ctx.WriteKeyWord("MICROSECOND AGO") + case BRIEOptionRateLimit: + ctx.WritePlainf("%d ", opt.UintValue/1048576) + ctx.WriteKeyWord("MB") + ctx.WritePlain("/") + ctx.WriteKeyWord("SECOND") + case BRIEOptionCSVHeader: + if opt.UintValue == BRIECSVHeaderIsColumns { + ctx.WriteKeyWord("COLUMNS") + } else { + ctx.WritePlainf("%d", opt.UintValue) + } + case BRIEOptionChecksum, BRIEOptionAnalyze: + // BACKUP/RESTORE doesn't support OPTIONAL value for now, should warn at executor + ctx.WriteKeyWord(BRIEOptionLevel(opt.UintValue).String()) + default: + ctx.WritePlainf("%d", opt.UintValue) + } + return nil +} + +// BRIEStmt is a statement for backup, restore, import and export. +type BRIEStmt struct { + stmtNode + + Kind BRIEKind + Schemas []string + Tables []*TableName + Storage string + JobID int64 + Options []*BRIEOption +} + +func (n *BRIEStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*BRIEStmt) + for i, val := range n.Tables { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Tables[i] = node.(*TableName) + } + return v.Leave(n) +} + +func (n *BRIEStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord(n.Kind.String()) + + switch n.Kind { + case BRIEKindRestore, BRIEKindBackup: + switch { + case len(n.Tables) != 0: + ctx.WriteKeyWord(" TABLE ") + for index, table := range n.Tables { + if index != 0 { + ctx.WritePlain(", ") + } + if err := table.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore BRIEStmt.Tables[%d]", index) + } + } + case len(n.Schemas) != 0: + ctx.WriteKeyWord(" DATABASE ") + for index, schema := range n.Schemas { + if index != 0 { + ctx.WritePlain(", ") + } + ctx.WriteName(schema) + } + default: + ctx.WriteKeyWord(" DATABASE") + ctx.WritePlain(" *") + } + + if n.Kind == BRIEKindBackup { + ctx.WriteKeyWord(" TO ") + ctx.WriteString(n.Storage) + } else { + ctx.WriteKeyWord(" FROM ") + ctx.WriteString(n.Storage) + } + case BRIEKindCancelJob, BRIEKindShowJob, BRIEKindShowQuery: + ctx.WritePlainf(" %d", n.JobID) + case BRIEKindStreamStart: + ctx.WriteKeyWord(" TO ") + ctx.WriteString(n.Storage) + case BRIEKindRestorePIT, BRIEKindStreamMetaData, BRIEKindShowBackupMeta, BRIEKindStreamPurge: + ctx.WriteKeyWord(" FROM ") + ctx.WriteString(n.Storage) + } + + for _, opt := range n.Options { + ctx.WritePlain(" ") + if err := opt.Restore(ctx); err != nil { + return err + } + } + + return nil +} + +// RedactURL redacts the secret tokens in the URL. only S3 url need redaction for now. +// if the url is not a valid url, return the original string. +func RedactURL(str string) string { + // FIXME: this solution is not scalable, and duplicates some logic from BR. + u, err := url.Parse(str) + if err != nil { + return str + } + scheme := u.Scheme + failpoint.Inject("forceRedactURL", func() { + scheme = "s3" + }) + switch strings.ToLower(scheme) { + case "s3", "ks3": + values := u.Query() + for k := range values { + // see below on why we normalize key + // https://github.com/pingcap/tidb/blob/a7c0d95f16ea2582bb569278c3f829403e6c3a7e/br/pkg/storage/parse.go#L163 + normalizedKey := strings.ToLower(strings.ReplaceAll(k, "_", "-")) + if normalizedKey == "access-key" || normalizedKey == "secret-access-key" || normalizedKey == "session-token" { + values[k] = []string{"xxxxxx"} + } + } + u.RawQuery = values.Encode() + } + return u.String() +} + +// SecureText implements SensitiveStmtNode +func (n *BRIEStmt) SecureText() string { + redactedStmt := &BRIEStmt{ + Kind: n.Kind, + Schemas: n.Schemas, + Tables: n.Tables, + Storage: RedactURL(n.Storage), + Options: n.Options, + } + + var sb strings.Builder + _ = redactedStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) + return sb.String() +} + +type ImportIntoActionTp string + +const ( + ImportIntoCancel ImportIntoActionTp = "cancel" +) + +// ImportIntoActionStmt represent CANCEL IMPORT INTO JOB statement. +// will support pause/resume/drop later. +type ImportIntoActionStmt struct { + stmtNode + + Tp ImportIntoActionTp + JobID int64 +} + +func (n *ImportIntoActionStmt) Accept(v Visitor) (Node, bool) { + newNode, _ := v.Enter(n) + return v.Leave(newNode) +} + +func (n *ImportIntoActionStmt) Restore(ctx *format.RestoreCtx) error { + if n.Tp != ImportIntoCancel { + return errors.Errorf("invalid IMPORT INTO action type: %s", n.Tp) + } + ctx.WriteKeyWord("CANCEL IMPORT JOB ") + ctx.WritePlainf("%d", n.JobID) + return nil +} + +// Ident is the table identifier composed of schema name and table name. +type Ident struct { + Schema model.CIStr + Name model.CIStr +} + +// String implements fmt.Stringer interface. +func (i Ident) String() string { + if i.Schema.O == "" { + return i.Name.O + } + return fmt.Sprintf("%s.%s", i.Schema, i.Name) +} + +// SelectStmtOpts wrap around select hints and switches +type SelectStmtOpts struct { + Distinct bool + SQLBigResult bool + SQLBufferResult bool + SQLCache bool + SQLSmallResult bool + CalcFoundRows bool + StraightJoin bool + Priority mysql.PriorityEnum + TableHints []*TableOptimizerHint + ExplicitAll bool +} + +// TableOptimizerHint is Table level optimizer hint +type TableOptimizerHint struct { + node + // HintName is the name or alias of the table(s) which the hint will affect. + // Table hints has no schema info + // It allows only table name or alias (if table has an alias) + HintName model.CIStr + // HintData is the payload of the hint. The actual type of this field + // is defined differently as according `HintName`. Define as following: + // + // Statement Execution Time Optimizer Hints + // See https://dev.mysql.com/doc/refman/5.7/en/optimizer-hints.html#optimizer-hints-execution-time + // - MAX_EXECUTION_TIME => uint64 + // - MEMORY_QUOTA => int64 + // - QUERY_TYPE => model.CIStr + // + // Time Range is used to hint the time range of inspection tables + // e.g: select /*+ time_range('','') */ * from information_schema.inspection_result. + // - TIME_RANGE => ast.HintTimeRange + // - READ_FROM_STORAGE => model.CIStr + // - USE_TOJA => bool + // - NTH_PLAN => int64 + HintData interface{} + // QBName is the default effective query block of this hint. + QBName model.CIStr + Tables []HintTable + Indexes []model.CIStr +} + +// HintTimeRange is the payload of `TIME_RANGE` hint +type HintTimeRange struct { + From string + To string +} + +// HintSetVar is the payload of `SET_VAR` hint +type HintSetVar struct { + VarName string + Value string +} + +// HintTable is table in the hint. It may have query block info. +type HintTable struct { + DBName model.CIStr + TableName model.CIStr + QBName model.CIStr + PartitionList []model.CIStr +} + +func (ht *HintTable) Restore(ctx *format.RestoreCtx) { + if !ctx.Flags.HasWithoutSchemaNameFlag() { + if ht.DBName.L != "" { + ctx.WriteName(ht.DBName.String()) + ctx.WriteKeyWord(".") + } + } + ctx.WriteName(ht.TableName.String()) + if ht.QBName.L != "" { + ctx.WriteKeyWord("@") + ctx.WriteName(ht.QBName.String()) + } + if len(ht.PartitionList) > 0 { + ctx.WriteKeyWord(" PARTITION") + ctx.WritePlain("(") + for i, p := range ht.PartitionList { + if i > 0 { + ctx.WritePlain(", ") + } + ctx.WriteName(p.String()) + } + ctx.WritePlain(")") + } +} + +// Restore implements Node interface. +func (n *TableOptimizerHint) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord(n.HintName.String()) + ctx.WritePlain("(") + if n.QBName.L != "" { + if n.HintName.L != "qb_name" { + ctx.WriteKeyWord("@") + } + ctx.WriteName(n.QBName.String()) + } + if n.HintName.L == "qb_name" && len(n.Tables) == 0 { + ctx.WritePlain(")") + return nil + } + // Hints without args except query block. + switch n.HintName.L { + case "mpp_1phase_agg", "mpp_2phase_agg", "hash_agg", "stream_agg", "agg_to_cop", "read_consistent_replica", "no_index_merge", "ignore_plan_cache", "limit_to_cop", "straight_join", "merge", "no_decorrelate": + ctx.WritePlain(")") + return nil + } + if n.QBName.L != "" { + ctx.WritePlain(" ") + } + // Hints with args except query block. + switch n.HintName.L { + case "max_execution_time": + ctx.WritePlainf("%d", n.HintData.(uint64)) + case "resource_group": + ctx.WriteName(n.HintData.(string)) + case "nth_plan": + ctx.WritePlainf("%d", n.HintData.(int64)) + case "tidb_hj", "tidb_smj", "tidb_inlj", "hash_join", "hash_join_build", "hash_join_probe", "merge_join", "inl_join", + "broadcast_join", "shuffle_join", "inl_hash_join", "inl_merge_join", "leading", "no_hash_join", "no_merge_join", + "no_index_join", "no_index_hash_join", "no_index_merge_join": + for i, table := range n.Tables { + if i != 0 { + ctx.WritePlain(", ") + } + table.Restore(ctx) + } + case "use_index", "ignore_index", "use_index_merge", "force_index", "order_index", "no_order_index": + n.Tables[0].Restore(ctx) + ctx.WritePlain(" ") + for i, index := range n.Indexes { + if i != 0 { + ctx.WritePlain(", ") + } + ctx.WriteName(index.String()) + } + case "qb_name": + if len(n.Tables) > 0 { + ctx.WritePlain(", ") + for i, table := range n.Tables { + if i != 0 { + ctx.WritePlain(". ") + } + table.Restore(ctx) + } + } + case "use_toja", "use_cascades": + if n.HintData.(bool) { + ctx.WritePlain("TRUE") + } else { + ctx.WritePlain("FALSE") + } + case "query_type": + ctx.WriteKeyWord(n.HintData.(model.CIStr).String()) + case "memory_quota": + ctx.WritePlainf("%d MB", n.HintData.(int64)/1024/1024) + case "read_from_storage": + ctx.WriteKeyWord(n.HintData.(model.CIStr).String()) + for i, table := range n.Tables { + if i == 0 { + ctx.WritePlain("[") + } + table.Restore(ctx) + if i == len(n.Tables)-1 { + ctx.WritePlain("]") + } else { + ctx.WritePlain(", ") + } + } + case "time_range": + hintData := n.HintData.(HintTimeRange) + ctx.WriteString(hintData.From) + ctx.WritePlain(", ") + ctx.WriteString(hintData.To) + case "set_var": + hintData := n.HintData.(HintSetVar) + ctx.WritePlain(hintData.VarName) + ctx.WritePlain(" = ") + ctx.WriteString(hintData.Value) + } + ctx.WritePlain(")") + return nil +} + +// Accept implements Node Accept interface. +func (n *TableOptimizerHint) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*TableOptimizerHint) + return v.Leave(n) +} + +// TextString represent a string, it can be a binary literal. +type TextString struct { + Value string + IsBinaryLiteral bool +} + +type BinaryLiteral interface { + ToString() string +} + +// NewDecimal creates a types.Decimal value, it's provided by parser driver. +var NewDecimal func(string) (interface{}, error) + +// NewHexLiteral creates a types.HexLiteral value, it's provided by parser driver. +var NewHexLiteral func(string) (interface{}, error) + +// NewBitLiteral creates a types.BitLiteral value, it's provided by parser driver. +var NewBitLiteral func(string) (interface{}, error) + +// SetResourceGroupStmt is a statement to set the resource group name for current session. +type SetResourceGroupStmt struct { + stmtNode + Name model.CIStr +} + +func (n *SetResourceGroupStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("SET RESOURCE GROUP ") + ctx.WriteName(n.Name.O) + return nil +} + +// Accept implements Node Accept interface. +func (n *SetResourceGroupStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*SetResourceGroupStmt) + return v.Leave(n) +} + +// CalibrateResourceType is the type for CalibrateResource statement. +type CalibrateResourceType int + +// calibrate resource [ workload < TPCC | OLTP_READ_WRITE | OLTP_READ_ONLY | OLTP_WRITE_ONLY | TPCH_10> ] +const ( + WorkloadNone CalibrateResourceType = iota + TPCC + OLTPREADWRITE + OLTPREADONLY + OLTPWRITEONLY + TPCH10 +) + +func (n CalibrateResourceType) Restore(ctx *format.RestoreCtx) error { + switch n { + case TPCC: + ctx.WriteKeyWord(" WORKLOAD TPCC") + case OLTPREADWRITE: + ctx.WriteKeyWord(" WORKLOAD OLTP_READ_WRITE") + case OLTPREADONLY: + ctx.WriteKeyWord(" WORKLOAD OLTP_READ_ONLY") + case OLTPWRITEONLY: + ctx.WriteKeyWord(" WORKLOAD OLTP_WRITE_ONLY") + case TPCH10: + ctx.WriteKeyWord(" WORKLOAD TPCH_10") + } + return nil +} + +// CalibrateResourceStmt is a statement to fetch the cluster RU capacity +type CalibrateResourceStmt struct { + stmtNode + DynamicCalibrateResourceOptionList []*DynamicCalibrateResourceOption + Tp CalibrateResourceType +} + +// Restore implements Node interface. +func (n *CalibrateResourceStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("CALIBRATE RESOURCE") + if err := n.Tp.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore CalibrateResourceStmt.CalibrateResourceType") + } + for i, option := range n.DynamicCalibrateResourceOptionList { + ctx.WritePlain(" ") + if err := option.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while splicing DynamicCalibrateResourceOption: [%v]", i) + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *CalibrateResourceStmt) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*CalibrateResourceStmt) + for _, val := range n.DynamicCalibrateResourceOptionList { + _, ok := val.Accept(v) + if !ok { + return n, false + } + } + return v.Leave(n) +} + +type DynamicCalibrateType int + +const ( + // specific time + CalibrateStartTime = iota + CalibrateEndTime + CalibrateDuration +) + +type DynamicCalibrateResourceOption struct { + stmtNode + Tp DynamicCalibrateType + StrValue string + Ts ExprNode + Unit TimeUnitType +} + +func (n *DynamicCalibrateResourceOption) Restore(ctx *format.RestoreCtx) error { + switch n.Tp { + case CalibrateStartTime: + ctx.WriteKeyWord("START_TIME ") + if err := n.Ts.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while splicing DynamicCalibrateResourceOption StartTime") + } + case CalibrateEndTime: + ctx.WriteKeyWord("END_TIME ") + if err := n.Ts.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while splicing DynamicCalibrateResourceOption EndTime") + } + case CalibrateDuration: + ctx.WriteKeyWord("DURATION ") + if len(n.StrValue) > 0 { + ctx.WriteString(n.StrValue) + } else { + ctx.WriteKeyWord("INTERVAL ") + if err := n.Ts.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occurred while restore DynamicCalibrateResourceOption DURATION TS") + } + ctx.WritePlain(" ") + ctx.WriteKeyWord(n.Unit.String()) + } + default: + return errors.Errorf("invalid DynamicCalibrateResourceOption: %d", n.Tp) + } + return nil +} + +// Accept implements Node Accept interface. +func (n *DynamicCalibrateResourceOption) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*DynamicCalibrateResourceOption) + if n.Ts != nil { + node, ok := n.Ts.Accept(v) + if !ok { + return n, false + } + n.Ts = node.(ExprNode) + } + return v.Leave(n) +} + +// DropQueryWatchStmt is a statement to drop a runaway watch item. +type DropQueryWatchStmt struct { + stmtNode + IntValue int64 +} + +func (n *DropQueryWatchStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("QUERY WATCH REMOVE ") + ctx.WritePlainf("%d", n.IntValue) + return nil +} + +// Accept implements Node Accept interface. +func (n *DropQueryWatchStmt) Accept(v Visitor) (Node, bool) { + newNode, _ := v.Enter(n) + n = newNode.(*DropQueryWatchStmt) + return v.Leave(n) +} + +// AddQueryWatchStmt is a statement to add a runaway watch item. +type AddQueryWatchStmt struct { + stmtNode + QueryWatchOptionList []*QueryWatchOption +} + +func (n *AddQueryWatchStmt) Restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("QUERY WATCH ADD") + for i, option := range n.QueryWatchOptionList { + ctx.WritePlain(" ") + if err := option.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while splicing QueryWatchOptionList: [%v]", i) + } + } + return nil +} + +// Accept implements Node Accept interface. +func (n *AddQueryWatchStmt) Accept(v Visitor) (Node, bool) { + newNode, _ := v.Enter(n) + n = newNode.(*AddQueryWatchStmt) + for _, val := range n.QueryWatchOptionList { + _, ok := val.Accept(v) + if !ok { + return n, false + } + } + return v.Leave(n) +} + +type QueryWatchOptionType int + +const ( + QueryWatchResourceGroup QueryWatchOptionType = iota + QueryWatchAction + QueryWatchType +) + +// QueryWatchOption is used for parsing manual management of watching runaway queries option. +type QueryWatchOption struct { + stmtNode + Tp QueryWatchOptionType + ResourceGroupOption *QueryWatchResourceGroupOption + ActionOption *ResourceGroupRunawayActionOption + TextOption *QueryWatchTextOption +} + +// Restore implements Node interface. +func (n *QueryWatchOption) Restore(ctx *format.RestoreCtx) error { + switch n.Tp { + case QueryWatchResourceGroup: + return n.ResourceGroupOption.restore(ctx) + case QueryWatchAction: + return n.ActionOption.Restore(ctx) + case QueryWatchType: + return n.TextOption.Restore(ctx) + } + return nil +} + +// Accept implements Node Accept interface. +func (n *QueryWatchOption) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*QueryWatchOption) + if n.ResourceGroupOption != nil && n.ResourceGroupOption.GroupNameExpr != nil { + node, ok := n.ResourceGroupOption.GroupNameExpr.Accept(v) + if !ok { + return n, false + } + n.ResourceGroupOption.GroupNameExpr = node.(ExprNode) + } + if n.ActionOption != nil { + node, ok := n.ActionOption.Accept(v) + if !ok { + return n, false + } + n.ActionOption = node.(*ResourceGroupRunawayActionOption) + } + if n.TextOption != nil { + node, ok := n.TextOption.Accept(v) + if !ok { + return n, false + } + n.TextOption = node.(*QueryWatchTextOption) + } + return v.Leave(n) +} + +func CheckQueryWatchAppend(ops []*QueryWatchOption, newOp *QueryWatchOption) bool { + for _, op := range ops { + if op.Tp == newOp.Tp { + return false + } + } + return true +} + +// QueryWatchResourceGroupOption is used for parsing the query watch resource group name. +type QueryWatchResourceGroupOption struct { + GroupNameStr model.CIStr + GroupNameExpr ExprNode +} + +func (n *QueryWatchResourceGroupOption) restore(ctx *format.RestoreCtx) error { + ctx.WriteKeyWord("RESOURCE GROUP ") + if n.GroupNameExpr != nil { + if err := n.GroupNameExpr.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while splicing ExprValue: [%v]", n.GroupNameExpr) + } + } else { + ctx.WriteName(n.GroupNameStr.String()) + } + return nil +} + +// QueryWatchTextOption is used for parsing the query watch text option. +type QueryWatchTextOption struct { + node + Type model.RunawayWatchType + PatternExpr ExprNode + TypeSpecified bool +} + +// Restore implements Node interface. +func (n *QueryWatchTextOption) Restore(ctx *format.RestoreCtx) error { + if n.TypeSpecified { + ctx.WriteKeyWord("SQL TEXT ") + ctx.WriteKeyWord(n.Type.String()) + ctx.WriteKeyWord(" TO ") + } else { + switch n.Type { + case model.WatchSimilar: + ctx.WriteKeyWord("SQL DIGEST ") + case model.WatchPlan: + ctx.WriteKeyWord("PLAN DIGEST ") + } + } + if err := n.PatternExpr.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while splicing ExprValue: [%v]", n.PatternExpr) + } + return nil +} + +// Accept implements Node Accept interface. +func (n *QueryWatchTextOption) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*QueryWatchTextOption) + if n.PatternExpr != nil { + node, ok := n.PatternExpr.Accept(v) + if !ok { + return n, false + } + n.PatternExpr = node.(ExprNode) + } + return v.Leave(n) +} diff --git a/pkg/planner/binding__failpoint_binding__.go b/pkg/planner/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..fc6da4ff0bb13 --- /dev/null +++ b/pkg/planner/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package planner + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/planner/cardinality/binding__failpoint_binding__.go b/pkg/planner/cardinality/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..f8cd42267a934 --- /dev/null +++ b/pkg/planner/cardinality/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package cardinality + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/planner/cardinality/row_count_index.go b/pkg/planner/cardinality/row_count_index.go index bc0f3c3226090..72dd3014fc89a 100644 --- a/pkg/planner/cardinality/row_count_index.go +++ b/pkg/planner/cardinality/row_count_index.go @@ -486,10 +486,10 @@ func expBackoffEstimation(sctx context.PlanContext, idx *statistics.Index, coll // Sort them. slices.Sort(singleColumnEstResults) l := len(singleColumnEstResults) - failpoint.Inject("cleanEstResults", func() { + if _, _err_ := failpoint.Eval(_curpkg_("cleanEstResults")); _err_ == nil { singleColumnEstResults = singleColumnEstResults[:0] l = 0 - }) + } if l == 1 { return singleColumnEstResults[0], true, nil } else if l == 0 { diff --git a/pkg/planner/cardinality/row_count_index.go__failpoint_stash__ b/pkg/planner/cardinality/row_count_index.go__failpoint_stash__ new file mode 100644 index 0000000000000..bc0f3c3226090 --- /dev/null +++ b/pkg/planner/cardinality/row_count_index.go__failpoint_stash__ @@ -0,0 +1,568 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cardinality + +import ( + "bytes" + "math" + "slices" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/planner/context" + "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/mathutil" + "github.com/pingcap/tidb/pkg/util/ranger" +) + +// GetRowCountByIndexRanges estimates the row count by a slice of Range. +func GetRowCountByIndexRanges(sctx context.PlanContext, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (result float64, err error) { + var name string + if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + debugTraceGetRowCountInput(sctx, idxID, indexRanges) + defer func() { + debugtrace.RecordAnyValuesWithNames(sctx, "Name", name, "Result", result) + debugtrace.LeaveContextCommon(sctx) + }() + } + sc := sctx.GetSessionVars().StmtCtx + idx := coll.GetIdx(idxID) + colNames := make([]string, 0, 8) + if idx != nil { + if idx.Info != nil { + name = idx.Info.Name.O + for _, col := range idx.Info.Columns { + colNames = append(colNames, col.Name.O) + } + } + } + recordUsedItemStatsStatus(sctx, idx, coll.PhysicalID, idxID) + if statistics.IndexStatsIsInvalid(sctx, idx, coll, idxID) { + colsLen := -1 + if idx != nil && idx.Info.Unique { + colsLen = len(idx.Info.Columns) + } + result, err = getPseudoRowCountByIndexRanges(sc.TypeCtx(), indexRanges, float64(coll.RealtimeCount), colsLen) + if err == nil && sc.EnableOptimizerCETrace && idx != nil { + ceTraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats-Pseudo", uint64(result)) + } + return result, err + } + realtimeCnt, modifyCount := coll.GetScaledRealtimeAndModifyCnt(idx) + if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.RecordAnyValuesWithNames(sctx, + "Histogram NotNull Count", idx.Histogram.NotNullCount(), + "TopN total count", idx.TopN.TotalCount(), + "Increase Factor", idx.GetIncreaseFactor(realtimeCnt), + ) + } + if idx.CMSketch != nil && idx.StatsVer == statistics.Version1 { + result, err = getIndexRowCountForStatsV1(sctx, coll, idxID, indexRanges) + } else { + result, err = getIndexRowCountForStatsV2(sctx, idx, coll, indexRanges, realtimeCnt, modifyCount) + } + if sc.EnableOptimizerCETrace { + ceTraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats", uint64(result)) + } + return result, errors.Trace(err) +} + +func getIndexRowCountForStatsV1(sctx context.PlanContext, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx + debugTrace := sc.EnableOptimizerDebugTrace + if debugTrace { + debugtrace.EnterContextCommon(sctx) + defer debugtrace.LeaveContextCommon(sctx) + } + idx := coll.GetIdx(idxID) + totalCount := float64(0) + for _, ran := range indexRanges { + if debugTrace { + debugTraceStartEstimateRange(sctx, ran, nil, nil, totalCount) + } + rangePosition := getOrdinalOfRangeCond(sc, ran) + var rangeVals []types.Datum + // Try to enum the last range values. + if rangePosition != len(ran.LowVal) { + rangeVals = statistics.EnumRangeValues(ran.LowVal[rangePosition], ran.HighVal[rangePosition], ran.LowExclude, ran.HighExclude) + if rangeVals != nil { + rangePosition++ + } + } + // If first one is range, just use the previous way to estimate; if it is [NULL, NULL] range + // on single-column index, use previous way as well, because CMSketch does not contain null + // values in this case. + if rangePosition == 0 || isSingleColIdxNullRange(idx, ran) { + realtimeCnt, modifyCount := coll.GetScaledRealtimeAndModifyCnt(idx) + count, err := getIndexRowCountForStatsV2(sctx, idx, nil, []*ranger.Range{ran}, realtimeCnt, modifyCount) + if err != nil { + return 0, errors.Trace(err) + } + if debugTrace { + debugTraceEndEstimateRange(sctx, count, debugTraceRange) + } + totalCount += count + continue + } + var selectivity float64 + // use CM Sketch to estimate the equal conditions + if rangeVals == nil { + bytes, err := codec.EncodeKey(sc.TimeZone(), nil, ran.LowVal[:rangePosition]...) + err = sc.HandleError(err) + if err != nil { + return 0, errors.Trace(err) + } + selectivity, err = getEqualCondSelectivity(sctx, coll, idx, bytes, rangePosition, ran) + if err != nil { + return 0, errors.Trace(err) + } + } else { + bytes, err := codec.EncodeKey(sc.TimeZone(), nil, ran.LowVal[:rangePosition-1]...) + err = sc.HandleError(err) + if err != nil { + return 0, errors.Trace(err) + } + prefixLen := len(bytes) + for _, val := range rangeVals { + bytes = bytes[:prefixLen] + bytes, err = codec.EncodeKey(sc.TimeZone(), bytes, val) + err = sc.HandleError(err) + if err != nil { + return 0, err + } + res, err := getEqualCondSelectivity(sctx, coll, idx, bytes, rangePosition, ran) + if err != nil { + return 0, errors.Trace(err) + } + selectivity += res + } + } + // use histogram to estimate the range condition + if rangePosition != len(ran.LowVal) { + rang := ranger.Range{ + LowVal: []types.Datum{ran.LowVal[rangePosition]}, + LowExclude: ran.LowExclude, + HighVal: []types.Datum{ran.HighVal[rangePosition]}, + HighExclude: ran.HighExclude, + Collators: []collate.Collator{ran.Collators[rangePosition]}, + } + var count float64 + var err error + colUniqueIDs := coll.Idx2ColUniqueIDs[idxID] + var colUniqueID int64 + if rangePosition >= len(colUniqueIDs) { + colUniqueID = -1 + } else { + colUniqueID = colUniqueIDs[rangePosition] + } + // prefer index stats over column stats + if idxIDs, ok := coll.ColUniqueID2IdxIDs[colUniqueID]; ok && len(idxIDs) > 0 { + idxID := idxIDs[0] + count, err = GetRowCountByIndexRanges(sctx, coll, idxID, []*ranger.Range{&rang}) + } else { + count, err = GetRowCountByColumnRanges(sctx, coll, colUniqueID, []*ranger.Range{&rang}) + } + if err != nil { + return 0, errors.Trace(err) + } + selectivity = selectivity * count / idx.TotalRowCount() + } + count := selectivity * idx.TotalRowCount() + if debugTrace { + debugTraceEndEstimateRange(sctx, count, debugTraceRange) + } + totalCount += count + } + if totalCount > idx.TotalRowCount() { + totalCount = idx.TotalRowCount() + } + return totalCount, nil +} + +// isSingleColIdxNullRange checks if a range is [NULL, NULL] on a single-column index. +func isSingleColIdxNullRange(idx *statistics.Index, ran *ranger.Range) bool { + if len(idx.Info.Columns) > 1 { + return false + } + l, h := ran.LowVal[0], ran.HighVal[0] + if l.IsNull() && h.IsNull() { + return true + } + return false +} + +// It uses the modifyCount to adjust the influence of modifications on the table. +func getIndexRowCountForStatsV2(sctx context.PlanContext, idx *statistics.Index, coll *statistics.HistColl, indexRanges []*ranger.Range, realtimeRowCount, modifyCount int64) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx + debugTrace := sc.EnableOptimizerDebugTrace + if debugTrace { + debugtrace.EnterContextCommon(sctx) + defer debugtrace.LeaveContextCommon(sctx) + } + totalCount := float64(0) + isSingleColIdx := len(idx.Info.Columns) == 1 + for _, indexRange := range indexRanges { + var count float64 + lb, err := codec.EncodeKey(sc.TimeZone(), nil, indexRange.LowVal...) + err = sc.HandleError(err) + if err != nil { + return 0, err + } + rb, err := codec.EncodeKey(sc.TimeZone(), nil, indexRange.HighVal...) + err = sc.HandleError(err) + if err != nil { + return 0, err + } + if debugTrace { + debugTraceStartEstimateRange(sctx, indexRange, lb, rb, totalCount) + } + fullLen := len(indexRange.LowVal) == len(indexRange.HighVal) && len(indexRange.LowVal) == len(idx.Info.Columns) + if bytes.Equal(lb, rb) { + // case 1: it's a point + if indexRange.LowExclude || indexRange.HighExclude { + if debugTrace { + debugTraceEndEstimateRange(sctx, 0, debugTraceImpossible) + } + continue + } + if fullLen { + // At most 1 in this case. + if idx.Info.Unique { + totalCount++ + if debugTrace { + debugTraceEndEstimateRange(sctx, 1, debugTraceUniquePoint) + } + continue + } + count = equalRowCountOnIndex(sctx, idx, lb, realtimeRowCount, modifyCount) + // If the current table row count has changed, we should scale the row count accordingly. + count *= idx.GetIncreaseFactor(realtimeRowCount) + if debugTrace { + debugTraceEndEstimateRange(sctx, count, debugTracePoint) + } + totalCount += count + continue + } + } + + // case 2: it's an interval + // The final interval is [low, high) + if indexRange.LowExclude { + lb = kv.Key(lb).PrefixNext() + } + if !indexRange.HighExclude { + rb = kv.Key(rb).PrefixNext() + } + l := types.NewBytesDatum(lb) + r := types.NewBytesDatum(rb) + lowIsNull := bytes.Equal(lb, nullKeyBytes) + if isSingleColIdx && lowIsNull { + count += float64(idx.Histogram.NullCount) + } + expBackoffSuccess := false + // Due to the limitation of calcFraction and convertDatumToScalar, the histogram actually won't estimate anything. + // If the first column's range is point. + if rangePosition := getOrdinalOfRangeCond(sc, indexRange); rangePosition > 0 && idx.StatsVer >= statistics.Version2 && coll != nil { + var expBackoffSel float64 + expBackoffSel, expBackoffSuccess, err = expBackoffEstimation(sctx, idx, coll, indexRange) + if err != nil { + return 0, err + } + if expBackoffSuccess { + expBackoffCnt := expBackoffSel * idx.TotalRowCount() + + upperLimit := expBackoffCnt + // Use the multi-column stats to calculate the max possible row count of [l, r) + if idx.Histogram.Len() > 0 { + _, lowerBkt, _, _ := idx.Histogram.LocateBucket(sctx, l) + _, upperBkt, _, _ := idx.Histogram.LocateBucket(sctx, r) + if debugTrace { + statistics.DebugTraceBuckets(sctx, &idx.Histogram, []int{lowerBkt - 1, upperBkt}) + } + // Use Count of the Bucket before l as the lower bound. + preCount := float64(0) + if lowerBkt > 0 { + preCount = float64(idx.Histogram.Buckets[lowerBkt-1].Count) + } + // Use Count of the Bucket where r exists as the upper bound. + upperCnt := float64(idx.Histogram.Buckets[upperBkt].Count) + upperLimit = upperCnt - preCount + upperLimit += float64(idx.TopN.BetweenCount(sctx, lb, rb)) + } + + // If the result of exponential backoff strategy is larger than the result from multi-column stats, + // use the upper limit from multi-column histogram instead. + if expBackoffCnt > upperLimit { + expBackoffCnt = upperLimit + } + count += expBackoffCnt + } + } + if !expBackoffSuccess { + count += betweenRowCountOnIndex(sctx, idx, l, r) + } + + // If the current table row count has changed, we should scale the row count accordingly. + increaseFactor := idx.GetIncreaseFactor(realtimeRowCount) + count *= increaseFactor + + // handling the out-of-range part + if (outOfRangeOnIndex(idx, l) && !(isSingleColIdx && lowIsNull)) || outOfRangeOnIndex(idx, r) { + histNDV := idx.NDV + // Exclude the TopN in Stats Version 2 + if idx.StatsVer == statistics.Version2 { + c := coll.GetCol(idx.Histogram.ID) + // If this is single column of a multi-column index - use the column's NDV rather than index NDV + isSingleColRange := len(indexRange.LowVal) == len(indexRange.HighVal) && len(indexRange.LowVal) == 1 + if isSingleColRange && !isSingleColIdx && c != nil && c.Histogram.NDV > 0 { + histNDV = c.Histogram.NDV - int64(c.TopN.Num()) + } else { + histNDV -= int64(idx.TopN.Num()) + } + } + count += idx.Histogram.OutOfRangeRowCount(sctx, &l, &r, modifyCount, histNDV, increaseFactor) + } + + if debugTrace { + debugTraceEndEstimateRange(sctx, count, debugTraceRange) + } + totalCount += count + } + // Don't allow the final result to go below 1 row + totalCount = mathutil.Clamp(totalCount, 1, float64(realtimeRowCount)) + return totalCount, nil +} + +var nullKeyBytes, _ = codec.EncodeKey(time.UTC, nil, types.NewDatum(nil)) + +func equalRowCountOnIndex(sctx context.PlanContext, idx *statistics.Index, b []byte, realtimeRowCount, modifyCount int64) (result float64) { + if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + debugtrace.RecordAnyValuesWithNames(sctx, "Encoded Value", b) + defer func() { + debugtrace.RecordAnyValuesWithNames(sctx, "Result", result) + debugtrace.LeaveContextCommon(sctx) + }() + } + if len(idx.Info.Columns) == 1 { + if bytes.Equal(b, nullKeyBytes) { + return float64(idx.Histogram.NullCount) + } + } + val := types.NewBytesDatum(b) + if idx.StatsVer < statistics.Version2 { + if idx.Histogram.NDV > 0 && outOfRangeOnIndex(idx, val) { + return outOfRangeEQSelectivity(sctx, idx.Histogram.NDV, realtimeRowCount, int64(idx.TotalRowCount())) * idx.TotalRowCount() + } + if idx.CMSketch != nil { + return float64(idx.QueryBytes(sctx, b)) + } + histRowCount, _ := idx.Histogram.EqualRowCount(sctx, val, false) + return histRowCount + } + // stats version == 2 + // 1. try to find this value in TopN + if idx.TopN != nil { + count, found := idx.TopN.QueryTopN(sctx, b) + if found { + return float64(count) + } + } + // 2. try to find this value in bucket.Repeat(the last value in every bucket) + histCnt, matched := idx.Histogram.EqualRowCount(sctx, val, true) + if matched { + return histCnt + } + // 3. use uniform distribution assumption for the rest (even when this value is not covered by the range of stats) + histNDV := float64(idx.Histogram.NDV - int64(idx.TopN.Num())) + if histNDV <= 0 { + // If the table hasn't been modified, it's safe to return 0. Otherwise, the TopN could be stale - return 1. + if modifyCount == 0 { + return 0 + } + return 1 + } + return idx.Histogram.NotNullCount() / histNDV +} + +// expBackoffEstimation estimate the multi-col cases following the Exponential Backoff. See comment below for details. +func expBackoffEstimation(sctx context.PlanContext, idx *statistics.Index, coll *statistics.HistColl, indexRange *ranger.Range) (sel float64, success bool, err error) { + if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + defer func() { + debugtrace.RecordAnyValuesWithNames(sctx, + "Result", sel, + "Success", success, + "error", err, + ) + debugtrace.LeaveContextCommon(sctx) + }() + } + tmpRan := []*ranger.Range{ + { + LowVal: make([]types.Datum, 1), + HighVal: make([]types.Datum, 1), + Collators: make([]collate.Collator, 1), + }, + } + colsIDs := coll.Idx2ColUniqueIDs[idx.Histogram.ID] + singleColumnEstResults := make([]float64, 0, len(indexRange.LowVal)) + // The following codes uses Exponential Backoff to reduce the impact of independent assumption. It works like: + // 1. Calc the selectivity of each column. + // 2. Sort them and choose the first 4 most selective filter and the corresponding selectivity is sel_1, sel_2, sel_3, sel_4 where i < j => sel_i < sel_j. + // 3. The final selectivity would be sel_1 * sel_2^{1/2} * sel_3^{1/4} * sel_4^{1/8}. + // This calculation reduced the independence assumption and can work well better than it. + for i := 0; i < len(indexRange.LowVal); i++ { + tmpRan[0].LowVal[0] = indexRange.LowVal[i] + tmpRan[0].HighVal[0] = indexRange.HighVal[i] + tmpRan[0].Collators[0] = indexRange.Collators[0] + if i == len(indexRange.LowVal)-1 { + tmpRan[0].LowExclude = indexRange.LowExclude + tmpRan[0].HighExclude = indexRange.HighExclude + } + colID := colsIDs[i] + var ( + count float64 + selectivity float64 + err error + foundStats bool + ) + if !statistics.ColumnStatsIsInvalid(coll.GetCol(colID), sctx, coll, colID) { + foundStats = true + count, err = GetRowCountByColumnRanges(sctx, coll, colID, tmpRan) + selectivity = count / float64(coll.RealtimeCount) + } + if idxIDs, ok := coll.ColUniqueID2IdxIDs[colID]; ok && !foundStats && len(indexRange.LowVal) > 1 { + // Note the `len(indexRange.LowVal) > 1` condition here, it means we only recursively call + // `GetRowCountByIndexRanges()` when the input `indexRange` is a multi-column range. This + // check avoids infinite recursion. + for _, idxID := range idxIDs { + if idxID == idx.Histogram.ID { + continue + } + idxStats := coll.GetIdx(idxID) + if idxStats == nil || statistics.IndexStatsIsInvalid(sctx, idxStats, coll, idxID) { + continue + } + foundStats = true + count, err = GetRowCountByIndexRanges(sctx, coll, idxID, tmpRan) + if err == nil { + break + } + realtimeCnt, _ := coll.GetScaledRealtimeAndModifyCnt(idxStats) + selectivity = count / float64(realtimeCnt) + } + } + if !foundStats { + continue + } + if err != nil { + return 0, false, err + } + singleColumnEstResults = append(singleColumnEstResults, selectivity) + } + // Sort them. + slices.Sort(singleColumnEstResults) + l := len(singleColumnEstResults) + failpoint.Inject("cleanEstResults", func() { + singleColumnEstResults = singleColumnEstResults[:0] + l = 0 + }) + if l == 1 { + return singleColumnEstResults[0], true, nil + } else if l == 0 { + return 0, false, nil + } + // Do not allow the exponential backoff to go below the available index bound. If the number of predicates + // is less than the number of index columns - use 90% of the bound to differentiate a subset from full index match. + // If there is an individual column selectivity that goes below this bound, use that selectivity only. + histNDV := coll.RealtimeCount + if idx.NDV > 0 { + histNDV = idx.NDV + } + idxLowBound := 1 / float64(min(histNDV, coll.RealtimeCount)) + if l < len(idx.Info.Columns) { + idxLowBound /= 0.9 + } + minTwoCol := min(singleColumnEstResults[0], singleColumnEstResults[1], idxLowBound) + multTwoCol := singleColumnEstResults[0] * math.Sqrt(singleColumnEstResults[1]) + if l == 2 { + return max(minTwoCol, multTwoCol), true, nil + } + minThreeCol := min(minTwoCol, singleColumnEstResults[2]) + multThreeCol := multTwoCol * math.Sqrt(math.Sqrt(singleColumnEstResults[2])) + if l == 3 { + return max(minThreeCol, multThreeCol), true, nil + } + minFourCol := min(minThreeCol, singleColumnEstResults[3]) + multFourCol := multThreeCol * math.Sqrt(math.Sqrt(math.Sqrt(singleColumnEstResults[3]))) + return max(minFourCol, multFourCol), true, nil +} + +// outOfRangeOnIndex checks if the datum is out of the range. +func outOfRangeOnIndex(idx *statistics.Index, val types.Datum) bool { + if !idx.Histogram.OutOfRange(val) { + return false + } + if idx.Histogram.Len() > 0 && matchPrefix(idx.Histogram.Bounds.GetRow(0), 0, &val) { + return false + } + return true +} + +// matchPrefix checks whether ad is the prefix of value +func matchPrefix(row chunk.Row, colIdx int, ad *types.Datum) bool { + switch ad.Kind() { + case types.KindString, types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit: + return strings.HasPrefix(row.GetString(colIdx), ad.GetString()) + } + return false +} + +// betweenRowCountOnIndex estimates the row count for interval [l, r). +// The input sctx is just for debug trace, you can pass nil safely if that's not needed. +func betweenRowCountOnIndex(sctx context.PlanContext, idx *statistics.Index, l, r types.Datum) float64 { + histBetweenCnt := idx.Histogram.BetweenRowCount(sctx, l, r) + if idx.StatsVer == statistics.Version1 { + return histBetweenCnt + } + return float64(idx.TopN.BetweenCount(sctx, l.GetBytes(), r.GetBytes())) + histBetweenCnt +} + +// getOrdinalOfRangeCond gets the ordinal of the position range condition, +// if not exist, it returns the end position. +func getOrdinalOfRangeCond(sc *stmtctx.StatementContext, ran *ranger.Range) int { + for i := range ran.LowVal { + a, b := ran.LowVal[i], ran.HighVal[i] + cmp, err := a.Compare(sc.TypeCtx(), &b, ran.Collators[0]) + if err != nil { + return 0 + } + if cmp != 0 { + return i + } + } + return len(ran.LowVal) +} diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 591b9da71813d..20d8abf39e825 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -235,7 +235,7 @@ func TestEstimationForUnknownValues(t *testing.T) { colID := table.Meta().Columns[0].ID count, err := cardinality.GetRowCountByColumnRanges(sctx, &statsTbl.HistColl, colID, getRange(30, 30)) require.NoError(t, err) - require.Equal(t, 0.2, count) + require.Equal(t, 1.2, count) count, err = cardinality.GetRowCountByColumnRanges(sctx, &statsTbl.HistColl, colID, getRange(9, 30)) require.NoError(t, err) @@ -248,7 +248,7 @@ func TestEstimationForUnknownValues(t *testing.T) { idxID := table.Meta().Indices[0].ID count, err = cardinality.GetRowCountByIndexRanges(sctx, &statsTbl.HistColl, idxID, getRange(30, 30)) require.NoError(t, err) - require.Equal(t, 0.1, count) + require.Equal(t, 1.0, count) count, err = cardinality.GetRowCountByIndexRanges(sctx, &statsTbl.HistColl, idxID, getRange(9, 30)) require.NoError(t, err) @@ -264,7 +264,7 @@ func TestEstimationForUnknownValues(t *testing.T) { colID = table.Meta().Columns[0].ID count, err = cardinality.GetRowCountByColumnRanges(sctx, &statsTbl.HistColl, colID, getRange(1, 30)) require.NoError(t, err) - require.Equal(t, 0.0, count) + require.Equal(t, 1.0, count) testKit.MustExec("drop table t") testKit.MustExec("create table t(a int, b int, index idx(b))") @@ -277,12 +277,12 @@ func TestEstimationForUnknownValues(t *testing.T) { colID = table.Meta().Columns[0].ID count, err = cardinality.GetRowCountByColumnRanges(sctx, &statsTbl.HistColl, colID, getRange(2, 2)) require.NoError(t, err) - require.Equal(t, 0.0, count) + require.Equal(t, 1.0, count) idxID = table.Meta().Indices[0].ID count, err = cardinality.GetRowCountByIndexRanges(sctx, &statsTbl.HistColl, idxID, getRange(2, 2)) require.NoError(t, err) - require.Equal(t, 0.0, count) + require.Equal(t, 1.0, count) } func TestEstimationUniqueKeyEqualConds(t *testing.T) { @@ -1189,8 +1189,8 @@ func TestCrossValidationSelectivity(t *testing.T) { require.NoError(t, h.DumpStatsDeltaToKV(true)) tk.MustExec("analyze table t") tk.MustQuery("explain format = 'brief' select * from t where a = 1 and b > 0 and b < 1000 and c > 1000").Check(testkit.Rows( - "TableReader 0.00 root data:Selection", - "└─Selection 0.00 cop[tikv] gt(test.t.c, 1000)", + "TableReader 1.00 root data:Selection", + "└─Selection 1.00 cop[tikv] gt(test.t.c, 1000)", " └─TableRangeScan 2.00 cop[tikv] table:t range:(1 0,1 1000), keep order:false")) } @@ -1212,8 +1212,8 @@ func TestIgnoreRealtimeStats(t *testing.T) { // From the real-time stats, we are able to know the total count is 11. testKit.MustExec("set @@tidb_opt_objective = 'moderate'") testKit.MustQuery("explain select * from t where a = 1 and b > 2").Check(testkit.Rows( - "TableReader_7 0.00 root data:Selection_6", - "└─Selection_6 0.00 cop[tikv] eq(test.t.a, 1), gt(test.t.b, 2)", + "TableReader_7 1.00 root data:Selection_6", + "└─Selection_6 1.00 cop[tikv] eq(test.t.a, 1), gt(test.t.b, 2)", " └─TableFullScan_5 11.00 cop[tikv] table:t keep order:false, stats:pseudo", )) diff --git a/pkg/planner/core/binding__failpoint_binding__.go b/pkg/planner/core/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..fd84c40441f21 --- /dev/null +++ b/pkg/planner/core/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package core + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/planner/core/collect_column_stats_usage.go b/pkg/planner/core/collect_column_stats_usage.go index 54c427c82289f..060a764789d42 100644 --- a/pkg/planner/core/collect_column_stats_usage.go +++ b/pkg/planner/core/collect_column_stats_usage.go @@ -190,9 +190,9 @@ func (c *columnStatsUsageCollector) addHistNeededColumns(ds *DataSource) { stats := domain.GetDomain(ds.SCtx()).StatsHandle() tblStats := stats.GetPartitionStats(ds.TableInfo, ds.PhysicalTableID) skipPseudoCheckForTest := false - failpoint.Inject("disablePseudoCheck", func() { + if _, _err_ := failpoint.Eval(_curpkg_("disablePseudoCheck")); _err_ == nil { skipPseudoCheckForTest = true - }) + } // Since we can not get the stats tbl, this table is not analyzed. So we don't need to consider load stats. if tblStats.Pseudo && !skipPseudoCheckForTest { return diff --git a/pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ b/pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ new file mode 100644 index 0000000000000..54c427c82289f --- /dev/null +++ b/pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ @@ -0,0 +1,456 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics/asyncload" + "github.com/pingcap/tidb/pkg/util/filter" + "github.com/pingcap/tidb/pkg/util/intset" + "golang.org/x/exp/maps" +) + +const ( + collectPredicateColumns uint64 = 1 << iota + collectHistNeededColumns +) + +// columnStatsUsageCollector collects predicate columns and/or histogram-needed columns from logical plan. +// Predicate columns are the columns whose statistics are utilized when making query plans, which usually occur in where conditions, join conditions and so on. +// Histogram-needed columns are the columns whose histograms are utilized when making query plans, which usually occur in the conditions pushed down to DataSource. +// The set of histogram-needed columns is the subset of that of predicate columns. +type columnStatsUsageCollector struct { + // collectMode indicates whether to collect predicate columns and/or histogram-needed columns + collectMode uint64 + // predicateCols records predicate columns. + predicateCols map[model.TableItemID]struct{} + // colMap maps expression.Column.UniqueID to the table columns whose statistics may be utilized to calculate statistics of the column. + // It is used for collecting predicate columns. + // For example, in `select count(distinct a, b) as e from t`, the count of column `e` is calculated as `max(ndv(t.a), ndv(t.b))` if + // we don't know `ndv(t.a, t.b)`(see (*LogicalAggregation).DeriveStats and getColsNDV for details). So when calculating the statistics + // of column `e`, we may use the statistics of column `t.a` and `t.b`. + colMap map[int64]map[model.TableItemID]struct{} + // histNeededCols records histogram-needed columns. The value field of the map indicates that whether we need to load the full stats of the time or not. + histNeededCols map[model.TableItemID]bool + // cols is used to store columns collected from expressions and saves some allocation. + cols []*expression.Column + + // visitedPhysTblIDs all ds.PhysicalTableID that have been visited. + // It's always collected, even collectHistNeededColumns is not set. + visitedPhysTblIDs *intset.FastIntSet + + // collectVisitedTable indicates whether to collect visited table + collectVisitedTable bool + // visitedtbls indicates the visited table + visitedtbls map[int64]struct{} +} + +func newColumnStatsUsageCollector(collectMode uint64, enabledPlanCapture bool) *columnStatsUsageCollector { + set := intset.NewFastIntSet() + collector := &columnStatsUsageCollector{ + collectMode: collectMode, + // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. + cols: make([]*expression.Column, 0, 8), + visitedPhysTblIDs: &set, + } + if collectMode&collectPredicateColumns != 0 { + collector.predicateCols = make(map[model.TableItemID]struct{}) + collector.colMap = make(map[int64]map[model.TableItemID]struct{}) + } + if collectMode&collectHistNeededColumns != 0 { + collector.histNeededCols = make(map[model.TableItemID]bool) + } + if enabledPlanCapture { + collector.collectVisitedTable = true + collector.visitedtbls = map[int64]struct{}{} + } + return collector +} + +func (c *columnStatsUsageCollector) addPredicateColumn(col *expression.Column) { + tblColIDs, ok := c.colMap[col.UniqueID] + if !ok { + // It may happen if some leaf of logical plan is LogicalMemTable/LogicalShow/LogicalShowDDLJobs. + return + } + for tblColID := range tblColIDs { + c.predicateCols[tblColID] = struct{}{} + } +} + +func (c *columnStatsUsageCollector) addPredicateColumnsFromExpressions(list []expression.Expression) { + cols := expression.ExtractColumnsAndCorColumnsFromExpressions(c.cols[:0], list) + for _, col := range cols { + c.addPredicateColumn(col) + } +} + +func (c *columnStatsUsageCollector) updateColMap(col *expression.Column, relatedCols []*expression.Column) { + if _, ok := c.colMap[col.UniqueID]; !ok { + c.colMap[col.UniqueID] = map[model.TableItemID]struct{}{} + } + for _, relatedCol := range relatedCols { + tblColIDs, ok := c.colMap[relatedCol.UniqueID] + if !ok { + // It may happen if some leaf of logical plan is LogicalMemTable/LogicalShow/LogicalShowDDLJobs. + continue + } + for tblColID := range tblColIDs { + c.colMap[col.UniqueID][tblColID] = struct{}{} + } + } +} + +func (c *columnStatsUsageCollector) updateColMapFromExpressions(col *expression.Column, list []expression.Expression) { + c.updateColMap(col, expression.ExtractColumnsAndCorColumnsFromExpressions(c.cols[:0], list)) +} + +func (c *columnStatsUsageCollector) collectPredicateColumnsForDataSource(ds *DataSource) { + // Skip all system tables. + if filter.IsSystemSchema(ds.DBName.L) { + return + } + // For partition tables, no matter whether it is static or dynamic pruning mode, we use table ID rather than partition ID to + // set TableColumnID.TableID. In this way, we keep the set of predicate columns consistent between different partitions and global table. + tblID := ds.TableInfo.ID + if c.collectVisitedTable { + c.visitedtbls[tblID] = struct{}{} + } + for _, col := range ds.Schema().Columns { + tblColID := model.TableItemID{TableID: tblID, ID: col.ID, IsIndex: false} + c.colMap[col.UniqueID] = map[model.TableItemID]struct{}{tblColID: {}} + } + // We should use `PushedDownConds` here. `AllConds` is used for partition pruning, which doesn't need stats. + c.addPredicateColumnsFromExpressions(ds.PushedDownConds) +} + +func (c *columnStatsUsageCollector) collectPredicateColumnsForJoin(p *LogicalJoin) { + // The only schema change is merging two schemas so there is no new column. + // Assume statistics of all the columns in EqualConditions/LeftConditions/RightConditions/OtherConditions are needed. + exprs := make([]expression.Expression, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) + for _, cond := range p.EqualConditions { + exprs = append(exprs, cond) + } + for _, cond := range p.LeftConditions { + exprs = append(exprs, cond) + } + for _, cond := range p.RightConditions { + exprs = append(exprs, cond) + } + for _, cond := range p.OtherConditions { + exprs = append(exprs, cond) + } + c.addPredicateColumnsFromExpressions(exprs) +} + +func (c *columnStatsUsageCollector) collectPredicateColumnsForUnionAll(p *LogicalUnionAll) { + // statistics of the ith column of UnionAll come from statistics of the ith column of each child. + schemas := make([]*expression.Schema, 0, len(p.Children())) + relatedCols := make([]*expression.Column, 0, len(p.Children())) + for _, child := range p.Children() { + schemas = append(schemas, child.Schema()) + } + for i, col := range p.Schema().Columns { + relatedCols = relatedCols[:0] + for j := range p.Children() { + relatedCols = append(relatedCols, schemas[j].Columns[i]) + } + c.updateColMap(col, relatedCols) + } +} + +func (c *columnStatsUsageCollector) addHistNeededColumns(ds *DataSource) { + c.visitedPhysTblIDs.Insert(int(ds.PhysicalTableID)) + if c.collectMode&collectHistNeededColumns == 0 { + return + } + if c.collectVisitedTable { + tblID := ds.TableInfo.ID + c.visitedtbls[tblID] = struct{}{} + } + stats := domain.GetDomain(ds.SCtx()).StatsHandle() + tblStats := stats.GetPartitionStats(ds.TableInfo, ds.PhysicalTableID) + skipPseudoCheckForTest := false + failpoint.Inject("disablePseudoCheck", func() { + skipPseudoCheckForTest = true + }) + // Since we can not get the stats tbl, this table is not analyzed. So we don't need to consider load stats. + if tblStats.Pseudo && !skipPseudoCheckForTest { + return + } + columns := expression.ExtractColumnsFromExpressions(c.cols[:0], ds.PushedDownConds, nil) + + colIDSet := intset.NewFastIntSet() + + for _, col := range columns { + // If the column is plan-generated one, Skip it. + // TODO: we may need to consider the ExtraHandle. + if col.ID < 0 { + continue + } + tblColID := model.TableItemID{TableID: ds.PhysicalTableID, ID: col.ID, IsIndex: false} + colIDSet.Insert(int(col.ID)) + c.histNeededCols[tblColID] = true + } + for _, column := range ds.TableInfo.Columns { + // If the column is plan-generated one, Skip it. + // TODO: we may need to consider the ExtraHandle. + if column.ID < 0 { + continue + } + if !column.Hidden { + tblColID := model.TableItemID{TableID: ds.PhysicalTableID, ID: column.ID, IsIndex: false} + if _, ok := c.histNeededCols[tblColID]; !ok { + c.histNeededCols[tblColID] = false + } + } + } +} + +func (c *columnStatsUsageCollector) collectFromPlan(lp base.LogicalPlan) { + for _, child := range lp.Children() { + c.collectFromPlan(child) + } + if c.collectMode&collectPredicateColumns != 0 { + switch x := lp.(type) { + case *DataSource: + c.collectPredicateColumnsForDataSource(x) + case *LogicalIndexScan: + c.collectPredicateColumnsForDataSource(x.Source) + c.addPredicateColumnsFromExpressions(x.AccessConds) + case *LogicalTableScan: + c.collectPredicateColumnsForDataSource(x.Source) + c.addPredicateColumnsFromExpressions(x.AccessConds) + case *logicalop.LogicalProjection: + // Schema change from children to self. + schema := x.Schema() + for i, expr := range x.Exprs { + c.updateColMapFromExpressions(schema.Columns[i], []expression.Expression{expr}) + } + case *LogicalSelection: + // Though the conditions in LogicalSelection are complex conditions which cannot be pushed down to DataSource, we still + // regard statistics of the columns in the conditions as needed. + c.addPredicateColumnsFromExpressions(x.Conditions) + case *LogicalAggregation: + // Just assume statistics of all the columns in GroupByItems are needed. + c.addPredicateColumnsFromExpressions(x.GroupByItems) + // Schema change from children to self. + schema := x.Schema() + for i, aggFunc := range x.AggFuncs { + c.updateColMapFromExpressions(schema.Columns[i], aggFunc.Args) + } + case *logicalop.LogicalWindow: + // Statistics of the columns in LogicalWindow.PartitionBy are used in optimizeByShuffle4Window. + // We don't use statistics of the columns in LogicalWindow.OrderBy currently. + for _, item := range x.PartitionBy { + c.addPredicateColumn(item.Col) + } + // Schema change from children to self. + windowColumns := x.GetWindowResultColumns() + for i, col := range windowColumns { + c.updateColMapFromExpressions(col, x.WindowFuncDescs[i].Args) + } + case *LogicalJoin: + c.collectPredicateColumnsForJoin(x) + case *LogicalApply: + c.collectPredicateColumnsForJoin(&x.LogicalJoin) + // Assume statistics of correlated columns are needed. + // Correlated columns can be found in LogicalApply.Children()[0].Schema(). Since we already visit LogicalApply.Children()[0], + // correlated columns must have existed in columnStatsUsageCollector.colMap. + for _, corCols := range x.CorCols { + c.addPredicateColumn(&corCols.Column) + } + case *logicalop.LogicalSort: + // Assume statistics of all the columns in ByItems are needed. + for _, item := range x.ByItems { + c.addPredicateColumnsFromExpressions([]expression.Expression{item.Expr}) + } + case *logicalop.LogicalTopN: + // Assume statistics of all the columns in ByItems are needed. + for _, item := range x.ByItems { + c.addPredicateColumnsFromExpressions([]expression.Expression{item.Expr}) + } + case *LogicalUnionAll: + c.collectPredicateColumnsForUnionAll(x) + case *LogicalPartitionUnionAll: + c.collectPredicateColumnsForUnionAll(&x.LogicalUnionAll) + case *LogicalCTE: + // Visit seedPartLogicalPlan and recursivePartLogicalPlan first. + c.collectFromPlan(x.Cte.seedPartLogicalPlan) + if x.Cte.recursivePartLogicalPlan != nil { + c.collectFromPlan(x.Cte.recursivePartLogicalPlan) + } + // Schema change from seedPlan/recursivePlan to self. + columns := x.Schema().Columns + seedColumns := x.Cte.seedPartLogicalPlan.Schema().Columns + var recursiveColumns []*expression.Column + if x.Cte.recursivePartLogicalPlan != nil { + recursiveColumns = x.Cte.recursivePartLogicalPlan.Schema().Columns + } + relatedCols := make([]*expression.Column, 0, 2) + for i, col := range columns { + relatedCols = append(relatedCols[:0], seedColumns[i]) + if recursiveColumns != nil { + relatedCols = append(relatedCols, recursiveColumns[i]) + } + c.updateColMap(col, relatedCols) + } + // If IsDistinct is true, then we use getColsNDV to calculate row count(see (*LogicalCTE).DeriveStat). In this case + // statistics of all the columns are needed. + if x.Cte.IsDistinct { + for _, col := range columns { + c.addPredicateColumn(col) + } + } + case *logicalop.LogicalCTETable: + // Schema change from seedPlan to self. + for i, col := range x.Schema().Columns { + c.updateColMap(col, []*expression.Column{x.SeedSchema.Columns[i]}) + } + } + } + // Histogram-needed columns are the columns which occur in the conditions pushed down to DataSource. + // We don't consider LogicalCTE because seedLogicalPlan and recursiveLogicalPlan haven't got logical optimization + // yet(seedLogicalPlan and recursiveLogicalPlan are optimized in DeriveStats phase). Without logical optimization, + // there is no condition pushed down to DataSource so no histogram-needed column can be collected. + // + // Since c.visitedPhysTblIDs is also collected here and needs to be collected even collectHistNeededColumns is not set, + // so we do the c.collectMode check in addHistNeededColumns() after collecting c.visitedPhysTblIDs. + switch x := lp.(type) { + case *DataSource: + c.addHistNeededColumns(x) + case *LogicalIndexScan: + c.addHistNeededColumns(x.Source) + case *LogicalTableScan: + c.addHistNeededColumns(x.Source) + } +} + +// CollectColumnStatsUsage collects column stats usage from logical plan. +// predicate indicates whether to collect predicate columns and histNeeded indicates whether to collect histogram-needed columns. +// First return value: predicate columns +// Second return value: histogram-needed columns (nil if histNeeded is false) +// Third return value: ds.PhysicalTableID from all DataSource (always collected) +func CollectColumnStatsUsage(lp base.LogicalPlan, histNeeded bool) ( + []model.TableItemID, + []model.StatsLoadItem, + *intset.FastIntSet, +) { + var mode uint64 + // Always collect predicate columns. + mode |= collectPredicateColumns + if histNeeded { + mode |= collectHistNeededColumns + } + collector := newColumnStatsUsageCollector(mode, lp.SCtx().GetSessionVars().IsPlanReplayerCaptureEnabled()) + collector.collectFromPlan(lp) + if collector.collectVisitedTable { + recordTableRuntimeStats(lp.SCtx(), collector.visitedtbls) + } + itemSet2slice := func(set map[model.TableItemID]bool) []model.StatsLoadItem { + ret := make([]model.StatsLoadItem, 0, len(set)) + for item, fullLoad := range set { + ret = append(ret, model.StatsLoadItem{TableItemID: item, FullLoad: fullLoad}) + } + return ret + } + is := lp.SCtx().GetInfoSchema().(infoschema.InfoSchema) + statsHandle := domain.GetDomain(lp.SCtx()).StatsHandle() + physTblIDsWithNeededCols := intset.NewFastIntSet() + for neededCol, fullLoad := range collector.histNeededCols { + if !fullLoad { + continue + } + physTblIDsWithNeededCols.Insert(int(neededCol.TableID)) + } + collector.visitedPhysTblIDs.ForEach(func(physicalTblID int) { + // 1. collect table metadata + tbl, _ := infoschema.FindTableByTblOrPartID(is, int64(physicalTblID)) + if tbl == nil { + return + } + + // 2. handle extra sync/async stats loading for the determinate mode + + // If we visited a table without getting any columns need stats (likely because there are no pushed down + // predicates), and we are in the determinate mode, we need to make sure we are able to get the "analyze row + // count" in getStatsTable(), which means any column/index stats are available. + if lp.SCtx().GetSessionVars().GetOptObjective() != variable.OptObjectiveDeterminate || + // If we already collected some columns that need trigger sync laoding on this table, we don't need to + // additionally do anything for determinate mode. + physTblIDsWithNeededCols.Has(physicalTblID) || + statsHandle == nil { + return + } + tblStats := statsHandle.GetTableStats(tbl.Meta()) + if tblStats == nil || tblStats.Pseudo { + return + } + var colToTriggerLoad *model.TableItemID + for _, col := range tbl.Cols() { + if col.State != model.StatePublic || (col.IsGenerated() && !col.GeneratedStored) || !tblStats.ColAndIdxExistenceMap.HasAnalyzed(col.ID, false) { + continue + } + if colStats := tblStats.GetCol(col.ID); colStats != nil { + // If any stats are already full loaded, we don't need to trigger stats loading on this table. + if colStats.IsFullLoad() { + colToTriggerLoad = nil + break + } + } + // Choose the first column we meet to trigger stats loading. + if colToTriggerLoad == nil { + colToTriggerLoad = &model.TableItemID{TableID: int64(physicalTblID), ID: col.ID, IsIndex: false} + } + } + if colToTriggerLoad == nil { + return + } + for _, idx := range tbl.Indices() { + if idx.Meta().State != model.StatePublic || idx.Meta().MVIndex { + continue + } + // If any stats are already full loaded, we don't need to trigger stats loading on this table. + if idxStats := tblStats.GetIdx(idx.Meta().ID); idxStats != nil && idxStats.IsFullLoad() { + colToTriggerLoad = nil + break + } + } + if colToTriggerLoad == nil { + return + } + if histNeeded { + collector.histNeededCols[*colToTriggerLoad] = true + } else { + asyncload.AsyncLoadHistogramNeededItems.Insert(*colToTriggerLoad, true) + } + }) + var ( + predicateCols []model.TableItemID + histNeededCols []model.StatsLoadItem + ) + predicateCols = maps.Keys(collector.predicateCols) + if histNeeded { + histNeededCols = itemSet2slice(collector.histNeededCols) + } + return predicateCols, histNeededCols, collector.visitedPhysTblIDs +} diff --git a/pkg/planner/core/debugtrace.go b/pkg/planner/core/debugtrace.go index babda2b1d4551..1f1dfdd68f196 100644 --- a/pkg/planner/core/debugtrace.go +++ b/pkg/planner/core/debugtrace.go @@ -196,11 +196,11 @@ func debugTraceGetStatsTbl( Outdated: outdated, StatsTblInfo: statistics.TraceStatsTbl(statsTbl), } - failpoint.Inject("DebugTraceStableStatsTbl", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("DebugTraceStableStatsTbl")); _err_ == nil { if val.(bool) { stabilizeGetStatsTblInfo(traceInfo) } - }) + } root.AppendStepToCurrentContext(traceInfo) } diff --git a/pkg/planner/core/debugtrace.go__failpoint_stash__ b/pkg/planner/core/debugtrace.go__failpoint_stash__ new file mode 100644 index 0000000000000..babda2b1d4551 --- /dev/null +++ b/pkg/planner/core/debugtrace.go__failpoint_stash__ @@ -0,0 +1,261 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "strconv" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/context" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/util/hint" +) + +/* + Below is debug trace for the received command from the client. + It records the input to the optimizer at the very beginning of query optimization. +*/ + +type receivedCmdInfo struct { + Command string + ExecutedASTText string + ExecuteStmtInfo *executeInfo +} + +type executeInfo struct { + PreparedSQL string + BinaryParamsInfo []binaryParamInfo + UseCursor bool +} + +type binaryParamInfo struct { + Type string + Value string +} + +func (info *binaryParamInfo) MarshalJSON() ([]byte, error) { + type binaryParamInfoForMarshal binaryParamInfo + infoForMarshal := new(binaryParamInfoForMarshal) + quote := `"` + // We only need the escape functionality of strconv.Quote, the quoting is not needed, + // so we trim the \" prefix and suffix here. + infoForMarshal.Type = strings.TrimSuffix( + strings.TrimPrefix( + strconv.Quote(info.Type), + quote), + quote) + infoForMarshal.Value = strings.TrimSuffix( + strings.TrimPrefix( + strconv.Quote(info.Value), + quote), + quote) + return debugtrace.EncodeJSONCommon(infoForMarshal) +} + +// DebugTraceReceivedCommand records the received command from the client to the debug trace. +func DebugTraceReceivedCommand(s base.PlanContext, cmd byte, stmtNode ast.StmtNode) { + sessionVars := s.GetSessionVars() + trace := debugtrace.GetOrInitDebugTraceRoot(s) + traceInfo := new(receivedCmdInfo) + trace.AppendStepWithNameToCurrentContext(traceInfo, "Received Command") + traceInfo.Command = mysql.Command2Str[cmd] + traceInfo.ExecutedASTText = stmtNode.Text() + + // Collect information for execute stmt, and record it in executeInfo. + var binaryParams []expression.Expression + var planCacheStmt *PlanCacheStmt + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + if execStmt.PrepStmt != nil { + planCacheStmt, _ = execStmt.PrepStmt.(*PlanCacheStmt) + } + if execStmt.BinaryArgs != nil { + binaryParams, _ = execStmt.BinaryArgs.([]expression.Expression) + } + } + useCursor := sessionVars.HasStatusFlag(mysql.ServerStatusCursorExists) + // If none of them needs record, we don't need a executeInfo. + if binaryParams == nil && planCacheStmt == nil && !useCursor { + return + } + execInfo := &executeInfo{} + traceInfo.ExecuteStmtInfo = execInfo + execInfo.UseCursor = useCursor + if planCacheStmt != nil { + execInfo.PreparedSQL = planCacheStmt.StmtText + } + if len(binaryParams) > 0 { + execInfo.BinaryParamsInfo = make([]binaryParamInfo, len(binaryParams)) + for i, param := range binaryParams { + execInfo.BinaryParamsInfo[i].Type = param.GetType(s.GetExprCtx().GetEvalCtx()).String() + execInfo.BinaryParamsInfo[i].Value = param.StringWithCtx(s.GetExprCtx().GetEvalCtx(), errors.RedactLogDisable) + } + } +} + +/* + Below is debug trace for the hint that matches the current query. +*/ + +type bindingHint struct { + Hint *hint.HintsSet + trying bool +} + +func (b *bindingHint) MarshalJSON() ([]byte, error) { + tmp := make(map[string]string, 1) + hintStr, err := b.Hint.Restore() + if err != nil { + return debugtrace.EncodeJSONCommon(err) + } + if b.trying { + tmp["Trying Hint"] = hintStr + } else { + tmp["Best Hint"] = hintStr + } + return debugtrace.EncodeJSONCommon(tmp) +} + +// DebugTraceTryBinding records the hint that might be chosen to the debug trace. +func DebugTraceTryBinding(s context.PlanContext, binding *hint.HintsSet) { + root := debugtrace.GetOrInitDebugTraceRoot(s) + traceInfo := &bindingHint{ + Hint: binding, + trying: true, + } + root.AppendStepToCurrentContext(traceInfo) +} + +// DebugTraceBestBinding records the chosen hint to the debug trace. +func DebugTraceBestBinding(s context.PlanContext, binding *hint.HintsSet) { + root := debugtrace.GetOrInitDebugTraceRoot(s) + traceInfo := &bindingHint{ + Hint: binding, + trying: false, + } + root.AppendStepToCurrentContext(traceInfo) +} + +/* + Below is debug trace for getStatsTable(). + Part of the logic for collecting information is in statistics/debug_trace.go. +*/ + +type getStatsTblInfo struct { + TableName string + TblInfoID int64 + InputPhysicalID int64 + HandleIsNil bool + UsePartitionStats bool + CountIsZero bool + Uninitialized bool + Outdated bool + StatsTblInfo *statistics.StatsTblTraceInfo +} + +func debugTraceGetStatsTbl( + s base.PlanContext, + tblInfo *model.TableInfo, + pid int64, + handleIsNil, + usePartitionStats, + countIsZero, + uninitialized, + outdated bool, + statsTbl *statistics.Table, +) { + root := debugtrace.GetOrInitDebugTraceRoot(s) + traceInfo := &getStatsTblInfo{ + TableName: tblInfo.Name.O, + TblInfoID: tblInfo.ID, + InputPhysicalID: pid, + HandleIsNil: handleIsNil, + UsePartitionStats: usePartitionStats, + CountIsZero: countIsZero, + Uninitialized: uninitialized, + Outdated: outdated, + StatsTblInfo: statistics.TraceStatsTbl(statsTbl), + } + failpoint.Inject("DebugTraceStableStatsTbl", func(val failpoint.Value) { + if val.(bool) { + stabilizeGetStatsTblInfo(traceInfo) + } + }) + root.AppendStepToCurrentContext(traceInfo) +} + +// Only for test. +func stabilizeGetStatsTblInfo(info *getStatsTblInfo) { + info.TblInfoID = 100 + info.InputPhysicalID = 100 + tbl := info.StatsTblInfo + if tbl == nil { + return + } + tbl.PhysicalID = 100 + tbl.Version = 440930000000000000 + for _, col := range tbl.Columns { + col.LastUpdateVersion = 440930000000000000 + } + for _, idx := range tbl.Indexes { + idx.LastUpdateVersion = 440930000000000000 + } +} + +/* + Below is debug trace for AccessPath. +*/ + +type accessPathForDebugTrace struct { + IndexName string `json:",omitempty"` + AccessConditions []string + IndexFilters []string + TableFilters []string + PartialPaths []accessPathForDebugTrace `json:",omitempty"` + CountAfterAccess float64 + CountAfterIndex float64 +} + +func convertAccessPathForDebugTrace(ctx expression.EvalContext, path *util.AccessPath, out *accessPathForDebugTrace) { + if path.Index != nil { + out.IndexName = path.Index.Name.O + } + out.AccessConditions = expression.ExprsToStringsForDisplay(ctx, path.AccessConds) + out.IndexFilters = expression.ExprsToStringsForDisplay(ctx, path.IndexFilters) + out.TableFilters = expression.ExprsToStringsForDisplay(ctx, path.TableFilters) + out.CountAfterAccess = path.CountAfterAccess + out.CountAfterIndex = path.CountAfterIndex + out.PartialPaths = make([]accessPathForDebugTrace, len(path.PartialIndexPaths)) + for i, partialPath := range path.PartialIndexPaths { + convertAccessPathForDebugTrace(ctx, partialPath, &out.PartialPaths[i]) + } +} + +func debugTraceAccessPaths(s base.PlanContext, paths []*util.AccessPath) { + root := debugtrace.GetOrInitDebugTraceRoot(s) + traceInfo := make([]accessPathForDebugTrace, len(paths)) + for i, partialPath := range paths { + convertAccessPathForDebugTrace(s.GetExprCtx().GetEvalCtx(), partialPath, &traceInfo[i]) + } + root.AppendStepWithNameToCurrentContext(traceInfo, "Access paths") +} diff --git a/pkg/planner/core/encode.go b/pkg/planner/core/encode.go index af67053afa477..0ae4093bd5504 100644 --- a/pkg/planner/core/encode.go +++ b/pkg/planner/core/encode.go @@ -41,12 +41,12 @@ func EncodeFlatPlan(flat *FlatPhysicalPlan) string { if flat.InExecute { return "" } - failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockPlanRowCount")); _err_ == nil { selectPlan, _ := flat.Main.GetSelectPlan() for _, op := range selectPlan { op.Origin.StatsInfo().RowCount = float64(val.(int)) } - }) + } pn := encoderPool.Get().(*planEncoder) defer func() { pn.buf.Reset() @@ -164,9 +164,9 @@ func EncodePlan(p base.Plan) string { defer encoderPool.Put(pn) selectPlan := getSelectPlan(p) if selectPlan != nil { - failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockPlanRowCount")); _err_ == nil { selectPlan.StatsInfo().RowCount = float64(val.(int)) - }) + } } return pn.encodePlanTree(p) } diff --git a/pkg/planner/core/encode.go__failpoint_stash__ b/pkg/planner/core/encode.go__failpoint_stash__ new file mode 100644 index 0000000000000..af67053afa477 --- /dev/null +++ b/pkg/planner/core/encode.go__failpoint_stash__ @@ -0,0 +1,386 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "bytes" + "crypto/sha256" + "hash" + "strconv" + "sync" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/util/plancodec" +) + +// EncodeFlatPlan encodes a FlatPhysicalPlan with compression. +func EncodeFlatPlan(flat *FlatPhysicalPlan) string { + if len(flat.Main) == 0 { + return "" + } + // We won't collect the plan when we're in "EXPLAIN FOR" statement and the plan is from EXECUTE statement (please + // read comments of InExecute for details about the meaning of InExecute) because we are unable to get some + // necessary information when the execution of the plan is finished and some states in the session such as + // PreparedParams are cleaned. + // The behavior in BinaryPlanStrFromFlatPlan() is also the same. + if flat.InExecute { + return "" + } + failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { + selectPlan, _ := flat.Main.GetSelectPlan() + for _, op := range selectPlan { + op.Origin.StatsInfo().RowCount = float64(val.(int)) + } + }) + pn := encoderPool.Get().(*planEncoder) + defer func() { + pn.buf.Reset() + encoderPool.Put(pn) + }() + buf := pn.buf + buf.Reset() + opCount := len(flat.Main) + for _, cte := range flat.CTEs { + opCount += len(cte) + } + // assume an operator costs around 80 bytes, preallocate space for them + buf.Grow(80 * opCount) + encodeFlatPlanTree(flat.Main, 0, &buf) + for _, cte := range flat.CTEs { + fop := cte[0] + cteDef := cte[0].Origin.(*CTEDefinition) + id := cteDef.CTE.IDForStorage + tp := plancodec.TypeCTEDefinition + taskTypeInfo := plancodec.EncodeTaskType(fop.IsRoot, fop.StoreType) + p := fop.Origin + actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(p.SCtx(), p, nil) + var estRows float64 + if fop.IsPhysicalPlan { + estRows = fop.Origin.(base.PhysicalPlan).GetEstRowCountForDisplay() + } else if statsInfo := p.StatsInfo(); statsInfo != nil { + estRows = statsInfo.RowCount + } + plancodec.EncodePlanNode( + int(fop.Depth), + strconv.Itoa(id)+fop.Label.String(), + tp, + estRows, + taskTypeInfo, + fop.Origin.ExplainInfo(), + actRows, + analyzeInfo, + memoryInfo, + diskInfo, + &buf, + ) + if len(cte) > 1 { + encodeFlatPlanTree(cte[1:], 1, &buf) + } + } + return plancodec.Compress(buf.Bytes()) +} + +func encodeFlatPlanTree(flatTree FlatPlanTree, offset int, buf *bytes.Buffer) { + for i := 0; i < len(flatTree); { + fop := flatTree[i] + taskTypeInfo := plancodec.EncodeTaskType(fop.IsRoot, fop.StoreType) + p := fop.Origin + actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(p.SCtx(), p, nil) + var estRows float64 + if fop.IsPhysicalPlan { + estRows = fop.Origin.(base.PhysicalPlan).GetEstRowCountForDisplay() + } else if statsInfo := p.StatsInfo(); statsInfo != nil { + estRows = statsInfo.RowCount + } + plancodec.EncodePlanNode( + int(fop.Depth), + strconv.Itoa(fop.Origin.ID())+fop.Label.String(), + fop.Origin.TP(), + estRows, + taskTypeInfo, + fop.Origin.ExplainInfo(), + actRows, + analyzeInfo, + memoryInfo, + diskInfo, + buf, + ) + + if fop.NeedReverseDriverSide { + // If NeedReverseDriverSide is true, we don't rely on the order of flatTree. + // Instead, we manually slice the build and probe side children from flatTree and recursively call + // encodeFlatPlanTree to keep build side before probe side. + buildSide := flatTree[fop.ChildrenIdx[1]-offset : fop.ChildrenEndIdx+1-offset] + probeSide := flatTree[fop.ChildrenIdx[0]-offset : fop.ChildrenIdx[1]-offset] + encodeFlatPlanTree(buildSide, fop.ChildrenIdx[1], buf) + encodeFlatPlanTree(probeSide, fop.ChildrenIdx[0], buf) + // Skip the children plan tree of the current operator. + i = fop.ChildrenEndIdx + 1 - offset + } else { + // Normally, we just go to the next element in the slice. + i++ + } + } +} + +var encoderPool = sync.Pool{ + New: func() any { + return &planEncoder{} + }, +} + +type planEncoder struct { + buf bytes.Buffer + encodedPlans map[int]bool + + ctes []*PhysicalCTE +} + +// EncodePlan is used to encodePlan the plan to the plan tree with compressing. +// Deprecated: FlattenPhysicalPlan() + EncodeFlatPlan() is preferred. +func EncodePlan(p base.Plan) string { + if explain, ok := p.(*Explain); ok { + p = explain.TargetPlan + } + if p == nil || p.SCtx() == nil { + return "" + } + pn := encoderPool.Get().(*planEncoder) + defer encoderPool.Put(pn) + selectPlan := getSelectPlan(p) + if selectPlan != nil { + failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { + selectPlan.StatsInfo().RowCount = float64(val.(int)) + }) + } + return pn.encodePlanTree(p) +} + +func (pn *planEncoder) encodePlanTree(p base.Plan) string { + pn.encodedPlans = make(map[int]bool) + pn.buf.Reset() + pn.ctes = pn.ctes[:0] + pn.encodePlan(p, true, kv.TiKV, 0) + pn.encodeCTEPlan() + return plancodec.Compress(pn.buf.Bytes()) +} + +func (pn *planEncoder) encodeCTEPlan() { + if len(pn.ctes) <= 0 { + return + } + explainedCTEPlan := make(map[int]struct{}) + for i := 0; i < len(pn.ctes); i++ { + x := (*CTEDefinition)(pn.ctes[i]) + // skip if the CTE has been explained, the same CTE has same IDForStorage + if _, ok := explainedCTEPlan[x.CTE.IDForStorage]; ok { + continue + } + taskTypeInfo := plancodec.EncodeTaskType(true, kv.TiKV) + actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(x.SCtx(), x, nil) + rowCount := 0.0 + if statsInfo := x.StatsInfo(); statsInfo != nil { + rowCount = x.StatsInfo().RowCount + } + plancodec.EncodePlanNode(0, strconv.Itoa(x.CTE.IDForStorage), plancodec.TypeCTEDefinition, rowCount, taskTypeInfo, x.ExplainInfo(), actRows, analyzeInfo, memoryInfo, diskInfo, &pn.buf) + pn.encodePlan(x.SeedPlan, true, kv.TiKV, 1) + if x.RecurPlan != nil { + pn.encodePlan(x.RecurPlan, true, kv.TiKV, 1) + } + explainedCTEPlan[x.CTE.IDForStorage] = struct{}{} + } +} + +func (pn *planEncoder) encodePlan(p base.Plan, isRoot bool, store kv.StoreType, depth int) { + taskTypeInfo := plancodec.EncodeTaskType(isRoot, store) + actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(p.SCtx(), p, nil) + rowCount := 0.0 + if pp, ok := p.(base.PhysicalPlan); ok { + rowCount = pp.GetEstRowCountForDisplay() + } else if statsInfo := p.StatsInfo(); statsInfo != nil { + rowCount = statsInfo.RowCount + } + plancodec.EncodePlanNode(depth, strconv.Itoa(p.ID()), p.TP(), rowCount, taskTypeInfo, p.ExplainInfo(), actRows, analyzeInfo, memoryInfo, diskInfo, &pn.buf) + pn.encodedPlans[p.ID()] = true + depth++ + + selectPlan := getSelectPlan(p) + if selectPlan == nil { + return + } + if !pn.encodedPlans[selectPlan.ID()] { + pn.encodePlan(selectPlan, isRoot, store, depth) + return + } + for _, child := range selectPlan.Children() { + if pn.encodedPlans[child.ID()] { + continue + } + pn.encodePlan(child, isRoot, store, depth) + } + switch copPlan := selectPlan.(type) { + case *PhysicalTableReader: + pn.encodePlan(copPlan.tablePlan, false, copPlan.StoreType, depth) + case *PhysicalIndexReader: + pn.encodePlan(copPlan.indexPlan, false, store, depth) + case *PhysicalIndexLookUpReader: + pn.encodePlan(copPlan.indexPlan, false, store, depth) + pn.encodePlan(copPlan.tablePlan, false, store, depth) + case *PhysicalIndexMergeReader: + for _, p := range copPlan.partialPlans { + pn.encodePlan(p, false, store, depth) + } + if copPlan.tablePlan != nil { + pn.encodePlan(copPlan.tablePlan, false, store, depth) + } + case *PhysicalCTE: + pn.ctes = append(pn.ctes, copPlan) + } +} + +var digesterPool = sync.Pool{ + New: func() any { + return &planDigester{ + hasher: sha256.New(), + } + }, +} + +type planDigester struct { + buf bytes.Buffer + encodedPlans map[int]bool + hasher hash.Hash +} + +// NormalizeFlatPlan normalizes a FlatPhysicalPlan and generates plan digest. +func NormalizeFlatPlan(flat *FlatPhysicalPlan) (normalized string, digest *parser.Digest) { + if flat == nil { + return "", parser.NewDigest(nil) + } + selectPlan, selectPlanOffset := flat.Main.GetSelectPlan() + if len(selectPlan) == 0 || !selectPlan[0].IsPhysicalPlan { + return "", parser.NewDigest(nil) + } + d := digesterPool.Get().(*planDigester) + defer func() { + d.buf.Reset() + d.hasher.Reset() + digesterPool.Put(d) + }() + // assume an operator costs around 30 bytes, preallocate space for them + d.buf.Grow(30 * len(selectPlan)) + for _, fop := range selectPlan { + taskTypeInfo := plancodec.EncodeTaskTypeForNormalize(fop.IsRoot, fop.StoreType) + p := fop.Origin.(base.PhysicalPlan) + plancodec.NormalizePlanNode( + int(fop.Depth-uint32(selectPlanOffset)), + fop.Origin.TP(), + taskTypeInfo, + p.ExplainNormalizedInfo(), + &d.buf, + ) + } + normalized = d.buf.String() + if len(normalized) == 0 { + return "", parser.NewDigest(nil) + } + _, err := d.hasher.Write(d.buf.Bytes()) + if err != nil { + panic(err) + } + digest = parser.NewDigest(d.hasher.Sum(nil)) + return +} + +// NormalizePlan is used to normalize the plan and generate plan digest. +// Deprecated: FlattenPhysicalPlan() + NormalizeFlatPlan() is preferred. +func NormalizePlan(p base.Plan) (normalized string, digest *parser.Digest) { + selectPlan := getSelectPlan(p) + if selectPlan == nil { + return "", parser.NewDigest(nil) + } + d := digesterPool.Get().(*planDigester) + defer func() { + d.buf.Reset() + d.hasher.Reset() + digesterPool.Put(d) + }() + d.normalizePlanTree(selectPlan) + normalized = d.buf.String() + _, err := d.hasher.Write(d.buf.Bytes()) + if err != nil { + panic(err) + } + digest = parser.NewDigest(d.hasher.Sum(nil)) + return +} + +func (d *planDigester) normalizePlanTree(p base.PhysicalPlan) { + d.encodedPlans = make(map[int]bool) + d.buf.Reset() + d.normalizePlan(p, true, kv.TiKV, 0) +} + +func (d *planDigester) normalizePlan(p base.PhysicalPlan, isRoot bool, store kv.StoreType, depth int) { + taskTypeInfo := plancodec.EncodeTaskTypeForNormalize(isRoot, store) + plancodec.NormalizePlanNode(depth, p.TP(), taskTypeInfo, p.ExplainNormalizedInfo(), &d.buf) + d.encodedPlans[p.ID()] = true + + depth++ + for _, child := range p.Children() { + if d.encodedPlans[child.ID()] { + continue + } + d.normalizePlan(child, isRoot, store, depth) + } + switch x := p.(type) { + case *PhysicalTableReader: + d.normalizePlan(x.tablePlan, false, x.StoreType, depth) + case *PhysicalIndexReader: + d.normalizePlan(x.indexPlan, false, store, depth) + case *PhysicalIndexLookUpReader: + d.normalizePlan(x.indexPlan, false, store, depth) + d.normalizePlan(x.tablePlan, false, store, depth) + case *PhysicalIndexMergeReader: + for _, p := range x.partialPlans { + d.normalizePlan(p, false, store, depth) + } + if x.tablePlan != nil { + d.normalizePlan(x.tablePlan, false, store, depth) + } + } +} + +func getSelectPlan(p base.Plan) base.PhysicalPlan { + var selectPlan base.PhysicalPlan + if physicalPlan, ok := p.(base.PhysicalPlan); ok { + selectPlan = physicalPlan + } else { + switch x := p.(type) { + case *Delete: + selectPlan = x.SelectPlan + case *Update: + selectPlan = x.SelectPlan + case *Insert: + selectPlan = x.SelectPlan + case *Explain: + selectPlan = getSelectPlan(x.TargetPlan) + } + } + return selectPlan +} diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index ad3398b1f5fad..21c1043b84848 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -873,11 +873,11 @@ func buildIndexJoinInner2TableScan( lastColMng = indexJoinResult.lastColManager } joins = make([]base.PhysicalPlan, 0, 3) - failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockOnlyEnableIndexHashJoin")); _err_ == nil { if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { - failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, nil, keyOff2IdxOff, path, lastColMng)) + return constructIndexHashJoin(p, prop, outerIdx, innerTask, nil, keyOff2IdxOff, path, lastColMng) } - }) + } joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, lastColMng, true)...) // We can reuse the `innerTask` here since index nested loop hash join // do not need the inner child to promise the order. @@ -924,11 +924,11 @@ func buildIndexJoinInner2IndexScan( } } innerTask := constructInnerIndexScanTask(p, prop, wrapper, indexJoinResult.chosenPath, indexJoinResult.chosenRanges.Range(), indexJoinResult.chosenRemained, innerJoinKeys, indexJoinResult.idxOff2KeyOff, rangeInfo, false, false, avgInnerRowCnt, maxOneRow) - failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockOnlyEnableIndexHashJoin")); _err_ == nil { if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL && innerTask != nil { - failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)) + return constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager) } - }) + } if innerTask != nil { joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager, true)...) // We can reuse the `innerTask` here since index nested loop hash join @@ -1867,11 +1867,11 @@ func tryToGetMppHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, useBCJ } // set preferredBuildIndex for test - failpoint.Inject("mockPreferredBuildIndex", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockPreferredBuildIndex")); _err_ == nil { if !p.SCtx().GetSessionVars().InRestrictedSQL { preferredBuildIndex = val.(int) } - }) + } baseJoin.InnerChildIdx = preferredBuildIndex childrenProps := make([]*property.PhysicalProperty, 2) diff --git a/pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ b/pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ new file mode 100644 index 0000000000000..ad3398b1f5fad --- /dev/null +++ b/pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ @@ -0,0 +1,3004 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "fmt" + "math" + "slices" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cardinality" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/cost" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/types" + h "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/plancodec" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +func exhaustPhysicalPlans4LogicalUnionScan(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + p := lp.(*logicalop.LogicalUnionScan) + if prop.IsFlashProp() { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because operator `UnionScan` is not supported now.") + return nil, true, nil + } + childProp := prop.CloneEssentialFields() + us := PhysicalUnionScan{ + Conditions: p.Conditions, + HandleCols: p.HandleCols, + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), childProp) + return []base.PhysicalPlan{us}, true, nil +} + +func findMaxPrefixLen(candidates [][]*expression.Column, keys []*expression.Column) int { + maxLen := 0 + for _, candidateKeys := range candidates { + matchedLen := 0 + for i := range keys { + if !(i < len(candidateKeys) && keys[i].EqualColumn(candidateKeys[i])) { + break + } + matchedLen++ + } + if matchedLen > maxLen { + maxLen = matchedLen + } + } + return maxLen +} + +func moveEqualToOtherConditions(p *LogicalJoin, offsets []int) []expression.Expression { + // Construct used equal condition set based on the equal condition offsets. + usedEqConds := set.NewIntSet() + for _, eqCondIdx := range offsets { + usedEqConds.Insert(eqCondIdx) + } + + // Construct otherConds, which is composed of the original other conditions + // and the remained unused equal conditions. + numOtherConds := len(p.OtherConditions) + len(p.EqualConditions) - len(usedEqConds) + otherConds := make([]expression.Expression, len(p.OtherConditions), numOtherConds) + copy(otherConds, p.OtherConditions) + for eqCondIdx := range p.EqualConditions { + if !usedEqConds.Exist(eqCondIdx) { + otherConds = append(otherConds, p.EqualConditions[eqCondIdx]) + } + } + + return otherConds +} + +// Only if the input required prop is the prefix fo join keys, we can pass through this property. +func (p *PhysicalMergeJoin) tryToGetChildReqProp(prop *property.PhysicalProperty) ([]*property.PhysicalProperty, bool) { + all, desc := prop.AllSameOrder() + lProp := property.NewPhysicalProperty(property.RootTaskType, p.LeftJoinKeys, desc, math.MaxFloat64, false) + rProp := property.NewPhysicalProperty(property.RootTaskType, p.RightJoinKeys, desc, math.MaxFloat64, false) + lProp.CTEProducerStatus = prop.CTEProducerStatus + rProp.CTEProducerStatus = prop.CTEProducerStatus + if !prop.IsSortItemEmpty() { + // sort merge join fits the cases of massive ordered data, so desc scan is always expensive. + if !all { + return nil, false + } + if !prop.IsPrefix(lProp) && !prop.IsPrefix(rProp) { + return nil, false + } + if prop.IsPrefix(rProp) && p.JoinType == LeftOuterJoin { + return nil, false + } + if prop.IsPrefix(lProp) && p.JoinType == RightOuterJoin { + return nil, false + } + } + + return []*property.PhysicalProperty{lProp, rProp}, true +} + +func checkJoinKeyCollation(leftKeys, rightKeys []*expression.Column) bool { + // if a left key and its corresponding right key have different collation, don't use MergeJoin since + // the their children may sort their records in different ways + for i := range leftKeys { + lt := leftKeys[i].RetType + rt := rightKeys[i].RetType + if (lt.EvalType() == types.ETString && rt.EvalType() == types.ETString) && + (leftKeys[i].RetType.GetCharset() != rightKeys[i].RetType.GetCharset() || + leftKeys[i].RetType.GetCollate() != rightKeys[i].RetType.GetCollate()) { + return false + } + } + return true +} + +// GetMergeJoin convert the logical join to physical merge join based on the physical property. +func GetMergeJoin(p *LogicalJoin, prop *property.PhysicalProperty, schema *expression.Schema, statsInfo *property.StatsInfo, leftStatsInfo *property.StatsInfo, rightStatsInfo *property.StatsInfo) []base.PhysicalPlan { + joins := make([]base.PhysicalPlan, 0, len(p.LeftProperties)+1) + // The LeftProperties caches all the possible properties that are provided by its children. + leftJoinKeys, rightJoinKeys, isNullEQ, hasNullEQ := p.GetJoinKeys() + + // EnumType/SetType Unsupported: merge join conflicts with index order. + // ref: https://github.com/pingcap/tidb/issues/24473, https://github.com/pingcap/tidb/issues/25669 + for _, leftKey := range leftJoinKeys { + if leftKey.RetType.GetType() == mysql.TypeEnum || leftKey.RetType.GetType() == mysql.TypeSet { + return nil + } + } + for _, rightKey := range rightJoinKeys { + if rightKey.RetType.GetType() == mysql.TypeEnum || rightKey.RetType.GetType() == mysql.TypeSet { + return nil + } + } + + // TODO: support null equal join keys for merge join + if hasNullEQ { + return nil + } + for _, lhsChildProperty := range p.LeftProperties { + offsets := util.GetMaxSortPrefix(lhsChildProperty, leftJoinKeys) + // If not all equal conditions hit properties. We ban merge join heuristically. Because in this case, merge join + // may get a very low performance. In executor, executes join results before other conditions filter it. + if len(offsets) < len(leftJoinKeys) { + continue + } + + leftKeys := lhsChildProperty[:len(offsets)] + rightKeys := expression.NewSchema(rightJoinKeys...).ColumnsByIndices(offsets) + newIsNullEQ := make([]bool, 0, len(offsets)) + for _, offset := range offsets { + newIsNullEQ = append(newIsNullEQ, isNullEQ[offset]) + } + + prefixLen := findMaxPrefixLen(p.RightProperties, rightKeys) + if prefixLen == 0 { + continue + } + + leftKeys = leftKeys[:prefixLen] + rightKeys = rightKeys[:prefixLen] + newIsNullEQ = newIsNullEQ[:prefixLen] + if !checkJoinKeyCollation(leftKeys, rightKeys) { + continue + } + offsets = offsets[:prefixLen] + baseJoin := basePhysicalJoin{ + JoinType: p.JoinType, + LeftConditions: p.LeftConditions, + RightConditions: p.RightConditions, + DefaultValues: p.DefaultValues, + LeftJoinKeys: leftKeys, + RightJoinKeys: rightKeys, + IsNullEQ: newIsNullEQ, + } + mergeJoin := PhysicalMergeJoin{basePhysicalJoin: baseJoin}.Init(p.SCtx(), statsInfo.ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset()) + mergeJoin.SetSchema(schema) + mergeJoin.OtherConditions = moveEqualToOtherConditions(p, offsets) + mergeJoin.initCompareFuncs() + if reqProps, ok := mergeJoin.tryToGetChildReqProp(prop); ok { + // Adjust expected count for children nodes. + if prop.ExpectedCnt < statsInfo.RowCount { + expCntScale := prop.ExpectedCnt / statsInfo.RowCount + reqProps[0].ExpectedCnt = leftStatsInfo.RowCount * expCntScale + reqProps[1].ExpectedCnt = rightStatsInfo.RowCount * expCntScale + } + mergeJoin.childrenReqProps = reqProps + _, desc := prop.AllSameOrder() + mergeJoin.Desc = desc + joins = append(joins, mergeJoin) + } + } + + if p.PreferJoinType&h.PreferNoMergeJoin > 0 { + if p.PreferJoinType&h.PreferMergeJoin == 0 { + return nil + } + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( + "Some MERGE_JOIN and NO_MERGE_JOIN hints conflict, NO_MERGE_JOIN is ignored") + } + + // If TiDB_SMJ hint is existed, it should consider enforce merge join, + // because we can't trust lhsChildProperty completely. + if (p.PreferJoinType&h.PreferMergeJoin) > 0 || + shouldSkipHashJoin(p) { // if hash join is not allowed, generate as many other types of join as possible to avoid 'cant-find-plan' error. + joins = append(joins, getEnforcedMergeJoin(p, prop, schema, statsInfo)...) + } + + return joins +} + +// Change JoinKeys order, by offsets array +// offsets array is generate by prop check +func getNewJoinKeysByOffsets(oldJoinKeys []*expression.Column, offsets []int) []*expression.Column { + newKeys := make([]*expression.Column, 0, len(oldJoinKeys)) + for _, offset := range offsets { + newKeys = append(newKeys, oldJoinKeys[offset]) + } + for pos, key := range oldJoinKeys { + isExist := false + for _, p := range offsets { + if p == pos { + isExist = true + break + } + } + if !isExist { + newKeys = append(newKeys, key) + } + } + return newKeys +} + +func getNewNullEQByOffsets(oldNullEQ []bool, offsets []int) []bool { + newNullEQ := make([]bool, 0, len(oldNullEQ)) + for _, offset := range offsets { + newNullEQ = append(newNullEQ, oldNullEQ[offset]) + } + for pos, key := range oldNullEQ { + isExist := false + for _, p := range offsets { + if p == pos { + isExist = true + break + } + } + if !isExist { + newNullEQ = append(newNullEQ, key) + } + } + return newNullEQ +} + +func getEnforcedMergeJoin(p *LogicalJoin, prop *property.PhysicalProperty, schema *expression.Schema, statsInfo *property.StatsInfo) []base.PhysicalPlan { + // Check whether SMJ can satisfy the required property + leftJoinKeys, rightJoinKeys, isNullEQ, hasNullEQ := p.GetJoinKeys() + // TODO: support null equal join keys for merge join + if hasNullEQ { + return nil + } + offsets := make([]int, 0, len(leftJoinKeys)) + all, desc := prop.AllSameOrder() + if !all { + return nil + } + evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() + for _, item := range prop.SortItems { + isExist, hasLeftColInProp, hasRightColInProp := false, false, false + for joinKeyPos := 0; joinKeyPos < len(leftJoinKeys); joinKeyPos++ { + var key *expression.Column + if item.Col.Equal(evalCtx, leftJoinKeys[joinKeyPos]) { + key = leftJoinKeys[joinKeyPos] + hasLeftColInProp = true + } + if item.Col.Equal(evalCtx, rightJoinKeys[joinKeyPos]) { + key = rightJoinKeys[joinKeyPos] + hasRightColInProp = true + } + if key == nil { + continue + } + for i := 0; i < len(offsets); i++ { + if offsets[i] == joinKeyPos { + isExist = true + break + } + } + if !isExist { + offsets = append(offsets, joinKeyPos) + } + isExist = true + break + } + if !isExist { + return nil + } + // If the output wants the order of the inner side. We should reject it since we might add null-extend rows of that side. + if p.JoinType == LeftOuterJoin && hasRightColInProp { + return nil + } + if p.JoinType == RightOuterJoin && hasLeftColInProp { + return nil + } + } + // Generate the enforced sort merge join + leftKeys := getNewJoinKeysByOffsets(leftJoinKeys, offsets) + rightKeys := getNewJoinKeysByOffsets(rightJoinKeys, offsets) + newNullEQ := getNewNullEQByOffsets(isNullEQ, offsets) + otherConditions := make([]expression.Expression, len(p.OtherConditions), len(p.OtherConditions)+len(p.EqualConditions)) + copy(otherConditions, p.OtherConditions) + if !checkJoinKeyCollation(leftKeys, rightKeys) { + // if the join keys' collation are conflicted, we use the empty join key + // and move EqualConditions to OtherConditions. + leftKeys = nil + rightKeys = nil + newNullEQ = nil + otherConditions = append(otherConditions, expression.ScalarFuncs2Exprs(p.EqualConditions)...) + } + lProp := property.NewPhysicalProperty(property.RootTaskType, leftKeys, desc, math.MaxFloat64, true) + rProp := property.NewPhysicalProperty(property.RootTaskType, rightKeys, desc, math.MaxFloat64, true) + baseJoin := basePhysicalJoin{ + JoinType: p.JoinType, + LeftConditions: p.LeftConditions, + RightConditions: p.RightConditions, + DefaultValues: p.DefaultValues, + LeftJoinKeys: leftKeys, + RightJoinKeys: rightKeys, + IsNullEQ: newNullEQ, + OtherConditions: otherConditions, + } + enforcedPhysicalMergeJoin := PhysicalMergeJoin{basePhysicalJoin: baseJoin, Desc: desc}.Init(p.SCtx(), statsInfo.ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset()) + enforcedPhysicalMergeJoin.SetSchema(schema) + enforcedPhysicalMergeJoin.childrenReqProps = []*property.PhysicalProperty{lProp, rProp} + enforcedPhysicalMergeJoin.initCompareFuncs() + return []base.PhysicalPlan{enforcedPhysicalMergeJoin} +} + +func (p *PhysicalMergeJoin) initCompareFuncs() { + p.CompareFuncs = make([]expression.CompareFunc, 0, len(p.LeftJoinKeys)) + for i := range p.LeftJoinKeys { + p.CompareFuncs = append(p.CompareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), p.LeftJoinKeys[i], p.RightJoinKeys[i])) + } +} + +func shouldSkipHashJoin(p *LogicalJoin) bool { + return (p.PreferJoinType&h.PreferNoHashJoin) > 0 || (p.SCtx().GetSessionVars().DisableHashJoin) +} + +func getHashJoins(p *LogicalJoin, prop *property.PhysicalProperty) (joins []base.PhysicalPlan, forced bool) { + if !prop.IsSortItemEmpty() { // hash join doesn't promise any orders + return + } + + forceLeftToBuild := ((p.PreferJoinType & h.PreferLeftAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferRightAsHJProbe) > 0) + forceRightToBuild := ((p.PreferJoinType & h.PreferRightAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferLeftAsHJProbe) > 0) + if forceLeftToBuild && forceRightToBuild { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("Some HASH_JOIN_BUILD and HASH_JOIN_PROBE hints are conflicts, please check the hints") + forceLeftToBuild = false + forceRightToBuild = false + } + + joins = make([]base.PhysicalPlan, 0, 2) + switch p.JoinType { + case SemiJoin, AntiSemiJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + joins = append(joins, getHashJoin(p, prop, 1, false)) + if forceLeftToBuild || forceRightToBuild { + // Do not support specifying the build and probe side for semi join. + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(fmt.Sprintf("We can't use the HASH_JOIN_BUILD or HASH_JOIN_PROBE hint for %s, please check the hint", p.JoinType)) + forceLeftToBuild = false + forceRightToBuild = false + } + case LeftOuterJoin: + if !forceLeftToBuild { + joins = append(joins, getHashJoin(p, prop, 1, false)) + } + if !forceRightToBuild { + joins = append(joins, getHashJoin(p, prop, 1, true)) + } + case RightOuterJoin: + if !forceLeftToBuild { + joins = append(joins, getHashJoin(p, prop, 0, true)) + } + if !forceRightToBuild { + joins = append(joins, getHashJoin(p, prop, 0, false)) + } + case InnerJoin: + if forceLeftToBuild { + joins = append(joins, getHashJoin(p, prop, 0, false)) + } else if forceRightToBuild { + joins = append(joins, getHashJoin(p, prop, 1, false)) + } else { + joins = append(joins, getHashJoin(p, prop, 1, false)) + joins = append(joins, getHashJoin(p, prop, 0, false)) + } + } + + forced = (p.PreferJoinType&h.PreferHashJoin > 0) || forceLeftToBuild || forceRightToBuild + shouldSkipHashJoin := shouldSkipHashJoin(p) + if !forced && shouldSkipHashJoin { + return nil, false + } else if forced && shouldSkipHashJoin { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( + "A conflict between the HASH_JOIN hint and the NO_HASH_JOIN hint, " + + "or the tidb_opt_enable_hash_join system variable, the HASH_JOIN hint will take precedence.") + } + return +} + +func getHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, innerIdx int, useOuterToBuild bool) *PhysicalHashJoin { + chReqProps := make([]*property.PhysicalProperty, 2) + chReqProps[innerIdx] = &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus} + chReqProps[1-innerIdx] = &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus} + if prop.ExpectedCnt < p.StatsInfo().RowCount { + expCntScale := prop.ExpectedCnt / p.StatsInfo().RowCount + chReqProps[1-innerIdx].ExpectedCnt = p.Children()[1-innerIdx].StatsInfo().RowCount * expCntScale + } + hashJoin := NewPhysicalHashJoin(p, innerIdx, useOuterToBuild, p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), chReqProps...) + hashJoin.SetSchema(p.Schema()) + return hashJoin +} + +// When inner plan is TableReader, the parameter `ranges` will be nil. Because pk only have one column. So all of its range +// is generated during execution time. +func constructIndexJoin( + p *LogicalJoin, + prop *property.PhysicalProperty, + outerIdx int, + innerTask base.Task, + ranges ranger.MutableRanges, + keyOff2IdxOff []int, + path *util.AccessPath, + compareFilters *ColWithCmpFuncManager, + extractOtherEQ bool, +) []base.PhysicalPlan { + if ranges == nil { + ranges = ranger.Ranges{} // empty range + } + + joinType := p.JoinType + var ( + innerJoinKeys []*expression.Column + outerJoinKeys []*expression.Column + isNullEQ []bool + hasNullEQ bool + ) + if outerIdx == 0 { + outerJoinKeys, innerJoinKeys, isNullEQ, hasNullEQ = p.GetJoinKeys() + } else { + innerJoinKeys, outerJoinKeys, isNullEQ, hasNullEQ = p.GetJoinKeys() + } + // TODO: support null equal join keys for index join + if hasNullEQ { + return nil + } + chReqProps := make([]*property.PhysicalProperty, 2) + chReqProps[outerIdx] = &property.PhysicalProperty{TaskTp: property.RootTaskType, ExpectedCnt: math.MaxFloat64, SortItems: prop.SortItems, CTEProducerStatus: prop.CTEProducerStatus} + if prop.ExpectedCnt < p.StatsInfo().RowCount { + expCntScale := prop.ExpectedCnt / p.StatsInfo().RowCount + chReqProps[outerIdx].ExpectedCnt = p.Children()[outerIdx].StatsInfo().RowCount * expCntScale + } + newInnerKeys := make([]*expression.Column, 0, len(innerJoinKeys)) + newOuterKeys := make([]*expression.Column, 0, len(outerJoinKeys)) + newIsNullEQ := make([]bool, 0, len(isNullEQ)) + newKeyOff := make([]int, 0, len(keyOff2IdxOff)) + newOtherConds := make([]expression.Expression, len(p.OtherConditions), len(p.OtherConditions)+len(p.EqualConditions)) + copy(newOtherConds, p.OtherConditions) + for keyOff, idxOff := range keyOff2IdxOff { + if keyOff2IdxOff[keyOff] < 0 { + newOtherConds = append(newOtherConds, p.EqualConditions[keyOff]) + continue + } + newInnerKeys = append(newInnerKeys, innerJoinKeys[keyOff]) + newOuterKeys = append(newOuterKeys, outerJoinKeys[keyOff]) + newIsNullEQ = append(newIsNullEQ, isNullEQ[keyOff]) + newKeyOff = append(newKeyOff, idxOff) + } + + var outerHashKeys, innerHashKeys []*expression.Column + outerHashKeys, innerHashKeys = make([]*expression.Column, len(newOuterKeys)), make([]*expression.Column, len(newInnerKeys)) + copy(outerHashKeys, newOuterKeys) + copy(innerHashKeys, newInnerKeys) + // we can use the `col col` in `OtherCondition` to build the hashtable to avoid the unnecessary calculating. + for i := len(newOtherConds) - 1; extractOtherEQ && i >= 0; i = i - 1 { + switch c := newOtherConds[i].(type) { + case *expression.ScalarFunction: + if c.FuncName.L == ast.EQ { + lhs, ok1 := c.GetArgs()[0].(*expression.Column) + rhs, ok2 := c.GetArgs()[1].(*expression.Column) + if ok1 && ok2 { + if lhs.InOperand || rhs.InOperand { + // if this other-cond is from a `[not] in` sub-query, do not convert it into eq-cond since + // IndexJoin cannot deal with NULL correctly in this case; please see #25799 for more details. + continue + } + outerSchema, innerSchema := p.Children()[outerIdx].Schema(), p.Children()[1-outerIdx].Schema() + if outerSchema.Contains(lhs) && innerSchema.Contains(rhs) { + outerHashKeys = append(outerHashKeys, lhs) // nozero + innerHashKeys = append(innerHashKeys, rhs) // nozero + } else if innerSchema.Contains(lhs) && outerSchema.Contains(rhs) { + outerHashKeys = append(outerHashKeys, rhs) // nozero + innerHashKeys = append(innerHashKeys, lhs) // nozero + } + newOtherConds = append(newOtherConds[:i], newOtherConds[i+1:]...) + } + } + default: + continue + } + } + + baseJoin := basePhysicalJoin{ + InnerChildIdx: 1 - outerIdx, + LeftConditions: p.LeftConditions, + RightConditions: p.RightConditions, + OtherConditions: newOtherConds, + JoinType: joinType, + OuterJoinKeys: newOuterKeys, + InnerJoinKeys: newInnerKeys, + IsNullEQ: newIsNullEQ, + DefaultValues: p.DefaultValues, + } + + join := PhysicalIndexJoin{ + basePhysicalJoin: baseJoin, + innerPlan: innerTask.Plan(), + KeyOff2IdxOff: newKeyOff, + Ranges: ranges, + CompareFilters: compareFilters, + OuterHashKeys: outerHashKeys, + InnerHashKeys: innerHashKeys, + }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), chReqProps...) + if path != nil { + join.IdxColLens = path.IdxColLens + } + join.SetSchema(p.Schema()) + return []base.PhysicalPlan{join} +} + +func constructIndexMergeJoin( + p *LogicalJoin, + prop *property.PhysicalProperty, + outerIdx int, + innerTask base.Task, + ranges ranger.MutableRanges, + keyOff2IdxOff []int, + path *util.AccessPath, + compareFilters *ColWithCmpFuncManager, +) []base.PhysicalPlan { + hintExists := false + if (outerIdx == 1 && (p.PreferJoinType&h.PreferLeftAsINLMJInner) > 0) || (outerIdx == 0 && (p.PreferJoinType&h.PreferRightAsINLMJInner) > 0) { + hintExists = true + } + indexJoins := constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, compareFilters, !hintExists) + indexMergeJoins := make([]base.PhysicalPlan, 0, len(indexJoins)) + for _, plan := range indexJoins { + join := plan.(*PhysicalIndexJoin) + // Index merge join can't handle hash keys. So we ban it heuristically. + if len(join.InnerHashKeys) > len(join.InnerJoinKeys) { + return nil + } + + // EnumType/SetType Unsupported: merge join conflicts with index order. + // ref: https://github.com/pingcap/tidb/issues/24473, https://github.com/pingcap/tidb/issues/25669 + for _, innerKey := range join.InnerJoinKeys { + if innerKey.RetType.GetType() == mysql.TypeEnum || innerKey.RetType.GetType() == mysql.TypeSet { + return nil + } + } + for _, outerKey := range join.OuterJoinKeys { + if outerKey.RetType.GetType() == mysql.TypeEnum || outerKey.RetType.GetType() == mysql.TypeSet { + return nil + } + } + + hasPrefixCol := false + for _, l := range join.IdxColLens { + if l != types.UnspecifiedLength { + hasPrefixCol = true + break + } + } + // If index column has prefix length, the merge join can not guarantee the relevance + // between index and join keys. So we should skip this case. + // For more details, please check the following code and comments. + if hasPrefixCol { + continue + } + + // keyOff2KeyOffOrderByIdx is map the join keys offsets to [0, len(joinKeys)) ordered by the + // join key position in inner index. + keyOff2KeyOffOrderByIdx := make([]int, len(join.OuterJoinKeys)) + keyOffMapList := make([]int, len(join.KeyOff2IdxOff)) + copy(keyOffMapList, join.KeyOff2IdxOff) + keyOffMap := make(map[int]int, len(keyOffMapList)) + for i, idxOff := range keyOffMapList { + keyOffMap[idxOff] = i + } + slices.Sort(keyOffMapList) + keyIsIndexPrefix := true + for keyOff, idxOff := range keyOffMapList { + if keyOff != idxOff { + keyIsIndexPrefix = false + break + } + keyOff2KeyOffOrderByIdx[keyOffMap[idxOff]] = keyOff + } + if !keyIsIndexPrefix { + continue + } + // isOuterKeysPrefix means whether the outer join keys are the prefix of the prop items. + isOuterKeysPrefix := len(join.OuterJoinKeys) <= len(prop.SortItems) + compareFuncs := make([]expression.CompareFunc, 0, len(join.OuterJoinKeys)) + outerCompareFuncs := make([]expression.CompareFunc, 0, len(join.OuterJoinKeys)) + + for i := range join.KeyOff2IdxOff { + if isOuterKeysPrefix && !prop.SortItems[i].Col.EqualColumn(join.OuterJoinKeys[keyOff2KeyOffOrderByIdx[i]]) { + isOuterKeysPrefix = false + } + compareFuncs = append(compareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), join.OuterJoinKeys[i], join.InnerJoinKeys[i])) + outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), join.OuterJoinKeys[i], join.OuterJoinKeys[i])) + } + // canKeepOuterOrder means whether the prop items are the prefix of the outer join keys. + canKeepOuterOrder := len(prop.SortItems) <= len(join.OuterJoinKeys) + for i := 0; canKeepOuterOrder && i < len(prop.SortItems); i++ { + if !prop.SortItems[i].Col.EqualColumn(join.OuterJoinKeys[keyOff2KeyOffOrderByIdx[i]]) { + canKeepOuterOrder = false + } + } + // Since index merge join requires prop items the prefix of outer join keys + // or outer join keys the prefix of the prop items. So we need `canKeepOuterOrder` or + // `isOuterKeysPrefix` to be true. + if canKeepOuterOrder || isOuterKeysPrefix { + indexMergeJoin := PhysicalIndexMergeJoin{ + PhysicalIndexJoin: *join, + KeyOff2KeyOffOrderByIdx: keyOff2KeyOffOrderByIdx, + NeedOuterSort: !isOuterKeysPrefix, + CompareFuncs: compareFuncs, + OuterCompareFuncs: outerCompareFuncs, + Desc: !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, + }.Init(p.SCtx()) + indexMergeJoins = append(indexMergeJoins, indexMergeJoin) + } + } + return indexMergeJoins +} + +func constructIndexHashJoin( + p *LogicalJoin, + prop *property.PhysicalProperty, + outerIdx int, + innerTask base.Task, + ranges ranger.MutableRanges, + keyOff2IdxOff []int, + path *util.AccessPath, + compareFilters *ColWithCmpFuncManager, +) []base.PhysicalPlan { + indexJoins := constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, compareFilters, true) + indexHashJoins := make([]base.PhysicalPlan, 0, len(indexJoins)) + for _, plan := range indexJoins { + join := plan.(*PhysicalIndexJoin) + indexHashJoin := PhysicalIndexHashJoin{ + PhysicalIndexJoin: *join, + // Prop is empty means that the parent operator does not need the + // join operator to provide any promise of the output order. + KeepOuterOrder: !prop.IsSortItemEmpty(), + }.Init(p.SCtx()) + indexHashJoins = append(indexHashJoins, indexHashJoin) + } + return indexHashJoins +} + +// getIndexJoinByOuterIdx will generate index join by outerIndex. OuterIdx points out the outer child. +// First of all, we'll check whether the inner child is DataSource. +// Then, we will extract the join keys of p's equal conditions. Then check whether all of them are just the primary key +// or match some part of on index. If so we will choose the best one and construct a index join. +func getIndexJoinByOuterIdx(p *LogicalJoin, prop *property.PhysicalProperty, outerIdx int) (joins []base.PhysicalPlan) { + outerChild, innerChild := p.Children()[outerIdx], p.Children()[1-outerIdx] + all, _ := prop.AllSameOrder() + // If the order by columns are not all from outer child, index join cannot promise the order. + if !prop.AllColsFromSchema(outerChild.Schema()) || !all { + return nil + } + var ( + innerJoinKeys []*expression.Column + outerJoinKeys []*expression.Column + ) + if outerIdx == 0 { + outerJoinKeys, innerJoinKeys, _, _ = p.GetJoinKeys() + } else { + innerJoinKeys, outerJoinKeys, _, _ = p.GetJoinKeys() + } + innerChildWrapper := extractIndexJoinInnerChildPattern(p, innerChild) + if innerChildWrapper == nil { + return nil + } + + var avgInnerRowCnt float64 + if outerChild.StatsInfo().RowCount > 0 { + avgInnerRowCnt = p.EqualCondOutCnt / outerChild.StatsInfo().RowCount + } + joins = buildIndexJoinInner2TableScan(p, prop, innerChildWrapper, innerJoinKeys, outerJoinKeys, outerIdx, avgInnerRowCnt) + if joins != nil { + return + } + return buildIndexJoinInner2IndexScan(p, prop, innerChildWrapper, innerJoinKeys, outerJoinKeys, outerIdx, avgInnerRowCnt) +} + +// indexJoinInnerChildWrapper is a wrapper for the inner child of an index join. +// It contains the lowest DataSource operator and other inner child operator +// which is flattened into a list structure from tree structure . +// For example, the inner child of an index join is a tree structure like: +// +// Projection +// Aggregation +// Selection +// DataSource +// +// The inner child wrapper will be: +// DataSource: the lowest DataSource operator. +// hasDitryWrite: whether the inner child contains dirty data. +// zippedChildren: [Projection, Aggregation, Selection] +type indexJoinInnerChildWrapper struct { + ds *DataSource + hasDitryWrite bool + zippedChildren []base.LogicalPlan +} + +func extractIndexJoinInnerChildPattern(p *LogicalJoin, innerChild base.LogicalPlan) *indexJoinInnerChildWrapper { + wrapper := &indexJoinInnerChildWrapper{} + nextChild := func(pp base.LogicalPlan) base.LogicalPlan { + if len(pp.Children()) != 1 { + return nil + } + return pp.Children()[0] + } +childLoop: + for curChild := innerChild; curChild != nil; curChild = nextChild(curChild) { + switch child := curChild.(type) { + case *DataSource: + wrapper.ds = child + break childLoop + case *logicalop.LogicalProjection, *LogicalSelection, *LogicalAggregation: + if !p.SCtx().GetSessionVars().EnableINLJoinInnerMultiPattern { + return nil + } + wrapper.zippedChildren = append(wrapper.zippedChildren, child) + case *logicalop.LogicalUnionScan: + wrapper.hasDitryWrite = true + wrapper.zippedChildren = append(wrapper.zippedChildren, child) + default: + return nil + } + } + if wrapper.ds == nil || wrapper.ds.PreferStoreType&h.PreferTiFlash != 0 { + return nil + } + return wrapper +} + +// buildIndexJoinInner2TableScan builds a TableScan as the inner child for an +// IndexJoin if possible. +// If the inner side of a index join is a TableScan, only one tuple will be +// fetched from the inner side for every tuple from the outer side. This will be +// promised to be no worse than building IndexScan as the inner child. +func buildIndexJoinInner2TableScan( + p *LogicalJoin, + prop *property.PhysicalProperty, wrapper *indexJoinInnerChildWrapper, + innerJoinKeys, outerJoinKeys []*expression.Column, + outerIdx int, avgInnerRowCnt float64) (joins []base.PhysicalPlan) { + ds := wrapper.ds + var tblPath *util.AccessPath + for _, path := range ds.PossibleAccessPaths { + if path.IsTablePath() && path.StoreType == kv.TiKV { + tblPath = path + break + } + } + if tblPath == nil { + return nil + } + keyOff2IdxOff := make([]int, len(innerJoinKeys)) + newOuterJoinKeys := make([]*expression.Column, 0) + var ranges ranger.MutableRanges = ranger.Ranges{} + var innerTask, innerTask2 base.Task + var indexJoinResult *indexJoinPathResult + if ds.TableInfo.IsCommonHandle { + indexJoinResult, keyOff2IdxOff = getBestIndexJoinPathResult(p, ds, innerJoinKeys, outerJoinKeys, func(path *util.AccessPath) bool { return path.IsCommonHandlePath }) + if indexJoinResult == nil { + return nil + } + rangeInfo := indexJoinPathRangeInfo(p.SCtx(), outerJoinKeys, indexJoinResult) + innerTask = constructInnerTableScanTask(p, prop, wrapper, indexJoinResult.chosenRanges.Range(), outerJoinKeys, rangeInfo, false, false, avgInnerRowCnt) + // The index merge join's inner plan is different from index join, so we + // should construct another inner plan for it. + // Because we can't keep order for union scan, if there is a union scan in inner task, + // we can't construct index merge join. + if !wrapper.hasDitryWrite { + innerTask2 = constructInnerTableScanTask(p, prop, wrapper, indexJoinResult.chosenRanges.Range(), outerJoinKeys, rangeInfo, true, !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, avgInnerRowCnt) + } + ranges = indexJoinResult.chosenRanges + } else { + pkMatched := false + pkCol := ds.getPKIsHandleCol() + if pkCol == nil { + return nil + } + for i, key := range innerJoinKeys { + if !key.EqualColumn(pkCol) { + keyOff2IdxOff[i] = -1 + continue + } + pkMatched = true + keyOff2IdxOff[i] = 0 + // Add to newOuterJoinKeys only if conditions contain inner primary key. For issue #14822. + newOuterJoinKeys = append(newOuterJoinKeys, outerJoinKeys[i]) + } + outerJoinKeys = newOuterJoinKeys + if !pkMatched { + return nil + } + ranges := ranger.FullIntRange(mysql.HasUnsignedFlag(pkCol.RetType.GetFlag())) + var buffer strings.Builder + buffer.WriteString("[") + for i, key := range outerJoinKeys { + if i != 0 { + buffer.WriteString(" ") + } + buffer.WriteString(key.StringWithCtx(p.SCtx().GetExprCtx().GetEvalCtx(), errors.RedactLogDisable)) + } + buffer.WriteString("]") + rangeInfo := buffer.String() + innerTask = constructInnerTableScanTask(p, prop, wrapper, ranges, outerJoinKeys, rangeInfo, false, false, avgInnerRowCnt) + // The index merge join's inner plan is different from index join, so we + // should construct another inner plan for it. + // Because we can't keep order for union scan, if there is a union scan in inner task, + // we can't construct index merge join. + if !wrapper.hasDitryWrite { + innerTask2 = constructInnerTableScanTask(p, prop, wrapper, ranges, outerJoinKeys, rangeInfo, true, !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, avgInnerRowCnt) + } + } + var ( + path *util.AccessPath + lastColMng *ColWithCmpFuncManager + ) + if indexJoinResult != nil { + path = indexJoinResult.chosenPath + lastColMng = indexJoinResult.lastColManager + } + joins = make([]base.PhysicalPlan, 0, 3) + failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { + if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { + failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, nil, keyOff2IdxOff, path, lastColMng)) + } + }) + joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, lastColMng, true)...) + // We can reuse the `innerTask` here since index nested loop hash join + // do not need the inner child to promise the order. + joins = append(joins, constructIndexHashJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, lastColMng)...) + if innerTask2 != nil { + joins = append(joins, constructIndexMergeJoin(p, prop, outerIdx, innerTask2, ranges, keyOff2IdxOff, path, lastColMng)...) + } + return joins +} + +func buildIndexJoinInner2IndexScan( + p *LogicalJoin, + prop *property.PhysicalProperty, wrapper *indexJoinInnerChildWrapper, innerJoinKeys, outerJoinKeys []*expression.Column, + outerIdx int, avgInnerRowCnt float64) (joins []base.PhysicalPlan) { + ds := wrapper.ds + indexValid := func(path *util.AccessPath) bool { + if path.IsTablePath() { + return false + } + // if path is index path. index path currently include two kind of, one is normal, and the other is mv index. + // for mv index like mvi(a, json, b), if driving condition is a=1, and we build a prefix scan with range [1,1] + // on mvi, it will return many index rows which breaks handle-unique attribute here. + // + // the basic rule is that: mv index can be and can only be accessed by indexMerge operator. (embedded handle duplication) + if !isMVIndexPath(path) { + return true // not a MVIndex path, it can successfully be index join probe side. + } + return false + } + indexJoinResult, keyOff2IdxOff := getBestIndexJoinPathResult(p, ds, innerJoinKeys, outerJoinKeys, indexValid) + if indexJoinResult == nil { + return nil + } + joins = make([]base.PhysicalPlan, 0, 3) + rangeInfo := indexJoinPathRangeInfo(p.SCtx(), outerJoinKeys, indexJoinResult) + maxOneRow := false + if indexJoinResult.chosenPath.Index.Unique && indexJoinResult.usedColsLen == len(indexJoinResult.chosenPath.FullIdxCols) { + l := len(indexJoinResult.chosenAccess) + if l == 0 { + maxOneRow = true + } else { + sf, ok := indexJoinResult.chosenAccess[l-1].(*expression.ScalarFunction) + maxOneRow = ok && (sf.FuncName.L == ast.EQ) + } + } + innerTask := constructInnerIndexScanTask(p, prop, wrapper, indexJoinResult.chosenPath, indexJoinResult.chosenRanges.Range(), indexJoinResult.chosenRemained, innerJoinKeys, indexJoinResult.idxOff2KeyOff, rangeInfo, false, false, avgInnerRowCnt, maxOneRow) + failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { + if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL && innerTask != nil { + failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)) + } + }) + if innerTask != nil { + joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager, true)...) + // We can reuse the `innerTask` here since index nested loop hash join + // do not need the inner child to promise the order. + joins = append(joins, constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)...) + } + // The index merge join's inner plan is different from index join, so we + // should construct another inner plan for it. + // Because we can't keep order for union scan, if there is a union scan in inner task, + // we can't construct index merge join. + if !wrapper.hasDitryWrite { + innerTask2 := constructInnerIndexScanTask(p, prop, wrapper, indexJoinResult.chosenPath, indexJoinResult.chosenRanges.Range(), indexJoinResult.chosenRemained, innerJoinKeys, indexJoinResult.idxOff2KeyOff, rangeInfo, true, !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, avgInnerRowCnt, maxOneRow) + if innerTask2 != nil { + joins = append(joins, constructIndexMergeJoin(p, prop, outerIdx, innerTask2, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)...) + } + } + return joins +} + +// constructInnerTableScanTask is specially used to construct the inner plan for PhysicalIndexJoin. +func constructInnerTableScanTask( + p *LogicalJoin, + prop *property.PhysicalProperty, + wrapper *indexJoinInnerChildWrapper, + ranges ranger.Ranges, + _ []*expression.Column, + rangeInfo string, + keepOrder bool, + desc bool, + rowCount float64, +) base.Task { + ds := wrapper.ds + // If `ds.TableInfo.GetPartitionInfo() != nil`, + // it means the data source is a partition table reader. + // If the inner task need to keep order, the partition table reader can't satisfy it. + if keepOrder && ds.TableInfo.GetPartitionInfo() != nil { + return nil + } + ts := PhysicalTableScan{ + Table: ds.TableInfo, + Columns: ds.Columns, + TableAsName: ds.TableAsName, + DBName: ds.DBName, + filterCondition: ds.PushedDownConds, + Ranges: ranges, + rangeInfo: rangeInfo, + KeepOrder: keepOrder, + Desc: desc, + physicalTableID: ds.PhysicalTableID, + isPartition: ds.PartitionDefIdx != nil, + tblCols: ds.TblCols, + tblColHists: ds.TblColHists, + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + ts.SetSchema(ds.Schema().Clone()) + if rowCount <= 0 { + rowCount = float64(1) + } + selectivity := float64(1) + countAfterAccess := rowCount + if len(ts.filterCondition) > 0 { + var err error + selectivity, _, err = cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, ts.filterCondition, ds.PossibleAccessPaths) + if err != nil || selectivity <= 0 { + logutil.BgLogger().Debug("unexpected selectivity, use selection factor", zap.Float64("selectivity", selectivity), zap.String("table", ts.TableAsName.L)) + selectivity = cost.SelectionFactor + } + // rowCount is computed from result row count of join, which has already accounted the filters on DataSource, + // i.e, rowCount equals to `countAfterAccess * selectivity`. + countAfterAccess = rowCount / selectivity + } + ts.SetStats(&property.StatsInfo{ + // TableScan as inner child of IndexJoin can return at most 1 tuple for each outer row. + RowCount: math.Min(1.0, countAfterAccess), + StatsVersion: ds.StatsInfo().StatsVersion, + // NDV would not be used in cost computation of IndexJoin, set leave it as default nil. + }) + usedStats := p.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) + if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { + ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) + } + copTask := &CopTask{ + tablePlan: ts, + indexPlanFinished: true, + tblColHists: ds.TblColHists, + keepOrder: ts.KeepOrder, + } + copTask.physPlanPartInfo = &PhysPlanPartInfo{ + PruningConds: ds.AllConds, + PartitionNames: ds.PartitionNames, + Columns: ds.TblCols, + ColumnNames: ds.OutputNames(), + } + ts.PlanPartInfo = copTask.physPlanPartInfo + selStats := ts.StatsInfo().Scale(selectivity) + ts.addPushedDownSelection(copTask, selStats) + return constructIndexJoinInnerSideTask(p, prop, copTask, ds, nil, wrapper) +} + +func constructInnerByZippedChildren(prop *property.PhysicalProperty, zippedChildren []base.LogicalPlan, child base.PhysicalPlan) base.PhysicalPlan { + for i := len(zippedChildren) - 1; i >= 0; i-- { + switch x := zippedChildren[i].(type) { + case *logicalop.LogicalUnionScan: + child = constructInnerUnionScan(prop, x, child) + case *logicalop.LogicalProjection: + child = constructInnerProj(prop, x, child) + case *LogicalSelection: + child = constructInnerSel(prop, x, child) + case *LogicalAggregation: + child = constructInnerAgg(prop, x, child) + } + } + return child +} + +func constructInnerAgg(prop *property.PhysicalProperty, logicalAgg *LogicalAggregation, child base.PhysicalPlan) base.PhysicalPlan { + if logicalAgg == nil { + return child + } + physicalHashAgg := NewPhysicalHashAgg(logicalAgg, logicalAgg.StatsInfo(), prop) + physicalHashAgg.SetSchema(logicalAgg.Schema().Clone()) + physicalHashAgg.SetChildren(child) + return physicalHashAgg +} + +func constructInnerSel(prop *property.PhysicalProperty, sel *LogicalSelection, child base.PhysicalPlan) base.PhysicalPlan { + if sel == nil { + return child + } + physicalSel := PhysicalSelection{ + Conditions: sel.Conditions, + }.Init(sel.SCtx(), sel.StatsInfo(), sel.QueryBlockOffset(), prop) + physicalSel.SetChildren(child) + return physicalSel +} + +func constructInnerProj(prop *property.PhysicalProperty, proj *logicalop.LogicalProjection, child base.PhysicalPlan) base.PhysicalPlan { + if proj == nil { + return child + } + physicalProj := PhysicalProjection{ + Exprs: proj.Exprs, + CalculateNoDelay: proj.CalculateNoDelay, + AvoidColumnEvaluator: proj.AvoidColumnEvaluator, + }.Init(proj.SCtx(), proj.StatsInfo(), proj.QueryBlockOffset(), prop) + physicalProj.SetChildren(child) + physicalProj.SetSchema(proj.Schema()) + return physicalProj +} + +func constructInnerUnionScan(prop *property.PhysicalProperty, us *logicalop.LogicalUnionScan, reader base.PhysicalPlan) base.PhysicalPlan { + if us == nil { + return reader + } + // Use `reader.StatsInfo()` instead of `us.StatsInfo()` because it should be more accurate. No need to specify + // childrenReqProps now since we have got reader already. + physicalUnionScan := PhysicalUnionScan{ + Conditions: us.Conditions, + HandleCols: us.HandleCols, + }.Init(us.SCtx(), reader.StatsInfo(), us.QueryBlockOffset(), prop) + physicalUnionScan.SetChildren(reader) + return physicalUnionScan +} + +// getColsNDVLowerBoundFromHistColl tries to get a lower bound of the NDV of columns (whose uniqueIDs are colUIDs). +func getColsNDVLowerBoundFromHistColl(colUIDs []int64, histColl *statistics.HistColl) int64 { + if len(colUIDs) == 0 || histColl == nil { + return -1 + } + + // 1. Try to get NDV from column stats if it's a single column. + if len(colUIDs) == 1 && histColl.ColNum() > 0 { + uid := colUIDs[0] + if colStats := histColl.GetCol(uid); colStats != nil && colStats.IsStatsInitialized() { + return colStats.NDV + } + } + + slices.Sort(colUIDs) + + // 2. Try to get NDV from index stats. + // Note that we don't need to specially handle prefix index here, because the NDV of a prefix index is + // equal or less than the corresponding normal index, and that's safe here since we want a lower bound. + for idxID, idxCols := range histColl.Idx2ColUniqueIDs { + if len(idxCols) != len(colUIDs) { + continue + } + orderedIdxCols := make([]int64, len(idxCols)) + copy(orderedIdxCols, idxCols) + slices.Sort(orderedIdxCols) + if !slices.Equal(orderedIdxCols, colUIDs) { + continue + } + if idxStats := histColl.GetIdx(idxID); idxStats != nil && idxStats.IsStatsInitialized() { + return idxStats.NDV + } + } + + // TODO: if there's an index that contains the expected columns, we can also make use of its NDV. + // For example, NDV(a,b,c) / NDV(c) is a safe lower bound of NDV(a,b). + + // 3. If we still haven't got an NDV, we use the maximum NDV in the column stats as a lower bound. + maxNDV := int64(-1) + for _, uid := range colUIDs { + colStats := histColl.GetCol(uid) + if colStats == nil || !colStats.IsStatsInitialized() { + continue + } + maxNDV = max(maxNDV, colStats.NDV) + } + return maxNDV +} + +// constructInnerIndexScanTask is specially used to construct the inner plan for PhysicalIndexJoin. +func constructInnerIndexScanTask( + p *LogicalJoin, + prop *property.PhysicalProperty, + wrapper *indexJoinInnerChildWrapper, + path *util.AccessPath, + ranges ranger.Ranges, + filterConds []expression.Expression, + _ []*expression.Column, + idxOffset2joinKeyOffset []int, + rangeInfo string, + keepOrder bool, + desc bool, + rowCount float64, + maxOneRow bool, +) base.Task { + ds := wrapper.ds + // If `ds.TableInfo.GetPartitionInfo() != nil`, + // it means the data source is a partition table reader. + // If the inner task need to keep order, the partition table reader can't satisfy it. + if keepOrder && ds.TableInfo.GetPartitionInfo() != nil { + return nil + } + is := PhysicalIndexScan{ + Table: ds.TableInfo, + TableAsName: ds.TableAsName, + DBName: ds.DBName, + Columns: ds.Columns, + Index: path.Index, + IdxCols: path.IdxCols, + IdxColLens: path.IdxColLens, + dataSourceSchema: ds.Schema(), + KeepOrder: keepOrder, + Ranges: ranges, + rangeInfo: rangeInfo, + Desc: desc, + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + tblColHists: ds.TblColHists, + pkIsHandleCol: ds.getPKIsHandleCol(), + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + cop := &CopTask{ + indexPlan: is, + tblColHists: ds.TblColHists, + tblCols: ds.TblCols, + keepOrder: is.KeepOrder, + } + cop.physPlanPartInfo = &PhysPlanPartInfo{ + PruningConds: ds.AllConds, + PartitionNames: ds.PartitionNames, + Columns: ds.TblCols, + ColumnNames: ds.OutputNames(), + } + if !path.IsSingleScan { + // On this way, it's double read case. + ts := PhysicalTableScan{ + Columns: ds.Columns, + Table: is.Table, + TableAsName: ds.TableAsName, + DBName: ds.DBName, + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + tblCols: ds.TblCols, + tblColHists: ds.TblColHists, + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + ts.schema = is.dataSourceSchema.Clone() + if ds.TableInfo.IsCommonHandle { + commonHandle := ds.HandleCols.(*util.CommonHandleCols) + for _, col := range commonHandle.GetColumns() { + if ts.schema.ColumnIndex(col) == -1 { + ts.Schema().Append(col) + ts.Columns = append(ts.Columns, col.ToInfo()) + cop.needExtraProj = true + } + } + } + // We set `StatsVersion` here and fill other fields in `(*copTask).finishIndexPlan`. Since `copTask.indexPlan` may + // change before calling `(*copTask).finishIndexPlan`, we don't know the stats information of `ts` currently and on + // the other hand, it may be hard to identify `StatsVersion` of `ts` in `(*copTask).finishIndexPlan`. + ts.SetStats(&property.StatsInfo{StatsVersion: ds.TableStats.StatsVersion}) + usedStats := p.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) + if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { + ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) + } + // If inner cop task need keep order, the extraHandleCol should be set. + if cop.keepOrder && !ds.TableInfo.IsCommonHandle { + var needExtraProj bool + cop.extraHandleCol, needExtraProj = ts.appendExtraHandleCol(ds) + cop.needExtraProj = cop.needExtraProj || needExtraProj + } + if cop.needExtraProj { + cop.originSchema = ds.Schema() + } + cop.tablePlan = ts + } + if cop.tablePlan != nil && ds.TableInfo.IsCommonHandle { + cop.commonHandleCols = ds.CommonHandleCols + } + is.initSchema(append(path.FullIdxCols, ds.CommonHandleCols...), cop.tablePlan != nil) + indexConds, tblConds := ds.splitIndexFilterConditions(filterConds, path.FullIdxCols, path.FullIdxColLens) + + // Note: due to a regression in JOB workload, we use the optimizer fix control to enable this for now. + // + // Because we are estimating an average row count of the inner side corresponding to each row from the outer side, + // the estimated row count of the IndexScan should be no larger than (total row count / NDV of join key columns). + // We can calculate the lower bound of the NDV therefore we can get an upper bound of the row count here. + rowCountUpperBound := -1.0 + fixControlOK := fixcontrol.GetBoolWithDefault(ds.SCtx().GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix44855, false) + if fixControlOK && ds.TableStats != nil { + usedColIDs := make([]int64, 0) + // We only consider columns in this index that (1) are used to probe as join key, + // and (2) are not prefix column in the index (for which we can't easily get a lower bound) + for idxOffset, joinKeyOffset := range idxOffset2joinKeyOffset { + if joinKeyOffset < 0 || + path.FullIdxColLens[idxOffset] != types.UnspecifiedLength || + path.FullIdxCols[idxOffset] == nil { + continue + } + usedColIDs = append(usedColIDs, path.FullIdxCols[idxOffset].UniqueID) + } + joinKeyNDV := getColsNDVLowerBoundFromHistColl(usedColIDs, ds.TableStats.HistColl) + if joinKeyNDV > 0 { + rowCountUpperBound = ds.TableStats.RowCount / float64(joinKeyNDV) + } + } + + if rowCountUpperBound > 0 { + rowCount = math.Min(rowCount, rowCountUpperBound) + } + if maxOneRow { + // Theoretically, this line is unnecessary because row count estimation of join should guarantee rowCount is not larger + // than 1.0; however, there may be rowCount larger than 1.0 in reality, e.g, pseudo statistics cases, which does not reflect + // unique constraint in NDV. + rowCount = math.Min(rowCount, 1.0) + } + tmpPath := &util.AccessPath{ + IndexFilters: indexConds, + TableFilters: tblConds, + CountAfterIndex: rowCount, + CountAfterAccess: rowCount, + } + // Assume equal conditions used by index join and other conditions are independent. + if len(tblConds) > 0 { + selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, tblConds, ds.PossibleAccessPaths) + if err != nil || selectivity <= 0 { + logutil.BgLogger().Debug("unexpected selectivity, use selection factor", zap.Float64("selectivity", selectivity), zap.String("table", ds.TableAsName.L)) + selectivity = cost.SelectionFactor + } + // rowCount is computed from result row count of join, which has already accounted the filters on DataSource, + // i.e, rowCount equals to `countAfterIndex * selectivity`. + cnt := rowCount / selectivity + if rowCountUpperBound > 0 { + cnt = math.Min(cnt, rowCountUpperBound) + } + if maxOneRow { + cnt = math.Min(cnt, 1.0) + } + tmpPath.CountAfterIndex = cnt + tmpPath.CountAfterAccess = cnt + } + if len(indexConds) > 0 { + selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, indexConds, ds.PossibleAccessPaths) + if err != nil || selectivity <= 0 { + logutil.BgLogger().Debug("unexpected selectivity, use selection factor", zap.Float64("selectivity", selectivity), zap.String("table", ds.TableAsName.L)) + selectivity = cost.SelectionFactor + } + cnt := tmpPath.CountAfterIndex / selectivity + if rowCountUpperBound > 0 { + cnt = math.Min(cnt, rowCountUpperBound) + } + if maxOneRow { + cnt = math.Min(cnt, 1.0) + } + tmpPath.CountAfterAccess = cnt + } + is.SetStats(ds.TableStats.ScaleByExpectCnt(tmpPath.CountAfterAccess)) + usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) + if usedStats != nil && usedStats.GetUsedInfo(is.physicalTableID) != nil { + is.usedStatsInfo = usedStats.GetUsedInfo(is.physicalTableID) + } + finalStats := ds.TableStats.ScaleByExpectCnt(rowCount) + if err := is.addPushedDownSelection(cop, ds, tmpPath, finalStats); err != nil { + logutil.BgLogger().Warn("unexpected error happened during addPushedDownSelection function", zap.Error(err)) + return nil + } + return constructIndexJoinInnerSideTask(p, prop, cop, ds, path, wrapper) +} + +// construct the inner join task by inner child plan tree +// The Logical include two parts: logicalplan->physicalplan, physicalplan->task +// Step1: whether agg can be pushed down to coprocessor +// +// Step1.1: If the agg can be pushded down to coprocessor, we will build a copTask and attach the agg to the copTask +// There are two kinds of agg: stream agg and hash agg. Stream agg depends on some conditions, such as the group by cols +// +// Step2: build other inner plan node to task +func constructIndexJoinInnerSideTask(p *LogicalJoin, prop *property.PhysicalProperty, dsCopTask *CopTask, ds *DataSource, path *util.AccessPath, wrapper *indexJoinInnerChildWrapper) base.Task { + var la *LogicalAggregation + var canPushAggToCop bool + if len(wrapper.zippedChildren) > 0 { + la, canPushAggToCop = wrapper.zippedChildren[len(wrapper.zippedChildren)-1].(*LogicalAggregation) + if la != nil && la.HasDistinct() { + // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. + // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. + if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown { + canPushAggToCop = false + } + } + } + + // If the bottom plan is not aggregation or the aggregation can't be pushed to coprocessor, we will construct a root task directly. + if !canPushAggToCop { + result := dsCopTask.ConvertToRootTask(ds.SCtx()).(*RootTask) + result.SetPlan(constructInnerByZippedChildren(prop, wrapper.zippedChildren, result.GetPlan())) + return result + } + + // Try stream aggregation first. + // We will choose the stream aggregation if the following conditions are met: + // 1. Force hint stream agg by /*+ stream_agg() */ + // 2. Other conditions copy from getStreamAggs() in exhaust_physical_plans.go + _, preferStream := la.ResetHintIfConflicted() + for _, aggFunc := range la.AggFuncs { + if aggFunc.Mode == aggregation.FinalMode { + preferStream = false + break + } + } + // group by a + b is not interested in any order. + groupByCols := la.GetGroupByCols() + if len(groupByCols) != len(la.GroupByItems) { + preferStream = false + } + if la.HasDistinct() && !la.DistinctArgsMeetsProperty() { + preferStream = false + } + // sort items must be the super set of group by items + if path != nil && path.Index != nil && !path.Index.MVIndex && + ds.TableInfo.GetPartitionInfo() == nil { + if len(path.IdxCols) < len(groupByCols) { + preferStream = false + } else { + sctx := p.SCtx() + for i, groupbyCol := range groupByCols { + if path.IdxColLens[i] != types.UnspecifiedLength || + !groupbyCol.EqualByExprAndID(sctx.GetExprCtx().GetEvalCtx(), path.IdxCols[i]) { + preferStream = false + } + } + } + } else { + preferStream = false + } + + // build physical agg and attach to task + var aggTask base.Task + // build stream agg and change ds keep order to true + if preferStream { + newGbyItems := make([]expression.Expression, len(la.GroupByItems)) + copy(newGbyItems, la.GroupByItems) + newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) + copy(newAggFuncs, la.AggFuncs) + streamAgg := basePhysicalAgg{ + GroupByItems: newGbyItems, + AggFuncs: newAggFuncs, + }.initForStream(la.SCtx(), la.StatsInfo(), la.QueryBlockOffset(), prop) + streamAgg.SetSchema(la.Schema().Clone()) + // change to keep order for index scan and dsCopTask + if dsCopTask.indexPlan != nil { + // get the index scan from dsCopTask.indexPlan + physicalIndexScan, _ := dsCopTask.indexPlan.(*PhysicalIndexScan) + if physicalIndexScan == nil && len(dsCopTask.indexPlan.Children()) == 1 { + physicalIndexScan, _ = dsCopTask.indexPlan.Children()[0].(*PhysicalIndexScan) + } + if physicalIndexScan != nil { + physicalIndexScan.KeepOrder = true + dsCopTask.keepOrder = true + aggTask = streamAgg.Attach2Task(dsCopTask) + } + } + } + + // build hash agg, when the stream agg is illegal such as the order by prop is not matched + if aggTask == nil { + physicalHashAgg := NewPhysicalHashAgg(la, la.StatsInfo(), prop) + physicalHashAgg.SetSchema(la.Schema().Clone()) + aggTask = physicalHashAgg.Attach2Task(dsCopTask) + } + + // build other inner plan node to task + result, ok := aggTask.(*RootTask) + if !ok { + return nil + } + result.SetPlan(constructInnerByZippedChildren(prop, wrapper.zippedChildren[0:len(wrapper.zippedChildren)-1], result.p)) + return result +} + +func filterIndexJoinBySessionVars(sc base.PlanContext, indexJoins []base.PhysicalPlan) []base.PhysicalPlan { + if sc.GetSessionVars().EnableIndexMergeJoin { + return indexJoins + } + for i := len(indexJoins) - 1; i >= 0; i-- { + if _, ok := indexJoins[i].(*PhysicalIndexMergeJoin); ok { + indexJoins = append(indexJoins[:i], indexJoins[i+1:]...) + } + } + return indexJoins +} + +const ( + joinLeft = 0 + joinRight = 1 + indexJoinMethod = 0 + indexHashJoinMethod = 1 + indexMergeJoinMethod = 2 +) + +func getIndexJoinSideAndMethod(join base.PhysicalPlan) (innerSide, joinMethod int, ok bool) { + var innerIdx int + switch ij := join.(type) { + case *PhysicalIndexJoin: + innerIdx = ij.getInnerChildIdx() + joinMethod = indexJoinMethod + case *PhysicalIndexHashJoin: + innerIdx = ij.getInnerChildIdx() + joinMethod = indexHashJoinMethod + case *PhysicalIndexMergeJoin: + innerIdx = ij.getInnerChildIdx() + joinMethod = indexMergeJoinMethod + default: + return 0, 0, false + } + ok = true + innerSide = joinLeft + if innerIdx == 1 { + innerSide = joinRight + } + return +} + +// tryToGetIndexJoin returns all available index join plans, and the second returned value indicates whether this plan is enforced by hints. +func tryToGetIndexJoin(p *LogicalJoin, prop *property.PhysicalProperty) (indexJoins []base.PhysicalPlan, canForced bool) { + // supportLeftOuter and supportRightOuter indicates whether this type of join + // supports the left side or right side to be the outer side. + var supportLeftOuter, supportRightOuter bool + switch p.JoinType { + case SemiJoin, AntiSemiJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin, LeftOuterJoin: + supportLeftOuter = true + case RightOuterJoin: + supportRightOuter = true + case InnerJoin: + supportLeftOuter, supportRightOuter = true, true + } + candidates := make([]base.PhysicalPlan, 0, 2) + if supportLeftOuter { + candidates = append(candidates, getIndexJoinByOuterIdx(p, prop, 0)...) + } + if supportRightOuter { + candidates = append(candidates, getIndexJoinByOuterIdx(p, prop, 1)...) + } + + // Handle hints and variables about index join. + // The priority is: force hints like TIDB_INLJ > filter hints like NO_INDEX_JOIN > variables. + // Handle hints conflict first. + stmtCtx := p.SCtx().GetSessionVars().StmtCtx + if p.PreferAny(h.PreferLeftAsINLJInner, h.PreferRightAsINLJInner) && p.PreferAny(h.PreferNoIndexJoin) { + stmtCtx.SetHintWarning("Some INL_JOIN and NO_INDEX_JOIN hints conflict, NO_INDEX_JOIN may be ignored") + } + if p.PreferAny(h.PreferLeftAsINLHJInner, h.PreferRightAsINLHJInner) && p.PreferAny(h.PreferNoIndexHashJoin) { + stmtCtx.SetHintWarning("Some INL_HASH_JOIN and NO_INDEX_HASH_JOIN hints conflict, NO_INDEX_HASH_JOIN may be ignored") + } + if p.PreferAny(h.PreferLeftAsINLMJInner, h.PreferRightAsINLMJInner) && p.PreferAny(h.PreferNoIndexMergeJoin) { + stmtCtx.SetHintWarning("Some INL_MERGE_JOIN and NO_INDEX_MERGE_JOIN hints conflict, NO_INDEX_MERGE_JOIN may be ignored") + } + + candidates, canForced = handleForceIndexJoinHints(p, prop, candidates) + if canForced { + return candidates, canForced + } + candidates = handleFilterIndexJoinHints(p, candidates) + return filterIndexJoinBySessionVars(p.SCtx(), candidates), false +} + +func handleFilterIndexJoinHints(p *LogicalJoin, candidates []base.PhysicalPlan) []base.PhysicalPlan { + if !p.PreferAny(h.PreferNoIndexJoin, h.PreferNoIndexHashJoin, h.PreferNoIndexMergeJoin) { + return candidates // no filter index join hints + } + filtered := make([]base.PhysicalPlan, 0, len(candidates)) + for _, candidate := range candidates { + _, joinMethod, ok := getIndexJoinSideAndMethod(candidate) + if !ok { + continue + } + if (p.PreferAny(h.PreferNoIndexJoin) && joinMethod == indexJoinMethod) || + (p.PreferAny(h.PreferNoIndexHashJoin) && joinMethod == indexHashJoinMethod) || + (p.PreferAny(h.PreferNoIndexMergeJoin) && joinMethod == indexMergeJoinMethod) { + continue + } + filtered = append(filtered, candidate) + } + return filtered +} + +// handleForceIndexJoinHints handles the force index join hints and returns all plans that can satisfy the hints. +func handleForceIndexJoinHints(p *LogicalJoin, prop *property.PhysicalProperty, candidates []base.PhysicalPlan) (indexJoins []base.PhysicalPlan, canForced bool) { + if !p.PreferAny(h.PreferRightAsINLJInner, h.PreferRightAsINLHJInner, h.PreferRightAsINLMJInner, + h.PreferLeftAsINLJInner, h.PreferLeftAsINLHJInner, h.PreferLeftAsINLMJInner) { + return candidates, false // no force index join hints + } + forced := make([]base.PhysicalPlan, 0, len(candidates)) + for _, candidate := range candidates { + innerSide, joinMethod, ok := getIndexJoinSideAndMethod(candidate) + if !ok { + continue + } + if (p.PreferAny(h.PreferLeftAsINLJInner) && innerSide == joinLeft && joinMethod == indexJoinMethod) || + (p.PreferAny(h.PreferRightAsINLJInner) && innerSide == joinRight && joinMethod == indexJoinMethod) || + (p.PreferAny(h.PreferLeftAsINLHJInner) && innerSide == joinLeft && joinMethod == indexHashJoinMethod) || + (p.PreferAny(h.PreferRightAsINLHJInner) && innerSide == joinRight && joinMethod == indexHashJoinMethod) || + (p.PreferAny(h.PreferLeftAsINLMJInner) && innerSide == joinLeft && joinMethod == indexMergeJoinMethod) || + (p.PreferAny(h.PreferRightAsINLMJInner) && innerSide == joinRight && joinMethod == indexMergeJoinMethod) { + forced = append(forced, candidate) + } + } + + if len(forced) > 0 { + return forced, true + } + // Cannot find any valid index join plan with these force hints. + // Print warning message if any hints cannot work. + // If the required property is not empty, we will enforce it and try the hint again. + // So we only need to generate warning message when the property is empty. + if prop.IsSortItemEmpty() { + var indexJoinTables, indexHashJoinTables, indexMergeJoinTables []h.HintedTable + if p.HintInfo != nil { + t := p.HintInfo.IndexJoin + indexJoinTables, indexHashJoinTables, indexMergeJoinTables = t.INLJTables, t.INLHJTables, t.INLMJTables + } + var errMsg string + switch { + case p.PreferAny(h.PreferLeftAsINLJInner, h.PreferRightAsINLJInner): // prefer index join + errMsg = fmt.Sprintf("Optimizer Hint %s or %s is inapplicable", h.Restore2JoinHint(h.HintINLJ, indexJoinTables), h.Restore2JoinHint(h.TiDBIndexNestedLoopJoin, indexJoinTables)) + case p.PreferAny(h.PreferLeftAsINLHJInner, h.PreferRightAsINLHJInner): // prefer index hash join + errMsg = fmt.Sprintf("Optimizer Hint %s is inapplicable", h.Restore2JoinHint(h.HintINLHJ, indexHashJoinTables)) + case p.PreferAny(h.PreferLeftAsINLMJInner, h.PreferRightAsINLMJInner): // prefer index merge join + errMsg = fmt.Sprintf("Optimizer Hint %s is inapplicable", h.Restore2JoinHint(h.HintINLMJ, indexMergeJoinTables)) + } + // Append inapplicable reason. + if len(p.EqualConditions) == 0 { + errMsg += " without column equal ON condition" + } + // Generate warning message to client. + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) + } + return candidates, false +} + +func checkChildFitBC(p base.Plan) bool { + if p.StatsInfo().HistColl == nil { + return p.SCtx().GetSessionVars().BroadcastJoinThresholdCount == -1 || p.StatsInfo().Count() < p.SCtx().GetSessionVars().BroadcastJoinThresholdCount + } + avg := cardinality.GetAvgRowSize(p.SCtx(), p.StatsInfo().HistColl, p.Schema().Columns, false, false) + sz := avg * float64(p.StatsInfo().Count()) + return p.SCtx().GetSessionVars().BroadcastJoinThresholdSize == -1 || sz < float64(p.SCtx().GetSessionVars().BroadcastJoinThresholdSize) +} + +func calcBroadcastExchangeSize(p base.Plan, mppStoreCnt int) (row float64, size float64, hasSize bool) { + s := p.StatsInfo() + row = float64(s.Count()) * float64(mppStoreCnt-1) + if s.HistColl == nil { + return row, 0, false + } + avg := cardinality.GetAvgRowSize(p.SCtx(), s.HistColl, p.Schema().Columns, false, false) + size = avg * row + return row, size, true +} + +func calcBroadcastExchangeSizeByChild(p1 base.Plan, p2 base.Plan, mppStoreCnt int) (row float64, size float64, hasSize bool) { + row1, size1, hasSize1 := calcBroadcastExchangeSize(p1, mppStoreCnt) + row2, size2, hasSize2 := calcBroadcastExchangeSize(p2, mppStoreCnt) + + // broadcast exchange size: + // Build: (mppStoreCnt - 1) * sizeof(BuildTable) + // Probe: 0 + // choose the child plan with the maximum approximate value as Probe + + if hasSize1 && hasSize2 { + return math.Min(row1, row2), math.Min(size1, size2), true + } + + return math.Min(row1, row2), 0, false +} + +func calcHashExchangeSize(p base.Plan, mppStoreCnt int) (row float64, sz float64, hasSize bool) { + s := p.StatsInfo() + row = float64(s.Count()) * float64(mppStoreCnt-1) / float64(mppStoreCnt) + if s.HistColl == nil { + return row, 0, false + } + avg := cardinality.GetAvgRowSize(p.SCtx(), s.HistColl, p.Schema().Columns, false, false) + sz = avg * row + return row, sz, true +} + +func calcHashExchangeSizeByChild(p1 base.Plan, p2 base.Plan, mppStoreCnt int) (row float64, size float64, hasSize bool) { + row1, size1, hasSize1 := calcHashExchangeSize(p1, mppStoreCnt) + row2, size2, hasSize2 := calcHashExchangeSize(p2, mppStoreCnt) + + // hash exchange size: + // Build: sizeof(BuildTable) * (mppStoreCnt - 1) / mppStoreCnt + // Probe: sizeof(ProbeTable) * (mppStoreCnt - 1) / mppStoreCnt + + if hasSize1 && hasSize2 { + return row1 + row2, size1 + size2, true + } + return row1 + row2, 0, false +} + +// The size of `Build` hash table when using broadcast join is about `X`. +// The size of `Build` hash table when using shuffle join is about `X / (mppStoreCnt)`. +// It will cost more time to construct `Build` hash table and search `Probe` while using broadcast join. +// Set a scale factor (`mppStoreCnt^*`) when estimating broadcast join in `isJoinFitMPPBCJ` and `isJoinChildFitMPPBCJ` (based on TPCH benchmark, it has been verified in Q9). + +func isJoinFitMPPBCJ(p *LogicalJoin, mppStoreCnt int) bool { + rowBC, szBC, hasSizeBC := calcBroadcastExchangeSizeByChild(p.Children()[0], p.Children()[1], mppStoreCnt) + rowHash, szHash, hasSizeHash := calcHashExchangeSizeByChild(p.Children()[0], p.Children()[1], mppStoreCnt) + if hasSizeBC && hasSizeHash { + return szBC*float64(mppStoreCnt) <= szHash + } + return rowBC*float64(mppStoreCnt) <= rowHash +} + +func isJoinChildFitMPPBCJ(p *LogicalJoin, childIndexToBC int, mppStoreCnt int) bool { + rowBC, szBC, hasSizeBC := calcBroadcastExchangeSize(p.Children()[childIndexToBC], mppStoreCnt) + rowHash, szHash, hasSizeHash := calcHashExchangeSizeByChild(p.Children()[0], p.Children()[1], mppStoreCnt) + + if hasSizeBC && hasSizeHash { + return szBC*float64(mppStoreCnt) <= szHash + } + return rowBC*float64(mppStoreCnt) <= rowHash +} + +// If we can use mpp broadcast join, that's our first choice. +func preferMppBCJ(p *LogicalJoin) bool { + if len(p.EqualConditions) == 0 && p.SCtx().GetSessionVars().AllowCartesianBCJ == 2 { + return true + } + + onlyCheckChild1 := p.JoinType == LeftOuterJoin || p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin + onlyCheckChild0 := p.JoinType == RightOuterJoin + + if p.SCtx().GetSessionVars().PreferBCJByExchangeDataSize { + mppStoreCnt, err := p.SCtx().GetMPPClient().GetMPPStoreCount() + + // No need to exchange data if there is only ONE mpp store. But the behavior of optimizer is unexpected if use broadcast way forcibly, such as tpch q4. + // TODO: always use broadcast way to exchange data if there is only ONE mpp store. + + if err == nil && mppStoreCnt > 0 { + if !(onlyCheckChild1 || onlyCheckChild0) { + return isJoinFitMPPBCJ(p, mppStoreCnt) + } + if mppStoreCnt > 1 { + if onlyCheckChild1 { + return isJoinChildFitMPPBCJ(p, 1, mppStoreCnt) + } else if onlyCheckChild0 { + return isJoinChildFitMPPBCJ(p, 0, mppStoreCnt) + } + } + // If mppStoreCnt is ONE and only need to check one child plan, rollback to original way. + // Otherwise, the plan of tpch q4 may be unexpected. + } + } + + if onlyCheckChild1 { + return checkChildFitBC(p.Children()[1]) + } else if onlyCheckChild0 { + return checkChildFitBC(p.Children()[0]) + } + return checkChildFitBC(p.Children()[0]) || checkChildFitBC(p.Children()[1]) +} + +func canExprsInJoinPushdown(p *LogicalJoin, storeType kv.StoreType) bool { + equalExprs := make([]expression.Expression, 0, len(p.EqualConditions)) + for _, eqCondition := range p.EqualConditions { + if eqCondition.FuncName.L == ast.NullEQ { + return false + } + equalExprs = append(equalExprs, eqCondition) + } + pushDownCtx := GetPushDownCtx(p.SCtx()) + if !expression.CanExprsPushDown(pushDownCtx, equalExprs, storeType) { + return false + } + if !expression.CanExprsPushDown(pushDownCtx, p.LeftConditions, storeType) { + return false + } + if !expression.CanExprsPushDown(pushDownCtx, p.RightConditions, storeType) { + return false + } + if !expression.CanExprsPushDown(pushDownCtx, p.OtherConditions, storeType) { + return false + } + return true +} + +func tryToGetMppHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, useBCJ bool) []base.PhysicalPlan { + if !prop.IsSortItemEmpty() { + return nil + } + if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { + return nil + } + + if !expression.IsPushDownEnabled(p.JoinType.String(), kv.TiFlash) { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because join type `" + p.JoinType.String() + "` is blocked by blacklist, check `table mysql.expr_pushdown_blacklist;` for more information.") + return nil + } + + if p.JoinType != InnerJoin && p.JoinType != LeftOuterJoin && p.JoinType != RightOuterJoin && p.JoinType != SemiJoin && p.JoinType != AntiSemiJoin && p.JoinType != LeftOuterSemiJoin && p.JoinType != AntiLeftOuterSemiJoin { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because join type `" + p.JoinType.String() + "` is not supported now.") + return nil + } + + if len(p.EqualConditions) == 0 { + if !useBCJ { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because `Cartesian Product` is only supported by broadcast join, check value and documents of variables `tidb_broadcast_join_threshold_size` and `tidb_broadcast_join_threshold_count`.") + return nil + } + if p.SCtx().GetSessionVars().AllowCartesianBCJ == 0 { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because `Cartesian Product` is only supported by broadcast join, check value and documents of variable `tidb_opt_broadcast_cartesian_join`.") + return nil + } + } + if len(p.LeftConditions) != 0 && p.JoinType != LeftOuterJoin { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because there is a join that is not `left join` but has left conditions, which is not supported by mpp now, see github.com/pingcap/tidb/issues/26090 for more information.") + return nil + } + if len(p.RightConditions) != 0 && p.JoinType != RightOuterJoin { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because there is a join that is not `right join` but has right conditions, which is not supported by mpp now.") + return nil + } + + if prop.MPPPartitionTp == property.BroadcastType { + return nil + } + if !canExprsInJoinPushdown(p, kv.TiFlash) { + return nil + } + lkeys, rkeys, _, _ := p.GetJoinKeys() + lNAkeys, rNAKeys := p.GetNAJoinKeys() + // check match property + baseJoin := basePhysicalJoin{ + JoinType: p.JoinType, + LeftConditions: p.LeftConditions, + RightConditions: p.RightConditions, + OtherConditions: p.OtherConditions, + DefaultValues: p.DefaultValues, + LeftJoinKeys: lkeys, + RightJoinKeys: rkeys, + LeftNAJoinKeys: lNAkeys, + RightNAJoinKeys: rNAKeys, + } + // It indicates which side is the build side. + forceLeftToBuild := ((p.PreferJoinType & h.PreferLeftAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferRightAsHJProbe) > 0) + forceRightToBuild := ((p.PreferJoinType & h.PreferRightAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferLeftAsHJProbe) > 0) + if forceLeftToBuild && forceRightToBuild { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( + "Some HASH_JOIN_BUILD and HASH_JOIN_PROBE hints are conflicts, please check the hints") + forceLeftToBuild = false + forceRightToBuild = false + } + preferredBuildIndex := 0 + fixedBuildSide := false // Used to indicate whether the build side for the MPP join is fixed or not. + if p.JoinType == InnerJoin { + if p.Children()[0].StatsInfo().Count() > p.Children()[1].StatsInfo().Count() { + preferredBuildIndex = 1 + } + } else if p.JoinType.IsSemiJoin() { + if !useBCJ && !p.IsNAAJ() && len(p.EqualConditions) > 0 && (p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin) { + // TiFlash only supports Non-null_aware non-cross semi/anti_semi join to use both sides as build side + preferredBuildIndex = 1 + // MPPOuterJoinFixedBuildSide default value is false + // use MPPOuterJoinFixedBuildSide here as a way to disable using left table as build side + if !p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide && p.Children()[1].StatsInfo().Count() > p.Children()[0].StatsInfo().Count() { + preferredBuildIndex = 0 + } + } else { + preferredBuildIndex = 1 + fixedBuildSide = true + } + } + if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { + // TiFlash does not require that the build side must be the inner table for outer join. + // so we can choose the build side based on the row count, except that: + // 1. it is a broadcast join(for broadcast join, it makes sense to use the broadcast side as the build side) + // 2. or session variable MPPOuterJoinFixedBuildSide is set to true + // 3. or nullAware/cross joins + if useBCJ || p.IsNAAJ() || len(p.EqualConditions) == 0 || p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide { + if !p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide { + // The hint has higher priority than variable. + fixedBuildSide = true + } + if p.JoinType == LeftOuterJoin { + preferredBuildIndex = 1 + } + } else if p.Children()[0].StatsInfo().Count() > p.Children()[1].StatsInfo().Count() { + preferredBuildIndex = 1 + } + } + + if forceLeftToBuild || forceRightToBuild { + match := (forceLeftToBuild && preferredBuildIndex == 0) || (forceRightToBuild && preferredBuildIndex == 1) + if !match { + if fixedBuildSide { + // A warning will be generated if the build side is fixed, but we attempt to change it using the hint. + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( + "Some HASH_JOIN_BUILD and HASH_JOIN_PROBE hints cannot be utilized for MPP joins, please check the hints") + } else { + // The HASH_JOIN_BUILD OR HASH_JOIN_PROBE hints can take effective. + preferredBuildIndex = 1 - preferredBuildIndex + } + } + } + + // set preferredBuildIndex for test + failpoint.Inject("mockPreferredBuildIndex", func(val failpoint.Value) { + if !p.SCtx().GetSessionVars().InRestrictedSQL { + preferredBuildIndex = val.(int) + } + }) + + baseJoin.InnerChildIdx = preferredBuildIndex + childrenProps := make([]*property.PhysicalProperty, 2) + if useBCJ { + childrenProps[preferredBuildIndex] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.BroadcastType, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + expCnt := math.MaxFloat64 + if prop.ExpectedCnt < p.StatsInfo().RowCount { + expCntScale := prop.ExpectedCnt / p.StatsInfo().RowCount + expCnt = p.Children()[1-preferredBuildIndex].StatsInfo().RowCount * expCntScale + } + if prop.MPPPartitionTp == property.HashType { + lPartitionKeys, rPartitionKeys := p.GetPotentialPartitionKeys() + hashKeys := rPartitionKeys + if preferredBuildIndex == 1 { + hashKeys = lPartitionKeys + } + matches := prop.IsSubsetOf(hashKeys) + if len(matches) == 0 { + return nil + } + childrenProps[1-preferredBuildIndex] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: expCnt, MPPPartitionTp: property.HashType, MPPPartitionCols: prop.MPPPartitionCols, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + } else { + childrenProps[1-preferredBuildIndex] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: expCnt, MPPPartitionTp: property.AnyType, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + } + } else { + lPartitionKeys, rPartitionKeys := p.GetPotentialPartitionKeys() + if prop.MPPPartitionTp == property.HashType { + var matches []int + if p.JoinType == InnerJoin { + if matches = prop.IsSubsetOf(lPartitionKeys); len(matches) == 0 { + matches = prop.IsSubsetOf(rPartitionKeys) + } + } else if p.JoinType == RightOuterJoin { + // for right out join, only the right partition keys can possibly matches the prop, because + // the left partition keys will generate NULL values randomly + // todo maybe we can add a null-sensitive flag in the MPPPartitionColumn to indicate whether the partition column is + // null-sensitive(used in aggregation) or null-insensitive(used in join) + matches = prop.IsSubsetOf(rPartitionKeys) + } else { + // for left out join, only the left partition keys can possibly matches the prop, because + // the right partition keys will generate NULL values randomly + // for semi/anti semi/left out semi/anti left out semi join, only left partition keys are returned, + // so just check the left partition keys + matches = prop.IsSubsetOf(lPartitionKeys) + } + if len(matches) == 0 { + return nil + } + lPartitionKeys = choosePartitionKeys(lPartitionKeys, matches) + rPartitionKeys = choosePartitionKeys(rPartitionKeys, matches) + } + childrenProps[0] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: lPartitionKeys, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + childrenProps[1] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: rPartitionKeys, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + } + join := PhysicalHashJoin{ + basePhysicalJoin: baseJoin, + Concurrency: uint(p.SCtx().GetSessionVars().CopTiFlashConcurrencyFactor), + EqualConditions: p.EqualConditions, + NAEqualConditions: p.NAEQConditions, + storeTp: kv.TiFlash, + mppShuffleJoin: !useBCJ, + // Mpp Join has quite heavy cost. Even limit might not suspend it in time, so we don't scale the count. + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), childrenProps...) + join.SetSchema(p.Schema()) + return []base.PhysicalPlan{join} +} + +func choosePartitionKeys(keys []*property.MPPPartitionColumn, matches []int) []*property.MPPPartitionColumn { + newKeys := make([]*property.MPPPartitionColumn, 0, len(matches)) + for _, id := range matches { + newKeys = append(newKeys, keys[id]) + } + return newKeys +} + +func exhaustPhysicalPlans4LogicalExpand(p *LogicalExpand, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + // under the mpp task type, if the sort item is not empty, refuse it, cause expanded data doesn't support any sort items. + if !prop.IsSortItemEmpty() { + // false, meaning we can add a sort enforcer. + return nil, false, nil + } + // when TiDB Expand execution is introduced: we can deal with two kind of physical plans. + // RootTaskType means expand should be run at TiDB node. + // (RootTaskType is the default option, we can also generate a mpp candidate for it) + // MPPTaskType means expand should be run at TiFlash node. + if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { + return nil, true, nil + } + // now Expand mode can only be executed on TiFlash node. + // Upper layer shouldn't expect any mpp partition from an Expand operator. + // todo: data output from Expand operator should keep the origin data mpp partition. + if prop.TaskTp == property.MppTaskType && prop.MPPPartitionTp != property.AnyType { + return nil, true, nil + } + var physicalExpands []base.PhysicalPlan + // for property.RootTaskType and property.MppTaskType with no partition option, we can give an MPP Expand. + canPushToTiFlash := p.CanPushToCop(kv.TiFlash) + if p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { + mppProp := prop.CloneEssentialFields() + mppProp.TaskTp = property.MppTaskType + expand := PhysicalExpand{ + GroupingSets: p.RollupGroupingSets, + LevelExprs: p.LevelExprs, + ExtraGroupingColNames: p.ExtraGroupingColNames, + }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), mppProp) + expand.SetSchema(p.Schema()) + physicalExpands = append(physicalExpands, expand) + // when the MppTaskType is required, we can return the physical plan directly. + if prop.TaskTp == property.MppTaskType { + return physicalExpands, true, nil + } + } + // for property.RootTaskType, we can give a TiDB Expand. + { + taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType, property.MppTaskType, property.RootTaskType} + for _, taskType := range taskTypes { + // require cop task type for children.F + tidbProp := prop.CloneEssentialFields() + tidbProp.TaskTp = taskType + expand := PhysicalExpand{ + GroupingSets: p.RollupGroupingSets, + LevelExprs: p.LevelExprs, + ExtraGroupingColNames: p.ExtraGroupingColNames, + }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), tidbProp) + expand.SetSchema(p.Schema()) + physicalExpands = append(physicalExpands, expand) + } + } + return physicalExpands, true, nil +} + +func exhaustPhysicalPlans4LogicalProjection(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + p := lp.(*logicalop.LogicalProjection) + newProp, ok := p.TryToGetChildProp(prop) + if !ok { + return nil, true, nil + } + newProps := []*property.PhysicalProperty{newProp} + // generate a mpp task candidate if mpp mode is allowed + ctx := p.SCtx() + pushDownCtx := GetPushDownCtx(ctx) + if newProp.TaskTp != property.MppTaskType && ctx.GetSessionVars().IsMPPAllowed() && p.CanPushToCop(kv.TiFlash) && + expression.CanExprsPushDown(pushDownCtx, p.Exprs, kv.TiFlash) { + mppProp := newProp.CloneEssentialFields() + mppProp.TaskTp = property.MppTaskType + newProps = append(newProps, mppProp) + } + if newProp.TaskTp != property.CopSingleReadTaskType && ctx.GetSessionVars().AllowProjectionPushDown && p.CanPushToCop(kv.TiKV) && + expression.CanExprsPushDown(pushDownCtx, p.Exprs, kv.TiKV) && !expression.ContainVirtualColumn(p.Exprs) && + expression.ProjectionBenefitsFromPushedDown(p.Exprs, p.Children()[0].Schema().Len()) { + copProp := newProp.CloneEssentialFields() + copProp.TaskTp = property.CopSingleReadTaskType + newProps = append(newProps, copProp) + } + + ret := make([]base.PhysicalPlan, 0, len(newProps)) + for _, newProp := range newProps { + proj := PhysicalProjection{ + Exprs: p.Exprs, + CalculateNoDelay: p.CalculateNoDelay, + AvoidColumnEvaluator: p.AvoidColumnEvaluator, + }.Init(ctx, p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), newProp) + proj.SetSchema(p.Schema()) + ret = append(ret, proj) + } + return ret, true, nil +} + +func pushLimitOrTopNForcibly(p base.LogicalPlan) bool { + var meetThreshold bool + var preferPushDown *bool + switch lp := p.(type) { + case *logicalop.LogicalTopN: + preferPushDown = &lp.PreferLimitToCop + meetThreshold = lp.Count+lp.Offset <= uint64(lp.SCtx().GetSessionVars().LimitPushDownThreshold) + case *logicalop.LogicalLimit: + preferPushDown = &lp.PreferLimitToCop + meetThreshold = true // always push Limit down in this case since it has no side effect + default: + return false + } + + if *preferPushDown || meetThreshold { + if p.CanPushToCop(kv.TiKV) { + return true + } + if *preferPushDown { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("Optimizer Hint LIMIT_TO_COP is inapplicable") + *preferPushDown = false + } + } + + return false +} + +func getPhysTopN(lt *logicalop.LogicalTopN, prop *property.PhysicalProperty) []base.PhysicalPlan { + allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} + if !pushLimitOrTopNForcibly(lt) { + allTaskTypes = append(allTaskTypes, property.RootTaskType) + } + if lt.SCtx().GetSessionVars().IsMPPAllowed() { + allTaskTypes = append(allTaskTypes, property.MppTaskType) + } + ret := make([]base.PhysicalPlan, 0, len(allTaskTypes)) + for _, tp := range allTaskTypes { + resultProp := &property.PhysicalProperty{TaskTp: tp, ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus} + topN := PhysicalTopN{ + ByItems: lt.ByItems, + PartitionBy: lt.PartitionBy, + Count: lt.Count, + Offset: lt.Offset, + }.Init(lt.SCtx(), lt.StatsInfo(), lt.QueryBlockOffset(), resultProp) + ret = append(ret, topN) + } + return ret +} + +func getPhysLimits(lt *logicalop.LogicalTopN, prop *property.PhysicalProperty) []base.PhysicalPlan { + p, canPass := GetPropByOrderByItems(lt.ByItems) + if !canPass { + return nil + } + + allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} + if !pushLimitOrTopNForcibly(lt) { + allTaskTypes = append(allTaskTypes, property.RootTaskType) + } + ret := make([]base.PhysicalPlan, 0, len(allTaskTypes)) + for _, tp := range allTaskTypes { + resultProp := &property.PhysicalProperty{TaskTp: tp, ExpectedCnt: float64(lt.Count + lt.Offset), SortItems: p.SortItems, CTEProducerStatus: prop.CTEProducerStatus} + limit := PhysicalLimit{ + Count: lt.Count, + Offset: lt.Offset, + PartitionBy: lt.GetPartitionBy(), + }.Init(lt.SCtx(), lt.StatsInfo(), lt.QueryBlockOffset(), resultProp) + limit.SetSchema(lt.Schema()) + ret = append(ret, limit) + } + return ret +} + +// MatchItems checks if this prop's columns can match by items totally. +func MatchItems(p *property.PhysicalProperty, items []*util.ByItems) bool { + if len(items) < len(p.SortItems) { + return false + } + for i, col := range p.SortItems { + sortItem := items[i] + if sortItem.Desc != col.Desc || !col.Col.EqualColumn(sortItem.Expr) { + return false + } + } + return true +} + +// GetHashJoin is public for cascades planner. +func GetHashJoin(la *LogicalApply, prop *property.PhysicalProperty) *PhysicalHashJoin { + return getHashJoin(&la.LogicalJoin, prop, 1, false) +} + +// ExhaustPhysicalPlans4LogicalApply generates the physical plan for a logical apply. +func ExhaustPhysicalPlans4LogicalApply(la *LogicalApply, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + if !prop.AllColsFromSchema(la.Children()[0].Schema()) || prop.IsFlashProp() { // for convenient, we don't pass through any prop + la.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because operator `Apply` is not supported now.") + return nil, true, nil + } + if !prop.IsSortItemEmpty() && la.SCtx().GetSessionVars().EnableParallelApply { + la.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Parallel Apply rejects the possible order properties of its outer child currently")) + return nil, true, nil + } + disableAggPushDownToCop(la.Children()[0]) + join := GetHashJoin(la, prop) + var columns = make([]*expression.Column, 0, len(la.CorCols)) + for _, colColumn := range la.CorCols { + columns = append(columns, &colColumn.Column) + } + cacheHitRatio := 0.0 + if la.StatsInfo().RowCount != 0 { + ndv, _ := cardinality.EstimateColsNDVWithMatchedLen(columns, la.Schema(), la.StatsInfo()) + // for example, if there are 100 rows and the number of distinct values of these correlated columns + // are 70, then we can assume 30 rows can hit the cache so the cache hit ratio is 1 - (70/100) = 0.3 + cacheHitRatio = 1 - (ndv / la.StatsInfo().RowCount) + } + + var canUseCache bool + if cacheHitRatio > 0.1 && la.SCtx().GetSessionVars().MemQuotaApplyCache > 0 { + canUseCache = true + } else { + canUseCache = false + } + + apply := PhysicalApply{ + PhysicalHashJoin: *join, + OuterSchema: la.CorCols, + CanUseCache: canUseCache, + }.Init(la.SCtx(), + la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), + la.QueryBlockOffset(), + &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, SortItems: prop.SortItems, CTEProducerStatus: prop.CTEProducerStatus}, + &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus}) + apply.SetSchema(la.Schema()) + return []base.PhysicalPlan{apply}, true, nil +} + +func disableAggPushDownToCop(p base.LogicalPlan) { + for _, child := range p.Children() { + disableAggPushDownToCop(child) + } + if agg, ok := p.(*LogicalAggregation); ok { + agg.NoCopPushDown = true + } +} + +func tryToGetMppWindows(lw *logicalop.LogicalWindow, prop *property.PhysicalProperty) []base.PhysicalPlan { + if !prop.IsSortItemAllForPartition() { + return nil + } + if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { + return nil + } + if prop.MPPPartitionTp == property.BroadcastType { + return nil + } + + { + allSupported := true + sctx := lw.SCtx() + for _, windowFunc := range lw.WindowFuncDescs { + if !windowFunc.CanPushDownToTiFlash(GetPushDownCtx(sctx)) { + lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because window function `" + windowFunc.Name + "` or its arguments are not supported now.") + allSupported = false + } else if !expression.IsPushDownEnabled(windowFunc.Name, kv.TiFlash) { + lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because window function `" + windowFunc.Name + "` is blocked by blacklist, check `table mysql.expr_pushdown_blacklist;` for more information.") + return nil + } + } + if !allSupported { + return nil + } + + if lw.Frame != nil && lw.Frame.Type == ast.Ranges { + ctx := lw.SCtx().GetExprCtx() + if _, err := expression.ExpressionsToPBList(ctx.GetEvalCtx(), lw.Frame.Start.CalcFuncs, lw.SCtx().GetClient()); err != nil { + lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because window function frame can't be pushed down, because " + err.Error()) + return nil + } + if _, err := expression.ExpressionsToPBList(ctx.GetEvalCtx(), lw.Frame.End.CalcFuncs, lw.SCtx().GetClient()); err != nil { + lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because window function frame can't be pushed down, because " + err.Error()) + return nil + } + + if !lw.CheckComparisonForTiFlash(lw.Frame.Start) || !lw.CheckComparisonForTiFlash(lw.Frame.End) { + lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because window function frame can't be pushed down, because Duration vs Datetime is invalid comparison as TiFlash can't handle it so far.") + return nil + } + } + } + + var byItems []property.SortItem + byItems = append(byItems, lw.PartitionBy...) + byItems = append(byItems, lw.OrderBy...) + childProperty := &property.PhysicalProperty{ + ExpectedCnt: math.MaxFloat64, + CanAddEnforcer: true, + SortItems: byItems, + TaskTp: property.MppTaskType, + SortItemsForPartition: byItems, + CTEProducerStatus: prop.CTEProducerStatus, + } + if !prop.IsPrefix(childProperty) { + return nil + } + + if len(lw.PartitionBy) > 0 { + partitionCols := lw.GetPartitionKeys() + // trying to match the required partitions. + if prop.MPPPartitionTp == property.HashType { + matches := prop.IsSubsetOf(partitionCols) + if len(matches) == 0 { + // do not satisfy the property of its parent, so return empty + return nil + } + partitionCols = choosePartitionKeys(partitionCols, matches) + } + childProperty.MPPPartitionTp = property.HashType + childProperty.MPPPartitionCols = partitionCols + } else { + childProperty.MPPPartitionTp = property.SinglePartitionType + } + + if prop.MPPPartitionTp == property.SinglePartitionType && childProperty.MPPPartitionTp != property.SinglePartitionType { + return nil + } + + window := PhysicalWindow{ + WindowFuncDescs: lw.WindowFuncDescs, + PartitionBy: lw.PartitionBy, + OrderBy: lw.OrderBy, + Frame: lw.Frame, + storeTp: kv.TiFlash, + }.Init(lw.SCtx(), lw.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), lw.QueryBlockOffset(), childProperty) + window.SetSchema(lw.Schema()) + + return []base.PhysicalPlan{window} +} + +func exhaustPhysicalPlans4LogicalWindow(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + lw := lp.(*logicalop.LogicalWindow) + windows := make([]base.PhysicalPlan, 0, 2) + + canPushToTiFlash := lw.CanPushToCop(kv.TiFlash) + if lw.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { + mppWindows := tryToGetMppWindows(lw, prop) + windows = append(windows, mppWindows...) + } + + // if there needs a mpp task, we don't generate tidb window function. + if prop.TaskTp == property.MppTaskType { + return windows, true, nil + } + var byItems []property.SortItem + byItems = append(byItems, lw.PartitionBy...) + byItems = append(byItems, lw.OrderBy...) + childProperty := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, SortItems: byItems, CanAddEnforcer: true, CTEProducerStatus: prop.CTEProducerStatus} + if !prop.IsPrefix(childProperty) { + return nil, true, nil + } + window := PhysicalWindow{ + WindowFuncDescs: lw.WindowFuncDescs, + PartitionBy: lw.PartitionBy, + OrderBy: lw.OrderBy, + Frame: lw.Frame, + }.Init(lw.SCtx(), lw.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), lw.QueryBlockOffset(), childProperty) + window.SetSchema(lw.Schema()) + + windows = append(windows, window) + return windows, true, nil +} + +func canPushToCopImpl(lp base.LogicalPlan, storeTp kv.StoreType, considerDual bool) bool { + p := lp.GetBaseLogicalPlan().(*logicalop.BaseLogicalPlan) + ret := true + for _, ch := range p.Children() { + switch c := ch.(type) { + case *DataSource: + validDs := false + indexMergeIsIntersection := false + for _, path := range c.PossibleAccessPaths { + if path.StoreType == storeTp { + validDs = true + } + if len(path.PartialIndexPaths) > 0 && path.IndexMergeIsIntersection { + indexMergeIsIntersection = true + } + } + ret = ret && validDs + + _, isTopN := p.Self().(*logicalop.LogicalTopN) + _, isLimit := p.Self().(*logicalop.LogicalLimit) + if (isTopN || isLimit) && indexMergeIsIntersection { + return false // TopN and Limit cannot be pushed down to the intersection type IndexMerge + } + + if c.TableInfo.TableCacheStatusType != model.TableCacheStatusDisable { + // Don't push to cop for cached table, it brings more harm than good: + // 1. Those tables are small enough, push to cop can't utilize several TiKV to accelerate computation. + // 2. Cached table use UnionScan to read the cache data, and push to cop is not supported when an UnionScan exists. + // Once aggregation is pushed to cop, the cache data can't be use any more. + return false + } + case *LogicalUnionAll: + if storeTp != kv.TiFlash { + return false + } + ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, true) + case *logicalop.LogicalSort: + if storeTp != kv.TiFlash { + return false + } + ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, true) + case *logicalop.LogicalProjection: + if storeTp != kv.TiFlash { + return false + } + ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, considerDual) + case *LogicalExpand: + // Expand itself only contains simple col ref and literal projection. (always ok, check its child) + if storeTp != kv.TiFlash { + return false + } + ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, considerDual) + case *logicalop.LogicalTableDual: + return storeTp == kv.TiFlash && considerDual + case *LogicalAggregation, *LogicalSelection, *LogicalJoin, *logicalop.LogicalWindow: + if storeTp != kv.TiFlash { + return false + } + ret = ret && c.CanPushToCop(storeTp) + // These operators can be partially push down to TiFlash, so we don't raise warning for them. + case *logicalop.LogicalLimit, *logicalop.LogicalTopN: + return false + case *logicalop.LogicalSequence: + return storeTp == kv.TiFlash + case *LogicalCTE: + if storeTp != kv.TiFlash { + return false + } + if c.Cte.recursivePartLogicalPlan != nil || !c.Cte.seedPartLogicalPlan.CanPushToCop(storeTp) { + return false + } + return true + default: + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because operator `" + c.TP() + "` is not supported now.") + return false + } + } + return ret +} + +func getEnforcedStreamAggs(la *LogicalAggregation, prop *property.PhysicalProperty) []base.PhysicalPlan { + if prop.IsFlashProp() { + return nil + } + _, desc := prop.AllSameOrder() + allTaskTypes := prop.GetAllPossibleChildTaskTypes() + enforcedAggs := make([]base.PhysicalPlan, 0, len(allTaskTypes)) + childProp := &property.PhysicalProperty{ + ExpectedCnt: math.Max(prop.ExpectedCnt*la.InputCount/la.StatsInfo().RowCount, prop.ExpectedCnt), + CanAddEnforcer: true, + SortItems: property.SortItemsFromCols(la.GetGroupByCols(), desc), + } + if !prop.IsPrefix(childProp) { + return enforcedAggs + } + taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} + if la.HasDistinct() { + // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. + // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. + if !la.CanPushToCop(kv.TiKV) || !la.SCtx().GetSessionVars().AllowDistinctAggPushDown { + taskTypes = []property.TaskType{property.RootTaskType} + } + } else if !la.PreferAggToCop { + taskTypes = append(taskTypes, property.RootTaskType) + } + for _, taskTp := range taskTypes { + copiedChildProperty := new(property.PhysicalProperty) + *copiedChildProperty = *childProp // It's ok to not deep copy the "cols" field. + copiedChildProperty.TaskTp = taskTp + + newGbyItems := make([]expression.Expression, len(la.GroupByItems)) + copy(newGbyItems, la.GroupByItems) + newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) + copy(newAggFuncs, la.AggFuncs) + + agg := basePhysicalAgg{ + GroupByItems: newGbyItems, + AggFuncs: newAggFuncs, + }.initForStream(la.SCtx(), la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), la.QueryBlockOffset(), copiedChildProperty) + agg.SetSchema(la.Schema().Clone()) + enforcedAggs = append(enforcedAggs, agg) + } + return enforcedAggs +} + +func (la *LogicalAggregation) distinctArgsMeetsProperty() bool { + for _, aggFunc := range la.AggFuncs { + if aggFunc.HasDistinct { + for _, distinctArg := range aggFunc.Args { + if !expression.Contains(la.SCtx().GetExprCtx().GetEvalCtx(), la.GroupByItems, distinctArg) { + return false + } + } + } + } + return true +} + +func getStreamAggs(lp base.LogicalPlan, prop *property.PhysicalProperty) []base.PhysicalPlan { + la := lp.(*LogicalAggregation) + // TODO: support CopTiFlash task type in stream agg + if prop.IsFlashProp() { + return nil + } + all, desc := prop.AllSameOrder() + if !all { + return nil + } + + for _, aggFunc := range la.AggFuncs { + if aggFunc.Mode == aggregation.FinalMode { + return nil + } + } + // group by a + b is not interested in any order. + groupByCols := la.GetGroupByCols() + if len(groupByCols) != len(la.GroupByItems) { + return nil + } + + allTaskTypes := prop.GetAllPossibleChildTaskTypes() + streamAggs := make([]base.PhysicalPlan, 0, len(la.PossibleProperties)*(len(allTaskTypes)-1)+len(allTaskTypes)) + childProp := &property.PhysicalProperty{ + ExpectedCnt: math.Max(prop.ExpectedCnt*la.InputCount/la.StatsInfo().RowCount, prop.ExpectedCnt), + } + + for _, possibleChildProperty := range la.PossibleProperties { + childProp.SortItems = property.SortItemsFromCols(possibleChildProperty[:len(groupByCols)], desc) + if !prop.IsPrefix(childProp) { + continue + } + // The table read of "CopDoubleReadTaskType" can't promises the sort + // property that the stream aggregation required, no need to consider. + taskTypes := []property.TaskType{property.CopSingleReadTaskType} + if la.HasDistinct() { + // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. + // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. + if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown || !la.CanPushToCop(kv.TiKV) { + // if variable doesn't allow DistinctAggPushDown, just produce root task type. + // if variable does allow DistinctAggPushDown, but OP itself can't be pushed down to tikv, just produce root task type. + taskTypes = []property.TaskType{property.RootTaskType} + } else if !la.DistinctArgsMeetsProperty() { + continue + } + } else if !la.PreferAggToCop { + taskTypes = append(taskTypes, property.RootTaskType) + } + if !la.CanPushToCop(kv.TiKV) && !la.CanPushToCop(kv.TiFlash) { + taskTypes = []property.TaskType{property.RootTaskType} + } + for _, taskTp := range taskTypes { + copiedChildProperty := new(property.PhysicalProperty) + *copiedChildProperty = *childProp // It's ok to not deep copy the "cols" field. + copiedChildProperty.TaskTp = taskTp + + newGbyItems := make([]expression.Expression, len(la.GroupByItems)) + copy(newGbyItems, la.GroupByItems) + newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) + copy(newAggFuncs, la.AggFuncs) + + agg := basePhysicalAgg{ + GroupByItems: newGbyItems, + AggFuncs: newAggFuncs, + }.initForStream(la.SCtx(), la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), la.QueryBlockOffset(), copiedChildProperty) + agg.SetSchema(la.Schema().Clone()) + streamAggs = append(streamAggs, agg) + } + } + // If STREAM_AGG hint is existed, it should consider enforce stream aggregation, + // because we can't trust possibleChildProperty completely. + if (la.PreferAggType & h.PreferStreamAgg) > 0 { + streamAggs = append(streamAggs, getEnforcedStreamAggs(la, prop)...) + } + return streamAggs +} + +// TODO: support more operators and distinct later +func checkCanPushDownToMPP(la *LogicalAggregation) bool { + hasUnsupportedDistinct := false + for _, agg := range la.AggFuncs { + // MPP does not support distinct except count distinct now + if agg.HasDistinct { + if agg.Name != ast.AggFuncCount && agg.Name != ast.AggFuncGroupConcat { + hasUnsupportedDistinct = true + } + } + // MPP does not support AggFuncApproxCountDistinct now + if agg.Name == ast.AggFuncApproxCountDistinct { + hasUnsupportedDistinct = true + } + } + if hasUnsupportedDistinct { + warnErr := errors.NewNoStackError("Aggregation can not be pushed to storage layer in mpp mode because it contains agg function with distinct") + if la.SCtx().GetSessionVars().StmtCtx.InExplainStmt { + la.SCtx().GetSessionVars().StmtCtx.AppendWarning(warnErr) + } else { + la.SCtx().GetSessionVars().StmtCtx.AppendExtraWarning(warnErr) + } + return false + } + return CheckAggCanPushCop(la.SCtx(), la.AggFuncs, la.GroupByItems, kv.TiFlash) +} + +func tryToGetMppHashAggs(la *LogicalAggregation, prop *property.PhysicalProperty) (hashAggs []base.PhysicalPlan) { + if !prop.IsSortItemEmpty() { + return nil + } + if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { + return nil + } + if prop.MPPPartitionTp == property.BroadcastType { + return nil + } + + // Is this aggregate a final stage aggregate? + // Final agg can't be split into multi-stage aggregate + hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode + // count final agg should become sum for MPP execution path. + // In the traditional case, TiDB take up the final agg role and push partial agg to TiKV, + // while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't + finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) { + for i, agg := range aggFuncs { + if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount { + oldFT := agg.RetTp + aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx().GetExprCtx(), ast.AggFuncSum, agg.Args, false) + aggFuncs[i].TypeInfer4FinalCount(oldFT) + } + } + } + // ref: https://github.com/pingcap/tiflash/blob/3ebb102fba17dce3d990d824a9df93d93f1ab + // 766/dbms/src/Flash/Coprocessor/AggregationInterpreterHelper.cpp#L26 + validMppAgg := func(mppAgg *PhysicalHashAgg) bool { + isFinalAgg := true + if mppAgg.AggFuncs[0].Mode != aggregation.FinalMode && mppAgg.AggFuncs[0].Mode != aggregation.CompleteMode { + isFinalAgg = false + } + for _, one := range mppAgg.AggFuncs[1:] { + otherIsFinalAgg := one.Mode == aggregation.FinalMode || one.Mode == aggregation.CompleteMode + if isFinalAgg != otherIsFinalAgg { + // different agg mode detected in mpp side. + return false + } + } + return true + } + + if len(la.GroupByItems) > 0 { + partitionCols := la.GetPotentialPartitionKeys() + // trying to match the required partitions. + if prop.MPPPartitionTp == property.HashType { + // partition key required by upper layer is subset of current layout. + matches := prop.IsSubsetOf(partitionCols) + if len(matches) == 0 { + // do not satisfy the property of its parent, so return empty + return nil + } + partitionCols = choosePartitionKeys(partitionCols, matches) + } else if prop.MPPPartitionTp != property.AnyType { + return nil + } + // TODO: permute various partition columns from group-by columns + // 1-phase agg + // If there are no available partition cols, but still have group by items, that means group by items are all expressions or constants. + // To avoid mess, we don't do any one-phase aggregation in this case. + // If this is a skew distinct group agg, skip generating 1-phase agg, because skew data will cause performance issue + // + // Rollup can't be 1-phase agg: cause it will append grouping_id to the schema, and expand each row as multi rows with different grouping_id. + // In a general, group items should also append grouping_id as its group layout, let's say 1-phase agg has grouping items as , and + // lower OP can supply as original partition layout, when we insert Expand logic between them: + // --> after fill null in Expand --> and this shown two rows should be shuffled to the same node (the underlying partition is not satisfied yet) + // <1,1> in node A <1,null,gid=1> in node A + // <1,2> in node B <1,null,gid=1> in node B + if len(partitionCols) != 0 && !la.SCtx().GetSessionVars().EnableSkewDistinctAgg { + childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) + agg.SetSchema(la.Schema().Clone()) + agg.MppRunMode = Mpp1Phase + finalAggAdjust(agg.AggFuncs) + if validMppAgg(agg) { + hashAggs = append(hashAggs, agg) + } + } + + // Final agg can't be split into multi-stage aggregate, so exit early + if hasFinalAgg { + return + } + + // 2-phase agg + // no partition property down,record partition cols inside agg itself, enforce shuffler latter. + childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) + agg.SetSchema(la.Schema().Clone()) + agg.MppRunMode = Mpp2Phase + agg.MppPartitionCols = partitionCols + if validMppAgg(agg) { + hashAggs = append(hashAggs, agg) + } + + // agg runs on TiDB with a partial agg on TiFlash if possible + if prop.TaskTp == property.RootTaskType { + childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) + agg.SetSchema(la.Schema().Clone()) + agg.MppRunMode = MppTiDB + hashAggs = append(hashAggs, agg) + } + } else if !hasFinalAgg { + // TODO: support scalar agg in MPP, merge the final result to one node + childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) + agg.SetSchema(la.Schema().Clone()) + if la.HasDistinct() || la.HasOrderBy() { + // mpp scalar mode means the data will be pass through to only one tiFlash node at last. + agg.MppRunMode = MppScalar + } else { + agg.MppRunMode = MppTiDB + } + hashAggs = append(hashAggs, agg) + } + + // handle MPP Agg hints + var preferMode AggMppRunMode + var prefer bool + if la.PreferAggType&h.PreferMPP1PhaseAgg > 0 { + preferMode, prefer = Mpp1Phase, true + } else if la.PreferAggType&h.PreferMPP2PhaseAgg > 0 { + preferMode, prefer = Mpp2Phase, true + } + if prefer { + var preferPlans []base.PhysicalPlan + for _, agg := range hashAggs { + if hg, ok := agg.(*PhysicalHashAgg); ok && hg.MppRunMode == preferMode { + preferPlans = append(preferPlans, hg) + } + } + hashAggs = preferPlans + } + return +} + +// getHashAggs will generate some kinds of taskType here, which finally converted to different task plan. +// when deciding whether to add a kind of taskType, there is a rule here. [Not is Not, Yes is not Sure] +// eg: which means +// +// 1: when you find something here that block hashAgg to be pushed down to XXX, just skip adding the XXXTaskType. +// 2: when you find nothing here to block hashAgg to be pushed down to XXX, just add the XXXTaskType here. +// for 2, the final result for this physical operator enumeration is chosen or rejected is according to more factors later (hint/variable/partition/virtual-col/cost) +// +// That is to say, the non-complete positive judgement of canPushDownToMPP/canPushDownToTiFlash/canPushDownToTiKV is not that for sure here. +func getHashAggs(lp base.LogicalPlan, prop *property.PhysicalProperty) []base.PhysicalPlan { + la := lp.(*LogicalAggregation) + if !prop.IsSortItemEmpty() { + return nil + } + if prop.TaskTp == property.MppTaskType && !checkCanPushDownToMPP(la) { + return nil + } + hashAggs := make([]base.PhysicalPlan, 0, len(prop.GetAllPossibleChildTaskTypes())) + taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} + canPushDownToTiFlash := la.CanPushToCop(kv.TiFlash) + canPushDownToMPP := canPushDownToTiFlash && la.SCtx().GetSessionVars().IsMPPAllowed() && checkCanPushDownToMPP(la) + if la.HasDistinct() { + // TODO: remove after the cost estimation of distinct pushdown is implemented. + if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown || !la.CanPushToCop(kv.TiKV) { + // if variable doesn't allow DistinctAggPushDown, just produce root task type. + // if variable does allow DistinctAggPushDown, but OP itself can't be pushed down to tikv, just produce root task type. + taskTypes = []property.TaskType{property.RootTaskType} + } + } else if !la.PreferAggToCop { + taskTypes = append(taskTypes, property.RootTaskType) + } + if !la.CanPushToCop(kv.TiKV) && !canPushDownToTiFlash { + taskTypes = []property.TaskType{property.RootTaskType} + } + if canPushDownToMPP { + taskTypes = append(taskTypes, property.MppTaskType) + } else { + hasMppHints := false + var errMsg string + if la.PreferAggType&h.PreferMPP1PhaseAgg > 0 { + errMsg = "The agg can not push down to the MPP side, the MPP_1PHASE_AGG() hint is invalid" + hasMppHints = true + } + if la.PreferAggType&h.PreferMPP2PhaseAgg > 0 { + errMsg = "The agg can not push down to the MPP side, the MPP_2PHASE_AGG() hint is invalid" + hasMppHints = true + } + if hasMppHints { + la.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) + } + } + if prop.IsFlashProp() { + taskTypes = []property.TaskType{prop.TaskTp} + } + + for _, taskTp := range taskTypes { + if taskTp == property.MppTaskType { + mppAggs := tryToGetMppHashAggs(la, prop) + if len(mppAggs) > 0 { + hashAggs = append(hashAggs, mppAggs...) + } + } else { + agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, TaskTp: taskTp, CTEProducerStatus: prop.CTEProducerStatus}) + agg.SetSchema(la.Schema().Clone()) + hashAggs = append(hashAggs, agg) + } + } + return hashAggs +} + +func exhaustPhysicalPlans4LogicalSelection(p *LogicalSelection, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + newProps := make([]*property.PhysicalProperty, 0, 2) + childProp := prop.CloneEssentialFields() + newProps = append(newProps, childProp) + + if prop.TaskTp != property.MppTaskType && + p.SCtx().GetSessionVars().IsMPPAllowed() && + p.canPushDown(kv.TiFlash) { + childPropMpp := prop.CloneEssentialFields() + childPropMpp.TaskTp = property.MppTaskType + newProps = append(newProps, childPropMpp) + } + + ret := make([]base.PhysicalPlan, 0, len(newProps)) + for _, newProp := range newProps { + sel := PhysicalSelection{ + Conditions: p.Conditions, + }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), newProp) + ret = append(ret, sel) + } + return ret, true, nil +} + +func exhaustPhysicalPlans4LogicalLimit(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + p := lp.(*logicalop.LogicalLimit) + return getLimitPhysicalPlans(p, prop) +} + +func getLimitPhysicalPlans(p *logicalop.LogicalLimit, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + if !prop.IsSortItemEmpty() { + return nil, true, nil + } + + allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} + if !pushLimitOrTopNForcibly(p) { + allTaskTypes = append(allTaskTypes, property.RootTaskType) + } + if p.CanPushToCop(kv.TiFlash) && p.SCtx().GetSessionVars().IsMPPAllowed() { + allTaskTypes = append(allTaskTypes, property.MppTaskType) + } + ret := make([]base.PhysicalPlan, 0, len(allTaskTypes)) + for _, tp := range allTaskTypes { + resultProp := &property.PhysicalProperty{TaskTp: tp, ExpectedCnt: float64(p.Count + p.Offset), CTEProducerStatus: prop.CTEProducerStatus} + limit := PhysicalLimit{ + Offset: p.Offset, + Count: p.Count, + PartitionBy: p.GetPartitionBy(), + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), resultProp) + limit.SetSchema(p.Schema()) + ret = append(ret, limit) + } + return ret, true, nil +} + +func exhaustPhysicalPlans4LogicalLock(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + p := lp.(*logicalop.LogicalLock) + if prop.IsFlashProp() { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( + "MPP mode may be blocked because operator `Lock` is not supported now.") + return nil, true, nil + } + childProp := prop.CloneEssentialFields() + lock := PhysicalLock{ + Lock: p.Lock, + TblID2Handle: p.TblID2Handle, + TblID2PhysTblIDCol: p.TblID2PhysTblIDCol, + }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) + return []base.PhysicalPlan{lock}, true, nil +} + +func exhaustUnionAllPhysicalPlans(p *LogicalUnionAll, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + // TODO: UnionAll can not pass any order, but we can change it to sort merge to keep order. + if !prop.IsSortItemEmpty() || (prop.IsFlashProp() && prop.TaskTp != property.MppTaskType) { + return nil, true, nil + } + // TODO: UnionAll can pass partition info, but for briefness, we prevent it from pushing down. + if prop.TaskTp == property.MppTaskType && prop.MPPPartitionTp != property.AnyType { + return nil, true, nil + } + canUseMpp := p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToCopImpl(&p.BaseLogicalPlan, kv.TiFlash, true) + chReqProps := make([]*property.PhysicalProperty, 0, p.ChildLen()) + for range p.Children() { + if canUseMpp && prop.TaskTp == property.MppTaskType { + chReqProps = append(chReqProps, &property.PhysicalProperty{ + ExpectedCnt: prop.ExpectedCnt, + TaskTp: property.MppTaskType, + RejectSort: true, + CTEProducerStatus: prop.CTEProducerStatus, + }) + } else { + chReqProps = append(chReqProps, &property.PhysicalProperty{ExpectedCnt: prop.ExpectedCnt, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus}) + } + } + ua := PhysicalUnionAll{ + mpp: canUseMpp && prop.TaskTp == property.MppTaskType, + }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), chReqProps...) + ua.SetSchema(p.Schema()) + if canUseMpp && prop.TaskTp == property.RootTaskType { + chReqProps = make([]*property.PhysicalProperty, 0, p.ChildLen()) + for range p.Children() { + chReqProps = append(chReqProps, &property.PhysicalProperty{ + ExpectedCnt: prop.ExpectedCnt, + TaskTp: property.MppTaskType, + RejectSort: true, + CTEProducerStatus: prop.CTEProducerStatus, + }) + } + mppUA := PhysicalUnionAll{mpp: true}.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), chReqProps...) + mppUA.SetSchema(p.Schema()) + return []base.PhysicalPlan{ua, mppUA}, true, nil + } + return []base.PhysicalPlan{ua}, true, nil +} + +func exhaustPartitionUnionAllPhysicalPlans(p *LogicalPartitionUnionAll, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + uas, flagHint, err := p.LogicalUnionAll.ExhaustPhysicalPlans(prop) + if err != nil { + return nil, false, err + } + for _, ua := range uas { + ua.(*PhysicalUnionAll).SetTP(plancodec.TypePartitionUnion) + } + return uas, flagHint, nil +} + +func exhaustPhysicalPlans4LogicalTopN(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + lt := lp.(*logicalop.LogicalTopN) + if MatchItems(prop, lt.ByItems) { + return append(getPhysTopN(lt, prop), getPhysLimits(lt, prop)...), true, nil + } + return nil, true, nil +} + +func exhaustPhysicalPlans4LogicalSort(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + ls := lp.(*logicalop.LogicalSort) + if prop.TaskTp == property.RootTaskType { + if MatchItems(prop, ls.ByItems) { + ret := make([]base.PhysicalPlan, 0, 2) + ret = append(ret, getPhysicalSort(ls, prop)) + ns := getNominalSort(ls, prop) + if ns != nil { + ret = append(ret, ns) + } + return ret, true, nil + } + } else if prop.TaskTp == property.MppTaskType && prop.RejectSort { + if canPushToCopImpl(&ls.BaseLogicalPlan, kv.TiFlash, true) { + ps := getNominalSortSimple(ls, prop) + return []base.PhysicalPlan{ps}, true, nil + } + } + return nil, true, nil +} + +func getPhysicalSort(ls *logicalop.LogicalSort, prop *property.PhysicalProperty) base.PhysicalPlan { + ps := PhysicalSort{ByItems: ls.ByItems}.Init(ls.SCtx(), ls.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ls.QueryBlockOffset(), &property.PhysicalProperty{TaskTp: prop.TaskTp, ExpectedCnt: math.MaxFloat64, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus}) + return ps +} + +func getNominalSort(ls *logicalop.LogicalSort, reqProp *property.PhysicalProperty) *NominalSort { + prop, canPass, onlyColumn := GetPropByOrderByItemsContainScalarFunc(ls.ByItems) + if !canPass { + return nil + } + prop.RejectSort = true + prop.ExpectedCnt = reqProp.ExpectedCnt + ps := NominalSort{OnlyColumn: onlyColumn, ByItems: ls.ByItems}.Init( + ls.SCtx(), ls.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ls.QueryBlockOffset(), prop) + return ps +} + +func getNominalSortSimple(ls *logicalop.LogicalSort, reqProp *property.PhysicalProperty) *NominalSort { + newProp := reqProp.CloneEssentialFields() + newProp.RejectSort = true + ps := NominalSort{OnlyColumn: true, ByItems: ls.ByItems}.Init( + ls.SCtx(), ls.StatsInfo().ScaleByExpectCnt(reqProp.ExpectedCnt), ls.QueryBlockOffset(), newProp) + return ps +} + +func exhaustPhysicalPlans4LogicalMaxOneRow(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + p := lp.(*logicalop.LogicalMaxOneRow) + if !prop.IsSortItemEmpty() || prop.IsFlashProp() { + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because operator `MaxOneRow` is not supported now.") + return nil, true, nil + } + mor := PhysicalMaxOneRow{}.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), &property.PhysicalProperty{ExpectedCnt: 2, CTEProducerStatus: prop.CTEProducerStatus}) + return []base.PhysicalPlan{mor}, true, nil +} + +func exhaustPhysicalPlans4LogicalCTE(p *LogicalCTE, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + pcte := PhysicalCTE{CTE: p.Cte}.Init(p.SCtx(), p.StatsInfo()) + if prop.IsFlashProp() { + pcte.storageSender = PhysicalExchangeSender{ + ExchangeType: tipb.ExchangeType_Broadcast, + }.Init(p.SCtx(), p.StatsInfo()) + } + pcte.SetSchema(p.Schema()) + pcte.childrenReqProps = []*property.PhysicalProperty{prop.CloneEssentialFields()} + return []base.PhysicalPlan{(*PhysicalCTEStorage)(pcte)}, true, nil +} + +func exhaustPhysicalPlans4LogicalSequence(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + p := lp.(*logicalop.LogicalSequence) + possibleChildrenProps := make([][]*property.PhysicalProperty, 0, 2) + anyType := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} + if prop.TaskTp == property.MppTaskType { + if prop.CTEProducerStatus == property.SomeCTEFailedMpp { + return nil, true, nil + } + anyType.CTEProducerStatus = property.AllCTECanMpp + possibleChildrenProps = append(possibleChildrenProps, []*property.PhysicalProperty{anyType, prop.CloneEssentialFields()}) + } else { + copied := prop.CloneEssentialFields() + copied.CTEProducerStatus = property.SomeCTEFailedMpp + possibleChildrenProps = append(possibleChildrenProps, []*property.PhysicalProperty{{TaskTp: property.RootTaskType, ExpectedCnt: math.MaxFloat64, CTEProducerStatus: property.SomeCTEFailedMpp}, copied}) + } + + if prop.TaskTp != property.MppTaskType && prop.CTEProducerStatus != property.SomeCTEFailedMpp && + p.SCtx().GetSessionVars().IsMPPAllowed() && prop.IsSortItemEmpty() { + possibleChildrenProps = append(possibleChildrenProps, []*property.PhysicalProperty{anyType, anyType.CloneEssentialFields()}) + } + seqs := make([]base.PhysicalPlan, 0, 2) + for _, propChoice := range possibleChildrenProps { + childReqs := make([]*property.PhysicalProperty, 0, p.ChildLen()) + for i := 0; i < p.ChildLen()-1; i++ { + childReqs = append(childReqs, propChoice[0].CloneEssentialFields()) + } + childReqs = append(childReqs, propChoice[1]) + seq := PhysicalSequence{}.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), childReqs...) + seq.SetSchema(p.Children()[p.ChildLen()-1].Schema()) + seqs = append(seqs, seq) + } + return seqs, true, nil +} diff --git a/pkg/planner/core/find_best_task.go b/pkg/planner/core/find_best_task.go index 8e819dee9c6e8..340e5f5a77c6c 100644 --- a/pkg/planner/core/find_best_task.go +++ b/pkg/planner/core/find_best_task.go @@ -1558,13 +1558,13 @@ func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, c if !prop.IsSortItemEmpty() && candidate.path.IndexMergeIsIntersection { return base.InvalidTask, nil } - failpoint.Inject("forceIndexMergeKeepOrder", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("forceIndexMergeKeepOrder")); _err_ == nil { if len(candidate.path.PartialIndexPaths) > 0 && !candidate.path.IndexMergeIsIntersection { if prop.IsSortItemEmpty() { - failpoint.Return(base.InvalidTask, nil) + return base.InvalidTask, nil } } - }) + } path := candidate.path scans := make([]base.PhysicalPlan, 0, len(path.PartialIndexPaths)) cop := &CopTask{ diff --git a/pkg/planner/core/find_best_task.go__failpoint_stash__ b/pkg/planner/core/find_best_task.go__failpoint_stash__ new file mode 100644 index 0000000000000..8e819dee9c6e8 --- /dev/null +++ b/pkg/planner/core/find_best_task.go__failpoint_stash__ @@ -0,0 +1,2982 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "cmp" + "fmt" + "math" + "slices" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cardinality" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/cost" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" + "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" + "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/types" + tidbutil "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + h "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/ranger" + "github.com/pingcap/tidb/pkg/util/tracing" + "go.uber.org/zap" +) + +// PlanCounterDisabled is the default value of PlanCounterTp, indicating that optimizer needn't force a plan. +var PlanCounterDisabled base.PlanCounterTp = -1 + +// GetPropByOrderByItems will check if this sort property can be pushed or not. In order to simplify the problem, we only +// consider the case that all expression are columns. +func GetPropByOrderByItems(items []*util.ByItems) (*property.PhysicalProperty, bool) { + propItems := make([]property.SortItem, 0, len(items)) + for _, item := range items { + col, ok := item.Expr.(*expression.Column) + if !ok { + return nil, false + } + propItems = append(propItems, property.SortItem{Col: col, Desc: item.Desc}) + } + return &property.PhysicalProperty{SortItems: propItems}, true +} + +// GetPropByOrderByItemsContainScalarFunc will check if this sort property can be pushed or not. In order to simplify the +// problem, we only consider the case that all expression are columns or some special scalar functions. +func GetPropByOrderByItemsContainScalarFunc(items []*util.ByItems) (*property.PhysicalProperty, bool, bool) { + propItems := make([]property.SortItem, 0, len(items)) + onlyColumn := true + for _, item := range items { + switch expr := item.Expr.(type) { + case *expression.Column: + propItems = append(propItems, property.SortItem{Col: expr, Desc: item.Desc}) + case *expression.ScalarFunction: + col, desc := expr.GetSingleColumn(item.Desc) + if col == nil { + return nil, false, false + } + propItems = append(propItems, property.SortItem{Col: col, Desc: desc}) + onlyColumn = false + default: + return nil, false, false + } + } + return &property.PhysicalProperty{SortItems: propItems}, true, onlyColumn +} + +func findBestTask4LogicalTableDual(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, opt *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) { + p := lp.(*logicalop.LogicalTableDual) + // If the required property is not empty and the row count > 1, + // we cannot ensure this required property. + // But if the row count is 0 or 1, we don't need to care about the property. + if (!prop.IsSortItemEmpty() && p.RowCount > 1) || planCounter.Empty() { + return base.InvalidTask, 0, nil + } + dual := PhysicalTableDual{ + RowCount: p.RowCount, + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) + dual.SetSchema(p.Schema()) + planCounter.Dec(1) + utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, dual, prop) + rt := &RootTask{} + rt.SetPlan(dual) + rt.SetEmpty(p.RowCount == 0) + return rt, 1, nil +} + +func findBestTask4LogicalShow(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) { + p := lp.(*logicalop.LogicalShow) + if !prop.IsSortItemEmpty() || planCounter.Empty() { + return base.InvalidTask, 0, nil + } + pShow := PhysicalShow{ShowContents: p.ShowContents, Extractor: p.Extractor}.Init(p.SCtx()) + pShow.SetSchema(p.Schema()) + planCounter.Dec(1) + rt := &RootTask{} + rt.SetPlan(pShow) + return rt, 1, nil +} + +func findBestTask4LogicalShowDDLJobs(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) { + p := lp.(*logicalop.LogicalShowDDLJobs) + if !prop.IsSortItemEmpty() || planCounter.Empty() { + return base.InvalidTask, 0, nil + } + pShow := PhysicalShowDDLJobs{JobNumber: p.JobNumber}.Init(p.SCtx()) + pShow.SetSchema(p.Schema()) + planCounter.Dec(1) + rt := &RootTask{} + rt.SetPlan(pShow) + return rt, 1, nil +} + +// rebuildChildTasks rebuilds the childTasks to make the clock_th combination. +func rebuildChildTasks(p *logicalop.BaseLogicalPlan, childTasks *[]base.Task, pp base.PhysicalPlan, childCnts []int64, planCounter int64, ts uint64, opt *optimizetrace.PhysicalOptimizeOp) error { + // The taskMap of children nodes should be rolled back first. + for _, child := range p.Children() { + child.RollBackTaskMap(ts) + } + + multAll := int64(1) + var curClock base.PlanCounterTp + for _, x := range childCnts { + multAll *= x + } + *childTasks = (*childTasks)[:0] + for j, child := range p.Children() { + multAll /= childCnts[j] + curClock = base.PlanCounterTp((planCounter-1)/multAll + 1) + childTask, _, err := child.FindBestTask(pp.GetChildReqProps(j), &curClock, opt) + planCounter = (planCounter-1)%multAll + 1 + if err != nil { + return err + } + if curClock != 0 { + return errors.Errorf("PlanCounterTp planCounter is not handled") + } + if childTask != nil && childTask.Invalid() { + return errors.Errorf("The current plan is invalid, please skip this plan") + } + *childTasks = append(*childTasks, childTask) + } + return nil +} + +func enumeratePhysicalPlans4Task( + p *logicalop.BaseLogicalPlan, + physicalPlans []base.PhysicalPlan, + prop *property.PhysicalProperty, + addEnforcer bool, + planCounter *base.PlanCounterTp, + opt *optimizetrace.PhysicalOptimizeOp, +) (base.Task, int64, error) { + var bestTask base.Task = base.InvalidTask + var curCntPlan, cntPlan int64 + var err error + childTasks := make([]base.Task, 0, p.ChildLen()) + childCnts := make([]int64, p.ChildLen()) + cntPlan = 0 + iteration := iteratePhysicalPlan4BaseLogical + if _, ok := p.Self().(*logicalop.LogicalSequence); ok { + iteration = iterateChildPlan4LogicalSequence + } + + for _, pp := range physicalPlans { + timeStampNow := p.GetLogicalTS4TaskMap() + savedPlanID := p.SCtx().GetSessionVars().PlanID.Load() + + childTasks, curCntPlan, childCnts, err = iteration(p, pp, childTasks, childCnts, prop, opt) + if err != nil { + return nil, 0, err + } + + // This check makes sure that there is no invalid child task. + if len(childTasks) != p.ChildLen() { + continue + } + + // If the target plan can be found in this physicalPlan(pp), rebuild childTasks to build the corresponding combination. + if planCounter.IsForce() && int64(*planCounter) <= curCntPlan { + p.SCtx().GetSessionVars().PlanID.Store(savedPlanID) + curCntPlan = int64(*planCounter) + err := rebuildChildTasks(p, &childTasks, pp, childCnts, int64(*planCounter), timeStampNow, opt) + if err != nil { + return nil, 0, err + } + } + + // Combine the best child tasks with parent physical plan. + curTask := pp.Attach2Task(childTasks...) + if curTask.Invalid() { + continue + } + + // An optimal task could not satisfy the property, so it should be converted here. + if _, ok := curTask.(*RootTask); !ok && prop.TaskTp == property.RootTaskType { + curTask = curTask.ConvertToRootTask(p.SCtx()) + } + + // Enforce curTask property + if addEnforcer { + curTask = enforceProperty(prop, curTask, p.Plan.SCtx()) + } + + // Optimize by shuffle executor to running in parallel manner. + if _, isMpp := curTask.(*MppTask); !isMpp && prop.IsSortItemEmpty() { + // Currently, we do not regard shuffled plan as a new plan. + curTask = optimizeByShuffle(curTask, p.Plan.SCtx()) + } + + cntPlan += curCntPlan + planCounter.Dec(curCntPlan) + + if planCounter.Empty() { + bestTask = curTask + break + } + utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, curTask.Plan(), prop) + // Get the most efficient one. + if curIsBetter, err := compareTaskCost(curTask, bestTask, opt); err != nil { + return nil, 0, err + } else if curIsBetter { + bestTask = curTask + } + } + return bestTask, cntPlan, nil +} + +// iteratePhysicalPlan4BaseLogical is used to iterate the physical plan and get all child tasks. +func iteratePhysicalPlan4BaseLogical( + p *logicalop.BaseLogicalPlan, + selfPhysicalPlan base.PhysicalPlan, + childTasks []base.Task, + childCnts []int64, + _ *property.PhysicalProperty, + opt *optimizetrace.PhysicalOptimizeOp, +) ([]base.Task, int64, []int64, error) { + // Find best child tasks firstly. + childTasks = childTasks[:0] + // The curCntPlan records the number of possible plans for pp + curCntPlan := int64(1) + for j, child := range p.Children() { + childProp := selfPhysicalPlan.GetChildReqProps(j) + childTask, cnt, err := child.FindBestTask(childProp, &PlanCounterDisabled, opt) + childCnts[j] = cnt + if err != nil { + return nil, 0, childCnts, err + } + curCntPlan = curCntPlan * cnt + if childTask != nil && childTask.Invalid() { + return nil, 0, childCnts, nil + } + childTasks = append(childTasks, childTask) + } + + // This check makes sure that there is no invalid child task. + if len(childTasks) != p.ChildLen() { + return nil, 0, childCnts, nil + } + return childTasks, curCntPlan, childCnts, nil +} + +// iterateChildPlan4LogicalSequence does the special part for sequence. We need to iterate its child one by one to check whether the former child is a valid plan and then go to the nex +func iterateChildPlan4LogicalSequence( + p *logicalop.BaseLogicalPlan, + selfPhysicalPlan base.PhysicalPlan, + childTasks []base.Task, + childCnts []int64, + prop *property.PhysicalProperty, + opt *optimizetrace.PhysicalOptimizeOp, +) ([]base.Task, int64, []int64, error) { + // Find best child tasks firstly. + childTasks = childTasks[:0] + // The curCntPlan records the number of possible plans for pp + curCntPlan := int64(1) + lastIdx := p.ChildLen() - 1 + for j := 0; j < lastIdx; j++ { + child := p.Children()[j] + childProp := selfPhysicalPlan.GetChildReqProps(j) + childTask, cnt, err := child.FindBestTask(childProp, &PlanCounterDisabled, opt) + childCnts[j] = cnt + if err != nil { + return nil, 0, nil, err + } + curCntPlan = curCntPlan * cnt + if childTask != nil && childTask.Invalid() { + return nil, 0, nil, nil + } + _, isMpp := childTask.(*MppTask) + if !isMpp && prop.IsFlashProp() { + break + } + childTasks = append(childTasks, childTask) + } + // This check makes sure that there is no invalid child task. + if len(childTasks) != p.ChildLen()-1 { + return nil, 0, nil, nil + } + + lastChildProp := selfPhysicalPlan.GetChildReqProps(lastIdx).CloneEssentialFields() + if lastChildProp.IsFlashProp() { + lastChildProp.CTEProducerStatus = property.AllCTECanMpp + } + lastChildTask, cnt, err := p.Children()[lastIdx].FindBestTask(lastChildProp, &PlanCounterDisabled, opt) + childCnts[lastIdx] = cnt + if err != nil { + return nil, 0, nil, err + } + curCntPlan = curCntPlan * cnt + if lastChildTask != nil && lastChildTask.Invalid() { + return nil, 0, nil, nil + } + + if _, ok := lastChildTask.(*MppTask); !ok && lastChildProp.CTEProducerStatus == property.AllCTECanMpp { + return nil, 0, nil, nil + } + + childTasks = append(childTasks, lastChildTask) + return childTasks, curCntPlan, childCnts, nil +} + +// compareTaskCost compares cost of curTask and bestTask and returns whether curTask's cost is smaller than bestTask's. +func compareTaskCost(curTask, bestTask base.Task, op *optimizetrace.PhysicalOptimizeOp) (curIsBetter bool, err error) { + curCost, curInvalid, err := utilfuncp.GetTaskPlanCost(curTask, op) + if err != nil { + return false, err + } + bestCost, bestInvalid, err := utilfuncp.GetTaskPlanCost(bestTask, op) + if err != nil { + return false, err + } + if curInvalid { + return false, nil + } + if bestInvalid { + return true, nil + } + return curCost < bestCost, nil +} + +// getTaskPlanCost returns the cost of this task. +// The new cost interface will be used if EnableNewCostInterface is true. +// The second returned value indicates whether this task is valid. +func getTaskPlanCost(t base.Task, pop *optimizetrace.PhysicalOptimizeOp) (float64, bool, error) { + if t.Invalid() { + return math.MaxFloat64, true, nil + } + + // use the new cost interface + var ( + taskType property.TaskType + indexPartialCost float64 + ) + switch t.(type) { + case *RootTask: + taskType = property.RootTaskType + case *CopTask: // no need to know whether the task is single-read or double-read, so both CopSingleReadTaskType and CopDoubleReadTaskType are OK + cop := t.(*CopTask) + if cop.indexPlan != nil && cop.tablePlan != nil { // handle IndexLookup specially + taskType = property.CopMultiReadTaskType + // keep compatible with the old cost interface, for CopMultiReadTask, the cost is idxCost + tblCost. + if !cop.indexPlanFinished { // only consider index cost in this case + idxCost, err := getPlanCost(cop.indexPlan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) + return idxCost, false, err + } + // consider both sides + idxCost, err := getPlanCost(cop.indexPlan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) + if err != nil { + return 0, false, err + } + tblCost, err := getPlanCost(cop.tablePlan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) + if err != nil { + return 0, false, err + } + return idxCost + tblCost, false, nil + } + + taskType = property.CopSingleReadTaskType + + // TiFlash can run cop task as well, check whether this cop task will run on TiKV or TiFlash. + if cop.tablePlan != nil { + leafNode := cop.tablePlan + for len(leafNode.Children()) > 0 { + leafNode = leafNode.Children()[0] + } + if tblScan, isScan := leafNode.(*PhysicalTableScan); isScan && tblScan.StoreType == kv.TiFlash { + taskType = property.MppTaskType + } + } + + // Detail reason ref about comment in function `convertToIndexMergeScan` + // for cop task with {indexPlan=nil, tablePlan=xxx, idxMergePartPlans=[x,x,x], indexPlanFinished=true} we should + // plus the partial index plan cost into the final cost. Because t.plan() the below code used only calculate the + // cost about table plan. + if cop.indexPlanFinished && len(cop.idxMergePartPlans) != 0 { + for _, partialScan := range cop.idxMergePartPlans { + partialCost, err := getPlanCost(partialScan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) + if err != nil { + return 0, false, err + } + indexPartialCost += partialCost + } + } + case *MppTask: + taskType = property.MppTaskType + default: + return 0, false, errors.New("unknown task type") + } + if t.Plan() == nil { + // It's a very special case for index merge case. + // t.plan() == nil in index merge COP case, it means indexPlanFinished is false in other words. + cost := 0.0 + copTsk := t.(*CopTask) + for _, partialScan := range copTsk.idxMergePartPlans { + partialCost, err := getPlanCost(partialScan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) + if err != nil { + return 0, false, err + } + cost += partialCost + } + return cost, false, nil + } + cost, err := getPlanCost(t.Plan(), taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) + return cost + indexPartialCost, false, err +} + +func appendCandidate4PhysicalOptimizeOp(pop *optimizetrace.PhysicalOptimizeOp, lp base.LogicalPlan, pp base.PhysicalPlan, prop *property.PhysicalProperty) { + if pop == nil || pop.GetTracer() == nil || pp == nil { + return + } + candidate := &tracing.CandidatePlanTrace{ + PlanTrace: &tracing.PlanTrace{TP: pp.TP(), ID: pp.ID(), + ExplainInfo: pp.ExplainInfo(), ProperType: prop.String()}, + MappingLogicalPlan: tracing.CodecPlanName(lp.TP(), lp.ID())} + pop.GetTracer().AppendCandidate(candidate) + + // for PhysicalIndexMergeJoin/PhysicalIndexHashJoin/PhysicalIndexJoin, it will use innerTask as a child instead of calling findBestTask, + // and innerTask.plan() will be appended to planTree in appendChildCandidate using empty MappingLogicalPlan field, so it won't mapping with the logic plan, + // that will cause no physical plan when the logic plan got selected. + // the fix to add innerTask.plan() to planTree and mapping correct logic plan + index := -1 + var plan base.PhysicalPlan + switch join := pp.(type) { + case *PhysicalIndexMergeJoin: + index = join.InnerChildIdx + plan = join.innerPlan + case *PhysicalIndexHashJoin: + index = join.InnerChildIdx + plan = join.innerPlan + case *PhysicalIndexJoin: + index = join.InnerChildIdx + plan = join.innerPlan + } + if index != -1 { + child := lp.(*logicalop.BaseLogicalPlan).Children()[index] + candidate := &tracing.CandidatePlanTrace{ + PlanTrace: &tracing.PlanTrace{TP: plan.TP(), ID: plan.ID(), + ExplainInfo: plan.ExplainInfo(), ProperType: prop.String()}, + MappingLogicalPlan: tracing.CodecPlanName(child.TP(), child.ID())} + pop.GetTracer().AppendCandidate(candidate) + } + pp.AppendChildCandidate(pop) +} + +func appendPlanCostDetail4PhysicalOptimizeOp(pop *optimizetrace.PhysicalOptimizeOp, detail *tracing.PhysicalPlanCostDetail) { + if pop == nil || pop.GetTracer() == nil { + return + } + pop.GetTracer().PhysicalPlanCostDetails[fmt.Sprintf("%v_%v", detail.GetPlanType(), detail.GetPlanID())] = detail +} + +// findBestTask is key workflow that drive logic plan tree to generate optimal physical ones. +// The logic inside it is mainly about physical plan numeration and task encapsulation, it should +// be defined in core pkg, and be called by logic plan in their logic interface implementation. +func findBestTask(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, + opt *optimizetrace.PhysicalOptimizeOp) (bestTask base.Task, cntPlan int64, err error) { + p := lp.GetBaseLogicalPlan().(*logicalop.BaseLogicalPlan) + // If p is an inner plan in an IndexJoin, the IndexJoin will generate an inner plan by itself, + // and set inner child prop nil, so here we do nothing. + if prop == nil { + return nil, 1, nil + } + // Look up the task with this prop in the task map. + // It's used to reduce double counting. + bestTask = p.GetTask(prop) + if bestTask != nil { + planCounter.Dec(1) + return bestTask, 1, nil + } + + canAddEnforcer := prop.CanAddEnforcer + + if prop.TaskTp != property.RootTaskType && !prop.IsFlashProp() { + // Currently all plan cannot totally push down to TiKV. + p.StoreTask(prop, base.InvalidTask) + return base.InvalidTask, 0, nil + } + + cntPlan = 0 + // prop should be read only because its cached hashcode might be not consistent + // when it is changed. So we clone a new one for the temporary changes. + newProp := prop.CloneEssentialFields() + var plansFitsProp, plansNeedEnforce []base.PhysicalPlan + var hintWorksWithProp bool + // Maybe the plan can satisfy the required property, + // so we try to get the task without the enforced sort first. + plansFitsProp, hintWorksWithProp, err = p.Self().ExhaustPhysicalPlans(newProp) + if err != nil { + return nil, 0, err + } + if !hintWorksWithProp && !newProp.IsSortItemEmpty() { + // If there is a hint in the plan and the hint cannot satisfy the property, + // we enforce this property and try to generate the PhysicalPlan again to + // make sure the hint can work. + canAddEnforcer = true + } + + if canAddEnforcer { + // Then, we use the empty property to get physicalPlans and + // try to get the task with an enforced sort. + newProp.SortItems = []property.SortItem{} + newProp.SortItemsForPartition = []property.SortItem{} + newProp.ExpectedCnt = math.MaxFloat64 + newProp.MPPPartitionCols = nil + newProp.MPPPartitionTp = property.AnyType + var hintCanWork bool + plansNeedEnforce, hintCanWork, err = p.Self().ExhaustPhysicalPlans(newProp) + if err != nil { + return nil, 0, err + } + if hintCanWork && !hintWorksWithProp { + // If the hint can work with the empty property, but cannot work with + // the required property, we give up `plansFitProp` to make sure the hint + // can work. + plansFitsProp = nil + } + if !hintCanWork && !hintWorksWithProp && !prop.CanAddEnforcer { + // If the original property is not enforced and hint cannot + // work anyway, we give up `plansNeedEnforce` for efficiency, + plansNeedEnforce = nil + } + newProp = prop + } + + var cnt int64 + var curTask base.Task + if bestTask, cnt, err = enumeratePhysicalPlans4Task(p, plansFitsProp, newProp, false, planCounter, opt); err != nil { + return nil, 0, err + } + cntPlan += cnt + if planCounter.Empty() { + goto END + } + + curTask, cnt, err = enumeratePhysicalPlans4Task(p, plansNeedEnforce, newProp, true, planCounter, opt) + if err != nil { + return nil, 0, err + } + cntPlan += cnt + if planCounter.Empty() { + bestTask = curTask + goto END + } + utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, curTask.Plan(), prop) + if curIsBetter, err := compareTaskCost(curTask, bestTask, opt); err != nil { + return nil, 0, err + } else if curIsBetter { + bestTask = curTask + } + +END: + p.StoreTask(prop, bestTask) + return bestTask, cntPlan, nil +} + +func findBestTask4LogicalMemTable(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, opt *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { + p := lp.(*logicalop.LogicalMemTable) + if prop.MPPPartitionTp != property.AnyType { + return base.InvalidTask, 0, nil + } + + // If prop.CanAddEnforcer is true, the prop.SortItems need to be set nil for p.findBestTask. + // Before function return, reset it for enforcing task prop. + oldProp := prop.CloneEssentialFields() + if prop.CanAddEnforcer { + // First, get the bestTask without enforced prop + prop.CanAddEnforcer = false + cnt := int64(0) + t, cnt, err = p.FindBestTask(prop, planCounter, opt) + if err != nil { + return nil, 0, err + } + prop.CanAddEnforcer = true + if t != base.InvalidTask { + cntPlan = cnt + return + } + // Next, get the bestTask with enforced prop + prop.SortItems = []property.SortItem{} + } + defer func() { + if err != nil { + return + } + if prop.CanAddEnforcer { + *prop = *oldProp + t = enforceProperty(prop, t, p.Plan.SCtx()) + prop.CanAddEnforcer = true + } + }() + + if !prop.IsSortItemEmpty() || planCounter.Empty() { + return base.InvalidTask, 0, nil + } + memTable := PhysicalMemTable{ + DBName: p.DBName, + Table: p.TableInfo, + Columns: p.Columns, + Extractor: p.Extractor, + QueryTimeRange: p.QueryTimeRange, + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) + memTable.SetSchema(p.Schema()) + planCounter.Dec(1) + utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, memTable, prop) + rt := &RootTask{} + rt.SetPlan(memTable) + return rt, 1, nil +} + +// tryToGetDualTask will check if the push down predicate has false constant. If so, it will return table dual. +func (ds *DataSource) tryToGetDualTask() (base.Task, error) { + for _, cond := range ds.PushedDownConds { + if con, ok := cond.(*expression.Constant); ok && con.DeferredExpr == nil && con.ParamMarker == nil { + result, _, err := expression.EvalBool(ds.SCtx().GetExprCtx().GetEvalCtx(), []expression.Expression{cond}, chunk.Row{}) + if err != nil { + return nil, err + } + if !result { + dual := PhysicalTableDual{}.Init(ds.SCtx(), ds.StatsInfo(), ds.QueryBlockOffset()) + dual.SetSchema(ds.Schema()) + rt := &RootTask{} + rt.SetPlan(dual) + return rt, nil + } + } + } + return nil, nil +} + +// candidatePath is used to maintain required info for skyline pruning. +type candidatePath struct { + path *util.AccessPath + accessCondsColMap util.Col2Len // accessCondsColMap maps Column.UniqueID to column length for the columns in AccessConds. + indexCondsColMap util.Col2Len // indexCondsColMap maps Column.UniqueID to column length for the columns in AccessConds and indexFilters. + isMatchProp bool +} + +func compareBool(l, r bool) int { + if l == r { + return 0 + } + if !l { + return -1 + } + return 1 +} + +func compareIndexBack(lhs, rhs *candidatePath) (int, bool) { + result := compareBool(lhs.path.IsSingleScan, rhs.path.IsSingleScan) + if result == 0 && !lhs.path.IsSingleScan { + // if both lhs and rhs need to access table after IndexScan, we utilize the set of columns that occurred in AccessConds and IndexFilters + // to compare how many table rows will be accessed. + return util.CompareCol2Len(lhs.indexCondsColMap, rhs.indexCondsColMap) + } + return result, true +} + +// compareCandidates is the core of skyline pruning, which is used to decide which candidate path is better. +// The return value is 1 if lhs is better, -1 if rhs is better, 0 if they are equivalent or not comparable. +func compareCandidates(sctx base.PlanContext, prop *property.PhysicalProperty, lhs, rhs *candidatePath) int { + // Due to #50125, full scan on MVIndex has been disabled, so MVIndex path might lead to 'can't find a proper plan' error at the end. + // Avoid MVIndex path to exclude all other paths and leading to 'can't find a proper plan' error, see #49438 for an example. + if isMVIndexPath(lhs.path) || isMVIndexPath(rhs.path) { + return 0 + } + + // This rule is empirical but not always correct. + // If x's range row count is significantly lower than y's, for example, 1000 times, we think x is better. + if lhs.path.CountAfterAccess > 100 && rhs.path.CountAfterAccess > 100 && // to prevent some extreme cases, e.g. 0.01 : 10 + len(lhs.path.PartialIndexPaths) == 0 && len(rhs.path.PartialIndexPaths) == 0 && // not IndexMerge since its row count estimation is not accurate enough + prop.ExpectedCnt == math.MaxFloat64 { // Limit may affect access row count + threshold := float64(fixcontrol.GetIntWithDefault(sctx.GetSessionVars().OptimizerFixControl, fixcontrol.Fix45132, 1000)) + if threshold > 0 { // set it to 0 to disable this rule + if lhs.path.CountAfterAccess/rhs.path.CountAfterAccess > threshold { + return -1 + } + if rhs.path.CountAfterAccess/lhs.path.CountAfterAccess > threshold { + return 1 + } + } + } + + // Below compares the two candidate paths on three dimensions: + // (1): the set of columns that occurred in the access condition, + // (2): does it require a double scan, + // (3): whether or not it matches the physical property. + // If `x` is not worse than `y` at all factors, + // and there exists one factor that `x` is better than `y`, then `x` is better than `y`. + accessResult, comparable1 := util.CompareCol2Len(lhs.accessCondsColMap, rhs.accessCondsColMap) + if !comparable1 { + return 0 + } + scanResult, comparable2 := compareIndexBack(lhs, rhs) + if !comparable2 { + return 0 + } + matchResult := compareBool(lhs.isMatchProp, rhs.isMatchProp) + sum := accessResult + scanResult + matchResult + if accessResult >= 0 && scanResult >= 0 && matchResult >= 0 && sum > 0 { + return 1 + } + if accessResult <= 0 && scanResult <= 0 && matchResult <= 0 && sum < 0 { + return -1 + } + return 0 +} + +func (ds *DataSource) isMatchProp(path *util.AccessPath, prop *property.PhysicalProperty) bool { + var isMatchProp bool + if path.IsIntHandlePath { + pkCol := ds.getPKIsHandleCol() + if len(prop.SortItems) == 1 && pkCol != nil { + isMatchProp = prop.SortItems[0].Col.EqualColumn(pkCol) + if path.StoreType == kv.TiFlash { + isMatchProp = isMatchProp && !prop.SortItems[0].Desc + } + } + return isMatchProp + } + all, _ := prop.AllSameOrder() + // When the prop is empty or `all` is false, `isMatchProp` is better to be `false` because + // it needs not to keep order for index scan. + + // Basically, if `prop.SortItems` is the prefix of `path.IdxCols`, then `isMatchProp` is true. However, we need to consider + // the situations when some columns of `path.IdxCols` are evaluated as constant. For example: + // ``` + // create table t(a int, b int, c int, d int, index idx_a_b_c(a, b, c), index idx_d_c_b_a(d, c, b, a)); + // select * from t where a = 1 order by b, c; + // select * from t where b = 1 order by a, c; + // select * from t where d = 1 and b = 2 order by c, a; + // select * from t where d = 1 and b = 2 order by c, b, a; + // ``` + // In the first two `SELECT` statements, `idx_a_b_c` matches the sort order. In the last two `SELECT` statements, `idx_d_c_b_a` + // matches the sort order. Hence, we use `path.ConstCols` to deal with the above situations. + if !prop.IsSortItemEmpty() && all && len(path.IdxCols) >= len(prop.SortItems) { + isMatchProp = true + i := 0 + for _, sortItem := range prop.SortItems { + found := false + for ; i < len(path.IdxCols); i++ { + if path.IdxColLens[i] == types.UnspecifiedLength && sortItem.Col.EqualColumn(path.IdxCols[i]) { + found = true + i++ + break + } + if path.ConstCols == nil || i >= len(path.ConstCols) || !path.ConstCols[i] { + break + } + } + if !found { + isMatchProp = false + break + } + } + } + return isMatchProp +} + +// matchPropForIndexMergeAlternatives will match the prop with inside PartialAlternativeIndexPaths, and choose +// 1 matched alternative to be a determined index merge partial path for each dimension in PartialAlternativeIndexPaths. +// finally, after we collected the all decided index merge partial paths, we will output a concrete index merge path +// with field PartialIndexPaths is fulfilled here. +// +// as we mentioned before, after deriveStats is done, the normal index OR path will be generated like below: +// +// `create table t (a int, b int, c int, key a(a), key b(b), key ac(a, c), key bc(b, c))` +// `explain format='verbose' select * from t where a=1 or b=1 order by c` +// +// like the case here: +// normal index merge OR path should be: +// for a=1, it has two partial alternative paths: [a, ac] +// for b=1, it has two partial alternative paths: [b, bc] +// and the index merge path: +// +// indexMergePath: { +// PartialIndexPaths: empty // 1D array here, currently is not decided yet. +// PartialAlternativeIndexPaths: [[a, ac], [b, bc]] // 2D array here, each for one DNF item choices. +// } +// +// let's say we have a prop requirement like sort by [c] here, we will choose the better one [ac] (because it can keep +// order) for the first batch [a, ac] from PartialAlternativeIndexPaths; and choose the better one [bc] (because it can +// keep order too) for the second batch [b, bc] from PartialAlternativeIndexPaths. Finally we output a concrete index +// merge path as +// +// indexMergePath: { +// PartialIndexPaths: [ac, bc] // just collected since they match the prop. +// ... +// } +// +// how about the prop is empty? that means the choice to be decided from [a, ac] and [b, bc] is quite random just according +// to their countAfterAccess. That's why we use a slices.SortFunc(matchIdxes, func(a, b int){}) inside there. After sort, +// the ASC order of matchIdxes of matched paths are ordered by their countAfterAccess, choosing the first one is straight forward. +// +// there is another case shown below, just the pick the first one after matchIdxes is ordered is not always right, as shown: +// special logic for alternative paths: +// +// index merge: +// matched paths-1: {pk, index1} +// matched paths-2: {pk} +// +// if we choose first one as we talked above, says pk here in the first matched paths, then path2 has no choice(avoiding all same +// index logic inside) but pk, this will result in all single index failure. so we need to sort the matchIdxes again according to +// their matched paths length, here mean: +// +// index merge: +// matched paths-1: {pk, index1} +// matched paths-2: {pk} +// +// and let matched paths-2 to be the first to make their determination --- choosing pk here, then next turn is matched paths-1 to +// make their choice, since pk is occupied, avoiding-all-same-index-logic inside will try to pick index1 here, so work can be done. +// +// at last, according to determinedIndexPartialPaths to rewrite their real countAfterAccess, this part is move from deriveStats to +// here. +func (ds *DataSource) matchPropForIndexMergeAlternatives(path *util.AccessPath, prop *property.PhysicalProperty) (*util.AccessPath, bool) { + // target: + // 1: index merge case, try to match the every alternative partial path to the order property as long as + // possible, and generate that property-matched index merge path out if any. + // 2: If the prop is empty (means no sort requirement), we will generate a random index partial combination + // path from all alternatives in case that no index merge path comes out. + + // Execution part doesn't support the merge operation for intersection case yet. + if path.IndexMergeIsIntersection { + return nil, false + } + + noSortItem := prop.IsSortItemEmpty() + allSame, _ := prop.AllSameOrder() + if !allSame { + return nil, false + } + // step1: match the property from all the index partial alternative paths. + determinedIndexPartialPaths := make([]*util.AccessPath, 0, len(path.PartialAlternativeIndexPaths)) + usedIndexMap := make(map[int64]struct{}, 1) + type idxWrapper struct { + // matchIdx is those match alternative paths from one alternative paths set. + // like we said above, for a=1, it has two partial alternative paths: [a, ac] + // if we met an empty property here, matchIdx from [a, ac] for a=1 will be both. = [0,1] + // if we met an sort[c] property here, matchIdx from [a, ac] for a=1 will be both. = [1] + matchIdx []int + // pathIdx actually is original position offset indicates where current matchIdx is + // computed from. eg: [[a, ac], [b, bc]] for sort[c] property: + // idxWrapper{[ac], 0}, 0 is the offset in first dimension of PartialAlternativeIndexPaths + // idxWrapper{[bc], 1}, 1 is the offset in first dimension of PartialAlternativeIndexPaths + pathIdx int + } + allMatchIdxes := make([]idxWrapper, 0, len(path.PartialAlternativeIndexPaths)) + // special logic for alternative paths: + // index merge: + // path1: {pk, index1} + // path2: {pk} + // if we choose pk in the first path, then path2 has no choice but pk, this will result in all single index failure. + // so we should collect all match prop paths down, stored as matchIdxes here. + for pathIdx, oneItemAlternatives := range path.PartialAlternativeIndexPaths { + matchIdxes := make([]int, 0, 1) + for i, oneIndexAlternativePath := range oneItemAlternatives { + // if there is some sort items and this path doesn't match this prop, continue. + if !noSortItem && !ds.isMatchProp(oneIndexAlternativePath, prop) { + continue + } + // two possibility here: + // 1. no sort items requirement. + // 2. matched with sorted items. + matchIdxes = append(matchIdxes, i) + } + if len(matchIdxes) == 0 { + // if all index alternative of one of the cnf item's couldn't match the sort property, + // the entire index merge union path can be ignored for this sort property, return false. + return nil, false + } + if len(matchIdxes) > 1 { + // if matchIdxes greater than 1, we should sort this match alternative path by its CountAfterAccess. + tmpOneItemAlternatives := oneItemAlternatives + slices.SortStableFunc(matchIdxes, func(a, b int) int { + lhsCountAfter := tmpOneItemAlternatives[a].CountAfterAccess + if len(tmpOneItemAlternatives[a].IndexFilters) > 0 { + lhsCountAfter = tmpOneItemAlternatives[a].CountAfterIndex + } + rhsCountAfter := tmpOneItemAlternatives[b].CountAfterAccess + if len(tmpOneItemAlternatives[b].IndexFilters) > 0 { + rhsCountAfter = tmpOneItemAlternatives[b].CountAfterIndex + } + return cmp.Compare(lhsCountAfter, rhsCountAfter) + }) + } + allMatchIdxes = append(allMatchIdxes, idxWrapper{matchIdxes, pathIdx}) + } + // sort allMatchIdxes by its element length. + // index merge: index merge: + // path1: {pk, index1} ==> path2: {pk} + // path2: {pk} path1: {pk, index1} + // here for the fixed choice pk of path2, let it be the first one to choose, left choice of index1 to path1. + slices.SortStableFunc(allMatchIdxes, func(a, b idxWrapper) int { + lhsLen := len(a.matchIdx) + rhsLen := len(b.matchIdx) + return cmp.Compare(lhsLen, rhsLen) + }) + for _, matchIdxes := range allMatchIdxes { + // since matchIdxes are ordered by matchIdxes's length, + // we should use matchIdxes.pathIdx to locate where it comes from. + alternatives := path.PartialAlternativeIndexPaths[matchIdxes.pathIdx] + found := false + // pick a most suitable index partial alternative from all matched alternative paths according to asc CountAfterAccess, + // By this way, a distinguished one is better. + for _, oneIdx := range matchIdxes.matchIdx { + var indexID int64 + if alternatives[oneIdx].IsTablePath() { + indexID = -1 + } else { + indexID = alternatives[oneIdx].Index.ID + } + if _, ok := usedIndexMap[indexID]; !ok { + // try to avoid all index partial paths are all about a single index. + determinedIndexPartialPaths = append(determinedIndexPartialPaths, alternatives[oneIdx].Clone()) + usedIndexMap[indexID] = struct{}{} + found = true + break + } + } + if !found { + // just pick the same name index (just using the first one is ok), in case that there may be some other + // picked distinctive index path for other partial paths latter. + determinedIndexPartialPaths = append(determinedIndexPartialPaths, alternatives[matchIdxes.matchIdx[0]].Clone()) + // uedIndexMap[oneItemAlternatives[oneIdx].Index.ID] = struct{}{} must already be colored. + } + } + if len(usedIndexMap) == 1 { + // if all partial path are using a same index, meaningless and fail over. + return nil, false + } + // step2: gen a new **concrete** index merge path. + indexMergePath := &util.AccessPath{ + PartialIndexPaths: determinedIndexPartialPaths, + IndexMergeIsIntersection: false, + // inherit those determined can't pushed-down table filters. + TableFilters: path.TableFilters, + } + // path.ShouldBeKeptCurrentFilter record that whether there are some part of the cnf item couldn't be pushed down to tikv already. + shouldKeepCurrentFilter := path.KeepIndexMergeORSourceFilter + pushDownCtx := GetPushDownCtx(ds.SCtx()) + for _, path := range determinedIndexPartialPaths { + // If any partial path contains table filters, we need to keep the whole DNF filter in the Selection. + if len(path.TableFilters) > 0 { + if !expression.CanExprsPushDown(pushDownCtx, path.TableFilters, kv.TiKV) { + // if this table filters can't be pushed down, all of them should be kept in the table side, cleaning the lookup side here. + path.TableFilters = nil + } + shouldKeepCurrentFilter = true + } + // If any partial path's index filter cannot be pushed to TiKV, we should keep the whole DNF filter. + if len(path.IndexFilters) != 0 && !expression.CanExprsPushDown(pushDownCtx, path.IndexFilters, kv.TiKV) { + shouldKeepCurrentFilter = true + // Clear IndexFilter, the whole filter will be put in indexMergePath.TableFilters. + path.IndexFilters = nil + } + } + // Keep this filter as a part of table filters for safety if it has any parameter. + if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), []expression.Expression{path.IndexMergeORSourceFilter}) { + shouldKeepCurrentFilter = true + } + if shouldKeepCurrentFilter { + // add the cnf expression back as table filer. + indexMergePath.TableFilters = append(indexMergePath.TableFilters, path.IndexMergeORSourceFilter) + } + + // step3: after the index merge path is determined, compute the countAfterAccess as usual. + accessConds := make([]expression.Expression, 0, len(determinedIndexPartialPaths)) + for _, p := range determinedIndexPartialPaths { + indexCondsForP := p.AccessConds[:] + indexCondsForP = append(indexCondsForP, p.IndexFilters...) + if len(indexCondsForP) > 0 { + accessConds = append(accessConds, expression.ComposeCNFCondition(ds.SCtx().GetExprCtx(), indexCondsForP...)) + } + } + accessDNF := expression.ComposeDNFCondition(ds.SCtx().GetExprCtx(), accessConds...) + sel, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, []expression.Expression{accessDNF}, nil) + if err != nil { + logutil.BgLogger().Debug("something wrong happened, use the default selectivity", zap.Error(err)) + sel = cost.SelectionFactor + } + indexMergePath.CountAfterAccess = sel * ds.TableStats.RowCount + if noSortItem { + // since there is no sort property, index merge case is generated by random combination, each alternative with the lower/lowest + // countAfterAccess, here the returned matchProperty should be false. + return indexMergePath, false + } + return indexMergePath, true +} + +func (ds *DataSource) isMatchPropForIndexMerge(path *util.AccessPath, prop *property.PhysicalProperty) bool { + // Execution part doesn't support the merge operation for intersection case yet. + if path.IndexMergeIsIntersection { + return false + } + allSame, _ := prop.AllSameOrder() + if !allSame { + return false + } + for _, partialPath := range path.PartialIndexPaths { + if !ds.isMatchProp(partialPath, prop) { + return false + } + } + return true +} + +func (ds *DataSource) getTableCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { + candidate := &candidatePath{path: path} + candidate.isMatchProp = ds.isMatchProp(path, prop) + candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx().GetEvalCtx(), path.AccessConds, nil, nil) + return candidate +} + +func (ds *DataSource) getIndexCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { + candidate := &candidatePath{path: path} + candidate.isMatchProp = ds.isMatchProp(path, prop) + candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx().GetEvalCtx(), path.AccessConds, path.IdxCols, path.IdxColLens) + candidate.indexCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx().GetEvalCtx(), append(path.AccessConds, path.IndexFilters...), path.FullIdxCols, path.FullIdxColLens) + return candidate +} + +func (ds *DataSource) convergeIndexMergeCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { + // since the all index path alternative paths is collected and undetermined, and we should determine a possible and concrete path for this prop. + possiblePath, match := ds.matchPropForIndexMergeAlternatives(path, prop) + if possiblePath == nil { + return nil + } + candidate := &candidatePath{path: possiblePath, isMatchProp: match} + return candidate +} + +func (ds *DataSource) getIndexMergeCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { + candidate := &candidatePath{path: path} + candidate.isMatchProp = ds.isMatchPropForIndexMerge(path, prop) + return candidate +} + +// skylinePruning prunes access paths according to different factors. An access path can be pruned only if +// there exists a path that is not worse than it at all factors and there is at least one better factor. +func (ds *DataSource) skylinePruning(prop *property.PhysicalProperty) []*candidatePath { + candidates := make([]*candidatePath, 0, 4) + for _, path := range ds.PossibleAccessPaths { + // We should check whether the possible access path is valid first. + if path.StoreType != kv.TiFlash && prop.IsFlashProp() { + continue + } + if len(path.PartialAlternativeIndexPaths) > 0 { + // OR normal index merge path, try to determine every index partial path for this property. + candidate := ds.convergeIndexMergeCandidate(path, prop) + if candidate != nil { + candidates = append(candidates, candidate) + } + continue + } + if path.PartialIndexPaths != nil { + candidates = append(candidates, ds.getIndexMergeCandidate(path, prop)) + continue + } + // if we already know the range of the scan is empty, just return a TableDual + if len(path.Ranges) == 0 { + return []*candidatePath{{path: path}} + } + var currentCandidate *candidatePath + if path.IsTablePath() { + currentCandidate = ds.getTableCandidate(path, prop) + } else { + if !(len(path.AccessConds) > 0 || !prop.IsSortItemEmpty() || path.Forced || path.IsSingleScan) { + continue + } + // We will use index to generate physical plan if any of the following conditions is satisfied: + // 1. This path's access cond is not nil. + // 2. We have a non-empty prop to match. + // 3. This index is forced to choose. + // 4. The needed columns are all covered by index columns(and handleCol). + currentCandidate = ds.getIndexCandidate(path, prop) + } + pruned := false + for i := len(candidates) - 1; i >= 0; i-- { + if candidates[i].path.StoreType == kv.TiFlash { + continue + } + result := compareCandidates(ds.SCtx(), prop, candidates[i], currentCandidate) + if result == 1 { + pruned = true + // We can break here because the current candidate cannot prune others anymore. + break + } else if result == -1 { + candidates = append(candidates[:i], candidates[i+1:]...) + } + } + if !pruned { + candidates = append(candidates, currentCandidate) + } + } + + if ds.SCtx().GetSessionVars().GetAllowPreferRangeScan() && len(candidates) > 1 { + // If a candidate path is TiFlash-path or forced-path, we just keep them. For other candidate paths, if there exists + // any range scan path, we remove full scan paths and keep range scan paths. + preferredPaths := make([]*candidatePath, 0, len(candidates)) + var hasRangeScanPath bool + for _, c := range candidates { + if c.path.Forced || c.path.StoreType == kv.TiFlash { + preferredPaths = append(preferredPaths, c) + continue + } + var unsignedIntHandle bool + if c.path.IsIntHandlePath && ds.TableInfo.PKIsHandle { + if pkColInfo := ds.TableInfo.GetPkColInfo(); pkColInfo != nil { + unsignedIntHandle = mysql.HasUnsignedFlag(pkColInfo.GetFlag()) + } + } + if !ranger.HasFullRange(c.path.Ranges, unsignedIntHandle) { + preferredPaths = append(preferredPaths, c) + hasRangeScanPath = true + } + } + if hasRangeScanPath { + return preferredPaths + } + } + + return candidates +} + +func (ds *DataSource) getPruningInfo(candidates []*candidatePath, prop *property.PhysicalProperty) string { + if len(candidates) == len(ds.PossibleAccessPaths) { + return "" + } + if len(candidates) == 1 && len(candidates[0].path.Ranges) == 0 { + // For TableDual, we don't need to output pruning info. + return "" + } + names := make([]string, 0, len(candidates)) + var tableName string + if ds.TableAsName.O == "" { + tableName = ds.TableInfo.Name.O + } else { + tableName = ds.TableAsName.O + } + getSimplePathName := func(path *util.AccessPath) string { + if path.IsTablePath() { + if path.StoreType == kv.TiFlash { + return tableName + "(tiflash)" + } + return tableName + } + return path.Index.Name.O + } + for _, cand := range candidates { + if cand.path.PartialIndexPaths != nil { + partialNames := make([]string, 0, len(cand.path.PartialIndexPaths)) + for _, partialPath := range cand.path.PartialIndexPaths { + partialNames = append(partialNames, getSimplePathName(partialPath)) + } + names = append(names, fmt.Sprintf("IndexMerge{%s}", strings.Join(partialNames, ","))) + } else { + names = append(names, getSimplePathName(cand.path)) + } + } + items := make([]string, 0, len(prop.SortItems)) + for _, item := range prop.SortItems { + items = append(items, item.String()) + } + return fmt.Sprintf("[%s] remain after pruning paths for %s given Prop{SortItems: [%s], TaskTp: %s}", + strings.Join(names, ","), tableName, strings.Join(items, " "), prop.TaskTp) +} + +func (ds *DataSource) isPointGetConvertableSchema() bool { + for _, col := range ds.Columns { + if col.Name.L == model.ExtraHandleName.L { + continue + } + + // Only handle tables that all columns are public. + if col.State != model.StatePublic { + return false + } + } + return true +} + +// exploreEnforcedPlan determines whether to explore enforced plans for this DataSource if it has already found an unenforced plan. +// See #46177 for more information. +func (ds *DataSource) exploreEnforcedPlan() bool { + // default value is false to keep it compatible with previous versions. + return fixcontrol.GetBoolWithDefault(ds.SCtx().GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix46177, false) +} + +func findBestTask4DS(ds *DataSource, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, opt *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { + // If ds is an inner plan in an IndexJoin, the IndexJoin will generate an inner plan by itself, + // and set inner child prop nil, so here we do nothing. + if prop == nil { + planCounter.Dec(1) + return nil, 1, nil + } + if ds.IsForUpdateRead && ds.SCtx().GetSessionVars().TxnCtx.IsExplicit { + hasPointGetPath := false + for _, path := range ds.PossibleAccessPaths { + if ds.isPointGetPath(path) { + hasPointGetPath = true + break + } + } + tblName := ds.TableInfo.Name + ds.PossibleAccessPaths, err = filterPathByIsolationRead(ds.SCtx(), ds.PossibleAccessPaths, tblName, ds.DBName) + if err != nil { + return nil, 1, err + } + if hasPointGetPath { + newPaths := make([]*util.AccessPath, 0) + for _, path := range ds.PossibleAccessPaths { + // if the path is the point get range path with for update lock, we should forbid tiflash as it's store path (#39543) + if path.StoreType != kv.TiFlash { + newPaths = append(newPaths, path) + } + } + ds.PossibleAccessPaths = newPaths + } + } + t = ds.GetTask(prop) + if t != nil { + cntPlan = 1 + planCounter.Dec(1) + return + } + var cnt int64 + var unenforcedTask base.Task + // If prop.CanAddEnforcer is true, the prop.SortItems need to be set nil for ds.findBestTask. + // Before function return, reset it for enforcing task prop and storing map. + oldProp := prop.CloneEssentialFields() + if prop.CanAddEnforcer { + // First, get the bestTask without enforced prop + prop.CanAddEnforcer = false + unenforcedTask, cnt, err = ds.FindBestTask(prop, planCounter, opt) + if err != nil { + return nil, 0, err + } + if !unenforcedTask.Invalid() && !ds.exploreEnforcedPlan() { + ds.StoreTask(prop, unenforcedTask) + return unenforcedTask, cnt, nil + } + + // Then, explore the bestTask with enforced prop + prop.CanAddEnforcer = true + cntPlan += cnt + prop.SortItems = []property.SortItem{} + prop.MPPPartitionTp = property.AnyType + } else if prop.MPPPartitionTp != property.AnyType { + return base.InvalidTask, 0, nil + } + defer func() { + if err != nil { + return + } + if prop.CanAddEnforcer { + *prop = *oldProp + t = enforceProperty(prop, t, ds.Plan.SCtx()) + prop.CanAddEnforcer = true + } + + if unenforcedTask != nil && !unenforcedTask.Invalid() { + curIsBest, cerr := compareTaskCost(unenforcedTask, t, opt) + if cerr != nil { + err = cerr + return + } + if curIsBest { + t = unenforcedTask + } + } + + ds.StoreTask(prop, t) + err = validateTableSamplePlan(ds, t, err) + }() + + t, err = ds.tryToGetDualTask() + if err != nil || t != nil { + planCounter.Dec(1) + if t != nil { + appendCandidate(ds, t, prop, opt) + } + return t, 1, err + } + + t = base.InvalidTask + candidates := ds.skylinePruning(prop) + pruningInfo := ds.getPruningInfo(candidates, prop) + defer func() { + if err == nil && t != nil && !t.Invalid() && pruningInfo != "" { + warnErr := errors.NewNoStackError(pruningInfo) + if ds.SCtx().GetSessionVars().StmtCtx.InVerboseExplain { + ds.SCtx().GetSessionVars().StmtCtx.AppendNote(warnErr) + } else { + ds.SCtx().GetSessionVars().StmtCtx.AppendExtraNote(warnErr) + } + } + }() + + cntPlan = 0 + for _, candidate := range candidates { + path := candidate.path + if path.PartialIndexPaths != nil { + idxMergeTask, err := ds.convertToIndexMergeScan(prop, candidate, opt) + if err != nil { + return nil, 0, err + } + if !idxMergeTask.Invalid() { + cntPlan++ + planCounter.Dec(1) + } + appendCandidate(ds, idxMergeTask, prop, opt) + + curIsBetter, err := compareTaskCost(idxMergeTask, t, opt) + if err != nil { + return nil, 0, err + } + if curIsBetter || planCounter.Empty() { + t = idxMergeTask + } + if planCounter.Empty() { + return t, cntPlan, nil + } + continue + } + // if we already know the range of the scan is empty, just return a TableDual + if len(path.Ranges) == 0 { + // We should uncache the tableDual plan. + if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), path.AccessConds) { + ds.SCtx().GetSessionVars().StmtCtx.SetSkipPlanCache("get a TableDual plan") + } + dual := PhysicalTableDual{}.Init(ds.SCtx(), ds.StatsInfo(), ds.QueryBlockOffset()) + dual.SetSchema(ds.Schema()) + cntPlan++ + planCounter.Dec(1) + t := &RootTask{} + t.SetPlan(dual) + appendCandidate(ds, t, prop, opt) + return t, cntPlan, nil + } + + canConvertPointGet := len(path.Ranges) > 0 && path.StoreType == kv.TiKV && ds.isPointGetConvertableSchema() + + if canConvertPointGet && path.Index != nil && path.Index.MVIndex { + canConvertPointGet = false // cannot use PointGet upon MVIndex + } + + if canConvertPointGet && !path.IsIntHandlePath { + // We simply do not build [batch] point get for prefix indexes. This can be optimized. + canConvertPointGet = path.Index.Unique && !path.Index.HasPrefixIndex() + // If any range cannot cover all columns of the index, we cannot build [batch] point get. + idxColsLen := len(path.Index.Columns) + for _, ran := range path.Ranges { + if len(ran.LowVal) != idxColsLen { + canConvertPointGet = false + break + } + } + } + if canConvertPointGet && ds.table.Meta().GetPartitionInfo() != nil { + // partition table with dynamic prune not support batchPointGet + // Due to sorting? + // Please make sure handle `where _tidb_rowid in (xx, xx)` correctly when delete this if statements. + if canConvertPointGet && len(path.Ranges) > 1 && ds.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + canConvertPointGet = false + } + if canConvertPointGet && len(path.Ranges) > 1 { + // TODO: This is now implemented, but to decrease + // the impact of supporting plan cache for patitioning, + // this is not yet enabled. + // TODO: just remove this if block and update/add tests... + // We can only build batch point get for hash partitions on a simple column now. This is + // decided by the current implementation of `BatchPointGetExec::initialize()`, specifically, + // the `getPhysID()` function. Once we optimize that part, we can come back and enable + // BatchPointGet plan for more cases. + hashPartColName := getHashOrKeyPartitionColumnName(ds.SCtx(), ds.table.Meta()) + if hashPartColName == nil { + canConvertPointGet = false + } + } + // Partition table can't use `_tidb_rowid` to generate PointGet Plan unless one partition is explicitly specified. + if canConvertPointGet && path.IsIntHandlePath && !ds.table.Meta().PKIsHandle && len(ds.PartitionNames) != 1 { + canConvertPointGet = false + } + if canConvertPointGet { + if path != nil && path.Index != nil && path.Index.Global { + // Don't convert to point get during ddl + // TODO: Revisit truncate partition and global index + if len(ds.TableInfo.GetPartitionInfo().DroppingDefinitions) > 0 || + len(ds.TableInfo.GetPartitionInfo().AddingDefinitions) > 0 { + canConvertPointGet = false + } + } + } + } + if canConvertPointGet { + allRangeIsPoint := true + tc := ds.SCtx().GetSessionVars().StmtCtx.TypeCtx() + for _, ran := range path.Ranges { + if !ran.IsPointNonNullable(tc) { + // unique indexes can have duplicated NULL rows so we cannot use PointGet if there is NULL + allRangeIsPoint = false + break + } + } + if allRangeIsPoint { + var pointGetTask base.Task + if len(path.Ranges) == 1 { + pointGetTask = ds.convertToPointGet(prop, candidate) + } else { + pointGetTask = ds.convertToBatchPointGet(prop, candidate) + } + + // Batch/PointGet plans may be over-optimized, like `a>=1(?) and a<=1(?)` --> `a=1` --> PointGet(a=1). + // For safety, prevent these plans from the plan cache here. + if !pointGetTask.Invalid() && expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), candidate.path.AccessConds) && !isSafePointGetPath4PlanCache(ds.SCtx(), candidate.path) { + ds.SCtx().GetSessionVars().StmtCtx.SetSkipPlanCache("Batch/PointGet plans may be over-optimized") + } + + appendCandidate(ds, pointGetTask, prop, opt) + if !pointGetTask.Invalid() { + cntPlan++ + planCounter.Dec(1) + } + curIsBetter, cerr := compareTaskCost(pointGetTask, t, opt) + if cerr != nil { + return nil, 0, cerr + } + if curIsBetter || planCounter.Empty() { + t = pointGetTask + if planCounter.Empty() { + return + } + continue + } + } + } + if path.IsTablePath() { + if ds.PreferStoreType&h.PreferTiFlash != 0 && path.StoreType == kv.TiKV { + continue + } + if ds.PreferStoreType&h.PreferTiKV != 0 && path.StoreType == kv.TiFlash { + continue + } + var tblTask base.Task + if ds.SampleInfo != nil { + tblTask, err = ds.convertToSampleTable(prop, candidate, opt) + } else { + tblTask, err = ds.convertToTableScan(prop, candidate, opt) + } + if err != nil { + return nil, 0, err + } + if !tblTask.Invalid() { + cntPlan++ + planCounter.Dec(1) + } + appendCandidate(ds, tblTask, prop, opt) + curIsBetter, err := compareTaskCost(tblTask, t, opt) + if err != nil { + return nil, 0, err + } + if curIsBetter || planCounter.Empty() { + t = tblTask + } + if planCounter.Empty() { + return t, cntPlan, nil + } + continue + } + // TiFlash storage do not support index scan. + if ds.PreferStoreType&h.PreferTiFlash != 0 { + continue + } + // TableSample do not support index scan. + if ds.SampleInfo != nil { + continue + } + idxTask, err := ds.convertToIndexScan(prop, candidate, opt) + if err != nil { + return nil, 0, err + } + if !idxTask.Invalid() { + cntPlan++ + planCounter.Dec(1) + } + appendCandidate(ds, idxTask, prop, opt) + curIsBetter, err := compareTaskCost(idxTask, t, opt) + if err != nil { + return nil, 0, err + } + if curIsBetter || planCounter.Empty() { + t = idxTask + } + if planCounter.Empty() { + return t, cntPlan, nil + } + } + + return +} + +// convertToIndexMergeScan builds the index merge scan for intersection or union cases. +func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (task base.Task, err error) { + if prop.IsFlashProp() || prop.TaskTp == property.CopSingleReadTaskType { + return base.InvalidTask, nil + } + // lift the limitation of that double read can not build index merge **COP** task with intersection. + // that means we can output a cop task here without encapsulating it as root task, for the convenience of attaching limit to its table side. + + if !prop.IsSortItemEmpty() && !candidate.isMatchProp { + return base.InvalidTask, nil + } + // while for now, we still can not push the sort prop to the intersection index plan side, temporarily banned here. + if !prop.IsSortItemEmpty() && candidate.path.IndexMergeIsIntersection { + return base.InvalidTask, nil + } + failpoint.Inject("forceIndexMergeKeepOrder", func(_ failpoint.Value) { + if len(candidate.path.PartialIndexPaths) > 0 && !candidate.path.IndexMergeIsIntersection { + if prop.IsSortItemEmpty() { + failpoint.Return(base.InvalidTask, nil) + } + } + }) + path := candidate.path + scans := make([]base.PhysicalPlan, 0, len(path.PartialIndexPaths)) + cop := &CopTask{ + indexPlanFinished: false, + tblColHists: ds.TblColHists, + } + cop.physPlanPartInfo = &PhysPlanPartInfo{ + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), + PartitionNames: ds.PartitionNames, + Columns: ds.TblCols, + ColumnNames: ds.OutputNames(), + } + // Add sort items for index scan for merge-sort operation between partitions. + byItems := make([]*util.ByItems, 0, len(prop.SortItems)) + for _, si := range prop.SortItems { + byItems = append(byItems, &util.ByItems{ + Expr: si.Col, + Desc: si.Desc, + }) + } + globalRemainingFilters := make([]expression.Expression, 0, 3) + for _, partPath := range path.PartialIndexPaths { + var scan base.PhysicalPlan + if partPath.IsTablePath() { + scan = ds.convertToPartialTableScan(prop, partPath, candidate.isMatchProp, byItems) + } else { + var remainingFilters []expression.Expression + scan, remainingFilters, err = ds.convertToPartialIndexScan(cop.physPlanPartInfo, prop, partPath, candidate.isMatchProp, byItems) + if err != nil { + return base.InvalidTask, err + } + if prop.TaskTp != property.RootTaskType && len(remainingFilters) > 0 { + return base.InvalidTask, nil + } + globalRemainingFilters = append(globalRemainingFilters, remainingFilters...) + } + scans = append(scans, scan) + } + totalRowCount := path.CountAfterAccess + if prop.ExpectedCnt < ds.StatsInfo().RowCount { + totalRowCount *= prop.ExpectedCnt / ds.StatsInfo().RowCount + } + ts, remainingFilters2, moreColumn, err := ds.buildIndexMergeTableScan(path.TableFilters, totalRowCount, candidate.isMatchProp) + if err != nil { + return base.InvalidTask, err + } + if prop.TaskTp != property.RootTaskType && len(remainingFilters2) > 0 { + return base.InvalidTask, nil + } + globalRemainingFilters = append(globalRemainingFilters, remainingFilters2...) + cop.keepOrder = candidate.isMatchProp + cop.tablePlan = ts + cop.idxMergePartPlans = scans + cop.idxMergeIsIntersection = path.IndexMergeIsIntersection + cop.idxMergeAccessMVIndex = path.IndexMergeAccessMVIndex + if moreColumn { + cop.needExtraProj = true + cop.originSchema = ds.Schema() + } + if len(globalRemainingFilters) != 0 { + cop.rootTaskConds = globalRemainingFilters + } + // after we lift the limitation of intersection and cop-type task in the code in this + // function above, we could set its index plan finished as true once we found its table + // plan is pure table scan below. + // And this will cause cost underestimation when we estimate the cost of the entire cop + // task plan in function `getTaskPlanCost`. + if prop.TaskTp == property.RootTaskType { + cop.indexPlanFinished = true + task = cop.ConvertToRootTask(ds.SCtx()) + } else { + _, pureTableScan := ts.(*PhysicalTableScan) + if !pureTableScan { + cop.indexPlanFinished = true + } + task = cop + } + return task, nil +} + +func (ds *DataSource) convertToPartialIndexScan(physPlanPartInfo *PhysPlanPartInfo, prop *property.PhysicalProperty, path *util.AccessPath, matchProp bool, byItems []*util.ByItems) (base.PhysicalPlan, []expression.Expression, error) { + is := ds.getOriginalPhysicalIndexScan(prop, path, matchProp, false) + // TODO: Consider using isIndexCoveringColumns() to avoid another TableRead + indexConds := path.IndexFilters + if matchProp { + if is.Table.GetPartitionInfo() != nil && !is.Index.Global && is.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + is.Columns, is.schema, _ = AddExtraPhysTblIDColumn(is.SCtx(), is.Columns, is.schema) + } + // Add sort items for index scan for merge-sort operation between partitions. + is.ByItems = byItems + } + + // Add a `Selection` for `IndexScan` with global index. + // It should pushdown to TiKV, DataSource schema doesn't contain partition id column. + indexConds, err := is.addSelectionConditionForGlobalIndex(ds, physPlanPartInfo, indexConds) + if err != nil { + return nil, nil, err + } + + if len(indexConds) > 0 { + pushedFilters, remainingFilter := extractFiltersForIndexMerge(GetPushDownCtx(ds.SCtx()), indexConds) + var selectivity float64 + if path.CountAfterAccess > 0 { + selectivity = path.CountAfterIndex / path.CountAfterAccess + } + rowCount := is.StatsInfo().RowCount * selectivity + stats := &property.StatsInfo{RowCount: rowCount} + stats.StatsVersion = ds.StatisticTable.Version + if ds.StatisticTable.Pseudo { + stats.StatsVersion = statistics.PseudoVersion + } + indexPlan := PhysicalSelection{Conditions: pushedFilters}.Init(is.SCtx(), stats, ds.QueryBlockOffset()) + indexPlan.SetChildren(is) + return indexPlan, remainingFilter, nil + } + return is, nil, nil +} + +func checkColinSchema(cols []*expression.Column, schema *expression.Schema) bool { + for _, col := range cols { + if schema.ColumnIndex(col) == -1 { + return false + } + } + return true +} + +func (ds *DataSource) convertToPartialTableScan(prop *property.PhysicalProperty, path *util.AccessPath, matchProp bool, byItems []*util.ByItems) (tablePlan base.PhysicalPlan) { + ts, rowCount := ds.getOriginalPhysicalTableScan(prop, path, matchProp) + overwritePartialTableScanSchema(ds, ts) + // remove ineffetive filter condition after overwriting physicalscan schema + newFilterConds := make([]expression.Expression, 0, len(path.TableFilters)) + for _, cond := range ts.filterCondition { + cols := expression.ExtractColumns(cond) + if checkColinSchema(cols, ts.schema) { + newFilterConds = append(newFilterConds, cond) + } + } + ts.filterCondition = newFilterConds + if matchProp { + if ts.Table.GetPartitionInfo() != nil && ts.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + ts.Columns, ts.schema, _ = AddExtraPhysTblIDColumn(ts.SCtx(), ts.Columns, ts.schema) + } + ts.ByItems = byItems + } + if len(ts.filterCondition) > 0 { + selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, ts.filterCondition, nil) + if err != nil { + logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) + selectivity = cost.SelectionFactor + } + tablePlan = PhysicalSelection{Conditions: ts.filterCondition}.Init(ts.SCtx(), ts.StatsInfo().ScaleByExpectCnt(selectivity*rowCount), ds.QueryBlockOffset()) + tablePlan.SetChildren(ts) + return tablePlan + } + tablePlan = ts + return tablePlan +} + +// overwritePartialTableScanSchema change the schema of partial table scan to handle columns. +func overwritePartialTableScanSchema(ds *DataSource, ts *PhysicalTableScan) { + handleCols := ds.HandleCols + if handleCols == nil { + handleCols = util.NewIntHandleCols(ds.newExtraHandleSchemaCol()) + } + hdColNum := handleCols.NumCols() + exprCols := make([]*expression.Column, 0, hdColNum) + infoCols := make([]*model.ColumnInfo, 0, hdColNum) + for i := 0; i < hdColNum; i++ { + col := handleCols.GetCol(i) + exprCols = append(exprCols, col) + if c := model.FindColumnInfoByID(ds.TableInfo.Columns, col.ID); c != nil { + infoCols = append(infoCols, c) + } else { + infoCols = append(infoCols, col.ToInfo()) + } + } + ts.schema = expression.NewSchema(exprCols...) + ts.Columns = infoCols +} + +// setIndexMergeTableScanHandleCols set the handle columns of the table scan. +func setIndexMergeTableScanHandleCols(ds *DataSource, ts *PhysicalTableScan) (err error) { + handleCols := ds.HandleCols + if handleCols == nil { + handleCols = util.NewIntHandleCols(ds.newExtraHandleSchemaCol()) + } + hdColNum := handleCols.NumCols() + exprCols := make([]*expression.Column, 0, hdColNum) + for i := 0; i < hdColNum; i++ { + col := handleCols.GetCol(i) + exprCols = append(exprCols, col) + } + ts.HandleCols, err = handleCols.ResolveIndices(expression.NewSchema(exprCols...)) + return +} + +// buildIndexMergeTableScan() returns Selection that will be pushed to TiKV. +// Filters that cannot be pushed to TiKV are also returned, and an extra Selection above IndexMergeReader will be constructed later. +func (ds *DataSource) buildIndexMergeTableScan(tableFilters []expression.Expression, + totalRowCount float64, matchProp bool) (base.PhysicalPlan, []expression.Expression, bool, error) { + ts := PhysicalTableScan{ + Table: ds.TableInfo, + Columns: slices.Clone(ds.Columns), + TableAsName: ds.TableAsName, + DBName: ds.DBName, + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + HandleCols: ds.HandleCols, + tblCols: ds.TblCols, + tblColHists: ds.TblColHists, + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + ts.SetSchema(ds.Schema().Clone()) + err := setIndexMergeTableScanHandleCols(ds, ts) + if err != nil { + return nil, nil, false, err + } + ts.SetStats(ds.TableStats.ScaleByExpectCnt(totalRowCount)) + usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) + if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { + ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) + } + if ds.StatisticTable.Pseudo { + ts.StatsInfo().StatsVersion = statistics.PseudoVersion + } + var currentTopPlan base.PhysicalPlan = ts + if len(tableFilters) > 0 { + pushedFilters, remainingFilters := extractFiltersForIndexMerge(GetPushDownCtx(ds.SCtx()), tableFilters) + pushedFilters1, remainingFilters1 := SplitSelCondsWithVirtualColumn(pushedFilters) + pushedFilters = pushedFilters1 + remainingFilters = append(remainingFilters, remainingFilters1...) + if len(pushedFilters) != 0 { + selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, pushedFilters, nil) + if err != nil { + logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) + selectivity = cost.SelectionFactor + } + sel := PhysicalSelection{Conditions: pushedFilters}.Init(ts.SCtx(), ts.StatsInfo().ScaleByExpectCnt(selectivity*totalRowCount), ts.QueryBlockOffset()) + sel.SetChildren(ts) + currentTopPlan = sel + } + if len(remainingFilters) > 0 { + return currentTopPlan, remainingFilters, false, nil + } + } + // If we don't need to use ordered scan, we don't need do the following codes for adding new columns. + if !matchProp { + return currentTopPlan, nil, false, nil + } + + // Add the row handle into the schema. + columnAdded := false + if ts.Table.PKIsHandle { + pk := ts.Table.GetPkColInfo() + pkCol := expression.ColInfo2Col(ts.tblCols, pk) + if !ts.schema.Contains(pkCol) { + ts.schema.Append(pkCol) + ts.Columns = append(ts.Columns, pk) + columnAdded = true + } + } else if ts.Table.IsCommonHandle { + idxInfo := ts.Table.GetPrimaryKey() + for _, idxCol := range idxInfo.Columns { + col := ts.tblCols[idxCol.Offset] + if !ts.schema.Contains(col) { + columnAdded = true + ts.schema.Append(col) + ts.Columns = append(ts.Columns, col.ToInfo()) + } + } + } else if !ts.schema.Contains(ts.HandleCols.GetCol(0)) { + ts.schema.Append(ts.HandleCols.GetCol(0)) + ts.Columns = append(ts.Columns, model.NewExtraHandleColInfo()) + columnAdded = true + } + + // For the global index of the partitioned table, we also need the PhysicalTblID to identify the rows from each partition. + if ts.Table.GetPartitionInfo() != nil && ts.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + var newColAdded bool + ts.Columns, ts.schema, newColAdded = AddExtraPhysTblIDColumn(ts.SCtx(), ts.Columns, ts.schema) + columnAdded = columnAdded || newColAdded + } + return currentTopPlan, nil, columnAdded, nil +} + +// extractFiltersForIndexMerge returns: +// `pushed`: exprs that can be pushed to TiKV. +// `remaining`: exprs that can NOT be pushed to TiKV but can be pushed to other storage engines. +// Why do we need this func? +// IndexMerge only works on TiKV, so we need to find all exprs that cannot be pushed to TiKV, and add a new Selection above IndexMergeReader. +// +// But the new Selection should exclude the exprs that can NOT be pushed to ALL the storage engines. +// Because these exprs have already been put in another Selection(check rule_predicate_push_down). +func extractFiltersForIndexMerge(ctx expression.PushDownContext, filters []expression.Expression) (pushed []expression.Expression, remaining []expression.Expression) { + for _, expr := range filters { + if expression.CanExprsPushDown(ctx, []expression.Expression{expr}, kv.TiKV) { + pushed = append(pushed, expr) + continue + } + if expression.CanExprsPushDown(ctx, []expression.Expression{expr}, kv.UnSpecified) { + remaining = append(remaining, expr) + } + } + return +} + +func isIndexColsCoveringCol(sctx expression.EvalContext, col *expression.Column, indexCols []*expression.Column, idxColLens []int, ignoreLen bool) bool { + for i, indexCol := range indexCols { + if indexCol == nil || !col.EqualByExprAndID(sctx, indexCol) { + continue + } + if ignoreLen || idxColLens[i] == types.UnspecifiedLength || idxColLens[i] == col.RetType.GetFlen() { + return true + } + } + return false +} + +func (ds *DataSource) indexCoveringColumn(column *expression.Column, indexColumns []*expression.Column, idxColLens []int, ignoreLen bool) bool { + if ds.TableInfo.PKIsHandle && mysql.HasPriKeyFlag(column.RetType.GetFlag()) { + return true + } + if column.ID == model.ExtraHandleID || column.ID == model.ExtraPhysTblID { + return true + } + evalCtx := ds.SCtx().GetExprCtx().GetEvalCtx() + coveredByPlainIndex := isIndexColsCoveringCol(evalCtx, column, indexColumns, idxColLens, ignoreLen) + coveredByClusteredIndex := isIndexColsCoveringCol(evalCtx, column, ds.CommonHandleCols, ds.CommonHandleLens, ignoreLen) + if !coveredByPlainIndex && !coveredByClusteredIndex { + return false + } + isClusteredNewCollationIdx := collate.NewCollationEnabled() && + column.GetType(evalCtx).EvalType() == types.ETString && + !mysql.HasBinaryFlag(column.GetType(evalCtx).GetFlag()) + if !coveredByPlainIndex && coveredByClusteredIndex && isClusteredNewCollationIdx && ds.table.Meta().CommonHandleVersion == 0 { + return false + } + return true +} + +func (ds *DataSource) isIndexCoveringColumns(columns, indexColumns []*expression.Column, idxColLens []int) bool { + for _, col := range columns { + if !ds.indexCoveringColumn(col, indexColumns, idxColLens, false) { + return false + } + } + return true +} + +func (ds *DataSource) isIndexCoveringCondition(condition expression.Expression, indexColumns []*expression.Column, idxColLens []int) bool { + switch v := condition.(type) { + case *expression.Column: + return ds.indexCoveringColumn(v, indexColumns, idxColLens, false) + case *expression.ScalarFunction: + // Even if the index only contains prefix `col`, the index can cover `col is null`. + if v.FuncName.L == ast.IsNull { + if col, ok := v.GetArgs()[0].(*expression.Column); ok { + return ds.indexCoveringColumn(col, indexColumns, idxColLens, true) + } + } + for _, arg := range v.GetArgs() { + if !ds.isIndexCoveringCondition(arg, indexColumns, idxColLens) { + return false + } + } + return true + } + return true +} + +func (ds *DataSource) isSingleScan(indexColumns []*expression.Column, idxColLens []int) bool { + if !ds.SCtx().GetSessionVars().OptPrefixIndexSingleScan || ds.ColsRequiringFullLen == nil { + // ds.ColsRequiringFullLen is set at (*DataSource).PruneColumns. In some cases we don't reach (*DataSource).PruneColumns + // and ds.ColsRequiringFullLen is nil, so we fall back to ds.isIndexCoveringColumns(ds.schema.Columns, indexColumns, idxColLens). + return ds.isIndexCoveringColumns(ds.Schema().Columns, indexColumns, idxColLens) + } + if !ds.isIndexCoveringColumns(ds.ColsRequiringFullLen, indexColumns, idxColLens) { + return false + } + for _, cond := range ds.AllConds { + if !ds.isIndexCoveringCondition(cond, indexColumns, idxColLens) { + return false + } + } + return true +} + +// If there is a table reader which needs to keep order, we should append a pk to table scan. +func (ts *PhysicalTableScan) appendExtraHandleCol(ds *DataSource) (*expression.Column, bool) { + handleCols := ds.HandleCols + if handleCols != nil { + return handleCols.GetCol(0), false + } + handleCol := ds.newExtraHandleSchemaCol() + ts.schema.Append(handleCol) + ts.Columns = append(ts.Columns, model.NewExtraHandleColInfo()) + return handleCol, true +} + +// convertToIndexScan converts the DataSource to index scan with idx. +func (ds *DataSource) convertToIndexScan(prop *property.PhysicalProperty, + candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (task base.Task, err error) { + if candidate.path.Index.MVIndex { + // MVIndex is special since different index rows may return the same _row_id and this can break some assumptions of IndexReader. + // Currently only support using IndexMerge to access MVIndex instead of IndexReader. + // TODO: make IndexReader support accessing MVIndex directly. + return base.InvalidTask, nil + } + if !candidate.path.IsSingleScan { + // If it's parent requires single read task, return max cost. + if prop.TaskTp == property.CopSingleReadTaskType { + return base.InvalidTask, nil + } + } else if prop.TaskTp == property.CopMultiReadTaskType { + // If it's parent requires double read task, return max cost. + return base.InvalidTask, nil + } + if !prop.IsSortItemEmpty() && !candidate.isMatchProp { + return base.InvalidTask, nil + } + // If we need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. + if prop.IsSortItemEmpty() && candidate.path.ForceKeepOrder { + return base.InvalidTask, nil + } + // If we don't need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. + if !prop.IsSortItemEmpty() && candidate.path.ForceNoKeepOrder { + return base.InvalidTask, nil + } + path := candidate.path + is := ds.getOriginalPhysicalIndexScan(prop, path, candidate.isMatchProp, candidate.path.IsSingleScan) + cop := &CopTask{ + indexPlan: is, + tblColHists: ds.TblColHists, + tblCols: ds.TblCols, + expectCnt: uint64(prop.ExpectedCnt), + } + cop.physPlanPartInfo = &PhysPlanPartInfo{ + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), + PartitionNames: ds.PartitionNames, + Columns: ds.TblCols, + ColumnNames: ds.OutputNames(), + } + if !candidate.path.IsSingleScan { + // On this way, it's double read case. + ts := PhysicalTableScan{ + Columns: util.CloneColInfos(ds.Columns), + Table: is.Table, + TableAsName: ds.TableAsName, + DBName: ds.DBName, + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + tblCols: ds.TblCols, + tblColHists: ds.TblColHists, + }.Init(ds.SCtx(), is.QueryBlockOffset()) + ts.SetSchema(ds.Schema().Clone()) + // We set `StatsVersion` here and fill other fields in `(*copTask).finishIndexPlan`. Since `copTask.indexPlan` may + // change before calling `(*copTask).finishIndexPlan`, we don't know the stats information of `ts` currently and on + // the other hand, it may be hard to identify `StatsVersion` of `ts` in `(*copTask).finishIndexPlan`. + ts.SetStats(&property.StatsInfo{StatsVersion: ds.TableStats.StatsVersion}) + usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) + if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { + ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) + } + cop.tablePlan = ts + } + task = cop + if cop.tablePlan != nil && ds.TableInfo.IsCommonHandle { + cop.commonHandleCols = ds.CommonHandleCols + commonHandle := ds.HandleCols.(*util.CommonHandleCols) + for _, col := range commonHandle.GetColumns() { + if ds.Schema().ColumnIndex(col) == -1 { + ts := cop.tablePlan.(*PhysicalTableScan) + ts.Schema().Append(col) + ts.Columns = append(ts.Columns, col.ToInfo()) + cop.needExtraProj = true + } + } + } + if candidate.isMatchProp { + cop.keepOrder = true + if cop.tablePlan != nil && !ds.TableInfo.IsCommonHandle { + col, isNew := cop.tablePlan.(*PhysicalTableScan).appendExtraHandleCol(ds) + cop.extraHandleCol = col + cop.needExtraProj = cop.needExtraProj || isNew + } + + if ds.TableInfo.GetPartitionInfo() != nil { + // Add sort items for index scan for merge-sort operation between partitions, only required for local index. + if !is.Index.Global { + byItems := make([]*util.ByItems, 0, len(prop.SortItems)) + for _, si := range prop.SortItems { + byItems = append(byItems, &util.ByItems{ + Expr: si.Col, + Desc: si.Desc, + }) + } + cop.indexPlan.(*PhysicalIndexScan).ByItems = byItems + } + if cop.tablePlan != nil && ds.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + if !is.Index.Global { + is.Columns, is.schema, _ = AddExtraPhysTblIDColumn(is.SCtx(), is.Columns, is.Schema()) + } + var succ bool + // global index for tableScan with keepOrder also need PhysicalTblID + ts := cop.tablePlan.(*PhysicalTableScan) + ts.Columns, ts.schema, succ = AddExtraPhysTblIDColumn(ts.SCtx(), ts.Columns, ts.Schema()) + cop.needExtraProj = cop.needExtraProj || succ + } + } + } + if cop.needExtraProj { + cop.originSchema = ds.Schema() + } + // prop.IsSortItemEmpty() would always return true when coming to here, + // so we can just use prop.ExpectedCnt as parameter of addPushedDownSelection. + finalStats := ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt) + if err = is.addPushedDownSelection(cop, ds, path, finalStats); err != nil { + return base.InvalidTask, err + } + if prop.TaskTp == property.RootTaskType { + task = task.ConvertToRootTask(ds.SCtx()) + } else if _, ok := task.(*RootTask); ok { + return base.InvalidTask, nil + } + return task, nil +} + +func (is *PhysicalIndexScan) getScanRowSize() float64 { + idx := is.Index + scanCols := make([]*expression.Column, 0, len(idx.Columns)+1) + // If `initSchema` has already appended the handle column in schema, just use schema columns, otherwise, add extra handle column. + if len(idx.Columns) == len(is.schema.Columns) { + scanCols = append(scanCols, is.schema.Columns...) + handleCol := is.pkIsHandleCol + if handleCol != nil { + scanCols = append(scanCols, handleCol) + } + } else { + scanCols = is.schema.Columns + } + return cardinality.GetIndexAvgRowSize(is.SCtx(), is.tblColHists, scanCols, is.Index.Unique) +} + +// initSchema is used to set the schema of PhysicalIndexScan. Before calling this, +// make sure the following field of PhysicalIndexScan are initialized: +// +// PhysicalIndexScan.Table *model.TableInfo +// PhysicalIndexScan.Index *model.IndexInfo +// PhysicalIndexScan.Index.Columns []*IndexColumn +// PhysicalIndexScan.IdxCols []*expression.Column +// PhysicalIndexScan.Columns []*model.ColumnInfo +func (is *PhysicalIndexScan) initSchema(idxExprCols []*expression.Column, isDoubleRead bool) { + indexCols := make([]*expression.Column, len(is.IdxCols), len(is.Index.Columns)+1) + copy(indexCols, is.IdxCols) + + for i := len(is.IdxCols); i < len(is.Index.Columns); i++ { + if idxExprCols[i] != nil { + indexCols = append(indexCols, idxExprCols[i]) + } else { + // TODO: try to reuse the col generated when building the DataSource. + indexCols = append(indexCols, &expression.Column{ + ID: is.Table.Columns[is.Index.Columns[i].Offset].ID, + RetType: &is.Table.Columns[is.Index.Columns[i].Offset].FieldType, + UniqueID: is.SCtx().GetSessionVars().AllocPlanColumnID(), + }) + } + } + is.NeedCommonHandle = is.Table.IsCommonHandle + + if is.NeedCommonHandle { + for i := len(is.Index.Columns); i < len(idxExprCols); i++ { + indexCols = append(indexCols, idxExprCols[i]) + } + } + setHandle := len(indexCols) > len(is.Index.Columns) + if !setHandle { + for i, col := range is.Columns { + if (mysql.HasPriKeyFlag(col.GetFlag()) && is.Table.PKIsHandle) || col.ID == model.ExtraHandleID { + indexCols = append(indexCols, is.dataSourceSchema.Columns[i]) + setHandle = true + break + } + } + } + + var extraPhysTblCol *expression.Column + // If `dataSouceSchema` contains `model.ExtraPhysTblID`, we should add it into `indexScan.schema` + for _, col := range is.dataSourceSchema.Columns { + if col.ID == model.ExtraPhysTblID { + extraPhysTblCol = col.Clone().(*expression.Column) + break + } + } + + if isDoubleRead || is.Index.Global { + // If it's double read case, the first index must return handle. So we should add extra handle column + // if there isn't a handle column. + if !setHandle { + if !is.Table.IsCommonHandle { + indexCols = append(indexCols, &expression.Column{ + RetType: types.NewFieldType(mysql.TypeLonglong), + ID: model.ExtraHandleID, + UniqueID: is.SCtx().GetSessionVars().AllocPlanColumnID(), + OrigName: model.ExtraHandleName.O, + }) + } + } + // If it's global index, handle and PhysTblID columns has to be added, so that needed pids can be filtered. + if is.Index.Global && extraPhysTblCol == nil { + indexCols = append(indexCols, &expression.Column{ + RetType: types.NewFieldType(mysql.TypeLonglong), + ID: model.ExtraPhysTblID, + UniqueID: is.SCtx().GetSessionVars().AllocPlanColumnID(), + OrigName: model.ExtraPhysTblIdName.O, + }) + } + } + + if extraPhysTblCol != nil { + indexCols = append(indexCols, extraPhysTblCol) + } + + is.SetSchema(expression.NewSchema(indexCols...)) +} + +func (is *PhysicalIndexScan) addSelectionConditionForGlobalIndex(p *DataSource, physPlanPartInfo *PhysPlanPartInfo, conditions []expression.Expression) ([]expression.Expression, error) { + if !is.Index.Global { + return conditions, nil + } + args := make([]expression.Expression, 0, len(p.PartitionNames)+1) + for _, col := range is.schema.Columns { + if col.ID == model.ExtraPhysTblID { + args = append(args, col.Clone()) + break + } + } + + if len(args) != 1 { + return nil, errors.Errorf("Can't find column %s in schema %s", model.ExtraPhysTblIdName.O, is.schema) + } + + // For SQL like 'select x from t partition(p0, p1) use index(idx)', + // we will add a `Selection` like `in(t._tidb_pid, p0, p1)` into the plan. + // For truncate/drop partitions, we should only return indexes where partitions still in public state. + idxArr, err := PartitionPruning(p.SCtx(), p.table.GetPartitionedTable(), + physPlanPartInfo.PruningConds, + physPlanPartInfo.PartitionNames, + physPlanPartInfo.Columns, + physPlanPartInfo.ColumnNames) + if err != nil { + return nil, err + } + needNot := false + pInfo := p.TableInfo.GetPartitionInfo() + if len(idxArr) == 1 && idxArr[0] == FullRange { + // Only filter adding and dropping partitions. + if len(pInfo.AddingDefinitions) == 0 && len(pInfo.DroppingDefinitions) == 0 { + return conditions, nil + } + needNot = true + for _, p := range pInfo.AddingDefinitions { + args = append(args, expression.NewInt64Const(p.ID)) + } + for _, p := range pInfo.DroppingDefinitions { + args = append(args, expression.NewInt64Const(p.ID)) + } + } else if len(idxArr) == 0 { + // add an invalid pid as param for `IN` function + args = append(args, expression.NewInt64Const(-1)) + } else { + // `PartitionPruning`` func does not return adding and dropping partitions + for _, idx := range idxArr { + args = append(args, expression.NewInt64Const(pInfo.Definitions[idx].ID)) + } + } + condition, err := expression.NewFunction(p.SCtx().GetExprCtx(), ast.In, types.NewFieldType(mysql.TypeLonglong), args...) + if err != nil { + return nil, err + } + if needNot { + condition, err = expression.NewFunction(p.SCtx().GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeLonglong), condition) + if err != nil { + return nil, err + } + } + return append(conditions, condition), nil +} + +func (is *PhysicalIndexScan) addPushedDownSelection(copTask *CopTask, p *DataSource, path *util.AccessPath, finalStats *property.StatsInfo) error { + // Add filter condition to table plan now. + indexConds, tableConds := path.IndexFilters, path.TableFilters + tableConds, copTask.rootTaskConds = SplitSelCondsWithVirtualColumn(tableConds) + + var newRootConds []expression.Expression + pctx := GetPushDownCtx(is.SCtx()) + indexConds, newRootConds = expression.PushDownExprs(pctx, indexConds, kv.TiKV) + copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) + + tableConds, newRootConds = expression.PushDownExprs(pctx, tableConds, kv.TiKV) + copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) + + // Add a `Selection` for `IndexScan` with global index. + // It should pushdown to TiKV, DataSource schema doesn't contain partition id column. + indexConds, err := is.addSelectionConditionForGlobalIndex(p, copTask.physPlanPartInfo, indexConds) + if err != nil { + return err + } + + if indexConds != nil { + var selectivity float64 + if path.CountAfterAccess > 0 { + selectivity = path.CountAfterIndex / path.CountAfterAccess + } + count := is.StatsInfo().RowCount * selectivity + stats := p.TableStats.ScaleByExpectCnt(count) + indexSel := PhysicalSelection{Conditions: indexConds}.Init(is.SCtx(), stats, is.QueryBlockOffset()) + indexSel.SetChildren(is) + copTask.indexPlan = indexSel + } + if len(tableConds) > 0 { + copTask.finishIndexPlan() + tableSel := PhysicalSelection{Conditions: tableConds}.Init(is.SCtx(), finalStats, is.QueryBlockOffset()) + if len(copTask.rootTaskConds) != 0 { + selectivity, _, err := cardinality.Selectivity(is.SCtx(), copTask.tblColHists, tableConds, nil) + if err != nil { + logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) + selectivity = cost.SelectionFactor + } + tableSel.SetStats(copTask.Plan().StatsInfo().Scale(selectivity)) + } + tableSel.SetChildren(copTask.tablePlan) + copTask.tablePlan = tableSel + } + return nil +} + +// NeedExtraOutputCol is designed for check whether need an extra column for +// pid or physical table id when build indexReq. +func (is *PhysicalIndexScan) NeedExtraOutputCol() bool { + if is.Table.Partition == nil { + return false + } + // has global index, should return pid + if is.Index.Global { + return true + } + // has embedded limit, should return physical table id + if len(is.ByItems) != 0 && is.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + return true + } + return false +} + +// SplitSelCondsWithVirtualColumn filter the select conditions which contain virtual column +func SplitSelCondsWithVirtualColumn(conds []expression.Expression) (withoutVirt []expression.Expression, withVirt []expression.Expression) { + for i := range conds { + if expression.ContainVirtualColumn(conds[i : i+1]) { + withVirt = append(withVirt, conds[i]) + } else { + withoutVirt = append(withoutVirt, conds[i]) + } + } + return withoutVirt, withVirt +} + +func matchIndicesProp(sctx base.PlanContext, idxCols []*expression.Column, colLens []int, propItems []property.SortItem) bool { + if len(idxCols) < len(propItems) { + return false + } + for i, item := range propItems { + if colLens[i] != types.UnspecifiedLength || !item.Col.EqualByExprAndID(sctx.GetExprCtx().GetEvalCtx(), idxCols[i]) { + return false + } + } + return true +} + +func (ds *DataSource) splitIndexFilterConditions(conditions []expression.Expression, indexColumns []*expression.Column, + idxColLens []int) (indexConds, tableConds []expression.Expression) { + var indexConditions, tableConditions []expression.Expression + for _, cond := range conditions { + var covered bool + if ds.SCtx().GetSessionVars().OptPrefixIndexSingleScan { + covered = ds.isIndexCoveringCondition(cond, indexColumns, idxColLens) + } else { + covered = ds.isIndexCoveringColumns(expression.ExtractColumns(cond), indexColumns, idxColLens) + } + if covered { + indexConditions = append(indexConditions, cond) + } else { + tableConditions = append(tableConditions, cond) + } + } + return indexConditions, tableConditions +} + +// GetPhysicalScan4LogicalTableScan returns PhysicalTableScan for the LogicalTableScan. +func GetPhysicalScan4LogicalTableScan(s *LogicalTableScan, schema *expression.Schema, stats *property.StatsInfo) *PhysicalTableScan { + ds := s.Source + ts := PhysicalTableScan{ + Table: ds.TableInfo, + Columns: ds.Columns, + TableAsName: ds.TableAsName, + DBName: ds.DBName, + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + Ranges: s.Ranges, + AccessCondition: s.AccessConds, + tblCols: ds.TblCols, + tblColHists: ds.TblColHists, + }.Init(s.SCtx(), s.QueryBlockOffset()) + ts.SetStats(stats) + ts.SetSchema(schema.Clone()) + return ts +} + +// GetPhysicalIndexScan4LogicalIndexScan returns PhysicalIndexScan for the logical IndexScan. +func GetPhysicalIndexScan4LogicalIndexScan(s *LogicalIndexScan, _ *expression.Schema, stats *property.StatsInfo) *PhysicalIndexScan { + ds := s.Source + is := PhysicalIndexScan{ + Table: ds.TableInfo, + TableAsName: ds.TableAsName, + DBName: ds.DBName, + Columns: s.Columns, + Index: s.Index, + IdxCols: s.IdxCols, + IdxColLens: s.IdxColLens, + AccessCondition: s.AccessConds, + Ranges: s.Ranges, + dataSourceSchema: ds.Schema(), + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + tblColHists: ds.TblColHists, + pkIsHandleCol: ds.getPKIsHandleCol(), + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + is.SetStats(stats) + is.initSchema(s.FullIdxCols, s.IsDoubleRead) + return is +} + +// isPointGetPath indicates whether the conditions are point-get-able. +// eg: create table t(a int, b int,c int unique, primary (a,b)) +// select * from t where a = 1 and b = 1 and c =1; +// the datasource can access by primary key(a,b) or unique key c which are both point-get-able +func (ds *DataSource) isPointGetPath(path *util.AccessPath) bool { + if len(path.Ranges) < 1 { + return false + } + if !path.IsIntHandlePath { + if path.Index == nil { + return false + } + if !path.Index.Unique || path.Index.HasPrefixIndex() { + return false + } + idxColsLen := len(path.Index.Columns) + for _, ran := range path.Ranges { + if len(ran.LowVal) != idxColsLen { + return false + } + } + } + tc := ds.SCtx().GetSessionVars().StmtCtx.TypeCtx() + for _, ran := range path.Ranges { + if !ran.IsPointNonNullable(tc) { + return false + } + } + return true +} + +// convertToTableScan converts the DataSource to table scan. +func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, error) { + // It will be handled in convertToIndexScan. + if prop.TaskTp == property.CopMultiReadTaskType { + return base.InvalidTask, nil + } + if !prop.IsSortItemEmpty() && !candidate.isMatchProp { + return base.InvalidTask, nil + } + // If we need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. + if prop.IsSortItemEmpty() && candidate.path.ForceKeepOrder { + return base.InvalidTask, nil + } + // If we don't need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. + if !prop.IsSortItemEmpty() && candidate.path.ForceNoKeepOrder { + return base.InvalidTask, nil + } + ts, _ := ds.getOriginalPhysicalTableScan(prop, candidate.path, candidate.isMatchProp) + if ts.KeepOrder && ts.StoreType == kv.TiFlash && (ts.Desc || ds.SCtx().GetSessionVars().TiFlashFastScan) { + // TiFlash fast mode(https://github.com/pingcap/tidb/pull/35851) does not keep order in TableScan + return base.InvalidTask, nil + } + if ts.StoreType == kv.TiFlash { + for _, col := range ts.Columns { + if col.IsVirtualGenerated() { + col.AddFlag(mysql.GeneratedColumnFlag) + } + } + } + // In disaggregated tiflash mode, only MPP is allowed, cop and batchCop is deprecated. + // So if prop.TaskTp is RootTaskType, have to use mppTask then convert to rootTask. + isTiFlashPath := ts.StoreType == kv.TiFlash + canMppConvertToRoot := prop.TaskTp == property.RootTaskType && ds.SCtx().GetSessionVars().IsMPPAllowed() && isTiFlashPath + canMppConvertToRootForDisaggregatedTiFlash := config.GetGlobalConfig().DisaggregatedTiFlash && canMppConvertToRoot + canMppConvertToRootForWhenTiFlashCopIsBanned := ds.SCtx().GetSessionVars().IsTiFlashCopBanned() && canMppConvertToRoot + if prop.TaskTp == property.MppTaskType || canMppConvertToRootForDisaggregatedTiFlash || canMppConvertToRootForWhenTiFlashCopIsBanned { + if ts.KeepOrder { + return base.InvalidTask, nil + } + if prop.MPPPartitionTp != property.AnyType { + return base.InvalidTask, nil + } + // ********************************** future deprecated start **************************/ + var hasVirtualColumn bool + for _, col := range ts.schema.Columns { + if col.VirtualExpr != nil { + ds.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because column `" + col.OrigName + "` is a virtual column which is not supported now.") + hasVirtualColumn = true + break + } + } + // in general, since MPP has supported the Gather operator to fill the virtual column, we should full lift restrictions here. + // we left them here, because cases like: + // parent-----+ + // V (when parent require a root task type here, we need convert mpp task to root task) + // projection [mpp task] [a] + // table-scan [mpp task] [a(virtual col as: b+1), b] + // in the process of converting mpp task to root task, the encapsulated table reader will use its first children schema [a] + // as its schema, so when we resolve indices later, the virtual column 'a' itself couldn't resolve itself anymore. + // + if hasVirtualColumn && !canMppConvertToRootForDisaggregatedTiFlash && !canMppConvertToRootForWhenTiFlashCopIsBanned { + return base.InvalidTask, nil + } + // ********************************** future deprecated end **************************/ + mppTask := &MppTask{ + p: ts, + partTp: property.AnyType, + tblColHists: ds.TblColHists, + } + ts.PlanPartInfo = &PhysPlanPartInfo{ + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), + PartitionNames: ds.PartitionNames, + Columns: ds.TblCols, + ColumnNames: ds.OutputNames(), + } + mppTask = ts.addPushedDownSelectionToMppTask(mppTask, ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt)) + var task base.Task = mppTask + if !mppTask.Invalid() { + if prop.TaskTp == property.MppTaskType && len(mppTask.rootTaskConds) > 0 { + // If got filters cannot be pushed down to tiflash, we have to make sure it will be executed in TiDB, + // So have to return a rootTask, but prop requires mppTask, cannot meet this requirement. + task = base.InvalidTask + } else if prop.TaskTp == property.RootTaskType { + // When got here, canMppConvertToRootX is true. + // This is for situations like cannot generate mppTask for some operators. + // Such as when the build side of HashJoin is Projection, + // which cannot pushdown to tiflash(because TiFlash doesn't support some expr in Proj) + // So HashJoin cannot pushdown to tiflash. But we still want TableScan to run on tiflash. + task = mppTask + task = task.ConvertToRootTask(ds.SCtx()) + } + } + return task, nil + } + if isTiFlashPath && config.GetGlobalConfig().DisaggregatedTiFlash || isTiFlashPath && ds.SCtx().GetSessionVars().IsTiFlashCopBanned() { + // prop.TaskTp is cop related, just return base.InvalidTask. + return base.InvalidTask, nil + } + copTask := &CopTask{ + tablePlan: ts, + indexPlanFinished: true, + tblColHists: ds.TblColHists, + } + copTask.physPlanPartInfo = &PhysPlanPartInfo{ + PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), + PartitionNames: ds.PartitionNames, + Columns: ds.TblCols, + ColumnNames: ds.OutputNames(), + } + ts.PlanPartInfo = copTask.physPlanPartInfo + var task base.Task = copTask + if candidate.isMatchProp { + copTask.keepOrder = true + if ds.TableInfo.GetPartitionInfo() != nil { + // TableScan on partition table on TiFlash can't keep order. + if ts.StoreType == kv.TiFlash { + return base.InvalidTask, nil + } + // Add sort items for table scan for merge-sort operation between partitions. + byItems := make([]*util.ByItems, 0, len(prop.SortItems)) + for _, si := range prop.SortItems { + byItems = append(byItems, &util.ByItems{ + Expr: si.Col, + Desc: si.Desc, + }) + } + ts.ByItems = byItems + } + } + ts.addPushedDownSelection(copTask, ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt)) + if prop.IsFlashProp() && len(copTask.rootTaskConds) != 0 { + return base.InvalidTask, nil + } + if prop.TaskTp == property.RootTaskType { + task = task.ConvertToRootTask(ds.SCtx()) + } else if _, ok := task.(*RootTask); ok { + return base.InvalidTask, nil + } + return task, nil +} + +func (ds *DataSource) convertToSampleTable(prop *property.PhysicalProperty, + candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, error) { + if prop.TaskTp == property.CopMultiReadTaskType { + return base.InvalidTask, nil + } + if !prop.IsSortItemEmpty() && !candidate.isMatchProp { + return base.InvalidTask, nil + } + if candidate.isMatchProp { + // Disable keep order property for sample table path. + return base.InvalidTask, nil + } + p := PhysicalTableSample{ + TableSampleInfo: ds.SampleInfo, + TableInfo: ds.table, + PhysicalTableID: ds.PhysicalTableID, + Desc: candidate.isMatchProp && prop.SortItems[0].Desc, + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + p.schema = ds.Schema() + rt := &RootTask{} + rt.SetPlan(p) + return rt, nil +} + +func (ds *DataSource) convertToPointGet(prop *property.PhysicalProperty, candidate *candidatePath) base.Task { + if !prop.IsSortItemEmpty() && !candidate.isMatchProp { + return base.InvalidTask + } + if prop.TaskTp == property.CopMultiReadTaskType && candidate.path.IsSingleScan || + prop.TaskTp == property.CopSingleReadTaskType && !candidate.path.IsSingleScan { + return base.InvalidTask + } + + if tidbutil.IsMemDB(ds.DBName.L) { + return base.InvalidTask + } + + accessCnt := math.Min(candidate.path.CountAfterAccess, float64(1)) + pointGetPlan := PointGetPlan{ + ctx: ds.SCtx(), + AccessConditions: candidate.path.AccessConds, + schema: ds.Schema().Clone(), + dbName: ds.DBName.L, + TblInfo: ds.TableInfo, + outputNames: ds.OutputNames(), + LockWaitTime: ds.SCtx().GetSessionVars().LockWaitTimeout, + Columns: ds.Columns, + }.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.QueryBlockOffset()) + if ds.PartitionDefIdx != nil { + pointGetPlan.PartitionIdx = ds.PartitionDefIdx + } + pointGetPlan.PartitionNames = ds.PartitionNames + rTsk := &RootTask{} + rTsk.SetPlan(pointGetPlan) + if candidate.path.IsIntHandlePath { + pointGetPlan.Handle = kv.IntHandle(candidate.path.Ranges[0].LowVal[0].GetInt64()) + pointGetPlan.UnsignedHandle = mysql.HasUnsignedFlag(ds.HandleCols.GetCol(0).RetType.GetFlag()) + pointGetPlan.accessCols = ds.TblCols + found := false + for i := range ds.Columns { + if ds.Columns[i].ID == ds.HandleCols.GetCol(0).ID { + pointGetPlan.HandleColOffset = ds.Columns[i].Offset + found = true + break + } + } + if !found { + return base.InvalidTask + } + // Add filter condition to table plan now. + if len(candidate.path.TableFilters) > 0 { + sel := PhysicalSelection{ + Conditions: candidate.path.TableFilters, + }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) + sel.SetChildren(pointGetPlan) + rTsk.SetPlan(sel) + } + } else { + pointGetPlan.IndexInfo = candidate.path.Index + pointGetPlan.IdxCols = candidate.path.IdxCols + pointGetPlan.IdxColLens = candidate.path.IdxColLens + pointGetPlan.IndexValues = candidate.path.Ranges[0].LowVal + if candidate.path.IsSingleScan { + pointGetPlan.accessCols = candidate.path.IdxCols + } else { + pointGetPlan.accessCols = ds.TblCols + } + // Add index condition to table plan now. + if len(candidate.path.IndexFilters)+len(candidate.path.TableFilters) > 0 { + sel := PhysicalSelection{ + Conditions: append(candidate.path.IndexFilters, candidate.path.TableFilters...), + }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) + sel.SetChildren(pointGetPlan) + rTsk.SetPlan(sel) + } + } + + return rTsk +} + +func (ds *DataSource) convertToBatchPointGet(prop *property.PhysicalProperty, candidate *candidatePath) base.Task { + if !prop.IsSortItemEmpty() && !candidate.isMatchProp { + return base.InvalidTask + } + if prop.TaskTp == property.CopMultiReadTaskType && candidate.path.IsSingleScan || + prop.TaskTp == property.CopSingleReadTaskType && !candidate.path.IsSingleScan { + return base.InvalidTask + } + + accessCnt := math.Min(candidate.path.CountAfterAccess, float64(len(candidate.path.Ranges))) + batchPointGetPlan := &BatchPointGetPlan{ + ctx: ds.SCtx(), + dbName: ds.DBName.L, + AccessConditions: candidate.path.AccessConds, + TblInfo: ds.TableInfo, + KeepOrder: !prop.IsSortItemEmpty(), + Columns: ds.Columns, + PartitionNames: ds.PartitionNames, + } + if ds.PartitionDefIdx != nil { + batchPointGetPlan.SinglePartition = true + batchPointGetPlan.PartitionIdxs = []int{*ds.PartitionDefIdx} + } + if batchPointGetPlan.KeepOrder { + batchPointGetPlan.Desc = prop.SortItems[0].Desc + } + rTsk := &RootTask{} + if candidate.path.IsIntHandlePath { + for _, ran := range candidate.path.Ranges { + batchPointGetPlan.Handles = append(batchPointGetPlan.Handles, kv.IntHandle(ran.LowVal[0].GetInt64())) + } + batchPointGetPlan.accessCols = ds.TblCols + found := false + for i := range ds.Columns { + if ds.Columns[i].ID == ds.HandleCols.GetCol(0).ID { + batchPointGetPlan.HandleColOffset = ds.Columns[i].Offset + found = true + break + } + } + if !found { + return base.InvalidTask + } + + // Add filter condition to table plan now. + if len(candidate.path.TableFilters) > 0 { + batchPointGetPlan.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.Schema().Clone(), ds.OutputNames(), ds.QueryBlockOffset()) + sel := PhysicalSelection{ + Conditions: candidate.path.TableFilters, + }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) + sel.SetChildren(batchPointGetPlan) + rTsk.SetPlan(sel) + } + } else { + batchPointGetPlan.IndexInfo = candidate.path.Index + batchPointGetPlan.IdxCols = candidate.path.IdxCols + batchPointGetPlan.IdxColLens = candidate.path.IdxColLens + for _, ran := range candidate.path.Ranges { + batchPointGetPlan.IndexValues = append(batchPointGetPlan.IndexValues, ran.LowVal) + } + if !prop.IsSortItemEmpty() { + batchPointGetPlan.KeepOrder = true + batchPointGetPlan.Desc = prop.SortItems[0].Desc + } + if candidate.path.IsSingleScan { + batchPointGetPlan.accessCols = candidate.path.IdxCols + } else { + batchPointGetPlan.accessCols = ds.TblCols + } + // Add index condition to table plan now. + if len(candidate.path.IndexFilters)+len(candidate.path.TableFilters) > 0 { + batchPointGetPlan.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.Schema().Clone(), ds.OutputNames(), ds.QueryBlockOffset()) + sel := PhysicalSelection{ + Conditions: append(candidate.path.IndexFilters, candidate.path.TableFilters...), + }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) + sel.SetChildren(batchPointGetPlan) + rTsk.SetPlan(sel) + } + } + if rTsk.GetPlan() == nil { + tmpP := batchPointGetPlan.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.Schema().Clone(), ds.OutputNames(), ds.QueryBlockOffset()) + rTsk.SetPlan(tmpP) + } + + return rTsk +} + +func (ts *PhysicalTableScan) addPushedDownSelectionToMppTask(mpp *MppTask, stats *property.StatsInfo) *MppTask { + filterCondition, rootTaskConds := SplitSelCondsWithVirtualColumn(ts.filterCondition) + var newRootConds []expression.Expression + filterCondition, newRootConds = expression.PushDownExprs(GetPushDownCtx(ts.SCtx()), filterCondition, ts.StoreType) + mpp.rootTaskConds = append(rootTaskConds, newRootConds...) + + ts.filterCondition = filterCondition + // Add filter condition to table plan now. + if len(ts.filterCondition) > 0 { + sel := PhysicalSelection{Conditions: ts.filterCondition}.Init(ts.SCtx(), stats, ts.QueryBlockOffset()) + sel.SetChildren(ts) + mpp.p = sel + } + return mpp +} + +func (ts *PhysicalTableScan) addPushedDownSelection(copTask *CopTask, stats *property.StatsInfo) { + ts.filterCondition, copTask.rootTaskConds = SplitSelCondsWithVirtualColumn(ts.filterCondition) + var newRootConds []expression.Expression + ts.filterCondition, newRootConds = expression.PushDownExprs(GetPushDownCtx(ts.SCtx()), ts.filterCondition, ts.StoreType) + copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) + + // Add filter condition to table plan now. + if len(ts.filterCondition) > 0 { + sel := PhysicalSelection{Conditions: ts.filterCondition}.Init(ts.SCtx(), stats, ts.QueryBlockOffset()) + if len(copTask.rootTaskConds) != 0 { + selectivity, _, err := cardinality.Selectivity(ts.SCtx(), copTask.tblColHists, ts.filterCondition, nil) + if err != nil { + logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) + selectivity = cost.SelectionFactor + } + sel.SetStats(ts.StatsInfo().Scale(selectivity)) + } + sel.SetChildren(ts) + copTask.tablePlan = sel + } +} + +func (ts *PhysicalTableScan) getScanRowSize() float64 { + if ts.StoreType == kv.TiKV { + return cardinality.GetTableAvgRowSize(ts.SCtx(), ts.tblColHists, ts.tblCols, ts.StoreType, true) + } + // If `ts.handleCol` is nil, then the schema of tableScan doesn't have handle column. + // This logic can be ensured in column pruning. + return cardinality.GetTableAvgRowSize(ts.SCtx(), ts.tblColHists, ts.Schema().Columns, ts.StoreType, ts.HandleCols != nil) +} + +func (ds *DataSource) getOriginalPhysicalTableScan(prop *property.PhysicalProperty, path *util.AccessPath, isMatchProp bool) (*PhysicalTableScan, float64) { + ts := PhysicalTableScan{ + Table: ds.TableInfo, + Columns: slices.Clone(ds.Columns), + TableAsName: ds.TableAsName, + DBName: ds.DBName, + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + Ranges: path.Ranges, + AccessCondition: path.AccessConds, + StoreType: path.StoreType, + HandleCols: ds.HandleCols, + tblCols: ds.TblCols, + tblColHists: ds.TblColHists, + constColsByCond: path.ConstCols, + prop: prop, + filterCondition: slices.Clone(path.TableFilters), + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + ts.SetSchema(ds.Schema().Clone()) + rowCount := path.CountAfterAccess + if prop.ExpectedCnt < ds.StatsInfo().RowCount { + rowCount = cardinality.AdjustRowCountForTableScanByLimit(ds.SCtx(), + ds.StatsInfo(), ds.TableStats, ds.StatisticTable, + path, prop.ExpectedCnt, isMatchProp && prop.SortItems[0].Desc) + } + // We need NDV of columns since it may be used in cost estimation of join. Precisely speaking, + // we should track NDV of each histogram bucket, and sum up the NDV of buckets we actually need + // to scan, but this would only help improve accuracy of NDV for one column, for other columns, + // we still need to assume values are uniformly distributed. For simplicity, we use uniform-assumption + // for all columns now, as we do in `deriveStatsByFilter`. + ts.SetStats(ds.TableStats.ScaleByExpectCnt(rowCount)) + usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) + if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { + ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) + } + if isMatchProp { + ts.Desc = prop.SortItems[0].Desc + ts.KeepOrder = true + } + return ts, rowCount +} + +func (ds *DataSource) getOriginalPhysicalIndexScan(prop *property.PhysicalProperty, path *util.AccessPath, isMatchProp bool, isSingleScan bool) *PhysicalIndexScan { + idx := path.Index + is := PhysicalIndexScan{ + Table: ds.TableInfo, + TableAsName: ds.TableAsName, + DBName: ds.DBName, + Columns: util.CloneColInfos(ds.Columns), + Index: idx, + IdxCols: path.IdxCols, + IdxColLens: path.IdxColLens, + AccessCondition: path.AccessConds, + Ranges: path.Ranges, + dataSourceSchema: ds.Schema(), + isPartition: ds.PartitionDefIdx != nil, + physicalTableID: ds.PhysicalTableID, + tblColHists: ds.TblColHists, + pkIsHandleCol: ds.getPKIsHandleCol(), + constColsByCond: path.ConstCols, + prop: prop, + }.Init(ds.SCtx(), ds.QueryBlockOffset()) + rowCount := path.CountAfterAccess + is.initSchema(append(path.FullIdxCols, ds.CommonHandleCols...), !isSingleScan) + + // If (1) there exists an index whose selectivity is smaller than the threshold, + // and (2) there is Selection on the IndexScan, we don't use the ExpectedCnt to + // adjust the estimated row count of the IndexScan. + ignoreExpectedCnt := ds.AccessPathMinSelectivity < ds.SCtx().GetSessionVars().OptOrderingIdxSelThresh && + len(path.IndexFilters)+len(path.TableFilters) > 0 + + if (isMatchProp || prop.IsSortItemEmpty()) && prop.ExpectedCnt < ds.StatsInfo().RowCount && !ignoreExpectedCnt { + rowCount = cardinality.AdjustRowCountForIndexScanByLimit(ds.SCtx(), + ds.StatsInfo(), ds.TableStats, ds.StatisticTable, + path, prop.ExpectedCnt, isMatchProp && prop.SortItems[0].Desc) + } + // ScaleByExpectCnt only allows to scale the row count smaller than the table total row count. + // But for MV index, it's possible that the IndexRangeScan row count is larger than the table total row count. + // Please see the Case 2 in CalcTotalSelectivityForMVIdxPath for an example. + if idx.MVIndex && rowCount > ds.TableStats.RowCount { + is.SetStats(ds.TableStats.Scale(rowCount / ds.TableStats.RowCount)) + } else { + is.SetStats(ds.TableStats.ScaleByExpectCnt(rowCount)) + } + usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) + if usedStats != nil && usedStats.GetUsedInfo(is.physicalTableID) != nil { + is.usedStatsInfo = usedStats.GetUsedInfo(is.physicalTableID) + } + if isMatchProp { + is.Desc = prop.SortItems[0].Desc + is.KeepOrder = true + } + return is +} + +func findBestTask4LogicalCTE(p *LogicalCTE, prop *property.PhysicalProperty, counter *base.PlanCounterTp, pop *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { + if p.ChildLen() > 0 { + return p.BaseLogicalPlan.FindBestTask(prop, counter, pop) + } + if !prop.IsSortItemEmpty() && !prop.CanAddEnforcer { + return base.InvalidTask, 1, nil + } + // The physical plan has been build when derive stats. + pcte := PhysicalCTE{SeedPlan: p.Cte.seedPartPhysicalPlan, RecurPlan: p.Cte.recursivePartPhysicalPlan, CTE: p.Cte, cteAsName: p.CteAsName, cteName: p.CteName}.Init(p.SCtx(), p.StatsInfo()) + pcte.SetSchema(p.Schema()) + if prop.IsFlashProp() && prop.CTEProducerStatus == property.AllCTECanMpp { + pcte.readerReceiver = PhysicalExchangeReceiver{IsCTEReader: true}.Init(p.SCtx(), p.StatsInfo()) + if prop.MPPPartitionTp != property.AnyType { + return base.InvalidTask, 1, nil + } + t = &MppTask{ + p: pcte, + partTp: prop.MPPPartitionTp, + hashCols: prop.MPPPartitionCols, + tblColHists: p.StatsInfo().HistColl, + } + } else { + rt := &RootTask{} + rt.SetPlan(pcte) + rt.SetEmpty(false) + t = rt + } + if prop.CanAddEnforcer { + t = enforceProperty(prop, t, p.Plan.SCtx()) + } + return t, 1, nil +} + +func findBestTask4LogicalCTETable(lp base.LogicalPlan, prop *property.PhysicalProperty, _ *base.PlanCounterTp, _ *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { + p := lp.(*logicalop.LogicalCTETable) + if !prop.IsSortItemEmpty() { + return base.InvalidTask, 0, nil + } + + pcteTable := PhysicalCTETable{IDForStorage: p.IDForStorage}.Init(p.SCtx(), p.StatsInfo()) + pcteTable.SetSchema(p.Schema()) + rt := &RootTask{} + rt.SetPlan(pcteTable) + t = rt + return t, 1, nil +} + +func appendCandidate(lp base.LogicalPlan, task base.Task, prop *property.PhysicalProperty, opt *optimizetrace.PhysicalOptimizeOp) { + if task == nil || task.Invalid() { + return + } + utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, lp, task.Plan(), prop) +} + +// PushDownNot here can convert condition 'not (a != 1)' to 'a = 1'. When we build range from conds, the condition like +// 'not (a != 1)' would not be handled so we need to convert it to 'a = 1', which can be handled when building range. +func pushDownNot(ctx expression.BuildContext, conds []expression.Expression) []expression.Expression { + for i, cond := range conds { + conds[i] = expression.PushDownNot(ctx, cond) + } + return conds +} + +func validateTableSamplePlan(ds *DataSource, t base.Task, err error) error { + if err != nil { + return err + } + if ds.SampleInfo != nil && !t.Invalid() { + if _, ok := t.Plan().(*PhysicalTableSample); !ok { + return expression.ErrInvalidTableSample.GenWithStackByArgs("plan not supported") + } + } + return nil +} diff --git a/pkg/planner/core/logical_join.go b/pkg/planner/core/logical_join.go index 99dd7b4b97aa5..2e37caf53bf0c 100644 --- a/pkg/planner/core/logical_join.go +++ b/pkg/planner/core/logical_join.go @@ -600,12 +600,12 @@ func (p *LogicalJoin) PreparePossibleProperties(_ *expression.Schema, childrenPr // If the hint is not matched, it will get other candidates. // If the hint is not figured, we will pick all candidates. func (p *LogicalJoin) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("MockOnlyEnableIndexHashJoin")); _err_ == nil { if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { indexJoins, _ := tryToGetIndexJoin(p, prop) - failpoint.Return(indexJoins, true, nil) + return indexJoins, true, nil } - }) + } if !isJoinHintSupportedInMPPMode(p.PreferJoinType) { if hasMPPJoinHints(p.PreferJoinType) { diff --git a/pkg/planner/core/logical_join.go__failpoint_stash__ b/pkg/planner/core/logical_join.go__failpoint_stash__ new file mode 100644 index 0000000000000..99dd7b4b97aa5 --- /dev/null +++ b/pkg/planner/core/logical_join.go__failpoint_stash__ @@ -0,0 +1,1672 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "bytes" + "fmt" + "math" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cardinality" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/cost" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util" + "github.com/pingcap/tidb/pkg/planner/funcdep" + "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" + "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" + "github.com/pingcap/tidb/pkg/types" + utilhint "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/intset" + "github.com/pingcap/tidb/pkg/util/plancodec" +) + +// JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, SemiJoin, AntiJoin. +type JoinType int + +const ( + // InnerJoin means inner join. + InnerJoin JoinType = iota + // LeftOuterJoin means left join. + LeftOuterJoin + // RightOuterJoin means right join. + RightOuterJoin + // SemiJoin means if row a in table A matches some rows in B, just output a. + SemiJoin + // AntiSemiJoin means if row a in table A does not match any row in B, then output a. + AntiSemiJoin + // LeftOuterSemiJoin means if row a in table A matches some rows in B, output (a, true), otherwise, output (a, false). + LeftOuterSemiJoin + // AntiLeftOuterSemiJoin means if row a in table A matches some rows in B, output (a, false), otherwise, output (a, true). + AntiLeftOuterSemiJoin +) + +// IsOuterJoin returns if this joiner is an outer joiner +func (tp JoinType) IsOuterJoin() bool { + return tp == LeftOuterJoin || tp == RightOuterJoin || + tp == LeftOuterSemiJoin || tp == AntiLeftOuterSemiJoin +} + +// IsSemiJoin returns if this joiner is a semi/anti-semi joiner +func (tp JoinType) IsSemiJoin() bool { + return tp == SemiJoin || tp == AntiSemiJoin || + tp == LeftOuterSemiJoin || tp == AntiLeftOuterSemiJoin +} + +func (tp JoinType) String() string { + switch tp { + case InnerJoin: + return "inner join" + case LeftOuterJoin: + return "left outer join" + case RightOuterJoin: + return "right outer join" + case SemiJoin: + return "semi join" + case AntiSemiJoin: + return "anti semi join" + case LeftOuterSemiJoin: + return "left outer semi join" + case AntiLeftOuterSemiJoin: + return "anti left outer semi join" + } + return "unsupported join type" +} + +// LogicalJoin is the logical join plan. +type LogicalJoin struct { + logicalop.LogicalSchemaProducer + + JoinType JoinType + Reordered bool + CartesianJoin bool + StraightJoin bool + + // HintInfo stores the join algorithm hint information specified by client. + HintInfo *utilhint.PlanHints + PreferJoinType uint + PreferJoinOrder bool + LeftPreferJoinType uint + RightPreferJoinType uint + + EqualConditions []*expression.ScalarFunction + // NAEQConditions means null aware equal conditions, which is used for null aware semi joins. + NAEQConditions []*expression.ScalarFunction + LeftConditions expression.CNFExprs + RightConditions expression.CNFExprs + OtherConditions expression.CNFExprs + + LeftProperties [][]*expression.Column + RightProperties [][]*expression.Column + + // DefaultValues is only used for left/right outer join, which is values the inner row's should be when the outer table + // doesn't match any inner table's row. + // That it's nil just means the default values is a slice of NULL. + // Currently, only `aggregation push down` phase will set this. + DefaultValues []types.Datum + + // FullSchema contains all the columns that the Join can output. It's ordered as [outer schema..., inner schema...]. + // This is useful for natural joins and "using" joins. In these cases, the join key columns from the + // inner side (or the right side when it's an inner join) will not be in the schema of Join. + // But upper operators should be able to find those "redundant" columns, and the user also can specifically select + // those columns, so we put the "redundant" columns here to make them be able to be found. + // + // For example: + // create table t1(a int, b int); create table t2(a int, b int); + // select * from t1 join t2 using (b); + // schema of the Join will be [t1.b, t1.a, t2.a]; FullSchema will be [t1.a, t1.b, t2.a, t2.b]. + // + // We record all columns and keep them ordered is for correctly handling SQLs like + // select t1.*, t2.* from t1 join t2 using (b); + // (*PlanBuilder).unfoldWildStar() handles the schema for such case. + FullSchema *expression.Schema + FullNames types.NameSlice + + // EqualCondOutCnt indicates the estimated count of joined rows after evaluating `EqualConditions`. + EqualCondOutCnt float64 +} + +// Init initializes LogicalJoin. +func (p LogicalJoin) Init(ctx base.PlanContext, offset int) *LogicalJoin { + p.BaseLogicalPlan = logicalop.NewBaseLogicalPlan(ctx, plancodec.TypeJoin, &p, offset) + return &p +} + +// *************************** start implementation of Plan interface *************************** + +// ExplainInfo implements Plan interface. +func (p *LogicalJoin) ExplainInfo() string { + evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() + buffer := bytes.NewBufferString(p.JoinType.String()) + if len(p.EqualConditions) > 0 { + fmt.Fprintf(buffer, ", equal:%v", p.EqualConditions) + } + if len(p.LeftConditions) > 0 { + fmt.Fprintf(buffer, ", left cond:%s", + expression.SortedExplainExpressionList(evalCtx, p.LeftConditions)) + } + if len(p.RightConditions) > 0 { + fmt.Fprintf(buffer, ", right cond:%s", + expression.SortedExplainExpressionList(evalCtx, p.RightConditions)) + } + if len(p.OtherConditions) > 0 { + fmt.Fprintf(buffer, ", other cond:%s", + expression.SortedExplainExpressionList(evalCtx, p.OtherConditions)) + } + return buffer.String() +} + +// ReplaceExprColumns implements base.LogicalPlan interface. +func (p *LogicalJoin) ReplaceExprColumns(replace map[string]*expression.Column) { + for _, equalExpr := range p.EqualConditions { + ruleutil.ResolveExprAndReplace(equalExpr, replace) + } + for _, leftExpr := range p.LeftConditions { + ruleutil.ResolveExprAndReplace(leftExpr, replace) + } + for _, rightExpr := range p.RightConditions { + ruleutil.ResolveExprAndReplace(rightExpr, replace) + } + for _, otherExpr := range p.OtherConditions { + ruleutil.ResolveExprAndReplace(otherExpr, replace) + } +} + +// *************************** end implementation of Plan interface *************************** + +// *************************** start implementation of logicalPlan interface *************************** + +// HashCode inherits the BaseLogicalPlan.LogicalPlan.<0th> implementation. + +// PredicatePushDown implements the base.LogicalPlan.<1st> interface. +func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (ret []expression.Expression, retPlan base.LogicalPlan) { + var equalCond []*expression.ScalarFunction + var leftPushCond, rightPushCond, otherCond, leftCond, rightCond []expression.Expression + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + predicates = p.outerJoinPropConst(predicates) + dual := Conds2TableDual(p, predicates) + if dual != nil { + appendTableDualTraceStep(p, dual, predicates, opt) + return ret, dual + } + // Handle where conditions + predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) + // Only derive left where condition, because right where condition cannot be pushed down + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true, false) + leftCond = leftPushCond + // Handle join conditions, only derive right join condition, because left join condition cannot be pushed down + _, derivedRightJoinCond := DeriveOtherConditions( + p, p.Children()[0].Schema(), p.Children()[1].Schema(), false, true) + rightCond = append(p.RightConditions, derivedRightJoinCond...) + p.RightConditions = nil + ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) + ret = append(ret, rightPushCond...) + case RightOuterJoin: + predicates = p.outerJoinPropConst(predicates) + dual := Conds2TableDual(p, predicates) + if dual != nil { + appendTableDualTraceStep(p, dual, predicates, opt) + return ret, dual + } + // Handle where conditions + predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) + // Only derive right where condition, because left where condition cannot be pushed down + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, false, true) + rightCond = rightPushCond + // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down + derivedLeftJoinCond, _ := DeriveOtherConditions( + p, p.Children()[0].Schema(), p.Children()[1].Schema(), true, false) + leftCond = append(p.LeftConditions, derivedLeftJoinCond...) + p.LeftConditions = nil + ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) + ret = append(ret, leftPushCond...) + case SemiJoin, InnerJoin: + tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) + tempCond = append(tempCond, p.LeftConditions...) + tempCond = append(tempCond, p.RightConditions...) + tempCond = append(tempCond, expression.ScalarFuncs2Exprs(p.EqualConditions)...) + tempCond = append(tempCond, p.OtherConditions...) + tempCond = append(tempCond, predicates...) + tempCond = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), tempCond) + tempCond = expression.PropagateConstant(p.SCtx().GetExprCtx(), tempCond) + // Return table dual when filter is constant false or null. + dual := Conds2TableDual(p, tempCond) + if dual != nil { + appendTableDualTraceStep(p, dual, tempCond, opt) + return ret, dual + } + equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(tempCond, true, true) + p.LeftConditions = nil + p.RightConditions = nil + p.EqualConditions = equalCond + p.OtherConditions = otherCond + leftCond = leftPushCond + rightCond = rightPushCond + case AntiSemiJoin: + predicates = expression.PropagateConstant(p.SCtx().GetExprCtx(), predicates) + // Return table dual when filter is constant false or null. + dual := Conds2TableDual(p, predicates) + if dual != nil { + appendTableDualTraceStep(p, dual, predicates, opt) + return ret, dual + } + // `predicates` should only contain left conditions or constant filters. + _, leftPushCond, rightPushCond, _ = p.extractOnCondition(predicates, true, true) + // Do not derive `is not null` for anti join, since it may cause wrong results. + // For example: + // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, + // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, + // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, + leftCond = leftPushCond + rightCond = append(p.RightConditions, rightPushCond...) + p.RightConditions = nil + } + leftCond = expression.RemoveDupExprs(leftCond) + rightCond = expression.RemoveDupExprs(rightCond) + leftRet, lCh := p.Children()[0].PredicatePushDown(leftCond, opt) + rightRet, rCh := p.Children()[1].PredicatePushDown(rightCond, opt) + utilfuncp.AddSelection(p, lCh, leftRet, 0, opt) + utilfuncp.AddSelection(p, rCh, rightRet, 1, opt) + p.updateEQCond() + ruleutil.BuildKeyInfoPortal(p) + return ret, p.Self() +} + +// PruneColumns implements the base.LogicalPlan.<2nd> interface. +func (p *LogicalJoin) PruneColumns(parentUsedCols []*expression.Column, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, error) { + leftCols, rightCols := p.extractUsedCols(parentUsedCols) + + var err error + p.Children()[0], err = p.Children()[0].PruneColumns(leftCols, opt) + if err != nil { + return nil, err + } + + p.Children()[1], err = p.Children()[1].PruneColumns(rightCols, opt) + if err != nil { + return nil, err + } + + p.mergeSchema() + if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + joinCol := p.Schema().Columns[len(p.Schema().Columns)-1] + parentUsedCols = append(parentUsedCols, joinCol) + } + p.InlineProjection(parentUsedCols, opt) + return p, nil +} + +// FindBestTask inherits the BaseLogicalPlan.LogicalPlan.<3rd> implementation. + +// BuildKeyInfo implements the base.LogicalPlan.<4th> interface. +func (p *LogicalJoin) BuildKeyInfo(selfSchema *expression.Schema, childSchema []*expression.Schema) { + p.LogicalSchemaProducer.BuildKeyInfo(selfSchema, childSchema) + switch p.JoinType { + case SemiJoin, LeftOuterSemiJoin, AntiSemiJoin, AntiLeftOuterSemiJoin: + selfSchema.Keys = childSchema[0].Clone().Keys + case InnerJoin, LeftOuterJoin, RightOuterJoin: + // If there is no equal conditions, then cartesian product can't be prevented and unique key information will destroy. + if len(p.EqualConditions) == 0 { + return + } + lOk := false + rOk := false + // Such as 'select * from t1 join t2 where t1.a = t2.a and t1.b = t2.b'. + // If one sides (a, b) is a unique key, then the unique key information is remained. + // But we don't consider this situation currently. + // Only key made by one column is considered now. + evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() + for _, expr := range p.EqualConditions { + ln := expr.GetArgs()[0].(*expression.Column) + rn := expr.GetArgs()[1].(*expression.Column) + for _, key := range childSchema[0].Keys { + if len(key) == 1 && key[0].Equal(evalCtx, ln) { + lOk = true + break + } + } + for _, key := range childSchema[1].Keys { + if len(key) == 1 && key[0].Equal(evalCtx, rn) { + rOk = true + break + } + } + } + // For inner join, if one side of one equal condition is unique key, + // another side's unique key information will all be reserved. + // If it's an outer join, NULL value will fill some position, which will destroy the unique key information. + if lOk && p.JoinType != LeftOuterJoin { + selfSchema.Keys = append(selfSchema.Keys, childSchema[1].Keys...) + } + if rOk && p.JoinType != RightOuterJoin { + selfSchema.Keys = append(selfSchema.Keys, childSchema[0].Keys...) + } + } +} + +// PushDownTopN implements the base.LogicalPlan.<5th> interface. +func (p *LogicalJoin) PushDownTopN(topNLogicalPlan base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { + var topN *logicalop.LogicalTopN + if topNLogicalPlan != nil { + topN = topNLogicalPlan.(*logicalop.LogicalTopN) + } + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + p.Children()[0] = p.pushDownTopNToChild(topN, 0, opt) + p.Children()[1] = p.Children()[1].PushDownTopN(nil, opt) + case RightOuterJoin: + p.Children()[1] = p.pushDownTopNToChild(topN, 1, opt) + p.Children()[0] = p.Children()[0].PushDownTopN(nil, opt) + default: + return p.BaseLogicalPlan.PushDownTopN(topN, opt) + } + + // The LogicalJoin may be also a LogicalApply. So we must use self to set parents. + if topN != nil { + return topN.AttachChild(p.Self(), opt) + } + return p.Self() +} + +// DeriveTopN inherits the BaseLogicalPlan.LogicalPlan.<6th> implementation. + +// PredicateSimplification inherits the BaseLogicalPlan.LogicalPlan.<7th> implementation. + +// ConstantPropagation implements the base.LogicalPlan.<8th> interface. +// about the logic of constant propagation in From List. +// Query: select * from t, (select a, b from s where s.a>1) tmp where tmp.a=t.a +// Origin logical plan: +/* + +----------------+ + | LogicalJoin | + +-------^--------+ + | + +-------------+--------------+ + | | ++-----+------+ +------+------+ +| Projection | | TableScan | ++-----^------+ +-------------+ + | + | ++-----+------+ +| Selection | +| s.a>1 | ++------------+ +*/ +// 1. 'PullUpConstantPredicates': Call this function until find selection and pull up the constant predicate layer by layer +// LogicalSelection: find the s.a>1 +// LogicalProjection: get the s.a>1 and pull up it, changed to tmp.a>1 +// 2. 'addCandidateSelection': Add selection above of LogicalJoin, +// put all predicates pulled up from the lower layer into the current new selection. +// LogicalSelection: tmp.a >1 +// +// Optimized plan: +/* + +----------------+ + | Selection | + | tmp.a>1 | + +-------^--------+ + | + +-------+--------+ + | LogicalJoin | + +-------^--------+ + | + +-------------+--------------+ + | | ++-----+------+ +------+------+ +| Projection | | TableScan | ++-----^------+ +-------------+ + | + | ++-----+------+ +| Selection | +| s.a>1 | ++------------+ +*/ +// Return nil if the root of plan has not been changed +// Return new root if the root of plan is changed to selection +func (p *LogicalJoin) ConstantPropagation(parentPlan base.LogicalPlan, currentChildIdx int, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { + // step1: get constant predicate from left or right according to the JoinType + var getConstantPredicateFromLeft bool + var getConstantPredicateFromRight bool + switch p.JoinType { + case LeftOuterJoin: + getConstantPredicateFromLeft = true + case RightOuterJoin: + getConstantPredicateFromRight = true + case InnerJoin: + getConstantPredicateFromLeft = true + getConstantPredicateFromRight = true + default: + return + } + var candidateConstantPredicates []expression.Expression + if getConstantPredicateFromLeft { + candidateConstantPredicates = p.Children()[0].PullUpConstantPredicates() + } + if getConstantPredicateFromRight { + candidateConstantPredicates = append(candidateConstantPredicates, p.Children()[1].PullUpConstantPredicates()...) + } + if len(candidateConstantPredicates) == 0 { + return + } + + // step2: add selection above of LogicalJoin + return addCandidateSelection(p, currentChildIdx, parentPlan, candidateConstantPredicates, opt) +} + +// PullUpConstantPredicates inherits the BaseLogicalPlan.LogicalPlan.<9th> implementation. + +// RecursiveDeriveStats inherits the BaseLogicalPlan.LogicalPlan.<10th> implementation. + +// DeriveStats implements the base.LogicalPlan.<11th> interface. +// If the type of join is SemiJoin, the selectivity of it will be same as selection's. +// If the type of join is LeftOuterSemiJoin, it will not add or remove any row. The last column is a boolean value, whose NDV should be two. +// If the type of join is inner/outer join, the output of join(s, t) should be N(s) * N(t) / (V(s.key) * V(t.key)) * Min(s.key, t.key). +// N(s) stands for the number of rows in relation s. V(s.key) means the NDV of join key in s. +// This is a quite simple strategy: We assume every bucket of relation which will participate join has the same number of rows, and apply cross join for +// every matched bucket. +func (p *LogicalJoin) DeriveStats(childStats []*property.StatsInfo, selfSchema *expression.Schema, childSchema []*expression.Schema, colGroups [][]*expression.Column) (*property.StatsInfo, error) { + if p.StatsInfo() != nil { + // Reload GroupNDVs since colGroups may have changed. + p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) + return p.StatsInfo(), nil + } + leftProfile, rightProfile := childStats[0], childStats[1] + leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() + p.EqualCondOutCnt = cardinality.EstimateFullJoinRowCount(p.SCtx(), + 0 == len(p.EqualConditions), + leftProfile, rightProfile, + leftJoinKeys, rightJoinKeys, + childSchema[0], childSchema[1], + nil, nil) + if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin { + p.SetStats(&property.StatsInfo{ + RowCount: leftProfile.RowCount * cost.SelectionFactor, + ColNDVs: make(map[int64]float64, len(leftProfile.ColNDVs)), + }) + for id, c := range leftProfile.ColNDVs { + p.StatsInfo().ColNDVs[id] = c * cost.SelectionFactor + } + return p.StatsInfo(), nil + } + if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + p.SetStats(&property.StatsInfo{ + RowCount: leftProfile.RowCount, + ColNDVs: make(map[int64]float64, selfSchema.Len()), + }) + for id, c := range leftProfile.ColNDVs { + p.StatsInfo().ColNDVs[id] = c + } + p.StatsInfo().ColNDVs[selfSchema.Columns[selfSchema.Len()-1].UniqueID] = 2.0 + p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) + return p.StatsInfo(), nil + } + count := p.EqualCondOutCnt + if p.JoinType == LeftOuterJoin { + count = math.Max(count, leftProfile.RowCount) + } else if p.JoinType == RightOuterJoin { + count = math.Max(count, rightProfile.RowCount) + } + colNDVs := make(map[int64]float64, selfSchema.Len()) + for id, c := range leftProfile.ColNDVs { + colNDVs[id] = math.Min(c, count) + } + for id, c := range rightProfile.ColNDVs { + colNDVs[id] = math.Min(c, count) + } + p.SetStats(&property.StatsInfo{ + RowCount: count, + ColNDVs: colNDVs, + }) + p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) + return p.StatsInfo(), nil +} + +// ExtractColGroups implements the base.LogicalPlan.<12th> interface. +func (p *LogicalJoin) ExtractColGroups(colGroups [][]*expression.Column) [][]*expression.Column { + leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() + extracted := make([][]*expression.Column, 0, 2+len(colGroups)) + if len(leftJoinKeys) > 1 && (p.JoinType == InnerJoin || p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin) { + extracted = append(extracted, expression.SortColumns(leftJoinKeys), expression.SortColumns(rightJoinKeys)) + } + var outerSchema *expression.Schema + if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + outerSchema = p.Children()[0].Schema() + } else if p.JoinType == RightOuterJoin { + outerSchema = p.Children()[1].Schema() + } + if len(colGroups) == 0 || outerSchema == nil { + return extracted + } + _, offsets := outerSchema.ExtractColGroups(colGroups) + if len(offsets) == 0 { + return extracted + } + for _, offset := range offsets { + extracted = append(extracted, colGroups[offset]) + } + return extracted +} + +// PreparePossibleProperties implements base.LogicalPlan.<13th> interface. +func (p *LogicalJoin) PreparePossibleProperties(_ *expression.Schema, childrenProperties ...[][]*expression.Column) [][]*expression.Column { + leftProperties := childrenProperties[0] + rightProperties := childrenProperties[1] + // TODO: We should consider properties propagation. + p.LeftProperties = leftProperties + p.RightProperties = rightProperties + if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin { + rightProperties = nil + } else if p.JoinType == RightOuterJoin { + leftProperties = nil + } + resultProperties := make([][]*expression.Column, len(leftProperties)+len(rightProperties)) + for i, cols := range leftProperties { + resultProperties[i] = make([]*expression.Column, len(cols)) + copy(resultProperties[i], cols) + } + leftLen := len(leftProperties) + for i, cols := range rightProperties { + resultProperties[leftLen+i] = make([]*expression.Column, len(cols)) + copy(resultProperties[leftLen+i], cols) + } + return resultProperties +} + +// ExhaustPhysicalPlans implements the base.LogicalPlan.<14th> interface. +// it can generates hash join, index join and sort merge join. +// Firstly we check the hint, if hint is figured by user, we force to choose the corresponding physical plan. +// If the hint is not matched, it will get other candidates. +// If the hint is not figured, we will pick all candidates. +func (p *LogicalJoin) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { + failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { + if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { + indexJoins, _ := tryToGetIndexJoin(p, prop) + failpoint.Return(indexJoins, true, nil) + } + }) + + if !isJoinHintSupportedInMPPMode(p.PreferJoinType) { + if hasMPPJoinHints(p.PreferJoinType) { + // If there are MPP hints but has some conflicts join method hints, all the join hints are invalid. + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("The MPP join hints are in conflict, and you can only specify join method hints that are currently supported by MPP mode now") + p.PreferJoinType = 0 + } else { + // If there are no MPP hints but has some conflicts join method hints, the MPP mode will be blocked. + p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because you have used hint to specify a join algorithm which is not supported by mpp now.") + if prop.IsFlashProp() { + return nil, false, nil + } + } + } + if prop.MPPPartitionTp == property.BroadcastType { + return nil, false, nil + } + joins := make([]base.PhysicalPlan, 0, 8) + canPushToTiFlash := p.CanPushToCop(kv.TiFlash) + if p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { + if (p.PreferJoinType & utilhint.PreferShuffleJoin) > 0 { + if shuffleJoins := tryToGetMppHashJoin(p, prop, false); len(shuffleJoins) > 0 { + return shuffleJoins, true, nil + } + } + if (p.PreferJoinType & utilhint.PreferBCJoin) > 0 { + if bcastJoins := tryToGetMppHashJoin(p, prop, true); len(bcastJoins) > 0 { + return bcastJoins, true, nil + } + } + if preferMppBCJ(p) { + mppJoins := tryToGetMppHashJoin(p, prop, true) + joins = append(joins, mppJoins...) + } else { + mppJoins := tryToGetMppHashJoin(p, prop, false) + joins = append(joins, mppJoins...) + } + } else { + hasMppHints := false + var errMsg string + if (p.PreferJoinType & utilhint.PreferShuffleJoin) > 0 { + errMsg = "The join can not push down to the MPP side, the shuffle_join() hint is invalid" + hasMppHints = true + } + if (p.PreferJoinType & utilhint.PreferBCJoin) > 0 { + errMsg = "The join can not push down to the MPP side, the broadcast_join() hint is invalid" + hasMppHints = true + } + if hasMppHints { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) + } + } + if prop.IsFlashProp() { + return joins, true, nil + } + + if !p.IsNAAJ() { + // naaj refuse merge join and index join. + mergeJoins := GetMergeJoin(p, prop, p.Schema(), p.StatsInfo(), p.Children()[0].StatsInfo(), p.Children()[1].StatsInfo()) + if (p.PreferJoinType&utilhint.PreferMergeJoin) > 0 && len(mergeJoins) > 0 { + return mergeJoins, true, nil + } + joins = append(joins, mergeJoins...) + + indexJoins, forced := tryToGetIndexJoin(p, prop) + if forced { + return indexJoins, true, nil + } + joins = append(joins, indexJoins...) + } + + hashJoins, forced := getHashJoins(p, prop) + if forced && len(hashJoins) > 0 { + return hashJoins, true, nil + } + joins = append(joins, hashJoins...) + + if p.PreferJoinType > 0 { + // If we reach here, it means we have a hint that doesn't work. + // It might be affected by the required property, so we enforce + // this property and try the hint again. + return joins, false, nil + } + return joins, true, nil +} + +// ExtractCorrelatedCols implements the base.LogicalPlan.<15th> interface. +func (p *LogicalJoin) ExtractCorrelatedCols() []*expression.CorrelatedColumn { + corCols := make([]*expression.CorrelatedColumn, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) + for _, fun := range p.EqualConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + for _, fun := range p.LeftConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + for _, fun := range p.RightConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + for _, fun := range p.OtherConditions { + corCols = append(corCols, expression.ExtractCorColumns(fun)...) + } + return corCols +} + +// MaxOneRow inherits the BaseLogicalPlan.LogicalPlan.<16th> implementation. + +// Children inherits the BaseLogicalPlan.LogicalPlan.<17th> implementation. + +// SetChildren inherits the BaseLogicalPlan.LogicalPlan.<18th> implementation. + +// SetChild inherits the BaseLogicalPlan.LogicalPlan.<19th> implementation. + +// RollBackTaskMap inherits the BaseLogicalPlan.LogicalPlan.<20th> implementation. + +// CanPushToCop inherits the BaseLogicalPlan.LogicalPlan.<21st> implementation. + +// ExtractFD implements the base.LogicalPlan.<22th> interface. +func (p *LogicalJoin) ExtractFD() *funcdep.FDSet { + switch p.JoinType { + case InnerJoin: + return p.extractFDForInnerJoin(nil) + case LeftOuterJoin, RightOuterJoin: + return p.extractFDForOuterJoin(nil) + case SemiJoin: + return p.extractFDForSemiJoin(nil) + default: + return &funcdep.FDSet{HashCodeToUniqueID: make(map[string]int)} + } +} + +// GetBaseLogicalPlan inherits the BaseLogicalPlan.LogicalPlan.<23th> implementation. + +// ConvertOuterToInnerJoin implements base.LogicalPlan.<24th> interface. +func (p *LogicalJoin) ConvertOuterToInnerJoin(predicates []expression.Expression) base.LogicalPlan { + innerTable := p.Children()[0] + outerTable := p.Children()[1] + switchChild := false + + if p.JoinType == LeftOuterJoin { + innerTable, outerTable = outerTable, innerTable + switchChild = true + } + + // First, simplify this join + if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { + canBeSimplified := false + for _, expr := range predicates { + isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr) + if isOk { + canBeSimplified = true + break + } + } + if canBeSimplified { + p.JoinType = InnerJoin + } + } + + // Next simplify join children + + combinedCond := mergeOnClausePredicates(p, predicates) + if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { + innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) + outerTable = outerTable.ConvertOuterToInnerJoin(predicates) + } else if p.JoinType == InnerJoin || p.JoinType == SemiJoin { + innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) + outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) + } else if p.JoinType == AntiSemiJoin { + innerTable = innerTable.ConvertOuterToInnerJoin(predicates) + outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) + } else { + innerTable = innerTable.ConvertOuterToInnerJoin(predicates) + outerTable = outerTable.ConvertOuterToInnerJoin(predicates) + } + + if switchChild { + p.SetChild(0, outerTable) + p.SetChild(1, innerTable) + } else { + p.SetChild(0, innerTable) + p.SetChild(1, outerTable) + } + + return p +} + +// *************************** end implementation of logicalPlan interface *************************** + +// IsNAAJ checks if the join is a non-adjacent-join. +func (p *LogicalJoin) IsNAAJ() bool { + return len(p.NAEQConditions) > 0 +} + +// Shallow copies a LogicalJoin struct. +func (p *LogicalJoin) Shallow() *LogicalJoin { + join := *p + return join.Init(p.SCtx(), p.QueryBlockOffset()) +} + +func (p *LogicalJoin) extractFDForSemiJoin(filtersFromApply []expression.Expression) *funcdep.FDSet { + // 1: since semi join will keep the part or all rows of the outer table, it's outer FD can be saved. + // 2: the un-projected column will be left for the upper layer projection or already be pruned from bottom up. + outerFD, _ := p.Children()[0].ExtractFD(), p.Children()[1].ExtractFD() + fds := outerFD + + eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions) + allConds := append(eqCondSlice, p.OtherConditions...) + allConds = append(allConds, filtersFromApply...) + notNullColsFromFilters := ExtractNotNullFromConds(allConds, p) + + constUniqueIDs := ExtractConstantCols(p.LeftConditions, p.SCtx(), fds) + + fds.MakeNotNull(notNullColsFromFilters) + fds.AddConstants(constUniqueIDs) + p.SetFDs(fds) + return fds +} + +func (p *LogicalJoin) extractFDForInnerJoin(filtersFromApply []expression.Expression) *funcdep.FDSet { + leftFD, rightFD := p.Children()[0].ExtractFD(), p.Children()[1].ExtractFD() + fds := leftFD + fds.MakeCartesianProduct(rightFD) + + eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions) + // some join eq conditions are stored in the OtherConditions. + allConds := append(eqCondSlice, p.OtherConditions...) + allConds = append(allConds, filtersFromApply...) + notNullColsFromFilters := ExtractNotNullFromConds(allConds, p) + + constUniqueIDs := ExtractConstantCols(allConds, p.SCtx(), fds) + + equivUniqueIDs := ExtractEquivalenceCols(allConds, p.SCtx(), fds) + + fds.MakeNotNull(notNullColsFromFilters) + fds.AddConstants(constUniqueIDs) + for _, equiv := range equivUniqueIDs { + fds.AddEquivalence(equiv[0], equiv[1]) + } + // merge the not-null-cols/registered-map from both side together. + fds.NotNullCols.UnionWith(rightFD.NotNullCols) + if fds.HashCodeToUniqueID == nil { + fds.HashCodeToUniqueID = rightFD.HashCodeToUniqueID + } else { + for k, v := range rightFD.HashCodeToUniqueID { + // If there's same constant in the different subquery, we might go into this IF branch. + if _, ok := fds.HashCodeToUniqueID[k]; ok { + continue + } + fds.HashCodeToUniqueID[k] = v + } + } + for i, ok := rightFD.GroupByCols.Next(0); ok; i, ok = rightFD.GroupByCols.Next(i + 1) { + fds.GroupByCols.Insert(i) + } + fds.HasAggBuilt = fds.HasAggBuilt || rightFD.HasAggBuilt + p.SetFDs(fds) + return fds +} + +func (p *LogicalJoin) extractFDForOuterJoin(filtersFromApply []expression.Expression) *funcdep.FDSet { + outerFD, innerFD := p.Children()[0].ExtractFD(), p.Children()[1].ExtractFD() + innerCondition := p.RightConditions + outerCondition := p.LeftConditions + outerCols, innerCols := intset.NewFastIntSet(), intset.NewFastIntSet() + for _, col := range p.Children()[0].Schema().Columns { + outerCols.Insert(int(col.UniqueID)) + } + for _, col := range p.Children()[1].Schema().Columns { + innerCols.Insert(int(col.UniqueID)) + } + if p.JoinType == RightOuterJoin { + innerFD, outerFD = outerFD, innerFD + innerCondition = p.LeftConditions + outerCondition = p.RightConditions + innerCols, outerCols = outerCols, innerCols + } + + eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions) + allConds := append(eqCondSlice, p.OtherConditions...) + allConds = append(allConds, innerCondition...) + allConds = append(allConds, outerCondition...) + allConds = append(allConds, filtersFromApply...) + notNullColsFromFilters := ExtractNotNullFromConds(allConds, p) + + filterFD := &funcdep.FDSet{HashCodeToUniqueID: make(map[string]int)} + + constUniqueIDs := ExtractConstantCols(allConds, p.SCtx(), filterFD) + + equivUniqueIDs := ExtractEquivalenceCols(allConds, p.SCtx(), filterFD) + + filterFD.AddConstants(constUniqueIDs) + equivOuterUniqueIDs := intset.NewFastIntSet() + equivAcrossNum := 0 + for _, equiv := range equivUniqueIDs { + filterFD.AddEquivalence(equiv[0], equiv[1]) + if equiv[0].SubsetOf(outerCols) && equiv[1].SubsetOf(innerCols) { + equivOuterUniqueIDs.UnionWith(equiv[0]) + equivAcrossNum++ + continue + } + if equiv[0].SubsetOf(innerCols) && equiv[1].SubsetOf(outerCols) { + equivOuterUniqueIDs.UnionWith(equiv[1]) + equivAcrossNum++ + } + } + filterFD.MakeNotNull(notNullColsFromFilters) + + // pre-perceive the filters for the convenience judgement of 3.3.1. + var opt funcdep.ArgOpts + if equivAcrossNum > 0 { + // find the equivalence FD across left and right cols. + var outConditionCols []*expression.Column + if len(outerCondition) != 0 { + outConditionCols = append(outConditionCols, expression.ExtractColumnsFromExpressions(nil, outerCondition, nil)...) + } + if len(p.OtherConditions) != 0 { + // other condition may contain right side cols, it doesn't affect the judgement of intersection of non-left-equiv cols. + outConditionCols = append(outConditionCols, expression.ExtractColumnsFromExpressions(nil, p.OtherConditions, nil)...) + } + outerConditionUniqueIDs := intset.NewFastIntSet() + for _, col := range outConditionCols { + outerConditionUniqueIDs.Insert(int(col.UniqueID)) + } + // judge whether left filters is on non-left-equiv cols. + if outerConditionUniqueIDs.Intersects(outerCols.Difference(equivOuterUniqueIDs)) { + opt.SkipFDRule331 = true + } + } else { + // if there is none across equivalence condition, skip rule 3.3.1. + opt.SkipFDRule331 = true + } + + opt.OnlyInnerFilter = len(eqCondSlice) == 0 && len(outerCondition) == 0 && len(p.OtherConditions) == 0 + if opt.OnlyInnerFilter { + // if one of the inner condition is constant false, the inner side are all null, left make constant all of that. + for _, one := range innerCondition { + if c, ok := one.(*expression.Constant); ok && c.DeferredExpr == nil && c.ParamMarker == nil { + if isTrue, err := c.Value.ToBool(p.SCtx().GetSessionVars().StmtCtx.TypeCtx()); err == nil { + if isTrue == 0 { + // c is false + opt.InnerIsFalse = true + } + } + } + } + } + + fds := outerFD + fds.MakeOuterJoin(innerFD, filterFD, outerCols, innerCols, &opt) + p.SetFDs(fds) + return fds +} + +// GetJoinKeys extracts join keys(columns) from EqualConditions. It returns left join keys, right +// join keys and an `isNullEQ` array which means the `joinKey[i]` is a `NullEQ` function. The `hasNullEQ` +// means whether there is a `NullEQ` of a join key. +func (p *LogicalJoin) GetJoinKeys() (leftKeys, rightKeys []*expression.Column, isNullEQ []bool, hasNullEQ bool) { + for _, expr := range p.EqualConditions { + leftKeys = append(leftKeys, expr.GetArgs()[0].(*expression.Column)) + rightKeys = append(rightKeys, expr.GetArgs()[1].(*expression.Column)) + isNullEQ = append(isNullEQ, expr.FuncName.L == ast.NullEQ) + hasNullEQ = hasNullEQ || expr.FuncName.L == ast.NullEQ + } + return +} + +// GetNAJoinKeys extracts join keys(columns) from NAEqualCondition. +func (p *LogicalJoin) GetNAJoinKeys() (leftKeys, rightKeys []*expression.Column) { + for _, expr := range p.NAEQConditions { + leftKeys = append(leftKeys, expr.GetArgs()[0].(*expression.Column)) + rightKeys = append(rightKeys, expr.GetArgs()[1].(*expression.Column)) + } + return +} + +// GetPotentialPartitionKeys return potential partition keys for join, the potential partition keys are +// the join keys of EqualConditions +func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*property.MPPPartitionColumn) { + for _, expr := range p.EqualConditions { + _, coll := expr.CharsetAndCollation() + collateID := property.GetCollateIDByNameForPartition(coll) + leftKeys = append(leftKeys, &property.MPPPartitionColumn{Col: expr.GetArgs()[0].(*expression.Column), CollateID: collateID}) + rightKeys = append(rightKeys, &property.MPPPartitionColumn{Col: expr.GetArgs()[1].(*expression.Column), CollateID: collateID}) + } + return +} + +// Decorrelate eliminate the correlated column with if the col is in schema. +func (p *LogicalJoin) Decorrelate(schema *expression.Schema) { + for i, cond := range p.LeftConditions { + p.LeftConditions[i] = cond.Decorrelate(schema) + } + for i, cond := range p.RightConditions { + p.RightConditions[i] = cond.Decorrelate(schema) + } + for i, cond := range p.OtherConditions { + p.OtherConditions[i] = cond.Decorrelate(schema) + } + for i, cond := range p.EqualConditions { + p.EqualConditions[i] = cond.Decorrelate(schema).(*expression.ScalarFunction) + } +} + +// ColumnSubstituteAll is used in projection elimination in apply de-correlation. +// Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged. +func (p *LogicalJoin) ColumnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) { + // make a copy of exprs for convenience of substitution (may change/partially change the expr tree) + cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions)) + cpRightConditions := make(expression.CNFExprs, len(p.RightConditions)) + cpOtherConditions := make(expression.CNFExprs, len(p.OtherConditions)) + cpEqualConditions := make([]*expression.ScalarFunction, len(p.EqualConditions)) + copy(cpLeftConditions, p.LeftConditions) + copy(cpRightConditions, p.RightConditions) + copy(cpOtherConditions, p.OtherConditions) + copy(cpEqualConditions, p.EqualConditions) + + exprCtx := p.SCtx().GetExprCtx() + // try to substitute columns in these condition. + for i, cond := range cpLeftConditions { + if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { + return + } + } + + for i, cond := range cpRightConditions { + if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { + return + } + } + + for i, cond := range cpOtherConditions { + if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { + return + } + } + + for i, cond := range cpEqualConditions { + var tmp expression.Expression + if hasFail, tmp = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { + return + } + cpEqualConditions[i] = tmp.(*expression.ScalarFunction) + } + + // if all substituted, change them atomically here. + p.LeftConditions = cpLeftConditions + p.RightConditions = cpRightConditions + p.OtherConditions = cpOtherConditions + p.EqualConditions = cpEqualConditions + + for i := len(p.EqualConditions) - 1; i >= 0; i-- { + newCond := p.EqualConditions[i] + + // If the columns used in the new filter all come from the left child, + // we can push this filter to it. + if expression.ExprFromSchema(newCond, p.Children()[0].Schema()) { + p.LeftConditions = append(p.LeftConditions, newCond) + p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) + continue + } + + // If the columns used in the new filter all come from the right + // child, we can push this filter to it. + if expression.ExprFromSchema(newCond, p.Children()[1].Schema()) { + p.RightConditions = append(p.RightConditions, newCond) + p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) + continue + } + + _, lhsIsCol := newCond.GetArgs()[0].(*expression.Column) + _, rhsIsCol := newCond.GetArgs()[1].(*expression.Column) + + // If the columns used in the new filter are not all expression.Column, + // we can not use it as join's equal condition. + if !(lhsIsCol && rhsIsCol) { + p.OtherConditions = append(p.OtherConditions, newCond) + p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) + continue + } + + p.EqualConditions[i] = newCond + } + return false +} + +// AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and +// `OtherConditions` by the result of extract. +func (p *LogicalJoin) AttachOnConds(onConds []expression.Expression) { + eq, left, right, other := p.extractOnCondition(onConds, false, false) + p.AppendJoinConds(eq, left, right, other) +} + +// AppendJoinConds appends new join conditions. +func (p *LogicalJoin) AppendJoinConds(eq []*expression.ScalarFunction, left, right, other []expression.Expression) { + p.EqualConditions = append(eq, p.EqualConditions...) + p.LeftConditions = append(left, p.LeftConditions...) + p.RightConditions = append(right, p.RightConditions...) + p.OtherConditions = append(other, p.OtherConditions...) +} + +// ExtractJoinKeys extract join keys as a schema for child with childIdx. +func (p *LogicalJoin) ExtractJoinKeys(childIdx int) *expression.Schema { + joinKeys := make([]*expression.Column, 0, len(p.EqualConditions)) + for _, eqCond := range p.EqualConditions { + joinKeys = append(joinKeys, eqCond.GetArgs()[childIdx].(*expression.Column)) + } + return expression.NewSchema(joinKeys...) +} + +// extractUsedCols extracts all the needed columns. +func (p *LogicalJoin) extractUsedCols(parentUsedCols []*expression.Column) (leftCols []*expression.Column, rightCols []*expression.Column) { + for _, eqCond := range p.EqualConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(eqCond)...) + } + for _, leftCond := range p.LeftConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(leftCond)...) + } + for _, rightCond := range p.RightConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(rightCond)...) + } + for _, otherCond := range p.OtherConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...) + } + for _, naeqCond := range p.NAEQConditions { + parentUsedCols = append(parentUsedCols, expression.ExtractColumns(naeqCond)...) + } + lChild := p.Children()[0] + rChild := p.Children()[1] + for _, col := range parentUsedCols { + if lChild.Schema().Contains(col) { + leftCols = append(leftCols, col) + } else if rChild.Schema().Contains(col) { + rightCols = append(rightCols, col) + } + } + return leftCols, rightCols +} + +// MergeSchema merge the schema of left and right child of join. +func (p *LogicalJoin) mergeSchema() { + p.SetSchema(buildLogicalJoinSchema(p.JoinType, p)) +} + +// pushDownTopNToChild will push a topN to one child of join. The idx stands for join child index. 0 is for left child. +func (p *LogicalJoin) pushDownTopNToChild(topN *logicalop.LogicalTopN, idx int, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { + if topN == nil { + return p.Children()[idx].PushDownTopN(nil, opt) + } + + for _, by := range topN.ByItems { + cols := expression.ExtractColumns(by.Expr) + for _, col := range cols { + if !p.Children()[idx].Schema().Contains(col) { + return p.Children()[idx].PushDownTopN(nil, opt) + } + } + } + + newTopN := logicalop.LogicalTopN{ + Count: topN.Count + topN.Offset, + ByItems: make([]*util.ByItems, len(topN.ByItems)), + PreferLimitToCop: topN.PreferLimitToCop, + }.Init(topN.SCtx(), topN.QueryBlockOffset()) + for i := range topN.ByItems { + newTopN.ByItems[i] = topN.ByItems[i].Clone() + } + appendTopNPushDownJoinTraceStep(p, newTopN, idx, opt) + return p.Children()[idx].PushDownTopN(newTopN, opt) +} + +// Add a new selection between parent plan and current plan with candidate predicates +/* ++-------------+ +-------------+ +| parentPlan | | parentPlan | ++-----^-------+ +-----^-------+ + | --addCandidateSelection---> | ++-----+-------+ +-----------+--------------+ +| currentPlan | | selection | ++-------------+ | candidate predicate | + +-----------^--------------+ + | + | + +----+--------+ + | currentPlan | + +-------------+ +*/ +// If the currentPlan at the top of query plan, return new root plan (selection) +// Else return nil +func addCandidateSelection(currentPlan base.LogicalPlan, currentChildIdx int, parentPlan base.LogicalPlan, + candidatePredicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { + // generate a new selection for candidatePredicates + selection := LogicalSelection{Conditions: candidatePredicates}.Init(currentPlan.SCtx(), currentPlan.QueryBlockOffset()) + // add selection above of p + if parentPlan == nil { + newRoot = selection + } else { + parentPlan.SetChild(currentChildIdx, selection) + } + selection.SetChildren(currentPlan) + appendAddSelectionTraceStep(parentPlan, currentPlan, selection, opt) + if parentPlan == nil { + return newRoot + } + return nil +} + +func (p *LogicalJoin) getGroupNDVs(colGroups [][]*expression.Column, childStats []*property.StatsInfo) []property.GroupNDV { + outerIdx := int(-1) + if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { + outerIdx = 0 + } else if p.JoinType == RightOuterJoin { + outerIdx = 1 + } + if outerIdx >= 0 && len(colGroups) > 0 { + return childStats[outerIdx].GroupNDVs + } + return nil +} + +// PreferAny checks whether the join type is in the joinFlags. +func (p *LogicalJoin) PreferAny(joinFlags ...uint) bool { + for _, flag := range joinFlags { + if p.PreferJoinType&flag > 0 { + return true + } + } + return false +} + +// ExtractOnCondition divide conditions in CNF of join node into 4 groups. +// These conditions can be where conditions, join conditions, or collection of both. +// If deriveLeft/deriveRight is set, we would try to derive more conditions for left/right plan. +func (p *LogicalJoin) ExtractOnCondition( + conditions []expression.Expression, + leftSchema *expression.Schema, + rightSchema *expression.Schema, + deriveLeft bool, + deriveRight bool) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, + rightCond []expression.Expression, otherCond []expression.Expression) { + ctx := p.SCtx() + for _, expr := range conditions { + // For queries like `select a in (select a from s where s.b = t.b) from t`, + // if subquery is empty caused by `s.b = t.b`, the result should always be + // false even if t.a is null or s.a is null. To make this join "empty aware", + // we should differentiate `t.a = s.a` from other column equal conditions, so + // we put it into OtherConditions instead of EqualConditions of join. + if expression.IsEQCondFromIn(expr) { + otherCond = append(otherCond, expr) + continue + } + binop, ok := expr.(*expression.ScalarFunction) + if ok && len(binop.GetArgs()) == 2 { + arg0, lOK := binop.GetArgs()[0].(*expression.Column) + arg1, rOK := binop.GetArgs()[1].(*expression.Column) + if lOK && rOK { + leftCol := leftSchema.RetrieveColumn(arg0) + rightCol := rightSchema.RetrieveColumn(arg1) + if leftCol == nil || rightCol == nil { + leftCol = leftSchema.RetrieveColumn(arg1) + rightCol = rightSchema.RetrieveColumn(arg0) + arg0, arg1 = arg1, arg0 + } + if leftCol != nil && rightCol != nil { + if deriveLeft { + if util.IsNullRejected(ctx, leftSchema, expr) && !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) { + notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), leftCol) + leftCond = append(leftCond, notNullExpr) + } + } + if deriveRight { + if util.IsNullRejected(ctx, rightSchema, expr) && !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { + notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), rightCol) + rightCond = append(rightCond, notNullExpr) + } + } + if binop.FuncName.L == ast.EQ { + cond := expression.NewFunctionInternal(ctx.GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), arg0, arg1) + eqCond = append(eqCond, cond.(*expression.ScalarFunction)) + continue + } + } + } + } + columns := expression.ExtractColumns(expr) + // `columns` may be empty, if the condition is like `correlated_column op constant`, or `constant`, + // push this kind of constant condition down according to join type. + if len(columns) == 0 { + leftCond, rightCond = p.pushDownConstExpr(expr, leftCond, rightCond, deriveLeft || deriveRight) + continue + } + allFromLeft, allFromRight := true, true + for _, col := range columns { + if !leftSchema.Contains(col) { + allFromLeft = false + } + if !rightSchema.Contains(col) { + allFromRight = false + } + } + if allFromRight { + rightCond = append(rightCond, expr) + } else if allFromLeft { + leftCond = append(leftCond, expr) + } else { + // Relax expr to two supersets: leftRelaxedCond and rightRelaxedCond, the expression now is + // `expr AND leftRelaxedCond AND rightRelaxedCond`. Motivation is to push filters down to + // children as much as possible. + if deriveLeft { + leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx.GetExprCtx(), expr, leftSchema) + if leftRelaxedCond != nil { + leftCond = append(leftCond, leftRelaxedCond) + } + } + if deriveRight { + rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx.GetExprCtx(), expr, rightSchema) + if rightRelaxedCond != nil { + rightCond = append(rightCond, rightRelaxedCond) + } + } + otherCond = append(otherCond, expr) + } + } + return +} + +// pushDownConstExpr checks if the condition is from filter condition, if true, push it down to both +// children of join, whatever the join type is; if false, push it down to inner child of outer join, +// and both children of non-outer-join. +func (p *LogicalJoin) pushDownConstExpr(expr expression.Expression, leftCond []expression.Expression, + rightCond []expression.Expression, filterCond bool) ([]expression.Expression, []expression.Expression) { + switch p.JoinType { + case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: + if filterCond { + leftCond = append(leftCond, expr) + // Append the expr to right join condition instead of `rightCond`, to make it able to be + // pushed down to children of join. + p.RightConditions = append(p.RightConditions, expr) + } else { + rightCond = append(rightCond, expr) + } + case RightOuterJoin: + if filterCond { + rightCond = append(rightCond, expr) + p.LeftConditions = append(p.LeftConditions, expr) + } else { + leftCond = append(leftCond, expr) + } + case SemiJoin, InnerJoin: + leftCond = append(leftCond, expr) + rightCond = append(rightCond, expr) + case AntiSemiJoin: + if filterCond { + leftCond = append(leftCond, expr) + } + rightCond = append(rightCond, expr) + } + return leftCond, rightCond +} + +func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, deriveLeft bool, + deriveRight bool) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, + rightCond []expression.Expression, otherCond []expression.Expression) { + return p.ExtractOnCondition(conditions, p.Children()[0].Schema(), p.Children()[1].Schema(), deriveLeft, deriveRight) +} + +// SetPreferredJoinTypeAndOrder sets the preferred join type and order for the LogicalJoin. +func (p *LogicalJoin) SetPreferredJoinTypeAndOrder(hintInfo *utilhint.PlanHints) { + if hintInfo == nil { + return + } + + lhsAlias := extractTableAlias(p.Children()[0], p.QueryBlockOffset()) + rhsAlias := extractTableAlias(p.Children()[1], p.QueryBlockOffset()) + if hintInfo.IfPreferMergeJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferMergeJoin + p.LeftPreferJoinType |= utilhint.PreferMergeJoin + } + if hintInfo.IfPreferMergeJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferMergeJoin + p.RightPreferJoinType |= utilhint.PreferMergeJoin + } + if hintInfo.IfPreferNoMergeJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferNoMergeJoin + p.LeftPreferJoinType |= utilhint.PreferNoMergeJoin + } + if hintInfo.IfPreferNoMergeJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferNoMergeJoin + p.RightPreferJoinType |= utilhint.PreferNoMergeJoin + } + if hintInfo.IfPreferBroadcastJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferBCJoin + p.LeftPreferJoinType |= utilhint.PreferBCJoin + } + if hintInfo.IfPreferBroadcastJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferBCJoin + p.RightPreferJoinType |= utilhint.PreferBCJoin + } + if hintInfo.IfPreferShuffleJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferShuffleJoin + p.LeftPreferJoinType |= utilhint.PreferShuffleJoin + } + if hintInfo.IfPreferShuffleJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferShuffleJoin + p.RightPreferJoinType |= utilhint.PreferShuffleJoin + } + if hintInfo.IfPreferHashJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferHashJoin + p.LeftPreferJoinType |= utilhint.PreferHashJoin + } + if hintInfo.IfPreferHashJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferHashJoin + p.RightPreferJoinType |= utilhint.PreferHashJoin + } + if hintInfo.IfPreferNoHashJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferNoHashJoin + p.LeftPreferJoinType |= utilhint.PreferNoHashJoin + } + if hintInfo.IfPreferNoHashJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferNoHashJoin + p.RightPreferJoinType |= utilhint.PreferNoHashJoin + } + if hintInfo.IfPreferINLJ(lhsAlias) { + p.PreferJoinType |= utilhint.PreferLeftAsINLJInner + p.LeftPreferJoinType |= utilhint.PreferINLJ + } + if hintInfo.IfPreferINLJ(rhsAlias) { + p.PreferJoinType |= utilhint.PreferRightAsINLJInner + p.RightPreferJoinType |= utilhint.PreferINLJ + } + if hintInfo.IfPreferINLHJ(lhsAlias) { + p.PreferJoinType |= utilhint.PreferLeftAsINLHJInner + p.LeftPreferJoinType |= utilhint.PreferINLHJ + } + if hintInfo.IfPreferINLHJ(rhsAlias) { + p.PreferJoinType |= utilhint.PreferRightAsINLHJInner + p.RightPreferJoinType |= utilhint.PreferINLHJ + } + if hintInfo.IfPreferINLMJ(lhsAlias) { + p.PreferJoinType |= utilhint.PreferLeftAsINLMJInner + p.LeftPreferJoinType |= utilhint.PreferINLMJ + } + if hintInfo.IfPreferINLMJ(rhsAlias) { + p.PreferJoinType |= utilhint.PreferRightAsINLMJInner + p.RightPreferJoinType |= utilhint.PreferINLMJ + } + if hintInfo.IfPreferNoIndexJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferNoIndexJoin + p.LeftPreferJoinType |= utilhint.PreferNoIndexJoin + } + if hintInfo.IfPreferNoIndexJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferNoIndexJoin + p.RightPreferJoinType |= utilhint.PreferNoIndexJoin + } + if hintInfo.IfPreferNoIndexHashJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferNoIndexHashJoin + p.LeftPreferJoinType |= utilhint.PreferNoIndexHashJoin + } + if hintInfo.IfPreferNoIndexHashJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferNoIndexHashJoin + p.RightPreferJoinType |= utilhint.PreferNoIndexHashJoin + } + if hintInfo.IfPreferNoIndexMergeJoin(lhsAlias) { + p.PreferJoinType |= utilhint.PreferNoIndexMergeJoin + p.LeftPreferJoinType |= utilhint.PreferNoIndexMergeJoin + } + if hintInfo.IfPreferNoIndexMergeJoin(rhsAlias) { + p.PreferJoinType |= utilhint.PreferNoIndexMergeJoin + p.RightPreferJoinType |= utilhint.PreferNoIndexMergeJoin + } + if hintInfo.IfPreferHJBuild(lhsAlias) { + p.PreferJoinType |= utilhint.PreferLeftAsHJBuild + p.LeftPreferJoinType |= utilhint.PreferHJBuild + } + if hintInfo.IfPreferHJBuild(rhsAlias) { + p.PreferJoinType |= utilhint.PreferRightAsHJBuild + p.RightPreferJoinType |= utilhint.PreferHJBuild + } + if hintInfo.IfPreferHJProbe(lhsAlias) { + p.PreferJoinType |= utilhint.PreferLeftAsHJProbe + p.LeftPreferJoinType |= utilhint.PreferHJProbe + } + if hintInfo.IfPreferHJProbe(rhsAlias) { + p.PreferJoinType |= utilhint.PreferRightAsHJProbe + p.RightPreferJoinType |= utilhint.PreferHJProbe + } + hasConflict := false + if !p.SCtx().GetSessionVars().EnableAdvancedJoinHint || p.SCtx().GetSessionVars().StmtCtx.StraightJoinOrder { + if containDifferentJoinTypes(p.PreferJoinType) { + hasConflict = true + } + } else if p.SCtx().GetSessionVars().EnableAdvancedJoinHint { + if containDifferentJoinTypes(p.LeftPreferJoinType) || containDifferentJoinTypes(p.RightPreferJoinType) { + hasConflict = true + } + } + if hasConflict { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( + "Join hints are conflict, you can only specify one type of join") + p.PreferJoinType = 0 + } + // set the join order + if hintInfo.LeadingJoinOrder != nil { + p.PreferJoinOrder = hintInfo.MatchTableName([]*utilhint.HintedTable{lhsAlias, rhsAlias}, hintInfo.LeadingJoinOrder) + } + // set hintInfo for further usage if this hint info can be used. + if p.PreferJoinType != 0 || p.PreferJoinOrder { + p.HintInfo = hintInfo + } +} + +// SetPreferredJoinType generates hint information for the logicalJoin based on the hint information of its left and right children. +func (p *LogicalJoin) SetPreferredJoinType() { + if p.LeftPreferJoinType == 0 && p.RightPreferJoinType == 0 { + return + } + p.PreferJoinType = setPreferredJoinTypeFromOneSide(p.LeftPreferJoinType, true) | setPreferredJoinTypeFromOneSide(p.RightPreferJoinType, false) + if containDifferentJoinTypes(p.PreferJoinType) { + p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( + "Join hints conflict after join reorder phase, you can only specify one type of join") + p.PreferJoinType = 0 + } +} + +// updateEQCond will extract the arguments of a equal condition that connect two expressions. +func (p *LogicalJoin) updateEQCond() { + lChild, rChild := p.Children()[0], p.Children()[1] + var lKeys, rKeys []expression.Expression + var lNAKeys, rNAKeys []expression.Expression + // We need two steps here: + // step1: try best to extract normal EQ condition from OtherCondition to join EqualConditions. + for i := len(p.OtherConditions) - 1; i >= 0; i-- { + need2Remove := false + if eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction); ok && eqCond.FuncName.L == ast.EQ { + // If it is a column equal condition converted from `[not] in (subq)`, do not move it + // to EqualConditions, and keep it in OtherConditions. Reference comments in `extractOnCondition` + // for detailed reasons. + if expression.IsEQCondFromIn(eqCond) { + continue + } + lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] + if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { + lKeys = append(lKeys, lExpr) + rKeys = append(rKeys, rExpr) + need2Remove = true + } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { + lKeys = append(lKeys, rExpr) + rKeys = append(rKeys, lExpr) + need2Remove = true + } + } + if need2Remove { + p.OtherConditions = append(p.OtherConditions[:i], p.OtherConditions[i+1:]...) + } + } + // eg: explain select * from t1, t3 where t1.a+1 = t3.a; + // tidb only accept the join key in EqualCondition as a normal column (join OP take granted for that) + // so once we found the left and right children's schema can supply the all columns in complicated EQ condition that used by left/right key. + // we will add a layer of projection here to convert the complicated expression of EQ's left or right side to be a normal column. + adjustKeyForm := func(leftKeys, rightKeys []expression.Expression, isNA bool) { + if len(leftKeys) > 0 { + needLProj, needRProj := false, false + for i := range leftKeys { + _, lOk := leftKeys[i].(*expression.Column) + _, rOk := rightKeys[i].(*expression.Column) + needLProj = needLProj || !lOk + needRProj = needRProj || !rOk + } + + var lProj, rProj *logicalop.LogicalProjection + if needLProj { + lProj = p.getProj(0) + } + if needRProj { + rProj = p.getProj(1) + } + for i := range leftKeys { + lKey, rKey := leftKeys[i], rightKeys[i] + if lProj != nil { + lKey = lProj.AppendExpr(lKey) + } + if rProj != nil { + rKey = rProj.AppendExpr(rKey) + } + eqCond := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) + if isNA { + p.NAEQConditions = append(p.NAEQConditions, eqCond.(*expression.ScalarFunction)) + } else { + p.EqualConditions = append(p.EqualConditions, eqCond.(*expression.ScalarFunction)) + } + } + } + } + adjustKeyForm(lKeys, rKeys, false) + + // Step2: when step1 is finished, then we can determine whether we need to extract NA-EQ from OtherCondition to NAEQConditions. + // when there are still no EqualConditions, let's try to be a NAAJ. + // todo: by now, when there is already a normal EQ condition, just keep NA-EQ as other-condition filters above it. + // eg: select * from stu where stu.name not in (select name from exam where exam.stu_id = stu.id); + // combination of and for join key is little complicated for now. + canBeNAAJ := (p.JoinType == AntiSemiJoin || p.JoinType == AntiLeftOuterSemiJoin) && len(p.EqualConditions) == 0 + if canBeNAAJ && p.SCtx().GetSessionVars().OptimizerEnableNAAJ { + var otherCond expression.CNFExprs + for i := 0; i < len(p.OtherConditions); i++ { + eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction) + if ok && eqCond.FuncName.L == ast.EQ && expression.IsEQCondFromIn(eqCond) { + // here must be a EQCondFromIn. + lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] + if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { + lNAKeys = append(lNAKeys, lExpr) + rNAKeys = append(rNAKeys, rExpr) + } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { + lNAKeys = append(lNAKeys, rExpr) + rNAKeys = append(rNAKeys, lExpr) + } + continue + } + otherCond = append(otherCond, p.OtherConditions[i]) + } + p.OtherConditions = otherCond + // here is for cases like: select (a+1, b*3) not in (select a,b from t2) from t1. + adjustKeyForm(lNAKeys, rNAKeys, true) + } +} + +func (p *LogicalJoin) getProj(idx int) *logicalop.LogicalProjection { + child := p.Children()[idx] + proj, ok := child.(*logicalop.LogicalProjection) + if ok { + return proj + } + proj = logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, child.Schema().Len())}.Init(p.SCtx(), child.QueryBlockOffset()) + for _, col := range child.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + } + proj.SetSchema(child.Schema().Clone()) + proj.SetChildren(child) + p.Children()[idx] = proj + return proj +} + +// outerJoinPropConst propagates constant equal and column equal conditions over outer join. +func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []expression.Expression { + outerTable := p.Children()[0] + innerTable := p.Children()[1] + if p.JoinType == RightOuterJoin { + innerTable, outerTable = outerTable, innerTable + } + lenJoinConds := len(p.EqualConditions) + len(p.LeftConditions) + len(p.RightConditions) + len(p.OtherConditions) + joinConds := make([]expression.Expression, 0, lenJoinConds) + for _, equalCond := range p.EqualConditions { + joinConds = append(joinConds, equalCond) + } + joinConds = append(joinConds, p.LeftConditions...) + joinConds = append(joinConds, p.RightConditions...) + joinConds = append(joinConds, p.OtherConditions...) + p.EqualConditions = nil + p.LeftConditions = nil + p.RightConditions = nil + p.OtherConditions = nil + nullSensitive := p.JoinType == AntiLeftOuterSemiJoin || p.JoinType == LeftOuterSemiJoin + joinConds, predicates = expression.PropConstOverOuterJoin(p.SCtx().GetExprCtx(), joinConds, predicates, outerTable.Schema(), innerTable.Schema(), nullSensitive) + p.AttachOnConds(joinConds) + return predicates +} diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index 6d818d8397819..ec4640ba1cdd2 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -4476,13 +4476,13 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as // If dynamic partition prune isn't enabled or global stats is not ready, we won't enable dynamic prune mode in query usePartitionProcessor := !isDynamicEnabled || (!globalStatsReady && !allowDynamicWithoutStats) - failpoint.Inject("forceDynamicPrune", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("forceDynamicPrune")); _err_ == nil { if val.(bool) { if isDynamicEnabled { usePartitionProcessor = false } } - }) + } if usePartitionProcessor { b.optFlag = b.optFlag | flagPartitionProcessor diff --git a/pkg/planner/core/logical_plan_builder.go__failpoint_stash__ b/pkg/planner/core/logical_plan_builder.go__failpoint_stash__ new file mode 100644 index 0000000000000..6d818d8397819 --- /dev/null +++ b/pkg/planner/core/logical_plan_builder.go__failpoint_stash__ @@ -0,0 +1,7284 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "context" + "fmt" + "math" + "math/bits" + "sort" + "strconv" + "strings" + "time" + "unicode" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/opcode" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner/core/base" + core_metrics "github.com/pingcap/tidb/pkg/planner/core/metrics" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/coreusage" + "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" + "github.com/pingcap/tidb/pkg/planner/util/tablesampler" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/table/temptable" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + util2 "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/hack" + h "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/intset" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/plancodec" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tidb/pkg/util/size" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +const ( + // ErrExprInSelect is in select fields for the error of ErrFieldNotInGroupBy + ErrExprInSelect = "SELECT list" + // ErrExprInOrderBy is in order by items for the error of ErrFieldNotInGroupBy + ErrExprInOrderBy = "ORDER BY" +) + +// aggOrderByResolver is currently resolving expressions of order by clause +// in aggregate function GROUP_CONCAT. +type aggOrderByResolver struct { + ctx base.PlanContext + err error + args []ast.ExprNode + exprDepth int // exprDepth is the depth of current expression in expression tree. +} + +func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) { + a.exprDepth++ + if n, ok := inNode.(*driver.ParamMarkerExpr); ok { + if a.exprDepth == 1 { + _, isNull, isExpectedType := getUintFromNode(a.ctx, n, false) + // For constant uint expression in top level, it should be treated as position expression. + if !isNull && isExpectedType { + return expression.ConstructPositionExpr(n), true + } + } + } + return inNode, false +} + +func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) { + if v, ok := inNode.(*ast.PositionExpr); ok { + pos, isNull, err := expression.PosFromPositionExpr(a.ctx.GetExprCtx(), a.ctx, v) + if err != nil { + a.err = err + } + if err != nil || isNull { + return inNode, false + } + if pos < 1 || pos > len(a.args) { + errPos := strconv.Itoa(pos) + if v.P != nil { + errPos = "?" + } + a.err = plannererrors.ErrUnknownColumn.FastGenByArgs(errPos, "order clause") + return inNode, false + } + ret := a.args[pos-1] + return ret, true + } + return inNode, true +} + +func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expression) (base.LogicalPlan, []expression.Expression, error) { + ectx := p.SCtx().GetExprCtx().GetEvalCtx() + b.optFlag |= flagResolveExpand + + // Rollup syntax require expand OP to do the data expansion, different data replica supply the different grouping layout. + distinctGbyExprs, gbyExprsRefPos := expression.DeduplicateGbyExpression(gbyItems) + // build another projection below. + proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, p.Schema().Len()+len(distinctGbyExprs))}.Init(b.ctx, b.getSelectOffset()) + // project: child's output and distinct GbyExprs in advance. (make every group-by item to be a column) + projSchema := p.Schema().Clone() + names := p.OutputNames() + for _, col := range projSchema.Columns { + proj.Exprs = append(proj.Exprs, col) + } + distinctGbyColNames := make(types.NameSlice, 0, len(distinctGbyExprs)) + distinctGbyCols := make([]*expression.Column, 0, len(distinctGbyExprs)) + for _, expr := range distinctGbyExprs { + // distinct group expr has been resolved in resolveGby. + proj.Exprs = append(proj.Exprs, expr) + + // add the newly appended names. + var name *types.FieldName + if c, ok := expr.(*expression.Column); ok { + name = buildExpandFieldName(ectx, c, names[p.Schema().ColumnIndex(c)], "") + } else { + name = buildExpandFieldName(ectx, expr, nil, "") + } + names = append(names, name) + distinctGbyColNames = append(distinctGbyColNames, name) + + // since we will change the nullability of source col, proj it with a new col id. + col := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + // clone it rather than using it directly, + RetType: expr.GetType(b.ctx.GetExprCtx().GetEvalCtx()).Clone(), + } + + projSchema.Append(col) + distinctGbyCols = append(distinctGbyCols, col) + } + proj.SetSchema(projSchema) + proj.SetChildren(p) + // since expand will ref original col and make some change, do the copy in executor rather than ref the same chunk.column. + proj.AvoidColumnEvaluator = true + proj.Proj4Expand = true + newGbyItems := expression.RestoreGbyExpression(distinctGbyCols, gbyExprsRefPos) + + // build expand. + rollupGroupingSets := expression.RollupGroupingSets(newGbyItems) + // eg: with rollup => {},{a},{a,b},{a,b,c} + // for every grouping set above, we should individually set those not-needed grouping-set col as null value. + // eg: let's say base schema is , d is unrelated col, keep it real in every grouping set projection. + // for grouping set {a,b,c}, project it as: [a, b, c, d, gid] + // for grouping set {a,b}, project it as: [a, b, null, d, gid] + // for grouping set {a}, project it as: [a, null, null, d, gid] + // for grouping set {}, project it as: [null, null, null, d, gid] + expandSchema := proj.Schema().Clone() + expression.AdjustNullabilityFromGroupingSets(rollupGroupingSets, expandSchema) + expand := LogicalExpand{ + RollupGroupingSets: rollupGroupingSets, + DistinctGroupByCol: distinctGbyCols, + DistinctGbyColNames: distinctGbyColNames, + // for resolving grouping function args. + DistinctGbyExprs: distinctGbyExprs, + + // fill the gen col names when building level projections. + }.Init(b.ctx, b.getSelectOffset()) + + // if we want to use bitAnd for the quick computation of grouping function, then the maximum capacity of num of grouping is about 64. + expand.GroupingMode = tipb.GroupingMode_ModeBitAnd + if len(expand.RollupGroupingSets) > 64 { + expand.GroupingMode = tipb.GroupingMode_ModeNumericSet + } + + expand.DistinctSize, expand.RollupGroupingIDs, expand.RollupID2GIDS = expand.RollupGroupingSets.DistinctSize() + hasDuplicateGroupingSet := len(expand.RollupGroupingSets) != expand.DistinctSize + // append the generated column for logical Expand. + tp := types.NewFieldType(mysql.TypeLonglong) + tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) + gid := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: tp, + OrigName: "gid", + } + expand.GID = gid + expandSchema.Append(gid) + expand.ExtraGroupingColNames = append(expand.ExtraGroupingColNames, gid.OrigName) + names = append(names, buildExpandFieldName(ectx, gid, nil, "gid_")) + expand.GIDName = names[len(names)-1] + if hasDuplicateGroupingSet { + // the last two col of the schema should be gid & gpos + gpos := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: tp.Clone(), + OrigName: "gpos", + } + expand.GPos = gpos + expandSchema.Append(gpos) + expand.ExtraGroupingColNames = append(expand.ExtraGroupingColNames, gpos.OrigName) + names = append(names, buildExpandFieldName(ectx, gpos, nil, "gpos_")) + expand.GPosName = names[len(names)-1] + } + expand.SetChildren(proj) + expand.SetSchema(expandSchema) + expand.SetOutputNames(names) + + // register current rollup Expand operator in current select block. + b.currentBlockExpand = expand + + // defer generating level-projection as last logical optimization rule. + return expand, newGbyItems, nil +} + +func (b *PlanBuilder) buildAggregation(ctx context.Context, p base.LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression, + correlatedAggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, map[int]int, error) { + b.optFlag |= flagBuildKeyInfo + b.optFlag |= flagPushDownAgg + // We may apply aggregation eliminate optimization. + // So we add the flagMaxMinEliminate to try to convert max/min to topn and flagPushDownTopN to handle the newly added topn operator. + b.optFlag |= flagMaxMinEliminate + b.optFlag |= flagPushDownTopN + // when we eliminate the max and min we may add `is not null` filter. + b.optFlag |= flagPredicatePushDown + b.optFlag |= flagEliminateAgg + b.optFlag |= flagEliminateProjection + + if b.ctx.GetSessionVars().EnableSkewDistinctAgg { + b.optFlag |= flagSkewDistinctAgg + } + // flag it if cte contain aggregation + if b.buildingCTE { + b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow = true + } + var rollupExpand *LogicalExpand + if expand, ok := p.(*LogicalExpand); ok { + rollupExpand = expand + } + + plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx, b.getSelectOffset()) + if hintinfo := b.TableHints(); hintinfo != nil { + plan4Agg.PreferAggType = hintinfo.PreferAggType + plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop + } + schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...) + names := make(types.NameSlice, 0, len(aggFuncList)+p.Schema().Len()) + // aggIdxMap maps the old index to new index after applying common aggregation functions elimination. + aggIndexMap := make(map[int]int) + + allAggsFirstRow := true + for i, aggFunc := range aggFuncList { + newArgList := make([]expression.Expression, 0, len(aggFunc.Args)) + for _, arg := range aggFunc.Args { + newArg, np, err := b.rewrite(ctx, arg, p, nil, true) + if err != nil { + return nil, nil, err + } + p = np + newArgList = append(newArgList, newArg) + } + newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), aggFunc.F, newArgList, aggFunc.Distinct) + if err != nil { + return nil, nil, err + } + if newFunc.Name != ast.AggFuncFirstRow { + allAggsFirstRow = false + } + if aggFunc.Order != nil { + trueArgs := aggFunc.Args[:len(aggFunc.Args)-1] // the last argument is SEPARATOR, remote it. + resolver := &aggOrderByResolver{ + ctx: b.ctx, + args: trueArgs, + } + for _, byItem := range aggFunc.Order.Items { + resolver.exprDepth = 0 + resolver.err = nil + retExpr, _ := byItem.Expr.Accept(resolver) + if resolver.err != nil { + return nil, nil, errors.Trace(resolver.err) + } + newByItem, np, err := b.rewrite(ctx, retExpr.(ast.ExprNode), p, nil, true) + if err != nil { + return nil, nil, err + } + p = np + newFunc.OrderByItems = append(newFunc.OrderByItems, &util.ByItems{Expr: newByItem, Desc: byItem.Desc}) + } + } + // combine identical aggregate functions + combined := false + for j := 0; j < i; j++ { + oldFunc := plan4Agg.AggFuncs[aggIndexMap[j]] + if oldFunc.Equal(b.ctx.GetExprCtx().GetEvalCtx(), newFunc) { + aggIndexMap[i] = aggIndexMap[j] + combined = true + if _, ok := correlatedAggMap[aggFunc]; ok { + if _, ok = b.correlatedAggMapper[aggFuncList[j]]; !ok { + b.correlatedAggMapper[aggFuncList[j]] = &expression.CorrelatedColumn{ + Column: *schema4Agg.Columns[aggIndexMap[j]], + Data: new(types.Datum), + } + } + b.correlatedAggMapper[aggFunc] = b.correlatedAggMapper[aggFuncList[j]] + } + break + } + } + // create new columns for aggregate functions which show up first + if !combined { + position := len(plan4Agg.AggFuncs) + aggIndexMap[i] = position + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) + column := expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: newFunc.RetTp, + } + schema4Agg.Append(&column) + names = append(names, types.EmptyName) + if _, ok := correlatedAggMap[aggFunc]; ok { + b.correlatedAggMapper[aggFunc] = &expression.CorrelatedColumn{ + Column: column, + Data: new(types.Datum), + } + } + } + } + for i, col := range p.Schema().Columns { + newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, nil, err + } + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) + newCol, _ := col.Clone().(*expression.Column) + newCol.RetType = newFunc.RetTp + schema4Agg.Append(newCol) + names = append(names, p.OutputNames()[i]) + } + var ( + join *LogicalJoin + isJoin bool + isSelectionJoin bool + ) + join, isJoin = p.(*LogicalJoin) + selection, isSelection := p.(*LogicalSelection) + if isSelection { + join, isSelectionJoin = selection.Children()[0].(*LogicalJoin) + } + if (isJoin && join.FullSchema != nil) || (isSelectionJoin && join.FullSchema != nil) { + for i, col := range join.FullSchema.Columns { + if p.Schema().Contains(col) { + continue + } + newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, nil, err + } + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) + newCol, _ := col.Clone().(*expression.Column) + newCol.RetType = newFunc.RetTp + schema4Agg.Append(newCol) + names = append(names, join.FullNames[i]) + } + } + hasGroupBy := len(gbyItems) > 0 + for i, aggFunc := range plan4Agg.AggFuncs { + err := aggFunc.UpdateNotNullFlag4RetType(hasGroupBy, allAggsFirstRow) + if err != nil { + return nil, nil, err + } + schema4Agg.Columns[i].RetType = aggFunc.RetTp + } + plan4Agg.SetOutputNames(names) + plan4Agg.SetChildren(p) + if rollupExpand != nil { + // append gid and gpos as the group keys if any. + plan4Agg.GroupByItems = append(gbyItems, rollupExpand.GID) + if rollupExpand.GPos != nil { + plan4Agg.GroupByItems = append(plan4Agg.GroupByItems, rollupExpand.GPos) + } + } else { + plan4Agg.GroupByItems = gbyItems + } + plan4Agg.SetSchema(schema4Agg) + return plan4Agg, aggIndexMap, nil +} + +func (b *PlanBuilder) buildTableRefs(ctx context.Context, from *ast.TableRefsClause) (p base.LogicalPlan, err error) { + if from == nil { + p = b.buildTableDual() + return + } + defer func() { + // After build the resultSetNode, need to reset it so that it can be referenced by outer level. + for _, cte := range b.outerCTEs { + cte.recursiveRef = false + } + }() + return b.buildResultSetNode(ctx, from.TableRefs, false) +} + +func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSetNode, isCTE bool) (p base.LogicalPlan, err error) { + //If it is building the CTE queries, we will mark them. + b.isCTE = isCTE + switch x := node.(type) { + case *ast.Join: + return b.buildJoin(ctx, x) + case *ast.TableSource: + var isTableName bool + switch v := x.Source.(type) { + case *ast.SelectStmt: + ci := b.prepareCTECheckForSubQuery() + defer resetCTECheckForSubQuery(ci) + b.optFlag = b.optFlag | flagConstantPropagation + p, err = b.buildSelect(ctx, v) + case *ast.SetOprStmt: + ci := b.prepareCTECheckForSubQuery() + defer resetCTECheckForSubQuery(ci) + p, err = b.buildSetOpr(ctx, v) + case *ast.TableName: + p, err = b.buildDataSource(ctx, v, &x.AsName) + isTableName = true + default: + err = plannererrors.ErrUnsupportedType.GenWithStackByArgs(v) + } + if err != nil { + return nil, err + } + + for _, name := range p.OutputNames() { + if name.Hidden { + continue + } + if x.AsName.L != "" { + name.TblName = x.AsName + } + } + // `TableName` is not a select block, so we do not need to handle it. + var plannerSelectBlockAsName []ast.HintTable + if p := b.ctx.GetSessionVars().PlannerSelectBlockAsName.Load(); p != nil { + plannerSelectBlockAsName = *p + } + if len(plannerSelectBlockAsName) > 0 && !isTableName { + plannerSelectBlockAsName[p.QueryBlockOffset()] = ast.HintTable{DBName: p.OutputNames()[0].DBName, TableName: p.OutputNames()[0].TblName} + } + // Duplicate column name in one table is not allowed. + // "select * from (select 1, 1) as a;" is duplicate + dupNames := make(map[string]struct{}, len(p.Schema().Columns)) + for _, name := range p.OutputNames() { + colName := name.ColName.O + if _, ok := dupNames[colName]; ok { + return nil, plannererrors.ErrDupFieldName.GenWithStackByArgs(colName) + } + dupNames[colName] = struct{}{} + } + return p, nil + case *ast.SelectStmt: + return b.buildSelect(ctx, x) + case *ast.SetOprStmt: + return b.buildSetOpr(ctx, x) + default: + return nil, plannererrors.ErrUnsupportedType.GenWithStack("Unsupported ast.ResultSetNode(%T) for buildResultSetNode()", x) + } +} + +// extractTableAlias returns table alias of the base.LogicalPlan's columns. +// It will return nil when there are multiple table alias, because the alias is only used to check if +// the base.LogicalPlan Match some optimizer hints, and hints are not expected to take effect in this case. +func extractTableAlias(p base.Plan, parentOffset int) *h.HintedTable { + if len(p.OutputNames()) > 0 && p.OutputNames()[0].TblName.L != "" { + firstName := p.OutputNames()[0] + for _, name := range p.OutputNames() { + if name.TblName.L != firstName.TblName.L || + (name.DBName.L != "" && firstName.DBName.L != "" && name.DBName.L != firstName.DBName.L) { // DBName can be nil, see #46160 + return nil + } + } + qbOffset := p.QueryBlockOffset() + var blockAsNames []ast.HintTable + if p := p.SCtx().GetSessionVars().PlannerSelectBlockAsName.Load(); p != nil { + blockAsNames = *p + } + // For sub-queries like `(select * from t) t1`, t1 should belong to its surrounding select block. + if qbOffset != parentOffset && blockAsNames != nil && blockAsNames[qbOffset].TableName.L != "" { + qbOffset = parentOffset + } + dbName := firstName.DBName + if dbName.L == "" { + dbName = model.NewCIStr(p.SCtx().GetSessionVars().CurrentDB) + } + return &h.HintedTable{DBName: dbName, TblName: firstName.TblName, SelectOffset: qbOffset} + } + return nil +} + +func setPreferredJoinTypeFromOneSide(preferJoinType uint, isLeft bool) (resJoinType uint) { + if preferJoinType == 0 { + return + } + if preferJoinType&h.PreferINLJ > 0 { + preferJoinType &= ^h.PreferINLJ + if isLeft { + resJoinType |= h.PreferLeftAsINLJInner + } else { + resJoinType |= h.PreferRightAsINLJInner + } + } + if preferJoinType&h.PreferINLHJ > 0 { + preferJoinType &= ^h.PreferINLHJ + if isLeft { + resJoinType |= h.PreferLeftAsINLHJInner + } else { + resJoinType |= h.PreferRightAsINLHJInner + } + } + if preferJoinType&h.PreferINLMJ > 0 { + preferJoinType &= ^h.PreferINLMJ + if isLeft { + resJoinType |= h.PreferLeftAsINLMJInner + } else { + resJoinType |= h.PreferRightAsINLMJInner + } + } + if preferJoinType&h.PreferHJBuild > 0 { + preferJoinType &= ^h.PreferHJBuild + if isLeft { + resJoinType |= h.PreferLeftAsHJBuild + } else { + resJoinType |= h.PreferRightAsHJBuild + } + } + if preferJoinType&h.PreferHJProbe > 0 { + preferJoinType &= ^h.PreferHJProbe + if isLeft { + resJoinType |= h.PreferLeftAsHJProbe + } else { + resJoinType |= h.PreferRightAsHJProbe + } + } + resJoinType |= preferJoinType + return +} + +func (ds *DataSource) setPreferredStoreType(hintInfo *h.PlanHints) { + if hintInfo == nil { + return + } + + var alias *h.HintedTable + if len(ds.TableAsName.L) != 0 { + alias = &h.HintedTable{DBName: ds.DBName, TblName: *ds.TableAsName, SelectOffset: ds.QueryBlockOffset()} + } else { + alias = &h.HintedTable{DBName: ds.DBName, TblName: ds.TableInfo.Name, SelectOffset: ds.QueryBlockOffset()} + } + if hintTbl := hintInfo.IfPreferTiKV(alias); hintTbl != nil { + for _, path := range ds.PossibleAccessPaths { + if path.StoreType == kv.TiKV { + ds.PreferStoreType |= h.PreferTiKV + ds.PreferPartitions[h.PreferTiKV] = hintTbl.Partitions + break + } + } + if ds.PreferStoreType&h.PreferTiKV == 0 { + errMsg := fmt.Sprintf("No available path for table %s.%s with the store type %s of the hint /*+ read_from_storage */, "+ + "please check the status of the table replica and variable value of tidb_isolation_read_engines(%v)", + ds.DBName.O, ds.table.Meta().Name.O, kv.TiKV.Name(), ds.SCtx().GetSessionVars().GetIsolationReadEngines()) + ds.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) + } else { + ds.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because you have set a hint to read table `" + hintTbl.TblName.O + "` from TiKV.") + } + } + if hintTbl := hintInfo.IfPreferTiFlash(alias); hintTbl != nil { + // `ds.PreferStoreType != 0`, which means there's a hint hit the both TiKV value and TiFlash value for table. + // We can't support read a table from two different storages, even partition table. + if ds.PreferStoreType != 0 { + ds.SCtx().GetSessionVars().StmtCtx.SetHintWarning( + fmt.Sprintf("Storage hints are conflict, you can only specify one storage type of table %s.%s", + alias.DBName.L, alias.TblName.L)) + ds.PreferStoreType = 0 + return + } + for _, path := range ds.PossibleAccessPaths { + if path.StoreType == kv.TiFlash { + ds.PreferStoreType |= h.PreferTiFlash + ds.PreferPartitions[h.PreferTiFlash] = hintTbl.Partitions + break + } + } + if ds.PreferStoreType&h.PreferTiFlash == 0 { + errMsg := fmt.Sprintf("No available path for table %s.%s with the store type %s of the hint /*+ read_from_storage */, "+ + "please check the status of the table replica and variable value of tidb_isolation_read_engines(%v)", + ds.DBName.O, ds.table.Meta().Name.O, kv.TiFlash.Name(), ds.SCtx().GetSessionVars().GetIsolationReadEngines()) + ds.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) + } + } +} + +func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (base.LogicalPlan, error) { + // We will construct a "Join" node for some statements like "INSERT", + // "DELETE", "UPDATE", "REPLACE". For this scenario "joinNode.Right" is nil + // and we only build the left "ResultSetNode". + if joinNode.Right == nil { + return b.buildResultSetNode(ctx, joinNode.Left, false) + } + + b.optFlag = b.optFlag | flagPredicatePushDown + // Add join reorder flag regardless of inner join or outer join. + b.optFlag = b.optFlag | flagJoinReOrder + b.optFlag |= flagPredicateSimplification + b.optFlag |= flagConvertOuterToInnerJoin + + leftPlan, err := b.buildResultSetNode(ctx, joinNode.Left, false) + if err != nil { + return nil, err + } + + rightPlan, err := b.buildResultSetNode(ctx, joinNode.Right, false) + if err != nil { + return nil, err + } + + // The recursive part in CTE must not be on the right side of a LEFT JOIN. + if lc, ok := rightPlan.(*logicalop.LogicalCTETable); ok && joinNode.Tp == ast.LeftJoin { + return nil, plannererrors.ErrCTERecursiveForbiddenJoinOrder.GenWithStackByArgs(lc.Name) + } + + handleMap1 := b.handleHelper.popMap() + handleMap2 := b.handleHelper.popMap() + b.handleHelper.mergeAndPush(handleMap1, handleMap2) + + joinPlan := LogicalJoin{StraightJoin: joinNode.StraightJoin || b.inStraightJoin}.Init(b.ctx, b.getSelectOffset()) + joinPlan.SetChildren(leftPlan, rightPlan) + joinPlan.SetSchema(expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema())) + joinPlan.SetOutputNames(make([]*types.FieldName, leftPlan.Schema().Len()+rightPlan.Schema().Len())) + copy(joinPlan.OutputNames(), leftPlan.OutputNames()) + copy(joinPlan.OutputNames()[leftPlan.Schema().Len():], rightPlan.OutputNames()) + + // Set join type. + switch joinNode.Tp { + case ast.LeftJoin: + // left outer join need to be checked elimination + b.optFlag = b.optFlag | flagEliminateOuterJoin + joinPlan.JoinType = LeftOuterJoin + util.ResetNotNullFlag(joinPlan.Schema(), leftPlan.Schema().Len(), joinPlan.Schema().Len()) + case ast.RightJoin: + // right outer join need to be checked elimination + b.optFlag = b.optFlag | flagEliminateOuterJoin + joinPlan.JoinType = RightOuterJoin + util.ResetNotNullFlag(joinPlan.Schema(), 0, leftPlan.Schema().Len()) + default: + joinPlan.JoinType = InnerJoin + } + + // Merge sub-plan's FullSchema into this join plan. + // Please read the comment of LogicalJoin.FullSchema for the details. + var ( + lFullSchema, rFullSchema *expression.Schema + lFullNames, rFullNames types.NameSlice + ) + if left, ok := leftPlan.(*LogicalJoin); ok && left.FullSchema != nil { + lFullSchema = left.FullSchema + lFullNames = left.FullNames + } else { + lFullSchema = leftPlan.Schema() + lFullNames = leftPlan.OutputNames() + } + if right, ok := rightPlan.(*LogicalJoin); ok && right.FullSchema != nil { + rFullSchema = right.FullSchema + rFullNames = right.FullNames + } else { + rFullSchema = rightPlan.Schema() + rFullNames = rightPlan.OutputNames() + } + if joinNode.Tp == ast.RightJoin { + // Make sure lFullSchema means outer full schema and rFullSchema means inner full schema. + lFullSchema, rFullSchema = rFullSchema, lFullSchema + lFullNames, rFullNames = rFullNames, lFullNames + } + joinPlan.FullSchema = expression.MergeSchema(lFullSchema, rFullSchema) + + // Clear NotNull flag for the inner side schema if it's an outer join. + if joinNode.Tp == ast.LeftJoin || joinNode.Tp == ast.RightJoin { + util.ResetNotNullFlag(joinPlan.FullSchema, lFullSchema.Len(), joinPlan.FullSchema.Len()) + } + + // Merge sub-plan's FullNames into this join plan, similar to the FullSchema logic above. + joinPlan.FullNames = make([]*types.FieldName, 0, len(lFullNames)+len(rFullNames)) + for _, lName := range lFullNames { + name := *lName + joinPlan.FullNames = append(joinPlan.FullNames, &name) + } + for _, rName := range rFullNames { + name := *rName + joinPlan.FullNames = append(joinPlan.FullNames, &name) + } + + // Set preferred join algorithm if some join hints is specified by user. + joinPlan.SetPreferredJoinTypeAndOrder(b.TableHints()) + + // "NATURAL JOIN" doesn't have "ON" or "USING" conditions. + // + // The "NATURAL [LEFT] JOIN" of two tables is defined to be semantically + // equivalent to an "INNER JOIN" or a "LEFT JOIN" with a "USING" clause + // that names all columns that exist in both tables. + // + // See https://dev.mysql.com/doc/refman/5.7/en/join.html for more detail. + if joinNode.NaturalJoin { + err = b.buildNaturalJoin(joinPlan, leftPlan, rightPlan, joinNode) + if err != nil { + return nil, err + } + } else if joinNode.Using != nil { + err = b.buildUsingClause(joinPlan, leftPlan, rightPlan, joinNode) + if err != nil { + return nil, err + } + } else if joinNode.On != nil { + b.curClause = onClause + onExpr, newPlan, err := b.rewrite(ctx, joinNode.On.Expr, joinPlan, nil, false) + if err != nil { + return nil, err + } + if newPlan != joinPlan { + return nil, errors.New("ON condition doesn't support subqueries yet") + } + onCondition := expression.SplitCNFItems(onExpr) + // Keep these expressions as a LogicalSelection upon the inner join, in order to apply + // possible decorrelate optimizations. The ON clause is actually treated as a WHERE clause now. + if joinPlan.JoinType == InnerJoin { + sel := LogicalSelection{Conditions: onCondition}.Init(b.ctx, b.getSelectOffset()) + sel.SetChildren(joinPlan) + return sel, nil + } + joinPlan.AttachOnConds(onCondition) + } else if joinPlan.JoinType == InnerJoin { + // If a inner join without "ON" or "USING" clause, it's a cartesian + // product over the join tables. + joinPlan.CartesianJoin = true + } + + return joinPlan, nil +} + +// buildUsingClause eliminate the redundant columns and ordering columns based +// on the "USING" clause. +// +// According to the standard SQL, columns are ordered in the following way: +// 1. coalesced common columns of "leftPlan" and "rightPlan", in the order they +// appears in "leftPlan". +// 2. the rest columns in "leftPlan", in the order they appears in "leftPlan". +// 3. the rest columns in "rightPlan", in the order they appears in "rightPlan". +func (b *PlanBuilder) buildUsingClause(p *LogicalJoin, leftPlan, rightPlan base.LogicalPlan, join *ast.Join) error { + filter := make(map[string]bool, len(join.Using)) + for _, col := range join.Using { + filter[col.Name.L] = true + } + err := b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp, filter) + if err != nil { + return err + } + // We do not need to coalesce columns for update and delete. + if b.inUpdateStmt || b.inDeleteStmt { + p.SetSchemaAndNames(expression.MergeSchema(p.Children()[0].Schema(), p.Children()[1].Schema()), + append(p.Children()[0].OutputNames(), p.Children()[1].OutputNames()...)) + } + return nil +} + +// buildNaturalJoin builds natural join output schema. It finds out all the common columns +// then using the same mechanism as buildUsingClause to eliminate redundant columns and build join conditions. +// According to standard SQL, producing this display order: +// +// All the common columns +// Every column in the first (left) table that is not a common column +// Every column in the second (right) table that is not a common column +func (b *PlanBuilder) buildNaturalJoin(p *LogicalJoin, leftPlan, rightPlan base.LogicalPlan, join *ast.Join) error { + err := b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp, nil) + if err != nil { + return err + } + // We do not need to coalesce columns for update and delete. + if b.inUpdateStmt || b.inDeleteStmt { + p.SetSchemaAndNames(expression.MergeSchema(p.Children()[0].Schema(), p.Children()[1].Schema()), + append(p.Children()[0].OutputNames(), p.Children()[1].OutputNames()...)) + } + return nil +} + +// coalesceCommonColumns is used by buildUsingClause and buildNaturalJoin. The filter is used by buildUsingClause. +func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan base.LogicalPlan, joinTp ast.JoinType, filter map[string]bool) error { + lsc := leftPlan.Schema().Clone() + rsc := rightPlan.Schema().Clone() + if joinTp == ast.LeftJoin { + util.ResetNotNullFlag(rsc, 0, rsc.Len()) + } else if joinTp == ast.RightJoin { + util.ResetNotNullFlag(lsc, 0, lsc.Len()) + } + lColumns, rColumns := lsc.Columns, rsc.Columns + lNames, rNames := leftPlan.OutputNames().Shallow(), rightPlan.OutputNames().Shallow() + if joinTp == ast.RightJoin { + leftPlan, rightPlan = rightPlan, leftPlan + lNames, rNames = rNames, lNames + lColumns, rColumns = rsc.Columns, lsc.Columns + } + + // Check using clause with ambiguous columns. + if filter != nil { + checkAmbiguous := func(names types.NameSlice) error { + columnNameInFilter := set.StringSet{} + for _, name := range names { + if _, ok := filter[name.ColName.L]; !ok { + continue + } + if columnNameInFilter.Exist(name.ColName.L) { + return plannererrors.ErrAmbiguous.GenWithStackByArgs(name.ColName.L, "from clause") + } + columnNameInFilter.Insert(name.ColName.L) + } + return nil + } + err := checkAmbiguous(lNames) + if err != nil { + return err + } + err = checkAmbiguous(rNames) + if err != nil { + return err + } + } else { + // Even with no using filter, we still should check the checkAmbiguous name before we try to find the common column from both side. + // (t3 cross join t4) natural join t1 + // t1 natural join (t3 cross join t4) + // t3 and t4 may generate the same name column from cross join. + // for every common column of natural join, the name from right or left should be exactly one. + commonNames := make([]string, 0, len(lNames)) + lNameMap := make(map[string]int, len(lNames)) + rNameMap := make(map[string]int, len(rNames)) + for _, name := range lNames { + // Natural join should ignore _tidb_rowid + if name.ColName.L == "_tidb_rowid" { + continue + } + // record left map + if cnt, ok := lNameMap[name.ColName.L]; ok { + lNameMap[name.ColName.L] = cnt + 1 + } else { + lNameMap[name.ColName.L] = 1 + } + } + for _, name := range rNames { + // Natural join should ignore _tidb_rowid + if name.ColName.L == "_tidb_rowid" { + continue + } + // record right map + if cnt, ok := rNameMap[name.ColName.L]; ok { + rNameMap[name.ColName.L] = cnt + 1 + } else { + rNameMap[name.ColName.L] = 1 + } + // check left map + if cnt, ok := lNameMap[name.ColName.L]; ok { + if cnt > 1 { + return plannererrors.ErrAmbiguous.GenWithStackByArgs(name.ColName.L, "from clause") + } + commonNames = append(commonNames, name.ColName.L) + } + } + // check right map + for _, commonName := range commonNames { + if rNameMap[commonName] > 1 { + return plannererrors.ErrAmbiguous.GenWithStackByArgs(commonName, "from clause") + } + } + } + + // Find out all the common columns and put them ahead. + commonLen := 0 + for i, lName := range lNames { + // Natural join should ignore _tidb_rowid + if lName.ColName.L == "_tidb_rowid" { + continue + } + for j := commonLen; j < len(rNames); j++ { + if lName.ColName.L != rNames[j].ColName.L { + continue + } + + if len(filter) > 0 { + if !filter[lName.ColName.L] { + break + } + // Mark this column exist. + filter[lName.ColName.L] = false + } + + col := lColumns[i] + copy(lColumns[commonLen+1:i+1], lColumns[commonLen:i]) + lColumns[commonLen] = col + + name := lNames[i] + copy(lNames[commonLen+1:i+1], lNames[commonLen:i]) + lNames[commonLen] = name + + col = rColumns[j] + copy(rColumns[commonLen+1:j+1], rColumns[commonLen:j]) + rColumns[commonLen] = col + + name = rNames[j] + copy(rNames[commonLen+1:j+1], rNames[commonLen:j]) + rNames[commonLen] = name + + commonLen++ + break + } + } + + if len(filter) > 0 && len(filter) != commonLen { + for col, notExist := range filter { + if notExist { + return plannererrors.ErrUnknownColumn.GenWithStackByArgs(col, "from clause") + } + } + } + + schemaCols := make([]*expression.Column, len(lColumns)+len(rColumns)-commonLen) + copy(schemaCols[:len(lColumns)], lColumns) + copy(schemaCols[len(lColumns):], rColumns[commonLen:]) + names := make(types.NameSlice, len(schemaCols)) + copy(names, lNames) + copy(names[len(lNames):], rNames[commonLen:]) + + conds := make([]expression.Expression, 0, commonLen) + for i := 0; i < commonLen; i++ { + lc, rc := lsc.Columns[i], rsc.Columns[i] + cond, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc) + if err != nil { + return err + } + conds = append(conds, cond) + if p.FullSchema != nil { + // since FullSchema is derived from left and right schema in upper layer, so rc/lc must be in FullSchema. + if joinTp == ast.RightJoin { + p.FullNames[p.FullSchema.ColumnIndex(lc)].Redundant = true + } else { + p.FullNames[p.FullSchema.ColumnIndex(rc)].Redundant = true + } + } + } + + p.SetSchema(expression.NewSchema(schemaCols...)) + p.SetOutputNames(names) + + p.OtherConditions = append(conds, p.OtherConditions...) + + return nil +} + +func (b *PlanBuilder) buildSelection(ctx context.Context, p base.LogicalPlan, where ast.ExprNode, aggMapper map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, error) { + b.optFlag |= flagPredicatePushDown + b.optFlag |= flagDeriveTopNFromWindow + b.optFlag |= flagPredicateSimplification + if b.curClause != havingClause { + b.curClause = whereClause + } + + conditions := splitWhere(where) + expressions := make([]expression.Expression, 0, len(conditions)) + selection := LogicalSelection{}.Init(b.ctx, b.getSelectOffset()) + for _, cond := range conditions { + expr, np, err := b.rewrite(ctx, cond, p, aggMapper, false) + if err != nil { + return nil, err + } + // for case: explain SELECT year+2 as y, SUM(profit) AS profit FROM sales GROUP BY year+2, year+profit WITH ROLLUP having y > 2002; + // currently, we succeed to resolve y to (year+2), but fail to resolve (year+2) to grouping col, and to base column function: plus(year, 2) instead. + // which will cause this selection being pushed down through Expand OP itself. + // + // In expand, we will additionally project (year+2) out as a new column, let's say grouping_col here, and we wanna it can substitute any upper layer's (year+2) + expr = b.replaceGroupingFunc(expr) + + p = np + if expr == nil { + continue + } + expressions = append(expressions, expr) + } + cnfExpres := make([]expression.Expression, 0) + useCache := b.ctx.GetSessionVars().StmtCtx.UseCache() + for _, expr := range expressions { + cnfItems := expression.SplitCNFItems(expr) + for _, item := range cnfItems { + if con, ok := item.(*expression.Constant); ok && expression.ConstExprConsiderPlanCache(con, useCache) { + ret, _, err := expression.EvalBool(b.ctx.GetExprCtx().GetEvalCtx(), expression.CNFExprs{con}, chunk.Row{}) + if err != nil { + return nil, errors.Trace(err) + } + if ret { + continue + } + // If there is condition which is always false, return dual plan directly. + dual := logicalop.LogicalTableDual{}.Init(b.ctx, b.getSelectOffset()) + dual.SetOutputNames(p.OutputNames()) + dual.SetSchema(p.Schema()) + return dual, nil + } + cnfExpres = append(cnfExpres, item) + } + } + if len(cnfExpres) == 0 { + return p, nil + } + evalCtx := b.ctx.GetExprCtx().GetEvalCtx() + // check expr field types. + for i, expr := range cnfExpres { + if expr.GetType(evalCtx).EvalType() == types.ETString { + tp := &types.FieldType{} + tp.SetType(mysql.TypeDouble) + tp.SetFlag(expr.GetType(evalCtx).GetFlag()) + tp.SetFlen(mysql.MaxRealWidth) + tp.SetDecimal(types.UnspecifiedLength) + types.SetBinChsClnFlag(tp) + cnfExpres[i] = expression.TryPushCastIntoControlFunctionForHybridType(b.ctx.GetExprCtx(), expr, tp) + } + } + selection.Conditions = cnfExpres + selection.SetChildren(p) + return selection, nil +} + +// buildProjectionFieldNameFromColumns builds the field name, table name and database name when field expression is a column reference. +func (*PlanBuilder) buildProjectionFieldNameFromColumns(origField *ast.SelectField, colNameField *ast.ColumnNameExpr, name *types.FieldName) (colName, origColName, tblName, origTblName, dbName model.CIStr) { + origTblName, origColName, dbName = name.OrigTblName, name.OrigColName, name.DBName + if origField.AsName.L == "" { + colName = colNameField.Name.Name + } else { + colName = origField.AsName + } + if tblName.L == "" { + tblName = name.TblName + } else { + tblName = colNameField.Name.Table + } + return +} + +// buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression. +func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(_ context.Context, field *ast.SelectField) (model.CIStr, error) { + if agg, ok := field.Expr.(*ast.AggregateFuncExpr); ok && agg.F == ast.AggFuncFirstRow { + // When the query is select t.a from t group by a; The Column Name should be a but not t.a; + return agg.Args[0].(*ast.ColumnNameExpr).Name.Name, nil + } + + innerExpr := getInnerFromParenthesesAndUnaryPlus(field.Expr) + funcCall, isFuncCall := innerExpr.(*ast.FuncCallExpr) + // When used to produce a result set column, NAME_CONST() causes the column to have the given name. + // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_name-const for details + if isFuncCall && funcCall.FnName.L == ast.NameConst { + if v, err := evalAstExpr(b.ctx.GetExprCtx(), funcCall.Args[0]); err == nil { + if s, err := v.ToString(); err == nil { + return model.NewCIStr(s), nil + } + } + return model.NewCIStr(""), plannererrors.ErrWrongArguments.GenWithStackByArgs("NAME_CONST") + } + valueExpr, isValueExpr := innerExpr.(*driver.ValueExpr) + + // Non-literal: Output as inputed, except that comments need to be removed. + if !isValueExpr { + return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment)), nil + } + + // Literal: Need special processing + switch valueExpr.Kind() { + case types.KindString: + projName := valueExpr.GetString() + projOffset := valueExpr.GetProjectionOffset() + if projOffset >= 0 { + projName = projName[:projOffset] + } + // See #3686, #3994: + // For string literals, string content is used as column name. Non-graph initial characters are trimmed. + fieldName := strings.TrimLeftFunc(projName, func(r rune) bool { + return !unicode.IsOneOf(mysql.RangeGraph, r) + }) + return model.NewCIStr(fieldName), nil + case types.KindNull: + // See #4053, #3685 + return model.NewCIStr("NULL"), nil + case types.KindBinaryLiteral: + // Don't rewrite BIT literal or HEX literals + return model.NewCIStr(field.Text()), nil + case types.KindInt64: + // See #9683 + // TRUE or FALSE can be a int64 + if mysql.HasIsBooleanFlag(valueExpr.Type.GetFlag()) { + if i := valueExpr.GetValue().(int64); i == 0 { + return model.NewCIStr("FALSE"), nil + } + return model.NewCIStr("TRUE"), nil + } + fallthrough + + default: + fieldName := field.Text() + fieldName = strings.TrimLeft(fieldName, "\t\n +(") + fieldName = strings.TrimRight(fieldName, "\t\n )") + return model.NewCIStr(fieldName), nil + } +} + +func buildExpandFieldName(ctx expression.EvalContext, expr expression.Expression, name *types.FieldName, genName string) *types.FieldName { + _, isCol := expr.(*expression.Column) + var origTblName, origColName, dbName, colName, tblName model.CIStr + if genName != "" { + // for case like: gid_, gpos_ + colName = model.NewCIStr(expr.StringWithCtx(ctx, errors.RedactLogDisable)) + } else if isCol { + // col ref to original col, while its nullability may be changed. + origTblName, origColName, dbName = name.OrigTblName, name.OrigColName, name.DBName + colName = model.NewCIStr("ex_" + name.ColName.O) + tblName = model.NewCIStr("ex_" + name.TblName.O) + } else { + // Other: complicated expression. + colName = model.NewCIStr("ex_" + expr.StringWithCtx(ctx, errors.RedactLogDisable)) + } + newName := &types.FieldName{ + TblName: tblName, + OrigTblName: origTblName, + ColName: colName, + OrigColName: origColName, + DBName: dbName, + } + return newName +} + +// buildProjectionField builds the field object according to SelectField in projection. +func (b *PlanBuilder) buildProjectionField(ctx context.Context, p base.LogicalPlan, field *ast.SelectField, expr expression.Expression) (*expression.Column, *types.FieldName, error) { + var origTblName, tblName, origColName, colName, dbName model.CIStr + innerNode := getInnerFromParenthesesAndUnaryPlus(field.Expr) + col, isCol := expr.(*expression.Column) + // Correlated column won't affect the final output names. So we can put it in any of the three logic block. + // Don't put it into the first block just for simplifying the codes. + if colNameField, ok := innerNode.(*ast.ColumnNameExpr); ok && isCol { + // Field is a column reference. + idx := p.Schema().ColumnIndex(col) + var name *types.FieldName + // The column maybe the one from join's redundant part. + if idx == -1 { + name = findColFromNaturalUsingJoin(p, col) + } else { + name = p.OutputNames()[idx] + } + colName, origColName, tblName, origTblName, dbName = b.buildProjectionFieldNameFromColumns(field, colNameField, name) + } else if field.AsName.L != "" { + // Field has alias. + colName = field.AsName + } else { + // Other: field is an expression. + var err error + if colName, err = b.buildProjectionFieldNameFromExpressions(ctx, field); err != nil { + return nil, nil, err + } + } + name := &types.FieldName{ + TblName: tblName, + OrigTblName: origTblName, + ColName: colName, + OrigColName: origColName, + DBName: dbName, + } + if isCol { + return col, name, nil + } + if expr == nil { + return nil, name, nil + } + // invalid unique id + correlatedColUniqueID := int64(0) + if cc, ok := expr.(*expression.CorrelatedColumn); ok { + correlatedColUniqueID = cc.UniqueID + } + // for expr projection, we should record the map relationship down. + newCol := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: expr.GetType(b.ctx.GetExprCtx().GetEvalCtx()), + CorrelatedColUniqueID: correlatedColUniqueID, + } + if b.ctx.GetSessionVars().OptimizerEnableNewOnlyFullGroupByCheck { + if b.ctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol == nil { + b.ctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol = make(map[string]int, 1) + } + b.ctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol[string(expr.HashCode())] = int(newCol.UniqueID) + } + newCol.SetCoercibility(expr.Coercibility()) + return newCol, name, nil +} + +type userVarTypeProcessor struct { + ctx context.Context + plan base.LogicalPlan + builder *PlanBuilder + mapper map[*ast.AggregateFuncExpr]int + err error +} + +func (p *userVarTypeProcessor) Enter(in ast.Node) (ast.Node, bool) { + v, ok := in.(*ast.VariableExpr) + if !ok { + return in, false + } + if v.IsSystem || v.Value == nil { + return in, true + } + _, p.plan, p.err = p.builder.rewrite(p.ctx, v, p.plan, p.mapper, true) + return in, true +} + +func (p *userVarTypeProcessor) Leave(in ast.Node) (ast.Node, bool) { + return in, p.err == nil +} + +func (b *PlanBuilder) preprocessUserVarTypes(ctx context.Context, p base.LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int) error { + aggMapper := make(map[*ast.AggregateFuncExpr]int) + for agg, i := range mapper { + aggMapper[agg] = i + } + processor := userVarTypeProcessor{ + ctx: ctx, + plan: p, + builder: b, + mapper: aggMapper, + } + for _, field := range fields { + field.Expr.Accept(&processor) + if processor.err != nil { + return processor.err + } + } + return nil +} + +// findColFromNaturalUsingJoin is used to recursively find the column from the +// underlying natural-using-join. +// e.g. For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0`, the +// plan will be `join->selection->projection`. The schema of the `selection` +// will be `[t1.a]`, thus we need to recursively retrieve the `t2.a` from the +// underlying join. +func findColFromNaturalUsingJoin(p base.LogicalPlan, col *expression.Column) (name *types.FieldName) { + switch x := p.(type) { + case *logicalop.LogicalLimit, *LogicalSelection, *logicalop.LogicalTopN, *logicalop.LogicalSort, *logicalop.LogicalMaxOneRow: + return findColFromNaturalUsingJoin(p.Children()[0], col) + case *LogicalJoin: + if x.FullSchema != nil { + idx := x.FullSchema.ColumnIndex(col) + return x.FullNames[idx] + } + } + return nil +} + +type resolveGroupingTraverseAction struct { + CurrentBlockExpand *LogicalExpand +} + +func (r resolveGroupingTraverseAction) Transform(expr expression.Expression) (res expression.Expression) { + switch x := expr.(type) { + case *expression.Column: + // when meeting a column, judge whether it's a relate grouping set col. + // eg: select a, b from t group by a, c with rollup, here a is, while b is not. + // in underlying Expand schema (a,b,c,a',c'), a select list should be resolved to a'. + res, _ = r.CurrentBlockExpand.trySubstituteExprWithGroupingSetCol(x) + case *expression.CorrelatedColumn: + // select 1 in (select t2.a from t group by t2.a, b with rollup) from t2; + // in this case: group by item has correlated column t2.a, and it's select list contains t2.a as well. + res, _ = r.CurrentBlockExpand.trySubstituteExprWithGroupingSetCol(x) + case *expression.Constant: + // constant just keep it real: select 1 from t group by a, b with rollup. + res = x + case *expression.ScalarFunction: + // scalar function just try to resolve itself first, then if not changed, trying resolve its children. + var substituted bool + res, substituted = r.CurrentBlockExpand.trySubstituteExprWithGroupingSetCol(x) + if !substituted { + // if not changed, try to resolve it children. + // select a+1, grouping(b) from t group by a+1 (projected as c), b with rollup: in this case, a+1 is resolved as c as a whole. + // select a+1, grouping(b) from t group by a(projected as a'), b with rollup : in this case, a+1 is resolved as a'+ 1. + newArgs := x.GetArgs() + for i, arg := range newArgs { + newArgs[i] = r.Transform(arg) + } + res = x + } + default: + res = expr + } + return res +} + +func (b *PlanBuilder) replaceGroupingFunc(expr expression.Expression) expression.Expression { + // current block doesn't have an expand OP, just return it. + if b.currentBlockExpand == nil { + return expr + } + // curExpand can supply the DistinctGbyExprs and gid col. + traverseAction := resolveGroupingTraverseAction{CurrentBlockExpand: b.currentBlockExpand} + return expr.Traverse(traverseAction) +} + +func (b *PlanBuilder) implicitProjectGroupingSetCols(projSchema *expression.Schema, projNames []*types.FieldName, projExprs []expression.Expression) (*expression.Schema, []*types.FieldName, []expression.Expression) { + if b.currentBlockExpand == nil { + return projSchema, projNames, projExprs + } + m := make(map[int64]struct{}, len(b.currentBlockExpand.DistinctGroupByCol)) + for _, col := range projSchema.Columns { + m[col.UniqueID] = struct{}{} + } + for idx, gCol := range b.currentBlockExpand.DistinctGroupByCol { + if _, ok := m[gCol.UniqueID]; ok { + // grouping col has been explicitly projected, not need to reserve it here for later order-by item (a+1) + // like: select a+1, b from t group by a+1 order by a+1. + continue + } + // project the grouping col out implicitly here. If it's not used by later OP, it will be cleaned in column pruner. + projSchema.Append(gCol) + projExprs = append(projExprs, gCol) + projNames = append(projNames, b.currentBlockExpand.DistinctGbyColNames[idx]) + } + // project GID. + projSchema.Append(b.currentBlockExpand.GID) + projExprs = append(projExprs, b.currentBlockExpand.GID) + projNames = append(projNames, b.currentBlockExpand.GIDName) + // project GPos if any. + if b.currentBlockExpand.GPos != nil { + projSchema.Append(b.currentBlockExpand.GPos) + projExprs = append(projExprs, b.currentBlockExpand.GPos) + projNames = append(projNames, b.currentBlockExpand.GPosName) + } + return projSchema, projNames, projExprs +} + +// buildProjection returns a Projection plan and non-aux columns length. +func (b *PlanBuilder) buildProjection(ctx context.Context, p base.LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, + windowMapper map[*ast.WindowFuncExpr]int, considerWindow bool, expandGenerateColumn bool) (base.LogicalPlan, []expression.Expression, int, error) { + err := b.preprocessUserVarTypes(ctx, p, fields, mapper) + if err != nil { + return nil, nil, 0, err + } + b.optFlag |= flagEliminateProjection + b.curClause = fieldList + proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx, b.getSelectOffset()) + schema := expression.NewSchema(make([]*expression.Column, 0, len(fields))...) + oldLen := 0 + newNames := make([]*types.FieldName, 0, len(fields)) + for i, field := range fields { + if !field.Auxiliary { + oldLen++ + } + + isWindowFuncField := ast.HasWindowFlag(field.Expr) + // Although window functions occurs in the select fields, but it has to be processed after having clause. + // So when we build the projection for select fields, we need to skip the window function. + // When `considerWindow` is false, we will only build fields for non-window functions, so we add fake placeholders. + // for window functions. These fake placeholders will be erased in column pruning. + // When `considerWindow` is true, all the non-window fields have been built, so we just use the schema columns. + if considerWindow && !isWindowFuncField { + col := p.Schema().Columns[i] + proj.Exprs = append(proj.Exprs, col) + schema.Append(col) + newNames = append(newNames, p.OutputNames()[i]) + continue + } else if !considerWindow && isWindowFuncField { + expr := expression.NewZero() + proj.Exprs = append(proj.Exprs, expr) + col, name, err := b.buildProjectionField(ctx, p, field, expr) + if err != nil { + return nil, nil, 0, err + } + schema.Append(col) + newNames = append(newNames, name) + continue + } + newExpr, np, err := b.rewriteWithPreprocess(ctx, field.Expr, p, mapper, windowMapper, true, nil) + if err != nil { + return nil, nil, 0, err + } + + // for case: select a+1, b, sum(b), grouping(a) from t group by a, b with rollup. + // the column inside aggregate (only sum(b) here) should be resolved to original source column, + // while for others, just use expanded columns if exists: a'+ 1, b', group(gid) + newExpr = b.replaceGroupingFunc(newExpr) + + // For window functions in the order by clause, we will append an field for it. + // We need rewrite the window mapper here so order by clause could find the added field. + if considerWindow && isWindowFuncField && field.Auxiliary { + if windowExpr, ok := field.Expr.(*ast.WindowFuncExpr); ok { + windowMapper[windowExpr] = i + } + } + + p = np + proj.Exprs = append(proj.Exprs, newExpr) + + col, name, err := b.buildProjectionField(ctx, p, field, newExpr) + if err != nil { + return nil, nil, 0, err + } + schema.Append(col) + newNames = append(newNames, name) + } + // implicitly project expand grouping set cols, if not used later, it will being pruned out in logical column pruner. + schema, newNames, proj.Exprs = b.implicitProjectGroupingSetCols(schema, newNames, proj.Exprs) + + proj.SetSchema(schema) + proj.SetOutputNames(newNames) + if expandGenerateColumn { + // Sometimes we need to add some fields to the projection so that we can use generate column substitute + // optimization. For example: select a+1 from t order by a+1, with a virtual generate column c as (a+1) and + // an index on c. We need to add c into the projection so that we can replace a+1 with c. + exprToColumn := make(ExprColumnMap) + collectGenerateColumn(p, exprToColumn) + for expr, col := range exprToColumn { + idx := p.Schema().ColumnIndex(col) + if idx == -1 { + continue + } + if proj.Schema().Contains(col) { + continue + } + proj.Schema().Columns = append(proj.Schema().Columns, col) + proj.Exprs = append(proj.Exprs, expr) + proj.SetOutputNames(append(proj.OutputNames(), p.OutputNames()[idx])) + } + } + proj.SetChildren(p) + // delay the only-full-group-by-check in create view statement to later query. + if !b.isCreateView && b.ctx.GetSessionVars().OptimizerEnableNewOnlyFullGroupByCheck && b.ctx.GetSessionVars().SQLMode.HasOnlyFullGroupBy() { + fds := proj.ExtractFD() + // Projection -> Children -> ... + // Let the projection itself to evaluate the whole FD, which will build the connection + // 1: from select-expr to registered-expr + // 2: from base-column to select-expr + // After that + if fds.HasAggBuilt { + for offset, expr := range proj.Exprs[:len(fields)] { + // skip the auxiliary column in agg appended to select fields, which mainly comes from two kind of cases: + // 1: having agg(t.a), this will append t.a to the select fields, if it isn't here. + // 2: order by agg(t.a), this will append t.a to the select fields, if it isn't here. + if fields[offset].AuxiliaryColInAgg { + continue + } + item := intset.NewFastIntSet() + switch x := expr.(type) { + case *expression.Column: + item.Insert(int(x.UniqueID)) + case *expression.ScalarFunction: + if expression.CheckFuncInExpr(x, ast.AnyValue) { + continue + } + scalarUniqueID, ok := fds.IsHashCodeRegistered(string(hack.String(x.HashCode()))) + if !ok { + logutil.BgLogger().Warn("Error occurred while maintaining the functional dependency") + continue + } + item.Insert(scalarUniqueID) + default: + } + // Rule #1, if there are no group cols, the col in the order by shouldn't be limited. + if fds.GroupByCols.Only1Zero() && fields[offset].AuxiliaryColInOrderBy { + continue + } + + // Rule #2, if select fields are constant, it's ok. + if item.SubsetOf(fds.ConstantCols()) { + continue + } + + // Rule #3, if select fields are subset of group by items, it's ok. + if item.SubsetOf(fds.GroupByCols) { + continue + } + + // Rule #4, if select fields are dependencies of Strict FD with determinants in group-by items, it's ok. + // lax FD couldn't be done here, eg: for unique key (b), index key NULL & NULL are different rows with + // uncertain other column values. + strictClosure := fds.ClosureOfStrict(fds.GroupByCols) + if item.SubsetOf(strictClosure) { + continue + } + // locate the base col that are not in (constant list / group by list / strict fd closure) for error show. + baseCols := expression.ExtractColumns(expr) + errShowCol := baseCols[0] + for _, col := range baseCols { + colSet := intset.NewFastIntSet(int(col.UniqueID)) + if !colSet.SubsetOf(strictClosure) { + errShowCol = col + break + } + } + // better use the schema alias name firstly if any. + name := "" + for idx, schemaCol := range proj.Schema().Columns { + if schemaCol.UniqueID == errShowCol.UniqueID { + name = proj.OutputNames()[idx].String() + break + } + } + if name == "" { + name = errShowCol.OrigName + } + // Only1Zero is to judge whether it's no-group-by-items case. + if !fds.GroupByCols.Only1Zero() { + return nil, nil, 0, plannererrors.ErrFieldNotInGroupBy.GenWithStackByArgs(offset+1, ErrExprInSelect, name) + } + return nil, nil, 0, plannererrors.ErrMixOfGroupFuncAndFields.GenWithStackByArgs(offset+1, name) + } + if fds.GroupByCols.Only1Zero() { + // maxOneRow is delayed from agg's ExtractFD logic since some details listed in it. + projectionUniqueIDs := intset.NewFastIntSet() + for _, expr := range proj.Exprs { + switch x := expr.(type) { + case *expression.Column: + projectionUniqueIDs.Insert(int(x.UniqueID)) + case *expression.ScalarFunction: + scalarUniqueID, ok := fds.IsHashCodeRegistered(string(hack.String(x.HashCode()))) + if !ok { + logutil.BgLogger().Warn("Error occurred while maintaining the functional dependency") + continue + } + projectionUniqueIDs.Insert(scalarUniqueID) + } + } + fds.MaxOneRow(projectionUniqueIDs) + } + // for select * from view (include agg), outer projection don't have to check select list with the inner group-by flag. + fds.HasAggBuilt = false + } + } + return proj, proj.Exprs, oldLen, nil +} + +func (b *PlanBuilder) buildDistinct(child base.LogicalPlan, length int) (*LogicalAggregation, error) { + b.optFlag = b.optFlag | flagBuildKeyInfo + b.optFlag = b.optFlag | flagPushDownAgg + plan4Agg := LogicalAggregation{ + AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()), + GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]), + }.Init(b.ctx, child.QueryBlockOffset()) + if hintinfo := b.TableHints(); hintinfo != nil { + plan4Agg.PreferAggType = hintinfo.PreferAggType + plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop + } + for _, col := range child.Schema().Columns { + aggDesc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) + if err != nil { + return nil, err + } + plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, aggDesc) + } + plan4Agg.SetChildren(child) + plan4Agg.SetSchema(child.Schema().Clone()) + plan4Agg.SetOutputNames(child.OutputNames()) + // Distinct will be rewritten as first_row, we reset the type here since the return type + // of first_row is not always the same as the column arg of first_row. + for i, col := range plan4Agg.Schema().Columns { + col.RetType = plan4Agg.AggFuncs[i].RetTp + } + return plan4Agg, nil +} + +// unionJoinFieldType finds the type which can carry the given types in Union. +// Note that unionJoinFieldType doesn't handle charset and collation, caller need to handle it by itself. +func unionJoinFieldType(a, b *types.FieldType) *types.FieldType { + // We ignore the pure NULL type. + if a.GetType() == mysql.TypeNull { + return b + } else if b.GetType() == mysql.TypeNull { + return a + } + resultTp := types.AggFieldType([]*types.FieldType{a, b}) + // This logic will be intelligible when it is associated with the buildProjection4Union logic. + if resultTp.GetType() == mysql.TypeNewDecimal { + // The decimal result type will be unsigned only when all the decimals to be united are unsigned. + resultTp.AndFlag(b.GetFlag() & mysql.UnsignedFlag) + } else { + // Non-decimal results will be unsigned when a,b both unsigned. + // ref1: https://dev.mysql.com/doc/refman/5.7/en/union.html#union-result-set + // ref2: https://github.com/pingcap/tidb/issues/24953 + resultTp.AddFlag((a.GetFlag() & mysql.UnsignedFlag) & (b.GetFlag() & mysql.UnsignedFlag)) + } + resultTp.SetDecimalUnderLimit(max(a.GetDecimal(), b.GetDecimal())) + // `flen - decimal` is the fraction before '.' + if a.GetFlen() == -1 || b.GetFlen() == -1 { + resultTp.SetFlenUnderLimit(-1) + } else { + resultTp.SetFlenUnderLimit(max(a.GetFlen()-a.GetDecimal(), b.GetFlen()-b.GetDecimal()) + resultTp.GetDecimal()) + } + types.TryToFixFlenOfDatetime(resultTp) + if resultTp.EvalType() != types.ETInt && (a.EvalType() == types.ETInt || b.EvalType() == types.ETInt) && resultTp.GetFlen() < mysql.MaxIntWidth { + resultTp.SetFlen(mysql.MaxIntWidth) + } + expression.SetBinFlagOrBinStr(b, resultTp) + return resultTp +} + +// Set the flen of the union column using the max flen in children. +func (b *PlanBuilder) setUnionFlen(resultTp *types.FieldType, cols []expression.Expression) { + if resultTp.GetFlen() == -1 { + return + } + isBinary := resultTp.GetCharset() == charset.CharsetBin + for i := 0; i < len(cols); i++ { + childTp := cols[i].GetType(b.ctx.GetExprCtx().GetEvalCtx()) + childTpCharLen := 1 + if isBinary { + if charsetInfo, ok := charset.CharacterSetInfos[childTp.GetCharset()]; ok { + childTpCharLen = charsetInfo.Maxlen + } + } + resultTp.SetFlen(max(resultTp.GetFlen(), childTpCharLen*childTp.GetFlen())) + } +} + +func (b *PlanBuilder) buildProjection4Union(_ context.Context, u *LogicalUnionAll) error { + unionCols := make([]*expression.Column, 0, u.Children()[0].Schema().Len()) + names := make([]*types.FieldName, 0, u.Children()[0].Schema().Len()) + + // Infer union result types by its children's schema. + for i, col := range u.Children()[0].Schema().Columns { + tmpExprs := make([]expression.Expression, 0, len(u.Children())) + tmpExprs = append(tmpExprs, col) + resultTp := col.RetType + for j := 1; j < len(u.Children()); j++ { + tmpExprs = append(tmpExprs, u.Children()[j].Schema().Columns[i]) + childTp := u.Children()[j].Schema().Columns[i].RetType + resultTp = unionJoinFieldType(resultTp, childTp) + } + collation, err := expression.CheckAndDeriveCollationFromExprs(b.ctx.GetExprCtx(), "UNION", resultTp.EvalType(), tmpExprs...) + if err != nil || collation.Coer == expression.CoercibilityNone { + return collate.ErrIllegalMixCollation.GenWithStackByArgs("UNION") + } + resultTp.SetCharset(collation.Charset) + resultTp.SetCollate(collation.Collation) + b.setUnionFlen(resultTp, tmpExprs) + names = append(names, &types.FieldName{ColName: u.Children()[0].OutputNames()[i].ColName}) + unionCols = append(unionCols, &expression.Column{ + RetType: resultTp, + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + }) + } + u.SetSchema(expression.NewSchema(unionCols...)) + u.SetOutputNames(names) + // Process each child and add a projection above original child. + // So the schema of `UnionAll` can be the same with its children's. + for childID, child := range u.Children() { + exprs := make([]expression.Expression, len(child.Schema().Columns)) + for i, srcCol := range child.Schema().Columns { + dstType := unionCols[i].RetType + srcType := srcCol.RetType + if !srcType.Equal(dstType) { + exprs[i] = expression.BuildCastFunction4Union(b.ctx.GetExprCtx(), srcCol, dstType) + } else { + exprs[i] = srcCol + } + } + b.optFlag |= flagEliminateProjection + proj := logicalop.LogicalProjection{Exprs: exprs, AvoidColumnEvaluator: true}.Init(b.ctx, b.getSelectOffset()) + proj.SetSchema(u.Schema().Clone()) + // reset the schema type to make the "not null" flag right. + for i, expr := range exprs { + proj.Schema().Columns[i].RetType = expr.GetType(b.ctx.GetExprCtx().GetEvalCtx()) + } + proj.SetChildren(child) + u.Children()[childID] = proj + } + return nil +} + +func (b *PlanBuilder) buildSetOpr(ctx context.Context, setOpr *ast.SetOprStmt) (base.LogicalPlan, error) { + if setOpr.With != nil { + l := len(b.outerCTEs) + defer func() { + b.outerCTEs = b.outerCTEs[:l] + }() + _, err := b.buildWith(ctx, setOpr.With) + if err != nil { + return nil, err + } + } + + // Because INTERSECT has higher precedence than UNION and EXCEPT. We build it first. + selectPlans := make([]base.LogicalPlan, 0, len(setOpr.SelectList.Selects)) + afterSetOprs := make([]*ast.SetOprType, 0, len(setOpr.SelectList.Selects)) + selects := setOpr.SelectList.Selects + for i := 0; i < len(selects); i++ { + intersects := []ast.Node{selects[i]} + for i+1 < len(selects) { + breakIteration := false + switch x := selects[i+1].(type) { + case *ast.SelectStmt: + if *x.AfterSetOperator != ast.Intersect && *x.AfterSetOperator != ast.IntersectAll { + breakIteration = true + } + case *ast.SetOprSelectList: + if *x.AfterSetOperator != ast.Intersect && *x.AfterSetOperator != ast.IntersectAll { + breakIteration = true + } + if x.Limit != nil || x.OrderBy != nil { + // when SetOprSelectList's limit and order-by is not nil, it means itself is converted from + // an independent ast.SetOprStmt in parser, its data should be evaluated first, and ordered + // by given items and conduct a limit on it, then it can only be integrated with other brothers. + breakIteration = true + } + } + if breakIteration { + break + } + intersects = append(intersects, selects[i+1]) + i++ + } + selectPlan, afterSetOpr, err := b.buildIntersect(ctx, intersects) + if err != nil { + return nil, err + } + selectPlans = append(selectPlans, selectPlan) + afterSetOprs = append(afterSetOprs, afterSetOpr) + } + setOprPlan, err := b.buildExcept(ctx, selectPlans, afterSetOprs) + if err != nil { + return nil, err + } + + oldLen := setOprPlan.Schema().Len() + + for i := 0; i < len(setOpr.SelectList.Selects); i++ { + b.handleHelper.popMap() + } + b.handleHelper.pushMap(nil) + + if setOpr.OrderBy != nil { + setOprPlan, err = b.buildSort(ctx, setOprPlan, setOpr.OrderBy.Items, nil, nil) + if err != nil { + return nil, err + } + } + + if setOpr.Limit != nil { + setOprPlan, err = b.buildLimit(setOprPlan, setOpr.Limit) + if err != nil { + return nil, err + } + } + + // Fix issue #8189 (https://github.com/pingcap/tidb/issues/8189). + // If there are extra expressions generated from `ORDER BY` clause, generate a `Projection` to remove them. + if oldLen != setOprPlan.Schema().Len() { + proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(setOprPlan.Schema().Columns[:oldLen])}.Init(b.ctx, b.getSelectOffset()) + proj.SetChildren(setOprPlan) + schema := expression.NewSchema(setOprPlan.Schema().Clone().Columns[:oldLen]...) + for _, col := range schema.Columns { + col.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() + } + proj.SetOutputNames(setOprPlan.OutputNames()[:oldLen]) + proj.SetSchema(schema) + return proj, nil + } + return setOprPlan, nil +} + +func (b *PlanBuilder) buildSemiJoinForSetOperator( + leftOriginPlan base.LogicalPlan, + rightPlan base.LogicalPlan, + joinType JoinType) (leftPlan base.LogicalPlan, err error) { + leftPlan, err = b.buildDistinct(leftOriginPlan, leftOriginPlan.Schema().Len()) + if err != nil { + return nil, err + } + b.optFlag |= flagConvertOuterToInnerJoin + + joinPlan := LogicalJoin{JoinType: joinType}.Init(b.ctx, b.getSelectOffset()) + joinPlan.SetChildren(leftPlan, rightPlan) + joinPlan.SetSchema(leftPlan.Schema()) + joinPlan.SetOutputNames(make([]*types.FieldName, leftPlan.Schema().Len())) + copy(joinPlan.OutputNames(), leftPlan.OutputNames()) + for j := 0; j < len(rightPlan.Schema().Columns); j++ { + leftCol, rightCol := leftPlan.Schema().Columns[j], rightPlan.Schema().Columns[j] + eqCond, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.NullEQ, types.NewFieldType(mysql.TypeTiny), leftCol, rightCol) + if err != nil { + return nil, err + } + _, leftArgIsColumn := eqCond.(*expression.ScalarFunction).GetArgs()[0].(*expression.Column) + _, rightArgIsColumn := eqCond.(*expression.ScalarFunction).GetArgs()[1].(*expression.Column) + if leftCol.RetType.GetType() != rightCol.RetType.GetType() || !leftArgIsColumn || !rightArgIsColumn { + joinPlan.OtherConditions = append(joinPlan.OtherConditions, eqCond) + } else { + joinPlan.EqualConditions = append(joinPlan.EqualConditions, eqCond.(*expression.ScalarFunction)) + } + } + return joinPlan, nil +} + +// buildIntersect build the set operator for 'intersect'. It is called before buildExcept and buildUnion because of its +// higher precedence. +func (b *PlanBuilder) buildIntersect(ctx context.Context, selects []ast.Node) (base.LogicalPlan, *ast.SetOprType, error) { + var leftPlan base.LogicalPlan + var err error + var afterSetOperator *ast.SetOprType + switch x := selects[0].(type) { + case *ast.SelectStmt: + afterSetOperator = x.AfterSetOperator + leftPlan, err = b.buildSelect(ctx, x) + case *ast.SetOprSelectList: + afterSetOperator = x.AfterSetOperator + leftPlan, err = b.buildSetOpr(ctx, &ast.SetOprStmt{SelectList: x, With: x.With, Limit: x.Limit, OrderBy: x.OrderBy}) + } + if err != nil { + return nil, nil, err + } + if len(selects) == 1 { + return leftPlan, afterSetOperator, nil + } + + columnNums := leftPlan.Schema().Len() + for i := 1; i < len(selects); i++ { + var rightPlan base.LogicalPlan + switch x := selects[i].(type) { + case *ast.SelectStmt: + if *x.AfterSetOperator == ast.IntersectAll { + // TODO: support intersect all + return nil, nil, errors.Errorf("TiDB do not support intersect all") + } + rightPlan, err = b.buildSelect(ctx, x) + case *ast.SetOprSelectList: + if *x.AfterSetOperator == ast.IntersectAll { + // TODO: support intersect all + return nil, nil, errors.Errorf("TiDB do not support intersect all") + } + rightPlan, err = b.buildSetOpr(ctx, &ast.SetOprStmt{SelectList: x, With: x.With, Limit: x.Limit, OrderBy: x.OrderBy}) + } + if err != nil { + return nil, nil, err + } + if rightPlan.Schema().Len() != columnNums { + return nil, nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() + } + leftPlan, err = b.buildSemiJoinForSetOperator(leftPlan, rightPlan, SemiJoin) + if err != nil { + return nil, nil, err + } + } + return leftPlan, afterSetOperator, nil +} + +// buildExcept build the set operators for 'except', and in this function, it calls buildUnion at the same time. Because +// Union and except has the same precedence. +func (b *PlanBuilder) buildExcept(ctx context.Context, selects []base.LogicalPlan, afterSetOpts []*ast.SetOprType) (base.LogicalPlan, error) { + unionPlans := []base.LogicalPlan{selects[0]} + tmpAfterSetOpts := []*ast.SetOprType{nil} + columnNums := selects[0].Schema().Len() + for i := 1; i < len(selects); i++ { + rightPlan := selects[i] + if rightPlan.Schema().Len() != columnNums { + return nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() + } + if *afterSetOpts[i] == ast.Except { + leftPlan, err := b.buildUnion(ctx, unionPlans, tmpAfterSetOpts) + if err != nil { + return nil, err + } + leftPlan, err = b.buildSemiJoinForSetOperator(leftPlan, rightPlan, AntiSemiJoin) + if err != nil { + return nil, err + } + unionPlans = []base.LogicalPlan{leftPlan} + tmpAfterSetOpts = []*ast.SetOprType{nil} + } else if *afterSetOpts[i] == ast.ExceptAll { + // TODO: support except all. + return nil, errors.Errorf("TiDB do not support except all") + } else { + unionPlans = append(unionPlans, rightPlan) + tmpAfterSetOpts = append(tmpAfterSetOpts, afterSetOpts[i]) + } + } + return b.buildUnion(ctx, unionPlans, tmpAfterSetOpts) +} + +func (b *PlanBuilder) buildUnion(ctx context.Context, selects []base.LogicalPlan, afterSetOpts []*ast.SetOprType) (base.LogicalPlan, error) { + if len(selects) == 1 { + return selects[0], nil + } + distinctSelectPlans, allSelectPlans, err := b.divideUnionSelectPlans(ctx, selects, afterSetOpts) + if err != nil { + return nil, err + } + unionDistinctPlan, err := b.buildUnionAll(ctx, distinctSelectPlans) + if err != nil { + return nil, err + } + if unionDistinctPlan != nil { + unionDistinctPlan, err = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) + if err != nil { + return nil, err + } + if len(allSelectPlans) > 0 { + // Can't change the statements order in order to get the correct column info. + allSelectPlans = append([]base.LogicalPlan{unionDistinctPlan}, allSelectPlans...) + } + } + + unionAllPlan, err := b.buildUnionAll(ctx, allSelectPlans) + if err != nil { + return nil, err + } + unionPlan := unionDistinctPlan + if unionAllPlan != nil { + unionPlan = unionAllPlan + } + + return unionPlan, nil +} + +// divideUnionSelectPlans resolves union's select stmts to logical plans. +// and divide result plans into "union-distinct" and "union-all" parts. +// divide rule ref: +// +// https://dev.mysql.com/doc/refman/5.7/en/union.html +// +// "Mixed UNION types are treated such that a DISTINCT union overrides any ALL union to its left." +func (*PlanBuilder) divideUnionSelectPlans(_ context.Context, selects []base.LogicalPlan, setOprTypes []*ast.SetOprType) (distinctSelects []base.LogicalPlan, allSelects []base.LogicalPlan, err error) { + firstUnionAllIdx := 0 + columnNums := selects[0].Schema().Len() + for i := len(selects) - 1; i > 0; i-- { + if firstUnionAllIdx == 0 && *setOprTypes[i] != ast.UnionAll { + firstUnionAllIdx = i + 1 + } + if selects[i].Schema().Len() != columnNums { + return nil, nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() + } + } + return selects[:firstUnionAllIdx], selects[firstUnionAllIdx:], nil +} + +func (b *PlanBuilder) buildUnionAll(ctx context.Context, subPlan []base.LogicalPlan) (base.LogicalPlan, error) { + if len(subPlan) == 0 { + return nil, nil + } + u := LogicalUnionAll{}.Init(b.ctx, b.getSelectOffset()) + u.SetChildren(subPlan...) + err := b.buildProjection4Union(ctx, u) + return u, err +} + +// itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem +type itemTransformer struct{} + +func (*itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { + if n, ok := inNode.(*driver.ParamMarkerExpr); ok { + newNode := expression.ConstructPositionExpr(n) + return newNode, true + } + return inNode, false +} + +func (*itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { + return inNode, false +} + +func (b *PlanBuilder) buildSort(ctx context.Context, p base.LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int) (*logicalop.LogicalSort, error) { + return b.buildSortWithCheck(ctx, p, byItems, aggMapper, windowMapper, nil, 0, false) +} + +func (b *PlanBuilder) buildSortWithCheck(ctx context.Context, p base.LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, + projExprs []expression.Expression, oldLen int, hasDistinct bool) (*logicalop.LogicalSort, error) { + if _, isUnion := p.(*LogicalUnionAll); isUnion { + b.curClause = globalOrderByClause + } else { + b.curClause = orderByClause + } + sort := logicalop.LogicalSort{}.Init(b.ctx, b.getSelectOffset()) + exprs := make([]*util.ByItems, 0, len(byItems)) + transformer := &itemTransformer{} + for i, item := range byItems { + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) + it, np, err := b.rewriteWithPreprocess(ctx, item.Expr, p, aggMapper, windowMapper, true, nil) + if err != nil { + return nil, err + } + // for case: select a+1, b, sum(b) from t group by a+1, b with rollup order by a+1. + // currently, we fail to resolve (a+1) in order-by to projection item (a+1), and adding + // another a' in the select fields instead, leading finally resolved expr is a'+1 here. + // + // Anyway, a and a' has the same column unique id, so we can do the replacement work like + // we did in build projection phase. + it = b.replaceGroupingFunc(it) + + // check whether ORDER BY items show up in SELECT DISTINCT fields, see #12442 + if hasDistinct && projExprs != nil { + err = b.checkOrderByInDistinct(item, i, it, p, projExprs, oldLen) + if err != nil { + return nil, err + } + } + + p = np + exprs = append(exprs, &util.ByItems{Expr: it, Desc: item.Desc}) + } + sort.ByItems = exprs + sort.SetChildren(p) + return sort, nil +} + +// checkOrderByInDistinct checks whether ORDER BY has conflicts with DISTINCT, see #12442 +func (b *PlanBuilder) checkOrderByInDistinct(byItem *ast.ByItem, idx int, expr expression.Expression, p base.LogicalPlan, originalExprs []expression.Expression, length int) error { + // Check if expressions in ORDER BY whole match some fields in DISTINCT. + // e.g. + // select distinct count(a) from t group by b order by count(a); ✔ + // select distinct a+1 from t order by a+1; ✔ + // select distinct a+1 from t order by a+2; ✗ + evalCtx := b.ctx.GetExprCtx().GetEvalCtx() + for j := 0; j < length; j++ { + // both check original expression & as name + if expr.Equal(evalCtx, originalExprs[j]) || expr.Equal(evalCtx, p.Schema().Columns[j]) { + return nil + } + } + + // Check if referenced columns of expressions in ORDER BY whole match some fields in DISTINCT, + // both original expression and alias can be referenced. + // e.g. + // select distinct a from t order by sin(a); ✔ + // select distinct a, b from t order by a+b; ✔ + // select distinct count(a), sum(a) from t group by b order by sum(a); ✔ + cols := expression.ExtractColumns(expr) +CheckReferenced: + for _, col := range cols { + for j := 0; j < length; j++ { + if col.Equal(evalCtx, originalExprs[j]) || col.Equal(evalCtx, p.Schema().Columns[j]) { + continue CheckReferenced + } + } + + // Failed cases + // e.g. + // select distinct sin(a) from t order by a; ✗ + // select distinct a from t order by a+b; ✗ + if _, ok := byItem.Expr.(*ast.AggregateFuncExpr); ok { + return plannererrors.ErrAggregateInOrderNotSelect.GenWithStackByArgs(idx+1, "DISTINCT") + } + // select distinct count(a) from t group by b order by sum(a); ✗ + return plannererrors.ErrFieldInOrderNotSelect.GenWithStackByArgs(idx+1, col.OrigName, "DISTINCT") + } + return nil +} + +// getUintFromNode gets uint64 value from ast.Node. +// For ordinary statement, node should be uint64 constant value. +// For prepared statement, node is string. We should convert it to uint64. +func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (uVal uint64, isNull bool, isExpectedType bool) { + var val any + switch v := n.(type) { + case *driver.ValueExpr: + val = v.GetValue() + case *driver.ParamMarkerExpr: + if !v.InExecute { + return 0, false, true + } + if mustInt64orUint64 { + if expected, _ := CheckParamTypeInt64orUint64(v); !expected { + return 0, false, false + } + } + param, err := expression.ParamMarkerExpression(ctx, v, false) + if err != nil { + return 0, false, false + } + str, isNull, err := expression.GetStringFromConstant(ctx.GetExprCtx().GetEvalCtx(), param) + if err != nil { + return 0, false, false + } + if isNull { + return 0, true, true + } + val = str + default: + return 0, false, false + } + switch v := val.(type) { + case uint64: + return v, false, true + case int64: + if v >= 0 { + return uint64(v), false, true + } + case string: + ctx := ctx.GetSessionVars().StmtCtx.TypeCtx() + uVal, err := types.StrToUint(ctx, v, false) + if err != nil { + return 0, false, false + } + return uVal, false, true + } + return 0, false, false +} + +// CheckParamTypeInt64orUint64 check param type for plan cache limit, only allow int64 and uint64 now +// eg: set @a = 1; +func CheckParamTypeInt64orUint64(param *driver.ParamMarkerExpr) (bool, uint64) { + val := param.GetValue() + switch v := val.(type) { + case int64: + if v >= 0 { + return true, uint64(v) + } + case uint64: + return true, v + } + return false, 0 +} + +func extractLimitCountOffset(ctx base.PlanContext, limit *ast.Limit) (count uint64, + offset uint64, err error) { + var isExpectedType bool + if limit.Count != nil { + count, _, isExpectedType = getUintFromNode(ctx, limit.Count, true) + if !isExpectedType { + return 0, 0, plannererrors.ErrWrongArguments.GenWithStackByArgs("LIMIT") + } + } + if limit.Offset != nil { + offset, _, isExpectedType = getUintFromNode(ctx, limit.Offset, true) + if !isExpectedType { + return 0, 0, plannererrors.ErrWrongArguments.GenWithStackByArgs("LIMIT") + } + } + return count, offset, nil +} + +func (b *PlanBuilder) buildLimit(src base.LogicalPlan, limit *ast.Limit) (base.LogicalPlan, error) { + b.optFlag = b.optFlag | flagPushDownTopN + var ( + offset, count uint64 + err error + ) + if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil { + return nil, err + } + + if count > math.MaxUint64-offset { + count = math.MaxUint64 - offset + } + if offset+count == 0 { + tableDual := logicalop.LogicalTableDual{RowCount: 0}.Init(b.ctx, b.getSelectOffset()) + tableDual.SetSchema(src.Schema()) + tableDual.SetOutputNames(src.OutputNames()) + return tableDual, nil + } + li := logicalop.LogicalLimit{ + Offset: offset, + Count: count, + }.Init(b.ctx, b.getSelectOffset()) + if hint := b.TableHints(); hint != nil { + li.PreferLimitToCop = hint.PreferLimitToCop + } + li.SetChildren(src) + return li, nil +} + +func resolveFromSelectFields(v *ast.ColumnNameExpr, fields []*ast.SelectField, ignoreAsName bool) (index int, err error) { + var matchedExpr ast.ExprNode + index = -1 + for i, field := range fields { + if field.Auxiliary { + continue + } + if field.Match(v, ignoreAsName) { + curCol, isCol := field.Expr.(*ast.ColumnNameExpr) + if !isCol { + return i, nil + } + if matchedExpr == nil { + matchedExpr = curCol + index = i + } else if !matchedExpr.(*ast.ColumnNameExpr).Name.Match(curCol.Name) && + !curCol.Name.Match(matchedExpr.(*ast.ColumnNameExpr).Name) { + return -1, plannererrors.ErrAmbiguous.GenWithStackByArgs(curCol.Name.Name.L, clauseMsg[fieldList]) + } + } + } + return +} + +// havingWindowAndOrderbyExprResolver visits Expr tree. +// It converts ColumnNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. +type havingWindowAndOrderbyExprResolver struct { + inAggFunc bool + inWindowFunc bool + inWindowSpec bool + inExpr bool + err error + p base.LogicalPlan + selectFields []*ast.SelectField + aggMapper map[*ast.AggregateFuncExpr]int + colMapper map[*ast.ColumnNameExpr]int + gbyItems []*ast.ByItem + outerSchemas []*expression.Schema + outerNames [][]*types.FieldName + curClause clauseCode + prevClause []clauseCode +} + +func (a *havingWindowAndOrderbyExprResolver) pushCurClause(newClause clauseCode) { + a.prevClause = append(a.prevClause, a.curClause) + a.curClause = newClause +} + +func (a *havingWindowAndOrderbyExprResolver) popCurClause() { + a.curClause = a.prevClause[len(a.prevClause)-1] + a.prevClause = a.prevClause[:len(a.prevClause)-1] +} + +// Enter implements Visitor interface. +func (a *havingWindowAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { + switch n.(type) { + case *ast.AggregateFuncExpr: + a.inAggFunc = true + case *ast.WindowFuncExpr: + a.inWindowFunc = true + case *ast.WindowSpec: + a.inWindowSpec = true + case *driver.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: + case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: + // Enter a new context, skip it. + // For example: select sum(c) + c + exists(select c from t) from t; + return n, true + case *ast.PartitionByClause: + a.pushCurClause(partitionByClause) + case *ast.OrderByClause: + if a.inWindowSpec { + a.pushCurClause(windowOrderByClause) + } + default: + a.inExpr = true + } + return n, false +} + +func (a *havingWindowAndOrderbyExprResolver) resolveFromPlan(v *ast.ColumnNameExpr, p base.LogicalPlan, resolveFieldsFirst bool) (int, error) { + idx, err := expression.FindFieldName(p.OutputNames(), v.Name) + if err != nil { + return -1, err + } + schemaCols, outputNames := p.Schema().Columns, p.OutputNames() + if idx < 0 { + // For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0 + // order by t2.a`, the query plan will be `join->selection->sort`. The + // schema of selection will be `[t1.a]`, thus we need to recursively + // retrieve the `t2.a` from the underlying join. + switch x := p.(type) { + case *logicalop.LogicalLimit, *LogicalSelection, *logicalop.LogicalTopN, *logicalop.LogicalSort, *logicalop.LogicalMaxOneRow: + return a.resolveFromPlan(v, p.Children()[0], resolveFieldsFirst) + case *LogicalJoin: + if len(x.FullNames) != 0 { + idx, err = expression.FindFieldName(x.FullNames, v.Name) + schemaCols, outputNames = x.FullSchema.Columns, x.FullNames + } + } + if err != nil || idx < 0 { + // nowhere to be found. + return -1, err + } + } + col := schemaCols[idx] + if col.IsHidden { + return -1, plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[a.curClause]) + } + name := outputNames[idx] + newColName := &ast.ColumnName{ + Schema: name.DBName, + Table: name.TblName, + Name: name.ColName, + } + for i, field := range a.selectFields { + if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && c.Name.Match(newColName) { + return i, nil + } + } + // From https://github.com/pingcap/tidb/issues/51107 + // You should make the column in the having clause as the correlated column + // which is not relation with select's fields and GroupBy's fields. + // For SQLs like: + // SELECT * FROM `t1` WHERE NOT (`t1`.`col_1`>= ( + // SELECT `t2`.`col_7` + // FROM (`t1`) + // JOIN `t2` + // WHERE ISNULL(`t2`.`col_3`) HAVING `t1`.`col_6`>1951988) + // ) ; + // + // if resolveFieldsFirst is false, the groupby is not nil. + if resolveFieldsFirst && a.curClause == havingClause { + return -1, nil + } + sf := &ast.SelectField{ + Expr: &ast.ColumnNameExpr{Name: newColName}, + Auxiliary: true, + } + // appended with new select fields. set them with flag. + if a.inAggFunc { + // should skip check in FD for only full group by. + sf.AuxiliaryColInAgg = true + } else if a.curClause == orderByClause { + // should skip check in FD for only full group by only when group by item are empty. + sf.AuxiliaryColInOrderBy = true + } + sf.Expr.SetType(col.GetStaticType()) + a.selectFields = append(a.selectFields, sf) + return len(a.selectFields) - 1, nil +} + +// Leave implements Visitor interface. +func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { + switch v := n.(type) { + case *ast.AggregateFuncExpr: + a.inAggFunc = false + a.aggMapper[v] = len(a.selectFields) + a.selectFields = append(a.selectFields, &ast.SelectField{ + Auxiliary: true, + Expr: v, + AsName: model.NewCIStr(fmt.Sprintf("sel_agg_%d", len(a.selectFields))), + }) + case *ast.WindowFuncExpr: + a.inWindowFunc = false + if a.curClause == havingClause { + a.err = plannererrors.ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.Name)) + return node, false + } + if a.curClause == orderByClause { + a.selectFields = append(a.selectFields, &ast.SelectField{ + Auxiliary: true, + Expr: v, + AsName: model.NewCIStr(fmt.Sprintf("sel_window_%d", len(a.selectFields))), + }) + } + case *ast.WindowSpec: + a.inWindowSpec = false + case *ast.PartitionByClause: + a.popCurClause() + case *ast.OrderByClause: + if a.inWindowSpec { + a.popCurClause() + } + case *ast.ColumnNameExpr: + resolveFieldsFirst := true + if a.inAggFunc || a.inWindowFunc || a.inWindowSpec || (a.curClause == orderByClause && a.inExpr) || a.curClause == fieldList { + resolveFieldsFirst = false + } + if !a.inAggFunc && a.curClause != orderByClause { + for _, item := range a.gbyItems { + if col, ok := item.Expr.(*ast.ColumnNameExpr); ok && + (v.Name.Match(col.Name) || col.Name.Match(v.Name)) { + resolveFieldsFirst = false + break + } + } + } + var index int + if resolveFieldsFirst { + index, a.err = resolveFromSelectFields(v, a.selectFields, false) + if a.err != nil { + return node, false + } + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = plannererrors.ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } + if index == -1 { + if a.curClause == orderByClause { + index, a.err = a.resolveFromPlan(v, a.p, resolveFieldsFirst) + } else if a.curClause == havingClause && v.Name.Table.L != "" { + // For SQLs like: + // select a from t b having b.a; + index, a.err = a.resolveFromPlan(v, a.p, resolveFieldsFirst) + if a.err != nil { + return node, false + } + if index != -1 { + // For SQLs like: + // select a+1 from t having t.a; + field := a.selectFields[index] + if field.Auxiliary { // having can't use auxiliary field + index = -1 + } + } + } else { + index, a.err = resolveFromSelectFields(v, a.selectFields, true) + } + } + } else { + // We should ignore the err when resolving from schema. Because we could resolve successfully + // when considering select fields. + var err error + index, err = a.resolveFromPlan(v, a.p, resolveFieldsFirst) + _ = err + if index == -1 && a.curClause != fieldList && + a.curClause != windowOrderByClause && a.curClause != partitionByClause { + index, a.err = resolveFromSelectFields(v, a.selectFields, false) + if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { + a.err = plannererrors.ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) + return node, false + } + } + } + if a.err != nil { + return node, false + } + if index == -1 { + // If we can't find it any where, it may be a correlated columns. + for _, names := range a.outerNames { + idx, err1 := expression.FindFieldName(names, v.Name) + if err1 != nil { + a.err = err1 + return node, false + } + if idx >= 0 { + return n, true + } + } + a.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), clauseMsg[a.curClause]) + return node, false + } + if a.inAggFunc { + return a.selectFields[index].Expr, true + } + a.colMapper[v] = index + } + return n, true +} + +// resolveHavingAndOrderBy will process aggregate functions and resolve the columns that don't exist in select fields. +// If we found some columns that are not in select fields, we will append it to select fields and update the colMapper. +// When we rewrite the order by / having expression, we will find column in map at first. +func (b *PlanBuilder) resolveHavingAndOrderBy(ctx context.Context, sel *ast.SelectStmt, p base.LogicalPlan) ( + map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int, error) { + extractor := &havingWindowAndOrderbyExprResolver{ + p: p, + selectFields: sel.Fields.Fields, + aggMapper: make(map[*ast.AggregateFuncExpr]int), + colMapper: b.colMapper, + outerSchemas: b.outerSchemas, + outerNames: b.outerNames, + } + if sel.GroupBy != nil { + extractor.gbyItems = sel.GroupBy.Items + } + // Extract agg funcs from having clause. + if sel.Having != nil { + extractor.curClause = havingClause + n, ok := sel.Having.Expr.Accept(extractor) + if !ok { + return nil, nil, errors.Trace(extractor.err) + } + sel.Having.Expr = n.(ast.ExprNode) + } + havingAggMapper := extractor.aggMapper + extractor.aggMapper = make(map[*ast.AggregateFuncExpr]int) + // Extract agg funcs from order by clause. + if sel.OrderBy != nil { + extractor.curClause = orderByClause + for _, item := range sel.OrderBy.Items { + extractor.inExpr = false + if ast.HasWindowFlag(item.Expr) { + continue + } + n, ok := item.Expr.Accept(extractor) + if !ok { + return nil, nil, errors.Trace(extractor.err) + } + item.Expr = n.(ast.ExprNode) + } + } + sel.Fields.Fields = extractor.selectFields + // this part is used to fetch correlated column from sub-query item in order-by clause, and append the origin + // auxiliary select filed in select list, otherwise, sub-query itself won't get the name resolved in outer schema. + if sel.OrderBy != nil { + for _, byItem := range sel.OrderBy.Items { + if _, ok := byItem.Expr.(*ast.SubqueryExpr); ok { + // correlated agg will be extracted completely latter. + _, np, err := b.rewrite(ctx, byItem.Expr, p, nil, true) + if err != nil { + return nil, nil, errors.Trace(err) + } + correlatedCols := coreusage.ExtractCorrelatedCols4LogicalPlan(np) + for _, cone := range correlatedCols { + var colName *ast.ColumnName + for idx, pone := range p.Schema().Columns { + if cone.UniqueID == pone.UniqueID { + pname := p.OutputNames()[idx] + colName = &ast.ColumnName{ + Schema: pname.DBName, + Table: pname.TblName, + Name: pname.ColName, + } + break + } + } + if colName != nil { + columnNameExpr := &ast.ColumnNameExpr{Name: colName} + for _, field := range sel.Fields.Fields { + if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && c.Name.Match(columnNameExpr.Name) && field.AsName.L == "" { + // deduplicate select fields: don't append it once it already has one. + // TODO: we add the field if it has alias, but actually they are the same column. We should not have two duplicate one. + columnNameExpr = nil + break + } + } + if columnNameExpr != nil { + sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ + Auxiliary: true, + Expr: columnNameExpr, + }) + } + } + } + } + } + } + return havingAggMapper, extractor.aggMapper, nil +} + +func (b *PlanBuilder) extractAggFuncsInExprs(exprs []ast.ExprNode) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { + extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} + for _, expr := range exprs { + expr.Accept(extractor) + } + aggList := extractor.AggFuncs + totalAggMapper := make(map[*ast.AggregateFuncExpr]int, len(aggList)) + + for i, agg := range aggList { + totalAggMapper[agg] = i + } + return aggList, totalAggMapper +} + +func (b *PlanBuilder) extractAggFuncsInSelectFields(fields []*ast.SelectField) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { + extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} + for _, f := range fields { + n, _ := f.Expr.Accept(extractor) + f.Expr = n.(ast.ExprNode) + } + aggList := extractor.AggFuncs + totalAggMapper := make(map[*ast.AggregateFuncExpr]int, len(aggList)) + + for i, agg := range aggList { + totalAggMapper[agg] = i + } + return aggList, totalAggMapper +} + +func (b *PlanBuilder) extractAggFuncsInByItems(byItems []*ast.ByItem) []*ast.AggregateFuncExpr { + extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} + for _, f := range byItems { + n, _ := f.Expr.Accept(extractor) + f.Expr = n.(ast.ExprNode) + } + return extractor.AggFuncs +} + +// extractCorrelatedAggFuncs extracts correlated aggregates which belong to outer query from aggregate function list. +func (b *PlanBuilder) extractCorrelatedAggFuncs(ctx context.Context, p base.LogicalPlan, aggFuncs []*ast.AggregateFuncExpr) (outer []*ast.AggregateFuncExpr, err error) { + corCols := make([]*expression.CorrelatedColumn, 0, len(aggFuncs)) + cols := make([]*expression.Column, 0, len(aggFuncs)) + aggMapper := make(map[*ast.AggregateFuncExpr]int) + for _, agg := range aggFuncs { + for _, arg := range agg.Args { + expr, _, err := b.rewrite(ctx, arg, p, aggMapper, true) + if err != nil { + return nil, err + } + corCols = append(corCols, expression.ExtractCorColumns(expr)...) + cols = append(cols, expression.ExtractColumns(expr)...) + } + if len(corCols) > 0 && len(cols) == 0 { + outer = append(outer, agg) + } + aggMapper[agg] = -1 + corCols, cols = corCols[:0], cols[:0] + } + return +} + +// resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. +func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p base.LogicalPlan) ( + map[*ast.AggregateFuncExpr]int, error) { + extractor := &havingWindowAndOrderbyExprResolver{ + p: p, + selectFields: sel.Fields.Fields, + aggMapper: make(map[*ast.AggregateFuncExpr]int), + colMapper: b.colMapper, + outerSchemas: b.outerSchemas, + outerNames: b.outerNames, + } + extractor.curClause = fieldList + for _, field := range sel.Fields.Fields { + if !ast.HasWindowFlag(field.Expr) { + continue + } + n, ok := field.Expr.Accept(extractor) + if !ok { + return nil, extractor.err + } + field.Expr = n.(ast.ExprNode) + } + for _, spec := range sel.WindowSpecs { + _, ok := spec.Accept(extractor) + if !ok { + return nil, extractor.err + } + } + if sel.OrderBy != nil { + extractor.curClause = orderByClause + for _, item := range sel.OrderBy.Items { + if !ast.HasWindowFlag(item.Expr) { + continue + } + n, ok := item.Expr.Accept(extractor) + if !ok { + return nil, extractor.err + } + item.Expr = n.(ast.ExprNode) + } + } + sel.Fields.Fields = extractor.selectFields + return extractor.aggMapper, nil +} + +// correlatedAggregateResolver visits Expr tree. +// It finds and collects all correlated aggregates which should be evaluated in the outer query. +type correlatedAggregateResolver struct { + ctx context.Context + err error + b *PlanBuilder + outerPlan base.LogicalPlan + + // correlatedAggFuncs stores aggregate functions which belong to outer query + correlatedAggFuncs []*ast.AggregateFuncExpr +} + +// Enter implements Visitor interface. +func (r *correlatedAggregateResolver) Enter(n ast.Node) (ast.Node, bool) { + if v, ok := n.(*ast.SelectStmt); ok { + if r.outerPlan != nil { + outerSchema := r.outerPlan.Schema() + r.b.outerSchemas = append(r.b.outerSchemas, outerSchema) + r.b.outerNames = append(r.b.outerNames, r.outerPlan.OutputNames()) + r.b.outerBlockExpand = append(r.b.outerBlockExpand, r.b.currentBlockExpand) + } + r.err = r.resolveSelect(v) + return n, true + } + return n, false +} + +// resolveSelect finds and collects correlated aggregates within the SELECT stmt. +// It resolves and builds FROM clause first to get a source plan, from which we can decide +// whether a column is correlated or not. +// Then it collects correlated aggregate from SELECT fields (including sub-queries), HAVING, +// ORDER BY, WHERE & GROUP BY. +// Finally it restore the original SELECT stmt. +func (r *correlatedAggregateResolver) resolveSelect(sel *ast.SelectStmt) (err error) { + if sel.With != nil { + l := len(r.b.outerCTEs) + defer func() { + r.b.outerCTEs = r.b.outerCTEs[:l] + }() + _, err := r.b.buildWith(r.ctx, sel.With) + if err != nil { + return err + } + } + // collect correlated aggregate from sub-queries inside FROM clause. + if err := r.collectFromTableRefs(sel.From); err != nil { + return err + } + p, err := r.b.buildTableRefs(r.ctx, sel.From) + if err != nil { + return err + } + + // similar to process in PlanBuilder.buildSelect + originalFields := sel.Fields.Fields + sel.Fields.Fields, err = r.b.unfoldWildStar(p, sel.Fields.Fields) + if err != nil { + return err + } + if r.b.capFlag&canExpandAST != 0 { + originalFields = sel.Fields.Fields + } + + hasWindowFuncField := r.b.detectSelectWindow(sel) + if hasWindowFuncField { + _, err = r.b.resolveWindowFunction(sel, p) + if err != nil { + return err + } + } + + _, _, err = r.b.resolveHavingAndOrderBy(r.ctx, sel, p) + if err != nil { + return err + } + + // find and collect correlated aggregates recursively in sub-queries + _, err = r.b.resolveCorrelatedAggregates(r.ctx, sel, p) + if err != nil { + return err + } + + // collect from SELECT fields, HAVING, ORDER BY and window functions + if r.b.detectSelectAgg(sel) { + err = r.collectFromSelectFields(p, sel.Fields.Fields) + if err != nil { + return err + } + } + + // collect from WHERE + err = r.collectFromWhere(p, sel.Where) + if err != nil { + return err + } + + // collect from GROUP BY + err = r.collectFromGroupBy(p, sel.GroupBy) + if err != nil { + return err + } + + // restore the sub-query + sel.Fields.Fields = originalFields + r.b.handleHelper.popMap() + return nil +} + +func (r *correlatedAggregateResolver) collectFromTableRefs(from *ast.TableRefsClause) error { + if from == nil { + return nil + } + subResolver := &correlatedAggregateResolver{ + ctx: r.ctx, + b: r.b, + } + _, ok := from.TableRefs.Accept(subResolver) + if !ok { + return subResolver.err + } + if len(subResolver.correlatedAggFuncs) == 0 { + return nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, subResolver.correlatedAggFuncs...) + return nil +} + +func (r *correlatedAggregateResolver) collectFromSelectFields(p base.LogicalPlan, fields []*ast.SelectField) error { + aggList, _ := r.b.extractAggFuncsInSelectFields(fields) + r.b.curClause = fieldList + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) + if err != nil { + return nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +func (r *correlatedAggregateResolver) collectFromGroupBy(p base.LogicalPlan, groupBy *ast.GroupByClause) error { + if groupBy == nil { + return nil + } + aggList := r.b.extractAggFuncsInByItems(groupBy.Items) + r.b.curClause = groupByClause + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) + if err != nil { + return nil + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +func (r *correlatedAggregateResolver) collectFromWhere(p base.LogicalPlan, where ast.ExprNode) error { + if where == nil { + return nil + } + extractor := &AggregateFuncExtractor{skipAggMap: r.b.correlatedAggMapper} + _, _ = where.Accept(extractor) + r.b.curClause = whereClause + outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, extractor.AggFuncs) + if err != nil { + return err + } + r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) + return nil +} + +// Leave implements Visitor interface. +func (r *correlatedAggregateResolver) Leave(n ast.Node) (ast.Node, bool) { + if _, ok := n.(*ast.SelectStmt); ok { + if r.outerPlan != nil { + r.b.outerSchemas = r.b.outerSchemas[0 : len(r.b.outerSchemas)-1] + r.b.outerNames = r.b.outerNames[0 : len(r.b.outerNames)-1] + r.b.currentBlockExpand = r.b.outerBlockExpand[len(r.b.outerBlockExpand)-1] + r.b.outerBlockExpand = r.b.outerBlockExpand[0 : len(r.b.outerBlockExpand)-1] + } + } + return n, r.err == nil +} + +// resolveCorrelatedAggregates finds and collects all correlated aggregates which should be evaluated +// in the outer query from all the sub-queries inside SELECT fields. +func (b *PlanBuilder) resolveCorrelatedAggregates(ctx context.Context, sel *ast.SelectStmt, p base.LogicalPlan) (map[*ast.AggregateFuncExpr]int, error) { + resolver := &correlatedAggregateResolver{ + ctx: ctx, + b: b, + outerPlan: p, + } + correlatedAggList := make([]*ast.AggregateFuncExpr, 0) + for _, field := range sel.Fields.Fields { + _, ok := field.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + if sel.Having != nil { + _, ok := sel.Having.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + if sel.OrderBy != nil { + for _, item := range sel.OrderBy.Items { + _, ok := item.Expr.Accept(resolver) + if !ok { + return nil, resolver.err + } + correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) + } + } + correlatedAggMap := make(map[*ast.AggregateFuncExpr]int) + for _, aggFunc := range correlatedAggList { + colMap := make(map[*types.FieldName]struct{}, len(p.Schema().Columns)) + allColFromAggExprNode(p, aggFunc, colMap) + for k := range colMap { + colName := &ast.ColumnName{ + Schema: k.DBName, + Table: k.TblName, + Name: k.ColName, + } + // Add the column referred in the agg func into the select list. So that we can resolve the agg func correctly. + // And we need set the AuxiliaryColInAgg to true to help our only_full_group_by checker work correctly. + sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ + Auxiliary: true, + AuxiliaryColInAgg: true, + Expr: &ast.ColumnNameExpr{Name: colName}, + }) + } + correlatedAggMap[aggFunc] = len(sel.Fields.Fields) + sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ + Auxiliary: true, + Expr: aggFunc, + AsName: model.NewCIStr(fmt.Sprintf("sel_subq_agg_%d", len(sel.Fields.Fields))), + }) + } + return correlatedAggMap, nil +} + +// gbyResolver resolves group by items from select fields. +type gbyResolver struct { + ctx base.PlanContext + fields []*ast.SelectField + schema *expression.Schema + names []*types.FieldName + err error + inExpr bool + isParam bool + skipAggMap map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn + + exprDepth int // exprDepth is the depth of current expression in expression tree. +} + +func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { + g.exprDepth++ + switch n := inNode.(type) { + case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: + return inNode, true + case *driver.ParamMarkerExpr: + g.isParam = true + if g.exprDepth == 1 && !n.UseAsValueInGbyByClause { + _, isNull, isExpectedType := getUintFromNode(g.ctx, n, false) + // For constant uint expression in top level, it should be treated as position expression. + if !isNull && isExpectedType { + return expression.ConstructPositionExpr(n), true + } + } + return n, true + case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: + default: + g.inExpr = true + } + return inNode, false +} + +func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { + extractor := &AggregateFuncExtractor{skipAggMap: g.skipAggMap} + switch v := inNode.(type) { + case *ast.ColumnNameExpr: + idx, err := expression.FindFieldName(g.names, v.Name) + if idx < 0 || !g.inExpr { + var index int + index, g.err = resolveFromSelectFields(v, g.fields, false) + if g.err != nil { + g.err = plannererrors.ErrAmbiguous.GenWithStackByArgs(v.Name.Name.L, clauseMsg[groupByClause]) + return inNode, false + } + if idx >= 0 { + return inNode, true + } + if index != -1 { + ret := g.fields[index].Expr + ret.Accept(extractor) + if len(extractor.AggFuncs) != 0 { + err = plannererrors.ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to group function") + } else if ast.HasWindowFlag(ret) { + err = plannererrors.ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to window function") + } else { + if isParam, ok := ret.(*driver.ParamMarkerExpr); ok { + isParam.UseAsValueInGbyByClause = true + } + return ret, true + } + } + g.err = err + return inNode, false + } + case *ast.PositionExpr: + pos, isNull, err := expression.PosFromPositionExpr(g.ctx.GetExprCtx(), g.ctx, v) + if err != nil { + g.err = plannererrors.ErrUnknown.GenWithStackByArgs() + } + if err != nil || isNull { + return inNode, false + } + if pos < 1 || pos > len(g.fields) { + g.err = errors.Errorf("Unknown column '%d' in 'group statement'", pos) + return inNode, false + } + ret := g.fields[pos-1].Expr + ret.Accept(extractor) + if len(extractor.AggFuncs) != 0 || ast.HasWindowFlag(ret) { + fieldName := g.fields[pos-1].AsName.String() + if fieldName == "" { + fieldName = g.fields[pos-1].Text() + } + g.err = plannererrors.ErrWrongGroupField.GenWithStackByArgs(fieldName) + return inNode, false + } + return ret, true + case *ast.ValuesExpr: + if v.Column == nil { + g.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs("", "VALUES() function") + } + } + return inNode, true +} + +func tblInfoFromCol(from ast.ResultSetNode, name *types.FieldName) *model.TableInfo { + tableList := ExtractTableList(from, true) + for _, field := range tableList { + if field.Name.L == name.TblName.L { + return field.TableInfo + } + } + return nil +} + +func buildFuncDependCol(p base.LogicalPlan, cond ast.ExprNode) (*types.FieldName, *types.FieldName, error) { + binOpExpr, ok := cond.(*ast.BinaryOperationExpr) + if !ok { + return nil, nil, nil + } + if binOpExpr.Op != opcode.EQ { + return nil, nil, nil + } + lColExpr, ok := binOpExpr.L.(*ast.ColumnNameExpr) + if !ok { + return nil, nil, nil + } + rColExpr, ok := binOpExpr.R.(*ast.ColumnNameExpr) + if !ok { + return nil, nil, nil + } + lIdx, err := expression.FindFieldName(p.OutputNames(), lColExpr.Name) + if err != nil { + return nil, nil, err + } + rIdx, err := expression.FindFieldName(p.OutputNames(), rColExpr.Name) + if err != nil { + return nil, nil, err + } + if lIdx == -1 { + return nil, nil, plannererrors.ErrUnknownColumn.GenWithStackByArgs(lColExpr.Name, "where clause") + } + if rIdx == -1 { + return nil, nil, plannererrors.ErrUnknownColumn.GenWithStackByArgs(rColExpr.Name, "where clause") + } + return p.OutputNames()[lIdx], p.OutputNames()[rIdx], nil +} + +func buildWhereFuncDepend(p base.LogicalPlan, where ast.ExprNode) (map[*types.FieldName]*types.FieldName, error) { + whereConditions := splitWhere(where) + colDependMap := make(map[*types.FieldName]*types.FieldName, 2*len(whereConditions)) + for _, cond := range whereConditions { + lCol, rCol, err := buildFuncDependCol(p, cond) + if err != nil { + return nil, err + } + if lCol == nil || rCol == nil { + continue + } + colDependMap[lCol] = rCol + colDependMap[rCol] = lCol + } + return colDependMap, nil +} + +func buildJoinFuncDepend(p base.LogicalPlan, from ast.ResultSetNode) (map[*types.FieldName]*types.FieldName, error) { + switch x := from.(type) { + case *ast.Join: + if x.On == nil { + return nil, nil + } + onConditions := splitWhere(x.On.Expr) + colDependMap := make(map[*types.FieldName]*types.FieldName, len(onConditions)) + for _, cond := range onConditions { + lCol, rCol, err := buildFuncDependCol(p, cond) + if err != nil { + return nil, err + } + if lCol == nil || rCol == nil { + continue + } + lTbl := tblInfoFromCol(x.Left, lCol) + if lTbl == nil { + lCol, rCol = rCol, lCol + } + switch x.Tp { + case ast.CrossJoin: + colDependMap[lCol] = rCol + colDependMap[rCol] = lCol + case ast.LeftJoin: + colDependMap[rCol] = lCol + case ast.RightJoin: + colDependMap[lCol] = rCol + } + } + return colDependMap, nil + default: + return nil, nil + } +} + +func checkColFuncDepend( + p base.LogicalPlan, + name *types.FieldName, + tblInfo *model.TableInfo, + gbyOrSingleValueColNames map[*types.FieldName]struct{}, + whereDependNames, joinDependNames map[*types.FieldName]*types.FieldName, +) bool { + for _, index := range tblInfo.Indices { + if !index.Unique { + continue + } + funcDepend := true + // if all columns of some unique/pri indexes are determined, all columns left are check-passed. + for _, indexCol := range index.Columns { + iColInfo := tblInfo.Columns[indexCol.Offset] + if !mysql.HasNotNullFlag(iColInfo.GetFlag()) { + funcDepend = false + break + } + cn := &ast.ColumnName{ + Schema: name.DBName, + Table: name.TblName, + Name: iColInfo.Name, + } + iIdx, err := expression.FindFieldName(p.OutputNames(), cn) + if err != nil || iIdx < 0 { + funcDepend = false + break + } + iName := p.OutputNames()[iIdx] + if _, ok := gbyOrSingleValueColNames[iName]; ok { + continue + } + if wCol, ok := whereDependNames[iName]; ok { + if _, ok = gbyOrSingleValueColNames[wCol]; ok { + continue + } + } + if jCol, ok := joinDependNames[iName]; ok { + if _, ok = gbyOrSingleValueColNames[jCol]; ok { + continue + } + } + funcDepend = false + break + } + if funcDepend { + return true + } + } + primaryFuncDepend := true + hasPrimaryField := false + for _, colInfo := range tblInfo.Columns { + if !mysql.HasPriKeyFlag(colInfo.GetFlag()) { + continue + } + hasPrimaryField = true + pkName := &ast.ColumnName{ + Schema: name.DBName, + Table: name.TblName, + Name: colInfo.Name, + } + pIdx, err := expression.FindFieldName(p.OutputNames(), pkName) + // It is possible that `pIdx < 0` and here is a case. + // ``` + // CREATE TABLE `BB` ( + // `pk` int(11) NOT NULL AUTO_INCREMENT, + // `col_int_not_null` int NOT NULL, + // PRIMARY KEY (`pk`) + // ); + // + // SELECT OUTR . col2 AS X + // FROM + // BB AS OUTR2 + // INNER JOIN + // (SELECT col_int_not_null AS col1, + // pk AS col2 + // FROM BB) AS OUTR ON OUTR2.col_int_not_null = OUTR.col1 + // GROUP BY OUTR2.col_int_not_null; + // ``` + // When we enter `checkColFuncDepend`, `pkName.Table` is `OUTR` which is an alias, while `pkName.Name` is `pk` + // which is a original name. Hence `expression.FindFieldName` will fail and `pIdx` will be less than 0. + // Currently, when we meet `pIdx < 0`, we directly regard `primaryFuncDepend` as false and jump out. This way is + // easy to implement but makes only-full-group-by checker not smart enough. Later we will refactor only-full-group-by + // checker and resolve the inconsistency between the alias table name and the original column name. + if err != nil || pIdx < 0 { + primaryFuncDepend = false + break + } + pCol := p.OutputNames()[pIdx] + if _, ok := gbyOrSingleValueColNames[pCol]; ok { + continue + } + if wCol, ok := whereDependNames[pCol]; ok { + if _, ok = gbyOrSingleValueColNames[wCol]; ok { + continue + } + } + if jCol, ok := joinDependNames[pCol]; ok { + if _, ok = gbyOrSingleValueColNames[jCol]; ok { + continue + } + } + primaryFuncDepend = false + break + } + return primaryFuncDepend && hasPrimaryField +} + +// ErrExprLoc is for generate the ErrFieldNotInGroupBy error info +type ErrExprLoc struct { + Offset int + Loc string +} + +func checkExprInGroupByOrIsSingleValue( + p base.LogicalPlan, + expr ast.ExprNode, + offset int, + loc string, + gbyOrSingleValueColNames map[*types.FieldName]struct{}, + gbyExprs []ast.ExprNode, + notInGbyOrSingleValueColNames map[*types.FieldName]ErrExprLoc, +) { + if _, ok := expr.(*ast.AggregateFuncExpr); ok { + return + } + if f, ok := expr.(*ast.FuncCallExpr); ok { + if f.FnName.L == ast.Grouping { + // just skip grouping function check here, because later in building plan phase, we + // will do the grouping function valid check. + return + } + } + if _, ok := expr.(*ast.ColumnNameExpr); !ok { + for _, gbyExpr := range gbyExprs { + if ast.ExpressionDeepEqual(gbyExpr, expr) { + return + } + } + } + // Function `any_value` can be used in aggregation, even `ONLY_FULL_GROUP_BY` is set. + // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_any-value for details + if f, ok := expr.(*ast.FuncCallExpr); ok { + if f.FnName.L == ast.AnyValue { + return + } + } + colMap := make(map[*types.FieldName]struct{}, len(p.Schema().Columns)) + allColFromExprNode(p, expr, colMap) + for col := range colMap { + if _, ok := gbyOrSingleValueColNames[col]; !ok { + notInGbyOrSingleValueColNames[col] = ErrExprLoc{Offset: offset, Loc: loc} + } + } +} + +func (b *PlanBuilder) checkOnlyFullGroupBy(p base.LogicalPlan, sel *ast.SelectStmt) (err error) { + if sel.GroupBy != nil { + err = b.checkOnlyFullGroupByWithGroupClause(p, sel) + } else { + err = b.checkOnlyFullGroupByWithOutGroupClause(p, sel) + } + return err +} + +func addGbyOrSingleValueColName(p base.LogicalPlan, colName *ast.ColumnName, gbyOrSingleValueColNames map[*types.FieldName]struct{}) { + idx, err := expression.FindFieldName(p.OutputNames(), colName) + if err != nil || idx < 0 { + return + } + gbyOrSingleValueColNames[p.OutputNames()[idx]] = struct{}{} +} + +func extractSingeValueColNamesFromWhere(p base.LogicalPlan, where ast.ExprNode, gbyOrSingleValueColNames map[*types.FieldName]struct{}) { + whereConditions := splitWhere(where) + for _, cond := range whereConditions { + binOpExpr, ok := cond.(*ast.BinaryOperationExpr) + if !ok || binOpExpr.Op != opcode.EQ { + continue + } + if colExpr, ok := binOpExpr.L.(*ast.ColumnNameExpr); ok { + if _, ok := binOpExpr.R.(ast.ValueExpr); ok { + addGbyOrSingleValueColName(p, colExpr.Name, gbyOrSingleValueColNames) + } + } else if colExpr, ok := binOpExpr.R.(*ast.ColumnNameExpr); ok { + if _, ok := binOpExpr.L.(ast.ValueExpr); ok { + addGbyOrSingleValueColName(p, colExpr.Name, gbyOrSingleValueColNames) + } + } + } +} + +func (*PlanBuilder) checkOnlyFullGroupByWithGroupClause(p base.LogicalPlan, sel *ast.SelectStmt) error { + gbyOrSingleValueColNames := make(map[*types.FieldName]struct{}, len(sel.Fields.Fields)) + gbyExprs := make([]ast.ExprNode, 0, len(sel.Fields.Fields)) + for _, byItem := range sel.GroupBy.Items { + expr := getInnerFromParenthesesAndUnaryPlus(byItem.Expr) + if colExpr, ok := expr.(*ast.ColumnNameExpr); ok { + addGbyOrSingleValueColName(p, colExpr.Name, gbyOrSingleValueColNames) + } else { + gbyExprs = append(gbyExprs, expr) + } + } + // MySQL permits a nonaggregate column not named in a GROUP BY clause when ONLY_FULL_GROUP_BY SQL mode is enabled, + // provided that this column is limited to a single value. + // See https://dev.mysql.com/doc/refman/5.7/en/group-by-handling.html for details. + extractSingeValueColNamesFromWhere(p, sel.Where, gbyOrSingleValueColNames) + + notInGbyOrSingleValueColNames := make(map[*types.FieldName]ErrExprLoc, len(sel.Fields.Fields)) + for offset, field := range sel.Fields.Fields { + if field.Auxiliary { + continue + } + checkExprInGroupByOrIsSingleValue(p, getInnerFromParenthesesAndUnaryPlus(field.Expr), offset, ErrExprInSelect, gbyOrSingleValueColNames, gbyExprs, notInGbyOrSingleValueColNames) + } + + if sel.OrderBy != nil { + for offset, item := range sel.OrderBy.Items { + if colName, ok := item.Expr.(*ast.ColumnNameExpr); ok { + index, err := resolveFromSelectFields(colName, sel.Fields.Fields, false) + if err != nil { + return err + } + // If the ByItem is in fields list, it has been checked already in above. + if index >= 0 { + continue + } + } + checkExprInGroupByOrIsSingleValue(p, getInnerFromParenthesesAndUnaryPlus(item.Expr), offset, ErrExprInOrderBy, gbyOrSingleValueColNames, gbyExprs, notInGbyOrSingleValueColNames) + } + } + if len(notInGbyOrSingleValueColNames) == 0 { + return nil + } + + whereDepends, err := buildWhereFuncDepend(p, sel.Where) + if err != nil { + return err + } + joinDepends, err := buildJoinFuncDepend(p, sel.From.TableRefs) + if err != nil { + return err + } + tblMap := make(map[*model.TableInfo]struct{}, len(notInGbyOrSingleValueColNames)) + for name, errExprLoc := range notInGbyOrSingleValueColNames { + tblInfo := tblInfoFromCol(sel.From.TableRefs, name) + if tblInfo == nil { + continue + } + if _, ok := tblMap[tblInfo]; ok { + continue + } + if checkColFuncDepend(p, name, tblInfo, gbyOrSingleValueColNames, whereDepends, joinDepends) { + tblMap[tblInfo] = struct{}{} + continue + } + switch errExprLoc.Loc { + case ErrExprInSelect: + if sel.GroupBy.Rollup { + return plannererrors.ErrFieldInGroupingNotGroupBy.GenWithStackByArgs(strconv.Itoa(errExprLoc.Offset + 1)) + } + return plannererrors.ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, name.DBName.O+"."+name.TblName.O+"."+name.OrigColName.O) + case ErrExprInOrderBy: + return plannererrors.ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, sel.OrderBy.Items[errExprLoc.Offset].Expr.Text()) + } + return nil + } + return nil +} + +func (*PlanBuilder) checkOnlyFullGroupByWithOutGroupClause(p base.LogicalPlan, sel *ast.SelectStmt) error { + resolver := colResolverForOnlyFullGroupBy{ + firstOrderByAggColIdx: -1, + } + resolver.curClause = fieldList + for idx, field := range sel.Fields.Fields { + resolver.exprIdx = idx + field.Accept(&resolver) + } + if len(resolver.nonAggCols) > 0 { + if sel.Having != nil { + sel.Having.Expr.Accept(&resolver) + } + if sel.OrderBy != nil { + resolver.curClause = orderByClause + for idx, byItem := range sel.OrderBy.Items { + resolver.exprIdx = idx + byItem.Expr.Accept(&resolver) + } + } + } + if resolver.firstOrderByAggColIdx != -1 && len(resolver.nonAggCols) > 0 { + // SQL like `select a from t where a = 1 order by count(b)` is illegal. + return plannererrors.ErrAggregateOrderNonAggQuery.GenWithStackByArgs(resolver.firstOrderByAggColIdx + 1) + } + if !resolver.hasAggFuncOrAnyValue || len(resolver.nonAggCols) == 0 { + return nil + } + singleValueColNames := make(map[*types.FieldName]struct{}, len(sel.Fields.Fields)) + extractSingeValueColNamesFromWhere(p, sel.Where, singleValueColNames) + whereDepends, err := buildWhereFuncDepend(p, sel.Where) + if err != nil { + return err + } + + joinDepends, err := buildJoinFuncDepend(p, sel.From.TableRefs) + if err != nil { + return err + } + tblMap := make(map[*model.TableInfo]struct{}, len(resolver.nonAggCols)) + for i, colName := range resolver.nonAggCols { + idx, err := expression.FindFieldName(p.OutputNames(), colName) + if err != nil || idx < 0 { + return plannererrors.ErrMixOfGroupFuncAndFields.GenWithStackByArgs(resolver.nonAggColIdxs[i]+1, colName.Name.O) + } + fieldName := p.OutputNames()[idx] + if _, ok := singleValueColNames[fieldName]; ok { + continue + } + tblInfo := tblInfoFromCol(sel.From.TableRefs, fieldName) + if tblInfo == nil { + continue + } + if _, ok := tblMap[tblInfo]; ok { + continue + } + if checkColFuncDepend(p, fieldName, tblInfo, singleValueColNames, whereDepends, joinDepends) { + tblMap[tblInfo] = struct{}{} + continue + } + return plannererrors.ErrMixOfGroupFuncAndFields.GenWithStackByArgs(resolver.nonAggColIdxs[i]+1, colName.Name.O) + } + return nil +} + +// colResolverForOnlyFullGroupBy visits Expr tree to find out if an Expr tree is an aggregation function. +// If so, find out the first column name that not in an aggregation function. +type colResolverForOnlyFullGroupBy struct { + nonAggCols []*ast.ColumnName + exprIdx int + nonAggColIdxs []int + hasAggFuncOrAnyValue bool + firstOrderByAggColIdx int + curClause clauseCode +} + +func (c *colResolverForOnlyFullGroupBy) Enter(node ast.Node) (ast.Node, bool) { + switch t := node.(type) { + case *ast.AggregateFuncExpr: + c.hasAggFuncOrAnyValue = true + if c.curClause == orderByClause { + c.firstOrderByAggColIdx = c.exprIdx + } + return node, true + case *ast.FuncCallExpr: + // enable function `any_value` in aggregation even `ONLY_FULL_GROUP_BY` is set + if t.FnName.L == ast.AnyValue { + c.hasAggFuncOrAnyValue = true + return node, true + } + case *ast.ColumnNameExpr: + c.nonAggCols = append(c.nonAggCols, t.Name) + c.nonAggColIdxs = append(c.nonAggColIdxs, c.exprIdx) + return node, true + case *ast.SubqueryExpr: + return node, true + } + return node, false +} + +func (*colResolverForOnlyFullGroupBy) Leave(node ast.Node) (ast.Node, bool) { + return node, true +} + +type aggColNameResolver struct { + colNameResolver +} + +func (*aggColNameResolver) Enter(inNode ast.Node) (ast.Node, bool) { + if _, ok := inNode.(*ast.ColumnNameExpr); ok { + return inNode, true + } + return inNode, false +} + +func allColFromAggExprNode(p base.LogicalPlan, n ast.Node, names map[*types.FieldName]struct{}) { + extractor := &aggColNameResolver{ + colNameResolver: colNameResolver{ + p: p, + names: names, + }, + } + n.Accept(extractor) +} + +type colNameResolver struct { + p base.LogicalPlan + names map[*types.FieldName]struct{} +} + +func (*colNameResolver) Enter(inNode ast.Node) (ast.Node, bool) { + switch inNode.(type) { + case *ast.ColumnNameExpr, *ast.SubqueryExpr, *ast.AggregateFuncExpr: + return inNode, true + } + return inNode, false +} + +func (c *colNameResolver) Leave(inNode ast.Node) (ast.Node, bool) { + if v, ok := inNode.(*ast.ColumnNameExpr); ok { + idx, err := expression.FindFieldName(c.p.OutputNames(), v.Name) + if err == nil && idx >= 0 { + c.names[c.p.OutputNames()[idx]] = struct{}{} + } + } + return inNode, true +} + +func allColFromExprNode(p base.LogicalPlan, n ast.Node, names map[*types.FieldName]struct{}) { + extractor := &colNameResolver{ + p: p, + names: names, + } + n.Accept(extractor) +} + +func (b *PlanBuilder) resolveGbyExprs(ctx context.Context, p base.LogicalPlan, gby *ast.GroupByClause, fields []*ast.SelectField) (base.LogicalPlan, []expression.Expression, bool, error) { + b.curClause = groupByClause + exprs := make([]expression.Expression, 0, len(gby.Items)) + resolver := &gbyResolver{ + ctx: b.ctx, + fields: fields, + schema: p.Schema(), + names: p.OutputNames(), + skipAggMap: b.correlatedAggMapper, + } + for _, item := range gby.Items { + resolver.inExpr = false + resolver.exprDepth = 0 + resolver.isParam = false + retExpr, _ := item.Expr.Accept(resolver) + if resolver.err != nil { + return nil, nil, false, errors.Trace(resolver.err) + } + if !resolver.isParam { + item.Expr = retExpr.(ast.ExprNode) + } + + itemExpr := retExpr.(ast.ExprNode) + expr, np, err := b.rewrite(ctx, itemExpr, p, nil, true) + if err != nil { + return nil, nil, false, err + } + + exprs = append(exprs, expr) + p = np + } + return p, exprs, gby.Rollup, nil +} + +func (*PlanBuilder) unfoldWildStar(p base.LogicalPlan, selectFields []*ast.SelectField) (resultList []*ast.SelectField, err error) { + join, isJoin := p.(*LogicalJoin) + for i, field := range selectFields { + if field.WildCard == nil { + resultList = append(resultList, field) + continue + } + if field.WildCard.Table.L == "" && i > 0 { + return nil, plannererrors.ErrInvalidWildCard + } + list := unfoldWildStar(field, p.OutputNames(), p.Schema().Columns) + // For sql like `select t1.*, t2.* from t1 join t2 using(a)` or `select t1.*, t2.* from t1 natual join t2`, + // the schema of the Join doesn't contain enough columns because the join keys are coalesced in this schema. + // We should collect the columns from the FullSchema. + if isJoin && join.FullSchema != nil && field.WildCard.Table.L != "" { + list = unfoldWildStar(field, join.FullNames, join.FullSchema.Columns) + } + if len(list) == 0 { + return nil, plannererrors.ErrBadTable.GenWithStackByArgs(field.WildCard.Table) + } + resultList = append(resultList, list...) + } + return resultList, nil +} + +func unfoldWildStar(field *ast.SelectField, outputName types.NameSlice, column []*expression.Column) (resultList []*ast.SelectField) { + dbName := field.WildCard.Schema + tblName := field.WildCard.Table + for i, name := range outputName { + col := column[i] + if col.IsHidden { + continue + } + if (dbName.L == "" || dbName.L == name.DBName.L) && + (tblName.L == "" || tblName.L == name.TblName.L) && + col.ID != model.ExtraHandleID && col.ID != model.ExtraPhysTblID { + colName := &ast.ColumnNameExpr{ + Name: &ast.ColumnName{ + Schema: name.DBName, + Table: name.TblName, + Name: name.ColName, + }} + colName.SetType(col.GetStaticType()) + field := &ast.SelectField{Expr: colName} + field.SetText(nil, name.ColName.O) + resultList = append(resultList, field) + } + } + return resultList +} + +func (b *PlanBuilder) addAliasName(ctx context.Context, selectStmt *ast.SelectStmt, p base.LogicalPlan) (resultList []*ast.SelectField, err error) { + selectFields := selectStmt.Fields.Fields + projOutNames := make([]*types.FieldName, 0, len(selectFields)) + for _, field := range selectFields { + colNameField, isColumnNameExpr := field.Expr.(*ast.ColumnNameExpr) + if isColumnNameExpr { + colName := colNameField.Name.Name + if field.AsName.L != "" { + colName = field.AsName + } + projOutNames = append(projOutNames, &types.FieldName{ + TblName: colNameField.Name.Table, + OrigTblName: colNameField.Name.Table, + ColName: colName, + OrigColName: colNameField.Name.Name, + DBName: colNameField.Name.Schema, + }) + } else { + // create view v as select name_const('col', 100); + // The column in v should be 'col', so we call `buildProjectionField` to handle this. + _, name, err := b.buildProjectionField(ctx, p, field, nil) + if err != nil { + return nil, err + } + projOutNames = append(projOutNames, name) + } + } + + // dedupMap is used for renaming a duplicated anonymous column + dedupMap := make(map[string]int) + anonymousFields := make([]bool, len(selectFields)) + + for i, field := range selectFields { + newField := *field + if newField.AsName.L == "" { + newField.AsName = projOutNames[i].ColName + } + + if _, ok := field.Expr.(*ast.ColumnNameExpr); !ok && field.AsName.L == "" { + anonymousFields[i] = true + } else { + anonymousFields[i] = false + // dedupMap should be inited with all non-anonymous fields before renaming other duplicated anonymous fields + dedupMap[newField.AsName.L] = 0 + } + + resultList = append(resultList, &newField) + } + + // We should rename duplicated anonymous fields in the first SelectStmt of CreateViewStmt + // See: https://github.com/pingcap/tidb/issues/29326 + if selectStmt.AsViewSchema { + for i, field := range resultList { + if !anonymousFields[i] { + continue + } + + oldName := field.AsName + if dup, ok := dedupMap[field.AsName.L]; ok { + if dup == 0 { + field.AsName = model.NewCIStr(fmt.Sprintf("Name_exp_%s", field.AsName.O)) + } else { + field.AsName = model.NewCIStr(fmt.Sprintf("Name_exp_%d_%s", dup, field.AsName.O)) + } + dedupMap[oldName.L] = dup + 1 + } else { + dedupMap[oldName.L] = 0 + } + } + } + + return resultList, nil +} + +func (b *PlanBuilder) pushHintWithoutTableWarning(hint *ast.TableOptimizerHint) { + var sb strings.Builder + ctx := format.NewRestoreCtx(0, &sb) + if err := hint.Restore(ctx); err != nil { + return + } + b.ctx.GetSessionVars().StmtCtx.SetHintWarning( + fmt.Sprintf("Hint %s is inapplicable. Please specify the table names in the arguments.", sb.String())) +} + +func (b *PlanBuilder) pushTableHints(hints []*ast.TableOptimizerHint, currentLevel int) { + hints = b.hintProcessor.GetCurrentStmtHints(hints, currentLevel) + currentDB := b.ctx.GetSessionVars().CurrentDB + warnHandler := b.ctx.GetSessionVars().StmtCtx + planHints, subQueryHintFlags, err := h.ParsePlanHints(hints, currentLevel, currentDB, + b.hintProcessor, b.ctx.GetSessionVars().StmtCtx.StraightJoinOrder, + b.subQueryCtx == handlingExistsSubquery, b.subQueryCtx == notHandlingSubquery, warnHandler) + if err != nil { + return + } + b.tableHintInfo = append(b.tableHintInfo, planHints) + b.subQueryHintFlags |= subQueryHintFlags +} + +func (b *PlanBuilder) popVisitInfo() { + if len(b.visitInfo) == 0 { + return + } + b.visitInfo = b.visitInfo[:len(b.visitInfo)-1] +} + +func (b *PlanBuilder) popTableHints() { + hintInfo := b.tableHintInfo[len(b.tableHintInfo)-1] + for _, warning := range h.CollectUnmatchedHintWarnings(hintInfo) { + b.ctx.GetSessionVars().StmtCtx.SetHintWarning(warning) + } + b.tableHintInfo = b.tableHintInfo[:len(b.tableHintInfo)-1] +} + +// TableHints returns the *TableHintInfo of PlanBuilder. +func (b *PlanBuilder) TableHints() *h.PlanHints { + if len(b.tableHintInfo) == 0 { + return nil + } + return b.tableHintInfo[len(b.tableHintInfo)-1] +} + +func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p base.LogicalPlan, err error) { + b.pushSelectOffset(sel.QueryBlockOffset) + b.pushTableHints(sel.TableHints, sel.QueryBlockOffset) + defer func() { + b.popSelectOffset() + // table hints are only visible in the current SELECT statement. + b.popTableHints() + }() + if b.buildingRecursivePartForCTE { + if sel.Distinct || sel.OrderBy != nil || sel.Limit != nil { + return nil, plannererrors.ErrNotSupportedYet.GenWithStackByArgs("ORDER BY / LIMIT / SELECT DISTINCT in recursive query block of Common Table Expression") + } + if sel.GroupBy != nil { + return nil, plannererrors.ErrCTERecursiveForbidsAggregation.FastGenByArgs(b.genCTETableNameForError()) + } + } + if sel.SelectStmtOpts != nil { + origin := b.inStraightJoin + b.inStraightJoin = sel.SelectStmtOpts.StraightJoin + defer func() { b.inStraightJoin = origin }() + } + + var ( + aggFuncs []*ast.AggregateFuncExpr + havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int + windowAggMap map[*ast.AggregateFuncExpr]int + correlatedAggMap map[*ast.AggregateFuncExpr]int + gbyCols []expression.Expression + projExprs []expression.Expression + rollup bool + ) + + // set for update read to true before building result set node + if isForUpdateReadSelectLock(sel.LockInfo) { + b.isForUpdateRead = true + } + + if hints := b.TableHints(); hints != nil && hints.CTEMerge { + // Verify Merge hints in the current query, + // we will update parameters for those that meet the rules, and warn those that do not. + // If the current query uses Merge Hint and the query is a CTE, + // we update the HINT information for the current query. + // If the current query is not a CTE query (it may be a subquery within a CTE query + // or an external non-CTE query), we will give a warning. + // In particular, recursive CTE have separate warnings, so they are no longer called. + if b.buildingCTE { + if b.isCTE { + b.outerCTEs[len(b.outerCTEs)-1].forceInlineByHintOrVar = true + } else if !b.buildingRecursivePartForCTE { + // If there has subquery which is not CTE and using `MERGE()` hint, we will show this warning; + b.ctx.GetSessionVars().StmtCtx.SetHintWarning( + "Hint merge() is inapplicable. " + + "Please check whether the hint is used in the right place, " + + "you should use this hint inside the CTE.") + } + } else if !b.buildingCTE && !b.isCTE { + b.ctx.GetSessionVars().StmtCtx.SetHintWarning( + "Hint merge() is inapplicable. " + + "Please check whether the hint is used in the right place, " + + "you should use this hint inside the CTE.") + } + } + + var currentLayerCTEs []*cteInfo + if sel.With != nil { + l := len(b.outerCTEs) + defer func() { + b.outerCTEs = b.outerCTEs[:l] + }() + currentLayerCTEs, err = b.buildWith(ctx, sel.With) + if err != nil { + return nil, err + } + } + + p, err = b.buildTableRefs(ctx, sel.From) + if err != nil { + return nil, err + } + + originalFields := sel.Fields.Fields + sel.Fields.Fields, err = b.unfoldWildStar(p, sel.Fields.Fields) + if err != nil { + return nil, err + } + if b.capFlag&canExpandAST != 0 { + // To be compatible with MySQL, we add alias name for each select field when creating view. + sel.Fields.Fields, err = b.addAliasName(ctx, sel, p) + if err != nil { + return nil, err + } + originalFields = sel.Fields.Fields + } + + if sel.GroupBy != nil { + p, gbyCols, rollup, err = b.resolveGbyExprs(ctx, p, sel.GroupBy, sel.Fields.Fields) + if err != nil { + return nil, err + } + } + + if b.ctx.GetSessionVars().SQLMode.HasOnlyFullGroupBy() && sel.From != nil && !b.ctx.GetSessionVars().OptimizerEnableNewOnlyFullGroupByCheck { + err = b.checkOnlyFullGroupBy(p, sel) + if err != nil { + return nil, err + } + } + + hasWindowFuncField := b.detectSelectWindow(sel) + // Some SQL statements define WINDOW but do not use them. But we also need to check the window specification list. + // For example: select id from t group by id WINDOW w AS (ORDER BY uids DESC) ORDER BY id; + // We don't use the WINDOW w, but if the 'uids' column is not in the table t, we still need to report an error. + if hasWindowFuncField || sel.WindowSpecs != nil { + if b.buildingRecursivePartForCTE { + return nil, plannererrors.ErrCTERecursiveForbidsAggregation.FastGenByArgs(b.genCTETableNameForError()) + } + + windowAggMap, err = b.resolveWindowFunction(sel, p) + if err != nil { + return nil, err + } + } + // We must resolve having and order by clause before build projection, + // because when the query is "select a+1 as b from t having sum(b) < 0", we must replace sum(b) to sum(a+1), + // which only can be done before building projection and extracting Agg functions. + havingMap, orderMap, err = b.resolveHavingAndOrderBy(ctx, sel, p) + if err != nil { + return nil, err + } + + // We have to resolve correlated aggregate inside sub-queries before building aggregation and building projection, + // for instance, count(a) inside the sub-query of "select (select count(a)) from t" should be evaluated within + // the context of the outer query. So we have to extract such aggregates from sub-queries and put them into + // SELECT field list. + correlatedAggMap, err = b.resolveCorrelatedAggregates(ctx, sel, p) + if err != nil { + return nil, err + } + + // b.allNames will be used in evalDefaultExpr(). Default function is special because it needs to find the + // corresponding column name, but does not need the value in the column. + // For example, select a from t order by default(b), the column b will not be in select fields. Also because + // buildSort is after buildProjection, so we need get OutputNames before BuildProjection and store in allNames. + // Otherwise, we will get select fields instead of all OutputNames, so that we can't find the column b in the + // above example. + b.allNames = append(b.allNames, p.OutputNames()) + defer func() { b.allNames = b.allNames[:len(b.allNames)-1] }() + + if sel.Where != nil { + p, err = b.buildSelection(ctx, p, sel.Where, nil) + if err != nil { + return nil, err + } + } + l := sel.LockInfo + if l != nil && l.LockType != ast.SelectLockNone { + for _, tName := range l.Tables { + // CTE has no *model.HintedTable, we need to skip it. + if tName.TableInfo == nil { + continue + } + b.ctx.GetSessionVars().StmtCtx.LockTableIDs[tName.TableInfo.ID] = struct{}{} + } + p, err = b.buildSelectLock(p, l) + if err != nil { + return nil, err + } + } + b.handleHelper.popMap() + b.handleHelper.pushMap(nil) + + hasAgg := b.detectSelectAgg(sel) + needBuildAgg := hasAgg + if hasAgg { + if b.buildingRecursivePartForCTE { + return nil, plannererrors.ErrCTERecursiveForbidsAggregation.GenWithStackByArgs(b.genCTETableNameForError()) + } + + aggFuncs, totalMap = b.extractAggFuncsInSelectFields(sel.Fields.Fields) + // len(aggFuncs) == 0 and sel.GroupBy == nil indicates that all the aggregate functions inside the SELECT fields + // are actually correlated aggregates from the outer query, which have already been built in the outer query. + // The only thing we need to do is to find them from b.correlatedAggMap in buildProjection. + if len(aggFuncs) == 0 && sel.GroupBy == nil { + needBuildAgg = false + } + } + if needBuildAgg { + // if rollup syntax is specified, Expand OP is required to replicate the data to feed different grouping layout. + if rollup { + p, gbyCols, err = b.buildExpand(p, gbyCols) + if err != nil { + return nil, err + } + } + var aggIndexMap map[int]int + p, aggIndexMap, err = b.buildAggregation(ctx, p, aggFuncs, gbyCols, correlatedAggMap) + if err != nil { + return nil, err + } + for agg, idx := range totalMap { + totalMap[agg] = aggIndexMap[idx] + } + } + + var oldLen int + // According to https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html, + // we can only process window functions after having clause, so `considerWindow` is false now. + p, projExprs, oldLen, err = b.buildProjection(ctx, p, sel.Fields.Fields, totalMap, nil, false, sel.OrderBy != nil) + if err != nil { + return nil, err + } + + if sel.Having != nil { + b.curClause = havingClause + p, err = b.buildSelection(ctx, p, sel.Having.Expr, havingMap) + if err != nil { + return nil, err + } + } + + b.windowSpecs, err = buildWindowSpecs(sel.WindowSpecs) + if err != nil { + return nil, err + } + + var windowMapper map[*ast.WindowFuncExpr]int + if hasWindowFuncField || sel.WindowSpecs != nil { + windowFuncs := extractWindowFuncs(sel.Fields.Fields) + // we need to check the func args first before we check the window spec + err := b.checkWindowFuncArgs(ctx, p, windowFuncs, windowAggMap) + if err != nil { + return nil, err + } + groupedFuncs, orderedSpec, err := b.groupWindowFuncs(windowFuncs) + if err != nil { + return nil, err + } + p, windowMapper, err = b.buildWindowFunctions(ctx, p, groupedFuncs, orderedSpec, windowAggMap) + if err != nil { + return nil, err + } + // `hasWindowFuncField == false` means there's only unused named window specs without window functions. + // In such case plan `p` is not changed, so we don't have to build another projection. + if hasWindowFuncField { + // Now we build the window function fields. + p, projExprs, oldLen, err = b.buildProjection(ctx, p, sel.Fields.Fields, windowAggMap, windowMapper, true, false) + if err != nil { + return nil, err + } + } + } + + if sel.Distinct { + p, err = b.buildDistinct(p, oldLen) + if err != nil { + return nil, err + } + } + + if sel.OrderBy != nil { + // We need to keep the ORDER BY clause for the following cases: + // 1. The select is top level query, order should be honored + // 2. The query has LIMIT clause + // 3. The control flag requires keeping ORDER BY explicitly + if len(b.qbOffset) == 1 || sel.Limit != nil || !b.ctx.GetSessionVars().RemoveOrderbyInSubquery { + if b.ctx.GetSessionVars().SQLMode.HasOnlyFullGroupBy() { + p, err = b.buildSortWithCheck(ctx, p, sel.OrderBy.Items, orderMap, windowMapper, projExprs, oldLen, sel.Distinct) + } else { + p, err = b.buildSort(ctx, p, sel.OrderBy.Items, orderMap, windowMapper) + } + if err != nil { + return nil, err + } + } + } + + if sel.Limit != nil { + p, err = b.buildLimit(p, sel.Limit) + if err != nil { + return nil, err + } + } + + sel.Fields.Fields = originalFields + if oldLen != p.Schema().Len() { + proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldLen])}.Init(b.ctx, b.getSelectOffset()) + proj.SetChildren(p) + schema := expression.NewSchema(p.Schema().Clone().Columns[:oldLen]...) + for _, col := range schema.Columns { + col.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() + } + proj.SetOutputNames(p.OutputNames()[:oldLen]) + proj.SetSchema(schema) + return b.tryToBuildSequence(currentLayerCTEs, proj), nil + } + + return b.tryToBuildSequence(currentLayerCTEs, p), nil +} + +func (b *PlanBuilder) tryToBuildSequence(ctes []*cteInfo, p base.LogicalPlan) base.LogicalPlan { + if !b.ctx.GetSessionVars().EnableMPPSharedCTEExecution { + return p + } + for i := len(ctes) - 1; i >= 0; i-- { + if !ctes[i].nonRecursive { + return p + } + if ctes[i].isInline || ctes[i].cteClass == nil { + ctes = append(ctes[:i], ctes[i+1:]...) + } + } + if len(ctes) == 0 { + return p + } + lctes := make([]base.LogicalPlan, 0, len(ctes)+1) + for _, cte := range ctes { + lcte := LogicalCTE{ + Cte: cte.cteClass, + CteAsName: cte.def.Name, + CteName: cte.def.Name, + SeedStat: cte.seedStat, + OnlyUsedAsStorage: true, + }.Init(b.ctx, b.getSelectOffset()) + lcte.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) + lctes = append(lctes, lcte) + } + b.optFlag |= flagPushDownSequence + seq := logicalop.LogicalSequence{}.Init(b.ctx, b.getSelectOffset()) + seq.SetChildren(append(lctes, p)...) + seq.SetOutputNames(p.OutputNames().Shallow()) + return seq +} + +func (b *PlanBuilder) buildTableDual() *logicalop.LogicalTableDual { + b.handleHelper.pushMap(nil) + return logicalop.LogicalTableDual{RowCount: 1}.Init(b.ctx, b.getSelectOffset()) +} + +func (ds *DataSource) newExtraHandleSchemaCol() *expression.Column { + tp := types.NewFieldType(mysql.TypeLonglong) + tp.SetFlag(mysql.NotNullFlag | mysql.PriKeyFlag) + return &expression.Column{ + RetType: tp, + UniqueID: ds.SCtx().GetSessionVars().AllocPlanColumnID(), + ID: model.ExtraHandleID, + OrigName: fmt.Sprintf("%v.%v.%v", ds.DBName, ds.TableInfo.Name, model.ExtraHandleName), + } +} + +// AddExtraPhysTblIDColumn for partition table. +// 'select ... for update' on a partition table need to know the partition ID +// to construct the lock key, so this column is added to the chunk row. +// Also needed for checking against the sessions transaction buffer +func (ds *DataSource) AddExtraPhysTblIDColumn() *expression.Column { + // Avoid adding multiple times (should never happen!) + cols := ds.TblCols + for i := len(cols) - 1; i >= 0; i-- { + if cols[i].ID == model.ExtraPhysTblID { + return cols[i] + } + } + pidCol := &expression.Column{ + RetType: types.NewFieldType(mysql.TypeLonglong), + UniqueID: ds.SCtx().GetSessionVars().AllocPlanColumnID(), + ID: model.ExtraPhysTblID, + OrigName: fmt.Sprintf("%v.%v.%v", ds.DBName, ds.TableInfo.Name, model.ExtraPhysTblIdName), + } + + ds.Columns = append(ds.Columns, model.NewExtraPhysTblIDColInfo()) + schema := ds.Schema() + schema.Append(pidCol) + ds.SetOutputNames(append(ds.OutputNames(), &types.FieldName{ + DBName: ds.DBName, + TblName: ds.TableInfo.Name, + ColName: model.ExtraPhysTblIdName, + OrigColName: model.ExtraPhysTblIdName, + })) + ds.TblCols = append(ds.TblCols, pidCol) + return pidCol +} + +// getStatsTable gets statistics information for a table specified by "tableID". +// A pseudo statistics table is returned in any of the following scenario: +// 1. tidb-server started and statistics handle has not been initialized. +// 2. table row count from statistics is zero. +// 3. statistics is outdated. +// Note: please also update getLatestVersionFromStatsTable() when logic in this function changes. +func getStatsTable(ctx base.PlanContext, tblInfo *model.TableInfo, pid int64) *statistics.Table { + statsHandle := domain.GetDomain(ctx).StatsHandle() + var usePartitionStats, countIs0, pseudoStatsForUninitialized, pseudoStatsForOutdated bool + var statsTbl *statistics.Table + if ctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(ctx) + defer func() { + debugTraceGetStatsTbl(ctx, + tblInfo, + pid, + statsHandle == nil, + usePartitionStats, + countIs0, + pseudoStatsForUninitialized, + pseudoStatsForOutdated, + statsTbl, + ) + debugtrace.LeaveContextCommon(ctx) + }() + } + // 1. tidb-server started and statistics handle has not been initialized. + if statsHandle == nil { + return statistics.PseudoTable(tblInfo, false, true) + } + + if pid == tblInfo.ID || ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + statsTbl = statsHandle.GetTableStats(tblInfo) + } else { + usePartitionStats = true + statsTbl = statsHandle.GetPartitionStats(tblInfo, pid) + } + intest.Assert(statsTbl.ColAndIdxExistenceMap != nil, "The existence checking map must not be nil.") + + allowPseudoTblTriggerLoading := false + // In OptObjectiveDeterminate mode, we need to ignore the real-time stats. + // To achieve this, we copy the statsTbl and reset the real-time stats fields (set ModifyCount to 0 and set + // RealtimeCount to the row count from the ANALYZE, which is fetched from loaded stats in GetAnalyzeRowCount()). + if ctx.GetSessionVars().GetOptObjective() == variable.OptObjectiveDeterminate { + analyzeCount := max(int64(statsTbl.GetAnalyzeRowCount()), 0) + // If the two fields are already the values we want, we don't need to modify it, and also we don't need to copy. + if statsTbl.RealtimeCount != analyzeCount || statsTbl.ModifyCount != 0 { + // Here is a case that we need specially care about: + // The original stats table from the stats cache is not a pseudo table, but the analyze row count is 0 (probably + // because of no col/idx stats are loaded), which will makes it a pseudo table according to the rule 2 below. + // Normally, a pseudo table won't trigger stats loading since we assume it means "no stats available", but + // in such case, we need it able to trigger stats loading. + // That's why we use the special allowPseudoTblTriggerLoading flag here. + if !statsTbl.Pseudo && statsTbl.RealtimeCount > 0 && analyzeCount == 0 { + allowPseudoTblTriggerLoading = true + } + // Copy it so we can modify the ModifyCount and the RealtimeCount safely. + statsTbl = statsTbl.ShallowCopy() + statsTbl.RealtimeCount = analyzeCount + statsTbl.ModifyCount = 0 + } + } + + // 2. table row count from statistics is zero. + if statsTbl.RealtimeCount == 0 { + countIs0 = true + core_metrics.PseudoEstimationNotAvailable.Inc() + return statistics.PseudoTable(tblInfo, allowPseudoTblTriggerLoading, true) + } + + // 3. statistics is uninitialized or outdated. + pseudoStatsForUninitialized = !statsTbl.IsInitialized() + pseudoStatsForOutdated = ctx.GetSessionVars().GetEnablePseudoForOutdatedStats() && statsTbl.IsOutdated() + if pseudoStatsForUninitialized || pseudoStatsForOutdated { + tbl := *statsTbl + tbl.Pseudo = true + statsTbl = &tbl + if pseudoStatsForUninitialized { + core_metrics.PseudoEstimationNotAvailable.Inc() + } else { + core_metrics.PseudoEstimationOutdate.Inc() + } + } + + return statsTbl +} + +// getLatestVersionFromStatsTable gets statistics information for a table specified by "tableID", and get the max +// LastUpdateVersion among all Columns and Indices in it. +// Its overall logic is quite similar to getStatsTable(). During plan cache matching, only the latest version is needed. +// In such case, compared to getStatsTable(), this function can save some copies, memory allocations and unnecessary +// checks. Also, this function won't trigger metrics changes. +func getLatestVersionFromStatsTable(ctx sessionctx.Context, tblInfo *model.TableInfo, pid int64) (version uint64) { + statsHandle := domain.GetDomain(ctx).StatsHandle() + // 1. tidb-server started and statistics handle has not been initialized. Pseudo stats table. + if statsHandle == nil { + return 0 + } + + var statsTbl *statistics.Table + if pid == tblInfo.ID || ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { + statsTbl = statsHandle.GetTableStats(tblInfo) + } else { + statsTbl = statsHandle.GetPartitionStats(tblInfo, pid) + } + + // 2. Table row count from statistics is zero. Pseudo stats table. + realtimeRowCount := statsTbl.RealtimeCount + if ctx.GetSessionVars().GetOptObjective() == variable.OptObjectiveDeterminate { + realtimeRowCount = max(int64(statsTbl.GetAnalyzeRowCount()), 0) + } + if realtimeRowCount == 0 { + return 0 + } + + // 3. Not pseudo stats table. Return the max LastUpdateVersion among all Columns and Indices + // return statsTbl.LastAnalyzeVersion + statsTbl.ForEachColumnImmutable(func(_ int64, col *statistics.Column) bool { + version = max(version, col.LastUpdateVersion) + return false + }) + statsTbl.ForEachIndexImmutable(func(_ int64, idx *statistics.Index) bool { + version = max(version, idx.LastUpdateVersion) + return false + }) + return version +} + +func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName *model.CIStr) (base.LogicalPlan, error) { + for i := len(b.outerCTEs) - 1; i >= 0; i-- { + cte := b.outerCTEs[i] + if cte.def.Name.L == tn.Name.L { + if cte.isBuilding { + if cte.nonRecursive { + // Can't see this CTE, try outer definition. + continue + } + + // Building the recursive part. + cte.useRecursive = true + if cte.seedLP == nil { + return nil, plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst.FastGenByArgs(tn.Name.String()) + } + + if cte.enterSubquery || cte.recursiveRef { + return nil, plannererrors.ErrInvalidRequiresSingleReference.FastGenByArgs(tn.Name.String()) + } + + cte.recursiveRef = true + p := logicalop.LogicalCTETable{Name: cte.def.Name.String(), IDForStorage: cte.storageID, SeedStat: cte.seedStat, SeedSchema: cte.seedLP.Schema()}.Init(b.ctx, b.getSelectOffset()) + p.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) + p.SetOutputNames(cte.seedLP.OutputNames()) + return p, nil + } + + b.handleHelper.pushMap(nil) + + hasLimit := false + limitBeg := uint64(0) + limitEnd := uint64(0) + if cte.limitLP != nil { + hasLimit = true + switch x := cte.limitLP.(type) { + case *logicalop.LogicalLimit: + limitBeg = x.Offset + limitEnd = x.Offset + x.Count + case *logicalop.LogicalTableDual: + // Beg and End will both be 0. + default: + return nil, errors.Errorf("invalid type for limit plan: %v", cte.limitLP) + } + } + + if cte.cteClass == nil { + cte.cteClass = &CTEClass{ + IsDistinct: cte.isDistinct, + seedPartLogicalPlan: cte.seedLP, + recursivePartLogicalPlan: cte.recurLP, + IDForStorage: cte.storageID, + optFlag: cte.optFlag, + HasLimit: hasLimit, + LimitBeg: limitBeg, + LimitEnd: limitEnd, + pushDownPredicates: make([]expression.Expression, 0), + ColumnMap: make(map[string]*expression.Column), + } + } + var p base.LogicalPlan + lp := LogicalCTE{CteAsName: tn.Name, CteName: tn.Name, Cte: cte.cteClass, SeedStat: cte.seedStat}.Init(b.ctx, b.getSelectOffset()) + prevSchema := cte.seedLP.Schema().Clone() + lp.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) + + // If current CTE query contain another CTE which 'containAggOrWindow' is true, current CTE 'containAggOrWindow' will be true + if b.buildingCTE { + b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow = cte.containAggOrWindow || b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow + } + // Compute cte inline + b.computeCTEInlineFlag(cte) + + if cte.recurLP == nil && cte.isInline { + saveCte := make([]*cteInfo, len(b.outerCTEs[i:])) + copy(saveCte, b.outerCTEs[i:]) + b.outerCTEs = b.outerCTEs[:i] + o := b.buildingCTE + b.buildingCTE = false + //nolint:all_revive,revive + defer func() { + b.outerCTEs = append(b.outerCTEs, saveCte...) + b.buildingCTE = o + }() + return b.buildDataSourceFromCTEMerge(ctx, cte.def) + } + + for i, col := range lp.Schema().Columns { + lp.Cte.ColumnMap[string(col.HashCode())] = prevSchema.Columns[i] + } + p = lp + p.SetOutputNames(cte.seedLP.OutputNames()) + if len(asName.String()) > 0 { + lp.CteAsName = *asName + var on types.NameSlice + for _, name := range p.OutputNames() { + cpOn := *name + cpOn.TblName = *asName + on = append(on, &cpOn) + } + p.SetOutputNames(on) + } + return p, nil + } + } + + return nil, nil +} + +// computeCTEInlineFlag, Combine the declaration of CTE and the use of CTE to jointly determine **whether a CTE can be inlined** +/* + There are some cases that CTE must be not inlined. + 1. CTE is recursive CTE. + 2. CTE contains agg or window and it is referenced by recursive part of CTE. + 3. Consumer count of CTE is more than one. + If 1 or 2 conditions are met, CTE cannot be inlined. + But if query is hint by 'merge()' or session variable "tidb_opt_force_inline_cte", + CTE will still not be inlined but a warning will be recorded "Hint or session variables are invalid" + If 3 condition is met, CTE can be inlined by hint and session variables. +*/ +func (b *PlanBuilder) computeCTEInlineFlag(cte *cteInfo) { + if cte.recurLP != nil { + if cte.forceInlineByHintOrVar { + b.ctx.GetSessionVars().StmtCtx.SetHintWarning( + fmt.Sprintf("Recursive CTE %s can not be inlined by merge() or tidb_opt_force_inline_cte.", cte.def.Name)) + } + } else if cte.containAggOrWindow && b.buildingRecursivePartForCTE { + if cte.forceInlineByHintOrVar { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(plannererrors.ErrCTERecursiveForbidsAggregation.FastGenByArgs(cte.def.Name)) + } + } else if cte.consumerCount > 1 { + if cte.forceInlineByHintOrVar { + cte.isInline = true + } + } else { + cte.isInline = true + } +} + +func (b *PlanBuilder) buildDataSourceFromCTEMerge(ctx context.Context, cte *ast.CommonTableExpression) (base.LogicalPlan, error) { + p, err := b.buildResultSetNode(ctx, cte.Query.Query, true) + if err != nil { + return nil, err + } + b.handleHelper.popMap() + outPutNames := p.OutputNames() + for _, name := range outPutNames { + name.TblName = cte.Name + name.DBName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) + } + + if len(cte.ColNameList) > 0 { + if len(cte.ColNameList) != len(p.OutputNames()) { + return nil, errors.New("CTE columns length is not consistent") + } + for i, n := range cte.ColNameList { + outPutNames[i].ColName = n + } + } + p.SetOutputNames(outPutNames) + return p, nil +} + +func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, asName *model.CIStr) (base.LogicalPlan, error) { + b.optFlag |= flagPredicateSimplification + dbName := tn.Schema + sessionVars := b.ctx.GetSessionVars() + + if dbName.L == "" { + // Try CTE. + p, err := b.tryBuildCTE(ctx, tn, asName) + if err != nil || p != nil { + return p, err + } + dbName = model.NewCIStr(sessionVars.CurrentDB) + } + + is := b.is + if len(b.buildingViewStack) > 0 { + // For tables in view, always ignore local temporary table, considering the below case: + // If a user created a normal table `t1` and a view `v1` referring `t1`, and then a local temporary table with a same name `t1` is created. + // At this time, executing 'select * from v1' should still return all records from normal table `t1` instead of temporary table `t1`. + is = temptable.DetachLocalTemporaryTableInfoSchema(is) + } + + tbl, err := is.TableByName(ctx, dbName, tn.Name) + if err != nil { + return nil, err + } + + tbl, err = tryLockMDLAndUpdateSchemaIfNecessary(ctx, b.ctx, dbName, tbl, b.is) + if err != nil { + return nil, err + } + tableInfo := tbl.Meta() + + if b.isCreateView && tableInfo.TempTableType == model.TempTableLocal { + return nil, plannererrors.ErrViewSelectTemporaryTable.GenWithStackByArgs(tn.Name) + } + + var authErr error + if sessionVars.User != nil { + authErr = plannererrors.ErrTableaccessDenied.FastGenByArgs("SELECT", sessionVars.User.AuthUsername, sessionVars.User.AuthHostname, tableInfo.Name.L) + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName.L, tableInfo.Name.L, "", authErr) + + if tbl.Type().IsVirtualTable() { + if tn.TableSample != nil { + return nil, expression.ErrInvalidTableSample.GenWithStackByArgs("Unsupported TABLESAMPLE in virtual tables") + } + return b.buildMemTable(ctx, dbName, tableInfo) + } + + tblName := *asName + if tblName.L == "" { + tblName = tn.Name + } + + if tableInfo.GetPartitionInfo() != nil { + // If `UseDynamicPruneMode` already been false, then we don't need to check whether execute `flagPartitionProcessor` + // otherwise we need to check global stats initialized for each partition table + if !b.ctx.GetSessionVars().IsDynamicPartitionPruneEnabled() { + b.optFlag = b.optFlag | flagPartitionProcessor + } else { + if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPruneMode { + b.optFlag = b.optFlag | flagPartitionProcessor + } else { + h := domain.GetDomain(b.ctx).StatsHandle() + tblStats := h.GetTableStats(tableInfo) + isDynamicEnabled := b.ctx.GetSessionVars().IsDynamicPartitionPruneEnabled() + globalStatsReady := tblStats.IsAnalyzed() + skipMissingPartition := b.ctx.GetSessionVars().SkipMissingPartitionStats + // If we already enabled the tidb_skip_missing_partition_stats, the global stats can be treated as exist. + allowDynamicWithoutStats := fixcontrol.GetBoolWithDefault(b.ctx.GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix44262, skipMissingPartition) + + // If dynamic partition prune isn't enabled or global stats is not ready, we won't enable dynamic prune mode in query + usePartitionProcessor := !isDynamicEnabled || (!globalStatsReady && !allowDynamicWithoutStats) + + failpoint.Inject("forceDynamicPrune", func(val failpoint.Value) { + if val.(bool) { + if isDynamicEnabled { + usePartitionProcessor = false + } + } + }) + + if usePartitionProcessor { + b.optFlag = b.optFlag | flagPartitionProcessor + b.ctx.GetSessionVars().StmtCtx.UseDynamicPruneMode = false + if isDynamicEnabled { + b.ctx.GetSessionVars().StmtCtx.AppendWarning( + fmt.Errorf("disable dynamic pruning due to %s has no global stats", tableInfo.Name.String())) + } + } + } + } + pt := tbl.(table.PartitionedTable) + // check partition by name. + if len(tn.PartitionNames) > 0 { + pids := make(map[int64]struct{}, len(tn.PartitionNames)) + for _, name := range tn.PartitionNames { + pid, err := tables.FindPartitionByName(tableInfo, name.L) + if err != nil { + return nil, err + } + pids[pid] = struct{}{} + } + pt = tables.NewPartitionTableWithGivenSets(pt, pids) + } + b.partitionedTable = append(b.partitionedTable, pt) + } else if len(tn.PartitionNames) != 0 { + return nil, plannererrors.ErrPartitionClauseOnNonpartitioned + } + + possiblePaths, err := getPossibleAccessPaths(b.ctx, b.TableHints(), tn.IndexHints, tbl, dbName, tblName, b.isForUpdateRead, b.optFlag&flagPartitionProcessor > 0) + if err != nil { + return nil, err + } + + if tableInfo.IsView() { + if tn.TableSample != nil { + return nil, expression.ErrInvalidTableSample.GenWithStackByArgs("Unsupported TABLESAMPLE in views") + } + + // Get the hints belong to the current view. + currentQBNameMap4View := make(map[string][]ast.HintTable) + currentViewHints := make(map[string][]*ast.TableOptimizerHint) + for qbName, viewQBNameHintTable := range b.hintProcessor.ViewQBNameToTable { + if len(viewQBNameHintTable) == 0 { + continue + } + viewSelectOffset := b.getSelectOffset() + + var viewHintSelectOffset int + if viewQBNameHintTable[0].QBName.L == "" { + // If we do not explicit set the qbName, we will set the empty qb name to @sel_1. + viewHintSelectOffset = 1 + } else { + viewHintSelectOffset = b.hintProcessor.GetHintOffset(viewQBNameHintTable[0].QBName, viewSelectOffset) + } + + // Check whether the current view can match the view name in the hint. + if viewQBNameHintTable[0].TableName.L == tblName.L && viewHintSelectOffset == viewSelectOffset { + // If the view hint can match the current view, we pop the first view table in the query block hint's table list. + // It means the hint belong the current view, the first view name in hint is matched. + // Because of the nested views, so we should check the left table list in hint when build the data source from the view inside the current view. + currentQBNameMap4View[qbName] = viewQBNameHintTable[1:] + currentViewHints[qbName] = b.hintProcessor.ViewQBNameToHints[qbName] + b.hintProcessor.ViewQBNameUsed[qbName] = struct{}{} + } + } + return b.BuildDataSourceFromView(ctx, dbName, tableInfo, currentQBNameMap4View, currentViewHints) + } + + if tableInfo.IsSequence() { + if tn.TableSample != nil { + return nil, expression.ErrInvalidTableSample.GenWithStackByArgs("Unsupported TABLESAMPLE in sequences") + } + // When the source is a Sequence, we convert it to a TableDual, as what most databases do. + return b.buildTableDual(), nil + } + + // remain tikv access path to generate point get acceess path if existed + // see detail in issue: https://github.com/pingcap/tidb/issues/39543 + if !(b.isForUpdateRead && b.ctx.GetSessionVars().TxnCtx.IsExplicit) { + // Skip storage engine check for CreateView. + if b.capFlag&canExpandAST == 0 { + possiblePaths, err = filterPathByIsolationRead(b.ctx, possiblePaths, tblName, dbName) + if err != nil { + return nil, err + } + } + } + + // Try to substitute generate column only if there is an index on generate column. + for _, index := range tableInfo.Indices { + if index.State != model.StatePublic { + continue + } + for _, indexCol := range index.Columns { + colInfo := tbl.Cols()[indexCol.Offset] + if colInfo.IsGenerated() && !colInfo.GeneratedStored { + b.optFlag |= flagGcSubstitute + break + } + } + } + + var columns []*table.Column + if b.inUpdateStmt { + // create table t(a int, b int). + // Imagine that, There are 2 TiDB instances in the cluster, name A, B. We add a column `c` to table t in the TiDB cluster. + // One of the TiDB, A, the column type in its infoschema is changed to public. And in the other TiDB, the column type is + // still StateWriteReorganization. + // TiDB A: insert into t values(1, 2, 3); + // TiDB B: update t set a = 2 where b = 2; + // If we use tbl.Cols() here, the update statement, will ignore the col `c`, and the data `3` will lost. + columns = tbl.WritableCols() + } else if b.inDeleteStmt { + // DeletableCols returns all columns of the table in deletable states. + columns = tbl.DeletableCols() + } else { + columns = tbl.Cols() + } + // extract the IndexMergeHint + var indexMergeHints []h.HintedIndex + if hints := b.TableHints(); hints != nil { + for i, hint := range hints.IndexMergeHintList { + if hint.Match(dbName, tblName) { + hints.IndexMergeHintList[i].Matched = true + // check whether the index names in IndexMergeHint are valid. + invalidIdxNames := make([]string, 0, len(hint.IndexHint.IndexNames)) + for _, idxName := range hint.IndexHint.IndexNames { + hasIdxName := false + for _, path := range possiblePaths { + if path.IsTablePath() { + if idxName.L == "primary" { + hasIdxName = true + break + } + continue + } + if idxName.L == path.Index.Name.L { + hasIdxName = true + break + } + } + if !hasIdxName { + invalidIdxNames = append(invalidIdxNames, idxName.String()) + } + } + if len(invalidIdxNames) == 0 { + indexMergeHints = append(indexMergeHints, hint) + } else { + // Append warning if there are invalid index names. + errMsg := fmt.Sprintf("use_index_merge(%s) is inapplicable, check whether the indexes (%s) "+ + "exist, or the indexes are conflicted with use_index/ignore_index/force_index hints.", + hint.IndexString(), strings.Join(invalidIdxNames, ", ")) + b.ctx.GetSessionVars().StmtCtx.SetHintWarning(errMsg) + } + } + } + } + ds := DataSource{ + DBName: dbName, + TableAsName: asName, + table: tbl, + TableInfo: tableInfo, + PhysicalTableID: tableInfo.ID, + AstIndexHints: tn.IndexHints, + IndexHints: b.TableHints().IndexHintList, + IndexMergeHints: indexMergeHints, + PossibleAccessPaths: possiblePaths, + Columns: make([]*model.ColumnInfo, 0, len(columns)), + PartitionNames: tn.PartitionNames, + TblCols: make([]*expression.Column, 0, len(columns)), + PreferPartitions: make(map[int][]model.CIStr), + IS: b.is, + IsForUpdateRead: b.isForUpdateRead, + }.Init(b.ctx, b.getSelectOffset()) + var handleCols util.HandleCols + schema := expression.NewSchema(make([]*expression.Column, 0, len(columns))...) + names := make([]*types.FieldName, 0, len(columns)) + for i, col := range columns { + ds.Columns = append(ds.Columns, col.ToInfo()) + names = append(names, &types.FieldName{ + DBName: dbName, + TblName: tableInfo.Name, + ColName: col.Name, + OrigTblName: tableInfo.Name, + OrigColName: col.Name, + // For update statement and delete statement, internal version should see the special middle state column, while user doesn't. + NotExplicitUsable: col.State != model.StatePublic, + }) + newCol := &expression.Column{ + UniqueID: sessionVars.AllocPlanColumnID(), + ID: col.ID, + RetType: col.FieldType.Clone(), + OrigName: names[i].String(), + IsHidden: col.Hidden, + } + if col.IsPKHandleColumn(tableInfo) { + handleCols = util.NewIntHandleCols(newCol) + } + schema.Append(newCol) + ds.TblCols = append(ds.TblCols, newCol) + } + // We append an extra handle column to the schema when the handle + // column is not the primary key of "ds". + if handleCols == nil { + if tableInfo.IsCommonHandle { + primaryIdx := tables.FindPrimaryIndex(tableInfo) + handleCols = util.NewCommonHandleCols(b.ctx.GetSessionVars().StmtCtx, tableInfo, primaryIdx, ds.TblCols) + } else { + extraCol := ds.newExtraHandleSchemaCol() + handleCols = util.NewIntHandleCols(extraCol) + ds.Columns = append(ds.Columns, model.NewExtraHandleColInfo()) + schema.Append(extraCol) + names = append(names, &types.FieldName{ + DBName: dbName, + TblName: tableInfo.Name, + ColName: model.ExtraHandleName, + OrigColName: model.ExtraHandleName, + }) + ds.TblCols = append(ds.TblCols, extraCol) + } + } + ds.HandleCols = handleCols + ds.UnMutableHandleCols = handleCols + handleMap := make(map[int64][]util.HandleCols) + handleMap[tableInfo.ID] = []util.HandleCols{handleCols} + b.handleHelper.pushMap(handleMap) + ds.SetSchema(schema) + ds.SetOutputNames(names) + ds.setPreferredStoreType(b.TableHints()) + ds.SampleInfo = tablesampler.NewTableSampleInfo(tn.TableSample, schema, b.partitionedTable) + b.isSampling = ds.SampleInfo != nil + + for i, colExpr := range ds.Schema().Columns { + var expr expression.Expression + if i < len(columns) { + if columns[i].IsGenerated() && !columns[i].GeneratedStored { + var err error + originVal := b.allowBuildCastArray + b.allowBuildCastArray = true + expr, _, err = b.rewrite(ctx, columns[i].GeneratedExpr.Clone(), ds, nil, true) + b.allowBuildCastArray = originVal + if err != nil { + return nil, err + } + colExpr.VirtualExpr = expr.Clone() + } + } + } + + // Init CommonHandleCols and CommonHandleLens for data source. + if tableInfo.IsCommonHandle { + ds.CommonHandleCols, ds.CommonHandleLens = expression.IndexInfo2Cols(ds.Columns, ds.Schema().Columns, tables.FindPrimaryIndex(tableInfo)) + } + // Init FullIdxCols, FullIdxColLens for accessPaths. + for _, path := range ds.PossibleAccessPaths { + if !path.IsIntHandlePath { + path.FullIdxCols, path.FullIdxColLens = expression.IndexInfo2Cols(ds.Columns, ds.Schema().Columns, path.Index) + + // check whether the path's index has a tidb_shard() prefix and the index column count + // more than 1. e.g. index(tidb_shard(a), a) + // set UkShardIndexPath only for unique secondary index + if !path.IsCommonHandlePath { + // tidb_shard expression must be first column of index + col := path.FullIdxCols[0] + if col != nil && + expression.GcColumnExprIsTidbShard(col.VirtualExpr) && + len(path.Index.Columns) > 1 && + path.Index.Unique { + path.IsUkShardIndexPath = true + ds.ContainExprPrefixUk = true + } + } + } + } + + var result base.LogicalPlan = ds + dirty := tableHasDirtyContent(b.ctx, tableInfo) + if dirty || tableInfo.TempTableType == model.TempTableLocal || tableInfo.TableCacheStatusType == model.TableCacheStatusEnable { + us := logicalop.LogicalUnionScan{HandleCols: handleCols}.Init(b.ctx, b.getSelectOffset()) + us.SetChildren(ds) + if tableInfo.Partition != nil && b.optFlag&flagPartitionProcessor == 0 { + // Adding ExtraPhysTblIDCol for UnionScan (transaction buffer handling) + // Not using old static prune mode + // Single TableReader for all partitions, needs the PhysTblID from storage + _ = ds.AddExtraPhysTblIDColumn() + } + result = us + } + + // Adding ExtraPhysTblIDCol for SelectLock (SELECT FOR UPDATE) is done when building SelectLock + + if sessionVars.StmtCtx.TblInfo2UnionScan == nil { + sessionVars.StmtCtx.TblInfo2UnionScan = make(map[*model.TableInfo]bool) + } + sessionVars.StmtCtx.TblInfo2UnionScan[tableInfo] = dirty + + return result, nil +} + +func (b *PlanBuilder) timeRangeForSummaryTable() util.QueryTimeRange { + const defaultSummaryDuration = 30 * time.Minute + hints := b.TableHints() + // User doesn't use TIME_RANGE hint + if hints == nil || (hints.TimeRangeHint.From == "" && hints.TimeRangeHint.To == "") { + to := time.Now() + from := to.Add(-defaultSummaryDuration) + return util.QueryTimeRange{From: from, To: to} + } + + // Parse time specified by user via TIM_RANGE hint + parse := func(s string) (time.Time, bool) { + t, err := time.ParseInLocation(util.MetricTableTimeFormat, s, time.Local) + if err != nil { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(err) + } + return t, err == nil + } + from, fromValid := parse(hints.TimeRangeHint.From) + to, toValid := parse(hints.TimeRangeHint.To) + switch { + case !fromValid && !toValid: + to = time.Now() + from = to.Add(-defaultSummaryDuration) + case fromValid && !toValid: + to = from.Add(defaultSummaryDuration) + case !fromValid && toValid: + from = to.Add(-defaultSummaryDuration) + } + + return util.QueryTimeRange{From: from, To: to} +} + +func (b *PlanBuilder) buildMemTable(_ context.Context, dbName model.CIStr, tableInfo *model.TableInfo) (base.LogicalPlan, error) { + // We can use the `TableInfo.Columns` directly because the memory table has + // a stable schema and there is no online DDL on the memory table. + schema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.Columns))...) + names := make([]*types.FieldName, 0, len(tableInfo.Columns)) + var handleCols util.HandleCols + for _, col := range tableInfo.Columns { + names = append(names, &types.FieldName{ + DBName: dbName, + TblName: tableInfo.Name, + ColName: col.Name, + OrigTblName: tableInfo.Name, + OrigColName: col.Name, + }) + // NOTE: Rewrite the expression if memory table supports generated columns in the future + newCol := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + ID: col.ID, + RetType: &col.FieldType, + } + if tableInfo.PKIsHandle && mysql.HasPriKeyFlag(col.GetFlag()) { + handleCols = util.NewIntHandleCols(newCol) + } + schema.Append(newCol) + } + + if handleCols != nil { + handleMap := make(map[int64][]util.HandleCols) + handleMap[tableInfo.ID] = []util.HandleCols{handleCols} + b.handleHelper.pushMap(handleMap) + } else { + b.handleHelper.pushMap(nil) + } + + // NOTE: Add a `LogicalUnionScan` if we support update memory table in the future + p := logicalop.LogicalMemTable{ + DBName: dbName, + TableInfo: tableInfo, + Columns: make([]*model.ColumnInfo, len(tableInfo.Columns)), + }.Init(b.ctx, b.getSelectOffset()) + p.SetSchema(schema) + p.SetOutputNames(names) + copy(p.Columns, tableInfo.Columns) + + // Some memory tables can receive some predicates + switch dbName.L { + case util2.MetricSchemaName.L: + p.Extractor = newMetricTableExtractor() + case util2.InformationSchemaName.L: + switch upTbl := strings.ToUpper(tableInfo.Name.O); upTbl { + case infoschema.TableClusterConfig, infoschema.TableClusterLoad, infoschema.TableClusterHardware, infoschema.TableClusterSystemInfo: + p.Extractor = &ClusterTableExtractor{} + case infoschema.TableClusterLog: + p.Extractor = &ClusterLogTableExtractor{} + case infoschema.TableTiDBHotRegionsHistory: + p.Extractor = &HotRegionsHistoryTableExtractor{} + case infoschema.TableInspectionResult: + p.Extractor = &InspectionResultTableExtractor{} + p.QueryTimeRange = b.timeRangeForSummaryTable() + case infoschema.TableInspectionSummary: + p.Extractor = &InspectionSummaryTableExtractor{} + p.QueryTimeRange = b.timeRangeForSummaryTable() + case infoschema.TableInspectionRules: + p.Extractor = &InspectionRuleTableExtractor{} + case infoschema.TableMetricSummary, infoschema.TableMetricSummaryByLabel: + p.Extractor = &MetricSummaryTableExtractor{} + p.QueryTimeRange = b.timeRangeForSummaryTable() + case infoschema.TableSlowQuery: + p.Extractor = &SlowQueryExtractor{} + case infoschema.TableStorageStats: + p.Extractor = &TableStorageStatsExtractor{} + case infoschema.TableTiFlashTables, infoschema.TableTiFlashSegments: + p.Extractor = &TiFlashSystemTableExtractor{} + case infoschema.TableStatementsSummary, infoschema.TableStatementsSummaryHistory: + p.Extractor = &StatementsSummaryExtractor{} + case infoschema.TableTiKVRegionPeers: + p.Extractor = &TikvRegionPeersExtractor{} + case infoschema.TableColumns: + p.Extractor = &ColumnsTableExtractor{} + case infoschema.TableTables: + ex := &InfoSchemaTablesExtractor{} + ex.initExtractableColNames(upTbl) + p.Extractor = ex + case infoschema.TablePartitions: + ex := &InfoSchemaPartitionsExtractor{} + ex.initExtractableColNames(upTbl) + p.Extractor = ex + case infoschema.TableStatistics: + ex := &InfoSchemaStatisticsExtractor{} + ex.initExtractableColNames(upTbl) + p.Extractor = ex + case infoschema.TableSchemata: + ex := &InfoSchemaSchemataExtractor{} + ex.initExtractableColNames(upTbl) + p.Extractor = ex + case infoschema.TableReferConst, + infoschema.TableKeyColumn, + infoschema.TableSequences, + infoschema.TableCheckConstraints, + infoschema.TableTiDBCheckConstraints, + infoschema.TableTiDBIndexUsage, + infoschema.TableTiDBIndexes, + infoschema.TableViews, + infoschema.TableConstraints: + ex := &InfoSchemaBaseExtractor{} + ex.initExtractableColNames(upTbl) + p.Extractor = ex + case infoschema.TableTiKVRegionStatus: + p.Extractor = &TiKVRegionStatusExtractor{tablesID: make([]int64, 0)} + } + } + return p, nil +} + +// checkRecursiveView checks whether this view is recursively defined. +func (b *PlanBuilder) checkRecursiveView(dbName model.CIStr, tableName model.CIStr) (func(), error) { + viewFullName := dbName.L + "." + tableName.L + if b.buildingViewStack == nil { + b.buildingViewStack = set.NewStringSet() + } + // If this view has already been on the building stack, it means + // this view contains a recursive definition. + if b.buildingViewStack.Exist(viewFullName) { + return nil, plannererrors.ErrViewRecursive.GenWithStackByArgs(dbName.O, tableName.O) + } + // If the view is being renamed, we return the mysql compatible error message. + if b.capFlag&renameView != 0 && viewFullName == b.renamingViewName { + return nil, plannererrors.ErrNoSuchTable.GenWithStackByArgs(dbName.O, tableName.O) + } + b.buildingViewStack.Insert(viewFullName) + return func() { delete(b.buildingViewStack, viewFullName) }, nil +} + +// BuildDataSourceFromView is used to build base.LogicalPlan from view +// qbNameMap4View and viewHints are used for the view's hint. +// qbNameMap4View maps the query block name to the view table lists. +// viewHints group the view hints based on the view's query block name. +func (b *PlanBuilder) BuildDataSourceFromView(ctx context.Context, dbName model.CIStr, tableInfo *model.TableInfo, qbNameMap4View map[string][]ast.HintTable, viewHints map[string][]*ast.TableOptimizerHint) (base.LogicalPlan, error) { + viewDepth := b.ctx.GetSessionVars().StmtCtx.ViewDepth + b.ctx.GetSessionVars().StmtCtx.ViewDepth++ + deferFunc, err := b.checkRecursiveView(dbName, tableInfo.Name) + if err != nil { + return nil, err + } + defer deferFunc() + + charset, collation := b.ctx.GetSessionVars().GetCharsetInfo() + viewParser := parser.New() + viewParser.SetParserConfig(b.ctx.GetSessionVars().BuildParserConfig()) + selectNode, err := viewParser.ParseOneStmt(tableInfo.View.SelectStmt, charset, collation) + if err != nil { + return nil, err + } + originalVisitInfo := b.visitInfo + b.visitInfo = make([]visitInfo, 0) + + // For the case that views appear in CTE queries, + // we need to save the CTEs after the views are established. + var saveCte []*cteInfo + if len(b.outerCTEs) > 0 { + saveCte = make([]*cteInfo, len(b.outerCTEs)) + copy(saveCte, b.outerCTEs) + } else { + saveCte = nil + } + o := b.buildingCTE + b.buildingCTE = false + defer func() { + b.outerCTEs = saveCte + b.buildingCTE = o + }() + + hintProcessor := h.NewQBHintHandler(b.ctx.GetSessionVars().StmtCtx) + selectNode.Accept(hintProcessor) + currentQbNameMap4View := make(map[string][]ast.HintTable) + currentQbHints4View := make(map[string][]*ast.TableOptimizerHint) + currentQbHints := make(map[int][]*ast.TableOptimizerHint) + currentQbNameMap := make(map[string]int) + + for qbName, viewQbNameHint := range qbNameMap4View { + // Check whether the view hint belong the current view or its nested views. + qbOffset := -1 + if len(viewQbNameHint) == 0 { + qbOffset = 1 + } else if len(viewQbNameHint) == 1 && viewQbNameHint[0].TableName.L == "" { + qbOffset = hintProcessor.GetHintOffset(viewQbNameHint[0].QBName, -1) + } else { + currentQbNameMap4View[qbName] = viewQbNameHint + currentQbHints4View[qbName] = viewHints[qbName] + } + + if qbOffset != -1 { + // If the hint belongs to the current view and not belongs to it's nested views, we should convert the view hint to the normal hint. + // After we convert the view hint to the normal hint, it can be reused the origin hint's infrastructure. + currentQbHints[qbOffset] = viewHints[qbName] + currentQbNameMap[qbName] = qbOffset + + delete(qbNameMap4View, qbName) + delete(viewHints, qbName) + } + } + + hintProcessor.ViewQBNameToTable = qbNameMap4View + hintProcessor.ViewQBNameToHints = viewHints + hintProcessor.ViewQBNameUsed = make(map[string]struct{}) + hintProcessor.QBOffsetToHints = currentQbHints + hintProcessor.QBNameToSelOffset = currentQbNameMap + + originHintProcessor := b.hintProcessor + originPlannerSelectBlockAsName := b.ctx.GetSessionVars().PlannerSelectBlockAsName.Load() + b.hintProcessor = hintProcessor + newPlannerSelectBlockAsName := make([]ast.HintTable, hintProcessor.MaxSelectStmtOffset()+1) + b.ctx.GetSessionVars().PlannerSelectBlockAsName.Store(&newPlannerSelectBlockAsName) + defer func() { + b.hintProcessor.HandleUnusedViewHints() + b.hintProcessor = originHintProcessor + b.ctx.GetSessionVars().PlannerSelectBlockAsName.Store(originPlannerSelectBlockAsName) + }() + selectLogicalPlan, err := b.Build(ctx, selectNode) + if err != nil { + logutil.BgLogger().Error("build plan for view failed", zap.Error(err)) + if terror.ErrorNotEqual(err, plannererrors.ErrViewRecursive) && + terror.ErrorNotEqual(err, plannererrors.ErrNoSuchTable) && + terror.ErrorNotEqual(err, plannererrors.ErrInternal) && + terror.ErrorNotEqual(err, plannererrors.ErrFieldNotInGroupBy) && + terror.ErrorNotEqual(err, plannererrors.ErrMixOfGroupFuncAndFields) && + terror.ErrorNotEqual(err, plannererrors.ErrViewNoExplain) && + terror.ErrorNotEqual(err, plannererrors.ErrNotSupportedYet) { + err = plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) + } + return nil, err + } + pm := privilege.GetPrivilegeManager(b.ctx) + if viewDepth != 0 && + b.ctx.GetSessionVars().StmtCtx.InExplainStmt && + pm != nil && + !pm.RequestVerification(b.ctx.GetSessionVars().ActiveRoles, dbName.L, tableInfo.Name.L, "", mysql.SelectPriv) { + return nil, plannererrors.ErrViewNoExplain + } + if tableInfo.View.Security == model.SecurityDefiner { + if pm != nil { + for _, v := range b.visitInfo { + if !pm.RequestVerificationWithUser(v.db, v.table, v.column, v.privilege, tableInfo.View.Definer) { + return nil, plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) + } + } + } + b.visitInfo = b.visitInfo[:0] + } + b.visitInfo = append(originalVisitInfo, b.visitInfo...) + + if b.ctx.GetSessionVars().StmtCtx.InExplainStmt { + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ShowViewPriv, dbName.L, tableInfo.Name.L, "", plannererrors.ErrViewNoExplain) + } + + if len(tableInfo.Columns) != selectLogicalPlan.Schema().Len() { + return nil, plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) + } + + return b.buildProjUponView(ctx, dbName, tableInfo, selectLogicalPlan) +} + +func (b *PlanBuilder) buildProjUponView(_ context.Context, dbName model.CIStr, tableInfo *model.TableInfo, selectLogicalPlan base.Plan) (base.LogicalPlan, error) { + columnInfo := tableInfo.Cols() + cols := selectLogicalPlan.Schema().Clone().Columns + outputNamesOfUnderlyingSelect := selectLogicalPlan.OutputNames().Shallow() + // In the old version of VIEW implementation, TableInfo.View.Cols is used to + // store the origin columns' names of the underlying SelectStmt used when + // creating the view. + if tableInfo.View.Cols != nil { + cols = cols[:0] + outputNamesOfUnderlyingSelect = outputNamesOfUnderlyingSelect[:0] + for _, info := range columnInfo { + idx := expression.FindFieldNameIdxByColName(selectLogicalPlan.OutputNames(), info.Name.L) + if idx == -1 { + return nil, plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) + } + cols = append(cols, selectLogicalPlan.Schema().Columns[idx]) + outputNamesOfUnderlyingSelect = append(outputNamesOfUnderlyingSelect, selectLogicalPlan.OutputNames()[idx]) + } + } + + projSchema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.Columns))...) + projExprs := make([]expression.Expression, 0, len(tableInfo.Columns)) + projNames := make(types.NameSlice, 0, len(tableInfo.Columns)) + for i, name := range outputNamesOfUnderlyingSelect { + origColName := name.ColName + if tableInfo.View.Cols != nil { + origColName = tableInfo.View.Cols[i] + } + projNames = append(projNames, &types.FieldName{ + // TblName is the of view instead of the name of the underlying table. + TblName: tableInfo.Name, + OrigTblName: name.OrigTblName, + ColName: columnInfo[i].Name, + OrigColName: origColName, + DBName: dbName, + }) + projSchema.Append(&expression.Column{ + UniqueID: cols[i].UniqueID, + RetType: cols[i].GetStaticType(), + }) + projExprs = append(projExprs, cols[i]) + } + projUponView := logicalop.LogicalProjection{Exprs: projExprs}.Init(b.ctx, b.getSelectOffset()) + projUponView.SetOutputNames(projNames) + projUponView.SetChildren(selectLogicalPlan.(base.LogicalPlan)) + projUponView.SetSchema(projSchema) + return projUponView, nil +} + +// buildApplyWithJoinType builds apply plan with outerPlan and innerPlan, which apply join with particular join type for +// every row from outerPlan and the whole innerPlan. +func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan base.LogicalPlan, tp JoinType, markNoDecorrelate bool) base.LogicalPlan { + b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate | flagConvertOuterToInnerJoin + ap := LogicalApply{LogicalJoin: LogicalJoin{JoinType: tp}, NoDecorrelate: markNoDecorrelate}.Init(b.ctx, b.getSelectOffset()) + ap.SetChildren(outerPlan, innerPlan) + ap.SetOutputNames(make([]*types.FieldName, outerPlan.Schema().Len()+innerPlan.Schema().Len())) + copy(ap.OutputNames(), outerPlan.OutputNames()) + ap.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema())) + setIsInApplyForCTE(innerPlan, ap.Schema()) + // Note that, tp can only be LeftOuterJoin or InnerJoin, so we don't consider other outer joins. + if tp == LeftOuterJoin { + b.optFlag = b.optFlag | flagEliminateOuterJoin + util.ResetNotNullFlag(ap.Schema(), outerPlan.Schema().Len(), ap.Schema().Len()) + } + for i := outerPlan.Schema().Len(); i < ap.Schema().Len(); i++ { + ap.OutputNames()[i] = types.EmptyName + } + ap.LogicalJoin.SetPreferredJoinTypeAndOrder(b.TableHints()) + return ap +} + +// buildSemiApply builds apply plan with outerPlan and innerPlan, which apply semi-join for every row from outerPlan and the whole innerPlan. +func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan base.LogicalPlan, condition []expression.Expression, + asScalar, not, considerRewrite, markNoDecorrelate bool) (base.LogicalPlan, error) { + b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate + + join, err := b.buildSemiJoin(outerPlan, innerPlan, condition, asScalar, not, considerRewrite) + if err != nil { + return nil, err + } + + setIsInApplyForCTE(innerPlan, join.Schema()) + ap := &LogicalApply{LogicalJoin: *join, NoDecorrelate: markNoDecorrelate} + ap.SetTP(plancodec.TypeApply) + ap.SetSelf(ap) + return ap, nil +} + +// setIsInApplyForCTE indicates CTE is the in inner side of Apply and correlate. +// the storage of cte needs to be reset for each outer row. +// It's better to handle this in CTEExec.Close(), but cte storage is closed when SQL is finished. +func setIsInApplyForCTE(p base.LogicalPlan, apSchema *expression.Schema) { + switch x := p.(type) { + case *LogicalCTE: + if len(coreusage.ExtractCorColumnsBySchema4LogicalPlan(p, apSchema)) > 0 { + x.Cte.IsInApply = true + } + setIsInApplyForCTE(x.Cte.seedPartLogicalPlan, apSchema) + if x.Cte.recursivePartLogicalPlan != nil { + setIsInApplyForCTE(x.Cte.recursivePartLogicalPlan, apSchema) + } + default: + for _, child := range p.Children() { + setIsInApplyForCTE(child, apSchema) + } + } +} + +func (b *PlanBuilder) buildMaxOneRow(p base.LogicalPlan) base.LogicalPlan { + // The query block of the MaxOneRow operator should be the same as that of its child. + maxOneRow := logicalop.LogicalMaxOneRow{}.Init(b.ctx, p.QueryBlockOffset()) + maxOneRow.SetChildren(p) + return maxOneRow +} + +func (b *PlanBuilder) buildSemiJoin(outerPlan, innerPlan base.LogicalPlan, onCondition []expression.Expression, asScalar, not, forceRewrite bool) (*LogicalJoin, error) { + b.optFlag |= flagConvertOuterToInnerJoin + joinPlan := LogicalJoin{}.Init(b.ctx, b.getSelectOffset()) + for i, expr := range onCondition { + onCondition[i] = expr.Decorrelate(outerPlan.Schema()) + } + joinPlan.SetChildren(outerPlan, innerPlan) + joinPlan.AttachOnConds(onCondition) + joinPlan.SetOutputNames(make([]*types.FieldName, outerPlan.Schema().Len(), outerPlan.Schema().Len()+innerPlan.Schema().Len()+1)) + copy(joinPlan.OutputNames(), outerPlan.OutputNames()) + if asScalar { + newSchema := outerPlan.Schema().Clone() + newSchema.Append(&expression.Column{ + RetType: types.NewFieldType(mysql.TypeTiny), + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + }) + joinPlan.SetOutputNames(append(joinPlan.OutputNames(), types.EmptyName)) + joinPlan.SetSchema(newSchema) + if not { + joinPlan.JoinType = AntiLeftOuterSemiJoin + } else { + joinPlan.JoinType = LeftOuterSemiJoin + } + } else { + joinPlan.SetSchema(outerPlan.Schema().Clone()) + if not { + joinPlan.JoinType = AntiSemiJoin + } else { + joinPlan.JoinType = SemiJoin + } + } + // Apply forces to choose hash join currently, so don't worry the hints will take effect if the semi join is in one apply. + joinPlan.SetPreferredJoinTypeAndOrder(b.TableHints()) + if forceRewrite { + joinPlan.PreferJoinType |= h.PreferRewriteSemiJoin + b.optFlag |= flagSemiJoinRewrite + } + return joinPlan, nil +} + +func getTableOffset(names []*types.FieldName, handleName *types.FieldName) (int, error) { + for i, name := range names { + if name.DBName.L == handleName.DBName.L && name.TblName.L == handleName.TblName.L { + return i, nil + } + } + return -1, errors.Errorf("Couldn't get column information when do update/delete") +} + +// TblColPosInfo represents an mapper from column index to handle index. +type TblColPosInfo struct { + TblID int64 + // Start and End represent the ordinal range [Start, End) of the consecutive columns. + Start, End int + // HandleOrdinal represents the ordinal of the handle column. + HandleCols util.HandleCols +} + +// MemoryUsage return the memory usage of TblColPosInfo +func (t *TblColPosInfo) MemoryUsage() (sum int64) { + if t == nil { + return + } + + sum = size.SizeOfInt64 + size.SizeOfInt*2 + if t.HandleCols != nil { + sum += t.HandleCols.MemoryUsage() + } + return +} + +// TblColPosInfoSlice attaches the methods of sort.Interface to []TblColPosInfos sorting in increasing order. +type TblColPosInfoSlice []TblColPosInfo + +// Len implements sort.Interface#Len. +func (c TblColPosInfoSlice) Len() int { + return len(c) +} + +// Swap implements sort.Interface#Swap. +func (c TblColPosInfoSlice) Swap(i, j int) { + c[i], c[j] = c[j], c[i] +} + +// Less implements sort.Interface#Less. +func (c TblColPosInfoSlice) Less(i, j int) bool { + return c[i].Start < c[j].Start +} + +// FindTblIdx finds the ordinal of the corresponding access column. +func (c TblColPosInfoSlice) FindTblIdx(colOrdinal int) (int, bool) { + if len(c) == 0 { + return 0, false + } + // find the smallest index of the range that its start great than colOrdinal. + // @see https://godoc.org/sort#Search + rangeBehindOrdinal := sort.Search(len(c), func(i int) bool { return c[i].Start > colOrdinal }) + if rangeBehindOrdinal == 0 { + return 0, false + } + return rangeBehindOrdinal - 1, true +} + +// buildColumns2Handle builds columns to handle mapping. +func buildColumns2Handle( + names []*types.FieldName, + tblID2Handle map[int64][]util.HandleCols, + tblID2Table map[int64]table.Table, + onlyWritableCol bool, +) (TblColPosInfoSlice, error) { + var cols2Handles TblColPosInfoSlice + for tblID, handleCols := range tblID2Handle { + tbl := tblID2Table[tblID] + var tblLen int + if onlyWritableCol { + tblLen = len(tbl.WritableCols()) + } else { + tblLen = len(tbl.Cols()) + } + for _, handleCol := range handleCols { + offset, err := getTableOffset(names, names[handleCol.GetCol(0).Index]) + if err != nil { + return nil, err + } + end := offset + tblLen + cols2Handles = append(cols2Handles, TblColPosInfo{tblID, offset, end, handleCol}) + } + } + sort.Sort(cols2Handles) + return cols2Handles, nil +} + +func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (base.Plan, error) { + b.pushSelectOffset(0) + b.pushTableHints(update.TableHints, 0) + defer func() { + b.popSelectOffset() + // table hints are only visible in the current UPDATE statement. + b.popTableHints() + }() + + b.inUpdateStmt = true + b.isForUpdateRead = true + + if update.With != nil { + l := len(b.outerCTEs) + defer func() { + b.outerCTEs = b.outerCTEs[:l] + }() + _, err := b.buildWith(ctx, update.With) + if err != nil { + return nil, err + } + } + + p, err := b.buildResultSetNode(ctx, update.TableRefs.TableRefs, false) + if err != nil { + return nil, err + } + + tableList := ExtractTableList(update.TableRefs.TableRefs, false) + for _, t := range tableList { + dbName := t.Schema.L + if dbName == "" { + dbName = b.ctx.GetSessionVars().CurrentDB + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) + } + + oldSchemaLen := p.Schema().Len() + if update.Where != nil { + p, err = b.buildSelection(ctx, p, update.Where, nil) + if err != nil { + return nil, err + } + } + if b.ctx.GetSessionVars().TxnCtx.IsPessimistic { + if update.TableRefs.TableRefs.Right == nil { + // buildSelectLock is an optimization that can reduce RPC call. + // We only need do this optimization for single table update which is the most common case. + // When TableRefs.Right is nil, it is single table update. + p, err = b.buildSelectLock(p, &ast.SelectLockInfo{ + LockType: ast.SelectLockForUpdate, + }) + if err != nil { + return nil, err + } + } + } + + if update.Order != nil { + p, err = b.buildSort(ctx, p, update.Order.Items, nil, nil) + if err != nil { + return nil, err + } + } + if update.Limit != nil { + p, err = b.buildLimit(p, update.Limit) + if err != nil { + return nil, err + } + } + + // Add project to freeze the order of output columns. + proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldSchemaLen])}.Init(b.ctx, b.getSelectOffset()) + proj.SetSchema(expression.NewSchema(make([]*expression.Column, oldSchemaLen)...)) + proj.SetOutputNames(make(types.NameSlice, len(p.OutputNames()))) + copy(proj.OutputNames(), p.OutputNames()) + copy(proj.Schema().Columns, p.Schema().Columns[:oldSchemaLen]) + proj.SetChildren(p) + p = proj + + utlr := &updatableTableListResolver{} + update.Accept(utlr) + orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, utlr.updatableTableList, update.List, p) + if err != nil { + return nil, err + } + p = np + + updt := Update{ + OrderedList: orderedList, + AllAssignmentsAreConstant: allAssignmentsAreConstant, + VirtualAssignmentsOffset: len(update.List), + }.Init(b.ctx) + updt.names = p.OutputNames() + // We cannot apply projection elimination when building the subplan, because + // columns in orderedList cannot be resolved. (^flagEliminateProjection should also be applied in postOptimize) + updt.SelectPlan, _, err = DoOptimize(ctx, b.ctx, b.optFlag&^flagEliminateProjection, p) + if err != nil { + return nil, err + } + err = updt.ResolveIndices() + if err != nil { + return nil, err + } + tblID2Handle, err := resolveIndicesForTblID2Handle(b.handleHelper.tailMap(), updt.SelectPlan.Schema()) + if err != nil { + return nil, err + } + tblID2table := make(map[int64]table.Table, len(tblID2Handle)) + for id := range tblID2Handle { + tblID2table[id], _ = b.is.TableByID(id) + } + updt.TblColPosInfos, err = buildColumns2Handle(updt.OutputNames(), tblID2Handle, tblID2table, true) + if err != nil { + return nil, err + } + updt.PartitionedTable = b.partitionedTable + updt.tblID2Table = tblID2table + err = updt.buildOnUpdateFKTriggers(b.ctx, b.is, tblID2table) + return updt, err +} + +// GetUpdateColumnsInfo get the update columns info. +func GetUpdateColumnsInfo(tblID2Table map[int64]table.Table, tblColPosInfos TblColPosInfoSlice, size int) []*table.Column { + colsInfo := make([]*table.Column, size) + for _, content := range tblColPosInfos { + tbl := tblID2Table[content.TblID] + for i, c := range tbl.WritableCols() { + colsInfo[content.Start+i] = c + } + } + return colsInfo +} + +type tblUpdateInfo struct { + name string + pkUpdated bool + partitionColUpdated bool +} + +// CheckUpdateList checks all related columns in updatable state. +func CheckUpdateList(assignFlags []int, updt *Update, newTblID2Table map[int64]table.Table) error { + updateFromOtherAlias := make(map[int64]tblUpdateInfo) + for _, content := range updt.TblColPosInfos { + tbl := newTblID2Table[content.TblID] + flags := assignFlags[content.Start:content.End] + var update, updatePK, updatePartitionCol bool + var partitionColumnNames []model.CIStr + if pt, ok := tbl.(table.PartitionedTable); ok && pt != nil { + partitionColumnNames = pt.GetPartitionColumnNames() + } + + for i, col := range tbl.WritableCols() { + // schema may be changed between building plan and building executor + // If i >= len(flags), it means the target table has been added columns, then we directly skip the check + if i >= len(flags) { + continue + } + if flags[i] < 0 { + continue + } + + if col.State != model.StatePublic { + return plannererrors.ErrUnknownColumn.GenWithStackByArgs(col.Name, clauseMsg[fieldList]) + } + + update = true + if mysql.HasPriKeyFlag(col.GetFlag()) { + updatePK = true + } + for _, partColName := range partitionColumnNames { + if col.Name.L == partColName.L { + updatePartitionCol = true + } + } + } + if update { + // Check for multi-updates on primary key, + // see https://dev.mysql.com/doc/mysql-errors/5.7/en/server-error-reference.html#error_er_multi_update_key_conflict + if otherTable, ok := updateFromOtherAlias[tbl.Meta().ID]; ok { + if otherTable.pkUpdated || updatePK || otherTable.partitionColUpdated || updatePartitionCol { + return plannererrors.ErrMultiUpdateKeyConflict.GenWithStackByArgs(otherTable.name, updt.names[content.Start].TblName.O) + } + } else { + updateFromOtherAlias[tbl.Meta().ID] = tblUpdateInfo{ + name: updt.names[content.Start].TblName.O, + pkUpdated: updatePK, + partitionColUpdated: updatePartitionCol, + } + } + } + } + return nil +} + +// If tl is CTE, its HintedTable will be nil. +// Only used in build plan from AST after preprocess. +func isCTE(tl *ast.TableName) bool { + return tl.TableInfo == nil +} + +func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p base.LogicalPlan) (newList []*expression.Assignment, po base.LogicalPlan, allAssignmentsAreConstant bool, e error) { + b.curClause = fieldList + // modifyColumns indicates which columns are in set list, + // and if it is set to `DEFAULT` + modifyColumns := make(map[string]bool, p.Schema().Len()) + var columnsIdx map[*ast.ColumnName]int + cacheColumnsIdx := false + if len(p.OutputNames()) > 16 { + cacheColumnsIdx = true + columnsIdx = make(map[*ast.ColumnName]int, len(list)) + } + for _, assign := range list { + idx, err := expression.FindFieldName(p.OutputNames(), assign.Column) + if err != nil { + return nil, nil, false, err + } + if idx < 0 { + return nil, nil, false, plannererrors.ErrUnknownColumn.GenWithStackByArgs(assign.Column.Name, "field list") + } + if cacheColumnsIdx { + columnsIdx[assign.Column] = idx + } + name := p.OutputNames()[idx] + foundListItem := false + for _, tl := range tableList { + if (tl.Schema.L == "" || tl.Schema.L == name.DBName.L) && (tl.Name.L == name.TblName.L) { + if isCTE(tl) || tl.TableInfo.IsView() || tl.TableInfo.IsSequence() { + return nil, nil, false, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE") + } + foundListItem = true + } + } + if !foundListItem { + // For case like: + // 1: update (select * from t1) t1 set b = 1111111 ----- (no updatable table here) + // 2: update (select 1 as a) as t, t1 set a=1 ----- (updatable t1 don't have column a) + // --- subQuery is not counted as updatable table. + return nil, nil, false, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE") + } + columnFullName := fmt.Sprintf("%s.%s.%s", name.DBName.L, name.TblName.L, name.ColName.L) + // We save a flag for the column in map `modifyColumns` + // This flag indicated if assign keyword `DEFAULT` to the column + modifyColumns[columnFullName] = IsDefaultExprSameColumn(p.OutputNames()[idx:idx+1], assign.Expr) + } + + // If columns in set list contains generated columns, raise error. + // And, fill virtualAssignments here; that's for generated columns. + virtualAssignments := make([]*ast.Assignment, 0) + for _, tn := range tableList { + if isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() { + continue + } + + tableInfo := tn.TableInfo + tableVal, found := b.is.TableByID(tableInfo.ID) + if !found { + return nil, nil, false, infoschema.ErrTableNotExists.FastGenByArgs(tn.DBInfo.Name.O, tableInfo.Name.O) + } + for i, colInfo := range tableVal.Cols() { + if !colInfo.IsGenerated() { + continue + } + columnFullName := fmt.Sprintf("%s.%s.%s", tn.DBInfo.Name.L, tn.Name.L, colInfo.Name.L) + isDefault, ok := modifyColumns[columnFullName] + if ok && colInfo.Hidden { + return nil, nil, false, plannererrors.ErrUnknownColumn.GenWithStackByArgs(colInfo.Name, clauseMsg[fieldList]) + } + // Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT. + // see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html + if ok && !isDefault { + return nil, nil, false, plannererrors.ErrBadGeneratedColumn.GenWithStackByArgs(colInfo.Name.O, tableInfo.Name.O) + } + virtualAssignments = append(virtualAssignments, &ast.Assignment{ + Column: &ast.ColumnName{Schema: tn.Schema, Table: tn.Name, Name: colInfo.Name}, + Expr: tableVal.Cols()[i].GeneratedExpr.Clone(), + }) + } + } + + allAssignmentsAreConstant = true + newList = make([]*expression.Assignment, 0, p.Schema().Len()) + tblDbMap := make(map[string]string, len(tableList)) + for _, tbl := range tableList { + if isCTE(tbl) { + continue + } + tblDbMap[tbl.Name.L] = tbl.DBInfo.Name.L + } + + allAssignments := append(list, virtualAssignments...) + dependentColumnsModified := make(map[int64]bool) + for i, assign := range allAssignments { + var idx int + var err error + if cacheColumnsIdx { + if i, ok := columnsIdx[assign.Column]; ok { + idx = i + } else { + idx, err = expression.FindFieldName(p.OutputNames(), assign.Column) + } + } else { + idx, err = expression.FindFieldName(p.OutputNames(), assign.Column) + } + if err != nil { + return nil, nil, false, err + } + col := p.Schema().Columns[idx] + name := p.OutputNames()[idx] + var newExpr expression.Expression + var np base.LogicalPlan + if i < len(list) { + // If assign `DEFAULT` to column, fill the `defaultExpr.Name` before rewrite expression + if expr := extractDefaultExpr(assign.Expr); expr != nil { + expr.Name = assign.Column + } + newExpr, np, err = b.rewrite(ctx, assign.Expr, p, nil, true) + if err != nil { + return nil, nil, false, err + } + dependentColumnsModified[col.UniqueID] = true + } else { + // rewrite with generation expression + rewritePreprocess := func(assign *ast.Assignment) func(expr ast.Node) ast.Node { + return func(expr ast.Node) ast.Node { + switch x := expr.(type) { + case *ast.ColumnName: + return &ast.ColumnName{ + Schema: assign.Column.Schema, + Table: assign.Column.Table, + Name: x.Name, + } + default: + return expr + } + } + } + + o := b.allowBuildCastArray + b.allowBuildCastArray = true + newExpr, np, err = b.rewriteWithPreprocess(ctx, assign.Expr, p, nil, nil, true, rewritePreprocess(assign)) + b.allowBuildCastArray = o + if err != nil { + return nil, nil, false, err + } + // check if the column is modified + dependentColumns := expression.ExtractDependentColumns(newExpr) + var isModified bool + for _, col := range dependentColumns { + if dependentColumnsModified[col.UniqueID] { + isModified = true + break + } + } + // skip unmodified generated columns + if !isModified { + continue + } + } + if _, isConst := newExpr.(*expression.Constant); !isConst { + allAssignmentsAreConstant = false + } + p = np + if cols := expression.ExtractColumnSet(newExpr); cols.Len() > 0 { + b.ctx.GetSessionVars().StmtCtx.ColRefFromUpdatePlan.UnionWith(cols) + } + newList = append(newList, &expression.Assignment{Col: col, ColName: name.ColName, Expr: newExpr}) + dbName := name.DBName.L + // To solve issue#10028, we need to get database name by the table alias name. + if dbNameTmp, ok := tblDbMap[name.TblName.L]; ok { + dbName = dbNameTmp + } + if dbName == "" { + dbName = b.ctx.GetSessionVars().CurrentDB + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, name.OrigTblName.L, "", nil) + } + return newList, p, allAssignmentsAreConstant, nil +} + +// extractDefaultExpr extract a `DefaultExpr` from `ExprNode`, +// If it is a `DEFAULT` function like `DEFAULT(a)`, return nil. +// Only if it is `DEFAULT` keyword, it will return the `DefaultExpr`. +func extractDefaultExpr(node ast.ExprNode) *ast.DefaultExpr { + if expr, ok := node.(*ast.DefaultExpr); ok && expr.Name == nil { + return expr + } + return nil +} + +// IsDefaultExprSameColumn - DEFAULT or col = DEFAULT(col) +func IsDefaultExprSameColumn(names types.NameSlice, node ast.ExprNode) bool { + if expr, ok := node.(*ast.DefaultExpr); ok { + if expr.Name == nil { + // col = DEFAULT + return true + } + refIdx, err := expression.FindFieldName(names, expr.Name) + if refIdx == 0 && err == nil { + // col = DEFAULT(col) + return true + } + } + return false +} + +func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (base.Plan, error) { + b.pushSelectOffset(0) + b.pushTableHints(ds.TableHints, 0) + defer func() { + b.popSelectOffset() + // table hints are only visible in the current DELETE statement. + b.popTableHints() + }() + + b.inDeleteStmt = true + b.isForUpdateRead = true + + if ds.With != nil { + l := len(b.outerCTEs) + defer func() { + b.outerCTEs = b.outerCTEs[:l] + }() + _, err := b.buildWith(ctx, ds.With) + if err != nil { + return nil, err + } + } + + p, err := b.buildResultSetNode(ctx, ds.TableRefs.TableRefs, false) + if err != nil { + return nil, err + } + oldSchema := p.Schema() + oldLen := oldSchema.Len() + + // For explicit column usage, should use the all-public columns. + if ds.Where != nil { + p, err = b.buildSelection(ctx, p, ds.Where, nil) + if err != nil { + return nil, err + } + } + if b.ctx.GetSessionVars().TxnCtx.IsPessimistic { + if !ds.IsMultiTable { + p, err = b.buildSelectLock(p, &ast.SelectLockInfo{ + LockType: ast.SelectLockForUpdate, + }) + if err != nil { + return nil, err + } + } + } + + if ds.Order != nil { + p, err = b.buildSort(ctx, p, ds.Order.Items, nil, nil) + if err != nil { + return nil, err + } + } + + if ds.Limit != nil { + p, err = b.buildLimit(p, ds.Limit) + if err != nil { + return nil, err + } + } + + // If the delete is non-qualified it does not require Select Priv + if ds.Where == nil && ds.Order == nil { + b.popVisitInfo() + } + var authErr error + sessionVars := b.ctx.GetSessionVars() + + proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldLen])}.Init(b.ctx, b.getSelectOffset()) + proj.SetChildren(p) + proj.SetSchema(oldSchema.Clone()) + proj.SetOutputNames(p.OutputNames()[:oldLen]) + p = proj + + del := Delete{ + IsMultiTable: ds.IsMultiTable, + }.Init(b.ctx) + + del.names = p.OutputNames() + // Collect visitInfo. + if ds.Tables != nil { + // Delete a, b from a, b, c, d... add a and b. + updatableList := make(map[string]bool) + tbInfoList := make(map[string]*ast.TableName) + collectTableName(ds.TableRefs.TableRefs, &updatableList, &tbInfoList) + for _, tn := range ds.Tables.Tables { + var canUpdate, foundMatch = false, false + name := tn.Name.L + if tn.Schema.L == "" { + canUpdate, foundMatch = updatableList[name] + } + + if !foundMatch { + if tn.Schema.L == "" { + name = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB).L + "." + tn.Name.L + } else { + name = tn.Schema.L + "." + tn.Name.L + } + canUpdate, foundMatch = updatableList[name] + } + // check sql like: `delete b from (select * from t) as a, t` + if !foundMatch { + return nil, plannererrors.ErrUnknownTable.GenWithStackByArgs(tn.Name.O, "MULTI DELETE") + } + // check sql like: `delete a from (select * from t) as a, t` + if !canUpdate { + return nil, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(tn.Name.O, "DELETE") + } + tb := tbInfoList[name] + tn.DBInfo = tb.DBInfo + tn.TableInfo = tb.TableInfo + if tn.TableInfo.IsView() { + return nil, errors.Errorf("delete view %s is not supported now", tn.Name.O) + } + if tn.TableInfo.IsSequence() { + return nil, errors.Errorf("delete sequence %s is not supported now", tn.Name.O) + } + if sessionVars.User != nil { + authErr = plannererrors.ErrTableaccessDenied.FastGenByArgs("DELETE", sessionVars.User.AuthUsername, sessionVars.User.AuthHostname, tb.Name.L) + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tb.DBInfo.Name.L, tb.Name.L, "", authErr) + } + } else { + // Delete from a, b, c, d. + tableList := ExtractTableList(ds.TableRefs.TableRefs, false) + for _, v := range tableList { + if isCTE(v) { + return nil, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(v.Name.O, "DELETE") + } + if v.TableInfo.IsView() { + return nil, errors.Errorf("delete view %s is not supported now", v.Name.O) + } + if v.TableInfo.IsSequence() { + return nil, errors.Errorf("delete sequence %s is not supported now", v.Name.O) + } + dbName := v.Schema.L + if dbName == "" { + dbName = b.ctx.GetSessionVars().CurrentDB + } + if sessionVars.User != nil { + authErr = plannererrors.ErrTableaccessDenied.FastGenByArgs("DELETE", sessionVars.User.AuthUsername, sessionVars.User.AuthHostname, v.Name.L) + } + b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, dbName, v.Name.L, "", authErr) + } + } + handleColsMap := b.handleHelper.tailMap() + tblID2Handle, err := resolveIndicesForTblID2Handle(handleColsMap, p.Schema()) + if err != nil { + return nil, err + } + if del.IsMultiTable { + // tblID2TableName is the table map value is an array which contains table aliases. + // Table ID may not be unique for deleting multiple tables, for statements like + // `delete from t as t1, t as t2`, the same table has two alias, we have to identify a table + // by its alias instead of ID. + tblID2TableName := make(map[int64][]*ast.TableName, len(ds.Tables.Tables)) + for _, tn := range ds.Tables.Tables { + tblID2TableName[tn.TableInfo.ID] = append(tblID2TableName[tn.TableInfo.ID], tn) + } + tblID2Handle = del.cleanTblID2HandleMap(tblID2TableName, tblID2Handle, del.names) + } + tblID2table := make(map[int64]table.Table, len(tblID2Handle)) + for id := range tblID2Handle { + tblID2table[id], _ = b.is.TableByID(id) + } + del.TblColPosInfos, err = buildColumns2Handle(del.names, tblID2Handle, tblID2table, false) + if err != nil { + return nil, err + } + + del.SelectPlan, _, err = DoOptimize(ctx, b.ctx, b.optFlag, p) + if err != nil { + return nil, err + } + + err = del.buildOnDeleteFKTriggers(b.ctx, b.is, tblID2table) + return del, err +} + +func resolveIndicesForTblID2Handle(tblID2Handle map[int64][]util.HandleCols, schema *expression.Schema) (map[int64][]util.HandleCols, error) { + newMap := make(map[int64][]util.HandleCols, len(tblID2Handle)) + for i, cols := range tblID2Handle { + for _, col := range cols { + resolvedCol, err := col.ResolveIndices(schema) + if err != nil { + return nil, err + } + newMap[i] = append(newMap[i], resolvedCol) + } + } + return newMap, nil +} + +func (p *Delete) cleanTblID2HandleMap( + tablesToDelete map[int64][]*ast.TableName, + tblID2Handle map[int64][]util.HandleCols, + outputNames []*types.FieldName, +) map[int64][]util.HandleCols { + for id, cols := range tblID2Handle { + names, ok := tablesToDelete[id] + if !ok { + delete(tblID2Handle, id) + continue + } + for i := len(cols) - 1; i >= 0; i-- { + hCols := cols[i] + var hasMatch bool + for j := 0; j < hCols.NumCols(); j++ { + if p.matchingDeletingTable(names, outputNames[hCols.GetCol(j).Index]) { + hasMatch = true + break + } + } + if !hasMatch { + cols = append(cols[:i], cols[i+1:]...) + } + } + if len(cols) == 0 { + delete(tblID2Handle, id) + continue + } + tblID2Handle[id] = cols + } + return tblID2Handle +} + +// matchingDeletingTable checks whether this column is from the table which is in the deleting list. +func (*Delete) matchingDeletingTable(names []*ast.TableName, name *types.FieldName) bool { + for _, n := range names { + if (name.DBName.L == "" || name.DBName.L == n.DBInfo.Name.L) && name.TblName.L == n.Name.L { + return true + } + } + return false +} + +func getWindowName(name string) string { + if name == "" { + return "" + } + return name +} + +// buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, +// so after the projection, window functions only needs to deal with columns. +func (b *PlanBuilder) buildProjectionForWindow(ctx context.Context, p base.LogicalPlan, spec *ast.WindowSpec, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, []property.SortItem, []property.SortItem, []expression.Expression, error) { + b.optFlag |= flagEliminateProjection + + var partitionItems, orderItems []*ast.ByItem + if spec.PartitionBy != nil { + partitionItems = spec.PartitionBy.Items + } + if spec.OrderBy != nil { + orderItems = spec.OrderBy.Items + } + + projLen := len(p.Schema().Columns) + len(partitionItems) + len(orderItems) + len(args) + proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, projLen)}.Init(b.ctx, b.getSelectOffset()) + proj.SetSchema(expression.NewSchema(make([]*expression.Column, 0, projLen)...)) + proj.SetOutputNames(make([]*types.FieldName, p.Schema().Len(), projLen)) + for _, col := range p.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + proj.Schema().Append(col) + } + copy(proj.OutputNames(), p.OutputNames()) + + propertyItems := make([]property.SortItem, 0, len(partitionItems)+len(orderItems)) + var err error + p, propertyItems, err = b.buildByItemsForWindow(ctx, p, proj, partitionItems, propertyItems, aggMap) + if err != nil { + return nil, nil, nil, nil, err + } + lenPartition := len(propertyItems) + p, propertyItems, err = b.buildByItemsForWindow(ctx, p, proj, orderItems, propertyItems, aggMap) + if err != nil { + return nil, nil, nil, nil, err + } + + newArgList := make([]expression.Expression, 0, len(args)) + for _, arg := range args { + newArg, np, err := b.rewrite(ctx, arg, p, aggMap, true) + if err != nil { + return nil, nil, nil, nil, err + } + p = np + switch newArg.(type) { + case *expression.Column, *expression.Constant: + newArgList = append(newArgList, newArg.Clone()) + continue + } + proj.Exprs = append(proj.Exprs, newArg) + proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName)) + col := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: newArg.GetType(b.ctx.GetExprCtx().GetEvalCtx()), + } + proj.Schema().Append(col) + newArgList = append(newArgList, col) + } + + proj.SetChildren(p) + return proj, propertyItems[:lenPartition], propertyItems[lenPartition:], newArgList, nil +} + +func (b *PlanBuilder) buildArgs4WindowFunc(ctx context.Context, p base.LogicalPlan, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) ([]expression.Expression, error) { + b.optFlag |= flagEliminateProjection + + newArgList := make([]expression.Expression, 0, len(args)) + // use below index for created a new col definition + // it's okay here because we only want to return the args used in window function + newColIndex := 0 + for _, arg := range args { + newArg, np, err := b.rewrite(ctx, arg, p, aggMap, true) + if err != nil { + return nil, err + } + p = np + switch newArg.(type) { + case *expression.Column, *expression.Constant: + newArgList = append(newArgList, newArg.Clone()) + continue + } + col := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: newArg.GetType(b.ctx.GetExprCtx().GetEvalCtx()), + } + newColIndex++ + newArgList = append(newArgList, col) + } + return newArgList, nil +} + +func (b *PlanBuilder) buildByItemsForWindow( + ctx context.Context, + p base.LogicalPlan, + proj *logicalop.LogicalProjection, + items []*ast.ByItem, + retItems []property.SortItem, + aggMap map[*ast.AggregateFuncExpr]int, +) (base.LogicalPlan, []property.SortItem, error) { + transformer := &itemTransformer{} + for _, item := range items { + newExpr, _ := item.Expr.Accept(transformer) + item.Expr = newExpr.(ast.ExprNode) + it, np, err := b.rewrite(ctx, item.Expr, p, aggMap, true) + if err != nil { + return nil, nil, err + } + p = np + if it.GetType(b.ctx.GetExprCtx().GetEvalCtx()).GetType() == mysql.TypeNull { + continue + } + if col, ok := it.(*expression.Column); ok { + retItems = append(retItems, property.SortItem{Col: col, Desc: item.Desc}) + // We need to attempt to add this column because a subquery may be created during the expression rewrite process. + // Therefore, we need to ensure that the column from the newly created query plan is added. + // If the column is already in the schema, we don't need to add it again. + if !proj.Schema().Contains(col) { + proj.Exprs = append(proj.Exprs, col) + proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName)) + proj.Schema().Append(col) + } + continue + } + proj.Exprs = append(proj.Exprs, it) + proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName)) + col := &expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: it.GetType(b.ctx.GetExprCtx().GetEvalCtx()), + } + proj.Schema().Append(col) + retItems = append(retItems, property.SortItem{Col: col, Desc: item.Desc}) + } + return p, retItems, nil +} + +// buildWindowFunctionFrameBound builds the bounds of window function frames. +// For type `Rows`, the bound expr must be an unsigned integer. +// For type `Range`, the bound expr must be temporal or numeric types. +func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast.WindowSpec, orderByItems []property.SortItem, boundClause *ast.FrameBound) (*logicalop.FrameBound, error) { + frameType := spec.Frame.Type + bound := &logicalop.FrameBound{Type: boundClause.Type, UnBounded: boundClause.UnBounded, IsExplicitRange: false} + if bound.UnBounded { + return bound, nil + } + + if frameType == ast.Rows { + if bound.Type == ast.CurrentRow { + return bound, nil + } + numRows, _, _ := getUintFromNode(b.ctx, boundClause.Expr, false) + bound.Num = numRows + return bound, nil + } + + bound.CalcFuncs = make([]expression.Expression, len(orderByItems)) + bound.CmpFuncs = make([]expression.CompareFunc, len(orderByItems)) + if bound.Type == ast.CurrentRow { + for i, item := range orderByItems { + col := item.Col + bound.CalcFuncs[i] = col + bound.CmpFuncs[i] = expression.GetCmpFunction(b.ctx.GetExprCtx(), col, col) + } + return bound, nil + } + + col := orderByItems[0].Col + // TODO: We also need to raise error for non-deterministic expressions, like rand(). + val, err := evalAstExprWithPlanCtx(b.ctx, boundClause.Expr) + if err != nil { + return nil, plannererrors.ErrWindowRangeBoundNotConstant.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + expr := expression.Constant{Value: val, RetType: boundClause.Expr.GetType()} + + checker := &expression.ParamMarkerInPrepareChecker{} + boundClause.Expr.Accept(checker) + + // If it has paramMarker and is in prepare stmt. We don't need to eval it since its value is not decided yet. + if !checker.InPrepareStmt { + // Do not raise warnings for truncate. + exprCtx := exprctx.CtxWithHandleTruncateErrLevel(b.ctx.GetExprCtx(), errctx.LevelIgnore) + uVal, isNull, err := expr.EvalInt(exprCtx.GetEvalCtx(), chunk.Row{}) + if uVal < 0 || isNull || err != nil { + return nil, plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + } + + bound.IsExplicitRange = true + desc := orderByItems[0].Desc + var funcName string + if boundClause.Unit != ast.TimeUnitInvalid { + // TODO: Perhaps we don't need to transcode this back to generic string + unitVal := boundClause.Unit.String() + unit := expression.Constant{ + Value: types.NewStringDatum(unitVal), + RetType: types.NewFieldType(mysql.TypeVarchar), + } + + // When the order is asc: + // `+` for following, and `-` for the preceding + // When the order is desc, `+` becomes `-` and vice-versa. + funcName = ast.DateAdd + if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) { + funcName = ast.DateSub + } + + bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx.GetExprCtx(), funcName, col.RetType, col, &expr, &unit) + if err != nil { + return nil, err + } + } else { + // When the order is asc: + // `+` for following, and `-` for the preceding + // When the order is desc, `+` becomes `-` and vice-versa. + funcName = ast.Plus + if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) { + funcName = ast.Minus + } + + bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx.GetExprCtx(), funcName, col.RetType, col, &expr) + if err != nil { + return nil, err + } + } + + cmpDataType := expression.GetAccurateCmpType(b.ctx.GetExprCtx().GetEvalCtx(), col, bound.CalcFuncs[0]) + bound.UpdateCmpFuncsAndCmpDataType(cmpDataType) + return bound, nil +} + +// buildWindowFunctionFrame builds the window function frames. +// See https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html +func (b *PlanBuilder) buildWindowFunctionFrame(ctx context.Context, spec *ast.WindowSpec, orderByItems []property.SortItem) (*logicalop.WindowFrame, error) { + frameClause := spec.Frame + if frameClause == nil { + return nil, nil + } + frame := &logicalop.WindowFrame{Type: frameClause.Type} + var err error + frame.Start, err = b.buildWindowFunctionFrameBound(ctx, spec, orderByItems, &frameClause.Extent.Start) + if err != nil { + return nil, err + } + frame.End, err = b.buildWindowFunctionFrameBound(ctx, spec, orderByItems, &frameClause.Extent.End) + return frame, err +} + +func (b *PlanBuilder) checkWindowFuncArgs(ctx context.Context, p base.LogicalPlan, windowFuncExprs []*ast.WindowFuncExpr, windowAggMap map[*ast.AggregateFuncExpr]int) error { + checker := &expression.ParamMarkerInPrepareChecker{} + for _, windowFuncExpr := range windowFuncExprs { + if strings.ToLower(windowFuncExpr.Name) == ast.AggFuncGroupConcat { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("group_concat as window function") + } + args, err := b.buildArgs4WindowFunc(ctx, p, windowFuncExpr.Args, windowAggMap) + if err != nil { + return err + } + checker.InPrepareStmt = false + for _, expr := range windowFuncExpr.Args { + expr.Accept(checker) + } + desc, err := aggregation.NewWindowFuncDesc(b.ctx.GetExprCtx(), windowFuncExpr.Name, args, checker.InPrepareStmt) + if err != nil { + return err + } + if desc == nil { + return plannererrors.ErrWrongArguments.GenWithStackByArgs(strings.ToLower(windowFuncExpr.Name)) + } + } + return nil +} + +func getAllByItems(itemsBuf []*ast.ByItem, spec *ast.WindowSpec) []*ast.ByItem { + itemsBuf = itemsBuf[:0] + if spec.PartitionBy != nil { + itemsBuf = append(itemsBuf, spec.PartitionBy.Items...) + } + if spec.OrderBy != nil { + itemsBuf = append(itemsBuf, spec.OrderBy.Items...) + } + return itemsBuf +} + +func restoreByItemText(item *ast.ByItem) string { + var sb strings.Builder + ctx := format.NewRestoreCtx(0, &sb) + err := item.Expr.Restore(ctx) + if err != nil { + return "" + } + return sb.String() +} + +func compareItems(lItems []*ast.ByItem, rItems []*ast.ByItem) bool { + minLen := min(len(lItems), len(rItems)) + for i := 0; i < minLen; i++ { + res := strings.Compare(restoreByItemText(lItems[i]), restoreByItemText(rItems[i])) + if res != 0 { + return res < 0 + } + res = compareBool(lItems[i].Desc, rItems[i].Desc) + if res != 0 { + return res < 0 + } + } + return len(lItems) < len(rItems) +} + +type windowFuncs struct { + spec *ast.WindowSpec + funcs []*ast.WindowFuncExpr +} + +// sortWindowSpecs sorts the window specifications by reversed alphabetical order, then we could add less `Sort` operator +// in physical plan because the window functions with the same partition by and order by clause will be at near places. +func sortWindowSpecs(groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr, orderedSpec []*ast.WindowSpec) []windowFuncs { + windows := make([]windowFuncs, 0, len(groupedFuncs)) + for _, spec := range orderedSpec { + windows = append(windows, windowFuncs{spec, groupedFuncs[spec]}) + } + lItemsBuf := make([]*ast.ByItem, 0, 4) + rItemsBuf := make([]*ast.ByItem, 0, 4) + sort.SliceStable(windows, func(i, j int) bool { + lItemsBuf = getAllByItems(lItemsBuf, windows[i].spec) + rItemsBuf = getAllByItems(rItemsBuf, windows[j].spec) + return !compareItems(lItemsBuf, rItemsBuf) + }) + return windows +} + +func (b *PlanBuilder) buildWindowFunctions(ctx context.Context, p base.LogicalPlan, groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr, orderedSpec []*ast.WindowSpec, aggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, map[*ast.WindowFuncExpr]int, error) { + if b.buildingCTE { + b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow = true + } + args := make([]ast.ExprNode, 0, 4) + windowMap := make(map[*ast.WindowFuncExpr]int) + for _, window := range sortWindowSpecs(groupedFuncs, orderedSpec) { + args = args[:0] + spec, funcs := window.spec, window.funcs + for _, windowFunc := range funcs { + args = append(args, windowFunc.Args...) + } + np, partitionBy, orderBy, args, err := b.buildProjectionForWindow(ctx, p, spec, args, aggMap) + if err != nil { + return nil, nil, err + } + if len(funcs) == 0 { + // len(funcs) == 0 indicates this an unused named window spec, + // so we just check for its validity and don't have to build plan for it. + err := b.checkOriginWindowSpec(spec, orderBy) + if err != nil { + return nil, nil, err + } + continue + } + err = b.checkOriginWindowFuncs(funcs, orderBy) + if err != nil { + return nil, nil, err + } + frame, err := b.buildWindowFunctionFrame(ctx, spec, orderBy) + if err != nil { + return nil, nil, err + } + + window := logicalop.LogicalWindow{ + PartitionBy: partitionBy, + OrderBy: orderBy, + Frame: frame, + }.Init(b.ctx, b.getSelectOffset()) + window.SetOutputNames(make([]*types.FieldName, np.Schema().Len())) + copy(window.OutputNames(), np.OutputNames()) + schema := np.Schema().Clone() + descs := make([]*aggregation.WindowFuncDesc, 0, len(funcs)) + preArgs := 0 + checker := &expression.ParamMarkerInPrepareChecker{} + for _, windowFunc := range funcs { + checker.InPrepareStmt = false + for _, expr := range windowFunc.Args { + expr.Accept(checker) + } + desc, err := aggregation.NewWindowFuncDesc(b.ctx.GetExprCtx(), windowFunc.Name, args[preArgs:preArgs+len(windowFunc.Args)], checker.InPrepareStmt) + if err != nil { + return nil, nil, err + } + if desc == nil { + return nil, nil, plannererrors.ErrWrongArguments.GenWithStackByArgs(strings.ToLower(windowFunc.Name)) + } + preArgs += len(windowFunc.Args) + desc.WrapCastForAggArgs(b.ctx.GetExprCtx()) + descs = append(descs, desc) + windowMap[windowFunc] = schema.Len() + schema.Append(&expression.Column{ + UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), + RetType: desc.RetTp, + }) + window.SetOutputNames(append(window.OutputNames(), types.EmptyName)) + } + window.WindowFuncDescs = descs + window.SetChildren(np) + window.SetSchema(schema) + p = window + } + return p, windowMap, nil +} + +// checkOriginWindowFuncs checks the validity for original window specifications for a group of functions. +// Because the grouped specification is different from them, we should especially check them before build window frame. +func (b *PlanBuilder) checkOriginWindowFuncs(funcs []*ast.WindowFuncExpr, orderByItems []property.SortItem) error { + for _, f := range funcs { + if f.IgnoreNull { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("IGNORE NULLS") + } + if f.Distinct { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("(DISTINCT ..)") + } + if f.FromLast { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("FROM LAST") + } + spec := &f.Spec + if f.Spec.Name.L != "" { + spec = b.windowSpecs[f.Spec.Name.L] + } + if err := b.checkOriginWindowSpec(spec, orderByItems); err != nil { + return err + } + } + return nil +} + +// checkOriginWindowSpec checks the validity for given window specification. +func (b *PlanBuilder) checkOriginWindowSpec(spec *ast.WindowSpec, orderByItems []property.SortItem) error { + if spec.Frame == nil { + return nil + } + if spec.Frame.Type == ast.Groups { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("GROUPS") + } + start, end := spec.Frame.Extent.Start, spec.Frame.Extent.End + if start.Type == ast.Following && start.UnBounded { + return plannererrors.ErrWindowFrameStartIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if end.Type == ast.Preceding && end.UnBounded { + return plannererrors.ErrWindowFrameEndIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if start.Type == ast.Following && (end.Type == ast.Preceding || end.Type == ast.CurrentRow) { + return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if (start.Type == ast.Following || start.Type == ast.CurrentRow) && end.Type == ast.Preceding { + return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + + err := b.checkOriginWindowFrameBound(&start, spec, orderByItems) + if err != nil { + return err + } + err = b.checkOriginWindowFrameBound(&end, spec, orderByItems) + if err != nil { + return err + } + return nil +} + +func (b *PlanBuilder) checkOriginWindowFrameBound(bound *ast.FrameBound, spec *ast.WindowSpec, orderByItems []property.SortItem) error { + if bound.Type == ast.CurrentRow || bound.UnBounded { + return nil + } + + frameType := spec.Frame.Type + if frameType == ast.Rows { + if bound.Unit != ast.TimeUnitInvalid { + return plannererrors.ErrWindowRowsIntervalUse.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + _, isNull, isExpectedType := getUintFromNode(b.ctx, bound.Expr, false) + if isNull || !isExpectedType { + return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + return nil + } + + if len(orderByItems) != 1 { + return plannererrors.ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + orderItemType := orderByItems[0].Col.RetType.GetType() + isNumeric, isTemporal := types.IsTypeNumeric(orderItemType), types.IsTypeTemporal(orderItemType) + if !isNumeric && !isTemporal { + return plannererrors.ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if bound.Unit != ast.TimeUnitInvalid && !isTemporal { + return plannererrors.ErrWindowRangeFrameNumericType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + if bound.Unit == ast.TimeUnitInvalid && !isNumeric { + return plannererrors.ErrWindowRangeFrameTemporalType.GenWithStackByArgs(getWindowName(spec.Name.O)) + } + return nil +} + +func extractWindowFuncs(fields []*ast.SelectField) []*ast.WindowFuncExpr { + extractor := &WindowFuncExtractor{} + for _, f := range fields { + n, _ := f.Expr.Accept(extractor) + f.Expr = n.(ast.ExprNode) + } + return extractor.windowFuncs +} + +func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, windowFuncName string) (*ast.WindowSpec, bool) { + needFrame := aggregation.NeedFrame(windowFuncName) + // According to MySQL, In the absence of a frame clause, the default frame depends on whether an ORDER BY clause is present: + // (1) With order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"; + // (2) Without order by, the default frame is includes all partition rows, equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + // or "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", which is the same as an empty frame. + // https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html + if needFrame && spec.Frame == nil && spec.OrderBy != nil { + newSpec := *spec + newSpec.Frame = &ast.FrameClause{ + Type: ast.Ranges, + Extent: ast.FrameExtent{ + Start: ast.FrameBound{Type: ast.Preceding, UnBounded: true}, + End: ast.FrameBound{Type: ast.CurrentRow}, + }, + } + return &newSpec, true + } + // "RANGE/ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" is equivalent to empty frame. + if needFrame && spec.Frame != nil && + spec.Frame.Extent.Start.UnBounded && spec.Frame.Extent.End.UnBounded { + newSpec := *spec + newSpec.Frame = nil + return &newSpec, true + } + if !needFrame { + var updated bool + newSpec := *spec + + // For functions that operate on the entire partition, the frame clause will be ignored. + if spec.Frame != nil { + specName := spec.Name.O + b.ctx.GetSessionVars().StmtCtx.AppendNote(plannererrors.ErrWindowFunctionIgnoresFrame.FastGenByArgs(windowFuncName, getWindowName(specName))) + newSpec.Frame = nil + updated = true + } + if b.ctx.GetSessionVars().EnablePipelinedWindowExec { + useDefaultFrame, defaultFrame := aggregation.UseDefaultFrame(windowFuncName) + if useDefaultFrame { + newSpec.Frame = &defaultFrame + updated = true + } + } + if updated { + return &newSpec, true + } + } + return spec, false +} + +// append ast.WindowSpec to []*ast.WindowSpec if absent +func appendIfAbsentWindowSpec(specs []*ast.WindowSpec, ns *ast.WindowSpec) []*ast.WindowSpec { + for _, spec := range specs { + if spec == ns { + return specs + } + } + return append(specs, ns) +} + +func specEqual(s1, s2 *ast.WindowSpec) (equal bool, err error) { + if (s1 == nil && s2 != nil) || (s1 != nil && s2 == nil) { + return false, nil + } + var sb1, sb2 strings.Builder + ctx1 := format.NewRestoreCtx(0, &sb1) + ctx2 := format.NewRestoreCtx(0, &sb2) + if err = s1.Restore(ctx1); err != nil { + return + } + if err = s2.Restore(ctx2); err != nil { + return + } + return sb1.String() == sb2.String(), nil +} + +// groupWindowFuncs groups the window functions according to the window specification name. +// TODO: We can group the window function by the definition of window specification. +func (b *PlanBuilder) groupWindowFuncs(windowFuncs []*ast.WindowFuncExpr) (map[*ast.WindowSpec][]*ast.WindowFuncExpr, []*ast.WindowSpec, error) { + // updatedSpecMap is used to handle the specifications that have frame clause changed. + updatedSpecMap := make(map[string][]*ast.WindowSpec) + groupedWindow := make(map[*ast.WindowSpec][]*ast.WindowFuncExpr) + orderedSpec := make([]*ast.WindowSpec, 0, len(windowFuncs)) + for _, windowFunc := range windowFuncs { + if windowFunc.Spec.Name.L == "" { + spec := &windowFunc.Spec + if spec.Ref.L != "" { + ref, ok := b.windowSpecs[spec.Ref.L] + if !ok { + return nil, nil, plannererrors.ErrWindowNoSuchWindow.GenWithStackByArgs(getWindowName(spec.Ref.O)) + } + err := mergeWindowSpec(spec, ref) + if err != nil { + return nil, nil, err + } + } + spec, _ = b.handleDefaultFrame(spec, windowFunc.Name) + groupedWindow[spec] = append(groupedWindow[spec], windowFunc) + orderedSpec = appendIfAbsentWindowSpec(orderedSpec, spec) + continue + } + + name := windowFunc.Spec.Name.L + spec, ok := b.windowSpecs[name] + if !ok { + return nil, nil, plannererrors.ErrWindowNoSuchWindow.GenWithStackByArgs(windowFunc.Spec.Name.O) + } + newSpec, updated := b.handleDefaultFrame(spec, windowFunc.Name) + if !updated { + groupedWindow[spec] = append(groupedWindow[spec], windowFunc) + orderedSpec = appendIfAbsentWindowSpec(orderedSpec, spec) + } else { + var updatedSpec *ast.WindowSpec + if _, ok := updatedSpecMap[name]; !ok { + updatedSpecMap[name] = []*ast.WindowSpec{newSpec} + updatedSpec = newSpec + } else { + for _, spec := range updatedSpecMap[name] { + eq, err := specEqual(spec, newSpec) + if err != nil { + return nil, nil, err + } + if eq { + updatedSpec = spec + break + } + } + if updatedSpec == nil { + updatedSpec = newSpec + updatedSpecMap[name] = append(updatedSpecMap[name], newSpec) + } + } + groupedWindow[updatedSpec] = append(groupedWindow[updatedSpec], windowFunc) + orderedSpec = appendIfAbsentWindowSpec(orderedSpec, updatedSpec) + } + } + // Unused window specs should also be checked in b.buildWindowFunctions, + // so we add them to `groupedWindow` with empty window functions. + for _, spec := range b.windowSpecs { + if _, ok := groupedWindow[spec]; !ok { + if _, ok = updatedSpecMap[spec.Name.L]; !ok { + groupedWindow[spec] = nil + orderedSpec = appendIfAbsentWindowSpec(orderedSpec, spec) + } + } + } + return groupedWindow, orderedSpec, nil +} + +// resolveWindowSpec resolve window specifications for sql like `select ... from t window w1 as (w2), w2 as (partition by a)`. +// We need to resolve the referenced window to get the definition of current window spec. +func resolveWindowSpec(spec *ast.WindowSpec, specs map[string]*ast.WindowSpec, inStack map[string]bool) error { + if inStack[spec.Name.L] { + return errors.Trace(plannererrors.ErrWindowCircularityInWindowGraph) + } + if spec.Ref.L == "" { + return nil + } + ref, ok := specs[spec.Ref.L] + if !ok { + return plannererrors.ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) + } + inStack[spec.Name.L] = true + err := resolveWindowSpec(ref, specs, inStack) + if err != nil { + return err + } + inStack[spec.Name.L] = false + return mergeWindowSpec(spec, ref) +} + +func mergeWindowSpec(spec, ref *ast.WindowSpec) error { + if ref.Frame != nil { + return plannererrors.ErrWindowNoInherentFrame.GenWithStackByArgs(ref.Name.O) + } + if spec.PartitionBy != nil { + return errors.Trace(plannererrors.ErrWindowNoChildPartitioning) + } + if ref.OrderBy != nil { + if spec.OrderBy != nil { + return plannererrors.ErrWindowNoRedefineOrderBy.GenWithStackByArgs(getWindowName(spec.Name.O), ref.Name.O) + } + spec.OrderBy = ref.OrderBy + } + spec.PartitionBy = ref.PartitionBy + spec.Ref = model.NewCIStr("") + return nil +} + +func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error) { + specsMap := make(map[string]*ast.WindowSpec, len(specs)) + for _, spec := range specs { + if _, ok := specsMap[spec.Name.L]; ok { + return nil, plannererrors.ErrWindowDuplicateName.GenWithStackByArgs(spec.Name.O) + } + newSpec := spec + specsMap[spec.Name.L] = &newSpec + } + inStack := make(map[string]bool, len(specs)) + for _, spec := range specsMap { + err := resolveWindowSpec(spec, specsMap, inStack) + if err != nil { + return nil, err + } + } + return specsMap, nil +} + +type updatableTableListResolver struct { + updatableTableList []*ast.TableName +} + +func (*updatableTableListResolver) Enter(inNode ast.Node) (ast.Node, bool) { + switch v := inNode.(type) { + case *ast.UpdateStmt, *ast.TableRefsClause, *ast.Join, *ast.TableSource, *ast.TableName: + return v, false + } + return inNode, true +} + +func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) { + if v, ok := inNode.(*ast.TableSource); ok { + if s, ok := v.Source.(*ast.TableName); ok { + if v.AsName.L != "" { + newTableName := *s + newTableName.Name = v.AsName + newTableName.Schema = model.NewCIStr("") + u.updatableTableList = append(u.updatableTableList, &newTableName) + } else { + u.updatableTableList = append(u.updatableTableList, s) + } + } + } + return inNode, true +} + +// ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName +// If asName is true, extract AsName prior to OrigName. +// Privilege check should use OrigName, while expression may use AsName. +func ExtractTableList(node ast.Node, asName bool) []*ast.TableName { + if node == nil { + return []*ast.TableName{} + } + e := &tableListExtractor{ + asName: asName, + tableNames: []*ast.TableName{}, + } + node.Accept(e) + tableNames := e.tableNames + m := make(map[string]map[string]*ast.TableName) // k1: schemaName, k2: tableName, v: ast.TableName + for _, x := range tableNames { + k1, k2 := x.Schema.L, x.Name.L + // allow empty schema name OR empty table name + if k1 != "" || k2 != "" { + if _, ok := m[k1]; !ok { + m[k1] = make(map[string]*ast.TableName) + } + m[k1][k2] = x + } + } + tableNames = tableNames[:0] + for _, x := range m { + for _, v := range x { + tableNames = append(tableNames, v) + } + } + return tableNames +} + +// tableListExtractor extracts all the TableNames from node. +type tableListExtractor struct { + asName bool + tableNames []*ast.TableName +} + +func (e *tableListExtractor) Enter(n ast.Node) (_ ast.Node, skipChildren bool) { + innerExtract := func(inner ast.Node) []*ast.TableName { + if inner == nil { + return nil + } + innerExtractor := &tableListExtractor{ + asName: e.asName, + tableNames: []*ast.TableName{}, + } + inner.Accept(innerExtractor) + return innerExtractor.tableNames + } + + switch x := n.(type) { + case *ast.TableName: + e.tableNames = append(e.tableNames, x) + case *ast.TableSource: + if s, ok := x.Source.(*ast.TableName); ok { + if x.AsName.L != "" && e.asName { + newTableName := *s + newTableName.Name = x.AsName + newTableName.Schema = model.NewCIStr("") + e.tableNames = append(e.tableNames, &newTableName) + } else { + e.tableNames = append(e.tableNames, s) + } + } else if s, ok := x.Source.(*ast.SelectStmt); ok { + if s.From != nil { + innerList := innerExtract(s.From.TableRefs) + if len(innerList) > 0 { + innerTableName := innerList[0] + if x.AsName.L != "" && e.asName { + newTableName := *innerList[0] + newTableName.Name = x.AsName + newTableName.Schema = model.NewCIStr("") + innerTableName = &newTableName + } + e.tableNames = append(e.tableNames, innerTableName) + } + } + } + return n, true + + case *ast.ShowStmt: + if x.DBName != "" { + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)}) + } + case *ast.CreateDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) + case *ast.AlterDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) + case *ast.DropDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) + + case *ast.FlashBackDatabaseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName}) + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.NewName)}) + case *ast.FlashBackToTimestampStmt: + if x.DBName.L != "" { + e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName}) + } + case *ast.FlashBackTableStmt: + if newName := x.NewName; newName != "" { + e.tableNames = append(e.tableNames, &ast.TableName{ + Schema: x.Table.Schema, + Name: model.NewCIStr(newName)}) + } + + case *ast.GrantStmt: + if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone { + if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable { + e.tableNames = append(e.tableNames, &ast.TableName{ + Schema: model.NewCIStr(x.Level.DBName), + Name: model.NewCIStr(x.Level.TableName), + }) + } + } + case *ast.RevokeStmt: + if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone { + if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable { + e.tableNames = append(e.tableNames, &ast.TableName{ + Schema: model.NewCIStr(x.Level.DBName), + Name: model.NewCIStr(x.Level.TableName), + }) + } + } + case *ast.BRIEStmt: + if x.Kind == ast.BRIEKindBackup || x.Kind == ast.BRIEKindRestore { + for _, v := range x.Schemas { + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(v)}) + } + } + case *ast.UseStmt: + e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)}) + case *ast.ExecuteStmt: + if v, ok := x.PrepStmt.(*PlanCacheStmt); ok { + e.tableNames = append(e.tableNames, innerExtract(v.PreparedAst.Stmt)...) + } + } + return n, false +} + +func (*tableListExtractor) Leave(n ast.Node) (ast.Node, bool) { + return n, true +} + +func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) { + switch x := node.(type) { + case *ast.Join: + collectTableName(x.Left, updatableName, info) + collectTableName(x.Right, updatableName, info) + case *ast.TableSource: + name := x.AsName.L + var canUpdate bool + var s *ast.TableName + if s, canUpdate = x.Source.(*ast.TableName); canUpdate { + if name == "" { + name = s.Schema.L + "." + s.Name.L + // it may be a CTE + if s.Schema.L == "" { + name = s.Name.L + } + } + (*info)[name] = s + } + (*updatableName)[name] = canUpdate && s.Schema.L != "" + } +} + +func appendDynamicVisitInfo(vi []visitInfo, privs []string, withGrant bool, err error) []visitInfo { + return append(vi, visitInfo{ + privilege: mysql.ExtendedPriv, + dynamicPrivs: privs, + dynamicWithGrant: withGrant, + err: err, + }) +} + +func appendVisitInfo(vi []visitInfo, priv mysql.PrivilegeType, db, tbl, col string, err error) []visitInfo { + return append(vi, visitInfo{ + privilege: priv, + db: db, + table: tbl, + column: col, + err: err, + }) +} + +func getInnerFromParenthesesAndUnaryPlus(expr ast.ExprNode) ast.ExprNode { + if pexpr, ok := expr.(*ast.ParenthesesExpr); ok { + return getInnerFromParenthesesAndUnaryPlus(pexpr.Expr) + } + if uexpr, ok := expr.(*ast.UnaryOperationExpr); ok && uexpr.Op == opcode.Plus { + return getInnerFromParenthesesAndUnaryPlus(uexpr.V) + } + return expr +} + +// containDifferentJoinTypes checks whether `PreferJoinType` contains different +// join types. +func containDifferentJoinTypes(preferJoinType uint) bool { + preferJoinType &= ^h.PreferNoHashJoin + preferJoinType &= ^h.PreferNoMergeJoin + preferJoinType &= ^h.PreferNoIndexJoin + preferJoinType &= ^h.PreferNoIndexHashJoin + preferJoinType &= ^h.PreferNoIndexMergeJoin + + inlMask := h.PreferRightAsINLJInner ^ h.PreferLeftAsINLJInner + inlhjMask := h.PreferRightAsINLHJInner ^ h.PreferLeftAsINLHJInner + inlmjMask := h.PreferRightAsINLMJInner ^ h.PreferLeftAsINLMJInner + hjRightBuildMask := h.PreferRightAsHJBuild ^ h.PreferLeftAsHJProbe + hjLeftBuildMask := h.PreferLeftAsHJBuild ^ h.PreferRightAsHJProbe + + mppMask := h.PreferShuffleJoin ^ h.PreferBCJoin + mask := inlMask ^ inlhjMask ^ inlmjMask ^ hjRightBuildMask ^ hjLeftBuildMask + onesCount := bits.OnesCount(preferJoinType & ^mask & ^mppMask) + if onesCount > 1 || onesCount == 1 && preferJoinType&mask > 0 { + return true + } + + cnt := 0 + if preferJoinType&inlMask > 0 { + cnt++ + } + if preferJoinType&inlhjMask > 0 { + cnt++ + } + if preferJoinType&inlmjMask > 0 { + cnt++ + } + if preferJoinType&hjLeftBuildMask > 0 { + cnt++ + } + if preferJoinType&hjRightBuildMask > 0 { + cnt++ + } + return cnt > 1 +} + +func hasMPPJoinHints(preferJoinType uint) bool { + return (preferJoinType&h.PreferBCJoin > 0) || (preferJoinType&h.PreferShuffleJoin > 0) +} + +// isJoinHintSupportedInMPPMode is used to check if the specified join hint is available under MPP mode. +func isJoinHintSupportedInMPPMode(preferJoinType uint) bool { + if preferJoinType == 0 { + return true + } + mppMask := h.PreferShuffleJoin ^ h.PreferBCJoin + // Currently, TiFlash only supports HASH JOIN, so the hint for HASH JOIN is available while other join method hints are forbidden. + joinMethodHintSupportedByTiflash := h.PreferHashJoin ^ h.PreferLeftAsHJBuild ^ h.PreferRightAsHJBuild ^ h.PreferLeftAsHJProbe ^ h.PreferRightAsHJProbe + onesCount := bits.OnesCount(preferJoinType & ^joinMethodHintSupportedByTiflash & ^mppMask) + return onesCount < 1 +} + +func (b *PlanBuilder) buildCte(ctx context.Context, cte *ast.CommonTableExpression, isRecursive bool) (p base.LogicalPlan, err error) { + saveBuildingCTE := b.buildingCTE + b.buildingCTE = true + defer func() { + b.buildingCTE = saveBuildingCTE + }() + + if isRecursive { + // buildingRecursivePartForCTE likes a stack. We save it before building a recursive CTE and restore it after building. + // We need a stack because we need to handle the nested recursive CTE. And buildingRecursivePartForCTE indicates the innermost CTE. + saveCheck := b.buildingRecursivePartForCTE + b.buildingRecursivePartForCTE = false + err = b.buildRecursiveCTE(ctx, cte.Query.Query) + if err != nil { + return nil, err + } + b.buildingRecursivePartForCTE = saveCheck + } else { + p, err = b.buildResultSetNode(ctx, cte.Query.Query, true) + if err != nil { + return nil, err + } + + p, err = b.adjustCTEPlanOutputName(p, cte) + if err != nil { + return nil, err + } + + cInfo := b.outerCTEs[len(b.outerCTEs)-1] + cInfo.seedLP = p + } + return nil, nil +} + +// buildRecursiveCTE handles the with clause `with recursive xxx as xx`. +func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNode) error { + b.isCTE = true + cInfo := b.outerCTEs[len(b.outerCTEs)-1] + switch x := (cte).(type) { + case *ast.SetOprStmt: + // 1. Handle the WITH clause if exists. + if x.With != nil { + l := len(b.outerCTEs) + sw := x.With + defer func() { + b.outerCTEs = b.outerCTEs[:l] + x.With = sw + }() + _, err := b.buildWith(ctx, x.With) + if err != nil { + return err + } + } + // Set it to nil, so that when builds the seed part, it won't build again. Reset it in defer so that the AST doesn't change after this function. + x.With = nil + + // 2. Build plans for each part of SetOprStmt. + recursive := make([]base.LogicalPlan, 0) + tmpAfterSetOptsForRecur := []*ast.SetOprType{nil} + + expectSeed := true + for i := 0; i < len(x.SelectList.Selects); i++ { + var p base.LogicalPlan + var err error + + var afterOpr *ast.SetOprType + switch y := x.SelectList.Selects[i].(type) { + case *ast.SelectStmt: + p, err = b.buildSelect(ctx, y) + afterOpr = y.AfterSetOperator + case *ast.SetOprSelectList: + p, err = b.buildSetOpr(ctx, &ast.SetOprStmt{SelectList: y, With: y.With}) + afterOpr = y.AfterSetOperator + } + + if expectSeed { + if cInfo.useRecursive { + // 3. If it fail to build a plan, it may be the recursive part. Then we build the seed part plan, and rebuild it. + if i == 0 { + return plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst.GenWithStackByArgs(cInfo.def.Name.String()) + } + + // It's the recursive part. Build the seed part, and build this recursive part again. + // Before we build the seed part, do some checks. + if x.OrderBy != nil { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("ORDER BY over UNION in recursive Common Table Expression") + } + // Limit clause is for the whole CTE instead of only for the seed part. + oriLimit := x.Limit + x.Limit = nil + + // Check union type. + if afterOpr != nil { + if *afterOpr != ast.Union && *afterOpr != ast.UnionAll { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("%s between seed part and recursive part, hint: The operator between seed part and recursive part must bu UNION[DISTINCT] or UNION ALL", afterOpr.String())) + } + cInfo.isDistinct = *afterOpr == ast.Union + } + + expectSeed = false + cInfo.useRecursive = false + + // Build seed part plan. + saveSelect := x.SelectList.Selects + x.SelectList.Selects = x.SelectList.Selects[:i] + // We're rebuilding the seed part, so we pop the result we built previously. + for _i := 0; _i < i; _i++ { + b.handleHelper.popMap() + } + p, err = b.buildSetOpr(ctx, x) + if err != nil { + return err + } + x.SelectList.Selects = saveSelect + p, err = b.adjustCTEPlanOutputName(p, cInfo.def) + if err != nil { + return err + } + cInfo.seedLP = p + + // Rebuild the plan. + i-- + b.buildingRecursivePartForCTE = true + x.Limit = oriLimit + continue + } + if err != nil { + return err + } + } else { + if err != nil { + return err + } + if afterOpr != nil { + if *afterOpr != ast.Union && *afterOpr != ast.UnionAll { + return plannererrors.ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("%s between recursive part's selects, hint: The operator between recursive part's selects must bu UNION[DISTINCT] or UNION ALL", afterOpr.String())) + } + } + if !cInfo.useRecursive { + return plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst.GenWithStackByArgs(cInfo.def.Name.String()) + } + cInfo.useRecursive = false + recursive = append(recursive, p) + tmpAfterSetOptsForRecur = append(tmpAfterSetOptsForRecur, afterOpr) + } + } + + if len(recursive) == 0 { + // In this case, even if SQL specifies "WITH RECURSIVE", the CTE is non-recursive. + p, err := b.buildSetOpr(ctx, x) + if err != nil { + return err + } + p, err = b.adjustCTEPlanOutputName(p, cInfo.def) + if err != nil { + return err + } + cInfo.seedLP = p + return nil + } + + // Build the recursive part's logical plan. + recurPart, err := b.buildUnion(ctx, recursive, tmpAfterSetOptsForRecur) + if err != nil { + return err + } + recurPart, err = b.buildProjection4CTEUnion(ctx, cInfo.seedLP, recurPart) + if err != nil { + return err + } + // 4. Finally, we get the seed part plan and recursive part plan. + cInfo.recurLP = recurPart + // Only need to handle limit if x is SetOprStmt. + if x.Limit != nil { + limit, err := b.buildLimit(cInfo.seedLP, x.Limit) + if err != nil { + return err + } + limit.SetChildren(limit.Children()[:0]...) + cInfo.limitLP = limit + } + return nil + default: + p, err := b.buildResultSetNode(ctx, x, true) + if err != nil { + // Refine the error message. + if errors.ErrorEqual(err, plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst) { + err = plannererrors.ErrCTERecursiveRequiresUnion.GenWithStackByArgs(cInfo.def.Name.String()) + } + return err + } + p, err = b.adjustCTEPlanOutputName(p, cInfo.def) + if err != nil { + return err + } + cInfo.seedLP = p + return nil + } +} + +func (b *PlanBuilder) adjustCTEPlanOutputName(p base.LogicalPlan, def *ast.CommonTableExpression) (base.LogicalPlan, error) { + outPutNames := p.OutputNames() + for _, name := range outPutNames { + name.TblName = def.Name + name.DBName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) + } + if len(def.ColNameList) > 0 { + if len(def.ColNameList) != len(p.OutputNames()) { + return nil, dbterror.ErrViewWrongList + } + for i, n := range def.ColNameList { + outPutNames[i].ColName = n + } + } + p.SetOutputNames(outPutNames) + return p, nil +} + +// prepareCTECheckForSubQuery prepares the check that the recursive CTE can't be referenced in subQuery. It's used before building a subQuery. +// For example: with recursive cte(n) as (select 1 union select * from (select * from cte) c1) select * from cte; +func (b *PlanBuilder) prepareCTECheckForSubQuery() []*cteInfo { + modifiedCTE := make([]*cteInfo, 0) + for _, cte := range b.outerCTEs { + if cte.isBuilding && !cte.enterSubquery { + cte.enterSubquery = true + modifiedCTE = append(modifiedCTE, cte) + } + } + return modifiedCTE +} + +// resetCTECheckForSubQuery resets the related variable. It's used after leaving a subQuery. +func resetCTECheckForSubQuery(ci []*cteInfo) { + for _, cte := range ci { + cte.enterSubquery = false + } +} + +// genCTETableNameForError find the nearest CTE name. +func (b *PlanBuilder) genCTETableNameForError() string { + name := "" + for i := len(b.outerCTEs) - 1; i >= 0; i-- { + if b.outerCTEs[i].isBuilding { + name = b.outerCTEs[i].def.Name.String() + break + } + } + return name +} + +func (b *PlanBuilder) buildWith(ctx context.Context, w *ast.WithClause) ([]*cteInfo, error) { + // Check CTE name must be unique. + nameMap := make(map[string]struct{}) + for _, cte := range w.CTEs { + if _, ok := nameMap[cte.Name.L]; ok { + return nil, plannererrors.ErrNonUniqTable + } + nameMap[cte.Name.L] = struct{}{} + } + ctes := make([]*cteInfo, 0, len(w.CTEs)) + for _, cte := range w.CTEs { + b.outerCTEs = append(b.outerCTEs, &cteInfo{def: cte, nonRecursive: !w.IsRecursive, isBuilding: true, storageID: b.allocIDForCTEStorage, seedStat: &property.StatsInfo{}, consumerCount: cte.ConsumerCount}) + b.allocIDForCTEStorage++ + saveFlag := b.optFlag + // Init the flag to flagPrunColumns, otherwise it's missing. + b.optFlag = flagPrunColumns + if b.ctx.GetSessionVars().EnableForceInlineCTE() { + b.outerCTEs[len(b.outerCTEs)-1].forceInlineByHintOrVar = true + } + _, err := b.buildCte(ctx, cte, w.IsRecursive) + if err != nil { + return nil, err + } + b.outerCTEs[len(b.outerCTEs)-1].optFlag = b.optFlag + b.outerCTEs[len(b.outerCTEs)-1].isBuilding = false + b.optFlag = saveFlag + // each cte (select statement) will generate a handle map, pop it out here. + b.handleHelper.popMap() + ctes = append(ctes, b.outerCTEs[len(b.outerCTEs)-1]) + } + return ctes, nil +} + +func (b *PlanBuilder) buildProjection4CTEUnion(_ context.Context, seed base.LogicalPlan, recur base.LogicalPlan) (base.LogicalPlan, error) { + if seed.Schema().Len() != recur.Schema().Len() { + return nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() + } + exprs := make([]expression.Expression, len(seed.Schema().Columns)) + resSchema := getResultCTESchema(seed.Schema(), b.ctx.GetSessionVars()) + for i, col := range recur.Schema().Columns { + if !resSchema.Columns[i].RetType.Equal(col.RetType) { + exprs[i] = expression.BuildCastFunction4Union(b.ctx.GetExprCtx(), col, resSchema.Columns[i].RetType) + } else { + exprs[i] = col + } + } + b.optFlag |= flagEliminateProjection + proj := logicalop.LogicalProjection{Exprs: exprs, AvoidColumnEvaluator: true}.Init(b.ctx, b.getSelectOffset()) + proj.SetSchema(resSchema) + proj.SetChildren(recur) + return proj, nil +} + +// The recursive part/CTE's schema is nullable, and the UID should be unique. +func getResultCTESchema(seedSchema *expression.Schema, svar *variable.SessionVars) *expression.Schema { + res := seedSchema.Clone() + for _, col := range res.Columns { + col.RetType = col.RetType.Clone() + col.UniqueID = svar.AllocPlanColumnID() + col.RetType.DelFlag(mysql.NotNullFlag) + // Since you have reallocated unique id here, the old-cloned-cached hash code is not valid anymore. + col.CleanHashCode() + } + return res +} diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index ebdd4706dcd18..1ae8bdc33dbfc 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -652,14 +652,14 @@ func (h *fineGrainedShuffleHelper) updateTarget(t shuffleTarget, p *basePhysical // calculateTiFlashStreamCountUsingMinLogicalCores uses minimal logical cpu cores among tiflash servers, and divide by 2 // return false, 0 if any err happens func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx base.PlanContext, serversInfo []infoschema.ServerInfo) (bool, uint64) { - failpoint.Inject("mockTiFlashStreamCountUsingMinLogicalCores", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockTiFlashStreamCountUsingMinLogicalCores")); _err_ == nil { intVal, err := strconv.Atoi(val.(string)) if err == nil { - failpoint.Return(true, uint64(intVal)) + return true, uint64(intVal) } else { - failpoint.Return(false, 0) + return false, 0 } - }) + } rows, err := infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, diagnosticspb.ServerInfoType_HardwareInfo, false) if err != nil { return false, 0 diff --git a/pkg/planner/core/optimizer.go__failpoint_stash__ b/pkg/planner/core/optimizer.go__failpoint_stash__ new file mode 100644 index 0000000000000..ebdd4706dcd18 --- /dev/null +++ b/pkg/planner/core/optimizer.go__failpoint_stash__ @@ -0,0 +1,1222 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "cmp" + "context" + "fmt" + "math" + "runtime" + "slices" + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/diagnosticspb" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/lock" + tablelock "github.com/pingcap/tidb/pkg/lock/context" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/rule" + "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + utilhint "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/set" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +// OptimizeAstNode optimizes the query to a physical plan directly. +var OptimizeAstNode func(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (base.Plan, types.NameSlice, error) + +// AllowCartesianProduct means whether tidb allows cartesian join without equal conditions. +var AllowCartesianProduct = atomic.NewBool(true) + +// IsReadOnly check whether the ast.Node is a read only statement. +var IsReadOnly func(node ast.Node, vars *variable.SessionVars) bool + +// Note: The order of flags is same as the order of optRule in the list. +// Do not mess up the order. +const ( + flagGcSubstitute uint64 = 1 << iota + flagPrunColumns + flagStabilizeResults + flagBuildKeyInfo + flagDecorrelate + flagSemiJoinRewrite + flagEliminateAgg + flagSkewDistinctAgg + flagEliminateProjection + flagMaxMinEliminate + flagConstantPropagation + flagConvertOuterToInnerJoin + flagPredicatePushDown + flagEliminateOuterJoin + flagPartitionProcessor + flagCollectPredicateColumnsPoint + flagPushDownAgg + flagDeriveTopNFromWindow + flagPredicateSimplification + flagPushDownTopN + flagSyncWaitStatsLoadPoint + flagJoinReOrder + flagPrunColumnsAgain + flagPushDownSequence + flagResolveExpand +) + +var optRuleList = []base.LogicalOptRule{ + &GcSubstituter{}, + &ColumnPruner{}, + &ResultReorder{}, + &rule.BuildKeySolver{}, + &DecorrelateSolver{}, + &SemiJoinRewriter{}, + &AggregationEliminator{}, + &SkewDistinctAggRewriter{}, + &ProjectionEliminator{}, + &MaxMinEliminator{}, + &ConstantPropagationSolver{}, + &ConvertOuterToInnerJoin{}, + &PPDSolver{}, + &OuterJoinEliminator{}, + &PartitionProcessor{}, + &CollectPredicateColumnsPoint{}, + &AggregationPushDownSolver{}, + &DeriveTopNFromWindow{}, + &PredicateSimplification{}, + &PushDownTopNOptimizer{}, + &SyncWaitStatsLoadPoint{}, + &JoinReOrderSolver{}, + &ColumnPruner{}, // column pruning again at last, note it will mess up the results of buildKeySolver + &PushDownSequenceSolver{}, + &ResolveExpand{}, +} + +// Interaction Rule List +/* The interaction rule will be trigger when it satisfies following conditions: +1. The related rule has been trigger and changed the plan +2. The interaction rule is enabled +*/ +var optInteractionRuleList = map[base.LogicalOptRule]base.LogicalOptRule{} + +// BuildLogicalPlanForTest builds a logical plan for testing purpose from ast.Node. +func BuildLogicalPlanForTest(ctx context.Context, sctx sessionctx.Context, node ast.Node, infoSchema infoschema.InfoSchema) (base.Plan, error) { + sctx.GetSessionVars().PlanID.Store(0) + sctx.GetSessionVars().PlanColumnID.Store(0) + builder, _ := NewPlanBuilder().Init(sctx.GetPlanCtx(), infoSchema, utilhint.NewQBHintHandler(nil)) + p, err := builder.Build(ctx, node) + if err != nil { + return nil, err + } + if logic, ok := p.(base.LogicalPlan); ok { + RecheckCTE(logic) + } + return p, err +} + +// CheckPrivilege checks the privilege for a user. +func CheckPrivilege(activeRoles []*auth.RoleIdentity, pm privilege.Manager, vs []visitInfo) error { + for _, v := range vs { + if v.privilege == mysql.ExtendedPriv { + hasPriv := false + for _, priv := range v.dynamicPrivs { + hasPriv = hasPriv || pm.RequestDynamicVerification(activeRoles, priv, v.dynamicWithGrant) + if hasPriv { + break + } + } + if !hasPriv { + if v.err == nil { + return plannererrors.ErrPrivilegeCheckFail.GenWithStackByArgs(v.dynamicPrivs) + } + return v.err + } + } else if !pm.RequestVerification(activeRoles, v.db, v.table, v.column, v.privilege) { + if v.err == nil { + return plannererrors.ErrPrivilegeCheckFail.GenWithStackByArgs(v.privilege.String()) + } + return v.err + } + } + return nil +} + +// VisitInfo4PrivCheck generates privilege check infos because privilege check of local temporary tables is different +// with normal tables. `CREATE` statement needs `CREATE TEMPORARY TABLE` privilege from the database, and subsequent +// statements do not need any privileges. +func VisitInfo4PrivCheck(ctx context.Context, is infoschema.InfoSchema, node ast.Node, vs []visitInfo) (privVisitInfo []visitInfo) { + if node == nil { + return vs + } + + switch stmt := node.(type) { + case *ast.CreateTableStmt: + privVisitInfo = make([]visitInfo, 0, len(vs)) + for _, v := range vs { + if v.privilege == mysql.CreatePriv { + if stmt.TemporaryKeyword == ast.TemporaryLocal { + // `CREATE TEMPORARY TABLE` privilege is required from the database, not the table. + newVisitInfo := v + newVisitInfo.privilege = mysql.CreateTMPTablePriv + newVisitInfo.table = "" + privVisitInfo = append(privVisitInfo, newVisitInfo) + } else { + // If both the normal table and temporary table already exist, we need to check the privilege. + privVisitInfo = append(privVisitInfo, v) + } + } else { + // `CREATE TABLE LIKE tmp` or `CREATE TABLE FROM SELECT tmp` in the future. + if needCheckTmpTablePriv(ctx, is, v) { + privVisitInfo = append(privVisitInfo, v) + } + } + } + case *ast.DropTableStmt: + // Dropping a local temporary table doesn't need any privileges. + if stmt.IsView { + privVisitInfo = vs + } else { + privVisitInfo = make([]visitInfo, 0, len(vs)) + if stmt.TemporaryKeyword != ast.TemporaryLocal { + for _, v := range vs { + if needCheckTmpTablePriv(ctx, is, v) { + privVisitInfo = append(privVisitInfo, v) + } + } + } + } + case *ast.GrantStmt, *ast.DropSequenceStmt, *ast.DropPlacementPolicyStmt: + // Some statements ignore local temporary tables, so they should check the privileges on normal tables. + privVisitInfo = vs + default: + privVisitInfo = make([]visitInfo, 0, len(vs)) + for _, v := range vs { + if needCheckTmpTablePriv(ctx, is, v) { + privVisitInfo = append(privVisitInfo, v) + } + } + } + return +} + +func needCheckTmpTablePriv(ctx context.Context, is infoschema.InfoSchema, v visitInfo) bool { + if v.db != "" && v.table != "" { + // Other statements on local temporary tables except `CREATE` do not check any privileges. + tb, err := is.TableByName(ctx, model.NewCIStr(v.db), model.NewCIStr(v.table)) + // If the table doesn't exist, we do not report errors to avoid leaking the existence of the table. + if err == nil && tb.Meta().TempTableType == model.TempTableLocal { + return false + } + } + return true +} + +// CheckTableLock checks the table lock. +func CheckTableLock(ctx tablelock.TableLockReadContext, is infoschema.InfoSchema, vs []visitInfo) error { + if !config.TableLockEnabled() { + return nil + } + + checker := lock.NewChecker(ctx, is) + for i := range vs { + err := checker.CheckTableLock(vs[i].db, vs[i].table, vs[i].privilege, vs[i].alterWritable) + // if table with lock-write table dropped, we can access other table, such as `rename` operation + if err == lock.ErrLockedTableDropped { + break + } + if err != nil { + return err + } + } + return nil +} + +func checkStableResultMode(sctx base.PlanContext) bool { + s := sctx.GetSessionVars() + st := s.StmtCtx + return s.EnableStableResultMode && (!st.InInsertStmt && !st.InUpdateStmt && !st.InDeleteStmt && !st.InLoadDataStmt) +} + +// doOptimize optimizes a logical plan into a physical plan, +// while also returning the optimized logical plan, the final physical plan, and the cost of the final plan. +// The returned logical plan is necessary for generating plans for Common Table Expressions (CTEs). +func doOptimize( + ctx context.Context, + sctx base.PlanContext, + flag uint64, + logic base.LogicalPlan, +) (base.LogicalPlan, base.PhysicalPlan, float64, error) { + sessVars := sctx.GetSessionVars() + flag = adjustOptimizationFlags(flag, logic) + logic, err := logicalOptimize(ctx, flag, logic) + if err != nil { + return nil, nil, 0, err + } + + if !AllowCartesianProduct.Load() && existsCartesianProduct(logic) { + return nil, nil, 0, errors.Trace(plannererrors.ErrCartesianProductUnsupported) + } + planCounter := base.PlanCounterTp(sessVars.StmtCtx.StmtHints.ForceNthPlan) + if planCounter == 0 { + planCounter = -1 + } + physical, cost, err := physicalOptimize(logic, &planCounter) + if err != nil { + return nil, nil, 0, err + } + finalPlan := postOptimize(ctx, sctx, physical) + + if sessVars.StmtCtx.EnableOptimizerCETrace { + refineCETrace(sctx) + } + if sessVars.StmtCtx.EnableOptimizeTrace { + sessVars.StmtCtx.OptimizeTracer.RecordFinalPlan(finalPlan.BuildPlanTrace()) + } + return logic, finalPlan, cost, nil +} + +func adjustOptimizationFlags(flag uint64, logic base.LogicalPlan) uint64 { + // If there is something after flagPrunColumns, do flagPrunColumnsAgain. + if flag&flagPrunColumns > 0 && flag-flagPrunColumns > flagPrunColumns { + flag |= flagPrunColumnsAgain + } + if checkStableResultMode(logic.SCtx()) { + flag |= flagStabilizeResults + } + if logic.SCtx().GetSessionVars().StmtCtx.StraightJoinOrder { + // When we use the straight Join Order hint, we should disable the join reorder optimization. + flag &= ^flagJoinReOrder + } + flag |= flagCollectPredicateColumnsPoint + flag |= flagSyncWaitStatsLoadPoint + if !logic.SCtx().GetSessionVars().StmtCtx.UseDynamicPruneMode { + flag |= flagPartitionProcessor // apply partition pruning under static mode + } + return flag +} + +// DoOptimize optimizes a logical plan to a physical plan. +func DoOptimize( + ctx context.Context, + sctx base.PlanContext, + flag uint64, + logic base.LogicalPlan, +) (base.PhysicalPlan, float64, error) { + sessVars := sctx.GetSessionVars() + if sessVars.StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + defer debugtrace.LeaveContextCommon(sctx) + } + _, finalPlan, cost, err := doOptimize(ctx, sctx, flag, logic) + return finalPlan, cost, err +} + +// refineCETrace will adjust the content of CETrace. +// Currently, it will (1) deduplicate trace records, (2) sort the trace records (to make it easier in the tests) and (3) fill in the table name. +func refineCETrace(sctx base.PlanContext) { + stmtCtx := sctx.GetSessionVars().StmtCtx + stmtCtx.OptimizerCETrace = tracing.DedupCETrace(stmtCtx.OptimizerCETrace) + slices.SortFunc(stmtCtx.OptimizerCETrace, func(i, j *tracing.CETraceRecord) int { + if i == nil && j != nil { + return -1 + } + if i == nil || j == nil { + return 1 + } + + if c := cmp.Compare(i.TableID, j.TableID); c != 0 { + return c + } + if c := cmp.Compare(i.Type, j.Type); c != 0 { + return c + } + if c := cmp.Compare(i.Expr, j.Expr); c != 0 { + return c + } + return cmp.Compare(i.RowCount, j.RowCount) + }) + traceRecords := stmtCtx.OptimizerCETrace + is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) + for _, rec := range traceRecords { + tbl, _ := infoschema.FindTableByTblOrPartID(is, rec.TableID) + if tbl != nil { + rec.TableName = tbl.Meta().Name.O + continue + } + logutil.BgLogger().Warn("Failed to find table in infoschema", zap.String("category", "OptimizerTrace"), + zap.Int64("table id", rec.TableID)) + } +} + +// mergeContinuousSelections merge continuous selections which may occur after changing plans. +func mergeContinuousSelections(p base.PhysicalPlan) { + if sel, ok := p.(*PhysicalSelection); ok { + for { + childSel := sel.children[0] + tmp, ok := childSel.(*PhysicalSelection) + if !ok { + break + } + sel.Conditions = append(sel.Conditions, tmp.Conditions...) + sel.SetChild(0, tmp.children[0]) + } + } + for _, child := range p.Children() { + mergeContinuousSelections(child) + } + // merge continuous selections in a coprocessor task of tiflash + tableReader, isTableReader := p.(*PhysicalTableReader) + if isTableReader && tableReader.StoreType == kv.TiFlash { + mergeContinuousSelections(tableReader.tablePlan) + tableReader.TablePlans = flattenPushDownPlan(tableReader.tablePlan) + } +} + +func postOptimize(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan) base.PhysicalPlan { + // some cases from update optimize will require avoiding projection elimination. + // see comments ahead of call of DoOptimize in function of buildUpdate(). + plan = eliminatePhysicalProjection(plan) + plan = InjectExtraProjection(plan) + mergeContinuousSelections(plan) + plan = eliminateUnionScanAndLock(sctx, plan) + plan = enableParallelApply(sctx, plan) + handleFineGrainedShuffle(ctx, sctx, plan) + propagateProbeParents(plan, nil) + countStarRewrite(plan) + disableReuseChunkIfNeeded(sctx, plan) + tryEnableLateMaterialization(sctx, plan) + generateRuntimeFilter(sctx, plan) + return plan +} + +func generateRuntimeFilter(sctx base.PlanContext, plan base.PhysicalPlan) { + if !sctx.GetSessionVars().IsRuntimeFilterEnabled() || sctx.GetSessionVars().InRestrictedSQL { + return + } + logutil.BgLogger().Debug("Start runtime filter generator") + rfGenerator := &RuntimeFilterGenerator{ + rfIDGenerator: &util.IDGenerator{}, + columnUniqueIDToRF: map[int64][]*RuntimeFilter{}, + parentPhysicalPlan: plan, + } + startRFGenerator := time.Now() + rfGenerator.GenerateRuntimeFilter(plan) + logutil.BgLogger().Debug("Finish runtime filter generator", + zap.Duration("Cost", time.Since(startRFGenerator))) +} + +// tryEnableLateMaterialization tries to push down some filter conditions to the table scan operator +// @brief: push down some filter conditions to the table scan operator +// @param: sctx: session context +// @param: plan: the physical plan to be pruned +// @note: this optimization is only applied when the TiFlash is used. +// @note: the following conditions should be satisfied: +// - Only the filter conditions with high selectivity should be pushed down. +// - The filter conditions which contain heavy cost functions should not be pushed down. +// - Filter conditions that apply to the same column are either pushed down or not pushed down at all. +func tryEnableLateMaterialization(sctx base.PlanContext, plan base.PhysicalPlan) { + // check if EnableLateMaterialization is set + if sctx.GetSessionVars().EnableLateMaterialization && !sctx.GetSessionVars().TiFlashFastScan { + predicatePushDownToTableScan(sctx, plan) + } + if sctx.GetSessionVars().EnableLateMaterialization && sctx.GetSessionVars().TiFlashFastScan { + sc := sctx.GetSessionVars().StmtCtx + sc.AppendWarning(errors.NewNoStackError("FastScan is not compatible with late materialization, late materialization is disabled")) + } +} + +/* +* +The countStarRewriter is used to rewrite + + count(*) -> count(not null column) + +**Only for TiFlash** +Attention: +Since count(*) is directly translated into count(1) during grammar parsing, +the rewritten pattern actually matches count(constant) + +Pattern: +PhysicalAggregation: count(constant) + + | + TableFullScan: TiFlash + +Optimize: +Table + + + +Query: select count(*) from table +ColumnPruningRule: datasource pick row_id +countStarRewrite: datasource pick k1 instead of row_id + + rewrite count(*) -> count(k1) + +Rewritten Query: select count(k1) from table +*/ +func countStarRewrite(plan base.PhysicalPlan) { + countStarRewriteInternal(plan) + if tableReader, ok := plan.(*PhysicalTableReader); ok { + countStarRewrite(tableReader.tablePlan) + } else { + for _, child := range plan.Children() { + countStarRewrite(child) + } + } +} + +func countStarRewriteInternal(plan base.PhysicalPlan) { + // match pattern any agg(count(constant)) -> tablefullscan(tiflash) + var physicalAgg *basePhysicalAgg + switch x := plan.(type) { + case *PhysicalHashAgg: + physicalAgg = x.getPointer() + case *PhysicalStreamAgg: + physicalAgg = x.getPointer() + default: + return + } + if len(physicalAgg.GroupByItems) > 0 || len(physicalAgg.children) != 1 { + return + } + for _, aggFunc := range physicalAgg.AggFuncs { + if aggFunc.Name != "count" || len(aggFunc.Args) != 1 || aggFunc.HasDistinct { + return + } + if _, ok := aggFunc.Args[0].(*expression.Constant); !ok { + return + } + } + physicalTableScan, ok := physicalAgg.Children()[0].(*PhysicalTableScan) + if !ok || !physicalTableScan.isFullScan() || physicalTableScan.StoreType != kv.TiFlash || len(physicalTableScan.schema.Columns) != 1 { + return + } + // rewrite datasource and agg args + rewriteTableScanAndAggArgs(physicalTableScan, physicalAgg.AggFuncs) +} + +// rewriteTableScanAndAggArgs Pick the narrowest and not null column from table +// If there is no not null column in Data Source, the row_id or pk column will be retained +func rewriteTableScanAndAggArgs(physicalTableScan *PhysicalTableScan, aggFuncs []*aggregation.AggFuncDesc) { + var resultColumnInfo *model.ColumnInfo + var resultColumn *expression.Column + + resultColumnInfo = physicalTableScan.Columns[0] + resultColumn = physicalTableScan.schema.Columns[0] + // prefer not null column from table + for _, columnInfo := range physicalTableScan.Table.Columns { + if columnInfo.FieldType.IsVarLengthType() { + continue + } + if mysql.HasNotNullFlag(columnInfo.GetFlag()) { + if columnInfo.GetFlen() < resultColumnInfo.GetFlen() { + resultColumnInfo = columnInfo + resultColumn = &expression.Column{ + UniqueID: physicalTableScan.SCtx().GetSessionVars().AllocPlanColumnID(), + ID: resultColumnInfo.ID, + RetType: resultColumnInfo.FieldType.Clone(), + OrigName: fmt.Sprintf("%s.%s.%s", physicalTableScan.DBName.L, physicalTableScan.Table.Name.L, resultColumnInfo.Name), + } + } + } + } + // table scan (row_id) -> (not null column) + physicalTableScan.Columns[0] = resultColumnInfo + physicalTableScan.schema.Columns[0] = resultColumn + // agg arg count(1) -> count(not null column) + arg := resultColumn.Clone() + for _, aggFunc := range aggFuncs { + constExpr, ok := aggFunc.Args[0].(*expression.Constant) + if !ok { + return + } + // count(null) shouldn't be rewritten + if constExpr.Value.IsNull() { + continue + } + aggFunc.Args[0] = arg + } +} + +// Only for MPP(Window<-[Sort]<-ExchangeReceiver<-ExchangeSender). +// TiFlashFineGrainedShuffleStreamCount: +// < 0: fine grained shuffle is disabled. +// > 0: use TiFlashFineGrainedShuffleStreamCount as stream count. +// == 0: use TiFlashMaxThreads as stream count when it's greater than 0. Otherwise set status as uninitialized. +func handleFineGrainedShuffle(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan) { + streamCount := sctx.GetSessionVars().TiFlashFineGrainedShuffleStreamCount + if streamCount < 0 { + return + } + if streamCount == 0 { + if sctx.GetSessionVars().TiFlashMaxThreads > 0 { + streamCount = sctx.GetSessionVars().TiFlashMaxThreads + } + } + // use two separate cluster info to avoid grpc calls cost + tiflashServerCountInfo := tiflashClusterInfo{unInitialized, 0} + streamCountInfo := tiflashClusterInfo{unInitialized, 0} + if streamCount != 0 { + streamCountInfo.itemStatus = initialized + streamCountInfo.itemValue = uint64(streamCount) + } + setupFineGrainedShuffle(ctx, sctx, &streamCountInfo, &tiflashServerCountInfo, plan) +} + +func setupFineGrainedShuffle(ctx context.Context, sctx base.PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, plan base.PhysicalPlan) { + if tableReader, ok := plan.(*PhysicalTableReader); ok { + if _, isExchangeSender := tableReader.tablePlan.(*PhysicalExchangeSender); isExchangeSender { + helper := fineGrainedShuffleHelper{shuffleTarget: unknown, plans: make([]*basePhysicalPlan, 1)} + setupFineGrainedShuffleInternal(ctx, sctx, tableReader.tablePlan, &helper, streamCountInfo, tiflashServerCountInfo) + } + } else { + for _, child := range plan.Children() { + setupFineGrainedShuffle(ctx, sctx, streamCountInfo, tiflashServerCountInfo, child) + } + } +} + +type shuffleTarget uint8 + +const ( + unknown shuffleTarget = iota + window + joinBuild + hashAgg +) + +type fineGrainedShuffleHelper struct { + shuffleTarget shuffleTarget + plans []*basePhysicalPlan + joinKeysCount int +} + +type tiflashClusterInfoStatus uint8 + +const ( + unInitialized tiflashClusterInfoStatus = iota + initialized + failed +) + +type tiflashClusterInfo struct { + itemStatus tiflashClusterInfoStatus + itemValue uint64 +} + +func (h *fineGrainedShuffleHelper) clear() { + h.shuffleTarget = unknown + h.plans = h.plans[:0] + h.joinKeysCount = 0 +} + +func (h *fineGrainedShuffleHelper) updateTarget(t shuffleTarget, p *basePhysicalPlan) { + h.shuffleTarget = t + h.plans = append(h.plans, p) +} + +// calculateTiFlashStreamCountUsingMinLogicalCores uses minimal logical cpu cores among tiflash servers, and divide by 2 +// return false, 0 if any err happens +func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx base.PlanContext, serversInfo []infoschema.ServerInfo) (bool, uint64) { + failpoint.Inject("mockTiFlashStreamCountUsingMinLogicalCores", func(val failpoint.Value) { + intVal, err := strconv.Atoi(val.(string)) + if err == nil { + failpoint.Return(true, uint64(intVal)) + } else { + failpoint.Return(false, 0) + } + }) + rows, err := infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, diagnosticspb.ServerInfoType_HardwareInfo, false) + if err != nil { + return false, 0 + } + var initialMaxCores uint64 = 10000 + var minLogicalCores = initialMaxCores // set to a large enough value here + for _, row := range rows { + if row[4].GetString() == "cpu-logical-cores" { + logicalCpus, err := strconv.Atoi(row[5].GetString()) + if err == nil && logicalCpus > 0 { + minLogicalCores = min(minLogicalCores, uint64(logicalCpus)) + } + } + } + // No need to check len(serersInfo) == serverCount here, since missing some servers' info won't affect the correctness + if minLogicalCores > 1 && minLogicalCores != initialMaxCores { + if runtime.GOARCH == "amd64" { + // In most x86-64 platforms, `Thread(s) per core` is 2 + return true, minLogicalCores / 2 + } + // ARM cpus don't implement Hyper-threading. + return true, minLogicalCores + // Other platforms are too rare to consider + } + + return false, 0 +} + +func checkFineGrainedShuffleForJoinAgg(ctx context.Context, sctx base.PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, exchangeColCount int, splitLimit uint64) (applyFlag bool, streamCount uint64) { + switch (*streamCountInfo).itemStatus { + case unInitialized: + streamCount = 4 // assume 8c node in cluster as minimal, stream count is 8 / 2 = 4 + case initialized: + streamCount = (*streamCountInfo).itemValue + case failed: + return false, 0 // probably won't reach this path + } + + var tiflashServerCount uint64 + switch (*tiflashServerCountInfo).itemStatus { + case unInitialized: + serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) + if err != nil { + (*tiflashServerCountInfo).itemStatus = failed + (*tiflashServerCountInfo).itemValue = 0 + if (*streamCountInfo).itemStatus == unInitialized { + setDefaultStreamCount(streamCountInfo) + } + return false, 0 + } + tiflashServerCount = uint64(len(serversInfo)) + (*tiflashServerCountInfo).itemStatus = initialized + (*tiflashServerCountInfo).itemValue = tiflashServerCount + case initialized: + tiflashServerCount = (*tiflashServerCountInfo).itemValue + case failed: + return false, 0 + } + + // if already exceeds splitLimit, no need to fetch actual logical cores + if tiflashServerCount*uint64(exchangeColCount)*streamCount > splitLimit { + return false, 0 + } + + // if streamCount already initialized, and can pass splitLimit check + if (*streamCountInfo).itemStatus == initialized { + return true, streamCount + } + + serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) + if err != nil { + (*tiflashServerCountInfo).itemStatus = failed + (*tiflashServerCountInfo).itemValue = 0 + return false, 0 + } + flag, temStreamCount := calculateTiFlashStreamCountUsingMinLogicalCores(ctx, sctx, serversInfo) + if !flag { + setDefaultStreamCount(streamCountInfo) + (*tiflashServerCountInfo).itemStatus = failed + return false, 0 + } + streamCount = temStreamCount + (*streamCountInfo).itemStatus = initialized + (*streamCountInfo).itemValue = streamCount + applyFlag = tiflashServerCount*uint64(exchangeColCount)*streamCount <= splitLimit + return applyFlag, streamCount +} + +func inferFineGrainedShuffleStreamCountForWindow(ctx context.Context, sctx base.PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) (streamCount uint64) { + switch (*streamCountInfo).itemStatus { + case unInitialized: + if (*tiflashServerCountInfo).itemStatus == failed { + setDefaultStreamCount(streamCountInfo) + streamCount = (*streamCountInfo).itemValue + break + } + + serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) + if err != nil { + setDefaultStreamCount(streamCountInfo) + streamCount = (*streamCountInfo).itemValue + (*tiflashServerCountInfo).itemStatus = failed + break + } + + if (*tiflashServerCountInfo).itemStatus == unInitialized { + (*tiflashServerCountInfo).itemStatus = initialized + (*tiflashServerCountInfo).itemValue = uint64(len(serversInfo)) + } + + flag, temStreamCount := calculateTiFlashStreamCountUsingMinLogicalCores(ctx, sctx, serversInfo) + if !flag { + setDefaultStreamCount(streamCountInfo) + streamCount = (*streamCountInfo).itemValue + (*tiflashServerCountInfo).itemStatus = failed + break + } + streamCount = temStreamCount + (*streamCountInfo).itemStatus = initialized + (*streamCountInfo).itemValue = streamCount + case initialized: + streamCount = (*streamCountInfo).itemValue + case failed: + setDefaultStreamCount(streamCountInfo) + streamCount = (*streamCountInfo).itemValue + } + return streamCount +} + +func setDefaultStreamCount(streamCountInfo *tiflashClusterInfo) { + (*streamCountInfo).itemStatus = initialized + (*streamCountInfo).itemValue = variable.DefStreamCountWhenMaxThreadsNotSet +} + +func setupFineGrainedShuffleInternal(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan, helper *fineGrainedShuffleHelper, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) { + switch x := plan.(type) { + case *PhysicalWindow: + // Do not clear the plans because window executor will keep the data partition. + // For non hash partition window function, there will be a passthrough ExchangeSender to collect data, + // which will break data partition. + helper.updateTarget(window, &x.basePhysicalPlan) + setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) + case *PhysicalSort: + if x.IsPartialSort { + // Partial sort will keep the data partition. + helper.plans = append(helper.plans, &x.basePhysicalPlan) + } else { + // Global sort will break the data partition. + helper.clear() + } + setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) + case *PhysicalSelection: + helper.plans = append(helper.plans, &x.basePhysicalPlan) + setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) + case *PhysicalProjection: + helper.plans = append(helper.plans, &x.basePhysicalPlan) + setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) + case *PhysicalExchangeReceiver: + helper.plans = append(helper.plans, &x.basePhysicalPlan) + setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) + case *PhysicalHashAgg: + // Todo: allow hash aggregation's output still benefits from fine grained shuffle + aggHelper := fineGrainedShuffleHelper{shuffleTarget: hashAgg, plans: []*basePhysicalPlan{}} + aggHelper.plans = append(aggHelper.plans, &x.basePhysicalPlan) + setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], &aggHelper, streamCountInfo, tiflashServerCountInfo) + case *PhysicalHashJoin: + child0 := x.children[0] + child1 := x.children[1] + buildChild := child0 + probChild := child1 + joinKeys := x.LeftJoinKeys + if x.InnerChildIdx != 0 { + // Child1 is build side. + buildChild = child1 + joinKeys = x.RightJoinKeys + probChild = child0 + } + if len(joinKeys) > 0 { // Not cross join + buildHelper := fineGrainedShuffleHelper{shuffleTarget: joinBuild, plans: []*basePhysicalPlan{}} + buildHelper.plans = append(buildHelper.plans, &x.basePhysicalPlan) + buildHelper.joinKeysCount = len(joinKeys) + setupFineGrainedShuffleInternal(ctx, sctx, buildChild, &buildHelper, streamCountInfo, tiflashServerCountInfo) + } else { + buildHelper := fineGrainedShuffleHelper{shuffleTarget: unknown, plans: []*basePhysicalPlan{}} + setupFineGrainedShuffleInternal(ctx, sctx, buildChild, &buildHelper, streamCountInfo, tiflashServerCountInfo) + } + // don't apply fine grained shuffle for probe side + helper.clear() + setupFineGrainedShuffleInternal(ctx, sctx, probChild, helper, streamCountInfo, tiflashServerCountInfo) + case *PhysicalExchangeSender: + if x.ExchangeType == tipb.ExchangeType_Hash { + // Set up stream count for all plans based on shuffle target type. + var exchangeColCount = x.Schema().Len() + switch helper.shuffleTarget { + case window: + streamCount := inferFineGrainedShuffleStreamCountForWindow(ctx, sctx, streamCountInfo, tiflashServerCountInfo) + x.TiFlashFineGrainedShuffleStreamCount = streamCount + for _, p := range helper.plans { + p.TiFlashFineGrainedShuffleStreamCount = streamCount + } + case hashAgg: + applyFlag, streamCount := checkFineGrainedShuffleForJoinAgg(ctx, sctx, streamCountInfo, tiflashServerCountInfo, exchangeColCount, 1200) // 1200: performance test result + if applyFlag { + x.TiFlashFineGrainedShuffleStreamCount = streamCount + for _, p := range helper.plans { + p.TiFlashFineGrainedShuffleStreamCount = streamCount + } + } + case joinBuild: + // Support hashJoin only when shuffle hash keys equals to join keys due to tiflash implementations + if len(x.HashCols) != helper.joinKeysCount { + break + } + applyFlag, streamCount := checkFineGrainedShuffleForJoinAgg(ctx, sctx, streamCountInfo, tiflashServerCountInfo, exchangeColCount, 600) // 600: performance test result + if applyFlag { + x.TiFlashFineGrainedShuffleStreamCount = streamCount + for _, p := range helper.plans { + p.TiFlashFineGrainedShuffleStreamCount = streamCount + } + } + } + } + // exchange sender will break the data partition. + helper.clear() + setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) + default: + for _, child := range x.Children() { + childHelper := fineGrainedShuffleHelper{shuffleTarget: unknown, plans: []*basePhysicalPlan{}} + setupFineGrainedShuffleInternal(ctx, sctx, child, &childHelper, streamCountInfo, tiflashServerCountInfo) + } + } +} + +// propagateProbeParents doesn't affect the execution plan, it only sets the probeParents field of a PhysicalPlan. +// It's for handling the inconsistency between row count in the statsInfo and the recorded actual row count. Please +// see comments in PhysicalPlan for details. +func propagateProbeParents(plan base.PhysicalPlan, probeParents []base.PhysicalPlan) { + plan.SetProbeParents(probeParents) + switch x := plan.(type) { + case *PhysicalApply, *PhysicalIndexJoin, *PhysicalIndexHashJoin, *PhysicalIndexMergeJoin: + if join, ok := plan.(interface{ getInnerChildIdx() int }); ok { + propagateProbeParents(plan.Children()[1-join.getInnerChildIdx()], probeParents) + + // The core logic of this method: + // Record every Apply and Index Join we met, record it in a slice, and set it in their inner children. + newParents := make([]base.PhysicalPlan, len(probeParents), len(probeParents)+1) + copy(newParents, probeParents) + newParents = append(newParents, plan) + propagateProbeParents(plan.Children()[join.getInnerChildIdx()], newParents) + } + case *PhysicalTableReader: + propagateProbeParents(x.tablePlan, probeParents) + case *PhysicalIndexReader: + propagateProbeParents(x.indexPlan, probeParents) + case *PhysicalIndexLookUpReader: + propagateProbeParents(x.indexPlan, probeParents) + propagateProbeParents(x.tablePlan, probeParents) + case *PhysicalIndexMergeReader: + for _, pchild := range x.partialPlans { + propagateProbeParents(pchild, probeParents) + } + propagateProbeParents(x.tablePlan, probeParents) + default: + for _, child := range plan.Children() { + propagateProbeParents(child, probeParents) + } + } +} + +func enableParallelApply(sctx base.PlanContext, plan base.PhysicalPlan) base.PhysicalPlan { + if !sctx.GetSessionVars().EnableParallelApply { + return plan + } + // the parallel apply has three limitation: + // 1. the parallel implementation now cannot keep order; + // 2. the inner child has to support clone; + // 3. if one Apply is in the inner side of another Apply, it cannot be parallel, for example: + // The topology of 3 Apply operators are A1(A2, A3), which means A2 is the outer child of A1 + // while A3 is the inner child. Then A1 and A2 can be parallel and A3 cannot. + if apply, ok := plan.(*PhysicalApply); ok { + outerIdx := 1 - apply.InnerChildIdx + noOrder := len(apply.GetChildReqProps(outerIdx).SortItems) == 0 // limitation 1 + _, err := SafeClone(sctx, apply.Children()[apply.InnerChildIdx]) + supportClone := err == nil // limitation 2 + if noOrder && supportClone { + apply.Concurrency = sctx.GetSessionVars().ExecutorConcurrency + } else { + if err != nil { + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Some apply operators can not be executed in parallel: %v", err)) + } else { + sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("Some apply operators can not be executed in parallel")) + } + } + // because of the limitation 3, we cannot parallelize Apply operators in this Apply's inner size, + // so we only invoke recursively for its outer child. + apply.SetChild(outerIdx, enableParallelApply(sctx, apply.Children()[outerIdx])) + return apply + } + for i, child := range plan.Children() { + plan.SetChild(i, enableParallelApply(sctx, child)) + } + return plan +} + +// LogicalOptimizeTest is just exported for test. +func LogicalOptimizeTest(ctx context.Context, flag uint64, logic base.LogicalPlan) (base.LogicalPlan, error) { + return logicalOptimize(ctx, flag, logic) +} + +func logicalOptimize(ctx context.Context, flag uint64, logic base.LogicalPlan) (base.LogicalPlan, error) { + if logic.SCtx().GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(logic.SCtx()) + defer debugtrace.LeaveContextCommon(logic.SCtx()) + } + opt := optimizetrace.DefaultLogicalOptimizeOption() + vars := logic.SCtx().GetSessionVars() + if vars.StmtCtx.EnableOptimizeTrace { + vars.StmtCtx.OptimizeTracer = &tracing.OptimizeTracer{} + tracer := &tracing.LogicalOptimizeTracer{ + Steps: make([]*tracing.LogicalRuleOptimizeTracer, 0), + } + opt = opt.WithEnableOptimizeTracer(tracer) + defer func() { + vars.StmtCtx.OptimizeTracer.Logical = tracer + }() + } + var err error + var againRuleList []base.LogicalOptRule + for i, rule := range optRuleList { + // The order of flags is same as the order of optRule in the list. + // We use a bitmask to record which opt rules should be used. If the i-th bit is 1, it means we should + // apply i-th optimizing rule. + if flag&(1< 0 { + logic.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("The parameter of nth_plan() is out of range")) + } + if t.Invalid() { + errMsg := "Can't find a proper physical plan for this query" + if config.GetGlobalConfig().DisaggregatedTiFlash && !logic.SCtx().GetSessionVars().IsMPPAllowed() { + errMsg += ": cop and batchCop are not allowed in disaggregated tiflash mode, you should turn on tidb_allow_mpp switch" + } + return nil, 0, plannererrors.ErrInternal.GenWithStackByArgs(errMsg) + } + + if err = t.Plan().ResolveIndices(); err != nil { + return nil, 0, err + } + cost, err = getPlanCost(t.Plan(), property.RootTaskType, optimizetrace.NewDefaultPlanCostOption()) + return t.Plan(), cost, err +} + +// eliminateUnionScanAndLock set lock property for PointGet and BatchPointGet and eliminates UnionScan and Lock. +func eliminateUnionScanAndLock(sctx base.PlanContext, p base.PhysicalPlan) base.PhysicalPlan { + var pointGet *PointGetPlan + var batchPointGet *BatchPointGetPlan + var physLock *PhysicalLock + var unionScan *PhysicalUnionScan + iteratePhysicalPlan(p, func(p base.PhysicalPlan) bool { + if len(p.Children()) > 1 { + return false + } + switch x := p.(type) { + case *PointGetPlan: + pointGet = x + case *BatchPointGetPlan: + batchPointGet = x + case *PhysicalLock: + physLock = x + case *PhysicalUnionScan: + unionScan = x + } + return true + }) + if pointGet == nil && batchPointGet == nil { + return p + } + if physLock == nil && unionScan == nil { + return p + } + if physLock != nil { + lock, waitTime := getLockWaitTime(sctx, physLock.Lock) + if !lock { + return p + } + if pointGet != nil { + pointGet.Lock = lock + pointGet.LockWaitTime = waitTime + } else { + batchPointGet.Lock = lock + batchPointGet.LockWaitTime = waitTime + } + } + return transformPhysicalPlan(p, func(p base.PhysicalPlan) base.PhysicalPlan { + if p == physLock { + return p.Children()[0] + } + if p == unionScan { + return p.Children()[0] + } + return p + }) +} + +func iteratePhysicalPlan(p base.PhysicalPlan, f func(p base.PhysicalPlan) bool) { + if !f(p) { + return + } + for _, child := range p.Children() { + iteratePhysicalPlan(child, f) + } +} + +func transformPhysicalPlan(p base.PhysicalPlan, f func(p base.PhysicalPlan) base.PhysicalPlan) base.PhysicalPlan { + for i, child := range p.Children() { + p.Children()[i] = transformPhysicalPlan(child, f) + } + return f(p) +} + +func existsCartesianProduct(p base.LogicalPlan) bool { + if join, ok := p.(*LogicalJoin); ok && len(join.EqualConditions) == 0 { + return join.JoinType == InnerJoin || join.JoinType == LeftOuterJoin || join.JoinType == RightOuterJoin + } + for _, child := range p.Children() { + if existsCartesianProduct(child) { + return true + } + } + return false +} + +// DefaultDisabledLogicalRulesList indicates the logical rules which should be banned. +var DefaultDisabledLogicalRulesList *atomic.Value + +func disableReuseChunkIfNeeded(sctx base.PlanContext, plan base.PhysicalPlan) { + if !sctx.GetSessionVars().IsAllocValid() { + return + } + + if checkOverlongColType(sctx, plan) { + return + } + + for _, child := range plan.Children() { + disableReuseChunkIfNeeded(sctx, child) + } +} + +// checkOverlongColType Check if read field type is long field. +func checkOverlongColType(sctx base.PlanContext, plan base.PhysicalPlan) bool { + if plan == nil { + return false + } + switch plan.(type) { + case *PhysicalTableReader, *PhysicalIndexReader, + *PhysicalIndexLookUpReader, *PhysicalIndexMergeReader, *PointGetPlan: + if existsOverlongType(plan.Schema()) { + sctx.GetSessionVars().ClearAlloc(nil, false) + return true + } + } + return false +} + +// existsOverlongType Check if exists long type column. +func existsOverlongType(schema *expression.Schema) bool { + if schema == nil { + return false + } + for _, column := range schema.Columns { + switch column.RetType.GetType() { + case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, + mysql.TypeBlob, mysql.TypeJSON: + return true + case mysql.TypeVarString, mysql.TypeVarchar: + // if the column is varchar and the length of + // the column is defined to be more than 1000, + // the column is considered a large type and + // disable chunk_reuse. + if column.RetType.GetFlen() > 1000 { + return true + } + } + } + return false +} diff --git a/pkg/planner/core/rule_collect_plan_stats.go b/pkg/planner/core/rule_collect_plan_stats.go index e9047782c5095..33c07f664f0f9 100644 --- a/pkg/planner/core/rule_collect_plan_stats.go +++ b/pkg/planner/core/rule_collect_plan_stats.go @@ -106,13 +106,13 @@ func RequestLoadStats(ctx base.PlanContext, neededHistItems []model.StatsLoadIte if maxExecutionTime > 0 && maxExecutionTime < uint64(syncWait) { syncWait = int64(maxExecutionTime) } - failpoint.Inject("assertSyncWaitFailed", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertSyncWaitFailed")); _err_ == nil { if val.(bool) { if syncWait != 1 { panic("syncWait should be 1(ms)") } } - }) + } var timeout = time.Duration(syncWait * time.Millisecond.Nanoseconds()) stmtCtx := ctx.GetSessionVars().StmtCtx err := domain.GetDomain(ctx).StatsHandle().SendLoadRequests(stmtCtx, neededHistItems, timeout) diff --git a/pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ b/pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ new file mode 100644 index 0000000000000..e9047782c5095 --- /dev/null +++ b/pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ @@ -0,0 +1,316 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "context" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// CollectPredicateColumnsPoint collects the columns that are used in the predicates. +type CollectPredicateColumnsPoint struct{} + +// Optimize implements LogicalOptRule.<0th> interface. +func (CollectPredicateColumnsPoint) Optimize(_ context.Context, plan base.LogicalPlan, _ *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, bool, error) { + planChanged := false + if plan.SCtx().GetSessionVars().InRestrictedSQL { + return plan, planChanged, nil + } + syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait.Load() + histNeeded := syncWait > 0 + predicateColumns, histNeededColumns, visitedPhysTblIDs := CollectColumnStatsUsage(plan, histNeeded) + if len(predicateColumns) > 0 { + plan.SCtx().UpdateColStatsUsage(predicateColumns) + } + + // Prepare the table metadata to avoid repeatedly fetching from the infoSchema below, and trigger extra sync/async + // stats loading for the determinate mode. + is := plan.SCtx().GetInfoSchema().(infoschema.InfoSchema) + tblID2Tbl := make(map[int64]table.Table) + visitedPhysTblIDs.ForEach(func(physicalTblID int) { + tbl, _ := infoschema.FindTableByTblOrPartID(is, int64(physicalTblID)) + if tbl == nil { + return + } + tblID2Tbl[int64(physicalTblID)] = tbl + }) + + // collect needed virtual columns from already needed columns + // Note that we use the dependingVirtualCols only to collect needed index stats, but not to trigger stats loading on + // the virtual columns themselves. It's because virtual columns themselves don't have statistics, while expression + // indexes, which are indexes on virtual columns, have statistics. We don't waste the resource here now. + dependingVirtualCols := CollectDependingVirtualCols(tblID2Tbl, histNeededColumns) + + histNeededIndices := collectSyncIndices(plan.SCtx(), append(histNeededColumns, dependingVirtualCols...), tblID2Tbl) + histNeededItems := collectHistNeededItems(histNeededColumns, histNeededIndices) + if histNeeded && len(histNeededItems) > 0 { + err := RequestLoadStats(plan.SCtx(), histNeededItems, syncWait) + return plan, planChanged, err + } + return plan, planChanged, nil +} + +// Name implements the base.LogicalOptRule.<1st> interface. +func (CollectPredicateColumnsPoint) Name() string { + return "collect_predicate_columns_point" +} + +// SyncWaitStatsLoadPoint sync-wait for stats load point. +type SyncWaitStatsLoadPoint struct{} + +// Optimize implements the base.LogicalOptRule.<0th> interface. +func (SyncWaitStatsLoadPoint) Optimize(_ context.Context, plan base.LogicalPlan, _ *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, bool, error) { + planChanged := false + if plan.SCtx().GetSessionVars().InRestrictedSQL { + return plan, planChanged, nil + } + if plan.SCtx().GetSessionVars().StmtCtx.IsSyncStatsFailed { + return plan, planChanged, nil + } + err := SyncWaitStatsLoad(plan) + return plan, planChanged, err +} + +// Name implements the base.LogicalOptRule.<1st> interface. +func (SyncWaitStatsLoadPoint) Name() string { + return "sync_wait_stats_load_point" +} + +// RequestLoadStats send load column/index stats requests to stats handle +func RequestLoadStats(ctx base.PlanContext, neededHistItems []model.StatsLoadItem, syncWait int64) error { + maxExecutionTime := ctx.GetSessionVars().GetMaxExecutionTime() + if maxExecutionTime > 0 && maxExecutionTime < uint64(syncWait) { + syncWait = int64(maxExecutionTime) + } + failpoint.Inject("assertSyncWaitFailed", func(val failpoint.Value) { + if val.(bool) { + if syncWait != 1 { + panic("syncWait should be 1(ms)") + } + } + }) + var timeout = time.Duration(syncWait * time.Millisecond.Nanoseconds()) + stmtCtx := ctx.GetSessionVars().StmtCtx + err := domain.GetDomain(ctx).StatsHandle().SendLoadRequests(stmtCtx, neededHistItems, timeout) + if err != nil { + stmtCtx.IsSyncStatsFailed = true + if variable.StatsLoadPseudoTimeout.Load() { + logutil.BgLogger().Warn("RequestLoadStats failed", zap.Error(err)) + stmtCtx.AppendWarning(err) + return nil + } + logutil.BgLogger().Warn("RequestLoadStats failed", zap.Error(err)) + return err + } + return nil +} + +// SyncWaitStatsLoad sync-wait for stats load until timeout +func SyncWaitStatsLoad(plan base.LogicalPlan) error { + stmtCtx := plan.SCtx().GetSessionVars().StmtCtx + if len(stmtCtx.StatsLoad.NeededItems) <= 0 { + return nil + } + err := domain.GetDomain(plan.SCtx()).StatsHandle().SyncWaitStatsLoad(stmtCtx) + if err != nil { + stmtCtx.IsSyncStatsFailed = true + if variable.StatsLoadPseudoTimeout.Load() { + logutil.BgLogger().Warn("SyncWaitStatsLoad failed", zap.Error(err)) + stmtCtx.AppendWarning(err) + return nil + } + logutil.BgLogger().Error("SyncWaitStatsLoad failed", zap.Error(err)) + return err + } + return nil +} + +// CollectDependingVirtualCols collects the virtual columns that depend on the needed columns, and returns them in a new slice. +// +// Why do we need this? +// It's mainly for stats sync loading. +// Currently, virtual columns themselves don't have statistics. But expression indexes, which are indexes on virtual +// columns, have statistics. We need to collect needed virtual columns, then needed expression index stats can be +// collected for sync loading. +// In normal cases, if a virtual column can be used, which means related statistics may be needed, the corresponding +// expressions in the query must have already been replaced with the virtual column before here. So we just need to treat +// them like normal columns in stats sync loading, which means we just extract the Column from the expressions, the +// virtual columns we want will be there. +// However, in some cases (the mv index case now), the expressions are not replaced with the virtual columns before here. +// Instead, we match the expression in the query against the expression behind the virtual columns after here when +// building the access paths. This means we are unable to known what virtual columns will be needed by just extracting +// the Column from the expressions here. So we need to manually collect the virtual columns that may be needed. +// +// Note 1: As long as a virtual column depends on the needed columns, it will be collected. This could collect some virtual +// columns that are not actually needed. +// It's OK because that's how sync loading is expected. Sync loading only needs to ensure all actually needed stats are +// triggered to be loaded. Other logic of sync loading also works like this. +// If we want to collect only the virtual columns that are actually needed, we need to make the checking logic here exactly +// the same as the logic for generating the access paths, which will make the logic here very complicated. +// +// Note 2: Only direct dependencies are considered here. +// If a virtual column depends on another virtual column, and the latter depends on the needed columns, then the former +// will not be collected. +// For example: create table t(a int, b int, c int as (a+b), d int as (c+1)); If a is needed, then c will be collected, +// but d will not be collected. +// It's because currently it's impossible that statistics related to indirectly depending columns are actually needed. +// If we need to check indirect dependency some day, we can easily extend the logic here. +func CollectDependingVirtualCols(tblID2Tbl map[int64]table.Table, neededItems []model.StatsLoadItem) []model.StatsLoadItem { + generatedCols := make([]model.StatsLoadItem, 0) + + // group the neededItems by table id + tblID2neededColIDs := make(map[int64][]int64, len(tblID2Tbl)) + for _, item := range neededItems { + if item.IsIndex { + continue + } + tblID2neededColIDs[item.TableID] = append(tblID2neededColIDs[item.TableID], item.ID) + } + + // process them by table id + for tblID, colIDs := range tblID2neededColIDs { + tbl := tblID2Tbl[tblID] + if tbl == nil { + continue + } + // collect the needed columns on this table into a set for faster lookup + colNameSet := make(map[string]struct{}, len(colIDs)) + for _, colID := range colIDs { + name := tbl.Meta().FindColumnNameByID(colID) + if name == "" { + continue + } + colNameSet[name] = struct{}{} + } + // iterate columns in this table, and collect the virtual columns that depend on the needed columns + for _, col := range tbl.Cols() { + // only handles virtual columns + if !col.IsVirtualGenerated() { + continue + } + // If this column is already needed, then skip it. + if _, ok := colNameSet[col.Name.L]; ok { + continue + } + // If there exists a needed column that is depended on by this virtual column, + // then we think this virtual column is needed. + for depCol := range col.Dependences { + if _, ok := colNameSet[depCol]; ok { + generatedCols = append(generatedCols, model.StatsLoadItem{TableItemID: model.TableItemID{TableID: tblID, ID: col.ID, IsIndex: false}, FullLoad: true}) + break + } + } + } + } + return generatedCols +} + +// collectSyncIndices will collect the indices which includes following conditions: +// 1. the indices contained the any one of histNeededColumns, eg: histNeededColumns contained A,B columns, and idx_a is +// composed up by A column, then we thought the idx_a should be collected +// 2. The stats condition of idx_a can't meet IsFullLoad, which means its stats was evicted previously +func collectSyncIndices(ctx base.PlanContext, + histNeededColumns []model.StatsLoadItem, + tblID2Tbl map[int64]table.Table, +) map[model.TableItemID]struct{} { + histNeededIndices := make(map[model.TableItemID]struct{}) + stats := domain.GetDomain(ctx).StatsHandle() + for _, column := range histNeededColumns { + if column.IsIndex { + continue + } + tbl := tblID2Tbl[column.TableID] + if tbl == nil { + continue + } + colName := tbl.Meta().FindColumnNameByID(column.ID) + if colName == "" { + continue + } + for _, idx := range tbl.Indices() { + if idx.Meta().State != model.StatePublic { + continue + } + idxCol := idx.Meta().FindColumnByName(colName) + idxID := idx.Meta().ID + if idxCol != nil { + tblStats := stats.GetTableStats(tbl.Meta()) + if tblStats == nil || tblStats.Pseudo { + continue + } + _, loadNeeded := tblStats.IndexIsLoadNeeded(idxID) + if !loadNeeded { + continue + } + histNeededIndices[model.TableItemID{TableID: column.TableID, ID: idxID, IsIndex: true}] = struct{}{} + } + } + } + return histNeededIndices +} + +func collectHistNeededItems(histNeededColumns []model.StatsLoadItem, histNeededIndices map[model.TableItemID]struct{}) (histNeededItems []model.StatsLoadItem) { + histNeededItems = make([]model.StatsLoadItem, 0, len(histNeededColumns)+len(histNeededIndices)) + for idx := range histNeededIndices { + histNeededItems = append(histNeededItems, model.StatsLoadItem{TableItemID: idx, FullLoad: true}) + } + histNeededItems = append(histNeededItems, histNeededColumns...) + return +} + +func recordTableRuntimeStats(sctx base.PlanContext, tbls map[int64]struct{}) { + tblStats := sctx.GetSessionVars().StmtCtx.TableStats + if tblStats == nil { + tblStats = map[int64]any{} + } + for tblID := range tbls { + tblJSONStats, skip, err := recordSingleTableRuntimeStats(sctx, tblID) + if err != nil { + logutil.BgLogger().Warn("record table json stats failed", zap.Int64("tblID", tblID), zap.Error(err)) + } + if tblJSONStats == nil && !skip { + logutil.BgLogger().Warn("record table json stats failed due to empty", zap.Int64("tblID", tblID)) + } + tblStats[tblID] = tblJSONStats + } + sctx.GetSessionVars().StmtCtx.TableStats = tblStats +} + +func recordSingleTableRuntimeStats(sctx base.PlanContext, tblID int64) (stats *statistics.Table, skip bool, err error) { + dom := domain.GetDomain(sctx) + statsHandle := dom.StatsHandle() + is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) + tbl, ok := is.TableByID(tblID) + if !ok { + return nil, false, nil + } + tableInfo := tbl.Meta() + stats = statsHandle.GetTableStats(tableInfo) + // Skip the warning if the table is a temporary table because the temporary table doesn't have stats. + skip = tableInfo.TempTableType != model.TempTableNone + return stats, skip, nil +} diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go index 0a082bc106ac5..40f2977caa27d 100644 --- a/pkg/planner/core/rule_eliminate_projection.go +++ b/pkg/planner/core/rule_eliminate_projection.go @@ -135,11 +135,11 @@ func doPhysicalProjectionElimination(p base.PhysicalPlan) base.PhysicalPlan { // eliminatePhysicalProjection should be called after physical optimization to // eliminate the redundant projection left after logical projection elimination. func eliminatePhysicalProjection(p base.PhysicalPlan) base.PhysicalPlan { - failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("DisableProjectionPostOptimization")); _err_ == nil { if val.(bool) { - failpoint.Return(p) + return p } - }) + } newRoot := doPhysicalProjectionElimination(p) return newRoot diff --git a/pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ b/pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ new file mode 100644 index 0000000000000..0a082bc106ac5 --- /dev/null +++ b/pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ @@ -0,0 +1,274 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "bytes" + "context" + "fmt" + + perrors "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" + ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util" + "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" +) + +// canProjectionBeEliminatedLoose checks whether a projection can be eliminated, +// returns true if every expression is a single column. +func canProjectionBeEliminatedLoose(p *logicalop.LogicalProjection) bool { + // project for expand will assign a new col id for col ref, because these column should be + // data cloned in the execution time and may be filled with null value at the same time. + // so it's not a REAL column reference. Detect the column ref in projection here and do + // the elimination here will restore the Expand's grouping sets column back to use the + // original column ref again. (which is not right) + if p.Proj4Expand { + return false + } + for _, expr := range p.Exprs { + _, ok := expr.(*expression.Column) + if !ok { + return false + } + } + return true +} + +// canProjectionBeEliminatedStrict checks whether a projection can be +// eliminated, returns true if the projection just copy its child's output. +func canProjectionBeEliminatedStrict(p *PhysicalProjection) bool { + // This is due to the in-compatibility between TiFlash and TiDB: + // For TiDB, the output schema of final agg is all the aggregated functions and for + // TiFlash, the output schema of agg(TiFlash not aware of the aggregation mode) is + // aggregated functions + group by columns, so to make the things work, for final + // mode aggregation that need to be running in TiFlash, always add an extra Project + // the align the output schema. In the future, we can solve this in-compatibility by + // passing down the aggregation mode to TiFlash. + if physicalAgg, ok := p.Children()[0].(*PhysicalHashAgg); ok { + if physicalAgg.MppRunMode == Mpp1Phase || physicalAgg.MppRunMode == Mpp2Phase || physicalAgg.MppRunMode == MppScalar { + if physicalAgg.IsFinalAgg() { + return false + } + } + } + if physicalAgg, ok := p.Children()[0].(*PhysicalStreamAgg); ok { + if physicalAgg.MppRunMode == Mpp1Phase || physicalAgg.MppRunMode == Mpp2Phase || physicalAgg.MppRunMode == MppScalar { + if physicalAgg.IsFinalAgg() { + return false + } + } + } + // If this projection is specially added for `DO`, we keep it. + if p.CalculateNoDelay { + return false + } + if p.Schema().Len() == 0 { + return true + } + child := p.Children()[0] + if p.Schema().Len() != child.Schema().Len() { + return false + } + for i, expr := range p.Exprs { + col, ok := expr.(*expression.Column) + if !ok || !col.EqualColumn(child.Schema().Columns[i]) { + return false + } + } + return true +} + +func doPhysicalProjectionElimination(p base.PhysicalPlan) base.PhysicalPlan { + for i, child := range p.Children() { + p.Children()[i] = doPhysicalProjectionElimination(child) + } + + // eliminate projection in a coprocessor task + tableReader, isTableReader := p.(*PhysicalTableReader) + if isTableReader && tableReader.StoreType == kv.TiFlash { + tableReader.tablePlan = eliminatePhysicalProjection(tableReader.tablePlan) + tableReader.TablePlans = flattenPushDownPlan(tableReader.tablePlan) + return p + } + + proj, isProj := p.(*PhysicalProjection) + if !isProj || !canProjectionBeEliminatedStrict(proj) { + return p + } + child := p.Children()[0] + if childProj, ok := child.(*PhysicalProjection); ok { + // when current projection is an empty projection(schema pruned by column pruner), no need to reset child's schema + // TODO: avoid producing empty projection in column pruner. + if p.Schema().Len() != 0 { + childProj.SetSchema(p.Schema()) + } + // If any of the consecutive projection operators has the AvoidColumnEvaluator set to true, + // we need to set the AvoidColumnEvaluator of the remaining projection to true. + if proj.AvoidColumnEvaluator { + childProj.AvoidColumnEvaluator = true + } + } + for i, col := range p.Schema().Columns { + if p.SCtx().GetSessionVars().StmtCtx.ColRefFromUpdatePlan.Has(int(col.UniqueID)) && !child.Schema().Columns[i].Equal(nil, col) { + return p + } + } + return child +} + +// eliminatePhysicalProjection should be called after physical optimization to +// eliminate the redundant projection left after logical projection elimination. +func eliminatePhysicalProjection(p base.PhysicalPlan) base.PhysicalPlan { + failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(p) + } + }) + + newRoot := doPhysicalProjectionElimination(p) + return newRoot +} + +// For select, insert, delete list +// The projection eliminate in logical optimize will optimize the projection under the projection, window, agg +// The projection eliminate in post optimize will optimize other projection + +// ProjectionEliminator is for update stmt +// The projection eliminate in logical optimize has been forbidden. +// The projection eliminate in post optimize will optimize the projection under the projection, window, agg (the condition is same as logical optimize) +type ProjectionEliminator struct { +} + +// Optimize implements the logicalOptRule interface. +func (pe *ProjectionEliminator) Optimize(_ context.Context, lp base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, bool, error) { + planChanged := false + root := pe.eliminate(lp, make(map[string]*expression.Column), false, opt) + return root, planChanged, nil +} + +// eliminate eliminates the redundant projection in a logical plan. +func (pe *ProjectionEliminator) eliminate(p base.LogicalPlan, replace map[string]*expression.Column, canEliminate bool, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { + // LogicalCTE's logical optimization is independent. + if _, ok := p.(*LogicalCTE); ok { + return p + } + proj, isProj := p.(*logicalop.LogicalProjection) + childFlag := canEliminate + if _, isUnion := p.(*LogicalUnionAll); isUnion { + childFlag = false + } else if _, isAgg := p.(*LogicalAggregation); isAgg || isProj { + childFlag = true + } else if _, isWindow := p.(*logicalop.LogicalWindow); isWindow { + childFlag = true + } + for i, child := range p.Children() { + p.Children()[i] = pe.eliminate(child, replace, childFlag, opt) + } + + // replace logical plan schema + switch x := p.(type) { + case *LogicalJoin: + x.SetSchema(buildLogicalJoinSchema(x.JoinType, x)) + case *LogicalApply: + x.SetSchema(buildLogicalJoinSchema(x.JoinType, x)) + default: + for _, dst := range p.Schema().Columns { + ruleutil.ResolveColumnAndReplace(dst, replace) + } + } + // replace all of exprs in logical plan + p.ReplaceExprColumns(replace) + + // eliminate duplicate projection: projection with child projection + if isProj { + if child, ok := p.Children()[0].(*logicalop.LogicalProjection); ok && !expression.ExprsHasSideEffects(child.Exprs) { + ctx := p.SCtx() + for i := range proj.Exprs { + proj.Exprs[i] = ReplaceColumnOfExpr(proj.Exprs[i], child, child.Schema()) + foldedExpr := expression.FoldConstant(ctx.GetExprCtx(), proj.Exprs[i]) + // the folded expr should have the same null flag with the original expr, especially for the projection under union, so forcing it here. + foldedExpr.GetType(ctx.GetExprCtx().GetEvalCtx()).SetFlag((foldedExpr.GetType(ctx.GetExprCtx().GetEvalCtx()).GetFlag() & ^mysql.NotNullFlag) | (proj.Exprs[i].GetType(ctx.GetExprCtx().GetEvalCtx()).GetFlag() & mysql.NotNullFlag)) + proj.Exprs[i] = foldedExpr + } + p.Children()[0] = child.Children()[0] + appendDupProjEliminateTraceStep(proj, child, opt) + } + } + + if !(isProj && canEliminate && canProjectionBeEliminatedLoose(proj)) { + return p + } + exprs := proj.Exprs + for i, col := range proj.Schema().Columns { + replace[string(col.HashCode())] = exprs[i].(*expression.Column) + } + appendProjEliminateTraceStep(proj, opt) + return p.Children()[0] +} + +// ReplaceColumnOfExpr replaces column of expression by another LogicalProjection. +func ReplaceColumnOfExpr(expr expression.Expression, proj *logicalop.LogicalProjection, schema *expression.Schema) expression.Expression { + switch v := expr.(type) { + case *expression.Column: + idx := schema.ColumnIndex(v) + if idx != -1 && idx < len(proj.Exprs) { + return proj.Exprs[idx] + } + case *expression.ScalarFunction: + for i := range v.GetArgs() { + v.GetArgs()[i] = ReplaceColumnOfExpr(v.GetArgs()[i], proj, schema) + } + } + return expr +} + +// Name implements the logicalOptRule.<1st> interface. +func (*ProjectionEliminator) Name() string { + return "projection_eliminate" +} + +func appendDupProjEliminateTraceStep(parent, child *logicalop.LogicalProjection, opt *optimizetrace.LogicalOptimizeOp) { + ectx := parent.SCtx().GetExprCtx().GetEvalCtx() + action := func() string { + buffer := bytes.NewBufferString( + fmt.Sprintf("%v_%v is eliminated, %v_%v's expressions changed into[", child.TP(), child.ID(), parent.TP(), parent.ID())) + for i, expr := range parent.Exprs { + if i > 0 { + buffer.WriteString(",") + } + buffer.WriteString(expr.StringWithCtx(ectx, perrors.RedactLogDisable)) + } + buffer.WriteString("]") + return buffer.String() + } + reason := func() string { + return fmt.Sprintf("%v_%v's child %v_%v is redundant", parent.TP(), parent.ID(), child.TP(), child.ID()) + } + opt.AppendStepToCurrent(child.ID(), child.TP(), reason, action) +} + +func appendProjEliminateTraceStep(proj *logicalop.LogicalProjection, opt *optimizetrace.LogicalOptimizeOp) { + reason := func() string { + return fmt.Sprintf("%v_%v's Exprs are all Columns", proj.TP(), proj.ID()) + } + action := func() string { + return fmt.Sprintf("%v_%v is eliminated", proj.TP(), proj.ID()) + } + opt.AppendStepToCurrent(proj.ID(), proj.TP(), reason, action) +} diff --git a/pkg/planner/core/rule_inject_extra_projection.go b/pkg/planner/core/rule_inject_extra_projection.go index e86c7db13fdd9..8335516747c85 100644 --- a/pkg/planner/core/rule_inject_extra_projection.go +++ b/pkg/planner/core/rule_inject_extra_projection.go @@ -36,11 +36,11 @@ import ( // 2. TiDB can be used as a coprocessor, when a plan tree been pushed down to // TiDB, we need to inject extra projections for the plan tree as well. func InjectExtraProjection(plan base.PhysicalPlan) base.PhysicalPlan { - failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("DisableProjectionPostOptimization")); _err_ == nil { if val.(bool) { - failpoint.Return(plan) + return plan } - }) + } return NewProjInjector().inject(plan) } diff --git a/pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ b/pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ new file mode 100644 index 0000000000000..e86c7db13fdd9 --- /dev/null +++ b/pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ @@ -0,0 +1,352 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "slices" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/planner/util/coreusage" +) + +// InjectExtraProjection is used to extract the expressions of specific +// operators into a physical Projection operator and inject the Projection below +// the operators. Thus we can accelerate the expression evaluation by eager +// evaluation. +// This function will be called in two situations: +// 1. In postOptimize. +// 2. TiDB can be used as a coprocessor, when a plan tree been pushed down to +// TiDB, we need to inject extra projections for the plan tree as well. +func InjectExtraProjection(plan base.PhysicalPlan) base.PhysicalPlan { + failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(plan) + } + }) + + return NewProjInjector().inject(plan) +} + +type projInjector struct { +} + +// NewProjInjector builds a projInjector. +func NewProjInjector() *projInjector { + return &projInjector{} +} + +func (pe *projInjector) inject(plan base.PhysicalPlan) base.PhysicalPlan { + for i, child := range plan.Children() { + plan.Children()[i] = pe.inject(child) + } + + if tr, ok := plan.(*PhysicalTableReader); ok && tr.StoreType == kv.TiFlash { + tr.tablePlan = pe.inject(tr.tablePlan) + tr.TablePlans = flattenPushDownPlan(tr.tablePlan) + } + + switch p := plan.(type) { + case *PhysicalHashAgg: + plan = InjectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems) + case *PhysicalStreamAgg: + plan = InjectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems) + case *PhysicalSort: + plan = InjectProjBelowSort(p, p.ByItems) + case *PhysicalTopN: + plan = InjectProjBelowSort(p, p.ByItems) + case *NominalSort: + plan = TurnNominalSortIntoProj(p, p.OnlyColumn, p.ByItems) + case *PhysicalUnionAll: + plan = injectProjBelowUnion(p) + } + return plan +} + +func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll { + if !un.mpp { + return un + } + for i, ch := range un.children { + exprs := make([]expression.Expression, len(ch.Schema().Columns)) + needChange := false + for i, dstCol := range un.schema.Columns { + dstType := dstCol.RetType + srcCol := ch.Schema().Columns[i] + srcCol.Index = i + srcType := srcCol.RetType + if !srcType.Equal(dstType) || !(mysql.HasNotNullFlag(dstType.GetFlag()) == mysql.HasNotNullFlag(srcType.GetFlag())) { + exprs[i] = expression.BuildCastFunction4Union(un.SCtx().GetExprCtx(), srcCol, dstType) + needChange = true + } else { + exprs[i] = srcCol + } + } + if needChange { + proj := PhysicalProjection{ + Exprs: exprs, + }.Init(un.SCtx(), ch.StatsInfo(), 0) + proj.SetSchema(un.schema.Clone()) + proj.SetChildren(ch) + un.children[i] = proj + } + } + return un +} + +// InjectProjBelowAgg injects a ProjOperator below AggOperator. So that All +// scalar functions in aggregation may speed up by vectorized evaluation in +// the `proj`. If all the args of `aggFuncs`, and all the item of `groupByItems` +// are columns or constants, we do not need to build the `proj`. +func InjectProjBelowAgg(aggPlan base.PhysicalPlan, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression) base.PhysicalPlan { + hasScalarFunc := false + exprCtx := aggPlan.SCtx().GetExprCtx() + coreusage.WrapCastForAggFuncs(exprCtx, aggFuncs) + for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ { + for _, arg := range aggFuncs[i].Args { + _, isScalarFunc := arg.(*expression.ScalarFunction) + hasScalarFunc = hasScalarFunc || isScalarFunc + } + for _, byItem := range aggFuncs[i].OrderByItems { + _, isScalarFunc := byItem.Expr.(*expression.ScalarFunction) + hasScalarFunc = hasScalarFunc || isScalarFunc + } + } + for i := 0; !hasScalarFunc && i < len(groupByItems); i++ { + _, isScalarFunc := groupByItems[i].(*expression.ScalarFunction) + hasScalarFunc = hasScalarFunc || isScalarFunc + } + if !hasScalarFunc { + return aggPlan + } + + projSchemaCols := make([]*expression.Column, 0, len(aggFuncs)+len(groupByItems)) + projExprs := make([]expression.Expression, 0, cap(projSchemaCols)) + cursor := 0 + + ectx := exprCtx.GetEvalCtx() + for _, f := range aggFuncs { + for i, arg := range f.Args { + if _, isCnst := arg.(*expression.Constant); isCnst { + continue + } + projExprs = append(projExprs, arg) + newArg := &expression.Column{ + UniqueID: aggPlan.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: arg.GetType(ectx), + Index: cursor, + } + projSchemaCols = append(projSchemaCols, newArg) + f.Args[i] = newArg + cursor++ + } + for _, byItem := range f.OrderByItems { + bi := byItem.Expr + if _, isCnst := bi.(*expression.Constant); isCnst { + continue + } + idx := slices.IndexFunc(projExprs, func(a expression.Expression) bool { + return a.Equal(ectx, bi) + }) + if idx < 0 { + projExprs = append(projExprs, bi) + newArg := &expression.Column{ + UniqueID: aggPlan.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: bi.GetType(ectx), + Index: cursor, + } + projSchemaCols = append(projSchemaCols, newArg) + byItem.Expr = newArg + cursor++ + } else { + byItem.Expr = projSchemaCols[idx] + } + } + } + + for i, item := range groupByItems { + it := item + if _, isCnst := it.(*expression.Constant); isCnst { + continue + } + idx := slices.IndexFunc(projExprs, func(a expression.Expression) bool { + return a.Equal(ectx, it) + }) + if idx < 0 { + projExprs = append(projExprs, it) + newArg := &expression.Column{ + UniqueID: aggPlan.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: item.GetType(ectx), + Index: cursor, + } + projSchemaCols = append(projSchemaCols, newArg) + groupByItems[i] = newArg + cursor++ + } else { + groupByItems[i] = projSchemaCols[idx] + } + } + + child := aggPlan.Children()[0] + prop := aggPlan.GetChildReqProps(0).CloneEssentialFields() + proj := PhysicalProjection{ + Exprs: projExprs, + AvoidColumnEvaluator: false, + }.Init(aggPlan.SCtx(), child.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), aggPlan.QueryBlockOffset(), prop) + proj.SetSchema(expression.NewSchema(projSchemaCols...)) + proj.SetChildren(child) + + aggPlan.SetChildren(proj) + return aggPlan +} + +// InjectProjBelowSort extracts the ScalarFunctions of `orderByItems` into a +// PhysicalProjection and injects it below PhysicalTopN/PhysicalSort. The schema +// of PhysicalSort and PhysicalTopN are the same as the schema of their +// children. When a projection is injected as the child of PhysicalSort and +// PhysicalTopN, some extra columns will be added into the schema of the +// Projection, thus we need to add another Projection upon them to prune the +// redundant columns. +func InjectProjBelowSort(p base.PhysicalPlan, orderByItems []*util.ByItems) base.PhysicalPlan { + hasScalarFunc, numOrderByItems := false, len(orderByItems) + for i := 0; !hasScalarFunc && i < numOrderByItems; i++ { + _, isScalarFunc := orderByItems[i].Expr.(*expression.ScalarFunction) + hasScalarFunc = hasScalarFunc || isScalarFunc + } + if !hasScalarFunc { + return p + } + + topProjExprs := make([]expression.Expression, 0, p.Schema().Len()) + for i := range p.Schema().Columns { + col := p.Schema().Columns[i].Clone().(*expression.Column) + col.Index = i + topProjExprs = append(topProjExprs, col) + } + topProj := PhysicalProjection{ + Exprs: topProjExprs, + AvoidColumnEvaluator: false, + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), nil) + topProj.SetSchema(p.Schema().Clone()) + topProj.SetChildren(p) + + childPlan := p.Children()[0] + bottomProjSchemaCols := make([]*expression.Column, 0, len(childPlan.Schema().Columns)+numOrderByItems) + bottomProjExprs := make([]expression.Expression, 0, len(childPlan.Schema().Columns)+numOrderByItems) + for _, col := range childPlan.Schema().Columns { + newCol := col.Clone().(*expression.Column) + newCol.Index = childPlan.Schema().ColumnIndex(newCol) + bottomProjSchemaCols = append(bottomProjSchemaCols, newCol) + bottomProjExprs = append(bottomProjExprs, newCol) + } + + for _, item := range orderByItems { + itemExpr := item.Expr + if _, isScalarFunc := itemExpr.(*expression.ScalarFunction); !isScalarFunc { + continue + } + bottomProjExprs = append(bottomProjExprs, itemExpr) + newArg := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: itemExpr.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), + Index: len(bottomProjSchemaCols), + } + bottomProjSchemaCols = append(bottomProjSchemaCols, newArg) + item.Expr = newArg + } + + childProp := p.GetChildReqProps(0).CloneEssentialFields() + bottomProj := PhysicalProjection{ + Exprs: bottomProjExprs, + AvoidColumnEvaluator: false, + }.Init(p.SCtx(), childPlan.StatsInfo().ScaleByExpectCnt(childProp.ExpectedCnt), p.QueryBlockOffset(), childProp) + bottomProj.SetSchema(expression.NewSchema(bottomProjSchemaCols...)) + bottomProj.SetChildren(childPlan) + p.SetChildren(bottomProj) + + if origChildProj, isChildProj := childPlan.(*PhysicalProjection); isChildProj { + refine4NeighbourProj(bottomProj, origChildProj) + } + refine4NeighbourProj(topProj, bottomProj) + + return topProj +} + +// TurnNominalSortIntoProj will turn nominal sort into two projections. This is to check if the scalar functions will +// overflow. +func TurnNominalSortIntoProj(p base.PhysicalPlan, onlyColumn bool, orderByItems []*util.ByItems) base.PhysicalPlan { + if onlyColumn { + return p.Children()[0] + } + + numOrderByItems := len(orderByItems) + childPlan := p.Children()[0] + + bottomProjSchemaCols := make([]*expression.Column, 0, len(childPlan.Schema().Columns)+numOrderByItems) + bottomProjExprs := make([]expression.Expression, 0, len(childPlan.Schema().Columns)+numOrderByItems) + for _, col := range childPlan.Schema().Columns { + newCol := col.Clone().(*expression.Column) + newCol.Index = childPlan.Schema().ColumnIndex(newCol) + bottomProjSchemaCols = append(bottomProjSchemaCols, newCol) + bottomProjExprs = append(bottomProjExprs, newCol) + } + + for _, item := range orderByItems { + itemExpr := item.Expr + if _, isScalarFunc := itemExpr.(*expression.ScalarFunction); !isScalarFunc { + continue + } + bottomProjExprs = append(bottomProjExprs, itemExpr) + newArg := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: itemExpr.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), + Index: len(bottomProjSchemaCols), + } + bottomProjSchemaCols = append(bottomProjSchemaCols, newArg) + } + + childProp := p.GetChildReqProps(0).CloneEssentialFields() + bottomProj := PhysicalProjection{ + Exprs: bottomProjExprs, + AvoidColumnEvaluator: false, + }.Init(p.SCtx(), childPlan.StatsInfo().ScaleByExpectCnt(childProp.ExpectedCnt), p.QueryBlockOffset(), childProp) + bottomProj.SetSchema(expression.NewSchema(bottomProjSchemaCols...)) + bottomProj.SetChildren(childPlan) + + topProjExprs := make([]expression.Expression, 0, childPlan.Schema().Len()) + for i := range childPlan.Schema().Columns { + col := childPlan.Schema().Columns[i].Clone().(*expression.Column) + col.Index = i + topProjExprs = append(topProjExprs, col) + } + topProj := PhysicalProjection{ + Exprs: topProjExprs, + AvoidColumnEvaluator: false, + }.Init(p.SCtx(), childPlan.StatsInfo().ScaleByExpectCnt(childProp.ExpectedCnt), p.QueryBlockOffset(), childProp) + topProj.SetSchema(childPlan.Schema().Clone()) + topProj.SetChildren(bottomProj) + + if origChildProj, isChildProj := childPlan.(*PhysicalProjection); isChildProj { + refine4NeighbourProj(bottomProj, origChildProj) + } + refine4NeighbourProj(topProj, bottomProj) + + return topProj +} diff --git a/pkg/planner/core/task.go b/pkg/planner/core/task.go index b926f5d9bb9fc..493c4e6b68145 100644 --- a/pkg/planner/core/task.go +++ b/pkg/planner/core/task.go @@ -2307,11 +2307,11 @@ func (p *PhysicalWindow) attach2TaskForMPP(mpp *MppTask) base.Task { columns := p.Schema().Clone().Columns[len(p.Schema().Columns)-len(p.WindowFuncDescs):] p.schema = expression.MergeSchema(mpp.Plan().Schema(), expression.NewSchema(columns...)) - failpoint.Inject("CheckMPPWindowSchemaLength", func() { + if _, _err_ := failpoint.Eval(_curpkg_("CheckMPPWindowSchemaLength")); _err_ == nil { if len(p.Schema().Columns) != len(mpp.Plan().Schema().Columns)+len(p.WindowFuncDescs) { panic("mpp physical window has incorrect schema length") } - }) + } return attachPlan2Task(p, mpp) } diff --git a/pkg/planner/core/task.go__failpoint_stash__ b/pkg/planner/core/task.go__failpoint_stash__ new file mode 100644 index 0000000000000..b926f5d9bb9fc --- /dev/null +++ b/pkg/planner/core/task.go__failpoint_stash__ @@ -0,0 +1,2473 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "math" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/cardinality" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/core/cost" + "github.com/pingcap/tidb/pkg/planner/core/operator/baseimpl" + "github.com/pingcap/tidb/pkg/planner/property" + "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/paging" + "github.com/pingcap/tidb/pkg/util/plancodec" + "go.uber.org/zap" +) + +func attachPlan2Task(p base.PhysicalPlan, t base.Task) base.Task { + switch v := t.(type) { + case *CopTask: + if v.indexPlanFinished { + p.SetChildren(v.tablePlan) + v.tablePlan = p + } else { + p.SetChildren(v.indexPlan) + v.indexPlan = p + } + case *RootTask: + p.SetChildren(v.GetPlan()) + v.SetPlan(p) + case *MppTask: + p.SetChildren(v.p) + v.p = p + } + return t +} + +// finishIndexPlan means we no longer add plan to index plan, and compute the network cost for it. +func (t *CopTask) finishIndexPlan() { + if t.indexPlanFinished { + return + } + t.indexPlanFinished = true + // index merge case is specially handled for now. + // We need a elegant way to solve the stats of index merge in this case. + if t.tablePlan != nil && t.indexPlan != nil { + ts := t.tablePlan.(*PhysicalTableScan) + originStats := ts.StatsInfo() + ts.SetStats(t.indexPlan.StatsInfo()) + if originStats != nil { + // keep the original stats version + ts.StatsInfo().StatsVersion = originStats.StatsVersion + } + } +} + +func (t *CopTask) getStoreType() kv.StoreType { + if t.tablePlan == nil { + return kv.TiKV + } + tp := t.tablePlan + for len(tp.Children()) > 0 { + if len(tp.Children()) > 1 { + return kv.TiFlash + } + tp = tp.Children()[0] + } + if ts, ok := tp.(*PhysicalTableScan); ok { + return ts.StoreType + } + return kv.TiKV +} + +// Attach2Task implements PhysicalPlan interface. +func (p *basePhysicalPlan) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].ConvertToRootTask(p.SCtx()) + return attachPlan2Task(p.self, t) +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalUnionScan) Attach2Task(tasks ...base.Task) base.Task { + // We need to pull the projection under unionScan upon unionScan. + // Since the projection only prunes columns, it's ok the put it upon unionScan. + if sel, ok := tasks[0].Plan().(*PhysicalSelection); ok { + if pj, ok := sel.children[0].(*PhysicalProjection); ok { + // Convert unionScan->selection->projection to projection->unionScan->selection. + sel.SetChildren(pj.children...) + p.SetChildren(sel) + p.SetStats(tasks[0].Plan().StatsInfo()) + rt, _ := tasks[0].(*RootTask) + rt.SetPlan(p) + pj.SetChildren(p) + return pj.Attach2Task(tasks...) + } + } + if pj, ok := tasks[0].Plan().(*PhysicalProjection); ok { + // Convert unionScan->projection to projection->unionScan, because unionScan can't handle projection as its children. + p.SetChildren(pj.children...) + p.SetStats(tasks[0].Plan().StatsInfo()) + rt, _ := tasks[0].(*RootTask) + rt.SetPlan(pj.children[0]) + pj.SetChildren(p) + return pj.Attach2Task(p.basePhysicalPlan.Attach2Task(tasks...)) + } + p.SetStats(tasks[0].Plan().StatsInfo()) + return p.basePhysicalPlan.Attach2Task(tasks...) +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalApply) Attach2Task(tasks ...base.Task) base.Task { + lTask := tasks[0].ConvertToRootTask(p.SCtx()) + rTask := tasks[1].ConvertToRootTask(p.SCtx()) + p.SetChildren(lTask.Plan(), rTask.Plan()) + p.schema = BuildPhysicalJoinSchema(p.JoinType, p) + t := &RootTask{} + t.SetPlan(p) + return t +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalIndexMergeJoin) Attach2Task(tasks ...base.Task) base.Task { + outerTask := tasks[1-p.InnerChildIdx].ConvertToRootTask(p.SCtx()) + if p.InnerChildIdx == 1 { + p.SetChildren(outerTask.Plan(), p.innerPlan) + } else { + p.SetChildren(p.innerPlan, outerTask.Plan()) + } + t := &RootTask{} + t.SetPlan(p) + return t +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalIndexHashJoin) Attach2Task(tasks ...base.Task) base.Task { + outerTask := tasks[1-p.InnerChildIdx].ConvertToRootTask(p.SCtx()) + if p.InnerChildIdx == 1 { + p.SetChildren(outerTask.Plan(), p.innerPlan) + } else { + p.SetChildren(p.innerPlan, outerTask.Plan()) + } + t := &RootTask{} + t.SetPlan(p) + return t +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalIndexJoin) Attach2Task(tasks ...base.Task) base.Task { + outerTask := tasks[1-p.InnerChildIdx].ConvertToRootTask(p.SCtx()) + if p.InnerChildIdx == 1 { + p.SetChildren(outerTask.Plan(), p.innerPlan) + } else { + p.SetChildren(p.innerPlan, outerTask.Plan()) + } + t := &RootTask{} + t.SetPlan(p) + return t +} + +// RowSize for cost model ver2 is simplified, always use this function to calculate row size. +func getAvgRowSize(stats *property.StatsInfo, cols []*expression.Column) (size float64) { + if stats.HistColl != nil { + size = cardinality.GetAvgRowSizeDataInDiskByRows(stats.HistColl, cols) + } else { + // Estimate using just the type info. + for _, col := range cols { + size += float64(chunk.EstimateTypeWidth(col.GetStaticType())) + } + } + return +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalHashJoin) Attach2Task(tasks ...base.Task) base.Task { + if p.storeTp == kv.TiFlash { + return p.attach2TaskForTiFlash(tasks...) + } + lTask := tasks[0].ConvertToRootTask(p.SCtx()) + rTask := tasks[1].ConvertToRootTask(p.SCtx()) + p.SetChildren(lTask.Plan(), rTask.Plan()) + task := &RootTask{} + task.SetPlan(p) + return task +} + +// TiDB only require that the types fall into the same catalog but TiFlash require the type to be exactly the same, so +// need to check if the conversion is a must +func needConvert(tp *types.FieldType, rtp *types.FieldType) bool { + // all the string type are mapped to the same type in TiFlash, so + // do not need convert for string types + if types.IsString(tp.GetType()) && types.IsString(rtp.GetType()) { + return false + } + if tp.GetType() != rtp.GetType() { + return true + } + if tp.GetType() != mysql.TypeNewDecimal { + return false + } + if tp.GetDecimal() != rtp.GetDecimal() { + return true + } + // for decimal type, TiFlash have 4 different impl based on the required precision + if tp.GetFlen() >= 0 && tp.GetFlen() <= 9 && rtp.GetFlen() >= 0 && rtp.GetFlen() <= 9 { + return false + } + if tp.GetFlen() > 9 && tp.GetFlen() <= 18 && rtp.GetFlen() > 9 && rtp.GetFlen() <= 18 { + return false + } + if tp.GetFlen() > 18 && tp.GetFlen() <= 38 && rtp.GetFlen() > 18 && rtp.GetFlen() <= 38 { + return false + } + if tp.GetFlen() > 38 && tp.GetFlen() <= 65 && rtp.GetFlen() > 38 && rtp.GetFlen() <= 65 { + return false + } + return true +} + +func negotiateCommonType(lType, rType *types.FieldType) (*types.FieldType, bool, bool) { + commonType := types.AggFieldType([]*types.FieldType{lType, rType}) + if commonType.GetType() == mysql.TypeNewDecimal { + lExtend := 0 + rExtend := 0 + cDec := rType.GetDecimal() + if lType.GetDecimal() < rType.GetDecimal() { + lExtend = rType.GetDecimal() - lType.GetDecimal() + } else if lType.GetDecimal() > rType.GetDecimal() { + rExtend = lType.GetDecimal() - rType.GetDecimal() + cDec = lType.GetDecimal() + } + lLen, rLen := lType.GetFlen()+lExtend, rType.GetFlen()+rExtend + cLen := max(lLen, rLen) + commonType.SetDecimalUnderLimit(cDec) + commonType.SetFlenUnderLimit(cLen) + } else if needConvert(lType, commonType) || needConvert(rType, commonType) { + if mysql.IsIntegerType(commonType.GetType()) { + // If the target type is int, both TiFlash and Mysql only support cast to Int64 + // so we need to promote the type to Int64 + commonType.SetType(mysql.TypeLonglong) + commonType.SetFlen(mysql.MaxIntWidth) + } + } + return commonType, needConvert(lType, commonType), needConvert(rType, commonType) +} + +func getProj(ctx base.PlanContext, p base.PhysicalPlan) *PhysicalProjection { + proj := PhysicalProjection{ + Exprs: make([]expression.Expression, 0, len(p.Schema().Columns)), + }.Init(ctx, p.StatsInfo(), p.QueryBlockOffset()) + for _, col := range p.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + } + proj.SetSchema(p.Schema().Clone()) + proj.SetChildren(p) + return proj +} + +func appendExpr(p *PhysicalProjection, expr expression.Expression) *expression.Column { + p.Exprs = append(p.Exprs, expr) + + col := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: expr.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), + } + col.SetCoercibility(expr.Coercibility()) + p.schema.Append(col) + return col +} + +// TiFlash join require that partition key has exactly the same type, while TiDB only guarantee the partition key is the same catalog, +// so if the partition key type is not exactly the same, we need add a projection below the join or exchanger if exists. +func (p *PhysicalHashJoin) convertPartitionKeysIfNeed(lTask, rTask *MppTask) (*MppTask, *MppTask) { + lp := lTask.p + if _, ok := lp.(*PhysicalExchangeReceiver); ok { + lp = lp.Children()[0].Children()[0] + } + rp := rTask.p + if _, ok := rp.(*PhysicalExchangeReceiver); ok { + rp = rp.Children()[0].Children()[0] + } + // to mark if any partition key needs to convert + lMask := make([]bool, len(lTask.hashCols)) + rMask := make([]bool, len(rTask.hashCols)) + cTypes := make([]*types.FieldType, len(lTask.hashCols)) + lChanged := false + rChanged := false + for i := range lTask.hashCols { + lKey := lTask.hashCols[i] + rKey := rTask.hashCols[i] + cType, lConvert, rConvert := negotiateCommonType(lKey.Col.RetType, rKey.Col.RetType) + if lConvert { + lMask[i] = true + cTypes[i] = cType + lChanged = true + } + if rConvert { + rMask[i] = true + cTypes[i] = cType + rChanged = true + } + } + if !lChanged && !rChanged { + return lTask, rTask + } + var lProj, rProj *PhysicalProjection + if lChanged { + lProj = getProj(p.SCtx(), lp) + lp = lProj + } + if rChanged { + rProj = getProj(p.SCtx(), rp) + rp = rProj + } + + lPartKeys := make([]*property.MPPPartitionColumn, 0, len(rTask.hashCols)) + rPartKeys := make([]*property.MPPPartitionColumn, 0, len(lTask.hashCols)) + for i := range lTask.hashCols { + lKey := lTask.hashCols[i] + rKey := rTask.hashCols[i] + if lMask[i] { + cType := cTypes[i].Clone() + cType.SetFlag(lKey.Col.RetType.GetFlag()) + lCast := expression.BuildCastFunction(p.SCtx().GetExprCtx(), lKey.Col, cType) + lKey = &property.MPPPartitionColumn{Col: appendExpr(lProj, lCast), CollateID: lKey.CollateID} + } + if rMask[i] { + cType := cTypes[i].Clone() + cType.SetFlag(rKey.Col.RetType.GetFlag()) + rCast := expression.BuildCastFunction(p.SCtx().GetExprCtx(), rKey.Col, cType) + rKey = &property.MPPPartitionColumn{Col: appendExpr(rProj, rCast), CollateID: rKey.CollateID} + } + lPartKeys = append(lPartKeys, lKey) + rPartKeys = append(rPartKeys, rKey) + } + // if left or right child changes, we need to add enforcer. + if lChanged { + nlTask := lTask.Copy().(*MppTask) + nlTask.p = lProj + nlTask = nlTask.enforceExchanger(&property.PhysicalProperty{ + TaskTp: property.MppTaskType, + MPPPartitionTp: property.HashType, + MPPPartitionCols: lPartKeys, + }) + lTask = nlTask + } + if rChanged { + nrTask := rTask.Copy().(*MppTask) + nrTask.p = rProj + nrTask = nrTask.enforceExchanger(&property.PhysicalProperty{ + TaskTp: property.MppTaskType, + MPPPartitionTp: property.HashType, + MPPPartitionCols: rPartKeys, + }) + rTask = nrTask + } + return lTask, rTask +} + +func (p *PhysicalHashJoin) attach2TaskForMpp(tasks ...base.Task) base.Task { + lTask, lok := tasks[0].(*MppTask) + rTask, rok := tasks[1].(*MppTask) + if !lok || !rok { + return base.InvalidTask + } + if p.mppShuffleJoin { + // protection check is case of some bugs + if len(lTask.hashCols) != len(rTask.hashCols) || len(lTask.hashCols) == 0 { + return base.InvalidTask + } + lTask, rTask = p.convertPartitionKeysIfNeed(lTask, rTask) + } + p.SetChildren(lTask.Plan(), rTask.Plan()) + // outer task is the task that will pass its MPPPartitionType to the join result + // for broadcast inner join, it should be the non-broadcast side, since broadcast side is always the build side, so + // just use the probe side is ok. + // for hash inner join, both side is ok, by default, we use the probe side + // for outer join, it should always be the outer side of the join + // for semi join, it should be the left side(the same as left out join) + outerTaskIndex := 1 - p.InnerChildIdx + if p.JoinType != InnerJoin { + if p.JoinType == RightOuterJoin { + outerTaskIndex = 1 + } else { + outerTaskIndex = 0 + } + } + // can not use the task from tasks because it maybe updated. + outerTask := lTask + if outerTaskIndex == 1 { + outerTask = rTask + } + task := &MppTask{ + p: p, + partTp: outerTask.partTp, + hashCols: outerTask.hashCols, + } + // Current TiFlash doesn't support receive Join executors' schema info directly from TiDB. + // Instead, it calculates Join executors' output schema using algorithm like BuildPhysicalJoinSchema which + // produces full semantic schema. + // Thus, the column prune optimization achievements will be abandoned here. + // To avoid the performance issue, add a projection here above the Join operator to prune useless columns explicitly. + // TODO(hyb): transfer Join executors' schema to TiFlash through DagRequest, and use it directly in TiFlash. + defaultSchema := BuildPhysicalJoinSchema(p.JoinType, p) + hashColArray := make([]*expression.Column, 0, len(task.hashCols)) + // For task.hashCols, these columns may not be contained in pruned columns: + // select A.id from A join B on A.id = B.id; Suppose B is probe side, and it's hash inner join. + // After column prune, the output schema of A join B will be A.id only; while the task's hashCols will be B.id. + // To make matters worse, the hashCols may be used to check if extra cast projection needs to be added, then the newly + // added projection will expect B.id as input schema. So make sure hashCols are included in task.p's schema. + // TODO: planner should takes the hashCols attribute into consideration when perform column pruning; Or provide mechanism + // to constraint hashCols are always chosen inside Join's pruned schema + for _, hashCol := range task.hashCols { + hashColArray = append(hashColArray, hashCol.Col) + } + if p.schema.Len() < defaultSchema.Len() { + if p.schema.Len() > 0 { + proj := PhysicalProjection{ + Exprs: expression.Column2Exprs(p.schema.Columns), + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) + + proj.SetSchema(p.Schema().Clone()) + for _, hashCol := range hashColArray { + if !proj.Schema().Contains(hashCol) && defaultSchema.Contains(hashCol) { + joinCol := defaultSchema.Columns[defaultSchema.ColumnIndex(hashCol)] + proj.Exprs = append(proj.Exprs, joinCol) + proj.Schema().Append(joinCol.Clone().(*expression.Column)) + } + } + attachPlan2Task(proj, task) + } else { + if len(hashColArray) == 0 { + constOne := expression.NewOne() + expr := make([]expression.Expression, 0, 1) + expr = append(expr, constOne) + proj := PhysicalProjection{ + Exprs: expr, + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) + + proj.schema = expression.NewSchema(&expression.Column{ + UniqueID: proj.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: constOne.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), + }) + attachPlan2Task(proj, task) + } else { + proj := PhysicalProjection{ + Exprs: make([]expression.Expression, 0, len(hashColArray)), + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) + + clonedHashColArray := make([]*expression.Column, 0, len(task.hashCols)) + for _, hashCol := range hashColArray { + if defaultSchema.Contains(hashCol) { + joinCol := defaultSchema.Columns[defaultSchema.ColumnIndex(hashCol)] + proj.Exprs = append(proj.Exprs, joinCol) + clonedHashColArray = append(clonedHashColArray, joinCol.Clone().(*expression.Column)) + } + } + + proj.SetSchema(expression.NewSchema(clonedHashColArray...)) + attachPlan2Task(proj, task) + } + } + } + p.schema = defaultSchema + return task +} + +func (p *PhysicalHashJoin) attach2TaskForTiFlash(tasks ...base.Task) base.Task { + lTask, lok := tasks[0].(*CopTask) + rTask, rok := tasks[1].(*CopTask) + if !lok || !rok { + return p.attach2TaskForMpp(tasks...) + } + p.SetChildren(lTask.Plan(), rTask.Plan()) + p.schema = BuildPhysicalJoinSchema(p.JoinType, p) + if !lTask.indexPlanFinished { + lTask.finishIndexPlan() + } + if !rTask.indexPlanFinished { + rTask.finishIndexPlan() + } + + task := &CopTask{ + tblColHists: rTask.tblColHists, + indexPlanFinished: true, + tablePlan: p, + } + return task +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalMergeJoin) Attach2Task(tasks ...base.Task) base.Task { + lTask := tasks[0].ConvertToRootTask(p.SCtx()) + rTask := tasks[1].ConvertToRootTask(p.SCtx()) + p.SetChildren(lTask.Plan(), rTask.Plan()) + t := &RootTask{} + t.SetPlan(p) + return t +} + +func buildIndexLookUpTask(ctx base.PlanContext, t *CopTask) *RootTask { + newTask := &RootTask{} + p := PhysicalIndexLookUpReader{ + tablePlan: t.tablePlan, + indexPlan: t.indexPlan, + ExtraHandleCol: t.extraHandleCol, + CommonHandleCols: t.commonHandleCols, + expectedCnt: t.expectCnt, + keepOrder: t.keepOrder, + }.Init(ctx, t.tablePlan.QueryBlockOffset()) + p.PlanPartInfo = t.physPlanPartInfo + setTableScanToTableRowIDScan(p.tablePlan) + p.SetStats(t.tablePlan.StatsInfo()) + // Do not inject the extra Projection even if t.needExtraProj is set, or the schema between the phase-1 agg and + // the final agg would be broken. Please reference comments for the similar logic in + // (*copTask).convertToRootTaskImpl() for the PhysicalTableReader case. + // We need to refactor these logics. + aggPushedDown := false + switch p.tablePlan.(type) { + case *PhysicalHashAgg, *PhysicalStreamAgg: + aggPushedDown = true + } + + if t.needExtraProj && !aggPushedDown { + schema := t.originSchema + proj := PhysicalProjection{Exprs: expression.Column2Exprs(schema.Columns)}.Init(ctx, p.StatsInfo(), t.tablePlan.QueryBlockOffset(), nil) + proj.SetSchema(schema) + proj.SetChildren(p) + newTask.SetPlan(proj) + } else { + newTask.SetPlan(p) + } + return newTask +} + +func extractRows(p base.PhysicalPlan) float64 { + f := float64(0) + for _, c := range p.Children() { + if len(c.Children()) != 0 { + f += extractRows(c) + } else { + f += c.StatsInfo().RowCount + } + } + return f +} + +// calcPagingCost calculates the cost for paging processing which may increase the seekCnt and reduce scanned rows. +func calcPagingCost(ctx base.PlanContext, indexPlan base.PhysicalPlan, expectCnt uint64) float64 { + sessVars := ctx.GetSessionVars() + indexRows := indexPlan.StatsCount() + sourceRows := extractRows(indexPlan) + // with paging, the scanned rows is always less than or equal to source rows. + if uint64(sourceRows) < expectCnt { + expectCnt = uint64(sourceRows) + } + seekCnt := paging.CalculateSeekCnt(expectCnt) + indexSelectivity := float64(1) + if sourceRows > indexRows { + indexSelectivity = indexRows / sourceRows + } + pagingCst := seekCnt*sessVars.GetSeekFactor(nil) + float64(expectCnt)*sessVars.GetCPUFactor() + pagingCst *= indexSelectivity + + // we want the diff between idxCst and pagingCst here, + // however, the idxCst does not contain seekFactor, so a seekFactor needs to be removed + return math.Max(pagingCst-sessVars.GetSeekFactor(nil), 0) +} + +func (t *CopTask) handleRootTaskConds(ctx base.PlanContext, newTask *RootTask) { + if len(t.rootTaskConds) > 0 { + selectivity, _, err := cardinality.Selectivity(ctx, t.tblColHists, t.rootTaskConds, nil) + if err != nil { + logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) + selectivity = cost.SelectionFactor + } + sel := PhysicalSelection{Conditions: t.rootTaskConds}.Init(ctx, newTask.GetPlan().StatsInfo().Scale(selectivity), newTask.GetPlan().QueryBlockOffset()) + sel.fromDataSource = true + sel.SetChildren(newTask.GetPlan()) + newTask.SetPlan(sel) + } +} + +// setTableScanToTableRowIDScan is to update the isChildOfIndexLookUp attribute of PhysicalTableScan child +func setTableScanToTableRowIDScan(p base.PhysicalPlan) { + if ts, ok := p.(*PhysicalTableScan); ok { + ts.SetIsChildOfIndexLookUp(true) + } else { + for _, child := range p.Children() { + setTableScanToTableRowIDScan(child) + } + } +} + +// Attach2Task attach limit to different cases. +// For Normal Index Lookup +// 1: attach the limit to table side or index side of normal index lookup cop task. (normal case, old code, no more +// explanation here) +// +// For Index Merge: +// 2: attach the limit to **table** side for index merge intersection case, cause intersection will invalidate the +// fetched limit+offset rows from each partial index plan, you can not decide how many you want in advance for partial +// index path, actually. After we sink limit to table side, we still need an upper root limit to control the real limit +// count admission. +// +// 3: attach the limit to **index** side for index merge union case, because each index plan will output the fetched +// limit+offset (* N path) rows, you still need an embedded pushedLimit inside index merge reader to cut it down. +// +// 4: attach the limit to the TOP of root index merge operator if there is some root condition exists for index merge +// intersection/union case. +func (p *PhysicalLimit) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + newPartitionBy := make([]property.SortItem, 0, len(p.GetPartitionBy())) + for _, expr := range p.GetPartitionBy() { + newPartitionBy = append(newPartitionBy, expr.Clone()) + } + + sunk := false + if cop, ok := t.(*CopTask); ok { + suspendLimitAboveTablePlan := func() { + newCount := p.Offset + p.Count + childProfile := cop.tablePlan.StatsInfo() + // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. + stats := util.DeriveLimitStats(childProfile, float64(newCount)) + pushedDownLimit := PhysicalLimit{PartitionBy: newPartitionBy, Count: newCount}.Init(p.SCtx(), stats, p.QueryBlockOffset()) + pushedDownLimit.SetChildren(cop.tablePlan) + cop.tablePlan = pushedDownLimit + // Don't use clone() so that Limit and its children share the same schema. Otherwise, the virtual generated column may not be resolved right. + pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) + t = cop.ConvertToRootTask(p.SCtx()) + } + if len(cop.idxMergePartPlans) == 0 { + // For double read which requires order being kept, the limit cannot be pushed down to the table side, + // because handles would be reordered before being sent to table scan. + if (!cop.keepOrder || !cop.indexPlanFinished || cop.indexPlan == nil) && len(cop.rootTaskConds) == 0 { + // When limit is pushed down, we should remove its offset. + newCount := p.Offset + p.Count + childProfile := cop.Plan().StatsInfo() + // Strictly speaking, for the row count of stats, we should multiply newCount with "regionNum", + // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. + stats := util.DeriveLimitStats(childProfile, float64(newCount)) + pushedDownLimit := PhysicalLimit{PartitionBy: newPartitionBy, Count: newCount}.Init(p.SCtx(), stats, p.QueryBlockOffset()) + cop = attachPlan2Task(pushedDownLimit, cop).(*CopTask) + // Don't use clone() so that Limit and its children share the same schema. Otherwise the virtual generated column may not be resolved right. + pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) + } + t = cop.ConvertToRootTask(p.SCtx()) + sunk = p.sinkIntoIndexLookUp(t) + } else if !cop.idxMergeIsIntersection { + // We only support push part of the order prop down to index merge build case. + if len(cop.rootTaskConds) == 0 { + // For double read which requires order being kept, the limit cannot be pushed down to the table side, + // because handles would be reordered before being sent to table scan. + if cop.indexPlanFinished && !cop.keepOrder { + // when the index plan is finished and index plan is not ordered, sink the limit to the index merge table side. + suspendLimitAboveTablePlan() + } else if !cop.indexPlanFinished { + // cop.indexPlanFinished = false indicates the table side is a pure table-scan, sink the limit to the index merge index side. + newCount := p.Offset + p.Count + limitChildren := make([]base.PhysicalPlan, 0, len(cop.idxMergePartPlans)) + for _, partialScan := range cop.idxMergePartPlans { + childProfile := partialScan.StatsInfo() + stats := util.DeriveLimitStats(childProfile, float64(newCount)) + pushedDownLimit := PhysicalLimit{PartitionBy: newPartitionBy, Count: newCount}.Init(p.SCtx(), stats, p.QueryBlockOffset()) + pushedDownLimit.SetChildren(partialScan) + pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) + limitChildren = append(limitChildren, pushedDownLimit) + } + cop.idxMergePartPlans = limitChildren + t = cop.ConvertToRootTask(p.SCtx()) + sunk = p.sinkIntoIndexMerge(t) + } else { + // when there are some limitations, just sink the limit upon the index merge reader. + t = cop.ConvertToRootTask(p.SCtx()) + sunk = p.sinkIntoIndexMerge(t) + } + } else { + // when there are some root conditions, just sink the limit upon the index merge reader. + t = cop.ConvertToRootTask(p.SCtx()) + sunk = p.sinkIntoIndexMerge(t) + } + } else if cop.idxMergeIsIntersection { + // In the index merge with intersection case, only the limit can be pushed down to the index merge table side. + // Note Difference: + // IndexMerge.PushedLimit is applied before table scan fetching, limiting the indexPartialPlan rows returned (it maybe ordered if orderBy items not empty) + // TableProbeSide sink limit is applied on the top of table plan, which will quickly shut down the both fetch-back and read-back process. + if len(cop.rootTaskConds) == 0 { + if cop.indexPlanFinished { + // indicates the table side is not a pure table-scan, so we could only append the limit upon the table plan. + suspendLimitAboveTablePlan() + } else { + t = cop.ConvertToRootTask(p.SCtx()) + sunk = p.sinkIntoIndexMerge(t) + } + } else { + // Otherwise, suspend the limit out of index merge reader. + t = cop.ConvertToRootTask(p.SCtx()) + sunk = p.sinkIntoIndexMerge(t) + } + } else { + // Whatever the remained case is, we directly convert to it to root task. + t = cop.ConvertToRootTask(p.SCtx()) + } + } else if mpp, ok := t.(*MppTask); ok { + newCount := p.Offset + p.Count + childProfile := mpp.Plan().StatsInfo() + stats := util.DeriveLimitStats(childProfile, float64(newCount)) + pushedDownLimit := PhysicalLimit{Count: newCount, PartitionBy: newPartitionBy}.Init(p.SCtx(), stats, p.QueryBlockOffset()) + mpp = attachPlan2Task(pushedDownLimit, mpp).(*MppTask) + pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) + t = mpp.ConvertToRootTask(p.SCtx()) + } + if sunk { + return t + } + // Skip limit with partition on the root. This is a derived topN and window function + // will take care of the filter. + if len(p.GetPartitionBy()) > 0 { + return t + } + return attachPlan2Task(p, t) +} + +func (p *PhysicalLimit) sinkIntoIndexLookUp(t base.Task) bool { + root := t.(*RootTask) + reader, isDoubleRead := root.GetPlan().(*PhysicalIndexLookUpReader) + proj, isProj := root.GetPlan().(*PhysicalProjection) + if !isDoubleRead && !isProj { + return false + } + if isProj { + reader, isDoubleRead = proj.Children()[0].(*PhysicalIndexLookUpReader) + if !isDoubleRead { + return false + } + } + + // We can sink Limit into IndexLookUpReader only if tablePlan contains no Selection. + ts, isTableScan := reader.tablePlan.(*PhysicalTableScan) + if !isTableScan { + return false + } + + // If this happens, some Projection Operator must be inlined into this Limit. (issues/14428) + // For example, if the original plan is `IndexLookUp(col1, col2) -> Limit(col1, col2) -> Project(col1)`, + // then after inlining the Project, it will be `IndexLookUp(col1, col2) -> Limit(col1)` here. + // If the Limit is sunk into the IndexLookUp, the IndexLookUp's schema needs to be updated as well, + // So we add an extra projection to solve the problem. + if p.Schema().Len() != reader.Schema().Len() { + extraProj := PhysicalProjection{ + Exprs: expression.Column2Exprs(p.schema.Columns), + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), nil) + extraProj.SetSchema(p.schema) + // If the root.p is already a Projection. We left the optimization for the later Projection Elimination. + extraProj.SetChildren(root.GetPlan()) + root.SetPlan(extraProj) + } + + reader.PushedLimit = &PushedDownLimit{ + Offset: p.Offset, + Count: p.Count, + } + originStats := ts.StatsInfo() + ts.SetStats(p.StatsInfo()) + if originStats != nil { + // keep the original stats version + ts.StatsInfo().StatsVersion = originStats.StatsVersion + } + reader.SetStats(p.StatsInfo()) + if isProj { + proj.SetStats(p.StatsInfo()) + } + return true +} + +func (p *PhysicalLimit) sinkIntoIndexMerge(t base.Task) bool { + root := t.(*RootTask) + imReader, isIm := root.GetPlan().(*PhysicalIndexMergeReader) + proj, isProj := root.GetPlan().(*PhysicalProjection) + if !isIm && !isProj { + return false + } + if isProj { + imReader, isIm = proj.Children()[0].(*PhysicalIndexMergeReader) + if !isIm { + return false + } + } + ts, ok := imReader.tablePlan.(*PhysicalTableScan) + if !ok { + return false + } + imReader.PushedLimit = &PushedDownLimit{ + Count: p.Count, + Offset: p.Offset, + } + // since ts.statsInfo.rowcount may dramatically smaller than limit.statsInfo. + // like limit: rowcount=1 + // ts: rowcount=0.0025 + originStats := ts.StatsInfo() + if originStats != nil { + // keep the original stats version + ts.StatsInfo().StatsVersion = originStats.StatsVersion + if originStats.RowCount < p.StatsInfo().RowCount { + ts.StatsInfo().RowCount = originStats.RowCount + } + } + needProj := p.schema.Len() != root.GetPlan().Schema().Len() + if !needProj { + for i := 0; i < p.schema.Len(); i++ { + if !p.schema.Columns[i].EqualColumn(root.GetPlan().Schema().Columns[i]) { + needProj = true + break + } + } + } + if needProj { + extraProj := PhysicalProjection{ + Exprs: expression.Column2Exprs(p.schema.Columns), + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), nil) + extraProj.SetSchema(p.schema) + // If the root.p is already a Projection. We left the optimization for the later Projection Elimination. + extraProj.SetChildren(root.GetPlan()) + root.SetPlan(extraProj) + } + return true +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalSort) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + t = attachPlan2Task(p, t) + return t +} + +// Attach2Task implements PhysicalPlan interface. +func (p *NominalSort) Attach2Task(tasks ...base.Task) base.Task { + if p.OnlyColumn { + return tasks[0] + } + t := tasks[0].Copy() + t = attachPlan2Task(p, t) + return t +} + +func (p *PhysicalTopN) getPushedDownTopN(childPlan base.PhysicalPlan) *PhysicalTopN { + newByItems := make([]*util.ByItems, 0, len(p.ByItems)) + for _, expr := range p.ByItems { + newByItems = append(newByItems, expr.Clone()) + } + newPartitionBy := make([]property.SortItem, 0, len(p.GetPartitionBy())) + for _, expr := range p.GetPartitionBy() { + newPartitionBy = append(newPartitionBy, expr.Clone()) + } + newCount := p.Offset + p.Count + childProfile := childPlan.StatsInfo() + // Strictly speaking, for the row count of pushed down TopN, we should multiply newCount with "regionNum", + // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. + stats := util.DeriveLimitStats(childProfile, float64(newCount)) + topN := PhysicalTopN{ + ByItems: newByItems, + PartitionBy: newPartitionBy, + Count: newCount, + }.Init(p.SCtx(), stats, p.QueryBlockOffset(), p.GetChildReqProps(0)) + topN.SetChildren(childPlan) + return topN +} + +// canPushToIndexPlan checks if this TopN can be pushed to the index side of copTask. +// It can be pushed to the index side when all columns used by ByItems are available from the index side and there's no prefix index column. +func (*PhysicalTopN) canPushToIndexPlan(indexPlan base.PhysicalPlan, byItemCols []*expression.Column) bool { + // If we call canPushToIndexPlan and there's no index plan, we should go into the index merge case. + // Index merge case is specially handled for now. So we directly return false here. + // So we directly return false. + if indexPlan == nil { + return false + } + schema := indexPlan.Schema() + for _, col := range byItemCols { + pos := schema.ColumnIndex(col) + if pos == -1 { + return false + } + if schema.Columns[pos].IsPrefix { + return false + } + } + return true +} + +// canExpressionConvertedToPB checks whether each of the the expression in TopN can be converted to pb. +func (p *PhysicalTopN) canExpressionConvertedToPB(storeTp kv.StoreType) bool { + exprs := make([]expression.Expression, 0, len(p.ByItems)) + for _, item := range p.ByItems { + exprs = append(exprs, item.Expr) + } + return expression.CanExprsPushDown(GetPushDownCtx(p.SCtx()), exprs, storeTp) +} + +// containVirtualColumn checks whether TopN.ByItems contains virtual generated columns. +func (p *PhysicalTopN) containVirtualColumn(tCols []*expression.Column) bool { + tColSet := make(map[int64]struct{}, len(tCols)) + for _, tCol := range tCols { + if tCol.ID > 0 && tCol.VirtualExpr != nil { + tColSet[tCol.ID] = struct{}{} + } + } + for _, by := range p.ByItems { + cols := expression.ExtractColumns(by.Expr) + for _, col := range cols { + if _, ok := tColSet[col.ID]; ok { + // A column with ID > 0 indicates that the column can be resolved by data source. + return true + } + } + } + return false +} + +// canPushDownToTiKV checks whether this topN can be pushed down to TiKV. +func (p *PhysicalTopN) canPushDownToTiKV(copTask *CopTask) bool { + if !p.canExpressionConvertedToPB(kv.TiKV) { + return false + } + if len(copTask.rootTaskConds) != 0 { + return false + } + if !copTask.indexPlanFinished && len(copTask.idxMergePartPlans) > 0 { + for _, partialPlan := range copTask.idxMergePartPlans { + if p.containVirtualColumn(partialPlan.Schema().Columns) { + return false + } + } + } else if p.containVirtualColumn(copTask.Plan().Schema().Columns) { + return false + } + return true +} + +// canPushDownToTiFlash checks whether this topN can be pushed down to TiFlash. +func (p *PhysicalTopN) canPushDownToTiFlash(mppTask *MppTask) bool { + if !p.canExpressionConvertedToPB(kv.TiFlash) { + return false + } + if p.containVirtualColumn(mppTask.Plan().Schema().Columns) { + return false + } + return true +} + +// Attach2Task implements physical plan +func (p *PhysicalTopN) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + cols := make([]*expression.Column, 0, len(p.ByItems)) + for _, item := range p.ByItems { + cols = append(cols, expression.ExtractColumns(item.Expr)...) + } + needPushDown := len(cols) > 0 + if copTask, ok := t.(*CopTask); ok && needPushDown && p.canPushDownToTiKV(copTask) && len(copTask.rootTaskConds) == 0 { + // If all columns in topN are from index plan, we push it to index plan, otherwise we finish the index plan and + // push it to table plan. + var pushedDownTopN *PhysicalTopN + if !copTask.indexPlanFinished && p.canPushToIndexPlan(copTask.indexPlan, cols) { + pushedDownTopN = p.getPushedDownTopN(copTask.indexPlan) + copTask.indexPlan = pushedDownTopN + } else { + // It works for both normal index scan and index merge scan. + copTask.finishIndexPlan() + pushedDownTopN = p.getPushedDownTopN(copTask.tablePlan) + copTask.tablePlan = pushedDownTopN + } + } else if mppTask, ok := t.(*MppTask); ok && needPushDown && p.canPushDownToTiFlash(mppTask) { + pushedDownTopN := p.getPushedDownTopN(mppTask.p) + mppTask.p = pushedDownTopN + } + rootTask := t.ConvertToRootTask(p.SCtx()) + // Skip TopN with partition on the root. This is a derived topN and window function + // will take care of the filter. + if len(p.GetPartitionBy()) > 0 { + return t + } + return attachPlan2Task(p, rootTask) +} + +// Attach2Task implements the PhysicalPlan interface. +func (p *PhysicalExpand) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + // current expand can only be run in MPP TiFlash mode or Root Tidb mode. + // if expr inside could not be pushed down to tiFlash, it will error in converting to pb side. + if mpp, ok := t.(*MppTask); ok { + p.SetChildren(mpp.p) + mpp.p = p + return mpp + } + // For root task + // since expand should be in root side accordingly, convert to root task now. + root := t.ConvertToRootTask(p.SCtx()) + t = attachPlan2Task(p, root) + if root, ok := tasks[0].(*RootTask); ok && root.IsEmpty() { + t.(*RootTask).SetEmpty(true) + } + return t +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalProjection) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + if cop, ok := t.(*CopTask); ok { + if (len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0) && expression.CanExprsPushDown(GetPushDownCtx(p.SCtx()), p.Exprs, cop.getStoreType()) { + copTask := attachPlan2Task(p, cop) + return copTask + } + } else if mpp, ok := t.(*MppTask); ok { + if expression.CanExprsPushDown(GetPushDownCtx(p.SCtx()), p.Exprs, kv.TiFlash) { + p.SetChildren(mpp.p) + mpp.p = p + return mpp + } + } + t = t.ConvertToRootTask(p.SCtx()) + t = attachPlan2Task(p, t) + if root, ok := tasks[0].(*RootTask); ok && root.IsEmpty() { + t.(*RootTask).SetEmpty(true) + } + return t +} + +func (p *PhysicalUnionAll) attach2MppTasks(tasks ...base.Task) base.Task { + t := &MppTask{p: p} + childPlans := make([]base.PhysicalPlan, 0, len(tasks)) + for _, tk := range tasks { + if mpp, ok := tk.(*MppTask); ok && !tk.Invalid() { + childPlans = append(childPlans, mpp.Plan()) + } else if root, ok := tk.(*RootTask); ok && root.IsEmpty() { + continue + } else { + return base.InvalidTask + } + } + if len(childPlans) == 0 { + return base.InvalidTask + } + p.SetChildren(childPlans...) + return t +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalUnionAll) Attach2Task(tasks ...base.Task) base.Task { + for _, t := range tasks { + if _, ok := t.(*MppTask); ok { + if p.TP() == plancodec.TypePartitionUnion { + // In attach2MppTasks(), will attach PhysicalUnion to mppTask directly. + // But PartitionUnion cannot pushdown to tiflash, so here disable PartitionUnion pushdown to tiflash explicitly. + // For now, return base.InvalidTask immediately, we can refine this by letting childTask of PartitionUnion convert to rootTask. + return base.InvalidTask + } + return p.attach2MppTasks(tasks...) + } + } + t := &RootTask{} + t.SetPlan(p) + childPlans := make([]base.PhysicalPlan, 0, len(tasks)) + for _, task := range tasks { + task = task.ConvertToRootTask(p.SCtx()) + childPlans = append(childPlans, task.Plan()) + } + p.SetChildren(childPlans...) + return t +} + +// Attach2Task implements PhysicalPlan interface. +func (sel *PhysicalSelection) Attach2Task(tasks ...base.Task) base.Task { + if mppTask, _ := tasks[0].(*MppTask); mppTask != nil { // always push to mpp task. + if expression.CanExprsPushDown(GetPushDownCtx(sel.SCtx()), sel.Conditions, kv.TiFlash) { + return attachPlan2Task(sel, mppTask.Copy()) + } + } + t := tasks[0].ConvertToRootTask(sel.SCtx()) + return attachPlan2Task(sel, t) +} + +// CheckAggCanPushCop checks whether the aggFuncs and groupByItems can +// be pushed down to coprocessor. +func CheckAggCanPushCop(sctx base.PlanContext, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, storeType kv.StoreType) bool { + sc := sctx.GetSessionVars().StmtCtx + ret := true + reason := "" + pushDownCtx := GetPushDownCtx(sctx) + for _, aggFunc := range aggFuncs { + // if the aggFunc contain VirtualColumn or CorrelatedColumn, it can not be pushed down. + if expression.ContainVirtualColumn(aggFunc.Args) || expression.ContainCorrelatedColumn(aggFunc.Args) { + reason = "expressions of AggFunc `" + aggFunc.Name + "` contain virtual column or correlated column, which is not supported now" + ret = false + break + } + if !aggregation.CheckAggPushDown(sctx.GetExprCtx().GetEvalCtx(), aggFunc, storeType) { + reason = "AggFunc `" + aggFunc.Name + "` is not supported now" + ret = false + break + } + if !expression.CanExprsPushDownWithExtraInfo(GetPushDownCtx(sctx), aggFunc.Args, storeType, aggFunc.Name == ast.AggFuncSum) { + reason = "arguments of AggFunc `" + aggFunc.Name + "` contains unsupported exprs" + ret = false + break + } + orderBySize := len(aggFunc.OrderByItems) + if orderBySize > 0 { + exprs := make([]expression.Expression, 0, orderBySize) + for _, item := range aggFunc.OrderByItems { + exprs = append(exprs, item.Expr) + } + if !expression.CanExprsPushDownWithExtraInfo(GetPushDownCtx(sctx), exprs, storeType, false) { + reason = "arguments of AggFunc `" + aggFunc.Name + "` contains unsupported exprs in order-by clause" + ret = false + break + } + } + pb, _ := aggregation.AggFuncToPBExpr(pushDownCtx, aggFunc, storeType) + if pb == nil { + reason = "AggFunc `" + aggFunc.Name + "` can not be converted to pb expr" + ret = false + break + } + } + if ret && expression.ContainVirtualColumn(groupByItems) { + reason = "groupByItems contain virtual columns, which is not supported now" + ret = false + } + if ret && !expression.CanExprsPushDown(GetPushDownCtx(sctx), groupByItems, storeType) { + reason = "groupByItems contain unsupported exprs" + ret = false + } + + if !ret { + storageName := storeType.Name() + if storeType == kv.UnSpecified { + storageName = "storage layer" + } + warnErr := errors.NewNoStackError("Aggregation can not be pushed to " + storageName + " because " + reason) + if sc.InExplainStmt { + sc.AppendWarning(warnErr) + } else { + sc.AppendExtraWarning(warnErr) + } + } + return ret +} + +// AggInfo stores the information of an Aggregation. +type AggInfo struct { + AggFuncs []*aggregation.AggFuncDesc + GroupByItems []expression.Expression + Schema *expression.Schema +} + +// BuildFinalModeAggregation splits either LogicalAggregation or PhysicalAggregation to finalAgg and partial1Agg, +// returns the information of partial and final agg. +// partialIsCop means whether partial agg is a cop task. When partialIsCop is false, +// we do not set the AggMode for partialAgg cause it may be split further when +// building the aggregate executor(e.g. buildHashAgg will split the AggDesc further for parallel executing). +// firstRowFuncMap is a map between partial first_row to final first_row, will be used in RemoveUnnecessaryFirstRow +func BuildFinalModeAggregation( + sctx base.PlanContext, original *AggInfo, partialIsCop bool, isMPPTask bool) (partial, final *AggInfo, firstRowFuncMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) { + ectx := sctx.GetExprCtx().GetEvalCtx() + + firstRowFuncMap = make(map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc, len(original.AggFuncs)) + partial = &AggInfo{ + AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(original.AggFuncs)), + GroupByItems: original.GroupByItems, + Schema: expression.NewSchema(), + } + partialCursor := 0 + final = &AggInfo{ + AggFuncs: make([]*aggregation.AggFuncDesc, len(original.AggFuncs)), + GroupByItems: make([]expression.Expression, 0, len(original.GroupByItems)), + Schema: original.Schema, + } + + partialGbySchema := expression.NewSchema() + // add group by columns + for _, gbyExpr := range partial.GroupByItems { + var gbyCol *expression.Column + if col, ok := gbyExpr.(*expression.Column); ok { + gbyCol = col + } else { + gbyCol = &expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: gbyExpr.GetType(ectx), + } + } + partialGbySchema.Append(gbyCol) + final.GroupByItems = append(final.GroupByItems, gbyCol) + } + + // TODO: Refactor the way of constructing aggregation functions. + // This for loop is ugly, but I do not find a proper way to reconstruct + // it right away. + + // group_concat is special when pushing down, it cannot take the two phase execution if no distinct but with orderBy, and other cases are also different: + // for example: group_concat([distinct] expr0, expr1[, order by expr2] separator ‘,’) + // no distinct, no orderBy: can two phase + // [final agg] group_concat(col#1,’,’) + // [part agg] group_concat(expr0, expr1,’,’) -> col#1 + // no distinct, orderBy: only one phase + // distinct, no orderBy: can two phase + // [final agg] group_concat(distinct col#0, col#1,’,’) + // [part agg] group by expr0 ->col#0, expr1 -> col#1 + // distinct, orderBy: can two phase + // [final agg] group_concat(distinct col#0, col#1, order by col#2,’,’) + // [part agg] group by expr0 ->col#0, expr1 -> col#1; agg function: firstrow(expr2)-> col#2 + + for i, aggFunc := range original.AggFuncs { + finalAggFunc := &aggregation.AggFuncDesc{HasDistinct: false} + finalAggFunc.Name = aggFunc.Name + finalAggFunc.OrderByItems = aggFunc.OrderByItems + args := make([]expression.Expression, 0, len(aggFunc.Args)) + if aggFunc.HasDistinct { + /* + eg: SELECT COUNT(DISTINCT a), SUM(b) FROM t GROUP BY c + + change from + [root] group by: c, funcs:count(distinct a), funcs:sum(b) + to + [root] group by: c, funcs:count(distinct a), funcs:sum(b) + [cop]: group by: c, a + */ + // onlyAddFirstRow means if the distinctArg does not occur in group by items, + // it should be replaced with a firstrow() agg function, needed for the order by items of group_concat() + getDistinctExpr := func(distinctArg expression.Expression, onlyAddFirstRow bool) (ret expression.Expression) { + // 1. add all args to partial.GroupByItems + foundInGroupBy := false + for j, gbyExpr := range partial.GroupByItems { + if gbyExpr.Equal(ectx, distinctArg) && gbyExpr.GetType(ectx).Equal(distinctArg.GetType(ectx)) { + // if the two expressions exactly the same in terms of data types and collation, then can avoid it. + foundInGroupBy = true + ret = partialGbySchema.Columns[j] + break + } + } + if !foundInGroupBy { + var gbyCol *expression.Column + if col, ok := distinctArg.(*expression.Column); ok { + gbyCol = col + } else { + gbyCol = &expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: distinctArg.GetType(ectx), + } + } + // 2. add group by items if needed + if !onlyAddFirstRow { + partial.GroupByItems = append(partial.GroupByItems, distinctArg) + partialGbySchema.Append(gbyCol) + ret = gbyCol + } + // 3. add firstrow() if needed + if !partialIsCop || onlyAddFirstRow { + // if partial is a cop task, firstrow function is redundant since group by items are outputted + // by group by schema, and final functions use group by schema as their arguments. + // if partial agg is not cop, we must append firstrow function & schema, to output the group by + // items. + // maybe we can unify them sometime. + // only add firstrow for order by items of group_concat() + firstRow, err := aggregation.NewAggFuncDesc(sctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{distinctArg}, false) + if err != nil { + panic("NewAggFuncDesc FirstRow meets error: " + err.Error()) + } + partial.AggFuncs = append(partial.AggFuncs, firstRow) + newCol, _ := gbyCol.Clone().(*expression.Column) + newCol.RetType = firstRow.RetTp + partial.Schema.Append(newCol) + if onlyAddFirstRow { + ret = newCol + } + partialCursor++ + } + } + return ret + } + + for j, distinctArg := range aggFunc.Args { + // the last arg of ast.AggFuncGroupConcat is the separator, so just put it into the final agg + if aggFunc.Name == ast.AggFuncGroupConcat && j+1 == len(aggFunc.Args) { + args = append(args, distinctArg) + continue + } + args = append(args, getDistinctExpr(distinctArg, false)) + } + + byItems := make([]*util.ByItems, 0, len(aggFunc.OrderByItems)) + for _, byItem := range aggFunc.OrderByItems { + byItems = append(byItems, &util.ByItems{Expr: getDistinctExpr(byItem.Expr, true), Desc: byItem.Desc}) + } + + if aggFunc.HasDistinct && isMPPTask && aggFunc.GroupingID > 0 { + // keep the groupingID as it was, otherwise the new split final aggregate's ganna lost its groupingID info. + finalAggFunc.GroupingID = aggFunc.GroupingID + } + + finalAggFunc.OrderByItems = byItems + finalAggFunc.HasDistinct = aggFunc.HasDistinct + // In logical optimize phase, the Agg->PartitionUnion->TableReader may become + // Agg1->PartitionUnion->Agg2->TableReader, and the Agg2 is a partial aggregation. + // So in the push down here, we need to add a new if-condition check: + // If the original agg mode is partial already, the finalAggFunc's mode become Partial2. + if aggFunc.Mode == aggregation.CompleteMode { + finalAggFunc.Mode = aggregation.CompleteMode + } else if aggFunc.Mode == aggregation.Partial1Mode || aggFunc.Mode == aggregation.Partial2Mode { + finalAggFunc.Mode = aggregation.Partial2Mode + } + } else { + if aggFunc.Name == ast.AggFuncGroupConcat && len(aggFunc.OrderByItems) > 0 { + // group_concat can only run in one phase if it has order by items but without distinct property + partial = nil + final = original + return + } + if aggregation.NeedCount(finalAggFunc.Name) { + // only Avg and Count need count + if isMPPTask && finalAggFunc.Name == ast.AggFuncCount { + // For MPP base.Task, the final count() is changed to sum(). + // Note: MPP mode does not run avg() directly, instead, avg() -> sum()/(case when count() = 0 then 1 else count() end), + // so we do not process it here. + finalAggFunc.Name = ast.AggFuncSum + } else { + // avg branch + ft := types.NewFieldType(mysql.TypeLonglong) + ft.SetFlen(21) + ft.SetCharset(charset.CharsetBin) + ft.SetCollate(charset.CollationBin) + partial.Schema.Append(&expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: ft, + }) + args = append(args, partial.Schema.Columns[partialCursor]) + partialCursor++ + } + } + if finalAggFunc.Name == ast.AggFuncApproxCountDistinct { + ft := types.NewFieldType(mysql.TypeString) + ft.SetCharset(charset.CharsetBin) + ft.SetCollate(charset.CollationBin) + ft.AddFlag(mysql.NotNullFlag) + partial.Schema.Append(&expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: ft, + }) + args = append(args, partial.Schema.Columns[partialCursor]) + partialCursor++ + } + if aggregation.NeedValue(finalAggFunc.Name) { + partial.Schema.Append(&expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: original.Schema.Columns[i].GetType(ectx), + }) + args = append(args, partial.Schema.Columns[partialCursor]) + partialCursor++ + } + if aggFunc.Name == ast.AggFuncAvg { + cntAgg := aggFunc.Clone() + cntAgg.Name = ast.AggFuncCount + err := cntAgg.TypeInfer(sctx.GetExprCtx()) + if err != nil { // must not happen + partial = nil + final = original + return + } + partial.Schema.Columns[partialCursor-2].RetType = cntAgg.RetTp + // we must call deep clone in this case, to avoid sharing the arguments. + sumAgg := aggFunc.Clone() + sumAgg.Name = ast.AggFuncSum + sumAgg.TypeInfer4AvgSum(sumAgg.RetTp) + partial.Schema.Columns[partialCursor-1].RetType = sumAgg.RetTp + partial.AggFuncs = append(partial.AggFuncs, cntAgg, sumAgg) + } else if aggFunc.Name == ast.AggFuncApproxCountDistinct || aggFunc.Name == ast.AggFuncGroupConcat { + newAggFunc := aggFunc.Clone() + newAggFunc.Name = aggFunc.Name + newAggFunc.RetTp = partial.Schema.Columns[partialCursor-1].GetType(ectx) + partial.AggFuncs = append(partial.AggFuncs, newAggFunc) + if aggFunc.Name == ast.AggFuncGroupConcat { + // append the last separator arg + args = append(args, aggFunc.Args[len(aggFunc.Args)-1]) + } + } else { + // other agg desc just split into two parts + partialFuncDesc := aggFunc.Clone() + partial.AggFuncs = append(partial.AggFuncs, partialFuncDesc) + if aggFunc.Name == ast.AggFuncFirstRow { + firstRowFuncMap[partialFuncDesc] = finalAggFunc + } + } + + // In logical optimize phase, the Agg->PartitionUnion->TableReader may become + // Agg1->PartitionUnion->Agg2->TableReader, and the Agg2 is a partial aggregation. + // So in the push down here, we need to add a new if-condition check: + // If the original agg mode is partial already, the finalAggFunc's mode become Partial2. + if aggFunc.Mode == aggregation.CompleteMode { + finalAggFunc.Mode = aggregation.FinalMode + } else if aggFunc.Mode == aggregation.Partial1Mode || aggFunc.Mode == aggregation.Partial2Mode { + finalAggFunc.Mode = aggregation.Partial2Mode + } + } + + finalAggFunc.Args = args + finalAggFunc.RetTp = aggFunc.RetTp + final.AggFuncs[i] = finalAggFunc + } + partial.Schema.Append(partialGbySchema.Columns...) + if partialIsCop { + for _, f := range partial.AggFuncs { + f.Mode = aggregation.Partial1Mode + } + } + return +} + +// convertAvgForMPP converts avg(arg) to sum(arg)/(case when count(arg)=0 then 1 else count(arg) end), in detail: +// 1.rewrite avg() in the final aggregation to count() and sum(), and reconstruct its schema. +// 2.replace avg() with sum(arg)/(case when count(arg)=0 then 1 else count(arg) end) and reuse the original schema of the final aggregation. +// If there is no avg, nothing is changed and return nil. +func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { + newSchema := expression.NewSchema() + newSchema.Keys = p.schema.Keys + newSchema.UniqueKeys = p.schema.UniqueKeys + newAggFuncs := make([]*aggregation.AggFuncDesc, 0, 2*len(p.AggFuncs)) + exprs := make([]expression.Expression, 0, 2*len(p.schema.Columns)) + // add agg functions schema + for i, aggFunc := range p.AggFuncs { + if aggFunc.Name == ast.AggFuncAvg { + // inset a count(column) + avgCount := aggFunc.Clone() + avgCount.Name = ast.AggFuncCount + err := avgCount.TypeInfer(p.SCtx().GetExprCtx()) + if err != nil { // must not happen + return nil + } + newAggFuncs = append(newAggFuncs, avgCount) + avgCountCol := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: avgCount.RetTp, + } + newSchema.Append(avgCountCol) + // insert a sum(column) + avgSum := aggFunc.Clone() + avgSum.Name = ast.AggFuncSum + avgSum.TypeInfer4AvgSum(avgSum.RetTp) + newAggFuncs = append(newAggFuncs, avgSum) + avgSumCol := &expression.Column{ + UniqueID: p.schema.Columns[i].UniqueID, + RetType: avgSum.RetTp, + } + newSchema.Append(avgSumCol) + // avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end) + eq := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), avgCountCol, expression.NewZero()) + caseWhen := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Case, avgCountCol.RetType, eq, expression.NewOne(), avgCountCol) + divide := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Div, avgSumCol.RetType, avgSumCol, caseWhen) + divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType + exprs = append(exprs, divide) + } else { + // other non-avg agg use the old schema as it did. + newAggFuncs = append(newAggFuncs, aggFunc) + newSchema.Append(p.schema.Columns[i]) + exprs = append(exprs, p.schema.Columns[i]) + } + } + // no avgs + // for final agg, always add project due to in-compatibility between TiDB and TiFlash + if len(p.schema.Columns) == len(newSchema.Columns) && !p.IsFinalAgg() { + return nil + } + // add remaining columns to exprs + for i := len(p.AggFuncs); i < len(p.schema.Columns); i++ { + exprs = append(exprs, p.schema.Columns[i]) + } + proj := PhysicalProjection{ + Exprs: exprs, + CalculateNoDelay: false, + AvoidColumnEvaluator: false, + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), p.GetChildReqProps(0).CloneEssentialFields()) + proj.SetSchema(p.schema) + + p.AggFuncs = newAggFuncs + p.schema = newSchema + + return proj +} + +func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTask bool) (partial, final base.PhysicalPlan) { + // Check if this aggregation can push down. + if !CheckAggCanPushCop(p.SCtx(), p.AggFuncs, p.GroupByItems, copTaskType) { + return nil, p.self + } + partialPref, finalPref, firstRowFuncMap := BuildFinalModeAggregation(p.SCtx(), &AggInfo{ + AggFuncs: p.AggFuncs, + GroupByItems: p.GroupByItems, + Schema: p.Schema().Clone(), + }, true, isMPPTask) + if partialPref == nil { + return nil, p.self + } + if p.TP() == plancodec.TypeStreamAgg && len(partialPref.GroupByItems) != len(finalPref.GroupByItems) { + return nil, p.self + } + // Remove unnecessary FirstRow. + partialPref.AggFuncs = RemoveUnnecessaryFirstRow(p.SCtx(), + finalPref.GroupByItems, partialPref.AggFuncs, partialPref.GroupByItems, partialPref.Schema, firstRowFuncMap) + if copTaskType == kv.TiDB { + // For partial agg of TiDB cop task, since TiDB coprocessor reuse the TiDB executor, + // and TiDB aggregation executor won't output the group by value, + // so we need add `firstrow` aggregation function to output the group by value. + aggFuncs, err := genFirstRowAggForGroupBy(p.SCtx(), partialPref.GroupByItems) + if err != nil { + return nil, p.self + } + partialPref.AggFuncs = append(partialPref.AggFuncs, aggFuncs...) + } + p.AggFuncs = partialPref.AggFuncs + p.GroupByItems = partialPref.GroupByItems + p.schema = partialPref.Schema + partialAgg := p.self + // Create physical "final" aggregation. + prop := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64} + if p.TP() == plancodec.TypeStreamAgg { + finalAgg := basePhysicalAgg{ + AggFuncs: finalPref.AggFuncs, + GroupByItems: finalPref.GroupByItems, + MppRunMode: p.MppRunMode, + }.initForStream(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), prop) + finalAgg.schema = finalPref.Schema + return partialAgg, finalAgg + } + + finalAgg := basePhysicalAgg{ + AggFuncs: finalPref.AggFuncs, + GroupByItems: finalPref.GroupByItems, + MppRunMode: p.MppRunMode, + }.initForHash(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), prop) + finalAgg.schema = finalPref.Schema + // partialAgg and finalAgg use the same ref of stats + return partialAgg, finalAgg +} + +func (p *basePhysicalAgg) scale3StageForDistinctAgg() (bool, expression.GroupingSets) { + if p.canUse3Stage4SingleDistinctAgg() { + return true, nil + } + return p.canUse3Stage4MultiDistinctAgg() +} + +// canUse3Stage4MultiDistinctAgg returns true if this agg can use 3 stage for multi distinct aggregation +func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss expression.GroupingSets) { + if !p.SCtx().GetSessionVars().Enable3StageDistinctAgg || !p.SCtx().GetSessionVars().Enable3StageMultiDistinctAgg || len(p.GroupByItems) > 0 { + return false, nil + } + defer func() { + // some clean work. + if !can { + for _, fun := range p.AggFuncs { + fun.GroupingID = 0 + } + } + }() + // groupingSets is alias of []GroupingSet, the below equal to = make([]GroupingSet, 0, 2) + groupingSets := make(expression.GroupingSets, 0, 2) + for _, fun := range p.AggFuncs { + if fun.HasDistinct { + if fun.Name != ast.AggFuncCount { + // now only for multi count(distinct x) + return false, nil + } + for _, arg := range fun.Args { + // bail out when args are not simple column, see GitHub issue #35417 + if _, ok := arg.(*expression.Column); !ok { + return false, nil + } + } + // here it's a valid count distinct agg with normal column args, collecting its distinct expr. + groupingSets = append(groupingSets, expression.GroupingSet{fun.Args}) + // groupingID now is the offset of target grouping in GroupingSets. + // todo: it may be changed after grouping set merge in the future. + fun.GroupingID = len(groupingSets) + } else if len(fun.Args) > 1 { + return false, nil + } + // banned group_concat(x order by y) + if len(fun.OrderByItems) > 0 || fun.Mode != aggregation.CompleteMode { + return false, nil + } + } + compressed := groupingSets.Merge() + if len(compressed) != len(groupingSets) { + p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Some grouping sets should be merged")) + // todo arenatlx: some grouping set should be merged which is not supported by now temporarily. + return false, nil + } + if groupingSets.NeedCloneColumn() { + // todo: column clone haven't implemented. + return false, nil + } + if len(groupingSets) > 1 { + // fill the grouping ID for normal agg. + for _, fun := range p.AggFuncs { + if fun.GroupingID == 0 { + // the grouping ID hasn't set. find the targeting grouping set. + groupingSetOffset := groupingSets.TargetOne(fun.Args) + if groupingSetOffset == -1 { + // todo: if we couldn't find a existed current valid group layout, we need to copy the column out from being filled with null value. + p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("couldn't find a proper group set for normal agg")) + return false, nil + } + // starting with 1 + fun.GroupingID = groupingSetOffset + 1 + } + } + return true, groupingSets + } + return false, nil +} + +// canUse3Stage4SingleDistinctAgg returns true if this agg can use 3 stage for distinct aggregation +func (p *basePhysicalAgg) canUse3Stage4SingleDistinctAgg() bool { + num := 0 + if !p.SCtx().GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 { + return false + } + for _, fun := range p.AggFuncs { + if fun.HasDistinct { + num++ + if num > 1 || fun.Name != ast.AggFuncCount { + return false + } + for _, arg := range fun.Args { + // bail out when args are not simple column, see GitHub issue #35417 + if _, ok := arg.(*expression.Column); !ok { + return false + } + } + } else if len(fun.Args) > 1 { + return false + } + + if len(fun.OrderByItems) > 0 || fun.Mode != aggregation.CompleteMode { + return false + } + } + return num == 1 +} + +func genFirstRowAggForGroupBy(ctx base.PlanContext, groupByItems []expression.Expression) ([]*aggregation.AggFuncDesc, error) { + aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(groupByItems)) + for _, groupBy := range groupByItems { + agg, err := aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{groupBy}, false) + if err != nil { + return nil, err + } + aggFuncs = append(aggFuncs, agg) + } + return aggFuncs, nil +} + +// RemoveUnnecessaryFirstRow removes unnecessary FirstRow of the aggregation. This function can be +// used for both LogicalAggregation and PhysicalAggregation. +// When the select column is same with the group by key, the column can be removed and gets value from the group by key. +// e.g +// select a, count(b) from t group by a; +// The schema is [firstrow(a), count(b), a]. The column firstrow(a) is unnecessary. +// Can optimize the schema to [count(b), a] , and change the index to get value. +func RemoveUnnecessaryFirstRow( + sctx base.PlanContext, + finalGbyItems []expression.Expression, + partialAggFuncs []*aggregation.AggFuncDesc, + partialGbyItems []expression.Expression, + partialSchema *expression.Schema, + firstRowFuncMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) []*aggregation.AggFuncDesc { + partialCursor := 0 + newAggFuncs := make([]*aggregation.AggFuncDesc, 0, len(partialAggFuncs)) + for _, aggFunc := range partialAggFuncs { + if aggFunc.Name == ast.AggFuncFirstRow { + canOptimize := false + for j, gbyExpr := range partialGbyItems { + if j >= len(finalGbyItems) { + // after distinct push, len(partialGbyItems) may larger than len(finalGbyItems) + // for example, + // select /*+ HASH_AGG() */ a, count(distinct a) from t; + // will generate to, + // HashAgg root funcs:count(distinct a), funcs:firstrow(a)" + // HashAgg cop group by:a, funcs:firstrow(a)->Column#6" + // the firstrow in root task can not be removed. + break + } + // Skip if it's a constant. + // For SELECT DISTINCT SQRT(1) FROM t. + // We shouldn't remove the firstrow(SQRT(1)). + if _, ok := gbyExpr.(*expression.Constant); ok { + continue + } + if gbyExpr.Equal(sctx.GetExprCtx().GetEvalCtx(), aggFunc.Args[0]) { + canOptimize = true + firstRowFuncMap[aggFunc].Args[0] = finalGbyItems[j] + break + } + } + if canOptimize { + partialSchema.Columns = append(partialSchema.Columns[:partialCursor], partialSchema.Columns[partialCursor+1:]...) + continue + } + } + partialCursor += computePartialCursorOffset(aggFunc.Name) + newAggFuncs = append(newAggFuncs, aggFunc) + } + return newAggFuncs +} + +func computePartialCursorOffset(name string) int { + offset := 0 + if aggregation.NeedCount(name) { + offset++ + } + if aggregation.NeedValue(name) { + offset++ + } + if name == ast.AggFuncApproxCountDistinct { + offset++ + } + return offset +} + +// Attach2Task implements PhysicalPlan interface. +func (p *PhysicalStreamAgg) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + if cop, ok := t.(*CopTask); ok { + // We should not push agg down across + // 1. double read, since the data of second read is ordered by handle instead of index. The `extraHandleCol` is added + // if the double read needs to keep order. So we just use it to decided + // whether the following plan is double read with order reserved. + // 2. the case that there's filters should be calculated on TiDB side. + // 3. the case of index merge + if (cop.indexPlan != nil && cop.tablePlan != nil && cop.keepOrder) || len(cop.rootTaskConds) > 0 || len(cop.idxMergePartPlans) > 0 { + t = cop.ConvertToRootTask(p.SCtx()) + attachPlan2Task(p, t) + } else { + storeType := cop.getStoreType() + // TiFlash doesn't support Stream Aggregation + if storeType == kv.TiFlash && len(p.GroupByItems) > 0 { + return base.InvalidTask + } + partialAgg, finalAgg := p.newPartialAggregate(storeType, false) + if partialAgg != nil { + if cop.tablePlan != nil { + cop.finishIndexPlan() + partialAgg.SetChildren(cop.tablePlan) + cop.tablePlan = partialAgg + // If needExtraProj is true, a projection will be created above the PhysicalIndexLookUpReader to make sure + // the schema is the same as the original DataSource schema. + // However, we pushed down the agg here, the partial agg was placed on the top of tablePlan, and the final + // agg will be placed above the PhysicalIndexLookUpReader, and the schema will be set correctly for them. + // If we add the projection again, the projection will be between the PhysicalIndexLookUpReader and + // the partial agg, and the schema will be broken. + cop.needExtraProj = false + } else { + partialAgg.SetChildren(cop.indexPlan) + cop.indexPlan = partialAgg + } + } + t = cop.ConvertToRootTask(p.SCtx()) + attachPlan2Task(finalAgg, t) + } + } else if mpp, ok := t.(*MppTask); ok { + t = mpp.ConvertToRootTask(p.SCtx()) + attachPlan2Task(p, t) + } else { + attachPlan2Task(p, t) + } + return t +} + +// cpuCostDivisor computes the concurrency to which we would amortize CPU cost +// for hash aggregation. +func (p *PhysicalHashAgg) cpuCostDivisor(hasDistinct bool) (divisor, con float64) { + if hasDistinct { + return 0, 0 + } + sessionVars := p.SCtx().GetSessionVars() + finalCon, partialCon := sessionVars.HashAggFinalConcurrency(), sessionVars.HashAggPartialConcurrency() + // According to `ValidateSetSystemVar`, `finalCon` and `partialCon` cannot be less than or equal to 0. + if finalCon == 1 && partialCon == 1 { + return 0, 0 + } + // It is tricky to decide which concurrency we should use to amortize CPU cost. Since cost of hash + // aggregation is tend to be under-estimated as explained in `attach2Task`, we choose the smaller + // concurrecy to make some compensation. + return math.Min(float64(finalCon), float64(partialCon)), float64(finalCon + partialCon) +} + +func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *MppTask) base.Task { + // 1-phase agg: when the partition columns can be satisfied, where the plan does not need to enforce Exchange + // only push down the original agg + proj := p.convertAvgForMPP() + attachPlan2Task(p.self, mpp) + if proj != nil { + attachPlan2Task(proj, mpp) + } + return mpp +} + +// scaleStats4GroupingSets scale the derived stats because the lower source has been expanded. +// +// parent OP <- logicalAgg <- children OP (derived stats) +// | +// v +// parent OP <- physicalAgg <- children OP (stats used) +// | +// +----------+----------+----------+ +// Final Mid Partial Expand +// +// physical agg stats is reasonable from the whole, because expand operator is designed to facilitate +// the Mid and Partial Agg, which means when leaving the Final, its output rowcount could be exactly +// the same as what it derived(estimated) before entering physical optimization phase. +// +// From the cost model correctness, for these inserted sub-agg and even expand operator, we should +// recompute the stats for them particularly. +// +// for example: grouping sets {},{}, group by items {a,b,c,groupingID} +// after expand: +// +// a, b, c, groupingID +// ... null c 1 ---+ +// ... null c 1 +------- replica group 1 +// ... null c 1 ---+ +// null ... c 2 ---+ +// null ... c 2 +------- replica group 2 +// null ... c 2 ---+ +// +// since null value is seen the same when grouping data (groupingID in one replica is always the same): +// - so the num of group in replica 1 is equal to NDV(a,c) +// - so the num of group in replica 2 is equal to NDV(b,c) +// +// in a summary, the total num of group of all replica is equal to = Σ:NDV(each-grouping-set-cols, normal-group-cols) +func (p *PhysicalHashAgg) scaleStats4GroupingSets(groupingSets expression.GroupingSets, groupingIDCol *expression.Column, + childSchema *expression.Schema, childStats *property.StatsInfo) { + idSets := groupingSets.AllSetsColIDs() + normalGbyCols := make([]*expression.Column, 0, len(p.GroupByItems)) + for _, gbyExpr := range p.GroupByItems { + cols := expression.ExtractColumns(gbyExpr) + for _, col := range cols { + if !idSets.Has(int(col.UniqueID)) && col.UniqueID != groupingIDCol.UniqueID { + normalGbyCols = append(normalGbyCols, col) + } + } + } + sumNDV := float64(0) + for _, groupingSet := range groupingSets { + // for every grouping set, pick its cols out, and combine with normal group cols to get the ndv. + groupingSetCols := groupingSet.ExtractCols() + groupingSetCols = append(groupingSetCols, normalGbyCols...) + ndv, _ := cardinality.EstimateColsNDVWithMatchedLen(groupingSetCols, childSchema, childStats) + sumNDV += ndv + } + // After group operator, all same rows are grouped into one row, that means all + // change the sub-agg's stats + if p.StatsInfo() != nil { + // equivalence to a new cloned one. (cause finalAgg and partialAgg may share a same copy of stats) + cpStats := p.StatsInfo().Scale(1) + cpStats.RowCount = sumNDV + // We cannot estimate the ColNDVs for every output, so we use a conservative strategy. + for k := range cpStats.ColNDVs { + cpStats.ColNDVs[k] = sumNDV + } + // for old groupNDV, if it's containing one more grouping set cols, just plus the NDV where the col is excluded. + // for example: old grouping NDV(b,c), where b is in grouping sets {},{}. so when countering the new NDV: + // cases: + // new grouping NDV(b,c) := old NDV(b,c) + NDV(null, c) = old NDV(b,c) + DNV(c). + // new grouping NDV(a,b,c) := old NDV(a,b,c) + NDV(null,b,c) + NDV(a,null,c) = old NDV(a,b,c) + NDV(b,c) + NDV(a,c) + allGroupingSetsIDs := groupingSets.AllSetsColIDs() + for _, oneGNDV := range cpStats.GroupNDVs { + newGNDV := oneGNDV.NDV + intersectionIDs := make([]int64, 0, len(oneGNDV.Cols)) + for i, id := range oneGNDV.Cols { + if allGroupingSetsIDs.Has(int(id)) { + // when meet an id in grouping sets, skip it (cause its null) and append the rest ids to count the incrementNDV. + beforeLen := len(intersectionIDs) + intersectionIDs = append(intersectionIDs, oneGNDV.Cols[i:]...) + incrementNDV, _ := cardinality.EstimateColsDNVWithMatchedLenFromUniqueIDs(intersectionIDs, childSchema, childStats) + newGNDV += incrementNDV + // restore the before intersectionIDs slice. + intersectionIDs = intersectionIDs[:beforeLen] + } + // insert ids one by one. + intersectionIDs = append(intersectionIDs, id) + } + oneGNDV.NDV = newGNDV + } + p.SetStats(cpStats) + } +} + +// adjust3StagePhaseAgg generate 3 stage aggregation for single/multi count distinct if applicable. +// +// select count(distinct a), count(b) from foo +// +// will generate plan: +// +// HashAgg sum(#1), sum(#2) -> final agg +// +- Exchange Passthrough +// +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg +// +- Exchange HashPartition by a +// +- HashAgg count(b) #3, group by a -> partial agg +// +- TableScan foo +// +// select count(distinct a), count(distinct b), count(c) from foo +// +// will generate plan: +// +// HashAgg sum(#1), sum(#2), sum(#3) -> final agg +// +- Exchange Passthrough +// +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg +// +- Exchange HashPartition by a,b,groupingID +// +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg +// +- Expand {}, {} -> expand +// +- TableScan foo +func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg base.PhysicalPlan, canUse3StageAgg bool, + groupingSets expression.GroupingSets, mpp *MppTask) (final, mid, part, proj4Part base.PhysicalPlan, _ error) { + ectx := p.SCtx().GetExprCtx().GetEvalCtx() + + if !(partialAgg != nil && canUse3StageAgg) { + // quick path: return the original finalAgg and partiAgg. + return finalAgg, nil, partialAgg, nil, nil + } + if len(groupingSets) == 0 { + // single distinct agg mode. + clonedAgg, err := finalAgg.Clone(p.SCtx()) + if err != nil { + return nil, nil, nil, nil, err + } + + // step1: adjust middle agg. + middleHashAgg := clonedAgg.(*PhysicalHashAgg) + distinctPos := 0 + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + for i, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: fun.RetTp, + } + if fun.HasDistinct { + distinctPos = i + fun.Mode = aggregation.Partial1Mode + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) + } + middleHashAgg.schema = middleSchema + + // step2: adjust final agg. + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if distinctPos == i { + // change count(distinct) to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + newArgs = append(newArgs, middleSchema.Columns[i]) + } else { + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, err + } + newArgs = append(newArgs, newCol) + } + } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + finalAggDescs = append(finalAggDescs, fun) + } + finalHashAgg.AggFuncs = finalAggDescs + // partialAgg is im-mutated from args. + return finalHashAgg, middleHashAgg, partialAgg, nil, nil + } + // multi distinct agg mode, having grouping sets. + // set the default expression to constant 1 for the convenience to choose default group set data. + var groupingIDCol expression.Expression + // enforce Expand operator above the children. + // physical plan is enumerated without children from itself, use mpp subtree instead p.children. + // scale(len(groupingSets)) will change the NDV, while Expand doesn't change the NDV and groupNDV. + stats := mpp.p.StatsInfo().Scale(float64(1)) + stats.RowCount = stats.RowCount * float64(len(groupingSets)) + physicalExpand := PhysicalExpand{ + GroupingSets: groupingSets, + }.Init(p.SCtx(), stats, mpp.p.QueryBlockOffset()) + // generate a new column as groupingID to identify which this row is targeting for. + tp := types.NewFieldType(mysql.TypeLonglong) + tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) + groupingIDCol = &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: tp, + } + // append the physical expand op with groupingID column. + physicalExpand.SetSchema(mpp.p.Schema().Clone()) + physicalExpand.schema.Append(groupingIDCol.(*expression.Column)) + physicalExpand.GroupingIDCol = groupingIDCol.(*expression.Column) + // attach PhysicalExpand to mpp + attachPlan2Task(physicalExpand, mpp) + + // having group sets + clonedAgg, err := finalAgg.Clone(p.SCtx()) + if err != nil { + return nil, nil, nil, nil, err + } + cloneHashAgg := clonedAgg.(*PhysicalHashAgg) + // Clone(), it will share same base-plan elements from the finalAgg, including id,tp,stats. Make a new one here. + cloneHashAgg.Plan = baseimpl.NewBasePlan(cloneHashAgg.SCtx(), cloneHashAgg.TP(), cloneHashAgg.QueryBlockOffset()) + cloneHashAgg.SetStats(finalAgg.StatsInfo()) // reuse the final agg stats here. + + // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. + // Since we may substitute the first arg of normal agg with case-when expression here, append a + // customized proj here rather than depending on postOptimize to insert a blunt one for us. + // + // proj4Partial output all the base col from lower op + caseWhen proj cols. + proj4Partial := new(PhysicalProjection).Init(p.SCtx(), mpp.p.StatsInfo(), mpp.p.QueryBlockOffset()) + for _, col := range mpp.p.Schema().Columns { + proj4Partial.Exprs = append(proj4Partial.Exprs, col) + } + proj4Partial.SetSchema(mpp.p.Schema().Clone()) + + partialHashAgg := partialAgg.(*PhysicalHashAgg) + partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) + partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) + // it will create a new stats for partial agg. + partialHashAgg.scaleStats4GroupingSets(groupingSets, groupingIDCol.(*expression.Column), proj4Partial.Schema(), proj4Partial.StatsInfo()) + for _, fun := range partialHashAgg.AggFuncs { + if !fun.HasDistinct { + // for normal agg phase1, we should also modify them to target for specified group data. + // Expr = (case when groupingID = targeted_groupingID then arg else null end) + eqExpr := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) + caseWhen := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Case, fun.Args[0].GetType(ectx), eqExpr, fun.Args[0], expression.NewNull()) + caseWhenProjCol := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: fun.Args[0].GetType(ectx), + } + proj4Partial.Exprs = append(proj4Partial.Exprs, caseWhen) + proj4Partial.Schema().Append(caseWhenProjCol) + fun.Args[0] = caseWhenProjCol + } + } + + // step2: adjust middle agg + // middleHashAgg shared the same stats with the final agg does. + middleHashAgg := cloneHashAgg + middleSchema := expression.NewSchema() + schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) + for _, fun := range middleHashAgg.AggFuncs { + col := &expression.Column{ + UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), + RetType: fun.RetTp, + } + if fun.HasDistinct { + // let count distinct agg aggregate on whole-scope data rather using case-when expr to target on specified group. (agg null strict attribute) + fun.Mode = aggregation.Partial1Mode + } else { + fun.Mode = aggregation.Partial2Mode + originalCol := fun.Args[0].(*expression.Column) + // record the origin column unique id down before change it to be case when expr. + // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) + schemaMap[originalCol.UniqueID] = col + } + middleSchema.Append(col) + } + middleHashAgg.schema = middleSchema + + // step3: adjust final agg + finalHashAgg := finalAgg.(*PhysicalHashAgg) + finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) + for i, fun := range finalHashAgg.AggFuncs { + newArgs := make([]expression.Expression, 0, 1) + if fun.HasDistinct { + // change count(distinct) agg to sum() + fun.Name = ast.AggFuncSum + fun.HasDistinct = false + // count(distinct a,b) -> become a single partial result col. + newArgs = append(newArgs, middleSchema.Columns[i]) + } else { + // remap final normal agg args to be output schema of middle normal agg. + for _, arg := range fun.Args { + newCol, err := arg.RemapColumn(schemaMap) + if err != nil { + return nil, nil, nil, nil, err + } + newArgs = append(newArgs, newCol) + } + } + fun.Mode = aggregation.FinalMode + fun.Args = newArgs + fun.GroupingID = 0 + finalAggDescs = append(finalAggDescs, fun) + } + finalHashAgg.AggFuncs = finalAggDescs + return finalHashAgg, middleHashAgg, partialHashAgg, proj4Partial, nil +} + +func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...base.Task) base.Task { + ectx := p.SCtx().GetExprCtx().GetEvalCtx() + + t := tasks[0].Copy() + mpp, ok := t.(*MppTask) + if !ok { + return base.InvalidTask + } + switch p.MppRunMode { + case Mpp1Phase: + // 1-phase agg: when the partition columns can be satisfied, where the plan does not need to enforce Exchange + // only push down the original agg + proj := p.convertAvgForMPP() + attachPlan2Task(p, mpp) + if proj != nil { + attachPlan2Task(proj, mpp) + } + return mpp + case Mpp2Phase: + // TODO: when partition property is matched by sub-plan, we actually needn't do extra an exchange and final agg. + proj := p.convertAvgForMPP() + partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true) + if partialAgg == nil { + return base.InvalidTask + } + attachPlan2Task(partialAgg, mpp) + partitionCols := p.MppPartitionCols + if len(partitionCols) == 0 { + items := finalAgg.(*PhysicalHashAgg).GroupByItems + partitionCols = make([]*property.MPPPartitionColumn, 0, len(items)) + for _, expr := range items { + col, ok := expr.(*expression.Column) + if !ok { + return base.InvalidTask + } + partitionCols = append(partitionCols, &property.MPPPartitionColumn{ + Col: col, + CollateID: property.GetCollateIDByNameForPartition(col.GetType(ectx).GetCollate()), + }) + } + } + prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} + newMpp := mpp.enforceExchangerImpl(prop) + if newMpp.Invalid() { + return newMpp + } + attachPlan2Task(finalAgg, newMpp) + // TODO: how to set 2-phase cost? + if proj != nil { + attachPlan2Task(proj, newMpp) + } + return newMpp + case MppTiDB: + partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, false) + if partialAgg != nil { + attachPlan2Task(partialAgg, mpp) + } + t = mpp.ConvertToRootTask(p.SCtx()) + attachPlan2Task(finalAgg, t) + return t + case MppScalar: + prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.SinglePartitionType} + if !mpp.needEnforceExchanger(prop) { + // On the one hand: when the low layer already satisfied the single partition layout, just do the all agg computation in the single node. + return p.attach2TaskForMpp1Phase(mpp) + } + // On the other hand: try to split the mppScalar agg into multi phases agg **down** to multi nodes since data already distributed across nodes. + // we have to check it before the content of p has been modified + canUse3StageAgg, groupingSets := p.scale3StageForDistinctAgg() + proj := p.convertAvgForMPP() + partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true) + if finalAgg == nil { + return base.InvalidTask + } + + final, middle, partial, proj4Partial, err := p.adjust3StagePhaseAgg(partialAgg, finalAgg, canUse3StageAgg, groupingSets, mpp) + if err != nil { + return base.InvalidTask + } + + // partial agg proj would be null if one scalar agg cannot run in two-phase mode + if proj4Partial != nil { + attachPlan2Task(proj4Partial, mpp) + } + + // partial agg would be null if one scalar agg cannot run in two-phase mode + if partial != nil { + attachPlan2Task(partial, mpp) + } + + if middle != nil && canUse3StageAgg { + items := partial.(*PhysicalHashAgg).GroupByItems + partitionCols := make([]*property.MPPPartitionColumn, 0, len(items)) + for _, expr := range items { + col, ok := expr.(*expression.Column) + if !ok { + continue + } + partitionCols = append(partitionCols, &property.MPPPartitionColumn{ + Col: col, + CollateID: property.GetCollateIDByNameForPartition(col.GetType(ectx).GetCollate()), + }) + } + + exProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} + newMpp := mpp.enforceExchanger(exProp) + attachPlan2Task(middle, newMpp) + mpp = newMpp + } + + // prop here still be the first generated single-partition requirement. + newMpp := mpp.enforceExchanger(prop) + attachPlan2Task(final, newMpp) + if proj == nil { + proj = PhysicalProjection{ + Exprs: make([]expression.Expression, 0, len(p.Schema().Columns)), + }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) + for _, col := range p.Schema().Columns { + proj.Exprs = append(proj.Exprs, col) + } + proj.SetSchema(p.schema) + } + attachPlan2Task(proj, newMpp) + return newMpp + default: + return base.InvalidTask + } +} + +// Attach2Task implements the PhysicalPlan interface. +func (p *PhysicalHashAgg) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + if cop, ok := t.(*CopTask); ok { + if len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0 { + copTaskType := cop.getStoreType() + partialAgg, finalAgg := p.newPartialAggregate(copTaskType, false) + if partialAgg != nil { + if cop.tablePlan != nil { + cop.finishIndexPlan() + partialAgg.SetChildren(cop.tablePlan) + cop.tablePlan = partialAgg + // If needExtraProj is true, a projection will be created above the PhysicalIndexLookUpReader to make sure + // the schema is the same as the original DataSource schema. + // However, we pushed down the agg here, the partial agg was placed on the top of tablePlan, and the final + // agg will be placed above the PhysicalIndexLookUpReader, and the schema will be set correctly for them. + // If we add the projection again, the projection will be between the PhysicalIndexLookUpReader and + // the partial agg, and the schema will be broken. + cop.needExtraProj = false + } else { + partialAgg.SetChildren(cop.indexPlan) + cop.indexPlan = partialAgg + } + } + // In `newPartialAggregate`, we are using stats of final aggregation as stats + // of `partialAgg`, so the network cost of transferring result rows of `partialAgg` + // to TiDB is normally under-estimated for hash aggregation, since the group-by + // column may be independent of the column used for region distribution, so a closer + // estimation of network cost for hash aggregation may multiply the number of + // regions involved in the `partialAgg`, which is unknown however. + t = cop.ConvertToRootTask(p.SCtx()) + attachPlan2Task(finalAgg, t) + } else { + t = cop.ConvertToRootTask(p.SCtx()) + attachPlan2Task(p, t) + } + } else if _, ok := t.(*MppTask); ok { + return p.attach2TaskForMpp(tasks...) + } else { + attachPlan2Task(p, t) + } + return t +} + +func (p *PhysicalWindow) attach2TaskForMPP(mpp *MppTask) base.Task { + // FIXME: currently, tiflash's join has different schema with TiDB, + // so we have to rebuild the schema of join and operators which may inherit schema from join. + // for window, we take the sub-plan's schema, and the schema generated by windowDescs. + columns := p.Schema().Clone().Columns[len(p.Schema().Columns)-len(p.WindowFuncDescs):] + p.schema = expression.MergeSchema(mpp.Plan().Schema(), expression.NewSchema(columns...)) + + failpoint.Inject("CheckMPPWindowSchemaLength", func() { + if len(p.Schema().Columns) != len(mpp.Plan().Schema().Columns)+len(p.WindowFuncDescs) { + panic("mpp physical window has incorrect schema length") + } + }) + + return attachPlan2Task(p, mpp) +} + +// Attach2Task implements the PhysicalPlan interface. +func (p *PhysicalWindow) Attach2Task(tasks ...base.Task) base.Task { + if mpp, ok := tasks[0].Copy().(*MppTask); ok && p.storeTp == kv.TiFlash { + return p.attach2TaskForMPP(mpp) + } + t := tasks[0].ConvertToRootTask(p.SCtx()) + return attachPlan2Task(p.self, t) +} + +// Attach2Task implements the PhysicalPlan interface. +func (p *PhysicalCTEStorage) Attach2Task(tasks ...base.Task) base.Task { + t := tasks[0].Copy() + if mpp, ok := t.(*MppTask); ok { + p.SetChildren(t.Plan()) + return &MppTask{ + p: p, + partTp: mpp.partTp, + hashCols: mpp.hashCols, + tblColHists: mpp.tblColHists, + } + } + t.ConvertToRootTask(p.SCtx()) + p.SetChildren(t.Plan()) + ta := &RootTask{} + ta.SetPlan(p) + return ta +} + +// Attach2Task implements the PhysicalPlan interface. +func (p *PhysicalSequence) Attach2Task(tasks ...base.Task) base.Task { + for _, t := range tasks { + _, isMpp := t.(*MppTask) + if !isMpp { + return tasks[len(tasks)-1] + } + } + + lastTask := tasks[len(tasks)-1].(*MppTask) + + children := make([]base.PhysicalPlan, 0, len(tasks)) + for _, t := range tasks { + children = append(children, t.Plan()) + } + + p.SetChildren(children...) + + mppTask := &MppTask{ + p: p, + partTp: lastTask.partTp, + hashCols: lastTask.hashCols, + tblColHists: lastTask.tblColHists, + } + return mppTask +} + +func collectPartitionInfosFromMPPPlan(p *PhysicalTableReader, mppPlan base.PhysicalPlan) { + switch x := mppPlan.(type) { + case *PhysicalTableScan: + p.TableScanAndPartitionInfos = append(p.TableScanAndPartitionInfos, tableScanAndPartitionInfo{x, x.PlanPartInfo}) + default: + for _, ch := range mppPlan.Children() { + collectPartitionInfosFromMPPPlan(p, ch) + } + } +} + +func collectRowSizeFromMPPPlan(mppPlan base.PhysicalPlan) (rowSize float64) { + if mppPlan != nil && mppPlan.StatsInfo() != nil && mppPlan.StatsInfo().HistColl != nil { + return cardinality.GetAvgRowSize(mppPlan.SCtx(), mppPlan.StatsInfo().HistColl, mppPlan.Schema().Columns, false, false) + } + return 1 // use 1 as lower-bound for safety +} + +func accumulateNetSeekCost4MPP(p base.PhysicalPlan) (cost float64) { + if ts, ok := p.(*PhysicalTableScan); ok { + return float64(len(ts.Ranges)) * float64(len(ts.Columns)) * ts.SCtx().GetSessionVars().GetSeekFactor(ts.Table) + } + for _, c := range p.Children() { + cost += accumulateNetSeekCost4MPP(c) + } + return +} + +func tryExpandVirtualColumn(p base.PhysicalPlan) { + if ts, ok := p.(*PhysicalTableScan); ok { + ts.Columns = ExpandVirtualColumn(ts.Columns, ts.schema, ts.Table.Columns) + return + } + for _, child := range p.Children() { + tryExpandVirtualColumn(child) + } +} + +func (t *MppTask) needEnforceExchanger(prop *property.PhysicalProperty) bool { + switch prop.MPPPartitionTp { + case property.AnyType: + return false + case property.BroadcastType: + return true + case property.SinglePartitionType: + return t.partTp != property.SinglePartitionType + default: + if t.partTp != property.HashType { + return true + } + // TODO: consider equalivant class + // TODO: `prop.IsSubsetOf` is enough, instead of equal. + // for example, if already partitioned by hash(B,C), then same (A,B,C) must distribute on a same node. + if len(prop.MPPPartitionCols) != len(t.hashCols) { + return true + } + for i, col := range prop.MPPPartitionCols { + if !col.Equal(t.hashCols[i]) { + return true + } + } + return false + } +} + +func (t *MppTask) enforceExchanger(prop *property.PhysicalProperty) *MppTask { + if !t.needEnforceExchanger(prop) { + return t + } + return t.Copy().(*MppTask).enforceExchangerImpl(prop) +} + +func (t *MppTask) enforceExchangerImpl(prop *property.PhysicalProperty) *MppTask { + if collate.NewCollationEnabled() && !t.p.SCtx().GetSessionVars().HashExchangeWithNewCollation && prop.MPPPartitionTp == property.HashType { + for _, col := range prop.MPPPartitionCols { + if types.IsString(col.Col.RetType.GetType()) { + t.p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because when `new_collation_enabled` is true, HashJoin or HashAgg with string key is not supported now.") + return &MppTask{} + } + } + } + ctx := t.p.SCtx() + sender := PhysicalExchangeSender{ + ExchangeType: prop.MPPPartitionTp.ToExchangeType(), + HashCols: prop.MPPPartitionCols, + }.Init(ctx, t.p.StatsInfo()) + + if ctx.GetSessionVars().ChooseMppVersion() >= kv.MppVersionV1 { + sender.CompressionMode = ctx.GetSessionVars().ChooseMppExchangeCompressionMode() + } + + sender.SetChildren(t.p) + receiver := PhysicalExchangeReceiver{}.Init(ctx, t.p.StatsInfo()) + receiver.SetChildren(sender) + return &MppTask{ + p: receiver, + partTp: prop.MPPPartitionTp, + hashCols: prop.MPPPartitionCols, + } +} diff --git a/pkg/planner/optimize.go b/pkg/planner/optimize.go index cd7757176821e..b5796eca459e0 100644 --- a/pkg/planner/optimize.go +++ b/pkg/planner/optimize.go @@ -246,7 +246,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in useBinding := enableUseBinding && isStmtNode && match if sessVars.StmtCtx.EnableOptimizerDebugTrace { - failpoint.Inject("SetBindingTimeToZero", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SetBindingTimeToZero")); _err_ == nil { if val.(bool) && bindings != nil { bindings = bindings.Copy() for i := range bindings { @@ -254,7 +254,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in bindings[i].UpdateTime = types.ZeroTime } } - }) + } debugtrace.RecordAnyValuesWithNames(pctx, "Used binding", useBinding, "Enable binding", enableUseBinding, @@ -459,19 +459,19 @@ var planBuilderPool = sync.Pool{ var optimizeCnt int func optimize(ctx context.Context, sctx pctx.PlanContext, node ast.Node, is infoschema.InfoSchema) (base.Plan, types.NameSlice, float64, error) { - failpoint.Inject("checkOptimizeCountOne", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkOptimizeCountOne")); _err_ == nil { // only count the optimization for SQL with specified text if testSQL, ok := val.(string); ok && testSQL == node.OriginalText() { optimizeCnt++ if optimizeCnt > 1 { - failpoint.Return(nil, nil, 0, errors.New("gofail wrong optimizerCnt error")) + return nil, nil, 0, errors.New("gofail wrong optimizerCnt error") } } - }) - failpoint.Inject("mockHighLoadForOptimize", func() { + } + if _, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForOptimize")); _err_ == nil { sqlPrefixes := []string{"select"} topsql.MockHighCPULoad(sctx.GetSessionVars().StmtCtx.OriginalSQL, sqlPrefixes, 10) - }) + } sessVars := sctx.GetSessionVars() if sessVars.StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) @@ -561,9 +561,9 @@ func buildLogicalPlan(ctx context.Context, sctx pctx.PlanContext, node ast.Node, sctx.GetSessionVars().MapScalarSubQ = nil sctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol = nil - failpoint.Inject("mockRandomPlanID", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockRandomPlanID")); _err_ == nil { sctx.GetSessionVars().PlanID.Store(rand.Int31n(1000)) // nolint:gosec - }) + } // reset fields about rewrite sctx.GetSessionVars().RewritePhaseInfo.Reset() diff --git a/pkg/planner/optimize.go__failpoint_stash__ b/pkg/planner/optimize.go__failpoint_stash__ new file mode 100644 index 0000000000000..cd7757176821e --- /dev/null +++ b/pkg/planner/optimize.go__failpoint_stash__ @@ -0,0 +1,631 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package planner + +import ( + "context" + "math" + "math/rand" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/planner/cascades" + pctx "github.com/pingcap/tidb/pkg/planner/context" + "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/hint" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/topsql" + "github.com/pingcap/tidb/pkg/util/tracing" + "go.uber.org/zap" +) + +// IsReadOnly check whether the ast.Node is a read only statement. +func IsReadOnly(node ast.Node, vars *variable.SessionVars) bool { + if execStmt, isExecStmt := node.(*ast.ExecuteStmt); isExecStmt { + prepareStmt, err := core.GetPreparedStmt(execStmt, vars) + if err != nil { + logutil.BgLogger().Warn("GetPreparedStmt failed", zap.Error(err)) + return false + } + return ast.IsReadOnly(prepareStmt.PreparedAst.Stmt) + } + return ast.IsReadOnly(node) +} + +// getPlanFromNonPreparedPlanCache tries to get an available cached plan from the NonPrepared Plan Cache for this stmt. +func getPlanFromNonPreparedPlanCache(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode, is infoschema.InfoSchema) (p base.Plan, ns types.NameSlice, ok bool, err error) { + stmtCtx := sctx.GetSessionVars().StmtCtx + _, isExplain := stmt.(*ast.ExplainStmt) + if !sctx.GetSessionVars().EnableNonPreparedPlanCache || // disabled + stmtCtx.InPreparedPlanBuilding || // already in cached plan rebuilding phase + stmtCtx.EnableOptimizerCETrace || stmtCtx.EnableOptimizeTrace || // in trace + stmtCtx.InRestrictedSQL || // is internal SQL + isExplain || // explain external + !sctx.GetSessionVars().DisableTxnAutoRetry || // txn-auto-retry + sctx.GetSessionVars().InMultiStmts || // in multi-stmt + (stmtCtx.InExplainStmt && stmtCtx.ExplainFormat != types.ExplainFormatPlanCache) { // in explain internal + return nil, nil, false, nil + } + ok, reason := core.NonPreparedPlanCacheableWithCtx(sctx.GetPlanCtx(), stmt, is) + if !ok { + if !isExplain && stmtCtx.InExplainStmt && stmtCtx.ExplainFormat == types.ExplainFormatPlanCache { + stmtCtx.AppendWarning(errors.NewNoStackErrorf("skip non-prepared plan-cache: %s", reason)) + } + return nil, nil, false, nil + } + + paramSQL, paramsVals, err := core.GetParamSQLFromAST(stmt) + if err != nil { + return nil, nil, false, err + } + if intest.InTest && ctx.Value(core.PlanCacheKeyTestIssue43667{}) != nil { // update the AST in the middle of the process + ctx.Value(core.PlanCacheKeyTestIssue43667{}).(func(stmt ast.StmtNode))(stmt) + } + val := sctx.GetSessionVars().GetNonPreparedPlanCacheStmt(paramSQL) + paramExprs := core.Params2Expressions(paramsVals) + + if val == nil { + // Create a new AST upon this parameterized SQL instead of using the original AST. + // Keep the original AST unchanged to avoid any side effect. + paramStmt, err := core.ParseParameterizedSQL(sctx, paramSQL) + if err != nil { + // This can happen rarely, cannot parse the parameterized(restored) SQL successfully, skip the plan cache in this case. + sctx.GetSessionVars().StmtCtx.AppendWarning(err) + return nil, nil, false, nil + } + // GeneratePlanCacheStmtWithAST may evaluate these parameters so set their values into SCtx in advance. + if err := core.SetParameterValuesIntoSCtx(sctx.GetPlanCtx(), true, nil, paramExprs); err != nil { + return nil, nil, false, err + } + cachedStmt, _, _, err := core.GeneratePlanCacheStmtWithAST(ctx, sctx, false, paramSQL, paramStmt, is) + if err != nil { + return nil, nil, false, err + } + sctx.GetSessionVars().AddNonPreparedPlanCacheStmt(paramSQL, cachedStmt) + val = cachedStmt + } + cachedStmt := val.(*core.PlanCacheStmt) + + cachedPlan, names, err := core.GetPlanFromPlanCache(ctx, sctx, true, is, cachedStmt, paramExprs) + if err != nil { + return nil, nil, false, err + } + + if intest.InTest && ctx.Value(core.PlanCacheKeyTestIssue47133{}) != nil { + ctx.Value(core.PlanCacheKeyTestIssue47133{}).(func(names []*types.FieldName))(names) + } + + return cachedPlan, names, true, nil +} + +// Optimize does optimization and creates a Plan. +func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (plan base.Plan, slice types.NameSlice, retErr error) { + defer tracing.StartRegion(ctx, "planner.Optimize").End() + sessVars := sctx.GetSessionVars() + pctx := sctx.GetPlanCtx() + if sessVars.StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(pctx) + defer debugtrace.LeaveContextCommon(pctx) + } + + if !sessVars.InRestrictedSQL && (variable.RestrictedReadOnly.Load() || variable.VarTiDBSuperReadOnly.Load()) { + allowed, err := allowInReadOnlyMode(pctx, node) + if err != nil { + return nil, nil, err + } + if !allowed { + return nil, nil, errors.Trace(plannererrors.ErrSQLInReadOnlyMode) + } + } + + if sessVars.SQLMode.HasStrictMode() && !IsReadOnly(node, sessVars) { + sessVars.StmtCtx.TiFlashEngineRemovedDueToStrictSQLMode = true + _, hasTiFlashAccess := sessVars.IsolationReadEngines[kv.TiFlash] + if hasTiFlashAccess { + delete(sessVars.IsolationReadEngines, kv.TiFlash) + } + defer func() { + sessVars.StmtCtx.TiFlashEngineRemovedDueToStrictSQLMode = false + if hasTiFlashAccess { + sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} + } + }() + } + + // handle the execute statement + if execAST, ok := node.(*ast.ExecuteStmt); ok { + p, names, err := OptimizeExecStmt(ctx, sctx, execAST, is) + return p, names, err + } + + tableHints := hint.ExtractTableHintsFromStmtNode(node, sessVars.StmtCtx) + originStmtHints, _, warns := hint.ParseStmtHints(tableHints, + setVarHintChecker, hypoIndexChecker(ctx, is), + sessVars.CurrentDB, byte(kv.ReplicaReadFollower)) + sessVars.StmtCtx.StmtHints = originStmtHints + for _, warn := range warns { + sessVars.StmtCtx.AppendWarning(warn) + } + + defer func() { + // Override the resource group if the hint is set. + if retErr == nil && sessVars.StmtCtx.StmtHints.HasResourceGroup { + if variable.EnableResourceControl.Load() { + hasPriv := true + // only check dynamic privilege when strict-mode is enabled. + if variable.EnableResourceControlStrictMode.Load() { + checker := privilege.GetPrivilegeManager(sctx) + if checker != nil { + hasRgAdminPriv := checker.RequestDynamicVerification(sctx.GetSessionVars().ActiveRoles, "RESOURCE_GROUP_ADMIN", false) + hasRgUserPriv := checker.RequestDynamicVerification(sctx.GetSessionVars().ActiveRoles, "RESOURCE_GROUP_USER", false) + hasPriv = hasRgAdminPriv || hasRgUserPriv + } + } + if hasPriv { + sessVars.StmtCtx.ResourceGroupName = sessVars.StmtCtx.StmtHints.ResourceGroup + // if we are in a txn, should update the txn resource name to let the txn + // commit with the hint resource group. + if txn, err := sctx.Txn(false); err == nil && txn != nil && txn.Valid() { + kv.SetTxnResourceGroup(txn, sessVars.StmtCtx.ResourceGroupName) + } + } else { + err := plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("SUPER or RESOURCE_GROUP_ADMIN or RESOURCE_GROUP_USER") + sessVars.StmtCtx.AppendWarning(err) + } + } else { + err := infoschema.ErrResourceGroupSupportDisabled + sessVars.StmtCtx.AppendWarning(err) + } + } + }() + + warns = warns[:0] + for name, val := range sessVars.StmtCtx.StmtHints.SetVars { + oldV, err := sessVars.SetSystemVarWithOldValAsRet(name, val) + if err != nil { + sessVars.StmtCtx.AppendWarning(err) + } + sessVars.StmtCtx.AddSetVarHintRestore(name, oldV) + } + if len(sessVars.StmtCtx.StmtHints.SetVars) > 0 { + sessVars.StmtCtx.SetSkipPlanCache("SET_VAR is used in the SQL") + } + + if _, isolationReadContainTiKV := sessVars.IsolationReadEngines[kv.TiKV]; isolationReadContainTiKV { + var fp base.Plan + if fpv, ok := sctx.Value(core.PointPlanKey).(core.PointPlanVal); ok { + // point plan is already tried in a multi-statement query. + fp = fpv.Plan + } else { + fp = core.TryFastPlan(pctx, node) + } + if fp != nil { + return fp, fp.OutputNames(), nil + } + } + if err := pctx.AdviseTxnWarmup(); err != nil { + return nil, nil, err + } + + enableUseBinding := sessVars.UsePlanBaselines + stmtNode, isStmtNode := node.(ast.StmtNode) + binding, match, scope := bindinfo.MatchSQLBinding(sctx, stmtNode) + var bindings bindinfo.Bindings + if match { + bindings = []bindinfo.Binding{binding} + } + + useBinding := enableUseBinding && isStmtNode && match + if sessVars.StmtCtx.EnableOptimizerDebugTrace { + failpoint.Inject("SetBindingTimeToZero", func(val failpoint.Value) { + if val.(bool) && bindings != nil { + bindings = bindings.Copy() + for i := range bindings { + bindings[i].CreateTime = types.ZeroTime + bindings[i].UpdateTime = types.ZeroTime + } + } + }) + debugtrace.RecordAnyValuesWithNames(pctx, + "Used binding", useBinding, + "Enable binding", enableUseBinding, + "IsStmtNode", isStmtNode, + "Matched", match, + "Scope", scope, + "Matched bindings", bindings, + ) + } + if isStmtNode { + // add the extra Limit after matching the bind record + stmtNode = core.TryAddExtraLimit(sctx, stmtNode) + node = stmtNode + } + + // try to get Plan from the NonPrepared Plan Cache + if sessVars.EnableNonPreparedPlanCache && + isStmtNode && + !useBinding { // TODO: support binding + cachedPlan, names, ok, err := getPlanFromNonPreparedPlanCache(ctx, sctx, stmtNode, is) + if err != nil { + return nil, nil, err + } + if ok { + return cachedPlan, names, nil + } + } + + var ( + names types.NameSlice + bestPlan, bestPlanFromBind base.Plan + chosenBinding bindinfo.Binding + err error + ) + if useBinding { + minCost := math.MaxFloat64 + var bindStmtHints hint.StmtHints + originHints := hint.CollectHint(stmtNode) + // bindings must be not nil when coming here, try to find the best binding. + for _, binding := range bindings { + if !binding.IsBindingEnabled() { + continue + } + if sessVars.StmtCtx.EnableOptimizerDebugTrace { + core.DebugTraceTryBinding(pctx, binding.Hint) + } + hint.BindHint(stmtNode, binding.Hint) + curStmtHints, _, curWarns := hint.ParseStmtHints(binding.Hint.GetStmtHints(), + setVarHintChecker, hypoIndexChecker(ctx, is), + sessVars.CurrentDB, byte(kv.ReplicaReadFollower)) + sessVars.StmtCtx.StmtHints = curStmtHints + // update session var by hint /set_var/ + for name, val := range sessVars.StmtCtx.StmtHints.SetVars { + oldV, err := sessVars.SetSystemVarWithOldValAsRet(name, val) + if err != nil { + sessVars.StmtCtx.AppendWarning(err) + } + sessVars.StmtCtx.AddSetVarHintRestore(name, oldV) + } + plan, curNames, cost, err := optimize(ctx, pctx, node, is) + if err != nil { + binding.Status = bindinfo.Invalid + handleInvalidBinding(ctx, pctx, scope, binding) + continue + } + if cost < minCost { + bindStmtHints, warns, minCost, names, bestPlanFromBind, chosenBinding = curStmtHints, curWarns, cost, curNames, plan, binding + } + } + if bestPlanFromBind == nil { + sessVars.StmtCtx.AppendWarning(errors.NewNoStackError("no plan generated from bindings")) + } else { + bestPlan = bestPlanFromBind + sessVars.StmtCtx.StmtHints = bindStmtHints + for _, warn := range warns { + sessVars.StmtCtx.AppendWarning(warn) + } + sessVars.StmtCtx.BindSQL = chosenBinding.BindSQL + sessVars.FoundInBinding = true + if sessVars.StmtCtx.InVerboseExplain { + sessVars.StmtCtx.AppendNote(errors.NewNoStackErrorf("Using the bindSQL: %v", chosenBinding.BindSQL)) + } else { + sessVars.StmtCtx.AppendExtraNote(errors.NewNoStackErrorf("Using the bindSQL: %v", chosenBinding.BindSQL)) + } + if len(tableHints) > 0 { + sessVars.StmtCtx.AppendWarning(errors.NewNoStackErrorf("The system ignores the hints in the current query and uses the hints specified in the bindSQL: %v", chosenBinding.BindSQL)) + } + } + // Restore the hint to avoid changing the stmt node. + hint.BindHint(stmtNode, originHints) + } + + if sessVars.StmtCtx.EnableOptimizerDebugTrace && bestPlanFromBind != nil { + core.DebugTraceBestBinding(pctx, chosenBinding.Hint) + } + // No plan found from the bindings, or the bindings are ignored. + if bestPlan == nil { + sessVars.StmtCtx.StmtHints = originStmtHints + bestPlan, names, _, err = optimize(ctx, pctx, node, is) + if err != nil { + return nil, nil, err + } + } + + // Add a baseline evolution task if: + // 1. the returned plan is from bindings; + // 2. the query is a select statement; + // 3. the original binding contains no read_from_storage hint; + // 4. the plan when ignoring bindings contains no tiflash hint; + // 5. the pending verified binding has not been added already; + savedStmtHints := sessVars.StmtCtx.StmtHints + defer func() { + sessVars.StmtCtx.StmtHints = savedStmtHints + }() + if sessVars.EvolvePlanBaselines && bestPlanFromBind != nil && + sessVars.SelectLimit == math.MaxUint64 { // do not evolve this query if sql_select_limit is enabled + // Check bestPlanFromBind firstly to avoid nil stmtNode. + if _, ok := stmtNode.(*ast.SelectStmt); ok && !bindings[0].Hint.ContainTableHint(hint.HintReadFromStorage) { + sessVars.StmtCtx.StmtHints = originStmtHints + defPlan, _, _, err := optimize(ctx, pctx, node, is) + if err != nil { + // Ignore this evolution task. + return bestPlan, names, nil + } + defPlanHints := core.GenHintsFromPhysicalPlan(defPlan) + for _, h := range defPlanHints { + if h.HintName.String() == hint.HintReadFromStorage { + return bestPlan, names, nil + } + } + } + } + + return bestPlan, names, nil +} + +// OptimizeForForeignKeyCascade does optimization and creates a Plan for foreign key cascade. +// Compare to Optimize, OptimizeForForeignKeyCascade only build plan by StmtNode, +// doesn't consider plan cache and plan binding, also doesn't do privilege check. +func OptimizeForForeignKeyCascade(ctx context.Context, sctx pctx.PlanContext, node ast.StmtNode, is infoschema.InfoSchema) (base.Plan, error) { + builder := planBuilderPool.Get().(*core.PlanBuilder) + defer planBuilderPool.Put(builder.ResetForReuse()) + hintProcessor := hint.NewQBHintHandler(sctx.GetSessionVars().StmtCtx) + builder.Init(sctx, is, hintProcessor) + p, err := builder.Build(ctx, node) + if err != nil { + return nil, err + } + if err := core.CheckTableLock(sctx, is, builder.GetVisitInfo()); err != nil { + return nil, err + } + return p, nil +} + +func allowInReadOnlyMode(sctx pctx.PlanContext, node ast.Node) (bool, error) { + pm := privilege.GetPrivilegeManager(sctx) + if pm == nil { + return true, nil + } + roles := sctx.GetSessionVars().ActiveRoles + // allow replication thread + // NOTE: it is required, whether SEM is enabled or not, only user with explicit RESTRICTED_REPLICA_WRITER_ADMIN granted can ignore the restriction, so we need to surpass the case that if SEM is not enabled, SUPER will has all privileges + if pm.HasExplicitlyGrantedDynamicPrivilege(roles, "RESTRICTED_REPLICA_WRITER_ADMIN", false) { + return true, nil + } + + switch node.(type) { + // allow change variables (otherwise can't unset read-only mode) + case *ast.SetStmt, + // allow analyze table + *ast.AnalyzeTableStmt, + *ast.UseStmt, + *ast.ShowStmt, + *ast.CreateBindingStmt, + *ast.DropBindingStmt, + *ast.PrepareStmt, + *ast.BeginStmt, + *ast.RollbackStmt: + return true, nil + case *ast.CommitStmt: + txn, err := sctx.Txn(true) + if err != nil { + return false, err + } + if !txn.IsReadOnly() { + return false, txn.Rollback() + } + return true, nil + } + + vars := sctx.GetSessionVars() + return IsReadOnly(node, vars), nil +} + +var planBuilderPool = sync.Pool{ + New: func() any { + return core.NewPlanBuilder() + }, +} + +// optimizeCnt is a global variable only used for test. +var optimizeCnt int + +func optimize(ctx context.Context, sctx pctx.PlanContext, node ast.Node, is infoschema.InfoSchema) (base.Plan, types.NameSlice, float64, error) { + failpoint.Inject("checkOptimizeCountOne", func(val failpoint.Value) { + // only count the optimization for SQL with specified text + if testSQL, ok := val.(string); ok && testSQL == node.OriginalText() { + optimizeCnt++ + if optimizeCnt > 1 { + failpoint.Return(nil, nil, 0, errors.New("gofail wrong optimizerCnt error")) + } + } + }) + failpoint.Inject("mockHighLoadForOptimize", func() { + sqlPrefixes := []string{"select"} + topsql.MockHighCPULoad(sctx.GetSessionVars().StmtCtx.OriginalSQL, sqlPrefixes, 10) + }) + sessVars := sctx.GetSessionVars() + if sessVars.StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + defer debugtrace.LeaveContextCommon(sctx) + } + + // build logical plan + hintProcessor := hint.NewQBHintHandler(sctx.GetSessionVars().StmtCtx) + node.Accept(hintProcessor) + defer hintProcessor.HandleUnusedViewHints() + builder := planBuilderPool.Get().(*core.PlanBuilder) + defer planBuilderPool.Put(builder.ResetForReuse()) + builder.Init(sctx, is, hintProcessor) + p, err := buildLogicalPlan(ctx, sctx, node, builder) + if err != nil { + return nil, nil, 0, err + } + + activeRoles := sessVars.ActiveRoles + // Check privilege. Maybe it's better to move this to the Preprocess, but + // we need the table information to check privilege, which is collected + // into the visitInfo in the logical plan builder. + if pm := privilege.GetPrivilegeManager(sctx); pm != nil { + visitInfo := core.VisitInfo4PrivCheck(ctx, is, node, builder.GetVisitInfo()) + if err := core.CheckPrivilege(activeRoles, pm, visitInfo); err != nil { + return nil, nil, 0, err + } + } + + if err := core.CheckTableLock(sctx, is, builder.GetVisitInfo()); err != nil { + return nil, nil, 0, err + } + + names := p.OutputNames() + + // Handle the non-logical plan statement. + logic, isLogicalPlan := p.(base.LogicalPlan) + if !isLogicalPlan { + return p, names, 0, nil + } + + core.RecheckCTE(logic) + + // Handle the logical plan statement, use cascades planner if enabled. + if sessVars.GetEnableCascadesPlanner() { + finalPlan, cost, err := cascades.DefaultOptimizer.FindBestPlan(sctx, logic) + return finalPlan, names, cost, err + } + + beginOpt := time.Now() + finalPlan, cost, err := core.DoOptimize(ctx, sctx, builder.GetOptFlag(), logic) + // TODO: capture plan replayer here if it matches sql and plan digest + + sessVars.DurationOptimization = time.Since(beginOpt) + return finalPlan, names, cost, err +} + +// OptimizeExecStmt to handle the "execute" statement +func OptimizeExecStmt(ctx context.Context, sctx sessionctx.Context, + execAst *ast.ExecuteStmt, is infoschema.InfoSchema) (base.Plan, types.NameSlice, error) { + builder := planBuilderPool.Get().(*core.PlanBuilder) + defer planBuilderPool.Put(builder.ResetForReuse()) + pctx := sctx.GetPlanCtx() + builder.Init(pctx, is, nil) + + p, err := buildLogicalPlan(ctx, pctx, execAst, builder) + if err != nil { + return nil, nil, err + } + exec, ok := p.(*core.Execute) + if !ok { + return nil, nil, errors.Errorf("invalid result plan type, should be Execute") + } + plan, names, err := core.GetPlanFromPlanCache(ctx, sctx, false, is, exec.PrepStmt, exec.Params) + if err != nil { + return nil, nil, err + } + exec.Plan = plan + exec.SetOutputNames(names) + exec.Stmt = exec.PrepStmt.PreparedAst.Stmt + return exec, names, nil +} + +func buildLogicalPlan(ctx context.Context, sctx pctx.PlanContext, node ast.Node, builder *core.PlanBuilder) (base.Plan, error) { + sctx.GetSessionVars().PlanID.Store(0) + sctx.GetSessionVars().PlanColumnID.Store(0) + sctx.GetSessionVars().MapScalarSubQ = nil + sctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol = nil + + failpoint.Inject("mockRandomPlanID", func() { + sctx.GetSessionVars().PlanID.Store(rand.Int31n(1000)) // nolint:gosec + }) + + // reset fields about rewrite + sctx.GetSessionVars().RewritePhaseInfo.Reset() + beginRewrite := time.Now() + p, err := builder.Build(ctx, node) + if err != nil { + return nil, err + } + sctx.GetSessionVars().RewritePhaseInfo.DurationRewrite = time.Since(beginRewrite) + if exec, ok := p.(*core.Execute); ok && exec.PrepStmt != nil { + sctx.GetSessionVars().StmtCtx.Tables = core.GetDBTableInfo(exec.PrepStmt.VisitInfos) + } else { + sctx.GetSessionVars().StmtCtx.Tables = core.GetDBTableInfo(builder.GetVisitInfo()) + } + return p, nil +} + +func handleInvalidBinding(ctx context.Context, sctx pctx.PlanContext, level string, binding bindinfo.Binding) { + sessionHandle := sctx.Value(bindinfo.SessionBindInfoKeyType).(bindinfo.SessionBindingHandle) + err := sessionHandle.DropSessionBinding(binding.SQLDigest) + if err != nil { + logutil.Logger(ctx).Info("drop session bindings failed") + } + if level == metrics.ScopeSession { + return + } + + globalHandle := domain.GetDomain(sctx).BindHandle() + globalHandle.AddInvalidGlobalBinding(binding) +} + +// setVarHintChecker checks whether the variable name in set_var hint is valid. +func setVarHintChecker(varName, hint string) (ok bool, warning error) { + sysVar := variable.GetSysVar(varName) + if sysVar == nil { // no such a variable + return false, plannererrors.ErrUnresolvedHintName.FastGenByArgs(varName, hint) + } + if !sysVar.IsHintUpdatableVerified { + warning = plannererrors.ErrNotHintUpdatable.FastGenByArgs(varName) + } + return true, warning +} + +func hypoIndexChecker(ctx context.Context, is infoschema.InfoSchema) func(db, tbl, col model.CIStr) (colOffset int, err error) { + return func(db, tbl, col model.CIStr) (colOffset int, err error) { + t, err := is.TableByName(ctx, db, tbl) + if err != nil { + return 0, errors.NewNoStackErrorf("table '%v.%v' doesn't exist", db, tbl) + } + for i, tblCol := range t.Cols() { + if tblCol.Name.L == col.L { + return i, nil + } + } + return 0, errors.NewNoStackErrorf("can't find column %v in table %v.%v", col, db, tbl) + } +} + +func init() { + core.OptimizeAstNode = Optimize + core.IsReadOnly = IsReadOnly + bindinfo.GetGlobalBindingHandle = func(sctx sessionctx.Context) bindinfo.GlobalBindingHandle { + return domain.GetDomain(sctx).BindHandle() + } +} diff --git a/pkg/server/binding__failpoint_binding__.go b/pkg/server/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..884841332390a --- /dev/null +++ b/pkg/server/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package server + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 3cd8b5c052637..241f5b8c79ff5 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -261,9 +261,9 @@ func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]b clientPlugin = authPluginImpl.Name } } - failpoint.Inject("FakeAuthSwitch", func() { - failpoint.Return([]byte(clientPlugin), nil) - }) + if _, _err_ := failpoint.Eval(_curpkg_("FakeAuthSwitch")); _err_ == nil { + return []byte(clientPlugin), nil + } enclen := 1 + len(clientPlugin) + 1 + len(cc.salt) + 1 data := cc.alloc.AllocWithLen(4, enclen) data = append(data, mysql.AuthSwitchRequest) // switch request @@ -488,11 +488,11 @@ func (cc *clientConn) readPacket() ([]byte, error) { } func (cc *clientConn) writePacket(data []byte) error { - failpoint.Inject("FakeClientConn", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FakeClientConn")); _err_ == nil { if cc.pkt == nil { - failpoint.Return(nil) + return nil } - }) + } return cc.pkt.WritePacket(data) } @@ -858,10 +858,10 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshake.Respo logutil.Logger(ctx).Warn("Failed to get authentication method for user", zap.String("user", cc.user), zap.String("host", host)) } - failpoint.Inject("FakeUser", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("FakeUser")); _err_ == nil { //nolint:forcetypeassert userplugin = val.(string) - }) + } if userplugin == mysql.AuthSocket { if !cc.isUnixSocket { return nil, servererr.ErrAccessDenied.FastGenByArgs(cc.user, host, hasPassword) @@ -1486,11 +1486,11 @@ func (cc *clientConn) flush(ctx context.Context) error { } } }() - failpoint.Inject("FakeClientConn", func() { + if _, _err_ := failpoint.Eval(_curpkg_("FakeClientConn")); _err_ == nil { if cc.pkt == nil { - failpoint.Return(nil) + return nil } - }) + } return cc.pkt.Flush() } @@ -2335,21 +2335,21 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs resultset.ResultSet, b stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) } for { - failpoint.Inject("fetchNextErr", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("fetchNextErr")); _err_ == nil { //nolint:forcetypeassert switch value.(string) { case "firstNext": - failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) + return firstNext, storeerr.ErrTiFlashServerTimeout case "secondNext": if !firstNext { - failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) + return firstNext, storeerr.ErrTiFlashServerTimeout } case "secondNextAndRetConflict": if !firstNext && validNextCount > 1 { - failpoint.Return(firstNext, kv.ErrWriteConflict) + return firstNext, kv.ErrWriteConflict } } - }) + } // Here server.tidbResultSet implements Next method. err := rs.Next(ctx, req) if err != nil { @@ -2549,9 +2549,9 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { Capability: cc.capability, } if fakeResp.AuthPlugin != "" { - failpoint.Inject("ChangeUserAuthSwitch", func(val failpoint.Value) { - failpoint.Return(errors.Errorf("%v", val)) - }) + if val, _err_ := failpoint.Eval(_curpkg_("ChangeUserAuthSwitch")); _err_ == nil { + return errors.Errorf("%v", val) + } newpass, err := cc.checkAuthPlugin(ctx, fakeResp) if err != nil { return err diff --git a/pkg/server/conn.go__failpoint_stash__ b/pkg/server/conn.go__failpoint_stash__ new file mode 100644 index 0000000000000..3cd8b5c052637 --- /dev/null +++ b/pkg/server/conn.go__failpoint_stash__ @@ -0,0 +1,2748 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// The MIT License (MIT) +// +// Copyright (c) 2014 wandoulabs +// Copyright (c) 2014 siddontang +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +package server + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/binary" + goerr "errors" + "fmt" + "io" + "net" + "os/user" + "runtime" + "runtime/pprof" + "runtime/trace" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/klauspost/compress/zstd" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/executor" + "github.com/pingcap/tidb/pkg/extension" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/plugin" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/privilege/conn" + "github.com/pingcap/tidb/pkg/privilege/privileges/ldap" + servererr "github.com/pingcap/tidb/pkg/server/err" + "github.com/pingcap/tidb/pkg/server/handler/tikvhandler" + "github.com/pingcap/tidb/pkg/server/internal" + "github.com/pingcap/tidb/pkg/server/internal/column" + "github.com/pingcap/tidb/pkg/server/internal/dump" + "github.com/pingcap/tidb/pkg/server/internal/handshake" + "github.com/pingcap/tidb/pkg/server/internal/parse" + "github.com/pingcap/tidb/pkg/server/internal/resultset" + util2 "github.com/pingcap/tidb/pkg/server/internal/util" + server_metrics "github.com/pingcap/tidb/pkg/server/metrics" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + storeerr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/arena" + "github.com/pingcap/tidb/pkg/util/chunk" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/resourcegrouptag" + tlsutil "github.com/pingcap/tidb/pkg/util/tls" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/prometheus/client_golang/prometheus" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +const ( + connStatusDispatching int32 = iota + connStatusReading + connStatusShutdown = variable.ConnStatusShutdown // Closed by server. + connStatusWaitShutdown = 3 // Notified by server to close. +) + +var ( + statusCompression = "Compression" + statusCompressionAlgorithm = "Compression_algorithm" + statusCompressionLevel = "Compression_level" +) + +var ( + // ConnectionInMemCounterForTest is a variable to count live connection object + ConnectionInMemCounterForTest = atomic.Int64{} +) + +// newClientConn creates a *clientConn object. +func newClientConn(s *Server) *clientConn { + cc := &clientConn{ + server: s, + connectionID: s.dom.NextConnID(), + collation: mysql.DefaultCollationID, + alloc: arena.NewAllocator(32 * 1024), + chunkAlloc: chunk.NewAllocator(), + status: connStatusDispatching, + lastActive: time.Now(), + authPlugin: mysql.AuthNativePassword, + quit: make(chan struct{}), + ppEnabled: s.cfg.ProxyProtocol.Networks != "", + } + + if intest.InTest { + ConnectionInMemCounterForTest.Add(1) + runtime.SetFinalizer(cc, func(*clientConn) { + ConnectionInMemCounterForTest.Add(-1) + }) + } + return cc +} + +// clientConn represents a connection between server and client, it maintains connection specific state, +// handles client query. +type clientConn struct { + pkt *internal.PacketIO // a helper to read and write data in packet format. + bufReadConn *util2.BufferedReadConn // a buffered-read net.Conn or buffered-read tls.Conn. + tlsConn *tls.Conn // TLS connection, nil if not TLS. + server *Server // a reference of server instance. + capability uint32 // client capability affects the way server handles client request. + connectionID uint64 // atomically allocated by a global variable, unique in process scope. + user string // user of the client. + dbname string // default database name. + salt []byte // random bytes used for authentication. + alloc arena.Allocator // an memory allocator for reducing memory allocation. + chunkAlloc chunk.Allocator + lastPacket []byte // latest sql query string, currently used for logging error. + // ShowProcess() and mysql.ComChangeUser both visit this field, ShowProcess() read information through + // the TiDBContext and mysql.ComChangeUser re-create it, so a lock is required here. + ctx struct { + sync.RWMutex + *TiDBContext // an interface to execute sql statements. + } + attrs map[string]string // attributes parsed from client handshake response. + serverHost string // server host + peerHost string // peer host + peerPort string // peer port + status int32 // dispatching/reading/shutdown/waitshutdown + lastCode uint16 // last error code + collation uint8 // collation used by client, may be different from the collation used by database. + lastActive time.Time // last active time + authPlugin string // default authentication plugin + isUnixSocket bool // connection is Unix Socket file + closeOnce sync.Once // closeOnce is used to make sure clientConn closes only once + rsEncoder *column.ResultEncoder // rsEncoder is used to encode the string result to different charsets + inputDecoder *util2.InputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8 + socketCredUID uint32 // UID from the other end of the Unix Socket + // mu is used for cancelling the execution of current transaction. + mu struct { + sync.RWMutex + cancelFunc context.CancelFunc + } + // quit is close once clientConn quit Run(). + quit chan struct{} + extensions *extension.SessionExtensions + + // Proxy Protocol Enabled + ppEnabled bool +} + +func (cc *clientConn) getCtx() *TiDBContext { + cc.ctx.RLock() + defer cc.ctx.RUnlock() + return cc.ctx.TiDBContext +} + +func (cc *clientConn) SetCtx(ctx *TiDBContext) { + cc.ctx.Lock() + cc.ctx.TiDBContext = ctx + cc.ctx.Unlock() +} + +func (cc *clientConn) String() string { + // MySQL converts a collation from u32 to char in the protocol, so the value could be wrong. It works fine for the + // default parameters (and libmysql seems not to provide any way to specify the collation other than the default + // one), so it's not a big problem. + collationStr := mysql.Collations[uint16(cc.collation)] + return fmt.Sprintf("id:%d, addr:%s status:%b, collation:%s, user:%s", + cc.connectionID, cc.bufReadConn.RemoteAddr(), cc.ctx.Status(), collationStr, cc.user, + ) +} + +func (cc *clientConn) setStatus(status int32) { + atomic.StoreInt32(&cc.status, status) + if ctx := cc.getCtx(); ctx != nil { + atomic.StoreInt32(&ctx.GetSessionVars().ConnectionStatus, status) + } +} + +func (cc *clientConn) getStatus() int32 { + return atomic.LoadInt32(&cc.status) +} + +func (cc *clientConn) CompareAndSwapStatus(oldStatus, newStatus int32) bool { + return atomic.CompareAndSwapInt32(&cc.status, oldStatus, newStatus) +} + +// authSwitchRequest is used by the server to ask the client to switch to a different authentication +// plugin. MySQL 8.0 libmysqlclient based clients by default always try `caching_sha2_password`, even +// when the server advertises the its default to be `mysql_native_password`. In addition to this switching +// may be needed on a per user basis as the authentication method is set per user. +// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html +// https://bugs.mysql.com/bug.php?id=93044 +func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) { + clientPlugin := plugin + if plugin == mysql.AuthLDAPSASL { + clientPlugin += "_client" + } else if plugin == mysql.AuthLDAPSimple { + clientPlugin = mysql.AuthMySQLClearPassword + } else if authPluginImpl, ok := cc.extensions.GetAuthPlugin(plugin); ok { + if authPluginImpl.RequiredClientSidePlugin != "" { + clientPlugin = authPluginImpl.RequiredClientSidePlugin + } else { + // If RequiredClientSidePlugin is empty, use the plugin name as the client plugin. + clientPlugin = authPluginImpl.Name + } + } + failpoint.Inject("FakeAuthSwitch", func() { + failpoint.Return([]byte(clientPlugin), nil) + }) + enclen := 1 + len(clientPlugin) + 1 + len(cc.salt) + 1 + data := cc.alloc.AllocWithLen(4, enclen) + data = append(data, mysql.AuthSwitchRequest) // switch request + data = append(data, []byte(clientPlugin)...) + data = append(data, byte(0x00)) // requires null + if plugin == mysql.AuthLDAPSASL { + // append sasl auth method name + data = append(data, []byte(ldap.LDAPSASLAuthImpl.GetSASLAuthMethod())...) + data = append(data, byte(0x00)) + } else { + data = append(data, cc.salt...) + data = append(data, 0) + } + err := cc.writePacket(data) + if err != nil { + logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err)) + return nil, err + } + err = cc.flush(ctx) + if err != nil { + logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) + return nil, err + } + resp, err := cc.readPacket() + if err != nil { + err = errors.SuspendStack(err) + if errors.Cause(err) == io.EOF { + logutil.Logger(ctx).Warn("authSwitchRequest response fail due to connection has be closed by client-side") + } else { + logutil.Logger(ctx).Warn("authSwitchRequest response fail", zap.Error(err)) + } + return nil, err + } + cc.authPlugin = plugin + return resp, nil +} + +// handshake works like TCP handshake, but in a higher level, it first writes initial packet to client, +// during handshake, client and server negotiate compatible features and do authentication. +// After handshake, client can send sql query to server. +func (cc *clientConn) handshake(ctx context.Context) error { + if err := cc.writeInitialHandshake(ctx); err != nil { + if errors.Cause(err) == io.EOF { + logutil.Logger(ctx).Debug("Could not send handshake due to connection has be closed by client-side") + } else { + logutil.Logger(ctx).Debug("Write init handshake to client fail", zap.Error(errors.SuspendStack(err))) + } + return err + } + if err := cc.readOptionalSSLRequestAndHandshakeResponse(ctx); err != nil { + err1 := cc.writeError(ctx, err) + if err1 != nil { + logutil.Logger(ctx).Debug("writeError failed", zap.Error(err1)) + } + return err + } + + // MySQL supports an "init_connect" query, which can be run on initial connection. + // The query must return a non-error or the client is disconnected. + if err := cc.initConnect(ctx); err != nil { + logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err)) + initErr := servererr.ErrNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed") + if err1 := cc.writeError(ctx, initErr); err1 != nil { + terror.Log(err1) + } + return initErr + } + + data := cc.alloc.AllocWithLen(4, 32) + data = append(data, mysql.OKHeader) + data = append(data, 0, 0) + if cc.capability&mysql.ClientProtocol41 > 0 { + data = dump.Uint16(data, mysql.ServerStatusAutocommit) + data = append(data, 0, 0) + } + + err := cc.writePacket(data) + cc.pkt.SetSequence(0) + if err != nil { + err = errors.SuspendStack(err) + logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err)) + return err + } + + err = cc.flush(ctx) + if err != nil { + err = errors.SuspendStack(err) + logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) + return err + } + + // With mysql --compression-algorithms=zlib,zstd both flags are set, the result is Zlib + if cc.capability&mysql.ClientCompress > 0 { + cc.pkt.SetCompressionAlgorithm(mysql.CompressionZlib) + cc.ctx.SetCompressionAlgorithm(mysql.CompressionZlib) + } else if cc.capability&mysql.ClientZstdCompressionAlgorithm > 0 { + cc.pkt.SetCompressionAlgorithm(mysql.CompressionZstd) + cc.ctx.SetCompressionAlgorithm(mysql.CompressionZstd) + } + + return err +} + +func (cc *clientConn) Close() error { + // Be careful, this function should be re-entrant. It might be called more than once for a single connection. + // Any logic which is not idempotent should be in closeConn() and wrapped with `cc.closeOnce.Do`, like decresing + // metrics, releasing resources, etc. + // + // TODO: avoid calling this function multiple times. It's not intuitive that a connection can be closed multiple + // times. + cc.server.rwlock.Lock() + delete(cc.server.clients, cc.connectionID) + cc.server.rwlock.Unlock() + return closeConn(cc) +} + +// closeConn is idempotent and thread-safe. +// It will be called on the same `clientConn` more than once to avoid connection leak. +func closeConn(cc *clientConn) error { + var err error + cc.closeOnce.Do(func() { + if cc.connectionID > 0 { + cc.server.dom.ReleaseConnID(cc.connectionID) + cc.connectionID = 0 + } + if cc.bufReadConn != nil { + err := cc.bufReadConn.Close() + if err != nil { + // We need to expect connection might have already disconnected. + // This is because closeConn() might be called after a connection read-timeout. + logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) + } + } + + // Close statements and session + // At first, it'll decrese the count of connections in the resource group, update the corresponding gauge. + // Then it'll close the statements and session, which release advisory locks, row locks, etc. + if ctx := cc.getCtx(); ctx != nil { + resourceGroupName := ctx.GetSessionVars().ResourceGroupName + metrics.ConnGauge.WithLabelValues(resourceGroupName).Dec() + + err = ctx.Close() + } else { + metrics.ConnGauge.WithLabelValues(resourcegroup.DefaultResourceGroupName).Dec() + } + }) + return err +} + +func (cc *clientConn) closeWithoutLock() error { + delete(cc.server.clients, cc.connectionID) + return closeConn(cc) +} + +// writeInitialHandshake sends server version, connection ID, server capability, collation, server status +// and auth salt to the client. +func (cc *clientConn) writeInitialHandshake(ctx context.Context) error { + data := make([]byte, 4, 128) + + // min version 10 + data = append(data, 10) + // server version[00] + data = append(data, mysql.ServerVersion...) + data = append(data, 0) + // connection id + data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24)) + // auth-plugin-data-part-1 + data = append(data, cc.salt[0:8]...) + // filler [00] + data = append(data, 0) + // capability flag lower 2 bytes, using default capability here + data = append(data, byte(cc.server.capability), byte(cc.server.capability>>8)) + // charset + if cc.collation == 0 { + cc.collation = uint8(mysql.DefaultCollationID) + } + data = append(data, cc.collation) + // status + data = dump.Uint16(data, mysql.ServerStatusAutocommit) + // below 13 byte may not be used + // capability flag upper 2 bytes, using default capability here + data = append(data, byte(cc.server.capability>>16), byte(cc.server.capability>>24)) + // length of auth-plugin-data + data = append(data, byte(len(cc.salt)+1)) + // reserved 10 [00] + data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + // auth-plugin-data-part-2 + data = append(data, cc.salt[8:]...) + data = append(data, 0) + // auth-plugin name + if ctx := cc.getCtx(); ctx == nil { + if err := cc.openSession(); err != nil { + return err + } + } + defAuthPlugin, err := cc.ctx.GetSessionVars().GetGlobalSystemVar(context.Background(), variable.DefaultAuthPlugin) + if err != nil { + return err + } + cc.authPlugin = defAuthPlugin + data = append(data, []byte(defAuthPlugin)...) + + // Close the session to force this to be re-opened after we parse the response. This is needed + // to ensure we use the collation and client flags from the response for the session. + if err = cc.ctx.Close(); err != nil { + return err + } + cc.SetCtx(nil) + + data = append(data, 0) + if err = cc.writePacket(data); err != nil { + return err + } + return cc.flush(ctx) +} + +func (cc *clientConn) readPacket() ([]byte, error) { + if cc.getCtx() != nil { + cc.pkt.SetMaxAllowedPacket(cc.ctx.GetSessionVars().MaxAllowedPacket) + } + return cc.pkt.ReadPacket() +} + +func (cc *clientConn) writePacket(data []byte) error { + failpoint.Inject("FakeClientConn", func() { + if cc.pkt == nil { + failpoint.Return(nil) + } + }) + return cc.pkt.WritePacket(data) +} + +func (cc *clientConn) getWaitTimeout(ctx context.Context) uint64 { + sessVars := cc.ctx.GetSessionVars() + if sessVars.InTxn() && sessVars.IdleTransactionTimeout > 0 { + return uint64(sessVars.IdleTransactionTimeout) + } + return cc.getSessionVarsWaitTimeout(ctx) +} + +// getSessionVarsWaitTimeout get session variable wait_timeout +func (cc *clientConn) getSessionVarsWaitTimeout(ctx context.Context) uint64 { + valStr, exists := cc.ctx.GetSessionVars().GetSystemVar(variable.WaitTimeout) + if !exists { + return variable.DefWaitTimeout + } + waitTimeout, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + logutil.Logger(ctx).Warn("get sysval wait_timeout failed, use default value", zap.Error(err)) + // if get waitTimeout error, use default value + return variable.DefWaitTimeout + } + return waitTimeout +} + +func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Context) error { + // Read a packet. It may be a SSLRequest or HandshakeResponse. + data, err := cc.readPacket() + if err != nil { + err = errors.SuspendStack(err) + if errors.Cause(err) == io.EOF { + logutil.Logger(ctx).Debug("wait handshake response fail due to connection has be closed by client-side") + } else { + logutil.Logger(ctx).Debug("wait handshake response fail", zap.Error(err)) + } + return err + } + + var resp handshake.Response41 + var pos int + + if len(data) < 2 { + logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data)) + return mysql.ErrMalformPacket + } + + capability := uint32(binary.LittleEndian.Uint16(data[:2])) + if capability&mysql.ClientProtocol41 <= 0 { + logutil.Logger(ctx).Error("ClientProtocol41 flag is not set, please upgrade client") + return servererr.ErrNotSupportedAuthMode + } + pos, err = parse.HandshakeResponseHeader(ctx, &resp, data) + if err != nil { + terror.Log(err) + return err + } + + // After read packets we should update the client's host and port to grab + // real client's IP and port from PROXY Protocol header if PROXY Protocol is enabled. + _, _, err = cc.PeerHost("", true) + if err != nil { + terror.Log(err) + return err + } + // If enable proxy protocol check audit plugins after update real IP + if cc.ppEnabled { + err = cc.server.checkAuditPlugin(cc) + if err != nil { + return err + } + } + + if resp.Capability&mysql.ClientSSL > 0 { + tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig)) + if tlsConfig != nil { + // The packet is a SSLRequest, let's switch to TLS. + if err = cc.upgradeToTLS(tlsConfig); err != nil { + return err + } + // Read the following HandshakeResponse packet. + data, err = cc.readPacket() + if err != nil { + logutil.Logger(ctx).Warn("read handshake response failure after upgrade to TLS", zap.Error(err)) + return err + } + pos, err = parse.HandshakeResponseHeader(ctx, &resp, data) + if err != nil { + terror.Log(err) + return err + } + } + } else if tlsutil.RequireSecureTransport.Load() && !cc.isUnixSocket { + // If it's not a socket connection, we should reject the connection + // because TLS is required. + err := servererr.ErrSecureTransportRequired.FastGenByArgs() + terror.Log(err) + return err + } + + // Read the remaining part of the packet. + err = parse.HandshakeResponseBody(ctx, &resp, data, pos) + if err != nil { + terror.Log(err) + return err + } + + cc.capability = resp.Capability & cc.server.capability + cc.user = resp.User + cc.dbname = resp.DBName + cc.collation = resp.Collation + cc.attrs = resp.Attrs + cc.pkt.SetZstdLevel(zstd.EncoderLevelFromZstd(resp.ZstdLevel)) + + err = cc.handleAuthPlugin(ctx, &resp) + if err != nil { + return err + } + + switch resp.AuthPlugin { + case mysql.AuthCachingSha2Password: + resp.Auth, err = cc.authSha(ctx, resp) + if err != nil { + return err + } + case mysql.AuthTiDBSM3Password: + resp.Auth, err = cc.authSM3(ctx, resp) + if err != nil { + return err + } + case mysql.AuthNativePassword: + case mysql.AuthSocket: + case mysql.AuthTiDBSessionToken: + case mysql.AuthTiDBAuthToken: + case mysql.AuthMySQLClearPassword: + case mysql.AuthLDAPSASL: + case mysql.AuthLDAPSimple: + default: + if _, ok := cc.extensions.GetAuthPlugin(resp.AuthPlugin); !ok { + return errors.New("Unknown auth plugin") + } + } + + err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin, resp.ZstdLevel) + if err != nil { + logutil.Logger(ctx).Warn("open new session or authentication failure", zap.Error(err)) + } + return err +} + +func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshake.Response41) error { + if resp.Capability&mysql.ClientPluginAuth > 0 { + newAuth, err := cc.checkAuthPlugin(ctx, resp) + if err != nil { + logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) + return err + } + if len(newAuth) > 0 { + resp.Auth = newAuth + } + + if _, ok := cc.extensions.GetAuthPlugin(resp.AuthPlugin); ok { + // The auth plugin has been registered, skip other checks. + return nil + } + switch resp.AuthPlugin { + case mysql.AuthCachingSha2Password: + case mysql.AuthTiDBSM3Password: + case mysql.AuthNativePassword: + case mysql.AuthSocket: + case mysql.AuthTiDBSessionToken: + case mysql.AuthMySQLClearPassword: + case mysql.AuthLDAPSASL: + case mysql.AuthLDAPSimple: + default: + logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin)) + } + } else { + // MySQL 5.1 and older clients don't support authentication plugins. + logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client") + _, err := cc.checkAuthPlugin(ctx, resp) + if err != nil { + return err + } + resp.AuthPlugin = mysql.AuthNativePassword + } + return nil +} + +// authSha implements the caching_sha2_password specific part of the protocol. +func (cc *clientConn) authSha(ctx context.Context, resp handshake.Response41) ([]byte, error) { + const ( + shaCommand = 1 + requestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel. + fastAuthOk = 3 + fastAuthFail = 4 + ) + + // If no password is specified, we don't send the FastAuthFail to do the full authentication + // as that doesn't make sense without a password and confuses the client. + // https://github.com/pingcap/tidb/issues/40831 + if len(resp.Auth) == 0 { + return []byte{}, nil + } + + // Currently we always send a "FastAuthFail" as the cached part of the protocol isn't implemented yet. + // This triggers the client to send the full response. + err := cc.writePacket([]byte{0, 0, 0, 0, shaCommand, fastAuthFail}) + if err != nil { + logutil.Logger(ctx).Error("authSha packet write failed", zap.Error(err)) + return nil, err + } + err = cc.flush(ctx) + if err != nil { + logutil.Logger(ctx).Error("authSha packet flush failed", zap.Error(err)) + return nil, err + } + + data, err := cc.readPacket() + if err != nil { + logutil.Logger(ctx).Error("authSha packet read failed", zap.Error(err)) + return nil, err + } + return bytes.Trim(data, "\x00"), nil +} + +// authSM3 implements the tidb_sm3_password specific part of the protocol. +// tidb_sm3_password is very similar to caching_sha2_password. +func (cc *clientConn) authSM3(ctx context.Context, resp handshake.Response41) ([]byte, error) { + // If no password is specified, we don't send the FastAuthFail to do the full authentication + // as that doesn't make sense without a password and confuses the client. + // https://github.com/pingcap/tidb/issues/40831 + if len(resp.Auth) == 0 { + return []byte{}, nil + } + + err := cc.writePacket([]byte{0, 0, 0, 0, 1, 4}) // fastAuthFail + if err != nil { + logutil.Logger(ctx).Error("authSM3 packet write failed", zap.Error(err)) + return nil, err + } + err = cc.flush(ctx) + if err != nil { + logutil.Logger(ctx).Error("authSM3 packet flush failed", zap.Error(err)) + return nil, err + } + + data, err := cc.readPacket() + if err != nil { + logutil.Logger(ctx).Error("authSM3 packet read failed", zap.Error(err)) + return nil, err + } + return bytes.Trim(data, "\x00"), nil +} + +func (cc *clientConn) SessionStatusToString() string { + status := cc.ctx.Status() + inTxn, autoCommit := 0, 0 + if status&mysql.ServerStatusInTrans > 0 { + inTxn = 1 + } + if status&mysql.ServerStatusAutocommit > 0 { + autoCommit = 1 + } + return fmt.Sprintf("inTxn:%d, autocommit:%d", + inTxn, autoCommit, + ) +} + +func (cc *clientConn) openSession() error { + var tlsStatePtr *tls.ConnectionState + if cc.tlsConn != nil { + tlsState := cc.tlsConn.ConnectionState() + tlsStatePtr = &tlsState + } + ctx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions) + if err != nil { + return err + } + cc.SetCtx(ctx) + + err = cc.server.checkConnectionCount() + if err != nil { + return err + } + return nil +} + +func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string, zstdLevel int) error { + // Open a context unless this was done before. + if ctx := cc.getCtx(); ctx == nil { + err := cc.openSession() + if err != nil { + return err + } + } + + hasPassword := "YES" + if len(authData) == 0 { + hasPassword = "NO" + } + + host, port, err := cc.PeerHost(hasPassword, false) + if err != nil { + return err + } + + if !cc.isUnixSocket && authPlugin == mysql.AuthSocket { + return servererr.ErrAccessDeniedNoPassword.FastGenByArgs(cc.user, host) + } + + userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host, AuthPlugin: authPlugin} + if err = cc.ctx.Auth(userIdentity, authData, cc.salt, cc); err != nil { + return err + } + cc.ctx.SetPort(port) + cc.ctx.SetCompressionLevel(zstdLevel) + if cc.dbname != "" { + _, err = cc.useDB(context.Background(), cc.dbname) + if err != nil { + return err + } + } + cc.ctx.SetSessionManager(cc.server) + return nil +} + +// mockOSUserForAuthSocketTest should only be used in test +var mockOSUserForAuthSocketTest atomic.Pointer[string] + +// Check if the Authentication Plugin of the server, client and user configuration matches +func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshake.Response41) ([]byte, error) { + // Open a context unless this was done before. + if ctx := cc.getCtx(); ctx == nil { + err := cc.openSession() + if err != nil { + return nil, err + } + } + + authData := resp.Auth + // tidb_session_token is always permitted and skips stored user plugin. + if resp.AuthPlugin == mysql.AuthTiDBSessionToken { + return authData, nil + } + hasPassword := "YES" + if len(authData) == 0 { + hasPassword = "NO" + } + + host, _, err := cc.PeerHost(hasPassword, false) + if err != nil { + return nil, err + } + // Find the identity of the user based on username and peer host. + identity, err := cc.ctx.MatchIdentity(cc.user, host) + if err != nil { + return nil, servererr.ErrAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + // Get the plugin for the identity. + userplugin, err := cc.ctx.AuthPluginForUser(identity) + if err != nil { + logutil.Logger(ctx).Warn("Failed to get authentication method for user", + zap.String("user", cc.user), zap.String("host", host)) + } + failpoint.Inject("FakeUser", func(val failpoint.Value) { + //nolint:forcetypeassert + userplugin = val.(string) + }) + if userplugin == mysql.AuthSocket { + if !cc.isUnixSocket { + return nil, servererr.ErrAccessDenied.FastGenByArgs(cc.user, host, hasPassword) + } + resp.AuthPlugin = mysql.AuthSocket + user, err := user.LookupId(fmt.Sprint(cc.socketCredUID)) + if err != nil { + return nil, err + } + uname := user.Username + + if intest.InTest { + if p := mockOSUserForAuthSocketTest.Load(); p != nil { + uname = *p + } + } + + return []byte(uname), nil + } + if len(userplugin) == 0 { + // No user plugin set, assuming MySQL Native Password + // This happens if the account doesn't exist or if the account doesn't have + // a password set. + if resp.AuthPlugin != mysql.AuthNativePassword { + if resp.Capability&mysql.ClientPluginAuth > 0 { + resp.AuthPlugin = mysql.AuthNativePassword + authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword) + if err != nil { + return nil, err + } + return authData, nil + } + } + return nil, nil + } + + // If the authentication method send by the server (cc.authPlugin) doesn't match + // the plugin configured for the user account in the mysql.user.plugin column + // or if the authentication method send by the server doesn't match the authentication + // method send by the client (*authPlugin) then we need to switch the authentication + // method to match the one configured for that specific user. + if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) { + if userplugin == mysql.AuthTiDBAuthToken { + userplugin = mysql.AuthMySQLClearPassword + } + if resp.Capability&mysql.ClientPluginAuth > 0 { + authData, err := cc.authSwitchRequest(ctx, userplugin) + if err != nil { + return nil, err + } + resp.AuthPlugin = userplugin + return authData, nil + } else if userplugin != mysql.AuthNativePassword { + // MySQL 5.1 and older don't support authentication plugins yet + return nil, servererr.ErrNotSupportedAuthMode + } + } + + return nil, nil +} + +func (cc *clientConn) PeerHost(hasPassword string, update bool) (host, port string, err error) { + // already get peer host + if len(cc.peerHost) > 0 { + // Proxy protocol enabled and not update + if cc.ppEnabled && !update { + return cc.peerHost, cc.peerPort, nil + } + // Proxy protocol not enabled + if !cc.ppEnabled { + return cc.peerHost, cc.peerPort, nil + } + } + host = variable.DefHostname + if cc.isUnixSocket { + cc.peerHost = host + cc.serverHost = host + return + } + addr := cc.bufReadConn.RemoteAddr().String() + host, port, err = net.SplitHostPort(addr) + if err != nil { + err = servererr.ErrAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword) + return + } + cc.peerHost = host + cc.peerPort = port + + serverAddr := cc.bufReadConn.LocalAddr().String() + serverHost, _, err := net.SplitHostPort(serverAddr) + if err != nil { + err = servererr.ErrAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword) + return + } + cc.serverHost = serverHost + + return +} + +// skipInitConnect follows MySQL's rules of when init-connect should be skipped. +// In 5.7 it is any user with SUPER privilege, but in 8.0 it is: +// - SUPER or the CONNECTION_ADMIN dynamic privilege. +// - (additional exception) users with expired passwords (not yet supported) +// In TiDB CONNECTION_ADMIN is satisfied by SUPER, so we only need to check once. +func (cc *clientConn) skipInitConnect() bool { + checker := privilege.GetPrivilegeManager(cc.ctx.Session) + activeRoles := cc.ctx.GetSessionVars().ActiveRoles + return checker != nil && checker.RequestDynamicVerification(activeRoles, "CONNECTION_ADMIN", false) +} + +// initResultEncoder initialize the result encoder for current connection. +func (cc *clientConn) initResultEncoder(ctx context.Context) { + chs, err := cc.ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetResults) + if err != nil { + chs = "" + logutil.Logger(ctx).Warn("get character_set_results system variable failed", zap.Error(err)) + } + cc.rsEncoder = column.NewResultEncoder(chs) +} + +func (cc *clientConn) initInputEncoder(ctx context.Context) { + chs, err := cc.ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetClient) + if err != nil { + chs = "" + logutil.Logger(ctx).Warn("get character_set_client system variable failed", zap.Error(err)) + } + cc.inputDecoder = util2.NewInputDecoder(chs) +} + +// initConnect runs the initConnect SQL statement if it has been specified. +// The semantics are MySQL compatible. +func (cc *clientConn) initConnect(ctx context.Context) error { + val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect) + if err != nil { + return err + } + if val == "" || cc.skipInitConnect() { + return nil + } + logutil.Logger(ctx).Debug("init_connect starting") + stmts, err := cc.ctx.Parse(ctx, val) + if err != nil { + return err + } + for _, stmt := range stmts { + rs, err := cc.ctx.ExecuteStmt(ctx, stmt) + if err != nil { + return err + } + // init_connect does not care about the results, + // but they need to be drained because of lazy loading. + if rs != nil { + req := rs.NewChunk(nil) + for { + if err = rs.Next(ctx, req); err != nil { + return err + } + if req.NumRows() == 0 { + break + } + } + rs.Close() + } + } + logutil.Logger(ctx).Debug("init_connect complete") + return nil +} + +// Run reads client query and writes query result to client in for loop, if there is a panic during query handling, +// it will be recovered and log the panic error. +// This function returns and the connection is closed if there is an IO error or there is a panic. +func (cc *clientConn) Run(ctx context.Context) { + defer func() { + r := recover() + if r != nil { + logutil.Logger(ctx).Error("connection running loop panic", + zap.Stringer("lastSQL", getLastStmtInConn{cc}), + zap.String("err", fmt.Sprintf("%v", r)), + zap.Stack("stack"), + ) + err := cc.writeError(ctx, fmt.Errorf("%v", r)) + terror.Log(err) + metrics.PanicCounter.WithLabelValues(metrics.LabelSession).Inc() + } + if cc.getStatus() != connStatusShutdown { + err := cc.Close() + terror.Log(err) + } + + close(cc.quit) + }() + + parentCtx := ctx + var traceInfo *model.TraceInfo + // Usually, client connection status changes between [dispatching] <=> [reading]. + // When some event happens, server may notify this client connection by setting + // the status to special values, for example: kill or graceful shutdown. + // The client connection would detect the events when it fails to change status + // by CAS operation, it would then take some actions accordingly. + for { + sessVars := cc.ctx.GetSessionVars() + if alias := sessVars.SessionAlias; traceInfo == nil || traceInfo.SessionAlias != alias { + // We should reset the context trace info when traceInfo not inited or session alias changed. + traceInfo = &model.TraceInfo{ + ConnectionID: cc.connectionID, + SessionAlias: alias, + } + ctx = logutil.WithSessionAlias(parentCtx, sessVars.SessionAlias) + ctx = tracing.ContextWithTraceInfo(ctx, traceInfo) + } + + // Close connection between txn when we are going to shutdown server. + // Note the current implementation when shutting down, for an idle connection, the connection may block at readPacket() + // consider provider a way to close the connection directly after sometime if we can not read any data. + if cc.server.inShutdownMode.Load() { + if !sessVars.InTxn() { + return + } + } + + if !cc.CompareAndSwapStatus(connStatusDispatching, connStatusReading) || + // The judge below will not be hit by all means, + // But keep it stayed as a reminder and for the code reference for connStatusWaitShutdown. + cc.getStatus() == connStatusWaitShutdown { + return + } + + cc.alloc.Reset() + // close connection when idle time is more than wait_timeout + // default 28800(8h), FIXME: should not block at here when we kill the connection. + waitTimeout := cc.getWaitTimeout(ctx) + cc.pkt.SetReadTimeout(time.Duration(waitTimeout) * time.Second) + start := time.Now() + data, err := cc.readPacket() + if err != nil { + if terror.ErrorNotEqual(err, io.EOF) { + if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() { + if cc.getStatus() == connStatusWaitShutdown { + logutil.Logger(ctx).Info("read packet timeout because of killed connection") + } else { + idleTime := time.Since(start) + logutil.Logger(ctx).Info("read packet timeout, close this connection", + zap.Duration("idle", idleTime), + zap.Uint64("waitTimeout", waitTimeout), + zap.Error(err), + ) + } + } else if errors.ErrorEqual(err, servererr.ErrNetPacketTooLarge) { + err := cc.writeError(ctx, err) + if err != nil { + terror.Log(err) + } + } else { + errStack := errors.ErrorStack(err) + if !strings.Contains(errStack, "use of closed network connection") { + logutil.Logger(ctx).Warn("read packet failed, close this connection", + zap.Error(errors.SuspendStack(err))) + } + } + } + server_metrics.DisconnectByClientWithError.Inc() + return + } + + // Should check InTxn() to avoid execute `begin` stmt. + if cc.server.inShutdownMode.Load() { + if !cc.ctx.GetSessionVars().InTxn() { + return + } + } + + if !cc.CompareAndSwapStatus(connStatusReading, connStatusDispatching) { + return + } + + startTime := time.Now() + err = cc.dispatch(ctx, data) + cc.ctx.GetSessionVars().ClearAlloc(&cc.chunkAlloc, err != nil) + cc.chunkAlloc.Reset() + if err != nil { + cc.audit(plugin.Error) // tell the plugin API there was a dispatch error + if terror.ErrorEqual(err, io.EOF) { + cc.addMetrics(data[0], startTime, nil) + server_metrics.DisconnectNormal.Inc() + return + } else if terror.ErrResultUndetermined.Equal(err) { + logutil.Logger(ctx).Error("result undetermined, close this connection", zap.Error(err)) + server_metrics.DisconnectErrorUndetermined.Inc() + return + } else if terror.ErrCritical.Equal(err) { + metrics.CriticalErrorCounter.Add(1) + logutil.Logger(ctx).Fatal("critical error, stop the server", zap.Error(err)) + } + var txnMode string + if ctx := cc.getCtx(); ctx != nil { + txnMode = ctx.GetSessionVars().GetReadableTxnMode() + } + vars := cc.getCtx().GetSessionVars() + for _, dbName := range session.GetDBNames(vars) { + metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err), dbName, vars.ResourceGroupName).Inc() + } + + if storeerr.ErrLockAcquireFailAndNoWaitSet.Equal(err) { + logutil.Logger(ctx).Debug("Expected error for FOR UPDATE NOWAIT", zap.Error(err)) + } else { + var timestamp uint64 + if ctx := cc.getCtx(); ctx != nil && ctx.GetSessionVars() != nil && ctx.GetSessionVars().TxnCtx != nil { + timestamp = ctx.GetSessionVars().TxnCtx.StartTS + if timestamp == 0 && ctx.GetSessionVars().TxnCtx.StaleReadTs > 0 { + // for state-read query. + timestamp = ctx.GetSessionVars().TxnCtx.StaleReadTs + } + } + logutil.Logger(ctx).Info("command dispatched failed", + zap.String("connInfo", cc.String()), + zap.String("command", mysql.Command2Str[data[0]]), + zap.String("status", cc.SessionStatusToString()), + zap.Stringer("sql", getLastStmtInConn{cc}), + zap.String("txn_mode", txnMode), + zap.Uint64("timestamp", timestamp), + zap.String("err", errStrForLog(err, cc.ctx.GetSessionVars().EnableRedactLog)), + ) + } + err1 := cc.writeError(ctx, err) + terror.Log(err1) + } + cc.addMetrics(data[0], startTime, err) + cc.pkt.SetSequence(0) + cc.pkt.SetCompressedSequence(0) + } +} + +func errStrForLog(err error, redactMode string) string { + if redactMode != errors.RedactLogDisable { + // currently, only ErrParse is considered when enableRedactLog because it may contain sensitive information like + // password or accesskey + if parser.ErrParse.Equal(err) { + return "fail to parse SQL, and must redact the whole error when enable log redaction" + } + } + var ret string + if kv.ErrKeyExists.Equal(err) || parser.ErrParse.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { + // Do not log stack for duplicated entry error. + ret = err.Error() + } else { + ret = errors.ErrorStack(err) + } + return ret +} + +func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { + if cmd == mysql.ComQuery && cc.ctx.Value(sessionctx.LastExecuteDDL) != nil { + // Don't take DDL execute time into account. + // It's already recorded by other metrics in ddl package. + return + } + + vars := cc.getCtx().GetSessionVars() + resourceGroupName := vars.ResourceGroupName + var counter prometheus.Counter + if len(resourceGroupName) == 0 || resourceGroupName == resourcegroup.DefaultResourceGroupName { + if err != nil && int(cmd) < len(server_metrics.QueryTotalCountErr) { + counter = server_metrics.QueryTotalCountErr[cmd] + } else if err == nil && int(cmd) < len(server_metrics.QueryTotalCountOk) { + counter = server_metrics.QueryTotalCountOk[cmd] + } + } + + if counter != nil { + counter.Inc() + } else { + label := server_metrics.CmdToString(cmd) + if err != nil { + metrics.QueryTotalCounter.WithLabelValues(label, "Error", resourceGroupName).Inc() + } else { + metrics.QueryTotalCounter.WithLabelValues(label, "OK", resourceGroupName).Inc() + } + } + + cost := time.Since(startTime) + sessionVar := cc.ctx.GetSessionVars() + affectedRows := cc.ctx.AffectedRows() + cc.ctx.GetTxnWriteThroughputSLI().FinishExecuteStmt(cost, affectedRows, sessionVar.InTxn()) + + stmtType := sessionVar.StmtCtx.StmtType + sqlType := metrics.LblGeneral + if stmtType != "" { + sqlType = stmtType + } + + switch sqlType { + case "Insert": + server_metrics.AffectedRowsCounterInsert.Add(float64(affectedRows)) + case "Replace": + server_metrics.AffectedRowsCounterReplace.Add(float64(affectedRows)) + case "Delete": + server_metrics.AffectedRowsCounterDelete.Add(float64(affectedRows)) + case "Update": + server_metrics.AffectedRowsCounterUpdate.Add(float64(affectedRows)) + } + + for _, dbName := range session.GetDBNames(vars) { + metrics.QueryDurationHistogram.WithLabelValues(sqlType, dbName, vars.StmtCtx.ResourceGroupName).Observe(cost.Seconds()) + metrics.QueryRPCHistogram.WithLabelValues(sqlType, dbName).Observe(float64(vars.StmtCtx.GetExecDetails().RequestCount)) + if vars.StmtCtx.GetExecDetails().ScanDetail != nil { + metrics.QueryProcessedKeyHistogram.WithLabelValues(sqlType, dbName).Observe(float64(vars.StmtCtx.GetExecDetails().ScanDetail.ProcessedKeys)) + } + } +} + +// dispatch handles client request based on command which is the first byte of the data. +// It also gets a token from server which is used to limit the concurrently handling clients. +// The most frequently used command is ComQuery. +func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { + defer func() { + // reset killed for each request + cc.ctx.GetSessionVars().SQLKiller.Reset() + }() + t := time.Now() + if (cc.ctx.Status() & mysql.ServerStatusInTrans) > 0 { + server_metrics.ConnIdleDurationHistogramInTxn.Observe(t.Sub(cc.lastActive).Seconds()) + } else { + server_metrics.ConnIdleDurationHistogramNotInTxn.Observe(t.Sub(cc.lastActive).Seconds()) + } + + cfg := config.GetGlobalConfig() + if cfg.OpenTracing.Enable { + var r tracing.Region + r, ctx = tracing.StartRegionWithNewRootSpan(ctx, "server.dispatch") + defer r.End() + } + + var cancelFunc context.CancelFunc + ctx, cancelFunc = context.WithCancel(ctx) + cc.mu.Lock() + cc.mu.cancelFunc = cancelFunc + cc.mu.Unlock() + + cc.lastPacket = data + cmd := data[0] + data = data[1:] + if topsqlstate.TopSQLEnabled() { + defer pprof.SetGoroutineLabels(ctx) + } + if variable.EnablePProfSQLCPU.Load() { + label := getLastStmtInConn{cc}.PProfLabel() + if len(label) > 0 { + defer pprof.SetGoroutineLabels(ctx) + ctx = pprof.WithLabels(ctx, pprof.Labels("sql", label)) + pprof.SetGoroutineLabels(ctx) + } + } + if trace.IsEnabled() { + lc := getLastStmtInConn{cc} + sqlType := lc.PProfLabel() + if len(sqlType) > 0 { + var task *trace.Task + ctx, task = trace.NewTask(ctx, sqlType) + defer task.End() + + trace.Log(ctx, "sql", lc.String()) + ctx = logutil.WithTraceLogger(ctx, tracing.TraceInfoFromContext(ctx)) + + taskID := *(*uint64)(unsafe.Pointer(task)) + ctx = pprof.WithLabels(ctx, pprof.Labels("trace", strconv.FormatUint(taskID, 10))) + pprof.SetGoroutineLabels(ctx) + } + } + token := cc.server.getToken() + defer func() { + // if handleChangeUser failed, cc.ctx may be nil + if ctx := cc.getCtx(); ctx != nil { + ctx.SetProcessInfo("", t, mysql.ComSleep, 0) + } + + cc.server.releaseToken(token) + cc.lastActive = time.Now() + }() + + vars := cc.ctx.GetSessionVars() + // reset killed for each request + vars.SQLKiller.Reset() + if cmd < mysql.ComEnd { + cc.ctx.SetCommandValue(cmd) + } + + dataStr := string(hack.String(data)) + switch cmd { + case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset, + mysql.ComSetOption, mysql.ComChangeUser: + cc.ctx.SetProcessInfo("", t, cmd, 0) + case mysql.ComInitDB: + cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0) + } + + switch cmd { + case mysql.ComQuit: + return io.EOF + case mysql.ComInitDB: + node, err := cc.useDB(ctx, dataStr) + cc.onExtensionStmtEnd(node, false, err) + if err != nil { + return err + } + return cc.writeOK(ctx) + case mysql.ComQuery: // Most frequently used command. + // For issue 1989 + // Input payload may end with byte '\0', we didn't find related mysql document about it, but mysql + // implementation accept that case. So trim the last '\0' here as if the payload an EOF string. + // See http://dev.mysql.com/doc/internals/en/com-query.html + if len(data) > 0 && data[len(data)-1] == 0 { + data = data[:len(data)-1] + dataStr = string(hack.String(data)) + } + return cc.handleQuery(ctx, dataStr) + case mysql.ComFieldList: + return cc.handleFieldList(ctx, dataStr) + // ComCreateDB, ComDropDB + case mysql.ComRefresh: + return cc.handleRefresh(ctx, data[0]) + case mysql.ComShutdown: // redirect to SQL + if err := cc.handleQuery(ctx, "SHUTDOWN"); err != nil { + return err + } + return cc.writeOK(ctx) + case mysql.ComStatistics: + return cc.writeStats(ctx) + // ComProcessInfo, ComConnect, ComProcessKill, ComDebug + case mysql.ComPing: + return cc.writeOK(ctx) + case mysql.ComChangeUser: + return cc.handleChangeUser(ctx, data) + // ComBinlogDump, ComTableDump, ComConnectOut, ComRegisterSlave + case mysql.ComStmtPrepare: + // For issue 39132, same as ComQuery + if len(data) > 0 && data[len(data)-1] == 0 { + data = data[:len(data)-1] + dataStr = string(hack.String(data)) + } + return cc.HandleStmtPrepare(ctx, dataStr) + case mysql.ComStmtExecute: + return cc.handleStmtExecute(ctx, data) + case mysql.ComStmtSendLongData: + return cc.handleStmtSendLongData(data) + case mysql.ComStmtClose: + return cc.handleStmtClose(data) + case mysql.ComStmtReset: + return cc.handleStmtReset(ctx, data) + case mysql.ComSetOption: + return cc.handleSetOption(ctx, data) + case mysql.ComStmtFetch: + return cc.handleStmtFetch(ctx, data) + // ComDaemon, ComBinlogDumpGtid + case mysql.ComResetConnection: + return cc.handleResetConnection(ctx) + // ComEnd + default: + return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", nil, cmd) + } +} + +func (cc *clientConn) writeStats(ctx context.Context) error { + var err error + var uptime int64 + info := tikvhandler.ServerInfo{} + info.ServerInfo, err = infosync.GetServerInfo() + if err != nil { + logutil.BgLogger().Error("Failed to get ServerInfo for uptime status", zap.Error(err)) + } else { + uptime = int64(time.Since(time.Unix(info.ServerInfo.StartTimestamp, 0)).Seconds()) + } + msg := []byte(fmt.Sprintf("Uptime: %d Threads: 0 Questions: 0 Slow queries: 0 Opens: 0 Flush tables: 0 Open tables: 0 Queries per second avg: 0.000", + uptime)) + data := cc.alloc.AllocWithLen(4, len(msg)) + data = append(data, msg...) + + err = cc.writePacket(data) + if err != nil { + return err + } + + return cc.flush(ctx) +} + +func (cc *clientConn) useDB(ctx context.Context, db string) (node ast.StmtNode, err error) { + // if input is "use `SELECT`", mysql client just send "SELECT" + // so we add `` around db. + stmts, err := cc.ctx.Parse(ctx, "use `"+db+"`") + if err != nil { + return nil, err + } + _, err = cc.ctx.ExecuteStmt(ctx, stmts[0]) + if err != nil { + return stmts[0], err + } + cc.dbname = db + return stmts[0], err +} + +func (cc *clientConn) flush(ctx context.Context) error { + var ( + stmtDetail *execdetails.StmtExecDetails + startTime time.Time + ) + if stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey); stmtDetailRaw != nil { + //nolint:forcetypeassert + stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) + startTime = time.Now() + } + defer func() { + if stmtDetail != nil { + stmtDetail.WriteSQLRespDuration += time.Since(startTime) + } + trace.StartRegion(ctx, "FlushClientConn").End() + if ctx := cc.getCtx(); ctx != nil && ctx.WarningCount() > 0 { + for _, err := range ctx.GetWarnings() { + var warn *errors.Error + if ok := goerr.As(err.Err, &warn); ok { + code := uint16(warn.Code()) + errno.IncrementWarning(code, cc.user, cc.peerHost) + } + } + } + }() + failpoint.Inject("FakeClientConn", func() { + if cc.pkt == nil { + failpoint.Return(nil) + } + }) + return cc.pkt.Flush() +} + +func (cc *clientConn) writeOK(ctx context.Context) error { + return cc.writeOkWith(ctx, mysql.OKHeader, true, cc.ctx.Status()) +} + +func (cc *clientConn) writeOkWith(ctx context.Context, header byte, flush bool, status uint16) error { + msg := cc.ctx.LastMessage() + affectedRows := cc.ctx.AffectedRows() + lastInsertID := cc.ctx.LastInsertID() + warnCnt := cc.ctx.WarningCount() + + enclen := 0 + if len(msg) > 0 { + enclen = util2.LengthEncodedIntSize(uint64(len(msg))) + len(msg) + } + + data := cc.alloc.AllocWithLen(4, 32+enclen) + data = append(data, header) + data = dump.LengthEncodedInt(data, affectedRows) + data = dump.LengthEncodedInt(data, lastInsertID) + if cc.capability&mysql.ClientProtocol41 > 0 { + data = dump.Uint16(data, status) + data = dump.Uint16(data, warnCnt) + } + if enclen > 0 { + // although MySQL manual says the info message is string(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html), + // it is actually string + data = dump.LengthEncodedString(data, []byte(msg)) + } + + err := cc.writePacket(data) + if err != nil { + return err + } + + if flush { + return cc.flush(ctx) + } + + return nil +} + +func (cc *clientConn) writeError(ctx context.Context, e error) error { + var ( + m *mysql.SQLError + te *terror.Error + ok bool + ) + originErr := errors.Cause(e) + if te, ok = originErr.(*terror.Error); ok { + m = terror.ToSQLError(te) + } else { + e := errors.Cause(originErr) + switch y := e.(type) { + case *terror.Error: + m = terror.ToSQLError(y) + default: + m = mysql.NewErrf(mysql.ErrUnknown, "%s", nil, e.Error()) + } + } + + cc.lastCode = m.Code + defer errno.IncrementError(m.Code, cc.user, cc.peerHost) + data := cc.alloc.AllocWithLen(4, 16+len(m.Message)) + data = append(data, mysql.ErrHeader) + data = append(data, byte(m.Code), byte(m.Code>>8)) + if cc.capability&mysql.ClientProtocol41 > 0 { + data = append(data, '#') + data = append(data, m.State...) + } + + data = append(data, m.Message...) + + err := cc.writePacket(data) + if err != nil { + return err + } + return cc.flush(ctx) +} + +// writeEOF writes an EOF packet or if ClientDeprecateEOF is set it +// writes an OK packet with EOF indicator. +// Note this function won't flush the stream because maybe there are more +// packets following it. +// serverStatus, a flag bit represents server information in the packet. +// Note: it is callers' responsibility to ensure correctness of serverStatus. +func (cc *clientConn) writeEOF(ctx context.Context, serverStatus uint16) error { + if cc.capability&mysql.ClientDeprecateEOF > 0 { + return cc.writeOkWith(ctx, mysql.EOFHeader, false, serverStatus) + } + + data := cc.alloc.AllocWithLen(4, 9) + + data = append(data, mysql.EOFHeader) + if cc.capability&mysql.ClientProtocol41 > 0 { + data = dump.Uint16(data, cc.ctx.WarningCount()) + data = dump.Uint16(data, serverStatus) + } + + err := cc.writePacket(data) + return err +} + +func (cc *clientConn) writeReq(ctx context.Context, filePath string) error { + data := cc.alloc.AllocWithLen(4, 5+len(filePath)) + data = append(data, mysql.LocalInFileHeader) + data = append(data, filePath...) + + err := cc.writePacket(data) + if err != nil { + return err + } + + return cc.flush(ctx) +} + +// getDataFromPath gets file contents from file path. +func (cc *clientConn) getDataFromPath(ctx context.Context, path string) ([]byte, error) { + err := cc.writeReq(ctx, path) + if err != nil { + return nil, err + } + var prevData, curData []byte + for { + curData, err = cc.readPacket() + if err != nil && terror.ErrorNotEqual(err, io.EOF) { + return nil, err + } + if len(curData) == 0 { + break + } + prevData = append(prevData, curData...) + } + return prevData, nil +} + +// handleLoadStats does the additional work after processing the 'load stats' query. +// It sends client a file path, then reads the file content from client, loads it into the storage. +func (cc *clientConn) handleLoadStats(ctx context.Context, loadStatsInfo *executor.LoadStatsInfo) error { + // If the server handles the load data request, the client has to set the ClientLocalFiles capability. + if cc.capability&mysql.ClientLocalFiles == 0 { + return servererr.ErrNotAllowedCommand + } + if loadStatsInfo == nil { + return errors.New("load stats: info is empty") + } + data, err := cc.getDataFromPath(ctx, loadStatsInfo.Path) + if err != nil { + return err + } + if len(data) == 0 { + return nil + } + return loadStatsInfo.Update(data) +} + +// handleIndexAdvise does the index advise work and returns the advise result for index. +func (cc *clientConn) handleIndexAdvise(ctx context.Context, indexAdviseInfo *executor.IndexAdviseInfo) error { + if cc.capability&mysql.ClientLocalFiles == 0 { + return servererr.ErrNotAllowedCommand + } + if indexAdviseInfo == nil { + return errors.New("Index Advise: info is empty") + } + + data, err := cc.getDataFromPath(ctx, indexAdviseInfo.Path) + if err != nil { + return err + } + if len(data) == 0 { + return errors.New("Index Advise: infile is empty") + } + + if err := indexAdviseInfo.GetIndexAdvice(data); err != nil { + return err + } + + // TODO: Write the rss []ResultSet. It will be done in another PR. + return nil +} + +func (cc *clientConn) handlePlanReplayerLoad(ctx context.Context, planReplayerLoadInfo *executor.PlanReplayerLoadInfo) error { + if cc.capability&mysql.ClientLocalFiles == 0 { + return servererr.ErrNotAllowedCommand + } + if planReplayerLoadInfo == nil { + return errors.New("plan replayer load: info is empty") + } + data, err := cc.getDataFromPath(ctx, planReplayerLoadInfo.Path) + if err != nil { + return err + } + if len(data) == 0 { + return nil + } + return planReplayerLoadInfo.Update(data) +} + +func (cc *clientConn) handlePlanReplayerDump(ctx context.Context, e *executor.PlanReplayerDumpInfo) error { + if cc.capability&mysql.ClientLocalFiles == 0 { + return servererr.ErrNotAllowedCommand + } + if e == nil { + return errors.New("plan replayer dump: executor is empty") + } + data, err := cc.getDataFromPath(ctx, e.Path) + if err != nil { + logutil.BgLogger().Error(err.Error()) + return err + } + if len(data) == 0 { + return nil + } + return e.DumpSQLsFromFile(ctx, data) +} + +func (cc *clientConn) audit(eventType plugin.GeneralEvent) { + err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + audit := plugin.DeclareAuditManifest(p.Manifest) + if audit.OnGeneralEvent != nil { + cmd := mysql.Command2Str[byte(atomic.LoadUint32(&cc.ctx.GetSessionVars().CommandValue))] + ctx := context.WithValue(context.Background(), plugin.ExecStartTimeCtxKey, cc.ctx.GetSessionVars().StartTime) + audit.OnGeneralEvent(ctx, cc.ctx.GetSessionVars(), eventType, cmd) + } + return nil + }) + if err != nil { + terror.Log(err) + } +} + +// handleQuery executes the sql query string and writes result set or result ok to the client. +// As the execution time of this function represents the performance of TiDB, we do time log and metrics here. +// Some special queries like `load data` that does not return result, which is handled in handleFileTransInConn. +func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) { + defer trace.StartRegion(ctx, "handleQuery").End() + sessVars := cc.ctx.GetSessionVars() + sc := sessVars.StmtCtx + prevWarns := sc.GetWarnings() + var stmts []ast.StmtNode + cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc) + if stmts, err = cc.ctx.Parse(ctx, sql); err != nil { + cc.onExtensionSQLParseFailed(sql, err) + return err + } + + if len(stmts) == 0 { + return cc.writeOK(ctx) + } + + warns := sc.GetWarnings() + parserWarns := warns[len(prevWarns):] + + var pointPlans []base.Plan + cc.ctx.GetSessionVars().InMultiStmts = false + if len(stmts) > 1 { + // The client gets to choose if it allows multi-statements, and + // probably defaults OFF. This helps prevent against SQL injection attacks + // by early terminating the first statement, and then running an entirely + // new statement. + + capabilities := cc.ctx.GetSessionVars().ClientCapability + if capabilities&mysql.ClientMultiStatements < 1 { + // The client does not have multi-statement enabled. We now need to determine + // how to handle an unsafe situation based on the multiStmt sysvar. + switch cc.ctx.GetSessionVars().MultiStatementMode { + case variable.OffInt: + err = servererr.ErrMultiStatementDisabled + return err + case variable.OnInt: + // multi statement is fully permitted, do nothing + default: + warn := contextutil.SQLWarn{Level: contextutil.WarnLevelWarning, Err: servererr.ErrMultiStatementDisabled} + parserWarns = append(parserWarns, warn) + } + } + cc.ctx.GetSessionVars().InMultiStmts = true + + // Only pre-build point plans for multi-statement query + pointPlans, err = cc.prefetchPointPlanKeys(ctx, stmts, sql) + if err != nil { + for _, stmt := range stmts { + cc.onExtensionStmtEnd(stmt, false, err) + } + return err + } + metrics.NumOfMultiQueryHistogram.Observe(float64(len(stmts))) + } + if len(pointPlans) > 0 { + defer cc.ctx.ClearValue(plannercore.PointPlanKey) + } + var retryable bool + var lastStmt ast.StmtNode + var expiredStmtTaskID uint64 + for i, stmt := range stmts { + if lastStmt != nil { + cc.onExtensionStmtEnd(lastStmt, true, nil) + } + lastStmt = stmt + + // expiredTaskID is the task ID of the previous statement. When executing a stmt, + // the StmtCtx will be reinit and the TaskID will change. We can compare the StmtCtx.TaskID + // with the previous one to determine whether StmtCtx has been inited for the current stmt. + expiredStmtTaskID = sessVars.StmtCtx.TaskID + + if len(pointPlans) > 0 { + // Save the point plan in Session, so we don't need to build the point plan again. + cc.ctx.SetValue(plannercore.PointPlanKey, plannercore.PointPlanVal{Plan: pointPlans[i]}) + } + retryable, err = cc.handleStmt(ctx, stmt, parserWarns, i == len(stmts)-1) + if err != nil { + action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err) + if txnErr != nil { + err = txnErr + break + } + + if retryable && action == sessiontxn.StmtActionRetryReady { + cc.ctx.GetSessionVars().RetryInfo.Retrying = true + _, err = cc.handleStmt(ctx, stmt, parserWarns, i == len(stmts)-1) + cc.ctx.GetSessionVars().RetryInfo.Retrying = false + if err != nil { + break + } + continue + } + if !retryable || !errors.ErrorEqual(err, storeerr.ErrTiFlashServerTimeout) { + break + } + _, allowTiFlashFallback := cc.ctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] + if !allowTiFlashFallback { + break + } + // When the TiFlash server seems down, we append a warning to remind the user to check the status of the TiFlash + // server and fallback to TiKV. + warns := append(parserWarns, contextutil.SQLWarn{Level: contextutil.WarnLevelError, Err: err}) + delete(cc.ctx.GetSessionVars().IsolationReadEngines, kv.TiFlash) + _, err = cc.handleStmt(ctx, stmt, warns, i == len(stmts)-1) + cc.ctx.GetSessionVars().IsolationReadEngines[kv.TiFlash] = struct{}{} + if err != nil { + break + } + } + } + + if lastStmt != nil { + cc.onExtensionStmtEnd(lastStmt, sessVars.StmtCtx.TaskID != expiredStmtTaskID, err) + } + + return err +} + +// prefetchPointPlanKeys extracts the point keys in multi-statement query, +// use BatchGet to get the keys, so the values will be cached in the snapshot cache, save RPC call cost. +// For pessimistic transaction, the keys will be batch locked. +func (cc *clientConn) prefetchPointPlanKeys(ctx context.Context, stmts []ast.StmtNode, sqls string) ([]base.Plan, error) { + txn, err := cc.ctx.Txn(false) + if err != nil { + return nil, err + } + if !txn.Valid() { + // Only prefetch in-transaction query for simplicity. + // Later we can support out-transaction multi-statement query. + return nil, nil + } + vars := cc.ctx.GetSessionVars() + if vars.TxnCtx.IsPessimistic { + if vars.IsIsolation(ast.ReadCommitted) { + // TODO: to support READ-COMMITTED, we need to avoid getting new TS for each statement in the query. + return nil, nil + } + if vars.TxnCtx.GetForUpdateTS() != vars.TxnCtx.StartTS { + // Do not handle the case that ForUpdateTS is changed for simplicity. + return nil, nil + } + } + pointPlans := make([]base.Plan, len(stmts)) + var idxKeys []kv.Key //nolint: prealloc + var rowKeys []kv.Key //nolint: prealloc + isCommonHandle := make(map[string]bool, 0) + + handlePlan := func(sctx sessionctx.Context, p base.PhysicalPlan, resetStmtCtxFn func()) error { + var tableID int64 + switch v := p.(type) { + case *plannercore.PointGetPlan: + v.PrunePartitions(sctx) + tableID = executor.GetPhysID(v.TblInfo, v.PartitionIdx) + if v.IndexInfo != nil { + resetStmtCtxFn() + idxKey, err1 := plannercore.EncodeUniqueIndexKey(cc.getCtx(), v.TblInfo, v.IndexInfo, v.IndexValues, tableID) + if err1 != nil { + return err1 + } + idxKeys = append(idxKeys, idxKey) + isCommonHandle[string(hack.String(idxKey))] = v.TblInfo.IsCommonHandle + } else { + rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(tableID, v.Handle)) + } + case *plannercore.BatchPointGetPlan: + _, isTableDual := v.PrunePartitionsAndValues(sctx) + if isTableDual { + return nil + } + pi := v.TblInfo.GetPartitionInfo() + getPhysID := func(i int) int64 { + if pi == nil || i >= len(v.PartitionIdxs) { + return v.TblInfo.ID + } + return executor.GetPhysID(v.TblInfo, &v.PartitionIdxs[i]) + } + if v.IndexInfo != nil { + resetStmtCtxFn() + for i, idxVals := range v.IndexValues { + idxKey, err1 := plannercore.EncodeUniqueIndexKey(cc.getCtx(), v.TblInfo, v.IndexInfo, idxVals, getPhysID(i)) + if err1 != nil { + return err1 + } + idxKeys = append(idxKeys, idxKey) + isCommonHandle[string(hack.String(idxKey))] = v.TblInfo.IsCommonHandle + } + } else { + for i, handle := range v.Handles { + rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(getPhysID(i), handle)) + } + } + } + return nil + } + + sc := vars.StmtCtx + for i, stmt := range stmts { + if _, ok := stmt.(*ast.UseStmt); ok { + // If there is a "use db" statement, we shouldn't cache even if it's possible. + // Consider the scenario where there are statements that could execute on multiple + // schemas, but the schema is actually different. + return nil, nil + } + // TODO: the preprocess is run twice, we should find some way to avoid do it again. + if err = plannercore.Preprocess(ctx, cc.getCtx(), stmt); err != nil { + // error might happen, see https://github.com/pingcap/tidb/issues/39664 + return nil, nil + } + p := plannercore.TryFastPlan(cc.ctx.Session.GetPlanCtx(), stmt) + pointPlans[i] = p + if p == nil { + continue + } + // Only support Update and Delete for now. + // TODO: support other point plans. + switch x := p.(type) { + case *plannercore.Update: + //nolint:forcetypeassert + updateStmt, ok := stmt.(*ast.UpdateStmt) + if !ok { + logutil.BgLogger().Warn("unexpected statement type for Update plan", + zap.String("type", fmt.Sprintf("%T", stmt))) + continue + } + err = handlePlan(cc.ctx.Session, x.SelectPlan, func() { + executor.ResetUpdateStmtCtx(sc, updateStmt, vars) + }) + if err != nil { + return nil, err + } + case *plannercore.Delete: + deleteStmt, ok := stmt.(*ast.DeleteStmt) + if !ok { + logutil.BgLogger().Warn("unexpected statement type for Delete plan", + zap.String("type", fmt.Sprintf("%T", stmt))) + continue + } + err = handlePlan(cc.ctx.Session, x.SelectPlan, func() { + executor.ResetDeleteStmtCtx(sc, deleteStmt, vars) + }) + if err != nil { + return nil, err + } + } + } + if len(idxKeys) == 0 && len(rowKeys) == 0 { + return pointPlans, nil + } + snapshot := txn.GetSnapshot() + setResourceGroupTaggerForMultiStmtPrefetch(snapshot, sqls) + idxVals, err1 := snapshot.BatchGet(ctx, idxKeys) + if err1 != nil { + return nil, err1 + } + for idxKey, idxVal := range idxVals { + h, err2 := tablecodec.DecodeHandleInIndexValue(idxVal) + if err2 != nil { + return nil, err2 + } + tblID := tablecodec.DecodeTableID(hack.Slice(idxKey)) + rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(tblID, h)) + } + if vars.TxnCtx.IsPessimistic { + allKeys := append(rowKeys, idxKeys...) + err = executor.LockKeys(ctx, cc.getCtx(), vars.LockWaitTimeout, allKeys...) + if err != nil { + // suppress the lock error, we are not going to handle it here for simplicity. + err = nil + logutil.BgLogger().Warn("lock keys error on prefetch", zap.Error(err)) + } + } else { + _, err = snapshot.BatchGet(ctx, rowKeys) + if err != nil { + return nil, err + } + } + return pointPlans, nil +} + +func setResourceGroupTaggerForMultiStmtPrefetch(snapshot kv.Snapshot, sqls string) { + if !topsqlstate.TopSQLEnabled() { + return + } + normalized, digest := parser.NormalizeDigest(sqls) + topsql.AttachAndRegisterSQLInfo(context.Background(), normalized, digest, false) + snapshot.SetOption(kv.ResourceGroupTagger, tikvrpc.ResourceGroupTagger(func(req *tikvrpc.Request) { + if req == nil { + return + } + if len(normalized) == 0 { + return + } + req.ResourceGroupTag = resourcegrouptag.EncodeResourceGroupTag(digest, nil, + resourcegrouptag.GetResourceGroupLabelByKey(resourcegrouptag.GetFirstKeyFromRequest(req))) + })) +} + +// The first return value indicates whether the call of handleStmt has no side effect and can be retried. +// Currently, the first return value is used to fall back to TiKV when TiFlash is down. +func (cc *clientConn) handleStmt( + ctx context.Context, stmt ast.StmtNode, + warns []contextutil.SQLWarn, lastStmt bool, +) (bool, error) { + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) + ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails()) + reg := trace.StartRegion(ctx, "ExecuteStmt") + cc.audit(plugin.Starting) + + // if stmt is load data stmt, store the channel that reads from the conn + // into the ctx for executor to use + if s, ok := stmt.(*ast.LoadDataStmt); ok { + if s.FileLocRef == ast.FileLocClient { + err := cc.preprocessLoadDataLocal(ctx) + defer cc.postprocessLoadDataLocal() + if err != nil { + return false, err + } + } + } + + rs, err := cc.ctx.ExecuteStmt(ctx, stmt) + reg.End() + // - If rs is not nil, the statement tracker detachment from session tracker + // is done in the `rs.Close` in most cases. + // - If the rs is nil and err is not nil, the detachment will be done in + // the `handleNoDelay`. + if rs != nil { + defer rs.Close() + } + + if err != nil { + // If error is returned during the planner phase or the executor.Open + // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker + // will not be detached. We need to detach them manually. + if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil { + sv.StmtCtx.DetachMemDiskTracker() + } + return true, err + } + + status := cc.ctx.Status() + if lastStmt { + cc.ctx.GetSessionVars().StmtCtx.AppendWarnings(warns) + } else { + status |= mysql.ServerMoreResultsExists + } + + if rs != nil { + if cc.getStatus() == connStatusShutdown { + return false, exeerrors.ErrQueryInterrupted + } + cc.ctx.GetSessionVars().SQLKiller.SetFinishFunc( + func() { + //nolint: errcheck + rs.Finish() + }) + cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true) + defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false) + defer cc.ctx.GetSessionVars().SQLKiller.ClearFinishFunc() + if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil { + return retryable, err + } + return false, nil + } + + handled, err := cc.handleFileTransInConn(ctx, status) + if handled { + if execStmt := cc.ctx.Value(session.ExecStmtVarKey); execStmt != nil { + //nolint:forcetypeassert + execStmt.(*executor.ExecStmt).FinishExecuteStmt(0, err, false) + } + } + return false, err +} + +// Preprocess LOAD DATA. Load data from a local file requires reading from the connection. +// The function pass a builder to build the connection reader to the context, +// which will be used in LoadDataExec. +func (cc *clientConn) preprocessLoadDataLocal(ctx context.Context) error { + if cc.capability&mysql.ClientLocalFiles == 0 { + return servererr.ErrNotAllowedCommand + } + + var readerBuilder executor.LoadDataReaderBuilder = func(filepath string) ( + io.ReadCloser, error, + ) { + err := cc.writeReq(ctx, filepath) + if err != nil { + return nil, err + } + + drained := false + r, w := io.Pipe() + + go func() { + var errOccurred error + + defer func() { + if errOccurred != nil { + // Continue reading packets to drain the connection + for !drained { + data, err := cc.readPacket() + if err != nil { + logutil.Logger(ctx).Error( + "drain connection failed in load data", + zap.Error(err), + ) + break + } + if len(data) == 0 { + drained = true + } + } + } + err := w.CloseWithError(errOccurred) + if err != nil { + logutil.Logger(ctx).Error( + "close pipe failed in `load data`", + zap.Error(err), + ) + } + }() + + for { + data, err := cc.readPacket() + if err != nil { + errOccurred = err + return + } + + if len(data) == 0 { + drained = true + return + } + + // Write all content in `data` + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + errOccurred = err + return + } + data = data[n:] + } + } + }() + + return r, nil + } + + cc.ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) + + return nil +} + +func (cc *clientConn) postprocessLoadDataLocal() { + cc.ctx.ClearValue(executor.LoadDataReaderBuilderKey) +} + +func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { + handled := false + + loadStats := cc.ctx.Value(executor.LoadStatsVarKey) + if loadStats != nil { + handled = true + defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil) + //nolint:forcetypeassert + if err := cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil { + return handled, err + } + } + + indexAdvise := cc.ctx.Value(executor.IndexAdviseVarKey) + if indexAdvise != nil { + handled = true + defer cc.ctx.SetValue(executor.IndexAdviseVarKey, nil) + //nolint:forcetypeassert + if err := cc.handleIndexAdvise(ctx, indexAdvise.(*executor.IndexAdviseInfo)); err != nil { + return handled, err + } + } + + planReplayerLoad := cc.ctx.Value(executor.PlanReplayerLoadVarKey) + if planReplayerLoad != nil { + handled = true + defer cc.ctx.SetValue(executor.PlanReplayerLoadVarKey, nil) + //nolint:forcetypeassert + if err := cc.handlePlanReplayerLoad(ctx, planReplayerLoad.(*executor.PlanReplayerLoadInfo)); err != nil { + return handled, err + } + } + + planReplayerDump := cc.ctx.Value(executor.PlanReplayerDumpVarKey) + if planReplayerDump != nil { + handled = true + defer cc.ctx.SetValue(executor.PlanReplayerDumpVarKey, nil) + //nolint:forcetypeassert + if err := cc.handlePlanReplayerDump(ctx, planReplayerDump.(*executor.PlanReplayerDumpInfo)); err != nil { + return handled, err + } + } + return handled, cc.writeOkWith(ctx, mysql.OKHeader, true, status) +} + +// handleFieldList returns the field list for a table. +// The sql string is composed of a table name and a terminating character \x00. +func (cc *clientConn) handleFieldList(ctx context.Context, sql string) (err error) { + parts := strings.Split(sql, "\x00") + columns, err := cc.ctx.FieldList(parts[0]) + if err != nil { + return err + } + data := cc.alloc.AllocWithLen(4, 1024) + cc.initResultEncoder(ctx) + defer cc.rsEncoder.Clean() + for _, column := range columns { + data = data[0:4] + data = column.DumpWithDefault(data, cc.rsEncoder) + if err := cc.writePacket(data); err != nil { + return err + } + } + if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil { + return err + } + return cc.flush(ctx) +} + +// writeResultSet writes data into a result set and uses rs.Next to get row data back. +// If binary is true, the data would be encoded in BINARY format. +// serverStatus, a flag bit represents server information. +// fetchSize, the desired number of rows to be fetched each time when client uses cursor. +// retryable indicates whether the call of writeResultSet has no side effect and can be retried to correct error. The call +// has side effect in cursor mode or once data has been sent to client. Currently retryable is used to fallback to TiKV when +// TiFlash is down. +func (cc *clientConn) writeResultSet(ctx context.Context, rs resultset.ResultSet, binary bool, serverStatus uint16, fetchSize int) (retryable bool, runErr error) { + defer func() { + // close ResultSet when cursor doesn't exist + r := recover() + if r == nil { + return + } + recoverdErr, ok := r.(error) + if !ok || !(exeerrors.ErrMemoryExceedForQuery.Equal(recoverdErr) || + exeerrors.ErrMemoryExceedForInstance.Equal(recoverdErr) || + exeerrors.ErrQueryInterrupted.Equal(recoverdErr) || + exeerrors.ErrMaxExecTimeExceeded.Equal(recoverdErr)) { + panic(r) + } + runErr = recoverdErr + // TODO(jianzhang.zj: add metrics here) + logutil.Logger(ctx).Error("write query result panic", zap.Stringer("lastSQL", getLastStmtInConn{cc}), zap.Stack("stack"), zap.Any("recover", r)) + }() + cc.initResultEncoder(ctx) + defer cc.rsEncoder.Clean() + if mysql.HasCursorExistsFlag(serverStatus) { + crs, ok := rs.(resultset.CursorResultSet) + if !ok { + // this branch is actually unreachable + return false, errors.New("this cursor is not a resultSet") + } + if err := cc.writeChunksWithFetchSize(ctx, crs, serverStatus, fetchSize); err != nil { + return false, err + } + return false, cc.flush(ctx) + } + if retryable, err := cc.writeChunks(ctx, rs, binary, serverStatus); err != nil { + return retryable, err + } + + return false, cc.flush(ctx) +} + +func (cc *clientConn) writeColumnInfo(columns []*column.Info) error { + data := cc.alloc.AllocWithLen(4, 1024) + data = dump.LengthEncodedInt(data, uint64(len(columns))) + if err := cc.writePacket(data); err != nil { + return err + } + for _, v := range columns { + data = data[0:4] + data = v.Dump(data, cc.rsEncoder) + if err := cc.writePacket(data); err != nil { + return err + } + } + return nil +} + +// writeChunks writes data from a Chunk, which filled data by a ResultSet, into a connection. +// binary specifies the way to dump data. It throws any error while dumping data. +// serverStatus, a flag bit represents server information +// The first return value indicates whether error occurs at the first call of ResultSet.Next. +func (cc *clientConn) writeChunks(ctx context.Context, rs resultset.ResultSet, binary bool, serverStatus uint16) (bool, error) { + data := cc.alloc.AllocWithLen(4, 1024) + req := rs.NewChunk(cc.chunkAlloc) + gotColumnInfo := false + firstNext := true + validNextCount := 0 + var start time.Time + var stmtDetail *execdetails.StmtExecDetails + stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey) + if stmtDetailRaw != nil { + //nolint:forcetypeassert + stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) + } + for { + failpoint.Inject("fetchNextErr", func(value failpoint.Value) { + //nolint:forcetypeassert + switch value.(string) { + case "firstNext": + failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) + case "secondNext": + if !firstNext { + failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) + } + case "secondNextAndRetConflict": + if !firstNext && validNextCount > 1 { + failpoint.Return(firstNext, kv.ErrWriteConflict) + } + } + }) + // Here server.tidbResultSet implements Next method. + err := rs.Next(ctx, req) + if err != nil { + return firstNext, err + } + if !gotColumnInfo { + // We need to call Next before we get columns. + // Otherwise, we will get incorrect columns info. + columns := rs.Columns() + if stmtDetail != nil { + start = time.Now() + } + if err = cc.writeColumnInfo(columns); err != nil { + return false, err + } + if cc.capability&mysql.ClientDeprecateEOF == 0 { + // metadata only needs EOF marker for old clients without ClientDeprecateEOF + if err = cc.writeEOF(ctx, serverStatus); err != nil { + return false, err + } + } + if stmtDetail != nil { + stmtDetail.WriteSQLRespDuration += time.Since(start) + } + gotColumnInfo = true + } + rowCount := req.NumRows() + if rowCount == 0 { + break + } + validNextCount++ + firstNext = false + reg := trace.StartRegion(ctx, "WriteClientConn") + if stmtDetail != nil { + start = time.Now() + } + for i := 0; i < rowCount; i++ { + data = data[0:4] + if binary { + data, err = column.DumpBinaryRow(data, rs.Columns(), req.GetRow(i), cc.rsEncoder) + } else { + data, err = column.DumpTextRow(data, rs.Columns(), req.GetRow(i), cc.rsEncoder) + } + if err != nil { + reg.End() + return false, err + } + if err = cc.writePacket(data); err != nil { + reg.End() + return false, err + } + } + reg.End() + if stmtDetail != nil { + stmtDetail.WriteSQLRespDuration += time.Since(start) + } + } + if err := rs.Finish(); err != nil { + return false, err + } + + if stmtDetail != nil { + start = time.Now() + } + + err := cc.writeEOF(ctx, serverStatus) + if stmtDetail != nil { + stmtDetail.WriteSQLRespDuration += time.Since(start) + } + return false, err +} + +// writeChunksWithFetchSize writes data from a Chunk, which filled data by a ResultSet, into a connection. +// binary specifies the way to dump data. It throws any error while dumping data. +// serverStatus, a flag bit represents server information. +// fetchSize, the desired number of rows to be fetched each time when client uses cursor. +func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs resultset.CursorResultSet, serverStatus uint16, fetchSize int) error { + var ( + stmtDetail *execdetails.StmtExecDetails + err error + start time.Time + ) + data := cc.alloc.AllocWithLen(4, 1024) + stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey) + if stmtDetailRaw != nil { + //nolint:forcetypeassert + stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) + } + if stmtDetail != nil { + start = time.Now() + } + + iter := rs.GetRowIterator() + // send the rows to the client according to fetchSize. + for i := 0; i < fetchSize && iter.Current(ctx) != iter.End(); i++ { + row := iter.Current(ctx) + + data = data[0:4] + data, err = column.DumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder) + if err != nil { + return err + } + if err = cc.writePacket(data); err != nil { + return err + } + + iter.Next(ctx) + } + if iter.Error() != nil { + return iter.Error() + } + + // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, + // and close ResultSet. + if iter.Current(ctx) == iter.End() { + serverStatus &^= mysql.ServerStatusCursorExists + serverStatus |= mysql.ServerStatusLastRowSend + } + + // don't include the time consumed by `cl.OnFetchReturned()` in the `WriteSQLRespDuration` + if stmtDetail != nil { + stmtDetail.WriteSQLRespDuration += time.Since(start) + } + + if cl, ok := rs.(resultset.FetchNotifier); ok { + cl.OnFetchReturned() + } + + start = time.Now() + err = cc.writeEOF(ctx, serverStatus) + if stmtDetail != nil { + stmtDetail.WriteSQLRespDuration += time.Since(start) + } + return err +} + +func (cc *clientConn) setConn(conn net.Conn) { + cc.bufReadConn = util2.NewBufferedReadConn(conn) + if cc.pkt == nil { + cc.pkt = internal.NewPacketIO(cc.bufReadConn) + } else { + // Preserve current sequence number. + cc.pkt.SetBufferedReadConn(cc.bufReadConn) + } +} + +func (cc *clientConn) upgradeToTLS(tlsConfig *tls.Config) error { + // Important: read from buffered reader instead of the original net.Conn because it may contain data we need. + tlsConn := tls.Server(cc.bufReadConn, tlsConfig) + if err := tlsConn.Handshake(); err != nil { + return err + } + cc.setConn(tlsConn) + cc.tlsConn = tlsConn + return nil +} + +func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { + user, data := util2.ParseNullTermString(data) + cc.user = string(hack.String(user)) + if len(data) < 1 { + return mysql.ErrMalformPacket + } + passLen := int(data[0]) + data = data[1:] + if passLen > len(data) { + return mysql.ErrMalformPacket + } + pass := data[:passLen] + data = data[passLen:] + dbName, data := util2.ParseNullTermString(data) + cc.dbname = string(hack.String(dbName)) + pluginName := "" + if len(data) > 0 { + // skip character set + if cc.capability&mysql.ClientProtocol41 > 0 && len(data) >= 2 { + data = data[2:] + } + if cc.capability&mysql.ClientPluginAuth > 0 && len(data) > 0 { + pluginNameB, _ := util2.ParseNullTermString(data) + pluginName = string(hack.String(pluginNameB)) + } + } + + if err := cc.ctx.Close(); err != nil { + logutil.Logger(ctx).Debug("close old context failed", zap.Error(err)) + } + // session was closed by `ctx.Close` and should `openSession` explicitly to renew session. + // `openSession` won't run again in `openSessionAndDoAuth` because ctx is not nil. + err := cc.openSession() + if err != nil { + return err + } + fakeResp := &handshake.Response41{ + Auth: pass, + AuthPlugin: pluginName, + Capability: cc.capability, + } + if fakeResp.AuthPlugin != "" { + failpoint.Inject("ChangeUserAuthSwitch", func(val failpoint.Value) { + failpoint.Return(errors.Errorf("%v", val)) + }) + newpass, err := cc.checkAuthPlugin(ctx, fakeResp) + if err != nil { + return err + } + if len(newpass) > 0 { + fakeResp.Auth = newpass + } + } + if err := cc.openSessionAndDoAuth(fakeResp.Auth, fakeResp.AuthPlugin, fakeResp.ZstdLevel); err != nil { + return err + } + return cc.handleCommonConnectionReset(ctx) +} + +func (cc *clientConn) handleResetConnection(ctx context.Context) error { + user := cc.ctx.GetSessionVars().User + err := cc.ctx.Close() + if err != nil { + logutil.Logger(ctx).Debug("close old context failed", zap.Error(err)) + } + var tlsStatePtr *tls.ConnectionState + if cc.tlsConn != nil { + tlsState := cc.tlsConn.ConnectionState() + tlsStatePtr = &tlsState + } + tidbCtx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions) + if err != nil { + return err + } + cc.SetCtx(tidbCtx) + if !cc.ctx.AuthWithoutVerification(user) { + return errors.New("Could not reset connection") + } + if cc.dbname != "" { // Restore the current DB + _, err = cc.useDB(context.Background(), cc.dbname) + if err != nil { + return err + } + } + cc.ctx.SetSessionManager(cc.server) + + return cc.handleCommonConnectionReset(ctx) +} + +func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error { + connectionInfo := cc.connectInfo() + cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo + + cc.onExtensionConnEvent(extension.ConnReset, nil) + err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { + authPlugin := plugin.DeclareAuditManifest(p.Manifest) + if authPlugin.OnConnectionEvent != nil { + connInfo := cc.ctx.GetSessionVars().ConnectionInfo + err := authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return err + } + return cc.writeOK(ctx) +} + +// safe to noop except 0x01 "FLUSH PRIVILEGES" +func (cc *clientConn) handleRefresh(ctx context.Context, subCommand byte) error { + if subCommand == 0x01 { + if err := cc.handleQuery(ctx, "FLUSH PRIVILEGES"); err != nil { + return err + } + } + return cc.writeOK(ctx) +} + +var _ fmt.Stringer = getLastStmtInConn{} + +type getLastStmtInConn struct { + *clientConn +} + +func (cc getLastStmtInConn) String() string { + if len(cc.lastPacket) == 0 { + return "" + } + cmd, data := cc.lastPacket[0], cc.lastPacket[1:] + switch cmd { + case mysql.ComInitDB: + return "Use " + string(data) + case mysql.ComFieldList: + return "ListFields " + string(data) + case mysql.ComQuery, mysql.ComStmtPrepare: + sql := string(hack.String(data)) + sql = parser.Normalize(sql, cc.ctx.GetSessionVars().EnableRedactLog) + return executor.FormatSQL(sql).String() + case mysql.ComStmtExecute, mysql.ComStmtFetch: + stmtID := binary.LittleEndian.Uint32(data[0:4]) + return executor.FormatSQL(cc.preparedStmt2String(stmtID)).String() + case mysql.ComStmtClose, mysql.ComStmtReset: + stmtID := binary.LittleEndian.Uint32(data[0:4]) + return mysql.Command2Str[cmd] + " " + strconv.Itoa(int(stmtID)) + default: + if cmdStr, ok := mysql.Command2Str[cmd]; ok { + return cmdStr + } + return string(hack.String(data)) + } +} + +// PProfLabel return sql label used to tag pprof. +func (cc getLastStmtInConn) PProfLabel() string { + if len(cc.lastPacket) == 0 { + return "" + } + cmd, data := cc.lastPacket[0], cc.lastPacket[1:] + switch cmd { + case mysql.ComInitDB: + return "UseDB" + case mysql.ComFieldList: + return "ListFields" + case mysql.ComStmtClose: + return "CloseStmt" + case mysql.ComStmtReset: + return "ResetStmt" + case mysql.ComQuery, mysql.ComStmtPrepare: + return parser.Normalize(executor.FormatSQL(string(hack.String(data))).String(), errors.RedactLogEnable) + case mysql.ComStmtExecute, mysql.ComStmtFetch: + stmtID := binary.LittleEndian.Uint32(data[0:4]) + return executor.FormatSQL(cc.preparedStmt2StringNoArgs(stmtID)).String() + default: + return "" + } +} + +var _ conn.AuthConn = &clientConn{} + +// WriteAuthMoreData implements `conn.AuthConn` interface +func (cc *clientConn) WriteAuthMoreData(data []byte) error { + // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_more_data.html + // the `AuthMoreData` packet is just an arbitrary binary slice with a byte 0x1 as prefix. + return cc.writePacket(append([]byte{0, 0, 0, 0, 1}, data...)) +} + +// ReadPacket implements `conn.AuthConn` interface +func (cc *clientConn) ReadPacket() ([]byte, error) { + return cc.readPacket() +} + +// Flush implements `conn.AuthConn` interface +func (cc *clientConn) Flush(ctx context.Context) error { + return cc.flush(ctx) +} + +type compressionStats struct{} + +// Stats returns the connection statistics. +func (*compressionStats) Stats(vars *variable.SessionVars) (map[string]any, error) { + m := make(map[string]any, 3) + + switch vars.CompressionAlgorithm { + case mysql.CompressionNone: + m[statusCompression] = "OFF" + m[statusCompressionAlgorithm] = "" + m[statusCompressionLevel] = 0 + case mysql.CompressionZlib: + m[statusCompression] = "ON" + m[statusCompressionAlgorithm] = "zlib" + m[statusCompressionLevel] = mysql.ZlibCompressDefaultLevel + case mysql.CompressionZstd: + m[statusCompression] = "ON" + m[statusCompressionAlgorithm] = "zstd" + m[statusCompressionLevel] = vars.CompressionLevel + default: + logutil.BgLogger().Debug( + "unexpected compression algorithm value", + zap.Int("algorithm", vars.CompressionAlgorithm), + ) + m[statusCompression] = "OFF" + m[statusCompressionAlgorithm] = "" + m[statusCompressionLevel] = 0 + } + + return m, nil +} + +// GetScope gets the status variables scope. +func (*compressionStats) GetScope(_ string) variable.ScopeFlag { + return variable.ScopeSession +} + +func init() { + variable.RegisterStatistics(&compressionStats{}) +} diff --git a/pkg/server/conn_stmt.go b/pkg/server/conn_stmt.go index 19ac430500944..8a3bedd411280 100644 --- a/pkg/server/conn_stmt.go +++ b/pkg/server/conn_stmt.go @@ -358,9 +358,9 @@ func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatem } } - failpoint.Inject("avoidEagerCursorFetch", func() { - failpoint.Return(false, errors.New("failpoint avoids eager cursor fetch")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("avoidEagerCursorFetch")); _err_ == nil { + return false, errors.New("failpoint avoids eager cursor fetch") + } cc.initResultEncoder(ctx) defer cc.rsEncoder.Clean() // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read @@ -376,12 +376,12 @@ func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatem rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker) rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch) if variable.EnableTmpStorageOnOOM.Load() { - failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testCursorFetchSpill")); _err_ == nil { if val, ok := val.(bool); val && ok { actionSpill := rowContainer.ActionSpillForTest() defer actionSpill.WaitForTest() } - }) + } action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority) vars.MemTracker.FallbackOldAndSetNewAction(action) } diff --git a/pkg/server/conn_stmt.go__failpoint_stash__ b/pkg/server/conn_stmt.go__failpoint_stash__ new file mode 100644 index 0000000000000..19ac430500944 --- /dev/null +++ b/pkg/server/conn_stmt.go__failpoint_stash__ @@ -0,0 +1,673 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// The MIT License (MIT) +// +// Copyright (c) 2014 wandoulabs +// Copyright (c) 2014 siddontang +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of +// this software and associated documentation files (the "Software"), to deal in +// the Software without restriction, including without limitation the rights to +// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +// the Software, and to permit persons to whom the Software is furnished to do so, +// subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +package server + +import ( + "context" + "encoding/binary" + "runtime/trace" + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/mysql" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/server/internal/dump" + "github.com/pingcap/tidb/pkg/server/internal/parse" + "github.com/pingcap/tidb/pkg/server/internal/resultset" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + storeerr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/redact" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +func (cc *clientConn) HandleStmtPrepare(ctx context.Context, sql string) error { + stmt, columns, params, err := cc.ctx.Prepare(sql) + if err != nil { + return err + } + data := make([]byte, 4, 128) + + // status ok + data = append(data, 0) + // stmt id + data = dump.Uint32(data, uint32(stmt.ID())) + // number columns + data = dump.Uint16(data, uint16(len(columns))) + // number params + data = dump.Uint16(data, uint16(len(params))) + // filter [00] + data = append(data, 0) + // warning count + data = append(data, 0, 0) // TODO support warning count + + if err := cc.writePacket(data); err != nil { + return err + } + + cc.initResultEncoder(ctx) + defer cc.rsEncoder.Clean() + if len(params) > 0 { + for i := 0; i < len(params); i++ { + data = data[0:4] + data = params[i].Dump(data, cc.rsEncoder) + + if err := cc.writePacket(data); err != nil { + return err + } + } + + if cc.capability&mysql.ClientDeprecateEOF == 0 { + // metadata only needs EOF marker for old clients without ClientDeprecateEOF + if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil { + return err + } + } + } + + if len(columns) > 0 { + for i := 0; i < len(columns); i++ { + data = data[0:4] + data = columns[i].Dump(data, cc.rsEncoder) + + if err := cc.writePacket(data); err != nil { + return err + } + } + + if cc.capability&mysql.ClientDeprecateEOF == 0 { + // metadata only needs EOF marker for old clients without ClientDeprecateEOF + if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil { + return err + } + } + } + return cc.flush(ctx) +} + +func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err error) { + defer trace.StartRegion(ctx, "HandleStmtExecute").End() + if len(data) < 9 { + return mysql.ErrMalformPacket + } + pos := 0 + stmtID := binary.LittleEndian.Uint32(data[0:4]) + pos += 4 + + stmt := cc.ctx.GetStatement(int(stmtID)) + if stmt == nil { + return mysql.NewErr(mysql.ErrUnknownStmtHandler, + strconv.FormatUint(uint64(stmtID), 10), "stmt_execute") + } + + flag := data[pos] + pos++ + // Please refer to https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html + // The client indicates that it wants to use cursor by setting this flag. + // Now we only support forward-only, read-only cursor. + useCursor := false + if flag&mysql.CursorTypeReadOnly > 0 { + useCursor = true + } + if flag&mysql.CursorTypeForUpdate > 0 { + return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeForUpdate", nil) + } + if flag&mysql.CursorTypeScrollable > 0 { + return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeScrollable", nil) + } + + if useCursor { + cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) + defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) + } else { + // not using streaming ,can reuse chunk + cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc) + } + // skip iteration-count, always 1 + pos += 4 + + var ( + nullBitmaps []byte + paramTypes []byte + paramValues []byte + ) + cc.initInputEncoder(ctx) + numParams := stmt.NumParams() + args := make([]param.BinaryParam, numParams) + if numParams > 0 { + nullBitmapLen := (numParams + 7) >> 3 + if len(data) < (pos + nullBitmapLen + 1) { + return mysql.ErrMalformPacket + } + nullBitmaps = data[pos : pos+nullBitmapLen] + pos += nullBitmapLen + + // new param bound flag + if data[pos] == 1 { + pos++ + if len(data) < (pos + (numParams << 1)) { + return mysql.ErrMalformPacket + } + + paramTypes = data[pos : pos+(numParams<<1)] + pos += numParams << 1 + paramValues = data[pos:] + // Just the first StmtExecute packet contain parameters type, + // we need save it for further use. + stmt.SetParamsType(paramTypes) + } else { + paramValues = data[pos+1:] + } + + err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) + // This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine) + errReset := stmt.Reset() + if errReset != nil { + logutil.Logger(ctx).Warn("fail to reset statement in EXECUTE command", zap.Error(errReset)) + } + if err != nil { + return errors.Annotate(err, cc.preparedStmt2String(stmtID)) + } + } + + sessVars := cc.ctx.GetSessionVars() + // expiredTaskID is the task ID of the previous statement. When executing a stmt, + // the StmtCtx will be reinit and the TaskID will change. We can compare the StmtCtx.TaskID + // with the previous one to determine whether StmtCtx has been inited for the current stmt. + expiredTaskID := sessVars.StmtCtx.TaskID + err = cc.executePlanCacheStmt(ctx, stmt, args, useCursor) + cc.onExtensionBinaryExecuteEnd(stmt, args, sessVars.StmtCtx.TaskID != expiredTaskID, err) + return err +} + +func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt any, args []param.BinaryParam, useCursor bool) (err error) { + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) + ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails()) + retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) + if err != nil { + action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err) + if txnErr != nil { + return txnErr + } + + if retryable && action == sessiontxn.StmtActionRetryReady { + cc.ctx.GetSessionVars().RetryInfo.Retrying = true + _, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) + cc.ctx.GetSessionVars().RetryInfo.Retrying = false + return err + } + } + _, allowTiFlashFallback := cc.ctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] + if allowTiFlashFallback && err != nil && errors.ErrorEqual(err, storeerr.ErrTiFlashServerTimeout) && retryable { + // When the TiFlash server seems down, we append a warning to remind the user to check the status of the TiFlash + // server and fallback to TiKV. + prevErr := err + delete(cc.ctx.GetSessionVars().IsolationReadEngines, kv.TiFlash) + defer func() { + cc.ctx.GetSessionVars().IsolationReadEngines[kv.TiFlash] = struct{}{} + }() + _, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) + // We append warning after the retry because `ResetContextOfStmt` may be called during the retry, which clears warnings. + cc.ctx.GetSessionVars().StmtCtx.AppendError(prevErr) + } + return err +} + +// The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried. +// Currently the first return value is used to fallback to TiKV when TiFlash is down. +func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []param.BinaryParam, useCursor bool) (bool, error) { + vars := (&cc.ctx).GetSessionVars() + prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID())) + if err != nil { + return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) + } + execStmt := &ast.ExecuteStmt{ + BinaryArgs: args, + PrepStmt: prepStmt, + } + + // first, try to clear the left cursor if there is one + if useCursor && stmt.GetCursorActive() { + if stmt.GetResultSet() != nil && stmt.GetResultSet().GetRowIterator() != nil { + stmt.GetResultSet().GetRowIterator().Close() + } + if stmt.GetRowContainer() != nil { + stmt.GetRowContainer().GetMemTracker().Detach() + stmt.GetRowContainer().GetDiskTracker().Detach() + err := stmt.GetRowContainer().Close() + if err != nil { + logutil.Logger(ctx).Error( + "Fail to close rowContainer before executing statement. May cause resource leak", + zap.Error(err)) + } + stmt.StoreRowContainer(nil) + } + stmt.StoreResultSet(nil) + stmt.SetCursorActive(false) + } + + // For the combination of `ComPrepare` and `ComExecute`, the statement name is stored in the client side, and the + // TiDB only has the ID, so don't try to construct an `EXECUTE SOMETHING`. Use the original prepared statement here + // instead. + sql := "" + planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt) + if ok { + sql = planCacheStmt.StmtText + } + execStmt.SetText(charset.EncodingUTF8Impl, sql) + rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt) + var lazy bool + if rs != nil { + defer func() { + if !lazy { + rs.Close() + } + }() + } + if err != nil { + // If error is returned during the planner phase or the executor.Open + // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker + // will not be detached. We need to detach them manually. + if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil { + sv.StmtCtx.DetachMemDiskTracker() + } + return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) + } + + if rs == nil { + if useCursor { + vars.SetStatusFlag(mysql.ServerStatusCursorExists, false) + } + return false, cc.writeOK(ctx) + } + if planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt); ok { + rs.SetPreparedStmt(planCacheStmt) + } + + // if the client wants to use cursor + // we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back ColumnInfo. + // Tell the client cursor exists in server by setting proper serverStatus. + if useCursor { + lazy, err = cc.executeWithCursor(ctx, stmt, rs) + return false, err + } + retryable, err := cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), 0) + if err != nil { + return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) + } + return false, nil +} + +func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (lazy bool, err error) { + vars := (&cc.ctx).GetSessionVars() + if vars.EnableLazyCursorFetch { + // try to execute with lazy cursor fetch + ok, err := cc.executeWithLazyCursor(ctx, stmt, rs) + + // if `ok` is false, should try to execute without lazy cursor fetch + if ok { + return true, err + } + } + + failpoint.Inject("avoidEagerCursorFetch", func() { + failpoint.Return(false, errors.New("failpoint avoids eager cursor fetch")) + }) + cc.initResultEncoder(ctx) + defer cc.rsEncoder.Clean() + // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read + // the rows directly to avoid running executor and accessing shared params/variables in the session + // NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command + // but the rows are still needed in the following FETCH command. + + // create the row container to manage spill + // this `rowContainer` will be released when the statement (or the connection) is closed. + rowContainer := chunk.NewRowContainer(rs.FieldTypes(), vars.MaxChunkSize) + rowContainer.GetMemTracker().AttachTo(vars.MemTracker) + rowContainer.GetMemTracker().SetLabel(memory.LabelForCursorFetch) + rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker) + rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch) + if variable.EnableTmpStorageOnOOM.Load() { + failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) { + if val, ok := val.(bool); val && ok { + actionSpill := rowContainer.ActionSpillForTest() + defer actionSpill.WaitForTest() + } + }) + action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority) + vars.MemTracker.FallbackOldAndSetNewAction(action) + } + defer func() { + if err != nil { + rowContainer.GetMemTracker().Detach() + rowContainer.GetDiskTracker().Detach() + errCloseRowContainer := rowContainer.Close() + if errCloseRowContainer != nil { + logutil.Logger(ctx).Error("Fail to close rowContainer in error handler. May cause resource leak", + zap.NamedError("original-error", err), zap.NamedError("close-error", errCloseRowContainer)) + } + } + }() + + for { + chk := rs.NewChunk(nil) + + if err = rs.Next(ctx, chk); err != nil { + return false, err + } + rowCount := chk.NumRows() + if rowCount == 0 { + break + } + + err = rowContainer.Add(chk) + if err != nil { + return false, err + } + } + + reader := chunk.NewRowContainerReader(rowContainer) + defer func() { + if err != nil { + reader.Close() + } + }() + crs := resultset.WrapWithRowContainerCursor(rs, reader) + if cl, ok := crs.(resultset.FetchNotifier); ok { + cl.OnFetchReturned() + } + stmt.StoreRowContainer(rowContainer) + + err = cc.writeExecuteResultWithCursor(ctx, stmt, crs) + return false, err +} + +// executeWithLazyCursor tries to detach the `ResultSet` and make it suitable to execute lazily. +// Be careful that the return value `(bool, error)` has different meaning with other similar functions. The first `bool` represent whether +// the `ResultSet` is suitable for lazy execution. If the return value is `(false, _)`, the `rs` in argument can still be used. If the +// first return value is `true` and `err` is not nil, the `rs` cannot be used anymore and should return the error to the upper layer. +func (cc *clientConn) executeWithLazyCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (ok bool, err error) { + drs, ok, err := rs.TryDetach() + if !ok || err != nil { + return false, err + } + + vars := (&cc.ctx).GetSessionVars() + crs := resultset.WrapWithLazyCursor(drs, vars.InitChunkSize, vars.MaxChunkSize) + err = cc.writeExecuteResultWithCursor(ctx, stmt, crs) + return true, err +} + +// writeExecuteResultWithCursor will store the `ResultSet` in `stmt` and send the column info to the client. The logic is shared between +// lazy cursor fetch and normal(eager) cursor fetch. +func (cc *clientConn) writeExecuteResultWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.CursorResultSet) error { + var err error + + stmt.StoreResultSet(rs) + stmt.SetCursorActive(true) + defer func() { + if err != nil { + // the resultSet and rowContainer have been closed in former "defer" statement. + stmt.StoreResultSet(nil) + stmt.StoreRowContainer(nil) + stmt.SetCursorActive(false) + } + }() + + if err = cc.writeColumnInfo(rs.Columns()); err != nil { + return err + } + + // explicitly flush columnInfo to client. + err = cc.writeEOF(ctx, cc.ctx.Status()) + if err != nil { + return err + } + + return cc.flush(ctx) +} + +func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) { + cc.ctx.GetSessionVars().StartTime = time.Now() + cc.ctx.GetSessionVars().ClearAlloc(nil, false) + cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) + defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) + // Reset the warn count. TODO: consider whether it's better to reset the whole session context/statement context. + if cc.ctx.GetSessionVars().StmtCtx != nil { + cc.ctx.GetSessionVars().StmtCtx.SetWarnings(nil) + } + cc.ctx.GetSessionVars().SysErrorCount = 0 + cc.ctx.GetSessionVars().SysWarningCount = 0 + + stmtID, fetchSize, err := parse.StmtFetchCmd(data) + if err != nil { + return err + } + + stmt := cc.ctx.GetStatement(int(stmtID)) + if stmt == nil { + return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler, + strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID)) + } + if !stmt.GetCursorActive() { + return errors.Annotate(mysql.NewErr(mysql.ErrSpCursorNotOpen), cc.preparedStmt2String(stmtID)) + } + // from now on, we have made sure: the statement has an active cursor + // then if facing any error, this cursor should be reset + defer func() { + if err != nil { + errReset := stmt.Reset() + if errReset != nil { + logutil.Logger(ctx).Error("Fail to reset statement in error handler. May cause resource leak.", + zap.NamedError("original-error", err), zap.NamedError("reset-error", errReset)) + } + } + }() + + if topsqlstate.TopSQLEnabled() { + prepareObj, _ := cc.preparedStmtID2CachePreparedStmt(stmtID) + if prepareObj != nil && prepareObj.SQLDigest != nil { + ctx = topsql.AttachAndRegisterSQLInfo(ctx, prepareObj.NormalizedSQL, prepareObj.SQLDigest, false) + } + } + sql := "" + if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*TiDBStatement); ok { + sql = prepared.sql + } + cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0) + rs := stmt.GetResultSet() + + _, err = cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), int(fetchSize)) + // if the iterator reached the end before writing result, we could say the `FETCH` command will send EOF + if rs.GetRowIterator().Current(ctx) == rs.GetRowIterator().End() { + // also reset the statement when the cursor reaches the end + // don't overwrite the `err` in outer scope, to avoid redundant `Reset()` in `defer` statement (though, it's not + // a big problem, as the `Reset()` function call is idempotent.) + err := stmt.Reset() + if err != nil { + logutil.Logger(ctx).Error("Fail to reset statement when FETCH command reaches the end. May cause resource leak", + zap.NamedError("error", err)) + } + } + if err != nil { + return errors.Annotate(err, cc.preparedStmt2String(stmtID)) + } + + return nil +} + +func (cc *clientConn) handleStmtClose(data []byte) (err error) { + if len(data) < 4 { + return + } + + stmtID := int(binary.LittleEndian.Uint32(data[0:4])) + stmt := cc.ctx.GetStatement(stmtID) + if stmt != nil { + return stmt.Close() + } + + return +} + +func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) { + if len(data) < 6 { + return mysql.ErrMalformPacket + } + + stmtID := int(binary.LittleEndian.Uint32(data[0:4])) + + stmt := cc.ctx.GetStatement(stmtID) + if stmt == nil { + return mysql.NewErr(mysql.ErrUnknownStmtHandler, + strconv.Itoa(stmtID), "stmt_send_longdata") + } + + paramID := int(binary.LittleEndian.Uint16(data[4:6])) + return stmt.AppendParam(paramID, data[6:]) +} + +func (cc *clientConn) handleStmtReset(ctx context.Context, data []byte) (err error) { + // A reset command should reset the statement to the state when it was right after prepare + // Then the following state should be cleared: + // 1.The opened cursor, including the rowContainer (and its cursor/memTracker). + // 2.The argument sent through `SEND_LONG_DATA`. + if len(data) < 4 { + return mysql.ErrMalformPacket + } + + stmtID := int(binary.LittleEndian.Uint32(data[0:4])) + stmt := cc.ctx.GetStatement(stmtID) + if stmt == nil { + return mysql.NewErr(mysql.ErrUnknownStmtHandler, + strconv.Itoa(stmtID), "stmt_reset") + } + err = stmt.Reset() + if err != nil { + // Both server and client cannot handle the error case well, so just left an error and return OK. + // It's fine to receive further `EXECUTE` command even the `Reset` function call failed. + logutil.Logger(ctx).Error("Fail to close statement in error handler of RESET command. May cause resource leak", + zap.NamedError("original-error", err), zap.NamedError("close-error", err)) + + return cc.writeOK(ctx) + } + + return cc.writeOK(ctx) +} + +// handleSetOption refer to https://dev.mysql.com/doc/internals/en/com-set-option.html +func (cc *clientConn) handleSetOption(ctx context.Context, data []byte) (err error) { + if len(data) < 2 { + return mysql.ErrMalformPacket + } + + switch binary.LittleEndian.Uint16(data[:2]) { + case 0: + cc.capability |= mysql.ClientMultiStatements + cc.ctx.SetClientCapability(cc.capability) + case 1: + cc.capability &^= mysql.ClientMultiStatements + cc.ctx.SetClientCapability(cc.capability) + default: + return mysql.ErrMalformPacket + } + + if err = cc.writeEOF(ctx, cc.ctx.Status()); err != nil { + return err + } + + return cc.flush(ctx) +} + +func (cc *clientConn) preparedStmt2String(stmtID uint32) string { + sv := cc.ctx.GetSessionVars() + if sv == nil { + return "" + } + sql := parser.Normalize(cc.preparedStmt2StringNoArgs(stmtID), sv.EnableRedactLog) + if m := sv.EnableRedactLog; m != errors.RedactLogEnable { + sql += redact.String(sv.EnableRedactLog, sv.PlanCacheParams.String()) + } + return sql +} + +func (cc *clientConn) preparedStmt2StringNoArgs(stmtID uint32) string { + sv := cc.ctx.GetSessionVars() + if sv == nil { + return "" + } + preparedObj, invalid := cc.preparedStmtID2CachePreparedStmt(stmtID) + if invalid { + return "invalidate PlanCacheStmt type, ID: " + strconv.FormatUint(uint64(stmtID), 10) + } + if preparedObj == nil { + return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10) + } + return preparedObj.PreparedAst.Stmt.Text() +} + +func (cc *clientConn) preparedStmtID2CachePreparedStmt(stmtID uint32) (_ *plannercore.PlanCacheStmt, invalid bool) { + sv := cc.ctx.GetSessionVars() + if sv == nil { + return nil, false + } + preparedPointer, ok := sv.PreparedStmts[stmtID] + if !ok { + // not found + return nil, false + } + preparedObj, ok := preparedPointer.(*plannercore.PlanCacheStmt) + if !ok { + // invalid cache. should never happen. + return nil, true + } + return preparedObj, false +} diff --git a/pkg/server/handler/binding__failpoint_binding__.go b/pkg/server/handler/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..12717ba2a9228 --- /dev/null +++ b/pkg/server/handler/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package handler + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/server/handler/extractorhandler/binding__failpoint_binding__.go b/pkg/server/handler/extractorhandler/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..7301cdac20fb4 --- /dev/null +++ b/pkg/server/handler/extractorhandler/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package extractorhandler + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/server/handler/extractorhandler/extractor.go b/pkg/server/handler/extractorhandler/extractor.go index 4d925e6e5f190..ba5831a6095a2 100644 --- a/pkg/server/handler/extractorhandler/extractor.go +++ b/pkg/server/handler/extractorhandler/extractor.go @@ -56,16 +56,16 @@ func (eh ExtractTaskServeHandler) ServeHTTP(w http.ResponseWriter, req *http.Req handler.WriteError(w, err) return } - failpoint.Inject("extractTaskServeHandler", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("extractTaskServeHandler")); _err_ == nil { if val.(bool) { w.WriteHeader(http.StatusOK) _, err = w.Write([]byte("mock")) if err != nil { handler.WriteError(w, err) } - failpoint.Return() + return } - }) + } name, err := eh.ExtractHandler.ExtractTask(context.Background(), task) if err != nil { diff --git a/pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ b/pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ new file mode 100644 index 0000000000000..4d925e6e5f190 --- /dev/null +++ b/pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ @@ -0,0 +1,169 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extractorhandler + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/server/handler" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +const ( + extractPlanTaskType = "plan" +) + +// ExtractTaskServeHandler is the http serve handler for extract task handler +type ExtractTaskServeHandler struct { + ExtractHandler *domain.ExtractHandle +} + +// NewExtractTaskServeHandler creates a new extract task serve handler +func NewExtractTaskServeHandler(extractHandler *domain.ExtractHandle) *ExtractTaskServeHandler { + return &ExtractTaskServeHandler{ExtractHandler: extractHandler} +} + +// ServeHTTP serves http +func (eh ExtractTaskServeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + task, isDump, err := buildExtractTask(req) + if err != nil { + logutil.BgLogger().Error("build extract task failed", zap.Error(err)) + handler.WriteError(w, err) + return + } + failpoint.Inject("extractTaskServeHandler", func(val failpoint.Value) { + if val.(bool) { + w.WriteHeader(http.StatusOK) + _, err = w.Write([]byte("mock")) + if err != nil { + handler.WriteError(w, err) + } + failpoint.Return() + } + }) + + name, err := eh.ExtractHandler.ExtractTask(context.Background(), task) + if err != nil { + logutil.BgLogger().Error("extract task failed", zap.Error(err)) + handler.WriteError(w, err) + return + } + w.WriteHeader(http.StatusOK) + if !isDump { + _, err = w.Write([]byte(name)) + if err != nil { + logutil.BgLogger().Error("extract handler failed", zap.Error(err)) + } + return + } + content, err := loadExtractResponse(name) + if err != nil { + logutil.BgLogger().Error("load extract task failed", zap.Error(err)) + handler.WriteError(w, err) + return + } + _, err = w.Write(content) + if err != nil { + handler.WriteError(w, err) + return + } + w.Header().Set("Content-Type", "application/zip") + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.zip\"", name)) +} + +func loadExtractResponse(name string) ([]byte, error) { + path := filepath.Join(domain.GetExtractTaskDirName(), name) + //nolint: gosec + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + content, err := io.ReadAll(file) + if err != nil { + return nil, err + } + return content, nil +} + +func buildExtractTask(req *http.Request) (*domain.ExtractTask, bool, error) { + extractTaskType := req.URL.Query().Get(handler.Type) + if strings.ToLower(extractTaskType) == extractPlanTaskType { + return buildExtractPlanTask(req) + } + logutil.BgLogger().Error("unknown extract task type") + return nil, false, errors.New("unknown extract task type") +} + +func buildExtractPlanTask(req *http.Request) (*domain.ExtractTask, bool, error) { + beginStr := req.URL.Query().Get(handler.Begin) + endStr := req.URL.Query().Get(handler.End) + var begin time.Time + var err error + if len(beginStr) < 1 { + begin = time.Now().Add(30 * time.Minute) + } else { + begin, err = time.Parse(types.TimeFormat, beginStr) + if err != nil { + logutil.BgLogger().Error("extract task begin time failed", zap.Error(err), zap.String("begin", beginStr)) + return nil, false, err + } + } + var end time.Time + if len(endStr) < 1 { + end = time.Now() + } else { + end, err = time.Parse(types.TimeFormat, endStr) + if err != nil { + logutil.BgLogger().Error("extract task end time failed", zap.Error(err), zap.String("end", endStr)) + return nil, false, err + } + } + isDump := extractBoolParam(handler.IsDump, false, req) + + return &domain.ExtractTask{ + ExtractType: domain.ExtractPlanType, + IsBackgroundJob: false, + Begin: begin, + End: end, + SkipStats: extractBoolParam(handler.IsSkipStats, false, req), + UseHistoryView: extractBoolParam(handler.IsHistoryView, true, req), + }, isDump, nil +} + +func extractBoolParam(param string, defaultValue bool, req *http.Request) bool { + str := req.URL.Query().Get(param) + if len(str) < 1 { + return defaultValue + } + v, err := strconv.ParseBool(str) + if err != nil { + return defaultValue + } + return v +} diff --git a/pkg/server/handler/tikv_handler.go b/pkg/server/handler/tikv_handler.go index 1a6339e05be6d..9c9081ce51270 100644 --- a/pkg/server/handler/tikv_handler.go +++ b/pkg/server/handler/tikv_handler.go @@ -259,11 +259,11 @@ func (t *TikvHandlerTool) GetRegionsMeta(regionIDs []uint64) ([]RegionMeta, erro return nil, errors.Trace(err) } - failpoint.Inject("errGetRegionByIDEmpty", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errGetRegionByIDEmpty")); _err_ == nil { if val.(bool) { region.Meta = nil } - }) + } if region.Meta == nil { return nil, errors.Errorf("region not found for regionID %q", regionID) diff --git a/pkg/server/handler/tikv_handler.go__failpoint_stash__ b/pkg/server/handler/tikv_handler.go__failpoint_stash__ new file mode 100644 index 0000000000000..1a6339e05be6d --- /dev/null +++ b/pkg/server/handler/tikv_handler.go__failpoint_stash__ @@ -0,0 +1,279 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package handler + +import ( + "context" + "encoding/hex" + "fmt" + "net/url" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + derr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/table/tables" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/tikv/client-go/v2/tikv" +) + +// TikvHandlerTool is a tool to handle TiKV data. +type TikvHandlerTool struct { + helper.Helper +} + +// NewTikvHandlerTool creates a new TikvHandlerTool. +func NewTikvHandlerTool(store helper.Storage) *TikvHandlerTool { + return &TikvHandlerTool{Helper: *helper.NewHelper(store)} +} + +type mvccKV struct { + Key string `json:"key"` + RegionID uint64 `json:"region_id"` + Value *kvrpcpb.MvccGetByKeyResponse `json:"value"` +} + +// GetRegionIDByKey gets the region id by the key. +func (t *TikvHandlerTool) GetRegionIDByKey(encodedKey []byte) (uint64, error) { + keyLocation, err := t.RegionCache.LocateKey(tikv.NewBackofferWithVars(context.Background(), 500, nil), encodedKey) + if err != nil { + return 0, derr.ToTiDBErr(err) + } + return keyLocation.Region.GetID(), nil +} + +// GetHandle gets the handle of the record. +func (t *TikvHandlerTool) GetHandle(tb table.PhysicalTable, params map[string]string, values url.Values) (kv.Handle, error) { + var handle kv.Handle + if intHandleStr, ok := params[Handle]; ok { + if tb.Meta().IsCommonHandle { + return nil, errors.BadRequestf("For clustered index tables, please use query strings to specify the column values.") + } + intHandle, err := strconv.ParseInt(intHandleStr, 0, 64) + if err != nil { + return nil, errors.Trace(err) + } + handle = kv.IntHandle(intHandle) + } else { + tblInfo := tb.Meta() + pkIdx := tables.FindPrimaryIndex(tblInfo) + if pkIdx == nil || !tblInfo.IsCommonHandle { + return nil, errors.BadRequestf("Clustered common handle not found.") + } + cols := tblInfo.Cols() + pkCols := make([]*model.ColumnInfo, 0, len(pkIdx.Columns)) + for _, idxCol := range pkIdx.Columns { + pkCols = append(pkCols, cols[idxCol.Offset]) + } + sc := stmtctx.NewStmtCtx() + sc.SetTimeZone(time.UTC) + pkDts, err := t.formValue2DatumRow(sc, values, pkCols) + if err != nil { + return nil, errors.Trace(err) + } + tablecodec.TruncateIndexValues(tblInfo, pkIdx, pkDts) + var handleBytes []byte + handleBytes, err = codec.EncodeKey(sc.TimeZone(), nil, pkDts...) + err = sc.HandleError(err) + if err != nil { + return nil, errors.Trace(err) + } + handle, err = kv.NewCommonHandle(handleBytes) + if err != nil { + return nil, errors.Trace(err) + } + } + return handle, nil +} + +// GetMvccByIdxValue gets the mvcc by the index value. +func (t *TikvHandlerTool) GetMvccByIdxValue(idx table.Index, values url.Values, idxCols []*model.ColumnInfo, handle kv.Handle) ([]*helper.MvccKV, error) { + // HTTP request is not a database session, set timezone to UTC directly here. + // See https://github.com/pingcap/tidb/blob/master/docs/tidb_http_api.md for more details. + sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) + idxRow, err := t.formValue2DatumRow(sc, values, idxCols) + if err != nil { + return nil, errors.Trace(err) + } + encodedKey, _, err := idx.GenIndexKey(sc.ErrCtx(), sc.TimeZone(), idxRow, handle, nil) + if err != nil { + return nil, errors.Trace(err) + } + data, err := t.GetMvccByEncodedKey(encodedKey) + if err != nil { + return nil, err + } + regionID, err := t.GetRegionIDByKey(encodedKey) + if err != nil { + return nil, err + } + idxData := &helper.MvccKV{Key: strings.ToUpper(hex.EncodeToString(encodedKey)), RegionID: regionID, Value: data} + tablecodec.IndexKey2TempIndexKey(encodedKey) + data, err = t.GetMvccByEncodedKey(encodedKey) + if err != nil { + return nil, err + } + regionID, err = t.GetRegionIDByKey(encodedKey) + if err != nil { + return nil, err + } + tempIdxData := &helper.MvccKV{Key: strings.ToUpper(hex.EncodeToString(encodedKey)), RegionID: regionID, Value: data} + return append([]*helper.MvccKV{}, idxData, tempIdxData), err +} + +// formValue2DatumRow converts URL query string to a Datum Row. +func (*TikvHandlerTool) formValue2DatumRow(sc *stmtctx.StatementContext, values url.Values, idxCols []*model.ColumnInfo) ([]types.Datum, error) { + data := make([]types.Datum, len(idxCols)) + for i, col := range idxCols { + colName := col.Name.String() + vals, ok := values[colName] + if !ok { + return nil, errors.BadRequestf("Missing value for index column %s.", colName) + } + + switch len(vals) { + case 0: + data[i].SetNull() + case 1: + bDatum := types.NewStringDatum(vals[0]) + cDatum, err := bDatum.ConvertTo(sc.TypeCtx(), &col.FieldType) + if err != nil { + return nil, errors.Trace(err) + } + data[i] = cDatum + default: + return nil, errors.BadRequestf("Invalid query form for column '%s', it's values are %v."+ + " Column value should be unique for one index record.", colName, vals) + } + } + return data, nil +} + +// GetTableID gets the table ID by the database name and table name. +func (t *TikvHandlerTool) GetTableID(dbName, tableName string) (int64, error) { + tbl, err := t.GetTable(dbName, tableName) + if err != nil { + return 0, errors.Trace(err) + } + return tbl.GetPhysicalID(), nil +} + +// GetTable gets the table by the database name and table name. +func (t *TikvHandlerTool) GetTable(dbName, tableName string) (table.PhysicalTable, error) { + schema, err := t.Schema() + if err != nil { + return nil, errors.Trace(err) + } + tableName, partitionName := ExtractTableAndPartitionName(tableName) + tableVal, err := schema.TableByName(context.Background(), model.NewCIStr(dbName), model.NewCIStr(tableName)) + if err != nil { + return nil, errors.Trace(err) + } + return t.GetPartition(tableVal, partitionName) +} + +// GetPartition gets the partition by the table and partition name. +func (*TikvHandlerTool) GetPartition(tableVal table.Table, partitionName string) (table.PhysicalTable, error) { + if pt, ok := tableVal.(table.PartitionedTable); ok { + if partitionName == "" { + return tableVal.(table.PhysicalTable), errors.New("work on partitioned table, please specify the table name like this: table(partition)") + } + tblInfo := pt.Meta() + pid, err := tables.FindPartitionByName(tblInfo, partitionName) + if err != nil { + return nil, errors.Trace(err) + } + return pt.GetPartition(pid), nil + } + if partitionName != "" { + return nil, fmt.Errorf("%s is not a partitionted table", tableVal.Meta().Name) + } + return tableVal.(table.PhysicalTable), nil +} + +// Schema gets the schema. +func (t *TikvHandlerTool) Schema() (infoschema.InfoSchema, error) { + dom, err := session.GetDomain(t.Store) + if err != nil { + return nil, err + } + return dom.InfoSchema(), nil +} + +// HandleMvccGetByHex handles the request of getting mvcc by hex encoded key. +func (t *TikvHandlerTool) HandleMvccGetByHex(params map[string]string) (*mvccKV, error) { + encodedKey, err := hex.DecodeString(params[HexKey]) + if err != nil { + return nil, errors.Trace(err) + } + data, err := t.GetMvccByEncodedKey(encodedKey) + if err != nil { + return nil, errors.Trace(err) + } + regionID, err := t.GetRegionIDByKey(encodedKey) + if err != nil { + return nil, err + } + return &mvccKV{Key: strings.ToUpper(params[HexKey]), Value: data, RegionID: regionID}, nil +} + +// RegionMeta contains a region's peer detail +type RegionMeta struct { + ID uint64 `json:"region_id"` + Leader *metapb.Peer `json:"leader"` + Peers []*metapb.Peer `json:"peers"` + RegionEpoch *metapb.RegionEpoch `json:"region_epoch"` +} + +// GetRegionsMeta gets regions meta by regionIDs +func (t *TikvHandlerTool) GetRegionsMeta(regionIDs []uint64) ([]RegionMeta, error) { + regions := make([]RegionMeta, len(regionIDs)) + for i, regionID := range regionIDs { + region, err := t.RegionCache.PDClient().GetRegionByID(context.TODO(), regionID) + if err != nil { + return nil, errors.Trace(err) + } + + failpoint.Inject("errGetRegionByIDEmpty", func(val failpoint.Value) { + if val.(bool) { + region.Meta = nil + } + }) + + if region.Meta == nil { + return nil, errors.Errorf("region not found for regionID %q", regionID) + } + regions[i] = RegionMeta{ + ID: regionID, + Leader: region.Leader, + Peers: region.Meta.Peers, + RegionEpoch: region.Meta.RegionEpoch, + } + } + return regions, nil +} diff --git a/pkg/server/http_status.go b/pkg/server/http_status.go index 335b3855dc68c..e77fe87e0bbad 100644 --- a/pkg/server/http_status.go +++ b/pkg/server/http_status.go @@ -420,14 +420,14 @@ func (s *Server) startHTTPServer() { }) // failpoint is enabled only for tests so we can add some http APIs here for tests. - failpoint.Inject("enableTestAPI", func() { + if _, _err_ := failpoint.Eval(_curpkg_("enableTestAPI")); _err_ == nil { router.PathPrefix("/fail/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") new(failpoint.HttpHandler).ServeHTTP(w, r) }) router.Handle("/test/{mod}/{op}", tikvhandler.NewTestHandler(tikvHandlerTool, 0)) - }) + } // ddlHook is enabled only for tests so we can substitute the callback in the DDL. router.Handle("/test/ddl/hook", tikvhandler.DDLHookHandler{}) diff --git a/pkg/server/http_status.go__failpoint_stash__ b/pkg/server/http_status.go__failpoint_stash__ new file mode 100644 index 0000000000000..335b3855dc68c --- /dev/null +++ b/pkg/server/http_status.go__failpoint_stash__ @@ -0,0 +1,613 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "archive/zip" + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/pprof" + "net/url" + "runtime" + rpprof "runtime/pprof" + "strconv" + "strings" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/fn" + pb "github.com/pingcap/kvproto/pkg/autoid" + autoid "github.com/pingcap/tidb/pkg/autoid_service" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/server/handler" + "github.com/pingcap/tidb/pkg/server/handler/optimizor" + "github.com/pingcap/tidb/pkg/server/handler/tikvhandler" + "github.com/pingcap/tidb/pkg/server/handler/ttlhandler" + util2 "github.com/pingcap/tidb/pkg/server/internal/util" + "github.com/pingcap/tidb/pkg/session" + "github.com/pingcap/tidb/pkg/store" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/cpuprofile" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/printer" + "github.com/pingcap/tidb/pkg/util/versioninfo" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/soheilhy/cmux" + "github.com/tiancaiamao/appdash/traceapp" + "go.uber.org/zap" + "google.golang.org/grpc/channelz/service" + static "sourcegraph.com/sourcegraph/appdash-data" +) + +const defaultStatusPort = 10080 + +func (s *Server) startStatusHTTP() error { + err := s.initHTTPListener() + if err != nil { + return err + } + go s.startHTTPServer() + return nil +} + +func serveError(w http.ResponseWriter, status int, txt string) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.Header().Del("Content-Disposition") + w.WriteHeader(status) + _, err := fmt.Fprintln(w, txt) + terror.Log(err) +} + +func sleepWithCtx(ctx context.Context, d time.Duration) { + select { + case <-time.After(d): + case <-ctx.Done(): + } +} + +func (s *Server) listenStatusHTTPServer() error { + s.statusAddr = net.JoinHostPort(s.cfg.Status.StatusHost, strconv.Itoa(int(s.cfg.Status.StatusPort))) + if s.cfg.Status.StatusPort == 0 && !RunInGoTest { + s.statusAddr = net.JoinHostPort(s.cfg.Status.StatusHost, strconv.Itoa(defaultStatusPort)) + } + + logutil.BgLogger().Info("for status and metrics report", zap.String("listening on addr", s.statusAddr)) + clusterSecurity := s.cfg.Security.ClusterSecurity() + tlsConfig, err := clusterSecurity.ToTLSConfig() + if err != nil { + logutil.BgLogger().Error("invalid TLS config", zap.Error(err)) + return errors.Trace(err) + } + tlsConfig = s.SetCNChecker(tlsConfig) + + if tlsConfig != nil { + // we need to manage TLS here for cmux to distinguish between HTTP and gRPC. + s.statusListener, err = tls.Listen("tcp", s.statusAddr, tlsConfig) + } else { + s.statusListener, err = net.Listen("tcp", s.statusAddr) + } + if err != nil { + logutil.BgLogger().Info("listen failed", zap.Error(err)) + return errors.Trace(err) + } else if RunInGoTest && s.cfg.Status.StatusPort == 0 { + s.statusAddr = s.statusListener.Addr().String() + s.cfg.Status.StatusPort = uint(s.statusListener.Addr().(*net.TCPAddr).Port) + } + return nil +} + +// Ballast try to reduce the GC frequency by using Ballast Object +type Ballast struct { + ballast []byte + ballastLock sync.Mutex + + maxSize int +} + +func newBallast(maxSize int) *Ballast { + var b Ballast + b.maxSize = 1024 * 1024 * 1024 * 2 + if maxSize > 0 { + b.maxSize = maxSize + } else { + // we try to use the total amount of ram as a reference to set the default ballastMaxSz + // since the fatal throw "runtime: out of memory" would never yield to `recover` + totalRAMSz, err := memory.MemTotal() + if err != nil { + logutil.BgLogger().Error("failed to get the total amount of RAM on this system", zap.Error(err)) + } else { + maxSzAdvice := totalRAMSz >> 2 + if uint64(b.maxSize) > maxSzAdvice { + b.maxSize = int(maxSzAdvice) + } + } + } + return &b +} + +// GetSize get the size of ballast object +func (b *Ballast) GetSize() int { + var sz int + b.ballastLock.Lock() + sz = len(b.ballast) + b.ballastLock.Unlock() + return sz +} + +// SetSize set the size of ballast object +func (b *Ballast) SetSize(newSz int) error { + if newSz < 0 { + return fmt.Errorf("newSz cannot be negative: %d", newSz) + } + if newSz > b.maxSize { + return fmt.Errorf("newSz cannot be bigger than %d but it has value %d", b.maxSize, newSz) + } + b.ballastLock.Lock() + b.ballast = make([]byte, newSz) + b.ballastLock.Unlock() + return nil +} + +// GenHTTPHandler generate a HTTP handler to get/set the size of this ballast object +func (b *Ballast) GenHTTPHandler() func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + _, err := w.Write([]byte(strconv.Itoa(b.GetSize()))) + terror.Log(err) + case http.MethodPost: + body, err := io.ReadAll(r.Body) + if err != nil { + terror.Log(err) + return + } + newSz, err := strconv.Atoi(string(body)) + if err == nil { + err = b.SetSize(newSz) + } + if err != nil { + w.WriteHeader(http.StatusBadRequest) + errStr := err.Error() + if _, err := w.Write([]byte(errStr)); err != nil { + terror.Log(err) + } + return + } + } + } +} + +func (s *Server) startHTTPServer() { + router := mux.NewRouter() + + router.HandleFunc("/status", s.handleStatus).Name("Status") + // HTTP path for prometheus. + router.Handle("/metrics", promhttp.Handler()).Name("Metrics") + + // HTTP path for dump statistics. + router.Handle("/stats/dump/{db}/{table}", s.newStatsHandler()). + Name("StatsDump") + router.Handle("/stats/dump/{db}/{table}/{snapshot}", s.newStatsHistoryHandler()). + Name("StatsHistoryDump") + + router.Handle("/plan_replayer/dump/{filename}", s.newPlanReplayerHandler()).Name("PlanReplayerDump") + router.Handle("/extract_task/dump", s.newExtractServeHandler()).Name("ExtractTaskDump") + + router.Handle("/optimize_trace/dump/{filename}", s.newOptimizeTraceHandler()).Name("OptimizeTraceDump") + + tikvHandlerTool := s.NewTikvHandlerTool() + router.Handle("/settings", tikvhandler.NewSettingsHandler(tikvHandlerTool)).Name("Settings") + router.Handle("/binlog/recover", tikvhandler.BinlogRecover{}).Name("BinlogRecover") + + router.Handle("/schema", tikvhandler.NewSchemaHandler(tikvHandlerTool)).Name("Schema") + router.Handle("/schema/{db}", tikvhandler.NewSchemaHandler(tikvHandlerTool)) + router.Handle("/schema/{db}/{table}", tikvhandler.NewSchemaHandler(tikvHandlerTool)) + router.Handle("/tables/{colID}/{colTp}/{colFlag}/{colLen}", tikvhandler.ValueHandler{}) + + router.Handle("/schema_storage", tikvhandler.NewSchemaStorageHandler(tikvHandlerTool)).Name("Schema Storage") + router.Handle("/schema_storage/{db}", tikvhandler.NewSchemaStorageHandler(tikvHandlerTool)) + router.Handle("/schema_storage/{db}/{table}", tikvhandler.NewSchemaStorageHandler(tikvHandlerTool)) + + router.Handle("/ddl/history", tikvhandler.NewDDLHistoryJobHandler(tikvHandlerTool)).Name("DDL_History") + router.Handle("/ddl/owner/resign", tikvhandler.NewDDLResignOwnerHandler(tikvHandlerTool.Store.(kv.Storage))).Name("DDL_Owner_Resign") + + // HTTP path for get the TiDB config + router.Handle("/config", fn.Wrap(func() (*config.Config, error) { + return config.GetGlobalConfig(), nil + })) + router.Handle("/labels", tikvhandler.LabelHandler{}).Name("Labels") + + // HTTP path for get server info. + router.Handle("/info", tikvhandler.NewServerInfoHandler(tikvHandlerTool)).Name("Info") + router.Handle("/info/all", tikvhandler.NewAllServerInfoHandler(tikvHandlerTool)).Name("InfoALL") + // HTTP path for get db and table info that is related to the tableID. + router.Handle("/db-table/{tableID}", tikvhandler.NewDBTableHandler(tikvHandlerTool)) + // HTTP path for get table tiflash replica info. + router.Handle("/tiflash/replica-deprecated", tikvhandler.NewFlashReplicaHandler(tikvHandlerTool)) + + // HTTP path for upgrade operations. + router.Handle("/upgrade/{op}", handler.NewClusterUpgradeHandler(tikvHandlerTool.Store.(kv.Storage))).Name("upgrade operations") + + if s.cfg.Store == "tikv" { + // HTTP path for tikv. + router.Handle("/tables/{db}/{table}/regions", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableRegions)) + router.Handle("/tables/{db}/{table}/ranges", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableRanges)) + router.Handle("/tables/{db}/{table}/scatter", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableScatter)) + router.Handle("/tables/{db}/{table}/stop-scatter", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpStopTableScatter)) + router.Handle("/tables/{db}/{table}/disk-usage", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableDiskUsage)) + router.Handle("/regions/meta", tikvhandler.NewRegionHandler(tikvHandlerTool)).Name("RegionsMeta") + router.Handle("/regions/hot", tikvhandler.NewRegionHandler(tikvHandlerTool)).Name("RegionHot") + router.Handle("/regions/{regionID}", tikvhandler.NewRegionHandler(tikvHandlerTool)) + } + + // HTTP path for get MVCC info + router.Handle("/mvcc/key/{db}/{table}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByKey)) + router.Handle("/mvcc/key/{db}/{table}/{handle}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByKey)) + router.Handle("/mvcc/txn/{startTS}/{db}/{table}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByTxn)) + router.Handle("/mvcc/hex/{hexKey}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByHex)) + router.Handle("/mvcc/index/{db}/{table}/{index}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByIdx)) + router.Handle("/mvcc/index/{db}/{table}/{index}/{handle}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByIdx)) + + // HTTP path for generate metric profile. + router.Handle("/metrics/profile", tikvhandler.NewProfileHandler(tikvHandlerTool)) + // HTTP path for web UI. + if host, port, err := net.SplitHostPort(s.statusAddr); err == nil { + if host == "" { + host = "localhost" + } + baseURL := &url.URL{ + Scheme: util.InternalHTTPSchema(), + Host: fmt.Sprintf("%s:%s", host, port), + } + router.HandleFunc("/web/trace", traceapp.HandleTiDB).Name("Trace Viewer") + sr := router.PathPrefix("/web/trace/").Subrouter() + if _, err := traceapp.New(traceapp.NewRouter(sr), baseURL); err != nil { + logutil.BgLogger().Error("new failed", zap.Error(err)) + } + router.PathPrefix("/static/").Handler(http.StripPrefix("/static", http.FileServer(static.Data))) + } + + router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + router.HandleFunc("/debug/pprof/profile", cpuprofile.ProfileHTTPHandler) + router.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + router.HandleFunc("/debug/pprof/trace", pprof.Trace) + // Other /debug/pprof paths not covered above are redirected to pprof.Index. + router.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index) + + ballast := newBallast(s.cfg.MaxBallastObjectSize) + { + err := ballast.SetSize(s.cfg.BallastObjectSize) + if err != nil { + logutil.BgLogger().Error("set initial ballast object size failed", zap.Error(err)) + } + } + router.HandleFunc("/debug/ballast-object-sz", ballast.GenHTTPHandler()) + + router.HandleFunc("/debug/gogc", func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + _, err := w.Write([]byte(strconv.Itoa(util.GetGOGC()))) + terror.Log(err) + case http.MethodPost: + body, err := io.ReadAll(r.Body) + if err != nil { + terror.Log(err) + return + } + + val, err := strconv.Atoi(string(body)) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + if _, err := w.Write([]byte(err.Error())); err != nil { + terror.Log(err) + } + return + } + + util.SetGOGC(val) + } + }) + + router.HandleFunc("/debug/zip", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="tidb_debug"`+time.Now().Format("20060102150405")+".zip")) + + // dump goroutine/heap/mutex + items := []struct { + name string + gc int + debug int + second int + }{ + {name: "goroutine", debug: 2}, + {name: "heap", gc: 1}, + {name: "mutex"}, + } + zw := zip.NewWriter(w) + for _, item := range items { + p := rpprof.Lookup(item.name) + if p == nil { + serveError(w, http.StatusNotFound, "Unknown profile") + return + } + if item.gc > 0 { + runtime.GC() + } + fw, err := zw.Create(item.name) + if err != nil { + serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", item.name, err)) + return + } + err = p.WriteTo(fw, item.debug) + terror.Log(err) + } + + // dump profile + fw, err := zw.Create("profile") + if err != nil { + serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", "profile", err)) + return + } + pc := cpuprofile.NewCollector() + if err := pc.StartCPUProfile(fw); err != nil { + serveError(w, http.StatusInternalServerError, + fmt.Sprintf("Could not enable CPU profiling: %s", err)) + return + } + sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64) + if sec <= 0 || err != nil { + sec = 10 + } + sleepWithCtx(r.Context(), time.Duration(sec)*time.Second) + err = pc.StopCPUProfile() + if err != nil { + serveError(w, http.StatusInternalServerError, + fmt.Sprintf("Could not enable CPU profiling: %s", err)) + return + } + + // dump config + fw, err = zw.Create("config") + if err != nil { + serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", "config", err)) + return + } + js, err := json.MarshalIndent(config.GetGlobalConfig(), "", " ") + if err != nil { + serveError(w, http.StatusInternalServerError, fmt.Sprintf("get config info fail%v", err)) + return + } + _, err = fw.Write(js) + terror.Log(err) + + // dump version + fw, err = zw.Create("version") + if err != nil { + serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", "version", err)) + return + } + _, err = fw.Write([]byte(printer.GetTiDBInfo())) + terror.Log(err) + + err = zw.Close() + terror.Log(err) + }) + + // failpoint is enabled only for tests so we can add some http APIs here for tests. + failpoint.Inject("enableTestAPI", func() { + router.PathPrefix("/fail/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") + new(failpoint.HttpHandler).ServeHTTP(w, r) + }) + + router.Handle("/test/{mod}/{op}", tikvhandler.NewTestHandler(tikvHandlerTool, 0)) + }) + + // ddlHook is enabled only for tests so we can substitute the callback in the DDL. + router.Handle("/test/ddl/hook", tikvhandler.DDLHookHandler{}) + + // ttlJobTriggerHandler is enabled only for tests, so we can accelerate the schedule of TTL job + router.Handle("/test/ttl/trigger/{db}/{table}", ttlhandler.NewTTLJobTriggerHandler(tikvHandlerTool.Store.(kv.Storage))) + + var ( + httpRouterPage bytes.Buffer + pathTemplate string + err error + ) + httpRouterPage.WriteString("TiDB Status and Metrics Report

TiDB Status and Metrics Report

") + err = router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { + pathTemplate, err = route.GetPathTemplate() + if err != nil { + logutil.BgLogger().Error("get HTTP router path failed", zap.Error(err)) + } + name := route.GetName() + // If the name attribute is not set, GetName returns "". + // "traceapp.xxx" are introduced by the traceapp package and are also ignored. + if name != "" && !strings.HasPrefix(name, "traceapp") && err == nil { + httpRouterPage.WriteString("") + } + return nil + }) + if err != nil { + logutil.BgLogger().Error("generate root failed", zap.Error(err)) + } + httpRouterPage.WriteString("") + httpRouterPage.WriteString("
" + name + "
Debug
") + router.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) { + _, err = responseWriter.Write(httpRouterPage.Bytes()) + if err != nil { + logutil.BgLogger().Error("write HTTP index page failed", zap.Error(err)) + } + }) + + serverMux := http.NewServeMux() + serverMux.Handle("/", router) + s.startStatusServerAndRPCServer(serverMux) +} + +func (s *Server) startStatusServerAndRPCServer(serverMux *http.ServeMux) { + m := cmux.New(s.statusListener) + // Match connections in order: + // First HTTP, and otherwise grpc. + httpL := m.Match(cmux.HTTP1Fast()) + grpcL := m.Match(cmux.Any()) + + statusServer := &http.Server{Addr: s.statusAddr, Handler: util2.NewCorsHandler(serverMux, s.cfg)} + grpcServer := NewRPCServer(s.cfg, s.dom, s) + service.RegisterChannelzServiceToServer(grpcServer) + if s.cfg.Store == "tikv" { + keyspaceName := config.GetGlobalKeyspaceName() + for { + var fullPath string + if keyspaceName == "" { + fullPath = fmt.Sprintf("%s://%s", s.cfg.Store, s.cfg.Path) + } else { + fullPath = fmt.Sprintf("%s://%s?keyspaceName=%s", s.cfg.Store, s.cfg.Path, keyspaceName) + } + store, err := store.New(fullPath) + if err != nil { + logutil.BgLogger().Error("new tikv store fail", zap.Error(err)) + break + } + ebd, ok := store.(kv.EtcdBackend) + if !ok { + break + } + etcdAddr, err := ebd.EtcdAddrs() + if err != nil { + logutil.BgLogger().Error("tikv store not etcd background", zap.Error(err)) + break + } + selfAddr := net.JoinHostPort(s.cfg.AdvertiseAddress, strconv.Itoa(int(s.cfg.Status.StatusPort))) + service := autoid.New(selfAddr, etcdAddr, store, ebd.TLSConfig()) + logutil.BgLogger().Info("register auto service at", zap.String("addr", selfAddr)) + pb.RegisterAutoIDAllocServer(grpcServer, service) + s.autoIDService = service + break + } + } + + s.statusServer = statusServer + s.grpcServer = grpcServer + + go util.WithRecovery(func() { + err := grpcServer.Serve(grpcL) + logutil.BgLogger().Error("grpc server error", zap.Error(err)) + }, nil) + + go util.WithRecovery(func() { + err := statusServer.Serve(httpL) + logutil.BgLogger().Error("http server error", zap.Error(err)) + }, nil) + + err := m.Serve() + if err != nil { + logutil.BgLogger().Error("start status/rpc server error", zap.Error(err)) + } +} + +// SetCNChecker set the CN checker for server. +func (s *Server) SetCNChecker(tlsConfig *tls.Config) *tls.Config { + if tlsConfig != nil && len(s.cfg.Security.ClusterVerifyCN) != 0 { + checkCN := make(map[string]struct{}) + for _, cn := range s.cfg.Security.ClusterVerifyCN { + cn = strings.TrimSpace(cn) + checkCN[cn] = struct{}{} + } + tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { + for _, chain := range verifiedChains { + if len(chain) != 0 { + if _, match := checkCN[chain[0].Subject.CommonName]; match { + return nil + } + } + } + return errors.Errorf("client certificate authentication failed. The Common Name from the client certificate was not found in the configuration cluster-verify-cn with value: %s", s.cfg.Security.ClusterVerifyCN) + } + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + return tlsConfig +} + +// Status of TiDB. +type Status struct { + Connections int `json:"connections"` + Version string `json:"version"` + GitHash string `json:"git_hash"` +} + +func (s *Server) handleStatus(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + // If the server is in the process of shutting down, return a non-200 status. + // It is important not to return Status{} as acquiring the s.ConnectionCount() + // acquires a lock that may already be held by the shutdown process. + if !s.health.Load() { + w.WriteHeader(http.StatusInternalServerError) + return + } + st := Status{ + Connections: s.ConnectionCount(), + Version: mysql.ServerVersion, + GitHash: versioninfo.TiDBGitHash, + } + js, err := json.Marshal(st) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + logutil.BgLogger().Error("encode json failed", zap.Error(err)) + return + } + _, err = w.Write(js) + terror.Log(errors.Trace(err)) +} + +func (s *Server) newStatsHandler() *optimizor.StatsHandler { + store, ok := s.driver.(*TiDBDriver) + if !ok { + panic("Illegal driver") + } + + do, err := session.GetDomain(store.store) + if err != nil { + panic("Failed to get domain") + } + return optimizor.NewStatsHandler(do) +} + +func (s *Server) newStatsHistoryHandler() *optimizor.StatsHistoryHandler { + store, ok := s.driver.(*TiDBDriver) + if !ok { + panic("Illegal driver") + } + + do, err := session.GetDomain(store.store) + if err != nil { + panic("Failed to get domain") + } + return optimizor.NewStatsHistoryHandler(do) +} diff --git a/pkg/session/binding__failpoint_binding__.go b/pkg/session/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..9ef59b452261c --- /dev/null +++ b/pkg/session/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package session + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/session/nontransactional.go b/pkg/session/nontransactional.go index 390de607cdaf4..fe7eb3680c432 100644 --- a/pkg/session/nontransactional.go +++ b/pkg/session/nontransactional.go @@ -417,11 +417,11 @@ func doOneJob(ctx context.Context, job *job, totalJobCount int, options statemen rs, err := se.ExecuteStmt(ctx, options.stmt.DMLStmt) // collect errors - failpoint.Inject("batchDMLError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("batchDMLError")); _err_ == nil { if val.(bool) { err = errors.New("injected batch(non-transactional) DML error") } - }) + } if err != nil { logutil.Logger(ctx).Error("Non-transactional DML SQL failed", zap.String("job", dmlSQLInLog), zap.Error(err), zap.Int("jobID", job.jobID), zap.Int("jobSize", job.jobSize)) job.err = err diff --git a/pkg/session/nontransactional.go__failpoint_stash__ b/pkg/session/nontransactional.go__failpoint_stash__ new file mode 100644 index 0000000000000..390de607cdaf4 --- /dev/null +++ b/pkg/session/nontransactional.go__failpoint_stash__ @@ -0,0 +1,847 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "fmt" + "math" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/format" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/opcode" + "github.com/pingcap/tidb/pkg/planner/core" + session_metrics "github.com/pingcap/tidb/pkg/session/metrics" + sessiontypes "github.com/pingcap/tidb/pkg/session/types" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/types" + driver "github.com/pingcap/tidb/pkg/types/parser_driver" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/redact" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "go.uber.org/zap" +) + +// ErrNonTransactionalJobFailure is the error when a non-transactional job fails. The error is returned and following jobs are canceled. +var ErrNonTransactionalJobFailure = dbterror.ClassSession.NewStd(errno.ErrNonTransactionalJobFailure) + +// job: handle keys in [start, end] +type job struct { + start types.Datum + end types.Datum + err error + jobID int + jobSize int // it can be inaccurate if there are concurrent writes + sql string +} + +// statementBuildInfo contains information that is needed to build the split statement in a job +type statementBuildInfo struct { + stmt *ast.NonTransactionalDMLStmt + shardColumnType types.FieldType + shardColumnRefer *ast.ResultField + originalCondition ast.ExprNode +} + +func (j job) String(redacted string) string { + return fmt.Sprintf("job id: %d, estimated size: %d, sql: %s", j.jobID, j.jobSize, redact.String(redacted, j.sql)) +} + +// HandleNonTransactionalDML is the entry point for a non-transactional DML statement +func HandleNonTransactionalDML(ctx context.Context, stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session) (sqlexec.RecordSet, error) { + sessVars := se.GetSessionVars() + originalReadStaleness := se.GetSessionVars().ReadStaleness + // NT-DML is a write operation, and should not be affected by read_staleness that is supposed to affect only SELECT. + sessVars.ReadStaleness = 0 + defer func() { + sessVars.ReadStaleness = originalReadStaleness + }() + err := core.Preprocess(ctx, se, stmt) + if err != nil { + return nil, err + } + if err := checkConstraint(stmt, se); err != nil { + return nil, err + } + + tableName, selectSQL, shardColumnInfo, tableSources, err := buildSelectSQL(stmt, se) + if err != nil { + return nil, err + } + + if err := checkConstraintWithShardColumn(se, stmt, tableName, shardColumnInfo, tableSources); err != nil { + return nil, err + } + + if stmt.DryRun == ast.DryRunQuery { + return buildDryRunResults(stmt.DryRun, []string{selectSQL}, se.GetSessionVars().BatchSize.MaxChunkSize) + } + + // TODO: choose an appropriate quota. + // Use the mem-quota-query as a workaround. As a result, a NT-DML may consume 2x of the memory quota. + memTracker := memory.NewTracker(memory.LabelForNonTransactionalDML, -1) + memTracker.AttachTo(se.GetSessionVars().MemTracker) + se.GetSessionVars().MemTracker.SetBytesLimit(se.GetSessionVars().MemQuotaQuery) + defer memTracker.Detach() + jobs, err := buildShardJobs(ctx, stmt, se, selectSQL, shardColumnInfo, memTracker) + if err != nil { + return nil, err + } + + splitStmts, err := runJobs(ctx, jobs, stmt, tableName, se, stmt.DMLStmt.WhereExpr()) + if err != nil { + return nil, err + } + if stmt.DryRun == ast.DryRunSplitDml { + return buildDryRunResults(stmt.DryRun, splitStmts, se.GetSessionVars().BatchSize.MaxChunkSize) + } + return buildExecuteResults(ctx, jobs, se.GetSessionVars().BatchSize.MaxChunkSize, se.GetSessionVars().EnableRedactLog) +} + +// we require: +// (1) in an update statement, shard column cannot be updated +// +// Note: this is not a comprehensive check. +// We do this to help user prevent some easy mistakes, at an acceptable maintenance cost. +func checkConstraintWithShardColumn(se sessiontypes.Session, stmt *ast.NonTransactionalDMLStmt, + tableName *ast.TableName, shardColumnInfo *model.ColumnInfo, tableSources []*ast.TableSource) error { + switch s := stmt.DMLStmt.(type) { + case *ast.UpdateStmt: + if err := checkUpdateShardColumn(se, s.List, shardColumnInfo, tableName, tableSources, true); err != nil { + return err + } + case *ast.InsertStmt: + // FIXME: is it possible to happen? + // `insert into t select * from t on duplicate key update id = id + 1` will return an ambiguous column error? + if err := checkUpdateShardColumn(se, s.OnDuplicate, shardColumnInfo, tableName, tableSources, false); err != nil { + return err + } + default: + } + return nil +} + +// shard column should not be updated. +func checkUpdateShardColumn(se sessiontypes.Session, assignments []*ast.Assignment, shardColumnInfo *model.ColumnInfo, + tableName *ast.TableName, tableSources []*ast.TableSource, isUpdate bool) error { + // if the table has alias, the alias is used in assignments, and we should use aliased name to compare + aliasedShardColumnTableName := tableName.Name.L + for _, tableSource := range tableSources { + if tableSource.Source.(*ast.TableName).Name.L == aliasedShardColumnTableName && tableSource.AsName.L != "" { + aliasedShardColumnTableName = tableSource.AsName.L + } + } + + if shardColumnInfo == nil { + return nil + } + for _, assignment := range assignments { + sameDB := (assignment.Column.Schema.L == tableName.Schema.L) || + (assignment.Column.Schema.L == "" && tableName.Schema.L == se.GetSessionVars().CurrentDB) + if !sameDB { + continue + } + sameTable := (assignment.Column.Table.L == aliasedShardColumnTableName) || (isUpdate && len(tableSources) == 1) + if !sameTable { + continue + } + if assignment.Column.Name.L == shardColumnInfo.Name.L { + return errors.New("Non-transactional DML, shard column cannot be updated") + } + } + return nil +} + +func checkConstraint(stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session) error { + sessVars := se.GetSessionVars() + if !(sessVars.IsAutocommit() && !sessVars.InTxn()) { + return errors.Errorf("non-transactional DML can only run in auto-commit mode. auto-commit:%v, inTxn:%v", + se.GetSessionVars().IsAutocommit(), se.GetSessionVars().InTxn()) + } + if variable.EnableBatchDML.Load() && sessVars.DMLBatchSize > 0 && (sessVars.BatchDelete || sessVars.BatchInsert) { + return errors.Errorf("can't run non-transactional DML with batch-dml") + } + + if sessVars.ReadConsistency.IsWeak() { + return errors.New("can't run non-transactional under weak read consistency") + } + if sessVars.SnapshotTS != 0 { + return errors.New("can't do non-transactional DML when tidb_snapshot is set") + } + + switch s := stmt.DMLStmt.(type) { + case *ast.DeleteStmt: + if err := checkTableRef(s.TableRefs, true); err != nil { + return err + } + if err := checkReadClauses(s.Limit, s.Order); err != nil { + return err + } + session_metrics.NonTransactionalDeleteCount.Inc() + case *ast.UpdateStmt: + if err := checkTableRef(s.TableRefs, true); err != nil { + return err + } + if err := checkReadClauses(s.Limit, s.Order); err != nil { + return err + } + session_metrics.NonTransactionalUpdateCount.Inc() + case *ast.InsertStmt: + if s.Select == nil { + return errors.New("Non-transactional insert supports insert select stmt only") + } + selectStmt, ok := s.Select.(*ast.SelectStmt) + if !ok { + return errors.New("Non-transactional insert doesn't support non-select source") + } + if err := checkTableRef(selectStmt.From, true); err != nil { + return err + } + if err := checkReadClauses(selectStmt.Limit, selectStmt.OrderBy); err != nil { + return err + } + session_metrics.NonTransactionalInsertCount.Inc() + default: + return errors.New("Unsupported DML type for non-transactional DML") + } + + return nil +} + +func checkTableRef(t *ast.TableRefsClause, allowMultipleTables bool) error { + if t == nil || t.TableRefs == nil || t.TableRefs.Left == nil { + return errors.New("table reference is nil") + } + if !allowMultipleTables && t.TableRefs.Right != nil { + return errors.New("Non-transactional statements don't support multiple tables") + } + return nil +} + +func checkReadClauses(limit *ast.Limit, order *ast.OrderByClause) error { + if limit != nil { + return errors.New("Non-transactional statements don't support limit") + } + if order != nil { + return errors.New("Non-transactional statements don't support order by") + } + return nil +} + +// single-threaded worker. work on the key range [start, end] +func runJobs(ctx context.Context, jobs []job, stmt *ast.NonTransactionalDMLStmt, + tableName *ast.TableName, se sessiontypes.Session, originalCondition ast.ExprNode) ([]string, error) { + // prepare for the construction of statement + var shardColumnRefer *ast.ResultField + var shardColumnType types.FieldType + for _, col := range tableName.TableInfo.Columns { + if col.Name.L == stmt.ShardColumn.Name.L { + shardColumnRefer = &ast.ResultField{ + Column: col, + Table: tableName.TableInfo, + DBName: tableName.Schema, + } + shardColumnType = col.FieldType + } + } + if shardColumnRefer == nil && stmt.ShardColumn.Name.L != model.ExtraHandleName.L { + return nil, errors.New("Non-transactional DML, shard column not found") + } + + splitStmts := make([]string, 0, len(jobs)) + for i := range jobs { + select { + case <-ctx.Done(): + failedJobs := make([]string, 0) + for _, job := range jobs { + if job.err != nil { + failedJobs = append(failedJobs, fmt.Sprintf("job:%s, error: %s", job.String(se.GetSessionVars().EnableRedactLog), job.err.Error())) + } + } + if len(failedJobs) == 0 { + logutil.Logger(ctx).Warn("Non-transactional DML worker exit because context canceled. No errors", + zap.Int("finished", i), zap.Int("total", len(jobs))) + } else { + logutil.Logger(ctx).Warn("Non-transactional DML worker exit because context canceled. Errors found", + zap.Int("finished", i), zap.Int("total", len(jobs)), zap.Strings("errors found", failedJobs)) + } + return nil, ctx.Err() + default: + } + + // _tidb_rowid + if shardColumnRefer == nil { + shardColumnType = *types.NewFieldType(mysql.TypeLonglong) + shardColumnRefer = &ast.ResultField{ + Column: model.NewExtraHandleColInfo(), + Table: tableName.TableInfo, + DBName: tableName.Schema, + } + } + stmtBuildInfo := statementBuildInfo{ + stmt: stmt, + shardColumnType: shardColumnType, + shardColumnRefer: shardColumnRefer, + originalCondition: originalCondition, + } + if stmt.DryRun == ast.DryRunSplitDml { + if i > 0 && i < len(jobs)-1 { + continue + } + splitStmt := doOneJob(ctx, &jobs[i], len(jobs), stmtBuildInfo, se, true) + splitStmts = append(splitStmts, splitStmt) + } else { + doOneJob(ctx, &jobs[i], len(jobs), stmtBuildInfo, se, false) + } + + // if the first job failed, there is a large chance that all jobs will fail. So return early. + if i == 0 && jobs[i].err != nil { + return nil, errors.Annotate(jobs[i].err, "Early return: error occurred in the first job. All jobs are canceled") + } + if jobs[i].err != nil && !se.GetSessionVars().NonTransactionalIgnoreError { + return nil, ErrNonTransactionalJobFailure.GenWithStackByArgs(jobs[i].jobID, len(jobs), jobs[i].start.String(), jobs[i].end.String(), jobs[i].String(se.GetSessionVars().EnableRedactLog), jobs[i].err.Error()) + } + } + return splitStmts, nil +} + +func doOneJob(ctx context.Context, job *job, totalJobCount int, options statementBuildInfo, se sessiontypes.Session, dryRun bool) string { + var whereCondition ast.ExprNode + + if job.start.IsNull() { + isNullCondition := &ast.IsNullExpr{ + Expr: &ast.ColumnNameExpr{ + Name: options.stmt.ShardColumn, + Refer: options.shardColumnRefer, + }, + Not: false, + } + if job.end.IsNull() { + // `where x is null` + whereCondition = isNullCondition + } else { + // `where (x <= job.end) || (x is null)` + right := &driver.ValueExpr{} + right.Type = options.shardColumnType + right.Datum = job.end + leCondition := &ast.BinaryOperationExpr{ + Op: opcode.LE, + L: &ast.ColumnNameExpr{ + Name: options.stmt.ShardColumn, + Refer: options.shardColumnRefer, + }, + R: right, + } + whereCondition = &ast.BinaryOperationExpr{ + Op: opcode.LogicOr, + L: leCondition, + R: isNullCondition, + } + } + } else { + // a normal between condition: `where x between start and end` + left := &driver.ValueExpr{} + left.Type = options.shardColumnType + left.Datum = job.start + right := &driver.ValueExpr{} + right.Type = options.shardColumnType + right.Datum = job.end + whereCondition = &ast.BetweenExpr{ + Expr: &ast.ColumnNameExpr{ + Name: options.stmt.ShardColumn, + Refer: options.shardColumnRefer, + }, + Left: left, + Right: right, + Not: false, + } + } + + if options.originalCondition == nil { + options.stmt.DMLStmt.SetWhereExpr(whereCondition) + } else { + options.stmt.DMLStmt.SetWhereExpr(&ast.BinaryOperationExpr{ + Op: opcode.LogicAnd, + L: whereCondition, + R: options.originalCondition, + }) + } + var sb strings.Builder + err := options.stmt.DMLStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags| + format.RestoreNameBackQuotes| + format.RestoreSpacesAroundBinaryOperation| + format.RestoreBracketAroundBinaryOperation| + format.RestoreStringWithoutCharset, &sb)) + if err != nil { + logutil.Logger(ctx).Error("Non-transactional DML, failed to restore the DML statement", zap.Error(err)) + job.err = errors.New("Failed to restore the DML statement, probably because of unsupported type of the shard column") + return "" + } + dmlSQL := sb.String() + + if dryRun { + return dmlSQL + } + + job.sql = dmlSQL + logutil.Logger(ctx).Info("start a Non-transactional DML", + zap.String("job", job.String(se.GetSessionVars().EnableRedactLog)), zap.Int("totalJobCount", totalJobCount)) + dmlSQLInLog := parser.Normalize(dmlSQL, se.GetSessionVars().EnableRedactLog) + + options.stmt.DMLStmt.SetText(nil, fmt.Sprintf("/* job %v/%v */ %s", job.jobID, totalJobCount, dmlSQL)) + rs, err := se.ExecuteStmt(ctx, options.stmt.DMLStmt) + + // collect errors + failpoint.Inject("batchDMLError", func(val failpoint.Value) { + if val.(bool) { + err = errors.New("injected batch(non-transactional) DML error") + } + }) + if err != nil { + logutil.Logger(ctx).Error("Non-transactional DML SQL failed", zap.String("job", dmlSQLInLog), zap.Error(err), zap.Int("jobID", job.jobID), zap.Int("jobSize", job.jobSize)) + job.err = err + } else { + logutil.Logger(ctx).Info("Non-transactional DML SQL finished successfully", zap.Int("jobID", job.jobID), + zap.Int("jobSize", job.jobSize), zap.String("dmlSQL", dmlSQLInLog)) + } + if rs != nil { + _ = rs.Close() + } + return "" +} + +func buildShardJobs(ctx context.Context, stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session, + selectSQL string, shardColumnInfo *model.ColumnInfo, memTracker *memory.Tracker) ([]job, error) { + var shardColumnCollate string + if shardColumnInfo != nil { + shardColumnCollate = shardColumnInfo.GetCollate() + } else { + shardColumnCollate = "" + } + + // A NT-DML is not a SELECT. We ignore the SelectLimit for selectSQL so that it can read all values. + originalSelectLimit := se.GetSessionVars().SelectLimit + se.GetSessionVars().SelectLimit = math.MaxUint64 + // NT-DML is a write operation, and should not be affected by read_staleness that is supposed to affect only SELECT. + rss, err := se.Execute(ctx, selectSQL) + se.GetSessionVars().SelectLimit = originalSelectLimit + + if err != nil { + return nil, err + } + if len(rss) != 1 { + return nil, errors.Errorf("Non-transactional DML, expecting 1 record set, but got %d", len(rss)) + } + rs := rss[0] + defer func() { + _ = rs.Close() + }() + + batchSize := int(stmt.Limit) + if batchSize <= 0 { + return nil, errors.New("Non-transactional DML, batch size should be positive") + } + jobCount := 0 + jobs := make([]job, 0) + currentSize := 0 + var currentStart, currentEnd types.Datum + + chk := rs.NewChunk(nil) + for { + err = rs.Next(ctx, chk) + if err != nil { + return nil, err + } + + // last chunk + if chk.NumRows() == 0 { + if currentSize > 0 { + // there's remaining work + jobs = appendNewJob(jobs, jobCount+1, currentStart, currentEnd, currentSize, memTracker) + } + break + } + + if len(jobs) > 0 && chk.NumRows()+currentSize < batchSize { + // not enough data for a batch + currentSize += chk.NumRows() + newEnd := chk.GetRow(chk.NumRows()-1).GetDatum(0, &rs.Fields()[0].Column.FieldType) + currentEnd = *newEnd.Clone() + continue + } + + iter := chunk.NewIterator4Chunk(chk) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + if currentSize == 0 { + newStart := row.GetDatum(0, &rs.Fields()[0].Column.FieldType) + currentStart = *newStart.Clone() + } + newEnd := row.GetDatum(0, &rs.Fields()[0].Column.FieldType) + if currentSize >= batchSize { + cmp, err := newEnd.Compare(se.GetSessionVars().StmtCtx.TypeCtx(), ¤tEnd, collate.GetCollator(shardColumnCollate)) + if err != nil { + return nil, err + } + if cmp != 0 { + jobCount++ + jobs = appendNewJob(jobs, jobCount, *currentStart.Clone(), *currentEnd.Clone(), currentSize, memTracker) + currentSize = 0 + currentStart = newEnd + } + } + currentEnd = newEnd + currentSize++ + } + currentEnd = *currentEnd.Clone() + currentStart = *currentStart.Clone() + } + + return jobs, nil +} + +func appendNewJob(jobs []job, id int, start types.Datum, end types.Datum, size int, tracker *memory.Tracker) []job { + jobs = append(jobs, job{jobID: id, start: start, end: end, jobSize: size}) + tracker.Consume(start.EstimatedMemUsage() + end.EstimatedMemUsage() + 64) + return jobs +} + +func buildSelectSQL(stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session) ( + *ast.TableName, string, *model.ColumnInfo, []*ast.TableSource, error) { + // only use the first table + join, ok := stmt.DMLStmt.TableRefsJoin() + if !ok { + return nil, "", nil, nil, errors.New("Non-transactional DML, table source not found") + } + tableSources := make([]*ast.TableSource, 0) + tableSources, err := collectTableSourcesInJoin(join, tableSources) + if err != nil { + return nil, "", nil, nil, err + } + if len(tableSources) == 0 { + return nil, "", nil, nil, errors.New("Non-transactional DML, no tables found in table refs") + } + leftMostTableSource := tableSources[0] + leftMostTableName, ok := leftMostTableSource.Source.(*ast.TableName) + if !ok { + return nil, "", nil, nil, errors.New("Non-transactional DML, table name not found") + } + + shardColumnInfo, tableName, err := selectShardColumn(stmt, se, tableSources, leftMostTableName, leftMostTableSource) + if err != nil { + return nil, "", nil, nil, err + } + + var sb strings.Builder + if stmt.DMLStmt.WhereExpr() != nil { + err := stmt.DMLStmt.WhereExpr().Restore(format.NewRestoreCtx(format.DefaultRestoreFlags| + format.RestoreNameBackQuotes| + format.RestoreSpacesAroundBinaryOperation| + format.RestoreBracketAroundBinaryOperation| + format.RestoreStringWithoutCharset, &sb), + ) + if err != nil { + return nil, "", nil, nil, errors.Annotate(err, "Failed to restore where clause in non-transactional DML") + } + } else { + sb.WriteString("TRUE") + } + // assure NULL values are placed first + selectSQL := fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE %s ORDER BY IF(ISNULL(`%s`),0,1),`%s`", + stmt.ShardColumn.Name.O, tableName.DBInfo.Name.O, tableName.Name.O, sb.String(), stmt.ShardColumn.Name.O, stmt.ShardColumn.Name.O) + return tableName, selectSQL, shardColumnInfo, tableSources, nil +} + +func selectShardColumn(stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session, tableSources []*ast.TableSource, + leftMostTableName *ast.TableName, leftMostTableSource *ast.TableSource) ( + *model.ColumnInfo, *ast.TableName, error) { + var indexed bool + var shardColumnInfo *model.ColumnInfo + var selectedTableName *ast.TableName + + if len(tableSources) == 1 { + // single table + leftMostTable, err := domain.GetDomain(se).InfoSchema().TableByName(context.Background(), leftMostTableName.Schema, leftMostTableName.Name) + if err != nil { + return nil, nil, err + } + selectedTableName = leftMostTableName + indexed, shardColumnInfo, err = selectShardColumnFromTheOnlyTable( + stmt, leftMostTableName, leftMostTableSource.AsName, leftMostTable) + if err != nil { + return nil, nil, err + } + } else { + // multi table join + if stmt.ShardColumn == nil { + leftMostTable, err := domain.GetDomain(se).InfoSchema().TableByName(context.Background(), leftMostTableName.Schema, leftMostTableName.Name) + if err != nil { + return nil, nil, err + } + selectedTableName = leftMostTableName + indexed, shardColumnInfo, err = selectShardColumnAutomatically(stmt, leftMostTable, leftMostTableName, leftMostTableSource.AsName) + if err != nil { + return nil, nil, err + } + } else if stmt.ShardColumn.Schema.L != "" && stmt.ShardColumn.Table.L != "" && stmt.ShardColumn.Name.L != "" { + specifiedDbName := stmt.ShardColumn.Schema + specifiedTableName := stmt.ShardColumn.Table + specifiedColName := stmt.ShardColumn.Name + + // the specified table must be in the join + tableInJoin := false + var chosenTableName model.CIStr + for _, tableSource := range tableSources { + tableSourceName := tableSource.Source.(*ast.TableName) + tableSourceFinalTableName := tableSource.AsName // precedence: alias name, then table name + if tableSourceFinalTableName.O == "" { + tableSourceFinalTableName = tableSourceName.Name + } + if tableSourceName.Schema.L == specifiedDbName.L && tableSourceFinalTableName.L == specifiedTableName.L { + tableInJoin = true + selectedTableName = tableSourceName + chosenTableName = tableSourceName.Name + break + } + } + if !tableInJoin { + return nil, nil, + errors.Errorf( + "Non-transactional DML, shard column %s.%s.%s is not in the tables involved in the join", + specifiedDbName.L, specifiedTableName.L, specifiedColName.L, + ) + } + + tbl, err := domain.GetDomain(se).InfoSchema().TableByName(context.Background(), specifiedDbName, chosenTableName) + if err != nil { + return nil, nil, err + } + indexed, shardColumnInfo, err = selectShardColumnByGivenName(specifiedColName.L, tbl) + if err != nil { + return nil, nil, err + } + } else { + return nil, nil, errors.New( + "Non-transactional DML, shard column must be fully specified (i.e. `BATCH ON dbname.tablename.colname`) when multiple tables are involved", + ) + } + } + if !indexed { + return nil, nil, errors.Errorf("Non-transactional DML, shard column %s is not indexed", stmt.ShardColumn.Name.L) + } + return shardColumnInfo, selectedTableName, nil +} + +func collectTableSourcesInJoin(node ast.ResultSetNode, tableSources []*ast.TableSource) ([]*ast.TableSource, error) { + if node == nil { + return tableSources, nil + } + switch x := node.(type) { + case *ast.Join: + var err error + tableSources, err = collectTableSourcesInJoin(x.Left, tableSources) + if err != nil { + return nil, err + } + tableSources, err = collectTableSourcesInJoin(x.Right, tableSources) + if err != nil { + return nil, err + } + case *ast.TableSource: + // assert it's a table name + if _, ok := x.Source.(*ast.TableName); !ok { + return nil, errors.New("Non-transactional DML, table name not found in join") + } + tableSources = append(tableSources, x) + default: + return nil, errors.Errorf("Non-transactional DML, unknown type %T in table refs", node) + } + return tableSources, nil +} + +// it attempts to auto-select a shard column from handle if not specified, and fills back the corresponding info in the stmt, +// making it transparent to following steps +func selectShardColumnFromTheOnlyTable(stmt *ast.NonTransactionalDMLStmt, tableName *ast.TableName, + tableAsName model.CIStr, tbl table.Table) ( + indexed bool, shardColumnInfo *model.ColumnInfo, err error) { + if stmt.ShardColumn == nil { + return selectShardColumnAutomatically(stmt, tbl, tableName, tableAsName) + } + + return selectShardColumnByGivenName(stmt.ShardColumn.Name.L, tbl) +} + +func selectShardColumnByGivenName(shardColumnName string, tbl table.Table) ( + indexed bool, shardColumnInfo *model.ColumnInfo, err error) { + tableInfo := tbl.Meta() + if shardColumnName == model.ExtraHandleName.L && !tableInfo.HasClusteredIndex() { + return true, nil, nil + } + + for _, col := range tbl.Cols() { + if col.Name.L == shardColumnName { + shardColumnInfo = col.ColumnInfo + break + } + } + if shardColumnInfo == nil { + return false, nil, errors.Errorf("shard column %s not found", shardColumnName) + } + // is int handle + if mysql.HasPriKeyFlag(shardColumnInfo.GetFlag()) && tableInfo.PKIsHandle { + return true, shardColumnInfo, nil + } + + for _, index := range tbl.Indices() { + if index.Meta().State != model.StatePublic || index.Meta().Invisible { + continue + } + indexColumns := index.Meta().Columns + // check only the first column + if len(indexColumns) > 0 && indexColumns[0].Name.L == shardColumnName { + indexed = true + break + } + } + return indexed, shardColumnInfo, nil +} + +func selectShardColumnAutomatically(stmt *ast.NonTransactionalDMLStmt, tbl table.Table, + tableName *ast.TableName, tableAsName model.CIStr) (bool, *model.ColumnInfo, error) { + // auto-detect shard column + var shardColumnInfo *model.ColumnInfo + tableInfo := tbl.Meta() + if tbl.Meta().PKIsHandle { + shardColumnInfo = tableInfo.GetPkColInfo() + } else if tableInfo.IsCommonHandle { + for _, index := range tableInfo.Indices { + if index.Primary { + if len(index.Columns) == 1 { + shardColumnInfo = tableInfo.Columns[index.Columns[0].Offset] + break + } + // if the clustered index contains multiple columns, we cannot automatically choose a column as the shard column + return false, nil, errors.New("Non-transactional DML, the clustered index contains multiple columns. Please specify a shard column") + } + } + if shardColumnInfo == nil { + return false, nil, errors.New("Non-transactional DML, the clustered index is not found") + } + } + + shardColumnName := model.ExtraHandleName.L + if shardColumnInfo != nil { + shardColumnName = shardColumnInfo.Name.L + } + + outputTableName := tableName.Name + if tableAsName.L != "" { + outputTableName = tableAsName + } + stmt.ShardColumn = &ast.ColumnName{ + Schema: tableName.Schema, + Table: outputTableName, // so that table alias works + Name: model.NewCIStr(shardColumnName), + } + return true, shardColumnInfo, nil +} + +func buildDryRunResults(dryRunOption int, results []string, maxChunkSize int) (sqlexec.RecordSet, error) { + var fieldName string + if dryRunOption == ast.DryRunSplitDml { + fieldName = "split statement examples" + } else { + fieldName = "query statement" + } + + resultFields := []*ast.ResultField{{ + Column: &model.ColumnInfo{ + FieldType: *types.NewFieldType(mysql.TypeString), + }, + ColumnAsName: model.NewCIStr(fieldName), + }} + rows := make([][]any, 0, len(results)) + for _, result := range results { + row := make([]any, 1) + row[0] = result + rows = append(rows, row) + } + return &sqlexec.SimpleRecordSet{ + ResultFields: resultFields, + Rows: rows, + MaxChunkSize: maxChunkSize, + }, nil +} + +func buildExecuteResults(ctx context.Context, jobs []job, maxChunkSize int, redactLog string) (sqlexec.RecordSet, error) { + failedJobs := make([]job, 0) + for _, job := range jobs { + if job.err != nil { + failedJobs = append(failedJobs, job) + } + } + if len(failedJobs) == 0 { + resultFields := []*ast.ResultField{ + { + Column: &model.ColumnInfo{ + FieldType: *types.NewFieldType(mysql.TypeLong), + }, + ColumnAsName: model.NewCIStr("number of jobs"), + }, + { + Column: &model.ColumnInfo{ + FieldType: *types.NewFieldType(mysql.TypeString), + }, + ColumnAsName: model.NewCIStr("job status"), + }, + } + rows := make([][]any, 1) + row := make([]any, 2) + row[0] = len(jobs) + row[1] = "all succeeded" + rows[0] = row + return &sqlexec.SimpleRecordSet{ + ResultFields: resultFields, + Rows: rows, + MaxChunkSize: maxChunkSize, + }, nil + } + + // ignoreError must be set. + var sb strings.Builder + for _, job := range failedJobs { + sb.WriteString(fmt.Sprintf("%s, %s;\n", job.String(redactLog), job.err.Error())) + } + + errStr := sb.String() + // log errors here in case the output is too long. There can be thousands of errors. + logutil.Logger(ctx).Error("Non-transactional DML failed", + zap.Int("num_failed_jobs", len(failedJobs)), zap.String("failed_jobs", errStr)) + + return nil, fmt.Errorf("%d/%d jobs failed in the non-transactional DML: %s, ...(more in logs)", + len(failedJobs), len(jobs), errStr[:min(500, len(errStr)-1)]) +} diff --git a/pkg/session/session.go b/pkg/session/session.go index e46006fee9389..4ca9bdc5227ba 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -519,13 +519,13 @@ func (s *session) doCommit(ctx context.Context) error { return err } // mockCommitError and mockGetTSErrorInRetry use to test PR #8743. - failpoint.Inject("mockCommitError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockCommitError")); _err_ == nil { if val.(bool) { if _, err := failpoint.Eval("tikvclient/mockCommitErrorOpt"); err == nil { - failpoint.Return(kv.ErrTxnRetryable) + return kv.ErrTxnRetryable } } - }) + } sessVars := s.GetSessionVars() @@ -661,10 +661,10 @@ func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transac sessVars := s.sessionVars txnTempTables := sessVars.TxnCtx.TemporaryTables if len(txnTempTables) == 0 { - failpoint.Inject("mockSleepBeforeTxnCommit", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("mockSleepBeforeTxnCommit")); _err_ == nil { ms := v.(int) time.Sleep(time.Millisecond * time.Duration(ms)) - }) + } return txn.Commit(ctx) } @@ -925,11 +925,11 @@ func (s *session) CommitTxn(ctx context.Context) error { // record the TTLInsertRows in the metric metrics.TTLInsertRowsCount.Add(float64(s.sessionVars.TxnCtx.InsertTTLRowsCount)) - failpoint.Inject("keepHistory", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("keepHistory")); _err_ == nil { if val.(bool) { - failpoint.Return(err) + return err } - }) + } s.sessionVars.TxnCtx.Cleanup() s.sessionVars.CleanupTxnReadTSIfUsed() return err @@ -1117,12 +1117,12 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { logutil.Logger(ctx).Warn("transaction association", zap.Uint64("retrying txnStartTS", s.GetSessionVars().TxnCtx.StartTS), zap.Uint64("original txnStartTS", orgStartTS)) - failpoint.Inject("preCommitHook", func() { + if _, _err_ := failpoint.Eval(_curpkg_("preCommitHook")); _err_ == nil { hook, ok := ctx.Value("__preCommitHook").(func()) if ok { hook() } - }) + } if err == nil { err = s.doCommit(ctx) if err == nil { @@ -2068,12 +2068,12 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex return nil, err } - failpoint.Inject("mockStmtSlow", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockStmtSlow")); _err_ == nil { if strings.Contains(stmtNode.Text(), "/* sleep */") { v, _ := val.(int) time.Sleep(time.Duration(v) * time.Millisecond) } - }) + } var stmtLabel string if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { @@ -2243,12 +2243,12 @@ func (s *session) hasFileTransInConn() bool { // runStmt executes the sqlexec.Statement and commit or rollback the current transaction. func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) { - failpoint.Inject("assertTxnManagerInRunStmt", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInRunStmt")); _err_ == nil { sessiontxn.RecordAssert(se, "assertTxnManagerInRunStmt", true) if stmt, ok := s.(*executor.ExecStmt); ok { sessiontxn.AssertTxnManagerInfoSchema(se, stmt.InfoSchema) } - }) + } r, ctx := tracing.StartRegionEx(ctx, "session.runStmt") defer r.End() @@ -3946,15 +3946,15 @@ func (s *session) PrepareTSFuture(ctx context.Context, future oracle.Future, sco return errors.New("cannot prepare ts future when txn is valid") } - failpoint.Inject("assertTSONotRequest", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTSONotRequest")); _err_ == nil { if _, ok := future.(sessiontxn.ConstantFuture); !ok && !s.isInternal() { panic("tso shouldn't be requested") } - }) + } - failpoint.InjectContext(ctx, "mockGetTSFail", func() { + if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockGetTSFail")); _err_ == nil { future = txnFailFuture{} - }) + } s.txn.changeToPending(&txnFuture{ future: future, diff --git a/pkg/session/session.go__failpoint_stash__ b/pkg/session/session.go__failpoint_stash__ new file mode 100644 index 0000000000000..e46006fee9389 --- /dev/null +++ b/pkg/session/session.go__failpoint_stash__ @@ -0,0 +1,4611 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package session + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/hex" + "encoding/json" + stderrs "errors" + "fmt" + "math" + "math/rand" + "runtime/pprof" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/ngaut/pools" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/bindinfo" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/placement" + distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" + "github.com/pingcap/tidb/pkg/disttask/importinto" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/executor" + "github.com/pingcap/tidb/pkg/executor/staticrecordset" + "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/expression/contextsession" + "github.com/pingcap/tidb/pkg/extension" + "github.com/pingcap/tidb/pkg/extension/extensionimpl" + "github.com/pingcap/tidb/pkg/infoschema" + infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/owner" + "github.com/pingcap/tidb/pkg/param" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/auth" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/planner" + planctx "github.com/pingcap/tidb/pkg/planner/context" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/plugin" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/privilege/conn" + "github.com/pingcap/tidb/pkg/privilege/privileges" + "github.com/pingcap/tidb/pkg/session/cursor" + session_metrics "github.com/pingcap/tidb/pkg/session/metrics" + "github.com/pingcap/tidb/pkg/session/txninfo" + "github.com/pingcap/tidb/pkg/session/types" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/statistics/handle/syncload" + "github.com/pingcap/tidb/pkg/statistics/handle/usage" + "github.com/pingcap/tidb/pkg/statistics/handle/usage/indexusage" + storeerr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/store/helper" + "github.com/pingcap/tidb/pkg/table" + tbctx "github.com/pingcap/tidb/pkg/table/context" + tbctximpl "github.com/pingcap/tidb/pkg/table/contextimpl" + "github.com/pingcap/tidb/pkg/table/temptable" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/ttl/ttlworker" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/logutil/consistency" + "github.com/pingcap/tidb/pkg/util/memory" + rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" + "github.com/pingcap/tidb/pkg/util/redact" + "github.com/pingcap/tidb/pkg/util/sem" + "github.com/pingcap/tidb/pkg/util/sli" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/timeutil" + "github.com/pingcap/tidb/pkg/util/topsql" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + "github.com/pingcap/tidb/pkg/util/tracing" + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/oracle" + tikvutil "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +func init() { + executor.CreateSession = func(ctx sessionctx.Context) (sessionctx.Context, error) { + return CreateSession(ctx.GetStore()) + } + executor.CloseSession = func(ctx sessionctx.Context) { + if se, ok := ctx.(types.Session); ok { + se.Close() + } + } +} + +var _ types.Session = (*session)(nil) + +type stmtRecord struct { + st sqlexec.Statement + stmtCtx *stmtctx.StatementContext +} + +// StmtHistory holds all histories of statements in a txn. +type StmtHistory struct { + history []*stmtRecord +} + +// Add appends a stmt to history list. +func (h *StmtHistory) Add(st sqlexec.Statement, stmtCtx *stmtctx.StatementContext) { + s := &stmtRecord{ + st: st, + stmtCtx: stmtCtx, + } + h.history = append(h.history, s) +} + +// Count returns the count of the history. +func (h *StmtHistory) Count() int { + return len(h.history) +} + +type session struct { + // processInfo is used by ShowProcess(), and should be modified atomically. + processInfo atomic.Pointer[util.ProcessInfo] + txn LazyTxn + + mu struct { + sync.RWMutex + values map[fmt.Stringer]any + } + + currentCtx context.Context // only use for runtime.trace, Please NEVER use it. + currentPlan base.Plan + + store kv.Storage + + sessionPlanCache sessionctx.SessionPlanCache + + sessionVars *variable.SessionVars + sessionManager util.SessionManager + + pctx *planContextImpl + exprctx *contextsession.SessionExprContext + tblctx *tbctximpl.TableContextImpl + + statsCollector *usage.SessionStatsItem + // ddlOwnerManager is used in `select tidb_is_ddl_owner()` statement; + ddlOwnerManager owner.Manager + // lockedTables use to record the table locks hold by the session. + lockedTables map[int64]model.TableLockTpInfo + + // client shared coprocessor client per session + client kv.Client + + mppClient kv.MPPClient + + // indexUsageCollector collects index usage information. + idxUsageCollector *indexusage.SessionIndexUsageCollector + + // StmtStats is used to count various indicators of each SQL in this session + // at each point in time. These data will be periodically taken away by the + // background goroutine. The background goroutine will continue to aggregate + // all the local data in each session, and finally report them to the remote + // regularly. + stmtStats *stmtstats.StatementStats + + // Used to encode and decode each type of session states. + sessionStatesHandlers map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler + + // Contains a list of sessions used to collect advisory locks. + advisoryLocks map[string]*advisoryLock + + extensions *extension.SessionExtensions + + sandBoxMode bool + + cursorTracker cursor.Tracker +} + +var parserPool = &sync.Pool{New: func() any { return parser.New() }} + +// AddTableLock adds table lock to the session lock map. +func (s *session) AddTableLock(locks []model.TableLockTpInfo) { + for _, l := range locks { + // read only lock is session unrelated, skip it when adding lock to session. + if l.Tp != model.TableLockReadOnly { + s.lockedTables[l.TableID] = l + } + } +} + +// ReleaseTableLocks releases table lock in the session lock map. +func (s *session) ReleaseTableLocks(locks []model.TableLockTpInfo) { + for _, l := range locks { + delete(s.lockedTables, l.TableID) + } +} + +// ReleaseTableLockByTableIDs releases table lock in the session lock map by table ID. +func (s *session) ReleaseTableLockByTableIDs(tableIDs []int64) { + for _, tblID := range tableIDs { + delete(s.lockedTables, tblID) + } +} + +// CheckTableLocked checks the table lock. +func (s *session) CheckTableLocked(tblID int64) (bool, model.TableLockType) { + lt, ok := s.lockedTables[tblID] + if !ok { + return false, model.TableLockNone + } + return true, lt.Tp +} + +// GetAllTableLocks gets all table locks table id and db id hold by the session. +func (s *session) GetAllTableLocks() []model.TableLockTpInfo { + lockTpInfo := make([]model.TableLockTpInfo, 0, len(s.lockedTables)) + for _, tl := range s.lockedTables { + lockTpInfo = append(lockTpInfo, tl) + } + return lockTpInfo +} + +// HasLockedTables uses to check whether this session locked any tables. +// If so, the session can only visit the table which locked by self. +func (s *session) HasLockedTables() bool { + b := len(s.lockedTables) > 0 + return b +} + +// ReleaseAllTableLocks releases all table locks hold by the session. +func (s *session) ReleaseAllTableLocks() { + s.lockedTables = make(map[int64]model.TableLockTpInfo) +} + +// IsDDLOwner checks whether this session is DDL owner. +func (s *session) IsDDLOwner() bool { + return s.ddlOwnerManager.IsOwner() +} + +func (s *session) cleanRetryInfo() { + if s.sessionVars.RetryInfo.Retrying { + return + } + + retryInfo := s.sessionVars.RetryInfo + defer retryInfo.Clean() + if len(retryInfo.DroppedPreparedStmtIDs) == 0 { + return + } + + planCacheEnabled := s.GetSessionVars().EnablePreparedPlanCache + var cacheKey string + var err error + var preparedObj *plannercore.PlanCacheStmt + if planCacheEnabled { + firstStmtID := retryInfo.DroppedPreparedStmtIDs[0] + if preparedPointer, ok := s.sessionVars.PreparedStmts[firstStmtID]; ok { + preparedObj, ok = preparedPointer.(*plannercore.PlanCacheStmt) + if ok { + cacheKey, _, _, _, err = plannercore.NewPlanCacheKey(s, preparedObj) + if err != nil { + logutil.Logger(s.currentCtx).Warn("clean cached plan failed", zap.Error(err)) + return + } + } + } + } + for i, stmtID := range retryInfo.DroppedPreparedStmtIDs { + if planCacheEnabled { + if i > 0 && preparedObj != nil { + cacheKey, _, _, _, err = plannercore.NewPlanCacheKey(s, preparedObj) + if err != nil { + logutil.Logger(s.currentCtx).Warn("clean cached plan failed", zap.Error(err)) + return + } + } + if !s.sessionVars.IgnorePreparedCacheCloseStmt { // keep the plan in cache + s.GetSessionPlanCache().Delete(cacheKey) + } + } + s.sessionVars.RemovePreparedStmt(stmtID) + } +} + +func (s *session) Status() uint16 { + return s.sessionVars.Status() +} + +func (s *session) LastInsertID() uint64 { + if s.sessionVars.StmtCtx.LastInsertID > 0 { + return s.sessionVars.StmtCtx.LastInsertID + } + return s.sessionVars.StmtCtx.InsertID +} + +func (s *session) LastMessage() string { + return s.sessionVars.StmtCtx.GetMessage() +} + +func (s *session) AffectedRows() uint64 { + return s.sessionVars.StmtCtx.AffectedRows() +} + +func (s *session) SetClientCapability(capability uint32) { + s.sessionVars.ClientCapability = capability +} + +func (s *session) SetConnectionID(connectionID uint64) { + s.sessionVars.ConnectionID = connectionID +} + +func (s *session) SetTLSState(tlsState *tls.ConnectionState) { + // If user is not connected via TLS, then tlsState == nil. + if tlsState != nil { + s.sessionVars.TLSConnectionState = tlsState + } +} + +func (s *session) SetCompressionAlgorithm(ca int) { + s.sessionVars.CompressionAlgorithm = ca +} + +func (s *session) SetCompressionLevel(level int) { + s.sessionVars.CompressionLevel = level +} + +func (s *session) SetCommandValue(command byte) { + atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command)) +} + +func (s *session) SetCollation(coID int) error { + cs, co, err := charset.GetCharsetInfoByID(coID) + if err != nil { + return err + } + // If new collations are enabled, switch to the default + // collation if this one is not supported. + co = collate.SubstituteMissingCollationToDefault(co) + for _, v := range variable.SetNamesVariables { + terror.Log(s.sessionVars.SetSystemVarWithoutValidation(v, cs)) + } + return s.sessionVars.SetSystemVarWithoutValidation(variable.CollationConnection, co) +} + +func (s *session) GetSessionPlanCache() sessionctx.SessionPlanCache { + // use the prepared plan cache + if !s.GetSessionVars().EnablePreparedPlanCache && !s.GetSessionVars().EnableNonPreparedPlanCache { + return nil + } + if s.sessionPlanCache == nil { // lazy construction + s.sessionPlanCache = plannercore.NewLRUPlanCache(uint(s.GetSessionVars().SessionPlanCacheSize), + variable.PreparedPlanCacheMemoryGuardRatio.Load(), plannercore.PreparedPlanCacheMaxMemory.Load(), s, false) + } + return s.sessionPlanCache +} + +func (s *session) SetSessionManager(sm util.SessionManager) { + s.sessionManager = sm +} + +func (s *session) GetSessionManager() util.SessionManager { + return s.sessionManager +} + +func (s *session) UpdateColStatsUsage(predicateColumns []model.TableItemID) { + if s.statsCollector == nil { + return + } + t := time.Now() + colMap := make(map[model.TableItemID]time.Time, len(predicateColumns)) + for _, col := range predicateColumns { + // TODO: Remove this assertion once it has been confirmed to operate correctly over a period of time. + intest.Assert(!col.IsIndex, "predicate column should only be table column") + colMap[col] = t + } + s.statsCollector.UpdateColStatsUsage(colMap) +} + +// FieldList returns fields list of a table. +func (s *session) FieldList(tableName string) ([]*ast.ResultField, error) { + is := s.GetInfoSchema().(infoschema.InfoSchema) + dbName := model.NewCIStr(s.GetSessionVars().CurrentDB) + tName := model.NewCIStr(tableName) + pm := privilege.GetPrivilegeManager(s) + if pm != nil && s.sessionVars.User != nil { + if !pm.RequestVerification(s.sessionVars.ActiveRoles, dbName.O, tName.O, "", mysql.AllPrivMask) { + user := s.sessionVars.User + u := user.Username + h := user.Hostname + if len(user.AuthUsername) > 0 && len(user.AuthHostname) > 0 { + u = user.AuthUsername + h = user.AuthHostname + } + return nil, plannererrors.ErrTableaccessDenied.GenWithStackByArgs("SELECT", u, h, tableName) + } + } + table, err := is.TableByName(context.Background(), dbName, tName) + if err != nil { + return nil, err + } + + cols := table.Cols() + fields := make([]*ast.ResultField, 0, len(cols)) + for _, col := range table.Cols() { + rf := &ast.ResultField{ + ColumnAsName: col.Name, + TableAsName: tName, + DBName: dbName, + Table: table.Meta(), + Column: col.ColumnInfo, + } + fields = append(fields, rf) + } + return fields, nil +} + +// TxnInfo returns a pointer to a *copy* of the internal TxnInfo, thus is *read only* +func (s *session) TxnInfo() *txninfo.TxnInfo { + s.txn.mu.RLock() + // Copy on read to get a snapshot, this API shouldn't be frequently called. + txnInfo := s.txn.mu.TxnInfo + s.txn.mu.RUnlock() + + if txnInfo.StartTS == 0 { + return nil + } + + processInfo := s.ShowProcess() + if processInfo == nil { + return nil + } + txnInfo.ConnectionID = processInfo.ID + txnInfo.Username = processInfo.User + txnInfo.CurrentDB = processInfo.DB + txnInfo.RelatedTableIDs = make(map[int64]struct{}) + s.GetSessionVars().GetRelatedTableForMDL().Range(func(key, _ any) bool { + txnInfo.RelatedTableIDs[key.(int64)] = struct{}{} + return true + }) + + return &txnInfo +} + +func (s *session) doCommit(ctx context.Context) error { + if !s.txn.Valid() { + return nil + } + + defer func() { + s.txn.changeToInvalid() + s.sessionVars.SetInTxn(false) + s.sessionVars.ClearDiskFullOpt() + }() + // check if the transaction is read-only + if s.txn.IsReadOnly() { + return nil + } + // check if the cluster is read-only + if !s.sessionVars.InRestrictedSQL && (variable.RestrictedReadOnly.Load() || variable.VarTiDBSuperReadOnly.Load()) { + // It is not internal SQL, and the cluster has one of RestrictedReadOnly or SuperReadOnly + // We need to privilege check again: a privilege check occurred during planning, but we need + // to prevent the case that a long running auto-commit statement is now trying to commit. + pm := privilege.GetPrivilegeManager(s) + roles := s.sessionVars.ActiveRoles + if pm != nil && !pm.HasExplicitlyGrantedDynamicPrivilege(roles, "RESTRICTED_REPLICA_WRITER_ADMIN", false) { + s.RollbackTxn(ctx) + return plannererrors.ErrSQLInReadOnlyMode + } + } + err := s.checkPlacementPolicyBeforeCommit() + if err != nil { + return err + } + // mockCommitError and mockGetTSErrorInRetry use to test PR #8743. + failpoint.Inject("mockCommitError", func(val failpoint.Value) { + if val.(bool) { + if _, err := failpoint.Eval("tikvclient/mockCommitErrorOpt"); err == nil { + failpoint.Return(kv.ErrTxnRetryable) + } + } + }) + + sessVars := s.GetSessionVars() + + var commitTSChecker func(uint64) bool + if tables := sessVars.TxnCtx.CachedTables; len(tables) > 0 { + c := cachedTableRenewLease{tables: tables} + now := time.Now() + err := c.start(ctx) + defer c.stop(ctx) + sessVars.StmtCtx.WaitLockLeaseTime += time.Since(now) + if err != nil { + return errors.Trace(err) + } + commitTSChecker = c.commitTSCheck + } + if err = sessiontxn.GetTxnManager(s).SetOptionsBeforeCommit(s.txn.Transaction, commitTSChecker); err != nil { + return err + } + + err = s.commitTxnWithTemporaryData(tikvutil.SetSessionID(ctx, sessVars.ConnectionID), &s.txn) + if err != nil { + err = s.handleAssertionFailure(ctx, err) + } + return err +} + +type cachedTableRenewLease struct { + tables map[int64]any + lease []uint64 // Lease for each visited cached tables. + exit chan struct{} +} + +func (c *cachedTableRenewLease) start(ctx context.Context) error { + c.exit = make(chan struct{}) + c.lease = make([]uint64, len(c.tables)) + wg := make(chan error, len(c.tables)) + ith := 0 + for _, raw := range c.tables { + tbl := raw.(table.CachedTable) + go tbl.WriteLockAndKeepAlive(ctx, c.exit, &c.lease[ith], wg) + ith++ + } + + // Wait for all LockForWrite() return, this function can return. + var err error + for ; ith > 0; ith-- { + tmp := <-wg + if tmp != nil { + err = tmp + } + } + return err +} + +func (c *cachedTableRenewLease) stop(_ context.Context) { + close(c.exit) +} + +func (c *cachedTableRenewLease) commitTSCheck(commitTS uint64) bool { + for i := 0; i < len(c.lease); i++ { + lease := atomic.LoadUint64(&c.lease[i]) + if commitTS >= lease { + // Txn fails to commit because the write lease is expired. + return false + } + } + return true +} + +// handleAssertionFailure extracts the possible underlying assertionFailed error, +// gets the corresponding MVCC history and logs it. +// If it's not an assertion failure, returns the original error. +func (s *session) handleAssertionFailure(ctx context.Context, err error) error { + var assertionFailure *tikverr.ErrAssertionFailed + if !stderrs.As(err, &assertionFailure) { + return err + } + key := assertionFailure.Key + newErr := kv.ErrAssertionFailed.GenWithStackByArgs( + hex.EncodeToString(key), assertionFailure.Assertion.String(), assertionFailure.StartTs, + assertionFailure.ExistingStartTs, assertionFailure.ExistingCommitTs, + ) + + rmode := s.GetSessionVars().EnableRedactLog + if rmode == errors.RedactLogEnable { + return newErr + } + + var decodeFunc func(kv.Key, *kvrpcpb.MvccGetByKeyResponse, map[string]any) + // if it's a record key or an index key, decode it + if infoSchema, ok := s.sessionVars.TxnCtx.InfoSchema.(infoschema.InfoSchema); ok && + infoSchema != nil && (tablecodec.IsRecordKey(key) || tablecodec.IsIndexKey(key)) { + tableOrPartitionID := tablecodec.DecodeTableID(key) + tbl, ok := infoSchema.TableByID(tableOrPartitionID) + if !ok { + tbl, _, _ = infoSchema.FindTableByPartitionID(tableOrPartitionID) + } + if tbl == nil { + logutil.Logger(ctx).Warn("cannot find table by id", zap.Int64("tableID", tableOrPartitionID), zap.String("key", hex.EncodeToString(key))) + return newErr + } + + if tablecodec.IsRecordKey(key) { + decodeFunc = consistency.DecodeRowMvccData(tbl.Meta()) + } else { + tableInfo := tbl.Meta() + _, indexID, _, e := tablecodec.DecodeIndexKey(key) + if e != nil { + logutil.Logger(ctx).Error("assertion failed but cannot decode index key", zap.Error(e)) + return newErr + } + var indexInfo *model.IndexInfo + for _, idx := range tableInfo.Indices { + if idx.ID == indexID { + indexInfo = idx + break + } + } + if indexInfo == nil { + return newErr + } + decodeFunc = consistency.DecodeIndexMvccData(indexInfo) + } + } + if store, ok := s.store.(helper.Storage); ok { + content := consistency.GetMvccByKey(store, key, decodeFunc) + logutil.Logger(ctx).Error("assertion failed", zap.String("message", newErr.Error()), zap.String("mvcc history", redact.String(rmode, content))) + } + return newErr +} + +func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transaction) error { + sessVars := s.sessionVars + txnTempTables := sessVars.TxnCtx.TemporaryTables + if len(txnTempTables) == 0 { + failpoint.Inject("mockSleepBeforeTxnCommit", func(v failpoint.Value) { + ms := v.(int) + time.Sleep(time.Millisecond * time.Duration(ms)) + }) + return txn.Commit(ctx) + } + + sessionData := sessVars.TemporaryTableData + var ( + stage kv.StagingHandle + localTempTables *infoschema.SessionTables + ) + + if sessVars.LocalTemporaryTables != nil { + localTempTables = sessVars.LocalTemporaryTables.(*infoschema.SessionTables) + } else { + localTempTables = new(infoschema.SessionTables) + } + + defer func() { + // stage != kv.InvalidStagingHandle means error occurs, we need to cleanup sessionData + if stage != kv.InvalidStagingHandle { + sessionData.Cleanup(stage) + } + }() + + for tblID, tbl := range txnTempTables { + if !tbl.GetModified() { + continue + } + + if tbl.GetMeta().TempTableType != model.TempTableLocal { + continue + } + if _, ok := localTempTables.TableByID(tblID); !ok { + continue + } + + if stage == kv.InvalidStagingHandle { + stage = sessionData.Staging() + } + + tblPrefix := tablecodec.EncodeTablePrefix(tblID) + endKey := tablecodec.EncodeTablePrefix(tblID + 1) + + txnMemBuffer := s.txn.GetMemBuffer() + iter, err := txnMemBuffer.Iter(tblPrefix, endKey) + if err != nil { + return err + } + + for iter.Valid() { + key := iter.Key() + if !bytes.HasPrefix(key, tblPrefix) { + break + } + + value := iter.Value() + if len(value) == 0 { + err = sessionData.DeleteTableKey(tblID, key) + } else { + err = sessionData.SetTableKey(tblID, key, iter.Value()) + } + + if err != nil { + return err + } + + err = iter.Next() + if err != nil { + return err + } + } + } + + err := txn.Commit(ctx) + if err != nil { + return err + } + + if stage != kv.InvalidStagingHandle { + sessionData.Release(stage) + stage = kv.InvalidStagingHandle + } + + return nil +} + +// errIsNoisy is used to filter DUPLICATE KEY errors. +// These can observed by users in INFORMATION_SCHEMA.CLIENT_ERRORS_SUMMARY_GLOBAL instead. +// +// The rationale for filtering these errors is because they are "client generated errors". i.e. +// of the errors defined in kv/error.go, these look to be clearly related to a client-inflicted issue, +// and the server is only responsible for handling the error correctly. It does not need to log. +func errIsNoisy(err error) bool { + if kv.ErrKeyExists.Equal(err) { + return true + } + if storeerr.ErrLockAcquireFailAndNoWaitSet.Equal(err) { + return true + } + return false +} + +func (s *session) doCommitWithRetry(ctx context.Context) error { + defer func() { + s.GetSessionVars().SetTxnIsolationLevelOneShotStateForNextTxn() + s.txn.changeToInvalid() + s.cleanRetryInfo() + sessiontxn.GetTxnManager(s).OnTxnEnd() + }() + if !s.txn.Valid() { + // If the transaction is invalid, maybe it has already been rolled back by the client. + return nil + } + isInternalTxn := false + if internal := s.txn.GetOption(kv.RequestSourceInternal); internal != nil && internal.(bool) { + isInternalTxn = true + } + var err error + txnSize := s.txn.Size() + isPessimistic := s.txn.IsPessimistic() + isPipelined := s.txn.IsPipelined() + r, ctx := tracing.StartRegionEx(ctx, "session.doCommitWithRetry") + defer r.End() + + err = s.doCommit(ctx) + if err != nil { + // polish the Write Conflict error message + newErr := s.tryReplaceWriteConflictError(err) + if newErr != nil { + err = newErr + } + + commitRetryLimit := s.sessionVars.RetryLimit + if !s.sessionVars.TxnCtx.CouldRetry { + commitRetryLimit = 0 + } + // Don't retry in BatchInsert mode. As a counter-example, insert into t1 select * from t2, + // BatchInsert already commit the first batch 1000 rows, then it commit 1000-2000 and retry the statement, + // Finally t1 will have more data than t2, with no errors return to user! + if s.isTxnRetryableError(err) && !s.sessionVars.BatchInsert && commitRetryLimit > 0 && !isPessimistic && !isPipelined { + logutil.Logger(ctx).Warn("sql", + zap.String("label", s.GetSQLLabel()), + zap.Error(err), + zap.String("txn", s.txn.GoString())) + // Transactions will retry 2 ~ commitRetryLimit times. + // We make larger transactions retry less times to prevent cluster resource outage. + txnSizeRate := float64(txnSize) / float64(kv.TxnTotalSizeLimit.Load()) + maxRetryCount := commitRetryLimit - int64(float64(commitRetryLimit-1)*txnSizeRate) + err = s.retry(ctx, uint(maxRetryCount)) + } else if !errIsNoisy(err) { + logutil.Logger(ctx).Warn("can not retry txn", + zap.String("label", s.GetSQLLabel()), + zap.Error(err), + zap.Bool("IsBatchInsert", s.sessionVars.BatchInsert), + zap.Bool("IsPessimistic", isPessimistic), + zap.Bool("InRestrictedSQL", s.sessionVars.InRestrictedSQL), + zap.Int64("tidb_retry_limit", s.sessionVars.RetryLimit), + zap.Bool("tidb_disable_txn_auto_retry", s.sessionVars.DisableTxnAutoRetry)) + } + } + counter := s.sessionVars.TxnCtx.StatementCount + duration := time.Since(s.GetSessionVars().TxnCtx.CreateTime).Seconds() + s.recordOnTransactionExecution(err, counter, duration, isInternalTxn) + + if err != nil { + if !errIsNoisy(err) { + logutil.Logger(ctx).Warn("commit failed", + zap.String("finished txn", s.txn.GoString()), + zap.Error(err)) + } + return err + } + s.updateStatsDeltaToCollector() + return nil +} + +// adds more information about the table in the error message +// precondition: oldErr is a 9007:WriteConflict Error +func (s *session) tryReplaceWriteConflictError(oldErr error) (newErr error) { + if !kv.ErrWriteConflict.Equal(oldErr) { + return nil + } + if errors.RedactLogEnabled.Load() == errors.RedactLogEnable { + return nil + } + originErr := errors.Cause(oldErr) + inErr, _ := originErr.(*errors.Error) + // we don't want to modify the oldErr, so copy the args list + oldArgs := inErr.Args() + args := make([]any, len(oldArgs)) + copy(args, oldArgs) + is := sessiontxn.GetTxnManager(s).GetTxnInfoSchema() + if is == nil { + return nil + } + newKeyTableField, ok := addTableNameInTableIDField(args[3], is) + if ok { + args[3] = newKeyTableField + } + newPrimaryKeyTableField, ok := addTableNameInTableIDField(args[5], is) + if ok { + args[5] = newPrimaryKeyTableField + } + return kv.ErrWriteConflict.FastGenByArgs(args...) +} + +// precondition: is != nil +func addTableNameInTableIDField(tableIDField any, is infoschema.InfoSchema) (enhancedMsg string, done bool) { + keyTableID, ok := tableIDField.(string) + if !ok { + return "", false + } + stringsInTableIDField := strings.Split(keyTableID, "=") + if len(stringsInTableIDField) == 0 { + return "", false + } + tableIDStr := stringsInTableIDField[len(stringsInTableIDField)-1] + tableID, err := strconv.ParseInt(tableIDStr, 10, 64) + if err != nil { + return "", false + } + var tableName string + tbl, ok := is.TableByID(tableID) + if !ok { + tableName = "unknown" + } else { + dbInfo, ok := infoschema.SchemaByTable(is, tbl.Meta()) + if !ok { + tableName = "unknown." + tbl.Meta().Name.String() + } else { + tableName = dbInfo.Name.String() + "." + tbl.Meta().Name.String() + } + } + enhancedMsg = keyTableID + ", tableName=" + tableName + return enhancedMsg, true +} + +func (s *session) updateStatsDeltaToCollector() { + mapper := s.GetSessionVars().TxnCtx.TableDeltaMap + if s.statsCollector != nil && mapper != nil { + for _, item := range mapper { + if item.TableID > 0 { + s.statsCollector.Update(item.TableID, item.Delta, item.Count, &item.ColSize) + } + } + } +} + +func (s *session) CommitTxn(ctx context.Context) error { + r, ctx := tracing.StartRegionEx(ctx, "session.CommitTxn") + defer r.End() + + var commitDetail *tikvutil.CommitDetails + ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) + err := s.doCommitWithRetry(ctx) + if commitDetail != nil { + s.sessionVars.StmtCtx.MergeExecDetails(nil, commitDetail) + } + + // record the TTLInsertRows in the metric + metrics.TTLInsertRowsCount.Add(float64(s.sessionVars.TxnCtx.InsertTTLRowsCount)) + + failpoint.Inject("keepHistory", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(err) + } + }) + s.sessionVars.TxnCtx.Cleanup() + s.sessionVars.CleanupTxnReadTSIfUsed() + return err +} + +func (s *session) RollbackTxn(ctx context.Context) { + r, ctx := tracing.StartRegionEx(ctx, "session.RollbackTxn") + defer r.End() + + if s.txn.Valid() { + terror.Log(s.txn.Rollback()) + } + if ctx.Value(inCloseSession{}) == nil { + s.cleanRetryInfo() + } + s.txn.changeToInvalid() + s.sessionVars.TxnCtx.Cleanup() + s.sessionVars.CleanupTxnReadTSIfUsed() + s.sessionVars.SetInTxn(false) + sessiontxn.GetTxnManager(s).OnTxnEnd() +} + +func (s *session) GetClient() kv.Client { + return s.client +} + +func (s *session) GetMPPClient() kv.MPPClient { + return s.mppClient +} + +func (s *session) String() string { + // TODO: how to print binded context in values appropriately? + sessVars := s.sessionVars + data := map[string]any{ + "id": sessVars.ConnectionID, + "user": sessVars.User, + "currDBName": sessVars.CurrentDB, + "status": sessVars.Status(), + "strictMode": sessVars.SQLMode.HasStrictMode(), + } + if s.txn.Valid() { + // if txn is committed or rolled back, txn is nil. + data["txn"] = s.txn.String() + } + if sessVars.SnapshotTS != 0 { + data["snapshotTS"] = sessVars.SnapshotTS + } + if sessVars.StmtCtx.LastInsertID > 0 { + data["lastInsertID"] = sessVars.StmtCtx.LastInsertID + } + if len(sessVars.PreparedStmts) > 0 { + data["preparedStmtCount"] = len(sessVars.PreparedStmts) + } + b, err := json.MarshalIndent(data, "", " ") + terror.Log(errors.Trace(err)) + return string(b) +} + +const sqlLogMaxLen = 1024 + +// SchemaChangedWithoutRetry is used for testing. +var SchemaChangedWithoutRetry uint32 + +func (s *session) GetSQLLabel() string { + if s.sessionVars.InRestrictedSQL { + return metrics.LblInternal + } + return metrics.LblGeneral +} + +func (s *session) isInternal() bool { + return s.sessionVars.InRestrictedSQL +} + +func (*session) isTxnRetryableError(err error) bool { + if atomic.LoadUint32(&SchemaChangedWithoutRetry) == 1 { + return kv.IsTxnRetryableError(err) + } + return kv.IsTxnRetryableError(err) || domain.ErrInfoSchemaChanged.Equal(err) +} + +func isEndTxnStmt(stmt ast.StmtNode, vars *variable.SessionVars) (bool, error) { + switch n := stmt.(type) { + case *ast.RollbackStmt, *ast.CommitStmt: + return true, nil + case *ast.ExecuteStmt: + ps, err := plannercore.GetPreparedStmt(n, vars) + if err != nil { + return false, err + } + return isEndTxnStmt(ps.PreparedAst.Stmt, vars) + } + return false, nil +} + +func (s *session) checkTxnAborted(stmt sqlexec.Statement) error { + if atomic.LoadUint32(&s.GetSessionVars().TxnCtx.LockExpire) == 0 { + return nil + } + // If the transaction is aborted, the following statements do not need to execute, except `commit` and `rollback`, + // because they are used to finish the aborted transaction. + if ok, err := isEndTxnStmt(stmt.(*executor.ExecStmt).StmtNode, s.sessionVars); err == nil && ok { + return nil + } else if err != nil { + return err + } + return kv.ErrLockExpire +} + +func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { + var retryCnt uint + defer func() { + s.sessionVars.RetryInfo.Retrying = false + // retryCnt only increments on retryable error, so +1 here. + if s.sessionVars.InRestrictedSQL { + session_metrics.TransactionRetryInternal.Observe(float64(retryCnt + 1)) + } else { + session_metrics.TransactionRetryGeneral.Observe(float64(retryCnt + 1)) + } + s.sessionVars.SetInTxn(false) + if err != nil { + s.RollbackTxn(ctx) + } + s.txn.changeToInvalid() + }() + + connID := s.sessionVars.ConnectionID + s.sessionVars.RetryInfo.Retrying = true + if atomic.LoadUint32(&s.sessionVars.TxnCtx.ForUpdate) == 1 { + err = ErrForUpdateCantRetry.GenWithStackByArgs(connID) + return err + } + + nh := GetHistory(s) + var schemaVersion int64 + sessVars := s.GetSessionVars() + orgStartTS := sessVars.TxnCtx.StartTS + label := s.GetSQLLabel() + for { + if err = s.PrepareTxnCtx(ctx); err != nil { + return err + } + s.sessionVars.RetryInfo.ResetOffset() + for i, sr := range nh.history { + st := sr.st + s.sessionVars.StmtCtx = sr.stmtCtx + s.sessionVars.StmtCtx.CTEStorageMap = map[int]*executor.CTEStorages{} + s.sessionVars.StmtCtx.ResetForRetry() + s.sessionVars.PlanCacheParams.Reset() + schemaVersion, err = st.RebuildPlan(ctx) + if err != nil { + return err + } + + if retryCnt == 0 { + // We do not have to log the query every time. + // We print the queries at the first try only. + sql := sqlForLog(st.GetTextToLog(false)) + if sessVars.EnableRedactLog != errors.RedactLogEnable { + sql += redact.String(sessVars.EnableRedactLog, sessVars.PlanCacheParams.String()) + } + logutil.Logger(ctx).Warn("retrying", + zap.Int64("schemaVersion", schemaVersion), + zap.Uint("retryCnt", retryCnt), + zap.Int("queryNum", i), + zap.String("sql", sql)) + } else { + logutil.Logger(ctx).Warn("retrying", + zap.Int64("schemaVersion", schemaVersion), + zap.Uint("retryCnt", retryCnt), + zap.Int("queryNum", i)) + } + _, digest := s.sessionVars.StmtCtx.SQLDigest() + s.txn.onStmtStart(digest.String()) + if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx, st.GetStmtNode()); err == nil { + _, err = st.Exec(ctx) + } + s.txn.onStmtEnd() + if err != nil { + s.StmtRollback(ctx, false) + break + } + s.StmtCommit(ctx) + } + logutil.Logger(ctx).Warn("transaction association", + zap.Uint64("retrying txnStartTS", s.GetSessionVars().TxnCtx.StartTS), + zap.Uint64("original txnStartTS", orgStartTS)) + failpoint.Inject("preCommitHook", func() { + hook, ok := ctx.Value("__preCommitHook").(func()) + if ok { + hook() + } + }) + if err == nil { + err = s.doCommit(ctx) + if err == nil { + break + } + } + if !s.isTxnRetryableError(err) { + logutil.Logger(ctx).Warn("sql", + zap.String("label", label), + zap.Stringer("session", s), + zap.Error(err)) + metrics.SessionRetryErrorCounter.WithLabelValues(label, metrics.LblUnretryable).Inc() + return err + } + retryCnt++ + if retryCnt >= maxCnt { + logutil.Logger(ctx).Warn("sql", + zap.String("label", label), + zap.Uint("retry reached max count", retryCnt)) + metrics.SessionRetryErrorCounter.WithLabelValues(label, metrics.LblReachMax).Inc() + return err + } + logutil.Logger(ctx).Warn("sql", + zap.String("label", label), + zap.Error(err), + zap.String("txn", s.txn.GoString())) + kv.BackOff(retryCnt) + s.txn.changeToInvalid() + s.sessionVars.SetInTxn(false) + } + return err +} + +func sqlForLog(sql string) string { + if len(sql) > sqlLogMaxLen { + sql = sql[:sqlLogMaxLen] + fmt.Sprintf("(len:%d)", len(sql)) + } + return executor.QueryReplacer.Replace(sql) +} + +func (s *session) sysSessionPool() util.SessionPool { + return domain.GetDomain(s).SysSessionPool() +} + +func createSessionFunc(store kv.Storage) pools.Factory { + return func() (pools.Resource, error) { + se, err := createSession(store) + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.AutoCommit, "1") + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.MaxExecutionTime, "0") + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.TiDBEnableWindowFunction, variable.BoolToOnOff(variable.DefEnableWindowFunction)) + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.TiDBConstraintCheckInPlacePessimistic, variable.On) + if err != nil { + return nil, errors.Trace(err) + } + se.sessionVars.CommonGlobalLoaded = true + se.sessionVars.InRestrictedSQL = true + // Internal session uses default format to prevent memory leak problem. + se.sessionVars.EnableChunkRPC = false + return se, nil + } +} + +func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.Resource, error) { + return func(dom *domain.Domain) (pools.Resource, error) { + se, err := CreateSessionWithDomain(store, dom) + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.AutoCommit, "1") + if err != nil { + return nil, err + } + err = se.sessionVars.SetSystemVar(variable.MaxExecutionTime, "0") + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) + if err != nil { + return nil, errors.Trace(err) + } + err = se.sessionVars.SetSystemVar(variable.TiDBConstraintCheckInPlacePessimistic, variable.On) + if err != nil { + return nil, errors.Trace(err) + } + se.sessionVars.CommonGlobalLoaded = true + se.sessionVars.InRestrictedSQL = true + // Internal session uses default format to prevent memory leak problem. + se.sessionVars.EnableChunkRPC = false + return se, nil + } +} + +func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet, alloc chunk.Allocator) ([]chunk.Row, error) { + var rows []chunk.Row + var req *chunk.Chunk + req = rs.NewChunk(alloc) + for { + err := rs.Next(ctx, req) + if err != nil || req.NumRows() == 0 { + return rows, err + } + iter := chunk.NewIterator4Chunk(req) + for r := iter.Begin(); r != iter.End(); r = iter.Next() { + rows = append(rows, r) + } + req = chunk.Renew(req, se.sessionVars.MaxChunkSize) + } +} + +// getTableValue executes restricted sql and the result is one column. +// It returns a string value. +func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { + if ctx.Value(kv.RequestSourceKey) == nil { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnSysVar) + } + rows, fields, err := s.ExecRestrictedSQL(ctx, nil, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) + if err != nil { + return "", err + } + if len(rows) == 0 { + return "", errResultIsEmpty + } + d := rows[0].GetDatum(0, &fields[0].Column.FieldType) + value, err := d.ToString() + if err != nil { + return "", err + } + return value, nil +} + +// replaceGlobalVariablesTableValue executes restricted sql updates the variable value +// It will then notify the etcd channel that the value has changed. +func (s *session) replaceGlobalVariablesTableValue(ctx context.Context, varName, val string, updateLocal bool) error { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnSysVar) + _, _, err := s.ExecRestrictedSQL(ctx, nil, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) + if err != nil { + return err + } + domain.GetDomain(s).NotifyUpdateSysVarCache(updateLocal) + return err +} + +// GetGlobalSysVar implements GlobalVarAccessor.GetGlobalSysVar interface. +func (s *session) GetGlobalSysVar(name string) (string, error) { + if s.Value(sessionctx.Initing) != nil { + // When running bootstrap or upgrade, we should not access global storage. + return "", nil + } + + sv := variable.GetSysVar(name) + if sv == nil { + // It might be a recently unregistered sysvar. We should return unknown + // since GetSysVar is the canonical version, but we can update the cache + // so the next request doesn't attempt to load this. + logutil.BgLogger().Info("sysvar does not exist. sysvar cache may be stale", zap.String("name", name)) + return "", variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + } + + sysVar, err := domain.GetDomain(s).GetGlobalVar(name) + if err != nil { + // The sysvar exists, but there is no cache entry yet. + // This might be because the sysvar was only recently registered. + // In which case it is safe to return the default, but we can also + // update the cache for the future. + logutil.BgLogger().Info("sysvar not in cache yet. sysvar cache may be stale", zap.String("name", name)) + sysVar, err = s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) + if err != nil { + return sv.Value, nil + } + } + // It might have been written from an earlier TiDB version, so we should do type validation + // See https://github.com/pingcap/tidb/issues/30255 for why we don't do full validation. + // If validation fails, we should return the default value: + // See: https://github.com/pingcap/tidb/pull/31566 + sysVar, err = sv.ValidateFromType(s.GetSessionVars(), sysVar, variable.ScopeGlobal) + if err != nil { + return sv.Value, nil + } + return sysVar, nil +} + +// SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface. +// it is called (but skipped) when setting instance scope +func (s *session) SetGlobalSysVar(ctx context.Context, name string, value string) (err error) { + sv := variable.GetSysVar(name) + if sv == nil { + return variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + } + if value, err = sv.Validate(s.sessionVars, value, variable.ScopeGlobal); err != nil { + return err + } + if err = sv.SetGlobalFromHook(ctx, s.sessionVars, value, false); err != nil { + return err + } + if sv.HasInstanceScope() { // skip for INSTANCE scope + return nil + } + if sv.GlobalConfigName != "" { + domain.GetDomain(s).NotifyGlobalConfigChange(sv.GlobalConfigName, variable.OnOffToTrueFalse(value)) + } + return s.replaceGlobalVariablesTableValue(context.TODO(), sv.Name, value, true) +} + +// SetGlobalSysVarOnly updates the sysvar, but does not call the validation function or update aliases. +// This is helpful to prevent duplicate warnings being appended from aliases, or recursion. +// updateLocal indicates whether to rebuild the local SysVar Cache. This is helpful to prevent recursion. +func (s *session) SetGlobalSysVarOnly(ctx context.Context, name string, value string, updateLocal bool) (err error) { + sv := variable.GetSysVar(name) + if sv == nil { + return variable.ErrUnknownSystemVar.GenWithStackByArgs(name) + } + if err = sv.SetGlobalFromHook(ctx, s.sessionVars, value, true); err != nil { + return err + } + if sv.HasInstanceScope() { // skip for INSTANCE scope + return nil + } + return s.replaceGlobalVariablesTableValue(ctx, sv.Name, value, updateLocal) +} + +// SetTiDBTableValue implements GlobalVarAccessor.SetTiDBTableValue interface. +func (s *session) SetTiDBTableValue(name, value, comment string) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnSysVar) + _, _, err := s.ExecRestrictedSQL(ctx, nil, `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) + return err +} + +// GetTiDBTableValue implements GlobalVarAccessor.GetTiDBTableValue interface. +func (s *session) GetTiDBTableValue(name string) (string, error) { + return s.getTableValue(context.TODO(), mysql.TiDBTable, name) +} + +var _ sqlexec.SQLParser = &session{} + +func (s *session) ParseSQL(ctx context.Context, sql string, params ...parser.ParseParam) ([]ast.StmtNode, []error, error) { + defer tracing.StartRegion(ctx, "ParseSQL").End() + + p := parserPool.Get().(*parser.Parser) + defer parserPool.Put(p) + + sqlMode := s.sessionVars.SQLMode + if s.isInternal() { + sqlMode = mysql.DelSQLMode(sqlMode, mysql.ModeNoBackslashEscapes) + } + p.SetSQLMode(sqlMode) + p.SetParserConfig(s.sessionVars.BuildParserConfig()) + tmp, warn, err := p.ParseSQL(sql, params...) + // The []ast.StmtNode is referenced by the parser, to reuse the parser, make a copy of the result. + res := make([]ast.StmtNode, len(tmp)) + copy(res, tmp) + return res, warn, err +} + +func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) { + // If command == mysql.ComSleep, it means the SQL execution is finished. The processinfo is reset to SLEEP. + // If the SQL finished and the session is not in transaction, the current start timestamp need to reset to 0. + // Otherwise, it should be set to the transaction start timestamp. + // Why not reset the transaction start timestamp to 0 when transaction committed? + // Because the select statement and other statements need this timestamp to read data, + // after the transaction is committed. e.g. SHOW MASTER STATUS; + var curTxnStartTS uint64 + var curTxnCreateTime time.Time + if command != mysql.ComSleep || s.GetSessionVars().InTxn() { + curTxnStartTS = s.sessionVars.TxnCtx.StartTS + curTxnCreateTime = s.sessionVars.TxnCtx.CreateTime + } + // Set curTxnStartTS to SnapshotTS directly when the session is trying to historic read. + // It will avoid the session meet GC lifetime too short error. + if s.GetSessionVars().SnapshotTS != 0 { + curTxnStartTS = s.GetSessionVars().SnapshotTS + } + p := s.currentPlan + if explain, ok := p.(*plannercore.Explain); ok && explain.Analyze && explain.TargetPlan != nil { + p = explain.TargetPlan + } + + pi := util.ProcessInfo{ + ID: s.sessionVars.ConnectionID, + Port: s.sessionVars.Port, + DB: s.sessionVars.CurrentDB, + Command: command, + Plan: p, + PlanExplainRows: plannercore.GetExplainRowsForPlan(p), + RuntimeStatsColl: s.sessionVars.StmtCtx.RuntimeStatsColl, + Time: t, + State: s.Status(), + Info: sql, + CurTxnStartTS: curTxnStartTS, + CurTxnCreateTime: curTxnCreateTime, + StmtCtx: s.sessionVars.StmtCtx, + RefCountOfStmtCtx: &s.sessionVars.RefCountOfStmtCtx, + MemTracker: s.sessionVars.MemTracker, + DiskTracker: s.sessionVars.DiskTracker, + StatsInfo: plannercore.GetStatsInfo, + OOMAlarmVariablesInfo: s.getOomAlarmVariablesInfo(), + TableIDs: s.sessionVars.StmtCtx.TableIDs, + IndexNames: s.sessionVars.StmtCtx.IndexNames, + MaxExecutionTime: maxExecutionTime, + RedactSQL: s.sessionVars.EnableRedactLog, + ResourceGroupName: s.sessionVars.StmtCtx.ResourceGroupName, + SessionAlias: s.sessionVars.SessionAlias, + CursorTracker: s.cursorTracker, + } + oldPi := s.ShowProcess() + if p == nil { + // Store the last valid plan when the current plan is nil. + // This is for `explain for connection` statement has the ability to query the last valid plan. + if oldPi != nil && oldPi.Plan != nil && len(oldPi.PlanExplainRows) > 0 { + pi.Plan = oldPi.Plan + pi.PlanExplainRows = oldPi.PlanExplainRows + pi.RuntimeStatsColl = oldPi.RuntimeStatsColl + } + } + // We set process info before building plan, so we extended execution time. + if oldPi != nil && oldPi.Info == pi.Info && oldPi.Command == pi.Command { + pi.Time = oldPi.Time + } + if oldPi != nil && oldPi.CurTxnStartTS != 0 && oldPi.CurTxnStartTS == pi.CurTxnStartTS { + // Keep the last expensive txn log time, avoid print too many expensive txn logs. + pi.ExpensiveTxnLogTime = oldPi.ExpensiveTxnLogTime + } + _, digest := s.sessionVars.StmtCtx.SQLDigest() + pi.Digest = digest.String() + // DO NOT reset the currentPlan to nil until this query finishes execution, otherwise reentrant calls + // of SetProcessInfo would override Plan and PlanExplainRows to nil. + if command == mysql.ComSleep { + s.currentPlan = nil + } + if s.sessionVars.User != nil { + pi.User = s.sessionVars.User.Username + pi.Host = s.sessionVars.User.Hostname + } + s.processInfo.Store(&pi) +} + +// UpdateProcessInfo updates the session's process info for the running statement. +func (s *session) UpdateProcessInfo() { + pi := s.ShowProcess() + if pi == nil || pi.CurTxnStartTS != 0 { + return + } + // do not modify this two fields in place, see issue: issues/50607 + shallowCP := pi.Clone() + // Update the current transaction start timestamp. + shallowCP.CurTxnStartTS = s.sessionVars.TxnCtx.StartTS + shallowCP.CurTxnCreateTime = s.sessionVars.TxnCtx.CreateTime + s.processInfo.Store(shallowCP) +} + +func (s *session) getOomAlarmVariablesInfo() util.OOMAlarmVariablesInfo { + return util.OOMAlarmVariablesInfo{ + SessionAnalyzeVersion: s.sessionVars.AnalyzeVersion, + SessionEnabledRateLimitAction: s.sessionVars.EnabledRateLimitAction, + SessionMemQuotaQuery: s.sessionVars.MemQuotaQuery, + } +} + +func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...any) (rs sqlexec.RecordSet, err error) { + origin := s.sessionVars.InRestrictedSQL + s.sessionVars.InRestrictedSQL = true + defer func() { + s.sessionVars.InRestrictedSQL = origin + // Restore the goroutine label by using the original ctx after execution is finished. + pprof.SetGoroutineLabels(ctx) + }() + + r, ctx := tracing.StartRegionEx(ctx, "session.ExecuteInternal") + defer r.End() + logutil.Eventf(ctx, "execute: %s", sql) + + stmtNode, err := s.ParseWithParams(ctx, sql, args...) + if err != nil { + return nil, err + } + + rs, err = s.ExecuteStmt(ctx, stmtNode) + if err != nil { + s.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, err + } + + return rs, err +} + +// Execute is deprecated, we can remove it as soon as plugins are migrated. +func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { + r, ctx := tracing.StartRegionEx(ctx, "session.Execute") + defer r.End() + logutil.Eventf(ctx, "execute: %s", sql) + + stmtNodes, err := s.Parse(ctx, sql) + if err != nil { + return nil, err + } + if len(stmtNodes) != 1 { + return nil, errors.New("Execute() API doesn't support multiple statements any more") + } + + rs, err := s.ExecuteStmt(ctx, stmtNodes[0]) + if err != nil { + s.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, err + } + return []sqlexec.RecordSet{rs}, err +} + +// Parse parses a query string to raw ast.StmtNode. +func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) { + logutil.Logger(ctx).Debug("parse", zap.String("sql", sql)) + parseStartTime := time.Now() + + // Load the session variables to the context. + // This is necessary for the parser to get the current sql_mode. + if err := s.loadCommonGlobalVariablesIfNeeded(); err != nil { + return nil, err + } + + stmts, warns, err := s.ParseSQL(ctx, sql, s.sessionVars.GetParseParams()...) + if err != nil { + s.rollbackOnError(ctx) + err = util.SyntaxError(err) + + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", redact.String(s.sessionVars.EnableRedactLog, sql))) + s.sessionVars.StmtCtx.AppendError(err) + } + return nil, err + } + + durParse := time.Since(parseStartTime) + s.GetSessionVars().DurationParse = durParse + isInternal := s.isInternal() + if isInternal { + session_metrics.SessionExecuteParseDurationInternal.Observe(durParse.Seconds()) + } else { + session_metrics.SessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) + } + for _, warn := range warns { + s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + return stmts, nil +} + +// ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. +// Note that it will not do escaping if no variable arguments are passed. +func (s *session) ParseWithParams(ctx context.Context, sql string, args ...any) (ast.StmtNode, error) { + var err error + if len(args) > 0 { + sql, err = sqlescape.EscapeSQL(sql, args...) + if err != nil { + return nil, err + } + } + + internal := s.isInternal() + + var stmts []ast.StmtNode + var warns []error + parseStartTime := time.Now() + if internal { + // Do no respect the settings from clients, if it is for internal usage. + // Charsets from clients may give chance injections. + // Refer to https://stackoverflow.com/questions/5741187/sql-injection-that-gets-around-mysql-real-escape-string/12118602. + stmts, warns, err = s.ParseSQL(ctx, sql) + } else { + stmts, warns, err = s.ParseSQL(ctx, sql, s.sessionVars.GetParseParams()...) + } + if len(stmts) != 1 && err == nil { + err = errors.New("run multiple statements internally is not supported") + } + if err != nil { + s.rollbackOnError(ctx) + logSQL := sql[:min(500, len(sql))] + logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", redact.String(s.sessionVars.EnableRedactLog, logSQL))) + return nil, util.SyntaxError(err) + } + durParse := time.Since(parseStartTime) + if internal { + session_metrics.SessionExecuteParseDurationInternal.Observe(durParse.Seconds()) + } else { + session_metrics.SessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) + } + for _, warn := range warns { + s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) + } + if topsqlstate.TopSQLEnabled() { + normalized, digest := parser.NormalizeDigest(sql) + if digest != nil { + // Reset the goroutine label when internal sql execute finish. + // Specifically reset in ExecRestrictedStmt function. + s.sessionVars.StmtCtx.IsSQLRegistered.Store(true) + topsql.AttachAndRegisterSQLInfo(ctx, normalized, digest, s.sessionVars.InRestrictedSQL) + } + } + return stmts[0], nil +} + +// GetAdvisoryLock acquires an advisory lock of lockName. +// Note that a lock can be acquired multiple times by the same session, +// in which case we increment a reference count. +// Each lock needs to be held in a unique session because +// we need to be able to ROLLBACK in any arbitrary order +// in order to release the locks. +func (s *session) GetAdvisoryLock(lockName string, timeout int64) error { + if lock, ok := s.advisoryLocks[lockName]; ok { + lock.IncrReferences() + return nil + } + sess, err := createSession(s.store) + if err != nil { + return err + } + infosync.StoreInternalSession(sess) + lock := &advisoryLock{session: sess, ctx: context.TODO(), owner: s.ShowProcess().ID} + err = lock.GetLock(lockName, timeout) + if err != nil { + return err + } + s.advisoryLocks[lockName] = lock + return nil +} + +// IsUsedAdvisoryLock checks if a lockName is already in use +func (s *session) IsUsedAdvisoryLock(lockName string) uint64 { + // Same session + if lock, ok := s.advisoryLocks[lockName]; ok { + return lock.owner + } + + // Check for transaction on advisory_locks table + sess, err := createSession(s.store) + if err != nil { + return 0 + } + lock := &advisoryLock{session: sess, ctx: context.TODO(), owner: s.ShowProcess().ID} + err = lock.IsUsedLock(lockName) + if err != nil { + // TODO: Return actual owner pid + // TODO: Check for mysql.ErrLockWaitTimeout and DeadLock + return 1 + } + return 0 +} + +// ReleaseAdvisoryLock releases an advisory locks held by the session. +// It returns FALSE if no lock by this name was held (by this session), +// and TRUE if a lock was held and "released". +// Note that the lock is not actually released if there are multiple +// references to the same lockName by the session, instead the reference +// count is decremented. +func (s *session) ReleaseAdvisoryLock(lockName string) (released bool) { + if lock, ok := s.advisoryLocks[lockName]; ok { + lock.DecrReferences() + if lock.ReferenceCount() <= 0 { + lock.Close() + delete(s.advisoryLocks, lockName) + infosync.DeleteInternalSession(lock.session) + } + return true + } + return false +} + +// ReleaseAllAdvisoryLocks releases all advisory locks held by the session +// and returns a count of the locks that were released. +// The count is based on unique locks held, so multiple references +// to the same lock do not need to be accounted for. +func (s *session) ReleaseAllAdvisoryLocks() int { + var count int + for lockName, lock := range s.advisoryLocks { + lock.Close() + count += lock.ReferenceCount() + delete(s.advisoryLocks, lockName) + infosync.DeleteInternalSession(lock.session) + } + return count +} + +// GetExtensions returns the `*extension.SessionExtensions` object +func (s *session) GetExtensions() *extension.SessionExtensions { + return s.extensions +} + +// SetExtensions sets the `*extension.SessionExtensions` object +func (s *session) SetExtensions(extensions *extension.SessionExtensions) { + s.extensions = extensions +} + +// InSandBoxMode indicates that this session is in sandbox mode +func (s *session) InSandBoxMode() bool { + return s.sandBoxMode +} + +// EnableSandBoxMode enable the sandbox mode. +func (s *session) EnableSandBoxMode() { + s.sandBoxMode = true +} + +// DisableSandBoxMode enable the sandbox mode. +func (s *session) DisableSandBoxMode() { + s.sandBoxMode = false +} + +// ParseWithParams4Test wrapper (s *session) ParseWithParams for test +func ParseWithParams4Test(ctx context.Context, s types.Session, + sql string, args ...any) (ast.StmtNode, error) { + return s.(*session).ParseWithParams(ctx, sql, args) +} + +var _ sqlexec.RestrictedSQLExecutor = &session{} +var _ sqlexec.SQLExecutor = &session{} + +// ExecRestrictedStmt implements RestrictedSQLExecutor interface. +func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( + []chunk.Row, []*ast.ResultField, error) { + defer pprof.SetGoroutineLabels(ctx) + execOption := sqlexec.GetExecOption(opts) + var se *session + var clean func() + var err error + if execOption.UseCurSession { + se, clean, err = s.useCurrentSession(execOption) + } else { + se, clean, err = s.getInternalSession(execOption) + } + if err != nil { + return nil, nil, err + } + defer clean() + + startTime := time.Now() + metrics.SessionRestrictedSQLCounter.Inc() + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) + ctx = context.WithValue(ctx, tikvutil.RUDetailsCtxKey, tikvutil.NewRUDetails()) + rs, err := se.ExecuteStmt(ctx, stmtNode) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs, nil) + if err != nil { + return nil, nil, err + } + + vars := se.GetSessionVars() + for _, dbName := range GetDBNames(vars) { + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal, dbName, vars.StmtCtx.ResourceGroupName).Observe(time.Since(startTime).Seconds()) + } + return rows, rs.Fields(), err +} + +// ExecRestrictedStmt4Test wrapper `(s *session) ExecRestrictedStmt` for test. +func ExecRestrictedStmt4Test(ctx context.Context, s types.Session, + stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( + []chunk.Row, []*ast.ResultField, error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) + return s.(*session).ExecRestrictedStmt(ctx, stmtNode, opts...) +} + +// only set and clean session with execOption +func (s *session) useCurrentSession(execOption sqlexec.ExecOption) (*session, func(), error) { + var err error + orgSnapshotInfoSchema, orgSnapshotTS := s.sessionVars.SnapshotInfoschema, s.sessionVars.SnapshotTS + if execOption.SnapshotTS != 0 { + if err = s.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { + return nil, nil, err + } + s.sessionVars.SnapshotInfoschema, err = getSnapshotInfoSchema(s, execOption.SnapshotTS) + if err != nil { + return nil, nil, err + } + } + prevStatsVer := s.sessionVars.AnalyzeVersion + if execOption.AnalyzeVer != 0 { + s.sessionVars.AnalyzeVersion = execOption.AnalyzeVer + } + prevAnalyzeSnapshot := s.sessionVars.EnableAnalyzeSnapshot + if execOption.AnalyzeSnapshot != nil { + s.sessionVars.EnableAnalyzeSnapshot = *execOption.AnalyzeSnapshot + } + prePruneMode := s.sessionVars.PartitionPruneMode.Load() + if len(execOption.PartitionPruneMode) > 0 { + s.sessionVars.PartitionPruneMode.Store(execOption.PartitionPruneMode) + } + prevSQL := s.sessionVars.StmtCtx.OriginalSQL + prevStmtType := s.sessionVars.StmtCtx.StmtType + prevTables := s.sessionVars.StmtCtx.Tables + return s, func() { + s.sessionVars.AnalyzeVersion = prevStatsVer + s.sessionVars.EnableAnalyzeSnapshot = prevAnalyzeSnapshot + if err := s.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) + } + s.sessionVars.SnapshotInfoschema = orgSnapshotInfoSchema + s.sessionVars.SnapshotTS = orgSnapshotTS + s.sessionVars.PartitionPruneMode.Store(prePruneMode) + s.sessionVars.StmtCtx.OriginalSQL = prevSQL + s.sessionVars.StmtCtx.StmtType = prevStmtType + s.sessionVars.StmtCtx.Tables = prevTables + s.sessionVars.MemTracker.Detach() + }, nil +} + +func (s *session) getInternalSession(execOption sqlexec.ExecOption) (*session, func(), error) { + tmp, err := s.sysSessionPool().Get() + if err != nil { + return nil, nil, errors.Trace(err) + } + se := tmp.(*session) + + // The special session will share the `InspectionTableCache` with current session + // if the current session in inspection mode. + if cache := s.sessionVars.InspectionTableCache; cache != nil { + se.sessionVars.InspectionTableCache = cache + } + se.sessionVars.OptimizerUseInvisibleIndexes = s.sessionVars.OptimizerUseInvisibleIndexes + + preSkipStats := s.sessionVars.SkipMissingPartitionStats + se.sessionVars.SkipMissingPartitionStats = s.sessionVars.SkipMissingPartitionStats + + if execOption.SnapshotTS != 0 { + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { + return nil, nil, err + } + se.sessionVars.SnapshotInfoschema, err = getSnapshotInfoSchema(s, execOption.SnapshotTS) + if err != nil { + return nil, nil, err + } + } + + prevStatsVer := se.sessionVars.AnalyzeVersion + if execOption.AnalyzeVer != 0 { + se.sessionVars.AnalyzeVersion = execOption.AnalyzeVer + } + + prevAnalyzeSnapshot := se.sessionVars.EnableAnalyzeSnapshot + if execOption.AnalyzeSnapshot != nil { + se.sessionVars.EnableAnalyzeSnapshot = *execOption.AnalyzeSnapshot + } + + prePruneMode := se.sessionVars.PartitionPruneMode.Load() + if len(execOption.PartitionPruneMode) > 0 { + se.sessionVars.PartitionPruneMode.Store(execOption.PartitionPruneMode) + } + + return se, func() { + se.sessionVars.AnalyzeVersion = prevStatsVer + se.sessionVars.EnableAnalyzeSnapshot = prevAnalyzeSnapshot + if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { + logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) + } + se.sessionVars.SnapshotInfoschema = nil + se.sessionVars.SnapshotTS = 0 + if !execOption.IgnoreWarning { + if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { + warnings := se.GetSessionVars().StmtCtx.GetWarnings() + s.GetSessionVars().StmtCtx.AppendWarnings(warnings) + } + } + se.sessionVars.PartitionPruneMode.Store(prePruneMode) + se.sessionVars.OptimizerUseInvisibleIndexes = false + se.sessionVars.SkipMissingPartitionStats = preSkipStats + se.sessionVars.InspectionTableCache = nil + se.sessionVars.MemTracker.Detach() + s.sysSessionPool().Put(tmp) + }, nil +} + +func (s *session) withRestrictedSQLExecutor(ctx context.Context, opts []sqlexec.OptionFuncAlias, fn func(context.Context, *session) ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { + execOption := sqlexec.GetExecOption(opts) + var se *session + var clean func() + var err error + if execOption.UseCurSession { + se, clean, err = s.useCurrentSession(execOption) + } else { + se, clean, err = s.getInternalSession(execOption) + } + if err != nil { + return nil, nil, errors.Trace(err) + } + defer clean() + if execOption.TrackSysProcID > 0 { + err = execOption.TrackSysProc(execOption.TrackSysProcID, se) + if err != nil { + return nil, nil, errors.Trace(err) + } + // unTrack should be called before clean (return sys session) + defer execOption.UnTrackSysProc(execOption.TrackSysProcID) + } + return fn(ctx, se) +} + +func (s *session) ExecRestrictedSQL(ctx context.Context, opts []sqlexec.OptionFuncAlias, sql string, params ...any) ([]chunk.Row, []*ast.ResultField, error) { + return s.withRestrictedSQLExecutor(ctx, opts, func(ctx context.Context, se *session) ([]chunk.Row, []*ast.ResultField, error) { + stmt, err := se.ParseWithParams(ctx, sql, params...) + if err != nil { + return nil, nil, errors.Trace(err) + } + defer pprof.SetGoroutineLabels(ctx) + startTime := time.Now() + metrics.SessionRestrictedSQLCounter.Inc() + ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) + ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) + rs, err := se.ExecuteInternalStmt(ctx, stmt) + if err != nil { + se.sessionVars.StmtCtx.AppendError(err) + } + if rs == nil { + return nil, nil, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + var rows []chunk.Row + rows, err = drainRecordSet(ctx, se, rs, nil) + if err != nil { + return nil, nil, err + } + + vars := se.GetSessionVars() + for _, dbName := range GetDBNames(vars) { + metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal, dbName, vars.StmtCtx.ResourceGroupName).Observe(time.Since(startTime).Seconds()) + } + return rows, rs.Fields(), err + }) +} + +// ExecuteInternalStmt execute internal stmt +func (s *session) ExecuteInternalStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + origin := s.sessionVars.InRestrictedSQL + s.sessionVars.InRestrictedSQL = true + defer func() { + s.sessionVars.InRestrictedSQL = origin + // Restore the goroutine label by using the original ctx after execution is finished. + pprof.SetGoroutineLabels(ctx) + }() + return s.ExecuteStmt(ctx, stmtNode) +} + +func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { + r, ctx := tracing.StartRegionEx(ctx, "session.ExecuteStmt") + defer r.End() + + if err := s.PrepareTxnCtx(ctx); err != nil { + return nil, err + } + if err := s.loadCommonGlobalVariablesIfNeeded(); err != nil { + return nil, err + } + + sessVars := s.sessionVars + sessVars.StartTime = time.Now() + + // Some executions are done in compile stage, so we reset them before compile. + if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { + return nil, err + } + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + if binParam, ok := execStmt.BinaryArgs.([]param.BinaryParam); ok { + args, err := param.ExecArgs(s.GetSessionVars().StmtCtx.TypeCtx(), binParam) + if err != nil { + return nil, err + } + execStmt.BinaryArgs = args + } + } + + normalizedSQL, digest := s.sessionVars.StmtCtx.SQLDigest() + cmdByte := byte(atomic.LoadUint32(&s.GetSessionVars().CommandValue)) + if topsqlstate.TopSQLEnabled() { + s.sessionVars.StmtCtx.IsSQLRegistered.Store(true) + ctx = topsql.AttachAndRegisterSQLInfo(ctx, normalizedSQL, digest, s.sessionVars.InRestrictedSQL) + } + if sessVars.InPlanReplayer { + sessVars.StmtCtx.EnableOptimizerDebugTrace = true + } else if dom := domain.GetDomain(s); dom != nil && !sessVars.InRestrictedSQL { + // This is the earliest place we can get the SQL digest for this execution. + // If we find this digest is registered for PLAN REPLAYER CAPTURE, we need to enable optimizer debug trace no matter + // the plan digest will be matched or not. + if planReplayerHandle := dom.GetPlanReplayerHandle(); planReplayerHandle != nil { + tasks := planReplayerHandle.GetTasks() + for _, task := range tasks { + if task.SQLDigest == digest.String() { + sessVars.StmtCtx.EnableOptimizerDebugTrace = true + } + } + } + } + if sessVars.StmtCtx.EnableOptimizerDebugTrace { + plannercore.DebugTraceReceivedCommand(s.pctx, cmdByte, stmtNode) + } + + if err := s.validateStatementInTxn(stmtNode); err != nil { + return nil, err + } + + if err := s.validateStatementReadOnlyInStaleness(stmtNode); err != nil { + return nil, err + } + + // Uncorrelated subqueries will execute once when building plan, so we reset process info before building plan. + s.currentPlan = nil // reset current plan + s.SetProcessInfo(stmtNode.Text(), time.Now(), cmdByte, 0) + s.txn.onStmtStart(digest.String()) + defer sessiontxn.GetTxnManager(s).OnStmtEnd() + defer s.txn.onStmtEnd() + + if err := s.onTxnManagerStmtStartOrRetry(ctx, stmtNode); err != nil { + return nil, err + } + + failpoint.Inject("mockStmtSlow", func(val failpoint.Value) { + if strings.Contains(stmtNode.Text(), "/* sleep */") { + v, _ := val.(int) + time.Sleep(time.Duration(v) * time.Millisecond) + } + }) + + var stmtLabel string + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + prepareStmt, err := plannercore.GetPreparedStmt(execStmt, s.sessionVars) + if err == nil && prepareStmt.PreparedAst != nil { + stmtLabel = ast.GetStmtLabel(prepareStmt.PreparedAst.Stmt) + } + } + if stmtLabel == "" { + stmtLabel = ast.GetStmtLabel(stmtNode) + } + s.setRequestSource(ctx, stmtLabel, stmtNode) + + // Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). + compiler := executor.Compiler{Ctx: s} + stmt, err := compiler.Compile(ctx, stmtNode) + // check if resource group hint is valid, can't do this in planner.Optimize because we can access + // infoschema there. + if sessVars.StmtCtx.ResourceGroupName != sessVars.ResourceGroupName { + // if target resource group doesn't exist, fallback to the origin resource group. + if _, ok := domain.GetDomain(s).InfoSchema().ResourceGroupByName(model.NewCIStr(sessVars.StmtCtx.ResourceGroupName)); !ok { + logutil.Logger(ctx).Warn("Unknown resource group from hint", zap.String("name", sessVars.StmtCtx.ResourceGroupName)) + sessVars.StmtCtx.ResourceGroupName = sessVars.ResourceGroupName + if txn, err := s.Txn(false); err == nil && txn != nil && txn.Valid() { + kv.SetTxnResourceGroup(txn, sessVars.ResourceGroupName) + } + } + } + if err != nil { + s.rollbackOnError(ctx) + + // Only print log message when this SQL is from the user. + // Mute the warning for internal SQLs. + if !s.sessionVars.InRestrictedSQL { + if !variable.ErrUnknownSystemVar.Equal(err) { + sql := stmtNode.Text() + sql = parser.Normalize(sql, s.sessionVars.EnableRedactLog) + logutil.Logger(ctx).Warn("compile SQL failed", zap.Error(err), + zap.String("SQL", sql)) + } + } + return nil, err + } + + durCompile := time.Since(s.sessionVars.StartTime) + s.GetSessionVars().DurationCompile = durCompile + if s.isInternal() { + session_metrics.SessionExecuteCompileDurationInternal.Observe(durCompile.Seconds()) + } else { + session_metrics.SessionExecuteCompileDurationGeneral.Observe(durCompile.Seconds()) + } + s.currentPlan = stmt.Plan + if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { + if execStmt.Name == "" { + // for exec-stmt on bin-protocol, ignore the plan detail in `show process` to gain performance benefits. + s.currentPlan = nil + } + } + + // Execute the physical plan. + logStmt(stmt, s) + + var recordSet sqlexec.RecordSet + if stmt.PsStmt != nil { // point plan short path + recordSet, err = stmt.PointGet(ctx) + s.txn.changeToInvalid() + } else { + recordSet, err = runStmt(ctx, s, stmt) + } + + // Observe the resource group query total counter if the resource control is enabled and the + // current session is attached with a resource group. + resourceGroupName := s.GetSessionVars().StmtCtx.ResourceGroupName + if len(resourceGroupName) > 0 { + metrics.ResourceGroupQueryTotalCounter.WithLabelValues(resourceGroupName, resourceGroupName).Inc() + } + + if err != nil { + if !errIsNoisy(err) { + logutil.Logger(ctx).Warn("run statement failed", + zap.Int64("schemaVersion", s.GetInfoSchema().SchemaMetaVersion()), + zap.Error(err), + zap.String("session", s.String())) + } + return recordSet, err + } + return recordSet, nil +} + +func (s *session) GetSQLExecutor() sqlexec.SQLExecutor { + return s +} + +func (s *session) GetRestrictedSQLExecutor() sqlexec.RestrictedSQLExecutor { + return s +} + +func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context, node ast.StmtNode) error { + if s.sessionVars.RetryInfo.Retrying { + return sessiontxn.GetTxnManager(s).OnStmtRetry(ctx) + } + return sessiontxn.GetTxnManager(s).OnStmtStart(ctx, node) +} + +func (s *session) validateStatementInTxn(stmtNode ast.StmtNode) error { + vars := s.GetSessionVars() + if _, ok := stmtNode.(*ast.ImportIntoStmt); ok && vars.InTxn() { + return errors.New("cannot run IMPORT INTO in explicit transaction") + } + return nil +} + +func (s *session) validateStatementReadOnlyInStaleness(stmtNode ast.StmtNode) error { + vars := s.GetSessionVars() + if !vars.TxnCtx.IsStaleness && vars.TxnReadTS.PeakTxnReadTS() == 0 && !vars.EnableExternalTSRead || vars.InRestrictedSQL { + return nil + } + errMsg := "only support read-only statement during read-only staleness transactions" + node := stmtNode.(ast.Node) + switch v := node.(type) { + case *ast.SplitRegionStmt: + return nil + case *ast.SelectStmt: + // select lock statement needs start a transaction which will be conflict to stale read, + // we forbid select lock statement in stale read for now. + if v.LockInfo != nil { + return errors.New("select lock hasn't been supported in stale read yet") + } + if !planner.IsReadOnly(stmtNode, vars) { + return errors.New(errMsg) + } + return nil + case *ast.ExplainStmt, *ast.DoStmt, *ast.ShowStmt, *ast.SetOprStmt, *ast.ExecuteStmt, *ast.SetOprSelectList: + if !planner.IsReadOnly(stmtNode, vars) { + return errors.New(errMsg) + } + return nil + default: + } + // covered DeleteStmt/InsertStmt/UpdateStmt/CallStmt/LoadDataStmt + if _, ok := stmtNode.(ast.DMLNode); ok { + return errors.New(errMsg) + } + return nil +} + +// fileTransInConnKeys contains the keys of queries that will be handled by handleFileTransInConn. +var fileTransInConnKeys = []fmt.Stringer{ + executor.LoadDataVarKey, + executor.LoadStatsVarKey, + executor.IndexAdviseVarKey, + executor.PlanReplayerLoadVarKey, +} + +func (s *session) hasFileTransInConn() bool { + s.mu.RLock() + defer s.mu.RUnlock() + + for _, k := range fileTransInConnKeys { + v := s.mu.values[k] + if v != nil { + return true + } + } + return false +} + +// runStmt executes the sqlexec.Statement and commit or rollback the current transaction. +func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) { + failpoint.Inject("assertTxnManagerInRunStmt", func() { + sessiontxn.RecordAssert(se, "assertTxnManagerInRunStmt", true) + if stmt, ok := s.(*executor.ExecStmt); ok { + sessiontxn.AssertTxnManagerInfoSchema(se, stmt.InfoSchema) + } + }) + + r, ctx := tracing.StartRegionEx(ctx, "session.runStmt") + defer r.End() + if r.Span != nil { + r.Span.LogKV("sql", s.OriginText()) + } + + se.SetValue(sessionctx.QueryString, s.OriginText()) + if _, ok := s.(*executor.ExecStmt).StmtNode.(ast.DDLNode); ok { + se.SetValue(sessionctx.LastExecuteDDL, true) + } else { + se.ClearValue(sessionctx.LastExecuteDDL) + } + + sessVars := se.sessionVars + + // Save origTxnCtx here to avoid it reset in the transaction retry. + origTxnCtx := sessVars.TxnCtx + err = se.checkTxnAborted(s) + if err != nil { + return nil, err + } + if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) { + // Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check, + // otherwise, the stmt won't be add into stmt history, and also don't need check. + // About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit + if err := checkStmtLimit(ctx, se, false); err != nil { + return nil, err + } + } + + rs, err = s.Exec(ctx) + + if se.txn.Valid() && se.txn.IsPipelined() { + // Pipelined-DMLs can return assertion errors and write conflicts here because they flush + // during execution, handle these errors as we would handle errors after a commit. + if err != nil { + err = se.handleAssertionFailure(ctx, err) + } + newErr := se.tryReplaceWriteConflictError(err) + if newErr != nil { + err = newErr + } + } + + sessVars.TxnCtx.StatementCount++ + if rs != nil { + if se.GetSessionVars().StmtCtx.IsExplainAnalyzeDML { + if !sessVars.InTxn() { + se.StmtCommit(ctx) + if err := se.CommitTxn(ctx); err != nil { + return nil, err + } + } + } + return &execStmtResult{ + RecordSet: rs, + sql: s, + se: se, + }, err + } + + err = finishStmt(ctx, se, err, s) + if se.hasFileTransInConn() { + // The query will be handled later in handleFileTransInConn, + // then should call the ExecStmt.FinishExecuteStmt to finish this statement. + se.SetValue(ExecStmtVarKey, s.(*executor.ExecStmt)) + } else { + // If it is not a select statement or special query, we record its slow log here, + // then it could include the transaction commit time. + s.(*executor.ExecStmt).FinishExecuteStmt(origTxnCtx.StartTS, err, false) + } + return nil, err +} + +// ExecStmtVarKeyType is a dummy type to avoid naming collision in context. +type ExecStmtVarKeyType int + +// String defines a Stringer function for debugging and pretty printing. +func (ExecStmtVarKeyType) String() string { + return "exec_stmt_var_key" +} + +// ExecStmtVarKey is a variable key for ExecStmt. +const ExecStmtVarKey ExecStmtVarKeyType = 0 + +// execStmtResult is the return value of ExecuteStmt and it implements the sqlexec.RecordSet interface. +// Why we need a struct to wrap a RecordSet and provide another RecordSet? +// This is because there are so many session state related things that definitely not belongs to the original +// RecordSet, so this struct exists and RecordSet.Close() is overridden to handle that. +type execStmtResult struct { + sqlexec.RecordSet + se *session + sql sqlexec.Statement + once sync.Once + closed bool +} + +func (rs *execStmtResult) Finish() error { + var err error + rs.once.Do(func() { + var err1 error + if f, ok := rs.RecordSet.(interface{ Finish() error }); ok { + err1 = f.Finish() + } + err2 := finishStmt(context.Background(), rs.se, err, rs.sql) + if err1 != nil { + err = err1 + } else { + err = err2 + } + }) + return err +} + +func (rs *execStmtResult) Close() error { + if rs.closed { + return nil + } + err1 := rs.Finish() + err2 := rs.RecordSet.Close() + rs.closed = true + if err1 != nil { + return err1 + } + return err2 +} + +func (rs *execStmtResult) TryDetach() (sqlexec.RecordSet, bool, error) { + if !rs.sql.IsReadOnly(rs.se.GetSessionVars()) { + return nil, false, nil + } + if !plannercore.IsAutoCommitTxn(rs.se.GetSessionVars()) { + return nil, false, nil + } + + drs, ok := rs.RecordSet.(sqlexec.DetachableRecordSet) + if !ok { + return nil, false, nil + } + detachedRS, ok, err := drs.TryDetach() + if !ok || err != nil { + return nil, ok, err + } + cursorHandle := rs.se.GetCursorTracker().NewCursor( + cursor.State{StartTS: rs.se.GetSessionVars().TxnCtx.StartTS}, + ) + crs := staticrecordset.WrapRecordSetWithCursor(cursorHandle, detachedRS) + + // Now, a transaction is not needed for the detached record set, so we commit the transaction and cleanup + // the session state. + err = finishStmt(context.Background(), rs.se, nil, rs.sql) + if err != nil { + err2 := detachedRS.Close() + if err2 != nil { + logutil.BgLogger().Error("close detached record set failed", zap.Error(err2)) + } + return nil, true, err + } + + return crs, true, nil +} + +// GetExecutor4Test exports the internal executor for test purpose. +func (rs *execStmtResult) GetExecutor4Test() any { + return rs.RecordSet.(interface{ GetExecutor4Test() any }).GetExecutor4Test() +} + +// rollbackOnError makes sure the next statement starts a new transaction with the latest InfoSchema. +func (s *session) rollbackOnError(ctx context.Context) { + if !s.sessionVars.InTxn() { + s.RollbackTxn(ctx) + } +} + +// PrepareStmt is used for executing prepare statement in binary protocol +func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) { + defer func() { + if s.sessionVars.StmtCtx != nil { + s.sessionVars.StmtCtx.DetachMemDiskTracker() + } + }() + if s.sessionVars.TxnCtx.InfoSchema == nil { + // We don't need to create a transaction for prepare statement, just get information schema will do. + s.sessionVars.TxnCtx.InfoSchema = domain.GetDomain(s).InfoSchema() + } + err = s.loadCommonGlobalVariablesIfNeeded() + if err != nil { + return + } + + ctx := context.Background() + // NewPrepareExec may need startTS to build the executor, for example prepare statement has subquery in int. + // So we have to call PrepareTxnCtx here. + if err = s.PrepareTxnCtx(ctx); err != nil { + return + } + + prepareStmt := &ast.PrepareStmt{SQLText: sql} + if err = s.onTxnManagerStmtStartOrRetry(ctx, prepareStmt); err != nil { + return + } + + if err = sessiontxn.GetTxnManager(s).AdviseWarmup(); err != nil { + return + } + prepareExec := executor.NewPrepareExec(s, sql) + err = prepareExec.Next(ctx, nil) + // Rollback even if err is nil. + s.rollbackOnError(ctx) + + if err != nil { + return + } + return prepareExec.ID, prepareExec.ParamCount, prepareExec.Fields, nil +} + +// ExecutePreparedStmt executes a prepared statement. +func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, params []expression.Expression) (sqlexec.RecordSet, error) { + prepStmt, err := s.sessionVars.GetPreparedStmtByID(stmtID) + if err != nil { + err = plannererrors.ErrStmtNotFound + logutil.Logger(ctx).Error("prepared statement not found", zap.Uint32("stmtID", stmtID)) + return nil, err + } + stmt, ok := prepStmt.(*plannercore.PlanCacheStmt) + if !ok { + return nil, errors.Errorf("invalid PlanCacheStmt type") + } + execStmt := &ast.ExecuteStmt{ + BinaryArgs: params, + PrepStmt: stmt, + } + return s.ExecuteStmt(ctx, execStmt) +} + +func (s *session) DropPreparedStmt(stmtID uint32) error { + vars := s.sessionVars + if _, ok := vars.PreparedStmts[stmtID]; !ok { + return plannererrors.ErrStmtNotFound + } + vars.RetryInfo.DroppedPreparedStmtIDs = append(vars.RetryInfo.DroppedPreparedStmtIDs, stmtID) + return nil +} + +func (s *session) Txn(active bool) (kv.Transaction, error) { + if !active { + return &s.txn, nil + } + _, err := sessiontxn.GetTxnManager(s).ActivateTxn() + s.SetMemoryFootprintChangeHook() + return &s.txn, err +} + +func (s *session) SetValue(key fmt.Stringer, value any) { + s.mu.Lock() + s.mu.values[key] = value + s.mu.Unlock() +} + +func (s *session) Value(key fmt.Stringer) any { + s.mu.RLock() + value := s.mu.values[key] + s.mu.RUnlock() + return value +} + +func (s *session) ClearValue(key fmt.Stringer) { + s.mu.Lock() + delete(s.mu.values, key) + s.mu.Unlock() +} + +type inCloseSession struct{} + +// Close function does some clean work when session end. +// Close should release the table locks which hold by the session. +func (s *session) Close() { + // TODO: do clean table locks when session exited without execute Close. + // TODO: do clean table locks when tidb-server was `kill -9`. + if s.HasLockedTables() && config.TableLockEnabled() { + if ds := config.TableLockDelayClean(); ds > 0 { + time.Sleep(time.Duration(ds) * time.Millisecond) + } + lockedTables := s.GetAllTableLocks() + err := domain.GetDomain(s).DDLExecutor().UnlockTables(s, lockedTables) + if err != nil { + logutil.BgLogger().Error("release table lock failed", zap.Uint64("conn", s.sessionVars.ConnectionID)) + } + } + s.ReleaseAllAdvisoryLocks() + if s.statsCollector != nil { + s.statsCollector.Delete() + } + if s.idxUsageCollector != nil { + s.idxUsageCollector.Flush() + } + bindValue := s.Value(bindinfo.SessionBindInfoKeyType) + if bindValue != nil { + bindValue.(bindinfo.SessionBindingHandle).Close() + } + ctx := context.WithValue(context.TODO(), inCloseSession{}, struct{}{}) + s.RollbackTxn(ctx) + if s.sessionVars != nil { + s.sessionVars.WithdrawAllPreparedStmt() + } + if s.stmtStats != nil { + s.stmtStats.SetFinished() + } + s.sessionVars.ClearDiskFullOpt() + if s.sessionPlanCache != nil { + s.sessionPlanCache.Close() + } +} + +// GetSessionVars implements the context.Context interface. +func (s *session) GetSessionVars() *variable.SessionVars { + return s.sessionVars +} + +// GetPlanCtx returns the PlanContext. +func (s *session) GetPlanCtx() planctx.PlanContext { + return s.pctx +} + +// GetExprCtx returns the expression context of the session. +func (s *session) GetExprCtx() exprctx.ExprContext { + return s.exprctx +} + +// GetTableCtx returns the table.MutateContext +func (s *session) GetTableCtx() tbctx.MutateContext { + return s.tblctx +} + +// GetDistSQLCtx returns the context used in DistSQL +func (s *session) GetDistSQLCtx() *distsqlctx.DistSQLContext { + vars := s.GetSessionVars() + sc := vars.StmtCtx + + return sc.GetOrInitDistSQLFromCache(func() *distsqlctx.DistSQLContext { + return &distsqlctx.DistSQLContext{ + WarnHandler: sc.WarnHandler, + InRestrictedSQL: sc.InRestrictedSQL, + Client: s.GetClient(), + + EnabledRateLimitAction: vars.EnabledRateLimitAction, + EnableChunkRPC: vars.EnableChunkRPC, + OriginalSQL: sc.OriginalSQL, + KVVars: vars.KVVars, + KvExecCounter: sc.KvExecCounter, + SessionMemTracker: vars.MemTracker, + + Location: sc.TimeZone(), + RuntimeStatsColl: sc.RuntimeStatsColl, + SQLKiller: &vars.SQLKiller, + ErrCtx: sc.ErrCtx(), + + TiFlashReplicaRead: vars.TiFlashReplicaRead, + TiFlashMaxThreads: vars.TiFlashMaxThreads, + TiFlashMaxBytesBeforeExternalJoin: vars.TiFlashMaxBytesBeforeExternalJoin, + TiFlashMaxBytesBeforeExternalGroupBy: vars.TiFlashMaxBytesBeforeExternalGroupBy, + TiFlashMaxBytesBeforeExternalSort: vars.TiFlashMaxBytesBeforeExternalSort, + TiFlashMaxQueryMemoryPerNode: vars.TiFlashMaxQueryMemoryPerNode, + TiFlashQuerySpillRatio: vars.TiFlashQuerySpillRatio, + + DistSQLConcurrency: vars.DistSQLScanConcurrency(), + ReplicaReadType: vars.GetReplicaRead(), + WeakConsistency: sc.WeakConsistency, + RCCheckTS: sc.RCCheckTS, + NotFillCache: sc.NotFillCache, + TaskID: sc.TaskID, + Priority: sc.Priority, + ResourceGroupTagger: sc.GetResourceGroupTagger(), + EnablePaging: vars.EnablePaging, + MinPagingSize: vars.MinPagingSize, + MaxPagingSize: vars.MaxPagingSize, + RequestSourceType: vars.RequestSourceType, + ExplicitRequestSourceType: vars.ExplicitRequestSourceType, + StoreBatchSize: vars.StoreBatchSize, + ResourceGroupName: sc.ResourceGroupName, + LoadBasedReplicaReadThreshold: vars.LoadBasedReplicaReadThreshold, + RunawayChecker: sc.RunawayChecker, + TiKVClientReadTimeout: vars.GetTiKVClientReadTimeout(), + + ReplicaClosestReadThreshold: vars.ReplicaClosestReadThreshold, + ConnectionID: vars.ConnectionID, + SessionAlias: vars.SessionAlias, + + ExecDetails: &sc.SyncExecDetails, + } + }) +} + +// GetRangerCtx returns the context used in `ranger` related functions +func (s *session) GetRangerCtx() *rangerctx.RangerContext { + vars := s.GetSessionVars() + sc := vars.StmtCtx + + rctx := sc.GetOrInitRangerCtxFromCache(func() any { + return &rangerctx.RangerContext{ + ExprCtx: s.GetExprCtx(), + TypeCtx: s.GetSessionVars().StmtCtx.TypeCtx(), + ErrCtx: s.GetSessionVars().StmtCtx.ErrCtx(), + + InPreparedPlanBuilding: s.GetSessionVars().StmtCtx.InPreparedPlanBuilding, + RegardNULLAsPoint: s.GetSessionVars().RegardNULLAsPoint, + OptPrefixIndexSingleScan: s.GetSessionVars().OptPrefixIndexSingleScan, + OptimizerFixControl: s.GetSessionVars().OptimizerFixControl, + + PlanCacheTracker: &s.GetSessionVars().StmtCtx.PlanCacheTracker, + RangeFallbackHandler: &s.GetSessionVars().StmtCtx.RangeFallbackHandler, + } + }) + + return rctx.(*rangerctx.RangerContext) +} + +// GetBuildPBCtx returns the context used in `ToPB` method +func (s *session) GetBuildPBCtx() *planctx.BuildPBContext { + vars := s.GetSessionVars() + sc := vars.StmtCtx + + bctx := sc.GetOrInitBuildPBCtxFromCache(func() any { + return &planctx.BuildPBContext{ + ExprCtx: s.GetExprCtx(), + Client: s.GetClient(), + + TiFlashFastScan: s.GetSessionVars().TiFlashFastScan, + TiFlashFineGrainedShuffleBatchSize: s.GetSessionVars().TiFlashFineGrainedShuffleBatchSize, + + // the following fields are used to build `expression.PushDownContext`. + // TODO: it'd be better to embed `expression.PushDownContext` in `BuildPBContext`. But `expression` already + // depends on this package, so we need to move `expression.PushDownContext` to a standalone package first. + GroupConcatMaxLen: s.GetSessionVars().GroupConcatMaxLen, + InExplainStmt: s.GetSessionVars().StmtCtx.InExplainStmt, + WarnHandler: s.GetSessionVars().StmtCtx.WarnHandler, + ExtraWarnghandler: s.GetSessionVars().StmtCtx.ExtraWarnHandler, + } + }) + + return bctx.(*planctx.BuildPBContext) +} + +func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { + pm := privilege.GetPrivilegeManager(s) + authplugin, err := pm.GetAuthPluginForConnection(user.Username, user.Hostname) + if err != nil { + return "", err + } + return authplugin, nil +} + +// Auth validates a user using an authentication string and salt. +// If the password fails, it will keep trying other users until exhausted. +// This means it can not be refactored to use MatchIdentity yet. +func (s *session) Auth(user *auth.UserIdentity, authentication, salt []byte, authConn conn.AuthConn) error { + hasPassword := "YES" + if len(authentication) == 0 { + hasPassword = "NO" + } + pm := privilege.GetPrivilegeManager(s) + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return privileges.ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) + } + // Check whether continuous login failure is enabled to lock the account. + // If enabled, determine whether to unlock the account and notify TiDB to update the cache. + enableAutoLock := pm.IsAccountAutoLockEnabled(authUser.Username, authUser.Hostname) + if enableAutoLock { + err = failedLoginTrackingBegin(s) + if err != nil { + return err + } + lockStatusChanged, err := verifyAccountAutoLock(s, authUser.Username, authUser.Hostname) + if err != nil { + rollbackErr := failedLoginTrackingRollback(s) + if rollbackErr != nil { + return rollbackErr + } + return err + } + err = failedLoginTrackingCommit(s) + if err != nil { + rollbackErr := failedLoginTrackingRollback(s) + if rollbackErr != nil { + return rollbackErr + } + return err + } + if lockStatusChanged { + // Notification auto unlock. + err = domain.GetDomain(s).NotifyUpdatePrivilege() + if err != nil { + return err + } + } + } + + info, err := pm.ConnectionVerification(user, authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars, authConn) + if err != nil { + if info.FailedDueToWrongPassword { + // when user enables the account locking function for consecutive login failures, + // the system updates the login failure count and determines whether to lock the account when authentication fails. + if enableAutoLock { + err := failedLoginTrackingBegin(s) + if err != nil { + return err + } + lockStatusChanged, passwordLocking, trackingErr := authFailedTracking(s, authUser.Username, authUser.Hostname) + if trackingErr != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return trackingErr + } + if err := failedLoginTrackingCommit(s); err != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return err + } + if lockStatusChanged { + // Notification auto lock. + err := autolockAction(s, passwordLocking, authUser.Username, authUser.Hostname) + if err != nil { + return err + } + } + } + } + return err + } + + if variable.EnableResourceControl.Load() && info.ResourceGroupName != "" { + s.sessionVars.SetResourceGroupName(info.ResourceGroupName) + } + + if info.InSandBoxMode { + // Enter sandbox mode, only execute statement for resetting password. + s.EnableSandBoxMode() + } + if enableAutoLock { + err := failedLoginTrackingBegin(s) + if err != nil { + return err + } + // The password is correct. If the account is not locked, the number of login failure statistics will be cleared. + err = authSuccessClearCount(s, authUser.Username, authUser.Hostname) + if err != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return err + } + err = failedLoginTrackingCommit(s) + if err != nil { + if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { + return rollBackErr + } + return err + } + } + pm.AuthSuccess(authUser.Username, authUser.Hostname) + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname + s.sessionVars.User = user + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) + return nil +} + +func authSuccessClearCount(s *session, user string, host string) error { + // Obtain accurate lock status and failure count information. + passwordLocking, err := getFailedLoginUserAttributes(s, user, host) + if err != nil { + return err + } + // If the account is locked, it may be caused by the untimely update of the cache, + // directly report the account lock. + if passwordLocking.AutoAccountLocked { + if passwordLocking.PasswordLockTimeDays == -1 { + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, + "unlimited", "unlimited") + } + + lds := strconv.FormatInt(passwordLocking.PasswordLockTimeDays, 10) + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, lds, lds) + } + if passwordLocking.FailedLoginCount != 0 { + // If the number of account login failures is not zero, it will be updated to 0. + passwordLockingJSON := privileges.BuildSuccessPasswordLockingJSON(passwordLocking.FailedLoginAttempts, + passwordLocking.PasswordLockTimeDays) + if passwordLockingJSON != "" { + if err := s.passwordLocking(user, host, passwordLockingJSON); err != nil { + return err + } + } + } + return nil +} + +func verifyAccountAutoLock(s *session, user, host string) (bool, error) { + pm := privilege.GetPrivilegeManager(s) + // Use the cache to determine whether to unlock the account. + // If the account needs to be unlocked, read the database information to determine whether + // the account needs to be unlocked. Otherwise, an error message is displayed. + lockStatusInMemory, err := pm.VerifyAccountAutoLockInMemory(user, host) + if err != nil { + return false, err + } + // If the lock status in the cache is Unlock, the automatic unlock is skipped. + // If memory synchronization is slow and there is a lock in the database, it will be processed upon successful login. + if !lockStatusInMemory { + return false, nil + } + lockStatusChanged := false + var plJSON string + // After checking the cache, obtain the latest data from the database and determine + // whether to automatically unlock the database to prevent repeated unlock errors. + pl, err := getFailedLoginUserAttributes(s, user, host) + if err != nil { + return false, err + } + if pl.AutoAccountLocked { + // If it is locked, need to check whether it can be automatically unlocked. + lockTimeDay := pl.PasswordLockTimeDays + if lockTimeDay == -1 { + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, "unlimited", "unlimited") + } + lastChanged := pl.AutoLockedLastChanged + d := time.Now().Unix() - lastChanged + if d <= lockTimeDay*24*60*60 { + lds := strconv.FormatInt(lockTimeDay, 10) + rds := strconv.FormatInt(int64(math.Ceil(float64(lockTimeDay)-float64(d)/(24*60*60))), 10) + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, lds, rds) + } + // Generate unlock json string. + plJSON = privileges.BuildPasswordLockingJSON(pl.FailedLoginAttempts, + pl.PasswordLockTimeDays, "N", 0, time.Now().Format(time.UnixDate)) + } + if plJSON != "" { + lockStatusChanged = true + if err = s.passwordLocking(user, host, plJSON); err != nil { + return false, err + } + } + return lockStatusChanged, nil +} + +func authFailedTracking(s *session, user string, host string) (bool, *privileges.PasswordLocking, error) { + // Obtain the number of consecutive password login failures. + passwordLocking, err := getFailedLoginUserAttributes(s, user, host) + if err != nil { + return false, nil, err + } + // Consecutive wrong password login failure times +1, + // If the lock condition is satisfied, the lock status is updated and the update cache is notified. + lockStatusChanged, err := userAutoAccountLocked(s, user, host, passwordLocking) + if err != nil { + return false, nil, err + } + return lockStatusChanged, passwordLocking, nil +} + +func autolockAction(s *session, passwordLocking *privileges.PasswordLocking, user, host string) error { + // Don't want to update the cache frequently, and only trigger the update cache when the lock status is updated. + err := domain.GetDomain(s).NotifyUpdatePrivilege() + if err != nil { + return err + } + // The number of failed login attempts reaches FAILED_LOGIN_ATTEMPTS. + // An error message is displayed indicating permission denial and account lock. + if passwordLocking.PasswordLockTimeDays == -1 { + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, + "unlimited", "unlimited") + } + lds := strconv.FormatInt(passwordLocking.PasswordLockTimeDays, 10) + return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, lds, lds) +} + +func (s *session) passwordLocking(user string, host string, newAttributesStr string) error { + sql := new(strings.Builder) + sqlescape.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable) + sqlescape.MustFormatSQL(sql, "user_attributes=json_merge_patch(coalesce(user_attributes, '{}'), %?)", newAttributesStr) + sqlescape.MustFormatSQL(sql, " WHERE Host=%? and User=%?;", host, user) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, sql.String()) + return err +} + +func failedLoginTrackingBegin(s *session) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, "BEGIN PESSIMISTIC") + return err +} + +func failedLoginTrackingCommit(s *session) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, "COMMIT") + if err != nil { + _, rollBackErr := s.ExecuteInternal(ctx, "ROLLBACK") + if rollBackErr != nil { + return rollBackErr + } + } + return err +} + +func failedLoginTrackingRollback(s *session) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + _, err := s.ExecuteInternal(ctx, "ROLLBACK") + return err +} + +// getFailedLoginUserAttributes queries the exact number of consecutive password login failures (concurrency is not allowed). +func getFailedLoginUserAttributes(s *session, user string, host string) (*privileges.PasswordLocking, error) { + passwordLocking := &privileges.PasswordLocking{} + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) + rs, err := s.ExecuteInternal(ctx, `SELECT user_attributes from mysql.user WHERE USER = %? AND HOST = %? for update`, user, host) + if err != nil { + return passwordLocking, err + } + defer func() { + if closeErr := rs.Close(); closeErr != nil { + err = closeErr + } + }() + req := rs.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + err = rs.Next(ctx, req) + if err != nil { + return passwordLocking, err + } + if req.NumRows() == 0 { + return passwordLocking, fmt.Errorf("user_attributes by `%s`@`%s` not found", user, host) + } + row := iter.Begin() + if !row.IsNull(0) { + passwordLockingJSON := row.GetJSON(0) + return passwordLocking, passwordLocking.ParseJSON(passwordLockingJSON) + } + return passwordLocking, fmt.Errorf("user_attributes by `%s`@`%s` not found", user, host) +} + +func userAutoAccountLocked(s *session, user string, host string, pl *privileges.PasswordLocking) (bool, error) { + // Indicates whether the user needs to update the lock status change. + lockStatusChanged := false + // The number of consecutive login failures is stored in the database. + // If the current login fails, one is added to the number of consecutive login failures + // stored in the database to determine whether the user needs to be locked and the number of update failures. + failedLoginCount := pl.FailedLoginCount + 1 + // If the cache is not updated, but it is already locked, it will report that the account is locked. + if pl.AutoAccountLocked { + if pl.PasswordLockTimeDays == -1 { + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, + "unlimited", "unlimited") + } + lds := strconv.FormatInt(pl.PasswordLockTimeDays, 10) + return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, lds, lds) + } + + autoAccountLocked := "N" + autoLockedLastChanged := "" + if pl.FailedLoginAttempts == 0 || pl.PasswordLockTimeDays == 0 { + return false, nil + } + + if failedLoginCount >= pl.FailedLoginAttempts { + autoLockedLastChanged = time.Now().Format(time.UnixDate) + autoAccountLocked = "Y" + lockStatusChanged = true + } + + newAttributesStr := privileges.BuildPasswordLockingJSON(pl.FailedLoginAttempts, + pl.PasswordLockTimeDays, autoAccountLocked, failedLoginCount, autoLockedLastChanged) + if newAttributesStr != "" { + return lockStatusChanged, s.passwordLocking(user, host, newAttributesStr) + } + return lockStatusChanged, nil +} + +// MatchIdentity finds the matching username + password in the MySQL privilege tables +// for a username + hostname, since MySQL can have wildcards. +func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { + pm := privilege.GetPrivilegeManager(s) + var success bool + var skipNameResolve bool + var user = &auth.UserIdentity{} + varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) + if err == nil && variable.TiDBOptOn(varVal) { + skipNameResolve = true + } + user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve) + if success { + return user, nil + } + // This error will not be returned to the user, access denied will be instead + return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) +} + +// AuthWithoutVerification is required by the ResetConnection RPC +func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { + pm := privilege.GetPrivilegeManager(s) + authUser, err := s.MatchIdentity(user.Username, user.Hostname) + if err != nil { + return false + } + if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) { + user.AuthUsername = authUser.Username + user.AuthHostname = authUser.Hostname + s.sessionVars.User = user + s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) + return true + } + return false +} + +// SetSessionStatesHandler implements the Session.SetSessionStatesHandler interface. +func (s *session) SetSessionStatesHandler(stateType sessionstates.SessionStateType, handler sessionctx.SessionStatesHandler) { + s.sessionStatesHandlers[stateType] = handler +} + +// ReportUsageStats reports the usage stats +func (s *session) ReportUsageStats() { + if s.idxUsageCollector != nil { + s.idxUsageCollector.Report() + } +} + +// CreateSession4Test creates a new session environment for test. +func CreateSession4Test(store kv.Storage) (types.Session, error) { + se, err := CreateSession4TestWithOpt(store, nil) + if err == nil { + // Cover both chunk rpc encoding and default encoding. + // nolint:gosec + if rand.Intn(2) == 0 { + se.GetSessionVars().EnableChunkRPC = false + } else { + se.GetSessionVars().EnableChunkRPC = true + } + } + return se, err +} + +// Opt describes the option for creating session +type Opt struct { + PreparedPlanCache sessionctx.SessionPlanCache +} + +// CreateSession4TestWithOpt creates a new session environment for test. +func CreateSession4TestWithOpt(store kv.Storage, opt *Opt) (types.Session, error) { + s, err := CreateSessionWithOpt(store, opt) + if err == nil { + // initialize session variables for test. + s.GetSessionVars().InitChunkSize = 2 + s.GetSessionVars().MaxChunkSize = 32 + s.GetSessionVars().MinPagingSize = variable.DefMinPagingSize + s.GetSessionVars().EnablePaging = variable.DefTiDBEnablePaging + s.GetSessionVars().StmtCtx.SetTimeZone(s.GetSessionVars().Location()) + err = s.GetSessionVars().SetSystemVarWithoutValidation(variable.CharacterSetConnection, "utf8mb4") + } + return s, err +} + +// CreateSession creates a new session environment. +func CreateSession(store kv.Storage) (types.Session, error) { + return CreateSessionWithOpt(store, nil) +} + +// CreateSessionWithOpt creates a new session environment with option. +// Use default option if opt is nil. +func CreateSessionWithOpt(store kv.Storage, opt *Opt) (types.Session, error) { + s, err := createSessionWithOpt(store, opt) + if err != nil { + return nil, err + } + + // Add auth here. + do, err := domap.Get(store) + if err != nil { + return nil, err + } + extensions, err := extension.GetExtensions() + if err != nil { + return nil, err + } + pm := privileges.NewUserPrivileges(do.PrivilegeHandle(), extensions) + privilege.BindPrivilegeManager(s, pm) + + // Add stats collector, and it will be freed by background stats worker + // which periodically updates stats using the collected data. + if do.StatsHandle() != nil && do.StatsUpdating() { + s.statsCollector = do.StatsHandle().NewSessionStatsItem().(*usage.SessionStatsItem) + if config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load() { + s.idxUsageCollector = do.StatsHandle().NewSessionIndexUsageCollector() + } + } + + s.cursorTracker = cursor.NewTracker() + + return s, nil +} + +// loadCollationParameter loads collation parameter from mysql.tidb +func loadCollationParameter(ctx context.Context, se *session) (bool, error) { + para, err := se.getTableValue(ctx, mysql.TiDBTable, TidbNewCollationEnabled) + if err != nil { + return false, err + } + if para == varTrue { + return true, nil + } else if para == varFalse { + return false, nil + } + logutil.BgLogger().Warn( + "Unexpected value of 'new_collation_enabled' in 'mysql.tidb', use 'False' instead", + zap.String("value", para)) + return false, nil +} + +type tableBasicInfo struct { + SQL string + id int64 +} + +var ( + errResultIsEmpty = dbterror.ClassExecutor.NewStd(errno.ErrResultIsEmpty) + // DDLJobTables is a list of tables definitions used in concurrent DDL. + DDLJobTables = []tableBasicInfo{ + {ddl.JobTableSQL, ddl.JobTableID}, + {ddl.ReorgTableSQL, ddl.ReorgTableID}, + {ddl.HistoryTableSQL, ddl.HistoryTableID}, + } + // BackfillTables is a list of tables definitions used in dist reorg DDL. + BackfillTables = []tableBasicInfo{ + {ddl.BackgroundSubtaskTableSQL, ddl.BackgroundSubtaskTableID}, + {ddl.BackgroundSubtaskHistoryTableSQL, ddl.BackgroundSubtaskHistoryTableID}, + } + mdlTable = `create table mysql.tidb_mdl_info( + job_id BIGINT NOT NULL PRIMARY KEY, + version BIGINT NOT NULL, + table_ids text(65535), + owner_id varchar(64) NOT NULL DEFAULT '' + );` +) + +func splitAndScatterTable(store kv.Storage, tableIDs []int64) { + if s, ok := store.(kv.SplittableStore); ok && atomic.LoadUint32(&ddl.EnableSplitTableRegion) == 1 { + ctxWithTimeout, cancel := context.WithTimeout(context.Background(), variable.DefWaitSplitRegionTimeout*time.Second) + var regionIDs []uint64 + for _, id := range tableIDs { + regionIDs = append(regionIDs, ddl.SplitRecordRegion(ctxWithTimeout, s, id, id, variable.DefTiDBScatterRegion)) + } + if variable.DefTiDBScatterRegion { + ddl.WaitScatterRegionFinish(ctxWithTimeout, s, regionIDs...) + } + cancel() + } +} + +// InitDDLJobTables is to create tidb_ddl_job, tidb_ddl_reorg and tidb_ddl_history, or tidb_background_subtask and tidb_background_subtask_history. +func InitDDLJobTables(store kv.Storage, targetVer meta.DDLTableVersion) error { + targetTables := DDLJobTables + if targetVer == meta.BackfillTableVersion { + targetTables = BackfillTables + } + return kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + tableVer, err := t.CheckDDLTableVersion() + if err != nil || tableVer >= targetVer { + return errors.Trace(err) + } + dbID, err := t.CreateMySQLDatabaseIfNotExists() + if err != nil { + return err + } + if err = createAndSplitTables(store, t, dbID, targetTables); err != nil { + return err + } + return t.SetDDLTables(targetVer) + }) +} + +func createAndSplitTables(store kv.Storage, t *meta.Meta, dbID int64, tables []tableBasicInfo) error { + tableIDs := make([]int64, 0, len(tables)) + for _, tbl := range tables { + tableIDs = append(tableIDs, tbl.id) + } + splitAndScatterTable(store, tableIDs) + p := parser.New() + for _, tbl := range tables { + stmt, err := p.ParseOneStmt(tbl.SQL, "", "") + if err != nil { + return errors.Trace(err) + } + tblInfo, err := ddl.BuildTableInfoFromAST(stmt.(*ast.CreateTableStmt)) + if err != nil { + return errors.Trace(err) + } + tblInfo.State = model.StatePublic + tblInfo.ID = tbl.id + tblInfo.UpdateTS = t.StartTS + err = t.CreateTableOrView(dbID, tblInfo) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +// InitMDLTable is to create tidb_mdl_info, which is used for metadata lock. +func InitMDLTable(store kv.Storage) error { + return kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + ver, err := t.CheckDDLTableVersion() + if err != nil || ver >= meta.MDLTableVersion { + return errors.Trace(err) + } + dbID, err := t.CreateMySQLDatabaseIfNotExists() + if err != nil { + return err + } + splitAndScatterTable(store, []int64{ddl.MDLTableID}) + p := parser.New() + stmt, err := p.ParseOneStmt(mdlTable, "", "") + if err != nil { + return errors.Trace(err) + } + tblInfo, err := ddl.BuildTableInfoFromAST(stmt.(*ast.CreateTableStmt)) + if err != nil { + return errors.Trace(err) + } + tblInfo.State = model.StatePublic + tblInfo.ID = ddl.MDLTableID + tblInfo.UpdateTS = t.StartTS + err = t.CreateTableOrView(dbID, tblInfo) + if err != nil { + return errors.Trace(err) + } + + return t.SetDDLTables(meta.MDLTableVersion) + }) +} + +// InitMDLVariableForBootstrap initializes the metadata lock variable. +func InitMDLVariableForBootstrap(store kv.Storage) error { + err := kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + return t.SetMetadataLock(true) + }) + if err != nil { + return err + } + variable.EnableMDL.Store(true) + return nil +} + +// InitTiDBSchemaCacheSize initializes the tidb schema cache size. +func InitTiDBSchemaCacheSize(store kv.Storage) error { + var ( + isNull bool + size uint64 + err error + ) + err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + size, isNull, err = t.GetSchemaCacheSize() + if err != nil { + return errors.Trace(err) + } + if isNull { + size = variable.DefTiDBSchemaCacheSize + return t.SetSchemaCacheSize(size) + } + return nil + }) + if err != nil { + return errors.Trace(err) + } + variable.SchemaCacheSize.Store(size) + return nil +} + +// InitMDLVariableForUpgrade initializes the metadata lock variable. +func InitMDLVariableForUpgrade(store kv.Storage) (bool, error) { + isNull := false + enable := false + var err error + err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + enable, isNull, err = t.GetMetadataLock() + if err != nil { + return err + } + return nil + }) + if isNull || !enable { + variable.EnableMDL.Store(false) + } else { + variable.EnableMDL.Store(true) + } + return isNull, err +} + +// InitMDLVariable initializes the metadata lock variable. +func InitMDLVariable(store kv.Storage) error { + isNull := false + enable := false + var err error + err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + enable, isNull, err = t.GetMetadataLock() + if err != nil { + return err + } + if isNull { + // Workaround for version: nightly-2022-11-07 to nightly-2022-11-17. + enable = true + logutil.BgLogger().Warn("metadata lock is null") + err = t.SetMetadataLock(true) + if err != nil { + return err + } + } + return nil + }) + variable.EnableMDL.Store(enable) + return err +} + +// BootstrapSession bootstrap session and domain. +func BootstrapSession(store kv.Storage) (*domain.Domain, error) { + return bootstrapSessionImpl(store, createSessions) +} + +// BootstrapSession4DistExecution bootstrap session and dom for Distributed execution test, only for unit testing. +func BootstrapSession4DistExecution(store kv.Storage) (*domain.Domain, error) { + return bootstrapSessionImpl(store, createSessions4DistExecution) +} + +// bootstrapSessionImpl bootstraps session and domain. +// the process works as follows: +// - if we haven't bootstrapped to the target version +// - create/init/start domain +// - bootstrap or upgrade, some variables will be initialized and stored to system +// table in the process, such as system time-zone +// - close domain +// +// - create/init another domain +// - initialization global variables from system table that's required to use sessionCtx, +// such as system time zone +// - start domain and other routines. +func bootstrapSessionImpl(store kv.Storage, createSessionsImpl func(store kv.Storage, cnt int) ([]*session, error)) (*domain.Domain, error) { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) + cfg := config.GetGlobalConfig() + if len(cfg.Instance.PluginLoad) > 0 { + err := plugin.Load(context.Background(), plugin.Config{ + Plugins: strings.Split(cfg.Instance.PluginLoad, ","), + PluginDir: cfg.Instance.PluginDir, + }) + if err != nil { + return nil, err + } + } + err := InitDDLJobTables(store, meta.BaseDDLTableVersion) + if err != nil { + return nil, err + } + err = InitMDLTable(store) + if err != nil { + return nil, err + } + err = InitDDLJobTables(store, meta.BackfillTableVersion) + if err != nil { + return nil, err + } + err = InitTiDBSchemaCacheSize(store) + if err != nil { + return nil, err + } + ver := getStoreBootstrapVersion(store) + if ver == notBootstrapped { + runInBootstrapSession(store, bootstrap) + } else if ver < currentBootstrapVersion { + runInBootstrapSession(store, upgrade) + } else { + err = InitMDLVariable(store) + if err != nil { + return nil, err + } + } + + // initiate disttask framework components which need a store + scheduler.RegisterSchedulerFactory( + proto.ImportInto, + func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { + return importinto.NewImportScheduler(ctx, task, param, store.(kv.StorageWithPD)) + }, + ) + taskexecutor.RegisterTaskType( + proto.ImportInto, + func(ctx context.Context, id string, task *proto.Task, table taskexecutor.TaskTable) taskexecutor.TaskExecutor { + return importinto.NewImportExecutor(ctx, id, task, table, store) + }, + ) + + analyzeConcurrencyQuota := int(config.GetGlobalConfig().Performance.AnalyzePartitionConcurrencyQuota) + concurrency := config.GetGlobalConfig().Performance.StatsLoadConcurrency + if concurrency == 0 { + // if concurrency is 0, we will set the concurrency of sync load by CPU. + concurrency = syncload.GetSyncLoadConcurrencyByCPU() + } + if concurrency < 0 { // it is only for test, in the production, negative value is illegal. + concurrency = 0 + } + + ses, err := createSessionsImpl(store, 10) + if err != nil { + return nil, err + } + ses[0].GetSessionVars().InRestrictedSQL = true + + // get system tz from mysql.tidb + tz, err := ses[0].getTableValue(ctx, mysql.TiDBTable, tidbSystemTZ) + if err != nil { + return nil, err + } + timeutil.SetSystemTZ(tz) + + // get the flag from `mysql`.`tidb` which indicating if new collations are enabled. + newCollationEnabled, err := loadCollationParameter(ctx, ses[0]) + if err != nil { + return nil, err + } + collate.SetNewCollationEnabledForTest(newCollationEnabled) + + // only start the domain after we have initialized some global variables. + dom := domain.GetDomain(ses[0]) + err = dom.Start() + if err != nil { + return nil, err + } + + // To deal with the location partition failure caused by inconsistent NewCollationEnabled values(see issue #32416). + rebuildAllPartitionValueMapAndSorted(ses[0]) + + // We should make the load bind-info loop before other loops which has internal SQL. + // Because the internal SQL may access the global bind-info handler. As the result, the data race occurs here as the + // LoadBindInfoLoop inits global bind-info handler. + err = dom.LoadBindInfoLoop(ses[1], ses[2]) + if err != nil { + return nil, err + } + + if !config.GetGlobalConfig().Security.SkipGrantTable { + err = dom.LoadPrivilegeLoop(ses[3]) + if err != nil { + return nil, err + } + } + + // Rebuild sysvar cache in a loop + err = dom.LoadSysVarCacheLoop(ses[4]) + if err != nil { + return nil, err + } + + if config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler { + // Invalid client-go tiflash_compute store cache if necessary. + err = dom.WatchTiFlashComputeNodeChange() + if err != nil { + return nil, err + } + } + + if err = extensionimpl.Bootstrap(context.Background(), dom); err != nil { + return nil, err + } + + if len(cfg.Instance.PluginLoad) > 0 { + err := plugin.Init(context.Background(), plugin.Config{EtcdClient: dom.GetEtcdClient()}) + if err != nil { + return nil, err + } + } + + err = executor.LoadExprPushdownBlacklist(ses[5]) + if err != nil { + return nil, err + } + err = executor.LoadOptRuleBlacklist(ctx, ses[5]) + if err != nil { + return nil, err + } + + planReplayerWorkerCnt := config.GetGlobalConfig().Performance.PlanReplayerDumpWorkerConcurrency + planReplayerWorkersSctx := make([]sessionctx.Context, planReplayerWorkerCnt) + pworkerSes, err := createSessions(store, int(planReplayerWorkerCnt)) + if err != nil { + return nil, err + } + for i := 0; i < int(planReplayerWorkerCnt); i++ { + planReplayerWorkersSctx[i] = pworkerSes[i] + } + // setup plan replayer handle + dom.SetupPlanReplayerHandle(ses[6], planReplayerWorkersSctx) + dom.StartPlanReplayerHandle() + // setup dumpFileGcChecker + dom.SetupDumpFileGCChecker(ses[7]) + dom.DumpFileGcCheckerLoop() + // setup historical stats worker + dom.SetupHistoricalStatsWorker(ses[8]) + dom.StartHistoricalStatsWorker() + failToLoadOrParseSQLFile := false // only used for unit test + if runBootstrapSQLFile { + pm := &privileges.UserPrivileges{ + Handle: dom.PrivilegeHandle(), + } + privilege.BindPrivilegeManager(ses[9], pm) + if err := doBootstrapSQLFile(ses[9]); err != nil && intest.InTest { + failToLoadOrParseSQLFile = true + } + } + // A sub context for update table stats, and other contexts for concurrent stats loading. + cnt := 1 + concurrency + syncStatsCtxs, err := createSessions(store, cnt) + if err != nil { + return nil, err + } + subCtxs := make([]sessionctx.Context, cnt) + for i := 0; i < cnt; i++ { + subCtxs[i] = sessionctx.Context(syncStatsCtxs[i]) + } + + // setup extract Handle + extractWorkers := 1 + sctxs, err := createSessions(store, extractWorkers) + if err != nil { + return nil, err + } + extractWorkerSctxs := make([]sessionctx.Context, 0) + for _, sctx := range sctxs { + extractWorkerSctxs = append(extractWorkerSctxs, sctx) + } + dom.SetupExtractHandle(extractWorkerSctxs) + + // setup init stats loader + initStatsCtx, err := createSession(store) + if err != nil { + return nil, err + } + if err = dom.LoadAndUpdateStatsLoop(subCtxs, initStatsCtx); err != nil { + return nil, err + } + + // init the instance plan cache + dom.InitInstancePlanCache() + + // start TTL job manager after setup stats collector + // because TTL could modify a lot of columns, and need to trigger auto analyze + ttlworker.AttachStatsCollector = func(s sqlexec.SQLExecutor) sqlexec.SQLExecutor { + if s, ok := s.(*session); ok { + return attachStatsCollector(s, dom) + } + return s + } + ttlworker.DetachStatsCollector = func(s sqlexec.SQLExecutor) sqlexec.SQLExecutor { + if s, ok := s.(*session); ok { + return detachStatsCollector(s) + } + return s + } + dom.StartTTLJobManager() + + analyzeCtxs, err := createSessions(store, analyzeConcurrencyQuota) + if err != nil { + return nil, err + } + subCtxs2 := make([]sessionctx.Context, analyzeConcurrencyQuota) + for i := 0; i < analyzeConcurrencyQuota; i++ { + subCtxs2[i] = analyzeCtxs[i] + } + dom.SetupAnalyzeExec(subCtxs2) + dom.LoadSigningCertLoop(cfg.Security.SessionTokenSigningCert, cfg.Security.SessionTokenSigningKey) + + if raw, ok := store.(kv.EtcdBackend); ok { + err = raw.StartGCWorker() + if err != nil { + return nil, err + } + } + + // This only happens in testing, since the failure of loading or parsing sql file + // would panic the bootstrapping. + if intest.InTest && failToLoadOrParseSQLFile { + dom.Close() + return nil, errors.New("Fail to load or parse sql file") + } + err = dom.InitDistTaskLoop() + if err != nil { + return nil, err + } + return dom, err +} + +// GetDomain gets the associated domain for store. +func GetDomain(store kv.Storage) (*domain.Domain, error) { + return domap.Get(store) +} + +// runInBootstrapSession create a special session for bootstrap to run. +// If no bootstrap and storage is remote, we must use a little lease time to +// bootstrap quickly, after bootstrapped, we will reset the lease time. +// TODO: Using a bootstrap tool for doing this may be better later. +func runInBootstrapSession(store kv.Storage, bootstrap func(types.Session)) { + s, err := createSession(store) + if err != nil { + // Bootstrap fail will cause program exit. + logutil.BgLogger().Fatal("createSession error", zap.Error(err)) + } + dom := domain.GetDomain(s) + err = dom.Start() + if err != nil { + // Bootstrap fail will cause program exit. + logutil.BgLogger().Fatal("start domain error", zap.Error(err)) + } + + // For the bootstrap SQLs, the following variables should be compatible with old TiDB versions. + s.sessionVars.EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly + + s.SetValue(sessionctx.Initing, true) + bootstrap(s) + finishBootstrap(store) + s.ClearValue(sessionctx.Initing) + + dom.Close() + if intest.InTest { + infosync.MockGlobalServerInfoManagerEntry.Close() + } + domap.Delete(store) +} + +func createSessions(store kv.Storage, cnt int) ([]*session, error) { + return createSessionsImpl(store, cnt, createSession) +} + +func createSessions4DistExecution(store kv.Storage, cnt int) ([]*session, error) { + domap.Delete(store) + + return createSessionsImpl(store, cnt, createSession4DistExecution) +} + +func createSessionsImpl(store kv.Storage, cnt int, createSessionImpl func(kv.Storage) (*session, error)) ([]*session, error) { + // Then we can create new dom + ses := make([]*session, cnt) + for i := 0; i < cnt; i++ { + se, err := createSessionImpl(store) + if err != nil { + return nil, err + } + ses[i] = se + } + + return ses, nil +} + +// createSession creates a new session. +// Please note that such a session is not tracked by the internal session list. +// This means the min ts reporter is not aware of it and may report a wrong min start ts. +// In most cases you should use a session pool in domain instead. +func createSession(store kv.Storage) (*session, error) { + return createSessionWithOpt(store, nil) +} + +func createSession4DistExecution(store kv.Storage) (*session, error) { + return createSessionWithOpt(store, nil) +} + +func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { + dom, err := domap.Get(store) + if err != nil { + return nil, err + } + s := &session{ + store: store, + ddlOwnerManager: dom.DDL().OwnerManager(), + client: store.GetClient(), + mppClient: store.GetMPPClient(), + stmtStats: stmtstats.CreateStatementStats(), + sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), + } + s.sessionVars = variable.NewSessionVars(s) + s.exprctx = contextsession.NewSessionExprContext(s) + s.pctx = newPlanContextImpl(s) + s.tblctx = tbctximpl.NewTableContextImpl(s) + + if opt != nil && opt.PreparedPlanCache != nil { + s.sessionPlanCache = opt.PreparedPlanCache + } + s.mu.values = make(map[fmt.Stringer]any) + s.lockedTables = make(map[int64]model.TableLockTpInfo) + s.advisoryLocks = make(map[string]*advisoryLock) + + domain.BindDomain(s, dom) + // session implements variable.GlobalVarAccessor. Bind it to ctx. + s.sessionVars.GlobalVarsAccessor = s + s.sessionVars.BinlogClient = binloginfo.GetPumpsClient() + s.txn.init() + + sessionBindHandle := bindinfo.NewSessionBindingHandle() + s.SetValue(bindinfo.SessionBindInfoKeyType, sessionBindHandle) + s.SetSessionStatesHandler(sessionstates.StateBinding, sessionBindHandle) + return s, nil +} + +// attachStatsCollector attaches the stats collector in the dom for the session +func attachStatsCollector(s *session, dom *domain.Domain) *session { + if dom.StatsHandle() != nil && dom.StatsUpdating() { + if s.statsCollector == nil { + s.statsCollector = dom.StatsHandle().NewSessionStatsItem().(*usage.SessionStatsItem) + } + if s.idxUsageCollector == nil && config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load() { + s.idxUsageCollector = dom.StatsHandle().NewSessionIndexUsageCollector() + } + } + + return s +} + +// detachStatsCollector removes the stats collector in the session +func detachStatsCollector(s *session) *session { + if s.statsCollector != nil { + s.statsCollector.Delete() + s.statsCollector = nil + } + if s.idxUsageCollector != nil { + s.idxUsageCollector.Flush() + s.idxUsageCollector = nil + } + return s +} + +// CreateSessionWithDomain creates a new Session and binds it with a Domain. +// We need this because when we start DDL in Domain, the DDL need a session +// to change some system tables. But at that time, we have been already in +// a lock context, which cause we can't call createSession directly. +func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) { + s := &session{ + store: store, + sessionVars: variable.NewSessionVars(nil), + client: store.GetClient(), + mppClient: store.GetMPPClient(), + stmtStats: stmtstats.CreateStatementStats(), + sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), + } + s.exprctx = contextsession.NewSessionExprContext(s) + s.pctx = newPlanContextImpl(s) + s.tblctx = tbctximpl.NewTableContextImpl(s) + s.mu.values = make(map[fmt.Stringer]any) + s.lockedTables = make(map[int64]model.TableLockTpInfo) + domain.BindDomain(s, dom) + // session implements variable.GlobalVarAccessor. Bind it to ctx. + s.sessionVars.GlobalVarsAccessor = s + s.txn.init() + return s, nil +} + +const ( + notBootstrapped = 0 +) + +func getStoreBootstrapVersion(store kv.Storage) int64 { + storeBootstrappedLock.Lock() + defer storeBootstrappedLock.Unlock() + // check in memory + _, ok := storeBootstrapped[store.UUID()] + if ok { + return currentBootstrapVersion + } + + var ver int64 + // check in kv store + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) + err := kv.RunInNewTxn(ctx, store, false, func(_ context.Context, txn kv.Transaction) error { + var err error + t := meta.NewMeta(txn) + ver, err = t.GetBootstrapVersion() + return err + }) + if err != nil { + logutil.BgLogger().Fatal("check bootstrapped failed", + zap.Error(err)) + } + + if ver > notBootstrapped { + // here mean memory is not ok, but other server has already finished it + storeBootstrapped[store.UUID()] = true + } + + modifyBootstrapVersionForTest(ver) + return ver +} + +func finishBootstrap(store kv.Storage) { + setStoreBootstrapped(store.UUID()) + + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) + err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { + t := meta.NewMeta(txn) + err := t.FinishBootstrap(currentBootstrapVersion) + return err + }) + if err != nil { + logutil.BgLogger().Fatal("finish bootstrap failed", + zap.Error(err)) + } +} + +const quoteCommaQuote = "', '" + +// loadCommonGlobalVariablesIfNeeded loads and applies commonly used global variables for the session. +func (s *session) loadCommonGlobalVariablesIfNeeded() error { + vars := s.sessionVars + if vars.CommonGlobalLoaded { + return nil + } + if s.Value(sessionctx.Initing) != nil { + // When running bootstrap or upgrade, we should not access global storage. + // But we need to init max_allowed_packet to use concat function during bootstrap or upgrade. + err := vars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) + if err != nil { + logutil.BgLogger().Error("set system variable max_allowed_packet error", zap.Error(err)) + } + return nil + } + + vars.CommonGlobalLoaded = true + + // Deep copy sessionvar cache + sessionCache, err := domain.GetDomain(s).GetSessionCache() + if err != nil { + return err + } + for varName, varVal := range sessionCache { + if _, ok := vars.GetSystemVar(varName); !ok { + err = vars.SetSystemVarWithRelaxedValidation(varName, varVal) + if err != nil { + if variable.ErrUnknownSystemVar.Equal(err) { + continue // sessionCache is stale; sysvar has likely been unregistered + } + return err + } + } + } + // when client set Capability Flags CLIENT_INTERACTIVE, init wait_timeout with interactive_timeout + if vars.ClientCapability&mysql.ClientInteractive > 0 { + if varVal, ok := vars.GetSystemVar(variable.InteractiveTimeout); ok { + if err := vars.SetSystemVar(variable.WaitTimeout, varVal); err != nil { + return err + } + } + } + return nil +} + +// PrepareTxnCtx begins a transaction, and creates a new transaction context. +// It is called before we execute a sql query. +func (s *session) PrepareTxnCtx(ctx context.Context) error { + s.currentCtx = ctx + if s.txn.validOrPending() { + return nil + } + + txnMode := ast.Optimistic + if !s.sessionVars.IsAutocommit() || (config.GetGlobalConfig().PessimisticTxn. + PessimisticAutoCommit.Load() && !s.GetSessionVars().BulkDMLEnabled) { + if s.sessionVars.TxnMode == ast.Pessimistic { + txnMode = ast.Pessimistic + } + } + + if s.sessionVars.RetryInfo.Retrying { + txnMode = ast.Pessimistic + } + + return sessiontxn.GetTxnManager(s).EnterNewTxn(ctx, &sessiontxn.EnterNewTxnRequest{ + Type: sessiontxn.EnterNewTxnBeforeStmt, + TxnMode: txnMode, + }) +} + +// PrepareTSFuture uses to try to get ts future. +func (s *session) PrepareTSFuture(ctx context.Context, future oracle.Future, scope string) error { + if s.txn.Valid() { + return errors.New("cannot prepare ts future when txn is valid") + } + + failpoint.Inject("assertTSONotRequest", func() { + if _, ok := future.(sessiontxn.ConstantFuture); !ok && !s.isInternal() { + panic("tso shouldn't be requested") + } + }) + + failpoint.InjectContext(ctx, "mockGetTSFail", func() { + future = txnFailFuture{} + }) + + s.txn.changeToPending(&txnFuture{ + future: future, + store: s.store, + txnScope: scope, + pipelined: s.usePipelinedDmlOrWarn(ctx), + }) + return nil +} + +// GetPreparedTxnFuture returns the TxnFuture if it is valid or pending. +// It returns nil otherwise. +func (s *session) GetPreparedTxnFuture() sessionctx.TxnFuture { + if !s.txn.validOrPending() { + return nil + } + return &s.txn +} + +// RefreshTxnCtx implements context.RefreshTxnCtx interface. +func (s *session) RefreshTxnCtx(ctx context.Context) error { + var commitDetail *tikvutil.CommitDetails + ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) + err := s.doCommit(ctx) + if commitDetail != nil { + s.GetSessionVars().StmtCtx.MergeExecDetails(nil, commitDetail) + } + if err != nil { + return err + } + + s.updateStatsDeltaToCollector() + + return sessiontxn.NewTxn(ctx, s) +} + +// GetStore gets the store of session. +func (s *session) GetStore() kv.Storage { + return s.store +} + +func (s *session) ShowProcess() *util.ProcessInfo { + return s.processInfo.Load() +} + +// GetStartTSFromSession returns the startTS in the session `se` +func GetStartTSFromSession(se any) (startTS, processInfoID uint64) { + tmp, ok := se.(*session) + if !ok { + logutil.BgLogger().Error("GetStartTSFromSession failed, can't transform to session struct") + return 0, 0 + } + txnInfo := tmp.TxnInfo() + if txnInfo != nil { + startTS = txnInfo.StartTS + processInfoID = txnInfo.ConnectionID + } + + logutil.BgLogger().Debug( + "GetStartTSFromSession getting startTS of internal session", + zap.Uint64("startTS", startTS), zap.Time("start time", oracle.GetTimeFromTS(startTS))) + + return startTS, processInfoID +} + +// logStmt logs some crucial SQL including: CREATE USER/GRANT PRIVILEGE/CHANGE PASSWORD/DDL etc and normal SQL +// if variable.ProcessGeneralLog is set. +func logStmt(execStmt *executor.ExecStmt, s *session) { + vars := s.GetSessionVars() + isCrucial := false + switch stmt := execStmt.StmtNode.(type) { + case *ast.DropIndexStmt: + isCrucial = true + if stmt.IsHypo { + isCrucial = false + } + case *ast.CreateIndexStmt: + isCrucial = true + if stmt.IndexOption != nil && stmt.IndexOption.Tp == model.IndexTypeHypo { + isCrucial = false + } + case *ast.CreateUserStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.SetPwdStmt, *ast.GrantStmt, + *ast.RevokeStmt, *ast.AlterTableStmt, *ast.CreateDatabaseStmt, *ast.CreateTableStmt, + *ast.DropDatabaseStmt, *ast.DropTableStmt, *ast.RenameTableStmt, *ast.TruncateTableStmt, + *ast.RenameUserStmt: + isCrucial = true + } + + if isCrucial { + user := vars.User + schemaVersion := s.GetInfoSchema().SchemaMetaVersion() + if ss, ok := execStmt.StmtNode.(ast.SensitiveStmtNode); ok { + logutil.BgLogger().Info("CRUCIAL OPERATION", + zap.Uint64("conn", vars.ConnectionID), + zap.Int64("schemaVersion", schemaVersion), + zap.String("secure text", ss.SecureText()), + zap.Stringer("user", user)) + } else { + logutil.BgLogger().Info("CRUCIAL OPERATION", + zap.Uint64("conn", vars.ConnectionID), + zap.Int64("schemaVersion", schemaVersion), + zap.String("cur_db", vars.CurrentDB), + zap.String("sql", execStmt.StmtNode.Text()), + zap.Stringer("user", user)) + } + } else { + logGeneralQuery(execStmt, s, false) + } +} + +func logGeneralQuery(execStmt *executor.ExecStmt, s *session, isPrepared bool) { + vars := s.GetSessionVars() + if variable.ProcessGeneralLog.Load() && !vars.InRestrictedSQL { + var query string + if isPrepared { + query = execStmt.OriginText() + } else { + query = execStmt.GetTextToLog(false) + } + + query = executor.QueryReplacer.Replace(query) + if vars.EnableRedactLog != errors.RedactLogEnable { + query += redact.String(vars.EnableRedactLog, vars.PlanCacheParams.String()) + } + logutil.GeneralLogger.Info("GENERAL_LOG", + zap.Uint64("conn", vars.ConnectionID), + zap.String("session_alias", vars.SessionAlias), + zap.String("user", vars.User.LoginString()), + zap.Int64("schemaVersion", s.GetInfoSchema().SchemaMetaVersion()), + zap.Uint64("txnStartTS", vars.TxnCtx.StartTS), + zap.Uint64("forUpdateTS", vars.TxnCtx.GetForUpdateTS()), + zap.Bool("isReadConsistency", vars.IsIsolation(ast.ReadCommitted)), + zap.String("currentDB", vars.CurrentDB), + zap.Bool("isPessimistic", vars.TxnCtx.IsPessimistic), + zap.String("sessionTxnMode", vars.GetReadableTxnMode()), + zap.String("sql", query)) + } +} + +func (s *session) recordOnTransactionExecution(err error, counter int, duration float64, isInternal bool) { + if s.sessionVars.TxnCtx.IsPessimistic { + if err != nil { + if isInternal { + session_metrics.TransactionDurationPessimisticAbortInternal.Observe(duration) + session_metrics.StatementPerTransactionPessimisticErrorInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationPessimisticAbortGeneral.Observe(duration) + session_metrics.StatementPerTransactionPessimisticErrorGeneral.Observe(float64(counter)) + } + } else { + if isInternal { + session_metrics.TransactionDurationPessimisticCommitInternal.Observe(duration) + session_metrics.StatementPerTransactionPessimisticOKInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationPessimisticCommitGeneral.Observe(duration) + session_metrics.StatementPerTransactionPessimisticOKGeneral.Observe(float64(counter)) + } + } + } else { + if err != nil { + if isInternal { + session_metrics.TransactionDurationOptimisticAbortInternal.Observe(duration) + session_metrics.StatementPerTransactionOptimisticErrorInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationOptimisticAbortGeneral.Observe(duration) + session_metrics.StatementPerTransactionOptimisticErrorGeneral.Observe(float64(counter)) + } + } else { + if isInternal { + session_metrics.TransactionDurationOptimisticCommitInternal.Observe(duration) + session_metrics.StatementPerTransactionOptimisticOKInternal.Observe(float64(counter)) + } else { + session_metrics.TransactionDurationOptimisticCommitGeneral.Observe(duration) + session_metrics.StatementPerTransactionOptimisticOKGeneral.Observe(float64(counter)) + } + } + } +} + +func (s *session) checkPlacementPolicyBeforeCommit() error { + var err error + // Get the txnScope of the transaction we're going to commit. + txnScope := s.GetSessionVars().TxnCtx.TxnScope + if txnScope == "" { + txnScope = kv.GlobalTxnScope + } + if txnScope != kv.GlobalTxnScope { + is := s.GetInfoSchema().(infoschema.InfoSchema) + deltaMap := s.GetSessionVars().TxnCtx.TableDeltaMap + for physicalTableID := range deltaMap { + var tableName string + var partitionName string + tblInfo, _, partInfo := is.FindTableByPartitionID(physicalTableID) + if tblInfo != nil && partInfo != nil { + tableName = tblInfo.Meta().Name.String() + partitionName = partInfo.Name.String() + } else { + tblInfo, _ := is.TableByID(physicalTableID) + tableName = tblInfo.Meta().Name.String() + } + bundle, ok := is.PlacementBundleByPhysicalTableID(physicalTableID) + if !ok { + errMsg := fmt.Sprintf("table %v doesn't have placement policies with txn_scope %v", + tableName, txnScope) + if len(partitionName) > 0 { + errMsg = fmt.Sprintf("table %v's partition %v doesn't have placement policies with txn_scope %v", + tableName, partitionName, txnScope) + } + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) + break + } + dcLocation, ok := bundle.GetLeaderDC(placement.DCLabelKey) + if !ok { + errMsg := fmt.Sprintf("table %v's leader placement policy is not defined", tableName) + if len(partitionName) > 0 { + errMsg = fmt.Sprintf("table %v's partition %v's leader placement policy is not defined", tableName, partitionName) + } + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) + break + } + if dcLocation != txnScope { + errMsg := fmt.Sprintf("table %v's leader location %v is out of txn_scope %v", tableName, dcLocation, txnScope) + if len(partitionName) > 0 { + errMsg = fmt.Sprintf("table %v's partition %v's leader location %v is out of txn_scope %v", + tableName, partitionName, dcLocation, txnScope) + } + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) + break + } + // FIXME: currently we assume the physicalTableID is the partition ID. In future, we should consider the situation + // if the physicalTableID belongs to a Table. + partitionID := physicalTableID + tbl, _, partitionDefInfo := is.FindTableByPartitionID(partitionID) + if tbl != nil { + tblInfo := tbl.Meta() + state := tblInfo.Partition.GetStateByID(partitionID) + if state == model.StateGlobalTxnOnly { + err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs( + fmt.Sprintf("partition %s of table %s can not be written by local transactions when its placement policy is being altered", + tblInfo.Name, partitionDefInfo.Name)) + break + } + } + } + } + return err +} + +func (s *session) SetPort(port string) { + s.sessionVars.Port = port +} + +// GetTxnWriteThroughputSLI implements the Context interface. +func (s *session) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { + return &s.txn.writeSLI +} + +// GetInfoSchema returns snapshotInfoSchema if snapshot schema is set. +// Transaction infoschema is returned if inside an explicit txn. +// Otherwise the latest infoschema is returned. +func (s *session) GetInfoSchema() infoschemactx.MetaOnlyInfoSchema { + vars := s.GetSessionVars() + var is infoschema.InfoSchema + if snap, ok := vars.SnapshotInfoschema.(infoschema.InfoSchema); ok { + logutil.BgLogger().Info("use snapshot schema", zap.Uint64("conn", vars.ConnectionID), zap.Int64("schemaVersion", snap.SchemaMetaVersion())) + is = snap + } else { + vars.TxnCtxMu.Lock() + if vars.TxnCtx != nil { + if tmp, ok := vars.TxnCtx.InfoSchema.(infoschema.InfoSchema); ok { + is = tmp + } + } + vars.TxnCtxMu.Unlock() + } + + if is == nil { + is = domain.GetDomain(s).InfoSchema() + } + + // Override the infoschema if the session has temporary table. + return temptable.AttachLocalTemporaryTableInfoSchema(s, is) +} + +func (s *session) GetDomainInfoSchema() infoschemactx.MetaOnlyInfoSchema { + is := domain.GetDomain(s).InfoSchema() + extIs := &infoschema.SessionExtendedInfoSchema{InfoSchema: is} + return temptable.AttachLocalTemporaryTableInfoSchema(s, extIs) +} + +func getSnapshotInfoSchema(s sessionctx.Context, snapshotTS uint64) (infoschema.InfoSchema, error) { + is, err := domain.GetDomain(s).GetSnapshotInfoSchema(snapshotTS) + if err != nil { + return nil, err + } + // Set snapshot does not affect the witness of the local temporary table. + // The session always see the latest temporary tables. + return temptable.AttachLocalTemporaryTableInfoSchema(s, is), nil +} + +func (s *session) GetStmtStats() *stmtstats.StatementStats { + return s.stmtStats +} + +// SetMemoryFootprintChangeHook sets the hook that is called when the memdb changes its size. +// Call this after s.txn becomes valid, since TxnInfo is initialized when the txn becomes valid. +func (s *session) SetMemoryFootprintChangeHook() { + if s.txn.MemHookSet() { + return + } + if config.GetGlobalConfig().Performance.TxnTotalSizeLimit != config.DefTxnTotalSizeLimit { + // if the user manually specifies the config, don't involve the new memory tracker mechanism, let the old config + // work as before. + return + } + hook := func(mem uint64) { + if s.sessionVars.MemDBFootprint == nil { + tracker := memory.NewTracker(memory.LabelForMemDB, -1) + tracker.AttachTo(s.sessionVars.MemTracker) + s.sessionVars.MemDBFootprint = tracker + } + s.sessionVars.MemDBFootprint.ReplaceBytesUsed(int64(mem)) + } + s.txn.SetMemoryFootprintChangeHook(hook) +} + +// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface. +func (s *session) EncodeSessionStates(ctx context.Context, + _ sessionctx.Context, sessionStates *sessionstates.SessionStates) error { + // Transaction status is hard to encode, so we do not support it. + s.txn.mu.Lock() + valid := s.txn.Valid() + s.txn.mu.Unlock() + if valid { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has an active transaction") + } + // Data in local temporary tables is hard to encode, so we do not support it. + // Check temporary tables here to avoid circle dependency. + if s.sessionVars.LocalTemporaryTables != nil { + localTempTables := s.sessionVars.LocalTemporaryTables.(*infoschema.SessionTables) + if localTempTables.Count() > 0 { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has local temporary tables") + } + } + // The advisory locks will be released when the session is closed. + if len(s.advisoryLocks) > 0 { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has advisory locks") + } + // The TableInfo stores session ID and server ID, so the session cannot be migrated. + if len(s.lockedTables) > 0 { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has locked tables") + } + // It's insecure to migrate sandBoxMode because users can fake it. + if s.InSandBoxMode() { + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session is in sandbox mode") + } + + if err := s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil { + return err + } + + hasRestrictVarPriv := false + checker := privilege.GetPrivilegeManager(s) + if checker == nil || checker.RequestDynamicVerification(s.sessionVars.ActiveRoles, "RESTRICTED_VARIABLES_ADMIN", false) { + hasRestrictVarPriv = true + } + // Encode session variables. We put it here instead of SessionVars to avoid cycle import. + sessionStates.SystemVars = make(map[string]string) + for _, sv := range variable.GetSysVars() { + switch { + case sv.HasNoneScope(), !sv.HasSessionScope(): + // Hidden attribute is deprecated. + // None-scoped variables cannot be modified. + // Noop variables should also be migrated even if they are noop. + continue + case sv.ReadOnly: + // Skip read-only variables here. We encode them into SessionStates manually. + continue + } + // Get all session variables because the default values may change between versions. + val, keep, err := s.sessionVars.GetSessionStatesSystemVar(sv.Name) + switch { + case err != nil: + return err + case !keep: + continue + case !hasRestrictVarPriv && sem.IsEnabled() && sem.IsInvisibleSysVar(sv.Name): + // If the variable has a global scope, it should be the same with the global one. + // Otherwise, it should be the same with the default value. + defaultVal := sv.Value + if sv.HasGlobalScope() { + // If the session value is the same with the global one, skip it. + if defaultVal, err = sv.GetGlobalFromHook(ctx, s.sessionVars); err != nil { + return err + } + } + if val != defaultVal { + // Case 1: the RESTRICTED_VARIABLES_ADMIN is revoked after setting the session variable. + // Case 2: the global variable is updated after the session is created. + // In any case, the variable can't be set in the new session, so give up. + return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs(fmt.Sprintf("session has set invisible variable '%s'", sv.Name)) + } + default: + sessionStates.SystemVars[sv.Name] = val + } + } + + // Encode prepared statements and sql bindings. + for _, handler := range s.sessionStatesHandlers { + if err := handler.EncodeSessionStates(ctx, s, sessionStates); err != nil { + return err + } + } + return nil +} + +// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface. +func (s *session) DecodeSessionStates(ctx context.Context, + _ sessionctx.Context, sessionStates *sessionstates.SessionStates) error { + // Decode prepared statements and sql bindings. + for _, handler := range s.sessionStatesHandlers { + if err := handler.DecodeSessionStates(ctx, s, sessionStates); err != nil { + return err + } + } + + // Decode session variables. + names := variable.OrderByDependency(sessionStates.SystemVars) + // Some variables must be set before others, e.g. tidb_enable_noop_functions should be before noop variables. + for _, name := range names { + val := sessionStates.SystemVars[name] + // Experimental system variables may change scope, data types, or even be removed. + // We just ignore the errors and continue. + if err := s.sessionVars.SetSystemVar(name, val); err != nil { + logutil.Logger(ctx).Warn("set session variable during decoding session states error", + zap.String("name", name), zap.String("value", val), zap.Error(err)) + } + } + + // Decoding session vars / prepared statements may override stmt ctx, such as warnings, + // so we decode stmt ctx at last. + return s.sessionVars.DecodeSessionStates(ctx, sessionStates) +} + +func (s *session) setRequestSource(ctx context.Context, stmtLabel string, stmtNode ast.StmtNode) { + if !s.isInternal() { + if txn, _ := s.Txn(false); txn != nil && txn.Valid() { + if txn.IsPipelined() { + stmtLabel = "pdml" + } + txn.SetOption(kv.RequestSourceType, stmtLabel) + } + s.sessionVars.RequestSourceType = stmtLabel + return + } + if source := ctx.Value(kv.RequestSourceKey); source != nil { + requestSource := source.(kv.RequestSource) + if requestSource.RequestSourceType != "" { + s.sessionVars.RequestSourceType = requestSource.RequestSourceType + return + } + } + // panic in test mode in case there are requests without source in the future. + // log warnings in production mode. + if intest.InTest { + panic("unexpected no source type context, if you see this error, " + + "the `RequestSourceTypeKey` is missing in your context") + } + logutil.Logger(ctx).Warn("unexpected no source type context, if you see this warning, "+ + "the `RequestSourceTypeKey` is missing in the context", + zap.Bool("internal", s.isInternal()), + zap.String("sql", stmtNode.Text())) +} + +// NewStmtIndexUsageCollector creates a new `*indexusage.StmtIndexUsageCollector` based on the internal session index +// usage collector +func (s *session) NewStmtIndexUsageCollector() *indexusage.StmtIndexUsageCollector { + if s.idxUsageCollector == nil { + return nil + } + + return indexusage.NewStmtIndexUsageCollector(s.idxUsageCollector) +} + +// usePipelinedDmlOrWarn returns the current statement can be executed as a pipelined DML. +func (s *session) usePipelinedDmlOrWarn(ctx context.Context) bool { + if !s.sessionVars.BulkDMLEnabled { + return false + } + stmtCtx := s.sessionVars.StmtCtx + if stmtCtx == nil { + return false + } + if stmtCtx.IsReadOnly { + return false + } + vars := s.GetSessionVars() + if !vars.TxnCtx.EnableMDL { + stmtCtx.AppendWarning( + errors.New( + "Pipelined DML can not be used without Metadata Lock. Fallback to standard mode", + ), + ) + return false + } + if (vars.BatchCommit || vars.BatchInsert || vars.BatchDelete) && vars.DMLBatchSize > 0 && variable.EnableBatchDML.Load() { + stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used with the deprecated Batch DML. Fallback to standard mode")) + return false + } + if vars.BinlogClient != nil { + stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used with Binlog: BinlogClient != nil. Fallback to standard mode")) + return false + } + if !(stmtCtx.InInsertStmt || stmtCtx.InDeleteStmt || stmtCtx.InUpdateStmt) { + if !stmtCtx.IsReadOnly { + stmtCtx.AppendWarning(errors.New("Pipelined DML can only be used for auto-commit INSERT, REPLACE, UPDATE or DELETE. Fallback to standard mode")) + } + return false + } + if s.isInternal() { + stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used for internal SQL. Fallback to standard mode")) + return false + } + if vars.InTxn() { + stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used in transaction. Fallback to standard mode")) + return false + } + if !vars.IsAutocommit() { + stmtCtx.AppendWarning(errors.New("Pipelined DML can only be used in autocommit mode. Fallback to standard mode")) + return false + } + if s.GetSessionVars().ConstraintCheckInPlace { + // we enforce that pipelined DML must lazily check key. + stmtCtx.AppendWarning( + errors.New( + "Pipelined DML can not be used when tidb_constraint_check_in_place=ON. " + + "Fallback to standard mode", + ), + ) + return false + } + is, ok := s.GetDomainInfoSchema().(infoschema.InfoSchema) + if !ok { + stmtCtx.AppendWarning(errors.New("Pipelined DML failed to get latest InfoSchema. Fallback to standard mode")) + return false + } + for _, t := range stmtCtx.Tables { + // get table schema from current infoschema + tbl, err := is.TableByName(ctx, model.NewCIStr(t.DB), model.NewCIStr(t.Table)) + if err != nil { + stmtCtx.AppendWarning(errors.New("Pipelined DML failed to get table schema. Fallback to standard mode")) + return false + } + if tbl.Meta().IsView() { + stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used on view. Fallback to standard mode")) + return false + } + if tbl.Meta().IsSequence() { + stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used on sequence. Fallback to standard mode")) + return false + } + if vars.ForeignKeyChecks && (len(tbl.Meta().ForeignKeys) > 0 || len(is.GetTableReferredForeignKeys(t.DB, t.Table)) > 0) { + stmtCtx.AppendWarning( + errors.New( + "Pipelined DML can not be used on table with foreign keys when foreign_key_checks = ON. Fallback to standard mode", + ), + ) + return false + } + if tbl.Meta().TempTableType != model.TempTableNone { + stmtCtx.AppendWarning( + errors.New( + "Pipelined DML can not be used on temporary tables. " + + "Fallback to standard mode", + ), + ) + return false + } + if tbl.Meta().TableCacheStatusType != model.TableCacheStatusDisable { + stmtCtx.AppendWarning( + errors.New( + "Pipelined DML can not be used on cached tables. " + + "Fallback to standard mode", + ), + ) + return false + } + } + + // tidb_dml_type=bulk will invalidate the config pessimistic-auto-commit. + // The behavior is as if the config is set to false. But we generate a warning for it. + if config.GetGlobalConfig().PessimisticTxn.PessimisticAutoCommit.Load() { + stmtCtx.AppendWarning( + errors.New( + "pessimistic-auto-commit config is ignored in favor of Pipelined DML", + ), + ) + } + return true +} + +// RemoveLockDDLJobs removes the DDL jobs which doesn't get the metadata lock from job2ver. +func RemoveLockDDLJobs(s types.Session, job2ver map[int64]int64, job2ids map[int64]string, printLog bool) { + sv := s.GetSessionVars() + if sv.InRestrictedSQL { + return + } + sv.TxnCtxMu.Lock() + defer sv.TxnCtxMu.Unlock() + if sv.TxnCtx == nil { + return + } + sv.GetRelatedTableForMDL().Range(func(tblID, value any) bool { + for jobID, ver := range job2ver { + ids := util.Str2Int64Map(job2ids[jobID]) + if _, ok := ids[tblID.(int64)]; ok && value.(int64) < ver { + delete(job2ver, jobID) + elapsedTime := time.Since(oracle.GetTimeFromTS(sv.TxnCtx.StartTS)) + if elapsedTime > time.Minute && printLog { + logutil.BgLogger().Info("old running transaction block DDL", zap.Int64("table ID", tblID.(int64)), zap.Int64("jobID", jobID), zap.Uint64("connection ID", sv.ConnectionID), zap.Duration("elapsed time", elapsedTime)) + } else { + logutil.BgLogger().Debug("old running transaction block DDL", zap.Int64("table ID", tblID.(int64)), zap.Int64("jobID", jobID), zap.Uint64("connection ID", sv.ConnectionID), zap.Duration("elapsed time", elapsedTime)) + } + } + } + return true + }) +} + +// GetDBNames gets the sql layer database names from the session. +func GetDBNames(seVar *variable.SessionVars) []string { + dbNames := make(map[string]struct{}) + if seVar == nil || !config.GetGlobalConfig().Status.RecordDBLabel { + return []string{""} + } + if seVar.StmtCtx != nil { + for _, t := range seVar.StmtCtx.Tables { + dbNames[t.DB] = struct{}{} + } + } + if len(dbNames) == 0 { + dbNames[seVar.CurrentDB] = struct{}{} + } + ns := make([]string, 0, len(dbNames)) + for n := range dbNames { + ns = append(ns, n) + } + return ns +} + +// GetCursorTracker returns the internal `cursor.Tracker` +func (s *session) GetCursorTracker() cursor.Tracker { + return s.cursorTracker +} diff --git a/pkg/session/sync_upgrade.go b/pkg/session/sync_upgrade.go index 52829793f6af7..9fb0ae318ed93 100644 --- a/pkg/session/sync_upgrade.go +++ b/pkg/session/sync_upgrade.go @@ -81,14 +81,14 @@ func SyncUpgradeState(s sessionctx.Context, timeout time.Duration) error { // SyncNormalRunning syncs normal state to etcd. func SyncNormalRunning(s sessionctx.Context) error { bgCtx := context.Background() - failpoint.Inject("mockResumeAllJobsFailed", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockResumeAllJobsFailed")); _err_ == nil { if val.(bool) { dom := domain.GetDomain(s) //nolint: errcheck dom.DDL().StateSyncer().UpdateGlobalState(bgCtx, syncer.NewStateInfo(syncer.StateNormalRunning)) - failpoint.Return(nil) + return nil } - }) + } logger := logutil.BgLogger().With(zap.String("category", "upgrading")) jobErrs, err := ddl.ResumeAllJobsBySystem(s) diff --git a/pkg/session/sync_upgrade.go__failpoint_stash__ b/pkg/session/sync_upgrade.go__failpoint_stash__ new file mode 100644 index 0000000000000..52829793f6af7 --- /dev/null +++ b/pkg/session/sync_upgrade.go__failpoint_stash__ @@ -0,0 +1,163 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/syncer" + dist_store "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/owner" + sessiontypes "github.com/pingcap/tidb/pkg/session/types" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// isContextDone checks if context is done. +func isContextDone(ctx context.Context) bool { + select { + case <-ctx.Done(): + return true + default: + } + return false +} + +// SyncUpgradeState syncs upgrade state to etcd. +func SyncUpgradeState(s sessionctx.Context, timeout time.Duration) error { + ctx, cancelFunc := context.WithTimeout(context.Background(), timeout) + defer cancelFunc() + dom := domain.GetDomain(s) + err := dom.DDL().StateSyncer().UpdateGlobalState(ctx, syncer.NewStateInfo(syncer.StateUpgrading)) + logger := logutil.BgLogger().With(zap.String("category", "upgrading")) + if err != nil { + logger.Error("update global state failed", zap.String("state", syncer.StateUpgrading), zap.Error(err)) + return err + } + + interval := 200 * time.Millisecond + for i := 0; ; i++ { + if isContextDone(ctx) { + logger.Error("get owner op failed", zap.Duration("timeout", timeout), zap.Error(err)) + return ctx.Err() + } + + var op owner.OpType + childCtx, cancel := context.WithTimeout(ctx, 3*time.Second) + op, err = owner.GetOwnerOpValue(childCtx, dom.EtcdClient(), ddl.DDLOwnerKey, "upgrade bootstrap") + cancel() + if err == nil && op.IsSyncedUpgradingState() { + break + } + if i%10 == 0 { + logger.Warn("get owner op failed", zap.Stringer("op", op), zap.Error(err)) + } + time.Sleep(interval) + } + + logger.Info("update global state to upgrading", zap.String("state", syncer.StateUpgrading)) + return nil +} + +// SyncNormalRunning syncs normal state to etcd. +func SyncNormalRunning(s sessionctx.Context) error { + bgCtx := context.Background() + failpoint.Inject("mockResumeAllJobsFailed", func(val failpoint.Value) { + if val.(bool) { + dom := domain.GetDomain(s) + //nolint: errcheck + dom.DDL().StateSyncer().UpdateGlobalState(bgCtx, syncer.NewStateInfo(syncer.StateNormalRunning)) + failpoint.Return(nil) + } + }) + + logger := logutil.BgLogger().With(zap.String("category", "upgrading")) + jobErrs, err := ddl.ResumeAllJobsBySystem(s) + if err != nil { + logger.Warn("resume all paused jobs failed", zap.Error(err)) + } + for _, e := range jobErrs { + logger.Warn("resume the job failed", zap.Error(e)) + } + + if mgr, _ := dist_store.GetTaskManager(); mgr != nil { + ctx := kv.WithInternalSourceType(bgCtx, kv.InternalDistTask) + err := mgr.AdjustTaskOverflowConcurrency(ctx, s) + if err != nil { + log.Warn("cannot adjust task overflow concurrency", zap.Error(err)) + } + } + + ctx, cancelFunc := context.WithTimeout(bgCtx, 3*time.Second) + defer cancelFunc() + dom := domain.GetDomain(s) + err = dom.DDL().StateSyncer().UpdateGlobalState(ctx, syncer.NewStateInfo(syncer.StateNormalRunning)) + if err != nil { + logger.Error("update global state to normal failed", zap.Error(err)) + return err + } + logger.Info("update global state to normal running finished") + return nil +} + +// IsUpgradingClusterState checks whether the global state is upgrading. +func IsUpgradingClusterState(s sessionctx.Context) (bool, error) { + dom := domain.GetDomain(s) + ctx, cancelFunc := context.WithTimeout(context.Background(), 3*time.Second) + defer cancelFunc() + stateInfo, err := dom.DDL().StateSyncer().GetGlobalState(ctx) + if err != nil { + return false, err + } + + return stateInfo.State == syncer.StateUpgrading, nil +} + +func printClusterState(s sessiontypes.Session, ver int64) { + // After SupportUpgradeHTTPOpVer version, the upgrade by paused user DDL can be notified through the HTTP API. + // We check the global state see if we are upgrading by paused the user DDL. + if ver >= SupportUpgradeHTTPOpVer { + isUpgradingClusterStateWithRetry(s, ver, currentBootstrapVersion, time.Duration(internalSQLTimeout)*time.Second) + } +} + +func isUpgradingClusterStateWithRetry(s sessionctx.Context, oldVer, newVer int64, timeout time.Duration) { + now := time.Now() + interval := 200 * time.Millisecond + logger := logutil.BgLogger().With(zap.String("category", "upgrading")) + for i := 0; ; i++ { + isUpgrading, err := IsUpgradingClusterState(s) + if err == nil { + logger.Info("get global state", zap.Int64("old version", oldVer), zap.Int64("latest version", newVer), zap.Bool("is upgrading state", isUpgrading)) + return + } + + if time.Since(now) >= timeout { + logger.Error("get global state failed", zap.Int64("old version", oldVer), zap.Int64("latest version", newVer), zap.Error(err)) + return + } + if i%25 == 0 { + logger.Warn("get global state failed", zap.Int64("old version", oldVer), zap.Int64("latest version", newVer), zap.Error(err)) + } + time.Sleep(interval) + } +} diff --git a/pkg/session/tidb.go b/pkg/session/tidb.go index a59529dc4669a..1bd907e58d8c4 100644 --- a/pkg/session/tidb.go +++ b/pkg/session/tidb.go @@ -229,9 +229,9 @@ func recordAbortTxnDuration(sessVars *variable.SessionVars, isInternal bool) { } func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { - failpoint.Inject("finishStmtError", func() { - failpoint.Return(errors.New("occur an error after finishStmt")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("finishStmtError")); _err_ == nil { + return errors.New("occur an error after finishStmt") + } sessVars := se.sessionVars if !sql.IsReadOnly(sessVars) { // All the history should be added here. diff --git a/pkg/session/tidb.go__failpoint_stash__ b/pkg/session/tidb.go__failpoint_stash__ new file mode 100644 index 0000000000000..a59529dc4669a --- /dev/null +++ b/pkg/session/tidb.go__failpoint_stash__ @@ -0,0 +1,403 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package session + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/schematracker" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/executor" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + session_metrics "github.com/pingcap/tidb/pkg/session/metrics" + "github.com/pingcap/tidb/pkg/session/types" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/pingcap/tidb/pkg/util/syncutil" + "go.uber.org/zap" +) + +type domainMap struct { + mu syncutil.Mutex + domains map[string]*domain.Domain +} + +// Get or create the domain for store. +// TODO decouple domain create from it, it's more clear to create domain explicitly +// before any usage of it. +func (dm *domainMap) Get(store kv.Storage) (d *domain.Domain, err error) { + dm.mu.Lock() + defer dm.mu.Unlock() + + if store == nil { + for _, d := range dm.domains { + // return available domain if any + return d, nil + } + return nil, errors.New("can not find available domain for a nil store") + } + + key := store.UUID() + + d = dm.domains[key] + if d != nil { + return + } + + ddlLease := time.Duration(atomic.LoadInt64(&schemaLease)) + statisticLease := time.Duration(atomic.LoadInt64(&statsLease)) + planReplayerGCLease := GetPlanReplayerGCLease() + err = util.RunWithRetry(util.DefaultMaxRetries, util.RetryInterval, func() (retry bool, err1 error) { + logutil.BgLogger().Info("new domain", + zap.String("store", store.UUID()), + zap.Stringer("ddl lease", ddlLease), + zap.Stringer("stats lease", statisticLease)) + factory := createSessionFunc(store) + sysFactory := createSessionWithDomainFunc(store) + d = domain.NewDomain(store, ddlLease, statisticLease, planReplayerGCLease, factory) + + var ddlInjector func(ddl.DDL, ddl.Executor, *infoschema.InfoCache) *schematracker.Checker + if injector, ok := store.(schematracker.StorageDDLInjector); ok { + ddlInjector = injector.Injector + } + err1 = d.Init(ddlLease, sysFactory, ddlInjector) + if err1 != nil { + // If we don't clean it, there are some dirty data when retrying the function of Init. + d.Close() + logutil.BgLogger().Error("init domain failed", zap.String("category", "ddl"), + zap.Error(err1)) + } + return true, err1 + }) + if err != nil { + return nil, err + } + dm.domains[key] = d + d.SetOnClose(func() { + dm.Delete(store) + }) + + return +} + +func (dm *domainMap) Delete(store kv.Storage) { + dm.mu.Lock() + delete(dm.domains, store.UUID()) + dm.mu.Unlock() +} + +var ( + domap = &domainMap{ + domains: map[string]*domain.Domain{}, + } + // store.UUID()-> IfBootstrapped + storeBootstrapped = make(map[string]bool) + storeBootstrappedLock sync.Mutex + + // schemaLease is the time for re-updating remote schema. + // In online DDL, we must wait 2 * SchemaLease time to guarantee + // all servers get the neweset schema. + // Default schema lease time is 1 second, you can change it with a proper time, + // but you must know that too little may cause badly performance degradation. + // For production, you should set a big schema lease, like 300s+. + schemaLease = int64(1 * time.Second) + + // statsLease is the time for reload stats table. + statsLease = int64(3 * time.Second) + + // planReplayerGCLease is the time for plan replayer gc. + planReplayerGCLease = int64(10 * time.Minute) +) + +// ResetStoreForWithTiKVTest is only used in the test code. +// TODO: Remove domap and storeBootstrapped. Use store.SetOption() to do it. +func ResetStoreForWithTiKVTest(store kv.Storage) { + domap.Delete(store) + unsetStoreBootstrapped(store.UUID()) +} + +func setStoreBootstrapped(storeUUID string) { + storeBootstrappedLock.Lock() + defer storeBootstrappedLock.Unlock() + storeBootstrapped[storeUUID] = true +} + +// unsetStoreBootstrapped delete store uuid from stored bootstrapped map. +// currently this function only used for test. +func unsetStoreBootstrapped(storeUUID string) { + storeBootstrappedLock.Lock() + defer storeBootstrappedLock.Unlock() + delete(storeBootstrapped, storeUUID) +} + +// SetSchemaLease changes the default schema lease time for DDL. +// This function is very dangerous, don't use it if you really know what you do. +// SetSchemaLease only affects not local storage after bootstrapped. +func SetSchemaLease(lease time.Duration) { + atomic.StoreInt64(&schemaLease, int64(lease)) +} + +// SetStatsLease changes the default stats lease time for loading stats info. +func SetStatsLease(lease time.Duration) { + atomic.StoreInt64(&statsLease, int64(lease)) +} + +// SetPlanReplayerGCLease changes the default plan repalyer gc lease time. +func SetPlanReplayerGCLease(lease time.Duration) { + atomic.StoreInt64(&planReplayerGCLease, int64(lease)) +} + +// GetPlanReplayerGCLease returns the plan replayer gc lease time. +func GetPlanReplayerGCLease() time.Duration { + return time.Duration(atomic.LoadInt64(&planReplayerGCLease)) +} + +// DisableStats4Test disables the stats for tests. +func DisableStats4Test() { + SetStatsLease(-1) +} + +// Parse parses a query string to raw ast.StmtNode. +func Parse(ctx sessionctx.Context, src string) ([]ast.StmtNode, error) { + logutil.BgLogger().Debug("compiling", zap.String("source", src)) + sessVars := ctx.GetSessionVars() + p := parser.New() + p.SetParserConfig(sessVars.BuildParserConfig()) + p.SetSQLMode(sessVars.SQLMode) + stmts, warns, err := p.ParseSQL(src, sessVars.GetParseParams()...) + for _, warn := range warns { + sessVars.StmtCtx.AppendWarning(warn) + } + if err != nil { + logutil.BgLogger().Warn("compiling", + zap.String("source", src), + zap.Error(err)) + return nil, err + } + return stmts, nil +} + +func recordAbortTxnDuration(sessVars *variable.SessionVars, isInternal bool) { + duration := time.Since(sessVars.TxnCtx.CreateTime).Seconds() + if sessVars.TxnCtx.IsPessimistic { + if isInternal { + session_metrics.TransactionDurationPessimisticAbortInternal.Observe(duration) + } else { + session_metrics.TransactionDurationPessimisticAbortGeneral.Observe(duration) + } + } else { + if isInternal { + session_metrics.TransactionDurationOptimisticAbortInternal.Observe(duration) + } else { + session_metrics.TransactionDurationOptimisticAbortGeneral.Observe(duration) + } + } +} + +func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { + failpoint.Inject("finishStmtError", func() { + failpoint.Return(errors.New("occur an error after finishStmt")) + }) + sessVars := se.sessionVars + if !sql.IsReadOnly(sessVars) { + // All the history should be added here. + if meetsErr == nil && sessVars.TxnCtx.CouldRetry { + GetHistory(se).Add(sql, sessVars.StmtCtx) + } + + // Handle the stmt commit/rollback. + if se.txn.Valid() { + if meetsErr != nil { + se.StmtRollback(ctx, false) + } else { + se.StmtCommit(ctx) + } + } + } + err := autoCommitAfterStmt(ctx, se, meetsErr, sql) + if se.txn.pending() { + // After run statement finish, txn state is still pending means the + // statement never need a Txn(), such as: + // + // set @@tidb_general_log = 1 + // set @@autocommit = 0 + // select 1 + // + // Reset txn state to invalid to dispose the pending start ts. + se.txn.changeToInvalid() + } + if err != nil { + return err + } + return checkStmtLimit(ctx, se, true) +} + +func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { + isInternal := false + if internal := se.txn.GetOption(kv.RequestSourceInternal); internal != nil && internal.(bool) { + isInternal = true + } + sessVars := se.sessionVars + if meetsErr != nil { + if !sessVars.InTxn() { + logutil.BgLogger().Info("rollbackTxn called due to ddl/autocommit failure") + se.RollbackTxn(ctx) + recordAbortTxnDuration(sessVars, isInternal) + } else if se.txn.Valid() && se.txn.IsPessimistic() && exeerrors.ErrDeadlock.Equal(meetsErr) { + logutil.BgLogger().Info("rollbackTxn for deadlock", zap.Uint64("txn", se.txn.StartTS())) + se.RollbackTxn(ctx) + recordAbortTxnDuration(sessVars, isInternal) + } + return meetsErr + } + + if !sessVars.InTxn() { + if err := se.CommitTxn(ctx); err != nil { + if _, ok := sql.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { + err = errors.Annotatef(err, "previous statement: %s", se.GetSessionVars().PrevStmt) + } + return err + } + return nil + } + return nil +} + +func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error { + // If the user insert, insert, insert ... but never commit, TiDB would OOM. + // So we limit the statement count in a transaction here. + var err error + sessVars := se.GetSessionVars() + history := GetHistory(se) + stmtCount := history.Count() + if !isFinish { + // history stmt count + current stmt, since current stmt is not finish, it has not add to history. + stmtCount++ + } + if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) { + if !sessVars.BatchCommit { + se.RollbackTxn(ctx) + return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t", + stmtCount, sessVars.IsAutocommit()) + } + if !isFinish { + // if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit. + return nil + } + // If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true, + // then commit the current transaction and create a new transaction. + err = sessiontxn.NewTxn(ctx, se) + // The transaction does not committed yet, we need to keep it in transaction. + // The last history could not be "commit"/"rollback" statement. + // It means it is impossible to start a new transaction at the end of the transaction. + // Because after the server executed "commit"/"rollback" statement, the session is out of the transaction. + sessVars.SetInTxn(true) + } + return err +} + +// GetHistory get all stmtHistory in current txn. Exported only for test. +// If stmtHistory is nil, will create a new one for current txn. +func GetHistory(ctx sessionctx.Context) *StmtHistory { + hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory) + if ok { + return hist + } + hist = new(StmtHistory) + ctx.GetSessionVars().TxnCtx.History = hist + return hist +} + +// GetRows4Test gets all the rows from a RecordSet, only used for test. +func GetRows4Test(ctx context.Context, _ sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { + if rs == nil { + return nil, nil + } + var rows []chunk.Row + req := rs.NewChunk(nil) + // Must reuse `req` for imitating server.(*clientConn).writeChunks + for { + err := rs.Next(ctx, req) + if err != nil { + return nil, err + } + if req.NumRows() == 0 { + break + } + + iter := chunk.NewIterator4Chunk(req.CopyConstruct()) + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + rows = append(rows, row) + } + } + return rows, nil +} + +// ResultSetToStringSlice changes the RecordSet to [][]string. +func ResultSetToStringSlice(ctx context.Context, s types.Session, rs sqlexec.RecordSet) ([][]string, error) { + rows, err := GetRows4Test(ctx, s, rs) + if err != nil { + return nil, err + } + err = rs.Close() + if err != nil { + return nil, err + } + sRows := make([][]string, len(rows)) + for i := range rows { + row := rows[i] + iRow := make([]string, row.Len()) + for j := 0; j < row.Len(); j++ { + if row.IsNull(j) { + iRow[j] = "" + } else { + d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType) + iRow[j], err = d.ToString() + if err != nil { + return nil, err + } + } + } + sRows[i] = iRow + } + return sRows, nil +} + +// Session errors. +var ( + ErrForUpdateCantRetry = dbterror.ClassSession.NewStd(errno.ErrForUpdateCantRetry) +) diff --git a/pkg/session/txn.go b/pkg/session/txn.go index 0ec88c4a5667a..c051166713ed7 100644 --- a/pkg/session/txn.go +++ b/pkg/session/txn.go @@ -422,29 +422,29 @@ func (txn *LazyTxn) Commit(ctx context.Context) error { txn.updateState(txninfo.TxnCommitting) txn.mu.Unlock() - failpoint.Inject("mockSlowCommit", func(_ failpoint.Value) {}) + failpoint.Eval(_curpkg_("mockSlowCommit")) // mockCommitError8942 is used for PR #8942. - failpoint.Inject("mockCommitError8942", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockCommitError8942")); _err_ == nil { if val.(bool) { - failpoint.Return(kv.ErrTxnRetryable) + return kv.ErrTxnRetryable } - }) + } // mockCommitRetryForAutoIncID is used to mock an commit retry for adjustAutoIncrementDatum. - failpoint.Inject("mockCommitRetryForAutoIncID", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockCommitRetryForAutoIncID")); _err_ == nil { if val.(bool) && !mockAutoIncIDRetry() { enableMockAutoIncIDRetry() - failpoint.Return(kv.ErrTxnRetryable) + return kv.ErrTxnRetryable } - }) + } - failpoint.Inject("mockCommitRetryForAutoRandID", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockCommitRetryForAutoRandID")); _err_ == nil { if val.(bool) && needMockAutoRandIDRetry() { decreaseMockAutoRandIDRetryCount() - failpoint.Return(kv.ErrTxnRetryable) + return kv.ErrTxnRetryable } - }) + } return txn.Transaction.Commit(ctx) } @@ -456,7 +456,7 @@ func (txn *LazyTxn) Rollback() error { txn.updateState(txninfo.TxnRollingBack) txn.mu.Unlock() // mockSlowRollback is used to mock a rollback which takes a long time - failpoint.Inject("mockSlowRollback", func(_ failpoint.Value) {}) + failpoint.Eval(_curpkg_("mockSlowRollback")) return txn.Transaction.Rollback() } @@ -474,7 +474,7 @@ func (txn *LazyTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keys ...k // LockKeysFunc Wrap the inner transaction's `LockKeys` to record the status func (txn *LazyTxn) LockKeysFunc(ctx context.Context, lockCtx *kv.LockCtx, fn func(), keys ...kv.Key) error { - failpoint.Inject("beforeLockKeys", func() {}) + failpoint.Eval(_curpkg_("beforeLockKeys")) t := time.Now() var originState txninfo.TxnRunningState @@ -705,7 +705,7 @@ type txnFuture struct { func (tf *txnFuture) wait() (kv.Transaction, error) { startTS, err := tf.future.Wait() - failpoint.Inject("txnFutureWait", func() {}) + failpoint.Eval(_curpkg_("txnFutureWait")) if err == nil { if tf.pipelined { return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithStartTS(startTS), tikv.WithPipelinedMemDB()) diff --git a/pkg/session/txn.go__failpoint_stash__ b/pkg/session/txn.go__failpoint_stash__ new file mode 100644 index 0000000000000..0ec88c4a5667a --- /dev/null +++ b/pkg/session/txn.go__failpoint_stash__ @@ -0,0 +1,778 @@ +// Copyright 2018 PingCAP, Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "bytes" + "context" + "fmt" + "runtime/trace" + "strings" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/session/txninfo" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sli" + "github.com/pingcap/tidb/pkg/util/syncutil" + "github.com/pingcap/tipb/go-binlog" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "go.uber.org/zap" +) + +// LazyTxn wraps kv.Transaction to provide a new kv.Transaction. +// 1. It holds all statement related modification in the buffer before flush to the txn, +// so if execute statement meets error, the txn won't be made dirty. +// 2. It's a lazy transaction, that means it's a txnFuture before StartTS() is really need. +type LazyTxn struct { + // States of a LazyTxn should be one of the followings: + // Invalid: kv.Transaction == nil && txnFuture == nil + // Pending: kv.Transaction == nil && txnFuture != nil + // Valid: kv.Transaction != nil && txnFuture == nil + kv.Transaction + txnFuture *txnFuture + + initCnt int + stagingHandle kv.StagingHandle + mutations map[int64]*binlog.TableMutation + writeSLI sli.TxnWriteThroughputSLI + + enterFairLockingOnValid bool + + // TxnInfo is added for the lock view feature, the data is frequent modified but + // rarely read (just in query select * from information_schema.tidb_trx). + // The data in this session would be query by other sessions, so Mutex is necessary. + // Since read is rare, the reader can copy-on-read to get a data snapshot. + mu struct { + syncutil.RWMutex + txninfo.TxnInfo + } + + // mark the txn enables lazy uniqueness check in pessimistic transactions. + lazyUniquenessCheckEnabled bool +} + +// GetTableInfo returns the cached index name. +func (txn *LazyTxn) GetTableInfo(id int64) *model.TableInfo { + return txn.Transaction.GetTableInfo(id) +} + +// CacheTableInfo caches the index name. +func (txn *LazyTxn) CacheTableInfo(id int64, info *model.TableInfo) { + txn.Transaction.CacheTableInfo(id, info) +} + +func (txn *LazyTxn) init() { + txn.mutations = make(map[int64]*binlog.TableMutation) + txn.mu.Lock() + defer txn.mu.Unlock() + txn.mu.TxnInfo = txninfo.TxnInfo{} +} + +// call this under lock! +func (txn *LazyTxn) updateState(state txninfo.TxnRunningState) { + if txn.mu.TxnInfo.State != state { + lastState := txn.mu.TxnInfo.State + lastStateChangeTime := txn.mu.TxnInfo.LastStateChangeTime + txn.mu.TxnInfo.State = state + txn.mu.TxnInfo.LastStateChangeTime = time.Now() + if !lastStateChangeTime.IsZero() { + hasLockLbl := !txn.mu.TxnInfo.BlockStartTime.IsZero() + txninfo.TxnDurationHistogram(lastState, hasLockLbl).Observe(time.Since(lastStateChangeTime).Seconds()) + } + txninfo.TxnStatusEnteringCounter(state).Inc() + } +} + +func (txn *LazyTxn) initStmtBuf() { + if txn.Transaction == nil { + return + } + buf := txn.Transaction.GetMemBuffer() + txn.initCnt = buf.Len() + if !txn.IsPipelined() { + txn.stagingHandle = buf.Staging() + } +} + +// countHint is estimated count of mutations. +func (txn *LazyTxn) countHint() int { + if txn.stagingHandle == kv.InvalidStagingHandle { + return 0 + } + return txn.Transaction.GetMemBuffer().Len() - txn.initCnt +} + +func (txn *LazyTxn) flushStmtBuf() { + if txn.stagingHandle == kv.InvalidStagingHandle { + return + } + buf := txn.Transaction.GetMemBuffer() + + if txn.lazyUniquenessCheckEnabled { + keysNeedSetPersistentPNE := kv.FindKeysInStage(buf, txn.stagingHandle, func(_ kv.Key, flags kv.KeyFlags, _ []byte) bool { + return flags.HasPresumeKeyNotExists() + }) + for _, key := range keysNeedSetPersistentPNE { + buf.UpdateFlags(key, kv.SetPreviousPresumeKeyNotExists) + } + } + + if !txn.IsPipelined() { + buf.Release(txn.stagingHandle) + } + txn.initCnt = buf.Len() +} + +func (txn *LazyTxn) cleanupStmtBuf() { + if txn.stagingHandle == kv.InvalidStagingHandle { + return + } + buf := txn.Transaction.GetMemBuffer() + if !txn.IsPipelined() { + buf.Cleanup(txn.stagingHandle) + } + txn.initCnt = buf.Len() + + txn.mu.Lock() + defer txn.mu.Unlock() + txn.mu.TxnInfo.EntriesCount = uint64(txn.Transaction.Len()) +} + +// resetTxnInfo resets the transaction info. +// Note: call it under lock! +func (txn *LazyTxn) resetTxnInfo( + startTS uint64, + state txninfo.TxnRunningState, + entriesCount uint64, + currentSQLDigest string, + allSQLDigests []string, +) { + if !txn.mu.LastStateChangeTime.IsZero() { + lastState := txn.mu.State + hasLockLbl := !txn.mu.BlockStartTime.IsZero() + txninfo.TxnDurationHistogram(lastState, hasLockLbl).Observe(time.Since(txn.mu.TxnInfo.LastStateChangeTime).Seconds()) + } + if txn.mu.TxnInfo.StartTS != 0 { + txninfo.Recorder.OnTrxEnd(&txn.mu.TxnInfo) + } + txn.mu.TxnInfo = txninfo.TxnInfo{} + txn.mu.TxnInfo.StartTS = startTS + txn.mu.TxnInfo.State = state + txninfo.TxnStatusEnteringCounter(state).Inc() + txn.mu.TxnInfo.LastStateChangeTime = time.Now() + txn.mu.TxnInfo.EntriesCount = entriesCount + + txn.mu.TxnInfo.CurrentSQLDigest = currentSQLDigest + txn.mu.TxnInfo.AllSQLDigests = allSQLDigests +} + +// Size implements the MemBuffer interface. +func (txn *LazyTxn) Size() int { + if txn.Transaction == nil { + return 0 + } + return txn.Transaction.Size() +} + +// Mem implements the MemBuffer interface. +func (txn *LazyTxn) Mem() uint64 { + if txn.Transaction == nil { + return 0 + } + return txn.Transaction.Mem() +} + +// SetMemoryFootprintChangeHook sets the hook to be called when the memory footprint of this transaction changes. +func (txn *LazyTxn) SetMemoryFootprintChangeHook(hook func(uint64)) { + if txn.Transaction == nil { + return + } + txn.Transaction.SetMemoryFootprintChangeHook(hook) +} + +// MemHookSet returns whether the memory footprint change hook is set. +func (txn *LazyTxn) MemHookSet() bool { + if txn.Transaction == nil { + return false + } + return txn.Transaction.MemHookSet() +} + +// Valid implements the kv.Transaction interface. +func (txn *LazyTxn) Valid() bool { + return txn.Transaction != nil && txn.Transaction.Valid() +} + +func (txn *LazyTxn) pending() bool { + return txn.Transaction == nil && txn.txnFuture != nil +} + +func (txn *LazyTxn) validOrPending() bool { + return txn.txnFuture != nil || txn.Valid() +} + +func (txn *LazyTxn) String() string { + if txn.Transaction != nil { + return txn.Transaction.String() + } + if txn.txnFuture != nil { + res := "txnFuture" + if txn.enterFairLockingOnValid { + res += " (pending fair locking)" + } + return res + } + return "invalid transaction" +} + +// GoString implements the "%#v" format for fmt.Printf. +func (txn *LazyTxn) GoString() string { + var s strings.Builder + s.WriteString("Txn{") + if txn.pending() { + s.WriteString("state=pending") + } else if txn.Valid() { + s.WriteString("state=valid") + fmt.Fprintf(&s, ", txnStartTS=%d", txn.Transaction.StartTS()) + if len(txn.mutations) > 0 { + fmt.Fprintf(&s, ", len(mutations)=%d, %#v", len(txn.mutations), txn.mutations) + } + } else { + s.WriteString("state=invalid") + } + + s.WriteString("}") + return s.String() +} + +// GetOption implements the GetOption +func (txn *LazyTxn) GetOption(opt int) any { + if txn.Transaction == nil { + if opt == kv.TxnScope { + return "" + } + return nil + } + return txn.Transaction.GetOption(opt) +} + +func (txn *LazyTxn) changeToPending(future *txnFuture) { + txn.Transaction = nil + txn.txnFuture = future +} + +func (txn *LazyTxn) changePendingToValid(ctx context.Context, sctx sessionctx.Context) error { + if txn.txnFuture == nil { + return errors.New("transaction future is not set") + } + + future := txn.txnFuture + txn.txnFuture = nil + + defer trace.StartRegion(ctx, "WaitTsoFuture").End() + t, err := future.wait() + if err != nil { + txn.Transaction = nil + return err + } + txn.Transaction = t + txn.initStmtBuf() + + if txn.enterFairLockingOnValid { + txn.enterFairLockingOnValid = false + err = txn.Transaction.StartFairLocking() + if err != nil { + return err + } + } + + // The txnInfo may already recorded the first statement (usually "begin") when it's pending, so keep them. + txn.mu.Lock() + defer txn.mu.Unlock() + txn.resetTxnInfo( + t.StartTS(), + txninfo.TxnIdle, + uint64(txn.Transaction.Len()), + txn.mu.TxnInfo.CurrentSQLDigest, + txn.mu.TxnInfo.AllSQLDigests) + + // set resource group name for kv request such as lock pessimistic keys. + kv.SetTxnResourceGroup(txn, sctx.GetSessionVars().StmtCtx.ResourceGroupName) + // overwrite entry size limit by sys var. + if entrySizeLimit := sctx.GetSessionVars().TxnEntrySizeLimit; entrySizeLimit > 0 { + txn.SetOption(kv.SizeLimits, kv.TxnSizeLimits{ + Entry: entrySizeLimit, + Total: kv.TxnTotalSizeLimit.Load(), + }) + } + + return nil +} + +func (txn *LazyTxn) changeToInvalid() { + if txn.stagingHandle != kv.InvalidStagingHandle && !txn.IsPipelined() { + txn.Transaction.GetMemBuffer().Cleanup(txn.stagingHandle) + } + txn.stagingHandle = kv.InvalidStagingHandle + txn.Transaction = nil + txn.txnFuture = nil + + txn.enterFairLockingOnValid = false + + txn.mu.Lock() + lastState := txn.mu.TxnInfo.State + lastStateChangeTime := txn.mu.TxnInfo.LastStateChangeTime + hasLock := !txn.mu.TxnInfo.BlockStartTime.IsZero() + if txn.mu.TxnInfo.StartTS != 0 { + txninfo.Recorder.OnTrxEnd(&txn.mu.TxnInfo) + } + txn.mu.TxnInfo = txninfo.TxnInfo{} + txn.mu.Unlock() + if !lastStateChangeTime.IsZero() { + txninfo.TxnDurationHistogram(lastState, hasLock).Observe(time.Since(lastStateChangeTime).Seconds()) + } +} + +func (txn *LazyTxn) onStmtStart(currentSQLDigest string) { + if len(currentSQLDigest) == 0 { + return + } + + txn.mu.Lock() + defer txn.mu.Unlock() + txn.updateState(txninfo.TxnRunning) + txn.mu.TxnInfo.CurrentSQLDigest = currentSQLDigest + // Keeps at most 50 history sqls to avoid consuming too much memory. + const maxTransactionStmtHistory int = 50 + if len(txn.mu.TxnInfo.AllSQLDigests) < maxTransactionStmtHistory { + txn.mu.TxnInfo.AllSQLDigests = append(txn.mu.TxnInfo.AllSQLDigests, currentSQLDigest) + } +} + +func (txn *LazyTxn) onStmtEnd() { + txn.mu.Lock() + defer txn.mu.Unlock() + txn.mu.TxnInfo.CurrentSQLDigest = "" + txn.updateState(txninfo.TxnIdle) +} + +var hasMockAutoIncIDRetry = int64(0) + +func enableMockAutoIncIDRetry() { + atomic.StoreInt64(&hasMockAutoIncIDRetry, 1) +} + +func mockAutoIncIDRetry() bool { + return atomic.LoadInt64(&hasMockAutoIncIDRetry) == 1 +} + +var mockAutoRandIDRetryCount = int64(0) + +func needMockAutoRandIDRetry() bool { + return atomic.LoadInt64(&mockAutoRandIDRetryCount) > 0 +} + +func decreaseMockAutoRandIDRetryCount() { + atomic.AddInt64(&mockAutoRandIDRetryCount, -1) +} + +// ResetMockAutoRandIDRetryCount set the number of occurrences of +// `kv.ErrTxnRetryable` when calling TxnState.Commit(). +func ResetMockAutoRandIDRetryCount(failTimes int64) { + atomic.StoreInt64(&mockAutoRandIDRetryCount, failTimes) +} + +// Commit overrides the Transaction interface. +func (txn *LazyTxn) Commit(ctx context.Context) error { + defer txn.reset() + if len(txn.mutations) != 0 || txn.countHint() != 0 { + logutil.BgLogger().Error("the code should never run here", + zap.String("TxnState", txn.GoString()), + zap.Int("staging handler", int(txn.stagingHandle)), + zap.Int("mutations", txn.countHint()), + zap.Stack("something must be wrong")) + return errors.Trace(kv.ErrInvalidTxn) + } + + txn.mu.Lock() + txn.updateState(txninfo.TxnCommitting) + txn.mu.Unlock() + + failpoint.Inject("mockSlowCommit", func(_ failpoint.Value) {}) + + // mockCommitError8942 is used for PR #8942. + failpoint.Inject("mockCommitError8942", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(kv.ErrTxnRetryable) + } + }) + + // mockCommitRetryForAutoIncID is used to mock an commit retry for adjustAutoIncrementDatum. + failpoint.Inject("mockCommitRetryForAutoIncID", func(val failpoint.Value) { + if val.(bool) && !mockAutoIncIDRetry() { + enableMockAutoIncIDRetry() + failpoint.Return(kv.ErrTxnRetryable) + } + }) + + failpoint.Inject("mockCommitRetryForAutoRandID", func(val failpoint.Value) { + if val.(bool) && needMockAutoRandIDRetry() { + decreaseMockAutoRandIDRetryCount() + failpoint.Return(kv.ErrTxnRetryable) + } + }) + + return txn.Transaction.Commit(ctx) +} + +// Rollback overrides the Transaction interface. +func (txn *LazyTxn) Rollback() error { + defer txn.reset() + txn.mu.Lock() + txn.updateState(txninfo.TxnRollingBack) + txn.mu.Unlock() + // mockSlowRollback is used to mock a rollback which takes a long time + failpoint.Inject("mockSlowRollback", func(_ failpoint.Value) {}) + return txn.Transaction.Rollback() +} + +// RollbackMemDBToCheckpoint overrides the Transaction interface. +func (txn *LazyTxn) RollbackMemDBToCheckpoint(savepoint *tikv.MemDBCheckpoint) { + txn.flushStmtBuf() + txn.Transaction.RollbackMemDBToCheckpoint(savepoint) + txn.cleanup() +} + +// LockKeys wraps the inner transaction's `LockKeys` to record the status +func (txn *LazyTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keys ...kv.Key) error { + return txn.LockKeysFunc(ctx, lockCtx, nil, keys...) +} + +// LockKeysFunc Wrap the inner transaction's `LockKeys` to record the status +func (txn *LazyTxn) LockKeysFunc(ctx context.Context, lockCtx *kv.LockCtx, fn func(), keys ...kv.Key) error { + failpoint.Inject("beforeLockKeys", func() {}) + t := time.Now() + + var originState txninfo.TxnRunningState + txn.mu.Lock() + originState = txn.mu.TxnInfo.State + txn.updateState(txninfo.TxnLockAcquiring) + txn.mu.TxnInfo.BlockStartTime.Valid = true + txn.mu.TxnInfo.BlockStartTime.Time = t + txn.mu.Unlock() + lockFunc := func() { + if fn != nil { + fn() + } + txn.mu.Lock() + defer txn.mu.Unlock() + txn.updateState(originState) + txn.mu.TxnInfo.BlockStartTime.Valid = false + txn.mu.TxnInfo.EntriesCount = uint64(txn.Transaction.Len()) + } + return txn.Transaction.LockKeysFunc(ctx, lockCtx, lockFunc, keys...) +} + +// StartFairLocking wraps the inner transaction to support using fair locking with lazy initialization. +func (txn *LazyTxn) StartFairLocking() error { + if txn.Valid() { + return txn.Transaction.StartFairLocking() + } else if !txn.pending() { + err := errors.New("trying to start fair locking on a transaction in invalid state") + logutil.BgLogger().Error("unexpected error when starting fair locking", zap.Error(err), zap.Stringer("txn", txn)) + return err + } + txn.enterFairLockingOnValid = true + return nil +} + +// RetryFairLocking wraps the inner transaction to support using fair locking with lazy initialization. +func (txn *LazyTxn) RetryFairLocking(ctx context.Context) error { + if txn.Valid() { + return txn.Transaction.RetryFairLocking(ctx) + } else if !txn.pending() { + err := errors.New("trying to retry fair locking on a transaction in invalid state") + logutil.BgLogger().Error("unexpected error when retrying fair locking", zap.Error(err), zap.Stringer("txnStartTS", txn)) + return err + } + return nil +} + +// CancelFairLocking wraps the inner transaction to support using fair locking with lazy initialization. +func (txn *LazyTxn) CancelFairLocking(ctx context.Context) error { + if txn.Valid() { + return txn.Transaction.CancelFairLocking(ctx) + } else if !txn.pending() { + err := errors.New("trying to cancel fair locking on a transaction in invalid state") + logutil.BgLogger().Error("unexpected error when cancelling fair locking", zap.Error(err), zap.Stringer("txnStartTS", txn)) + return err + } + if !txn.enterFairLockingOnValid { + err := errors.New("trying to cancel fair locking when it's not started") + logutil.BgLogger().Error("unexpected error when cancelling fair locking", zap.Error(err), zap.Stringer("txnStartTS", txn)) + return err + } + txn.enterFairLockingOnValid = false + return nil +} + +// DoneFairLocking wraps the inner transaction to support using fair locking with lazy initialization. +func (txn *LazyTxn) DoneFairLocking(ctx context.Context) error { + if txn.Valid() { + return txn.Transaction.DoneFairLocking(ctx) + } + if !txn.pending() { + err := errors.New("trying to cancel fair locking on a transaction in invalid state") + logutil.BgLogger().Error("unexpected error when finishing fair locking") + return err + } + if !txn.enterFairLockingOnValid { + err := errors.New("trying to finish fair locking when it's not started") + logutil.BgLogger().Error("unexpected error when finishing fair locking") + return err + } + txn.enterFairLockingOnValid = false + return nil +} + +// IsInFairLockingMode wraps the inner transaction to support using fair locking with lazy initialization. +func (txn *LazyTxn) IsInFairLockingMode() bool { + if txn.Valid() { + return txn.Transaction.IsInFairLockingMode() + } else if txn.pending() { + return txn.enterFairLockingOnValid + } + return false +} + +func (txn *LazyTxn) reset() { + txn.cleanup() + txn.changeToInvalid() +} + +func (txn *LazyTxn) cleanup() { + txn.cleanupStmtBuf() + txn.initStmtBuf() + for key := range txn.mutations { + delete(txn.mutations, key) + } +} + +// KeysNeedToLock returns the keys need to be locked. +func (txn *LazyTxn) KeysNeedToLock() ([]kv.Key, error) { + if txn.stagingHandle == kv.InvalidStagingHandle { + return nil, nil + } + keys := make([]kv.Key, 0, txn.countHint()) + buf := txn.Transaction.GetMemBuffer() + buf.InspectStage(txn.stagingHandle, func(k kv.Key, flags kv.KeyFlags, v []byte) { + if !KeyNeedToLock(k, v, flags) { + return + } + keys = append(keys, k) + }) + + return keys, nil +} + +// Wait converts pending txn to valid +func (txn *LazyTxn) Wait(ctx context.Context, sctx sessionctx.Context) (kv.Transaction, error) { + if !txn.validOrPending() { + return txn, errors.AddStack(kv.ErrInvalidTxn) + } + if txn.pending() { + defer func(begin time.Time) { + sctx.GetSessionVars().DurationWaitTS = time.Since(begin) + }(time.Now()) + + // Transaction is lazy initialized. + // PrepareTxnCtx is called to get a tso future, makes s.txn a pending txn, + // If Txn() is called later, wait for the future to get a valid txn. + if err := txn.changePendingToValid(ctx, sctx); err != nil { + logutil.BgLogger().Error("active transaction fail", + zap.Error(err)) + txn.cleanup() + sctx.GetSessionVars().TxnCtx.StartTS = 0 + return txn, err + } + txn.lazyUniquenessCheckEnabled = !sctx.GetSessionVars().ConstraintCheckInPlacePessimistic + } + return txn, nil +} + +// KeyNeedToLock returns true if the key need to lock. +func KeyNeedToLock(k, v []byte, flags kv.KeyFlags) bool { + isTableKey := bytes.HasPrefix(k, tablecodec.TablePrefix()) + if !isTableKey { + // meta key always need to lock. + return true + } + + // a pessimistic locking is skipped, perform the conflict check and + // constraint check (more accurately, PresumeKeyNotExist) in prewrite (or later pessimistic locking) + if flags.HasNeedConstraintCheckInPrewrite() { + return false + } + + if flags.HasPresumeKeyNotExists() { + return true + } + + // lock row key, primary key and unique index for delete operation, + if len(v) == 0 { + return flags.HasNeedLocked() || tablecodec.IsRecordKey(k) + } + + if tablecodec.IsUntouchedIndexKValue(k, v) { + return false + } + + if !tablecodec.IsIndexKey(k) { + return true + } + + if tablecodec.IsTempIndexKey(k) { + tmpVal, err := tablecodec.DecodeTempIndexValue(v) + if err != nil { + logutil.BgLogger().Warn("decode temp index value failed", zap.Error(err)) + return false + } + current := tmpVal.Current() + return current.Handle != nil || tablecodec.IndexKVIsUnique(current.Value) + } + + return tablecodec.IndexKVIsUnique(v) +} + +func getBinlogMutation(ctx sessionctx.Context, tableID int64) *binlog.TableMutation { + bin := binloginfo.GetPrewriteValue(ctx, true) + for i := range bin.Mutations { + if bin.Mutations[i].TableId == tableID { + return &bin.Mutations[i] + } + } + idx := len(bin.Mutations) + bin.Mutations = append(bin.Mutations, binlog.TableMutation{TableId: tableID}) + return &bin.Mutations[idx] +} + +func mergeToMutation(m1, m2 *binlog.TableMutation) { + m1.InsertedRows = append(m1.InsertedRows, m2.InsertedRows...) + m1.UpdatedRows = append(m1.UpdatedRows, m2.UpdatedRows...) + m1.DeletedIds = append(m1.DeletedIds, m2.DeletedIds...) + m1.DeletedPks = append(m1.DeletedPks, m2.DeletedPks...) + m1.DeletedRows = append(m1.DeletedRows, m2.DeletedRows...) + m1.Sequence = append(m1.Sequence, m2.Sequence...) +} + +type txnFailFuture struct{} + +func (txnFailFuture) Wait() (uint64, error) { + return 0, errors.New("mock get timestamp fail") +} + +// txnFuture is a promise, which promises to return a txn in future. +type txnFuture struct { + future oracle.Future + store kv.Storage + txnScope string + pipelined bool +} + +func (tf *txnFuture) wait() (kv.Transaction, error) { + startTS, err := tf.future.Wait() + failpoint.Inject("txnFutureWait", func() {}) + if err == nil { + if tf.pipelined { + return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithStartTS(startTS), tikv.WithPipelinedMemDB()) + } + return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithStartTS(startTS)) + } else if config.GetGlobalConfig().Store == "unistore" { + return nil, err + } + + logutil.BgLogger().Warn("wait tso failed", zap.Error(err)) + // It would retry get timestamp. + if tf.pipelined { + return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithPipelinedMemDB()) + } + return tf.store.Begin(tikv.WithTxnScope(tf.txnScope)) +} + +// HasDirtyContent checks whether there's dirty update on the given table. +// Put this function here is to avoid cycle import. +func (s *session) HasDirtyContent(tid int64) bool { + // There should not be dirty content in a txn with pipelined memdb, and it also doesn't support Iter function. + if s.txn.Transaction == nil || s.txn.Transaction.IsPipelined() { + return false + } + seekKey := tablecodec.EncodeTablePrefix(tid) + it, err := s.txn.GetMemBuffer().Iter(seekKey, nil) + terror.Log(err) + return it.Valid() && bytes.HasPrefix(it.Key(), seekKey) +} + +// StmtCommit implements the sessionctx.Context interface. +func (s *session) StmtCommit(ctx context.Context) { + defer func() { + s.txn.cleanup() + }() + + txnManager := sessiontxn.GetTxnManager(s) + err := txnManager.OnStmtCommit(ctx) + if err != nil { + logutil.Logger(ctx).Error("txnManager failed to handle OnStmtCommit", zap.Error(err)) + } + + st := &s.txn + st.flushStmtBuf() + + // Need to flush binlog. + for tableID, delta := range st.mutations { + mutation := getBinlogMutation(s, tableID) + mergeToMutation(mutation, delta) + } +} + +// StmtRollback implements the sessionctx.Context interface. +func (s *session) StmtRollback(ctx context.Context, isForPessimisticRetry bool) { + txnManager := sessiontxn.GetTxnManager(s) + err := txnManager.OnStmtRollback(ctx, isForPessimisticRetry) + if err != nil { + logutil.Logger(ctx).Error("txnManager failed to handle OnStmtRollback", zap.Error(err)) + } + s.txn.cleanup() +} + +// StmtGetMutation implements the sessionctx.Context interface. +func (s *session) StmtGetMutation(tableID int64) *binlog.TableMutation { + st := &s.txn + if _, ok := st.mutations[tableID]; !ok { + st.mutations[tableID] = &binlog.TableMutation{TableId: tableID} + } + return st.mutations[tableID] +} diff --git a/pkg/session/txnmanager.go b/pkg/session/txnmanager.go index 7b3b512c6acee..dd77d7f08964e 100644 --- a/pkg/session/txnmanager.go +++ b/pkg/session/txnmanager.go @@ -108,12 +108,12 @@ func (m *txnManager) GetStmtForUpdateTS() (uint64, error) { return 0, err } - failpoint.Inject("assertTxnManagerForUpdateTSEqual", func() { + if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerForUpdateTSEqual")); _err_ == nil { sessVars := m.sctx.GetSessionVars() if txnCtxForUpdateTS := sessVars.TxnCtx.GetForUpdateTS(); sessVars.SnapshotTS == 0 && ts != txnCtxForUpdateTS { panic(fmt.Sprintf("forUpdateTS not equal %d != %d", ts, txnCtxForUpdateTS)) } - }) + } return ts, nil } diff --git a/pkg/session/txnmanager.go__failpoint_stash__ b/pkg/session/txnmanager.go__failpoint_stash__ new file mode 100644 index 0000000000000..7b3b512c6acee --- /dev/null +++ b/pkg/session/txnmanager.go__failpoint_stash__ @@ -0,0 +1,381 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "context" + "fmt" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/sessiontxn/isolation" + "github.com/pingcap/tidb/pkg/sessiontxn/staleread" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func init() { + sessiontxn.GetTxnManager = getTxnManager +} + +func getTxnManager(sctx sessionctx.Context) sessiontxn.TxnManager { + if manager, ok := sctx.GetSessionVars().TxnManager.(sessiontxn.TxnManager); ok { + return manager + } + + manager := newTxnManager(sctx) + sctx.GetSessionVars().TxnManager = manager + return manager +} + +// txnManager implements sessiontxn.TxnManager +type txnManager struct { + sctx sessionctx.Context + + stmtNode ast.StmtNode + ctxProvider sessiontxn.TxnContextProvider + + // We always reuse the same OptimisticTxnContextProvider in one session to reduce memory allocation cost for every new txn. + reservedOptimisticProviders [2]isolation.OptimisticTxnContextProvider + + // used for slow transaction logs + events []event + lastInstant time.Time + enterTxnInstant time.Time +} + +type event struct { + event string + duration time.Duration +} + +func (s event) MarshalLogObject(enc zapcore.ObjectEncoder) error { + enc.AddString("event", s.event) + enc.AddDuration("gap", s.duration) + return nil +} + +func newTxnManager(sctx sessionctx.Context) *txnManager { + return &txnManager{sctx: sctx} +} + +func (m *txnManager) GetTxnInfoSchema() infoschema.InfoSchema { + if m.ctxProvider != nil { + return m.ctxProvider.GetTxnInfoSchema() + } + + if is := m.sctx.GetDomainInfoSchema(); is != nil { + return is.(infoschema.InfoSchema) + } + + return nil +} + +func (m *txnManager) GetStmtReadTS() (uint64, error) { + if m.ctxProvider == nil { + return 0, errors.New("context provider not set") + } + return m.ctxProvider.GetStmtReadTS() +} + +func (m *txnManager) GetStmtForUpdateTS() (uint64, error) { + if m.ctxProvider == nil { + return 0, errors.New("context provider not set") + } + + ts, err := m.ctxProvider.GetStmtForUpdateTS() + if err != nil { + return 0, err + } + + failpoint.Inject("assertTxnManagerForUpdateTSEqual", func() { + sessVars := m.sctx.GetSessionVars() + if txnCtxForUpdateTS := sessVars.TxnCtx.GetForUpdateTS(); sessVars.SnapshotTS == 0 && ts != txnCtxForUpdateTS { + panic(fmt.Sprintf("forUpdateTS not equal %d != %d", ts, txnCtxForUpdateTS)) + } + }) + + return ts, nil +} + +func (m *txnManager) GetTxnScope() string { + if m.ctxProvider == nil { + return kv.GlobalTxnScope + } + return m.ctxProvider.GetTxnScope() +} + +func (m *txnManager) GetReadReplicaScope() string { + if m.ctxProvider == nil { + return kv.GlobalReplicaScope + } + return m.ctxProvider.GetReadReplicaScope() +} + +// GetSnapshotWithStmtReadTS gets snapshot with read ts +func (m *txnManager) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { + if m.ctxProvider == nil { + return nil, errors.New("context provider not set") + } + return m.ctxProvider.GetSnapshotWithStmtReadTS() +} + +// GetSnapshotWithStmtForUpdateTS gets snapshot with for update ts +func (m *txnManager) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { + if m.ctxProvider == nil { + return nil, errors.New("context provider not set") + } + return m.ctxProvider.GetSnapshotWithStmtForUpdateTS() +} + +func (m *txnManager) GetContextProvider() sessiontxn.TxnContextProvider { + return m.ctxProvider +} + +func (m *txnManager) EnterNewTxn(ctx context.Context, r *sessiontxn.EnterNewTxnRequest) error { + ctxProvider, err := m.newProviderWithRequest(r) + if err != nil { + return err + } + + if err = ctxProvider.OnInitialize(ctx, r.Type); err != nil { + m.sctx.RollbackTxn(ctx) + return err + } + + m.ctxProvider = ctxProvider + if r.Type == sessiontxn.EnterNewTxnWithBeginStmt { + m.sctx.GetSessionVars().SetInTxn(true) + } + + m.resetEvents() + m.recordEvent("enter txn") + return nil +} + +func (m *txnManager) OnTxnEnd() { + m.ctxProvider = nil + m.stmtNode = nil + + m.events = append(m.events, event{event: "txn end", duration: time.Since(m.lastInstant)}) + + duration := time.Since(m.enterTxnInstant) + threshold := m.sctx.GetSessionVars().SlowTxnThreshold + if threshold > 0 && uint64(duration.Milliseconds()) >= threshold { + logutil.BgLogger().Info( + "slow transaction", zap.Duration("duration", duration), + zap.Uint64("conn", m.sctx.GetSessionVars().ConnectionID), + zap.Uint64("txnStartTS", m.sctx.GetSessionVars().TxnCtx.StartTS), + zap.Objects("events", m.events), + ) + } + + m.lastInstant = time.Now() +} + +func (m *txnManager) GetCurrentStmt() ast.StmtNode { + return m.stmtNode +} + +// OnStmtStart is the hook that should be called when a new statement started +func (m *txnManager) OnStmtStart(ctx context.Context, node ast.StmtNode) error { + m.stmtNode = node + + if m.ctxProvider == nil { + return errors.New("context provider not set") + } + + var sql string + if node != nil { + sql = node.OriginalText() + sql = parser.Normalize(sql, m.sctx.GetSessionVars().EnableRedactLog) + } + m.recordEvent(sql) + return m.ctxProvider.OnStmtStart(ctx, m.stmtNode) +} + +// OnStmtEnd implements the TxnManager interface +func (m *txnManager) OnStmtEnd() { + m.recordEvent("stmt end") +} + +// OnPessimisticStmtStart is the hook that should be called when starts handling a pessimistic DML or +// a pessimistic select-for-update statements. +func (m *txnManager) OnPessimisticStmtStart(ctx context.Context) error { + if m.ctxProvider == nil { + return errors.New("context provider not set") + } + return m.ctxProvider.OnPessimisticStmtStart(ctx) +} + +// OnPessimisticStmtEnd is the hook that should be called when finishes handling a pessimistic DML or +// select-for-update statement. +func (m *txnManager) OnPessimisticStmtEnd(ctx context.Context, isSuccessful bool) error { + if m.ctxProvider == nil { + return errors.New("context provider not set") + } + return m.ctxProvider.OnPessimisticStmtEnd(ctx, isSuccessful) +} + +// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error +func (m *txnManager) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { + if m.ctxProvider == nil { + return sessiontxn.NoIdea() + } + return m.ctxProvider.OnStmtErrorForNextAction(ctx, point, err) +} + +// ActivateTxn decides to activate txn according to the parameter `active` +func (m *txnManager) ActivateTxn() (kv.Transaction, error) { + if m.ctxProvider == nil { + return nil, errors.AddStack(kv.ErrInvalidTxn) + } + return m.ctxProvider.ActivateTxn() +} + +// OnStmtRetry is the hook that should be called when a statement retry +func (m *txnManager) OnStmtRetry(ctx context.Context) error { + if m.ctxProvider == nil { + return errors.New("context provider not set") + } + return m.ctxProvider.OnStmtRetry(ctx) +} + +// OnStmtCommit is the hook that should be called when a statement is executed successfully. +func (m *txnManager) OnStmtCommit(ctx context.Context) error { + if m.ctxProvider == nil { + return errors.New("context provider not set") + } + m.recordEvent("stmt commit") + return m.ctxProvider.OnStmtCommit(ctx) +} + +func (m *txnManager) recordEvent(eventName string) { + if m.events == nil { + m.resetEvents() + } + m.events = append(m.events, event{event: eventName, duration: time.Since(m.lastInstant)}) + m.lastInstant = time.Now() +} + +func (m *txnManager) resetEvents() { + if m.events == nil { + m.events = make([]event, 0, 10) + } else { + m.events = m.events[:0] + } + m.enterTxnInstant = time.Now() +} + +// OnStmtRollback is the hook that should be called when a statement fails to execute. +func (m *txnManager) OnStmtRollback(ctx context.Context, isForPessimisticRetry bool) error { + if m.ctxProvider == nil { + return errors.New("context provider not set") + } + m.recordEvent("stmt rollback") + return m.ctxProvider.OnStmtRollback(ctx, isForPessimisticRetry) +} + +// OnLocalTemporaryTableCreated is the hook that should be called when a temporary table created. +// The provider will update its state then +func (m *txnManager) OnLocalTemporaryTableCreated() { + if m.ctxProvider != nil { + m.ctxProvider.OnLocalTemporaryTableCreated() + } +} + +func (m *txnManager) AdviseWarmup() error { + if m.sctx.GetSessionVars().BulkDMLEnabled { + // We don't want to validate the feasibility of pipelined DML here. + // We'd like to check it later after optimization so that optimizer info can be used. + // And it does not make much sense to save such a little time for pipelined-dml as it's + // for bulk processing. + return nil + } + + if m.ctxProvider != nil { + return m.ctxProvider.AdviseWarmup() + } + return nil +} + +// AdviseOptimizeWithPlan providers optimization according to the plan +func (m *txnManager) AdviseOptimizeWithPlan(plan any) error { + if m.ctxProvider != nil { + return m.ctxProvider.AdviseOptimizeWithPlan(plan) + } + return nil +} + +func (m *txnManager) newProviderWithRequest(r *sessiontxn.EnterNewTxnRequest) (sessiontxn.TxnContextProvider, error) { + if r.Provider != nil { + return r.Provider, nil + } + + if r.StaleReadTS > 0 { + m.sctx.GetSessionVars().TxnCtx.StaleReadTs = r.StaleReadTS + return staleread.NewStalenessTxnContextProvider(m.sctx, r.StaleReadTS, nil), nil + } + + sessVars := m.sctx.GetSessionVars() + + txnMode := r.TxnMode + if txnMode == "" { + txnMode = sessVars.TxnMode + } + + switch txnMode { + case "", ast.Optimistic: + // When txnMode is 'OPTIMISTIC' or '', the transaction should be optimistic + provider := &m.reservedOptimisticProviders[0] + if old, ok := m.ctxProvider.(*isolation.OptimisticTxnContextProvider); ok && old == provider { + // We should make sure the new provider is not the same with the old one + provider = &m.reservedOptimisticProviders[1] + } + provider.ResetForNewTxn(m.sctx, r.CausalConsistencyOnly) + return provider, nil + case ast.Pessimistic: + // When txnMode is 'PESSIMISTIC', the provider should be determined by the isolation level + switch sessVars.IsolationLevelForNewTxn() { + case ast.ReadCommitted: + return isolation.NewPessimisticRCTxnContextProvider(m.sctx, r.CausalConsistencyOnly), nil + case ast.Serializable: + // The Oracle serializable isolation is actually SI in pessimistic mode. + // Do not update ForUpdateTS when the user is using the Serializable isolation level. + // It can be used temporarily on the few occasions when an Oracle-like isolation level is needed. + // Support for this does not mean that TiDB supports serializable isolation of MySQL. + // tidb_skip_isolation_level_check should still be disabled by default. + return isolation.NewPessimisticSerializableTxnContextProvider(m.sctx, r.CausalConsistencyOnly), nil + default: + // We use Repeatable read for all other cases. + return isolation.NewPessimisticRRTxnContextProvider(m.sctx, r.CausalConsistencyOnly), nil + } + default: + return nil, errors.Errorf("Invalid txn mode '%s'", txnMode) + } +} + +// SetOptionsBeforeCommit sets options before commit. +func (m *txnManager) SetOptionsBeforeCommit(txn kv.Transaction, commitTSChecker func(uint64) bool) error { + return m.ctxProvider.SetOptionsBeforeCommit(txn, commitTSChecker) +} diff --git a/pkg/sessionctx/sessionstates/binding__failpoint_binding__.go b/pkg/sessionctx/sessionstates/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..46a20b98f439c --- /dev/null +++ b/pkg/sessionctx/sessionstates/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package sessionstates + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/sessionctx/sessionstates/session_token.go b/pkg/sessionctx/sessionstates/session_token.go index 5702d4966f118..b141c0c605e8f 100644 --- a/pkg/sessionctx/sessionstates/session_token.go +++ b/pkg/sessionctx/sessionstates/session_token.go @@ -327,10 +327,10 @@ func (sc *signingCert) checkSignature(content, signature []byte) error { func getNow() time.Time { now := time.Now() - failpoint.Inject("mockNowOffset", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockNowOffset")); _err_ == nil { if s := uint64(val.(int)); s != 0 { now = now.Add(time.Duration(s)) } - }) + } return now } diff --git a/pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ b/pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ new file mode 100644 index 0000000000000..5702d4966f118 --- /dev/null +++ b/pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ @@ -0,0 +1,336 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessionstates + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/sha512" + "crypto/tls" + "crypto/x509" + "encoding/json" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// Token-based authentication is used in session migration. We don't use typical authentication because the proxy +// cannot store the user passwords for security issues. +// +// The process of token-based authentication: +// 1. Before migrating the session, the proxy requires a token from server A. +// 2. Server A generates a token and signs it with a private key defined in the certificate. +// 3. The proxy authenticates with server B and sends the signed token as the password. +// 4. Server B checks the signature with the public key defined in the certificate and then verifies the token. +// +// The highlight is that the certificates on all the servers should be the same all the time. +// However, the certificates should be rotated periodically. Just in case of using different certificates to +// sign and check, a server should keep the old certificate for a while. A server will try both +// the 2 certificates to check the signature. +const ( + // A token needs a lifetime to avoid brute force attack. + tokenLifetime = time.Minute + // LoadCertInterval is the interval of reloading the certificate. The certificate should be rotated periodically. + LoadCertInterval = 10 * time.Minute + // After a certificate is replaced, it's still valid for oldCertValidTime. + // oldCertValidTime must be a little longer than LoadCertInterval, because the previous server may + // sign with the old cert but the new server checks with the new cert. + // - server A loads the old cert at 00:00:00. + // - the cert is rotated at 00:00:01 on all servers. + // - server B loads the new cert at 00:00:02. + // - server A signs token with the old cert at 00:10:00. + // - server B reloads the same new cert again at 00:10:01, and it has 3 certs now. + // - server B receives the token at 00:10:02, so the old cert should be valid for more than 10m after replacement. + oldCertValidTime = 15 * time.Minute +) + +// SessionToken represents the token used to authenticate with the new server. +type SessionToken struct { + Username string `json:"username"` + SignTime time.Time `json:"sign-time"` + ExpireTime time.Time `json:"expire-time"` + Signature []byte `json:"signature,omitempty"` +} + +// CreateSessionToken creates a token for the proxy. +func CreateSessionToken(username string) (*SessionToken, error) { + now := getNow() + token := &SessionToken{ + Username: username, + SignTime: now, + ExpireTime: now.Add(tokenLifetime), + } + tokenBytes, err := json.Marshal(token) + if err != nil { + return nil, errors.Trace(err) + } + if token.Signature, err = globalSigningCert.sign(tokenBytes); err != nil { + return nil, ErrCannotMigrateSession.GenWithStackByArgs(err.Error()) + } + return token, nil +} + +// ValidateSessionToken validates the token sent from the proxy. +func ValidateSessionToken(tokenBytes []byte, username string) (err error) { + var token SessionToken + if err = json.Unmarshal(tokenBytes, &token); err != nil { + return errors.Trace(err) + } + signature := token.Signature + // Clear the signature and marshal it again to get the original content. + token.Signature = nil + if tokenBytes, err = json.Marshal(token); err != nil { + return errors.Trace(err) + } + if err = globalSigningCert.checkSignature(tokenBytes, signature); err != nil { + return ErrCannotMigrateSession.GenWithStackByArgs(err.Error()) + } + now := getNow() + if now.After(token.ExpireTime) { + return ErrCannotMigrateSession.GenWithStackByArgs("token expired", token.ExpireTime.String()) + } + // An attacker may forge a very long lifetime to brute force, so we also need to check `SignTime`. + // However, we need to be tolerant of these problems: + // - The `tokenLifetime` may change between TiDB versions, so we can't check `token.SignTime.Add(tokenLifetime).Equal(token.ExpireTime)` + // - There may exist time bias between TiDB instances, so we can't check `now.After(token.SignTime)` + if token.SignTime.Add(tokenLifetime).Before(now) { + return ErrCannotMigrateSession.GenWithStackByArgs("token lifetime is too long", token.SignTime.String()) + } + if !strings.EqualFold(username, token.Username) { + return ErrCannotMigrateSession.GenWithStackByArgs("username does not match", username, token.Username) + } + return nil +} + +// SetKeyPath sets the path of key.pem and force load the certificate again. +func SetKeyPath(keyPath string) { + globalSigningCert.setKeyPath(keyPath) +} + +// SetCertPath sets the path of key.pem and force load the certificate again. +func SetCertPath(certPath string) { + globalSigningCert.setCertPath(certPath) +} + +// ReloadSigningCert is used to load the certificate periodically in a separate goroutine. +// It's impossible to know when the old certificate should expire without this goroutine: +// - If the certificate is rotated a minute ago, the old certificate should be still valid for a while. +// - If the certificate is rotated a month ago, the old certificate should expire for safety. +func ReloadSigningCert() { + globalSigningCert.lockAndLoad() +} + +var globalSigningCert signingCert + +// signingCert represents the parsed certificate used for token-based auth. +type signingCert struct { + sync.RWMutex + certPath string + keyPath string + // The cert file may happen to be rotated between signing and checking, so we keep the old cert for a while. + // certs contain all the certificates that are not expired yet. + certs []*certInfo +} + +type certInfo struct { + cert *x509.Certificate + privKey crypto.PrivateKey + expireTime time.Time +} + +func (sc *signingCert) setCertPath(certPath string) { + sc.Lock() + if certPath != sc.certPath { + sc.certPath = certPath + // It may fail expectedly because the key path is not set yet. + sc.checkAndLoadCert() + } + sc.Unlock() +} + +func (sc *signingCert) setKeyPath(keyPath string) { + sc.Lock() + if keyPath != sc.keyPath { + sc.keyPath = keyPath + // It may fail expectedly because the cert path is not set yet. + sc.checkAndLoadCert() + } + sc.Unlock() +} + +func (sc *signingCert) lockAndLoad() { + sc.Lock() + sc.checkAndLoadCert() + sc.Unlock() +} + +func (sc *signingCert) checkAndLoadCert() { + if len(sc.certPath) == 0 || len(sc.keyPath) == 0 { + return + } + if err := sc.loadCert(); err != nil { + logutil.BgLogger().Warn("loading signing cert failed", + zap.String("cert path", sc.certPath), + zap.String("key path", sc.keyPath), + zap.Error(err)) + } else { + logutil.BgLogger().Info("signing cert is loaded successfully", + zap.String("cert path", sc.certPath), + zap.String("key path", sc.keyPath)) + } +} + +// loadCert loads the cert and adds it into the cert list. +func (sc *signingCert) loadCert() error { + tlsCert, err := tls.LoadX509KeyPair(sc.certPath, sc.keyPath) + if err != nil { + return errors.Wrapf(err, "load x509 failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath) + } + var cert *x509.Certificate + if tlsCert.Leaf != nil { + cert = tlsCert.Leaf + } else { + if cert, err = x509.ParseCertificate(tlsCert.Certificate[0]); err != nil { + return errors.Wrapf(err, "parse x509 cert failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath) + } + } + + // Rotate certs. Ensure that the expireTime of certs is in descending order. + now := getNow() + newCerts := make([]*certInfo, 0, len(sc.certs)+1) + newCerts = append(newCerts, &certInfo{ + cert: cert, + privKey: tlsCert.PrivateKey, + expireTime: now.Add(LoadCertInterval + oldCertValidTime), + }) + for i := 0; i < len(sc.certs); i++ { + // Discard the certs that are already expired. + if now.After(sc.certs[i].expireTime) { + break + } + newCerts = append(newCerts, sc.certs[i]) + } + sc.certs = newCerts + return nil +} + +// sign generates a signature with the content and the private key. +func (sc *signingCert) sign(content []byte) ([]byte, error) { + var ( + signer crypto.Signer + opts crypto.SignerOpts + ) + sc.RLock() + defer sc.RUnlock() + if len(sc.certs) == 0 { + return nil, errors.New("no certificate or key file to sign the data") + } + // Always sign the token with the latest cert. + certInfo := sc.certs[0] + switch key := certInfo.privKey.(type) { + case ed25519.PrivateKey: + signer = key + opts = crypto.Hash(0) + case *rsa.PrivateKey: + signer = key + var pssHash crypto.Hash + switch certInfo.cert.SignatureAlgorithm { + case x509.SHA256WithRSAPSS: + pssHash = crypto.SHA256 + case x509.SHA384WithRSAPSS: + pssHash = crypto.SHA384 + case x509.SHA512WithRSAPSS: + pssHash = crypto.SHA512 + } + if pssHash != 0 { + h := pssHash.New() + h.Write(content) + content = h.Sum(nil) + opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: pssHash} + break + } + switch certInfo.cert.SignatureAlgorithm { + case x509.SHA256WithRSA: + hashed := sha256.Sum256(content) + content = hashed[:] + opts = crypto.SHA256 + case x509.SHA384WithRSA: + hashed := sha512.Sum384(content) + content = hashed[:] + opts = crypto.SHA384 + case x509.SHA512WithRSA: + hashed := sha512.Sum512(content) + content = hashed[:] + opts = crypto.SHA512 + default: + return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String()) + } + case *ecdsa.PrivateKey: + signer = key + default: + return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String()) + } + return signer.Sign(rand.Reader, content, opts) +} + +// checkSignature checks the signature and the content. +func (sc *signingCert) checkSignature(content, signature []byte) error { + sc.RLock() + defer sc.RUnlock() + now := getNow() + var err error + for _, certInfo := range sc.certs { + // The expireTime is in descending order. So if the first one is expired, we skip the following. + if now.After(certInfo.expireTime) { + break + } + switch certInfo.privKey.(type) { + // ESDSA is special: `PrivateKey.Sign` doesn't match with `Certificate.CheckSignature`. + case *ecdsa.PrivateKey: + if !ecdsa.VerifyASN1(certInfo.cert.PublicKey.(*ecdsa.PublicKey), content, signature) { + err = errors.New("x509: ECDSA verification failure") + } + default: + err = certInfo.cert.CheckSignature(certInfo.cert.SignatureAlgorithm, content, signature) + } + if err == nil { + return nil + } + } + // no certs (possible) or all certs are expired (impossible) + if err == nil { + return errors.Errorf("no valid certificate to check the signature, cached certificates: %d", len(sc.certs)) + } + return err +} + +func getNow() time.Time { + now := time.Now() + failpoint.Inject("mockNowOffset", func(val failpoint.Value) { + if s := uint64(val.(int)); s != 0 { + now = now.Add(time.Duration(s)) + } + }) + return now +} diff --git a/pkg/sessiontxn/isolation/base.go b/pkg/sessiontxn/isolation/base.go index 7ae7502f00534..6eb34820c5003 100644 --- a/pkg/sessiontxn/isolation/base.go +++ b/pkg/sessiontxn/isolation/base.go @@ -643,9 +643,9 @@ func newOracleFuture(ctx context.Context, sctx sessionctx.Context, scope string) r, ctx := tracing.StartRegionEx(ctx, "isolation.newOracleFuture") defer r.End() - failpoint.Inject("requestTsoFromPD", func() { + if _, _err_ := failpoint.Eval(_curpkg_("requestTsoFromPD")); _err_ == nil { sessiontxn.TsoRequestCountInc(sctx) - }) + } oracleStore := sctx.GetStore().GetOracle() option := &oracle.Option{TxnScope: scope} diff --git a/pkg/sessiontxn/isolation/base.go__failpoint_stash__ b/pkg/sessiontxn/isolation/base.go__failpoint_stash__ new file mode 100644 index 0000000000000..7ae7502f00534 --- /dev/null +++ b/pkg/sessiontxn/isolation/base.go__failpoint_stash__ @@ -0,0 +1,747 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package isolation + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/sessiontxn/internal" + "github.com/pingcap/tidb/pkg/sessiontxn/staleread" + "github.com/pingcap/tidb/pkg/store/driver/txn" + "github.com/pingcap/tidb/pkg/table/temptable" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/tableutil" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/pingcap/tipb/go-binlog" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" +) + +// baseTxnContextProvider is a base class for the transaction context providers that implement `TxnContextProvider` in different isolation. +// It provides some common functions below: +// - Provides a default `OnInitialize` method to initialize its inner state. +// - Provides some methods like `activateTxn` and `prepareTxn` to manage the inner transaction. +// - Provides default methods `GetTxnInfoSchema`, `GetStmtReadTS` and `GetStmtForUpdateTS` and return the snapshot information schema or ts when `tidb_snapshot` is set. +// - Provides other default methods like `Advise`, `OnStmtStart`, `OnStmtRetry` and `OnStmtErrorForNextAction` +// +// The subclass can set some inner property of `baseTxnContextProvider` when it is constructed. +// For example, `getStmtReadTSFunc` and `getStmtForUpdateTSFunc` should be set, and they will be called when `GetStmtReadTS` +// or `GetStmtForUpdate` to get the timestamp that should be used by the corresponding isolation level. +type baseTxnContextProvider struct { + // States that should be initialized when baseTxnContextProvider is created and should not be changed after that + sctx sessionctx.Context + causalConsistencyOnly bool + onInitializeTxnCtx func(*variable.TransactionContext) + onTxnActiveFunc func(kv.Transaction, sessiontxn.EnterNewTxnType) + getStmtReadTSFunc func() (uint64, error) + getStmtForUpdateTSFunc func() (uint64, error) + + // Runtime states + ctx context.Context + infoSchema infoschema.InfoSchema + txn kv.Transaction + isTxnPrepared bool + enterNewTxnType sessiontxn.EnterNewTxnType + // constStartTS is only used by point get max ts optimization currently. + // When constStartTS != 0, we use constStartTS directly without fetching it from tso. + // To save the cpu cycles `PrepareTSFuture` will also not be called when warmup (postpone to activate txn). + constStartTS uint64 +} + +// OnInitialize is the hook that should be called when enter a new txn with this provider +func (p *baseTxnContextProvider) OnInitialize(ctx context.Context, tp sessiontxn.EnterNewTxnType) (err error) { + if p.getStmtReadTSFunc == nil || p.getStmtForUpdateTSFunc == nil { + return errors.New("ts functions should not be nil") + } + + p.ctx = ctx + sessVars := p.sctx.GetSessionVars() + activeNow := true + switch tp { + case sessiontxn.EnterNewTxnDefault: + // As we will enter a new txn, we need to commit the old txn if it's still valid. + // There are two main steps here to enter a new txn: + // 1. prepareTxnWithOracleTS + // 2. ActivateTxn + if err := internal.CommitBeforeEnterNewTxn(p.ctx, p.sctx); err != nil { + return err + } + if err := p.prepareTxnWithOracleTS(); err != nil { + return err + } + case sessiontxn.EnterNewTxnWithBeginStmt: + if !canReuseTxnWhenExplicitBegin(p.sctx) { + // As we will enter a new txn, we need to commit the old txn if it's still valid. + // There are two main steps here to enter a new txn: + // 1. prepareTxnWithOracleTS + // 2. ActivateTxn + if err := internal.CommitBeforeEnterNewTxn(p.ctx, p.sctx); err != nil { + return err + } + if err := p.prepareTxnWithOracleTS(); err != nil { + return err + } + } + sessVars.SetInTxn(true) + case sessiontxn.EnterNewTxnBeforeStmt: + activeNow = false + default: + return errors.Errorf("Unsupported type: %v", tp) + } + + p.enterNewTxnType = tp + p.infoSchema = p.sctx.GetDomainInfoSchema().(infoschema.InfoSchema) + txnCtx := &variable.TransactionContext{ + TxnCtxNoNeedToRestore: variable.TxnCtxNoNeedToRestore{ + CreateTime: time.Now(), + InfoSchema: p.infoSchema, + TxnScope: sessVars.CheckAndGetTxnScope(), + }, + } + if p.onInitializeTxnCtx != nil { + p.onInitializeTxnCtx(txnCtx) + } + sessVars.TxnCtxMu.Lock() + sessVars.TxnCtx = txnCtx + sessVars.TxnCtxMu.Unlock() + if variable.EnableMDL.Load() { + sessVars.TxnCtx.EnableMDL = true + } + + txn, err := p.sctx.Txn(false) + if err != nil { + return err + } + p.isTxnPrepared = txn.Valid() || p.sctx.GetPreparedTxnFuture() != nil + if activeNow { + _, err = p.ActivateTxn() + } + + return err +} + +// GetTxnInfoSchema returns the information schema used by txn +func (p *baseTxnContextProvider) GetTxnInfoSchema() infoschema.InfoSchema { + if is := p.sctx.GetSessionVars().SnapshotInfoschema; is != nil { + return is.(infoschema.InfoSchema) + } + if _, ok := p.infoSchema.(*infoschema.SessionExtendedInfoSchema); !ok { + p.infoSchema = &infoschema.SessionExtendedInfoSchema{ + InfoSchema: p.infoSchema, + } + p.sctx.GetSessionVars().TxnCtx.InfoSchema = p.infoSchema + } + return p.infoSchema +} + +// GetTxnScope returns the current txn scope +func (p *baseTxnContextProvider) GetTxnScope() string { + return p.sctx.GetSessionVars().TxnCtx.TxnScope +} + +// GetReadReplicaScope returns the read replica scope +func (p *baseTxnContextProvider) GetReadReplicaScope() string { + if txnScope := p.GetTxnScope(); txnScope != kv.GlobalTxnScope && txnScope != "" { + // In local txn, we should use txnScope as the readReplicaScope + return txnScope + } + + if p.sctx.GetSessionVars().GetReplicaRead().IsClosestRead() { + // If closest read is set, we should use the scope where instance located. + return config.GetTxnScopeFromConfig() + } + + // When it is not local txn or closet read, we should use global scope + return kv.GlobalReplicaScope +} + +// GetStmtReadTS returns the read timestamp used by select statement (not for select ... for update) +func (p *baseTxnContextProvider) GetStmtReadTS() (uint64, error) { + if _, err := p.ActivateTxn(); err != nil { + return 0, err + } + + if snapshotTS := p.sctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { + return snapshotTS, nil + } + return p.getStmtReadTSFunc() +} + +// GetStmtForUpdateTS returns the read timestamp used by update/insert/delete or select ... for update +func (p *baseTxnContextProvider) GetStmtForUpdateTS() (uint64, error) { + if _, err := p.ActivateTxn(); err != nil { + return 0, err + } + + if snapshotTS := p.sctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { + return snapshotTS, nil + } + return p.getStmtForUpdateTSFunc() +} + +// OnStmtStart is the hook that should be called when a new statement started +func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context, _ ast.StmtNode) error { + p.ctx = ctx + return nil +} + +// OnPessimisticStmtStart is the hook that should be called when starts handling a pessimistic DML or +// a pessimistic select-for-update statements. +func (p *baseTxnContextProvider) OnPessimisticStmtStart(_ context.Context) error { + return nil +} + +// OnPessimisticStmtEnd is the hook that should be called when finishes handling a pessimistic DML or +// select-for-update statement. +func (p *baseTxnContextProvider) OnPessimisticStmtEnd(_ context.Context, _ bool) error { + return nil +} + +// OnStmtRetry is the hook that should be called when a statement is retried internally. +func (p *baseTxnContextProvider) OnStmtRetry(ctx context.Context) error { + p.ctx = ctx + p.sctx.GetSessionVars().TxnCtx.CurrentStmtPessimisticLockCache = nil + return nil +} + +// OnStmtCommit is the hook that should be called when a statement is executed successfully. +func (p *baseTxnContextProvider) OnStmtCommit(_ context.Context) error { + return nil +} + +// OnStmtRollback is the hook that should be called when a statement fails to execute. +func (p *baseTxnContextProvider) OnStmtRollback(_ context.Context, _ bool) error { + return nil +} + +// OnLocalTemporaryTableCreated is the hook that should be called when a local temporary table created. +func (p *baseTxnContextProvider) OnLocalTemporaryTableCreated() { + p.infoSchema = temptable.AttachLocalTemporaryTableInfoSchema(p.sctx, p.infoSchema) + p.sctx.GetSessionVars().TxnCtx.InfoSchema = p.infoSchema + if p.txn != nil && p.txn.Valid() { + if interceptor := temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema); interceptor != nil { + p.txn.SetOption(kv.SnapInterceptor, interceptor) + } + } +} + +// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error +func (p *baseTxnContextProvider) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { + switch point { + case sessiontxn.StmtErrAfterPessimisticLock: + // for pessimistic lock error, return the error by default + return sessiontxn.ErrorAction(err) + default: + return sessiontxn.NoIdea() + } +} + +func (p *baseTxnContextProvider) getTxnStartTS() (uint64, error) { + txn, err := p.ActivateTxn() + if err != nil { + return 0, err + } + return txn.StartTS(), nil +} + +// ActivateTxn activates the transaction and set the relevant context variables. +func (p *baseTxnContextProvider) ActivateTxn() (kv.Transaction, error) { + if p.txn != nil { + return p.txn, nil + } + + if err := p.prepareTxn(); err != nil { + return nil, err + } + + if p.constStartTS != 0 { + if err := p.replaceTxnTsFuture(sessiontxn.ConstantFuture(p.constStartTS)); err != nil { + return nil, err + } + } + + txnFuture := p.sctx.GetPreparedTxnFuture() + txn, err := txnFuture.Wait(p.ctx, p.sctx) + if err != nil { + return nil, err + } + + sessVars := p.sctx.GetSessionVars() + sessVars.TxnCtxMu.Lock() + sessVars.TxnCtx.StartTS = txn.StartTS() + sessVars.GetRowIDShardGenerator().SetShardStep(int(sessVars.ShardAllocateStep)) + sessVars.TxnCtxMu.Unlock() + if sessVars.MemDBFootprint != nil { + sessVars.MemDBFootprint.Detach() + } + sessVars.MemDBFootprint = nil + + if p.enterNewTxnType == sessiontxn.EnterNewTxnBeforeStmt && !sessVars.IsAutocommit() && sessVars.SnapshotTS == 0 { + sessVars.SetInTxn(true) + } + + txn.SetVars(sessVars.KVVars) + + p.SetOptionsOnTxnActive(txn) + + if p.onTxnActiveFunc != nil { + p.onTxnActiveFunc(txn, p.enterNewTxnType) + } + + p.txn = txn + return txn, nil +} + +// prepareTxn prepares txn with an oracle ts future. If the snapshotTS is set, +// the txn is prepared with it. +func (p *baseTxnContextProvider) prepareTxn() error { + if p.isTxnPrepared { + return nil + } + + if snapshotTS := p.sctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { + return p.replaceTxnTsFuture(sessiontxn.ConstantFuture(snapshotTS)) + } + + future := newOracleFuture(p.ctx, p.sctx, p.sctx.GetSessionVars().TxnCtx.TxnScope) + return p.replaceTxnTsFuture(future) +} + +// prepareTxnWithOracleTS +// The difference between prepareTxnWithOracleTS and prepareTxn is that prepareTxnWithOracleTS +// does not consider snapshotTS +func (p *baseTxnContextProvider) prepareTxnWithOracleTS() error { + if p.isTxnPrepared { + return nil + } + + future := newOracleFuture(p.ctx, p.sctx, p.sctx.GetSessionVars().TxnCtx.TxnScope) + return p.replaceTxnTsFuture(future) +} + +func (p *baseTxnContextProvider) forcePrepareConstStartTS(ts uint64) error { + if p.txn != nil { + return errors.New("cannot force prepare const start ts because txn is active") + } + p.constStartTS = ts + p.isTxnPrepared = true + return nil +} + +func (p *baseTxnContextProvider) replaceTxnTsFuture(future oracle.Future) error { + txn, err := p.sctx.Txn(false) + if err != nil { + return err + } + + if txn.Valid() { + return nil + } + + txnScope := p.sctx.GetSessionVars().TxnCtx.TxnScope + if err = p.sctx.PrepareTSFuture(p.ctx, future, txnScope); err != nil { + return err + } + + p.isTxnPrepared = true + return nil +} + +func (p *baseTxnContextProvider) isTidbSnapshotEnabled() bool { + return p.sctx.GetSessionVars().SnapshotTS != 0 +} + +// isBeginStmtWithStaleRead indicates whether the current statement is `BeginStmt` type with stale read +// Because stale read will use `staleread.StalenessTxnContextProvider` for query, so if `staleread.IsStmtStaleness()` +// returns true in other providers, it means the current statement is `BeginStmt` with stale read +func (p *baseTxnContextProvider) isBeginStmtWithStaleRead() bool { + return staleread.IsStmtStaleness(p.sctx) +} + +// AdviseWarmup provides warmup for inner state +func (p *baseTxnContextProvider) AdviseWarmup() error { + if p.isTxnPrepared || p.isBeginStmtWithStaleRead() { + // When executing `START TRANSACTION READ ONLY AS OF ...` no need to warmUp + return nil + } + return p.prepareTxn() +} + +// AdviseOptimizeWithPlan providers optimization according to the plan +func (p *baseTxnContextProvider) AdviseOptimizeWithPlan(_ any) error { + return nil +} + +// GetSnapshotWithStmtReadTS gets snapshot with read ts +func (p *baseTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { + ts, err := p.GetStmtReadTS() + if err != nil { + return nil, err + } + + return p.getSnapshotByTS(ts) +} + +// GetSnapshotWithStmtForUpdateTS gets snapshot with for update ts +func (p *baseTxnContextProvider) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { + ts, err := p.GetStmtForUpdateTS() + if err != nil { + return nil, err + } + + return p.getSnapshotByTS(ts) +} + +// getSnapshotByTS get snapshot from store according to the snapshotTS and set the transaction related +// options before return +func (p *baseTxnContextProvider) getSnapshotByTS(snapshotTS uint64) (kv.Snapshot, error) { + txn, err := p.sctx.Txn(false) + if err != nil { + return nil, err + } + + txnCtx := p.sctx.GetSessionVars().TxnCtx + if txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() && txnCtx.StartTS == snapshotTS { + return txn.GetSnapshot(), nil + } + + sessVars := p.sctx.GetSessionVars() + snapshot := internal.GetSnapshotWithTS( + p.sctx, + snapshotTS, + temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema), + ) + + replicaReadType := sessVars.GetReplicaRead() + if replicaReadType.IsFollowerRead() && + !sessVars.StmtCtx.RCCheckTS && + !sessVars.RcWriteCheckTS { + snapshot.SetOption(kv.ReplicaRead, replicaReadType) + } + + return snapshot, nil +} + +func (p *baseTxnContextProvider) SetOptionsOnTxnActive(txn kv.Transaction) { + sessVars := p.sctx.GetSessionVars() + + readReplicaType := sessVars.GetReplicaRead() + if readReplicaType.IsFollowerRead() { + txn.SetOption(kv.ReplicaRead, readReplicaType) + } + + if interceptor := temptable.SessionSnapshotInterceptor( + p.sctx, + p.infoSchema, + ); interceptor != nil { + txn.SetOption(kv.SnapInterceptor, interceptor) + } + + if sessVars.StmtCtx.WeakConsistency { + txn.SetOption(kv.IsolationLevel, kv.RC) + } + + internal.SetTxnAssertionLevel(txn, sessVars.AssertionLevel) + + if p.sctx.GetSessionVars().InRestrictedSQL { + txn.SetOption(kv.RequestSourceInternal, true) + } + + if txn.IsPipelined() { + txn.SetOption(kv.RequestSourceType, "p-dml") + } else if tp := p.sctx.GetSessionVars().RequestSourceType; tp != "" { + txn.SetOption(kv.RequestSourceType, tp) + } + + if sessVars.LoadBasedReplicaReadThreshold > 0 { + txn.SetOption(kv.LoadBasedReplicaReadThreshold, sessVars.LoadBasedReplicaReadThreshold) + } + + txn.SetOption(kv.CommitHook, func(info string, _ error) { sessVars.LastTxnInfo = info }) + txn.SetOption(kv.EnableAsyncCommit, sessVars.EnableAsyncCommit) + txn.SetOption(kv.Enable1PC, sessVars.Enable1PC) + if sessVars.DiskFullOpt != kvrpcpb.DiskFullOpt_NotAllowedOnFull { + txn.SetDiskFullOpt(sessVars.DiskFullOpt) + } + txn.SetOption(kv.InfoSchema, sessVars.TxnCtx.InfoSchema) + if sessVars.StmtCtx.KvExecCounter != nil { + // Bind an interceptor for client-go to count the number of SQL executions of each TiKV. + txn.SetOption(kv.RPCInterceptor, sessVars.StmtCtx.KvExecCounter.RPCInterceptor()) + } + txn.SetOption(kv.ResourceGroupTagger, sessVars.StmtCtx.GetResourceGroupTagger()) + txn.SetOption(kv.ExplicitRequestSourceType, sessVars.ExplicitRequestSourceType) + + if p.causalConsistencyOnly || !sessVars.GuaranteeLinearizability { + // priority of the sysvar is lower than `start transaction with causal consistency only` + txn.SetOption(kv.GuaranteeLinearizability, false) + } else { + // We needn't ask the TiKV client to guarantee linearizability for auto-commit transactions + // because the property is naturally holds: + // We guarantee the commitTS of any transaction must not exceed the next timestamp from the TSO. + // An auto-commit transaction fetches its startTS from the TSO so its commitTS > its startTS > the commitTS + // of any previously committed transactions. + // Additionally, it's required to guarantee linearizability for snapshot read-only transactions though + // it does take effects on read-only transactions now. + txn.SetOption( + kv.GuaranteeLinearizability, + !sessVars.IsAutocommit() || + sessVars.SnapshotTS > 0 || + p.enterNewTxnType == sessiontxn.EnterNewTxnDefault || + p.enterNewTxnType == sessiontxn.EnterNewTxnWithBeginStmt, + ) + } + + txn.SetOption(kv.SessionID, p.sctx.GetSessionVars().ConnectionID) +} + +func (p *baseTxnContextProvider) SetOptionsBeforeCommit( + txn kv.Transaction, commitTSChecker func(uint64) bool, +) error { + sessVars := p.sctx.GetSessionVars() + // Pipelined dml txn already flushed mutations into stores, so we don't need to set options for them. + // Instead, some invariants must be checked to avoid anomalies though are unreachable in designed usages. + if p.txn.IsPipelined() { + if p.txn.IsPipelined() && !sessVars.TxnCtx.EnableMDL { + return errors.New("cannot commit pipelined transaction without Metadata Lock: MDL is OFF") + } + if len(sessVars.TxnCtx.TemporaryTables) > 0 { + return errors.New("pipelined dml with temporary tables is not allowed") + } + if sessVars.BinlogClient != nil { + return errors.New("pipelined dml with binlog is not allowed") + } + if sessVars.CDCWriteSource != 0 { + return errors.New("pipelined dml with CDC source is not allowed") + } + if commitTSChecker != nil { + return errors.New("pipelined dml with commitTS checker is not allowed") + } + return nil + } + + // set resource tagger again for internal tasks separated in different transactions + txn.SetOption(kv.ResourceGroupTagger, sessVars.StmtCtx.GetResourceGroupTagger()) + + // Get the related table or partition IDs. + relatedPhysicalTables := sessVars.TxnCtx.TableDeltaMap + // Get accessed temporary tables in the transaction. + temporaryTables := sessVars.TxnCtx.TemporaryTables + physicalTableIDs := make([]int64, 0, len(relatedPhysicalTables)) + for id := range relatedPhysicalTables { + // Schema change on global temporary tables doesn't affect transactions. + if _, ok := temporaryTables[id]; ok { + continue + } + physicalTableIDs = append(physicalTableIDs, id) + } + needCheckSchema := true + // Set this option for 2 phase commit to validate schema lease. + if sessVars.TxnCtx != nil { + needCheckSchema = !sessVars.TxnCtx.EnableMDL + } + + // TODO: refactor SetOption usage to avoid race risk, should detect it in test. + // The pipelined txn will may be flushed in background, not touch the options to avoid races. + // to avoid session set overlap the txn set. + txn.SetOption( + kv.SchemaChecker, + domain.NewSchemaChecker( + domain.GetDomain(p.sctx), + p.GetTxnInfoSchema().SchemaMetaVersion(), + physicalTableIDs, + needCheckSchema, + ), + ) + + if sessVars.StmtCtx.KvExecCounter != nil { + // Bind an interceptor for client-go to count the number of SQL executions of each TiKV. + txn.SetOption(kv.RPCInterceptor, sessVars.StmtCtx.KvExecCounter.RPCInterceptor()) + } + + if tables := sessVars.TxnCtx.TemporaryTables; len(tables) > 0 { + txn.SetOption(kv.KVFilter, temporaryTableKVFilter(tables)) + } + + if sessVars.BinlogClient != nil { + prewriteValue := binloginfo.GetPrewriteValue(p.sctx, false) + if prewriteValue != nil { + prewriteData, err := prewriteValue.Marshal() + if err != nil { + return errors.Trace(err) + } + info := &binloginfo.BinlogInfo{ + Data: &binlog.Binlog{ + Tp: binlog.BinlogType_Prewrite, + PrewriteValue: prewriteData, + }, + Client: sessVars.BinlogClient, + } + txn.SetOption(kv.BinlogInfo, info) + } + } + + var txnSource uint64 + if val := txn.GetOption(kv.TxnSource); val != nil { + txnSource, _ = val.(uint64) + } + // If the transaction is started by CDC, we need to set the CDCWriteSource option. + if sessVars.CDCWriteSource != 0 { + err := kv.SetCDCWriteSource(&txnSource, sessVars.CDCWriteSource) + if err != nil { + return errors.Trace(err) + } + + txn.SetOption(kv.TxnSource, txnSource) + } + + if commitTSChecker != nil { + txn.SetOption(kv.CommitTSUpperBoundCheck, commitTSChecker) + } + return nil +} + +// canReuseTxnWhenExplicitBegin returns whether we should reuse the txn when starting a transaction explicitly +func canReuseTxnWhenExplicitBegin(sctx sessionctx.Context) bool { + sessVars := sctx.GetSessionVars() + txnCtx := sessVars.TxnCtx + // If BEGIN is the first statement in TxnCtx, we can reuse the existing transaction, without the + // need to call NewTxn, which commits the existing transaction and begins a new one. + // If the last un-committed/un-rollback transaction is a time-bounded read-only transaction, we should + // always create a new transaction. + // If the variable `tidb_snapshot` is set, we should always create a new transaction because the current txn may be + // initialized with snapshot ts. + return txnCtx.History == nil && !txnCtx.IsStaleness && sessVars.SnapshotTS == 0 +} + +// newOracleFuture creates new future according to the scope and the session context +func newOracleFuture(ctx context.Context, sctx sessionctx.Context, scope string) oracle.Future { + r, ctx := tracing.StartRegionEx(ctx, "isolation.newOracleFuture") + defer r.End() + + failpoint.Inject("requestTsoFromPD", func() { + sessiontxn.TsoRequestCountInc(sctx) + }) + + oracleStore := sctx.GetStore().GetOracle() + option := &oracle.Option{TxnScope: scope} + + if sctx.GetSessionVars().UseLowResolutionTSO() { + return oracleStore.GetLowResolutionTimestampAsync(ctx, option) + } + return oracleStore.GetTimestampAsync(ctx, option) +} + +// funcFuture implements oracle.Future +type funcFuture func() (uint64, error) + +// Wait returns a ts got from the func +func (f funcFuture) Wait() (uint64, error) { + return f() +} + +// basePessimisticTxnContextProvider extends baseTxnContextProvider with some functionalities that are commonly used in +// pessimistic transactions. +type basePessimisticTxnContextProvider struct { + baseTxnContextProvider +} + +// OnPessimisticStmtStart is the hook that should be called when starts handling a pessimistic DML or +// a pessimistic select-for-update statements. +func (p *basePessimisticTxnContextProvider) OnPessimisticStmtStart(ctx context.Context) error { + if err := p.baseTxnContextProvider.OnPessimisticStmtStart(ctx); err != nil { + return err + } + if p.sctx.GetSessionVars().PessimisticTransactionFairLocking && + p.txn != nil && + p.sctx.GetSessionVars().ConnectionID != 0 && + !p.sctx.GetSessionVars().InRestrictedSQL { + if err := p.txn.StartFairLocking(); err != nil { + return err + } + } + return nil +} + +// OnPessimisticStmtEnd is the hook that should be called when finishes handling a pessimistic DML or +// select-for-update statement. +func (p *basePessimisticTxnContextProvider) OnPessimisticStmtEnd(ctx context.Context, isSuccessful bool) error { + if err := p.baseTxnContextProvider.OnPessimisticStmtEnd(ctx, isSuccessful); err != nil { + return err + } + if p.txn != nil && p.txn.IsInFairLockingMode() { + if isSuccessful { + if err := p.txn.DoneFairLocking(ctx); err != nil { + return err + } + } else { + if err := p.txn.CancelFairLocking(ctx); err != nil { + return err + } + } + } + + if isSuccessful { + p.sctx.GetSessionVars().TxnCtx.FlushStmtPessimisticLockCache() + } else { + p.sctx.GetSessionVars().TxnCtx.CurrentStmtPessimisticLockCache = nil + } + return nil +} + +func (p *basePessimisticTxnContextProvider) retryFairLockingIfNeeded(ctx context.Context) error { + if p.txn != nil && p.txn.IsInFairLockingMode() { + if err := p.txn.RetryFairLocking(ctx); err != nil { + return err + } + } + return nil +} + +func (p *basePessimisticTxnContextProvider) cancelFairLockingIfNeeded(ctx context.Context) error { + if p.txn != nil && p.txn.IsInFairLockingMode() { + if err := p.txn.CancelFairLocking(ctx); err != nil { + return err + } + } + return nil +} + +type temporaryTableKVFilter map[int64]tableutil.TempTable + +func (m temporaryTableKVFilter) IsUnnecessaryKeyValue( + key, value []byte, flags tikvstore.KeyFlags, +) (bool, error) { + tid := tablecodec.DecodeTableID(key) + if _, ok := m[tid]; ok { + return true, nil + } + + // This is the default filter for all tables. + defaultFilter := txn.TiDBKVFilter{} + return defaultFilter.IsUnnecessaryKeyValue(key, value, flags) +} diff --git a/pkg/sessiontxn/isolation/binding__failpoint_binding__.go b/pkg/sessiontxn/isolation/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..825cd58da0304 --- /dev/null +++ b/pkg/sessiontxn/isolation/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package isolation + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/sessiontxn/isolation/readcommitted.go b/pkg/sessiontxn/isolation/readcommitted.go index 7cd946df70ac9..aa3fed63bbb5c 100644 --- a/pkg/sessiontxn/isolation/readcommitted.go +++ b/pkg/sessiontxn/isolation/readcommitted.go @@ -131,9 +131,9 @@ func (p *PessimisticRCTxnContextProvider) OnStmtRetry(ctx context.Context) error if err := p.basePessimisticTxnContextProvider.OnStmtRetry(ctx); err != nil { return err } - failpoint.Inject("CallOnStmtRetry", func() { + if _, _err_ := failpoint.Eval(_curpkg_("CallOnStmtRetry")); _err_ == nil { sessiontxn.OnStmtRetryCountInc(p.sctx) - }) + } p.latestOracleTSValid = false p.checkTSInWriteStmt = false return p.prepareStmt(false) @@ -164,9 +164,9 @@ func (p *PessimisticRCTxnContextProvider) getOracleFuture() funcFuture { if ts, err = future.Wait(); err != nil { return } - failpoint.Inject("waitTsoOfOracleFuture", func() { + if _, _err_ := failpoint.Eval(_curpkg_("waitTsoOfOracleFuture")); _err_ == nil { sessiontxn.TsoWaitCountInc(p.sctx) - }) + } txnCtx.SetForUpdateTS(ts) ts = txnCtx.GetForUpdateTS() p.latestOracleTS = ts @@ -318,9 +318,9 @@ func (p *PessimisticRCTxnContextProvider) AdviseOptimizeWithPlan(val any) (err e } if useLastOracleTS { - failpoint.Inject("tsoUseConstantFuture", func() { + if _, _err_ := failpoint.Eval(_curpkg_("tsoUseConstantFuture")); _err_ == nil { sessiontxn.TsoUseConstantCountInc(p.sctx) - }) + } p.checkTSInWriteStmt = true p.stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) } diff --git a/pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ b/pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ new file mode 100644 index 0000000000000..7cd946df70ac9 --- /dev/null +++ b/pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ @@ -0,0 +1,360 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package isolation + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + isolation_metrics "github.com/pingcap/tidb/pkg/sessiontxn/isolation/metrics" + "github.com/pingcap/tidb/pkg/util/logutil" + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/oracle" + "go.uber.org/zap" +) + +type stmtState struct { + stmtTS uint64 + stmtTSFuture oracle.Future + stmtUseStartTS bool +} + +func (s *stmtState) prepareStmt(useStartTS bool) error { + *s = stmtState{ + stmtUseStartTS: useStartTS, + } + return nil +} + +// PessimisticRCTxnContextProvider provides txn context for isolation level read-committed +type PessimisticRCTxnContextProvider struct { + basePessimisticTxnContextProvider + stmtState + latestOracleTS uint64 + // latestOracleTSValid shows whether we have already fetched a ts from pd and whether the ts we fetched is still valid. + latestOracleTSValid bool + // checkTSInWriteStmt is used to set RCCheckTS isolation for getting value when doing point-write + checkTSInWriteStmt bool +} + +// NewPessimisticRCTxnContextProvider returns a new PessimisticRCTxnContextProvider +func NewPessimisticRCTxnContextProvider(sctx sessionctx.Context, causalConsistencyOnly bool) *PessimisticRCTxnContextProvider { + provider := &PessimisticRCTxnContextProvider{ + basePessimisticTxnContextProvider: basePessimisticTxnContextProvider{ + baseTxnContextProvider: baseTxnContextProvider{ + sctx: sctx, + causalConsistencyOnly: causalConsistencyOnly, + onInitializeTxnCtx: func(txnCtx *variable.TransactionContext) { + txnCtx.IsPessimistic = true + txnCtx.Isolation = ast.ReadCommitted + }, + onTxnActiveFunc: func(txn kv.Transaction, _ sessiontxn.EnterNewTxnType) { + txn.SetOption(kv.Pessimistic, true) + }, + }, + }, + } + + provider.onTxnActiveFunc = func(txn kv.Transaction, _ sessiontxn.EnterNewTxnType) { + txn.SetOption(kv.Pessimistic, true) + provider.latestOracleTS = txn.StartTS() + provider.latestOracleTSValid = true + } + provider.getStmtReadTSFunc = provider.getStmtTS + provider.getStmtForUpdateTSFunc = provider.getStmtTS + return provider +} + +// OnStmtStart is the hook that should be called when a new statement started +func (p *PessimisticRCTxnContextProvider) OnStmtStart(ctx context.Context, node ast.StmtNode) error { + if err := p.basePessimisticTxnContextProvider.OnStmtStart(ctx, node); err != nil { + return err + } + + // Try to mark the `RCCheckTS` flag for the first time execution of in-transaction read requests + // using read-consistency isolation level. + if node != nil && NeedSetRCCheckTSFlag(p.sctx, node) { + p.sctx.GetSessionVars().StmtCtx.RCCheckTS = true + } + p.checkTSInWriteStmt = false + + return p.prepareStmt(!p.isTxnPrepared) +} + +// NeedSetRCCheckTSFlag checks whether it's needed to set `RCCheckTS` flag in current stmtctx. +func NeedSetRCCheckTSFlag(ctx sessionctx.Context, node ast.Node) bool { + sessionVars := ctx.GetSessionVars() + if sessionVars.ConnectionID > 0 && variable.EnableRCReadCheckTS.Load() && + sessionVars.InTxn() && !sessionVars.RetryInfo.Retrying && + plannercore.IsReadOnly(node, sessionVars) { + return true + } + return false +} + +// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error +func (p *PessimisticRCTxnContextProvider) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { + switch point { + case sessiontxn.StmtErrAfterQuery: + return p.handleAfterQueryError(err) + case sessiontxn.StmtErrAfterPessimisticLock: + return p.handleAfterPessimisticLockError(ctx, err) + default: + return p.basePessimisticTxnContextProvider.OnStmtErrorForNextAction(ctx, point, err) + } +} + +// OnStmtRetry is the hook that should be called when a statement is retried internally. +func (p *PessimisticRCTxnContextProvider) OnStmtRetry(ctx context.Context) error { + if err := p.basePessimisticTxnContextProvider.OnStmtRetry(ctx); err != nil { + return err + } + failpoint.Inject("CallOnStmtRetry", func() { + sessiontxn.OnStmtRetryCountInc(p.sctx) + }) + p.latestOracleTSValid = false + p.checkTSInWriteStmt = false + return p.prepareStmt(false) +} + +func (p *PessimisticRCTxnContextProvider) prepareStmtTS() { + if p.stmtTSFuture != nil { + return + } + sessVars := p.sctx.GetSessionVars() + var stmtTSFuture oracle.Future + switch { + case p.stmtUseStartTS: + stmtTSFuture = funcFuture(p.getTxnStartTS) + case p.latestOracleTSValid && sessVars.StmtCtx.RCCheckTS: + stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) + default: + stmtTSFuture = p.getOracleFuture() + } + + p.stmtTSFuture = stmtTSFuture +} + +func (p *PessimisticRCTxnContextProvider) getOracleFuture() funcFuture { + txnCtx := p.sctx.GetSessionVars().TxnCtx + future := newOracleFuture(p.ctx, p.sctx, txnCtx.TxnScope) + return func() (ts uint64, err error) { + if ts, err = future.Wait(); err != nil { + return + } + failpoint.Inject("waitTsoOfOracleFuture", func() { + sessiontxn.TsoWaitCountInc(p.sctx) + }) + txnCtx.SetForUpdateTS(ts) + ts = txnCtx.GetForUpdateTS() + p.latestOracleTS = ts + p.latestOracleTSValid = true + return + } +} + +func (p *PessimisticRCTxnContextProvider) getStmtTS() (ts uint64, err error) { + if p.stmtTS != 0 { + return p.stmtTS, nil + } + + var txn kv.Transaction + if txn, err = p.ActivateTxn(); err != nil { + return 0, err + } + + p.prepareStmtTS() + start := time.Now() + if ts, err = p.stmtTSFuture.Wait(); err != nil { + return 0, err + } + p.sctx.GetSessionVars().DurationWaitTS += time.Since(start) + + txn.SetOption(kv.SnapshotTS, ts) + p.stmtTS = ts + return +} + +// handleAfterQueryError will be called when the handle point is `StmtErrAfterQuery`. +// At this point the query will be retried from the beginning. +func (p *PessimisticRCTxnContextProvider) handleAfterQueryError(queryErr error) (sessiontxn.StmtErrorAction, error) { + sessVars := p.sctx.GetSessionVars() + if !errors.ErrorEqual(queryErr, kv.ErrWriteConflict) || !sessVars.StmtCtx.RCCheckTS { + return sessiontxn.NoIdea() + } + + isolation_metrics.RcReadCheckTSWriteConfilictCounter.Inc() + + logutil.Logger(p.ctx).Info("RC read with ts checking has failed, retry RC read", + zap.String("sql", sessVars.StmtCtx.OriginalSQL), zap.Error(queryErr)) + return sessiontxn.RetryReady() +} + +func (p *PessimisticRCTxnContextProvider) handleAfterPessimisticLockError(ctx context.Context, lockErr error) (sessiontxn.StmtErrorAction, error) { + txnCtx := p.sctx.GetSessionVars().TxnCtx + retryable := false + if deadlock, ok := errors.Cause(lockErr).(*tikverr.ErrDeadlock); ok && deadlock.IsRetryable { + logutil.Logger(p.ctx).Info("single statement deadlock, retry statement", + zap.Uint64("txn", txnCtx.StartTS), + zap.Uint64("lockTS", deadlock.LockTs), + zap.Stringer("lockKey", kv.Key(deadlock.LockKey)), + zap.Uint64("deadlockKeyHash", deadlock.DeadlockKeyHash)) + retryable = true + + // In fair locking mode, when statement retry happens, `retryFairLockingIfNeeded` should be + // called to make its state ready for retrying. But single-statement deadlock is an exception. We need to exit + // fair locking in single-statement-deadlock case, otherwise the lock this statement has acquired won't be + // released after retrying, so it still blocks another transaction and the deadlock won't be resolved. + if err := p.cancelFairLockingIfNeeded(ctx); err != nil { + return sessiontxn.ErrorAction(err) + } + } else if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { + logutil.Logger(p.ctx).Debug("pessimistic write conflict, retry statement", + zap.Uint64("txn", txnCtx.StartTS), + zap.Uint64("forUpdateTS", txnCtx.GetForUpdateTS()), + zap.String("err", lockErr.Error())) + retryable = true + if p.checkTSInWriteStmt { + isolation_metrics.RcWriteCheckTSWriteConfilictCounter.Inc() + } + } + + if retryable { + if err := p.basePessimisticTxnContextProvider.retryFairLockingIfNeeded(ctx); err != nil { + return sessiontxn.ErrorAction(err) + } + return sessiontxn.RetryReady() + } + return sessiontxn.ErrorAction(lockErr) +} + +// AdviseWarmup provides warmup for inner state +func (p *PessimisticRCTxnContextProvider) AdviseWarmup() error { + if err := p.prepareTxn(); err != nil { + return err + } + + if !p.isTidbSnapshotEnabled() { + p.prepareStmtTS() + } + + return nil +} + +// planSkipGetTsoFromPD identifies the plans which don't need get newest ts from PD. +func planSkipGetTsoFromPD(sctx sessionctx.Context, plan base.Plan, inLockOrWriteStmt bool) bool { + switch v := plan.(type) { + case *plannercore.PointGetPlan: + return sctx.GetSessionVars().RcWriteCheckTS && (v.Lock || inLockOrWriteStmt) + case base.PhysicalPlan: + if len(v.Children()) == 0 { + return false + } + _, isPhysicalLock := v.(*plannercore.PhysicalLock) + for _, p := range v.Children() { + if !planSkipGetTsoFromPD(sctx, p, isPhysicalLock || inLockOrWriteStmt) { + return false + } + } + return true + case *plannercore.Update: + return planSkipGetTsoFromPD(sctx, v.SelectPlan, true) + case *plannercore.Delete: + return planSkipGetTsoFromPD(sctx, v.SelectPlan, true) + case *plannercore.Insert: + return v.SelectPlan == nil && len(v.OnDuplicate) == 0 && !v.IsReplace + } + return false +} + +// AdviseOptimizeWithPlan in read-committed covers as many cases as repeatable-read. +// We do not fetch latest ts immediately for such scenes. +// 1. A query like the form of "SELECT ... FOR UPDATE" whose execution plan is "PointGet". +// 2. An INSERT statement without "SELECT" subquery. +// 3. A UPDATE statement whose sub execution plan is "PointGet". +// 4. A DELETE statement whose sub execution plan is "PointGet". +func (p *PessimisticRCTxnContextProvider) AdviseOptimizeWithPlan(val any) (err error) { + if p.isTidbSnapshotEnabled() || p.isBeginStmtWithStaleRead() { + return nil + } + if p.stmtUseStartTS || !p.latestOracleTSValid { + return nil + } + + plan, ok := val.(base.Plan) + if !ok { + return nil + } + + if execute, ok := plan.(*plannercore.Execute); ok { + plan = execute.Plan + } + + useLastOracleTS := false + if !p.sctx.GetSessionVars().RetryInfo.Retrying { + useLastOracleTS = planSkipGetTsoFromPD(p.sctx, plan, false) + } + + if useLastOracleTS { + failpoint.Inject("tsoUseConstantFuture", func() { + sessiontxn.TsoUseConstantCountInc(p.sctx) + }) + p.checkTSInWriteStmt = true + p.stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) + } + + return nil +} + +// GetSnapshotWithStmtForUpdateTS gets snapshot with for update ts +func (p *PessimisticRCTxnContextProvider) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { + snapshot, err := p.basePessimisticTxnContextProvider.GetSnapshotWithStmtForUpdateTS() + if err != nil { + return nil, err + } + if p.checkTSInWriteStmt { + snapshot.SetOption(kv.IsolationLevel, kv.RCCheckTS) + } + return snapshot, err +} + +// GetSnapshotWithStmtReadTS gets snapshot with read ts +func (p *PessimisticRCTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { + snapshot, err := p.basePessimisticTxnContextProvider.GetSnapshotWithStmtForUpdateTS() + if err != nil { + return nil, err + } + + if p.sctx.GetSessionVars().StmtCtx.RCCheckTS { + snapshot.SetOption(kv.IsolationLevel, kv.RCCheckTS) + } + + return snapshot, nil +} + +// IsCheckTSInWriteStmtMode is only used for test +func (p *PessimisticRCTxnContextProvider) IsCheckTSInWriteStmtMode() bool { + return p.checkTSInWriteStmt +} diff --git a/pkg/sessiontxn/isolation/repeatable_read.go b/pkg/sessiontxn/isolation/repeatable_read.go index 55f80568f1f88..077815399acbc 100644 --- a/pkg/sessiontxn/isolation/repeatable_read.go +++ b/pkg/sessiontxn/isolation/repeatable_read.go @@ -114,9 +114,9 @@ func (p *PessimisticRRTxnContextProvider) updateForUpdateTS() (err error) { return errors.Trace(kv.ErrInvalidTxn) } - failpoint.Inject("RequestTsoFromPD", func() { + if _, _err_ := failpoint.Eval(_curpkg_("RequestTsoFromPD")); _err_ == nil { sessiontxn.TsoRequestCountInc(sctx) - }) + } // Because the ForUpdateTS is used for the snapshot for reading data in DML. // We can avoid allocating a global TSO here to speed it up by using the local TSO. diff --git a/pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ b/pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ new file mode 100644 index 0000000000000..55f80568f1f88 --- /dev/null +++ b/pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ @@ -0,0 +1,309 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package isolation + +import ( + "context" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/terror" + plannercore "github.com/pingcap/tidb/pkg/planner/core" + "github.com/pingcap/tidb/pkg/planner/core/base" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/sessiontxn" + "github.com/pingcap/tidb/pkg/util/logutil" + tikverr "github.com/tikv/client-go/v2/error" + "go.uber.org/zap" +) + +// PessimisticRRTxnContextProvider provides txn context for isolation level repeatable-read +type PessimisticRRTxnContextProvider struct { + basePessimisticTxnContextProvider + + // Used for ForUpdateRead statement + forUpdateTS uint64 + latestForUpdateTS uint64 + // It may decide whether to update forUpdateTs when calling provider's getForUpdateTs + // See more details in the comments of optimizeWithPlan + optimizeForNotFetchingLatestTS bool +} + +// NewPessimisticRRTxnContextProvider returns a new PessimisticRRTxnContextProvider +func NewPessimisticRRTxnContextProvider(sctx sessionctx.Context, causalConsistencyOnly bool) *PessimisticRRTxnContextProvider { + provider := &PessimisticRRTxnContextProvider{ + basePessimisticTxnContextProvider: basePessimisticTxnContextProvider{ + baseTxnContextProvider: baseTxnContextProvider{ + sctx: sctx, + causalConsistencyOnly: causalConsistencyOnly, + onInitializeTxnCtx: func(txnCtx *variable.TransactionContext) { + txnCtx.IsPessimistic = true + txnCtx.Isolation = ast.RepeatableRead + }, + onTxnActiveFunc: func(txn kv.Transaction, _ sessiontxn.EnterNewTxnType) { + txn.SetOption(kv.Pessimistic, true) + }, + }, + }, + } + + provider.getStmtReadTSFunc = provider.getTxnStartTS + provider.getStmtForUpdateTSFunc = provider.getForUpdateTs + + return provider +} + +func (p *PessimisticRRTxnContextProvider) getForUpdateTs() (ts uint64, err error) { + if p.forUpdateTS != 0 { + return p.forUpdateTS, nil + } + + var txn kv.Transaction + if txn, err = p.ActivateTxn(); err != nil { + return 0, err + } + + if p.optimizeForNotFetchingLatestTS { + p.forUpdateTS = p.sctx.GetSessionVars().TxnCtx.GetForUpdateTS() + return p.forUpdateTS, nil + } + + txnCtx := p.sctx.GetSessionVars().TxnCtx + futureTS := newOracleFuture(p.ctx, p.sctx, txnCtx.TxnScope) + + start := time.Now() + if ts, err = futureTS.Wait(); err != nil { + return 0, err + } + p.sctx.GetSessionVars().DurationWaitTS += time.Since(start) + + txnCtx.SetForUpdateTS(ts) + txn.SetOption(kv.SnapshotTS, ts) + + p.forUpdateTS = ts + + return +} + +// updateForUpdateTS acquires the latest TSO and update the TransactionContext and kv.Transaction with it. +func (p *PessimisticRRTxnContextProvider) updateForUpdateTS() (err error) { + sctx := p.sctx + var txn kv.Transaction + + if txn, err = sctx.Txn(false); err != nil { + return err + } + + if !txn.Valid() { + return errors.Trace(kv.ErrInvalidTxn) + } + + failpoint.Inject("RequestTsoFromPD", func() { + sessiontxn.TsoRequestCountInc(sctx) + }) + + // Because the ForUpdateTS is used for the snapshot for reading data in DML. + // We can avoid allocating a global TSO here to speed it up by using the local TSO. + version, err := sctx.GetStore().CurrentVersion(sctx.GetSessionVars().TxnCtx.TxnScope) + if err != nil { + return err + } + + sctx.GetSessionVars().TxnCtx.SetForUpdateTS(version.Ver) + p.latestForUpdateTS = version.Ver + txn.SetOption(kv.SnapshotTS, version.Ver) + + return nil +} + +// OnStmtStart is the hook that should be called when a new statement started +func (p *PessimisticRRTxnContextProvider) OnStmtStart(ctx context.Context, node ast.StmtNode) error { + if err := p.basePessimisticTxnContextProvider.OnStmtStart(ctx, node); err != nil { + return err + } + + p.forUpdateTS = 0 + p.optimizeForNotFetchingLatestTS = false + + return nil +} + +// OnStmtRetry is the hook that should be called when a statement is retried internally. +func (p *PessimisticRRTxnContextProvider) OnStmtRetry(ctx context.Context) (err error) { + if err = p.basePessimisticTxnContextProvider.OnStmtRetry(ctx); err != nil { + return err + } + + // If TxnCtx.forUpdateTS is updated in OnStmtErrorForNextAction, we assign the value to the provider + if p.latestForUpdateTS > p.forUpdateTS { + p.forUpdateTS = p.latestForUpdateTS + } else { + p.forUpdateTS = 0 + } + + p.optimizeForNotFetchingLatestTS = false + + return nil +} + +// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error +func (p *PessimisticRRTxnContextProvider) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { + switch point { + case sessiontxn.StmtErrAfterPessimisticLock: + return p.handleAfterPessimisticLockError(ctx, err) + default: + return sessiontxn.NoIdea() + } +} + +// AdviseOptimizeWithPlan optimizes for update point get related execution. +// Use case: In for update point get related operations, we do not fetch ts from PD but use the last ts we fetched. +// +// We expect that the data that the point get acquires has not been changed. +// +// Benefit: Save the cost of acquiring ts from PD. +// Drawbacks: If the data has been changed since the ts we used, we need to retry. +// One exception is insert operation, when it has no select plan, we do not fetch the latest ts immediately. We only update ts +// if write conflict is incurred. +func (p *PessimisticRRTxnContextProvider) AdviseOptimizeWithPlan(val any) (err error) { + if p.isTidbSnapshotEnabled() || p.isBeginStmtWithStaleRead() { + return nil + } + + plan, ok := val.(base.Plan) + if !ok { + return nil + } + + if execute, ok := plan.(*plannercore.Execute); ok { + plan = execute.Plan + } + + p.optimizeForNotFetchingLatestTS = notNeedGetLatestTSFromPD(plan, false) + + return nil +} + +// notNeedGetLatestTSFromPD searches for optimization condition recursively +// Note: For point get and batch point get (name it plan), if one of the ancestor node is update/delete/physicalLock, +// we should check whether the plan.Lock is true or false. See comments in needNotToBeOptimized. +// inLockOrWriteStmt = true means one of the ancestor node is update/delete/physicalLock. +func notNeedGetLatestTSFromPD(plan base.Plan, inLockOrWriteStmt bool) bool { + switch v := plan.(type) { + case *plannercore.PointGetPlan: + // We do not optimize the point get/ batch point get if plan.lock = false and inLockOrWriteStmt = true. + // Theoretically, the plan.lock should be true if the flag is true. But due to the bug describing in Issue35524, + // the plan.lock can be false in the case of inLockOrWriteStmt being true. In this case, optimization here can lead to different results + // which cannot be accepted as AdviseOptimizeWithPlan cannot change results. + return !inLockOrWriteStmt || v.Lock + case *plannercore.BatchPointGetPlan: + return !inLockOrWriteStmt || v.Lock + case base.PhysicalPlan: + if len(v.Children()) == 0 { + return false + } + _, isPhysicalLock := v.(*plannercore.PhysicalLock) + for _, p := range v.Children() { + if !notNeedGetLatestTSFromPD(p, isPhysicalLock || inLockOrWriteStmt) { + return false + } + } + return true + case *plannercore.Update: + return notNeedGetLatestTSFromPD(v.SelectPlan, true) + case *plannercore.Delete: + return notNeedGetLatestTSFromPD(v.SelectPlan, true) + case *plannercore.Insert: + return v.SelectPlan == nil + } + return false +} + +func (p *PessimisticRRTxnContextProvider) handleAfterPessimisticLockError(ctx context.Context, lockErr error) (sessiontxn.StmtErrorAction, error) { + sessVars := p.sctx.GetSessionVars() + txnCtx := sessVars.TxnCtx + + if deadlock, ok := errors.Cause(lockErr).(*tikverr.ErrDeadlock); ok { + if !deadlock.IsRetryable { + return sessiontxn.ErrorAction(lockErr) + } + + logutil.Logger(p.ctx).Info("single statement deadlock, retry statement", + zap.Uint64("txn", txnCtx.StartTS), + zap.Uint64("lockTS", deadlock.LockTs), + zap.Stringer("lockKey", kv.Key(deadlock.LockKey)), + zap.Uint64("deadlockKeyHash", deadlock.DeadlockKeyHash)) + + // In fair locking mode, when statement retry happens, `retryFairLockingIfNeeded` should be + // called to make its state ready for retrying. But single-statement deadlock is an exception. We need to exit + // fair locking in single-statement-deadlock case, otherwise the lock this statement has acquired won't be + // released after retrying, so it still blocks another transaction and the deadlock won't be resolved. + if err := p.cancelFairLockingIfNeeded(ctx); err != nil { + return sessiontxn.ErrorAction(err) + } + } else if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { + // Always update forUpdateTS by getting a new timestamp from PD. + // If we use the conflict commitTS as the new forUpdateTS and async commit + // is used, the commitTS of this transaction may exceed the max timestamp + // that PD allocates. Then, the change may be invisible to a new transaction, + // which means linearizability is broken. + // suppose the following scenario: + // - Txn1/2/3 get start-ts + // - Txn1/2 all get min-commit-ts as required by async commit from PD in order + // - now max ts on PD is PD-max-ts + // - Txn2 commit with calculated commit-ts = PD-max-ts + 1 + // - Txn3 try lock a key committed by Txn2 and get write conflict and use + // conflict commit-ts as forUpdateTS, lock and read, TiKV will update its + // max-ts to PD-max-ts + 1 + // - Txn1 commit with calculated commit-ts = PD-max-ts + 2 + // - suppose Txn4 after Txn1 on same session, it gets start-ts = PD-max-ts + 1 from PD + // - Txn4 cannot see Txn1's changes because its start-ts is less than Txn1's commit-ts + // which breaks linearizability. + errStr := lockErr.Error() + forUpdateTS := txnCtx.GetForUpdateTS() + + logutil.Logger(p.ctx).Debug("pessimistic write conflict, retry statement", + zap.Uint64("txn", txnCtx.StartTS), + zap.Uint64("forUpdateTS", forUpdateTS), + zap.String("err", errStr)) + } else { + // This branch: if err is not nil, always update forUpdateTS to avoid problem described below. + // For nowait, when ErrLock happened, ErrLockAcquireFailAndNoWaitSet will be returned, and in the same txn + // the select for updateTs must be updated, otherwise there maybe rollback problem. + // begin + // select for update key1 (here encounters ErrLocked or other errors (or max_execution_time like util), + // key1 lock has not gotten and async rollback key1 is raised) + // select for update key1 again (this time lock is acquired successfully (maybe lock was released by others)) + // the async rollback operation rollbacks the lock just acquired + if err := p.updateForUpdateTS(); err != nil { + logutil.Logger(p.ctx).Warn("UpdateForUpdateTS failed", zap.Error(err)) + } + + return sessiontxn.ErrorAction(lockErr) + } + + if err := p.updateForUpdateTS(); err != nil { + return sessiontxn.ErrorAction(lockErr) + } + + if err := p.retryFairLockingIfNeeded(ctx); err != nil { + return sessiontxn.ErrorAction(err) + } + return sessiontxn.RetryReady() +} diff --git a/pkg/sessiontxn/staleread/binding__failpoint_binding__.go b/pkg/sessiontxn/staleread/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..19c3042b7bdd1 --- /dev/null +++ b/pkg/sessiontxn/staleread/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package staleread + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/sessiontxn/staleread/util.go b/pkg/sessiontxn/staleread/util.go index c66daedfa2ab5..6919733064282 100644 --- a/pkg/sessiontxn/staleread/util.go +++ b/pkg/sessiontxn/staleread/util.go @@ -35,9 +35,9 @@ import ( // CalculateAsOfTsExpr calculates the TsExpr of AsOfClause to get a StartTS. func CalculateAsOfTsExpr(ctx context.Context, sctx pctx.PlanContext, tsExpr ast.ExprNode) (uint64, error) { sctx.GetSessionVars().StmtCtx.SetStaleTSOProvider(func() (uint64, error) { - failpoint.Inject("mockStaleReadTSO", func(val failpoint.Value) (uint64, error) { + if val, _err_ := failpoint.Eval(_curpkg_("mockStaleReadTSO")); _err_ == nil { return uint64(val.(int)), nil - }) + } // this function accepts a context, but we don't need it when there is a valid cached ts. // in most cases, the stale read ts can be calculated from `cached ts + time since cache - staleness`, // this can be more accurate than `time.Now() - staleness`, because TiDB's local time can drift. diff --git a/pkg/sessiontxn/staleread/util.go__failpoint_stash__ b/pkg/sessiontxn/staleread/util.go__failpoint_stash__ new file mode 100644 index 0000000000000..c66daedfa2ab5 --- /dev/null +++ b/pkg/sessiontxn/staleread/util.go__failpoint_stash__ @@ -0,0 +1,97 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package staleread + +import ( + "context" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/mysql" + pctx "github.com/pingcap/tidb/pkg/planner/context" + plannerutil "github.com/pingcap/tidb/pkg/planner/util" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" + "github.com/tikv/client-go/v2/oracle" +) + +// CalculateAsOfTsExpr calculates the TsExpr of AsOfClause to get a StartTS. +func CalculateAsOfTsExpr(ctx context.Context, sctx pctx.PlanContext, tsExpr ast.ExprNode) (uint64, error) { + sctx.GetSessionVars().StmtCtx.SetStaleTSOProvider(func() (uint64, error) { + failpoint.Inject("mockStaleReadTSO", func(val failpoint.Value) (uint64, error) { + return uint64(val.(int)), nil + }) + // this function accepts a context, but we don't need it when there is a valid cached ts. + // in most cases, the stale read ts can be calculated from `cached ts + time since cache - staleness`, + // this can be more accurate than `time.Now() - staleness`, because TiDB's local time can drift. + return sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) + }) + tsVal, err := plannerutil.EvalAstExprWithPlanCtx(sctx, tsExpr) + if err != nil { + return 0, err + } + + if tsVal.IsNull() { + return 0, plannererrors.ErrAsOf.FastGenWithCause("as of timestamp cannot be NULL") + } + + toTypeTimestamp := types.NewFieldType(mysql.TypeTimestamp) + // We need at least the millionsecond here, so set fsp to 3. + toTypeTimestamp.SetDecimal(3) + tsTimestamp, err := tsVal.ConvertTo(sctx.GetSessionVars().StmtCtx.TypeCtx(), toTypeTimestamp) + if err != nil { + return 0, err + } + tsTime, err := tsTimestamp.GetMysqlTime().GoTime(sctx.GetSessionVars().Location()) + if err != nil { + return 0, err + } + return oracle.GoTimeToTS(tsTime), nil +} + +// CalculateTsWithReadStaleness calculates the TsExpr for readStaleness duration +func CalculateTsWithReadStaleness(sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { + nowVal, err := expression.GetStmtTimestamp(sctx.GetExprCtx().GetEvalCtx()) + if err != nil { + return 0, err + } + tsVal := nowVal.Add(readStaleness) + sc := sctx.GetSessionVars().StmtCtx + minTsVal := expression.GetStmtMinSafeTime(sc, sctx.GetStore(), sc.TimeZone()) + return oracle.GoTimeToTS(expression.CalAppropriateTime(tsVal, nowVal, minTsVal)), nil +} + +// IsStmtStaleness indicates whether the current statement is staleness or not +func IsStmtStaleness(sctx sessionctx.Context) bool { + return sctx.GetSessionVars().StmtCtx.IsStaleness +} + +// GetExternalTimestamp returns the external timestamp in cache, or get and store it in cache +func GetExternalTimestamp(ctx context.Context, sc *stmtctx.StatementContext) (uint64, error) { + // Try to get from the stmt cache to make sure this function is deterministic. + externalTimestamp, err := sc.GetOrEvaluateStmtCache(stmtctx.StmtExternalTSCacheKey, func() (any, error) { + return variable.GetExternalTimestamp(ctx) + }) + + if err != nil { + return 0, plannererrors.ErrAsOf.FastGenWithCause(err.Error()) + } + return externalTimestamp.(uint64), nil +} diff --git a/pkg/statistics/binding__failpoint_binding__.go b/pkg/statistics/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..d74e174269a9b --- /dev/null +++ b/pkg/statistics/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package statistics + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/statistics/cmsketch.go b/pkg/statistics/cmsketch.go index 5c013189b55fe..6fa5eb769eda2 100644 --- a/pkg/statistics/cmsketch.go +++ b/pkg/statistics/cmsketch.go @@ -99,7 +99,7 @@ func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { } } slices.SortStableFunc(sorted, func(i, j dataCnt) int { return -cmp.Compare(i.cnt, j.cnt) }) - failpoint.Inject("StabilizeV1AnalyzeTopN", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("StabilizeV1AnalyzeTopN")); _err_ == nil { if val.(bool) { // The earlier TopN entry will modify the CMSketch, therefore influence later TopN entry's row count. // So we need to make the order here fully deterministic to make the stats from analyze ver1 stable. @@ -109,7 +109,7 @@ func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { (sorted[i].cnt == sorted[j].cnt && string(sorted[i].data) < string(sorted[j].data)) }) } - }) + } var ( sumTopN uint64 @@ -270,9 +270,9 @@ func QueryValue(sctx context.PlanContext, c *CMSketch, t *TopN, val types.Datum) // QueryBytes is used to query the count of specified bytes. func (c *CMSketch) QueryBytes(d []byte) uint64 { - failpoint.Inject("mockQueryBytesMaxUint64", func(val failpoint.Value) { - failpoint.Return(uint64(val.(int))) - }) + if val, _err_ := failpoint.Eval(_curpkg_("mockQueryBytesMaxUint64")); _err_ == nil { + return uint64(val.(int)) + } h1, h2 := murmur3.Sum128(d) return c.queryHashValue(nil, h1, h2) } diff --git a/pkg/statistics/cmsketch.go__failpoint_stash__ b/pkg/statistics/cmsketch.go__failpoint_stash__ new file mode 100644 index 0000000000000..5c013189b55fe --- /dev/null +++ b/pkg/statistics/cmsketch.go__failpoint_stash__ @@ -0,0 +1,865 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package statistics + +import ( + "bytes" + "cmp" + "fmt" + "math" + "reflect" + "slices" + "sort" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/planner/context" + "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tipb/go-tipb" + "github.com/twmb/murmur3" +) + +// topNThreshold is the minimum ratio of the number of topN elements in CMSketch, 10 means 1 / 10 = 10%. +const topNThreshold = uint64(10) + +var ( + // ErrQueryInterrupted indicates interrupted + ErrQueryInterrupted = dbterror.ClassExecutor.NewStd(mysql.ErrQueryInterrupted) +) + +// CMSketch is used to estimate point queries. +// Refer: https://en.wikipedia.org/wiki/Count-min_sketch +type CMSketch struct { + table [][]uint32 + count uint64 // TopN is not counted in count + defaultValue uint64 // In sampled data, if cmsketch returns a small value (less than avg value / 2), then this will returned. + depth int32 + width int32 +} + +// NewCMSketch returns a new CM sketch. +func NewCMSketch(d, w int32) *CMSketch { + tbl := make([][]uint32, d) + // Background: The Go's memory allocator will ask caller to sweep spans in some scenarios. + // This can cause memory allocation request latency unpredictable, if the list of spans which need sweep is too long. + // For memory allocation large than 32K, the allocator will never allocate memory from spans list. + // + // The memory referenced by the CMSketch will never be freed. + // If the number of table or index is extremely large, there will be a large amount of spans in global list. + // The default value of `d` is 5 and `w` is 2048, if we use a single slice for them the size will be 40K. + // This allocation will be handled by mheap and will never have impact on normal allocations. + arena := make([]uint32, d*w) + for i := range tbl { + tbl[i] = arena[i*int(w) : (i+1)*int(w)] + } + return &CMSketch{depth: d, width: w, table: tbl} +} + +// topNHelper wraps some variables used when building cmsketch with top n. +type topNHelper struct { + sorted []dataCnt + sampleSize uint64 + onlyOnceItems uint64 + sumTopN uint64 + actualNumTop uint32 +} + +func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { + counter := make(map[hack.MutableString]uint64, len(sample)) + for i := range sample { + counter[hack.String(sample[i])]++ + } + sorted, onlyOnceItems := make([]dataCnt, 0, len(counter)), uint64(0) + for key, cnt := range counter { + sorted = append(sorted, dataCnt{hack.Slice(string(key)), cnt}) + if cnt == 1 { + onlyOnceItems++ + } + } + slices.SortStableFunc(sorted, func(i, j dataCnt) int { return -cmp.Compare(i.cnt, j.cnt) }) + failpoint.Inject("StabilizeV1AnalyzeTopN", func(val failpoint.Value) { + if val.(bool) { + // The earlier TopN entry will modify the CMSketch, therefore influence later TopN entry's row count. + // So we need to make the order here fully deterministic to make the stats from analyze ver1 stable. + // See (*SampleCollector).ExtractTopN(), which calls this function, for details + sort.SliceStable(sorted, func(i, j int) bool { + return sorted[i].cnt > sorted[j].cnt || + (sorted[i].cnt == sorted[j].cnt && string(sorted[i].data) < string(sorted[j].data)) + }) + } + }) + + var ( + sumTopN uint64 + sampleNDV = uint32(len(sorted)) + ) + numTop = min(sampleNDV, numTop) // Ensure numTop no larger than sampNDV. + // Only element whose frequency is not smaller than 2/3 multiples the + // frequency of the n-th element are added to the TopN statistics. We chose + // 2/3 as an empirical value because the average cardinality estimation + // error is relatively small compared with 1/2. + var actualNumTop uint32 + for ; actualNumTop < sampleNDV && actualNumTop < numTop*2; actualNumTop++ { + if actualNumTop >= numTop && sorted[actualNumTop].cnt*3 < sorted[numTop-1].cnt*2 { + break + } + if sorted[actualNumTop].cnt == 1 { + break + } + sumTopN += sorted[actualNumTop].cnt + } + + return &topNHelper{sorted, uint64(len(sample)), onlyOnceItems, sumTopN, actualNumTop} +} + +// NewCMSketchAndTopN returns a new CM sketch with TopN elements, the estimate NDV and the scale ratio. +func NewCMSketchAndTopN(d, w int32, sample [][]byte, numTop uint32, rowCount uint64) (*CMSketch, *TopN, uint64, uint64) { + if rowCount == 0 || len(sample) == 0 { + return nil, nil, 0, 0 + } + helper := newTopNHelper(sample, numTop) + // rowCount is not a accurate value when fast analyzing + // In some cases, if user triggers fast analyze when rowCount is close to sampleSize, unexpected bahavior might happen. + rowCount = max(rowCount, uint64(len(sample))) + estimateNDV, scaleRatio := calculateEstimateNDV(helper, rowCount) + defaultVal := calculateDefaultVal(helper, estimateNDV, scaleRatio, rowCount) + c, t := buildCMSAndTopN(helper, d, w, scaleRatio, defaultVal) + return c, t, estimateNDV, scaleRatio +} + +func buildCMSAndTopN(helper *topNHelper, d, w int32, scaleRatio uint64, defaultVal uint64) (c *CMSketch, t *TopN) { + c = NewCMSketch(d, w) + enableTopN := helper.sampleSize/topNThreshold <= helper.sumTopN + if enableTopN { + t = NewTopN(int(helper.actualNumTop)) + for i := uint32(0); i < helper.actualNumTop; i++ { + data, cnt := helper.sorted[i].data, helper.sorted[i].cnt + t.AppendTopN(data, cnt*scaleRatio) + } + t.Sort() + helper.sorted = helper.sorted[helper.actualNumTop:] + } + c.defaultValue = defaultVal + for i := range helper.sorted { + data, cnt := helper.sorted[i].data, helper.sorted[i].cnt + // If the value only occurred once in the sample, we assumes that there is no difference with + // value that does not occurred in the sample. + rowCount := defaultVal + if cnt > 1 { + rowCount = cnt * scaleRatio + } + c.InsertBytesByCount(data, rowCount) + } + return +} + +func calculateDefaultVal(helper *topNHelper, estimateNDV, scaleRatio, rowCount uint64) uint64 { + sampleNDV := uint64(len(helper.sorted)) + if rowCount <= (helper.sampleSize-helper.onlyOnceItems)*scaleRatio { + return 1 + } + estimateRemainingCount := rowCount - (helper.sampleSize-helper.onlyOnceItems)*scaleRatio + return estimateRemainingCount / max(1, estimateNDV-sampleNDV+helper.onlyOnceItems) +} + +// MemoryUsage returns the total memory usage of a CMSketch. +// only calc the hashtable size(CMSketch.table) and the CMSketch.topN +// data are not tracked because size of CMSketch.topN take little influence +// We ignore the size of other metadata in CMSketch. +func (c *CMSketch) MemoryUsage() (sum int64) { + if c == nil { + return + } + sum = int64(c.depth * c.width * 4) + return +} + +// InsertBytes inserts the bytes value into the CM Sketch. +func (c *CMSketch) InsertBytes(bytes []byte) { + c.InsertBytesByCount(bytes, 1) +} + +// InsertBytesByCount adds the bytes value into the TopN (if value already in TopN) or CM Sketch by delta, this does not updates c.defaultValue. +func (c *CMSketch) InsertBytesByCount(bytes []byte, count uint64) { + h1, h2 := murmur3.Sum128(bytes) + c.count += count + for i := range c.table { + j := (h1 + h2*uint64(i)) % uint64(c.width) + c.table[i][j] += uint32(count) + } +} + +func (c *CMSketch) considerDefVal(cnt uint64) bool { + return (cnt == 0 || (cnt > c.defaultValue && cnt < 2*(c.count/uint64(c.width)))) && c.defaultValue > 0 +} + +// setValue sets the count for value that hashed into (h1, h2), and update defaultValue if necessary. +func (c *CMSketch) setValue(h1, h2 uint64, count uint64) { + oriCount := c.queryHashValue(nil, h1, h2) + if c.considerDefVal(oriCount) { + // We should update c.defaultValue if we used c.defaultValue when getting the estimate count. + // This should make estimation better, remove this line if it does not work as expected. + c.defaultValue = uint64(float64(c.defaultValue)*0.95 + float64(c.defaultValue)*0.05) + if c.defaultValue == 0 { + // c.defaultValue never guess 0 since we are using a sampled data. + c.defaultValue = 1 + } + } + + c.count += count - oriCount + // let it overflow naturally + deltaCount := uint32(count) - uint32(oriCount) + for i := range c.table { + j := (h1 + h2*uint64(i)) % uint64(c.width) + c.table[i][j] = c.table[i][j] + deltaCount + } +} + +// SubValue remove a value from the CMSketch. +func (c *CMSketch) SubValue(h1, h2 uint64, count uint64) { + c.count -= count + for i := range c.table { + j := (h1 + h2*uint64(i)) % uint64(c.width) + c.table[i][j] = c.table[i][j] - uint32(count) + } +} + +// QueryValue is used to query the count of specified value. +func QueryValue(sctx context.PlanContext, c *CMSketch, t *TopN, val types.Datum) (uint64, error) { + var sc *stmtctx.StatementContext + tz := time.UTC + if sctx != nil { + sc = sctx.GetSessionVars().StmtCtx + tz = sc.TimeZone() + } + rawData, err := tablecodec.EncodeValue(tz, nil, val) + if sc != nil { + err = sc.HandleError(err) + } + if err != nil { + return 0, errors.Trace(err) + } + h1, h2 := murmur3.Sum128(rawData) + if ret, ok := t.QueryTopN(sctx, rawData); ok { + return ret, nil + } + return c.queryHashValue(sctx, h1, h2), nil +} + +// QueryBytes is used to query the count of specified bytes. +func (c *CMSketch) QueryBytes(d []byte) uint64 { + failpoint.Inject("mockQueryBytesMaxUint64", func(val failpoint.Value) { + failpoint.Return(uint64(val.(int))) + }) + h1, h2 := murmur3.Sum128(d) + return c.queryHashValue(nil, h1, h2) +} + +// The input sctx is just for debug trace, you can pass nil safely if that's not needed. +func (c *CMSketch) queryHashValue(sctx context.PlanContext, h1, h2 uint64) (result uint64) { + vals := make([]uint32, c.depth) + originVals := make([]uint32, c.depth) + minValue := uint32(math.MaxUint32) + useDefaultValue := false + if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + defer func() { + debugtrace.RecordAnyValuesWithNames(sctx, + "Origin Values", originVals, + "Values", vals, + "Use default value", useDefaultValue, + "Result", result, + ) + debugtrace.LeaveContextCommon(sctx) + }() + } + // We want that when res is 0 before the noise is eliminated, the default value is not used. + // So we need a temp value to distinguish before and after eliminating noise. + temp := uint32(1) + for i := range c.table { + j := (h1 + h2*uint64(i)) % uint64(c.width) + originVals[i] = c.table[i][j] + if minValue > c.table[i][j] { + minValue = c.table[i][j] + } + noise := (c.count - uint64(c.table[i][j])) / (uint64(c.width) - 1) + if uint64(c.table[i][j]) == 0 { + vals[i] = 0 + } else if uint64(c.table[i][j]) < noise { + vals[i] = temp + } else { + vals[i] = c.table[i][j] - uint32(noise) + temp + } + } + slices.Sort(vals) + res := vals[(c.depth-1)/2] + (vals[c.depth/2]-vals[(c.depth-1)/2])/2 + if res > minValue+temp { + res = minValue + temp + } + if res == 0 { + return uint64(0) + } + res = res - temp + if c.considerDefVal(uint64(res)) { + useDefaultValue = true + return c.defaultValue + } + return uint64(res) +} + +// MergeTopNAndUpdateCMSketch merges the src TopN into the dst, and spilled values will be inserted into the CMSketch. +func MergeTopNAndUpdateCMSketch(dst, src *TopN, c *CMSketch, numTop uint32) []TopNMeta { + topNs := []*TopN{src, dst} + mergedTopN, popedTopNPair := MergeTopN(topNs, numTop) + if mergedTopN == nil { + // mergedTopN == nil means the total count of the input TopN are equal to zero + return popedTopNPair + } + dst.TopN = mergedTopN.TopN + for _, topNMeta := range popedTopNPair { + c.InsertBytesByCount(topNMeta.Encoded, topNMeta.Count) + } + return popedTopNPair +} + +// MergeCMSketch merges two CM Sketch. +func (c *CMSketch) MergeCMSketch(rc *CMSketch) error { + if c == nil || rc == nil { + return nil + } + if c.depth != rc.depth || c.width != rc.width { + return errors.New("Dimensions of Count-Min Sketch should be the same") + } + c.count += rc.count + for i := range c.table { + for j := range c.table[i] { + c.table[i][j] += rc.table[i][j] + } + } + return nil +} + +// CMSketchToProto converts CMSketch to its protobuf representation. +func CMSketchToProto(c *CMSketch, topn *TopN) *tipb.CMSketch { + protoSketch := &tipb.CMSketch{} + if c != nil { + protoSketch.Rows = make([]*tipb.CMSketchRow, c.depth) + for i := range c.table { + protoSketch.Rows[i] = &tipb.CMSketchRow{Counters: make([]uint32, c.width)} + copy(protoSketch.Rows[i].Counters, c.table[i]) + } + protoSketch.DefaultValue = c.defaultValue + } + if topn != nil { + for _, dataMeta := range topn.TopN { + protoSketch.TopN = append(protoSketch.TopN, &tipb.CMSketchTopN{Data: dataMeta.Encoded, Count: dataMeta.Count}) + } + } + return protoSketch +} + +// CMSketchAndTopNFromProto converts CMSketch and TopN from its protobuf representation. +func CMSketchAndTopNFromProto(protoSketch *tipb.CMSketch) (*CMSketch, *TopN) { + if protoSketch == nil { + return nil, nil + } + retTopN := TopNFromProto(protoSketch.TopN) + if len(protoSketch.Rows) == 0 { + return nil, retTopN + } + c := NewCMSketch(int32(len(protoSketch.Rows)), int32(len(protoSketch.Rows[0].Counters))) + for i, row := range protoSketch.Rows { + c.count = 0 + for j, counter := range row.Counters { + c.table[i][j] = counter + c.count = c.count + uint64(counter) + } + } + c.defaultValue = protoSketch.DefaultValue + return c, retTopN +} + +// TopNFromProto converts TopN from its protobuf representation. +func TopNFromProto(protoTopN []*tipb.CMSketchTopN) *TopN { + if len(protoTopN) == 0 { + return nil + } + topN := NewTopN(len(protoTopN)) + for _, e := range protoTopN { + d := make([]byte, len(e.Data)) + copy(d, e.Data) + topN.AppendTopN(d, e.Count) + } + topN.Sort() + return topN +} + +// EncodeCMSketchWithoutTopN encodes the given CMSketch to byte slice. +// Note that it does not include the topN. +func EncodeCMSketchWithoutTopN(c *CMSketch) ([]byte, error) { + if c == nil { + return nil, nil + } + p := CMSketchToProto(c, nil) + p.TopN = nil + protoData, err := p.Marshal() + return protoData, err +} + +// DecodeCMSketchAndTopN decode a CMSketch from the given byte slice. +func DecodeCMSketchAndTopN(data []byte, topNRows []chunk.Row) (*CMSketch, *TopN, error) { + if data == nil && len(topNRows) == 0 { + return nil, nil, nil + } + if len(data) == 0 { + return nil, DecodeTopN(topNRows), nil + } + cm, err := DecodeCMSketch(data) + if err != nil { + return nil, nil, errors.Trace(err) + } + return cm, DecodeTopN(topNRows), nil +} + +// DecodeTopN decodes a TopN from the given byte slice. +func DecodeTopN(topNRows []chunk.Row) *TopN { + if len(topNRows) == 0 { + return nil + } + topN := NewTopN(len(topNRows)) + for _, row := range topNRows { + data := make([]byte, len(row.GetBytes(0))) + copy(data, row.GetBytes(0)) + topN.AppendTopN(data, row.GetUint64(1)) + } + topN.Sort() + return topN +} + +// DecodeCMSketch encodes the given CMSketch to byte slice. +func DecodeCMSketch(data []byte) (*CMSketch, error) { + if len(data) == 0 { + return nil, nil + } + protoSketch := &tipb.CMSketch{} + err := protoSketch.Unmarshal(data) + if err != nil { + return nil, errors.Trace(err) + } + if len(protoSketch.Rows) == 0 { + return nil, nil + } + c := NewCMSketch(int32(len(protoSketch.Rows)), int32(len(protoSketch.Rows[0].Counters))) + for i, row := range protoSketch.Rows { + c.count = 0 + for j, counter := range row.Counters { + c.table[i][j] = counter + c.count = c.count + uint64(counter) + } + } + c.defaultValue = protoSketch.DefaultValue + return c, nil +} + +// TotalCount returns the total count in the sketch, it is only used for test. +func (c *CMSketch) TotalCount() uint64 { + if c == nil { + return 0 + } + return c.count +} + +// Equal tests if two CM Sketch equal, it is only used for test. +func (c *CMSketch) Equal(rc *CMSketch) bool { + return reflect.DeepEqual(c, rc) +} + +// Copy makes a copy for current CMSketch. +func (c *CMSketch) Copy() *CMSketch { + if c == nil { + return nil + } + tbl := make([][]uint32, c.depth) + for i := range tbl { + tbl[i] = make([]uint32, c.width) + copy(tbl[i], c.table[i]) + } + return &CMSketch{count: c.count, width: c.width, depth: c.depth, table: tbl, defaultValue: c.defaultValue} +} + +// GetWidthAndDepth returns the width and depth of CM Sketch. +func (c *CMSketch) GetWidthAndDepth() (width, depth int32) { + return c.width, c.depth +} + +// CalcDefaultValForAnalyze calculate the default value for Analyze. +// The value of it is count / NDV in CMSketch. This means count and NDV are not include topN. +func (c *CMSketch) CalcDefaultValForAnalyze(ndv uint64) { + c.defaultValue = c.count / max(1, ndv) +} + +// TopN stores most-common values, which is used to estimate point queries. +type TopN struct { + TopN []TopNMeta +} + +// Scale scales the TopN by the given factor. +func (c *TopN) Scale(scaleFactor float64) { + for i := range c.TopN { + c.TopN[i].Count = uint64(float64(c.TopN[i].Count) * scaleFactor) + } +} + +// AppendTopN appends a topn into the TopN struct. +func (c *TopN) AppendTopN(data []byte, count uint64) { + if c == nil { + return + } + c.TopN = append(c.TopN, TopNMeta{data, count}) +} + +func (c *TopN) String() string { + if c == nil { + return "EmptyTopN" + } + builder := &strings.Builder{} + fmt.Fprintf(builder, "TopN{length: %v, ", len(c.TopN)) + fmt.Fprint(builder, "[") + for i := 0; i < len(c.TopN); i++ { + fmt.Fprintf(builder, "(%v, %v)", c.TopN[i].Encoded, c.TopN[i].Count) + if i+1 != len(c.TopN) { + fmt.Fprint(builder, ", ") + } + } + fmt.Fprint(builder, "]") + fmt.Fprint(builder, "}") + return builder.String() +} + +// Num returns the ndv of the TopN. +// +// TopN is declared directly in Histogram. So the Len is occupied by the Histogram. We use Num instead. +func (c *TopN) Num() int { + if c == nil { + return 0 + } + return len(c.TopN) +} + +// DecodedString returns the value with decoded result. +func (c *TopN) DecodedString(ctx sessionctx.Context, colTypes []byte) (string, error) { + if c == nil { + return "", nil + } + builder := &strings.Builder{} + fmt.Fprintf(builder, "TopN{length: %v, ", len(c.TopN)) + fmt.Fprint(builder, "[") + var tmpDatum types.Datum + for i := 0; i < len(c.TopN); i++ { + tmpDatum.SetBytes(c.TopN[i].Encoded) + valStr, err := ValueToString(ctx.GetSessionVars(), &tmpDatum, len(colTypes), colTypes) + if err != nil { + return "", err + } + fmt.Fprintf(builder, "(%v, %v)", valStr, c.TopN[i].Count) + if i+1 != len(c.TopN) { + fmt.Fprint(builder, ", ") + } + } + fmt.Fprint(builder, "]") + fmt.Fprint(builder, "}") + return builder.String(), nil +} + +// Copy makes a copy for current TopN. +func (c *TopN) Copy() *TopN { + if c == nil { + return nil + } + topN := make([]TopNMeta, len(c.TopN)) + for i, t := range c.TopN { + topN[i].Encoded = make([]byte, len(t.Encoded)) + copy(topN[i].Encoded, t.Encoded) + topN[i].Count = t.Count + } + return &TopN{ + TopN: topN, + } +} + +// TopNMeta stores the unit of the TopN. +type TopNMeta struct { + Encoded []byte + Count uint64 +} + +// QueryTopN returns the results for (h1, h2) in murmur3.Sum128(), if not exists, return (0, false). +// The input sctx is just for debug trace, you can pass nil safely if that's not needed. +func (c *TopN) QueryTopN(sctx context.PlanContext, d []byte) (result uint64, found bool) { + if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + defer func() { + debugtrace.RecordAnyValuesWithNames(sctx, "Result", result, "Found", found) + debugtrace.LeaveContextCommon(sctx) + }() + } + if c == nil { + return 0, false + } + idx := c.FindTopN(d) + if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.RecordAnyValuesWithNames(sctx, "FindTopN idx", idx) + } + if idx < 0 { + return 0, false + } + return c.TopN[idx].Count, true +} + +// FindTopN finds the index of the given value in the TopN. +func (c *TopN) FindTopN(d []byte) int { + if c == nil { + return -1 + } + if len(c.TopN) == 0 { + return -1 + } + if len(c.TopN) == 1 { + if bytes.Equal(c.TopN[0].Encoded, d) { + return 0 + } + return -1 + } + if bytes.Compare(c.TopN[len(c.TopN)-1].Encoded, d) < 0 { + return -1 + } + if bytes.Compare(c.TopN[0].Encoded, d) > 0 { + return -1 + } + idx, match := slices.BinarySearchFunc(c.TopN, d, func(a TopNMeta, b []byte) int { + return bytes.Compare(a.Encoded, b) + }) + if !match { + return -1 + } + return idx +} + +// LowerBound searches on the sorted top-n items, +// returns the smallest index i such that the value at element i is not less than `d`. +func (c *TopN) LowerBound(d []byte) (idx int, match bool) { + if c == nil { + return 0, false + } + idx, match = slices.BinarySearchFunc(c.TopN, d, func(a TopNMeta, b []byte) int { + return bytes.Compare(a.Encoded, b) + }) + return idx, match +} + +// BetweenCount estimates the row count for interval [l, r). +// The input sctx is just for debug trace, you can pass nil safely if that's not needed. +func (c *TopN) BetweenCount(sctx context.PlanContext, l, r []byte) (result uint64) { + if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugtrace.EnterContextCommon(sctx) + defer func() { + debugtrace.RecordAnyValuesWithNames(sctx, "Result", result) + debugtrace.LeaveContextCommon(sctx) + }() + } + if c == nil { + return 0 + } + lIdx, _ := c.LowerBound(l) + rIdx, _ := c.LowerBound(r) + ret := uint64(0) + for i := lIdx; i < rIdx; i++ { + ret += c.TopN[i].Count + } + if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { + debugTraceTopNRange(sctx, c, lIdx, rIdx) + } + return ret +} + +// Sort sorts the topn items. +func (c *TopN) Sort() { + if c == nil { + return + } + slices.SortFunc(c.TopN, func(i, j TopNMeta) int { + return bytes.Compare(i.Encoded, j.Encoded) + }) +} + +// TotalCount returns how many data is stored in TopN. +func (c *TopN) TotalCount() uint64 { + if c == nil { + return 0 + } + total := uint64(0) + for _, t := range c.TopN { + total += t.Count + } + return total +} + +// Equal checks whether the two TopN are equal. +func (c *TopN) Equal(cc *TopN) bool { + if c.TotalCount() == 0 && cc.TotalCount() == 0 { + return true + } else if c.TotalCount() != cc.TotalCount() { + return false + } + if len(c.TopN) != len(cc.TopN) { + return false + } + for i := range c.TopN { + if !bytes.Equal(c.TopN[i].Encoded, cc.TopN[i].Encoded) { + return false + } + if c.TopN[i].Count != cc.TopN[i].Count { + return false + } + } + return true +} + +// RemoveVal remove the val from TopN if it exists. +func (c *TopN) RemoveVal(val []byte) { + if c == nil { + return + } + pos := c.FindTopN(val) + if pos == -1 { + return + } + c.TopN = append(c.TopN[:pos], c.TopN[pos+1:]...) +} + +// MemoryUsage returns the total memory usage of a topn. +func (c *TopN) MemoryUsage() (sum int64) { + if c == nil { + return + } + sum = 32 // size of array (24) + reference (8) + for _, meta := range c.TopN { + sum += 32 + int64(cap(meta.Encoded)) // 32 is size of byte array (24) + size of uint64 (8) + } + return +} + +// queryAddTopN TopN adds count to CMSketch.topN if exists, and returns the count of such elements after insert. +// If such elements does not in topn elements, nothing will happen and false will be returned. +func (c *TopN) updateTopNWithDelta(d []byte, delta uint64, increase bool) bool { + if c == nil || c.TopN == nil { + return false + } + idx := c.FindTopN(d) + if idx >= 0 { + if increase { + c.TopN[idx].Count += delta + } else { + c.TopN[idx].Count -= delta + } + return true + } + return false +} + +// NewTopN creates the new TopN struct by the given size. +func NewTopN(n int) *TopN { + return &TopN{TopN: make([]TopNMeta, 0, n)} +} + +// MergeTopN is used to merge more TopN structures to generate a new TopN struct by the given size. +// The input parameters are multiple TopN structures to be merged and the size of the new TopN that will be generated. +// The output parameters are the newly generated TopN structure and the remaining numbers. +// Notice: The n can be 0. So n has no default value, we must explicitly specify this value. +func MergeTopN(topNs []*TopN, n uint32) (*TopN, []TopNMeta) { + if CheckEmptyTopNs(topNs) { + return nil, nil + } + // Different TopN structures may hold the same value, we have to merge them. + counter := make(map[hack.MutableString]uint64) + for _, topN := range topNs { + if topN.TotalCount() == 0 { + continue + } + for _, val := range topN.TopN { + counter[hack.String(val.Encoded)] += val.Count + } + } + numTop := len(counter) + if numTop == 0 { + return nil, nil + } + sorted := make([]TopNMeta, 0, numTop) + for value, cnt := range counter { + data := hack.Slice(string(value)) + sorted = append(sorted, TopNMeta{Encoded: data, Count: cnt}) + } + return GetMergedTopNFromSortedSlice(sorted, n) +} + +// CheckEmptyTopNs checks whether all TopNs are empty. +func CheckEmptyTopNs(topNs []*TopN) bool { + for _, topN := range topNs { + if topN.TotalCount() != 0 { + return false + } + } + return true +} + +// SortTopnMeta sort topnMeta +func SortTopnMeta(topnMetas []TopNMeta) { + slices.SortFunc(topnMetas, func(i, j TopNMeta) int { + if i.Count != j.Count { + return cmp.Compare(j.Count, i.Count) + } + return bytes.Compare(i.Encoded, j.Encoded) + }) +} + +// TopnMetaCompare compare topnMeta +func TopnMetaCompare(i, j TopNMeta) int { + c := cmp.Compare(j.Count, i.Count) + if c != 0 { + return c + } + return bytes.Compare(i.Encoded, j.Encoded) +} + +// GetMergedTopNFromSortedSlice returns merged topn +func GetMergedTopNFromSortedSlice(sorted []TopNMeta, n uint32) (*TopN, []TopNMeta) { + SortTopnMeta(sorted) + n = min(uint32(len(sorted)), n) + + var finalTopN TopN + finalTopN.TopN = sorted[:n] + finalTopN.Sort() + return &finalTopN, sorted[n:] +} diff --git a/pkg/statistics/handle/autoanalyze/autoanalyze.go b/pkg/statistics/handle/autoanalyze/autoanalyze.go index 7ff39511b7865..3e4c38ae22852 100644 --- a/pkg/statistics/handle/autoanalyze/autoanalyze.go +++ b/pkg/statistics/handle/autoanalyze/autoanalyze.go @@ -720,7 +720,7 @@ func insertAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, insta } job.ID = new(uint64) *job.ID = rows[0].GetUint64(0) - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { if val.(bool) { logutil.BgLogger().Info("InsertAnalyzeJob", zap.String("table_schema", job.DBName), @@ -730,7 +730,7 @@ func insertAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, insta zap.Uint64("job_id", *job.ID), ) } - }) + } return nil } @@ -746,14 +746,14 @@ func startAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob) { if err != nil { statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzePending, statistics.AnalyzeRunning)), zap.Error(err)) } - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { if val.(bool) { logutil.BgLogger().Info("StartAnalyzeJob", zap.Time("start_time", job.StartTime), zap.Uint64("job id", *job.ID), ) } - }) + } } // updateAnalyzeJobProgress updates count of the processed rows when increment reaches a threshold. @@ -770,14 +770,14 @@ func updateAnalyzeJobProgress(sctx sessionctx.Context, job *statistics.AnalyzeJo if err != nil { statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("process %v rows", delta)), zap.Error(err)) } - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { if val.(bool) { logutil.BgLogger().Info("UpdateAnalyzeJobProgress", zap.Int64("increase processed_rows", delta), zap.Uint64("job id", *job.ID), ) } - }) + } } // finishAnalyzeJob finishes an analyze or merge job @@ -825,7 +825,7 @@ func finishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analy logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzeRunning, state)), zap.Error(err)) } - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { if val.(bool) { logger := logutil.BgLogger().With( zap.Time("end_time", job.EndTime), @@ -839,5 +839,5 @@ func finishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analy } logger.Info("FinishAnalyzeJob") } - }) + } } diff --git a/pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ b/pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ new file mode 100644 index 0000000000000..7ff39511b7865 --- /dev/null +++ b/pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ @@ -0,0 +1,843 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package autoanalyze + +import ( + "context" + "fmt" + "math/rand" + "net" + "strconv" + "strings" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/sysproctrack" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/autoanalyze/exec" + "github.com/pingcap/tidb/pkg/statistics/handle/autoanalyze/refresher" + "github.com/pingcap/tidb/pkg/statistics/handle/lockstats" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" + statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" + statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "github.com/pingcap/tidb/pkg/util/timeutil" + "go.uber.org/zap" +) + +// statsAnalyze implements util.StatsAnalyze. +// statsAnalyze is used to handle auto-analyze and manage analyze jobs. +type statsAnalyze struct { + statsHandle statstypes.StatsHandle + // sysProcTracker is used to track sys process like analyze + sysProcTracker sysproctrack.Tracker +} + +// NewStatsAnalyze creates a new StatsAnalyze. +func NewStatsAnalyze( + statsHandle statstypes.StatsHandle, + sysProcTracker sysproctrack.Tracker, +) statstypes.StatsAnalyze { + return &statsAnalyze{statsHandle: statsHandle, sysProcTracker: sysProcTracker} +} + +// InsertAnalyzeJob inserts the analyze job to the storage. +func (sa *statsAnalyze) InsertAnalyzeJob(job *statistics.AnalyzeJob, instance string, procID uint64) error { + return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + return insertAnalyzeJob(sctx, job, instance, procID) + }) +} + +func (sa *statsAnalyze) StartAnalyzeJob(job *statistics.AnalyzeJob) { + err := statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + startAnalyzeJob(sctx, job) + return nil + }) + if err != nil { + statslogutil.StatsLogger().Warn("failed to start analyze job", zap.Error(err)) + } +} + +func (sa *statsAnalyze) UpdateAnalyzeJobProgress(job *statistics.AnalyzeJob, rowCount int64) { + err := statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + updateAnalyzeJobProgress(sctx, job, rowCount) + return nil + }) + if err != nil { + statslogutil.StatsLogger().Warn("failed to update analyze job progress", zap.Error(err)) + } +} + +func (sa *statsAnalyze) FinishAnalyzeJob(job *statistics.AnalyzeJob, failReason error, analyzeType statistics.JobType) { + err := statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + finishAnalyzeJob(sctx, job, failReason, analyzeType) + return nil + }) + if err != nil { + statslogutil.StatsLogger().Warn("failed to finish analyze job", zap.Error(err)) + } +} + +// DeleteAnalyzeJobs deletes the analyze jobs whose update time is earlier than updateTime. +func (sa *statsAnalyze) DeleteAnalyzeJobs(updateTime time.Time) error { + return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + _, _, err := statsutil.ExecRows(sctx, "DELETE FROM mysql.analyze_jobs WHERE update_time < CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)", updateTime.UTC().Format(types.TimeFormat)) + return err + }) +} + +// CleanupCorruptedAnalyzeJobsOnCurrentInstance cleans up the potentially corrupted analyze job. +// It only cleans up the jobs that are associated with the current instance. +func (sa *statsAnalyze) CleanupCorruptedAnalyzeJobsOnCurrentInstance(currentRunningProcessIDs map[uint64]struct{}) error { + return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + return CleanupCorruptedAnalyzeJobsOnCurrentInstance(sctx, currentRunningProcessIDs) + }, statsutil.FlagWrapTxn) +} + +// CleanupCorruptedAnalyzeJobsOnDeadInstances removes analyze jobs that may have been corrupted. +// Specifically, it removes jobs associated with instances that no longer exist in the cluster. +func (sa *statsAnalyze) CleanupCorruptedAnalyzeJobsOnDeadInstances() error { + return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + return CleanupCorruptedAnalyzeJobsOnDeadInstances(sctx) + }, statsutil.FlagWrapTxn) +} + +// SelectAnalyzeJobsOnCurrentInstanceSQL is the SQL to select the analyze jobs whose +// state is `pending` or `running` and the update time is more than 10 minutes ago +// and the instance is current instance. +const SelectAnalyzeJobsOnCurrentInstanceSQL = `SELECT id, process_id + FROM mysql.analyze_jobs + WHERE instance = %? + AND state IN ('pending', 'running') + AND update_time < CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)` + +// SelectAnalyzeJobsSQL is the SQL to select the analyze jobs whose +// state is `pending` or `running` and the update time is more than 10 minutes ago. +const SelectAnalyzeJobsSQL = `SELECT id, instance + FROM mysql.analyze_jobs + WHERE state IN ('pending', 'running') + AND update_time < CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)` + +// BatchUpdateAnalyzeJobSQL is the SQL to update the analyze jobs to `failed` state. +const BatchUpdateAnalyzeJobSQL = `UPDATE mysql.analyze_jobs + SET state = 'failed', + fail_reason = 'The TiDB Server has either shut down or the analyze query was terminated during the analyze job execution', + process_id = NULL + WHERE id IN (%?)` + +func tenMinutesAgo() string { + return time.Now().Add(-10 * time.Minute).UTC().Format(types.TimeFormat) +} + +// CleanupCorruptedAnalyzeJobsOnCurrentInstance cleans up the potentially corrupted analyze job from current instance. +// Exported for testing. +func CleanupCorruptedAnalyzeJobsOnCurrentInstance( + sctx sessionctx.Context, + currentRunningProcessIDs map[uint64]struct{}, +) error { + serverInfo, err := infosync.GetServerInfo() + if err != nil { + return errors.Trace(err) + } + instance := net.JoinHostPort(serverInfo.IP, strconv.Itoa(int(serverInfo.Port))) + // Get all the analyze jobs whose state is `pending` or `running` and the update time is more than 10 minutes ago + // and the instance is current instance. + rows, _, err := statsutil.ExecRows( + sctx, + SelectAnalyzeJobsOnCurrentInstanceSQL, + instance, + tenMinutesAgo(), + ) + if err != nil { + return errors.Trace(err) + } + + jobIDs := make([]string, 0, len(rows)) + for _, row := range rows { + // The process ID is typically non-null for running or pending jobs. + // However, in rare cases(I don't which case), it may be null. Therefore, it's necessary to check its value. + if !row.IsNull(1) { + processID := row.GetUint64(1) + // If the process id is not in currentRunningProcessIDs, we need to clean up the job. + // They don't belong to current instance any more. + if _, ok := currentRunningProcessIDs[processID]; !ok { + jobID := row.GetUint64(0) + jobIDs = append(jobIDs, strconv.FormatUint(jobID, 10)) + } + } + } + + // Do a batch update to clean up the jobs. + if len(jobIDs) > 0 { + _, _, err = statsutil.ExecRows( + sctx, + BatchUpdateAnalyzeJobSQL, + jobIDs, + ) + if err != nil { + return errors.Trace(err) + } + statslogutil.StatsLogger().Info( + "clean up the potentially corrupted analyze jobs from current instance", + zap.Strings("jobIDs", jobIDs), + ) + } + + return nil +} + +// CleanupCorruptedAnalyzeJobsOnDeadInstances cleans up the potentially corrupted analyze job from dead instances. +func CleanupCorruptedAnalyzeJobsOnDeadInstances( + sctx sessionctx.Context, +) error { + rows, _, err := statsutil.ExecRows( + sctx, + SelectAnalyzeJobsSQL, + tenMinutesAgo(), + ) + if err != nil { + return errors.Trace(err) + } + if len(rows) == 0 { + return nil + } + + // Get all the instances from etcd. + serverInfo, err := infosync.GetAllServerInfo(context.Background()) + if err != nil { + return errors.Trace(err) + } + instances := make(map[string]struct{}, len(serverInfo)) + for _, info := range serverInfo { + instance := net.JoinHostPort(info.IP, strconv.Itoa(int(info.Port))) + instances[instance] = struct{}{} + } + + jobIDs := make([]string, 0, len(rows)) + for _, row := range rows { + // If the instance is not in instances, we need to clean up the job. + // It means the instance is down or the instance is not in the cluster any more. + instance := row.GetString(1) + if _, ok := instances[instance]; !ok { + jobID := row.GetUint64(0) + jobIDs = append(jobIDs, strconv.FormatUint(jobID, 10)) + } + } + + // Do a batch update to clean up the jobs. + if len(jobIDs) > 0 { + _, _, err = statsutil.ExecRows( + sctx, + BatchUpdateAnalyzeJobSQL, + jobIDs, + ) + if err != nil { + return errors.Trace(err) + } + statslogutil.StatsLogger().Info( + "clean up the potentially corrupted analyze jobs from dead instances", + zap.Strings("jobIDs", jobIDs), + ) + } + + return nil +} + +// HandleAutoAnalyze analyzes the outdated tables. (The change percent of the table exceeds the threshold) +// It also analyzes newly created tables and newly added indexes. +func (sa *statsAnalyze) HandleAutoAnalyze() (analyzed bool) { + _ = statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { + analyzed = HandleAutoAnalyze(sctx, sa.statsHandle, sa.sysProcTracker) + return nil + }) + return +} + +// CheckAnalyzeVersion checks whether all the statistics versions of this table's columns and indexes are the same. +func (sa *statsAnalyze) CheckAnalyzeVersion(tblInfo *model.TableInfo, physicalIDs []int64, version *int) bool { + // We simply choose one physical id to get its stats. + var tbl *statistics.Table + for _, pid := range physicalIDs { + tbl = sa.statsHandle.GetPartitionStats(tblInfo, pid) + if !tbl.Pseudo { + break + } + } + if tbl == nil || tbl.Pseudo { + return true + } + return statistics.CheckAnalyzeVerOnTable(tbl, version) +} + +// HandleAutoAnalyze analyzes the newly created table or index. +func HandleAutoAnalyze( + sctx sessionctx.Context, + statsHandle statstypes.StatsHandle, + sysProcTracker sysproctrack.Tracker, +) (analyzed bool) { + defer func() { + if r := recover(); r != nil { + statslogutil.StatsLogger().Error( + "HandleAutoAnalyze panicked", + zap.Any("recover", r), + zap.Stack("stack"), + ) + } + }() + if variable.EnableAutoAnalyzePriorityQueue.Load() { + r := refresher.NewRefresher(statsHandle, sysProcTracker) + err := r.RebuildTableAnalysisJobQueue() + if err != nil { + statslogutil.StatsLogger().Error("rebuild table analysis job queue failed", zap.Error(err)) + return false + } + return r.PickOneTableAndAnalyzeByPriority() + } + + parameters := exec.GetAutoAnalyzeParameters(sctx) + autoAnalyzeRatio := exec.ParseAutoAnalyzeRatio(parameters[variable.TiDBAutoAnalyzeRatio]) + // Determine the time window for auto-analysis and verify if the current time falls within this range. + start, end, err := exec.ParseAutoAnalysisWindow( + parameters[variable.TiDBAutoAnalyzeStartTime], + parameters[variable.TiDBAutoAnalyzeEndTime], + ) + if err != nil { + statslogutil.StatsLogger().Error( + "parse auto analyze period failed", + zap.Error(err), + ) + return false + } + if !timeutil.WithinDayTimePeriod(start, end, time.Now()) { + return false + } + pruneMode := variable.PartitionPruneMode(sctx.GetSessionVars().PartitionPruneMode.Load()) + + return RandomPickOneTableAndTryAutoAnalyze( + sctx, + statsHandle, + sysProcTracker, + autoAnalyzeRatio, + pruneMode, + start, + end, + ) +} + +// RandomPickOneTableAndTryAutoAnalyze randomly picks one table and tries to analyze it. +// 1. If the table is not analyzed, analyze it. +// 2. If the table is analyzed, analyze it when "tbl.ModifyCount/tbl.Count > autoAnalyzeRatio". +// 3. If the table is analyzed, analyze its indices when the index is not analyzed. +// 4. If the table is locked, skip it. +// Exposed solely for testing. +func RandomPickOneTableAndTryAutoAnalyze( + sctx sessionctx.Context, + statsHandle statstypes.StatsHandle, + sysProcTracker sysproctrack.Tracker, + autoAnalyzeRatio float64, + pruneMode variable.PartitionPruneMode, + start, end time.Time, +) bool { + is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) + dbs := infoschema.AllSchemaNames(is) + // Shuffle the database and table slice to randomize the order of analyzing tables. + rd := rand.New(rand.NewSource(time.Now().UnixNano())) // #nosec G404 + rd.Shuffle(len(dbs), func(i, j int) { + dbs[i], dbs[j] = dbs[j], dbs[i] + }) + // Query locked tables once to minimize overhead. + // Outdated lock info is acceptable as we verify table lock status pre-analysis. + lockedTables, err := lockstats.QueryLockedTables(sctx) + if err != nil { + statslogutil.StatsLogger().Error( + "check table lock failed", + zap.Error(err), + ) + return false + } + + for _, db := range dbs { + // Ignore the memory and system database. + if util.IsMemOrSysDB(strings.ToLower(db)) { + continue + } + + tbls, err := is.SchemaTableInfos(context.Background(), model.NewCIStr(db)) + terror.Log(err) + // We shuffle dbs and tbls so that the order of iterating tables is random. If the order is fixed and the auto + // analyze job of one table fails for some reason, it may always analyze the same table and fail again and again + // when the HandleAutoAnalyze is triggered. Randomizing the order can avoid the problem. + // TODO: Design a priority queue to place the table which needs analyze most in the front. + rd.Shuffle(len(tbls), func(i, j int) { + tbls[i], tbls[j] = tbls[j], tbls[i] + }) + + // We need to check every partition of every table to see if it needs to be analyzed. + for _, tblInfo := range tbls { + // Sometimes the tables are too many. Auto-analyze will take too much time on it. + // so we need to check the available time. + if !timeutil.WithinDayTimePeriod(start, end, time.Now()) { + return false + } + // If table locked, skip analyze all partitions of the table. + // FIXME: This check is not accurate, because other nodes may change the table lock status at any time. + if _, ok := lockedTables[tblInfo.ID]; ok { + continue + } + + if tblInfo.IsView() { + continue + } + + pi := tblInfo.GetPartitionInfo() + // No partitions, analyze the whole table. + if pi == nil { + statsTbl := statsHandle.GetTableStatsForAutoAnalyze(tblInfo) + sql := "analyze table %n.%n" + analyzed := tryAutoAnalyzeTable(sctx, statsHandle, sysProcTracker, tblInfo, statsTbl, autoAnalyzeRatio, sql, db, tblInfo.Name.O) + if analyzed { + // analyze one table at a time to let it get the freshest parameters. + // others will be analyzed next round which is just 3s later. + return true + } + continue + } + // Only analyze the partition that has not been locked. + partitionDefs := make([]model.PartitionDefinition, 0, len(pi.Definitions)) + for _, def := range pi.Definitions { + if _, ok := lockedTables[def.ID]; !ok { + partitionDefs = append(partitionDefs, def) + } + } + partitionStats := getPartitionStats(statsHandle, tblInfo, partitionDefs) + if pruneMode == variable.Dynamic { + analyzed := tryAutoAnalyzePartitionTableInDynamicMode( + sctx, + statsHandle, + sysProcTracker, + tblInfo, + partitionDefs, + partitionStats, + db, + autoAnalyzeRatio, + ) + if analyzed { + return true + } + continue + } + for _, def := range partitionDefs { + sql := "analyze table %n.%n partition %n" + statsTbl := partitionStats[def.ID] + analyzed := tryAutoAnalyzeTable(sctx, statsHandle, sysProcTracker, tblInfo, statsTbl, autoAnalyzeRatio, sql, db, tblInfo.Name.O, def.Name.O) + if analyzed { + return true + } + } + } + } + + return false +} + +func getPartitionStats( + statsHandle statstypes.StatsHandle, + tblInfo *model.TableInfo, + defs []model.PartitionDefinition, +) map[int64]*statistics.Table { + partitionStats := make(map[int64]*statistics.Table, len(defs)) + + for _, def := range defs { + partitionStats[def.ID] = statsHandle.GetPartitionStatsForAutoAnalyze(tblInfo, def.ID) + } + + return partitionStats +} + +// Determine whether the table and index require analysis. +func tryAutoAnalyzeTable( + sctx sessionctx.Context, + statsHandle statstypes.StatsHandle, + sysProcTracker sysproctrack.Tracker, + tblInfo *model.TableInfo, + statsTbl *statistics.Table, + ratio float64, + sql string, + params ...any, +) bool { + // 1. If the statistics are either not loaded or are classified as pseudo, there is no need for analyze + // Pseudo statistics can be created by the optimizer, so we need to double check it. + // 2. If the table is too small, we don't want to waste time to analyze it. + // Leave the opportunity to other bigger tables. + if statsTbl == nil || statsTbl.Pseudo || statsTbl.RealtimeCount < statistics.AutoAnalyzeMinCnt { + return false + } + + // Check if the table needs to analyze. + if needAnalyze, reason := NeedAnalyzeTable( + statsTbl, + ratio, + ); needAnalyze { + escaped, err := sqlescape.EscapeSQL(sql, params...) + if err != nil { + return false + } + statslogutil.StatsLogger().Info( + "auto analyze triggered", + zap.String("sql", escaped), + zap.String("reason", reason), + ) + + tableStatsVer := sctx.GetSessionVars().AnalyzeVersion + statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) + exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sql, params...) + + return true + } + + // Whether the table needs to analyze or not, we need to check the indices of the table. + for _, idx := range tblInfo.Indices { + if idxStats := statsTbl.GetIdx(idx.ID); idxStats == nil && !statsTbl.ColAndIdxExistenceMap.HasAnalyzed(idx.ID, true) && idx.State == model.StatePublic { + sqlWithIdx := sql + " index %n" + paramsWithIdx := append(params, idx.Name.O) + escaped, err := sqlescape.EscapeSQL(sqlWithIdx, paramsWithIdx...) + if err != nil { + return false + } + + statslogutil.StatsLogger().Info( + "auto analyze for unanalyzed indexes", + zap.String("sql", escaped), + ) + tableStatsVer := sctx.GetSessionVars().AnalyzeVersion + statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) + exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sqlWithIdx, paramsWithIdx...) + return true + } + } + return false +} + +// NeedAnalyzeTable checks if we need to analyze the table: +// 1. If the table has never been analyzed, we need to analyze it. +// 2. If the table had been analyzed before, we need to analyze it when +// "tbl.ModifyCount/tbl.Count > autoAnalyzeRatio" and the current time is +// between `start` and `end`. +// +// Exposed for test. +func NeedAnalyzeTable(tbl *statistics.Table, autoAnalyzeRatio float64) (bool, string) { + analyzed := tbl.IsAnalyzed() + if !analyzed { + return true, "table unanalyzed" + } + // Auto analyze is disabled. + if autoAnalyzeRatio == 0 { + return false, "" + } + // No need to analyze it. + tblCnt := float64(tbl.RealtimeCount) + if histCnt := tbl.GetAnalyzeRowCount(); histCnt > 0 { + tblCnt = histCnt + } + if float64(tbl.ModifyCount)/tblCnt <= autoAnalyzeRatio { + return false, "" + } + return true, fmt.Sprintf("too many modifications(%v/%v>%v)", tbl.ModifyCount, tblCnt, autoAnalyzeRatio) +} + +// It is very similar to tryAutoAnalyzeTable, but it commits the analyze job in batch for partitions. +func tryAutoAnalyzePartitionTableInDynamicMode( + sctx sessionctx.Context, + statsHandle statstypes.StatsHandle, + sysProcTracker sysproctrack.Tracker, + tblInfo *model.TableInfo, + partitionDefs []model.PartitionDefinition, + partitionStats map[int64]*statistics.Table, + db string, + ratio float64, +) bool { + tableStatsVer := sctx.GetSessionVars().AnalyzeVersion + analyzePartitionBatchSize := int(variable.AutoAnalyzePartitionBatchSize.Load()) + needAnalyzePartitionNames := make([]any, 0, len(partitionDefs)) + + for _, def := range partitionDefs { + partitionStats := partitionStats[def.ID] + // 1. If the statistics are either not loaded or are classified as pseudo, there is no need for analyze. + // Pseudo statistics can be created by the optimizer, so we need to double check it. + // 2. If the table is too small, we don't want to waste time to analyze it. + // Leave the opportunity to other bigger tables. + if partitionStats == nil || partitionStats.Pseudo || partitionStats.RealtimeCount < statistics.AutoAnalyzeMinCnt { + continue + } + if needAnalyze, reason := NeedAnalyzeTable( + partitionStats, + ratio, + ); needAnalyze { + needAnalyzePartitionNames = append(needAnalyzePartitionNames, def.Name.O) + statslogutil.StatsLogger().Info( + "need to auto analyze", + zap.String("database", db), + zap.String("table", tblInfo.Name.String()), + zap.String("partition", def.Name.O), + zap.String("reason", reason), + ) + statistics.CheckAnalyzeVerOnTable(partitionStats, &tableStatsVer) + } + } + + getSQL := func(prefix, suffix string, numPartitions int) string { + var sqlBuilder strings.Builder + sqlBuilder.WriteString(prefix) + for i := 0; i < numPartitions; i++ { + if i != 0 { + sqlBuilder.WriteString(",") + } + sqlBuilder.WriteString(" %n") + } + sqlBuilder.WriteString(suffix) + return sqlBuilder.String() + } + + if len(needAnalyzePartitionNames) > 0 { + statslogutil.StatsLogger().Info("start to auto analyze", + zap.String("database", db), + zap.String("table", tblInfo.Name.String()), + zap.Any("partitions", needAnalyzePartitionNames), + zap.Int("analyze partition batch size", analyzePartitionBatchSize), + ) + + statsTbl := statsHandle.GetTableStats(tblInfo) + statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) + for i := 0; i < len(needAnalyzePartitionNames); i += analyzePartitionBatchSize { + start := i + end := start + analyzePartitionBatchSize + if end >= len(needAnalyzePartitionNames) { + end = len(needAnalyzePartitionNames) + } + + // Do batch analyze for partitions. + sql := getSQL("analyze table %n.%n partition", "", end-start) + params := append([]any{db, tblInfo.Name.O}, needAnalyzePartitionNames[start:end]...) + + statslogutil.StatsLogger().Info( + "auto analyze triggered", + zap.String("database", db), + zap.String("table", tblInfo.Name.String()), + zap.Any("partitions", needAnalyzePartitionNames[start:end]), + ) + exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sql, params...) + } + + return true + } + // Check if any index of the table needs to analyze. + for _, idx := range tblInfo.Indices { + if idx.State != model.StatePublic { + continue + } + // Collect all the partition names that need to analyze. + for _, def := range partitionDefs { + partitionStats := partitionStats[def.ID] + // 1. If the statistics are either not loaded or are classified as pseudo, there is no need for analyze. + // Pseudo statistics can be created by the optimizer, so we need to double check it. + if partitionStats == nil || partitionStats.Pseudo { + continue + } + // 2. If the index is not analyzed, we need to analyze it. + if !partitionStats.ColAndIdxExistenceMap.HasAnalyzed(idx.ID, true) { + needAnalyzePartitionNames = append(needAnalyzePartitionNames, def.Name.O) + statistics.CheckAnalyzeVerOnTable(partitionStats, &tableStatsVer) + } + } + if len(needAnalyzePartitionNames) > 0 { + statsTbl := statsHandle.GetTableStats(tblInfo) + statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) + + for i := 0; i < len(needAnalyzePartitionNames); i += analyzePartitionBatchSize { + start := i + end := start + analyzePartitionBatchSize + if end >= len(needAnalyzePartitionNames) { + end = len(needAnalyzePartitionNames) + } + + sql := getSQL("analyze table %n.%n partition", " index %n", end-start) + params := append([]any{db, tblInfo.Name.O}, needAnalyzePartitionNames[start:end]...) + params = append(params, idx.Name.O) + statslogutil.StatsLogger().Info("auto analyze for unanalyzed", + zap.String("database", db), + zap.String("table", tblInfo.Name.String()), + zap.String("index", idx.Name.String()), + zap.Any("partitions", needAnalyzePartitionNames[start:end]), + ) + exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sql, params...) + } + + return true + } + } + + return false +} + +// insertAnalyzeJob inserts analyze job into mysql.analyze_jobs and gets job ID for further updating job. +func insertAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, instance string, procID uint64) (err error) { + jobInfo := job.JobInfo + const textMaxLength = 65535 + if len(jobInfo) > textMaxLength { + jobInfo = jobInfo[:textMaxLength] + } + const insertJob = "INSERT INTO mysql.analyze_jobs (table_schema, table_name, partition_name, job_info, state, instance, process_id) VALUES (%?, %?, %?, %?, %?, %?, %?)" + _, _, err = statsutil.ExecRows(sctx, insertJob, job.DBName, job.TableName, job.PartitionName, jobInfo, statistics.AnalyzePending, instance, procID) + if err != nil { + return err + } + const getJobID = "SELECT LAST_INSERT_ID()" + rows, _, err := statsutil.ExecRows(sctx, getJobID) + if err != nil { + return err + } + job.ID = new(uint64) + *job.ID = rows[0].GetUint64(0) + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val.(bool) { + logutil.BgLogger().Info("InsertAnalyzeJob", + zap.String("table_schema", job.DBName), + zap.String("table_name", job.TableName), + zap.String("partition_name", job.PartitionName), + zap.String("job_info", jobInfo), + zap.Uint64("job_id", *job.ID), + ) + } + }) + return nil +} + +// startAnalyzeJob marks the state of the analyze job as running and sets the start time. +func startAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob) { + if job == nil || job.ID == nil { + return + } + job.StartTime = time.Now() + job.Progress.SetLastDumpTime(job.StartTime) + const sql = "UPDATE mysql.analyze_jobs SET start_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %? WHERE id = %?" + _, _, err := statsutil.ExecRows(sctx, sql, job.StartTime.UTC().Format(types.TimeFormat), statistics.AnalyzeRunning, *job.ID) + if err != nil { + statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzePending, statistics.AnalyzeRunning)), zap.Error(err)) + } + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val.(bool) { + logutil.BgLogger().Info("StartAnalyzeJob", + zap.Time("start_time", job.StartTime), + zap.Uint64("job id", *job.ID), + ) + } + }) +} + +// updateAnalyzeJobProgress updates count of the processed rows when increment reaches a threshold. +func updateAnalyzeJobProgress(sctx sessionctx.Context, job *statistics.AnalyzeJob, rowCount int64) { + if job == nil || job.ID == nil { + return + } + delta := job.Progress.Update(rowCount) + if delta == 0 { + return + } + const sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %? WHERE id = %?" + _, _, err := statsutil.ExecRows(sctx, sql, delta, *job.ID) + if err != nil { + statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("process %v rows", delta)), zap.Error(err)) + } + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val.(bool) { + logutil.BgLogger().Info("UpdateAnalyzeJobProgress", + zap.Int64("increase processed_rows", delta), + zap.Uint64("job id", *job.ID), + ) + } + }) +} + +// finishAnalyzeJob finishes an analyze or merge job +func finishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analyzeErr error, analyzeType statistics.JobType) { + if job == nil || job.ID == nil { + return + } + + job.EndTime = time.Now() + var sql string + var args []any + + // process_id is used to see which process is running the analyze job and kill the analyze job. After the analyze job + // is finished(or failed), process_id is useless and we set it to NULL to avoid `kill tidb process_id` wrongly. + if analyzeErr != nil { + failReason := analyzeErr.Error() + const textMaxLength = 65535 + if len(failReason) > textMaxLength { + failReason = failReason[:textMaxLength] + } + + if analyzeType == statistics.TableAnalysisJob { + sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %?, end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, fail_reason = %?, process_id = NULL WHERE id = %?" + args = []any{job.Progress.GetDeltaCount(), job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFailed, failReason, *job.ID} + } else { + sql = "UPDATE mysql.analyze_jobs SET end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, fail_reason = %?, process_id = NULL WHERE id = %?" + args = []any{job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFailed, failReason, *job.ID} + } + } else { + if analyzeType == statistics.TableAnalysisJob { + sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %?, end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, process_id = NULL WHERE id = %?" + args = []any{job.Progress.GetDeltaCount(), job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFinished, *job.ID} + } else { + sql = "UPDATE mysql.analyze_jobs SET end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, process_id = NULL WHERE id = %?" + args = []any{job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFinished, *job.ID} + } + } + + _, _, err := statsutil.ExecRows(sctx, sql, args...) + if err != nil { + state := statistics.AnalyzeFinished + if analyzeErr != nil { + state = statistics.AnalyzeFailed + } + logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzeRunning, state)), zap.Error(err)) + } + + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { + if val.(bool) { + logger := logutil.BgLogger().With( + zap.Time("end_time", job.EndTime), + zap.Uint64("job id", *job.ID), + ) + if analyzeType == statistics.TableAnalysisJob { + logger = logger.With(zap.Int64("increase processed_rows", job.Progress.GetDeltaCount())) + } + if analyzeErr != nil { + logger = logger.With(zap.Error(analyzeErr)) + } + logger.Info("FinishAnalyzeJob") + } + }) +} diff --git a/pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go b/pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..258dfa09d159f --- /dev/null +++ b/pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package autoanalyze + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/statistics/handle/binding__failpoint_binding__.go b/pkg/statistics/handle/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..947a787548ae9 --- /dev/null +++ b/pkg/statistics/handle/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package handle + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/statistics/handle/bootstrap.go b/pkg/statistics/handle/bootstrap.go index c2157f21eed39..24cd4233acc27 100644 --- a/pkg/statistics/handle/bootstrap.go +++ b/pkg/statistics/handle/bootstrap.go @@ -734,7 +734,7 @@ func (h *Handle) InitStatsLite(ctx context.Context, is infoschema.InfoSchema) (e if err != nil { return err } - failpoint.Inject("beforeInitStatsLite", func() {}) + failpoint.Eval(_curpkg_("beforeInitStatsLite")) cache, err := h.initStatsMeta(ctx, is) if err != nil { return errors.Trace(err) @@ -769,7 +769,7 @@ func (h *Handle) InitStats(ctx context.Context, is infoschema.InfoSchema) (err e if err != nil { return err } - failpoint.Inject("beforeInitStats", func() {}) + failpoint.Eval(_curpkg_("beforeInitStats")) cache, err := h.initStatsMeta(ctx, is) if err != nil { return errors.Trace(err) diff --git a/pkg/statistics/handle/bootstrap.go__failpoint_stash__ b/pkg/statistics/handle/bootstrap.go__failpoint_stash__ new file mode 100644 index 0000000000000..c2157f21eed39 --- /dev/null +++ b/pkg/statistics/handle/bootstrap.go__failpoint_stash__ @@ -0,0 +1,815 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package handle + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/cache" + "github.com/pingcap/tidb/pkg/statistics/handle/initstats" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" + statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" + "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" +) + +// initStatsStep is the step to load stats by paging. +const initStatsStep = int64(500) + +var maxTidRecord MaxTidRecord + +// MaxTidRecord is to record the max tid. +type MaxTidRecord struct { + mu sync.Mutex + tid atomic.Int64 +} + +func (h *Handle) initStatsMeta4Chunk(ctx context.Context, is infoschema.InfoSchema, cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { + var physicalID, maxPhysicalID int64 + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + physicalID = row.GetInt64(1) + + // Detect the context cancel signal, since it may take a long time for the loop. + // TODO: add context to TableInfoByID and remove this code block? + if ctx.Err() != nil { + return + } + + // The table is read-only. Please do not modify it. + table, ok := h.TableInfoByID(is, physicalID) + if !ok { + logutil.BgLogger().Debug("unknown physical ID in stats meta table, maybe it has been dropped", zap.Int64("ID", physicalID)) + continue + } + maxPhysicalID = max(physicalID, maxPhysicalID) + tableInfo := table.Meta() + newHistColl := *statistics.NewHistColl(physicalID, true, row.GetInt64(3), row.GetInt64(2), 4, 4) + snapshot := row.GetUint64(4) + tbl := &statistics.Table{ + HistColl: newHistColl, + Version: row.GetUint64(0), + ColAndIdxExistenceMap: statistics.NewColAndIndexExistenceMap(len(tableInfo.Columns), len(tableInfo.Indices)), + IsPkIsHandle: tableInfo.PKIsHandle, + // During the initialization phase, we need to initialize LastAnalyzeVersion with the snapshot, + // which ensures that we don't duplicate the auto-analyze of a particular type of table. + // When the predicate columns feature is turned on, if a table has neither predicate columns nor indexes, + // then auto-analyze will only analyze the _row_id and refresh stats_meta, + // but since we don't have any histograms or topn's created for _row_id at the moment. + // So if we don't initialize LastAnalyzeVersion with the snapshot here, + // it will stay at 0 and auto-analyze won't be able to detect that the table has been analyzed. + // But in the future, we maybe will create some records for _row_id, see: + // https://github.com/pingcap/tidb/issues/51098 + LastAnalyzeVersion: snapshot, + } + cache.Put(physicalID, tbl) // put this table again since it is updated + } + maxTidRecord.mu.Lock() + defer maxTidRecord.mu.Unlock() + if maxTidRecord.tid.Load() < maxPhysicalID { + maxTidRecord.tid.Store(physicalID) + } +} + +func (h *Handle) initStatsMeta(ctx context.Context, is infoschema.InfoSchema) (statstypes.StatsCache, error) { + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) + sql := "select HIGH_PRIORITY version, table_id, modify_count, count, snapshot from mysql.stats_meta" + rc, err := util.Exec(h.initStatsCtx, sql) + if err != nil { + return nil, errors.Trace(err) + } + defer terror.Call(rc.Close) + tables, err := cache.NewStatsCacheImpl(h) + if err != nil { + return nil, err + } + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return nil, errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsMeta4Chunk(ctx, is, tables, iter) + } + return tables, nil +} + +func (h *Handle) initStatsHistograms4ChunkLite(is infoschema.InfoSchema, cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { + var table *statistics.Table + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + tblID := row.GetInt64(0) + if table == nil || table.PhysicalID != tblID { + if table != nil { + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } + var ok bool + table, ok = cache.Get(tblID) + if !ok { + continue + } + table = table.Copy() + } + isIndex := row.GetInt64(1) + id := row.GetInt64(2) + ndv := row.GetInt64(3) + nullCount := row.GetInt64(5) + statsVer := row.GetInt64(7) + tbl, _ := h.TableInfoByID(is, table.PhysicalID) + // All the objects in the table share the same stats version. + if statsVer != statistics.Version0 { + table.StatsVer = int(statsVer) + } + if isIndex > 0 { + var idxInfo *model.IndexInfo + for _, idx := range tbl.Meta().Indices { + if idx.ID == id { + idxInfo = idx + break + } + } + if idxInfo == nil { + continue + } + table.ColAndIdxExistenceMap.InsertIndex(idxInfo.ID, idxInfo, statsVer != statistics.Version0) + if statsVer != statistics.Version0 { + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, row.GetUint64(4)) + } + } else { + var colInfo *model.ColumnInfo + for _, col := range tbl.Meta().Columns { + if col.ID == id { + colInfo = col + break + } + } + if colInfo == nil { + continue + } + table.ColAndIdxExistenceMap.InsertCol(colInfo.ID, colInfo, statsVer != statistics.Version0 || ndv > 0 || nullCount > 0) + if statsVer != statistics.Version0 { + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, row.GetUint64(4)) + } + } + } + if table != nil { + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } +} + +func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, cache statstypes.StatsCache, iter *chunk.Iterator4Chunk, isCacheFull bool) { + var table *statistics.Table + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + tblID, statsVer := row.GetInt64(0), row.GetInt64(8) + if table == nil || table.PhysicalID != tblID { + if table != nil { + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } + var ok bool + table, ok = cache.Get(tblID) + if !ok { + continue + } + table = table.Copy() + } + // All the objects in the table share the same stats version. + if statsVer != statistics.Version0 { + table.StatsVer = int(statsVer) + } + id, ndv, nullCount, version, totColSize := row.GetInt64(2), row.GetInt64(3), row.GetInt64(5), row.GetUint64(4), row.GetInt64(7) + lastAnalyzePos := row.GetDatum(11, types.NewFieldType(mysql.TypeBlob)) + tbl, _ := h.TableInfoByID(is, table.PhysicalID) + if row.GetInt64(1) > 0 { + var idxInfo *model.IndexInfo + for _, idx := range tbl.Meta().Indices { + if idx.ID == id { + idxInfo = idx + break + } + } + if idxInfo == nil { + continue + } + + var cms *statistics.CMSketch + var topN *statistics.TopN + var err error + if !isCacheFull { + // stats cache is full. we should not put it into cache. but we must set LastAnalyzeVersion + cms, topN, err = statistics.DecodeCMSketchAndTopN(row.GetBytes(6), nil) + if err != nil { + cms = nil + terror.Log(errors.Trace(err)) + } + } + hist := statistics.NewHistogram(id, ndv, nullCount, version, types.NewFieldType(mysql.TypeBlob), chunk.InitialCapacity, 0) + index := &statistics.Index{ + Histogram: *hist, + CMSketch: cms, + TopN: topN, + Info: idxInfo, + StatsVer: statsVer, + Flag: row.GetInt64(10), + PhysicalID: tblID, + } + if statsVer != statistics.Version0 { + // We first set the StatsLoadedStatus as AllEvicted. when completing to load bucket, we will set it as ALlLoad. + index.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) + } + lastAnalyzePos.Copy(&index.LastAnalyzePos) + table.SetIdx(idxInfo.ID, index) + table.ColAndIdxExistenceMap.InsertIndex(idxInfo.ID, idxInfo, statsVer != statistics.Version0) + } else { + var colInfo *model.ColumnInfo + for _, col := range tbl.Meta().Columns { + if col.ID == id { + colInfo = col + break + } + } + if colInfo == nil { + continue + } + hist := statistics.NewHistogram(id, ndv, nullCount, version, &colInfo.FieldType, 0, totColSize) + hist.Correlation = row.GetFloat64(9) + col := &statistics.Column{ + Histogram: *hist, + PhysicalID: table.PhysicalID, + Info: colInfo, + IsHandle: tbl.Meta().PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), + Flag: row.GetInt64(10), + StatsVer: statsVer, + } + // primary key column has no stats info, because primary key's is_index is false. so it cannot load the topn + col.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + lastAnalyzePos.Copy(&col.LastAnalyzePos) + table.SetCol(hist.ID, col) + table.ColAndIdxExistenceMap.InsertCol(colInfo.ID, colInfo, statsVer != statistics.Version0 || ndv > 0 || nullCount > 0) + if statsVer != statistics.Version0 { + // The LastAnalyzeVersion is added by ALTER table so its value might be 0. + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) + } + } + } + if table != nil { + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } +} + +func (h *Handle) initStatsHistogramsLite(ctx context.Context, is infoschema.InfoSchema, cache statstypes.StatsCache) error { + sql := "select /*+ ORDER_INDEX(mysql.stats_histograms,tbl)*/ HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms order by table_id" + rc, err := util.Exec(h.initStatsCtx, sql) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsHistograms4ChunkLite(is, cache, iter) + } + return nil +} + +func (h *Handle) initStatsHistograms(is infoschema.InfoSchema, cache statstypes.StatsCache) error { + sql := "select /*+ ORDER_INDEX(mysql.stats_histograms,tbl)*/ HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms order by table_id" + rc, err := util.Exec(h.initStatsCtx, sql) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsHistograms4Chunk(is, cache, iter, false) + } + return nil +} + +func (h *Handle) initStatsHistogramsByPaging(is infoschema.InfoSchema, cache statstypes.StatsCache, task initstats.Task, totalMemory uint64) error { + se, err := h.Pool.SPool().Get() + if err != nil { + return err + } + defer func() { + if err == nil { // only recycle when no error + h.Pool.SPool().Put(se) + } + }() + + sctx := se.(sessionctx.Context) + // Why do we need to add `is_index=1` in the SQL? + // because it is aligned to the `initStatsTopN` function, which only loads the topn of the index too. + // the other will be loaded by sync load. + sql := "select HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms where table_id >= %? and table_id < %? and is_index=1" + rc, err := util.Exec(sctx, sql, task.StartTid, task.EndTid) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsHistograms4Chunk(is, cache, iter, isFullCache(cache, totalMemory)) + } + return nil +} + +func (h *Handle) initStatsHistogramsConcurrency(is infoschema.InfoSchema, cache statstypes.StatsCache, totalMemory uint64) error { + var maxTid = maxTidRecord.tid.Load() + tid := int64(0) + ls := initstats.NewRangeWorker("histogram", func(task initstats.Task) error { + return h.initStatsHistogramsByPaging(is, cache, task, totalMemory) + }, uint64(maxTid), uint64(initStatsStep)) + ls.LoadStats() + for tid <= maxTid { + ls.SendTask(initstats.Task{ + StartTid: tid, + EndTid: tid + initStatsStep, + }) + tid += initStatsStep + } + ls.Wait() + return nil +} + +func (*Handle) initStatsTopN4Chunk(cache statstypes.StatsCache, iter *chunk.Iterator4Chunk, totalMemory uint64) { + if isFullCache(cache, totalMemory) { + return + } + affectedIndexes := make(map[*statistics.Index]struct{}) + var table *statistics.Table + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + tblID := row.GetInt64(0) + if table == nil || table.PhysicalID != tblID { + if table != nil { + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } + var ok bool + table, ok = cache.Get(tblID) + if !ok { + continue + } + table = table.Copy() + } + idx := table.GetIdx(row.GetInt64(1)) + if idx == nil || (idx.CMSketch == nil && idx.StatsVer <= statistics.Version1) { + continue + } + if idx.TopN == nil { + idx.TopN = statistics.NewTopN(32) + } + affectedIndexes[idx] = struct{}{} + data := make([]byte, len(row.GetBytes(2))) + copy(data, row.GetBytes(2)) + idx.TopN.AppendTopN(data, row.GetUint64(3)) + } + if table != nil { + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } + for idx := range affectedIndexes { + idx.TopN.Sort() + } +} + +func (h *Handle) initStatsTopN(cache statstypes.StatsCache, totalMemory uint64) error { + sql := "select /*+ ORDER_INDEX(mysql.stats_top_n,tbl)*/ HIGH_PRIORITY table_id, hist_id, value, count from mysql.stats_top_n where is_index = 1 order by table_id" + rc, err := util.Exec(h.initStatsCtx, sql) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsTopN4Chunk(cache, iter, totalMemory) + } + return nil +} + +func (h *Handle) initStatsTopNByPaging(cache statstypes.StatsCache, task initstats.Task, totalMemory uint64) error { + se, err := h.Pool.SPool().Get() + if err != nil { + return err + } + defer func() { + if err == nil { // only recycle when no error + h.Pool.SPool().Put(se) + } + }() + sctx := se.(sessionctx.Context) + sql := "select HIGH_PRIORITY table_id, hist_id, value, count from mysql.stats_top_n where is_index = 1 and table_id >= %? and table_id < %? order by table_id" + rc, err := util.Exec(sctx, sql, task.StartTid, task.EndTid) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsTopN4Chunk(cache, iter, totalMemory) + } + return nil +} + +func (h *Handle) initStatsTopNConcurrency(cache statstypes.StatsCache, totalMemory uint64) error { + if isFullCache(cache, totalMemory) { + return nil + } + var maxTid = maxTidRecord.tid.Load() + tid := int64(0) + ls := initstats.NewRangeWorker("TopN", func(task initstats.Task) error { + if isFullCache(cache, totalMemory) { + return nil + } + return h.initStatsTopNByPaging(cache, task, totalMemory) + }, uint64(maxTid), uint64(initStatsStep)) + ls.LoadStats() + for tid <= maxTid { + if isFullCache(cache, totalMemory) { + break + } + ls.SendTask(initstats.Task{ + StartTid: tid, + EndTid: tid + initStatsStep, + }) + tid += initStatsStep + } + ls.Wait() + return nil +} + +func (*Handle) initStatsFMSketch4Chunk(cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + table, ok := cache.Get(row.GetInt64(0)) + if !ok { + continue + } + fms, err := statistics.DecodeFMSketch(row.GetBytes(3)) + if err != nil { + fms = nil + terror.Log(errors.Trace(err)) + } + + isIndex := row.GetInt64(1) + id := row.GetInt64(2) + if isIndex == 1 { + if idxStats := table.GetIdx(id); idxStats != nil { + idxStats.FMSketch = fms + } + } else { + if colStats := table.GetCol(id); colStats != nil { + colStats.FMSketch = fms + } + } + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } +} + +func (h *Handle) initStatsFMSketch(cache statstypes.StatsCache) error { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + sql := "select HIGH_PRIORITY table_id, is_index, hist_id, value from mysql.stats_fm_sketch" + rc, err := util.Exec(h.initStatsCtx, sql) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsFMSketch4Chunk(cache, iter) + } + return nil +} + +func (*Handle) initStatsBuckets4Chunk(cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { + var table *statistics.Table + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + tableID, isIndex, histID := row.GetInt64(0), row.GetInt64(1), row.GetInt64(2) + if table == nil || table.PhysicalID != tableID { + if table != nil { + table.SetAllIndexFullLoadForBootstrap() + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } + var ok bool + table, ok = cache.Get(tableID) + if !ok { + continue + } + table = table.Copy() + } + var lower, upper types.Datum + var hist *statistics.Histogram + if isIndex > 0 { + index := table.GetIdx(histID) + if index == nil { + continue + } + hist = &index.Histogram + lower, upper = types.NewBytesDatum(row.GetBytes(5)), types.NewBytesDatum(row.GetBytes(6)) + } else { + column := table.GetCol(histID) + if column == nil { + continue + } + if !mysql.HasPriKeyFlag(column.Info.GetFlag()) { + continue + } + hist = &column.Histogram + d := types.NewBytesDatum(row.GetBytes(5)) + var err error + lower, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, &column.Info.FieldType) + if err != nil { + logutil.BgLogger().Debug("decode bucket lower bound failed", zap.Error(err)) + table.DelCol(histID) + continue + } + d = types.NewBytesDatum(row.GetBytes(6)) + upper, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, &column.Info.FieldType) + if err != nil { + logutil.BgLogger().Debug("decode bucket upper bound failed", zap.Error(err)) + table.DelCol(histID) + continue + } + } + hist.AppendBucketWithNDV(&lower, &upper, row.GetInt64(3), row.GetInt64(4), row.GetInt64(7)) + } + if table != nil { + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } +} + +func (h *Handle) initStatsBuckets(cache statstypes.StatsCache, totalMemory uint64) error { + if isFullCache(cache, totalMemory) { + return nil + } + if config.GetGlobalConfig().Performance.ConcurrentlyInitStats { + err := h.initStatsBucketsConcurrency(cache, totalMemory) + if err != nil { + return errors.Trace(err) + } + } else { + sql := "select /*+ ORDER_INDEX(mysql.stats_buckets,tbl)*/ HIGH_PRIORITY table_id, is_index, hist_id, count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets order by table_id, is_index, hist_id, bucket_id" + rc, err := util.Exec(h.initStatsCtx, sql) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsBuckets4Chunk(cache, iter) + } + } + tables := cache.Values() + for _, table := range tables { + table.CalcPreScalar() + cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. + } + return nil +} + +func (h *Handle) initStatsBucketsByPaging(cache statstypes.StatsCache, task initstats.Task) error { + se, err := h.Pool.SPool().Get() + if err != nil { + return err + } + defer func() { + if err == nil { // only recycle when no error + h.Pool.SPool().Put(se) + } + }() + sctx := se.(sessionctx.Context) + sql := "select HIGH_PRIORITY table_id, is_index, hist_id, count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets where table_id >= %? and table_id < %? order by table_id, is_index, hist_id, bucket_id" + rc, err := util.Exec(sctx, sql, task.StartTid, task.EndTid) + if err != nil { + return errors.Trace(err) + } + defer terror.Call(rc.Close) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) + req := rc.NewChunk(nil) + iter := chunk.NewIterator4Chunk(req) + for { + err := rc.Next(ctx, req) + if err != nil { + return errors.Trace(err) + } + if req.NumRows() == 0 { + break + } + h.initStatsBuckets4Chunk(cache, iter) + } + return nil +} + +func (h *Handle) initStatsBucketsConcurrency(cache statstypes.StatsCache, totalMemory uint64) error { + if isFullCache(cache, totalMemory) { + return nil + } + var maxTid = maxTidRecord.tid.Load() + tid := int64(0) + ls := initstats.NewRangeWorker("bucket", func(task initstats.Task) error { + if isFullCache(cache, totalMemory) { + return nil + } + return h.initStatsBucketsByPaging(cache, task) + }, uint64(maxTid), uint64(initStatsStep)) + ls.LoadStats() + for tid <= maxTid { + ls.SendTask(initstats.Task{ + StartTid: tid, + EndTid: tid + initStatsStep, + }) + tid += initStatsStep + if isFullCache(cache, totalMemory) { + break + } + } + ls.Wait() + return nil +} + +// InitStatsLite initiates the stats cache. The function is liter and faster than InitStats. +// 1. Basic stats meta data is loaded.(count, modify count, etc.) +// 2. Column/index stats are loaded. (only histogram) +// 3. TopN, Bucket, FMSketch are not loaded. +func (h *Handle) InitStatsLite(ctx context.Context, is infoschema.InfoSchema) (err error) { + defer func() { + _, err1 := util.Exec(h.initStatsCtx, "commit") + if err == nil && err1 != nil { + err = err1 + } + }() + _, err = util.Exec(h.initStatsCtx, "begin") + if err != nil { + return err + } + failpoint.Inject("beforeInitStatsLite", func() {}) + cache, err := h.initStatsMeta(ctx, is) + if err != nil { + return errors.Trace(err) + } + statslogutil.StatsLogger().Info("complete to load the meta in the lite mode") + err = h.initStatsHistogramsLite(ctx, is, cache) + if err != nil { + cache.Close() + return errors.Trace(err) + } + statslogutil.StatsLogger().Info("complete to load the histogram in the lite mode") + h.Replace(cache) + return nil +} + +// InitStats initiates the stats cache. +// 1. Basic stats meta data is loaded.(count, modify count, etc.) +// 2. Column/index stats are loaded. (histogram, topn, buckets, FMSketch) +func (h *Handle) InitStats(ctx context.Context, is infoschema.InfoSchema) (err error) { + totalMemory, err := memory.MemTotal() + if err != nil { + return err + } + loadFMSketch := config.GetGlobalConfig().Performance.EnableLoadFMSketch + defer func() { + _, err1 := util.Exec(h.initStatsCtx, "commit") + if err == nil && err1 != nil { + err = err1 + } + }() + _, err = util.Exec(h.initStatsCtx, "begin") + if err != nil { + return err + } + failpoint.Inject("beforeInitStats", func() {}) + cache, err := h.initStatsMeta(ctx, is) + if err != nil { + return errors.Trace(err) + } + statslogutil.StatsLogger().Info("complete to load the meta") + if config.GetGlobalConfig().Performance.ConcurrentlyInitStats { + err = h.initStatsHistogramsConcurrency(is, cache, totalMemory) + } else { + err = h.initStatsHistograms(is, cache) + } + statslogutil.StatsLogger().Info("complete to load the histogram") + if err != nil { + return errors.Trace(err) + } + if config.GetGlobalConfig().Performance.ConcurrentlyInitStats { + err = h.initStatsTopNConcurrency(cache, totalMemory) + } else { + err = h.initStatsTopN(cache, totalMemory) + } + statslogutil.StatsLogger().Info("complete to load the topn") + if err != nil { + return err + } + if loadFMSketch { + err = h.initStatsFMSketch(cache) + if err != nil { + return err + } + statslogutil.StatsLogger().Info("complete to load the FM Sketch") + } + err = h.initStatsBuckets(cache, totalMemory) + statslogutil.StatsLogger().Info("complete to load the bucket") + if err != nil { + return errors.Trace(err) + } + h.Replace(cache) + return nil +} + +func isFullCache(cache statstypes.StatsCache, total uint64) bool { + memQuota := variable.StatsCacheMemQuota.Load() + return (uint64(cache.MemConsumed()) >= total/4) || (cache.MemConsumed() >= memQuota && memQuota != 0) +} diff --git a/pkg/statistics/handle/cache/binding__failpoint_binding__.go b/pkg/statistics/handle/cache/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..01370bd22e2f3 --- /dev/null +++ b/pkg/statistics/handle/cache/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package cache + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/statistics/handle/cache/statscache.go b/pkg/statistics/handle/cache/statscache.go index 66444375ea6e9..8a5b10edb2581 100644 --- a/pkg/statistics/handle/cache/statscache.go +++ b/pkg/statistics/handle/cache/statscache.go @@ -223,9 +223,9 @@ func (s *StatsCacheImpl) MemConsumed() (size int64) { // Get returns the specified table's stats. func (s *StatsCacheImpl) Get(tableID int64) (*statistics.Table, bool) { - failpoint.Inject("StatsCacheGetNil", func() { - failpoint.Return(nil, false) - }) + if _, _err_ := failpoint.Eval(_curpkg_("StatsCacheGetNil")); _err_ == nil { + return nil, false + } return s.Load().Get(tableID) } diff --git a/pkg/statistics/handle/cache/statscache.go__failpoint_stash__ b/pkg/statistics/handle/cache/statscache.go__failpoint_stash__ new file mode 100644 index 0000000000000..66444375ea6e9 --- /dev/null +++ b/pkg/statistics/handle/cache/statscache.go__failpoint_stash__ @@ -0,0 +1,287 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a Copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "context" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/infoschema" + tidbmetrics "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/metrics" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" + handle_metrics "github.com/pingcap/tidb/pkg/statistics/handle/metrics" + "github.com/pingcap/tidb/pkg/statistics/handle/types" + "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +// StatsCacheImpl implements util.StatsCache. +type StatsCacheImpl struct { + atomic.Pointer[StatsCache] + + statsHandle types.StatsHandle +} + +// NewStatsCacheImpl creates a new StatsCache. +func NewStatsCacheImpl(statsHandle types.StatsHandle) (types.StatsCache, error) { + newCache, err := NewStatsCache() + if err != nil { + return nil, err + } + + result := &StatsCacheImpl{ + statsHandle: statsHandle, + } + result.Store(newCache) + + return result, nil +} + +// NewStatsCacheImplForTest creates a new StatsCache for test. +func NewStatsCacheImplForTest() (types.StatsCache, error) { + return NewStatsCacheImpl(nil) +} + +// Update reads stats meta from store and updates the stats map. +func (s *StatsCacheImpl) Update(ctx context.Context, is infoschema.InfoSchema) error { + start := time.Now() + lastVersion := s.getLastVersion() + var ( + rows []chunk.Row + err error + ) + if err := util.CallWithSCtx(s.statsHandle.SPool(), func(sctx sessionctx.Context) error { + rows, _, err = util.ExecRows( + sctx, + "SELECT version, table_id, modify_count, count, snapshot from mysql.stats_meta where version > %? order by version", + lastVersion, + ) + return err + }); err != nil { + return errors.Trace(err) + } + + tables := make([]*statistics.Table, 0, len(rows)) + deletedTableIDs := make([]int64, 0, len(rows)) + + for _, row := range rows { + version := row.GetUint64(0) + physicalID := row.GetInt64(1) + modifyCount := row.GetInt64(2) + count := row.GetInt64(3) + snapshot := row.GetUint64(4) + + // Detect the context cancel signal, since it may take a long time for the loop. + // TODO: add context to TableInfoByID and remove this code block? + if ctx.Err() != nil { + return ctx.Err() + } + + table, ok := s.statsHandle.TableInfoByID(is, physicalID) + if !ok { + logutil.BgLogger().Debug( + "unknown physical ID in stats meta table, maybe it has been dropped", + zap.Int64("ID", physicalID), + ) + deletedTableIDs = append(deletedTableIDs, physicalID) + continue + } + tableInfo := table.Meta() + // If the table is not updated, we can skip it. + if oldTbl, ok := s.Get(physicalID); ok && + oldTbl.Version >= version && + tableInfo.UpdateTS == oldTbl.TblInfoUpdateTS { + continue + } + tbl, err := s.statsHandle.TableStatsFromStorage( + tableInfo, + physicalID, + false, + 0, + ) + // Error is not nil may mean that there are some ddl changes on this table, we will not update it. + if err != nil { + statslogutil.StatsLogger().Error( + "error occurred when read table stats", + zap.String("table", tableInfo.Name.O), + zap.Error(err), + ) + continue + } + if tbl == nil { + deletedTableIDs = append(deletedTableIDs, physicalID) + continue + } + tbl.Version = version + tbl.RealtimeCount = count + tbl.ModifyCount = modifyCount + tbl.TblInfoUpdateTS = tableInfo.UpdateTS + // It only occurs in the following situations: + // 1. The table has already been analyzed, + // but because the predicate columns feature is turned on, and it doesn't have any columns or indexes analyzed, + // it only analyzes _row_id and refreshes stats_meta, in which case the snapshot is not zero. + // 2. LastAnalyzeVersion is 0 because it has never been loaded. + // In this case, we can initialize LastAnalyzeVersion to the snapshot, + // otherwise auto-analyze will assume that the table has never been analyzed and try to analyze it again. + if tbl.LastAnalyzeVersion == 0 && snapshot != 0 { + tbl.LastAnalyzeVersion = snapshot + } + tables = append(tables, tbl) + } + + s.UpdateStatsCache(tables, deletedTableIDs) + dur := time.Since(start) + tidbmetrics.StatsDeltaLoadHistogram.Observe(dur.Seconds()) + return nil +} + +func (s *StatsCacheImpl) getLastVersion() uint64 { + // Get the greatest version of the stats meta table. + lastVersion := s.MaxTableStatsVersion() + // We need this because for two tables, the smaller version may write later than the one with larger version. + // Consider the case that there are two tables A and B, their version and commit time is (A0, A1) and (B0, B1), + // and A0 < B0 < B1 < A1. We will first read the stats of B, and update the lastVersion to B0, but we cannot read + // the table stats of A0 if we read stats that greater than lastVersion which is B0. + // We can read the stats if the diff between commit time and version is less than three lease. + offset := util.DurationToTS(3 * s.statsHandle.Lease()) + if s.MaxTableStatsVersion() >= offset { + lastVersion = lastVersion - offset + } else { + lastVersion = 0 + } + + return lastVersion +} + +// Replace replaces this cache. +func (s *StatsCacheImpl) Replace(cache types.StatsCache) { + x := cache.(*StatsCacheImpl) + s.replace(x.Load()) +} + +// replace replaces the cache with the new cache. +func (s *StatsCacheImpl) replace(newCache *StatsCache) { + old := s.Swap(newCache) + if old != nil { + old.Close() + } + metrics.CostGauge.Set(float64(newCache.Cost())) +} + +// UpdateStatsCache updates the cache with the new cache. +func (s *StatsCacheImpl) UpdateStatsCache(tables []*statistics.Table, deletedIDs []int64) { + if enableQuota := config.GetGlobalConfig().Performance.EnableStatsCacheMemQuota; enableQuota { + s.Load().Update(tables, deletedIDs) + } else { + // TODO: remove this branch because we will always enable quota. + newCache := s.Load().CopyAndUpdate(tables, deletedIDs) + s.replace(newCache) + } +} + +// Close closes this cache. +func (s *StatsCacheImpl) Close() { + s.Load().Close() +} + +// Clear clears this cache. +// Create a empty cache and replace the old one. +func (s *StatsCacheImpl) Clear() { + cache, err := NewStatsCache() + if err != nil { + logutil.BgLogger().Warn("create stats cache failed", zap.Error(err)) + return + } + s.replace(cache) +} + +// MemConsumed returns its memory usage. +func (s *StatsCacheImpl) MemConsumed() (size int64) { + return s.Load().Cost() +} + +// Get returns the specified table's stats. +func (s *StatsCacheImpl) Get(tableID int64) (*statistics.Table, bool) { + failpoint.Inject("StatsCacheGetNil", func() { + failpoint.Return(nil, false) + }) + return s.Load().Get(tableID) +} + +// Put puts this table stats into the cache. +func (s *StatsCacheImpl) Put(id int64, t *statistics.Table) { + s.Load().put(id, t) +} + +// MaxTableStatsVersion returns the version of the current cache, which is defined as +// the max table stats version the cache has in its lifecycle. +func (s *StatsCacheImpl) MaxTableStatsVersion() uint64 { + return s.Load().Version() +} + +// Values returns all values in this cache. +func (s *StatsCacheImpl) Values() []*statistics.Table { + return s.Load().Values() +} + +// Len returns the length of this cache. +func (s *StatsCacheImpl) Len() int { + return s.Load().Len() +} + +// SetStatsCacheCapacity sets the cache's capacity. +func (s *StatsCacheImpl) SetStatsCacheCapacity(c int64) { + s.Load().SetCapacity(c) +} + +// UpdateStatsHealthyMetrics updates stats healthy distribution metrics according to stats cache. +func (s *StatsCacheImpl) UpdateStatsHealthyMetrics() { + distribution := make([]int64, 5) + uneligibleAnalyze := 0 + for _, tbl := range s.Values() { + distribution[4]++ // total table count + isEligibleForAnalysis := tbl.IsEligibleForAnalysis() + if !isEligibleForAnalysis { + uneligibleAnalyze++ + continue + } + healthy, ok := tbl.GetStatsHealthy() + if !ok { + continue + } + if healthy < 50 { + distribution[0]++ + } else if healthy < 80 { + distribution[1]++ + } else if healthy < 100 { + distribution[2]++ + } else { + distribution[3]++ + } + } + for i, val := range distribution { + handle_metrics.StatsHealthyGauges[i].Set(float64(val)) + } + handle_metrics.StatsHealthyGauges[5].Set(float64(uneligibleAnalyze)) +} diff --git a/pkg/statistics/handle/globalstats/binding__failpoint_binding__.go b/pkg/statistics/handle/globalstats/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..d91f90cd70082 --- /dev/null +++ b/pkg/statistics/handle/globalstats/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package globalstats + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/statistics/handle/globalstats/global_stats_async.go b/pkg/statistics/handle/globalstats/global_stats_async.go index 2ef30c6cdba24..8b2dce81f247a 100644 --- a/pkg/statistics/handle/globalstats/global_stats_async.go +++ b/pkg/statistics/handle/globalstats/global_stats_async.go @@ -242,12 +242,12 @@ func (a *AsyncMergePartitionStats2GlobalStats) ioWorker(sctx sessionctx.Context, return err } close(a.cmsketch) - failpoint.Inject("PanicSameTime", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("PanicSameTime")); _err_ == nil { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) panic("test for PanicSameTime") } - }) + } err = a.loadHistogramAndTopN(sctx, a.globalTableInfo, isIndex) if err != nil { close(a.ioWorkerExitWhenErrChan) @@ -285,12 +285,12 @@ func (a *AsyncMergePartitionStats2GlobalStats) cpuWorker(stmtCtx *stmtctx.Statem statslogutil.StatsLogger().Warn("dealCMSketch failed", zap.Error(err)) return err } - failpoint.Inject("PanicSameTime", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("PanicSameTime")); _err_ == nil { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) panic("test for PanicSameTime") } - }) + } err = a.dealHistogramAndTopN(stmtCtx, sctx, opts, isIndex, tz, analyzeVersion) if err != nil { statslogutil.StatsLogger().Warn("dealHistogramAndTopN failed", zap.Error(err)) @@ -370,7 +370,7 @@ func (a *AsyncMergePartitionStats2GlobalStats) loadFmsketch(sctx sessionctx.Cont } func (a *AsyncMergePartitionStats2GlobalStats) loadCMsketch(sctx sessionctx.Context, isIndex bool) error { - failpoint.Inject("PanicInIOWorker", nil) + failpoint.Eval(_curpkg_("PanicInIOWorker")) for i := 0; i < a.globalStats.Num; i++ { for _, partitionID := range a.partitionIDs { _, ok := a.skipPartition[skipItem{ @@ -401,12 +401,12 @@ func (a *AsyncMergePartitionStats2GlobalStats) loadCMsketch(sctx sessionctx.Cont } func (a *AsyncMergePartitionStats2GlobalStats) loadHistogramAndTopN(sctx sessionctx.Context, tableInfo *model.TableInfo, isIndex bool) error { - failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ErrorSameTime")); _err_ == nil { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) - failpoint.Return(errors.New("ErrorSameTime returned error")) + return errors.New("ErrorSameTime returned error") } - }) + } for i := 0; i < a.globalStats.Num; i++ { hists := make([]*statistics.Histogram, 0, a.partitionNum) topn := make([]*statistics.TopN, 0, a.partitionNum) @@ -442,7 +442,7 @@ func (a *AsyncMergePartitionStats2GlobalStats) loadHistogramAndTopN(sctx session } func (a *AsyncMergePartitionStats2GlobalStats) dealFMSketch() { - failpoint.Inject("PanicInCPUWorker", nil) + failpoint.Eval(_curpkg_("PanicInCPUWorker")) for { select { case fms, ok := <-a.fmsketch: @@ -461,11 +461,11 @@ func (a *AsyncMergePartitionStats2GlobalStats) dealFMSketch() { } func (a *AsyncMergePartitionStats2GlobalStats) dealCMSketch() error { - failpoint.Inject("dealCMSketchErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("dealCMSketchErr")); _err_ == nil { if val, _ := val.(bool); val { - failpoint.Return(errors.New("dealCMSketch returned error")) + return errors.New("dealCMSketch returned error") } - }) + } for { select { case cms, ok := <-a.cmsketch: @@ -487,17 +487,17 @@ func (a *AsyncMergePartitionStats2GlobalStats) dealCMSketch() error { } func (a *AsyncMergePartitionStats2GlobalStats) dealHistogramAndTopN(stmtCtx *stmtctx.StatementContext, sctx sessionctx.Context, opts map[ast.AnalyzeOptionType]uint64, isIndex bool, tz *time.Location, analyzeVersion int) (err error) { - failpoint.Inject("dealHistogramAndTopNErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("dealHistogramAndTopNErr")); _err_ == nil { if val, _ := val.(bool); val { - failpoint.Return(errors.New("dealHistogramAndTopNErr returned error")) + return errors.New("dealHistogramAndTopNErr returned error") } - }) - failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("ErrorSameTime")); _err_ == nil { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) - failpoint.Return(errors.New("ErrorSameTime returned error")) + return errors.New("ErrorSameTime returned error") } - }) + } for { select { case item, ok := <-a.histogramAndTopn: diff --git a/pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ b/pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ new file mode 100644 index 0000000000000..2ef30c6cdba24 --- /dev/null +++ b/pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ @@ -0,0 +1,542 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package globalstats + +import ( + "context" + stderrors "errors" + "fmt" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/infoschema" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/statistics" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" + "github.com/pingcap/tidb/pkg/statistics/handle/storage" + statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" + "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/types" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" +) + +type mergeItem[T any] struct { + item T + idx int +} + +type skipItem struct { + histID int64 + partitionID int64 +} + +// toSQLIndex is used to convert bool to int64. +func toSQLIndex(isIndex bool) int { + var index = int(0) + if isIndex { + index = 1 + } + return index +} + +// AsyncMergePartitionStats2GlobalStats is used to merge partition stats to global stats. +// it divides the merge task into two parts. +// - IOWorker: load stats from storage. it will load fmsketch, cmsketch, histogram and topn. and send them to cpuWorker. +// - CPUWorker: merge the stats from IOWorker and generate global stats. +// +// ┌────────────────────────┐ ┌───────────────────────┐ +// │ │ │ │ +// │ │ │ │ +// │ │ │ │ +// │ IOWorker │ │ CPUWorker │ +// │ │ ────► │ │ +// │ │ │ │ +// │ │ │ │ +// │ │ │ │ +// └────────────────────────┘ └───────────────────────┘ +type AsyncMergePartitionStats2GlobalStats struct { + is infoschema.InfoSchema + statsHandle statstypes.StatsHandle + globalStats *GlobalStats + cmsketch chan mergeItem[*statistics.CMSketch] + fmsketch chan mergeItem[*statistics.FMSketch] + histogramAndTopn chan mergeItem[*StatsWrapper] + allPartitionStats map[int64]*statistics.Table + PartitionDefinition map[int64]model.PartitionDefinition + tableInfo map[int64]*model.TableInfo + // key is partition id and histID + skipPartition map[skipItem]struct{} + // ioWorker meet error, it will close this channel to notify cpuWorker. + ioWorkerExitWhenErrChan chan struct{} + // cpuWorker exit, it will close this channel to notify ioWorker. + cpuWorkerExitChan chan struct{} + globalTableInfo *model.TableInfo + histIDs []int64 + globalStatsNDV []int64 + partitionIDs []int64 + partitionNum int + skipMissingPartitionStats bool +} + +// NewAsyncMergePartitionStats2GlobalStats creates a new AsyncMergePartitionStats2GlobalStats. +func NewAsyncMergePartitionStats2GlobalStats( + statsHandle statstypes.StatsHandle, + globalTableInfo *model.TableInfo, + histIDs []int64, + is infoschema.InfoSchema) (*AsyncMergePartitionStats2GlobalStats, error) { + partitionNum := len(globalTableInfo.Partition.Definitions) + return &AsyncMergePartitionStats2GlobalStats{ + statsHandle: statsHandle, + cmsketch: make(chan mergeItem[*statistics.CMSketch], 5), + fmsketch: make(chan mergeItem[*statistics.FMSketch], 5), + histogramAndTopn: make(chan mergeItem[*StatsWrapper]), + PartitionDefinition: make(map[int64]model.PartitionDefinition), + tableInfo: make(map[int64]*model.TableInfo), + partitionIDs: make([]int64, 0, partitionNum), + ioWorkerExitWhenErrChan: make(chan struct{}), + cpuWorkerExitChan: make(chan struct{}), + skipPartition: make(map[skipItem]struct{}), + allPartitionStats: make(map[int64]*statistics.Table), + globalTableInfo: globalTableInfo, + histIDs: histIDs, + is: is, + partitionNum: partitionNum, + }, nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) prepare(sctx sessionctx.Context, isIndex bool) (err error) { + if len(a.histIDs) == 0 { + for _, col := range a.globalTableInfo.Columns { + // The virtual generated column stats can not be merged to the global stats. + if col.IsVirtualGenerated() { + continue + } + a.histIDs = append(a.histIDs, col.ID) + } + } + a.globalStats = newGlobalStats(len(a.histIDs)) + a.globalStats.Num = len(a.histIDs) + a.globalStatsNDV = make([]int64, 0, a.globalStats.Num) + // get all partition stats + for _, def := range a.globalTableInfo.Partition.Definitions { + partitionID := def.ID + a.partitionIDs = append(a.partitionIDs, partitionID) + a.PartitionDefinition[partitionID] = def + partitionTable, ok := a.statsHandle.TableInfoByID(a.is, partitionID) + if !ok { + return errors.Errorf("unknown physical ID %d in stats meta table, maybe it has been dropped", partitionID) + } + tableInfo := partitionTable.Meta() + a.tableInfo[partitionID] = tableInfo + realtimeCount, modifyCount, isNull, err := storage.StatsMetaCountAndModifyCount(sctx, partitionID) + if err != nil { + return err + } + if !isNull { + // In a partition, we will only update globalStats.Count once. + a.globalStats.Count += realtimeCount + a.globalStats.ModifyCount += modifyCount + } + err1 := skipPartition(sctx, partitionID, isIndex) + if err1 != nil { + // no idx so idx = 0 + err := a.dealWithSkipPartition(partitionID, isIndex, 0, err1) + if err != nil { + return err + } + if types.ErrPartitionStatsMissing.Equal(err1) { + continue + } + } + for idx, hist := range a.histIDs { + err1 := skipColumnPartition(sctx, partitionID, isIndex, hist) + if err1 != nil { + err := a.dealWithSkipPartition(partitionID, isIndex, idx, err1) + if err != nil { + return err + } + if types.ErrPartitionStatsMissing.Equal(err1) { + break + } + } + } + } + return nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) dealWithSkipPartition(partitionID int64, isIndex bool, idx int, err error) error { + switch { + case types.ErrPartitionStatsMissing.Equal(err): + return a.dealErrPartitionStatsMissing(partitionID) + case types.ErrPartitionColumnStatsMissing.Equal(err): + return a.dealErrPartitionColumnStatsMissing(isIndex, partitionID, idx) + default: + return err + } +} + +func (a *AsyncMergePartitionStats2GlobalStats) dealErrPartitionStatsMissing(partitionID int64) error { + missingPart := fmt.Sprintf("partition `%s`", a.PartitionDefinition[partitionID].Name.L) + a.globalStats.MissingPartitionStats = append(a.globalStats.MissingPartitionStats, missingPart) + for _, histID := range a.histIDs { + a.skipPartition[skipItem{ + histID: histID, + partitionID: partitionID, + }] = struct{}{} + } + return nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) dealErrPartitionColumnStatsMissing(isIndex bool, partitionID int64, idx int) error { + var missingPart string + if isIndex { + missingPart = fmt.Sprintf("partition `%s` index `%s`", a.PartitionDefinition[partitionID].Name.L, a.tableInfo[partitionID].FindIndexNameByID(a.histIDs[idx])) + } else { + missingPart = fmt.Sprintf("partition `%s` column `%s`", a.PartitionDefinition[partitionID].Name.L, a.tableInfo[partitionID].FindColumnNameByID(a.histIDs[idx])) + } + if !a.skipMissingPartitionStats { + return types.ErrPartitionColumnStatsMissing.GenWithStackByArgs(fmt.Sprintf("table `%s` %s", a.tableInfo[partitionID].Name.L, missingPart)) + } + a.globalStats.MissingPartitionStats = append(a.globalStats.MissingPartitionStats, missingPart) + a.skipPartition[skipItem{ + histID: a.histIDs[idx], + partitionID: partitionID, + }] = struct{}{} + return nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) ioWorker(sctx sessionctx.Context, isIndex bool) (err error) { + defer func() { + if r := recover(); r != nil { + statslogutil.StatsLogger().Warn("ioWorker panic", zap.Stack("stack"), zap.Any("error", r)) + close(a.ioWorkerExitWhenErrChan) + err = errors.New(fmt.Sprint(r)) + } + }() + err = a.loadFmsketch(sctx, isIndex) + if err != nil { + close(a.ioWorkerExitWhenErrChan) + return err + } + close(a.fmsketch) + err = a.loadCMsketch(sctx, isIndex) + if err != nil { + close(a.ioWorkerExitWhenErrChan) + return err + } + close(a.cmsketch) + failpoint.Inject("PanicSameTime", func(val failpoint.Value) { + if val, _ := val.(bool); val { + time.Sleep(1 * time.Second) + panic("test for PanicSameTime") + } + }) + err = a.loadHistogramAndTopN(sctx, a.globalTableInfo, isIndex) + if err != nil { + close(a.ioWorkerExitWhenErrChan) + return err + } + close(a.histogramAndTopn) + return nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) cpuWorker(stmtCtx *stmtctx.StatementContext, sctx sessionctx.Context, opts map[ast.AnalyzeOptionType]uint64, isIndex bool, tz *time.Location, analyzeVersion int) (err error) { + defer func() { + if r := recover(); r != nil { + statslogutil.StatsLogger().Warn("cpuWorker panic", zap.Stack("stack"), zap.Any("error", r)) + err = errors.New(fmt.Sprint(r)) + } + close(a.cpuWorkerExitChan) + }() + a.dealFMSketch() + select { + case <-a.ioWorkerExitWhenErrChan: + return nil + default: + for i := 0; i < a.globalStats.Num; i++ { + // Update the global NDV. + globalStatsNDV := a.globalStats.Fms[i].NDV() + if globalStatsNDV > a.globalStats.Count { + globalStatsNDV = a.globalStats.Count + } + a.globalStatsNDV = append(a.globalStatsNDV, globalStatsNDV) + a.globalStats.Fms[i].DestroyAndPutToPool() + } + } + err = a.dealCMSketch() + if err != nil { + statslogutil.StatsLogger().Warn("dealCMSketch failed", zap.Error(err)) + return err + } + failpoint.Inject("PanicSameTime", func(val failpoint.Value) { + if val, _ := val.(bool); val { + time.Sleep(1 * time.Second) + panic("test for PanicSameTime") + } + }) + err = a.dealHistogramAndTopN(stmtCtx, sctx, opts, isIndex, tz, analyzeVersion) + if err != nil { + statslogutil.StatsLogger().Warn("dealHistogramAndTopN failed", zap.Error(err)) + return err + } + return nil +} + +// Result returns the global stats. +func (a *AsyncMergePartitionStats2GlobalStats) Result() *GlobalStats { + return a.globalStats +} + +// MergePartitionStats2GlobalStats merges partition stats to global stats. +func (a *AsyncMergePartitionStats2GlobalStats) MergePartitionStats2GlobalStats( + sctx sessionctx.Context, + opts map[ast.AnalyzeOptionType]uint64, + isIndex bool, +) error { + a.skipMissingPartitionStats = sctx.GetSessionVars().SkipMissingPartitionStats + tz := sctx.GetSessionVars().StmtCtx.TimeZone() + analyzeVersion := sctx.GetSessionVars().AnalyzeVersion + stmtCtx := sctx.GetSessionVars().StmtCtx + return util.CallWithSCtx(a.statsHandle.SPool(), + func(sctx sessionctx.Context) error { + err := a.prepare(sctx, isIndex) + if err != nil { + return err + } + ctx := context.Background() + metawg, _ := errgroup.WithContext(ctx) + mergeWg, _ := errgroup.WithContext(ctx) + metawg.Go(func() error { + return a.ioWorker(sctx, isIndex) + }) + mergeWg.Go(func() error { + return a.cpuWorker(stmtCtx, sctx, opts, isIndex, tz, analyzeVersion) + }) + err = metawg.Wait() + if err != nil { + if err1 := mergeWg.Wait(); err1 != nil { + err = stderrors.Join(err, err1) + } + return err + } + return mergeWg.Wait() + }, + ) +} + +func (a *AsyncMergePartitionStats2GlobalStats) loadFmsketch(sctx sessionctx.Context, isIndex bool) error { + for i := 0; i < a.globalStats.Num; i++ { + // load fmsketch from tikv + for _, partitionID := range a.partitionIDs { + _, ok := a.skipPartition[skipItem{ + histID: a.histIDs[i], + partitionID: partitionID, + }] + if ok { + continue + } + fmsketch, err := storage.FMSketchFromStorage(sctx, partitionID, int64(toSQLIndex(isIndex)), a.histIDs[i]) + if err != nil { + return err + } + select { + case a.fmsketch <- mergeItem[*statistics.FMSketch]{ + fmsketch, i, + }: + case <-a.cpuWorkerExitChan: + statslogutil.StatsLogger().Warn("ioWorker detects CPUWorker has exited") + return nil + } + } + } + return nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) loadCMsketch(sctx sessionctx.Context, isIndex bool) error { + failpoint.Inject("PanicInIOWorker", nil) + for i := 0; i < a.globalStats.Num; i++ { + for _, partitionID := range a.partitionIDs { + _, ok := a.skipPartition[skipItem{ + histID: a.histIDs[i], + partitionID: partitionID, + }] + if ok { + continue + } + cmsketch, err := storage.CMSketchFromStorage(sctx, partitionID, toSQLIndex(isIndex), a.histIDs[i]) + if err != nil { + return err + } + a.cmsketch <- mergeItem[*statistics.CMSketch]{ + cmsketch, i, + } + select { + case a.cmsketch <- mergeItem[*statistics.CMSketch]{ + cmsketch, i, + }: + case <-a.cpuWorkerExitChan: + statslogutil.StatsLogger().Warn("ioWorker detects CPUWorker has exited") + return nil + } + } + } + return nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) loadHistogramAndTopN(sctx sessionctx.Context, tableInfo *model.TableInfo, isIndex bool) error { + failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { + if val, _ := val.(bool); val { + time.Sleep(1 * time.Second) + failpoint.Return(errors.New("ErrorSameTime returned error")) + } + }) + for i := 0; i < a.globalStats.Num; i++ { + hists := make([]*statistics.Histogram, 0, a.partitionNum) + topn := make([]*statistics.TopN, 0, a.partitionNum) + for _, partitionID := range a.partitionIDs { + _, ok := a.skipPartition[skipItem{ + histID: a.histIDs[i], + partitionID: partitionID, + }] + if ok { + continue + } + h, err := storage.LoadHistogram(sctx, partitionID, toSQLIndex(isIndex), a.histIDs[i], tableInfo) + if err != nil { + return err + } + t, err := storage.TopNFromStorage(sctx, partitionID, toSQLIndex(isIndex), a.histIDs[i]) + if err != nil { + return err + } + hists = append(hists, h) + topn = append(topn, t) + } + select { + case a.histogramAndTopn <- mergeItem[*StatsWrapper]{ + NewStatsWrapper(hists, topn), i, + }: + case <-a.cpuWorkerExitChan: + statslogutil.StatsLogger().Warn("ioWorker detects CPUWorker has exited") + return nil + } + } + return nil +} + +func (a *AsyncMergePartitionStats2GlobalStats) dealFMSketch() { + failpoint.Inject("PanicInCPUWorker", nil) + for { + select { + case fms, ok := <-a.fmsketch: + if !ok { + return + } + if a.globalStats.Fms[fms.idx] == nil { + a.globalStats.Fms[fms.idx] = fms.item + } else { + a.globalStats.Fms[fms.idx].MergeFMSketch(fms.item) + } + case <-a.ioWorkerExitWhenErrChan: + return + } + } +} + +func (a *AsyncMergePartitionStats2GlobalStats) dealCMSketch() error { + failpoint.Inject("dealCMSketchErr", func(val failpoint.Value) { + if val, _ := val.(bool); val { + failpoint.Return(errors.New("dealCMSketch returned error")) + } + }) + for { + select { + case cms, ok := <-a.cmsketch: + if !ok { + return nil + } + if a.globalStats.Cms[cms.idx] == nil { + a.globalStats.Cms[cms.idx] = cms.item + } else { + err := a.globalStats.Cms[cms.idx].MergeCMSketch(cms.item) + if err != nil { + return err + } + } + case <-a.ioWorkerExitWhenErrChan: + return nil + } + } +} + +func (a *AsyncMergePartitionStats2GlobalStats) dealHistogramAndTopN(stmtCtx *stmtctx.StatementContext, sctx sessionctx.Context, opts map[ast.AnalyzeOptionType]uint64, isIndex bool, tz *time.Location, analyzeVersion int) (err error) { + failpoint.Inject("dealHistogramAndTopNErr", func(val failpoint.Value) { + if val, _ := val.(bool); val { + failpoint.Return(errors.New("dealHistogramAndTopNErr returned error")) + } + }) + failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { + if val, _ := val.(bool); val { + time.Sleep(1 * time.Second) + failpoint.Return(errors.New("ErrorSameTime returned error")) + } + }) + for { + select { + case item, ok := <-a.histogramAndTopn: + if !ok { + return nil + } + var err error + var poppedTopN []statistics.TopNMeta + var allhg []*statistics.Histogram + wrapper := item.item + a.globalStats.TopN[item.idx], poppedTopN, allhg, err = mergeGlobalStatsTopN(a.statsHandle.GPool(), sctx, wrapper, + tz, analyzeVersion, uint32(opts[ast.AnalyzeOptNumTopN]), isIndex) + if err != nil { + return err + } + + // Merge histogram. + globalHg := &(a.globalStats.Hg[item.idx]) + *globalHg, err = statistics.MergePartitionHist2GlobalHist(stmtCtx, allhg, poppedTopN, + int64(opts[ast.AnalyzeOptNumBuckets]), isIndex) + if err != nil { + return err + } + + // NOTICE: after merging bucket NDVs have the trend to be underestimated, so for safe we don't use them. + for j := range (*globalHg).Buckets { + (*globalHg).Buckets[j].NDV = 0 + } + (*globalHg).NDV = a.globalStatsNDV[item.idx] + case <-a.ioWorkerExitWhenErrChan: + return nil + } + } +} + +func skipPartition(sctx sessionctx.Context, partitionID int64, isIndex bool) error { + return storage.CheckSkipPartition(sctx, partitionID, toSQLIndex(isIndex)) +} + +func skipColumnPartition(sctx sessionctx.Context, partitionID int64, isIndex bool, histsID int64) error { + return storage.CheckSkipColumnPartiion(sctx, partitionID, toSQLIndex(isIndex), histsID) +} diff --git a/pkg/statistics/handle/storage/binding__failpoint_binding__.go b/pkg/statistics/handle/storage/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..a1a747a15d57f --- /dev/null +++ b/pkg/statistics/handle/storage/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package storage + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/statistics/handle/storage/read.go b/pkg/statistics/handle/storage/read.go index 410f66b53c63a..c5aa829062d10 100644 --- a/pkg/statistics/handle/storage/read.go +++ b/pkg/statistics/handle/storage/read.go @@ -221,9 +221,9 @@ func CheckSkipColumnPartiion(sctx sessionctx.Context, tblID int64, isIndex int, // ExtendedStatsFromStorage reads extended stats from storage. func ExtendedStatsFromStorage(sctx sessionctx.Context, table *statistics.Table, tableID int64, loadAll bool) (*statistics.Table, error) { - failpoint.Inject("injectExtStatsLoadErr", func() { - failpoint.Return(nil, errors.New("gofail extendedStatsFromStorage error")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("injectExtStatsLoadErr")); _err_ == nil { + return nil, errors.New("gofail extendedStatsFromStorage error") + } lastVersion := uint64(0) if table.ExtendedStats != nil && !loadAll { lastVersion = table.ExtendedStats.LastUpdateVersion diff --git a/pkg/statistics/handle/storage/read.go__failpoint_stash__ b/pkg/statistics/handle/storage/read.go__failpoint_stash__ new file mode 100644 index 0000000000000..410f66b53c63a --- /dev/null +++ b/pkg/statistics/handle/storage/read.go__failpoint_stash__ @@ -0,0 +1,759 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "encoding/json" + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/asyncload" + statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" + statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" + "github.com/pingcap/tidb/pkg/statistics/handle/util" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "go.uber.org/zap" +) + +// StatsMetaCountAndModifyCount reads count and modify_count for the given table from mysql.stats_meta. +func StatsMetaCountAndModifyCount(sctx sessionctx.Context, tableID int64) (count, modifyCount int64, isNull bool, err error) { + rows, _, err := util.ExecRows(sctx, "select count, modify_count from mysql.stats_meta where table_id = %?", tableID) + if err != nil { + return 0, 0, false, err + } + if len(rows) == 0 { + return 0, 0, true, nil + } + count = int64(rows[0].GetUint64(0)) + modifyCount = rows[0].GetInt64(1) + return count, modifyCount, false, nil +} + +// HistMetaFromStorageWithHighPriority reads the meta info of the histogram from the storage. +func HistMetaFromStorageWithHighPriority(sctx sessionctx.Context, item *model.TableItemID, possibleColInfo *model.ColumnInfo) (*statistics.Histogram, *types.Datum, int64, int64, error) { + isIndex := 0 + var tp *types.FieldType + if item.IsIndex { + isIndex = 1 + tp = types.NewFieldType(mysql.TypeBlob) + } else { + tp = &possibleColInfo.FieldType + } + rows, _, err := util.ExecRows(sctx, + "select high_priority distinct_count, version, null_count, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms where table_id = %? and hist_id = %? and is_index = %?", + item.TableID, + item.ID, + isIndex, + ) + if err != nil { + return nil, nil, 0, 0, err + } + if len(rows) == 0 { + return nil, nil, 0, 0, nil + } + hist := statistics.NewHistogram(item.ID, rows[0].GetInt64(0), rows[0].GetInt64(2), rows[0].GetUint64(1), tp, chunk.InitialCapacity, rows[0].GetInt64(3)) + hist.Correlation = rows[0].GetFloat64(5) + lastPos := rows[0].GetDatum(7, types.NewFieldType(mysql.TypeBlob)) + return hist, &lastPos, rows[0].GetInt64(4), rows[0].GetInt64(6), nil +} + +// HistogramFromStorageWithPriority wraps the HistogramFromStorage with the given kv.Priority. +// Sync load and async load will use high priority to get data. +func HistogramFromStorageWithPriority( + sctx sessionctx.Context, + tableID int64, + colID int64, + tp *types.FieldType, + distinct int64, + isIndex int, + ver uint64, + nullCount int64, + totColSize int64, + corr float64, + priority int, +) (*statistics.Histogram, error) { + selectPrefix := "select " + switch priority { + case kv.PriorityHigh: + selectPrefix += "high_priority " + case kv.PriorityLow: + selectPrefix += "low_priority " + } + rows, fields, err := util.ExecRows(sctx, selectPrefix+"count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %? order by bucket_id", tableID, isIndex, colID) + if err != nil { + return nil, errors.Trace(err) + } + bucketSize := len(rows) + hg := statistics.NewHistogram(colID, distinct, nullCount, ver, tp, bucketSize, totColSize) + hg.Correlation = corr + totalCount := int64(0) + for i := 0; i < bucketSize; i++ { + count := rows[i].GetInt64(0) + repeats := rows[i].GetInt64(1) + var upperBound, lowerBound types.Datum + if isIndex == 1 { + lowerBound = rows[i].GetDatum(2, &fields[2].Column.FieldType) + upperBound = rows[i].GetDatum(3, &fields[3].Column.FieldType) + } else { + d := rows[i].GetDatum(2, &fields[2].Column.FieldType) + // For new collation data, when storing the bounds of the histogram, we store the collate key instead of the + // original value. + // But there's additional conversion logic for new collation data, and the collate key might be longer than + // the FieldType.flen. + // If we use the original FieldType here, there might be errors like "Invalid utf8mb4 character string" + // or "Data too long". + // So we change it to TypeBlob to bypass those logics here. + if tp.EvalType() == types.ETString && tp.GetType() != mysql.TypeEnum && tp.GetType() != mysql.TypeSet { + tp = types.NewFieldType(mysql.TypeBlob) + } + lowerBound, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, tp) + if err != nil { + return nil, errors.Trace(err) + } + d = rows[i].GetDatum(3, &fields[3].Column.FieldType) + upperBound, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, tp) + if err != nil { + return nil, errors.Trace(err) + } + } + totalCount += count + hg.AppendBucketWithNDV(&lowerBound, &upperBound, totalCount, repeats, rows[i].GetInt64(4)) + } + hg.PreCalculateScalar() + return hg, nil +} + +// CMSketchAndTopNFromStorageWithHighPriority reads CMSketch and TopN from storage. +func CMSketchAndTopNFromStorageWithHighPriority(sctx sessionctx.Context, tblID int64, isIndex, histID, statsVer int64) (_ *statistics.CMSketch, _ *statistics.TopN, err error) { + topNRows, _, err := util.ExecRows(sctx, "select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) + if err != nil { + return nil, nil, err + } + // If we are on version higher than 1. Don't read Count-Min Sketch. + if statsVer > statistics.Version1 { + return statistics.DecodeCMSketchAndTopN(nil, topNRows) + } + rows, _, err := util.ExecRows(sctx, "select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) + if err != nil { + return nil, nil, err + } + if len(rows) == 0 { + return statistics.DecodeCMSketchAndTopN(nil, topNRows) + } + return statistics.DecodeCMSketchAndTopN(rows[0].GetBytes(0), topNRows) +} + +// CMSketchFromStorage reads CMSketch from storage +func CMSketchFromStorage(sctx sessionctx.Context, tblID int64, isIndex int, histID int64) (_ *statistics.CMSketch, err error) { + rows, _, err := util.ExecRows(sctx, "select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) + if err != nil || len(rows) == 0 { + return nil, err + } + return statistics.DecodeCMSketch(rows[0].GetBytes(0)) +} + +// TopNFromStorage reads TopN from storage +func TopNFromStorage(sctx sessionctx.Context, tblID int64, isIndex int, histID int64) (_ *statistics.TopN, err error) { + rows, _, err := util.ExecRows(sctx, "select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) + if err != nil || len(rows) == 0 { + return nil, err + } + return statistics.DecodeTopN(rows), nil +} + +// FMSketchFromStorage reads FMSketch from storage +func FMSketchFromStorage(sctx sessionctx.Context, tblID int64, isIndex, histID int64) (_ *statistics.FMSketch, err error) { + rows, _, err := util.ExecRows(sctx, "select value from mysql.stats_fm_sketch where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) + if err != nil || len(rows) == 0 { + return nil, err + } + return statistics.DecodeFMSketch(rows[0].GetBytes(0)) +} + +// CheckSkipPartition checks if we can skip loading the partition. +func CheckSkipPartition(sctx sessionctx.Context, tblID int64, isIndex int) error { + rows, _, err := util.ExecRows(sctx, "select distinct_count from mysql.stats_histograms where table_id =%? and is_index = %?", tblID, isIndex) + if err != nil { + return err + } + if len(rows) == 0 { + return types.ErrPartitionStatsMissing + } + return nil +} + +// CheckSkipColumnPartiion checks if we can skip loading the partition. +func CheckSkipColumnPartiion(sctx sessionctx.Context, tblID int64, isIndex int, histsID int64) error { + rows, _, err := util.ExecRows(sctx, "select distinct_count from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histsID) + if err != nil { + return err + } + if len(rows) == 0 { + return types.ErrPartitionColumnStatsMissing + } + return nil +} + +// ExtendedStatsFromStorage reads extended stats from storage. +func ExtendedStatsFromStorage(sctx sessionctx.Context, table *statistics.Table, tableID int64, loadAll bool) (*statistics.Table, error) { + failpoint.Inject("injectExtStatsLoadErr", func() { + failpoint.Return(nil, errors.New("gofail extendedStatsFromStorage error")) + }) + lastVersion := uint64(0) + if table.ExtendedStats != nil && !loadAll { + lastVersion = table.ExtendedStats.LastUpdateVersion + } else { + table.ExtendedStats = statistics.NewExtendedStatsColl() + } + rows, _, err := util.ExecRows(sctx, "select name, status, type, column_ids, stats, version from mysql.stats_extended where table_id = %? and status in (%?, %?, %?) and version > %?", + tableID, statistics.ExtendedStatsInited, statistics.ExtendedStatsAnalyzed, statistics.ExtendedStatsDeleted, lastVersion) + if err != nil || len(rows) == 0 { + return table, nil + } + for _, row := range rows { + lastVersion = max(lastVersion, row.GetUint64(5)) + name := row.GetString(0) + status := uint8(row.GetInt64(1)) + if status == statistics.ExtendedStatsDeleted || status == statistics.ExtendedStatsInited { + delete(table.ExtendedStats.Stats, name) + } else { + item := &statistics.ExtendedStatsItem{ + Tp: uint8(row.GetInt64(2)), + } + colIDs := row.GetString(3) + err := json.Unmarshal([]byte(colIDs), &item.ColIDs) + if err != nil { + statslogutil.StatsLogger().Error("decode column IDs failed", zap.String("column_ids", colIDs), zap.Error(err)) + return nil, err + } + statsStr := row.GetString(4) + if item.Tp == ast.StatsTypeCardinality || item.Tp == ast.StatsTypeCorrelation { + if statsStr != "" { + item.ScalarVals, err = strconv.ParseFloat(statsStr, 64) + if err != nil { + statslogutil.StatsLogger().Error("parse scalar stats failed", zap.String("stats", statsStr), zap.Error(err)) + return nil, err + } + } + } else { + item.StringVals = statsStr + } + table.ExtendedStats.Stats[name] = item + } + } + table.ExtendedStats.LastUpdateVersion = lastVersion + return table, nil +} + +func indexStatsFromStorage(sctx sessionctx.Context, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo, loadAll bool, lease time.Duration, tracker *memory.Tracker) error { + histID := row.GetInt64(2) + distinct := row.GetInt64(3) + histVer := row.GetUint64(4) + nullCount := row.GetInt64(5) + statsVer := row.GetInt64(7) + idx := table.GetIdx(histID) + flag := row.GetInt64(8) + lastAnalyzePos := row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)) + + for _, idxInfo := range tableInfo.Indices { + if histID != idxInfo.ID { + continue + } + table.ColAndIdxExistenceMap.InsertIndex(idxInfo.ID, idxInfo, statsVer != statistics.Version0) + // All the objects in the table shares the same stats version. + // Update here. + if statsVer != statistics.Version0 { + table.StatsVer = int(statsVer) + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, histVer) + } + // We will not load buckets, topn and cmsketch if: + // 1. lease > 0, and: + // 2. the index doesn't have any of buckets, topn, cmsketch in memory before, and: + // 3. loadAll is false. + // 4. lite-init-stats is true(remove the condition when lite init stats is GA). + notNeedLoad := lease > 0 && + (idx == nil || ((!idx.IsStatsInitialized() || idx.IsAllEvicted()) && idx.LastUpdateVersion < histVer)) && + !loadAll && + config.GetGlobalConfig().Performance.LiteInitStats + if notNeedLoad { + // If we don't have this index in memory, skip it. + if idx == nil { + return nil + } + idx = &statistics.Index{ + Histogram: *statistics.NewHistogram(histID, distinct, nullCount, histVer, types.NewFieldType(mysql.TypeBlob), 0, 0), + StatsVer: statsVer, + Info: idxInfo, + Flag: flag, + PhysicalID: table.PhysicalID, + } + if idx.IsAnalyzed() { + idx.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + } + lastAnalyzePos.Copy(&idx.LastAnalyzePos) + break + } + if idx == nil || idx.LastUpdateVersion < histVer || loadAll { + hg, err := HistogramFromStorageWithPriority(sctx, table.PhysicalID, histID, types.NewFieldType(mysql.TypeBlob), distinct, 1, histVer, nullCount, 0, 0, kv.PriorityNormal) + if err != nil { + return errors.Trace(err) + } + cms, topN, err := CMSketchAndTopNFromStorageWithHighPriority(sctx, table.PhysicalID, 1, idxInfo.ID, statsVer) + if err != nil { + return errors.Trace(err) + } + var fmSketch *statistics.FMSketch + if loadAll { + // FMSketch is only used when merging partition stats into global stats. When merging partition stats into global stats, + // we load all the statistics, i.e., loadAll is true. + fmSketch, err = FMSketchFromStorage(sctx, table.PhysicalID, 1, histID) + if err != nil { + return errors.Trace(err) + } + } + idx = &statistics.Index{ + Histogram: *hg, + CMSketch: cms, + TopN: topN, + FMSketch: fmSketch, + Info: idxInfo, + StatsVer: statsVer, + Flag: flag, + PhysicalID: table.PhysicalID, + } + if statsVer != statistics.Version0 { + idx.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() + } + lastAnalyzePos.Copy(&idx.LastAnalyzePos) + } + break + } + if idx != nil { + if tracker != nil { + tracker.Consume(idx.MemoryUsage().TotalMemoryUsage()) + } + table.SetIdx(histID, idx) + } else { + logutil.BgLogger().Debug("we cannot find index id in table info. It may be deleted.", zap.Int64("indexID", histID), zap.String("table", tableInfo.Name.O)) + } + return nil +} + +func columnStatsFromStorage(sctx sessionctx.Context, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo, loadAll bool, lease time.Duration, tracker *memory.Tracker) error { + histID := row.GetInt64(2) + distinct := row.GetInt64(3) + histVer := row.GetUint64(4) + nullCount := row.GetInt64(5) + totColSize := row.GetInt64(6) + statsVer := row.GetInt64(7) + correlation := row.GetFloat64(9) + lastAnalyzePos := row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)) + col := table.GetCol(histID) + flag := row.GetInt64(8) + + for _, colInfo := range tableInfo.Columns { + if histID != colInfo.ID { + continue + } + table.ColAndIdxExistenceMap.InsertCol(histID, colInfo, statsVer != statistics.Version0 || distinct > 0 || nullCount > 0) + // All the objects in the table shares the same stats version. + // Update here. + if statsVer != statistics.Version0 { + table.StatsVer = int(statsVer) + table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, histVer) + } + isHandle := tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()) + // We will not load buckets, topn and cmsketch if: + // 1. lease > 0, and: + // 2. this column is not handle or lite-init-stats is true(remove the condition when lite init stats is GA), and: + // 3. the column doesn't have any of buckets, topn, cmsketch in memory before, and: + // 4. loadAll is false. + // + // Here is the explanation of the condition `!col.IsStatsInitialized() || col.IsAllEvicted()`. + // For one column: + // 1. If there is no stats for it in the storage(i.e., analyze has never been executed before), then its stats status + // would be `!col.IsStatsInitialized()`. In this case we should go the `notNeedLoad` path. + // 2. If there exists stats for it in the storage but its stats status is `col.IsAllEvicted()`, there are two + // sub cases for this case. One is that the column stats have never been used/needed by the optimizer so they have + // never been loaded. The other is that the column stats were loaded and then evicted. For the both sub cases, + // we should go the `notNeedLoad` path. + // 3. If some parts(Histogram/TopN/CMSketch) of stats for it exist in TiDB memory currently, we choose to load all of + // its new stats once we find stats version is updated. + notNeedLoad := lease > 0 && + (!isHandle || config.GetGlobalConfig().Performance.LiteInitStats) && + (col == nil || ((!col.IsStatsInitialized() || col.IsAllEvicted()) && col.LastUpdateVersion < histVer)) && + !loadAll + if notNeedLoad { + // If we don't have the column in memory currently, just skip it. + if col == nil { + return nil + } + col = &statistics.Column{ + PhysicalID: table.PhysicalID, + Histogram: *statistics.NewHistogram(histID, distinct, nullCount, histVer, &colInfo.FieldType, 0, totColSize), + Info: colInfo, + IsHandle: tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), + Flag: flag, + StatsVer: statsVer, + } + if col.StatsAvailable() { + col.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + } + lastAnalyzePos.Copy(&col.LastAnalyzePos) + col.Histogram.Correlation = correlation + break + } + if col == nil || col.LastUpdateVersion < histVer || loadAll { + hg, err := HistogramFromStorageWithPriority(sctx, table.PhysicalID, histID, &colInfo.FieldType, distinct, 0, histVer, nullCount, totColSize, correlation, kv.PriorityNormal) + if err != nil { + return errors.Trace(err) + } + cms, topN, err := CMSketchAndTopNFromStorageWithHighPriority(sctx, table.PhysicalID, 0, colInfo.ID, statsVer) + if err != nil { + return errors.Trace(err) + } + var fmSketch *statistics.FMSketch + if loadAll { + // FMSketch is only used when merging partition stats into global stats. When merging partition stats into global stats, + // we load all the statistics, i.e., loadAll is true. + fmSketch, err = FMSketchFromStorage(sctx, table.PhysicalID, 0, histID) + if err != nil { + return errors.Trace(err) + } + } + col = &statistics.Column{ + PhysicalID: table.PhysicalID, + Histogram: *hg, + Info: colInfo, + CMSketch: cms, + TopN: topN, + FMSketch: fmSketch, + IsHandle: tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), + Flag: flag, + StatsVer: statsVer, + } + if col.StatsAvailable() { + col.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() + } + lastAnalyzePos.Copy(&col.LastAnalyzePos) + break + } + if col.TotColSize != totColSize { + newCol := *col + newCol.TotColSize = totColSize + col = &newCol + } + break + } + if col != nil { + if tracker != nil { + tracker.Consume(col.MemoryUsage().TotalMemoryUsage()) + } + table.SetCol(col.ID, col) + } else { + // If we didn't find a Column or Index in tableInfo, we won't load the histogram for it. + // But don't worry, next lease the ddl will be updated, and we will load a same table for two times to + // avoid error. + logutil.BgLogger().Debug("we cannot find column in table info now. It may be deleted", zap.Int64("colID", histID), zap.String("table", tableInfo.Name.O)) + } + return nil +} + +// TableStatsFromStorage loads table stats info from storage. +func TableStatsFromStorage(sctx sessionctx.Context, snapshot uint64, tableInfo *model.TableInfo, tableID int64, loadAll bool, lease time.Duration, table *statistics.Table) (_ *statistics.Table, err error) { + tracker := memory.NewTracker(memory.LabelForAnalyzeMemory, -1) + tracker.AttachTo(sctx.GetSessionVars().MemTracker) + defer tracker.Detach() + // If table stats is pseudo, we also need to copy it, since we will use the column stats when + // the average error rate of it is small. + if table == nil || snapshot > 0 { + histColl := *statistics.NewHistColl(tableID, true, 0, 0, 4, 4) + table = &statistics.Table{ + HistColl: histColl, + ColAndIdxExistenceMap: statistics.NewColAndIndexExistenceMap(len(tableInfo.Columns), len(tableInfo.Indices)), + } + } else { + // We copy it before writing to avoid race. + table = table.Copy() + } + table.Pseudo = false + + realtimeCount, modidyCount, isNull, err := StatsMetaCountAndModifyCount(sctx, tableID) + if err != nil || isNull { + return nil, err + } + table.ModifyCount = modidyCount + table.RealtimeCount = realtimeCount + + rows, _, err := util.ExecRows(sctx, "select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %?", tableID) + // Check deleted table. + if err != nil || len(rows) == 0 { + return nil, nil + } + for _, row := range rows { + if err := sctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { + return nil, err + } + if row.GetInt64(1) > 0 { + err = indexStatsFromStorage(sctx, row, table, tableInfo, loadAll, lease, tracker) + } else { + err = columnStatsFromStorage(sctx, row, table, tableInfo, loadAll, lease, tracker) + } + if err != nil { + return nil, err + } + } + return ExtendedStatsFromStorage(sctx, table, tableID, loadAll) +} + +// LoadHistogram will load histogram from storage. +func LoadHistogram(sctx sessionctx.Context, tableID int64, isIndex int, histID int64, tableInfo *model.TableInfo) (*statistics.Histogram, error) { + row, _, err := util.ExecRows(sctx, "select distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tableID, isIndex, histID) + if err != nil || len(row) == 0 { + return nil, err + } + distinct := row[0].GetInt64(0) + histVer := row[0].GetUint64(1) + nullCount := row[0].GetInt64(2) + var totColSize int64 + var corr float64 + var tp types.FieldType + if isIndex == 0 { + totColSize = row[0].GetInt64(3) + corr = row[0].GetFloat64(6) + for _, colInfo := range tableInfo.Columns { + if histID != colInfo.ID { + continue + } + tp = colInfo.FieldType + break + } + return HistogramFromStorageWithPriority(sctx, tableID, histID, &tp, distinct, isIndex, histVer, nullCount, totColSize, corr, kv.PriorityNormal) + } + return HistogramFromStorageWithPriority(sctx, tableID, histID, types.NewFieldType(mysql.TypeBlob), distinct, isIndex, histVer, nullCount, 0, 0, kv.PriorityNormal) +} + +// LoadNeededHistograms will load histograms for those needed columns/indices. +func LoadNeededHistograms(sctx sessionctx.Context, statsCache statstypes.StatsCache, loadFMSketch bool) (err error) { + items := asyncload.AsyncLoadHistogramNeededItems.AllItems() + for _, item := range items { + if !item.IsIndex { + err = loadNeededColumnHistograms(sctx, statsCache, item.TableItemID, loadFMSketch, item.FullLoad) + } else { + // Index is always full load. + err = loadNeededIndexHistograms(sctx, statsCache, item.TableItemID, loadFMSketch) + } + if err != nil { + return err + } + } + return nil +} + +// CleanFakeItemsForShowHistInFlights cleans the invalid inserted items. +func CleanFakeItemsForShowHistInFlights(statsCache statstypes.StatsCache) int { + items := asyncload.AsyncLoadHistogramNeededItems.AllItems() + reallyNeeded := 0 + for _, item := range items { + tbl, ok := statsCache.Get(item.TableID) + if !ok { + asyncload.AsyncLoadHistogramNeededItems.Delete(item.TableItemID) + continue + } + loadNeeded := false + if item.IsIndex { + _, loadNeeded = tbl.IndexIsLoadNeeded(item.ID) + } else { + var analyzed bool + _, loadNeeded, analyzed = tbl.ColumnIsLoadNeeded(item.ID, item.FullLoad) + loadNeeded = loadNeeded && analyzed + } + if !loadNeeded { + asyncload.AsyncLoadHistogramNeededItems.Delete(item.TableItemID) + continue + } + reallyNeeded++ + } + return reallyNeeded +} + +func loadNeededColumnHistograms(sctx sessionctx.Context, statsCache statstypes.StatsCache, col model.TableItemID, loadFMSketch bool, fullLoad bool) (err error) { + tbl, ok := statsCache.Get(col.TableID) + if !ok { + return nil + } + var colInfo *model.ColumnInfo + _, loadNeeded, analyzed := tbl.ColumnIsLoadNeeded(col.ID, true) + if !loadNeeded || !analyzed { + asyncload.AsyncLoadHistogramNeededItems.Delete(col) + return nil + } + colInfo = tbl.ColAndIdxExistenceMap.GetCol(col.ID) + hg, _, statsVer, _, err := HistMetaFromStorageWithHighPriority(sctx, &col, colInfo) + if hg == nil || err != nil { + asyncload.AsyncLoadHistogramNeededItems.Delete(col) + return err + } + var ( + cms *statistics.CMSketch + topN *statistics.TopN + fms *statistics.FMSketch + ) + if fullLoad { + hg, err = HistogramFromStorageWithPriority(sctx, col.TableID, col.ID, &colInfo.FieldType, hg.NDV, 0, hg.LastUpdateVersion, hg.NullCount, hg.TotColSize, hg.Correlation, kv.PriorityHigh) + if err != nil { + return errors.Trace(err) + } + cms, topN, err = CMSketchAndTopNFromStorageWithHighPriority(sctx, col.TableID, 0, col.ID, statsVer) + if err != nil { + return errors.Trace(err) + } + if loadFMSketch { + fms, err = FMSketchFromStorage(sctx, col.TableID, 0, col.ID) + if err != nil { + return errors.Trace(err) + } + } + } + colHist := &statistics.Column{ + PhysicalID: col.TableID, + Histogram: *hg, + Info: colInfo, + CMSketch: cms, + TopN: topN, + FMSketch: fms, + IsHandle: tbl.IsPkIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), + StatsVer: statsVer, + } + // Reload the latest stats cache, otherwise the `updateStatsCache` may fail with high probability, because functions + // like `GetPartitionStats` called in `fmSketchFromStorage` would have modified the stats cache already. + tbl, ok = statsCache.Get(col.TableID) + if !ok { + return nil + } + tbl = tbl.Copy() + if colHist.StatsAvailable() { + if fullLoad { + colHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() + } else { + colHist.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + } + tbl.LastAnalyzeVersion = max(tbl.LastAnalyzeVersion, colHist.LastUpdateVersion) + if statsVer != statistics.Version0 { + tbl.StatsVer = int(statsVer) + } + } + tbl.SetCol(col.ID, colHist) + statsCache.UpdateStatsCache([]*statistics.Table{tbl}, nil) + asyncload.AsyncLoadHistogramNeededItems.Delete(col) + if col.IsSyncLoadFailed { + logutil.BgLogger().Warn("Hist for column should already be loaded as sync but not found.", + zap.Int64("table_id", colHist.PhysicalID), + zap.Int64("column_id", colHist.Info.ID), + zap.String("column_name", colHist.Info.Name.O)) + } + return nil +} + +func loadNeededIndexHistograms(sctx sessionctx.Context, statsCache statstypes.StatsCache, idx model.TableItemID, loadFMSketch bool) (err error) { + tbl, ok := statsCache.Get(idx.TableID) + if !ok { + return nil + } + _, loadNeeded := tbl.IndexIsLoadNeeded(idx.ID) + if !loadNeeded { + asyncload.AsyncLoadHistogramNeededItems.Delete(idx) + return nil + } + hgMeta, lastAnalyzePos, statsVer, flag, err := HistMetaFromStorageWithHighPriority(sctx, &idx, nil) + if hgMeta == nil || err != nil { + asyncload.AsyncLoadHistogramNeededItems.Delete(idx) + return err + } + idxInfo := tbl.ColAndIdxExistenceMap.GetIndex(idx.ID) + hg, err := HistogramFromStorageWithPriority(sctx, idx.TableID, idx.ID, types.NewFieldType(mysql.TypeBlob), hgMeta.NDV, 1, hgMeta.LastUpdateVersion, hgMeta.NullCount, hgMeta.TotColSize, hgMeta.Correlation, kv.PriorityHigh) + if err != nil { + return errors.Trace(err) + } + cms, topN, err := CMSketchAndTopNFromStorageWithHighPriority(sctx, idx.TableID, 1, idx.ID, statsVer) + if err != nil { + return errors.Trace(err) + } + var fms *statistics.FMSketch + if loadFMSketch { + fms, err = FMSketchFromStorage(sctx, idx.TableID, 1, idx.ID) + if err != nil { + return errors.Trace(err) + } + } + idxHist := &statistics.Index{Histogram: *hg, CMSketch: cms, TopN: topN, FMSketch: fms, + Info: idxInfo, StatsVer: statsVer, + Flag: flag, PhysicalID: idx.TableID, + StatsLoadedStatus: statistics.NewStatsFullLoadStatus()} + lastAnalyzePos.Copy(&idxHist.LastAnalyzePos) + + tbl, ok = statsCache.Get(idx.TableID) + if !ok { + return nil + } + tbl = tbl.Copy() + if idxHist.StatsVer != statistics.Version0 { + tbl.StatsVer = int(idxHist.StatsVer) + } + tbl.SetIdx(idx.ID, idxHist) + tbl.LastAnalyzeVersion = max(tbl.LastAnalyzeVersion, idxHist.LastUpdateVersion) + statsCache.UpdateStatsCache([]*statistics.Table{tbl}, nil) + if idx.IsSyncLoadFailed { + logutil.BgLogger().Warn("Hist for index should already be loaded as sync but not found.", + zap.Int64("table_id", idx.TableID), + zap.Int64("index_id", idxHist.Info.ID), + zap.String("index_name", idxHist.Info.Name.O)) + } + asyncload.AsyncLoadHistogramNeededItems.Delete(idx) + return nil +} + +// StatsMetaByTableIDFromStorage gets the stats meta of a table from storage. +func StatsMetaByTableIDFromStorage(sctx sessionctx.Context, tableID int64, snapshot uint64) (version uint64, modifyCount, count int64, err error) { + var rows []chunk.Row + if snapshot == 0 { + rows, _, err = util.ExecRows(sctx, + "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", tableID) + } else { + rows, _, err = util.ExecWithOpts(sctx, + []sqlexec.OptionFuncAlias{sqlexec.ExecOptionWithSnapshot(snapshot), sqlexec.ExecOptionUseCurSession}, + "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", tableID) + } + if err != nil || len(rows) == 0 { + return + } + version = rows[0].GetUint64(0) + modifyCount = rows[0].GetInt64(1) + count = rows[0].GetInt64(2) + return +} diff --git a/pkg/statistics/handle/syncload/binding__failpoint_binding__.go b/pkg/statistics/handle/syncload/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..f2b453264a678 --- /dev/null +++ b/pkg/statistics/handle/syncload/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package syncload + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/statistics/handle/syncload/stats_syncload.go b/pkg/statistics/handle/syncload/stats_syncload.go index 65a234f71b474..e6972877bfe1d 100644 --- a/pkg/statistics/handle/syncload/stats_syncload.go +++ b/pkg/statistics/handle/syncload/stats_syncload.go @@ -83,14 +83,14 @@ type statsWrapper struct { func (s *statsSyncLoad) SendLoadRequests(sc *stmtctx.StatementContext, neededHistItems []model.StatsLoadItem, timeout time.Duration) error { remainedItems := s.removeHistLoadedColumns(neededHistItems) - failpoint.Inject("assertSyncLoadItems", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("assertSyncLoadItems")); _err_ == nil { if sc.OptimizeTracer != nil { count := val.(int) if len(remainedItems) != count { panic("remained items count wrong") } } - }) + } if len(remainedItems) <= 0 { return nil @@ -361,12 +361,12 @@ func (s *statsSyncLoad) handleOneItemTask(task *statstypes.NeededItemTask) (err // readStatsForOneItem reads hist for one column/index, TODO load data via kv-get asynchronously func (*statsSyncLoad) readStatsForOneItem(sctx sessionctx.Context, item model.TableItemID, w *statsWrapper, isPkIsHandle bool, fullLoad bool) (*statsWrapper, error) { - failpoint.Inject("mockReadStatsForOnePanic", nil) - failpoint.Inject("mockReadStatsForOneFail", func(val failpoint.Value) { + failpoint.Eval(_curpkg_("mockReadStatsForOnePanic")) + if val, _err_ := failpoint.Eval(_curpkg_("mockReadStatsForOneFail")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("gofail ReadStatsForOne error")) + return nil, errors.New("gofail ReadStatsForOne error") } - }) + } loadFMSketch := config.GetGlobalConfig().Performance.EnableLoadFMSketch var hg *statistics.Histogram var err error diff --git a/pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ b/pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ new file mode 100644 index 0000000000000..65a234f71b474 --- /dev/null +++ b/pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ @@ -0,0 +1,574 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncload + +import ( + "fmt" + "math/rand" + "runtime" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/statistics/handle/storage" + statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" + "golang.org/x/sync/singleflight" +) + +// RetryCount is the max retry count for a sync load task. +const RetryCount = 3 + +// GetSyncLoadConcurrencyByCPU returns the concurrency of sync load by CPU. +func GetSyncLoadConcurrencyByCPU() int { + core := runtime.GOMAXPROCS(0) + if core <= 8 { + return 5 + } else if core <= 16 { + return 6 + } else if core <= 32 { + return 8 + } + return 10 +} + +type statsSyncLoad struct { + statsHandle statstypes.StatsHandle + StatsLoad statstypes.StatsLoad +} + +var globalStatsSyncLoadSingleFlight singleflight.Group + +// NewStatsSyncLoad creates a new StatsSyncLoad. +func NewStatsSyncLoad(statsHandle statstypes.StatsHandle) statstypes.StatsSyncLoad { + s := &statsSyncLoad{statsHandle: statsHandle} + cfg := config.GetGlobalConfig() + s.StatsLoad.NeededItemsCh = make(chan *statstypes.NeededItemTask, cfg.Performance.StatsLoadQueueSize) + s.StatsLoad.TimeoutItemsCh = make(chan *statstypes.NeededItemTask, cfg.Performance.StatsLoadQueueSize) + return s +} + +type statsWrapper struct { + colInfo *model.ColumnInfo + idxInfo *model.IndexInfo + col *statistics.Column + idx *statistics.Index +} + +// SendLoadRequests send neededColumns requests +func (s *statsSyncLoad) SendLoadRequests(sc *stmtctx.StatementContext, neededHistItems []model.StatsLoadItem, timeout time.Duration) error { + remainedItems := s.removeHistLoadedColumns(neededHistItems) + + failpoint.Inject("assertSyncLoadItems", func(val failpoint.Value) { + if sc.OptimizeTracer != nil { + count := val.(int) + if len(remainedItems) != count { + panic("remained items count wrong") + } + } + }) + + if len(remainedItems) <= 0 { + return nil + } + sc.StatsLoad.Timeout = timeout + sc.StatsLoad.NeededItems = remainedItems + sc.StatsLoad.ResultCh = make([]<-chan singleflight.Result, 0, len(remainedItems)) + for _, item := range remainedItems { + localItem := item + resultCh := globalStatsSyncLoadSingleFlight.DoChan(localItem.Key(), func() (any, error) { + timer := time.NewTimer(timeout) + defer timer.Stop() + task := &statstypes.NeededItemTask{ + Item: localItem, + ToTimeout: time.Now().Local().Add(timeout), + ResultCh: make(chan stmtctx.StatsLoadResult, 1), + } + select { + case s.StatsLoad.NeededItemsCh <- task: + metrics.SyncLoadDedupCounter.Inc() + result, ok := <-task.ResultCh + intest.Assert(ok, "task.ResultCh cannot be closed") + return result, nil + case <-timer.C: + return nil, errors.New("sync load stats channel is full and timeout sending task to channel") + } + }) + sc.StatsLoad.ResultCh = append(sc.StatsLoad.ResultCh, resultCh) + } + sc.StatsLoad.LoadStartTime = time.Now() + return nil +} + +// SyncWaitStatsLoad sync waits loading of neededColumns and return false if timeout +func (*statsSyncLoad) SyncWaitStatsLoad(sc *stmtctx.StatementContext) error { + if len(sc.StatsLoad.NeededItems) <= 0 { + return nil + } + var errorMsgs []string + defer func() { + if len(errorMsgs) > 0 { + logutil.BgLogger().Warn("SyncWaitStatsLoad meets error", + zap.Strings("errors", errorMsgs)) + } + sc.StatsLoad.NeededItems = nil + }() + resultCheckMap := map[model.TableItemID]struct{}{} + for _, col := range sc.StatsLoad.NeededItems { + resultCheckMap[col.TableItemID] = struct{}{} + } + timer := time.NewTimer(sc.StatsLoad.Timeout) + defer timer.Stop() + for _, resultCh := range sc.StatsLoad.ResultCh { + select { + case result, ok := <-resultCh: + metrics.SyncLoadCounter.Inc() + if !ok { + return errors.New("sync load stats channel closed unexpectedly") + } + // this error is from statsSyncLoad.SendLoadRequests which start to task and send task into worker, + // not the stats loading error + if result.Err != nil { + errorMsgs = append(errorMsgs, result.Err.Error()) + } else { + val := result.Val.(stmtctx.StatsLoadResult) + // this error is from the stats loading error + if val.HasError() { + errorMsgs = append(errorMsgs, val.ErrorMsg()) + } + delete(resultCheckMap, val.Item) + } + case <-timer.C: + metrics.SyncLoadCounter.Inc() + metrics.SyncLoadTimeoutCounter.Inc() + return errors.New("sync load stats timeout") + } + } + if len(resultCheckMap) == 0 { + metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) + return nil + } + return nil +} + +// removeHistLoadedColumns removed having-hist columns based on neededColumns and statsCache. +func (s *statsSyncLoad) removeHistLoadedColumns(neededItems []model.StatsLoadItem) []model.StatsLoadItem { + remainedItems := make([]model.StatsLoadItem, 0, len(neededItems)) + for _, item := range neededItems { + tbl, ok := s.statsHandle.Get(item.TableID) + if !ok { + continue + } + if item.IsIndex { + _, loadNeeded := tbl.IndexIsLoadNeeded(item.ID) + if loadNeeded { + remainedItems = append(remainedItems, item) + } + continue + } + _, loadNeeded, _ := tbl.ColumnIsLoadNeeded(item.ID, item.FullLoad) + if loadNeeded { + remainedItems = append(remainedItems, item) + } + } + return remainedItems +} + +// AppendNeededItem appends needed columns/indices to ch, it is only used for test +func (s *statsSyncLoad) AppendNeededItem(task *statstypes.NeededItemTask, timeout time.Duration) error { + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case s.StatsLoad.NeededItemsCh <- task: + case <-timer.C: + return errors.New("Channel is full and timeout writing to channel") + } + return nil +} + +var errExit = errors.New("Stop loading since domain is closed") + +// SubLoadWorker loads hist data for each column +func (s *statsSyncLoad) SubLoadWorker(sctx sessionctx.Context, exit chan struct{}, exitWg *util.WaitGroupEnhancedWrapper) { + defer func() { + exitWg.Done() + logutil.BgLogger().Info("SubLoadWorker exited.") + }() + // if the last task is not successfully handled in last round for error or panic, pass it to this round to retry + var lastTask *statstypes.NeededItemTask + for { + task, err := s.HandleOneTask(sctx, lastTask, exit) + lastTask = task + if err != nil { + switch err { + case errExit: + return + default: + // To avoid the thundering herd effect + // thundering herd effect: Everyone tries to retry a large number of requests simultaneously when a problem occurs. + r := rand.Intn(500) + time.Sleep(s.statsHandle.Lease()/10 + time.Duration(r)*time.Microsecond) + continue + } + } + } +} + +// HandleOneTask handles last task if not nil, else handle a new task from chan, and return current task if fail somewhere. +// - If the task is handled successfully, return nil, nil. +// - If the task is timeout, return the task and nil. The caller should retry the timeout task without sleep. +// - If the task is failed, return the task, error. The caller should retry the timeout task with sleep. +func (s *statsSyncLoad) HandleOneTask(sctx sessionctx.Context, lastTask *statstypes.NeededItemTask, exit chan struct{}) (task *statstypes.NeededItemTask, err error) { + defer func() { + // recover for each task, worker keeps working + if r := recover(); r != nil { + logutil.BgLogger().Error("stats loading panicked", zap.Any("error", r), zap.Stack("stack")) + err = errors.Errorf("stats loading panicked: %v", r) + } + }() + if lastTask == nil { + task, err = s.drainColTask(sctx, exit) + if err != nil { + if err != errExit { + logutil.BgLogger().Error("Fail to drain task for stats loading.", zap.Error(err)) + } + return task, err + } + } else { + task = lastTask + } + result := stmtctx.StatsLoadResult{Item: task.Item.TableItemID} + err = s.handleOneItemTask(task) + if err == nil { + task.ResultCh <- result + return nil, nil + } + if !isVaildForRetry(task) { + result.Error = err + task.ResultCh <- result + return nil, nil + } + return task, err +} + +func isVaildForRetry(task *statstypes.NeededItemTask) bool { + task.Retry++ + return task.Retry <= RetryCount +} + +func (s *statsSyncLoad) handleOneItemTask(task *statstypes.NeededItemTask) (err error) { + se, err := s.statsHandle.SPool().Get() + if err != nil { + return err + } + sctx := se.(sessionctx.Context) + sctx.GetSessionVars().StmtCtx.Priority = mysql.HighPriority + defer func() { + // recover for each task, worker keeps working + if r := recover(); r != nil { + logutil.BgLogger().Error("handleOneItemTask panicked", zap.Any("recover", r), zap.Stack("stack")) + err = errors.Errorf("stats loading panicked: %v", r) + } + if err == nil { // only recycle when no error + sctx.GetSessionVars().StmtCtx.Priority = mysql.NoPriority + s.statsHandle.SPool().Put(se) + } + }() + item := task.Item.TableItemID + tbl, ok := s.statsHandle.Get(item.TableID) + if !ok { + return nil + } + wrapper := &statsWrapper{} + if item.IsIndex { + index, loadNeeded := tbl.IndexIsLoadNeeded(item.ID) + if !loadNeeded { + return nil + } + if index != nil { + wrapper.idxInfo = index.Info + } else { + wrapper.idxInfo = tbl.ColAndIdxExistenceMap.GetIndex(item.ID) + } + } else { + col, loadNeeded, analyzed := tbl.ColumnIsLoadNeeded(item.ID, task.Item.FullLoad) + if !loadNeeded { + return nil + } + if col != nil { + wrapper.colInfo = col.Info + } else { + wrapper.colInfo = tbl.ColAndIdxExistenceMap.GetCol(item.ID) + } + // If this column is not analyzed yet and we don't have it in memory. + // We create a fake one for the pseudo estimation. + if loadNeeded && !analyzed { + wrapper.col = &statistics.Column{ + PhysicalID: item.TableID, + Info: wrapper.colInfo, + Histogram: *statistics.NewHistogram(item.ID, 0, 0, 0, &wrapper.colInfo.FieldType, 0, 0), + IsHandle: tbl.IsPkIsHandle && mysql.HasPriKeyFlag(wrapper.colInfo.GetFlag()), + } + s.updateCachedItem(item, wrapper.col, wrapper.idx, task.Item.FullLoad) + return nil + } + } + t := time.Now() + needUpdate := false + wrapper, err = s.readStatsForOneItem(sctx, item, wrapper, tbl.IsPkIsHandle, task.Item.FullLoad) + if err != nil { + return err + } + if item.IsIndex { + if wrapper.idxInfo != nil { + needUpdate = true + } + } else { + if wrapper.colInfo != nil { + needUpdate = true + } + } + metrics.ReadStatsHistogram.Observe(float64(time.Since(t).Milliseconds())) + if needUpdate { + s.updateCachedItem(item, wrapper.col, wrapper.idx, task.Item.FullLoad) + } + return nil +} + +// readStatsForOneItem reads hist for one column/index, TODO load data via kv-get asynchronously +func (*statsSyncLoad) readStatsForOneItem(sctx sessionctx.Context, item model.TableItemID, w *statsWrapper, isPkIsHandle bool, fullLoad bool) (*statsWrapper, error) { + failpoint.Inject("mockReadStatsForOnePanic", nil) + failpoint.Inject("mockReadStatsForOneFail", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("gofail ReadStatsForOne error")) + } + }) + loadFMSketch := config.GetGlobalConfig().Performance.EnableLoadFMSketch + var hg *statistics.Histogram + var err error + isIndexFlag := int64(0) + hg, lastAnalyzePos, statsVer, flag, err := storage.HistMetaFromStorageWithHighPriority(sctx, &item, w.colInfo) + if err != nil { + return nil, err + } + if hg == nil { + logutil.BgLogger().Error("fail to get hist meta for this histogram, possibly a deleted one", zap.Int64("table_id", item.TableID), + zap.Int64("hist_id", item.ID), zap.Bool("is_index", item.IsIndex)) + return nil, errors.Trace(fmt.Errorf("fail to get hist meta for this histogram, table_id:%v, hist_id:%v, is_index:%v", item.TableID, item.ID, item.IsIndex)) + } + if item.IsIndex { + isIndexFlag = 1 + } + var cms *statistics.CMSketch + var topN *statistics.TopN + var fms *statistics.FMSketch + if fullLoad { + if item.IsIndex { + hg, err = storage.HistogramFromStorageWithPriority(sctx, item.TableID, item.ID, types.NewFieldType(mysql.TypeBlob), hg.NDV, int(isIndexFlag), hg.LastUpdateVersion, hg.NullCount, hg.TotColSize, hg.Correlation, kv.PriorityHigh) + if err != nil { + return nil, errors.Trace(err) + } + } else { + hg, err = storage.HistogramFromStorageWithPriority(sctx, item.TableID, item.ID, &w.colInfo.FieldType, hg.NDV, int(isIndexFlag), hg.LastUpdateVersion, hg.NullCount, hg.TotColSize, hg.Correlation, kv.PriorityHigh) + if err != nil { + return nil, errors.Trace(err) + } + } + cms, topN, err = storage.CMSketchAndTopNFromStorageWithHighPriority(sctx, item.TableID, isIndexFlag, item.ID, statsVer) + if err != nil { + return nil, errors.Trace(err) + } + if loadFMSketch { + fms, err = storage.FMSketchFromStorage(sctx, item.TableID, isIndexFlag, item.ID) + if err != nil { + return nil, errors.Trace(err) + } + } + } + if item.IsIndex { + idxHist := &statistics.Index{ + Histogram: *hg, + CMSketch: cms, + TopN: topN, + FMSketch: fms, + Info: w.idxInfo, + StatsVer: statsVer, + Flag: flag, + PhysicalID: item.TableID, + } + if statsVer != statistics.Version0 { + if fullLoad { + idxHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() + } else { + idxHist.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + } + } + lastAnalyzePos.Copy(&idxHist.LastAnalyzePos) + w.idx = idxHist + } else { + colHist := &statistics.Column{ + PhysicalID: item.TableID, + Histogram: *hg, + Info: w.colInfo, + CMSketch: cms, + TopN: topN, + FMSketch: fms, + IsHandle: isPkIsHandle && mysql.HasPriKeyFlag(w.colInfo.GetFlag()), + StatsVer: statsVer, + } + if colHist.StatsAvailable() { + if fullLoad { + colHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() + } else { + colHist.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() + } + } + w.col = colHist + } + return w, nil +} + +// drainColTask will hang until a column task can return, and either task or error will be returned. +func (s *statsSyncLoad) drainColTask(sctx sessionctx.Context, exit chan struct{}) (*statstypes.NeededItemTask, error) { + // select NeededColumnsCh firstly, if no task, then select TimeoutColumnsCh + for { + select { + case <-exit: + return nil, errExit + case task, ok := <-s.StatsLoad.NeededItemsCh: + if !ok { + return nil, errors.New("drainColTask: cannot read from NeededColumnsCh, maybe the chan is closed") + } + // if the task has already timeout, no sql is sync-waiting for it, + // so do not handle it just now, put it to another channel with lower priority + if time.Now().After(task.ToTimeout) { + task.ToTimeout.Add(time.Duration(sctx.GetSessionVars().StatsLoadSyncWait.Load()) * time.Microsecond) + s.writeToTimeoutChan(s.StatsLoad.TimeoutItemsCh, task) + continue + } + return task, nil + case task, ok := <-s.StatsLoad.TimeoutItemsCh: + select { + case <-exit: + return nil, errExit + case task0, ok0 := <-s.StatsLoad.NeededItemsCh: + if !ok0 { + return nil, errors.New("drainColTask: cannot read from NeededColumnsCh, maybe the chan is closed") + } + // send task back to TimeoutColumnsCh and return the task drained from NeededColumnsCh + s.writeToTimeoutChan(s.StatsLoad.TimeoutItemsCh, task) + return task0, nil + default: + if !ok { + return nil, errors.New("drainColTask: cannot read from TimeoutColumnsCh, maybe the chan is closed") + } + // NeededColumnsCh is empty now, handle task from TimeoutColumnsCh + return task, nil + } + } + } +} + +// writeToTimeoutChan writes in a nonblocking way, and if the channel queue is full, it's ok to drop the task. +func (*statsSyncLoad) writeToTimeoutChan(taskCh chan *statstypes.NeededItemTask, task *statstypes.NeededItemTask) { + select { + case taskCh <- task: + default: + } +} + +// writeToChanWithTimeout writes a task to a channel and blocks until timeout. +func (*statsSyncLoad) writeToChanWithTimeout(taskCh chan *statstypes.NeededItemTask, task *statstypes.NeededItemTask, timeout time.Duration) error { + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case taskCh <- task: + case <-timer.C: + return errors.New("Channel is full and timeout writing to channel") + } + return nil +} + +// writeToResultChan safe-writes with panic-recover so one write-fail will not have big impact. +func (*statsSyncLoad) writeToResultChan(resultCh chan stmtctx.StatsLoadResult, rs stmtctx.StatsLoadResult) { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("writeToResultChan panicked", zap.Any("error", r), zap.Stack("stack")) + } + }() + select { + case resultCh <- rs: + default: + } +} + +// updateCachedItem updates the column/index hist to global statsCache. +func (s *statsSyncLoad) updateCachedItem(item model.TableItemID, colHist *statistics.Column, idxHist *statistics.Index, fullLoaded bool) (updated bool) { + s.StatsLoad.Lock() + defer s.StatsLoad.Unlock() + // Reload the latest stats cache, otherwise the `updateStatsCache` may fail with high probability, because functions + // like `GetPartitionStats` called in `fmSketchFromStorage` would have modified the stats cache already. + tbl, ok := s.statsHandle.Get(item.TableID) + if !ok { + return false + } + if !item.IsIndex && colHist != nil { + c := tbl.GetCol(item.ID) + // - If the stats is fully loaded, + // - If the stats is meta-loaded and we also just need the meta. + if c != nil && (c.IsFullLoad() || !fullLoaded) { + return false + } + tbl = tbl.Copy() + tbl.SetCol(item.ID, colHist) + // If the column is analyzed we refresh the map for the possible change. + if colHist.StatsAvailable() { + tbl.ColAndIdxExistenceMap.InsertCol(item.ID, colHist.Info, true) + } + // All the objects shares the same stats version. Update it here. + if colHist.StatsVer != statistics.Version0 { + tbl.StatsVer = statistics.Version0 + } + } else if item.IsIndex && idxHist != nil { + index := tbl.GetIdx(item.ID) + // - If the stats is fully loaded, + // - If the stats is meta-loaded and we also just need the meta. + if index != nil && (index.IsFullLoad() || !fullLoaded) { + return true + } + tbl = tbl.Copy() + tbl.SetIdx(item.ID, idxHist) + // If the index is analyzed we refresh the map for the possible change. + if idxHist.IsAnalyzed() { + tbl.ColAndIdxExistenceMap.InsertIndex(item.ID, idxHist.Info, true) + // All the objects shares the same stats version. Update it here. + tbl.StatsVer = statistics.Version0 + } + } + s.statsHandle.UpdateStatsCache([]*statistics.Table{tbl}, nil) + return true +} diff --git a/pkg/store/copr/batch_coprocessor.go b/pkg/store/copr/batch_coprocessor.go index 850204fb9b168..9b8d1416f5798 100644 --- a/pkg/store/copr/batch_coprocessor.go +++ b/pkg/store/copr/batch_coprocessor.go @@ -747,7 +747,7 @@ func buildBatchCopTasksConsistentHash( } func failpointCheckForConsistentHash(tasks []*batchCopTask) { - failpoint.Inject("checkOnlyDispatchToTiFlashComputeNodes", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("checkOnlyDispatchToTiFlashComputeNodes")); _err_ == nil { logutil.BgLogger().Debug("in checkOnlyDispatchToTiFlashComputeNodes") // This failpoint will be tested in test-infra case, because we needs setup a cluster. @@ -768,18 +768,18 @@ func failpointCheckForConsistentHash(tasks []*batchCopTask) { panic(err) } } - }) + } } func failpointCheckWhichPolicy(act tiflashcompute.DispatchPolicy) { - failpoint.Inject("testWhichDispatchPolicy", func(exp failpoint.Value) { + if exp, _err_ := failpoint.Eval(_curpkg_("testWhichDispatchPolicy")); _err_ == nil { expStr := exp.(string) actStr := tiflashcompute.GetDispatchPolicy(act) if actStr != expStr { err := errors.Errorf("tiflash_compute dispatch should be %v, but got %v", expStr, actStr) panic(err) } - }) + } } func filterAllStoresAccordingToTiFlashReplicaRead(allStores []uint64, aliveStores *aliveStoresBundle, policy tiflash.ReplicaRead) (storesMatchedPolicy []uint64, needsCrossZoneAccess bool) { @@ -1174,11 +1174,11 @@ func (b *batchCopIterator) run(ctx context.Context) { for _, task := range b.tasks { b.wg.Add(1) boMaxSleep := CopNextMaxBackoff - failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("ReduceCopNextMaxBackoff")); _err_ == nil { if value.(bool) { boMaxSleep = 2 } - }) + } bo := backoff.NewBackofferWithVars(ctx, boMaxSleep, b.vars) go b.handleTask(ctx, bo, task) } diff --git a/pkg/store/copr/batch_coprocessor.go__failpoint_stash__ b/pkg/store/copr/batch_coprocessor.go__failpoint_stash__ new file mode 100644 index 0000000000000..850204fb9b168 --- /dev/null +++ b/pkg/store/copr/batch_coprocessor.go__failpoint_stash__ @@ -0,0 +1,1588 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package copr + +import ( + "bytes" + "cmp" + "context" + "fmt" + "io" + "math" + "math/rand" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/store/driver/backoff" + derr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/tiflash" + "github.com/pingcap/tidb/pkg/util/tiflashcompute" + "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/twmb/murmur3" + "go.uber.org/zap" +) + +const fetchTopoMaxBackoff = 20000 + +// batchCopTask comprises of multiple copTask that will send to same store. +type batchCopTask struct { + storeAddr string + cmdType tikvrpc.CmdType + ctx *tikv.RPCContext + + regionInfos []RegionInfo // region info for single physical table + // PartitionTableRegions indicates region infos for each partition table, used by scanning partitions in batch. + // Thus, one of `regionInfos` and `PartitionTableRegions` must be nil. + PartitionTableRegions []*coprocessor.TableRegions +} + +type batchCopResponse struct { + pbResp *coprocessor.BatchResponse + detail *CopRuntimeStats + + // batch Cop Response is yet to return startKey. So batchCop cannot retry partially. + startKey kv.Key + err error + respSize int64 + respTime time.Duration +} + +// GetData implements the kv.ResultSubset GetData interface. +func (rs *batchCopResponse) GetData() []byte { + return rs.pbResp.Data +} + +// GetStartKey implements the kv.ResultSubset GetStartKey interface. +func (rs *batchCopResponse) GetStartKey() kv.Key { + return rs.startKey +} + +// GetExecDetails is unavailable currently, because TiFlash has not collected exec details for batch cop. +// TODO: Will fix in near future. +func (rs *batchCopResponse) GetCopRuntimeStats() *CopRuntimeStats { + return rs.detail +} + +// MemSize returns how many bytes of memory this response use +func (rs *batchCopResponse) MemSize() int64 { + if rs.respSize != 0 { + return rs.respSize + } + + // ignore rs.err + rs.respSize += int64(cap(rs.startKey)) + if rs.detail != nil { + rs.respSize += int64(sizeofExecDetails) + } + if rs.pbResp != nil { + // Using a approximate size since it's hard to get a accurate value. + rs.respSize += int64(rs.pbResp.Size()) + } + return rs.respSize +} + +func (rs *batchCopResponse) RespTime() time.Duration { + return rs.respTime +} + +func deepCopyStoreTaskMap(storeTaskMap map[uint64]*batchCopTask) map[uint64]*batchCopTask { + storeTasks := make(map[uint64]*batchCopTask) + for storeID, task := range storeTaskMap { + t := batchCopTask{ + storeAddr: task.storeAddr, + cmdType: task.cmdType, + ctx: task.ctx, + } + t.regionInfos = make([]RegionInfo, len(task.regionInfos)) + copy(t.regionInfos, task.regionInfos) + storeTasks[storeID] = &t + } + return storeTasks +} + +func regionTotalCount(storeTasks map[uint64]*batchCopTask, candidateRegionInfos []RegionInfo) int { + count := len(candidateRegionInfos) + for _, task := range storeTasks { + count += len(task.regionInfos) + } + return count +} + +const ( + maxBalanceScore = 100 + balanceScoreThreshold = 85 +) + +// Select at most cnt RegionInfos from candidateRegionInfos that belong to storeID. +// If selected[i] is true, candidateRegionInfos[i] has been selected and should be skip. +// storeID2RegionIndex is a map that key is storeID and value is a region index slice. +// selectRegion use storeID2RegionIndex to find RegionInfos that belong to storeID efficiently. +func selectRegion(storeID uint64, candidateRegionInfos []RegionInfo, selected []bool, storeID2RegionIndex map[uint64][]int, cnt int64) []RegionInfo { + regionIndexes, ok := storeID2RegionIndex[storeID] + if !ok { + logutil.BgLogger().Error("selectRegion: storeID2RegionIndex not found", zap.Uint64("storeID", storeID)) + return nil + } + var regionInfos []RegionInfo + i := 0 + for ; i < len(regionIndexes) && len(regionInfos) < int(cnt); i++ { + idx := regionIndexes[i] + if selected[idx] { + continue + } + selected[idx] = true + regionInfos = append(regionInfos, candidateRegionInfos[idx]) + } + // Remove regions that has been selected. + storeID2RegionIndex[storeID] = regionIndexes[i:] + return regionInfos +} + +// Higher scores mean more balance: (100 - unblance percentage) +func balanceScore(maxRegionCount, minRegionCount int, balanceContinuousRegionCount int64) int { + if minRegionCount <= 0 { + return math.MinInt32 + } + unbalanceCount := maxRegionCount - minRegionCount + if unbalanceCount <= int(balanceContinuousRegionCount) { + return maxBalanceScore + } + return maxBalanceScore - unbalanceCount*100/minRegionCount +} + +func isBalance(score int) bool { + return score >= balanceScoreThreshold +} + +func checkBatchCopTaskBalance(storeTasks map[uint64]*batchCopTask, balanceContinuousRegionCount int64) (int, []string) { + if len(storeTasks) == 0 { + return 0, []string{} + } + maxRegionCount := 0 + minRegionCount := math.MaxInt32 + balanceInfos := []string{} + for storeID, task := range storeTasks { + cnt := len(task.regionInfos) + if cnt > maxRegionCount { + maxRegionCount = cnt + } + if cnt < minRegionCount { + minRegionCount = cnt + } + balanceInfos = append(balanceInfos, fmt.Sprintf("storeID %d storeAddr %s regionCount %d", storeID, task.storeAddr, cnt)) + } + return balanceScore(maxRegionCount, minRegionCount, balanceContinuousRegionCount), balanceInfos +} + +// balanceBatchCopTaskWithContinuity try to balance `continuous regions` between TiFlash Stores. +// In fact, not absolutely continuous is required, regions' range are closed to store in a TiFlash segment is enough for internal read optimization. +// +// First, sort candidateRegionInfos by their key ranges. +// Second, build a storeID2RegionIndex data structure to fastly locate regions of a store (avoid scanning candidateRegionInfos repeatedly). +// Third, each store will take balanceContinuousRegionCount from the sorted candidateRegionInfos. These regions are stored very close to each other in TiFlash. +// Fourth, if the region count is not balance between TiFlash, it may fallback to the original balance logic. +func balanceBatchCopTaskWithContinuity(storeTaskMap map[uint64]*batchCopTask, candidateRegionInfos []RegionInfo, balanceContinuousRegionCount int64) ([]*batchCopTask, int) { + if len(candidateRegionInfos) < 500 { + return nil, 0 + } + funcStart := time.Now() + regionCount := regionTotalCount(storeTaskMap, candidateRegionInfos) + storeTasks := deepCopyStoreTaskMap(storeTaskMap) + + // Sort regions by their key ranges. + slices.SortFunc(candidateRegionInfos, func(i, j RegionInfo) int { + // Special case: Sort empty ranges to the end. + if i.Ranges.Len() < 1 || j.Ranges.Len() < 1 { + return cmp.Compare(j.Ranges.Len(), i.Ranges.Len()) + } + // StartKey0 < StartKey1 + return bytes.Compare(i.Ranges.At(0).StartKey, j.Ranges.At(0).StartKey) + }) + + balanceStart := time.Now() + // Build storeID -> region index slice index and we can fastly locate regions of a store. + storeID2RegionIndex := make(map[uint64][]int) + for i, ri := range candidateRegionInfos { + for _, storeID := range ri.AllStores { + if val, ok := storeID2RegionIndex[storeID]; ok { + storeID2RegionIndex[storeID] = append(val, i) + } else { + storeID2RegionIndex[storeID] = []int{i} + } + } + } + + // If selected[i] is true, candidateRegionInfos[i] is selected by a store and should skip it in selectRegion. + selected := make([]bool, len(candidateRegionInfos)) + for { + totalCount := 0 + selectCountThisRound := 0 + for storeID, task := range storeTasks { + // Each store select balanceContinuousRegionCount regions from candidateRegionInfos. + // Since candidateRegionInfos is sorted, it is very likely that these regions are close to each other in TiFlash. + regionInfo := selectRegion(storeID, candidateRegionInfos, selected, storeID2RegionIndex, balanceContinuousRegionCount) + task.regionInfos = append(task.regionInfos, regionInfo...) + totalCount += len(task.regionInfos) + selectCountThisRound += len(regionInfo) + } + if totalCount >= regionCount { + break + } + if selectCountThisRound == 0 { + logutil.BgLogger().Error("selectCandidateRegionInfos fail: some region cannot find relevant store.", zap.Int("regionCount", regionCount), zap.Int("candidateCount", len(candidateRegionInfos))) + return nil, 0 + } + } + balanceEnd := time.Now() + + score, balanceInfos := checkBatchCopTaskBalance(storeTasks, balanceContinuousRegionCount) + if !isBalance(score) { + logutil.BgLogger().Warn("balanceBatchCopTaskWithContinuity is not balance", zap.Int("score", score), zap.Strings("balanceInfos", balanceInfos)) + } + + totalCount := 0 + var res []*batchCopTask + for _, task := range storeTasks { + totalCount += len(task.regionInfos) + if len(task.regionInfos) > 0 { + res = append(res, task) + } + } + if totalCount != regionCount { + logutil.BgLogger().Error("balanceBatchCopTaskWithContinuity error", zap.Int("totalCount", totalCount), zap.Int("regionCount", regionCount)) + return nil, 0 + } + + logutil.BgLogger().Debug("balanceBatchCopTaskWithContinuity time", + zap.Int("candidateRegionCount", len(candidateRegionInfos)), + zap.Int64("balanceContinuousRegionCount", balanceContinuousRegionCount), + zap.Int("balanceScore", score), + zap.Duration("balanceTime", balanceEnd.Sub(balanceStart)), + zap.Duration("totalTime", time.Since(funcStart))) + + return res, score +} + +// balanceBatchCopTask balance the regions between available stores, the basic rule is +// 1. the first region of each original batch cop task belongs to its original store because some +// meta data(like the rpc context) in batchCopTask is related to it +// 2. for the remaining regions: +// if there is only 1 available store, then put the region to the related store +// otherwise, these region will be balance between TiFlash stores. +// +// Currently, there are two balance strategies. +// The first balance strategy: use a greedy algorithm to put it into the store with highest weight. This strategy only consider the region count between TiFlash stores. +// +// The second balance strategy: Not only consider the region count between TiFlash stores, but also try to make the regions' range continuous(stored in TiFlash closely). +// If balanceWithContinuity is true, the second balance strategy is enable. +func balanceBatchCopTask(aliveStores []*tikv.Store, originalTasks []*batchCopTask, balanceWithContinuity bool, balanceContinuousRegionCount int64) []*batchCopTask { + if len(originalTasks) == 0 { + log.Info("Batch cop task balancer got an empty task set.") + return originalTasks + } + storeTaskMap := make(map[uint64]*batchCopTask) + // storeCandidateRegionMap stores all the possible store->region map. Its content is + // store id -> region signature -> region info. We can see it as store id -> region lists. + storeCandidateRegionMap := make(map[uint64]map[string]RegionInfo) + totalRegionCandidateNum := 0 + totalRemainingRegionNum := 0 + + for _, s := range aliveStores { + storeTaskMap[s.StoreID()] = &batchCopTask{ + storeAddr: s.GetAddr(), + cmdType: originalTasks[0].cmdType, + ctx: &tikv.RPCContext{Addr: s.GetAddr(), Store: s}, + } + } + + var candidateRegionInfos []RegionInfo + for _, task := range originalTasks { + for _, ri := range task.regionInfos { + // for each region, figure out the valid store num + validStoreNum := 0 + var validStoreID uint64 + for _, storeID := range ri.AllStores { + if _, ok := storeTaskMap[storeID]; ok { + validStoreNum++ + // original store id might be invalid, so we have to set it again. + validStoreID = storeID + } + } + if validStoreNum == 0 { + logutil.BgLogger().Warn("Meet regions that don't have an available store. Give up balancing") + return originalTasks + } else if validStoreNum == 1 { + // if only one store is valid, just put it to storeTaskMap + storeTaskMap[validStoreID].regionInfos = append(storeTaskMap[validStoreID].regionInfos, ri) + } else { + // if more than one store is valid, put the region + // to store candidate map + totalRegionCandidateNum += validStoreNum + totalRemainingRegionNum++ + candidateRegionInfos = append(candidateRegionInfos, ri) + taskKey := ri.Region.String() + for _, storeID := range ri.AllStores { + if _, validStore := storeTaskMap[storeID]; !validStore { + continue + } + if _, ok := storeCandidateRegionMap[storeID]; !ok { + candidateMap := make(map[string]RegionInfo) + storeCandidateRegionMap[storeID] = candidateMap + } + if _, duplicateRegion := storeCandidateRegionMap[storeID][taskKey]; duplicateRegion { + // duplicated region, should not happen, just give up balance + logutil.BgLogger().Warn("Meet duplicated region info during when trying to balance batch cop task, give up balancing") + return originalTasks + } + storeCandidateRegionMap[storeID][taskKey] = ri + } + } + } + } + + // If balanceBatchCopTaskWithContinuity failed (not balance or return nil), it will fallback to the original balance logic. + // So storeTaskMap should not be modify. + var contiguousTasks []*batchCopTask = nil + contiguousBalanceScore := 0 + if balanceWithContinuity { + contiguousTasks, contiguousBalanceScore = balanceBatchCopTaskWithContinuity(storeTaskMap, candidateRegionInfos, balanceContinuousRegionCount) + if isBalance(contiguousBalanceScore) && contiguousTasks != nil { + return contiguousTasks + } + } + + if totalRemainingRegionNum > 0 { + avgStorePerRegion := float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + findNextStore := func(candidateStores []uint64) uint64 { + store := uint64(math.MaxUint64) + weightedRegionNum := math.MaxFloat64 + if candidateStores != nil { + for _, storeID := range candidateStores { + if _, validStore := storeCandidateRegionMap[storeID]; !validStore { + continue + } + num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) + if num < weightedRegionNum { + store = storeID + weightedRegionNum = num + } + } + if store != uint64(math.MaxUint64) { + return store + } + } + for storeID := range storeTaskMap { + if _, validStore := storeCandidateRegionMap[storeID]; !validStore { + continue + } + num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) + if num < weightedRegionNum { + store = storeID + weightedRegionNum = num + } + } + return store + } + + store := findNextStore(nil) + for totalRemainingRegionNum > 0 { + if store == uint64(math.MaxUint64) { + break + } + var key string + var ri RegionInfo + for key, ri = range storeCandidateRegionMap[store] { + // get the first region + break + } + storeTaskMap[store].regionInfos = append(storeTaskMap[store].regionInfos, ri) + totalRemainingRegionNum-- + for _, id := range ri.AllStores { + if _, ok := storeCandidateRegionMap[id]; ok { + delete(storeCandidateRegionMap[id], key) + totalRegionCandidateNum-- + if len(storeCandidateRegionMap[id]) == 0 { + delete(storeCandidateRegionMap, id) + } + } + } + if totalRemainingRegionNum > 0 { + avgStorePerRegion = float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) + // it is not optimal because we only check the stores that affected by this region, in fact in order + // to find out the store with the lowest weightedRegionNum, all stores should be checked, but I think + // check only the affected stores is more simple and will get a good enough result + store = findNextStore(ri.AllStores) + } + } + if totalRemainingRegionNum > 0 { + logutil.BgLogger().Warn("Some regions are not used when trying to balance batch cop task, give up balancing") + return originalTasks + } + } + + if contiguousTasks != nil { + score, balanceInfos := checkBatchCopTaskBalance(storeTaskMap, balanceContinuousRegionCount) + if !isBalance(score) { + logutil.BgLogger().Warn("Region count is not balance and use contiguousTasks", zap.Int("contiguousBalanceScore", contiguousBalanceScore), zap.Int("score", score), zap.Strings("balanceInfos", balanceInfos)) + return contiguousTasks + } + } + + var ret []*batchCopTask + for _, task := range storeTaskMap { + if len(task.regionInfos) > 0 { + ret = append(ret, task) + } + } + return ret +} + +func buildBatchCopTasksForNonPartitionedTable( + ctx context.Context, + bo *backoff.Backoffer, + store *kvStore, + ranges *KeyRanges, + storeType kv.StoreType, + isMPP bool, + ttl time.Duration, + balanceWithContinuity bool, + balanceContinuousRegionCount int64, + dispatchPolicy tiflashcompute.DispatchPolicy, + tiflashReplicaReadPolicy tiflash.ReplicaRead, + appendWarning func(error)) ([]*batchCopTask, error) { + if config.GetGlobalConfig().DisaggregatedTiFlash { + if config.GetGlobalConfig().UseAutoScaler { + return buildBatchCopTasksConsistentHash(ctx, bo, store, []*KeyRanges{ranges}, storeType, ttl, dispatchPolicy) + } + return buildBatchCopTasksConsistentHashForPD(bo, store, []*KeyRanges{ranges}, storeType, ttl, dispatchPolicy) + } + return buildBatchCopTasksCore(bo, store, []*KeyRanges{ranges}, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount, tiflashReplicaReadPolicy, appendWarning) +} + +func buildBatchCopTasksForPartitionedTable( + ctx context.Context, + bo *backoff.Backoffer, + store *kvStore, + rangesForEachPhysicalTable []*KeyRanges, + storeType kv.StoreType, + isMPP bool, + ttl time.Duration, + balanceWithContinuity bool, + balanceContinuousRegionCount int64, + partitionIDs []int64, + dispatchPolicy tiflashcompute.DispatchPolicy, + tiflashReplicaReadPolicy tiflash.ReplicaRead, + appendWarning func(error)) (batchTasks []*batchCopTask, err error) { + if config.GetGlobalConfig().DisaggregatedTiFlash { + if config.GetGlobalConfig().UseAutoScaler { + batchTasks, err = buildBatchCopTasksConsistentHash(ctx, bo, store, rangesForEachPhysicalTable, storeType, ttl, dispatchPolicy) + } else { + // todo: remove this after AutoScaler is stable. + batchTasks, err = buildBatchCopTasksConsistentHashForPD(bo, store, rangesForEachPhysicalTable, storeType, ttl, dispatchPolicy) + } + } else { + batchTasks, err = buildBatchCopTasksCore(bo, store, rangesForEachPhysicalTable, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount, tiflashReplicaReadPolicy, appendWarning) + } + if err != nil { + return nil, err + } + // generate tableRegions for batchCopTasks + convertRegionInfosToPartitionTableRegions(batchTasks, partitionIDs) + return batchTasks, nil +} + +func filterAliveStoresStr(ctx context.Context, storesStr []string, ttl time.Duration, kvStore *kvStore) (aliveStores []string) { + aliveIdx := filterAliveStoresHelper(ctx, storesStr, ttl, kvStore) + for _, idx := range aliveIdx { + aliveStores = append(aliveStores, storesStr[idx]) + } + return aliveStores +} + +func filterAliveStores(ctx context.Context, stores []*tikv.Store, ttl time.Duration, kvStore *kvStore) (aliveStores []*tikv.Store) { + storesStr := make([]string, 0, len(stores)) + for _, s := range stores { + storesStr = append(storesStr, s.GetAddr()) + } + + aliveIdx := filterAliveStoresHelper(ctx, storesStr, ttl, kvStore) + for _, idx := range aliveIdx { + aliveStores = append(aliveStores, stores[idx]) + } + return aliveStores +} + +func filterAliveStoresHelper(ctx context.Context, stores []string, ttl time.Duration, kvStore *kvStore) (aliveIdx []int) { + var wg sync.WaitGroup + var mu sync.Mutex + wg.Add(len(stores)) + for i := range stores { + go func(idx int) { + defer wg.Done() + s := stores[idx] + + // Check if store is failed already. + if ok := GlobalMPPFailedStoreProber.IsRecovery(ctx, s, ttl); !ok { + return + } + + tikvClient := kvStore.GetTiKVClient() + if ok := detectMPPStore(ctx, tikvClient, s, DetectTimeoutLimit); !ok { + GlobalMPPFailedStoreProber.Add(ctx, s, tikvClient) + return + } + + mu.Lock() + defer mu.Unlock() + aliveIdx = append(aliveIdx, idx) + }(i) + } + wg.Wait() + + logutil.BgLogger().Info("detecting available mpp stores", zap.Any("total", len(stores)), zap.Any("alive", len(aliveIdx))) + return aliveIdx +} + +func getTiFlashComputeRPCContextByConsistentHash(ids []tikv.RegionVerID, storesStr []string) (res []*tikv.RPCContext, err error) { + // Use RendezvousHash + for _, id := range ids { + var maxHash uint32 = 0 + var maxHashStore string = "" + for _, store := range storesStr { + h := murmur3.StringSum32(fmt.Sprintf("%s-%d", store, id.GetID())) + if h > maxHash { + maxHash = h + maxHashStore = store + } + } + rpcCtx := &tikv.RPCContext{ + Region: id, + Addr: maxHashStore, + } + res = append(res, rpcCtx) + } + return res, nil +} + +func getTiFlashComputeRPCContextByRoundRobin(ids []tikv.RegionVerID, storesStr []string) (res []*tikv.RPCContext, err error) { + startIdx := rand.Intn(len(storesStr)) + for _, id := range ids { + rpcCtx := &tikv.RPCContext{ + Region: id, + Addr: storesStr[startIdx%len(storesStr)], + } + + startIdx++ + res = append(res, rpcCtx) + } + return res, nil +} + +// 1. Split range by region location to build copTasks. +// 2. For each copTask build its rpcCtx , the target tiflash_compute node will be chosen using consistent hash. +// 3. All copTasks that will be sent to one tiflash_compute node are put in one batchCopTask. +func buildBatchCopTasksConsistentHash( + ctx context.Context, + bo *backoff.Backoffer, + kvStore *kvStore, + rangesForEachPhysicalTable []*KeyRanges, + storeType kv.StoreType, + ttl time.Duration, + dispatchPolicy tiflashcompute.DispatchPolicy) (res []*batchCopTask, err error) { + failpointCheckWhichPolicy(dispatchPolicy) + start := time.Now() + const cmdType = tikvrpc.CmdBatchCop + cache := kvStore.GetRegionCache() + fetchTopoBo := backoff.NewBackofferWithVars(ctx, fetchTopoMaxBackoff, nil) + + var ( + retryNum int + rangesLen int + storesStr []string + ) + + tasks := make([]*copTask, 0) + regionIDs := make([]tikv.RegionVerID, 0) + + for i, ranges := range rangesForEachPhysicalTable { + rangesLen += ranges.Len() + locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) + if err != nil { + return nil, errors.Trace(err) + } + for _, lo := range locations { + tasks = append(tasks, &copTask{ + region: lo.Location.Region, + ranges: lo.Ranges, + cmdType: cmdType, + storeType: storeType, + partitionIndex: int64(i), + }) + regionIDs = append(regionIDs, lo.Location.Region) + } + } + splitKeyElapsed := time.Since(start) + + fetchTopoStart := time.Now() + for { + retryNum++ + storesStr, err = tiflashcompute.GetGlobalTopoFetcher().FetchAndGetTopo() + if err != nil { + return nil, err + } + storesBefFilter := len(storesStr) + storesStr = filterAliveStoresStr(ctx, storesStr, ttl, kvStore) + logutil.BgLogger().Info("topo filter alive", zap.Any("topo", storesStr)) + if len(storesStr) == 0 { + errMsg := "Cannot find proper topo to dispatch MPPTask: " + if storesBefFilter == 0 { + errMsg += "topo from AutoScaler is empty" + } else { + errMsg += "detect aliveness failed, no alive ComputeNode" + } + retErr := errors.New(errMsg) + logutil.BgLogger().Info("buildBatchCopTasksConsistentHash retry because FetchAndGetTopo return empty topo", zap.Int("retryNum", retryNum)) + if intest.InTest && retryNum > 3 { + return nil, retErr + } + err := fetchTopoBo.Backoff(tikv.BoTiFlashRPC(), retErr) + if err != nil { + return nil, retErr + } + continue + } + break + } + fetchTopoElapsed := time.Since(fetchTopoStart) + + var rpcCtxs []*tikv.RPCContext + if dispatchPolicy == tiflashcompute.DispatchPolicyRR { + rpcCtxs, err = getTiFlashComputeRPCContextByRoundRobin(regionIDs, storesStr) + } else if dispatchPolicy == tiflashcompute.DispatchPolicyConsistentHash { + rpcCtxs, err = getTiFlashComputeRPCContextByConsistentHash(regionIDs, storesStr) + } else { + err = errors.Errorf("unexpected dispatch policy %v", dispatchPolicy) + } + if err != nil { + return nil, err + } + if len(rpcCtxs) != len(tasks) { + return nil, errors.Errorf("length should be equal, len(rpcCtxs): %d, len(tasks): %d", len(rpcCtxs), len(tasks)) + } + taskMap := make(map[string]*batchCopTask) + for i, rpcCtx := range rpcCtxs { + regionInfo := RegionInfo{ + // tasks and rpcCtxs are correspond to each other. + Region: tasks[i].region, + Ranges: tasks[i].ranges, + PartitionIndex: tasks[i].partitionIndex, + } + if batchTask, ok := taskMap[rpcCtx.Addr]; ok { + batchTask.regionInfos = append(batchTask.regionInfos, regionInfo) + } else { + batchTask := &batchCopTask{ + storeAddr: rpcCtx.Addr, + cmdType: cmdType, + ctx: rpcCtx, + regionInfos: []RegionInfo{regionInfo}, + } + taskMap[rpcCtx.Addr] = batchTask + res = append(res, batchTask) + } + } + logutil.BgLogger().Info("buildBatchCopTasksConsistentHash done", + zap.Any("len(tasks)", len(taskMap)), + zap.Any("len(tiflash_compute)", len(storesStr)), + zap.Any("dispatchPolicy", tiflashcompute.GetDispatchPolicy(dispatchPolicy))) + + if log.GetLevel() <= zap.DebugLevel { + debugTaskMap := make(map[string]string, len(taskMap)) + for s, b := range taskMap { + debugTaskMap[s] = fmt.Sprintf("addr: %s; regionInfos: %v", b.storeAddr, b.regionInfos) + } + logutil.BgLogger().Debug("detailed info buildBatchCopTasksConsistentHash", zap.Any("taskMap", debugTaskMap), zap.Any("allStores", storesStr)) + } + + if elapsed := time.Since(start); elapsed > time.Millisecond*500 { + logutil.BgLogger().Warn("buildBatchCopTasksConsistentHash takes too much time", + zap.Duration("total elapsed", elapsed), + zap.Int("retryNum", retryNum), + zap.Duration("splitKeyElapsed", splitKeyElapsed), + zap.Duration("fetchTopoElapsed", fetchTopoElapsed), + zap.Int("range len", rangesLen), + zap.Int("copTaskNum", len(tasks)), + zap.Int("batchCopTaskNum", len(res))) + } + failpointCheckForConsistentHash(res) + return res, nil +} + +func failpointCheckForConsistentHash(tasks []*batchCopTask) { + failpoint.Inject("checkOnlyDispatchToTiFlashComputeNodes", func(val failpoint.Value) { + logutil.BgLogger().Debug("in checkOnlyDispatchToTiFlashComputeNodes") + + // This failpoint will be tested in test-infra case, because we needs setup a cluster. + // All tiflash_compute nodes addrs are stored in val, separated by semicolon. + str := val.(string) + addrs := strings.Split(str, ";") + if len(addrs) < 1 { + err := fmt.Sprintf("unexpected length of tiflash_compute node addrs: %v, %s", len(addrs), str) + panic(err) + } + addrMap := make(map[string]struct{}) + for _, addr := range addrs { + addrMap[addr] = struct{}{} + } + for _, batchTask := range tasks { + if _, ok := addrMap[batchTask.storeAddr]; !ok { + err := errors.Errorf("batchCopTask send to node which is not tiflash_compute: %v(tiflash_compute nodes: %s)", batchTask.storeAddr, str) + panic(err) + } + } + }) +} + +func failpointCheckWhichPolicy(act tiflashcompute.DispatchPolicy) { + failpoint.Inject("testWhichDispatchPolicy", func(exp failpoint.Value) { + expStr := exp.(string) + actStr := tiflashcompute.GetDispatchPolicy(act) + if actStr != expStr { + err := errors.Errorf("tiflash_compute dispatch should be %v, but got %v", expStr, actStr) + panic(err) + } + }) +} + +func filterAllStoresAccordingToTiFlashReplicaRead(allStores []uint64, aliveStores *aliveStoresBundle, policy tiflash.ReplicaRead) (storesMatchedPolicy []uint64, needsCrossZoneAccess bool) { + if policy.IsAllReplicas() { + for _, id := range allStores { + if _, ok := aliveStores.storeIDsInAllZones[id]; ok { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + return + } + // Check whether exists available stores in TiDB zone. If so, we only need to access TiFlash stores in TiDB zone. + for _, id := range allStores { + if _, ok := aliveStores.storeIDsInTiDBZone[id]; ok { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + // If no available stores in TiDB zone, we need to access TiFlash stores in other zones. + if len(storesMatchedPolicy) == 0 { + // needsCrossZoneAccess indicates whether we need to access(directly read or remote read) TiFlash stores in other zones. + needsCrossZoneAccess = true + + if policy == tiflash.ClosestAdaptive { + // If the policy is `ClosestAdaptive`, we can dispatch tasks to the TiFlash stores in other zones. + for _, id := range allStores { + if _, ok := aliveStores.storeIDsInAllZones[id]; ok { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + } else if policy == tiflash.ClosestReplicas { + // If the policy is `ClosestReplicas`, we dispatch tasks to the TiFlash stores in TiDB zone and remote read from other zones. + for id := range aliveStores.storeIDsInTiDBZone { + storesMatchedPolicy = append(storesMatchedPolicy, id) + } + } + } + return +} + +func getAllUsedTiFlashStores(allTiFlashStores []*tikv.Store, allUsedTiFlashStoresMap map[uint64]struct{}) []*tikv.Store { + allUsedTiFlashStores := make([]*tikv.Store, 0, len(allUsedTiFlashStoresMap)) + for _, store := range allTiFlashStores { + _, ok := allUsedTiFlashStoresMap[store.StoreID()] + if ok { + allUsedTiFlashStores = append(allUsedTiFlashStores, store) + } + } + return allUsedTiFlashStores +} + +// getAliveStoresAndStoreIDs gets alive TiFlash stores and their IDs. +// If tiflashReplicaReadPolicy is not all_replicas, it will also return the IDs of the alive TiFlash stores in TiDB zone. +func getAliveStoresAndStoreIDs(ctx context.Context, cache *RegionCache, allUsedTiFlashStoresMap map[uint64]struct{}, ttl time.Duration, store *kvStore, tiflashReplicaReadPolicy tiflash.ReplicaRead, tidbZone string) (aliveStores *aliveStoresBundle) { + aliveStores = new(aliveStoresBundle) + allTiFlashStores := cache.RegionCache.GetTiFlashStores(tikv.LabelFilterNoTiFlashWriteNode) + allUsedTiFlashStores := getAllUsedTiFlashStores(allTiFlashStores, allUsedTiFlashStoresMap) + aliveStores.storesInAllZones = filterAliveStores(ctx, allUsedTiFlashStores, ttl, store) + + if !tiflashReplicaReadPolicy.IsAllReplicas() { + aliveStores.storeIDsInTiDBZone = make(map[uint64]struct{}, len(aliveStores.storesInAllZones)) + for _, as := range aliveStores.storesInAllZones { + // If the `zone` label of the TiFlash store is not set, we treat it as a TiFlash store in other zones. + if tiflashZone, isSet := as.GetLabelValue(placement.DCLabelKey); isSet && tiflashZone == tidbZone { + aliveStores.storeIDsInTiDBZone[as.StoreID()] = struct{}{} + aliveStores.storesInTiDBZone = append(aliveStores.storesInTiDBZone, as) + } + } + } + if !tiflashReplicaReadPolicy.IsClosestReplicas() { + aliveStores.storeIDsInAllZones = make(map[uint64]struct{}, len(aliveStores.storesInAllZones)) + for _, as := range aliveStores.storesInAllZones { + aliveStores.storeIDsInAllZones[as.StoreID()] = struct{}{} + } + } + return aliveStores +} + +// filterAccessibleStoresAndBuildRegionInfo filters the stores that can be accessed according to: +// 1. tiflash_replica_read policy +// 2. whether the store is alive +// After filtering, it will build the RegionInfo. +func filterAccessibleStoresAndBuildRegionInfo( + cache *RegionCache, + allStores []uint64, + bo *Backoffer, + task *copTask, + rpcCtx *tikv.RPCContext, + aliveStores *aliveStoresBundle, + tiflashReplicaReadPolicy tiflash.ReplicaRead, + regionInfoNeedsReloadOnSendFail []RegionInfo, + regionsInOtherZones []uint64, + maxRemoteReadCountAllowed int, + tidbZone string) (regionInfo RegionInfo, _ []RegionInfo, _ []uint64, err error) { + needCrossZoneAccess := false + allStores, needCrossZoneAccess = filterAllStoresAccordingToTiFlashReplicaRead(allStores, aliveStores, tiflashReplicaReadPolicy) + + regionInfo = RegionInfo{ + Region: task.region, + Meta: rpcCtx.Meta, + Ranges: task.ranges, + AllStores: allStores, + PartitionIndex: task.partitionIndex} + + if needCrossZoneAccess { + regionsInOtherZones = append(regionsInOtherZones, task.region.GetID()) + regionInfoNeedsReloadOnSendFail = append(regionInfoNeedsReloadOnSendFail, regionInfo) + if tiflashReplicaReadPolicy.IsClosestReplicas() && len(regionsInOtherZones) > maxRemoteReadCountAllowed { + regionIDErrMsg := "" + for i := 0; i < 3 && i < len(regionsInOtherZones); i++ { + regionIDErrMsg += fmt.Sprintf("%d, ", regionsInOtherZones[i]) + } + err = errors.Errorf( + "no less than %d region(s) can not be accessed by TiFlash in the zone [%s]: %setc", + len(regionsInOtherZones), tidbZone, regionIDErrMsg) + // We need to reload the region cache here to avoid the failure throughout the region cache refresh TTL. + cache.OnSendFailForBatchRegions(bo, rpcCtx.Store, regionInfoNeedsReloadOnSendFail, true, err) + return regionInfo, nil, nil, err + } + } + return regionInfo, regionInfoNeedsReloadOnSendFail, regionsInOtherZones, nil +} + +type aliveStoresBundle struct { + storesInAllZones []*tikv.Store + storeIDsInAllZones map[uint64]struct{} + storesInTiDBZone []*tikv.Store + storeIDsInTiDBZone map[uint64]struct{} +} + +// When `partitionIDs != nil`, it means that buildBatchCopTasksCore is constructing a batch cop tasks for PartitionTableScan. +// At this time, `len(rangesForEachPhysicalTable) == len(partitionIDs)` and `rangesForEachPhysicalTable[i]` is for partition `partitionIDs[i]`. +// Otherwise, `rangesForEachPhysicalTable[0]` indicates the range for the single physical table. +func buildBatchCopTasksCore(bo *backoff.Backoffer, store *kvStore, rangesForEachPhysicalTable []*KeyRanges, storeType kv.StoreType, isMPP bool, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64, tiflashReplicaReadPolicy tiflash.ReplicaRead, appendWarning func(error)) ([]*batchCopTask, error) { + cache := store.GetRegionCache() + start := time.Now() + const cmdType = tikvrpc.CmdBatchCop + rangesLen := 0 + + tidbZone, isTiDBLabelZoneSet := config.GetGlobalConfig().Labels[placement.DCLabelKey] + var ( + aliveStores *aliveStoresBundle + maxRemoteReadCountAllowed int + ) + if !isTiDBLabelZoneSet { + tiflashReplicaReadPolicy = tiflash.AllReplicas + } + + for { + var tasks []*copTask + rangesLen = 0 + for i, ranges := range rangesForEachPhysicalTable { + rangesLen += ranges.Len() + locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) + if err != nil { + return nil, errors.Trace(err) + } + for _, lo := range locations { + tasks = append(tasks, &copTask{ + region: lo.Location.Region, + ranges: lo.Ranges, + cmdType: cmdType, + storeType: storeType, + partitionIndex: int64(i), + }) + } + } + + rpcCtxs := make([]*tikv.RPCContext, 0, len(tasks)) + usedTiFlashStores := make([][]uint64, 0, len(tasks)) + usedTiFlashStoresMap := make(map[uint64]struct{}, 0) + needRetry := false + for _, task := range tasks { + rpcCtx, err := cache.GetTiFlashRPCContext(bo.TiKVBackoffer(), task.region, isMPP, tikv.LabelFilterNoTiFlashWriteNode) + if err != nil { + return nil, errors.Trace(err) + } + + // When rpcCtx is nil, it's not only attributed to the miss region, but also + // some TiFlash stores crash and can't be recovered. + // That is not an error that can be easily recovered, so we regard this error + // same as rpc error. + if rpcCtx == nil { + needRetry = true + logutil.BgLogger().Info("retry for TiFlash peer with region missing", zap.Uint64("region id", task.region.GetID())) + // Probably all the regions are invalid. Make the loop continue and mark all the regions invalid. + // Then `splitRegion` will reloads these regions. + continue + } + + allStores, _ := cache.GetAllValidTiFlashStores(task.region, rpcCtx.Store, tikv.LabelFilterNoTiFlashWriteNode) + for _, storeID := range allStores { + usedTiFlashStoresMap[storeID] = struct{}{} + } + rpcCtxs = append(rpcCtxs, rpcCtx) + usedTiFlashStores = append(usedTiFlashStores, allStores) + } + + if needRetry { + // As mentioned above, nil rpcCtx is always attributed to failed stores. + // It's equal to long poll the store but get no response. Here we'd better use + // TiFlash error to trigger the TiKV fallback mechanism. + err := bo.Backoff(tikv.BoTiFlashRPC(), errors.New("Cannot find region with TiFlash peer")) + if err != nil { + return nil, errors.Trace(err) + } + continue + } + + aliveStores = getAliveStoresAndStoreIDs(bo.GetCtx(), cache, usedTiFlashStoresMap, ttl, store, tiflashReplicaReadPolicy, tidbZone) + if tiflashReplicaReadPolicy.IsClosestReplicas() { + if len(aliveStores.storeIDsInTiDBZone) == 0 { + return nil, errors.Errorf("There is no region in tidb zone(%s)", tidbZone) + } + maxRemoteReadCountAllowed = len(aliveStores.storeIDsInTiDBZone) * tiflash.MaxRemoteReadCountPerNodeForClosestReplicas + } + + var batchTasks []*batchCopTask + var regionIDsInOtherZones []uint64 + var regionInfosNeedReloadOnSendFail []RegionInfo + storeTaskMap := make(map[string]*batchCopTask) + storeIDsUnionSetForAllTasks := make(map[uint64]struct{}) + for idx, task := range tasks { + var err error + var regionInfo RegionInfo + regionInfo, regionInfosNeedReloadOnSendFail, regionIDsInOtherZones, err = filterAccessibleStoresAndBuildRegionInfo(cache, usedTiFlashStores[idx], bo, task, rpcCtxs[idx], aliveStores, tiflashReplicaReadPolicy, regionInfosNeedReloadOnSendFail, regionIDsInOtherZones, maxRemoteReadCountAllowed, tidbZone) + if err != nil { + return nil, err + } + if batchCop, ok := storeTaskMap[rpcCtxs[idx].Addr]; ok { + batchCop.regionInfos = append(batchCop.regionInfos, regionInfo) + } else { + batchTask := &batchCopTask{ + storeAddr: rpcCtxs[idx].Addr, + cmdType: cmdType, + ctx: rpcCtxs[idx], + regionInfos: []RegionInfo{regionInfo}, + } + storeTaskMap[rpcCtxs[idx].Addr] = batchTask + } + for _, storeID := range regionInfo.AllStores { + storeIDsUnionSetForAllTasks[storeID] = struct{}{} + } + } + + if len(regionIDsInOtherZones) != 0 { + warningMsg := fmt.Sprintf("total %d region(s) can not be accessed by TiFlash in the zone [%s]:", len(regionIDsInOtherZones), tidbZone) + regionIDErrMsg := "" + for i := 0; i < 3 && i < len(regionIDsInOtherZones); i++ { + regionIDErrMsg += fmt.Sprintf("%d, ", regionIDsInOtherZones[i]) + } + warningMsg += regionIDErrMsg + "etc" + appendWarning(errors.NewNoStackErrorf(warningMsg)) + } + + for _, task := range storeTaskMap { + batchTasks = append(batchTasks, task) + } + if log.GetLevel() <= zap.DebugLevel { + msg := "Before region balance:" + for _, task := range batchTasks { + msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," + } + logutil.BgLogger().Debug(msg) + } + balanceStart := time.Now() + storesUnionSetForAllTasks := make([]*tikv.Store, 0, len(storeIDsUnionSetForAllTasks)) + for _, store := range aliveStores.storesInAllZones { + if _, ok := storeIDsUnionSetForAllTasks[store.StoreID()]; ok { + storesUnionSetForAllTasks = append(storesUnionSetForAllTasks, store) + } + } + batchTasks = balanceBatchCopTask(storesUnionSetForAllTasks, batchTasks, balanceWithContinuity, balanceContinuousRegionCount) + balanceElapsed := time.Since(balanceStart) + if log.GetLevel() <= zap.DebugLevel { + msg := "After region balance:" + for _, task := range batchTasks { + msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," + } + logutil.BgLogger().Debug(msg) + } + + if elapsed := time.Since(start); elapsed > time.Millisecond*500 { + logutil.BgLogger().Warn("buildBatchCopTasksCore takes too much time", + zap.Duration("elapsed", elapsed), + zap.Duration("balanceElapsed", balanceElapsed), + zap.Int("range len", rangesLen), + zap.Int("task len", len(batchTasks))) + } + metrics.TxnRegionsNumHistogramWithBatchCoprocessor.Observe(float64(len(batchTasks))) + return batchTasks, nil + } +} + +func convertRegionInfosToPartitionTableRegions(batchTasks []*batchCopTask, partitionIDs []int64) { + for _, copTask := range batchTasks { + tableRegions := make([]*coprocessor.TableRegions, len(partitionIDs)) + // init coprocessor.TableRegions + for j, pid := range partitionIDs { + tableRegions[j] = &coprocessor.TableRegions{ + PhysicalTableId: pid, + } + } + // fill region infos + for _, ri := range copTask.regionInfos { + tableRegions[ri.PartitionIndex].Regions = append(tableRegions[ri.PartitionIndex].Regions, + ri.toCoprocessorRegionInfo()) + } + count := 0 + // clear empty table region + for j := 0; j < len(tableRegions); j++ { + if len(tableRegions[j].Regions) != 0 { + tableRegions[count] = tableRegions[j] + count++ + } + } + copTask.PartitionTableRegions = tableRegions[:count] + copTask.regionInfos = nil + } +} + +func (c *CopClient) sendBatch(ctx context.Context, req *kv.Request, vars *tikv.Variables, option *kv.ClientSendOption) kv.Response { + if req.KeepOrder || req.Desc { + return copErrorResponse{errors.New("batch coprocessor cannot prove keep order or desc property")} + } + ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTs) + bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, vars) + + var tasks []*batchCopTask + var err error + if req.PartitionIDAndRanges != nil { + // For Partition Table Scan + keyRanges := make([]*KeyRanges, 0, len(req.PartitionIDAndRanges)) + partitionIDs := make([]int64, 0, len(req.PartitionIDAndRanges)) + for _, pi := range req.PartitionIDAndRanges { + keyRanges = append(keyRanges, NewKeyRanges(pi.KeyRanges)) + partitionIDs = append(partitionIDs, pi.ID) + } + tasks, err = buildBatchCopTasksForPartitionedTable(ctx, bo, c.store.kvStore, keyRanges, req.StoreType, false, 0, false, 0, partitionIDs, tiflashcompute.DispatchPolicyInvalid, option.TiFlashReplicaRead, option.AppendWarning) + } else { + // TODO: merge the if branch. + ranges := NewKeyRanges(req.KeyRanges.FirstPartitionRange()) + tasks, err = buildBatchCopTasksForNonPartitionedTable(ctx, bo, c.store.kvStore, ranges, req.StoreType, false, 0, false, 0, tiflashcompute.DispatchPolicyInvalid, option.TiFlashReplicaRead, option.AppendWarning) + } + + if err != nil { + return copErrorResponse{err} + } + it := &batchCopIterator{ + store: c.store.kvStore, + req: req, + finishCh: make(chan struct{}), + vars: vars, + rpcCancel: tikv.NewRPCanceller(), + enableCollectExecutionInfo: option.EnableCollectExecutionInfo, + tiflashReplicaReadPolicy: option.TiFlashReplicaRead, + appendWarning: option.AppendWarning, + } + ctx = context.WithValue(ctx, tikv.RPCCancellerCtxKey{}, it.rpcCancel) + it.tasks = tasks + it.respChan = make(chan *batchCopResponse, 2048) + go it.run(ctx) + return it +} + +type batchCopIterator struct { + store *kvStore + req *kv.Request + finishCh chan struct{} + + tasks []*batchCopTask + + // Batch results are stored in respChan. + respChan chan *batchCopResponse + + vars *tikv.Variables + + rpcCancel *tikv.RPCCanceller + + wg sync.WaitGroup + // closed represents when the Close is called. + // There are two cases we need to close the `finishCh` channel, one is when context is done, the other one is + // when the Close is called. we use atomic.CompareAndSwap `closed` to to make sure the channel is not closed twice. + closed uint32 + + enableCollectExecutionInfo bool + tiflashReplicaReadPolicy tiflash.ReplicaRead + appendWarning func(error) +} + +func (b *batchCopIterator) run(ctx context.Context) { + // We run workers for every batch cop. + for _, task := range b.tasks { + b.wg.Add(1) + boMaxSleep := CopNextMaxBackoff + failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { + if value.(bool) { + boMaxSleep = 2 + } + }) + bo := backoff.NewBackofferWithVars(ctx, boMaxSleep, b.vars) + go b.handleTask(ctx, bo, task) + } + b.wg.Wait() + close(b.respChan) +} + +// Next returns next coprocessor result. +// NOTE: Use nil to indicate finish, so if the returned ResultSubset is not nil, reader should continue to call Next(). +func (b *batchCopIterator) Next(ctx context.Context) (kv.ResultSubset, error) { + var ( + resp *batchCopResponse + ok bool + closed bool + ) + + // Get next fetched resp from chan + resp, ok, closed = b.recvFromRespCh(ctx) + if !ok || closed { + return nil, nil + } + + if resp.err != nil { + return nil, errors.Trace(resp.err) + } + + err := b.store.CheckVisibility(b.req.StartTs) + if err != nil { + return nil, errors.Trace(err) + } + return resp, nil +} + +func (b *batchCopIterator) recvFromRespCh(ctx context.Context) (resp *batchCopResponse, ok bool, exit bool) { + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case resp, ok = <-b.respChan: + return + case <-ticker.C: + killed := atomic.LoadUint32(b.vars.Killed) + if killed != 0 { + logutil.Logger(ctx).Info( + "a killed signal is received", + zap.Uint32("signal", killed), + ) + resp = &batchCopResponse{err: derr.ErrQueryInterrupted} + ok = true + return + } + case <-b.finishCh: + exit = true + return + case <-ctx.Done(): + // We select the ctx.Done() in the thread of `Next` instead of in the worker to avoid the cost of `WithCancel`. + if atomic.CompareAndSwapUint32(&b.closed, 0, 1) { + close(b.finishCh) + } + exit = true + return + } + } +} + +// Close releases the resource. +func (b *batchCopIterator) Close() error { + if atomic.CompareAndSwapUint32(&b.closed, 0, 1) { + close(b.finishCh) + } + b.rpcCancel.CancelAll() + b.wg.Wait() + return nil +} + +func (b *batchCopIterator) handleTask(ctx context.Context, bo *Backoffer, task *batchCopTask) { + tasks := []*batchCopTask{task} + for idx := 0; idx < len(tasks); idx++ { + ret, err := b.handleTaskOnce(ctx, bo, tasks[idx]) + if err != nil { + resp := &batchCopResponse{err: errors.Trace(err), detail: new(CopRuntimeStats)} + b.sendToRespCh(resp) + break + } + tasks = append(tasks, ret...) + } + b.wg.Done() +} + +// Merge all ranges and request again. +func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *backoff.Backoffer, batchTask *batchCopTask) ([]*batchCopTask, error) { + if batchTask.regionInfos != nil { + var ranges []kv.KeyRange + for _, ri := range batchTask.regionInfos { + ri.Ranges.Do(func(ran *kv.KeyRange) { + ranges = append(ranges, *ran) + }) + } + // need to make sure the key ranges is sorted + slices.SortFunc(ranges, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + ret, err := buildBatchCopTasksForNonPartitionedTable(ctx, bo, b.store, NewKeyRanges(ranges), b.req.StoreType, false, 0, false, 0, tiflashcompute.DispatchPolicyInvalid, b.tiflashReplicaReadPolicy, b.appendWarning) + return ret, err + } + // Retry Partition Table Scan + keyRanges := make([]*KeyRanges, 0, len(batchTask.PartitionTableRegions)) + pid := make([]int64, 0, len(batchTask.PartitionTableRegions)) + for _, trs := range batchTask.PartitionTableRegions { + pid = append(pid, trs.PhysicalTableId) + ranges := make([]kv.KeyRange, 0, len(trs.Regions)) + for _, ri := range trs.Regions { + for _, ran := range ri.Ranges { + ranges = append(ranges, kv.KeyRange{ + StartKey: ran.Start, + EndKey: ran.End, + }) + } + } + // need to make sure the key ranges is sorted + slices.SortFunc(ranges, func(i, j kv.KeyRange) int { + return bytes.Compare(i.StartKey, j.StartKey) + }) + keyRanges = append(keyRanges, NewKeyRanges(ranges)) + } + ret, err := buildBatchCopTasksForPartitionedTable(ctx, bo, b.store, keyRanges, b.req.StoreType, false, 0, false, 0, pid, tiflashcompute.DispatchPolicyInvalid, b.tiflashReplicaReadPolicy, b.appendWarning) + return ret, err +} + +// TiFlashReadTimeoutUltraLong represents the max time that tiflash request may take, since it may scan many regions for tiflash. +const TiFlashReadTimeoutUltraLong = 3600 * time.Second + +func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *backoff.Backoffer, task *batchCopTask) ([]*batchCopTask, error) { + sender := NewRegionBatchRequestSender(b.store.GetRegionCache(), b.store.GetTiKVClient(), b.enableCollectExecutionInfo) + var regionInfos = make([]*coprocessor.RegionInfo, 0, len(task.regionInfos)) + for _, ri := range task.regionInfos { + regionInfos = append(regionInfos, ri.toCoprocessorRegionInfo()) + } + + copReq := coprocessor.BatchRequest{ + Tp: b.req.Tp, + StartTs: b.req.StartTs, + Data: b.req.Data, + SchemaVer: b.req.SchemaVar, + Regions: regionInfos, + TableRegions: task.PartitionTableRegions, + ConnectionId: b.req.ConnID, + ConnectionAlias: b.req.ConnAlias, + } + + rgName := b.req.ResourceGroupName + if !variable.EnableResourceControl.Load() { + rgName = "" + } + req := tikvrpc.NewRequest(task.cmdType, &copReq, kvrpcpb.Context{ + IsolationLevel: isolationLevelToPB(b.req.IsolationLevel), + Priority: priorityToPB(b.req.Priority), + NotFillCache: b.req.NotFillCache, + RecordTimeStat: true, + RecordScanStat: true, + TaskId: b.req.TaskID, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: rgName, + }, + }) + if b.req.ResourceGroupTagger != nil { + b.req.ResourceGroupTagger(req) + } + req.StoreTp = getEndPointType(kv.TiFlash) + + logutil.BgLogger().Debug("send batch request to ", zap.String("req info", req.String()), zap.Int("cop task len", len(task.regionInfos))) + resp, retry, cancel, err := sender.SendReqToAddr(bo, task.ctx, task.regionInfos, req, TiFlashReadTimeoutUltraLong) + // If there are store errors, we should retry for all regions. + if retry { + return b.retryBatchCopTask(ctx, bo, task) + } + if err != nil { + err = derr.ToTiDBErr(err) + return nil, errors.Trace(err) + } + defer cancel() + return nil, b.handleStreamedBatchCopResponse(ctx, bo, resp.Resp.(*tikvrpc.BatchCopStreamResponse), task) +} + +func (b *batchCopIterator) handleStreamedBatchCopResponse(ctx context.Context, bo *Backoffer, response *tikvrpc.BatchCopStreamResponse, task *batchCopTask) (err error) { + defer response.Close() + resp := response.BatchResponse + if resp == nil { + // streaming request returns io.EOF, so the first Response is nil. + return + } + for { + err = b.handleBatchCopResponse(bo, resp, task) + if err != nil { + return errors.Trace(err) + } + resp, err = response.Recv() + if err != nil { + if errors.Cause(err) == io.EOF { + return nil + } + + if err1 := bo.Backoff(tikv.BoTiKVRPC(), errors.Errorf("recv stream response error: %v, task store addr: %s", err, task.storeAddr)); err1 != nil { + return errors.Trace(err) + } + + // No coprocessor.Response for network error, rebuild task based on the last success one. + if errors.Cause(err) == context.Canceled { + logutil.BgLogger().Info("stream recv timeout", zap.Error(err)) + } else { + logutil.BgLogger().Info("stream unknown error", zap.Error(err)) + } + return derr.ErrTiFlashServerTimeout + } + } +} + +func (b *batchCopIterator) handleBatchCopResponse(bo *Backoffer, response *coprocessor.BatchResponse, task *batchCopTask) (err error) { + if otherErr := response.GetOtherError(); otherErr != "" { + err = errors.Errorf("other error: %s", otherErr) + logutil.BgLogger().Warn("other error", + zap.Uint64("txnStartTS", b.req.StartTs), + zap.String("storeAddr", task.storeAddr), + zap.Error(err)) + return errors.Trace(err) + } + + if len(response.RetryRegions) > 0 { + logutil.BgLogger().Info("multiple regions are stale and need to be refreshed", zap.Int("region size", len(response.RetryRegions))) + for idx, retry := range response.RetryRegions { + id := tikv.NewRegionVerID(retry.Id, retry.RegionEpoch.ConfVer, retry.RegionEpoch.Version) + logutil.BgLogger().Info("invalid region because tiflash detected stale region", zap.String("region id", id.String())) + b.store.GetRegionCache().InvalidateCachedRegionWithReason(id, tikv.EpochNotMatch) + if idx >= 10 { + logutil.BgLogger().Info("stale regions are too many, so we omit the rest ones") + break + } + } + return + } + + resp := &batchCopResponse{ + pbResp: response, + detail: new(CopRuntimeStats), + } + + b.handleCollectExecutionInfo(bo, resp, task) + b.sendToRespCh(resp) + + return +} + +func (b *batchCopIterator) sendToRespCh(resp *batchCopResponse) (exit bool) { + select { + case b.respChan <- resp: + case <-b.finishCh: + exit = true + } + return +} + +func (b *batchCopIterator) handleCollectExecutionInfo(bo *Backoffer, resp *batchCopResponse, task *batchCopTask) { + if !b.enableCollectExecutionInfo { + return + } + backoffTimes := bo.GetBackoffTimes() + resp.detail.BackoffTime = time.Duration(bo.GetTotalSleep()) * time.Millisecond + resp.detail.BackoffSleep = make(map[string]time.Duration, len(backoffTimes)) + resp.detail.BackoffTimes = make(map[string]int, len(backoffTimes)) + for backoff := range backoffTimes { + resp.detail.BackoffTimes[backoff] = backoffTimes[backoff] + resp.detail.BackoffSleep[backoff] = time.Duration(bo.GetBackoffSleepMS()[backoff]) * time.Millisecond + } + resp.detail.CalleeAddress = task.storeAddr +} + +// Only called when UseAutoScaler is false. +func buildBatchCopTasksConsistentHashForPD(bo *backoff.Backoffer, + kvStore *kvStore, + rangesForEachPhysicalTable []*KeyRanges, + storeType kv.StoreType, + ttl time.Duration, + dispatchPolicy tiflashcompute.DispatchPolicy) (res []*batchCopTask, err error) { + failpointCheckWhichPolicy(dispatchPolicy) + const cmdType = tikvrpc.CmdBatchCop + var ( + retryNum int + rangesLen int + copTaskNum int + splitKeyElapsed time.Duration + getStoreElapsed time.Duration + ) + cache := kvStore.GetRegionCache() + start := time.Now() + + for { + retryNum++ + rangesLen = 0 + tasks := make([]*copTask, 0) + regionIDs := make([]tikv.RegionVerID, 0) + + splitKeyStart := time.Now() + for i, ranges := range rangesForEachPhysicalTable { + rangesLen += ranges.Len() + locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) + if err != nil { + return nil, errors.Trace(err) + } + for _, lo := range locations { + tasks = append(tasks, &copTask{ + region: lo.Location.Region, + ranges: lo.Ranges, + cmdType: cmdType, + storeType: storeType, + partitionIndex: int64(i), + }) + regionIDs = append(regionIDs, lo.Location.Region) + } + } + splitKeyElapsed += time.Since(splitKeyStart) + + getStoreStart := time.Now() + stores, err := cache.GetTiFlashComputeStores(bo.TiKVBackoffer()) + if err != nil { + return nil, err + } + stores = filterAliveStores(bo.GetCtx(), stores, ttl, kvStore) + if len(stores) == 0 { + return nil, errors.New("tiflash_compute node is unavailable") + } + getStoreElapsed = time.Since(getStoreStart) + + storesStr := make([]string, 0, len(stores)) + for _, s := range stores { + storesStr = append(storesStr, s.GetAddr()) + } + var rpcCtxs []*tikv.RPCContext + if dispatchPolicy == tiflashcompute.DispatchPolicyRR { + rpcCtxs, err = getTiFlashComputeRPCContextByRoundRobin(regionIDs, storesStr) + } else if dispatchPolicy == tiflashcompute.DispatchPolicyConsistentHash { + rpcCtxs, err = getTiFlashComputeRPCContextByConsistentHash(regionIDs, storesStr) + } else { + err = errors.Errorf("unexpected dispatch policy %v", dispatchPolicy) + } + if err != nil { + return nil, err + } + if rpcCtxs == nil { + logutil.BgLogger().Info("buildBatchCopTasksConsistentHashForPD retry because rcpCtx is nil", zap.Int("retryNum", retryNum)) + err := bo.Backoff(tikv.BoTiFlashRPC(), errors.New("Cannot find region with TiFlash peer")) + if err != nil { + return nil, errors.Trace(err) + } + continue + } + if len(rpcCtxs) != len(tasks) { + return nil, errors.Errorf("length should be equal, len(rpcCtxs): %d, len(tasks): %d", len(rpcCtxs), len(tasks)) + } + copTaskNum = len(tasks) + taskMap := make(map[string]*batchCopTask) + for i, rpcCtx := range rpcCtxs { + regionInfo := RegionInfo{ + // tasks and rpcCtxs are correspond to each other. + Region: tasks[i].region, + Ranges: tasks[i].ranges, + PartitionIndex: tasks[i].partitionIndex, + } + if batchTask, ok := taskMap[rpcCtx.Addr]; ok { + batchTask.regionInfos = append(batchTask.regionInfos, regionInfo) + } else { + batchTask := &batchCopTask{ + storeAddr: rpcCtx.Addr, + cmdType: cmdType, + ctx: rpcCtx, + regionInfos: []RegionInfo{regionInfo}, + } + taskMap[rpcCtx.Addr] = batchTask + res = append(res, batchTask) + } + } + logutil.BgLogger().Info("buildBatchCopTasksConsistentHashForPD done", + zap.Any("len(tasks)", len(taskMap)), + zap.Any("len(tiflash_compute)", len(stores)), + zap.Any("dispatchPolicy", tiflashcompute.GetDispatchPolicy(dispatchPolicy))) + if log.GetLevel() <= zap.DebugLevel { + debugTaskMap := make(map[string]string, len(taskMap)) + for s, b := range taskMap { + debugTaskMap[s] = fmt.Sprintf("addr: %s; regionInfos: %v", b.storeAddr, b.regionInfos) + } + logutil.BgLogger().Debug("detailed info buildBatchCopTasksConsistentHashForPD", zap.Any("taskMap", debugTaskMap), zap.Any("allStores", storesStr)) + } + break + } + + if elapsed := time.Since(start); elapsed > time.Millisecond*500 { + logutil.BgLogger().Warn("buildBatchCopTasksConsistentHashForPD takes too much time", + zap.Duration("total elapsed", elapsed), + zap.Int("retryNum", retryNum), + zap.Duration("splitKeyElapsed", splitKeyElapsed), + zap.Duration("getStoreElapsed", getStoreElapsed), + zap.Int("range len", rangesLen), + zap.Int("copTaskNum", copTaskNum), + zap.Int("batchCopTaskNum", len(res))) + } + failpointCheckForConsistentHash(res) + return res, nil +} diff --git a/pkg/store/copr/binding__failpoint_binding__.go b/pkg/store/copr/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..3a49604367f85 --- /dev/null +++ b/pkg/store/copr/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package copr + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/store/copr/coprocessor.go b/pkg/store/copr/coprocessor.go index 6b59d8096784d..4b831d0c72e3b 100644 --- a/pkg/store/copr/coprocessor.go +++ b/pkg/store/copr/coprocessor.go @@ -111,9 +111,9 @@ func (c *CopClient) Send(ctx context.Context, req *kv.Request, variables any, op // BuildCopIterator builds the iterator without calling `open`. func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars *tikv.Variables, option *kv.ClientSendOption) (*copIterator, kv.Response) { eventCb := option.EventCb - failpoint.Inject("DisablePaging", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("DisablePaging")); _err_ == nil { req.Paging.Enable = false - }) + } if req.StoreType == kv.TiDB { // coprocessor on TiDB doesn't support paging req.Paging.Enable = false @@ -122,13 +122,13 @@ func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars // coprocessor request but type is not DAG req.Paging.Enable = false } - failpoint.Inject("checkKeyRangeSortedForPaging", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("checkKeyRangeSortedForPaging")); _err_ == nil { if req.Paging.Enable { if !req.KeyRanges.IsFullySorted() { logutil.BgLogger().Fatal("distsql request key range not sorted!") } } - }) + } if !checkStoreBatchCopr(req) { req.StoreBatchSize = 0 } @@ -327,11 +327,11 @@ func buildCopTasks(bo *Backoffer, ranges *KeyRanges, opt *buildCopTaskOpt) ([]*c } rangesPerTaskLimit := rangesPerTask - failpoint.Inject("setRangesPerTask", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("setRangesPerTask")); _err_ == nil { if v, ok := val.(int); ok { rangesPerTaskLimit = v } - }) + } // TODO(youjiali1995): is there any request type that needn't be split by buckets? locs, err := cache.SplitKeyRangesByBuckets(bo, ranges) @@ -789,12 +789,12 @@ func init() { // send the result back. func (worker *copIteratorWorker) run(ctx context.Context) { defer func() { - failpoint.Inject("ticase-4169", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ticase-4169")); _err_ == nil { if val.(bool) { worker.memTracker.Consume(10 * MockResponseSizeForTest) worker.memTracker.Consume(10 * MockResponseSizeForTest) } - }) + } worker.wg.Done() }() // 16KB ballast helps grow the stack to the requirement of copIteratorWorker. @@ -865,12 +865,12 @@ func (it *copIterator) open(ctx context.Context, enabledRateLimitAction, enableC } taskSender.respChan = it.respChan it.actionOnExceed.setEnabled(enabledRateLimitAction) - failpoint.Inject("ticase-4171", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ticase-4171")); _err_ == nil { if val.(bool) { it.memTracker.Consume(10 * MockResponseSizeForTest) it.memTracker.Consume(10 * MockResponseSizeForTest) } - }) + } go taskSender.run(it.req.ConnID) } @@ -897,7 +897,7 @@ func (sender *copIteratorTaskSender) run(connID uint64) { break } if connID > 0 { - failpoint.Inject("pauseCopIterTaskSender", func() {}) + failpoint.Eval(_curpkg_("pauseCopIterTaskSender")) } } close(sender.taskCh) @@ -911,7 +911,7 @@ func (sender *copIteratorTaskSender) run(connID uint64) { } func (it *copIterator) recvFromRespCh(ctx context.Context, respCh <-chan *copResponse) (resp *copResponse, ok bool, exit bool) { - failpoint.InjectCall("CtxCancelBeforeReceive", ctx) + failpoint.Call(_curpkg_("CtxCancelBeforeReceive"), ctx) ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() for { @@ -919,13 +919,13 @@ func (it *copIterator) recvFromRespCh(ctx context.Context, respCh <-chan *copRes case resp, ok = <-respCh: if it.memTracker != nil && resp != nil { consumed := resp.MemSize() - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { if val.(bool) { if resp != finCopResp { consumed = MockResponseSizeForTest } } - }) + } it.memTracker.Consume(-consumed) } return @@ -991,14 +991,14 @@ func (sender *copIteratorTaskSender) sendToTaskCh(t *copTask, sendTo chan<- *cop func (worker *copIteratorWorker) sendToRespCh(resp *copResponse, respCh chan<- *copResponse, checkOOM bool) (exit bool) { if worker.memTracker != nil && checkOOM { consumed := resp.MemSize() - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { if val.(bool) { if resp != finCopResp { consumed = MockResponseSizeForTest } } - }) - failpoint.Inject("ConsumeRandomPanic", nil) + } + failpoint.Eval(_curpkg_("ConsumeRandomPanic")) worker.memTracker.Consume(consumed) } select { @@ -1022,16 +1022,16 @@ func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { ) defer func() { if resp == nil { - failpoint.Inject("ticase-4170", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ticase-4170")); _err_ == nil { if val.(bool) { it.memTracker.Consume(10 * MockResponseSizeForTest) it.memTracker.Consume(10 * MockResponseSizeForTest) } - }) + } } }() // wait unit at least 5 copResponse received. - failpoint.Inject("testRateLimitActionMockWaitMax", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockWaitMax")); _err_ == nil { if val.(bool) { // we only need to trigger oom at least once. if len(it.tasks) > 9 { @@ -1040,7 +1040,7 @@ func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { } } } - }) + } // If data order matters, response should be returned in the same order as copTask slice. // Otherwise all responses are returned from a single channel. if it.respChan != nil { @@ -1117,11 +1117,11 @@ func chooseBackoffer(ctx context.Context, backoffermap map[uint64]*Backoffer, ta return bo } boMaxSleep := CopNextMaxBackoff - failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("ReduceCopNextMaxBackoff")); _err_ == nil { if value.(bool) { boMaxSleep = 2 } - }) + } newbo := backoff.NewBackofferWithVars(ctx, boMaxSleep, worker.vars) backoffermap[task.region.GetID()] = newbo return newbo @@ -1165,11 +1165,11 @@ func (worker *copIteratorWorker) handleTask(ctx context.Context, task *copTask, // handleTaskOnce handles single copTask, successful results are send to channel. // If error happened, returns error. If region split or meet lock, returns the remain tasks. func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { - failpoint.Inject("handleTaskOnceError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("handleTaskOnceError")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("mock handleTaskOnce error")) + return nil, errors.New("mock handleTaskOnce error") } - }) + } if task.paging { task.pagingTaskIdx = atomic.AddUint32(worker.pagingTaskIdx, 1) @@ -1222,10 +1222,10 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch if task.tikvClientReadTimeout > 0 { timeout = time.Duration(task.tikvClientReadTimeout) * time.Millisecond } - failpoint.Inject("sleepCoprRequest", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("sleepCoprRequest")); _err_ == nil { //nolint:durationcheck time.Sleep(time.Millisecond * time.Duration(v.(int))) - }) + } if worker.req.RunawayChecker != nil { if err := worker.req.RunawayChecker.BeforeCopRequest(req); err != nil { @@ -1259,14 +1259,14 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch timeout, getEndPointType(task.storeType), task.storeAddr, ops...) err = derr.ToTiDBErr(err) if worker.req.RunawayChecker != nil { - failpoint.Inject("sleepCoprAfterReq", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("sleepCoprAfterReq")); _err_ == nil { //nolint:durationcheck value := v.(int) time.Sleep(time.Millisecond * time.Duration(value)) if value > 50 { err = errors.Errorf("Coprocessor task terminated due to exceeding the deadline") } - }) + } err = worker.req.RunawayChecker.CheckCopRespError(err) } if err != nil { @@ -1540,9 +1540,9 @@ func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, rpcCtx *t }, } task := batchedTask.task - failpoint.Inject("batchCopRegionError", func() { + if _, _err_ := failpoint.Eval(_curpkg_("batchCopRegionError")); _err_ == nil { batchResp.RegionError = &errorpb.Error{} - }) + } if regionErr := batchResp.GetRegionError(); regionErr != nil { errStr := fmt.Sprintf("region_id:%v, region_ver:%v, store_type:%s, peer_addr:%s, error:%s", task.region.GetID(), task.region.GetVer(), task.storeType.Name(), task.storeAddr, regionErr.String()) @@ -1776,11 +1776,11 @@ func (worker *copIteratorWorker) handleCollectExecutionInfo(bo *Backoffer, rpcCt if !worker.enableCollectExecutionInfo { return } - failpoint.Inject("disable-collect-execution", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("disable-collect-execution")); _err_ == nil { if val.(bool) { panic("shouldn't reachable") } - }) + } if resp.detail == nil { resp.detail = new(CopRuntimeStats) } @@ -2023,13 +2023,13 @@ func (e *rateLimitAction) Action(t *memory.Tracker) { } return } - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { if val.(bool) { if e.cond.triggerCountForTest+e.cond.remainingTokenNum != e.totalTokenNum { panic("triggerCount + remainingTokenNum not equal to totalTokenNum") } } - }) + } logutil.BgLogger().Info("memory exceeds quota, destroy one token now.", zap.Int64("consumed", t.BytesConsumed()), zap.Int64("quota", t.GetBytesLimit()), @@ -2143,9 +2143,9 @@ func optRowHint(req *kv.Request) bool { // disable extra concurrency for internal tasks. return false } - failpoint.Inject("disableFixedRowCountHint", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("disableFixedRowCountHint")); _err_ == nil { opt = false - }) + } return opt } diff --git a/pkg/store/copr/coprocessor.go__failpoint_stash__ b/pkg/store/copr/coprocessor.go__failpoint_stash__ new file mode 100644 index 0000000000000..6b59d8096784d --- /dev/null +++ b/pkg/store/copr/coprocessor.go__failpoint_stash__ @@ -0,0 +1,2170 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package copr + +import ( + "context" + "fmt" + "math" + "net" + "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + "unsafe" + + "github.com/gogo/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/domain/resourcegroup" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/kv" + tidbmetrics "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + copr_metrics "github.com/pingcap/tidb/pkg/store/copr/metrics" + "github.com/pingcap/tidb/pkg/store/driver/backoff" + derr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/store/driver/options" + util2 "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/paging" + "github.com/pingcap/tidb/pkg/util/size" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/pingcap/tidb/pkg/util/trxevents" + "github.com/pingcap/tipb/go-tipb" + "github.com/tikv/client-go/v2/metrics" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/tikvrpc/interceptor" + "github.com/tikv/client-go/v2/txnkv/txnlock" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" + "github.com/tikv/client-go/v2/util" + "go.uber.org/zap" +) + +// Maximum total sleep time(in ms) for kv/cop commands. +const ( + copBuildTaskMaxBackoff = 5000 + CopNextMaxBackoff = 20000 + CopSmallTaskRow = 32 // 32 is the initial batch size of TiKV + smallTaskSigma = 0.5 + smallConcPerCore = 20 +) + +// CopClient is coprocessor client. +type CopClient struct { + kv.RequestTypeSupportedChecker + store *Store + replicaReadSeed uint32 +} + +// Send builds the request and gets the coprocessor iterator response. +func (c *CopClient) Send(ctx context.Context, req *kv.Request, variables any, option *kv.ClientSendOption) kv.Response { + vars, ok := variables.(*tikv.Variables) + if !ok { + return copErrorResponse{errors.Errorf("unsupported variables:%+v", variables)} + } + if req.StoreType == kv.TiFlash && req.BatchCop { + logutil.BgLogger().Debug("send batch requests") + return c.sendBatch(ctx, req, vars, option) + } + ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTs) + ctx = context.WithValue(ctx, util.RequestSourceKey, req.RequestSource) + ctx = interceptor.WithRPCInterceptor(ctx, interceptor.GetRPCInterceptorFromCtx(ctx)) + enabledRateLimitAction := option.EnabledRateLimitAction + sessionMemTracker := option.SessionMemTracker + it, errRes := c.BuildCopIterator(ctx, req, vars, option) + if errRes != nil { + return errRes + } + ctx = context.WithValue(ctx, tikv.RPCCancellerCtxKey{}, it.rpcCancel) + if sessionMemTracker != nil && enabledRateLimitAction { + sessionMemTracker.FallbackOldAndSetNewAction(it.actionOnExceed) + } + it.open(ctx, enabledRateLimitAction, option.EnableCollectExecutionInfo) + return it +} + +// BuildCopIterator builds the iterator without calling `open`. +func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars *tikv.Variables, option *kv.ClientSendOption) (*copIterator, kv.Response) { + eventCb := option.EventCb + failpoint.Inject("DisablePaging", func(_ failpoint.Value) { + req.Paging.Enable = false + }) + if req.StoreType == kv.TiDB { + // coprocessor on TiDB doesn't support paging + req.Paging.Enable = false + } + if req.Tp != kv.ReqTypeDAG { + // coprocessor request but type is not DAG + req.Paging.Enable = false + } + failpoint.Inject("checkKeyRangeSortedForPaging", func(_ failpoint.Value) { + if req.Paging.Enable { + if !req.KeyRanges.IsFullySorted() { + logutil.BgLogger().Fatal("distsql request key range not sorted!") + } + } + }) + if !checkStoreBatchCopr(req) { + req.StoreBatchSize = 0 + } + + bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, vars) + var ( + tasks []*copTask + err error + ) + tryRowHint := optRowHint(req) + elapsed := time.Duration(0) + buildOpt := &buildCopTaskOpt{ + req: req, + cache: c.store.GetRegionCache(), + eventCb: eventCb, + respChan: req.KeepOrder, + elapsed: &elapsed, + } + buildTaskFunc := func(ranges []kv.KeyRange, hints []int) error { + keyRanges := NewKeyRanges(ranges) + if tryRowHint { + buildOpt.rowHints = hints + } + tasksFromRanges, err := buildCopTasks(bo, keyRanges, buildOpt) + if err != nil { + return err + } + if len(tasks) == 0 { + tasks = tasksFromRanges + return nil + } + tasks = append(tasks, tasksFromRanges...) + return nil + } + // Here we build the task by partition, not directly by region. + // This is because it's possible that TiDB merge multiple small partition into one region which break some assumption. + // Keep it split by partition would be more safe. + err = req.KeyRanges.ForEachPartitionWithErr(buildTaskFunc) + // only batch store requests in first build. + req.StoreBatchSize = 0 + reqType := "null" + if req.ClosestReplicaReadAdjuster != nil { + reqType = "miss" + if req.ClosestReplicaReadAdjuster(req, len(tasks)) { + reqType = "hit" + } + } + tidbmetrics.DistSQLCoprClosestReadCounter.WithLabelValues(reqType).Inc() + if err != nil { + return nil, copErrorResponse{err} + } + it := &copIterator{ + store: c.store, + req: req, + concurrency: req.Concurrency, + finishCh: make(chan struct{}), + vars: vars, + memTracker: req.MemTracker, + replicaReadSeed: c.replicaReadSeed, + rpcCancel: tikv.NewRPCanceller(), + buildTaskElapsed: *buildOpt.elapsed, + runawayChecker: req.RunawayChecker, + } + // Pipelined-dml can flush locks when it is still reading. + // The coprocessor of the txn should not be blocked by itself. + // It should be the only case where a coprocessor can read locks of the same ts. + // + // But when start_ts is not obtained from PD, + // the start_ts could conflict with another pipelined-txn's start_ts. + // in which case the locks of same ts cannot be ignored. + // We rely on the assumption: start_ts is not from PD => this is a stale read. + if !req.IsStaleness { + it.resolvedLocks.Put(req.StartTs) + } + it.tasks = tasks + if it.concurrency > len(tasks) { + it.concurrency = len(tasks) + } + if tryRowHint { + var smallTasks int + smallTasks, it.smallTaskConcurrency = smallTaskConcurrency(tasks, c.store.numcpu) + if len(tasks)-smallTasks < it.concurrency { + it.concurrency = len(tasks) - smallTasks + } + } + if it.concurrency < 1 { + // Make sure that there is at least one worker. + it.concurrency = 1 + } + + if it.req.KeepOrder { + if it.smallTaskConcurrency > 20 { + it.smallTaskConcurrency = 20 + } + it.sendRate = util.NewRateLimit(2 * (it.concurrency + it.smallTaskConcurrency)) + it.respChan = nil + } else { + it.respChan = make(chan *copResponse) + it.sendRate = util.NewRateLimit(it.concurrency + it.smallTaskConcurrency) + } + it.actionOnExceed = newRateLimitAction(uint(it.sendRate.GetCapacity())) + return it, nil +} + +// copTask contains a related Region and KeyRange for a kv.Request. +type copTask struct { + taskID uint64 + region tikv.RegionVerID + bucketsVer uint64 + ranges *KeyRanges + + respChan chan *copResponse + storeAddr string + cmdType tikvrpc.CmdType + storeType kv.StoreType + + eventCb trxevents.EventCallback + paging bool + pagingSize uint64 + pagingTaskIdx uint32 + + partitionIndex int64 // used by balanceBatchCopTask in PartitionTableScan + requestSource util.RequestSource + RowCountHint int // used for extra concurrency of small tasks, -1 for unknown row count + batchTaskList map[uint64]*batchedCopTask + + // when this task is batched and the leader's wait duration exceeds the load-based threshold, + // we set this field to the target replica store ID and redirect the request to the replica. + redirect2Replica *uint64 + busyThreshold time.Duration + meetLockFallback bool + + // timeout value for one kv readonly request + tikvClientReadTimeout uint64 + // firstReadType is used to indicate the type of first read when retrying. + firstReadType string +} + +type batchedCopTask struct { + task *copTask + region coprocessor.RegionInfo + storeID uint64 + peer *metapb.Peer + loadBasedReplicaRetry bool +} + +func (r *copTask) String() string { + return fmt.Sprintf("region(%d %d %d) ranges(%d) store(%s)", + r.region.GetID(), r.region.GetConfVer(), r.region.GetVer(), r.ranges.Len(), r.storeAddr) +} + +func (r *copTask) ToPBBatchTasks() []*coprocessor.StoreBatchTask { + if len(r.batchTaskList) == 0 { + return nil + } + pbTasks := make([]*coprocessor.StoreBatchTask, 0, len(r.batchTaskList)) + for _, task := range r.batchTaskList { + storeBatchTask := &coprocessor.StoreBatchTask{ + RegionId: task.region.GetRegionId(), + RegionEpoch: task.region.GetRegionEpoch(), + Peer: task.peer, + Ranges: task.region.GetRanges(), + TaskId: task.task.taskID, + } + pbTasks = append(pbTasks, storeBatchTask) + } + return pbTasks +} + +// rangesPerTask limits the length of the ranges slice sent in one copTask. +const rangesPerTask = 25000 + +type buildCopTaskOpt struct { + req *kv.Request + cache *RegionCache + eventCb trxevents.EventCallback + respChan bool + rowHints []int + elapsed *time.Duration + // ignoreTiKVClientReadTimeout is used to ignore tikv_client_read_timeout configuration, use default timeout instead. + ignoreTiKVClientReadTimeout bool +} + +func buildCopTasks(bo *Backoffer, ranges *KeyRanges, opt *buildCopTaskOpt) ([]*copTask, error) { + req, cache, eventCb, hints := opt.req, opt.cache, opt.eventCb, opt.rowHints + start := time.Now() + defer tracing.StartRegion(bo.GetCtx(), "copr.buildCopTasks").End() + cmdType := tikvrpc.CmdCop + if req.StoreType == kv.TiDB { + return buildTiDBMemCopTasks(ranges, req) + } + rangesLen := ranges.Len() + // something went wrong, disable hints to avoid out of range index. + if len(hints) != rangesLen { + hints = nil + } + + rangesPerTaskLimit := rangesPerTask + failpoint.Inject("setRangesPerTask", func(val failpoint.Value) { + if v, ok := val.(int); ok { + rangesPerTaskLimit = v + } + }) + + // TODO(youjiali1995): is there any request type that needn't be split by buckets? + locs, err := cache.SplitKeyRangesByBuckets(bo, ranges) + if err != nil { + return nil, errors.Trace(err) + } + // Channel buffer is 2 for handling region split. + // In a common case, two region split tasks will not be blocked. + chanSize := 2 + // in paging request, a request will be returned in multi batches, + // enlarge the channel size to avoid the request blocked by buffer full. + if req.Paging.Enable { + chanSize = 18 + } + + var builder taskBuilder + if req.StoreBatchSize > 0 && hints != nil { + builder = newBatchTaskBuilder(bo, req, cache, req.ReplicaRead) + } else { + builder = newLegacyTaskBuilder(len(locs)) + } + origRangeIdx := 0 + for _, loc := range locs { + // TiKV will return gRPC error if the message is too large. So we need to limit the length of the ranges slice + // to make sure the message can be sent successfully. + rLen := loc.Ranges.Len() + // If this is a paging request, we set the paging size to minPagingSize, + // the size will grow every round. + pagingSize := uint64(0) + if req.Paging.Enable { + pagingSize = req.Paging.MinPagingSize + } + for i := 0; i < rLen; { + nextI := min(i+rangesPerTaskLimit, rLen) + hint := -1 + // calculate the row count hint + if hints != nil { + startKey, endKey := loc.Ranges.RefAt(i).StartKey, loc.Ranges.RefAt(nextI-1).EndKey + // move to the previous range if startKey of current range is lower than endKey of previous location. + // In the following example, task1 will move origRangeIdx to region(i, z). + // When counting the row hint for task2, we need to move origRangeIdx back to region(a, h). + // |<- region(a, h) ->| |<- region(i, z) ->| + // |<- task1 ->| |<- task2 ->| ... + if origRangeIdx > 0 && ranges.RefAt(origRangeIdx-1).EndKey.Cmp(startKey) > 0 { + origRangeIdx-- + } + hint = 0 + for nextOrigRangeIdx := origRangeIdx; nextOrigRangeIdx < ranges.Len(); nextOrigRangeIdx++ { + rangeStart := ranges.RefAt(nextOrigRangeIdx).StartKey + if rangeStart.Cmp(endKey) > 0 { + origRangeIdx = nextOrigRangeIdx + break + } + hint += hints[nextOrigRangeIdx] + } + } + task := &copTask{ + region: loc.Location.Region, + bucketsVer: loc.getBucketVersion(), + ranges: loc.Ranges.Slice(i, nextI), + cmdType: cmdType, + storeType: req.StoreType, + eventCb: eventCb, + paging: req.Paging.Enable, + pagingSize: pagingSize, + requestSource: req.RequestSource, + RowCountHint: hint, + busyThreshold: req.StoreBusyThreshold, + } + if !opt.ignoreTiKVClientReadTimeout { + task.tikvClientReadTimeout = req.TiKVClientReadTimeout + } + // only keep-order need chan inside task. + // tasks by region error will reuse the channel of parent task. + if req.KeepOrder && opt.respChan { + task.respChan = make(chan *copResponse, chanSize) + } + if err = builder.handle(task); err != nil { + return nil, err + } + i = nextI + if req.Paging.Enable { + if req.LimitSize != 0 && req.LimitSize < pagingSize { + // disable paging for small limit. + task.paging = false + task.pagingSize = 0 + } else { + pagingSize = paging.GrowPagingSize(pagingSize, req.Paging.MaxPagingSize) + } + } + } + } + + if req.Desc { + builder.reverse() + } + tasks := builder.build() + elapsed := time.Since(start) + if elapsed > time.Millisecond*500 { + logutil.BgLogger().Warn("buildCopTasks takes too much time", + zap.Duration("elapsed", elapsed), + zap.Int("range len", rangesLen), + zap.Int("task len", len(tasks))) + } + if opt.elapsed != nil { + *opt.elapsed = *opt.elapsed + elapsed + } + metrics.TxnRegionsNumHistogramWithCoprocessor.Observe(float64(builder.regionNum())) + return tasks, nil +} + +type taskBuilder interface { + handle(*copTask) error + reverse() + build() []*copTask + regionNum() int +} + +type legacyTaskBuilder struct { + tasks []*copTask +} + +func newLegacyTaskBuilder(hint int) *legacyTaskBuilder { + return &legacyTaskBuilder{ + tasks: make([]*copTask, 0, hint), + } +} + +func (b *legacyTaskBuilder) handle(task *copTask) error { + b.tasks = append(b.tasks, task) + return nil +} + +func (b *legacyTaskBuilder) regionNum() int { + return len(b.tasks) +} + +func (b *legacyTaskBuilder) reverse() { + reverseTasks(b.tasks) +} + +func (b *legacyTaskBuilder) build() []*copTask { + return b.tasks +} + +type storeReplicaKey struct { + storeID uint64 + replicaRead bool +} + +type batchStoreTaskBuilder struct { + bo *Backoffer + req *kv.Request + cache *RegionCache + taskID uint64 + limit int + store2Idx map[storeReplicaKey]int + tasks []*copTask + replicaRead kv.ReplicaReadType +} + +func newBatchTaskBuilder(bo *Backoffer, req *kv.Request, cache *RegionCache, replicaRead kv.ReplicaReadType) *batchStoreTaskBuilder { + return &batchStoreTaskBuilder{ + bo: bo, + req: req, + cache: cache, + taskID: 0, + limit: req.StoreBatchSize, + store2Idx: make(map[storeReplicaKey]int, 16), + tasks: make([]*copTask, 0, 16), + replicaRead: replicaRead, + } +} + +func (b *batchStoreTaskBuilder) handle(task *copTask) (err error) { + b.taskID++ + task.taskID = b.taskID + handled := false + defer func() { + if !handled && err == nil { + // fallback to non-batch way. It's mainly caused by region miss. + b.tasks = append(b.tasks, task) + } + }() + // only batch small tasks for memory control. + if b.limit <= 0 || !isSmallTask(task) { + return nil + } + batchedTask, err := b.cache.BuildBatchTask(b.bo, b.req, task, b.replicaRead) + if err != nil { + return err + } + if batchedTask == nil { + return nil + } + key := storeReplicaKey{ + storeID: batchedTask.storeID, + replicaRead: batchedTask.loadBasedReplicaRetry, + } + if idx, ok := b.store2Idx[key]; !ok || len(b.tasks[idx].batchTaskList) >= b.limit { + if batchedTask.loadBasedReplicaRetry { + // If the task is dispatched to leader because all followers are busy, + // task.redirect2Replica != nil means the busy threshold shouldn't take effect again. + batchedTask.task.redirect2Replica = &batchedTask.storeID + } + b.tasks = append(b.tasks, batchedTask.task) + b.store2Idx[key] = len(b.tasks) - 1 + } else { + if b.tasks[idx].batchTaskList == nil { + b.tasks[idx].batchTaskList = make(map[uint64]*batchedCopTask, b.limit) + // disable paging for batched task. + b.tasks[idx].paging = false + b.tasks[idx].pagingSize = 0 + } + if task.RowCountHint > 0 { + b.tasks[idx].RowCountHint += task.RowCountHint + } + b.tasks[idx].batchTaskList[task.taskID] = batchedTask + } + handled = true + return nil +} + +func (b *batchStoreTaskBuilder) regionNum() int { + // we allocate b.taskID for each region task, so the final b.taskID is equal to the related region number. + return int(b.taskID) +} + +func (b *batchStoreTaskBuilder) reverse() { + reverseTasks(b.tasks) +} + +func (b *batchStoreTaskBuilder) build() []*copTask { + return b.tasks +} + +func buildTiDBMemCopTasks(ranges *KeyRanges, req *kv.Request) ([]*copTask, error) { + servers, err := infosync.GetAllServerInfo(context.Background()) + if err != nil { + return nil, err + } + cmdType := tikvrpc.CmdCop + tasks := make([]*copTask, 0, len(servers)) + for _, ser := range servers { + if req.TiDBServerID > 0 && req.TiDBServerID != ser.ServerIDGetter() { + continue + } + + addr := net.JoinHostPort(ser.IP, strconv.FormatUint(uint64(ser.StatusPort), 10)) + tasks = append(tasks, &copTask{ + ranges: ranges, + respChan: make(chan *copResponse, 2), + cmdType: cmdType, + storeType: req.StoreType, + storeAddr: addr, + RowCountHint: -1, + }) + } + return tasks, nil +} + +func reverseTasks(tasks []*copTask) { + for i := 0; i < len(tasks)/2; i++ { + j := len(tasks) - i - 1 + tasks[i], tasks[j] = tasks[j], tasks[i] + } +} + +func isSmallTask(task *copTask) bool { + // strictly, only RowCountHint == -1 stands for unknown task rows, + // but when RowCountHint == 0, it may be caused by initialized value, + // to avoid the future bugs, let the tasks with RowCountHint == 0 be non-small tasks. + return task.RowCountHint > 0 && + (len(task.batchTaskList) == 0 && task.RowCountHint <= CopSmallTaskRow) || + (len(task.batchTaskList) > 0 && task.RowCountHint <= 2*CopSmallTaskRow) +} + +// smallTaskConcurrency counts the small tasks of tasks, +// then returns the task count and extra concurrency for small tasks. +func smallTaskConcurrency(tasks []*copTask, numcpu int) (int, int) { + res := 0 + for _, task := range tasks { + if isSmallTask(task) { + res++ + } + } + if res == 0 { + return 0, 0 + } + // Calculate the extra concurrency for small tasks + // extra concurrency = tasks / (1 + sigma * sqrt(log(tasks ^ 2))) + extraConc := int(float64(res) / (1 + smallTaskSigma*math.Sqrt(2*math.Log(float64(res))))) + if numcpu <= 0 { + numcpu = 1 + } + smallTaskConcurrencyLimit := smallConcPerCore * numcpu + if extraConc > smallTaskConcurrencyLimit { + extraConc = smallTaskConcurrencyLimit + } + return res, extraConc +} + +// CopInfo is used to expose functions of copIterator. +type CopInfo interface { + // GetConcurrency returns the concurrency and small task concurrency. + GetConcurrency() (int, int) + // GetStoreBatchInfo returns the batched and fallback num. + GetStoreBatchInfo() (uint64, uint64) + // GetBuildTaskElapsed returns the duration of building task. + GetBuildTaskElapsed() time.Duration +} + +type copIterator struct { + store *Store + req *kv.Request + concurrency int + smallTaskConcurrency int + finishCh chan struct{} + + // If keepOrder, results are stored in copTask.respChan, read them out one by one. + tasks []*copTask + // curr indicates the curr id of the finished copTask + curr int + + // sendRate controls the sending rate of copIteratorTaskSender + sendRate *util.RateLimit + + // Otherwise, results are stored in respChan. + respChan chan *copResponse + + vars *tikv.Variables + + memTracker *memory.Tracker + + replicaReadSeed uint32 + + rpcCancel *tikv.RPCCanceller + + wg sync.WaitGroup + // closed represents when the Close is called. + // There are two cases we need to close the `finishCh` channel, one is when context is done, the other one is + // when the Close is called. we use atomic.CompareAndSwap `closed` to make sure the channel is not closed twice. + closed uint32 + + resolvedLocks util.TSSet + committedLocks util.TSSet + + actionOnExceed *rateLimitAction + pagingTaskIdx uint32 + + buildTaskElapsed time.Duration + storeBatchedNum atomic.Uint64 + storeBatchedFallbackNum atomic.Uint64 + + runawayChecker *resourcegroup.RunawayChecker + unconsumedStats *unconsumedCopRuntimeStats +} + +// copIteratorWorker receives tasks from copIteratorTaskSender, handles tasks and sends the copResponse to respChan. +type copIteratorWorker struct { + taskCh <-chan *copTask + wg *sync.WaitGroup + store *Store + req *kv.Request + respChan chan<- *copResponse + finishCh <-chan struct{} + vars *tikv.Variables + kvclient *txnsnapshot.ClientHelper + + memTracker *memory.Tracker + + replicaReadSeed uint32 + + enableCollectExecutionInfo bool + pagingTaskIdx *uint32 + + storeBatchedNum *atomic.Uint64 + storeBatchedFallbackNum *atomic.Uint64 + unconsumedStats *unconsumedCopRuntimeStats +} + +// copIteratorTaskSender sends tasks to taskCh then wait for the workers to exit. +type copIteratorTaskSender struct { + taskCh chan<- *copTask + smallTaskCh chan<- *copTask + wg *sync.WaitGroup + tasks []*copTask + finishCh <-chan struct{} + respChan chan<- *copResponse + sendRate *util.RateLimit +} + +type copResponse struct { + pbResp *coprocessor.Response + detail *CopRuntimeStats + startKey kv.Key + err error + respSize int64 + respTime time.Duration +} + +const sizeofExecDetails = int(unsafe.Sizeof(execdetails.ExecDetails{})) + +// GetData implements the kv.ResultSubset GetData interface. +func (rs *copResponse) GetData() []byte { + return rs.pbResp.Data +} + +// GetStartKey implements the kv.ResultSubset GetStartKey interface. +func (rs *copResponse) GetStartKey() kv.Key { + return rs.startKey +} + +func (rs *copResponse) GetCopRuntimeStats() *CopRuntimeStats { + return rs.detail +} + +// MemSize returns how many bytes of memory this response use +func (rs *copResponse) MemSize() int64 { + if rs.respSize != 0 { + return rs.respSize + } + if rs == finCopResp { + return 0 + } + + // ignore rs.err + rs.respSize += int64(cap(rs.startKey)) + if rs.detail != nil { + rs.respSize += int64(sizeofExecDetails) + } + if rs.pbResp != nil { + // Using a approximate size since it's hard to get a accurate value. + rs.respSize += int64(rs.pbResp.Size()) + } + return rs.respSize +} + +func (rs *copResponse) RespTime() time.Duration { + return rs.respTime +} + +const minLogCopTaskTime = 300 * time.Millisecond + +// When the worker finished `handleTask`, we need to notify the copIterator that there is one task finished. +// For the non-keep-order case, we send a finCopResp into the respCh after `handleTask`. When copIterator recv +// finCopResp from the respCh, it will be aware that there is one task finished. +var finCopResp *copResponse + +func init() { + finCopResp = &copResponse{} +} + +// run is a worker function that get a copTask from channel, handle it and +// send the result back. +func (worker *copIteratorWorker) run(ctx context.Context) { + defer func() { + failpoint.Inject("ticase-4169", func(val failpoint.Value) { + if val.(bool) { + worker.memTracker.Consume(10 * MockResponseSizeForTest) + worker.memTracker.Consume(10 * MockResponseSizeForTest) + } + }) + worker.wg.Done() + }() + // 16KB ballast helps grow the stack to the requirement of copIteratorWorker. + // This reduces the `morestack` call during the execution of `handleTask`, thus improvement the efficiency of TiDB. + // TODO: remove ballast after global pool is applied. + ballast := make([]byte, 16*size.KB) + for task := range worker.taskCh { + respCh := worker.respChan + if respCh == nil { + respCh = task.respChan + } + worker.handleTask(ctx, task, respCh) + if worker.respChan != nil { + // When a task is finished by the worker, send a finCopResp into channel to notify the copIterator that + // there is a task finished. + worker.sendToRespCh(finCopResp, worker.respChan, false) + } + if task.respChan != nil { + close(task.respChan) + } + if worker.finished() { + return + } + } + runtime.KeepAlive(ballast) +} + +// open starts workers and sender goroutines. +func (it *copIterator) open(ctx context.Context, enabledRateLimitAction, enableCollectExecutionInfo bool) { + taskCh := make(chan *copTask, 1) + smallTaskCh := make(chan *copTask, 1) + it.unconsumedStats = &unconsumedCopRuntimeStats{} + it.wg.Add(it.concurrency + it.smallTaskConcurrency) + // Start it.concurrency number of workers to handle cop requests. + for i := 0; i < it.concurrency+it.smallTaskConcurrency; i++ { + var ch chan *copTask + if i < it.concurrency { + ch = taskCh + } else { + ch = smallTaskCh + } + worker := &copIteratorWorker{ + taskCh: ch, + wg: &it.wg, + store: it.store, + req: it.req, + respChan: it.respChan, + finishCh: it.finishCh, + vars: it.vars, + kvclient: txnsnapshot.NewClientHelper(it.store.store, &it.resolvedLocks, &it.committedLocks, false), + memTracker: it.memTracker, + replicaReadSeed: it.replicaReadSeed, + enableCollectExecutionInfo: enableCollectExecutionInfo, + pagingTaskIdx: &it.pagingTaskIdx, + storeBatchedNum: &it.storeBatchedNum, + storeBatchedFallbackNum: &it.storeBatchedFallbackNum, + unconsumedStats: it.unconsumedStats, + } + go worker.run(ctx) + } + taskSender := &copIteratorTaskSender{ + taskCh: taskCh, + smallTaskCh: smallTaskCh, + wg: &it.wg, + tasks: it.tasks, + finishCh: it.finishCh, + sendRate: it.sendRate, + } + taskSender.respChan = it.respChan + it.actionOnExceed.setEnabled(enabledRateLimitAction) + failpoint.Inject("ticase-4171", func(val failpoint.Value) { + if val.(bool) { + it.memTracker.Consume(10 * MockResponseSizeForTest) + it.memTracker.Consume(10 * MockResponseSizeForTest) + } + }) + go taskSender.run(it.req.ConnID) +} + +func (sender *copIteratorTaskSender) run(connID uint64) { + // Send tasks to feed the worker goroutines. + for _, t := range sender.tasks { + // we control the sending rate to prevent all tasks + // being done (aka. all of the responses are buffered) by copIteratorWorker. + // We keep the number of inflight tasks within the number of 2 * concurrency when Keep Order is true. + // If KeepOrder is false, the number equals the concurrency. + // It sends one more task if a task has been finished in copIterator.Next. + exit := sender.sendRate.GetToken(sender.finishCh) + if exit { + break + } + var sendTo chan<- *copTask + if isSmallTask(t) { + sendTo = sender.smallTaskCh + } else { + sendTo = sender.taskCh + } + exit = sender.sendToTaskCh(t, sendTo) + if exit { + break + } + if connID > 0 { + failpoint.Inject("pauseCopIterTaskSender", func() {}) + } + } + close(sender.taskCh) + close(sender.smallTaskCh) + + // Wait for worker goroutines to exit. + sender.wg.Wait() + if sender.respChan != nil { + close(sender.respChan) + } +} + +func (it *copIterator) recvFromRespCh(ctx context.Context, respCh <-chan *copResponse) (resp *copResponse, ok bool, exit bool) { + failpoint.InjectCall("CtxCancelBeforeReceive", ctx) + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + for { + select { + case resp, ok = <-respCh: + if it.memTracker != nil && resp != nil { + consumed := resp.MemSize() + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val.(bool) { + if resp != finCopResp { + consumed = MockResponseSizeForTest + } + } + }) + it.memTracker.Consume(-consumed) + } + return + case <-it.finishCh: + exit = true + return + case <-ticker.C: + killed := atomic.LoadUint32(it.vars.Killed) + if killed != 0 { + logutil.Logger(ctx).Info( + "a killed signal is received", + zap.Uint32("signal", killed), + ) + resp = &copResponse{err: derr.ErrQueryInterrupted} + ok = true + return + } + case <-ctx.Done(): + // We select the ctx.Done() in the thread of `Next` instead of in the worker to avoid the cost of `WithCancel`. + if atomic.CompareAndSwapUint32(&it.closed, 0, 1) { + close(it.finishCh) + } + exit = true + return + } + } +} + +// GetConcurrency returns the concurrency and small task concurrency. +func (it *copIterator) GetConcurrency() (int, int) { + return it.concurrency, it.smallTaskConcurrency +} + +// GetStoreBatchInfo returns the batched and fallback num. +func (it *copIterator) GetStoreBatchInfo() (uint64, uint64) { + return it.storeBatchedNum.Load(), it.storeBatchedFallbackNum.Load() +} + +// GetBuildTaskElapsed returns the duration of building task. +func (it *copIterator) GetBuildTaskElapsed() time.Duration { + return it.buildTaskElapsed +} + +// GetSendRate returns the rate-limit object. +func (it *copIterator) GetSendRate() *util.RateLimit { + return it.sendRate +} + +// GetTasks returns the built tasks. +func (it *copIterator) GetTasks() []*copTask { + return it.tasks +} + +func (sender *copIteratorTaskSender) sendToTaskCh(t *copTask, sendTo chan<- *copTask) (exit bool) { + select { + case sendTo <- t: + case <-sender.finishCh: + exit = true + } + return +} + +func (worker *copIteratorWorker) sendToRespCh(resp *copResponse, respCh chan<- *copResponse, checkOOM bool) (exit bool) { + if worker.memTracker != nil && checkOOM { + consumed := resp.MemSize() + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val.(bool) { + if resp != finCopResp { + consumed = MockResponseSizeForTest + } + } + }) + failpoint.Inject("ConsumeRandomPanic", nil) + worker.memTracker.Consume(consumed) + } + select { + case respCh <- resp: + case <-worker.finishCh: + exit = true + } + return +} + +// MockResponseSizeForTest mock the response size +const MockResponseSizeForTest = 100 * 1024 * 1024 + +// Next returns next coprocessor result. +// NOTE: Use nil to indicate finish, so if the returned ResultSubset is not nil, reader should continue to call Next(). +func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { + var ( + resp *copResponse + ok bool + closed bool + ) + defer func() { + if resp == nil { + failpoint.Inject("ticase-4170", func(val failpoint.Value) { + if val.(bool) { + it.memTracker.Consume(10 * MockResponseSizeForTest) + it.memTracker.Consume(10 * MockResponseSizeForTest) + } + }) + } + }() + // wait unit at least 5 copResponse received. + failpoint.Inject("testRateLimitActionMockWaitMax", func(val failpoint.Value) { + if val.(bool) { + // we only need to trigger oom at least once. + if len(it.tasks) > 9 { + for it.memTracker.MaxConsumed() < 5*MockResponseSizeForTest { + time.Sleep(10 * time.Millisecond) + } + } + } + }) + // If data order matters, response should be returned in the same order as copTask slice. + // Otherwise all responses are returned from a single channel. + if it.respChan != nil { + // Get next fetched resp from chan + resp, ok, closed = it.recvFromRespCh(ctx, it.respChan) + if !ok || closed { + it.actionOnExceed.close() + return nil, errors.Trace(ctx.Err()) + } + if resp == finCopResp { + it.actionOnExceed.destroyTokenIfNeeded(func() { + it.sendRate.PutToken() + }) + return it.Next(ctx) + } + } else { + for { + if it.curr >= len(it.tasks) { + // Resp will be nil if iterator is finishCh. + it.actionOnExceed.close() + return nil, nil + } + task := it.tasks[it.curr] + resp, ok, closed = it.recvFromRespCh(ctx, task.respChan) + if closed { + // Close() is called or context cancelled/timeout, so Next() is invalid. + return nil, errors.Trace(ctx.Err()) + } + if ok { + break + } + it.actionOnExceed.destroyTokenIfNeeded(func() { + it.sendRate.PutToken() + }) + // Switch to next task. + it.tasks[it.curr] = nil + it.curr++ + } + } + + if resp.err != nil { + return nil, errors.Trace(resp.err) + } + + err := it.store.CheckVisibility(it.req.StartTs) + if err != nil { + return nil, errors.Trace(err) + } + return resp, nil +} + +// HasUnconsumedCopRuntimeStats indicate whether has unconsumed CopRuntimeStats. +type HasUnconsumedCopRuntimeStats interface { + // CollectUnconsumedCopRuntimeStats returns unconsumed CopRuntimeStats. + CollectUnconsumedCopRuntimeStats() []*CopRuntimeStats +} + +func (it *copIterator) CollectUnconsumedCopRuntimeStats() []*CopRuntimeStats { + if it == nil || it.unconsumedStats == nil { + return nil + } + it.unconsumedStats.Lock() + stats := make([]*CopRuntimeStats, 0, len(it.unconsumedStats.stats)) + stats = append(stats, it.unconsumedStats.stats...) + it.unconsumedStats.Unlock() + return stats +} + +// Associate each region with an independent backoffer. In this way, when multiple regions are +// unavailable, TiDB can execute very quickly without blocking +func chooseBackoffer(ctx context.Context, backoffermap map[uint64]*Backoffer, task *copTask, worker *copIteratorWorker) *Backoffer { + bo, ok := backoffermap[task.region.GetID()] + if ok { + return bo + } + boMaxSleep := CopNextMaxBackoff + failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { + if value.(bool) { + boMaxSleep = 2 + } + }) + newbo := backoff.NewBackofferWithVars(ctx, boMaxSleep, worker.vars) + backoffermap[task.region.GetID()] = newbo + return newbo +} + +// handleTask handles single copTask, sends the result to channel, retry automatically on error. +func (worker *copIteratorWorker) handleTask(ctx context.Context, task *copTask, respCh chan<- *copResponse) { + defer func() { + r := recover() + if r != nil { + logutil.BgLogger().Error("copIteratorWork meet panic", + zap.Any("r", r), + zap.Stack("stack trace")) + resp := &copResponse{err: util2.GetRecoverError(r)} + // if panic has happened, set checkOOM to false to avoid another panic. + worker.sendToRespCh(resp, respCh, false) + } + }() + remainTasks := []*copTask{task} + backoffermap := make(map[uint64]*Backoffer) + for len(remainTasks) > 0 { + curTask := remainTasks[0] + bo := chooseBackoffer(ctx, backoffermap, curTask, worker) + tasks, err := worker.handleTaskOnce(bo, curTask, respCh) + if err != nil { + resp := &copResponse{err: errors.Trace(err)} + worker.sendToRespCh(resp, respCh, true) + return + } + if worker.finished() { + break + } + if len(tasks) > 0 { + remainTasks = append(tasks, remainTasks[1:]...) + } else { + remainTasks = remainTasks[1:] + } + } +} + +// handleTaskOnce handles single copTask, successful results are send to channel. +// If error happened, returns error. If region split or meet lock, returns the remain tasks. +func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { + failpoint.Inject("handleTaskOnceError", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("mock handleTaskOnce error")) + } + }) + + if task.paging { + task.pagingTaskIdx = atomic.AddUint32(worker.pagingTaskIdx, 1) + } + + copReq := coprocessor.Request{ + Tp: worker.req.Tp, + StartTs: worker.req.StartTs, + Data: worker.req.Data, + Ranges: task.ranges.ToPBRanges(), + SchemaVer: worker.req.SchemaVar, + PagingSize: task.pagingSize, + Tasks: task.ToPBBatchTasks(), + ConnectionId: worker.req.ConnID, + ConnectionAlias: worker.req.ConnAlias, + } + + cacheKey, cacheValue := worker.buildCacheKey(task, &copReq) + + replicaRead := worker.req.ReplicaRead + rgName := worker.req.ResourceGroupName + if task.storeType == kv.TiFlash && !variable.EnableResourceControl.Load() { + // By calling variable.EnableGlobalResourceControlFunc() and setting global variables, + // tikv/client-go can sense whether the rg function is enabled + // But for tiflash, it check if rgName is empty to decide if resource control is enabled or not. + rgName = "" + } + req := tikvrpc.NewReplicaReadRequest(task.cmdType, &copReq, options.GetTiKVReplicaReadType(replicaRead), &worker.replicaReadSeed, kvrpcpb.Context{ + IsolationLevel: isolationLevelToPB(worker.req.IsolationLevel), + Priority: priorityToPB(worker.req.Priority), + NotFillCache: worker.req.NotFillCache, + RecordTimeStat: true, + RecordScanStat: true, + TaskId: worker.req.TaskID, + ResourceControlContext: &kvrpcpb.ResourceControlContext{ + ResourceGroupName: rgName, + }, + BusyThresholdMs: uint32(task.busyThreshold.Milliseconds()), + BucketsVersion: task.bucketsVer, + }) + req.InputRequestSource = task.requestSource.GetRequestSource() + if task.firstReadType != "" { + req.ReadType = task.firstReadType + req.IsRetryRequest = true + } + if worker.req.ResourceGroupTagger != nil { + worker.req.ResourceGroupTagger(req) + } + timeout := config.GetGlobalConfig().TiKVClient.CoprReqTimeout + if task.tikvClientReadTimeout > 0 { + timeout = time.Duration(task.tikvClientReadTimeout) * time.Millisecond + } + failpoint.Inject("sleepCoprRequest", func(v failpoint.Value) { + //nolint:durationcheck + time.Sleep(time.Millisecond * time.Duration(v.(int))) + }) + + if worker.req.RunawayChecker != nil { + if err := worker.req.RunawayChecker.BeforeCopRequest(req); err != nil { + return nil, err + } + } + req.StoreTp = getEndPointType(task.storeType) + startTime := time.Now() + if worker.kvclient.Stats == nil { + worker.kvclient.Stats = tikv.NewRegionRequestRuntimeStats() + } + // set ReadReplicaScope and TxnScope so that req.IsStaleRead will be true when it's a global scope stale read. + req.ReadReplicaScope = worker.req.ReadReplicaScope + req.TxnScope = worker.req.TxnScope + if task.meetLockFallback { + req.DisableStaleReadMeetLock() + } else if worker.req.IsStaleness { + req.EnableStaleWithMixedReplicaRead() + } + staleRead := req.GetStaleRead() + ops := make([]tikv.StoreSelectorOption, 0, 2) + if len(worker.req.MatchStoreLabels) > 0 { + ops = append(ops, tikv.WithMatchLabels(worker.req.MatchStoreLabels)) + } + if task.redirect2Replica != nil { + req.ReplicaRead = true + req.ReplicaReadType = options.GetTiKVReplicaReadType(kv.ReplicaReadFollower) + ops = append(ops, tikv.WithMatchStores([]uint64{*task.redirect2Replica})) + } + resp, rpcCtx, storeAddr, err := worker.kvclient.SendReqCtx(bo.TiKVBackoffer(), req, task.region, + timeout, getEndPointType(task.storeType), task.storeAddr, ops...) + err = derr.ToTiDBErr(err) + if worker.req.RunawayChecker != nil { + failpoint.Inject("sleepCoprAfterReq", func(v failpoint.Value) { + //nolint:durationcheck + value := v.(int) + time.Sleep(time.Millisecond * time.Duration(value)) + if value > 50 { + err = errors.Errorf("Coprocessor task terminated due to exceeding the deadline") + } + }) + err = worker.req.RunawayChecker.CheckCopRespError(err) + } + if err != nil { + if task.storeType == kv.TiDB { + err = worker.handleTiDBSendReqErr(err, task, ch) + return nil, err + } + worker.collectUnconsumedCopRuntimeStats(bo, rpcCtx) + return nil, errors.Trace(err) + } + + // Set task.storeAddr field so its task.String() method have the store address information. + task.storeAddr = storeAddr + + costTime := time.Since(startTime) + copResp := resp.Resp.(*coprocessor.Response) + + if costTime > minLogCopTaskTime { + worker.logTimeCopTask(costTime, task, bo, copResp) + } + + storeID := strconv.FormatUint(req.Context.GetPeer().GetStoreId(), 10) + isInternal := util.IsRequestSourceInternal(&task.requestSource) + scope := metrics.LblGeneral + if isInternal { + scope = metrics.LblInternal + } + metrics.TiKVCoprocessorHistogram.WithLabelValues(storeID, strconv.FormatBool(staleRead), scope).Observe(costTime.Seconds()) + if copResp != nil { + tidbmetrics.DistSQLCoprRespBodySize.WithLabelValues(storeAddr).Observe(float64(len(copResp.Data))) + } + + var remains []*copTask + if worker.req.Paging.Enable { + remains, err = worker.handleCopPagingResult(bo, rpcCtx, &copResponse{pbResp: copResp}, cacheKey, cacheValue, task, ch, costTime) + } else { + // Handles the response for non-paging copTask. + remains, err = worker.handleCopResponse(bo, rpcCtx, &copResponse{pbResp: copResp}, cacheKey, cacheValue, task, ch, nil, costTime) + } + if req.ReadType != "" { + for _, remain := range remains { + remain.firstReadType = req.ReadType + } + } + return remains, err +} + +const ( + minLogBackoffTime = 100 + minLogKVProcessTime = 100 +) + +func (worker *copIteratorWorker) logTimeCopTask(costTime time.Duration, task *copTask, bo *Backoffer, resp *coprocessor.Response) { + logStr := fmt.Sprintf("[TIME_COP_PROCESS] resp_time:%s txnStartTS:%d region_id:%d store_addr:%s", costTime, worker.req.StartTs, task.region.GetID(), task.storeAddr) + if worker.kvclient.Stats != nil { + logStr += fmt.Sprintf(" stats:%s", worker.kvclient.Stats.String()) + } + if bo.GetTotalSleep() > minLogBackoffTime { + backoffTypes := strings.ReplaceAll(fmt.Sprintf("%v", bo.TiKVBackoffer().GetTypes()), " ", ",") + logStr += fmt.Sprintf(" backoff_ms:%d backoff_types:%s", bo.GetTotalSleep(), backoffTypes) + } + if regionErr := resp.GetRegionError(); regionErr != nil { + logStr += fmt.Sprintf(" region_err:%s", regionErr.String()) + } + // resp might be nil, but it is safe to call resp.GetXXX here. + detailV2 := resp.GetExecDetailsV2() + detail := resp.GetExecDetails() + var timeDetail *kvrpcpb.TimeDetail + if detailV2 != nil && detailV2.TimeDetail != nil { + timeDetail = detailV2.TimeDetail + } else if detail != nil && detail.TimeDetail != nil { + timeDetail = detail.TimeDetail + } + if timeDetail != nil { + logStr += fmt.Sprintf(" kv_process_ms:%d", timeDetail.ProcessWallTimeMs) + logStr += fmt.Sprintf(" kv_wait_ms:%d", timeDetail.WaitWallTimeMs) + logStr += fmt.Sprintf(" kv_read_ms:%d", timeDetail.KvReadWallTimeMs) + if timeDetail.ProcessWallTimeMs <= minLogKVProcessTime { + logStr = strings.Replace(logStr, "TIME_COP_PROCESS", "TIME_COP_WAIT", 1) + } + } + + if detailV2 != nil && detailV2.ScanDetailV2 != nil { + logStr += fmt.Sprintf(" processed_versions:%d", detailV2.ScanDetailV2.ProcessedVersions) + logStr += fmt.Sprintf(" total_versions:%d", detailV2.ScanDetailV2.TotalVersions) + logStr += fmt.Sprintf(" rocksdb_delete_skipped_count:%d", detailV2.ScanDetailV2.RocksdbDeleteSkippedCount) + logStr += fmt.Sprintf(" rocksdb_key_skipped_count:%d", detailV2.ScanDetailV2.RocksdbKeySkippedCount) + logStr += fmt.Sprintf(" rocksdb_cache_hit_count:%d", detailV2.ScanDetailV2.RocksdbBlockCacheHitCount) + logStr += fmt.Sprintf(" rocksdb_read_count:%d", detailV2.ScanDetailV2.RocksdbBlockReadCount) + logStr += fmt.Sprintf(" rocksdb_read_byte:%d", detailV2.ScanDetailV2.RocksdbBlockReadByte) + } else if detail != nil && detail.ScanDetail != nil { + logStr = appendScanDetail(logStr, "write", detail.ScanDetail.Write) + logStr = appendScanDetail(logStr, "data", detail.ScanDetail.Data) + logStr = appendScanDetail(logStr, "lock", detail.ScanDetail.Lock) + } + logutil.Logger(bo.GetCtx()).Info(logStr) +} + +func appendScanDetail(logStr string, columnFamily string, scanInfo *kvrpcpb.ScanInfo) string { + if scanInfo != nil { + logStr += fmt.Sprintf(" scan_total_%s:%d", columnFamily, scanInfo.Total) + logStr += fmt.Sprintf(" scan_processed_%s:%d", columnFamily, scanInfo.Processed) + } + return logStr +} + +func (worker *copIteratorWorker) handleCopPagingResult(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse, cacheKey []byte, cacheValue *coprCacheValue, task *copTask, ch chan<- *copResponse, costTime time.Duration) ([]*copTask, error) { + remainedTasks, err := worker.handleCopResponse(bo, rpcCtx, resp, cacheKey, cacheValue, task, ch, nil, costTime) + if err != nil || len(remainedTasks) != 0 { + // If there is region error or lock error, keep the paging size and retry. + for _, remainedTask := range remainedTasks { + remainedTask.pagingSize = task.pagingSize + } + return remainedTasks, errors.Trace(err) + } + pagingRange := resp.pbResp.Range + // only paging requests need to calculate the next ranges + if pagingRange == nil { + // If the storage engine doesn't support paging protocol, it should have return all the region data. + // So we finish here. + return nil, nil + } + + // calculate next ranges and grow the paging size + task.ranges = worker.calculateRemain(task.ranges, pagingRange, worker.req.Desc) + if task.ranges.Len() == 0 { + return nil, nil + } + + task.pagingSize = paging.GrowPagingSize(task.pagingSize, worker.req.Paging.MaxPagingSize) + return []*copTask{task}, nil +} + +// handleCopResponse checks coprocessor Response for region split and lock, +// returns more tasks when that happens, or handles the response if no error. +// if we're handling coprocessor paging response, lastRange is the range of last +// successful response, otherwise it's nil. +func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse, cacheKey []byte, cacheValue *coprCacheValue, task *copTask, ch chan<- *copResponse, lastRange *coprocessor.KeyRange, costTime time.Duration) ([]*copTask, error) { + if ver := resp.pbResp.GetLatestBucketsVersion(); task.bucketsVer < ver { + worker.store.GetRegionCache().UpdateBucketsIfNeeded(task.region, ver) + } + if regionErr := resp.pbResp.GetRegionError(); regionErr != nil { + if rpcCtx != nil && task.storeType == kv.TiDB { + resp.err = errors.Errorf("error: %v", regionErr) + worker.sendToRespCh(resp, ch, true) + return nil, nil + } + errStr := fmt.Sprintf("region_id:%v, region_ver:%v, store_type:%s, peer_addr:%s, error:%s", + task.region.GetID(), task.region.GetVer(), task.storeType.Name(), task.storeAddr, regionErr.String()) + if err := bo.Backoff(tikv.BoRegionMiss(), errors.New(errStr)); err != nil { + return nil, errors.Trace(err) + } + // We may meet RegionError at the first packet, but not during visiting the stream. + remains, err := buildCopTasks(bo, task.ranges, &buildCopTaskOpt{ + req: worker.req, + cache: worker.store.GetRegionCache(), + respChan: false, + eventCb: task.eventCb, + ignoreTiKVClientReadTimeout: true, + }) + if err != nil { + return remains, err + } + return worker.handleBatchRemainsOnErr(bo, rpcCtx, remains, resp.pbResp, task, ch) + } + if lockErr := resp.pbResp.GetLocked(); lockErr != nil { + if err := worker.handleLockErr(bo, lockErr, task); err != nil { + return nil, err + } + task.meetLockFallback = true + return worker.handleBatchRemainsOnErr(bo, rpcCtx, []*copTask{task}, resp.pbResp, task, ch) + } + if otherErr := resp.pbResp.GetOtherError(); otherErr != "" { + err := errors.Errorf("other error: %s", otherErr) + + firstRangeStartKey := task.ranges.At(0).StartKey + lastRangeEndKey := task.ranges.At(task.ranges.Len() - 1).EndKey + + logutil.Logger(bo.GetCtx()).Warn("other error", + zap.Uint64("txnStartTS", worker.req.StartTs), + zap.Uint64("regionID", task.region.GetID()), + zap.Uint64("regionVer", task.region.GetVer()), + zap.Uint64("regionConfVer", task.region.GetConfVer()), + zap.Uint64("bucketsVer", task.bucketsVer), + zap.Uint64("latestBucketsVer", resp.pbResp.GetLatestBucketsVersion()), + zap.Int("rangeNums", task.ranges.Len()), + zap.ByteString("firstRangeStartKey", firstRangeStartKey), + zap.ByteString("lastRangeEndKey", lastRangeEndKey), + zap.String("storeAddr", task.storeAddr), + zap.Error(err)) + if strings.Contains(err.Error(), "write conflict") { + return nil, kv.ErrWriteConflict.FastGen("%s", otherErr) + } + return nil, errors.Trace(err) + } + // When the request is using paging API, the `Range` is not nil. + if resp.pbResp.Range != nil { + resp.startKey = resp.pbResp.Range.Start + } else if task.ranges != nil && task.ranges.Len() > 0 { + resp.startKey = task.ranges.At(0).StartKey + } + worker.handleCollectExecutionInfo(bo, rpcCtx, resp) + resp.respTime = costTime + + if err := worker.handleCopCache(task, resp, cacheKey, cacheValue); err != nil { + return nil, err + } + + pbResp := resp.pbResp + worker.sendToRespCh(resp, ch, true) + return worker.handleBatchCopResponse(bo, rpcCtx, pbResp, task.batchTaskList, ch) +} + +func (worker *copIteratorWorker) handleBatchRemainsOnErr(bo *Backoffer, rpcCtx *tikv.RPCContext, remains []*copTask, resp *coprocessor.Response, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { + if len(task.batchTaskList) == 0 { + return remains, nil + } + batchedTasks := task.batchTaskList + task.batchTaskList = nil + batchedRemains, err := worker.handleBatchCopResponse(bo, rpcCtx, resp, batchedTasks, ch) + if err != nil { + return nil, err + } + return append(remains, batchedRemains...), nil +} + +// handle the batched cop response. +// tasks will be changed, so the input tasks should not be used after calling this function. +func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *coprocessor.Response, + tasks map[uint64]*batchedCopTask, ch chan<- *copResponse) (remainTasks []*copTask, err error) { + if len(tasks) == 0 { + return nil, nil + } + batchedNum := len(tasks) + busyThresholdFallback := false + defer func() { + if err != nil { + return + } + if !busyThresholdFallback { + worker.storeBatchedNum.Add(uint64(batchedNum - len(remainTasks))) + worker.storeBatchedFallbackNum.Add(uint64(len(remainTasks))) + } + }() + appendRemainTasks := func(tasks ...*copTask) { + if remainTasks == nil { + // allocate size of remain length + remainTasks = make([]*copTask, 0, len(tasks)) + } + remainTasks = append(remainTasks, tasks...) + } + // need Addr for recording details. + var dummyRPCCtx *tikv.RPCContext + if rpcCtx != nil { + dummyRPCCtx = &tikv.RPCContext{ + Addr: rpcCtx.Addr, + } + } + batchResps := resp.GetBatchResponses() + for _, batchResp := range batchResps { + taskID := batchResp.GetTaskId() + batchedTask, ok := tasks[taskID] + if !ok { + return nil, errors.Errorf("task id %d not found", batchResp.GetTaskId()) + } + delete(tasks, taskID) + resp := &copResponse{ + pbResp: &coprocessor.Response{ + Data: batchResp.Data, + ExecDetailsV2: batchResp.ExecDetailsV2, + }, + } + task := batchedTask.task + failpoint.Inject("batchCopRegionError", func() { + batchResp.RegionError = &errorpb.Error{} + }) + if regionErr := batchResp.GetRegionError(); regionErr != nil { + errStr := fmt.Sprintf("region_id:%v, region_ver:%v, store_type:%s, peer_addr:%s, error:%s", + task.region.GetID(), task.region.GetVer(), task.storeType.Name(), task.storeAddr, regionErr.String()) + if err := bo.Backoff(tikv.BoRegionMiss(), errors.New(errStr)); err != nil { + return nil, errors.Trace(err) + } + remains, err := buildCopTasks(bo, task.ranges, &buildCopTaskOpt{ + req: worker.req, + cache: worker.store.GetRegionCache(), + respChan: false, + eventCb: task.eventCb, + ignoreTiKVClientReadTimeout: true, + }) + if err != nil { + return nil, err + } + appendRemainTasks(remains...) + continue + } + //TODO: handle locks in batch + if lockErr := batchResp.GetLocked(); lockErr != nil { + if err := worker.handleLockErr(bo, resp.pbResp.GetLocked(), task); err != nil { + return nil, err + } + task.meetLockFallback = true + appendRemainTasks(task) + continue + } + if otherErr := batchResp.GetOtherError(); otherErr != "" { + err := errors.Errorf("other error: %s", otherErr) + + firstRangeStartKey := task.ranges.At(0).StartKey + lastRangeEndKey := task.ranges.At(task.ranges.Len() - 1).EndKey + + logutil.Logger(bo.GetCtx()).Warn("other error", + zap.Uint64("txnStartTS", worker.req.StartTs), + zap.Uint64("regionID", task.region.GetID()), + zap.Uint64("regionVer", task.region.GetVer()), + zap.Uint64("regionConfVer", task.region.GetConfVer()), + zap.Uint64("bucketsVer", task.bucketsVer), + // TODO: add bucket version in log + //zap.Uint64("latestBucketsVer", batchResp.GetLatestBucketsVersion()), + zap.Int("rangeNums", task.ranges.Len()), + zap.ByteString("firstRangeStartKey", firstRangeStartKey), + zap.ByteString("lastRangeEndKey", lastRangeEndKey), + zap.String("storeAddr", task.storeAddr), + zap.Error(err)) + if strings.Contains(err.Error(), "write conflict") { + return nil, kv.ErrWriteConflict.FastGen("%s", otherErr) + } + return nil, errors.Trace(err) + } + worker.handleCollectExecutionInfo(bo, dummyRPCCtx, resp) + worker.sendToRespCh(resp, ch, true) + } + for _, t := range tasks { + task := t.task + // when the error is generated by client or a load-based server busy, + // response is empty by design, skip warning for this case. + if len(batchResps) != 0 { + firstRangeStartKey := task.ranges.At(0).StartKey + lastRangeEndKey := task.ranges.At(task.ranges.Len() - 1).EndKey + logutil.Logger(bo.GetCtx()).Error("response of batched task missing", + zap.Uint64("id", task.taskID), + zap.Uint64("txnStartTS", worker.req.StartTs), + zap.Uint64("regionID", task.region.GetID()), + zap.Uint64("regionVer", task.region.GetVer()), + zap.Uint64("regionConfVer", task.region.GetConfVer()), + zap.Uint64("bucketsVer", task.bucketsVer), + zap.Int("rangeNums", task.ranges.Len()), + zap.ByteString("firstRangeStartKey", firstRangeStartKey), + zap.ByteString("lastRangeEndKey", lastRangeEndKey), + zap.String("storeAddr", task.storeAddr)) + } + appendRemainTasks(t.task) + } + if regionErr := resp.GetRegionError(); regionErr != nil && regionErr.ServerIsBusy != nil && + regionErr.ServerIsBusy.EstimatedWaitMs > 0 && len(remainTasks) != 0 { + if len(batchResps) != 0 { + return nil, errors.New("store batched coprocessor with server is busy error shouldn't contain responses") + } + busyThresholdFallback = true + handler := newBatchTaskBuilder(bo, worker.req, worker.store.GetRegionCache(), kv.ReplicaReadFollower) + for _, task := range remainTasks { + // do not set busy threshold again. + task.busyThreshold = 0 + if err = handler.handle(task); err != nil { + return nil, err + } + } + remainTasks = handler.build() + } + return remainTasks, nil +} + +func (worker *copIteratorWorker) handleLockErr(bo *Backoffer, lockErr *kvrpcpb.LockInfo, task *copTask) error { + if lockErr == nil { + return nil + } + resolveLockDetail := worker.getLockResolverDetails() + // Be care that we didn't redact the SQL statement because the log is DEBUG level. + if task.eventCb != nil { + task.eventCb(trxevents.WrapCopMeetLock(&trxevents.CopMeetLock{ + LockInfo: lockErr, + })) + } else { + logutil.Logger(bo.GetCtx()).Debug("coprocessor encounters lock", + zap.Stringer("lock", lockErr)) + } + resolveLocksOpts := txnlock.ResolveLocksOptions{ + CallerStartTS: worker.req.StartTs, + Locks: []*txnlock.Lock{txnlock.NewLock(lockErr)}, + Detail: resolveLockDetail, + } + resolveLocksRes, err1 := worker.kvclient.ResolveLocksWithOpts(bo.TiKVBackoffer(), resolveLocksOpts) + err1 = derr.ToTiDBErr(err1) + if err1 != nil { + return errors.Trace(err1) + } + msBeforeExpired := resolveLocksRes.TTL + if msBeforeExpired > 0 { + if err := bo.BackoffWithMaxSleepTxnLockFast(int(msBeforeExpired), errors.New(lockErr.String())); err != nil { + return errors.Trace(err) + } + } + return nil +} + +func (worker *copIteratorWorker) buildCacheKey(task *copTask, copReq *coprocessor.Request) (cacheKey []byte, cacheValue *coprCacheValue) { + // If there are many ranges, it is very likely to be a TableLookupRequest. They are not worth to cache since + // computing is not the main cost. Ignore requests with many ranges directly to avoid slowly building the cache key. + if task.cmdType == tikvrpc.CmdCop && worker.store.coprCache != nil && worker.req.Cacheable && worker.store.coprCache.CheckRequestAdmission(len(copReq.Ranges)) { + cKey, err := coprCacheBuildKey(copReq) + if err == nil { + cacheKey = cKey + cValue := worker.store.coprCache.Get(cKey) + copReq.IsCacheEnabled = true + + if cValue != nil && cValue.RegionID == task.region.GetID() && cValue.TimeStamp <= worker.req.StartTs { + // Append cache version to the request to skip Coprocessor computation if possible + // when request result is cached + copReq.CacheIfMatchVersion = cValue.RegionDataVersion + cacheValue = cValue + } else { + copReq.CacheIfMatchVersion = 0 + } + } else { + logutil.BgLogger().Warn("Failed to build copr cache key", zap.Error(err)) + } + } + return +} + +func (worker *copIteratorWorker) handleCopCache(task *copTask, resp *copResponse, cacheKey []byte, cacheValue *coprCacheValue) error { + if resp.pbResp.IsCacheHit { + if cacheValue == nil { + return errors.New("Internal error: received illegal TiKV response") + } + copr_metrics.CoprCacheCounterHit.Add(1) + // Cache hit and is valid: use cached data as response data and we don't update the cache. + data := make([]byte, len(cacheValue.Data)) + copy(data, cacheValue.Data) + resp.pbResp.Data = data + if worker.req.Paging.Enable { + var start, end []byte + if cacheValue.PageStart != nil { + start = make([]byte, len(cacheValue.PageStart)) + copy(start, cacheValue.PageStart) + } + if cacheValue.PageEnd != nil { + end = make([]byte, len(cacheValue.PageEnd)) + copy(end, cacheValue.PageEnd) + } + // When paging protocol is used, the response key range is part of the cache data. + if start != nil || end != nil { + resp.pbResp.Range = &coprocessor.KeyRange{ + Start: start, + End: end, + } + } else { + resp.pbResp.Range = nil + } + } + // `worker.enableCollectExecutionInfo` is loaded from the instance's config. Because it's not related to the request, + // the cache key can be same when `worker.enableCollectExecutionInfo` is true or false. + // When `worker.enableCollectExecutionInfo` is false, the `resp.detail` is nil, and hit cache is still possible. + // Check `resp.detail` to avoid panic. + // Details: https://github.com/pingcap/tidb/issues/48212 + if resp.detail != nil { + resp.detail.CoprCacheHit = true + } + return nil + } + copr_metrics.CoprCacheCounterMiss.Add(1) + // Cache not hit or cache hit but not valid: update the cache if the response can be cached. + if cacheKey != nil && resp.pbResp.CanBeCached && resp.pbResp.CacheLastVersion > 0 { + if resp.detail != nil { + if worker.store.coprCache.CheckResponseAdmission(resp.pbResp.Data.Size(), resp.detail.TimeDetail.ProcessTime, task.pagingTaskIdx) { + data := make([]byte, len(resp.pbResp.Data)) + copy(data, resp.pbResp.Data) + + newCacheValue := coprCacheValue{ + Data: data, + TimeStamp: worker.req.StartTs, + RegionID: task.region.GetID(), + RegionDataVersion: resp.pbResp.CacheLastVersion, + } + // When paging protocol is used, the response key range is part of the cache data. + if r := resp.pbResp.GetRange(); r != nil { + newCacheValue.PageStart = append([]byte{}, r.GetStart()...) + newCacheValue.PageEnd = append([]byte{}, r.GetEnd()...) + } + worker.store.coprCache.Set(cacheKey, &newCacheValue) + } + } + } + return nil +} + +func (worker *copIteratorWorker) getLockResolverDetails() *util.ResolveLockDetail { + if !worker.enableCollectExecutionInfo { + return nil + } + return &util.ResolveLockDetail{} +} + +func (worker *copIteratorWorker) handleCollectExecutionInfo(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse) { + defer func() { + worker.kvclient.Stats = nil + }() + if !worker.enableCollectExecutionInfo { + return + } + failpoint.Inject("disable-collect-execution", func(val failpoint.Value) { + if val.(bool) { + panic("shouldn't reachable") + } + }) + if resp.detail == nil { + resp.detail = new(CopRuntimeStats) + } + worker.collectCopRuntimeStats(resp.detail, bo, rpcCtx, resp) +} + +func (worker *copIteratorWorker) collectCopRuntimeStats(copStats *CopRuntimeStats, bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse) { + copStats.ReqStats = worker.kvclient.Stats + backoffTimes := bo.GetBackoffTimes() + copStats.BackoffTime = time.Duration(bo.GetTotalSleep()) * time.Millisecond + copStats.BackoffSleep = make(map[string]time.Duration, len(backoffTimes)) + copStats.BackoffTimes = make(map[string]int, len(backoffTimes)) + for backoff := range backoffTimes { + copStats.BackoffTimes[backoff] = backoffTimes[backoff] + copStats.BackoffSleep[backoff] = time.Duration(bo.GetBackoffSleepMS()[backoff]) * time.Millisecond + } + if rpcCtx != nil { + copStats.CalleeAddress = rpcCtx.Addr + } + if resp == nil { + return + } + sd := &util.ScanDetail{} + td := util.TimeDetail{} + if pbDetails := resp.pbResp.ExecDetailsV2; pbDetails != nil { + // Take values in `ExecDetailsV2` first. + if pbDetails.TimeDetail != nil || pbDetails.TimeDetailV2 != nil { + td.MergeFromTimeDetail(pbDetails.TimeDetailV2, pbDetails.TimeDetail) + } + if scanDetailV2 := pbDetails.ScanDetailV2; scanDetailV2 != nil { + sd.MergeFromScanDetailV2(scanDetailV2) + } + } else if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { + if timeDetail := pbDetails.TimeDetail; timeDetail != nil { + td.MergeFromTimeDetail(nil, timeDetail) + } + if scanDetail := pbDetails.ScanDetail; scanDetail != nil { + if scanDetail.Write != nil { + sd.ProcessedKeys = scanDetail.Write.Processed + sd.TotalKeys = scanDetail.Write.Total + } + } + } + copStats.ScanDetail = sd + copStats.TimeDetail = td +} + +func (worker *copIteratorWorker) collectUnconsumedCopRuntimeStats(bo *Backoffer, rpcCtx *tikv.RPCContext) { + if worker.kvclient.Stats == nil { + return + } + copStats := &CopRuntimeStats{} + worker.collectCopRuntimeStats(copStats, bo, rpcCtx, nil) + worker.unconsumedStats.Lock() + worker.unconsumedStats.stats = append(worker.unconsumedStats.stats, copStats) + worker.unconsumedStats.Unlock() + worker.kvclient.Stats = nil +} + +// CopRuntimeStats contains execution detail information. +type CopRuntimeStats struct { + execdetails.ExecDetails + ReqStats *tikv.RegionRequestRuntimeStats + + CoprCacheHit bool +} + +type unconsumedCopRuntimeStats struct { + sync.Mutex + stats []*CopRuntimeStats +} + +func (worker *copIteratorWorker) handleTiDBSendReqErr(err error, task *copTask, ch chan<- *copResponse) error { + errCode := errno.ErrUnknown + errMsg := err.Error() + if terror.ErrorEqual(err, derr.ErrTiKVServerTimeout) { + errCode = errno.ErrTiKVServerTimeout + errMsg = "TiDB server timeout, address is " + task.storeAddr + } + if terror.ErrorEqual(err, derr.ErrTiFlashServerTimeout) { + errCode = errno.ErrTiFlashServerTimeout + errMsg = "TiDB server timeout, address is " + task.storeAddr + } + selResp := tipb.SelectResponse{ + Warnings: []*tipb.Error{ + { + Code: int32(errCode), + Msg: errMsg, + }, + }, + } + data, err := proto.Marshal(&selResp) + if err != nil { + return errors.Trace(err) + } + resp := &copResponse{ + pbResp: &coprocessor.Response{ + Data: data, + }, + detail: &CopRuntimeStats{}, + } + worker.sendToRespCh(resp, ch, true) + return nil +} + +// calculateRetry splits the input ranges into two, and take one of them according to desc flag. +// It's used in paging API, to calculate which range is consumed and what needs to be retry. +// For example: +// ranges: [r1 --> r2) [r3 --> r4) +// split: [s1 --> s2) +// In normal scan order, all data before s1 is consumed, so the retry ranges should be [s1 --> r2) [r3 --> r4) +// In reverse scan order, all data after s2 is consumed, so the retry ranges should be [r1 --> r2) [r3 --> s2) +func (worker *copIteratorWorker) calculateRetry(ranges *KeyRanges, split *coprocessor.KeyRange, desc bool) *KeyRanges { + if split == nil { + return ranges + } + if desc { + left, _ := ranges.Split(split.End) + return left + } + _, right := ranges.Split(split.Start) + return right +} + +// calculateRemain calculates the remain ranges to be processed, it's used in paging API. +// For example: +// ranges: [r1 --> r2) [r3 --> r4) +// split: [s1 --> s2) +// In normal scan order, all data before s2 is consumed, so the remained ranges should be [s2 --> r4) +// In reverse scan order, all data after s1 is consumed, so the remained ranges should be [r1 --> s1) +func (worker *copIteratorWorker) calculateRemain(ranges *KeyRanges, split *coprocessor.KeyRange, desc bool) *KeyRanges { + if split == nil { + return ranges + } + if desc { + left, _ := ranges.Split(split.Start) + return left + } + _, right := ranges.Split(split.End) + return right +} + +// finished checks the flags and finished channel, it tells whether the worker is finished. +func (worker *copIteratorWorker) finished() bool { + if worker.vars != nil && worker.vars.Killed != nil { + killed := atomic.LoadUint32(worker.vars.Killed) + if killed != 0 { + logutil.BgLogger().Info( + "a killed signal is received in copIteratorWorker", + zap.Uint32("signal", killed), + ) + return true + } + } + select { + case <-worker.finishCh: + return true + default: + return false + } +} + +func (it *copIterator) Close() error { + if atomic.CompareAndSwapUint32(&it.closed, 0, 1) { + close(it.finishCh) + } + it.rpcCancel.CancelAll() + it.actionOnExceed.close() + it.wg.Wait() + return nil +} + +// copErrorResponse returns error when calling Next() +type copErrorResponse struct{ error } + +func (it copErrorResponse) Next(ctx context.Context) (kv.ResultSubset, error) { + return nil, it.error +} + +func (it copErrorResponse) Close() error { + return nil +} + +// rateLimitAction an OOM Action which is used to control the token if OOM triggered. The token number should be +// set on initial. Each time the Action is triggered, one token would be destroyed. If the count of the token is less +// than 2, the action would be delegated to the fallback action. +type rateLimitAction struct { + memory.BaseOOMAction + // enabled indicates whether the rateLimitAction is permitted to Action. 1 means permitted, 0 denied. + enabled uint32 + // totalTokenNum indicates the total token at initial + totalTokenNum uint + cond struct { + sync.Mutex + // exceeded indicates whether have encountered OOM situation. + exceeded bool + // remainingTokenNum indicates the count of tokens which still exists + remainingTokenNum uint + once sync.Once + // triggerCountForTest indicates the total count of the rateLimitAction's Action being executed + triggerCountForTest uint + } +} + +func newRateLimitAction(totalTokenNumber uint) *rateLimitAction { + return &rateLimitAction{ + totalTokenNum: totalTokenNumber, + cond: struct { + sync.Mutex + exceeded bool + remainingTokenNum uint + once sync.Once + triggerCountForTest uint + }{ + Mutex: sync.Mutex{}, + exceeded: false, + remainingTokenNum: totalTokenNumber, + once: sync.Once{}, + }, + } +} + +// Action implements ActionOnExceed.Action +func (e *rateLimitAction) Action(t *memory.Tracker) { + if !e.isEnabled() { + if fallback := e.GetFallback(); fallback != nil { + fallback.Action(t) + } + return + } + e.conditionLock() + defer e.conditionUnlock() + e.cond.once.Do(func() { + if e.cond.remainingTokenNum < 2 { + e.setEnabled(false) + logutil.BgLogger().Info("memory exceeds quota, rateLimitAction delegate to fallback action", + zap.Uint("total token count", e.totalTokenNum)) + if fallback := e.GetFallback(); fallback != nil { + fallback.Action(t) + } + return + } + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { + if val.(bool) { + if e.cond.triggerCountForTest+e.cond.remainingTokenNum != e.totalTokenNum { + panic("triggerCount + remainingTokenNum not equal to totalTokenNum") + } + } + }) + logutil.BgLogger().Info("memory exceeds quota, destroy one token now.", + zap.Int64("consumed", t.BytesConsumed()), + zap.Int64("quota", t.GetBytesLimit()), + zap.Uint("total token count", e.totalTokenNum), + zap.Uint("remaining token count", e.cond.remainingTokenNum)) + e.cond.exceeded = true + e.cond.triggerCountForTest++ + }) +} + +// GetPriority get the priority of the Action. +func (e *rateLimitAction) GetPriority() int64 { + return memory.DefRateLimitPriority +} + +// destroyTokenIfNeeded will check the `exceed` flag after copWorker finished one task. +// If the exceed flag is true and there is no token been destroyed before, one token will be destroyed, +// or the token would be return back. +func (e *rateLimitAction) destroyTokenIfNeeded(returnToken func()) { + if !e.isEnabled() { + returnToken() + return + } + e.conditionLock() + defer e.conditionUnlock() + if !e.cond.exceeded { + returnToken() + return + } + // If actionOnExceed has been triggered and there is no token have been destroyed before, + // destroy one token. + e.cond.remainingTokenNum = e.cond.remainingTokenNum - 1 + e.cond.exceeded = false + e.cond.once = sync.Once{} +} + +func (e *rateLimitAction) conditionLock() { + e.cond.Lock() +} + +func (e *rateLimitAction) conditionUnlock() { + e.cond.Unlock() +} + +func (e *rateLimitAction) close() { + if !e.isEnabled() { + return + } + e.setEnabled(false) + e.conditionLock() + defer e.conditionUnlock() + e.cond.exceeded = false + e.SetFinished() +} + +func (e *rateLimitAction) setEnabled(enabled bool) { + newValue := uint32(0) + if enabled { + newValue = uint32(1) + } + atomic.StoreUint32(&e.enabled, newValue) +} + +func (e *rateLimitAction) isEnabled() bool { + return atomic.LoadUint32(&e.enabled) > 0 +} + +// priorityToPB converts priority type to wire type. +func priorityToPB(pri int) kvrpcpb.CommandPri { + switch pri { + case kv.PriorityLow: + return kvrpcpb.CommandPri_Low + case kv.PriorityHigh: + return kvrpcpb.CommandPri_High + default: + return kvrpcpb.CommandPri_Normal + } +} + +func isolationLevelToPB(level kv.IsoLevel) kvrpcpb.IsolationLevel { + switch level { + case kv.RC: + return kvrpcpb.IsolationLevel_RC + case kv.SI: + return kvrpcpb.IsolationLevel_SI + case kv.RCCheckTS: + return kvrpcpb.IsolationLevel_RCCheckTS + default: + return kvrpcpb.IsolationLevel_SI + } +} + +// BuildKeyRanges is used for test, quickly build key ranges from paired keys. +func BuildKeyRanges(keys ...string) []kv.KeyRange { + var ranges []kv.KeyRange + for i := 0; i < len(keys); i += 2 { + ranges = append(ranges, kv.KeyRange{ + StartKey: []byte(keys[i]), + EndKey: []byte(keys[i+1]), + }) + } + return ranges +} + +func optRowHint(req *kv.Request) bool { + opt := true + if req.StoreType == kv.TiDB { + return false + } + if req.RequestSource.RequestSourceInternal || req.Tp != kv.ReqTypeDAG { + // disable extra concurrency for internal tasks. + return false + } + failpoint.Inject("disableFixedRowCountHint", func(_ failpoint.Value) { + opt = false + }) + return opt +} + +func checkStoreBatchCopr(req *kv.Request) bool { + if req.Tp != kv.ReqTypeDAG || req.StoreType != kv.TiKV { + return false + } + // TODO: support keep-order batch + if req.ReplicaRead != kv.ReplicaReadLeader || req.KeepOrder { + // Disable batch copr for follower read + return false + } + // Disable batch copr when paging is enabled. + if req.Paging.Enable { + return false + } + // Disable it for internal requests to avoid regression. + if req.RequestSource.RequestSourceInternal { + return false + } + return true +} diff --git a/pkg/store/copr/mpp.go b/pkg/store/copr/mpp.go index 32c098aa35cf2..618b04f20abf9 100644 --- a/pkg/store/copr/mpp.go +++ b/pkg/store/copr/mpp.go @@ -280,10 +280,10 @@ func (c *MPPClient) CheckVisibility(startTime uint64) error { } func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, TTL int64) (int, error) { - failpoint.Inject("mppStoreCountSetLastUpdateTime", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountSetLastUpdateTime")); _err_ == nil { v, _ := strconv.ParseInt(value.(string), 10, 0) c.lastUpdate = v - }) + } lastUpdate := atomic.LoadInt64(&c.lastUpdate) now := time.Now().UnixMicro() @@ -295,10 +295,10 @@ func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, } } - failpoint.Inject("mppStoreCountSetLastUpdateTimeP2", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountSetLastUpdateTimeP2")); _err_ == nil { v, _ := strconv.ParseInt(value.(string), 10, 0) c.lastUpdate = v - }) + } if !atomic.CompareAndSwapInt64(&c.lastUpdate, lastUpdate, now) { if isInit { @@ -311,11 +311,11 @@ func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, cnt := 0 stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) - failpoint.Inject("mppStoreCountPDError", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountPDError")); _err_ == nil { if value.(bool) { err = errors.New("failed to get mpp store count") } - }) + } if err != nil { // always to update cache next time @@ -328,9 +328,9 @@ func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, } cnt += 1 } - failpoint.Inject("mppStoreCountSetMPPCnt", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountSetMPPCnt")); _err_ == nil { cnt = value.(int) - }) + } if !isInit || atomic.LoadInt64(&c.lastUpdate) == now { atomic.StoreInt32(&c.cnt, int32(cnt)) diff --git a/pkg/store/copr/mpp.go__failpoint_stash__ b/pkg/store/copr/mpp.go__failpoint_stash__ new file mode 100644 index 0000000000000..32c098aa35cf2 --- /dev/null +++ b/pkg/store/copr/mpp.go__failpoint_stash__ @@ -0,0 +1,346 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package copr + +import ( + "context" + "strconv" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/mpp" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/store/driver/backoff" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/tiflash" + "github.com/pingcap/tidb/pkg/util/tiflashcompute" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// MPPClient servers MPP requests. +type MPPClient struct { + store *kvStore +} + +type mppStoreCnt struct { + cnt int32 + lastUpdate int64 + initFlag int32 +} + +// GetAddress returns the network address. +func (c *batchCopTask) GetAddress() string { + return c.storeAddr +} + +// ConstructMPPTasks receives ScheduleRequest, which are actually collects of kv ranges. We allocates MPPTaskMeta for them and returns. +func (c *MPPClient) ConstructMPPTasks(ctx context.Context, req *kv.MPPBuildTasksRequest, ttl time.Duration, dispatchPolicy tiflashcompute.DispatchPolicy, tiflashReplicaReadPolicy tiflash.ReplicaRead, appendWarning func(error)) ([]kv.MPPTaskMeta, error) { + ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTS) + bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, nil) + var tasks []*batchCopTask + var err error + if req.PartitionIDAndRanges != nil { + rangesForEachPartition := make([]*KeyRanges, len(req.PartitionIDAndRanges)) + partitionIDs := make([]int64, len(req.PartitionIDAndRanges)) + for i, p := range req.PartitionIDAndRanges { + rangesForEachPartition[i] = NewKeyRanges(p.KeyRanges) + partitionIDs[i] = p.ID + } + tasks, err = buildBatchCopTasksForPartitionedTable(ctx, bo, c.store, rangesForEachPartition, kv.TiFlash, true, ttl, true, 20, partitionIDs, dispatchPolicy, tiflashReplicaReadPolicy, appendWarning) + } else { + if req.KeyRanges == nil { + return nil, errors.New("KeyRanges in MPPBuildTasksRequest is nil") + } + ranges := NewKeyRanges(req.KeyRanges) + tasks, err = buildBatchCopTasksForNonPartitionedTable(ctx, bo, c.store, ranges, kv.TiFlash, true, ttl, true, 20, dispatchPolicy, tiflashReplicaReadPolicy, appendWarning) + } + + if err != nil { + return nil, errors.Trace(err) + } + mppTasks := make([]kv.MPPTaskMeta, 0, len(tasks)) + for _, copTask := range tasks { + mppTasks = append(mppTasks, copTask) + } + return mppTasks, nil +} + +// DispatchMPPTask dispatch mpp task, and returns valid response when retry = false and err is nil +func (c *MPPClient) DispatchMPPTask(param kv.DispatchMPPTaskParam) (resp *mpp.DispatchTaskResponse, retry bool, err error) { + req := param.Req + var regionInfos []*coprocessor.RegionInfo + originalTask, ok := req.Meta.(*batchCopTask) + if ok { + for _, ri := range originalTask.regionInfos { + regionInfos = append(regionInfos, ri.toCoprocessorRegionInfo()) + } + } + + // meta for current task. + taskMeta := &mpp.TaskMeta{StartTs: req.StartTs, QueryTs: req.MppQueryID.QueryTs, LocalQueryId: req.MppQueryID.LocalQueryID, TaskId: req.ID, ServerId: req.MppQueryID.ServerID, + GatherId: req.GatherID, + Address: req.Meta.GetAddress(), + CoordinatorAddress: req.CoordinatorAddress, + ReportExecutionSummary: req.ReportExecutionSummary, + MppVersion: req.MppVersion.ToInt64(), + ResourceGroupName: req.ResourceGroupName, + ConnectionId: req.ConnectionID, + ConnectionAlias: req.ConnectionAlias, + } + + mppReq := &mpp.DispatchTaskRequest{ + Meta: taskMeta, + EncodedPlan: req.Data, + // TODO: This is only an experience value. It's better to be configurable. + Timeout: 60, + SchemaVer: req.SchemaVar, + Regions: regionInfos, + } + if originalTask != nil { + mppReq.TableRegions = originalTask.PartitionTableRegions + if mppReq.TableRegions != nil { + mppReq.Regions = nil + } + } + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPTask, mppReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = getEndPointType(kv.TiFlash) + + // TODO: Handle dispatch task response correctly, including retry logic and cancel logic. + var rpcResp *tikvrpc.Response + invalidPDCache := config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler + bo := backoff.NewBackofferWithTikvBo(param.Bo) + + // If copTasks is not empty, we should send request according to region distribution. + // Or else it's the task without region, which always happens in high layer task without table. + // In that case + if originalTask != nil { + sender := NewRegionBatchRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), param.EnableCollectExecutionInfo) + rpcResp, retry, _, err = sender.SendReqToAddr(bo, originalTask.ctx, originalTask.regionInfos, wrappedReq, tikv.ReadTimeoutMedium) + // No matter what the rpc error is, we won't retry the mpp dispatch tasks. + // TODO: If we want to retry, we must redo the plan fragment cutting and task scheduling. + // That's a hard job but we can try it in the future. + if sender.GetRPCError() != nil { + logutil.BgLogger().Warn("mpp dispatch meet io error", zap.String("error", sender.GetRPCError().Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) + if invalidPDCache { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } + err = sender.GetRPCError() + } + } else { + rpcResp, err = c.store.GetTiKVClient().SendRequest(param.Ctx, req.Meta.GetAddress(), wrappedReq, tikv.ReadTimeoutMedium) + if errors.Cause(err) == context.Canceled || status.Code(errors.Cause(err)) == codes.Canceled { + retry = false + } else if err != nil { + if invalidPDCache { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } + if bo.Backoff(tikv.BoTiFlashRPC(), err) == nil { + retry = true + } + } + } + + if err != nil || retry { + return nil, retry, err + } + + realResp := rpcResp.Resp.(*mpp.DispatchTaskResponse) + if realResp.Error != nil { + return realResp, false, nil + } + + if len(realResp.RetryRegions) > 0 { + logutil.BgLogger().Info("TiFlash found " + strconv.Itoa(len(realResp.RetryRegions)) + " stale regions. Only first " + strconv.Itoa(min(10, len(realResp.RetryRegions))) + " regions will be logged if the log level is higher than Debug") + for index, retry := range realResp.RetryRegions { + id := tikv.NewRegionVerID(retry.Id, retry.RegionEpoch.ConfVer, retry.RegionEpoch.Version) + if index < 10 || log.GetLevel() <= zap.DebugLevel { + logutil.BgLogger().Info("invalid region because tiflash detected stale region", zap.String("region id", id.String())) + } + c.store.GetRegionCache().InvalidateCachedRegionWithReason(id, tikv.EpochNotMatch) + } + } + return realResp, retry, err +} + +// CancelMPPTasks cancels mpp tasks +// NOTE: We do not retry here, because retry is helpless when errors result from TiFlash or Network. If errors occur, the execution on TiFlash will finally stop after some minutes. +// This function is exclusively called, and only the first call succeeds sending tasks and setting all tasks as cancelled, while others will not work. +func (c *MPPClient) CancelMPPTasks(param kv.CancelMPPTasksParam) { + usedStoreAddrs := param.StoreAddr + reqs := param.Reqs + if len(usedStoreAddrs) == 0 || len(reqs) == 0 { + return + } + + firstReq := reqs[0] + killReq := &mpp.CancelTaskRequest{ + Meta: &mpp.TaskMeta{StartTs: firstReq.StartTs, GatherId: firstReq.GatherID, QueryTs: firstReq.MppQueryID.QueryTs, LocalQueryId: firstReq.MppQueryID.LocalQueryID, ServerId: firstReq.MppQueryID.ServerID, MppVersion: firstReq.MppVersion.ToInt64(), ResourceGroupName: firstReq.ResourceGroupName}, + } + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPCancel, killReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = getEndPointType(kv.TiFlash) + + // send cancel cmd to all stores where tasks run + invalidPDCache := config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler + wg := util.WaitGroupWrapper{} + gotErr := atomic.Bool{} + for addr := range usedStoreAddrs { + storeAddr := addr + wg.Run(func() { + _, err := c.store.GetTiKVClient().SendRequest(context.Background(), storeAddr, wrappedReq, tikv.ReadTimeoutShort) + logutil.BgLogger().Debug("cancel task", zap.Uint64("query id ", firstReq.StartTs), zap.String("on addr", storeAddr), zap.Int64("mpp-version", firstReq.MppVersion.ToInt64())) + if err != nil { + logutil.BgLogger().Error("cancel task error", zap.Error(err), zap.Uint64("query id", firstReq.StartTs), zap.String("on addr", storeAddr), zap.Int64("mpp-version", firstReq.MppVersion.ToInt64())) + if invalidPDCache { + gotErr.CompareAndSwap(false, true) + } + } + }) + } + wg.Wait() + if invalidPDCache && gotErr.Load() { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } +} + +// EstablishMPPConns build a mpp connection to receive data, return valid response when err is nil +func (c *MPPClient) EstablishMPPConns(param kv.EstablishMPPConnsParam) (*tikvrpc.MPPStreamResponse, error) { + req := param.Req + taskMeta := param.TaskMeta + connReq := &mpp.EstablishMPPConnectionRequest{ + SenderMeta: taskMeta, + ReceiverMeta: &mpp.TaskMeta{ + StartTs: req.StartTs, + GatherId: req.GatherID, + QueryTs: req.MppQueryID.QueryTs, + LocalQueryId: req.MppQueryID.LocalQueryID, + ServerId: req.MppQueryID.ServerID, + MppVersion: req.MppVersion.ToInt64(), + TaskId: -1, + ResourceGroupName: req.ResourceGroupName, + }, + } + + var err error + + wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPConn, connReq, kvrpcpb.Context{}) + wrappedReq.StoreTp = getEndPointType(kv.TiFlash) + + // Drain results from root task. + // We don't need to process any special error. When we meet errors, just let it fail. + rpcResp, err := c.store.GetTiKVClient().SendRequest(param.Ctx, req.Meta.GetAddress(), wrappedReq, TiFlashReadTimeoutUltraLong) + + var stream *tikvrpc.MPPStreamResponse + if rpcResp != nil && rpcResp.Resp != nil { + stream = rpcResp.Resp.(*tikvrpc.MPPStreamResponse) + } + + if err != nil { + if stream != nil { + stream.Close() + } + logutil.BgLogger().Warn("establish mpp connection meet error and cannot retry", zap.String("error", err.Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) + if config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler { + c.store.GetRegionCache().InvalidateTiFlashComputeStores() + } + return nil, err + } + + return stream, nil +} + +// CheckVisibility checks if it is safe to read using given ts. +func (c *MPPClient) CheckVisibility(startTime uint64) error { + return c.store.CheckVisibility(startTime) +} + +func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, TTL int64) (int, error) { + failpoint.Inject("mppStoreCountSetLastUpdateTime", func(value failpoint.Value) { + v, _ := strconv.ParseInt(value.(string), 10, 0) + c.lastUpdate = v + }) + + lastUpdate := atomic.LoadInt64(&c.lastUpdate) + now := time.Now().UnixMicro() + isInit := atomic.LoadInt32(&c.initFlag) != 0 + + if now-lastUpdate < TTL { + if isInit { + return int(atomic.LoadInt32(&c.cnt)), nil + } + } + + failpoint.Inject("mppStoreCountSetLastUpdateTimeP2", func(value failpoint.Value) { + v, _ := strconv.ParseInt(value.(string), 10, 0) + c.lastUpdate = v + }) + + if !atomic.CompareAndSwapInt64(&c.lastUpdate, lastUpdate, now) { + if isInit { + return int(atomic.LoadInt32(&c.cnt)), nil + } + // if has't initialized, always fetch latest mpp store info + } + + // update mpp store cache + cnt := 0 + stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) + + failpoint.Inject("mppStoreCountPDError", func(value failpoint.Value) { + if value.(bool) { + err = errors.New("failed to get mpp store count") + } + }) + + if err != nil { + // always to update cache next time + atomic.StoreInt32(&c.initFlag, 0) + return 0, err + } + for _, s := range stores { + if !tikv.LabelFilterNoTiFlashWriteNode(s.GetLabels()) { + continue + } + cnt += 1 + } + failpoint.Inject("mppStoreCountSetMPPCnt", func(value failpoint.Value) { + cnt = value.(int) + }) + + if !isInit || atomic.LoadInt64(&c.lastUpdate) == now { + atomic.StoreInt32(&c.cnt, int32(cnt)) + atomic.StoreInt32(&c.initFlag, 1) + } + + return cnt, nil +} + +// GetMPPStoreCount returns number of TiFlash stores +func (c *MPPClient) GetMPPStoreCount() (int, error) { + return c.store.mppStoreCnt.getMPPStoreCount(c.store.store.Ctx(), c.store.store.GetPDClient(), 120*1e6 /* TTL 120sec */) +} diff --git a/pkg/store/driver/txn/binding__failpoint_binding__.go b/pkg/store/driver/txn/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..5a3dafc415412 --- /dev/null +++ b/pkg/store/driver/txn/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package txn + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/store/driver/txn/binlog.go b/pkg/store/driver/txn/binlog.go index b17bcbaed25fa..96006349db43b 100644 --- a/pkg/store/driver/txn/binlog.go +++ b/pkg/store/driver/txn/binlog.go @@ -65,12 +65,12 @@ func (e *binlogExecutor) Commit(ctx context.Context, commitTS int64) { wg := sync.WaitGroup{} mock := false - failpoint.Inject("mockSyncBinlogCommit", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockSyncBinlogCommit")); _err_ == nil { if val.(bool) { wg.Add(1) mock = true } - }) + } go func() { logutil.Eventf(ctx, "start write finish binlog") binlogWriteResult := e.binInfo.WriteBinlog(e.txn.GetClusterID()) diff --git a/pkg/store/driver/txn/binlog.go__failpoint_stash__ b/pkg/store/driver/txn/binlog.go__failpoint_stash__ new file mode 100644 index 0000000000000..b17bcbaed25fa --- /dev/null +++ b/pkg/store/driver/txn/binlog.go__failpoint_stash__ @@ -0,0 +1,90 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package txn + +import ( + "context" + "sync" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tipb/go-binlog" + "github.com/tikv/client-go/v2/tikv" + "go.uber.org/zap" +) + +type binlogExecutor struct { + txn *tikv.KVTxn + binInfo *binloginfo.BinlogInfo +} + +func (e *binlogExecutor) Skip() { + binloginfo.RemoveOneSkippedCommitter() +} + +func (e *binlogExecutor) Prewrite(ctx context.Context, primary []byte) <-chan tikv.BinlogWriteResult { + ch := make(chan tikv.BinlogWriteResult, 1) + go func() { + logutil.Eventf(ctx, "start prewrite binlog") + bin := e.binInfo.Data + bin.StartTs = int64(e.txn.StartTS()) + if bin.Tp == binlog.BinlogType_Prewrite { + bin.PrewriteKey = primary + } + wr := e.binInfo.WriteBinlog(e.txn.GetClusterID()) + if wr.Skipped() { + e.binInfo.Data.PrewriteValue = nil + binloginfo.AddOneSkippedCommitter() + } + logutil.Eventf(ctx, "finish prewrite binlog") + ch <- wr + }() + return ch +} + +func (e *binlogExecutor) Commit(ctx context.Context, commitTS int64) { + e.binInfo.Data.Tp = binlog.BinlogType_Commit + if commitTS == 0 { + e.binInfo.Data.Tp = binlog.BinlogType_Rollback + } + e.binInfo.Data.CommitTs = commitTS + e.binInfo.Data.PrewriteValue = nil + + wg := sync.WaitGroup{} + mock := false + failpoint.Inject("mockSyncBinlogCommit", func(val failpoint.Value) { + if val.(bool) { + wg.Add(1) + mock = true + } + }) + go func() { + logutil.Eventf(ctx, "start write finish binlog") + binlogWriteResult := e.binInfo.WriteBinlog(e.txn.GetClusterID()) + err := binlogWriteResult.GetError() + if err != nil { + logutil.BgLogger().Error("failed to write binlog", + zap.Error(err)) + } + logutil.Eventf(ctx, "finish write finish binlog") + if mock { + wg.Done() + } + }() + if mock { + wg.Wait() + } +} diff --git a/pkg/store/driver/txn/txn_driver.go b/pkg/store/driver/txn/txn_driver.go index 03c288852a70f..91e48f2977c0f 100644 --- a/pkg/store/driver/txn/txn_driver.go +++ b/pkg/store/driver/txn/txn_driver.go @@ -401,7 +401,7 @@ func (txn *tikvTxn) UpdateMemBufferFlags(key []byte, flags ...kv.FlagsOp) { func (txn *tikvTxn) generateWriteConflictForLockedWithConflict(lockCtx *kv.LockCtx) error { if lockCtx.MaxLockedWithConflictTS != 0 { - failpoint.Inject("lockedWithConflictOccurs", func() {}) + failpoint.Eval(_curpkg_("lockedWithConflictOccurs")) var bufTableID, bufRest bytes.Buffer foundKey := false for k, v := range lockCtx.Values { diff --git a/pkg/store/driver/txn/txn_driver.go__failpoint_stash__ b/pkg/store/driver/txn/txn_driver.go__failpoint_stash__ new file mode 100644 index 0000000000000..03c288852a70f --- /dev/null +++ b/pkg/store/driver/txn/txn_driver.go__failpoint_stash__ @@ -0,0 +1,491 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package txn + +import ( + "bytes" + "context" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + derr "github.com/pingcap/tidb/pkg/store/driver/error" + "github.com/pingcap/tidb/pkg/store/driver/options" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/tracing" + tikverr "github.com/tikv/client-go/v2/error" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/tikvrpc/interceptor" + "github.com/tikv/client-go/v2/txnkv" + "github.com/tikv/client-go/v2/txnkv/txnsnapshot" + "go.uber.org/zap" +) + +type tikvTxn struct { + *tikv.KVTxn + idxNameCache map[int64]*model.TableInfo + snapshotInterceptor kv.SnapshotInterceptor + // columnMapsCache is a cache used for the mutation checker + columnMapsCache any + isCommitterWorking atomic.Bool + memBuffer *memBuffer +} + +// NewTiKVTxn returns a new Transaction. +func NewTiKVTxn(txn *tikv.KVTxn) kv.Transaction { + txn.SetKVFilter(TiDBKVFilter{}) + + // init default size limits by config + entryLimit := kv.TxnEntrySizeLimit.Load() + totalLimit := kv.TxnTotalSizeLimit.Load() + txn.GetUnionStore().SetEntrySizeLimit(entryLimit, totalLimit) + + return &tikvTxn{ + txn, make(map[int64]*model.TableInfo), nil, nil, atomic.Bool{}, + newMemBuffer(txn.GetMemBuffer(), txn.IsPipelined()), + } +} + +func (txn *tikvTxn) GetTableInfo(id int64) *model.TableInfo { + return txn.idxNameCache[id] +} + +func (txn *tikvTxn) SetDiskFullOpt(level kvrpcpb.DiskFullOpt) { + txn.KVTxn.SetDiskFullOpt(level) +} + +func (txn *tikvTxn) CacheTableInfo(id int64, info *model.TableInfo) { + txn.idxNameCache[id] = info + // For partition table, also cached tblInfo with TableID for global index. + if info != nil && info.ID != id { + txn.idxNameCache[info.ID] = info + } +} + +func (txn *tikvTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keysInput ...kv.Key) error { + if intest.InTest { + txn.isCommitterWorking.Store(true) + defer txn.isCommitterWorking.Store(false) + } + keys := toTiKVKeys(keysInput) + err := txn.KVTxn.LockKeys(ctx, lockCtx, keys...) + if err != nil { + return txn.extractKeyErr(err) + } + return txn.generateWriteConflictForLockedWithConflict(lockCtx) +} + +func (txn *tikvTxn) LockKeysFunc(ctx context.Context, lockCtx *kv.LockCtx, fn func(), keysInput ...kv.Key) error { + if intest.InTest { + txn.isCommitterWorking.Store(true) + defer txn.isCommitterWorking.Store(false) + } + keys := toTiKVKeys(keysInput) + err := txn.KVTxn.LockKeysFunc(ctx, lockCtx, fn, keys...) + if err != nil { + return txn.extractKeyErr(err) + } + return txn.generateWriteConflictForLockedWithConflict(lockCtx) +} + +func (txn *tikvTxn) Commit(ctx context.Context) error { + if intest.InTest { + txn.isCommitterWorking.Store(true) + } + err := txn.KVTxn.Commit(ctx) + return txn.extractKeyErr(err) +} + +func (txn *tikvTxn) GetMemDBCheckpoint() *tikv.MemDBCheckpoint { + buf := txn.KVTxn.GetMemBuffer() + return buf.Checkpoint() +} + +func (txn *tikvTxn) RollbackMemDBToCheckpoint(savepoint *tikv.MemDBCheckpoint) { + buf := txn.KVTxn.GetMemBuffer() + buf.RevertToCheckpoint(savepoint) +} + +// GetSnapshot returns the Snapshot binding to this transaction. +func (txn *tikvTxn) GetSnapshot() kv.Snapshot { + return &tikvSnapshot{txn.KVTxn.GetSnapshot(), txn.snapshotInterceptor} +} + +// Iter creates an Iterator positioned on the first entry that k <= entry's key. +// If such entry is not found, it returns an invalid Iterator with no error. +// It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded. +// The Iterator must be Closed after use. +func (txn *tikvTxn) Iter(k kv.Key, upperBound kv.Key) (iter kv.Iterator, err error) { + var dirtyIter, snapIter kv.Iterator + if dirtyIter, err = txn.GetMemBuffer().Iter(k, upperBound); err != nil { + return nil, err + } + + if snapIter, err = txn.GetSnapshot().Iter(k, upperBound); err != nil { + dirtyIter.Close() + return nil, err + } + + iter, err = NewUnionIter(dirtyIter, snapIter, false) + if err != nil { + dirtyIter.Close() + snapIter.Close() + } + + return iter, err +} + +// IterReverse creates a reversed Iterator positioned on the first entry which key is less than k. +// The returned iterator will iterate from greater key to smaller key. +// If k is nil, the returned iterator will be positioned at the last key. +func (txn *tikvTxn) IterReverse(k kv.Key, lowerBound kv.Key) (iter kv.Iterator, err error) { + var dirtyIter, snapIter kv.Iterator + + if dirtyIter, err = txn.GetMemBuffer().IterReverse(k, lowerBound); err != nil { + return nil, err + } + + if snapIter, err = txn.GetSnapshot().IterReverse(k, lowerBound); err != nil { + dirtyIter.Close() + return nil, err + } + + iter, err = NewUnionIter(dirtyIter, snapIter, true) + if err != nil { + dirtyIter.Close() + snapIter.Close() + } + + return iter, err +} + +// BatchGet gets kv from the memory buffer of statement and transaction, and the kv storage. +// Do not use len(value) == 0 or value == nil to represent non-exist. +// If a key doesn't exist, there shouldn't be any corresponding entry in the result map. +func (txn *tikvTxn) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { + r, ctx := tracing.StartRegionEx(ctx, "tikvTxn.BatchGet") + defer r.End() + return NewBufferBatchGetter(txn.GetMemBuffer(), nil, txn.GetSnapshot()).BatchGet(ctx, keys) +} + +func (txn *tikvTxn) Delete(k kv.Key) error { + err := txn.KVTxn.Delete(k) + return derr.ToTiDBErr(err) +} + +func (txn *tikvTxn) Get(ctx context.Context, k kv.Key) ([]byte, error) { + val, err := txn.GetMemBuffer().Get(ctx, k) + if kv.ErrNotExist.Equal(err) { + val, err = txn.GetSnapshot().Get(ctx, k) + } + + if err == nil && len(val) == 0 { + return nil, kv.ErrNotExist + } + + return val, err +} + +func (txn *tikvTxn) Set(k kv.Key, v []byte) error { + err := txn.KVTxn.Set(k, v) + return derr.ToTiDBErr(err) +} + +func (txn *tikvTxn) GetMemBuffer() kv.MemBuffer { + if txn.memBuffer == nil { + txn.memBuffer = newMemBuffer(txn.KVTxn.GetMemBuffer(), txn.IsPipelined()) + } + return txn.memBuffer +} + +func (txn *tikvTxn) SetOption(opt int, val any) { + if intest.InTest { + txn.assertCommitterNotWorking() + } + switch opt { + case kv.BinlogInfo: + txn.SetBinlogExecutor(&binlogExecutor{ + txn: txn.KVTxn, + binInfo: val.(*binloginfo.BinlogInfo), // val cannot be other type. + }) + case kv.SchemaChecker: + txn.SetSchemaLeaseChecker(val.(tikv.SchemaLeaseChecker)) + case kv.IsolationLevel: + level := getTiKVIsolationLevel(val.(kv.IsoLevel)) + txn.KVTxn.GetSnapshot().SetIsolationLevel(level) + case kv.Priority: + txn.KVTxn.SetPriority(getTiKVPriority(val.(int))) + case kv.NotFillCache: + txn.KVTxn.GetSnapshot().SetNotFillCache(val.(bool)) + case kv.Pessimistic: + txn.SetPessimistic(val.(bool)) + case kv.SnapshotTS: + txn.KVTxn.GetSnapshot().SetSnapshotTS(val.(uint64)) + case kv.ReplicaRead: + t := options.GetTiKVReplicaReadType(val.(kv.ReplicaReadType)) + txn.KVTxn.GetSnapshot().SetReplicaRead(t) + case kv.TaskID: + txn.KVTxn.GetSnapshot().SetTaskID(val.(uint64)) + case kv.InfoSchema: + txn.SetSchemaVer(val.(tikv.SchemaVer)) + case kv.CollectRuntimeStats: + if val == nil { + txn.KVTxn.GetSnapshot().SetRuntimeStats(nil) + } else { + txn.KVTxn.GetSnapshot().SetRuntimeStats(val.(*txnsnapshot.SnapshotRuntimeStats)) + } + case kv.SampleStep: + txn.KVTxn.GetSnapshot().SetSampleStep(val.(uint32)) + case kv.CommitHook: + txn.SetCommitCallback(val.(func(string, error))) + case kv.EnableAsyncCommit: + txn.SetEnableAsyncCommit(val.(bool)) + case kv.Enable1PC: + txn.SetEnable1PC(val.(bool)) + case kv.GuaranteeLinearizability: + txn.SetCausalConsistency(!val.(bool)) + case kv.TxnScope: + txn.SetScope(val.(string)) + case kv.IsStalenessReadOnly: + txn.KVTxn.GetSnapshot().SetIsStalenessReadOnly(val.(bool)) + case kv.MatchStoreLabels: + txn.KVTxn.GetSnapshot().SetMatchStoreLabels(val.([]*metapb.StoreLabel)) + case kv.ResourceGroupTag: + txn.KVTxn.SetResourceGroupTag(val.([]byte)) + case kv.ResourceGroupTagger: + txn.KVTxn.SetResourceGroupTagger(val.(tikvrpc.ResourceGroupTagger)) + case kv.KVFilter: + txn.KVTxn.SetKVFilter(val.(tikv.KVFilter)) + case kv.SnapInterceptor: + txn.snapshotInterceptor = val.(kv.SnapshotInterceptor) + case kv.CommitTSUpperBoundCheck: + txn.KVTxn.SetCommitTSUpperBoundCheck(val.(func(commitTS uint64) bool)) + case kv.RPCInterceptor: + txn.KVTxn.AddRPCInterceptor(val.(interceptor.RPCInterceptor)) + case kv.AssertionLevel: + txn.KVTxn.SetAssertionLevel(val.(kvrpcpb.AssertionLevel)) + case kv.TableToColumnMaps: + txn.columnMapsCache = val + case kv.RequestSourceInternal: + txn.KVTxn.SetRequestSourceInternal(val.(bool)) + case kv.RequestSourceType: + txn.KVTxn.SetRequestSourceType(val.(string)) + case kv.ExplicitRequestSourceType: + txn.KVTxn.SetExplicitRequestSourceType(val.(string)) + case kv.ReplicaReadAdjuster: + txn.KVTxn.GetSnapshot().SetReplicaReadAdjuster(val.(txnkv.ReplicaReadAdjuster)) + case kv.TxnSource: + txn.KVTxn.SetTxnSource(val.(uint64)) + case kv.ResourceGroupName: + txn.KVTxn.SetResourceGroupName(val.(string)) + case kv.LoadBasedReplicaReadThreshold: + txn.KVTxn.GetSnapshot().SetLoadBasedReplicaReadThreshold(val.(time.Duration)) + case kv.TiKVClientReadTimeout: + txn.KVTxn.GetSnapshot().SetKVReadTimeout(time.Duration(val.(uint64) * uint64(time.Millisecond))) + case kv.SizeLimits: + limits := val.(kv.TxnSizeLimits) + txn.KVTxn.GetUnionStore().SetEntrySizeLimit(limits.Entry, limits.Total) + case kv.SessionID: + txn.KVTxn.SetSessionID(val.(uint64)) + } +} + +func (txn *tikvTxn) GetOption(opt int) any { + switch opt { + case kv.GuaranteeLinearizability: + return !txn.KVTxn.IsCasualConsistency() + case kv.TxnScope: + return txn.KVTxn.GetScope() + case kv.TableToColumnMaps: + return txn.columnMapsCache + case kv.RequestSourceInternal: + return txn.RequestSourceInternal + case kv.RequestSourceType: + return txn.RequestSourceType + default: + return nil + } +} + +// SetVars sets variables to the transaction. +func (txn *tikvTxn) SetVars(vars any) { + if vs, ok := vars.(*tikv.Variables); ok { + txn.KVTxn.SetVars(vs) + } +} + +func (txn *tikvTxn) GetVars() any { + return txn.KVTxn.GetVars() +} + +func (txn *tikvTxn) extractKeyErr(err error) error { + if e, ok := errors.Cause(err).(*tikverr.ErrKeyExist); ok { + return txn.extractKeyExistsErr(e) + } + return extractKeyErr(err) +} + +func (txn *tikvTxn) extractKeyExistsErr(errExist *tikverr.ErrKeyExist) error { + var key kv.Key = errExist.GetKey() + tableID, indexID, isRecord, err := tablecodec.DecodeKeyHead(key) + if err != nil { + return genKeyExistsError("UNKNOWN", key.String(), err) + } + indexID = tablecodec.IndexIDMask & indexID + + tblInfo := txn.GetTableInfo(tableID) + if tblInfo == nil { + return genKeyExistsError("UNKNOWN", key.String(), errors.New("cannot find table info")) + } + var value []byte + if txn.IsPipelined() { + value = errExist.Value + if len(value) == 0 { + return genKeyExistsError( + "UNKNOWN", + key.String(), + errors.New("The value is empty (a delete)"), + ) + } + } else { + value, err = txn.KVTxn.GetUnionStore().GetMemBuffer().GetMemDB().SelectValueHistory(key, func(value []byte) bool { return len(value) != 0 }) + } + if err != nil { + return genKeyExistsError("UNKNOWN", key.String(), err) + } + + if isRecord { + return ExtractKeyExistsErrFromHandle(key, value, tblInfo) + } + return ExtractKeyExistsErrFromIndex(key, value, tblInfo, indexID) +} + +// SetAssertion sets an assertion for the key operation. +func (txn *tikvTxn) SetAssertion(key []byte, assertion ...kv.FlagsOp) error { + f, err := txn.GetUnionStore().GetMemBuffer().GetFlags(key) + if err != nil && !tikverr.IsErrNotFound(err) { + return err + } + if err == nil && f.HasAssertionFlags() { + return nil + } + txn.UpdateMemBufferFlags(key, assertion...) + return nil +} + +func (txn *tikvTxn) UpdateMemBufferFlags(key []byte, flags ...kv.FlagsOp) { + txn.GetUnionStore().GetMemBuffer().UpdateFlags(key, getTiKVFlagsOps(flags)...) +} + +func (txn *tikvTxn) generateWriteConflictForLockedWithConflict(lockCtx *kv.LockCtx) error { + if lockCtx.MaxLockedWithConflictTS != 0 { + failpoint.Inject("lockedWithConflictOccurs", func() {}) + var bufTableID, bufRest bytes.Buffer + foundKey := false + for k, v := range lockCtx.Values { + if v.LockedWithConflictTS >= lockCtx.MaxLockedWithConflictTS { + foundKey = true + prettyWriteKey(&bufTableID, &bufRest, []byte(k)) + break + } + } + if !foundKey { + bufTableID.WriteString("") + } + // TODO: Primary is not exported here. + primary := " primary=" + primaryRest := "" + return kv.ErrWriteConflict.FastGenByArgs(txn.StartTS(), 0, lockCtx.MaxLockedWithConflictTS, bufTableID.String(), bufRest.String(), primary, primaryRest, "LockedWithConflict") + } + return nil +} + +// StartFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. +// TODO: Update the methods' signatures in client-go to avoid this adaptor functions. +// TODO: Rename aggressive locking in client-go to fair locking. +func (txn *tikvTxn) StartFairLocking() error { + txn.KVTxn.StartAggressiveLocking() + return nil +} + +// RetryFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. +func (txn *tikvTxn) RetryFairLocking(ctx context.Context) error { + txn.KVTxn.RetryAggressiveLocking(ctx) + return nil +} + +// CancelFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. +func (txn *tikvTxn) CancelFairLocking(ctx context.Context) error { + txn.KVTxn.CancelAggressiveLocking(ctx) + return nil +} + +// DoneFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. +func (txn *tikvTxn) DoneFairLocking(ctx context.Context) error { + txn.KVTxn.DoneAggressiveLocking(ctx) + return nil +} + +// IsInFairLockingMode adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. +func (txn *tikvTxn) IsInFairLockingMode() bool { + return txn.KVTxn.IsInAggressiveLockingMode() +} + +// MayFlush wraps the flush function and extract the error. +func (txn *tikvTxn) MayFlush() error { + if !txn.IsPipelined() { + return nil + } + if intest.InTest { + txn.isCommitterWorking.Store(true) + } + _, err := txn.KVTxn.GetMemBuffer().Flush(false) + return txn.extractKeyErr(err) +} + +// assertCommitterNotWorking asserts that the committer is not working, so it's safe to modify the options for txn and committer. +// It panics when committer is working, only use it when test with --tags=intest tag. +func (txn *tikvTxn) assertCommitterNotWorking() { + if txn.isCommitterWorking.Load() { + panic("committer is working") + } +} + +// TiDBKVFilter is the filter specific to TiDB to filter out KV pairs that needn't be committed. +type TiDBKVFilter struct{} + +// IsUnnecessaryKeyValue defines which kinds of KV pairs from TiDB needn't be committed. +func (f TiDBKVFilter) IsUnnecessaryKeyValue(key, value []byte, flags tikvstore.KeyFlags) (bool, error) { + isUntouchedValue := tablecodec.IsUntouchedIndexKValue(key, value) + if isUntouchedValue && flags.HasPresumeKeyNotExists() { + logutil.BgLogger().Error("unexpected path the untouched key value with PresumeKeyNotExists flag", + zap.Stringer("key", kv.Key(key)), zap.Stringer("value", kv.Key(value)), + zap.Uint16("flags", uint16(flags)), zap.Stack("stack")) + return false, errors.Errorf( + "unexpected path the untouched key=%s value=%s contains PresumeKeyNotExists flag keyFlags=%v", + kv.Key(key).String(), kv.Key(value).String(), flags) + } + return isUntouchedValue, nil +} diff --git a/pkg/store/gcworker/binding__failpoint_binding__.go b/pkg/store/gcworker/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..158fd645690b5 --- /dev/null +++ b/pkg/store/gcworker/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package gcworker + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/store/gcworker/gc_worker.go b/pkg/store/gcworker/gc_worker.go index 3f5c1d3d166b3..7f579774703a4 100644 --- a/pkg/store/gcworker/gc_worker.go +++ b/pkg/store/gcworker/gc_worker.go @@ -734,9 +734,9 @@ func (w *GCWorker) setGCWorkerServiceSafePoint(ctx context.Context, safePoint ui } func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency int) error { - failpoint.Inject("mockRunGCJobFail", func() { - failpoint.Return(errors.New("mock failure of runGCJoB")) - }) + if _, _err_ := failpoint.Eval(_curpkg_("mockRunGCJobFail")); _err_ == nil { + return errors.New("mock failure of runGCJoB") + } metrics.GCWorkerCounter.WithLabelValues("run_job").Inc() err := w.resolveLocks(ctx, safePoint, concurrency) @@ -832,9 +832,9 @@ func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64, concurren } else { err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) } - failpoint.Inject("ignoreDeleteRangeFailed", func() { + if _, _err_ := failpoint.Eval(_curpkg_("ignoreDeleteRangeFailed")); _err_ == nil { err = nil - }) + } if err != nil { logutil.Logger(ctx).Error("delete range failed on range", zap.String("category", "gc worker"), @@ -1132,9 +1132,9 @@ func (w *GCWorker) resolveLocks( handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { scanLimit := uint32(tikv.GCScanLockLimit) - failpoint.Inject("lowScanLockLimit", func() { + if _, _err_ := failpoint.Eval(_curpkg_("lowScanLockLimit")); _err_ == nil { scanLimit = 3 - }) + } return tikv.ResolveLocksForRange(ctx, w.regionLockResolver, safePoint, r.StartKey, r.EndKey, tikv.NewGcResolveLockMaxBackoffer, scanLimit) } @@ -1492,7 +1492,7 @@ func (w *GCWorker) saveValueToSysTable(key, value string) error { func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util.DelRangeTask, gcPlacementRuleCache map[int64]any) (err error) { // Get the job from the job history var historyJob *model.Job - failpoint.Inject("mockHistoryJobForGC", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("mockHistoryJobForGC")); _err_ == nil { args, err1 := json.Marshal([]any{kv.Key{}, []int64{int64(v.(int))}}) if err1 != nil { return @@ -1503,7 +1503,7 @@ func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util TableID: int64(v.(int)), RawArgs: args, } - }) + } if historyJob == nil { historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) if err != nil { @@ -1566,7 +1566,7 @@ func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util func (w *GCWorker) doGCLabelRules(dr util.DelRangeTask) (err error) { // Get the job from the job history var historyJob *model.Job - failpoint.Inject("mockHistoryJob", func(v failpoint.Value) { + if v, _err_ := failpoint.Eval(_curpkg_("mockHistoryJob")); _err_ == nil { args, err1 := json.Marshal([]any{kv.Key{}, []int64{}, []string{v.(string)}}) if err1 != nil { return @@ -1576,7 +1576,7 @@ func (w *GCWorker) doGCLabelRules(dr util.DelRangeTask) (err error) { Type: model.ActionDropTable, RawArgs: args, } - }) + } if historyJob == nil { se := createSession(w.store) historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) diff --git a/pkg/store/gcworker/gc_worker.go__failpoint_stash__ b/pkg/store/gcworker/gc_worker.go__failpoint_stash__ new file mode 100644 index 0000000000000..3f5c1d3d166b3 --- /dev/null +++ b/pkg/store/gcworker/gc_worker.go__failpoint_stash__ @@ -0,0 +1,1759 @@ +// Copyright 2017 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gcworker + +import ( + "bytes" + "context" + "encoding/hex" + "encoding/json" + "fmt" + "math" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/tidb/pkg/ddl" + "github.com/pingcap/tidb/pkg/ddl/label" + "github.com/pingcap/tidb/pkg/ddl/placement" + "github.com/pingcap/tidb/pkg/ddl/util" + "github.com/pingcap/tidb/pkg/domain/infosync" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/privilege" + "github.com/pingcap/tidb/pkg/session" + sessiontypes "github.com/pingcap/tidb/pkg/session/types" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/logutil" + tikverr "github.com/tikv/client-go/v2/error" + tikvstore "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "github.com/tikv/client-go/v2/txnkv/rangetask" + tikvutil "github.com/tikv/client-go/v2/util" + pd "github.com/tikv/pd/client" + "go.uber.org/zap" +) + +// GCWorker periodically triggers GC process on tikv server. +type GCWorker struct { + uuid string + desc string + store kv.Storage + tikvStore tikv.Storage + pdClient pd.Client + gcIsRunning bool + lastFinish time.Time + cancel context.CancelFunc + done chan error + regionLockResolver tikv.RegionLockResolver +} + +// NewGCWorker creates a GCWorker instance. +func NewGCWorker(store kv.Storage, pdClient pd.Client) (*GCWorker, error) { + ver, err := store.CurrentVersion(kv.GlobalTxnScope) + if err != nil { + return nil, errors.Trace(err) + } + hostName, err := os.Hostname() + if err != nil { + hostName = "unknown" + } + tikvStore, ok := store.(tikv.Storage) + if !ok { + return nil, errors.New("GC should run against TiKV storage") + } + uuid := strconv.FormatUint(ver.Ver, 16) + resolverIdentifier := fmt.Sprintf("gc-worker-%s", uuid) + worker := &GCWorker{ + uuid: uuid, + desc: fmt.Sprintf("host:%s, pid:%d, start at %s", hostName, os.Getpid(), time.Now()), + store: store, + tikvStore: tikvStore, + pdClient: pdClient, + gcIsRunning: false, + lastFinish: time.Now(), + regionLockResolver: tikv.NewRegionLockResolver(resolverIdentifier, tikvStore), + done: make(chan error), + } + variable.RegisterStatistics(worker) + return worker, nil +} + +// Start starts the worker. +func (w *GCWorker) Start() { + var ctx context.Context + ctx, w.cancel = context.WithCancel(context.Background()) + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnGC) + var wg sync.WaitGroup + wg.Add(1) + go w.start(ctx, &wg) + wg.Wait() // Wait create session finish in worker, some test code depend on this to avoid race. +} + +// Close stops background goroutines. +func (w *GCWorker) Close() { + w.cancel() +} + +const ( + booleanTrue = "true" + booleanFalse = "false" + + gcWorkerTickInterval = time.Minute + gcWorkerLease = time.Minute * 2 + gcLeaderUUIDKey = "tikv_gc_leader_uuid" + gcLeaderDescKey = "tikv_gc_leader_desc" + gcLeaderLeaseKey = "tikv_gc_leader_lease" + + gcLastRunTimeKey = "tikv_gc_last_run_time" + gcRunIntervalKey = "tikv_gc_run_interval" + gcDefaultRunInterval = time.Minute * 10 + gcWaitTime = time.Minute * 1 + gcRedoDeleteRangeDelay = 24 * time.Hour + + gcLifeTimeKey = "tikv_gc_life_time" + gcDefaultLifeTime = time.Minute * 10 + gcMinLifeTime = time.Minute * 10 + gcSafePointKey = "tikv_gc_safe_point" + gcConcurrencyKey = "tikv_gc_concurrency" + gcDefaultConcurrency = 2 + gcMinConcurrency = 1 + gcMaxConcurrency = 128 + + gcEnableKey = "tikv_gc_enable" + gcDefaultEnableValue = true + + gcModeKey = "tikv_gc_mode" + gcModeCentral = "central" + gcModeDistributed = "distributed" + gcModeDefault = gcModeDistributed + + gcScanLockModeKey = "tikv_gc_scan_lock_mode" + + gcAutoConcurrencyKey = "tikv_gc_auto_concurrency" + gcDefaultAutoConcurrency = true + + gcWorkerServiceSafePointID = "gc_worker" + + // Status var names start with tidb_% + tidbGCLastRunTime = "tidb_gc_last_run_time" + tidbGCLeaderDesc = "tidb_gc_leader_desc" + tidbGCLeaderLease = "tidb_gc_leader_lease" + tidbGCLeaderUUID = "tidb_gc_leader_uuid" + tidbGCSafePoint = "tidb_gc_safe_point" +) + +var gcSafePointCacheInterval = tikv.GcSafePointCacheInterval + +var gcVariableComments = map[string]string{ + gcLeaderUUIDKey: "Current GC worker leader UUID. (DO NOT EDIT)", + gcLeaderDescKey: "Host name and pid of current GC leader. (DO NOT EDIT)", + gcLeaderLeaseKey: "Current GC worker leader lease. (DO NOT EDIT)", + gcLastRunTimeKey: "The time when last GC starts. (DO NOT EDIT)", + gcRunIntervalKey: "GC run interval, at least 10m, in Go format.", + gcLifeTimeKey: "All versions within life time will not be collected by GC, at least 10m, in Go format.", + gcSafePointKey: "All versions after safe point can be accessed. (DO NOT EDIT)", + gcConcurrencyKey: "How many goroutines used to do GC parallel, [1, 128], default 2", + gcEnableKey: "Current GC enable status", + gcModeKey: "Mode of GC, \"central\" or \"distributed\"", + gcAutoConcurrencyKey: "Let TiDB pick the concurrency automatically. If set false, tikv_gc_concurrency will be used", + gcScanLockModeKey: "Mode of scanning locks, \"physical\" or \"legacy\".(Deprecated)", +} + +const ( + unsafeDestroyRangeTimeout = 5 * time.Minute + gcTimeout = 5 * time.Minute +) + +func (w *GCWorker) start(ctx context.Context, wg *sync.WaitGroup) { + logutil.Logger(ctx).Info("start", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid)) + + w.tick(ctx) // Immediately tick once to initialize configs. + wg.Done() + + ticker := time.NewTicker(gcWorkerTickInterval) + defer ticker.Stop() + defer func() { + r := recover() + if r != nil { + logutil.Logger(ctx).Error("gcWorker", + zap.Any("r", r), + zap.Stack("stack")) + metrics.PanicCounter.WithLabelValues(metrics.LabelGCWorker).Inc() + } + }() + for { + select { + case <-ticker.C: + w.tick(ctx) + case err := <-w.done: + w.gcIsRunning = false + w.lastFinish = time.Now() + if err != nil { + logutil.Logger(ctx).Error("runGCJob", zap.String("category", "gc worker"), zap.Error(err)) + } + case <-ctx.Done(): + logutil.Logger(ctx).Info("quit", zap.String("category", "gc worker"), zap.String("uuid", w.uuid)) + return + } + } +} + +func createSession(store kv.Storage) sessiontypes.Session { + for { + se, err := session.CreateSession(store) + if err != nil { + logutil.BgLogger().Warn("create session", zap.String("category", "gc worker"), zap.Error(err)) + continue + } + // Disable privilege check for gc worker session. + privilege.BindPrivilegeManager(se, nil) + se.GetSessionVars().CommonGlobalLoaded = true + se.GetSessionVars().InRestrictedSQL = true + se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) + return se + } +} + +// GetScope gets the status variables scope. +func (w *GCWorker) GetScope(status string) variable.ScopeFlag { + return variable.DefaultStatusVarScopeFlag +} + +// Stats returns the server statistics. +func (w *GCWorker) Stats(vars *variable.SessionVars) (map[string]any, error) { + m := make(map[string]any) + if v, err := w.loadValueFromSysTable(gcLeaderUUIDKey); err == nil { + m[tidbGCLeaderUUID] = v + } + if v, err := w.loadValueFromSysTable(gcLeaderDescKey); err == nil { + m[tidbGCLeaderDesc] = v + } + if v, err := w.loadValueFromSysTable(gcLeaderLeaseKey); err == nil { + m[tidbGCLeaderLease] = v + } + if v, err := w.loadValueFromSysTable(gcLastRunTimeKey); err == nil { + m[tidbGCLastRunTime] = v + } + if v, err := w.loadValueFromSysTable(gcSafePointKey); err == nil { + m[tidbGCSafePoint] = v + } + return m, nil +} + +func (w *GCWorker) tick(ctx context.Context) { + isLeader, err := w.checkLeader(ctx) + if err != nil { + logutil.Logger(ctx).Warn("check leader", zap.String("category", "gc worker"), zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("check_leader").Inc() + return + } + if isLeader { + err = w.leaderTick(ctx) + if err != nil { + logutil.Logger(ctx).Warn("leader tick", zap.String("category", "gc worker"), zap.Error(err)) + } + } else { + // Config metrics should always be updated by leader, set them to 0 when current instance is not leader. + metrics.GCConfigGauge.WithLabelValues(gcRunIntervalKey).Set(0) + metrics.GCConfigGauge.WithLabelValues(gcLifeTimeKey).Set(0) + } +} + +// getGCSafePoint returns the current gc safe point. +func getGCSafePoint(ctx context.Context, pdClient pd.Client) (uint64, error) { + // If there is try to set gc safepoint is 0, the interface will not set gc safepoint to 0, + // it will return current gc safepoint. + safePoint, err := pdClient.UpdateGCSafePoint(ctx, 0) + if err != nil { + return 0, errors.Trace(err) + } + return safePoint, nil +} + +func (w *GCWorker) logIsGCSafePointTooEarly(ctx context.Context, safePoint uint64) error { + now, err := w.getOracleTime() + if err != nil { + return errors.Trace(err) + } + + checkTs := oracle.GoTimeToTS(now.Add(-gcDefaultLifeTime * 2)) + if checkTs > safePoint { + logutil.Logger(ctx).Info("gc safepoint is too early. "+ + "Maybe there is a bit BR/Lightning/CDC task, "+ + "or a long transaction is running "+ + "or need a tidb without setting keyspace-name to calculate and update gc safe point.", + zap.String("category", "gc worker")) + } + return nil +} + +func (w *GCWorker) runKeyspaceDeleteRange(ctx context.Context, concurrency int) error { + // Get safe point from PD. + // The GC safe point is updated only after the global GC have done resolveLocks phase globally. + // So, in the following code, resolveLocks must have been done by the global GC on the ranges to be deleted, + // so its safe to delete the ranges. + safePoint, err := getGCSafePoint(ctx, w.pdClient) + if err != nil { + logutil.Logger(ctx).Info("get gc safe point error", zap.String("category", "gc worker"), zap.Error(errors.Trace(err))) + return nil + } + + if safePoint == 0 { + logutil.Logger(ctx).Info("skip keyspace delete range, because gc safe point is 0", zap.String("category", "gc worker")) + return nil + } + + err = w.logIsGCSafePointTooEarly(ctx, safePoint) + if err != nil { + logutil.Logger(ctx).Info("log is gc safe point is too early error", zap.String("category", "gc worker"), zap.Error(errors.Trace(err))) + return nil + } + + keyspaceID := w.store.GetCodec().GetKeyspaceID() + logutil.Logger(ctx).Info("start keyspace delete range", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int("concurrency", concurrency), + zap.Uint32("keyspaceID", uint32(keyspaceID)), + zap.Uint64("GCSafepoint", safePoint)) + + // Do deleteRanges. + err = w.deleteRanges(ctx, safePoint, concurrency) + if err != nil { + logutil.Logger(ctx).Error("delete range returns an error", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("delete_range").Inc() + return errors.Trace(err) + } + + // Do redoDeleteRanges. + err = w.redoDeleteRanges(ctx, safePoint, concurrency) + if err != nil { + logutil.Logger(ctx).Error("redo-delete range returns an error", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("redo_delete_range").Inc() + return errors.Trace(err) + } + + return nil +} + +// leaderTick of GC worker checks if it should start a GC job every tick. +func (w *GCWorker) leaderTick(ctx context.Context) error { + if w.gcIsRunning { + logutil.Logger(ctx).Info("there's already a gc job running, skipped", zap.String("category", "gc worker"), + zap.String("leaderTick on", w.uuid)) + return nil + } + + concurrency, err := w.getGCConcurrency(ctx) + if err != nil { + logutil.Logger(ctx).Info("failed to get gc concurrency.", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + return errors.Trace(err) + } + + // Gc safe point is not separated by keyspace now. The whole cluster has only one global gc safe point. + // So at least one TiDB with `keyspace-name` not set is required in the whole cluster to calculate and update gc safe point. + // If `keyspace-name` is set, the TiDB node will only do its own delete range, and will not calculate gc safe point and resolve locks. + // Note that when `keyspace-name` is set, `checkLeader` will be done within the key space. + // Therefore only one TiDB node in each key space will be responsible to do delete range. + if w.store.GetCodec().GetKeyspace() != nil { + err = w.runKeyspaceGCJob(ctx, concurrency) + if err != nil { + return errors.Trace(err) + } + return nil + } + + ok, safePoint, err := w.prepare(ctx) + if err != nil { + metrics.GCJobFailureCounter.WithLabelValues("prepare").Inc() + return errors.Trace(err) + } else if !ok { + return nil + } + // When the worker is just started, or an old GC job has just finished, + // wait a while before starting a new job. + if time.Since(w.lastFinish) < gcWaitTime { + logutil.Logger(ctx).Info("another gc job has just finished, skipped.", zap.String("category", "gc worker"), + zap.String("leaderTick on ", w.uuid)) + return nil + } + + w.gcIsRunning = true + logutil.Logger(ctx).Info("starts the whole job", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("safePoint", safePoint), + zap.Int("concurrency", concurrency)) + go func() { + w.done <- w.runGCJob(ctx, safePoint, concurrency) + }() + return nil +} + +func (w *GCWorker) runKeyspaceGCJob(ctx context.Context, concurrency int) error { + // When the worker is just started, or an old GC job has just finished, + // wait a while before starting a new job. + if time.Since(w.lastFinish) < gcWaitTime { + logutil.Logger(ctx).Info("another keyspace gc job has just finished, skipped.", zap.String("category", "gc worker"), + zap.String("leaderTick on ", w.uuid)) + return nil + } + + now, err := w.getOracleTime() + if err != nil { + return errors.Trace(err) + } + ok, err := w.checkGCInterval(now) + if err != nil || !ok { + return errors.Trace(err) + } + + go func() { + w.done <- w.runKeyspaceDeleteRange(ctx, concurrency) + }() + + err = w.saveTime(gcLastRunTimeKey, now) + if err != nil { + return errors.Trace(err) + } + + return nil +} + +// prepare checks preconditions for starting a GC job. It returns a bool +// that indicates whether the GC job should start and the new safePoint. +func (w *GCWorker) prepare(ctx context.Context) (bool, uint64, error) { + // Add a transaction here is to prevent following situations: + // 1. GC check gcEnable is true, continue to do GC + // 2. The user sets gcEnable to false + // 3. The user gets `tikv_gc_safe_point` value is t1, then the user thinks the data after time t1 won't be clean by GC. + // 4. GC update `tikv_gc_safe_point` value to t2, continue do GC in this round. + // Then the data record that has been dropped between time t1 and t2, will be cleaned by GC, but the user thinks the data after t1 won't be clean by GC. + se := createSession(w.store) + defer se.Close() + _, err := se.ExecuteInternal(ctx, "BEGIN") + if err != nil { + return false, 0, errors.Trace(err) + } + doGC, safePoint, err := w.checkPrepare(ctx) + if doGC { + err = se.CommitTxn(ctx) + if err != nil { + return false, 0, errors.Trace(err) + } + } else { + se.RollbackTxn(ctx) + } + return doGC, safePoint, errors.Trace(err) +} + +func (w *GCWorker) checkPrepare(ctx context.Context) (bool, uint64, error) { + enable, err := w.checkGCEnable() + if err != nil { + return false, 0, errors.Trace(err) + } + + if !enable { + logutil.Logger(ctx).Warn("gc status is disabled.", zap.String("category", "gc worker")) + return false, 0, nil + } + now, err := w.getOracleTime() + if err != nil { + return false, 0, errors.Trace(err) + } + ok, err := w.checkGCInterval(now) + if err != nil || !ok { + return false, 0, errors.Trace(err) + } + newSafePoint, newSafePointValue, err := w.calcNewSafePoint(ctx, now) + if err != nil || newSafePoint == nil { + return false, 0, errors.Trace(err) + } + err = w.saveTime(gcLastRunTimeKey, now) + if err != nil { + return false, 0, errors.Trace(err) + } + err = w.saveTime(gcSafePointKey, *newSafePoint) + if err != nil { + return false, 0, errors.Trace(err) + } + return true, newSafePointValue, nil +} + +func (w *GCWorker) calcGlobalMinStartTS(ctx context.Context) (uint64, error) { + kvs, err := w.tikvStore.GetSafePointKV().GetWithPrefix(infosync.ServerMinStartTSPath) + if err != nil { + return 0, err + } + + var globalMinStartTS uint64 = math.MaxUint64 + for _, v := range kvs { + minStartTS, err := strconv.ParseUint(string(v.Value), 10, 64) + if err != nil { + logutil.Logger(ctx).Warn("parse minStartTS failed", zap.Error(err)) + continue + } + if minStartTS < globalMinStartTS { + globalMinStartTS = minStartTS + } + } + return globalMinStartTS, nil +} + +// calcNewSafePoint uses the current global transaction min start timestamp to calculate the new safe point. +func (w *GCWorker) calcSafePointByMinStartTS(ctx context.Context, safePoint uint64) uint64 { + globalMinStartTS, err := w.calcGlobalMinStartTS(ctx) + if err != nil { + logutil.Logger(ctx).Warn("get all minStartTS failed", zap.Error(err)) + return safePoint + } + + // If the lock.ts <= max_ts(safePoint), it will be collected and resolved by the gc worker, + // the locks of ongoing pessimistic transactions could be resolved by the gc worker and then + // the transaction is aborted, decrement the value by 1 to avoid this. + globalMinStartAllowedTS := globalMinStartTS + if globalMinStartTS > 0 { + globalMinStartAllowedTS = globalMinStartTS - 1 + } + + if globalMinStartAllowedTS < safePoint { + logutil.Logger(ctx).Info("gc safepoint blocked by a running session", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("globalMinStartTS", globalMinStartTS), + zap.Uint64("globalMinStartAllowedTS", globalMinStartAllowedTS), + zap.Uint64("safePoint", safePoint)) + safePoint = globalMinStartAllowedTS + } + return safePoint +} + +func (w *GCWorker) getOracleTime() (time.Time, error) { + currentVer, err := w.store.CurrentVersion(kv.GlobalTxnScope) + if err != nil { + return time.Time{}, errors.Trace(err) + } + return oracle.GetTimeFromTS(currentVer.Ver), nil +} + +func (w *GCWorker) checkGCEnable() (bool, error) { + return w.loadBooleanWithDefault(gcEnableKey, gcDefaultEnableValue) +} + +func (w *GCWorker) checkUseAutoConcurrency() (bool, error) { + return w.loadBooleanWithDefault(gcAutoConcurrencyKey, gcDefaultAutoConcurrency) +} + +func (w *GCWorker) loadBooleanWithDefault(key string, defaultValue bool) (bool, error) { + str, err := w.loadValueFromSysTable(key) + if err != nil { + return false, errors.Trace(err) + } + if str == "" { + // Save default value for gc enable key. The default value is always true. + defaultValueStr := booleanFalse + if defaultValue { + defaultValueStr = booleanTrue + } + err = w.saveValueToSysTable(key, defaultValueStr) + if err != nil { + return defaultValue, errors.Trace(err) + } + return defaultValue, nil + } + return strings.EqualFold(str, booleanTrue), nil +} + +func (w *GCWorker) getGCConcurrency(ctx context.Context) (int, error) { + useAutoConcurrency, err := w.checkUseAutoConcurrency() + if err != nil { + logutil.Logger(ctx).Error("failed to load config gc_auto_concurrency. use default value.", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + useAutoConcurrency = gcDefaultAutoConcurrency + } + if !useAutoConcurrency { + return w.loadGCConcurrencyWithDefault() + } + + stores, err := w.getStoresForGC(ctx) + concurrency := len(stores) + if err != nil { + logutil.Logger(ctx).Error("failed to get up stores to calculate concurrency. use config.", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + + concurrency, err = w.loadGCConcurrencyWithDefault() + if err != nil { + logutil.Logger(ctx).Error("failed to load gc concurrency from config. use default value.", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + concurrency = gcDefaultConcurrency + } + } + + if concurrency == 0 { + logutil.Logger(ctx).Error("no store is up", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid)) + return 0, errors.New("[gc worker] no store is up") + } + + return concurrency, nil +} + +func (w *GCWorker) checkGCInterval(now time.Time) (bool, error) { + runInterval, err := w.loadDurationWithDefault(gcRunIntervalKey, gcDefaultRunInterval) + if err != nil { + return false, errors.Trace(err) + } + metrics.GCConfigGauge.WithLabelValues(gcRunIntervalKey).Set(runInterval.Seconds()) + lastRun, err := w.loadTime(gcLastRunTimeKey) + if err != nil { + return false, errors.Trace(err) + } + + if lastRun != nil && lastRun.Add(*runInterval).After(now) { + logutil.BgLogger().Debug("skipping garbage collection because gc interval hasn't elapsed since last run", zap.String("category", "gc worker"), + zap.String("leaderTick on", w.uuid), + zap.Duration("interval", *runInterval), + zap.Time("last run", *lastRun)) + return false, nil + } + + return true, nil +} + +// validateGCLifeTime checks whether life time is small than min gc life time. +func (w *GCWorker) validateGCLifeTime(lifeTime time.Duration) (time.Duration, error) { + if lifeTime >= gcMinLifeTime { + return lifeTime, nil + } + + logutil.BgLogger().Info("invalid gc life time", zap.String("category", "gc worker"), + zap.Duration("get gc life time", lifeTime), + zap.Duration("min gc life time", gcMinLifeTime)) + + err := w.saveDuration(gcLifeTimeKey, gcMinLifeTime) + return gcMinLifeTime, err +} + +func (w *GCWorker) calcNewSafePoint(ctx context.Context, now time.Time) (*time.Time, uint64, error) { + lifeTime, err := w.loadDurationWithDefault(gcLifeTimeKey, gcDefaultLifeTime) + if err != nil { + return nil, 0, errors.Trace(err) + } + *lifeTime, err = w.validateGCLifeTime(*lifeTime) + if err != nil { + return nil, 0, err + } + metrics.GCConfigGauge.WithLabelValues(gcLifeTimeKey).Set(lifeTime.Seconds()) + + lastSafePoint, err := w.loadTime(gcSafePointKey) + if err != nil { + return nil, 0, errors.Trace(err) + } + + safePointValue := w.calcSafePointByMinStartTS(ctx, oracle.GoTimeToTS(now.Add(-*lifeTime))) + safePointValue, err = w.setGCWorkerServiceSafePoint(ctx, safePointValue) + if err != nil { + return nil, 0, errors.Trace(err) + } + + // safepoint is recorded in time.Time format which strips the logical part of the timestamp. + // To prevent the GC worker from keeping working due to the loss of logical part when the + // safe point isn't changed, we should compare them in time.Time format. + safePoint := oracle.GetTimeFromTS(safePointValue) + // We should never decrease safePoint. + if lastSafePoint != nil && !safePoint.After(*lastSafePoint) { + logutil.BgLogger().Info("last safe point is later than current one."+ + "No need to gc."+ + "This might be caused by manually enlarging gc lifetime", + zap.String("category", "gc worker"), + zap.String("leaderTick on", w.uuid), + zap.Time("last safe point", *lastSafePoint), + zap.Time("current safe point", safePoint)) + return nil, 0, nil + } + return &safePoint, safePointValue, nil +} + +// setGCWorkerServiceSafePoint sets the given safePoint as TiDB's service safePoint to PD, and returns the current minimal +// service safePoint among all services. +func (w *GCWorker) setGCWorkerServiceSafePoint(ctx context.Context, safePoint uint64) (uint64, error) { + // Sets TTL to MAX to make it permanently valid. + minSafePoint, err := w.pdClient.UpdateServiceGCSafePoint(ctx, gcWorkerServiceSafePointID, math.MaxInt64, safePoint) + if err != nil { + logutil.Logger(ctx).Error("failed to update service safe point", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("update_service_safe_point").Inc() + return 0, errors.Trace(err) + } + if minSafePoint < safePoint { + logutil.Logger(ctx).Info("there's another service in the cluster requires an earlier safe point. "+ + "gc will continue with the earlier one", + zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("ourSafePoint", safePoint), + zap.Uint64("minSafePoint", minSafePoint), + ) + safePoint = minSafePoint + } + return safePoint, nil +} + +func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency int) error { + failpoint.Inject("mockRunGCJobFail", func() { + failpoint.Return(errors.New("mock failure of runGCJoB")) + }) + metrics.GCWorkerCounter.WithLabelValues("run_job").Inc() + + err := w.resolveLocks(ctx, safePoint, concurrency) + if err != nil { + logutil.Logger(ctx).Error("resolve locks returns an error", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("resolve_lock").Inc() + return errors.Trace(err) + } + + // Save safe point to pd. + err = w.saveSafePoint(w.tikvStore.GetSafePointKV(), safePoint) + if err != nil { + logutil.Logger(ctx).Error("failed to save safe point to PD", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("save_safe_point").Inc() + return errors.Trace(err) + } + // Sleep to wait for all other tidb instances update their safepoint cache. + time.Sleep(gcSafePointCacheInterval) + + err = w.deleteRanges(ctx, safePoint, concurrency) + if err != nil { + logutil.Logger(ctx).Error("delete range returns an error", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("delete_range").Inc() + return errors.Trace(err) + } + err = w.redoDeleteRanges(ctx, safePoint, concurrency) + if err != nil { + logutil.Logger(ctx).Error("redo-delete range returns an error", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("redo_delete_range").Inc() + return errors.Trace(err) + } + + if w.checkUseDistributedGC() { + err = w.uploadSafePointToPD(ctx, safePoint) + if err != nil { + logutil.Logger(ctx).Error("failed to upload safe point to PD", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("upload_safe_point").Inc() + return errors.Trace(err) + } + } else { + err = w.doGC(ctx, safePoint, concurrency) + if err != nil { + logutil.Logger(ctx).Error("do GC returns an error", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("gc").Inc() + return errors.Trace(err) + } + } + + return nil +} + +// deleteRanges processes all delete range records whose ts < safePoint in table `gc_delete_range` +// `concurrency` specifies the concurrency to send NotifyDeleteRange. +func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64, concurrency int) error { + metrics.GCWorkerCounter.WithLabelValues("delete_range").Inc() + + se := createSession(w.store) + defer se.Close() + ranges, err := util.LoadDeleteRanges(ctx, se, safePoint) + if err != nil { + return errors.Trace(err) + } + + v2, err := util.IsRaftKv2(ctx, se) + if err != nil { + return errors.Trace(err) + } + // Cache table ids on which placement rules have been GC-ed, to avoid redundantly GC the same table id multiple times. + gcPlacementRuleCache := make(map[int64]any, len(ranges)) + + logutil.Logger(ctx).Info("start delete ranges", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int("ranges", len(ranges))) + startTime := time.Now() + for _, r := range ranges { + startKey, endKey := r.Range() + if v2 { + // In raftstore-v2, we use delete range instead to avoid deletion omission + task := rangetask.NewDeleteRangeTask(w.tikvStore, startKey, endKey, concurrency) + err = task.Execute(ctx) + } else { + err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) + } + failpoint.Inject("ignoreDeleteRangeFailed", func() { + err = nil + }) + + if err != nil { + logutil.Logger(ctx).Error("delete range failed on range", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Stringer("startKey", startKey), + zap.Stringer("endKey", endKey), + zap.Error(err)) + continue + } + + if err := w.doGCPlacementRules(se, safePoint, r, gcPlacementRuleCache); err != nil { + logutil.Logger(ctx).Error("gc placement rules failed on range", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int64("jobID", r.JobID), + zap.Int64("elementID", r.ElementID), + zap.Error(err)) + continue + } + if err := w.doGCLabelRules(r); err != nil { + logutil.Logger(ctx).Error("gc label rules failed on range", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int64("jobID", r.JobID), + zap.Int64("elementID", r.ElementID), + zap.Error(err)) + continue + } + + err = util.CompleteDeleteRange(se, r, !v2) + if err != nil { + logutil.Logger(ctx).Error("failed to mark delete range task done", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Stringer("startKey", startKey), + zap.Stringer("endKey", endKey), + zap.Error(err)) + metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("save").Inc() + } + } + logutil.Logger(ctx).Info("finish delete ranges", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int("num of ranges", len(ranges)), + zap.Duration("cost time", time.Since(startTime))) + metrics.GCHistogram.WithLabelValues("delete_ranges").Observe(time.Since(startTime).Seconds()) + return nil +} + +// redoDeleteRanges checks all deleted ranges whose ts is at least `lifetime + 24h` ago. See TiKV RFC #2. +// `concurrency` specifies the concurrency to send NotifyDeleteRange. +func (w *GCWorker) redoDeleteRanges(ctx context.Context, safePoint uint64, concurrency int) error { + metrics.GCWorkerCounter.WithLabelValues("redo_delete_range").Inc() + + // We check delete range records that are deleted about 24 hours ago. + redoDeleteRangesTs := safePoint - oracle.ComposeTS(int64(gcRedoDeleteRangeDelay.Seconds())*1000, 0) + + se := createSession(w.store) + ranges, err := util.LoadDoneDeleteRanges(ctx, se, redoDeleteRangesTs) + se.Close() + if err != nil { + return errors.Trace(err) + } + + logutil.Logger(ctx).Info("start redo-delete ranges", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int("num of ranges", len(ranges))) + startTime := time.Now() + for _, r := range ranges { + startKey, endKey := r.Range() + + err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) + if err != nil { + logutil.Logger(ctx).Error("redo-delete range failed on range", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Stringer("startKey", startKey), + zap.Stringer("endKey", endKey), + zap.Error(err)) + continue + } + + se := createSession(w.store) + err := util.DeleteDoneRecord(se, r) + se.Close() + if err != nil { + logutil.Logger(ctx).Error("failed to remove delete_range_done record", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Stringer("startKey", startKey), + zap.Stringer("endKey", endKey), + zap.Error(err)) + metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("save_redo").Inc() + } + } + logutil.Logger(ctx).Info("finish redo-delete ranges", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int("num of ranges", len(ranges)), + zap.Duration("cost time", time.Since(startTime))) + metrics.GCHistogram.WithLabelValues("redo_delete_ranges").Observe(time.Since(startTime).Seconds()) + return nil +} + +func (w *GCWorker) doUnsafeDestroyRangeRequest(ctx context.Context, startKey []byte, endKey []byte, _ int) error { + // Get all stores every time deleting a region. So the store list is less probably to be stale. + stores, err := w.getStoresForGC(ctx) + if err != nil { + logutil.Logger(ctx).Error("delete ranges: got an error while trying to get store list from PD", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("get_stores").Inc() + return errors.Trace(err) + } + + req := tikvrpc.NewRequest(tikvrpc.CmdUnsafeDestroyRange, &kvrpcpb.UnsafeDestroyRangeRequest{ + StartKey: startKey, + EndKey: endKey, + }, kvrpcpb.Context{DiskFullOpt: kvrpcpb.DiskFullOpt_AllowedOnAlmostFull}) + + var wg sync.WaitGroup + errChan := make(chan error, len(stores)) + + for _, store := range stores { + address := store.Address + storeID := store.Id + wg.Add(1) + go func() { + defer wg.Done() + + resp, err1 := w.tikvStore.GetTiKVClient().SendRequest(ctx, address, req, unsafeDestroyRangeTimeout) + if err1 == nil { + if resp == nil || resp.Resp == nil { + err1 = errors.Errorf("unsafe destroy range returns nil response from store %v", storeID) + } else { + errStr := (resp.Resp.(*kvrpcpb.UnsafeDestroyRangeResponse)).Error + if len(errStr) > 0 { + err1 = errors.Errorf("unsafe destroy range failed on store %v: %s", storeID, errStr) + } + } + } + + if err1 != nil { + metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("send").Inc() + } + errChan <- err1 + }() + } + + var errs []string + for range stores { + err1 := <-errChan + if err1 != nil { + errs = append(errs, err1.Error()) + } + } + + wg.Wait() + + if len(errs) > 0 { + return errors.Errorf("[gc worker] destroy range finished with errors: %v", errs) + } + + return nil +} + +// needsGCOperationForStore checks if the store-level requests related to GC needs to be sent to the store. The store-level +// requests includes UnsafeDestroyRange, PhysicalScanLock, etc. +func needsGCOperationForStore(store *metapb.Store) (bool, error) { + // TombStone means the store has been removed from the cluster and there isn't any peer on the store, so needn't do GC for it. + // Offline means the store is being removed from the cluster and it becomes tombstone after all peers are removed from it, + // so we need to do GC for it. + if store.State == metapb.StoreState_Tombstone { + return false, nil + } + + engineLabel := "" + for _, label := range store.GetLabels() { + if label.GetKey() == placement.EngineLabelKey { + engineLabel = label.GetValue() + break + } + } + + switch engineLabel { + case placement.EngineLabelTiFlash: + // For a TiFlash node, it uses other approach to delete dropped tables, so it's safe to skip sending + // UnsafeDestroyRange requests; it has only learner peers and their data must exist in TiKV, so it's safe to + // skip physical resolve locks for it. + return false, nil + + case placement.EngineLabelTiFlashCompute: + logutil.BgLogger().Debug("will ignore gc tiflash_compute node", zap.String("category", "gc worker")) + return false, nil + + case placement.EngineLabelTiKV, "": + // If no engine label is set, it should be a TiKV node. + return true, nil + + default: + return true, errors.Errorf("unsupported store engine \"%v\" with storeID %v, addr %v", + engineLabel, + store.GetId(), + store.GetAddress()) + } +} + +// getStoresForGC gets the list of stores that needs to be processed during GC. +func (w *GCWorker) getStoresForGC(ctx context.Context) ([]*metapb.Store, error) { + stores, err := w.pdClient.GetAllStores(ctx) + if err != nil { + return nil, errors.Trace(err) + } + + upStores := make([]*metapb.Store, 0, len(stores)) + for _, store := range stores { + needsGCOp, err := needsGCOperationForStore(store) + if err != nil { + return nil, errors.Trace(err) + } + if needsGCOp { + upStores = append(upStores, store) + } + } + return upStores, nil +} + +func (w *GCWorker) getStoresMapForGC(ctx context.Context) (map[uint64]*metapb.Store, error) { + stores, err := w.getStoresForGC(ctx) + if err != nil { + return nil, err + } + + storesMap := make(map[uint64]*metapb.Store, len(stores)) + for _, store := range stores { + storesMap[store.Id] = store + } + + return storesMap, nil +} + +func (w *GCWorker) loadGCConcurrencyWithDefault() (int, error) { + str, err := w.loadValueFromSysTable(gcConcurrencyKey) + if err != nil { + return gcDefaultConcurrency, errors.Trace(err) + } + if str == "" { + err = w.saveValueToSysTable(gcConcurrencyKey, strconv.Itoa(gcDefaultConcurrency)) + if err != nil { + return gcDefaultConcurrency, errors.Trace(err) + } + return gcDefaultConcurrency, nil + } + + jobConcurrency, err := strconv.Atoi(str) + if err != nil { + return gcDefaultConcurrency, err + } + + if jobConcurrency < gcMinConcurrency { + jobConcurrency = gcMinConcurrency + } + + if jobConcurrency > gcMaxConcurrency { + jobConcurrency = gcMaxConcurrency + } + + return jobConcurrency, nil +} + +// Central mode is deprecated in v5.0. This function will always return true. +func (w *GCWorker) checkUseDistributedGC() bool { + mode, err := w.loadValueFromSysTable(gcModeKey) + if err == nil && mode == "" { + err = w.saveValueToSysTable(gcModeKey, gcModeDefault) + } + if err != nil { + logutil.BgLogger().Error("failed to load gc mode, fall back to distributed mode", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Error(err)) + metrics.GCJobFailureCounter.WithLabelValues("check_gc_mode").Inc() + } else if strings.EqualFold(mode, gcModeCentral) { + logutil.BgLogger().Warn("distributed mode will be used as central mode is deprecated", zap.String("category", "gc worker")) + } else if !strings.EqualFold(mode, gcModeDistributed) { + logutil.BgLogger().Warn("distributed mode will be used", zap.String("category", "gc worker"), + zap.String("invalid gc mode", mode)) + } + return true +} + +func (w *GCWorker) resolveLocks( + ctx context.Context, + safePoint uint64, + concurrency int, +) error { + metrics.GCWorkerCounter.WithLabelValues("resolve_locks").Inc() + logutil.Logger(ctx).Info("start resolve locks", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("safePoint", safePoint), + zap.Int("concurrency", concurrency)) + startTime := time.Now() + + handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { + scanLimit := uint32(tikv.GCScanLockLimit) + failpoint.Inject("lowScanLockLimit", func() { + scanLimit = 3 + }) + return tikv.ResolveLocksForRange(ctx, w.regionLockResolver, safePoint, r.StartKey, r.EndKey, tikv.NewGcResolveLockMaxBackoffer, scanLimit) + } + + runner := rangetask.NewRangeTaskRunner("resolve-locks-runner", w.tikvStore, concurrency, handler) + // Run resolve lock on the whole TiKV cluster. Empty keys means the range is unbounded. + err := runner.RunOnRange(ctx, []byte(""), []byte("")) + if err != nil { + logutil.Logger(ctx).Error("resolve locks failed", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("safePoint", safePoint), + zap.Error(err)) + return errors.Trace(err) + } + + logutil.Logger(ctx).Info("finish resolve locks", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("safePoint", safePoint), + zap.Int("regions", runner.CompletedRegions())) + metrics.GCHistogram.WithLabelValues("resolve_locks").Observe(time.Since(startTime).Seconds()) + return nil +} + +const gcOneRegionMaxBackoff = 20000 + +func (w *GCWorker) uploadSafePointToPD(ctx context.Context, safePoint uint64) error { + var newSafePoint uint64 + var err error + + bo := tikv.NewBackofferWithVars(ctx, gcOneRegionMaxBackoff, nil) + for { + newSafePoint, err = w.pdClient.UpdateGCSafePoint(ctx, safePoint) + if err != nil { + if errors.Cause(err) == context.Canceled { + return errors.Trace(err) + } + err = bo.Backoff(tikv.BoPDRPC(), errors.Errorf("failed to upload safe point to PD, err: %v", err)) + if err != nil { + return errors.Trace(err) + } + continue + } + break + } + + if newSafePoint != safePoint { + logutil.Logger(ctx).Warn("PD rejected safe point", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("our safe point", safePoint), + zap.Uint64("using another safe point", newSafePoint)) + return errors.Errorf("PD rejected our safe point %v but is using another safe point %v", safePoint, newSafePoint) + } + logutil.Logger(ctx).Info("sent safe point to PD", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("safe point", safePoint)) + return nil +} + +func (w *GCWorker) doGCForRange(ctx context.Context, startKey []byte, endKey []byte, safePoint uint64) (rangetask.TaskStat, error) { + var stat rangetask.TaskStat + defer func() { + metrics.GCActionRegionResultCounter.WithLabelValues("success").Add(float64(stat.CompletedRegions)) + metrics.GCActionRegionResultCounter.WithLabelValues("fail").Add(float64(stat.FailedRegions)) + }() + key := startKey + for { + bo := tikv.NewBackofferWithVars(ctx, gcOneRegionMaxBackoff, nil) + loc, err := w.tikvStore.GetRegionCache().LocateKey(bo, key) + if err != nil { + return stat, errors.Trace(err) + } + + var regionErr *errorpb.Error + regionErr, err = w.doGCForRegion(bo, safePoint, loc.Region) + + // we check regionErr here first, because we know 'regionErr' and 'err' should not return together, to keep it to + // make the process correct. + if regionErr != nil { + err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) + if err == nil { + continue + } + } + + if err != nil { + logutil.BgLogger().Warn("[gc worker]", + zap.String("uuid", w.uuid), + zap.String("gc for range", fmt.Sprintf("[%d, %d)", startKey, endKey)), + zap.Uint64("safePoint", safePoint), + zap.Error(err)) + stat.FailedRegions++ + } else { + stat.CompletedRegions++ + } + + key = loc.EndKey + if len(key) == 0 || bytes.Compare(key, endKey) >= 0 { + break + } + } + + return stat, nil +} + +// doGCForRegion used for gc for region. +// these two errors should not return together, for more, see the func 'doGC' +func (w *GCWorker) doGCForRegion(bo *tikv.Backoffer, safePoint uint64, region tikv.RegionVerID) (*errorpb.Error, error) { + req := tikvrpc.NewRequest(tikvrpc.CmdGC, &kvrpcpb.GCRequest{ + SafePoint: safePoint, + }) + + resp, err := w.tikvStore.SendReq(bo, req, region, gcTimeout) + if err != nil { + return nil, errors.Trace(err) + } + regionErr, err := resp.GetRegionError() + if err != nil { + return nil, errors.Trace(err) + } + if regionErr != nil { + return regionErr, nil + } + + if resp.Resp == nil { + return nil, errors.Trace(tikverr.ErrBodyMissing) + } + gcResp := resp.Resp.(*kvrpcpb.GCResponse) + if gcResp.GetError() != nil { + return nil, errors.Errorf("unexpected gc error: %s", gcResp.GetError()) + } + + return nil, nil +} + +func (w *GCWorker) doGC(ctx context.Context, safePoint uint64, concurrency int) error { + metrics.GCWorkerCounter.WithLabelValues("do_gc").Inc() + logutil.Logger(ctx).Info("start doing gc for all keys", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int("concurrency", concurrency), + zap.Uint64("safePoint", safePoint)) + startTime := time.Now() + + runner := rangetask.NewRangeTaskRunner( + "gc-runner", + w.tikvStore, + concurrency, + func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { + return w.doGCForRange(ctx, r.StartKey, r.EndKey, safePoint) + }) + + err := runner.RunOnRange(ctx, []byte(""), []byte("")) + if err != nil { + logutil.Logger(ctx).Warn("failed to do gc for all keys", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Int("concurrency", concurrency), + zap.Error(err)) + return errors.Trace(err) + } + + successRegions := runner.CompletedRegions() + failedRegions := runner.FailedRegions() + + logutil.Logger(ctx).Info("finished doing gc for all keys", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid), + zap.Uint64("safePoint", safePoint), + zap.Int("successful regions", successRegions), + zap.Int("failed regions", failedRegions), + zap.Duration("total cost time", time.Since(startTime))) + metrics.GCHistogram.WithLabelValues("do_gc").Observe(time.Since(startTime).Seconds()) + + return nil +} + +func (w *GCWorker) checkLeader(ctx context.Context) (bool, error) { + metrics.GCWorkerCounter.WithLabelValues("check_leader").Inc() + se := createSession(w.store) + defer se.Close() + + _, err := se.ExecuteInternal(ctx, "BEGIN") + if err != nil { + return false, errors.Trace(err) + } + leader, err := w.loadValueFromSysTable(gcLeaderUUIDKey) + if err != nil { + se.RollbackTxn(ctx) + return false, errors.Trace(err) + } + logutil.BgLogger().Debug("got leader", zap.String("category", "gc worker"), zap.String("uuid", leader)) + if leader == w.uuid { + err = w.saveTime(gcLeaderLeaseKey, time.Now().Add(gcWorkerLease)) + if err != nil { + se.RollbackTxn(ctx) + return false, errors.Trace(err) + } + err = se.CommitTxn(ctx) + if err != nil { + return false, errors.Trace(err) + } + return true, nil + } + + se.RollbackTxn(ctx) + + _, err = se.ExecuteInternal(ctx, "BEGIN") + if err != nil { + return false, errors.Trace(err) + } + lease, err := w.loadTime(gcLeaderLeaseKey) + if err != nil { + se.RollbackTxn(ctx) + return false, errors.Trace(err) + } + if lease == nil || lease.Before(time.Now()) { + logutil.BgLogger().Debug("register as leader", zap.String("category", "gc worker"), + zap.String("uuid", w.uuid)) + metrics.GCWorkerCounter.WithLabelValues("register_leader").Inc() + + err = w.saveValueToSysTable(gcLeaderUUIDKey, w.uuid) + if err != nil { + se.RollbackTxn(ctx) + return false, errors.Trace(err) + } + err = w.saveValueToSysTable(gcLeaderDescKey, w.desc) + if err != nil { + se.RollbackTxn(ctx) + return false, errors.Trace(err) + } + err = w.saveTime(gcLeaderLeaseKey, time.Now().Add(gcWorkerLease)) + if err != nil { + se.RollbackTxn(ctx) + return false, errors.Trace(err) + } + err = se.CommitTxn(ctx) + if err != nil { + return false, errors.Trace(err) + } + return true, nil + } + se.RollbackTxn(ctx) + return false, nil +} + +func (w *GCWorker) saveSafePoint(kv tikv.SafePointKV, t uint64) error { + s := strconv.FormatUint(t, 10) + err := kv.Put(tikv.GcSavedSafePoint, s) + if err != nil { + logutil.BgLogger().Error("save safepoint failed", zap.Error(err)) + return errors.Trace(err) + } + return nil +} + +func (w *GCWorker) saveTime(key string, t time.Time) error { + err := w.saveValueToSysTable(key, t.Format(tikvutil.GCTimeFormat)) + return errors.Trace(err) +} + +func (w *GCWorker) loadTime(key string) (*time.Time, error) { + str, err := w.loadValueFromSysTable(key) + if err != nil { + return nil, errors.Trace(err) + } + if str == "" { + return nil, nil + } + t, err := tikvutil.CompatibleParseGCTime(str) + if err != nil { + return nil, errors.Trace(err) + } + return &t, nil +} + +func (w *GCWorker) saveDuration(key string, d time.Duration) error { + err := w.saveValueToSysTable(key, d.String()) + return errors.Trace(err) +} + +func (w *GCWorker) loadDuration(key string) (*time.Duration, error) { + str, err := w.loadValueFromSysTable(key) + if err != nil { + return nil, errors.Trace(err) + } + if str == "" { + return nil, nil + } + d, err := time.ParseDuration(str) + if err != nil { + return nil, errors.Trace(err) + } + return &d, nil +} + +func (w *GCWorker) loadDurationWithDefault(key string, def time.Duration) (*time.Duration, error) { + d, err := w.loadDuration(key) + if err != nil { + return nil, errors.Trace(err) + } + if d == nil { + err = w.saveDuration(key, def) + if err != nil { + return nil, errors.Trace(err) + } + return &def, nil + } + return d, nil +} + +func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnGC) + se := createSession(w.store) + defer se.Close() + rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key) + if rs != nil { + defer terror.Call(rs.Close) + } + if err != nil { + return "", errors.Trace(err) + } + req := rs.NewChunk(nil) + err = rs.Next(ctx, req) + if err != nil { + return "", errors.Trace(err) + } + if req.NumRows() == 0 { + logutil.BgLogger().Debug("load kv", zap.String("category", "gc worker"), + zap.String("key", key)) + return "", nil + } + value := req.GetRow(0).GetString(0) + logutil.BgLogger().Debug("load kv", zap.String("category", "gc worker"), + zap.String("key", key), + zap.String("value", value)) + return value, nil +} + +func (w *GCWorker) saveValueToSysTable(key, value string) error { + const stmt = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) + ON DUPLICATE KEY + UPDATE variable_value = %?, comment = %?` + se := createSession(w.store) + defer se.Close() + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnGC) + _, err := se.ExecuteInternal(ctx, stmt, + key, value, gcVariableComments[key], + value, gcVariableComments[key]) + logutil.BgLogger().Debug("save kv", zap.String("category", "gc worker"), + zap.String("key", key), + zap.String("value", value), + zap.Error(err)) + return errors.Trace(err) +} + +// GC placement rules when the partitions are removed by the GC worker. +// Placement rules cannot be removed immediately after drop table / truncate table, +// because the tables can be flashed back or recovered. +func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util.DelRangeTask, gcPlacementRuleCache map[int64]any) (err error) { + // Get the job from the job history + var historyJob *model.Job + failpoint.Inject("mockHistoryJobForGC", func(v failpoint.Value) { + args, err1 := json.Marshal([]any{kv.Key{}, []int64{int64(v.(int))}}) + if err1 != nil { + return + } + historyJob = &model.Job{ + ID: dr.JobID, + Type: model.ActionDropTable, + TableID: int64(v.(int)), + RawArgs: args, + } + }) + if historyJob == nil { + historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) + if err != nil { + return + } + if historyJob == nil { + return dbterror.ErrDDLJobNotFound.GenWithStackByArgs(dr.JobID) + } + } + + // Notify PD to drop the placement rules of partition-ids and table-id, even if there may be no placement rules. + var physicalTableIDs []int64 + switch historyJob.Type { + case model.ActionDropTable, model.ActionTruncateTable: + var startKey kv.Key + if err = historyJob.DecodeArgs(&startKey, &physicalTableIDs); err != nil { + return + } + physicalTableIDs = append(physicalTableIDs, historyJob.TableID) + case model.ActionDropSchema, model.ActionDropTablePartition, model.ActionTruncateTablePartition, + model.ActionReorganizePartition, model.ActionRemovePartitioning, + model.ActionAlterTablePartitioning: + if err = historyJob.DecodeArgs(&physicalTableIDs); err != nil { + return + } + } + + // Skip table ids that's already successfully handled. + tmp := physicalTableIDs[:0] + for _, id := range physicalTableIDs { + if _, ok := gcPlacementRuleCache[id]; !ok { + tmp = append(tmp, id) + } + } + physicalTableIDs = tmp + + if len(physicalTableIDs) == 0 { + return + } + + if err := infosync.DeleteTiFlashPlacementRules(context.Background(), physicalTableIDs); err != nil { + logutil.BgLogger().Error("delete placement rules failed", zap.Error(err), zap.Int64s("tableIDs", physicalTableIDs)) + } + bundles := make([]*placement.Bundle, 0, len(physicalTableIDs)) + for _, id := range physicalTableIDs { + bundles = append(bundles, placement.NewBundle(id)) + } + err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) + if err != nil { + return + } + + // Cache the table id if its related rule are deleted successfully. + for _, id := range physicalTableIDs { + gcPlacementRuleCache[id] = struct{}{} + } + return nil +} + +func (w *GCWorker) doGCLabelRules(dr util.DelRangeTask) (err error) { + // Get the job from the job history + var historyJob *model.Job + failpoint.Inject("mockHistoryJob", func(v failpoint.Value) { + args, err1 := json.Marshal([]any{kv.Key{}, []int64{}, []string{v.(string)}}) + if err1 != nil { + return + } + historyJob = &model.Job{ + ID: dr.JobID, + Type: model.ActionDropTable, + RawArgs: args, + } + }) + if historyJob == nil { + se := createSession(w.store) + historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) + se.Close() + if err != nil { + return + } + if historyJob == nil { + return dbterror.ErrDDLJobNotFound.GenWithStackByArgs(dr.JobID) + } + } + + if historyJob.Type == model.ActionDropTable { + var ( + startKey kv.Key + physicalTableIDs []int64 + ruleIDs []string + rules map[string]*label.Rule + ) + if err = historyJob.DecodeArgs(&startKey, &physicalTableIDs, &ruleIDs); err != nil { + return + } + + // TODO: Here we need to get rules from PD and filter the rules which is not elegant. We should find a better way. + rules, err = infosync.GetLabelRules(context.TODO(), ruleIDs) + if err != nil { + return + } + + ruleIDs = getGCRules(append(physicalTableIDs, historyJob.TableID), rules) + patch := label.NewRulePatch([]*label.Rule{}, ruleIDs) + err = infosync.UpdateLabelRules(context.TODO(), patch) + } + return +} + +func getGCRules(ids []int64, rules map[string]*label.Rule) []string { + oldRange := make(map[string]struct{}) + for _, id := range ids { + startKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(id))) + endKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(id+1))) + oldRange[startKey+endKey] = struct{}{} + } + + var gcRules []string + for _, rule := range rules { + find := false + for _, d := range rule.Data.([]any) { + if r, ok := d.(map[string]any); ok { + nowRange := fmt.Sprintf("%s%s", r["start_key"], r["end_key"]) + if _, ok := oldRange[nowRange]; ok { + find = true + } + } + } + if find { + gcRules = append(gcRules, rule.ID) + } + } + return gcRules +} + +// RunGCJob sends GC command to KV. It is exported for kv api, do not use it with GCWorker at the same time. +// only use for test +func RunGCJob(ctx context.Context, regionLockResolver tikv.RegionLockResolver, s tikv.Storage, pd pd.Client, safePoint uint64, identifier string, concurrency int) error { + gcWorker := &GCWorker{ + tikvStore: s, + uuid: identifier, + pdClient: pd, + regionLockResolver: regionLockResolver, + } + + if concurrency <= 0 { + return errors.Errorf("[gc worker] gc concurrency should greater than 0, current concurrency: %v", concurrency) + } + + safePoint, err := gcWorker.setGCWorkerServiceSafePoint(ctx, safePoint) + if err != nil { + return errors.Trace(err) + } + + err = gcWorker.resolveLocks(ctx, safePoint, concurrency) + if err != nil { + return errors.Trace(err) + } + + err = gcWorker.saveSafePoint(gcWorker.tikvStore.GetSafePointKV(), safePoint) + if err != nil { + return errors.Trace(err) + } + // Sleep to wait for all other tidb instances update their safepoint cache. + time.Sleep(gcSafePointCacheInterval) + err = gcWorker.doGC(ctx, safePoint, concurrency) + if err != nil { + return errors.Trace(err) + } + return nil +} + +// RunDistributedGCJob notifies TiKVs to do GC. It is exported for kv api, do not use it with GCWorker at the same time. +// This function may not finish immediately because it may take some time to do resolveLocks. +// Param concurrency specifies the concurrency of resolveLocks phase. +func RunDistributedGCJob(ctx context.Context, regionLockResolver tikv.RegionLockResolver, s tikv.Storage, pd pd.Client, safePoint uint64, identifier string, concurrency int) error { + gcWorker := &GCWorker{ + tikvStore: s, + uuid: identifier, + pdClient: pd, + regionLockResolver: regionLockResolver, + } + + safePoint, err := gcWorker.setGCWorkerServiceSafePoint(ctx, safePoint) + if err != nil { + return errors.Trace(err) + } + err = gcWorker.resolveLocks(ctx, safePoint, concurrency) + if err != nil { + return errors.Trace(err) + } + + // Save safe point to pd. + err = gcWorker.saveSafePoint(gcWorker.tikvStore.GetSafePointKV(), safePoint) + if err != nil { + return errors.Trace(err) + } + // Sleep to wait for all other tidb instances update their safepoint cache. + time.Sleep(gcSafePointCacheInterval) + + err = gcWorker.uploadSafePointToPD(ctx, safePoint) + if err != nil { + return errors.Trace(err) + } + return nil +} + +// RunResolveLocks resolves all locks before the safePoint. +// It is exported only for test, do not use it in the production environment. +func RunResolveLocks(ctx context.Context, s tikv.Storage, pd pd.Client, safePoint uint64, identifier string, concurrency int) error { + gcWorker := &GCWorker{ + tikvStore: s, + uuid: identifier, + pdClient: pd, + regionLockResolver: tikv.NewRegionLockResolver("test-resolver", s), + } + return gcWorker.resolveLocks(ctx, safePoint, concurrency) +} + +// MockGCWorker is for test. +type MockGCWorker struct { + worker *GCWorker +} + +// NewMockGCWorker creates a MockGCWorker instance ONLY for test. +func NewMockGCWorker(store kv.Storage) (*MockGCWorker, error) { + ver, err := store.CurrentVersion(kv.GlobalTxnScope) + if err != nil { + return nil, errors.Trace(err) + } + hostName, err := os.Hostname() + if err != nil { + hostName = "unknown" + } + worker := &GCWorker{ + uuid: strconv.FormatUint(ver.Ver, 16), + desc: fmt.Sprintf("host:%s, pid:%d, start at %s", hostName, os.Getpid(), time.Now()), + store: store, + tikvStore: store.(tikv.Storage), + gcIsRunning: false, + lastFinish: time.Now(), + done: make(chan error), + pdClient: store.(tikv.Storage).GetRegionCache().PDClient(), + } + return &MockGCWorker{worker: worker}, nil +} + +// DeleteRanges calls deleteRanges internally, just for test. +func (w *MockGCWorker) DeleteRanges(ctx context.Context, safePoint uint64) error { + logutil.Logger(ctx).Error("deleteRanges is called") + ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnGC) + return w.worker.deleteRanges(ctx, safePoint, 1) +} diff --git a/pkg/store/mockstore/unistore/binding__failpoint_binding__.go b/pkg/store/mockstore/unistore/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..2394b995d2777 --- /dev/null +++ b/pkg/store/mockstore/unistore/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package unistore + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go b/pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..bea17e070ec6f --- /dev/null +++ b/pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package cophandler + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go b/pkg/store/mockstore/unistore/cophandler/cop_handler.go index 7097f07104476..b26a3cfb7bff4 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go @@ -147,10 +147,10 @@ func ExecutorListsToTree(exec []*tipb.Executor) *tipb.Executor { func handleCopDAGRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) (resp *coprocessor.Response) { startTime := time.Now() resp = &coprocessor.Response{} - failpoint.Inject("mockCopCacheInUnistore", func(cacheVersion failpoint.Value) { + if cacheVersion, _err_ := failpoint.Eval(_curpkg_("mockCopCacheInUnistore")); _err_ == nil { if req.IsCacheEnabled { if uint64(cacheVersion.(int)) == req.CacheIfMatchVersion { - failpoint.Return(&coprocessor.Response{IsCacheHit: true, CacheLastVersion: uint64(cacheVersion.(int))}) + return &coprocessor.Response{IsCacheHit: true, CacheLastVersion: uint64(cacheVersion.(int))} } else { defer func() { resp.CanBeCached = true @@ -165,7 +165,7 @@ func handleCopDAGRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemSt }() } } - }) + } dagCtx, dagReq, err := buildDAG(dbReader, lockStore, req) if err != nil { resp.OtherError = err.Error() diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ new file mode 100644 index 0000000000000..7097f07104476 --- /dev/null +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ @@ -0,0 +1,674 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cophandler + +import ( + "bytes" + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/golang/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/tidb/pkg/expression" + "github.com/pingcap/tidb/pkg/expression/aggregation" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/charset" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/store/mockstore/unistore/client" + "github.com/pingcap/tidb/pkg/store/mockstore/unistore/lockstore" + "github.com/pingcap/tidb/pkg/store/mockstore/unistore/tikv/dbreader" + "github.com/pingcap/tidb/pkg/store/mockstore/unistore/tikv/kverrors" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + contextutil "github.com/pingcap/tidb/pkg/util/context" + "github.com/pingcap/tidb/pkg/util/mock" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "github.com/pingcap/tipb/go-tipb" +) + +var globalLocationMap *locationMap = newLocationMap() + +type locationMap struct { + lmap map[string]*time.Location + mu sync.RWMutex +} + +func newLocationMap() *locationMap { + return &locationMap{ + lmap: make(map[string]*time.Location), + } +} + +func (l *locationMap) getLocation(name string) (*time.Location, bool) { + l.mu.RLock() + defer l.mu.RUnlock() + result, ok := l.lmap[name] + return result, ok +} + +func (l *locationMap) setLocation(name string, value *time.Location) { + l.mu.Lock() + defer l.mu.Unlock() + l.lmap[name] = value +} + +// MPPCtx is the mpp execution context +type MPPCtx struct { + RPCClient client.Client + StoreAddr string + TaskHandler *MPPTaskHandler + Ctx context.Context +} + +// HandleCopRequest handles coprocessor request. +func HandleCopRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) *coprocessor.Response { + return HandleCopRequestWithMPPCtx(dbReader, lockStore, req, nil) +} + +// HandleCopRequestWithMPPCtx handles coprocessor request, actually, this is the updated version for +// HandleCopRequest(after mpp test is supported), however, go does not support function overloading, +// I have to rename it to HandleCopRequestWithMPPCtx. +func HandleCopRequestWithMPPCtx(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request, mppCtx *MPPCtx) *coprocessor.Response { + switch req.Tp { + case kv.ReqTypeDAG: + if mppCtx != nil && mppCtx.TaskHandler != nil { + return HandleMPPDAGReq(dbReader, req, mppCtx) + } + return handleCopDAGRequest(dbReader, lockStore, req) + case kv.ReqTypeAnalyze: + return handleCopAnalyzeRequest(dbReader, req) + case kv.ReqTypeChecksum: + return handleCopChecksumRequest(dbReader, req) + } + return &coprocessor.Response{OtherError: fmt.Sprintf("unsupported request type %d", req.GetTp())} +} + +type dagContext struct { + *evalContext + dbReader *dbreader.DBReader + lockStore *lockstore.MemStore + resolvedLocks []uint64 + dagReq *tipb.DAGRequest + keyRanges []*coprocessor.KeyRange + startTS uint64 +} + +// ExecutorListsToTree converts a list of executors to a tree. +func ExecutorListsToTree(exec []*tipb.Executor) *tipb.Executor { + i := len(exec) - 1 + rootExec := exec[i] + for i--; 0 <= i; i-- { + switch exec[i+1].Tp { + case tipb.ExecType_TypeAggregation: + exec[i+1].Aggregation.Child = exec[i] + case tipb.ExecType_TypeProjection: + exec[i+1].Projection.Child = exec[i] + case tipb.ExecType_TypeTopN: + exec[i+1].TopN.Child = exec[i] + case tipb.ExecType_TypeLimit: + exec[i+1].Limit.Child = exec[i] + case tipb.ExecType_TypeSelection: + exec[i+1].Selection.Child = exec[i] + case tipb.ExecType_TypeStreamAgg: + exec[i+1].Aggregation.Child = exec[i] + default: + panic("unsupported dag executor type") + } + } + return rootExec +} + +// handleCopDAGRequest handles coprocessor DAG request using MPP executors. +func handleCopDAGRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) (resp *coprocessor.Response) { + startTime := time.Now() + resp = &coprocessor.Response{} + failpoint.Inject("mockCopCacheInUnistore", func(cacheVersion failpoint.Value) { + if req.IsCacheEnabled { + if uint64(cacheVersion.(int)) == req.CacheIfMatchVersion { + failpoint.Return(&coprocessor.Response{IsCacheHit: true, CacheLastVersion: uint64(cacheVersion.(int))}) + } else { + defer func() { + resp.CanBeCached = true + resp.CacheLastVersion = uint64(cacheVersion.(int)) + if resp.ExecDetails == nil { + resp.ExecDetails = &kvrpcpb.ExecDetails{TimeDetail: &kvrpcpb.TimeDetail{ProcessWallTimeMs: 500}} + } else if resp.ExecDetails.TimeDetail == nil { + resp.ExecDetails.TimeDetail = &kvrpcpb.TimeDetail{ProcessWallTimeMs: 500} + } else { + resp.ExecDetails.TimeDetail.ProcessWallTimeMs = 500 + } + }() + } + } + }) + dagCtx, dagReq, err := buildDAG(dbReader, lockStore, req) + if err != nil { + resp.OtherError = err.Error() + return resp + } + + exec, chunks, lastRange, counts, ndvs, err := buildAndRunMPPExecutor(dagCtx, dagReq, req.PagingSize) + + sc := dagCtx.sctx.GetSessionVars().StmtCtx + if err != nil { + errMsg := err.Error() + if strings.HasPrefix(errMsg, ErrExecutorNotSupportedMsg) { + resp.OtherError = err.Error() + return resp + } + return genRespWithMPPExec(nil, lastRange, nil, nil, exec, dagReq, err, sc.GetWarnings(), time.Since(startTime)) + } + return genRespWithMPPExec(chunks, lastRange, counts, ndvs, exec, dagReq, err, sc.GetWarnings(), time.Since(startTime)) +} + +func buildAndRunMPPExecutor(dagCtx *dagContext, dagReq *tipb.DAGRequest, pagingSize uint64) (mppExec, []tipb.Chunk, *coprocessor.KeyRange, []int64, []int64, error) { + rootExec := dagReq.RootExecutor + if rootExec == nil { + rootExec = ExecutorListsToTree(dagReq.Executors) + } + + var counts, ndvs []int64 + + if dagReq.GetCollectRangeCounts() { + counts = make([]int64, len(dagCtx.keyRanges)) + ndvs = make([]int64, len(dagCtx.keyRanges)) + } + builder := &mppExecBuilder{ + sctx: dagCtx.sctx, + dbReader: dagCtx.dbReader, + dagReq: dagReq, + dagCtx: dagCtx, + mppCtx: nil, + counts: counts, + ndvs: ndvs, + } + var lastRange *coprocessor.KeyRange + if pagingSize > 0 { + lastRange = &coprocessor.KeyRange{} + builder.paging = lastRange + builder.pagingSize = pagingSize + } + exec, err := builder.buildMPPExecutor(rootExec) + if err != nil { + return nil, nil, nil, nil, nil, err + } + chunks, err := mppExecute(exec, dagCtx, dagReq, pagingSize) + if lastRange != nil && len(lastRange.Start) == 0 && len(lastRange.End) == 0 { + // When should this happen, something is wrong? + lastRange = nil + } + return exec, chunks, lastRange, counts, ndvs, err +} + +func mppExecute(exec mppExec, dagCtx *dagContext, dagReq *tipb.DAGRequest, pagingSize uint64) (chunks []tipb.Chunk, err error) { + err = exec.open() + defer func() { + err := exec.stop() + if err != nil { + panic(err) + } + }() + if err != nil { + return + } + + var totalRows uint64 + var chk *chunk.Chunk + fields := exec.getFieldTypes() + for { + chk, err = exec.next() + if err != nil || chk == nil || chk.NumRows() == 0 { + return + } + + switch dagReq.EncodeType { + case tipb.EncodeType_TypeDefault: + chunks, err = useDefaultEncoding(chk, dagCtx, dagReq, fields, chunks) + case tipb.EncodeType_TypeChunk: + chunks = useChunkEncoding(chk, dagReq, fields, chunks) + if pagingSize > 0 { + totalRows += uint64(chk.NumRows()) + if totalRows > pagingSize { + return + } + } + default: + err = fmt.Errorf("unsupported DAG request encode type %s", dagReq.EncodeType) + } + if err != nil { + return + } + } +} + +func useDefaultEncoding(chk *chunk.Chunk, dagCtx *dagContext, dagReq *tipb.DAGRequest, + fields []*types.FieldType, chunks []tipb.Chunk) ([]tipb.Chunk, error) { + var buf []byte + var datums []types.Datum + var err error + numRows := chk.NumRows() + sc := dagCtx.sctx.GetSessionVars().StmtCtx + errCtx := sc.ErrCtx() + for i := 0; i < numRows; i++ { + datums = datums[:0] + if dagReq.OutputOffsets != nil { + for _, j := range dagReq.OutputOffsets { + datums = append(datums, chk.GetRow(i).GetDatum(int(j), fields[j])) + } + } else { + for j, ft := range fields { + datums = append(datums, chk.GetRow(i).GetDatum(j, ft)) + } + } + buf, err = codec.EncodeValue(sc.TimeZone(), buf[:0], datums...) + err = errCtx.HandleError(err) + if err != nil { + return nil, errors.Trace(err) + } + chunks = appendRow(chunks, buf, i) + } + return chunks, nil +} + +func useChunkEncoding(chk *chunk.Chunk, dagReq *tipb.DAGRequest, fields []*types.FieldType, chunks []tipb.Chunk) []tipb.Chunk { + if dagReq.OutputOffsets != nil { + offsets := make([]int, len(dagReq.OutputOffsets)) + newFields := make([]*types.FieldType, len(dagReq.OutputOffsets)) + for i := 0; i < len(dagReq.OutputOffsets); i++ { + offset := dagReq.OutputOffsets[i] + offsets[i] = int(offset) + newFields[i] = fields[offset] + } + chk = chk.Prune(offsets) + fields = newFields + } + + c := chunk.NewCodec(fields) + buffer := c.Encode(chk) + chunks = append(chunks, tipb.Chunk{ + RowsData: buffer, + }) + return chunks +} + +func buildDAG(reader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) (*dagContext, *tipb.DAGRequest, error) { + if len(req.Ranges) == 0 { + return nil, nil, errors.New("request range is null") + } + if req.GetTp() != kv.ReqTypeDAG { + return nil, nil, errors.Errorf("unsupported request type %d", req.GetTp()) + } + + dagReq := new(tipb.DAGRequest) + err := proto.Unmarshal(req.Data, dagReq) + if err != nil { + return nil, nil, errors.Trace(err) + } + var tz *time.Location + switch dagReq.TimeZoneName { + case "": + tz = time.FixedZone("UTC", int(dagReq.TimeZoneOffset)) + case "System": + tz = time.Local + default: + var ok bool + tz, ok = globalLocationMap.getLocation(dagReq.TimeZoneName) + if !ok { + tz, err = time.LoadLocation(dagReq.TimeZoneName) + if err != nil { + return nil, nil, errors.Trace(err) + } + globalLocationMap.setLocation(dagReq.TimeZoneName, tz) + } + } + sctx := flagsAndTzToSessionContext(dagReq.Flags, tz) + if dagReq.DivPrecisionIncrement != nil { + sctx.GetSessionVars().DivPrecisionIncrement = int(*dagReq.DivPrecisionIncrement) + } else { + sctx.GetSessionVars().DivPrecisionIncrement = variable.DefDivPrecisionIncrement + } + ctx := &dagContext{ + evalContext: &evalContext{sctx: sctx}, + dbReader: reader, + lockStore: lockStore, + dagReq: dagReq, + keyRanges: req.Ranges, + startTS: req.StartTs, + resolvedLocks: req.Context.ResolvedLocks, + } + return ctx, dagReq, err +} + +func getAggInfo(ctx *dagContext, pbAgg *tipb.Aggregation) ([]aggregation.Aggregation, []expression.Expression, error) { + length := len(pbAgg.AggFunc) + aggs := make([]aggregation.Aggregation, 0, length) + var err error + for _, expr := range pbAgg.AggFunc { + var aggExpr aggregation.Aggregation + aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx.GetExprCtx()) + if err != nil { + return nil, nil, errors.Trace(err) + } + aggs = append(aggs, aggExpr) + } + groupBys, err := convertToExprs(ctx.sctx, ctx.fieldTps, pbAgg.GetGroupBy()) + if err != nil { + return nil, nil, errors.Trace(err) + } + + return aggs, groupBys, nil +} + +func getTopNInfo(ctx *evalContext, topN *tipb.TopN) (heap *topNHeap, conds []expression.Expression, err error) { + pbConds := make([]*tipb.Expr, len(topN.OrderBy)) + for i, item := range topN.OrderBy { + pbConds[i] = item.Expr + } + heap = &topNHeap{ + totalCount: int(topN.Limit), + topNSorter: topNSorter{ + orderByItems: topN.OrderBy, + sc: ctx.sctx.GetSessionVars().StmtCtx, + }, + } + if conds, err = convertToExprs(ctx.sctx, ctx.fieldTps, pbConds); err != nil { + return nil, nil, errors.Trace(err) + } + + return heap, conds, nil +} + +type evalContext struct { + columnInfos []*tipb.ColumnInfo + fieldTps []*types.FieldType + primaryCols []int64 + sctx sessionctx.Context +} + +func (e *evalContext) setColumnInfo(cols []*tipb.ColumnInfo) { + e.columnInfos = make([]*tipb.ColumnInfo, len(cols)) + copy(e.columnInfos, cols) + + e.fieldTps = make([]*types.FieldType, 0, len(e.columnInfos)) + for _, col := range e.columnInfos { + ft := fieldTypeFromPBColumn(col) + e.fieldTps = append(e.fieldTps, ft) + } +} + +func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType, primaryCols []int64, timeZone *time.Location) (*rowcodec.ChunkDecoder, error) { + var ( + pkCols []int64 + cols = make([]rowcodec.ColInfo, 0, len(columnInfos)) + ) + for i := range columnInfos { + info := columnInfos[i] + if info.ColumnId == model.ExtraPhysTblID { + // Skip since it needs to be filled in from the key + continue + } + ft := fieldTps[i] + col := rowcodec.ColInfo{ + ID: info.ColumnId, + Ft: ft, + IsPKHandle: info.PkHandle, + } + cols = append(cols, col) + if info.PkHandle { + pkCols = append(pkCols, info.ColumnId) + } + } + if len(pkCols) == 0 { + if primaryCols != nil { + pkCols = primaryCols + } else { + pkCols = []int64{-1} + } + } + def := func(i int, chk *chunk.Chunk) error { + info := columnInfos[i] + if info.PkHandle || len(info.DefaultVal) == 0 { + chk.AppendNull(i) + return nil + } + decoder := codec.NewDecoder(chk, timeZone) + _, err := decoder.DecodeOne(info.DefaultVal, i, fieldTps[i]) + if err != nil { + return err + } + return nil + } + return rowcodec.NewChunkDecoder(cols, pkCols, def, timeZone), nil +} + +// flagsAndTzToSessionContext creates a sessionctx.Context from a `tipb.SelectRequest.Flags`. +func flagsAndTzToSessionContext(flags uint64, tz *time.Location) sessionctx.Context { + sc := stmtctx.NewStmtCtx() + sc.InitFromPBFlagAndTz(flags, tz) + sctx := mock.NewContext() + sctx.GetSessionVars().StmtCtx = sc + sctx.GetSessionVars().TimeZone = tz + return sctx +} + +// ErrLocked is returned when trying to Read/Write on a locked key. Client should +// backoff or cleanup the lock then retry. +type ErrLocked struct { + Key []byte + Primary []byte + StartTS uint64 + TTL uint64 + LockType uint8 +} + +// BuildLockErr generates ErrKeyLocked objects +func BuildLockErr(key []byte, primaryKey []byte, startTS uint64, TTL uint64, lockType uint8) *ErrLocked { + errLocked := &ErrLocked{ + Key: key, + Primary: primaryKey, + StartTS: startTS, + TTL: TTL, + LockType: lockType, + } + return errLocked +} + +// Error formats the lock to a string. +func (e *ErrLocked) Error() string { + return fmt.Sprintf("key is locked, key: %q, Type: %v, primary: %q, startTS: %v", e.Key, e.LockType, e.Primary, e.StartTS) +} + +func genRespWithMPPExec(chunks []tipb.Chunk, lastRange *coprocessor.KeyRange, counts, ndvs []int64, exec mppExec, dagReq *tipb.DAGRequest, err error, warnings []contextutil.SQLWarn, dur time.Duration) *coprocessor.Response { + resp := &coprocessor.Response{ + Range: lastRange, + } + selResp := &tipb.SelectResponse{ + Error: toPBError(err), + Chunks: chunks, + OutputCounts: counts, + Ndvs: ndvs, + EncodeType: dagReq.EncodeType, + } + executors := dagReq.Executors + if dagReq.CollectExecutionSummaries != nil && *dagReq.CollectExecutionSummaries { + // for simplicity, we assume all executors to be spending the same amount of time as the request + timeProcessed := uint64(dur / time.Nanosecond) + execSummary := make([]*tipb.ExecutorExecutionSummary, len(executors)) + e := exec + for i := len(executors) - 1; 0 <= i; i-- { + execSummary[i] = e.buildSummary() + execSummary[i].TimeProcessedNs = &timeProcessed + if i != 0 { + e = exec.child() + } + } + selResp.ExecutionSummaries = execSummary + } + if len(warnings) > 0 { + selResp.Warnings = make([]*tipb.Error, 0, len(warnings)) + for i := range warnings { + selResp.Warnings = append(selResp.Warnings, toPBError(warnings[i].Err)) + } + } + if locked, ok := errors.Cause(err).(*ErrLocked); ok { + resp.Locked = &kvrpcpb.LockInfo{ + Key: locked.Key, + PrimaryLock: locked.Primary, + LockVersion: locked.StartTS, + LockTtl: locked.TTL, + } + } + resp.ExecDetails = &kvrpcpb.ExecDetails{ + TimeDetail: &kvrpcpb.TimeDetail{ProcessWallTimeMs: uint64(dur / time.Millisecond)}, + } + resp.ExecDetailsV2 = &kvrpcpb.ExecDetailsV2{ + TimeDetail: resp.ExecDetails.TimeDetail, + } + data, mErr := proto.Marshal(selResp) + if mErr != nil { + resp.OtherError = mErr.Error() + return resp + } + resp.Data = data + if err != nil { + if conflictErr, ok := errors.Cause(err).(*kverrors.ErrConflict); ok { + resp.OtherError = conflictErr.Error() + } + } + return resp +} + +func toPBError(err error) *tipb.Error { + if err == nil { + return nil + } + perr := new(tipb.Error) + e := errors.Cause(err) + switch y := e.(type) { + case *terror.Error: + tmp := terror.ToSQLError(y) + perr.Code = int32(tmp.Code) + perr.Msg = tmp.Message + case *mysql.SQLError: + perr.Code = int32(y.Code) + perr.Msg = y.Message + default: + perr.Code = int32(1) + perr.Msg = err.Error() + } + return perr +} + +// extractKVRanges extracts kv.KeyRanges slice from a SelectRequest. +func extractKVRanges(startKey, endKey []byte, keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) { + kvRanges = make([]kv.KeyRange, 0, len(keyRanges)) + for _, kran := range keyRanges { + if bytes.Compare(kran.GetStart(), kran.GetEnd()) >= 0 { + err = errors.Errorf("invalid range, start should be smaller than end: %v %v", kran.GetStart(), kran.GetEnd()) + return + } + + upperKey := kran.GetEnd() + if bytes.Compare(upperKey, startKey) <= 0 { + continue + } + lowerKey := kran.GetStart() + if len(endKey) != 0 && bytes.Compare(lowerKey, endKey) >= 0 { + break + } + r := kv.KeyRange{ + StartKey: kv.Key(maxStartKey(lowerKey, startKey)), + EndKey: kv.Key(minEndKey(upperKey, endKey)), + } + kvRanges = append(kvRanges, r) + } + if descScan { + reverseKVRanges(kvRanges) + } + return +} + +func reverseKVRanges(kvRanges []kv.KeyRange) { + for i := 0; i < len(kvRanges)/2; i++ { + j := len(kvRanges) - i - 1 + kvRanges[i], kvRanges[j] = kvRanges[j], kvRanges[i] + } +} + +func maxStartKey(rangeStartKey kv.Key, regionStartKey []byte) []byte { + if bytes.Compare([]byte(rangeStartKey), regionStartKey) > 0 { + return []byte(rangeStartKey) + } + return regionStartKey +} + +func minEndKey(rangeEndKey kv.Key, regionEndKey []byte) []byte { + if len(regionEndKey) == 0 || bytes.Compare([]byte(rangeEndKey), regionEndKey) < 0 { + return []byte(rangeEndKey) + } + return regionEndKey +} + +const rowsPerChunk = 64 + +func appendRow(chunks []tipb.Chunk, data []byte, rowCnt int) []tipb.Chunk { + if rowCnt%rowsPerChunk == 0 { + chunks = append(chunks, tipb.Chunk{}) + } + cur := &chunks[len(chunks)-1] + cur.RowsData = append(cur.RowsData, data...) + return chunks +} + +// fieldTypeFromPBColumn creates a types.FieldType from tipb.ColumnInfo. +func fieldTypeFromPBColumn(col *tipb.ColumnInfo) *types.FieldType { + charsetStr, collationStr, _ := charset.GetCharsetInfoByID(int(collate.RestoreCollationIDIfNeeded(col.GetCollation()))) + ft := &types.FieldType{} + ft.SetType(byte(col.GetTp())) + ft.SetFlag(uint(col.GetFlag())) + ft.SetFlen(int(col.GetColumnLen())) + ft.SetDecimal(int(col.GetDecimal())) + ft.SetElems(col.Elems) + ft.SetCharset(charsetStr) + ft.SetCollate(collationStr) + return ft +} + +// handleCopChecksumRequest handles coprocessor check sum request. +func handleCopChecksumRequest(dbReader *dbreader.DBReader, req *coprocessor.Request) *coprocessor.Response { + resp := &tipb.ChecksumResponse{ + Checksum: 1, + TotalKvs: 1, + TotalBytes: 1, + } + data, err := resp.Marshal() + if err != nil { + return &coprocessor.Response{OtherError: fmt.Sprintf("marshal checksum response error: %v", err)} + } + return &coprocessor.Response{Data: data} +} diff --git a/pkg/store/mockstore/unistore/rpc.go b/pkg/store/mockstore/unistore/rpc.go index b255138dff63b..d22ab62ccc6f8 100644 --- a/pkg/store/mockstore/unistore/rpc.go +++ b/pkg/store/mockstore/unistore/rpc.go @@ -62,41 +62,41 @@ var UnistoreRPCClientSendHook atomic.Pointer[func(*tikvrpc.Request)] // SendRequest sends a request to mock cluster. func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { - failpoint.Inject("rpcServerBusy", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcServerBusy")); _err_ == nil { if val.(bool) { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}})) + return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}}) } - }) - failpoint.Inject("epochNotMatch", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("epochNotMatch")); _err_ == nil { if val.(bool) { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}})) + return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}}) } - }) + } - failpoint.Inject("unistoreRPCClientSendHook", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("unistoreRPCClientSendHook")); _err_ == nil { if fn := UnistoreRPCClientSendHook.Load(); val.(bool) && fn != nil { (*fn)(req) } - }) + } - failpoint.Inject("rpcTiKVAllowedOnAlmostFull", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcTiKVAllowedOnAlmostFull")); _err_ == nil { if val.(bool) { if req.Type == tikvrpc.CmdPrewrite || req.Type == tikvrpc.CmdCommit { if req.Context.DiskFullOpt != kvrpcpb.DiskFullOpt_AllowedOnAlmostFull { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{DiskFull: &errorpb.DiskFull{StoreId: []uint64{1}, Reason: "disk full"}})) + return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{DiskFull: &errorpb.DiskFull{StoreId: []uint64{1}, Reason: "disk full"}}) } } } - }) - failpoint.Inject("unistoreRPCDeadlineExceeded", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("unistoreRPCDeadlineExceeded")); _err_ == nil { if val.(bool) && timeout < time.Second { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) + return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"}) } - }) - failpoint.Inject("unistoreRPCSlowByInjestSleep", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("unistoreRPCSlowByInjestSleep")); _err_ == nil { time.Sleep(time.Duration(val.(int) * int(time.Millisecond))) - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) - }) + return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"}) + } select { case <-ctx.Done(): @@ -127,10 +127,10 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp, err = c.usSvr.KvGet(ctx, req.Get()) case tikvrpc.CmdScan: kvScanReq := req.Scan() - failpoint.Inject("rpcScanResult", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcScanResult")); _err_ == nil { switch val.(string) { case "keyError": - failpoint.Return(&tikvrpc.Response{ + return &tikvrpc.Response{ Resp: &kvrpcpb.ScanResponse{Error: &kvrpcpb.KeyError{ Locked: &kvrpcpb.LockInfo{ PrimaryLock: kvScanReq.StartKey, @@ -141,38 +141,38 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R LockType: kvrpcpb.Op_Put, }, }}, - }, nil) + }, nil } - }) + } resp.Resp, err = c.usSvr.KvScan(ctx, kvScanReq) case tikvrpc.CmdPrewrite: - failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcPrewriteResult")); _err_ == nil { if val != nil { switch val.(string) { case "timeout": - failpoint.Return(nil, errors.New("timeout")) + return nil, errors.New("timeout") case "notLeader": - failpoint.Return(&tikvrpc.Response{ + return &tikvrpc.Response{ Resp: &kvrpcpb.PrewriteResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, - }, nil) + }, nil case "writeConflict": - failpoint.Return(&tikvrpc.Response{ + return &tikvrpc.Response{ Resp: &kvrpcpb.PrewriteResponse{Errors: []*kvrpcpb.KeyError{{Conflict: &kvrpcpb.WriteConflict{}}}}, - }, nil) + }, nil } } - }) + } r := req.Prewrite() c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) resp.Resp, err = c.usSvr.KvPrewrite(ctx, r) - failpoint.Inject("rpcPrewriteTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcPrewriteTimeout")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, undeterminedErr) + return nil, undeterminedErr } - }) + } case tikvrpc.CmdPessimisticLock: r := req.PessimisticLock() c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) @@ -180,28 +180,28 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R case tikvrpc.CmdPessimisticRollback: resp.Resp, err = c.usSvr.KVPessimisticRollback(ctx, req.PessimisticRollback()) case tikvrpc.CmdCommit: - failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcCommitResult")); _err_ == nil { switch val.(string) { case "timeout": - failpoint.Return(nil, errors.New("timeout")) + return nil, errors.New("timeout") case "notLeader": - failpoint.Return(&tikvrpc.Response{ + return &tikvrpc.Response{ Resp: &kvrpcpb.CommitResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, - }, nil) + }, nil case "keyError": - failpoint.Return(&tikvrpc.Response{ + return &tikvrpc.Response{ Resp: &kvrpcpb.CommitResponse{Error: &kvrpcpb.KeyError{}}, - }, nil) + }, nil } - }) + } resp.Resp, err = c.usSvr.KvCommit(ctx, req.Commit()) - failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcCommitTimeout")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, undeterminedErr) + return nil, undeterminedErr } - }) + } case tikvrpc.CmdCleanup: resp.Resp, err = c.usSvr.KvCleanup(ctx, req.Cleanup()) case tikvrpc.CmdCheckTxnStatus: @@ -212,10 +212,10 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp, err = c.usSvr.KvTxnHeartBeat(ctx, req.TxnHeartBeat()) case tikvrpc.CmdBatchGet: batchGetReq := req.BatchGet() - failpoint.Inject("rpcBatchGetResult", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("rpcBatchGetResult")); _err_ == nil { switch val.(string) { case "keyError": - failpoint.Return(&tikvrpc.Response{ + return &tikvrpc.Response{ Resp: &kvrpcpb.BatchGetResponse{Error: &kvrpcpb.KeyError{ Locked: &kvrpcpb.LockInfo{ PrimaryLock: batchGetReq.Keys[0], @@ -226,9 +226,9 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R LockType: kvrpcpb.Op_Put, }, }}, - }, nil) + }, nil } - }) + } resp.Resp, err = c.usSvr.KvBatchGet(ctx, batchGetReq) case tikvrpc.CmdBatchRollback: @@ -262,41 +262,41 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R case tikvrpc.CmdCopStream: resp.Resp, err = c.handleCopStream(ctx, req.Cop()) case tikvrpc.CmdBatchCop: - failpoint.Inject("BatchCopCancelled", func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("BatchCopCancelled")); _err_ == nil { if value.(bool) { - failpoint.Return(nil, context.Canceled) + return nil, context.Canceled } - }) + } - failpoint.Inject("BatchCopRpcErr"+addr, func(value failpoint.Value) { + if value, _err_ := failpoint.Eval(_curpkg_("BatchCopRpcErr" + addr)); _err_ == nil { if value.(string) == addr { - failpoint.Return(nil, errors.New("rpc error")) + return nil, errors.New("rpc error") } - }) + } resp.Resp, err = c.handleBatchCop(ctx, req.BatchCop(), timeout) case tikvrpc.CmdMPPConn: - failpoint.Inject("mppConnTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mppConnTimeout")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("rpc error")) + return nil, errors.New("rpc error") } - }) - failpoint.Inject("MppVersionError", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("MppVersionError")); _err_ == nil { if v := int64(val.(int)); v > req.EstablishMPPConn().GetReceiverMeta().GetMppVersion() || v > req.EstablishMPPConn().GetSenderMeta().GetMppVersion() { - failpoint.Return(nil, context.Canceled) + return nil, context.Canceled } - }) + } resp.Resp, err = c.handleEstablishMPPConnection(ctx, req.EstablishMPPConn(), timeout, storeID) case tikvrpc.CmdMPPTask: - failpoint.Inject("mppDispatchTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mppDispatchTimeout")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("rpc error")) + return nil, errors.New("rpc error") } - }) - failpoint.Inject("MppVersionError", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("MppVersionError")); _err_ == nil { if v := int64(val.(int)); v > req.DispatchMPPTask().GetMeta().GetMppVersion() { - failpoint.Return(nil, context.Canceled) + return nil, context.Canceled } - }) + } resp.Resp, err = c.handleDispatchMPPTask(ctx, req.DispatchMPPTask(), storeID) case tikvrpc.CmdMPPCancel: case tikvrpc.CmdMvccGetByKey: @@ -367,11 +367,11 @@ func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.Est if err != nil { return nil, err } - failpoint.Inject("establishMppConnectionErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("establishMppConnectionErr")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, errors.New("rpc error")) + return nil, errors.New("rpc error") } - }) + } var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, ctx: ctx, targetTask: r.ReceiverMeta} streamResp := &tikvrpc.MPPStreamResponse{Tikv_EstablishMPPConnectionClient: &mockClient} _, cancel := context.WithCancel(ctx) @@ -510,11 +510,11 @@ func (mock *mockBatchCopClient) Recv() (*coprocessor.BatchResponse, error) { } return ret, err } - failpoint.Inject("batchCopRecvTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("batchCopRecvTimeout")); _err_ == nil { if val.(bool) { - failpoint.Return(nil, context.Canceled) + return nil, context.Canceled } - }) + } return nil, io.EOF } @@ -532,23 +532,23 @@ func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { mock.idx++ return ret, nil } - failpoint.Inject("mppRecvTimeout", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mppRecvTimeout")); _err_ == nil { if int64(val.(int)) == mock.targetTask.TaskId { - failpoint.Return(nil, context.Canceled) + return nil, context.Canceled } - }) - failpoint.Inject("mppRecvHang", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("mppRecvHang")); _err_ == nil { for val.(bool) { select { case <-mock.ctx.Done(): { - failpoint.Return(nil, context.Canceled) + return nil, context.Canceled } default: time.Sleep(1 * time.Second) } } - }) + } return nil, io.EOF } diff --git a/pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ b/pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ new file mode 100644 index 0000000000000..b255138dff63b --- /dev/null +++ b/pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ @@ -0,0 +1,582 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unistore + +import ( + "context" + "io" + "math" + "os" + "strconv" + "sync/atomic" + "time" + + "github.com/golang/protobuf/proto" + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/kvproto/pkg/coprocessor" + "github.com/pingcap/kvproto/pkg/debugpb" + "github.com/pingcap/kvproto/pkg/errorpb" + "github.com/pingcap/kvproto/pkg/kvrpcpb" + "github.com/pingcap/kvproto/pkg/metapb" + "github.com/pingcap/kvproto/pkg/mpp" + "github.com/pingcap/tidb/pkg/parser/terror" + us "github.com/pingcap/tidb/pkg/store/mockstore/unistore/tikv" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/tikv/client-go/v2/tikv" + "github.com/tikv/client-go/v2/tikvrpc" + "google.golang.org/grpc/metadata" +) + +// For gofail injection. +var undeterminedErr = terror.ErrResultUndetermined + +// RPCClient sends kv RPC calls to mock cluster. RPCClient mocks the behavior of +// a rpc client at tikv's side. +type RPCClient struct { + usSvr *us.Server + cluster *Cluster + path string + rawHandler *rawHandler + persistent bool + closed int32 +} + +// CheckResourceTagForTopSQLInGoTest is used to identify whether check resource tag for TopSQL. +var CheckResourceTagForTopSQLInGoTest bool + +// UnistoreRPCClientSendHook exports for test. +var UnistoreRPCClientSendHook atomic.Pointer[func(*tikvrpc.Request)] + +// SendRequest sends a request to mock cluster. +func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { + failpoint.Inject("rpcServerBusy", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}})) + } + }) + failpoint.Inject("epochNotMatch", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}})) + } + }) + + failpoint.Inject("unistoreRPCClientSendHook", func(val failpoint.Value) { + if fn := UnistoreRPCClientSendHook.Load(); val.(bool) && fn != nil { + (*fn)(req) + } + }) + + failpoint.Inject("rpcTiKVAllowedOnAlmostFull", func(val failpoint.Value) { + if val.(bool) { + if req.Type == tikvrpc.CmdPrewrite || req.Type == tikvrpc.CmdCommit { + if req.Context.DiskFullOpt != kvrpcpb.DiskFullOpt_AllowedOnAlmostFull { + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{DiskFull: &errorpb.DiskFull{StoreId: []uint64{1}, Reason: "disk full"}})) + } + } + } + }) + failpoint.Inject("unistoreRPCDeadlineExceeded", func(val failpoint.Value) { + if val.(bool) && timeout < time.Second { + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) + } + }) + failpoint.Inject("unistoreRPCSlowByInjestSleep", func(val failpoint.Value) { + time.Sleep(time.Duration(val.(int) * int(time.Millisecond))) + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) + }) + + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if atomic.LoadInt32(&c.closed) != 0 { + // Return `context.Canceled` can break Backoff. + return nil, context.Canceled + } + + storeID, err := c.usSvr.GetStoreIDByAddr(addr) + if err != nil { + return nil, err + } + + if CheckResourceTagForTopSQLInGoTest { + err = checkResourceTagForTopSQL(req) + if err != nil { + return nil, err + } + } + + resp := &tikvrpc.Response{} + switch req.Type { + case tikvrpc.CmdGet: + resp.Resp, err = c.usSvr.KvGet(ctx, req.Get()) + case tikvrpc.CmdScan: + kvScanReq := req.Scan() + failpoint.Inject("rpcScanResult", func(val failpoint.Value) { + switch val.(string) { + case "keyError": + failpoint.Return(&tikvrpc.Response{ + Resp: &kvrpcpb.ScanResponse{Error: &kvrpcpb.KeyError{ + Locked: &kvrpcpb.LockInfo{ + PrimaryLock: kvScanReq.StartKey, + LockVersion: kvScanReq.Version - 1, + Key: kvScanReq.StartKey, + LockTtl: 50, + TxnSize: 1, + LockType: kvrpcpb.Op_Put, + }, + }}, + }, nil) + } + }) + + resp.Resp, err = c.usSvr.KvScan(ctx, kvScanReq) + case tikvrpc.CmdPrewrite: + failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) { + if val != nil { + switch val.(string) { + case "timeout": + failpoint.Return(nil, errors.New("timeout")) + case "notLeader": + failpoint.Return(&tikvrpc.Response{ + Resp: &kvrpcpb.PrewriteResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, + }, nil) + case "writeConflict": + failpoint.Return(&tikvrpc.Response{ + Resp: &kvrpcpb.PrewriteResponse{Errors: []*kvrpcpb.KeyError{{Conflict: &kvrpcpb.WriteConflict{}}}}, + }, nil) + } + } + }) + + r := req.Prewrite() + c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) + resp.Resp, err = c.usSvr.KvPrewrite(ctx, r) + + failpoint.Inject("rpcPrewriteTimeout", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, undeterminedErr) + } + }) + case tikvrpc.CmdPessimisticLock: + r := req.PessimisticLock() + c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) + resp.Resp, err = c.usSvr.KvPessimisticLock(ctx, r) + case tikvrpc.CmdPessimisticRollback: + resp.Resp, err = c.usSvr.KVPessimisticRollback(ctx, req.PessimisticRollback()) + case tikvrpc.CmdCommit: + failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { + switch val.(string) { + case "timeout": + failpoint.Return(nil, errors.New("timeout")) + case "notLeader": + failpoint.Return(&tikvrpc.Response{ + Resp: &kvrpcpb.CommitResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, + }, nil) + case "keyError": + failpoint.Return(&tikvrpc.Response{ + Resp: &kvrpcpb.CommitResponse{Error: &kvrpcpb.KeyError{}}, + }, nil) + } + }) + + resp.Resp, err = c.usSvr.KvCommit(ctx, req.Commit()) + + failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, undeterminedErr) + } + }) + case tikvrpc.CmdCleanup: + resp.Resp, err = c.usSvr.KvCleanup(ctx, req.Cleanup()) + case tikvrpc.CmdCheckTxnStatus: + resp.Resp, err = c.usSvr.KvCheckTxnStatus(ctx, req.CheckTxnStatus()) + case tikvrpc.CmdCheckSecondaryLocks: + resp.Resp, err = c.usSvr.KvCheckSecondaryLocks(ctx, req.CheckSecondaryLocks()) + case tikvrpc.CmdTxnHeartBeat: + resp.Resp, err = c.usSvr.KvTxnHeartBeat(ctx, req.TxnHeartBeat()) + case tikvrpc.CmdBatchGet: + batchGetReq := req.BatchGet() + failpoint.Inject("rpcBatchGetResult", func(val failpoint.Value) { + switch val.(string) { + case "keyError": + failpoint.Return(&tikvrpc.Response{ + Resp: &kvrpcpb.BatchGetResponse{Error: &kvrpcpb.KeyError{ + Locked: &kvrpcpb.LockInfo{ + PrimaryLock: batchGetReq.Keys[0], + LockVersion: batchGetReq.Version - 1, + Key: batchGetReq.Keys[0], + LockTtl: 50, + TxnSize: 1, + LockType: kvrpcpb.Op_Put, + }, + }}, + }, nil) + } + }) + + resp.Resp, err = c.usSvr.KvBatchGet(ctx, batchGetReq) + case tikvrpc.CmdBatchRollback: + resp.Resp, err = c.usSvr.KvBatchRollback(ctx, req.BatchRollback()) + case tikvrpc.CmdScanLock: + resp.Resp, err = c.usSvr.KvScanLock(ctx, req.ScanLock()) + case tikvrpc.CmdResolveLock: + resp.Resp, err = c.usSvr.KvResolveLock(ctx, req.ResolveLock()) + case tikvrpc.CmdGC: + resp.Resp, err = c.usSvr.KvGC(ctx, req.GC()) + case tikvrpc.CmdDeleteRange: + resp.Resp, err = c.usSvr.KvDeleteRange(ctx, req.DeleteRange()) + case tikvrpc.CmdRawGet: + resp.Resp, err = c.rawHandler.RawGet(ctx, req.RawGet()) + case tikvrpc.CmdRawBatchGet: + resp.Resp, err = c.rawHandler.RawBatchGet(ctx, req.RawBatchGet()) + case tikvrpc.CmdRawPut: + resp.Resp, err = c.rawHandler.RawPut(ctx, req.RawPut()) + case tikvrpc.CmdRawBatchPut: + resp.Resp, err = c.rawHandler.RawBatchPut(ctx, req.RawBatchPut()) + case tikvrpc.CmdRawDelete: + resp.Resp, err = c.rawHandler.RawDelete(ctx, req.RawDelete()) + case tikvrpc.CmdRawBatchDelete: + resp.Resp, err = c.rawHandler.RawBatchDelete(ctx, req.RawBatchDelete()) + case tikvrpc.CmdRawDeleteRange: + resp.Resp, err = c.rawHandler.RawDeleteRange(ctx, req.RawDeleteRange()) + case tikvrpc.CmdRawScan: + resp.Resp, err = c.rawHandler.RawScan(ctx, req.RawScan()) + case tikvrpc.CmdCop: + resp.Resp, err = c.usSvr.Coprocessor(ctx, req.Cop()) + case tikvrpc.CmdCopStream: + resp.Resp, err = c.handleCopStream(ctx, req.Cop()) + case tikvrpc.CmdBatchCop: + failpoint.Inject("BatchCopCancelled", func(value failpoint.Value) { + if value.(bool) { + failpoint.Return(nil, context.Canceled) + } + }) + + failpoint.Inject("BatchCopRpcErr"+addr, func(value failpoint.Value) { + if value.(string) == addr { + failpoint.Return(nil, errors.New("rpc error")) + } + }) + resp.Resp, err = c.handleBatchCop(ctx, req.BatchCop(), timeout) + case tikvrpc.CmdMPPConn: + failpoint.Inject("mppConnTimeout", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("rpc error")) + } + }) + failpoint.Inject("MppVersionError", func(val failpoint.Value) { + if v := int64(val.(int)); v > req.EstablishMPPConn().GetReceiverMeta().GetMppVersion() || v > req.EstablishMPPConn().GetSenderMeta().GetMppVersion() { + failpoint.Return(nil, context.Canceled) + } + }) + resp.Resp, err = c.handleEstablishMPPConnection(ctx, req.EstablishMPPConn(), timeout, storeID) + case tikvrpc.CmdMPPTask: + failpoint.Inject("mppDispatchTimeout", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("rpc error")) + } + }) + failpoint.Inject("MppVersionError", func(val failpoint.Value) { + if v := int64(val.(int)); v > req.DispatchMPPTask().GetMeta().GetMppVersion() { + failpoint.Return(nil, context.Canceled) + } + }) + resp.Resp, err = c.handleDispatchMPPTask(ctx, req.DispatchMPPTask(), storeID) + case tikvrpc.CmdMPPCancel: + case tikvrpc.CmdMvccGetByKey: + resp.Resp, err = c.usSvr.MvccGetByKey(ctx, req.MvccGetByKey()) + case tikvrpc.CmdMPPAlive: + resp.Resp, err = c.usSvr.IsAlive(ctx, req.IsMPPAlive()) + case tikvrpc.CmdMvccGetByStartTs: + resp.Resp, err = c.usSvr.MvccGetByStartTs(ctx, req.MvccGetByStartTs()) + case tikvrpc.CmdSplitRegion: + resp.Resp, err = c.usSvr.SplitRegion(ctx, req.SplitRegion()) + case tikvrpc.CmdDebugGetRegionProperties: + resp.Resp, err = c.handleDebugGetRegionProperties(ctx, req.DebugGetRegionProperties()) + return resp, err + case tikvrpc.CmdStoreSafeTS: + resp.Resp, err = c.usSvr.GetStoreSafeTS(ctx, req.StoreSafeTS()) + return resp, err + case tikvrpc.CmdUnsafeDestroyRange: + // Pretend it was done. Unistore does not have "destroy", and the + // keys has already been removed one-by-one before through: + // (dr *delRange) startEmulator() + resp.Resp = &kvrpcpb.UnsafeDestroyRangeResponse{} + return resp, nil + case tikvrpc.CmdFlush: + r := req.Flush() + c.cluster.handleDelay(r.StartTs, r.Context.RegionId) + resp.Resp, err = c.usSvr.KvFlush(ctx, r) + case tikvrpc.CmdBufferBatchGet: + r := req.BufferBatchGet() + resp.Resp, err = c.usSvr.KvBufferBatchGet(ctx, r) + default: + err = errors.Errorf("not support this request type %v", req.Type) + } + if err != nil { + return nil, err + } + var regErr *errorpb.Error + if req.Type != tikvrpc.CmdBatchCop && req.Type != tikvrpc.CmdMPPConn && req.Type != tikvrpc.CmdMPPTask && req.Type != tikvrpc.CmdMPPAlive { + regErr, err = resp.GetRegionError() + } + if err != nil { + return nil, err + } + if regErr != nil { + if regErr.EpochNotMatch != nil { + for i, newReg := range regErr.EpochNotMatch.CurrentRegions { + regErr.EpochNotMatch.CurrentRegions[i] = proto.Clone(newReg).(*metapb.Region) + } + } + } + return resp, nil +} + +func (c *RPCClient) handleCopStream(ctx context.Context, req *coprocessor.Request) (*tikvrpc.CopStreamResponse, error) { + copResp, err := c.usSvr.Coprocessor(ctx, req) + if err != nil { + return nil, err + } + return &tikvrpc.CopStreamResponse{ + Tikv_CoprocessorStreamClient: new(mockCopStreamClient), + Response: copResp, + }, nil +} + +// handleEstablishMPPConnection handle the mock mpp collection came from root or peers. +func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.EstablishMPPConnectionRequest, timeout time.Duration, storeID uint64) (*tikvrpc.MPPStreamResponse, error) { + mockServer := new(mockMPPConnectStreamServer) + err := c.usSvr.EstablishMPPConnectionWithStoreID(r, mockServer, storeID) + if err != nil { + return nil, err + } + failpoint.Inject("establishMppConnectionErr", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, errors.New("rpc error")) + } + }) + var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, ctx: ctx, targetTask: r.ReceiverMeta} + streamResp := &tikvrpc.MPPStreamResponse{Tikv_EstablishMPPConnectionClient: &mockClient} + _, cancel := context.WithCancel(ctx) + streamResp.Lease.Cancel = cancel + streamResp.Timeout = timeout + // mock the stream resp from the server's resp slice + first, err := streamResp.Recv() + if err != nil { + if errors.Cause(err) != io.EOF { + return nil, errors.Trace(err) + } + } + streamResp.MPPDataPacket = first + return streamResp, nil +} + +func (c *RPCClient) handleDispatchMPPTask(ctx context.Context, r *mpp.DispatchTaskRequest, storeID uint64) (*mpp.DispatchTaskResponse, error) { + return c.usSvr.DispatchMPPTaskWithStoreID(ctx, r, storeID) +} + +func (c *RPCClient) handleBatchCop(ctx context.Context, r *coprocessor.BatchRequest, timeout time.Duration) (*tikvrpc.BatchCopStreamResponse, error) { + mockBatchCopServer := &mockBatchCoprocessorStreamServer{} + err := c.usSvr.BatchCoprocessor(r, mockBatchCopServer) + if err != nil { + return nil, err + } + var mockBatchCopClient = mockBatchCopClient{batchResponses: mockBatchCopServer.batchResponses, idx: 0} + batchResp := &tikvrpc.BatchCopStreamResponse{Tikv_BatchCoprocessorClient: &mockBatchCopClient} + _, cancel := context.WithCancel(ctx) + batchResp.Lease.Cancel = cancel + batchResp.Timeout = timeout + first, err := batchResp.Recv() + if err != nil { + return nil, errors.Trace(err) + } + batchResp.BatchResponse = first + return batchResp, nil +} + +func (c *RPCClient) handleDebugGetRegionProperties(ctx context.Context, req *debugpb.GetRegionPropertiesRequest) (*debugpb.GetRegionPropertiesResponse, error) { + region := c.cluster.GetRegion(req.RegionId) + _, start, err := codec.DecodeBytes(region.StartKey, nil) + if err != nil { + return nil, err + } + _, end, err := codec.DecodeBytes(region.EndKey, nil) + if err != nil { + return nil, err + } + scanResp, err := c.usSvr.KvScan(ctx, &kvrpcpb.ScanRequest{ + Context: &kvrpcpb.Context{ + RegionId: region.Id, + RegionEpoch: region.RegionEpoch, + }, + StartKey: start, + EndKey: end, + Version: math.MaxUint64, + Limit: math.MaxUint32, + }) + if err != nil { + return nil, err + } + if err := scanResp.GetRegionError(); err != nil { + panic(err) + } + return &debugpb.GetRegionPropertiesResponse{ + Props: []*debugpb.Property{{ + Name: "mvcc.num_rows", + Value: strconv.Itoa(len(scanResp.Pairs)), + }}}, nil +} + +// Close closes RPCClient and cleanup temporal resources. +func (c *RPCClient) Close() error { + atomic.StoreInt32(&c.closed, 1) + if c.usSvr != nil { + c.usSvr.Stop() + } + if !c.persistent && c.path != "" { + err := os.RemoveAll(c.path) + _ = err + } + return nil +} + +// CloseAddr implements tikv.Client interface and it does nothing. +func (c *RPCClient) CloseAddr(addr string) error { + return nil +} + +// SetEventListener implements tikv.Client interface. +func (c *RPCClient) SetEventListener(listener tikv.ClientEventListener) {} + +type mockClientStream struct{} + +// Header implements grpc.ClientStream interface +func (mockClientStream) Header() (metadata.MD, error) { return nil, nil } + +// Trailer implements grpc.ClientStream interface +func (mockClientStream) Trailer() metadata.MD { return nil } + +// CloseSend implements grpc.ClientStream interface +func (mockClientStream) CloseSend() error { return nil } + +// Context implements grpc.ClientStream interface +func (mockClientStream) Context() context.Context { return nil } + +// SendMsg implements grpc.ClientStream interface +func (mockClientStream) SendMsg(m any) error { return nil } + +// RecvMsg implements grpc.ClientStream interface +func (mockClientStream) RecvMsg(m any) error { return nil } + +type mockCopStreamClient struct { + mockClientStream +} + +func (mock *mockCopStreamClient) Recv() (*coprocessor.Response, error) { + return nil, io.EOF +} + +type mockBatchCopClient struct { + mockClientStream + batchResponses []*coprocessor.BatchResponse + idx int +} + +func (mock *mockBatchCopClient) Recv() (*coprocessor.BatchResponse, error) { + if mock.idx < len(mock.batchResponses) { + ret := mock.batchResponses[mock.idx] + mock.idx++ + var err error + if len(ret.OtherError) > 0 { + err = errors.New(ret.OtherError) + ret = nil + } + return ret, err + } + failpoint.Inject("batchCopRecvTimeout", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(nil, context.Canceled) + } + }) + return nil, io.EOF +} + +type mockMPPConnectionClient struct { + mockClientStream + mppResponses []*mpp.MPPDataPacket + idx int + ctx context.Context + targetTask *mpp.TaskMeta +} + +func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { + if mock.idx < len(mock.mppResponses) { + ret := mock.mppResponses[mock.idx] + mock.idx++ + return ret, nil + } + failpoint.Inject("mppRecvTimeout", func(val failpoint.Value) { + if int64(val.(int)) == mock.targetTask.TaskId { + failpoint.Return(nil, context.Canceled) + } + }) + failpoint.Inject("mppRecvHang", func(val failpoint.Value) { + for val.(bool) { + select { + case <-mock.ctx.Done(): + { + failpoint.Return(nil, context.Canceled) + } + default: + time.Sleep(1 * time.Second) + } + } + }) + return nil, io.EOF +} + +type mockServerStream struct{} + +func (mockServerStream) SetHeader(metadata.MD) error { return nil } +func (mockServerStream) SendHeader(metadata.MD) error { return nil } +func (mockServerStream) SetTrailer(metadata.MD) {} +func (mockServerStream) Context() context.Context { return nil } +func (mockServerStream) SendMsg(any) error { return nil } +func (mockServerStream) RecvMsg(any) error { return nil } + +type mockBatchCoprocessorStreamServer struct { + mockServerStream + batchResponses []*coprocessor.BatchResponse +} + +func (mockBatchCopServer *mockBatchCoprocessorStreamServer) Send(response *coprocessor.BatchResponse) error { + mockBatchCopServer.batchResponses = append(mockBatchCopServer.batchResponses, response) + return nil +} + +type mockMPPConnectStreamServer struct { + mockServerStream + mppResponses []*mpp.MPPDataPacket +} + +func (mockMPPConnectStreamServer *mockMPPConnectStreamServer) Send(mppResponse *mpp.MPPDataPacket) error { + mockMPPConnectStreamServer.mppResponses = append(mockMPPConnectStreamServer.mppResponses, mppResponse) + return nil +} diff --git a/pkg/table/contextimpl/binding__failpoint_binding__.go b/pkg/table/contextimpl/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..beab72aa0be19 --- /dev/null +++ b/pkg/table/contextimpl/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package contextimpl + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/table/contextimpl/table.go b/pkg/table/contextimpl/table.go index 1f701246ab8a8..202a8c0017490 100644 --- a/pkg/table/contextimpl/table.go +++ b/pkg/table/contextimpl/table.go @@ -120,11 +120,11 @@ func (ctx *TableContextImpl) GetReservedRowIDAlloc() (*stmtctx.ReservedRowIDAllo // GetBinlogSupport implements the MutateContext interface. func (ctx *TableContextImpl) GetBinlogSupport() (context.BinlogSupport, bool) { - failpoint.Inject("forceWriteBinlog", func() { + if _, _err_ := failpoint.Eval(_curpkg_("forceWriteBinlog")); _err_ == nil { // Just to cover binlog related code in this package, since the `BinlogClient` is // still nil, mutations won't be written to pump on commit. - failpoint.Return(ctx, true) - }) + return ctx, true + } if ctx.vars().BinlogClient != nil { return ctx, true } diff --git a/pkg/table/contextimpl/table.go__failpoint_stash__ b/pkg/table/contextimpl/table.go__failpoint_stash__ new file mode 100644 index 0000000000000..1f701246ab8a8 --- /dev/null +++ b/pkg/table/contextimpl/table.go__failpoint_stash__ @@ -0,0 +1,200 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package contextimpl + +import ( + "github.com/pingcap/failpoint" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table/context" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tipb/go-binlog" +) + +var _ context.MutateContext = &TableContextImpl{} +var _ context.AllocatorContext = &TableContextImpl{} + +// TableContextImpl is used to provide context for table operations. +type TableContextImpl struct { + sessionctx.Context + // mutateBuffers is a memory pool for table related memory allocation that aims to reuse memory + // and saves allocation + // The buffers are supposed to be used inside AddRecord/UpdateRecord/RemoveRecord. + mutateBuffers *context.MutateBuffers +} + +// NewTableContextImpl creates a new TableContextImpl. +func NewTableContextImpl(sctx sessionctx.Context) *TableContextImpl { + return &TableContextImpl{ + Context: sctx, + mutateBuffers: context.NewMutateBuffers(sctx.GetSessionVars().GetWriteStmtBufs()), + } +} + +// AlternativeAllocators implements the AllocatorContext interface +func (ctx *TableContextImpl) AlternativeAllocators(tbl *model.TableInfo) (allocators autoid.Allocators, ok bool) { + // Use an independent allocator for global temporary tables. + if tbl.TempTableType == model.TempTableGlobal { + if tempTbl := ctx.vars().GetTemporaryTable(tbl); tempTbl != nil { + if alloc := tempTbl.GetAutoIDAllocator(); alloc != nil { + return autoid.NewAllocators(false, alloc), true + } + } + // If the session is not in a txn, for example, in "show create table", use the original allocator. + } + return +} + +// GetExprCtx returns the ExprContext +func (ctx *TableContextImpl) GetExprCtx() exprctx.ExprContext { + return ctx.Context.GetExprCtx() +} + +// ConnectionID implements the MutateContext interface. +func (ctx *TableContextImpl) ConnectionID() uint64 { + return ctx.vars().ConnectionID +} + +// InRestrictedSQL returns whether the current context is used in restricted SQL. +func (ctx *TableContextImpl) InRestrictedSQL() bool { + return ctx.vars().InRestrictedSQL +} + +// TxnAssertionLevel implements the MutateContext interface. +func (ctx *TableContextImpl) TxnAssertionLevel() variable.AssertionLevel { + return ctx.vars().AssertionLevel +} + +// EnableMutationChecker implements the MutateContext interface. +func (ctx *TableContextImpl) EnableMutationChecker() bool { + return ctx.vars().EnableMutationChecker +} + +// GetRowEncodingConfig returns the RowEncodingConfig. +func (ctx *TableContextImpl) GetRowEncodingConfig() context.RowEncodingConfig { + vars := ctx.vars() + return context.RowEncodingConfig{ + IsRowLevelChecksumEnabled: vars.IsRowLevelChecksumEnabled(), + RowEncoder: &vars.RowEncoder, + } +} + +// GetMutateBuffers implements the MutateContext interface. +func (ctx *TableContextImpl) GetMutateBuffers() *context.MutateBuffers { + return ctx.mutateBuffers +} + +// GetRowIDShardGenerator implements the MutateContext interface. +func (ctx *TableContextImpl) GetRowIDShardGenerator() *variable.RowIDShardGenerator { + return ctx.vars().GetRowIDShardGenerator() +} + +// GetReservedRowIDAlloc implements the MutateContext interface. +func (ctx *TableContextImpl) GetReservedRowIDAlloc() (*stmtctx.ReservedRowIDAlloc, bool) { + if sc := ctx.vars().StmtCtx; sc != nil { + return &sc.ReservedRowIDAlloc, true + } + // `StmtCtx` should not be nil in the `variable.SessionVars`. + // We just put an assertion that will panic only if in test here. + // In production code, here returns (nil, false) to make code safe + // because some old code checks `StmtCtx != nil` but we don't know why. + intest.Assert(false, "SessionVars.StmtCtx should not be nil") + return nil, false +} + +// GetBinlogSupport implements the MutateContext interface. +func (ctx *TableContextImpl) GetBinlogSupport() (context.BinlogSupport, bool) { + failpoint.Inject("forceWriteBinlog", func() { + // Just to cover binlog related code in this package, since the `BinlogClient` is + // still nil, mutations won't be written to pump on commit. + failpoint.Return(ctx, true) + }) + if ctx.vars().BinlogClient != nil { + return ctx, true + } + return nil, false +} + +// GetBinlogMutation implements the BinlogSupport interface. +func (ctx *TableContextImpl) GetBinlogMutation(tblID int64) *binlog.TableMutation { + return ctx.Context.StmtGetMutation(tblID) +} + +// GetStatisticsSupport implements the MutateContext interface. +func (ctx *TableContextImpl) GetStatisticsSupport() (context.StatisticsSupport, bool) { + if ctx.vars().TxnCtx != nil { + return ctx, true + } + return nil, false +} + +// UpdatePhysicalTableDelta implements the StatisticsSupport interface. +func (ctx *TableContextImpl) UpdatePhysicalTableDelta( + physicalTableID int64, delta int64, count int64, cols variable.DeltaCols, +) { + if txnCtx := ctx.vars().TxnCtx; txnCtx != nil { + txnCtx.UpdateDeltaForTable(physicalTableID, delta, count, cols) + } +} + +// GetCachedTableSupport implements the MutateContext interface. +func (ctx *TableContextImpl) GetCachedTableSupport() (context.CachedTableSupport, bool) { + if ctx.vars().TxnCtx != nil { + return ctx, true + } + return nil, false +} + +// AddCachedTableHandleToTxn implements `CachedTableSupport` interface +func (ctx *TableContextImpl) AddCachedTableHandleToTxn(tableID int64, handle any) { + txnCtx := ctx.vars().TxnCtx + if txnCtx.CachedTables == nil { + txnCtx.CachedTables = make(map[int64]any) + } + if _, ok := txnCtx.CachedTables[tableID]; !ok { + txnCtx.CachedTables[tableID] = handle + } +} + +// GetTemporaryTableSupport implements the MutateContext interface. +func (ctx *TableContextImpl) GetTemporaryTableSupport() (context.TemporaryTableSupport, bool) { + if ctx.vars().TxnCtx == nil { + return nil, false + } + return ctx, true +} + +// GetTemporaryTableSizeLimit implements TemporaryTableSupport interface. +func (ctx *TableContextImpl) GetTemporaryTableSizeLimit() int64 { + return ctx.vars().TMPTableSize +} + +// AddTemporaryTableToTxn implements the TemporaryTableSupport interface. +func (ctx *TableContextImpl) AddTemporaryTableToTxn(tblInfo *model.TableInfo) (context.TemporaryTableHandler, bool) { + vars := ctx.vars() + if tbl := vars.GetTemporaryTable(tblInfo); tbl != nil { + tbl.SetModified(true) + return context.NewTemporaryTableHandler(tbl, vars.TemporaryTableData), true + } + return context.TemporaryTableHandler{}, false +} + +func (ctx *TableContextImpl) vars() *variable.SessionVars { + return ctx.Context.GetSessionVars() +} diff --git a/pkg/table/tables/binding__failpoint_binding__.go b/pkg/table/tables/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..fc93c522f2734 --- /dev/null +++ b/pkg/table/tables/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package tables + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/table/tables/cache.go b/pkg/table/tables/cache.go index 1051ee8a995f6..7554cfd7f6e0d 100644 --- a/pkg/table/tables/cache.go +++ b/pkg/table/tables/cache.go @@ -100,9 +100,9 @@ func (c *cachedTable) TryReadFromCache(ts uint64, leaseDuration time.Duration) ( distance := leaseTime.Sub(nowTime) var triggerFailpoint bool - failpoint.Inject("mockRenewLeaseABA1", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("mockRenewLeaseABA1")); _err_ == nil { triggerFailpoint = true - }) + } if distance >= 0 && distance <= leaseDuration/2 || triggerFailpoint { if h := c.TakeStateRemoteHandleNoWait(); h != nil { @@ -273,11 +273,11 @@ func (c *cachedTable) RemoveRecord(sctx table.MutateContext, h kv.Handle, r []ty var TestMockRenewLeaseABA2 chan struct{} func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, leaseDuration time.Duration) { - failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("mockRenewLeaseABA2")); _err_ == nil { c.PutStateRemoteHandle(handle) <-TestMockRenewLeaseABA2 c.TakeStateRemoteHandle() - }) + } defer c.PutStateRemoteHandle(handle) @@ -298,9 +298,9 @@ func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, }) } - failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_("mockRenewLeaseABA2")); _err_ == nil { TestMockRenewLeaseABA2 <- struct{}{} - }) + } } const cacheTableWriteLease = 5 * time.Second diff --git a/pkg/table/tables/cache.go__failpoint_stash__ b/pkg/table/tables/cache.go__failpoint_stash__ new file mode 100644 index 0000000000000..1051ee8a995f6 --- /dev/null +++ b/pkg/table/tables/cache.go__failpoint_stash__ @@ -0,0 +1,355 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "sync/atomic" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/sqlexec" + "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" + "go.uber.org/zap" +) + +var ( + _ table.CachedTable = &cachedTable{} +) + +type cachedTable struct { + TableCommon + cacheData atomic.Pointer[cacheData] + totalSize int64 + // StateRemote is not thread-safe, this tokenLimit is used to keep only one visitor. + tokenLimit +} + +type tokenLimit chan StateRemote + +func (t tokenLimit) TakeStateRemoteHandle() StateRemote { + handle := <-t + return handle +} + +func (t tokenLimit) TakeStateRemoteHandleNoWait() StateRemote { + select { + case handle := <-t: + return handle + default: + return nil + } +} + +func (t tokenLimit) PutStateRemoteHandle(handle StateRemote) { + t <- handle +} + +// cacheData pack the cache data and lease. +type cacheData struct { + Start uint64 + Lease uint64 + kv.MemBuffer +} + +func leaseFromTS(ts uint64, leaseDuration time.Duration) uint64 { + physicalTime := oracle.GetTimeFromTS(ts) + lease := oracle.GoTimeToTS(physicalTime.Add(leaseDuration)) + return lease +} + +func newMemBuffer(store kv.Storage) (kv.MemBuffer, error) { + // Here is a trick to get a MemBuffer data, because the internal API is not exposed. + // Create a transaction with start ts 0, and take the MemBuffer out. + buffTxn, err := store.Begin(tikv.WithStartTS(0)) + if err != nil { + return nil, err + } + return buffTxn.GetMemBuffer(), nil +} + +func (c *cachedTable) TryReadFromCache(ts uint64, leaseDuration time.Duration) (kv.MemBuffer, bool /*loading*/) { + data := c.cacheData.Load() + if data == nil { + return nil, false + } + if ts >= data.Start && ts < data.Lease { + leaseTime := oracle.GetTimeFromTS(data.Lease) + nowTime := oracle.GetTimeFromTS(ts) + distance := leaseTime.Sub(nowTime) + + var triggerFailpoint bool + failpoint.Inject("mockRenewLeaseABA1", func(_ failpoint.Value) { + triggerFailpoint = true + }) + + if distance >= 0 && distance <= leaseDuration/2 || triggerFailpoint { + if h := c.TakeStateRemoteHandleNoWait(); h != nil { + go c.renewLease(h, ts, data, leaseDuration) + } + } + // If data is not nil, but data.MemBuffer is nil, it means the data is being + // loading by a background goroutine. + return data.MemBuffer, data.MemBuffer == nil + } + return nil, false +} + +// newCachedTable creates a new CachedTable Instance +func newCachedTable(tbl *TableCommon) (table.Table, error) { + ret := &cachedTable{ + TableCommon: tbl.Copy(), + tokenLimit: make(chan StateRemote, 1), + } + return ret, nil +} + +// Init is an extra operation for cachedTable after TableFromMeta, +// Because cachedTable need some additional parameter that can't be passed in TableFromMeta. +func (c *cachedTable) Init(exec sqlexec.SQLExecutor) error { + raw, ok := exec.(sqlExec) + if !ok { + return errors.New("Need sqlExec rather than sqlexec.SQLExecutor") + } + handle := NewStateRemote(raw) + c.PutStateRemoteHandle(handle) + return nil +} + +func (c *cachedTable) loadDataFromOriginalTable(store kv.Storage) (kv.MemBuffer, uint64, int64, error) { + buffer, err := newMemBuffer(store) + if err != nil { + return nil, 0, 0, err + } + var startTS uint64 + totalSize := int64(0) + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnCacheTable) + err = kv.RunInNewTxn(ctx, store, true, func(ctx context.Context, txn kv.Transaction) error { + prefix := tablecodec.GenTablePrefix(c.tableID) + if err != nil { + return errors.Trace(err) + } + startTS = txn.StartTS() + it, err := txn.Iter(prefix, prefix.PrefixNext()) + if err != nil { + return errors.Trace(err) + } + defer it.Close() + + for it.Valid() && it.Key().HasPrefix(prefix) { + key := it.Key() + value := it.Value() + err = buffer.Set(key, value) + if err != nil { + return errors.Trace(err) + } + totalSize += int64(len(key)) + totalSize += int64(len(value)) + err = it.Next() + if err != nil { + return errors.Trace(err) + } + } + return nil + }) + if err != nil { + return nil, 0, totalSize, err + } + + return buffer, startTS, totalSize, nil +} + +func (c *cachedTable) UpdateLockForRead(ctx context.Context, store kv.Storage, ts uint64, leaseDuration time.Duration) { + if h := c.TakeStateRemoteHandleNoWait(); h != nil { + go c.updateLockForRead(ctx, h, store, ts, leaseDuration) + } +} + +func (c *cachedTable) updateLockForRead(ctx context.Context, handle StateRemote, store kv.Storage, ts uint64, leaseDuration time.Duration) { + defer func() { + if r := recover(); r != nil { + log.Error("panic in the recoverable goroutine", + zap.Any("r", r), + zap.Stack("stack trace")) + } + c.PutStateRemoteHandle(handle) + }() + + // Load data from original table and the update lock information. + tid := c.Meta().ID + lease := leaseFromTS(ts, leaseDuration) + succ, err := handle.LockForRead(ctx, tid, lease) + if err != nil { + log.Warn("lock cached table for read", zap.Error(err)) + return + } + if succ { + c.cacheData.Store(&cacheData{ + Start: ts, + Lease: lease, + MemBuffer: nil, // Async loading, this will be set later. + }) + + // Make the load data process async, in case that loading data takes longer the + // lease duration, then the loaded data get staled and that process repeats forever. + go func() { + start := time.Now() + mb, startTS, totalSize, err := c.loadDataFromOriginalTable(store) + metrics.LoadTableCacheDurationHistogram.Observe(time.Since(start).Seconds()) + if err != nil { + log.Info("load data from table fail", zap.Error(err)) + return + } + + tmp := c.cacheData.Load() + if tmp != nil && tmp.Start == ts { + c.cacheData.Store(&cacheData{ + Start: startTS, + Lease: tmp.Lease, + MemBuffer: mb, + }) + atomic.StoreInt64(&c.totalSize, totalSize) + } + }() + } + // Current status is not suitable to cache. +} + +const cachedTableSizeLimit = 64 * (1 << 20) + +// AddRecord implements the AddRecord method for the table.Table interface. +func (c *cachedTable) AddRecord(sctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { + if atomic.LoadInt64(&c.totalSize) > cachedTableSizeLimit { + return nil, table.ErrOptOnCacheTable.GenWithStackByArgs("table too large") + } + txnCtxAddCachedTable(sctx, c.Meta().ID, c) + return c.TableCommon.AddRecord(sctx, r, opts...) +} + +func txnCtxAddCachedTable(sctx table.MutateContext, tid int64, handle *cachedTable) { + if s, ok := sctx.GetCachedTableSupport(); ok { + s.AddCachedTableHandleToTxn(tid, handle) + } +} + +// UpdateRecord implements table.Table +func (c *cachedTable) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { + // Prevent furthur writing when the table is already too large. + if atomic.LoadInt64(&c.totalSize) > cachedTableSizeLimit { + return table.ErrOptOnCacheTable.GenWithStackByArgs("table too large") + } + txnCtxAddCachedTable(ctx, c.Meta().ID, c) + return c.TableCommon.UpdateRecord(ctx, h, oldData, newData, touched, opts...) +} + +// RemoveRecord implements table.Table RemoveRecord interface. +func (c *cachedTable) RemoveRecord(sctx table.MutateContext, h kv.Handle, r []types.Datum) error { + txnCtxAddCachedTable(sctx, c.Meta().ID, c) + return c.TableCommon.RemoveRecord(sctx, h, r) +} + +// TestMockRenewLeaseABA2 is used by test function TestRenewLeaseABAFailPoint. +var TestMockRenewLeaseABA2 chan struct{} + +func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, leaseDuration time.Duration) { + failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { + c.PutStateRemoteHandle(handle) + <-TestMockRenewLeaseABA2 + c.TakeStateRemoteHandle() + }) + + defer c.PutStateRemoteHandle(handle) + + tid := c.Meta().ID + lease := leaseFromTS(ts, leaseDuration) + newLease, err := handle.RenewReadLease(context.Background(), tid, data.Lease, lease) + if err != nil { + if !kv.IsTxnRetryableError(err) { + log.Warn("Renew read lease error", zap.Error(err)) + } + return + } + if newLease > 0 { + c.cacheData.Store(&cacheData{ + Start: data.Start, + Lease: newLease, + MemBuffer: data.MemBuffer, + }) + } + + failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { + TestMockRenewLeaseABA2 <- struct{}{} + }) +} + +const cacheTableWriteLease = 5 * time.Second + +func (c *cachedTable) WriteLockAndKeepAlive(ctx context.Context, exit chan struct{}, leasePtr *uint64, wg chan error) { + writeLockLease, err := c.lockForWrite(ctx) + atomic.StoreUint64(leasePtr, writeLockLease) + wg <- err + if err != nil { + logutil.Logger(ctx).Warn("lock for write lock fail", zap.String("category", "cached table"), zap.Error(err)) + return + } + + t := time.NewTicker(cacheTableWriteLease / 2) + defer t.Stop() + for { + select { + case <-t.C: + if err := c.renew(ctx, leasePtr); err != nil { + logutil.Logger(ctx).Warn("renew write lock lease fail", zap.String("category", "cached table"), zap.Error(err)) + return + } + case <-exit: + return + } + } +} + +func (c *cachedTable) renew(ctx context.Context, leasePtr *uint64) error { + oldLease := atomic.LoadUint64(leasePtr) + physicalTime := oracle.GetTimeFromTS(oldLease) + newLease := oracle.GoTimeToTS(physicalTime.Add(cacheTableWriteLease)) + + h := c.TakeStateRemoteHandle() + defer c.PutStateRemoteHandle(h) + + succ, err := h.RenewWriteLease(ctx, c.Meta().ID, newLease) + if err != nil { + return errors.Trace(err) + } + if succ { + atomic.StoreUint64(leasePtr, newLease) + } + return nil +} + +func (c *cachedTable) lockForWrite(ctx context.Context) (uint64, error) { + handle := c.TakeStateRemoteHandle() + defer c.PutStateRemoteHandle(handle) + + return handle.LockForWrite(ctx, c.Meta().ID, cacheTableWriteLease) +} diff --git a/pkg/table/tables/mutation_checker.go b/pkg/table/tables/mutation_checker.go index 33f18ea37cb95..0fd6fdeb50517 100644 --- a/pkg/table/tables/mutation_checker.go +++ b/pkg/table/tables/mutation_checker.go @@ -549,8 +549,8 @@ func corruptMutations(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle, c } func injectMutationError(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle) error { - failpoint.Inject("corruptMutations", func(commands failpoint.Value) { - failpoint.Return(corruptMutations(t, txn, sh, commands.(string))) - }) + if commands, _err_ := failpoint.Eval(_curpkg_("corruptMutations")); _err_ == nil { + return corruptMutations(t, txn, sh, commands.(string)) + } return nil } diff --git a/pkg/table/tables/mutation_checker.go__failpoint_stash__ b/pkg/table/tables/mutation_checker.go__failpoint_stash__ new file mode 100644 index 0000000000000..33f18ea37cb95 --- /dev/null +++ b/pkg/table/tables/mutation_checker.go__failpoint_stash__ @@ -0,0 +1,556 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/errno" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/table" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/dbterror" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/rowcodec" + "go.uber.org/zap" +) + +var ( + // ErrInconsistentRowValue is the error when values in a row insertion does not match the expected ones. + ErrInconsistentRowValue = dbterror.ClassTable.NewStd(errno.ErrInconsistentRowValue) + // ErrInconsistentHandle is the error when the handle in the row/index insertions does not match. + ErrInconsistentHandle = dbterror.ClassTable.NewStd(errno.ErrInconsistentHandle) + // ErrInconsistentIndexedValue is the error when decoded values from the index mutation cannot match row value + ErrInconsistentIndexedValue = dbterror.ClassTable.NewStd(errno.ErrInconsistentIndexedValue) +) + +type mutation struct { + key kv.Key + flags kv.KeyFlags + value []byte + indexID int64 // only for index mutations +} + +type columnMaps struct { + ColumnIDToInfo map[int64]*model.ColumnInfo + ColumnIDToFieldType map[int64]*types.FieldType + IndexIDToInfo map[int64]*model.IndexInfo + IndexIDToRowColInfos map[int64][]rowcodec.ColInfo +} + +// CheckDataConsistency checks whether the given set of mutations corresponding to a single row is consistent. +// Namely, assume the database is consistent before, applying the mutations shouldn't break the consistency. +// It aims at reducing bugs that will corrupt data, and preventing mistakes from spreading if possible. +// +// 3 conditions are checked: +// (1) row.value is consistent with input data +// (2) the handle is consistent in row and index insertions +// (3) the keys of the indices are consistent with the values of rows +// +// The check doesn't work and just returns nil when: +// (1) the table is partitioned +// (2) new collation is enabled and restored data is needed +// +// The check is performed on almost every write. Its performance matters. +// Let M = the number of mutations, C = the number of columns in the table, +// I = the sum of the number of columns in all indices, +// The time complexity is O(M * C + I) +// The space complexity is O(M + C + I) +func CheckDataConsistency( + txn kv.Transaction, tc types.Context, t *TableCommon, + rowToInsert, rowToRemove []types.Datum, memBuffer kv.MemBuffer, sh kv.StagingHandle, +) error { + if t.Meta().GetPartitionInfo() != nil { + return nil + } + if txn.IsPipelined() { + return nil + } + if sh == 0 { + // some implementations of MemBuffer doesn't support staging, e.g. that in pkg/lightning/backend/kv + return nil + } + indexMutations, rowInsertion, err := collectTableMutationsFromBufferStage(t, memBuffer, sh) + if err != nil { + return errors.Trace(err) + } + + columnMaps := getColumnMaps(txn, t) + + // Row insertion consistency check contributes the least to defending data-index consistency, but costs most CPU resources. + // So we disable it for now. + // + // if rowToInsert != nil { + // if err := checkRowInsertionConsistency( + // sessVars, rowToInsert, rowInsertion, columnMaps.ColumnIDToInfo, columnMaps.ColumnIDToFieldType, t.Meta().Name.O, + // ); err != nil { + // return errors.Trace(err) + // } + // } + + if rowInsertion.key != nil { + if err = checkHandleConsistency(rowInsertion, indexMutations, columnMaps.IndexIDToInfo, t.Meta()); err != nil { + return errors.Trace(err) + } + } + + if err := checkIndexKeys( + tc, t, rowToInsert, rowToRemove, indexMutations, columnMaps.IndexIDToInfo, columnMaps.IndexIDToRowColInfos, + ); err != nil { + return errors.Trace(err) + } + return nil +} + +// checkHandleConsistency checks whether the handles, with regard to a single-row change, +// in row insertions and index insertions are consistent. +// A PUT_index implies a PUT_row with the same handle. +// Deletions are not checked since the values of deletions are unknown +func checkHandleConsistency(rowInsertion mutation, indexMutations []mutation, indexIDToInfo map[int64]*model.IndexInfo, tblInfo *model.TableInfo) error { + var insertionHandle kv.Handle + var err error + + if rowInsertion.key == nil { + return nil + } + insertionHandle, err = tablecodec.DecodeRowKey(rowInsertion.key) + if err != nil { + return errors.Trace(err) + } + + for _, m := range indexMutations { + if len(m.value) == 0 { + continue + } + + // Generate correct index id for check. + idxID := m.indexID & tablecodec.IndexIDMask + indexInfo, ok := indexIDToInfo[idxID] + if !ok { + return errors.New("index not found") + } + + // If this is the temporary index data, need to remove the last byte of index data(version about when it is written). + var ( + value []byte + orgKey []byte + indexHandle kv.Handle + ) + if idxID != m.indexID { + if tablecodec.TempIndexValueIsUntouched(m.value) { + // We never commit the untouched key values to the storage. Skip this check. + continue + } + var tempIdxVal tablecodec.TempIndexValue + tempIdxVal, err = tablecodec.DecodeTempIndexValue(m.value) + if err != nil { + return err + } + if !tempIdxVal.IsEmpty() { + value = tempIdxVal.Current().Value + } + if len(value) == 0 { + // Skip the deleted operation values. + continue + } + orgKey = append(orgKey, m.key...) + tablecodec.TempIndexKey2IndexKey(orgKey) + indexHandle, err = tablecodec.DecodeIndexHandle(orgKey, value, len(indexInfo.Columns)) + } else { + indexHandle, err = tablecodec.DecodeIndexHandle(m.key, m.value, len(indexInfo.Columns)) + } + if err != nil { + return errors.Trace(err) + } + // NOTE: handle type can be different, see issue 29520 + if indexHandle.IsInt() == insertionHandle.IsInt() && indexHandle.Compare(insertionHandle) != 0 { + err = ErrInconsistentHandle.GenWithStackByArgs(tblInfo.Name, indexInfo.Name.O, indexHandle, insertionHandle, m, rowInsertion) + logutil.BgLogger().Error("inconsistent handle in index and record insertions", zap.Error(err)) + return err + } + } + + return err +} + +// checkIndexKeys checks whether the decoded data from keys of index mutations are consistent with the expected ones. +// +// How it works: +// +// Assume the set of row values changes from V1 to V2, we check +// (1) V2 - V1 = {added indices} +// (2) V1 - V2 = {deleted indices} +// +// To check (1), we need +// (a) {added indices} is a subset of {needed indices} => each index mutation is consistent with the input/row key/value +// (b) {needed indices} is a subset of {added indices}. The check process would be exactly the same with how we generate the mutations, thus ignored. +func checkIndexKeys( + tc types.Context, t *TableCommon, rowToInsert, rowToRemove []types.Datum, + indexMutations []mutation, indexIDToInfo map[int64]*model.IndexInfo, + indexIDToRowColInfos map[int64][]rowcodec.ColInfo, +) error { + var indexData []types.Datum + for _, m := range indexMutations { + var value []byte + // Generate correct index id for check. + idxID := m.indexID & tablecodec.IndexIDMask + indexInfo, ok := indexIDToInfo[idxID] + if !ok { + return errors.New("index not found") + } + rowColInfos, ok := indexIDToRowColInfos[idxID] + if !ok { + return errors.New("index not found") + } + + var isTmpIdxValAndDeleted bool + // If this is temp index data, need remove last byte of index data. + if idxID != m.indexID { + if tablecodec.TempIndexValueIsUntouched(m.value) { + // We never commit the untouched key values to the storage. Skip this check. + continue + } + tmpVal, err := tablecodec.DecodeTempIndexValue(m.value) + if err != nil { + return err + } + curElem := tmpVal.Current() + isTmpIdxValAndDeleted = curElem.Delete + value = append(value, curElem.Value...) + } else { + value = append(value, m.value...) + } + + // when we cannot decode the key to get the original value + if len(value) == 0 && NeedRestoredData(indexInfo.Columns, t.Meta().Columns) { + continue + } + + decodedIndexValues, err := tablecodec.DecodeIndexKV( + m.key, value, len(indexInfo.Columns), tablecodec.HandleNotNeeded, rowColInfos, + ) + if err != nil { + return errors.Trace(err) + } + + // reuse the underlying memory, save an allocation + if indexData == nil { + indexData = make([]types.Datum, 0, len(decodedIndexValues)) + } else { + indexData = indexData[:0] + } + + loc := tc.Location() + for i, v := range decodedIndexValues { + fieldType := t.Columns[indexInfo.Columns[i].Offset].FieldType.ArrayType() + datum, err := tablecodec.DecodeColumnValue(v, fieldType, loc) + if err != nil { + return errors.Trace(err) + } + indexData = append(indexData, datum) + } + + // When it is in add index new backfill state. + if len(value) == 0 || isTmpIdxValAndDeleted { + err = compareIndexData(tc, t.Columns, indexData, rowToRemove, indexInfo, t.Meta()) + } else { + err = compareIndexData(tc, t.Columns, indexData, rowToInsert, indexInfo, t.Meta()) + } + if err != nil { + return errors.Trace(err) + } + } + return nil +} + +// checkRowInsertionConsistency checks whether the values of row mutations are consistent with the expected ones +// We only check data added since a deletion of a row doesn't care about its value (and we cannot know it) +func checkRowInsertionConsistency( + sessVars *variable.SessionVars, rowToInsert []types.Datum, rowInsertion mutation, + columnIDToInfo map[int64]*model.ColumnInfo, columnIDToFieldType map[int64]*types.FieldType, tableName string, +) error { + if rowToInsert == nil { + // it's a deletion + return nil + } + + decodedData, err := tablecodec.DecodeRowToDatumMap(rowInsertion.value, columnIDToFieldType, sessVars.Location()) + if err != nil { + return errors.Trace(err) + } + + // NOTE: we cannot check if the decoded values contain all columns since some columns may be skipped. It can even be empty + // Instead, we check that decoded index values are consistent with the input row. + + for columnID, decodedDatum := range decodedData { + inputDatum := rowToInsert[columnIDToInfo[columnID].Offset] + cmp, err := decodedDatum.Compare(sessVars.StmtCtx.TypeCtx(), &inputDatum, collate.GetCollator(decodedDatum.Collation())) + if err != nil { + return errors.Trace(err) + } + if cmp != 0 { + err = ErrInconsistentRowValue.GenWithStackByArgs(tableName, inputDatum.String(), decodedDatum.String()) + logutil.BgLogger().Error("inconsistent row value in row insertion", zap.Error(err)) + return err + } + } + return nil +} + +// collectTableMutationsFromBufferStage collects mutations of the current table from the mem buffer stage +// It returns: (1) all index mutations (2) the only row insertion +// If there are no row insertions, the 2nd returned value is nil +// If there are multiple row insertions, an error is returned +func collectTableMutationsFromBufferStage(t *TableCommon, memBuffer kv.MemBuffer, sh kv.StagingHandle) ( + []mutation, mutation, error, +) { + indexMutations := make([]mutation, 0) + var rowInsertion mutation + var err error + inspector := func(key kv.Key, flags kv.KeyFlags, data []byte) { + // only check the current table + if tablecodec.DecodeTableID(key) == t.physicalTableID { + m := mutation{key, flags, data, 0} + if rowcodec.IsRowKey(key) { + if len(data) > 0 { + if rowInsertion.key == nil { + rowInsertion = m + } else { + err = errors.Errorf( + "multiple row mutations added/mutated, one = %+v, another = %+v", rowInsertion, m, + ) + } + } + } else { + _, m.indexID, _, err = tablecodec.DecodeIndexKey(m.key) + if err != nil { + err = errors.Trace(err) + } + indexMutations = append(indexMutations, m) + } + } + } + memBuffer.InspectStage(sh, inspector) + return indexMutations, rowInsertion, err +} + +// compareIndexData compares the decoded index data with the input data. +// Returns error if the index data is not a subset of the input data. +func compareIndexData( + tc types.Context, cols []*table.Column, indexData, input []types.Datum, indexInfo *model.IndexInfo, + tableInfo *model.TableInfo, +) error { + for i := range indexData { + decodedMutationDatum := indexData[i] + expectedDatum := input[indexInfo.Columns[i].Offset] + + tablecodec.TruncateIndexValue( + &expectedDatum, indexInfo.Columns[i], + cols[indexInfo.Columns[i].Offset].ColumnInfo, + ) + tablecodec.TruncateIndexValue( + &decodedMutationDatum, indexInfo.Columns[i], + cols[indexInfo.Columns[i].Offset].ColumnInfo, + ) + + comparison, err := CompareIndexAndVal(tc, expectedDatum, decodedMutationDatum, + collate.GetCollator(decodedMutationDatum.Collation()), + cols[indexInfo.Columns[i].Offset].ColumnInfo.FieldType.IsArray() && expectedDatum.Kind() == types.KindMysqlJSON) + if err != nil { + return errors.Trace(err) + } + + if comparison != 0 { + err = ErrInconsistentIndexedValue.GenWithStackByArgs( + tableInfo.Name.O, indexInfo.Name.O, cols[indexInfo.Columns[i].Offset].ColumnInfo.Name.O, + decodedMutationDatum.String(), expectedDatum.String(), + ) + logutil.BgLogger().Error("inconsistent indexed value in index insertion", zap.Error(err)) + return err + } + } + return nil +} + +// CompareIndexAndVal compare index valued and row value. +func CompareIndexAndVal(tc types.Context, rowVal types.Datum, idxVal types.Datum, collator collate.Collator, cmpMVIndex bool) (int, error) { + var cmpRes int + var err error + if cmpMVIndex { + // If it is multi-valued index, we should check the JSON contains the indexed value. + bj := rowVal.GetMysqlJSON() + count := bj.GetElemCount() + for elemIdx := 0; elemIdx < count; elemIdx++ { + jsonDatum := types.NewJSONDatum(bj.ArrayGetElem(elemIdx)) + cmpRes, err = jsonDatum.Compare(tc, &idxVal, collate.GetBinaryCollator()) + if err != nil { + return 0, errors.Trace(err) + } + if cmpRes == 0 { + break + } + } + } else { + cmpRes, err = idxVal.Compare(tc, &rowVal, collator) + } + return cmpRes, err +} + +// getColumnMaps tries to get the columnMaps from transaction options. If there isn't one, it builds one and stores it. +// It saves redundant computations of the map. +func getColumnMaps(txn kv.Transaction, t *TableCommon) columnMaps { + getter := func() (map[int64]columnMaps, bool) { + m, ok := txn.GetOption(kv.TableToColumnMaps).(map[int64]columnMaps) + return m, ok + } + setter := func(maps map[int64]columnMaps) { + txn.SetOption(kv.TableToColumnMaps, maps) + } + columnMaps := getOrBuildColumnMaps(getter, setter, t) + return columnMaps +} + +// getOrBuildColumnMaps tries to get the columnMaps from some place. If there isn't one, it builds one and stores it. +// It saves redundant computations of the map. +func getOrBuildColumnMaps( + getter func() (map[int64]columnMaps, bool), setter func(map[int64]columnMaps), t *TableCommon, +) columnMaps { + tableMaps, ok := getter() + if !ok || tableMaps == nil { + tableMaps = make(map[int64]columnMaps) + } + maps, ok := tableMaps[t.tableID] + if !ok { + maps = columnMaps{ + make(map[int64]*model.ColumnInfo, len(t.Meta().Columns)), + make(map[int64]*types.FieldType, len(t.Meta().Columns)), + make(map[int64]*model.IndexInfo, len(t.Indices())), + make(map[int64][]rowcodec.ColInfo, len(t.Indices())), + } + + for _, col := range t.Meta().Columns { + maps.ColumnIDToInfo[col.ID] = col + maps.ColumnIDToFieldType[col.ID] = &(col.FieldType) + } + for _, index := range t.Indices() { + if index.Meta().Primary && t.meta.IsCommonHandle { + continue + } + maps.IndexIDToInfo[index.Meta().ID] = index.Meta() + maps.IndexIDToRowColInfos[index.Meta().ID] = BuildRowcodecColInfoForIndexColumns(index.Meta(), t.Meta()) + } + + tableMaps[t.tableID] = maps + setter(tableMaps) + } + return maps +} + +// only used in tests +// commands is a comma separated string, each representing a type of corruptions to the mutations +// The injection depends on actual encoding rules. +func corruptMutations(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle, cmds string) error { + commands := strings.Split(cmds, ",") + memBuffer := txn.GetMemBuffer() + + indexMutations, _, err := collectTableMutationsFromBufferStage(t, memBuffer, sh) + if err != nil { + return errors.Trace(err) + } + + for _, cmd := range commands { + switch cmd { + case "extraIndex": + // an extra index mutation + { + if len(indexMutations) == 0 { + continue + } + indexMutation := indexMutations[0] + key := make([]byte, len(indexMutation.key)) + copy(key, indexMutation.key) + key[len(key)-1]++ + if len(indexMutation.value) == 0 { + if err := memBuffer.Delete(key); err != nil { + return errors.Trace(err) + } + } else { + if err := memBuffer.Set(key, indexMutation.value); err != nil { + return errors.Trace(err) + } + } + } + case "missingIndex": + // an index mutation is missing + // "missIndex" should be placed in front of "extraIndex"es, + // in case it removes the mutation that was just added + { + indexMutation := indexMutations[0] + memBuffer.RemoveFromBuffer(indexMutation.key) + } + case "corruptIndexKey": + // a corrupted index mutation. + // TODO: distinguish which part is corrupted, value or handle + { + indexMutation := indexMutations[0] + key := indexMutation.key + memBuffer.RemoveFromBuffer(key) + key[len(key)-1]++ + if len(indexMutation.value) == 0 { + if err := memBuffer.Delete(key); err != nil { + return errors.Trace(err) + } + } else { + if err := memBuffer.Set(key, indexMutation.value); err != nil { + return errors.Trace(err) + } + } + } + case "corruptIndexValue": + // TODO: distinguish which part to corrupt, int handle, common handle, or restored data? + // It doesn't make much sense to always corrupt the last byte + { + if len(indexMutations) == 0 { + continue + } + indexMutation := indexMutations[0] + value := indexMutation.value + if len(value) > 0 { + value[len(value)-1]++ + if err := memBuffer.Set(indexMutation.key, value); err != nil { + return errors.Trace(err) + } + } + } + default: + return fmt.Errorf("unknown command to corrupt mutation: %s", cmd) + } + } + return nil +} + +func injectMutationError(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle) error { + failpoint.Inject("corruptMutations", func(commands failpoint.Value) { + failpoint.Return(corruptMutations(t, txn, sh, commands.(string))) + }) + return nil +} diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index b44056c4b2613..cd0bb4c0e18f6 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -544,17 +544,17 @@ func (t *TableCommon) updateRecord(sctx table.MutateContext, h kv.Handle, oldDat return err } - failpoint.Inject("updateRecordForceAssertNotExist", func() { + if _, _err_ := failpoint.Eval(_curpkg_("updateRecordForceAssertNotExist")); _err_ == nil { // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. if sctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting not exist on UpdateRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { - failpoint.Return(err) + return err } } - }) + } if t.shouldAssert(sctx.TxnAssertionLevel()) { err = txn.SetAssertion(key, kv.SetAssertExist) @@ -930,17 +930,17 @@ func (t *TableCommon) addRecord(sctx table.MutateContext, r []types.Datum, opt * return nil, err } - failpoint.Inject("addRecordForceAssertExist", func() { + if _, _err_ := failpoint.Eval(_curpkg_("addRecordForceAssertExist")); _err_ == nil { // Assert the key exists while it actually doesn't. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. if sctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting exist on AddRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertExist); err != nil { - failpoint.Return(nil, err) + return nil, err } } - }) + } if setPresume && !txn.IsPessimistic() { err = txn.SetAssertion(key, kv.SetAssertUnknown) } else { @@ -1364,17 +1364,17 @@ func (t *TableCommon) removeRowData(ctx table.MutateContext, h kv.Handle) error } key := t.RecordKey(h) - failpoint.Inject("removeRecordForceAssertNotExist", func() { + if _, _err_ := failpoint.Eval(_curpkg_("removeRecordForceAssertNotExist")); _err_ == nil { // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. if ctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting not exist on RemoveRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { - failpoint.Return(err) + return err } } - }) + } if t.shouldAssert(ctx.TxnAssertionLevel()) { err = txn.SetAssertion(key, kv.SetAssertExist) } else { diff --git a/pkg/table/tables/tables.go__failpoint_stash__ b/pkg/table/tables/tables.go__failpoint_stash__ new file mode 100644 index 0000000000000..b44056c4b2613 --- /dev/null +++ b/pkg/table/tables/tables.go__failpoint_stash__ @@ -0,0 +1,2100 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +package tables + +import ( + "context" + "fmt" + "math" + "strconv" + "strings" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/expression" + exprctx "github.com/pingcap/tidb/pkg/expression/context" + "github.com/pingcap/tidb/pkg/kv" + "github.com/pingcap/tidb/pkg/meta" + "github.com/pingcap/tidb/pkg/meta/autoid" + "github.com/pingcap/tidb/pkg/parser/ast" + "github.com/pingcap/tidb/pkg/parser/model" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" + "github.com/pingcap/tidb/pkg/sessionctx/variable" + "github.com/pingcap/tidb/pkg/statistics" + "github.com/pingcap/tidb/pkg/table" + tbctx "github.com/pingcap/tidb/pkg/table/context" + "github.com/pingcap/tidb/pkg/tablecodec" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/codec" + "github.com/pingcap/tidb/pkg/util/collate" + "github.com/pingcap/tidb/pkg/util/generatedexpr" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/stringutil" + "github.com/pingcap/tidb/pkg/util/tableutil" + "github.com/pingcap/tidb/pkg/util/tracing" + "github.com/pingcap/tipb/go-binlog" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +// TableCommon is shared by both Table and partition. +// NOTE: when copying this struct, use Copy() to clear its columns cache. +type TableCommon struct { + // TODO: Why do we need tableID, when it is already in meta.ID ? + tableID int64 + // physicalTableID is a unique int64 to identify a physical table. + physicalTableID int64 + Columns []*table.Column + + // column caches + // They are pointers to support copying TableCommon to CachedTable and PartitionedTable + publicColumns []*table.Column + visibleColumns []*table.Column + hiddenColumns []*table.Column + writableColumns []*table.Column + fullHiddenColsAndVisibleColumns []*table.Column + writableConstraints []*table.Constraint + + indices []table.Index + meta *model.TableInfo + allocs autoid.Allocators + sequence *sequenceCommon + dependencyColumnOffsets []int + Constraints []*table.Constraint + + // recordPrefix and indexPrefix are generated using physicalTableID. + recordPrefix kv.Key + indexPrefix kv.Key +} + +// ResetColumnsCache implements testingKnob interface. +func (t *TableCommon) ResetColumnsCache() { + t.publicColumns = t.getCols(full) + t.visibleColumns = t.getCols(visible) + t.hiddenColumns = t.getCols(hidden) + + t.writableColumns = make([]*table.Column, 0, len(t.Columns)) + for _, col := range t.Columns { + if col.State == model.StateDeleteOnly || col.State == model.StateDeleteReorganization { + continue + } + t.writableColumns = append(t.writableColumns, col) + } + + t.fullHiddenColsAndVisibleColumns = make([]*table.Column, 0, len(t.Columns)) + for _, col := range t.Columns { + if col.Hidden || col.State == model.StatePublic { + t.fullHiddenColsAndVisibleColumns = append(t.fullHiddenColsAndVisibleColumns, col) + } + } + + if t.Constraints != nil { + t.writableConstraints = make([]*table.Constraint, 0, len(t.Constraints)) + for _, con := range t.Constraints { + if !con.Enforced { + continue + } + if con.State == model.StateDeleteOnly || con.State == model.StateDeleteReorganization { + continue + } + t.writableConstraints = append(t.writableConstraints, con) + } + } +} + +// Copy copies a TableCommon struct, and reset its column cache. This is not a deep copy. +func (t *TableCommon) Copy() TableCommon { + newTable := *t + return newTable +} + +// MockTableFromMeta only serves for test. +func MockTableFromMeta(tblInfo *model.TableInfo) table.Table { + columns := make([]*table.Column, 0, len(tblInfo.Columns)) + for _, colInfo := range tblInfo.Columns { + col := table.ToColumn(colInfo) + columns = append(columns, col) + } + + constraints, err := table.LoadCheckConstraint(tblInfo) + if err != nil { + return nil + } + var t TableCommon + initTableCommon(&t, tblInfo, tblInfo.ID, columns, autoid.NewAllocators(false), constraints) + if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { + ret, err := newCachedTable(&t) + if err != nil { + return nil + } + return ret + } + if tblInfo.GetPartitionInfo() == nil { + if err := initTableIndices(&t); err != nil { + return nil + } + return &t + } + + ret, err := newPartitionedTable(&t, tblInfo) + if err != nil { + return nil + } + return ret +} + +// TableFromMeta creates a Table instance from model.TableInfo. +func TableFromMeta(allocs autoid.Allocators, tblInfo *model.TableInfo) (table.Table, error) { + if tblInfo.State == model.StateNone { + return nil, table.ErrTableStateCantNone.GenWithStackByArgs(tblInfo.Name) + } + + colsLen := len(tblInfo.Columns) + columns := make([]*table.Column, 0, colsLen) + for i, colInfo := range tblInfo.Columns { + if colInfo.State == model.StateNone { + return nil, table.ErrColumnStateCantNone.GenWithStackByArgs(colInfo.Name) + } + + // Print some information when the column's offset isn't equal to i. + if colInfo.Offset != i { + logutil.BgLogger().Error("wrong table schema", zap.Any("table", tblInfo), zap.Any("column", colInfo), zap.Int("index", i), zap.Int("offset", colInfo.Offset), zap.Int("columnNumber", colsLen)) + } + + col := table.ToColumn(colInfo) + if col.IsGenerated() { + genStr := colInfo.GeneratedExprString + expr, err := buildGeneratedExpr(tblInfo, genStr) + if err != nil { + return nil, err + } + col.GeneratedExpr = table.NewClonableExprNode(func() ast.ExprNode { + newExpr, err1 := buildGeneratedExpr(tblInfo, genStr) + if err1 != nil { + logutil.BgLogger().Warn("unexpected parse generated string error", + zap.String("generatedStr", genStr), + zap.Error(err1)) + return expr + } + return newExpr + }, expr) + } + // default value is expr. + if col.DefaultIsExpr { + expr, err := generatedexpr.ParseExpression(colInfo.DefaultValue.(string)) + if err != nil { + return nil, err + } + col.DefaultExpr = expr + } + columns = append(columns, col) + } + + constraints, err := table.LoadCheckConstraint(tblInfo) + if err != nil { + return nil, err + } + var t TableCommon + initTableCommon(&t, tblInfo, tblInfo.ID, columns, allocs, constraints) + if tblInfo.GetPartitionInfo() == nil { + if err := initTableIndices(&t); err != nil { + return nil, err + } + if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { + return newCachedTable(&t) + } + return &t, nil + } + return newPartitionedTable(&t, tblInfo) +} + +func buildGeneratedExpr(tblInfo *model.TableInfo, genExpr string) (ast.ExprNode, error) { + expr, err := generatedexpr.ParseExpression(genExpr) + if err != nil { + return nil, err + } + expr, err = generatedexpr.SimpleResolveName(expr, tblInfo) + if err != nil { + return nil, err + } + return expr, nil +} + +// initTableCommon initializes a TableCommon struct. +func initTableCommon(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators, constraints []*table.Constraint) { + t.tableID = tblInfo.ID + t.physicalTableID = physicalTableID + t.allocs = allocs + t.meta = tblInfo + t.Columns = cols + t.Constraints = constraints + t.recordPrefix = tablecodec.GenTableRecordPrefix(physicalTableID) + t.indexPrefix = tablecodec.GenTableIndexPrefix(physicalTableID) + if tblInfo.IsSequence() { + t.sequence = &sequenceCommon{meta: tblInfo.Sequence} + } + for _, col := range cols { + if col.ChangeStateInfo != nil { + t.dependencyColumnOffsets = append(t.dependencyColumnOffsets, col.ChangeStateInfo.DependencyColumnOffset) + } + } + t.ResetColumnsCache() +} + +// initTableIndices initializes the indices of the TableCommon. +func initTableIndices(t *TableCommon) error { + tblInfo := t.meta + for _, idxInfo := range tblInfo.Indices { + if idxInfo.State == model.StateNone { + return table.ErrIndexStateCantNone.GenWithStackByArgs(idxInfo.Name) + } + + // Use partition ID for index, because TableCommon may be table or partition. + idx := NewIndex(t.physicalTableID, tblInfo, idxInfo) + intest.AssertFunc(func() bool { + // `TableCommon.indices` is type of `[]table.Index` to implement interface method `Table.Indices`. + // However, we have an assumption that the specific type of each element in it should always be `*index`. + // We have this assumption because some codes access the inner method of `*index`, + // and they use `asIndex` to cast `table.Index` to `*index`. + _, ok := idx.(*index) + intest.Assert(ok, "index should be type of `*index`") + return true + }) + t.indices = append(t.indices, idx) + } + return nil +} + +// asIndex casts a table.Index to *index which is the actual type of index in TableCommon. +func asIndex(idx table.Index) *index { + return idx.(*index) +} + +func initTableCommonWithIndices(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators, constraints []*table.Constraint) error { + initTableCommon(t, tblInfo, physicalTableID, cols, allocs, constraints) + return initTableIndices(t) +} + +// Indices implements table.Table Indices interface. +func (t *TableCommon) Indices() []table.Index { + return t.indices +} + +// GetWritableIndexByName gets the index meta from the table by the index name. +func GetWritableIndexByName(idxName string, t table.Table) table.Index { + for _, idx := range t.Indices() { + if !IsIndexWritable(idx) { + continue + } + if idxName == idx.Meta().Name.L { + return idx + } + } + return nil +} + +// deletableIndices implements table.Table deletableIndices interface. +func (t *TableCommon) deletableIndices() []table.Index { + // All indices are deletable because we don't need to check StateNone. + return t.indices +} + +// Meta implements table.Table Meta interface. +func (t *TableCommon) Meta() *model.TableInfo { + return t.meta +} + +// GetPhysicalID implements table.Table GetPhysicalID interface. +func (t *TableCommon) GetPhysicalID() int64 { + return t.physicalTableID +} + +// GetPartitionedTable implements table.Table GetPhysicalID interface. +func (t *TableCommon) GetPartitionedTable() table.PartitionedTable { + return nil +} + +type getColsMode int64 + +const ( + _ getColsMode = iota + visible + hidden + full +) + +func (t *TableCommon) getCols(mode getColsMode) []*table.Column { + columns := make([]*table.Column, 0, len(t.Columns)) + for _, col := range t.Columns { + if col.State != model.StatePublic { + continue + } + if (mode == visible && col.Hidden) || (mode == hidden && !col.Hidden) { + continue + } + columns = append(columns, col) + } + return columns +} + +// Cols implements table.Table Cols interface. +func (t *TableCommon) Cols() []*table.Column { + return t.publicColumns +} + +// VisibleCols implements table.Table VisibleCols interface. +func (t *TableCommon) VisibleCols() []*table.Column { + return t.visibleColumns +} + +// HiddenCols implements table.Table HiddenCols interface. +func (t *TableCommon) HiddenCols() []*table.Column { + return t.hiddenColumns +} + +// WritableCols implements table WritableCols interface. +func (t *TableCommon) WritableCols() []*table.Column { + return t.writableColumns +} + +// DeletableCols implements table DeletableCols interface. +func (t *TableCommon) DeletableCols() []*table.Column { + return t.Columns +} + +// WritableConstraint returns constraints of the table in writable states. +func (t *TableCommon) WritableConstraint() []*table.Constraint { + if t.Constraints == nil { + return nil + } + return t.writableConstraints +} + +// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. +func (t *TableCommon) FullHiddenColsAndVisibleCols() []*table.Column { + return t.fullHiddenColsAndVisibleColumns +} + +// RecordPrefix implements table.Table interface. +func (t *TableCommon) RecordPrefix() kv.Key { + return t.recordPrefix +} + +// IndexPrefix implements table.Table interface. +func (t *TableCommon) IndexPrefix() kv.Key { + return t.indexPrefix +} + +// RecordKey implements table.Table interface. +func (t *TableCommon) RecordKey(h kv.Handle) kv.Key { + return tablecodec.EncodeRecordKey(t.recordPrefix, h) +} + +// shouldAssert checks if the partition should be in consistent +// state and can have assertion. +func (t *TableCommon) shouldAssert(level variable.AssertionLevel) bool { + p := t.Meta().Partition + if p != nil { + // This disables asserting during Reorganize Partition. + switch level { + case variable.AssertionLevelFast: + // Fast option, just skip assertion for all partitions. + if p.DDLState != model.StateNone && p.DDLState != model.StatePublic { + return false + } + case variable.AssertionLevelStrict: + // Strict, only disable assertion for intermediate partitions. + // If there were an easy way to get from a TableCommon back to the partitioned table... + for i := range p.AddingDefinitions { + if t.physicalTableID == p.AddingDefinitions[i].ID { + return false + } + } + } + } + return true +} + +// UpdateRecord implements table.Table UpdateRecord interface. +// `touched` means which columns are really modified, used for secondary indices. +// Length of `oldData` and `newData` equals to length of `t.WritableCols()`. +func (t *TableCommon) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { + opt := table.NewUpdateRecordOpt(opts...) + return t.updateRecord(ctx, h, oldData, newData, touched, opt) +} + +func (t *TableCommon) updateRecord(sctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opt *table.UpdateRecordOpt) error { + txn, err := sctx.Txn(true) + if err != nil { + return err + } + + memBuffer := txn.GetMemBuffer() + sh := memBuffer.Staging() + defer memBuffer.Cleanup(sh) + + if m := t.Meta(); m.TempTableType != model.TempTableNone { + if tmpTable, sizeLimit, ok := addTemporaryTable(sctx, m); ok { + if err = checkTempTableSize(tmpTable, sizeLimit); err != nil { + return err + } + defer handleTempTableSize(tmpTable, txn.Size(), txn) + } + } + + var binlogColIDs []int64 + var binlogOldRow, binlogNewRow []types.Datum + numColsCap := len(newData) + 1 // +1 for the extra handle column that we may need to append. + + // a reusable buffer to save malloc + // Note: The buffer should not be referenced or modified outside this function. + // It can only act as a temporary buffer for the current function call. + mutateBuffers := sctx.GetMutateBuffers() + encodeRowBuffer := mutateBuffers.GetEncodeRowBufferWithCap(numColsCap) + checkRowBuffer := mutateBuffers.GetCheckRowBufferWithCap(numColsCap) + binlogSupport, shouldWriteBinlog := getBinlogSupport(sctx, t.meta) + if shouldWriteBinlog { + binlogColIDs = make([]int64, 0, numColsCap) + binlogOldRow = make([]types.Datum, 0, numColsCap) + binlogNewRow = make([]types.Datum, 0, numColsCap) + } + + for _, col := range t.Columns { + var value types.Datum + if col.State == model.StateDeleteOnly || col.State == model.StateDeleteReorganization { + if col.ChangeStateInfo != nil { + // TODO: Check overflow or ignoreTruncate. + value, err = table.CastColumnValue(sctx.GetExprCtx(), oldData[col.DependencyColumnOffset], col.ColumnInfo, false, false) + if err != nil { + logutil.BgLogger().Info("update record cast value failed", zap.Any("col", col), zap.Uint64("txnStartTS", txn.StartTS()), + zap.String("handle", h.String()), zap.Any("val", oldData[col.DependencyColumnOffset]), zap.Error(err)) + return err + } + oldData = append(oldData, value) + touched = append(touched, touched[col.DependencyColumnOffset]) + } + continue + } + if col.State != model.StatePublic { + // If col is in write only or write reorganization state we should keep the oldData. + // Because the oldData must be the original data(it's changed by other TiDBs.) or the original default value. + // TODO: Use newData directly. + value = oldData[col.Offset] + if col.ChangeStateInfo != nil { + // TODO: Check overflow or ignoreTruncate. + value, err = table.CastColumnValue(sctx.GetExprCtx(), newData[col.DependencyColumnOffset], col.ColumnInfo, false, false) + if err != nil { + return err + } + newData[col.Offset] = value + touched[col.Offset] = touched[col.DependencyColumnOffset] + } + } else { + value = newData[col.Offset] + } + if !t.canSkip(col, &value) { + encodeRowBuffer.AddColVal(col.ID, value) + } + checkRowBuffer.AddColVal(value) + if shouldWriteBinlog && !t.canSkipUpdateBinlog(col, value) { + binlogColIDs = append(binlogColIDs, col.ID) + binlogOldRow = append(binlogOldRow, oldData[col.Offset]) + binlogNewRow = append(binlogNewRow, value) + } + } + // check data constraint + evalCtx := sctx.GetExprCtx().GetEvalCtx() + if constraints := t.WritableConstraint(); len(constraints) > 0 { + if err = table.CheckRowConstraint(evalCtx, constraints, checkRowBuffer.GetRowToCheck()); err != nil { + return err + } + } + // rebuild index + err = t.rebuildUpdateRecordIndices(sctx, txn, h, touched, oldData, newData, opt) + if err != nil { + return err + } + + key := t.RecordKey(h) + tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() + err = encodeRowBuffer.WriteMemBufferEncoded(sctx.GetRowEncodingConfig(), tc.Location(), ec, memBuffer, key) + if err != nil { + return err + } + + failpoint.Inject("updateRecordForceAssertNotExist", func() { + // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. + // Since only the first assertion takes effect, set the injected assertion before setting the correct one to + // override it. + if sctx.ConnectionID() != 0 { + logutil.BgLogger().Info("force asserting not exist on UpdateRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) + if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { + failpoint.Return(err) + } + } + }) + + if t.shouldAssert(sctx.TxnAssertionLevel()) { + err = txn.SetAssertion(key, kv.SetAssertExist) + } else { + err = txn.SetAssertion(key, kv.SetAssertUnknown) + } + if err != nil { + return err + } + + if err = injectMutationError(t, txn, sh); err != nil { + return err + } + if sctx.EnableMutationChecker() { + if err = CheckDataConsistency(txn, tc, t, newData, oldData, memBuffer, sh); err != nil { + return errors.Trace(err) + } + } + + memBuffer.Release(sh) + if shouldWriteBinlog { + if !t.meta.PKIsHandle && !t.meta.IsCommonHandle { + binlogColIDs = append(binlogColIDs, model.ExtraHandleID) + binlogOldRow = append(binlogOldRow, types.NewIntDatum(h.IntValue())) + binlogNewRow = append(binlogNewRow, types.NewIntDatum(h.IntValue())) + } + err = t.addUpdateBinlog(sctx, binlogSupport, binlogOldRow, binlogNewRow, binlogColIDs) + if err != nil { + return err + } + } + + if s, ok := sctx.GetStatisticsSupport(); ok { + colSizeBuffer := mutateBuffers.GetColSizeDeltaBufferWithCap(len(t.Cols())) + for id, col := range t.Cols() { + size, err := codec.EstimateValueSize(tc, newData[id]) + if err != nil { + continue + } + newLen := size - 1 + size, err = codec.EstimateValueSize(tc, oldData[id]) + if err != nil { + continue + } + oldLen := size - 1 + colSizeBuffer.AddColSizeDelta(col.ID, int64(newLen-oldLen)) + } + s.UpdatePhysicalTableDelta(t.physicalTableID, 0, 1, colSizeBuffer) + } + return nil +} + +func (t *TableCommon) rebuildUpdateRecordIndices( + ctx table.MutateContext, txn kv.Transaction, + h kv.Handle, touched []bool, oldData []types.Datum, newData []types.Datum, + opt *table.UpdateRecordOpt, +) error { + for _, idx := range t.deletableIndices() { + if t.meta.IsCommonHandle && idx.Meta().Primary { + continue + } + for _, ic := range idx.Meta().Columns { + if !touched[ic.Offset] { + continue + } + oldVs, err := idx.FetchValues(oldData, nil) + if err != nil { + return err + } + if err = t.removeRowIndex(ctx, h, oldVs, idx, txn); err != nil { + return err + } + break + } + } + createIdxOpt := opt.GetCreateIdxOpt() + for _, idx := range t.Indices() { + if !IsIndexWritable(idx) { + continue + } + if t.meta.IsCommonHandle && idx.Meta().Primary { + continue + } + untouched := true + for _, ic := range idx.Meta().Columns { + if !touched[ic.Offset] { + continue + } + untouched = false + break + } + if untouched && opt.SkipWriteUntouchedIndices { + continue + } + newVs, err := idx.FetchValues(newData, nil) + if err != nil { + return err + } + if err := t.buildIndexForRow(ctx, h, newVs, newData, asIndex(idx), txn, untouched, createIdxOpt); err != nil { + return err + } + } + return nil +} + +// FindPrimaryIndex uses to find primary index in tableInfo. +func FindPrimaryIndex(tblInfo *model.TableInfo) *model.IndexInfo { + var pkIdx *model.IndexInfo + for _, idx := range tblInfo.Indices { + if idx.Primary { + pkIdx = idx + break + } + } + return pkIdx +} + +// TryGetCommonPkColumnIds get the IDs of primary key column if the table has common handle. +func TryGetCommonPkColumnIds(tbl *model.TableInfo) []int64 { + if !tbl.IsCommonHandle { + return nil + } + pkIdx := FindPrimaryIndex(tbl) + pkColIDs := make([]int64, 0, len(pkIdx.Columns)) + for _, idxCol := range pkIdx.Columns { + pkColIDs = append(pkColIDs, tbl.Columns[idxCol.Offset].ID) + } + return pkColIDs +} + +// PrimaryPrefixColumnIDs get prefix column ids in primary key. +func PrimaryPrefixColumnIDs(tbl *model.TableInfo) (prefixCols []int64) { + for _, idx := range tbl.Indices { + if !idx.Primary { + continue + } + for _, col := range idx.Columns { + if col.Length > 0 && tbl.Columns[col.Offset].GetFlen() > col.Length { + prefixCols = append(prefixCols, tbl.Columns[col.Offset].ID) + } + } + } + return +} + +// TryGetCommonPkColumns get the primary key columns if the table has common handle. +func TryGetCommonPkColumns(tbl table.Table) []*table.Column { + if !tbl.Meta().IsCommonHandle { + return nil + } + pkIdx := FindPrimaryIndex(tbl.Meta()) + cols := tbl.Cols() + pkCols := make([]*table.Column, 0, len(pkIdx.Columns)) + for _, idxCol := range pkIdx.Columns { + pkCols = append(pkCols, cols[idxCol.Offset]) + } + return pkCols +} + +func addTemporaryTable(sctx table.MutateContext, tblInfo *model.TableInfo) (tbctx.TemporaryTableHandler, int64, bool) { + if s, ok := sctx.GetTemporaryTableSupport(); ok { + if h, ok := s.AddTemporaryTableToTxn(tblInfo); ok { + return h, s.GetTemporaryTableSizeLimit(), ok + } + } + return tbctx.TemporaryTableHandler{}, 0, false +} + +// The size of a temporary table is calculated by accumulating the transaction size delta. +func handleTempTableSize(t tbctx.TemporaryTableHandler, txnSizeBefore int, txn kv.Transaction) { + t.UpdateTxnDeltaSize(txn.Size() - txnSizeBefore) +} + +func checkTempTableSize(tmpTable tbctx.TemporaryTableHandler, sizeLimit int64) error { + if tmpTable.GetCommittedSize()+tmpTable.GetDirtySize() > sizeLimit { + return table.ErrTempTableFull.GenWithStackByArgs(tmpTable.Meta().Name.O) + } + return nil +} + +// AddRecord implements table.Table AddRecord interface. +func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { + // TODO: optimize the allocation (and calculation) of opt. + opt := table.NewAddRecordOpt(opts...) + return t.addRecord(sctx, r, opt) +} + +func (t *TableCommon) addRecord(sctx table.MutateContext, r []types.Datum, opt *table.AddRecordOpt) (recordID kv.Handle, err error) { + txn, err := sctx.Txn(true) + if err != nil { + return nil, err + } + if m := t.Meta(); m.TempTableType != model.TempTableNone { + if tmpTable, sizeLimit, ok := addTemporaryTable(sctx, m); ok { + if err = checkTempTableSize(tmpTable, sizeLimit); err != nil { + return nil, err + } + defer handleTempTableSize(tmpTable, txn.Size(), txn) + } + } + + var ctx context.Context + if opt.Ctx != nil { + ctx = opt.Ctx + var r tracing.Region + r, ctx = tracing.StartRegionEx(ctx, "table.AddRecord") + defer r.End() + } else { + ctx = context.Background() + } + + evalCtx := sctx.GetExprCtx().GetEvalCtx() + tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() + + var hasRecordID bool + cols := t.Cols() + // opt.IsUpdate is a flag for update. + // If handle ID is changed when update, update will remove the old record first, and then call `AddRecord` to add a new record. + // Currently, only insert can set _tidb_rowid, update can not update _tidb_rowid. + if len(r) > len(cols) && !opt.IsUpdate { + // The last value is _tidb_rowid. + recordID = kv.IntHandle(r[len(r)-1].GetInt64()) + hasRecordID = true + } else { + tblInfo := t.Meta() + txn.CacheTableInfo(t.physicalTableID, tblInfo) + if tblInfo.PKIsHandle { + recordID = kv.IntHandle(r[tblInfo.GetPkColInfo().Offset].GetInt64()) + hasRecordID = true + } else if tblInfo.IsCommonHandle { + pkIdx := FindPrimaryIndex(tblInfo) + pkDts := make([]types.Datum, 0, len(pkIdx.Columns)) + for _, idxCol := range pkIdx.Columns { + pkDts = append(pkDts, r[idxCol.Offset]) + } + tablecodec.TruncateIndexValues(tblInfo, pkIdx, pkDts) + var handleBytes []byte + handleBytes, err = codec.EncodeKey(tc.Location(), nil, pkDts...) + err = ec.HandleError(err) + if err != nil { + return + } + recordID, err = kv.NewCommonHandle(handleBytes) + if err != nil { + return + } + hasRecordID = true + } + } + if !hasRecordID { + if opt.ReserveAutoID > 0 { + // Reserve a batch of auto ID in the statement context. + // The reserved ID could be used in the future within this statement, by the + // following AddRecord() operation. + // Make the IDs continuous benefit for the performance of TiKV. + if reserved, ok := sctx.GetReservedRowIDAlloc(); ok { + var baseRowID, maxRowID int64 + if baseRowID, maxRowID, err = AllocHandleIDs(ctx, sctx, t, uint64(opt.ReserveAutoID)); err != nil { + return nil, err + } + reserved.Reset(baseRowID, maxRowID) + } + } + + recordID, err = AllocHandle(ctx, sctx, t) + if err != nil { + return nil, err + } + } + + // a reusable buffer to save malloc + // Note: The buffer should not be referenced or modified outside this function. + // It can only act as a temporary buffer for the current function call. + mutateBuffers := sctx.GetMutateBuffers() + encodeRowBuffer := mutateBuffers.GetEncodeRowBufferWithCap(len(r)) + memBuffer := txn.GetMemBuffer() + sh := memBuffer.Staging() + defer memBuffer.Cleanup(sh) + + sessVars := sctx.GetSessionVars() + for _, col := range t.Columns { + var value types.Datum + if col.State == model.StateDeleteOnly || col.State == model.StateDeleteReorganization { + continue + } + // In column type change, since we have set the origin default value for changing col, but + // for the new insert statement, we should use the casted value of relative column to insert. + if col.ChangeStateInfo != nil && col.State != model.StatePublic { + // TODO: Check overflow or ignoreTruncate. + value, err = table.CastColumnValue(sctx.GetExprCtx(), r[col.DependencyColumnOffset], col.ColumnInfo, false, false) + if err != nil { + return nil, err + } + if len(r) < len(t.WritableCols()) { + r = append(r, value) + } else { + r[col.Offset] = value + } + encodeRowBuffer.AddColVal(col.ID, value) + continue + } + if col.State == model.StatePublic { + value = r[col.Offset] + } else { + // col.ChangeStateInfo must be nil here. + // because `col.State != model.StatePublic` is true here, if col.ChangeStateInfo is not nil, the col should + // be handle by the previous if-block. + + if opt.IsUpdate { + // If `AddRecord` is called by an update, the default value should be handled the update. + value = r[col.Offset] + } else { + // If `AddRecord` is called by an insert and the col is in write only or write reorganization state, we must + // add it with its default value. + value, err = table.GetColOriginDefaultValue(sctx.GetExprCtx(), col.ToInfo()) + if err != nil { + return nil, err + } + // add value to `r` for dirty db in transaction. + // Otherwise when update will panic cause by get value of column in write only state from dirty db. + if col.Offset < len(r) { + r[col.Offset] = value + } else { + r = append(r, value) + } + } + } + if !t.canSkip(col, &value) { + encodeRowBuffer.AddColVal(col.ID, value) + } + } + // check data constraint + if err = table.CheckRowConstraintWithDatum(evalCtx, t.WritableConstraint(), r); err != nil { + return nil, err + } + key := t.RecordKey(recordID) + var setPresume bool + if opt.DupKeyCheck != table.DupKeyCheckSkip { + if t.meta.TempTableType != model.TempTableNone { + // Always check key for temporary table because it does not write to TiKV + _, err = txn.Get(ctx, key) + } else if sctx.GetSessionVars().LazyCheckKeyNotExists() || txn.IsPipelined() { + var v []byte + v, err = txn.GetMemBuffer().GetLocal(ctx, key) + if err != nil { + setPresume = true + } + if err == nil && len(v) == 0 { + err = kv.ErrNotExist + } + } else { + _, err = txn.Get(ctx, key) + } + if err == nil { + dupErr := getDuplicateError(t.Meta(), recordID, r) + return recordID, dupErr + } else if !kv.ErrNotExist.Equal(err) { + return recordID, err + } + } + + var flags []kv.FlagsOp + if setPresume { + flags = []kv.FlagsOp{kv.SetPresumeKeyNotExists} + if !sessVars.ConstraintCheckInPlacePessimistic && sessVars.TxnCtx.IsPessimistic && sessVars.InTxn() && + !sctx.InRestrictedSQL() && sctx.ConnectionID() > 0 { + flags = append(flags, kv.SetNeedConstraintCheckInPrewrite) + } + } + + err = encodeRowBuffer.WriteMemBufferEncoded(sctx.GetRowEncodingConfig(), tc.Location(), ec, memBuffer, key, flags...) + if err != nil { + return nil, err + } + + failpoint.Inject("addRecordForceAssertExist", func() { + // Assert the key exists while it actually doesn't. This is helpful to test if assertion takes effect. + // Since only the first assertion takes effect, set the injected assertion before setting the correct one to + // override it. + if sctx.ConnectionID() != 0 { + logutil.BgLogger().Info("force asserting exist on AddRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) + if err = txn.SetAssertion(key, kv.SetAssertExist); err != nil { + failpoint.Return(nil, err) + } + } + }) + if setPresume && !txn.IsPessimistic() { + err = txn.SetAssertion(key, kv.SetAssertUnknown) + } else { + err = txn.SetAssertion(key, kv.SetAssertNotExist) + } + if err != nil { + return nil, err + } + + // Insert new entries into indices. + h, err := t.addIndices(sctx, recordID, r, txn, opt.GetCreateIdxOpt()) + if err != nil { + return h, err + } + + if err = injectMutationError(t, txn, sh); err != nil { + return nil, err + } + if sctx.EnableMutationChecker() { + if err = CheckDataConsistency(txn, tc, t, r, nil, memBuffer, sh); err != nil { + return nil, errors.Trace(err) + } + } + + memBuffer.Release(sh) + + binlogSupport, shouldWriteBinlog := getBinlogSupport(sctx, t.meta) + if shouldWriteBinlog { + // For insert, TiDB and Binlog can use same row and schema. + err = t.addInsertBinlog(sctx, binlogSupport, recordID, encodeRowBuffer) + if err != nil { + return nil, err + } + } + + if s, ok := sctx.GetStatisticsSupport(); ok { + colSizeBuffer := sctx.GetMutateBuffers().GetColSizeDeltaBufferWithCap(len(t.Cols())) + for id, col := range t.Cols() { + size, err := codec.EstimateValueSize(tc, r[id]) + if err != nil { + continue + } + colSizeBuffer.AddColSizeDelta(col.ID, int64(size-1)) + } + s.UpdatePhysicalTableDelta(t.physicalTableID, 1, 1, colSizeBuffer) + } + return recordID, nil +} + +// genIndexKeyStrs generates index content strings representation. +func genIndexKeyStrs(colVals []types.Datum) ([]string, error) { + // Pass pre-composed error to txn. + strVals := make([]string, 0, len(colVals)) + for _, cv := range colVals { + cvs := "NULL" + var err error + if !cv.IsNull() { + cvs, err = types.ToString(cv.GetValue()) + if err != nil { + return nil, err + } + } + strVals = append(strVals, cvs) + } + return strVals, nil +} + +// addIndices adds data into indices. If any key is duplicated, returns the original handle. +func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r []types.Datum, txn kv.Transaction, opt *table.CreateIdxOpt) (kv.Handle, error) { + writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs() + indexVals := writeBufs.IndexValsBuf + skipCheck := opt.DupKeyCheck == table.DupKeyCheckSkip + for _, v := range t.Indices() { + if !IsIndexWritable(v) { + continue + } + if t.meta.IsCommonHandle && v.Meta().Primary { + continue + } + // We declared `err` here to make sure `indexVals` is assigned with `=` instead of `:=`. + // The latter one will create a new variable that shadows the outside `indexVals` that makes `indexVals` outside + // always nil, and we cannot reuse it. + var err error + indexVals, err = v.FetchValues(r, indexVals) + if err != nil { + return nil, err + } + var dupErr error + if !skipCheck && v.Meta().Unique { + // Make error message consistent with MySQL. + tablecodec.TruncateIndexValues(t.meta, v.Meta(), indexVals) + colStrVals, err := genIndexKeyStrs(indexVals) + if err != nil { + return nil, err + } + dupErr = kv.GenKeyExistsErr(colStrVals, fmt.Sprintf("%s.%s", v.TableMeta().Name.String(), v.Meta().Name.String())) + } + rsData := TryGetHandleRestoredDataWrapper(t.meta, r, nil, v.Meta()) + if dupHandle, err := asIndex(v).create(sctx, txn, indexVals, recordID, rsData, false, opt); err != nil { + if kv.ErrKeyExists.Equal(err) { + return dupHandle, dupErr + } + return nil, err + } + } + // save the buffer, multi rows insert can use it. + writeBufs.IndexValsBuf = indexVals + return nil, nil +} + +// RowWithCols is used to get the corresponding column datum values with the given handle. +func RowWithCols(t table.Table, ctx sessionctx.Context, h kv.Handle, cols []*table.Column) ([]types.Datum, error) { + // Get raw row data from kv. + key := tablecodec.EncodeRecordKey(t.RecordPrefix(), h) + txn, err := ctx.Txn(true) + if err != nil { + return nil, err + } + value, err := txn.Get(context.TODO(), key) + if err != nil { + return nil, err + } + v, _, err := DecodeRawRowData(ctx, t.Meta(), h, cols, value) + if err != nil { + return nil, err + } + return v, nil +} + +func containFullColInHandle(meta *model.TableInfo, col *table.Column) (containFullCol bool, idxInHandle int) { + pkIdx := FindPrimaryIndex(meta) + for i, idxCol := range pkIdx.Columns { + if meta.Columns[idxCol.Offset].ID == col.ID { + idxInHandle = i + containFullCol = idxCol.Length == types.UnspecifiedLength + return + } + } + return +} + +// DecodeRawRowData decodes raw row data into a datum slice and a (columnID:columnValue) map. +func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h kv.Handle, cols []*table.Column, + value []byte) ([]types.Datum, map[int64]types.Datum, error) { + v := make([]types.Datum, len(cols)) + colTps := make(map[int64]*types.FieldType, len(cols)) + prefixCols := make(map[int64]struct{}) + for i, col := range cols { + if col == nil { + continue + } + if col.IsPKHandleColumn(meta) { + if mysql.HasUnsignedFlag(col.GetFlag()) { + v[i].SetUint64(uint64(h.IntValue())) + } else { + v[i].SetInt64(h.IntValue()) + } + continue + } + if col.IsCommonHandleColumn(meta) && !types.NeedRestoredData(&col.FieldType) { + if containFullCol, idxInHandle := containFullColInHandle(meta, col); containFullCol { + dtBytes := h.EncodedCol(idxInHandle) + _, dt, err := codec.DecodeOne(dtBytes) + if err != nil { + return nil, nil, err + } + dt, err = tablecodec.Unflatten(dt, &col.FieldType, ctx.GetSessionVars().Location()) + if err != nil { + return nil, nil, err + } + v[i] = dt + continue + } + prefixCols[col.ID] = struct{}{} + } + colTps[col.ID] = &col.FieldType + } + rowMap, err := tablecodec.DecodeRowToDatumMap(value, colTps, ctx.GetSessionVars().Location()) + if err != nil { + return nil, rowMap, err + } + defaultVals := make([]types.Datum, len(cols)) + for i, col := range cols { + if col == nil { + continue + } + if col.IsPKHandleColumn(meta) || (col.IsCommonHandleColumn(meta) && !types.NeedRestoredData(&col.FieldType)) { + if _, isPrefix := prefixCols[col.ID]; !isPrefix { + continue + } + } + ri, ok := rowMap[col.ID] + if ok { + v[i] = ri + continue + } + if col.IsVirtualGenerated() { + continue + } + if col.ChangeStateInfo != nil { + v[i], _, err = GetChangingColVal(ctx.GetExprCtx(), cols, col, rowMap, defaultVals) + } else { + v[i], err = GetColDefaultValue(ctx.GetExprCtx(), col, defaultVals) + } + if err != nil { + return nil, rowMap, err + } + } + return v, rowMap, nil +} + +// GetChangingColVal gets the changing column value when executing "modify/change column" statement. +// For statement like update-where, it will fetch the old row out and insert it into kv again. +// Since update statement can see the writable columns, it is responsible for the casting relative column / get the fault value here. +// old row : a-b-[nil] +// new row : a-b-[a'/default] +// Thus the writable new row is corresponding to Write-Only constraints. +func GetChangingColVal(ctx exprctx.BuildContext, cols []*table.Column, col *table.Column, rowMap map[int64]types.Datum, defaultVals []types.Datum) (_ types.Datum, isDefaultVal bool, err error) { + relativeCol := cols[col.ChangeStateInfo.DependencyColumnOffset] + idxColumnVal, ok := rowMap[relativeCol.ID] + if ok { + idxColumnVal, err = table.CastColumnValue(ctx, idxColumnVal, col.ColumnInfo, false, false) + // TODO: Consider sql_mode and the error msg(encounter this error check whether to rollback). + if err != nil { + return idxColumnVal, false, errors.Trace(err) + } + return idxColumnVal, false, nil + } + + idxColumnVal, err = GetColDefaultValue(ctx, col, defaultVals) + if err != nil { + return idxColumnVal, false, errors.Trace(err) + } + + return idxColumnVal, true, nil +} + +// RemoveRecord implements table.Table RemoveRecord interface. +func (t *TableCommon) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []types.Datum) error { + txn, err := ctx.Txn(true) + if err != nil { + return err + } + + memBuffer := txn.GetMemBuffer() + sh := memBuffer.Staging() + defer memBuffer.Cleanup(sh) + + err = t.removeRowData(ctx, h) + if err != nil { + return err + } + + if m := t.Meta(); m.TempTableType != model.TempTableNone { + if tmpTable, sizeLimit, ok := addTemporaryTable(ctx, m); ok { + if err = checkTempTableSize(tmpTable, sizeLimit); err != nil { + return err + } + defer handleTempTableSize(tmpTable, txn.Size(), txn) + } + } + + // The table has non-public column and this column is doing the operation of "modify/change column". + if len(t.Columns) > len(r) && t.Columns[len(r)].ChangeStateInfo != nil { + // The changing column datum derived from related column should be casted here. + // Otherwise, the existed changing indexes will not be deleted. + relatedColDatum := r[t.Columns[len(r)].ChangeStateInfo.DependencyColumnOffset] + value, err := table.CastColumnValue(ctx.GetExprCtx(), relatedColDatum, t.Columns[len(r)].ColumnInfo, false, false) + if err != nil { + logutil.BgLogger().Info("remove record cast value failed", zap.Any("col", t.Columns[len(r)]), + zap.String("handle", h.String()), zap.Any("val", relatedColDatum), zap.Error(err)) + return err + } + r = append(r, value) + } + err = t.removeRowIndices(ctx, h, r) + if err != nil { + return err + } + + if err = injectMutationError(t, txn, sh); err != nil { + return err + } + + tc := ctx.GetExprCtx().GetEvalCtx().TypeCtx() + if ctx.EnableMutationChecker() { + if err = CheckDataConsistency(txn, tc, t, nil, r, memBuffer, sh); err != nil { + return errors.Trace(err) + } + } + memBuffer.Release(sh) + + binlogSupport, shouldWriteBinlog := getBinlogSupport(ctx, t.meta) + if shouldWriteBinlog { + cols := t.DeletableCols() + colIDs := make([]int64, 0, len(cols)+1) + for _, col := range cols { + colIDs = append(colIDs, col.ID) + } + var binlogRow []types.Datum + if !t.meta.PKIsHandle && !t.meta.IsCommonHandle { + colIDs = append(colIDs, model.ExtraHandleID) + binlogRow = make([]types.Datum, 0, len(r)+1) + binlogRow = append(binlogRow, r...) + handleData, err := h.Data() + if err != nil { + return err + } + binlogRow = append(binlogRow, handleData...) + } else { + binlogRow = r + } + err = t.addDeleteBinlog(ctx, binlogSupport, binlogRow, colIDs) + } + + if s, ok := ctx.GetStatisticsSupport(); ok { + // a reusable buffer to save malloc + // Note: The buffer should not be referenced or modified outside this function. + // It can only act as a temporary buffer for the current function call. + colSizeBuffer := ctx.GetMutateBuffers().GetColSizeDeltaBufferWithCap(len(t.Cols())) + for id, col := range t.Cols() { + size, err := codec.EstimateValueSize(tc, r[id]) + if err != nil { + continue + } + colSizeBuffer.AddColSizeDelta(col.ID, -int64(size-1)) + } + s.UpdatePhysicalTableDelta( + t.physicalTableID, -1, 1, colSizeBuffer, + ) + } + return err +} + +func (t *TableCommon) addInsertBinlog(ctx table.MutateContext, support tbctx.BinlogSupport, h kv.Handle, encodeRowBuffer *tbctx.EncodeRowBuffer) error { + evalCtx := ctx.GetExprCtx().GetEvalCtx() + loc, ec := evalCtx.Location(), evalCtx.ErrCtx() + handleData, err := h.Data() + if err != nil { + return err + } + pk, err := codec.EncodeValue(loc, nil, handleData...) + err = ec.HandleError(err) + if err != nil { + return err + } + value, err := encodeRowBuffer.EncodeBinlogRowData(loc, ec) + if err != nil { + return err + } + bin := append(pk, value...) + mutation := support.GetBinlogMutation(t.tableID) + mutation.InsertedRows = append(mutation.InsertedRows, bin) + mutation.Sequence = append(mutation.Sequence, binlog.MutationType_Insert) + return nil +} + +func (t *TableCommon) addUpdateBinlog(ctx table.MutateContext, support tbctx.BinlogSupport, oldRow, newRow []types.Datum, colIDs []int64) error { + evalCtx := ctx.GetExprCtx().GetEvalCtx() + loc, ec := evalCtx.Location(), evalCtx.ErrCtx() + old, err := tablecodec.EncodeOldRow(loc, oldRow, colIDs, nil, nil) + err = ec.HandleError(err) + if err != nil { + return err + } + newVal, err := tablecodec.EncodeOldRow(loc, newRow, colIDs, nil, nil) + err = ec.HandleError(err) + if err != nil { + return err + } + bin := append(old, newVal...) + mutation := support.GetBinlogMutation(t.tableID) + mutation.UpdatedRows = append(mutation.UpdatedRows, bin) + mutation.Sequence = append(mutation.Sequence, binlog.MutationType_Update) + return nil +} + +func (t *TableCommon) addDeleteBinlog(ctx table.MutateContext, support tbctx.BinlogSupport, r []types.Datum, colIDs []int64) error { + evalCtx := ctx.GetExprCtx().GetEvalCtx() + loc, ec := evalCtx.Location(), evalCtx.ErrCtx() + data, err := tablecodec.EncodeOldRow(loc, r, colIDs, nil, nil) + err = ec.HandleError(err) + if err != nil { + return err + } + mutation := support.GetBinlogMutation(t.tableID) + mutation.DeletedRows = append(mutation.DeletedRows, data) + mutation.Sequence = append(mutation.Sequence, binlog.MutationType_DeleteRow) + return nil +} + +func writeSequenceUpdateValueBinlog(sctx sessionctx.Context, db, sequence string, end int64) error { + // 1: when sequenceCommon update the local cache passively. + // 2: When sequenceCommon setval to the allocator actively. + // Both of this two case means the upper bound the sequence has changed in meta, which need to write the binlog + // to the downstream. + // Sequence sends `select setval(seq, num)` sql string to downstream via `setDDLBinlog`, which is mocked as a DDL binlog. + binlogCli := sctx.GetSessionVars().BinlogClient + sqlMode := sctx.GetSessionVars().SQLMode + sequenceFullName := stringutil.Escape(db, sqlMode) + "." + stringutil.Escape(sequence, sqlMode) + sql := "select setval(" + sequenceFullName + ", " + strconv.FormatInt(end, 10) + ")" + + ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) + err := kv.RunInNewTxn(ctx, sctx.GetStore(), true, func(ctx context.Context, txn kv.Transaction) error { + m := meta.NewMeta(txn) + mockJobID, err := m.GenGlobalID() + if err != nil { + return err + } + binloginfo.SetDDLBinlog(binlogCli, txn, mockJobID, int32(model.StatePublic), sql) + return nil + }) + return err +} + +func (t *TableCommon) removeRowData(ctx table.MutateContext, h kv.Handle) error { + // Remove row data. + txn, err := ctx.Txn(true) + if err != nil { + return err + } + + key := t.RecordKey(h) + failpoint.Inject("removeRecordForceAssertNotExist", func() { + // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. + // Since only the first assertion takes effect, set the injected assertion before setting the correct one to + // override it. + if ctx.ConnectionID() != 0 { + logutil.BgLogger().Info("force asserting not exist on RemoveRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) + if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { + failpoint.Return(err) + } + } + }) + if t.shouldAssert(ctx.TxnAssertionLevel()) { + err = txn.SetAssertion(key, kv.SetAssertExist) + } else { + err = txn.SetAssertion(key, kv.SetAssertUnknown) + } + if err != nil { + return err + } + return txn.Delete(key) +} + +// removeRowIndices removes all the indices of a row. +func (t *TableCommon) removeRowIndices(ctx table.MutateContext, h kv.Handle, rec []types.Datum) error { + txn, err := ctx.Txn(true) + if err != nil { + return err + } + for _, v := range t.deletableIndices() { + if v.Meta().Primary && (t.Meta().IsCommonHandle || t.Meta().PKIsHandle) { + continue + } + vals, err := v.FetchValues(rec, nil) + if err != nil { + logutil.BgLogger().Info("remove row index failed", zap.Any("index", v.Meta()), zap.Uint64("txnStartTS", txn.StartTS()), zap.String("handle", h.String()), zap.Any("record", rec), zap.Error(err)) + return err + } + if err = v.Delete(ctx, txn, vals, h); err != nil { + if v.Meta().State != model.StatePublic && kv.ErrNotExist.Equal(err) { + // If the index is not in public state, we may have not created the index, + // or already deleted the index, so skip ErrNotExist error. + logutil.BgLogger().Debug("row index not exists", zap.Any("index", v.Meta()), zap.Uint64("txnStartTS", txn.StartTS()), zap.String("handle", h.String())) + continue + } + return err + } + } + return nil +} + +// removeRowIndex implements table.Table RemoveRowIndex interface. +func (t *TableCommon) removeRowIndex(ctx table.MutateContext, h kv.Handle, vals []types.Datum, idx table.Index, txn kv.Transaction) error { + return idx.Delete(ctx, txn, vals, h) +} + +// buildIndexForRow implements table.Table BuildIndexForRow interface. +func (t *TableCommon) buildIndexForRow(ctx table.MutateContext, h kv.Handle, vals []types.Datum, newData []types.Datum, idx *index, txn kv.Transaction, untouched bool, opt *table.CreateIdxOpt) error { + rsData := TryGetHandleRestoredDataWrapper(t.meta, newData, nil, idx.Meta()) + if _, err := idx.create(ctx, txn, vals, h, rsData, untouched, opt); err != nil { + if kv.ErrKeyExists.Equal(err) { + // Make error message consistent with MySQL. + tablecodec.TruncateIndexValues(t.meta, idx.Meta(), vals) + colStrVals, err1 := genIndexKeyStrs(vals) + if err1 != nil { + // if genIndexKeyStrs failed, return the original error. + return err + } + + return kv.GenKeyExistsErr(colStrVals, fmt.Sprintf("%s.%s", idx.TableMeta().Name.String(), idx.Meta().Name.String())) + } + return err + } + return nil +} + +// IterRecords iterates records in the table and calls fn. +func IterRecords(t table.Table, ctx sessionctx.Context, cols []*table.Column, + fn table.RecordIterFunc) error { + prefix := t.RecordPrefix() + txn, err := ctx.Txn(true) + if err != nil { + return err + } + + startKey := tablecodec.EncodeRecordKey(t.RecordPrefix(), kv.IntHandle(math.MinInt64)) + it, err := txn.Iter(startKey, prefix.PrefixNext()) + if err != nil { + return err + } + defer it.Close() + + if !it.Valid() { + return nil + } + + logutil.BgLogger().Debug("iterate records", zap.ByteString("startKey", startKey), zap.ByteString("key", it.Key()), zap.ByteString("value", it.Value())) + + colMap := make(map[int64]*types.FieldType, len(cols)) + for _, col := range cols { + colMap[col.ID] = &col.FieldType + } + defaultVals := make([]types.Datum, len(cols)) + for it.Valid() && it.Key().HasPrefix(prefix) { + // first kv pair is row lock information. + // TODO: check valid lock + // get row handle + handle, err := tablecodec.DecodeRowKey(it.Key()) + if err != nil { + return err + } + rowMap, err := tablecodec.DecodeRowToDatumMap(it.Value(), colMap, ctx.GetSessionVars().Location()) + if err != nil { + return err + } + pkIds, decodeLoc := TryGetCommonPkColumnIds(t.Meta()), ctx.GetSessionVars().Location() + data := make([]types.Datum, len(cols)) + for _, col := range cols { + if col.IsPKHandleColumn(t.Meta()) { + if mysql.HasUnsignedFlag(col.GetFlag()) { + data[col.Offset].SetUint64(uint64(handle.IntValue())) + } else { + data[col.Offset].SetInt64(handle.IntValue()) + } + continue + } else if mysql.HasPriKeyFlag(col.GetFlag()) { + data[col.Offset], err = tryDecodeColumnFromCommonHandle(col, handle, pkIds, decodeLoc) + if err != nil { + return err + } + continue + } + if _, ok := rowMap[col.ID]; ok { + data[col.Offset] = rowMap[col.ID] + continue + } + data[col.Offset], err = GetColDefaultValue(ctx.GetExprCtx(), col, defaultVals) + if err != nil { + return err + } + } + more, err := fn(handle, data, cols) + if !more || err != nil { + return err + } + + rk := tablecodec.EncodeRecordKey(t.RecordPrefix(), handle) + err = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)) + if err != nil { + return err + } + } + + return nil +} + +func tryDecodeColumnFromCommonHandle(col *table.Column, handle kv.Handle, pkIds []int64, decodeLoc *time.Location) (types.Datum, error) { + for i, hid := range pkIds { + if hid != col.ID { + continue + } + _, d, err := codec.DecodeOne(handle.EncodedCol(i)) + if err != nil { + return types.Datum{}, errors.Trace(err) + } + if d, err = tablecodec.Unflatten(d, &col.FieldType, decodeLoc); err != nil { + return types.Datum{}, err + } + return d, nil + } + return types.Datum{}, nil +} + +// GetColDefaultValue gets a column default value. +// The defaultVals is used to avoid calculating the default value multiple times. +func GetColDefaultValue(ctx exprctx.BuildContext, col *table.Column, defaultVals []types.Datum) ( + colVal types.Datum, err error) { + if col.GetOriginDefaultValue() == nil && mysql.HasNotNullFlag(col.GetFlag()) { + return colVal, errors.New("Miss column") + } + if defaultVals[col.Offset].IsNull() { + colVal, err = table.GetColOriginDefaultValue(ctx, col.ToInfo()) + if err != nil { + return colVal, err + } + defaultVals[col.Offset] = colVal + } else { + colVal = defaultVals[col.Offset] + } + + return colVal, nil +} + +// AllocHandle allocate a new handle. +// A statement could reserve some ID in the statement context, try those ones first. +func AllocHandle(ctx context.Context, mctx table.MutateContext, t table.Table) (kv.IntHandle, + error) { + if mctx != nil { + if reserved, ok := mctx.GetReservedRowIDAlloc(); ok { + // First try to alloc if the statement has reserved auto ID. + if rowID, ok := reserved.Consume(); ok { + return kv.IntHandle(rowID), nil + } + } + } + + _, rowID, err := AllocHandleIDs(ctx, mctx, t, 1) + return kv.IntHandle(rowID), err +} + +// AllocHandleIDs allocates n handle ids (_tidb_rowid), and caches the range +// in the table.MutateContext. +func AllocHandleIDs(ctx context.Context, mctx table.MutateContext, t table.Table, n uint64) (int64, int64, error) { + meta := t.Meta() + base, maxID, err := t.Allocators(mctx).Get(autoid.RowIDAllocType).Alloc(ctx, n, 1, 1) + if err != nil { + return 0, 0, err + } + if meta.ShardRowIDBits > 0 { + shardFmt := autoid.NewShardIDFormat(types.NewFieldType(mysql.TypeLonglong), meta.ShardRowIDBits, autoid.RowIDBitLength) + // Use max record ShardRowIDBits to check overflow. + if OverflowShardBits(maxID, meta.MaxShardRowIDBits, autoid.RowIDBitLength, true) { + // If overflow, the rowID may be duplicated. For examples, + // t.meta.ShardRowIDBits = 4 + // rowID = 0010111111111111111111111111111111111111111111111111111111111111 + // shard = 0100000000000000000000000000000000000000000000000000000000000000 + // will be duplicated with: + // rowID = 0100111111111111111111111111111111111111111111111111111111111111 + // shard = 0010000000000000000000000000000000000000000000000000000000000000 + return 0, 0, autoid.ErrAutoincReadFailed + } + shard := mctx.GetRowIDShardGenerator().GetCurrentShard(int(n)) + base = shardFmt.Compose(shard, base) + maxID = shardFmt.Compose(shard, maxID) + } + return base, maxID, nil +} + +// OverflowShardBits checks whether the recordID overflow `1<<(typeBitsLength-shardRowIDBits-1) -1`. +func OverflowShardBits(recordID int64, shardRowIDBits uint64, typeBitsLength uint64, reservedSignBit bool) bool { + var signBit uint64 + if reservedSignBit { + signBit = 1 + } + mask := (1< 0 +} + +// Allocators implements table.Table Allocators interface. +func (t *TableCommon) Allocators(ctx table.AllocatorContext) autoid.Allocators { + if ctx == nil { + return t.allocs + } + if alloc, ok := ctx.AlternativeAllocators(t.meta); ok { + return alloc + } + return t.allocs +} + +// Type implements table.Table Type interface. +func (t *TableCommon) Type() table.Type { + return table.NormalTable +} + +func getBinlogSupport(ctx table.MutateContext, tblInfo *model.TableInfo) (tbctx.BinlogSupport, bool) { + if tblInfo.TempTableType != model.TempTableNone || ctx.InRestrictedSQL() { + return nil, false + } + return ctx.GetBinlogSupport() +} + +func (t *TableCommon) canSkip(col *table.Column, value *types.Datum) bool { + return CanSkip(t.Meta(), col, value) +} + +// CanSkip is for these cases, we can skip the columns in encoded row: +// 1. the column is included in primary key; +// 2. the column's default value is null, and the value equals to that but has no origin default; +// 3. the column is virtual generated. +func CanSkip(info *model.TableInfo, col *table.Column, value *types.Datum) bool { + if col.IsPKHandleColumn(info) { + return true + } + if col.IsCommonHandleColumn(info) { + pkIdx := FindPrimaryIndex(info) + for _, idxCol := range pkIdx.Columns { + if info.Columns[idxCol.Offset].ID != col.ID { + continue + } + canSkip := idxCol.Length == types.UnspecifiedLength + canSkip = canSkip && !types.NeedRestoredData(&col.FieldType) + return canSkip + } + } + if col.GetDefaultValue() == nil && value.IsNull() && col.GetOriginDefaultValue() == nil { + return true + } + if col.IsVirtualGenerated() { + return true + } + return false +} + +// canSkipUpdateBinlog checks whether the column can be skipped or not. +func (t *TableCommon) canSkipUpdateBinlog(col *table.Column, value types.Datum) bool { + return col.IsVirtualGenerated() +} + +// FindIndexByColName returns a public table index containing only one column named `name`. +func FindIndexByColName(t table.Table, name string) table.Index { + for _, idx := range t.Indices() { + // only public index can be read. + if idx.Meta().State != model.StatePublic { + continue + } + + if len(idx.Meta().Columns) == 1 && strings.EqualFold(idx.Meta().Columns[0].Name.L, name) { + return idx + } + } + return nil +} + +func getDuplicateError(tblInfo *model.TableInfo, handle kv.Handle, row []types.Datum) error { + keyName := tblInfo.Name.String() + ".PRIMARY" + + if handle.IsInt() { + return kv.GenKeyExistsErr([]string{handle.String()}, keyName) + } + pkIdx := FindPrimaryIndex(tblInfo) + if pkIdx == nil { + handleData, err := handle.Data() + if err != nil { + return kv.ErrKeyExists.FastGenByArgs(handle.String(), keyName) + } + colStrVals, err := genIndexKeyStrs(handleData) + if err != nil { + return kv.ErrKeyExists.FastGenByArgs(handle.String(), keyName) + } + return kv.GenKeyExistsErr(colStrVals, keyName) + } + pkDts := make([]types.Datum, 0, len(pkIdx.Columns)) + for _, idxCol := range pkIdx.Columns { + pkDts = append(pkDts, row[idxCol.Offset]) + } + tablecodec.TruncateIndexValues(tblInfo, pkIdx, pkDts) + colStrVals, err := genIndexKeyStrs(pkDts) + if err != nil { + // if genIndexKeyStrs failed, return ErrKeyExists with handle.String(). + return kv.ErrKeyExists.FastGenByArgs(handle.String(), keyName) + } + return kv.GenKeyExistsErr(colStrVals, keyName) +} + +func init() { + table.TableFromMeta = TableFromMeta + table.MockTableFromMeta = MockTableFromMeta + tableutil.TempTableFromMeta = TempTableFromMeta +} + +// sequenceCommon cache the sequence value. +// `alter sequence` will invalidate the cached range. +// `setval` will recompute the start position of cached value. +type sequenceCommon struct { + meta *model.SequenceInfo + // base < end when increment > 0. + // base > end when increment < 0. + end int64 + base int64 + // round is used to count the cycle times. + round int64 + mu sync.RWMutex +} + +// GetSequenceBaseEndRound is used in test. +func (s *sequenceCommon) GetSequenceBaseEndRound() (int64, int64, int64) { + s.mu.RLock() + defer s.mu.RUnlock() + return s.base, s.end, s.round +} + +// GetSequenceNextVal implements util.SequenceTable GetSequenceNextVal interface. +// Caching the sequence value in table, we can easily be notified with the cache empty, +// and write the binlogInfo in table level rather than in allocator. +func (t *TableCommon) GetSequenceNextVal(ctx any, dbName, seqName string) (nextVal int64, err error) { + seq := t.sequence + if seq == nil { + // TODO: refine the error. + return 0, errors.New("sequenceCommon is nil") + } + seq.mu.Lock() + defer seq.mu.Unlock() + + err = func() error { + // Check if need to update the cache batch from storage. + // Because seq.base is not always the last allocated value (may be set by setval()). + // So we should try to seek the next value in cache (not just add increment to seq.base). + var ( + updateCache bool + offset int64 + ok bool + ) + if seq.base == seq.end { + // There is no cache yet. + updateCache = true + } else { + // Seek the first valid value in cache. + offset = seq.getOffset() + if seq.meta.Increment > 0 { + nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.base, seq.end) + } else { + nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.end, seq.base) + } + if !ok { + updateCache = true + } + } + if !updateCache { + return nil + } + // Update batch alloc from kv storage. + sequenceAlloc, err1 := getSequenceAllocator(t.allocs) + if err1 != nil { + return err1 + } + var base, end, round int64 + base, end, round, err1 = sequenceAlloc.AllocSeqCache() + if err1 != nil { + return err1 + } + // Only update local cache when alloc succeed. + seq.base = base + seq.end = end + seq.round = round + // write sequence binlog to the pumpClient. + if ctx.(sessionctx.Context).GetSessionVars().BinlogClient != nil { + err = writeSequenceUpdateValueBinlog(ctx.(sessionctx.Context), dbName, seqName, seq.end) + if err != nil { + return err + } + } + // Seek the first valid value in new cache. + // Offset may have changed cause the round is updated. + offset = seq.getOffset() + if seq.meta.Increment > 0 { + nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.base, seq.end) + } else { + nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.end, seq.base) + } + if !ok { + return errors.New("can't find the first value in sequence cache") + } + return nil + }() + // Sequence alloc in kv store error. + if err != nil { + if err == autoid.ErrAutoincReadFailed { + return 0, table.ErrSequenceHasRunOut.GenWithStackByArgs(dbName, seqName) + } + return 0, err + } + seq.base = nextVal + return nextVal, nil +} + +// SetSequenceVal implements util.SequenceTable SetSequenceVal interface. +// The returned bool indicates the newVal is already under the base. +func (t *TableCommon) SetSequenceVal(ctx any, newVal int64, dbName, seqName string) (int64, bool, error) { + seq := t.sequence + if seq == nil { + // TODO: refine the error. + return 0, false, errors.New("sequenceCommon is nil") + } + seq.mu.Lock() + defer seq.mu.Unlock() + + if seq.meta.Increment > 0 { + if newVal <= t.sequence.base { + return 0, true, nil + } + if newVal <= t.sequence.end { + t.sequence.base = newVal + return newVal, false, nil + } + } else { + if newVal >= t.sequence.base { + return 0, true, nil + } + if newVal >= t.sequence.end { + t.sequence.base = newVal + return newVal, false, nil + } + } + + // Invalid the current cache. + t.sequence.base = t.sequence.end + + // Rebase from kv storage. + sequenceAlloc, err := getSequenceAllocator(t.allocs) + if err != nil { + return 0, false, err + } + res, alreadySatisfied, err := sequenceAlloc.RebaseSeq(newVal) + if err != nil { + return 0, false, err + } + if !alreadySatisfied { + // Write sequence binlog to the pumpClient. + if ctx.(sessionctx.Context).GetSessionVars().BinlogClient != nil { + err = writeSequenceUpdateValueBinlog(ctx.(sessionctx.Context), dbName, seqName, seq.end) + if err != nil { + return 0, false, err + } + } + } + // Record the current end after setval succeed. + // Consider the following case. + // create sequence seq + // setval(seq, 100) setval(seq, 50) + // Because no cache (base, end keep 0), so the second setval won't return NULL. + t.sequence.base, t.sequence.end = newVal, newVal + return res, alreadySatisfied, nil +} + +// getOffset is used in under GetSequenceNextVal & SetSequenceVal, which mu is locked. +func (s *sequenceCommon) getOffset() int64 { + offset := s.meta.Start + if s.meta.Cycle && s.round > 0 { + if s.meta.Increment > 0 { + offset = s.meta.MinValue + } else { + offset = s.meta.MaxValue + } + } + return offset +} + +// GetSequenceID implements util.SequenceTable GetSequenceID interface. +func (t *TableCommon) GetSequenceID() int64 { + return t.tableID +} + +// GetSequenceCommon is used in test to get sequenceCommon. +func (t *TableCommon) GetSequenceCommon() *sequenceCommon { + return t.sequence +} + +// TryGetHandleRestoredDataWrapper tries to get the restored data for handle if needed. The argument can be a slice or a map. +func TryGetHandleRestoredDataWrapper(tblInfo *model.TableInfo, row []types.Datum, rowMap map[int64]types.Datum, idx *model.IndexInfo) []types.Datum { + if !collate.NewCollationEnabled() || !tblInfo.IsCommonHandle || tblInfo.CommonHandleVersion == 0 { + return nil + } + rsData := make([]types.Datum, 0, 4) + pkIdx := FindPrimaryIndex(tblInfo) + for _, pkIdxCol := range pkIdx.Columns { + pkCol := tblInfo.Columns[pkIdxCol.Offset] + if !types.NeedRestoredData(&pkCol.FieldType) { + continue + } + var datum types.Datum + if len(rowMap) > 0 { + datum = rowMap[pkCol.ID] + } else { + datum = row[pkCol.Offset] + } + TryTruncateRestoredData(&datum, pkCol, pkIdxCol, idx) + ConvertDatumToTailSpaceCount(&datum, pkCol) + rsData = append(rsData, datum) + } + return rsData +} + +// TryTruncateRestoredData tries to truncate index values. +// Says that primary key(a (8)), +// For index t(a), don't truncate the value. +// For index t(a(9)), truncate to a(9). +// For index t(a(7)), truncate to a(8). +func TryTruncateRestoredData(datum *types.Datum, pkCol *model.ColumnInfo, + pkIdxCol *model.IndexColumn, idx *model.IndexInfo) { + truncateTargetCol := pkIdxCol + for _, idxCol := range idx.Columns { + if idxCol.Offset == pkIdxCol.Offset { + truncateTargetCol = maxIndexLen(pkIdxCol, idxCol) + break + } + } + tablecodec.TruncateIndexValue(datum, truncateTargetCol, pkCol) +} + +// ConvertDatumToTailSpaceCount converts a string datum to an int datum that represents the tail space count. +func ConvertDatumToTailSpaceCount(datum *types.Datum, col *model.ColumnInfo) { + if collate.IsBinCollation(col.GetCollate()) { + *datum = types.NewIntDatum(stringutil.GetTailSpaceCount(datum.GetString())) + } +} + +func maxIndexLen(idxA, idxB *model.IndexColumn) *model.IndexColumn { + if idxA.Length == types.UnspecifiedLength { + return idxA + } + if idxB.Length == types.UnspecifiedLength { + return idxB + } + if idxA.Length > idxB.Length { + return idxA + } + return idxB +} + +func getSequenceAllocator(allocs autoid.Allocators) (autoid.Allocator, error) { + for _, alloc := range allocs.Allocs { + if alloc.GetType() == autoid.SequenceType { + return alloc, nil + } + } + // TODO: refine the error. + return nil, errors.New("sequence allocator is nil") +} + +// BuildTableScanFromInfos build tipb.TableScan with *model.TableInfo and *model.ColumnInfo. +func BuildTableScanFromInfos(tableInfo *model.TableInfo, columnInfos []*model.ColumnInfo) *tipb.TableScan { + pkColIDs := TryGetCommonPkColumnIds(tableInfo) + tsExec := &tipb.TableScan{ + TableId: tableInfo.ID, + Columns: util.ColumnsToProto(columnInfos, tableInfo.PKIsHandle, false), + PrimaryColumnIds: pkColIDs, + } + if tableInfo.IsCommonHandle { + tsExec.PrimaryPrefixColumnIds = PrimaryPrefixColumnIDs(tableInfo) + } + return tsExec +} + +// BuildPartitionTableScanFromInfos build tipb.PartitonTableScan with *model.TableInfo and *model.ColumnInfo. +func BuildPartitionTableScanFromInfos(tableInfo *model.TableInfo, columnInfos []*model.ColumnInfo, fastScan bool) *tipb.PartitionTableScan { + pkColIDs := TryGetCommonPkColumnIds(tableInfo) + tsExec := &tipb.PartitionTableScan{ + TableId: tableInfo.ID, + Columns: util.ColumnsToProto(columnInfos, tableInfo.PKIsHandle, false), + PrimaryColumnIds: pkColIDs, + IsFastScan: &fastScan, + } + if tableInfo.IsCommonHandle { + tsExec.PrimaryPrefixColumnIds = PrimaryPrefixColumnIDs(tableInfo) + } + return tsExec +} + +// SetPBColumnsDefaultValue sets the default values of tipb.ColumnInfo. +func SetPBColumnsDefaultValue(ctx expression.BuildContext, pbColumns []*tipb.ColumnInfo, columns []*model.ColumnInfo) error { + for i, c := range columns { + // For virtual columns, we set their default values to NULL so that TiKV will return NULL properly, + // They real values will be computed later. + if c.IsGenerated() && !c.GeneratedStored { + pbColumns[i].DefaultVal = []byte{codec.NilFlag} + } + if c.GetOriginDefaultValue() == nil { + continue + } + + evalCtx := ctx.GetEvalCtx() + d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(ctx, c) + if err != nil { + return err + } + + pbColumns[i].DefaultVal, err = tablecodec.EncodeValue(evalCtx.Location(), nil, d) + ec := evalCtx.ErrCtx() + err = ec.HandleError(err) + if err != nil { + return err + } + } + return nil +} + +// TemporaryTable is used to store transaction-specific or session-specific information for global / local temporary tables. +// For example, stats and autoID should have their own copies of data, instead of being shared by all sessions. +type TemporaryTable struct { + // Whether it's modified in this transaction. + modified bool + // The stats of this table. So far it's always pseudo stats. + stats *statistics.Table + // The autoID allocator of this table. + autoIDAllocator autoid.Allocator + // Table size. + size int64 + + meta *model.TableInfo +} + +// TempTableFromMeta builds a TempTable from model.TableInfo. +func TempTableFromMeta(tblInfo *model.TableInfo) tableutil.TempTable { + return &TemporaryTable{ + modified: false, + stats: statistics.PseudoTable(tblInfo, false, false), + autoIDAllocator: autoid.NewAllocatorFromTempTblInfo(tblInfo), + meta: tblInfo, + } +} + +// GetAutoIDAllocator is implemented from TempTable.GetAutoIDAllocator. +func (t *TemporaryTable) GetAutoIDAllocator() autoid.Allocator { + return t.autoIDAllocator +} + +// SetModified is implemented from TempTable.SetModified. +func (t *TemporaryTable) SetModified(modified bool) { + t.modified = modified +} + +// GetModified is implemented from TempTable.GetModified. +func (t *TemporaryTable) GetModified() bool { + return t.modified +} + +// GetStats is implemented from TempTable.GetStats. +func (t *TemporaryTable) GetStats() any { + return t.stats +} + +// GetSize gets the table size. +func (t *TemporaryTable) GetSize() int64 { + return t.size +} + +// SetSize sets the table size. +func (t *TemporaryTable) SetSize(v int64) { + t.size = v +} + +// GetMeta gets the table meta. +func (t *TemporaryTable) GetMeta() *model.TableInfo { + return t.meta +} diff --git a/pkg/ttl/ttlworker/binding__failpoint_binding__.go b/pkg/ttl/ttlworker/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..74271806c94e4 --- /dev/null +++ b/pkg/ttl/ttlworker/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package ttlworker + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/ttl/ttlworker/config.go b/pkg/ttl/ttlworker/config.go index b5cef0b0ecc11..a562fd04afd77 100644 --- a/pkg/ttl/ttlworker/config.go +++ b/pkg/ttl/ttlworker/config.go @@ -35,36 +35,36 @@ const ttlTaskHeartBeatTickerInterval = time.Minute const ttlGCInterval = time.Hour func getUpdateInfoSchemaCacheInterval() time.Duration { - failpoint.Inject("update-info-schema-cache-interval", func(val failpoint.Value) time.Duration { + if val, _err_ := failpoint.Eval(_curpkg_("update-info-schema-cache-interval")); _err_ == nil { return time.Duration(val.(int)) - }) + } return updateInfoSchemaCacheInterval } func getUpdateTTLTableStatusCacheInterval() time.Duration { - failpoint.Inject("update-status-table-cache-interval", func(val failpoint.Value) time.Duration { + if val, _err_ := failpoint.Eval(_curpkg_("update-status-table-cache-interval")); _err_ == nil { return time.Duration(val.(int)) - }) + } return updateTTLTableStatusCacheInterval } func getResizeWorkersInterval() time.Duration { - failpoint.Inject("resize-workers-interval", func(val failpoint.Value) time.Duration { + if val, _err_ := failpoint.Eval(_curpkg_("resize-workers-interval")); _err_ == nil { return time.Duration(val.(int)) - }) + } return resizeWorkersInterval } func getTaskManagerLoopTickerInterval() time.Duration { - failpoint.Inject("task-manager-loop-interval", func(val failpoint.Value) time.Duration { + if val, _err_ := failpoint.Eval(_curpkg_("task-manager-loop-interval")); _err_ == nil { return time.Duration(val.(int)) - }) + } return taskManagerLoopTickerInterval } func getTaskManagerHeartBeatExpireInterval() time.Duration { - failpoint.Inject("task-manager-heartbeat-expire-interval", func(val failpoint.Value) time.Duration { + if val, _err_ := failpoint.Eval(_curpkg_("task-manager-heartbeat-expire-interval")); _err_ == nil { return time.Duration(val.(int)) - }) + } return 2 * ttlTaskHeartBeatTickerInterval } diff --git a/pkg/ttl/ttlworker/config.go__failpoint_stash__ b/pkg/ttl/ttlworker/config.go__failpoint_stash__ new file mode 100644 index 0000000000000..b5cef0b0ecc11 --- /dev/null +++ b/pkg/ttl/ttlworker/config.go__failpoint_stash__ @@ -0,0 +1,70 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ttlworker + +import ( + "time" + + "github.com/pingcap/failpoint" +) + +const jobManagerLoopTickerInterval = 10 * time.Second + +const updateInfoSchemaCacheInterval = 2 * time.Minute +const updateTTLTableStatusCacheInterval = 2 * time.Minute + +const ttlInternalSQLTimeout = 30 * time.Second +const resizeWorkersInterval = 30 * time.Second +const splitScanCount = 64 +const ttlJobTimeout = 6 * time.Hour + +const taskManagerLoopTickerInterval = time.Minute +const ttlTaskHeartBeatTickerInterval = time.Minute +const ttlGCInterval = time.Hour + +func getUpdateInfoSchemaCacheInterval() time.Duration { + failpoint.Inject("update-info-schema-cache-interval", func(val failpoint.Value) time.Duration { + return time.Duration(val.(int)) + }) + return updateInfoSchemaCacheInterval +} + +func getUpdateTTLTableStatusCacheInterval() time.Duration { + failpoint.Inject("update-status-table-cache-interval", func(val failpoint.Value) time.Duration { + return time.Duration(val.(int)) + }) + return updateTTLTableStatusCacheInterval +} + +func getResizeWorkersInterval() time.Duration { + failpoint.Inject("resize-workers-interval", func(val failpoint.Value) time.Duration { + return time.Duration(val.(int)) + }) + return resizeWorkersInterval +} + +func getTaskManagerLoopTickerInterval() time.Duration { + failpoint.Inject("task-manager-loop-interval", func(val failpoint.Value) time.Duration { + return time.Duration(val.(int)) + }) + return taskManagerLoopTickerInterval +} + +func getTaskManagerHeartBeatExpireInterval() time.Duration { + failpoint.Inject("task-manager-heartbeat-expire-interval", func(val failpoint.Value) time.Duration { + return time.Duration(val.(int)) + }) + return 2 * ttlTaskHeartBeatTickerInterval +} diff --git a/pkg/util/binding__failpoint_binding__.go b/pkg/util/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..c7fcdb8c0fcf4 --- /dev/null +++ b/pkg/util/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package util + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/breakpoint/binding__failpoint_binding__.go b/pkg/util/breakpoint/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..2199b6182c919 --- /dev/null +++ b/pkg/util/breakpoint/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package breakpoint + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/breakpoint/breakpoint.go b/pkg/util/breakpoint/breakpoint.go index 60c0e1f828799..10eeb0f1eb64a 100644 --- a/pkg/util/breakpoint/breakpoint.go +++ b/pkg/util/breakpoint/breakpoint.go @@ -25,10 +25,10 @@ const NotifyBreakPointFuncKey = stringutil.StringerStr("breakPointNotifyFunc") // Inject injects a break point to a session func Inject(sctx sessionctx.Context, name string) { - failpoint.Inject(name, func(_ failpoint.Value) { + if _, _err_ := failpoint.Eval(_curpkg_(name)); _err_ == nil { val := sctx.Value(NotifyBreakPointFuncKey) if breakPointNotifyAndWaitContinue, ok := val.(func(string)); ok { breakPointNotifyAndWaitContinue(name) } - }) + } } diff --git a/pkg/util/breakpoint/breakpoint.go__failpoint_stash__ b/pkg/util/breakpoint/breakpoint.go__failpoint_stash__ new file mode 100644 index 0000000000000..60c0e1f828799 --- /dev/null +++ b/pkg/util/breakpoint/breakpoint.go__failpoint_stash__ @@ -0,0 +1,34 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package breakpoint + +import ( + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/stringutil" +) + +// NotifyBreakPointFuncKey is the key where break point notify function located +const NotifyBreakPointFuncKey = stringutil.StringerStr("breakPointNotifyFunc") + +// Inject injects a break point to a session +func Inject(sctx sessionctx.Context, name string) { + failpoint.Inject(name, func(_ failpoint.Value) { + val := sctx.Value(NotifyBreakPointFuncKey) + if breakPointNotifyAndWaitContinue, ok := val.(func(string)); ok { + breakPointNotifyAndWaitContinue(name) + } + }) +} diff --git a/pkg/util/cgroup/binding__failpoint_binding__.go b/pkg/util/cgroup/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..d5cbc65ac07c0 --- /dev/null +++ b/pkg/util/cgroup/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package cgroup + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/cgroup/cgroup_cpu_linux.go b/pkg/util/cgroup/cgroup_cpu_linux.go index 665b81f24e2e9..fe8512439d1d0 100644 --- a/pkg/util/cgroup/cgroup_cpu_linux.go +++ b/pkg/util/cgroup/cgroup_cpu_linux.go @@ -28,13 +28,13 @@ import ( // GetCgroupCPU returns the CPU usage and quota for the current cgroup. func GetCgroupCPU() (CPUUsage, error) { - failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("GetCgroupCPUErr")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { var cpuUsage CPUUsage - failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) + return cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr") } - }) + } cpuusage, err := getCgroupCPU("/") cpuusage.NumCPU = runtime.NumCPU() diff --git a/pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ b/pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ new file mode 100644 index 0000000000000..665b81f24e2e9 --- /dev/null +++ b/pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ @@ -0,0 +1,100 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package cgroup + +import ( + "math" + "os" + "runtime" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" +) + +// GetCgroupCPU returns the CPU usage and quota for the current cgroup. +func GetCgroupCPU() (CPUUsage, error) { + failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + var cpuUsage CPUUsage + failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) + } + }) + cpuusage, err := getCgroupCPU("/") + + cpuusage.NumCPU = runtime.NumCPU() + return cpuusage, err +} + +// CPUQuotaToGOMAXPROCS converts the CPU quota applied to the calling process +// to a valid GOMAXPROCS value. +func CPUQuotaToGOMAXPROCS(minValue int) (int, CPUQuotaStatus, error) { + quota, err := GetCgroupCPU() + if err != nil { + return -1, CPUQuotaUndefined, err + } + maxProcs := int(math.Ceil(quota.CPUShares())) + if minValue > 0 && maxProcs < minValue { + return minValue, CPUQuotaMinUsed, nil + } + return maxProcs, CPUQuotaUsed, nil +} + +// GetCPUPeriodAndQuota returns CPU period and quota time of cgroup. +func GetCPUPeriodAndQuota() (period int64, quota int64, err error) { + return getCgroupCPUPeriodAndQuota("/") +} + +// InContainer returns true if the process is running in a container. +func InContainer() bool { + // for cgroup V1, check /proc/self/cgroup, for V2, check /proc/self/mountinfo + return inContainer(procPathCGroup) || inContainer(procPathMountInfo) +} + +func inContainer(path string) bool { + v, err := os.ReadFile(path) + if err != nil { + return false + } + + // For cgroup V1, check /proc/self/cgroup + if path == procPathCGroup { + if strings.Contains(string(v), "docker") || + strings.Contains(string(v), "kubepods") || + strings.Contains(string(v), "containerd") { + return true + } + } + + // For cgroup V2, check /proc/self/mountinfo + if path == procPathMountInfo { + lines := strings.Split(string(v), "\n") + for _, line := range lines { + v := strings.Split(line, " ") + // check mount point of root dir is on overlay or not. + // v[4] means `mount point`, v[8] means `filesystem type`. + // see details from https://man7.org/linux/man-pages/man5/proc.5.html + // TODO: enhance this check, as overlay is not the only storage driver for container. + if len(v) > 8 && v[4] == "/" && v[8] == "overlay" { + return true + } + } + } + + return false +} diff --git a/pkg/util/cgroup/cgroup_cpu_unsupport.go b/pkg/util/cgroup/cgroup_cpu_unsupport.go index 092875c09657b..8c8b7c2c4c3ec 100644 --- a/pkg/util/cgroup/cgroup_cpu_unsupport.go +++ b/pkg/util/cgroup/cgroup_cpu_unsupport.go @@ -26,12 +26,12 @@ import ( // GetCgroupCPU returns the CPU usage and quota for the current cgroup. func GetCgroupCPU() (CPUUsage, error) { var cpuUsage CPUUsage - failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("GetCgroupCPUErr")); _err_ == nil { //nolint:forcetypeassert if val.(bool) { - failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) + return cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr") } - }) + } cpuUsage.NumCPU = runtime.NumCPU() return cpuUsage, nil } diff --git a/pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ b/pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ new file mode 100644 index 0000000000000..092875c09657b --- /dev/null +++ b/pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ @@ -0,0 +1,55 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !linux + +package cgroup + +import ( + "runtime" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" +) + +// GetCgroupCPU returns the CPU usage and quota for the current cgroup. +func GetCgroupCPU() (CPUUsage, error) { + var cpuUsage CPUUsage + failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { + //nolint:forcetypeassert + if val.(bool) { + failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) + } + }) + cpuUsage.NumCPU = runtime.NumCPU() + return cpuUsage, nil +} + +// GetCPUPeriodAndQuota returns CPU period and quota time of cgroup. +// This is Linux-specific and not supported in the current OS. +func GetCPUPeriodAndQuota() (period int64, quota int64, err error) { + return -1, -1, nil +} + +// CPUQuotaToGOMAXPROCS converts the CPU quota applied to the calling process +// to a valid GOMAXPROCS value. This is Linux-specific and not supported in the +// current OS. +func CPUQuotaToGOMAXPROCS(_ int) (int, CPUQuotaStatus, error) { + return -1, CPUQuotaUndefined, nil +} + +// InContainer returns true if the process is running in a container. +func InContainer() bool { + return false +} diff --git a/pkg/util/chunk/binding__failpoint_binding__.go b/pkg/util/chunk/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..ff0fbc387a100 --- /dev/null +++ b/pkg/util/chunk/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package chunk + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/chunk/chunk_in_disk.go b/pkg/util/chunk/chunk_in_disk.go index dadcddefd4584..e8e7177cac03d 100644 --- a/pkg/util/chunk/chunk_in_disk.go +++ b/pkg/util/chunk/chunk_in_disk.go @@ -329,7 +329,7 @@ func (d *DataInDiskByChunks) NumChunks() int { func injectChunkInDiskRandomError() error { var err error - failpoint.Inject("ChunkInDiskError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ChunkInDiskError")); _err_ == nil { if val.(bool) { randNum := rand.Int31n(10000) if randNum < 3 { @@ -339,6 +339,6 @@ func injectChunkInDiskRandomError() error { time.Sleep(time.Duration(delayTime) * time.Millisecond) } } - }) + } return err } diff --git a/pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ b/pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ new file mode 100644 index 0000000000000..dadcddefd4584 --- /dev/null +++ b/pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ @@ -0,0 +1,344 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package chunk + +import ( + "io" + "math/rand" + "os" + "strconv" + "time" + "unsafe" + + errors2 "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/memory" +) + +const byteLen = int64(unsafe.Sizeof(byte(0))) +const intLen = int64(unsafe.Sizeof(int(0))) +const int64Len = int64(unsafe.Sizeof(int64(0))) + +const chkFixedSize = intLen * 4 +const colMetaSize = int64Len * 4 + +const defaultChunkDataInDiskByChunksPath = "defaultChunkDataInDiskByChunksPath" + +// DataInDiskByChunks represents some data stored in temporary disk. +// They can only be restored by chunks. +type DataInDiskByChunks struct { + fieldTypes []*types.FieldType + offsetOfEachChunk []int64 + + totalDataSize int64 + totalRowNum int64 + diskTracker *disk.Tracker // track disk usage. + + dataFile diskFileReaderWriter + + // Write or read data needs this buffer to temporarily store data + buf []byte +} + +// NewDataInDiskByChunks creates a new DataInDiskByChunks with field types. +func NewDataInDiskByChunks(fieldTypes []*types.FieldType) *DataInDiskByChunks { + d := &DataInDiskByChunks{ + fieldTypes: fieldTypes, + totalDataSize: 0, + totalRowNum: 0, + // TODO: set the quota of disk usage. + diskTracker: disk.NewTracker(memory.LabelForChunkDataInDiskByChunks, -1), + buf: make([]byte, 0, 4096), + } + return d +} + +func (d *DataInDiskByChunks) initDiskFile() (err error) { + err = disk.CheckAndInitTempDir() + if err != nil { + return + } + err = d.dataFile.initWithFileName(defaultChunkDataInDiskByChunksPath + strconv.Itoa(d.diskTracker.Label())) + return +} + +// GetDiskTracker returns the memory tracker of this List. +func (d *DataInDiskByChunks) GetDiskTracker() *disk.Tracker { + return d.diskTracker +} + +// Add adds a chunk to the DataInDiskByChunks. Caller must make sure the input chk has the same field types. +// Warning: Do not concurrently call this function. +func (d *DataInDiskByChunks) Add(chk *Chunk) (err error) { + if err := injectChunkInDiskRandomError(); err != nil { + return err + } + + if chk.NumRows() == 0 { + return errors2.New("Chunk spilled to disk should have at least 1 row") + } + + if d.dataFile.file == nil { + err = d.initDiskFile() + if err != nil { + return + } + } + + serializedBytesNum := d.serializeDataToBuf(chk) + + var writeNum int + writeNum, err = d.dataFile.write(d.buf) + if err != nil { + return + } + + if int64(writeNum) != serializedBytesNum { + return errors2.New("Some data fail to be spilled to disk") + } + d.offsetOfEachChunk = append(d.offsetOfEachChunk, d.totalDataSize) + d.totalDataSize += serializedBytesNum + d.totalRowNum += int64(chk.NumRows()) + d.dataFile.offWrite += serializedBytesNum + + d.diskTracker.Consume(serializedBytesNum) + return +} + +func (d *DataInDiskByChunks) getChunkSize(chkIdx int) int64 { + totalChunkNum := len(d.offsetOfEachChunk) + if chkIdx == totalChunkNum-1 { + return d.totalDataSize - d.offsetOfEachChunk[chkIdx] + } + return d.offsetOfEachChunk[chkIdx+1] - d.offsetOfEachChunk[chkIdx] +} + +// GetChunk gets a Chunk from the DataInDiskByChunks by chkIdx. +func (d *DataInDiskByChunks) GetChunk(chkIdx int) (*Chunk, error) { + if err := injectChunkInDiskRandomError(); err != nil { + return nil, err + } + + reader := d.dataFile.getSectionReader(d.offsetOfEachChunk[chkIdx]) + chkSize := d.getChunkSize(chkIdx) + + if cap(d.buf) < int(chkSize) { + d.buf = make([]byte, chkSize) + } else { + d.buf = d.buf[:chkSize] + } + + readByteNum, err := io.ReadFull(reader, d.buf) + if err != nil { + return nil, err + } + + if int64(readByteNum) != chkSize { + return nil, errors2.New("Fail to restore the spilled chunk") + } + + chk := NewEmptyChunk(d.fieldTypes) + d.deserializeDataToChunk(chk) + + return chk, nil +} + +// Close releases the disk resource. +func (d *DataInDiskByChunks) Close() { + if d.dataFile.file != nil { + d.diskTracker.Consume(-d.diskTracker.BytesConsumed()) + terror.Call(d.dataFile.file.Close) + terror.Log(os.Remove(d.dataFile.file.Name())) + } +} + +func (d *DataInDiskByChunks) serializeColMeta(pos int64, length int64, nullMapSize int64, dataSize int64, offsetSize int64) { + *(*int64)(unsafe.Pointer(&d.buf[pos])) = length + *(*int64)(unsafe.Pointer(&d.buf[pos+int64Len])) = nullMapSize + *(*int64)(unsafe.Pointer(&d.buf[pos+int64Len*2])) = dataSize + *(*int64)(unsafe.Pointer(&d.buf[pos+int64Len*3])) = offsetSize +} + +func (d *DataInDiskByChunks) serializeOffset(pos *int64, offsets []int64, offsetSize int64) { + d.buf = d.buf[:*pos+offsetSize] + for _, offset := range offsets { + *(*int64)(unsafe.Pointer(&d.buf[*pos])) = offset + *pos += int64Len + } +} + +func (d *DataInDiskByChunks) serializeChunkData(pos *int64, chk *Chunk, selSize int64) { + d.buf = d.buf[:chkFixedSize] + *(*int)(unsafe.Pointer(&d.buf[*pos])) = chk.numVirtualRows + *(*int)(unsafe.Pointer(&d.buf[*pos+intLen])) = chk.capacity + *(*int)(unsafe.Pointer(&d.buf[*pos+intLen*2])) = chk.requiredRows + *(*int)(unsafe.Pointer(&d.buf[*pos+intLen*3])) = int(selSize) + *pos += chkFixedSize + + d.buf = d.buf[:*pos+selSize] + + selLen := len(chk.sel) + for i := 0; i < selLen; i++ { + *(*int)(unsafe.Pointer(&d.buf[*pos])) = chk.sel[i] + *pos += intLen + } +} + +func (d *DataInDiskByChunks) serializeColumns(pos *int64, chk *Chunk) { + for _, col := range chk.columns { + d.buf = d.buf[:*pos+colMetaSize] + nullMapSize := int64(len(col.nullBitmap)) * byteLen + dataSize := int64(len(col.data)) * byteLen + offsetSize := int64(len(col.offsets)) * int64Len + d.serializeColMeta(*pos, int64(col.length), nullMapSize, dataSize, offsetSize) + *pos += colMetaSize + + d.buf = append(d.buf, col.nullBitmap...) + d.buf = append(d.buf, col.data...) + *pos += nullMapSize + dataSize + d.serializeOffset(pos, col.offsets, offsetSize) + } +} + +// Serialized format of a chunk: +// chunk data: | numVirtualRows | capacity | requiredRows | selSize | sel... | +// column1 data: | length | nullMapSize | dataSize | offsetSize | nullBitmap... | data... | offsets... | +// column2 data: | length | nullMapSize | dataSize | offsetSize | nullBitmap... | data... | offsets... | +// ... +// columnN data: | length | nullMapSize | dataSize | offsetSize | nullBitmap... | data... | offsets... | +// +// `xxx...` means this is a variable field filled by bytes. +func (d *DataInDiskByChunks) serializeDataToBuf(chk *Chunk) int64 { + totalBytes := int64(0) + + // Calculate total memory that buffer needs + selSize := int64(len(chk.sel)) * intLen + totalBytes += chkFixedSize + selSize + for _, col := range chk.columns { + nullMapSize := int64(len(col.nullBitmap)) * byteLen + dataSize := int64(len(col.data)) * byteLen + offsetSize := int64(len(col.offsets)) * int64Len + totalBytes += colMetaSize + nullMapSize + dataSize + offsetSize + } + + if cap(d.buf) < int(totalBytes) { + d.buf = make([]byte, 0, totalBytes) + } + + pos := int64(0) + d.serializeChunkData(&pos, chk, selSize) + d.serializeColumns(&pos, chk) + return totalBytes +} + +func (d *DataInDiskByChunks) deserializeColMeta(pos *int64) (length int64, nullMapSize int64, dataSize int64, offsetSize int64) { + length = *(*int64)(unsafe.Pointer(&d.buf[*pos])) + *pos += int64Len + + nullMapSize = *(*int64)(unsafe.Pointer(&d.buf[*pos])) + *pos += int64Len + + dataSize = *(*int64)(unsafe.Pointer(&d.buf[*pos])) + *pos += int64Len + + offsetSize = *(*int64)(unsafe.Pointer(&d.buf[*pos])) + *pos += int64Len + return +} + +func (d *DataInDiskByChunks) deserializeSel(chk *Chunk, pos *int64, selSize int) { + selLen := int64(selSize) / intLen + chk.sel = make([]int, selLen) + for i := int64(0); i < selLen; i++ { + chk.sel[i] = *(*int)(unsafe.Pointer(&d.buf[*pos])) + *pos += intLen + } +} + +func (d *DataInDiskByChunks) deserializeChunkData(chk *Chunk, pos *int64) { + chk.numVirtualRows = *(*int)(unsafe.Pointer(&d.buf[*pos])) + *pos += intLen + + chk.capacity = *(*int)(unsafe.Pointer(&d.buf[*pos])) + *pos += intLen + + chk.requiredRows = *(*int)(unsafe.Pointer(&d.buf[*pos])) + *pos += intLen + + selSize := *(*int)(unsafe.Pointer(&d.buf[*pos])) + *pos += intLen + if selSize != 0 { + d.deserializeSel(chk, pos, selSize) + } +} + +func (d *DataInDiskByChunks) deserializeOffsets(dst []int64, pos *int64) { + offsetNum := len(dst) + for i := 0; i < offsetNum; i++ { + dst[i] = *(*int64)(unsafe.Pointer(&d.buf[*pos])) + *pos += int64Len + } +} + +func (d *DataInDiskByChunks) deserializeColumns(chk *Chunk, pos *int64) { + for _, col := range chk.columns { + length, nullMapSize, dataSize, offsetSize := d.deserializeColMeta(pos) + col.nullBitmap = make([]byte, nullMapSize) + col.data = make([]byte, dataSize) + col.offsets = make([]int64, offsetSize/int64Len) + + col.length = int(length) + copy(col.nullBitmap, d.buf[*pos:*pos+nullMapSize]) + *pos += nullMapSize + copy(col.data, d.buf[*pos:*pos+dataSize]) + *pos += dataSize + d.deserializeOffsets(col.offsets, pos) + } +} + +func (d *DataInDiskByChunks) deserializeDataToChunk(chk *Chunk) { + pos := int64(0) + d.deserializeChunkData(chk, &pos) + d.deserializeColumns(chk, &pos) +} + +// NumRows returns total spilled row number +func (d *DataInDiskByChunks) NumRows() int64 { + return d.totalRowNum +} + +// NumChunks returns total spilled chunk number +func (d *DataInDiskByChunks) NumChunks() int { + return len(d.offsetOfEachChunk) +} + +func injectChunkInDiskRandomError() error { + var err error + failpoint.Inject("ChunkInDiskError", func(val failpoint.Value) { + if val.(bool) { + randNum := rand.Int31n(10000) + if randNum < 3 { + err = errors2.New("random error is triggered") + } else if randNum < 6 { + delayTime := rand.Int31n(10) + 5 + time.Sleep(time.Duration(delayTime) * time.Millisecond) + } + } + }) + return err +} diff --git a/pkg/util/chunk/row_container.go b/pkg/util/chunk/row_container.go index 9b443569fb05c..572c09b7984ad 100644 --- a/pkg/util/chunk/row_container.go +++ b/pkg/util/chunk/row_container.go @@ -164,11 +164,11 @@ func (c *RowContainer) spillToDisk(preSpillError error) { logutil.BgLogger().Error("spill to disk failed", zap.Stack("stack"), zap.Error(err)) } }() - failpoint.Inject("spillToDiskOutOfDiskQuota", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("spillToDiskOutOfDiskQuota")); _err_ == nil { if val.(bool) { panic("out of disk quota when spilling") } - }) + } if preSpillError != nil { c.m.records.spillError = preSpillError return @@ -249,11 +249,11 @@ func (c *RowContainer) NumChunks() int { func (c *RowContainer) Add(chk *Chunk) (err error) { c.m.RLock() defer c.m.RUnlock() - failpoint.Inject("testRowContainerDeadLock", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("testRowContainerDeadLock")); _err_ == nil { if val.(bool) { time.Sleep(time.Second) } - }) + } if c.alreadySpilled() { if err := c.m.records.spillError; err != nil { return err @@ -559,11 +559,11 @@ func (c *SortedRowContainer) keyColumnsLess(i, j int) bool { c.memTracker.Consume(1) c.timesOfRowCompare = 0 } - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { if val.(bool) { c.timesOfRowCompare += 1024 } - }) + } c.timesOfRowCompare++ rowI := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[i]) rowJ := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[j]) @@ -598,11 +598,11 @@ func (c *SortedRowContainer) Sort() (ret error) { c.ptrM.rowPtrs = append(c.ptrM.rowPtrs, RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)}) } } - failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorDuringSortRowContainer")); _err_ == nil { if val.(bool) { panic("sort meet error") } - }) + } sort.Slice(c.ptrM.rowPtrs, c.keyColumnsLess) return } diff --git a/pkg/util/chunk/row_container.go__failpoint_stash__ b/pkg/util/chunk/row_container.go__failpoint_stash__ new file mode 100644 index 0000000000000..9b443569fb05c --- /dev/null +++ b/pkg/util/chunk/row_container.go__failpoint_stash__ @@ -0,0 +1,691 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package chunk + +import ( + "fmt" + "sort" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util/disk" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "go.uber.org/zap" + "golang.org/x/sys/cpu" +) + +// ErrCannotAddBecauseSorted indicate that the SortPartition is sorted and prohibit inserting data. +var ErrCannotAddBecauseSorted = errors.New("can not add because sorted") + +type rowContainerRecord struct { + inMemory *List + inDisk *DataInDiskByRows + // spillError stores the error when spilling. + spillError error +} + +type mutexForRowContainer struct { + // Add cache padding to avoid false sharing issue. + _ cpu.CacheLinePad + // RWMutex guarantees spill and get operator for rowContainer is mutually exclusive. + // `rLock` and `wLocks` is introduced to reduce the contention when multiple + // goroutine touch the same rowContainer concurrently. If there are multiple + // goroutines touch the same rowContainer concurrently, it's recommended to + // use RowContainer.ShallowCopyWithNewMutex to build a new RowContainer for + // each goroutine. Thus each goroutine holds its own rLock but share the same + // underlying data, which can reduce the contention on m.rLock remarkably and + // get better performance. + rLock sync.RWMutex + wLocks []*sync.RWMutex + records *rowContainerRecord + _ cpu.CacheLinePad +} + +// Lock locks rw for writing. +func (m *mutexForRowContainer) Lock() { + for _, l := range m.wLocks { + l.Lock() + } +} + +// Unlock unlocks rw for writing. +func (m *mutexForRowContainer) Unlock() { + for _, l := range m.wLocks { + l.Unlock() + } +} + +// RLock locks rw for reading. +func (m *mutexForRowContainer) RLock() { + m.rLock.RLock() +} + +// RUnlock undoes a single RLock call. +func (m *mutexForRowContainer) RUnlock() { + m.rLock.RUnlock() +} + +type spillHelper interface { + SpillToDisk() + hasEnoughDataToSpill(t *memory.Tracker) bool +} + +// RowContainer provides a place for many rows, so many that we might want to spill them into disk. +// nolint:structcheck +type RowContainer struct { + m *mutexForRowContainer + + memTracker *memory.Tracker + diskTracker *disk.Tracker + actionSpill *SpillDiskAction +} + +// NewRowContainer creates a new RowContainer in memory. +func NewRowContainer(fieldType []*types.FieldType, chunkSize int) *RowContainer { + li := NewList(fieldType, chunkSize, chunkSize) + rc := &RowContainer{ + m: &mutexForRowContainer{ + records: &rowContainerRecord{inMemory: li}, + rLock: sync.RWMutex{}, + wLocks: []*sync.RWMutex{}, + }, + memTracker: memory.NewTracker(memory.LabelForRowContainer, -1), + diskTracker: disk.NewTracker(memory.LabelForRowContainer, -1), + } + rc.m.wLocks = append(rc.m.wLocks, &rc.m.rLock) + li.GetMemTracker().AttachTo(rc.GetMemTracker()) + return rc +} + +// ShallowCopyWithNewMutex shallow clones a RowContainer. +// The new RowContainer shares the same underlying data with the old one but +// holds an individual rLock. +func (c *RowContainer) ShallowCopyWithNewMutex() *RowContainer { + newRC := *c + newRC.m = &mutexForRowContainer{ + records: c.m.records, + rLock: sync.RWMutex{}, + wLocks: []*sync.RWMutex{}, + } + c.m.wLocks = append(c.m.wLocks, &newRC.m.rLock) + return &newRC +} + +// SpillToDisk spills data to disk. This function may be called in parallel. +func (c *RowContainer) SpillToDisk() { + c.spillToDisk(nil) +} + +func (*RowContainer) hasEnoughDataToSpill(_ *memory.Tracker) bool { + return true +} + +func (c *RowContainer) spillToDisk(preSpillError error) { + c.m.Lock() + defer c.m.Unlock() + if c.alreadySpilled() { + return + } + // c.actionSpill may be nil when testing SpillToDisk directly. + if c.actionSpill != nil { + if c.actionSpill.getStatus() == spilledYet { + // The rowContainer has been closed. + return + } + c.actionSpill.setStatus(spilling) + defer c.actionSpill.cond.Broadcast() + defer c.actionSpill.setStatus(spilledYet) + } + var err error + memory.QueryForceDisk.Add(1) + n := c.m.records.inMemory.NumChunks() + c.m.records.inDisk = NewDataInDiskByRows(c.m.records.inMemory.FieldTypes()) + c.m.records.inDisk.diskTracker.AttachTo(c.diskTracker) + defer func() { + if r := recover(); r != nil { + err := fmt.Errorf("%v", r) + c.m.records.spillError = err + logutil.BgLogger().Error("spill to disk failed", zap.Stack("stack"), zap.Error(err)) + } + }() + failpoint.Inject("spillToDiskOutOfDiskQuota", func(val failpoint.Value) { + if val.(bool) { + panic("out of disk quota when spilling") + } + }) + if preSpillError != nil { + c.m.records.spillError = preSpillError + return + } + for i := 0; i < n; i++ { + chk := c.m.records.inMemory.GetChunk(i) + err = c.m.records.inDisk.Add(chk) + if err != nil { + c.m.records.spillError = err + return + } + c.m.records.inMemory.GetMemTracker().HandleKillSignal() + } + c.m.records.inMemory.Clear() +} + +// Reset resets RowContainer. +func (c *RowContainer) Reset() error { + c.m.Lock() + defer c.m.Unlock() + if c.alreadySpilled() { + err := c.m.records.inDisk.Close() + c.m.records.inDisk = nil + if err != nil { + return err + } + c.actionSpill.Reset() + } else { + c.m.records.inMemory.Reset() + } + return nil +} + +// alreadySpilled indicates that records have spilled out into disk. +func (c *RowContainer) alreadySpilled() bool { + return c.m.records.inDisk != nil +} + +// AlreadySpilledSafeForTest indicates that records have spilled out into disk. It's thread-safe. +// The function is only used for test. +func (c *RowContainer) AlreadySpilledSafeForTest() bool { + c.m.RLock() + defer c.m.RUnlock() + return c.m.records.inDisk != nil +} + +// NumRow returns the number of rows in the container +func (c *RowContainer) NumRow() int { + c.m.RLock() + defer c.m.RUnlock() + if c.alreadySpilled() { + return c.m.records.inDisk.Len() + } + return c.m.records.inMemory.Len() +} + +// NumRowsOfChunk returns the number of rows of a chunk in the DataInDiskByRows. +func (c *RowContainer) NumRowsOfChunk(chkID int) int { + c.m.RLock() + defer c.m.RUnlock() + if c.alreadySpilled() { + return c.m.records.inDisk.NumRowsOfChunk(chkID) + } + return c.m.records.inMemory.NumRowsOfChunk(chkID) +} + +// NumChunks returns the number of chunks in the container. +func (c *RowContainer) NumChunks() int { + c.m.RLock() + defer c.m.RUnlock() + if c.alreadySpilled() { + return c.m.records.inDisk.NumChunks() + } + return c.m.records.inMemory.NumChunks() +} + +// Add appends a chunk into the RowContainer. +func (c *RowContainer) Add(chk *Chunk) (err error) { + c.m.RLock() + defer c.m.RUnlock() + failpoint.Inject("testRowContainerDeadLock", func(val failpoint.Value) { + if val.(bool) { + time.Sleep(time.Second) + } + }) + if c.alreadySpilled() { + if err := c.m.records.spillError; err != nil { + return err + } + err = c.m.records.inDisk.Add(chk) + } else { + c.m.records.inMemory.Add(chk) + } + return +} + +// AllocChunk allocates a new chunk from RowContainer. +func (c *RowContainer) AllocChunk() (chk *Chunk) { + return c.m.records.inMemory.allocChunk() +} + +// GetChunk returns chkIdx th chunk of in memory records. +func (c *RowContainer) GetChunk(chkIdx int) (*Chunk, error) { + c.m.RLock() + defer c.m.RUnlock() + if !c.alreadySpilled() { + return c.m.records.inMemory.GetChunk(chkIdx), nil + } + if err := c.m.records.spillError; err != nil { + return nil, err + } + return c.m.records.inDisk.GetChunk(chkIdx) +} + +// GetRow returns the row the ptr pointed to. +func (c *RowContainer) GetRow(ptr RowPtr) (row Row, err error) { + row, _, err = c.GetRowAndAppendToChunkIfInDisk(ptr, nil) + return row, err +} + +// GetRowAndAppendToChunkIfInDisk gets a Row from the RowContainer by RowPtr. If the container has spilled, the row will +// be appended to the chunk. It'll return `nil` chunk if the container hasn't spilled, or it returns an error. +func (c *RowContainer) GetRowAndAppendToChunkIfInDisk(ptr RowPtr, chk *Chunk) (row Row, _ *Chunk, err error) { + c.m.RLock() + defer c.m.RUnlock() + if c.alreadySpilled() { + if err := c.m.records.spillError; err != nil { + return Row{}, nil, err + } + return c.m.records.inDisk.GetRowAndAppendToChunk(ptr, chk) + } + return c.m.records.inMemory.GetRow(ptr), nil, nil +} + +// GetRowAndAlwaysAppendToChunk gets a Row from the RowContainer by RowPtr. Unlike `GetRowAndAppendToChunkIfInDisk`, this +// function always appends the row to the chunk, without considering whether it has spilled. +// It'll return `nil` chunk if it returns an error, or the chunk will be the same with the argument. +func (c *RowContainer) GetRowAndAlwaysAppendToChunk(ptr RowPtr, chk *Chunk) (row Row, _ *Chunk, err error) { + row, retChk, err := c.GetRowAndAppendToChunkIfInDisk(ptr, chk) + if err != nil { + return row, nil, err + } + + if retChk == nil { + // The container hasn't spilled, and the row is not appended to the chunk, so append the chunk explicitly here + chk.AppendRow(row) + } + + return row, chk, nil +} + +// GetMemTracker returns the memory tracker in records, panics if the RowContainer has already spilled. +func (c *RowContainer) GetMemTracker() *memory.Tracker { + return c.memTracker +} + +// GetDiskTracker returns the underlying disk usage tracker in recordsInDisk. +func (c *RowContainer) GetDiskTracker() *disk.Tracker { + return c.diskTracker +} + +// Close close the RowContainer +func (c *RowContainer) Close() (err error) { + c.m.RLock() + defer c.m.RUnlock() + if c.actionSpill != nil { + // Set status to spilledYet to avoid spilling. + c.actionSpill.setStatus(spilledYet) + c.actionSpill.cond.Broadcast() + c.actionSpill.SetFinished() + } + c.memTracker.Detach() + c.diskTracker.Detach() + if c.alreadySpilled() { + err = c.m.records.inDisk.Close() + c.m.records.inDisk = nil + } + c.m.records.inMemory.Clear() + c.m.records.inMemory = nil + return +} + +// ActionSpill returns a SpillDiskAction for spilling over to disk. +func (c *RowContainer) ActionSpill() *SpillDiskAction { + if c.actionSpill == nil { + c.actionSpill = &SpillDiskAction{ + c: c, + baseSpillDiskAction: &baseSpillDiskAction{cond: spillStatusCond{sync.NewCond(new(sync.Mutex)), notSpilled}}, + } + } + return c.actionSpill +} + +// ActionSpillForTest returns a SpillDiskAction for spilling over to disk for test. +func (c *RowContainer) ActionSpillForTest() *SpillDiskAction { + c.actionSpill = &SpillDiskAction{ + c: c, + baseSpillDiskAction: &baseSpillDiskAction{ + testSyncInputFunc: func() { + c.actionSpill.testWg.Add(1) + }, + testSyncOutputFunc: func() { + c.actionSpill.testWg.Done() + }, + cond: spillStatusCond{sync.NewCond(new(sync.Mutex)), notSpilled}, + }, + } + return c.actionSpill +} + +type baseSpillDiskAction struct { + memory.BaseOOMAction + m sync.Mutex + once sync.Once + cond spillStatusCond + + // test function only used for test sync. + testSyncInputFunc func() + testSyncOutputFunc func() + testWg sync.WaitGroup +} + +// SpillDiskAction implements memory.ActionOnExceed for chunk.List. If +// the memory quota of a query is exceeded, SpillDiskAction.Action is +// triggered. +type SpillDiskAction struct { + c *RowContainer + *baseSpillDiskAction +} + +// Action sends a signal to trigger spillToDisk method of RowContainer +// and if it is already triggered before, call its fallbackAction. +func (a *SpillDiskAction) Action(t *memory.Tracker) { + a.action(t, a.c) +} + +type spillStatusCond struct { + *sync.Cond + // status indicates different stages for the Action + // notSpilled indicates the rowContainer is not spilled. + // spilling indicates the rowContainer is spilling. + // spilledYet indicates thr rowContainer is spilled. + status spillStatus +} + +type spillStatus uint32 + +const ( + notSpilled spillStatus = iota + spilling + spilledYet +) + +func (a *baseSpillDiskAction) setStatus(status spillStatus) { + a.cond.L.Lock() + defer a.cond.L.Unlock() + a.cond.status = status +} + +func (a *baseSpillDiskAction) getStatus() spillStatus { + a.cond.L.Lock() + defer a.cond.L.Unlock() + return a.cond.status +} + +func (a *baseSpillDiskAction) action(t *memory.Tracker, spillHelper spillHelper) { + a.m.Lock() + defer a.m.Unlock() + + if a.getStatus() == notSpilled && spillHelper.hasEnoughDataToSpill(t) { + a.once.Do(func() { + logutil.BgLogger().Info("memory exceeds quota, spill to disk now.", + zap.Int64("consumed", t.BytesConsumed()), zap.Int64("quota", t.GetBytesLimit())) + if a.testSyncInputFunc != nil { + a.testSyncInputFunc() + go func() { + spillHelper.SpillToDisk() + a.testSyncOutputFunc() + }() + return + } + go spillHelper.SpillToDisk() + }) + return + } + + a.cond.L.Lock() + for a.cond.status == spilling { + a.cond.Wait() + } + a.cond.L.Unlock() + + if !t.CheckExceed() { + return + } + if fallback := a.GetFallback(); fallback != nil { + fallback.Action(t) + } +} + +// Reset resets the status for SpillDiskAction. +func (a *baseSpillDiskAction) Reset() { + a.m.Lock() + defer a.m.Unlock() + a.setStatus(notSpilled) + a.once = sync.Once{} +} + +// GetPriority get the priority of the Action. +func (*baseSpillDiskAction) GetPriority() int64 { + return memory.DefSpillPriority +} + +// WaitForTest waits all goroutine have gone. +func (a *baseSpillDiskAction) WaitForTest() { + a.testWg.Wait() +} + +// SortedRowContainer provides a place for many rows, so many that we might want to sort and spill them into disk. +type SortedRowContainer struct { + *RowContainer + ptrM struct { + sync.RWMutex + // rowPtrs store the chunk index and row index for each row. + // rowPtrs != nil indicates the pointer is initialized and sorted. + // It will get an ErrCannotAddBecauseSorted when trying to insert data if rowPtrs != nil. + rowPtrs []RowPtr + } + + ByItemsDesc []bool + // keyColumns is the column index of the by items. + keyColumns []int + // keyCmpFuncs is used to compare each ByItem. + keyCmpFuncs []CompareFunc + + actionSpill *SortAndSpillDiskAction + memTracker *memory.Tracker + + // Sort is a time-consuming operation, we need to set a checkpoint to detect + // the outside signal periodically. + timesOfRowCompare uint +} + +// NewSortedRowContainer creates a new SortedRowContainer in memory. +func NewSortedRowContainer(fieldType []*types.FieldType, chunkSize int, byItemsDesc []bool, + keyColumns []int, keyCmpFuncs []CompareFunc) *SortedRowContainer { + src := SortedRowContainer{RowContainer: NewRowContainer(fieldType, chunkSize), + ByItemsDesc: byItemsDesc, keyColumns: keyColumns, keyCmpFuncs: keyCmpFuncs} + src.memTracker = memory.NewTracker(memory.LabelForRowContainer, -1) + src.RowContainer.GetMemTracker().AttachTo(src.GetMemTracker()) + return &src +} + +// Close close the SortedRowContainer +func (c *SortedRowContainer) Close() error { + c.ptrM.Lock() + defer c.ptrM.Unlock() + c.GetMemTracker().Consume(int64(-8 * c.NumRow())) + c.ptrM.rowPtrs = nil + return c.RowContainer.Close() +} + +func (c *SortedRowContainer) lessRow(rowI, rowJ Row) bool { + for i, colIdx := range c.keyColumns { + cmpFunc := c.keyCmpFuncs[i] + if cmpFunc != nil { + cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) + if c.ByItemsDesc[i] { + cmp = -cmp + } + if cmp < 0 { + return true + } else if cmp > 0 { + return false + } + } + } + return false +} + +// SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered. +const SignalCheckpointForSort uint = 10240 + +// keyColumnsLess is the less function for key columns. +func (c *SortedRowContainer) keyColumnsLess(i, j int) bool { + if c.timesOfRowCompare >= SignalCheckpointForSort { + // Trigger Consume for checking the NeedKill signal + c.memTracker.Consume(1) + c.timesOfRowCompare = 0 + } + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val.(bool) { + c.timesOfRowCompare += 1024 + } + }) + c.timesOfRowCompare++ + rowI := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[i]) + rowJ := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[j]) + return c.lessRow(rowI, rowJ) +} + +// Sort inits pointers and sorts the records. +func (c *SortedRowContainer) Sort() (ret error) { + c.ptrM.Lock() + defer c.ptrM.Unlock() + ret = nil + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + ret = err + } else { + ret = fmt.Errorf("%v", r) + } + } + }() + if c.ptrM.rowPtrs != nil { + return + } + c.ptrM.rowPtrs = make([]RowPtr, 0, c.NumRow()) // The memory usage has been tracked in SortedRowContainer.Add() function + for chkIdx := 0; chkIdx < c.NumChunks(); chkIdx++ { + rowChk, err := c.GetChunk(chkIdx) + // err must be nil, because the chunk is in memory. + if err != nil { + panic(err) + } + for rowIdx := 0; rowIdx < rowChk.NumRows(); rowIdx++ { + c.ptrM.rowPtrs = append(c.ptrM.rowPtrs, RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)}) + } + } + failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { + if val.(bool) { + panic("sort meet error") + } + }) + sort.Slice(c.ptrM.rowPtrs, c.keyColumnsLess) + return +} + +// SpillToDisk spills data to disk. This function may be called in parallel. +func (c *SortedRowContainer) SpillToDisk() { + err := c.Sort() + c.RowContainer.spillToDisk(err) +} + +func (c *SortedRowContainer) hasEnoughDataToSpill(t *memory.Tracker) bool { + // Guarantee that each partition size is at least 10% of the threshold, to avoid opening too many files. + return c.GetMemTracker().BytesConsumed() > t.GetBytesLimit()/10 +} + +// Add appends a chunk into the SortedRowContainer. +func (c *SortedRowContainer) Add(chk *Chunk) (err error) { + c.ptrM.RLock() + defer c.ptrM.RUnlock() + if c.ptrM.rowPtrs != nil { + return ErrCannotAddBecauseSorted + } + // Consume the memory usage of rowPtrs in advance + c.GetMemTracker().Consume(int64(chk.NumRows() * 8)) + return c.RowContainer.Add(chk) +} + +// GetSortedRow returns the row the idx pointed to. +func (c *SortedRowContainer) GetSortedRow(idx int) (Row, error) { + c.ptrM.RLock() + defer c.ptrM.RUnlock() + ptr := c.ptrM.rowPtrs[idx] + return c.RowContainer.GetRow(ptr) +} + +// GetSortedRowAndAlwaysAppendToChunk returns the row the idx pointed to. +func (c *SortedRowContainer) GetSortedRowAndAlwaysAppendToChunk(idx int, chk *Chunk) (Row, *Chunk, error) { + c.ptrM.RLock() + defer c.ptrM.RUnlock() + ptr := c.ptrM.rowPtrs[idx] + return c.RowContainer.GetRowAndAlwaysAppendToChunk(ptr, chk) +} + +// ActionSpill returns a SortAndSpillDiskAction for sorting and spilling over to disk. +func (c *SortedRowContainer) ActionSpill() *SortAndSpillDiskAction { + if c.actionSpill == nil { + c.actionSpill = &SortAndSpillDiskAction{ + c: c, + baseSpillDiskAction: c.RowContainer.ActionSpill().baseSpillDiskAction, + } + } + return c.actionSpill +} + +// ActionSpillForTest returns a SortAndSpillDiskAction for sorting and spilling over to disk for test. +func (c *SortedRowContainer) ActionSpillForTest() *SortAndSpillDiskAction { + c.actionSpill = &SortAndSpillDiskAction{ + c: c, + baseSpillDiskAction: c.RowContainer.ActionSpillForTest().baseSpillDiskAction, + } + return c.actionSpill +} + +// GetMemTracker return the memory tracker for the sortedRowContainer +func (c *SortedRowContainer) GetMemTracker() *memory.Tracker { + return c.memTracker +} + +// SortAndSpillDiskAction implements memory.ActionOnExceed for chunk.List. If +// the memory quota of a query is exceeded, SortAndSpillDiskAction.Action is +// triggered. +type SortAndSpillDiskAction struct { + c *SortedRowContainer + *baseSpillDiskAction +} + +// Action sends a signal to trigger sortAndSpillToDisk method of RowContainer +// and if it is already triggered before, call its fallbackAction. +func (a *SortAndSpillDiskAction) Action(t *memory.Tracker) { + a.action(t, a.c) +} + +// WaitForTest waits all goroutine have gone. +func (a *SortAndSpillDiskAction) WaitForTest() { + a.testWg.Wait() +} diff --git a/pkg/util/chunk/row_container_reader.go b/pkg/util/chunk/row_container_reader.go index ca124083079c5..797934e66e0e7 100644 --- a/pkg/util/chunk/row_container_reader.go +++ b/pkg/util/chunk/row_container_reader.go @@ -124,11 +124,11 @@ func (reader *rowContainerReader) startWorker() { for chkIdx := 0; chkIdx < reader.rc.NumChunks(); chkIdx++ { chk, err := reader.rc.GetChunk(chkIdx) - failpoint.Inject("get-chunk-error", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("get-chunk-error")); _err_ == nil { if val.(bool) { err = errors.New("fail to get chunk for test") } - }) + } if err != nil { reader.err = err return diff --git a/pkg/util/chunk/row_container_reader.go__failpoint_stash__ b/pkg/util/chunk/row_container_reader.go__failpoint_stash__ new file mode 100644 index 0000000000000..ca124083079c5 --- /dev/null +++ b/pkg/util/chunk/row_container_reader.go__failpoint_stash__ @@ -0,0 +1,170 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package chunk + +import ( + "context" + "runtime" + "sync" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util/logutil" +) + +// RowContainerReader is a forward-only iterator for the row container. It provides an interface similar to other +// iterators, but it doesn't provide `ReachEnd` function and requires manually closing to release goroutine. +// +// It's recommended to use the following pattern to use it: +// +// for iter := NewRowContainerReader(rc); iter.Current() != iter.End(); iter.Next() { +// ... +// } +// iter.Close() +// if iter.Error() != nil { +// } +type RowContainerReader interface { + // Next returns the next Row. + Next() Row + + // Current returns the current Row. + Current() Row + + // End returns the invalid end Row. + End() Row + + // Error returns none-nil error if anything wrong happens during the iteration. + Error() error + + // Close closes the dumper + Close() +} + +var _ RowContainerReader = &rowContainerReader{} + +// rowContainerReader is a forward-only iterator for the row container +// It will spawn two goroutines for reading chunks from disk, and converting the chunk to rows. The row will only be sent +// to `rowCh` inside only after when the full chunk has been read, to avoid concurrently read/write to the chunk. +// +// TODO: record the memory allocated for the channel and chunks. +type rowContainerReader struct { + // context, cancel and waitgroup are used to stop and wait until all goroutine stops. + ctx context.Context + cancel func() + wg sync.WaitGroup + + rc *RowContainer + + currentRow Row + rowCh chan Row + + // this error will only be set by worker + err error +} + +// Next implements RowContainerReader +func (reader *rowContainerReader) Next() Row { + for row := range reader.rowCh { + reader.currentRow = row + return row + } + reader.currentRow = reader.End() + return reader.End() +} + +// Current implements RowContainerReader +func (reader *rowContainerReader) Current() Row { + return reader.currentRow +} + +// End implements RowContainerReader +func (*rowContainerReader) End() Row { + return Row{} +} + +// Error implements RowContainerReader +func (reader *rowContainerReader) Error() error { + return reader.err +} + +func (reader *rowContainerReader) initializeChannel() { + if reader.rc.NumChunks() == 0 { + reader.rowCh = make(chan Row, 1024) + } else { + assumeChunkSize := reader.rc.NumRowsOfChunk(0) + // To avoid blocking in sending to `rowCh` and don't start reading the next chunk, it'd be better to give it + // a buffer at least larger than a single chunk. Here it's allocated twice the chunk size to leave some margin. + reader.rowCh = make(chan Row, 2*assumeChunkSize) + } +} + +// Close implements RowContainerReader +func (reader *rowContainerReader) Close() { + reader.cancel() + reader.wg.Wait() +} + +func (reader *rowContainerReader) startWorker() { + reader.wg.Add(1) + go func() { + defer close(reader.rowCh) + defer reader.wg.Done() + + for chkIdx := 0; chkIdx < reader.rc.NumChunks(); chkIdx++ { + chk, err := reader.rc.GetChunk(chkIdx) + failpoint.Inject("get-chunk-error", func(val failpoint.Value) { + if val.(bool) { + err = errors.New("fail to get chunk for test") + } + }) + if err != nil { + reader.err = err + return + } + + for i := 0; i < chk.NumRows(); i++ { + select { + case reader.rowCh <- chk.GetRow(i): + case <-reader.ctx.Done(): + return + } + } + } + }() +} + +// NewRowContainerReader creates a forward only iterator for row container +func NewRowContainerReader(rc *RowContainer) *rowContainerReader { + ctx, cancel := context.WithCancel(context.Background()) + + reader := &rowContainerReader{ + ctx: ctx, + cancel: cancel, + wg: sync.WaitGroup{}, + + rc: rc, + } + reader.initializeChannel() + reader.startWorker() + reader.Next() + runtime.SetFinalizer(reader, func(reader *rowContainerReader) { + if reader.ctx.Err() == nil { + logutil.BgLogger().Warn("rowContainerReader is closed by finalizer") + reader.Close() + } + }) + + return reader +} diff --git a/pkg/util/codec/binding__failpoint_binding__.go b/pkg/util/codec/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..894075781dd35 --- /dev/null +++ b/pkg/util/codec/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package codec + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/codec/decimal.go b/pkg/util/codec/decimal.go index 7b26feb86b2d8..a4a0d80c78f87 100644 --- a/pkg/util/codec/decimal.go +++ b/pkg/util/codec/decimal.go @@ -47,11 +47,11 @@ func valueSizeOfDecimal(dec *types.MyDecimal, precision, frac int) (int, error) // DecodeDecimal decodes bytes to decimal. func DecodeDecimal(b []byte) ([]byte, *types.MyDecimal, int, int, error) { - failpoint.Inject("errorInDecodeDecimal", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("errorInDecodeDecimal")); _err_ == nil { if val.(bool) { - failpoint.Return(b, nil, 0, 0, errors.New("gofail error")) + return b, nil, 0, 0, errors.New("gofail error") } - }) + } if len(b) < 3 { return b, nil, 0, 0, errors.New("insufficient bytes to decode value") diff --git a/pkg/util/codec/decimal.go__failpoint_stash__ b/pkg/util/codec/decimal.go__failpoint_stash__ new file mode 100644 index 0000000000000..7b26feb86b2d8 --- /dev/null +++ b/pkg/util/codec/decimal.go__failpoint_stash__ @@ -0,0 +1,69 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package codec + +import ( + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" +) + +// EncodeDecimal encodes a decimal into a byte slice which can be sorted lexicographically later. +func EncodeDecimal(b []byte, dec *types.MyDecimal, precision, frac int) ([]byte, error) { + if precision == 0 { + precision, frac = dec.PrecisionAndFrac() + } + if frac > mysql.MaxDecimalScale { + frac = mysql.MaxDecimalScale + } + b = append(b, byte(precision), byte(frac)) + b, err := dec.WriteBin(precision, frac, b) + return b, errors.Trace(err) +} + +func valueSizeOfDecimal(dec *types.MyDecimal, precision, frac int) (int, error) { + if precision == 0 { + precision, frac = dec.PrecisionAndFrac() + } + binSize, err := types.DecimalBinSize(precision, frac) + if err != nil { + return 0, err + } + return binSize + 2, nil +} + +// DecodeDecimal decodes bytes to decimal. +func DecodeDecimal(b []byte) ([]byte, *types.MyDecimal, int, int, error) { + failpoint.Inject("errorInDecodeDecimal", func(val failpoint.Value) { + if val.(bool) { + failpoint.Return(b, nil, 0, 0, errors.New("gofail error")) + } + }) + + if len(b) < 3 { + return b, nil, 0, 0, errors.New("insufficient bytes to decode value") + } + precision := int(b[0]) + frac := int(b[1]) + b = b[2:] + dec := new(types.MyDecimal) + binSize, err := dec.FromBin(b, precision, frac) + b = b[binSize:] + if err != nil { + return b, nil, precision, frac, errors.Trace(err) + } + return b, dec, precision, frac, nil +} diff --git a/pkg/util/cpu/binding__failpoint_binding__.go b/pkg/util/cpu/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..440594fb462e6 --- /dev/null +++ b/pkg/util/cpu/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package cpu + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/cpu/cpu.go b/pkg/util/cpu/cpu.go index 6d14cbfada58d..c6a72ae64de22 100644 --- a/pkg/util/cpu/cpu.go +++ b/pkg/util/cpu/cpu.go @@ -125,8 +125,8 @@ func getCPUTime() (userTimeMillis, sysTimeMillis int64, err error) { // GetCPUCount returns the number of logical CPUs usable by the current process. func GetCPUCount() int { - failpoint.Inject("mockNumCpu", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) + if val, _err_ := failpoint.Eval(_curpkg_("mockNumCpu")); _err_ == nil { + return val.(int) + } return runtime.GOMAXPROCS(0) } diff --git a/pkg/util/cpu/cpu.go__failpoint_stash__ b/pkg/util/cpu/cpu.go__failpoint_stash__ new file mode 100644 index 0000000000000..6d14cbfada58d --- /dev/null +++ b/pkg/util/cpu/cpu.go__failpoint_stash__ @@ -0,0 +1,132 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cpu + +import ( + "os" + "runtime" + "sync" + "time" + + sigar "github.com/cloudfoundry/gosigar" + "github.com/pingcap/failpoint" + "github.com/pingcap/log" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/util/cgroup" + "github.com/pingcap/tidb/pkg/util/mathutil" + "go.uber.org/atomic" + "go.uber.org/zap" +) + +var cpuUsage atomic.Float64 + +// If your kernel is lower than linux 4.7, you cannot get the cpu usage in the container. +var unsupported atomic.Bool + +// GetCPUUsage returns the cpu usage of the current process. +func GetCPUUsage() (float64, bool) { + return cpuUsage.Load(), unsupported.Load() +} + +// Observer is used to observe the cpu usage of the current process. +type Observer struct { + utime int64 + stime int64 + now int64 + exit chan struct{} + cpu mathutil.ExponentialMovingAverage + wg sync.WaitGroup +} + +// NewCPUObserver returns a cpu observer. +func NewCPUObserver() *Observer { + return &Observer{ + exit: make(chan struct{}), + now: time.Now().UnixNano(), + cpu: *mathutil.NewExponentialMovingAverage(0.95, 10), + } +} + +// Start starts the cpu observer. +func (c *Observer) Start() { + _, err := cgroup.GetCgroupCPU() + if err != nil { + unsupported.Store(true) + log.Error("GetCgroupCPU", zap.Error(err)) + return + } + c.wg.Add(1) + go func() { + ticker := time.NewTicker(100 * time.Millisecond) + defer func() { + ticker.Stop() + c.wg.Done() + }() + for { + select { + case <-ticker.C: + curr := c.observe() + c.cpu.Add(curr) + cpuUsage.Store(c.cpu.Get()) + metrics.EMACPUUsageGauge.Set(c.cpu.Get()) + case <-c.exit: + return + } + } + }() +} + +// Stop stops the cpu observer. +func (c *Observer) Stop() { + close(c.exit) + c.wg.Wait() +} + +func (c *Observer) observe() float64 { + user, sys, err := getCPUTime() + if err != nil { + log.Error("getCPUTime", zap.Error(err)) + } + cgroupCPU, _ := cgroup.GetCgroupCPU() + cpuShare := cgroupCPU.CPUShares() + now := time.Now().UnixNano() + dur := float64(now - c.now) + utime := user * 1e6 + stime := sys * 1e6 + urate := float64(utime-c.utime) / dur + srate := float64(stime-c.stime) / dur + c.now = now + c.utime = utime + c.stime = stime + return (srate + urate) / cpuShare +} + +// getCPUTime returns the cumulative user/system time (in ms) since the process start. +func getCPUTime() (userTimeMillis, sysTimeMillis int64, err error) { + pid := os.Getpid() + cpuTime := sigar.ProcTime{} + if err := cpuTime.Get(pid); err != nil { + return 0, 0, err + } + return int64(cpuTime.User), int64(cpuTime.Sys), nil +} + +// GetCPUCount returns the number of logical CPUs usable by the current process. +func GetCPUCount() int { + failpoint.Inject("mockNumCpu", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) + return runtime.GOMAXPROCS(0) +} diff --git a/pkg/util/etcd.go b/pkg/util/etcd.go index bff00a8d66428..66b28b29ae44f 100644 --- a/pkg/util/etcd.go +++ b/pkg/util/etcd.go @@ -51,21 +51,21 @@ func NewSession(ctx context.Context, logPrefix string, etcdCli *clientv3.Client, return etcdSession, errors.Trace(err) } - failpoint.Inject("closeClient", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("closeClient")); _err_ == nil { if val.(bool) { if err := etcdCli.Close(); err != nil { - failpoint.Return(etcdSession, errors.Trace(err)) + return etcdSession, errors.Trace(err) } } - }) + } - failpoint.Inject("closeGrpc", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("closeGrpc")); _err_ == nil { if val.(bool) { if err := etcdCli.ActiveConnection().Close(); err != nil { - failpoint.Return(etcdSession, errors.Trace(err)) + return etcdSession, errors.Trace(err) } } - }) + } startTime := time.Now() etcdSession, err = concurrency.NewSession(etcdCli, diff --git a/pkg/util/etcd.go__failpoint_stash__ b/pkg/util/etcd.go__failpoint_stash__ new file mode 100644 index 0000000000000..bff00a8d66428 --- /dev/null +++ b/pkg/util/etcd.go__failpoint_stash__ @@ -0,0 +1,103 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "context" + "math" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/util/logutil" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/client/v3/concurrency" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +const ( + newSessionRetryInterval = 200 * time.Millisecond + logIntervalCnt = int(3 * time.Second / newSessionRetryInterval) + + // NewSessionDefaultRetryCnt is the default retry times when create new session. + NewSessionDefaultRetryCnt = 3 + // NewSessionRetryUnlimited is the unlimited retry times when create new session. + NewSessionRetryUnlimited = math.MaxInt64 +) + +// NewSession creates a new etcd session. +func NewSession(ctx context.Context, logPrefix string, etcdCli *clientv3.Client, retryCnt, ttl int) (*concurrency.Session, error) { + var err error + + var etcdSession *concurrency.Session + failedCnt := 0 + for i := 0; i < retryCnt; i++ { + if err = contextDone(ctx, err); err != nil { + return etcdSession, errors.Trace(err) + } + + failpoint.Inject("closeClient", func(val failpoint.Value) { + if val.(bool) { + if err := etcdCli.Close(); err != nil { + failpoint.Return(etcdSession, errors.Trace(err)) + } + } + }) + + failpoint.Inject("closeGrpc", func(val failpoint.Value) { + if val.(bool) { + if err := etcdCli.ActiveConnection().Close(); err != nil { + failpoint.Return(etcdSession, errors.Trace(err)) + } + } + }) + + startTime := time.Now() + etcdSession, err = concurrency.NewSession(etcdCli, + concurrency.WithTTL(ttl), concurrency.WithContext(ctx)) + metrics.NewSessionHistogram.WithLabelValues(logPrefix, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) + if err == nil { + break + } + if failedCnt%logIntervalCnt == 0 { + logutil.BgLogger().Warn("failed to new session to etcd", zap.String("ownerInfo", logPrefix), zap.Error(err)) + } + + time.Sleep(newSessionRetryInterval) + failedCnt++ + } + return etcdSession, errors.Trace(err) +} + +func contextDone(ctx context.Context, err error) error { + select { + case <-ctx.Done(): + return errors.Trace(ctx.Err()) + default: + } + // Sometime the ctx isn't closed, but the etcd client is closed, + // we need to treat it as if context is done. + // TODO: Make sure ctx is closed with etcd client. + if terror.ErrorEqual(err, context.Canceled) || + terror.ErrorEqual(err, context.DeadlineExceeded) || + terror.ErrorEqual(err, grpc.ErrClientConnClosing) { + return errors.Trace(err) + } + + return nil +} diff --git a/pkg/util/gctuner/binding__failpoint_binding__.go b/pkg/util/gctuner/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..58a2d96183692 --- /dev/null +++ b/pkg/util/gctuner/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package gctuner + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/gctuner/memory_limit_tuner.go b/pkg/util/gctuner/memory_limit_tuner.go index dcf2ccb96ee22..027fa9f5b6ac8 100644 --- a/pkg/util/gctuner/memory_limit_tuner.go +++ b/pkg/util/gctuner/memory_limit_tuner.go @@ -107,17 +107,17 @@ func (t *memoryLimitTuner) tuning() { if intest.InTest { resetInterval = 3 * time.Second } - failpoint.Inject("mockUpdateGlobalVarDuringAdjustPercentage", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateGlobalVarDuringAdjustPercentage")); _err_ == nil { if val, ok := val.(bool); val && ok { time.Sleep(300 * time.Millisecond) t.UpdateMemoryLimit() } - }) - failpoint.Inject("testMemoryLimitTuner", func(val failpoint.Value) { + } + if val, _err_ := failpoint.Eval(_curpkg_("testMemoryLimitTuner")); _err_ == nil { if val, ok := val.(bool); val && ok { resetInterval = 1 * time.Second } - }) + } time.Sleep(resetInterval) debug.SetMemoryLimit(t.calcMemoryLimit(t.GetPercentage())) for !t.adjustPercentageInProgress.CompareAndSwap(true, false) { diff --git a/pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ b/pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ new file mode 100644 index 0000000000000..dcf2ccb96ee22 --- /dev/null +++ b/pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ @@ -0,0 +1,190 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gctuner + +import ( + "math" + "runtime/debug" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/intest" + "github.com/pingcap/tidb/pkg/util/memory" + atomicutil "go.uber.org/atomic" +) + +// GlobalMemoryLimitTuner only allow one memory limit tuner in one process +var GlobalMemoryLimitTuner = &memoryLimitTuner{} + +// Go runtime trigger GC when hit memory limit which managed via runtime/debug.SetMemoryLimit. +// So we can change memory limit dynamically to avoid frequent GC when memory usage is greater than the limit. +type memoryLimitTuner struct { + finalizer *finalizer + isValidValueSet atomicutil.Bool + percentage atomicutil.Float64 + adjustPercentageInProgress atomicutil.Bool + serverMemLimitBeforeAdjust atomicutil.Uint64 + percentageBeforeAdjust atomicutil.Float64 + nextGCTriggeredByMemoryLimit atomicutil.Bool + + // The flag to disable memory limit adjust. There might be many tasks need to activate it in future, + // so it is integer type. + adjustDisabled atomicutil.Int64 +} + +// fallbackPercentage indicates the fallback memory limit percentage when turning. +const fallbackPercentage float64 = 1.1 + +var memoryGoroutineCntInTest = *atomicutil.NewInt64(0) + +// WaitMemoryLimitTunerExitInTest is used to wait memory limit tuner exit in test. +func WaitMemoryLimitTunerExitInTest() { + if intest.InTest { + for memoryGoroutineCntInTest.Load() > 0 { + time.Sleep(100 * time.Millisecond) + } + } +} + +// DisableAdjustMemoryLimit makes memoryLimitTuner directly return `initGOMemoryLimitValue` when function `calcMemoryLimit` is called. +func (t *memoryLimitTuner) DisableAdjustMemoryLimit() { + t.adjustDisabled.Add(1) + debug.SetMemoryLimit(initGOMemoryLimitValue) +} + +// EnableAdjustMemoryLimit makes memoryLimitTuner return an adjusted memory limit when function `calcMemoryLimit` is called. +func (t *memoryLimitTuner) EnableAdjustMemoryLimit() { + t.adjustDisabled.Add(-1) + t.UpdateMemoryLimit() +} + +// tuning check the memory nextGC and judge whether this GC is trigger by memory limit. +// Go runtime ensure that it will be called serially. +func (t *memoryLimitTuner) tuning() { + if !t.isValidValueSet.Load() { + return + } + r := memory.ForceReadMemStats() + gogc := util.GetGOGC() + ratio := float64(100+gogc) / 100 + // This `if` checks whether the **last** GC was triggered by MemoryLimit as far as possible. + // If the **last** GC was triggered by MemoryLimit, we'll set MemoryLimit to MAXVALUE to return control back to GOGC + // to avoid frequent GC when memory usage fluctuates above and below MemoryLimit. + // The logic we judge whether the **last** GC was triggered by MemoryLimit is as follows: + // suppose `NextGC` = `HeapInUse * (100 + GOGC) / 100)`, + // - If NextGC < MemoryLimit, the **next** GC will **not** be triggered by MemoryLimit thus we do not care about + // why the **last** GC is triggered. And MemoryLimit will not be reset this time. + // - Only if NextGC >= MemoryLimit , the **next** GC will be triggered by MemoryLimit. Thus, we need to reset + // MemoryLimit after the **next** GC happens if needed. + if float64(r.HeapInuse)*ratio > float64(debug.SetMemoryLimit(-1)) { + if t.nextGCTriggeredByMemoryLimit.Load() && t.adjustPercentageInProgress.CompareAndSwap(false, true) { + // It's ok to update `adjustPercentageInProgress`, `serverMemLimitBeforeAdjust` and `percentageBeforeAdjust` not in a transaction. + // The update of memory limit is eventually consistent. + t.serverMemLimitBeforeAdjust.Store(memory.ServerMemoryLimit.Load()) + t.percentageBeforeAdjust.Store(t.GetPercentage()) + go func() { + if intest.InTest { + memoryGoroutineCntInTest.Inc() + defer memoryGoroutineCntInTest.Dec() + } + memory.MemoryLimitGCLast.Store(time.Now()) + memory.MemoryLimitGCTotal.Add(1) + debug.SetMemoryLimit(t.calcMemoryLimit(fallbackPercentage)) + resetInterval := 1 * time.Minute // Wait 1 minute and set back, to avoid frequent GC + if intest.InTest { + resetInterval = 3 * time.Second + } + failpoint.Inject("mockUpdateGlobalVarDuringAdjustPercentage", func(val failpoint.Value) { + if val, ok := val.(bool); val && ok { + time.Sleep(300 * time.Millisecond) + t.UpdateMemoryLimit() + } + }) + failpoint.Inject("testMemoryLimitTuner", func(val failpoint.Value) { + if val, ok := val.(bool); val && ok { + resetInterval = 1 * time.Second + } + }) + time.Sleep(resetInterval) + debug.SetMemoryLimit(t.calcMemoryLimit(t.GetPercentage())) + for !t.adjustPercentageInProgress.CompareAndSwap(true, false) { + continue + } + }() + memory.TriggerMemoryLimitGC.Store(true) + } + t.nextGCTriggeredByMemoryLimit.Store(true) + } else { + t.nextGCTriggeredByMemoryLimit.Store(false) + memory.TriggerMemoryLimitGC.Store(false) + } +} + +// Start starts the memory limit tuner. +func (t *memoryLimitTuner) Start() { + t.finalizer = newFinalizer(t.tuning) // Start tuning +} + +// Stop stops the memory limit tuner. +func (t *memoryLimitTuner) Stop() { + t.finalizer.stop() +} + +// SetPercentage set the percentage for memory limit tuner. +func (t *memoryLimitTuner) SetPercentage(percentage float64) { + t.percentage.Store(percentage) +} + +// GetPercentage get the percentage from memory limit tuner. +func (t *memoryLimitTuner) GetPercentage() float64 { + return t.percentage.Load() +} + +// UpdateMemoryLimit updates the memory limit. +// This function should be called when `tidb_server_memory_limit` or `tidb_server_memory_limit_gc_trigger` is modified. +func (t *memoryLimitTuner) UpdateMemoryLimit() { + if t.adjustPercentageInProgress.Load() { + if t.serverMemLimitBeforeAdjust.Load() == memory.ServerMemoryLimit.Load() && t.percentageBeforeAdjust.Load() == t.GetPercentage() { + return + } + } + var memoryLimit = t.calcMemoryLimit(t.GetPercentage()) + if memoryLimit == math.MaxInt64 { + t.isValidValueSet.Store(false) + memoryLimit = initGOMemoryLimitValue + } else { + t.isValidValueSet.Store(true) + } + debug.SetMemoryLimit(memoryLimit) +} + +func (t *memoryLimitTuner) calcMemoryLimit(percentage float64) int64 { + if t.adjustDisabled.Load() > 0 { + return initGOMemoryLimitValue + } + memoryLimit := int64(float64(memory.ServerMemoryLimit.Load()) * percentage) // `tidb_server_memory_limit` * `tidb_server_memory_limit_gc_trigger` + if memoryLimit == 0 { + memoryLimit = math.MaxInt64 + } + return memoryLimit +} + +var initGOMemoryLimitValue int64 + +func init() { + initGOMemoryLimitValue = debug.SetMemoryLimit(-1) + GlobalMemoryLimitTuner.Start() +} diff --git a/pkg/util/memory/binding__failpoint_binding__.go b/pkg/util/memory/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..2f327d1d1bb00 --- /dev/null +++ b/pkg/util/memory/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package memory + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/memory/meminfo.go b/pkg/util/memory/meminfo.go index 0d7bd68a2bcbe..f89304a4629cd 100644 --- a/pkg/util/memory/meminfo.go +++ b/pkg/util/memory/meminfo.go @@ -36,11 +36,11 @@ var MemUsed func() (uint64, error) // GetMemTotalIgnoreErr returns the total amount of RAM on this system/container. If error occurs, return 0. func GetMemTotalIgnoreErr() uint64 { if memTotal, err := MemTotal(); err == nil { - failpoint.Inject("GetMemTotalError", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("GetMemTotalError")); _err_ == nil { if val, ok := val.(bool); val && ok { memTotal = 0 } - }) + } return memTotal } return 0 diff --git a/pkg/util/memory/meminfo.go__failpoint_stash__ b/pkg/util/memory/meminfo.go__failpoint_stash__ new file mode 100644 index 0000000000000..0d7bd68a2bcbe --- /dev/null +++ b/pkg/util/memory/meminfo.go__failpoint_stash__ @@ -0,0 +1,215 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memory + +import ( + "sync" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/sysutil" + "github.com/pingcap/tidb/pkg/parser/terror" + "github.com/pingcap/tidb/pkg/util/cgroup" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/shirou/gopsutil/v3/mem" + "go.uber.org/zap" +) + +// MemTotal returns the total amount of RAM on this system +var MemTotal func() (uint64, error) + +// MemUsed returns the total used amount of RAM on this system +var MemUsed func() (uint64, error) + +// GetMemTotalIgnoreErr returns the total amount of RAM on this system/container. If error occurs, return 0. +func GetMemTotalIgnoreErr() uint64 { + if memTotal, err := MemTotal(); err == nil { + failpoint.Inject("GetMemTotalError", func(val failpoint.Value) { + if val, ok := val.(bool); val && ok { + memTotal = 0 + } + }) + return memTotal + } + return 0 +} + +// MemTotalNormal returns the total amount of RAM on this system in non-container environment. +func MemTotalNormal() (uint64, error) { + total, t := memLimit.get() + if time.Since(t) < 60*time.Second { + return total, nil + } + return memTotalNormal() +} + +func memTotalNormal() (uint64, error) { + v, err := mem.VirtualMemory() + if err != nil { + return 0, err + } + memLimit.set(v.Total, time.Now()) + return v.Total, nil +} + +// MemUsedNormal returns the total used amount of RAM on this system in non-container environment. +func MemUsedNormal() (uint64, error) { + used, t := memUsage.get() + if time.Since(t) < 500*time.Millisecond { + return used, nil + } + v, err := mem.VirtualMemory() + if err != nil { + return 0, err + } + memUsage.set(v.Used, time.Now()) + return v.Used, nil +} + +type memInfoCache struct { + updateTime time.Time + mu *sync.RWMutex + mem uint64 +} + +func (c *memInfoCache) get() (memo uint64, t time.Time) { + c.mu.RLock() + defer c.mu.RUnlock() + memo, t = c.mem, c.updateTime + return +} + +func (c *memInfoCache) set(memo uint64, t time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + c.mem, c.updateTime = memo, t +} + +// expiration time is 60s +var memLimit *memInfoCache + +// expiration time is 500ms +var memUsage *memInfoCache + +// expiration time is 500ms +// save the memory usage of the server process +var serverMemUsage *memInfoCache + +// MemTotalCGroup returns the total amount of RAM on this system in container environment. +func MemTotalCGroup() (uint64, error) { + memo, t := memLimit.get() + if time.Since(t) < 60*time.Second { + return memo, nil + } + memo, err := cgroup.GetMemoryLimit() + if err != nil { + return memo, err + } + v, err := mem.VirtualMemory() + if err != nil { + return 0, err + } + memo = min(v.Total, memo) + memLimit.set(memo, time.Now()) + return memo, nil +} + +// MemUsedCGroup returns the total used amount of RAM on this system in container environment. +func MemUsedCGroup() (uint64, error) { + memo, t := memUsage.get() + if time.Since(t) < 500*time.Millisecond { + return memo, nil + } + memo, err := cgroup.GetMemoryUsage() + if err != nil { + return memo, err + } + v, err := mem.VirtualMemory() + if err != nil { + return 0, err + } + memo = min(v.Used, memo) + memUsage.set(memo, time.Now()) + return memo, nil +} + +// it is for test and init. +func init() { + if cgroup.InContainer() { + MemTotal = MemTotalCGroup + MemUsed = MemUsedCGroup + sysutil.RegisterGetMemoryCapacity(MemTotalCGroup) + } else { + MemTotal = MemTotalNormal + MemUsed = MemUsedNormal + } + memLimit = &memInfoCache{ + mu: &sync.RWMutex{}, + } + memUsage = &memInfoCache{ + mu: &sync.RWMutex{}, + } + serverMemUsage = &memInfoCache{ + mu: &sync.RWMutex{}, + } + _, err := MemTotal() + terror.MustNil(err) + _, err = MemUsed() + terror.MustNil(err) +} + +// InitMemoryHook initializes the memory hook. +// It is to solve the problem that tidb cannot read cgroup in the systemd. +// so if we are not in the container, we compare the cgroup memory limit and the physical memory, +// the cgroup memory limit is smaller, we use the cgroup memory hook. +func InitMemoryHook() { + if cgroup.InContainer() { + logutil.BgLogger().Info("use cgroup memory hook because TiDB is in the container") + return + } + cgroupValue, err := cgroup.GetMemoryLimit() + if err != nil { + return + } + physicalValue, err := memTotalNormal() + if err != nil { + return + } + if physicalValue > cgroupValue && cgroupValue != 0 { + MemTotal = MemTotalCGroup + MemUsed = MemUsedCGroup + sysutil.RegisterGetMemoryCapacity(MemTotalCGroup) + logutil.BgLogger().Info("use cgroup memory hook", zap.Int64("cgroupMemorySize", int64(cgroupValue)), zap.Int64("physicalMemorySize", int64(physicalValue))) + } else { + logutil.BgLogger().Info("use physical memory hook", zap.Int64("cgroupMemorySize", int64(cgroupValue)), zap.Int64("physicalMemorySize", int64(physicalValue))) + } + _, err = MemTotal() + terror.MustNil(err) + _, err = MemUsed() + terror.MustNil(err) +} + +// InstanceMemUsed returns the memory usage of this TiDB server +func InstanceMemUsed() (uint64, error) { + used, t := serverMemUsage.get() + if time.Since(t) < 500*time.Millisecond { + return used, nil + } + var memoryUsage uint64 + instanceStats := ReadMemStats() + memoryUsage = instanceStats.HeapAlloc + serverMemUsage.set(memoryUsage, time.Now()) + return memoryUsage, nil +} diff --git a/pkg/util/memory/memstats.go b/pkg/util/memory/memstats.go index 4ea192620bee2..b0d53f08d59df 100644 --- a/pkg/util/memory/memstats.go +++ b/pkg/util/memory/memstats.go @@ -35,10 +35,10 @@ func ReadMemStats() (memStats *runtime.MemStats) { } else { memStats = ForceReadMemStats() } - failpoint.Inject("ReadMemStats", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("ReadMemStats")); _err_ == nil { injectedSize := val.(int) memStats = &runtime.MemStats{HeapInuse: memStats.HeapInuse + uint64(injectedSize)} - }) + } return } diff --git a/pkg/util/memory/memstats.go__failpoint_stash__ b/pkg/util/memory/memstats.go__failpoint_stash__ new file mode 100644 index 0000000000000..4ea192620bee2 --- /dev/null +++ b/pkg/util/memory/memstats.go__failpoint_stash__ @@ -0,0 +1,57 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memory + +import ( + "runtime" + "sync/atomic" + "time" + + "github.com/pingcap/failpoint" +) + +var stats atomic.Pointer[globalMstats] + +// ReadMemInterval controls the interval to read memory stats. +const ReadMemInterval = 300 * time.Millisecond + +// ReadMemStats read the mem stats from runtime.ReadMemStats +func ReadMemStats() (memStats *runtime.MemStats) { + s := stats.Load() + if s != nil { + memStats = &s.m + } else { + memStats = ForceReadMemStats() + } + failpoint.Inject("ReadMemStats", func(val failpoint.Value) { + injectedSize := val.(int) + memStats = &runtime.MemStats{HeapInuse: memStats.HeapInuse + uint64(injectedSize)} + }) + return +} + +// ForceReadMemStats is to force read memory stats. +func ForceReadMemStats() *runtime.MemStats { + var g globalMstats + g.ts = time.Now() + runtime.ReadMemStats(&g.m) + stats.Store(&g) + return &g.m +} + +type globalMstats struct { + ts time.Time + m runtime.MemStats +} diff --git a/pkg/util/replayer/binding__failpoint_binding__.go b/pkg/util/replayer/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..69f926fdeb093 --- /dev/null +++ b/pkg/util/replayer/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package replayer + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/replayer/replayer.go b/pkg/util/replayer/replayer.go index 9711479fb6c41..d4b1c12643a0e 100644 --- a/pkg/util/replayer/replayer.go +++ b/pkg/util/replayer/replayer.go @@ -59,9 +59,9 @@ func GeneratePlanReplayerFileName(isCapture, isContinuesCapture, enableHistorica func generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (string, error) { // Generate key and create zip file time := time.Now().UnixNano() - failpoint.Inject("InjectPlanReplayerFileNameTimeField", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("InjectPlanReplayerFileNameTimeField")); _err_ == nil { time = int64(val.(int)) - }) + } b := make([]byte, 16) //nolint: gosec _, err := rand.Read(b) diff --git a/pkg/util/replayer/replayer.go__failpoint_stash__ b/pkg/util/replayer/replayer.go__failpoint_stash__ new file mode 100644 index 0000000000000..9711479fb6c41 --- /dev/null +++ b/pkg/util/replayer/replayer.go__failpoint_stash__ @@ -0,0 +1,87 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package replayer + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/config" +) + +// PlanReplayerTaskKey indicates key of a plan replayer task +type PlanReplayerTaskKey struct { + SQLDigest string + PlanDigest string +} + +// GeneratePlanReplayerFile generates plan replayer file +func GeneratePlanReplayerFile(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (*os.File, string, error) { + path := GetPlanReplayerDirName() + err := os.MkdirAll(path, os.ModePerm) + if err != nil { + return nil, "", errors.AddStack(err) + } + fileName, err := generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture) + if err != nil { + return nil, "", errors.AddStack(err) + } + zf, err := os.Create(filepath.Join(path, fileName)) + if err != nil { + return nil, "", errors.AddStack(err) + } + return zf, fileName, err +} + +// GeneratePlanReplayerFileName generates plan replayer capture task name +func GeneratePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (string, error) { + return generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture) +} + +func generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (string, error) { + // Generate key and create zip file + time := time.Now().UnixNano() + failpoint.Inject("InjectPlanReplayerFileNameTimeField", func(val failpoint.Value) { + time = int64(val.(int)) + }) + b := make([]byte, 16) + //nolint: gosec + _, err := rand.Read(b) + if err != nil { + return "", err + } + key := base64.URLEncoding.EncodeToString(b) + // "capture_replayer" in filename has special meaning for the /plan_replayer/dump/ HTTP handler + if isContinuesCapture || isCapture && enableHistoricalStatsForCapture { + return fmt.Sprintf("capture_replayer_%v_%v.zip", key, time), nil + } + if isCapture && !enableHistoricalStatsForCapture { + return fmt.Sprintf("capture_normal_replayer_%v_%v.zip", key, time), nil + } + return fmt.Sprintf("replayer_%v_%v.zip", key, time), nil +} + +// GetPlanReplayerDirName returns plan replayer directory path. +// The path is related to the process id. +func GetPlanReplayerDirName() string { + tidbLogDir := filepath.Dir(config.GetGlobalConfig().Log.File.Filename) + return filepath.Join(tidbLogDir, "replayer") +} diff --git a/pkg/util/servermemorylimit/binding__failpoint_binding__.go b/pkg/util/servermemorylimit/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..a3d497c09b313 --- /dev/null +++ b/pkg/util/servermemorylimit/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package servermemorylimit + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/servermemorylimit/servermemorylimit.go b/pkg/util/servermemorylimit/servermemorylimit.go index 47022d23c6635..b7bb2c01b61fa 100644 --- a/pkg/util/servermemorylimit/servermemorylimit.go +++ b/pkg/util/servermemorylimit/servermemorylimit.go @@ -136,11 +136,11 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { if bt == 0 { return } - failpoint.Inject("issue42662_2", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("issue42662_2")); _err_ == nil { if val.(bool) { bt = 1 } - }) + } instanceStats := memory.ReadMemStats() if instanceStats.HeapInuse > MemoryMaxUsed.Load() { MemoryMaxUsed.Store(instanceStats.HeapInuse) diff --git a/pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ b/pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ new file mode 100644 index 0000000000000..47022d23c6635 --- /dev/null +++ b/pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ @@ -0,0 +1,264 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package servermemorylimit + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser/mysql" + "github.com/pingcap/tidb/pkg/types" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/memory" + "github.com/pingcap/tidb/pkg/util/sqlkiller" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// Process global Observation indicators for memory limit. +var ( + MemoryMaxUsed = atomicutil.NewUint64(0) + SessionKillLast = atomicutil.NewTime(time.Time{}) + SessionKillTotal = atomicutil.NewInt64(0) + IsKilling = atomicutil.NewBool(false) + GlobalMemoryOpsHistoryManager = &memoryOpsHistoryManager{} +) + +// Handle is the handler for server memory limit. +type Handle struct { + exitCh chan struct{} + sm atomic.Value +} + +// NewServerMemoryLimitHandle builds a new server memory limit handler. +func NewServerMemoryLimitHandle(exitCh chan struct{}) *Handle { + return &Handle{exitCh: exitCh} +} + +// SetSessionManager sets the SessionManager which is used to fetching the info +// of all active sessions. +func (smqh *Handle) SetSessionManager(sm util.SessionManager) *Handle { + smqh.sm.Store(sm) + return smqh +} + +// Run starts a server memory limit checker goroutine at the start time of the server. +// This goroutine will obtain the `heapInuse` of Golang runtime periodically and compare it with `tidb_server_memory_limit`. +// When `heapInuse` is greater than `tidb_server_memory_limit`, it will set the `needKill` flag of `MemUsageTop1Tracker`. +// When the corresponding SQL try to acquire more memory(next Tracker.Consume() call), it will trigger panic and exit. +// When this goroutine detects the `needKill` SQL has exited successfully, it will immediately trigger runtime.GC() to release memory resources. +func (smqh *Handle) Run() { + tickInterval := time.Millisecond * time.Duration(100) + ticker := time.NewTicker(tickInterval) + defer ticker.Stop() + sm := smqh.sm.Load().(util.SessionManager) + sessionToBeKilled := &sessionToBeKilled{} + for { + select { + case <-ticker.C: + killSessIfNeeded(sessionToBeKilled, memory.ServerMemoryLimit.Load(), sm) + case <-smqh.exitCh: + return + } + } +} + +type sessionToBeKilled struct { + isKilling bool + sqlStartTime time.Time + sessionID uint64 + sessionTracker *memory.Tracker + + killStartTime time.Time + lastLogTime time.Time +} + +func (s *sessionToBeKilled) reset() { + s.isKilling = false + s.sqlStartTime = time.Time{} + s.sessionID = 0 + s.sessionTracker = nil + s.killStartTime = time.Time{} + s.lastLogTime = time.Time{} +} + +func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { + if s.isKilling { + if info, ok := sm.GetProcessInfo(s.sessionID); ok { + if info.Time == s.sqlStartTime { + if time.Since(s.lastLogTime) > 5*time.Second { + logutil.BgLogger().Warn(fmt.Sprintf("global memory controller failed to kill the top-consumer in %ds", + time.Since(s.killStartTime)/time.Second), + zap.Uint64("conn", info.ID), + zap.String("sql digest", info.Digest), + zap.String("sql text", fmt.Sprintf("%.100v", info.Info)), + zap.Int64("sql memory usage", info.MemTracker.BytesConsumed())) + s.lastLogTime = time.Now() + + if seconds := time.Since(s.killStartTime) / time.Second; seconds >= 60 { + // If the SQL cannot be terminated after 60 seconds, it may be stuck in the network stack while writing packets to the client, + // encountering some bugs that cause it to hang, or failing to detect the kill signal. + // In this case, the resources can be reclaimed by calling the `Finish` method, and then we can start looking for the next SQL with the largest memory usage. + logutil.BgLogger().Warn(fmt.Sprintf("global memory controller failed to kill the top-consumer in %d seconds. Attempting to force close the executors.", seconds)) + s.sessionTracker.Killer.FinishResultSet() + goto Succ + } + } + return + } + } + Succ: + s.reset() + IsKilling.Store(false) + memory.MemUsageTop1Tracker.CompareAndSwap(s.sessionTracker, nil) + //nolint: all_revive,revive + runtime.GC() + logutil.BgLogger().Warn("global memory controller killed the top1 memory consumer successfully") + } + + if bt == 0 { + return + } + failpoint.Inject("issue42662_2", func(val failpoint.Value) { + if val.(bool) { + bt = 1 + } + }) + instanceStats := memory.ReadMemStats() + if instanceStats.HeapInuse > MemoryMaxUsed.Load() { + MemoryMaxUsed.Store(instanceStats.HeapInuse) + } + limitSessMinSize := memory.ServerMemoryLimitSessMinSize.Load() + if instanceStats.HeapInuse > bt { + t := memory.MemUsageTop1Tracker.Load() + if t != nil { + sessionID := t.SessionID.Load() + memUsage := t.BytesConsumed() + // If the memory usage of the top1 session is less than tidb_server_memory_limit_sess_min_size, we do not need to kill it. + if uint64(memUsage) < limitSessMinSize { + memory.MemUsageTop1Tracker.CompareAndSwap(t, nil) + t = nil + } else if info, ok := sm.GetProcessInfo(sessionID); ok { + logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer", + zap.Uint64("conn", info.ID), + zap.String("sql digest", info.Digest), + zap.String("sql text", fmt.Sprintf("%.100v", info.Info)), + zap.Uint64("tidb_server_memory_limit", bt), + zap.Uint64("heap inuse", instanceStats.HeapInuse), + zap.Int64("sql memory usage", info.MemTracker.BytesConsumed()), + ) + s.sessionID = sessionID + s.sqlStartTime = info.Time + s.isKilling = true + s.sessionTracker = t + t.Killer.SendKillSignal(sqlkiller.ServerMemoryExceeded) + + killTime := time.Now() + SessionKillTotal.Add(1) + SessionKillLast.Store(killTime) + IsKilling.Store(true) + GlobalMemoryOpsHistoryManager.recordOne(info, killTime, bt, instanceStats.HeapInuse) + s.lastLogTime = time.Now() + s.killStartTime = time.Now() + } + } + // If no one larger than tidb_server_memory_limit_sess_min_size is found, we will not kill any one. + if t == nil { + if s.lastLogTime.IsZero() { + s.lastLogTime = time.Now() + } + if time.Since(s.lastLogTime) < 5*time.Second { + return + } + logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer, but no one larger than tidb_server_memory_limit_sess_min_size is found", zap.Uint64("tidb_server_memory_limit_sess_min_size", limitSessMinSize)) + s.lastLogTime = time.Now() + } + } +} + +type memoryOpsHistoryManager struct { + mu sync.Mutex + infos []memoryOpsHistory + offsets int +} + +type memoryOpsHistory struct { + killTime time.Time + memoryLimit uint64 + memoryCurrent uint64 + processInfoDatum []types.Datum // id,user,host,db,command,time,state,info,digest,mem,disk,txnStart +} + +func (m *memoryOpsHistoryManager) init() { + m.infos = make([]memoryOpsHistory, 50) + m.offsets = 0 +} + +func (m *memoryOpsHistoryManager) recordOne(info *util.ProcessInfo, killTime time.Time, memoryLimit uint64, memoryCurrent uint64) { + m.mu.Lock() + defer m.mu.Unlock() + op := memoryOpsHistory{killTime: killTime, memoryLimit: memoryLimit, memoryCurrent: memoryCurrent, processInfoDatum: types.MakeDatums(info.ToRow(time.UTC)...)} + sqlInfo := op.processInfoDatum[7] + sqlInfo.SetString(fmt.Sprintf("%.256v", sqlInfo.GetString()), mysql.DefaultCollationName) // Truncated + // Only record the last 50 history ops + m.infos[m.offsets] = op + m.offsets++ + if m.offsets >= 50 { + m.offsets = 0 + } +} + +func (m *memoryOpsHistoryManager) GetRows() [][]types.Datum { + m.mu.Lock() + defer m.mu.Unlock() + rows := make([][]types.Datum, 0, len(m.infos)) + getRowFromInfo := func(info memoryOpsHistory) { + killTime := types.NewTime(types.FromGoTime(info.killTime), mysql.TypeDatetime, 0) + op := "SessionKill" + rows = append(rows, []types.Datum{ + types.NewDatum(killTime), // TIME + types.NewDatum(op), // OPS + types.NewDatum(info.memoryLimit), // MEMORY_LIMIT + types.NewDatum(info.memoryCurrent), // MEMORY_CURRENT + info.processInfoDatum[0], // PROCESSID + info.processInfoDatum[9], // MEM + info.processInfoDatum[10], // DISK + info.processInfoDatum[2], // CLIENT + info.processInfoDatum[3], // DB + info.processInfoDatum[1], // USER + info.processInfoDatum[8], // SQL_DIGEST + info.processInfoDatum[7], // SQL_TEXT + }) + } + var zeroTime = time.Time{} + for i := 0; i < len(m.infos); i++ { + pos := (m.offsets + i) % len(m.infos) + info := m.infos[pos] + if info.killTime.Equal(zeroTime) { + continue + } + getRowFromInfo(info) + } + return rows +} + +func init() { + GlobalMemoryOpsHistoryManager.init() +} diff --git a/pkg/util/session_pool.go b/pkg/util/session_pool.go index 95f5dd9515b43..f233b04d4888d 100644 --- a/pkg/util/session_pool.go +++ b/pkg/util/session_pool.go @@ -66,9 +66,9 @@ func (p *pool) Get() (resource pools.Resource, err error) { } // Put the internal session to the map of SessionManager - failpoint.Inject("mockSessionPoolReturnError", func() { + if _, _err_ := failpoint.Eval(_curpkg_("mockSessionPoolReturnError")); _err_ == nil { err = errors.New("mockSessionPoolReturnError") - }) + } if err == nil && p.getCallback != nil { p.getCallback(resource) diff --git a/pkg/util/session_pool.go__failpoint_stash__ b/pkg/util/session_pool.go__failpoint_stash__ new file mode 100644 index 0000000000000..95f5dd9515b43 --- /dev/null +++ b/pkg/util/session_pool.go__failpoint_stash__ @@ -0,0 +1,113 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "errors" + "sync" + + "github.com/ngaut/pools" + "github.com/pingcap/failpoint" +) + +// SessionPool is a recyclable resource pool for the session. +type SessionPool interface { + Get() (pools.Resource, error) + Put(pools.Resource) + Close() +} + +// resourceCallback is a helper function to be triggered after Get/Put call. +type resourceCallback func(pools.Resource) + +type pool struct { + resources chan pools.Resource + factory pools.Factory + mu struct { + sync.RWMutex + closed bool + } + getCallback resourceCallback + putCallback resourceCallback +} + +// NewSessionPool creates a new session pool with the given capacity and factory function. +func NewSessionPool(capacity int, factory pools.Factory, getCallback, putCallback resourceCallback) SessionPool { + return &pool{ + resources: make(chan pools.Resource, capacity), + factory: factory, + getCallback: getCallback, + putCallback: putCallback, + } +} + +// Get gets a session from the session pool. +func (p *pool) Get() (resource pools.Resource, err error) { + var ok bool + select { + case resource, ok = <-p.resources: + if !ok { + err = errors.New("session pool closed") + } + default: + resource, err = p.factory() + } + + // Put the internal session to the map of SessionManager + failpoint.Inject("mockSessionPoolReturnError", func() { + err = errors.New("mockSessionPoolReturnError") + }) + + if err == nil && p.getCallback != nil { + p.getCallback(resource) + } + + return +} + +// Put puts the session back to the pool. +func (p *pool) Put(resource pools.Resource) { + p.mu.RLock() + defer p.mu.RUnlock() + if p.putCallback != nil { + p.putCallback(resource) + } + if p.mu.closed { + resource.Close() + return + } + + select { + case p.resources <- resource: + default: + resource.Close() + } +} + +// Close closes the pool to release all resources. +func (p *pool) Close() { + p.mu.Lock() + if p.mu.closed { + p.mu.Unlock() + return + } + p.mu.closed = true + close(p.resources) + p.mu.Unlock() + + for r := range p.resources { + r.Close() + } +} diff --git a/pkg/util/sli/binding__failpoint_binding__.go b/pkg/util/sli/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..c9eff40605613 --- /dev/null +++ b/pkg/util/sli/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package sli + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/sli/sli.go b/pkg/util/sli/sli.go index 4acdee79e5e09..1917255c28585 100644 --- a/pkg/util/sli/sli.go +++ b/pkg/util/sli/sli.go @@ -50,9 +50,9 @@ func (t *TxnWriteThroughputSLI) FinishExecuteStmt(cost time.Duration, affectRow t.reportMetric() // Skip reset for test. - failpoint.Inject("CheckTxnWriteThroughput", func() { - failpoint.Return() - }) + if _, _err_ := failpoint.Eval(_curpkg_("CheckTxnWriteThroughput")); _err_ == nil { + return + } // Reset for next transaction. t.Reset() diff --git a/pkg/util/sli/sli.go__failpoint_stash__ b/pkg/util/sli/sli.go__failpoint_stash__ new file mode 100644 index 0000000000000..4acdee79e5e09 --- /dev/null +++ b/pkg/util/sli/sli.go__failpoint_stash__ @@ -0,0 +1,120 @@ +// Copyright 2021 PingCAP, Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sli + +import ( + "fmt" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/metrics" +) + +// TxnWriteThroughputSLI uses to report transaction write throughput metrics for SLI. +type TxnWriteThroughputSLI struct { + invalid bool + affectRow uint64 + writeSize int + readKeys int + writeKeys int + writeTime time.Duration +} + +// FinishExecuteStmt records the cost for write statement which affect rows more than 0. +// And report metrics when the transaction is committed. +func (t *TxnWriteThroughputSLI) FinishExecuteStmt(cost time.Duration, affectRow uint64, inTxn bool) { + if affectRow > 0 { + t.writeTime += cost + t.affectRow += affectRow + } + + // Currently not in transaction means the last transaction is finish, should report metrics and reset data. + if !inTxn { + if affectRow == 0 { + // AffectRows is 0 when statement is commit. + t.writeTime += cost + } + // Report metrics after commit this transaction. + t.reportMetric() + + // Skip reset for test. + failpoint.Inject("CheckTxnWriteThroughput", func() { + failpoint.Return() + }) + + // Reset for next transaction. + t.Reset() + } +} + +// AddReadKeys adds the read keys. +func (t *TxnWriteThroughputSLI) AddReadKeys(readKeys int64) { + t.readKeys += int(readKeys) +} + +// AddTxnWriteSize adds the transaction write size and keys. +func (t *TxnWriteThroughputSLI) AddTxnWriteSize(size, keys int) { + t.writeSize += size + t.writeKeys += keys +} + +func (t *TxnWriteThroughputSLI) reportMetric() { + if t.IsInvalid() { + return + } + if t.IsSmallTxn() { + metrics.SmallTxnWriteDuration.Observe(t.writeTime.Seconds()) + } else { + metrics.TxnWriteThroughput.Observe(float64(t.writeSize) / t.writeTime.Seconds()) + } +} + +// SetInvalid marks this transaction is invalid to report SLI metrics. +func (t *TxnWriteThroughputSLI) SetInvalid() { + t.invalid = true +} + +// IsInvalid checks the transaction is valid to report SLI metrics. Currently, the following case will cause invalid: +// 1. The transaction contains `insert|replace into ... select ... from ...` statement. +// 2. The write SQL statement has more read keys than write keys. +func (t *TxnWriteThroughputSLI) IsInvalid() bool { + return t.invalid || t.readKeys > t.writeKeys || t.writeSize == 0 || t.writeTime == 0 +} + +const ( + smallTxnAffectRow = 20 + smallTxnSize = 1 * 1024 * 1024 // 1MB +) + +// IsSmallTxn exports for testing. +func (t *TxnWriteThroughputSLI) IsSmallTxn() bool { + return t.affectRow <= smallTxnAffectRow && t.writeSize <= smallTxnSize +} + +// Reset exports for testing. +func (t *TxnWriteThroughputSLI) Reset() { + t.invalid = false + t.affectRow = 0 + t.writeSize = 0 + t.readKeys = 0 + t.writeKeys = 0 + t.writeTime = 0 +} + +// String exports for testing. +func (t *TxnWriteThroughputSLI) String() string { + return fmt.Sprintf("invalid: %v, affectRow: %v, writeSize: %v, readKeys: %v, writeKeys: %v, writeTime: %v", + t.invalid, t.affectRow, t.writeSize, t.readKeys, t.writeKeys, t.writeTime.String()) +} diff --git a/pkg/util/sqlkiller/binding__failpoint_binding__.go b/pkg/util/sqlkiller/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..02be0a09bb3d4 --- /dev/null +++ b/pkg/util/sqlkiller/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package sqlkiller + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/sqlkiller/sqlkiller.go b/pkg/util/sqlkiller/sqlkiller.go index 06782653a5d05..bb3553f77fa8d 100644 --- a/pkg/util/sqlkiller/sqlkiller.go +++ b/pkg/util/sqlkiller/sqlkiller.go @@ -108,7 +108,7 @@ func (killer *SQLKiller) ClearFinishFunc() { // HandleSignal handles the kill signal and return the error. func (killer *SQLKiller) HandleSignal() error { - failpoint.Inject("randomPanic", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("randomPanic")); _err_ == nil { if p, ok := val.(int); ok { if rand.Float64() > (float64)(p)/1000 { if killer.ConnID != 0 { @@ -117,7 +117,7 @@ func (killer *SQLKiller) HandleSignal() error { } } } - }) + } status := atomic.LoadUint32(&killer.Signal) err := killer.getKillError(status) if status == ServerMemoryExceeded { diff --git a/pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ b/pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ new file mode 100644 index 0000000000000..06782653a5d05 --- /dev/null +++ b/pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ @@ -0,0 +1,136 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlkiller + +import ( + "math/rand" + "sync" + "sync/atomic" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +type killSignal = uint32 + +// KillSignal types. +const ( + UnspecifiedKillSignal killSignal = iota + QueryInterrupted + MaxExecTimeExceeded + QueryMemoryExceeded + ServerMemoryExceeded + // When you add a new signal, you should also modify store/driver/error/ToTidbErr, + // so that errors in client can be correctly converted to tidb errors. +) + +// SQLKiller is used to kill a query. +type SQLKiller struct { + Signal killSignal + ConnID uint64 + // FinishFuncLock is used to ensure that Finish is not called and modified at the same time. + // An external call to the Finish function only allows when the main goroutine to be in the writeResultSet process. + // When the main goroutine exits the writeResultSet process, the Finish function will be cleared. + FinishFuncLock sync.Mutex + Finish func() + // InWriteResultSet is used to indicate whether the query is currently calling clientConn.writeResultSet(). + // If the query is in writeResultSet and Finish() can acquire rs.finishLock, we can assume the query is waiting for the client to receive data from the server over network I/O. + InWriteResultSet atomic.Bool +} + +// SendKillSignal sends a kill signal to the query. +func (killer *SQLKiller) SendKillSignal(reason killSignal) { + if atomic.CompareAndSwapUint32(&killer.Signal, 0, reason) { + status := atomic.LoadUint32(&killer.Signal) + err := killer.getKillError(status) + logutil.BgLogger().Warn("kill initiated", zap.Uint64("connection ID", killer.ConnID), zap.String("reason", err.Error())) + } +} + +// GetKillSignal gets the kill signal. +func (killer *SQLKiller) GetKillSignal() killSignal { + return atomic.LoadUint32(&killer.Signal) +} + +// getKillError gets the error according to the kill signal. +func (killer *SQLKiller) getKillError(status killSignal) error { + switch status { + case QueryInterrupted: + return exeerrors.ErrQueryInterrupted.GenWithStackByArgs() + case MaxExecTimeExceeded: + return exeerrors.ErrMaxExecTimeExceeded.GenWithStackByArgs() + case QueryMemoryExceeded: + return exeerrors.ErrMemoryExceedForQuery.GenWithStackByArgs(killer.ConnID) + case ServerMemoryExceeded: + return exeerrors.ErrMemoryExceedForInstance.GenWithStackByArgs(killer.ConnID) + } + return nil +} + +// FinishResultSet is used to close the result set. +// If a kill signal is sent but the SQL query is stuck in the network stack while writing packets to the client, +// encountering some bugs that cause it to hang, or failing to detect the kill signal, we can call Finish to release resources used during the SQL execution process. +func (killer *SQLKiller) FinishResultSet() { + killer.FinishFuncLock.Lock() + defer killer.FinishFuncLock.Unlock() + if killer.Finish != nil { + killer.Finish() + } +} + +// SetFinishFunc sets the finish function. +func (killer *SQLKiller) SetFinishFunc(fn func()) { + killer.FinishFuncLock.Lock() + defer killer.FinishFuncLock.Unlock() + killer.Finish = fn +} + +// ClearFinishFunc clears the finish function.1 +func (killer *SQLKiller) ClearFinishFunc() { + killer.FinishFuncLock.Lock() + defer killer.FinishFuncLock.Unlock() + killer.Finish = nil +} + +// HandleSignal handles the kill signal and return the error. +func (killer *SQLKiller) HandleSignal() error { + failpoint.Inject("randomPanic", func(val failpoint.Value) { + if p, ok := val.(int); ok { + if rand.Float64() > (float64)(p)/1000 { + if killer.ConnID != 0 { + targetStatus := rand.Int31n(5) + atomic.StoreUint32(&killer.Signal, uint32(targetStatus)) + } + } + } + }) + status := atomic.LoadUint32(&killer.Signal) + err := killer.getKillError(status) + if status == ServerMemoryExceeded { + logutil.BgLogger().Warn("global memory controller, NeedKill signal is received successfully", + zap.Uint64("conn", killer.ConnID)) + } + return err +} + +// Reset resets the SqlKiller. +func (killer *SQLKiller) Reset() { + if atomic.LoadUint32(&killer.Signal) != 0 { + logutil.BgLogger().Warn("kill finished", zap.Uint64("conn", killer.ConnID)) + } + atomic.StoreUint32(&killer.Signal, 0) +} diff --git a/pkg/util/stmtsummary/binding__failpoint_binding__.go b/pkg/util/stmtsummary/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..bc46da5472193 --- /dev/null +++ b/pkg/util/stmtsummary/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package stmtsummary + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/stmtsummary/statement_summary.go b/pkg/util/stmtsummary/statement_summary.go index e9dc133b1c46e..636cf4fa991fd 100644 --- a/pkg/util/stmtsummary/statement_summary.go +++ b/pkg/util/stmtsummary/statement_summary.go @@ -297,7 +297,7 @@ func (ssMap *stmtSummaryByDigestMap) AddStatement(sei *StmtExecInfo) { // All times are counted in seconds. now := time.Now().Unix() - failpoint.Inject("mockTimeForStatementsSummary", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockTimeForStatementsSummary")); _err_ == nil { // mockTimeForStatementsSummary takes string of Unix timestamp if unixTimeStr, ok := val.(string); ok { unixTime, err := strconv.ParseInt(unixTimeStr, 10, 64) @@ -306,7 +306,7 @@ func (ssMap *stmtSummaryByDigestMap) AddStatement(sei *StmtExecInfo) { } now = unixTime } - }) + } intervalSeconds := ssMap.refreshInterval() historySize := ssMap.historySize() diff --git a/pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ b/pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ new file mode 100644 index 0000000000000..e9dc133b1c46e --- /dev/null +++ b/pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ @@ -0,0 +1,1039 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmtsummary + +import ( + "bytes" + "cmp" + "container/list" + "fmt" + "math" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/execdetails" + "github.com/pingcap/tidb/pkg/util/hack" + "github.com/pingcap/tidb/pkg/util/kvcache" + "github.com/pingcap/tidb/pkg/util/plancodec" + "github.com/tikv/client-go/v2/util" + atomic2 "go.uber.org/atomic" + "golang.org/x/exp/maps" +) + +// stmtSummaryByDigestKey defines key for stmtSummaryByDigestMap.summaryMap. +type stmtSummaryByDigestKey struct { + // Same statements may appear in different schema, but they refer to different tables. + schemaName string + digest string + // The digest of the previous statement. + prevDigest string + // The digest of the plan of this SQL. + planDigest string + // `resourceGroupName` is the resource group's name of this statement is bind to. + resourceGroupName string + // `hash` is the hash value of this object. + hash []byte +} + +// Hash implements SimpleLRUCache.Key. +// Only when current SQL is `commit` do we record `prevSQL`. Otherwise, `prevSQL` is empty. +// `prevSQL` is included in the key To distinguish different transactions. +func (key *stmtSummaryByDigestKey) Hash() []byte { + if len(key.hash) == 0 { + key.hash = make([]byte, 0, len(key.schemaName)+len(key.digest)+len(key.prevDigest)+len(key.planDigest)+len(key.resourceGroupName)) + key.hash = append(key.hash, hack.Slice(key.digest)...) + key.hash = append(key.hash, hack.Slice(key.schemaName)...) + key.hash = append(key.hash, hack.Slice(key.prevDigest)...) + key.hash = append(key.hash, hack.Slice(key.planDigest)...) + key.hash = append(key.hash, hack.Slice(key.resourceGroupName)...) + } + return key.hash +} + +// stmtSummaryByDigestMap is a LRU cache that stores statement summaries. +type stmtSummaryByDigestMap struct { + // It's rare to read concurrently, so RWMutex is not needed. + sync.Mutex + summaryMap *kvcache.SimpleLRUCache + // beginTimeForCurInterval is the begin time for current summary. + beginTimeForCurInterval int64 + + // These options are set by global system variables and are accessed concurrently. + optEnabled *atomic2.Bool + optEnableInternalQuery *atomic2.Bool + optMaxStmtCount *atomic2.Uint32 + optRefreshInterval *atomic2.Int64 + optHistorySize *atomic2.Int32 + optMaxSQLLength *atomic2.Int32 + + // other stores summary of evicted data. + other *stmtSummaryByDigestEvicted +} + +// StmtSummaryByDigestMap is a global map containing all statement summaries. +var StmtSummaryByDigestMap = newStmtSummaryByDigestMap() + +// stmtSummaryByDigest is the summary for each type of statements. +type stmtSummaryByDigest struct { + // It's rare to read concurrently, so RWMutex is not needed. + // Mutex is only used to lock `history`. + sync.Mutex + initialized bool + // Each element in history is a summary in one interval. + history *list.List + // Following fields are common for each summary element. + // They won't change once this object is created, so locking is not needed. + schemaName string + digest string + planDigest string + stmtType string + normalizedSQL string + tableNames string + isInternal bool +} + +// stmtSummaryByDigestElement is the summary for each type of statements in current interval. +type stmtSummaryByDigestElement struct { + sync.Mutex + // Each summary is summarized between [beginTime, endTime). + beginTime int64 + endTime int64 + // basic + sampleSQL string + charset string + collation string + prevSQL string + samplePlan string + sampleBinaryPlan string + planHint string + indexNames []string + execCount int64 + sumErrors int + sumWarnings int + // latency + sumLatency time.Duration + maxLatency time.Duration + minLatency time.Duration + sumParseLatency time.Duration + maxParseLatency time.Duration + sumCompileLatency time.Duration + maxCompileLatency time.Duration + // coprocessor + sumNumCopTasks int64 + maxCopProcessTime time.Duration + maxCopProcessAddress string + maxCopWaitTime time.Duration + maxCopWaitAddress string + // TiKV + sumProcessTime time.Duration + maxProcessTime time.Duration + sumWaitTime time.Duration + maxWaitTime time.Duration + sumBackoffTime time.Duration + maxBackoffTime time.Duration + sumTotalKeys int64 + maxTotalKeys int64 + sumProcessedKeys int64 + maxProcessedKeys int64 + sumRocksdbDeleteSkippedCount uint64 + maxRocksdbDeleteSkippedCount uint64 + sumRocksdbKeySkippedCount uint64 + maxRocksdbKeySkippedCount uint64 + sumRocksdbBlockCacheHitCount uint64 + maxRocksdbBlockCacheHitCount uint64 + sumRocksdbBlockReadCount uint64 + maxRocksdbBlockReadCount uint64 + sumRocksdbBlockReadByte uint64 + maxRocksdbBlockReadByte uint64 + // txn + commitCount int64 + sumGetCommitTsTime time.Duration + maxGetCommitTsTime time.Duration + sumPrewriteTime time.Duration + maxPrewriteTime time.Duration + sumCommitTime time.Duration + maxCommitTime time.Duration + sumLocalLatchTime time.Duration + maxLocalLatchTime time.Duration + sumCommitBackoffTime int64 + maxCommitBackoffTime int64 + sumResolveLockTime int64 + maxResolveLockTime int64 + sumWriteKeys int64 + maxWriteKeys int + sumWriteSize int64 + maxWriteSize int + sumPrewriteRegionNum int64 + maxPrewriteRegionNum int32 + sumTxnRetry int64 + maxTxnRetry int + sumBackoffTimes int64 + backoffTypes map[string]int + authUsers map[string]struct{} + // other + sumMem int64 + maxMem int64 + sumDisk int64 + maxDisk int64 + sumAffectedRows uint64 + sumKVTotal time.Duration + sumPDTotal time.Duration + sumBackoffTotal time.Duration + sumWriteSQLRespTotal time.Duration + sumResultRows int64 + maxResultRows int64 + minResultRows int64 + prepared bool + // The first time this type of SQL executes. + firstSeen time.Time + // The last time this type of SQL executes. + lastSeen time.Time + // plan cache + planInCache bool + planCacheHits int64 + planInBinding bool + // pessimistic execution retry information. + execRetryCount uint + execRetryTime time.Duration + // request-units + resourceGroupName string + StmtRUSummary + + planCacheUnqualifiedCount int64 + lastPlanCacheUnqualified string // the reason why this query is unqualified for the plan cache +} + +// StmtExecInfo records execution information of each statement. +type StmtExecInfo struct { + SchemaName string + OriginalSQL fmt.Stringer + Charset string + Collation string + NormalizedSQL string + Digest string + PrevSQL string + PrevSQLDigest string + PlanGenerator func() (string, string, any) + BinaryPlanGenerator func() string + PlanDigest string + PlanDigestGen func() string + User string + TotalLatency time.Duration + ParseLatency time.Duration + CompileLatency time.Duration + StmtCtx *stmtctx.StatementContext + CopTasks *execdetails.CopTasksDetails + ExecDetail *execdetails.ExecDetails + MemMax int64 + DiskMax int64 + StartTime time.Time + IsInternal bool + Succeed bool + PlanInCache bool + PlanInBinding bool + ExecRetryCount uint + ExecRetryTime time.Duration + execdetails.StmtExecDetails + ResultRows int64 + TiKVExecDetails util.ExecDetails + Prepared bool + KeyspaceName string + KeyspaceID uint32 + ResourceGroupName string + RUDetail *util.RUDetails + + PlanCacheUnqualified string +} + +// newStmtSummaryByDigestMap creates an empty stmtSummaryByDigestMap. +func newStmtSummaryByDigestMap() *stmtSummaryByDigestMap { + ssbde := newStmtSummaryByDigestEvicted() + + // This initializes the stmtSummaryByDigestMap with "compiled defaults" + // (which are regrettably duplicated from sessionctx/variable/tidb_vars.go). + // Unfortunately we need to do this to avoid circular dependencies, but the correct + // values will be applied on startup as soon as domain.LoadSysVarCacheLoop() is called, + // which in turn calls func domain.checkEnableServerGlobalVar(name, sVal string) for each sysvar. + // Currently this is early enough in the startup sequence. + maxStmtCount := uint(3000) + newSsMap := &stmtSummaryByDigestMap{ + summaryMap: kvcache.NewSimpleLRUCache(maxStmtCount, 0, 0), + optMaxStmtCount: atomic2.NewUint32(uint32(maxStmtCount)), + optEnabled: atomic2.NewBool(true), + optEnableInternalQuery: atomic2.NewBool(false), + optRefreshInterval: atomic2.NewInt64(1800), + optHistorySize: atomic2.NewInt32(24), + optMaxSQLLength: atomic2.NewInt32(4096), + other: ssbde, + } + newSsMap.summaryMap.SetOnEvict(func(k kvcache.Key, v kvcache.Value) { + historySize := newSsMap.historySize() + newSsMap.other.AddEvicted(k.(*stmtSummaryByDigestKey), v.(*stmtSummaryByDigest), historySize) + }) + return newSsMap +} + +// AddStatement adds a statement to StmtSummaryByDigestMap. +func (ssMap *stmtSummaryByDigestMap) AddStatement(sei *StmtExecInfo) { + // All times are counted in seconds. + now := time.Now().Unix() + + failpoint.Inject("mockTimeForStatementsSummary", func(val failpoint.Value) { + // mockTimeForStatementsSummary takes string of Unix timestamp + if unixTimeStr, ok := val.(string); ok { + unixTime, err := strconv.ParseInt(unixTimeStr, 10, 64) + if err != nil { + panic(err.Error()) + } + now = unixTime + } + }) + + intervalSeconds := ssMap.refreshInterval() + historySize := ssMap.historySize() + + key := &stmtSummaryByDigestKey{ + schemaName: sei.SchemaName, + digest: sei.Digest, + prevDigest: sei.PrevSQLDigest, + planDigest: sei.PlanDigest, + resourceGroupName: sei.ResourceGroupName, + } + // Calculate hash value in advance, to reduce the time holding the lock. + key.Hash() + + // Enclose the block in a function to ensure the lock will always be released. + summary, beginTime := func() (*stmtSummaryByDigest, int64) { + ssMap.Lock() + defer ssMap.Unlock() + + // Check again. Statements could be added before disabling the flag and after Clear(). + if !ssMap.Enabled() { + return nil, 0 + } + if sei.IsInternal && !ssMap.EnabledInternal() { + return nil, 0 + } + + if ssMap.beginTimeForCurInterval+intervalSeconds <= now { + // `beginTimeForCurInterval` is a multiple of intervalSeconds, so that when the interval is a multiple + // of 60 (or 600, 1800, 3600, etc), begin time shows 'XX:XX:00', not 'XX:XX:01'~'XX:XX:59'. + ssMap.beginTimeForCurInterval = now / intervalSeconds * intervalSeconds + } + + beginTime := ssMap.beginTimeForCurInterval + value, ok := ssMap.summaryMap.Get(key) + var summary *stmtSummaryByDigest + if !ok { + // Lazy initialize it to release ssMap.mutex ASAP. + summary = new(stmtSummaryByDigest) + ssMap.summaryMap.Put(key, summary) + } else { + summary = value.(*stmtSummaryByDigest) + } + summary.isInternal = summary.isInternal && sei.IsInternal + return summary, beginTime + }() + // Lock a single entry, not the whole cache. + if summary != nil { + summary.add(sei, beginTime, intervalSeconds, historySize) + } +} + +// Clear removes all statement summaries. +func (ssMap *stmtSummaryByDigestMap) Clear() { + ssMap.Lock() + defer ssMap.Unlock() + + ssMap.summaryMap.DeleteAll() + ssMap.other.Clear() + ssMap.beginTimeForCurInterval = 0 +} + +// clearInternal removes all statement summaries which are internal summaries. +func (ssMap *stmtSummaryByDigestMap) clearInternal() { + ssMap.Lock() + defer ssMap.Unlock() + + for _, key := range ssMap.summaryMap.Keys() { + summary, ok := ssMap.summaryMap.Get(key) + if !ok { + continue + } + if summary.(*stmtSummaryByDigest).isInternal { + ssMap.summaryMap.Delete(key) + } + } +} + +// BindableStmt is a wrapper struct for a statement that is extracted from statements_summary and can be +// created binding on. +type BindableStmt struct { + Schema string + Query string + PlanHint string + Charset string + Collation string + Users map[string]struct{} // which users have processed this stmt +} + +// GetMoreThanCntBindableStmt gets users' select/update/delete SQLs that occurred more than the specified count. +func (ssMap *stmtSummaryByDigestMap) GetMoreThanCntBindableStmt(cnt int64) []*BindableStmt { + ssMap.Lock() + values := ssMap.summaryMap.Values() + ssMap.Unlock() + + stmts := make([]*BindableStmt, 0, len(values)) + for _, value := range values { + ssbd := value.(*stmtSummaryByDigest) + func() { + ssbd.Lock() + defer ssbd.Unlock() + if ssbd.initialized && (ssbd.stmtType == "Select" || ssbd.stmtType == "Delete" || ssbd.stmtType == "Update" || ssbd.stmtType == "Insert" || ssbd.stmtType == "Replace") { + if ssbd.history.Len() > 0 { + ssElement := ssbd.history.Back().Value.(*stmtSummaryByDigestElement) + ssElement.Lock() + + // Empty auth users means that it is an internal queries. + if len(ssElement.authUsers) > 0 && (int64(ssbd.history.Len()) > cnt || ssElement.execCount > cnt) { + stmt := &BindableStmt{ + Schema: ssbd.schemaName, + Query: ssElement.sampleSQL, + PlanHint: ssElement.planHint, + Charset: ssElement.charset, + Collation: ssElement.collation, + Users: make(map[string]struct{}), + } + maps.Copy(stmt.Users, ssElement.authUsers) + // If it is SQL command prepare / execute, the ssElement.sampleSQL is `execute ...`, we should get the original select query. + // If it is binary protocol prepare / execute, ssbd.normalizedSQL should be same as ssElement.sampleSQL. + if ssElement.prepared { + stmt.Query = ssbd.normalizedSQL + } + stmts = append(stmts, stmt) + } + ssElement.Unlock() + } + } + }() + } + return stmts +} + +// SetEnabled enables or disables statement summary +func (ssMap *stmtSummaryByDigestMap) SetEnabled(value bool) error { + // `optEnabled` and `ssMap` don't need to be strictly atomically updated. + ssMap.optEnabled.Store(value) + if !value { + ssMap.Clear() + } + return nil +} + +// Enabled returns whether statement summary is enabled. +func (ssMap *stmtSummaryByDigestMap) Enabled() bool { + return ssMap.optEnabled.Load() +} + +// SetEnabledInternalQuery enables or disables internal statement summary +func (ssMap *stmtSummaryByDigestMap) SetEnabledInternalQuery(value bool) error { + // `optEnableInternalQuery` and `ssMap` don't need to be strictly atomically updated. + ssMap.optEnableInternalQuery.Store(value) + if !value { + ssMap.clearInternal() + } + return nil +} + +// EnabledInternal returns whether internal statement summary is enabled. +func (ssMap *stmtSummaryByDigestMap) EnabledInternal() bool { + return ssMap.optEnableInternalQuery.Load() +} + +// SetRefreshInterval sets refreshing interval in ssMap.sysVars. +func (ssMap *stmtSummaryByDigestMap) SetRefreshInterval(value int64) error { + ssMap.optRefreshInterval.Store(value) + return nil +} + +// refreshInterval gets the refresh interval for summaries. +func (ssMap *stmtSummaryByDigestMap) refreshInterval() int64 { + return ssMap.optRefreshInterval.Load() +} + +// SetHistorySize sets the history size for all summaries. +func (ssMap *stmtSummaryByDigestMap) SetHistorySize(value int) error { + ssMap.optHistorySize.Store(int32(value)) + return nil +} + +// historySize gets the history size for summaries. +func (ssMap *stmtSummaryByDigestMap) historySize() int { + return int(ssMap.optHistorySize.Load()) +} + +// SetHistorySize sets the history size for all summaries. +func (ssMap *stmtSummaryByDigestMap) SetMaxStmtCount(value uint) error { + // `optMaxStmtCount` and `ssMap` don't need to be strictly atomically updated. + ssMap.optMaxStmtCount.Store(uint32(value)) + + ssMap.Lock() + defer ssMap.Unlock() + return ssMap.summaryMap.SetCapacity(value) +} + +// Used by tests +// nolint: unused +func (ssMap *stmtSummaryByDigestMap) maxStmtCount() int { + return int(ssMap.optMaxStmtCount.Load()) +} + +// SetHistorySize sets the history size for all summaries. +func (ssMap *stmtSummaryByDigestMap) SetMaxSQLLength(value int) error { + ssMap.optMaxSQLLength.Store(int32(value)) + return nil +} + +func (ssMap *stmtSummaryByDigestMap) maxSQLLength() int { + return int(ssMap.optMaxSQLLength.Load()) +} + +// GetBindableStmtFromCluster gets users' select/update/delete SQL. +func GetBindableStmtFromCluster(rows []chunk.Row) *BindableStmt { + for _, row := range rows { + user := row.GetString(3) + stmtType := row.GetString(0) + if user != "" && (stmtType == "Select" || stmtType == "Delete" || stmtType == "Update" || stmtType == "Insert" || stmtType == "Replace") { + // Empty auth users means that it is an internal queries. + stmt := &BindableStmt{ + Schema: row.GetString(1), //schemaName + Query: row.GetString(5), //sampleSQL + PlanHint: row.GetString(8), //planHint + Charset: row.GetString(6), //charset + Collation: row.GetString(7), //collation + } + // If it is SQL command prepare / execute, we should remove the arguments + // If it is binary protocol prepare / execute, ssbd.normalizedSQL should be same as ssElement.sampleSQL. + if row.GetInt64(4) == 1 { + if idx := strings.LastIndex(stmt.Query, "[arguments:"); idx != -1 { + stmt.Query = stmt.Query[:idx] + } + } + return stmt + } + } + return nil +} + +// newStmtSummaryByDigest creates a stmtSummaryByDigest from StmtExecInfo. +func (ssbd *stmtSummaryByDigest) init(sei *StmtExecInfo, _ int64, _ int64, _ int) { + // Use "," to separate table names to support FIND_IN_SET. + var buffer bytes.Buffer + for i, value := range sei.StmtCtx.Tables { + // In `create database` statement, DB name is not empty but table name is empty. + if len(value.Table) == 0 { + continue + } + buffer.WriteString(strings.ToLower(value.DB)) + buffer.WriteString(".") + buffer.WriteString(strings.ToLower(value.Table)) + if i < len(sei.StmtCtx.Tables)-1 { + buffer.WriteString(",") + } + } + tableNames := buffer.String() + + planDigest := sei.PlanDigest + if sei.PlanDigestGen != nil && len(planDigest) == 0 { + // It comes here only when the plan is 'Point_Get'. + planDigest = sei.PlanDigestGen() + } + ssbd.schemaName = sei.SchemaName + ssbd.digest = sei.Digest + ssbd.planDigest = planDigest + ssbd.stmtType = sei.StmtCtx.StmtType + ssbd.normalizedSQL = formatSQL(sei.NormalizedSQL) + ssbd.tableNames = tableNames + ssbd.history = list.New() + ssbd.initialized = true +} + +func (ssbd *stmtSummaryByDigest) add(sei *StmtExecInfo, beginTime int64, intervalSeconds int64, historySize int) { + // Enclose this block in a function to ensure the lock will always be released. + ssElement, isElementNew := func() (*stmtSummaryByDigestElement, bool) { + ssbd.Lock() + defer ssbd.Unlock() + + if !ssbd.initialized { + ssbd.init(sei, beginTime, intervalSeconds, historySize) + } + + var ssElement *stmtSummaryByDigestElement + isElementNew := true + if ssbd.history.Len() > 0 { + lastElement := ssbd.history.Back().Value.(*stmtSummaryByDigestElement) + if lastElement.beginTime >= beginTime { + ssElement = lastElement + isElementNew = false + } else { + // The last elements expires to the history. + lastElement.onExpire(intervalSeconds) + } + } + if isElementNew { + // If the element is new created, `ssElement.add(sei)` should be done inside the lock of `ssbd`. + ssElement = newStmtSummaryByDigestElement(sei, beginTime, intervalSeconds) + if ssElement == nil { + return nil, isElementNew + } + ssbd.history.PushBack(ssElement) + } + + // `historySize` might be modified anytime, so check expiration every time. + // Even if history is set to 0, current summary is still needed. + for ssbd.history.Len() > historySize && ssbd.history.Len() > 1 { + ssbd.history.Remove(ssbd.history.Front()) + } + + return ssElement, isElementNew + }() + + // Lock a single entry, not the whole `ssbd`. + if !isElementNew { + ssElement.add(sei, intervalSeconds) + } +} + +// collectHistorySummaries puts at most `historySize` summaries to an array. +func (ssbd *stmtSummaryByDigest) collectHistorySummaries(checker *stmtSummaryChecker, historySize int) []*stmtSummaryByDigestElement { + ssbd.Lock() + defer ssbd.Unlock() + + if !ssbd.initialized { + return nil + } + if checker != nil && !checker.isDigestValid(ssbd.digest) { + return nil + } + + ssElements := make([]*stmtSummaryByDigestElement, 0, ssbd.history.Len()) + for listElement := ssbd.history.Front(); listElement != nil && len(ssElements) < historySize; listElement = listElement.Next() { + ssElement := listElement.Value.(*stmtSummaryByDigestElement) + ssElements = append(ssElements, ssElement) + } + return ssElements +} + +// MaxEncodedPlanSizeInBytes is the upper limit of the size of the plan and the binary plan in the stmt summary. +var MaxEncodedPlanSizeInBytes = 1024 * 1024 + +func newStmtSummaryByDigestElement(sei *StmtExecInfo, beginTime int64, intervalSeconds int64) *stmtSummaryByDigestElement { + // sampleSQL / authUsers(sampleUser) / samplePlan / prevSQL / indexNames store the values shown at the first time, + // because it compacts performance to update every time. + samplePlan, planHint, e := sei.PlanGenerator() + if e != nil { + return nil + } + if len(samplePlan) > MaxEncodedPlanSizeInBytes { + samplePlan = plancodec.PlanDiscardedEncoded + } + binPlan := "" + if sei.BinaryPlanGenerator != nil { + binPlan = sei.BinaryPlanGenerator() + if len(binPlan) > MaxEncodedPlanSizeInBytes { + binPlan = plancodec.BinaryPlanDiscardedEncoded + } + } + ssElement := &stmtSummaryByDigestElement{ + beginTime: beginTime, + sampleSQL: formatSQL(sei.OriginalSQL.String()), + charset: sei.Charset, + collation: sei.Collation, + // PrevSQL is already truncated to cfg.Log.QueryLogMaxLen. + prevSQL: sei.PrevSQL, + // samplePlan needs to be decoded so it can't be truncated. + samplePlan: samplePlan, + sampleBinaryPlan: binPlan, + planHint: planHint, + indexNames: sei.StmtCtx.IndexNames, + minLatency: sei.TotalLatency, + firstSeen: sei.StartTime, + lastSeen: sei.StartTime, + backoffTypes: make(map[string]int), + authUsers: make(map[string]struct{}), + planInCache: false, + planCacheHits: 0, + planInBinding: false, + prepared: sei.Prepared, + minResultRows: math.MaxInt64, + resourceGroupName: sei.ResourceGroupName, + } + ssElement.add(sei, intervalSeconds) + return ssElement +} + +// onExpire is called when this element expires to history. +func (ssElement *stmtSummaryByDigestElement) onExpire(intervalSeconds int64) { + ssElement.Lock() + defer ssElement.Unlock() + + // refreshInterval may change anytime, so we need to update endTime. + if ssElement.beginTime+intervalSeconds > ssElement.endTime { + // // If interval changes to a bigger value, update endTime to beginTime + interval. + ssElement.endTime = ssElement.beginTime + intervalSeconds + } else if ssElement.beginTime+intervalSeconds < ssElement.endTime { + now := time.Now().Unix() + // If interval changes to a smaller value and now > beginTime + interval, update endTime to current time. + if now > ssElement.beginTime+intervalSeconds { + ssElement.endTime = now + } + } +} + +func (ssElement *stmtSummaryByDigestElement) add(sei *StmtExecInfo, intervalSeconds int64) { + ssElement.Lock() + defer ssElement.Unlock() + + // add user to auth users set + if len(sei.User) > 0 { + ssElement.authUsers[sei.User] = struct{}{} + } + + // refreshInterval may change anytime, update endTime ASAP. + ssElement.endTime = ssElement.beginTime + intervalSeconds + ssElement.execCount++ + if !sei.Succeed { + ssElement.sumErrors++ + } + ssElement.sumWarnings += int(sei.StmtCtx.WarningCount()) + + // latency + ssElement.sumLatency += sei.TotalLatency + if sei.TotalLatency > ssElement.maxLatency { + ssElement.maxLatency = sei.TotalLatency + } + if sei.TotalLatency < ssElement.minLatency { + ssElement.minLatency = sei.TotalLatency + } + ssElement.sumParseLatency += sei.ParseLatency + if sei.ParseLatency > ssElement.maxParseLatency { + ssElement.maxParseLatency = sei.ParseLatency + } + ssElement.sumCompileLatency += sei.CompileLatency + if sei.CompileLatency > ssElement.maxCompileLatency { + ssElement.maxCompileLatency = sei.CompileLatency + } + + // coprocessor + numCopTasks := int64(sei.CopTasks.NumCopTasks) + ssElement.sumNumCopTasks += numCopTasks + if sei.CopTasks.MaxProcessTime > ssElement.maxCopProcessTime { + ssElement.maxCopProcessTime = sei.CopTasks.MaxProcessTime + ssElement.maxCopProcessAddress = sei.CopTasks.MaxProcessAddress + } + if sei.CopTasks.MaxWaitTime > ssElement.maxCopWaitTime { + ssElement.maxCopWaitTime = sei.CopTasks.MaxWaitTime + ssElement.maxCopWaitAddress = sei.CopTasks.MaxWaitAddress + } + + // TiKV + ssElement.sumProcessTime += sei.ExecDetail.TimeDetail.ProcessTime + if sei.ExecDetail.TimeDetail.ProcessTime > ssElement.maxProcessTime { + ssElement.maxProcessTime = sei.ExecDetail.TimeDetail.ProcessTime + } + ssElement.sumWaitTime += sei.ExecDetail.TimeDetail.WaitTime + if sei.ExecDetail.TimeDetail.WaitTime > ssElement.maxWaitTime { + ssElement.maxWaitTime = sei.ExecDetail.TimeDetail.WaitTime + } + ssElement.sumBackoffTime += sei.ExecDetail.BackoffTime + if sei.ExecDetail.BackoffTime > ssElement.maxBackoffTime { + ssElement.maxBackoffTime = sei.ExecDetail.BackoffTime + } + + if sei.ExecDetail.ScanDetail != nil { + ssElement.sumTotalKeys += sei.ExecDetail.ScanDetail.TotalKeys + if sei.ExecDetail.ScanDetail.TotalKeys > ssElement.maxTotalKeys { + ssElement.maxTotalKeys = sei.ExecDetail.ScanDetail.TotalKeys + } + ssElement.sumProcessedKeys += sei.ExecDetail.ScanDetail.ProcessedKeys + if sei.ExecDetail.ScanDetail.ProcessedKeys > ssElement.maxProcessedKeys { + ssElement.maxProcessedKeys = sei.ExecDetail.ScanDetail.ProcessedKeys + } + ssElement.sumRocksdbDeleteSkippedCount += sei.ExecDetail.ScanDetail.RocksdbDeleteSkippedCount + if sei.ExecDetail.ScanDetail.RocksdbDeleteSkippedCount > ssElement.maxRocksdbDeleteSkippedCount { + ssElement.maxRocksdbDeleteSkippedCount = sei.ExecDetail.ScanDetail.RocksdbDeleteSkippedCount + } + ssElement.sumRocksdbKeySkippedCount += sei.ExecDetail.ScanDetail.RocksdbKeySkippedCount + if sei.ExecDetail.ScanDetail.RocksdbKeySkippedCount > ssElement.maxRocksdbKeySkippedCount { + ssElement.maxRocksdbKeySkippedCount = sei.ExecDetail.ScanDetail.RocksdbKeySkippedCount + } + ssElement.sumRocksdbBlockCacheHitCount += sei.ExecDetail.ScanDetail.RocksdbBlockCacheHitCount + if sei.ExecDetail.ScanDetail.RocksdbBlockCacheHitCount > ssElement.maxRocksdbBlockCacheHitCount { + ssElement.maxRocksdbBlockCacheHitCount = sei.ExecDetail.ScanDetail.RocksdbBlockCacheHitCount + } + ssElement.sumRocksdbBlockReadCount += sei.ExecDetail.ScanDetail.RocksdbBlockReadCount + if sei.ExecDetail.ScanDetail.RocksdbBlockReadCount > ssElement.maxRocksdbBlockReadCount { + ssElement.maxRocksdbBlockReadCount = sei.ExecDetail.ScanDetail.RocksdbBlockReadCount + } + ssElement.sumRocksdbBlockReadByte += sei.ExecDetail.ScanDetail.RocksdbBlockReadByte + if sei.ExecDetail.ScanDetail.RocksdbBlockReadByte > ssElement.maxRocksdbBlockReadByte { + ssElement.maxRocksdbBlockReadByte = sei.ExecDetail.ScanDetail.RocksdbBlockReadByte + } + } + + // txn + commitDetails := sei.ExecDetail.CommitDetail + if commitDetails != nil { + ssElement.commitCount++ + ssElement.sumPrewriteTime += commitDetails.PrewriteTime + if commitDetails.PrewriteTime > ssElement.maxPrewriteTime { + ssElement.maxPrewriteTime = commitDetails.PrewriteTime + } + ssElement.sumCommitTime += commitDetails.CommitTime + if commitDetails.CommitTime > ssElement.maxCommitTime { + ssElement.maxCommitTime = commitDetails.CommitTime + } + ssElement.sumGetCommitTsTime += commitDetails.GetCommitTsTime + if commitDetails.GetCommitTsTime > ssElement.maxGetCommitTsTime { + ssElement.maxGetCommitTsTime = commitDetails.GetCommitTsTime + } + resolveLockTime := atomic.LoadInt64(&commitDetails.ResolveLock.ResolveLockTime) + ssElement.sumResolveLockTime += resolveLockTime + if resolveLockTime > ssElement.maxResolveLockTime { + ssElement.maxResolveLockTime = resolveLockTime + } + ssElement.sumLocalLatchTime += commitDetails.LocalLatchTime + if commitDetails.LocalLatchTime > ssElement.maxLocalLatchTime { + ssElement.maxLocalLatchTime = commitDetails.LocalLatchTime + } + ssElement.sumWriteKeys += int64(commitDetails.WriteKeys) + if commitDetails.WriteKeys > ssElement.maxWriteKeys { + ssElement.maxWriteKeys = commitDetails.WriteKeys + } + ssElement.sumWriteSize += int64(commitDetails.WriteSize) + if commitDetails.WriteSize > ssElement.maxWriteSize { + ssElement.maxWriteSize = commitDetails.WriteSize + } + prewriteRegionNum := atomic.LoadInt32(&commitDetails.PrewriteRegionNum) + ssElement.sumPrewriteRegionNum += int64(prewriteRegionNum) + if prewriteRegionNum > ssElement.maxPrewriteRegionNum { + ssElement.maxPrewriteRegionNum = prewriteRegionNum + } + ssElement.sumTxnRetry += int64(commitDetails.TxnRetry) + if commitDetails.TxnRetry > ssElement.maxTxnRetry { + ssElement.maxTxnRetry = commitDetails.TxnRetry + } + commitDetails.Mu.Lock() + commitBackoffTime := commitDetails.Mu.CommitBackoffTime + ssElement.sumCommitBackoffTime += commitBackoffTime + if commitBackoffTime > ssElement.maxCommitBackoffTime { + ssElement.maxCommitBackoffTime = commitBackoffTime + } + ssElement.sumBackoffTimes += int64(len(commitDetails.Mu.PrewriteBackoffTypes)) + for _, backoffType := range commitDetails.Mu.PrewriteBackoffTypes { + ssElement.backoffTypes[backoffType]++ + } + ssElement.sumBackoffTimes += int64(len(commitDetails.Mu.CommitBackoffTypes)) + for _, backoffType := range commitDetails.Mu.CommitBackoffTypes { + ssElement.backoffTypes[backoffType]++ + } + commitDetails.Mu.Unlock() + } + + // plan cache + if sei.PlanInCache { + ssElement.planInCache = true + ssElement.planCacheHits++ + } else { + ssElement.planInCache = false + } + if sei.PlanCacheUnqualified != "" { + ssElement.planCacheUnqualifiedCount++ + ssElement.lastPlanCacheUnqualified = sei.PlanCacheUnqualified + } + + // SPM + if sei.PlanInBinding { + ssElement.planInBinding = true + } else { + ssElement.planInBinding = false + } + + // other + ssElement.sumAffectedRows += sei.StmtCtx.AffectedRows() + ssElement.sumMem += sei.MemMax + if sei.MemMax > ssElement.maxMem { + ssElement.maxMem = sei.MemMax + } + ssElement.sumDisk += sei.DiskMax + if sei.DiskMax > ssElement.maxDisk { + ssElement.maxDisk = sei.DiskMax + } + if sei.StartTime.Before(ssElement.firstSeen) { + ssElement.firstSeen = sei.StartTime + } + if ssElement.lastSeen.Before(sei.StartTime) { + ssElement.lastSeen = sei.StartTime + } + if sei.ExecRetryCount > 0 { + ssElement.execRetryCount += sei.ExecRetryCount + ssElement.execRetryTime += sei.ExecRetryTime + } + if sei.ResultRows > 0 { + ssElement.sumResultRows += sei.ResultRows + if ssElement.maxResultRows < sei.ResultRows { + ssElement.maxResultRows = sei.ResultRows + } + if ssElement.minResultRows > sei.ResultRows { + ssElement.minResultRows = sei.ResultRows + } + } else { + ssElement.minResultRows = 0 + } + ssElement.sumKVTotal += time.Duration(atomic.LoadInt64(&sei.TiKVExecDetails.WaitKVRespDuration)) + ssElement.sumPDTotal += time.Duration(atomic.LoadInt64(&sei.TiKVExecDetails.WaitPDRespDuration)) + ssElement.sumBackoffTotal += time.Duration(atomic.LoadInt64(&sei.TiKVExecDetails.BackoffDuration)) + ssElement.sumWriteSQLRespTotal += sei.StmtExecDetails.WriteSQLRespDuration + + // request-units + ssElement.StmtRUSummary.Add(sei.RUDetail) +} + +// Truncate SQL to maxSQLLength. +func formatSQL(sql string) string { + maxSQLLength := StmtSummaryByDigestMap.maxSQLLength() + length := len(sql) + if length > maxSQLLength { + var result strings.Builder + result.WriteString(sql[:maxSQLLength]) + fmt.Fprintf(&result, "(len:%d)", length) + return result.String() + } + return sql +} + +// Format the backoffType map to a string or nil. +func formatBackoffTypes(backoffMap map[string]int) any { + type backoffStat struct { + backoffType string + count int + } + + size := len(backoffMap) + if size == 0 { + return nil + } + + backoffArray := make([]backoffStat, 0, len(backoffMap)) + for backoffType, count := range backoffMap { + backoffArray = append(backoffArray, backoffStat{backoffType, count}) + } + slices.SortFunc(backoffArray, func(i, j backoffStat) int { + return cmp.Compare(j.count, i.count) + }) + + var buffer bytes.Buffer + for index, stat := range backoffArray { + if _, err := fmt.Fprintf(&buffer, "%v:%d", stat.backoffType, stat.count); err != nil { + return "FORMAT ERROR" + } + if index < len(backoffArray)-1 { + buffer.WriteString(",") + } + } + return buffer.String() +} + +func avgInt(sum int64, count int64) int64 { + if count > 0 { + return sum / count + } + return 0 +} + +func avgFloat(sum int64, count int64) float64 { + if count > 0 { + return float64(sum) / float64(count) + } + return 0 +} + +func avgSumFloat(sum float64, count int64) float64 { + if count > 0 { + return sum / float64(count) + } + return 0 +} + +func convertEmptyToNil(str string) any { + if str == "" { + return nil + } + return str +} + +// StmtRUSummary is the request-units summary for each type of statements. +type StmtRUSummary struct { + SumRRU float64 `json:"sum_rru"` + SumWRU float64 `json:"sum_wru"` + SumRUWaitDuration time.Duration `json:"sum_ru_wait_duration"` + MaxRRU float64 `json:"max_rru"` + MaxWRU float64 `json:"max_wru"` + MaxRUWaitDuration time.Duration `json:"max_ru_wait_duration"` +} + +// Add add a new sample value to the ru summary record. +func (s *StmtRUSummary) Add(info *util.RUDetails) { + if info != nil { + rru := info.RRU() + s.SumRRU += rru + if s.MaxRRU < rru { + s.MaxRRU = rru + } + wru := info.WRU() + s.SumWRU += wru + if s.MaxWRU < wru { + s.MaxWRU = wru + } + ruWaitDur := info.RUWaitDuration() + s.SumRUWaitDuration += ruWaitDur + if s.MaxRUWaitDuration < ruWaitDur { + s.MaxRUWaitDuration = ruWaitDur + } + } +} + +// Merge merges the value of 2 ru summary records. +func (s *StmtRUSummary) Merge(other *StmtRUSummary) { + s.SumRRU += other.SumRRU + s.SumWRU += other.SumWRU + s.SumRUWaitDuration += other.SumRUWaitDuration + if s.MaxRRU < other.MaxRRU { + s.MaxRRU = other.MaxRRU + } + if s.MaxWRU < other.MaxWRU { + s.MaxWRU = other.MaxWRU + } + if s.MaxRUWaitDuration < other.MaxRUWaitDuration { + s.MaxRUWaitDuration = other.MaxRUWaitDuration + } +} diff --git a/pkg/util/topsql/binding__failpoint_binding__.go b/pkg/util/topsql/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..2baa7cc332089 --- /dev/null +++ b/pkg/util/topsql/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package topsql + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/topsql/reporter/binding__failpoint_binding__.go b/pkg/util/topsql/reporter/binding__failpoint_binding__.go new file mode 100644 index 0000000000000..2b1d47d228c18 --- /dev/null +++ b/pkg/util/topsql/reporter/binding__failpoint_binding__.go @@ -0,0 +1,14 @@ + +package reporter + +import "reflect" + +type __failpointBindingType struct {pkgpath string} +var __failpointBindingCache = &__failpointBindingType{} + +func init() { + __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() +} +func _curpkg_(name string) string { + return __failpointBindingCache.pkgpath + "/" + name +} diff --git a/pkg/util/topsql/reporter/pubsub.go b/pkg/util/topsql/reporter/pubsub.go index a9ce1fe7e9b8c..ceb6139d5fbea 100644 --- a/pkg/util/topsql/reporter/pubsub.go +++ b/pkg/util/topsql/reporter/pubsub.go @@ -138,7 +138,7 @@ func (ds *pubSubDataSink) run() error { return ctx.Err() } - failpoint.Inject("mockGrpcLogPanic", nil) + failpoint.Eval(_curpkg_("mockGrpcLogPanic")) if err != nil { logutil.BgLogger().Warn( "[top-sql] pubsub datasink failed to send data to subscriber", diff --git a/pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ b/pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ new file mode 100644 index 0000000000000..a9ce1fe7e9b8c --- /dev/null +++ b/pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ @@ -0,0 +1,274 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reporter + +import ( + "context" + "errors" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + reporter_metrics "github.com/pingcap/tidb/pkg/util/topsql/reporter/metrics" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" +) + +// TopSQLPubSubService implements tipb.TopSQLPubSubServer. +// +// If a client subscribes to TopSQL records, the TopSQLPubSubService is responsible +// for registering an associated DataSink to the reporter. Then the DataSink sends +// data to the client periodically. +type TopSQLPubSubService struct { + dataSinkRegisterer DataSinkRegisterer +} + +// NewTopSQLPubSubService creates a new TopSQLPubSubService. +func NewTopSQLPubSubService(dataSinkRegisterer DataSinkRegisterer) *TopSQLPubSubService { + return &TopSQLPubSubService{dataSinkRegisterer: dataSinkRegisterer} +} + +var _ tipb.TopSQLPubSubServer = &TopSQLPubSubService{} + +// Subscribe registers dataSinks to the reporter and redirects data received from reporter +// to subscribers associated with those dataSinks. +func (ps *TopSQLPubSubService) Subscribe(_ *tipb.TopSQLSubRequest, stream tipb.TopSQLPubSub_SubscribeServer) error { + ds := newPubSubDataSink(stream, ps.dataSinkRegisterer) + if err := ps.dataSinkRegisterer.Register(ds); err != nil { + return err + } + return ds.run() +} + +type pubSubDataSink struct { + ctx context.Context + cancel context.CancelFunc + + stream tipb.TopSQLPubSub_SubscribeServer + sendTaskCh chan sendTask + + // for deregister + registerer DataSinkRegisterer +} + +func newPubSubDataSink(stream tipb.TopSQLPubSub_SubscribeServer, registerer DataSinkRegisterer) *pubSubDataSink { + ctx, cancel := context.WithCancel(stream.Context()) + + return &pubSubDataSink{ + ctx: ctx, + cancel: cancel, + + stream: stream, + sendTaskCh: make(chan sendTask, 1), + + registerer: registerer, + } +} + +var _ DataSink = &pubSubDataSink{} + +func (ds *pubSubDataSink) TrySend(data *ReportData, deadline time.Time) error { + select { + case ds.sendTaskCh <- sendTask{data: data, deadline: deadline}: + return nil + case <-ds.ctx.Done(): + return ds.ctx.Err() + default: + reporter_metrics.IgnoreReportChannelFullCounter.Inc() + return errors.New("the channel of pubsub dataSink is full") + } +} + +func (ds *pubSubDataSink) OnReporterClosing() { + ds.cancel() +} + +func (ds *pubSubDataSink) run() error { + defer func() { + if r := recover(); r != nil { + // To catch panic when log grpc error. https://github.com/pingcap/tidb/issues/51301. + logutil.BgLogger().Error("[top-sql] got panic in pub sub data sink, just ignore", zap.Error(util.GetRecoverError(r))) + } + ds.registerer.Deregister(ds) + ds.cancel() + }() + + for { + select { + case task := <-ds.sendTaskCh: + ctx, cancel := context.WithDeadline(ds.ctx, task.deadline) + var err error + + start := time.Now() + go util.WithRecovery(func() { + defer cancel() + err = ds.doSend(ctx, task.data) + + if err != nil { + reporter_metrics.ReportAllDurationFailedHistogram.Observe(time.Since(start).Seconds()) + } else { + reporter_metrics.ReportAllDurationSuccHistogram.Observe(time.Since(start).Seconds()) + } + }, nil) + + // When the deadline is exceeded, the closure inside `go util.WithRecovery` above may not notice that + // immediately because it can be blocked by `stream.Send`. + // In order to clean up resources as quickly as possible, we let that closure run in an individual goroutine, + // and wait for timeout here. + <-ctx.Done() + + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + logutil.BgLogger().Warn( + "[top-sql] pubsub datasink failed to send data to subscriber due to deadline exceeded", + zap.Time("deadline", task.deadline), + ) + return ctx.Err() + } + + failpoint.Inject("mockGrpcLogPanic", nil) + if err != nil { + logutil.BgLogger().Warn( + "[top-sql] pubsub datasink failed to send data to subscriber", + zap.Error(err), + ) + return err + } + case <-ds.ctx.Done(): + return ds.ctx.Err() + } + } +} + +func (ds *pubSubDataSink) doSend(ctx context.Context, data *ReportData) error { + if err := ds.sendTopSQLRecords(ctx, data.DataRecords); err != nil { + return err + } + if err := ds.sendSQLMeta(ctx, data.SQLMetas); err != nil { + return err + } + return ds.sendPlanMeta(ctx, data.PlanMetas) +} + +func (ds *pubSubDataSink) sendTopSQLRecords(ctx context.Context, records []tipb.TopSQLRecord) (err error) { + if len(records) == 0 { + return + } + + start := time.Now() + sentCount := 0 + defer func() { + reporter_metrics.TopSQLReportRecordCounterHistogram.Observe(float64(sentCount)) + if err != nil { + reporter_metrics.ReportRecordDurationFailedHistogram.Observe(time.Since(start).Seconds()) + } else { + reporter_metrics.ReportRecordDurationSuccHistogram.Observe(time.Since(start).Seconds()) + } + }() + + topSQLRecord := &tipb.TopSQLSubResponse_Record{} + r := &tipb.TopSQLSubResponse{RespOneof: topSQLRecord} + + for i := range records { + topSQLRecord.Record = &records[i] + if err = ds.stream.Send(r); err != nil { + return + } + sentCount++ + + select { + case <-ctx.Done(): + err = ctx.Err() + return + default: + } + } + + return +} + +func (ds *pubSubDataSink) sendSQLMeta(ctx context.Context, sqlMetas []tipb.SQLMeta) (err error) { + if len(sqlMetas) == 0 { + return + } + + start := time.Now() + sentCount := 0 + defer func() { + reporter_metrics.TopSQLReportSQLCountHistogram.Observe(float64(sentCount)) + if err != nil { + reporter_metrics.ReportSQLDurationFailedHistogram.Observe(time.Since(start).Seconds()) + } else { + reporter_metrics.ReportSQLDurationSuccHistogram.Observe(time.Since(start).Seconds()) + } + }() + + sqlMeta := &tipb.TopSQLSubResponse_SqlMeta{} + r := &tipb.TopSQLSubResponse{RespOneof: sqlMeta} + + for i := range sqlMetas { + sqlMeta.SqlMeta = &sqlMetas[i] + if err = ds.stream.Send(r); err != nil { + return + } + sentCount++ + + select { + case <-ctx.Done(): + err = ctx.Err() + return + default: + } + } + + return +} + +func (ds *pubSubDataSink) sendPlanMeta(ctx context.Context, planMetas []tipb.PlanMeta) (err error) { + if len(planMetas) == 0 { + return + } + + start := time.Now() + sentCount := 0 + defer func() { + reporter_metrics.TopSQLReportPlanCountHistogram.Observe(float64(sentCount)) + if err != nil { + reporter_metrics.ReportPlanDurationFailedHistogram.Observe(time.Since(start).Seconds()) + } else { + reporter_metrics.ReportPlanDurationSuccHistogram.Observe(time.Since(start).Seconds()) + } + }() + + planMeta := &tipb.TopSQLSubResponse_PlanMeta{} + r := &tipb.TopSQLSubResponse{RespOneof: planMeta} + + for i := range planMetas { + planMeta.PlanMeta = &planMetas[i] + if err = ds.stream.Send(r); err != nil { + return + } + sentCount++ + + select { + case <-ctx.Done(): + err = ctx.Err() + return + default: + } + } + + return +} diff --git a/pkg/util/topsql/reporter/reporter.go b/pkg/util/topsql/reporter/reporter.go index ddb2fd4c81c82..40afca010e75c 100644 --- a/pkg/util/topsql/reporter/reporter.go +++ b/pkg/util/topsql/reporter/reporter.go @@ -287,14 +287,14 @@ func (tsr *RemoteTopSQLReporter) doReport(data *ReportData) { return } timeout := reportTimeout - failpoint.Inject("resetTimeoutForTest", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("resetTimeoutForTest")); _err_ == nil { if val.(bool) { interval := time.Duration(topsqlstate.GlobalState.ReportIntervalSeconds.Load()) * time.Second if interval < timeout { timeout = interval } } - }) + } _ = tsr.trySend(data, time.Now().Add(timeout)) } diff --git a/pkg/util/topsql/reporter/reporter.go__failpoint_stash__ b/pkg/util/topsql/reporter/reporter.go__failpoint_stash__ new file mode 100644 index 0000000000000..ddb2fd4c81c82 --- /dev/null +++ b/pkg/util/topsql/reporter/reporter.go__failpoint_stash__ @@ -0,0 +1,333 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package reporter + +import ( + "context" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/topsql/collector" + reporter_metrics "github.com/pingcap/tidb/pkg/util/topsql/reporter/metrics" + topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + "go.uber.org/zap" +) + +const ( + reportTimeout = 40 * time.Second + collectChanBufferSize = 2 +) + +var nowFunc = time.Now + +// TopSQLReporter collects Top SQL metrics. +type TopSQLReporter interface { + collector.Collector + stmtstats.Collector + + // Start uses to start the reporter. + Start() + + // RegisterSQL registers a normalizedSQL with SQLDigest. + // + // Note that the normalized SQL string can be of >1M long. + // This function should be thread-safe, which means concurrently calling it + // in several goroutines should be fine. It should also return immediately, + // and do any CPU-intensive job asynchronously. + RegisterSQL(sqlDigest []byte, normalizedSQL string, isInternal bool) + + // RegisterPlan like RegisterSQL, but for normalized plan strings. + // isLarge indicates the size of normalizedPlan is big. + RegisterPlan(planDigest []byte, normalizedPlan string, isLarge bool) + + // Close uses to close and release the reporter resource. + Close() +} + +var _ TopSQLReporter = &RemoteTopSQLReporter{} +var _ DataSinkRegisterer = &RemoteTopSQLReporter{} + +// RemoteTopSQLReporter implements TopSQLReporter that sends data to a remote agent. +// This should be called periodically to collect TopSQL resource usage metrics. +type RemoteTopSQLReporter struct { + ctx context.Context + reportCollectedDataChan chan collectedData + cancel context.CancelFunc + sqlCPUCollector *collector.SQLCPUCollector + collectCPUTimeChan chan []collector.SQLCPUTimeRecord + collectStmtStatsChan chan stmtstats.StatementStatsMap + collecting *collecting + normalizedSQLMap *normalizedSQLMap + normalizedPlanMap *normalizedPlanMap + stmtStatsBuffer map[uint64]stmtstats.StatementStatsMap // timestamp => stmtstats.StatementStatsMap + // calling decodePlan this can take a while, so should not block critical paths. + decodePlan planBinaryDecodeFunc + // Instead of dropping large plans, we compress it into encoded format and report + compressPlan planBinaryCompressFunc + DefaultDataSinkRegisterer +} + +// NewRemoteTopSQLReporter creates a new RemoteTopSQLReporter. +// +// decodePlan is a decoding function which will be called asynchronously to decode the plan binary to string. +func NewRemoteTopSQLReporter(decodePlan planBinaryDecodeFunc, compressPlan planBinaryCompressFunc) *RemoteTopSQLReporter { + ctx, cancel := context.WithCancel(context.Background()) + tsr := &RemoteTopSQLReporter{ + DefaultDataSinkRegisterer: NewDefaultDataSinkRegisterer(ctx), + ctx: ctx, + cancel: cancel, + collectCPUTimeChan: make(chan []collector.SQLCPUTimeRecord, collectChanBufferSize), + collectStmtStatsChan: make(chan stmtstats.StatementStatsMap, collectChanBufferSize), + reportCollectedDataChan: make(chan collectedData, 1), + collecting: newCollecting(), + normalizedSQLMap: newNormalizedSQLMap(), + normalizedPlanMap: newNormalizedPlanMap(), + stmtStatsBuffer: map[uint64]stmtstats.StatementStatsMap{}, + decodePlan: decodePlan, + compressPlan: compressPlan, + } + tsr.sqlCPUCollector = collector.NewSQLCPUCollector(tsr) + return tsr +} + +// Start implements the TopSQLReporter interface. +func (tsr *RemoteTopSQLReporter) Start() { + tsr.sqlCPUCollector.Start() + go tsr.collectWorker() + go tsr.reportWorker() +} + +// Collect implements tracecpu.Collector. +// +// WARN: It will drop the DataRecords if the processing is not in time. +// This function is thread-safe and efficient. +func (tsr *RemoteTopSQLReporter) Collect(data []collector.SQLCPUTimeRecord) { + if len(data) == 0 { + return + } + select { + case tsr.collectCPUTimeChan <- data: + default: + // ignore if chan blocked + reporter_metrics.IgnoreCollectChannelFullCounter.Inc() + } +} + +// CollectStmtStatsMap implements stmtstats.Collector. +// +// WARN: It will drop the DataRecords if the processing is not in time. +// This function is thread-safe and efficient. +func (tsr *RemoteTopSQLReporter) CollectStmtStatsMap(data stmtstats.StatementStatsMap) { + if len(data) == 0 { + return + } + select { + case tsr.collectStmtStatsChan <- data: + default: + // ignore if chan blocked + reporter_metrics.IgnoreCollectStmtChannelFullCounter.Inc() + } +} + +// RegisterSQL implements TopSQLReporter. +// +// This function is thread-safe and efficient. +func (tsr *RemoteTopSQLReporter) RegisterSQL(sqlDigest []byte, normalizedSQL string, isInternal bool) { + tsr.normalizedSQLMap.register(sqlDigest, normalizedSQL, isInternal) +} + +// RegisterPlan implements TopSQLReporter. +// +// This function is thread-safe and efficient. +func (tsr *RemoteTopSQLReporter) RegisterPlan(planDigest []byte, normalizedPlan string, isLarge bool) { + tsr.normalizedPlanMap.register(planDigest, normalizedPlan, isLarge) +} + +// Close implements TopSQLReporter. +func (tsr *RemoteTopSQLReporter) Close() { + tsr.cancel() + tsr.sqlCPUCollector.Stop() + tsr.onReporterClosing() +} + +// collectWorker consumes and collects data from tracecpu.Collector/stmtstats.Collector. +func (tsr *RemoteTopSQLReporter) collectWorker() { + defer util.Recover("top-sql", "collectWorker", nil, false) + + currentReportInterval := topsqlstate.GlobalState.ReportIntervalSeconds.Load() + reportTicker := time.NewTicker(time.Second * time.Duration(currentReportInterval)) + defer reportTicker.Stop() + for { + select { + case <-tsr.ctx.Done(): + return + case data := <-tsr.collectCPUTimeChan: + timestamp := uint64(nowFunc().Unix()) + tsr.processCPUTimeData(timestamp, data) + case data := <-tsr.collectStmtStatsChan: + timestamp := uint64(nowFunc().Unix()) + tsr.stmtStatsBuffer[timestamp] = data + case <-reportTicker.C: + tsr.processStmtStatsData() + tsr.takeDataAndSendToReportChan() + // Update `reportTicker` if report interval changed. + if newInterval := topsqlstate.GlobalState.ReportIntervalSeconds.Load(); newInterval != currentReportInterval { + currentReportInterval = newInterval + reportTicker.Reset(time.Second * time.Duration(currentReportInterval)) + } + } + } +} + +// processCPUTimeData collects top N cpuRecords of each round into tsr.collecting, and evict the +// data that is not in top N. All the evicted cpuRecords will be summary into the others. +func (tsr *RemoteTopSQLReporter) processCPUTimeData(timestamp uint64, data cpuRecords) { + defer util.Recover("top-sql", "processCPUTimeData", nil, false) + + // Get top N cpuRecords of each round cpuRecords. Collect the top N to tsr.collecting + // for each round. SQL meta will not be evicted, since the evicted SQL can be appeared + // on other components (TiKV) TopN DataRecords. + top, evicted := data.topN(int(topsqlstate.GlobalState.MaxStatementCount.Load())) + for _, r := range top { + tsr.collecting.getOrCreateRecord(r.SQLDigest, r.PlanDigest).appendCPUTime(timestamp, r.CPUTimeMs) + } + if len(evicted) == 0 { + return + } + totalEvictedCPUTime := uint32(0) + for _, e := range evicted { + totalEvictedCPUTime += e.CPUTimeMs + // Mark which digests are evicted under each timestamp. + // We will determine whether the corresponding CPUTime has been evicted + // when collecting stmtstats. If so, then we can ignore it directly. + tsr.collecting.markAsEvicted(timestamp, e.SQLDigest, e.PlanDigest) + } + tsr.collecting.appendOthersCPUTime(timestamp, totalEvictedCPUTime) +} + +// processStmtStatsData collects tsr.stmtStatsBuffer into tsr.collecting. +// All the evicted items will be summary into the others. +func (tsr *RemoteTopSQLReporter) processStmtStatsData() { + defer util.Recover("top-sql", "processStmtStatsData", nil, false) + + for timestamp, data := range tsr.stmtStatsBuffer { + for digest, item := range data { + sqlDigest, planDigest := []byte(digest.SQLDigest), []byte(digest.PlanDigest) + if tsr.collecting.hasEvicted(timestamp, sqlDigest, planDigest) { + // This timestamp+sql+plan has been evicted due to low CPUTime. + tsr.collecting.appendOthersStmtStatsItem(timestamp, *item) + continue + } + tsr.collecting.getOrCreateRecord(sqlDigest, planDigest).appendStmtStatsItem(timestamp, *item) + } + } + tsr.stmtStatsBuffer = map[uint64]stmtstats.StatementStatsMap{} +} + +// takeDataAndSendToReportChan takes records data and then send to the report channel for reporting. +func (tsr *RemoteTopSQLReporter) takeDataAndSendToReportChan() { + // Send to report channel. When channel is full, data will be dropped. + select { + case tsr.reportCollectedDataChan <- collectedData{ + collected: tsr.collecting.take(), + normalizedSQLMap: tsr.normalizedSQLMap.take(), + normalizedPlanMap: tsr.normalizedPlanMap.take(), + }: + default: + // ignore if chan blocked + reporter_metrics.IgnoreReportChannelFullCounter.Inc() + } +} + +// reportWorker sends data to the gRPC endpoint from the `reportCollectedDataChan` one by one. +func (tsr *RemoteTopSQLReporter) reportWorker() { + defer util.Recover("top-sql", "reportWorker", nil, false) + + for { + select { + case data := <-tsr.reportCollectedDataChan: + // When `reportCollectedDataChan` receives something, there could be ongoing + // `RegisterSQL` and `RegisterPlan` running, who writes to the data structure + // that `data` contains. So we wait for a little while to ensure that writes + // are finished. + time.Sleep(time.Millisecond * 100) + rs := data.collected.getReportRecords() + // Convert to protobuf data and do report. + tsr.doReport(&ReportData{ + DataRecords: rs.toProto(), + SQLMetas: data.normalizedSQLMap.toProto(), + PlanMetas: data.normalizedPlanMap.toProto(tsr.decodePlan, tsr.compressPlan), + }) + case <-tsr.ctx.Done(): + return + } + } +} + +// doReport sends ReportData to DataSinks. +func (tsr *RemoteTopSQLReporter) doReport(data *ReportData) { + defer util.Recover("top-sql", "doReport", nil, false) + + if !data.hasData() { + return + } + timeout := reportTimeout + failpoint.Inject("resetTimeoutForTest", func(val failpoint.Value) { + if val.(bool) { + interval := time.Duration(topsqlstate.GlobalState.ReportIntervalSeconds.Load()) * time.Second + if interval < timeout { + timeout = interval + } + } + }) + _ = tsr.trySend(data, time.Now().Add(timeout)) +} + +// trySend sends ReportData to all internal registered DataSinks. +func (tsr *RemoteTopSQLReporter) trySend(data *ReportData, deadline time.Time) error { + tsr.DefaultDataSinkRegisterer.Lock() + dataSinks := make([]DataSink, 0, len(tsr.dataSinks)) + for ds := range tsr.dataSinks { + dataSinks = append(dataSinks, ds) + } + tsr.DefaultDataSinkRegisterer.Unlock() + for _, ds := range dataSinks { + if err := ds.TrySend(data, deadline); err != nil { + logutil.BgLogger().Warn("failed to send data to datasink", zap.String("category", "top-sql"), zap.Error(err)) + } + } + return nil +} + +// onReporterClosing calls the OnReporterClosing method of all internally registered DataSinks. +func (tsr *RemoteTopSQLReporter) onReporterClosing() { + var m map[DataSink]struct{} + tsr.DefaultDataSinkRegisterer.Lock() + m, tsr.dataSinks = tsr.dataSinks, make(map[DataSink]struct{}) + tsr.DefaultDataSinkRegisterer.Unlock() + for d := range m { + d.OnReporterClosing() + } +} + +// collectedData is used for transmission in the channel. +type collectedData struct { + collected *collecting + normalizedSQLMap *normalizedSQLMap + normalizedPlanMap *normalizedPlanMap +} diff --git a/pkg/util/topsql/topsql.go b/pkg/util/topsql/topsql.go index d6818051125c7..daa237f229fb7 100644 --- a/pkg/util/topsql/topsql.go +++ b/pkg/util/topsql/topsql.go @@ -106,7 +106,7 @@ func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDige linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) - failpoint.Inject("mockHighLoadForEachSQL", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForEachSQL")); _err_ == nil { // In integration test, some SQL run very fast that Top SQL pprof profile unable to sample data of those SQL, // So need mock some high cpu load to make sure pprof profile successfully samples the data of those SQL. // Attention: Top SQL pprof profile unable to sample data of those SQL which run very fast, this behavior is expected. @@ -118,7 +118,7 @@ func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDige logutil.BgLogger().Info("attach SQL info", zap.String("sql", normalizedSQL)) } } - }) + } return ctx } @@ -135,14 +135,14 @@ func AttachSQLAndPlanInfo(ctx context.Context, sqlDigest *parser.Digest, planDig ctx = collector.CtxWithSQLAndPlanDigest(ctx, sqlDigestStr, planDigestStr) pprof.SetGoroutineLabels(ctx) - failpoint.Inject("mockHighLoadForEachPlan", func(val failpoint.Value) { + if val, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForEachPlan")); _err_ == nil { // Work like mockHighLoadForEachSQL failpoint. if val.(bool) { if MockHighCPULoad("", []string{""}, 1) { logutil.BgLogger().Info("attach SQL info") } } - }) + } return ctx } diff --git a/pkg/util/topsql/topsql.go__failpoint_stash__ b/pkg/util/topsql/topsql.go__failpoint_stash__ new file mode 100644 index 0000000000000..d6818051125c7 --- /dev/null +++ b/pkg/util/topsql/topsql.go__failpoint_stash__ @@ -0,0 +1,187 @@ +// Copyright 2021 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package topsql + +import ( + "context" + "runtime/pprof" + "strings" + "time" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/util/logutil" + "github.com/pingcap/tidb/pkg/util/plancodec" + "github.com/pingcap/tidb/pkg/util/topsql/collector" + "github.com/pingcap/tidb/pkg/util/topsql/reporter" + "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" + "github.com/pingcap/tipb/go-tipb" + "go.uber.org/zap" + "google.golang.org/grpc" +) + +const ( + // MaxSQLTextSize exports for testing. + MaxSQLTextSize = 4 * 1024 + // MaxBinaryPlanSize exports for testing. + MaxBinaryPlanSize = 2 * 1024 +) + +var ( + globalTopSQLReport reporter.TopSQLReporter + singleTargetDataSink *reporter.SingleTargetDataSink +) + +func init() { + remoteReporter := reporter.NewRemoteTopSQLReporter(plancodec.DecodeNormalizedPlan, plancodec.Compress) + globalTopSQLReport = remoteReporter + singleTargetDataSink = reporter.NewSingleTargetDataSink(remoteReporter) +} + +// SetupTopSQL sets up the top-sql worker. +func SetupTopSQL() { + globalTopSQLReport.Start() + singleTargetDataSink.Start() + + stmtstats.RegisterCollector(globalTopSQLReport) + stmtstats.SetupAggregator() +} + +// SetupTopSQLForTest sets up the global top-sql reporter, it's exporting for test. +func SetupTopSQLForTest(r reporter.TopSQLReporter) { + globalTopSQLReport = r +} + +// RegisterPubSubServer registers TopSQLPubSubService to the given gRPC server. +func RegisterPubSubServer(s *grpc.Server) { + if register, ok := globalTopSQLReport.(reporter.DataSinkRegisterer); ok { + service := reporter.NewTopSQLPubSubService(register) + tipb.RegisterTopSQLPubSubServer(s, service) + } +} + +// Close uses to close and release the top sql resource. +func Close() { + singleTargetDataSink.Close() + globalTopSQLReport.Close() + stmtstats.CloseAggregator() +} + +// RegisterSQL uses to register SQL information into Top SQL. +func RegisterSQL(normalizedSQL string, sqlDigest *parser.Digest, isInternal bool) { + if sqlDigest != nil { + sqlDigestBytes := sqlDigest.Bytes() + linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) + } +} + +// RegisterPlan uses to register plan information into Top SQL. +func RegisterPlan(normalizedPlan string, planDigest *parser.Digest) { + if planDigest != nil { + planDigestBytes := planDigest.Bytes() + linkPlanTextWithDigest(planDigestBytes, normalizedPlan) + } +} + +// AttachAndRegisterSQLInfo attach the sql information into Top SQL and register the SQL meta information. +func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDigest *parser.Digest, isInternal bool) context.Context { + if sqlDigest == nil || len(sqlDigest.String()) == 0 { + return ctx + } + sqlDigestBytes := sqlDigest.Bytes() + ctx = collector.CtxWithSQLDigest(ctx, sqlDigest.String()) + pprof.SetGoroutineLabels(ctx) + + linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) + + failpoint.Inject("mockHighLoadForEachSQL", func(val failpoint.Value) { + // In integration test, some SQL run very fast that Top SQL pprof profile unable to sample data of those SQL, + // So need mock some high cpu load to make sure pprof profile successfully samples the data of those SQL. + // Attention: Top SQL pprof profile unable to sample data of those SQL which run very fast, this behavior is expected. + // The integration test was just want to make sure each type of SQL will be set goroutine labels and and can be collected. + if val.(bool) { + sqlPrefixes := []string{"insert", "update", "delete", "load", "replace", "select", "begin", + "commit", "analyze", "explain", "trace", "create", "set global"} + if MockHighCPULoad(normalizedSQL, sqlPrefixes, 1) { + logutil.BgLogger().Info("attach SQL info", zap.String("sql", normalizedSQL)) + } + } + }) + return ctx +} + +// AttachSQLAndPlanInfo attach the sql and plan information into Top SQL +func AttachSQLAndPlanInfo(ctx context.Context, sqlDigest *parser.Digest, planDigest *parser.Digest) context.Context { + if sqlDigest == nil || len(sqlDigest.String()) == 0 { + return ctx + } + var planDigestStr string + sqlDigestStr := sqlDigest.String() + if planDigest != nil { + planDigestStr = planDigest.String() + } + ctx = collector.CtxWithSQLAndPlanDigest(ctx, sqlDigestStr, planDigestStr) + pprof.SetGoroutineLabels(ctx) + + failpoint.Inject("mockHighLoadForEachPlan", func(val failpoint.Value) { + // Work like mockHighLoadForEachSQL failpoint. + if val.(bool) { + if MockHighCPULoad("", []string{""}, 1) { + logutil.BgLogger().Info("attach SQL info") + } + } + }) + return ctx +} + +// MockHighCPULoad mocks high cpu load, only use in failpoint test. +func MockHighCPULoad(sql string, sqlPrefixs []string, load int64) bool { + lowerSQL := strings.ToLower(sql) + if strings.Contains(lowerSQL, "mysql") && !strings.Contains(lowerSQL, "global_variables") { + return false + } + match := false + for _, prefix := range sqlPrefixs { + if strings.HasPrefix(lowerSQL, prefix) { + match = true + break + } + } + if !match { + return false + } + start := time.Now() + for { + if time.Since(start) > 12*time.Millisecond*time.Duration(load) { + break + } + for i := 0; i < 10e5; i++ { + continue + } + } + return true +} + +func linkSQLTextWithDigest(sqlDigest []byte, normalizedSQL string, isInternal bool) { + if len(normalizedSQL) > MaxSQLTextSize { + normalizedSQL = normalizedSQL[:MaxSQLTextSize] + } + + globalTopSQLReport.RegisterSQL(sqlDigest, normalizedSQL, isInternal) +} + +func linkPlanTextWithDigest(planDigest []byte, normalizedBinaryPlan string) { + globalTopSQLReport.RegisterPlan(planDigest, normalizedBinaryPlan, len(normalizedBinaryPlan) > MaxBinaryPlanSize) +} From a6ce64b48d96fd56b70c7ef8f238fc5ae3dc1c13 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 14:12:01 -0500 Subject: [PATCH 07/35] testcase updates6 --- br/pkg/backup/binding__failpoint_binding__.go | 14 - .../binding__failpoint_binding__.go | 14 - br/pkg/backup/prepare_snap/prepare.go | 6 +- .../prepare.go__failpoint_stash__ | 484 -- br/pkg/backup/store.go | 32 +- br/pkg/backup/store.go__failpoint_stash__ | 307 - .../binding__failpoint_binding__.go | 14 - br/pkg/checkpoint/checkpoint.go | 22 +- .../checkpoint.go__failpoint_stash__ | 872 -- .../checksum/binding__failpoint_binding__.go | 14 - br/pkg/checksum/executor.go | 4 +- .../checksum/executor.go__failpoint_stash__ | 419 - br/pkg/conn/binding__failpoint_binding__.go | 14 - br/pkg/conn/conn.go | 18 +- br/pkg/conn/conn.go__failpoint_stash__ | 457 -- br/pkg/pdutil/binding__failpoint_binding__.go | 14 - br/pkg/pdutil/pd.go | 4 +- br/pkg/pdutil/pd.go__failpoint_stash__ | 782 -- .../restore/binding__failpoint_binding__.go | 14 - .../binding__failpoint_binding__.go | 14 - br/pkg/restore/log_client/client.go | 30 +- .../log_client/client.go__failpoint_stash__ | 1689 ---- br/pkg/restore/log_client/import_retry.go | 4 +- .../import_retry.go__failpoint_stash__ | 284 - br/pkg/restore/misc.go | 4 +- br/pkg/restore/misc.go__failpoint_stash__ | 157 - .../binding__failpoint_binding__.go | 14 - br/pkg/restore/snap_client/client.go | 10 +- .../snap_client/client.go__failpoint_stash__ | 1200 --- br/pkg/restore/snap_client/context_manager.go | 4 +- .../context_manager.go__failpoint_stash__ | 290 - br/pkg/restore/snap_client/import.go | 8 +- .../snap_client/import.go__failpoint_stash__ | 846 -- .../split/binding__failpoint_binding__.go | 14 - br/pkg/restore/split/client.go | 26 +- .../split/client.go__failpoint_stash__ | 1067 --- br/pkg/restore/split/split.go | 4 +- .../restore/split/split.go__failpoint_stash__ | 352 - .../storage/binding__failpoint_binding__.go | 14 - br/pkg/storage/s3.go | 8 +- br/pkg/storage/s3.go__failpoint_stash__ | 1208 --- br/pkg/streamhelper/advancer.go | 12 +- .../advancer.go__failpoint_stash__ | 735 -- br/pkg/streamhelper/advancer_cliext.go | 8 +- .../advancer_cliext.go__failpoint_stash__ | 301 - .../binding__failpoint_binding__.go | 14 - br/pkg/streamhelper/flush_subscriber.go | 6 +- .../flush_subscriber.go__failpoint_stash__ | 373 - br/pkg/task/backup.go | 8 +- br/pkg/task/backup.go__failpoint_stash__ | 835 -- br/pkg/task/binding__failpoint_binding__.go | 14 - .../operator/binding__failpoint_binding__.go | 14 - br/pkg/task/operator/cmd.go | 6 +- .../task/operator/cmd.go__failpoint_stash__ | 251 - br/pkg/task/restore.go | 4 +- br/pkg/task/restore.go__failpoint_stash__ | 1713 ---- br/pkg/task/stream.go | 12 +- br/pkg/task/stream.go__failpoint_stash__ | 1846 ----- br/pkg/utils/backoff.go | 4 +- br/pkg/utils/backoff.go__failpoint_stash__ | 323 - br/pkg/utils/binding__failpoint_binding__.go | 14 - br/pkg/utils/pprof.go | 4 +- br/pkg/utils/pprof.go__failpoint_stash__ | 69 - br/pkg/utils/register.go | 16 +- br/pkg/utils/register.go__failpoint_stash__ | 334 - br/pkg/utils/store_manager.go | 4 +- .../utils/store_manager.go__failpoint_stash__ | 264 - .../export/binding__failpoint_binding__.go | 14 - dumpling/export/config.go | 4 +- dumpling/export/config.go__failpoint_stash__ | 809 -- dumpling/export/dump.go | 26 +- dumpling/export/dump.go__failpoint_stash__ | 1704 ---- dumpling/export/sql.go | 12 +- dumpling/export/sql.go__failpoint_stash__ | 1643 ---- dumpling/export/status.go | 4 +- dumpling/export/status.go__failpoint_stash__ | 144 - dumpling/export/writer_util.go | 16 +- .../export/writer_util.go__failpoint_stash__ | 674 -- .../importer/binding__failpoint_binding__.go | 14 - lightning/pkg/importer/chunk_process.go | 14 +- .../chunk_process.go__failpoint_stash__ | 778 -- lightning/pkg/importer/get_pre_info.go | 35 +- .../get_pre_info.go__failpoint_stash__ | 835 -- lightning/pkg/importer/import.go | 30 +- .../pkg/importer/import.go__failpoint_stash__ | 2080 ----- lightning/pkg/importer/table_import.go | 40 +- .../table_import.go__failpoint_stash__ | 1822 ----- .../server/binding__failpoint_binding__.go | 14 - lightning/pkg/server/lightning.go | 40 +- .../server/lightning.go__failpoint_stash__ | 1152 --- pkg/autoid_service/autoid.go | 6 +- .../autoid.go__failpoint_stash__ | 612 -- .../binding__failpoint_binding__.go | 14 - pkg/bindinfo/binding__failpoint_binding__.go | 14 - pkg/bindinfo/global_handle.go | 4 +- .../global_handle.go__failpoint_stash__ | 745 -- pkg/ddl/add_column.go | 8 +- pkg/ddl/add_column.go__failpoint_stash__ | 1288 --- pkg/ddl/backfilling.go | 26 +- pkg/ddl/backfilling.go__failpoint_stash__ | 1124 --- pkg/ddl/backfilling_dist_scheduler.go | 22 +- ...lling_dist_scheduler.go__failpoint_stash__ | 639 -- pkg/ddl/backfilling_operators.go | 34 +- ...ackfilling_operators.go__failpoint_stash__ | 962 --- pkg/ddl/backfilling_read_index.go | 2 +- ...ckfilling_read_index.go__failpoint_stash__ | 315 - pkg/ddl/binding__failpoint_binding__.go | 14 - pkg/ddl/cluster.go | 16 +- pkg/ddl/cluster.go__failpoint_stash__ | 902 -- pkg/ddl/column.go | 22 +- pkg/ddl/column.go__failpoint_stash__ | 1320 --- pkg/ddl/constraint.go | 12 +- pkg/ddl/constraint.go__failpoint_stash__ | 432 - pkg/ddl/create_table.go | 14 +- pkg/ddl/create_table.go__failpoint_stash__ | 1527 ---- pkg/ddl/ddl.go | 12 +- pkg/ddl/ddl.go__failpoint_stash__ | 1421 ---- pkg/ddl/ddl_tiflash_api.go | 32 +- pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ | 608 -- pkg/ddl/delete_range.go | 4 +- pkg/ddl/delete_range.go__failpoint_stash__ | 548 -- pkg/ddl/executor.go | 32 +- pkg/ddl/executor.go__failpoint_stash__ | 6540 --------------- pkg/ddl/index.go | 72 +- pkg/ddl/index.go__failpoint_stash__ | 2616 ------ pkg/ddl/index_cop.go | 8 +- pkg/ddl/index_cop.go__failpoint_stash__ | 392 - pkg/ddl/index_merge_tmp.go | 4 +- pkg/ddl/index_merge_tmp.go__failpoint_stash__ | 400 - pkg/ddl/ingest/backend.go | 16 +- pkg/ddl/ingest/backend.go__failpoint_stash__ | 343 - pkg/ddl/ingest/backend_mgr.go | 4 +- .../ingest/backend_mgr.go__failpoint_stash__ | 285 - .../ingest/binding__failpoint_binding__.go | 14 - pkg/ddl/ingest/checkpoint.go | 16 +- .../ingest/checkpoint.go__failpoint_stash__ | 509 -- pkg/ddl/ingest/disk_root.go | 6 +- .../ingest/disk_root.go__failpoint_stash__ | 157 - pkg/ddl/ingest/env.go | 4 +- pkg/ddl/ingest/env.go__failpoint_stash__ | 189 - pkg/ddl/ingest/mock.go | 2 +- pkg/ddl/ingest/mock.go__failpoint_stash__ | 236 - pkg/ddl/job_scheduler.go | 34 +- pkg/ddl/job_scheduler.go__failpoint_stash__ | 837 -- pkg/ddl/job_submitter.go | 18 +- pkg/ddl/job_submitter.go__failpoint_stash__ | 669 -- pkg/ddl/job_worker.go | 32 +- pkg/ddl/job_worker.go__failpoint_stash__ | 1013 --- pkg/ddl/mock.go | 20 +- pkg/ddl/mock.go__failpoint_stash__ | 260 - pkg/ddl/modify_column.go | 18 +- pkg/ddl/modify_column.go__failpoint_stash__ | 1318 --- pkg/ddl/partition.go | 54 +- pkg/ddl/partition.go__failpoint_stash__ | 4922 ----------- .../placement/binding__failpoint_binding__.go | 14 - pkg/ddl/placement/bundle.go | 4 +- .../placement/bundle.go__failpoint_stash__ | 712 -- pkg/ddl/reorg.go | 18 +- pkg/ddl/reorg.go__failpoint_stash__ | 982 --- pkg/ddl/rollingback.go | 6 +- pkg/ddl/rollingback.go__failpoint_stash__ | 629 -- pkg/ddl/schema_version.go | 8 +- pkg/ddl/schema_version.go__failpoint_stash__ | 418 - .../session/binding__failpoint_binding__.go | 14 - pkg/ddl/session/session.go | 4 +- pkg/ddl/session/session.go__failpoint_stash__ | 137 - .../syncer/binding__failpoint_binding__.go | 14 - pkg/ddl/syncer/syncer.go | 8 +- pkg/ddl/syncer/syncer.go__failpoint_stash__ | 629 -- pkg/ddl/table.go | 36 +- pkg/ddl/table.go__failpoint_stash__ | 1681 ---- pkg/ddl/util/binding__failpoint_binding__.go | 14 - pkg/ddl/util/util.go | 4 +- pkg/ddl/util/util.go__failpoint_stash__ | 427 - pkg/distsql/binding__failpoint_binding__.go | 14 - pkg/distsql/request_builder.go | 8 +- .../request_builder.go__failpoint_stash__ | 862 -- pkg/distsql/select_result.go | 6 +- .../select_result.go__failpoint_stash__ | 815 -- .../scheduler/binding__failpoint_binding__.go | 14 - pkg/disttask/framework/scheduler/nodes.go | 4 +- .../scheduler/nodes.go__failpoint_stash__ | 191 - pkg/disttask/framework/scheduler/scheduler.go | 32 +- .../scheduler/scheduler.go__failpoint_stash__ | 624 -- .../framework/scheduler/scheduler_manager.go | 14 +- .../scheduler_manager.go__failpoint_stash__ | 485 -- .../storage/binding__failpoint_binding__.go | 14 - pkg/disttask/framework/storage/history.go | 4 +- .../storage/history.go__failpoint_stash__ | 96 - pkg/disttask/framework/storage/task_table.go | 8 +- .../storage/task_table.go__failpoint_stash__ | 809 -- .../binding__failpoint_binding__.go | 14 - .../framework/taskexecutor/task_executor.go | 38 +- .../task_executor.go__failpoint_stash__ | 683 -- .../binding__failpoint_binding__.go | 14 - pkg/disttask/importinto/planner.go | 14 +- .../importinto/planner.go__failpoint_stash__ | 528 -- pkg/disttask/importinto/scheduler.go | 20 +- .../scheduler.go__failpoint_stash__ | 719 -- pkg/disttask/importinto/subtask_executor.go | 14 +- .../subtask_executor.go__failpoint_stash__ | 142 - pkg/disttask/importinto/task_executor.go | 4 +- .../task_executor.go__failpoint_stash__ | 577 -- pkg/domain/binding__failpoint_binding__.go | 14 - pkg/domain/domain.go | 38 +- pkg/domain/domain.go__failpoint_stash__ | 3239 -------- pkg/domain/historical_stats.go | 4 +- .../historical_stats.go__failpoint_stash__ | 98 - .../infosync/binding__failpoint_binding__.go | 14 - pkg/domain/infosync/info.go | 32 +- .../infosync/info.go__failpoint_stash__ | 1457 ---- pkg/domain/infosync/tiflash_manager.go | 4 +- .../tiflash_manager.go__failpoint_stash__ | 893 -- pkg/domain/plan_replayer_dump.go | 4 +- .../plan_replayer_dump.go__failpoint_stash__ | 923 --- pkg/domain/runaway.go | 12 +- pkg/domain/runaway.go__failpoint_stash__ | 659 -- pkg/executor/adapter.go | 40 +- pkg/executor/adapter.go__failpoint_stash__ | 2240 ----- pkg/executor/aggregate/agg_hash_executor.go | 34 +- .../agg_hash_executor.go__failpoint_stash__ | 857 -- .../aggregate/agg_hash_final_worker.go | 12 +- ...gg_hash_final_worker.go__failpoint_stash__ | 283 - .../aggregate/agg_hash_partial_worker.go | 14 +- ..._hash_partial_worker.go__failpoint_stash__ | 390 - pkg/executor/aggregate/agg_stream_executor.go | 14 +- .../agg_stream_executor.go__failpoint_stash__ | 237 - pkg/executor/aggregate/agg_util.go | 4 +- .../aggregate/agg_util.go__failpoint_stash__ | 312 - .../aggregate/binding__failpoint_binding__.go | 14 - pkg/executor/analyze.go | 12 +- pkg/executor/analyze.go__failpoint_stash__ | 619 -- pkg/executor/analyze_col.go | 8 +- .../analyze_col.go__failpoint_stash__ | 494 -- pkg/executor/analyze_col_v2.go | 20 +- .../analyze_col_v2.go__failpoint_stash__ | 885 -- pkg/executor/analyze_idx.go | 14 +- .../analyze_idx.go__failpoint_stash__ | 344 - pkg/executor/batch_point_get.go | 6 +- .../batch_point_get.go__failpoint_stash__ | 528 -- pkg/executor/binding__failpoint_binding__.go | 14 - pkg/executor/brie.go | 10 +- pkg/executor/brie.go__failpoint_stash__ | 826 -- pkg/executor/builder.go | 42 +- pkg/executor/builder.go__failpoint_stash__ | 5659 ------------- pkg/executor/checksum.go | 2 +- pkg/executor/checksum.go__failpoint_stash__ | 336 - pkg/executor/compiler.go | 8 +- pkg/executor/compiler.go__failpoint_stash__ | 568 -- pkg/executor/cte.go | 20 +- pkg/executor/cte.go__failpoint_stash__ | 770 -- pkg/executor/executor.go | 6 +- pkg/executor/executor.go__failpoint_stash__ | 2673 ------ pkg/executor/import_into.go | 2 +- .../import_into.go__failpoint_stash__ | 344 - .../importer/binding__failpoint_binding__.go | 14 - pkg/executor/importer/job.go | 4 +- .../importer/job.go__failpoint_stash__ | 370 - pkg/executor/importer/table_import.go | 16 +- .../table_import.go__failpoint_stash__ | 983 --- pkg/executor/index_merge_reader.go | 72 +- .../index_merge_reader.go__failpoint_stash__ | 2056 ----- pkg/executor/infoschema_reader.go | 4 +- .../infoschema_reader.go__failpoint_stash__ | 3878 --------- pkg/executor/inspection_result.go | 4 +- .../inspection_result.go__failpoint_stash__ | 1248 --- .../binding__failpoint_binding__.go | 14 - .../calibrateresource/calibrate_resource.go | 12 +- .../calibrate_resource.go__failpoint_stash__ | 704 -- .../exec/binding__failpoint_binding__.go | 14 - pkg/executor/internal/exec/executor.go | 12 +- .../exec/executor.go__failpoint_stash__ | 468 -- .../mpp/binding__failpoint_binding__.go | 14 - .../internal/mpp/executor_with_retry.go | 16 +- .../executor_with_retry.go__failpoint_stash__ | 254 - .../internal/mpp/local_mpp_coordinator.go | 18 +- ...ocal_mpp_coordinator.go__failpoint_stash__ | 772 -- .../pdhelper/binding__failpoint_binding__.go | 14 - pkg/executor/internal/pdhelper/pd.go | 4 +- .../pdhelper/pd.go__failpoint_stash__ | 128 - pkg/executor/join/base_join_probe.go | 4 +- .../base_join_probe.go__failpoint_stash__ | 593 -- .../join/binding__failpoint_binding__.go | 14 - pkg/executor/join/hash_join_base.go | 26 +- .../join/hash_join_base.go__failpoint_stash__ | 379 - pkg/executor/join/hash_join_v1.go | 16 +- .../join/hash_join_v1.go__failpoint_stash__ | 1434 ---- pkg/executor/join/hash_join_v2.go | 12 +- .../join/hash_join_v2.go__failpoint_stash__ | 943 --- pkg/executor/join/index_lookup_hash_join.go | 34 +- ...dex_lookup_hash_join.go__failpoint_stash__ | 884 -- pkg/executor/join/index_lookup_join.go | 14 +- .../index_lookup_join.go__failpoint_stash__ | 882 -- pkg/executor/join/index_lookup_merge_join.go | 6 +- ...ex_lookup_merge_join.go__failpoint_stash__ | 743 -- pkg/executor/join/merge_join.go | 10 +- .../join/merge_join.go__failpoint_stash__ | 420 - pkg/executor/load_data.go | 8 +- pkg/executor/load_data.go__failpoint_stash__ | 780 -- pkg/executor/memtable_reader.go | 8 +- .../memtable_reader.go__failpoint_stash__ | 1009 --- pkg/executor/metrics_reader.go | 12 +- .../metrics_reader.go__failpoint_stash__ | 365 - pkg/executor/parallel_apply.go | 8 +- .../parallel_apply.go__failpoint_stash__ | 405 - pkg/executor/point_get.go | 10 +- pkg/executor/point_get.go__failpoint_stash__ | 824 -- pkg/executor/projection.go | 16 +- pkg/executor/projection.go__failpoint_stash__ | 501 -- pkg/executor/shuffle.go | 12 +- pkg/executor/shuffle.go__failpoint_stash__ | 492 -- pkg/executor/slow_query.go | 8 +- pkg/executor/slow_query.go__failpoint_stash__ | 1259 --- .../sortexec/binding__failpoint_binding__.go | 14 - pkg/executor/sortexec/parallel_sort_worker.go | 8 +- ...parallel_sort_worker.go__failpoint_stash__ | 229 - pkg/executor/sortexec/sort.go | 16 +- .../sortexec/sort.go__failpoint_stash__ | 845 -- pkg/executor/sortexec/sort_partition.go | 8 +- .../sort_partition.go__failpoint_stash__ | 367 - pkg/executor/sortexec/sort_util.go | 4 +- .../sortexec/sort_util.go__failpoint_stash__ | 124 - pkg/executor/sortexec/topn.go | 4 +- .../sortexec/topn.go__failpoint_stash__ | 647 -- pkg/executor/sortexec/topn_worker.go | 4 +- .../topn_worker.go__failpoint_stash__ | 130 - pkg/executor/table_reader.go | 4 +- .../table_reader.go__failpoint_stash__ | 632 -- .../unionexec/binding__failpoint_binding__.go | 14 - pkg/executor/unionexec/union.go | 12 +- .../unionexec/union.go__failpoint_stash__ | 232 - .../binding__failpoint_binding__.go | 14 - pkg/expression/aggregation/explain.go | 4 +- .../aggregation/explain.go__failpoint_stash__ | 80 - .../binding__failpoint_binding__.go | 14 - pkg/expression/builtin_json.go | 6 +- .../builtin_json.go__failpoint_stash__ | 1940 ----- pkg/expression/builtin_time.go | 8 +- .../builtin_time.go__failpoint_stash__ | 6832 ---------------- pkg/expression/expr_to_pb.go | 4 +- .../expr_to_pb.go__failpoint_stash__ | 319 - pkg/expression/helper.go | 6 +- pkg/expression/helper.go__failpoint_stash__ | 170 - pkg/expression/infer_pushdown.go | 14 +- .../infer_pushdown.go__failpoint_stash__ | 536 -- pkg/expression/util.go | 6 +- pkg/expression/util.go__failpoint_stash__ | 1921 ----- .../binding__failpoint_binding__.go | 14 - pkg/infoschema/builder.go | 6 +- pkg/infoschema/builder.go__failpoint_stash__ | 1040 --- pkg/infoschema/infoschema_v2.go | 6 +- .../infoschema_v2.go__failpoint_stash__ | 1456 ---- .../binding__failpoint_binding__.go | 14 - pkg/infoschema/perfschema/tables.go | 4 +- .../perfschema/tables.go__failpoint_stash__ | 415 - pkg/infoschema/sieve.go | 6 +- pkg/infoschema/sieve.go__failpoint_stash__ | 272 - pkg/infoschema/tables.go | 22 +- pkg/infoschema/tables.go__failpoint_stash__ | 2694 ------ pkg/kv/binding__failpoint_binding__.go | 14 - pkg/kv/txn.go | 10 +- pkg/kv/txn.go__failpoint_stash__ | 247 - pkg/lightning/backend/backend.go | 4 +- .../backend/backend.go__failpoint_stash__ | 439 - .../backend/binding__failpoint_binding__.go | 14 - .../external/binding__failpoint_binding__.go | 14 - pkg/lightning/backend/external/byte_reader.go | 4 +- .../byte_reader.go__failpoint_stash__ | 351 - pkg/lightning/backend/external/engine.go | 4 +- .../external/engine.go__failpoint_stash__ | 732 -- pkg/lightning/backend/external/merge_v2.go | 4 +- .../external/merge_v2.go__failpoint_stash__ | 183 - .../local/binding__failpoint_binding__.go | 14 - pkg/lightning/backend/local/checksum.go | 2 +- .../local/checksum.go__failpoint_stash__ | 517 -- pkg/lightning/backend/local/engine.go | 12 +- .../local/engine.go__failpoint_stash__ | 1682 ---- pkg/lightning/backend/local/engine_mgr.go | 4 +- .../local/engine_mgr.go__failpoint_stash__ | 658 -- pkg/lightning/backend/local/local.go | 50 +- .../backend/local/local.go__failpoint_stash__ | 1754 ---- pkg/lightning/backend/local/local_unix.go | 8 +- .../local/local_unix.go__failpoint_stash__ | 92 - pkg/lightning/backend/local/region_job.go | 36 +- .../local/region_job.go__failpoint_stash__ | 907 -- .../tidb/binding__failpoint_binding__.go | 14 - pkg/lightning/backend/tidb/tidb.go | 15 +- .../backend/tidb/tidb.go__failpoint_stash__ | 956 --- .../common/binding__failpoint_binding__.go | 14 - pkg/lightning/common/storage_unix.go | 6 +- .../common/storage_unix.go__failpoint_stash__ | 79 - pkg/lightning/common/storage_windows.go | 6 +- .../storage_windows.go__failpoint_stash__ | 56 - pkg/lightning/common/util.go | 8 +- .../common/util.go__failpoint_stash__ | 704 -- .../mydump/binding__failpoint_binding__.go | 14 - pkg/lightning/mydump/loader.go | 8 +- .../mydump/loader.go__failpoint_stash__ | 868 -- pkg/meta/autoid/autoid.go | 16 +- pkg/meta/autoid/autoid.go__failpoint_stash__ | 1351 --- .../autoid/binding__failpoint_binding__.go | 14 - pkg/owner/binding__failpoint_binding__.go | 14 - pkg/owner/manager.go | 10 +- pkg/owner/manager.go__failpoint_stash__ | 486 -- pkg/owner/mock.go | 6 +- pkg/owner/mock.go__failpoint_stash__ | 230 - .../ast/binding__failpoint_binding__.go | 14 - pkg/parser/ast/misc.go | 4 +- pkg/parser/ast/misc.go__failpoint_stash__ | 4209 ---------- pkg/planner/binding__failpoint_binding__.go | 14 - .../binding__failpoint_binding__.go | 14 - pkg/planner/cardinality/row_count_index.go | 4 +- .../row_count_index.go__failpoint_stash__ | 568 -- .../core/binding__failpoint_binding__.go | 14 - .../core/collect_column_stats_usage.go | 4 +- ...t_column_stats_usage.go__failpoint_stash__ | 456 -- pkg/planner/core/debugtrace.go | 4 +- .../core/debugtrace.go__failpoint_stash__ | 261 - pkg/planner/core/encode.go | 8 +- pkg/planner/core/encode.go__failpoint_stash__ | 386 - pkg/planner/core/exhaust_physical_plans.go | 16 +- ...haust_physical_plans.go__failpoint_stash__ | 3004 ------- pkg/planner/core/find_best_task.go | 6 +- .../core/find_best_task.go__failpoint_stash__ | 2982 ------- pkg/planner/core/logical_join.go | 6 +- .../core/logical_join.go__failpoint_stash__ | 1672 ---- pkg/planner/core/logical_plan_builder.go | 4 +- ...logical_plan_builder.go__failpoint_stash__ | 7284 ----------------- pkg/planner/core/optimizer.go | 8 +- .../core/optimizer.go__failpoint_stash__ | 1222 --- pkg/planner/core/rule_collect_plan_stats.go | 4 +- ...e_collect_plan_stats.go__failpoint_stash__ | 316 - pkg/planner/core/rule_eliminate_projection.go | 6 +- ...eliminate_projection.go__failpoint_stash__ | 274 - .../core/rule_inject_extra_projection.go | 6 +- ...ect_extra_projection.go__failpoint_stash__ | 352 - pkg/planner/core/task.go | 4 +- pkg/planner/core/task.go__failpoint_stash__ | 2473 ------ pkg/planner/optimize.go | 18 +- pkg/planner/optimize.go__failpoint_stash__ | 631 -- pkg/server/binding__failpoint_binding__.go | 14 - pkg/server/conn.go | 38 +- pkg/server/conn.go__failpoint_stash__ | 2748 ------- pkg/server/conn_stmt.go | 10 +- pkg/server/conn_stmt.go__failpoint_stash__ | 673 -- .../handler/binding__failpoint_binding__.go | 14 - .../binding__failpoint_binding__.go | 14 - .../handler/extractorhandler/extractor.go | 6 +- .../extractor.go__failpoint_stash__ | 169 - pkg/server/handler/tikv_handler.go | 4 +- .../tikv_handler.go__failpoint_stash__ | 279 - pkg/server/http_status.go | 4 +- pkg/server/http_status.go__failpoint_stash__ | 613 -- pkg/session/binding__failpoint_binding__.go | 14 - pkg/session/nontransactional.go | 4 +- .../nontransactional.go__failpoint_stash__ | 847 -- pkg/session/session.go | 36 +- pkg/session/session.go__failpoint_stash__ | 4611 ----------- pkg/session/sync_upgrade.go | 6 +- .../sync_upgrade.go__failpoint_stash__ | 163 - pkg/session/tidb.go | 6 +- pkg/session/tidb.go__failpoint_stash__ | 403 - pkg/session/txn.go | 26 +- pkg/session/txn.go__failpoint_stash__ | 778 -- pkg/session/txnmanager.go | 4 +- pkg/session/txnmanager.go__failpoint_stash__ | 381 - .../binding__failpoint_binding__.go | 14 - pkg/sessionctx/sessionstates/session_token.go | 4 +- .../session_token.go__failpoint_stash__ | 336 - pkg/sessiontxn/isolation/base.go | 4 +- .../isolation/base.go__failpoint_stash__ | 747 -- .../isolation/binding__failpoint_binding__.go | 14 - pkg/sessiontxn/isolation/readcommitted.go | 12 +- .../readcommitted.go__failpoint_stash__ | 360 - pkg/sessiontxn/isolation/repeatable_read.go | 4 +- .../repeatable_read.go__failpoint_stash__ | 309 - .../staleread/binding__failpoint_binding__.go | 14 - pkg/sessiontxn/staleread/util.go | 4 +- .../staleread/util.go__failpoint_stash__ | 97 - .../binding__failpoint_binding__.go | 14 - pkg/statistics/cmsketch.go | 10 +- pkg/statistics/cmsketch.go__failpoint_stash__ | 865 -- .../handle/autoanalyze/autoanalyze.go | 16 +- .../autoanalyze.go__failpoint_stash__ | 843 -- .../binding__failpoint_binding__.go | 14 - .../handle/binding__failpoint_binding__.go | 14 - pkg/statistics/handle/bootstrap.go | 4 +- .../handle/bootstrap.go__failpoint_stash__ | 815 -- .../cache/binding__failpoint_binding__.go | 14 - pkg/statistics/handle/cache/statscache.go | 6 +- .../cache/statscache.go__failpoint_stash__ | 287 - .../binding__failpoint_binding__.go | 14 - .../handle/globalstats/global_stats_async.go | 36 +- .../global_stats_async.go__failpoint_stash__ | 542 -- .../storage/binding__failpoint_binding__.go | 14 - pkg/statistics/handle/storage/read.go | 6 +- .../handle/storage/read.go__failpoint_stash__ | 759 -- .../syncload/binding__failpoint_binding__.go | 14 - .../handle/syncload/stats_syncload.go | 12 +- .../stats_syncload.go__failpoint_stash__ | 574 -- pkg/store/copr/batch_coprocessor.go | 12 +- .../batch_coprocessor.go__failpoint_stash__ | 1588 ---- .../copr/binding__failpoint_binding__.go | 14 - pkg/store/copr/coprocessor.go | 76 +- .../copr/coprocessor.go__failpoint_stash__ | 2170 ----- pkg/store/copr/mpp.go | 16 +- pkg/store/copr/mpp.go__failpoint_stash__ | 346 - .../txn/binding__failpoint_binding__.go | 14 - pkg/store/driver/txn/binlog.go | 4 +- .../driver/txn/binlog.go__failpoint_stash__ | 90 - pkg/store/driver/txn/txn_driver.go | 2 +- .../txn/txn_driver.go__failpoint_stash__ | 491 -- .../gcworker/binding__failpoint_binding__.go | 14 - pkg/store/gcworker/gc_worker.go | 22 +- .../gcworker/gc_worker.go__failpoint_stash__ | 1759 ---- .../unistore/binding__failpoint_binding__.go | 14 - .../binding__failpoint_binding__.go | 14 - .../unistore/cophandler/cop_handler.go | 6 +- .../cop_handler.go__failpoint_stash__ | 674 -- pkg/store/mockstore/unistore/rpc.go | 150 +- .../unistore/rpc.go__failpoint_stash__ | 582 -- .../binding__failpoint_binding__.go | 14 - pkg/table/contextimpl/table.go | 6 +- .../contextimpl/table.go__failpoint_stash__ | 200 - .../tables/binding__failpoint_binding__.go | 14 - pkg/table/tables/cache.go | 12 +- pkg/table/tables/cache.go__failpoint_stash__ | 355 - pkg/table/tables/mutation_checker.go | 6 +- .../mutation_checker.go__failpoint_stash__ | 556 -- pkg/table/tables/tables.go | 18 +- pkg/table/tables/tables.go__failpoint_stash__ | 2100 ----- .../ttlworker/binding__failpoint_binding__.go | 14 - pkg/ttl/ttlworker/config.go | 20 +- .../ttlworker/config.go__failpoint_stash__ | 70 - pkg/util/binding__failpoint_binding__.go | 14 - .../binding__failpoint_binding__.go | 14 - pkg/util/breakpoint/breakpoint.go | 4 +- .../breakpoint.go__failpoint_stash__ | 34 - .../cgroup/binding__failpoint_binding__.go | 14 - pkg/util/cgroup/cgroup_cpu_linux.go | 6 +- .../cgroup_cpu_linux.go__failpoint_stash__ | 100 - pkg/util/cgroup/cgroup_cpu_unsupport.go | 6 +- ...cgroup_cpu_unsupport.go__failpoint_stash__ | 55 - .../chunk/binding__failpoint_binding__.go | 14 - pkg/util/chunk/chunk_in_disk.go | 4 +- .../chunk/chunk_in_disk.go__failpoint_stash__ | 344 - pkg/util/chunk/row_container.go | 16 +- .../chunk/row_container.go__failpoint_stash__ | 691 -- pkg/util/chunk/row_container_reader.go | 4 +- ...row_container_reader.go__failpoint_stash__ | 170 - .../codec/binding__failpoint_binding__.go | 14 - pkg/util/codec/decimal.go | 6 +- pkg/util/codec/decimal.go__failpoint_stash__ | 69 - pkg/util/cpu/binding__failpoint_binding__.go | 14 - pkg/util/cpu/cpu.go | 6 +- pkg/util/cpu/cpu.go__failpoint_stash__ | 132 - pkg/util/etcd.go | 12 +- pkg/util/etcd.go__failpoint_stash__ | 103 - .../gctuner/binding__failpoint_binding__.go | 14 - pkg/util/gctuner/memory_limit_tuner.go | 8 +- .../memory_limit_tuner.go__failpoint_stash__ | 190 - .../memory/binding__failpoint_binding__.go | 14 - pkg/util/memory/meminfo.go | 4 +- pkg/util/memory/meminfo.go__failpoint_stash__ | 215 - pkg/util/memory/memstats.go | 4 +- .../memory/memstats.go__failpoint_stash__ | 57 - .../replayer/binding__failpoint_binding__.go | 14 - pkg/util/replayer/replayer.go | 4 +- .../replayer/replayer.go__failpoint_stash__ | 87 - .../binding__failpoint_binding__.go | 14 - .../servermemorylimit/servermemorylimit.go | 4 +- .../servermemorylimit.go__failpoint_stash__ | 264 - pkg/util/session_pool.go | 4 +- pkg/util/session_pool.go__failpoint_stash__ | 113 - pkg/util/sli/binding__failpoint_binding__.go | 14 - pkg/util/sli/sli.go | 6 +- pkg/util/sli/sli.go__failpoint_stash__ | 120 - .../sqlkiller/binding__failpoint_binding__.go | 14 - pkg/util/sqlkiller/sqlkiller.go | 4 +- .../sqlkiller/sqlkiller.go__failpoint_stash__ | 136 - .../binding__failpoint_binding__.go | 14 - pkg/util/stmtsummary/statement_summary.go | 4 +- .../statement_summary.go__failpoint_stash__ | 1039 --- .../topsql/binding__failpoint_binding__.go | 14 - .../reporter/binding__failpoint_binding__.go | 14 - pkg/util/topsql/reporter/pubsub.go | 2 +- .../reporter/pubsub.go__failpoint_stash__ | 274 - pkg/util/topsql/reporter/reporter.go | 4 +- .../reporter/reporter.go__failpoint_stash__ | 333 - pkg/util/topsql/topsql.go | 8 +- pkg/util/topsql/topsql.go__failpoint_stash__ | 187 - 591 files changed, 1613 insertions(+), 218907 deletions(-) delete mode 100644 br/pkg/backup/binding__failpoint_binding__.go delete mode 100644 br/pkg/backup/prepare_snap/binding__failpoint_binding__.go delete mode 100644 br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ delete mode 100644 br/pkg/backup/store.go__failpoint_stash__ delete mode 100644 br/pkg/checkpoint/binding__failpoint_binding__.go delete mode 100644 br/pkg/checkpoint/checkpoint.go__failpoint_stash__ delete mode 100644 br/pkg/checksum/binding__failpoint_binding__.go delete mode 100644 br/pkg/checksum/executor.go__failpoint_stash__ delete mode 100644 br/pkg/conn/binding__failpoint_binding__.go delete mode 100644 br/pkg/conn/conn.go__failpoint_stash__ delete mode 100644 br/pkg/pdutil/binding__failpoint_binding__.go delete mode 100644 br/pkg/pdutil/pd.go__failpoint_stash__ delete mode 100644 br/pkg/restore/binding__failpoint_binding__.go delete mode 100644 br/pkg/restore/log_client/binding__failpoint_binding__.go delete mode 100644 br/pkg/restore/log_client/client.go__failpoint_stash__ delete mode 100644 br/pkg/restore/log_client/import_retry.go__failpoint_stash__ delete mode 100644 br/pkg/restore/misc.go__failpoint_stash__ delete mode 100644 br/pkg/restore/snap_client/binding__failpoint_binding__.go delete mode 100644 br/pkg/restore/snap_client/client.go__failpoint_stash__ delete mode 100644 br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ delete mode 100644 br/pkg/restore/snap_client/import.go__failpoint_stash__ delete mode 100644 br/pkg/restore/split/binding__failpoint_binding__.go delete mode 100644 br/pkg/restore/split/client.go__failpoint_stash__ delete mode 100644 br/pkg/restore/split/split.go__failpoint_stash__ delete mode 100644 br/pkg/storage/binding__failpoint_binding__.go delete mode 100644 br/pkg/storage/s3.go__failpoint_stash__ delete mode 100644 br/pkg/streamhelper/advancer.go__failpoint_stash__ delete mode 100644 br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ delete mode 100644 br/pkg/streamhelper/binding__failpoint_binding__.go delete mode 100644 br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ delete mode 100644 br/pkg/task/backup.go__failpoint_stash__ delete mode 100644 br/pkg/task/binding__failpoint_binding__.go delete mode 100644 br/pkg/task/operator/binding__failpoint_binding__.go delete mode 100644 br/pkg/task/operator/cmd.go__failpoint_stash__ delete mode 100644 br/pkg/task/restore.go__failpoint_stash__ delete mode 100644 br/pkg/task/stream.go__failpoint_stash__ delete mode 100644 br/pkg/utils/backoff.go__failpoint_stash__ delete mode 100644 br/pkg/utils/binding__failpoint_binding__.go delete mode 100644 br/pkg/utils/pprof.go__failpoint_stash__ delete mode 100644 br/pkg/utils/register.go__failpoint_stash__ delete mode 100644 br/pkg/utils/store_manager.go__failpoint_stash__ delete mode 100644 dumpling/export/binding__failpoint_binding__.go delete mode 100644 dumpling/export/config.go__failpoint_stash__ delete mode 100644 dumpling/export/dump.go__failpoint_stash__ delete mode 100644 dumpling/export/sql.go__failpoint_stash__ delete mode 100644 dumpling/export/status.go__failpoint_stash__ delete mode 100644 dumpling/export/writer_util.go__failpoint_stash__ delete mode 100644 lightning/pkg/importer/binding__failpoint_binding__.go delete mode 100644 lightning/pkg/importer/chunk_process.go__failpoint_stash__ delete mode 100644 lightning/pkg/importer/get_pre_info.go__failpoint_stash__ delete mode 100644 lightning/pkg/importer/import.go__failpoint_stash__ delete mode 100644 lightning/pkg/importer/table_import.go__failpoint_stash__ delete mode 100644 lightning/pkg/server/binding__failpoint_binding__.go delete mode 100644 lightning/pkg/server/lightning.go__failpoint_stash__ delete mode 100644 pkg/autoid_service/autoid.go__failpoint_stash__ delete mode 100644 pkg/autoid_service/binding__failpoint_binding__.go delete mode 100644 pkg/bindinfo/binding__failpoint_binding__.go delete mode 100644 pkg/bindinfo/global_handle.go__failpoint_stash__ delete mode 100644 pkg/ddl/add_column.go__failpoint_stash__ delete mode 100644 pkg/ddl/backfilling.go__failpoint_stash__ delete mode 100644 pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ delete mode 100644 pkg/ddl/backfilling_operators.go__failpoint_stash__ delete mode 100644 pkg/ddl/backfilling_read_index.go__failpoint_stash__ delete mode 100644 pkg/ddl/binding__failpoint_binding__.go delete mode 100644 pkg/ddl/cluster.go__failpoint_stash__ delete mode 100644 pkg/ddl/column.go__failpoint_stash__ delete mode 100644 pkg/ddl/constraint.go__failpoint_stash__ delete mode 100644 pkg/ddl/create_table.go__failpoint_stash__ delete mode 100644 pkg/ddl/ddl.go__failpoint_stash__ delete mode 100644 pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ delete mode 100644 pkg/ddl/delete_range.go__failpoint_stash__ delete mode 100644 pkg/ddl/executor.go__failpoint_stash__ delete mode 100644 pkg/ddl/index.go__failpoint_stash__ delete mode 100644 pkg/ddl/index_cop.go__failpoint_stash__ delete mode 100644 pkg/ddl/index_merge_tmp.go__failpoint_stash__ delete mode 100644 pkg/ddl/ingest/backend.go__failpoint_stash__ delete mode 100644 pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ delete mode 100644 pkg/ddl/ingest/binding__failpoint_binding__.go delete mode 100644 pkg/ddl/ingest/checkpoint.go__failpoint_stash__ delete mode 100644 pkg/ddl/ingest/disk_root.go__failpoint_stash__ delete mode 100644 pkg/ddl/ingest/env.go__failpoint_stash__ delete mode 100644 pkg/ddl/ingest/mock.go__failpoint_stash__ delete mode 100644 pkg/ddl/job_scheduler.go__failpoint_stash__ delete mode 100644 pkg/ddl/job_submitter.go__failpoint_stash__ delete mode 100644 pkg/ddl/job_worker.go__failpoint_stash__ delete mode 100644 pkg/ddl/mock.go__failpoint_stash__ delete mode 100644 pkg/ddl/modify_column.go__failpoint_stash__ delete mode 100644 pkg/ddl/partition.go__failpoint_stash__ delete mode 100644 pkg/ddl/placement/binding__failpoint_binding__.go delete mode 100644 pkg/ddl/placement/bundle.go__failpoint_stash__ delete mode 100644 pkg/ddl/reorg.go__failpoint_stash__ delete mode 100644 pkg/ddl/rollingback.go__failpoint_stash__ delete mode 100644 pkg/ddl/schema_version.go__failpoint_stash__ delete mode 100644 pkg/ddl/session/binding__failpoint_binding__.go delete mode 100644 pkg/ddl/session/session.go__failpoint_stash__ delete mode 100644 pkg/ddl/syncer/binding__failpoint_binding__.go delete mode 100644 pkg/ddl/syncer/syncer.go__failpoint_stash__ delete mode 100644 pkg/ddl/table.go__failpoint_stash__ delete mode 100644 pkg/ddl/util/binding__failpoint_binding__.go delete mode 100644 pkg/ddl/util/util.go__failpoint_stash__ delete mode 100644 pkg/distsql/binding__failpoint_binding__.go delete mode 100644 pkg/distsql/request_builder.go__failpoint_stash__ delete mode 100644 pkg/distsql/select_result.go__failpoint_stash__ delete mode 100644 pkg/disttask/framework/scheduler/binding__failpoint_binding__.go delete mode 100644 pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ delete mode 100644 pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ delete mode 100644 pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ delete mode 100644 pkg/disttask/framework/storage/binding__failpoint_binding__.go delete mode 100644 pkg/disttask/framework/storage/history.go__failpoint_stash__ delete mode 100644 pkg/disttask/framework/storage/task_table.go__failpoint_stash__ delete mode 100644 pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go delete mode 100644 pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ delete mode 100644 pkg/disttask/importinto/binding__failpoint_binding__.go delete mode 100644 pkg/disttask/importinto/planner.go__failpoint_stash__ delete mode 100644 pkg/disttask/importinto/scheduler.go__failpoint_stash__ delete mode 100644 pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ delete mode 100644 pkg/disttask/importinto/task_executor.go__failpoint_stash__ delete mode 100644 pkg/domain/binding__failpoint_binding__.go delete mode 100644 pkg/domain/domain.go__failpoint_stash__ delete mode 100644 pkg/domain/historical_stats.go__failpoint_stash__ delete mode 100644 pkg/domain/infosync/binding__failpoint_binding__.go delete mode 100644 pkg/domain/infosync/info.go__failpoint_stash__ delete mode 100644 pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ delete mode 100644 pkg/domain/plan_replayer_dump.go__failpoint_stash__ delete mode 100644 pkg/domain/runaway.go__failpoint_stash__ delete mode 100644 pkg/executor/adapter.go__failpoint_stash__ delete mode 100644 pkg/executor/aggregate/agg_hash_executor.go__failpoint_stash__ delete mode 100644 pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ delete mode 100644 pkg/executor/aggregate/agg_hash_partial_worker.go__failpoint_stash__ delete mode 100644 pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ delete mode 100644 pkg/executor/aggregate/agg_util.go__failpoint_stash__ delete mode 100644 pkg/executor/aggregate/binding__failpoint_binding__.go delete mode 100644 pkg/executor/analyze.go__failpoint_stash__ delete mode 100644 pkg/executor/analyze_col.go__failpoint_stash__ delete mode 100644 pkg/executor/analyze_col_v2.go__failpoint_stash__ delete mode 100644 pkg/executor/analyze_idx.go__failpoint_stash__ delete mode 100644 pkg/executor/batch_point_get.go__failpoint_stash__ delete mode 100644 pkg/executor/binding__failpoint_binding__.go delete mode 100644 pkg/executor/brie.go__failpoint_stash__ delete mode 100644 pkg/executor/builder.go__failpoint_stash__ delete mode 100644 pkg/executor/checksum.go__failpoint_stash__ delete mode 100644 pkg/executor/compiler.go__failpoint_stash__ delete mode 100644 pkg/executor/cte.go__failpoint_stash__ delete mode 100644 pkg/executor/executor.go__failpoint_stash__ delete mode 100644 pkg/executor/import_into.go__failpoint_stash__ delete mode 100644 pkg/executor/importer/binding__failpoint_binding__.go delete mode 100644 pkg/executor/importer/job.go__failpoint_stash__ delete mode 100644 pkg/executor/importer/table_import.go__failpoint_stash__ delete mode 100644 pkg/executor/index_merge_reader.go__failpoint_stash__ delete mode 100644 pkg/executor/infoschema_reader.go__failpoint_stash__ delete mode 100644 pkg/executor/inspection_result.go__failpoint_stash__ delete mode 100644 pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go delete mode 100644 pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ delete mode 100644 pkg/executor/internal/exec/binding__failpoint_binding__.go delete mode 100644 pkg/executor/internal/exec/executor.go__failpoint_stash__ delete mode 100644 pkg/executor/internal/mpp/binding__failpoint_binding__.go delete mode 100644 pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ delete mode 100644 pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ delete mode 100644 pkg/executor/internal/pdhelper/binding__failpoint_binding__.go delete mode 100644 pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ delete mode 100644 pkg/executor/join/base_join_probe.go__failpoint_stash__ delete mode 100644 pkg/executor/join/binding__failpoint_binding__.go delete mode 100644 pkg/executor/join/hash_join_base.go__failpoint_stash__ delete mode 100644 pkg/executor/join/hash_join_v1.go__failpoint_stash__ delete mode 100644 pkg/executor/join/hash_join_v2.go__failpoint_stash__ delete mode 100644 pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ delete mode 100644 pkg/executor/join/index_lookup_join.go__failpoint_stash__ delete mode 100644 pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ delete mode 100644 pkg/executor/join/merge_join.go__failpoint_stash__ delete mode 100644 pkg/executor/load_data.go__failpoint_stash__ delete mode 100644 pkg/executor/memtable_reader.go__failpoint_stash__ delete mode 100644 pkg/executor/metrics_reader.go__failpoint_stash__ delete mode 100644 pkg/executor/parallel_apply.go__failpoint_stash__ delete mode 100644 pkg/executor/point_get.go__failpoint_stash__ delete mode 100644 pkg/executor/projection.go__failpoint_stash__ delete mode 100644 pkg/executor/shuffle.go__failpoint_stash__ delete mode 100644 pkg/executor/slow_query.go__failpoint_stash__ delete mode 100644 pkg/executor/sortexec/binding__failpoint_binding__.go delete mode 100644 pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ delete mode 100644 pkg/executor/sortexec/sort.go__failpoint_stash__ delete mode 100644 pkg/executor/sortexec/sort_partition.go__failpoint_stash__ delete mode 100644 pkg/executor/sortexec/sort_util.go__failpoint_stash__ delete mode 100644 pkg/executor/sortexec/topn.go__failpoint_stash__ delete mode 100644 pkg/executor/sortexec/topn_worker.go__failpoint_stash__ delete mode 100644 pkg/executor/table_reader.go__failpoint_stash__ delete mode 100644 pkg/executor/unionexec/binding__failpoint_binding__.go delete mode 100644 pkg/executor/unionexec/union.go__failpoint_stash__ delete mode 100644 pkg/expression/aggregation/binding__failpoint_binding__.go delete mode 100644 pkg/expression/aggregation/explain.go__failpoint_stash__ delete mode 100644 pkg/expression/binding__failpoint_binding__.go delete mode 100644 pkg/expression/builtin_json.go__failpoint_stash__ delete mode 100644 pkg/expression/builtin_time.go__failpoint_stash__ delete mode 100644 pkg/expression/expr_to_pb.go__failpoint_stash__ delete mode 100644 pkg/expression/helper.go__failpoint_stash__ delete mode 100644 pkg/expression/infer_pushdown.go__failpoint_stash__ delete mode 100644 pkg/expression/util.go__failpoint_stash__ delete mode 100644 pkg/infoschema/binding__failpoint_binding__.go delete mode 100644 pkg/infoschema/builder.go__failpoint_stash__ delete mode 100644 pkg/infoschema/infoschema_v2.go__failpoint_stash__ delete mode 100644 pkg/infoschema/perfschema/binding__failpoint_binding__.go delete mode 100644 pkg/infoschema/perfschema/tables.go__failpoint_stash__ delete mode 100644 pkg/infoschema/sieve.go__failpoint_stash__ delete mode 100644 pkg/infoschema/tables.go__failpoint_stash__ delete mode 100644 pkg/kv/binding__failpoint_binding__.go delete mode 100644 pkg/kv/txn.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/backend.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/binding__failpoint_binding__.go delete mode 100644 pkg/lightning/backend/external/binding__failpoint_binding__.go delete mode 100644 pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/external/engine.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/local/binding__failpoint_binding__.go delete mode 100644 pkg/lightning/backend/local/checksum.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/local/engine.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/local/local.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/local/local_unix.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/local/region_job.go__failpoint_stash__ delete mode 100644 pkg/lightning/backend/tidb/binding__failpoint_binding__.go delete mode 100644 pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ delete mode 100644 pkg/lightning/common/binding__failpoint_binding__.go delete mode 100644 pkg/lightning/common/storage_unix.go__failpoint_stash__ delete mode 100644 pkg/lightning/common/storage_windows.go__failpoint_stash__ delete mode 100644 pkg/lightning/common/util.go__failpoint_stash__ delete mode 100644 pkg/lightning/mydump/binding__failpoint_binding__.go delete mode 100644 pkg/lightning/mydump/loader.go__failpoint_stash__ delete mode 100644 pkg/meta/autoid/autoid.go__failpoint_stash__ delete mode 100644 pkg/meta/autoid/binding__failpoint_binding__.go delete mode 100644 pkg/owner/binding__failpoint_binding__.go delete mode 100644 pkg/owner/manager.go__failpoint_stash__ delete mode 100644 pkg/owner/mock.go__failpoint_stash__ delete mode 100644 pkg/parser/ast/binding__failpoint_binding__.go delete mode 100644 pkg/parser/ast/misc.go__failpoint_stash__ delete mode 100644 pkg/planner/binding__failpoint_binding__.go delete mode 100644 pkg/planner/cardinality/binding__failpoint_binding__.go delete mode 100644 pkg/planner/cardinality/row_count_index.go__failpoint_stash__ delete mode 100644 pkg/planner/core/binding__failpoint_binding__.go delete mode 100644 pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ delete mode 100644 pkg/planner/core/debugtrace.go__failpoint_stash__ delete mode 100644 pkg/planner/core/encode.go__failpoint_stash__ delete mode 100644 pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ delete mode 100644 pkg/planner/core/find_best_task.go__failpoint_stash__ delete mode 100644 pkg/planner/core/logical_join.go__failpoint_stash__ delete mode 100644 pkg/planner/core/logical_plan_builder.go__failpoint_stash__ delete mode 100644 pkg/planner/core/optimizer.go__failpoint_stash__ delete mode 100644 pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ delete mode 100644 pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ delete mode 100644 pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ delete mode 100644 pkg/planner/core/task.go__failpoint_stash__ delete mode 100644 pkg/planner/optimize.go__failpoint_stash__ delete mode 100644 pkg/server/binding__failpoint_binding__.go delete mode 100644 pkg/server/conn.go__failpoint_stash__ delete mode 100644 pkg/server/conn_stmt.go__failpoint_stash__ delete mode 100644 pkg/server/handler/binding__failpoint_binding__.go delete mode 100644 pkg/server/handler/extractorhandler/binding__failpoint_binding__.go delete mode 100644 pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ delete mode 100644 pkg/server/handler/tikv_handler.go__failpoint_stash__ delete mode 100644 pkg/server/http_status.go__failpoint_stash__ delete mode 100644 pkg/session/binding__failpoint_binding__.go delete mode 100644 pkg/session/nontransactional.go__failpoint_stash__ delete mode 100644 pkg/session/session.go__failpoint_stash__ delete mode 100644 pkg/session/sync_upgrade.go__failpoint_stash__ delete mode 100644 pkg/session/tidb.go__failpoint_stash__ delete mode 100644 pkg/session/txn.go__failpoint_stash__ delete mode 100644 pkg/session/txnmanager.go__failpoint_stash__ delete mode 100644 pkg/sessionctx/sessionstates/binding__failpoint_binding__.go delete mode 100644 pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ delete mode 100644 pkg/sessiontxn/isolation/base.go__failpoint_stash__ delete mode 100644 pkg/sessiontxn/isolation/binding__failpoint_binding__.go delete mode 100644 pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ delete mode 100644 pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ delete mode 100644 pkg/sessiontxn/staleread/binding__failpoint_binding__.go delete mode 100644 pkg/sessiontxn/staleread/util.go__failpoint_stash__ delete mode 100644 pkg/statistics/binding__failpoint_binding__.go delete mode 100644 pkg/statistics/cmsketch.go__failpoint_stash__ delete mode 100644 pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ delete mode 100644 pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go delete mode 100644 pkg/statistics/handle/binding__failpoint_binding__.go delete mode 100644 pkg/statistics/handle/bootstrap.go__failpoint_stash__ delete mode 100644 pkg/statistics/handle/cache/binding__failpoint_binding__.go delete mode 100644 pkg/statistics/handle/cache/statscache.go__failpoint_stash__ delete mode 100644 pkg/statistics/handle/globalstats/binding__failpoint_binding__.go delete mode 100644 pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ delete mode 100644 pkg/statistics/handle/storage/binding__failpoint_binding__.go delete mode 100644 pkg/statistics/handle/storage/read.go__failpoint_stash__ delete mode 100644 pkg/statistics/handle/syncload/binding__failpoint_binding__.go delete mode 100644 pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ delete mode 100644 pkg/store/copr/batch_coprocessor.go__failpoint_stash__ delete mode 100644 pkg/store/copr/binding__failpoint_binding__.go delete mode 100644 pkg/store/copr/coprocessor.go__failpoint_stash__ delete mode 100644 pkg/store/copr/mpp.go__failpoint_stash__ delete mode 100644 pkg/store/driver/txn/binding__failpoint_binding__.go delete mode 100644 pkg/store/driver/txn/binlog.go__failpoint_stash__ delete mode 100644 pkg/store/driver/txn/txn_driver.go__failpoint_stash__ delete mode 100644 pkg/store/gcworker/binding__failpoint_binding__.go delete mode 100644 pkg/store/gcworker/gc_worker.go__failpoint_stash__ delete mode 100644 pkg/store/mockstore/unistore/binding__failpoint_binding__.go delete mode 100644 pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go delete mode 100644 pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ delete mode 100644 pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ delete mode 100644 pkg/table/contextimpl/binding__failpoint_binding__.go delete mode 100644 pkg/table/contextimpl/table.go__failpoint_stash__ delete mode 100644 pkg/table/tables/binding__failpoint_binding__.go delete mode 100644 pkg/table/tables/cache.go__failpoint_stash__ delete mode 100644 pkg/table/tables/mutation_checker.go__failpoint_stash__ delete mode 100644 pkg/table/tables/tables.go__failpoint_stash__ delete mode 100644 pkg/ttl/ttlworker/binding__failpoint_binding__.go delete mode 100644 pkg/ttl/ttlworker/config.go__failpoint_stash__ delete mode 100644 pkg/util/binding__failpoint_binding__.go delete mode 100644 pkg/util/breakpoint/binding__failpoint_binding__.go delete mode 100644 pkg/util/breakpoint/breakpoint.go__failpoint_stash__ delete mode 100644 pkg/util/cgroup/binding__failpoint_binding__.go delete mode 100644 pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ delete mode 100644 pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ delete mode 100644 pkg/util/chunk/binding__failpoint_binding__.go delete mode 100644 pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ delete mode 100644 pkg/util/chunk/row_container.go__failpoint_stash__ delete mode 100644 pkg/util/chunk/row_container_reader.go__failpoint_stash__ delete mode 100644 pkg/util/codec/binding__failpoint_binding__.go delete mode 100644 pkg/util/codec/decimal.go__failpoint_stash__ delete mode 100644 pkg/util/cpu/binding__failpoint_binding__.go delete mode 100644 pkg/util/cpu/cpu.go__failpoint_stash__ delete mode 100644 pkg/util/etcd.go__failpoint_stash__ delete mode 100644 pkg/util/gctuner/binding__failpoint_binding__.go delete mode 100644 pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ delete mode 100644 pkg/util/memory/binding__failpoint_binding__.go delete mode 100644 pkg/util/memory/meminfo.go__failpoint_stash__ delete mode 100644 pkg/util/memory/memstats.go__failpoint_stash__ delete mode 100644 pkg/util/replayer/binding__failpoint_binding__.go delete mode 100644 pkg/util/replayer/replayer.go__failpoint_stash__ delete mode 100644 pkg/util/servermemorylimit/binding__failpoint_binding__.go delete mode 100644 pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ delete mode 100644 pkg/util/session_pool.go__failpoint_stash__ delete mode 100644 pkg/util/sli/binding__failpoint_binding__.go delete mode 100644 pkg/util/sli/sli.go__failpoint_stash__ delete mode 100644 pkg/util/sqlkiller/binding__failpoint_binding__.go delete mode 100644 pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ delete mode 100644 pkg/util/stmtsummary/binding__failpoint_binding__.go delete mode 100644 pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ delete mode 100644 pkg/util/topsql/binding__failpoint_binding__.go delete mode 100644 pkg/util/topsql/reporter/binding__failpoint_binding__.go delete mode 100644 pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ delete mode 100644 pkg/util/topsql/reporter/reporter.go__failpoint_stash__ delete mode 100644 pkg/util/topsql/topsql.go__failpoint_stash__ diff --git a/br/pkg/backup/binding__failpoint_binding__.go b/br/pkg/backup/binding__failpoint_binding__.go deleted file mode 100644 index 20f171c30d696..0000000000000 --- a/br/pkg/backup/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package backup - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/backup/prepare_snap/binding__failpoint_binding__.go b/br/pkg/backup/prepare_snap/binding__failpoint_binding__.go deleted file mode 100644 index f51555385eb65..0000000000000 --- a/br/pkg/backup/prepare_snap/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package preparesnap - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/backup/prepare_snap/prepare.go b/br/pkg/backup/prepare_snap/prepare.go index 8ea386eb5a2bc..2435fb4986bdc 100644 --- a/br/pkg/backup/prepare_snap/prepare.go +++ b/br/pkg/backup/prepare_snap/prepare.go @@ -454,9 +454,9 @@ func (p *Preparer) pushWaitApply(reqs pendingRequests, region Region) { // PrepareConnections prepares the connections for each store. // This will pause the admin commands for each store. func (p *Preparer) PrepareConnections(ctx context.Context) error { - if _, _err_ := failpoint.Eval(_curpkg_("PrepareConnectionsErr")); _err_ == nil { - return errors.New("mock PrepareConnectionsErr") - } + failpoint.Inject("PrepareConnectionsErr", func() { + failpoint.Return(errors.New("mock PrepareConnectionsErr")) + }) log.Info("Preparing connections to stores.") stores, err := p.env.GetAllLiveStores(ctx) if err != nil { diff --git a/br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ b/br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ deleted file mode 100644 index 2435fb4986bdc..0000000000000 --- a/br/pkg/backup/prepare_snap/prepare.go__failpoint_stash__ +++ /dev/null @@ -1,484 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package preparesnap - -import ( - "bytes" - "context" - "fmt" - "time" - - "github.com/google/btree" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - brpb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/logutil" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "golang.org/x/sync/errgroup" -) - -const ( - /* The combination of defaultMaxRetry and defaultRetryBackoff limits - the whole procedure to about 5 min if there is a region always fail. - Also note that we are batching during retrying. Retrying many region - costs only one chance of retrying if they are batched. */ - - defaultMaxRetry = 60 - defaultRetryBackoff = 5 * time.Second - defaultLeaseDur = 120 * time.Second - - /* Give pd enough time to find the region. If we aren't able to fetch - the region, the whole procedure might be aborted. */ - - regionCacheMaxBackoffMs = 60000 -) - -type pendingRequests map[uint64]*brpb.PrepareSnapshotBackupRequest - -type rangeOrRegion struct { - // If it is a range, this should be zero. - id uint64 - startKey []byte - endKey []byte -} - -func (r rangeOrRegion) String() string { - rng := logutil.StringifyRangeOf(r.startKey, r.endKey) - if r.id == 0 { - return fmt.Sprintf("range%s", rng) - } - return fmt.Sprintf("region(id=%d, range=%s)", r.id, rng) -} - -func (r rangeOrRegion) compareWith(than rangeOrRegion) bool { - return bytes.Compare(r.startKey, than.startKey) < 0 -} - -type Preparer struct { - /* Environments. */ - env Env - - /* Internal Status. */ - inflightReqs map[uint64]metapb.Region - failed []rangeOrRegion - waitApplyDoneRegions btree.BTreeG[rangeOrRegion] - retryTime int - nextRetry *time.Timer - - /* Internal I/O. */ - eventChan chan event - clients map[uint64]*prepareStream - - /* Interface for caller. */ - waitApplyFinished bool - - /* Some configurations. They aren't thread safe. - You may need to configure them before starting the Preparer. */ - RetryBackoff time.Duration - RetryLimit int - LeaseDuration time.Duration - - /* Observers. Initialize them before starting.*/ - AfterConnectionsEstablished func() -} - -func New(env Env) *Preparer { - prep := &Preparer{ - env: env, - - inflightReqs: make(map[uint64]metapb.Region), - waitApplyDoneRegions: *btree.NewG(16, rangeOrRegion.compareWith), - eventChan: make(chan event, 128), - clients: make(map[uint64]*prepareStream), - - RetryBackoff: defaultRetryBackoff, - RetryLimit: defaultMaxRetry, - LeaseDuration: defaultLeaseDur, - } - return prep -} - -func (p *Preparer) MarshalLogObject(om zapcore.ObjectEncoder) error { - om.AddInt("inflight_requests", len(p.inflightReqs)) - reqs := 0 - for _, r := range p.inflightReqs { - om.AddString("simple_inflight_region", rangeOrRegion{id: r.Id, startKey: r.StartKey, endKey: r.EndKey}.String()) - reqs += 1 - if reqs > 3 { - break - } - } - om.AddInt("failed_requests", len(p.failed)) - failed := 0 - for _, r := range p.failed { - om.AddString("simple_failed_region", r.String()) - failed += 1 - if failed > 5 { - break - } - } - err := om.AddArray("connected_stores", zapcore.ArrayMarshalerFunc(func(ae zapcore.ArrayEncoder) error { - for id := range p.clients { - ae.AppendUint64(id) - } - return nil - })) - if err != nil { - return err - } - om.AddInt("retry_time", p.retryTime) - om.AddBool("wait_apply_finished", p.waitApplyFinished) - return nil -} - -// DriveLoopAndWaitPrepare drives the state machine and block the -// current goroutine until we are safe to start taking snapshot. -// -// After this invoked, you shouldn't share this `Preparer` with any other goroutines. -// -// After this the cluster will enter the land between normal and taking snapshot. -// This state will continue even this function returns, until `Finalize` invoked. -// Splitting, ingesting and conf changing will all be blocked. -func (p *Preparer) DriveLoopAndWaitPrepare(ctx context.Context) error { - logutil.CL(ctx).Info("Start drive the loop.", zap.Duration("retry_backoff", p.RetryBackoff), - zap.Int("retry_limit", p.RetryLimit), - zap.Duration("lease_duration", p.LeaseDuration)) - p.retryTime = 0 - if err := p.PrepareConnections(ctx); err != nil { - log.Error("failed to prepare connections", logutil.ShortError(err)) - return errors.Annotate(err, "failed to prepare connections") - } - if p.AfterConnectionsEstablished != nil { - p.AfterConnectionsEstablished() - } - if err := p.AdvanceState(ctx); err != nil { - log.Error("failed to check the progress of our work", logutil.ShortError(err)) - return errors.Annotate(err, "failed to begin step") - } - for !p.waitApplyFinished { - if err := p.WaitAndHandleNextEvent(ctx); err != nil { - log.Error("failed to wait and handle next event", logutil.ShortError(err)) - return errors.Annotate(err, "failed to step") - } - } - return nil -} - -// Finalize notify the cluster to go back to the normal mode. -// This will return an error if the cluster has already entered the normal mode when this is called. -func (p *Preparer) Finalize(ctx context.Context) error { - eg := new(errgroup.Group) - for id, cli := range p.clients { - cli := cli - id := id - eg.Go(func() error { - if err := cli.Finalize(ctx); err != nil { - return errors.Annotatef(err, "failed to finalize the prepare stream for %d", id) - } - return nil - }) - } - errCh := make(chan error, 1) - go func() { - if err := eg.Wait(); err != nil { - logutil.CL(ctx).Warn("failed to finalize some prepare streams.", logutil.ShortError(err)) - errCh <- err - return - } - logutil.CL(ctx).Info("all connections to store have shuted down.") - errCh <- nil - }() - for { - select { - case event, ok := <-p.eventChan: - if !ok { - return nil - } - if err := p.onEvent(ctx, event); err != nil { - return err - } - case err, ok := <-errCh: - if !ok { - panic("unreachable.") - } - if err != nil { - return err - } - // All streams are finialized, they shouldn't send more events to event chan. - close(p.eventChan) - case <-ctx.Done(): - return ctx.Err() - } - } -} - -func (p *Preparer) batchEvents(evts *[]event) { - for { - select { - case evt := <-p.eventChan: - *evts = append(*evts, evt) - default: - return - } - } -} - -// WaitAndHandleNextEvent is exported for test usage. -// This waits the next event (wait apply done, errors, etc..) of preparing. -// Generally `DriveLoopAndWaitPrepare` is all you need. -func (p *Preparer) WaitAndHandleNextEvent(ctx context.Context) error { - select { - case <-ctx.Done(): - logutil.CL(ctx).Warn("User canceled.", logutil.ShortError(ctx.Err())) - return ctx.Err() - case evt := <-p.eventChan: - logutil.CL(ctx).Debug("received event", zap.Stringer("event", evt)) - events := []event{evt} - p.batchEvents(&events) - for _, evt := range events { - err := p.onEvent(ctx, evt) - if err != nil { - return errors.Annotatef(err, "failed to handle event %v", evt) - } - } - return p.AdvanceState(ctx) - case <-p.retryChan(): - return p.workOnPendingRanges(ctx) - } -} - -func (p *Preparer) removePendingRequest(r *metapb.Region) bool { - r2, ok := p.inflightReqs[r.GetId()] - if !ok { - return false - } - matches := r2.GetRegionEpoch().GetVersion() == r.GetRegionEpoch().GetVersion() && - r2.GetRegionEpoch().GetConfVer() == r.GetRegionEpoch().GetConfVer() - if !matches { - return false - } - delete(p.inflightReqs, r.GetId()) - return true -} - -func (p *Preparer) onEvent(ctx context.Context, e event) error { - switch e.ty { - case eventMiscErr: - // Note: some of errors might be able to be retry. - // But for now it seems there isn't one. - return errors.Annotatef(e.err, "unrecoverable error at store %d", e.storeID) - case eventWaitApplyDone: - if !p.removePendingRequest(e.region) { - logutil.CL(ctx).Warn("received unmatched response, perhaps stale, drop it", zap.Stringer("region", e.region)) - return nil - } - r := rangeOrRegion{ - id: e.region.GetId(), - startKey: e.region.GetStartKey(), - endKey: e.region.GetEndKey(), - } - if e.err != nil { - logutil.CL(ctx).Warn("requesting a region failed.", zap.Uint64("store", e.storeID), logutil.ShortError(e.err)) - p.failed = append(p.failed, r) - if p.nextRetry != nil { - p.nextRetry.Stop() - } - // Reset the timer so we can collect more regions. - // Note: perhaps it is better to make a deadline heap or something - // so every region backoffs the same time. - p.nextRetry = time.NewTimer(p.RetryBackoff) - return nil - } - if item, ok := p.waitApplyDoneRegions.ReplaceOrInsert(r); ok { - logutil.CL(ctx).Warn("overlapping in success region", - zap.Stringer("old_region", item), - zap.Stringer("new_region", r)) - } - default: - return errors.Annotatef(unsupported(), "unsupported event type %d", e.ty) - } - - return nil -} - -func (p *Preparer) retryChan() <-chan time.Time { - if p.nextRetry == nil { - return nil - } - return p.nextRetry.C -} - -// AdvanceState is exported for test usage. -// This call will check whether now we are safe to forward the whole procedure. -// If we can, this will set `p.waitApplyFinished` to true. -// Generally `DriveLoopAndWaitPrepare` is all you need, you may not want to call this. -func (p *Preparer) AdvanceState(ctx context.Context) error { - logutil.CL(ctx).Info("Checking the progress of our work.", zap.Object("current", p)) - if len(p.inflightReqs) == 0 && len(p.failed) == 0 { - holes := p.checkHole() - if len(holes) == 0 { - p.waitApplyFinished = true - return nil - } - logutil.CL(ctx).Warn("It seems there are still some works to be done.", zap.Stringers("regions", holes)) - p.failed = holes - return p.workOnPendingRanges(ctx) - } - - return nil -} - -func (p *Preparer) checkHole() []rangeOrRegion { - log.Info("Start checking the hole.", zap.Int("len", p.waitApplyDoneRegions.Len())) - if p.waitApplyDoneRegions.Len() == 0 { - return []rangeOrRegion{{}} - } - - last := []byte("") - failed := []rangeOrRegion{} - p.waitApplyDoneRegions.Ascend(func(item rangeOrRegion) bool { - if bytes.Compare(last, item.startKey) < 0 { - failed = append(failed, rangeOrRegion{startKey: last, endKey: item.startKey}) - } - last = item.endKey - return true - }) - // Not the end key of key space. - if len(last) > 0 { - failed = append(failed, rangeOrRegion{ - startKey: last, - }) - } - return failed -} - -func (p *Preparer) workOnPendingRanges(ctx context.Context) error { - p.nextRetry = nil - if len(p.failed) == 0 { - return nil - } - p.retryTime += 1 - if p.retryTime > p.RetryLimit { - return retryLimitExceeded() - } - - logutil.CL(ctx).Info("retrying some ranges incomplete.", zap.Int("ranges", len(p.failed))) - preqs := pendingRequests{} - for _, r := range p.failed { - rs, err := p.env.LoadRegionsInKeyRange(ctx, r.startKey, r.endKey) - if err != nil { - return errors.Annotatef(err, "retrying range of %s: get region", logutil.StringifyRangeOf(r.startKey, r.endKey)) - } - logutil.CL(ctx).Info("loaded regions in range for retry.", zap.Int("regions", len(rs))) - for _, region := range rs { - p.pushWaitApply(preqs, region) - } - } - p.failed = nil - return p.sendWaitApply(ctx, preqs) -} - -func (p *Preparer) sendWaitApply(ctx context.Context, reqs pendingRequests) error { - logutil.CL(ctx).Info("about to send wait apply to stores", zap.Int("to-stores", len(reqs))) - for store, req := range reqs { - logutil.CL(ctx).Info("sending wait apply requests to store", zap.Uint64("store", store), zap.Int("regions", len(req.Regions))) - stream, err := p.streamOf(ctx, store) - if err != nil { - return errors.Annotatef(err, "failed to dial the store %d", store) - } - err = stream.cli.Send(req) - if err != nil { - return errors.Annotatef(err, "failed to send message to the store %d", store) - } - } - return nil -} - -func (p *Preparer) streamOf(ctx context.Context, storeID uint64) (*prepareStream, error) { - _, ok := p.clients[storeID] - if !ok { - log.Warn("stream of store found a store not established connection", zap.Uint64("store", storeID)) - cli, err := p.env.ConnectToStore(ctx, storeID) - if err != nil { - return nil, errors.Annotatef(err, "failed to dial store %d", storeID) - } - if err := p.createAndCacheStream(ctx, cli, storeID); err != nil { - return nil, errors.Annotatef(err, "failed to create and cache stream for store %d", storeID) - } - } - return p.clients[storeID], nil -} - -func (p *Preparer) createAndCacheStream(ctx context.Context, cli PrepareClient, storeID uint64) error { - if _, ok := p.clients[storeID]; ok { - return nil - } - - s := new(prepareStream) - s.storeID = storeID - s.output = p.eventChan - s.leaseDuration = p.LeaseDuration - err := s.InitConn(ctx, cli) - if err != nil { - return err - } - p.clients[storeID] = s - return nil -} - -func (p *Preparer) pushWaitApply(reqs pendingRequests, region Region) { - leader := region.GetLeaderStoreID() - if _, ok := reqs[leader]; !ok { - reqs[leader] = new(brpb.PrepareSnapshotBackupRequest) - reqs[leader].Ty = brpb.PrepareSnapshotBackupRequestType_WaitApply - } - reqs[leader].Regions = append(reqs[leader].Regions, region.GetMeta()) - p.inflightReqs[region.GetMeta().Id] = *region.GetMeta() -} - -// PrepareConnections prepares the connections for each store. -// This will pause the admin commands for each store. -func (p *Preparer) PrepareConnections(ctx context.Context) error { - failpoint.Inject("PrepareConnectionsErr", func() { - failpoint.Return(errors.New("mock PrepareConnectionsErr")) - }) - log.Info("Preparing connections to stores.") - stores, err := p.env.GetAllLiveStores(ctx) - if err != nil { - return errors.Annotate(err, "failed to get all live stores") - } - - log.Info("Start to initialize the connections.", zap.Int("stores", len(stores))) - clients := map[uint64]PrepareClient{} - for _, store := range stores { - cli, err := p.env.ConnectToStore(ctx, store.Id) - if err != nil { - return errors.Annotatef(err, "failed to dial the store %d", store.Id) - } - clients[store.Id] = cli - } - - for id, cli := range clients { - log.Info("Start to pause the admin commands.", zap.Uint64("store", id)) - if err := p.createAndCacheStream(ctx, cli, id); err != nil { - return errors.Annotatef(err, "failed to create and cache stream for store %d", id) - } - } - - return nil -} diff --git a/br/pkg/backup/store.go b/br/pkg/backup/store.go index 48f935b03f4b0..02f7166193918 100644 --- a/br/pkg/backup/store.go +++ b/br/pkg/backup/store.go @@ -63,7 +63,7 @@ func doSendBackup( req backuppb.BackupRequest, respFn func(*backuppb.BackupResponse) error, ) error { - if v, _err_ := failpoint.Eval(_curpkg_("hint-backup-start")); _err_ == nil { + failpoint.Inject("hint-backup-start", func(v failpoint.Value) { logutil.CL(ctx).Info("failpoint hint-backup-start injected, " + "process will notify the shell.") if sigFile, ok := v.(string); ok { @@ -76,9 +76,9 @@ func doSendBackup( } } time.Sleep(3 * time.Second) - } + }) bCli, err := client.Backup(ctx, &req) - if val, _err_ := failpoint.Eval(_curpkg_("reset-retryable-error")); _err_ == nil { + failpoint.Inject("reset-retryable-error", func(val failpoint.Value) { switch val.(string) { case "Unavailable": { @@ -91,13 +91,13 @@ func doSendBackup( err = status.Error(codes.Internal, "Internal error") } } - } - if val, _err_ := failpoint.Eval(_curpkg_("reset-not-retryable-error")); _err_ == nil { + }) + failpoint.Inject("reset-not-retryable-error", func(val failpoint.Value) { if val.(bool) { logutil.CL(ctx).Debug("failpoint reset-not-retryable-error injected.") err = status.Error(codes.Unknown, "Your server was haunted hence doesn't work, meow :3") } - } + }) if err != nil { return err } @@ -159,28 +159,28 @@ func startBackup( zap.Int("retry", retry), zap.Int("reqIndex", reqIndex)) return doSendBackup(ectx, backupCli, bkReq, func(resp *backuppb.BackupResponse) error { // Forward all responses (including error). - if val, _err_ := failpoint.Eval(_curpkg_("backup-timeout-error")); _err_ == nil { + failpoint.Inject("backup-timeout-error", func(val failpoint.Value) { msg := val.(string) logutil.CL(ectx).Info("failpoint backup-timeout-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ Msg: msg, } - } - if val, _err_ := failpoint.Eval(_curpkg_("backup-storage-error")); _err_ == nil { + }) + failpoint.Inject("backup-storage-error", func(val failpoint.Value) { msg := val.(string) logutil.CL(ectx).Debug("failpoint backup-storage-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ Msg: msg, } - } - if val, _err_ := failpoint.Eval(_curpkg_("tikv-rw-error")); _err_ == nil { + }) + failpoint.Inject("tikv-rw-error", func(val failpoint.Value) { msg := val.(string) logutil.CL(ectx).Debug("failpoint tikv-rw-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ Msg: msg, } - } - if val, _err_ := failpoint.Eval(_curpkg_("tikv-region-error")); _err_ == nil { + }) + failpoint.Inject("tikv-region-error", func(val failpoint.Value) { msg := val.(string) logutil.CL(ectx).Debug("failpoint tikv-region-error injected.", zap.String("msg", msg)) resp.Error = &backuppb.Error{ @@ -191,7 +191,7 @@ func startBackup( }, }, } - } + }) select { case <-ectx.Done(): return ectx.Err() @@ -247,12 +247,12 @@ func ObserveStoreChangesAsync(ctx context.Context, stateNotifier chan BackupRetr logutil.CL(ctx).Warn("failed to watch store changes at beginning, ignore it", zap.Error(err)) } tickInterval := 30 * time.Second - if val, _err_ := failpoint.Eval(_curpkg_("backup-store-change-tick")); _err_ == nil { + failpoint.Inject("backup-store-change-tick", func(val failpoint.Value) { if val.(bool) { tickInterval = 100 * time.Millisecond } logutil.CL(ctx).Info("failpoint backup-store-change-tick injected.", zap.Duration("interval", tickInterval)) - } + }) tick := time.NewTicker(tickInterval) for { select { diff --git a/br/pkg/backup/store.go__failpoint_stash__ b/br/pkg/backup/store.go__failpoint_stash__ deleted file mode 100644 index 02f7166193918..0000000000000 --- a/br/pkg/backup/store.go__failpoint_stash__ +++ /dev/null @@ -1,307 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package backup - -import ( - "context" - "io" - "os" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/kvproto/pkg/errorpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/rtree" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/br/pkg/utils/storewatch" - tidbutil "github.com/pingcap/tidb/pkg/util" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -type BackupRetryPolicy struct { - One uint64 - All bool -} - -type BackupSender interface { - SendAsync( - ctx context.Context, - round uint64, - storeID uint64, - request backuppb.BackupRequest, - concurrency uint, - cli backuppb.BackupClient, - respCh chan *ResponseAndStore, - StateNotifier chan BackupRetryPolicy) -} - -type ResponseAndStore struct { - Resp *backuppb.BackupResponse - StoreID uint64 -} - -func (r ResponseAndStore) GetResponse() *backuppb.BackupResponse { - return r.Resp -} - -func (r ResponseAndStore) GetStoreID() uint64 { - return r.StoreID -} - -func doSendBackup( - ctx context.Context, - client backuppb.BackupClient, - req backuppb.BackupRequest, - respFn func(*backuppb.BackupResponse) error, -) error { - failpoint.Inject("hint-backup-start", func(v failpoint.Value) { - logutil.CL(ctx).Info("failpoint hint-backup-start injected, " + - "process will notify the shell.") - if sigFile, ok := v.(string); ok { - file, err := os.Create(sigFile) - if err != nil { - log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) - } - if file != nil { - file.Close() - } - } - time.Sleep(3 * time.Second) - }) - bCli, err := client.Backup(ctx, &req) - failpoint.Inject("reset-retryable-error", func(val failpoint.Value) { - switch val.(string) { - case "Unavailable": - { - logutil.CL(ctx).Debug("failpoint reset-retryable-error unavailable injected.") - err = status.Error(codes.Unavailable, "Unavailable error") - } - case "Internal": - { - logutil.CL(ctx).Debug("failpoint reset-retryable-error internal injected.") - err = status.Error(codes.Internal, "Internal error") - } - } - }) - failpoint.Inject("reset-not-retryable-error", func(val failpoint.Value) { - if val.(bool) { - logutil.CL(ctx).Debug("failpoint reset-not-retryable-error injected.") - err = status.Error(codes.Unknown, "Your server was haunted hence doesn't work, meow :3") - } - }) - if err != nil { - return err - } - defer func() { - _ = bCli.CloseSend() - }() - - for { - resp, err := bCli.Recv() - if err != nil { - if errors.Cause(err) == io.EOF { // nolint:errorlint - logutil.CL(ctx).Debug("backup streaming finish", - logutil.Key("backup-start-key", req.GetStartKey()), - logutil.Key("backup-end-key", req.GetEndKey())) - return nil - } - return err - } - // TODO: handle errors in the resp. - logutil.CL(ctx).Debug("range backed up", - logutil.Key("small-range-start-key", resp.GetStartKey()), - logutil.Key("small-range-end-key", resp.GetEndKey()), - zap.Int("api-version", int(resp.ApiVersion))) - err = respFn(resp) - if err != nil { - return errors.Trace(err) - } - } -} - -func startBackup( - ctx context.Context, - storeID uint64, - backupReq backuppb.BackupRequest, - backupCli backuppb.BackupClient, - concurrency uint, - respCh chan *ResponseAndStore, -) error { - // this goroutine handle the response from a single store - select { - case <-ctx.Done(): - return ctx.Err() - default: - logutil.CL(ctx).Info("try backup", zap.Uint64("storeID", storeID)) - // Send backup request to the store. - // handle the backup response or internal error here. - // handle the store error(reboot or network partition) outside. - reqs := SplitBackupReqRanges(backupReq, concurrency) - pool := tidbutil.NewWorkerPool(concurrency, "store_backup") - eg, ectx := errgroup.WithContext(ctx) - for i, req := range reqs { - bkReq := req - reqIndex := i - pool.ApplyOnErrorGroup(eg, func() error { - retry := -1 - return utils.WithRetry(ectx, func() error { - retry += 1 - logutil.CL(ectx).Info("backup to store", zap.Uint64("storeID", storeID), - zap.Int("retry", retry), zap.Int("reqIndex", reqIndex)) - return doSendBackup(ectx, backupCli, bkReq, func(resp *backuppb.BackupResponse) error { - // Forward all responses (including error). - failpoint.Inject("backup-timeout-error", func(val failpoint.Value) { - msg := val.(string) - logutil.CL(ectx).Info("failpoint backup-timeout-error injected.", zap.String("msg", msg)) - resp.Error = &backuppb.Error{ - Msg: msg, - } - }) - failpoint.Inject("backup-storage-error", func(val failpoint.Value) { - msg := val.(string) - logutil.CL(ectx).Debug("failpoint backup-storage-error injected.", zap.String("msg", msg)) - resp.Error = &backuppb.Error{ - Msg: msg, - } - }) - failpoint.Inject("tikv-rw-error", func(val failpoint.Value) { - msg := val.(string) - logutil.CL(ectx).Debug("failpoint tikv-rw-error injected.", zap.String("msg", msg)) - resp.Error = &backuppb.Error{ - Msg: msg, - } - }) - failpoint.Inject("tikv-region-error", func(val failpoint.Value) { - msg := val.(string) - logutil.CL(ectx).Debug("failpoint tikv-region-error injected.", zap.String("msg", msg)) - resp.Error = &backuppb.Error{ - // Msg: msg, - Detail: &backuppb.Error_RegionError{ - RegionError: &errorpb.Error{ - Message: msg, - }, - }, - } - }) - select { - case <-ectx.Done(): - return ectx.Err() - case respCh <- &ResponseAndStore{ - Resp: resp, - StoreID: storeID, - }: - } - return nil - }) - }, utils.NewBackupSSTBackoffer()) - }) - } - return eg.Wait() - } -} - -func getBackupRanges(ranges []rtree.Range) []*kvrpcpb.KeyRange { - requestRanges := make([]*kvrpcpb.KeyRange, 0, len(ranges)) - for _, r := range ranges { - requestRanges = append(requestRanges, &kvrpcpb.KeyRange{ - StartKey: r.StartKey, - EndKey: r.EndKey, - }) - } - return requestRanges -} - -func ObserveStoreChangesAsync(ctx context.Context, stateNotifier chan BackupRetryPolicy, pdCli pd.Client) { - go func() { - sendAll := false - newJoinStoresMap := make(map[uint64]struct{}) - cb := storewatch.MakeCallback(storewatch.WithOnReboot(func(s *metapb.Store) { - sendAll = true - }), storewatch.WithOnDisconnect(func(s *metapb.Store) { - sendAll = true - }), storewatch.WithOnNewStoreRegistered(func(s *metapb.Store) { - // only backup for this store - newJoinStoresMap[s.Id] = struct{}{} - })) - - notifyFn := func(ctx context.Context, sendPolicy BackupRetryPolicy) { - select { - case <-ctx.Done(): - case stateNotifier <- sendPolicy: - } - } - - watcher := storewatch.New(pdCli, cb) - // make a first step, and make the state correct for next 30s check - err := watcher.Step(ctx) - if err != nil { - logutil.CL(ctx).Warn("failed to watch store changes at beginning, ignore it", zap.Error(err)) - } - tickInterval := 30 * time.Second - failpoint.Inject("backup-store-change-tick", func(val failpoint.Value) { - if val.(bool) { - tickInterval = 100 * time.Millisecond - } - logutil.CL(ctx).Info("failpoint backup-store-change-tick injected.", zap.Duration("interval", tickInterval)) - }) - tick := time.NewTicker(tickInterval) - for { - select { - case <-ctx.Done(): - return - case <-tick.C: - // reset the state - sendAll = false - clear(newJoinStoresMap) - logutil.CL(ctx).Info("check store changes every tick") - err := watcher.Step(ctx) - if err != nil { - logutil.CL(ctx).Warn("failed to watch store changes, ignore it", zap.Error(err)) - } - if sendAll { - logutil.CL(ctx).Info("detect some store(s) restarted or disconnected, notify with all stores") - notifyFn(ctx, BackupRetryPolicy{All: true}) - } else if len(newJoinStoresMap) > 0 { - for storeID := range newJoinStoresMap { - logutil.CL(ctx).Info("detect a new registered store, notify with this store", zap.Uint64("storeID", storeID)) - notifyFn(ctx, BackupRetryPolicy{One: storeID}) - } - } - } - } - }() -} - -func SplitBackupReqRanges(req backuppb.BackupRequest, count uint) []backuppb.BackupRequest { - rangeCount := len(req.SubRanges) - if rangeCount == 0 { - return []backuppb.BackupRequest{req} - } - splitRequests := make([]backuppb.BackupRequest, 0, count) - if count <= 1 { - // 0/1 means no need to split, just send one batch request - return []backuppb.BackupRequest{req} - } - splitStep := rangeCount / int(count) - if splitStep == 0 { - // splitStep should be at least 1 - // if count >= rangeCount, means no batch, split them all - splitStep = 1 - } - subRanges := req.SubRanges - for i := 0; i < rangeCount; i += splitStep { - splitReq := req - splitReq.SubRanges = subRanges[i:min(i+splitStep, rangeCount)] - splitRequests = append(splitRequests, splitReq) - } - return splitRequests -} diff --git a/br/pkg/checkpoint/binding__failpoint_binding__.go b/br/pkg/checkpoint/binding__failpoint_binding__.go deleted file mode 100644 index 8590a12a60919..0000000000000 --- a/br/pkg/checkpoint/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package checkpoint - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/checkpoint/checkpoint.go b/br/pkg/checkpoint/checkpoint.go index ab438f2635665..4b397a60e5eeb 100644 --- a/br/pkg/checkpoint/checkpoint.go +++ b/br/pkg/checkpoint/checkpoint.go @@ -391,7 +391,7 @@ func (r *CheckpointRunner[K, V]) startCheckpointMainLoop( tickDurationForChecksum, tickDurationForLock time.Duration, ) { - if _, _err_ := failpoint.Eval(_curpkg_("checkpoint-more-quickly-flush")); _err_ == nil { + failpoint.Inject("checkpoint-more-quickly-flush", func(_ failpoint.Value) { tickDurationForChecksum = 1 * time.Second tickDurationForFlush = 3 * time.Second if tickDurationForLock > 0 { @@ -402,7 +402,7 @@ func (r *CheckpointRunner[K, V]) startCheckpointMainLoop( zap.Duration("checksum", tickDurationForChecksum), zap.Duration("lock", tickDurationForLock), ) - } + }) r.wg.Add(1) checkpointLoop := func(ctx context.Context) { defer r.wg.Done() @@ -506,9 +506,9 @@ func (r *CheckpointRunner[K, V]) doChecksumFlush(ctx context.Context, checksumIt return errors.Annotatef(err, "failed to write file %s for checkpoint checksum", fname) } - if _, _err_ := failpoint.Eval(_curpkg_("failed-after-checkpoint-flushes-checksum")); _err_ == nil { - return errors.Errorf("failpoint: failed after checkpoint flushes checksum") - } + failpoint.Inject("failed-after-checkpoint-flushes-checksum", func(_ failpoint.Value) { + failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes checksum")) + }) return nil } @@ -570,9 +570,9 @@ func (r *CheckpointRunner[K, V]) doFlush(ctx context.Context, meta map[K]*RangeG } } - if _, _err_ := failpoint.Eval(_curpkg_("failed-after-checkpoint-flushes")); _err_ == nil { - return errors.Errorf("failpoint: failed after checkpoint flushes") - } + failpoint.Inject("failed-after-checkpoint-flushes", func(_ failpoint.Value) { + failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes")) + }) return nil } @@ -663,9 +663,9 @@ func (r *CheckpointRunner[K, V]) updateLock(ctx context.Context) error { return errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("failed-after-checkpoint-updates-lock")); _err_ == nil { - return errors.Errorf("failpoint: failed after checkpoint updates lock") - } + failpoint.Inject("failed-after-checkpoint-updates-lock", func(_ failpoint.Value) { + failpoint.Return(errors.Errorf("failpoint: failed after checkpoint updates lock")) + }) return nil } diff --git a/br/pkg/checkpoint/checkpoint.go__failpoint_stash__ b/br/pkg/checkpoint/checkpoint.go__failpoint_stash__ deleted file mode 100644 index 4b397a60e5eeb..0000000000000 --- a/br/pkg/checkpoint/checkpoint.go__failpoint_stash__ +++ /dev/null @@ -1,872 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package checkpoint - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "math/rand" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/metautil" - "github.com/pingcap/tidb/br/pkg/rtree" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/util" - "github.com/tikv/client-go/v2/oracle" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -const CheckpointDir = "checkpoints" - -type flushPosition struct { - CheckpointDataDir string - CheckpointChecksumDir string - CheckpointLockPath string -} - -const MaxChecksumTotalCost float64 = 60.0 - -const defaultTickDurationForFlush = 30 * time.Second - -const defaultTckDurationForChecksum = 5 * time.Second - -const defaultTickDurationForLock = 4 * time.Minute - -const lockTimeToLive = 5 * time.Minute - -type KeyType interface { - ~BackupKeyType | ~RestoreKeyType -} - -type RangeType struct { - *rtree.Range -} - -func (r RangeType) IdentKey() []byte { - return r.StartKey -} - -type ValueType interface { - IdentKey() []byte -} - -type CheckpointMessage[K KeyType, V ValueType] struct { - // start-key of the origin range - GroupKey K - - Group []V -} - -// A Checkpoint Range File is like this: -// -// CheckpointData -// +----------------+ RangeGroupData RangeGroup -// | DureTime | +--------------------------+ encrypted +--------------------+ -// | RangeGroupData-+---> | RangeGroupsEncriptedData-+----------> | GroupKey/TableID | -// | RangeGroupData | | Checksum | | Range | -// | ... | | CipherIv | | ... | -// | RangeGroupData | | Size | | Range | -// +----------------+ +--------------------------+ +--------------------+ -// -// For restore, because there is no group key, so there is only one RangeGroupData -// with multi-ranges in the ChecksumData. - -type RangeGroup[K KeyType, V ValueType] struct { - GroupKey K `json:"group-key"` - Group []V `json:"groups"` -} - -type RangeGroupData struct { - RangeGroupsEncriptedData []byte - Checksum []byte - CipherIv []byte - - Size int -} - -type CheckpointData struct { - DureTime time.Duration `json:"dure-time"` - RangeGroupMetas []*RangeGroupData `json:"range-group-metas"` -} - -// A Checkpoint Checksum File is like this: -// -// ChecksumInfo ChecksumItems ChecksumItem -// +------------+ +--------------+ +--------------+ -// | Content--+--> | ChecksumItem-+---> | TableID | -// | Checksum | | ChecksumItem | | Crc64xor | -// | DureTime | | ... | | TotalKvs | -// +------------+ | ChecksumItem | | TotalBytes | -// +--------------+ +--------------+ - -type ChecksumItem struct { - TableID int64 `json:"table-id"` - Crc64xor uint64 `json:"crc64-xor"` - TotalKvs uint64 `json:"total-kvs"` - TotalBytes uint64 `json:"total-bytes"` -} - -type ChecksumItems struct { - Items []*ChecksumItem `json:"checksum-items"` -} - -type ChecksumInfo struct { - Content []byte `json:"content"` - Checksum []byte `json:"checksum"` - DureTime time.Duration `json:"dure-time"` -} - -type GlobalTimer interface { - GetTS(context.Context) (int64, int64, error) -} - -type CheckpointRunner[K KeyType, V ValueType] struct { - flushPosition - lockId uint64 - - meta map[K]*RangeGroup[K, V] - checksum ChecksumItems - - valueMarshaler func(*RangeGroup[K, V]) ([]byte, error) - - storage storage.ExternalStorage - cipher *backuppb.CipherInfo - timer GlobalTimer - - appendCh chan *CheckpointMessage[K, V] - checksumCh chan *ChecksumItem - doneCh chan bool - metaCh chan map[K]*RangeGroup[K, V] - checksumMetaCh chan ChecksumItems - lockCh chan struct{} - errCh chan error - err error - errLock sync.RWMutex - - wg sync.WaitGroup -} - -func newCheckpointRunner[K KeyType, V ValueType]( - ctx context.Context, - storage storage.ExternalStorage, - cipher *backuppb.CipherInfo, - timer GlobalTimer, - f flushPosition, - vm func(*RangeGroup[K, V]) ([]byte, error), -) *CheckpointRunner[K, V] { - return &CheckpointRunner[K, V]{ - flushPosition: f, - - meta: make(map[K]*RangeGroup[K, V]), - checksum: ChecksumItems{Items: make([]*ChecksumItem, 0)}, - - valueMarshaler: vm, - - storage: storage, - cipher: cipher, - timer: timer, - - appendCh: make(chan *CheckpointMessage[K, V]), - checksumCh: make(chan *ChecksumItem), - doneCh: make(chan bool, 1), - metaCh: make(chan map[K]*RangeGroup[K, V]), - checksumMetaCh: make(chan ChecksumItems), - lockCh: make(chan struct{}), - errCh: make(chan error, 1), - err: nil, - } -} - -func (r *CheckpointRunner[K, V]) FlushChecksum( - ctx context.Context, - tableID int64, - crc64xor uint64, - totalKvs uint64, - totalBytes uint64, -) error { - checksumItem := &ChecksumItem{ - TableID: tableID, - Crc64xor: crc64xor, - TotalKvs: totalKvs, - TotalBytes: totalBytes, - } - return r.FlushChecksumItem(ctx, checksumItem) -} - -func (r *CheckpointRunner[K, V]) FlushChecksumItem( - ctx context.Context, - checksumItem *ChecksumItem, -) error { - select { - case <-ctx.Done(): - return errors.Annotatef(ctx.Err(), "failed to append checkpoint checksum item") - case err, ok := <-r.errCh: - if !ok { - r.errLock.RLock() - err = r.err - r.errLock.RUnlock() - return errors.Annotate(err, "[checkpoint] Checksum: failed to append checkpoint checksum item") - } - return err - case r.checksumCh <- checksumItem: - return nil - } -} - -func (r *CheckpointRunner[K, V]) Append( - ctx context.Context, - message *CheckpointMessage[K, V], -) error { - select { - case <-ctx.Done(): - return errors.Annotatef(ctx.Err(), "failed to append checkpoint message") - case err, ok := <-r.errCh: - if !ok { - r.errLock.RLock() - err = r.err - r.errLock.RUnlock() - return errors.Annotate(err, "[checkpoint] Append: failed to append checkpoint message") - } - return err - case r.appendCh <- message: - return nil - } -} - -// Note: Cannot be parallel with `Append` function -func (r *CheckpointRunner[K, V]) WaitForFinish(ctx context.Context, flush bool) { - if r.doneCh != nil { - select { - case r.doneCh <- flush: - - default: - log.Warn("not the first close the checkpoint runner", zap.String("category", "checkpoint")) - } - } - // wait the range flusher exit - r.wg.Wait() - // remove the checkpoint lock - if r.lockId > 0 { - err := r.storage.DeleteFile(ctx, r.CheckpointLockPath) - if err != nil { - log.Warn("failed to remove the checkpoint lock", zap.Error(err)) - } - } -} - -// Send the checksum to the flush goroutine, and reset the CheckpointRunner's checksum -func (r *CheckpointRunner[K, V]) flushChecksum(ctx context.Context, errCh chan error) error { - checksum := ChecksumItems{ - Items: r.checksum.Items, - } - r.checksum.Items = make([]*ChecksumItem, 0) - // do flush - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-errCh: - return err - case r.checksumMetaCh <- checksum: - } - return nil -} - -// Send the meta to the flush goroutine, and reset the CheckpointRunner's meta -func (r *CheckpointRunner[K, V]) flushMeta(ctx context.Context, errCh chan error) error { - meta := r.meta - r.meta = make(map[K]*RangeGroup[K, V]) - // do flush - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-errCh: - return err - case r.metaCh <- meta: - } - return nil -} - -func (r *CheckpointRunner[K, V]) setLock(ctx context.Context, errCh chan error) error { - select { - case <-ctx.Done(): - return ctx.Err() - case err := <-errCh: - return err - case r.lockCh <- struct{}{}: - } - return nil -} - -// start a goroutine to flush the meta, which is sent from `checkpoint looper`, to the external storage -func (r *CheckpointRunner[K, V]) startCheckpointFlushLoop(ctx context.Context, wg *sync.WaitGroup) chan error { - errCh := make(chan error, 1) - wg.Add(1) - flushWorker := func(ctx context.Context, errCh chan error) { - defer wg.Done() - for { - select { - case <-ctx.Done(): - if err := ctx.Err(); err != nil { - errCh <- err - } - return - case meta, ok := <-r.metaCh: - if !ok { - log.Info("stop checkpoint flush worker") - return - } - if err := r.doFlush(ctx, meta); err != nil { - errCh <- errors.Annotate(err, "failed to flush checkpoint data.") - return - } - case checksums, ok := <-r.checksumMetaCh: - if !ok { - log.Info("stop checkpoint flush worker") - return - } - if err := r.doChecksumFlush(ctx, checksums); err != nil { - errCh <- errors.Annotate(err, "failed to flush checkpoint checksum.") - return - } - case _, ok := <-r.lockCh: - if !ok { - log.Info("stop checkpoint flush worker") - return - } - if err := r.updateLock(ctx); err != nil { - errCh <- errors.Annotate(err, "failed to update checkpoint lock.") - return - } - } - } - } - - go flushWorker(ctx, errCh) - return errCh -} - -func (r *CheckpointRunner[K, V]) sendError(err error) { - select { - case r.errCh <- err: - log.Error("send the error", zap.String("category", "checkpoint"), zap.Error(err)) - r.errLock.Lock() - r.err = err - r.errLock.Unlock() - close(r.errCh) - default: - log.Error("errCh is blocked", logutil.ShortError(err)) - } -} - -func (r *CheckpointRunner[K, V]) startCheckpointMainLoop( - ctx context.Context, - tickDurationForFlush, - tickDurationForChecksum, - tickDurationForLock time.Duration, -) { - failpoint.Inject("checkpoint-more-quickly-flush", func(_ failpoint.Value) { - tickDurationForChecksum = 1 * time.Second - tickDurationForFlush = 3 * time.Second - if tickDurationForLock > 0 { - tickDurationForLock = 1 * time.Second - } - log.Info("adjust the tick duration for flush or lock", - zap.Duration("flush", tickDurationForFlush), - zap.Duration("checksum", tickDurationForChecksum), - zap.Duration("lock", tickDurationForLock), - ) - }) - r.wg.Add(1) - checkpointLoop := func(ctx context.Context) { - defer r.wg.Done() - cctx, cancel := context.WithCancel(ctx) - defer cancel() - var wg sync.WaitGroup - errCh := r.startCheckpointFlushLoop(cctx, &wg) - flushTicker := time.NewTicker(tickDurationForFlush) - defer flushTicker.Stop() - checksumTicker := time.NewTicker(tickDurationForChecksum) - defer checksumTicker.Stop() - // register time ticker, the lock ticker is optional - lockTicker := dispatcherTicker(tickDurationForLock) - defer lockTicker.Stop() - for { - select { - case <-ctx.Done(): - if err := ctx.Err(); err != nil { - r.sendError(err) - } - return - case <-lockTicker.Ch(): - if err := r.setLock(ctx, errCh); err != nil { - r.sendError(err) - return - } - case <-checksumTicker.C: - if err := r.flushChecksum(ctx, errCh); err != nil { - r.sendError(err) - return - } - case <-flushTicker.C: - if err := r.flushMeta(ctx, errCh); err != nil { - r.sendError(err) - return - } - case msg := <-r.appendCh: - groups, exist := r.meta[msg.GroupKey] - if !exist { - groups = &RangeGroup[K, V]{ - GroupKey: msg.GroupKey, - Group: make([]V, 0), - } - r.meta[msg.GroupKey] = groups - } - groups.Group = append(groups.Group, msg.Group...) - case msg := <-r.checksumCh: - r.checksum.Items = append(r.checksum.Items, msg) - case flush := <-r.doneCh: - log.Info("stop checkpoint runner") - if flush { - // NOTE: the exit step, don't send error any more. - if err := r.flushMeta(ctx, errCh); err != nil { - log.Error("failed to flush checkpoint meta", zap.Error(err)) - } else if err := r.flushChecksum(ctx, errCh); err != nil { - log.Error("failed to flush checkpoint checksum", zap.Error(err)) - } - } - // close the channel to flush worker - // and wait it to consumes all the metas - close(r.metaCh) - close(r.checksumMetaCh) - close(r.lockCh) - wg.Wait() - return - case err := <-errCh: - // pass flush worker's error back - r.sendError(err) - return - } - } - } - - go checkpointLoop(ctx) -} - -// flush the checksum to the external storage -func (r *CheckpointRunner[K, V]) doChecksumFlush(ctx context.Context, checksumItems ChecksumItems) error { - if len(checksumItems.Items) == 0 { - return nil - } - content, err := json.Marshal(checksumItems) - if err != nil { - return errors.Trace(err) - } - - checksum := sha256.Sum256(content) - checksumInfo := &ChecksumInfo{ - Content: content, - Checksum: checksum[:], - DureTime: summary.NowDureTime(), - } - - data, err := json.Marshal(checksumInfo) - if err != nil { - return errors.Trace(err) - } - - fname := fmt.Sprintf("%s/t%d_and__.cpt", r.CheckpointChecksumDir, checksumItems.Items[0].TableID) - if err = r.storage.WriteFile(ctx, fname, data); err != nil { - return errors.Annotatef(err, "failed to write file %s for checkpoint checksum", fname) - } - - failpoint.Inject("failed-after-checkpoint-flushes-checksum", func(_ failpoint.Value) { - failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes checksum")) - }) - return nil -} - -// flush the meta to the external storage -func (r *CheckpointRunner[K, V]) doFlush(ctx context.Context, meta map[K]*RangeGroup[K, V]) error { - if len(meta) == 0 { - return nil - } - - checkpointData := &CheckpointData{ - DureTime: summary.NowDureTime(), - RangeGroupMetas: make([]*RangeGroupData, 0, len(meta)), - } - - var fname []byte = nil - - for _, group := range meta { - if len(group.Group) == 0 { - continue - } - - // use the first item's group-key and sub-range-key as the filename - if len(fname) == 0 { - fname = append([]byte(fmt.Sprint(group.GroupKey, '.', '.')), group.Group[0].IdentKey()...) - } - - // Flush the metaFile to storage - content, err := r.valueMarshaler(group) - if err != nil { - return errors.Trace(err) - } - - encryptBuff, iv, err := metautil.Encrypt(content, r.cipher) - if err != nil { - return errors.Trace(err) - } - - checksum := sha256.Sum256(content) - - checkpointData.RangeGroupMetas = append(checkpointData.RangeGroupMetas, &RangeGroupData{ - RangeGroupsEncriptedData: encryptBuff, - Checksum: checksum[:], - Size: len(content), - CipherIv: iv, - }) - } - - if len(checkpointData.RangeGroupMetas) > 0 { - data, err := json.Marshal(checkpointData) - if err != nil { - return errors.Trace(err) - } - - checksum := sha256.Sum256(fname) - checksumEncoded := base64.URLEncoding.EncodeToString(checksum[:]) - path := fmt.Sprintf("%s/%s_%d.cpt", r.CheckpointDataDir, checksumEncoded, rand.Uint64()) - if err := r.storage.WriteFile(ctx, path, data); err != nil { - return errors.Trace(err) - } - } - - failpoint.Inject("failed-after-checkpoint-flushes", func(_ failpoint.Value) { - failpoint.Return(errors.Errorf("failpoint: failed after checkpoint flushes")) - }) - return nil -} - -type CheckpointLock struct { - LockId uint64 `json:"lock-id"` - ExpireAt int64 `json:"expire-at"` -} - -// get ts with retry -func (r *CheckpointRunner[K, V]) getTS(ctx context.Context) (int64, int64, error) { - var ( - p int64 = 0 - l int64 = 0 - retry int = 0 - ) - errRetry := utils.WithRetry(ctx, func() error { - var err error - p, l, err = r.timer.GetTS(ctx) - if err != nil { - retry++ - log.Info("failed to get ts", zap.Int("retry", retry), zap.Error(err)) - return err - } - - return nil - }, utils.NewPDReqBackoffer()) - - return p, l, errors.Trace(errRetry) -} - -// flush the lock to the external storage -func (r *CheckpointRunner[K, V]) flushLock(ctx context.Context, p int64) error { - lock := &CheckpointLock{ - LockId: r.lockId, - ExpireAt: p + lockTimeToLive.Milliseconds(), - } - log.Info("start to flush the checkpoint lock", zap.Int64("lock-at", p), - zap.Int64("expire-at", lock.ExpireAt)) - data, err := json.Marshal(lock) - if err != nil { - return errors.Trace(err) - } - - err = r.storage.WriteFile(ctx, r.CheckpointLockPath, data) - return errors.Trace(err) -} - -// check whether this lock belongs to this BR -func (r *CheckpointRunner[K, V]) checkLockFile(ctx context.Context, now int64) error { - data, err := r.storage.ReadFile(ctx, r.CheckpointLockPath) - if err != nil { - return errors.Trace(err) - } - lock := &CheckpointLock{} - err = json.Unmarshal(data, lock) - if err != nil { - return errors.Trace(err) - } - if lock.ExpireAt <= now { - if lock.LockId > r.lockId { - return errors.Errorf("There are another BR(%d) running after but setting lock before this one(%d). "+ - "Please check whether the BR is running. If not, you can retry.", lock.LockId, r.lockId) - } - if lock.LockId == r.lockId { - log.Warn("The lock has expired.", - zap.Int64("expire-at(ms)", lock.ExpireAt), zap.Int64("now(ms)", now)) - } - } else if lock.LockId != r.lockId { - return errors.Errorf("The existing lock will expire in %d seconds. "+ - "There may be another BR(%d) running. If not, you can wait for the lock to expire, "+ - "or delete the file `%s%s` manually.", - (lock.ExpireAt-now)/1000, lock.LockId, strings.TrimRight(r.storage.URI(), "/"), r.CheckpointLockPath) - } - - return nil -} - -// generate a new lock and flush the lock to the external storage -func (r *CheckpointRunner[K, V]) updateLock(ctx context.Context) error { - p, _, err := r.getTS(ctx) - if err != nil { - return errors.Trace(err) - } - if err = r.checkLockFile(ctx, p); err != nil { - return errors.Trace(err) - } - if err = r.flushLock(ctx, p); err != nil { - return errors.Trace(err) - } - - failpoint.Inject("failed-after-checkpoint-updates-lock", func(_ failpoint.Value) { - failpoint.Return(errors.Errorf("failpoint: failed after checkpoint updates lock")) - }) - - return nil -} - -// Attempt to initialize the lock. Need to stop the backup when there is an unexpired locks. -func (r *CheckpointRunner[K, V]) initialLock(ctx context.Context) error { - p, l, err := r.getTS(ctx) - if err != nil { - return errors.Trace(err) - } - r.lockId = oracle.ComposeTS(p, l) - exist, err := r.storage.FileExists(ctx, r.CheckpointLockPath) - if err != nil { - return errors.Trace(err) - } - if exist { - if err := r.checkLockFile(ctx, p); err != nil { - return errors.Trace(err) - } - } - if err = r.flushLock(ctx, p); err != nil { - return errors.Trace(err) - } - - // wait for 3 seconds to check whether the lock file is overwritten by another BR - time.Sleep(3 * time.Second) - err = r.checkLockFile(ctx, p) - return errors.Trace(err) -} - -// walk the whole checkpoint range files and retrieve the metadata of backed up/restored ranges -// and return the total time cost in the past executions -func walkCheckpointFile[K KeyType, V ValueType]( - ctx context.Context, - s storage.ExternalStorage, - cipher *backuppb.CipherInfo, - subDir string, - fn func(groupKey K, value V), -) (time.Duration, error) { - // records the total time cost in the past executions - var pastDureTime time.Duration = 0 - err := s.WalkDir(ctx, &storage.WalkOption{SubDir: subDir}, func(path string, size int64) error { - if strings.HasSuffix(path, ".cpt") { - content, err := s.ReadFile(ctx, path) - if err != nil { - return errors.Trace(err) - } - - checkpointData := &CheckpointData{} - if err = json.Unmarshal(content, checkpointData); err != nil { - log.Error("failed to unmarshal the checkpoint data info, skip it", zap.Error(err)) - return nil - } - - if checkpointData.DureTime > pastDureTime { - pastDureTime = checkpointData.DureTime - } - for _, meta := range checkpointData.RangeGroupMetas { - decryptContent, err := metautil.Decrypt(meta.RangeGroupsEncriptedData, cipher, meta.CipherIv) - if err != nil { - return errors.Trace(err) - } - - checksum := sha256.Sum256(decryptContent) - if !bytes.Equal(meta.Checksum, checksum[:]) { - log.Error("checkpoint checksum info's checksum mismatch, skip it", - zap.ByteString("expect", meta.Checksum), - zap.ByteString("got", checksum[:]), - ) - continue - } - - group := &RangeGroup[K, V]{} - if err = json.Unmarshal(decryptContent, group); err != nil { - return errors.Trace(err) - } - - for _, g := range group.Group { - fn(group.GroupKey, g) - } - } - } - return nil - }) - - return pastDureTime, errors.Trace(err) -} - -// load checkpoint meta data from external storage and unmarshal back -func loadCheckpointMeta[T any](ctx context.Context, s storage.ExternalStorage, path string, m *T) error { - data, err := s.ReadFile(ctx, path) - if err != nil { - return errors.Trace(err) - } - - err = json.Unmarshal(data, m) - return errors.Trace(err) -} - -// walk the whole checkpoint checksum files and retrieve checksum information of tables calculated -func loadCheckpointChecksum( - ctx context.Context, - s storage.ExternalStorage, - subDir string, -) (map[int64]*ChecksumItem, time.Duration, error) { - var pastDureTime time.Duration = 0 - checkpointChecksum := make(map[int64]*ChecksumItem) - err := s.WalkDir(ctx, &storage.WalkOption{SubDir: subDir}, func(path string, size int64) error { - data, err := s.ReadFile(ctx, path) - if err != nil { - return errors.Trace(err) - } - info := &ChecksumInfo{} - err = json.Unmarshal(data, info) - if err != nil { - log.Error("failed to unmarshal the checkpoint checksum info, skip it", zap.Error(err)) - return nil - } - - checksum := sha256.Sum256(info.Content) - if !bytes.Equal(info.Checksum, checksum[:]) { - log.Error("checkpoint checksum info's checksum mismatch, skip it", - zap.ByteString("expect", info.Checksum), - zap.ByteString("got", checksum[:]), - ) - return nil - } - - if info.DureTime > pastDureTime { - pastDureTime = info.DureTime - } - - items := &ChecksumItems{} - err = json.Unmarshal(info.Content, items) - if err != nil { - return errors.Trace(err) - } - - for _, c := range items.Items { - checkpointChecksum[c.TableID] = c - } - return nil - }) - return checkpointChecksum, pastDureTime, errors.Trace(err) -} - -func saveCheckpointMetadata[T any](ctx context.Context, s storage.ExternalStorage, meta *T, path string) error { - data, err := json.Marshal(meta) - if err != nil { - return errors.Trace(err) - } - - err = s.WriteFile(ctx, path, data) - return errors.Trace(err) -} - -func removeCheckpointData(ctx context.Context, s storage.ExternalStorage, subDir string) error { - var ( - // Generate one file every 30 seconds, so there are only 1200 files in 10 hours. - removedFileNames = make([]string, 0, 1200) - - removeCnt int = 0 - removeSize int64 = 0 - ) - err := s.WalkDir(ctx, &storage.WalkOption{SubDir: subDir}, func(path string, size int64) error { - if !strings.HasSuffix(path, ".cpt") && !strings.HasSuffix(path, ".meta") && !strings.HasSuffix(path, ".lock") { - return nil - } - removedFileNames = append(removedFileNames, path) - removeCnt += 1 - removeSize += size - return nil - }) - if err != nil { - return errors.Trace(err) - } - log.Info("start to remove checkpoint data", - zap.String("checkpoint task", subDir), - zap.Int("remove-count", removeCnt), - zap.Int64("remove-size", removeSize), - ) - - maxFailedFilesNum := int64(16) - var failedFilesCount atomic.Int64 - pool := util.NewWorkerPool(4, "checkpoint remove worker") - eg, gCtx := errgroup.WithContext(ctx) - for _, filename := range removedFileNames { - name := filename - pool.ApplyOnErrorGroup(eg, func() error { - if err := s.DeleteFile(gCtx, name); err != nil { - log.Warn("failed to remove the file", zap.String("filename", name), zap.Error(err)) - if failedFilesCount.Add(1) >= maxFailedFilesNum { - return errors.Annotate(err, "failed to delete too many files") - } - } - return nil - }) - } - if err := eg.Wait(); err != nil { - return errors.Trace(err) - } - log.Info("all the checkpoint data has been removed", zap.String("checkpoint task", subDir)) - return nil -} diff --git a/br/pkg/checksum/binding__failpoint_binding__.go b/br/pkg/checksum/binding__failpoint_binding__.go deleted file mode 100644 index c63a7388ac3ba..0000000000000 --- a/br/pkg/checksum/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package checksum - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/checksum/executor.go b/br/pkg/checksum/executor.go index 58c9ebc56a3cc..22f5a13d23a65 100644 --- a/br/pkg/checksum/executor.go +++ b/br/pkg/checksum/executor.go @@ -387,12 +387,12 @@ func (exec *Executor) Execute( vars.BackOffWeight = exec.backoffWeight } resp, err = sendChecksumRequest(ctx, client, req, vars) - if val, _err_ := failpoint.Eval(_curpkg_("checksumRetryErr")); _err_ == nil { + failpoint.Inject("checksumRetryErr", func(val failpoint.Value) { // first time reach here. return error if val.(bool) { err = errors.New("inject checksum error") } - } + }) if err != nil { return errors.Trace(err) } diff --git a/br/pkg/checksum/executor.go__failpoint_stash__ b/br/pkg/checksum/executor.go__failpoint_stash__ deleted file mode 100644 index 22f5a13d23a65..0000000000000 --- a/br/pkg/checksum/executor.go__failpoint_stash__ +++ /dev/null @@ -1,419 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package checksum - -import ( - "context" - - "github.com/gogo/protobuf/proto" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/metautil" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -// ExecutorBuilder is used to build a "kv.Request". -type ExecutorBuilder struct { - table *model.TableInfo - ts uint64 - - oldTable *metautil.Table - - concurrency uint - backoffWeight int - - oldKeyspace []byte - newKeyspace []byte - - resourceGroupName string - explicitRequestSourceType string -} - -// NewExecutorBuilder returns a new executor builder. -func NewExecutorBuilder(table *model.TableInfo, ts uint64) *ExecutorBuilder { - return &ExecutorBuilder{ - table: table, - ts: ts, - - concurrency: variable.DefDistSQLScanConcurrency, - } -} - -// SetOldTable set a old table info to the builder. -func (builder *ExecutorBuilder) SetOldTable(oldTable *metautil.Table) *ExecutorBuilder { - builder.oldTable = oldTable - return builder -} - -// SetConcurrency set the concurrency of the checksum executing. -func (builder *ExecutorBuilder) SetConcurrency(conc uint) *ExecutorBuilder { - builder.concurrency = conc - return builder -} - -// SetBackoffWeight set the backoffWeight of the checksum executing. -func (builder *ExecutorBuilder) SetBackoffWeight(backoffWeight int) *ExecutorBuilder { - builder.backoffWeight = backoffWeight - return builder -} - -func (builder *ExecutorBuilder) SetOldKeyspace(keyspace []byte) *ExecutorBuilder { - builder.oldKeyspace = keyspace - return builder -} - -func (builder *ExecutorBuilder) SetNewKeyspace(keyspace []byte) *ExecutorBuilder { - builder.newKeyspace = keyspace - return builder -} - -func (builder *ExecutorBuilder) SetResourceGroupName(name string) *ExecutorBuilder { - builder.resourceGroupName = name - return builder -} - -func (builder *ExecutorBuilder) SetExplicitRequestSourceType(name string) *ExecutorBuilder { - builder.explicitRequestSourceType = name - return builder -} - -// Build builds a checksum executor. -func (builder *ExecutorBuilder) Build() (*Executor, error) { - reqs, err := buildChecksumRequest( - builder.table, - builder.oldTable, - builder.ts, - builder.concurrency, - builder.oldKeyspace, - builder.newKeyspace, - builder.resourceGroupName, - builder.explicitRequestSourceType, - ) - if err != nil { - return nil, errors.Trace(err) - } - return &Executor{reqs: reqs, backoffWeight: builder.backoffWeight}, nil -} - -func buildChecksumRequest( - newTable *model.TableInfo, - oldTable *metautil.Table, - startTS uint64, - concurrency uint, - oldKeyspace []byte, - newKeyspace []byte, - resourceGroupName, explicitRequestSourceType string, -) ([]*kv.Request, error) { - var partDefs []model.PartitionDefinition - if part := newTable.Partition; part != nil { - partDefs = part.Definitions - } - - reqs := make([]*kv.Request, 0, (len(newTable.Indices)+1)*(len(partDefs)+1)) - var oldTableID int64 - if oldTable != nil { - oldTableID = oldTable.Info.ID - } - rs, err := buildRequest(newTable, newTable.ID, oldTable, oldTableID, startTS, concurrency, - oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) - if err != nil { - return nil, errors.Trace(err) - } - reqs = append(reqs, rs...) - - for _, partDef := range partDefs { - var oldPartID int64 - if oldTable != nil { - for _, oldPartDef := range oldTable.Info.Partition.Definitions { - if oldPartDef.Name == partDef.Name { - oldPartID = oldPartDef.ID - } - } - } - rs, err := buildRequest(newTable, partDef.ID, oldTable, oldPartID, startTS, concurrency, - oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) - if err != nil { - return nil, errors.Trace(err) - } - reqs = append(reqs, rs...) - } - - return reqs, nil -} - -func buildRequest( - tableInfo *model.TableInfo, - tableID int64, - oldTable *metautil.Table, - oldTableID int64, - startTS uint64, - concurrency uint, - oldKeyspace []byte, - newKeyspace []byte, - resourceGroupName, explicitRequestSourceType string, -) ([]*kv.Request, error) { - reqs := make([]*kv.Request, 0) - req, err := buildTableRequest(tableInfo, tableID, oldTable, oldTableID, startTS, concurrency, - oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) - if err != nil { - return nil, errors.Trace(err) - } - reqs = append(reqs, req) - - for _, indexInfo := range tableInfo.Indices { - if indexInfo.State != model.StatePublic { - continue - } - var oldIndexInfo *model.IndexInfo - if oldTable != nil { - for _, oldIndex := range oldTable.Info.Indices { - if oldIndex.Name == indexInfo.Name { - oldIndexInfo = oldIndex - break - } - } - if oldIndexInfo == nil { - log.Panic("index not found in origin table, "+ - "please check the restore table has the same index info with origin table", - zap.Int64("table id", tableID), - zap.Stringer("table name", tableInfo.Name), - zap.Int64("origin table id", oldTableID), - zap.Stringer("origin table name", oldTable.Info.Name), - zap.Stringer("index name", indexInfo.Name)) - } - } - req, err = buildIndexRequest( - tableID, indexInfo, oldTableID, oldIndexInfo, startTS, concurrency, - oldKeyspace, newKeyspace, resourceGroupName, explicitRequestSourceType) - if err != nil { - return nil, errors.Trace(err) - } - reqs = append(reqs, req) - } - - return reqs, nil -} - -func buildTableRequest( - tableInfo *model.TableInfo, - tableID int64, - oldTable *metautil.Table, - oldTableID int64, - startTS uint64, - concurrency uint, - oldKeyspace []byte, - newKeyspace []byte, - resourceGroupName, explicitRequestSourceType string, -) (*kv.Request, error) { - var rule *tipb.ChecksumRewriteRule - if oldTable != nil { - rule = &tipb.ChecksumRewriteRule{ - OldPrefix: append(append([]byte{}, oldKeyspace...), tablecodec.GenTableRecordPrefix(oldTableID)...), - NewPrefix: append(append([]byte{}, newKeyspace...), tablecodec.GenTableRecordPrefix(tableID)...), - } - } - - checksum := &tipb.ChecksumRequest{ - ScanOn: tipb.ChecksumScanOn_Table, - Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, - Rule: rule, - } - - var ranges []*ranger.Range - if tableInfo.IsCommonHandle { - ranges = ranger.FullNotNullRange() - } else { - ranges = ranger.FullIntRange(false) - } - - var builder distsql.RequestBuilder - // Use low priority to reducing impact to other requests. - builder.Request.Priority = kv.PriorityLow - return builder.SetHandleRanges(nil, tableID, tableInfo.IsCommonHandle, ranges). - SetStartTS(startTS). - SetChecksumRequest(checksum). - SetConcurrency(int(concurrency)). - SetResourceGroupName(resourceGroupName). - SetExplicitRequestSourceType(explicitRequestSourceType). - Build() -} - -func buildIndexRequest( - tableID int64, - indexInfo *model.IndexInfo, - oldTableID int64, - oldIndexInfo *model.IndexInfo, - startTS uint64, - concurrency uint, - oldKeyspace []byte, - newKeyspace []byte, - resourceGroupName, ExplicitRequestSourceType string, -) (*kv.Request, error) { - var rule *tipb.ChecksumRewriteRule - if oldIndexInfo != nil { - rule = &tipb.ChecksumRewriteRule{ - OldPrefix: append(append([]byte{}, oldKeyspace...), - tablecodec.EncodeTableIndexPrefix(oldTableID, oldIndexInfo.ID)...), - NewPrefix: append(append([]byte{}, newKeyspace...), - tablecodec.EncodeTableIndexPrefix(tableID, indexInfo.ID)...), - } - } - checksum := &tipb.ChecksumRequest{ - ScanOn: tipb.ChecksumScanOn_Index, - Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, - Rule: rule, - } - - ranges := ranger.FullRange() - - var builder distsql.RequestBuilder - // Use low priority to reducing impact to other requests. - builder.Request.Priority = kv.PriorityLow - return builder.SetIndexRanges(nil, tableID, indexInfo.ID, ranges). - SetStartTS(startTS). - SetChecksumRequest(checksum). - SetConcurrency(int(concurrency)). - SetResourceGroupName(resourceGroupName). - SetExplicitRequestSourceType(ExplicitRequestSourceType). - Build() -} - -func sendChecksumRequest( - ctx context.Context, client kv.Client, req *kv.Request, vars *kv.Variables, -) (resp *tipb.ChecksumResponse, err error) { - res, err := distsql.Checksum(ctx, client, req, vars) - if err != nil { - return nil, errors.Trace(err) - } - defer func() { - if err1 := res.Close(); err1 != nil { - err = err1 - } - }() - - resp = &tipb.ChecksumResponse{} - - for { - data, err := res.NextRaw(ctx) - if err != nil { - return nil, errors.Trace(err) - } - if data == nil { - break - } - checksum := &tipb.ChecksumResponse{} - if err = checksum.Unmarshal(data); err != nil { - return nil, errors.Trace(err) - } - updateChecksumResponse(resp, checksum) - } - - return resp, nil -} - -func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { - resp.Checksum ^= update.Checksum - resp.TotalKvs += update.TotalKvs - resp.TotalBytes += update.TotalBytes -} - -// Executor is a checksum executor. -type Executor struct { - reqs []*kv.Request - backoffWeight int -} - -// Len returns the total number of checksum requests. -func (exec *Executor) Len() int { - return len(exec.reqs) -} - -// Each executes the function to each requests in the executor. -func (exec *Executor) Each(f func(*kv.Request) error) error { - for _, req := range exec.reqs { - err := f(req) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -// RawRequests extracts the raw requests associated with this executor. -// This is mainly used for debugging only. -func (exec *Executor) RawRequests() ([]*tipb.ChecksumRequest, error) { - res := make([]*tipb.ChecksumRequest, 0, len(exec.reqs)) - for _, req := range exec.reqs { - rawReq := new(tipb.ChecksumRequest) - if err := proto.Unmarshal(req.Data, rawReq); err != nil { - return nil, errors.Trace(err) - } - res = append(res, rawReq) - } - return res, nil -} - -// Execute executes a checksum executor. -func (exec *Executor) Execute( - ctx context.Context, - client kv.Client, - updateFn func(), -) (*tipb.ChecksumResponse, error) { - checksumResp := &tipb.ChecksumResponse{} - checksumBackoffer := utils.InitialRetryState(utils.ChecksumRetryTime, - utils.ChecksumWaitInterval, utils.ChecksumMaxWaitInterval) - for _, req := range exec.reqs { - // Pointer to SessionVars.Killed - // Killed is a flag to indicate that this query is killed. - // - // It is useful in TiDB, however, it's a place holder in BR. - killed := uint32(0) - var ( - resp *tipb.ChecksumResponse - err error - ) - err = utils.WithRetry(ctx, func() error { - vars := kv.NewVariables(&killed) - if exec.backoffWeight > 0 { - vars.BackOffWeight = exec.backoffWeight - } - resp, err = sendChecksumRequest(ctx, client, req, vars) - failpoint.Inject("checksumRetryErr", func(val failpoint.Value) { - // first time reach here. return error - if val.(bool) { - err = errors.New("inject checksum error") - } - }) - if err != nil { - return errors.Trace(err) - } - return nil - }, &checksumBackoffer) - if err != nil { - return nil, errors.Trace(err) - } - updateChecksumResponse(checksumResp, resp) - updateFn() - } - return checksumResp, checkContextDone(ctx) -} - -// The coprocessor won't return the error if the context is done, -// so sometimes BR would get the incomplete result. -// checkContextDone makes sure the result is not affected by CONTEXT DONE. -func checkContextDone(ctx context.Context) error { - ctxErr := ctx.Err() - if ctxErr != nil { - return errors.Annotate(ctxErr, "context is cancelled by other error") - } - return nil -} diff --git a/br/pkg/conn/binding__failpoint_binding__.go b/br/pkg/conn/binding__failpoint_binding__.go deleted file mode 100644 index 195eaae166265..0000000000000 --- a/br/pkg/conn/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package conn - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/conn/conn.go b/br/pkg/conn/conn.go index 29cb9dafa12ba..cdb81a011c8a5 100644 --- a/br/pkg/conn/conn.go +++ b/br/pkg/conn/conn.go @@ -87,29 +87,29 @@ func GetAllTiKVStoresWithRetry(ctx context.Context, ctx, func() error { stores, err = util.GetAllTiKVStores(ctx, pdClient, storeBehavior) - if val, _err_ := failpoint.Eval(_curpkg_("hint-GetAllTiKVStores-error")); _err_ == nil { + failpoint.Inject("hint-GetAllTiKVStores-error", func(val failpoint.Value) { logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-error injected.") if val.(bool) { err = status.Error(codes.Unknown, "Retryable error") - return err + failpoint.Return(err) } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("hint-GetAllTiKVStores-grpc-cancel")); _err_ == nil { + failpoint.Inject("hint-GetAllTiKVStores-grpc-cancel", func(val failpoint.Value) { logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-grpc-cancel injected.") if val.(bool) { err = status.Error(codes.Canceled, "Cancel Retry") - return err + failpoint.Return(err) } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("hint-GetAllTiKVStores-ctx-cancel")); _err_ == nil { + failpoint.Inject("hint-GetAllTiKVStores-ctx-cancel", func(val failpoint.Value) { logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-ctx-cancel injected.") if val.(bool) { err = context.Canceled - return err + failpoint.Return(err) } - } + }) return errors.Trace(err) }, diff --git a/br/pkg/conn/conn.go__failpoint_stash__ b/br/pkg/conn/conn.go__failpoint_stash__ deleted file mode 100644 index cdb81a011c8a5..0000000000000 --- a/br/pkg/conn/conn.go__failpoint_stash__ +++ /dev/null @@ -1,457 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package conn - -import ( - "context" - "crypto/tls" - "fmt" - "io" - "net" - "net/http" - "net/url" - "strings" - - "github.com/docker/go-units" - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/log" - kvconfig "github.com/pingcap/tidb/br/pkg/config" - "github.com/pingcap/tidb/br/pkg/conn/util" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/pdutil" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/txnkv/txnlock" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/status" -) - -const ( - // DefaultMergeRegionSizeBytes is the default region split size, 96MB. - // See https://github.com/tikv/tikv/blob/v4.0.8/components/raftstore/src/coprocessor/config.rs#L35-L38 - DefaultMergeRegionSizeBytes uint64 = 96 * units.MiB - - // DefaultMergeRegionKeyCount is the default region key count, 960000. - DefaultMergeRegionKeyCount uint64 = 960000 - - // DefaultImportNumGoroutines is the default number of threads for import. - // use 128 as default value, which is 8 times of the default value of tidb. - // we think is proper for IO-bound cases. - DefaultImportNumGoroutines uint = 128 -) - -type VersionCheckerType int - -const ( - // default version checker - NormalVersionChecker VersionCheckerType = iota - // version checker for PiTR - StreamVersionChecker -) - -// Mgr manages connections to a TiDB cluster. -type Mgr struct { - *pdutil.PdController - dom *domain.Domain - storage kv.Storage // Used to access SQL related interfaces. - tikvStore tikv.Storage // Used to access TiKV specific interfaces. - ownsStorage bool - - *utils.StoreManager -} - -func GetAllTiKVStoresWithRetry(ctx context.Context, - pdClient util.StoreMeta, - storeBehavior util.StoreBehavior, -) ([]*metapb.Store, error) { - stores := make([]*metapb.Store, 0) - var err error - - errRetry := utils.WithRetry( - ctx, - func() error { - stores, err = util.GetAllTiKVStores(ctx, pdClient, storeBehavior) - failpoint.Inject("hint-GetAllTiKVStores-error", func(val failpoint.Value) { - logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-error injected.") - if val.(bool) { - err = status.Error(codes.Unknown, "Retryable error") - failpoint.Return(err) - } - }) - - failpoint.Inject("hint-GetAllTiKVStores-grpc-cancel", func(val failpoint.Value) { - logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-grpc-cancel injected.") - if val.(bool) { - err = status.Error(codes.Canceled, "Cancel Retry") - failpoint.Return(err) - } - }) - - failpoint.Inject("hint-GetAllTiKVStores-ctx-cancel", func(val failpoint.Value) { - logutil.CL(ctx).Debug("failpoint hint-GetAllTiKVStores-ctx-cancel injected.") - if val.(bool) { - err = context.Canceled - failpoint.Return(err) - } - }) - - return errors.Trace(err) - }, - utils.NewPDReqBackoffer(), - ) - - return stores, errors.Trace(errRetry) -} - -func checkStoresAlive(ctx context.Context, - pdclient pd.Client, - storeBehavior util.StoreBehavior) error { - // Check live tikv. - stores, err := util.GetAllTiKVStores(ctx, pdclient, storeBehavior) - if err != nil { - log.Error("failed to get store", zap.Error(err)) - return errors.Trace(err) - } - - liveStoreCount := 0 - for _, s := range stores { - if s.GetState() != metapb.StoreState_Up { - continue - } - liveStoreCount++ - } - log.Info("checked alive KV stores", zap.Int("aliveStores", liveStoreCount), zap.Int("totalStores", len(stores))) - return nil -} - -// NewMgr creates a new Mgr. -// -// Domain is optional for Backup, set `needDomain` to false to disable -// initializing Domain. -func NewMgr( - ctx context.Context, - g glue.Glue, - pdAddrs []string, - tlsConf *tls.Config, - securityOption pd.SecurityOption, - keepalive keepalive.ClientParameters, - storeBehavior util.StoreBehavior, - checkRequirements bool, - needDomain bool, - versionCheckerType VersionCheckerType, -) (*Mgr, error) { - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("conn.NewMgr", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - log.Info("new mgr", zap.Strings("pdAddrs", pdAddrs)) - - controller, err := pdutil.NewPdController(ctx, pdAddrs, tlsConf, securityOption) - if err != nil { - log.Error("failed to create pd controller", zap.Error(err)) - return nil, errors.Trace(err) - } - if checkRequirements { - var checker version.VerChecker - switch versionCheckerType { - case NormalVersionChecker: - checker = version.CheckVersionForBR - case StreamVersionChecker: - checker = version.CheckVersionForBRPiTR - default: - return nil, errors.Errorf("unknown command type, comman code is %d", versionCheckerType) - } - err = version.CheckClusterVersion(ctx, controller.GetPDClient(), checker) - if err != nil { - return nil, errors.Annotate(err, "running BR in incompatible version of cluster, "+ - "if you believe it's OK, use --check-requirements=false to skip.") - } - } - - err = checkStoresAlive(ctx, controller.GetPDClient(), storeBehavior) - if err != nil { - return nil, errors.Trace(err) - } - - // Disable GC because TiDB enables GC already. - path := fmt.Sprintf( - "tikv://%s?disableGC=true&keyspaceName=%s", - strings.Join(pdAddrs, ","), config.GetGlobalKeyspaceName(), - ) - storage, err := g.Open(path, securityOption) - if err != nil { - return nil, errors.Trace(err) - } - - tikvStorage, ok := storage.(tikv.Storage) - if !ok { - return nil, berrors.ErrKVNotTiKV - } - - var dom *domain.Domain - if needDomain { - dom, err = g.GetDomain(storage) - if err != nil { - return nil, errors.Trace(err) - } - // we must check tidb(tikv version) any time after concurrent ddl feature implemented in v6.2. - // we will keep this check until 7.0, which allow the breaking changes. - // NOTE: must call it after domain created! - // FIXME: remove this check in v7.0 - err = version.CheckClusterVersion(ctx, controller.GetPDClient(), version.CheckVersionForDDL) - if err != nil { - return nil, errors.Annotate(err, "unable to check cluster version for ddl") - } - } - - mgr := &Mgr{ - PdController: controller, - storage: storage, - tikvStore: tikvStorage, - dom: dom, - ownsStorage: g.OwnsStorage(), - StoreManager: utils.NewStoreManager(controller.GetPDClient(), keepalive, tlsConf), - } - return mgr, nil -} - -// GetBackupClient get or create a backup client. -func (mgr *Mgr) GetBackupClient(ctx context.Context, storeID uint64) (backuppb.BackupClient, error) { - var cli backuppb.BackupClient - if err := mgr.WithConn(ctx, storeID, func(cc *grpc.ClientConn) { - cli = backuppb.NewBackupClient(cc) - }); err != nil { - return nil, err - } - return cli, nil -} - -func (mgr *Mgr) GetLogBackupClient(ctx context.Context, storeID uint64) (logbackup.LogBackupClient, error) { - var cli logbackup.LogBackupClient - if err := mgr.WithConn(ctx, storeID, func(cc *grpc.ClientConn) { - cli = logbackup.NewLogBackupClient(cc) - }); err != nil { - return nil, err - } - return cli, nil -} - -// GetStorage returns a kv storage. -func (mgr *Mgr) GetStorage() kv.Storage { - return mgr.storage -} - -// GetTLSConfig returns the tls config. -func (mgr *Mgr) GetTLSConfig() *tls.Config { - return mgr.StoreManager.TLSConfig() -} - -// GetStore gets the tikvStore. -func (mgr *Mgr) GetStore() tikv.Storage { - return mgr.tikvStore -} - -// GetLockResolver gets the LockResolver. -func (mgr *Mgr) GetLockResolver() *txnlock.LockResolver { - return mgr.tikvStore.GetLockResolver() -} - -// GetDomain returns a tikv storage. -func (mgr *Mgr) GetDomain() *domain.Domain { - return mgr.dom -} - -func (mgr *Mgr) Close() { - if mgr.StoreManager != nil { - mgr.StoreManager.Close() - } - // Gracefully shutdown domain so it does not affect other TiDB DDL. - // Must close domain before closing storage, otherwise it gets stuck forever. - if mgr.ownsStorage { - if mgr.dom != nil { - mgr.dom.Close() - } - tikv.StoreShuttingDown(1) - _ = mgr.storage.Close() - } - - mgr.PdController.Close() -} - -// GetTS gets current ts from pd. -func (mgr *Mgr) GetTS(ctx context.Context) (uint64, error) { - p, l, err := mgr.GetPDClient().GetTS(ctx) - if err != nil { - return 0, errors.Trace(err) - } - - return oracle.ComposeTS(p, l), nil -} - -// ProcessTiKVConfigs handle the tikv config for region split size, region split keys, and import goroutines in place. -// It retrieves the config from all alive tikv stores and returns the minimum values. -// If retrieving the config fails, it returns the default config values. -func (mgr *Mgr) ProcessTiKVConfigs(ctx context.Context, cfg *kvconfig.KVConfig, client *http.Client) { - mergeRegionSize := cfg.MergeRegionSize - mergeRegionKeyCount := cfg.MergeRegionKeyCount - importGoroutines := cfg.ImportGoroutines - - if mergeRegionSize.Modified && mergeRegionKeyCount.Modified && importGoroutines.Modified { - log.Info("no need to retrieve the config from tikv if user has set the config") - return - } - - err := mgr.GetConfigFromTiKV(ctx, client, func(resp *http.Response) error { - respBytes, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - if !mergeRegionSize.Modified || !mergeRegionKeyCount.Modified { - size, keys, e := kvconfig.ParseMergeRegionSizeFromConfig(respBytes) - if e != nil { - log.Warn("Failed to parse region split size and keys from config", logutil.ShortError(e)) - return e - } - if mergeRegionKeyCount.Value == DefaultMergeRegionKeyCount || keys < mergeRegionKeyCount.Value { - mergeRegionSize.Value = size - mergeRegionKeyCount.Value = keys - } - } - if !importGoroutines.Modified { - threads, e := kvconfig.ParseImportThreadsFromConfig(respBytes) - if e != nil { - log.Warn("Failed to parse import num-threads from config", logutil.ShortError(e)) - return e - } - // We use 8 times the default value because it's an IO-bound case. - if importGoroutines.Value == DefaultImportNumGoroutines || (threads > 0 && threads*8 < importGoroutines.Value) { - importGoroutines.Value = threads * 8 - } - } - // replace the value - cfg.MergeRegionSize = mergeRegionSize - cfg.MergeRegionKeyCount = mergeRegionKeyCount - cfg.ImportGoroutines = importGoroutines - return nil - }) - - if err != nil { - log.Warn("Failed to get config from TiKV; using default", logutil.ShortError(err)) - } -} - -// IsLogBackupEnabled is used for br to check whether tikv has enabled log backup. -func (mgr *Mgr) IsLogBackupEnabled(ctx context.Context, client *http.Client) (bool, error) { - logbackupEnable := true - err := mgr.GetConfigFromTiKV(ctx, client, func(resp *http.Response) error { - respBytes, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - enable, err := kvconfig.ParseLogBackupEnableFromConfig(respBytes) - if err != nil { - log.Warn("Failed to parse log-backup enable from config", logutil.ShortError(err)) - return err - } - logbackupEnable = logbackupEnable && enable - return nil - }) - return logbackupEnable, errors.Trace(err) -} - -// GetConfigFromTiKV get configs from all alive tikv stores. -func (mgr *Mgr) GetConfigFromTiKV(ctx context.Context, cli *http.Client, fn func(*http.Response) error) error { - allStores, err := GetAllTiKVStoresWithRetry(ctx, mgr.GetPDClient(), util.SkipTiFlash) - if err != nil { - return errors.Trace(err) - } - - httpPrefix := "http://" - if mgr.GetTLSConfig() != nil { - httpPrefix = "https://" - } - - for _, store := range allStores { - if store.State != metapb.StoreState_Up { - continue - } - // we need make sure every available store support backup-stream otherwise we might lose data. - // so check every store's config - addr, err := handleTiKVAddress(store, httpPrefix) - if err != nil { - return err - } - configAddr := fmt.Sprintf("%s/config", addr.String()) - - err = utils.WithRetry(ctx, func() error { - resp, e := cli.Get(configAddr) - if e != nil { - return e - } - defer resp.Body.Close() - err = fn(resp) - if err != nil { - return err - } - return nil - }, utils.NewPDReqBackoffer()) - if err != nil { - // if one store failed, break and return error - return err - } - } - return nil -} - -func handleTiKVAddress(store *metapb.Store, httpPrefix string) (*url.URL, error) { - statusAddr := store.GetStatusAddress() - nodeAddr := store.GetAddress() - if !strings.HasPrefix(statusAddr, "http") { - statusAddr = httpPrefix + statusAddr - } - if !strings.HasPrefix(nodeAddr, "http") { - nodeAddr = httpPrefix + nodeAddr - } - - statusUrl, err := url.Parse(statusAddr) - if err != nil { - return nil, err - } - nodeUrl, err := url.Parse(nodeAddr) - if err != nil { - return nil, err - } - - // we try status address as default - addr := statusUrl - // but in sometimes we may not get the correct status address from PD. - if statusUrl.Hostname() != nodeUrl.Hostname() { - // if not matched, we use the address as default, but change the port - addr.Host = net.JoinHostPort(nodeUrl.Hostname(), statusUrl.Port()) - log.Warn("store address and status address mismatch the host, we will use the store address as hostname", - zap.Uint64("store", store.Id), - zap.String("status address", statusAddr), - zap.String("node address", nodeAddr), - zap.Any("request address", statusUrl), - ) - } - return addr, nil -} diff --git a/br/pkg/pdutil/binding__failpoint_binding__.go b/br/pkg/pdutil/binding__failpoint_binding__.go deleted file mode 100644 index 35b536059c6a3..0000000000000 --- a/br/pkg/pdutil/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package pdutil - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/pdutil/pd.go b/br/pkg/pdutil/pd.go index 4e7ec9c0a614a..31f20cd8af433 100644 --- a/br/pkg/pdutil/pd.go +++ b/br/pkg/pdutil/pd.go @@ -212,12 +212,12 @@ func parseVersion(versionStr string) *semver.Version { zap.String("version", versionStr), zap.Error(err)) version = &semver.Version{Major: 0, Minor: 0, Patch: 0} } - if val, _err_ := failpoint.Eval(_curpkg_("PDEnabledPauseConfig")); _err_ == nil { + failpoint.Inject("PDEnabledPauseConfig", func(val failpoint.Value) { if val.(bool) { // test pause config is enable version = &semver.Version{Major: 5, Minor: 0, Patch: 0} } - } + }) return version } diff --git a/br/pkg/pdutil/pd.go__failpoint_stash__ b/br/pkg/pdutil/pd.go__failpoint_stash__ deleted file mode 100644 index 31f20cd8af433..0000000000000 --- a/br/pkg/pdutil/pd.go__failpoint_stash__ +++ /dev/null @@ -1,782 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package pdutil - -import ( - "context" - "crypto/tls" - "encoding/hex" - "fmt" - "math" - "net/http" - "strings" - "time" - - "github.com/coreos/go-semver/semver" - "github.com/docker/go-units" - "github.com/google/uuid" - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/pkg/util/codec" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - "github.com/tikv/pd/client/retry" - "go.uber.org/zap" - "google.golang.org/grpc" -) - -const ( - maxMsgSize = int(128 * units.MiB) // pd.ScanRegion may return a large response - pauseTimeout = 5 * time.Minute - // pd request retry time when connection fail - PDRequestRetryTime = 120 - // set max-pending-peer-count to a large value to avoid scatter region failed. - maxPendingPeerUnlimited uint64 = math.MaxInt32 -) - -// pauseConfigGenerator generate a config value according to store count and current value. -type pauseConfigGenerator func(int, any) any - -// zeroPauseConfig sets the config to 0. -func zeroPauseConfig(int, any) any { - return 0 -} - -// pauseConfigMulStores multiplies the existing value by -// number of stores. The value is limited to 40, as larger value -// may make the cluster unstable. -func pauseConfigMulStores(stores int, raw any) any { - rawCfg := raw.(float64) - return math.Min(40, rawCfg*float64(stores)) -} - -// pauseConfigFalse sets the config to "false". -func pauseConfigFalse(int, any) any { - return "false" -} - -// constConfigGeneratorBuilder build a pauseConfigGenerator based on a given const value. -func constConfigGeneratorBuilder(val any) pauseConfigGenerator { - return func(int, any) any { - return val - } -} - -// ClusterConfig represents a set of scheduler whose config have been modified -// along with their original config. -type ClusterConfig struct { - // Enable PD schedulers before restore - Schedulers []string `json:"schedulers"` - // Original scheudle configuration - ScheduleCfg map[string]any `json:"schedule_cfg"` -} - -type pauseSchedulerBody struct { - Delay int64 `json:"delay"` -} - -var ( - // in v4.0.8 version we can use pause configs - // see https://github.com/tikv/pd/pull/3088 - pauseConfigVersion = semver.Version{Major: 4, Minor: 0, Patch: 8} - - // After v6.1.0 version, we can pause schedulers by key range with TTL. - minVersionForRegionLabelTTL = semver.Version{Major: 6, Minor: 1, Patch: 0} - - // Schedulers represent region/leader schedulers which can impact on performance. - Schedulers = map[string]struct{}{ - "balance-leader-scheduler": {}, - "balance-hot-region-scheduler": {}, - "balance-region-scheduler": {}, - - "shuffle-leader-scheduler": {}, - "shuffle-region-scheduler": {}, - "shuffle-hot-region-scheduler": {}, - - "evict-slow-store-scheduler": {}, - } - expectPDCfgGenerators = map[string]pauseConfigGenerator{ - "merge-schedule-limit": zeroPauseConfig, - // TODO "leader-schedule-limit" and "region-schedule-limit" don't support ttl for now, - // but we still need set these config for compatible with old version. - // we need wait for https://github.com/tikv/pd/pull/3131 merged. - // see details https://github.com/pingcap/br/pull/592#discussion_r522684325 - "leader-schedule-limit": pauseConfigMulStores, - "region-schedule-limit": pauseConfigMulStores, - "max-snapshot-count": pauseConfigMulStores, - "enable-location-replacement": pauseConfigFalse, - "max-pending-peer-count": constConfigGeneratorBuilder(maxPendingPeerUnlimited), - } - - // defaultPDCfg find by https://github.com/tikv/pd/blob/master/conf/config.toml. - // only use for debug command. - defaultPDCfg = map[string]any{ - "merge-schedule-limit": 8, - "leader-schedule-limit": 4, - "region-schedule-limit": 2048, - "enable-location-replacement": "true", - } -) - -// DefaultExpectPDCfgGenerators returns default pd config generators -func DefaultExpectPDCfgGenerators() map[string]pauseConfigGenerator { - clone := make(map[string]pauseConfigGenerator, len(expectPDCfgGenerators)) - for k := range expectPDCfgGenerators { - clone[k] = expectPDCfgGenerators[k] - } - return clone -} - -// PdController manage get/update config from pd. -type PdController struct { - pdClient pd.Client - pdHTTPCli pdhttp.Client - version *semver.Version - - // control the pause schedulers goroutine - schedulerPauseCh chan struct{} - // control the ttl of pausing schedulers - SchedulerPauseTTL time.Duration -} - -// NewPdController creates a new PdController. -func NewPdController( - ctx context.Context, - pdAddrs []string, - tlsConf *tls.Config, - securityOption pd.SecurityOption, -) (*PdController, error) { - maxCallMsgSize := []grpc.DialOption{ - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(maxMsgSize)), - grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(maxMsgSize)), - } - pdClient, err := pd.NewClientWithContext( - ctx, pdAddrs, securityOption, - pd.WithGRPCDialOptions(maxCallMsgSize...), - // If the time too short, we may scatter a region many times, because - // the interface `ScatterRegions` may time out. - pd.WithCustomTimeoutOption(60*time.Second), - ) - if err != nil { - log.Error("fail to create pd client", zap.Error(err)) - return nil, errors.Trace(err) - } - - pdHTTPCliConfig := make([]pdhttp.ClientOption, 0, 1) - if tlsConf != nil { - pdHTTPCliConfig = append(pdHTTPCliConfig, pdhttp.WithTLSConfig(tlsConf)) - } - pdHTTPCli := pdhttp.NewClientWithServiceDiscovery( - "br/lightning PD controller", - pdClient.GetServiceDiscovery(), - pdHTTPCliConfig..., - ).WithBackoffer(retry.InitialBackoffer(time.Second, time.Second, PDRequestRetryTime*time.Second)) - versionStr, err := pdHTTPCli.GetPDVersion(ctx) - if err != nil { - pdHTTPCli.Close() - pdClient.Close() - return nil, errors.Trace(err) - } - version := parseVersion(versionStr) - - return &PdController{ - pdClient: pdClient, - pdHTTPCli: pdHTTPCli, - version: version, - // We should make a buffered channel here otherwise when context canceled, - // gracefully shutdown will stick at resuming schedulers. - schedulerPauseCh: make(chan struct{}, 1), - }, nil -} - -func NewPdControllerWithPDClient(pdClient pd.Client, pdHTTPCli pdhttp.Client, v *semver.Version) *PdController { - return &PdController{ - pdClient: pdClient, - pdHTTPCli: pdHTTPCli, - version: v, - schedulerPauseCh: make(chan struct{}, 1), - } -} - -func parseVersion(versionStr string) *semver.Version { - // we need trim space or semver will parse failed - v := strings.TrimSpace(versionStr) - v = strings.Trim(v, "\"") - v = strings.TrimPrefix(v, "v") - version, err := semver.NewVersion(v) - if err != nil { - log.Warn("fail back to v0.0.0 version", - zap.String("version", versionStr), zap.Error(err)) - version = &semver.Version{Major: 0, Minor: 0, Patch: 0} - } - failpoint.Inject("PDEnabledPauseConfig", func(val failpoint.Value) { - if val.(bool) { - // test pause config is enable - version = &semver.Version{Major: 5, Minor: 0, Patch: 0} - } - }) - return version -} - -func (p *PdController) isPauseConfigEnabled() bool { - return p.version.Compare(pauseConfigVersion) >= 0 -} - -// SetPDClient set pd addrs and cli for test. -func (p *PdController) SetPDClient(pdClient pd.Client) { - p.pdClient = pdClient -} - -// GetPDClient set pd addrs and cli for test. -func (p *PdController) GetPDClient() pd.Client { - return p.pdClient -} - -// GetPDHTTPClient returns the pd http client. -func (p *PdController) GetPDHTTPClient() pdhttp.Client { - return p.pdHTTPCli -} - -// GetClusterVersion returns the current cluster version. -func (p *PdController) GetClusterVersion(ctx context.Context) (string, error) { - v, err := p.pdHTTPCli.GetClusterVersion(ctx) - return v, errors.Trace(err) -} - -// GetRegionCount returns the region count in the specified range. -func (p *PdController) GetRegionCount(ctx context.Context, startKey, endKey []byte) (int, error) { - // TiKV reports region start/end keys to PD in memcomparable-format. - var start, end []byte - start = codec.EncodeBytes(nil, startKey) - if len(endKey) != 0 { // Empty end key means the max. - end = codec.EncodeBytes(nil, endKey) - } - status, err := p.pdHTTPCli.GetRegionStatusByKeyRange(ctx, pdhttp.NewKeyRange(start, end), true) - if err != nil { - return 0, errors.Trace(err) - } - return status.Count, nil -} - -// GetStoreInfo returns the info of store with the specified id. -func (p *PdController) GetStoreInfo(ctx context.Context, storeID uint64) (*pdhttp.StoreInfo, error) { - info, err := p.pdHTTPCli.GetStore(ctx, storeID) - return info, errors.Trace(err) -} - -func (p *PdController) doPauseSchedulers( - ctx context.Context, - schedulers []string, -) ([]string, error) { - // pause this scheduler with 300 seconds - delay := int64(p.ttlOfPausing().Seconds()) - removedSchedulers := make([]string, 0, len(schedulers)) - for _, scheduler := range schedulers { - err := p.pdHTTPCli.SetSchedulerDelay(ctx, scheduler, delay) - if err != nil { - return removedSchedulers, errors.Trace(err) - } - removedSchedulers = append(removedSchedulers, scheduler) - } - return removedSchedulers, nil -} - -func (p *PdController) pauseSchedulersAndConfigWith( - ctx context.Context, schedulers []string, - schedulerCfg map[string]any, -) ([]string, error) { - // first pause this scheduler, if the first time failed. we should return the error - // so put first time out of for loop. and in for loop we could ignore other failed pause. - removedSchedulers, err := p.doPauseSchedulers(ctx, schedulers) - if err != nil { - log.Error("failed to pause scheduler at beginning", - zap.Strings("name", schedulers), zap.Error(err)) - return nil, errors.Trace(err) - } - log.Info("pause scheduler successful at beginning", zap.Strings("name", schedulers)) - if schedulerCfg != nil { - err = p.doPauseConfigs(ctx, schedulerCfg) - if err != nil { - log.Error("failed to pause config at beginning", - zap.Any("cfg", schedulerCfg), zap.Error(err)) - return nil, errors.Trace(err) - } - log.Info("pause configs successful at beginning", zap.Any("cfg", schedulerCfg)) - } - - go func() { - tick := time.NewTicker(p.ttlOfPausing() / 3) - defer tick.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-tick.C: - _, err := p.doPauseSchedulers(ctx, schedulers) - if err != nil { - log.Warn("pause scheduler failed, ignore it and wait next time pause", zap.Error(err)) - } - if schedulerCfg != nil { - err = p.doPauseConfigs(ctx, schedulerCfg) - if err != nil { - log.Warn("pause configs failed, ignore it and wait next time pause", zap.Error(err)) - } - } - log.Info("pause scheduler(configs)", zap.Strings("name", removedSchedulers), - zap.Any("cfg", schedulerCfg)) - case <-p.schedulerPauseCh: - log.Info("exit pause scheduler and configs successful") - return - } - } - }() - return removedSchedulers, nil -} - -// ResumeSchedulers resume pd scheduler. -func (p *PdController) ResumeSchedulers(ctx context.Context, schedulers []string) error { - return errors.Trace(p.resumeSchedulerWith(ctx, schedulers)) -} - -func (p *PdController) resumeSchedulerWith(ctx context.Context, schedulers []string) (err error) { - log.Info("resume scheduler", zap.Strings("schedulers", schedulers)) - p.schedulerPauseCh <- struct{}{} - - // 0 means stop pause. - delay := int64(0) - for _, scheduler := range schedulers { - err = p.pdHTTPCli.SetSchedulerDelay(ctx, scheduler, delay) - if err != nil { - log.Error("failed to resume scheduler after retry, you may reset this scheduler manually"+ - "or just wait this scheduler pause timeout", zap.String("scheduler", scheduler)) - } else { - log.Info("resume scheduler successful", zap.String("scheduler", scheduler)) - } - } - // no need to return error, because the pause will timeout. - return nil -} - -// ListSchedulers list all pd scheduler. -func (p *PdController) ListSchedulers(ctx context.Context) ([]string, error) { - s, err := p.pdHTTPCli.GetSchedulers(ctx) - return s, errors.Trace(err) -} - -// GetPDScheduleConfig returns PD schedule config value associated with the key. -// It returns nil if there is no such config item. -func (p *PdController) GetPDScheduleConfig(ctx context.Context) (map[string]any, error) { - cfg, err := p.pdHTTPCli.GetScheduleConfig(ctx) - return cfg, errors.Trace(err) -} - -// UpdatePDScheduleConfig updates PD schedule config value associated with the key. -func (p *PdController) UpdatePDScheduleConfig(ctx context.Context) error { - log.Info("update pd with default config", zap.Any("cfg", defaultPDCfg)) - return errors.Trace(p.doUpdatePDScheduleConfig(ctx, defaultPDCfg)) -} - -func (p *PdController) doUpdatePDScheduleConfig( - ctx context.Context, cfg map[string]any, ttlSeconds ...float64, -) error { - newCfg := make(map[string]any) - for k, v := range cfg { - // if we want use ttl, we need use config prefix first. - // which means cfg should transfer from "max-merge-region-keys" to "schedule.max-merge-region-keys". - sc := fmt.Sprintf("schedule.%s", k) - newCfg[sc] = v - } - - if err := p.pdHTTPCli.SetConfig(ctx, newCfg, ttlSeconds...); err != nil { - return errors.Annotatef( - berrors.ErrPDUpdateFailed, - "failed to update PD schedule config: %s", - err.Error(), - ) - } - return nil -} - -func (p *PdController) doPauseConfigs(ctx context.Context, cfg map[string]any) error { - // pause this scheduler with 300 seconds - return errors.Trace(p.doUpdatePDScheduleConfig(ctx, cfg, p.ttlOfPausing().Seconds())) -} - -func restoreSchedulers(ctx context.Context, pd *PdController, clusterCfg ClusterConfig, - configsNeedRestore map[string]pauseConfigGenerator) error { - if err := pd.ResumeSchedulers(ctx, clusterCfg.Schedulers); err != nil { - return errors.Annotate(err, "fail to add PD schedulers") - } - log.Info("restoring config", zap.Any("config", clusterCfg.ScheduleCfg)) - mergeCfg := make(map[string]any) - for cfgKey := range configsNeedRestore { - value := clusterCfg.ScheduleCfg[cfgKey] - if value == nil { - // Ignore non-exist config. - continue - } - mergeCfg[cfgKey] = value - } - - prefix := make([]float64, 0, 1) - if pd.isPauseConfigEnabled() { - // set config's ttl to zero, make temporary config invalid immediately. - prefix = append(prefix, 0) - } - // reset config with previous value. - if err := pd.doUpdatePDScheduleConfig(ctx, mergeCfg, prefix...); err != nil { - return errors.Annotate(err, "fail to update PD merge config") - } - return nil -} - -// MakeUndoFunctionByConfig return an UndoFunc based on specified ClusterConfig -func (p *PdController) MakeUndoFunctionByConfig(config ClusterConfig) UndoFunc { - return p.GenRestoreSchedulerFunc(config, expectPDCfgGenerators) -} - -// GenRestoreSchedulerFunc gen restore func -func (p *PdController) GenRestoreSchedulerFunc(config ClusterConfig, - configsNeedRestore map[string]pauseConfigGenerator) UndoFunc { - // todo: we only need config names, not a map[string]pauseConfigGenerator - restore := func(ctx context.Context) error { - return restoreSchedulers(ctx, p, config, configsNeedRestore) - } - return restore -} - -// RemoveSchedulers removes the schedulers that may slow down BR speed. -func (p *PdController) RemoveSchedulers(ctx context.Context) (undo UndoFunc, err error) { - undo = Nop - - origin, _, err1 := p.RemoveSchedulersWithOrigin(ctx) - if err1 != nil { - err = err1 - return - } - - undo = p.MakeUndoFunctionByConfig(ClusterConfig{Schedulers: origin.Schedulers, ScheduleCfg: origin.ScheduleCfg}) - return undo, errors.Trace(err) -} - -// RemoveSchedulersWithConfig removes the schedulers that may slow down BR speed. -func (p *PdController) RemoveSchedulersWithConfig( - ctx context.Context, -) (undo UndoFunc, config *ClusterConfig, err error) { - undo = Nop - - origin, _, err1 := p.RemoveSchedulersWithOrigin(ctx) - if err1 != nil { - err = err1 - return - } - - undo = p.MakeUndoFunctionByConfig(ClusterConfig{Schedulers: origin.Schedulers, ScheduleCfg: origin.ScheduleCfg}) - return undo, &origin, errors.Trace(err) -} - -// RemoveAllPDSchedulers pause pd scheduler during the snapshot backup and restore -func (p *PdController) RemoveAllPDSchedulers(ctx context.Context) (undo UndoFunc, err error) { - undo = Nop - - // during the backup, we shall stop all scheduler so that restore easy to implement - // during phase-2, pd is fresh and in recovering-mode(recovering-mark=true), there's no leader - // so there's no leader or region schedule initially. when phase-2 start force setting leaders, schedule may begin. - // we don't want pd do any leader or region schedule during this time, so we set those params to 0 - // before we force setting leaders - const enableTiKVSplitRegion = "enable-tikv-split-region" - scheduleLimitParams := []string{ - "hot-region-schedule-limit", - "leader-schedule-limit", - "merge-schedule-limit", - "region-schedule-limit", - "replica-schedule-limit", - enableTiKVSplitRegion, - } - pdConfigGenerators := DefaultExpectPDCfgGenerators() - for _, param := range scheduleLimitParams { - if param == enableTiKVSplitRegion { - pdConfigGenerators[param] = func(int, any) any { return false } - } else { - pdConfigGenerators[param] = func(int, any) any { return 0 } - } - } - - oldPDConfig, _, err1 := p.RemoveSchedulersWithConfigGenerator(ctx, pdConfigGenerators) - if err1 != nil { - err = err1 - return - } - - undo = p.GenRestoreSchedulerFunc(oldPDConfig, pdConfigGenerators) - return undo, errors.Trace(err) -} - -// RemoveSchedulersWithOrigin pause and remove br related schedule configs and return the origin and modified configs -func (p *PdController) RemoveSchedulersWithOrigin(ctx context.Context) ( - origin ClusterConfig, - modified ClusterConfig, - err error, -) { - origin, modified, err = p.RemoveSchedulersWithConfigGenerator(ctx, expectPDCfgGenerators) - err = errors.Trace(err) - return -} - -// RemoveSchedulersWithConfigGenerator pause scheduler with custom config generator -func (p *PdController) RemoveSchedulersWithConfigGenerator( - ctx context.Context, - pdConfigGenerators map[string]pauseConfigGenerator, -) (origin ClusterConfig, modified ClusterConfig, err error) { - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("PdController.RemoveSchedulers", - opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - originCfg := ClusterConfig{} - removedCfg := ClusterConfig{} - stores, err := p.pdClient.GetAllStores(ctx) - if err != nil { - return originCfg, removedCfg, errors.Trace(err) - } - scheduleCfg, err := p.GetPDScheduleConfig(ctx) - if err != nil { - return originCfg, removedCfg, errors.Trace(err) - } - disablePDCfg := make(map[string]any, len(pdConfigGenerators)) - originPDCfg := make(map[string]any, len(pdConfigGenerators)) - for cfgKey, cfgValFunc := range pdConfigGenerators { - value, ok := scheduleCfg[cfgKey] - if !ok { - // Ignore non-exist config. - continue - } - disablePDCfg[cfgKey] = cfgValFunc(len(stores), value) - originPDCfg[cfgKey] = value - } - originCfg.ScheduleCfg = originPDCfg - removedCfg.ScheduleCfg = disablePDCfg - - log.Debug("saved PD config", zap.Any("config", scheduleCfg)) - - // Remove default PD scheduler that may affect restore process. - existSchedulers, err := p.ListSchedulers(ctx) - if err != nil { - return originCfg, removedCfg, errors.Trace(err) - } - needRemoveSchedulers := make([]string, 0, len(existSchedulers)) - for _, s := range existSchedulers { - if _, ok := Schedulers[s]; ok { - needRemoveSchedulers = append(needRemoveSchedulers, s) - } - } - - removedSchedulers, err := p.doRemoveSchedulersWith(ctx, needRemoveSchedulers, disablePDCfg) - if err != nil { - return originCfg, removedCfg, errors.Trace(err) - } - - originCfg.Schedulers = removedSchedulers - removedCfg.Schedulers = removedSchedulers - - return originCfg, removedCfg, nil -} - -// RemoveSchedulersWithCfg removes pd schedulers and configs with specified ClusterConfig -func (p *PdController) RemoveSchedulersWithCfg(ctx context.Context, removeCfg ClusterConfig) error { - _, err := p.doRemoveSchedulersWith(ctx, removeCfg.Schedulers, removeCfg.ScheduleCfg) - return errors.Trace(err) -} - -func (p *PdController) doRemoveSchedulersWith( - ctx context.Context, - needRemoveSchedulers []string, - disablePDCfg map[string]any, -) ([]string, error) { - if !p.isPauseConfigEnabled() { - return nil, errors.Errorf("pd version %s not support pause config, please upgrade", p.version.String()) - } - // after 4.0.8 we can set these config with TTL - s, err := p.pauseSchedulersAndConfigWith(ctx, needRemoveSchedulers, disablePDCfg) - return s, errors.Trace(err) -} - -// GetMinResolvedTS get min-resolved-ts from pd -func (p *PdController) GetMinResolvedTS(ctx context.Context) (uint64, error) { - ts, _, err := p.pdHTTPCli.GetMinResolvedTSByStoresIDs(ctx, nil) - return ts, errors.Trace(err) -} - -// RecoverBaseAllocID recover base alloc id -func (p *PdController) RecoverBaseAllocID(ctx context.Context, id uint64) error { - return errors.Trace(p.pdHTTPCli.ResetBaseAllocID(ctx, id)) -} - -// ResetTS reset current ts of pd -func (p *PdController) ResetTS(ctx context.Context, ts uint64) error { - // reset-ts of PD will never set ts < current pd ts - // we set force-use-larger=true to allow ts > current pd ts + 24h(on default) - err := p.pdHTTPCli.ResetTS(ctx, ts, true) - if err == nil { - return nil - } - if strings.Contains(err.Error(), http.StatusText(http.StatusForbidden)) { - log.Info("reset-ts returns with status forbidden, ignore") - return nil - } - return errors.Trace(err) -} - -// MarkRecovering mark pd into recovering -func (p *PdController) MarkRecovering(ctx context.Context) error { - return errors.Trace(p.pdHTTPCli.SetSnapshotRecoveringMark(ctx)) -} - -// UnmarkRecovering unmark pd recovering -func (p *PdController) UnmarkRecovering(ctx context.Context) error { - return errors.Trace(p.pdHTTPCli.DeleteSnapshotRecoveringMark(ctx)) -} - -// RegionLabel is the label of a region. This struct is partially copied from -// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L31. -type RegionLabel struct { - Key string `json:"key"` - Value string `json:"value"` - TTL string `json:"ttl,omitempty"` - StartAt string `json:"start_at,omitempty"` -} - -// LabelRule is the rule to assign labels to a region. This struct is partially copied from -// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L41. -type LabelRule struct { - ID string `json:"id"` - Labels []RegionLabel `json:"labels"` - RuleType string `json:"rule_type"` - Data any `json:"data"` -} - -// KeyRangeRule contains the start key and end key of the LabelRule. This struct is partially copied from -// https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L62. -type KeyRangeRule struct { - StartKeyHex string `json:"start_key"` // hex format start key, for marshal/unmarshal - EndKeyHex string `json:"end_key"` // hex format end key, for marshal/unmarshal -} - -// PauseSchedulersByKeyRange will pause schedulers for regions in the specific key range. -// This function will spawn a goroutine to keep pausing schedulers periodically until the context is done. -// The return done channel is used to notify the caller that the background goroutine is exited. -func PauseSchedulersByKeyRange( - ctx context.Context, - pdHTTPCli pdhttp.Client, - startKey, endKey []byte, -) (done <-chan struct{}, err error) { - done, err = pauseSchedulerByKeyRangeWithTTL(ctx, pdHTTPCli, startKey, endKey, pauseTimeout) - // Wait for the rule to take effect because the PD operator is processed asynchronously. - // To synchronize this, checking the operator status may not be enough. For details, see - // https://github.com/pingcap/tidb/issues/49477. - // Let's use two times default value of `patrol-region-interval` from PD configuration. - <-time.After(20 * time.Millisecond) - return done, errors.Trace(err) -} - -func pauseSchedulerByKeyRangeWithTTL( - ctx context.Context, - pdHTTPCli pdhttp.Client, - startKey, endKey []byte, - ttl time.Duration, -) (<-chan struct{}, error) { - rule := &pdhttp.LabelRule{ - ID: uuid.New().String(), - Labels: []pdhttp.RegionLabel{{ - Key: "schedule", - Value: "deny", - TTL: ttl.String(), - }}, - RuleType: "key-range", - // Data should be a list of KeyRangeRule when rule type is key-range. - // See https://github.com/tikv/pd/blob/783d060861cef37c38cbdcab9777fe95c17907fe/server/schedule/labeler/rules.go#L169. - Data: []KeyRangeRule{{ - StartKeyHex: hex.EncodeToString(startKey), - EndKeyHex: hex.EncodeToString(endKey), - }}, - } - done := make(chan struct{}) - - if err := pdHTTPCli.SetRegionLabelRule(ctx, rule); err != nil { - close(done) - return nil, errors.Trace(err) - } - - go func() { - defer close(done) - ticker := time.NewTicker(ttl / 3) - defer ticker.Stop() - loop: - for { - select { - case <-ticker.C: - if err := pdHTTPCli.SetRegionLabelRule(ctx, rule); err != nil { - if berrors.IsContextCanceled(err) { - break loop - } - log.Warn("pause scheduler by key range failed, ignore it and wait next time pause", - zap.Error(err)) - } - case <-ctx.Done(): - break loop - } - } - // Use a new context to avoid the context is canceled by the caller. - recoverCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - // Set ttl to 0 to remove the rule. - rule.Labels[0].TTL = time.Duration(0).String() - deleteRule := &pdhttp.LabelRulePatch{DeleteRules: []string{rule.ID}} - if err := pdHTTPCli.PatchRegionLabelRules(recoverCtx, deleteRule); err != nil { - log.Warn("failed to delete region label rule, the rule will be removed after ttl expires", - zap.String("rule-id", rule.ID), zap.Duration("ttl", ttl), zap.Error(err)) - } - }() - return done, nil -} - -// CanPauseSchedulerByKeyRange returns whether the scheduler can be paused by key range. -func (p *PdController) CanPauseSchedulerByKeyRange() bool { - // We need ttl feature to ensure scheduler can recover from pause automatically. - return p.version.Compare(minVersionForRegionLabelTTL) >= 0 -} - -// Close closes the connection to pd. -func (p *PdController) Close() { - p.pdClient.Close() - if p.pdHTTPCli != nil { - // nil in some unit tests - p.pdHTTPCli.Close() - } - if p.schedulerPauseCh != nil { - close(p.schedulerPauseCh) - } -} - -func (p *PdController) ttlOfPausing() time.Duration { - if p.SchedulerPauseTTL > 0 { - return p.SchedulerPauseTTL - } - return pauseTimeout -} - -// FetchPDVersion get pd version -func FetchPDVersion(ctx context.Context, pdHTTPCli pdhttp.Client) (*semver.Version, error) { - ver, err := pdHTTPCli.GetPDVersion(ctx) - if err != nil { - return nil, errors.Trace(err) - } - - return parseVersion(ver), nil -} diff --git a/br/pkg/restore/binding__failpoint_binding__.go b/br/pkg/restore/binding__failpoint_binding__.go deleted file mode 100644 index dab6e72f95323..0000000000000 --- a/br/pkg/restore/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package restore - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/restore/log_client/binding__failpoint_binding__.go b/br/pkg/restore/log_client/binding__failpoint_binding__.go deleted file mode 100644 index b17db4979c387..0000000000000 --- a/br/pkg/restore/log_client/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package logclient - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/restore/log_client/client.go b/br/pkg/restore/log_client/client.go index 18a14c76c10e3..fb34e709ce4ec 100644 --- a/br/pkg/restore/log_client/client.go +++ b/br/pkg/restore/log_client/client.go @@ -828,9 +828,9 @@ func (rc *LogClient) RestoreMetaKVFiles( filesInDefaultCF = SortMetaKVFiles(filesInDefaultCF) filesInWriteCF = SortMetaKVFiles(filesInWriteCF) - if _, _err_ := failpoint.Eval(_curpkg_("failed-before-id-maps-saved")); _err_ == nil { - return errors.New("failpoint: failed before id maps saved") - } + failpoint.Inject("failed-before-id-maps-saved", func(_ failpoint.Value) { + failpoint.Return(errors.New("failpoint: failed before id maps saved")) + }) log.Info("start to restore meta files", zap.Int("total files", len(files)), @@ -848,9 +848,9 @@ func (rc *LogClient) RestoreMetaKVFiles( return errors.Trace(err) } } - if _, _err_ := failpoint.Eval(_curpkg_("failed-after-id-maps-saved")); _err_ == nil { - return errors.New("failpoint: failed after id maps saved") - } + failpoint.Inject("failed-after-id-maps-saved", func(_ failpoint.Value) { + failpoint.Return(errors.New("failpoint: failed after id maps saved")) + }) // run the rewrite and restore meta-kv into TiKV cluster. if err := RestoreMetaKVFilesWithBatchMethod( @@ -1087,18 +1087,18 @@ func (rc *LogClient) restoreMetaKvEntries( log.Debug("after rewrite entry", zap.Int("new-key-len", len(newEntry.Key)), zap.Int("new-value-len", len(entry.E.Value)), zap.ByteString("new-key", newEntry.Key)) - if _, _err_ := failpoint.Eval(_curpkg_("failed-to-restore-metakv")); _err_ == nil { - return 0, 0, errors.Errorf("failpoint: failed to restore metakv") - } + failpoint.Inject("failed-to-restore-metakv", func(_ failpoint.Value) { + failpoint.Return(0, 0, errors.Errorf("failpoint: failed to restore metakv")) + }) if err := rc.rawKVClient.Put(ctx, newEntry.Key, newEntry.Value, entry.Ts); err != nil { return 0, 0, errors.Trace(err) } // for failpoint, we need to flush the cache in rawKVClient every time - if _, _err_ := failpoint.Eval(_curpkg_("do-not-put-metakv-in-batch")); _err_ == nil { + failpoint.Inject("do-not-put-metakv-in-batch", func(_ failpoint.Value) { if err := rc.rawKVClient.PutRest(ctx); err != nil { - return 0, 0, errors.Trace(err) + failpoint.Return(0, 0, errors.Trace(err)) } - } + }) kvCount++ size += uint64(len(newEntry.Key) + len(newEntry.Value)) } @@ -1397,11 +1397,11 @@ NEXTSQL: return errors.Trace(err) } } - if v, _err_ := failpoint.Eval(_curpkg_("failed-before-create-ingest-index")); _err_ == nil { + failpoint.Inject("failed-before-create-ingest-index", func(v failpoint.Value) { if v != nil && v.(bool) { - return errors.New("failed before create ingest index") + failpoint.Return(errors.New("failed before create ingest index")) } - } + }) // create the repaired index when first execution or not found it if err := rc.se.ExecuteInternal(ctx, sql.AddSQL, sql.AddArgs...); err != nil { return errors.Trace(err) diff --git a/br/pkg/restore/log_client/client.go__failpoint_stash__ b/br/pkg/restore/log_client/client.go__failpoint_stash__ deleted file mode 100644 index fb34e709ce4ec..0000000000000 --- a/br/pkg/restore/log_client/client.go__failpoint_stash__ +++ /dev/null @@ -1,1689 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package logclient - -import ( - "cmp" - "context" - "crypto/tls" - "fmt" - "math" - "os" - "slices" - "strconv" - "strings" - "sync" - "time" - - "github.com/fatih/color" - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/checkpoint" - "github.com/pingcap/tidb/br/pkg/checksum" - "github.com/pingcap/tidb/br/pkg/conn" - "github.com/pingcap/tidb/br/pkg/conn/util" - "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/metautil" - "github.com/pingcap/tidb/br/pkg/restore" - "github.com/pingcap/tidb/br/pkg/restore/ingestrec" - importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" - logsplit "github.com/pingcap/tidb/br/pkg/restore/internal/log_split" - "github.com/pingcap/tidb/br/pkg/restore/internal/rawkv" - "github.com/pingcap/tidb/br/pkg/restore/split" - "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" - restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/stream" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/br/pkg/utils/iter" - "github.com/pingcap/tidb/br/pkg/version" - ddlutil "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/model" - tidbutil "github.com/pingcap/tidb/pkg/util" - filter "github.com/pingcap/tidb/pkg/util/table-filter" - "github.com/tikv/client-go/v2/config" - kvutil "github.com/tikv/client-go/v2/util" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc/keepalive" -) - -const MetaKVBatchSize = 64 * 1024 * 1024 -const maxSplitKeysOnce = 10240 - -// rawKVBatchCount specifies the count of entries that the rawkv client puts into TiKV. -const rawKVBatchCount = 64 - -type LogClient struct { - cipher *backuppb.CipherInfo - pdClient pd.Client - pdHTTPClient pdhttp.Client - clusterID uint64 - dom *domain.Domain - tlsConf *tls.Config - keepaliveConf keepalive.ClientParameters - - rawKVClient *rawkv.RawKVBatchClient - storage storage.ExternalStorage - - se glue.Session - - // currentTS is used for rewrite meta kv when restore stream. - // Can not use `restoreTS` directly, because schema created in `full backup` maybe is new than `restoreTS`. - currentTS uint64 - - *LogFileManager - - workerPool *tidbutil.WorkerPool - fileImporter *LogFileImporter - - // the query to insert rows into table `gc_delete_range`, lack of ts. - deleteRangeQuery []*stream.PreDelRangeQuery - deleteRangeQueryCh chan *stream.PreDelRangeQuery - deleteRangeQueryWaitGroup sync.WaitGroup - - // checkpoint information for log restore - useCheckpoint bool -} - -// NewRestoreClient returns a new RestoreClient. -func NewRestoreClient( - pdClient pd.Client, - pdHTTPCli pdhttp.Client, - tlsConf *tls.Config, - keepaliveConf keepalive.ClientParameters, -) *LogClient { - return &LogClient{ - pdClient: pdClient, - pdHTTPClient: pdHTTPCli, - tlsConf: tlsConf, - keepaliveConf: keepaliveConf, - deleteRangeQuery: make([]*stream.PreDelRangeQuery, 0), - deleteRangeQueryCh: make(chan *stream.PreDelRangeQuery, 10), - } -} - -// Close a client. -func (rc *LogClient) Close() { - // close the connection, and it must be succeed when in SQL mode. - if rc.se != nil { - rc.se.Close() - } - - if rc.rawKVClient != nil { - rc.rawKVClient.Close() - } - - if err := rc.fileImporter.Close(); err != nil { - log.Warn("failed to close file improter") - } - - log.Info("Restore client closed") -} - -func (rc *LogClient) SetRawKVBatchClient( - ctx context.Context, - pdAddrs []string, - security config.Security, -) error { - rawkvClient, err := rawkv.NewRawkvClient(ctx, pdAddrs, security) - if err != nil { - return errors.Trace(err) - } - - rc.rawKVClient = rawkv.NewRawKVBatchClient(rawkvClient, rawKVBatchCount) - return nil -} - -func (rc *LogClient) SetCrypter(crypter *backuppb.CipherInfo) { - rc.cipher = crypter -} - -func (rc *LogClient) SetConcurrency(c uint) { - log.Info("download worker pool", zap.Uint("size", c)) - rc.workerPool = tidbutil.NewWorkerPool(c, "file") -} - -func (rc *LogClient) SetStorage(ctx context.Context, backend *backuppb.StorageBackend, opts *storage.ExternalStorageOptions) error { - var err error - rc.storage, err = storage.New(ctx, backend, opts) - if err != nil { - return errors.Trace(err) - } - return nil -} - -func (rc *LogClient) SetCurrentTS(ts uint64) { - rc.currentTS = ts -} - -// GetClusterID gets the cluster id from down-stream cluster. -func (rc *LogClient) GetClusterID(ctx context.Context) uint64 { - if rc.clusterID <= 0 { - rc.clusterID = rc.pdClient.GetClusterID(ctx) - } - return rc.clusterID -} - -func (rc *LogClient) GetDomain() *domain.Domain { - return rc.dom -} - -func (rc *LogClient) CleanUpKVFiles( - ctx context.Context, -) error { - // Current we only have v1 prefix. - // In the future, we can add more operation for this interface. - return rc.fileImporter.ClearFiles(ctx, rc.pdClient, "v1") -} - -func (rc *LogClient) StartCheckpointRunnerForLogRestore(ctx context.Context, taskName string) (*checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType], error) { - runner, err := checkpoint.StartCheckpointRunnerForLogRestore(ctx, rc.storage, rc.cipher, taskName) - return runner, errors.Trace(err) -} - -// Init create db connection and domain for storage. -func (rc *LogClient) Init(g glue.Glue, store kv.Storage) error { - var err error - rc.se, err = g.CreateSession(store) - if err != nil { - return errors.Trace(err) - } - - // Set SQL mode to None for avoiding SQL compatibility problem - err = rc.se.Execute(context.Background(), "set @@sql_mode=''") - if err != nil { - return errors.Trace(err) - } - - rc.dom, err = g.GetDomain(store) - if err != nil { - return errors.Trace(err) - } - - return nil -} - -func (rc *LogClient) InitClients(ctx context.Context, backend *backuppb.StorageBackend) { - stores, err := conn.GetAllTiKVStoresWithRetry(ctx, rc.pdClient, util.SkipTiFlash) - if err != nil { - log.Fatal("failed to get stores", zap.Error(err)) - } - - metaClient := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, len(stores)+1) - importCli := importclient.NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf) - rc.fileImporter = NewLogFileImporter(metaClient, importCli, backend) -} - -func (rc *LogClient) InitCheckpointMetadataForLogRestore(ctx context.Context, taskName string, gcRatio string) (string, error) { - rc.useCheckpoint = true - - // it shows that the user has modified gc-ratio, if `gcRatio` doesn't equal to "1.1". - // update the `gcRatio` for checkpoint metadata. - if gcRatio == utils.DefaultGcRatioVal { - // if the checkpoint metadata exists in the external storage, the restore is not - // for the first time. - exists, err := checkpoint.ExistsRestoreCheckpoint(ctx, rc.storage, taskName) - if err != nil { - return "", errors.Trace(err) - } - - if exists { - // load the checkpoint since this is not the first time to restore - meta, err := checkpoint.LoadCheckpointMetadataForRestore(ctx, rc.storage, taskName) - if err != nil { - return "", errors.Trace(err) - } - - log.Info("reuse gc ratio from checkpoint metadata", zap.String("gc-ratio", gcRatio)) - return meta.GcRatio, nil - } - } - - // initialize the checkpoint metadata since it is the first time to restore. - log.Info("save gc ratio into checkpoint metadata", zap.String("gc-ratio", gcRatio)) - if err := checkpoint.SaveCheckpointMetadataForRestore(ctx, rc.storage, &checkpoint.CheckpointMetadataForRestore{ - GcRatio: gcRatio, - }, taskName); err != nil { - return gcRatio, errors.Trace(err) - } - - return gcRatio, nil -} - -func (rc *LogClient) InstallLogFileManager(ctx context.Context, startTS, restoreTS uint64, metadataDownloadBatchSize uint) error { - init := LogFileManagerInit{ - StartTS: startTS, - RestoreTS: restoreTS, - Storage: rc.storage, - - MetadataDownloadBatchSize: metadataDownloadBatchSize, - } - var err error - rc.LogFileManager, err = CreateLogFileManager(ctx, init) - if err != nil { - return err - } - return nil -} - -type FilesInRegion struct { - defaultSize uint64 - defaultKVCount int64 - writeSize uint64 - writeKVCount int64 - - defaultFiles []*LogDataFileInfo - writeFiles []*LogDataFileInfo - deleteFiles []*LogDataFileInfo -} - -type FilesInTable struct { - regionMapFiles map[int64]*FilesInRegion -} - -func ApplyKVFilesWithBatchMethod( - ctx context.Context, - logIter LogIter, - batchCount int, - batchSize uint64, - applyFunc func(files []*LogDataFileInfo, kvCount int64, size uint64), - applyWg *sync.WaitGroup, -) error { - var ( - tableMapFiles = make(map[int64]*FilesInTable) - tmpFiles = make([]*LogDataFileInfo, 0, batchCount) - tmpSize uint64 = 0 - tmpKVCount int64 = 0 - ) - for r := logIter.TryNext(ctx); !r.Finished; r = logIter.TryNext(ctx) { - if r.Err != nil { - return r.Err - } - - f := r.Item - if f.GetType() == backuppb.FileType_Put && f.GetLength() >= batchSize { - applyFunc([]*LogDataFileInfo{f}, f.GetNumberOfEntries(), f.GetLength()) - continue - } - - fit, exist := tableMapFiles[f.TableId] - if !exist { - fit = &FilesInTable{ - regionMapFiles: make(map[int64]*FilesInRegion), - } - tableMapFiles[f.TableId] = fit - } - fs, exist := fit.regionMapFiles[f.RegionId] - if !exist { - fs = &FilesInRegion{} - fit.regionMapFiles[f.RegionId] = fs - } - - if f.GetType() == backuppb.FileType_Delete { - if fs.defaultFiles == nil { - fs.deleteFiles = make([]*LogDataFileInfo, 0) - } - fs.deleteFiles = append(fs.deleteFiles, f) - } else { - if f.GetCf() == stream.DefaultCF { - if fs.defaultFiles == nil { - fs.defaultFiles = make([]*LogDataFileInfo, 0, batchCount) - } - fs.defaultFiles = append(fs.defaultFiles, f) - fs.defaultSize += f.Length - fs.defaultKVCount += f.GetNumberOfEntries() - if len(fs.defaultFiles) >= batchCount || fs.defaultSize >= batchSize { - applyFunc(fs.defaultFiles, fs.defaultKVCount, fs.defaultSize) - fs.defaultFiles = nil - fs.defaultSize = 0 - fs.defaultKVCount = 0 - } - } else { - if fs.writeFiles == nil { - fs.writeFiles = make([]*LogDataFileInfo, 0, batchCount) - } - fs.writeFiles = append(fs.writeFiles, f) - fs.writeSize += f.GetLength() - fs.writeKVCount += f.GetNumberOfEntries() - if len(fs.writeFiles) >= batchCount || fs.writeSize >= batchSize { - applyFunc(fs.writeFiles, fs.writeKVCount, fs.writeSize) - fs.writeFiles = nil - fs.writeSize = 0 - fs.writeKVCount = 0 - } - } - } - } - - for _, fwt := range tableMapFiles { - for _, fs := range fwt.regionMapFiles { - if len(fs.defaultFiles) > 0 { - applyFunc(fs.defaultFiles, fs.defaultKVCount, fs.defaultSize) - } - if len(fs.writeFiles) > 0 { - applyFunc(fs.writeFiles, fs.writeKVCount, fs.writeSize) - } - } - } - - applyWg.Wait() - for _, fwt := range tableMapFiles { - for _, fs := range fwt.regionMapFiles { - for _, d := range fs.deleteFiles { - tmpFiles = append(tmpFiles, d) - tmpSize += d.GetLength() - tmpKVCount += d.GetNumberOfEntries() - - if len(tmpFiles) >= batchCount || tmpSize >= batchSize { - applyFunc(tmpFiles, tmpKVCount, tmpSize) - tmpFiles = make([]*LogDataFileInfo, 0, batchCount) - tmpSize = 0 - tmpKVCount = 0 - } - } - if len(tmpFiles) > 0 { - applyFunc(tmpFiles, tmpKVCount, tmpSize) - tmpFiles = make([]*LogDataFileInfo, 0, batchCount) - tmpSize = 0 - tmpKVCount = 0 - } - } - } - - return nil -} - -func ApplyKVFilesWithSingelMethod( - ctx context.Context, - files LogIter, - applyFunc func(file []*LogDataFileInfo, kvCount int64, size uint64), - applyWg *sync.WaitGroup, -) error { - deleteKVFiles := make([]*LogDataFileInfo, 0) - - for r := files.TryNext(ctx); !r.Finished; r = files.TryNext(ctx) { - if r.Err != nil { - return r.Err - } - - f := r.Item - if f.GetType() == backuppb.FileType_Delete { - deleteKVFiles = append(deleteKVFiles, f) - continue - } - applyFunc([]*LogDataFileInfo{f}, f.GetNumberOfEntries(), f.GetLength()) - } - - applyWg.Wait() - log.Info("restore delete files", zap.Int("count", len(deleteKVFiles))) - for _, file := range deleteKVFiles { - f := file - applyFunc([]*LogDataFileInfo{f}, f.GetNumberOfEntries(), f.GetLength()) - } - - return nil -} - -func (rc *LogClient) RestoreKVFiles( - ctx context.Context, - rules map[int64]*restoreutils.RewriteRules, - idrules map[int64]int64, - logIter LogIter, - runner *checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType], - pitrBatchCount uint32, - pitrBatchSize uint32, - updateStats func(kvCount uint64, size uint64), - onProgress func(cnt int64), -) error { - var ( - err error - fileCount = 0 - start = time.Now() - supportBatch = version.CheckPITRSupportBatchKVFiles() - skipFile = 0 - ) - defer func() { - if err == nil { - elapsed := time.Since(start) - log.Info("Restore KV files", zap.Duration("take", elapsed)) - summary.CollectSuccessUnit("files", fileCount, elapsed) - } - }() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("Client.RestoreKVFiles", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - var applyWg sync.WaitGroup - eg, ectx := errgroup.WithContext(ctx) - applyFunc := func(files []*LogDataFileInfo, kvCount int64, size uint64) { - if len(files) == 0 { - return - } - // get rewrite rule from table id. - // because the tableID of files is the same. - rule, ok := rules[files[0].TableId] - if !ok { - // TODO handle new created table - // For this version we do not handle new created table after full backup. - // in next version we will perform rewrite and restore meta key to restore new created tables. - // so we can simply skip the file that doesn't have the rule here. - onProgress(int64(len(files))) - summary.CollectInt("FileSkip", len(files)) - log.Debug("skip file due to table id not matched", zap.Int64("table-id", files[0].TableId)) - skipFile += len(files) - } else { - applyWg.Add(1) - downstreamId := idrules[files[0].TableId] - rc.workerPool.ApplyOnErrorGroup(eg, func() (err error) { - fileStart := time.Now() - defer applyWg.Done() - defer func() { - onProgress(int64(len(files))) - updateStats(uint64(kvCount), size) - summary.CollectInt("File", len(files)) - - if err == nil { - filenames := make([]string, 0, len(files)) - if runner == nil { - for _, f := range files { - filenames = append(filenames, f.Path+", ") - } - } else { - for _, f := range files { - filenames = append(filenames, f.Path+", ") - if e := checkpoint.AppendRangeForLogRestore(ectx, runner, f.MetaDataGroupName, downstreamId, f.OffsetInMetaGroup, f.OffsetInMergedGroup); e != nil { - err = errors.Annotate(e, "failed to append checkpoint data") - break - } - } - } - log.Info("import files done", zap.Int("batch-count", len(files)), zap.Uint64("batch-size", size), - zap.Duration("take", time.Since(fileStart)), zap.Strings("files", filenames)) - } - }() - - return rc.fileImporter.ImportKVFiles(ectx, files, rule, rc.shiftStartTS, rc.startTS, rc.restoreTS, supportBatch) - }) - } - } - - rc.workerPool.ApplyOnErrorGroup(eg, func() error { - if supportBatch { - err = ApplyKVFilesWithBatchMethod(ectx, logIter, int(pitrBatchCount), uint64(pitrBatchSize), applyFunc, &applyWg) - } else { - err = ApplyKVFilesWithSingelMethod(ectx, logIter, applyFunc, &applyWg) - } - return errors.Trace(err) - }) - - if err = eg.Wait(); err != nil { - summary.CollectFailureUnit("file", err) - log.Error("restore files failed", zap.Error(err)) - } - - log.Info("total skip files due to table id not matched", zap.Int("count", skipFile)) - if skipFile > 0 { - log.Debug("table id in full backup storage", zap.Any("tables", rules)) - } - - return errors.Trace(err) -} - -func (rc *LogClient) initSchemasMap( - ctx context.Context, - clusterID uint64, - restoreTS uint64, -) ([]*backuppb.PitrDBMap, error) { - filename := metautil.PitrIDMapsFilename(clusterID, restoreTS) - exist, err := rc.storage.FileExists(ctx, filename) - if err != nil { - return nil, errors.Annotatef(err, "failed to check filename:%s ", filename) - } else if !exist { - log.Info("pitr id maps isn't existed", zap.String("file", filename)) - return nil, nil - } - - metaData, err := rc.storage.ReadFile(ctx, filename) - if err != nil { - return nil, errors.Trace(err) - } - backupMeta := &backuppb.BackupMeta{} - if err = backupMeta.Unmarshal(metaData); err != nil { - return nil, errors.Trace(err) - } - - return backupMeta.GetDbMaps(), nil -} - -func initFullBackupTables( - ctx context.Context, - s storage.ExternalStorage, - tableFilter filter.Filter, -) (map[int64]*metautil.Table, error) { - metaData, err := s.ReadFile(ctx, metautil.MetaFile) - if err != nil { - return nil, errors.Trace(err) - } - backupMeta := &backuppb.BackupMeta{} - if err = backupMeta.Unmarshal(metaData); err != nil { - return nil, errors.Trace(err) - } - - // read full backup databases to get map[table]table.Info - reader := metautil.NewMetaReader(backupMeta, s, nil) - - databases, err := metautil.LoadBackupTables(ctx, reader, false) - if err != nil { - return nil, errors.Trace(err) - } - - tables := make(map[int64]*metautil.Table) - for _, db := range databases { - dbName := db.Info.Name.O - if name, ok := utils.GetSysDBName(db.Info.Name); utils.IsSysDB(name) && ok { - dbName = name - } - - if !tableFilter.MatchSchema(dbName) { - continue - } - - for _, table := range db.Tables { - // check this db is empty. - if table.Info == nil { - tables[db.Info.ID] = table - continue - } - if !tableFilter.MatchTable(dbName, table.Info.Name.O) { - continue - } - tables[table.Info.ID] = table - } - } - - return tables, nil -} - -type FullBackupStorageConfig struct { - Backend *backuppb.StorageBackend - Opts *storage.ExternalStorageOptions -} - -type InitSchemaConfig struct { - // required - IsNewTask bool - TableFilter filter.Filter - - // optional - TiFlashRecorder *tiflashrec.TiFlashRecorder - FullBackupStorage *FullBackupStorageConfig -} - -const UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL = "UNSAFE_PITR_LOG_RESTORE_START_BEFORE_ANY_UPSTREAM_USER_DDL" - -func (rc *LogClient) generateDBReplacesFromFullBackupStorage( - ctx context.Context, - cfg *InitSchemaConfig, -) (map[stream.UpstreamID]*stream.DBReplace, error) { - dbReplaces := make(map[stream.UpstreamID]*stream.DBReplace) - if cfg.FullBackupStorage == nil { - envVal, ok := os.LookupEnv(UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL) - if ok && len(envVal) > 0 { - log.Info(fmt.Sprintf("the environment variable %s is active, skip loading the base schemas.", UnsafePITRLogRestoreStartBeforeAnyUpstreamUserDDL)) - return dbReplaces, nil - } - return nil, errors.Errorf("miss upstream table information at `start-ts`(%d) but the full backup path is not specified", rc.startTS) - } - s, err := storage.New(ctx, cfg.FullBackupStorage.Backend, cfg.FullBackupStorage.Opts) - if err != nil { - return nil, errors.Trace(err) - } - fullBackupTables, err := initFullBackupTables(ctx, s, cfg.TableFilter) - if err != nil { - return nil, errors.Trace(err) - } - for _, t := range fullBackupTables { - dbName, _ := utils.GetSysDBCIStrName(t.DB.Name) - newDBInfo, exist := rc.dom.InfoSchema().SchemaByName(dbName) - if !exist { - log.Info("db not existed", zap.String("dbname", dbName.String())) - continue - } - - dbReplace, exist := dbReplaces[t.DB.ID] - if !exist { - dbReplace = stream.NewDBReplace(t.DB.Name.O, newDBInfo.ID) - dbReplaces[t.DB.ID] = dbReplace - } - - if t.Info == nil { - // If the db is empty, skip it. - continue - } - newTableInfo, err := restore.GetTableSchema(rc.GetDomain(), dbName, t.Info.Name) - if err != nil { - log.Info("table not existed", zap.String("tablename", dbName.String()+"."+t.Info.Name.String())) - continue - } - - dbReplace.TableMap[t.Info.ID] = &stream.TableReplace{ - Name: newTableInfo.Name.O, - TableID: newTableInfo.ID, - PartitionMap: restoreutils.GetPartitionIDMap(newTableInfo, t.Info), - IndexMap: restoreutils.GetIndexIDMap(newTableInfo, t.Info), - } - } - return dbReplaces, nil -} - -// InitSchemasReplaceForDDL gets schemas information Mapping from old schemas to new schemas. -// It is used to rewrite meta kv-event. -func (rc *LogClient) InitSchemasReplaceForDDL( - ctx context.Context, - cfg *InitSchemaConfig, -) (*stream.SchemasReplace, error) { - var ( - err error - dbMaps []*backuppb.PitrDBMap - // the id map doesn't need to construct only when it is not the first execution - needConstructIdMap bool - - dbReplaces map[stream.UpstreamID]*stream.DBReplace - ) - - // not new task, load schemas map from external storage - if !cfg.IsNewTask { - log.Info("try to load pitr id maps") - needConstructIdMap = false - dbMaps, err = rc.initSchemasMap(ctx, rc.GetClusterID(ctx), rc.restoreTS) - if err != nil { - return nil, errors.Trace(err) - } - } - - // a new task, but without full snapshot restore, tries to load - // schemas map whose `restore-ts`` is the task's `start-ts`. - if len(dbMaps) <= 0 && cfg.FullBackupStorage == nil { - log.Info("try to load pitr id maps of the previous task", zap.Uint64("start-ts", rc.startTS)) - needConstructIdMap = true - dbMaps, err = rc.initSchemasMap(ctx, rc.GetClusterID(ctx), rc.startTS) - if err != nil { - return nil, errors.Trace(err) - } - existTiFlashTable := false - rc.dom.InfoSchema().ListTablesWithSpecialAttribute(func(tableInfo *model.TableInfo) bool { - if tableInfo.TiFlashReplica != nil && tableInfo.TiFlashReplica.Count > 0 { - existTiFlashTable = true - } - return false - }) - if existTiFlashTable { - return nil, errors.Errorf("exist table(s) have tiflash replica, please remove it before restore") - } - } - - if len(dbMaps) <= 0 { - log.Info("no id maps, build the table replaces from cluster and full backup schemas") - needConstructIdMap = true - dbReplaces, err = rc.generateDBReplacesFromFullBackupStorage(ctx, cfg) - if err != nil { - return nil, errors.Trace(err) - } - } else { - dbReplaces = stream.FromSchemaMaps(dbMaps) - } - - for oldDBID, dbReplace := range dbReplaces { - log.Info("replace info", func() []zapcore.Field { - fields := make([]zapcore.Field, 0, (len(dbReplace.TableMap)+1)*3) - fields = append(fields, - zap.String("dbName", dbReplace.Name), - zap.Int64("oldID", oldDBID), - zap.Int64("newID", dbReplace.DbID)) - for oldTableID, tableReplace := range dbReplace.TableMap { - fields = append(fields, - zap.String("table", tableReplace.Name), - zap.Int64("oldID", oldTableID), - zap.Int64("newID", tableReplace.TableID)) - } - return fields - }()...) - } - - rp := stream.NewSchemasReplace( - dbReplaces, needConstructIdMap, cfg.TiFlashRecorder, rc.currentTS, cfg.TableFilter, rc.GenGlobalID, rc.GenGlobalIDs, - rc.RecordDeleteRange) - return rp, nil -} - -func SortMetaKVFiles(files []*backuppb.DataFileInfo) []*backuppb.DataFileInfo { - slices.SortFunc(files, func(i, j *backuppb.DataFileInfo) int { - if c := cmp.Compare(i.GetMinTs(), j.GetMinTs()); c != 0 { - return c - } - if c := cmp.Compare(i.GetMaxTs(), j.GetMaxTs()); c != 0 { - return c - } - return cmp.Compare(i.GetResolvedTs(), j.GetResolvedTs()) - }) - return files -} - -// RestoreMetaKVFiles tries to restore files about meta kv-event from stream-backup. -func (rc *LogClient) RestoreMetaKVFiles( - ctx context.Context, - files []*backuppb.DataFileInfo, - schemasReplace *stream.SchemasReplace, - updateStats func(kvCount uint64, size uint64), - progressInc func(), -) error { - filesInWriteCF := make([]*backuppb.DataFileInfo, 0, len(files)) - filesInDefaultCF := make([]*backuppb.DataFileInfo, 0, len(files)) - - // The k-v events in default CF should be restored firstly. The reason is that: - // The error of transactions of meta could happen if restore write CF events successfully, - // but failed to restore default CF events. - for _, f := range files { - if f.Cf == stream.WriteCF { - filesInWriteCF = append(filesInWriteCF, f) - continue - } - if f.Type == backuppb.FileType_Delete { - // this should happen abnormally. - // only do some preventive checks here. - log.Warn("detected delete file of meta key, skip it", zap.Any("file", f)) - continue - } - if f.Cf == stream.DefaultCF { - filesInDefaultCF = append(filesInDefaultCF, f) - } - } - filesInDefaultCF = SortMetaKVFiles(filesInDefaultCF) - filesInWriteCF = SortMetaKVFiles(filesInWriteCF) - - failpoint.Inject("failed-before-id-maps-saved", func(_ failpoint.Value) { - failpoint.Return(errors.New("failpoint: failed before id maps saved")) - }) - - log.Info("start to restore meta files", - zap.Int("total files", len(files)), - zap.Int("default files", len(filesInDefaultCF)), - zap.Int("write files", len(filesInWriteCF))) - - if schemasReplace.NeedConstructIdMap() { - // Preconstruct the map and save it into external storage. - if err := rc.PreConstructAndSaveIDMap( - ctx, - filesInWriteCF, - filesInDefaultCF, - schemasReplace, - ); err != nil { - return errors.Trace(err) - } - } - failpoint.Inject("failed-after-id-maps-saved", func(_ failpoint.Value) { - failpoint.Return(errors.New("failpoint: failed after id maps saved")) - }) - - // run the rewrite and restore meta-kv into TiKV cluster. - if err := RestoreMetaKVFilesWithBatchMethod( - ctx, - filesInDefaultCF, - filesInWriteCF, - schemasReplace, - updateStats, - progressInc, - rc.RestoreBatchMetaKVFiles, - ); err != nil { - return errors.Trace(err) - } - - // Update global schema version and report all of TiDBs. - if err := rc.UpdateSchemaVersion(ctx); err != nil { - return errors.Trace(err) - } - return nil -} - -// PreConstructAndSaveIDMap constructs id mapping and save it. -func (rc *LogClient) PreConstructAndSaveIDMap( - ctx context.Context, - fsInWriteCF, fsInDefaultCF []*backuppb.DataFileInfo, - sr *stream.SchemasReplace, -) error { - sr.SetPreConstructMapStatus() - - if err := rc.constructIDMap(ctx, fsInWriteCF, sr); err != nil { - return errors.Trace(err) - } - if err := rc.constructIDMap(ctx, fsInDefaultCF, sr); err != nil { - return errors.Trace(err) - } - - if err := rc.SaveIDMap(ctx, sr); err != nil { - return errors.Trace(err) - } - return nil -} - -func (rc *LogClient) constructIDMap( - ctx context.Context, - fs []*backuppb.DataFileInfo, - sr *stream.SchemasReplace, -) error { - for _, f := range fs { - entries, _, err := rc.ReadAllEntries(ctx, f, math.MaxUint64) - if err != nil { - return errors.Trace(err) - } - - for _, entry := range entries { - if _, err := sr.RewriteKvEntry(&entry.E, f.GetCf()); err != nil { - return errors.Trace(err) - } - } - } - return nil -} - -func RestoreMetaKVFilesWithBatchMethod( - ctx context.Context, - defaultFiles []*backuppb.DataFileInfo, - writeFiles []*backuppb.DataFileInfo, - schemasReplace *stream.SchemasReplace, - updateStats func(kvCount uint64, size uint64), - progressInc func(), - restoreBatch func( - ctx context.Context, - files []*backuppb.DataFileInfo, - schemasReplace *stream.SchemasReplace, - kvEntries []*KvEntryWithTS, - filterTS uint64, - updateStats func(kvCount uint64, size uint64), - progressInc func(), - cf string, - ) ([]*KvEntryWithTS, error), -) error { - // the average size of each KV is 2560 Bytes - // kvEntries is kvs left by the previous batch - const kvSize = 2560 - var ( - rangeMin uint64 - rangeMax uint64 - err error - - batchSize uint64 = 0 - defaultIdx int = 0 - writeIdx int = 0 - - defaultKvEntries = make([]*KvEntryWithTS, 0) - writeKvEntries = make([]*KvEntryWithTS, 0) - ) - // Set restoreKV to SchemaReplace. - schemasReplace.SetRestoreKVStatus() - - for i, f := range defaultFiles { - if i == 0 { - rangeMax = f.MaxTs - rangeMin = f.MinTs - batchSize = f.Length - } else { - if f.MinTs <= rangeMax && batchSize+f.Length <= MetaKVBatchSize { - rangeMin = min(rangeMin, f.MinTs) - rangeMax = max(rangeMax, f.MaxTs) - batchSize += f.Length - } else { - // Either f.MinTS > rangeMax or f.MinTs is the filterTs we need. - // So it is ok to pass f.MinTs as filterTs. - defaultKvEntries, err = restoreBatch(ctx, defaultFiles[defaultIdx:i], schemasReplace, defaultKvEntries, f.MinTs, updateStats, progressInc, stream.DefaultCF) - if err != nil { - return errors.Trace(err) - } - defaultIdx = i - rangeMin = f.MinTs - rangeMax = f.MaxTs - // the initial batch size is the size of left kvs and the current file length. - batchSize = uint64(len(defaultKvEntries)*kvSize) + f.Length - - // restore writeCF kv to f.MinTs - var toWriteIdx int - for toWriteIdx = writeIdx; toWriteIdx < len(writeFiles); toWriteIdx++ { - if writeFiles[toWriteIdx].MinTs >= f.MinTs { - break - } - } - writeKvEntries, err = restoreBatch(ctx, writeFiles[writeIdx:toWriteIdx], schemasReplace, writeKvEntries, f.MinTs, updateStats, progressInc, stream.WriteCF) - if err != nil { - return errors.Trace(err) - } - writeIdx = toWriteIdx - } - } - } - - // restore the left meta kv files and entries - // Notice: restoreBatch needs to realize the parameter `files` and `kvEntries` might be empty - // Assert: defaultIdx <= len(defaultFiles) && writeIdx <= len(writeFiles) - _, err = restoreBatch(ctx, defaultFiles[defaultIdx:], schemasReplace, defaultKvEntries, math.MaxUint64, updateStats, progressInc, stream.DefaultCF) - if err != nil { - return errors.Trace(err) - } - _, err = restoreBatch(ctx, writeFiles[writeIdx:], schemasReplace, writeKvEntries, math.MaxUint64, updateStats, progressInc, stream.WriteCF) - if err != nil { - return errors.Trace(err) - } - - return nil -} - -func (rc *LogClient) RestoreBatchMetaKVFiles( - ctx context.Context, - files []*backuppb.DataFileInfo, - schemasReplace *stream.SchemasReplace, - kvEntries []*KvEntryWithTS, - filterTS uint64, - updateStats func(kvCount uint64, size uint64), - progressInc func(), - cf string, -) ([]*KvEntryWithTS, error) { - nextKvEntries := make([]*KvEntryWithTS, 0) - curKvEntries := make([]*KvEntryWithTS, 0) - if len(files) == 0 && len(kvEntries) == 0 { - return nextKvEntries, nil - } - - // filter the kv from kvEntries again. - for _, kv := range kvEntries { - if kv.Ts < filterTS { - curKvEntries = append(curKvEntries, kv) - } else { - nextKvEntries = append(nextKvEntries, kv) - } - } - - // read all of entries from files. - for _, f := range files { - es, nextEs, err := rc.ReadAllEntries(ctx, f, filterTS) - if err != nil { - return nextKvEntries, errors.Trace(err) - } - - curKvEntries = append(curKvEntries, es...) - nextKvEntries = append(nextKvEntries, nextEs...) - } - - // sort these entries. - slices.SortFunc(curKvEntries, func(i, j *KvEntryWithTS) int { - return cmp.Compare(i.Ts, j.Ts) - }) - - // restore these entries with rawPut() method. - kvCount, size, err := rc.restoreMetaKvEntries(ctx, schemasReplace, curKvEntries, cf) - if err != nil { - return nextKvEntries, errors.Trace(err) - } - - if schemasReplace.IsRestoreKVStatus() { - updateStats(kvCount, size) - for i := 0; i < len(files); i++ { - progressInc() - } - } - return nextKvEntries, nil -} - -func (rc *LogClient) restoreMetaKvEntries( - ctx context.Context, - sr *stream.SchemasReplace, - entries []*KvEntryWithTS, - columnFamily string, -) (uint64, uint64, error) { - var ( - kvCount uint64 - size uint64 - ) - - rc.rawKVClient.SetColumnFamily(columnFamily) - - for _, entry := range entries { - log.Debug("before rewrte entry", zap.Uint64("key-ts", entry.Ts), zap.Int("key-len", len(entry.E.Key)), - zap.Int("value-len", len(entry.E.Value)), zap.ByteString("key", entry.E.Key)) - - newEntry, err := sr.RewriteKvEntry(&entry.E, columnFamily) - if err != nil { - log.Error("rewrite txn entry failed", zap.Int("klen", len(entry.E.Key)), - logutil.Key("txn-key", entry.E.Key)) - return 0, 0, errors.Trace(err) - } else if newEntry == nil { - continue - } - log.Debug("after rewrite entry", zap.Int("new-key-len", len(newEntry.Key)), - zap.Int("new-value-len", len(entry.E.Value)), zap.ByteString("new-key", newEntry.Key)) - - failpoint.Inject("failed-to-restore-metakv", func(_ failpoint.Value) { - failpoint.Return(0, 0, errors.Errorf("failpoint: failed to restore metakv")) - }) - if err := rc.rawKVClient.Put(ctx, newEntry.Key, newEntry.Value, entry.Ts); err != nil { - return 0, 0, errors.Trace(err) - } - // for failpoint, we need to flush the cache in rawKVClient every time - failpoint.Inject("do-not-put-metakv-in-batch", func(_ failpoint.Value) { - if err := rc.rawKVClient.PutRest(ctx); err != nil { - failpoint.Return(0, 0, errors.Trace(err)) - } - }) - kvCount++ - size += uint64(len(newEntry.Key) + len(newEntry.Value)) - } - - return kvCount, size, rc.rawKVClient.PutRest(ctx) -} - -// GenGlobalID generates a global id by transaction way. -func (rc *LogClient) GenGlobalID(ctx context.Context) (int64, error) { - var id int64 - storage := rc.GetDomain().Store() - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) - err := kv.RunInNewTxn( - ctx, - storage, - true, - func(ctx context.Context, txn kv.Transaction) error { - var e error - t := meta.NewMeta(txn) - id, e = t.GenGlobalID() - return e - }) - - return id, err -} - -// GenGlobalIDs generates several global ids by transaction way. -func (rc *LogClient) GenGlobalIDs(ctx context.Context, n int) ([]int64, error) { - ids := make([]int64, 0) - storage := rc.GetDomain().Store() - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) - err := kv.RunInNewTxn( - ctx, - storage, - true, - func(ctx context.Context, txn kv.Transaction) error { - var e error - t := meta.NewMeta(txn) - ids, e = t.GenGlobalIDs(n) - return e - }) - - return ids, err -} - -// UpdateSchemaVersion updates schema version by transaction way. -func (rc *LogClient) UpdateSchemaVersion(ctx context.Context) error { - storage := rc.GetDomain().Store() - var schemaVersion int64 - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) - if err := kv.RunInNewTxn( - ctx, - storage, - true, - func(ctx context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - var e error - // To trigger full-reload instead of diff-reload, we need to increase the schema version - // by at least `domain.LoadSchemaDiffVersionGapThreshold`. - schemaVersion, e = t.GenSchemaVersions(64 + domain.LoadSchemaDiffVersionGapThreshold) - if e != nil { - return e - } - // add the diff key so that the domain won't retry to reload the schemas with `schemaVersion` frequently. - return t.SetSchemaDiff(&model.SchemaDiff{ - Version: schemaVersion, - Type: model.ActionNone, - SchemaID: -1, - TableID: -1, - RegenerateSchemaMap: true, - }) - }, - ); err != nil { - return errors.Trace(err) - } - - log.Info("update global schema version", zap.Int64("global-schema-version", schemaVersion)) - - ver := strconv.FormatInt(schemaVersion, 10) - if err := ddlutil.PutKVToEtcd( - ctx, - rc.GetDomain().GetEtcdClient(), - math.MaxInt, - ddlutil.DDLGlobalSchemaVersion, - ver, - ); err != nil { - return errors.Annotatef(err, "failed to put global schema verson %v to etcd", ver) - } - - return nil -} - -func (rc *LogClient) WrapLogFilesIterWithSplitHelper(logIter LogIter, rules map[int64]*restoreutils.RewriteRules, g glue.Glue, store kv.Storage) (LogIter, error) { - se, err := g.CreateSession(store) - if err != nil { - return nil, errors.Trace(err) - } - execCtx := se.GetSessionCtx().GetRestrictedSQLExecutor() - splitSize, splitKeys := utils.GetRegionSplitInfo(execCtx) - log.Info("get split threshold from tikv config", zap.Uint64("split-size", splitSize), zap.Int64("split-keys", splitKeys)) - client := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, 3) - return NewLogFilesIterWithSplitHelper(logIter, rules, client, splitSize, splitKeys), nil -} - -func (rc *LogClient) generateKvFilesSkipMap(ctx context.Context, downstreamIdset map[int64]struct{}, taskName string) (*LogFilesSkipMap, error) { - skipMap := NewLogFilesSkipMap() - t, err := checkpoint.WalkCheckpointFileForRestore(ctx, rc.storage, rc.cipher, taskName, func(groupKey checkpoint.LogRestoreKeyType, off checkpoint.LogRestoreValueMarshaled) { - for tableID, foffs := range off.Foffs { - // filter out the checkpoint data of dropped table - if _, exists := downstreamIdset[tableID]; exists { - for _, foff := range foffs { - skipMap.Insert(groupKey, off.Goff, foff) - } - } - } - }) - if err != nil { - return nil, errors.Trace(err) - } - summary.AdjustStartTimeToEarlierTime(t) - return skipMap, nil -} - -func (rc *LogClient) WrapLogFilesIterWithCheckpoint( - ctx context.Context, - logIter LogIter, - downstreamIdset map[int64]struct{}, - taskName string, - updateStats func(kvCount, size uint64), - onProgress func(), -) (LogIter, error) { - skipMap, err := rc.generateKvFilesSkipMap(ctx, downstreamIdset, taskName) - if err != nil { - return nil, errors.Trace(err) - } - return iter.FilterOut(logIter, func(d *LogDataFileInfo) bool { - if skipMap.NeedSkip(d.MetaDataGroupName, d.OffsetInMetaGroup, d.OffsetInMergedGroup) { - onProgress() - updateStats(uint64(d.NumberOfEntries), d.Length) - return true - } - return false - }), nil -} - -const ( - alterTableDropIndexSQL = "ALTER TABLE %n.%n DROP INDEX %n" - alterTableAddIndexFormat = "ALTER TABLE %%n.%%n ADD INDEX %%n(%s)" - alterTableAddUniqueIndexFormat = "ALTER TABLE %%n.%%n ADD UNIQUE KEY %%n(%s)" - alterTableAddPrimaryFormat = "ALTER TABLE %%n.%%n ADD PRIMARY KEY (%s) NONCLUSTERED" -) - -func (rc *LogClient) generateRepairIngestIndexSQLs( - ctx context.Context, - ingestRecorder *ingestrec.IngestRecorder, - taskName string, -) ([]checkpoint.CheckpointIngestIndexRepairSQL, bool, error) { - var sqls []checkpoint.CheckpointIngestIndexRepairSQL - if rc.useCheckpoint { - exists, err := checkpoint.ExistsCheckpointIngestIndexRepairSQLs(ctx, rc.storage, taskName) - if err != nil { - return sqls, false, errors.Trace(err) - } - if exists { - checkpointSQLs, err := checkpoint.LoadCheckpointIngestIndexRepairSQLs(ctx, rc.storage, taskName) - if err != nil { - return sqls, false, errors.Trace(err) - } - sqls = checkpointSQLs.SQLs - log.Info("load ingest index repair sqls from checkpoint", zap.String("category", "ingest"), zap.Reflect("sqls", sqls)) - return sqls, true, nil - } - } - - if err := ingestRecorder.UpdateIndexInfo(rc.dom.InfoSchema()); err != nil { - return sqls, false, errors.Trace(err) - } - if err := ingestRecorder.Iterate(func(_, indexID int64, info *ingestrec.IngestIndexInfo) error { - var ( - addSQL strings.Builder - addArgs []any = make([]any, 0, 5+len(info.ColumnArgs)) - ) - if info.IsPrimary { - addSQL.WriteString(fmt.Sprintf(alterTableAddPrimaryFormat, info.ColumnList)) - addArgs = append(addArgs, info.SchemaName.O, info.TableName.O) - addArgs = append(addArgs, info.ColumnArgs...) - } else if info.IndexInfo.Unique { - addSQL.WriteString(fmt.Sprintf(alterTableAddUniqueIndexFormat, info.ColumnList)) - addArgs = append(addArgs, info.SchemaName.O, info.TableName.O, info.IndexInfo.Name.O) - addArgs = append(addArgs, info.ColumnArgs...) - } else { - addSQL.WriteString(fmt.Sprintf(alterTableAddIndexFormat, info.ColumnList)) - addArgs = append(addArgs, info.SchemaName.O, info.TableName.O, info.IndexInfo.Name.O) - addArgs = append(addArgs, info.ColumnArgs...) - } - // USING BTREE/HASH/RTREE - indexTypeStr := info.IndexInfo.Tp.String() - if len(indexTypeStr) > 0 { - addSQL.WriteString(" USING ") - addSQL.WriteString(indexTypeStr) - } - - // COMMENT [...] - if len(info.IndexInfo.Comment) > 0 { - addSQL.WriteString(" COMMENT %?") - addArgs = append(addArgs, info.IndexInfo.Comment) - } - - if info.IndexInfo.Invisible { - addSQL.WriteString(" INVISIBLE") - } else { - addSQL.WriteString(" VISIBLE") - } - - sqls = append(sqls, checkpoint.CheckpointIngestIndexRepairSQL{ - IndexID: indexID, - SchemaName: info.SchemaName, - TableName: info.TableName, - IndexName: info.IndexInfo.Name.O, - AddSQL: addSQL.String(), - AddArgs: addArgs, - }) - - return nil - }); err != nil { - return sqls, false, errors.Trace(err) - } - - if rc.useCheckpoint && len(sqls) > 0 { - if err := checkpoint.SaveCheckpointIngestIndexRepairSQLs(ctx, rc.storage, &checkpoint.CheckpointIngestIndexRepairSQLs{ - SQLs: sqls, - }, taskName); err != nil { - return sqls, false, errors.Trace(err) - } - } - return sqls, false, nil -} - -// RepairIngestIndex drops the indexes from IngestRecorder and re-add them. -func (rc *LogClient) RepairIngestIndex(ctx context.Context, ingestRecorder *ingestrec.IngestRecorder, g glue.Glue, taskName string) error { - sqls, fromCheckpoint, err := rc.generateRepairIngestIndexSQLs(ctx, ingestRecorder, taskName) - if err != nil { - return errors.Trace(err) - } - - info := rc.dom.InfoSchema() - console := glue.GetConsole(g) -NEXTSQL: - for _, sql := range sqls { - progressTitle := fmt.Sprintf("repair ingest index %s for table %s.%s", sql.IndexName, sql.SchemaName, sql.TableName) - - tableInfo, err := info.TableByName(ctx, sql.SchemaName, sql.TableName) - if err != nil { - return errors.Trace(err) - } - oldIndexIDFound := false - if fromCheckpoint { - for _, idx := range tableInfo.Indices() { - indexInfo := idx.Meta() - if indexInfo.ID == sql.IndexID { - // the original index id is not dropped - oldIndexIDFound = true - break - } - // what if index's state is not public? - if indexInfo.Name.O == sql.IndexName { - // find the same name index, but not the same index id, - // which means the repaired index id is created - if _, err := fmt.Fprintf(console.Out(), "%s ... %s\n", progressTitle, color.HiGreenString("SKIPPED DUE TO CHECKPOINT MODE")); err != nil { - return errors.Trace(err) - } - continue NEXTSQL - } - } - } - - if err := func(sql checkpoint.CheckpointIngestIndexRepairSQL) error { - w := console.StartProgressBar(progressTitle, glue.OnlyOneTask) - defer w.Close() - - // TODO: When the TiDB supports the DROP and CREATE the same name index in one SQL, - // the checkpoint for ingest recorder can be removed and directly use the SQL: - // ALTER TABLE db.tbl DROP INDEX `i_1`, ADD IDNEX `i_1` ... - // - // This SQL is compatible with checkpoint: If one ingest index has been recreated by - // the SQL, the index's id would be another one. In the next retry execution, BR can - // not find the ingest index's dropped id so that BR regards it as a dropped index by - // restored metakv and then skips repairing it. - - // only when first execution or old index id is not dropped - if !fromCheckpoint || oldIndexIDFound { - if err := rc.se.ExecuteInternal(ctx, alterTableDropIndexSQL, sql.SchemaName.O, sql.TableName.O, sql.IndexName); err != nil { - return errors.Trace(err) - } - } - failpoint.Inject("failed-before-create-ingest-index", func(v failpoint.Value) { - if v != nil && v.(bool) { - failpoint.Return(errors.New("failed before create ingest index")) - } - }) - // create the repaired index when first execution or not found it - if err := rc.se.ExecuteInternal(ctx, sql.AddSQL, sql.AddArgs...); err != nil { - return errors.Trace(err) - } - w.Inc() - if err := w.Wait(ctx); err != nil { - return errors.Trace(err) - } - return nil - }(sql); err != nil { - return errors.Trace(err) - } - } - - return nil -} - -func (rc *LogClient) RecordDeleteRange(sql *stream.PreDelRangeQuery) { - rc.deleteRangeQueryCh <- sql -} - -// use channel to save the delete-range query to make it thread-safety. -func (rc *LogClient) RunGCRowsLoader(ctx context.Context) { - rc.deleteRangeQueryWaitGroup.Add(1) - - go func() { - defer rc.deleteRangeQueryWaitGroup.Done() - for { - select { - case <-ctx.Done(): - return - case query, ok := <-rc.deleteRangeQueryCh: - if !ok { - return - } - rc.deleteRangeQuery = append(rc.deleteRangeQuery, query) - } - } - }() -} - -// InsertGCRows insert the querys into table `gc_delete_range` -func (rc *LogClient) InsertGCRows(ctx context.Context) error { - close(rc.deleteRangeQueryCh) - rc.deleteRangeQueryWaitGroup.Wait() - ts, err := restore.GetTSWithRetry(ctx, rc.pdClient) - if err != nil { - return errors.Trace(err) - } - jobIDMap := make(map[int64]int64) - for _, query := range rc.deleteRangeQuery { - paramsList := make([]any, 0, len(query.ParamsList)*5) - for _, params := range query.ParamsList { - newJobID, exists := jobIDMap[params.JobID] - if !exists { - newJobID, err = rc.GenGlobalID(ctx) - if err != nil { - return errors.Trace(err) - } - jobIDMap[params.JobID] = newJobID - } - log.Info("insert into the delete range", - zap.Int64("jobID", newJobID), - zap.Int64("elemID", params.ElemID), - zap.String("startKey", params.StartKey), - zap.String("endKey", params.EndKey), - zap.Uint64("ts", ts)) - // (job_id, elem_id, start_key, end_key, ts) - paramsList = append(paramsList, newJobID, params.ElemID, params.StartKey, params.EndKey, ts) - } - if len(paramsList) > 0 { - // trim the ',' behind the query.Sql if exists - // that's when the rewrite rule of the last table id is not exist - sql := strings.TrimSuffix(query.Sql, ",") - if err := rc.se.ExecuteInternal(ctx, sql, paramsList...); err != nil { - return errors.Trace(err) - } - } - } - return nil -} - -// only for unit test -func (rc *LogClient) GetGCRows() []*stream.PreDelRangeQuery { - close(rc.deleteRangeQueryCh) - rc.deleteRangeQueryWaitGroup.Wait() - return rc.deleteRangeQuery -} - -// SaveIDMap saves the id mapping information. -func (rc *LogClient) SaveIDMap( - ctx context.Context, - sr *stream.SchemasReplace, -) error { - idMaps := sr.TidySchemaMaps() - clusterID := rc.GetClusterID(ctx) - metaFileName := metautil.PitrIDMapsFilename(clusterID, rc.restoreTS) - metaWriter := metautil.NewMetaWriter(rc.storage, metautil.MetaFileSize, false, metaFileName, nil) - metaWriter.Update(func(m *backuppb.BackupMeta) { - // save log startTS to backupmeta file - m.ClusterId = clusterID - m.DbMaps = idMaps - }) - - if err := metaWriter.FlushBackupMeta(ctx); err != nil { - return errors.Trace(err) - } - if rc.useCheckpoint { - var items map[int64]model.TiFlashReplicaInfo - if sr.TiflashRecorder != nil { - items = sr.TiflashRecorder.GetItems() - } - log.Info("save checkpoint task info with InLogRestoreAndIdMapPersist status") - if err := checkpoint.SaveCheckpointTaskInfoForLogRestore(ctx, rc.storage, &checkpoint.CheckpointTaskInfoForLogRestore{ - Progress: checkpoint.InLogRestoreAndIdMapPersist, - StartTS: rc.startTS, - RestoreTS: rc.restoreTS, - RewriteTS: rc.currentTS, - TiFlashItems: items, - }, rc.GetClusterID(ctx)); err != nil { - return errors.Trace(err) - } - } - return nil -} - -// called by failpoint, only used for test -// it would print the checksum result into the log, and -// the auto-test script records them to compare another -// cluster's checksum. -func (rc *LogClient) FailpointDoChecksumForLogRestore( - ctx context.Context, - kvClient kv.Client, - pdClient pd.Client, - idrules map[int64]int64, - rewriteRules map[int64]*restoreutils.RewriteRules, -) (finalErr error) { - startTS, err := restore.GetTSWithRetry(ctx, rc.pdClient) - if err != nil { - return errors.Trace(err) - } - // set gc safepoint for checksum - sp := utils.BRServiceSafePoint{ - BackupTS: startTS, - TTL: utils.DefaultBRGCSafePointTTL, - ID: utils.MakeSafePointID(), - } - cctx, gcSafePointKeeperCancel := context.WithCancel(ctx) - defer func() { - log.Info("start to remove gc-safepoint keeper") - // close the gc safe point keeper at first - gcSafePointKeeperCancel() - // set the ttl to 0 to remove the gc-safe-point - sp.TTL = 0 - if err := utils.UpdateServiceSafePoint(ctx, pdClient, sp); err != nil { - log.Warn("failed to update service safe point, backup may fail if gc triggered", - zap.Error(err), - ) - } - log.Info("finish removing gc-safepoint keeper") - }() - err = utils.StartServiceSafePointKeeper(cctx, pdClient, sp) - if err != nil { - return errors.Trace(err) - } - - eg, ectx := errgroup.WithContext(ctx) - pool := tidbutil.NewWorkerPool(4, "checksum for log restore") - infoSchema := rc.GetDomain().InfoSchema() - // downstream id -> upstream id - reidRules := make(map[int64]int64) - for upstreamID, downstreamID := range idrules { - reidRules[downstreamID] = upstreamID - } - for upstreamID, downstreamID := range idrules { - newTable, ok := infoSchema.TableByID(downstreamID) - if !ok { - // a dropped table - continue - } - rewriteRule, ok := rewriteRules[upstreamID] - if !ok { - continue - } - newTableInfo := newTable.Meta() - var definitions []model.PartitionDefinition - if newTableInfo.Partition != nil { - for _, def := range newTableInfo.Partition.Definitions { - upid, ok := reidRules[def.ID] - if !ok { - log.Panic("no rewrite rule for parition table id", zap.Int64("id", def.ID)) - } - definitions = append(definitions, model.PartitionDefinition{ - ID: upid, - }) - } - } - oldPartition := &model.PartitionInfo{ - Definitions: definitions, - } - oldTable := &metautil.Table{ - Info: &model.TableInfo{ - ID: upstreamID, - Indices: newTableInfo.Indices, - Partition: oldPartition, - }, - } - pool.ApplyOnErrorGroup(eg, func() error { - exe, err := checksum.NewExecutorBuilder(newTableInfo, startTS). - SetOldTable(oldTable). - SetConcurrency(4). - SetOldKeyspace(rewriteRule.OldKeyspace). - SetNewKeyspace(rewriteRule.NewKeyspace). - SetExplicitRequestSourceType(kvutil.ExplicitTypeBR). - Build() - if err != nil { - return errors.Trace(err) - } - checksumResp, err := exe.Execute(ectx, kvClient, func() {}) - if err != nil { - return errors.Trace(err) - } - // print to log so that the test script can get final checksum - log.Info("failpoint checksum completed", - zap.String("table-name", newTableInfo.Name.O), - zap.Int64("upstream-id", oldTable.Info.ID), - zap.Uint64("checksum", checksumResp.Checksum), - zap.Uint64("total-kvs", checksumResp.TotalKvs), - zap.Uint64("total-bytes", checksumResp.TotalBytes), - ) - return nil - }) - } - - return eg.Wait() -} - -type LogFilesIterWithSplitHelper struct { - iter LogIter - helper *logsplit.LogSplitHelper - buffer []*LogDataFileInfo - next int -} - -const SplitFilesBufferSize = 4096 - -func NewLogFilesIterWithSplitHelper(iter LogIter, rules map[int64]*restoreutils.RewriteRules, client split.SplitClient, splitSize uint64, splitKeys int64) LogIter { - return &LogFilesIterWithSplitHelper{ - iter: iter, - helper: logsplit.NewLogSplitHelper(rules, client, splitSize, splitKeys), - buffer: nil, - next: 0, - } -} - -func (splitIter *LogFilesIterWithSplitHelper) TryNext(ctx context.Context) iter.IterResult[*LogDataFileInfo] { - if splitIter.next >= len(splitIter.buffer) { - splitIter.buffer = make([]*LogDataFileInfo, 0, SplitFilesBufferSize) - for r := splitIter.iter.TryNext(ctx); !r.Finished; r = splitIter.iter.TryNext(ctx) { - if r.Err != nil { - return r - } - f := r.Item - splitIter.helper.Merge(f.DataFileInfo) - splitIter.buffer = append(splitIter.buffer, f) - if len(splitIter.buffer) >= SplitFilesBufferSize { - break - } - } - splitIter.next = 0 - if len(splitIter.buffer) == 0 { - return iter.Done[*LogDataFileInfo]() - } - log.Info("start to split the regions") - startTime := time.Now() - if err := splitIter.helper.Split(ctx); err != nil { - return iter.Throw[*LogDataFileInfo](errors.Trace(err)) - } - log.Info("end to split the regions", zap.Duration("takes", time.Since(startTime))) - } - - res := iter.Emit(splitIter.buffer[splitIter.next]) - splitIter.next += 1 - return res -} diff --git a/br/pkg/restore/log_client/import_retry.go b/br/pkg/restore/log_client/import_retry.go index ab7d60c7e98b1..93f454d6252e5 100644 --- a/br/pkg/restore/log_client/import_retry.go +++ b/br/pkg/restore/log_client/import_retry.go @@ -92,12 +92,12 @@ func (o *OverRegionsInRangeController) handleInRegionError(ctx context.Context, if strings.Contains(result.StoreError.GetMessage(), "memory is limited") { sleepDuration := 15 * time.Second - if val, _err_ := failpoint.Eval(_curpkg_("hint-memory-is-limited")); _err_ == nil { + failpoint.Inject("hint-memory-is-limited", func(val failpoint.Value) { if val.(bool) { logutil.CL(ctx).Debug("failpoint hint-memory-is-limited injected.") sleepDuration = 100 * time.Microsecond } - } + }) time.Sleep(sleepDuration) return true } diff --git a/br/pkg/restore/log_client/import_retry.go__failpoint_stash__ b/br/pkg/restore/log_client/import_retry.go__failpoint_stash__ deleted file mode 100644 index 93f454d6252e5..0000000000000 --- a/br/pkg/restore/log_client/import_retry.go__failpoint_stash__ +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package logclient - -import ( - "context" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/errorpb" - "github.com/pingcap/kvproto/pkg/import_sstpb" - "github.com/pingcap/kvproto/pkg/metapb" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/restore/split" - restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/tikv/client-go/v2/kv" - "go.uber.org/multierr" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -type RegionFunc func(ctx context.Context, r *split.RegionInfo) RPCResult - -type OverRegionsInRangeController struct { - start []byte - end []byte - metaClient split.SplitClient - - errors error - rs *utils.RetryState -} - -// OverRegionsInRange creates a controller that cloud be used to scan regions in a range and -// apply a function over these regions. -// You can then call the `Run` method for applying some functions. -func OverRegionsInRange(start, end []byte, metaClient split.SplitClient, retryStatus *utils.RetryState) OverRegionsInRangeController { - // IMPORTANT: we record the start/end key with TimeStamp. - // but scanRegion will drop the TimeStamp and the end key is exclusive. - // if we do not use PrefixNextKey. we might scan fewer regions than we expected. - // and finally cause the data lost. - end = restoreutils.TruncateTS(end) - end = kv.PrefixNextKey(end) - - return OverRegionsInRangeController{ - start: start, - end: end, - metaClient: metaClient, - rs: retryStatus, - } -} - -func (o *OverRegionsInRangeController) onError(_ context.Context, result RPCResult, region *split.RegionInfo) { - o.errors = multierr.Append(o.errors, errors.Annotatef(&result, "execute over region %v failed", region.Region)) - // TODO: Maybe handle some of region errors like `epoch not match`? -} - -func (o *OverRegionsInRangeController) tryFindLeader(ctx context.Context, region *split.RegionInfo) (*metapb.Peer, error) { - var leader *metapb.Peer - failed := false - leaderRs := utils.InitialRetryState(4, 5*time.Second, 10*time.Second) - err := utils.WithRetry(ctx, func() error { - r, err := o.metaClient.GetRegionByID(ctx, region.Region.Id) - if err != nil { - return err - } - if !split.CheckRegionEpoch(r, region) { - failed = true - return nil - } - if r.Leader != nil { - leader = r.Leader - return nil - } - return errors.Annotatef(berrors.ErrPDLeaderNotFound, "there is no leader for region %d", region.Region.Id) - }, &leaderRs) - if failed { - return nil, errors.Annotatef(berrors.ErrKVEpochNotMatch, "the current epoch of %s is changed", region) - } - if err != nil { - return nil, err - } - return leader, nil -} - -// handleInRegionError handles the error happens internal in the region. Update the region info, and perform a suitable backoff. -func (o *OverRegionsInRangeController) handleInRegionError(ctx context.Context, result RPCResult, region *split.RegionInfo) (cont bool) { - if result.StoreError.GetServerIsBusy() != nil { - if strings.Contains(result.StoreError.GetMessage(), "memory is limited") { - sleepDuration := 15 * time.Second - - failpoint.Inject("hint-memory-is-limited", func(val failpoint.Value) { - if val.(bool) { - logutil.CL(ctx).Debug("failpoint hint-memory-is-limited injected.") - sleepDuration = 100 * time.Microsecond - } - }) - time.Sleep(sleepDuration) - return true - } - } - - if nl := result.StoreError.GetNotLeader(); nl != nil { - if nl.Leader != nil { - region.Leader = nl.Leader - // try the new leader immediately. - return true - } - // we retry manually, simply record the retry event. - time.Sleep(o.rs.ExponentialBackoff()) - // There may not be leader, waiting... - leader, err := o.tryFindLeader(ctx, region) - if err != nil { - // Leave the region info unchanged, let it retry then. - logutil.CL(ctx).Warn("failed to find leader", logutil.Region(region.Region), logutil.ShortError(err)) - return false - } - region.Leader = leader - return true - } - // For other errors, like `ServerIsBusy`, `RegionIsNotInitialized`, just trivially backoff. - time.Sleep(o.rs.ExponentialBackoff()) - return true -} - -func (o *OverRegionsInRangeController) prepareLogCtx(ctx context.Context) context.Context { - lctx := logutil.ContextWithField( - ctx, - logutil.Key("startKey", o.start), - logutil.Key("endKey", o.end), - ) - return lctx -} - -// Run executes the `regionFunc` over the regions in `o.start` and `o.end`. -// It would retry the errors according to the `rpcResponse`. -func (o *OverRegionsInRangeController) Run(ctx context.Context, f RegionFunc) error { - return o.runOverRegions(o.prepareLogCtx(ctx), f) -} - -func (o *OverRegionsInRangeController) runOverRegions(ctx context.Context, f RegionFunc) error { - if !o.rs.ShouldRetry() { - return o.errors - } - - // Scan regions covered by the file range - regionInfos, errScanRegion := split.PaginateScanRegion( - ctx, o.metaClient, o.start, o.end, split.ScanRegionPaginationLimit) - if errScanRegion != nil { - return errors.Trace(errScanRegion) - } - - for _, region := range regionInfos { - cont, err := o.runInRegion(ctx, f, region) - if err != nil { - return err - } - if !cont { - return nil - } - } - return nil -} - -// runInRegion executes the function in the region, and returns `cont = false` if no need for trying for next region. -func (o *OverRegionsInRangeController) runInRegion(ctx context.Context, f RegionFunc, region *split.RegionInfo) (cont bool, err error) { - if !o.rs.ShouldRetry() { - return false, o.errors - } - result := f(ctx, region) - - if !result.OK() { - o.onError(ctx, result, region) - switch result.StrategyForRetry() { - case StrategyGiveUp: - logutil.CL(ctx).Warn("unexpected error, should stop to retry", logutil.ShortError(&result), logutil.Region(region.Region)) - return false, o.errors - case StrategyFromThisRegion: - logutil.CL(ctx).Warn("retry for region", logutil.Region(region.Region), logutil.ShortError(&result)) - if !o.handleInRegionError(ctx, result, region) { - return false, o.runOverRegions(ctx, f) - } - return o.runInRegion(ctx, f, region) - case StrategyFromStart: - logutil.CL(ctx).Warn("retry for execution over regions", logutil.ShortError(&result)) - // TODO: make a backoffer considering more about the error info, - // instead of ingore the result and retry. - time.Sleep(o.rs.ExponentialBackoff()) - return false, o.runOverRegions(ctx, f) - } - } - return true, nil -} - -// RPCResult is the result after executing some RPCs to TiKV. -type RPCResult struct { - Err error - - ImportError string - StoreError *errorpb.Error -} - -func RPCResultFromPBError(err *import_sstpb.Error) RPCResult { - return RPCResult{ - ImportError: err.GetMessage(), - StoreError: err.GetStoreError(), - } -} - -func RPCResultFromError(err error) RPCResult { - return RPCResult{ - Err: err, - } -} - -func RPCResultOK() RPCResult { - return RPCResult{} -} - -type RetryStrategy int - -const ( - StrategyGiveUp RetryStrategy = iota - StrategyFromThisRegion - StrategyFromStart -) - -func (r *RPCResult) StrategyForRetry() RetryStrategy { - if r.Err != nil { - return r.StrategyForRetryGoError() - } - return r.StrategyForRetryStoreError() -} - -func (r *RPCResult) StrategyForRetryStoreError() RetryStrategy { - if r.StoreError == nil && r.ImportError == "" { - return StrategyGiveUp - } - - if r.StoreError.GetServerIsBusy() != nil || - r.StoreError.GetRegionNotInitialized() != nil || - r.StoreError.GetNotLeader() != nil || - r.StoreError.GetServerIsBusy() != nil { - return StrategyFromThisRegion - } - - return StrategyFromStart -} - -func (r *RPCResult) StrategyForRetryGoError() RetryStrategy { - if r.Err == nil { - return StrategyGiveUp - } - - // we should unwrap the error or we cannot get the write gRPC status. - if gRPCErr, ok := status.FromError(errors.Cause(r.Err)); ok { - switch gRPCErr.Code() { - case codes.Unavailable, codes.Aborted, codes.ResourceExhausted, codes.DeadlineExceeded: - return StrategyFromThisRegion - } - } - - return StrategyGiveUp -} - -func (r *RPCResult) Error() string { - if r.Err != nil { - return r.Err.Error() - } - if r.StoreError != nil { - return r.StoreError.GetMessage() - } - if r.ImportError != "" { - return r.ImportError - } - return "BUG(There is no error but reported as error)" -} - -func (r *RPCResult) OK() bool { - return r.Err == nil && r.ImportError == "" && r.StoreError == nil -} diff --git a/br/pkg/restore/misc.go b/br/pkg/restore/misc.go index 469b4ea7b9cca..62d7fbc32fdb4 100644 --- a/br/pkg/restore/misc.go +++ b/br/pkg/restore/misc.go @@ -137,11 +137,11 @@ func GetTSWithRetry(ctx context.Context, pdClient pd.Client) (uint64, error) { err := utils.WithRetry(ctx, func() error { startTS, getTSErr = GetTS(ctx, pdClient) - if val, _err_ := failpoint.Eval(_curpkg_("get-ts-error")); _err_ == nil { + failpoint.Inject("get-ts-error", func(val failpoint.Value) { if val.(bool) && retry < 3 { getTSErr = errors.Errorf("rpc error: code = Unknown desc = [PD:tso:ErrGenerateTimestamp]generate timestamp failed, requested pd is not leader of cluster") } - } + }) retry++ if getTSErr != nil { diff --git a/br/pkg/restore/misc.go__failpoint_stash__ b/br/pkg/restore/misc.go__failpoint_stash__ deleted file mode 100644 index 62d7fbc32fdb4..0000000000000 --- a/br/pkg/restore/misc.go__failpoint_stash__ +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package restore - -import ( - "context" - "fmt" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/model" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/tikv/client-go/v2/oracle" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" -) - -// deprecated parameter -type Granularity string - -const ( - FineGrained Granularity = "fine-grained" - CoarseGrained Granularity = "coarse-grained" -) - -type UniqueTableName struct { - DB string - Table string -} - -func TransferBoolToValue(enable bool) string { - if enable { - return "ON" - } - return "OFF" -} - -// GetTableSchema returns the schema of a table from TiDB. -func GetTableSchema( - dom *domain.Domain, - dbName model.CIStr, - tableName model.CIStr, -) (*model.TableInfo, error) { - info := dom.InfoSchema() - table, err := info.TableByName(context.Background(), dbName, tableName) - if err != nil { - return nil, errors.Trace(err) - } - return table.Meta(), nil -} - -const maxUserTablesNum = 10 - -// AssertUserDBsEmpty check whether user dbs exist in the cluster -func AssertUserDBsEmpty(dom *domain.Domain) error { - databases := dom.InfoSchema().AllSchemas() - m := meta.NewSnapshotMeta(dom.Store().GetSnapshot(kv.MaxVersion)) - userTables := make([]string, 0, maxUserTablesNum+1) - appendTables := func(dbName, tableName string) bool { - if len(userTables) >= maxUserTablesNum { - userTables = append(userTables, "...") - return true - } - userTables = append(userTables, fmt.Sprintf("%s.%s", dbName, tableName)) - return false - } -LISTDBS: - for _, db := range databases { - dbName := db.Name.L - if tidbutil.IsMemOrSysDB(dbName) { - continue - } - tables, err := m.ListSimpleTables(db.ID) - if err != nil { - return errors.Annotatef(err, "failed to iterator tables of database[id=%d]", db.ID) - } - if len(tables) == 0 { - // tidb create test db on fresh cluster - // if it's empty we don't take it as user db - if dbName != "test" { - if appendTables(db.Name.O, "") { - break LISTDBS - } - } - continue - } - for _, table := range tables { - if appendTables(db.Name.O, table.Name.O) { - break LISTDBS - } - } - } - if len(userTables) > 0 { - return errors.Annotate(berrors.ErrRestoreNotFreshCluster, - "user db/tables: "+strings.Join(userTables, ", ")) - } - return nil -} - -// GetTS gets a new timestamp from PD. -func GetTS(ctx context.Context, pdClient pd.Client) (uint64, error) { - p, l, err := pdClient.GetTS(ctx) - if err != nil { - return 0, errors.Trace(err) - } - restoreTS := oracle.ComposeTS(p, l) - return restoreTS, nil -} - -// GetTSWithRetry gets a new timestamp with retry from PD. -func GetTSWithRetry(ctx context.Context, pdClient pd.Client) (uint64, error) { - var ( - startTS uint64 - getTSErr error - retry uint - ) - - err := utils.WithRetry(ctx, func() error { - startTS, getTSErr = GetTS(ctx, pdClient) - failpoint.Inject("get-ts-error", func(val failpoint.Value) { - if val.(bool) && retry < 3 { - getTSErr = errors.Errorf("rpc error: code = Unknown desc = [PD:tso:ErrGenerateTimestamp]generate timestamp failed, requested pd is not leader of cluster") - } - }) - - retry++ - if getTSErr != nil { - log.Warn("failed to get TS, retry it", zap.Uint("retry time", retry), logutil.ShortError(getTSErr)) - } - return getTSErr - }, utils.NewPDReqBackoffer()) - - if err != nil { - log.Error("failed to get TS", zap.Error(err)) - } - return startTS, errors.Trace(err) -} diff --git a/br/pkg/restore/snap_client/binding__failpoint_binding__.go b/br/pkg/restore/snap_client/binding__failpoint_binding__.go deleted file mode 100644 index 777c78dc6633b..0000000000000 --- a/br/pkg/restore/snap_client/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package snapclient - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/restore/snap_client/client.go b/br/pkg/restore/snap_client/client.go index d1fb0f18c227f..061fad6388016 100644 --- a/br/pkg/restore/snap_client/client.go +++ b/br/pkg/restore/snap_client/client.go @@ -726,11 +726,11 @@ func (rc *SnapClient) createTablesInWorkerPool(ctx context.Context, tables []*me workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { db := rc.dbPool[id%uint64(len(rc.dbPool))] cts, err := rc.createTables(ectx, db, tableSlice, newTS) // ddl job for [lastSent:i) - if val, _err_ := failpoint.Eval(_curpkg_("restore-createtables-error")); _err_ == nil { + failpoint.Inject("restore-createtables-error", func(val failpoint.Value) { if val.(bool) { err = errors.New("sample error without extra message") } - } + }) if err != nil { log.Error("create tables fail", zap.Error(err)) return err @@ -825,9 +825,9 @@ func (rc *SnapClient) IsFullClusterRestore() bool { // IsFull returns whether this backup is full. func (rc *SnapClient) IsFull() bool { - if _, _err_ := failpoint.Eval(_curpkg_("mock-incr-backup-data")); _err_ == nil { - return false - } + failpoint.Inject("mock-incr-backup-data", func() { + failpoint.Return(false) + }) return !rc.IsIncremental() } diff --git a/br/pkg/restore/snap_client/client.go__failpoint_stash__ b/br/pkg/restore/snap_client/client.go__failpoint_stash__ deleted file mode 100644 index 061fad6388016..0000000000000 --- a/br/pkg/restore/snap_client/client.go__failpoint_stash__ +++ /dev/null @@ -1,1200 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package snapclient - -import ( - "bytes" - "cmp" - "context" - "crypto/tls" - "encoding/json" - "fmt" - "slices" - "strings" - "sync" - "time" - - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/checkpoint" - "github.com/pingcap/tidb/br/pkg/checksum" - "github.com/pingcap/tidb/br/pkg/conn" - "github.com/pingcap/tidb/br/pkg/conn/util" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/metautil" - "github.com/pingcap/tidb/br/pkg/pdutil" - "github.com/pingcap/tidb/br/pkg/restore" - importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" - tidallocdb "github.com/pingcap/tidb/br/pkg/restore/internal/prealloc_db" - tidalloc "github.com/pingcap/tidb/br/pkg/restore/internal/prealloc_table_id" - internalutils "github.com/pingcap/tidb/br/pkg/restore/internal/utils" - "github.com/pingcap/tidb/br/pkg/restore/split" - restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" - "github.com/pingcap/tidb/br/pkg/rtree" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/model" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/redact" - kvutil "github.com/tikv/client-go/v2/util" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc/keepalive" -) - -const ( - strictPlacementPolicyMode = "STRICT" - ignorePlacementPolicyMode = "IGNORE" - - defaultDDLConcurrency = 16 - maxSplitKeysOnce = 10240 -) - -const minBatchDdlSize = 1 - -type SnapClient struct { - // Tool clients used by SnapClient - fileImporter *SnapFileImporter - pdClient pd.Client - pdHTTPClient pdhttp.Client - - // User configurable parameters - cipher *backuppb.CipherInfo - concurrencyPerStore uint - keepaliveConf keepalive.ClientParameters - rateLimit uint64 - tlsConf *tls.Config - - switchCh chan struct{} - - storeCount int - supportPolicy bool - workerPool *tidbutil.WorkerPool - - noSchema bool - hasSpeedLimited bool - - databases map[string]*metautil.Database - ddlJobs []*model.Job - - // store tables need to rebase info like auto id and random id and so on after create table - rebasedTablesMap map[restore.UniqueTableName]bool - - backupMeta *backuppb.BackupMeta - - // TODO Remove this field or replace it with a []*DB, - // since https://github.com/pingcap/br/pull/377 needs more DBs to speed up DDL execution. - // And for now, we must inject a pool of DBs to `Client.GoCreateTables`, otherwise there would be a race condition. - // This is dirty: why we need DBs from different sources? - // By replace it with a []*DB, we can remove the dirty parameter of `Client.GoCreateTable`, - // along with them in some private functions. - // Before you do it, you can firstly read discussions at - // https://github.com/pingcap/br/pull/377#discussion_r446594501, - // this probably isn't as easy as it seems like (however, not hard, too :D) - db *tidallocdb.DB - - // use db pool to speed up restoration in BR binary mode. - dbPool []*tidallocdb.DB - - dom *domain.Domain - - // correspond to --tidb-placement-mode config. - // STRICT(default) means policy related SQL can be executed in tidb. - // IGNORE means policy related SQL will be ignored. - policyMode string - - // policy name -> policy info - policyMap *sync.Map - - batchDdlSize uint - - // if fullClusterRestore = true: - // - if there's system tables in the backup(backup data since br 5.1.0), the cluster should be a fresh cluster - // without user database or table. and system tables about privileges is restored together with user data. - // - if there no system tables in the backup(backup data from br < 5.1.0), restore all user data just like - // previous version did. - // if fullClusterRestore = false, restore all user data just like previous version did. - // fullClusterRestore = true when there is no explicit filter setting, and it's full restore or point command - // with a full backup data. - // todo: maybe change to an enum - // this feature is controlled by flag with-sys-table - fullClusterRestore bool - - // see RestoreCommonConfig.WithSysTable - withSysTable bool - - // the rewrite mode of the downloaded SST files in TiKV. - rewriteMode RewriteMode - - // checkpoint information for snapshot restore - checkpointRunner *checkpoint.CheckpointRunner[checkpoint.RestoreKeyType, checkpoint.RestoreValueType] - checkpointChecksum map[int64]*checkpoint.ChecksumItem -} - -// NewRestoreClient returns a new RestoreClient. -func NewRestoreClient( - pdClient pd.Client, - pdHTTPCli pdhttp.Client, - tlsConf *tls.Config, - keepaliveConf keepalive.ClientParameters, -) *SnapClient { - return &SnapClient{ - pdClient: pdClient, - pdHTTPClient: pdHTTPCli, - tlsConf: tlsConf, - keepaliveConf: keepaliveConf, - switchCh: make(chan struct{}), - } -} - -func (rc *SnapClient) closeConn() { - // rc.db can be nil in raw kv mode. - if rc.db != nil { - rc.db.Close() - } - for _, db := range rc.dbPool { - db.Close() - } -} - -// Close a client. -func (rc *SnapClient) Close() { - // close the connection, and it must be succeed when in SQL mode. - rc.closeConn() - - if err := rc.fileImporter.Close(); err != nil { - log.Warn("failed to close file improter") - } - - log.Info("Restore client closed") -} - -func (rc *SnapClient) SetRateLimit(rateLimit uint64) { - rc.rateLimit = rateLimit -} - -func (rc *SnapClient) SetCrypter(crypter *backuppb.CipherInfo) { - rc.cipher = crypter -} - -// GetClusterID gets the cluster id from down-stream cluster. -func (rc *SnapClient) GetClusterID(ctx context.Context) uint64 { - return rc.pdClient.GetClusterID(ctx) -} - -func (rc *SnapClient) GetDomain() *domain.Domain { - return rc.dom -} - -// GetTLSConfig returns the tls config. -func (rc *SnapClient) GetTLSConfig() *tls.Config { - return rc.tlsConf -} - -// GetSupportPolicy tells whether target tidb support placement policy. -func (rc *SnapClient) GetSupportPolicy() bool { - return rc.supportPolicy -} - -func (rc *SnapClient) updateConcurrency() { - // we believe 32 is large enough for download worker pool. - // it won't reach the limit if sst files distribute evenly. - // when restore memory usage is still too high, we should reduce concurrencyPerStore - // to sarifice some speed to reduce memory usage. - count := uint(rc.storeCount) * rc.concurrencyPerStore * 32 - log.Info("download coarse worker pool", zap.Uint("size", count)) - rc.workerPool = tidbutil.NewWorkerPool(count, "file") -} - -// SetConcurrencyPerStore sets the concurrency of download files for each store. -func (rc *SnapClient) SetConcurrencyPerStore(c uint) { - log.Info("per-store download worker pool", zap.Uint("size", c)) - rc.concurrencyPerStore = c -} - -func (rc *SnapClient) SetBatchDdlSize(batchDdlsize uint) { - rc.batchDdlSize = batchDdlsize -} - -func (rc *SnapClient) GetBatchDdlSize() uint { - return rc.batchDdlSize -} - -func (rc *SnapClient) SetWithSysTable(withSysTable bool) { - rc.withSysTable = withSysTable -} - -// TODO: remove this check and return RewriteModeKeyspace -func (rc *SnapClient) SetRewriteMode(ctx context.Context) { - if err := version.CheckClusterVersion(ctx, rc.pdClient, version.CheckVersionForKeyspaceBR); err != nil { - log.Warn("Keyspace BR is not supported in this cluster, fallback to legacy restore", zap.Error(err)) - rc.rewriteMode = RewriteModeLegacy - } else { - rc.rewriteMode = RewriteModeKeyspace - } -} - -func (rc *SnapClient) GetRewriteMode() RewriteMode { - return rc.rewriteMode -} - -// SetPlacementPolicyMode to policy mode. -func (rc *SnapClient) SetPlacementPolicyMode(withPlacementPolicy string) { - switch strings.ToUpper(withPlacementPolicy) { - case strictPlacementPolicyMode: - rc.policyMode = strictPlacementPolicyMode - case ignorePlacementPolicyMode: - rc.policyMode = ignorePlacementPolicyMode - default: - rc.policyMode = strictPlacementPolicyMode - } - log.Info("set placement policy mode", zap.String("mode", rc.policyMode)) -} - -// AllocTableIDs would pre-allocate the table's origin ID if exists, so that the TiKV doesn't need to rewrite the key in -// the download stage. -func (rc *SnapClient) AllocTableIDs(ctx context.Context, tables []*metautil.Table) error { - preallocedTableIDs := tidalloc.New(tables) - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) - err := kv.RunInNewTxn(ctx, rc.GetDomain().Store(), true, func(_ context.Context, txn kv.Transaction) error { - return preallocedTableIDs.Alloc(meta.NewMeta(txn)) - }) - if err != nil { - return err - } - - log.Info("registering the table IDs", zap.Stringer("ids", preallocedTableIDs)) - for i := range rc.dbPool { - rc.dbPool[i].RegisterPreallocatedIDs(preallocedTableIDs) - } - if rc.db != nil { - rc.db.RegisterPreallocatedIDs(preallocedTableIDs) - } - return nil -} - -// InitCheckpoint initialize the checkpoint status for the cluster. If the cluster is -// restored for the first time, it will initialize the checkpoint metadata. Otherwrise, -// it will load checkpoint metadata and checkpoint ranges/checksum from the external -// storage. -func (rc *SnapClient) InitCheckpoint( - ctx context.Context, - s storage.ExternalStorage, - taskName string, - config *pdutil.ClusterConfig, - checkpointFirstRun bool, -) (map[int64]map[string]struct{}, *pdutil.ClusterConfig, error) { - var ( - // checkpoint sets distinguished by range key - checkpointSetWithTableID = make(map[int64]map[string]struct{}) - - checkpointClusterConfig *pdutil.ClusterConfig - - err error - ) - - if !checkpointFirstRun { - // load the checkpoint since this is not the first time to restore - meta, err := checkpoint.LoadCheckpointMetadataForRestore(ctx, s, taskName) - if err != nil { - return checkpointSetWithTableID, nil, errors.Trace(err) - } - - // The schedulers config is nil, so the restore-schedulers operation is just nil. - // Then the undo function would use the result undo of `remove schedulers` operation, - // instead of that in checkpoint meta. - if meta.SchedulersConfig != nil { - checkpointClusterConfig = meta.SchedulersConfig - } - - // t1 is the latest time the checkpoint ranges persisted to the external storage. - t1, err := checkpoint.WalkCheckpointFileForRestore(ctx, s, rc.cipher, taskName, func(tableID int64, rangeKey checkpoint.RestoreValueType) { - checkpointSet, exists := checkpointSetWithTableID[tableID] - if !exists { - checkpointSet = make(map[string]struct{}) - checkpointSetWithTableID[tableID] = checkpointSet - } - checkpointSet[rangeKey.RangeKey] = struct{}{} - }) - if err != nil { - return checkpointSetWithTableID, nil, errors.Trace(err) - } - // t2 is the latest time the checkpoint checksum persisted to the external storage. - checkpointChecksum, t2, err := checkpoint.LoadCheckpointChecksumForRestore(ctx, s, taskName) - if err != nil { - return checkpointSetWithTableID, nil, errors.Trace(err) - } - rc.checkpointChecksum = checkpointChecksum - // use the later time to adjust the summary elapsed time. - if t1 > t2 { - summary.AdjustStartTimeToEarlierTime(t1) - } else { - summary.AdjustStartTimeToEarlierTime(t2) - } - } else { - // initialize the checkpoint metadata since it is the first time to restore. - meta := &checkpoint.CheckpointMetadataForRestore{} - // a nil config means undo function - if config != nil { - meta.SchedulersConfig = &pdutil.ClusterConfig{Schedulers: config.Schedulers, ScheduleCfg: config.ScheduleCfg} - } - if err = checkpoint.SaveCheckpointMetadataForRestore(ctx, s, meta, taskName); err != nil { - return checkpointSetWithTableID, nil, errors.Trace(err) - } - } - - rc.checkpointRunner, err = checkpoint.StartCheckpointRunnerForRestore(ctx, s, rc.cipher, taskName) - return checkpointSetWithTableID, checkpointClusterConfig, errors.Trace(err) -} - -func (rc *SnapClient) WaitForFinishCheckpoint(ctx context.Context, flush bool) { - if rc.checkpointRunner != nil { - rc.checkpointRunner.WaitForFinish(ctx, flush) - } -} - -// makeDBPool makes a session pool with specficated size by sessionFactory. -func makeDBPool(size uint, dbFactory func() (*tidallocdb.DB, error)) ([]*tidallocdb.DB, error) { - dbPool := make([]*tidallocdb.DB, 0, size) - for i := uint(0); i < size; i++ { - db, e := dbFactory() - if e != nil { - return dbPool, e - } - if db != nil { - dbPool = append(dbPool, db) - } - } - return dbPool, nil -} - -// Init create db connection and domain for storage. -func (rc *SnapClient) Init(g glue.Glue, store kv.Storage) error { - // setDB must happen after set PolicyMode. - // we will use policyMode to set session variables. - var err error - rc.db, rc.supportPolicy, err = tidallocdb.NewDB(g, store, rc.policyMode) - if err != nil { - return errors.Trace(err) - } - rc.dom, err = g.GetDomain(store) - if err != nil { - return errors.Trace(err) - } - - // init backupMeta only for passing unit test - if rc.backupMeta == nil { - rc.backupMeta = new(backuppb.BackupMeta) - } - - // There are different ways to create session between in binary and in SQL. - // - // Maybe allow user modify the DDL concurrency isn't necessary, - // because executing DDL is really I/O bound (or, algorithm bound?), - // and we cost most of time at waiting DDL jobs be enqueued. - // So these jobs won't be faster or slower when machine become faster or slower, - // hence make it a fixed value would be fine. - rc.dbPool, err = makeDBPool(defaultDDLConcurrency, func() (*tidallocdb.DB, error) { - db, _, err := tidallocdb.NewDB(g, store, rc.policyMode) - return db, err - }) - if err != nil { - log.Warn("create session pool failed, we will send DDLs only by created sessions", - zap.Error(err), - zap.Int("sessionCount", len(rc.dbPool)), - ) - } - return errors.Trace(err) -} - -func (rc *SnapClient) initClients(ctx context.Context, backend *backuppb.StorageBackend, isRawKvMode bool, isTxnKvMode bool) error { - stores, err := conn.GetAllTiKVStoresWithRetry(ctx, rc.pdClient, util.SkipTiFlash) - if err != nil { - return errors.Annotate(err, "failed to get stores") - } - rc.storeCount = len(stores) - rc.updateConcurrency() - - var splitClientOpts []split.ClientOptionalParameter - if isRawKvMode { - splitClientOpts = append(splitClientOpts, split.WithRawKV()) - } - metaClient := split.NewClient(rc.pdClient, rc.pdHTTPClient, rc.tlsConf, maxSplitKeysOnce, rc.storeCount+1, splitClientOpts...) - importCli := importclient.NewImportClient(metaClient, rc.tlsConf, rc.keepaliveConf) - rc.fileImporter, err = NewSnapFileImporter(ctx, metaClient, importCli, backend, isRawKvMode, isTxnKvMode, stores, rc.rewriteMode, rc.concurrencyPerStore) - return errors.Trace(err) -} - -func (rc *SnapClient) needLoadSchemas(backupMeta *backuppb.BackupMeta) bool { - return !(backupMeta.IsRawKv || backupMeta.IsTxnKv) -} - -// InitBackupMeta loads schemas from BackupMeta to initialize RestoreClient. -func (rc *SnapClient) InitBackupMeta( - c context.Context, - backupMeta *backuppb.BackupMeta, - backend *backuppb.StorageBackend, - reader *metautil.MetaReader, - loadStats bool) error { - if rc.needLoadSchemas(backupMeta) { - databases, err := metautil.LoadBackupTables(c, reader, loadStats) - if err != nil { - return errors.Trace(err) - } - rc.databases = databases - - var ddlJobs []*model.Job - // ddls is the bytes of json.Marshal - ddls, err := reader.ReadDDLs(c) - if err != nil { - return errors.Trace(err) - } - if len(ddls) != 0 { - err = json.Unmarshal(ddls, &ddlJobs) - if err != nil { - return errors.Trace(err) - } - } - rc.ddlJobs = ddlJobs - } - rc.backupMeta = backupMeta - log.Info("load backupmeta", zap.Int("databases", len(rc.databases)), zap.Int("jobs", len(rc.ddlJobs))) - - return rc.initClients(c, backend, backupMeta.IsRawKv, backupMeta.IsTxnKv) -} - -// IsRawKvMode checks whether the backup data is in raw kv format, in which case transactional recover is forbidden. -func (rc *SnapClient) IsRawKvMode() bool { - return rc.backupMeta.IsRawKv -} - -// GetFilesInRawRange gets all files that are in the given range or intersects with the given range. -func (rc *SnapClient) GetFilesInRawRange(startKey []byte, endKey []byte, cf string) ([]*backuppb.File, error) { - if !rc.IsRawKvMode() { - return nil, errors.Annotate(berrors.ErrRestoreModeMismatch, "the backup data is not in raw kv mode") - } - - for _, rawRange := range rc.backupMeta.RawRanges { - // First check whether the given range is backup-ed. If not, we cannot perform the restore. - if rawRange.Cf != cf { - continue - } - - if (len(rawRange.EndKey) > 0 && bytes.Compare(startKey, rawRange.EndKey) >= 0) || - (len(endKey) > 0 && bytes.Compare(rawRange.StartKey, endKey) >= 0) { - // The restoring range is totally out of the current range. Skip it. - continue - } - - if bytes.Compare(startKey, rawRange.StartKey) < 0 || - utils.CompareEndKey(endKey, rawRange.EndKey) > 0 { - // Only partial of the restoring range is in the current backup-ed range. So the given range can't be fully - // restored. - return nil, errors.Annotatef(berrors.ErrRestoreRangeMismatch, - "the given range to restore [%s, %s) is not fully covered by the range that was backed up [%s, %s)", - redact.Key(startKey), redact.Key(endKey), redact.Key(rawRange.StartKey), redact.Key(rawRange.EndKey), - ) - } - - // We have found the range that contains the given range. Find all necessary files. - files := make([]*backuppb.File, 0) - - for _, file := range rc.backupMeta.Files { - if file.Cf != cf { - continue - } - - if len(file.EndKey) > 0 && bytes.Compare(file.EndKey, startKey) < 0 { - // The file is before the range to be restored. - continue - } - if len(endKey) > 0 && bytes.Compare(endKey, file.StartKey) <= 0 { - // The file is after the range to be restored. - // The specified endKey is exclusive, so when it equals to a file's startKey, the file is still skipped. - continue - } - - files = append(files, file) - } - - // There should be at most one backed up range that covers the restoring range. - return files, nil - } - - return nil, errors.Annotate(berrors.ErrRestoreRangeMismatch, "no backup data in the range") -} - -// ResetTS resets the timestamp of PD to a bigger value. -func (rc *SnapClient) ResetTS(ctx context.Context, pdCtrl *pdutil.PdController) error { - restoreTS := rc.backupMeta.GetEndVersion() - log.Info("reset pd timestamp", zap.Uint64("ts", restoreTS)) - return utils.WithRetry(ctx, func() error { - return pdCtrl.ResetTS(ctx, restoreTS) - }, utils.NewPDReqBackoffer()) -} - -// GetDatabases returns all databases. -func (rc *SnapClient) GetDatabases() []*metautil.Database { - dbs := make([]*metautil.Database, 0, len(rc.databases)) - for _, db := range rc.databases { - dbs = append(dbs, db) - } - return dbs -} - -// HasBackedUpSysDB whether we have backed up system tables -// br backs system tables up since 5.1.0 -func (rc *SnapClient) HasBackedUpSysDB() bool { - sysDBs := []string{"mysql", "sys"} - for _, db := range sysDBs { - temporaryDB := utils.TemporaryDBName(db) - _, backedUp := rc.databases[temporaryDB.O] - if backedUp { - return true - } - } - return false -} - -// GetPlacementPolicies returns policies. -func (rc *SnapClient) GetPlacementPolicies() (*sync.Map, error) { - policies := &sync.Map{} - for _, p := range rc.backupMeta.Policies { - policyInfo := &model.PolicyInfo{} - err := json.Unmarshal(p.Info, policyInfo) - if err != nil { - return nil, errors.Trace(err) - } - policies.Store(policyInfo.Name.L, policyInfo) - } - return policies, nil -} - -// GetDDLJobs returns ddl jobs. -func (rc *SnapClient) GetDDLJobs() []*model.Job { - return rc.ddlJobs -} - -// SetPolicyMap set policyMap. -func (rc *SnapClient) SetPolicyMap(p *sync.Map) { - rc.policyMap = p -} - -// CreatePolicies creates all policies in full restore. -func (rc *SnapClient) CreatePolicies(ctx context.Context, policyMap *sync.Map) error { - var err error - policyMap.Range(func(key, value any) bool { - e := rc.db.CreatePlacementPolicy(ctx, value.(*model.PolicyInfo)) - if e != nil { - err = e - return false - } - return true - }) - return err -} - -// CreateDatabases creates databases. If the client has the db pool, it would create it. -func (rc *SnapClient) CreateDatabases(ctx context.Context, dbs []*metautil.Database) error { - if rc.IsSkipCreateSQL() { - log.Info("skip create database") - return nil - } - - if len(rc.dbPool) == 0 { - log.Info("create databases sequentially") - for _, db := range dbs { - err := rc.db.CreateDatabase(ctx, db.Info, rc.supportPolicy, rc.policyMap) - if err != nil { - return errors.Trace(err) - } - } - return nil - } - - log.Info("create databases in db pool", zap.Int("pool size", len(rc.dbPool))) - eg, ectx := errgroup.WithContext(ctx) - workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DB DDL workers") - for _, db_ := range dbs { - db := db_ - workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { - conn := rc.dbPool[id%uint64(len(rc.dbPool))] - return conn.CreateDatabase(ectx, db.Info, rc.supportPolicy, rc.policyMap) - }) - } - return eg.Wait() -} - -// generateRebasedTables generate a map[UniqueTableName]bool to represent tables that haven't updated table info. -// there are two situations: -// 1. tables that already exists in the restored cluster. -// 2. tables that are created by executing ddl jobs. -// so, only tables in incremental restoration will be added to the map -func (rc *SnapClient) generateRebasedTables(tables []*metautil.Table) { - if !rc.IsIncremental() { - // in full restoration, all tables are created by Session.CreateTable, and all tables' info is updated. - rc.rebasedTablesMap = make(map[restore.UniqueTableName]bool) - return - } - - rc.rebasedTablesMap = make(map[restore.UniqueTableName]bool, len(tables)) - for _, table := range tables { - rc.rebasedTablesMap[restore.UniqueTableName{DB: table.DB.Name.String(), Table: table.Info.Name.String()}] = true - } -} - -// getRebasedTables returns tables that may need to be rebase auto increment id or auto random id -func (rc *SnapClient) getRebasedTables() map[restore.UniqueTableName]bool { - return rc.rebasedTablesMap -} - -func (rc *SnapClient) createTables( - ctx context.Context, - db *tidallocdb.DB, - tables []*metautil.Table, - newTS uint64, -) ([]CreatedTable, error) { - log.Info("client to create tables") - if rc.IsSkipCreateSQL() { - log.Info("skip create table and alter autoIncID") - } else { - err := db.CreateTables(ctx, tables, rc.getRebasedTables(), rc.supportPolicy, rc.policyMap) - if err != nil { - return nil, errors.Trace(err) - } - } - cts := make([]CreatedTable, 0, len(tables)) - for _, table := range tables { - newTableInfo, err := restore.GetTableSchema(rc.dom, table.DB.Name, table.Info.Name) - if err != nil { - return nil, errors.Trace(err) - } - if newTableInfo.IsCommonHandle != table.Info.IsCommonHandle { - return nil, errors.Annotatef(berrors.ErrRestoreModeMismatch, - "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", - restore.TransferBoolToValue(table.Info.IsCommonHandle), - table.Info.IsCommonHandle, - newTableInfo.IsCommonHandle) - } - rules := restoreutils.GetRewriteRules(newTableInfo, table.Info, newTS, true) - ct := CreatedTable{ - RewriteRule: rules, - Table: newTableInfo, - OldTable: table, - } - log.Debug("new created tables", zap.Any("table", ct)) - cts = append(cts, ct) - } - return cts, nil -} - -func (rc *SnapClient) createTablesInWorkerPool(ctx context.Context, tables []*metautil.Table, newTS uint64, outCh chan<- CreatedTable) error { - eg, ectx := errgroup.WithContext(ctx) - rater := logutil.TraceRateOver(logutil.MetricTableCreatedCounter) - workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "Create Tables Worker") - numOfTables := len(tables) - - for lastSent := 0; lastSent < numOfTables; lastSent += int(rc.batchDdlSize) { - end := min(lastSent+int(rc.batchDdlSize), len(tables)) - log.Info("create tables", zap.Int("table start", lastSent), zap.Int("table end", end)) - - tableSlice := tables[lastSent:end] - workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { - db := rc.dbPool[id%uint64(len(rc.dbPool))] - cts, err := rc.createTables(ectx, db, tableSlice, newTS) // ddl job for [lastSent:i) - failpoint.Inject("restore-createtables-error", func(val failpoint.Value) { - if val.(bool) { - err = errors.New("sample error without extra message") - } - }) - if err != nil { - log.Error("create tables fail", zap.Error(err)) - return err - } - for _, ct := range cts { - log.Debug("table created and send to next", - zap.Int("output chan size", len(outCh)), - zap.Stringer("table", ct.OldTable.Info.Name), - zap.Stringer("database", ct.OldTable.DB.Name)) - outCh <- ct - rater.Inc() - rater.L().Info("table created", - zap.Stringer("table", ct.OldTable.Info.Name), - zap.Stringer("database", ct.OldTable.DB.Name)) - } - return err - }) - } - return eg.Wait() -} - -func (rc *SnapClient) createTable( - ctx context.Context, - db *tidallocdb.DB, - table *metautil.Table, - newTS uint64, -) (CreatedTable, error) { - if rc.IsSkipCreateSQL() { - log.Info("skip create table and alter autoIncID", zap.Stringer("table", table.Info.Name)) - } else { - err := db.CreateTable(ctx, table, rc.getRebasedTables(), rc.supportPolicy, rc.policyMap) - if err != nil { - return CreatedTable{}, errors.Trace(err) - } - } - newTableInfo, err := restore.GetTableSchema(rc.dom, table.DB.Name, table.Info.Name) - if err != nil { - return CreatedTable{}, errors.Trace(err) - } - if newTableInfo.IsCommonHandle != table.Info.IsCommonHandle { - return CreatedTable{}, errors.Annotatef(berrors.ErrRestoreModeMismatch, - "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", - restore.TransferBoolToValue(table.Info.IsCommonHandle), - table.Info.IsCommonHandle, - newTableInfo.IsCommonHandle) - } - rules := restoreutils.GetRewriteRules(newTableInfo, table.Info, newTS, true) - et := CreatedTable{ - RewriteRule: rules, - Table: newTableInfo, - OldTable: table, - } - return et, nil -} - -func (rc *SnapClient) createTablesWithSoleDB(ctx context.Context, - createOneTable func(ctx context.Context, db *tidallocdb.DB, t *metautil.Table) error, - tables []*metautil.Table) error { - for _, t := range tables { - if err := createOneTable(ctx, rc.db, t); err != nil { - return errors.Trace(err) - } - } - return nil -} - -func (rc *SnapClient) createTablesWithDBPool(ctx context.Context, - createOneTable func(ctx context.Context, db *tidallocdb.DB, t *metautil.Table) error, - tables []*metautil.Table) error { - eg, ectx := errgroup.WithContext(ctx) - workers := tidbutil.NewWorkerPool(uint(len(rc.dbPool)), "DDL workers") - for _, t := range tables { - table := t - workers.ApplyWithIDInErrorGroup(eg, func(id uint64) error { - db := rc.dbPool[id%uint64(len(rc.dbPool))] - return createOneTable(ectx, db, table) - }) - } - return eg.Wait() -} - -// InitFullClusterRestore init fullClusterRestore and set SkipGrantTable as needed -func (rc *SnapClient) InitFullClusterRestore(explicitFilter bool) { - rc.fullClusterRestore = !explicitFilter && rc.IsFull() - - log.Info("full cluster restore", zap.Bool("value", rc.fullClusterRestore)) -} - -func (rc *SnapClient) IsFullClusterRestore() bool { - return rc.fullClusterRestore -} - -// IsFull returns whether this backup is full. -func (rc *SnapClient) IsFull() bool { - failpoint.Inject("mock-incr-backup-data", func() { - failpoint.Return(false) - }) - return !rc.IsIncremental() -} - -// IsIncremental returns whether this backup is incremental. -func (rc *SnapClient) IsIncremental() bool { - return !(rc.backupMeta.StartVersion == rc.backupMeta.EndVersion || - rc.backupMeta.StartVersion == 0) -} - -// NeedCheckFreshCluster is every time. except restore from a checkpoint or user has not set filter argument. -func (rc *SnapClient) NeedCheckFreshCluster(ExplicitFilter bool, firstRun bool) bool { - return rc.IsFull() && !ExplicitFilter && firstRun -} - -// EnableSkipCreateSQL sets switch of skip create schema and tables. -func (rc *SnapClient) EnableSkipCreateSQL() { - rc.noSchema = true -} - -// IsSkipCreateSQL returns whether we need skip create schema and tables in restore. -func (rc *SnapClient) IsSkipCreateSQL() bool { - return rc.noSchema -} - -// CheckTargetClusterFresh check whether the target cluster is fresh or not -// if there's no user dbs or tables, we take it as a fresh cluster, although -// user may have created some users or made other changes. -func (rc *SnapClient) CheckTargetClusterFresh(ctx context.Context) error { - log.Info("checking whether target cluster is fresh") - return restore.AssertUserDBsEmpty(rc.dom) -} - -// ExecDDLs executes the queries of the ddl jobs. -func (rc *SnapClient) ExecDDLs(ctx context.Context, ddlJobs []*model.Job) error { - // Sort the ddl jobs by schema version in ascending order. - slices.SortFunc(ddlJobs, func(i, j *model.Job) int { - return cmp.Compare(i.BinlogInfo.SchemaVersion, j.BinlogInfo.SchemaVersion) - }) - - for _, job := range ddlJobs { - err := rc.db.ExecDDL(ctx, job) - if err != nil { - return errors.Trace(err) - } - log.Info("execute ddl query", - zap.String("db", job.SchemaName), - zap.String("query", job.Query), - zap.Int64("historySchemaVersion", job.BinlogInfo.SchemaVersion)) - } - return nil -} - -func (rc *SnapClient) ResetSpeedLimit(ctx context.Context) error { - rc.hasSpeedLimited = false - err := rc.setSpeedLimit(ctx, 0) - if err != nil { - return errors.Trace(err) - } - return nil -} - -func (rc *SnapClient) setSpeedLimit(ctx context.Context, rateLimit uint64) error { - if !rc.hasSpeedLimited { - stores, err := util.GetAllTiKVStores(ctx, rc.pdClient, util.SkipTiFlash) - if err != nil { - return errors.Trace(err) - } - - eg, ectx := errgroup.WithContext(ctx) - for _, store := range stores { - if err := ectx.Err(); err != nil { - return errors.Trace(err) - } - - finalStore := store - rc.workerPool.ApplyOnErrorGroup(eg, - func() error { - err := rc.fileImporter.SetDownloadSpeedLimit(ectx, finalStore.GetId(), rateLimit) - if err != nil { - return errors.Trace(err) - } - return nil - }) - } - - if err := eg.Wait(); err != nil { - return errors.Trace(err) - } - rc.hasSpeedLimited = true - } - return nil -} - -func getFileRangeKey(f string) string { - // the backup date file pattern is `{store_id}_{region_id}_{epoch_version}_{key}_{ts}_{cf}.sst` - // so we need to compare with out the `_{cf}.sst` suffix - idx := strings.LastIndex(f, "_") - if idx < 0 { - panic(fmt.Sprintf("invalid backup data file name: '%s'", f)) - } - - return f[:idx] -} - -// isFilesBelongToSameRange check whether two files are belong to the same range with different cf. -func isFilesBelongToSameRange(f1, f2 string) bool { - return getFileRangeKey(f1) == getFileRangeKey(f2) -} - -func drainFilesByRange(files []*backuppb.File) ([]*backuppb.File, []*backuppb.File) { - if len(files) == 0 { - return nil, nil - } - idx := 1 - for idx < len(files) { - if !isFilesBelongToSameRange(files[idx-1].Name, files[idx].Name) { - break - } - idx++ - } - - return files[:idx], files[idx:] -} - -// RestoreSSTFiles tries to restore the files. -func (rc *SnapClient) RestoreSSTFiles( - ctx context.Context, - tableIDWithFiles []TableIDWithFiles, - updateCh glue.Progress, -) (err error) { - start := time.Now() - fileCount := 0 - defer func() { - elapsed := time.Since(start) - if err == nil { - log.Info("Restore files", zap.Duration("take", elapsed)) - summary.CollectSuccessUnit("files", fileCount, elapsed) - } - }() - - log.Debug("start to restore files", zap.Int("files", fileCount)) - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("Client.RestoreSSTFiles", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - eg, ectx := errgroup.WithContext(ctx) - err = rc.setSpeedLimit(ctx, rc.rateLimit) - if err != nil { - return errors.Trace(err) - } - - var rangeFiles []*backuppb.File - var leftFiles []*backuppb.File -LOOPFORTABLE: - for _, tableIDWithFile := range tableIDWithFiles { - tableID := tableIDWithFile.TableID - files := tableIDWithFile.Files - rules := tableIDWithFile.RewriteRules - fileCount += len(files) - for rangeFiles, leftFiles = drainFilesByRange(files); len(rangeFiles) != 0; rangeFiles, leftFiles = drainFilesByRange(leftFiles) { - if ectx.Err() != nil { - log.Warn("Restoring encountered error and already stopped, give up remained files.", - zap.Int("remained", len(leftFiles)), - logutil.ShortError(ectx.Err())) - // We will fetch the error from the errgroup then (If there were). - // Also note if the parent context has been canceled or something, - // breaking here directly is also a reasonable behavior. - break LOOPFORTABLE - } - filesReplica := rangeFiles - rc.fileImporter.WaitUntilUnblock() - rc.workerPool.ApplyOnErrorGroup(eg, func() (restoreErr error) { - fileStart := time.Now() - defer func() { - if restoreErr == nil { - log.Info("import files done", logutil.Files(filesReplica), - zap.Duration("take", time.Since(fileStart))) - updateCh.Inc() - } - }() - if importErr := rc.fileImporter.ImportSSTFiles(ectx, filesReplica, rules, rc.cipher, rc.dom.Store().GetCodec().GetAPIVersion()); importErr != nil { - return errors.Trace(importErr) - } - - // the data of this range has been import done - if rc.checkpointRunner != nil && len(filesReplica) > 0 { - rangeKey := getFileRangeKey(filesReplica[0].Name) - // The checkpoint range shows this ranges of kvs has been restored into - // the table corresponding to the table-id. - if err := checkpoint.AppendRangesForRestore(ectx, rc.checkpointRunner, tableID, rangeKey); err != nil { - return errors.Trace(err) - } - } - return nil - }) - } - } - - if err := eg.Wait(); err != nil { - summary.CollectFailureUnit("file", err) - log.Error( - "restore files failed", - zap.Error(err), - ) - return errors.Trace(err) - } - // Once the parent context canceled and there is no task running in the errgroup, - // we may break the for loop without error in the errgroup. (Will this happen?) - // At that time, return the error in the context here. - return ctx.Err() -} - -func (rc *SnapClient) execChecksum( - ctx context.Context, - tbl *CreatedTable, - kvClient kv.Client, - concurrency uint, -) error { - logger := log.L().With( - zap.String("db", tbl.OldTable.DB.Name.O), - zap.String("table", tbl.OldTable.Info.Name.O), - ) - - if tbl.OldTable.NoChecksum() { - logger.Warn("table has no checksum, skipping checksum") - return nil - } - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("Client.execChecksum", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - item, exists := rc.checkpointChecksum[tbl.Table.ID] - if !exists { - startTS, err := restore.GetTSWithRetry(ctx, rc.pdClient) - if err != nil { - return errors.Trace(err) - } - exe, err := checksum.NewExecutorBuilder(tbl.Table, startTS). - SetOldTable(tbl.OldTable). - SetConcurrency(concurrency). - SetOldKeyspace(tbl.RewriteRule.OldKeyspace). - SetNewKeyspace(tbl.RewriteRule.NewKeyspace). - SetExplicitRequestSourceType(kvutil.ExplicitTypeBR). - Build() - if err != nil { - return errors.Trace(err) - } - checksumResp, err := exe.Execute(ctx, kvClient, func() { - // TODO: update progress here. - }) - if err != nil { - return errors.Trace(err) - } - item = &checkpoint.ChecksumItem{ - TableID: tbl.Table.ID, - Crc64xor: checksumResp.Checksum, - TotalKvs: checksumResp.TotalKvs, - TotalBytes: checksumResp.TotalBytes, - } - if rc.checkpointRunner != nil { - err = rc.checkpointRunner.FlushChecksumItem(ctx, item) - if err != nil { - return errors.Trace(err) - } - } - } - table := tbl.OldTable - if item.Crc64xor != table.Crc64Xor || - item.TotalKvs != table.TotalKvs || - item.TotalBytes != table.TotalBytes { - logger.Error("failed in validate checksum", - zap.Uint64("origin tidb crc64", table.Crc64Xor), - zap.Uint64("calculated crc64", item.Crc64xor), - zap.Uint64("origin tidb total kvs", table.TotalKvs), - zap.Uint64("calculated total kvs", item.TotalKvs), - zap.Uint64("origin tidb total bytes", table.TotalBytes), - zap.Uint64("calculated total bytes", item.TotalBytes), - ) - return errors.Annotate(berrors.ErrRestoreChecksumMismatch, "failed to validate checksum") - } - logger.Info("success in validate checksum") - return nil -} - -func (rc *SnapClient) WaitForFilesRestored(ctx context.Context, files []*backuppb.File, updateCh glue.Progress) error { - errCh := make(chan error, len(files)) - eg, ectx := errgroup.WithContext(ctx) - defer close(errCh) - - for _, file := range files { - fileReplica := file - rc.workerPool.ApplyOnErrorGroup(eg, - func() error { - defer func() { - log.Info("import sst files done", logutil.Files(files)) - updateCh.Inc() - }() - return rc.fileImporter.ImportSSTFiles(ectx, []*backuppb.File{fileReplica}, restoreutils.EmptyRewriteRule(), rc.cipher, rc.backupMeta.ApiVersion) - }) - } - if err := eg.Wait(); err != nil { - return errors.Trace(err) - } - return nil -} - -// RestoreRaw tries to restore raw keys in the specified range. -func (rc *SnapClient) RestoreRaw( - ctx context.Context, startKey []byte, endKey []byte, files []*backuppb.File, updateCh glue.Progress, -) error { - start := time.Now() - defer func() { - elapsed := time.Since(start) - log.Info("Restore Raw", - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey), - zap.Duration("take", elapsed)) - }() - err := rc.fileImporter.SetRawRange(startKey, endKey) - if err != nil { - return errors.Trace(err) - } - - err = rc.WaitForFilesRestored(ctx, files, updateCh) - if err != nil { - return errors.Trace(err) - } - log.Info( - "finish to restore raw range", - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey), - ) - return nil -} - -// SplitRanges implements TiKVRestorer. It splits region by -// data range after rewrite. -func (rc *SnapClient) SplitRanges( - ctx context.Context, - ranges []rtree.Range, - updateCh glue.Progress, - isRawKv bool, -) error { - splitClientOpts := make([]split.ClientOptionalParameter, 0, 2) - splitClientOpts = append(splitClientOpts, split.WithOnSplit(func(keys [][]byte) { - for range keys { - updateCh.Inc() - } - })) - if isRawKv { - splitClientOpts = append(splitClientOpts, split.WithRawKV()) - } - - splitter := internalutils.NewRegionSplitter(split.NewClient( - rc.pdClient, - rc.pdHTTPClient, - rc.tlsConf, - maxSplitKeysOnce, - rc.storeCount+1, - splitClientOpts..., - )) - - return splitter.ExecuteSplit(ctx, ranges) -} diff --git a/br/pkg/restore/snap_client/context_manager.go b/br/pkg/restore/snap_client/context_manager.go index 1c5a2569a6a00..294f774630db6 100644 --- a/br/pkg/restore/snap_client/context_manager.go +++ b/br/pkg/restore/snap_client/context_manager.go @@ -242,10 +242,10 @@ func (manager *brContextManager) waitPlacementSchedule(ctx context.Context, tabl } log.Info("start waiting placement schedule") ticker := time.NewTicker(time.Second * 10) - if _, _err_ := failpoint.Eval(_curpkg_("wait-placement-schedule-quicker-ticker")); _err_ == nil { + failpoint.Inject("wait-placement-schedule-quicker-ticker", func() { ticker.Stop() ticker = time.NewTicker(time.Millisecond * 500) - } + }) defer ticker.Stop() for { select { diff --git a/br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ b/br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ deleted file mode 100644 index 294f774630db6..0000000000000 --- a/br/pkg/restore/snap_client/context_manager.go__failpoint_stash__ +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package snapclient - -import ( - "context" - "crypto/tls" - "encoding/hex" - "fmt" - "strconv" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/conn" - "github.com/pingcap/tidb/br/pkg/conn/util" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/restore/split" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/codec" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - "go.uber.org/zap" -) - -// ContextManager is the struct to manage a TiKV 'context' for restore. -// Batcher will call Enter when any table should be restore on batch, -// so you can do some prepare work here(e.g. set placement rules for online restore). -type ContextManager interface { - // Enter make some tables 'enter' this context(a.k.a., prepare for restore). - Enter(ctx context.Context, tables []CreatedTable) error - // Leave make some tables 'leave' this context(a.k.a., restore is done, do some post-works). - Leave(ctx context.Context, tables []CreatedTable) error - // Close closes the context manager, sometimes when the manager is 'killed' and should do some cleanup - // it would be call. - Close(ctx context.Context) -} - -// NewBRContextManager makes a BR context manager, that is, -// set placement rules for online restore when enter(see ), -// unset them when leave. -func NewBRContextManager(ctx context.Context, pdClient pd.Client, pdHTTPCli pdhttp.Client, tlsConf *tls.Config, isOnline bool) (ContextManager, error) { - manager := &brContextManager{ - // toolClient reuse the split.SplitClient to do miscellaneous things. It doesn't - // call split related functions so set the arguments to arbitrary values. - toolClient: split.NewClient(pdClient, pdHTTPCli, tlsConf, maxSplitKeysOnce, 3), - isOnline: isOnline, - - hasTable: make(map[int64]CreatedTable), - } - - err := manager.loadRestoreStores(ctx, pdClient) - return manager, errors.Trace(err) -} - -type brContextManager struct { - toolClient split.SplitClient - restoreStores []uint64 - isOnline bool - - // This 'set' of table ID allow us to handle each table just once. - hasTable map[int64]CreatedTable - mu sync.Mutex -} - -func (manager *brContextManager) Close(ctx context.Context) { - tbls := make([]*model.TableInfo, 0, len(manager.hasTable)) - for _, tbl := range manager.hasTable { - tbls = append(tbls, tbl.Table) - } - manager.splitPostWork(ctx, tbls) -} - -func (manager *brContextManager) Enter(ctx context.Context, tables []CreatedTable) error { - placementRuleTables := make([]*model.TableInfo, 0, len(tables)) - manager.mu.Lock() - defer manager.mu.Unlock() - - for _, tbl := range tables { - if _, ok := manager.hasTable[tbl.Table.ID]; !ok { - placementRuleTables = append(placementRuleTables, tbl.Table) - } - manager.hasTable[tbl.Table.ID] = tbl - } - - return manager.splitPrepareWork(ctx, placementRuleTables) -} - -func (manager *brContextManager) Leave(ctx context.Context, tables []CreatedTable) error { - manager.mu.Lock() - defer manager.mu.Unlock() - placementRuleTables := make([]*model.TableInfo, 0, len(tables)) - - for _, table := range tables { - placementRuleTables = append(placementRuleTables, table.Table) - } - - manager.splitPostWork(ctx, placementRuleTables) - log.Info("restore table done", zapTables(tables)) - for _, tbl := range placementRuleTables { - delete(manager.hasTable, tbl.ID) - } - return nil -} - -func (manager *brContextManager) splitPostWork(ctx context.Context, tables []*model.TableInfo) { - err := manager.resetPlacementRules(ctx, tables) - if err != nil { - log.Warn("reset placement rules failed", zap.Error(err)) - return - } -} - -func (manager *brContextManager) splitPrepareWork(ctx context.Context, tables []*model.TableInfo) error { - err := manager.setupPlacementRules(ctx, tables) - if err != nil { - log.Error("setup placement rules failed", zap.Error(err)) - return errors.Trace(err) - } - - err = manager.waitPlacementSchedule(ctx, tables) - if err != nil { - log.Error("wait placement schedule failed", zap.Error(err)) - return errors.Trace(err) - } - return nil -} - -const ( - restoreLabelKey = "exclusive" - restoreLabelValue = "restore" -) - -// loadRestoreStores loads the stores used to restore data. This function is called only when is online. -func (manager *brContextManager) loadRestoreStores(ctx context.Context, pdClient util.StoreMeta) error { - if !manager.isOnline { - return nil - } - stores, err := conn.GetAllTiKVStoresWithRetry(ctx, pdClient, util.SkipTiFlash) - if err != nil { - return errors.Trace(err) - } - for _, s := range stores { - if s.GetState() != metapb.StoreState_Up { - continue - } - for _, l := range s.GetLabels() { - if l.GetKey() == restoreLabelKey && l.GetValue() == restoreLabelValue { - manager.restoreStores = append(manager.restoreStores, s.GetId()) - break - } - } - } - log.Info("load restore stores", zap.Uint64s("store-ids", manager.restoreStores)) - return nil -} - -// SetupPlacementRules sets rules for the tables' regions. -func (manager *brContextManager) setupPlacementRules(ctx context.Context, tables []*model.TableInfo) error { - if !manager.isOnline || len(manager.restoreStores) == 0 { - return nil - } - log.Info("start setting placement rules") - rule, err := manager.toolClient.GetPlacementRule(ctx, "pd", "default") - if err != nil { - return errors.Trace(err) - } - rule.Index = 100 - rule.Override = true - rule.LabelConstraints = append(rule.LabelConstraints, pdhttp.LabelConstraint{ - Key: restoreLabelKey, - Op: "in", - Values: []string{restoreLabelValue}, - }) - for _, t := range tables { - rule.ID = getRuleID(t.ID) - rule.StartKeyHex = hex.EncodeToString(codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID))) - rule.EndKeyHex = hex.EncodeToString(codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID+1))) - err = manager.toolClient.SetPlacementRule(ctx, rule) - if err != nil { - return errors.Trace(err) - } - } - log.Info("finish setting placement rules") - return nil -} - -func (manager *brContextManager) checkRegions(ctx context.Context, tables []*model.TableInfo) (bool, string, error) { - for i, t := range tables { - start := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID)) - end := codec.EncodeBytes([]byte{}, tablecodec.EncodeTablePrefix(t.ID+1)) - ok, regionProgress, err := manager.checkRange(ctx, start, end) - if err != nil { - return false, "", errors.Trace(err) - } - if !ok { - return false, fmt.Sprintf("table %v/%v, %s", i, len(tables), regionProgress), nil - } - } - return true, "", nil -} - -func (manager *brContextManager) checkRange(ctx context.Context, start, end []byte) (bool, string, error) { - regions, err := manager.toolClient.ScanRegions(ctx, start, end, -1) - if err != nil { - return false, "", errors.Trace(err) - } - for i, r := range regions { - NEXT_PEER: - for _, p := range r.Region.GetPeers() { - for _, storeID := range manager.restoreStores { - if p.GetStoreId() == storeID { - continue NEXT_PEER - } - } - return false, fmt.Sprintf("region %v/%v", i, len(regions)), nil - } - } - return true, "", nil -} - -// waitPlacementSchedule waits PD to move tables to restore stores. -func (manager *brContextManager) waitPlacementSchedule(ctx context.Context, tables []*model.TableInfo) error { - if !manager.isOnline || len(manager.restoreStores) == 0 { - return nil - } - log.Info("start waiting placement schedule") - ticker := time.NewTicker(time.Second * 10) - failpoint.Inject("wait-placement-schedule-quicker-ticker", func() { - ticker.Stop() - ticker = time.NewTicker(time.Millisecond * 500) - }) - defer ticker.Stop() - for { - select { - case <-ticker.C: - ok, progress, err := manager.checkRegions(ctx, tables) - if err != nil { - return errors.Trace(err) - } - if ok { - log.Info("finish waiting placement schedule") - return nil - } - log.Info("placement schedule progress: " + progress) - case <-ctx.Done(): - return ctx.Err() - } - } -} - -func getRuleID(tableID int64) string { - return "restore-t" + strconv.FormatInt(tableID, 10) -} - -// resetPlacementRules removes placement rules for tables. -func (manager *brContextManager) resetPlacementRules(ctx context.Context, tables []*model.TableInfo) error { - if !manager.isOnline || len(manager.restoreStores) == 0 { - return nil - } - log.Info("start resetting placement rules") - var failedTables []int64 - for _, t := range tables { - err := manager.toolClient.DeletePlacementRule(ctx, "pd", getRuleID(t.ID)) - if err != nil { - log.Info("failed to delete placement rule for table", zap.Int64("table-id", t.ID)) - failedTables = append(failedTables, t.ID) - } - } - if len(failedTables) > 0 { - return errors.Annotatef(berrors.ErrPDInvalidResponse, "failed to delete placement rules for tables %v", failedTables) - } - return nil -} diff --git a/br/pkg/restore/snap_client/import.go b/br/pkg/restore/snap_client/import.go index 1a27d67011066..cdab5a678628a 100644 --- a/br/pkg/restore/snap_client/import.go +++ b/br/pkg/restore/snap_client/import.go @@ -489,15 +489,15 @@ func (importer *SnapFileImporter) download( downloadMetas, e = importer.downloadSST(ctx, regionInfo, files, rewriteRules, cipher, apiVersion) } - if val, _err_ := failpoint.Eval(_curpkg_("restore-storage-error")); _err_ == nil { + failpoint.Inject("restore-storage-error", func(val failpoint.Value) { msg := val.(string) log.Debug("failpoint restore-storage-error injected.", zap.String("msg", msg)) e = errors.Annotate(e, msg) - } - if _, _err_ := failpoint.Eval(_curpkg_("restore-gRPC-error")); _err_ == nil { + }) + failpoint.Inject("restore-gRPC-error", func(_ failpoint.Value) { log.Warn("the connection to TiKV has been cut by a neko, meow :3") e = status.Error(codes.Unavailable, "the connection to TiKV has been cut by a neko, meow :3") - } + }) if isDecryptSstErr(e) { log.Info("fail to decrypt when download sst, try again with no-crypt", logutil.Files(files)) if importer.kvMode == Raw || importer.kvMode == Txn { diff --git a/br/pkg/restore/snap_client/import.go__failpoint_stash__ b/br/pkg/restore/snap_client/import.go__failpoint_stash__ deleted file mode 100644 index cdab5a678628a..0000000000000 --- a/br/pkg/restore/snap_client/import.go__failpoint_stash__ +++ /dev/null @@ -1,846 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package snapclient - -import ( - "bytes" - "context" - "fmt" - "math/rand" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/kvproto/pkg/import_sstpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - importclient "github.com/pingcap/tidb/br/pkg/restore/internal/import_client" - "github.com/pingcap/tidb/br/pkg/restore/split" - restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/util/codec" - kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/multierr" - "go.uber.org/zap" - "golang.org/x/exp/maps" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -type KvMode int - -const ( - TiDB KvMode = iota - Raw - Txn -) - -const ( - // Todo: make it configable - gRPCTimeOut = 25 * time.Minute -) - -// RewriteMode is a mode flag that tells the TiKV how to handle the rewrite rules. -type RewriteMode int - -const ( - // RewriteModeLegacy means no rewrite rule is applied. - RewriteModeLegacy RewriteMode = iota - - // RewriteModeKeyspace means the rewrite rule could be applied to keyspace. - RewriteModeKeyspace -) - -type storeTokenChannelMap struct { - sync.RWMutex - tokens map[uint64]chan struct{} -} - -func (s *storeTokenChannelMap) acquireTokenCh(storeID uint64, bufferSize uint) chan struct{} { - s.RLock() - tokenCh, ok := s.tokens[storeID] - // handle the case that the store is new-scaled in the cluster - if !ok { - s.RUnlock() - s.Lock() - // Notice: worker channel can't replaced, because it is still used after unlock. - if tokenCh, ok = s.tokens[storeID]; !ok { - tokenCh = utils.BuildWorkerTokenChannel(bufferSize) - s.tokens[storeID] = tokenCh - } - s.Unlock() - } else { - s.RUnlock() - } - return tokenCh -} - -func (s *storeTokenChannelMap) ShouldBlock() bool { - s.RLock() - defer s.RUnlock() - if len(s.tokens) == 0 { - // never block if there is no store worker pool - return false - } - for _, pool := range s.tokens { - if len(pool) > 0 { - // At least one store worker pool has available worker - return false - } - } - return true -} - -func newStoreTokenChannelMap(stores []*metapb.Store, bufferSize uint) *storeTokenChannelMap { - storeTokenChannelMap := &storeTokenChannelMap{ - sync.RWMutex{}, - make(map[uint64]chan struct{}), - } - if bufferSize == 0 { - return storeTokenChannelMap - } - for _, store := range stores { - ch := utils.BuildWorkerTokenChannel(bufferSize) - storeTokenChannelMap.tokens[store.Id] = ch - } - return storeTokenChannelMap -} - -type SnapFileImporter struct { - metaClient split.SplitClient - importClient importclient.ImporterClient - backend *backuppb.StorageBackend - - downloadTokensMap *storeTokenChannelMap - ingestTokensMap *storeTokenChannelMap - - concurrencyPerStore uint - - kvMode KvMode - rawStartKey []byte - rawEndKey []byte - rewriteMode RewriteMode - - cacheKey string - cond *sync.Cond -} - -func NewSnapFileImporter( - ctx context.Context, - metaClient split.SplitClient, - importClient importclient.ImporterClient, - backend *backuppb.StorageBackend, - isRawKvMode bool, - isTxnKvMode bool, - tikvStores []*metapb.Store, - rewriteMode RewriteMode, - concurrencyPerStore uint, -) (*SnapFileImporter, error) { - kvMode := TiDB - if isRawKvMode { - kvMode = Raw - } - if isTxnKvMode { - kvMode = Txn - } - - fileImporter := &SnapFileImporter{ - metaClient: metaClient, - backend: backend, - importClient: importClient, - downloadTokensMap: newStoreTokenChannelMap(tikvStores, concurrencyPerStore), - ingestTokensMap: newStoreTokenChannelMap(tikvStores, concurrencyPerStore), - kvMode: kvMode, - rewriteMode: rewriteMode, - cacheKey: fmt.Sprintf("BR-%s-%d", time.Now().Format("20060102150405"), rand.Int63()), - concurrencyPerStore: concurrencyPerStore, - cond: sync.NewCond(new(sync.Mutex)), - } - - err := fileImporter.checkMultiIngestSupport(ctx, tikvStores) - return fileImporter, errors.Trace(err) -} - -func (importer *SnapFileImporter) WaitUntilUnblock() { - importer.cond.L.Lock() - for importer.ShouldBlock() { - // wait for download worker notified - importer.cond.Wait() - } - importer.cond.L.Unlock() -} - -func (importer *SnapFileImporter) ShouldBlock() bool { - if importer != nil { - return importer.downloadTokensMap.ShouldBlock() || importer.ingestTokensMap.ShouldBlock() - } - return false -} - -func (importer *SnapFileImporter) releaseToken(tokenCh chan struct{}) { - tokenCh <- struct{}{} - // finish the task, notify the main goroutine to continue - importer.cond.L.Lock() - importer.cond.Signal() - importer.cond.L.Unlock() -} - -func (importer *SnapFileImporter) Close() error { - if importer != nil && importer.importClient != nil { - return importer.importClient.CloseGrpcClient() - } - return nil -} - -func (importer *SnapFileImporter) SetDownloadSpeedLimit(ctx context.Context, storeID, rateLimit uint64) error { - req := &import_sstpb.SetDownloadSpeedLimitRequest{ - SpeedLimit: rateLimit, - } - _, err := importer.importClient.SetDownloadSpeedLimit(ctx, storeID, req) - return errors.Trace(err) -} - -// checkMultiIngestSupport checks whether all stores support multi-ingest -func (importer *SnapFileImporter) checkMultiIngestSupport(ctx context.Context, tikvStores []*metapb.Store) error { - storeIDs := make([]uint64, 0, len(tikvStores)) - for _, s := range tikvStores { - if s.State != metapb.StoreState_Up { - continue - } - storeIDs = append(storeIDs, s.Id) - } - - if err := importer.importClient.CheckMultiIngestSupport(ctx, storeIDs); err != nil { - return errors.Trace(err) - } - return nil -} - -// SetRawRange sets the range to be restored in raw kv mode. -func (importer *SnapFileImporter) SetRawRange(startKey, endKey []byte) error { - if importer.kvMode != Raw { - return errors.Annotate(berrors.ErrRestoreModeMismatch, "file importer is not in raw kv mode") - } - importer.rawStartKey = startKey - importer.rawEndKey = endKey - return nil -} - -func getKeyRangeByMode(mode KvMode) func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { - switch mode { - case Raw: - return func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { - return f.GetStartKey(), f.GetEndKey(), nil - } - case Txn: - return func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { - start, end := f.GetStartKey(), f.GetEndKey() - if len(start) != 0 { - start = codec.EncodeBytes([]byte{}, f.GetStartKey()) - } - if len(end) != 0 { - end = codec.EncodeBytes([]byte{}, f.GetEndKey()) - } - return start, end, nil - } - default: - return func(f *backuppb.File, rules *restoreutils.RewriteRules) ([]byte, []byte, error) { - return restoreutils.GetRewriteRawKeys(f, rules) - } - } -} - -// getKeyRangeForFiles gets the maximum range on files. -func (importer *SnapFileImporter) getKeyRangeForFiles( - files []*backuppb.File, - rewriteRules *restoreutils.RewriteRules, -) ([]byte, []byte, error) { - var ( - startKey, endKey []byte - start, end []byte - err error - ) - getRangeFn := getKeyRangeByMode(importer.kvMode) - for _, f := range files { - start, end, err = getRangeFn(f, rewriteRules) - if err != nil { - return nil, nil, errors.Trace(err) - } - if len(startKey) == 0 || bytes.Compare(start, startKey) < 0 { - startKey = start - } - if len(endKey) == 0 || bytes.Compare(endKey, end) < 0 { - endKey = end - } - } - - log.Debug("rewrite file keys", logutil.Files(files), - logutil.Key("startKey", startKey), logutil.Key("endKey", endKey)) - return startKey, endKey, nil -} - -// ImportSSTFiles tries to import a file. -// All rules must contain encoded keys. -func (importer *SnapFileImporter) ImportSSTFiles( - ctx context.Context, - files []*backuppb.File, - rewriteRules *restoreutils.RewriteRules, - cipher *backuppb.CipherInfo, - apiVersion kvrpcpb.APIVersion, -) error { - start := time.Now() - log.Debug("import file", logutil.Files(files)) - - // Rewrite the start key and end key of file to scan regions - startKey, endKey, err := importer.getKeyRangeForFiles(files, rewriteRules) - if err != nil { - return errors.Trace(err) - } - - err = utils.WithRetry(ctx, func() error { - // Scan regions covered by the file range - regionInfos, errScanRegion := split.PaginateScanRegion( - ctx, importer.metaClient, startKey, endKey, split.ScanRegionPaginationLimit) - if errScanRegion != nil { - return errors.Trace(errScanRegion) - } - - log.Debug("scan regions", logutil.Files(files), zap.Int("count", len(regionInfos))) - // Try to download and ingest the file in every region - regionLoop: - for _, regionInfo := range regionInfos { - info := regionInfo - // Try to download file. - downloadMetas, errDownload := importer.download(ctx, info, files, rewriteRules, cipher, apiVersion) - if errDownload != nil { - for _, e := range multierr.Errors(errDownload) { - switch errors.Cause(e) { // nolint:errorlint - case berrors.ErrKVRewriteRuleNotFound, berrors.ErrKVRangeIsEmpty: - // Skip this region - log.Warn("download file skipped", - logutil.Files(files), - logutil.Region(info.Region), - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey), - logutil.Key("file-simple-start", files[0].StartKey), - logutil.Key("file-simple-end", files[0].EndKey), - logutil.ShortError(e)) - continue regionLoop - } - } - log.Warn("download file failed, retry later", - logutil.Files(files), - logutil.Region(info.Region), - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey), - logutil.ShortError(errDownload)) - return errors.Trace(errDownload) - } - log.Debug("download file done", - zap.String("file-sample", files[0].Name), zap.Stringer("take", time.Since(start)), - logutil.Key("start", files[0].StartKey), logutil.Key("end", files[0].EndKey)) - start = time.Now() - if errIngest := importer.ingest(ctx, files, info, downloadMetas); errIngest != nil { - log.Warn("ingest file failed, retry later", - logutil.Files(files), - logutil.SSTMetas(downloadMetas), - logutil.Region(info.Region), - zap.Error(errIngest)) - return errors.Trace(errIngest) - } - log.Debug("ingest file done", zap.String("file-sample", files[0].Name), zap.Stringer("take", time.Since(start))) - } - - for _, f := range files { - summary.CollectSuccessUnit(summary.TotalKV, 1, f.TotalKvs) - summary.CollectSuccessUnit(summary.TotalBytes, 1, f.TotalBytes) - } - return nil - }, utils.NewImportSSTBackoffer()) - if err != nil { - log.Error("import sst file failed after retry, stop the whole progress", logutil.Files(files), zap.Error(err)) - return errors.Trace(err) - } - return nil -} - -// getSSTMetaFromFile compares the keys in file, region and rewrite rules, then returns a sst conn. -// The range of the returned sst meta is [regionRule.NewKeyPrefix, append(regionRule.NewKeyPrefix, 0xff)]. -func getSSTMetaFromFile( - id []byte, - file *backuppb.File, - region *metapb.Region, - regionRule *import_sstpb.RewriteRule, - rewriteMode RewriteMode, -) (meta *import_sstpb.SSTMeta, err error) { - r := *region - // If the rewrite mode is for keyspace, then the region bound should be decoded. - if rewriteMode == RewriteModeKeyspace { - if len(region.GetStartKey()) > 0 { - _, r.StartKey, err = codec.DecodeBytes(region.GetStartKey(), nil) - if err != nil { - return - } - } - if len(region.GetEndKey()) > 0 { - _, r.EndKey, err = codec.DecodeBytes(region.GetEndKey(), nil) - if err != nil { - return - } - } - } - - // Get the column family of the file by the file name. - var cfName string - if strings.Contains(file.GetName(), restoreutils.DefaultCFName) { - cfName = restoreutils.DefaultCFName - } else if strings.Contains(file.GetName(), restoreutils.WriteCFName) { - cfName = restoreutils.WriteCFName - } - // Find the overlapped part between the file and the region. - // Here we rewrites the keys to compare with the keys of the region. - rangeStart := regionRule.GetNewKeyPrefix() - // rangeStart = max(rangeStart, region.StartKey) - if bytes.Compare(rangeStart, r.GetStartKey()) < 0 { - rangeStart = r.GetStartKey() - } - - // Append 10 * 0xff to make sure rangeEnd cover all file key - // If choose to regionRule.NewKeyPrefix + 1, it may cause WrongPrefix here - // https://github.com/tikv/tikv/blob/970a9bf2a9ea782a455ae579ad237aaf6cb1daec/ - // components/sst_importer/src/sst_importer.rs#L221 - suffix := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} - rangeEnd := append(append([]byte{}, regionRule.GetNewKeyPrefix()...), suffix...) - // rangeEnd = min(rangeEnd, region.EndKey) - if len(r.GetEndKey()) > 0 && bytes.Compare(rangeEnd, r.GetEndKey()) > 0 { - rangeEnd = r.GetEndKey() - } - - if bytes.Compare(rangeStart, rangeEnd) > 0 { - log.Panic("range start exceed range end", - logutil.File(file), - logutil.Key("startKey", rangeStart), - logutil.Key("endKey", rangeEnd)) - } - - log.Debug("get sstMeta", - logutil.Region(region), - logutil.File(file), - logutil.Key("startKey", rangeStart), - logutil.Key("endKey", rangeEnd)) - - return &import_sstpb.SSTMeta{ - Uuid: id, - CfName: cfName, - Range: &import_sstpb.Range{ - Start: rangeStart, - End: rangeEnd, - }, - Length: file.GetSize_(), - RegionId: region.GetId(), - RegionEpoch: region.GetRegionEpoch(), - CipherIv: file.GetCipherIv(), - }, nil -} - -// a new way to download ssts files -// 1. download write + default sst files at peer level. -// 2. control the download concurrency per store. -func (importer *SnapFileImporter) download( - ctx context.Context, - regionInfo *split.RegionInfo, - files []*backuppb.File, - rewriteRules *restoreutils.RewriteRules, - cipher *backuppb.CipherInfo, - apiVersion kvrpcpb.APIVersion, -) ([]*import_sstpb.SSTMeta, error) { - var ( - downloadMetas = make([]*import_sstpb.SSTMeta, 0, len(files)) - ) - errDownload := utils.WithRetry(ctx, func() error { - var e error - // we treat Txn kv file as Raw kv file. because we don't have table id to decode - if importer.kvMode == Raw || importer.kvMode == Txn { - downloadMetas, e = importer.downloadRawKVSST(ctx, regionInfo, files, cipher, apiVersion) - } else { - downloadMetas, e = importer.downloadSST(ctx, regionInfo, files, rewriteRules, cipher, apiVersion) - } - - failpoint.Inject("restore-storage-error", func(val failpoint.Value) { - msg := val.(string) - log.Debug("failpoint restore-storage-error injected.", zap.String("msg", msg)) - e = errors.Annotate(e, msg) - }) - failpoint.Inject("restore-gRPC-error", func(_ failpoint.Value) { - log.Warn("the connection to TiKV has been cut by a neko, meow :3") - e = status.Error(codes.Unavailable, "the connection to TiKV has been cut by a neko, meow :3") - }) - if isDecryptSstErr(e) { - log.Info("fail to decrypt when download sst, try again with no-crypt", logutil.Files(files)) - if importer.kvMode == Raw || importer.kvMode == Txn { - downloadMetas, e = importer.downloadRawKVSST(ctx, regionInfo, files, nil, apiVersion) - } else { - downloadMetas, e = importer.downloadSST(ctx, regionInfo, files, rewriteRules, nil, apiVersion) - } - } - if e != nil { - return errors.Trace(e) - } - - return nil - }, utils.NewDownloadSSTBackoffer()) - - return downloadMetas, errDownload -} - -func (importer *SnapFileImporter) buildDownloadRequest( - file *backuppb.File, - rewriteRules *restoreutils.RewriteRules, - regionInfo *split.RegionInfo, - cipher *backuppb.CipherInfo, -) (*import_sstpb.DownloadRequest, import_sstpb.SSTMeta, error) { - uid := uuid.New() - id := uid[:] - // Get the rewrite rule for the file. - fileRule := restoreutils.FindMatchedRewriteRule(file, rewriteRules) - if fileRule == nil { - return nil, import_sstpb.SSTMeta{}, errors.Trace(berrors.ErrKVRewriteRuleNotFound) - } - - // For the legacy version of TiKV, we need to encode the key prefix, since in the legacy - // version, the TiKV will rewrite the key with the encoded prefix without decoding the keys in - // the SST file. For the new version of TiKV that support keyspace rewrite, we don't need to - // encode the key prefix. The TiKV will decode the keys in the SST file and rewrite the keys - // with the plain prefix and encode the keys before writing to SST. - - // for the keyspace rewrite mode - rule := *fileRule - // for the legacy rewrite mode - if importer.rewriteMode == RewriteModeLegacy { - rule.OldKeyPrefix = restoreutils.EncodeKeyPrefix(fileRule.GetOldKeyPrefix()) - rule.NewKeyPrefix = restoreutils.EncodeKeyPrefix(fileRule.GetNewKeyPrefix()) - } - - sstMeta, err := getSSTMetaFromFile(id, file, regionInfo.Region, &rule, importer.rewriteMode) - if err != nil { - return nil, import_sstpb.SSTMeta{}, err - } - - req := &import_sstpb.DownloadRequest{ - Sst: *sstMeta, - StorageBackend: importer.backend, - Name: file.GetName(), - RewriteRule: rule, - CipherInfo: cipher, - StorageCacheId: importer.cacheKey, - // For the older version of TiDB, the request type will be default to `import_sstpb.RequestType_Legacy` - RequestType: import_sstpb.DownloadRequestType_Keyspace, - Context: &kvrpcpb.Context{ - ResourceControlContext: &kvrpcpb.ResourceControlContext{ - ResourceGroupName: "", // TODO, - }, - RequestSource: kvutil.BuildRequestSource(true, kv.InternalTxnBR, kvutil.ExplicitTypeBR), - }, - } - return req, *sstMeta, nil -} - -func (importer *SnapFileImporter) downloadSST( - ctx context.Context, - regionInfo *split.RegionInfo, - files []*backuppb.File, - rewriteRules *restoreutils.RewriteRules, - cipher *backuppb.CipherInfo, - apiVersion kvrpcpb.APIVersion, -) ([]*import_sstpb.SSTMeta, error) { - var mu sync.Mutex - downloadMetasMap := make(map[string]import_sstpb.SSTMeta) - resultMetasMap := make(map[string]*import_sstpb.SSTMeta) - downloadReqsMap := make(map[string]*import_sstpb.DownloadRequest) - for _, file := range files { - req, sstMeta, err := importer.buildDownloadRequest(file, rewriteRules, regionInfo, cipher) - if err != nil { - return nil, errors.Trace(err) - } - sstMeta.ApiVersion = apiVersion - downloadMetasMap[file.Name] = sstMeta - downloadReqsMap[file.Name] = req - } - - eg, ectx := errgroup.WithContext(ctx) - for _, p := range regionInfo.Region.GetPeers() { - peer := p - eg.Go(func() error { - tokenCh := importer.downloadTokensMap.acquireTokenCh(peer.GetStoreId(), importer.concurrencyPerStore) - select { - case <-ectx.Done(): - return ectx.Err() - case <-tokenCh: - } - defer func() { - importer.releaseToken(tokenCh) - }() - for _, file := range files { - req, ok := downloadReqsMap[file.Name] - if !ok { - return errors.New("not found file key for download request") - } - var err error - var resp *import_sstpb.DownloadResponse - resp, err = utils.WithRetryV2(ectx, utils.NewDownloadSSTBackoffer(), func(ctx context.Context) (*import_sstpb.DownloadResponse, error) { - dctx, cancel := context.WithTimeout(ctx, gRPCTimeOut) - defer cancel() - return importer.importClient.DownloadSST(dctx, peer.GetStoreId(), req) - }) - if err != nil { - return errors.Trace(err) - } - if resp.GetError() != nil { - return errors.Annotate(berrors.ErrKVDownloadFailed, resp.GetError().GetMessage()) - } - if resp.GetIsEmpty() { - return errors.Trace(berrors.ErrKVRangeIsEmpty) - } - - mu.Lock() - sstMeta, ok := downloadMetasMap[file.Name] - if !ok { - mu.Unlock() - return errors.Errorf("not found file %s for download sstMeta", file.Name) - } - sstMeta.Range = &import_sstpb.Range{ - Start: restoreutils.TruncateTS(resp.Range.GetStart()), - End: restoreutils.TruncateTS(resp.Range.GetEnd()), - } - resultMetasMap[file.Name] = &sstMeta - mu.Unlock() - - log.Debug("download from peer", - logutil.Region(regionInfo.Region), - logutil.File(file), - logutil.Peer(peer), - logutil.Key("resp-range-start", resp.Range.Start), - logutil.Key("resp-range-end", resp.Range.End), - zap.Bool("resp-isempty", resp.IsEmpty), - zap.Uint32("resp-crc32", resp.Crc32), - zap.Int("len files", len(files)), - ) - } - return nil - }) - } - if err := eg.Wait(); err != nil { - return nil, err - } - return maps.Values(resultMetasMap), nil -} - -func (importer *SnapFileImporter) downloadRawKVSST( - ctx context.Context, - regionInfo *split.RegionInfo, - files []*backuppb.File, - cipher *backuppb.CipherInfo, - apiVersion kvrpcpb.APIVersion, -) ([]*import_sstpb.SSTMeta, error) { - downloadMetas := make([]*import_sstpb.SSTMeta, 0, len(files)) - for _, file := range files { - uid := uuid.New() - id := uid[:] - // Empty rule - var rule import_sstpb.RewriteRule - sstMeta, err := getSSTMetaFromFile(id, file, regionInfo.Region, &rule, RewriteModeLegacy) - if err != nil { - return nil, err - } - - // Cut the SST file's range to fit in the restoring range. - if bytes.Compare(importer.rawStartKey, sstMeta.Range.GetStart()) > 0 { - sstMeta.Range.Start = importer.rawStartKey - } - if len(importer.rawEndKey) > 0 && - (len(sstMeta.Range.GetEnd()) == 0 || bytes.Compare(importer.rawEndKey, sstMeta.Range.GetEnd()) <= 0) { - sstMeta.Range.End = importer.rawEndKey - sstMeta.EndKeyExclusive = true - } - if bytes.Compare(sstMeta.Range.GetStart(), sstMeta.Range.GetEnd()) > 0 { - return nil, errors.Trace(berrors.ErrKVRangeIsEmpty) - } - - req := &import_sstpb.DownloadRequest{ - Sst: *sstMeta, - StorageBackend: importer.backend, - Name: file.GetName(), - RewriteRule: rule, - IsRawKv: true, - CipherInfo: cipher, - StorageCacheId: importer.cacheKey, - } - log.Debug("download SST", logutil.SSTMeta(sstMeta), logutil.Region(regionInfo.Region)) - - var atomicResp atomic.Pointer[import_sstpb.DownloadResponse] - eg, ectx := errgroup.WithContext(ctx) - for _, p := range regionInfo.Region.GetPeers() { - peer := p - eg.Go(func() error { - resp, err := importer.importClient.DownloadSST(ectx, peer.GetStoreId(), req) - if err != nil { - return errors.Trace(err) - } - if resp.GetError() != nil { - return errors.Annotate(berrors.ErrKVDownloadFailed, resp.GetError().GetMessage()) - } - if resp.GetIsEmpty() { - return errors.Trace(berrors.ErrKVRangeIsEmpty) - } - - atomicResp.Store(resp) - return nil - }) - } - - if err := eg.Wait(); err != nil { - return nil, err - } - - downloadResp := atomicResp.Load() - sstMeta.Range.Start = downloadResp.Range.GetStart() - sstMeta.Range.End = downloadResp.Range.GetEnd() - sstMeta.ApiVersion = apiVersion - downloadMetas = append(downloadMetas, sstMeta) - } - return downloadMetas, nil -} - -func (importer *SnapFileImporter) ingest( - ctx context.Context, - files []*backuppb.File, - info *split.RegionInfo, - downloadMetas []*import_sstpb.SSTMeta, -) error { - tokenCh := importer.ingestTokensMap.acquireTokenCh(info.Leader.GetStoreId(), importer.concurrencyPerStore) - select { - case <-ctx.Done(): - return ctx.Err() - case <-tokenCh: - } - defer func() { - importer.releaseToken(tokenCh) - }() - for { - ingestResp, errIngest := importer.ingestSSTs(ctx, downloadMetas, info) - if errIngest != nil { - return errors.Trace(errIngest) - } - - errPb := ingestResp.GetError() - switch { - case errPb == nil: - return nil - case errPb.NotLeader != nil: - // If error is `NotLeader`, update the region info and retry - var newInfo *split.RegionInfo - if newLeader := errPb.GetNotLeader().GetLeader(); newLeader != nil { - newInfo = &split.RegionInfo{ - Leader: newLeader, - Region: info.Region, - } - } else { - for { - // Slow path, get region from PD - newInfo, errIngest = importer.metaClient.GetRegion( - ctx, info.Region.GetStartKey()) - if errIngest != nil { - return errors.Trace(errIngest) - } - if newInfo != nil { - break - } - // do not get region info, wait a second and GetRegion() again. - log.Warn("ingest get region by key return nil", logutil.Region(info.Region), - logutil.Files(files), - logutil.SSTMetas(downloadMetas), - ) - time.Sleep(time.Second) - } - } - - if !split.CheckRegionEpoch(newInfo, info) { - return errors.Trace(berrors.ErrKVEpochNotMatch) - } - log.Debug("ingest sst returns not leader error, retry it", - logutil.Files(files), - logutil.SSTMetas(downloadMetas), - logutil.Region(info.Region), - zap.Stringer("newLeader", newInfo.Leader)) - info = newInfo - case errPb.EpochNotMatch != nil: - // TODO handle epoch not match error - // 1. retry download if needed - // 2. retry ingest - return errors.Trace(berrors.ErrKVEpochNotMatch) - case errPb.KeyNotInRegion != nil: - return errors.Trace(berrors.ErrKVKeyNotInRegion) - default: - // Other errors like `ServerIsBusy`, `RegionNotFound`, etc. should be retryable - return errors.Annotatef(berrors.ErrKVIngestFailed, "ingest error %s", errPb) - } - } -} - -func (importer *SnapFileImporter) ingestSSTs( - ctx context.Context, - sstMetas []*import_sstpb.SSTMeta, - regionInfo *split.RegionInfo, -) (*import_sstpb.IngestResponse, error) { - leader := regionInfo.Leader - if leader == nil { - return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, - "region id %d has no leader", regionInfo.Region.Id) - } - reqCtx := &kvrpcpb.Context{ - RegionId: regionInfo.Region.GetId(), - RegionEpoch: regionInfo.Region.GetRegionEpoch(), - Peer: leader, - ResourceControlContext: &kvrpcpb.ResourceControlContext{ - ResourceGroupName: "", // TODO, - }, - RequestSource: kvutil.BuildRequestSource(true, kv.InternalTxnBR, kvutil.ExplicitTypeBR), - } - - req := &import_sstpb.MultiIngestRequest{ - Context: reqCtx, - Ssts: sstMetas, - } - log.Debug("ingest SSTs", logutil.SSTMetas(sstMetas), logutil.Leader(leader)) - resp, err := importer.importClient.MultiIngest(ctx, leader.GetStoreId(), req) - return resp, errors.Trace(err) -} - -func isDecryptSstErr(err error) bool { - return err != nil && - strings.Contains(err.Error(), "Engine Engine") && - strings.Contains(err.Error(), "Corruption: Bad table magic number") -} diff --git a/br/pkg/restore/split/binding__failpoint_binding__.go b/br/pkg/restore/split/binding__failpoint_binding__.go deleted file mode 100644 index bb941977d20db..0000000000000 --- a/br/pkg/restore/split/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package split - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/restore/split/client.go b/br/pkg/restore/split/client.go index c6c54cf0f46e6..226ecc58360e0 100644 --- a/br/pkg/restore/split/client.go +++ b/br/pkg/restore/split/client.go @@ -277,7 +277,7 @@ func splitRegionWithFailpoint( keys [][]byte, isRawKv bool, ) (*kvrpcpb.SplitRegionResponse, error) { - if injectNewLeader, _err_ := failpoint.Eval(_curpkg_("not-leader-error")); _err_ == nil { + failpoint.Inject("not-leader-error", func(injectNewLeader failpoint.Value) { log.Debug("failpoint not-leader-error injected.") resp := &kvrpcpb.SplitRegionResponse{ RegionError: &errorpb.Error{ @@ -289,16 +289,16 @@ func splitRegionWithFailpoint( if injectNewLeader.(bool) { resp.RegionError.NotLeader.Leader = regionInfo.Leader } - return resp, nil - } - if _, _err_ := failpoint.Eval(_curpkg_("somewhat-retryable-error")); _err_ == nil { + failpoint.Return(resp, nil) + }) + failpoint.Inject("somewhat-retryable-error", func() { log.Debug("failpoint somewhat-retryable-error injected.") - return &kvrpcpb.SplitRegionResponse{ + failpoint.Return(&kvrpcpb.SplitRegionResponse{ RegionError: &errorpb.Error{ ServerIsBusy: &errorpb.ServerIsBusy{}, }, - }, nil - } + }, nil) + }) return client.SplitRegion(ctx, &kvrpcpb.SplitRegionRequest{ Context: &kvrpcpb.Context{ RegionId: regionInfo.Region.Id, @@ -646,9 +646,9 @@ func (bo *splitBackoffer) Attempt() int { } func (c *pdClient) SplitWaitAndScatter(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) { - if _, _err_ := failpoint.Eval(_curpkg_("failToSplit")); _err_ == nil { - return nil, errors.New("retryable error") - } + failpoint.Inject("failToSplit", func(_ failpoint.Value) { + failpoint.Return(nil, errors.New("retryable error")) + }) if len(keys) == 0 { return []*RegionInfo{region}, nil } @@ -764,10 +764,10 @@ func (c *pdClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetO } func (c *pdClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*RegionInfo, error) { - if _, _err_ := failpoint.Eval(_curpkg_("no-leader-error")); _err_ == nil { + failpoint.Inject("no-leader-error", func(_ failpoint.Value) { logutil.CL(ctx).Debug("failpoint no-leader-error injected.") - return nil, status.Error(codes.Unavailable, "not leader") - } + failpoint.Return(nil, status.Error(codes.Unavailable, "not leader")) + }) //nolint:staticcheck regions, err := c.client.ScanRegions(ctx, key, endKey, limit) diff --git a/br/pkg/restore/split/client.go__failpoint_stash__ b/br/pkg/restore/split/client.go__failpoint_stash__ deleted file mode 100644 index 226ecc58360e0..0000000000000 --- a/br/pkg/restore/split/client.go__failpoint_stash__ +++ /dev/null @@ -1,1067 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package split - -import ( - "bytes" - "context" - "crypto/tls" - "slices" - "strconv" - "strings" - "sync" - "time" - - "github.com/docker/go-units" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/errorpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/pingcap/kvproto/pkg/tikvpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/conn/util" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - brlog "github.com/pingcap/tidb/pkg/lightning/log" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/intest" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - "go.uber.org/multierr" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" -) - -const ( - splitRegionMaxRetryTime = 4 -) - -var ( - // the max total key size in a split region batch. - // our threshold should be smaller than TiKV's raft max entry size(default is 8MB). - maxBatchSplitSize = 6 * units.MiB -) - -// SplitClient is an external client used by RegionSplitter. -type SplitClient interface { - // GetStore gets a store by a store id. - GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) - // GetRegion gets a region which includes a specified key. - GetRegion(ctx context.Context, key []byte) (*RegionInfo, error) - // GetRegionByID gets a region by a region id. - GetRegionByID(ctx context.Context, regionID uint64) (*RegionInfo, error) - // SplitKeysAndScatter splits the related regions of the keys and scatters the - // new regions. It returns the new regions that need to be called with - // WaitRegionsScattered. - SplitKeysAndScatter(ctx context.Context, sortedSplitKeys [][]byte) ([]*RegionInfo, error) - - // SplitWaitAndScatter splits a region from a batch of keys, waits for the split - // is finished, and scatters the new regions. It will return the original region, - // new regions and error. The input keys should not be encoded. - // - // The split step has a few retry times. If it meets error, the error is returned - // directly. - // - // The split waiting step has a backoff retry logic, if split has made progress, - // it will not increase the retry counter. Otherwise, it will retry for about 1h. - // If the retry is timeout, it will log a warning and continue. - // - // The scatter step has a few retry times. If it meets error, it will log a - // warning and continue. - // TODO(lance6716): remove this function in interface after BR uses SplitKeysAndScatter. - SplitWaitAndScatter(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) - // GetOperator gets the status of operator of the specified region. - GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) - // ScanRegions gets a list of regions, starts from the region that contains key. - // Limit limits the maximum number of regions returned. - ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*RegionInfo, error) - // GetPlacementRule loads a placement rule from PD. - GetPlacementRule(ctx context.Context, groupID, ruleID string) (*pdhttp.Rule, error) - // SetPlacementRule insert or update a placement rule to PD. - SetPlacementRule(ctx context.Context, rule *pdhttp.Rule) error - // DeletePlacementRule removes a placement rule from PD. - DeletePlacementRule(ctx context.Context, groupID, ruleID string) error - // SetStoresLabel add or update specified label of stores. If labelValue - // is empty, it clears the label. - SetStoresLabel(ctx context.Context, stores []uint64, labelKey, labelValue string) error - // WaitRegionsScattered waits for an already started scatter region action to - // finish. Internally it will backoff and retry at the maximum internal of 2 - // seconds. If the scatter makes progress during the retry, it will not decrease - // the retry counter. If there's always no progress, it will retry for about 1h. - // Caller can set the context timeout to control the max waiting time. - // - // The first return value is always the number of regions that are not finished - // scattering no matter what the error is. - WaitRegionsScattered(ctx context.Context, regionInfos []*RegionInfo) (notFinished int, err error) -} - -// pdClient is a wrapper of pd client, can be used by RegionSplitter. -type pdClient struct { - mu sync.Mutex - client pd.Client - httpCli pdhttp.Client - tlsConf *tls.Config - storeCache map[uint64]*metapb.Store - - // FIXME when config changed during the lifetime of pdClient, - // this may mislead the scatter. - needScatterVal bool - needScatterInit sync.Once - - isRawKv bool - onSplit func(key [][]byte) - splitConcurrency int - splitBatchKeyCnt int -} - -type ClientOptionalParameter func(*pdClient) - -// WithRawKV sets the client to use raw kv mode. -func WithRawKV() ClientOptionalParameter { - return func(c *pdClient) { - c.isRawKv = true - } -} - -// WithOnSplit sets a callback function to be called after each split. -func WithOnSplit(onSplit func(key [][]byte)) ClientOptionalParameter { - return func(c *pdClient) { - c.onSplit = onSplit - } -} - -// NewClient creates a SplitClient. -// -// splitBatchKeyCnt controls how many keys are sent to TiKV in a batch in split -// region API. splitConcurrency controls how many regions are split concurrently. -func NewClient( - client pd.Client, - httpCli pdhttp.Client, - tlsConf *tls.Config, - splitBatchKeyCnt int, - splitConcurrency int, - opts ...ClientOptionalParameter, -) SplitClient { - cli := &pdClient{ - client: client, - httpCli: httpCli, - tlsConf: tlsConf, - storeCache: make(map[uint64]*metapb.Store), - splitBatchKeyCnt: splitBatchKeyCnt, - splitConcurrency: splitConcurrency, - } - for _, opt := range opts { - opt(cli) - } - return cli -} - -func (c *pdClient) needScatter(ctx context.Context) bool { - c.needScatterInit.Do(func() { - var err error - c.needScatterVal, err = c.checkNeedScatter(ctx) - if err != nil { - log.Warn( - "failed to check whether need to scatter, use permissive strategy: always scatter", - logutil.ShortError(err)) - c.needScatterVal = true - } - if !c.needScatterVal { - log.Info("skipping scatter because the replica number isn't less than store count.") - } - }) - return c.needScatterVal -} - -func (c *pdClient) scatterRegions(ctx context.Context, newRegions []*RegionInfo) error { - log.Info("scatter regions", zap.Int("regions", len(newRegions))) - // the retry is for the temporary network errors during sending request. - return utils.WithRetry(ctx, func() error { - err := c.tryScatterRegions(ctx, newRegions) - if isUnsupportedError(err) { - log.Warn("batch scatter isn't supported, rollback to old method", logutil.ShortError(err)) - c.scatterRegionsSequentially( - ctx, newRegions, - // backoff about 6s, or we give up scattering this region. - &ExponentialBackoffer{ - Attempts: 7, - BaseBackoff: 100 * time.Millisecond, - }) - return nil - } - return err - }, &ExponentialBackoffer{Attempts: 3, BaseBackoff: 500 * time.Millisecond}) -} - -func (c *pdClient) tryScatterRegions(ctx context.Context, regionInfo []*RegionInfo) error { - regionsID := make([]uint64, 0, len(regionInfo)) - for _, v := range regionInfo { - regionsID = append(regionsID, v.Region.Id) - log.Debug("scattering regions", logutil.Key("start", v.Region.StartKey), - logutil.Key("end", v.Region.EndKey), - zap.Uint64("id", v.Region.Id)) - } - resp, err := c.client.ScatterRegions(ctx, regionsID, pd.WithSkipStoreLimit()) - if err != nil { - return err - } - if pbErr := resp.GetHeader().GetError(); pbErr.GetType() != pdpb.ErrorType_OK { - return errors.Annotatef(berrors.ErrPDInvalidResponse, - "pd returns error during batch scattering: %s", pbErr) - } - return nil -} - -func (c *pdClient) GetStore(ctx context.Context, storeID uint64) (*metapb.Store, error) { - c.mu.Lock() - defer c.mu.Unlock() - store, ok := c.storeCache[storeID] - if ok { - return store, nil - } - store, err := c.client.GetStore(ctx, storeID) - if err != nil { - return nil, errors.Trace(err) - } - c.storeCache[storeID] = store - return store, nil -} - -func (c *pdClient) GetRegion(ctx context.Context, key []byte) (*RegionInfo, error) { - region, err := c.client.GetRegion(ctx, key) - if err != nil { - return nil, errors.Trace(err) - } - if region == nil { - return nil, nil - } - return &RegionInfo{ - Region: region.Meta, - Leader: region.Leader, - }, nil -} - -func (c *pdClient) GetRegionByID(ctx context.Context, regionID uint64) (*RegionInfo, error) { - region, err := c.client.GetRegionByID(ctx, regionID) - if err != nil { - return nil, errors.Trace(err) - } - if region == nil { - return nil, nil - } - return &RegionInfo{ - Region: region.Meta, - Leader: region.Leader, - PendingPeers: region.PendingPeers, - DownPeers: region.DownPeers, - }, nil -} - -func splitRegionWithFailpoint( - ctx context.Context, - regionInfo *RegionInfo, - peer *metapb.Peer, - client tikvpb.TikvClient, - keys [][]byte, - isRawKv bool, -) (*kvrpcpb.SplitRegionResponse, error) { - failpoint.Inject("not-leader-error", func(injectNewLeader failpoint.Value) { - log.Debug("failpoint not-leader-error injected.") - resp := &kvrpcpb.SplitRegionResponse{ - RegionError: &errorpb.Error{ - NotLeader: &errorpb.NotLeader{ - RegionId: regionInfo.Region.Id, - }, - }, - } - if injectNewLeader.(bool) { - resp.RegionError.NotLeader.Leader = regionInfo.Leader - } - failpoint.Return(resp, nil) - }) - failpoint.Inject("somewhat-retryable-error", func() { - log.Debug("failpoint somewhat-retryable-error injected.") - failpoint.Return(&kvrpcpb.SplitRegionResponse{ - RegionError: &errorpb.Error{ - ServerIsBusy: &errorpb.ServerIsBusy{}, - }, - }, nil) - }) - return client.SplitRegion(ctx, &kvrpcpb.SplitRegionRequest{ - Context: &kvrpcpb.Context{ - RegionId: regionInfo.Region.Id, - RegionEpoch: regionInfo.Region.RegionEpoch, - Peer: peer, - }, - SplitKeys: keys, - IsRawKv: isRawKv, - }) -} - -func (c *pdClient) sendSplitRegionRequest( - ctx context.Context, regionInfo *RegionInfo, keys [][]byte, -) (*kvrpcpb.SplitRegionResponse, error) { - var splitErrors error - for i := 0; i < splitRegionMaxRetryTime; i++ { - retry, result, err := sendSplitRegionRequest(ctx, c, regionInfo, keys, &splitErrors, i) - if retry { - continue - } - if err != nil { - return nil, multierr.Append(splitErrors, err) - } - if result != nil { - return result, nil - } - return nil, errors.Trace(splitErrors) - } - return nil, errors.Trace(splitErrors) -} - -func sendSplitRegionRequest( - ctx context.Context, - c *pdClient, - regionInfo *RegionInfo, - keys [][]byte, - splitErrors *error, - retry int, -) (bool, *kvrpcpb.SplitRegionResponse, error) { - if intest.InTest { - mockCli, ok := c.client.(*MockPDClientForSplit) - if ok { - return mockCli.SplitRegion(regionInfo, keys, c.isRawKv) - } - } - var peer *metapb.Peer - // scanRegions may return empty Leader in https://github.com/tikv/pd/blob/v4.0.8/server/grpc_service.go#L524 - // so wee also need check Leader.Id != 0 - if regionInfo.Leader != nil && regionInfo.Leader.Id != 0 { - peer = regionInfo.Leader - } else { - if len(regionInfo.Region.Peers) == 0 { - return false, nil, - errors.Annotatef(berrors.ErrRestoreNoPeer, "region[%d] doesn't have any peer", - regionInfo.Region.GetId()) - } - peer = regionInfo.Region.Peers[0] - } - storeID := peer.GetStoreId() - store, err := c.GetStore(ctx, storeID) - if err != nil { - return false, nil, err - } - opt := grpc.WithTransportCredentials(insecure.NewCredentials()) - if c.tlsConf != nil { - opt = grpc.WithTransportCredentials(credentials.NewTLS(c.tlsConf)) - } - conn, err := grpc.Dial(store.GetAddress(), opt, - config.DefaultGrpcKeepaliveParams) - if err != nil { - return false, nil, err - } - defer conn.Close() - client := tikvpb.NewTikvClient(conn) - resp, err := splitRegionWithFailpoint(ctx, regionInfo, peer, client, keys, c.isRawKv) - if err != nil { - return false, nil, err - } - if resp.RegionError != nil { - log.Warn("fail to split region", - logutil.Region(regionInfo.Region), - logutil.Keys(keys), - zap.Stringer("regionErr", resp.RegionError)) - *splitErrors = multierr.Append(*splitErrors, - errors.Annotatef(berrors.ErrRestoreSplitFailed, "split region failed: err=%v", resp.RegionError)) - if nl := resp.RegionError.NotLeader; nl != nil { - if leader := nl.GetLeader(); leader != nil { - regionInfo.Leader = leader - } else { - newRegionInfo, findLeaderErr := c.GetRegionByID(ctx, nl.RegionId) - if findLeaderErr != nil { - return false, nil, findLeaderErr - } - if !CheckRegionEpoch(newRegionInfo, regionInfo) { - return false, nil, berrors.ErrKVEpochNotMatch - } - log.Info("find new leader", zap.Uint64("new leader", newRegionInfo.Leader.Id)) - regionInfo = newRegionInfo - } - log.Info("split region meet not leader error, retrying", - zap.Int("retry times", retry), - zap.Uint64("regionID", regionInfo.Region.Id), - zap.Any("new leader", regionInfo.Leader), - ) - return true, nil, nil - } - // TODO: we don't handle RegionNotMatch and RegionNotFound here, - // because I think we don't have enough information to retry. - // But maybe we can handle them here by some information the error itself provides. - if resp.RegionError.ServerIsBusy != nil || - resp.RegionError.StaleCommand != nil { - log.Warn("a error occurs on split region", - zap.Int("retry times", retry), - zap.Uint64("regionID", regionInfo.Region.Id), - zap.String("error", resp.RegionError.Message), - zap.Any("error verbose", resp.RegionError), - ) - return true, nil, nil - } - return false, nil, nil - } - return false, resp, nil -} - -// batchSplitRegionsWithOrigin calls the batch split region API and groups the -// returned regions into two groups: the region with the same ID as the origin, -// and the other regions. The former does not need to be scattered while the -// latter need to be scattered. -// -// Depending on the TiKV configuration right-derive-when-split, the origin region -// can be the first return region or the last return region. -func (c *pdClient) batchSplitRegionsWithOrigin( - ctx context.Context, regionInfo *RegionInfo, keys [][]byte, -) (*RegionInfo, []*RegionInfo, error) { - resp, err := c.sendSplitRegionRequest(ctx, regionInfo, keys) - if err != nil { - return nil, nil, errors.Trace(err) - } - - regions := resp.GetRegions() - newRegionInfos := make([]*RegionInfo, 0, len(regions)) - var originRegion *RegionInfo - for _, region := range regions { - var leader *metapb.Peer - - // Assume the leaders will be at the same store. - if regionInfo.Leader != nil { - for _, p := range region.GetPeers() { - if p.GetStoreId() == regionInfo.Leader.GetStoreId() { - leader = p - break - } - } - } - // original region - if region.GetId() == regionInfo.Region.GetId() { - originRegion = &RegionInfo{ - Region: region, - Leader: leader, - } - continue - } - newRegionInfos = append(newRegionInfos, &RegionInfo{ - Region: region, - Leader: leader, - }) - } - return originRegion, newRegionInfos, nil -} - -func (c *pdClient) waitRegionsSplit(ctx context.Context, newRegions []*RegionInfo) error { - backoffer := NewBackoffMayNotCountBackoffer() - needRecheck := make([]*RegionInfo, 0, len(newRegions)) - return utils.WithRetryReturnLastErr(ctx, func() error { - needRecheck = needRecheck[:0] - - for _, r := range newRegions { - regionID := r.Region.GetId() - - ok, err := c.hasHealthyRegion(ctx, regionID) - if !ok || err != nil { - if err != nil { - brlog.FromContext(ctx).Warn( - "wait for split failed", - zap.Uint64("regionID", regionID), - zap.Error(err), - ) - } - needRecheck = append(needRecheck, r) - } - } - - if len(needRecheck) == 0 { - return nil - } - - backoffErr := ErrBackoff - // if made progress in this round, don't increase the retryCnt - if len(needRecheck) < len(newRegions) { - backoffErr = ErrBackoffAndDontCount - } - newRegions = slices.Clone(needRecheck) - - return errors.Annotatef( - backoffErr, - "WaitRegionsSplit not finished, needRecheck: %d, the first unfinished region: %s", - len(needRecheck), needRecheck[0].Region.String(), - ) - }, backoffer) -} - -func (c *pdClient) hasHealthyRegion(ctx context.Context, regionID uint64) (bool, error) { - regionInfo, err := c.GetRegionByID(ctx, regionID) - if err != nil { - return false, errors.Trace(err) - } - // the region hasn't get ready. - if regionInfo == nil { - return false, nil - } - - // check whether the region is healthy and report. - // TODO: the log may be too verbose. we should use Prometheus metrics once it get ready for BR. - for _, peer := range regionInfo.PendingPeers { - log.Debug("unhealthy region detected", logutil.Peer(peer), zap.String("type", "pending")) - } - for _, peer := range regionInfo.DownPeers { - log.Debug("unhealthy region detected", logutil.Peer(peer), zap.String("type", "down")) - } - // we ignore down peers for they are (normally) hard to be fixed in reasonable time. - // (or once there is a peer down, we may get stuck at waiting region get ready.) - return len(regionInfo.PendingPeers) == 0, nil -} - -func (c *pdClient) SplitKeysAndScatter(ctx context.Context, sortedSplitKeys [][]byte) ([]*RegionInfo, error) { - if len(sortedSplitKeys) == 0 { - return nil, nil - } - // we need to find the regions that contain the split keys. However, the scan - // region API accepts a key range [start, end) where end key is exclusive, and if - // sortedSplitKeys length is 1, scan region may return empty result. So we - // increase the end key a bit. If the end key is on the region boundaries, it - // will be skipped by getSplitKeysOfRegions. - scanStart := codec.EncodeBytesExt(nil, sortedSplitKeys[0], c.isRawKv) - lastKey := kv.Key(sortedSplitKeys[len(sortedSplitKeys)-1]) - if len(lastKey) > 0 { - lastKey = lastKey.Next() - } - scanEnd := codec.EncodeBytesExt(nil, lastKey, c.isRawKv) - - // mu protects ret, retrySplitKeys, lastSplitErr - mu := sync.Mutex{} - ret := make([]*RegionInfo, 0, len(sortedSplitKeys)+1) - retrySplitKeys := make([][]byte, 0, len(sortedSplitKeys)) - var lastSplitErr error - - err := utils.WithRetryReturnLastErr(ctx, func() error { - ret = ret[:0] - - if len(retrySplitKeys) > 0 { - scanStart = codec.EncodeBytesExt(nil, retrySplitKeys[0], c.isRawKv) - lastKey2 := kv.Key(retrySplitKeys[len(retrySplitKeys)-1]) - scanEnd = codec.EncodeBytesExt(nil, lastKey2.Next(), c.isRawKv) - } - regions, err := PaginateScanRegion(ctx, c, scanStart, scanEnd, ScanRegionPaginationLimit) - if err != nil { - return err - } - log.Info("paginate scan regions", - zap.Int("count", len(regions)), - logutil.Key("start", scanStart), - logutil.Key("end", scanEnd)) - - allSplitKeys := sortedSplitKeys - if len(retrySplitKeys) > 0 { - allSplitKeys = retrySplitKeys - retrySplitKeys = retrySplitKeys[:0] - } - splitKeyMap := getSplitKeysOfRegions(allSplitKeys, regions, c.isRawKv) - workerPool := tidbutil.NewWorkerPool(uint(c.splitConcurrency), "split keys") - eg, eCtx := errgroup.WithContext(ctx) - for region, splitKeys := range splitKeyMap { - region := region - splitKeys := splitKeys - workerPool.ApplyOnErrorGroup(eg, func() error { - // TODO(lance6716): add error handling to retry from scan or retry from split - newRegions, err2 := c.SplitWaitAndScatter(eCtx, region, splitKeys) - if err2 != nil { - if common.IsContextCanceledError(err2) { - return err2 - } - log.Warn("split and scatter region meet error, will retry", - zap.Uint64("region_id", region.Region.Id), - zap.Error(err2)) - mu.Lock() - retrySplitKeys = append(retrySplitKeys, splitKeys...) - lastSplitErr = err2 - mu.Unlock() - return nil - } - - if len(newRegions) != len(splitKeys) { - log.Warn("split key count and new region count mismatch", - zap.Int("new region count", len(newRegions)), - zap.Int("split key count", len(splitKeys))) - } - mu.Lock() - ret = append(ret, newRegions...) - mu.Unlock() - return nil - }) - } - if err2 := eg.Wait(); err2 != nil { - return err2 - } - if len(retrySplitKeys) == 0 { - return nil - } - slices.SortFunc(retrySplitKeys, bytes.Compare) - return lastSplitErr - }, newSplitBackoffer()) - return ret, errors.Trace(err) -} - -type splitBackoffer struct { - state utils.RetryState -} - -func newSplitBackoffer() *splitBackoffer { - return &splitBackoffer{ - state: utils.InitialRetryState(SplitRetryTimes, SplitRetryInterval, SplitMaxRetryInterval), - } -} - -func (bo *splitBackoffer) NextBackoff(err error) time.Duration { - if berrors.ErrInvalidRange.Equal(err) { - bo.state.GiveUp() - return 0 - } - return bo.state.ExponentialBackoff() -} - -func (bo *splitBackoffer) Attempt() int { - return bo.state.Attempt() -} - -func (c *pdClient) SplitWaitAndScatter(ctx context.Context, region *RegionInfo, keys [][]byte) ([]*RegionInfo, error) { - failpoint.Inject("failToSplit", func(_ failpoint.Value) { - failpoint.Return(nil, errors.New("retryable error")) - }) - if len(keys) == 0 { - return []*RegionInfo{region}, nil - } - - var ( - start, end = 0, 0 - batchSize = 0 - newRegions = make([]*RegionInfo, 0, len(keys)) - ) - - for end <= len(keys) { - if end == len(keys) || - batchSize+len(keys[end]) > maxBatchSplitSize || - end-start >= c.splitBatchKeyCnt { - // split, wait and scatter for this batch - originRegion, newRegionsOfBatch, err := c.batchSplitRegionsWithOrigin(ctx, region, keys[start:end]) - if err != nil { - return nil, errors.Trace(err) - } - err = c.waitRegionsSplit(ctx, newRegionsOfBatch) - if err != nil { - brlog.FromContext(ctx).Warn( - "wait regions split failed, will continue anyway", - zap.Error(err), - ) - } - if err = ctx.Err(); err != nil { - return nil, errors.Trace(err) - } - err = c.scatterRegions(ctx, newRegionsOfBatch) - if err != nil { - brlog.FromContext(ctx).Warn( - "scatter regions failed, will continue anyway", - zap.Error(err), - ) - } - if c.onSplit != nil { - c.onSplit(keys[start:end]) - } - - // the region with the max start key is the region need to be further split, - // depending on the origin region is the first region or last region, we need to - // compare the origin region and the last one of new regions. - lastNewRegion := newRegionsOfBatch[len(newRegionsOfBatch)-1] - if bytes.Compare(originRegion.Region.StartKey, lastNewRegion.Region.StartKey) < 0 { - region = lastNewRegion - } else { - region = originRegion - } - newRegions = append(newRegions, newRegionsOfBatch...) - batchSize = 0 - start = end - } - - if end < len(keys) { - batchSize += len(keys[end]) - } - end++ - } - - return newRegions, errors.Trace(ctx.Err()) -} - -func (c *pdClient) getStoreCount(ctx context.Context) (int, error) { - stores, err := util.GetAllTiKVStores(ctx, c.client, util.SkipTiFlash) - if err != nil { - return 0, err - } - return len(stores), err -} - -func (c *pdClient) getMaxReplica(ctx context.Context) (int, error) { - resp, err := c.httpCli.GetReplicateConfig(ctx) - if err != nil { - return 0, errors.Trace(err) - } - key := "max-replicas" - val, ok := resp[key] - if !ok { - return 0, errors.Errorf("key %s not found in response %v", key, resp) - } - return int(val.(float64)), nil -} - -func (c *pdClient) checkNeedScatter(ctx context.Context) (bool, error) { - storeCount, err := c.getStoreCount(ctx) - if err != nil { - return false, err - } - maxReplica, err := c.getMaxReplica(ctx) - if err != nil { - return false, err - } - log.Info("checking whether need to scatter", zap.Int("store", storeCount), zap.Int("max-replica", maxReplica)) - // Skipping scatter may lead to leader unbalanced, - // currently, we skip scatter only when: - // 1. max-replica > store-count (Probably a misconfigured or playground cluster.) - // 2. store-count == 1 (No meaning for scattering.) - // We can still omit scatter when `max-replica == store-count`, if we create a BalanceLeader operator here, - // however, there isn't evidence for transform leader is much faster than scattering empty regions. - return storeCount >= maxReplica && storeCount > 1, nil -} - -func (c *pdClient) scatterRegion(ctx context.Context, regionInfo *RegionInfo) error { - if !c.needScatter(ctx) { - return nil - } - return c.client.ScatterRegion(ctx, regionInfo.Region.GetId()) -} - -func (c *pdClient) GetOperator(ctx context.Context, regionID uint64) (*pdpb.GetOperatorResponse, error) { - return c.client.GetOperator(ctx, regionID) -} - -func (c *pdClient) ScanRegions(ctx context.Context, key, endKey []byte, limit int) ([]*RegionInfo, error) { - failpoint.Inject("no-leader-error", func(_ failpoint.Value) { - logutil.CL(ctx).Debug("failpoint no-leader-error injected.") - failpoint.Return(nil, status.Error(codes.Unavailable, "not leader")) - }) - - //nolint:staticcheck - regions, err := c.client.ScanRegions(ctx, key, endKey, limit) - if err != nil { - return nil, errors.Trace(err) - } - regionInfos := make([]*RegionInfo, 0, len(regions)) - for _, region := range regions { - regionInfos = append(regionInfos, &RegionInfo{ - Region: region.Meta, - Leader: region.Leader, - }) - } - return regionInfos, nil -} - -func (c *pdClient) GetPlacementRule(ctx context.Context, groupID, ruleID string) (*pdhttp.Rule, error) { - resp, err := c.httpCli.GetPlacementRule(ctx, groupID, ruleID) - return resp, errors.Trace(err) -} - -func (c *pdClient) SetPlacementRule(ctx context.Context, rule *pdhttp.Rule) error { - return c.httpCli.SetPlacementRule(ctx, rule) -} - -func (c *pdClient) DeletePlacementRule(ctx context.Context, groupID, ruleID string) error { - return c.httpCli.DeletePlacementRule(ctx, groupID, ruleID) -} - -func (c *pdClient) SetStoresLabel( - ctx context.Context, stores []uint64, labelKey, labelValue string, -) error { - m := map[string]string{labelKey: labelValue} - for _, id := range stores { - err := c.httpCli.SetStoreLabels(ctx, int64(id), m) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -func (c *pdClient) scatterRegionsSequentially(ctx context.Context, newRegions []*RegionInfo, backoffer utils.Backoffer) { - newRegionSet := make(map[uint64]*RegionInfo, len(newRegions)) - for _, newRegion := range newRegions { - newRegionSet[newRegion.Region.Id] = newRegion - } - - if err := utils.WithRetry(ctx, func() error { - log.Info("trying to scatter regions...", zap.Int("remain", len(newRegionSet))) - var errs error - for _, region := range newRegionSet { - err := c.scatterRegion(ctx, region) - if err == nil { - // it is safe according to the Go language spec. - delete(newRegionSet, region.Region.Id) - } else if !PdErrorCanRetry(err) { - log.Warn("scatter meet error cannot be retried, skipping", - logutil.ShortError(err), - logutil.Region(region.Region), - ) - delete(newRegionSet, region.Region.Id) - } - errs = multierr.Append(errs, err) - } - return errs - }, backoffer); err != nil { - log.Warn("Some regions haven't been scattered because errors.", - zap.Int("count", len(newRegionSet)), - // if all region are failed to scatter, the short error might also be verbose... - logutil.ShortError(err), - logutil.AbbreviatedArray("failed-regions", newRegionSet, func(i any) []string { - m := i.(map[uint64]*RegionInfo) - result := make([]string, 0, len(m)) - for id := range m { - result = append(result, strconv.Itoa(int(id))) - } - return result - }), - ) - } -} - -func (c *pdClient) isScatterRegionFinished( - ctx context.Context, - regionID uint64, -) (scatterDone bool, needRescatter bool, scatterErr error) { - resp, err := c.GetOperator(ctx, regionID) - if err != nil { - if common.IsRetryableError(err) { - // retry in the next cycle - return false, false, nil - } - return false, false, errors.Trace(err) - } - return isScatterRegionFinished(resp) -} - -func (c *pdClient) WaitRegionsScattered(ctx context.Context, regions []*RegionInfo) (int, error) { - var ( - backoffer = NewBackoffMayNotCountBackoffer() - retryCnt = -1 - needRescatter = make([]*RegionInfo, 0, len(regions)) - needRecheck = make([]*RegionInfo, 0, len(regions)) - ) - - err := utils.WithRetryReturnLastErr(ctx, func() error { - retryCnt++ - loggedInThisRound := false - needRecheck = needRecheck[:0] - needRescatter = needRescatter[:0] - - for i, region := range regions { - regionID := region.Region.GetId() - - if retryCnt > 10 && !loggedInThisRound { - loggedInThisRound = true - resp, err := c.GetOperator(ctx, regionID) - brlog.FromContext(ctx).Info( - "retried many times to wait for scattering regions, checking operator", - zap.Int("retryCnt", retryCnt), - zap.Uint64("firstRegionID", regionID), - zap.Stringer("response", resp), - zap.Error(err), - ) - } - - ok, rescatter, err := c.isScatterRegionFinished(ctx, regionID) - if err != nil { - if !common.IsRetryableError(err) { - brlog.FromContext(ctx).Warn( - "wait for scatter region encountered non-retryable error", - logutil.Region(region.Region), - zap.Error(err), - ) - needRecheck = append(needRecheck, regions[i:]...) - return err - } - // if meet retryable error, recheck this region in next round - brlog.FromContext(ctx).Warn( - "wait for scatter region encountered error, will retry again", - logutil.Region(region.Region), - zap.Error(err), - ) - needRecheck = append(needRecheck, region) - continue - } - - if ok { - continue - } - // not finished scattered, check again in next round - needRecheck = append(needRecheck, region) - - if rescatter { - needRescatter = append(needRescatter, region) - } - } - - if len(needRecheck) == 0 { - return nil - } - - backoffErr := ErrBackoff - // if made progress in this round, don't increase the retryCnt - if len(needRecheck) < len(regions) { - backoffErr = ErrBackoffAndDontCount - } - - regions = slices.Clone(needRecheck) - - if len(needRescatter) > 0 { - scatterErr := c.scatterRegions(ctx, needRescatter) - if scatterErr != nil { - if !common.IsRetryableError(scatterErr) { - return scatterErr - } - - return errors.Annotate(backoffErr, scatterErr.Error()) - } - } - return errors.Annotatef( - backoffErr, - "scatter region not finished, retryCnt: %d, needRecheck: %d, needRescatter: %d, the first unfinished region: %s", - retryCnt, len(needRecheck), len(needRescatter), needRecheck[0].Region.String(), - ) - }, backoffer) - - return len(needRecheck), err -} - -// isScatterRegionFinished checks whether the scatter region operator is -// finished. -func isScatterRegionFinished(resp *pdpb.GetOperatorResponse) ( - scatterDone bool, - needRescatter bool, - scatterErr error, -) { - // Heartbeat may not be sent to PD - if respErr := resp.GetHeader().GetError(); respErr != nil { - if respErr.GetType() == pdpb.ErrorType_REGION_NOT_FOUND { - return true, false, nil - } - return false, false, errors.Annotatef( - berrors.ErrPDInvalidResponse, - "get operator error: %s, error message: %s", - respErr.GetType(), - respErr.GetMessage(), - ) - } - // that 'scatter-operator' has finished - if string(resp.GetDesc()) != "scatter-region" { - return true, false, nil - } - switch resp.GetStatus() { - case pdpb.OperatorStatus_SUCCESS: - return true, false, nil - case pdpb.OperatorStatus_RUNNING: - return false, false, nil - default: - return false, true, nil - } -} - -// CheckRegionEpoch check region epoch. -func CheckRegionEpoch(_new, _old *RegionInfo) bool { - return _new.Region.GetId() == _old.Region.GetId() && - _new.Region.GetRegionEpoch().GetVersion() == _old.Region.GetRegionEpoch().GetVersion() && - _new.Region.GetRegionEpoch().GetConfVer() == _old.Region.GetRegionEpoch().GetConfVer() -} - -// ExponentialBackoffer trivially retry any errors it meets. -// It's useful when the caller has handled the errors but -// only want to a more semantic backoff implementation. -type ExponentialBackoffer struct { - Attempts int - BaseBackoff time.Duration -} - -func (b *ExponentialBackoffer) exponentialBackoff() time.Duration { - bo := b.BaseBackoff - b.Attempts-- - if b.Attempts == 0 { - return 0 - } - b.BaseBackoff *= 2 - return bo -} - -// PdErrorCanRetry when pd error retry. -func PdErrorCanRetry(err error) bool { - // There are 3 type of reason that PD would reject a `scatter` request: - // (1) region %d has no leader - // (2) region %d is hot - // (3) region %d is not fully replicated - // - // (2) shouldn't happen in a recently splitted region. - // (1) and (3) might happen, and should be retried. - grpcErr := status.Convert(err) - if grpcErr == nil { - return false - } - return strings.Contains(grpcErr.Message(), "is not fully replicated") || - strings.Contains(grpcErr.Message(), "has no leader") -} - -// NextBackoff returns a duration to wait before retrying again. -func (b *ExponentialBackoffer) NextBackoff(error) time.Duration { - // trivially exponential back off, because we have handled the error at upper level. - return b.exponentialBackoff() -} - -// Attempt returns the remain attempt times -func (b *ExponentialBackoffer) Attempt() int { - return b.Attempts -} - -// isUnsupportedError checks whether we should fallback to ScatterRegion API when meeting the error. -func isUnsupportedError(err error) bool { - s, ok := status.FromError(errors.Cause(err)) - if !ok { - // Not a gRPC error. Something other went wrong. - return false - } - // In two conditions, we fallback to ScatterRegion: - // (1) If the RPC endpoint returns UNIMPLEMENTED. (This is just for making test cases not be so magic.) - // (2) If the Message is "region 0 not found": - // In fact, PD reuses the gRPC endpoint `ScatterRegion` for the batch version of scattering. - // When the request contains the field `regionIDs`, it would use the batch version, - // Otherwise, it uses the old version and scatter the region with `regionID` in the request. - // When facing 4.x, BR(which uses v5.x PD clients and call `ScatterRegions`!) would set `regionIDs` - // which would be ignored by protocol buffers, and leave the `regionID` be zero. - // Then the older version of PD would try to search the region with ID 0. - // (Then it consistently fails, and returns "region 0 not found".) - return s.Code() == codes.Unimplemented || - strings.Contains(s.Message(), "region 0 not found") -} diff --git a/br/pkg/restore/split/split.go b/br/pkg/restore/split/split.go index fef8899ab10a9..ce6faa90b209c 100644 --- a/br/pkg/restore/split/split.go +++ b/br/pkg/restore/split/split.go @@ -233,11 +233,11 @@ func (b *WaitRegionOnlineBackoffer) NextBackoff(err error) time.Duration { // it needs more time to wait splitting the regions that contains data in PITR. // 2s * 150 delayTime := b.Stat.ExponentialBackoff() - if val, _err_ := failpoint.Eval(_curpkg_("hint-scan-region-backoff")); _err_ == nil { + failpoint.Inject("hint-scan-region-backoff", func(val failpoint.Value) { if val.(bool) { delayTime = time.Microsecond } - } + }) return delayTime } b.Stat.GiveUp() diff --git a/br/pkg/restore/split/split.go__failpoint_stash__ b/br/pkg/restore/split/split.go__failpoint_stash__ deleted file mode 100644 index ce6faa90b209c..0000000000000 --- a/br/pkg/restore/split/split.go__failpoint_stash__ +++ /dev/null @@ -1,352 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package split - -import ( - "bytes" - "context" - "encoding/hex" - goerrors "errors" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/redact" - "go.uber.org/zap" -) - -var ( - WaitRegionOnlineAttemptTimes = config.DefaultRegionCheckBackoffLimit - SplitRetryTimes = 150 -) - -// Constants for split retry machinery. -const ( - SplitRetryInterval = 50 * time.Millisecond - SplitMaxRetryInterval = 4 * time.Second - - // it takes 30 minutes to scatter regions when each TiKV has 400k regions - ScatterWaitUpperInterval = 30 * time.Minute - - ScanRegionPaginationLimit = 128 -) - -func checkRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { - // current pd can't guarantee the consistency of returned regions - if len(regions) == 0 { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan region return empty result, startKey: %s, endKey: %s", - redact.Key(startKey), redact.Key(endKey)) - } - - if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "first region %d's startKey(%s) > startKey(%s), region epoch: %s", - regions[0].Region.Id, - redact.Key(regions[0].Region.StartKey), redact.Key(startKey), - regions[0].Region.RegionEpoch.String()) - } else if len(regions[len(regions)-1].Region.EndKey) != 0 && - bytes.Compare(regions[len(regions)-1].Region.EndKey, endKey) < 0 { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "last region %d's endKey(%s) < endKey(%s), region epoch: %s", - regions[len(regions)-1].Region.Id, - redact.Key(regions[len(regions)-1].Region.EndKey), redact.Key(endKey), - regions[len(regions)-1].Region.RegionEpoch.String()) - } - - cur := regions[0] - if cur.Leader == nil { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "region %d's leader is nil", cur.Region.Id) - } - if cur.Leader.StoreId == 0 { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "region %d's leader's store id is 0", cur.Region.Id) - } - for _, r := range regions[1:] { - if r.Leader == nil { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "region %d's leader is nil", r.Region.Id) - } - if r.Leader.StoreId == 0 { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "region %d's leader's store id is 0", r.Region.Id) - } - if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "region %d's endKey not equal to next region %d's startKey, endKey: %s, startKey: %s, region epoch: %s %s", - cur.Region.Id, r.Region.Id, - redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey), - cur.Region.RegionEpoch.String(), r.Region.RegionEpoch.String()) - } - cur = r - } - - return nil -} - -// PaginateScanRegion scan regions with a limit pagination and return all regions -// at once. The returned regions are continuous and cover the key range. If not, -// or meet errors, it will retry internally. -func PaginateScanRegion( - ctx context.Context, client SplitClient, startKey, endKey []byte, limit int, -) ([]*RegionInfo, error) { - if len(endKey) != 0 && bytes.Compare(startKey, endKey) > 0 { - return nil, errors.Annotatef(berrors.ErrInvalidRange, "startKey > endKey, startKey: %s, endkey: %s", - hex.EncodeToString(startKey), hex.EncodeToString(endKey)) - } - - var ( - lastRegions []*RegionInfo - err error - backoffer = NewWaitRegionOnlineBackoffer() - ) - _ = utils.WithRetry(ctx, func() error { - regions := make([]*RegionInfo, 0, 16) - scanStartKey := startKey - for { - var batch []*RegionInfo - batch, err = client.ScanRegions(ctx, scanStartKey, endKey, limit) - if err != nil { - err = errors.Annotatef(berrors.ErrPDBatchScanRegion.Wrap(err), "scan regions from start-key:%s, err: %s", - redact.Key(scanStartKey), err.Error()) - return err - } - regions = append(regions, batch...) - if len(batch) < limit { - // No more region - break - } - scanStartKey = batch[len(batch)-1].Region.GetEndKey() - if len(scanStartKey) == 0 || - (len(endKey) > 0 && bytes.Compare(scanStartKey, endKey) >= 0) { - // All key space have scanned - break - } - } - // if the number of regions changed, we can infer TiKV side really - // made some progress so don't increase the retry times. - if len(regions) != len(lastRegions) { - backoffer.Stat.ReduceRetry() - } - lastRegions = regions - - if err = checkRegionConsistency(startKey, endKey, regions); err != nil { - log.Warn("failed to scan region, retrying", - logutil.ShortError(err), - zap.Int("regionLength", len(regions))) - return err - } - return nil - }, backoffer) - - return lastRegions, err -} - -// checkPartRegionConsistency only checks the continuity of regions and the first region consistency. -func checkPartRegionConsistency(startKey, endKey []byte, regions []*RegionInfo) error { - // current pd can't guarantee the consistency of returned regions - if len(regions) == 0 { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "scan region return empty result, startKey: %s, endKey: %s", - redact.Key(startKey), redact.Key(endKey)) - } - - if bytes.Compare(regions[0].Region.StartKey, startKey) > 0 { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "first region's startKey > startKey, startKey: %s, regionStartKey: %s", - redact.Key(startKey), redact.Key(regions[0].Region.StartKey)) - } - - cur := regions[0] - for _, r := range regions[1:] { - if !bytes.Equal(cur.Region.EndKey, r.Region.StartKey) { - return errors.Annotatef(berrors.ErrPDBatchScanRegion, - "region endKey not equal to next region startKey, endKey: %s, startKey: %s", - redact.Key(cur.Region.EndKey), redact.Key(r.Region.StartKey)) - } - cur = r - } - - return nil -} - -func ScanRegionsWithRetry( - ctx context.Context, client SplitClient, startKey, endKey []byte, limit int, -) ([]*RegionInfo, error) { - if len(endKey) != 0 && bytes.Compare(startKey, endKey) > 0 { - return nil, errors.Annotatef(berrors.ErrInvalidRange, "startKey > endKey, startKey: %s, endkey: %s", - hex.EncodeToString(startKey), hex.EncodeToString(endKey)) - } - - var regions []*RegionInfo - var err error - // we don't need to return multierr. since there only 3 times retry. - // in most case 3 times retry have the same error. so we just return the last error. - // actually we'd better remove all multierr in br/lightning. - // because it's not easy to check multierr equals normal error. - // see https://github.com/pingcap/tidb/issues/33419. - _ = utils.WithRetry(ctx, func() error { - regions, err = client.ScanRegions(ctx, startKey, endKey, limit) - if err != nil { - err = errors.Annotatef(berrors.ErrPDBatchScanRegion, "scan regions from start-key:%s, err: %s", - redact.Key(startKey), err.Error()) - return err - } - - if err = checkPartRegionConsistency(startKey, endKey, regions); err != nil { - log.Warn("failed to scan region, retrying", logutil.ShortError(err)) - return err - } - - return nil - }, NewWaitRegionOnlineBackoffer()) - - return regions, err -} - -type WaitRegionOnlineBackoffer struct { - Stat utils.RetryState -} - -// NewWaitRegionOnlineBackoffer create a backoff to wait region online. -func NewWaitRegionOnlineBackoffer() *WaitRegionOnlineBackoffer { - return &WaitRegionOnlineBackoffer{ - Stat: utils.InitialRetryState( - WaitRegionOnlineAttemptTimes, - time.Millisecond*10, - time.Second*2, - ), - } -} - -// NextBackoff returns a duration to wait before retrying again -func (b *WaitRegionOnlineBackoffer) NextBackoff(err error) time.Duration { - // TODO(lance6716): why we only backoff when the error is ErrPDBatchScanRegion? - var perr *errors.Error - if goerrors.As(err, &perr) && berrors.ErrPDBatchScanRegion.ID() == perr.ID() { - // it needs more time to wait splitting the regions that contains data in PITR. - // 2s * 150 - delayTime := b.Stat.ExponentialBackoff() - failpoint.Inject("hint-scan-region-backoff", func(val failpoint.Value) { - if val.(bool) { - delayTime = time.Microsecond - } - }) - return delayTime - } - b.Stat.GiveUp() - return 0 -} - -// Attempt returns the remain attempt times -func (b *WaitRegionOnlineBackoffer) Attempt() int { - return b.Stat.Attempt() -} - -// BackoffMayNotCountBackoffer is a backoffer but it may not increase the retry -// counter. It should be used with ErrBackoff or ErrBackoffAndDontCount. -type BackoffMayNotCountBackoffer struct { - state utils.RetryState -} - -var ( - ErrBackoff = errors.New("found backoff error") - ErrBackoffAndDontCount = errors.New("found backoff error but don't count") -) - -// NewBackoffMayNotCountBackoffer creates a new backoffer that may backoff or retry. -// -// TODO: currently it has the same usage as NewWaitRegionOnlineBackoffer so we -// don't expose its inner settings. -func NewBackoffMayNotCountBackoffer() *BackoffMayNotCountBackoffer { - return &BackoffMayNotCountBackoffer{ - state: utils.InitialRetryState( - WaitRegionOnlineAttemptTimes, - time.Millisecond*10, - time.Second*2, - ), - } -} - -// NextBackoff implements utils.Backoffer. For BackoffMayNotCountBackoffer, only -// ErrBackoff and ErrBackoffAndDontCount is meaningful. -func (b *BackoffMayNotCountBackoffer) NextBackoff(err error) time.Duration { - if errors.ErrorEqual(err, ErrBackoff) { - return b.state.ExponentialBackoff() - } - if errors.ErrorEqual(err, ErrBackoffAndDontCount) { - delay := b.state.ExponentialBackoff() - b.state.ReduceRetry() - return delay - } - b.state.GiveUp() - return 0 -} - -// Attempt implements utils.Backoffer. -func (b *BackoffMayNotCountBackoffer) Attempt() int { - return b.state.Attempt() -} - -// getSplitKeysOfRegions checks every input key is necessary to split region on -// it. Returns a map from region to split keys belongs to it. -// -// The key will be skipped if it's the region boundary. -// -// prerequisite: -// - sortedKeys are sorted in ascending order. -// - sortedRegions are continuous and sorted in ascending order by start key. -// - sortedRegions can cover all keys in sortedKeys. -// PaginateScanRegion should satisfy the above prerequisites. -func getSplitKeysOfRegions( - sortedKeys [][]byte, - sortedRegions []*RegionInfo, - isRawKV bool, -) map[*RegionInfo][][]byte { - splitKeyMap := make(map[*RegionInfo][][]byte, len(sortedRegions)) - curKeyIndex := 0 - splitKey := codec.EncodeBytesExt(nil, sortedKeys[curKeyIndex], isRawKV) - - for _, region := range sortedRegions { - for { - if len(sortedKeys[curKeyIndex]) == 0 { - // should not happen? - goto nextKey - } - // If splitKey is the boundary of the region, don't need to split on it. - if bytes.Equal(splitKey, region.Region.GetStartKey()) { - goto nextKey - } - // If splitKey is not in this region, we should move to the next region. - if !region.ContainsInterior(splitKey) { - break - } - - splitKeyMap[region] = append(splitKeyMap[region], sortedKeys[curKeyIndex]) - - nextKey: - curKeyIndex++ - if curKeyIndex >= len(sortedKeys) { - return splitKeyMap - } - splitKey = codec.EncodeBytesExt(nil, sortedKeys[curKeyIndex], isRawKV) - } - } - lastKey := sortedKeys[len(sortedKeys)-1] - endOfLastRegion := sortedRegions[len(sortedRegions)-1].Region.GetEndKey() - if !bytes.Equal(lastKey, endOfLastRegion) { - log.Error("in getSplitKeysOfRegions, regions don't cover all keys", - zap.String("firstKey", hex.EncodeToString(sortedKeys[0])), - zap.String("lastKey", hex.EncodeToString(lastKey)), - zap.String("firstRegionStartKey", hex.EncodeToString(sortedRegions[0].Region.GetStartKey())), - zap.String("lastRegionEndKey", hex.EncodeToString(endOfLastRegion)), - ) - } - return splitKeyMap -} diff --git a/br/pkg/storage/binding__failpoint_binding__.go b/br/pkg/storage/binding__failpoint_binding__.go deleted file mode 100644 index a1a747a15d57f..0000000000000 --- a/br/pkg/storage/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package storage - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/storage/s3.go b/br/pkg/storage/s3.go index 0755077968f39..3987512b2a0a2 100644 --- a/br/pkg/storage/s3.go +++ b/br/pkg/storage/s3.go @@ -590,10 +590,10 @@ func (rs *S3Storage) ReadFile(ctx context.Context, file string) ([]byte, error) // close the body of response since data has been already read out result.Body.Close() // for unit test - if _, _err_ := failpoint.Eval(_curpkg_("read-s3-body-failed")); _err_ == nil { + failpoint.Inject("read-s3-body-failed", func(_ failpoint.Value) { log.Info("original error", zap.Error(readErr)) readErr = errors.Errorf("read: connection reset by peer") - } + }) if readErr != nil { if isDeadlineExceedError(readErr) || isCancelError(readErr) { return nil, errors.Annotatef(readErr, "failed to read body from get object result, file info: input.bucket='%s', input.key='%s', retryCnt='%d'", @@ -1169,12 +1169,12 @@ func isConnectionRefusedError(err error) bool { func (rl retryerWithLog) ShouldRetry(r *request.Request) bool { // for unit test - if _, _err_ := failpoint.Eval(_curpkg_("replace-error-to-connection-reset-by-peer")); _err_ == nil { + failpoint.Inject("replace-error-to-connection-reset-by-peer", func(_ failpoint.Value) { log.Info("original error", zap.Error(r.Error)) if r.Error != nil { r.Error = errors.New("read tcp *.*.*.*:*->*.*.*.*:*: read: connection reset by peer") } - } + }) if r.HTTPRequest.URL.Host == ec2MetaAddress && (isDeadlineExceedError(r.Error) || isConnectionResetError(r.Error)) { // fast fail for unreachable linklocal address in EC2 containers. log.Warn("failed to get EC2 metadata. skipping.", logutil.ShortError(r.Error)) diff --git a/br/pkg/storage/s3.go__failpoint_stash__ b/br/pkg/storage/s3.go__failpoint_stash__ deleted file mode 100644 index 3987512b2a0a2..0000000000000 --- a/br/pkg/storage/s3.go__failpoint_stash__ +++ /dev/null @@ -1,1208 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package storage - -import ( - "bytes" - "context" - "fmt" - "io" - "net/url" - "path" - "regexp" - "strconv" - "strings" - "sync" - "time" - - alicred "github.com/aliyun/alibaba-cloud-sdk-go/sdk/auth/credentials" - aliproviders "github.com/aliyun/alibaba-cloud-sdk-go/sdk/auth/credentials/providers" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" - "github.com/aws/aws-sdk-go/service/s3/s3manager" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/pkg/util/prefetch" - "github.com/spf13/pflag" - "go.uber.org/zap" -) - -var hardcodedS3ChunkSize = 5 * 1024 * 1024 - -const ( - s3EndpointOption = "s3.endpoint" - s3RegionOption = "s3.region" - s3StorageClassOption = "s3.storage-class" - s3SseOption = "s3.sse" - s3SseKmsKeyIDOption = "s3.sse-kms-key-id" - s3ACLOption = "s3.acl" - s3ProviderOption = "s3.provider" - s3RoleARNOption = "s3.role-arn" - s3ExternalIDOption = "s3.external-id" - notFound = "NotFound" - // number of retries to make of operations. - maxRetries = 7 - // max number of retries when meets error - maxErrorRetries = 3 - ec2MetaAddress = "169.254.169.254" - - // the maximum number of byte to read for seek. - maxSkipOffsetByRead = 1 << 16 // 64KB - - defaultRegion = "us-east-1" - // to check the cloud type by endpoint tag. - domainAliyun = "aliyuncs.com" -) - -var permissionCheckFn = map[Permission]func(context.Context, s3iface.S3API, *backuppb.S3) error{ - AccessBuckets: s3BucketExistenceCheck, - ListObjects: listObjectsCheck, - GetObject: getObjectCheck, - PutAndDeleteObject: PutAndDeleteObjectCheck, -} - -// WriteBufferSize is the size of the buffer used for writing. (64K may be a better choice) -var WriteBufferSize = 5 * 1024 * 1024 - -// S3Storage defines some standard operations for BR/Lightning on the S3 storage. -// It implements the `ExternalStorage` interface. -type S3Storage struct { - svc s3iface.S3API - options *backuppb.S3 -} - -// GetS3APIHandle gets the handle to the S3 API. -func (rs *S3Storage) GetS3APIHandle() s3iface.S3API { - return rs.svc -} - -// GetOptions gets the external storage operations for the S3. -func (rs *S3Storage) GetOptions() *backuppb.S3 { - return rs.options -} - -// S3Uploader does multi-part upload to s3. -type S3Uploader struct { - svc s3iface.S3API - createOutput *s3.CreateMultipartUploadOutput - completeParts []*s3.CompletedPart -} - -// UploadPart update partial data to s3, we should call CreateMultipartUpload to start it, -// and call CompleteMultipartUpload to finish it. -func (u *S3Uploader) Write(ctx context.Context, data []byte) (int, error) { - partInput := &s3.UploadPartInput{ - Body: bytes.NewReader(data), - Bucket: u.createOutput.Bucket, - Key: u.createOutput.Key, - PartNumber: aws.Int64(int64(len(u.completeParts) + 1)), - UploadId: u.createOutput.UploadId, - ContentLength: aws.Int64(int64(len(data))), - } - - uploadResult, err := u.svc.UploadPartWithContext(ctx, partInput) - if err != nil { - return 0, errors.Trace(err) - } - u.completeParts = append(u.completeParts, &s3.CompletedPart{ - ETag: uploadResult.ETag, - PartNumber: partInput.PartNumber, - }) - return len(data), nil -} - -// Close complete multi upload request. -func (u *S3Uploader) Close(ctx context.Context) error { - completeInput := &s3.CompleteMultipartUploadInput{ - Bucket: u.createOutput.Bucket, - Key: u.createOutput.Key, - UploadId: u.createOutput.UploadId, - MultipartUpload: &s3.CompletedMultipartUpload{ - Parts: u.completeParts, - }, - } - _, err := u.svc.CompleteMultipartUploadWithContext(ctx, completeInput) - return errors.Trace(err) -} - -// S3BackendOptions contains options for s3 storage. -type S3BackendOptions struct { - Endpoint string `json:"endpoint" toml:"endpoint"` - Region string `json:"region" toml:"region"` - StorageClass string `json:"storage-class" toml:"storage-class"` - Sse string `json:"sse" toml:"sse"` - SseKmsKeyID string `json:"sse-kms-key-id" toml:"sse-kms-key-id"` - ACL string `json:"acl" toml:"acl"` - AccessKey string `json:"access-key" toml:"access-key"` - SecretAccessKey string `json:"secret-access-key" toml:"secret-access-key"` - SessionToken string `json:"session-token" toml:"session-token"` - Provider string `json:"provider" toml:"provider"` - ForcePathStyle bool `json:"force-path-style" toml:"force-path-style"` - UseAccelerateEndpoint bool `json:"use-accelerate-endpoint" toml:"use-accelerate-endpoint"` - RoleARN string `json:"role-arn" toml:"role-arn"` - ExternalID string `json:"external-id" toml:"external-id"` - ObjectLockEnabled bool `json:"object-lock-enabled" toml:"object-lock-enabled"` -} - -// Apply apply s3 options on backuppb.S3. -func (options *S3BackendOptions) Apply(s3 *backuppb.S3) error { - if options.Endpoint != "" { - u, err := url.Parse(options.Endpoint) - if err != nil { - return errors.Trace(err) - } - if u.Scheme == "" { - return errors.Errorf("scheme not found in endpoint") - } - if u.Host == "" { - return errors.Errorf("host not found in endpoint") - } - } - // In some cases, we need to set ForcePathStyle to false. - // Refer to: https://rclone.org/s3/#s3-force-path-style - if options.Provider == "alibaba" || options.Provider == "netease" || - options.UseAccelerateEndpoint { - options.ForcePathStyle = false - } - if options.AccessKey == "" && options.SecretAccessKey != "" { - return errors.Annotate(berrors.ErrStorageInvalidConfig, "access_key not found") - } - if options.AccessKey != "" && options.SecretAccessKey == "" { - return errors.Annotate(berrors.ErrStorageInvalidConfig, "secret_access_key not found") - } - - s3.Endpoint = strings.TrimSuffix(options.Endpoint, "/") - s3.Region = options.Region - // StorageClass, SSE and ACL are acceptable to be empty - s3.StorageClass = options.StorageClass - s3.Sse = options.Sse - s3.SseKmsKeyId = options.SseKmsKeyID - s3.Acl = options.ACL - s3.AccessKey = options.AccessKey - s3.SecretAccessKey = options.SecretAccessKey - s3.SessionToken = options.SessionToken - s3.ForcePathStyle = options.ForcePathStyle - s3.RoleArn = options.RoleARN - s3.ExternalId = options.ExternalID - s3.Provider = options.Provider - return nil -} - -// defineS3Flags defines the command line flags for S3BackendOptions. -func defineS3Flags(flags *pflag.FlagSet) { - // TODO: remove experimental tag if it's stable - flags.String(s3EndpointOption, "", - "(experimental) Set the S3 endpoint URL, please specify the http or https scheme explicitly") - flags.String(s3RegionOption, "", "(experimental) Set the S3 region, e.g. us-east-1") - flags.String(s3StorageClassOption, "", "(experimental) Set the S3 storage class, e.g. STANDARD") - flags.String(s3SseOption, "", "Set S3 server-side encryption, e.g. aws:kms") - flags.String(s3SseKmsKeyIDOption, "", "KMS CMK key id to use with S3 server-side encryption."+ - "Leave empty to use S3 owned key.") - flags.String(s3ACLOption, "", "(experimental) Set the S3 canned ACLs, e.g. authenticated-read") - flags.String(s3ProviderOption, "", "(experimental) Set the S3 provider, e.g. aws, alibaba, ceph") - flags.String(s3RoleARNOption, "", "(experimental) Set the ARN of the IAM role to assume when accessing AWS S3") - flags.String(s3ExternalIDOption, "", "(experimental) Set the external ID when assuming the role to access AWS S3") -} - -// parseFromFlags parse S3BackendOptions from command line flags. -func (options *S3BackendOptions) parseFromFlags(flags *pflag.FlagSet) error { - var err error - options.Endpoint, err = flags.GetString(s3EndpointOption) - if err != nil { - return errors.Trace(err) - } - options.Endpoint = strings.TrimSuffix(options.Endpoint, "/") - options.Region, err = flags.GetString(s3RegionOption) - if err != nil { - return errors.Trace(err) - } - options.Sse, err = flags.GetString(s3SseOption) - if err != nil { - return errors.Trace(err) - } - options.SseKmsKeyID, err = flags.GetString(s3SseKmsKeyIDOption) - if err != nil { - return errors.Trace(err) - } - options.ACL, err = flags.GetString(s3ACLOption) - if err != nil { - return errors.Trace(err) - } - options.StorageClass, err = flags.GetString(s3StorageClassOption) - if err != nil { - return errors.Trace(err) - } - options.ForcePathStyle = true - options.Provider, err = flags.GetString(s3ProviderOption) - if err != nil { - return errors.Trace(err) - } - options.RoleARN, err = flags.GetString(s3RoleARNOption) - if err != nil { - return errors.Trace(err) - } - options.ExternalID, err = flags.GetString(s3ExternalIDOption) - if err != nil { - return errors.Trace(err) - } - return nil -} - -// NewS3StorageForTest creates a new S3Storage for testing only. -func NewS3StorageForTest(svc s3iface.S3API, options *backuppb.S3) *S3Storage { - return &S3Storage{ - svc: svc, - options: options, - } -} - -// auto access without ak / sk. -func autoNewCred(qs *backuppb.S3) (cred *credentials.Credentials, err error) { - if qs.AccessKey != "" && qs.SecretAccessKey != "" { - return credentials.NewStaticCredentials(qs.AccessKey, qs.SecretAccessKey, qs.SessionToken), nil - } - endpoint := qs.Endpoint - // if endpoint is empty,return no error and run default(aws) follow. - if endpoint == "" { - return nil, nil - } - // if it Contains 'aliyuncs', fetch the sts token. - if strings.Contains(endpoint, domainAliyun) { - return createOssRAMCred() - } - // other case ,return no error and run default(aws) follow. - return nil, nil -} - -func createOssRAMCred() (*credentials.Credentials, error) { - cred, err := aliproviders.NewInstanceMetadataProvider().Retrieve() - if err != nil { - log.Warn("failed to get aliyun ram credential", zap.Error(err)) - return nil, nil - } - var aliCred, ok = cred.(*alicred.StsTokenCredential) - if !ok { - return nil, errors.Errorf("invalid credential type %T", cred) - } - newCred := credentials.NewChainCredentials([]credentials.Provider{ - &credentials.EnvProvider{}, - &credentials.SharedCredentialsProvider{}, - &credentials.StaticProvider{Value: credentials.Value{AccessKeyID: aliCred.AccessKeyId, SecretAccessKey: aliCred.AccessKeySecret, SessionToken: aliCred.AccessKeyStsToken, ProviderName: ""}}, - }) - if _, err := newCred.Get(); err != nil { - return nil, errors.Trace(err) - } - return newCred, nil -} - -// NewS3Storage initialize a new s3 storage for metadata. -func NewS3Storage(ctx context.Context, backend *backuppb.S3, opts *ExternalStorageOptions) (obj *S3Storage, errRet error) { - qs := *backend - awsConfig := aws.NewConfig(). - WithS3ForcePathStyle(qs.ForcePathStyle). - WithCredentialsChainVerboseErrors(true) - if qs.Region == "" { - awsConfig.WithRegion(defaultRegion) - } else { - awsConfig.WithRegion(qs.Region) - } - - if opts.S3Retryer != nil { - request.WithRetryer(awsConfig, opts.S3Retryer) - } else { - request.WithRetryer(awsConfig, defaultS3Retryer()) - } - - if qs.Endpoint != "" { - awsConfig.WithEndpoint(qs.Endpoint) - } - if opts.HTTPClient != nil { - awsConfig.WithHTTPClient(opts.HTTPClient) - } - cred, err := autoNewCred(&qs) - if err != nil { - return nil, errors.Trace(err) - } - if cred != nil { - awsConfig.WithCredentials(cred) - } - // awsConfig.WithLogLevel(aws.LogDebugWithSigning) - awsSessionOpts := session.Options{ - Config: *awsConfig, - } - ses, err := session.NewSessionWithOptions(awsSessionOpts) - if err != nil { - return nil, errors.Trace(err) - } - - if !opts.SendCredentials { - // Clear the credentials if exists so that they will not be sent to TiKV - backend.AccessKey = "" - backend.SecretAccessKey = "" - backend.SessionToken = "" - } else if ses.Config.Credentials != nil { - if qs.AccessKey == "" || qs.SecretAccessKey == "" { - v, cerr := ses.Config.Credentials.Get() - if cerr != nil { - return nil, errors.Trace(cerr) - } - backend.AccessKey = v.AccessKeyID - backend.SecretAccessKey = v.SecretAccessKey - backend.SessionToken = v.SessionToken - } - } - - s3CliConfigs := []*aws.Config{} - // if role ARN and external ID are provided, try to get the credential using this way - if len(qs.RoleArn) > 0 { - creds := stscreds.NewCredentials(ses, qs.RoleArn, func(p *stscreds.AssumeRoleProvider) { - if len(qs.ExternalId) > 0 { - p.ExternalID = &qs.ExternalId - } - }) - s3CliConfigs = append(s3CliConfigs, - aws.NewConfig().WithCredentials(creds), - ) - } - c := s3.New(ses, s3CliConfigs...) - - var region string - if len(qs.Provider) == 0 || qs.Provider == "aws" { - confCred := ses.Config.Credentials - setCredOpt := func(req *request.Request) { - // s3manager.GetBucketRegionWithClient will set credential anonymous, which works with s3. - // we need reassign credential to be compatible with minio authentication. - if confCred != nil { - req.Config.Credentials = confCred - } - // s3manager.GetBucketRegionWithClient use path style addressing default. - // we need set S3ForcePathStyle by our config if we set endpoint. - if qs.Endpoint != "" { - req.Config.S3ForcePathStyle = ses.Config.S3ForcePathStyle - } - } - region, err = s3manager.GetBucketRegionWithClient(ctx, c, qs.Bucket, setCredOpt) - if err != nil { - return nil, errors.Annotatef(err, "failed to get region of bucket %s", qs.Bucket) - } - } else { - // for other s3 compatible provider like ovh storage didn't return the region correctlly - // so we cannot automatically get the bucket region. just fallback to manually region setting. - region = qs.Region - } - - if qs.Region != region { - if qs.Region != "" { - return nil, errors.Trace(fmt.Errorf("s3 bucket and region are not matched, bucket=%s, input region=%s, real region=%s", - qs.Bucket, qs.Region, region)) - } - - qs.Region = region - backend.Region = region - if region != defaultRegion { - s3CliConfigs = append(s3CliConfigs, aws.NewConfig().WithRegion(region)) - c = s3.New(ses, s3CliConfigs...) - } - } - log.Info("succeed to get bucket region from s3", zap.String("bucket region", region)) - - if len(qs.Prefix) > 0 && !strings.HasSuffix(qs.Prefix, "/") { - qs.Prefix += "/" - } - - for _, p := range opts.CheckPermissions { - err := permissionCheckFn[p](ctx, c, &qs) - if err != nil { - return nil, errors.Annotatef(berrors.ErrStorageInvalidPermission, "check permission %s failed due to %v", p, err) - } - } - - s3Storage := &S3Storage{ - svc: c, - options: &qs, - } - if opts.CheckS3ObjectLockOptions { - backend.ObjectLockEnabled = s3Storage.IsObjectLockEnabled() - } - return s3Storage, nil -} - -// s3BucketExistenceCheck checks if a bucket exists. -func s3BucketExistenceCheck(_ context.Context, svc s3iface.S3API, qs *backuppb.S3) error { - input := &s3.HeadBucketInput{ - Bucket: aws.String(qs.Bucket), - } - _, err := svc.HeadBucket(input) - return errors.Trace(err) -} - -// listObjectsCheck checks the permission of listObjects -func listObjectsCheck(_ context.Context, svc s3iface.S3API, qs *backuppb.S3) error { - input := &s3.ListObjectsInput{ - Bucket: aws.String(qs.Bucket), - Prefix: aws.String(qs.Prefix), - MaxKeys: aws.Int64(1), - } - _, err := svc.ListObjects(input) - if err != nil { - return errors.Trace(err) - } - return nil -} - -// getObjectCheck checks the permission of getObject -func getObjectCheck(_ context.Context, svc s3iface.S3API, qs *backuppb.S3) error { - input := &s3.GetObjectInput{ - Bucket: aws.String(qs.Bucket), - Key: aws.String("not-exists"), - } - _, err := svc.GetObject(input) - if aerr, ok := err.(awserr.Error); ok { - if aerr.Code() == "NoSuchKey" { - // if key not exists and we reach this error, that - // means we have the correct permission to GetObject - // other we will get another error - return nil - } - return errors.Trace(err) - } - return nil -} - -// PutAndDeleteObjectCheck checks the permission of putObject -// S3 API doesn't provide a way to check the permission, we have to put an -// object to check the permission. -// exported for testing. -func PutAndDeleteObjectCheck(ctx context.Context, svc s3iface.S3API, options *backuppb.S3) (err error) { - file := fmt.Sprintf("access-check/%s", uuid.New().String()) - defer func() { - // we always delete the object used for permission check, - // even on error, since the object might be created successfully even - // when it returns an error. - input := &s3.DeleteObjectInput{ - Bucket: aws.String(options.Bucket), - Key: aws.String(options.Prefix + file), - } - _, err2 := svc.DeleteObjectWithContext(ctx, input) - if aerr, ok := err2.(awserr.Error); ok { - if aerr.Code() != "NoSuchKey" { - log.Warn("failed to delete object used for permission check", - zap.String("bucket", options.Bucket), - zap.String("key", *input.Key), zap.Error(err2)) - } - } - if err == nil { - err = errors.Trace(err2) - } - }() - // when no permission, aws returns err with code "AccessDenied" - input := buildPutObjectInput(options, file, []byte("check")) - _, err = svc.PutObjectWithContext(ctx, input) - return errors.Trace(err) -} - -func (rs *S3Storage) IsObjectLockEnabled() bool { - input := &s3.GetObjectLockConfigurationInput{ - Bucket: aws.String(rs.options.Bucket), - } - resp, err := rs.svc.GetObjectLockConfiguration(input) - if err != nil { - log.Warn("failed to check object lock for bucket", zap.String("bucket", rs.options.Bucket), zap.Error(err)) - return false - } - if resp != nil && resp.ObjectLockConfiguration != nil { - if s3.ObjectLockEnabledEnabled == aws.StringValue(resp.ObjectLockConfiguration.ObjectLockEnabled) { - return true - } - } - return false -} - -func buildPutObjectInput(options *backuppb.S3, file string, data []byte) *s3.PutObjectInput { - input := &s3.PutObjectInput{ - Body: aws.ReadSeekCloser(bytes.NewReader(data)), - Bucket: aws.String(options.Bucket), - Key: aws.String(options.Prefix + file), - } - if options.Acl != "" { - input = input.SetACL(options.Acl) - } - if options.Sse != "" { - input = input.SetServerSideEncryption(options.Sse) - } - if options.SseKmsKeyId != "" { - input = input.SetSSEKMSKeyId(options.SseKmsKeyId) - } - if options.StorageClass != "" { - input = input.SetStorageClass(options.StorageClass) - } - return input -} - -// WriteFile writes data to a file to storage. -func (rs *S3Storage) WriteFile(ctx context.Context, file string, data []byte) error { - input := buildPutObjectInput(rs.options, file, data) - // we don't need to calculate contentMD5 if s3 object lock enabled. - // since aws-go-sdk already did it in #computeBodyHashes - // https://github.com/aws/aws-sdk-go/blob/bcb2cf3fc2263c8c28b3119b07d2dbb44d7c93a0/service/s3/body_hash.go#L30 - _, err := rs.svc.PutObjectWithContext(ctx, input) - if err != nil { - return errors.Trace(err) - } - hinput := &s3.HeadObjectInput{ - Bucket: aws.String(rs.options.Bucket), - Key: aws.String(rs.options.Prefix + file), - } - err = rs.svc.WaitUntilObjectExistsWithContext(ctx, hinput) - return errors.Trace(err) -} - -// ReadFile reads the file from the storage and returns the contents. -func (rs *S3Storage) ReadFile(ctx context.Context, file string) ([]byte, error) { - var ( - data []byte - readErr error - ) - for retryCnt := 0; retryCnt < maxErrorRetries; retryCnt += 1 { - input := &s3.GetObjectInput{ - Bucket: aws.String(rs.options.Bucket), - Key: aws.String(rs.options.Prefix + file), - } - result, err := rs.svc.GetObjectWithContext(ctx, input) - if err != nil { - return nil, errors.Annotatef(err, - "failed to read s3 file, file info: input.bucket='%s', input.key='%s'", - *input.Bucket, *input.Key) - } - data, readErr = io.ReadAll(result.Body) - // close the body of response since data has been already read out - result.Body.Close() - // for unit test - failpoint.Inject("read-s3-body-failed", func(_ failpoint.Value) { - log.Info("original error", zap.Error(readErr)) - readErr = errors.Errorf("read: connection reset by peer") - }) - if readErr != nil { - if isDeadlineExceedError(readErr) || isCancelError(readErr) { - return nil, errors.Annotatef(readErr, "failed to read body from get object result, file info: input.bucket='%s', input.key='%s', retryCnt='%d'", - *input.Bucket, *input.Key, retryCnt) - } - continue - } - return data, nil - } - // retry too much, should be failed - return nil, errors.Annotatef(readErr, "failed to read body from get object result (retry too much), file info: input.bucket='%s', input.key='%s'", - rs.options.Bucket, rs.options.Prefix+file) -} - -// DeleteFile delete the file in s3 storage -func (rs *S3Storage) DeleteFile(ctx context.Context, file string) error { - input := &s3.DeleteObjectInput{ - Bucket: aws.String(rs.options.Bucket), - Key: aws.String(rs.options.Prefix + file), - } - - _, err := rs.svc.DeleteObjectWithContext(ctx, input) - return errors.Trace(err) -} - -// s3DeleteObjectsLimit is the upper limit of objects in a delete request. -// See https://docs.aws.amazon.com/sdk-for-go/api/service/s3/#S3.DeleteObjects. -const s3DeleteObjectsLimit = 1000 - -// DeleteFiles delete the files in batch in s3 storage. -func (rs *S3Storage) DeleteFiles(ctx context.Context, files []string) error { - for len(files) > 0 { - batch := files - if len(batch) > s3DeleteObjectsLimit { - batch = batch[:s3DeleteObjectsLimit] - } - objects := make([]*s3.ObjectIdentifier, 0, len(batch)) - for _, file := range batch { - objects = append(objects, &s3.ObjectIdentifier{ - Key: aws.String(rs.options.Prefix + file), - }) - } - input := &s3.DeleteObjectsInput{ - Bucket: aws.String(rs.options.Bucket), - Delete: &s3.Delete{ - Objects: objects, - Quiet: aws.Bool(false), - }, - } - _, err := rs.svc.DeleteObjectsWithContext(ctx, input) - if err != nil { - return errors.Trace(err) - } - files = files[len(batch):] - } - return nil -} - -// FileExists check if file exists on s3 storage. -func (rs *S3Storage) FileExists(ctx context.Context, file string) (bool, error) { - input := &s3.HeadObjectInput{ - Bucket: aws.String(rs.options.Bucket), - Key: aws.String(rs.options.Prefix + file), - } - - _, err := rs.svc.HeadObjectWithContext(ctx, input) - if err != nil { - if aerr, ok := errors.Cause(err).(awserr.Error); ok { // nolint:errorlint - switch aerr.Code() { - case s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchKey, notFound: - return false, nil - } - } - return false, errors.Trace(err) - } - return true, nil -} - -// WalkDir traverse all the files in a dir. -// -// fn is the function called for each regular file visited by WalkDir. -// The first argument is the file path that can be used in `Open` -// function; the second argument is the size in byte of the file determined -// by path. -func (rs *S3Storage) WalkDir(ctx context.Context, opt *WalkOption, fn func(string, int64) error) error { - if opt == nil { - opt = &WalkOption{} - } - prefix := path.Join(rs.options.Prefix, opt.SubDir) - if len(prefix) > 0 && !strings.HasSuffix(prefix, "/") { - prefix += "/" - } - - if len(opt.ObjPrefix) != 0 { - prefix += opt.ObjPrefix - } - - maxKeys := int64(1000) - if opt.ListCount > 0 { - maxKeys = opt.ListCount - } - req := &s3.ListObjectsInput{ - Bucket: aws.String(rs.options.Bucket), - Prefix: aws.String(prefix), - MaxKeys: aws.Int64(maxKeys), - } - - for { - // FIXME: We can't use ListObjectsV2, it is not universally supported. - // (Ceph RGW supported ListObjectsV2 since v15.1.0, released 2020 Jan 30th) - // (as of 2020, DigitalOcean Spaces still does not support V2 - https://developers.digitalocean.com/documentation/spaces/#list-bucket-contents) - res, err := rs.svc.ListObjectsWithContext(ctx, req) - if err != nil { - return errors.Trace(err) - } - for _, r := range res.Contents { - // https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjects.html#AmazonS3-ListObjects-response-NextMarker - - // - // `res.NextMarker` is populated only if we specify req.Delimiter. - // Aliyun OSS and minio will populate NextMarker no matter what, - // but this documented behavior does apply to AWS S3: - // - // "If response does not include the NextMarker and it is truncated, - // you can use the value of the last Key in the response as the marker - // in the subsequent request to get the next set of object keys." - req.Marker = r.Key - - // when walk on specify directory, the result include storage.Prefix, - // which can not be reuse in other API(Open/Read) directly. - // so we use TrimPrefix to filter Prefix for next Open/Read. - path := strings.TrimPrefix(*r.Key, rs.options.Prefix) - // trim the prefix '/' to ensure that the path returned is consistent with the local storage - path = strings.TrimPrefix(path, "/") - itemSize := *r.Size - - // filter out s3's empty directory items - if itemSize <= 0 && strings.HasSuffix(path, "/") { - log.Info("this path is an empty directory and cannot be opened in S3. Skip it", zap.String("path", path)) - continue - } - if err = fn(path, itemSize); err != nil { - return errors.Trace(err) - } - } - if !aws.BoolValue(res.IsTruncated) { - break - } - } - - return nil -} - -// URI returns s3:///. -func (rs *S3Storage) URI() string { - return "s3://" + rs.options.Bucket + "/" + rs.options.Prefix -} - -// Open a Reader by file path. -func (rs *S3Storage) Open(ctx context.Context, path string, o *ReaderOption) (ExternalFileReader, error) { - start := int64(0) - end := int64(0) - prefetchSize := 0 - if o != nil { - if o.StartOffset != nil { - start = *o.StartOffset - } - if o.EndOffset != nil { - end = *o.EndOffset - } - prefetchSize = o.PrefetchSize - } - reader, r, err := rs.open(ctx, path, start, end) - if err != nil { - return nil, errors.Trace(err) - } - if prefetchSize > 0 { - reader = prefetch.NewReader(reader, o.PrefetchSize) - } - return &s3ObjectReader{ - storage: rs, - name: path, - reader: reader, - ctx: ctx, - rangeInfo: r, - prefetchSize: prefetchSize, - }, nil -} - -// RangeInfo represents the an HTTP Content-Range header value -// of the form `bytes [Start]-[End]/[Size]`. -type RangeInfo struct { - // Start is the absolute position of the first byte of the byte range, - // starting from 0. - Start int64 - // End is the absolute position of the last byte of the byte range. This end - // offset is inclusive, e.g. if the Size is 1000, the maximum value of End - // would be 999. - End int64 - // Size is the total size of the original file. - Size int64 -} - -// if endOffset > startOffset, should return reader for bytes in [startOffset, endOffset). -func (rs *S3Storage) open( - ctx context.Context, - path string, - startOffset, endOffset int64, -) (io.ReadCloser, RangeInfo, error) { - input := &s3.GetObjectInput{ - Bucket: aws.String(rs.options.Bucket), - Key: aws.String(rs.options.Prefix + path), - } - - // If we just open part of the object, we set `Range` in the request. - // If we meant to open the whole object, not just a part of it, - // we do not pass the range in the request, - // so that even if the object is empty, we can still get the response without errors. - // Then this behavior is similar to openning an empty file in local file system. - isFullRangeRequest := false - var rangeOffset *string - switch { - case endOffset > startOffset: - // s3 endOffset is inclusive - rangeOffset = aws.String(fmt.Sprintf("bytes=%d-%d", startOffset, endOffset-1)) - case startOffset == 0: - // openning the whole object, no need to fill the `Range` field in the request - isFullRangeRequest = true - default: - rangeOffset = aws.String(fmt.Sprintf("bytes=%d-", startOffset)) - } - input.Range = rangeOffset - result, err := rs.svc.GetObjectWithContext(ctx, input) - if err != nil { - return nil, RangeInfo{}, errors.Trace(err) - } - - var r RangeInfo - // Those requests without a `Range` will have no `ContentRange` in the response, - // In this case, we'll parse the `ContentLength` field instead. - if isFullRangeRequest { - // We must ensure the `ContentLengh` has data even if for empty objects, - // otherwise we have no places to get the object size - if result.ContentLength == nil { - return nil, RangeInfo{}, errors.Annotatef(berrors.ErrStorageUnknown, "open file '%s' failed. The S3 object has no content length", path) - } - objectSize := *(result.ContentLength) - r = RangeInfo{ - Start: 0, - End: objectSize - 1, - Size: objectSize, - } - } else { - r, err = ParseRangeInfo(result.ContentRange) - if err != nil { - return nil, RangeInfo{}, errors.Trace(err) - } - } - - if startOffset != r.Start || (endOffset != 0 && endOffset != r.End+1) { - return nil, r, errors.Annotatef(berrors.ErrStorageUnknown, "open file '%s' failed, expected range: %s, got: %v", - path, *rangeOffset, result.ContentRange) - } - - return result.Body, r, nil -} - -var contentRangeRegex = regexp.MustCompile(`bytes (\d+)-(\d+)/(\d+)$`) - -// ParseRangeInfo parses the Content-Range header and returns the offsets. -func ParseRangeInfo(info *string) (ri RangeInfo, err error) { - if info == nil || len(*info) == 0 { - err = errors.Annotate(berrors.ErrStorageUnknown, "ContentRange is empty") - return - } - subMatches := contentRangeRegex.FindStringSubmatch(*info) - if len(subMatches) != 4 { - err = errors.Annotatef(berrors.ErrStorageUnknown, "invalid content range: '%s'", *info) - return - } - - ri.Start, err = strconv.ParseInt(subMatches[1], 10, 64) - if err != nil { - err = errors.Annotatef(err, "invalid start offset value '%s' in ContentRange '%s'", subMatches[1], *info) - return - } - ri.End, err = strconv.ParseInt(subMatches[2], 10, 64) - if err != nil { - err = errors.Annotatef(err, "invalid end offset value '%s' in ContentRange '%s'", subMatches[2], *info) - return - } - ri.Size, err = strconv.ParseInt(subMatches[3], 10, 64) - if err != nil { - err = errors.Annotatef(err, "invalid size size value '%s' in ContentRange '%s'", subMatches[3], *info) - return - } - return -} - -// s3ObjectReader wrap GetObjectOutput.Body and add the `Seek` method. -type s3ObjectReader struct { - storage *S3Storage - name string - reader io.ReadCloser - pos int64 - rangeInfo RangeInfo - // reader context used for implement `io.Seek` - // currently, lightning depends on package `xitongsys/parquet-go` to read parquet file and it needs `io.Seeker` - // See: https://github.com/xitongsys/parquet-go/blob/207a3cee75900b2b95213627409b7bac0f190bb3/source/source.go#L9-L10 - ctx context.Context - prefetchSize int -} - -// Read implement the io.Reader interface. -func (r *s3ObjectReader) Read(p []byte) (n int, err error) { - retryCnt := 0 - maxCnt := r.rangeInfo.End + 1 - r.pos - if maxCnt == 0 { - return 0, io.EOF - } - if maxCnt > int64(len(p)) { - maxCnt = int64(len(p)) - } - n, err = r.reader.Read(p[:maxCnt]) - // TODO: maybe we should use !errors.Is(err, io.EOF) here to avoid error lint, but currently, pingcap/errors - // doesn't implement this method yet. - for err != nil && errors.Cause(err) != io.EOF && retryCnt < maxErrorRetries { //nolint:errorlint - log.L().Warn( - "read s3 object failed, will retry", - zap.String("file", r.name), - zap.Int("retryCnt", retryCnt), - zap.Error(err), - ) - // if can retry, reopen a new reader and try read again - end := r.rangeInfo.End + 1 - if end == r.rangeInfo.Size { - end = 0 - } - _ = r.reader.Close() - - newReader, _, err1 := r.storage.open(r.ctx, r.name, r.pos, end) - if err1 != nil { - log.Warn("open new s3 reader failed", zap.String("file", r.name), zap.Error(err1)) - return - } - r.reader = newReader - if r.prefetchSize > 0 { - r.reader = prefetch.NewReader(r.reader, r.prefetchSize) - } - retryCnt++ - n, err = r.reader.Read(p[:maxCnt]) - } - - r.pos += int64(n) - return -} - -// Close implement the io.Closer interface. -func (r *s3ObjectReader) Close() error { - return r.reader.Close() -} - -// Seek implement the io.Seeker interface. -// -// Currently, tidb-lightning depends on this method to read parquet file for s3 storage. -func (r *s3ObjectReader) Seek(offset int64, whence int) (int64, error) { - var realOffset int64 - switch whence { - case io.SeekStart: - realOffset = offset - case io.SeekCurrent: - realOffset = r.pos + offset - case io.SeekEnd: - realOffset = r.rangeInfo.Size + offset - default: - return 0, errors.Annotatef(berrors.ErrStorageUnknown, "Seek: invalid whence '%d'", whence) - } - if realOffset < 0 { - return 0, errors.Annotatef(berrors.ErrStorageUnknown, "Seek in '%s': invalid offset to seek '%d'.", r.name, realOffset) - } - - if realOffset == r.pos { - return realOffset, nil - } else if realOffset >= r.rangeInfo.Size { - // See: https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 - // because s3's GetObject interface doesn't allow get a range that matches zero length data, - // so if the position is out of range, we need to always return io.EOF after the seek operation. - - // close current read and open a new one which target offset - if err := r.reader.Close(); err != nil { - log.L().Warn("close s3 reader failed, will ignore this error", logutil.ShortError(err)) - } - - r.reader = io.NopCloser(bytes.NewReader(nil)) - r.pos = r.rangeInfo.Size - return r.pos, nil - } - - // if seek ahead no more than 64k, we discard these data - if realOffset > r.pos && realOffset-r.pos <= maxSkipOffsetByRead { - _, err := io.CopyN(io.Discard, r, realOffset-r.pos) - if err != nil { - return r.pos, errors.Trace(err) - } - return realOffset, nil - } - - // close current read and open a new one which target offset - err := r.reader.Close() - if err != nil { - return 0, errors.Trace(err) - } - - newReader, info, err := r.storage.open(r.ctx, r.name, realOffset, 0) - if err != nil { - return 0, errors.Trace(err) - } - r.reader = newReader - if r.prefetchSize > 0 { - r.reader = prefetch.NewReader(r.reader, r.prefetchSize) - } - r.rangeInfo = info - r.pos = realOffset - return realOffset, nil -} - -func (r *s3ObjectReader) GetFileSize() (int64, error) { - return r.rangeInfo.Size, nil -} - -// createUploader create multi upload request. -func (rs *S3Storage) createUploader(ctx context.Context, name string) (ExternalFileWriter, error) { - input := &s3.CreateMultipartUploadInput{ - Bucket: aws.String(rs.options.Bucket), - Key: aws.String(rs.options.Prefix + name), - } - if rs.options.Acl != "" { - input = input.SetACL(rs.options.Acl) - } - if rs.options.Sse != "" { - input = input.SetServerSideEncryption(rs.options.Sse) - } - if rs.options.SseKmsKeyId != "" { - input = input.SetSSEKMSKeyId(rs.options.SseKmsKeyId) - } - if rs.options.StorageClass != "" { - input = input.SetStorageClass(rs.options.StorageClass) - } - - resp, err := rs.svc.CreateMultipartUploadWithContext(ctx, input) - if err != nil { - return nil, errors.Trace(err) - } - return &S3Uploader{ - svc: rs.svc, - createOutput: resp, - completeParts: make([]*s3.CompletedPart, 0, 128), - }, nil -} - -type s3ObjectWriter struct { - wd *io.PipeWriter - wg *sync.WaitGroup - err error -} - -// Write implement the io.Writer interface. -func (s *s3ObjectWriter) Write(_ context.Context, p []byte) (int, error) { - return s.wd.Write(p) -} - -// Close implement the io.Closer interface. -func (s *s3ObjectWriter) Close(_ context.Context) error { - err := s.wd.Close() - if err != nil { - return err - } - s.wg.Wait() - return s.err -} - -// Create creates multi upload request. -func (rs *S3Storage) Create(ctx context.Context, name string, option *WriterOption) (ExternalFileWriter, error) { - var uploader ExternalFileWriter - var err error - if option == nil || option.Concurrency <= 1 { - uploader, err = rs.createUploader(ctx, name) - if err != nil { - return nil, err - } - } else { - up := s3manager.NewUploaderWithClient(rs.svc, func(u *s3manager.Uploader) { - u.PartSize = option.PartSize - u.Concurrency = option.Concurrency - u.BufferProvider = s3manager.NewBufferedReadSeekerWriteToPool(option.Concurrency * hardcodedS3ChunkSize) - }) - rd, wd := io.Pipe() - upParams := &s3manager.UploadInput{ - Bucket: aws.String(rs.options.Bucket), - Key: aws.String(rs.options.Prefix + name), - Body: rd, - } - s3Writer := &s3ObjectWriter{wd: wd, wg: &sync.WaitGroup{}} - s3Writer.wg.Add(1) - go func() { - _, err := up.UploadWithContext(ctx, upParams) - // like a channel we only let sender close the pipe in happy path - if err != nil { - log.Warn("upload to s3 failed", zap.String("filename", name), zap.Error(err)) - _ = rd.CloseWithError(err) - } - s3Writer.err = err - s3Writer.wg.Done() - }() - uploader = s3Writer - } - bufSize := WriteBufferSize - if option != nil && option.PartSize > 0 { - bufSize = int(option.PartSize) - } - uploaderWriter := newBufferedWriter(uploader, bufSize, NoCompression) - return uploaderWriter, nil -} - -// Rename implements ExternalStorage interface. -func (rs *S3Storage) Rename(ctx context.Context, oldFileName, newFileName string) error { - content, err := rs.ReadFile(ctx, oldFileName) - if err != nil { - return errors.Trace(err) - } - err = rs.WriteFile(ctx, newFileName, content) - if err != nil { - return errors.Trace(err) - } - if err = rs.DeleteFile(ctx, oldFileName); err != nil { - return errors.Trace(err) - } - return nil -} - -// Close implements ExternalStorage interface. -func (*S3Storage) Close() {} - -// retryerWithLog wrappes the client.DefaultRetryer, and logging when retry triggered. -type retryerWithLog struct { - client.DefaultRetryer -} - -func isCancelError(err error) bool { - return strings.Contains(err.Error(), "context canceled") -} - -func isDeadlineExceedError(err error) bool { - // TODO find a better way. - // Known challenges: - // - // If we want to unwrap the r.Error: - // 1. the err should be an awserr.Error (let it be awsErr) - // 2. awsErr.OrigErr() should be an *url.Error (let it be urlErr). - // 3. urlErr.Err should be a http.httpError (which is private). - // - // If we want to reterive the error from the request context: - // The error of context in the HTTPRequest (i.e. r.HTTPRequest.Context().Err() ) is nil. - return strings.Contains(err.Error(), "context deadline exceeded") -} - -func isConnectionResetError(err error) bool { - return strings.Contains(err.Error(), "read: connection reset") -} - -func isConnectionRefusedError(err error) bool { - return strings.Contains(err.Error(), "connection refused") -} - -func (rl retryerWithLog) ShouldRetry(r *request.Request) bool { - // for unit test - failpoint.Inject("replace-error-to-connection-reset-by-peer", func(_ failpoint.Value) { - log.Info("original error", zap.Error(r.Error)) - if r.Error != nil { - r.Error = errors.New("read tcp *.*.*.*:*->*.*.*.*:*: read: connection reset by peer") - } - }) - if r.HTTPRequest.URL.Host == ec2MetaAddress && (isDeadlineExceedError(r.Error) || isConnectionResetError(r.Error)) { - // fast fail for unreachable linklocal address in EC2 containers. - log.Warn("failed to get EC2 metadata. skipping.", logutil.ShortError(r.Error)) - return false - } - if isConnectionResetError(r.Error) { - return true - } - if isConnectionRefusedError(r.Error) { - return false - } - return rl.DefaultRetryer.ShouldRetry(r) -} - -func (rl retryerWithLog) RetryRules(r *request.Request) time.Duration { - backoffTime := rl.DefaultRetryer.RetryRules(r) - if backoffTime > 0 { - log.Warn("failed to request s3, retrying", zap.Error(r.Error), zap.Duration("backoff", backoffTime)) - } - return backoffTime -} - -func defaultS3Retryer() request.Retryer { - return retryerWithLog{ - DefaultRetryer: client.DefaultRetryer{ - NumMaxRetries: maxRetries, - MinRetryDelay: 1 * time.Second, - MinThrottleDelay: 2 * time.Second, - }, - } -} diff --git a/br/pkg/streamhelper/advancer.go b/br/pkg/streamhelper/advancer.go index 1b21c8da19e59..6d477994ef07f 100644 --- a/br/pkg/streamhelper/advancer.go +++ b/br/pkg/streamhelper/advancer.go @@ -138,9 +138,9 @@ func (c *checkpoint) equal(o *checkpoint) bool { // we should try to resolve lock for the range // to keep the RPO in 5 min. func (c *checkpoint) needResolveLocks() bool { - if val, _err_ := failpoint.Eval(_curpkg_("NeedResolveLocks")); _err_ == nil { - return val.(bool) - } + failpoint.Inject("NeedResolveLocks", func(val failpoint.Value) { + failpoint.Return(val.(bool)) + }) return time.Since(c.resolveLockTime) > 3*time.Minute } @@ -532,7 +532,7 @@ func (c *CheckpointAdvancer) SpawnSubscriptionHandler(ctx context.Context) { if !ok { return } - failpoint.Eval(_curpkg_("subscription-handler-loop")) + failpoint.Inject("subscription-handler-loop", func() {}) c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { if vsf == nil { log.Warn("Span tree not found, perhaps stale event of removed tasks.", @@ -555,7 +555,7 @@ func (c *CheckpointAdvancer) subscribeTick(ctx context.Context) error { if c.subscriber == nil { return nil } - failpoint.Eval(_curpkg_("get_subscriber")) + failpoint.Inject("get_subscriber", nil) if err := c.subscriber.UpdateStoreTopology(ctx); err != nil { log.Warn("Error when updating store topology.", zap.String("category", "log backup advancer"), logutil.ShortError(err)) @@ -684,7 +684,7 @@ func (c *CheckpointAdvancer) asyncResolveLocksForRanges(ctx context.Context, tar // run in another goroutine // do not block main tick here go func() { - failpoint.Eval(_curpkg_("AsyncResolveLocks")) + failpoint.Inject("AsyncResolveLocks", func() {}) handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { // we will scan all locks and try to resolve them by check txn status. return tikv.ResolveLocksForRange( diff --git a/br/pkg/streamhelper/advancer.go__failpoint_stash__ b/br/pkg/streamhelper/advancer.go__failpoint_stash__ deleted file mode 100644 index 6d477994ef07f..0000000000000 --- a/br/pkg/streamhelper/advancer.go__failpoint_stash__ +++ /dev/null @@ -1,735 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package streamhelper - -import ( - "bytes" - "context" - "fmt" - "math" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/streamhelper/config" - "github.com/pingcap/tidb/br/pkg/streamhelper/spans" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/util" - tikvstore "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/txnkv/rangetask" - "go.uber.org/multierr" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -// CheckpointAdvancer is the central node for advancing the checkpoint of log backup. -// It's a part of "checkpoint v3". -// Generally, it scan the regions in the task range, collect checkpoints from tikvs. -/* - ┌──────┐ - ┌────►│ TiKV │ - │ └──────┘ - │ - │ - ┌──────────┐GetLastFlushTSOfRegion│ ┌──────┐ - │ Advancer ├──────────────────────┼────►│ TiKV │ - └────┬─────┘ │ └──────┘ - │ │ - │ │ - │ │ ┌──────┐ - │ └────►│ TiKV │ - │ └──────┘ - │ - │ UploadCheckpointV3 ┌──────────────────┐ - └─────────────────────►│ PD │ - └──────────────────┘ -*/ -type CheckpointAdvancer struct { - env Env - - // The concurrency accessed task: - // both by the task listener and ticking. - task *backuppb.StreamBackupTaskInfo - taskRange []kv.KeyRange - taskMu sync.Mutex - - // the read-only config. - // once tick begin, this should not be changed for now. - cfg config.Config - - // the cached last checkpoint. - // if no progress, this cache can help us don't to send useless requests. - lastCheckpoint *checkpoint - lastCheckpointMu sync.Mutex - inResolvingLock atomic.Bool - isPaused atomic.Bool - - checkpoints *spans.ValueSortedFull - checkpointsMu sync.Mutex - - subscriber *FlushSubscriber - subscriberMu sync.Mutex -} - -// HasTask returns whether the advancer has been bound to a task. -func (c *CheckpointAdvancer) HasTask() bool { - c.taskMu.Lock() - defer c.taskMu.Unlock() - - return c.task != nil -} - -// HasSubscriber returns whether the advancer is associated with a subscriber. -func (c *CheckpointAdvancer) HasSubscribion() bool { - c.subscriberMu.Lock() - defer c.subscriberMu.Unlock() - - return c.subscriber != nil && len(c.subscriber.subscriptions) > 0 -} - -// checkpoint represents the TS with specific range. -// it's only used in advancer.go. -type checkpoint struct { - StartKey []byte - EndKey []byte - TS uint64 - - // It's better to use PD timestamp in future, for now - // use local time to decide the time to resolve lock is ok. - resolveLockTime time.Time -} - -func newCheckpointWithTS(ts uint64) *checkpoint { - return &checkpoint{ - TS: ts, - resolveLockTime: time.Now(), - } -} - -func NewCheckpointWithSpan(s spans.Valued) *checkpoint { - return &checkpoint{ - StartKey: s.Key.StartKey, - EndKey: s.Key.EndKey, - TS: s.Value, - resolveLockTime: time.Now(), - } -} - -func (c *checkpoint) safeTS() uint64 { - return c.TS - 1 -} - -func (c *checkpoint) equal(o *checkpoint) bool { - return bytes.Equal(c.StartKey, o.StartKey) && - bytes.Equal(c.EndKey, o.EndKey) && c.TS == o.TS -} - -// if a checkpoint stay in a time too long(3 min) -// we should try to resolve lock for the range -// to keep the RPO in 5 min. -func (c *checkpoint) needResolveLocks() bool { - failpoint.Inject("NeedResolveLocks", func(val failpoint.Value) { - failpoint.Return(val.(bool)) - }) - return time.Since(c.resolveLockTime) > 3*time.Minute -} - -// NewCheckpointAdvancer creates a checkpoint advancer with the env. -func NewCheckpointAdvancer(env Env) *CheckpointAdvancer { - return &CheckpointAdvancer{ - env: env, - cfg: config.Default(), - } -} - -// UpdateConfig updates the config for the advancer. -// Note this should be called before starting the loop, because there isn't locks, -// TODO: support updating config when advancer starts working. -// (Maybe by applying changes at begin of ticking, and add locks.) -func (c *CheckpointAdvancer) UpdateConfig(newConf config.Config) { - c.cfg = newConf -} - -// UpdateConfigWith updates the config by modifying the current config. -func (c *CheckpointAdvancer) UpdateConfigWith(f func(*config.Config)) { - cfg := c.cfg - f(&cfg) - c.UpdateConfig(cfg) -} - -// UpdateLastCheckpoint modify the checkpoint in ticking. -func (c *CheckpointAdvancer) UpdateLastCheckpoint(p *checkpoint) { - c.lastCheckpointMu.Lock() - c.lastCheckpoint = p - c.lastCheckpointMu.Unlock() -} - -// Config returns the current config. -func (c *CheckpointAdvancer) Config() config.Config { - return c.cfg -} - -// GetInResolvingLock only used for test. -func (c *CheckpointAdvancer) GetInResolvingLock() bool { - return c.inResolvingLock.Load() -} - -// GetCheckpointInRange scans the regions in the range, -// collect them to the collector. -func (c *CheckpointAdvancer) GetCheckpointInRange(ctx context.Context, start, end []byte, - collector *clusterCollector) error { - log.Debug("scanning range", logutil.Key("start", start), logutil.Key("end", end)) - iter := IterateRegion(c.env, start, end) - for !iter.Done() { - rs, err := iter.Next(ctx) - if err != nil { - return err - } - log.Debug("scan region", zap.Int("len", len(rs))) - for _, r := range rs { - err := collector.CollectRegion(r) - if err != nil { - log.Warn("meet error during getting checkpoint", logutil.ShortError(err)) - return err - } - } - } - return nil -} - -func (c *CheckpointAdvancer) recordTimeCost(message string, fields ...zap.Field) func() { - now := time.Now() - label := strings.ReplaceAll(message, " ", "-") - return func() { - cost := time.Since(now) - fields = append(fields, zap.Stringer("take", cost)) - metrics.AdvancerTickDuration.WithLabelValues(label).Observe(cost.Seconds()) - log.Debug(message, fields...) - } -} - -// tryAdvance tries to advance the checkpoint ts of a set of ranges which shares the same checkpoint. -func (c *CheckpointAdvancer) tryAdvance(ctx context.Context, length int, - getRange func(int) kv.KeyRange) (err error) { - defer c.recordTimeCost("try advance", zap.Int("len", length))() - defer utils.PanicToErr(&err) - - ranges := spans.Collapse(length, getRange) - workers := util.NewWorkerPool(uint(config.DefaultMaxConcurrencyAdvance)*4, "sub ranges") - eg, cx := errgroup.WithContext(ctx) - collector := NewClusterCollector(ctx, c.env) - collector.SetOnSuccessHook(func(u uint64, kr kv.KeyRange) { - c.checkpointsMu.Lock() - defer c.checkpointsMu.Unlock() - c.checkpoints.Merge(spans.Valued{Key: kr, Value: u}) - }) - clampedRanges := utils.IntersectAll(ranges, utils.CloneSlice(c.taskRange)) - for _, r := range clampedRanges { - r := r - workers.ApplyOnErrorGroup(eg, func() (e error) { - defer c.recordTimeCost("get regions in range")() - defer utils.PanicToErr(&e) - return c.GetCheckpointInRange(cx, r.StartKey, r.EndKey, collector) - }) - } - err = eg.Wait() - if err != nil { - return err - } - - _, err = collector.Finish(ctx) - if err != nil { - return err - } - return nil -} - -func tsoBefore(n time.Duration) uint64 { - now := time.Now() - return oracle.ComposeTS(now.UnixMilli()-n.Milliseconds(), 0) -} - -func tsoAfter(ts uint64, n time.Duration) uint64 { - return oracle.GoTimeToTS(oracle.GetTimeFromTS(ts).Add(n)) -} - -func (c *CheckpointAdvancer) WithCheckpoints(f func(*spans.ValueSortedFull)) { - c.checkpointsMu.Lock() - defer c.checkpointsMu.Unlock() - - f(c.checkpoints) -} - -// only used for test -func (c *CheckpointAdvancer) NewCheckpoints(cps *spans.ValueSortedFull) { - c.checkpoints = cps -} - -func (c *CheckpointAdvancer) fetchRegionHint(ctx context.Context, startKey []byte) string { - region, err := locateKeyOfRegion(ctx, c.env, startKey) - if err != nil { - return errors.Annotate(err, "failed to fetch region").Error() - } - r := region.Region - l := region.Leader - prs := []int{} - for _, p := range r.GetPeers() { - prs = append(prs, int(p.StoreId)) - } - metrics.LogBackupCurrentLastRegionID.Set(float64(r.Id)) - metrics.LogBackupCurrentLastRegionLeaderStoreID.Set(float64(l.StoreId)) - return fmt.Sprintf("ID=%d,Leader=%d,ConfVer=%d,Version=%d,Peers=%v,RealRange=%s", - r.GetId(), l.GetStoreId(), r.GetRegionEpoch().GetConfVer(), r.GetRegionEpoch().GetVersion(), - prs, logutil.StringifyRangeOf(r.GetStartKey(), r.GetEndKey())) -} - -func (c *CheckpointAdvancer) CalculateGlobalCheckpointLight(ctx context.Context, - threshold time.Duration) (spans.Valued, error) { - var targets []spans.Valued - var minValue spans.Valued - thresholdTso := tsoBefore(threshold) - c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { - vsf.TraverseValuesLessThan(thresholdTso, func(v spans.Valued) bool { - targets = append(targets, v) - return true - }) - minValue = vsf.Min() - }) - sctx, cancel := context.WithTimeout(ctx, time.Second) - // Always fetch the hint and update the metrics. - hint := c.fetchRegionHint(sctx, minValue.Key.StartKey) - logger := log.Debug - if minValue.Value < thresholdTso { - logger = log.Info - } - logger("current last region", zap.String("category", "log backup advancer hint"), - zap.Stringer("min", minValue), zap.Int("for-polling", len(targets)), - zap.String("min-ts", oracle.GetTimeFromTS(minValue.Value).Format(time.RFC3339)), - zap.String("region-hint", hint), - ) - cancel() - if len(targets) == 0 { - return minValue, nil - } - err := c.tryAdvance(ctx, len(targets), func(i int) kv.KeyRange { return targets[i].Key }) - if err != nil { - return minValue, err - } - return minValue, nil -} - -func (c *CheckpointAdvancer) consumeAllTask(ctx context.Context, ch <-chan TaskEvent) error { - for { - select { - case e, ok := <-ch: - if !ok { - return nil - } - log.Info("meet task event", zap.Stringer("event", &e)) - if err := c.onTaskEvent(ctx, e); err != nil { - if errors.Cause(e.Err) != context.Canceled { - log.Error("listen task meet error, would reopen.", logutil.ShortError(err)) - return err - } - return nil - } - default: - return nil - } - } -} - -// beginListenTaskChange bootstraps the initial task set, -// and returns a channel respecting the change of tasks. -func (c *CheckpointAdvancer) beginListenTaskChange(ctx context.Context) (<-chan TaskEvent, error) { - ch := make(chan TaskEvent, 1024) - if err := c.env.Begin(ctx, ch); err != nil { - return nil, err - } - err := c.consumeAllTask(ctx, ch) - if err != nil { - return nil, err - } - return ch, nil -} - -// StartTaskListener starts the task listener for the advancer. -// When no task detected, advancer would do nothing, please call this before begin the tick loop. -func (c *CheckpointAdvancer) StartTaskListener(ctx context.Context) { - cx, cancel := context.WithCancel(ctx) - var ch <-chan TaskEvent - for { - if cx.Err() != nil { - // make linter happy. - cancel() - return - } - var err error - ch, err = c.beginListenTaskChange(cx) - if err == nil { - break - } - log.Warn("failed to begin listening, retrying...", logutil.ShortError(err)) - time.Sleep(c.cfg.BackoffTime) - } - - go func() { - defer cancel() - for { - select { - case <-ctx.Done(): - return - case e, ok := <-ch: - if !ok { - log.Info("Task watcher exits due to stream ends.", zap.String("category", "log backup advancer")) - return - } - log.Info("Meet task event", zap.String("category", "log backup advancer"), zap.Stringer("event", &e)) - if err := c.onTaskEvent(ctx, e); err != nil { - if errors.Cause(e.Err) != context.Canceled { - log.Error("listen task meet error, would reopen.", logutil.ShortError(err)) - time.AfterFunc(c.cfg.BackoffTime, func() { c.StartTaskListener(ctx) }) - } - log.Info("Task watcher exits due to some error.", zap.String("category", "log backup advancer"), - logutil.ShortError(err)) - return - } - } - } - }() -} - -func (c *CheckpointAdvancer) setCheckpoints(cps *spans.ValueSortedFull) { - c.checkpointsMu.Lock() - c.checkpoints = cps - c.checkpointsMu.Unlock() -} - -func (c *CheckpointAdvancer) onTaskEvent(ctx context.Context, e TaskEvent) error { - c.taskMu.Lock() - defer c.taskMu.Unlock() - switch e.Type { - case EventAdd: - utils.LogBackupTaskCountInc() - c.task = e.Info - c.taskRange = spans.Collapse(len(e.Ranges), func(i int) kv.KeyRange { return e.Ranges[i] }) - c.setCheckpoints(spans.Sorted(spans.NewFullWith(e.Ranges, 0))) - globalCheckpointTs, err := c.env.GetGlobalCheckpointForTask(ctx, e.Name) - if err != nil { - log.Error("failed to get global checkpoint, skipping.", logutil.ShortError(err)) - return err - } - if globalCheckpointTs < c.task.StartTs { - globalCheckpointTs = c.task.StartTs - } - log.Info("get global checkpoint", zap.Uint64("checkpoint", globalCheckpointTs)) - c.lastCheckpoint = newCheckpointWithTS(globalCheckpointTs) - p, err := c.env.BlockGCUntil(ctx, globalCheckpointTs) - if err != nil { - log.Warn("failed to upload service GC safepoint, skipping.", logutil.ShortError(err)) - } - log.Info("added event", zap.Stringer("task", e.Info), - zap.Stringer("ranges", logutil.StringifyKeys(c.taskRange)), zap.Uint64("current-checkpoint", p)) - case EventDel: - utils.LogBackupTaskCountDec() - c.task = nil - c.isPaused.Store(false) - c.taskRange = nil - // This would be synced by `taskMu`, perhaps we'd better rename that to `tickMu`. - // Do the null check because some of test cases won't equip the advancer with subscriber. - if c.subscriber != nil { - c.subscriber.Clear() - } - c.setCheckpoints(nil) - if err := c.env.ClearV3GlobalCheckpointForTask(ctx, e.Name); err != nil { - log.Warn("failed to clear global checkpoint", logutil.ShortError(err)) - } - if err := c.env.UnblockGC(ctx); err != nil { - log.Warn("failed to remove service GC safepoint", logutil.ShortError(err)) - } - metrics.LastCheckpoint.DeleteLabelValues(e.Name) - case EventPause: - if c.task.GetName() == e.Name { - c.isPaused.Store(true) - } - case EventResume: - if c.task.GetName() == e.Name { - c.isPaused.Store(false) - } - case EventErr: - return e.Err - } - return nil -} - -func (c *CheckpointAdvancer) setCheckpoint(ctx context.Context, s spans.Valued) bool { - cp := NewCheckpointWithSpan(s) - if cp.TS < c.lastCheckpoint.TS { - log.Warn("failed to update global checkpoint: stale", - zap.Uint64("old", c.lastCheckpoint.TS), zap.Uint64("new", cp.TS)) - return false - } - // Need resolve lock for different range and same TS - // so check the range and TS here. - if cp.equal(c.lastCheckpoint) { - return false - } - c.UpdateLastCheckpoint(cp) - metrics.LastCheckpoint.WithLabelValues(c.task.GetName()).Set(float64(c.lastCheckpoint.TS)) - return true -} - -// advanceCheckpointBy advances the checkpoint by a checkpoint getter function. -func (c *CheckpointAdvancer) advanceCheckpointBy(ctx context.Context, - getCheckpoint func(context.Context) (spans.Valued, error)) error { - start := time.Now() - cp, err := getCheckpoint(ctx) - if err != nil { - return err - } - - if c.setCheckpoint(ctx, cp) { - log.Info("uploading checkpoint for task", - zap.Stringer("checkpoint", oracle.GetTimeFromTS(cp.Value)), - zap.Uint64("checkpoint", cp.Value), - zap.String("task", c.task.Name), - zap.Stringer("take", time.Since(start))) - } - return nil -} - -func (c *CheckpointAdvancer) stopSubscriber() { - c.subscriberMu.Lock() - defer c.subscriberMu.Unlock() - c.subscriber.Drop() - c.subscriber = nil -} - -func (c *CheckpointAdvancer) SpawnSubscriptionHandler(ctx context.Context) { - c.subscriberMu.Lock() - defer c.subscriberMu.Unlock() - c.subscriber = NewSubscriber(c.env, c.env, WithMasterContext(ctx)) - es := c.subscriber.Events() - log.Info("Subscription handler spawned.", zap.String("category", "log backup subscription manager")) - - go func() { - defer utils.CatchAndLogPanic() - for { - select { - case <-ctx.Done(): - return - case event, ok := <-es: - if !ok { - return - } - failpoint.Inject("subscription-handler-loop", func() {}) - c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { - if vsf == nil { - log.Warn("Span tree not found, perhaps stale event of removed tasks.", - zap.String("category", "log backup subscription manager")) - return - } - log.Debug("Accepting region flush event.", - zap.Stringer("range", logutil.StringifyRange(event.Key)), - zap.Uint64("checkpoint", event.Value)) - vsf.Merge(event) - }) - } - } - }() -} - -func (c *CheckpointAdvancer) subscribeTick(ctx context.Context) error { - c.subscriberMu.Lock() - defer c.subscriberMu.Unlock() - if c.subscriber == nil { - return nil - } - failpoint.Inject("get_subscriber", nil) - if err := c.subscriber.UpdateStoreTopology(ctx); err != nil { - log.Warn("Error when updating store topology.", - zap.String("category", "log backup advancer"), logutil.ShortError(err)) - } - c.subscriber.HandleErrors(ctx) - return c.subscriber.PendingErrors() -} - -func (c *CheckpointAdvancer) isCheckpointLagged(ctx context.Context) (bool, error) { - if c.cfg.CheckPointLagLimit <= 0 { - return false, nil - } - - now, err := c.env.FetchCurrentTS(ctx) - if err != nil { - return false, err - } - - lagDuration := oracle.GetTimeFromTS(now).Sub(oracle.GetTimeFromTS(c.lastCheckpoint.TS)) - if lagDuration > c.cfg.CheckPointLagLimit { - log.Warn("checkpoint lag is too large", zap.String("category", "log backup advancer"), - zap.Stringer("lag", lagDuration)) - return true, nil - } - return false, nil -} - -func (c *CheckpointAdvancer) importantTick(ctx context.Context) error { - c.checkpointsMu.Lock() - c.setCheckpoint(ctx, c.checkpoints.Min()) - c.checkpointsMu.Unlock() - if err := c.env.UploadV3GlobalCheckpointForTask(ctx, c.task.Name, c.lastCheckpoint.TS); err != nil { - return errors.Annotate(err, "failed to upload global checkpoint") - } - isLagged, err := c.isCheckpointLagged(ctx) - if err != nil { - return errors.Annotate(err, "failed to check timestamp") - } - if isLagged { - err := c.env.PauseTask(ctx, c.task.Name) - if err != nil { - return errors.Annotate(err, "failed to pause task") - } - return errors.Annotate(errors.Errorf("check point lagged too large"), "check point lagged too large") - } - p, err := c.env.BlockGCUntil(ctx, c.lastCheckpoint.safeTS()) - if err != nil { - return errors.Annotatef(err, - "failed to update service GC safe point, current checkpoint is %d, target checkpoint is %d", - c.lastCheckpoint.safeTS(), p) - } - if p <= c.lastCheckpoint.safeTS() { - log.Info("updated log backup GC safe point.", - zap.Uint64("checkpoint", p), zap.Uint64("target", c.lastCheckpoint.safeTS())) - } - if p > c.lastCheckpoint.safeTS() { - log.Warn("update log backup GC safe point failed: stale.", - zap.Uint64("checkpoint", p), zap.Uint64("target", c.lastCheckpoint.safeTS())) - } - return nil -} - -func (c *CheckpointAdvancer) optionalTick(cx context.Context) error { - // lastCheckpoint is not increased too long enough. - // assume the cluster has expired locks for whatever reasons. - var targets []spans.Valued - if c.lastCheckpoint != nil && c.lastCheckpoint.needResolveLocks() && c.inResolvingLock.CompareAndSwap(false, true) { - c.WithCheckpoints(func(vsf *spans.ValueSortedFull) { - // when get locks here. assume these locks are not belong to same txn, - // but these locks' start ts are close to 1 minute. try resolve these locks at one time - vsf.TraverseValuesLessThan(tsoAfter(c.lastCheckpoint.TS, time.Minute), func(v spans.Valued) bool { - targets = append(targets, v) - return true - }) - }) - if len(targets) != 0 { - log.Info("Advancer starts to resolve locks", zap.Int("targets", len(targets))) - // use new context here to avoid timeout - ctx := context.Background() - c.asyncResolveLocksForRanges(ctx, targets) - } else { - // don't forget set state back - c.inResolvingLock.Store(false) - } - } - threshold := c.Config().GetDefaultStartPollThreshold() - if err := c.subscribeTick(cx); err != nil { - log.Warn("Subscriber meet error, would polling the checkpoint.", zap.String("category", "log backup advancer"), - logutil.ShortError(err)) - threshold = c.Config().GetSubscriberErrorStartPollThreshold() - } - - return c.advanceCheckpointBy(cx, func(cx context.Context) (spans.Valued, error) { - return c.CalculateGlobalCheckpointLight(cx, threshold) - }) -} - -func (c *CheckpointAdvancer) tick(ctx context.Context) error { - c.taskMu.Lock() - defer c.taskMu.Unlock() - if c.task == nil || c.isPaused.Load() { - log.Debug("No tasks yet, skipping advancing.") - return nil - } - - var errs error - - cx, cancel := context.WithTimeout(ctx, c.Config().TickTimeout()) - defer cancel() - err := c.optionalTick(cx) - if err != nil { - log.Warn("option tick failed.", zap.String("category", "log backup advancer"), logutil.ShortError(err)) - errs = multierr.Append(errs, err) - } - - err = c.importantTick(ctx) - if err != nil { - log.Warn("important tick failed.", zap.String("category", "log backup advancer"), logutil.ShortError(err)) - errs = multierr.Append(errs, err) - } - - return errs -} - -func (c *CheckpointAdvancer) asyncResolveLocksForRanges(ctx context.Context, targets []spans.Valued) { - // run in another goroutine - // do not block main tick here - go func() { - failpoint.Inject("AsyncResolveLocks", func() {}) - handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { - // we will scan all locks and try to resolve them by check txn status. - return tikv.ResolveLocksForRange( - ctx, c.env, math.MaxUint64, r.StartKey, r.EndKey, tikv.NewGcResolveLockMaxBackoffer, tikv.GCScanLockLimit) - } - workerPool := util.NewWorkerPool(uint(config.DefaultMaxConcurrencyAdvance), "advancer resolve locks") - var wg sync.WaitGroup - for _, r := range targets { - targetRange := r - wg.Add(1) - workerPool.Apply(func() { - defer wg.Done() - // Run resolve lock on the whole TiKV cluster. - // it will use startKey/endKey to scan region in PD. - // but regionCache already has a codecPDClient. so just use decode key here. - // and it almost only include one region here. so set concurrency to 1. - runner := rangetask.NewRangeTaskRunner("advancer-resolve-locks-runner", - c.env.GetStore(), 1, handler) - err := runner.RunOnRange(ctx, targetRange.Key.StartKey, targetRange.Key.EndKey) - if err != nil { - // wait for next tick - log.Warn("resolve locks failed, wait for next tick", zap.String("category", "advancer"), - zap.String("uuid", "log backup advancer"), - zap.Error(err)) - } - }) - } - wg.Wait() - log.Info("finish resolve locks for checkpoint", zap.String("category", "advancer"), - zap.String("uuid", "log backup advancer"), - logutil.Key("StartKey", c.lastCheckpoint.StartKey), - logutil.Key("EndKey", c.lastCheckpoint.EndKey), - zap.Int("targets", len(targets))) - c.lastCheckpointMu.Lock() - c.lastCheckpoint.resolveLockTime = time.Now() - c.lastCheckpointMu.Unlock() - c.inResolvingLock.Store(false) - }() -} - -func (c *CheckpointAdvancer) TEST_registerCallbackForSubscriptions(f func()) int { - cnt := 0 - for _, sub := range c.subscriber.subscriptions { - sub.onDaemonExit = f - cnt += 1 - } - return cnt -} diff --git a/br/pkg/streamhelper/advancer_cliext.go b/br/pkg/streamhelper/advancer_cliext.go index f283120549451..1411c306c3abd 100644 --- a/br/pkg/streamhelper/advancer_cliext.go +++ b/br/pkg/streamhelper/advancer_cliext.go @@ -183,10 +183,10 @@ func (t AdvancerExt) startListen(ctx context.Context, rev int64, ch chan<- TaskE for { select { case resp, ok := <-taskCh: - if _, _err_ := failpoint.Eval(_curpkg_("advancer_close_channel")); _err_ == nil { + failpoint.Inject("advancer_close_channel", func() { // We cannot really close the channel, just simulating it. ok = false - } + }) if !ok { ch <- errorEvent(io.EOF) return @@ -195,10 +195,10 @@ func (t AdvancerExt) startListen(ctx context.Context, rev int64, ch chan<- TaskE return } case resp, ok := <-pauseCh: - if _, _err_ := failpoint.Eval(_curpkg_("advancer_close_pause_channel")); _err_ == nil { + failpoint.Inject("advancer_close_pause_channel", func() { // We cannot really close the channel, just simulating it. ok = false - } + }) if !ok { ch <- errorEvent(io.EOF) return diff --git a/br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ b/br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ deleted file mode 100644 index 1411c306c3abd..0000000000000 --- a/br/pkg/streamhelper/advancer_cliext.go__failpoint_stash__ +++ /dev/null @@ -1,301 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package streamhelper - -import ( - "bytes" - "context" - "encoding/binary" - "fmt" - "io" - "strings" - - "github.com/golang/protobuf/proto" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/util/redact" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" -) - -type EventType int - -const ( - EventAdd EventType = iota - EventDel - EventErr - EventPause - EventResume -) - -func (t EventType) String() string { - switch t { - case EventAdd: - return "Add" - case EventDel: - return "Del" - case EventErr: - return "Err" - case EventPause: - return "Pause" - case EventResume: - return "Resume" - } - return "Unknown" -} - -type TaskEvent struct { - Type EventType - Name string - Info *backuppb.StreamBackupTaskInfo - Ranges []kv.KeyRange - Err error -} - -func (t *TaskEvent) String() string { - if t.Err != nil { - return fmt.Sprintf("%s(%s, err = %s)", t.Type, t.Name, t.Err) - } - return fmt.Sprintf("%s(%s)", t.Type, t.Name) -} - -type AdvancerExt struct { - MetaDataClient -} - -func errorEvent(err error) TaskEvent { - return TaskEvent{ - Type: EventErr, - Err: err, - } -} - -func (t AdvancerExt) toTaskEvent(ctx context.Context, event *clientv3.Event) (TaskEvent, error) { - te := TaskEvent{} - var prefix string - - if bytes.HasPrefix(event.Kv.Key, []byte(PrefixOfTask())) { - prefix = PrefixOfTask() - te.Name = strings.TrimPrefix(string(event.Kv.Key), prefix) - } else if bytes.HasPrefix(event.Kv.Key, []byte(PrefixOfPause())) { - prefix = PrefixOfPause() - te.Name = strings.TrimPrefix(string(event.Kv.Key), prefix) - } else { - return TaskEvent{}, - errors.Annotatef(berrors.ErrInvalidArgument, "the path isn't a task/pause path (%s)", - string(event.Kv.Key)) - } - - switch { - case event.Type == clientv3.EventTypePut && prefix == PrefixOfTask(): - te.Type = EventAdd - case event.Type == clientv3.EventTypeDelete && prefix == PrefixOfTask(): - te.Type = EventDel - case event.Type == clientv3.EventTypePut && prefix == PrefixOfPause(): - te.Type = EventPause - case event.Type == clientv3.EventTypeDelete && prefix == PrefixOfPause(): - te.Type = EventResume - default: - return TaskEvent{}, - errors.Annotatef(berrors.ErrInvalidArgument, - "invalid event type or prefix: type=%s, prefix=%s", event.Type, prefix) - } - - te.Info = new(backuppb.StreamBackupTaskInfo) - if err := proto.Unmarshal(event.Kv.Value, te.Info); err != nil { - return TaskEvent{}, err - } - - var err error - te.Ranges, err = t.MetaDataClient.TaskByInfo(*te.Info).Ranges(ctx) - if err != nil { - return TaskEvent{}, err - } - - return te, nil -} - -func (t AdvancerExt) eventFromWatch(ctx context.Context, resp clientv3.WatchResponse) ([]TaskEvent, error) { - result := make([]TaskEvent, 0, len(resp.Events)) - if err := resp.Err(); err != nil { - return nil, err - } - for _, event := range resp.Events { - te, err := t.toTaskEvent(ctx, event) - if err != nil { - te.Type = EventErr - te.Err = err - } - result = append(result, te) - } - return result, nil -} - -func (t AdvancerExt) startListen(ctx context.Context, rev int64, ch chan<- TaskEvent) { - taskCh := t.Client.Watcher.Watch(ctx, PrefixOfTask(), clientv3.WithPrefix(), clientv3.WithRev(rev)) - pauseCh := t.Client.Watcher.Watch(ctx, PrefixOfPause(), clientv3.WithPrefix(), clientv3.WithRev(rev)) - - // inner function def - handleResponse := func(resp clientv3.WatchResponse) bool { - events, err := t.eventFromWatch(ctx, resp) - if err != nil { - log.Warn("Meet error during receiving the task event.", - zap.String("category", "log backup advancer"), logutil.ShortError(err)) - ch <- errorEvent(err) - return false - } - for _, event := range events { - ch <- event - } - return true - } - - // inner function def - collectRemaining := func() { - log.Info("Start collecting remaining events in the channel.", zap.String("category", "log backup advancer"), - zap.Int("remained", len(taskCh))) - defer log.Info("Finish collecting remaining events in the channel.", zap.String("category", "log backup advancer")) - for { - if taskCh == nil && pauseCh == nil { - return - } - - select { - case resp, ok := <-taskCh: - if !ok || !handleResponse(resp) { - taskCh = nil - } - case resp, ok := <-pauseCh: - if !ok || !handleResponse(resp) { - pauseCh = nil - } - } - } - } - - go func() { - defer close(ch) - for { - select { - case resp, ok := <-taskCh: - failpoint.Inject("advancer_close_channel", func() { - // We cannot really close the channel, just simulating it. - ok = false - }) - if !ok { - ch <- errorEvent(io.EOF) - return - } - if !handleResponse(resp) { - return - } - case resp, ok := <-pauseCh: - failpoint.Inject("advancer_close_pause_channel", func() { - // We cannot really close the channel, just simulating it. - ok = false - }) - if !ok { - ch <- errorEvent(io.EOF) - return - } - if !handleResponse(resp) { - return - } - case <-ctx.Done(): - collectRemaining() - ch <- errorEvent(ctx.Err()) - return - } - } - }() -} - -func (t AdvancerExt) getFullTasksAsEvent(ctx context.Context) ([]TaskEvent, int64, error) { - tasks, rev, err := t.GetAllTasksWithRevision(ctx) - if err != nil { - return nil, 0, err - } - events := make([]TaskEvent, 0, len(tasks)) - for _, task := range tasks { - ranges, err := task.Ranges(ctx) - if err != nil { - return nil, 0, err - } - te := TaskEvent{ - Type: EventAdd, - Name: task.Info.Name, - Info: &(task.Info), - Ranges: ranges, - } - events = append(events, te) - } - return events, rev, nil -} - -func (t AdvancerExt) Begin(ctx context.Context, ch chan<- TaskEvent) error { - initialTasks, rev, err := t.getFullTasksAsEvent(ctx) - if err != nil { - return err - } - // Note: maybe `go` here so we won't block? - for _, task := range initialTasks { - ch <- task - } - t.startListen(ctx, rev+1, ch) - return nil -} - -func (t AdvancerExt) GetGlobalCheckpointForTask(ctx context.Context, taskName string) (uint64, error) { - key := GlobalCheckpointOf(taskName) - resp, err := t.KV.Get(ctx, key) - if err != nil { - return 0, err - } - - if len(resp.Kvs) == 0 { - return 0, nil - } - - firstKV := resp.Kvs[0] - value := firstKV.Value - if len(value) != 8 { - return 0, errors.Annotatef(berrors.ErrPiTRMalformedMetadata, - "the global checkpoint isn't 64bits (it is %d bytes, value = %s)", - len(value), - redact.Key(value)) - } - - return binary.BigEndian.Uint64(value), nil -} - -func (t AdvancerExt) UploadV3GlobalCheckpointForTask(ctx context.Context, taskName string, checkpoint uint64) error { - key := GlobalCheckpointOf(taskName) - value := string(encodeUint64(checkpoint)) - oldValue, err := t.GetGlobalCheckpointForTask(ctx, taskName) - if err != nil { - return err - } - - if checkpoint < oldValue { - log.Warn("skipping upload global checkpoint", zap.String("category", "log backup advancer"), - zap.Uint64("old", oldValue), zap.Uint64("new", checkpoint)) - return nil - } - - _, err = t.KV.Put(ctx, key, value) - if err != nil { - return err - } - return nil -} - -func (t AdvancerExt) ClearV3GlobalCheckpointForTask(ctx context.Context, taskName string) error { - key := GlobalCheckpointOf(taskName) - _, err := t.KV.Delete(ctx, key) - return err -} diff --git a/br/pkg/streamhelper/binding__failpoint_binding__.go b/br/pkg/streamhelper/binding__failpoint_binding__.go deleted file mode 100644 index 0872efd5448a7..0000000000000 --- a/br/pkg/streamhelper/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package streamhelper - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/streamhelper/flush_subscriber.go b/br/pkg/streamhelper/flush_subscriber.go index 8d32adb576749..f1310b0212372 100644 --- a/br/pkg/streamhelper/flush_subscriber.go +++ b/br/pkg/streamhelper/flush_subscriber.go @@ -102,10 +102,10 @@ func (f *FlushSubscriber) UpdateStoreTopology(ctx context.Context) error { // Clear clears all the subscriptions. func (f *FlushSubscriber) Clear() { timeout := clearSubscriberTimeOut - if v, _err_ := failpoint.Eval(_curpkg_("FlushSubscriber.Clear.timeoutMs")); _err_ == nil { + failpoint.Inject("FlushSubscriber.Clear.timeoutMs", func(v failpoint.Value) { //nolint:durationcheck timeout = time.Duration(v.(int)) * time.Millisecond - } + }) log.Info("Clearing.", zap.String("category", "log backup flush subscriber"), zap.Duration("timeout", timeout)) @@ -302,7 +302,7 @@ func (s *subscription) listenOver(ctx context.Context, cli eventStream) { logutil.Key("event", m.EndKey), logutil.ShortError(err)) continue } - failpoint.Eval(_curpkg_("subscription.listenOver.aboutToSend")) + failpoint.Inject("subscription.listenOver.aboutToSend", func() {}) evt := spans.Valued{ Key: spans.Span{ diff --git a/br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ b/br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ deleted file mode 100644 index f1310b0212372..0000000000000 --- a/br/pkg/streamhelper/flush_subscriber.go__failpoint_stash__ +++ /dev/null @@ -1,373 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package streamhelper - -import ( - "context" - "io" - "strconv" - "sync" - "time" - - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/streamhelper/spans" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/util/codec" - "go.uber.org/multierr" - "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -const ( - // clearSubscriberTimeOut is the timeout for clearing the subscriber. - clearSubscriberTimeOut = 1 * time.Minute -) - -// FlushSubscriber maintains the state of subscribing to the cluster. -type FlushSubscriber struct { - dialer LogBackupService - cluster TiKVClusterMeta - - // Current connections. - subscriptions map[uint64]*subscription - // The output channel. - eventsTunnel chan spans.Valued - // The background context for subscribes. - masterCtx context.Context -} - -// SubscriberConfig is a config which cloud be applied into the subscriber. -type SubscriberConfig func(*FlushSubscriber) - -// WithMasterContext sets the "master context" for the subscriber, -// that context would be the "background" context for every subtasks created by the subscription manager. -func WithMasterContext(ctx context.Context) SubscriberConfig { - return func(fs *FlushSubscriber) { fs.masterCtx = ctx } -} - -// NewSubscriber creates a new subscriber via the environment and optional configs. -func NewSubscriber(dialer LogBackupService, cluster TiKVClusterMeta, config ...SubscriberConfig) *FlushSubscriber { - subs := &FlushSubscriber{ - dialer: dialer, - cluster: cluster, - - subscriptions: map[uint64]*subscription{}, - eventsTunnel: make(chan spans.Valued, 1024), - masterCtx: context.Background(), - } - - for _, c := range config { - c(subs) - } - - return subs -} - -// UpdateStoreTopology fetches the current store topology and try to adapt the subscription state with it. -func (f *FlushSubscriber) UpdateStoreTopology(ctx context.Context) error { - stores, err := f.cluster.Stores(ctx) - if err != nil { - return errors.Annotate(err, "failed to get store list") - } - - storeSet := map[uint64]struct{}{} - for _, store := range stores { - sub, ok := f.subscriptions[store.ID] - if !ok { - f.addSubscription(ctx, store) - f.subscriptions[store.ID].connect(f.masterCtx, f.dialer) - } else if sub.storeBootAt != store.BootAt { - sub.storeBootAt = store.BootAt - sub.connect(f.masterCtx, f.dialer) - } - storeSet[store.ID] = struct{}{} - } - - for id := range f.subscriptions { - _, ok := storeSet[id] - if !ok { - f.removeSubscription(ctx, id) - } - } - return nil -} - -// Clear clears all the subscriptions. -func (f *FlushSubscriber) Clear() { - timeout := clearSubscriberTimeOut - failpoint.Inject("FlushSubscriber.Clear.timeoutMs", func(v failpoint.Value) { - //nolint:durationcheck - timeout = time.Duration(v.(int)) * time.Millisecond - }) - log.Info("Clearing.", - zap.String("category", "log backup flush subscriber"), - zap.Duration("timeout", timeout)) - cx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - for id := range f.subscriptions { - f.removeSubscription(cx, id) - } -} - -// Drop terminates the lifetime of the subscriber. -// This subscriber would be no more usable. -func (f *FlushSubscriber) Drop() { - f.Clear() - close(f.eventsTunnel) -} - -// HandleErrors execute the handlers over all pending errors. -// Note that the handler may cannot handle the pending errors, at that time, -// you can fetch the errors via `PendingErrors` call. -func (f *FlushSubscriber) HandleErrors(ctx context.Context) { - for id, sub := range f.subscriptions { - err := sub.loadError() - if err != nil { - retry := f.canBeRetried(err) - log.Warn("Meet error.", zap.String("category", "log backup flush subscriber"), - logutil.ShortError(err), zap.Bool("can-retry?", retry), zap.Uint64("store", id)) - if retry { - sub.connect(f.masterCtx, f.dialer) - } - } - } -} - -// Events returns the output channel of the events. -func (f *FlushSubscriber) Events() <-chan spans.Valued { - return f.eventsTunnel -} - -type eventStream = logbackup.LogBackup_SubscribeFlushEventClient - -type joinHandle <-chan struct{} - -func (jh joinHandle) Wait(ctx context.Context) { - select { - case <-jh: - case <-ctx.Done(): - log.Warn("join handle timed out.", zap.StackSkip("caller", 1)) - } -} - -func spawnJoinable(f func()) joinHandle { - c := make(chan struct{}) - go func() { - defer close(c) - f() - }() - return c -} - -// subscription is the state of subscription of one store. -// initially, it is IDLE, where cancel == nil. -// once `connect` called, it goto CONNECTED, where cancel != nil and err == nil. -// once some error (both foreground or background) happens, it goto ERROR, where err != nil. -type subscription struct { - // the handle to cancel the worker goroutine. - cancel context.CancelFunc - // the handle to wait until the worker goroutine exits. - background joinHandle - errMu sync.Mutex - err error - - // Immutable state. - storeID uint64 - // We record start bootstrap time and once a store restarts - // we need to try reconnect even there is a error cannot be retry. - storeBootAt uint64 - output chan<- spans.Valued - - onDaemonExit func() -} - -func (s *subscription) emitError(err error) { - s.errMu.Lock() - defer s.errMu.Unlock() - - s.err = err -} - -func (s *subscription) loadError() error { - s.errMu.Lock() - defer s.errMu.Unlock() - - return s.err -} - -func (s *subscription) clearError() { - s.errMu.Lock() - defer s.errMu.Unlock() - - s.err = nil -} - -func newSubscription(toStore Store, output chan<- spans.Valued) *subscription { - return &subscription{ - storeID: toStore.ID, - storeBootAt: toStore.BootAt, - output: output, - } -} - -func (s *subscription) connect(ctx context.Context, dialer LogBackupService) { - err := s.doConnect(ctx, dialer) - if err != nil { - s.emitError(err) - } -} - -func (s *subscription) doConnect(ctx context.Context, dialer LogBackupService) error { - log.Info("Adding subscription.", zap.String("category", "log backup subscription manager"), - zap.Uint64("store", s.storeID), zap.Uint64("boot", s.storeBootAt)) - // We should shutdown the background task firstly. - // Once it yields some error during shuting down, the error won't be brought to next run. - s.close(ctx) - s.clearError() - - c, err := dialer.GetLogBackupClient(ctx, s.storeID) - if err != nil { - return errors.Annotate(err, "failed to get log backup client") - } - cx, cancel := context.WithCancel(ctx) - cli, err := c.SubscribeFlushEvent(cx, &logbackup.SubscribeFlushEventRequest{ - ClientId: uuid.NewString(), - }) - if err != nil { - cancel() - _ = dialer.ClearCache(ctx, s.storeID) - return errors.Annotate(err, "failed to subscribe events") - } - lcx := logutil.ContextWithField(cx, zap.Uint64("store-id", s.storeID), - zap.String("category", "log backup flush subscriber")) - s.cancel = cancel - s.background = spawnJoinable(func() { s.listenOver(lcx, cli) }) - return nil -} - -func (s *subscription) close(ctx context.Context) { - if s.cancel != nil { - s.cancel() - s.background.Wait(ctx) - } - // HACK: don't close the internal channel here, - // because it is a ever-sharing channel. -} - -func (s *subscription) listenOver(ctx context.Context, cli eventStream) { - storeID := s.storeID - logutil.CL(ctx).Info("Listen starting.", zap.Uint64("store", storeID)) - defer func() { - if s.onDaemonExit != nil { - s.onDaemonExit() - } - - if pData := recover(); pData != nil { - log.Warn("Subscriber paniked.", zap.Uint64("store", storeID), zap.Any("panic-data", pData), zap.Stack("stack")) - s.emitError(errors.Annotatef(berrors.ErrUnknown, "panic during executing: %v", pData)) - } - }() - for { - // Shall we use RecvMsg for better performance? - // Note that the spans.Full requires the input slice be immutable. - msg, err := cli.Recv() - if err != nil { - logutil.CL(ctx).Info("Listen stopped.", - zap.Uint64("store", storeID), logutil.ShortError(err)) - if err == io.EOF || err == context.Canceled || status.Code(err) == codes.Canceled { - return - } - s.emitError(errors.Annotatef(err, "while receiving from store id %d", storeID)) - return - } - - log.Debug("Sending events.", zap.Int("size", len(msg.Events))) - for _, m := range msg.Events { - start, err := decodeKey(m.StartKey) - if err != nil { - logutil.CL(ctx).Warn("start key not encoded, skipping", - logutil.Key("event", m.StartKey), logutil.ShortError(err)) - continue - } - end, err := decodeKey(m.EndKey) - if err != nil { - logutil.CL(ctx).Warn("end key not encoded, skipping", - logutil.Key("event", m.EndKey), logutil.ShortError(err)) - continue - } - failpoint.Inject("subscription.listenOver.aboutToSend", func() {}) - - evt := spans.Valued{ - Key: spans.Span{ - StartKey: start, - EndKey: end, - }, - Value: m.Checkpoint, - } - select { - case s.output <- evt: - case <-ctx.Done(): - logutil.CL(ctx).Warn("Context canceled while sending events.", - zap.Uint64("store", storeID)) - return - } - } - metrics.RegionCheckpointSubscriptionEvent.WithLabelValues( - strconv.Itoa(int(storeID))).Observe(float64(len(msg.Events))) - } -} - -func (f *FlushSubscriber) addSubscription(ctx context.Context, toStore Store) { - f.subscriptions[toStore.ID] = newSubscription(toStore, f.eventsTunnel) -} - -func (f *FlushSubscriber) removeSubscription(ctx context.Context, toStore uint64) { - subs, ok := f.subscriptions[toStore] - if ok { - log.Info("Removing subscription.", zap.String("category", "log backup subscription manager"), - zap.Uint64("store", toStore)) - subs.close(ctx) - delete(f.subscriptions, toStore) - } -} - -// decodeKey decodes the key from TiKV, because the region range is encoded in TiKV. -func decodeKey(key []byte) ([]byte, error) { - if len(key) == 0 { - return key, nil - } - // Ignore the timestamp... - _, data, err := codec.DecodeBytes(key, nil) - if err != nil { - return key, err - } - return data, err -} - -func (f *FlushSubscriber) canBeRetried(err error) bool { - for _, e := range multierr.Errors(errors.Cause(err)) { - s := status.Convert(e) - // Is there any other error cannot be retried? - if s.Code() == codes.Unimplemented { - return false - } - } - return true -} - -func (f *FlushSubscriber) PendingErrors() error { - var allErr error - for _, s := range f.subscriptions { - if err := s.loadError(); err != nil { - allErr = multierr.Append(allErr, errors.Annotatef(err, "store %d has error", s.storeID)) - } - } - return allErr -} diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index 2acb33e8aa6c8..a04f14f8b519a 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -631,7 +631,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig progressCount := uint64(0) progressCallBack := func() { updateCh.Inc() - if v, _err_ := failpoint.Eval(_curpkg_("progress-call-back")); _err_ == nil { + failpoint.Inject("progress-call-back", func(v failpoint.Value) { log.Info("failpoint progress-call-back injected") atomic.AddUint64(&progressCount, 1) if fileName, ok := v.(string); ok { @@ -645,7 +645,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig log.Warn("failed to write data to file", zap.Error(err)) } } - } + }) } if cfg.UseCheckpoint { @@ -668,7 +668,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig }() } - if v, _err_ := failpoint.Eval(_curpkg_("s3-outage-during-writing-file")); _err_ == nil { + failpoint.Inject("s3-outage-during-writing-file", func(v failpoint.Value) { log.Info("failpoint s3-outage-during-writing-file injected, " + "process will sleep for 5s and notify the shell to kill s3 service.") if sigFile, ok := v.(string); ok { @@ -681,7 +681,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig } } time.Sleep(5 * time.Second) - } + }) metawriter.StartWriteMetasAsync(ctx, metautil.AppendDataFile) err = client.BackupRanges(ctx, ranges, req, uint(cfg.Concurrency), cfg.ReplicaReadLabel, metawriter, progressCallBack) diff --git a/br/pkg/task/backup.go__failpoint_stash__ b/br/pkg/task/backup.go__failpoint_stash__ deleted file mode 100644 index a04f14f8b519a..0000000000000 --- a/br/pkg/task/backup.go__failpoint_stash__ +++ /dev/null @@ -1,835 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package task - -import ( - "context" - "crypto/sha256" - "encoding/json" - "fmt" - "os" - "strconv" - "strings" - "sync/atomic" - "time" - - "github.com/docker/go-units" - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/backup" - "github.com/pingcap/tidb/br/pkg/checkpoint" - "github.com/pingcap/tidb/br/pkg/checksum" - "github.com/pingcap/tidb/br/pkg/conn" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/metautil" - "github.com/pingcap/tidb/br/pkg/rtree" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/statistics/handle" - "github.com/pingcap/tidb/pkg/types" - "github.com/spf13/pflag" - "github.com/tikv/client-go/v2/oracle" - kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/multierr" - "go.uber.org/zap" -) - -const ( - flagBackupTimeago = "timeago" - flagBackupTS = "backupts" - flagLastBackupTS = "lastbackupts" - flagCompressionType = "compression" - flagCompressionLevel = "compression-level" - flagRemoveSchedulers = "remove-schedulers" - flagIgnoreStats = "ignore-stats" - flagUseBackupMetaV2 = "use-backupmeta-v2" - flagUseCheckpoint = "use-checkpoint" - flagKeyspaceName = "keyspace-name" - flagReplicaReadLabel = "replica-read-label" - flagTableConcurrency = "table-concurrency" - - flagGCTTL = "gcttl" - - defaultBackupConcurrency = 4 - maxBackupConcurrency = 256 -) - -const ( - FullBackupCmd = "Full Backup" - DBBackupCmd = "Database Backup" - TableBackupCmd = "Table Backup" - RawBackupCmd = "Raw Backup" - TxnBackupCmd = "Txn Backup" - EBSBackupCmd = "EBS Backup" -) - -// CompressionConfig is the configuration for sst file compression. -type CompressionConfig struct { - CompressionType backuppb.CompressionType `json:"compression-type" toml:"compression-type"` - CompressionLevel int32 `json:"compression-level" toml:"compression-level"` -} - -// BackupConfig is the configuration specific for backup tasks. -type BackupConfig struct { - Config - - TimeAgo time.Duration `json:"time-ago" toml:"time-ago"` - BackupTS uint64 `json:"backup-ts" toml:"backup-ts"` - LastBackupTS uint64 `json:"last-backup-ts" toml:"last-backup-ts"` - GCTTL int64 `json:"gc-ttl" toml:"gc-ttl"` - RemoveSchedulers bool `json:"remove-schedulers" toml:"remove-schedulers"` - IgnoreStats bool `json:"ignore-stats" toml:"ignore-stats"` - UseBackupMetaV2 bool `json:"use-backupmeta-v2"` - UseCheckpoint bool `json:"use-checkpoint" toml:"use-checkpoint"` - ReplicaReadLabel map[string]string `json:"replica-read-label" toml:"replica-read-label"` - TableConcurrency uint `json:"table-concurrency" toml:"table-concurrency"` - CompressionConfig - - // for ebs-based backup - FullBackupType FullBackupType `json:"full-backup-type" toml:"full-backup-type"` - VolumeFile string `json:"volume-file" toml:"volume-file"` - SkipAWS bool `json:"skip-aws" toml:"skip-aws"` - CloudAPIConcurrency uint `json:"cloud-api-concurrency" toml:"cloud-api-concurrency"` - ProgressFile string `json:"progress-file" toml:"progress-file"` - SkipPauseGCAndScheduler bool `json:"skip-pause-gc-and-scheduler" toml:"skip-pause-gc-and-scheduler"` -} - -// DefineBackupFlags defines common flags for the backup command. -func DefineBackupFlags(flags *pflag.FlagSet) { - flags.Duration( - flagBackupTimeago, 0, - "The history version of the backup task, e.g. 1m, 1h. Do not exceed GCSafePoint") - - // TODO: remove experimental tag if it's stable - flags.Uint64(flagLastBackupTS, 0, "(experimental) the last time backup ts,"+ - " use for incremental backup, support TSO only") - flags.String(flagBackupTS, "", "the backup ts support TSO or datetime,"+ - " e.g. '400036290571534337', '2018-05-11 01:42:23'") - flags.Int64(flagGCTTL, utils.DefaultBRGCSafePointTTL, "the TTL (in seconds) that PD holds for BR's GC safepoint") - flags.String(flagCompressionType, "zstd", - "backup sst file compression algorithm, value can be one of 'lz4|zstd|snappy'") - flags.Int32(flagCompressionLevel, 0, "compression level used for sst file compression") - - flags.Uint32(flagConcurrency, 4, "The size of a BR thread pool that executes tasks, "+ - "One task represents one table range (or one index range) according to the backup schemas. If there is one table with one index."+ - "there will be two tasks to back up this table. This value should increase if you need to back up lots of tables or indices.") - - flags.Uint(flagTableConcurrency, backup.DefaultSchemaConcurrency, "The size of a BR thread pool used for backup table metas, "+ - "including tableInfo/checksum and stats.") - - flags.Bool(flagRemoveSchedulers, false, - "disable the balance, shuffle and region-merge schedulers in PD to speed up backup") - // This flag can impact the online cluster, so hide it in case of abuse. - _ = flags.MarkHidden(flagRemoveSchedulers) - - // Disable stats by default. - // TODO: we need a better way to backup/restore stats. - flags.Bool(flagIgnoreStats, true, "ignore backup stats") - - flags.Bool(flagUseBackupMetaV2, true, - "use backup meta v2 to store meta info") - - flags.String(flagKeyspaceName, "", "keyspace name for backup") - // This flag will change the structure of backupmeta. - // we must make sure the old three version of br can parse the v2 meta to keep compatibility. - // so this flag should set to false for three version by default. - // for example: - // if we put this feature in v4.0.14, then v4.0.14 br can parse v2 meta - // but will generate v1 meta due to this flag is false. the behaviour is as same as v4.0.15, v4.0.16. - // finally v4.0.17 will set this flag to true, and generate v2 meta. - // - // the version currently is v7.4.0, the flag can be set to true as default value. - // _ = flags.MarkHidden(flagUseBackupMetaV2) - - flags.Bool(flagUseCheckpoint, true, "use checkpoint mode") - _ = flags.MarkHidden(flagUseCheckpoint) - - flags.String(flagReplicaReadLabel, "", "specify the label of the stores to be used for backup, e.g. 'label_key:label_value'") -} - -// ParseFromFlags parses the backup-related flags from the flag set. -func (cfg *BackupConfig) ParseFromFlags(flags *pflag.FlagSet) error { - timeAgo, err := flags.GetDuration(flagBackupTimeago) - if err != nil { - return errors.Trace(err) - } - if timeAgo < 0 { - return errors.Annotate(berrors.ErrInvalidArgument, "negative timeago is not allowed") - } - cfg.TimeAgo = timeAgo - cfg.LastBackupTS, err = flags.GetUint64(flagLastBackupTS) - if err != nil { - return errors.Trace(err) - } - backupTS, err := flags.GetString(flagBackupTS) - if err != nil { - return errors.Trace(err) - } - cfg.BackupTS, err = ParseTSString(backupTS, false) - if err != nil { - return errors.Trace(err) - } - cfg.UseBackupMetaV2, err = flags.GetBool(flagUseBackupMetaV2) - if err != nil { - return errors.Trace(err) - } - cfg.UseCheckpoint, err = flags.GetBool(flagUseCheckpoint) - if err != nil { - return errors.Trace(err) - } - if cfg.LastBackupTS > 0 { - // TODO: compatible with incremental backup - cfg.UseCheckpoint = false - log.Info("since incremental backup is used, turn off checkpoint mode") - } - gcTTL, err := flags.GetInt64(flagGCTTL) - if err != nil { - return errors.Trace(err) - } - cfg.GCTTL = gcTTL - cfg.Concurrency, err = flags.GetUint32(flagConcurrency) - if err != nil { - return errors.Trace(err) - } - if cfg.TableConcurrency, err = flags.GetUint(flagTableConcurrency); err != nil { - return errors.Trace(err) - } - - compressionCfg, err := parseCompressionFlags(flags) - if err != nil { - return errors.Trace(err) - } - cfg.CompressionConfig = *compressionCfg - - if err = cfg.Config.ParseFromFlags(flags); err != nil { - return errors.Trace(err) - } - cfg.RemoveSchedulers, err = flags.GetBool(flagRemoveSchedulers) - if err != nil { - return errors.Trace(err) - } - cfg.IgnoreStats, err = flags.GetBool(flagIgnoreStats) - if err != nil { - return errors.Trace(err) - } - cfg.KeyspaceName, err = flags.GetString(flagKeyspaceName) - if err != nil { - return errors.Trace(err) - } - - if flags.Lookup(flagFullBackupType) != nil { - // for backup full - fullBackupType, err := flags.GetString(flagFullBackupType) - if err != nil { - return errors.Trace(err) - } - if !FullBackupType(fullBackupType).Valid() { - return errors.New("invalid full backup type") - } - cfg.FullBackupType = FullBackupType(fullBackupType) - cfg.SkipAWS, err = flags.GetBool(flagSkipAWS) - if err != nil { - return errors.Trace(err) - } - cfg.CloudAPIConcurrency, err = flags.GetUint(flagCloudAPIConcurrency) - if err != nil { - return errors.Trace(err) - } - cfg.VolumeFile, err = flags.GetString(flagBackupVolumeFile) - if err != nil { - return errors.Trace(err) - } - cfg.ProgressFile, err = flags.GetString(flagProgressFile) - if err != nil { - return errors.Trace(err) - } - cfg.SkipPauseGCAndScheduler, err = flags.GetBool(flagOperatorPausedGCAndSchedulers) - if err != nil { - return errors.Trace(err) - } - } - - cfg.ReplicaReadLabel, err = parseReplicaReadLabelFlag(flags) - if err != nil { - return errors.Trace(err) - } - - return nil -} - -// parseCompressionFlags parses the backup-related flags from the flag set. -func parseCompressionFlags(flags *pflag.FlagSet) (*CompressionConfig, error) { - compressionStr, err := flags.GetString(flagCompressionType) - if err != nil { - return nil, errors.Trace(err) - } - compressionType, err := parseCompressionType(compressionStr) - if err != nil { - return nil, errors.Trace(err) - } - level, err := flags.GetInt32(flagCompressionLevel) - if err != nil { - return nil, errors.Trace(err) - } - return &CompressionConfig{ - CompressionLevel: level, - CompressionType: compressionType, - }, nil -} - -// Adjust is use for BR(binary) and BR in TiDB. -// When new config was add and not included in parser. -// we should set proper value in this function. -// so that both binary and TiDB will use same default value. -func (cfg *BackupConfig) Adjust() { - cfg.adjust() - usingDefaultConcurrency := false - if cfg.Config.Concurrency == 0 { - cfg.Config.Concurrency = defaultBackupConcurrency - usingDefaultConcurrency = true - } - if cfg.Config.Concurrency > maxBackupConcurrency { - cfg.Config.Concurrency = maxBackupConcurrency - } - if cfg.RateLimit != unlimited { - // TiKV limits the upload rate by each backup request. - // When the backup requests are sent concurrently, - // the ratelimit couldn't work as intended. - // Degenerating to sequentially sending backup requests to avoid this. - if !usingDefaultConcurrency { - logutil.WarnTerm("setting `--ratelimit` and `--concurrency` at the same time, "+ - "ignoring `--concurrency`: `--ratelimit` forces sequential (i.e. concurrency = 1) backup", - zap.String("ratelimit", units.HumanSize(float64(cfg.RateLimit))+"/s"), - zap.Uint32("concurrency-specified", cfg.Config.Concurrency)) - } - cfg.Config.Concurrency = 1 - } - - if cfg.GCTTL == 0 { - cfg.GCTTL = utils.DefaultBRGCSafePointTTL - } - // Use zstd as default - if cfg.CompressionType == backuppb.CompressionType_UNKNOWN { - cfg.CompressionType = backuppb.CompressionType_ZSTD - } - if cfg.CloudAPIConcurrency == 0 { - cfg.CloudAPIConcurrency = defaultCloudAPIConcurrency - } -} - -type immutableBackupConfig struct { - LastBackupTS uint64 `json:"last-backup-ts"` - IgnoreStats bool `json:"ignore-stats"` - UseCheckpoint bool `json:"use-checkpoint"` - - storage.BackendOptions - Storage string `json:"storage"` - PD []string `json:"pd"` - SendCreds bool `json:"send-credentials-to-tikv"` - NoCreds bool `json:"no-credentials"` - FilterStr []string `json:"filter-strings"` - CipherInfo backuppb.CipherInfo `json:"cipher"` - KeyspaceName string `json:"keyspace-name"` -} - -// a rough hash for checkpoint checker -func (cfg *BackupConfig) Hash() ([]byte, error) { - config := &immutableBackupConfig{ - LastBackupTS: cfg.LastBackupTS, - IgnoreStats: cfg.IgnoreStats, - UseCheckpoint: cfg.UseCheckpoint, - - BackendOptions: cfg.BackendOptions, - Storage: cfg.Storage, - PD: cfg.PD, - SendCreds: cfg.SendCreds, - NoCreds: cfg.NoCreds, - FilterStr: cfg.FilterStr, - CipherInfo: cfg.CipherInfo, - KeyspaceName: cfg.KeyspaceName, - } - data, err := json.Marshal(config) - if err != nil { - return nil, errors.Trace(err) - } - hash := sha256.Sum256(data) - - return hash[:], nil -} - -func isFullBackup(cmdName string) bool { - return cmdName == FullBackupCmd -} - -// RunBackup starts a backup task inside the current goroutine. -func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig) error { - cfg.Adjust() - config.UpdateGlobal(func(conf *config.Config) { - conf.KeyspaceName = cfg.KeyspaceName - }) - - defer summary.Summary(cmdName) - ctx, cancel := context.WithCancel(c) - defer cancel() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("task.RunBackup", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - u, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) - if err != nil { - return errors.Trace(err) - } - // if use noop as external storage, turn off the checkpoint mode - if u.GetNoop() != nil { - log.Info("since noop external storage is used, turn off checkpoint mode") - cfg.UseCheckpoint = false - } - skipStats := cfg.IgnoreStats - // For backup, Domain is not needed if user ignores stats. - // Domain loads all table info into memory. By skipping Domain, we save - // lots of memory (about 500MB for 40K 40 fields YCSB tables). - needDomain := !skipStats - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), cfg.CheckRequirements, needDomain, conn.NormalVersionChecker) - if err != nil { - return errors.Trace(err) - } - defer mgr.Close() - // after version check, check the cluster whether support checkpoint mode - if cfg.UseCheckpoint { - err = version.CheckCheckpointSupport() - if err != nil { - log.Warn("unable to use checkpoint mode, fall back to normal mode", zap.Error(err)) - cfg.UseCheckpoint = false - } - } - var statsHandle *handle.Handle - if !skipStats { - statsHandle = mgr.GetDomain().StatsHandle() - } - - var newCollationEnable string - err = g.UseOneShotSession(mgr.GetStorage(), !needDomain, func(se glue.Session) error { - newCollationEnable, err = se.GetGlobalVariable(utils.GetTidbNewCollationEnabled()) - if err != nil { - return errors.Trace(err) - } - log.Info(fmt.Sprintf("get %s config from mysql.tidb table", utils.TidbNewCollationEnabled), - zap.String(utils.GetTidbNewCollationEnabled(), newCollationEnable)) - return nil - }) - if err != nil { - return errors.Trace(err) - } - - client := backup.NewBackupClient(ctx, mgr) - - // set cipher only for checkpoint - client.SetCipher(&cfg.CipherInfo) - - opts := storage.ExternalStorageOptions{ - NoCredentials: cfg.NoCreds, - SendCredentials: cfg.SendCreds, - CheckS3ObjectLockOptions: true, - } - if err = client.SetStorageAndCheckNotInUse(ctx, u, &opts); err != nil { - return errors.Trace(err) - } - // if checkpoint mode is unused at this time but there is checkpoint meta, - // CheckCheckpoint will stop backing up - cfgHash, err := cfg.Hash() - if err != nil { - return errors.Trace(err) - } - err = client.CheckCheckpoint(cfgHash) - if err != nil { - return errors.Trace(err) - } - err = client.SetLockFile(ctx) - if err != nil { - return errors.Trace(err) - } - // if use checkpoint and gcTTL is the default value - // update gcttl to checkpoint's default gc ttl - if cfg.UseCheckpoint && cfg.GCTTL == utils.DefaultBRGCSafePointTTL { - cfg.GCTTL = utils.DefaultCheckpointGCSafePointTTL - log.Info("use checkpoint's default GC TTL", zap.Int64("GC TTL", cfg.GCTTL)) - } - client.SetGCTTL(cfg.GCTTL) - - backupTS, err := client.GetTS(ctx, cfg.TimeAgo, cfg.BackupTS) - if err != nil { - return errors.Trace(err) - } - g.Record("BackupTS", backupTS) - safePointID := client.GetSafePointID() - sp := utils.BRServiceSafePoint{ - BackupTS: backupTS, - TTL: client.GetGCTTL(), - ID: safePointID, - } - - // use lastBackupTS as safePoint if exists - isIncrementalBackup := cfg.LastBackupTS > 0 - if isIncrementalBackup { - sp.BackupTS = cfg.LastBackupTS - } - - log.Info("current backup safePoint job", zap.Object("safePoint", sp)) - cctx, gcSafePointKeeperCancel := context.WithCancel(ctx) - gcSafePointKeeperRemovable := false - defer func() { - // don't reset the gc-safe-point if checkpoint mode is used and backup is not finished - if cfg.UseCheckpoint && !gcSafePointKeeperRemovable { - log.Info("skip removing gc-safepoint keeper for next retry", zap.String("gc-id", sp.ID)) - return - } - log.Info("start to remove gc-safepoint keeper") - // close the gc safe point keeper at first - gcSafePointKeeperCancel() - // set the ttl to 0 to remove the gc-safe-point - sp.TTL = 0 - if err := utils.UpdateServiceSafePoint(ctx, mgr.GetPDClient(), sp); err != nil { - log.Warn("failed to update service safe point, backup may fail if gc triggered", - zap.Error(err), - ) - } - log.Info("finish removing gc-safepoint keeper") - }() - err = utils.StartServiceSafePointKeeper(cctx, mgr.GetPDClient(), sp) - if err != nil { - return errors.Trace(err) - } - - if cfg.RemoveSchedulers { - log.Debug("removing some PD schedulers") - restore, e := mgr.RemoveSchedulers(ctx) - defer func() { - if ctx.Err() != nil { - log.Warn("context canceled, doing clean work with background context") - ctx = context.Background() - } - if restoreE := restore(ctx); restoreE != nil { - log.Warn("failed to restore removed schedulers, you may need to restore them manually", zap.Error(restoreE)) - } - }() - if e != nil { - return errors.Trace(err) - } - } - - req := backuppb.BackupRequest{ - ClusterId: client.GetClusterID(), - StartVersion: cfg.LastBackupTS, - EndVersion: backupTS, - RateLimit: cfg.RateLimit, - StorageBackend: client.GetStorageBackend(), - Concurrency: defaultBackupConcurrency, - CompressionType: cfg.CompressionType, - CompressionLevel: cfg.CompressionLevel, - CipherInfo: &cfg.CipherInfo, - ReplicaRead: len(cfg.ReplicaReadLabel) != 0, - Context: &kvrpcpb.Context{ - ResourceControlContext: &kvrpcpb.ResourceControlContext{ - ResourceGroupName: "", // TODO, - }, - RequestSource: kvutil.BuildRequestSource(true, kv.InternalTxnBR, kvutil.ExplicitTypeBR), - }, - } - brVersion := g.GetVersion() - clusterVersion, err := mgr.GetClusterVersion(ctx) - if err != nil { - return errors.Trace(err) - } - - ranges, schemas, policies, err := client.BuildBackupRangeAndSchema(mgr.GetStorage(), cfg.TableFilter, backupTS, isFullBackup(cmdName)) - if err != nil { - return errors.Trace(err) - } - - // Metafile size should be less than 64MB. - metawriter := metautil.NewMetaWriter(client.GetStorage(), - metautil.MetaFileSize, cfg.UseBackupMetaV2, "", &cfg.CipherInfo) - // Hack way to update backupmeta. - metawriter.Update(func(m *backuppb.BackupMeta) { - m.StartVersion = req.StartVersion - m.EndVersion = req.EndVersion - m.IsRawKv = req.IsRawKv - m.ClusterId = req.ClusterId - m.ClusterVersion = clusterVersion - m.BrVersion = brVersion - m.NewCollationsEnabled = newCollationEnable - m.ApiVersion = mgr.GetStorage().GetCodec().GetAPIVersion() - }) - - log.Info("get placement policies", zap.Int("count", len(policies))) - if len(policies) != 0 { - metawriter.Update(func(m *backuppb.BackupMeta) { - m.Policies = policies - }) - } - - // nothing to backup - if len(ranges) == 0 { - pdAddress := strings.Join(cfg.PD, ",") - log.Warn("Nothing to backup, maybe connected to cluster for restoring", - zap.String("PD address", pdAddress)) - - err = metawriter.FlushBackupMeta(ctx) - if err == nil { - summary.SetSuccessStatus(true) - } - return err - } - - if isIncrementalBackup { - if backupTS <= cfg.LastBackupTS { - log.Error("LastBackupTS is larger or equal to current TS") - return errors.Annotate(berrors.ErrInvalidArgument, "LastBackupTS is larger or equal to current TS") - } - err = utils.CheckGCSafePoint(ctx, mgr.GetPDClient(), cfg.LastBackupTS) - if err != nil { - log.Error("Check gc safepoint for last backup ts failed", zap.Error(err)) - return errors.Trace(err) - } - - metawriter.StartWriteMetasAsync(ctx, metautil.AppendDDL) - err = backup.WriteBackupDDLJobs(metawriter, g, mgr.GetStorage(), cfg.LastBackupTS, backupTS, needDomain) - if err != nil { - return errors.Trace(err) - } - if err = metawriter.FinishWriteMetas(ctx, metautil.AppendDDL); err != nil { - return errors.Trace(err) - } - } - - summary.CollectInt("backup total ranges", len(ranges)) - - approximateRegions, err := getRegionCountOfRanges(ctx, mgr, ranges) - if err != nil { - return errors.Trace(err) - } - // Redirect to log if there is no log file to avoid unreadable output. - updateCh := g.StartProgress( - ctx, cmdName, int64(approximateRegions), !cfg.LogProgress) - summary.CollectInt("backup total regions", approximateRegions) - - progressCount := uint64(0) - progressCallBack := func() { - updateCh.Inc() - failpoint.Inject("progress-call-back", func(v failpoint.Value) { - log.Info("failpoint progress-call-back injected") - atomic.AddUint64(&progressCount, 1) - if fileName, ok := v.(string); ok { - f, osErr := os.OpenFile(fileName, os.O_CREATE|os.O_WRONLY, os.ModePerm) - if osErr != nil { - log.Warn("failed to create file", zap.Error(osErr)) - } - msg := []byte(fmt.Sprintf("region:%d\n", atomic.LoadUint64(&progressCount))) - _, err = f.Write(msg) - if err != nil { - log.Warn("failed to write data to file", zap.Error(err)) - } - } - }) - } - - if cfg.UseCheckpoint { - if err = client.StartCheckpointRunner(ctx, cfgHash, backupTS, ranges, safePointID, progressCallBack); err != nil { - return errors.Trace(err) - } - defer func() { - if !gcSafePointKeeperRemovable { - log.Info("wait for flush checkpoint...") - client.WaitForFinishCheckpoint(ctx, true) - } else { - log.Info("start to remove checkpoint data for backup") - client.WaitForFinishCheckpoint(ctx, false) - if removeErr := checkpoint.RemoveCheckpointDataForBackup(ctx, client.GetStorage()); removeErr != nil { - log.Warn("failed to remove checkpoint data for backup", zap.Error(removeErr)) - } else { - log.Info("the checkpoint data for backup is removed.") - } - } - }() - } - - failpoint.Inject("s3-outage-during-writing-file", func(v failpoint.Value) { - log.Info("failpoint s3-outage-during-writing-file injected, " + - "process will sleep for 5s and notify the shell to kill s3 service.") - if sigFile, ok := v.(string); ok { - file, err := os.Create(sigFile) - if err != nil { - log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) - } - if file != nil { - file.Close() - } - } - time.Sleep(5 * time.Second) - }) - - metawriter.StartWriteMetasAsync(ctx, metautil.AppendDataFile) - err = client.BackupRanges(ctx, ranges, req, uint(cfg.Concurrency), cfg.ReplicaReadLabel, metawriter, progressCallBack) - if err != nil { - return errors.Trace(err) - } - // Backup has finished - updateCh.Close() - - err = metawriter.FinishWriteMetas(ctx, metautil.AppendDataFile) - if err != nil { - return errors.Trace(err) - } - - skipChecksum := !cfg.Checksum || isIncrementalBackup - checksumProgress := int64(schemas.Len()) - if skipChecksum { - checksumProgress = 1 - if isIncrementalBackup { - // Since we don't support checksum for incremental data, fast checksum should be skipped. - log.Info("Skip fast checksum in incremental backup") - } else { - // When user specified not to calculate checksum, don't calculate checksum. - log.Info("Skip fast checksum") - } - } - updateCh = g.StartProgress(ctx, "Checksum", checksumProgress, !cfg.LogProgress) - schemasConcurrency := min(cfg.TableConcurrency, uint(schemas.Len())) - - err = schemas.BackupSchemas( - ctx, metawriter, client.GetCheckpointRunner(), mgr.GetStorage(), statsHandle, backupTS, schemasConcurrency, cfg.ChecksumConcurrency, skipChecksum, updateCh) - if err != nil { - return errors.Trace(err) - } - - err = metawriter.FlushBackupMeta(ctx) - if err != nil { - return errors.Trace(err) - } - // Since backupmeta is flushed on the external storage, - // we can remove the gc safepoint keeper - gcSafePointKeeperRemovable = true - - // Checksum has finished, close checksum progress. - updateCh.Close() - - if !skipChecksum { - // Check if checksum from files matches checksum from coprocessor. - err = checksum.FastChecksum(ctx, metawriter.Backupmeta(), client.GetStorage(), &cfg.CipherInfo) - if err != nil { - return errors.Trace(err) - } - } - archiveSize := metawriter.ArchiveSize() - g.Record(summary.BackupDataSize, archiveSize) - //backup from tidb will fetch a general Size issue https://github.com/pingcap/tidb/issues/27247 - g.Record("Size", archiveSize) - // Set task summary to success status. - summary.SetSuccessStatus(true) - return nil -} - -func getRegionCountOfRanges( - ctx context.Context, - mgr *conn.Mgr, - ranges []rtree.Range, -) (int, error) { - // The number of regions need to backup - approximateRegions := 0 - for _, r := range ranges { - regionCount, err := mgr.GetRegionCount(ctx, r.StartKey, r.EndKey) - if err != nil { - return 0, errors.Trace(err) - } - approximateRegions += regionCount - } - return approximateRegions, nil -} - -// ParseTSString port from tidb setSnapshotTS. -func ParseTSString(ts string, tzCheck bool) (uint64, error) { - if len(ts) == 0 { - return 0, nil - } - if tso, err := strconv.ParseUint(ts, 10, 64); err == nil { - return tso, nil - } - - loc := time.Local - sc := stmtctx.NewStmtCtxWithTimeZone(loc) - if tzCheck { - tzIdx, _, _, _, _ := types.GetTimezone(ts) - if tzIdx < 0 { - return 0, errors.Errorf("must set timezone when using datetime format ts, e.g. '2018-05-11 01:42:23+0800'") - } - } - t, err := types.ParseTime(sc.TypeCtx(), ts, mysql.TypeTimestamp, types.MaxFsp) - if err != nil { - return 0, errors.Trace(err) - } - t1, err := t.GoTime(loc) - if err != nil { - return 0, errors.Trace(err) - } - return oracle.GoTimeToTS(t1), nil -} - -func DefaultBackupConfig() BackupConfig { - fs := pflag.NewFlagSet("dummy", pflag.ContinueOnError) - DefineCommonFlags(fs) - DefineBackupFlags(fs) - cfg := BackupConfig{} - err := multierr.Combine( - cfg.ParseFromFlags(fs), - cfg.Config.ParseFromFlags(fs), - ) - if err != nil { - log.Panic("infallible operation failed.", zap.Error(err)) - } - return cfg -} - -func parseCompressionType(s string) (backuppb.CompressionType, error) { - var ct backuppb.CompressionType - switch s { - case "lz4": - ct = backuppb.CompressionType_LZ4 - case "snappy": - ct = backuppb.CompressionType_SNAPPY - case "zstd": - ct = backuppb.CompressionType_ZSTD - default: - return backuppb.CompressionType_UNKNOWN, errors.Annotatef(berrors.ErrInvalidArgument, "invalid compression type '%s'", s) - } - return ct, nil -} - -func parseReplicaReadLabelFlag(flags *pflag.FlagSet) (map[string]string, error) { - replicaReadLabelStr, err := flags.GetString(flagReplicaReadLabel) - if err != nil { - return nil, errors.Trace(err) - } - if replicaReadLabelStr == "" { - return nil, nil - } - kv := strings.Split(replicaReadLabelStr, ":") - if len(kv) != 2 { - return nil, errors.Annotatef(berrors.ErrInvalidArgument, "invalid replica read label '%s'", replicaReadLabelStr) - } - return map[string]string{kv[0]: kv[1]}, nil -} diff --git a/br/pkg/task/binding__failpoint_binding__.go b/br/pkg/task/binding__failpoint_binding__.go deleted file mode 100644 index ecd81d7d48d63..0000000000000 --- a/br/pkg/task/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package task - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/task/operator/binding__failpoint_binding__.go b/br/pkg/task/operator/binding__failpoint_binding__.go deleted file mode 100644 index fc7dd0d4f3fa7..0000000000000 --- a/br/pkg/task/operator/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package operator - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/task/operator/cmd.go b/br/pkg/task/operator/cmd.go index 5ae421460737c..cbe5c3ac2442b 100644 --- a/br/pkg/task/operator/cmd.go +++ b/br/pkg/task/operator/cmd.go @@ -146,9 +146,9 @@ func AdaptEnvForSnapshotBackup(ctx context.Context, cfg *PauseGcConfig) error { }) cx.run(func() error { return pauseAdminAndWaitApply(cx, initChan) }) go func() { - if _, _err_ := failpoint.Eval(_curpkg_("SkipReadyHint")); _err_ == nil { - return - } + failpoint.Inject("SkipReadyHint", func() { + failpoint.Return() + }) cx.rdGrp.Wait() if cfg.OnAllReady != nil { cfg.OnAllReady() diff --git a/br/pkg/task/operator/cmd.go__failpoint_stash__ b/br/pkg/task/operator/cmd.go__failpoint_stash__ deleted file mode 100644 index cbe5c3ac2442b..0000000000000 --- a/br/pkg/task/operator/cmd.go__failpoint_stash__ +++ /dev/null @@ -1,251 +0,0 @@ -// Copyright 2023 PingCAP, Inc. Licensed under Apache-2.0. - -package operator - -import ( - "context" - "crypto/tls" - "runtime/debug" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - preparesnap "github.com/pingcap/tidb/br/pkg/backup/prepare_snap" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/pdutil" - "github.com/pingcap/tidb/br/pkg/task" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/tikv/client-go/v2/tikv" - "go.uber.org/multierr" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc/keepalive" -) - -func dialPD(ctx context.Context, cfg *task.Config) (*pdutil.PdController, error) { - var tc *tls.Config - if cfg.TLS.IsEnabled() { - var err error - tc, err = cfg.TLS.ToTLSConfig() - if err != nil { - return nil, err - } - } - mgr, err := pdutil.NewPdController(ctx, cfg.PD, tc, cfg.TLS.ToPDSecurityOption()) - if err != nil { - return nil, err - } - return mgr, nil -} - -func (cx *AdaptEnvForSnapshotBackupContext) cleanUpWith(f func(ctx context.Context)) { - cx.cleanUpWithRetErr(nil, func(ctx context.Context) error { f(ctx); return nil }) -} - -func (cx *AdaptEnvForSnapshotBackupContext) cleanUpWithRetErr(errOut *error, f func(ctx context.Context) error) { - ctx, cancel := context.WithTimeout(context.Background(), cx.cfg.TTL) - defer cancel() - err := f(ctx) - if errOut != nil { - *errOut = multierr.Combine(*errOut, err) - } -} - -func (cx *AdaptEnvForSnapshotBackupContext) run(f func() error) { - cx.rdGrp.Add(1) - buf := debug.Stack() - cx.runGrp.Go(func() error { - err := f() - if err != nil { - log.Error("A task failed.", zap.Error(err), zap.ByteString("task-created-at", buf)) - } - return err - }) -} - -type AdaptEnvForSnapshotBackupContext struct { - context.Context - - pdMgr *pdutil.PdController - kvMgr *utils.StoreManager - cfg PauseGcConfig - - rdGrp sync.WaitGroup - runGrp *errgroup.Group -} - -func (cx *AdaptEnvForSnapshotBackupContext) Close() { - cx.pdMgr.Close() - cx.kvMgr.Close() -} - -func (cx *AdaptEnvForSnapshotBackupContext) GetBackOffer(operation string) utils.Backoffer { - state := utils.InitialRetryState(64, 1*time.Second, 10*time.Second) - bo := utils.GiveUpRetryOn(&state, berrors.ErrPossibleInconsistency) - bo = utils.VerboseRetry(bo, logutil.CL(cx).With(zap.String("operation", operation))) - return bo -} - -func (cx *AdaptEnvForSnapshotBackupContext) ReadyL(name string, notes ...zap.Field) { - logutil.CL(cx).Info("Stage ready.", append(notes, zap.String("component", name))...) - cx.rdGrp.Done() -} - -func hintAllReady() { - // Hacking: some version of operators using the follow two logs to check whether we are ready... - log.Info("Schedulers are paused.") - log.Info("GC is paused.") - log.Info("All ready.") -} - -// AdaptEnvForSnapshotBackup blocks the current goroutine and pause the GC safepoint and remove the scheduler by the config. -// This function will block until the context being canceled. -func AdaptEnvForSnapshotBackup(ctx context.Context, cfg *PauseGcConfig) error { - utils.DumpGoroutineWhenExit.Store(true) - mgr, err := dialPD(ctx, &cfg.Config) - if err != nil { - return errors.Annotate(err, "failed to dial PD") - } - mgr.SchedulerPauseTTL = cfg.TTL - var tconf *tls.Config - if cfg.TLS.IsEnabled() { - tconf, err = cfg.TLS.ToTLSConfig() - if err != nil { - return errors.Annotate(err, "invalid tls config") - } - } - kvMgr := utils.NewStoreManager(mgr.GetPDClient(), keepalive.ClientParameters{ - Time: cfg.Config.GRPCKeepaliveTime, - Timeout: cfg.Config.GRPCKeepaliveTimeout, - }, tconf) - eg, ectx := errgroup.WithContext(ctx) - cx := &AdaptEnvForSnapshotBackupContext{ - Context: logutil.ContextWithField(ectx, zap.String("tag", "br_operator")), - pdMgr: mgr, - kvMgr: kvMgr, - cfg: *cfg, - rdGrp: sync.WaitGroup{}, - runGrp: eg, - } - defer cx.Close() - - initChan := make(chan struct{}) - cx.run(func() error { return pauseGCKeeper(cx) }) - cx.run(func() error { - log.Info("Pause scheduler waiting all connections established.") - select { - case <-initChan: - case <-cx.Done(): - return cx.Err() - } - log.Info("Pause scheduler noticed connections established.") - return pauseSchedulerKeeper(cx) - }) - cx.run(func() error { return pauseAdminAndWaitApply(cx, initChan) }) - go func() { - failpoint.Inject("SkipReadyHint", func() { - failpoint.Return() - }) - cx.rdGrp.Wait() - if cfg.OnAllReady != nil { - cfg.OnAllReady() - } - utils.DumpGoroutineWhenExit.Store(false) - hintAllReady() - }() - defer func() { - if cfg.OnExit != nil { - cfg.OnExit() - } - }() - - return eg.Wait() -} - -func pauseAdminAndWaitApply(cx *AdaptEnvForSnapshotBackupContext, afterConnectionsEstablished chan<- struct{}) error { - env := preparesnap.CliEnv{ - Cache: tikv.NewRegionCache(cx.pdMgr.GetPDClient()), - Mgr: cx.kvMgr, - } - defer env.Cache.Close() - retryEnv := preparesnap.RetryAndSplitRequestEnv{Env: env} - begin := time.Now() - prep := preparesnap.New(retryEnv) - prep.LeaseDuration = cx.cfg.TTL - prep.AfterConnectionsEstablished = func() { - log.Info("All connections are stablished.") - close(afterConnectionsEstablished) - } - - defer cx.cleanUpWith(func(ctx context.Context) { - if err := prep.Finalize(ctx); err != nil { - logutil.CL(ctx).Warn("failed to finalize the prepare stream", logutil.ShortError(err)) - } - }) - - // We must use our own context here, or once we are cleaning up the client will be invalid. - myCtx := logutil.ContextWithField(context.Background(), zap.String("category", "pause_admin_and_wait_apply")) - if err := prep.DriveLoopAndWaitPrepare(myCtx); err != nil { - return err - } - - cx.ReadyL("pause_admin_and_wait_apply", zap.Stringer("take", time.Since(begin))) - <-cx.Done() - return nil -} - -func pauseGCKeeper(cx *AdaptEnvForSnapshotBackupContext) (err error) { - // Note: should we remove the service safepoint as soon as this exits? - sp := utils.BRServiceSafePoint{ - ID: utils.MakeSafePointID(), - TTL: int64(cx.cfg.TTL.Seconds()), - BackupTS: cx.cfg.SafePoint, - } - if sp.BackupTS == 0 { - rts, err := cx.pdMgr.GetMinResolvedTS(cx) - if err != nil { - return err - } - logutil.CL(cx).Info("No service safepoint provided, using the minimal resolved TS.", zap.Uint64("min-resolved-ts", rts)) - sp.BackupTS = rts - } - err = utils.StartServiceSafePointKeeper(cx, cx.pdMgr.GetPDClient(), sp) - if err != nil { - return err - } - cx.ReadyL("pause_gc", zap.Object("safepoint", sp)) - defer cx.cleanUpWithRetErr(&err, func(ctx context.Context) error { - cancelSP := utils.BRServiceSafePoint{ - ID: sp.ID, - TTL: 0, - } - return utils.UpdateServiceSafePoint(ctx, cx.pdMgr.GetPDClient(), cancelSP) - }) - // Note: in fact we can directly return here. - // But the name `keeper` implies once the function exits, - // the GC should be resume, so let's block here. - <-cx.Done() - return nil -} - -func pauseSchedulerKeeper(ctx *AdaptEnvForSnapshotBackupContext) error { - undo, err := ctx.pdMgr.RemoveAllPDSchedulers(ctx) - if undo != nil { - defer ctx.cleanUpWith(func(ctx context.Context) { - if err := undo(ctx); err != nil { - log.Warn("failed to restore pd scheduler.", logutil.ShortError(err)) - } - }) - } - if err != nil { - return err - } - ctx.ReadyL("pause_scheduler") - // Wait until the context canceled. - // So we can properly do the clean up work. - <-ctx.Done() - return nil -} diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 392a8b005b858..8bc6383be78b6 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -1113,10 +1113,10 @@ func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConf // Restore sst files in batch. batchSize := mathutil.MaxInt - if v, _err_ := failpoint.Eval(_curpkg_("small-batch-size")); _err_ == nil { + failpoint.Inject("small-batch-size", func(v failpoint.Value) { log.Info("failpoint small batch size is on", zap.Int("size", v.(int))) batchSize = v.(int) - } + }) // Split/Scatter + Download/Ingest progressLen := int64(rangeSize + len(files)) diff --git a/br/pkg/task/restore.go__failpoint_stash__ b/br/pkg/task/restore.go__failpoint_stash__ deleted file mode 100644 index 8bc6383be78b6..0000000000000 --- a/br/pkg/task/restore.go__failpoint_stash__ +++ /dev/null @@ -1,1713 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package task - -import ( - "cmp" - "context" - "fmt" - "slices" - "strings" - "time" - - "github.com/docker/go-units" - "github.com/google/uuid" - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/checkpoint" - pconfig "github.com/pingcap/tidb/br/pkg/config" - "github.com/pingcap/tidb/br/pkg/conn" - connutil "github.com/pingcap/tidb/br/pkg/conn/util" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/httputil" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/metautil" - "github.com/pingcap/tidb/br/pkg/restore" - snapclient "github.com/pingcap/tidb/br/pkg/restore/snap_client" - "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/engine" - "github.com/pingcap/tidb/pkg/util/mathutil" - "github.com/spf13/cobra" - "github.com/spf13/pflag" - "github.com/tikv/client-go/v2/tikv" - pd "github.com/tikv/pd/client" - "github.com/tikv/pd/client/http" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/multierr" - "go.uber.org/zap" -) - -const ( - flagOnline = "online" - flagNoSchema = "no-schema" - flagLoadStats = "load-stats" - flagGranularity = "granularity" - flagConcurrencyPerStore = "tikv-max-restore-concurrency" - flagAllowPITRFromIncremental = "allow-pitr-from-incremental" - - // FlagMergeRegionSizeBytes is the flag name of merge small regions by size - FlagMergeRegionSizeBytes = "merge-region-size-bytes" - // FlagMergeRegionKeyCount is the flag name of merge small regions by key count - FlagMergeRegionKeyCount = "merge-region-key-count" - // FlagPDConcurrency controls concurrency pd-relative operations like split & scatter. - FlagPDConcurrency = "pd-concurrency" - // FlagStatsConcurrency controls concurrency to restore statistic. - FlagStatsConcurrency = "stats-concurrency" - // FlagBatchFlushInterval controls after how long the restore batch would be auto sended. - FlagBatchFlushInterval = "batch-flush-interval" - // FlagDdlBatchSize controls batch ddl size to create a batch of tables - FlagDdlBatchSize = "ddl-batch-size" - // FlagWithPlacementPolicy corresponds to tidb config with-tidb-placement-mode - // current only support STRICT or IGNORE, the default is STRICT according to tidb. - FlagWithPlacementPolicy = "with-tidb-placement-mode" - // FlagKeyspaceName corresponds to tidb config keyspace-name - FlagKeyspaceName = "keyspace-name" - - // FlagWaitTiFlashReady represents whether wait tiflash replica ready after table restored and checksumed. - FlagWaitTiFlashReady = "wait-tiflash-ready" - - // FlagStreamStartTS and FlagStreamRestoreTS is used for log restore timestamp range. - FlagStreamStartTS = "start-ts" - FlagStreamRestoreTS = "restored-ts" - // FlagStreamFullBackupStorage is used for log restore, represents the full backup storage. - FlagStreamFullBackupStorage = "full-backup-storage" - // FlagPiTRBatchCount and FlagPiTRBatchSize are used for restore log with batch method. - FlagPiTRBatchCount = "pitr-batch-count" - FlagPiTRBatchSize = "pitr-batch-size" - FlagPiTRConcurrency = "pitr-concurrency" - - FlagResetSysUsers = "reset-sys-users" - - defaultPiTRBatchCount = 8 - defaultPiTRBatchSize = 16 * 1024 * 1024 - defaultRestoreConcurrency = 128 - defaultPiTRConcurrency = 16 - defaultPDConcurrency = 1 - defaultStatsConcurrency = 12 - defaultBatchFlushInterval = 16 * time.Second - defaultFlagDdlBatchSize = 128 - resetSpeedLimitRetryTimes = 3 - maxRestoreBatchSizeLimit = 10240 -) - -const ( - FullRestoreCmd = "Full Restore" - DBRestoreCmd = "DataBase Restore" - TableRestoreCmd = "Table Restore" - PointRestoreCmd = "Point Restore" - RawRestoreCmd = "Raw Restore" - TxnRestoreCmd = "Txn Restore" -) - -// RestoreCommonConfig is the common configuration for all BR restore tasks. -type RestoreCommonConfig struct { - Online bool `json:"online" toml:"online"` - Granularity string `json:"granularity" toml:"granularity"` - ConcurrencyPerStore pconfig.ConfigTerm[uint] `json:"tikv-max-restore-concurrency" toml:"tikv-max-restore-concurrency"` - - // MergeSmallRegionSizeBytes is the threshold of merging small regions (Default 96MB, region split size). - // MergeSmallRegionKeyCount is the threshold of merging smalle regions (Default 960_000, region split key count). - // See https://github.com/tikv/tikv/blob/v4.0.8/components/raftstore/src/coprocessor/config.rs#L35-L38 - MergeSmallRegionSizeBytes pconfig.ConfigTerm[uint64] `json:"merge-region-size-bytes" toml:"merge-region-size-bytes"` - MergeSmallRegionKeyCount pconfig.ConfigTerm[uint64] `json:"merge-region-key-count" toml:"merge-region-key-count"` - - // determines whether enable restore sys table on default, see fullClusterRestore in restore/client.go - WithSysTable bool `json:"with-sys-table" toml:"with-sys-table"` - - ResetSysUsers []string `json:"reset-sys-users" toml:"reset-sys-users"` -} - -// adjust adjusts the abnormal config value in the current config. -// useful when not starting BR from CLI (e.g. from BRIE in SQL). -func (cfg *RestoreCommonConfig) adjust() { - if !cfg.MergeSmallRegionKeyCount.Modified { - cfg.MergeSmallRegionKeyCount.Value = conn.DefaultMergeRegionKeyCount - } - if !cfg.MergeSmallRegionSizeBytes.Modified { - cfg.MergeSmallRegionSizeBytes.Value = conn.DefaultMergeRegionSizeBytes - } - if len(cfg.Granularity) == 0 { - cfg.Granularity = string(restore.CoarseGrained) - } - if !cfg.ConcurrencyPerStore.Modified { - cfg.ConcurrencyPerStore.Value = conn.DefaultImportNumGoroutines - } -} - -// DefineRestoreCommonFlags defines common flags for the restore command. -func DefineRestoreCommonFlags(flags *pflag.FlagSet) { - // TODO remove experimental tag if it's stable - flags.Bool(flagOnline, false, "(experimental) Whether online when restore") - flags.String(flagGranularity, string(restore.CoarseGrained), "(deprecated) Whether split & scatter regions using fine-grained way during restore") - flags.Uint(flagConcurrencyPerStore, 128, "The size of thread pool on each store that executes tasks") - flags.Uint32(flagConcurrency, 128, "(deprecated) The size of thread pool on BR that executes tasks, "+ - "where each task restores one SST file to TiKV") - flags.Uint64(FlagMergeRegionSizeBytes, conn.DefaultMergeRegionSizeBytes, - "the threshold of merging small regions (Default 96MB, region split size)") - flags.Uint64(FlagMergeRegionKeyCount, conn.DefaultMergeRegionKeyCount, - "the threshold of merging small regions (Default 960_000, region split key count)") - flags.Uint(FlagPDConcurrency, defaultPDConcurrency, - "concurrency pd-relative operations like split & scatter.") - flags.Uint(FlagStatsConcurrency, defaultStatsConcurrency, - "concurrency to restore statistic") - flags.Duration(FlagBatchFlushInterval, defaultBatchFlushInterval, - "after how long a restore batch would be auto sent.") - flags.Uint(FlagDdlBatchSize, defaultFlagDdlBatchSize, - "batch size for ddl to create a batch of tables once.") - flags.Bool(flagWithSysTable, true, "whether restore system privilege tables on default setting") - flags.StringArrayP(FlagResetSysUsers, "", []string{"cloud_admin", "root"}, "whether reset these users after restoration") - flags.Bool(flagUseFSR, false, "whether enable FSR for AWS snapshots") - - _ = flags.MarkHidden(FlagResetSysUsers) - _ = flags.MarkHidden(FlagMergeRegionSizeBytes) - _ = flags.MarkHidden(FlagMergeRegionKeyCount) - _ = flags.MarkHidden(FlagPDConcurrency) - _ = flags.MarkHidden(FlagStatsConcurrency) - _ = flags.MarkHidden(FlagBatchFlushInterval) - _ = flags.MarkHidden(FlagDdlBatchSize) -} - -// ParseFromFlags parses the config from the flag set. -func (cfg *RestoreCommonConfig) ParseFromFlags(flags *pflag.FlagSet) error { - var err error - cfg.Online, err = flags.GetBool(flagOnline) - if err != nil { - return errors.Trace(err) - } - cfg.Granularity, err = flags.GetString(flagGranularity) - if err != nil { - return errors.Trace(err) - } - cfg.ConcurrencyPerStore.Value, err = flags.GetUint(flagConcurrencyPerStore) - if err != nil { - return errors.Trace(err) - } - cfg.ConcurrencyPerStore.Modified = flags.Changed(flagConcurrencyPerStore) - - cfg.MergeSmallRegionKeyCount.Value, err = flags.GetUint64(FlagMergeRegionKeyCount) - if err != nil { - return errors.Trace(err) - } - cfg.MergeSmallRegionKeyCount.Modified = flags.Changed(FlagMergeRegionKeyCount) - - cfg.MergeSmallRegionSizeBytes.Value, err = flags.GetUint64(FlagMergeRegionSizeBytes) - if err != nil { - return errors.Trace(err) - } - cfg.MergeSmallRegionSizeBytes.Modified = flags.Changed(FlagMergeRegionSizeBytes) - - if flags.Lookup(flagWithSysTable) != nil { - cfg.WithSysTable, err = flags.GetBool(flagWithSysTable) - if err != nil { - return errors.Trace(err) - } - } - cfg.ResetSysUsers, err = flags.GetStringArray(FlagResetSysUsers) - if err != nil { - return errors.Trace(err) - } - return errors.Trace(err) -} - -// RestoreConfig is the configuration specific for restore tasks. -type RestoreConfig struct { - Config - RestoreCommonConfig - - NoSchema bool `json:"no-schema" toml:"no-schema"` - LoadStats bool `json:"load-stats" toml:"load-stats"` - PDConcurrency uint `json:"pd-concurrency" toml:"pd-concurrency"` - StatsConcurrency uint `json:"stats-concurrency" toml:"stats-concurrency"` - BatchFlushInterval time.Duration `json:"batch-flush-interval" toml:"batch-flush-interval"` - // DdlBatchSize use to define the size of batch ddl to create tables - DdlBatchSize uint `json:"ddl-batch-size" toml:"ddl-batch-size"` - - WithPlacementPolicy string `json:"with-tidb-placement-mode" toml:"with-tidb-placement-mode"` - - // FullBackupStorage is used to run `restore full` before `restore log`. - // if it is empty, directly take restoring log justly. - FullBackupStorage string `json:"full-backup-storage" toml:"full-backup-storage"` - - // AllowPITRFromIncremental indicates whether this restore should enter a compatibility mode for incremental restore. - // In this restore mode, the restore will not perform timestamp rewrite on the incremental data. - AllowPITRFromIncremental bool `json:"allow-pitr-from-incremental" toml:"allow-pitr-from-incremental"` - - // [startTs, RestoreTS] is used to `restore log` from StartTS to RestoreTS. - StartTS uint64 `json:"start-ts" toml:"start-ts"` - RestoreTS uint64 `json:"restore-ts" toml:"restore-ts"` - tiflashRecorder *tiflashrec.TiFlashRecorder `json:"-" toml:"-"` - PitrBatchCount uint32 `json:"pitr-batch-count" toml:"pitr-batch-count"` - PitrBatchSize uint32 `json:"pitr-batch-size" toml:"pitr-batch-size"` - PitrConcurrency uint32 `json:"-" toml:"-"` - - UseCheckpoint bool `json:"use-checkpoint" toml:"use-checkpoint"` - checkpointSnapshotRestoreTaskName string `json:"-" toml:"-"` - checkpointLogRestoreTaskName string `json:"-" toml:"-"` - checkpointTaskInfoClusterID uint64 `json:"-" toml:"-"` - WaitTiflashReady bool `json:"wait-tiflash-ready" toml:"wait-tiflash-ready"` - - // for ebs-based restore - FullBackupType FullBackupType `json:"full-backup-type" toml:"full-backup-type"` - Prepare bool `json:"prepare" toml:"prepare"` - OutputFile string `json:"output-file" toml:"output-file"` - SkipAWS bool `json:"skip-aws" toml:"skip-aws"` - CloudAPIConcurrency uint `json:"cloud-api-concurrency" toml:"cloud-api-concurrency"` - VolumeType pconfig.EBSVolumeType `json:"volume-type" toml:"volume-type"` - VolumeIOPS int64 `json:"volume-iops" toml:"volume-iops"` - VolumeThroughput int64 `json:"volume-throughput" toml:"volume-throughput"` - VolumeEncrypted bool `json:"volume-encrypted" toml:"volume-encrypted"` - ProgressFile string `json:"progress-file" toml:"progress-file"` - TargetAZ string `json:"target-az" toml:"target-az"` - UseFSR bool `json:"use-fsr" toml:"use-fsr"` -} - -// DefineRestoreFlags defines common flags for the restore tidb command. -func DefineRestoreFlags(flags *pflag.FlagSet) { - flags.Bool(flagNoSchema, false, "skip creating schemas and tables, reuse existing empty ones") - flags.Bool(flagLoadStats, true, "Run load stats at end of snapshot restore task") - // Do not expose this flag - _ = flags.MarkHidden(flagNoSchema) - flags.String(FlagWithPlacementPolicy, "STRICT", "correspond to tidb global/session variable with-tidb-placement-mode") - flags.String(FlagKeyspaceName, "", "correspond to tidb config keyspace-name") - - flags.Bool(flagUseCheckpoint, true, "use checkpoint mode") - _ = flags.MarkHidden(flagUseCheckpoint) - - flags.Bool(FlagWaitTiFlashReady, false, "whether wait tiflash replica ready if tiflash exists") - flags.Bool(flagAllowPITRFromIncremental, true, "whether make incremental restore compatible with later log restore"+ - " default is true, the incremental restore will not perform rewrite on the incremental data"+ - " meanwhile the incremental restore will not allow to restore 3 backfilled type ddl jobs,"+ - " these ddl jobs are Add index, Modify column and Reorganize partition") - - DefineRestoreCommonFlags(flags) -} - -// DefineStreamRestoreFlags defines for the restore log command. -func DefineStreamRestoreFlags(command *cobra.Command) { - command.Flags().String(FlagStreamStartTS, "", "the start timestamp which log restore from.\n"+ - "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'") - command.Flags().String(FlagStreamRestoreTS, "", "the point of restore, used for log restore.\n"+ - "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'") - command.Flags().String(FlagStreamFullBackupStorage, "", "specify the backup full storage. "+ - "fill it if want restore full backup before restore log.") - command.Flags().Uint32(FlagPiTRBatchCount, defaultPiTRBatchCount, "specify the batch count to restore log.") - command.Flags().Uint32(FlagPiTRBatchSize, defaultPiTRBatchSize, "specify the batch size to retore log.") - command.Flags().Uint32(FlagPiTRConcurrency, defaultPiTRConcurrency, "specify the concurrency to restore log.") -} - -// ParseStreamRestoreFlags parses the `restore stream` flags from the flag set. -func (cfg *RestoreConfig) ParseStreamRestoreFlags(flags *pflag.FlagSet) error { - tsString, err := flags.GetString(FlagStreamStartTS) - if err != nil { - return errors.Trace(err) - } - if cfg.StartTS, err = ParseTSString(tsString, true); err != nil { - return errors.Trace(err) - } - tsString, err = flags.GetString(FlagStreamRestoreTS) - if err != nil { - return errors.Trace(err) - } - if cfg.RestoreTS, err = ParseTSString(tsString, true); err != nil { - return errors.Trace(err) - } - - if cfg.FullBackupStorage, err = flags.GetString(FlagStreamFullBackupStorage); err != nil { - return errors.Trace(err) - } - - if cfg.StartTS > 0 && len(cfg.FullBackupStorage) > 0 { - return errors.Annotatef(berrors.ErrInvalidArgument, "%v and %v are mutually exclusive", - FlagStreamStartTS, FlagStreamFullBackupStorage) - } - - if cfg.PitrBatchCount, err = flags.GetUint32(FlagPiTRBatchCount); err != nil { - return errors.Trace(err) - } - if cfg.PitrBatchSize, err = flags.GetUint32(FlagPiTRBatchSize); err != nil { - return errors.Trace(err) - } - if cfg.PitrConcurrency, err = flags.GetUint32(FlagPiTRConcurrency); err != nil { - return errors.Trace(err) - } - return nil -} - -// ParseFromFlags parses the restore-related flags from the flag set. -func (cfg *RestoreConfig) ParseFromFlags(flags *pflag.FlagSet) error { - var err error - cfg.NoSchema, err = flags.GetBool(flagNoSchema) - if err != nil { - return errors.Trace(err) - } - cfg.LoadStats, err = flags.GetBool(flagLoadStats) - if err != nil { - return errors.Trace(err) - } - err = cfg.Config.ParseFromFlags(flags) - if err != nil { - return errors.Trace(err) - } - err = cfg.RestoreCommonConfig.ParseFromFlags(flags) - if err != nil { - return errors.Trace(err) - } - cfg.Concurrency, err = flags.GetUint32(flagConcurrency) - if err != nil { - return errors.Trace(err) - } - if cfg.Config.Concurrency == 0 { - cfg.Config.Concurrency = defaultRestoreConcurrency - } - cfg.PDConcurrency, err = flags.GetUint(FlagPDConcurrency) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", FlagPDConcurrency) - } - cfg.StatsConcurrency, err = flags.GetUint(FlagStatsConcurrency) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", FlagStatsConcurrency) - } - cfg.BatchFlushInterval, err = flags.GetDuration(FlagBatchFlushInterval) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", FlagBatchFlushInterval) - } - - cfg.DdlBatchSize, err = flags.GetUint(FlagDdlBatchSize) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", FlagDdlBatchSize) - } - cfg.WithPlacementPolicy, err = flags.GetString(FlagWithPlacementPolicy) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", FlagWithPlacementPolicy) - } - cfg.KeyspaceName, err = flags.GetString(FlagKeyspaceName) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", FlagKeyspaceName) - } - cfg.UseCheckpoint, err = flags.GetBool(flagUseCheckpoint) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", flagUseCheckpoint) - } - - cfg.WaitTiflashReady, err = flags.GetBool(FlagWaitTiFlashReady) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", FlagWaitTiFlashReady) - } - - cfg.AllowPITRFromIncremental, err = flags.GetBool(flagAllowPITRFromIncremental) - if err != nil { - return errors.Annotatef(err, "failed to get flag %s", flagAllowPITRFromIncremental) - } - - if flags.Lookup(flagFullBackupType) != nil { - // for restore full only - fullBackupType, err := flags.GetString(flagFullBackupType) - if err != nil { - return errors.Trace(err) - } - if !FullBackupType(fullBackupType).Valid() { - return errors.New("invalid full backup type") - } - cfg.FullBackupType = FullBackupType(fullBackupType) - cfg.Prepare, err = flags.GetBool(flagPrepare) - if err != nil { - return errors.Trace(err) - } - cfg.SkipAWS, err = flags.GetBool(flagSkipAWS) - if err != nil { - return errors.Trace(err) - } - cfg.CloudAPIConcurrency, err = flags.GetUint(flagCloudAPIConcurrency) - if err != nil { - return errors.Trace(err) - } - cfg.OutputFile, err = flags.GetString(flagOutputMetaFile) - if err != nil { - return errors.Trace(err) - } - volumeType, err := flags.GetString(flagVolumeType) - if err != nil { - return errors.Trace(err) - } - cfg.VolumeType = pconfig.EBSVolumeType(volumeType) - if !cfg.VolumeType.Valid() { - return errors.New("invalid volume type: " + volumeType) - } - if cfg.VolumeIOPS, err = flags.GetInt64(flagVolumeIOPS); err != nil { - return errors.Trace(err) - } - if cfg.VolumeThroughput, err = flags.GetInt64(flagVolumeThroughput); err != nil { - return errors.Trace(err) - } - if cfg.VolumeEncrypted, err = flags.GetBool(flagVolumeEncrypted); err != nil { - return errors.Trace(err) - } - - cfg.ProgressFile, err = flags.GetString(flagProgressFile) - if err != nil { - return errors.Trace(err) - } - - cfg.TargetAZ, err = flags.GetString(flagTargetAZ) - if err != nil { - return errors.Trace(err) - } - - cfg.UseFSR, err = flags.GetBool(flagUseFSR) - if err != nil { - return errors.Trace(err) - } - - // iops: gp3 [3,000-16,000]; io1/io2 [100-32,000] - // throughput: gp3 [125, 1000]; io1/io2 cannot set throughput - // io1 and io2 volumes support up to 64,000 IOPS only on Instances built on the Nitro System. - // Other instance families support performance up to 32,000 IOPS. - // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_CreateVolume.html - // todo: check lower/upper bound - } - - return nil -} - -// Adjust is use for BR(binary) and BR in TiDB. -// When new config was added and not included in parser. -// we should set proper value in this function. -// so that both binary and TiDB will use same default value. -func (cfg *RestoreConfig) Adjust() { - cfg.Config.adjust() - cfg.RestoreCommonConfig.adjust() - - if cfg.Config.Concurrency == 0 { - cfg.Config.Concurrency = defaultRestoreConcurrency - } - if cfg.Config.SwitchModeInterval == 0 { - cfg.Config.SwitchModeInterval = defaultSwitchInterval - } - if cfg.PDConcurrency == 0 { - cfg.PDConcurrency = defaultPDConcurrency - } - if cfg.StatsConcurrency == 0 { - cfg.StatsConcurrency = defaultStatsConcurrency - } - if cfg.BatchFlushInterval == 0 { - cfg.BatchFlushInterval = defaultBatchFlushInterval - } - if cfg.DdlBatchSize == 0 { - cfg.DdlBatchSize = defaultFlagDdlBatchSize - } - if cfg.CloudAPIConcurrency == 0 { - cfg.CloudAPIConcurrency = defaultCloudAPIConcurrency - } -} - -func (cfg *RestoreConfig) adjustRestoreConfigForStreamRestore() { - if cfg.PitrConcurrency == 0 { - cfg.PitrConcurrency = defaultPiTRConcurrency - } - if cfg.PitrBatchCount == 0 { - cfg.PitrBatchCount = defaultPiTRBatchCount - } - if cfg.PitrBatchSize == 0 { - cfg.PitrBatchSize = defaultPiTRBatchSize - } - // another goroutine is used to iterate the backup file - cfg.PitrConcurrency += 1 - log.Info("set restore kv files concurrency", zap.Int("concurrency", int(cfg.PitrConcurrency))) - cfg.Config.Concurrency = cfg.PitrConcurrency -} - -// generateLogRestoreTaskName generates the log restore taskName for checkpoint -func (cfg *RestoreConfig) generateLogRestoreTaskName(clusterID, startTS, restoreTs uint64) string { - cfg.checkpointTaskInfoClusterID = clusterID - cfg.checkpointLogRestoreTaskName = fmt.Sprintf("%d/%d.%d", clusterID, startTS, restoreTs) - return cfg.checkpointLogRestoreTaskName -} - -// generateSnapshotRestoreTaskName generates the snapshot restore taskName for checkpoint -func (cfg *RestoreConfig) generateSnapshotRestoreTaskName(clusterID uint64) string { - cfg.checkpointSnapshotRestoreTaskName = fmt.Sprint(clusterID) - return cfg.checkpointSnapshotRestoreTaskName -} - -func configureRestoreClient(ctx context.Context, client *snapclient.SnapClient, cfg *RestoreConfig) error { - client.SetRateLimit(cfg.RateLimit) - client.SetCrypter(&cfg.CipherInfo) - if cfg.NoSchema { - client.EnableSkipCreateSQL() - } - client.SetBatchDdlSize(cfg.DdlBatchSize) - client.SetPlacementPolicyMode(cfg.WithPlacementPolicy) - client.SetWithSysTable(cfg.WithSysTable) - client.SetRewriteMode(ctx) - return nil -} - -func CheckNewCollationEnable( - backupNewCollationEnable string, - g glue.Glue, - storage kv.Storage, - CheckRequirements bool, -) (bool, error) { - se, err := g.CreateSession(storage) - if err != nil { - return false, errors.Trace(err) - } - - newCollationEnable, err := se.GetGlobalVariable(utils.GetTidbNewCollationEnabled()) - if err != nil { - return false, errors.Trace(err) - } - // collate.newCollationEnabled is set to 1 when the collate package is initialized, - // so we need to modify this value according to the config of the cluster - // before using the collate package. - enabled := newCollationEnable == "True" - // modify collate.newCollationEnabled according to the config of the cluster - collate.SetNewCollationEnabledForTest(enabled) - log.Info(fmt.Sprintf("set %s", utils.TidbNewCollationEnabled), zap.Bool("new_collation_enabled", enabled)) - - if backupNewCollationEnable == "" { - if CheckRequirements { - return enabled, errors.Annotatef(berrors.ErrUnknown, - "the value '%s' not found in backupmeta. "+ - "you can use \"SELECT VARIABLE_VALUE FROM mysql.tidb WHERE VARIABLE_NAME='%s';\" to manually check the config. "+ - "if you ensure the value '%s' in backup cluster is as same as restore cluster, use --check-requirements=false to skip this check", - utils.TidbNewCollationEnabled, utils.TidbNewCollationEnabled, utils.TidbNewCollationEnabled) - } - log.Warn(fmt.Sprintf("the config '%s' is not in backupmeta", utils.TidbNewCollationEnabled)) - return enabled, nil - } - - if !strings.EqualFold(backupNewCollationEnable, newCollationEnable) { - return enabled, errors.Annotatef(berrors.ErrUnknown, - "the config '%s' not match, upstream:%v, downstream: %v", - utils.TidbNewCollationEnabled, backupNewCollationEnable, newCollationEnable) - } - - return enabled, nil -} - -// CheckRestoreDBAndTable is used to check whether the restore dbs or tables have been backup -func CheckRestoreDBAndTable(schemas []*metautil.Database, cfg *RestoreConfig) error { - if len(cfg.Schemas) == 0 && len(cfg.Tables) == 0 { - return nil - } - schemasMap := make(map[string]struct{}) - tablesMap := make(map[string]struct{}) - for _, db := range schemas { - dbName := db.Info.Name.L - if dbCIStrName, ok := utils.GetSysDBCIStrName(db.Info.Name); utils.IsSysDB(dbCIStrName.O) && ok { - dbName = dbCIStrName.L - } - schemasMap[utils.EncloseName(dbName)] = struct{}{} - for _, table := range db.Tables { - if table.Info == nil { - // we may back up empty database. - continue - } - tablesMap[utils.EncloseDBAndTable(dbName, table.Info.Name.L)] = struct{}{} - } - } - restoreSchemas := cfg.Schemas - restoreTables := cfg.Tables - for schema := range restoreSchemas { - schemaLName := strings.ToLower(schema) - if _, ok := schemasMap[schemaLName]; !ok { - return errors.Annotatef(berrors.ErrUndefinedRestoreDbOrTable, - "[database: %v] has not been backup, please ensure you has input a correct database name", schema) - } - } - for table := range restoreTables { - tableLName := strings.ToLower(table) - if _, ok := tablesMap[tableLName]; !ok { - return errors.Annotatef(berrors.ErrUndefinedRestoreDbOrTable, - "[table: %v] has not been backup, please ensure you has input a correct table name", table) - } - } - return nil -} - -func isFullRestore(cmdName string) bool { - return cmdName == FullRestoreCmd -} - -// IsStreamRestore checks the command is `restore point` -func IsStreamRestore(cmdName string) bool { - return cmdName == PointRestoreCmd -} - -func registerTaskToPD(ctx context.Context, etcdCLI *clientv3.Client) (closeF func(context.Context) error, err error) { - register := utils.NewTaskRegister(etcdCLI, utils.RegisterRestore, fmt.Sprintf("restore-%s", uuid.New())) - err = register.RegisterTask(ctx) - return register.Close, errors.Trace(err) -} - -func removeCheckpointDataForSnapshotRestore(ctx context.Context, storageName string, taskName string, config *Config) error { - _, s, err := GetStorage(ctx, storageName, config) - if err != nil { - return errors.Trace(err) - } - return errors.Trace(checkpoint.RemoveCheckpointDataForRestore(ctx, s, taskName)) -} - -func removeCheckpointDataForLogRestore(ctx context.Context, storageName string, taskName string, clusterID uint64, config *Config) error { - _, s, err := GetStorage(ctx, storageName, config) - if err != nil { - return errors.Trace(err) - } - return errors.Trace(checkpoint.RemoveCheckpointDataForLogRestore(ctx, s, taskName, clusterID)) -} - -func DefaultRestoreConfig() RestoreConfig { - fs := pflag.NewFlagSet("dummy", pflag.ContinueOnError) - DefineCommonFlags(fs) - DefineRestoreFlags(fs) - cfg := RestoreConfig{} - err := multierr.Combine( - cfg.ParseFromFlags(fs), - cfg.RestoreCommonConfig.ParseFromFlags(fs), - cfg.Config.ParseFromFlags(fs), - ) - if err != nil { - log.Panic("infallible failed.", zap.Error(err)) - } - - return cfg -} - -// RunRestore starts a restore task inside the current goroutine. -func RunRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConfig) error { - etcdCLI, err := dialEtcdWithCfg(c, cfg.Config) - if err != nil { - return err - } - defer func() { - if err := etcdCLI.Close(); err != nil { - log.Error("failed to close the etcd client", zap.Error(err)) - } - }() - if err := checkTaskExists(c, cfg, etcdCLI); err != nil { - return errors.Annotate(err, "failed to check task exists") - } - closeF, err := registerTaskToPD(c, etcdCLI) - if err != nil { - return errors.Annotate(err, "failed to register task to pd") - } - defer func() { - _ = closeF(c) - }() - - config.UpdateGlobal(func(conf *config.Config) { - conf.KeyspaceName = cfg.KeyspaceName - }) - - var restoreError error - if IsStreamRestore(cmdName) { - restoreError = RunStreamRestore(c, g, cmdName, cfg) - } else { - restoreError = runRestore(c, g, cmdName, cfg, nil) - } - if restoreError != nil { - return errors.Trace(restoreError) - } - // Clear the checkpoint data - if cfg.UseCheckpoint { - if len(cfg.checkpointLogRestoreTaskName) > 0 { - log.Info("start to remove checkpoint data for log restore") - err = removeCheckpointDataForLogRestore(c, cfg.Config.Storage, cfg.checkpointLogRestoreTaskName, cfg.checkpointTaskInfoClusterID, &cfg.Config) - if err != nil { - log.Warn("failed to remove checkpoint data for log restore", zap.Error(err)) - } - } - if len(cfg.checkpointSnapshotRestoreTaskName) > 0 { - log.Info("start to remove checkpoint data for snapshot restore.") - var storage string - if IsStreamRestore(cmdName) { - storage = cfg.FullBackupStorage - } else { - storage = cfg.Config.Storage - } - err = removeCheckpointDataForSnapshotRestore(c, storage, cfg.checkpointSnapshotRestoreTaskName, &cfg.Config) - if err != nil { - log.Warn("failed to remove checkpoint data for snapshot restore", zap.Error(err)) - } - } - log.Info("all the checkpoint data is removed.") - } - return nil -} - -func runRestore(c context.Context, g glue.Glue, cmdName string, cfg *RestoreConfig, checkInfo *PiTRTaskInfo) error { - cfg.Adjust() - defer summary.Summary(cmdName) - ctx, cancel := context.WithCancel(c) - defer cancel() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("task.RunRestore", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - // Restore needs domain to do DDL. - needDomain := true - keepaliveCfg := GetKeepalive(&cfg.Config) - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, keepaliveCfg, cfg.CheckRequirements, needDomain, conn.NormalVersionChecker) - if err != nil { - return errors.Trace(err) - } - defer mgr.Close() - codec := mgr.GetStorage().GetCodec() - - // need retrieve these configs from tikv if not set in command. - kvConfigs := &pconfig.KVConfig{ - ImportGoroutines: cfg.ConcurrencyPerStore, - MergeRegionSize: cfg.MergeSmallRegionSizeBytes, - MergeRegionKeyCount: cfg.MergeSmallRegionKeyCount, - } - - // according to https://github.com/pingcap/tidb/issues/34167. - // we should get the real config from tikv to adapt the dynamic region. - httpCli := httputil.NewClient(mgr.GetTLSConfig()) - mgr.ProcessTiKVConfigs(ctx, kvConfigs, httpCli) - - keepaliveCfg.PermitWithoutStream = true - client := snapclient.NewRestoreClient(mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), keepaliveCfg) - // using tikv config to set the concurrency-per-store for client. - client.SetConcurrencyPerStore(kvConfigs.ImportGoroutines.Value) - err = configureRestoreClient(ctx, client, cfg) - if err != nil { - return errors.Trace(err) - } - // Init DB connection sessions - err = client.Init(g, mgr.GetStorage()) - defer client.Close() - - if err != nil { - return errors.Trace(err) - } - u, s, backupMeta, err := ReadBackupMeta(ctx, metautil.MetaFile, &cfg.Config) - if err != nil { - return errors.Trace(err) - } - if cfg.CheckRequirements { - err := checkIncompatibleChangefeed(ctx, backupMeta.EndVersion, mgr.GetDomain().GetEtcdClient()) - log.Info("Checking incompatible TiCDC changefeeds before restoring.", - logutil.ShortError(err), zap.Uint64("restore-ts", backupMeta.EndVersion)) - if err != nil { - return errors.Trace(err) - } - } - - backupVersion := version.NormalizeBackupVersion(backupMeta.ClusterVersion) - if cfg.CheckRequirements && backupVersion != nil { - if versionErr := version.CheckClusterVersion(ctx, mgr.GetPDClient(), version.CheckVersionForBackup(backupVersion)); versionErr != nil { - return errors.Trace(versionErr) - } - } - if _, err = CheckNewCollationEnable(backupMeta.GetNewCollationsEnabled(), g, mgr.GetStorage(), cfg.CheckRequirements); err != nil { - return errors.Trace(err) - } - - reader := metautil.NewMetaReader(backupMeta, s, &cfg.CipherInfo) - if err = client.InitBackupMeta(c, backupMeta, u, reader, cfg.LoadStats); err != nil { - return errors.Trace(err) - } - - if client.IsRawKvMode() { - return errors.Annotate(berrors.ErrRestoreModeMismatch, "cannot do transactional restore from raw kv data") - } - if err = CheckRestoreDBAndTable(client.GetDatabases(), cfg); err != nil { - return err - } - files, tables, dbs := filterRestoreFiles(client, cfg) - if len(dbs) == 0 && len(tables) != 0 { - return errors.Annotate(berrors.ErrRestoreInvalidBackup, "contain tables but no databases") - } - - if cfg.CheckRequirements { - if err := checkDiskSpace(ctx, mgr, files, tables); err != nil { - return errors.Trace(err) - } - } - - archiveSize := reader.ArchiveSize(ctx, files) - g.Record(summary.RestoreDataSize, archiveSize) - //restore from tidb will fetch a general Size issue https://github.com/pingcap/tidb/issues/27247 - g.Record("Size", archiveSize) - restoreTS, err := restore.GetTSWithRetry(ctx, mgr.GetPDClient()) - if err != nil { - return errors.Trace(err) - } - - // for full + log restore. should check the cluster is empty. - if client.IsFull() && checkInfo != nil && checkInfo.FullRestoreCheckErr != nil { - return checkInfo.FullRestoreCheckErr - } - - if client.IsIncremental() { - // don't support checkpoint for the ddl restore - log.Info("the incremental snapshot restore doesn't support checkpoint mode, so unuse checkpoint.") - cfg.UseCheckpoint = false - } - - importModeSwitcher := restore.NewImportModeSwitcher(mgr.GetPDClient(), cfg.Config.SwitchModeInterval, mgr.GetTLSConfig()) - restoreSchedulers, schedulersConfig, err := restore.RestorePreWork(ctx, mgr, importModeSwitcher, cfg.Online, true) - if err != nil { - return errors.Trace(err) - } - - schedulersRemovable := false - defer func() { - // don't reset pd scheduler if checkpoint mode is used and restored is not finished - if cfg.UseCheckpoint && !schedulersRemovable { - log.Info("skip removing pd schehduler for next retry") - return - } - log.Info("start to remove the pd scheduler") - // run the post-work to avoid being stuck in the import - // mode or emptied schedulers. - restore.RestorePostWork(ctx, importModeSwitcher, restoreSchedulers, cfg.Online) - log.Info("finish removing pd scheduler") - }() - - var checkpointTaskName string - var checkpointFirstRun bool = true - if cfg.UseCheckpoint { - checkpointTaskName = cfg.generateSnapshotRestoreTaskName(client.GetClusterID(ctx)) - // if the checkpoint metadata exists in the external storage, the restore is not - // for the first time. - existsCheckpointMetadata, err := checkpoint.ExistsRestoreCheckpoint(ctx, s, checkpointTaskName) - if err != nil { - return errors.Trace(err) - } - checkpointFirstRun = !existsCheckpointMetadata - } - - if isFullRestore(cmdName) { - if client.NeedCheckFreshCluster(cfg.ExplicitFilter, checkpointFirstRun) { - if err = client.CheckTargetClusterFresh(ctx); err != nil { - return errors.Trace(err) - } - } - // todo: move this check into InitFullClusterRestore, we should move restore config into a separate package - // to avoid import cycle problem which we won't do it in this pr, then refactor this - // - // if it's point restore and reached here, then cmdName=FullRestoreCmd and len(cfg.FullBackupStorage) > 0 - if cfg.WithSysTable { - client.InitFullClusterRestore(cfg.ExplicitFilter) - } - } - - if client.IsFullClusterRestore() && client.HasBackedUpSysDB() { - if err = snapclient.CheckSysTableCompatibility(mgr.GetDomain(), tables); err != nil { - return errors.Trace(err) - } - } - - // reload or register the checkpoint - var checkpointSetWithTableID map[int64]map[string]struct{} - if cfg.UseCheckpoint { - sets, restoreSchedulersConfigFromCheckpoint, err := client.InitCheckpoint(ctx, s, checkpointTaskName, schedulersConfig, checkpointFirstRun) - if err != nil { - return errors.Trace(err) - } - if restoreSchedulersConfigFromCheckpoint != nil { - restoreSchedulers = mgr.MakeUndoFunctionByConfig(*restoreSchedulersConfigFromCheckpoint) - } - checkpointSetWithTableID = sets - - defer func() { - // need to flush the whole checkpoint data so that br can quickly jump to - // the log kv restore step when the next retry. - log.Info("wait for flush checkpoint...") - client.WaitForFinishCheckpoint(ctx, len(cfg.FullBackupStorage) > 0 || !schedulersRemovable) - }() - } - - sp := utils.BRServiceSafePoint{ - BackupTS: restoreTS, - TTL: utils.DefaultBRGCSafePointTTL, - ID: utils.MakeSafePointID(), - } - g.Record("BackupTS", backupMeta.EndVersion) - g.Record("RestoreTS", restoreTS) - cctx, gcSafePointKeeperCancel := context.WithCancel(ctx) - defer func() { - log.Info("start to remove gc-safepoint keeper") - // close the gc safe point keeper at first - gcSafePointKeeperCancel() - // set the ttl to 0 to remove the gc-safe-point - sp.TTL = 0 - if err := utils.UpdateServiceSafePoint(ctx, mgr.GetPDClient(), sp); err != nil { - log.Warn("failed to update service safe point, backup may fail if gc triggered", - zap.Error(err), - ) - } - log.Info("finish removing gc-safepoint keeper") - }() - // restore checksum will check safe point with its start ts, see details at - // https://github.com/pingcap/tidb/blob/180c02127105bed73712050594da6ead4d70a85f/store/tikv/kv.go#L186-L190 - // so, we should keep the safe point unchangeable. to avoid GC life time is shorter than transaction duration. - err = utils.StartServiceSafePointKeeper(cctx, mgr.GetPDClient(), sp) - if err != nil { - return errors.Trace(err) - } - - ddlJobs := FilterDDLJobs(client.GetDDLJobs(), tables) - ddlJobs = FilterDDLJobByRules(ddlJobs, DDLJobBlockListRule) - if cfg.AllowPITRFromIncremental { - err = CheckDDLJobByRules(ddlJobs, DDLJobLogIncrementalCompactBlockListRule) - if err != nil { - return errors.Trace(err) - } - } - - err = PreCheckTableTiFlashReplica(ctx, mgr.GetPDClient(), tables, cfg.tiflashRecorder) - if err != nil { - return errors.Trace(err) - } - - err = PreCheckTableClusterIndex(tables, ddlJobs, mgr.GetDomain()) - if err != nil { - return errors.Trace(err) - } - - // pre-set TiDB config for restore - restoreDBConfig := enableTiDBConfig() - defer restoreDBConfig() - - if client.GetSupportPolicy() { - // create policy if backupMeta has policies. - policies, err := client.GetPlacementPolicies() - if err != nil { - return errors.Trace(err) - } - if isFullRestore(cmdName) { - // we should restore all policies during full restoration. - err = client.CreatePolicies(ctx, policies) - if err != nil { - return errors.Trace(err) - } - } else { - client.SetPolicyMap(policies) - } - } - - // preallocate the table id, because any ddl job or database creation also allocates the global ID - err = client.AllocTableIDs(ctx, tables) - if err != nil { - return errors.Trace(err) - } - - // execute DDL first - err = client.ExecDDLs(ctx, ddlJobs) - if err != nil { - return errors.Trace(err) - } - - // nothing to restore, maybe only ddl changes in incremental restore - if len(dbs) == 0 && len(tables) == 0 { - log.Info("nothing to restore, all databases and tables are filtered out") - // even nothing to restore, we show a success message since there is no failure. - summary.SetSuccessStatus(true) - return nil - } - - if err = client.CreateDatabases(ctx, dbs); err != nil { - return errors.Trace(err) - } - - var newTS uint64 - if client.IsIncremental() { - if !cfg.AllowPITRFromIncremental { - // we need to get the new ts after execDDL - // or backfilled data in upstream may not be covered by - // the new ts. - // see https://github.com/pingcap/tidb/issues/54426 - newTS, err = restore.GetTSWithRetry(ctx, mgr.GetPDClient()) - if err != nil { - return errors.Trace(err) - } - } - } - // We make bigger errCh so we won't block on multi-part failed. - errCh := make(chan error, 32) - - tableStream := client.GoCreateTables(ctx, tables, newTS, errCh) - - if len(files) == 0 { - log.Info("no files, empty databases and tables are restored") - summary.SetSuccessStatus(true) - // don't return immediately, wait all pipeline done. - } else { - oldKeyspace, _, err := tikv.DecodeKey(files[0].GetStartKey(), backupMeta.ApiVersion) - if err != nil { - return errors.Trace(err) - } - newKeyspace := codec.GetKeyspace() - - // If the API V2 data occurs in the restore process, the cluster must - // support the keyspace rewrite mode. - if (len(oldKeyspace) > 0 || len(newKeyspace) > 0) && client.GetRewriteMode() == snapclient.RewriteModeLegacy { - return errors.Annotate(berrors.ErrRestoreModeMismatch, "cluster only supports legacy rewrite mode") - } - - // Hijack the tableStream and rewrite the rewrite rules. - tableStream = util.ChanMap(tableStream, func(t snapclient.CreatedTable) snapclient.CreatedTable { - // Set the keyspace info for the checksum requests - t.RewriteRule.OldKeyspace = oldKeyspace - t.RewriteRule.NewKeyspace = newKeyspace - - for _, rule := range t.RewriteRule.Data { - rule.OldKeyPrefix = append(append([]byte{}, oldKeyspace...), rule.OldKeyPrefix...) - rule.NewKeyPrefix = codec.EncodeKey(rule.NewKeyPrefix) - } - return t - }) - } - - if cfg.tiflashRecorder != nil { - tableStream = util.ChanMap(tableStream, func(t snapclient.CreatedTable) snapclient.CreatedTable { - if cfg.tiflashRecorder != nil { - cfg.tiflashRecorder.Rewrite(t.OldTable.Info.ID, t.Table.ID) - } - return t - }) - } - - // Block on creating tables before restore starts. since create table is no longer a heavy operation any more. - tableStream = client.GoBlockCreateTablesPipeline(ctx, maxRestoreBatchSizeLimit, tableStream) - - tableFileMap := MapTableToFiles(files) - log.Debug("mapped table to files", zap.Any("result map", tableFileMap)) - - rangeStream := client.GoValidateFileRanges( - ctx, tableStream, tableFileMap, kvConfigs.MergeRegionSize.Value, kvConfigs.MergeRegionKeyCount.Value, errCh) - - rangeSize := EstimateRangeSize(files) - summary.CollectInt("restore ranges", rangeSize) - log.Info("range and file prepared", zap.Int("file count", len(files)), zap.Int("range count", rangeSize)) - - // Do not reset timestamp if we are doing incremental restore, because - // we are not allowed to decrease timestamp. - if !client.IsIncremental() { - if err = client.ResetTS(ctx, mgr.PdController); err != nil { - log.Error("reset pd TS failed", zap.Error(err)) - return errors.Trace(err) - } - } - - // Restore sst files in batch. - batchSize := mathutil.MaxInt - failpoint.Inject("small-batch-size", func(v failpoint.Value) { - log.Info("failpoint small batch size is on", zap.Int("size", v.(int))) - batchSize = v.(int) - }) - - // Split/Scatter + Download/Ingest - progressLen := int64(rangeSize + len(files)) - if cfg.Checksum { - progressLen += int64(len(tables)) - } - if cfg.WaitTiflashReady { - progressLen += int64(len(tables)) - } - // Redirect to log if there is no log file to avoid unreadable output. - updateCh := g.StartProgress( - ctx, - cmdName, - progressLen, - !cfg.LogProgress) - defer updateCh.Close() - sender, err := snapclient.NewTiKVSender(ctx, client, updateCh, cfg.PDConcurrency) - if err != nil { - return errors.Trace(err) - } - manager, err := snapclient.NewBRContextManager(ctx, mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), cfg.Online) - if err != nil { - return errors.Trace(err) - } - batcher, afterTableRestoredCh := snapclient.NewBatcher(ctx, sender, manager, errCh, updateCh) - batcher.SetCheckpoint(checkpointSetWithTableID) - batcher.SetThreshold(batchSize) - batcher.EnableAutoCommit(ctx, cfg.BatchFlushInterval) - go restoreTableStream(ctx, rangeStream, batcher, errCh) - - var finish <-chan struct{} - postHandleCh := afterTableRestoredCh - - // pipeline checksum - if cfg.Checksum { - postHandleCh = client.GoValidateChecksum( - ctx, postHandleCh, mgr.GetStorage().GetClient(), errCh, updateCh, cfg.ChecksumConcurrency) - } - - // pipeline update meta and load stats - postHandleCh = client.GoUpdateMetaAndLoadStats(ctx, s, postHandleCh, errCh, cfg.StatsConcurrency, cfg.LoadStats) - - // pipeline wait Tiflash synced - if cfg.WaitTiflashReady { - postHandleCh = client.GoWaitTiFlashReady(ctx, postHandleCh, updateCh, errCh) - } - - finish = dropToBlackhole(ctx, postHandleCh, errCh) - - // Reset speed limit. ResetSpeedLimit must be called after client.InitBackupMeta has been called. - defer func() { - var resetErr error - // In future we may need a mechanism to set speed limit in ttl. like what we do in switchmode. TODO - for retry := 0; retry < resetSpeedLimitRetryTimes; retry++ { - resetErr = client.ResetSpeedLimit(ctx) - if resetErr != nil { - log.Warn("failed to reset speed limit, retry it", - zap.Int("retry time", retry), logutil.ShortError(resetErr)) - time.Sleep(time.Duration(retry+3) * time.Second) - continue - } - break - } - if resetErr != nil { - log.Error("failed to reset speed limit, please reset it manually", zap.Error(resetErr)) - } - }() - - select { - case err = <-errCh: - err = multierr.Append(err, multierr.Combine(Exhaust(errCh)...)) - case <-finish: - } - - // If any error happened, return now. - if err != nil { - return errors.Trace(err) - } - - // The cost of rename user table / replace into system table wouldn't be so high. - // So leave it out of the pipeline for easier implementation. - err = client.RestoreSystemSchemas(ctx, cfg.TableFilter) - if err != nil { - return errors.Trace(err) - } - - schedulersRemovable = true - - // Set task summary to success status. - summary.SetSuccessStatus(true) - return nil -} - -func getMaxReplica(ctx context.Context, mgr *conn.Mgr) (cnt uint64, err error) { - var resp map[string]any - err = utils.WithRetry(ctx, func() error { - resp, err = mgr.GetPDHTTPClient().GetReplicateConfig(ctx) - return err - }, utils.NewPDReqBackoffer()) - if err != nil { - return 0, errors.Trace(err) - } - - key := "max-replicas" - val, ok := resp[key] - if !ok { - return 0, errors.Errorf("key %s not found in response %v", key, resp) - } - return uint64(val.(float64)), nil -} - -func getStores(ctx context.Context, mgr *conn.Mgr) (stores *http.StoresInfo, err error) { - err = utils.WithRetry(ctx, func() error { - stores, err = mgr.GetPDHTTPClient().GetStores(ctx) - return err - }, utils.NewPDReqBackoffer()) - if err != nil { - return nil, errors.Trace(err) - } - return stores, nil -} - -func EstimateTikvUsage(files []*backuppb.File, replicaCnt uint64, storeCnt uint64) uint64 { - if storeCnt == 0 { - return 0 - } - if replicaCnt > storeCnt { - replicaCnt = storeCnt - } - totalSize := uint64(0) - for _, file := range files { - totalSize += file.GetSize_() - } - log.Info("estimate tikv usage", zap.Uint64("total size", totalSize), zap.Uint64("replicaCnt", replicaCnt), zap.Uint64("store count", storeCnt)) - return totalSize * replicaCnt / storeCnt -} - -func EstimateTiflashUsage(tables []*metautil.Table, storeCnt uint64) uint64 { - if storeCnt == 0 { - return 0 - } - tiflashTotal := uint64(0) - for _, table := range tables { - if table.Info.TiFlashReplica == nil || table.Info.TiFlashReplica.Count <= 0 { - continue - } - tableBytes := uint64(0) - for _, file := range table.Files { - tableBytes += file.GetSize_() - } - tiflashTotal += tableBytes * table.Info.TiFlashReplica.Count - } - log.Info("estimate tiflash usage", zap.Uint64("total size", tiflashTotal), zap.Uint64("store count", storeCnt)) - return tiflashTotal / storeCnt -} - -func CheckStoreSpace(necessary uint64, store *http.StoreInfo) error { - available, err := units.RAMInBytes(store.Status.Available) - if err != nil { - return errors.Annotatef(berrors.ErrPDInvalidResponse, "store %d has invalid available space %s", store.Store.ID, store.Status.Available) - } - if available <= 0 { - return errors.Annotatef(berrors.ErrPDInvalidResponse, "store %d has invalid available space %s", store.Store.ID, store.Status.Available) - } - if uint64(available) < necessary { - return errors.Annotatef(berrors.ErrKVDiskFull, "store %d has no space left on device, available %s, necessary %s", - store.Store.ID, units.BytesSize(float64(available)), units.BytesSize(float64(necessary))) - } - return nil -} - -func checkDiskSpace(ctx context.Context, mgr *conn.Mgr, files []*backuppb.File, tables []*metautil.Table) error { - maxReplica, err := getMaxReplica(ctx, mgr) - if err != nil { - return errors.Trace(err) - } - stores, err := getStores(ctx, mgr) - if err != nil { - return errors.Trace(err) - } - - var tikvCnt, tiflashCnt uint64 = 0, 0 - for i := range stores.Stores { - store := &stores.Stores[i] - if engine.IsTiFlashHTTPResp(&store.Store) { - tiflashCnt += 1 - continue - } - tikvCnt += 1 - } - - // We won't need to restore more than 1800 PB data at one time, right? - preserve := func(base uint64, ratio float32) uint64 { - if base > 1000*units.PB { - return base - } - return base * uint64(ratio*10) / 10 - } - tikvUsage := preserve(EstimateTikvUsage(files, maxReplica, tikvCnt), 1.1) - tiflashUsage := preserve(EstimateTiflashUsage(tables, tiflashCnt), 1.4) - log.Info("preserved disk space", zap.Uint64("tikv", tikvUsage), zap.Uint64("tiflash", tiflashUsage)) - - err = utils.WithRetry(ctx, func() error { - stores, err = getStores(ctx, mgr) - if err != nil { - return errors.Trace(err) - } - for _, store := range stores.Stores { - if engine.IsTiFlashHTTPResp(&store.Store) { - if err := CheckStoreSpace(tiflashUsage, &store); err != nil { - return errors.Trace(err) - } - continue - } - if err := CheckStoreSpace(tikvUsage, &store); err != nil { - return errors.Trace(err) - } - } - return nil - }, utils.NewDiskCheckBackoffer()) - if err != nil { - return errors.Trace(err) - } - return nil -} - -// Exhaust drains all remaining errors in the channel, into a slice of errors. -func Exhaust(ec <-chan error) []error { - out := make([]error, 0, len(ec)) - for { - select { - case err := <-ec: - out = append(out, err) - default: - // errCh will NEVER be closed(ya see, it has multi sender-part), - // so we just consume the current backlog of this channel, then return. - return out - } - } -} - -// EstimateRangeSize estimates the total range count by file. -func EstimateRangeSize(files []*backuppb.File) int { - result := 0 - for _, f := range files { - if strings.HasSuffix(f.GetName(), "_write.sst") { - result++ - } - } - return result -} - -// MapTableToFiles makes a map that mapping table ID to its backup files. -// aware that one file can and only can hold one table. -func MapTableToFiles(files []*backuppb.File) map[int64][]*backuppb.File { - result := map[int64][]*backuppb.File{} - for _, file := range files { - tableID := tablecodec.DecodeTableID(file.GetStartKey()) - tableEndID := tablecodec.DecodeTableID(file.GetEndKey()) - if tableID != tableEndID { - log.Panic("key range spread between many files.", - zap.String("file name", file.Name), - logutil.Key("startKey", file.StartKey), - logutil.Key("endKey", file.EndKey)) - } - if tableID == 0 { - log.Panic("invalid table key of file", - zap.String("file name", file.Name), - logutil.Key("startKey", file.StartKey), - logutil.Key("endKey", file.EndKey)) - } - result[tableID] = append(result[tableID], file) - } - return result -} - -// dropToBlackhole drop all incoming tables into black hole, -// i.e. don't execute checksum, just increase the process anyhow. -func dropToBlackhole( - ctx context.Context, - inCh <-chan *snapclient.CreatedTable, - errCh chan<- error, -) <-chan struct{} { - outCh := make(chan struct{}, 1) - go func() { - defer func() { - close(outCh) - }() - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case _, ok := <-inCh: - if !ok { - return - } - } - } - }() - return outCh -} - -// filterRestoreFiles filters tables that can't be processed after applying cfg.TableFilter.MatchTable. -// if the db has no table that can be processed, the db will be filtered too. -func filterRestoreFiles( - client *snapclient.SnapClient, - cfg *RestoreConfig, -) (files []*backuppb.File, tables []*metautil.Table, dbs []*metautil.Database) { - for _, db := range client.GetDatabases() { - dbName := db.Info.Name.O - if name, ok := utils.GetSysDBName(db.Info.Name); utils.IsSysDB(name) && ok { - dbName = name - } - if !cfg.TableFilter.MatchSchema(dbName) { - continue - } - dbs = append(dbs, db) - for _, table := range db.Tables { - if table.Info == nil || !cfg.TableFilter.MatchTable(dbName, table.Info.Name.O) { - continue - } - files = append(files, table.Files...) - tables = append(tables, table) - } - } - return -} - -// enableTiDBConfig tweaks some of configs of TiDB to make the restore progress go well. -// return a function that could restore the config to origin. -func enableTiDBConfig() func() { - restoreConfig := config.RestoreFunc() - config.UpdateGlobal(func(conf *config.Config) { - // set max-index-length before execute DDLs and create tables - // we set this value to max(3072*4), otherwise we might not restore table - // when upstream and downstream both set this value greater than default(3072) - conf.MaxIndexLength = config.DefMaxOfMaxIndexLength - log.Warn("set max-index-length to max(3072*4) to skip check index length in DDL") - conf.IndexLimit = config.DefMaxOfIndexLimit - log.Warn("set index-limit to max(64*8) to skip check index count in DDL") - conf.TableColumnCountLimit = config.DefMaxOfTableColumnCountLimit - log.Warn("set table-column-count to max(4096) to skip check column count in DDL") - }) - return restoreConfig -} - -// restoreTableStream blocks current goroutine and restore a stream of tables, -// by send tables to batcher. -func restoreTableStream( - ctx context.Context, - inputCh <-chan snapclient.TableWithRange, - batcher *snapclient.Batcher, - errCh chan<- error, -) { - oldTableCount := 0 - defer func() { - // when things done, we must clean pending requests. - batcher.Close() - log.Info("doing postwork", - zap.Int("table count", oldTableCount), - ) - }() - - for { - select { - case <-ctx.Done(): - errCh <- ctx.Err() - return - case t, ok := <-inputCh: - if !ok { - return - } - oldTableCount += 1 - - batcher.Add(t) - } - } -} - -func getTiFlashNodeCount(ctx context.Context, pdClient pd.Client) (uint64, error) { - tiFlashStores, err := conn.GetAllTiKVStoresWithRetry(ctx, pdClient, connutil.TiFlashOnly) - if err != nil { - return 0, errors.Trace(err) - } - return uint64(len(tiFlashStores)), nil -} - -// PreCheckTableTiFlashReplica checks whether TiFlash replica is less than TiFlash node. -func PreCheckTableTiFlashReplica( - ctx context.Context, - pdClient pd.Client, - tables []*metautil.Table, - recorder *tiflashrec.TiFlashRecorder, -) error { - tiFlashStoreCount, err := getTiFlashNodeCount(ctx, pdClient) - if err != nil { - return err - } - for _, table := range tables { - if table.Info.TiFlashReplica != nil { - // we should not set available to true. because we cannot guarantee the raft log lag of tiflash when restore finished. - // just let tiflash ticker set it by checking lag of all related regions. - table.Info.TiFlashReplica.Available = false - table.Info.TiFlashReplica.AvailablePartitionIDs = nil - if recorder != nil { - recorder.AddTable(table.Info.ID, *table.Info.TiFlashReplica) - log.Info("record tiflash replica for table, to reset it by ddl later", - zap.Stringer("db", table.DB.Name), - zap.Stringer("table", table.Info.Name), - ) - table.Info.TiFlashReplica = nil - } else if table.Info.TiFlashReplica.Count > tiFlashStoreCount { - // we cannot satisfy TiFlash replica in restore cluster. so we should - // set TiFlashReplica to unavailable in tableInfo, to avoid TiDB cannot sense TiFlash and make plan to TiFlash - // see details at https://github.com/pingcap/br/issues/931 - // TODO maybe set table.Info.TiFlashReplica.Count to tiFlashStoreCount, but we need more tests about it. - log.Warn("table does not satisfy tiflash replica requirements, set tiflash replcia to unavailable", - zap.Stringer("db", table.DB.Name), - zap.Stringer("table", table.Info.Name), - zap.Uint64("expect tiflash replica", table.Info.TiFlashReplica.Count), - zap.Uint64("actual tiflash store", tiFlashStoreCount), - ) - table.Info.TiFlashReplica = nil - } - } - } - return nil -} - -// PreCheckTableClusterIndex checks whether backup tables and existed tables have different cluster index options。 -func PreCheckTableClusterIndex( - tables []*metautil.Table, - ddlJobs []*model.Job, - dom *domain.Domain, -) error { - for _, table := range tables { - oldTableInfo, err := restore.GetTableSchema(dom, table.DB.Name, table.Info.Name) - // table exists in database - if err == nil { - if table.Info.IsCommonHandle != oldTableInfo.IsCommonHandle { - return errors.Annotatef(berrors.ErrRestoreModeMismatch, - "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", - restore.TransferBoolToValue(table.Info.IsCommonHandle), - table.Info.IsCommonHandle, - oldTableInfo.IsCommonHandle) - } - } - } - for _, job := range ddlJobs { - if job.Type == model.ActionCreateTable { - tableInfo := job.BinlogInfo.TableInfo - if tableInfo != nil { - oldTableInfo, err := restore.GetTableSchema(dom, model.NewCIStr(job.SchemaName), tableInfo.Name) - // table exists in database - if err == nil { - if tableInfo.IsCommonHandle != oldTableInfo.IsCommonHandle { - return errors.Annotatef(berrors.ErrRestoreModeMismatch, - "Clustered index option mismatch. Restored cluster's @@tidb_enable_clustered_index should be %v (backup table = %v, created table = %v).", - restore.TransferBoolToValue(tableInfo.IsCommonHandle), - tableInfo.IsCommonHandle, - oldTableInfo.IsCommonHandle) - } - } - } - } - } - return nil -} - -func getDatabases(tables []*metautil.Table) (dbs []*model.DBInfo) { - dbIDs := make(map[int64]bool) - for _, table := range tables { - if !dbIDs[table.DB.ID] { - dbs = append(dbs, table.DB) - dbIDs[table.DB.ID] = true - } - } - return -} - -// FilterDDLJobs filters ddl jobs. -func FilterDDLJobs(allDDLJobs []*model.Job, tables []*metautil.Table) (ddlJobs []*model.Job) { - // Sort the ddl jobs by schema version in descending order. - slices.SortFunc(allDDLJobs, func(i, j *model.Job) int { - return cmp.Compare(j.BinlogInfo.SchemaVersion, i.BinlogInfo.SchemaVersion) - }) - dbs := getDatabases(tables) - for _, db := range dbs { - // These maps is for solving some corner case. - // e.g. let "t=2" indicates that the id of database "t" is 2, if the ddl execution sequence is: - // rename "a" to "b"(a=1) -> drop "b"(b=1) -> create "b"(b=2) -> rename "b" to "a"(a=2) - // Which we cannot find the "create" DDL by name and id directly. - // To cover †his case, we must find all names and ids the database/table ever had. - dbIDs := make(map[int64]bool) - dbIDs[db.ID] = true - dbNames := make(map[string]bool) - dbNames[db.Name.String()] = true - for _, job := range allDDLJobs { - if job.BinlogInfo.DBInfo != nil { - if dbIDs[job.SchemaID] || dbNames[job.BinlogInfo.DBInfo.Name.String()] { - ddlJobs = append(ddlJobs, job) - // The the jobs executed with the old id, like the step 2 in the example above. - dbIDs[job.SchemaID] = true - // For the jobs executed after rename, like the step 3 in the example above. - dbNames[job.BinlogInfo.DBInfo.Name.String()] = true - } - } - } - } - - for _, table := range tables { - tableIDs := make(map[int64]bool) - tableIDs[table.Info.ID] = true - tableNames := make(map[restore.UniqueTableName]bool) - name := restore.UniqueTableName{DB: table.DB.Name.String(), Table: table.Info.Name.String()} - tableNames[name] = true - for _, job := range allDDLJobs { - if job.BinlogInfo.TableInfo != nil { - name = restore.UniqueTableName{DB: job.SchemaName, Table: job.BinlogInfo.TableInfo.Name.String()} - if tableIDs[job.TableID] || tableNames[name] { - ddlJobs = append(ddlJobs, job) - tableIDs[job.TableID] = true - // For truncate table, the id may be changed - tableIDs[job.BinlogInfo.TableInfo.ID] = true - tableNames[name] = true - } - } - } - } - return ddlJobs -} - -// CheckDDLJobByRules if one of rules returns true, the job in srcDDLJobs will be filtered. -func CheckDDLJobByRules(srcDDLJobs []*model.Job, rules ...DDLJobFilterRule) error { - for _, ddlJob := range srcDDLJobs { - for _, rule := range rules { - if rule(ddlJob) { - return errors.Annotatef(berrors.ErrRestoreModeMismatch, "DDL job %s is not allowed in incremental restore"+ - " when --allow-pitr-from-incremental enabled", ddlJob.String()) - } - } - } - return nil -} - -// FilterDDLJobByRules if one of rules returns true, the job in srcDDLJobs will be filtered. -func FilterDDLJobByRules(srcDDLJobs []*model.Job, rules ...DDLJobFilterRule) (dstDDLJobs []*model.Job) { - dstDDLJobs = make([]*model.Job, 0, len(srcDDLJobs)) - for _, ddlJob := range srcDDLJobs { - passed := true - for _, rule := range rules { - if rule(ddlJob) { - passed = false - break - } - } - - if passed { - dstDDLJobs = append(dstDDLJobs, ddlJob) - } - } - - return -} - -type DDLJobFilterRule func(ddlJob *model.Job) bool - -var incrementalRestoreActionBlockList = map[model.ActionType]struct{}{ - model.ActionSetTiFlashReplica: {}, - model.ActionUpdateTiFlashReplicaStatus: {}, - model.ActionLockTable: {}, - model.ActionUnlockTable: {}, -} - -var logIncrementalRestoreCompactibleBlockList = map[model.ActionType]struct{}{ - model.ActionAddIndex: {}, - model.ActionModifyColumn: {}, - model.ActionReorganizePartition: {}, -} - -// DDLJobBlockListRule rule for filter ddl job with type in block list. -func DDLJobBlockListRule(ddlJob *model.Job) bool { - return checkIsInActions(ddlJob.Type, incrementalRestoreActionBlockList) -} - -func DDLJobLogIncrementalCompactBlockListRule(ddlJob *model.Job) bool { - return checkIsInActions(ddlJob.Type, logIncrementalRestoreCompactibleBlockList) -} - -func checkIsInActions(action model.ActionType, actions map[model.ActionType]struct{}) bool { - _, ok := actions[action] - return ok -} diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index cf46760af7677..29e3177df7e0c 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -1174,9 +1174,9 @@ func RunStreamRestore( return errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("failed-before-full-restore")); _err_ == nil { - return errors.New("failpoint: failed before full restore") - } + failpoint.Inject("failed-before-full-restore", func(_ failpoint.Value) { + failpoint.Return(errors.New("failpoint: failed before full restore")) + }) recorder := tiflashrec.New() cfg.tiflashRecorder = recorder @@ -1469,11 +1469,11 @@ func restoreStream( } } - if _, _err_ := failpoint.Eval(_curpkg_("do-checksum-with-rewrite-rules")); _err_ == nil { + failpoint.Inject("do-checksum-with-rewrite-rules", func(_ failpoint.Value) { if err := client.FailpointDoChecksumForLogRestore(ctx, mgr.GetStorage().GetClient(), mgr.GetPDClient(), idrules, rewriteRules); err != nil { - return errors.Annotate(err, "failed to do checksum") + failpoint.Return(errors.Annotate(err, "failed to do checksum")) } - } + }) gcDisabledRestorable = true diff --git a/br/pkg/task/stream.go__failpoint_stash__ b/br/pkg/task/stream.go__failpoint_stash__ deleted file mode 100644 index 29e3177df7e0c..0000000000000 --- a/br/pkg/task/stream.go__failpoint_stash__ +++ /dev/null @@ -1,1846 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package task - -import ( - "bytes" - "context" - "encoding/binary" - "fmt" - "math" - "net/http" - "slices" - "strings" - "sync" - "time" - - "github.com/docker/go-units" - "github.com/fatih/color" - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/backup" - "github.com/pingcap/tidb/br/pkg/checkpoint" - "github.com/pingcap/tidb/br/pkg/conn" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/httputil" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/metautil" - "github.com/pingcap/tidb/br/pkg/restore" - "github.com/pingcap/tidb/br/pkg/restore/ingestrec" - logclient "github.com/pingcap/tidb/br/pkg/restore/log_client" - "github.com/pingcap/tidb/br/pkg/restore/tiflashrec" - restoreutils "github.com/pingcap/tidb/br/pkg/restore/utils" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/stream" - "github.com/pingcap/tidb/br/pkg/streamhelper" - advancercfg "github.com/pingcap/tidb/br/pkg/streamhelper/config" - "github.com/pingcap/tidb/br/pkg/streamhelper/daemon" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/util/cdcutil" - "github.com/spf13/pflag" - "github.com/tikv/client-go/v2/oracle" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" -) - -const ( - flagYes = "yes" - flagUntil = "until" - flagStreamJSONOutput = "json" - flagStreamTaskName = "task-name" - flagStreamStartTS = "start-ts" - flagStreamEndTS = "end-ts" - flagGCSafePointTTS = "gc-ttl" - - truncateLockPath = "truncating.lock" - hintOnTruncateLock = "There might be another truncate task running, or a truncate task that didn't exit properly. " + - "You may check the metadata and continue by wait other task finish or manually delete the lock file " + truncateLockPath + " at the external storage." -) - -var ( - StreamStart = "log start" - StreamStop = "log stop" - StreamPause = "log pause" - StreamResume = "log resume" - StreamStatus = "log status" - StreamTruncate = "log truncate" - StreamMetadata = "log metadata" - StreamCtl = "log advancer" - - skipSummaryCommandList = map[string]struct{}{ - StreamStatus: {}, - StreamTruncate: {}, - } - - streamShiftDuration = time.Hour -) - -var StreamCommandMap = map[string]func(c context.Context, g glue.Glue, cmdName string, cfg *StreamConfig) error{ - StreamStart: RunStreamStart, - StreamStop: RunStreamStop, - StreamPause: RunStreamPause, - StreamResume: RunStreamResume, - StreamStatus: RunStreamStatus, - StreamTruncate: RunStreamTruncate, - StreamMetadata: RunStreamMetadata, - StreamCtl: RunStreamAdvancer, -} - -// StreamConfig specifies the configure about backup stream -type StreamConfig struct { - Config - - TaskName string `json:"task-name" toml:"task-name"` - - // StartTS usually equals the tso of full-backup, but user can reset it - StartTS uint64 `json:"start-ts" toml:"start-ts"` - EndTS uint64 `json:"end-ts" toml:"end-ts"` - // SafePointTTL ensures TiKV can scan entries not being GC at [startTS, currentTS] - SafePointTTL int64 `json:"safe-point-ttl" toml:"safe-point-ttl"` - - // Spec for the command `truncate`, we should truncate the until when? - Until uint64 `json:"until" toml:"until"` - DryRun bool `json:"dry-run" toml:"dry-run"` - SkipPrompt bool `json:"skip-prompt" toml:"skip-prompt"` - - // Spec for the command `status`. - JSONOutput bool `json:"json-output" toml:"json-output"` - - // Spec for the command `advancer`. - AdvancerCfg advancercfg.Config `json:"advancer-config" toml:"advancer-config"` -} - -func (cfg *StreamConfig) makeStorage(ctx context.Context) (storage.ExternalStorage, error) { - u, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) - if err != nil { - return nil, errors.Trace(err) - } - opts := getExternalStorageOptions(&cfg.Config, u) - storage, err := storage.New(ctx, u, &opts) - if err != nil { - return nil, errors.Trace(err) - } - return storage, nil -} - -// DefineStreamStartFlags defines flags used for `stream start` -func DefineStreamStartFlags(flags *pflag.FlagSet) { - DefineStreamCommonFlags(flags) - - flags.String(flagStreamStartTS, "", - "usually equals last full backupTS, used for backup log. Default value is current ts.\n"+ - "support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'.") - // 999999999999999999 means 2090-11-18 22:07:45 - flags.String(flagStreamEndTS, "999999999999999999", "end ts, indicate stopping observe after endTS"+ - "support TSO or datetime") - _ = flags.MarkHidden(flagStreamEndTS) - flags.Int64(flagGCSafePointTTS, utils.DefaultStreamStartSafePointTTL, - "the TTL (in seconds) that PD holds for BR's GC safepoint") - _ = flags.MarkHidden(flagGCSafePointTTS) -} - -func DefineStreamPauseFlags(flags *pflag.FlagSet) { - DefineStreamCommonFlags(flags) - flags.Int64(flagGCSafePointTTS, utils.DefaultStreamPauseSafePointTTL, - "the TTL (in seconds) that PD holds for BR's GC safepoint") -} - -// DefineStreamCommonFlags define common flags for `stream task` -func DefineStreamCommonFlags(flags *pflag.FlagSet) { - flags.String(flagStreamTaskName, "", "The task name for the backup log task.") -} - -func DefineStreamStatusCommonFlags(flags *pflag.FlagSet) { - flags.String(flagStreamTaskName, stream.WildCard, - "The task name for backup stream log. If default, get status of all of tasks", - ) - flags.Bool(flagStreamJSONOutput, false, - "Print JSON as the output.", - ) -} - -func DefineStreamTruncateLogFlags(flags *pflag.FlagSet) { - flags.String(flagUntil, "", "Remove all backup data until this TS."+ - "(support TSO or datetime, e.g. '400036290571534337' or '2018-05-11 01:42:23+0800'.)") - flags.Bool(flagDryRun, false, "Run the command but don't really delete the files.") - flags.BoolP(flagYes, "y", false, "Skip all prompts and always execute the command.") -} - -func (cfg *StreamConfig) ParseStreamStatusFromFlags(flags *pflag.FlagSet) error { - var err error - cfg.JSONOutput, err = flags.GetBool(flagStreamJSONOutput) - if err != nil { - return errors.Trace(err) - } - - if err = cfg.ParseStreamCommonFromFlags(flags); err != nil { - return errors.Trace(err) - } - - return nil -} - -func (cfg *StreamConfig) ParseStreamTruncateFromFlags(flags *pflag.FlagSet) error { - tsString, err := flags.GetString(flagUntil) - if err != nil { - return errors.Trace(err) - } - if cfg.Until, err = ParseTSString(tsString, true); err != nil { - return errors.Trace(err) - } - if cfg.SkipPrompt, err = flags.GetBool(flagYes); err != nil { - return errors.Trace(err) - } - if cfg.DryRun, err = flags.GetBool(flagDryRun); err != nil { - return errors.Trace(err) - } - return nil -} - -// ParseStreamStartFromFlags parse parameters for `stream start` -func (cfg *StreamConfig) ParseStreamStartFromFlags(flags *pflag.FlagSet) error { - err := cfg.ParseStreamCommonFromFlags(flags) - if err != nil { - return errors.Trace(err) - } - - tsString, err := flags.GetString(flagStreamStartTS) - if err != nil { - return errors.Trace(err) - } - - if cfg.StartTS, err = ParseTSString(tsString, true); err != nil { - return errors.Trace(err) - } - - tsString, err = flags.GetString(flagStreamEndTS) - if err != nil { - return errors.Trace(err) - } - - if cfg.EndTS, err = ParseTSString(tsString, true); err != nil { - return errors.Trace(err) - } - - if cfg.SafePointTTL, err = flags.GetInt64(flagGCSafePointTTS); err != nil { - return errors.Trace(err) - } - - if cfg.SafePointTTL <= 0 { - cfg.SafePointTTL = utils.DefaultStreamStartSafePointTTL - } - - return nil -} - -// ParseStreamPauseFromFlags parse parameters for `stream pause` -func (cfg *StreamConfig) ParseStreamPauseFromFlags(flags *pflag.FlagSet) error { - err := cfg.ParseStreamCommonFromFlags(flags) - if err != nil { - return errors.Trace(err) - } - - if cfg.SafePointTTL, err = flags.GetInt64(flagGCSafePointTTS); err != nil { - return errors.Trace(err) - } - if cfg.SafePointTTL <= 0 { - cfg.SafePointTTL = utils.DefaultStreamPauseSafePointTTL - } - return nil -} - -// ParseStreamCommonFromFlags parse parameters for `stream task` -func (cfg *StreamConfig) ParseStreamCommonFromFlags(flags *pflag.FlagSet) error { - var err error - - cfg.TaskName, err = flags.GetString(flagStreamTaskName) - if err != nil { - return errors.Trace(err) - } - - if len(cfg.TaskName) <= 0 { - return errors.Annotate(berrors.ErrInvalidArgument, "Miss parameters task-name") - } - return nil -} - -type streamMgr struct { - cfg *StreamConfig - mgr *conn.Mgr - bc *backup.Client - httpCli *http.Client -} - -func NewStreamMgr(ctx context.Context, cfg *StreamConfig, g glue.Glue, isStreamStart bool) (*streamMgr, error) { - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), - cfg.CheckRequirements, false, conn.StreamVersionChecker) - if err != nil { - return nil, errors.Trace(err) - } - defer func() { - if err != nil { - mgr.Close() - } - }() - - // just stream start need Storage - s := &streamMgr{ - cfg: cfg, - mgr: mgr, - } - if isStreamStart { - client := backup.NewBackupClient(ctx, mgr) - - backend, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) - if err != nil { - return nil, errors.Trace(err) - } - - opts := storage.ExternalStorageOptions{ - NoCredentials: cfg.NoCreds, - SendCredentials: cfg.SendCreds, - CheckS3ObjectLockOptions: true, - } - if err = client.SetStorage(ctx, backend, &opts); err != nil { - return nil, errors.Trace(err) - } - s.bc = client - - // create http client to do some requirements check. - s.httpCli = httputil.NewClient(mgr.GetTLSConfig()) - } - return s, nil -} - -func (s *streamMgr) close() { - s.mgr.Close() -} - -func (s *streamMgr) checkLock(ctx context.Context) (bool, error) { - return s.bc.GetStorage().FileExists(ctx, metautil.LockFile) -} - -func (s *streamMgr) setLock(ctx context.Context) error { - return s.bc.SetLockFile(ctx) -} - -// adjustAndCheckStartTS checks that startTS should be smaller than currentTS, -// and endTS is larger than currentTS. -func (s *streamMgr) adjustAndCheckStartTS(ctx context.Context) error { - currentTS, err := s.mgr.GetTS(ctx) - if err != nil { - return errors.Trace(err) - } - // set currentTS to startTS as a default value - if s.cfg.StartTS == 0 { - s.cfg.StartTS = currentTS - } - - if currentTS < s.cfg.StartTS { - return errors.Annotatef(berrors.ErrInvalidArgument, - "invalid timestamps, startTS %d should be smaller than currentTS %d", - s.cfg.StartTS, currentTS) - } - if s.cfg.EndTS <= currentTS { - return errors.Annotatef(berrors.ErrInvalidArgument, - "invalid timestamps, endTS %d should be larger than currentTS %d", - s.cfg.EndTS, currentTS) - } - - return nil -} - -// checkImportTaskRunning checks whether there is any import task running. -func (s *streamMgr) checkImportTaskRunning(ctx context.Context, etcdCLI *clientv3.Client) error { - list, err := utils.GetImportTasksFrom(ctx, etcdCLI) - if err != nil { - return errors.Trace(err) - } - if !list.Empty() { - return errors.Errorf("There are some lightning/restore tasks running: %s"+ - "please stop or wait finishing at first. "+ - "If the lightning/restore task is forced to terminate by system, "+ - "please wait for ttl to decrease to 0.", list.MessageToUser()) - } - return nil -} - -// setGCSafePoint sets the server safe point to PD. -func (s *streamMgr) setGCSafePoint(ctx context.Context, sp utils.BRServiceSafePoint) error { - err := utils.CheckGCSafePoint(ctx, s.mgr.GetPDClient(), sp.BackupTS) - if err != nil { - return errors.Annotatef(err, - "failed to check gc safePoint, ts %v", sp.BackupTS) - } - - err = utils.UpdateServiceSafePoint(ctx, s.mgr.GetPDClient(), sp) - if err != nil { - return errors.Trace(err) - } - - log.Info("set stream safePoint", zap.Object("safePoint", sp)) - return nil -} - -func (s *streamMgr) buildObserveRanges() ([]kv.KeyRange, error) { - dRanges, err := stream.BuildObserveDataRanges( - s.mgr.GetStorage(), - s.cfg.FilterStr, - s.cfg.TableFilter, - s.cfg.StartTS, - ) - if err != nil { - return nil, errors.Trace(err) - } - - mRange := stream.BuildObserveMetaRange() - rs := append([]kv.KeyRange{*mRange}, dRanges...) - slices.SortFunc(rs, func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - - return rs, nil -} - -func (s *streamMgr) backupFullSchemas(ctx context.Context) error { - clusterVersion, err := s.mgr.GetClusterVersion(ctx) - if err != nil { - return errors.Trace(err) - } - - metaWriter := metautil.NewMetaWriter(s.bc.GetStorage(), metautil.MetaFileSize, true, metautil.MetaFile, nil) - metaWriter.Update(func(m *backuppb.BackupMeta) { - // save log startTS to backupmeta file - m.StartVersion = s.cfg.StartTS - m.ClusterId = s.bc.GetClusterID() - m.ClusterVersion = clusterVersion - }) - - if err = metaWriter.FlushBackupMeta(ctx); err != nil { - return errors.Trace(err) - } - return nil -} - -func (s *streamMgr) checkStreamStartEnable(ctx context.Context) error { - supportStream, err := s.mgr.IsLogBackupEnabled(ctx, s.httpCli) - if err != nil { - return errors.Trace(err) - } - if !supportStream { - return errors.New("Unable to create task about log-backup. " + - "please set TiKV config `log-backup.enable` to true and restart TiKVs.") - } - - return nil -} - -type RestoreFunc func(string) error - -// KeepGcDisabled keeps GC disabled and return a function that used to gc enabled. -// gc.ratio-threshold = "-1.0", which represents disable gc in TiKV. -func KeepGcDisabled(g glue.Glue, store kv.Storage) (RestoreFunc, string, error) { - se, err := g.CreateSession(store) - if err != nil { - return nil, "", errors.Trace(err) - } - - execCtx := se.GetSessionCtx().GetRestrictedSQLExecutor() - oldRatio, err := utils.GetGcRatio(execCtx) - if err != nil { - return nil, "", errors.Trace(err) - } - - newRatio := "-1.0" - err = utils.SetGcRatio(execCtx, newRatio) - if err != nil { - return nil, "", errors.Trace(err) - } - - // If the oldRatio is negative, which is not normal status. - // It should set default value "1.1" after PiTR finished. - if strings.HasPrefix(oldRatio, "-") { - oldRatio = utils.DefaultGcRatioVal - } - - return func(ratio string) error { - return utils.SetGcRatio(execCtx, ratio) - }, oldRatio, nil -} - -// RunStreamCommand run all kinds of `stream task` -func RunStreamCommand( - ctx context.Context, - g glue.Glue, - cmdName string, - cfg *StreamConfig, -) error { - cfg.Config.adjust() - defer func() { - if _, ok := skipSummaryCommandList[cmdName]; !ok { - summary.Summary(cmdName) - } - }() - commandFn, exist := StreamCommandMap[cmdName] - if !exist { - return errors.Annotatef(berrors.ErrInvalidArgument, "invalid command %s", cmdName) - } - - if err := commandFn(ctx, g, cmdName, cfg); err != nil { - log.Error("failed to stream", zap.String("command", cmdName), zap.Error(err)) - summary.SetSuccessStatus(false) - summary.CollectFailureUnit(cmdName, err) - return err - } - summary.SetSuccessStatus(true) - return nil -} - -// RunStreamStart specifies starting a stream task -func RunStreamStart( - c context.Context, - g glue.Glue, - cmdName string, - cfg *StreamConfig, -) error { - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("task.RunStreamStart", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - streamMgr, err := NewStreamMgr(ctx, cfg, g, true) - if err != nil { - return errors.Trace(err) - } - defer streamMgr.close() - - if err = streamMgr.checkStreamStartEnable(ctx); err != nil { - return errors.Trace(err) - } - if err = streamMgr.adjustAndCheckStartTS(ctx); err != nil { - return errors.Trace(err) - } - - etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) - if err != nil { - return errors.Trace(err) - } - cli := streamhelper.NewMetaDataClient(etcdCLI) - defer func() { - if closeErr := cli.Close(); closeErr != nil { - log.Warn("failed to close etcd client", zap.Error(closeErr)) - } - }() - if err = streamMgr.checkImportTaskRunning(ctx, cli.Client); err != nil { - return errors.Trace(err) - } - // It supports single stream log task currently. - if count, err := cli.GetTaskCount(ctx); err != nil { - return errors.Trace(err) - } else if count > 0 { - return errors.Annotate(berrors.ErrStreamLogTaskExist, "It supports single stream log task currently") - } - - exist, err := streamMgr.checkLock(ctx) - if err != nil { - return errors.Trace(err) - } - // exist is true, which represents restart a stream task. Or create a new stream task. - if exist { - logInfo, err := getLogRange(ctx, &cfg.Config) - if err != nil { - return errors.Trace(err) - } - if logInfo.clusterID > 0 && logInfo.clusterID != streamMgr.bc.GetClusterID() { - return errors.Annotatef(berrors.ErrInvalidArgument, - "the stream log files from cluster ID:%v and current cluster ID:%v ", - logInfo.clusterID, streamMgr.bc.GetClusterID()) - } - - cfg.StartTS = logInfo.logMaxTS - if err = streamMgr.setGCSafePoint( - ctx, - utils.BRServiceSafePoint{ - ID: utils.MakeSafePointID(), - TTL: cfg.SafePointTTL, - BackupTS: cfg.StartTS, - }, - ); err != nil { - return errors.Trace(err) - } - } else { - if err = streamMgr.setGCSafePoint( - ctx, - utils.BRServiceSafePoint{ - ID: utils.MakeSafePointID(), - TTL: cfg.SafePointTTL, - BackupTS: cfg.StartTS, - }, - ); err != nil { - return errors.Trace(err) - } - if err = streamMgr.setLock(ctx); err != nil { - return errors.Trace(err) - } - if err = streamMgr.backupFullSchemas(ctx); err != nil { - return errors.Trace(err) - } - } - - ranges, err := streamMgr.buildObserveRanges() - if err != nil { - return errors.Trace(err) - } else if len(ranges) == 0 { - // nothing to backup - pdAddress := strings.Join(cfg.PD, ",") - log.Warn("Nothing to observe, maybe connected to cluster for restoring", - zap.String("PD address", pdAddress)) - return errors.Annotate(berrors.ErrInvalidArgument, "nothing need to observe") - } - - ti := streamhelper.TaskInfo{ - PBInfo: backuppb.StreamBackupTaskInfo{ - Storage: streamMgr.bc.GetStorageBackend(), - StartTs: cfg.StartTS, - EndTs: cfg.EndTS, - Name: cfg.TaskName, - TableFilter: cfg.FilterStr, - CompressionType: backuppb.CompressionType_ZSTD, - }, - Ranges: ranges, - Pausing: false, - } - if err = cli.PutTask(ctx, ti); err != nil { - return errors.Trace(err) - } - summary.Log(cmdName, ti.ZapTaskInfo()...) - return nil -} - -func RunStreamMetadata( - c context.Context, - g glue.Glue, - cmdName string, - cfg *StreamConfig, -) error { - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan( - "task.RunStreamCheckLog", - opentracing.ChildOf(span.Context()), - ) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - logInfo, err := getLogRange(ctx, &cfg.Config) - if err != nil { - return errors.Trace(err) - } - - logMinDate := stream.FormatDate(oracle.GetTimeFromTS(logInfo.logMinTS)) - logMaxDate := stream.FormatDate(oracle.GetTimeFromTS(logInfo.logMaxTS)) - summary.Log(cmdName, zap.Uint64("log-min-ts", logInfo.logMinTS), - zap.String("log-min-date", logMinDate), - zap.Uint64("log-max-ts", logInfo.logMaxTS), - zap.String("log-max-date", logMaxDate), - ) - return nil -} - -// RunStreamStop specifies stoping a stream task -func RunStreamStop( - c context.Context, - g glue.Glue, - cmdName string, - cfg *StreamConfig, -) error { - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan( - "task.RunStreamStop", - opentracing.ChildOf(span.Context()), - ) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - streamMgr, err := NewStreamMgr(ctx, cfg, g, false) - if err != nil { - return errors.Trace(err) - } - defer streamMgr.close() - - etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) - if err != nil { - return errors.Trace(err) - } - cli := streamhelper.NewMetaDataClient(etcdCLI) - defer func() { - if closeErr := cli.Close(); closeErr != nil { - log.Warn("failed to close etcd client", zap.Error(closeErr)) - } - }() - // to add backoff - ti, err := cli.GetTask(ctx, cfg.TaskName) - if err != nil { - return errors.Trace(err) - } - - if err = cli.DeleteTask(ctx, cfg.TaskName); err != nil { - return errors.Trace(err) - } - - if err := streamMgr.setGCSafePoint(ctx, - utils.BRServiceSafePoint{ - ID: buildPauseSafePointName(ti.Info.Name), - TTL: 0, // 0 means remove this service safe point. - BackupTS: math.MaxUint64, - }, - ); err != nil { - log.Warn("failed to remove safe point", zap.String("error", err.Error())) - } - - summary.Log(cmdName, logutil.StreamBackupTaskInfo(&ti.Info)) - return nil -} - -// RunStreamPause specifies pausing a stream task. -func RunStreamPause( - c context.Context, - g glue.Glue, - cmdName string, - cfg *StreamConfig, -) error { - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan( - "task.RunStreamPause", - opentracing.ChildOf(span.Context()), - ) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - streamMgr, err := NewStreamMgr(ctx, cfg, g, false) - if err != nil { - return errors.Trace(err) - } - defer streamMgr.close() - - etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) - if err != nil { - return errors.Trace(err) - } - cli := streamhelper.NewMetaDataClient(etcdCLI) - defer func() { - if closeErr := cli.Close(); closeErr != nil { - log.Warn("failed to close etcd client", zap.Error(closeErr)) - } - }() - // to add backoff - ti, isPaused, err := cli.GetTaskWithPauseStatus(ctx, cfg.TaskName) - if err != nil { - return errors.Trace(err) - } else if isPaused { - return errors.Annotatef(berrors.ErrKVUnknown, "The task %s is paused already.", cfg.TaskName) - } - - globalCheckPointTS, err := ti.GetGlobalCheckPointTS(ctx) - if err != nil { - return errors.Trace(err) - } - if err = streamMgr.setGCSafePoint( - ctx, - utils.BRServiceSafePoint{ - ID: buildPauseSafePointName(ti.Info.Name), - TTL: cfg.SafePointTTL, - BackupTS: globalCheckPointTS, - }, - ); err != nil { - return errors.Trace(err) - } - - err = cli.PauseTask(ctx, cfg.TaskName) - if err != nil { - return errors.Trace(err) - } - - summary.Log(cmdName, logutil.StreamBackupTaskInfo(&ti.Info)) - return nil -} - -// RunStreamResume specifies resuming a stream task. -func RunStreamResume( - c context.Context, - g glue.Glue, - cmdName string, - cfg *StreamConfig, -) error { - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan( - "task.RunStreamResume", - opentracing.ChildOf(span.Context()), - ) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - streamMgr, err := NewStreamMgr(ctx, cfg, g, false) - if err != nil { - return errors.Trace(err) - } - defer streamMgr.close() - - etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) - if err != nil { - return errors.Trace(err) - } - cli := streamhelper.NewMetaDataClient(etcdCLI) - defer func() { - if closeErr := cli.Close(); closeErr != nil { - log.Warn("failed to close etcd client", zap.Error(closeErr)) - } - }() - // to add backoff - ti, isPaused, err := cli.GetTaskWithPauseStatus(ctx, cfg.TaskName) - if err != nil { - return errors.Trace(err) - } else if !isPaused { - return errors.Annotatef(berrors.ErrKVUnknown, - "The task %s is active already.", cfg.TaskName) - } - - globalCheckPointTS, err := ti.GetGlobalCheckPointTS(ctx) - if err != nil { - return errors.Trace(err) - } - err = utils.CheckGCSafePoint(ctx, streamMgr.mgr.GetPDClient(), globalCheckPointTS) - if err != nil { - return errors.Annotatef(err, "the global checkpoint ts: %v(%s) has been gc. ", - globalCheckPointTS, oracle.GetTimeFromTS(globalCheckPointTS)) - } - - err = cli.ResumeTask(ctx, cfg.TaskName) - if err != nil { - return errors.Trace(err) - } - - err = cli.CleanLastErrorOfTask(ctx, cfg.TaskName) - if err != nil { - return err - } - - if err := streamMgr.setGCSafePoint(ctx, - utils.BRServiceSafePoint{ - ID: buildPauseSafePointName(ti.Info.Name), - TTL: utils.DefaultStreamStartSafePointTTL, - BackupTS: globalCheckPointTS, - }, - ); err != nil { - log.Warn("failed to remove safe point", - zap.Uint64("safe-point", globalCheckPointTS), zap.String("error", err.Error())) - } - - summary.Log(cmdName, logutil.StreamBackupTaskInfo(&ti.Info)) - return nil -} - -func RunStreamAdvancer(c context.Context, g glue.Glue, cmdName string, cfg *StreamConfig) error { - ctx, cancel := context.WithCancel(c) - defer cancel() - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), - cfg.CheckRequirements, false, conn.StreamVersionChecker) - if err != nil { - return err - } - - etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) - if err != nil { - return err - } - env := streamhelper.CliEnv(mgr.StoreManager, mgr.GetStore(), etcdCLI) - advancer := streamhelper.NewCheckpointAdvancer(env) - advancer.UpdateConfig(cfg.AdvancerCfg) - advancerd := daemon.New(advancer, streamhelper.OwnerManagerForLogBackup(ctx, etcdCLI), cfg.AdvancerCfg.TickDuration) - loop, err := advancerd.Begin(ctx) - if err != nil { - return err - } - loop() - return nil -} - -func checkConfigForStatus(pd []string) error { - if len(pd) == 0 { - return errors.Annotatef(berrors.ErrInvalidArgument, - "the command needs access to PD, please specify `-u` or `--pd`") - } - - return nil -} - -// makeStatusController makes the status controller via some config. -// this should better be in the `stream` package but it is impossible because of cyclic requirements. -func makeStatusController(ctx context.Context, cfg *StreamConfig, g glue.Glue) (*stream.StatusController, error) { - console := glue.GetConsole(g) - etcdCLI, err := dialEtcdWithCfg(ctx, cfg.Config) - if err != nil { - return nil, err - } - cli := streamhelper.NewMetaDataClient(etcdCLI) - var printer stream.TaskPrinter - if !cfg.JSONOutput { - printer = stream.PrintTaskByTable(console) - } else { - printer = stream.PrintTaskWithJSON(console) - } - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), - cfg.CheckRequirements, false, conn.StreamVersionChecker) - if err != nil { - return nil, err - } - return stream.NewStatusController(cli, mgr, printer), nil -} - -// RunStreamStatus get status for a specific stream task -func RunStreamStatus( - c context.Context, - g glue.Glue, - cmdName string, - cfg *StreamConfig, -) error { - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan( - "task.RunStreamStatus", - opentracing.ChildOf(span.Context()), - ) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - if err := checkConfigForStatus(cfg.PD); err != nil { - return err - } - ctl, err := makeStatusController(ctx, cfg, g) - if err != nil { - return err - } - - defer func() { - if closeErr := ctl.Close(); closeErr != nil { - log.Warn("failed to close etcd client", zap.Error(closeErr)) - } - }() - return ctl.PrintStatusOfTask(ctx, cfg.TaskName) -} - -// RunStreamTruncate truncates the log that belong to (0, until-ts) -func RunStreamTruncate(c context.Context, g glue.Glue, cmdName string, cfg *StreamConfig) (err error) { - console := glue.GetConsole(g) - em := color.New(color.Bold).SprintFunc() - warn := color.New(color.Bold, color.FgHiRed).SprintFunc() - formatTS := func(ts uint64) string { - return oracle.GetTimeFromTS(ts).Format("2006-01-02 15:04:05.0000") - } - if cfg.Until == 0 { - return errors.Annotatef(berrors.ErrInvalidArgument, "please provide the `--until` ts") - } - - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - extStorage, err := cfg.makeStorage(ctx) - if err != nil { - return err - } - if err := storage.TryLockRemote(ctx, extStorage, truncateLockPath, hintOnTruncateLock); err != nil { - return err - } - defer utils.WithCleanUp(&err, 10*time.Second, func(ctx context.Context) error { - return storage.UnlockRemote(ctx, extStorage, truncateLockPath) - }) - - sp, err := stream.GetTSFromFile(ctx, extStorage, stream.TruncateSafePointFileName) - if err != nil { - return err - } - - if cfg.Until < sp { - console.Println("According to the log, you have truncated backup data before", em(formatTS(sp))) - if !cfg.SkipPrompt && !console.PromptBool("Continue? ") { - return nil - } - } - - readMetaDone := console.ShowTask("Reading Metadata... ", glue.WithTimeCost()) - metas := stream.StreamMetadataSet{ - MetadataDownloadBatchSize: cfg.MetadataDownloadBatchSize, - Helper: stream.NewMetadataHelper(), - DryRun: cfg.DryRun, - } - shiftUntilTS, err := metas.LoadUntilAndCalculateShiftTS(ctx, extStorage, cfg.Until) - if err != nil { - return err - } - readMetaDone() - - var ( - fileCount int = 0 - kvCount int64 = 0 - totalSize uint64 = 0 - ) - - metas.IterateFilesFullyBefore(shiftUntilTS, func(d *stream.FileGroupInfo) (shouldBreak bool) { - fileCount++ - totalSize += d.Length - kvCount += d.KVCount - return - }) - console.Printf("We are going to remove %s files, until %s.\n", - em(fileCount), - em(formatTS(cfg.Until)), - ) - if !cfg.SkipPrompt && !console.PromptBool(warn("Sure? ")) { - return nil - } - - if cfg.Until > sp && !cfg.DryRun { - if err := stream.SetTSToFile( - ctx, extStorage, cfg.Until, stream.TruncateSafePointFileName); err != nil { - return err - } - } - - // begin to remove - p := console.StartProgressBar( - "Clearing Data Files and Metadata", fileCount, - glue.WithTimeCost(), - glue.WithConstExtraField("kv-count", kvCount), - glue.WithConstExtraField("kv-size", fmt.Sprintf("%d(%s)", totalSize, units.HumanSize(float64(totalSize)))), - ) - defer p.Close() - - notDeleted, err := metas.RemoveDataFilesAndUpdateMetadataInBatch(ctx, shiftUntilTS, extStorage, p.IncBy) - if err != nil { - return err - } - - if err := p.Wait(ctx); err != nil { - return err - } - - if len(notDeleted) > 0 { - const keepFirstNFailure = 16 - console.Println("Files below are not deleted due to error, you may clear it manually, check log for detail error:") - console.Println("- Total", em(len(notDeleted)), "items.") - if len(notDeleted) > keepFirstNFailure { - console.Println("-", em(len(notDeleted)-keepFirstNFailure), "items omitted.") - // TODO: maybe don't add them at the very first. - notDeleted = notDeleted[:keepFirstNFailure] - } - for _, f := range notDeleted { - console.Println(f) - } - } - - return nil -} - -// checkTaskExists checks whether there is a log backup task running. -// If so, return an error. -func checkTaskExists(ctx context.Context, cfg *RestoreConfig, etcdCLI *clientv3.Client) error { - if err := checkConfigForStatus(cfg.PD); err != nil { - return err - } - - cli := streamhelper.NewMetaDataClient(etcdCLI) - // check log backup task - tasks, err := cli.GetAllTasks(ctx) - if err != nil { - return err - } - if len(tasks) > 0 { - return errors.Errorf("log backup task is running: %s, "+ - "please stop the task before restore, and after PITR operation finished, "+ - "create log-backup task again and create a full backup on this cluster", tasks[0].Info.Name) - } - - return nil -} - -func checkIncompatibleChangefeed(ctx context.Context, backupTS uint64, etcdCLI *clientv3.Client) error { - nameSet, err := cdcutil.GetIncompatibleChangefeedsWithSafeTS(ctx, etcdCLI, backupTS) - if err != nil { - return err - } - if !nameSet.Empty() { - return errors.Errorf("%splease remove changefeed(s) before restore", nameSet.MessageToUser()) - } - return nil -} - -// RunStreamRestore restores stream log. -func RunStreamRestore( - c context.Context, - g glue.Glue, - cmdName string, - cfg *RestoreConfig, -) (err error) { - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan("task.RunStreamRestore", opentracing.ChildOf(span.Context())) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - _, s, err := GetStorage(ctx, cfg.Config.Storage, &cfg.Config) - if err != nil { - return errors.Trace(err) - } - logInfo, err := getLogRangeWithStorage(ctx, s) - if err != nil { - return errors.Trace(err) - } - if cfg.RestoreTS == 0 { - cfg.RestoreTS = logInfo.logMaxTS - } - - if len(cfg.FullBackupStorage) > 0 { - startTS, fullClusterID, err := getFullBackupTS(ctx, cfg) - if err != nil { - return errors.Trace(err) - } - if logInfo.clusterID > 0 && fullClusterID > 0 && logInfo.clusterID != fullClusterID { - return errors.Annotatef(berrors.ErrInvalidArgument, - "the full snapshot(from cluster ID:%v) and log(from cluster ID:%v) come from different cluster.", - fullClusterID, logInfo.clusterID) - } - - cfg.StartTS = startTS - if cfg.StartTS < logInfo.logMinTS { - return errors.Annotatef(berrors.ErrInvalidArgument, - "it has gap between full backup ts:%d(%s) and log backup ts:%d(%s). ", - cfg.StartTS, oracle.GetTimeFromTS(cfg.StartTS), - logInfo.logMinTS, oracle.GetTimeFromTS(logInfo.logMinTS)) - } - } - - log.Info("start restore on point", - zap.Uint64("restore-from", cfg.StartTS), zap.Uint64("restore-to", cfg.RestoreTS), - zap.Uint64("log-min-ts", logInfo.logMinTS), zap.Uint64("log-max-ts", logInfo.logMaxTS)) - if err := checkLogRange(cfg.StartTS, cfg.RestoreTS, logInfo.logMinTS, logInfo.logMaxTS); err != nil { - return errors.Trace(err) - } - - checkInfo, err := checkPiTRTaskInfo(ctx, g, s, cfg) - if err != nil { - return errors.Trace(err) - } - - failpoint.Inject("failed-before-full-restore", func(_ failpoint.Value) { - failpoint.Return(errors.New("failpoint: failed before full restore")) - }) - - recorder := tiflashrec.New() - cfg.tiflashRecorder = recorder - // restore full snapshot. - if checkInfo.NeedFullRestore { - logStorage := cfg.Config.Storage - cfg.Config.Storage = cfg.FullBackupStorage - // TiFlash replica is restored to down-stream on 'pitr' currently. - if err = runRestore(ctx, g, FullRestoreCmd, cfg, checkInfo); err != nil { - return errors.Trace(err) - } - cfg.Config.Storage = logStorage - } else if len(cfg.FullBackupStorage) > 0 { - skipMsg := []byte(fmt.Sprintf("%s command is skipped due to checkpoint mode for restore\n", FullRestoreCmd)) - if _, err := glue.GetConsole(g).Out().Write(skipMsg); err != nil { - return errors.Trace(err) - } - if checkInfo.CheckpointInfo != nil && checkInfo.CheckpointInfo.TiFlashItems != nil { - log.Info("load tiflash records of snapshot restore from checkpoint") - if err != nil { - return errors.Trace(err) - } - cfg.tiflashRecorder.Load(checkInfo.CheckpointInfo.TiFlashItems) - } - } - // restore log. - cfg.adjustRestoreConfigForStreamRestore() - if err := restoreStream(ctx, g, cfg, checkInfo.CheckpointInfo); err != nil { - return errors.Trace(err) - } - return nil -} - -// RunStreamRestore start restore job -func restoreStream( - c context.Context, - g glue.Glue, - cfg *RestoreConfig, - taskInfo *checkpoint.CheckpointTaskInfoForLogRestore, -) (err error) { - var ( - totalKVCount uint64 - totalSize uint64 - checkpointTotalKVCount uint64 - checkpointTotalSize uint64 - currentTS uint64 - mu sync.Mutex - startTime = time.Now() - ) - defer func() { - if err != nil { - summary.Log("restore log failed summary", zap.Error(err)) - } else { - totalDureTime := time.Since(startTime) - summary.Log("restore log success summary", zap.Duration("total-take", totalDureTime), - zap.Uint64("source-start-point", cfg.StartTS), - zap.Uint64("source-end-point", cfg.RestoreTS), - zap.Uint64("target-end-point", currentTS), - zap.String("source-start", stream.FormatDate(oracle.GetTimeFromTS(cfg.StartTS))), - zap.String("source-end", stream.FormatDate(oracle.GetTimeFromTS(cfg.RestoreTS))), - zap.String("target-end", stream.FormatDate(oracle.GetTimeFromTS(currentTS))), - zap.Uint64("total-kv-count", totalKVCount), - zap.Uint64("skipped-kv-count-by-checkpoint", checkpointTotalKVCount), - zap.String("total-size", units.HumanSize(float64(totalSize))), - zap.String("skipped-size-by-checkpoint", units.HumanSize(float64(checkpointTotalSize))), - zap.String("average-speed", units.HumanSize(float64(totalSize)/totalDureTime.Seconds())+"/s"), - ) - } - }() - - ctx, cancelFn := context.WithCancel(c) - defer cancelFn() - - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - span1 := span.Tracer().StartSpan( - "restoreStream", - opentracing.ChildOf(span.Context()), - ) - defer span1.Finish() - ctx = opentracing.ContextWithSpan(ctx, span1) - } - - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), - cfg.CheckRequirements, true, conn.StreamVersionChecker) - if err != nil { - return errors.Trace(err) - } - defer mgr.Close() - - client, err := createRestoreClient(ctx, g, cfg, mgr) - if err != nil { - return errors.Annotate(err, "failed to create restore client") - } - defer client.Close() - - if taskInfo != nil && taskInfo.RewriteTS > 0 { - // reuse the task's rewrite ts - log.Info("reuse the task's rewrite ts", zap.Uint64("rewrite-ts", taskInfo.RewriteTS)) - currentTS = taskInfo.RewriteTS - } else { - currentTS, err = restore.GetTSWithRetry(ctx, mgr.GetPDClient()) - if err != nil { - return errors.Trace(err) - } - } - client.SetCurrentTS(currentTS) - - importModeSwitcher := restore.NewImportModeSwitcher(mgr.GetPDClient(), cfg.Config.SwitchModeInterval, mgr.GetTLSConfig()) - restoreSchedulers, _, err := restore.RestorePreWork(ctx, mgr, importModeSwitcher, cfg.Online, false) - if err != nil { - return errors.Trace(err) - } - // Always run the post-work even on error, so we don't stuck in the import - // mode or emptied schedulers - defer restore.RestorePostWork(ctx, importModeSwitcher, restoreSchedulers, cfg.Online) - - // It need disable GC in TiKV when PiTR. - // because the process of PITR is concurrent and kv events isn't sorted by tso. - restoreGc, oldRatio, err := KeepGcDisabled(g, mgr.GetStorage()) - if err != nil { - return errors.Trace(err) - } - gcDisabledRestorable := false - defer func() { - // don't restore the gc-ratio-threshold if checkpoint mode is used and restored is not finished - if cfg.UseCheckpoint && !gcDisabledRestorable { - log.Info("skip restore the gc-ratio-threshold for next retry") - return - } - - log.Info("start to restore gc", zap.String("ratio", oldRatio)) - if err := restoreGc(oldRatio); err != nil { - log.Error("failed to set gc enabled", zap.Error(err)) - } - log.Info("finish restoring gc") - }() - - var taskName string - var checkpointRunner *checkpoint.CheckpointRunner[checkpoint.LogRestoreKeyType, checkpoint.LogRestoreValueType] - if cfg.UseCheckpoint { - taskName = cfg.generateLogRestoreTaskName(client.GetClusterID(ctx), cfg.StartTS, cfg.RestoreTS) - oldRatioFromCheckpoint, err := client.InitCheckpointMetadataForLogRestore(ctx, taskName, oldRatio) - if err != nil { - return errors.Trace(err) - } - oldRatio = oldRatioFromCheckpoint - - checkpointRunner, err = client.StartCheckpointRunnerForLogRestore(ctx, taskName) - if err != nil { - return errors.Trace(err) - } - defer func() { - log.Info("wait for flush checkpoint...") - checkpointRunner.WaitForFinish(ctx, !gcDisabledRestorable) - }() - } - - err = client.InstallLogFileManager(ctx, cfg.StartTS, cfg.RestoreTS, cfg.MetadataDownloadBatchSize) - if err != nil { - return err - } - - // get full backup meta storage to generate rewrite rules. - fullBackupStorage, err := parseFullBackupTablesStorage(cfg) - if err != nil { - return errors.Trace(err) - } - // load the id maps only when the checkpoint mode is used and not the first execution - newTask := true - if taskInfo != nil && taskInfo.Progress == checkpoint.InLogRestoreAndIdMapPersist { - newTask = false - } - // get the schemas ID replace information. - schemasReplace, err := client.InitSchemasReplaceForDDL(ctx, &logclient.InitSchemaConfig{ - IsNewTask: newTask, - TableFilter: cfg.TableFilter, - TiFlashRecorder: cfg.tiflashRecorder, - FullBackupStorage: fullBackupStorage, - }) - if err != nil { - return errors.Trace(err) - } - schemasReplace.AfterTableRewritten = func(deleted bool, tableInfo *model.TableInfo) { - // When the table replica changed to 0, the tiflash replica might be set to `nil`. - // We should remove the table if we meet. - if deleted || tableInfo.TiFlashReplica == nil { - cfg.tiflashRecorder.DelTable(tableInfo.ID) - return - } - cfg.tiflashRecorder.AddTable(tableInfo.ID, *tableInfo.TiFlashReplica) - // Remove the replica firstly. Let's restore them at the end. - tableInfo.TiFlashReplica = nil - } - - updateStats := func(kvCount uint64, size uint64) { - mu.Lock() - defer mu.Unlock() - totalKVCount += kvCount - totalSize += size - } - dataFileCount := 0 - ddlFiles, err := client.LoadDDLFilesAndCountDMLFiles(ctx, &dataFileCount) - if err != nil { - return err - } - pm := g.StartProgress(ctx, "Restore Meta Files", int64(len(ddlFiles)), !cfg.LogProgress) - if err = withProgress(pm, func(p glue.Progress) error { - client.RunGCRowsLoader(ctx) - return client.RestoreMetaKVFiles(ctx, ddlFiles, schemasReplace, updateStats, p.Inc) - }); err != nil { - return errors.Annotate(err, "failed to restore meta files") - } - - rewriteRules := initRewriteRules(schemasReplace) - - ingestRecorder := schemasReplace.GetIngestRecorder() - if err := rangeFilterFromIngestRecorder(ingestRecorder, rewriteRules); err != nil { - return errors.Trace(err) - } - - // generate the upstream->downstream id maps for checkpoint - idrules := make(map[int64]int64) - downstreamIdset := make(map[int64]struct{}) - for upstreamId, rule := range rewriteRules { - downstreamId := restoreutils.GetRewriteTableID(upstreamId, rule) - idrules[upstreamId] = downstreamId - downstreamIdset[downstreamId] = struct{}{} - } - - logFilesIter, err := client.LoadDMLFiles(ctx) - if err != nil { - return errors.Trace(err) - } - pd := g.StartProgress(ctx, "Restore KV Files", int64(dataFileCount), !cfg.LogProgress) - err = withProgress(pd, func(p glue.Progress) error { - if cfg.UseCheckpoint { - updateStatsWithCheckpoint := func(kvCount, size uint64) { - mu.Lock() - defer mu.Unlock() - totalKVCount += kvCount - totalSize += size - checkpointTotalKVCount += kvCount - checkpointTotalSize += size - } - logFilesIter, err = client.WrapLogFilesIterWithCheckpoint(ctx, logFilesIter, downstreamIdset, taskName, updateStatsWithCheckpoint, p.Inc) - if err != nil { - return errors.Trace(err) - } - } - logFilesIterWithSplit, err := client.WrapLogFilesIterWithSplitHelper(logFilesIter, rewriteRules, g, mgr.GetStorage()) - if err != nil { - return errors.Trace(err) - } - - return client.RestoreKVFiles(ctx, rewriteRules, idrules, logFilesIterWithSplit, checkpointRunner, cfg.PitrBatchCount, cfg.PitrBatchSize, updateStats, p.IncBy) - }) - if err != nil { - return errors.Annotate(err, "failed to restore kv files") - } - - if err = client.CleanUpKVFiles(ctx); err != nil { - return errors.Annotate(err, "failed to clean up") - } - - if err = client.InsertGCRows(ctx); err != nil { - return errors.Annotate(err, "failed to insert rows into gc_delete_range") - } - - if err = client.RepairIngestIndex(ctx, ingestRecorder, g, taskName); err != nil { - return errors.Annotate(err, "failed to repair ingest index") - } - - if cfg.tiflashRecorder != nil { - sqls := cfg.tiflashRecorder.GenerateAlterTableDDLs(mgr.GetDomain().InfoSchema()) - log.Info("Generating SQLs for restoring TiFlash Replica", - zap.Strings("sqls", sqls)) - err = g.UseOneShotSession(mgr.GetStorage(), false, func(se glue.Session) error { - for _, sql := range sqls { - if errExec := se.ExecuteInternal(ctx, sql); errExec != nil { - logutil.WarnTerm("Failed to restore tiflash replica config, you may execute the sql restore it manually.", - logutil.ShortError(errExec), - zap.String("sql", sql), - ) - } - } - return nil - }) - if err != nil { - return err - } - } - - failpoint.Inject("do-checksum-with-rewrite-rules", func(_ failpoint.Value) { - if err := client.FailpointDoChecksumForLogRestore(ctx, mgr.GetStorage().GetClient(), mgr.GetPDClient(), idrules, rewriteRules); err != nil { - failpoint.Return(errors.Annotate(err, "failed to do checksum")) - } - }) - - gcDisabledRestorable = true - - return nil -} - -func createRestoreClient(ctx context.Context, g glue.Glue, cfg *RestoreConfig, mgr *conn.Mgr) (*logclient.LogClient, error) { - var err error - keepaliveCfg := GetKeepalive(&cfg.Config) - keepaliveCfg.PermitWithoutStream = true - client := logclient.NewRestoreClient(mgr.GetPDClient(), mgr.GetPDHTTPClient(), mgr.GetTLSConfig(), keepaliveCfg) - err = client.Init(g, mgr.GetStorage()) - if err != nil { - return nil, errors.Trace(err) - } - defer func() { - if err != nil { - client.Close() - } - }() - - u, err := storage.ParseBackend(cfg.Storage, &cfg.BackendOptions) - if err != nil { - return nil, errors.Trace(err) - } - - opts := getExternalStorageOptions(&cfg.Config, u) - if err = client.SetStorage(ctx, u, &opts); err != nil { - return nil, errors.Trace(err) - } - client.SetCrypter(&cfg.CipherInfo) - client.SetConcurrency(uint(cfg.Concurrency)) - client.InitClients(ctx, u) - - err = client.SetRawKVBatchClient(ctx, cfg.PD, cfg.TLS.ToKVSecurity()) - if err != nil { - return nil, errors.Trace(err) - } - - return client, nil -} - -// rangeFilterFromIngestRecorder rewrites the table id of items in the ingestRecorder -// TODO: need to implement the range filter out feature -func rangeFilterFromIngestRecorder(recorder *ingestrec.IngestRecorder, rewriteRules map[int64]*restoreutils.RewriteRules) error { - err := recorder.RewriteTableID(func(tableID int64) (int64, bool, error) { - rewriteRule, exists := rewriteRules[tableID] - if !exists { - // since the table's files will be skipped restoring, here also skips. - return 0, true, nil - } - newTableID := restoreutils.GetRewriteTableID(tableID, rewriteRule) - if newTableID == 0 { - return 0, false, errors.Errorf("newTableID is 0, tableID: %d", tableID) - } - return newTableID, false, nil - }) - return errors.Trace(err) -} - -func getExternalStorageOptions(cfg *Config, u *backuppb.StorageBackend) storage.ExternalStorageOptions { - var httpClient *http.Client - if u.GetGcs() == nil { - httpClient = storage.GetDefaultHttpClient(cfg.MetadataDownloadBatchSize) - } - return storage.ExternalStorageOptions{ - NoCredentials: cfg.NoCreds, - SendCredentials: cfg.SendCreds, - HTTPClient: httpClient, - } -} - -func checkLogRange(restoreFrom, restoreTo, logMinTS, logMaxTS uint64) error { - // serveral ts constraint: - // logMinTS <= restoreFrom <= restoreTo <= logMaxTS - if logMinTS > restoreFrom || restoreFrom > restoreTo || restoreTo > logMaxTS { - return errors.Annotatef(berrors.ErrInvalidArgument, - "restore log from %d(%s) to %d(%s), "+ - " but the current existed log from %d(%s) to %d(%s)", - restoreFrom, oracle.GetTimeFromTS(restoreFrom), - restoreTo, oracle.GetTimeFromTS(restoreTo), - logMinTS, oracle.GetTimeFromTS(logMinTS), - logMaxTS, oracle.GetTimeFromTS(logMaxTS), - ) - } - return nil -} - -// withProgress execute some logic with the progress, and close it once the execution done. -func withProgress(p glue.Progress, cc func(p glue.Progress) error) error { - defer p.Close() - return cc(p) -} - -type backupLogInfo struct { - logMaxTS uint64 - logMinTS uint64 - clusterID uint64 -} - -// getLogRange gets the log-min-ts and log-max-ts of starting log backup. -func getLogRange( - ctx context.Context, - cfg *Config, -) (backupLogInfo, error) { - _, s, err := GetStorage(ctx, cfg.Storage, cfg) - if err != nil { - return backupLogInfo{}, errors.Trace(err) - } - return getLogRangeWithStorage(ctx, s) -} - -func getLogRangeWithStorage( - ctx context.Context, - s storage.ExternalStorage, -) (backupLogInfo, error) { - // logStartTS: Get log start ts from backupmeta file. - metaData, err := s.ReadFile(ctx, metautil.MetaFile) - if err != nil { - return backupLogInfo{}, errors.Trace(err) - } - backupMeta := &backuppb.BackupMeta{} - if err = backupMeta.Unmarshal(metaData); err != nil { - return backupLogInfo{}, errors.Trace(err) - } - // endVersion > 0 represents that the storage has been used for `br backup` - if backupMeta.GetEndVersion() > 0 { - return backupLogInfo{}, errors.Annotate(berrors.ErrStorageUnknown, - "the storage has been used for full backup") - } - logStartTS := backupMeta.GetStartVersion() - - // truncateTS: get log truncate ts from TruncateSafePointFileName. - // If truncateTS equals 0, which represents the stream log has never been truncated. - truncateTS, err := stream.GetTSFromFile(ctx, s, stream.TruncateSafePointFileName) - if err != nil { - return backupLogInfo{}, errors.Trace(err) - } - logMinTS := max(logStartTS, truncateTS) - - // get max global resolved ts from metas. - logMaxTS, err := getGlobalCheckpointFromStorage(ctx, s) - if err != nil { - return backupLogInfo{}, errors.Trace(err) - } - logMaxTS = max(logMinTS, logMaxTS) - - return backupLogInfo{ - logMaxTS: logMaxTS, - logMinTS: logMinTS, - clusterID: backupMeta.ClusterId, - }, nil -} - -func getGlobalCheckpointFromStorage(ctx context.Context, s storage.ExternalStorage) (uint64, error) { - var globalCheckPointTS uint64 = 0 - opt := storage.WalkOption{SubDir: stream.GetStreamBackupGlobalCheckpointPrefix()} - err := s.WalkDir(ctx, &opt, func(path string, size int64) error { - if !strings.HasSuffix(path, ".ts") { - return nil - } - - buff, err := s.ReadFile(ctx, path) - if err != nil { - return errors.Trace(err) - } - ts := binary.LittleEndian.Uint64(buff) - globalCheckPointTS = max(ts, globalCheckPointTS) - return nil - }) - return globalCheckPointTS, errors.Trace(err) -} - -// getFullBackupTS gets the snapshot-ts of full bakcup -func getFullBackupTS( - ctx context.Context, - cfg *RestoreConfig, -) (uint64, uint64, error) { - _, s, err := GetStorage(ctx, cfg.FullBackupStorage, &cfg.Config) - if err != nil { - return 0, 0, errors.Trace(err) - } - - metaData, err := s.ReadFile(ctx, metautil.MetaFile) - if err != nil { - return 0, 0, errors.Trace(err) - } - - backupmeta := &backuppb.BackupMeta{} - if err = backupmeta.Unmarshal(metaData); err != nil { - return 0, 0, errors.Trace(err) - } - - return backupmeta.GetEndVersion(), backupmeta.GetClusterId(), nil -} - -func parseFullBackupTablesStorage( - cfg *RestoreConfig, -) (*logclient.FullBackupStorageConfig, error) { - if len(cfg.FullBackupStorage) == 0 { - log.Info("the full backup path is not specified, so BR will try to get id maps") - return nil, nil - } - u, err := storage.ParseBackend(cfg.FullBackupStorage, &cfg.BackendOptions) - if err != nil { - return nil, errors.Trace(err) - } - return &logclient.FullBackupStorageConfig{ - Backend: u, - Opts: storageOpts(&cfg.Config), - }, nil -} - -func initRewriteRules(schemasReplace *stream.SchemasReplace) map[int64]*restoreutils.RewriteRules { - rules := make(map[int64]*restoreutils.RewriteRules) - filter := schemasReplace.TableFilter - - for _, dbReplace := range schemasReplace.DbMap { - if utils.IsSysDB(dbReplace.Name) || !filter.MatchSchema(dbReplace.Name) { - continue - } - - for oldTableID, tableReplace := range dbReplace.TableMap { - if !filter.MatchTable(dbReplace.Name, tableReplace.Name) { - continue - } - - if _, exist := rules[oldTableID]; !exist { - log.Info("add rewrite rule", - zap.String("tableName", dbReplace.Name+"."+tableReplace.Name), - zap.Int64("oldID", oldTableID), zap.Int64("newID", tableReplace.TableID)) - rules[oldTableID] = restoreutils.GetRewriteRuleOfTable( - oldTableID, tableReplace.TableID, 0, tableReplace.IndexMap, false) - } - - for oldID, newID := range tableReplace.PartitionMap { - if _, exist := rules[oldID]; !exist { - log.Info("add rewrite rule", - zap.String("tableName", dbReplace.Name+"."+tableReplace.Name), - zap.Int64("oldID", oldID), zap.Int64("newID", newID)) - rules[oldID] = restoreutils.GetRewriteRuleOfTable(oldID, newID, 0, tableReplace.IndexMap, false) - } - } - } - } - return rules -} - -// ShiftTS gets a smaller shiftTS than startTS. -// It has a safe duration between shiftTS and startTS for trasaction. -func ShiftTS(startTS uint64) uint64 { - physical := oracle.ExtractPhysical(startTS) - logical := oracle.ExtractLogical(startTS) - - shiftPhysical := physical - streamShiftDuration.Milliseconds() - if shiftPhysical < 0 { - return 0 - } - return oracle.ComposeTS(shiftPhysical, logical) -} - -func buildPauseSafePointName(taskName string) string { - return fmt.Sprintf("%s_pause_safepoint", taskName) -} - -func checkPiTRRequirements(mgr *conn.Mgr) error { - return restore.AssertUserDBsEmpty(mgr.GetDomain()) -} - -type PiTRTaskInfo struct { - CheckpointInfo *checkpoint.CheckpointTaskInfoForLogRestore - NeedFullRestore bool - FullRestoreCheckErr error -} - -func checkPiTRTaskInfo( - ctx context.Context, - g glue.Glue, - s storage.ExternalStorage, - cfg *RestoreConfig, -) (*PiTRTaskInfo, error) { - var ( - doFullRestore = (len(cfg.FullBackupStorage) > 0) - curTaskInfo *checkpoint.CheckpointTaskInfoForLogRestore - errTaskMsg string - ) - checkInfo := &PiTRTaskInfo{} - - mgr, err := NewMgr(ctx, g, cfg.PD, cfg.TLS, GetKeepalive(&cfg.Config), - cfg.CheckRequirements, true, conn.StreamVersionChecker) - if err != nil { - return checkInfo, errors.Trace(err) - } - defer mgr.Close() - - clusterID := mgr.GetPDClient().GetClusterID(ctx) - if cfg.UseCheckpoint { - exists, err := checkpoint.ExistsCheckpointTaskInfo(ctx, s, clusterID) - if err != nil { - return checkInfo, errors.Trace(err) - } - if exists { - curTaskInfo, err = checkpoint.LoadCheckpointTaskInfoForLogRestore(ctx, s, clusterID) - if err != nil { - return checkInfo, errors.Trace(err) - } - // TODO: check whether user has manually modified the cluster(ddl). If so, regard the behavior - // as restore from scratch. (update `curTaskInfo.RewriteTs` to 0 as an uninitial value) - - // The task info is written to external storage without status `InSnapshotRestore` only when - // id-maps is persist into external storage, so there is no need to do snapshot restore again. - if curTaskInfo.StartTS == cfg.StartTS && curTaskInfo.RestoreTS == cfg.RestoreTS { - // the same task, check whether skip snapshot restore - doFullRestore = doFullRestore && (curTaskInfo.Progress == checkpoint.InSnapshotRestore) - // update the snapshot restore task name to clean up in final - if !doFullRestore && (len(cfg.FullBackupStorage) > 0) { - _ = cfg.generateSnapshotRestoreTaskName(clusterID) - } - log.Info("the same task", zap.Bool("skip-snapshot-restore", !doFullRestore)) - } else { - // not the same task, so overwrite the taskInfo with a new task - log.Info("not the same task, start to restore from scratch") - errTaskMsg = fmt.Sprintf( - "a new task [start-ts=%d] [restored-ts=%d] while the last task info: [start-ts=%d] [restored-ts=%d] [skip-snapshot-restore=%t]", - cfg.StartTS, cfg.RestoreTS, curTaskInfo.StartTS, curTaskInfo.RestoreTS, curTaskInfo.Progress == checkpoint.InLogRestoreAndIdMapPersist) - - curTaskInfo = nil - } - } - } - checkInfo.CheckpointInfo = curTaskInfo - checkInfo.NeedFullRestore = doFullRestore - // restore full snapshot precheck. - if doFullRestore { - if !(cfg.UseCheckpoint && curTaskInfo != nil) { - // Only when use checkpoint and not the first execution, - // skip checking requirements. - log.Info("check pitr requirements for the first execution") - if err := checkPiTRRequirements(mgr); err != nil { - if len(errTaskMsg) > 0 { - err = errors.Annotatef(err, "The current restore task is regarded as %s. "+ - "If you ensure that no changes have been made to the cluster since the last execution, "+ - "you can adjust the `start-ts` or `restored-ts` to continue with the previous execution. "+ - "Otherwise, if you want to restore from scratch, please clean the cluster at first", errTaskMsg) - } - // delay cluster checks after we get the backupmeta. - // for the case that the restore inc + log backup, - // we can still restore them. - checkInfo.FullRestoreCheckErr = err - return checkInfo, nil - } - } - } - - // persist the new task info - if cfg.UseCheckpoint && curTaskInfo == nil { - log.Info("save checkpoint task info with `InSnapshotRestore` status") - if err := checkpoint.SaveCheckpointTaskInfoForLogRestore(ctx, s, &checkpoint.CheckpointTaskInfoForLogRestore{ - Progress: checkpoint.InSnapshotRestore, - StartTS: cfg.StartTS, - RestoreTS: cfg.RestoreTS, - // updated in the stage of `InLogRestoreAndIdMapPersist` - RewriteTS: 0, - TiFlashItems: nil, - }, clusterID); err != nil { - return checkInfo, errors.Trace(err) - } - } - return checkInfo, nil -} diff --git a/br/pkg/utils/backoff.go b/br/pkg/utils/backoff.go index e8093d86cfa35..385ed4319a06a 100644 --- a/br/pkg/utils/backoff.go +++ b/br/pkg/utils/backoff.go @@ -268,9 +268,9 @@ func (bo *pdReqBackoffer) NextBackoff(err error) time.Duration { } } - if _, _err_ := failpoint.Eval(_curpkg_("set-attempt-to-one")); _err_ == nil { + failpoint.Inject("set-attempt-to-one", func(_ failpoint.Value) { bo.attempt = 1 - } + }) if bo.delayTime > bo.maxDelayTime { return bo.maxDelayTime } diff --git a/br/pkg/utils/backoff.go__failpoint_stash__ b/br/pkg/utils/backoff.go__failpoint_stash__ deleted file mode 100644 index 385ed4319a06a..0000000000000 --- a/br/pkg/utils/backoff.go__failpoint_stash__ +++ /dev/null @@ -1,323 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package utils - -import ( - "context" - "database/sql" - "io" - "math" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "go.uber.org/multierr" - "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -const ( - // importSSTRetryTimes specifies the retry time. Its longest time is about 90s-100s. - importSSTRetryTimes = 16 - importSSTWaitInterval = 40 * time.Millisecond - importSSTMaxWaitInterval = 10 * time.Second - - downloadSSTRetryTimes = 8 - downloadSSTWaitInterval = 1 * time.Second - downloadSSTMaxWaitInterval = 4 * time.Second - - backupSSTRetryTimes = 5 - backupSSTWaitInterval = 2 * time.Second - backupSSTMaxWaitInterval = 3 * time.Second - - resetTSRetryTime = 32 - resetTSWaitInterval = 50 * time.Millisecond - resetTSMaxWaitInterval = 2 * time.Second - - resetTSRetryTimeExt = 600 - resetTSWaitIntervalExt = 500 * time.Millisecond - resetTSMaxWaitIntervalExt = 300 * time.Second - - // region heartbeat are 10 seconds by default, if some region has 2 heartbeat missing (15 seconds), it appear to be a network issue between PD and TiKV. - FlashbackRetryTime = 3 - FlashbackWaitInterval = 3 * time.Second - FlashbackMaxWaitInterval = 15 * time.Second - - ChecksumRetryTime = 8 - ChecksumWaitInterval = 1 * time.Second - ChecksumMaxWaitInterval = 30 * time.Second - - gRPC_Cancel = "the client connection is closing" -) - -// At least, there are two possible cancel() call, -// one from go context, another from gRPC, here we retry when gRPC cancel with connection closing -func isGRPCCancel(err error) bool { - if s, ok := status.FromError(err); ok { - if strings.Contains(s.Message(), gRPC_Cancel) { - return true - } - } - return false -} - -// ConstantBackoff is a backoffer that retry forever until success. -type ConstantBackoff time.Duration - -// NextBackoff returns a duration to wait before retrying again -func (c ConstantBackoff) NextBackoff(err error) time.Duration { - return time.Duration(c) -} - -// Attempt returns the remain attempt times -func (c ConstantBackoff) Attempt() int { - // A large enough value. Also still safe for arithmetic operations (won't easily overflow). - return math.MaxInt16 -} - -// RetryState is the mutable state needed for retrying. -// It likes the `utils.Backoffer`, but more fundamental: -// this only control the backoff time and knows nothing about what error happens. -// NOTE: Maybe also implement the backoffer via this. -type RetryState struct { - maxRetry int - retryTimes int - - maxBackoff time.Duration - nextBackoff time.Duration -} - -// InitialRetryState make the initial state for retrying. -func InitialRetryState(maxRetryTimes int, initialBackoff, maxBackoff time.Duration) RetryState { - return RetryState{ - maxRetry: maxRetryTimes, - maxBackoff: maxBackoff, - nextBackoff: initialBackoff, - } -} - -// Whether in the current state we can retry. -func (rs *RetryState) ShouldRetry() bool { - return rs.retryTimes < rs.maxRetry -} - -// Get the exponential backoff durion and transform the state. -func (rs *RetryState) ExponentialBackoff() time.Duration { - rs.retryTimes++ - backoff := rs.nextBackoff - rs.nextBackoff *= 2 - if rs.nextBackoff > rs.maxBackoff { - rs.nextBackoff = rs.maxBackoff - } - return backoff -} - -func (rs *RetryState) GiveUp() { - rs.retryTimes = rs.maxRetry -} - -// ReduceRetry reduces retry times for 1. -func (rs *RetryState) ReduceRetry() { - rs.retryTimes-- -} - -// Attempt implements the `Backoffer`. -// TODO: Maybe use this to replace the `exponentialBackoffer` (which is nearly homomorphic to this)? -func (rs *RetryState) Attempt() int { - return rs.maxRetry - rs.retryTimes -} - -// NextBackoff implements the `Backoffer`. -func (rs *RetryState) NextBackoff(error) time.Duration { - return rs.ExponentialBackoff() -} - -type importerBackoffer struct { - attempt int - delayTime time.Duration - maxDelayTime time.Duration - errContext *ErrorContext -} - -// NewBackoffer creates a new controller regulating a truncated exponential backoff. -func NewBackoffer(attempt int, delayTime, maxDelayTime time.Duration, errContext *ErrorContext) Backoffer { - return &importerBackoffer{ - attempt: attempt, - delayTime: delayTime, - maxDelayTime: maxDelayTime, - errContext: errContext, - } -} - -func NewImportSSTBackoffer() Backoffer { - errContext := NewErrorContext("import sst", 3) - return NewBackoffer(importSSTRetryTimes, importSSTWaitInterval, importSSTMaxWaitInterval, errContext) -} - -func NewDownloadSSTBackoffer() Backoffer { - errContext := NewErrorContext("download sst", 3) - return NewBackoffer(downloadSSTRetryTimes, downloadSSTWaitInterval, downloadSSTMaxWaitInterval, errContext) -} - -func NewBackupSSTBackoffer() Backoffer { - errContext := NewErrorContext("backup sst", 3) - return NewBackoffer(backupSSTRetryTimes, backupSSTWaitInterval, backupSSTMaxWaitInterval, errContext) -} - -func (bo *importerBackoffer) NextBackoff(err error) time.Duration { - // we don't care storeID here. - errs := multierr.Errors(err) - lastErr := errs[len(errs)-1] - res := HandleUnknownBackupError(lastErr.Error(), 0, bo.errContext) - if res.Strategy == StrategyRetry { - bo.delayTime = 2 * bo.delayTime - bo.attempt-- - } else { - e := errors.Cause(lastErr) - switch e { // nolint:errorlint - case berrors.ErrKVEpochNotMatch, berrors.ErrKVDownloadFailed, berrors.ErrKVIngestFailed, berrors.ErrPDLeaderNotFound: - bo.delayTime = 2 * bo.delayTime - bo.attempt-- - case berrors.ErrKVRangeIsEmpty, berrors.ErrKVRewriteRuleNotFound: - // Expected error, finish the operation - bo.delayTime = 0 - bo.attempt = 0 - default: - switch status.Code(e) { - case codes.Unavailable, codes.Aborted, codes.DeadlineExceeded, codes.ResourceExhausted, codes.Internal: - bo.delayTime = 2 * bo.delayTime - bo.attempt-- - case codes.Canceled: - if isGRPCCancel(lastErr) { - bo.delayTime = 2 * bo.delayTime - bo.attempt-- - } else { - bo.delayTime = 0 - bo.attempt = 0 - } - default: - // Unexpected error - bo.delayTime = 0 - bo.attempt = 0 - log.Warn("unexpected error, stop retrying", zap.Error(err)) - } - } - } - if bo.delayTime > bo.maxDelayTime { - return bo.maxDelayTime - } - return bo.delayTime -} - -func (bo *importerBackoffer) Attempt() int { - return bo.attempt -} - -type pdReqBackoffer struct { - attempt int - delayTime time.Duration - maxDelayTime time.Duration -} - -func NewPDReqBackoffer() Backoffer { - return &pdReqBackoffer{ - attempt: resetTSRetryTime, - delayTime: resetTSWaitInterval, - maxDelayTime: resetTSMaxWaitInterval, - } -} - -func NewPDReqBackofferExt() Backoffer { - return &pdReqBackoffer{ - attempt: resetTSRetryTimeExt, - delayTime: resetTSWaitIntervalExt, - maxDelayTime: resetTSMaxWaitIntervalExt, - } -} - -func (bo *pdReqBackoffer) NextBackoff(err error) time.Duration { - // bo.delayTime = 2 * bo.delayTime - // bo.attempt-- - e := errors.Cause(err) - switch e { // nolint:errorlint - case nil, context.Canceled, context.DeadlineExceeded, sql.ErrNoRows: - // Excepted error, finish the operation - bo.delayTime = 0 - bo.attempt = 0 - case berrors.ErrRestoreTotalKVMismatch, io.EOF: - bo.delayTime = 2 * bo.delayTime - bo.attempt-- - default: - // If the connection timeout, pd client would cancel the context, and return grpc context cancel error. - // So make the codes.Canceled retryable too. - // It's OK to retry the grpc context cancel error, because the parent context cancel returns context.Canceled. - // For example, cancel the `ectx` and then pdClient.GetTS(ectx) returns context.Canceled instead of grpc context canceled. - switch status.Code(e) { - case codes.DeadlineExceeded, codes.Canceled, codes.NotFound, codes.AlreadyExists, codes.PermissionDenied, codes.ResourceExhausted, codes.Aborted, codes.OutOfRange, codes.Unavailable, codes.DataLoss, codes.Unknown: - bo.delayTime = 2 * bo.delayTime - bo.attempt-- - default: - // Unexcepted error - bo.delayTime = 0 - bo.attempt = 0 - log.Warn("unexcepted error, stop to retry", zap.Error(err)) - } - } - - failpoint.Inject("set-attempt-to-one", func(_ failpoint.Value) { - bo.attempt = 1 - }) - if bo.delayTime > bo.maxDelayTime { - return bo.maxDelayTime - } - return bo.delayTime -} - -func (bo *pdReqBackoffer) Attempt() int { - return bo.attempt -} - -type DiskCheckBackoffer struct { - attempt int - delayTime time.Duration - maxDelayTime time.Duration -} - -func NewDiskCheckBackoffer() Backoffer { - return &DiskCheckBackoffer{ - attempt: resetTSRetryTime, - delayTime: resetTSWaitInterval, - maxDelayTime: resetTSMaxWaitInterval, - } -} - -func (bo *DiskCheckBackoffer) NextBackoff(err error) time.Duration { - e := errors.Cause(err) - switch e { // nolint:errorlint - case nil, context.Canceled, context.DeadlineExceeded, berrors.ErrKVDiskFull: - bo.delayTime = 0 - bo.attempt = 0 - case berrors.ErrPDInvalidResponse: - bo.delayTime = 2 * bo.delayTime - bo.attempt-- - default: - bo.delayTime = 2 * bo.delayTime - if bo.attempt > 5 { - bo.attempt = 5 - } - bo.attempt-- - } - - if bo.delayTime > bo.maxDelayTime { - return bo.maxDelayTime - } - return bo.delayTime -} - -func (bo *DiskCheckBackoffer) Attempt() int { - return bo.attempt -} diff --git a/br/pkg/utils/binding__failpoint_binding__.go b/br/pkg/utils/binding__failpoint_binding__.go deleted file mode 100644 index 0a24d8976f30c..0000000000000 --- a/br/pkg/utils/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package utils - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/br/pkg/utils/pprof.go b/br/pkg/utils/pprof.go index 66e5c5e57d14f..c2e5ad63c8e5a 100644 --- a/br/pkg/utils/pprof.go +++ b/br/pkg/utils/pprof.go @@ -33,11 +33,11 @@ func listen(statusAddr string) (net.Listener, error) { log.Warn("Try to start pprof when it has been started, nothing will happen", zap.String("address", startedPProf)) return nil, errors.Annotate(berrors.ErrUnknown, "try to start pprof when it has been started at "+startedPProf) } - if v, _err_ := failpoint.Eval(_curpkg_("determined-pprof-port")); _err_ == nil { + failpoint.Inject("determined-pprof-port", func(v failpoint.Value) { port := v.(int) statusAddr = fmt.Sprintf(":%d", port) log.Info("injecting failpoint, pprof will start at determined port", zap.Int("port", port)) - } + }) listener, err := net.Listen("tcp", statusAddr) if err != nil { log.Warn("failed to start pprof", zap.String("addr", statusAddr), zap.Error(err)) diff --git a/br/pkg/utils/pprof.go__failpoint_stash__ b/br/pkg/utils/pprof.go__failpoint_stash__ deleted file mode 100644 index c2e5ad63c8e5a..0000000000000 --- a/br/pkg/utils/pprof.go__failpoint_stash__ +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package utils - -import ( - "fmt" - "net" //nolint:goimports - // #nosec - // register HTTP handler for /debug/pprof - "net/http" - // For pprof - _ "net/http/pprof" // #nosec G108 - "os" - "sync" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - tidbutils "github.com/pingcap/tidb/pkg/util" - "go.uber.org/zap" -) - -var ( - startedPProf = "" - mu sync.Mutex -) - -func listen(statusAddr string) (net.Listener, error) { - mu.Lock() - defer mu.Unlock() - if startedPProf != "" { - log.Warn("Try to start pprof when it has been started, nothing will happen", zap.String("address", startedPProf)) - return nil, errors.Annotate(berrors.ErrUnknown, "try to start pprof when it has been started at "+startedPProf) - } - failpoint.Inject("determined-pprof-port", func(v failpoint.Value) { - port := v.(int) - statusAddr = fmt.Sprintf(":%d", port) - log.Info("injecting failpoint, pprof will start at determined port", zap.Int("port", port)) - }) - listener, err := net.Listen("tcp", statusAddr) - if err != nil { - log.Warn("failed to start pprof", zap.String("addr", statusAddr), zap.Error(err)) - return nil, errors.Trace(err) - } - startedPProf = listener.Addr().String() - log.Info("bound pprof to addr", zap.String("addr", startedPProf)) - _, _ = fmt.Fprintf(os.Stderr, "bound pprof to addr %s\n", startedPProf) - return listener, nil -} - -// StartPProfListener forks a new goroutine listening on specified port and provide pprof info. -func StartPProfListener(statusAddr string, wrapper *tidbutils.TLS) error { - listener, err := listen(statusAddr) - if err != nil { - return err - } - - go func() { - if e := http.Serve(wrapper.WrapListener(listener), nil); e != nil { - log.Warn("failed to serve pprof", zap.String("addr", startedPProf), zap.Error(e)) - mu.Lock() - startedPProf = "" - mu.Unlock() - return - } - }() - return nil -} diff --git a/br/pkg/utils/register.go b/br/pkg/utils/register.go index 3aeb4a9b21343..95a102ae68d26 100644 --- a/br/pkg/utils/register.go +++ b/br/pkg/utils/register.go @@ -195,17 +195,17 @@ func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.L if timeLeftThreshold < minTimeLeftThreshold { timeLeftThreshold = minTimeLeftThreshold } - if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-always-grant")); _err_ == nil { + failpoint.Inject("brie-task-register-always-grant", func(_ failpoint.Value) { timeLeftThreshold = tr.ttl - } + }) for { CONSUMERESP: for { - if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-keepalive-stop")); _err_ == nil { + failpoint.Inject("brie-task-register-keepalive-stop", func(_ failpoint.Value) { if _, err = tr.client.Lease.Revoke(ctx, tr.curLeaseID); err != nil { log.Warn("brie-task-register-keepalive-stop", zap.Error(err)) } - } + }) select { case <-ctx.Done(): return @@ -223,9 +223,9 @@ func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.L timeGap := time.Since(lastUpdateTime) if tr.ttl-timeGap <= timeLeftThreshold { lease, err := tr.grant(ctx) - if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-failed-to-grant")); _err_ == nil { + failpoint.Inject("brie-task-register-failed-to-grant", func(_ failpoint.Value) { err = errors.New("failpoint-error") - } + }) if err != nil { select { case <-ctx.Done(): @@ -243,9 +243,9 @@ func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.L if needReputKV { // if the lease has expired, need to put the key again _, err := tr.client.KV.Put(ctx, tr.key, "", clientv3.WithLease(tr.curLeaseID)) - if _, _err_ := failpoint.Eval(_curpkg_("brie-task-register-failed-to-reput")); _err_ == nil { + failpoint.Inject("brie-task-register-failed-to-reput", func(_ failpoint.Value) { err = errors.New("failpoint-error") - } + }) if err != nil { select { case <-ctx.Done(): diff --git a/br/pkg/utils/register.go__failpoint_stash__ b/br/pkg/utils/register.go__failpoint_stash__ deleted file mode 100644 index 95a102ae68d26..0000000000000 --- a/br/pkg/utils/register.go__failpoint_stash__ +++ /dev/null @@ -1,334 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package utils - -import ( - "context" - "fmt" - "path" - "strings" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" -) - -// RegisterTaskType for the sub-prefix path for key -type RegisterTaskType int - -const ( - RegisterRestore RegisterTaskType = iota - RegisterLightning - RegisterImportInto -) - -func (tp RegisterTaskType) String() string { - switch tp { - case RegisterRestore: - return "restore" - case RegisterLightning: - return "lightning" - case RegisterImportInto: - return "import-into" - } - return "default" -} - -// The key format should be {RegisterImportTaskPrefix}/{RegisterTaskType}/{taskName} -const ( - // RegisterImportTaskPrefix is the prefix of the key for task register - // todo: remove "/import" suffix, it's confusing to have a key like "/tidb/brie/import/restore/restore-xxx" - RegisterImportTaskPrefix = "/tidb/brie/import" - - RegisterRetryInternal = 10 * time.Second - defaultTaskRegisterTTL = 3 * time.Minute // 3 minutes -) - -// TaskRegister can register the task to PD with a lease. -type TaskRegister interface { - // Close closes the background task if using RegisterTask - // and revoke the lease. - // NOTE: we don't close the etcd client here, call should do it. - Close(ctx context.Context) (err error) - // RegisterTask firstly put its key to PD with a lease, - // and start to keepalive the lease in the background. - // DO NOT mix calls to RegisterTask and RegisterTaskOnce. - RegisterTask(c context.Context) error - // RegisterTaskOnce put its key to PD with a lease if the key does not exist, - // else we refresh the lease. - // you have to call this method periodically to keep the lease alive. - // DO NOT mix calls to RegisterTask and RegisterTaskOnce. - RegisterTaskOnce(ctx context.Context) error -} - -type taskRegister struct { - client *clientv3.Client - ttl time.Duration - secondTTL int64 - key string - - // leaseID used to revoke the lease - curLeaseID clientv3.LeaseID - wg sync.WaitGroup - cancel context.CancelFunc -} - -// NewTaskRegisterWithTTL build a TaskRegister with key format {RegisterTaskPrefix}/{RegisterTaskType}/{taskName} -func NewTaskRegisterWithTTL(client *clientv3.Client, ttl time.Duration, tp RegisterTaskType, taskName string) TaskRegister { - return &taskRegister{ - client: client, - ttl: ttl, - secondTTL: int64(ttl / time.Second), - key: path.Join(RegisterImportTaskPrefix, tp.String(), taskName), - - curLeaseID: clientv3.NoLease, - } -} - -// NewTaskRegister build a TaskRegister with key format {RegisterTaskPrefix}/{RegisterTaskType}/{taskName} -func NewTaskRegister(client *clientv3.Client, tp RegisterTaskType, taskName string) TaskRegister { - return NewTaskRegisterWithTTL(client, defaultTaskRegisterTTL, tp, taskName) -} - -// Close implements the TaskRegister interface -func (tr *taskRegister) Close(ctx context.Context) (err error) { - // not needed if using RegisterTaskOnce - if tr.cancel != nil { - tr.cancel() - } - tr.wg.Wait() - if tr.curLeaseID != clientv3.NoLease { - _, err = tr.client.Lease.Revoke(ctx, tr.curLeaseID) - if err != nil { - log.Warn("failed to revoke the lease", zap.Error(err), zap.Int64("lease-id", int64(tr.curLeaseID))) - } - } - return err -} - -func (tr *taskRegister) grant(ctx context.Context) (*clientv3.LeaseGrantResponse, error) { - lease, err := tr.client.Lease.Grant(ctx, tr.secondTTL) - if err != nil { - return nil, err - } - if len(lease.Error) > 0 { - return nil, errors.New(lease.Error) - } - return lease, nil -} - -// RegisterTaskOnce implements the TaskRegister interface -func (tr *taskRegister) RegisterTaskOnce(ctx context.Context) error { - resp, err := tr.client.Get(ctx, tr.key) - if err != nil { - return errors.Trace(err) - } - if len(resp.Kvs) == 0 { - lease, err2 := tr.grant(ctx) - if err2 != nil { - return errors.Annotatef(err2, "failed grant a lease") - } - tr.curLeaseID = lease.ID - _, err2 = tr.client.KV.Put(ctx, tr.key, "", clientv3.WithLease(lease.ID)) - if err2 != nil { - return errors.Trace(err2) - } - } else { - // if the task is run distributively, like IMPORT INTO, we should refresh the lease ID, - // in case the owner changed during the registration, and the new owner create the key. - tr.curLeaseID = clientv3.LeaseID(resp.Kvs[0].Lease) - _, err2 := tr.client.Lease.KeepAliveOnce(ctx, tr.curLeaseID) - if err2 != nil { - return errors.Trace(err2) - } - } - return nil -} - -// RegisterTask implements the TaskRegister interface -func (tr *taskRegister) RegisterTask(c context.Context) error { - cctx, cancel := context.WithCancel(c) - tr.cancel = cancel - lease, err := tr.grant(cctx) - if err != nil { - return errors.Annotatef(err, "failed grant a lease") - } - tr.curLeaseID = lease.ID - _, err = tr.client.KV.Put(cctx, tr.key, "", clientv3.WithLease(lease.ID)) - if err != nil { - return errors.Trace(err) - } - - // KeepAlive interval equals to ttl/3 - respCh, err := tr.client.Lease.KeepAlive(cctx, lease.ID) - if err != nil { - return errors.Trace(err) - } - tr.wg.Add(1) - go tr.keepaliveLoop(cctx, respCh) - return nil -} - -func (tr *taskRegister) keepaliveLoop(ctx context.Context, ch <-chan *clientv3.LeaseKeepAliveResponse) { - defer tr.wg.Done() - const minTimeLeftThreshold time.Duration = 20 * time.Second - var ( - timeLeftThreshold time.Duration = tr.ttl / 4 - lastUpdateTime time.Time = time.Now() - err error - ) - if timeLeftThreshold < minTimeLeftThreshold { - timeLeftThreshold = minTimeLeftThreshold - } - failpoint.Inject("brie-task-register-always-grant", func(_ failpoint.Value) { - timeLeftThreshold = tr.ttl - }) - for { - CONSUMERESP: - for { - failpoint.Inject("brie-task-register-keepalive-stop", func(_ failpoint.Value) { - if _, err = tr.client.Lease.Revoke(ctx, tr.curLeaseID); err != nil { - log.Warn("brie-task-register-keepalive-stop", zap.Error(err)) - } - }) - select { - case <-ctx.Done(): - return - case _, ok := <-ch: - if !ok { - break CONSUMERESP - } - lastUpdateTime = time.Now() - } - } - log.Warn("the keepalive channel is closed, try to recreate it") - needReputKV := false - RECREATE: - for { - timeGap := time.Since(lastUpdateTime) - if tr.ttl-timeGap <= timeLeftThreshold { - lease, err := tr.grant(ctx) - failpoint.Inject("brie-task-register-failed-to-grant", func(_ failpoint.Value) { - err = errors.New("failpoint-error") - }) - if err != nil { - select { - case <-ctx.Done(): - return - default: - } - log.Warn("failed to grant lease", zap.Error(err)) - time.Sleep(RegisterRetryInternal) - continue - } - tr.curLeaseID = lease.ID - lastUpdateTime = time.Now() - needReputKV = true - } - if needReputKV { - // if the lease has expired, need to put the key again - _, err := tr.client.KV.Put(ctx, tr.key, "", clientv3.WithLease(tr.curLeaseID)) - failpoint.Inject("brie-task-register-failed-to-reput", func(_ failpoint.Value) { - err = errors.New("failpoint-error") - }) - if err != nil { - select { - case <-ctx.Done(): - return - default: - } - log.Warn("failed to put new kv", zap.Error(err)) - time.Sleep(RegisterRetryInternal) - continue - } - needReputKV = false - } - // recreate keepalive - ch, err = tr.client.Lease.KeepAlive(ctx, tr.curLeaseID) - if err != nil { - select { - case <-ctx.Done(): - return - default: - } - log.Warn("failed to create new kv", zap.Error(err)) - time.Sleep(RegisterRetryInternal) - continue - } - - break RECREATE - } - } -} - -// RegisterTask saves the task's information -type RegisterTask struct { - Key string - LeaseID int64 - TTL int64 -} - -// MessageToUser marshal the task to user message -func (task RegisterTask) MessageToUser() string { - return fmt.Sprintf("[ key: %s, lease-id: %x, ttl: %ds ]", task.Key, task.LeaseID, task.TTL) -} - -type RegisterTasksList struct { - Tasks []RegisterTask -} - -func (list RegisterTasksList) MessageToUser() string { - var tasksMsgBuf strings.Builder - for _, task := range list.Tasks { - tasksMsgBuf.WriteString(task.MessageToUser()) - tasksMsgBuf.WriteString(", ") - } - return tasksMsgBuf.String() -} - -func (list RegisterTasksList) Empty() bool { - return len(list.Tasks) == 0 -} - -// GetImportTasksFrom try to get all the import tasks with prefix `RegisterTaskPrefix` -func GetImportTasksFrom(ctx context.Context, client *clientv3.Client) (RegisterTasksList, error) { - resp, err := client.KV.Get(ctx, RegisterImportTaskPrefix, clientv3.WithPrefix()) - if err != nil { - return RegisterTasksList{}, errors.Trace(err) - } - - list := RegisterTasksList{ - Tasks: make([]RegisterTask, 0, len(resp.Kvs)), - } - for _, kv := range resp.Kvs { - leaseResp, err := client.Lease.TimeToLive(ctx, clientv3.LeaseID(kv.Lease)) - if err != nil { - return list, errors.Annotatef(err, "failed to get time-to-live of lease: %x", kv.Lease) - } - // the lease has expired - if leaseResp.TTL <= 0 { - continue - } - list.Tasks = append(list.Tasks, RegisterTask{ - Key: string(kv.Key), - LeaseID: kv.Lease, - TTL: leaseResp.TTL, - }) - } - return list, nil -} diff --git a/br/pkg/utils/store_manager.go b/br/pkg/utils/store_manager.go index cbddafd5bfe70..73e7e3fbb7a07 100644 --- a/br/pkg/utils/store_manager.go +++ b/br/pkg/utils/store_manager.go @@ -119,7 +119,7 @@ func (mgr *StoreManager) PDClient() pd.Client { } func (mgr *StoreManager) getGrpcConnLocked(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - if v, _err_ := failpoint.Eval(_curpkg_("hint-get-backup-client")); _err_ == nil { + failpoint.Inject("hint-get-backup-client", func(v failpoint.Value) { log.Info("failpoint hint-get-backup-client injected, "+ "process will notify the shell.", zap.Uint64("store", storeID)) if sigFile, ok := v.(string); ok { @@ -132,7 +132,7 @@ func (mgr *StoreManager) getGrpcConnLocked(ctx context.Context, storeID uint64) } } time.Sleep(3 * time.Second) - } + }) store, err := mgr.pdClient.GetStore(ctx, storeID) if err != nil { return nil, errors.Trace(err) diff --git a/br/pkg/utils/store_manager.go__failpoint_stash__ b/br/pkg/utils/store_manager.go__failpoint_stash__ deleted file mode 100644 index 73e7e3fbb7a07..0000000000000 --- a/br/pkg/utils/store_manager.go__failpoint_stash__ +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2022 PingCAP, Inc. Licensed under Apache-2.0. - -package utils - -import ( - "context" - "crypto/tls" - "os" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/log" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/backoff" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" -) - -const ( - dialTimeout = 30 * time.Second - resetRetryTimes = 3 -) - -// Pool is a lazy pool of gRPC channels. -// When `Get` called, it lazily allocates new connection if connection not full. -// If it's full, then it will return allocated channels round-robin. -type Pool struct { - mu sync.Mutex - - conns []*grpc.ClientConn - next int - cap int - newConn func(ctx context.Context) (*grpc.ClientConn, error) -} - -func (p *Pool) takeConns() (conns []*grpc.ClientConn) { - p.mu.Lock() - defer p.mu.Unlock() - p.conns, conns = nil, p.conns - p.next = 0 - return conns -} - -// Close closes the conn pool. -func (p *Pool) Close() { - for _, c := range p.takeConns() { - if err := c.Close(); err != nil { - log.Warn("failed to close clientConn", zap.String("target", c.Target()), zap.Error(err)) - } - } -} - -// Get tries to get an existing connection from the pool, or make a new one if the pool not full. -func (p *Pool) Get(ctx context.Context) (*grpc.ClientConn, error) { - p.mu.Lock() - defer p.mu.Unlock() - if len(p.conns) < p.cap { - c, err := p.newConn(ctx) - if err != nil { - return nil, err - } - p.conns = append(p.conns, c) - return c, nil - } - - conn := p.conns[p.next] - p.next = (p.next + 1) % p.cap - return conn, nil -} - -// NewConnPool creates a new Pool by the specified conn factory function and capacity. -func NewConnPool(capacity int, newConn func(ctx context.Context) (*grpc.ClientConn, error)) *Pool { - return &Pool{ - cap: capacity, - conns: make([]*grpc.ClientConn, 0, capacity), - newConn: newConn, - - mu: sync.Mutex{}, - } -} - -type StoreManager struct { - pdClient pd.Client - grpcClis struct { - mu sync.Mutex - clis map[uint64]*grpc.ClientConn - } - keepalive keepalive.ClientParameters - tlsConf *tls.Config -} - -func (mgr *StoreManager) GetKeepalive() keepalive.ClientParameters { - return mgr.keepalive -} - -// NewStoreManager create a new manager for gRPC connections to stores. -func NewStoreManager(pdCli pd.Client, kl keepalive.ClientParameters, tlsConf *tls.Config) *StoreManager { - return &StoreManager{ - pdClient: pdCli, - grpcClis: struct { - mu sync.Mutex - clis map[uint64]*grpc.ClientConn - }{clis: make(map[uint64]*grpc.ClientConn)}, - keepalive: kl, - tlsConf: tlsConf, - } -} - -func (mgr *StoreManager) PDClient() pd.Client { - return mgr.pdClient -} - -func (mgr *StoreManager) getGrpcConnLocked(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - failpoint.Inject("hint-get-backup-client", func(v failpoint.Value) { - log.Info("failpoint hint-get-backup-client injected, "+ - "process will notify the shell.", zap.Uint64("store", storeID)) - if sigFile, ok := v.(string); ok { - file, err := os.Create(sigFile) - if err != nil { - log.Warn("failed to create file for notifying, skipping notify", zap.Error(err)) - } - if file != nil { - file.Close() - } - } - time.Sleep(3 * time.Second) - }) - store, err := mgr.pdClient.GetStore(ctx, storeID) - if err != nil { - return nil, errors.Trace(err) - } - opt := grpc.WithTransportCredentials(insecure.NewCredentials()) - if mgr.tlsConf != nil { - opt = grpc.WithTransportCredentials(credentials.NewTLS(mgr.tlsConf)) - } - ctx, cancel := context.WithTimeout(ctx, dialTimeout) - bfConf := backoff.DefaultConfig - bfConf.MaxDelay = time.Second * 3 - addr := store.GetPeerAddress() - if addr == "" { - addr = store.GetAddress() - } - log.Info("StoreManager: dialing to store.", zap.String("address", addr), zap.Uint64("store-id", storeID)) - conn, err := grpc.DialContext( - ctx, - addr, - opt, - grpc.WithBlock(), - grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), - grpc.WithKeepaliveParams(mgr.keepalive), - ) - cancel() - if err != nil { - return nil, berrors.ErrFailedToConnect.Wrap(err).GenWithStack("failed to make connection to store %d", storeID) - } - return conn, nil -} - -func (mgr *StoreManager) RemoveConn(ctx context.Context, storeID uint64) error { - if ctx.Err() != nil { - return errors.Trace(ctx.Err()) - } - - mgr.grpcClis.mu.Lock() - defer mgr.grpcClis.mu.Unlock() - - if conn, ok := mgr.grpcClis.clis[storeID]; ok { - // Find a cached backup client. - err := conn.Close() - if err != nil { - log.Warn("close backup connection failed, ignore it", zap.Uint64("storeID", storeID)) - } - delete(mgr.grpcClis.clis, storeID) - return nil - } - return nil -} - -func (mgr *StoreManager) TryWithConn(ctx context.Context, storeID uint64, f func(*grpc.ClientConn) error) error { - if ctx.Err() != nil { - return errors.Trace(ctx.Err()) - } - - mgr.grpcClis.mu.Lock() - defer mgr.grpcClis.mu.Unlock() - - if conn, ok := mgr.grpcClis.clis[storeID]; ok { - // Find a cached backup client. - return f(conn) - } - - conn, err := mgr.getGrpcConnLocked(ctx, storeID) - if err != nil { - return errors.Trace(err) - } - // Cache the conn. - mgr.grpcClis.clis[storeID] = conn - return f(conn) -} - -func (mgr *StoreManager) WithConn(ctx context.Context, storeID uint64, f func(*grpc.ClientConn)) error { - return mgr.TryWithConn(ctx, storeID, func(cc *grpc.ClientConn) error { f(cc); return nil }) -} - -// ResetBackupClient reset the connection for backup client. -func (mgr *StoreManager) ResetBackupClient(ctx context.Context, storeID uint64) (backuppb.BackupClient, error) { - var ( - conn *grpc.ClientConn - err error - ) - err = mgr.RemoveConn(ctx, storeID) - if err != nil { - return nil, errors.Trace(err) - } - - mgr.grpcClis.mu.Lock() - defer mgr.grpcClis.mu.Unlock() - - for retry := 0; retry < resetRetryTimes; retry++ { - conn, err = mgr.getGrpcConnLocked(ctx, storeID) - if err != nil { - log.Warn("failed to reset grpc connection, retry it", - zap.Int("retry time", retry), logutil.ShortError(err)) - time.Sleep(time.Duration(retry+3) * time.Second) - continue - } - mgr.grpcClis.clis[storeID] = conn - break - } - if err != nil { - return nil, errors.Trace(err) - } - return backuppb.NewBackupClient(conn), nil -} - -// Close closes all client in Mgr. -func (mgr *StoreManager) Close() { - if mgr == nil { - return - } - mgr.grpcClis.mu.Lock() - for _, cli := range mgr.grpcClis.clis { - err := cli.Close() - if err != nil { - log.Error("fail to close Mgr", zap.Error(err)) - } - } - mgr.grpcClis.mu.Unlock() -} - -func (mgr *StoreManager) TLSConfig() *tls.Config { - if mgr == nil { - return nil - } - return mgr.tlsConf -} diff --git a/dumpling/export/binding__failpoint_binding__.go b/dumpling/export/binding__failpoint_binding__.go deleted file mode 100644 index e39dec4835192..0000000000000 --- a/dumpling/export/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package export - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/dumpling/export/config.go b/dumpling/export/config.go index a799184ae43c1..52337ec732601 100644 --- a/dumpling/export/config.go +++ b/dumpling/export/config.go @@ -291,11 +291,11 @@ func (conf *Config) GetDriverConfig(db string) *mysql.Config { if conf.AllowCleartextPasswords { driverCfg.AllowCleartextPasswords = true } - if val, _err_ := failpoint.Eval(_curpkg_("SetWaitTimeout")); _err_ == nil { + failpoint.Inject("SetWaitTimeout", func(val failpoint.Value) { driverCfg.Params = map[string]string{ "wait_timeout": strconv.Itoa(val.(int)), } - } + }) return driverCfg } diff --git a/dumpling/export/config.go__failpoint_stash__ b/dumpling/export/config.go__failpoint_stash__ deleted file mode 100644 index 52337ec732601..0000000000000 --- a/dumpling/export/config.go__failpoint_stash__ +++ /dev/null @@ -1,809 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package export - -import ( - "context" - "crypto/tls" - "encoding/json" - "fmt" - "net" - "strconv" - "strings" - "text/template" - "time" - - "github.com/coreos/go-semver/semver" - "github.com/docker/go-units" - "github.com/go-sql-driver/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/promutil" - filter "github.com/pingcap/tidb/pkg/util/table-filter" - "github.com/prometheus/client_golang/prometheus" - "github.com/spf13/pflag" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -const ( - flagDatabase = "database" - flagTablesList = "tables-list" - flagHost = "host" - flagUser = "user" - flagPort = "port" - flagPassword = "password" - flagAllowCleartextPasswords = "allow-cleartext-passwords" - flagThreads = "threads" - flagFilesize = "filesize" - flagStatementSize = "statement-size" - flagOutput = "output" - flagLoglevel = "loglevel" - flagLogfile = "logfile" - flagLogfmt = "logfmt" - flagConsistency = "consistency" - flagSnapshot = "snapshot" - flagNoViews = "no-views" - flagNoSequences = "no-sequences" - flagSortByPk = "order-by-primary-key" - flagStatusAddr = "status-addr" - flagRows = "rows" - flagWhere = "where" - flagEscapeBackslash = "escape-backslash" - flagFiletype = "filetype" - flagNoHeader = "no-header" - flagNoSchemas = "no-schemas" - flagNoData = "no-data" - flagCsvNullValue = "csv-null-value" - flagSQL = "sql" - flagFilter = "filter" - flagCaseSensitive = "case-sensitive" - flagDumpEmptyDatabase = "dump-empty-database" - flagTidbMemQuotaQuery = "tidb-mem-quota-query" - flagCA = "ca" - flagCert = "cert" - flagKey = "key" - flagCsvSeparator = "csv-separator" - flagCsvDelimiter = "csv-delimiter" - flagCsvLineTerminator = "csv-line-terminator" - flagOutputFilenameTemplate = "output-filename-template" - flagCompleteInsert = "complete-insert" - flagParams = "params" - flagReadTimeout = "read-timeout" - flagTransactionalConsistency = "transactional-consistency" - flagCompress = "compress" - flagCsvOutputDialect = "csv-output-dialect" - - // FlagHelp represents the help flag - FlagHelp = "help" -) - -// CSVDialect is the dialect of the CSV output for compatible with different import target -type CSVDialect int - -const ( - // CSVDialectDefault is the default dialect, which is MySQL/MariaDB/TiDB etc. - CSVDialectDefault CSVDialect = iota - // CSVDialectSnowflake is the dialect of Snowflake - CSVDialectSnowflake - // CSVDialectRedshift is the dialect of Redshift - CSVDialectRedshift - // CSVDialectBigQuery is the dialect of BigQuery - CSVDialectBigQuery -) - -// BinaryFormat is the format of binary data -// Three standard formats are supported: UTF8, HEX and Base64 now. -type BinaryFormat int - -const ( - // BinaryFormatUTF8 is the default format, format binary data as UTF8 string - BinaryFormatUTF8 BinaryFormat = iota - // BinaryFormatHEX format binary data as HEX string, e.g. 12ABCD - BinaryFormatHEX - // BinaryFormatBase64 format binary data as Base64 string, e.g. 123qwer== - BinaryFormatBase64 -) - -// DialectBinaryFormatMap is the map of dialect and binary format -var DialectBinaryFormatMap = map[CSVDialect]BinaryFormat{ - CSVDialectDefault: BinaryFormatUTF8, - CSVDialectSnowflake: BinaryFormatHEX, - CSVDialectRedshift: BinaryFormatHEX, - CSVDialectBigQuery: BinaryFormatBase64, -} - -// Config is the dump config for dumpling -type Config struct { - storage.BackendOptions - - SpecifiedTables bool - AllowCleartextPasswords bool - SortByPk bool - NoViews bool - NoSequences bool - NoHeader bool - NoSchemas bool - NoData bool - CompleteInsert bool - TransactionalConsistency bool - EscapeBackslash bool - DumpEmptyDatabase bool - PosAfterConnect bool - CompressType storage.CompressType - - Host string - Port int - Threads int - User string - Password string `json:"-"` - Security struct { - TLS *tls.Config `json:"-"` - CAPath string - CertPath string - KeyPath string - SSLCABytes []byte `json:"-"` - SSLCertBytes []byte `json:"-"` - SSLKeyBytes []byte `json:"-"` - } - - LogLevel string - LogFile string - LogFormat string - OutputDirPath string - StatusAddr string - Snapshot string - Consistency string - CsvNullValue string - SQL string - CsvSeparator string - CsvDelimiter string - CsvLineTerminator string - Databases []string - - TableFilter filter.Filter `json:"-"` - Where string - FileType string - ServerInfo version.ServerInfo - Logger *zap.Logger `json:"-"` - OutputFileTemplate *template.Template `json:"-"` - Rows uint64 - ReadTimeout time.Duration - TiDBMemQuotaQuery uint64 - FileSize uint64 - StatementSize uint64 - SessionParams map[string]any - Tables DatabaseTables - CollationCompatible string - CsvOutputDialect CSVDialect - - Labels prometheus.Labels `json:"-"` - PromFactory promutil.Factory `json:"-"` - PromRegistry promutil.Registry `json:"-"` - ExtStorage storage.ExternalStorage `json:"-"` - MinTLSVersion uint16 `json:"-"` - - IOTotalBytes *atomic.Uint64 - Net string -} - -// ServerInfoUnknown is the unknown database type to dumpling -var ServerInfoUnknown = version.ServerInfo{ - ServerType: version.ServerTypeUnknown, - ServerVersion: nil, -} - -// DefaultConfig returns the default export Config for dumpling -func DefaultConfig() *Config { - allFilter, _ := filter.Parse([]string{"*.*"}) - return &Config{ - Databases: nil, - Host: "127.0.0.1", - User: "root", - Port: 3306, - Password: "", - Threads: 4, - Logger: nil, - StatusAddr: ":8281", - FileSize: UnspecifiedSize, - StatementSize: DefaultStatementSize, - OutputDirPath: ".", - ServerInfo: ServerInfoUnknown, - SortByPk: true, - Tables: nil, - Snapshot: "", - Consistency: ConsistencyTypeAuto, - NoViews: true, - NoSequences: true, - Rows: UnspecifiedSize, - Where: "", - EscapeBackslash: true, - FileType: "", - NoHeader: false, - NoSchemas: false, - NoData: false, - CsvNullValue: "\\N", - SQL: "", - TableFilter: allFilter, - DumpEmptyDatabase: true, - CsvDelimiter: "\"", - CsvSeparator: ",", - CsvLineTerminator: "\r\n", - SessionParams: make(map[string]any), - OutputFileTemplate: DefaultOutputFileTemplate, - PosAfterConnect: false, - CollationCompatible: LooseCollationCompatible, - CsvOutputDialect: CSVDialectDefault, - SpecifiedTables: false, - PromFactory: promutil.NewDefaultFactory(), - PromRegistry: promutil.NewDefaultRegistry(), - TransactionalConsistency: true, - } -} - -// String returns dumpling's config in json format -func (conf *Config) String() string { - cfg, err := json.Marshal(conf) - if err != nil && conf.Logger != nil { - conf.Logger.Error("fail to marshal config to json", zap.Error(err)) - } - return string(cfg) -} - -// GetDriverConfig returns the MySQL driver config from Config. -func (conf *Config) GetDriverConfig(db string) *mysql.Config { - driverCfg := mysql.NewConfig() - // maxAllowedPacket=0 can be used to automatically fetch the max_allowed_packet variable from server on every connection. - // https://github.com/go-sql-driver/mysql#maxallowedpacket - hostPort := net.JoinHostPort(conf.Host, strconv.Itoa(conf.Port)) - driverCfg.User = conf.User - driverCfg.Passwd = conf.Password - driverCfg.Net = "tcp" - if conf.Net != "" { - driverCfg.Net = conf.Net - } - driverCfg.Addr = hostPort - driverCfg.DBName = db - driverCfg.Collation = "utf8mb4_general_ci" - driverCfg.ReadTimeout = conf.ReadTimeout - driverCfg.WriteTimeout = 30 * time.Second - driverCfg.InterpolateParams = true - driverCfg.MaxAllowedPacket = 0 - if conf.Security.TLS != nil { - driverCfg.TLS = conf.Security.TLS - } else { - // Use TLS first. - driverCfg.AllowFallbackToPlaintext = true - minTLSVersion := uint16(tls.VersionTLS12) - if conf.MinTLSVersion != 0 { - minTLSVersion = conf.MinTLSVersion - } - /* #nosec G402 */ - driverCfg.TLS = &tls.Config{ - InsecureSkipVerify: true, - MinVersion: minTLSVersion, - NextProtos: []string{"h2", "http/1.1"}, // specify `h2` to let Go use HTTP/2. - } - } - if conf.AllowCleartextPasswords { - driverCfg.AllowCleartextPasswords = true - } - failpoint.Inject("SetWaitTimeout", func(val failpoint.Value) { - driverCfg.Params = map[string]string{ - "wait_timeout": strconv.Itoa(val.(int)), - } - }) - return driverCfg -} - -func timestampDirName() string { - return fmt.Sprintf("./export-%s", time.Now().Format(time.RFC3339)) -} - -// DefineFlags defines flags of dumpling's configuration -func (*Config) DefineFlags(flags *pflag.FlagSet) { - storage.DefineFlags(flags) - flags.StringSliceP(flagDatabase, "B", nil, "Databases to dump") - flags.StringSliceP(flagTablesList, "T", nil, "Comma delimited table list to dump; must be qualified table names") - flags.StringP(flagHost, "h", "127.0.0.1", "The host to connect to") - flags.StringP(flagUser, "u", "root", "Username with privileges to run the dump") - flags.IntP(flagPort, "P", 4000, "TCP/IP port to connect to") - flags.StringP(flagPassword, "p", "", "User password") - flags.Bool(flagAllowCleartextPasswords, false, "Allow passwords to be sent in cleartext (warning: don't use without TLS)") - flags.IntP(flagThreads, "t", 4, "Number of goroutines to use, default 4") - flags.StringP(flagFilesize, "F", "", "The approximate size of output file") - flags.Uint64P(flagStatementSize, "s", DefaultStatementSize, "Attempted size of INSERT statement in bytes") - flags.StringP(flagOutput, "o", timestampDirName(), "Output directory") - flags.String(flagLoglevel, "info", "Log level: {debug|info|warn|error|dpanic|panic|fatal}") - flags.StringP(flagLogfile, "L", "", "Log file `path`, leave empty to write to console") - flags.String(flagLogfmt, "text", "Log `format`: {text|json}") - flags.String(flagConsistency, ConsistencyTypeAuto, "Consistency level during dumping: {auto|none|flush|lock|snapshot}") - flags.String(flagSnapshot, "", "Snapshot position (uint64 or MySQL style string timestamp). Valid only when consistency=snapshot") - flags.BoolP(flagNoViews, "W", true, "Do not dump views") - flags.Bool(flagNoSequences, true, "Do not dump sequences") - flags.Bool(flagSortByPk, true, "Sort dump results by primary key through order by sql") - flags.String(flagStatusAddr, ":8281", "dumpling API server and pprof addr") - flags.Uint64P(flagRows, "r", UnspecifiedSize, "If specified, dumpling will split table into chunks and concurrently dump them to different files to improve efficiency. For TiDB v3.0+, specify this will make dumpling split table with each file one TiDB region(no matter how many rows is).\n"+ - "If not specified, dumpling will dump table without inner-concurrency which could be relatively slow. default unlimited") - flags.String(flagWhere, "", "Dump only selected records") - flags.Bool(flagEscapeBackslash, true, "use backslash to escape special characters") - flags.String(flagFiletype, "", "The type of export file (sql/csv)") - flags.Bool(flagNoHeader, false, "whether not to dump CSV table header") - flags.BoolP(flagNoSchemas, "m", false, "Do not dump table schemas with the data") - flags.BoolP(flagNoData, "d", false, "Do not dump table data") - flags.String(flagCsvNullValue, "\\N", "The null value used when export to csv") - flags.StringP(flagSQL, "S", "", "Dump data with given sql. This argument doesn't support concurrent dump") - _ = flags.MarkHidden(flagSQL) - flags.StringSliceP(flagFilter, "f", []string{"*.*", DefaultTableFilter}, "filter to select which tables to dump") - flags.Bool(flagCaseSensitive, false, "whether the filter should be case-sensitive") - flags.Bool(flagDumpEmptyDatabase, true, "whether to dump empty database") - flags.Uint64(flagTidbMemQuotaQuery, UnspecifiedSize, "The maximum memory limit for a single SQL statement, in bytes.") - flags.String(flagCA, "", "The path name to the certificate authority file for TLS connection") - flags.String(flagCert, "", "The path name to the client certificate file for TLS connection") - flags.String(flagKey, "", "The path name to the client private key file for TLS connection") - flags.String(flagCsvSeparator, ",", "The separator for csv files, default ','") - flags.String(flagCsvDelimiter, "\"", "The delimiter for values in csv files, default '\"'") - flags.String(flagCsvLineTerminator, "\r\n", "The line terminator for csv files, default '\\r\\n'") - flags.String(flagOutputFilenameTemplate, "", "The output filename template (without file extension)") - flags.Bool(flagCompleteInsert, false, "Use complete INSERT statements that include column names") - flags.StringToString(flagParams, nil, `Extra session variables used while dumping, accepted format: --params "character_set_client=latin1,character_set_connection=latin1"`) - flags.Bool(FlagHelp, false, "Print help message and quit") - flags.Duration(flagReadTimeout, 15*time.Minute, "I/O read timeout for db connection.") - _ = flags.MarkHidden(flagReadTimeout) - flags.Bool(flagTransactionalConsistency, true, "Only support transactional consistency") - _ = flags.MarkHidden(flagTransactionalConsistency) - flags.StringP(flagCompress, "c", "", "Compress output file type, support 'gzip', 'snappy', 'zstd', 'no-compression' now") - flags.String(flagCsvOutputDialect, "", "The dialect of output CSV file, support 'snowflake', 'redshift', 'bigquery' now") -} - -// ParseFromFlags parses dumpling's export.Config from flags -// nolint: gocyclo -func (conf *Config) ParseFromFlags(flags *pflag.FlagSet) error { - var err error - conf.Databases, err = flags.GetStringSlice(flagDatabase) - if err != nil { - return errors.Trace(err) - } - conf.Host, err = flags.GetString(flagHost) - if err != nil { - return errors.Trace(err) - } - conf.User, err = flags.GetString(flagUser) - if err != nil { - return errors.Trace(err) - } - conf.Port, err = flags.GetInt(flagPort) - if err != nil { - return errors.Trace(err) - } - conf.Password, err = flags.GetString(flagPassword) - if err != nil { - return errors.Trace(err) - } - conf.AllowCleartextPasswords, err = flags.GetBool(flagAllowCleartextPasswords) - if err != nil { - return errors.Trace(err) - } - conf.Threads, err = flags.GetInt(flagThreads) - if err != nil { - return errors.Trace(err) - } - conf.StatementSize, err = flags.GetUint64(flagStatementSize) - if err != nil { - return errors.Trace(err) - } - conf.OutputDirPath, err = flags.GetString(flagOutput) - if err != nil { - return errors.Trace(err) - } - conf.LogLevel, err = flags.GetString(flagLoglevel) - if err != nil { - return errors.Trace(err) - } - conf.LogFile, err = flags.GetString(flagLogfile) - if err != nil { - return errors.Trace(err) - } - conf.LogFormat, err = flags.GetString(flagLogfmt) - if err != nil { - return errors.Trace(err) - } - conf.Consistency, err = flags.GetString(flagConsistency) - if err != nil { - return errors.Trace(err) - } - conf.Snapshot, err = flags.GetString(flagSnapshot) - if err != nil { - return errors.Trace(err) - } - conf.NoViews, err = flags.GetBool(flagNoViews) - if err != nil { - return errors.Trace(err) - } - conf.NoSequences, err = flags.GetBool(flagNoSequences) - if err != nil { - return errors.Trace(err) - } - conf.SortByPk, err = flags.GetBool(flagSortByPk) - if err != nil { - return errors.Trace(err) - } - conf.StatusAddr, err = flags.GetString(flagStatusAddr) - if err != nil { - return errors.Trace(err) - } - conf.Rows, err = flags.GetUint64(flagRows) - if err != nil { - return errors.Trace(err) - } - conf.Where, err = flags.GetString(flagWhere) - if err != nil { - return errors.Trace(err) - } - conf.EscapeBackslash, err = flags.GetBool(flagEscapeBackslash) - if err != nil { - return errors.Trace(err) - } - conf.FileType, err = flags.GetString(flagFiletype) - if err != nil { - return errors.Trace(err) - } - conf.NoHeader, err = flags.GetBool(flagNoHeader) - if err != nil { - return errors.Trace(err) - } - conf.NoSchemas, err = flags.GetBool(flagNoSchemas) - if err != nil { - return errors.Trace(err) - } - conf.NoData, err = flags.GetBool(flagNoData) - if err != nil { - return errors.Trace(err) - } - conf.CsvNullValue, err = flags.GetString(flagCsvNullValue) - if err != nil { - return errors.Trace(err) - } - conf.SQL, err = flags.GetString(flagSQL) - if err != nil { - return errors.Trace(err) - } - conf.DumpEmptyDatabase, err = flags.GetBool(flagDumpEmptyDatabase) - if err != nil { - return errors.Trace(err) - } - conf.Security.CAPath, err = flags.GetString(flagCA) - if err != nil { - return errors.Trace(err) - } - conf.Security.CertPath, err = flags.GetString(flagCert) - if err != nil { - return errors.Trace(err) - } - conf.Security.KeyPath, err = flags.GetString(flagKey) - if err != nil { - return errors.Trace(err) - } - conf.CsvSeparator, err = flags.GetString(flagCsvSeparator) - if err != nil { - return errors.Trace(err) - } - conf.CsvDelimiter, err = flags.GetString(flagCsvDelimiter) - if err != nil { - return errors.Trace(err) - } - conf.CsvLineTerminator, err = flags.GetString(flagCsvLineTerminator) - if err != nil { - return errors.Trace(err) - } - conf.CompleteInsert, err = flags.GetBool(flagCompleteInsert) - if err != nil { - return errors.Trace(err) - } - conf.ReadTimeout, err = flags.GetDuration(flagReadTimeout) - if err != nil { - return errors.Trace(err) - } - conf.TransactionalConsistency, err = flags.GetBool(flagTransactionalConsistency) - if err != nil { - return errors.Trace(err) - } - conf.TiDBMemQuotaQuery, err = flags.GetUint64(flagTidbMemQuotaQuery) - if err != nil { - return errors.Trace(err) - } - - if conf.Threads <= 0 { - return errors.Errorf("--threads is set to %d. It should be greater than 0", conf.Threads) - } - if len(conf.CsvSeparator) == 0 { - return errors.New("--csv-separator is set to \"\". It must not be an empty string") - } - - if conf.SessionParams == nil { - conf.SessionParams = make(map[string]any) - } - - tablesList, err := flags.GetStringSlice(flagTablesList) - if err != nil { - return errors.Trace(err) - } - fileSizeStr, err := flags.GetString(flagFilesize) - if err != nil { - return errors.Trace(err) - } - filters, err := flags.GetStringSlice(flagFilter) - if err != nil { - return errors.Trace(err) - } - caseSensitive, err := flags.GetBool(flagCaseSensitive) - if err != nil { - return errors.Trace(err) - } - outputFilenameFormat, err := flags.GetString(flagOutputFilenameTemplate) - if err != nil { - return errors.Trace(err) - } - params, err := flags.GetStringToString(flagParams) - if err != nil { - return errors.Trace(err) - } - - conf.SpecifiedTables = len(tablesList) > 0 - conf.Tables, err = GetConfTables(tablesList) - if err != nil { - return errors.Trace(err) - } - - conf.TableFilter, err = ParseTableFilter(tablesList, filters) - if err != nil { - return errors.Errorf("failed to parse filter: %s", err) - } - - if !caseSensitive { - conf.TableFilter = filter.CaseInsensitive(conf.TableFilter) - } - - conf.FileSize, err = ParseFileSize(fileSizeStr) - if err != nil { - return errors.Trace(err) - } - - if outputFilenameFormat == "" && conf.SQL != "" { - outputFilenameFormat = DefaultAnonymousOutputFileTemplateText - } - tmpl, err := ParseOutputFileTemplate(outputFilenameFormat) - if err != nil { - return errors.Errorf("failed to parse output filename template (--output-filename-template '%s')", outputFilenameFormat) - } - conf.OutputFileTemplate = tmpl - - compressType, err := flags.GetString(flagCompress) - if err != nil { - return errors.Trace(err) - } - conf.CompressType, err = ParseCompressType(compressType) - if err != nil { - return errors.Trace(err) - } - - dialect, err := flags.GetString(flagCsvOutputDialect) - if err != nil { - return errors.Trace(err) - } - if dialect != "" && conf.FileType != "csv" { - return errors.Errorf("%s is only supported when dumping whole table to csv, not compatible with %s", flagCsvOutputDialect, conf.FileType) - } - conf.CsvOutputDialect, err = ParseOutputDialect(dialect) - if err != nil { - return errors.Trace(err) - } - - for k, v := range params { - conf.SessionParams[k] = v - } - - err = conf.BackendOptions.ParseFromFlags(pflag.CommandLine) - if err != nil { - return errors.Trace(err) - } - - return nil -} - -// ParseFileSize parses file size from tables-list and filter arguments -func ParseFileSize(fileSizeStr string) (uint64, error) { - if len(fileSizeStr) == 0 { - return UnspecifiedSize, nil - } else if fileSizeMB, err := strconv.ParseUint(fileSizeStr, 10, 64); err == nil { - fmt.Printf("Warning: -F without unit is not recommended, try using `-F '%dMiB'` in the future\n", fileSizeMB) - return fileSizeMB * units.MiB, nil - } else if size, err := units.RAMInBytes(fileSizeStr); err == nil { - return uint64(size), nil - } - return 0, errors.Errorf("failed to parse filesize (-F '%s')", fileSizeStr) -} - -// ParseTableFilter parses table filter from tables-list and filter arguments -func ParseTableFilter(tablesList, filters []string) (filter.Filter, error) { - if len(tablesList) == 0 { - return filter.Parse(filters) - } - - // only parse -T when -f is default value. otherwise bail out. - if !sameStringArray(filters, []string{"*.*", DefaultTableFilter}) { - return nil, errors.New("cannot pass --tables-list and --filter together") - } - - tableNames := make([]filter.Table, 0, len(tablesList)) - for _, table := range tablesList { - parts := strings.SplitN(table, ".", 2) - if len(parts) < 2 { - return nil, errors.Errorf("--tables-list only accepts qualified table names, but `%s` lacks a dot", table) - } - tableNames = append(tableNames, filter.Table{Schema: parts[0], Name: parts[1]}) - } - - return filter.NewTablesFilter(tableNames...), nil -} - -// GetConfTables parses tables from tables-list and filter arguments -func GetConfTables(tablesList []string) (DatabaseTables, error) { - dbTables := DatabaseTables{} - var ( - tablename string - avgRowLength uint64 - ) - avgRowLength = 0 - for _, tablename = range tablesList { - parts := strings.SplitN(tablename, ".", 2) - if len(parts) < 2 { - return nil, errors.Errorf("--tables-list only accepts qualified table names, but `%s` lacks a dot", tablename) - } - dbName := parts[0] - tbName := parts[1] - dbTables[dbName] = append(dbTables[dbName], &TableInfo{tbName, avgRowLength, TableTypeBase}) - } - return dbTables, nil -} - -// ParseCompressType parses compressType string to storage.CompressType -func ParseCompressType(compressType string) (storage.CompressType, error) { - switch compressType { - case "", "no-compression": - return storage.NoCompression, nil - case "gzip", "gz": - return storage.Gzip, nil - case "snappy": - return storage.Snappy, nil - case "zstd", "zst": - return storage.Zstd, nil - default: - return storage.NoCompression, errors.Errorf("unknown compress type %s", compressType) - } -} - -// ParseOutputDialect parses output dialect string to Dialect -func ParseOutputDialect(outputDialect string) (CSVDialect, error) { - switch outputDialect { - case "", "default": - return CSVDialectDefault, nil - case "snowflake": - return CSVDialectSnowflake, nil - case "redshift": - return CSVDialectRedshift, nil - case "bigquery": - return CSVDialectBigQuery, nil - default: - return CSVDialectDefault, errors.Errorf("unknown output dialect %s", outputDialect) - } -} - -func (conf *Config) createExternalStorage(ctx context.Context) (storage.ExternalStorage, error) { - if conf.ExtStorage != nil { - return conf.ExtStorage, nil - } - b, err := storage.ParseBackend(conf.OutputDirPath, &conf.BackendOptions) - if err != nil { - return nil, errors.Trace(err) - } - - // TODO: support setting httpClient with certification later - return storage.New(ctx, b, &storage.ExternalStorageOptions{}) -} - -const ( - // UnspecifiedSize means the filesize/statement-size is unspecified - UnspecifiedSize = 0 - // DefaultStatementSize is the default statement size - DefaultStatementSize = 1000000 - // TiDBMemQuotaQueryName is the session variable TiDBMemQuotaQuery's name in TiDB - TiDBMemQuotaQueryName = "tidb_mem_quota_query" - // DefaultTableFilter is the default exclude table filter. It will exclude all system databases - DefaultTableFilter = "!/^(mysql|sys|INFORMATION_SCHEMA|PERFORMANCE_SCHEMA|METRICS_SCHEMA|INSPECTION_SCHEMA)$/.*" - - defaultTaskChannelCapacity = 128 - defaultDumpGCSafePointTTL = 5 * 60 - defaultEtcdDialTimeOut = 3 * time.Second - - // LooseCollationCompatible is used in DM, represents a collation setting for best compatibility. - LooseCollationCompatible = "loose" - // StrictCollationCompatible is used in DM, represents a collation setting for correctness. - StrictCollationCompatible = "strict" - - dumplingServiceSafePointPrefix = "dumpling" -) - -var ( - decodeRegionVersion = semver.New("3.0.0") - gcSafePointVersion = semver.New("4.0.0") - tableSampleVersion = semver.New("5.0.0-nightly") -) - -func adjustConfig(conf *Config, fns ...func(*Config) error) error { - for _, f := range fns { - err := f(conf) - if err != nil { - return err - } - } - return nil -} - -func buildTLSConfig(conf *Config) error { - tlsConfig, err := util.NewTLSConfig( - util.WithCAPath(conf.Security.CAPath), - util.WithCertAndKeyPath(conf.Security.CertPath, conf.Security.KeyPath), - util.WithCAContent(conf.Security.SSLCABytes), - util.WithCertAndKeyContent(conf.Security.SSLCertBytes, conf.Security.SSLKeyBytes), - util.WithMinTLSVersion(conf.MinTLSVersion), - ) - if err != nil { - return errors.Trace(err) - } - conf.Security.TLS = tlsConfig - return nil -} - -func validateSpecifiedSQL(conf *Config) error { - if conf.SQL != "" && conf.Where != "" { - return errors.New("can't specify both --sql and --where at the same time. Please try to combine them into --sql") - } - return nil -} - -func adjustFileFormat(conf *Config) error { - conf.FileType = strings.ToLower(conf.FileType) - switch conf.FileType { - case "": - if conf.SQL != "" { - conf.FileType = FileFormatCSVString - } else { - conf.FileType = FileFormatSQLTextString - } - case FileFormatSQLTextString: - if conf.SQL != "" { - return errors.Errorf("unsupported config.FileType '%s' when we specify --sql, please unset --filetype or set it to 'csv'", conf.FileType) - } - case FileFormatCSVString: - default: - return errors.Errorf("unknown config.FileType '%s'", conf.FileType) - } - return nil -} - -func matchMysqlBugversion(info version.ServerInfo) bool { - // if 8.0.3 <= mysql8 version < 8.0.23 - // FLUSH TABLES WITH READ LOCK could block other sessions from executing SHOW TABLE STATUS. - // see more in https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-23.html - if info.ServerType != version.ServerTypeMySQL { - return false - } - currentVersion := info.ServerVersion - bugVersionStart := semver.New("8.0.2") - bugVersionEnd := semver.New("8.0.23") - return bugVersionStart.LessThan(*currentVersion) && currentVersion.LessThan(*bugVersionEnd) -} diff --git a/dumpling/export/dump.go b/dumpling/export/dump.go index 0d8627108e089..0d94c814eaa77 100644 --- a/dumpling/export/dump.go +++ b/dumpling/export/dump.go @@ -72,7 +72,7 @@ type Dumper struct { // NewDumper returns a new Dumper func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { - if val, _err_ := failpoint.Eval(_curpkg_("setExtStorage")); _err_ == nil { + failpoint.Inject("setExtStorage", func(val failpoint.Value) { path := val.(string) b, err := storage.ParseBackend(path, nil) if err != nil { @@ -83,7 +83,7 @@ func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { panic(err) } conf.ExtStorage = s - } + }) tctx, cancelFn := tcontext.Background().WithContext(ctx).WithCancel() d := &Dumper{ @@ -111,7 +111,7 @@ func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { if err != nil { return nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("SetIOTotalBytes")); _err_ == nil { + failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { d.conf.IOTotalBytes = gatomic.NewUint64(0) d.conf.Net = uuid.New().String() go func() { @@ -120,7 +120,7 @@ func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { d.tctx.L().Logger.Info("IOTotalBytes", zap.Uint64("IOTotalBytes", d.conf.IOTotalBytes.Load())) } }() - } + }) err = runSteps(d, initLogger, @@ -257,9 +257,9 @@ func (d *Dumper) Dump() (dumpErr error) { } chanSize := defaultTaskChannelCapacity - if _, _err_ := failpoint.Eval(_curpkg_("SmallDumpChanSize")); _err_ == nil { + failpoint.Inject("SmallDumpChanSize", func() { chanSize = 1 - } + }) taskIn, taskOut := infiniteChan[Task]() // todo: refine metrics AddGauge(d.metrics.taskChannelCapacity, float64(chanSize)) @@ -280,7 +280,7 @@ func (d *Dumper) Dump() (dumpErr error) { } } // Inject consistency failpoint test after we release the table lock - failpoint.Eval(_curpkg_("ConsistencyCheck")) + failpoint.Inject("ConsistencyCheck", nil) if conf.PosAfterConnect { // record again, to provide a location to exit safe mode for DM @@ -300,7 +300,7 @@ func (d *Dumper) Dump() (dumpErr error) { tableDataStartTime := time.Now() - if _, _err_ := failpoint.Eval(_curpkg_("PrintTiDBMemQuotaQuery")); _err_ == nil { + failpoint.Inject("PrintTiDBMemQuotaQuery", func(_ failpoint.Value) { row := d.dbHandle.QueryRowContext(tctx, "select @@tidb_mem_quota_query;") var s string err = row.Scan(&s) @@ -309,7 +309,7 @@ func (d *Dumper) Dump() (dumpErr error) { } else { fmt.Printf("tidb_mem_quota_query == %s\n", s) } - } + }) baseConn := newBaseConn(metaConn, true, rebuildMetaConn) if conf.SQL == "" { @@ -321,10 +321,10 @@ func (d *Dumper) Dump() (dumpErr error) { } d.metrics.progressReady.Store(true) close(taskIn) - if _, _err_ := failpoint.Eval(_curpkg_("EnableLogProgress")); _err_ == nil { + failpoint.Inject("EnableLogProgress", func() { time.Sleep(1 * time.Second) tctx.L().Debug("progress ready, sleep 1s") - } + }) _ = baseConn.DBConn.Close() if err := wg.Wait(); err != nil { summary.CollectFailureUnit("dump table data", err) @@ -357,10 +357,10 @@ func (d *Dumper) startWriters(tctx *tcontext.Context, wg *errgroup.Group, taskCh // tctx.L().Debug("finished dumping table data", // zap.String("database", td.Meta.DatabaseName()), // zap.String("table", td.Meta.TableName())) - if _, _err_ := failpoint.Eval(_curpkg_("EnableLogProgress")); _err_ == nil { + failpoint.Inject("EnableLogProgress", func() { time.Sleep(1 * time.Second) tctx.L().Debug("EnableLogProgress, sleep 1s") - } + }) } }) writer.setFinishTaskCallBack(func(task Task) { diff --git a/dumpling/export/dump.go__failpoint_stash__ b/dumpling/export/dump.go__failpoint_stash__ deleted file mode 100644 index 0d94c814eaa77..0000000000000 --- a/dumpling/export/dump.go__failpoint_stash__ +++ /dev/null @@ -1,1704 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package export - -import ( - "bytes" - "context" - "database/sql" - "database/sql/driver" - "encoding/hex" - "fmt" - "math/big" - "net" - "slices" - "strconv" - "strings" - "sync/atomic" - "time" - - "github.com/coreos/go-semver/semver" - // import mysql driver - "github.com/go-sql-driver/mysql" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - pclog "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/summary" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/dumpling/cli" - tcontext "github.com/pingcap/tidb/dumpling/context" - "github.com/pingcap/tidb/dumpling/log" - infoschema "github.com/pingcap/tidb/pkg/infoschema/context" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/codec" - pd "github.com/tikv/pd/client" - gatomic "go.uber.org/atomic" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -var openDBFunc = openDB - -var errEmptyHandleVals = errors.New("empty handleVals for TiDB table") - -// After TiDB v6.2.0 we always enable tidb_enable_paging by default. -// see https://docs.pingcap.com/zh/tidb/dev/system-variables#tidb_enable_paging-%E4%BB%8E-v540-%E7%89%88%E6%9C%AC%E5%BC%80%E5%A7%8B%E5%BC%95%E5%85%A5 -var enablePagingVersion = semver.New("6.2.0") - -// Dumper is the dump progress structure -type Dumper struct { - tctx *tcontext.Context - cancelCtx context.CancelFunc - conf *Config - metrics *metrics - - extStore storage.ExternalStorage - dbHandle *sql.DB - - tidbPDClientForGC pd.Client - selectTiDBTableRegionFunc func(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) - totalTables int64 - charsetAndDefaultCollationMap map[string]string - - speedRecorder *SpeedRecorder -} - -// NewDumper returns a new Dumper -func NewDumper(ctx context.Context, conf *Config) (*Dumper, error) { - failpoint.Inject("setExtStorage", func(val failpoint.Value) { - path := val.(string) - b, err := storage.ParseBackend(path, nil) - if err != nil { - panic(err) - } - s, err := storage.New(context.Background(), b, &storage.ExternalStorageOptions{}) - if err != nil { - panic(err) - } - conf.ExtStorage = s - }) - - tctx, cancelFn := tcontext.Background().WithContext(ctx).WithCancel() - d := &Dumper{ - tctx: tctx, - conf: conf, - cancelCtx: cancelFn, - selectTiDBTableRegionFunc: selectTiDBTableRegion, - speedRecorder: NewSpeedRecorder(), - } - - var err error - - d.metrics = newMetrics(conf.PromFactory, conf.Labels) - d.metrics.registerTo(conf.PromRegistry) - defer func() { - if err != nil { - d.metrics.unregisterFrom(conf.PromRegistry) - } - }() - - err = adjustConfig(conf, - buildTLSConfig, - validateSpecifiedSQL, - adjustFileFormat) - if err != nil { - return nil, err - } - failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { - d.conf.IOTotalBytes = gatomic.NewUint64(0) - d.conf.Net = uuid.New().String() - go func() { - for { - time.Sleep(10 * time.Millisecond) - d.tctx.L().Logger.Info("IOTotalBytes", zap.Uint64("IOTotalBytes", d.conf.IOTotalBytes.Load())) - } - }() - }) - - err = runSteps(d, - initLogger, - createExternalStore, - startHTTPService, - openSQLDB, - detectServerInfo, - resolveAutoConsistency, - - validateResolveAutoConsistency, - tidbSetPDClientForGC, - tidbGetSnapshot, - tidbStartGCSavepointUpdateService, - - setSessionParam) - return d, err -} - -// Dump dumps table from database -// nolint: gocyclo -func (d *Dumper) Dump() (dumpErr error) { - initColTypeRowReceiverMap() - var ( - conn *sql.Conn - err error - conCtrl ConsistencyController - ) - tctx, conf, pool := d.tctx, d.conf, d.dbHandle - tctx.L().Info("begin to run Dump", zap.Stringer("conf", conf)) - m := newGlobalMetadata(tctx, d.extStore, conf.Snapshot) - repeatableRead := needRepeatableRead(conf.ServerInfo.ServerType, conf.Consistency) - defer func() { - if dumpErr == nil { - _ = m.writeGlobalMetaData() - } - }() - - // for consistency lock, we should get table list at first to generate the lock tables SQL - if conf.Consistency == ConsistencyTypeLock { - conn, err = createConnWithConsistency(tctx, pool, repeatableRead) - if err != nil { - return errors.Trace(err) - } - if err = prepareTableListToDump(tctx, conf, conn); err != nil { - _ = conn.Close() - return err - } - _ = conn.Close() - } - - conCtrl, err = NewConsistencyController(tctx, conf, pool) - if err != nil { - return err - } - if err = conCtrl.Setup(tctx); err != nil { - return errors.Trace(err) - } - // To avoid lock is not released - defer func() { - err = conCtrl.TearDown(tctx) - if err != nil { - tctx.L().Warn("fail to tear down consistency controller", zap.Error(err)) - } - }() - - metaConn, err := createConnWithConsistency(tctx, pool, repeatableRead) - if err != nil { - return err - } - defer func() { - _ = metaConn.Close() - }() - m.recordStartTime(time.Now()) - // for consistency lock, we can write snapshot info after all tables are locked. - // the binlog pos may changed because there is still possible write between we lock tables and write master status. - // but for the locked tables doing replication that starts from metadata is safe. - // for consistency flush, record snapshot after whole tables are locked. The recorded meta info is exactly the locked snapshot. - // for consistency snapshot, we should use the snapshot that we get/set at first in metadata. TiDB will assure the snapshot of TSO. - // for consistency none, the binlog pos in metadata might be earlier than dumped data. We need to enable safe-mode to assure data safety. - err = m.recordGlobalMetaData(metaConn, conf.ServerInfo.ServerType, false) - if err != nil { - tctx.L().Info("get global metadata failed", log.ShortError(err)) - } - - if d.conf.CollationCompatible == StrictCollationCompatible { - //init charset and default collation map - d.charsetAndDefaultCollationMap, err = GetCharsetAndDefaultCollation(tctx.Context, metaConn) - if err != nil { - return err - } - } - - // for other consistencies, we should get table list after consistency is set up and GlobalMetaData is cached - if conf.Consistency != ConsistencyTypeLock { - if err = prepareTableListToDump(tctx, conf, metaConn); err != nil { - return err - } - } - if err = d.renewSelectTableRegionFuncForLowerTiDB(tctx); err != nil { - tctx.L().Info("cannot update select table region info for TiDB", log.ShortError(err)) - } - - atomic.StoreInt64(&d.totalTables, int64(calculateTableCount(conf.Tables))) - - rebuildMetaConn := func(conn *sql.Conn, updateMeta bool) (*sql.Conn, error) { - _ = conn.Raw(func(any) error { - // return an `ErrBadConn` to ensure close the connection, but do not put it back to the pool. - // if we choose to use `Close`, it will always put the connection back to the pool. - return driver.ErrBadConn - }) - - newConn, err1 := createConnWithConsistency(tctx, pool, repeatableRead) - if err1 != nil { - return conn, errors.Trace(err1) - } - conn = newConn - // renew the master status after connection. dm can't close safe-mode until dm reaches current pos - if updateMeta && conf.PosAfterConnect { - err1 = m.recordGlobalMetaData(conn, conf.ServerInfo.ServerType, true) - if err1 != nil { - return conn, errors.Trace(err1) - } - } - return conn, nil - } - - rebuildConn := func(conn *sql.Conn, updateMeta bool) (*sql.Conn, error) { - // make sure that the lock connection is still alive - err1 := conCtrl.PingContext(tctx) - if err1 != nil { - return conn, errors.Trace(err1) - } - return rebuildMetaConn(conn, updateMeta) - } - - chanSize := defaultTaskChannelCapacity - failpoint.Inject("SmallDumpChanSize", func() { - chanSize = 1 - }) - taskIn, taskOut := infiniteChan[Task]() - // todo: refine metrics - AddGauge(d.metrics.taskChannelCapacity, float64(chanSize)) - wg, writingCtx := errgroup.WithContext(tctx) - writerCtx := tctx.WithContext(writingCtx) - writers, tearDownWriters, err := d.startWriters(writerCtx, wg, taskOut, rebuildConn) - if err != nil { - return err - } - defer tearDownWriters() - - if conf.TransactionalConsistency { - if conf.Consistency == ConsistencyTypeFlush || conf.Consistency == ConsistencyTypeLock { - tctx.L().Info("All the dumping transactions have started. Start to unlock tables") - } - if err = conCtrl.TearDown(tctx); err != nil { - return errors.Trace(err) - } - } - // Inject consistency failpoint test after we release the table lock - failpoint.Inject("ConsistencyCheck", nil) - - if conf.PosAfterConnect { - // record again, to provide a location to exit safe mode for DM - err = m.recordGlobalMetaData(metaConn, conf.ServerInfo.ServerType, true) - if err != nil { - tctx.L().Info("get global metadata (after connection pool established) failed", log.ShortError(err)) - } - } - - summary.SetLogCollector(summary.NewLogCollector(tctx.L().Info)) - summary.SetUnit(summary.BackupUnit) - defer summary.Summary(summary.BackupUnit) - - logProgressCtx, logProgressCancel := tctx.WithCancel() - go d.runLogProgress(logProgressCtx) - defer logProgressCancel() - - tableDataStartTime := time.Now() - - failpoint.Inject("PrintTiDBMemQuotaQuery", func(_ failpoint.Value) { - row := d.dbHandle.QueryRowContext(tctx, "select @@tidb_mem_quota_query;") - var s string - err = row.Scan(&s) - if err != nil { - fmt.Println(errors.Trace(err)) - } else { - fmt.Printf("tidb_mem_quota_query == %s\n", s) - } - }) - baseConn := newBaseConn(metaConn, true, rebuildMetaConn) - - if conf.SQL == "" { - if err = d.dumpDatabases(writerCtx, baseConn, taskIn); err != nil && !errors.ErrorEqual(err, context.Canceled) { - return err - } - } else { - d.dumpSQL(writerCtx, baseConn, taskIn) - } - d.metrics.progressReady.Store(true) - close(taskIn) - failpoint.Inject("EnableLogProgress", func() { - time.Sleep(1 * time.Second) - tctx.L().Debug("progress ready, sleep 1s") - }) - _ = baseConn.DBConn.Close() - if err := wg.Wait(); err != nil { - summary.CollectFailureUnit("dump table data", err) - return errors.Trace(err) - } - summary.CollectSuccessUnit("dump cost", countTotalTask(writers), time.Since(tableDataStartTime)) - - summary.SetSuccessStatus(true) - m.recordFinishTime(time.Now()) - return nil -} - -func (d *Dumper) startWriters(tctx *tcontext.Context, wg *errgroup.Group, taskChan <-chan Task, - rebuildConnFn func(*sql.Conn, bool) (*sql.Conn, error)) ([]*Writer, func(), error) { - conf, pool := d.conf, d.dbHandle - writers := make([]*Writer, conf.Threads) - for i := 0; i < conf.Threads; i++ { - conn, err := createConnWithConsistency(tctx, pool, needRepeatableRead(conf.ServerInfo.ServerType, conf.Consistency)) - if err != nil { - return nil, func() {}, err - } - writer := NewWriter(tctx, int64(i), conf, conn, d.extStore, d.metrics) - writer.rebuildConnFn = rebuildConnFn - writer.setFinishTableCallBack(func(task Task) { - if _, ok := task.(*TaskTableData); ok { - IncCounter(d.metrics.finishedTablesCounter) - // FIXME: actually finishing the last chunk doesn't means this table is 'finished'. - // We can call this table is 'finished' if all its chunks are finished. - // Comment this log now to avoid ambiguity. - // tctx.L().Debug("finished dumping table data", - // zap.String("database", td.Meta.DatabaseName()), - // zap.String("table", td.Meta.TableName())) - failpoint.Inject("EnableLogProgress", func() { - time.Sleep(1 * time.Second) - tctx.L().Debug("EnableLogProgress, sleep 1s") - }) - } - }) - writer.setFinishTaskCallBack(func(task Task) { - IncGauge(d.metrics.taskChannelCapacity) - if td, ok := task.(*TaskTableData); ok { - d.metrics.completedChunks.Add(1) - tctx.L().Debug("finish dumping table data task", - zap.String("database", td.Meta.DatabaseName()), - zap.String("table", td.Meta.TableName()), - zap.Int("chunkIdx", td.ChunkIndex)) - } - }) - wg.Go(func() error { - return writer.run(taskChan) - }) - writers[i] = writer - } - tearDown := func() { - for _, w := range writers { - _ = w.conn.Close() - } - } - return writers, tearDown, nil -} - -func (d *Dumper) dumpDatabases(tctx *tcontext.Context, metaConn *BaseConn, taskChan chan<- Task) error { - conf := d.conf - allTables := conf.Tables - - // policy should be created before database - // placement policy in other server type can be different, so we only handle the tidb server - if conf.ServerInfo.ServerType == version.ServerTypeTiDB { - policyNames, err := ListAllPlacementPolicyNames(tctx, metaConn) - if err != nil { - errCause := errors.Cause(err) - if mysqlErr, ok := errCause.(*mysql.MySQLError); ok && mysqlErr.Number == ErrNoSuchTable { - // some old tidb version and other server type doesn't support placement rules, we can skip it. - tctx.L().Debug("cannot dump placement policy, maybe the server doesn't support it", log.ShortError(err)) - } else { - tctx.L().Warn("fail to dump placement policy: ", log.ShortError(err)) - } - } - for _, policy := range policyNames { - createPolicySQL, err := ShowCreatePlacementPolicy(tctx, metaConn, policy) - if err != nil { - return errors.Trace(err) - } - wrappedCreatePolicySQL := fmt.Sprintf("/*T![placement] %s */", createPolicySQL) - task := NewTaskPolicyMeta(policy, wrappedCreatePolicySQL) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - } - } - - parser1 := parser.New() - for dbName, tables := range allTables { - if !conf.NoSchemas { - createDatabaseSQL, err := ShowCreateDatabase(tctx, metaConn, dbName) - if err != nil { - return errors.Trace(err) - } - - // adjust db collation - createDatabaseSQL, err = adjustDatabaseCollation(tctx, d.conf.CollationCompatible, parser1, createDatabaseSQL, d.charsetAndDefaultCollationMap) - if err != nil { - return errors.Trace(err) - } - - task := NewTaskDatabaseMeta(dbName, createDatabaseSQL) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - } - - for _, table := range tables { - tctx.L().Debug("start dumping table...", zap.String("database", dbName), - zap.String("table", table.Name)) - meta, err := dumpTableMeta(tctx, conf, metaConn, dbName, table) - if err != nil { - return errors.Trace(err) - } - - if !conf.NoSchemas { - switch table.Type { - case TableTypeView: - task := NewTaskViewMeta(dbName, table.Name, meta.ShowCreateTable(), meta.ShowCreateView()) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - case TableTypeSequence: - task := NewTaskSequenceMeta(dbName, table.Name, meta.ShowCreateTable()) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - default: - // adjust table collation - newCreateSQL, err := adjustTableCollation(tctx, d.conf.CollationCompatible, parser1, meta.ShowCreateTable(), d.charsetAndDefaultCollationMap) - if err != nil { - return errors.Trace(err) - } - meta.(*tableMeta).showCreateTable = newCreateSQL - - task := NewTaskTableMeta(dbName, table.Name, meta.ShowCreateTable()) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - } - } - if table.Type == TableTypeBase { - err = d.dumpTableData(tctx, metaConn, meta, taskChan) - if err != nil { - return errors.Trace(err) - } - } - } - } - return nil -} - -// adjustDatabaseCollation adjusts db collation and return new create sql and collation -func adjustDatabaseCollation(tctx *tcontext.Context, collationCompatible string, parser *parser.Parser, originSQL string, charsetAndDefaultCollationMap map[string]string) (string, error) { - if collationCompatible != StrictCollationCompatible { - return originSQL, nil - } - stmt, err := parser.ParseOneStmt(originSQL, "", "") - if err != nil { - tctx.L().Warn("parse create database error, maybe tidb parser doesn't support it", zap.String("originSQL", originSQL), log.ShortError(err)) - return originSQL, nil - } - createStmt, ok := stmt.(*ast.CreateDatabaseStmt) - if !ok { - return originSQL, nil - } - var charset string - for _, createOption := range createStmt.Options { - // already have 'Collation' - if createOption.Tp == ast.DatabaseOptionCollate { - return originSQL, nil - } - if createOption.Tp == ast.DatabaseOptionCharset { - charset = createOption.Value - } - } - // get db collation - collation, ok := charsetAndDefaultCollationMap[strings.ToLower(charset)] - if !ok { - tctx.L().Warn("not found database charset default collation.", zap.String("originSQL", originSQL), zap.String("charset", strings.ToLower(charset))) - return originSQL, nil - } - // add collation - createStmt.Options = append(createStmt.Options, &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: collation}) - // rewrite sql - var b []byte - bf := bytes.NewBuffer(b) - err = createStmt.Restore(&format.RestoreCtx{ - Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment, - In: bf, - }) - if err != nil { - return "", errors.Trace(err) - } - return bf.String(), nil -} - -// adjustTableCollation adjusts table collation -func adjustTableCollation(tctx *tcontext.Context, collationCompatible string, parser *parser.Parser, originSQL string, charsetAndDefaultCollationMap map[string]string) (string, error) { - if collationCompatible != StrictCollationCompatible { - return originSQL, nil - } - stmt, err := parser.ParseOneStmt(originSQL, "", "") - if err != nil { - tctx.L().Warn("parse create table error, maybe tidb parser doesn't support it", zap.String("originSQL", originSQL), log.ShortError(err)) - return originSQL, nil - } - createStmt, ok := stmt.(*ast.CreateTableStmt) - if !ok { - return originSQL, nil - } - var charset string - var collation string - for _, createOption := range createStmt.Options { - // already have 'Collation' - if createOption.Tp == ast.TableOptionCollate { - collation = createOption.StrValue - break - } - if createOption.Tp == ast.TableOptionCharset { - charset = createOption.StrValue - } - } - - if collation == "" && charset != "" { - collation, ok := charsetAndDefaultCollationMap[strings.ToLower(charset)] - if !ok { - tctx.L().Warn("not found table charset default collation.", zap.String("originSQL", originSQL), zap.String("charset", strings.ToLower(charset))) - return originSQL, nil - } - - // add collation - createStmt.Options = append(createStmt.Options, &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: collation}) - } - - // adjust columns collation - adjustColumnsCollation(tctx, createStmt, charsetAndDefaultCollationMap) - - // rewrite sql - var b []byte - bf := bytes.NewBuffer(b) - err = createStmt.Restore(&format.RestoreCtx{ - Flags: format.DefaultRestoreFlags | format.RestoreTiDBSpecialComment, - In: bf, - }) - if err != nil { - return "", errors.Trace(err) - } - return bf.String(), nil -} - -// adjustColumnsCollation adds column's collation. -func adjustColumnsCollation(tctx *tcontext.Context, createStmt *ast.CreateTableStmt, charsetAndDefaultCollationMap map[string]string) { -ColumnLoop: - for _, col := range createStmt.Cols { - for _, options := range col.Options { - // already have 'Collation' - if options.Tp == ast.ColumnOptionCollate { - continue ColumnLoop - } - } - fieldType := col.Tp - if fieldType.GetCollate() != "" { - continue - } - if fieldType.GetCharset() != "" { - // just have charset - collation, ok := charsetAndDefaultCollationMap[strings.ToLower(fieldType.GetCharset())] - if !ok { - tctx.L().Warn("not found charset default collation for column.", zap.String("table", createStmt.Table.Name.String()), zap.String("column", col.Name.String()), zap.String("charset", strings.ToLower(fieldType.GetCharset()))) - continue - } - fieldType.SetCollate(collation) - } - } -} - -func (d *Dumper) dumpTableData(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { - conf := d.conf - if conf.NoData { - return nil - } - - // Update total rows - fieldName, _ := pickupPossibleField(tctx, meta, conn) - c := estimateCount(tctx, meta.DatabaseName(), meta.TableName(), conn, fieldName, conf) - AddCounter(d.metrics.estimateTotalRowsCounter, float64(c)) - - if conf.Rows == UnspecifiedSize { - return d.sequentialDumpTable(tctx, conn, meta, taskChan) - } - return d.concurrentDumpTable(tctx, conn, meta, taskChan) -} - -func (d *Dumper) buildConcatTask(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (*TaskTableData, error) { - tableChan := make(chan Task, 128) - errCh := make(chan error, 1) - go func() { - // adjust rows to suitable rows for this table - d.conf.Rows = GetSuitableRows(meta.AvgRowLength()) - err := d.concurrentDumpTable(tctx, conn, meta, tableChan) - d.conf.Rows = UnspecifiedSize - if err != nil { - errCh <- err - } else { - close(errCh) - } - }() - tableDataArr := make([]*tableData, 0) - handleSubTask := func(task Task) { - tableTask, ok := task.(*TaskTableData) - if !ok { - tctx.L().Warn("unexpected task when splitting table chunks", zap.String("task", tableTask.Brief())) - return - } - tableDataInst, ok := tableTask.Data.(*tableData) - if !ok { - tctx.L().Warn("unexpected task.Data when splitting table chunks", zap.String("task", tableTask.Brief())) - return - } - tableDataArr = append(tableDataArr, tableDataInst) - d.metrics.totalChunks.Dec() - } - for { - select { - case err, ok := <-errCh: - if !ok { - // make sure all the subtasks in tableChan are handled - for len(tableChan) > 0 { - task := <-tableChan - handleSubTask(task) - } - if len(tableDataArr) <= 1 { - return nil, nil - } - queries := make([]string, 0, len(tableDataArr)) - colLen := tableDataArr[0].colLen - for _, tableDataInst := range tableDataArr { - queries = append(queries, tableDataInst.query) - if colLen != tableDataInst.colLen { - tctx.L().Warn("colLen varies for same table", - zap.Int("oldColLen", colLen), - zap.String("oldQuery", queries[0]), - zap.Int("newColLen", tableDataInst.colLen), - zap.String("newQuery", tableDataInst.query)) - return nil, nil - } - } - return d.newTaskTableData(meta, newMultiQueriesChunk(queries, colLen), 0, 1), nil - } - return nil, err - case task := <-tableChan: - handleSubTask(task) - } - } -} - -func (d *Dumper) dumpWholeTableDirectly(tctx *tcontext.Context, meta TableMeta, taskChan chan<- Task, partition, orderByClause string, currentChunk, totalChunks int) error { - conf := d.conf - tableIR := SelectAllFromTable(conf, meta, partition, orderByClause) - task := d.newTaskTableData(meta, tableIR, currentChunk, totalChunks) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - return nil -} - -func (d *Dumper) sequentialDumpTable(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { - conf := d.conf - if conf.ServerInfo.ServerType == version.ServerTypeTiDB { - task, err := d.buildConcatTask(tctx, conn, meta) - if err != nil { - return errors.Trace(err) - } - if task != nil { - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - return nil - } - tctx.L().Info("didn't build tidb concat sqls, will select all from table now", - zap.String("database", meta.DatabaseName()), - zap.String("table", meta.TableName())) - } - orderByClause, err := buildOrderByClause(tctx, conf, conn, meta.DatabaseName(), meta.TableName(), meta.HasImplicitRowID()) - if err != nil { - return err - } - return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) -} - -// concurrentDumpTable tries to split table into several chunks to dump -func (d *Dumper) concurrentDumpTable(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { - conf := d.conf - db, tbl := meta.DatabaseName(), meta.TableName() - if conf.ServerInfo.ServerType == version.ServerTypeTiDB && - conf.ServerInfo.ServerVersion != nil && - (conf.ServerInfo.ServerVersion.Compare(*tableSampleVersion) >= 0 || - (conf.ServerInfo.HasTiKV && conf.ServerInfo.ServerVersion.Compare(*decodeRegionVersion) >= 0)) { - err := d.concurrentDumpTiDBTables(tctx, conn, meta, taskChan) - // don't retry on context error and successful tasks - if err2 := errors.Cause(err); err2 == nil || err2 == context.DeadlineExceeded || err2 == context.Canceled { - return err - } else if err2 != errEmptyHandleVals { - tctx.L().Info("fallback to concurrent dump tables using rows due to some problem. This won't influence the whole dump process", - zap.String("database", db), zap.String("table", tbl), log.ShortError(err)) - } - } - - orderByClause, err := buildOrderByClause(tctx, conf, conn, db, tbl, meta.HasImplicitRowID()) - if err != nil { - return err - } - - field, err := pickupPossibleField(tctx, meta, conn) - if err != nil || field == "" { - // skip split chunk logic if not found proper field - tctx.L().Info("fallback to sequential dump due to no proper field. This won't influence the whole dump process", - zap.String("database", db), zap.String("table", tbl), log.ShortError(err)) - return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) - } - - count := estimateCount(d.tctx, db, tbl, conn, field, conf) - tctx.L().Info("get estimated rows count", - zap.String("database", db), - zap.String("table", tbl), - zap.Uint64("estimateCount", count)) - if count < conf.Rows { - // skip chunk logic if estimates are low - tctx.L().Info("fallback to sequential dump due to estimate count < rows. This won't influence the whole dump process", - zap.Uint64("estimate count", count), - zap.Uint64("conf.rows", conf.Rows), - zap.String("database", db), - zap.String("table", tbl)) - return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) - } - - min, max, err := d.selectMinAndMaxIntValue(tctx, conn, db, tbl, field) - if err != nil { - tctx.L().Info("fallback to sequential dump due to cannot get bounding values. This won't influence the whole dump process", - log.ShortError(err)) - return d.dumpWholeTableDirectly(tctx, meta, taskChan, "", orderByClause, 0, 1) - } - tctx.L().Debug("get int bounding values", - zap.String("lower", min.String()), - zap.String("upper", max.String())) - - // every chunk would have eventual adjustments - estimatedChunks := count / conf.Rows - estimatedStep := new(big.Int).Sub(max, min).Uint64()/estimatedChunks + 1 - bigEstimatedStep := new(big.Int).SetUint64(estimatedStep) - cutoff := new(big.Int).Set(min) - totalChunks := estimatedChunks - if estimatedStep == 1 { - totalChunks = new(big.Int).Sub(max, min).Uint64() + 1 - } - - selectField, selectLen := meta.SelectedField(), meta.SelectedLen() - - chunkIndex := 0 - nullValueCondition := "" - if conf.Where == "" { - nullValueCondition = fmt.Sprintf("`%s` IS NULL OR ", escapeString(field)) - } - for max.Cmp(cutoff) >= 0 { - nextCutOff := new(big.Int).Add(cutoff, bigEstimatedStep) - where := fmt.Sprintf("%s(`%s` >= %d AND `%s` < %d)", nullValueCondition, escapeString(field), cutoff, escapeString(field), nextCutOff) - query := buildSelectQuery(db, tbl, selectField, "", buildWhereCondition(conf, where), orderByClause) - if len(nullValueCondition) > 0 { - nullValueCondition = "" - } - task := d.newTaskTableData(meta, newTableData(query, selectLen, false), chunkIndex, int(totalChunks)) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - cutoff = nextCutOff - chunkIndex++ - } - return nil -} - -func (d *Dumper) sendTaskToChan(tctx *tcontext.Context, task Task, taskChan chan<- Task) (ctxDone bool) { - select { - case <-tctx.Done(): - return true - case taskChan <- task: - tctx.L().Debug("send task to writer", - zap.String("task", task.Brief())) - DecGauge(d.metrics.taskChannelCapacity) - return false - } -} - -func (d *Dumper) selectMinAndMaxIntValue(tctx *tcontext.Context, conn *BaseConn, db, tbl, field string) (*big.Int, *big.Int, error) { - conf, zero := d.conf, &big.Int{} - query := fmt.Sprintf("SELECT MIN(`%s`),MAX(`%s`) FROM `%s`.`%s`", - escapeString(field), escapeString(field), escapeString(db), escapeString(tbl)) - if conf.Where != "" { - query = fmt.Sprintf("%s WHERE %s", query, conf.Where) - } - tctx.L().Debug("split chunks", zap.String("query", query)) - - var smin sql.NullString - var smax sql.NullString - err := conn.QuerySQL(tctx, func(rows *sql.Rows) error { - err := rows.Scan(&smin, &smax) - rows.Close() - return err - }, func() {}, query) - if err != nil { - return zero, zero, errors.Annotatef(err, "can't get min/max values to split chunks, query: %s", query) - } - if !smax.Valid || !smin.Valid { - // found no data - return zero, zero, errors.Errorf("no invalid min/max value found in query %s", query) - } - - max := new(big.Int) - min := new(big.Int) - var ok bool - if max, ok = max.SetString(smax.String, 10); !ok { - return zero, zero, errors.Errorf("fail to convert max value %s in query %s", smax.String, query) - } - if min, ok = min.SetString(smin.String, 10); !ok { - return zero, zero, errors.Errorf("fail to convert min value %s in query %s", smin.String, query) - } - return min, max, nil -} - -func (d *Dumper) concurrentDumpTiDBTables(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task) error { - db, tbl := meta.DatabaseName(), meta.TableName() - - var ( - handleColNames []string - handleVals [][]string - err error - ) - // for TiDB v5.0+, we can use table sample directly - if d.conf.ServerInfo.ServerVersion.Compare(*tableSampleVersion) >= 0 { - tctx.L().Debug("dumping TiDB tables with TABLESAMPLE", - zap.String("database", db), zap.String("table", tbl)) - handleColNames, handleVals, err = selectTiDBTableSample(tctx, conn, meta) - } else { - // for TiDB v3.0+, we can use table region decode in TiDB directly - tctx.L().Debug("dumping TiDB tables with TABLE REGIONS", - zap.String("database", db), zap.String("table", tbl)) - var partitions []string - if d.conf.ServerInfo.ServerVersion.Compare(*gcSafePointVersion) >= 0 { - partitions, err = GetPartitionNames(tctx, conn, db, tbl) - } - if err == nil { - if len(partitions) != 0 { - return d.concurrentDumpTiDBPartitionTables(tctx, conn, meta, taskChan, partitions) - } - handleColNames, handleVals, err = d.selectTiDBTableRegionFunc(tctx, conn, meta) - } - } - if err != nil { - return err - } - return d.sendConcurrentDumpTiDBTasks(tctx, meta, taskChan, handleColNames, handleVals, "", 0, len(handleVals)+1) -} - -func (d *Dumper) concurrentDumpTiDBPartitionTables(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, taskChan chan<- Task, partitions []string) error { - db, tbl := meta.DatabaseName(), meta.TableName() - tctx.L().Debug("dumping TiDB tables with TABLE REGIONS for partition table", - zap.String("database", db), zap.String("table", tbl), zap.Strings("partitions", partitions)) - - startChunkIdx := 0 - totalChunk := 0 - cachedHandleVals := make([][][]string, len(partitions)) - - handleColNames, _, err := selectTiDBRowKeyFields(tctx, conn, meta, checkTiDBTableRegionPkFields) - if err != nil { - return err - } - // cache handleVals here to calculate the total chunks - for i, partition := range partitions { - handleVals, err := selectTiDBPartitionRegion(tctx, conn, db, tbl, partition) - if err != nil { - return err - } - totalChunk += len(handleVals) + 1 - cachedHandleVals[i] = handleVals - } - for i, partition := range partitions { - err := d.sendConcurrentDumpTiDBTasks(tctx, meta, taskChan, handleColNames, cachedHandleVals[i], partition, startChunkIdx, totalChunk) - if err != nil { - return err - } - startChunkIdx += len(cachedHandleVals[i]) + 1 - } - return nil -} - -func (d *Dumper) sendConcurrentDumpTiDBTasks(tctx *tcontext.Context, - meta TableMeta, taskChan chan<- Task, - handleColNames []string, handleVals [][]string, partition string, - startChunkIdx, totalChunk int) error { - db, tbl := meta.DatabaseName(), meta.TableName() - if len(handleVals) == 0 { - if partition == "" { - // return error to make outside function try using rows method to dump data - return errors.Annotatef(errEmptyHandleVals, "table: `%s`.`%s`", escapeString(db), escapeString(tbl)) - } - return d.dumpWholeTableDirectly(tctx, meta, taskChan, partition, buildOrderByClauseString(handleColNames), startChunkIdx, totalChunk) - } - conf := d.conf - selectField, selectLen := meta.SelectedField(), meta.SelectedLen() - where := buildWhereClauses(handleColNames, handleVals) - orderByClause := buildOrderByClauseString(handleColNames) - - for i, w := range where { - query := buildSelectQuery(db, tbl, selectField, partition, buildWhereCondition(conf, w), orderByClause) - task := d.newTaskTableData(meta, newTableData(query, selectLen, false), i+startChunkIdx, totalChunk) - ctxDone := d.sendTaskToChan(tctx, task, taskChan) - if ctxDone { - return tctx.Err() - } - } - return nil -} - -// L returns real logger -func (d *Dumper) L() log.Logger { - return d.tctx.L() -} - -func selectTiDBTableSample(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { - pkFields, pkColTypes, err := selectTiDBRowKeyFields(tctx, conn, meta, nil) - if err != nil { - return nil, nil, errors.Trace(err) - } - - query := buildTiDBTableSampleQuery(pkFields, meta.DatabaseName(), meta.TableName()) - pkValNum := len(pkFields) - var iter SQLRowIter - rowRec := MakeRowReceiver(pkColTypes) - buf := new(bytes.Buffer) - - err = conn.QuerySQL(tctx, func(rows *sql.Rows) error { - if iter == nil { - iter = &rowIter{ - rows: rows, - args: make([]any, pkValNum), - } - } - err = iter.Decode(rowRec) - if err != nil { - return errors.Trace(err) - } - pkValRow := make([]string, 0, pkValNum) - for _, rec := range rowRec.receivers { - rec.WriteToBuffer(buf, true) - pkValRow = append(pkValRow, buf.String()) - buf.Reset() - } - pkVals = append(pkVals, pkValRow) - return nil - }, func() { - if iter != nil { - _ = iter.Close() - iter = nil - } - rowRec = MakeRowReceiver(pkColTypes) - pkVals = pkVals[:0] - buf.Reset() - }, query) - if err == nil && iter != nil && iter.Error() != nil { - err = iter.Error() - } - - return pkFields, pkVals, err -} - -func buildTiDBTableSampleQuery(pkFields []string, dbName, tblName string) string { - template := "SELECT %s FROM `%s`.`%s` TABLESAMPLE REGIONS() ORDER BY %s" - quotaPk := make([]string, len(pkFields)) - for i, s := range pkFields { - quotaPk[i] = fmt.Sprintf("`%s`", escapeString(s)) - } - pks := strings.Join(quotaPk, ",") - return fmt.Sprintf(template, pks, escapeString(dbName), escapeString(tblName), pks) -} - -func selectTiDBRowKeyFields(tctx *tcontext.Context, conn *BaseConn, meta TableMeta, checkPkFields func([]string, []string) error) (pkFields, pkColTypes []string, err error) { - if meta.HasImplicitRowID() { - pkFields, pkColTypes = []string{"_tidb_rowid"}, []string{"BIGINT"} - } else { - pkFields, pkColTypes, err = GetPrimaryKeyAndColumnTypes(tctx, conn, meta) - if err == nil { - if checkPkFields != nil { - err = checkPkFields(pkFields, pkColTypes) - } - } - } - return -} - -func checkTiDBTableRegionPkFields(pkFields, pkColTypes []string) (err error) { - if len(pkFields) != 1 || len(pkColTypes) != 1 { - err = errors.Errorf("unsupported primary key for selectTableRegion. pkFields: [%s], pkColTypes: [%s]", strings.Join(pkFields, ", "), strings.Join(pkColTypes, ", ")) - return - } - if _, ok := dataTypeInt[pkColTypes[0]]; !ok { - err = errors.Errorf("unsupported primary key type for selectTableRegion. pkFields: [%s], pkColTypes: [%s]", strings.Join(pkFields, ", "), strings.Join(pkColTypes, ", ")) - } - return -} - -func selectTiDBTableRegion(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { - pkFields, _, err = selectTiDBRowKeyFields(tctx, conn, meta, checkTiDBTableRegionPkFields) - if err != nil { - return - } - - var ( - startKey, decodedKey sql.NullString - rowID = -1 - ) - const ( - tableRegionSQL = "SELECT START_KEY,tidb_decode_key(START_KEY) from INFORMATION_SCHEMA.TIKV_REGION_STATUS s WHERE s.DB_NAME = ? AND s.TABLE_NAME = ? AND IS_INDEX = 0 ORDER BY START_KEY;" - tidbRowID = "_tidb_rowid=" - ) - dbName, tableName := meta.DatabaseName(), meta.TableName() - logger := tctx.L().With(zap.String("database", dbName), zap.String("table", tableName)) - err = conn.QuerySQL(tctx, func(rows *sql.Rows) error { - rowID++ - err = rows.Scan(&startKey, &decodedKey) - if err != nil { - return errors.Trace(err) - } - // first region's start key has no use. It may come from another table or might be invalid - if rowID == 0 { - return nil - } - if !startKey.Valid { - logger.Debug("meet invalid start key", zap.Int("rowID", rowID)) - return nil - } - if !decodedKey.Valid { - logger.Debug("meet invalid decoded start key", zap.Int("rowID", rowID), zap.String("startKey", startKey.String)) - return nil - } - pkVal, err2 := extractTiDBRowIDFromDecodedKey(tidbRowID, decodedKey.String) - if err2 != nil { - logger.Debug("cannot extract pkVal from decoded start key", - zap.Int("rowID", rowID), zap.String("startKey", startKey.String), zap.String("decodedKey", decodedKey.String), log.ShortError(err2)) - } else { - pkVals = append(pkVals, []string{pkVal}) - } - return nil - }, func() { - pkFields = pkFields[:0] - pkVals = pkVals[:0] - }, tableRegionSQL, dbName, tableName) - - return pkFields, pkVals, errors.Trace(err) -} - -func selectTiDBPartitionRegion(tctx *tcontext.Context, conn *BaseConn, dbName, tableName, partition string) (pkVals [][]string, err error) { - var startKeys [][]string - const ( - partitionRegionSQL = "SHOW TABLE `%s`.`%s` PARTITION(`%s`) REGIONS" - regionRowKey = "r_" - ) - logger := tctx.L().With(zap.String("database", dbName), zap.String("table", tableName), zap.String("partition", partition)) - startKeys, err = conn.QuerySQLWithColumns(tctx, []string{"START_KEY"}, fmt.Sprintf(partitionRegionSQL, escapeString(dbName), escapeString(tableName), escapeString(partition))) - if err != nil { - return - } - for rowID, startKey := range startKeys { - if rowID == 0 || len(startKey) != 1 { - continue - } - pkVal, err2 := extractTiDBRowIDFromDecodedKey(regionRowKey, startKey[0]) - if err2 != nil { - logger.Debug("show table region start key doesn't have rowID", - zap.Int("rowID", rowID), zap.String("startKey", startKey[0]), zap.Error(err2)) - } else { - pkVals = append(pkVals, []string{pkVal}) - } - } - - return pkVals, nil -} - -func extractTiDBRowIDFromDecodedKey(indexField, key string) (string, error) { - if p := strings.Index(key, indexField); p != -1 { - p += len(indexField) - return key[p:], nil - } - return "", errors.Errorf("decoded key %s doesn't have %s field", key, indexField) -} - -func getListTableTypeByConf(conf *Config) listTableType { - // use listTableByShowTableStatus by default because it has better performance - listType := listTableByShowTableStatus - if conf.Consistency == ConsistencyTypeLock { - // for consistency lock, we need to build the tables to dump as soon as possible - listType = listTableByInfoSchema - } else if conf.Consistency == ConsistencyTypeFlush && matchMysqlBugversion(conf.ServerInfo) { - // For some buggy versions of mysql, we need a workaround to get a list of table names. - listType = listTableByShowFullTables - } - return listType -} - -func prepareTableListToDump(tctx *tcontext.Context, conf *Config, db *sql.Conn) error { - if conf.SQL != "" { - return nil - } - - ifSeqExists, err := CheckIfSeqExists(db) - if err != nil { - return err - } - var listType listTableType - if ifSeqExists { - listType = listTableByShowFullTables - } else { - listType = getListTableTypeByConf(conf) - } - - if conf.SpecifiedTables { - return updateSpecifiedTablesMeta(tctx, db, conf.Tables, listType) - } - databases, err := prepareDumpingDatabases(tctx, conf, db) - if err != nil { - return err - } - - tableTypes := []TableType{TableTypeBase} - if !conf.NoViews { - tableTypes = append(tableTypes, TableTypeView) - } - if !conf.NoSequences { - tableTypes = append(tableTypes, TableTypeSequence) - } - - conf.Tables, err = ListAllDatabasesTables(tctx, db, databases, listType, tableTypes...) - if err != nil { - return err - } - - filterTables(tctx, conf) - return nil -} - -func dumpTableMeta(tctx *tcontext.Context, conf *Config, conn *BaseConn, db string, table *TableInfo) (TableMeta, error) { - tbl := table.Name - selectField, selectLen, err := buildSelectField(tctx, conn, db, tbl, conf.CompleteInsert) - if err != nil { - return nil, err - } - var ( - colTypes []*sql.ColumnType - hasImplicitRowID bool - ) - if conf.ServerInfo.ServerType == version.ServerTypeTiDB { - hasImplicitRowID, err = SelectTiDBRowID(tctx, conn, db, tbl) - if err != nil { - tctx.L().Info("check implicit rowID failed", zap.String("database", db), zap.String("table", tbl), log.ShortError(err)) - } - } - - // If all columns are generated - if table.Type == TableTypeBase { - if selectField == "" { - colTypes, err = GetColumnTypes(tctx, conn, "*", db, tbl) - } else { - colTypes, err = GetColumnTypes(tctx, conn, selectField, db, tbl) - } - } - if err != nil { - return nil, err - } - - meta := &tableMeta{ - avgRowLength: table.AvgRowLength, - database: db, - table: tbl, - colTypes: colTypes, - selectedField: selectField, - selectedLen: selectLen, - hasImplicitRowID: hasImplicitRowID, - specCmts: getSpecialComments(conf.ServerInfo.ServerType), - } - - if conf.NoSchemas { - return meta, nil - } - switch table.Type { - case TableTypeView: - viewName := table.Name - createTableSQL, createViewSQL, err1 := ShowCreateView(tctx, conn, db, viewName) - if err1 != nil { - return meta, err1 - } - meta.showCreateTable = createTableSQL - meta.showCreateView = createViewSQL - return meta, nil - case TableTypeSequence: - sequenceName := table.Name - createSequenceSQL, err2 := ShowCreateSequence(tctx, conn, db, sequenceName, conf) - if err2 != nil { - return meta, err2 - } - meta.showCreateTable = createSequenceSQL - return meta, nil - } - - createTableSQL, err := ShowCreateTable(tctx, conn, db, tbl) - if err != nil { - return nil, err - } - meta.showCreateTable = createTableSQL - return meta, nil -} - -func (d *Dumper) dumpSQL(tctx *tcontext.Context, metaConn *BaseConn, taskChan chan<- Task) { - conf := d.conf - meta := &tableMeta{} - data := newTableData(conf.SQL, 0, true) - task := d.newTaskTableData(meta, data, 0, 1) - c := detectEstimateRows(tctx, metaConn, fmt.Sprintf("EXPLAIN %s", conf.SQL), []string{"rows", "estRows", "count"}) - AddCounter(d.metrics.estimateTotalRowsCounter, float64(c)) - atomic.StoreInt64(&d.totalTables, int64(1)) - d.sendTaskToChan(tctx, task, taskChan) -} - -func canRebuildConn(consistency string, trxConsistencyOnly bool) bool { - switch consistency { - case ConsistencyTypeLock, ConsistencyTypeFlush: - return !trxConsistencyOnly - case ConsistencyTypeSnapshot, ConsistencyTypeNone: - return true - default: - return false - } -} - -// Close closes a Dumper and stop dumping immediately -func (d *Dumper) Close() error { - d.cancelCtx() - d.metrics.unregisterFrom(d.conf.PromRegistry) - if d.dbHandle != nil { - return d.dbHandle.Close() - } - return nil -} - -func runSteps(d *Dumper, steps ...func(*Dumper) error) error { - for _, st := range steps { - err := st(d) - if err != nil { - return err - } - } - return nil -} - -func initLogger(d *Dumper) error { - conf := d.conf - var ( - logger log.Logger - err error - props *pclog.ZapProperties - ) - // conf.Logger != nil means dumpling is used as a library - if conf.Logger != nil { - logger = log.NewAppLogger(conf.Logger) - } else { - logger, props, err = log.InitAppLogger(&log.Config{ - Level: conf.LogLevel, - File: conf.LogFile, - Format: conf.LogFormat, - }) - if err != nil { - return errors.Trace(err) - } - pclog.ReplaceGlobals(logger.Logger, props) - cli.LogLongVersion(logger) - } - d.tctx = d.tctx.WithLogger(logger) - return nil -} - -// createExternalStore is an initialization step of Dumper. -func createExternalStore(d *Dumper) error { - tctx, conf := d.tctx, d.conf - extStore, err := conf.createExternalStorage(tctx) - if err != nil { - return errors.Trace(err) - } - d.extStore = extStore - return nil -} - -// startHTTPService is an initialization step of Dumper. -func startHTTPService(d *Dumper) error { - conf := d.conf - if conf.StatusAddr != "" { - go func() { - err := startDumplingService(d.tctx, conf.StatusAddr) - if err != nil { - d.L().Info("meet error when stopping dumpling http service", log.ShortError(err)) - } - }() - } - return nil -} - -// openSQLDB is an initialization step of Dumper. -func openSQLDB(d *Dumper) error { - if d.conf.IOTotalBytes != nil { - mysql.RegisterDialContext(d.conf.Net, func(ctx context.Context, addr string) (net.Conn, error) { - dial := &net.Dialer{} - conn, err := dial.DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } - tcpConn := conn.(*net.TCPConn) - // try https://github.com/go-sql-driver/mysql/blob/bcc459a906419e2890a50fc2c99ea6dd927a88f2/connector.go#L56-L64 - err = tcpConn.SetKeepAlive(true) - if err != nil { - d.tctx.L().Logger.Warn("fail to keep alive", zap.Error(err)) - } - return util.NewTCPConnWithIOCounter(tcpConn, d.conf.IOTotalBytes), nil - }) - } - conf := d.conf - c, err := mysql.NewConnector(conf.GetDriverConfig("")) - if err != nil { - return errors.Trace(err) - } - d.dbHandle = sql.OpenDB(c) - return nil -} - -// detectServerInfo is an initialization step of Dumper. -func detectServerInfo(d *Dumper) error { - db, conf := d.dbHandle, d.conf - versionStr, err := version.FetchVersion(d.tctx.Context, db) - if err != nil { - conf.ServerInfo = ServerInfoUnknown - return err - } - conf.ServerInfo = version.ParseServerInfo(versionStr) - return nil -} - -// resolveAutoConsistency is an initialization step of Dumper. -func resolveAutoConsistency(d *Dumper) error { - conf := d.conf - if conf.Consistency != ConsistencyTypeAuto { - return nil - } - switch conf.ServerInfo.ServerType { - case version.ServerTypeTiDB: - conf.Consistency = ConsistencyTypeSnapshot - case version.ServerTypeMySQL, version.ServerTypeMariaDB: - conf.Consistency = ConsistencyTypeFlush - default: - conf.Consistency = ConsistencyTypeNone - } - - if conf.Consistency == ConsistencyTypeFlush { - timeout := time.Second * 5 - ctx, cancel := context.WithTimeout(d.tctx.Context, timeout) - defer cancel() - - // probe if upstream has enough privilege to FLUSH TABLE WITH READ LOCK - conn, err := d.dbHandle.Conn(ctx) - if err != nil { - return errors.New("failed to get connection from db pool after 5 seconds") - } - //nolint: errcheck - defer conn.Close() - - err = FlushTableWithReadLock(d.tctx, conn) - //nolint: errcheck - defer UnlockTables(d.tctx, conn) - if err != nil { - // fallback to ConsistencyTypeLock - d.tctx.L().Warn("error when use FLUSH TABLE WITH READ LOCK, fallback to LOCK TABLES", - zap.Error(err)) - conf.Consistency = ConsistencyTypeLock - } - } - return nil -} - -func validateResolveAutoConsistency(d *Dumper) error { - conf := d.conf - if conf.Consistency != ConsistencyTypeSnapshot && conf.Snapshot != "" { - return errors.Errorf("can't specify --snapshot when --consistency isn't snapshot, resolved consistency: %s", conf.Consistency) - } - return nil -} - -// tidbSetPDClientForGC is an initialization step of Dumper. -func tidbSetPDClientForGC(d *Dumper) error { - tctx, si, pool := d.tctx, d.conf.ServerInfo, d.dbHandle - if si.ServerType != version.ServerTypeTiDB || - si.ServerVersion == nil || - si.ServerVersion.Compare(*gcSafePointVersion) < 0 { - return nil - } - pdAddrs, err := GetPdAddrs(tctx, pool) - if err != nil { - tctx.L().Info("meet some problem while fetching pd addrs. This won't affect dump process", log.ShortError(err)) - return nil - } - if len(pdAddrs) > 0 { - doPdGC, err := checkSameCluster(tctx, pool, pdAddrs) - if err != nil { - tctx.L().Info("meet error while check whether fetched pd addr and TiDB belong to one cluster. This won't affect dump process", log.ShortError(err), zap.Strings("pdAddrs", pdAddrs)) - } else if doPdGC { - pdClient, err := pd.NewClientWithContext(tctx, pdAddrs, pd.SecurityOption{}) - if err != nil { - tctx.L().Info("create pd client to control GC failed. This won't affect dump process", log.ShortError(err), zap.Strings("pdAddrs", pdAddrs)) - } - d.tidbPDClientForGC = pdClient - } - } - return nil -} - -// tidbGetSnapshot is an initialization step of Dumper. -func tidbGetSnapshot(d *Dumper) error { - conf, doPdGC := d.conf, d.tidbPDClientForGC != nil - consistency := conf.Consistency - pool, tctx := d.dbHandle, d.tctx - snapshotConsistency := consistency == "snapshot" - if conf.Snapshot == "" && (doPdGC || snapshotConsistency) { - conn, err := pool.Conn(tctx) - if err != nil { - tctx.L().Warn("fail to open connection to get snapshot from TiDB", log.ShortError(err)) - // for consistency snapshot, we must get a snapshot here, or we will dump inconsistent data, but for other consistency we can ignore this error. - if !snapshotConsistency { - err = nil - } - return err - } - snapshot, err := getSnapshot(conn) - _ = conn.Close() - if err != nil { - tctx.L().Warn("fail to get snapshot from TiDB", log.ShortError(err)) - // for consistency snapshot, we must get a snapshot here, or we will dump inconsistent data, but for other consistency we can ignore this error. - if !snapshotConsistency { - err = nil - } - return err - } - conf.Snapshot = snapshot - } - return nil -} - -// tidbStartGCSavepointUpdateService is an initialization step of Dumper. -func tidbStartGCSavepointUpdateService(d *Dumper) error { - tctx, pool, conf := d.tctx, d.dbHandle, d.conf - snapshot, si := conf.Snapshot, conf.ServerInfo - if d.tidbPDClientForGC != nil { - snapshotTS, err := parseSnapshotToTSO(pool, snapshot) - if err != nil { - return err - } - go updateServiceSafePoint(tctx, d.tidbPDClientForGC, defaultDumpGCSafePointTTL, snapshotTS) - } else if si.ServerType == version.ServerTypeTiDB { - tctx.L().Warn("If the amount of data to dump is large, criteria: (data more than 60GB or dumped time more than 10 minutes)\n" + - "you'd better adjust the tikv_gc_life_time to avoid export failure due to TiDB GC during the dump process.\n" + - "Before dumping: run sql `update mysql.tidb set VARIABLE_VALUE = '720h' where VARIABLE_NAME = 'tikv_gc_life_time';` in tidb.\n" + - "After dumping: run sql `update mysql.tidb set VARIABLE_VALUE = '10m' where VARIABLE_NAME = 'tikv_gc_life_time';` in tidb.\n") - } - return nil -} - -func updateServiceSafePoint(tctx *tcontext.Context, pdClient pd.Client, ttl int64, snapshotTS uint64) { - updateInterval := time.Duration(ttl/2) * time.Second - tick := time.NewTicker(updateInterval) - dumplingServiceSafePointID := fmt.Sprintf("%s_%d", dumplingServiceSafePointPrefix, time.Now().UnixNano()) - tctx.L().Info("generate dumpling gc safePoint id", zap.String("id", dumplingServiceSafePointID)) - - for { - tctx.L().Debug("update PD safePoint limit with ttl", - zap.Uint64("safePoint", snapshotTS), - zap.Int64("ttl", ttl)) - for retryCnt := 0; retryCnt <= 10; retryCnt++ { - _, err := pdClient.UpdateServiceGCSafePoint(tctx, dumplingServiceSafePointID, ttl, snapshotTS) - if err == nil { - break - } - tctx.L().Debug("update PD safePoint failed", log.ShortError(err), zap.Int("retryTime", retryCnt)) - select { - case <-tctx.Done(): - return - case <-time.After(time.Second): - } - } - select { - case <-tctx.Done(): - return - case <-tick.C: - } - } -} - -// setDefaultSessionParams is a step to set default params for session params. -func setDefaultSessionParams(si version.ServerInfo, sessionParams map[string]any) { - defaultSessionParams := map[string]any{} - if si.ServerType == version.ServerTypeTiDB && si.HasTiKV && si.ServerVersion.Compare(*enablePagingVersion) >= 0 { - defaultSessionParams["tidb_enable_paging"] = "ON" - } - for k, v := range defaultSessionParams { - if _, ok := sessionParams[k]; !ok { - sessionParams[k] = v - } - } -} - -// setSessionParam is an initialization step of Dumper. -func setSessionParam(d *Dumper) error { - conf, pool := d.conf, d.dbHandle - si := conf.ServerInfo - consistency, snapshot := conf.Consistency, conf.Snapshot - sessionParam := conf.SessionParams - if si.ServerType == version.ServerTypeTiDB && conf.TiDBMemQuotaQuery != UnspecifiedSize { - sessionParam[TiDBMemQuotaQueryName] = conf.TiDBMemQuotaQuery - } - var err error - if snapshot != "" { - if si.ServerType != version.ServerTypeTiDB { - return errors.New("snapshot consistency is not supported for this server") - } - if consistency == ConsistencyTypeSnapshot { - conf.ServerInfo.HasTiKV, err = CheckTiDBWithTiKV(pool) - if err != nil { - d.L().Info("cannot check whether TiDB has TiKV, will apply tidb_snapshot by default. This won't affect dump process", log.ShortError(err)) - } - if conf.ServerInfo.HasTiKV { - sessionParam[snapshotVar] = snapshot - } - } - } - if d.dbHandle, err = resetDBWithSessionParams(d.tctx, pool, conf.GetDriverConfig(""), conf.SessionParams); err != nil { - return errors.Trace(err) - } - return nil -} - -func openDB(cfg *mysql.Config) (*sql.DB, error) { - c, err := mysql.NewConnector(cfg) - if err != nil { - return nil, errors.Trace(err) - } - return sql.OpenDB(c), nil -} - -func (d *Dumper) renewSelectTableRegionFuncForLowerTiDB(tctx *tcontext.Context) error { - conf := d.conf - if !(conf.ServerInfo.ServerType == version.ServerTypeTiDB && conf.ServerInfo.ServerVersion != nil && conf.ServerInfo.HasTiKV && - conf.ServerInfo.ServerVersion.Compare(*decodeRegionVersion) >= 0 && - conf.ServerInfo.ServerVersion.Compare(*gcSafePointVersion) < 0) { - tctx.L().Debug("no need to build region info because database is not TiDB 3.x") - return nil - } - // for TiDB v3.0+, the original selectTiDBTableRegionFunc will always fail, - // because TiDB v3.0 doesn't have `tidb_decode_key` function nor `DB_NAME`,`TABLE_NAME` columns in `INFORMATION_SCHEMA.TIKV_REGION_STATUS`. - // reference: https://github.com/pingcap/tidb/blob/c497d5c/dumpling/export/dump.go#L775 - // To avoid this function continuously returning errors and confusing users because we fail to init this function at first, - // selectTiDBTableRegionFunc is set to always return an ignorable error at first. - d.selectTiDBTableRegionFunc = func(_ *tcontext.Context, _ *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { - return nil, nil, errors.Annotatef(errEmptyHandleVals, "table: `%s`.`%s`", escapeString(meta.DatabaseName()), escapeString(meta.TableName())) - } - dbHandle, err := openDBFunc(conf.GetDriverConfig("")) - if err != nil { - return errors.Trace(err) - } - defer func() { - _ = dbHandle.Close() - }() - conn, err := dbHandle.Conn(tctx) - if err != nil { - return errors.Trace(err) - } - defer func() { - _ = conn.Close() - }() - dbInfos, err := GetDBInfo(conn, DatabaseTablesToMap(conf.Tables)) - if err != nil { - return errors.Trace(err) - } - regionsInfo, err := GetRegionInfos(conn) - if err != nil { - return errors.Trace(err) - } - tikvHelper := &helper.Helper{} - tableInfos := tikvHelper.GetRegionsTableInfo(regionsInfo, infoschema.DBInfoAsInfoSchema(dbInfos), nil) - - tableInfoMap := make(map[string]map[string][]int64, len(conf.Tables)) - for _, region := range regionsInfo.Regions { - tableList := tableInfos[region.ID] - for _, table := range tableList { - db, tbl := table.DB.Name.O, table.Table.Name.O - if _, ok := tableInfoMap[db]; !ok { - tableInfoMap[db] = make(map[string][]int64, len(conf.Tables[db])) - } - - key, err := hex.DecodeString(region.StartKey) - if err != nil { - d.L().Debug("invalid region start key", log.ShortError(err), zap.String("key", region.StartKey)) - continue - } - // Auto decode byte if needed. - _, bs, err := codec.DecodeBytes(key, nil) - if err == nil { - key = bs - } - // Try to decode it as a record key. - tableID, handle, err := tablecodec.DecodeRecordKey(key) - if err != nil { - d.L().Debug("cannot decode region start key", log.ShortError(err), zap.String("key", region.StartKey), zap.Int64("tableID", tableID)) - continue - } - if handle.IsInt() { - tableInfoMap[db][tbl] = append(tableInfoMap[db][tbl], handle.IntValue()) - } else { - d.L().Debug("not an int handle", log.ShortError(err), zap.Stringer("handle", handle)) - } - } - } - for _, tbInfos := range tableInfoMap { - for _, tbInfoLoop := range tbInfos { - // make sure tbInfo is only used in this loop - tbInfo := tbInfoLoop - slices.Sort(tbInfo) - } - } - - d.selectTiDBTableRegionFunc = func(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) (pkFields []string, pkVals [][]string, err error) { - pkFields, _, err = selectTiDBRowKeyFields(tctx, conn, meta, checkTiDBTableRegionPkFields) - if err != nil { - return - } - dbName, tableName := meta.DatabaseName(), meta.TableName() - if tbInfos, ok := tableInfoMap[dbName]; ok { - if tbInfo, ok := tbInfos[tableName]; ok { - pkVals = make([][]string, len(tbInfo)) - for i, val := range tbInfo { - pkVals[i] = []string{strconv.FormatInt(val, 10)} - } - } - } - return - } - - return nil -} - -func (d *Dumper) newTaskTableData(meta TableMeta, data TableDataIR, currentChunk, totalChunks int) *TaskTableData { - d.metrics.totalChunks.Add(1) - return NewTaskTableData(meta, data, currentChunk, totalChunks) -} diff --git a/dumpling/export/sql.go b/dumpling/export/sql.go index e0c8fb1682d7f..690ef65fe054f 100644 --- a/dumpling/export/sql.go +++ b/dumpling/export/sql.go @@ -609,9 +609,9 @@ func GetColumnTypes(tctx *tcontext.Context, db *BaseConn, fields, database, tabl if err == nil { err = rows.Close() } - if _, _err_ := failpoint.Eval(_curpkg_("ChaosBrokenMetaConn")); _err_ == nil { - return errors.New("connection is closed") - } + failpoint.Inject("ChaosBrokenMetaConn", func(_ failpoint.Value) { + failpoint.Return(errors.New("connection is closed")) + }) return errors.Annotatef(err, "sql: %s", query) }, func() { colTypes = nil @@ -992,9 +992,9 @@ func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Con } cfg.Params[k] = s } - if _, _err_ := failpoint.Eval(_curpkg_("SkipResetDB")); _err_ == nil { - return db, nil - } + failpoint.Inject("SkipResetDB", func(_ failpoint.Value) { + failpoint.Return(db, nil) + }) db.Close() c, err := mysql.NewConnector(cfg) diff --git a/dumpling/export/sql.go__failpoint_stash__ b/dumpling/export/sql.go__failpoint_stash__ deleted file mode 100644 index 690ef65fe054f..0000000000000 --- a/dumpling/export/sql.go__failpoint_stash__ +++ /dev/null @@ -1,1643 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package export - -import ( - "bytes" - "context" - "database/sql" - "encoding/json" - "fmt" - "io" - "math" - "strconv" - "strings" - - "github.com/go-sql-driver/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/version" - tcontext "github.com/pingcap/tidb/dumpling/context" - "github.com/pingcap/tidb/dumpling/log" - dbconfig "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/parser/model" - pd "github.com/tikv/pd/client/http" - "go.uber.org/multierr" - "go.uber.org/zap" -) - -const ( - orderByTiDBRowID = "ORDER BY `_tidb_rowid`" - snapshotVar = "tidb_snapshot" -) - -type listTableType int - -const ( - listTableByInfoSchema listTableType = iota - listTableByShowFullTables - listTableByShowTableStatus -) - -// ShowDatabases shows the databases of a database server. -func ShowDatabases(db *sql.Conn) ([]string, error) { - var res oneStrColumnTable - if err := simpleQuery(db, "SHOW DATABASES", res.handleOneRow); err != nil { - return nil, err - } - return res.data, nil -} - -// ShowTables shows the tables of a database, the caller should use the correct database. -func ShowTables(db *sql.Conn) ([]string, error) { - var res oneStrColumnTable - if err := simpleQuery(db, "SHOW TABLES", res.handleOneRow); err != nil { - return nil, err - } - return res.data, nil -} - -// ShowCreateDatabase constructs the create database SQL for a specified database -// returns (createDatabaseSQL, error) -func ShowCreateDatabase(tctx *tcontext.Context, db *BaseConn, database string) (string, error) { - var oneRow [2]string - handleOneRow := func(rows *sql.Rows) error { - return rows.Scan(&oneRow[0], &oneRow[1]) - } - query := fmt.Sprintf("SHOW CREATE DATABASE `%s`", escapeString(database)) - err := db.QuerySQL(tctx, handleOneRow, func() { - oneRow[0], oneRow[1] = "", "" - }, query) - if multiErrs := multierr.Errors(err); len(multiErrs) > 0 { - for _, multiErr := range multiErrs { - if mysqlErr, ok := errors.Cause(multiErr).(*mysql.MySQLError); ok { - // Falling back to simple create statement for MemSQL/SingleStore, because of this: - // ERROR 1706 (HY000): Feature 'SHOW CREATE DATABASE' is not supported by MemSQL. - if strings.Contains(mysqlErr.Error(), "SHOW CREATE DATABASE") { - return fmt.Sprintf("CREATE DATABASE `%s`", escapeString(database)), nil - } - } - } - } - return oneRow[1], err -} - -// ShowCreateTable constructs the create table SQL for a specified table -// returns (createTableSQL, error) -func ShowCreateTable(tctx *tcontext.Context, db *BaseConn, database, table string) (string, error) { - var oneRow [2]string - handleOneRow := func(rows *sql.Rows) error { - return rows.Scan(&oneRow[0], &oneRow[1]) - } - query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", escapeString(database), escapeString(table)) - err := db.QuerySQL(tctx, handleOneRow, func() { - oneRow[0], oneRow[1] = "", "" - }, query) - if err != nil { - return "", err - } - return oneRow[1], nil -} - -// ShowCreatePlacementPolicy constructs the create policy SQL for a specified table -// returns (createPolicySQL, error) -func ShowCreatePlacementPolicy(tctx *tcontext.Context, db *BaseConn, policy string) (string, error) { - var oneRow [2]string - handleOneRow := func(rows *sql.Rows) error { - return rows.Scan(&oneRow[0], &oneRow[1]) - } - query := fmt.Sprintf("SHOW CREATE PLACEMENT POLICY `%s`", escapeString(policy)) - err := db.QuerySQL(tctx, handleOneRow, func() { - oneRow[0], oneRow[1] = "", "" - }, query) - return oneRow[1], err -} - -// ShowCreateView constructs the create view SQL for a specified view -// returns (createFakeTableSQL, createViewSQL, error) -func ShowCreateView(tctx *tcontext.Context, db *BaseConn, database, view string) (createFakeTableSQL string, createRealViewSQL string, err error) { - var fieldNames []string - handleFieldRow := func(rows *sql.Rows) error { - var oneRow [6]sql.NullString - scanErr := rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3], &oneRow[4], &oneRow[5]) - if scanErr != nil { - return errors.Trace(scanErr) - } - if oneRow[0].Valid { - fieldNames = append(fieldNames, fmt.Sprintf("`%s` int", escapeString(oneRow[0].String))) - } - return nil - } - var oneRow [4]string - handleOneRow := func(rows *sql.Rows) error { - return rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3]) - } - var createTableSQL, createViewSQL strings.Builder - - // Build createTableSQL - query := fmt.Sprintf("SHOW FIELDS FROM `%s`.`%s`", escapeString(database), escapeString(view)) - err = db.QuerySQL(tctx, handleFieldRow, func() { - fieldNames = []string{} - }, query) - if err != nil { - return "", "", err - } - fmt.Fprintf(&createTableSQL, "CREATE TABLE `%s`(\n", escapeString(view)) - createTableSQL.WriteString(strings.Join(fieldNames, ",\n")) - createTableSQL.WriteString("\n)ENGINE=MyISAM;\n") - - // Build createViewSQL - fmt.Fprintf(&createViewSQL, "DROP TABLE IF EXISTS `%s`;\n", escapeString(view)) - fmt.Fprintf(&createViewSQL, "DROP VIEW IF EXISTS `%s`;\n", escapeString(view)) - query = fmt.Sprintf("SHOW CREATE VIEW `%s`.`%s`", escapeString(database), escapeString(view)) - err = db.QuerySQL(tctx, handleOneRow, func() { - for i := range oneRow { - oneRow[i] = "" - } - }, query) - if err != nil { - return "", "", err - } - // The result for `show create view` SQL - // mysql> show create view v1; - // +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+ - // | View | Create View | character_set_client | collation_connection | - // +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+ - // | v1 | CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `v1` (`a`) AS SELECT `t`.`a` AS `a` FROM `test`.`t` | utf8 | utf8_general_ci | - // +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+ - SetCharset(&createViewSQL, oneRow[2], oneRow[3]) - createViewSQL.WriteString(oneRow[1]) - createViewSQL.WriteString(";\n") - RestoreCharset(&createViewSQL) - - return createTableSQL.String(), createViewSQL.String(), nil -} - -// ShowCreateSequence constructs the create sequence SQL for a specified sequence -// returns (createSequenceSQL, error) -func ShowCreateSequence(tctx *tcontext.Context, db *BaseConn, database, sequence string, conf *Config) (string, error) { - var oneRow [2]string - handleOneRow := func(rows *sql.Rows) error { - return rows.Scan(&oneRow[0], &oneRow[1]) - } - var ( - createSequenceSQL strings.Builder - nextNotCachedValue int64 - ) - query := fmt.Sprintf("SHOW CREATE SEQUENCE `%s`.`%s`", escapeString(database), escapeString(sequence)) - err := db.QuerySQL(tctx, handleOneRow, func() { - oneRow[0], oneRow[1] = "", "" - }, query) - if err != nil { - return "", err - } - createSequenceSQL.WriteString(oneRow[1]) - createSequenceSQL.WriteString(";\n") - - switch conf.ServerInfo.ServerType { - case version.ServerTypeTiDB: - // Get next not allocated auto increment id of the whole cluster - query := fmt.Sprintf("SHOW TABLE `%s`.`%s` NEXT_ROW_ID", escapeString(database), escapeString(sequence)) - results, err := db.QuerySQLWithColumns(tctx, []string{"NEXT_GLOBAL_ROW_ID", "ID_TYPE"}, query) - if err != nil { - return "", err - } - for _, oneRow := range results { - nextGlobalRowID, idType := oneRow[0], oneRow[1] - if idType == "SEQUENCE" { - nextNotCachedValue, _ = strconv.ParseInt(nextGlobalRowID, 10, 64) - } - } - fmt.Fprintf(&createSequenceSQL, "SELECT SETVAL(`%s`,%d);\n", escapeString(sequence), nextNotCachedValue) - case version.ServerTypeMariaDB: - var oneRow1 string - handleOneRow1 := func(rows *sql.Rows) error { - return rows.Scan(&oneRow1) - } - query := fmt.Sprintf("SELECT NEXT_NOT_CACHED_VALUE FROM `%s`.`%s`", escapeString(database), escapeString(sequence)) - err := db.QuerySQL(tctx, handleOneRow1, func() { - oneRow1 = "" - }, query) - if err != nil { - return "", err - } - nextNotCachedValue, _ = strconv.ParseInt(oneRow1, 10, 64) - fmt.Fprintf(&createSequenceSQL, "SELECT SETVAL(`%s`,%d);\n", escapeString(sequence), nextNotCachedValue) - } - return createSequenceSQL.String(), nil -} - -// SetCharset builds the set charset SQLs -func SetCharset(w *strings.Builder, characterSet, collationConnection string) { - w.WriteString("SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT;\n") - w.WriteString("SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS;\n") - w.WriteString("SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION;\n") - - fmt.Fprintf(w, "SET character_set_client = %s;\n", characterSet) - fmt.Fprintf(w, "SET character_set_results = %s;\n", characterSet) - fmt.Fprintf(w, "SET collation_connection = %s;\n", collationConnection) -} - -// RestoreCharset builds the restore charset SQLs -func RestoreCharset(w io.StringWriter) { - _, _ = w.WriteString("SET character_set_client = @PREV_CHARACTER_SET_CLIENT;\n") - _, _ = w.WriteString("SET character_set_results = @PREV_CHARACTER_SET_RESULTS;\n") - _, _ = w.WriteString("SET collation_connection = @PREV_COLLATION_CONNECTION;\n") -} - -// updateSpecifiedTablesMeta updates DatabaseTables with correct table type and avg row size. -func updateSpecifiedTablesMeta(tctx *tcontext.Context, db *sql.Conn, dbTables DatabaseTables, listType listTableType) error { - var ( - schema, table, tableTypeStr string - tableType TableType - avgRowLength uint64 - err error - ) - switch listType { - case listTableByInfoSchema: - dbNames := make([]string, 0, len(dbTables)) - for db := range dbTables { - dbNames = append(dbNames, fmt.Sprintf("'%s'", db)) - } - query := fmt.Sprintf("SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA IN (%s)", strings.Join(dbNames, ",")) - if err := simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { - var ( - sqlAvgRowLength sql.NullInt64 - err2 error - ) - if err2 = rows.Scan(&schema, &table, &tableTypeStr, &sqlAvgRowLength); err != nil { - return errors.Trace(err2) - } - - tbls, ok := dbTables[schema] - if !ok { - return nil - } - for _, tbl := range tbls { - if tbl.Name == table { - tableType, err2 = ParseTableType(tableTypeStr) - if err2 != nil { - return errors.Trace(err2) - } - if sqlAvgRowLength.Valid { - avgRowLength = uint64(sqlAvgRowLength.Int64) - } else { - avgRowLength = 0 - } - tbl.Type = tableType - tbl.AvgRowLength = avgRowLength - } - } - return nil - }, query); err != nil { - return errors.Annotatef(err, "sql: %s", query) - } - return nil - case listTableByShowFullTables: - for schema, tbls := range dbTables { - query := fmt.Sprintf("SHOW FULL TABLES FROM `%s`", - escapeString(schema)) - if err := simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { - var err2 error - if err2 = rows.Scan(&table, &tableTypeStr); err != nil { - return errors.Trace(err2) - } - for _, tbl := range tbls { - if tbl.Name == table { - tableType, err2 = ParseTableType(tableTypeStr) - if err2 != nil { - return errors.Trace(err2) - } - tbl.Type = tableType - } - } - return nil - }, query); err != nil { - return errors.Annotatef(err, "sql: %s", query) - } - } - return nil - default: - const queryTemplate = "SHOW TABLE STATUS FROM `%s`" - for schema, tbls := range dbTables { - query := fmt.Sprintf(queryTemplate, escapeString(schema)) - rows, err := db.QueryContext(tctx, query) - if err != nil { - return errors.Annotatef(err, "sql: %s", query) - } - results, err := GetSpecifiedColumnValuesAndClose(rows, "NAME", "ENGINE", "AVG_ROW_LENGTH", "COMMENT") - if err != nil { - return errors.Annotatef(err, "sql: %s", query) - } - for _, oneRow := range results { - table, engine, avgRowLengthStr, comment := oneRow[0], oneRow[1], oneRow[2], oneRow[3] - for _, tbl := range tbls { - if tbl.Name == table { - if avgRowLengthStr != "" { - avgRowLength, err = strconv.ParseUint(avgRowLengthStr, 10, 64) - if err != nil { - return errors.Annotatef(err, "sql: %s", query) - } - } else { - avgRowLength = 0 - } - tbl.AvgRowLength = avgRowLength - tableType = TableTypeBase - if engine == "" && (comment == "" || comment == TableTypeViewStr) { - tableType = TableTypeView - } else if engine == "" { - tctx.L().Warn("invalid table without engine found", zap.String("database", schema), zap.String("table", table)) - continue - } - tbl.Type = tableType - } - } - } - } - return nil - } -} - -// ListAllDatabasesTables lists all the databases and tables from the database -// listTableByInfoSchema list tables by table information_schema in MySQL -// listTableByShowTableStatus has better performance than listTableByInfoSchema -// listTableByShowFullTables is used in mysql8 version [8.0.3,8.0.23), more details can be found in the comments of func matchMysqlBugversion -func ListAllDatabasesTables(tctx *tcontext.Context, db *sql.Conn, databaseNames []string, - listType listTableType, tableTypes ...TableType) (DatabaseTables, error) { // revive:disable-line:flag-parameter - dbTables := DatabaseTables{} - var ( - schema, table, tableTypeStr string - tableType TableType - avgRowLength uint64 - err error - ) - - tableTypeConditions := make([]string, len(tableTypes)) - for i, tableType := range tableTypes { - tableTypeConditions[i] = fmt.Sprintf("TABLE_TYPE='%s'", tableType) - } - switch listType { - case listTableByInfoSchema: - query := fmt.Sprintf("SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE %s", strings.Join(tableTypeConditions, " OR ")) - for _, schema := range databaseNames { - dbTables[schema] = make([]*TableInfo, 0) - } - if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { - var ( - sqlAvgRowLength sql.NullInt64 - err2 error - ) - if err2 = rows.Scan(&schema, &table, &tableTypeStr, &sqlAvgRowLength); err != nil { - return errors.Trace(err2) - } - tableType, err2 = ParseTableType(tableTypeStr) - if err2 != nil { - return errors.Trace(err2) - } - - if sqlAvgRowLength.Valid { - avgRowLength = uint64(sqlAvgRowLength.Int64) - } else { - avgRowLength = 0 - } - // only append tables to schemas in databaseNames - if _, ok := dbTables[schema]; ok { - dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType}) - } - return nil - }, query); err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - case listTableByShowFullTables: - for _, schema = range databaseNames { - dbTables[schema] = make([]*TableInfo, 0) - query := fmt.Sprintf("SHOW FULL TABLES FROM `%s` WHERE %s", - escapeString(schema), strings.Join(tableTypeConditions, " OR ")) - if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error { - var err2 error - if err2 = rows.Scan(&table, &tableTypeStr); err != nil { - return errors.Trace(err2) - } - tableType, err2 = ParseTableType(tableTypeStr) - if err2 != nil { - return errors.Trace(err2) - } - avgRowLength = 0 // can't get avgRowLength from the result of `show full tables` so hardcode to 0 here - dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType}) - return nil - }, query); err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - } - default: - const queryTemplate = "SHOW TABLE STATUS FROM `%s`" - selectedTableType := make(map[TableType]struct{}) - for _, tableType = range tableTypes { - selectedTableType[tableType] = struct{}{} - } - for _, schema = range databaseNames { - dbTables[schema] = make([]*TableInfo, 0) - query := fmt.Sprintf(queryTemplate, escapeString(schema)) - rows, err := db.QueryContext(tctx, query) - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - results, err := GetSpecifiedColumnValuesAndClose(rows, "NAME", "ENGINE", "AVG_ROW_LENGTH", "COMMENT") - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - for _, oneRow := range results { - table, engine, avgRowLengthStr, comment := oneRow[0], oneRow[1], oneRow[2], oneRow[3] - if avgRowLengthStr != "" { - avgRowLength, err = strconv.ParseUint(avgRowLengthStr, 10, 64) - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - } else { - avgRowLength = 0 - } - tableType = TableTypeBase - if engine == "" && (comment == "" || comment == TableTypeViewStr) { - tableType = TableTypeView - } else if engine == "" { - tctx.L().Warn("invalid table without engine found", zap.String("database", schema), zap.String("table", table)) - continue - } - if _, ok := selectedTableType[tableType]; !ok { - continue - } - dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType}) - } - } - } - return dbTables, nil -} - -// ListAllPlacementPolicyNames returns all placement policy names. -func ListAllPlacementPolicyNames(tctx *tcontext.Context, db *BaseConn) ([]string, error) { - var policyList []string - var policy string - const query = "select distinct policy_name from information_schema.placement_policies where policy_name is not null;" - err := db.QuerySQL(tctx, func(rows *sql.Rows) error { - err := rows.Scan(&policy) - if err != nil { - return errors.Trace(err) - } - policyList = append(policyList, policy) - return nil - }, func() { - policyList = policyList[:0] - }, query) - return policyList, errors.Annotatef(err, "sql: %s", query) -} - -// SelectVersion gets the version information from the database server -func SelectVersion(db *sql.DB) (string, error) { - var versionInfo string - const query = "SELECT version()" - row := db.QueryRow(query) - err := row.Scan(&versionInfo) - if err != nil { - return "", errors.Annotatef(err, "sql: %s", query) - } - return versionInfo, nil -} - -// SelectAllFromTable dumps data serialized from a specified table -func SelectAllFromTable(conf *Config, meta TableMeta, partition, orderByClause string) TableDataIR { - database, table := meta.DatabaseName(), meta.TableName() - selectedField, selectLen := meta.SelectedField(), meta.SelectedLen() - query := buildSelectQuery(database, table, selectedField, partition, buildWhereCondition(conf, ""), orderByClause) - - return &tableData{ - query: query, - colLen: selectLen, - } -} - -func buildSelectQuery(database, table, fields, partition, where, orderByClause string) string { - var query strings.Builder - query.WriteString("SELECT ") - if fields == "" { - // If all of the columns are generated, - // we need to make sure the query is valid. - fields = "''" - } - query.WriteString(fields) - query.WriteString(" FROM `") - query.WriteString(escapeString(database)) - query.WriteString("`.`") - query.WriteString(escapeString(table)) - query.WriteByte('`') - if partition != "" { - query.WriteString(" PARTITION(`") - query.WriteString(escapeString(partition)) - query.WriteString("`)") - } - - if where != "" { - query.WriteString(" ") - query.WriteString(where) - } - - if orderByClause != "" { - query.WriteString(" ") - query.WriteString(orderByClause) - } - - return query.String() -} - -func buildOrderByClause(tctx *tcontext.Context, conf *Config, db *BaseConn, database, table string, hasImplicitRowID bool) (string, error) { // revive:disable-line:flag-parameter - if !conf.SortByPk { - return "", nil - } - if hasImplicitRowID { - return orderByTiDBRowID, nil - } - cols, err := GetPrimaryKeyColumns(tctx, db, database, table) - if err != nil { - return "", errors.Trace(err) - } - return buildOrderByClauseString(cols), nil -} - -// SelectTiDBRowID checks whether this table has _tidb_rowid column -func SelectTiDBRowID(tctx *tcontext.Context, db *BaseConn, database, table string) (bool, error) { - tiDBRowIDQuery := fmt.Sprintf("SELECT _tidb_rowid from `%s`.`%s` LIMIT 1", escapeString(database), escapeString(table)) - hasImplictRowID := false - err := db.ExecSQL(tctx, func(_ sql.Result, err error) error { - if err != nil { - hasImplictRowID = false - errMsg := strings.ToLower(err.Error()) - if strings.Contains(errMsg, fmt.Sprintf("%d", errno.ErrBadField)) { - return nil - } - return errors.Annotatef(err, "sql: %s", tiDBRowIDQuery) - } - hasImplictRowID = true - return nil - }, tiDBRowIDQuery) - return hasImplictRowID, err -} - -// GetSuitableRows gets suitable rows for each table -func GetSuitableRows(avgRowLength uint64) uint64 { - const ( - defaultRows = 200000 - maxRows = 1000000 - bytesPerFile = 128 * 1024 * 1024 // 128MB per file by default - ) - if avgRowLength == 0 { - return defaultRows - } - estimateRows := bytesPerFile / avgRowLength - if estimateRows > maxRows { - return maxRows - } - return estimateRows -} - -// GetColumnTypes gets *sql.ColumnTypes from a specified table -func GetColumnTypes(tctx *tcontext.Context, db *BaseConn, fields, database, table string) ([]*sql.ColumnType, error) { - query := fmt.Sprintf("SELECT %s FROM `%s`.`%s` LIMIT 1", fields, escapeString(database), escapeString(table)) - var colTypes []*sql.ColumnType - err := db.QuerySQL(tctx, func(rows *sql.Rows) error { - var err error - colTypes, err = rows.ColumnTypes() - if err == nil { - err = rows.Close() - } - failpoint.Inject("ChaosBrokenMetaConn", func(_ failpoint.Value) { - failpoint.Return(errors.New("connection is closed")) - }) - return errors.Annotatef(err, "sql: %s", query) - }, func() { - colTypes = nil - }, query) - if err != nil { - return nil, err - } - return colTypes, nil -} - -// GetPrimaryKeyAndColumnTypes gets all primary columns and their types in ordinal order -func GetPrimaryKeyAndColumnTypes(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) ([]string, []string, error) { - var ( - colNames, colTypes []string - err error - ) - colNames, err = GetPrimaryKeyColumns(tctx, conn, meta.DatabaseName(), meta.TableName()) - if err != nil { - return nil, nil, err - } - colName2Type := string2Map(meta.ColumnNames(), meta.ColumnTypes()) - colTypes = make([]string, len(colNames)) - for i, colName := range colNames { - colTypes[i] = colName2Type[colName] - } - return colNames, colTypes, nil -} - -// GetPrimaryKeyColumns gets all primary columns in ordinal order -func GetPrimaryKeyColumns(tctx *tcontext.Context, db *BaseConn, database, table string) ([]string, error) { - priKeyColsQuery := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", escapeString(database), escapeString(table)) - results, err := db.QuerySQLWithColumns(tctx, []string{"KEY_NAME", "COLUMN_NAME"}, priKeyColsQuery) - if err != nil { - return nil, err - } - - cols := make([]string, 0, len(results)) - for _, oneRow := range results { - keyName, columnName := oneRow[0], oneRow[1] - if keyName == "PRIMARY" { - cols = append(cols, columnName) - } - } - return cols, nil -} - -// getNumericIndex picks up indices according to the following priority: -// primary key > unique key with the smallest count > key with the max cardinality -// primary key with multi cols is before unique key with single col because we will sort result by primary keys -func getNumericIndex(tctx *tcontext.Context, db *BaseConn, meta TableMeta) (string, error) { - database, table := meta.DatabaseName(), meta.TableName() - colName2Type := string2Map(meta.ColumnNames(), meta.ColumnTypes()) - keyQuery := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", escapeString(database), escapeString(table)) - results, err := db.QuerySQLWithColumns(tctx, []string{"NON_UNIQUE", "SEQ_IN_INDEX", "KEY_NAME", "COLUMN_NAME", "CARDINALITY"}, keyQuery) - if err != nil { - return "", err - } - type keyColumnPair struct { - colName string - count uint64 - } - var ( - uniqueKeyMap = map[string]keyColumnPair{} // unique key name -> key column name, unique key columns count - keyColumn string - maxCardinality int64 = -1 - ) - - // check primary key first, then unique key - for _, oneRow := range results { - nonUnique, seqInIndex, keyName, colName, cardinality := oneRow[0], oneRow[1], oneRow[2], oneRow[3], oneRow[4] - // only try pick the first column, because the second column of pk/uk in where condition will trigger a full table scan - if seqInIndex != "1" { - if pair, ok := uniqueKeyMap[keyName]; ok { - seqInIndexInt, err := strconv.ParseUint(seqInIndex, 10, 64) - if err == nil && seqInIndexInt > pair.count { - uniqueKeyMap[keyName] = keyColumnPair{pair.colName, seqInIndexInt} - } - } - continue - } - _, numberColumn := dataTypeInt[colName2Type[colName]] - if numberColumn { - switch { - case keyName == "PRIMARY": - return colName, nil - case nonUnique == "0": - uniqueKeyMap[keyName] = keyColumnPair{colName, 1} - // pick index column with max cardinality when there is no unique index - case len(uniqueKeyMap) == 0: - cardinalityInt, err := strconv.ParseInt(cardinality, 10, 64) - if err == nil && cardinalityInt > maxCardinality { - keyColumn = colName - maxCardinality = cardinalityInt - } - } - } - } - if len(uniqueKeyMap) > 0 { - var ( - minCols uint64 = math.MaxUint64 - uniqueKeyColumn string - ) - for _, pair := range uniqueKeyMap { - if pair.count < minCols { - uniqueKeyColumn = pair.colName - minCols = pair.count - } - } - return uniqueKeyColumn, nil - } - return keyColumn, nil -} - -// FlushTableWithReadLock flush tables with read lock -func FlushTableWithReadLock(ctx context.Context, db *sql.Conn) error { - const ftwrlQuery = "FLUSH TABLES WITH READ LOCK" - _, err := db.ExecContext(ctx, ftwrlQuery) - return errors.Annotatef(err, "sql: %s", ftwrlQuery) -} - -// LockTables locks table with read lock -func LockTables(ctx context.Context, db *sql.Conn, database, table string) error { - lockTableQuery := fmt.Sprintf("LOCK TABLES `%s`.`%s` READ", escapeString(database), escapeString(table)) - _, err := db.ExecContext(ctx, lockTableQuery) - return errors.Annotatef(err, "sql: %s", lockTableQuery) -} - -// UnlockTables unlocks all tables' lock -func UnlockTables(ctx context.Context, db *sql.Conn) error { - const unlockTableQuery = "UNLOCK TABLES" - _, err := db.ExecContext(ctx, unlockTableQuery) - return errors.Annotatef(err, "sql: %s", unlockTableQuery) -} - -// ShowMasterStatus get SHOW MASTER STATUS result from database -func ShowMasterStatus(db *sql.Conn) ([]string, error) { - var oneRow []string - handleOneRow := func(rows *sql.Rows) error { - cols, err := rows.Columns() - if err != nil { - return errors.Trace(err) - } - fieldNum := len(cols) - oneRow = make([]string, fieldNum) - addr := make([]any, fieldNum) - for i := range oneRow { - addr[i] = &oneRow[i] - } - return rows.Scan(addr...) - } - const showMasterStatusQuery = "SHOW MASTER STATUS" - err := simpleQuery(db, showMasterStatusQuery, handleOneRow) - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", showMasterStatusQuery) - } - return oneRow, nil -} - -// GetSpecifiedColumnValueAndClose get columns' values whose name is equal to columnName and close the given rows -func GetSpecifiedColumnValueAndClose(rows *sql.Rows, columnName string) ([]string, error) { - if rows == nil { - return []string{}, nil - } - defer rows.Close() - var strs []string - columns, _ := rows.Columns() - addr := make([]any, len(columns)) - oneRow := make([]sql.NullString, len(columns)) - fieldIndex := -1 - for i, col := range columns { - if strings.EqualFold(col, columnName) { - fieldIndex = i - } - addr[i] = &oneRow[i] - } - if fieldIndex == -1 { - return strs, nil - } - for rows.Next() { - err := rows.Scan(addr...) - if err != nil { - return strs, errors.Trace(err) - } - if oneRow[fieldIndex].Valid { - strs = append(strs, oneRow[fieldIndex].String) - } - } - return strs, errors.Trace(rows.Err()) -} - -// GetSpecifiedColumnValuesAndClose get columns' values whose name is equal to columnName -func GetSpecifiedColumnValuesAndClose(rows *sql.Rows, columnName ...string) ([][]string, error) { - if rows == nil { - return [][]string{}, nil - } - defer rows.Close() - var strs [][]string - columns, err := rows.Columns() - if err != nil { - return strs, errors.Trace(err) - } - addr := make([]any, len(columns)) - oneRow := make([]sql.NullString, len(columns)) - fieldIndexMp := make(map[int]int) - for i, col := range columns { - addr[i] = &oneRow[i] - for j, name := range columnName { - if strings.EqualFold(col, name) { - fieldIndexMp[i] = j - } - } - } - if len(fieldIndexMp) == 0 { - return strs, nil - } - for rows.Next() { - err := rows.Scan(addr...) - if err != nil { - return strs, errors.Trace(err) - } - written := false - tmpStr := make([]string, len(columnName)) - for colPos, namePos := range fieldIndexMp { - if oneRow[colPos].Valid { - written = true - tmpStr[namePos] = oneRow[colPos].String - } - } - if written { - strs = append(strs, tmpStr) - } - } - return strs, errors.Trace(rows.Err()) -} - -// GetPdAddrs gets PD address from TiDB -func GetPdAddrs(tctx *tcontext.Context, db *sql.DB) ([]string, error) { - const query = "SELECT * FROM information_schema.cluster_info where type = 'pd';" - rows, err := db.QueryContext(tctx, query) - if err != nil { - return []string{}, errors.Annotatef(err, "sql: %s", query) - } - pdAddrs, err := GetSpecifiedColumnValueAndClose(rows, "STATUS_ADDRESS") - return pdAddrs, errors.Annotatef(err, "sql: %s", query) -} - -// GetTiDBDDLIDs gets DDL IDs from TiDB -func GetTiDBDDLIDs(tctx *tcontext.Context, db *sql.DB) ([]string, error) { - const query = "SELECT * FROM information_schema.tidb_servers_info;" - rows, err := db.QueryContext(tctx, query) - if err != nil { - return []string{}, errors.Annotatef(err, "sql: %s", query) - } - ddlIDs, err := GetSpecifiedColumnValueAndClose(rows, "DDL_ID") - return ddlIDs, errors.Annotatef(err, "sql: %s", query) -} - -// getTiDBConfig gets tidb config from TiDB server -// @@tidb_config details doc https://docs.pingcap.com/tidb/stable/system-variables#tidb_config -// this variable exists at least from v2.0.0, so this works in most existing tidb instances -func getTiDBConfig(db *sql.Conn) (dbconfig.Config, error) { - const query = "SELECT @@tidb_config;" - var ( - tidbConfig dbconfig.Config - tidbConfigBytes []byte - ) - row := db.QueryRowContext(context.Background(), query) - err := row.Scan(&tidbConfigBytes) - if err != nil { - return tidbConfig, errors.Annotatef(err, "sql: %s", query) - } - err = json.Unmarshal(tidbConfigBytes, &tidbConfig) - return tidbConfig, errors.Annotatef(err, "sql: %s", query) -} - -// CheckTiDBWithTiKV use sql to check whether current TiDB has TiKV -func CheckTiDBWithTiKV(db *sql.DB) (bool, error) { - conn, err := db.Conn(context.Background()) - if err == nil { - defer func() { - _ = conn.Close() - }() - tidbConfig, err := getTiDBConfig(conn) - if err == nil { - return tidbConfig.Store == "tikv", nil - } - } - var count int - const query = "SELECT COUNT(1) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'" - row := db.QueryRow(query) - err = row.Scan(&count) - if err != nil { - // still return true here. Because sometimes users may not have privileges for MySQL.TiDB database - // In most production cases TiDB has TiKV - return true, errors.Annotatef(err, "sql: %s", query) - } - return count > 0, nil -} - -// CheckIfSeqExists use sql to check whether sequence exists -func CheckIfSeqExists(db *sql.Conn) (bool, error) { - var count int - const query = "SELECT COUNT(1) as c FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='SEQUENCE'" - row := db.QueryRowContext(context.Background(), query) - err := row.Scan(&count) - if err != nil { - return false, errors.Annotatef(err, "sql: %s", query) - } - - return count > 0, nil -} - -// CheckTiDBEnableTableLock use sql variable to check whether current TiDB has TiKV -func CheckTiDBEnableTableLock(db *sql.Conn) (bool, error) { - tidbConfig, err := getTiDBConfig(db) - if err != nil { - return false, err - } - return tidbConfig.EnableTableLock, nil -} - -func getSnapshot(db *sql.Conn) (string, error) { - str, err := ShowMasterStatus(db) - if err != nil { - return "", err - } - return str[snapshotFieldIndex], nil -} - -func isUnknownSystemVariableErr(err error) bool { - return strings.Contains(err.Error(), "Unknown system variable") -} - -// resetDBWithSessionParams will return a new sql.DB as a replacement for input `db` with new session parameters. -// If returned error is nil, the input `db` will be closed. -func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, cfg *mysql.Config, params map[string]any) (*sql.DB, error) { - support := make(map[string]any) - for k, v := range params { - var pv any - if str, ok := v.(string); ok { - if pvi, err := strconv.ParseInt(str, 10, 64); err == nil { - pv = pvi - } else if pvf, err := strconv.ParseFloat(str, 64); err == nil { - pv = pvf - } else { - pv = str - } - } else { - pv = v - } - s := fmt.Sprintf("SET SESSION %s = ?", k) - _, err := db.ExecContext(tctx, s, pv) - if err != nil { - if k == snapshotVar { - err = errors.Annotate(err, "fail to set snapshot for tidb, please set --consistency=none/--consistency=lock or fix snapshot problem") - } else if isUnknownSystemVariableErr(err) { - tctx.L().Info("session variable is not supported by db", zap.String("variable", k), zap.Reflect("value", v)) - continue - } - return nil, errors.Trace(err) - } - - support[k] = pv - } - - if cfg.Params == nil { - cfg.Params = make(map[string]string) - } - - for k, v := range support { - var s string - // Wrap string with quote to handle string with space. For example, '2020-10-20 13:41:40' - // For --params argument, quote doesn't matter because it doesn't affect the actual value - if str, ok := v.(string); ok { - s = wrapStringWith(str, "'") - } else { - s = fmt.Sprintf("%v", v) - } - cfg.Params[k] = s - } - failpoint.Inject("SkipResetDB", func(_ failpoint.Value) { - failpoint.Return(db, nil) - }) - - db.Close() - c, err := mysql.NewConnector(cfg) - if err != nil { - return nil, errors.Trace(err) - } - newDB := sql.OpenDB(c) - // ping to make sure all session parameters are set correctly - err = newDB.PingContext(tctx) - if err != nil { - newDB.Close() - } - return newDB, nil -} - -func createConnWithConsistency(ctx context.Context, db *sql.DB, repeatableRead bool) (*sql.Conn, error) { - conn, err := db.Conn(ctx) - if err != nil { - return nil, errors.Trace(err) - } - var query string - if repeatableRead { - query = "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ" - _, err = conn.ExecContext(ctx, query) - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - } - query = "START TRANSACTION /*!40108 WITH CONSISTENT SNAPSHOT */" - _, err = conn.ExecContext(ctx, query) - if err != nil { - // Some MySQL Compatible databases like Vitess and MemSQL/SingleStore - // are newer than 4.1.8 (the version comment) but don't actually support - // `WITH CONSISTENT SNAPSHOT`. So retry without that if the statement fails. - query = "START TRANSACTION" - _, err = conn.ExecContext(ctx, query) - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - } - return conn, nil -} - -// buildSelectField returns the selecting fields' string(joined by comma(`,`)), -// and the number of writable fields. -func buildSelectField(tctx *tcontext.Context, db *BaseConn, dbName, tableName string, completeInsert bool) (string, int, error) { // revive:disable-line:flag-parameter - query := fmt.Sprintf("SHOW COLUMNS FROM `%s`.`%s`", escapeString(dbName), escapeString(tableName)) - results, err := db.QuerySQLWithColumns(tctx, []string{"FIELD", "EXTRA"}, query) - if err != nil { - return "", 0, err - } - availableFields := make([]string, 0) - hasGenerateColumn := false - for _, oneRow := range results { - fieldName, extra := oneRow[0], oneRow[1] - switch extra { - case "STORED GENERATED", "VIRTUAL GENERATED": - hasGenerateColumn = true - continue - } - availableFields = append(availableFields, wrapBackTicks(escapeString(fieldName))) - } - if completeInsert || hasGenerateColumn { - return strings.Join(availableFields, ","), len(availableFields), nil - } - return "*", len(availableFields), nil -} - -func buildWhereClauses(handleColNames []string, handleVals [][]string) []string { - if len(handleColNames) == 0 || len(handleVals) == 0 { - return nil - } - quotaCols := make([]string, len(handleColNames)) - for i, s := range handleColNames { - quotaCols[i] = fmt.Sprintf("`%s`", escapeString(s)) - } - where := make([]string, 0, len(handleVals)+1) - buf := &bytes.Buffer{} - buildCompareClause(buf, quotaCols, handleVals[0], less, false) - where = append(where, buf.String()) - buf.Reset() - for i := 1; i < len(handleVals); i++ { - low, up := handleVals[i-1], handleVals[i] - buildBetweenClause(buf, quotaCols, low, up) - where = append(where, buf.String()) - buf.Reset() - } - buildCompareClause(buf, quotaCols, handleVals[len(handleVals)-1], greater, true) - where = append(where, buf.String()) - buf.Reset() - return where -} - -// return greater than TableRangeScan where clause -// the result doesn't contain brackets -const ( - greater = '>' - less = '<' - equal = '=' -) - -// buildCompareClause build clause with specified bounds. Usually we will use the following two conditions: -// (compare, writeEqual) == (less, false), return quotaCols < bound clause. In other words, (-inf, bound) -// (compare, writeEqual) == (greater, true), return quotaCols >= bound clause. In other words, [bound, +inf) -func buildCompareClause(buf *bytes.Buffer, quotaCols []string, bound []string, compare byte, writeEqual bool) { // revive:disable-line:flag-parameter - for i, col := range quotaCols { - if i > 0 { - buf.WriteString("or(") - } - for j := 0; j < i; j++ { - buf.WriteString(quotaCols[j]) - buf.WriteByte(equal) - buf.WriteString(bound[j]) - buf.WriteString(" and ") - } - buf.WriteString(col) - buf.WriteByte(compare) - if writeEqual && i == len(quotaCols)-1 { - buf.WriteByte(equal) - } - buf.WriteString(bound[i]) - if i > 0 { - buf.WriteByte(')') - } else if i != len(quotaCols)-1 { - buf.WriteByte(' ') - } - } -} - -// getCommonLength returns the common length of low and up -func getCommonLength(low []string, up []string) int { - for i := range low { - if low[i] != up[i] { - return i - } - } - return len(low) -} - -// buildBetweenClause build clause in a specified table range. -// the result where clause will be low <= quotaCols < up. In other words, [low, up) -func buildBetweenClause(buf *bytes.Buffer, quotaCols []string, low []string, up []string) { - singleBetween := func(writeEqual bool) { - buf.WriteString(quotaCols[0]) - buf.WriteByte(greater) - if writeEqual { - buf.WriteByte(equal) - } - buf.WriteString(low[0]) - buf.WriteString(" and ") - buf.WriteString(quotaCols[0]) - buf.WriteByte(less) - buf.WriteString(up[0]) - } - // handle special cases with common prefix - commonLen := getCommonLength(low, up) - if commonLen > 0 { - // unexpected case for low == up, return empty result - if commonLen == len(low) { - buf.WriteString("false") - return - } - for i := 0; i < commonLen; i++ { - if i > 0 { - buf.WriteString(" and ") - } - buf.WriteString(quotaCols[i]) - buf.WriteByte(equal) - buf.WriteString(low[i]) - } - buf.WriteString(" and(") - defer buf.WriteByte(')') - quotaCols = quotaCols[commonLen:] - low = low[commonLen:] - up = up[commonLen:] - } - - // handle special cases with only one column - if len(quotaCols) == 1 { - singleBetween(true) - return - } - buf.WriteByte('(') - singleBetween(false) - buf.WriteString(")or(") - buf.WriteString(quotaCols[0]) - buf.WriteByte(equal) - buf.WriteString(low[0]) - buf.WriteString(" and(") - buildCompareClause(buf, quotaCols[1:], low[1:], greater, true) - buf.WriteString("))or(") - buf.WriteString(quotaCols[0]) - buf.WriteByte(equal) - buf.WriteString(up[0]) - buf.WriteString(" and(") - buildCompareClause(buf, quotaCols[1:], up[1:], less, false) - buf.WriteString("))") -} - -func buildOrderByClauseString(handleColNames []string) string { - if len(handleColNames) == 0 { - return "" - } - separator := "," - quotaCols := make([]string, len(handleColNames)) - for i, col := range handleColNames { - quotaCols[i] = fmt.Sprintf("`%s`", escapeString(col)) - } - return fmt.Sprintf("ORDER BY %s", strings.Join(quotaCols, separator)) -} - -func buildLockTablesSQL(allTables DatabaseTables, blockList map[string]map[string]any) string { - // ,``.`` READ has 11 bytes, "LOCK TABLE" has 10 bytes - estimatedCap := len(allTables)*11 + 10 - s := bytes.NewBuffer(make([]byte, 0, estimatedCap)) - n := false - for dbName, tables := range allTables { - escapedDBName := escapeString(dbName) - for _, table := range tables { - // Lock views will lock related tables. However, we won't dump data only the create sql of view, so we needn't lock view here. - // Besides, mydumper also only lock base table here. https://github.com/maxbube/mydumper/blob/1fabdf87e3007e5934227b504ad673ba3697946c/mydumper.c#L1568 - if table.Type != TableTypeBase { - continue - } - if blockTable, ok := blockList[dbName]; ok { - if _, ok := blockTable[table.Name]; ok { - continue - } - } - if !n { - fmt.Fprintf(s, "LOCK TABLES `%s`.`%s` READ", escapedDBName, escapeString(table.Name)) - n = true - } else { - fmt.Fprintf(s, ",`%s`.`%s` READ", escapedDBName, escapeString(table.Name)) - } - } - } - return s.String() -} - -type oneStrColumnTable struct { - data []string -} - -func (o *oneStrColumnTable) handleOneRow(rows *sql.Rows) error { - var str string - if err := rows.Scan(&str); err != nil { - return errors.Trace(err) - } - o.data = append(o.data, str) - return nil -} - -func simpleQuery(conn *sql.Conn, query string, handleOneRow func(*sql.Rows) error) error { - return simpleQueryWithArgs(context.Background(), conn, handleOneRow, query) -} - -func simpleQueryWithArgs(ctx context.Context, conn *sql.Conn, handleOneRow func(*sql.Rows) error, query string, args ...any) error { - var ( - rows *sql.Rows - err error - ) - if len(args) > 0 { - rows, err = conn.QueryContext(ctx, query, args...) - } else { - rows, err = conn.QueryContext(ctx, query) - } - if err != nil { - return errors.Annotatef(err, "sql: %s, args: %s", query, args) - } - defer rows.Close() - - for rows.Next() { - if err := handleOneRow(rows); err != nil { - rows.Close() - return errors.Annotatef(err, "sql: %s, args: %s", query, args) - } - } - return errors.Annotatef(rows.Err(), "sql: %s, args: %s", query, args) -} - -func pickupPossibleField(tctx *tcontext.Context, meta TableMeta, db *BaseConn) (string, error) { - // try using _tidb_rowid first - if meta.HasImplicitRowID() { - return "_tidb_rowid", nil - } - // try to use pk or uk - fieldName, err := getNumericIndex(tctx, db, meta) - if err != nil { - return "", err - } - - // if fieldName == "", there is no proper index - return fieldName, nil -} - -func estimateCount(tctx *tcontext.Context, dbName, tableName string, db *BaseConn, field string, conf *Config) uint64 { - var query string - if strings.TrimSpace(field) == "*" || strings.TrimSpace(field) == "" { - query = fmt.Sprintf("EXPLAIN SELECT * FROM `%s`.`%s`", escapeString(dbName), escapeString(tableName)) - } else { - query = fmt.Sprintf("EXPLAIN SELECT `%s` FROM `%s`.`%s`", escapeString(field), escapeString(dbName), escapeString(tableName)) - } - - if conf.Where != "" { - query += " WHERE " - query += conf.Where - } - - estRows := detectEstimateRows(tctx, db, query, []string{"rows", "estRows", "count"}) - /* tidb results field name is estRows (before 4.0.0-beta.2: count) - +-----------------------+----------+-----------+---------------------------------------------------------+ - | id | estRows | task | access object | operator info | - +-----------------------+----------+-----------+---------------------------------------------------------+ - | tablereader_5 | 10000.00 | root | | data:tablefullscan_4 | - | └─tablefullscan_4 | 10000.00 | cop[tikv] | table:a | table:a, keep order:false, stats:pseudo | - +-----------------------+----------+-----------+---------------------------------------------------------- - - mariadb result field name is rows - +------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+ - | id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra | - +------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+ - | 1 | SIMPLE | sbtest1 | index | NULL | k_1 | 4 | NULL | 15000049 | Using index | - +------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+ - - mysql result field name is rows - +----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+ - | id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra | - +----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+ - | 1 | SIMPLE | t1 | NULL | index | NULL | multi_col | 10 | NULL | 5 | 100.00 | Using index | - +----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+ - */ - if estRows > 0 { - return estRows - } - return 0 -} - -func detectEstimateRows(tctx *tcontext.Context, db *BaseConn, query string, fieldNames []string) uint64 { - var ( - fieldIndex int - oneRow []sql.NullString - ) - err := db.QuerySQL(tctx, func(rows *sql.Rows) error { - columns, err := rows.Columns() - if err != nil { - return errors.Trace(err) - } - addr := make([]any, len(columns)) - oneRow = make([]sql.NullString, len(columns)) - fieldIndex = -1 - found: - for i := range oneRow { - for _, fieldName := range fieldNames { - if strings.EqualFold(columns[i], fieldName) { - fieldIndex = i - break found - } - } - } - if fieldIndex == -1 { - rows.Close() - return nil - } - - for i := range oneRow { - addr[i] = &oneRow[i] - } - return rows.Scan(addr...) - }, func() {}, query) - if err != nil || fieldIndex == -1 { - tctx.L().Info("can't estimate rows from db", - zap.String("query", query), zap.Int("fieldIndex", fieldIndex), log.ShortError(err)) - return 0 - } - - estRows, err := strconv.ParseFloat(oneRow[fieldIndex].String, 64) - if err != nil { - tctx.L().Info("can't get parse estimate rows from db", - zap.String("query", query), zap.String("estRows", oneRow[fieldIndex].String), log.ShortError(err)) - return 0 - } - return uint64(estRows) -} - -func parseSnapshotToTSO(pool *sql.DB, snapshot string) (uint64, error) { - snapshotTS, err := strconv.ParseUint(snapshot, 10, 64) - if err == nil { - return snapshotTS, nil - } - var tso sql.NullInt64 - query := "SELECT unix_timestamp(?)" - row := pool.QueryRow(query, snapshot) - err = row.Scan(&tso) - if err != nil { - return 0, errors.Annotatef(err, "sql: %s", strings.ReplaceAll(query, "?", fmt.Sprintf(`"%s"`, snapshot))) - } - if !tso.Valid { - return 0, errors.Errorf("snapshot %s format not supported. please use tso or '2006-01-02 15:04:05' format time", snapshot) - } - return (uint64(tso.Int64) << 18) * 1000, nil -} - -func buildWhereCondition(conf *Config, where string) string { - var query strings.Builder - separator := "WHERE" - leftBracket := " " - rightBracket := " " - if conf.Where != "" && where != "" { - leftBracket = " (" - rightBracket = ") " - } - if conf.Where != "" { - query.WriteString(separator) - query.WriteString(leftBracket) - query.WriteString(conf.Where) - query.WriteString(rightBracket) - separator = "AND" - } - if where != "" { - query.WriteString(separator) - query.WriteString(leftBracket) - query.WriteString(where) - query.WriteString(rightBracket) - } - return query.String() -} - -func escapeString(s string) string { - return strings.ReplaceAll(s, "`", "``") -} - -// GetPartitionNames get partition names from a specified table -func GetPartitionNames(tctx *tcontext.Context, db *BaseConn, schema, table string) (partitions []string, err error) { - partitions = make([]string, 0) - var partitionName sql.NullString - err = db.QuerySQL(tctx, func(rows *sql.Rows) error { - err := rows.Scan(&partitionName) - if err != nil { - return errors.Trace(err) - } - if partitionName.Valid { - partitions = append(partitions, partitionName.String) - } - return nil - }, func() { - partitions = partitions[:0] - }, "SELECT PARTITION_NAME from INFORMATION_SCHEMA.PARTITIONS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", schema, table) - return -} - -// GetPartitionTableIDs get partition tableIDs through histograms. -// SHOW STATS_HISTOGRAMS has db_name,table_name,partition_name but doesn't have partition id -// mysql.stats_histograms has partition_id but doesn't have db_name,table_name,partition_name -// So we combine the results from these two sqls to get partition ids for each table -// If UPDATE_TIME,DISTINCT_COUNT are equal, we assume these two records can represent one line. -// If histograms are not accurate or (UPDATE_TIME,DISTINCT_COUNT) has duplicate data, it's still fine. -// Because the possibility is low and the effect is that we will select more than one regions in one time, -// this will not affect the correctness of the dumping data and will not affect the memory usage much. -// This method is tricky, but no better way is found. -// Because TiDB v3.0.0's information_schema.partition table doesn't have partition name or partition id info -// return (dbName -> tbName -> partitionName -> partitionID, error) -func GetPartitionTableIDs(db *sql.Conn, tables map[string]map[string]struct{}) (map[string]map[string]map[string]int64, error) { - const ( - showStatsHistogramsSQL = "SHOW STATS_HISTOGRAMS" - selectStatsHistogramsSQL = "SELECT TABLE_ID,FROM_UNIXTIME(VERSION DIV 262144 DIV 1000,'%Y-%m-%d %H:%i:%s') AS UPDATE_TIME,DISTINCT_COUNT FROM mysql.stats_histograms" - ) - partitionIDs := make(map[string]map[string]map[string]int64, len(tables)) - rows, err := db.QueryContext(context.Background(), showStatsHistogramsSQL) - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", showStatsHistogramsSQL) - } - results, err := GetSpecifiedColumnValuesAndClose(rows, "DB_NAME", "TABLE_NAME", "PARTITION_NAME", "UPDATE_TIME", "DISTINCT_COUNT") - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", showStatsHistogramsSQL) - } - type partitionInfo struct { - dbName, tbName, partitionName string - } - saveMap := make(map[string]map[string]partitionInfo) - for _, oneRow := range results { - dbName, tbName, partitionName, updateTime, distinctCount := oneRow[0], oneRow[1], oneRow[2], oneRow[3], oneRow[4] - if len(partitionName) == 0 { - continue - } - if tbm, ok := tables[dbName]; ok { - if _, ok = tbm[tbName]; ok { - if _, ok = saveMap[updateTime]; !ok { - saveMap[updateTime] = make(map[string]partitionInfo) - } - saveMap[updateTime][distinctCount] = partitionInfo{ - dbName: dbName, - tbName: tbName, - partitionName: partitionName, - } - } - } - } - if len(saveMap) == 0 { - return map[string]map[string]map[string]int64{}, nil - } - err = simpleQuery(db, selectStatsHistogramsSQL, func(rows *sql.Rows) error { - var ( - tableID int64 - updateTime, distinctCount string - ) - err2 := rows.Scan(&tableID, &updateTime, &distinctCount) - if err2 != nil { - return errors.Trace(err2) - } - if mpt, ok := saveMap[updateTime]; ok { - if partition, ok := mpt[distinctCount]; ok { - dbName, tbName, partitionName := partition.dbName, partition.tbName, partition.partitionName - if _, ok := partitionIDs[dbName]; !ok { - partitionIDs[dbName] = make(map[string]map[string]int64) - } - if _, ok := partitionIDs[dbName][tbName]; !ok { - partitionIDs[dbName][tbName] = make(map[string]int64) - } - partitionIDs[dbName][tbName][partitionName] = tableID - } - } - return nil - }) - return partitionIDs, err -} - -// GetDBInfo get model.DBInfos from database sql interface. -// We need table_id to check whether a region belongs to this table -func GetDBInfo(db *sql.Conn, tables map[string]map[string]struct{}) ([]*model.DBInfo, error) { - const tableIDSQL = "SELECT TABLE_SCHEMA,TABLE_NAME,TIDB_TABLE_ID FROM INFORMATION_SCHEMA.TABLES ORDER BY TABLE_SCHEMA" - - schemas := make([]*model.DBInfo, 0, len(tables)) - var ( - tableSchema, tableName string - tidbTableID int64 - ) - partitionIDs, err := GetPartitionTableIDs(db, tables) - if err != nil { - return nil, err - } - err = simpleQuery(db, tableIDSQL, func(rows *sql.Rows) error { - err2 := rows.Scan(&tableSchema, &tableName, &tidbTableID) - if err2 != nil { - return errors.Trace(err2) - } - if tbm, ok := tables[tableSchema]; !ok { - return nil - } else if _, ok = tbm[tableName]; !ok { - return nil - } - last := len(schemas) - 1 - if last < 0 || schemas[last].Name.O != tableSchema { - dbInfo := &model.DBInfo{Name: model.CIStr{O: tableSchema}} - dbInfo.Deprecated.Tables = make([]*model.TableInfo, 0, len(tables[tableSchema])) - schemas = append(schemas, dbInfo) - last++ - } - var partition *model.PartitionInfo - if tbm, ok := partitionIDs[tableSchema]; ok { - if ptm, ok := tbm[tableName]; ok { - partition = &model.PartitionInfo{Definitions: make([]model.PartitionDefinition, 0, len(ptm))} - for partitionName, partitionID := range ptm { - partition.Definitions = append(partition.Definitions, model.PartitionDefinition{ - ID: partitionID, - Name: model.CIStr{O: partitionName}, - }) - } - } - } - schemas[last].Deprecated.Tables = append(schemas[last].Deprecated.Tables, &model.TableInfo{ - ID: tidbTableID, - Name: model.CIStr{O: tableName}, - Partition: partition, - }) - return nil - }) - return schemas, err -} - -// GetRegionInfos get region info including regionID, start key, end key from database sql interface. -// start key, end key includes information to help split table -func GetRegionInfos(db *sql.Conn) (*pd.RegionsInfo, error) { - const tableRegionSQL = "SELECT REGION_ID,START_KEY,END_KEY FROM INFORMATION_SCHEMA.TIKV_REGION_STATUS ORDER BY START_KEY;" - var ( - regionID int64 - startKey, endKey string - ) - regionsInfo := &pd.RegionsInfo{Regions: make([]pd.RegionInfo, 0)} - err := simpleQuery(db, tableRegionSQL, func(rows *sql.Rows) error { - err := rows.Scan(®ionID, &startKey, &endKey) - if err != nil { - return errors.Trace(err) - } - regionsInfo.Regions = append(regionsInfo.Regions, pd.RegionInfo{ - ID: regionID, - StartKey: startKey, - EndKey: endKey, - }) - return nil - }) - return regionsInfo, err -} - -// GetCharsetAndDefaultCollation gets charset and default collation map. -func GetCharsetAndDefaultCollation(ctx context.Context, db *sql.Conn) (map[string]string, error) { - charsetAndDefaultCollation := make(map[string]string) - query := "SHOW CHARACTER SET" - - // Show an example. - /* - mysql> SHOW CHARACTER SET; - +----------+---------------------------------+---------------------+--------+ - | Charset | Description | Default collation | Maxlen | - +----------+---------------------------------+---------------------+--------+ - | armscii8 | ARMSCII-8 Armenian | armscii8_general_ci | 1 | - | ascii | US ASCII | ascii_general_ci | 1 | - | big5 | Big5 Traditional Chinese | big5_chinese_ci | 2 | - | binary | Binary pseudo charset | binary | 1 | - | cp1250 | Windows Central European | cp1250_general_ci | 1 | - | cp1251 | Windows Cyrillic | cp1251_general_ci | 1 | - +----------+---------------------------------+---------------------+--------+ - */ - - rows, err := db.QueryContext(ctx, query) - if err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - - defer rows.Close() - for rows.Next() { - var charset, description, collation string - var maxlen int - if scanErr := rows.Scan(&charset, &description, &collation, &maxlen); scanErr != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - charsetAndDefaultCollation[strings.ToLower(charset)] = collation - } - if err = rows.Close(); err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - if err = rows.Err(); err != nil { - return nil, errors.Annotatef(err, "sql: %s", query) - } - return charsetAndDefaultCollation, err -} diff --git a/dumpling/export/status.go b/dumpling/export/status.go index 9c67964f69602..0a861f4c40677 100644 --- a/dumpling/export/status.go +++ b/dumpling/export/status.go @@ -18,11 +18,11 @@ const logProgressTick = 2 * time.Minute func (d *Dumper) runLogProgress(tctx *tcontext.Context) { logProgressTicker := time.NewTicker(logProgressTick) - if _, _err_ := failpoint.Eval(_curpkg_("EnableLogProgress")); _err_ == nil { + failpoint.Inject("EnableLogProgress", func() { logProgressTicker.Stop() logProgressTicker = time.NewTicker(time.Duration(1) * time.Second) tctx.L().Debug("EnableLogProgress") - } + }) lastCheckpoint := time.Now() lastBytes := float64(0) defer logProgressTicker.Stop() diff --git a/dumpling/export/status.go__failpoint_stash__ b/dumpling/export/status.go__failpoint_stash__ deleted file mode 100644 index 0a861f4c40677..0000000000000 --- a/dumpling/export/status.go__failpoint_stash__ +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package export - -import ( - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/docker/go-units" - "github.com/pingcap/failpoint" - tcontext "github.com/pingcap/tidb/dumpling/context" - "go.uber.org/zap" -) - -const logProgressTick = 2 * time.Minute - -func (d *Dumper) runLogProgress(tctx *tcontext.Context) { - logProgressTicker := time.NewTicker(logProgressTick) - failpoint.Inject("EnableLogProgress", func() { - logProgressTicker.Stop() - logProgressTicker = time.NewTicker(time.Duration(1) * time.Second) - tctx.L().Debug("EnableLogProgress") - }) - lastCheckpoint := time.Now() - lastBytes := float64(0) - defer logProgressTicker.Stop() - for { - select { - case <-tctx.Done(): - tctx.L().Debug("stopping log progress") - return - case <-logProgressTicker.C: - nanoseconds := float64(time.Since(lastCheckpoint).Nanoseconds()) - s := d.GetStatus() - tctx.L().Info("progress", - zap.String("tables", fmt.Sprintf("%.0f/%.0f (%.1f%%)", s.CompletedTables, float64(d.totalTables), s.CompletedTables/float64(d.totalTables)*100)), - zap.String("finished rows", fmt.Sprintf("%.0f", s.FinishedRows)), - zap.String("estimate total rows", fmt.Sprintf("%.0f", s.EstimateTotalRows)), - zap.String("finished size", units.HumanSize(s.FinishedBytes)), - zap.Float64("average speed(MiB/s)", (s.FinishedBytes-lastBytes)/(1048576e-9*nanoseconds)), - zap.Float64("recent speed bps", s.CurrentSpeedBPS), - zap.String("chunks progress", s.Progress), - ) - - lastCheckpoint = time.Now() - lastBytes = s.FinishedBytes - } - } -} - -// DumpStatus is the status of dumping. -type DumpStatus struct { - CompletedTables float64 - FinishedBytes float64 - FinishedRows float64 - EstimateTotalRows float64 - TotalTables int64 - CurrentSpeedBPS float64 - Progress string -} - -// GetStatus returns the status of dumping by reading metrics. -func (d *Dumper) GetStatus() *DumpStatus { - ret := &DumpStatus{} - ret.TotalTables = atomic.LoadInt64(&d.totalTables) - ret.CompletedTables = ReadCounter(d.metrics.finishedTablesCounter) - ret.FinishedBytes = ReadGauge(d.metrics.finishedSizeGauge) - ret.FinishedRows = ReadGauge(d.metrics.finishedRowsGauge) - ret.EstimateTotalRows = ReadCounter(d.metrics.estimateTotalRowsCounter) - ret.CurrentSpeedBPS = d.speedRecorder.GetSpeed(ret.FinishedBytes) - if d.metrics.progressReady.Load() { - // chunks will be zero when upstream has no data - if d.metrics.totalChunks.Load() == 0 { - ret.Progress = "100 %" - return ret - } - progress := float64(d.metrics.completedChunks.Load()) / float64(d.metrics.totalChunks.Load()) - if progress > 1 { - ret.Progress = "100 %" - d.L().Warn("completedChunks is greater than totalChunks", zap.Int64("completedChunks", d.metrics.completedChunks.Load()), zap.Int64("totalChunks", d.metrics.totalChunks.Load())) - } else { - ret.Progress = fmt.Sprintf("%5.2f %%", progress*100) - } - } - return ret -} - -func calculateTableCount(m DatabaseTables) int { - cnt := 0 - for _, tables := range m { - for _, table := range tables { - if table.Type == TableTypeBase { - cnt++ - } - } - } - return cnt -} - -// SpeedRecorder record the finished bytes and calculate its speed. -type SpeedRecorder struct { - mu sync.Mutex - lastFinished float64 - lastUpdateTime time.Time - speedBPS float64 -} - -// NewSpeedRecorder new a SpeedRecorder. -func NewSpeedRecorder() *SpeedRecorder { - return &SpeedRecorder{ - lastUpdateTime: time.Now(), - } -} - -// GetSpeed calculate status speed. -func (s *SpeedRecorder) GetSpeed(finished float64) float64 { - s.mu.Lock() - defer s.mu.Unlock() - - if finished <= s.lastFinished { - // for finished bytes does not get forwarded, use old speed to avoid - // display zero. We may find better strategy in future. - return s.speedBPS - } - - now := time.Now() - elapsed := now.Sub(s.lastUpdateTime).Seconds() - if elapsed == 0 { - // if time is short, return last speed - return s.speedBPS - } - currentSpeed := (finished - s.lastFinished) / elapsed - if currentSpeed == 0 { - currentSpeed = 1 - } - - s.lastFinished = finished - s.lastUpdateTime = now - s.speedBPS = currentSpeed - - return currentSpeed -} diff --git a/dumpling/export/writer_util.go b/dumpling/export/writer_util.go index d6a0c3c69d609..e7ed2de2e611a 100644 --- a/dumpling/export/writer_util.go +++ b/dumpling/export/writer_util.go @@ -239,10 +239,10 @@ func WriteInsert( } counter++ wp.AddFileSize(uint64(bf.Len()-lastBfSize) + 2) // 2 is for ",\n" and ";\n" - if _, _err_ := failpoint.Eval(_curpkg_("ChaosBrokenWriterConn")); _err_ == nil { - return 0, errors.New("connection is closed") - } - failpoint.Eval(_curpkg_("AtEveryRow")) + failpoint.Inject("ChaosBrokenWriterConn", func(_ failpoint.Value) { + failpoint.Return(0, errors.New("connection is closed")) + }) + failpoint.Inject("AtEveryRow", nil) fileRowIter.Next() shouldSwitch := wp.ShouldSwitchStatement() @@ -464,9 +464,9 @@ func buildFileWriter(tctx *tcontext.Context, s storage.ExternalStorage, fileName tctx.L().Debug("opened file", zap.String("path", fullPath)) tearDownRoutine := func(ctx context.Context) error { err := writer.Close(ctx) - if _, _err_ := failpoint.Eval(_curpkg_("FailToCloseMetaFile")); _err_ == nil { + failpoint.Inject("FailToCloseMetaFile", func(_ failpoint.Value) { err = errors.New("injected error: fail to close meta file") - } + }) if err == nil { return nil } @@ -507,9 +507,9 @@ func buildInterceptFileWriter(pCtx *tcontext.Context, s storage.ExternalStorage, } pCtx.L().Debug("tear down lazy file writer...", zap.String("path", fullPath)) err := writer.Close(ctx) - if _, _err_ := failpoint.Eval(_curpkg_("FailToCloseDataFile")); _err_ == nil { + failpoint.Inject("FailToCloseDataFile", func(_ failpoint.Value) { err = errors.New("injected error: fail to close data file") - } + }) if err != nil { pCtx.L().Warn("fail to close file", zap.String("path", fullPath), diff --git a/dumpling/export/writer_util.go__failpoint_stash__ b/dumpling/export/writer_util.go__failpoint_stash__ deleted file mode 100644 index e7ed2de2e611a..0000000000000 --- a/dumpling/export/writer_util.go__failpoint_stash__ +++ /dev/null @@ -1,674 +0,0 @@ -// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0. - -package export - -import ( - "bytes" - "context" - "fmt" - "io" - "strings" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/summary" - tcontext "github.com/pingcap/tidb/dumpling/context" - "github.com/pingcap/tidb/dumpling/log" - "github.com/prometheus/client_golang/prometheus" - "go.uber.org/zap" -) - -const lengthLimit = 1048576 - -var pool = sync.Pool{New: func() any { - return &bytes.Buffer{} -}} - -type writerPipe struct { - input chan *bytes.Buffer - closed chan struct{} - errCh chan error - metrics *metrics - labels prometheus.Labels - - finishedFileSize uint64 - currentFileSize uint64 - currentStatementSize uint64 - - fileSizeLimit uint64 - statementSizeLimit uint64 - - w storage.ExternalFileWriter -} - -func newWriterPipe( - w storage.ExternalFileWriter, - fileSizeLimit, - statementSizeLimit uint64, - metrics *metrics, - labels prometheus.Labels, -) *writerPipe { - return &writerPipe{ - input: make(chan *bytes.Buffer, 8), - closed: make(chan struct{}), - errCh: make(chan error, 1), - w: w, - metrics: metrics, - labels: labels, - - currentFileSize: 0, - currentStatementSize: 0, - fileSizeLimit: fileSizeLimit, - statementSizeLimit: statementSizeLimit, - } -} - -func (b *writerPipe) Run(tctx *tcontext.Context) { - defer close(b.closed) - var errOccurs bool - receiveChunkTime := time.Now() - for { - select { - case s, ok := <-b.input: - if !ok { - return - } - if errOccurs { - continue - } - ObserveHistogram(b.metrics.receiveWriteChunkTimeHistogram, time.Since(receiveChunkTime).Seconds()) - receiveChunkTime = time.Now() - err := writeBytes(tctx, b.w, s.Bytes()) - ObserveHistogram(b.metrics.writeTimeHistogram, time.Since(receiveChunkTime).Seconds()) - AddGauge(b.metrics.finishedSizeGauge, float64(s.Len())) - b.finishedFileSize += uint64(s.Len()) - s.Reset() - pool.Put(s) - if err != nil { - errOccurs = true - b.errCh <- err - } - receiveChunkTime = time.Now() - case <-tctx.Done(): - return - } - } -} - -func (b *writerPipe) AddFileSize(fileSize uint64) { - b.currentFileSize += fileSize - b.currentStatementSize += fileSize -} - -func (b *writerPipe) Error() error { - select { - case err := <-b.errCh: - return err - default: - return nil - } -} - -func (b *writerPipe) ShouldSwitchFile() bool { - return b.fileSizeLimit != UnspecifiedSize && b.currentFileSize >= b.fileSizeLimit -} - -func (b *writerPipe) ShouldSwitchStatement() bool { - return (b.fileSizeLimit != UnspecifiedSize && b.currentFileSize >= b.fileSizeLimit) || - (b.statementSizeLimit != UnspecifiedSize && b.currentStatementSize >= b.statementSizeLimit) -} - -// WriteMeta writes MetaIR to a storage.ExternalFileWriter -func WriteMeta(tctx *tcontext.Context, meta MetaIR, w storage.ExternalFileWriter) error { - tctx.L().Debug("start dumping meta data", zap.String("target", meta.TargetName())) - - specCmtIter := meta.SpecialComments() - for specCmtIter.HasNext() { - if err := write(tctx, w, fmt.Sprintf("%s\n", specCmtIter.Next())); err != nil { - return err - } - } - - if err := write(tctx, w, meta.MetaSQL()); err != nil { - return err - } - - tctx.L().Debug("finish dumping meta data", zap.String("target", meta.TargetName())) - return nil -} - -// WriteInsert writes TableDataIR to a storage.ExternalFileWriter in sql type -func WriteInsert( - pCtx *tcontext.Context, - cfg *Config, - meta TableMeta, - tblIR TableDataIR, - w storage.ExternalFileWriter, - metrics *metrics, -) (n uint64, err error) { - fileRowIter := tblIR.Rows() - if !fileRowIter.HasNext() { - return 0, fileRowIter.Error() - } - - bf := pool.Get().(*bytes.Buffer) - if bfCap := bf.Cap(); bfCap < lengthLimit { - bf.Grow(lengthLimit - bfCap) - } - - wp := newWriterPipe(w, cfg.FileSize, cfg.StatementSize, metrics, cfg.Labels) - - // use context.Background here to make sure writerPipe can deplete all the chunks in pipeline - ctx, cancel := tcontext.Background().WithLogger(pCtx.L()).WithCancel() - var wg sync.WaitGroup - wg.Add(1) - go func() { - wp.Run(ctx) - wg.Done() - }() - defer func() { - cancel() - wg.Wait() - }() - - specCmtIter := meta.SpecialComments() - for specCmtIter.HasNext() { - bf.WriteString(specCmtIter.Next()) - bf.WriteByte('\n') - } - wp.currentFileSize += uint64(bf.Len()) - - var ( - insertStatementPrefix string - row = MakeRowReceiver(meta.ColumnTypes()) - counter uint64 - lastCounter uint64 - escapeBackslash = cfg.EscapeBackslash - ) - - defer func() { - if err != nil { - pCtx.L().Warn("fail to dumping table(chunk), will revert some metrics and start a retry if possible", - zap.String("database", meta.DatabaseName()), - zap.String("table", meta.TableName()), - zap.Uint64("finished rows", lastCounter), - zap.Uint64("finished size", wp.finishedFileSize), - log.ShortError(err)) - SubGauge(metrics.finishedRowsGauge, float64(lastCounter)) - SubGauge(metrics.finishedSizeGauge, float64(wp.finishedFileSize)) - } else { - pCtx.L().Debug("finish dumping table(chunk)", - zap.String("database", meta.DatabaseName()), - zap.String("table", meta.TableName()), - zap.Uint64("finished rows", counter), - zap.Uint64("finished size", wp.finishedFileSize)) - summary.CollectSuccessUnit(summary.TotalBytes, 1, wp.finishedFileSize) - summary.CollectSuccessUnit("total rows", 1, counter) - } - }() - - selectedField := meta.SelectedField() - - // if has generated column - if selectedField != "" && selectedField != "*" { - insertStatementPrefix = fmt.Sprintf("INSERT INTO %s (%s) VALUES\n", - wrapBackTicks(escapeString(meta.TableName())), selectedField) - } else { - insertStatementPrefix = fmt.Sprintf("INSERT INTO %s VALUES\n", - wrapBackTicks(escapeString(meta.TableName()))) - } - insertStatementPrefixLen := uint64(len(insertStatementPrefix)) - - for fileRowIter.HasNext() { - wp.currentStatementSize = 0 - bf.WriteString(insertStatementPrefix) - wp.AddFileSize(insertStatementPrefixLen) - - for fileRowIter.HasNext() { - lastBfSize := bf.Len() - if selectedField != "" { - if err = fileRowIter.Decode(row); err != nil { - return counter, errors.Trace(err) - } - row.WriteToBuffer(bf, escapeBackslash) - } else { - bf.WriteString("()") - } - counter++ - wp.AddFileSize(uint64(bf.Len()-lastBfSize) + 2) // 2 is for ",\n" and ";\n" - failpoint.Inject("ChaosBrokenWriterConn", func(_ failpoint.Value) { - failpoint.Return(0, errors.New("connection is closed")) - }) - failpoint.Inject("AtEveryRow", nil) - - fileRowIter.Next() - shouldSwitch := wp.ShouldSwitchStatement() - if fileRowIter.HasNext() && !shouldSwitch { - bf.WriteString(",\n") - } else { - bf.WriteString(";\n") - } - if bf.Len() >= lengthLimit { - select { - case <-pCtx.Done(): - return counter, pCtx.Err() - case err = <-wp.errCh: - return counter, err - case wp.input <- bf: - bf = pool.Get().(*bytes.Buffer) - if bfCap := bf.Cap(); bfCap < lengthLimit { - bf.Grow(lengthLimit - bfCap) - } - AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) - lastCounter = counter - } - } - - if shouldSwitch { - break - } - } - if wp.ShouldSwitchFile() { - break - } - } - if bf.Len() > 0 { - wp.input <- bf - } - close(wp.input) - <-wp.closed - AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) - lastCounter = counter - if err = fileRowIter.Error(); err != nil { - return counter, errors.Trace(err) - } - return counter, wp.Error() -} - -// WriteInsertInCsv writes TableDataIR to a storage.ExternalFileWriter in csv type -func WriteInsertInCsv( - pCtx *tcontext.Context, - cfg *Config, - meta TableMeta, - tblIR TableDataIR, - w storage.ExternalFileWriter, - metrics *metrics, -) (n uint64, err error) { - fileRowIter := tblIR.Rows() - if !fileRowIter.HasNext() { - return 0, fileRowIter.Error() - } - - bf := pool.Get().(*bytes.Buffer) - if bfCap := bf.Cap(); bfCap < lengthLimit { - bf.Grow(lengthLimit - bfCap) - } - - wp := newWriterPipe(w, cfg.FileSize, UnspecifiedSize, metrics, cfg.Labels) - opt := &csvOption{ - nullValue: cfg.CsvNullValue, - separator: []byte(cfg.CsvSeparator), - delimiter: []byte(cfg.CsvDelimiter), - lineTerminator: []byte(cfg.CsvLineTerminator), - binaryFormat: DialectBinaryFormatMap[cfg.CsvOutputDialect], - } - - // use context.Background here to make sure writerPipe can deplete all the chunks in pipeline - ctx, cancel := tcontext.Background().WithLogger(pCtx.L()).WithCancel() - var wg sync.WaitGroup - wg.Add(1) - go func() { - wp.Run(ctx) - wg.Done() - }() - defer func() { - cancel() - wg.Wait() - }() - - var ( - row = MakeRowReceiver(meta.ColumnTypes()) - counter uint64 - lastCounter uint64 - escapeBackslash = cfg.EscapeBackslash - selectedFields = meta.SelectedField() - ) - - defer func() { - if err != nil { - pCtx.L().Warn("fail to dumping table(chunk), will revert some metrics and start a retry if possible", - zap.String("database", meta.DatabaseName()), - zap.String("table", meta.TableName()), - zap.Uint64("finished rows", lastCounter), - zap.Uint64("finished size", wp.finishedFileSize), - log.ShortError(err)) - SubGauge(metrics.finishedRowsGauge, float64(lastCounter)) - SubGauge(metrics.finishedSizeGauge, float64(wp.finishedFileSize)) - } else { - pCtx.L().Debug("finish dumping table(chunk)", - zap.String("database", meta.DatabaseName()), - zap.String("table", meta.TableName()), - zap.Uint64("finished rows", counter), - zap.Uint64("finished size", wp.finishedFileSize)) - summary.CollectSuccessUnit(summary.TotalBytes, 1, wp.finishedFileSize) - summary.CollectSuccessUnit("total rows", 1, counter) - } - }() - - if !cfg.NoHeader && len(meta.ColumnNames()) != 0 && selectedFields != "" { - for i, col := range meta.ColumnNames() { - bf.Write(opt.delimiter) - escapeCSV([]byte(col), bf, escapeBackslash, opt) - bf.Write(opt.delimiter) - if i != len(meta.ColumnTypes())-1 { - bf.Write(opt.separator) - } - } - bf.Write(opt.lineTerminator) - } - wp.currentFileSize += uint64(bf.Len()) - - for fileRowIter.HasNext() { - lastBfSize := bf.Len() - if selectedFields != "" { - if err = fileRowIter.Decode(row); err != nil { - return counter, errors.Trace(err) - } - row.WriteToBufferInCsv(bf, escapeBackslash, opt) - } - counter++ - wp.currentFileSize += uint64(bf.Len()-lastBfSize) + 1 // 1 is for "\n" - - bf.Write(opt.lineTerminator) - if bf.Len() >= lengthLimit { - select { - case <-pCtx.Done(): - return counter, pCtx.Err() - case err = <-wp.errCh: - return counter, err - case wp.input <- bf: - bf = pool.Get().(*bytes.Buffer) - if bfCap := bf.Cap(); bfCap < lengthLimit { - bf.Grow(lengthLimit - bfCap) - } - AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) - lastCounter = counter - } - } - - fileRowIter.Next() - if wp.ShouldSwitchFile() { - break - } - } - - if bf.Len() > 0 { - wp.input <- bf - } - close(wp.input) - <-wp.closed - AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) - lastCounter = counter - if err = fileRowIter.Error(); err != nil { - return counter, errors.Trace(err) - } - return counter, wp.Error() -} - -func write(tctx *tcontext.Context, writer storage.ExternalFileWriter, str string) error { - _, err := writer.Write(tctx, []byte(str)) - if err != nil { - // str might be very long, only output the first 200 chars - outputLength := len(str) - if outputLength >= 200 { - outputLength = 200 - } - tctx.L().Warn("fail to write", - zap.String("heading 200 characters", str[:outputLength]), - zap.Error(err)) - } - return errors.Trace(err) -} - -func writeBytes(tctx *tcontext.Context, writer storage.ExternalFileWriter, p []byte) error { - _, err := writer.Write(tctx, p) - if err != nil { - // str might be very long, only output the first 200 chars - outputLength := len(p) - if outputLength >= 200 { - outputLength = 200 - } - tctx.L().Warn("fail to write", - zap.ByteString("heading 200 characters", p[:outputLength]), - zap.Error(err)) - if strings.Contains(err.Error(), "Part number must be an integer between 1 and 10000") { - err = errors.Annotate(err, "workaround: dump file exceeding 50GB, please specify -F=256MB -r=200000 to avoid this problem") - } - } - return errors.Trace(err) -} - -func buildFileWriter(tctx *tcontext.Context, s storage.ExternalStorage, fileName string, compressType storage.CompressType) (storage.ExternalFileWriter, func(ctx context.Context) error, error) { - fileName += compressFileSuffix(compressType) - fullPath := s.URI() + "/" + fileName - writer, err := storage.WithCompression(s, compressType, storage.DecompressConfig{}).Create(tctx, fileName, nil) - if err != nil { - tctx.L().Warn("fail to open file", - zap.String("path", fullPath), - zap.Error(err)) - return nil, nil, errors.Trace(err) - } - tctx.L().Debug("opened file", zap.String("path", fullPath)) - tearDownRoutine := func(ctx context.Context) error { - err := writer.Close(ctx) - failpoint.Inject("FailToCloseMetaFile", func(_ failpoint.Value) { - err = errors.New("injected error: fail to close meta file") - }) - if err == nil { - return nil - } - err = errors.Trace(err) - tctx.L().Warn("fail to close file", - zap.String("path", fullPath), - zap.Error(err)) - return err - } - return writer, tearDownRoutine, nil -} - -func buildInterceptFileWriter(pCtx *tcontext.Context, s storage.ExternalStorage, fileName string, compressType storage.CompressType) (storage.ExternalFileWriter, func(context.Context) error) { - fileName += compressFileSuffix(compressType) - var writer storage.ExternalFileWriter - fullPath := s.URI() + "/" + fileName - fileWriter := &InterceptFileWriter{} - initRoutine := func() error { - // use separated context pCtx here to make sure context used in ExternalFile won't be canceled before close, - // which will cause a context canceled error when closing gcs's Writer - w, err := storage.WithCompression(s, compressType, storage.DecompressConfig{}).Create(pCtx, fileName, nil) - if err != nil { - pCtx.L().Warn("fail to open file", - zap.String("path", fullPath), - zap.Error(err)) - return newWriterError(err) - } - writer = w - pCtx.L().Debug("opened file", zap.String("path", fullPath)) - fileWriter.ExternalFileWriter = writer - return nil - } - fileWriter.initRoutine = initRoutine - - tearDownRoutine := func(ctx context.Context) error { - if writer == nil { - return nil - } - pCtx.L().Debug("tear down lazy file writer...", zap.String("path", fullPath)) - err := writer.Close(ctx) - failpoint.Inject("FailToCloseDataFile", func(_ failpoint.Value) { - err = errors.New("injected error: fail to close data file") - }) - if err != nil { - pCtx.L().Warn("fail to close file", - zap.String("path", fullPath), - zap.Error(err)) - } - return err - } - return fileWriter, tearDownRoutine -} - -// LazyStringWriter is an interceptor of io.StringWriter, -// will lazily create file the first time StringWriter need to write something. -type LazyStringWriter struct { - initRoutine func() error - sync.Once - io.StringWriter - err error -} - -// WriteString implements io.StringWriter. It check whether writer has written something and init a file at first time -func (l *LazyStringWriter) WriteString(str string) (int, error) { - l.Do(func() { l.err = l.initRoutine() }) - if l.err != nil { - return 0, errors.Errorf("open file error: %s", l.err.Error()) - } - return l.StringWriter.WriteString(str) -} - -type writerError struct { - error -} - -func (e *writerError) Error() string { - return e.error.Error() -} - -func newWriterError(err error) error { - if err == nil { - return nil - } - return &writerError{error: err} -} - -// InterceptFileWriter is an interceptor of os.File, -// tracking whether a StringWriter has written something. -type InterceptFileWriter struct { - storage.ExternalFileWriter - sync.Once - SomethingIsWritten bool - - initRoutine func() error - err error -} - -// Write implements storage.ExternalFileWriter.Write. It check whether writer has written something and init a file at first time -func (w *InterceptFileWriter) Write(ctx context.Context, p []byte) (int, error) { - w.Do(func() { w.err = w.initRoutine() }) - if len(p) > 0 { - w.SomethingIsWritten = true - } - if w.err != nil { - return 0, errors.Annotate(w.err, "open file error") - } - n, err := w.ExternalFileWriter.Write(ctx, p) - return n, newWriterError(err) -} - -// Close closes the InterceptFileWriter -func (w *InterceptFileWriter) Close(ctx context.Context) error { - return w.ExternalFileWriter.Close(ctx) -} - -func wrapBackTicks(identifier string) string { - if !strings.HasPrefix(identifier, "`") && !strings.HasSuffix(identifier, "`") { - return wrapStringWith(identifier, "`") - } - return identifier -} - -func wrapStringWith(str string, wrapper string) string { - return fmt.Sprintf("%s%s%s", wrapper, str, wrapper) -} - -func compressFileSuffix(compressType storage.CompressType) string { - switch compressType { - case storage.NoCompression: - return "" - case storage.Gzip: - return ".gz" - case storage.Snappy: - return ".snappy" - case storage.Zstd: - return ".zst" - default: - return "" - } -} - -// FileFormat is the format that output to file. Currently we support SQL text and CSV file format. -type FileFormat int32 - -const ( - // FileFormatUnknown indicates the given file type is unknown - FileFormatUnknown FileFormat = iota - // FileFormatSQLText indicates the given file type is sql type - FileFormatSQLText - // FileFormatCSV indicates the given file type is csv type - FileFormatCSV -) - -const ( - // FileFormatSQLTextString indicates the string/suffix of sql type file - FileFormatSQLTextString = "sql" - // FileFormatCSVString indicates the string/suffix of csv type file - FileFormatCSVString = "csv" -) - -// String implement Stringer.String method. -func (f FileFormat) String() string { - switch f { - case FileFormatSQLText: - return strings.ToUpper(FileFormatSQLTextString) - case FileFormatCSV: - return strings.ToUpper(FileFormatCSVString) - default: - return "unknown" - } -} - -// Extension returns the extension for specific format. -// -// text -> "sql" -// csv -> "csv" -func (f FileFormat) Extension() string { - switch f { - case FileFormatSQLText: - return FileFormatSQLTextString - case FileFormatCSV: - return FileFormatCSVString - default: - return "unknown_format" - } -} - -// WriteInsert writes TableDataIR to a storage.ExternalFileWriter in sql/csv type -func (f FileFormat) WriteInsert( - pCtx *tcontext.Context, - cfg *Config, - meta TableMeta, - tblIR TableDataIR, - w storage.ExternalFileWriter, - metrics *metrics, -) (uint64, error) { - switch f { - case FileFormatSQLText: - return WriteInsert(pCtx, cfg, meta, tblIR, w, metrics) - case FileFormatCSV: - return WriteInsertInCsv(pCtx, cfg, meta, tblIR, w, metrics) - default: - return 0, errors.Errorf("unknown file format") - } -} diff --git a/lightning/pkg/importer/binding__failpoint_binding__.go b/lightning/pkg/importer/binding__failpoint_binding__.go deleted file mode 100644 index 62be625525776..0000000000000 --- a/lightning/pkg/importer/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package importer - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/lightning/pkg/importer/chunk_process.go b/lightning/pkg/importer/chunk_process.go index fff2ef14cefd1..802e7c15d4b6c 100644 --- a/lightning/pkg/importer/chunk_process.go +++ b/lightning/pkg/importer/chunk_process.go @@ -480,9 +480,9 @@ func (cr *chunkProcessor) encodeLoop( kvPacket = append(kvPacket, deliveredKVs{kvs: kvs, columns: filteredColumns, offset: newOffset, rowID: rowID, realOffset: newScannedOffset}) kvSize += kvs.Size() - if val, _err_ := failpoint.Eval(_curpkg_("mock-kv-size")); _err_ == nil { + failpoint.Inject("mock-kv-size", func(val failpoint.Value) { kvSize += uint64(val.(int)) - } + }) // pebble cannot allow > 4.0G kv in one batch. // we will meet pebble panic when import sql file and each kv has the size larger than 4G / maxKvPairsCnt. // so add this check. @@ -735,25 +735,25 @@ func (cr *chunkProcessor) deliverLoop( // No need to save checkpoint if nothing was delivered. dataSynced = cr.maybeSaveCheckpoint(rc, t, engineID, cr.chunk, dataEngine, indexEngine) } - if _, _err_ := failpoint.Eval(_curpkg_("SlowDownWriteRows")); _err_ == nil { + failpoint.Inject("SlowDownWriteRows", func() { deliverLogger.Warn("Slowed down write rows") finished := rc.status.FinishedFileSize.Load() total := rc.status.TotalFileSize.Load() deliverLogger.Warn("PrintStatus Failpoint", zap.Int64("finished", finished), zap.Int64("total", total)) - } - failpoint.Eval(_curpkg_("FailAfterWriteRows")) + }) + failpoint.Inject("FailAfterWriteRows", nil) // TODO: for local backend, we may save checkpoint more frequently, e.g. after written // 10GB kv pairs to data engine, we can do a flush for both data & index engine, then we // can safely update current checkpoint. - if _, _err_ := failpoint.Eval(_curpkg_("LocalBackendSaveCheckpoint")); _err_ == nil { + failpoint.Inject("LocalBackendSaveCheckpoint", func() { if !isLocalBackend(rc.cfg) && (dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0) { // No need to save checkpoint if nothing was delivered. saveCheckpoint(rc, t, engineID, cr.chunk) } - } + }) } return diff --git a/lightning/pkg/importer/chunk_process.go__failpoint_stash__ b/lightning/pkg/importer/chunk_process.go__failpoint_stash__ deleted file mode 100644 index 802e7c15d4b6c..0000000000000 --- a/lightning/pkg/importer/chunk_process.go__failpoint_stash__ +++ /dev/null @@ -1,778 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "bytes" - "context" - "io" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/keyspace" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/tidb" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/mydump" - verify "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/lightning/worker" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/store/driver/txn" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/extsort" - "go.uber.org/zap" -) - -// chunkProcessor process data chunk -// for local backend it encodes and writes KV to local disk -// for tidb backend it transforms data into sql and executes them. -type chunkProcessor struct { - parser mydump.Parser - index int - chunk *checkpoints.ChunkCheckpoint -} - -func newChunkProcessor( - ctx context.Context, - index int, - cfg *config.Config, - chunk *checkpoints.ChunkCheckpoint, - ioWorkers *worker.Pool, - store storage.ExternalStorage, - tableInfo *model.TableInfo, -) (*chunkProcessor, error) { - parser, err := openParser(ctx, cfg, chunk, ioWorkers, store, tableInfo) - if err != nil { - return nil, err - } - return &chunkProcessor{ - parser: parser, - index: index, - chunk: chunk, - }, nil -} - -func openParser( - ctx context.Context, - cfg *config.Config, - chunk *checkpoints.ChunkCheckpoint, - ioWorkers *worker.Pool, - store storage.ExternalStorage, - tblInfo *model.TableInfo, -) (mydump.Parser, error) { - blockBufSize := int64(cfg.Mydumper.ReadBlockSize) - reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, store, storage.DecompressConfig{ - ZStdDecodeConcurrency: 1, - }) - if err != nil { - return nil, err - } - - var parser mydump.Parser - switch chunk.FileMeta.Type { - case mydump.SourceTypeCSV: - hasHeader := cfg.Mydumper.CSV.Header && chunk.Chunk.Offset == 0 - // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. - charsetConvertor, err := mydump.NewCharsetConvertor(cfg.Mydumper.DataCharacterSet, cfg.Mydumper.DataInvalidCharReplace) - if err != nil { - return nil, err - } - parser, err = mydump.NewCSVParser(ctx, &cfg.Mydumper.CSV, reader, blockBufSize, ioWorkers, hasHeader, charsetConvertor) - if err != nil { - return nil, err - } - case mydump.SourceTypeSQL: - parser = mydump.NewChunkParser(ctx, cfg.TiDB.SQLMode, reader, blockBufSize, ioWorkers) - case mydump.SourceTypeParquet: - parser, err = mydump.NewParquetParser(ctx, store, reader, chunk.FileMeta.Path) - if err != nil { - return nil, err - } - default: - return nil, errors.Errorf("file '%s' with unknown source type '%s'", chunk.Key.Path, chunk.FileMeta.Type.String()) - } - - if chunk.FileMeta.Compression == mydump.CompressionNone { - if err = parser.SetPos(chunk.Chunk.Offset, chunk.Chunk.PrevRowIDMax); err != nil { - _ = parser.Close() - return nil, err - } - } else { - if err = mydump.ReadUntil(parser, chunk.Chunk.Offset); err != nil { - _ = parser.Close() - return nil, err - } - parser.SetRowID(chunk.Chunk.PrevRowIDMax) - } - if len(chunk.ColumnPermutation) > 0 { - parser.SetColumns(getColumnNames(tblInfo, chunk.ColumnPermutation)) - } - - return parser, nil -} - -func getColumnNames(tableInfo *model.TableInfo, permutation []int) []string { - colIndexes := make([]int, 0, len(permutation)) - for i := 0; i < len(permutation); i++ { - colIndexes = append(colIndexes, -1) - } - colCnt := 0 - for i, p := range permutation { - if p >= 0 { - colIndexes[p] = i - colCnt++ - } - } - - names := make([]string, 0, colCnt) - for _, idx := range colIndexes { - // skip columns with index -1 - if idx >= 0 { - // original fields contains _tidb_rowid field - if idx == len(tableInfo.Columns) { - names = append(names, model.ExtraHandleName.O) - } else { - names = append(names, tableInfo.Columns[idx].Name.O) - } - } - } - return names -} - -func (cr *chunkProcessor) process( - ctx context.Context, - t *TableImporter, - engineID int32, - dataEngine, indexEngine backend.EngineWriter, - rc *Controller, -) error { - logger := t.logger.With( - zap.Int32("engineNumber", engineID), - zap.Int("fileIndex", cr.index), - zap.Stringer("path", &cr.chunk.Key), - ) - // Create the encoder. - kvEncoder, err := rc.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{ - SessionOptions: encode.SessionOptions{ - SQLMode: rc.cfg.TiDB.SQLMode, - Timestamp: cr.chunk.Timestamp, - SysVars: rc.sysVars, - // use chunk.PrevRowIDMax as the auto random seed, so it can stay the same value after recover from checkpoint. - AutoRandomSeed: cr.chunk.Chunk.PrevRowIDMax, - }, - Path: cr.chunk.Key.Path, - Table: t.encTable, - Logger: logger, - }) - if err != nil { - return err - } - defer kvEncoder.Close() - - kvsCh := make(chan []deliveredKVs, maxKVQueueSize) - deliverCompleteCh := make(chan deliverResult) - - go func() { - defer close(deliverCompleteCh) - dur, err := cr.deliverLoop(ctx, kvsCh, t, engineID, dataEngine, indexEngine, rc) - select { - case <-ctx.Done(): - case deliverCompleteCh <- deliverResult{dur, err}: - } - }() - - logTask := logger.Begin(zap.InfoLevel, "restore file") - - readTotalDur, encodeTotalDur, encodeErr := cr.encodeLoop( - ctx, - kvsCh, - t, - logger, - kvEncoder, - deliverCompleteCh, - rc, - ) - var deliverErr error - select { - case deliverResult, ok := <-deliverCompleteCh: - if ok { - logTask.End(zap.ErrorLevel, deliverResult.err, - zap.Duration("readDur", readTotalDur), - zap.Duration("encodeDur", encodeTotalDur), - zap.Duration("deliverDur", deliverResult.totalDur), - zap.Object("checksum", &cr.chunk.Checksum), - ) - deliverErr = deliverResult.err - } else { - // else, this must cause by ctx cancel - deliverErr = ctx.Err() - } - case <-ctx.Done(): - deliverErr = ctx.Err() - } - return errors.Trace(firstErr(encodeErr, deliverErr)) -} - -//nolint:nakedret // TODO: refactor -func (cr *chunkProcessor) encodeLoop( - ctx context.Context, - kvsCh chan<- []deliveredKVs, - t *TableImporter, - logger log.Logger, - kvEncoder encode.Encoder, - deliverCompleteCh <-chan deliverResult, - rc *Controller, -) (readTotalDur time.Duration, encodeTotalDur time.Duration, err error) { - defer close(kvsCh) - - // when AddIndexBySQL, we use all PK and UK to run pre-deduplication, and then we - // strip almost all secondary index to run encodeLoop. In encodeLoop when we meet - // a duplicated row marked by pre-deduplication, we need original table structure - // to generate the duplicate error message, so here create a new encoder with - // original table structure. - originalTableEncoder := kvEncoder - if rc.cfg.TikvImporter.AddIndexBySQL { - encTable, err := tables.TableFromMeta(t.alloc, t.tableInfo.Desired) - if err != nil { - return 0, 0, errors.Trace(err) - } - - originalTableEncoder, err = rc.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{ - SessionOptions: encode.SessionOptions{ - SQLMode: rc.cfg.TiDB.SQLMode, - Timestamp: cr.chunk.Timestamp, - SysVars: rc.sysVars, - // use chunk.PrevRowIDMax as the auto random seed, so it can stay the same value after recover from checkpoint. - AutoRandomSeed: cr.chunk.Chunk.PrevRowIDMax, - }, - Path: cr.chunk.Key.Path, - Table: encTable, - Logger: logger, - }) - if err != nil { - return 0, 0, errors.Trace(err) - } - defer originalTableEncoder.Close() - } - - send := func(kvs []deliveredKVs) error { - select { - case kvsCh <- kvs: - return nil - case <-ctx.Done(): - return ctx.Err() - case deliverResult, ok := <-deliverCompleteCh: - if deliverResult.err == nil && !ok { - deliverResult.err = ctx.Err() - } - if deliverResult.err == nil { - deliverResult.err = errors.New("unexpected premature fulfillment") - logger.DPanic("unexpected: deliverCompleteCh prematurely fulfilled with no error", zap.Bool("chIsOpen", ok)) - } - return errors.Trace(deliverResult.err) - } - } - - pauser, maxKvPairsCnt := rc.pauser, rc.cfg.TikvImporter.MaxKVPairs - initializedColumns, reachEOF := false, false - // filteredColumns is column names that excluded ignored columns - // WARN: this might be not correct when different SQL statements contains different fields, - // but since ColumnPermutation also depends on the hypothesis that the columns in one source file is the same - // so this should be ok. - var ( - filteredColumns []string - extendVals []types.Datum - ) - ignoreColumns, err1 := rc.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(t.dbInfo.Name, t.tableInfo.Core.Name.O, rc.cfg.Mydumper.CaseSensitive) - if err1 != nil { - err = err1 - return - } - - var dupIgnoreRowsIter extsort.Iterator - if t.dupIgnoreRows != nil { - dupIgnoreRowsIter, err = t.dupIgnoreRows.NewIterator(ctx) - if err != nil { - return 0, 0, err - } - defer func() { - _ = dupIgnoreRowsIter.Close() - }() - } - - for !reachEOF { - if err = pauser.Wait(ctx); err != nil { - return - } - offset, _ := cr.parser.Pos() - if offset >= cr.chunk.Chunk.EndOffset { - break - } - - var readDur, encodeDur time.Duration - canDeliver := false - kvPacket := make([]deliveredKVs, 0, maxKvPairsCnt) - curOffset := offset - var newOffset, rowID, newScannedOffset int64 - var scannedOffset int64 = -1 - var kvSize uint64 - var scannedOffsetErr error - outLoop: - for !canDeliver { - readDurStart := time.Now() - err = cr.parser.ReadRow() - columnNames := cr.parser.Columns() - newOffset, rowID = cr.parser.Pos() - if cr.chunk.FileMeta.Compression != mydump.CompressionNone || cr.chunk.FileMeta.Type == mydump.SourceTypeParquet { - newScannedOffset, scannedOffsetErr = cr.parser.ScannedPos() - if scannedOffsetErr != nil { - logger.Warn("fail to get data engine ScannedPos, progress may not be accurate", - log.ShortError(scannedOffsetErr), zap.String("file", cr.chunk.FileMeta.Path)) - } - if scannedOffset == -1 { - scannedOffset = newScannedOffset - } - } - - switch errors.Cause(err) { - case nil: - if !initializedColumns { - if len(cr.chunk.ColumnPermutation) == 0 { - if err = t.initializeColumns(columnNames, cr.chunk); err != nil { - return - } - } - filteredColumns = columnNames - ignoreColsMap := ignoreColumns.ColumnsMap() - if len(ignoreColsMap) > 0 || len(cr.chunk.FileMeta.ExtendData.Columns) > 0 { - filteredColumns, extendVals = filterColumns(columnNames, cr.chunk.FileMeta.ExtendData, ignoreColsMap, t.tableInfo.Core) - } - lastRow := cr.parser.LastRow() - lastRowLen := len(lastRow.Row) - extendColsMap := make(map[string]int) - for i, c := range cr.chunk.FileMeta.ExtendData.Columns { - extendColsMap[c] = lastRowLen + i - } - for i, col := range t.tableInfo.Core.Columns { - if p, ok := extendColsMap[col.Name.O]; ok { - cr.chunk.ColumnPermutation[i] = p - } - } - initializedColumns = true - - if dupIgnoreRowsIter != nil { - dupIgnoreRowsIter.Seek(common.EncodeIntRowID(lastRow.RowID)) - } - } - case io.EOF: - reachEOF = true - break outLoop - default: - err = common.ErrEncodeKV.Wrap(err).GenWithStackByArgs(&cr.chunk.Key, newOffset) - return - } - readDur += time.Since(readDurStart) - encodeDurStart := time.Now() - lastRow := cr.parser.LastRow() - lastRow.Row = append(lastRow.Row, extendVals...) - - // Skip duplicated rows. - if dupIgnoreRowsIter != nil { - rowIDKey := common.EncodeIntRowID(lastRow.RowID) - isDupIgnored := false - dupDetectLoop: - for dupIgnoreRowsIter.Valid() { - switch bytes.Compare(rowIDKey, dupIgnoreRowsIter.UnsafeKey()) { - case 0: - isDupIgnored = true - break dupDetectLoop - case 1: - dupIgnoreRowsIter.Next() - case -1: - break dupDetectLoop - } - } - if dupIgnoreRowsIter.Error() != nil { - err = dupIgnoreRowsIter.Error() - return - } - if isDupIgnored { - cr.parser.RecycleRow(lastRow) - lastOffset := curOffset - curOffset = newOffset - - if rc.errorMgr.ConflictRecordsRemain() <= 0 { - continue - } - - dupMsg := cr.getDuplicateMessage( - originalTableEncoder, - lastRow, - lastOffset, - dupIgnoreRowsIter.UnsafeValue(), - t.tableInfo.Desired, - logger, - ) - rowText := tidb.EncodeRowForRecord(ctx, t.encTable, rc.cfg.TiDB.SQLMode, lastRow.Row, cr.chunk.ColumnPermutation) - err = rc.errorMgr.RecordDuplicate( - ctx, - logger, - t.tableName, - cr.chunk.Key.Path, - newOffset, - dupMsg, - lastRow.RowID, - rowText, - ) - if err != nil { - return 0, 0, err - } - continue - } - } - - // sql -> kv - kvs, encodeErr := kvEncoder.Encode(lastRow.Row, lastRow.RowID, cr.chunk.ColumnPermutation, curOffset) - encodeDur += time.Since(encodeDurStart) - - hasIgnoredEncodeErr := false - if encodeErr != nil { - rowText := tidb.EncodeRowForRecord(ctx, t.encTable, rc.cfg.TiDB.SQLMode, lastRow.Row, cr.chunk.ColumnPermutation) - encodeErr = rc.errorMgr.RecordTypeError(ctx, logger, t.tableName, cr.chunk.Key.Path, newOffset, rowText, encodeErr) - if encodeErr != nil { - err = common.ErrEncodeKV.Wrap(encodeErr).GenWithStackByArgs(&cr.chunk.Key, newOffset) - } - hasIgnoredEncodeErr = true - } - cr.parser.RecycleRow(lastRow) - curOffset = newOffset - - if err != nil { - return - } - if hasIgnoredEncodeErr { - continue - } - - kvPacket = append(kvPacket, deliveredKVs{kvs: kvs, columns: filteredColumns, offset: newOffset, - rowID: rowID, realOffset: newScannedOffset}) - kvSize += kvs.Size() - failpoint.Inject("mock-kv-size", func(val failpoint.Value) { - kvSize += uint64(val.(int)) - }) - // pebble cannot allow > 4.0G kv in one batch. - // we will meet pebble panic when import sql file and each kv has the size larger than 4G / maxKvPairsCnt. - // so add this check. - if kvSize >= minDeliverBytes || len(kvPacket) >= maxKvPairsCnt || newOffset == cr.chunk.Chunk.EndOffset { - canDeliver = true - kvSize = 0 - } - } - encodeTotalDur += encodeDur - readTotalDur += readDur - if m, ok := metric.FromContext(ctx); ok { - m.RowEncodeSecondsHistogram.Observe(encodeDur.Seconds()) - m.RowReadSecondsHistogram.Observe(readDur.Seconds()) - if cr.chunk.FileMeta.Type == mydump.SourceTypeParquet { - m.RowReadBytesHistogram.Observe(float64(newScannedOffset - scannedOffset)) - } else { - m.RowReadBytesHistogram.Observe(float64(newOffset - offset)) - } - } - - if len(kvPacket) != 0 { - deliverKvStart := time.Now() - if err = send(kvPacket); err != nil { - return - } - if m, ok := metric.FromContext(ctx); ok { - m.RowKVDeliverSecondsHistogram.Observe(time.Since(deliverKvStart).Seconds()) - } - } - } - - err = send([]deliveredKVs{{offset: cr.chunk.Chunk.EndOffset, realOffset: cr.chunk.FileMeta.FileSize}}) - return -} - -// getDuplicateMessage gets the duplicate message like a SQL error. When it meets -// internal error, the error message will be returned instead of the duplicate message. -// If the index is not found (which is not expected), an empty string will be returned. -func (cr *chunkProcessor) getDuplicateMessage( - kvEncoder encode.Encoder, - lastRow mydump.Row, - lastOffset int64, - encodedIdxID []byte, - tableInfo *model.TableInfo, - logger log.Logger, -) string { - _, idxID, err := codec.DecodeVarint(encodedIdxID) - if err != nil { - return err.Error() - } - kvs, err := kvEncoder.Encode(lastRow.Row, lastRow.RowID, cr.chunk.ColumnPermutation, lastOffset) - if err != nil { - return err.Error() - } - - if idxID == conflictOnHandle { - for _, kv := range kvs.(*kv.Pairs).Pairs { - if tablecodec.IsRecordKey(kv.Key) { - dupErr := txn.ExtractKeyExistsErrFromHandle(kv.Key, kv.Val, tableInfo) - return dupErr.Error() - } - } - // should not happen - logger.Warn("fail to find conflict record key", - zap.String("file", cr.chunk.FileMeta.Path), - zap.Any("row", lastRow.Row)) - } else { - for _, kv := range kvs.(*kv.Pairs).Pairs { - _, decodedIdxID, isRecordKey, err := tablecodec.DecodeKeyHead(kv.Key) - if err != nil { - return err.Error() - } - if !isRecordKey && decodedIdxID == idxID { - dupErr := txn.ExtractKeyExistsErrFromIndex(kv.Key, kv.Val, tableInfo, idxID) - return dupErr.Error() - } - } - // should not happen - logger.Warn("fail to find conflict index key", - zap.String("file", cr.chunk.FileMeta.Path), - zap.Int64("idxID", idxID), - zap.Any("row", lastRow.Row)) - } - return "" -} - -//nolint:nakedret // TODO: refactor -func (cr *chunkProcessor) deliverLoop( - ctx context.Context, - kvsCh <-chan []deliveredKVs, - t *TableImporter, - engineID int32, - dataEngine, indexEngine backend.EngineWriter, - rc *Controller, -) (deliverTotalDur time.Duration, err error) { - deliverLogger := t.logger.With( - zap.Int32("engineNumber", engineID), - zap.Int("fileIndex", cr.index), - zap.Stringer("path", &cr.chunk.Key), - zap.String("task", "deliver"), - ) - // Fetch enough KV pairs from the source. - dataKVs := rc.encBuilder.MakeEmptyRows() - indexKVs := rc.encBuilder.MakeEmptyRows() - - dataSynced := true - hasMoreKVs := true - var startRealOffset, currRealOffset int64 // save to 0 at first - - keyspace := keyspace.CodecV1.GetKeyspace() - if t.kvStore != nil { - keyspace = t.kvStore.GetCodec().GetKeyspace() - } - for hasMoreKVs { - var ( - dataChecksum = verify.NewKVChecksumWithKeyspace(keyspace) - indexChecksum = verify.NewKVChecksumWithKeyspace(keyspace) - ) - var columns []string - var kvPacket []deliveredKVs - // init these two field as checkpoint current value, so even if there are no kv pairs delivered, - // chunk checkpoint should stay the same - startOffset := cr.chunk.Chunk.Offset - currOffset := startOffset - startRealOffset = cr.chunk.Chunk.RealOffset - currRealOffset = startRealOffset - rowID := cr.chunk.Chunk.PrevRowIDMax - - populate: - for dataChecksum.SumSize()+indexChecksum.SumSize() < minDeliverBytes { - select { - case kvPacket = <-kvsCh: - if len(kvPacket) == 0 { - hasMoreKVs = false - break populate - } - for _, p := range kvPacket { - if p.kvs == nil { - // This is the last message. - currOffset = p.offset - currRealOffset = p.realOffset - hasMoreKVs = false - break populate - } - p.kvs.ClassifyAndAppend(&dataKVs, dataChecksum, &indexKVs, indexChecksum) - columns = p.columns - currOffset = p.offset - currRealOffset = p.realOffset - rowID = p.rowID - } - case <-ctx.Done(): - err = ctx.Err() - return - } - } - - err = func() error { - // We use `TryRLock` with sleep here to avoid blocking current goroutine during importing when disk-quota is - // triggered, so that we can save chunkCheckpoint as soon as possible after `FlushEngine` is called. - // This implementation may not be very elegant or even completely correct, but it is currently a relatively - // simple and effective solution. - for !rc.diskQuotaLock.TryRLock() { - // try to update chunk checkpoint, this can help save checkpoint after importing when disk-quota is triggered - if !dataSynced { - dataSynced = cr.maybeSaveCheckpoint(rc, t, engineID, cr.chunk, dataEngine, indexEngine) - } - time.Sleep(time.Millisecond) - } - defer rc.diskQuotaLock.RUnlock() - - // Write KVs into the engine - start := time.Now() - - if err = dataEngine.AppendRows(ctx, columns, dataKVs); err != nil { - if !common.IsContextCanceledError(err) { - deliverLogger.Error("write to data engine failed", log.ShortError(err)) - } - - return errors.Trace(err) - } - if err = indexEngine.AppendRows(ctx, columns, indexKVs); err != nil { - if !common.IsContextCanceledError(err) { - deliverLogger.Error("write to index engine failed", log.ShortError(err)) - } - return errors.Trace(err) - } - - if m, ok := metric.FromContext(ctx); ok { - deliverDur := time.Since(start) - deliverTotalDur += deliverDur - m.BlockDeliverSecondsHistogram.Observe(deliverDur.Seconds()) - m.BlockDeliverBytesHistogram.WithLabelValues(metric.BlockDeliverKindData).Observe(float64(dataChecksum.SumSize())) - m.BlockDeliverBytesHistogram.WithLabelValues(metric.BlockDeliverKindIndex).Observe(float64(indexChecksum.SumSize())) - m.BlockDeliverKVPairsHistogram.WithLabelValues(metric.BlockDeliverKindData).Observe(float64(dataChecksum.SumKVS())) - m.BlockDeliverKVPairsHistogram.WithLabelValues(metric.BlockDeliverKindIndex).Observe(float64(indexChecksum.SumKVS())) - } - return nil - }() - if err != nil { - return - } - dataSynced = false - - dataKVs = dataKVs.Clear() - indexKVs = indexKVs.Clear() - - // Update the table, and save a checkpoint. - // (the write to the importer is effective immediately, thus update these here) - // No need to apply a lock since this is the only thread updating `cr.chunk.**`. - // In local mode, we should write these checkpoints after engine flushed. - lastOffset := cr.chunk.Chunk.Offset - cr.chunk.Checksum.Add(dataChecksum) - cr.chunk.Checksum.Add(indexChecksum) - cr.chunk.Chunk.Offset = currOffset - cr.chunk.Chunk.RealOffset = currRealOffset - cr.chunk.Chunk.PrevRowIDMax = rowID - - if m, ok := metric.FromContext(ctx); ok { - // value of currOffset comes from parser.pos which increase monotonically. the init value of parser.pos - // comes from chunk.Chunk.Offset. so it shouldn't happen that currOffset - startOffset < 0. - // but we met it one time, but cannot reproduce it now, we add this check to make code more robust - // TODO: reproduce and find the root cause and fix it completely - var lowOffset, highOffset int64 - if cr.chunk.FileMeta.Compression != mydump.CompressionNone { - lowOffset, highOffset = startRealOffset, currRealOffset - } else { - lowOffset, highOffset = startOffset, currOffset - } - delta := highOffset - lowOffset - if delta >= 0 { - if cr.chunk.FileMeta.Type == mydump.SourceTypeParquet { - if currRealOffset > startRealOffset { - m.BytesCounter.WithLabelValues(metric.StateRestored).Add(float64(currRealOffset - startRealOffset)) - } - m.RowsCounter.WithLabelValues(metric.StateRestored, t.tableName).Add(float64(delta)) - } else { - m.BytesCounter.WithLabelValues(metric.StateRestored).Add(float64(delta)) - m.RowsCounter.WithLabelValues(metric.StateRestored, t.tableName).Add(float64(dataChecksum.SumKVS())) - } - if rc.status != nil && rc.status.backend == config.BackendTiDB { - rc.status.FinishedFileSize.Add(delta) - } - } else { - deliverLogger.Error("offset go back", zap.Int64("curr", highOffset), - zap.Int64("start", lowOffset)) - } - } - - if currOffset > lastOffset || dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0 { - // No need to save checkpoint if nothing was delivered. - dataSynced = cr.maybeSaveCheckpoint(rc, t, engineID, cr.chunk, dataEngine, indexEngine) - } - failpoint.Inject("SlowDownWriteRows", func() { - deliverLogger.Warn("Slowed down write rows") - finished := rc.status.FinishedFileSize.Load() - total := rc.status.TotalFileSize.Load() - deliverLogger.Warn("PrintStatus Failpoint", - zap.Int64("finished", finished), - zap.Int64("total", total)) - }) - failpoint.Inject("FailAfterWriteRows", nil) - // TODO: for local backend, we may save checkpoint more frequently, e.g. after written - // 10GB kv pairs to data engine, we can do a flush for both data & index engine, then we - // can safely update current checkpoint. - - failpoint.Inject("LocalBackendSaveCheckpoint", func() { - if !isLocalBackend(rc.cfg) && (dataChecksum.SumKVS() != 0 || indexChecksum.SumKVS() != 0) { - // No need to save checkpoint if nothing was delivered. - saveCheckpoint(rc, t, engineID, cr.chunk) - } - }) - } - - return -} - -func (*chunkProcessor) maybeSaveCheckpoint( - rc *Controller, - t *TableImporter, - engineID int32, - chunk *checkpoints.ChunkCheckpoint, - data, index backend.EngineWriter, -) bool { - if data.IsSynced() && index.IsSynced() { - saveCheckpoint(rc, t, engineID, chunk) - return true - } - return false -} - -func (cr *chunkProcessor) close() { - _ = cr.parser.Close() -} diff --git a/lightning/pkg/importer/get_pre_info.go b/lightning/pkg/importer/get_pre_info.go index 64880d9c52750..5e34c6bf36186 100644 --- a/lightning/pkg/importer/get_pre_info.go +++ b/lightning/pkg/importer/get_pre_info.go @@ -187,9 +187,9 @@ func (g *TargetInfoGetterImpl) CheckVersionRequirements(ctx context.Context) err // It tries to select the row count from the target DB. func (g *TargetInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { var result bool - if _, _err_ := failpoint.Eval(_curpkg_("CheckTableEmptyFailed")); _err_ == nil { - return nil, errors.New("mock error") - } + failpoint.Inject("CheckTableEmptyFailed", func() { + failpoint.Return(nil, errors.New("mock error")) + }) exec := common.SQLWithRetry{ DB: g.db, Logger: log.FromContext(ctx), @@ -365,18 +365,19 @@ func (p *PreImportInfoGetterImpl) GetAllTableStructures(ctx context.Context, opt func (p *PreImportInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Context, dbSrcFileMeta *mydump.MDDatabaseMeta, getPreInfoCfg *ropts.GetPreInfoConfig) ([]*model.TableInfo, error) { dbName := dbSrcFileMeta.Name - if v, _err_ := failpoint.Eval(_curpkg_("getTableStructuresByFileMeta_BeforeFetchRemoteTableModels")); _err_ == nil { - - fmt.Println("failpoint: getTableStructuresByFileMeta_BeforeFetchRemoteTableModels") - const defaultMilliSeconds int = 5000 - sleepMilliSeconds, ok := v.(int) - if !ok || sleepMilliSeconds <= 0 || sleepMilliSeconds > 30000 { - sleepMilliSeconds = defaultMilliSeconds - } - //nolint: errcheck - failpoint.Enable("github.com/pingcap/tidb/pkg/lightning/backend/tidb/FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", fmt.Sprintf("sleep(%d)", sleepMilliSeconds)) - - } + failpoint.Inject( + "getTableStructuresByFileMeta_BeforeFetchRemoteTableModels", + func(v failpoint.Value) { + fmt.Println("failpoint: getTableStructuresByFileMeta_BeforeFetchRemoteTableModels") + const defaultMilliSeconds int = 5000 + sleepMilliSeconds, ok := v.(int) + if !ok || sleepMilliSeconds <= 0 || sleepMilliSeconds > 30000 { + sleepMilliSeconds = defaultMilliSeconds + } + //nolint: errcheck + failpoint.Enable("github.com/pingcap/tidb/pkg/lightning/backend/tidb/FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", fmt.Sprintf("sleep(%d)", sleepMilliSeconds)) + }, + ) currentTableInfosFromDB, err := p.targetInfoGetter.FetchRemoteTableModels(ctx, dbName) if err != nil { if getPreInfoCfg != nil && getPreInfoCfg.IgnoreDBNotExist { @@ -757,9 +758,9 @@ outloop: rowSize += uint64(lastRow.Length) parser.RecycleRow(lastRow) - if val, _err_ := failpoint.Eval(_curpkg_("mock-kv-size")); _err_ == nil { + failpoint.Inject("mock-kv-size", func(val failpoint.Value) { kvSize += uint64(val.(int)) - } + }) if rowSize > maxSampleDataSize || rowCount > maxSampleRowCount { break } diff --git a/lightning/pkg/importer/get_pre_info.go__failpoint_stash__ b/lightning/pkg/importer/get_pre_info.go__failpoint_stash__ deleted file mode 100644 index 5e34c6bf36186..0000000000000 --- a/lightning/pkg/importer/get_pre_info.go__failpoint_stash__ +++ /dev/null @@ -1,835 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "bytes" - "context" - "database/sql" - "fmt" - "io" - "strings" - - mysql_sql_driver "github.com/go-sql-driver/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - ropts "github.com/pingcap/tidb/lightning/pkg/importer/opts" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/backend/tidb" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/errormanager" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/lightning/worker" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - _ "github.com/pingcap/tidb/pkg/planner/core" // to setup expression.EvalAstExpr. Otherwise we cannot parse the default value - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/mock" - pdhttp "github.com/tikv/pd/client/http" - "go.uber.org/zap" - "golang.org/x/exp/maps" -) - -// compressionRatio is the tikv/tiflash's compression ratio -const compressionRatio = float64(1) / 3 - -// EstimateSourceDataSizeResult is the object for estimated data size result. -type EstimateSourceDataSizeResult struct { - // SizeWithIndex is the tikv size with the index. - SizeWithIndex int64 - // SizeWithoutIndex is the tikv size without the index. - SizeWithoutIndex int64 - // HasUnsortedBigTables indicates whether the source data has unsorted big tables or not. - HasUnsortedBigTables bool - // TiFlashSize is the size of tiflash. - TiFlashSize int64 -} - -// PreImportInfoGetter defines the operations to get information from sources and target. -// These information are used in the preparation of the import ( like precheck ). -type PreImportInfoGetter interface { - TargetInfoGetter - // GetAllTableStructures gets all the table structures with the information from both the source and the target. - GetAllTableStructures(ctx context.Context, opts ...ropts.GetPreInfoOption) (map[string]*checkpoints.TidbDBInfo, error) - // ReadFirstNRowsByTableName reads the first N rows of data of an importing source table. - ReadFirstNRowsByTableName(ctx context.Context, schemaName string, tableName string, n int) (cols []string, rows [][]types.Datum, err error) - // ReadFirstNRowsByFileMeta reads the first N rows of an data file. - ReadFirstNRowsByFileMeta(ctx context.Context, dataFileMeta mydump.SourceFileMeta, n int) (cols []string, rows [][]types.Datum, err error) - // EstimateSourceDataSize estimates the datasize to generate during the import as well as some other sub-informaiton. - // It will return: - // * the estimated data size to generate during the import, - // which might include some extra index data to generate besides the source file data - // * the total data size of all the source files, - // * whether there are some unsorted big tables - EstimateSourceDataSize(ctx context.Context, opts ...ropts.GetPreInfoOption) (*EstimateSourceDataSizeResult, error) -} - -// TargetInfoGetter defines the operations to get information from target. -type TargetInfoGetter interface { - // FetchRemoteDBModels fetches the database structures from the remote target. - FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) - // FetchRemoteTableModels fetches the table structures from the remote target. - FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) - // CheckVersionRequirements performs the check whether the target satisfies the version requirements. - CheckVersionRequirements(ctx context.Context) error - // IsTableEmpty checks whether the specified table on the target DB contains data or not. - IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) - // GetTargetSysVariablesForImport gets some important systam variables for importing on the target. - GetTargetSysVariablesForImport(ctx context.Context, opts ...ropts.GetPreInfoOption) map[string]string - // GetMaxReplica gets the max-replica from replication config on the target. - GetMaxReplica(ctx context.Context) (uint64, error) - // GetStorageInfo gets the storage information on the target. - GetStorageInfo(ctx context.Context) (*pdhttp.StoresInfo, error) - // GetEmptyRegionsInfo gets the region information of all the empty regions on the target. - GetEmptyRegionsInfo(ctx context.Context) (*pdhttp.RegionsInfo, error) -} - -type preInfoGetterKey string - -const ( - preInfoGetterKeyDBMetas preInfoGetterKey = "PRE_INFO_GETTER/DB_METAS" -) - -// WithPreInfoGetterDBMetas returns a new context with the specified dbMetas. -func WithPreInfoGetterDBMetas(ctx context.Context, dbMetas []*mydump.MDDatabaseMeta) context.Context { - return context.WithValue(ctx, preInfoGetterKeyDBMetas, dbMetas) -} - -// TargetInfoGetterImpl implements the operations to get information from the target. -type TargetInfoGetterImpl struct { - cfg *config.Config - db *sql.DB - backend backend.TargetInfoGetter - pdHTTPCli pdhttp.Client -} - -// NewTargetInfoGetterImpl creates a TargetInfoGetterImpl object. -func NewTargetInfoGetterImpl( - cfg *config.Config, - targetDB *sql.DB, - pdHTTPCli pdhttp.Client, -) (*TargetInfoGetterImpl, error) { - tls, err := cfg.ToTLS() - if err != nil { - return nil, errors.Trace(err) - } - var backendTargetInfoGetter backend.TargetInfoGetter - switch cfg.TikvImporter.Backend { - case config.BackendTiDB: - backendTargetInfoGetter = tidb.NewTargetInfoGetter(targetDB) - case config.BackendLocal: - backendTargetInfoGetter = local.NewTargetInfoGetter(tls, targetDB, pdHTTPCli) - default: - return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) - } - return &TargetInfoGetterImpl{ - cfg: cfg, - db: targetDB, - backend: backendTargetInfoGetter, - pdHTTPCli: pdHTTPCli, - }, nil -} - -// FetchRemoteDBModels implements TargetInfoGetter. -func (g *TargetInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { - return g.backend.FetchRemoteDBModels(ctx) -} - -// FetchRemoteTableModels fetches the table structures from the remote target. -// It implements the TargetInfoGetter interface. -func (g *TargetInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - return g.backend.FetchRemoteTableModels(ctx, schemaName) -} - -// CheckVersionRequirements performs the check whether the target satisfies the version requirements. -// It implements the TargetInfoGetter interface. -// Mydump database metas are retrieved from the context. -func (g *TargetInfoGetterImpl) CheckVersionRequirements(ctx context.Context) error { - var dbMetas []*mydump.MDDatabaseMeta - dbmetasVal := ctx.Value(preInfoGetterKeyDBMetas) - if dbmetasVal != nil { - if m, ok := dbmetasVal.([]*mydump.MDDatabaseMeta); ok { - dbMetas = m - } - } - return g.backend.CheckRequirements(ctx, &backend.CheckCtx{ - DBMetas: dbMetas, - }) -} - -// IsTableEmpty checks whether the specified table on the target DB contains data or not. -// It implements the TargetInfoGetter interface. -// It tries to select the row count from the target DB. -func (g *TargetInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { - var result bool - failpoint.Inject("CheckTableEmptyFailed", func() { - failpoint.Return(nil, errors.New("mock error")) - }) - exec := common.SQLWithRetry{ - DB: g.db, - Logger: log.FromContext(ctx), - } - var dump int - err := exec.QueryRow(ctx, "check table empty", - // Here we use the `USE INDEX()` hint to skip fetch the record from index. - // In Lightning, if previous importing is halted half-way, it is possible that - // the data is partially imported, but the index data has not been imported. - // In this situation, if no hint is added, the SQL executor might fetch the record from index, - // which is empty. This will result in missing check. - common.SprintfWithIdentifiers("SELECT 1 FROM %s.%s USE INDEX() LIMIT 1", schemaName, tableName), - &dump, - ) - - isNoSuchTableErr := false - rootErr := errors.Cause(err) - if mysqlErr, ok := rootErr.(*mysql_sql_driver.MySQLError); ok && mysqlErr.Number == errno.ErrNoSuchTable { - isNoSuchTableErr = true - } - switch { - case isNoSuchTableErr: - result = true - case errors.ErrorEqual(err, sql.ErrNoRows): - result = true - case err != nil: - return nil, errors.Trace(err) - default: - result = false - } - return &result, nil -} - -// GetTargetSysVariablesForImport gets some important system variables for importing on the target. -// It implements the TargetInfoGetter interface. -// It uses the SQL to fetch sys variables from the target. -func (g *TargetInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Context, _ ...ropts.GetPreInfoOption) map[string]string { - sysVars := ObtainImportantVariables(ctx, g.db, !isTiDBBackend(g.cfg)) - // override by manually set vars - maps.Copy(sysVars, g.cfg.TiDB.Vars) - return sysVars -} - -// GetMaxReplica implements the TargetInfoGetter interface. -func (g *TargetInfoGetterImpl) GetMaxReplica(ctx context.Context) (uint64, error) { - cfg, err := g.pdHTTPCli.GetReplicateConfig(ctx) - if err != nil { - return 0, errors.Trace(err) - } - val := cfg["max-replicas"].(float64) - return uint64(val), nil -} - -// GetStorageInfo gets the storage information on the target. -// It implements the TargetInfoGetter interface. -// It uses the PD interface through TLS to get the information. -func (g *TargetInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdhttp.StoresInfo, error) { - return g.pdHTTPCli.GetStores(ctx) -} - -// GetEmptyRegionsInfo gets the region information of all the empty regions on the target. -// It implements the TargetInfoGetter interface. -// It uses the PD interface through TLS to get the information. -func (g *TargetInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdhttp.RegionsInfo, error) { - return g.pdHTTPCli.GetEmptyRegions(ctx) -} - -// PreImportInfoGetterImpl implements the operations to get information used in importing preparation. -type PreImportInfoGetterImpl struct { - cfg *config.Config - getPreInfoCfg *ropts.GetPreInfoConfig - srcStorage storage.ExternalStorage - ioWorkers *worker.Pool - encBuilder encode.EncodingBuilder - targetInfoGetter TargetInfoGetter - - dbMetas []*mydump.MDDatabaseMeta - mdDBMetaMap map[string]*mydump.MDDatabaseMeta - mdDBTableMetaMap map[string]map[string]*mydump.MDTableMeta - - dbInfosCache map[string]*checkpoints.TidbDBInfo - sysVarsCache map[string]string - estimatedSizeCache *EstimateSourceDataSizeResult -} - -// NewPreImportInfoGetter creates a PreImportInfoGetterImpl object. -func NewPreImportInfoGetter( - cfg *config.Config, - dbMetas []*mydump.MDDatabaseMeta, - srcStorage storage.ExternalStorage, - targetInfoGetter TargetInfoGetter, - ioWorkers *worker.Pool, - encBuilder encode.EncodingBuilder, - opts ...ropts.GetPreInfoOption, -) (*PreImportInfoGetterImpl, error) { - if ioWorkers == nil { - ioWorkers = worker.NewPool(context.Background(), cfg.App.IOConcurrency, "pre_info_getter_io") - } - if encBuilder == nil { - switch cfg.TikvImporter.Backend { - case config.BackendTiDB: - encBuilder = tidb.NewEncodingBuilder() - case config.BackendLocal: - encBuilder = local.NewEncodingBuilder(context.Background()) - default: - return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) - } - } - - getPreInfoCfg := ropts.NewDefaultGetPreInfoConfig() - for _, o := range opts { - o(getPreInfoCfg) - } - result := &PreImportInfoGetterImpl{ - cfg: cfg, - getPreInfoCfg: getPreInfoCfg, - dbMetas: dbMetas, - srcStorage: srcStorage, - ioWorkers: ioWorkers, - encBuilder: encBuilder, - targetInfoGetter: targetInfoGetter, - } - result.Init() - return result, nil -} - -// Init initializes some internal data and states for PreImportInfoGetterImpl. -func (p *PreImportInfoGetterImpl) Init() { - mdDBMetaMap := make(map[string]*mydump.MDDatabaseMeta) - mdDBTableMetaMap := make(map[string]map[string]*mydump.MDTableMeta) - for _, dbMeta := range p.dbMetas { - dbName := dbMeta.Name - mdDBMetaMap[dbName] = dbMeta - mdTableMetaMap, ok := mdDBTableMetaMap[dbName] - if !ok { - mdTableMetaMap = make(map[string]*mydump.MDTableMeta) - mdDBTableMetaMap[dbName] = mdTableMetaMap - } - for _, tblMeta := range dbMeta.Tables { - tblName := tblMeta.Name - mdTableMetaMap[tblName] = tblMeta - } - } - p.mdDBMetaMap = mdDBMetaMap - p.mdDBTableMetaMap = mdDBTableMetaMap -} - -// GetAllTableStructures gets all the table structures with the information from both the source and the target. -// It implements the PreImportInfoGetter interface. -// It has a caching mechanism: the table structures will be obtained from the source only once. -func (p *PreImportInfoGetterImpl) GetAllTableStructures(ctx context.Context, opts ...ropts.GetPreInfoOption) (map[string]*checkpoints.TidbDBInfo, error) { - var ( - dbInfos map[string]*checkpoints.TidbDBInfo - err error - ) - getPreInfoCfg := p.getPreInfoCfg.Clone() - for _, o := range opts { - o(getPreInfoCfg) - } - dbInfos = p.dbInfosCache - if dbInfos != nil && !getPreInfoCfg.ForceReloadCache { - return dbInfos, nil - } - dbInfos, err = LoadSchemaInfo(ctx, p.dbMetas, func(ctx context.Context, dbName string) ([]*model.TableInfo, error) { - return p.getTableStructuresByFileMeta(ctx, p.mdDBMetaMap[dbName], getPreInfoCfg) - }) - if err != nil { - return nil, errors.Trace(err) - } - p.dbInfosCache = dbInfos - return dbInfos, nil -} - -func (p *PreImportInfoGetterImpl) getTableStructuresByFileMeta(ctx context.Context, dbSrcFileMeta *mydump.MDDatabaseMeta, getPreInfoCfg *ropts.GetPreInfoConfig) ([]*model.TableInfo, error) { - dbName := dbSrcFileMeta.Name - failpoint.Inject( - "getTableStructuresByFileMeta_BeforeFetchRemoteTableModels", - func(v failpoint.Value) { - fmt.Println("failpoint: getTableStructuresByFileMeta_BeforeFetchRemoteTableModels") - const defaultMilliSeconds int = 5000 - sleepMilliSeconds, ok := v.(int) - if !ok || sleepMilliSeconds <= 0 || sleepMilliSeconds > 30000 { - sleepMilliSeconds = defaultMilliSeconds - } - //nolint: errcheck - failpoint.Enable("github.com/pingcap/tidb/pkg/lightning/backend/tidb/FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", fmt.Sprintf("sleep(%d)", sleepMilliSeconds)) - }, - ) - currentTableInfosFromDB, err := p.targetInfoGetter.FetchRemoteTableModels(ctx, dbName) - if err != nil { - if getPreInfoCfg != nil && getPreInfoCfg.IgnoreDBNotExist { - dbNotExistErr := dbterror.ClassSchema.NewStd(errno.ErrBadDB).FastGenByArgs(dbName) - // The returned error is an error showing get info request error, - // and attaches the detailed error response as a string. - // So we cannot get the error chain and use error comparison, - // and instead, we use the string comparison on error messages. - if strings.Contains(err.Error(), dbNotExistErr.Error()) { - log.L().Warn("DB not exists. But ignore it", zap.Error(err)) - goto get_struct_from_src - } - } - return nil, errors.Trace(err) - } -get_struct_from_src: - currentTableInfosMap := make(map[string]*model.TableInfo) - for _, tblInfo := range currentTableInfosFromDB { - currentTableInfosMap[tblInfo.Name.L] = tblInfo - } - resultInfos := make([]*model.TableInfo, len(dbSrcFileMeta.Tables)) - for i, tableFileMeta := range dbSrcFileMeta.Tables { - if curTblInfo, ok := currentTableInfosMap[strings.ToLower(tableFileMeta.Name)]; ok { - resultInfos[i] = curTblInfo - continue - } - createTblSQL, err := tableFileMeta.GetSchema(ctx, p.srcStorage) - if err != nil { - return nil, errors.Annotatef(err, "get create table statement from schema file error: %s", tableFileMeta.Name) - } - theTableInfo, err := newTableInfo(createTblSQL, 0) - log.L().Info("generate table info from SQL", zap.Error(err), zap.String("sql", createTblSQL), zap.String("table_name", tableFileMeta.Name), zap.String("db_name", dbSrcFileMeta.Name)) - if err != nil { - errMsg := "generate table info from SQL error" - log.L().Error(errMsg, zap.Error(err), zap.String("sql", createTblSQL), zap.String("table_name", tableFileMeta.Name)) - return nil, errors.Annotatef(err, "%s: %s", errMsg, tableFileMeta.Name) - } - resultInfos[i] = theTableInfo - } - return resultInfos, nil -} - -func newTableInfo(createTblSQL string, tableID int64) (*model.TableInfo, error) { - parser := parser.New() - astNode, err := parser.ParseOneStmt(createTblSQL, "", "") - if err != nil { - errMsg := "parse sql statement error" - log.L().Error(errMsg, zap.Error(err), zap.String("sql", createTblSQL)) - return nil, errors.Trace(err) - } - sctx := mock.NewContext() - createTableStmt, ok := astNode.(*ast.CreateTableStmt) - if !ok { - return nil, errors.New("cannot transfer the parsed SQL as an CREATE TABLE statement") - } - info, err := ddl.MockTableInfo(sctx, createTableStmt, tableID) - if err != nil { - return nil, errors.Trace(err) - } - info.State = model.StatePublic - return info, nil -} - -// ReadFirstNRowsByTableName reads the first N rows of data of an importing source table. -// It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) ReadFirstNRowsByTableName(ctx context.Context, schemaName string, tableName string, n int) ([]string, [][]types.Datum, error) { - mdTableMetaMap, ok := p.mdDBTableMetaMap[schemaName] - if !ok { - return nil, nil, errors.Errorf("cannot find the schema: %s", schemaName) - } - mdTableMeta, ok := mdTableMetaMap[tableName] - if !ok { - return nil, nil, errors.Errorf("cannot find the table: %s.%s", schemaName, tableName) - } - if len(mdTableMeta.DataFiles) <= 0 { - return nil, [][]types.Datum{}, nil - } - return p.ReadFirstNRowsByFileMeta(ctx, mdTableMeta.DataFiles[0].FileMeta, n) -} - -// ReadFirstNRowsByFileMeta reads the first N rows of an data file. -// It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) ReadFirstNRowsByFileMeta(ctx context.Context, dataFileMeta mydump.SourceFileMeta, n int) ([]string, [][]types.Datum, error) { - reader, err := mydump.OpenReader(ctx, &dataFileMeta, p.srcStorage, storage.DecompressConfig{ - ZStdDecodeConcurrency: 1, - }) - if err != nil { - return nil, nil, errors.Trace(err) - } - - var parser mydump.Parser - blockBufSize := int64(p.cfg.Mydumper.ReadBlockSize) - switch dataFileMeta.Type { - case mydump.SourceTypeCSV: - hasHeader := p.cfg.Mydumper.CSV.Header - // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. - charsetConvertor, err := mydump.NewCharsetConvertor(p.cfg.Mydumper.DataCharacterSet, p.cfg.Mydumper.DataInvalidCharReplace) - if err != nil { - return nil, nil, errors.Trace(err) - } - parser, err = mydump.NewCSVParser(ctx, &p.cfg.Mydumper.CSV, reader, blockBufSize, p.ioWorkers, hasHeader, charsetConvertor) - if err != nil { - return nil, nil, errors.Trace(err) - } - case mydump.SourceTypeSQL: - parser = mydump.NewChunkParser(ctx, p.cfg.TiDB.SQLMode, reader, blockBufSize, p.ioWorkers) - case mydump.SourceTypeParquet: - parser, err = mydump.NewParquetParser(ctx, p.srcStorage, reader, dataFileMeta.Path) - if err != nil { - return nil, nil, errors.Trace(err) - } - default: - panic(fmt.Sprintf("unknown file type '%s'", dataFileMeta.Type)) - } - //nolint: errcheck - defer parser.Close() - - rows := [][]types.Datum{} - for i := 0; i < n; i++ { - err := parser.ReadRow() - if err != nil { - if errors.Cause(err) != io.EOF { - return nil, nil, errors.Trace(err) - } - break - } - lastRowDatums := append([]types.Datum{}, parser.LastRow().Row...) - rows = append(rows, lastRowDatums) - } - return parser.Columns(), rows, nil -} - -// EstimateSourceDataSize estimates the datasize to generate during the import as well as some other sub-informaiton. -// It implements the PreImportInfoGetter interface. -// It has a cache mechanism. The estimated size will only calculated once. -// The caching behavior can be changed by appending the `ForceReloadCache(true)` option. -func (p *PreImportInfoGetterImpl) EstimateSourceDataSize(ctx context.Context, opts ...ropts.GetPreInfoOption) (*EstimateSourceDataSizeResult, error) { - var result *EstimateSourceDataSizeResult - - getPreInfoCfg := p.getPreInfoCfg.Clone() - for _, o := range opts { - o(getPreInfoCfg) - } - result = p.estimatedSizeCache - if result != nil && !getPreInfoCfg.ForceReloadCache { - return result, nil - } - - var ( - sizeWithIndex = int64(0) - tiflashSize = int64(0) - sourceTotalSize = int64(0) - tableCount = 0 - unSortedBigTableCount = 0 - errMgr = errormanager.New(nil, p.cfg, log.FromContext(ctx)) - ) - - dbInfos, err := p.GetAllTableStructures(ctx) - if err != nil { - return nil, errors.Trace(err) - } - sysVars := p.GetTargetSysVariablesForImport(ctx) - for _, db := range p.dbMetas { - info, ok := dbInfos[db.Name] - if !ok { - continue - } - for _, tbl := range db.Tables { - sourceTotalSize += tbl.TotalSize - tableInfo, ok := info.Tables[tbl.Name] - if ok { - tableSize := tbl.TotalSize - // Do not sample small table because there may a large number of small table and it will take a long - // time to sample data for all of them. - if isTiDBBackend(p.cfg) || tbl.TotalSize < int64(config.SplitRegionSize) { - tbl.IndexRatio = 1.0 - tbl.IsRowOrdered = false - } else { - sampledIndexRatio, isRowOrderedFromSample, err := p.sampleDataFromTable(ctx, db.Name, tbl, tableInfo.Core, errMgr, sysVars) - if err != nil { - return nil, errors.Trace(err) - } - tbl.IndexRatio = sampledIndexRatio - tbl.IsRowOrdered = isRowOrderedFromSample - - tableSize = int64(float64(tbl.TotalSize) * tbl.IndexRatio) - - if tbl.TotalSize > int64(config.DefaultBatchSize)*2 && !tbl.IsRowOrdered { - unSortedBigTableCount++ - } - } - - sizeWithIndex += tableSize - if tableInfo.Core.TiFlashReplica != nil && tableInfo.Core.TiFlashReplica.Available { - tiflashSize += tableSize * int64(tableInfo.Core.TiFlashReplica.Count) - } - tableCount++ - } - } - } - - if isLocalBackend(p.cfg) { - sizeWithIndex = int64(float64(sizeWithIndex) * compressionRatio) - tiflashSize = int64(float64(tiflashSize) * compressionRatio) - } - - result = &EstimateSourceDataSizeResult{ - SizeWithIndex: sizeWithIndex, - SizeWithoutIndex: sourceTotalSize, - HasUnsortedBigTables: (unSortedBigTableCount > 0), - TiFlashSize: tiflashSize, - } - p.estimatedSizeCache = result - return result, nil -} - -// sampleDataFromTable samples the source data file to get the extra data ratio for the index -// It returns: -// * the extra data ratio with index size accounted -// * is the sample data ordered by row -func (p *PreImportInfoGetterImpl) sampleDataFromTable( - ctx context.Context, - dbName string, - tableMeta *mydump.MDTableMeta, - tableInfo *model.TableInfo, - errMgr *errormanager.ErrorManager, - sysVars map[string]string, -) (float64, bool, error) { - resultIndexRatio := 1.0 - isRowOrdered := false - if len(tableMeta.DataFiles) == 0 { - return resultIndexRatio, isRowOrdered, nil - } - sampleFile := tableMeta.DataFiles[0].FileMeta - reader, err := mydump.OpenReader(ctx, &sampleFile, p.srcStorage, storage.DecompressConfig{ - ZStdDecodeConcurrency: 1, - }) - if err != nil { - return 0.0, false, errors.Trace(err) - } - idAlloc := kv.NewPanickingAllocators(tableInfo.SepAutoInc(), 0) - tbl, err := tables.TableFromMeta(idAlloc, tableInfo) - if err != nil { - return 0.0, false, errors.Trace(err) - } - logger := log.FromContext(ctx).With(zap.String("table", tableMeta.Name)) - kvEncoder, err := p.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{ - SessionOptions: encode.SessionOptions{ - SQLMode: p.cfg.TiDB.SQLMode, - Timestamp: 0, - SysVars: sysVars, - AutoRandomSeed: 0, - }, - Table: tbl, - Logger: logger, - }) - if err != nil { - return 0.0, false, errors.Trace(err) - } - blockBufSize := int64(p.cfg.Mydumper.ReadBlockSize) - - var parser mydump.Parser - switch tableMeta.DataFiles[0].FileMeta.Type { - case mydump.SourceTypeCSV: - hasHeader := p.cfg.Mydumper.CSV.Header - // Create a utf8mb4 convertor to encode and decode data with the charset of CSV files. - charsetConvertor, err := mydump.NewCharsetConvertor(p.cfg.Mydumper.DataCharacterSet, p.cfg.Mydumper.DataInvalidCharReplace) - if err != nil { - return 0.0, false, errors.Trace(err) - } - parser, err = mydump.NewCSVParser(ctx, &p.cfg.Mydumper.CSV, reader, blockBufSize, p.ioWorkers, hasHeader, charsetConvertor) - if err != nil { - return 0.0, false, errors.Trace(err) - } - case mydump.SourceTypeSQL: - parser = mydump.NewChunkParser(ctx, p.cfg.TiDB.SQLMode, reader, blockBufSize, p.ioWorkers) - case mydump.SourceTypeParquet: - parser, err = mydump.NewParquetParser(ctx, p.srcStorage, reader, sampleFile.Path) - if err != nil { - return 0.0, false, errors.Trace(err) - } - default: - panic(fmt.Sprintf("file '%s' with unknown source type '%s'", sampleFile.Path, sampleFile.Type.String())) - } - //nolint: errcheck - defer parser.Close() - logger.Begin(zap.InfoLevel, "sample file") - igCols, err := p.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(dbName, tableMeta.Name, p.cfg.Mydumper.CaseSensitive) - if err != nil { - return 0.0, false, errors.Trace(err) - } - - initializedColumns := false - var ( - columnPermutation []int - kvSize uint64 - rowSize uint64 - extendVals []types.Datum - ) - rowCount := 0 - dataKVs := p.encBuilder.MakeEmptyRows() - indexKVs := p.encBuilder.MakeEmptyRows() - lastKey := make([]byte, 0) - isRowOrdered = true -outloop: - for { - offset, _ := parser.Pos() - err = parser.ReadRow() - columnNames := parser.Columns() - - switch errors.Cause(err) { - case nil: - if !initializedColumns { - ignoreColsMap := igCols.ColumnsMap() - if len(columnPermutation) == 0 { - columnPermutation, err = createColumnPermutation( - columnNames, - ignoreColsMap, - tableInfo, - log.FromContext(ctx)) - if err != nil { - return 0.0, false, errors.Trace(err) - } - } - if len(sampleFile.ExtendData.Columns) > 0 { - _, extendVals = filterColumns(columnNames, sampleFile.ExtendData, ignoreColsMap, tableInfo) - } - initializedColumns = true - lastRow := parser.LastRow() - lastRowLen := len(lastRow.Row) - extendColsMap := make(map[string]int) - for i, c := range sampleFile.ExtendData.Columns { - extendColsMap[c] = lastRowLen + i - } - for i, col := range tableInfo.Columns { - if p, ok := extendColsMap[col.Name.O]; ok { - columnPermutation[i] = p - } - } - } - case io.EOF: - break outloop - default: - err = errors.Annotatef(err, "in file offset %d", offset) - return 0.0, false, errors.Trace(err) - } - lastRow := parser.LastRow() - rowCount++ - lastRow.Row = append(lastRow.Row, extendVals...) - - var dataChecksum, indexChecksum verification.KVChecksum - kvs, encodeErr := kvEncoder.Encode(lastRow.Row, lastRow.RowID, columnPermutation, offset) - if encodeErr != nil { - encodeErr = errMgr.RecordTypeError(ctx, log.FromContext(ctx), tableInfo.Name.O, sampleFile.Path, offset, - "" /* use a empty string here because we don't actually record */, encodeErr) - if encodeErr != nil { - return 0.0, false, errors.Annotatef(encodeErr, "in file at offset %d", offset) - } - if rowCount < maxSampleRowCount { - continue - } - break - } - if isRowOrdered { - kvs.ClassifyAndAppend(&dataKVs, &dataChecksum, &indexKVs, &indexChecksum) - for _, kv := range kv.Rows2KvPairs(dataKVs) { - if len(lastKey) == 0 { - lastKey = kv.Key - } else if bytes.Compare(lastKey, kv.Key) > 0 { - isRowOrdered = false - break - } - } - dataKVs = dataKVs.Clear() - indexKVs = indexKVs.Clear() - } - kvSize += kvs.Size() - rowSize += uint64(lastRow.Length) - parser.RecycleRow(lastRow) - - failpoint.Inject("mock-kv-size", func(val failpoint.Value) { - kvSize += uint64(val.(int)) - }) - if rowSize > maxSampleDataSize || rowCount > maxSampleRowCount { - break - } - } - - if rowSize > 0 && kvSize > rowSize { - resultIndexRatio = float64(kvSize) / float64(rowSize) - } - log.FromContext(ctx).Info("Sample source data", zap.String("table", tableMeta.Name), zap.Float64("IndexRatio", tableMeta.IndexRatio), zap.Bool("IsSourceOrder", tableMeta.IsRowOrdered)) - return resultIndexRatio, isRowOrdered, nil -} - -// GetMaxReplica implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) GetMaxReplica(ctx context.Context) (uint64, error) { - return p.targetInfoGetter.GetMaxReplica(ctx) -} - -// GetStorageInfo gets the storage information on the target. -// It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) GetStorageInfo(ctx context.Context) (*pdhttp.StoresInfo, error) { - return p.targetInfoGetter.GetStorageInfo(ctx) -} - -// GetEmptyRegionsInfo gets the region information of all the empty regions on the target. -// It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) GetEmptyRegionsInfo(ctx context.Context) (*pdhttp.RegionsInfo, error) { - return p.targetInfoGetter.GetEmptyRegionsInfo(ctx) -} - -// IsTableEmpty checks whether the specified table on the target DB contains data or not. -// It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) IsTableEmpty(ctx context.Context, schemaName string, tableName string) (*bool, error) { - return p.targetInfoGetter.IsTableEmpty(ctx, schemaName, tableName) -} - -// FetchRemoteDBModels fetches the database structures from the remote target. -// It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { - return p.targetInfoGetter.FetchRemoteDBModels(ctx) -} - -// FetchRemoteTableModels fetches the table structures from the remote target. -// It implements the PreImportInfoGetter interface. -func (p *PreImportInfoGetterImpl) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - return p.targetInfoGetter.FetchRemoteTableModels(ctx, schemaName) -} - -// CheckVersionRequirements performs the check whether the target satisfies the version requirements. -// It implements the PreImportInfoGetter interface. -// Mydump database metas are retrieved from the context. -func (p *PreImportInfoGetterImpl) CheckVersionRequirements(ctx context.Context) error { - return p.targetInfoGetter.CheckVersionRequirements(ctx) -} - -// GetTargetSysVariablesForImport gets some important systam variables for importing on the target. -// It implements the PreImportInfoGetter interface. -// It has caching mechanism. -func (p *PreImportInfoGetterImpl) GetTargetSysVariablesForImport(ctx context.Context, opts ...ropts.GetPreInfoOption) map[string]string { - var sysVars map[string]string - - getPreInfoCfg := p.getPreInfoCfg.Clone() - for _, o := range opts { - o(getPreInfoCfg) - } - sysVars = p.sysVarsCache - if sysVars != nil && !getPreInfoCfg.ForceReloadCache { - return sysVars - } - sysVars = p.targetInfoGetter.GetTargetSysVariablesForImport(ctx) - p.sysVarsCache = sysVars - return sysVars -} diff --git a/lightning/pkg/importer/import.go b/lightning/pkg/importer/import.go index 60a1a3059c780..04218f2f1102e 100644 --- a/lightning/pkg/importer/import.go +++ b/lightning/pkg/importer/import.go @@ -133,9 +133,9 @@ var DeliverPauser = common.NewPauser() // nolint:gochecknoinits // TODO: refactor func init() { - if v, _err_ := failpoint.Eval(_curpkg_("SetMinDeliverBytes")); _err_ == nil { + failpoint.Inject("SetMinDeliverBytes", func(v failpoint.Value) { minDeliverBytes = uint64(v.(int)) - } + }) } type saveCp struct { @@ -541,7 +541,7 @@ func (rc *Controller) Close() { // Run starts the restore task. func (rc *Controller) Run(ctx context.Context) error { - failpoint.Eval(_curpkg_("beforeRun")) + failpoint.Inject("beforeRun", func() {}) opts := []func(context.Context) error{ rc.setGlobalVariables, @@ -636,10 +636,10 @@ func (rc *Controller) initCheckpoint(ctx context.Context) error { if err != nil { return common.ErrInitCheckpoint.Wrap(err).GenWithStackByArgs() } - if _, _err_ := failpoint.Eval(_curpkg_("InitializeCheckpointExit")); _err_ == nil { + failpoint.Inject("InitializeCheckpointExit", func() { log.FromContext(ctx).Warn("exit triggered", zap.String("failpoint", "InitializeCheckpointExit")) os.Exit(0) - } + }) if err := rc.loadDesiredTableInfos(ctx); err != nil { return err } @@ -864,7 +864,7 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { lock.Unlock() //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Eval(_curpkg_("SlowDownCheckpointUpdate")) + failpoint.Inject("SlowDownCheckpointUpdate", func() {}) if len(cpd) > 0 { err := rc.checkpointsDB.Update(rc.taskCtx, cpd) @@ -897,25 +897,25 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { lock.Unlock() //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - if _, _err_ := failpoint.Eval(_curpkg_("FailIfImportedChunk")); _err_ == nil { + failpoint.Inject("FailIfImportedChunk", func() { if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { rc.checkpointsWg.Done() rc.checkpointsWg.Wait() panic("forcing failure due to FailIfImportedChunk") } - } + }) //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - if val, _err_ := failpoint.Eval(_curpkg_("FailIfStatusBecomes")); _err_ == nil { + failpoint.Inject("FailIfStatusBecomes", func(val failpoint.Value) { if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && merger.EngineID >= 0 && int(merger.Status) == val.(int) { rc.checkpointsWg.Done() rc.checkpointsWg.Wait() panic("forcing failure due to FailIfStatusBecomes") } - } + }) //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - if val, _err_ := failpoint.Eval(_curpkg_("FailIfIndexEngineImported")); _err_ == nil { + failpoint.Inject("FailIfIndexEngineImported", func(val failpoint.Value) { if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && merger.EngineID == checkpoints.WholeTableEngineID && merger.Status == checkpoints.CheckpointStatusIndexImported && val.(int) > 0 { @@ -923,10 +923,10 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { rc.checkpointsWg.Wait() panic("forcing failure due to FailIfIndexEngineImported") } - } + }) //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - if _, _err_ := failpoint.Eval(_curpkg_("KillIfImportedChunk")); _err_ == nil { + failpoint.Inject("KillIfImportedChunk", func() { if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { rc.checkpointsWg.Done() rc.checkpointsWg.Wait() @@ -938,9 +938,9 @@ func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { scp.waitCh <- context.Canceled } } - return + failpoint.Return() } - } + }) } // Don't put this statement in defer function at the beginning. failpoint function may call it manually. rc.checkpointsWg.Done() diff --git a/lightning/pkg/importer/import.go__failpoint_stash__ b/lightning/pkg/importer/import.go__failpoint_stash__ deleted file mode 100644 index 04218f2f1102e..0000000000000 --- a/lightning/pkg/importer/import.go__failpoint_stash__ +++ /dev/null @@ -1,2080 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "context" - "database/sql" - "fmt" - "math" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/coreos/go-semver/semver" - "github.com/docker/go-units" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/metapb" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/pdutil" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/br/pkg/version/build" - "github.com/pingcap/tidb/lightning/pkg/web" - tidbconfig "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/keyspace" - tidbkv "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/backend/tidb" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/errormanager" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/lightning/tikv" - "github.com/pingcap/tidb/pkg/lightning/worker" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/session" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/store/driver" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/etcd" - regexprrouter "github.com/pingcap/tidb/pkg/util/regexpr-router" - "github.com/pingcap/tidb/pkg/util/set" - "github.com/prometheus/client_golang/prometheus" - tikvconfig "github.com/tikv/client-go/v2/config" - kvutil "github.com/tikv/client-go/v2/util" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - "github.com/tikv/pd/client/retry" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/atomic" - "go.uber.org/multierr" - "go.uber.org/zap" -) - -// compact levels -const ( - FullLevelCompact = -1 - Level1Compact = 1 -) - -const ( - compactStateIdle int32 = iota - compactStateDoing -) - -// task related table names and create table statements. -const ( - TaskMetaTableName = "task_meta_v2" - TableMetaTableName = "table_meta" - // CreateTableMetadataTable stores the per-table sub jobs information used by TiDB Lightning - CreateTableMetadataTable = `CREATE TABLE IF NOT EXISTS %s.%s ( - task_id BIGINT(20) UNSIGNED, - table_id BIGINT(64) NOT NULL, - table_name VARCHAR(64) NOT NULL, - row_id_base BIGINT(20) NOT NULL DEFAULT 0, - row_id_max BIGINT(20) NOT NULL DEFAULT 0, - total_kvs_base BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - total_bytes_base BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - checksum_base BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - total_kvs BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - total_bytes BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - checksum BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - status VARCHAR(32) NOT NULL, - has_duplicates BOOL NOT NULL DEFAULT 0, - PRIMARY KEY (table_id, task_id) - );` - // CreateTaskMetaTable stores the pre-lightning metadata used by TiDB Lightning - CreateTaskMetaTable = `CREATE TABLE IF NOT EXISTS %s.%s ( - task_id BIGINT(20) UNSIGNED NOT NULL, - pd_cfgs VARCHAR(2048) NOT NULL DEFAULT '', - status VARCHAR(32) NOT NULL, - state TINYINT(1) NOT NULL DEFAULT 0 COMMENT '0: normal, 1: exited before finish', - tikv_source_bytes BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - tiflash_source_bytes BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - tikv_avail BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - tiflash_avail BIGINT(20) UNSIGNED NOT NULL DEFAULT 0, - PRIMARY KEY (task_id) - );` -) - -var ( - minTiKVVersionForConflictStrategy = *semver.New("5.2.0") - maxTiKVVersionForConflictStrategy = version.NextMajorVersion() -) - -// DeliverPauser is a shared pauser to pause progress to (*chunkProcessor).encodeLoop -var DeliverPauser = common.NewPauser() - -// nolint:gochecknoinits // TODO: refactor -func init() { - failpoint.Inject("SetMinDeliverBytes", func(v failpoint.Value) { - minDeliverBytes = uint64(v.(int)) - }) -} - -type saveCp struct { - tableName string - merger checkpoints.TableCheckpointMerger - waitCh chan<- error -} - -type errorSummary struct { - status checkpoints.CheckpointStatus - err error -} - -type errorSummaries struct { - sync.Mutex - logger log.Logger - summary map[string]errorSummary -} - -// makeErrorSummaries returns an initialized errorSummaries instance -func makeErrorSummaries(logger log.Logger) errorSummaries { - return errorSummaries{ - logger: logger, - summary: make(map[string]errorSummary), - } -} - -func (es *errorSummaries) emitLog() { - es.Lock() - defer es.Unlock() - - if errorCount := len(es.summary); errorCount > 0 { - logger := es.logger - logger.Error("tables failed to be imported", zap.Int("count", errorCount)) - for tableName, errorSummary := range es.summary { - logger.Error("-", - zap.String("table", tableName), - zap.String("status", errorSummary.status.MetricName()), - log.ShortError(errorSummary.err), - ) - } - } -} - -func (es *errorSummaries) record(tableName string, err error, status checkpoints.CheckpointStatus) { - es.Lock() - defer es.Unlock() - es.summary[tableName] = errorSummary{status: status, err: err} -} - -const ( - diskQuotaStateIdle int32 = iota - diskQuotaStateChecking - diskQuotaStateImporting -) - -// Controller controls the whole import process. -type Controller struct { - taskCtx context.Context - cfg *config.Config - dbMetas []*mydump.MDDatabaseMeta - dbInfos map[string]*checkpoints.TidbDBInfo - tableWorkers *worker.Pool - indexWorkers *worker.Pool - regionWorkers *worker.Pool - ioWorkers *worker.Pool - checksumWorks *worker.Pool - pauser *common.Pauser - engineMgr backend.EngineManager - backend backend.Backend - db *sql.DB - pdCli pd.Client - pdHTTPCli pdhttp.Client - - sysVars map[string]string - tls *common.TLS - checkTemplate Template - - errorSummaries errorSummaries - - checkpointsDB checkpoints.DB - saveCpCh chan saveCp - checkpointsWg sync.WaitGroup - - closedEngineLimit *worker.Pool - addIndexLimit *worker.Pool - - store storage.ExternalStorage - ownStore bool - metaMgrBuilder metaMgrBuilder - errorMgr *errormanager.ErrorManager - taskMgr taskMetaMgr - - diskQuotaLock sync.RWMutex - diskQuotaState atomic.Int32 - compactState atomic.Int32 - status *LightningStatus - dupIndicator *atomic.Bool - - preInfoGetter PreImportInfoGetter - precheckItemBuilder *PrecheckItemBuilder - encBuilder encode.EncodingBuilder - tikvModeSwitcher local.TiKVModeSwitcher - - keyspaceName string - resourceGroupName string - taskType string -} - -// LightningStatus provides the finished bytes and total bytes of the current task. -// It should keep the value after restart from checkpoint. -// When it is tidb backend, FinishedFileSize can be counted after chunk data is -// restored to tidb. When it is local backend it's counted after whole engine is -// imported. -// TotalFileSize may be an estimated value, so when the task is finished, it may -// not equal to FinishedFileSize. -type LightningStatus struct { - backend string - FinishedFileSize atomic.Int64 - TotalFileSize atomic.Int64 -} - -// ControllerParam contains many parameters for creating a Controller. -type ControllerParam struct { - // databases that dumper created - DBMetas []*mydump.MDDatabaseMeta - // a pointer to status to report it to caller - Status *LightningStatus - // storage interface to read the dump data - DumpFileStorage storage.ExternalStorage - // true if DumpFileStorage is created by lightning. In some cases where lightning is a library, the framework may pass an DumpFileStorage - OwnExtStorage bool - // used by lightning server mode to pause tasks - Pauser *common.Pauser - // DB is a connection pool to TiDB - DB *sql.DB - // storage interface to write file checkpoints - CheckpointStorage storage.ExternalStorage - // when CheckpointStorage is not nil, save file checkpoint to it with this name - CheckpointName string - // DupIndicator can expose the duplicate detection result to the caller - DupIndicator *atomic.Bool - // Keyspace name - KeyspaceName string - // ResourceGroup name for current TiDB user - ResourceGroupName string - // TaskType is the source component name use for background task control. - TaskType string -} - -// NewImportController creates a new Controller instance. -func NewImportController( - ctx context.Context, - cfg *config.Config, - param *ControllerParam, -) (*Controller, error) { - param.Pauser = DeliverPauser - return NewImportControllerWithPauser(ctx, cfg, param) -} - -// NewImportControllerWithPauser creates a new Controller instance with a pauser. -func NewImportControllerWithPauser( - ctx context.Context, - cfg *config.Config, - p *ControllerParam, -) (*Controller, error) { - tls, err := cfg.ToTLS() - if err != nil { - return nil, err - } - - var cpdb checkpoints.DB - // if CheckpointStorage is set, we should use given ExternalStorage to create checkpoints. - if p.CheckpointStorage != nil { - cpdb, err = checkpoints.NewFileCheckpointsDBWithExstorageFileName(ctx, p.CheckpointStorage.URI(), p.CheckpointStorage, p.CheckpointName) - if err != nil { - return nil, common.ErrOpenCheckpoint.Wrap(err).GenWithStackByArgs() - } - } else { - cpdb, err = checkpoints.OpenCheckpointsDB(ctx, cfg) - if err != nil { - if berrors.Is(err, common.ErrUnknownCheckpointDriver) { - return nil, err - } - return nil, common.ErrOpenCheckpoint.Wrap(err).GenWithStackByArgs() - } - } - - taskCp, err := cpdb.TaskCheckpoint(ctx) - if err != nil { - return nil, common.ErrReadCheckpoint.Wrap(err).GenWithStack("get task checkpoint failed") - } - if err := verifyCheckpoint(cfg, taskCp); err != nil { - return nil, errors.Trace(err) - } - // reuse task id to reuse task meta correctly. - if taskCp != nil { - cfg.TaskID = taskCp.TaskID - } - - db := p.DB - errorMgr := errormanager.New(db, cfg, log.FromContext(ctx)) - if err := errorMgr.Init(ctx); err != nil { - return nil, common.ErrInitErrManager.Wrap(err).GenWithStackByArgs() - } - - var encodingBuilder encode.EncodingBuilder - var backendObj backend.Backend - var pdCli pd.Client - var pdHTTPCli pdhttp.Client - switch cfg.TikvImporter.Backend { - case config.BackendTiDB: - encodingBuilder = tidb.NewEncodingBuilder() - backendObj = tidb.NewTiDBBackend(ctx, db, cfg, errorMgr) - case config.BackendLocal: - var rLimit local.RlimT - rLimit, err = local.GetSystemRLimit() - if err != nil { - return nil, err - } - maxOpenFiles := int(rLimit / local.RlimT(cfg.App.TableConcurrency)) - // check overflow - if maxOpenFiles < 0 { - maxOpenFiles = math.MaxInt32 - } - - addrs := strings.Split(cfg.TiDB.PdAddr, ",") - pdCli, err = pd.NewClientWithContext(ctx, addrs, tls.ToPDSecurityOption()) - if err != nil { - return nil, errors.Trace(err) - } - pdHTTPCli = pdhttp.NewClientWithServiceDiscovery( - "lightning", - pdCli.GetServiceDiscovery(), - pdhttp.WithTLSConfig(tls.TLSConfig()), - ).WithBackoffer(retry.InitialBackoffer(time.Second, time.Second, pdutil.PDRequestRetryTime*time.Second)) - - if isLocalBackend(cfg) && cfg.Conflict.Strategy != config.NoneOnDup { - if err := tikv.CheckTiKVVersion(ctx, pdHTTPCli, minTiKVVersionForConflictStrategy, maxTiKVVersionForConflictStrategy); err != nil { - if !berrors.Is(err, berrors.ErrVersionMismatch) { - return nil, common.ErrCheckKVVersion.Wrap(err).GenWithStackByArgs() - } - log.FromContext(ctx).Warn("TiKV version doesn't support conflict strategy. The resolution algorithm will fall back to 'none'", zap.Error(err)) - cfg.Conflict.Strategy = config.NoneOnDup - } - } - - initGlobalConfig(tls.ToTiKVSecurityConfig()) - - encodingBuilder = local.NewEncodingBuilder(ctx) - - // get resource group name. - exec := common.SQLWithRetry{ - DB: db, - Logger: log.FromContext(ctx), - } - if err := exec.QueryRow(ctx, "", "select current_resource_group();", &p.ResourceGroupName); err != nil { - if common.IsFunctionNotExistErr(err, "current_resource_group") { - log.FromContext(ctx).Warn("current_resource_group() not supported, ignore this error", zap.Error(err)) - } - } - - taskType, err := common.GetExplicitRequestSourceTypeFromDB(ctx, db) - if err != nil { - return nil, errors.Annotatef(err, "get system variable '%s' failed", variable.TiDBExplicitRequestSourceType) - } - if taskType == "" { - taskType = kvutil.ExplicitTypeLightning - } - p.TaskType = taskType - - // TODO: we should not need to check config here. - // Instead, we should perform the following during switch mode: - // 1. for each tikv, try to switch mode without any ranges. - // 2. if it returns normally, it means the store is using a raft-v1 engine. - // 3. if it returns the `partitioned-raft-kv only support switch mode with range set` error, - // it means the store is a raft-v2 engine and we will include the ranges from now on. - isRaftKV2, err := common.IsRaftKV2(ctx, db) - if err != nil { - log.FromContext(ctx).Warn("check isRaftKV2 failed", zap.Error(err)) - } - var raftKV2SwitchModeDuration time.Duration - if isRaftKV2 { - raftKV2SwitchModeDuration = cfg.Cron.SwitchMode.Duration - } - backendConfig := local.NewBackendConfig(cfg, maxOpenFiles, p.KeyspaceName, p.ResourceGroupName, p.TaskType, raftKV2SwitchModeDuration) - backendObj, err = local.NewBackend(ctx, tls, backendConfig, pdCli.GetServiceDiscovery()) - if err != nil { - return nil, common.NormalizeOrWrapErr(common.ErrUnknown, err) - } - err = verifyLocalFile(ctx, cpdb, cfg.TikvImporter.SortedKVDir) - if err != nil { - return nil, err - } - default: - return nil, common.ErrUnknownBackend.GenWithStackByArgs(cfg.TikvImporter.Backend) - } - p.Status.backend = cfg.TikvImporter.Backend - - var metaBuilder metaMgrBuilder - isSSTImport := cfg.TikvImporter.Backend == config.BackendLocal - switch { - case isSSTImport && cfg.TikvImporter.ParallelImport: - metaBuilder = &dbMetaMgrBuilder{ - db: db, - taskID: cfg.TaskID, - schema: cfg.App.MetaSchemaName, - needChecksum: cfg.PostRestore.Checksum != config.OpLevelOff, - } - case isSSTImport: - metaBuilder = singleMgrBuilder{ - taskID: cfg.TaskID, - } - default: - metaBuilder = noopMetaMgrBuilder{} - } - - var wrapper backend.TargetInfoGetter - if cfg.TikvImporter.Backend == config.BackendLocal { - wrapper = local.NewTargetInfoGetter(tls, db, pdHTTPCli) - } else { - wrapper = tidb.NewTargetInfoGetter(db) - } - ioWorkers := worker.NewPool(ctx, cfg.App.IOConcurrency, "io") - targetInfoGetter := &TargetInfoGetterImpl{ - cfg: cfg, - db: db, - backend: wrapper, - pdHTTPCli: pdHTTPCli, - } - preInfoGetter, err := NewPreImportInfoGetter( - cfg, - p.DBMetas, - p.DumpFileStorage, - targetInfoGetter, - ioWorkers, - encodingBuilder, - ) - if err != nil { - return nil, errors.Trace(err) - } - - preCheckBuilder := NewPrecheckItemBuilder( - cfg, p.DBMetas, preInfoGetter, cpdb, pdHTTPCli, - ) - - rc := &Controller{ - taskCtx: ctx, - cfg: cfg, - dbMetas: p.DBMetas, - tableWorkers: nil, - indexWorkers: nil, - regionWorkers: worker.NewPool(ctx, cfg.App.RegionConcurrency, "region"), - ioWorkers: ioWorkers, - checksumWorks: worker.NewPool(ctx, cfg.TiDB.ChecksumTableConcurrency, "checksum"), - pauser: p.Pauser, - engineMgr: backend.MakeEngineManager(backendObj), - backend: backendObj, - pdCli: pdCli, - pdHTTPCli: pdHTTPCli, - db: db, - sysVars: common.DefaultImportantVariables, - tls: tls, - checkTemplate: NewSimpleTemplate(), - - errorSummaries: makeErrorSummaries(log.FromContext(ctx)), - checkpointsDB: cpdb, - saveCpCh: make(chan saveCp), - closedEngineLimit: worker.NewPool(ctx, cfg.App.TableConcurrency*2, "closed-engine"), - // Currently, TiDB add index acceration doesn't support multiple tables simultaneously. - // So we use a single worker to ensure at most one table is adding index at the same time. - addIndexLimit: worker.NewPool(ctx, 1, "add-index"), - - store: p.DumpFileStorage, - ownStore: p.OwnExtStorage, - metaMgrBuilder: metaBuilder, - errorMgr: errorMgr, - status: p.Status, - taskMgr: nil, - dupIndicator: p.DupIndicator, - - preInfoGetter: preInfoGetter, - precheckItemBuilder: preCheckBuilder, - encBuilder: encodingBuilder, - tikvModeSwitcher: local.NewTiKVModeSwitcher(tls.TLSConfig(), pdHTTPCli, log.FromContext(ctx).Logger), - - keyspaceName: p.KeyspaceName, - resourceGroupName: p.ResourceGroupName, - taskType: p.TaskType, - } - - return rc, nil -} - -// Close closes the controller. -func (rc *Controller) Close() { - rc.backend.Close() - _ = rc.db.Close() - if rc.pdCli != nil { - rc.pdCli.Close() - } -} - -// Run starts the restore task. -func (rc *Controller) Run(ctx context.Context) error { - failpoint.Inject("beforeRun", func() {}) - - opts := []func(context.Context) error{ - rc.setGlobalVariables, - rc.restoreSchema, - rc.preCheckRequirements, - rc.initCheckpoint, - rc.importTables, - rc.fullCompact, - rc.cleanCheckpoints, - } - - task := log.FromContext(ctx).Begin(zap.InfoLevel, "the whole procedure") - - var err error - finished := false -outside: - for i, process := range opts { - err = process(ctx) - if i == len(opts)-1 { - finished = true - } - logger := task.With(zap.Int("step", i), log.ShortError(err)) - - switch { - case err == nil: - case log.IsContextCanceledError(err): - logger.Info("task canceled") - break outside - default: - logger.Error("run failed") - break outside // ps : not continue - } - } - - // if process is cancelled, should make sure checkpoints are written to db. - if !finished { - rc.waitCheckpointFinish() - } - - task.End(zap.ErrorLevel, err) - rc.errorMgr.LogErrorDetails() - rc.errorSummaries.emitLog() - - return errors.Trace(err) -} - -func (rc *Controller) restoreSchema(ctx context.Context) error { - // create table with schema file - // we can handle the duplicated created with createIfNotExist statement - // and we will check the schema in TiDB is valid with the datafile in DataCheck later. - logger := log.FromContext(ctx) - concurrency := min(rc.cfg.App.RegionConcurrency, 8) - // sql.DB is a connection pool, we set it to concurrency + 1(for job generator) - // to reuse connections, as we might call db.Conn/conn.Close many times. - // there's no API to get sql.DB.MaxIdleConns, so we revert to its default which is 2 - rc.db.SetMaxIdleConns(concurrency + 1) - defer rc.db.SetMaxIdleConns(2) - schemaImp := mydump.NewSchemaImporter(logger, rc.cfg.TiDB.SQLMode, rc.db, rc.store, concurrency) - err := schemaImp.Run(ctx, rc.dbMetas) - if err != nil { - return err - } - - dbInfos, err := rc.preInfoGetter.GetAllTableStructures(ctx) - if err != nil { - return errors.Trace(err) - } - // For local backend, we need DBInfo.ID to operate the global autoid allocator. - if isLocalBackend(rc.cfg) { - dbs, err := tikv.FetchRemoteDBModelsFromTLS(ctx, rc.tls) - if err != nil { - return errors.Trace(err) - } - dbIDs := make(map[string]int64) - for _, db := range dbs { - dbIDs[db.Name.L] = db.ID - } - for _, dbInfo := range dbInfos { - dbInfo.ID = dbIDs[strings.ToLower(dbInfo.Name)] - } - } - rc.dbInfos = dbInfos - rc.sysVars = rc.preInfoGetter.GetTargetSysVariablesForImport(ctx) - - return nil -} - -// initCheckpoint initializes all tables' checkpoint data -func (rc *Controller) initCheckpoint(ctx context.Context) error { - // Load new checkpoints - err := rc.checkpointsDB.Initialize(ctx, rc.cfg, rc.dbInfos) - if err != nil { - return common.ErrInitCheckpoint.Wrap(err).GenWithStackByArgs() - } - failpoint.Inject("InitializeCheckpointExit", func() { - log.FromContext(ctx).Warn("exit triggered", zap.String("failpoint", "InitializeCheckpointExit")) - os.Exit(0) - }) - if err := rc.loadDesiredTableInfos(ctx); err != nil { - return err - } - - rc.checkpointsWg.Add(1) // checkpointsWg will be done in `rc.listenCheckpointUpdates` - go rc.listenCheckpointUpdates(log.FromContext(ctx)) - - // Estimate the number of chunks for progress reporting - return rc.estimateChunkCountIntoMetrics(ctx) -} - -func (rc *Controller) loadDesiredTableInfos(ctx context.Context) error { - for _, dbInfo := range rc.dbInfos { - for _, tableInfo := range dbInfo.Tables { - cp, err := rc.checkpointsDB.Get(ctx, common.UniqueTable(dbInfo.Name, tableInfo.Name)) - if err != nil { - return err - } - // If checkpoint is disabled, cp.TableInfo will be nil. - // In this case, we just use current tableInfo as desired tableInfo. - if cp.TableInfo != nil { - tableInfo.Desired = cp.TableInfo - } - } - } - return nil -} - -// verifyCheckpoint check whether previous task checkpoint is compatible with task config -func verifyCheckpoint(cfg *config.Config, taskCp *checkpoints.TaskCheckpoint) error { - if taskCp == nil { - return nil - } - // always check the backend value even with 'check-requirements = false' - retryUsage := "destroy all checkpoints" - if cfg.Checkpoint.Driver == config.CheckpointDriverFile { - retryUsage = fmt.Sprintf("delete the file '%s'", cfg.Checkpoint.DSN) - } - retryUsage += " and remove all restored tables and try again" - - if cfg.TikvImporter.Backend != taskCp.Backend { - return common.ErrInvalidCheckpoint.GenWithStack("config 'tikv-importer.backend' value '%s' different from checkpoint value '%s', please %s", cfg.TikvImporter.Backend, taskCp.Backend, retryUsage) - } - - if cfg.App.CheckRequirements { - if build.ReleaseVersion != taskCp.LightningVer { - var displayVer string - if len(taskCp.LightningVer) != 0 { - displayVer = fmt.Sprintf("at '%s'", taskCp.LightningVer) - } else { - displayVer = "before v4.0.6/v3.0.19" - } - return common.ErrInvalidCheckpoint.GenWithStack("lightning version is '%s', but checkpoint was created %s, please %s", build.ReleaseVersion, displayVer, retryUsage) - } - - errorFmt := "config '%s' value '%s' different from checkpoint value '%s'. You may set 'check-requirements = false' to skip this check or " + retryUsage - if cfg.Mydumper.SourceDir != taskCp.SourceDir { - return common.ErrInvalidCheckpoint.GenWithStack(errorFmt, "mydumper.data-source-dir", cfg.Mydumper.SourceDir, taskCp.SourceDir) - } - - if cfg.TikvImporter.Backend == config.BackendLocal && cfg.TikvImporter.SortedKVDir != taskCp.SortedKVDir { - return common.ErrInvalidCheckpoint.GenWithStack(errorFmt, "mydumper.sorted-kv-dir", cfg.TikvImporter.SortedKVDir, taskCp.SortedKVDir) - } - } - - return nil -} - -// for local backend, we should check if local SST exists in disk, otherwise we'll lost data -func verifyLocalFile(ctx context.Context, cpdb checkpoints.DB, dir string) error { - targetTables, err := cpdb.GetLocalStoringTables(ctx) - if err != nil { - return errors.Trace(err) - } - for tableName, engineIDs := range targetTables { - for _, engineID := range engineIDs { - _, eID := backend.MakeUUID(tableName, int64(engineID)) - file := local.Engine{UUID: eID} - err := file.Exist(dir) - if err != nil { - log.FromContext(ctx).Error("can't find local file", - zap.String("table name", tableName), - zap.Int32("engine ID", engineID)) - if os.IsNotExist(err) { - err = common.ErrCheckLocalFile.GenWithStackByArgs(tableName, dir) - } else { - err = common.ErrCheckLocalFile.Wrap(err).GenWithStackByArgs(tableName, dir) - } - return err - } - } - } - return nil -} - -func (rc *Controller) estimateChunkCountIntoMetrics(ctx context.Context) error { - estimatedChunkCount := 0.0 - estimatedEngineCnt := int64(0) - for _, dbMeta := range rc.dbMetas { - for _, tableMeta := range dbMeta.Tables { - batchSize := mydump.CalculateBatchSize(float64(rc.cfg.Mydumper.BatchSize), - tableMeta.IsRowOrdered, float64(tableMeta.TotalSize)) - tableName := common.UniqueTable(dbMeta.Name, tableMeta.Name) - dbCp, err := rc.checkpointsDB.Get(ctx, tableName) - if err != nil { - return errors.Trace(err) - } - - fileChunks := make(map[string]float64) - for engineID, eCp := range dbCp.Engines { - if eCp.Status < checkpoints.CheckpointStatusImported { - estimatedEngineCnt++ - } - if engineID == common.IndexEngineID { - continue - } - for _, c := range eCp.Chunks { - if _, ok := fileChunks[c.Key.Path]; !ok { - fileChunks[c.Key.Path] = 0.0 - } - remainChunkCnt := float64(c.UnfinishedSize()) / float64(c.TotalSize()) - fileChunks[c.Key.Path] += remainChunkCnt - } - } - // estimate engines count if engine cp is empty - if len(dbCp.Engines) == 0 { - estimatedEngineCnt += ((tableMeta.TotalSize + int64(batchSize) - 1) / int64(batchSize)) + 1 - } - for _, fileMeta := range tableMeta.DataFiles { - if cnt, ok := fileChunks[fileMeta.FileMeta.Path]; ok { - estimatedChunkCount += cnt - continue - } - if fileMeta.FileMeta.Type == mydump.SourceTypeCSV { - cfg := rc.cfg.Mydumper - if fileMeta.FileMeta.FileSize > int64(cfg.MaxRegionSize) && cfg.StrictFormat && - !cfg.CSV.Header && fileMeta.FileMeta.Compression == mydump.CompressionNone { - estimatedChunkCount += math.Round(float64(fileMeta.FileMeta.FileSize) / float64(cfg.MaxRegionSize)) - } else { - estimatedChunkCount++ - } - } else { - estimatedChunkCount++ - } - } - } - } - if m, ok := metric.FromContext(ctx); ok { - m.ChunkCounter.WithLabelValues(metric.ChunkStateEstimated).Add(estimatedChunkCount) - m.ProcessedEngineCounter.WithLabelValues(metric.ChunkStateEstimated, metric.TableResultSuccess). - Add(float64(estimatedEngineCnt)) - } - return nil -} - -func firstErr(errors ...error) error { - for _, err := range errors { - if err != nil { - return err - } - } - return nil -} - -func (rc *Controller) saveStatusCheckpoint(ctx context.Context, tableName string, engineID int32, err error, statusIfSucceed checkpoints.CheckpointStatus) error { - merger := &checkpoints.StatusCheckpointMerger{Status: statusIfSucceed, EngineID: engineID} - - logger := log.FromContext(ctx).With(zap.String("table", tableName), zap.Int32("engine_id", engineID), - zap.String("new_status", statusIfSucceed.MetricName()), zap.Error(err)) - logger.Debug("update checkpoint") - - switch { - case err == nil: - case utils.MessageIsRetryableStorageError(err.Error()), common.IsContextCanceledError(err): - // recoverable error, should not be recorded in checkpoint - // which will prevent lightning from automatically recovering - return nil - default: - // unrecoverable error - merger.SetInvalid() - rc.errorSummaries.record(tableName, err, statusIfSucceed) - } - - if m, ok := metric.FromContext(ctx); ok { - if engineID == checkpoints.WholeTableEngineID { - m.RecordTableCount(statusIfSucceed.MetricName(), err) - } else { - m.RecordEngineCount(statusIfSucceed.MetricName(), err) - } - } - - waitCh := make(chan error, 1) - rc.saveCpCh <- saveCp{tableName: tableName, merger: merger, waitCh: waitCh} - - select { - case saveCpErr := <-waitCh: - if saveCpErr != nil { - logger.Error("failed to save status checkpoint", log.ShortError(saveCpErr)) - } - return saveCpErr - case <-ctx.Done(): - return ctx.Err() - } -} - -// listenCheckpointUpdates will combine several checkpoints together to reduce database load. -func (rc *Controller) listenCheckpointUpdates(logger log.Logger) { - var lock sync.Mutex - coalesed := make(map[string]*checkpoints.TableCheckpointDiff) - var waiters []chan<- error - - hasCheckpoint := make(chan struct{}, 1) - defer close(hasCheckpoint) - - go func() { - for range hasCheckpoint { - lock.Lock() - cpd := coalesed - coalesed = make(map[string]*checkpoints.TableCheckpointDiff) - ws := waiters - waiters = nil - lock.Unlock() - - //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("SlowDownCheckpointUpdate", func() {}) - - if len(cpd) > 0 { - err := rc.checkpointsDB.Update(rc.taskCtx, cpd) - for _, w := range ws { - w <- common.NormalizeOrWrapErr(common.ErrUpdateCheckpoint, err) - } - web.BroadcastCheckpointDiff(cpd) - } - rc.checkpointsWg.Done() - } - }() - - for scp := range rc.saveCpCh { - lock.Lock() - cpd, ok := coalesed[scp.tableName] - if !ok { - cpd = checkpoints.NewTableCheckpointDiff() - coalesed[scp.tableName] = cpd - } - scp.merger.MergeInto(cpd) - if scp.waitCh != nil { - waiters = append(waiters, scp.waitCh) - } - - if len(hasCheckpoint) == 0 { - rc.checkpointsWg.Add(1) - hasCheckpoint <- struct{}{} - } - - lock.Unlock() - - //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("FailIfImportedChunk", func() { - if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { - rc.checkpointsWg.Done() - rc.checkpointsWg.Wait() - panic("forcing failure due to FailIfImportedChunk") - } - }) - - //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("FailIfStatusBecomes", func(val failpoint.Value) { - if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && merger.EngineID >= 0 && int(merger.Status) == val.(int) { - rc.checkpointsWg.Done() - rc.checkpointsWg.Wait() - panic("forcing failure due to FailIfStatusBecomes") - } - }) - - //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("FailIfIndexEngineImported", func(val failpoint.Value) { - if merger, ok := scp.merger.(*checkpoints.StatusCheckpointMerger); ok && - merger.EngineID == checkpoints.WholeTableEngineID && - merger.Status == checkpoints.CheckpointStatusIndexImported && val.(int) > 0 { - rc.checkpointsWg.Done() - rc.checkpointsWg.Wait() - panic("forcing failure due to FailIfIndexEngineImported") - } - }) - - //nolint:scopelint // This would be either INLINED or ERASED, at compile time. - failpoint.Inject("KillIfImportedChunk", func() { - if merger, ok := scp.merger.(*checkpoints.ChunkCheckpointMerger); ok && merger.Pos >= merger.EndOffset { - rc.checkpointsWg.Done() - rc.checkpointsWg.Wait() - if err := common.KillMySelf(); err != nil { - logger.Warn("KillMySelf() failed to kill itself", log.ShortError(err)) - } - for scp := range rc.saveCpCh { - if scp.waitCh != nil { - scp.waitCh <- context.Canceled - } - } - failpoint.Return() - } - }) - } - // Don't put this statement in defer function at the beginning. failpoint function may call it manually. - rc.checkpointsWg.Done() -} - -// buildRunPeriodicActionAndCancelFunc build the runPeriodicAction func and a cancel func -func (rc *Controller) buildRunPeriodicActionAndCancelFunc(ctx context.Context, stop <-chan struct{}) (func(), func(bool)) { - cancelFuncs := make([]func(bool), 0) - closeFuncs := make([]func(), 0) - // a nil channel blocks forever. - // if the cron duration is zero we use the nil channel to skip the action. - var logProgressChan <-chan time.Time - if rc.cfg.Cron.LogProgress.Duration > 0 { - logProgressTicker := time.NewTicker(rc.cfg.Cron.LogProgress.Duration) - closeFuncs = append(closeFuncs, func() { - logProgressTicker.Stop() - }) - logProgressChan = logProgressTicker.C - } - - var switchModeChan <-chan time.Time - // tidb backend don't need to switch tikv to import mode - if isLocalBackend(rc.cfg) && rc.cfg.Cron.SwitchMode.Duration > 0 { - switchModeTicker := time.NewTicker(rc.cfg.Cron.SwitchMode.Duration) - cancelFuncs = append(cancelFuncs, func(bool) { switchModeTicker.Stop() }) - cancelFuncs = append(cancelFuncs, func(do bool) { - if do { - rc.tikvModeSwitcher.ToNormalMode(ctx) - } - }) - switchModeChan = switchModeTicker.C - } - - var checkQuotaChan <-chan time.Time - // only local storage has disk quota concern. - if rc.cfg.TikvImporter.Backend == config.BackendLocal && rc.cfg.Cron.CheckDiskQuota.Duration > 0 { - checkQuotaTicker := time.NewTicker(rc.cfg.Cron.CheckDiskQuota.Duration) - cancelFuncs = append(cancelFuncs, func(bool) { checkQuotaTicker.Stop() }) - checkQuotaChan = checkQuotaTicker.C - } - - return func() { - defer func() { - for _, f := range closeFuncs { - f() - } - }() - if rc.cfg.Cron.SwitchMode.Duration > 0 && isLocalBackend(rc.cfg) { - rc.tikvModeSwitcher.ToImportMode(ctx) - } - start := time.Now() - for { - select { - case <-ctx.Done(): - log.FromContext(ctx).Warn("stopping periodic actions", log.ShortError(ctx.Err())) - return - case <-stop: - log.FromContext(ctx).Info("everything imported, stopping periodic actions") - return - - case <-switchModeChan: - // periodically switch to import mode, as requested by TiKV 3.0 - // TiKV will switch back to normal mode if we didn't call this again within 10 minutes - rc.tikvModeSwitcher.ToImportMode(ctx) - - case <-logProgressChan: - metrics, ok := metric.FromContext(ctx) - if !ok { - log.FromContext(ctx).Warn("couldn't find metrics from context, skip log progress") - continue - } - // log the current progress periodically, so OPS will know that we're still working - nanoseconds := float64(time.Since(start).Nanoseconds()) - totalRestoreBytes := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateTotalRestore)) - restoredBytes := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateRestored)) - totalRowsToRestore := metric.ReadAllCounters(metrics.RowsCounter.MetricVec, prometheus.Labels{"state": metric.StateTotalRestore}) - restoredRows := metric.ReadAllCounters(metrics.RowsCounter.MetricVec, prometheus.Labels{"state": metric.StateRestored}) - // the estimated chunk is not accurate(likely under estimated), but the actual count is not accurate - // before the last table start, so use the bigger of the two should be a workaround - estimated := metric.ReadCounter(metrics.ChunkCounter.WithLabelValues(metric.ChunkStateEstimated)) - pending := metric.ReadCounter(metrics.ChunkCounter.WithLabelValues(metric.ChunkStatePending)) - if estimated < pending { - estimated = pending - } - finished := metric.ReadCounter(metrics.ChunkCounter.WithLabelValues(metric.ChunkStateFinished)) - totalTables := metric.ReadCounter(metrics.TableCounter.WithLabelValues(metric.TableStatePending, metric.TableResultSuccess)) - completedTables := metric.ReadCounter(metrics.TableCounter.WithLabelValues(metric.TableStateCompleted, metric.TableResultSuccess)) - bytesRead := metric.ReadHistogramSum(metrics.RowReadBytesHistogram) - engineEstimated := metric.ReadCounter(metrics.ProcessedEngineCounter.WithLabelValues(metric.ChunkStateEstimated, metric.TableResultSuccess)) - enginePending := metric.ReadCounter(metrics.ProcessedEngineCounter.WithLabelValues(metric.ChunkStatePending, metric.TableResultSuccess)) - if engineEstimated < enginePending { - engineEstimated = enginePending - } - engineFinished := metric.ReadCounter(metrics.ProcessedEngineCounter.WithLabelValues(metric.TableStateImported, metric.TableResultSuccess)) - bytesWritten := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateRestoreWritten)) - bytesImported := metric.ReadCounter(metrics.BytesCounter.WithLabelValues(metric.StateImported)) - - var state string - var remaining zap.Field - switch { - case finished >= estimated: - if engineFinished < engineEstimated { - state = "importing" - } else { - state = "post-processing" - } - case finished > 0: - state = "writing" - default: - state = "preparing" - } - - // lightning restore is separated into restore engine and import engine, they are both parallelized - // and pipelined between engines, so we can only weight the progress of those 2 phase to get the - // total progress. - // - // for local & importer backend: - // in most case import engine is faster since there's little computations, but inside one engine - // restore and import is serialized, the progress of those two will not differ too much, and - // import engine determines the end time of the whole restore, so we average them for now. - // the result progress may fall behind the real progress if import is faster. - // - // for tidb backend, we do nothing during import engine, so we use restore engine progress as the - // total progress. - restoreBytesField := zap.Skip() - importBytesField := zap.Skip() - restoreRowsField := zap.Skip() - remaining = zap.Skip() - totalPercent := 0.0 - if restoredBytes > 0 || restoredRows > 0 { - var restorePercent float64 - if totalRowsToRestore > 0 { - restorePercent = math.Min(restoredRows/totalRowsToRestore, 1.0) - restoreRowsField = zap.String("restore-rows", fmt.Sprintf("%.0f/%.0f", - restoredRows, totalRowsToRestore)) - } else { - restorePercent = math.Min(restoredBytes/totalRestoreBytes, 1.0) - restoreRowsField = zap.String("restore-rows", fmt.Sprintf("%.0f/%.0f(estimated)", - restoredRows, restoredRows/restorePercent)) - } - metrics.ProgressGauge.WithLabelValues(metric.ProgressPhaseRestore).Set(restorePercent) - if rc.cfg.TikvImporter.Backend != config.BackendTiDB { - var importPercent float64 - if bytesWritten > 0 { - // estimate total import bytes from written bytes - // when importPercent = 1, totalImportBytes = bytesWritten, but there's case - // bytesImported may bigger or smaller than bytesWritten such as when deduplicate - // we calculate progress using engines then use the bigger one in case bytesImported is - // smaller. - totalImportBytes := bytesWritten / restorePercent - biggerPercent := math.Max(bytesImported/totalImportBytes, engineFinished/engineEstimated) - importPercent = math.Min(biggerPercent, 1.0) - importBytesField = zap.String("import-bytes", fmt.Sprintf("%s/%s(estimated)", - units.BytesSize(bytesImported), units.BytesSize(totalImportBytes))) - } - metrics.ProgressGauge.WithLabelValues(metric.ProgressPhaseImport).Set(importPercent) - totalPercent = (restorePercent + importPercent) / 2 - } else { - totalPercent = restorePercent - } - if totalPercent < 1.0 { - remainNanoseconds := (1.0 - totalPercent) / totalPercent * nanoseconds - remaining = zap.Duration("remaining", time.Duration(remainNanoseconds).Round(time.Second)) - } - restoreBytesField = zap.String("restore-bytes", fmt.Sprintf("%s/%s", - units.BytesSize(restoredBytes), units.BytesSize(totalRestoreBytes))) - } - metrics.ProgressGauge.WithLabelValues(metric.ProgressPhaseTotal).Set(totalPercent) - - formatPercent := func(num, denom float64) string { - if denom > 0 { - return fmt.Sprintf(" (%.1f%%)", num/denom*100) - } - return "" - } - - // avoid output bytes speed if there are no unfinished chunks - encodeSpeedField := zap.Skip() - if bytesRead > 0 { - encodeSpeedField = zap.Float64("encode speed(MiB/s)", bytesRead/(1048576e-9*nanoseconds)) - } - - // Note: a speed of 28 MiB/s roughly corresponds to 100 GiB/hour. - log.FromContext(ctx).Info("progress", - zap.String("total", fmt.Sprintf("%.1f%%", totalPercent*100)), - // zap.String("files", fmt.Sprintf("%.0f/%.0f (%.1f%%)", finished, estimated, finished/estimated*100)), - zap.String("tables", fmt.Sprintf("%.0f/%.0f%s", completedTables, totalTables, formatPercent(completedTables, totalTables))), - zap.String("chunks", fmt.Sprintf("%.0f/%.0f%s", finished, estimated, formatPercent(finished, estimated))), - zap.String("engines", fmt.Sprintf("%.f/%.f%s", engineFinished, engineEstimated, formatPercent(engineFinished, engineEstimated))), - restoreBytesField, restoreRowsField, importBytesField, - encodeSpeedField, - zap.String("state", state), - remaining, - ) - - case <-checkQuotaChan: - // verify the total space occupied by sorted-kv-dir is below the quota, - // otherwise we perform an emergency import. - rc.enforceDiskQuota(ctx) - } - } - }, func(do bool) { - log.FromContext(ctx).Info("cancel periodic actions", zap.Bool("do", do)) - for _, f := range cancelFuncs { - f(do) - } - } -} - -func (rc *Controller) buildTablesRanges() []tidbkv.KeyRange { - var keyRanges []tidbkv.KeyRange - for _, dbInfo := range rc.dbInfos { - for _, tableInfo := range dbInfo.Tables { - if ranges, err := distsql.BuildTableRanges(tableInfo.Core); err == nil { - keyRanges = append(keyRanges, ranges...) - } - } - } - return keyRanges -} - -type checksumManagerKeyType struct{} - -var checksumManagerKey checksumManagerKeyType - -const ( - pauseGCTTLForDupeRes = time.Hour - pauseGCIntervalForDupeRes = time.Minute -) - -func (rc *Controller) keepPauseGCForDupeRes(ctx context.Context) (<-chan struct{}, error) { - tlsOpt := rc.tls.ToPDSecurityOption() - addrs := strings.Split(rc.cfg.TiDB.PdAddr, ",") - pdCli, err := pd.NewClientWithContext(ctx, addrs, tlsOpt) - if err != nil { - return nil, errors.Trace(err) - } - - serviceID := "lightning-duplicate-resolution-" + uuid.New().String() - ttl := int64(pauseGCTTLForDupeRes / time.Second) - - var ( - safePoint uint64 - paused bool - ) - // Try to get the minimum safe point across all services as our GC safe point. - for i := 0; i < 10; i++ { - if i > 0 { - time.Sleep(time.Second * 3) - } - minSafePoint, err := pdCli.UpdateServiceGCSafePoint(ctx, serviceID, ttl, 1) - if err != nil { - pdCli.Close() - return nil, errors.Trace(err) - } - newMinSafePoint, err := pdCli.UpdateServiceGCSafePoint(ctx, serviceID, ttl, minSafePoint) - if err != nil { - pdCli.Close() - return nil, errors.Trace(err) - } - if newMinSafePoint <= minSafePoint { - safePoint = minSafePoint - paused = true - break - } - log.FromContext(ctx).Warn( - "Failed to register GC safe point because the current minimum safe point is newer"+ - " than what we assume, will retry newMinSafePoint next time", - zap.Uint64("minSafePoint", minSafePoint), - zap.Uint64("newMinSafePoint", newMinSafePoint), - ) - } - if !paused { - pdCli.Close() - return nil, common.ErrPauseGC.GenWithStack("failed to pause GC for duplicate resolution after all retries") - } - - exitCh := make(chan struct{}) - go func(safePoint uint64) { - defer pdCli.Close() - defer close(exitCh) - ticker := time.NewTicker(pauseGCIntervalForDupeRes) - defer ticker.Stop() - for { - select { - case <-ticker.C: - minSafePoint, err := pdCli.UpdateServiceGCSafePoint(ctx, serviceID, ttl, safePoint) - if err != nil { - log.FromContext(ctx).Warn("Failed to register GC safe point", zap.Error(err)) - continue - } - if minSafePoint > safePoint { - log.FromContext(ctx).Warn("The current minimum safe point is newer than what we hold, duplicate records are at"+ - "risk of being GC and not detectable", - zap.Uint64("safePoint", safePoint), - zap.Uint64("minSafePoint", minSafePoint), - ) - safePoint = minSafePoint - } - case <-ctx.Done(): - stopCtx, cancelFunc := context.WithTimeout(context.Background(), time.Second*5) - if _, err := pdCli.UpdateServiceGCSafePoint(stopCtx, serviceID, 0, safePoint); err != nil { - log.FromContext(ctx).Warn("Failed to reset safe point ttl to zero", zap.Error(err)) - } - // just make compiler happy - cancelFunc() - return - } - } - }(safePoint) - return exitCh, nil -} - -func (rc *Controller) importTables(ctx context.Context) (finalErr error) { - // output error summary - defer rc.outputErrorSummary() - - if isLocalBackend(rc.cfg) && rc.cfg.Conflict.Strategy != config.NoneOnDup { - subCtx, cancel := context.WithCancel(ctx) - exitCh, err := rc.keepPauseGCForDupeRes(subCtx) - if err != nil { - cancel() - return errors.Trace(err) - } - defer func() { - cancel() - <-exitCh - }() - } - - logTask := log.FromContext(ctx).Begin(zap.InfoLevel, "restore all tables data") - if rc.tableWorkers == nil { - rc.tableWorkers = worker.NewPool(ctx, rc.cfg.App.TableConcurrency, "table") - } - if rc.indexWorkers == nil { - rc.indexWorkers = worker.NewPool(ctx, rc.cfg.App.IndexConcurrency, "index") - } - - // for local backend, we should disable some pd scheduler and change some settings, to - // make split region and ingest sst more stable - // because importer backend is mostly use for v3.x cluster which doesn't support these api, - // so we also don't do this for import backend - finishSchedulers := func() { - if rc.taskMgr != nil { - rc.taskMgr.Close() - } - } - // if one lightning failed abnormally, and can't determine whether it needs to switch back, - // we do not do switch back automatically - switchBack := false - cleanup := false - postProgress := func() error { return nil } - var kvStore tidbkv.Storage - var etcdCli *clientv3.Client - - if isLocalBackend(rc.cfg) { - var ( - restoreFn pdutil.UndoFunc - err error - ) - - if rc.cfg.TikvImporter.PausePDSchedulerScope == config.PausePDSchedulerScopeGlobal { - logTask.Info("pause pd scheduler of global scope") - - restoreFn, err = rc.taskMgr.CheckAndPausePdSchedulers(ctx) - if err != nil { - return errors.Trace(err) - } - } - - finishSchedulers = func() { - taskFinished := finalErr == nil - // use context.Background to make sure this restore function can still be executed even if ctx is canceled - restoreCtx := context.Background() - needSwitchBack, needCleanup, err := rc.taskMgr.CheckAndFinishRestore(restoreCtx, taskFinished) - if err != nil { - logTask.Warn("check restore pd schedulers failed", zap.Error(err)) - return - } - switchBack = needSwitchBack - cleanup = needCleanup - - if needSwitchBack && restoreFn != nil { - logTask.Info("add back PD leader®ion schedulers") - if restoreE := restoreFn(restoreCtx); restoreE != nil { - logTask.Warn("failed to restore removed schedulers, you may need to restore them manually", zap.Error(restoreE)) - } - } - - if rc.taskMgr != nil { - rc.taskMgr.Close() - } - } - - // Disable GC because TiDB enables GC already. - urlsWithScheme := rc.pdCli.GetServiceDiscovery().GetServiceURLs() - // remove URL scheme - urlsWithoutScheme := make([]string, 0, len(urlsWithScheme)) - for _, u := range urlsWithScheme { - u = strings.TrimPrefix(u, "http://") - u = strings.TrimPrefix(u, "https://") - urlsWithoutScheme = append(urlsWithoutScheme, u) - } - kvStore, err = driver.TiKVDriver{}.OpenWithOptions( - fmt.Sprintf( - "tikv://%s?disableGC=true&keyspaceName=%s", - strings.Join(urlsWithoutScheme, ","), rc.keyspaceName, - ), - driver.WithSecurity(rc.tls.ToTiKVSecurityConfig()), - ) - if err != nil { - return errors.Trace(err) - } - etcdCli, err := clientv3.New(clientv3.Config{ - Endpoints: urlsWithScheme, - AutoSyncInterval: 30 * time.Second, - TLS: rc.tls.TLSConfig(), - }) - if err != nil { - return errors.Trace(err) - } - etcd.SetEtcdCliByNamespace(etcdCli, keyspace.MakeKeyspaceEtcdNamespace(kvStore.GetCodec())) - - manager, err := NewChecksumManager(ctx, rc, kvStore) - if err != nil { - return errors.Trace(err) - } - ctx = context.WithValue(ctx, &checksumManagerKey, manager) - - undo, err := rc.registerTaskToPD(ctx) - if err != nil { - return errors.Trace(err) - } - defer undo() - } - - type task struct { - tr *TableImporter - cp *checkpoints.TableCheckpoint - } - - totalTables := 0 - for _, dbMeta := range rc.dbMetas { - totalTables += len(dbMeta.Tables) - } - postProcessTaskChan := make(chan task, totalTables) - - var wg sync.WaitGroup - var restoreErr common.OnceError - - stopPeriodicActions := make(chan struct{}) - - periodicActions, cancelFunc := rc.buildRunPeriodicActionAndCancelFunc(ctx, stopPeriodicActions) - go periodicActions() - - defer close(stopPeriodicActions) - - defer func() { - finishSchedulers() - cancelFunc(switchBack) - - if err := postProgress(); err != nil { - logTask.End(zap.ErrorLevel, err) - finalErr = err - return - } - logTask.End(zap.ErrorLevel, nil) - // clean up task metas - if cleanup { - logTask.Info("cleanup task metas") - if cleanupErr := rc.taskMgr.Cleanup(context.Background()); cleanupErr != nil { - logTask.Warn("failed to clean task metas, you may need to restore them manually", zap.Error(cleanupErr)) - } - // cleanup table meta and schema db if needed. - if err := rc.taskMgr.CleanupAllMetas(context.Background()); err != nil { - logTask.Warn("failed to clean table task metas, you may need to restore them manually", zap.Error(err)) - } - } - if kvStore != nil { - if err := kvStore.Close(); err != nil { - logTask.Warn("failed to close kv store", zap.Error(err)) - } - } - if etcdCli != nil { - if err := etcdCli.Close(); err != nil { - logTask.Warn("failed to close etcd client", zap.Error(err)) - } - } - }() - - taskCh := make(chan task, rc.cfg.App.IndexConcurrency) - defer close(taskCh) - - for i := 0; i < rc.cfg.App.IndexConcurrency; i++ { - go func() { - for task := range taskCh { - tableLogTask := task.tr.logger.Begin(zap.InfoLevel, "restore table") - web.BroadcastTableCheckpoint(task.tr.tableName, task.cp) - - needPostProcess, err := task.tr.importTable(ctx, rc, task.cp) - if err != nil && !common.IsContextCanceledError(err) { - task.tr.logger.Error("failed to import table", zap.Error(err)) - } - - err = common.NormalizeOrWrapErr(common.ErrRestoreTable, err, task.tr.tableName) - tableLogTask.End(zap.ErrorLevel, err) - web.BroadcastError(task.tr.tableName, err) - if m, ok := metric.FromContext(ctx); ok { - m.RecordTableCount(metric.TableStateCompleted, err) - } - restoreErr.Set(err) - if needPostProcess { - postProcessTaskChan <- task - } - wg.Done() - } - }() - } - - var allTasks []task - var totalDataSizeToRestore int64 - for _, dbMeta := range rc.dbMetas { - dbInfo := rc.dbInfos[dbMeta.Name] - for _, tableMeta := range dbMeta.Tables { - tableInfo := dbInfo.Tables[tableMeta.Name] - tableName := common.UniqueTable(dbInfo.Name, tableInfo.Name) - cp, err := rc.checkpointsDB.Get(ctx, tableName) - if err != nil { - return errors.Trace(err) - } - if cp.Status < checkpoints.CheckpointStatusAllWritten && len(tableMeta.DataFiles) == 0 { - continue - } - igCols, err := rc.cfg.Mydumper.IgnoreColumns.GetIgnoreColumns(dbInfo.Name, tableInfo.Name, rc.cfg.Mydumper.CaseSensitive) - if err != nil { - return errors.Trace(err) - } - tr, err := NewTableImporter(tableName, tableMeta, dbInfo, tableInfo, cp, igCols.ColumnsMap(), kvStore, etcdCli, log.FromContext(ctx)) - if err != nil { - return errors.Trace(err) - } - - allTasks = append(allTasks, task{tr: tr, cp: cp}) - - if len(cp.Engines) == 0 { - for i, fi := range tableMeta.DataFiles { - totalDataSizeToRestore += fi.FileMeta.FileSize - if fi.FileMeta.Type == mydump.SourceTypeParquet { - numberRows, err := mydump.ReadParquetFileRowCountByFile(ctx, rc.store, fi.FileMeta) - if err != nil { - return errors.Trace(err) - } - if m, ok := metric.FromContext(ctx); ok { - m.RowsCounter.WithLabelValues(metric.StateTotalRestore, tableName).Add(float64(numberRows)) - } - fi.FileMeta.Rows = numberRows - tableMeta.DataFiles[i] = fi - } - } - } else { - for _, eng := range cp.Engines { - for _, chunk := range eng.Chunks { - // for parquet files filesize is more accurate, we can calculate correct unfinished bytes unless - // we set up the reader, so we directly use filesize here - if chunk.FileMeta.Type == mydump.SourceTypeParquet { - totalDataSizeToRestore += chunk.FileMeta.FileSize - if m, ok := metric.FromContext(ctx); ok { - m.RowsCounter.WithLabelValues(metric.StateTotalRestore, tableName).Add(float64(chunk.UnfinishedSize())) - } - } else { - totalDataSizeToRestore += chunk.UnfinishedSize() - } - } - } - } - } - } - - if m, ok := metric.FromContext(ctx); ok { - m.BytesCounter.WithLabelValues(metric.StateTotalRestore).Add(float64(totalDataSizeToRestore)) - } - - for i := range allTasks { - wg.Add(1) - select { - case taskCh <- allTasks[i]: - case <-ctx.Done(): - return ctx.Err() - } - } - - wg.Wait() - // if context is done, should return directly - select { - case <-ctx.Done(): - err := restoreErr.Get() - if err == nil { - err = ctx.Err() - } - logTask.End(zap.ErrorLevel, err) - return err - default: - } - - postProgress = func() error { - close(postProcessTaskChan) - // otherwise, we should run all tasks in the post-process task chan - for i := 0; i < rc.cfg.App.TableConcurrency; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for task := range postProcessTaskChan { - metaMgr := rc.metaMgrBuilder.TableMetaMgr(task.tr) - // force all the remain post-process tasks to be executed - _, err2 := task.tr.postProcess(ctx, rc, task.cp, true, metaMgr) - restoreErr.Set(err2) - } - }() - } - wg.Wait() - return restoreErr.Get() - } - - return nil -} - -func (rc *Controller) registerTaskToPD(ctx context.Context) (undo func(), _ error) { - etcdCli, err := dialEtcdWithCfg(ctx, rc.cfg, rc.pdCli.GetServiceDiscovery().GetServiceURLs()) - if err != nil { - return nil, errors.Trace(err) - } - - register := utils.NewTaskRegister(etcdCli, utils.RegisterLightning, fmt.Sprintf("lightning-%s", uuid.New())) - - undo = func() { - closeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - if err := register.Close(closeCtx); err != nil { - log.L().Warn("failed to unregister task", zap.Error(err)) - } - if err := etcdCli.Close(); err != nil { - log.L().Warn("failed to close etcd client", zap.Error(err)) - } - } - if err := register.RegisterTask(ctx); err != nil { - undo() - return nil, errors.Trace(err) - } - return undo, nil -} - -func addExtendDataForCheckpoint( - ctx context.Context, - cfg *config.Config, - cp *checkpoints.TableCheckpoint, -) error { - if len(cfg.Routes) == 0 { - return nil - } - hasExtractor := false - for _, route := range cfg.Routes { - hasExtractor = hasExtractor || route.TableExtractor != nil || route.SchemaExtractor != nil || route.SourceExtractor != nil - if hasExtractor { - break - } - } - if !hasExtractor { - return nil - } - - // Use default file router directly because fileRouter and router are not compatible - fileRouter, err := mydump.NewDefaultFileRouter(log.FromContext(ctx)) - if err != nil { - return err - } - var router *regexprrouter.RouteTable - router, err = regexprrouter.NewRegExprRouter(cfg.Mydumper.CaseSensitive, cfg.Routes) - if err != nil { - return err - } - for _, engine := range cp.Engines { - for _, chunk := range engine.Chunks { - _, file := filepath.Split(chunk.FileMeta.Path) - var res *mydump.RouteResult - res, err = fileRouter.Route(file) - if err != nil { - return err - } - extendCols, extendData := router.FetchExtendColumn(res.Schema, res.Name, cfg.Mydumper.SourceID) - chunk.FileMeta.ExtendData = mydump.ExtendColumnData{ - Columns: extendCols, - Values: extendData, - } - } - } - return nil -} - -func (rc *Controller) outputErrorSummary() { - if rc.errorMgr.HasError() { - fmt.Println(rc.errorMgr.Output()) - } -} - -// do full compaction for the whole data. -func (rc *Controller) fullCompact(ctx context.Context) error { - if !rc.cfg.PostRestore.Compact { - log.FromContext(ctx).Info("skip full compaction") - return nil - } - - // wait until any existing level-1 compact to complete first. - task := log.FromContext(ctx).Begin(zap.InfoLevel, "wait for completion of existing level 1 compaction") - for !rc.compactState.CompareAndSwap(compactStateIdle, compactStateDoing) { - time.Sleep(100 * time.Millisecond) - } - task.End(zap.ErrorLevel, nil) - - return errors.Trace(rc.doCompact(ctx, FullLevelCompact)) -} - -func (rc *Controller) doCompact(ctx context.Context, level int32) error { - return tikv.ForAllStores( - ctx, - rc.pdHTTPCli, - metapb.StoreState_Offline, - func(c context.Context, store *pdhttp.MetaStore) error { - return tikv.Compact(c, rc.tls, store.Address, level, rc.resourceGroupName) - }, - ) -} - -func (rc *Controller) enforceDiskQuota(ctx context.Context) { - if !rc.diskQuotaState.CompareAndSwap(diskQuotaStateIdle, diskQuotaStateChecking) { - // do not run multiple the disk quota check / import simultaneously. - // (we execute the lock check in background to avoid blocking the cron thread) - return - } - - localBackend := rc.backend.(*local.Backend) - go func() { - // locker is assigned when we detect the disk quota is exceeded. - // before the disk quota is confirmed exceeded, we keep the diskQuotaLock - // unlocked to avoid periodically interrupting the writer threads. - var locker sync.Locker - defer func() { - rc.diskQuotaState.Store(diskQuotaStateIdle) - if locker != nil { - locker.Unlock() - } - }() - - isRetrying := false - - for { - // sleep for a cycle if we are retrying because there is nothing new to import. - if isRetrying { - select { - case <-ctx.Done(): - return - case <-time.After(rc.cfg.Cron.CheckDiskQuota.Duration): - } - } else { - isRetrying = true - } - - quota := int64(rc.cfg.TikvImporter.DiskQuota) - largeEngines, inProgressLargeEngines, totalDiskSize, totalMemSize := local.CheckDiskQuota(localBackend, quota) - if m, ok := metric.FromContext(ctx); ok { - m.LocalStorageUsageBytesGauge.WithLabelValues("disk").Set(float64(totalDiskSize)) - m.LocalStorageUsageBytesGauge.WithLabelValues("mem").Set(float64(totalMemSize)) - } - - logger := log.FromContext(ctx).With( - zap.Int64("diskSize", totalDiskSize), - zap.Int64("memSize", totalMemSize), - zap.Int64("quota", quota), - zap.Int("largeEnginesCount", len(largeEngines)), - zap.Int("inProgressLargeEnginesCount", inProgressLargeEngines)) - - if len(largeEngines) == 0 && inProgressLargeEngines == 0 { - logger.Debug("disk quota respected") - return - } - - if locker == nil { - // blocks all writers when we detected disk quota being exceeded. - rc.diskQuotaLock.Lock() - locker = &rc.diskQuotaLock - } - - logger.Warn("disk quota exceeded") - if len(largeEngines) == 0 { - logger.Warn("all large engines are already importing, keep blocking all writes") - continue - } - - // flush all engines so that checkpoints can be updated. - if err := rc.backend.FlushAllEngines(ctx); err != nil { - logger.Error("flush engine for disk quota failed, check again later", log.ShortError(err)) - return - } - - // at this point, all engines are synchronized on disk. - // we then import every large engines one by one and complete. - // if any engine failed to import, we just try again next time, since the data are still intact. - rc.diskQuotaState.Store(diskQuotaStateImporting) - task := logger.Begin(zap.WarnLevel, "importing large engines for disk quota") - var importErr error - for _, engine := range largeEngines { - // Use a larger split region size to avoid split the same region by many times. - if err := localBackend.UnsafeImportAndReset(ctx, engine, int64(config.SplitRegionSize)*int64(config.MaxSplitRegionSizeRatio), int64(config.SplitRegionKeys)*int64(config.MaxSplitRegionSizeRatio)); err != nil { - importErr = multierr.Append(importErr, err) - } - } - task.End(zap.ErrorLevel, importErr) - return - } - }() -} - -func (rc *Controller) setGlobalVariables(ctx context.Context) error { - // skip for tidb backend to be compatible with MySQL - if isTiDBBackend(rc.cfg) { - return nil - } - // set new collation flag base on tidb config - enabled, err := ObtainNewCollationEnabled(ctx, rc.db) - if err != nil { - return err - } - // we should enable/disable new collation here since in server mode, tidb config - // may be different in different tasks - collate.SetNewCollationEnabledForTest(enabled) - log.FromContext(ctx).Info(session.TidbNewCollationEnabled, zap.Bool("enabled", enabled)) - - return nil -} - -func (rc *Controller) waitCheckpointFinish() { - // wait checkpoint process finish so that we can do cleanup safely - close(rc.saveCpCh) - rc.checkpointsWg.Wait() -} - -func (rc *Controller) cleanCheckpoints(ctx context.Context) error { - rc.waitCheckpointFinish() - - if !rc.cfg.Checkpoint.Enable { - return nil - } - - logger := log.FromContext(ctx).With( - zap.Stringer("keepAfterSuccess", rc.cfg.Checkpoint.KeepAfterSuccess), - zap.Int64("taskID", rc.cfg.TaskID), - ) - - task := logger.Begin(zap.InfoLevel, "clean checkpoints") - var err error - switch rc.cfg.Checkpoint.KeepAfterSuccess { - case config.CheckpointRename: - err = rc.checkpointsDB.MoveCheckpoints(ctx, rc.cfg.TaskID) - case config.CheckpointRemove: - err = rc.checkpointsDB.RemoveCheckpoint(ctx, "all") - } - task.End(zap.ErrorLevel, err) - if err != nil { - return common.ErrCleanCheckpoint.Wrap(err).GenWithStackByArgs() - } - return nil -} - -func isLocalBackend(cfg *config.Config) bool { - return cfg.TikvImporter.Backend == config.BackendLocal -} - -func isTiDBBackend(cfg *config.Config) bool { - return cfg.TikvImporter.Backend == config.BackendTiDB -} - -// preCheckRequirements checks -// 1. Cluster resource -// 2. Local node resource -// 3. Cluster region -// 4. Lightning configuration -// before restore tables start. -func (rc *Controller) preCheckRequirements(ctx context.Context) error { - if err := rc.DataCheck(ctx); err != nil { - return errors.Trace(err) - } - - if rc.cfg.App.CheckRequirements { - if err := rc.ClusterIsAvailable(ctx); err != nil { - return errors.Trace(err) - } - - if rc.ownStore { - if err := rc.StoragePermission(ctx); err != nil { - return errors.Trace(err) - } - } - } - - if err := rc.metaMgrBuilder.Init(ctx); err != nil { - return common.ErrInitMetaManager.Wrap(err).GenWithStackByArgs() - } - taskExist := false - - // We still need to sample source data even if this task has existed, because we need to judge whether the - // source is in order as row key to decide how to sort local data. - estimatedSizeResult, err := rc.preInfoGetter.EstimateSourceDataSize(ctx) - if err != nil { - return common.ErrCheckDataSource.Wrap(err).GenWithStackByArgs() - } - estimatedDataSizeWithIndex := estimatedSizeResult.SizeWithIndex - estimatedTiflashDataSize := estimatedSizeResult.TiFlashSize - - // Do not import with too large concurrency because these data may be all unsorted. - if estimatedSizeResult.HasUnsortedBigTables { - if rc.cfg.App.TableConcurrency > rc.cfg.App.IndexConcurrency { - rc.cfg.App.TableConcurrency = rc.cfg.App.IndexConcurrency - } - } - if rc.status != nil { - rc.status.TotalFileSize.Store(estimatedSizeResult.SizeWithoutIndex) - } - if isLocalBackend(rc.cfg) { - pdAddrs := rc.pdCli.GetServiceDiscovery().GetServiceURLs() - pdController, err := pdutil.NewPdController( - ctx, pdAddrs, rc.tls.TLSConfig(), rc.tls.ToPDSecurityOption(), - ) - if err != nil { - return common.NormalizeOrWrapErr(common.ErrCreatePDClient, err) - } - - // PdController will be closed when `taskMetaMgr` closes. - rc.taskMgr = rc.metaMgrBuilder.TaskMetaMgr(pdController) - taskExist, err = rc.taskMgr.CheckTaskExist(ctx) - if err != nil { - return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() - } - if !taskExist { - if err = rc.taskMgr.InitTask(ctx, estimatedDataSizeWithIndex, estimatedTiflashDataSize); err != nil { - return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() - } - } - if rc.cfg.TikvImporter.PausePDSchedulerScope == config.PausePDSchedulerScopeTable && - !rc.taskMgr.CanPauseSchedulerByKeyRange() { - return errors.New("target cluster don't support pause-pd-scheduler-scope=table, the minimal version required is 6.1.0") - } - if rc.cfg.App.CheckRequirements { - needCheck := true - if rc.cfg.Checkpoint.Enable { - taskCheckpoints, err := rc.checkpointsDB.TaskCheckpoint(ctx) - if err != nil { - return common.ErrReadCheckpoint.Wrap(err).GenWithStack("get task checkpoint failed") - } - // If task checkpoint is initialized, it means check has been performed before. - // We don't need and shouldn't check again, because lightning may have already imported some data. - needCheck = taskCheckpoints == nil - } - if needCheck { - err = rc.localResource(ctx) - if err != nil { - return common.ErrCheckLocalResource.Wrap(err).GenWithStackByArgs() - } - if err := rc.clusterResource(ctx); err != nil { - if err1 := rc.taskMgr.CleanupTask(ctx); err1 != nil { - log.FromContext(ctx).Warn("cleanup task failed", zap.Error(err1)) - return common.ErrMetaMgrUnknown.Wrap(err).GenWithStackByArgs() - } - } - if err := rc.checkClusterRegion(ctx); err != nil { - return common.ErrCheckClusterRegion.Wrap(err).GenWithStackByArgs() - } - } - // even if checkpoint exists, we still need to make sure CDC/PiTR task is not running. - if err := rc.checkCDCPiTR(ctx); err != nil { - return common.ErrCheckCDCPiTR.Wrap(err).GenWithStackByArgs() - } - } - } - - if rc.cfg.App.CheckRequirements { - fmt.Println(rc.checkTemplate.Output()) - } - if !rc.checkTemplate.Success() { - if !taskExist && rc.taskMgr != nil { - err := rc.taskMgr.CleanupTask(ctx) - if err != nil { - log.FromContext(ctx).Warn("cleanup task failed", zap.Error(err)) - } - } - return common.ErrPreCheckFailed.GenWithStackByArgs(rc.checkTemplate.FailedMsg()) - } - return nil -} - -// DataCheck checks the data schema which needs #rc.restoreSchema finished. -func (rc *Controller) DataCheck(ctx context.Context) error { - if rc.cfg.App.CheckRequirements { - if err := rc.HasLargeCSV(ctx); err != nil { - return errors.Trace(err) - } - } - - if err := rc.checkCheckpoints(ctx); err != nil { - return errors.Trace(err) - } - - if rc.cfg.App.CheckRequirements { - if err := rc.checkSourceSchema(ctx); err != nil { - return errors.Trace(err) - } - } - - if err := rc.checkTableEmpty(ctx); err != nil { - return common.ErrCheckTableEmpty.Wrap(err).GenWithStackByArgs() - } - if err := rc.checkCSVHeader(ctx); err != nil { - return common.ErrCheckCSVHeader.Wrap(err).GenWithStackByArgs() - } - - return nil -} - -var ( - maxKVQueueSize = 32 // Cache at most this number of rows before blocking the encode loop - minDeliverBytes uint64 = 96 * units.KiB // 96 KB (data + index). batch at least this amount of bytes to reduce number of messages -) - -type deliveredKVs struct { - kvs encode.Row // if kvs is nil, this indicated we've got the last message. - columns []string - offset int64 - rowID int64 - - realOffset int64 // indicates file reader's current position, only used for compressed files -} - -type deliverResult struct { - totalDur time.Duration - err error -} - -func saveCheckpoint(rc *Controller, t *TableImporter, engineID int32, chunk *checkpoints.ChunkCheckpoint) { - // We need to update the AllocBase every time we've finished a file. - // The AllocBase is determined by the maximum of the "handle" (_tidb_rowid - // or integer primary key), which can only be obtained by reading all data. - - var base int64 - if t.tableInfo.Core.ContainsAutoRandomBits() { - base = t.alloc.Get(autoid.AutoRandomType).Base() + 1 - } else { - base = t.alloc.Get(autoid.RowIDAllocType).Base() + 1 - } - rc.saveCpCh <- saveCp{ - tableName: t.tableName, - merger: &checkpoints.RebaseCheckpointMerger{ - AllocBase: base, - }, - } - rc.saveCpCh <- saveCp{ - tableName: t.tableName, - merger: &checkpoints.ChunkCheckpointMerger{ - EngineID: engineID, - Key: chunk.Key, - Checksum: chunk.Checksum, - Pos: chunk.Chunk.Offset, - RowID: chunk.Chunk.PrevRowIDMax, - ColumnPermutation: chunk.ColumnPermutation, - EndOffset: chunk.Chunk.EndOffset, - }, - } -} - -// filterColumns filter columns and extend columns. -// It accepts: -// - columnsNames, header in the data files; -// - extendData, extendData fetched through data file name, that is to say, table info; -// - ignoreColsMap, columns to be ignored when we import; -// - tableInfo, tableInfo of the target table; -// It returns: -// - filteredColumns, columns of the original data to import. -// - extendValueDatums, extended Data to import. -// The data we import will use filteredColumns as columns, use (parser.LastRow+extendValueDatums) as data -// ColumnPermutation will be modified to make sure the correspondence relationship is correct. -// if len(columnsNames) > 0, it means users has specified each field definition, we can just use users -func filterColumns(columnNames []string, extendData mydump.ExtendColumnData, ignoreColsMap map[string]struct{}, tableInfo *model.TableInfo) ([]string, []types.Datum) { - extendCols, extendVals := extendData.Columns, extendData.Values - extendColsSet := set.NewStringSet(extendCols...) - filteredColumns := make([]string, 0, len(columnNames)) - if len(columnNames) > 0 { - if len(ignoreColsMap) > 0 { - for _, c := range columnNames { - _, ok := ignoreColsMap[c] - if !ok { - filteredColumns = append(filteredColumns, c) - } - } - } else { - filteredColumns = columnNames - } - } else if len(ignoreColsMap) > 0 || len(extendCols) > 0 { - // init column names by table schema - // after filtered out some columns, we must explicitly set the columns for TiDB backend - for _, col := range tableInfo.Columns { - _, ok := ignoreColsMap[col.Name.L] - // ignore all extend row values specified by users - if !col.Hidden && !ok && !extendColsSet.Exist(col.Name.O) { - filteredColumns = append(filteredColumns, col.Name.O) - } - } - } - extendValueDatums := make([]types.Datum, 0) - filteredColumns = append(filteredColumns, extendCols...) - for _, extendVal := range extendVals { - extendValueDatums = append(extendValueDatums, types.NewStringDatum(extendVal)) - } - return filteredColumns, extendValueDatums -} - -// check store liveness of tikv client-go requires GlobalConfig to work correctly, so we need to init it, -// else tikv will report SSL error when tls is enabled. -// and the SSL error seems affects normal logic of newer TiKV version, and cause the error "tikv: region is unavailable" -// during checksum. -// todo: DM relay on lightning physical mode too, but client-go doesn't support passing TLS data as bytes, -func initGlobalConfig(secCfg tikvconfig.Security) { - if secCfg.ClusterSSLCA != "" || secCfg.ClusterSSLCert != "" { - conf := tidbconfig.GetGlobalConfig() - conf.Security.ClusterSSLCA = secCfg.ClusterSSLCA - conf.Security.ClusterSSLCert = secCfg.ClusterSSLCert - conf.Security.ClusterSSLKey = secCfg.ClusterSSLKey - tidbconfig.StoreGlobalConfig(conf) - } -} diff --git a/lightning/pkg/importer/table_import.go b/lightning/pkg/importer/table_import.go index be00b6c6250e5..ccc0fcc088b3b 100644 --- a/lightning/pkg/importer/table_import.go +++ b/lightning/pkg/importer/table_import.go @@ -235,9 +235,9 @@ func (tr *TableImporter) importTable( } } - if _, _err_ := failpoint.Eval(_curpkg_("FailAfterDuplicateDetection")); _err_ == nil { + failpoint.Inject("FailAfterDuplicateDetection", func() { panic("forcing failure after duplicate detection") - } + }) } // 3. Drop indexes if add-index-by-sql is enabled @@ -279,9 +279,9 @@ func (tr *TableImporter) populateChunks(ctx context.Context, rc *Controller, cp tableRegions, err := mydump.MakeTableRegions(ctx, divideConfig) if err == nil { timestamp := time.Now().Unix() - if v, _err_ := failpoint.Eval(_curpkg_("PopulateChunkTimestamp")); _err_ == nil { + failpoint.Inject("PopulateChunkTimestamp", func(v failpoint.Value) { timestamp = int64(v.(int)) - } + }) for _, region := range tableRegions { engine, found := cp.Engines[region.EngineID] if !found { @@ -579,13 +579,13 @@ func (tr *TableImporter) importEngines(pCtx context.Context, rc *Controller, cp if cp.Status < checkpoints.CheckpointStatusIndexImported { var err error if indexEngineCp.Status < checkpoints.CheckpointStatusImported { - if _, _err_ := failpoint.Eval(_curpkg_("FailBeforeStartImportingIndexEngine")); _err_ == nil { + failpoint.Inject("FailBeforeStartImportingIndexEngine", func() { errMsg := "fail before importing index KV data" tr.logger.Warn(errMsg) - return errors.New(errMsg) - } + failpoint.Return(errors.New(errMsg)) + }) err = tr.importKV(ctx, closedIndexEngine, rc) - if _, _err_ := failpoint.Eval(_curpkg_("FailBeforeIndexEngineImported")); _err_ == nil { + failpoint.Inject("FailBeforeIndexEngineImported", func() { finished := rc.status.FinishedFileSize.Load() total := rc.status.TotalFileSize.Load() tr.logger.Warn("print lightning status", @@ -593,7 +593,7 @@ func (tr *TableImporter) importEngines(pCtx context.Context, rc *Controller, cp zap.Int64("total", total), zap.Bool("equal", finished == total)) panic("forcing failure due to FailBeforeIndexEngineImported") - } + }) } saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexImported) @@ -722,11 +722,11 @@ ChunkLoop: } checkFlushLock.Unlock() - if _, _err_ := failpoint.Eval(_curpkg_("orphanWriterGoRoutine")); _err_ == nil { + failpoint.Inject("orphanWriterGoRoutine", func() { if chunkIndex > 0 { <-pCtx.Done() } - } + }) select { case <-pCtx.Done(): @@ -1038,12 +1038,12 @@ func (tr *TableImporter) postProcess( } hasDupe = hasLocalDupe } - if v, _err_ := failpoint.Eval(_curpkg_("SlowDownCheckDupe")); _err_ == nil { + failpoint.Inject("SlowDownCheckDupe", func(v failpoint.Value) { sec := v.(int) tr.logger.Warn("start to sleep several seconds before checking other dupe", zap.Int("seconds", sec)) time.Sleep(time.Duration(sec) * time.Second) - } + }) otherHasDupe, needRemoteDupe, baseTotalChecksum, err := metaMgr.CheckAndUpdateLocalChecksum(ctx, &localChecksum, hasDupe) if err != nil { @@ -1087,11 +1087,11 @@ func (tr *TableImporter) postProcess( var remoteChecksum *local.RemoteChecksum remoteChecksum, err = DoChecksum(ctx, tr.tableInfo) - if _, _err_ := failpoint.Eval(_curpkg_("checksum-error")); _err_ == nil { + failpoint.Inject("checksum-error", func() { tr.logger.Info("failpoint checksum-error injected.") remoteChecksum = nil err = status.Error(codes.Unknown, "Checksum meets error.") - } + }) if err != nil { if rc.cfg.PostRestore.Checksum != config.OpLevelOptional { return false, errors.Trace(err) @@ -1367,7 +1367,7 @@ func (tr *TableImporter) importKV( m.ImportSecondsHistogram.Observe(dur.Seconds()) } - failpoint.Eval(_curpkg_("SlowDownImport")) + failpoint.Inject("SlowDownImport", func() {}) return nil } @@ -1544,17 +1544,17 @@ func (*TableImporter) executeDDL( resultCh <- s.Exec(ctx, "add index", ddl) }() - if _, _err_ := failpoint.Eval(_curpkg_("AddIndexCrash")); _err_ == nil { + failpoint.Inject("AddIndexCrash", func() { _ = common.KillMySelf() - } + }) var ddlErr error for { select { case ddlErr = <-resultCh: - if _, _err_ := failpoint.Eval(_curpkg_("AddIndexFail")); _err_ == nil { + failpoint.Inject("AddIndexFail", func() { ddlErr = errors.New("injected error") - } + }) if ddlErr == nil { return nil } diff --git a/lightning/pkg/importer/table_import.go__failpoint_stash__ b/lightning/pkg/importer/table_import.go__failpoint_stash__ deleted file mode 100644 index ccc0fcc088b3b..0000000000000 --- a/lightning/pkg/importer/table_import.go__failpoint_stash__ +++ /dev/null @@ -1,1822 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "cmp" - "context" - "database/sql" - "encoding/hex" - "fmt" - "path/filepath" - "slices" - "strings" - "sync" - "time" - - dmysql "github.com/go-sql-driver/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/lightning/pkg/web" - "github.com/pingcap/tidb/pkg/errno" - tidbkv "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/mydump" - verify "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/lightning/worker" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/extsort" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/multierr" - "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// TableImporter is a helper struct to import a table. -type TableImporter struct { - // The unique table name in the form "`db`.`tbl`". - tableName string - dbInfo *checkpoints.TidbDBInfo - tableInfo *checkpoints.TidbTableInfo - tableMeta *mydump.MDTableMeta - encTable table.Table - alloc autoid.Allocators - logger log.Logger - kvStore tidbkv.Storage - etcdCli *clientv3.Client - autoidCli *autoid.ClientDiscover - - // dupIgnoreRows tracks the rowIDs of rows that are duplicated and should be ignored. - dupIgnoreRows extsort.ExternalSorter - - ignoreColumns map[string]struct{} -} - -// NewTableImporter creates a new TableImporter. -func NewTableImporter( - tableName string, - tableMeta *mydump.MDTableMeta, - dbInfo *checkpoints.TidbDBInfo, - tableInfo *checkpoints.TidbTableInfo, - cp *checkpoints.TableCheckpoint, - ignoreColumns map[string]struct{}, - kvStore tidbkv.Storage, - etcdCli *clientv3.Client, - logger log.Logger, -) (*TableImporter, error) { - idAlloc := kv.NewPanickingAllocators(tableInfo.Core.SepAutoInc(), cp.AllocBase) - tbl, err := tables.TableFromMeta(idAlloc, tableInfo.Core) - if err != nil { - return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", tableName) - } - autoidCli := autoid.NewClientDiscover(etcdCli) - - return &TableImporter{ - tableName: tableName, - dbInfo: dbInfo, - tableInfo: tableInfo, - tableMeta: tableMeta, - encTable: tbl, - alloc: idAlloc, - kvStore: kvStore, - etcdCli: etcdCli, - autoidCli: autoidCli, - logger: logger.With(zap.String("table", tableName)), - ignoreColumns: ignoreColumns, - }, nil -} - -func (tr *TableImporter) importTable( - ctx context.Context, - rc *Controller, - cp *checkpoints.TableCheckpoint, -) (bool, error) { - // 1. Load the table info. - select { - case <-ctx.Done(): - return false, ctx.Err() - default: - } - - metaMgr := rc.metaMgrBuilder.TableMetaMgr(tr) - // no need to do anything if the chunks are already populated - if len(cp.Engines) > 0 { - tr.logger.Info("reusing engines and files info from checkpoint", - zap.Int("enginesCnt", len(cp.Engines)), - zap.Int("filesCnt", cp.CountChunks()), - ) - err := addExtendDataForCheckpoint(ctx, rc.cfg, cp) - if err != nil { - return false, errors.Trace(err) - } - } else if cp.Status < checkpoints.CheckpointStatusAllWritten { - if err := tr.populateChunks(ctx, rc, cp); err != nil { - return false, errors.Trace(err) - } - - // fetch the max chunk row_id max value as the global max row_id - rowIDMax := int64(0) - for _, engine := range cp.Engines { - if len(engine.Chunks) > 0 && engine.Chunks[len(engine.Chunks)-1].Chunk.RowIDMax > rowIDMax { - rowIDMax = engine.Chunks[len(engine.Chunks)-1].Chunk.RowIDMax - } - } - versionStr, err := version.FetchVersion(ctx, rc.db) - if err != nil { - return false, errors.Trace(err) - } - - versionInfo := version.ParseServerInfo(versionStr) - - // "show table next_row_id" is only available after tidb v4.0.0 - if versionInfo.ServerVersion.Major >= 4 && isLocalBackend(rc.cfg) { - // first, insert a new-line into meta table - if err = metaMgr.InitTableMeta(ctx); err != nil { - return false, err - } - - checksum, rowIDBase, err := metaMgr.AllocTableRowIDs(ctx, rowIDMax) - if err != nil { - return false, err - } - tr.RebaseChunkRowIDs(cp, rowIDBase) - - if checksum != nil { - if cp.Checksum != *checksum { - cp.Checksum = *checksum - rc.saveCpCh <- saveCp{ - tableName: tr.tableName, - merger: &checkpoints.TableChecksumMerger{ - Checksum: cp.Checksum, - }, - } - } - tr.logger.Info("checksum before restore table", zap.Object("checksum", &cp.Checksum)) - } - } - if err := rc.checkpointsDB.InsertEngineCheckpoints(ctx, tr.tableName, cp.Engines); err != nil { - return false, errors.Trace(err) - } - web.BroadcastTableCheckpoint(tr.tableName, cp) - - // rebase the allocator so it exceeds the number of rows. - if tr.tableInfo.Core.ContainsAutoRandomBits() { - cp.AllocBase = max(cp.AllocBase, tr.tableInfo.Core.AutoRandID) - if err := tr.alloc.Get(autoid.AutoRandomType).Rebase(context.Background(), cp.AllocBase, false); err != nil { - return false, err - } - } else { - cp.AllocBase = max(cp.AllocBase, tr.tableInfo.Core.AutoIncID) - if err := tr.alloc.Get(autoid.RowIDAllocType).Rebase(context.Background(), cp.AllocBase, false); err != nil { - return false, err - } - } - rc.saveCpCh <- saveCp{ - tableName: tr.tableName, - merger: &checkpoints.RebaseCheckpointMerger{ - AllocBase: cp.AllocBase, - }, - } - } - - // 2. Do duplicate detection if needed - if isLocalBackend(rc.cfg) && rc.cfg.Conflict.PrecheckConflictBeforeImport && rc.cfg.Conflict.Strategy != config.NoneOnDup { - _, uuid := backend.MakeUUID(tr.tableName, common.IndexEngineID) - workingDir := filepath.Join(rc.cfg.TikvImporter.SortedKVDir, uuid.String()+local.DupDetectDirSuffix) - resultDir := filepath.Join(rc.cfg.TikvImporter.SortedKVDir, uuid.String()+local.DupResultDirSuffix) - - dupIgnoreRows, err := extsort.OpenDiskSorter(resultDir, &extsort.DiskSorterOptions{ - Concurrency: rc.cfg.App.RegionConcurrency, - }) - if err != nil { - return false, errors.Trace(err) - } - tr.dupIgnoreRows = dupIgnoreRows - - if cp.Status < checkpoints.CheckpointStatusDupDetected { - err := tr.preDeduplicate(ctx, rc, cp, workingDir) - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusDupDetected) - if err := firstErr(err, saveCpErr); err != nil { - return false, errors.Trace(err) - } - } - - if !dupIgnoreRows.IsSorted() { - if err := dupIgnoreRows.Sort(ctx); err != nil { - return false, errors.Trace(err) - } - } - - failpoint.Inject("FailAfterDuplicateDetection", func() { - panic("forcing failure after duplicate detection") - }) - } - - // 3. Drop indexes if add-index-by-sql is enabled - if cp.Status < checkpoints.CheckpointStatusIndexDropped && isLocalBackend(rc.cfg) && rc.cfg.TikvImporter.AddIndexBySQL { - err := tr.dropIndexes(ctx, rc.db) - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexDropped) - if err := firstErr(err, saveCpErr); err != nil { - return false, errors.Trace(err) - } - } - - // 4. Restore engines (if still needed) - err := tr.importEngines(ctx, rc, cp) - if err != nil { - return false, errors.Trace(err) - } - - err = metaMgr.UpdateTableStatus(ctx, metaStatusRestoreFinished) - if err != nil { - return false, errors.Trace(err) - } - - // 5. Post-process. With the last parameter set to false, we can allow delay analyze execute latter - return tr.postProcess(ctx, rc, cp, false /* force-analyze */, metaMgr) -} - -// Close implements the Importer interface. -func (tr *TableImporter) Close() { - tr.encTable = nil - if tr.dupIgnoreRows != nil { - _ = tr.dupIgnoreRows.Close() - } - tr.logger.Info("restore done") -} - -func (tr *TableImporter) populateChunks(ctx context.Context, rc *Controller, cp *checkpoints.TableCheckpoint) error { - task := tr.logger.Begin(zap.InfoLevel, "load engines and files") - divideConfig := mydump.NewDataDivideConfig(rc.cfg, len(tr.tableInfo.Core.Columns), rc.ioWorkers, rc.store, tr.tableMeta) - tableRegions, err := mydump.MakeTableRegions(ctx, divideConfig) - if err == nil { - timestamp := time.Now().Unix() - failpoint.Inject("PopulateChunkTimestamp", func(v failpoint.Value) { - timestamp = int64(v.(int)) - }) - for _, region := range tableRegions { - engine, found := cp.Engines[region.EngineID] - if !found { - engine = &checkpoints.EngineCheckpoint{ - Status: checkpoints.CheckpointStatusLoaded, - } - cp.Engines[region.EngineID] = engine - } - ccp := &checkpoints.ChunkCheckpoint{ - Key: checkpoints.ChunkCheckpointKey{ - Path: region.FileMeta.Path, - Offset: region.Chunk.Offset, - }, - FileMeta: region.FileMeta, - ColumnPermutation: nil, - Chunk: region.Chunk, - Timestamp: timestamp, - } - if len(region.Chunk.Columns) > 0 { - perms, err := parseColumnPermutations( - tr.tableInfo.Core, - region.Chunk.Columns, - tr.ignoreColumns, - log.FromContext(ctx)) - if err != nil { - return errors.Trace(err) - } - ccp.ColumnPermutation = perms - } - engine.Chunks = append(engine.Chunks, ccp) - } - - // Add index engine checkpoint - cp.Engines[common.IndexEngineID] = &checkpoints.EngineCheckpoint{Status: checkpoints.CheckpointStatusLoaded} - } - task.End(zap.ErrorLevel, err, - zap.Int("enginesCnt", len(cp.Engines)), - zap.Int("filesCnt", len(tableRegions)), - ) - return err -} - -// AutoIDRequirement implements autoid.Requirement. -var _ autoid.Requirement = &TableImporter{} - -// Store implements the autoid.Requirement interface. -func (tr *TableImporter) Store() tidbkv.Storage { - return tr.kvStore -} - -// AutoIDClient implements the autoid.Requirement interface. -func (tr *TableImporter) AutoIDClient() *autoid.ClientDiscover { - return tr.autoidCli -} - -// RebaseChunkRowIDs rebase the row id of the chunks. -func (*TableImporter) RebaseChunkRowIDs(cp *checkpoints.TableCheckpoint, rowIDBase int64) { - if rowIDBase == 0 { - return - } - for _, engine := range cp.Engines { - for _, chunk := range engine.Chunks { - chunk.Chunk.PrevRowIDMax += rowIDBase - chunk.Chunk.RowIDMax += rowIDBase - } - } -} - -// initializeColumns computes the "column permutation" for an INSERT INTO -// statement. Suppose a table has columns (a, b, c, d) in canonical order, and -// we execute `INSERT INTO (d, b, a) VALUES ...`, we will need to remap the -// columns as: -// -// - column `a` is at position 2 -// - column `b` is at position 1 -// - column `c` is missing -// - column `d` is at position 0 -// -// The column permutation of (d, b, a) is set to be [2, 1, -1, 0]. -// -// The argument `columns` _must_ be in lower case. -func (tr *TableImporter) initializeColumns(columns []string, ccp *checkpoints.ChunkCheckpoint) error { - colPerm, err := createColumnPermutation(columns, tr.ignoreColumns, tr.tableInfo.Core, tr.logger) - if err != nil { - return err - } - ccp.ColumnPermutation = colPerm - return nil -} - -func createColumnPermutation( - columns []string, - ignoreColumns map[string]struct{}, - tableInfo *model.TableInfo, - logger log.Logger, -) ([]int, error) { - var colPerm []int - if len(columns) == 0 { - colPerm = make([]int, 0, len(tableInfo.Columns)+1) - shouldIncludeRowID := common.TableHasAutoRowID(tableInfo) - - // no provided columns, so use identity permutation. - for i, col := range tableInfo.Columns { - idx := i - if _, ok := ignoreColumns[col.Name.L]; ok { - idx = -1 - } else if col.IsGenerated() { - idx = -1 - } - colPerm = append(colPerm, idx) - } - if shouldIncludeRowID { - colPerm = append(colPerm, -1) - } - } else { - var err error - colPerm, err = parseColumnPermutations(tableInfo, columns, ignoreColumns, logger) - if err != nil { - return nil, errors.Trace(err) - } - } - return colPerm, nil -} - -func (tr *TableImporter) importEngines(pCtx context.Context, rc *Controller, cp *checkpoints.TableCheckpoint) error { - indexEngineCp := cp.Engines[common.IndexEngineID] - if indexEngineCp == nil { - tr.logger.Error("fail to importEngines because indexengine is nil") - return common.ErrCheckpointNotFound.GenWithStack("table %v index engine checkpoint not found", tr.tableName) - } - - ctx, cancel := context.WithCancel(pCtx) - defer cancel() - - // The table checkpoint status set to `CheckpointStatusIndexImported` only if - // both all data engines and the index engine had been imported to TiKV. - // But persist index engine checkpoint status and table checkpoint status are - // not an atomic operation, so `cp.Status < CheckpointStatusIndexImported` - // but `indexEngineCp.Status == CheckpointStatusImported` could happen - // when kill lightning after saving index engine checkpoint status before saving - // table checkpoint status. - var closedIndexEngine *backend.ClosedEngine - var restoreErr error - // if index-engine checkpoint is lower than `CheckpointStatusClosed`, there must be - // data-engines that need to be restore or import. Otherwise, all data-engines should - // be finished already. - - handleDataEngineThisRun := false - idxEngineCfg := &backend.EngineConfig{ - TableInfo: tr.tableInfo, - } - if indexEngineCp.Status < checkpoints.CheckpointStatusClosed { - handleDataEngineThisRun = true - indexWorker := rc.indexWorkers.Apply() - defer rc.indexWorkers.Recycle(indexWorker) - - if rc.cfg.TikvImporter.Backend == config.BackendLocal { - // for index engine, the estimate factor is non-clustered index count - idxCnt := len(tr.tableInfo.Core.Indices) - if !common.TableHasAutoRowID(tr.tableInfo.Core) { - idxCnt-- - } - threshold := local.EstimateCompactionThreshold(tr.tableMeta.DataFiles, cp, int64(idxCnt)) - idxEngineCfg.Local = backend.LocalEngineConfig{ - Compact: threshold > 0, - CompactConcurrency: 4, - CompactThreshold: threshold, - BlockSize: int(rc.cfg.TikvImporter.BlockSize), - } - } - // import backend can't reopen engine if engine is closed, so - // only open index engine if any data engines don't finish writing. - var indexEngine *backend.OpenedEngine - var err error - for engineID, engine := range cp.Engines { - if engineID == common.IndexEngineID { - continue - } - if engine.Status < checkpoints.CheckpointStatusAllWritten { - indexEngine, err = rc.engineMgr.OpenEngine(ctx, idxEngineCfg, tr.tableName, common.IndexEngineID) - if err != nil { - return errors.Trace(err) - } - break - } - } - - logTask := tr.logger.Begin(zap.InfoLevel, "import whole table") - var wg sync.WaitGroup - var engineErr common.OnceError - setError := func(err error) { - engineErr.Set(err) - // cancel this context to fail fast - cancel() - } - - type engineCheckpoint struct { - engineID int32 - checkpoint *checkpoints.EngineCheckpoint - } - allEngines := make([]engineCheckpoint, 0, len(cp.Engines)) - for engineID, engine := range cp.Engines { - allEngines = append(allEngines, engineCheckpoint{engineID: engineID, checkpoint: engine}) - } - slices.SortFunc(allEngines, func(i, j engineCheckpoint) int { return cmp.Compare(i.engineID, j.engineID) }) - - for _, ecp := range allEngines { - engineID := ecp.engineID - engine := ecp.checkpoint - select { - case <-ctx.Done(): - // Set engineErr and break this for loop to wait all the sub-routines done before return. - // Directly return may cause panic because caller will close the pebble db but some sub routines - // are still reading from or writing to the pebble db. - engineErr.Set(ctx.Err()) - default: - } - if engineErr.Get() != nil { - break - } - - // Should skip index engine - if engineID < 0 { - continue - } - - if engine.Status < checkpoints.CheckpointStatusImported { - wg.Add(1) - - // If the number of chunks is small, it means that this engine may be finished in a few times. - // We do not limit it in TableConcurrency - restoreWorker := rc.tableWorkers.Apply() - go func(w *worker.Worker, eid int32, ecp *checkpoints.EngineCheckpoint) { - defer wg.Done() - engineLogTask := tr.logger.With(zap.Int32("engineNumber", eid)).Begin(zap.InfoLevel, "restore engine") - dataClosedEngine, err := tr.preprocessEngine(ctx, rc, indexEngine, eid, ecp) - engineLogTask.End(zap.ErrorLevel, err) - rc.tableWorkers.Recycle(w) - if err == nil { - dataWorker := rc.closedEngineLimit.Apply() - defer rc.closedEngineLimit.Recycle(dataWorker) - err = tr.importEngine(ctx, dataClosedEngine, rc, ecp) - if rc.status != nil && rc.status.backend == config.BackendLocal { - for _, chunk := range ecp.Chunks { - rc.status.FinishedFileSize.Add(chunk.TotalSize()) - } - } - } - if err != nil { - setError(err) - } - }(restoreWorker, engineID, engine) - } else { - for _, chunk := range engine.Chunks { - rc.status.FinishedFileSize.Add(chunk.TotalSize()) - } - } - } - - wg.Wait() - - restoreErr = engineErr.Get() - logTask.End(zap.ErrorLevel, restoreErr) - if restoreErr != nil { - return errors.Trace(restoreErr) - } - - if indexEngine != nil { - closedIndexEngine, restoreErr = indexEngine.Close(ctx) - } else { - closedIndexEngine, restoreErr = rc.engineMgr.UnsafeCloseEngine(ctx, idxEngineCfg, tr.tableName, common.IndexEngineID) - } - - if err = rc.saveStatusCheckpoint(ctx, tr.tableName, common.IndexEngineID, restoreErr, checkpoints.CheckpointStatusClosed); err != nil { - return errors.Trace(firstErr(restoreErr, err)) - } - } else if indexEngineCp.Status == checkpoints.CheckpointStatusClosed { - // If index engine file has been closed but not imported only if context cancel occurred - // when `importKV()` execution, so `UnsafeCloseEngine` and continue import it. - closedIndexEngine, restoreErr = rc.engineMgr.UnsafeCloseEngine(ctx, idxEngineCfg, tr.tableName, common.IndexEngineID) - } - if restoreErr != nil { - return errors.Trace(restoreErr) - } - - // if data engine is handled in previous run and we continue importing from checkpoint - if !handleDataEngineThisRun { - for _, engine := range cp.Engines { - for _, chunk := range engine.Chunks { - rc.status.FinishedFileSize.Add(chunk.Chunk.EndOffset - chunk.Key.Offset) - } - } - } - - if cp.Status < checkpoints.CheckpointStatusIndexImported { - var err error - if indexEngineCp.Status < checkpoints.CheckpointStatusImported { - failpoint.Inject("FailBeforeStartImportingIndexEngine", func() { - errMsg := "fail before importing index KV data" - tr.logger.Warn(errMsg) - failpoint.Return(errors.New(errMsg)) - }) - err = tr.importKV(ctx, closedIndexEngine, rc) - failpoint.Inject("FailBeforeIndexEngineImported", func() { - finished := rc.status.FinishedFileSize.Load() - total := rc.status.TotalFileSize.Load() - tr.logger.Warn("print lightning status", - zap.Int64("finished", finished), - zap.Int64("total", total), - zap.Bool("equal", finished == total)) - panic("forcing failure due to FailBeforeIndexEngineImported") - }) - } - - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexImported) - if err = firstErr(err, saveCpErr); err != nil { - return errors.Trace(err) - } - } - return nil -} - -// preprocessEngine do some preprocess work -// for local backend, it do local sort, for tidb backend it transforms data into sql and execute -// TODO: it's not a correct name for tidb backend, since there's no post-process for it -// TODO: after separate local/tidb backend more clearly, rename it. -func (tr *TableImporter) preprocessEngine( - pCtx context.Context, - rc *Controller, - indexEngine *backend.OpenedEngine, - engineID int32, - cp *checkpoints.EngineCheckpoint, -) (*backend.ClosedEngine, error) { - ctx, cancel := context.WithCancel(pCtx) - defer cancel() - // all data has finished written, we can close the engine directly. - if cp.Status >= checkpoints.CheckpointStatusAllWritten { - engineCfg := &backend.EngineConfig{ - TableInfo: tr.tableInfo, - } - closedEngine, err := rc.engineMgr.UnsafeCloseEngine(ctx, engineCfg, tr.tableName, engineID) - // If any error occurred, recycle worker immediately - if err != nil { - return closedEngine, errors.Trace(err) - } - if rc.status != nil && rc.status.backend == config.BackendTiDB { - for _, chunk := range cp.Chunks { - rc.status.FinishedFileSize.Add(chunk.Chunk.EndOffset - chunk.Key.Offset) - } - } - return closedEngine, nil - } - - // if the key are ordered, LocalWrite can optimize the writing. - // table has auto-incremented _tidb_rowid must satisfy following restrictions: - // - clustered index disable and primary key is not number - // - no auto random bits (auto random or shard row id) - // - no partition table - // - no explicit _tidb_rowid field (At this time we can't determine if the source file contains _tidb_rowid field, - // so we will do this check in LocalWriter when the first row is received.) - hasAutoIncrementAutoID := common.TableHasAutoRowID(tr.tableInfo.Core) && - tr.tableInfo.Core.AutoRandomBits == 0 && tr.tableInfo.Core.ShardRowIDBits == 0 && - tr.tableInfo.Core.Partition == nil - dataWriterCfg := &backend.LocalWriterConfig{} - dataWriterCfg.Local.IsKVSorted = hasAutoIncrementAutoID - dataWriterCfg.TiDB.TableName = tr.tableName - - logTask := tr.logger.With(zap.Int32("engineNumber", engineID)).Begin(zap.InfoLevel, "encode kv data and write") - dataEngineCfg := &backend.EngineConfig{ - TableInfo: tr.tableInfo, - } - if !tr.tableMeta.IsRowOrdered { - dataEngineCfg.Local.Compact = true - dataEngineCfg.Local.CompactConcurrency = 4 - dataEngineCfg.Local.CompactThreshold = local.CompactionUpperThreshold - } - dataEngine, err := rc.engineMgr.OpenEngine(ctx, dataEngineCfg, tr.tableName, engineID) - if err != nil { - return nil, errors.Trace(err) - } - - var wg sync.WaitGroup - var chunkErr common.OnceError - - type chunkFlushStatus struct { - dataStatus backend.ChunkFlushStatus - indexStatus backend.ChunkFlushStatus - chunkCp *checkpoints.ChunkCheckpoint - } - - // chunks that are finished writing, but checkpoints are not finished due to flush not finished. - var checkFlushLock sync.Mutex - flushPendingChunks := make([]chunkFlushStatus, 0, 16) - - chunkCpChan := make(chan *checkpoints.ChunkCheckpoint, 16) - go func() { - for { - select { - case cp, ok := <-chunkCpChan: - if !ok { - return - } - saveCheckpoint(rc, tr, engineID, cp) - case <-ctx.Done(): - return - } - } - }() - - setError := func(err error) { - chunkErr.Set(err) - cancel() - } - - metrics, _ := metric.FromContext(ctx) - - // Restore table data -ChunkLoop: - for chunkIndex, chunk := range cp.Chunks { - if rc.status != nil && rc.status.backend == config.BackendTiDB { - rc.status.FinishedFileSize.Add(chunk.Chunk.Offset - chunk.Key.Offset) - } - if chunk.Chunk.Offset >= chunk.Chunk.EndOffset { - continue - } - - checkFlushLock.Lock() - finished := 0 - for _, c := range flushPendingChunks { - if !(c.indexStatus.Flushed() && c.dataStatus.Flushed()) { - break - } - chunkCpChan <- c.chunkCp - finished++ - } - if finished > 0 { - flushPendingChunks = flushPendingChunks[finished:] - } - checkFlushLock.Unlock() - - failpoint.Inject("orphanWriterGoRoutine", func() { - if chunkIndex > 0 { - <-pCtx.Done() - } - }) - - select { - case <-pCtx.Done(): - break ChunkLoop - default: - } - - if chunkErr.Get() != nil { - break - } - - // Flows : - // 1. read mydump file - // 2. sql -> kvs - // 3. load kvs data (into kv deliver server) - // 4. flush kvs data (into tikv node) - var remainChunkCnt float64 - if chunk.Chunk.Offset < chunk.Chunk.EndOffset { - remainChunkCnt = float64(chunk.UnfinishedSize()) / float64(chunk.TotalSize()) - if metrics != nil { - metrics.ChunkCounter.WithLabelValues(metric.ChunkStatePending).Add(remainChunkCnt) - } - } - - dataWriter, err := dataEngine.LocalWriter(ctx, dataWriterCfg) - if err != nil { - setError(err) - break - } - - writerCfg := &backend.LocalWriterConfig{} - writerCfg.TiDB.TableName = tr.tableName - indexWriter, err := indexEngine.LocalWriter(ctx, writerCfg) - if err != nil { - _, _ = dataWriter.Close(ctx) - setError(err) - break - } - cr, err := newChunkProcessor(ctx, chunkIndex, rc.cfg, chunk, rc.ioWorkers, rc.store, tr.tableInfo.Core) - if err != nil { - setError(err) - break - } - - if chunk.FileMeta.Type == mydump.SourceTypeParquet { - // TODO: use the compressed size of the chunk to conduct memory control - if _, err = getChunkCompressedSizeForParquet(ctx, chunk, rc.store); err != nil { - return nil, errors.Trace(err) - } - } - - restoreWorker := rc.regionWorkers.Apply() - wg.Add(1) - go func(w *worker.Worker, cr *chunkProcessor) { - // Restore a chunk. - defer func() { - cr.close() - wg.Done() - rc.regionWorkers.Recycle(w) - }() - if metrics != nil { - metrics.ChunkCounter.WithLabelValues(metric.ChunkStateRunning).Add(remainChunkCnt) - } - err := cr.process(ctx, tr, engineID, dataWriter, indexWriter, rc) - var dataFlushStatus, indexFlushStaus backend.ChunkFlushStatus - if err == nil { - dataFlushStatus, err = dataWriter.Close(ctx) - } - if err == nil { - indexFlushStaus, err = indexWriter.Close(ctx) - } - if err == nil { - if metrics != nil { - metrics.ChunkCounter.WithLabelValues(metric.ChunkStateFinished).Add(remainChunkCnt) - metrics.BytesCounter.WithLabelValues(metric.StateRestoreWritten).Add(float64(cr.chunk.Checksum.SumSize())) - } - if dataFlushStatus != nil && indexFlushStaus != nil { - if dataFlushStatus.Flushed() && indexFlushStaus.Flushed() { - saveCheckpoint(rc, tr, engineID, cr.chunk) - } else { - checkFlushLock.Lock() - flushPendingChunks = append(flushPendingChunks, chunkFlushStatus{ - dataStatus: dataFlushStatus, - indexStatus: indexFlushStaus, - chunkCp: cr.chunk, - }) - checkFlushLock.Unlock() - } - } - } else { - if metrics != nil { - metrics.ChunkCounter.WithLabelValues(metric.ChunkStateFailed).Add(remainChunkCnt) - } - setError(err) - } - }(restoreWorker, cr) - } - - wg.Wait() - select { - case <-pCtx.Done(): - return nil, pCtx.Err() - default: - } - - // Report some statistics into the log for debugging. - totalKVSize := uint64(0) - totalSQLSize := int64(0) - logKeyName := "read(bytes)" - for _, chunk := range cp.Chunks { - totalKVSize += chunk.Checksum.SumSize() - totalSQLSize += chunk.UnfinishedSize() - if chunk.FileMeta.Type == mydump.SourceTypeParquet { - logKeyName = "read(rows)" - } - } - - err = chunkErr.Get() - logTask.End(zap.ErrorLevel, err, - zap.Int64(logKeyName, totalSQLSize), - zap.Uint64("written", totalKVSize), - ) - - trySavePendingChunks := func(context.Context) error { - checkFlushLock.Lock() - cnt := 0 - for _, chunk := range flushPendingChunks { - if !(chunk.dataStatus.Flushed() && chunk.indexStatus.Flushed()) { - break - } - saveCheckpoint(rc, tr, engineID, chunk.chunkCp) - cnt++ - } - flushPendingChunks = flushPendingChunks[cnt:] - checkFlushLock.Unlock() - return nil - } - - // in local mode, this check-point make no sense, because we don't do flush now, - // so there may be data lose if exit at here. So we don't write this checkpoint - // here like other mode. - if !isLocalBackend(rc.cfg) { - if saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, engineID, err, checkpoints.CheckpointStatusAllWritten); saveCpErr != nil { - return nil, errors.Trace(firstErr(err, saveCpErr)) - } - } - if err != nil { - // if process is canceled, we should flush all chunk checkpoints for local backend - if isLocalBackend(rc.cfg) && common.IsContextCanceledError(err) { - // ctx is canceled, so to avoid Close engine failed, we use `context.Background()` here - if _, err2 := dataEngine.Close(context.Background()); err2 != nil { - log.FromContext(ctx).Warn("flush all chunk checkpoints failed before manually exits", zap.Error(err2)) - return nil, errors.Trace(err) - } - if err2 := trySavePendingChunks(context.Background()); err2 != nil { - log.FromContext(ctx).Warn("flush all chunk checkpoints failed before manually exits", zap.Error(err2)) - } - } - return nil, errors.Trace(err) - } - - closedDataEngine, err := dataEngine.Close(ctx) - // For local backend, if checkpoint is enabled, we must flush index engine to avoid data loss. - // this flush action impact up to 10% of the performance, so we only do it if necessary. - if err == nil && rc.cfg.Checkpoint.Enable && isLocalBackend(rc.cfg) { - if err = indexEngine.Flush(ctx); err != nil { - return nil, errors.Trace(err) - } - if err = trySavePendingChunks(ctx); err != nil { - return nil, errors.Trace(err) - } - } - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, engineID, err, checkpoints.CheckpointStatusClosed) - if err = firstErr(err, saveCpErr); err != nil { - // If any error occurred, recycle worker immediately - return nil, errors.Trace(err) - } - return closedDataEngine, nil -} - -func (tr *TableImporter) importEngine( - ctx context.Context, - closedEngine *backend.ClosedEngine, - rc *Controller, - cp *checkpoints.EngineCheckpoint, -) error { - if cp.Status >= checkpoints.CheckpointStatusImported { - return nil - } - - // 1. calling import - if err := tr.importKV(ctx, closedEngine, rc); err != nil { - return errors.Trace(err) - } - - // 2. perform a level-1 compact if idling. - if rc.cfg.PostRestore.Level1Compact && rc.compactState.CompareAndSwap(compactStateIdle, compactStateDoing) { - go func() { - // we ignore level-1 compact failure since it is not fatal. - // no need log the error, it is done in (*Importer).Compact already. - _ = rc.doCompact(ctx, Level1Compact) - rc.compactState.Store(compactStateIdle) - }() - } - - return nil -} - -// postProcess execute rebase-auto-id/checksum/analyze according to the task config. -// -// if the parameter forcePostProcess to true, postProcess force run checksum and analyze even if the -// post-process-at-last config is true. And if this two phases are skipped, the first return value will be true. -func (tr *TableImporter) postProcess( - ctx context.Context, - rc *Controller, - cp *checkpoints.TableCheckpoint, - forcePostProcess bool, - metaMgr tableMetaMgr, -) (bool, error) { - if !rc.backend.ShouldPostProcess() { - return false, nil - } - - // alter table set auto_increment - if cp.Status < checkpoints.CheckpointStatusAlteredAutoInc { - tblInfo := tr.tableInfo.Core - var err error - // TODO why we have to rebase id for tidb backend??? remove it later. - if tblInfo.ContainsAutoRandomBits() { - ft := &common.GetAutoRandomColumn(tblInfo).FieldType - shardFmt := autoid.NewShardIDFormat(ft, tblInfo.AutoRandomBits, tblInfo.AutoRandomRangeBits) - maxCap := shardFmt.IncrementalBitsCapacity() - err = AlterAutoRandom(ctx, rc.db, tr.tableName, uint64(tr.alloc.Get(autoid.AutoRandomType).Base())+1, maxCap) - } else if common.TableHasAutoRowID(tblInfo) || tblInfo.GetAutoIncrementColInfo() != nil { - if isLocalBackend(rc.cfg) { - // for TiDB version >= 6.5.0, a table might have separate allocators for auto_increment column and _tidb_rowid, - // especially when a table has auto_increment non-clustered PK, it will use both allocators. - // And in this case, ALTER TABLE xxx AUTO_INCREMENT = xxx only works on the allocator of auto_increment column, - // not for allocator of _tidb_rowid. - // So we need to rebase IDs for those 2 allocators explicitly. - err = common.RebaseTableAllocators(ctx, map[autoid.AllocatorType]int64{ - autoid.RowIDAllocType: tr.alloc.Get(autoid.RowIDAllocType).Base(), - autoid.AutoIncrementType: tr.alloc.Get(autoid.AutoIncrementType).Base(), - }, tr, tr.dbInfo.ID, tr.tableInfo.Core) - } else { - // only alter auto increment id iff table contains auto-increment column or generated handle. - // ALTER TABLE xxx AUTO_INCREMENT = yyy has a bad naming. - // if a table has implicit _tidb_rowid column & tbl.SepAutoID=false, then it works on _tidb_rowid - // allocator, even if the table has NO auto-increment column. - newBase := uint64(tr.alloc.Get(autoid.RowIDAllocType).Base()) + 1 - err = AlterAutoIncrement(ctx, rc.db, tr.tableName, newBase) - } - } - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusAlteredAutoInc) - if err = firstErr(err, saveCpErr); err != nil { - return false, errors.Trace(err) - } - cp.Status = checkpoints.CheckpointStatusAlteredAutoInc - } - - // tidb backend don't need checksum & analyze - if rc.cfg.PostRestore.Checksum == config.OpLevelOff && rc.cfg.PostRestore.Analyze == config.OpLevelOff { - tr.logger.Debug("skip checksum & analyze, either because not supported by this backend or manually disabled") - err := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, nil, checkpoints.CheckpointStatusAnalyzeSkipped) - return false, errors.Trace(err) - } - - if !forcePostProcess && rc.cfg.PostRestore.PostProcessAtLast { - return true, nil - } - - w := rc.checksumWorks.Apply() - defer rc.checksumWorks.Recycle(w) - - shouldSkipAnalyze := false - estimatedModifyCnt := 100_000_000 - if cp.Status < checkpoints.CheckpointStatusChecksumSkipped { - // 4. do table checksum - var localChecksum verify.KVChecksum - for _, engine := range cp.Engines { - for _, chunk := range engine.Chunks { - localChecksum.Add(&chunk.Checksum) - } - } - indexNum := len(tr.tableInfo.Core.Indices) - if common.TableHasAutoRowID(tr.tableInfo.Core) { - indexNum++ - } - estimatedModifyCnt = int(localChecksum.SumKVS()) / (1 + indexNum) - tr.logger.Info("local checksum", zap.Object("checksum", &localChecksum)) - - // 4.5. do duplicate detection. - // if we came here, it must be a local backend. - // todo: remove this cast after we refactor the backend interface. Physical mode is so different, we shouldn't - // try to abstract it with logical mode. - localBackend := rc.backend.(*local.Backend) - dupeController := localBackend.GetDupeController(rc.cfg.TikvImporter.RangeConcurrency*2, rc.errorMgr) - hasDupe := false - if rc.cfg.Conflict.Strategy != config.NoneOnDup { - opts := &encode.SessionOptions{ - SQLMode: mysql.ModeStrictAllTables, - SysVars: rc.sysVars, - } - var err error - hasLocalDupe, err := dupeController.CollectLocalDuplicateRows(ctx, tr.encTable, tr.tableName, opts, rc.cfg.Conflict.Strategy) - if err != nil { - tr.logger.Error("collect local duplicate keys failed", log.ShortError(err)) - return false, errors.Trace(err) - } - hasDupe = hasLocalDupe - } - failpoint.Inject("SlowDownCheckDupe", func(v failpoint.Value) { - sec := v.(int) - tr.logger.Warn("start to sleep several seconds before checking other dupe", - zap.Int("seconds", sec)) - time.Sleep(time.Duration(sec) * time.Second) - }) - - otherHasDupe, needRemoteDupe, baseTotalChecksum, err := metaMgr.CheckAndUpdateLocalChecksum(ctx, &localChecksum, hasDupe) - if err != nil { - return false, errors.Trace(err) - } - needChecksum := !otherHasDupe && needRemoteDupe - hasDupe = hasDupe || otherHasDupe - - if needRemoteDupe && rc.cfg.Conflict.Strategy != config.NoneOnDup { - opts := &encode.SessionOptions{ - SQLMode: mysql.ModeStrictAllTables, - SysVars: rc.sysVars, - } - hasRemoteDupe, e := dupeController.CollectRemoteDuplicateRows(ctx, tr.encTable, tr.tableName, opts, rc.cfg.Conflict.Strategy) - if e != nil { - tr.logger.Error("collect remote duplicate keys failed", log.ShortError(e)) - return false, errors.Trace(e) - } - hasDupe = hasDupe || hasRemoteDupe - - if hasDupe { - if err = dupeController.ResolveDuplicateRows(ctx, tr.encTable, tr.tableName, rc.cfg.Conflict.Strategy); err != nil { - tr.logger.Error("resolve remote duplicate keys failed", log.ShortError(err)) - return false, errors.Trace(err) - } - } - } - - if rc.dupIndicator != nil { - tr.logger.Debug("set dupIndicator", zap.Bool("has-duplicate", hasDupe)) - rc.dupIndicator.CompareAndSwap(false, hasDupe) - } - - nextStage := checkpoints.CheckpointStatusChecksummed - if rc.cfg.PostRestore.Checksum != config.OpLevelOff && !hasDupe && needChecksum { - if cp.Checksum.SumKVS() > 0 || baseTotalChecksum.SumKVS() > 0 { - localChecksum.Add(&cp.Checksum) - localChecksum.Add(baseTotalChecksum) - tr.logger.Info("merged local checksum", zap.Object("checksum", &localChecksum)) - } - - var remoteChecksum *local.RemoteChecksum - remoteChecksum, err = DoChecksum(ctx, tr.tableInfo) - failpoint.Inject("checksum-error", func() { - tr.logger.Info("failpoint checksum-error injected.") - remoteChecksum = nil - err = status.Error(codes.Unknown, "Checksum meets error.") - }) - if err != nil { - if rc.cfg.PostRestore.Checksum != config.OpLevelOptional { - return false, errors.Trace(err) - } - tr.logger.Warn("do checksum failed, will skip this error and go on", log.ShortError(err)) - err = nil - } - if remoteChecksum != nil { - err = tr.compareChecksum(remoteChecksum, localChecksum) - // with post restore level 'optional', we will skip checksum error - if rc.cfg.PostRestore.Checksum == config.OpLevelOptional { - if err != nil { - tr.logger.Warn("compare checksum failed, will skip this error and go on", log.ShortError(err)) - err = nil - } - } - } - } else { - switch { - case rc.cfg.PostRestore.Checksum == config.OpLevelOff: - tr.logger.Info("skip checksum because the checksum option is off") - case hasDupe: - tr.logger.Info("skip checksum&analyze because duplicates were detected") - shouldSkipAnalyze = true - case !needChecksum: - tr.logger.Info("skip checksum&analyze because other lightning instance will do this") - shouldSkipAnalyze = true - } - err = nil - nextStage = checkpoints.CheckpointStatusChecksumSkipped - } - - // Don't call FinishTable when other lightning will calculate checksum. - if err == nil && needChecksum { - err = metaMgr.FinishTable(ctx) - } - - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, nextStage) - if err = firstErr(err, saveCpErr); err != nil { - return false, errors.Trace(err) - } - cp.Status = nextStage - } - - if cp.Status < checkpoints.CheckpointStatusIndexAdded { - var err error - if rc.cfg.TikvImporter.AddIndexBySQL { - w := rc.addIndexLimit.Apply() - err = tr.addIndexes(ctx, rc.db) - rc.addIndexLimit.Recycle(w) - // Analyze will be automatically triggered after indexes are added by SQL. We can skip manual analyze. - shouldSkipAnalyze = true - } - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusIndexAdded) - if err = firstErr(err, saveCpErr); err != nil { - return false, errors.Trace(err) - } - cp.Status = checkpoints.CheckpointStatusIndexAdded - } - - // do table analyze - if cp.Status < checkpoints.CheckpointStatusAnalyzeSkipped { - switch { - case shouldSkipAnalyze || rc.cfg.PostRestore.Analyze == config.OpLevelOff: - if !shouldSkipAnalyze { - updateStatsMeta(ctx, rc.db, tr.tableInfo.ID, estimatedModifyCnt) - } - tr.logger.Info("skip analyze") - if err := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, nil, checkpoints.CheckpointStatusAnalyzeSkipped); err != nil { - return false, errors.Trace(err) - } - cp.Status = checkpoints.CheckpointStatusAnalyzeSkipped - case forcePostProcess || !rc.cfg.PostRestore.PostProcessAtLast: - err := tr.analyzeTable(ctx, rc.db) - // witch post restore level 'optional', we will skip analyze error - if rc.cfg.PostRestore.Analyze == config.OpLevelOptional { - if err != nil { - tr.logger.Warn("analyze table failed, will skip this error and go on", log.ShortError(err)) - err = nil - } - } - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, checkpoints.WholeTableEngineID, err, checkpoints.CheckpointStatusAnalyzed) - if err = firstErr(err, saveCpErr); err != nil { - return false, errors.Trace(err) - } - cp.Status = checkpoints.CheckpointStatusAnalyzed - } - } - - return true, nil -} - -func getChunkCompressedSizeForParquet( - ctx context.Context, - chunk *checkpoints.ChunkCheckpoint, - store storage.ExternalStorage, -) (int64, error) { - reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, store, storage.DecompressConfig{ - ZStdDecodeConcurrency: 1, - }) - if err != nil { - return 0, errors.Trace(err) - } - parser, err := mydump.NewParquetParser(ctx, store, reader, chunk.FileMeta.Path) - if err != nil { - _ = reader.Close() - return 0, errors.Trace(err) - } - //nolint: errcheck - defer parser.Close() - err = parser.Reader.ReadFooter() - if err != nil { - return 0, errors.Trace(err) - } - rowGroups := parser.Reader.Footer.GetRowGroups() - var maxRowGroupSize int64 - for _, rowGroup := range rowGroups { - var rowGroupSize int64 - columnChunks := rowGroup.GetColumns() - for _, columnChunk := range columnChunks { - columnChunkSize := columnChunk.MetaData.GetTotalCompressedSize() - rowGroupSize += columnChunkSize - } - maxRowGroupSize = max(maxRowGroupSize, rowGroupSize) - } - return maxRowGroupSize, nil -} - -func updateStatsMeta(ctx context.Context, db *sql.DB, tableID int64, count int) { - s := common.SQLWithRetry{ - DB: db, - Logger: log.FromContext(ctx).With(zap.Int64("tableID", tableID)), - } - err := s.Transact(ctx, "update stats_meta", func(ctx context.Context, tx *sql.Tx) error { - rs, err := tx.ExecContext(ctx, ` -update mysql.stats_meta - set modify_count = ?, - count = ?, - version = @@tidb_current_ts - where table_id = ?; -`, count, count, tableID) - if err != nil { - return errors.Trace(err) - } - affected, err := rs.RowsAffected() - if err != nil { - return errors.Trace(err) - } - if affected == 0 { - return errors.Errorf("record with table_id %d not found", tableID) - } - return nil - }) - if err != nil { - s.Logger.Warn("failed to update stats_meta", zap.Error(err)) - } -} - -func parseColumnPermutations( - tableInfo *model.TableInfo, - columns []string, - ignoreColumns map[string]struct{}, - logger log.Logger, -) ([]int, error) { - colPerm := make([]int, 0, len(tableInfo.Columns)+1) - - columnMap := make(map[string]int) - for i, column := range columns { - columnMap[column] = i - } - - tableColumnMap := make(map[string]int) - for i, col := range tableInfo.Columns { - tableColumnMap[col.Name.L] = i - } - - // check if there are some unknown columns - var unknownCols []string - for _, c := range columns { - if _, ok := tableColumnMap[c]; !ok && c != model.ExtraHandleName.L { - if _, ignore := ignoreColumns[c]; !ignore { - unknownCols = append(unknownCols, c) - } - } - } - - if len(unknownCols) > 0 { - return colPerm, common.ErrUnknownColumns.GenWithStackByArgs(strings.Join(unknownCols, ","), tableInfo.Name) - } - - for _, colInfo := range tableInfo.Columns { - if i, ok := columnMap[colInfo.Name.L]; ok { - if _, ignore := ignoreColumns[colInfo.Name.L]; !ignore { - colPerm = append(colPerm, i) - } else { - logger.Debug("column ignored by user requirements", - zap.Stringer("table", tableInfo.Name), - zap.String("colName", colInfo.Name.O), - zap.Stringer("colType", &colInfo.FieldType), - ) - colPerm = append(colPerm, -1) - } - } else { - if len(colInfo.GeneratedExprString) == 0 { - logger.Warn("column missing from data file, going to fill with default value", - zap.Stringer("table", tableInfo.Name), - zap.String("colName", colInfo.Name.O), - zap.Stringer("colType", &colInfo.FieldType), - ) - } - colPerm = append(colPerm, -1) - } - } - // append _tidb_rowid column - rowIDIdx := -1 - if i, ok := columnMap[model.ExtraHandleName.L]; ok { - if _, ignored := ignoreColumns[model.ExtraHandleName.L]; !ignored { - rowIDIdx = i - } - } - // FIXME: the schema info for tidb backend is not complete, so always add the _tidb_rowid field. - // Other logic should ignore this extra field if not needed. - colPerm = append(colPerm, rowIDIdx) - - return colPerm, nil -} - -func (tr *TableImporter) importKV( - ctx context.Context, - closedEngine *backend.ClosedEngine, - rc *Controller, -) error { - task := closedEngine.Logger().Begin(zap.InfoLevel, "import and cleanup engine") - regionSplitSize := int64(rc.cfg.TikvImporter.RegionSplitSize) - regionSplitKeys := int64(rc.cfg.TikvImporter.RegionSplitKeys) - - if regionSplitSize == 0 && rc.taskMgr != nil { - regionSplitSize = int64(config.SplitRegionSize) - if err := rc.taskMgr.CheckTasksExclusively(ctx, func(tasks []taskMeta) ([]taskMeta, error) { - if len(tasks) > 0 { - regionSplitSize = int64(config.SplitRegionSize) * int64(min(len(tasks), config.MaxSplitRegionSizeRatio)) - } - return nil, nil - }); err != nil { - return errors.Trace(err) - } - } - if regionSplitKeys == 0 { - if regionSplitSize > int64(config.SplitRegionSize) { - regionSplitKeys = int64(float64(regionSplitSize) / float64(config.SplitRegionSize) * float64(config.SplitRegionKeys)) - } else { - regionSplitKeys = int64(config.SplitRegionKeys) - } - } - err := closedEngine.Import(ctx, regionSplitSize, regionSplitKeys) - if common.ErrFoundDuplicateKeys.Equal(err) { - err = local.ConvertToErrFoundConflictRecords(err, tr.encTable) - } - saveCpErr := rc.saveStatusCheckpoint(ctx, tr.tableName, closedEngine.GetID(), err, checkpoints.CheckpointStatusImported) - // Don't clean up when save checkpoint failed, because we will verifyLocalFile and import engine again after restart. - if err == nil && saveCpErr == nil { - err = multierr.Append(err, closedEngine.Cleanup(ctx)) - } - err = firstErr(err, saveCpErr) - - dur := task.End(zap.ErrorLevel, err) - - if err != nil { - return errors.Trace(err) - } - - if m, ok := metric.FromContext(ctx); ok { - m.ImportSecondsHistogram.Observe(dur.Seconds()) - } - - failpoint.Inject("SlowDownImport", func() {}) - - return nil -} - -// do checksum for each table. -func (tr *TableImporter) compareChecksum(remoteChecksum *local.RemoteChecksum, localChecksum verify.KVChecksum) error { - if remoteChecksum.Checksum != localChecksum.Sum() || - remoteChecksum.TotalKVs != localChecksum.SumKVS() || - remoteChecksum.TotalBytes != localChecksum.SumSize() { - return common.ErrChecksumMismatch.GenWithStackByArgs( - remoteChecksum.Checksum, localChecksum.Sum(), - remoteChecksum.TotalKVs, localChecksum.SumKVS(), - remoteChecksum.TotalBytes, localChecksum.SumSize(), - ) - } - - tr.logger.Info("checksum pass", zap.Object("local", &localChecksum)) - return nil -} - -func (tr *TableImporter) analyzeTable(ctx context.Context, db *sql.DB) error { - task := tr.logger.Begin(zap.InfoLevel, "analyze") - exec := common.SQLWithRetry{ - DB: db, - Logger: tr.logger, - } - err := exec.Exec(ctx, "analyze table", "ANALYZE TABLE "+tr.tableName) - task.End(zap.ErrorLevel, err) - return err -} - -func (tr *TableImporter) dropIndexes(ctx context.Context, db *sql.DB) error { - logger := log.FromContext(ctx).With(zap.String("table", tr.tableName)) - - tblInfo := tr.tableInfo - remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo.Core) - for _, idxInfo := range dropIndexes { - sqlStr := common.BuildDropIndexSQL(tblInfo.DB, tblInfo.Name, idxInfo) - - logger.Info("drop index", zap.String("sql", sqlStr)) - - s := common.SQLWithRetry{ - DB: db, - Logger: logger, - } - if err := s.Exec(ctx, "drop index", sqlStr); err != nil { - if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { - switch merr.Number { - case errno.ErrCantDropFieldOrKey, errno.ErrDropIndexNeededInForeignKey: - remainIndexes = append(remainIndexes, idxInfo) - logger.Info("can't drop index, skip", zap.String("index", idxInfo.Name.O), zap.Error(err)) - continue - } - } - return common.ErrDropIndexFailed.Wrap(err).GenWithStackByArgs(common.EscapeIdentifier(idxInfo.Name.O), tr.tableName) - } - } - if len(remainIndexes) < len(tblInfo.Core.Indices) { - // Must clone (*model.TableInfo) before modifying it, since it may be referenced in other place. - tblInfo.Core = tblInfo.Core.Clone() - tblInfo.Core.Indices = remainIndexes - - // Rebuild encTable. - encTable, err := tables.TableFromMeta(tr.alloc, tblInfo.Core) - if err != nil { - return errors.Trace(err) - } - tr.encTable = encTable - } - return nil -} - -func (tr *TableImporter) addIndexes(ctx context.Context, db *sql.DB) (retErr error) { - const progressStep = "add-index" - task := tr.logger.Begin(zap.InfoLevel, "add indexes") - defer func() { - task.End(zap.ErrorLevel, retErr) - }() - - tblInfo := tr.tableInfo - tableName := tr.tableName - - singleSQL, multiSQLs := common.BuildAddIndexSQL(tableName, tblInfo.Core, tblInfo.Desired) - if len(multiSQLs) == 0 { - return nil - } - - logger := log.FromContext(ctx).With(zap.String("table", tableName)) - - defer func() { - if retErr == nil { - web.BroadcastTableProgress(tr.tableName, progressStep, 1) - } else if !log.IsContextCanceledError(retErr) { - // Try to strip the prefix of the error message. - // e.g "add index failed: Error 1062 ..." -> "Error 1062 ..." - cause := errors.Cause(retErr) - if cause == nil { - cause = retErr - } - retErr = common.ErrAddIndexFailed.GenWithStack( - "add index failed on table %s: %v, you can add index manually by the following SQL: %s", - tableName, cause, singleSQL) - } - }() - - var totalRows int - if m, ok := metric.FromContext(ctx); ok { - totalRows = int(metric.ReadCounter(m.RowsCounter.WithLabelValues(metric.StateRestored, tableName))) - } - - // Try to add all indexes in one statement. - err := tr.executeDDL(ctx, db, singleSQL, func(status *ddlStatus) { - if totalRows > 0 { - progress := float64(status.rowCount) / float64(totalRows*len(multiSQLs)) - if progress > 1 { - progress = 1 - } - web.BroadcastTableProgress(tableName, progressStep, progress) - logger.Info("add index progress", zap.String("progress", fmt.Sprintf("%.1f%%", progress*100))) - } - }) - if err == nil { - return nil - } - if !common.IsDupKeyError(err) { - return err - } - if len(multiSQLs) == 1 { - return nil - } - logger.Warn("cannot add all indexes in one statement, try to add them one by one", zap.Strings("sqls", multiSQLs), zap.Error(err)) - - baseProgress := float64(0) - for _, ddl := range multiSQLs { - err := tr.executeDDL(ctx, db, ddl, func(status *ddlStatus) { - if totalRows > 0 { - p := float64(status.rowCount) / float64(totalRows) - progress := baseProgress + p/float64(len(multiSQLs)) - web.BroadcastTableProgress(tableName, progressStep, progress) - logger.Info("add index progress", zap.String("progress", fmt.Sprintf("%.1f%%", progress*100))) - } - }) - if err != nil && !common.IsDupKeyError(err) { - return err - } - baseProgress += 1.0 / float64(len(multiSQLs)) - web.BroadcastTableProgress(tableName, progressStep, baseProgress) - } - return nil -} - -func (*TableImporter) executeDDL( - ctx context.Context, - db *sql.DB, - ddl string, - updateProgress func(status *ddlStatus), -) error { - logger := log.FromContext(ctx).With(zap.String("ddl", ddl)) - logger.Info("execute ddl") - - s := common.SQLWithRetry{ - DB: db, - Logger: logger, - } - - var currentTS int64 - if err := s.QueryRow(ctx, "", "SELECT UNIX_TIMESTAMP()", ¤tTS); err != nil { - currentTS = time.Now().Unix() - logger.Warn("failed to query current timestamp, use current time instead", zap.Int64("currentTS", currentTS), zap.Error(err)) - } - - resultCh := make(chan error, 1) - go func() { - resultCh <- s.Exec(ctx, "add index", ddl) - }() - - failpoint.Inject("AddIndexCrash", func() { - _ = common.KillMySelf() - }) - - var ddlErr error - for { - select { - case ddlErr = <-resultCh: - failpoint.Inject("AddIndexFail", func() { - ddlErr = errors.New("injected error") - }) - if ddlErr == nil { - return nil - } - if log.IsContextCanceledError(ddlErr) { - return ddlErr - } - if isDeterminedError(ddlErr) { - return ddlErr - } - logger.Warn("failed to execute ddl, try to query ddl status", zap.Error(ddlErr)) - case <-time.After(getDDLStatusInterval): - } - - var status *ddlStatus - err := common.Retry("query ddl status", logger, func() error { - var err error - status, err = getDDLStatus(ctx, db, ddl, time.Unix(currentTS, 0)) - return err - }) - if err != nil || status == nil { - logger.Warn("failed to query ddl status", zap.Error(err)) - if ddlErr != nil { - return ddlErr - } - continue - } - updateProgress(status) - - if ddlErr != nil { - switch state := status.state; state { - case model.JobStateDone, model.JobStateSynced: - logger.Info("ddl job is finished", zap.Stringer("state", state)) - return nil - case model.JobStateRunning, model.JobStateQueueing, model.JobStateNone: - logger.Info("ddl job is running", zap.Stringer("state", state)) - default: - logger.Warn("ddl job is canceled or rollbacked", zap.Stringer("state", state)) - return ddlErr - } - } - } -} - -func isDeterminedError(err error) bool { - if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { - switch merr.Number { - case errno.ErrDupKeyName, errno.ErrMultiplePriKey, errno.ErrDupUnique, errno.ErrDupEntry: - return true - } - } - return false -} - -const ( - getDDLStatusInterval = time.Minute - // Limit the number of jobs to query. Large limit may result in empty result. See https://github.com/pingcap/tidb/issues/42298. - // A new TiDB cluster has at least 40 jobs in the history queue, so 30 is a reasonable value. - getDDLStatusMaxJobs = 30 -) - -type ddlStatus struct { - state model.JobState - rowCount int64 -} - -func getDDLStatus( - ctx context.Context, - db *sql.DB, - query string, - minCreateTime time.Time, -) (*ddlStatus, error) { - jobID, err := getDDLJobIDByQuery(ctx, db, query) - if err != nil || jobID == 0 { - return nil, err - } - rows, err := db.QueryContext(ctx, fmt.Sprintf("ADMIN SHOW DDL JOBS %d WHERE job_id = %d", getDDLStatusMaxJobs, jobID)) - if err != nil { - return nil, errors.Trace(err) - } - defer rows.Close() - - cols, err := rows.Columns() - if err != nil { - return nil, errors.Trace(err) - } - - var ( - rowCount int64 - state string - createTimeStr sql.NullString - ) - dest := make([]any, len(cols)) - for i, col := range cols { - switch strings.ToLower(col) { - case "row_count": - dest[i] = &rowCount - case "state": - dest[i] = &state - case "create_time": - dest[i] = &createTimeStr - default: - var anyStr sql.NullString - dest[i] = &anyStr - } - } - status := &ddlStatus{} - - for rows.Next() { - if err := rows.Scan(dest...); err != nil { - return nil, errors.Trace(err) - } - status.rowCount += rowCount - // subjob doesn't have create_time, ignore it. - if !createTimeStr.Valid || createTimeStr.String == "" { - continue - } - createTime, err := time.Parse(time.DateTime, createTimeStr.String) - if err != nil { - return nil, errors.Trace(err) - } - // The job is not created by the current task, ignore it. - if createTime.Before(minCreateTime) { - return nil, nil - } - status.state = model.StrToJobState(state) - } - return status, errors.Trace(rows.Err()) -} - -func getDDLJobIDByQuery(ctx context.Context, db *sql.DB, wantQuery string) (int64, error) { - rows, err := db.QueryContext(ctx, fmt.Sprintf("ADMIN SHOW DDL JOB QUERIES LIMIT %d", getDDLStatusMaxJobs)) - if err != nil { - return 0, errors.Trace(err) - } - defer rows.Close() - - for rows.Next() { - var ( - jobID int64 - query string - ) - if err := rows.Scan(&jobID, &query); err != nil { - return 0, errors.Trace(err) - } - if query == wantQuery { - return jobID, errors.Trace(rows.Err()) - } - } - return 0, errors.Trace(rows.Err()) -} - -func (tr *TableImporter) preDeduplicate( - ctx context.Context, - rc *Controller, - cp *checkpoints.TableCheckpoint, - workingDir string, -) error { - d := &dupDetector{ - tr: tr, - rc: rc, - cp: cp, - logger: tr.logger, - } - originalErr := d.run(ctx, workingDir, tr.dupIgnoreRows) - if originalErr == nil { - return nil - } - - if !ErrDuplicateKey.Equal(originalErr) { - return errors.Trace(originalErr) - } - - var ( - idxName string - oneConflictMsg, otherConflictMsg string - ) - - // provide a more friendly error message - - dupErr := errors.Cause(originalErr).(*errors.Error) - conflictIdxID := dupErr.Args()[0].(int64) - if conflictIdxID == conflictOnHandle { - idxName = "PRIMARY" - } else { - for _, idxInfo := range tr.tableInfo.Core.Indices { - if idxInfo.ID == conflictIdxID { - idxName = idxInfo.Name.O - break - } - } - } - if idxName == "" { - tr.logger.Error("cannot find index name", zap.Int64("conflictIdxID", conflictIdxID)) - return errors.Trace(originalErr) - } - if !rc.cfg.Checkpoint.Enable { - err := errors.Errorf("duplicate key in table %s caused by index `%s`, but because checkpoint is off we can't have more details", - tr.tableName, idxName) - rc.errorMgr.RecordDuplicateOnce( - ctx, tr.logger, tr.tableName, "", -1, err.Error(), -1, "", - ) - return err - } - conflictEncodedRowIDs := dupErr.Args()[1].([][]byte) - if len(conflictEncodedRowIDs) < 2 { - tr.logger.Error("invalid conflictEncodedRowIDs", zap.Int("len", len(conflictEncodedRowIDs))) - return errors.Trace(originalErr) - } - rowID := make([]int64, 2) - var err error - _, rowID[0], err = codec.DecodeComparableVarint(conflictEncodedRowIDs[0]) - if err != nil { - rowIDHex := hex.EncodeToString(conflictEncodedRowIDs[0]) - tr.logger.Error("failed to decode rowID", - zap.String("rowID", rowIDHex), - zap.Error(err)) - return errors.Trace(originalErr) - } - _, rowID[1], err = codec.DecodeComparableVarint(conflictEncodedRowIDs[1]) - if err != nil { - rowIDHex := hex.EncodeToString(conflictEncodedRowIDs[1]) - tr.logger.Error("failed to decode rowID", - zap.String("rowID", rowIDHex), - zap.Error(err)) - return errors.Trace(originalErr) - } - - tableCp, err := rc.checkpointsDB.Get(ctx, tr.tableName) - if err != nil { - tr.logger.Error("failed to get table checkpoint", zap.Error(err)) - return errors.Trace(err) - } - var ( - secondConflictPath string - ) - for _, engineCp := range tableCp.Engines { - for _, chunkCp := range engineCp.Chunks { - if chunkCp.Chunk.PrevRowIDMax <= rowID[0] && rowID[0] < chunkCp.Chunk.RowIDMax { - oneConflictMsg = fmt.Sprintf("row %d counting from offset %d in file %s", - rowID[0]-chunkCp.Chunk.PrevRowIDMax, - chunkCp.Chunk.Offset, - chunkCp.FileMeta.Path) - } - if chunkCp.Chunk.PrevRowIDMax <= rowID[1] && rowID[1] < chunkCp.Chunk.RowIDMax { - secondConflictPath = chunkCp.FileMeta.Path - otherConflictMsg = fmt.Sprintf("row %d counting from offset %d in file %s", - rowID[1]-chunkCp.Chunk.PrevRowIDMax, - chunkCp.Chunk.Offset, - chunkCp.FileMeta.Path) - } - } - } - if oneConflictMsg == "" || otherConflictMsg == "" { - tr.logger.Error("cannot find conflict rows by rowID", - zap.Int64("rowID[0]", rowID[0]), - zap.Int64("rowID[1]", rowID[1])) - return errors.Trace(originalErr) - } - err = errors.Errorf("duplicate entry for key '%s', a pair of conflicting rows are (%s, %s)", - idxName, oneConflictMsg, otherConflictMsg) - rc.errorMgr.RecordDuplicateOnce( - ctx, tr.logger, tr.tableName, secondConflictPath, -1, err.Error(), rowID[1], "", - ) - return err -} diff --git a/lightning/pkg/server/binding__failpoint_binding__.go b/lightning/pkg/server/binding__failpoint_binding__.go deleted file mode 100644 index 884841332390a..0000000000000 --- a/lightning/pkg/server/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package server - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/lightning/pkg/server/lightning.go b/lightning/pkg/server/lightning.go index 465538267411f..0c20abf10c824 100644 --- a/lightning/pkg/server/lightning.go +++ b/lightning/pkg/server/lightning.go @@ -237,12 +237,12 @@ func (l *Lightning) goServe(statusAddr string, realAddrWriter io.Writer) error { mux.HandleFunc("/debug/pprof/trace", pprof.Trace) // Enable failpoint http API for testing. - if _, _err_ := failpoint.Eval(_curpkg_("EnableTestAPI")); _err_ == nil { + failpoint.Inject("EnableTestAPI", func() { mux.HandleFunc("/fail/", func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") new(failpoint.HttpHandler).ServeHTTP(w, r) }) - } + }) handleTasks := http.StripPrefix("/tasks", http.HandlerFunc(l.handleTask)) mux.Handle("/tasks", httpHandleWrapper(handleTasks.ServeHTTP)) @@ -329,7 +329,7 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. opt(o) } - if val, _err_ := failpoint.Eval(_curpkg_("setExtStorage")); _err_ == nil { + failpoint.Inject("setExtStorage", func(val failpoint.Value) { path := val.(string) b, err := storage.ParseBackend(path, nil) if err != nil { @@ -341,11 +341,11 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. } o.dumpFileStorage = s o.checkpointStorage = s - } - if val, _err_ := failpoint.Eval(_curpkg_("setCheckpointName")); _err_ == nil { + }) + failpoint.Inject("setCheckpointName", func(val failpoint.Value) { file := val.(string) o.checkpointName = file - } + }) if o.dumpFileStorage != nil { // we don't use it, set a value to pass Adjust @@ -357,11 +357,11 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. } taskCfg.TaskID = time.Now().UnixNano() - if val, _err_ := failpoint.Eval(_curpkg_("SetTaskID")); _err_ == nil { + failpoint.Inject("SetTaskID", func(val failpoint.Value) { taskCfg.TaskID = int64(val.(int)) - } + }) - if _, _err_ := failpoint.Eval(_curpkg_("SetIOTotalBytes")); _err_ == nil { + failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { o.logger.Info("set io total bytes") taskCfg.TiDB.IOTotalBytes = atomic.NewUint64(0) taskCfg.TiDB.UUID = uuid.New().String() @@ -371,7 +371,7 @@ func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config. log.L().Info("IOTotalBytes", zap.Uint64("IOTotalBytes", taskCfg.TiDB.IOTotalBytes.Load())) } }() - } + }) if taskCfg.TiDB.IOTotalBytes != nil { o.logger.Info("found IO total bytes counter") mysql.RegisterDialContext(taskCfg.TiDB.UUID, func(ctx context.Context, addr string) (net.Conn, error) { @@ -463,7 +463,7 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti web.BroadcastEndTask(err) }() - if _, _err_ := failpoint.Eval(_curpkg_("SkipRunTask")); _err_ == nil { + failpoint.Inject("SkipRunTask", func() { if notifyCh, ok := l.ctx.Value(taskRunNotifyKey).(chan struct{}); ok { select { case notifyCh <- struct{}{}: @@ -474,13 +474,13 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti select { case recorder <- taskCfg: case <-ctx.Done(): - return ctx.Err() + failpoint.Return(ctx.Err()) } } - return nil - } + failpoint.Return(nil) + }) - if val, _err_ := failpoint.Eval(_curpkg_("SetCertExpiredSoon")); _err_ == nil { + failpoint.Inject("SetCertExpiredSoon", func(val failpoint.Value) { rootKeyPath := val.(string) rootCaPath := taskCfg.Security.CAPath keyPath := taskCfg.Security.KeyPath @@ -488,9 +488,9 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti if err := updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath, time.Second*10); err != nil { panic(err) } - } + }) - if _, _err_ := failpoint.Eval(_curpkg_("PrintStatus")); _err_ == nil { + failpoint.Inject("PrintStatus", func() { defer func() { finished, total := l.Status() o.logger.Warn("PrintStatus Failpoint", @@ -498,7 +498,7 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti zap.Int64("total", total), zap.Bool("equal", finished == total)) }() - } + }) if err := taskCfg.TiDB.Security.BuildTLSConfig(); err != nil { return common.ErrInvalidTLSConfig.Wrap(err) @@ -602,10 +602,10 @@ func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *opti return errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("orphanWriterGoRoutine")); _err_ == nil { + failpoint.Inject("orphanWriterGoRoutine", func() { // don't exit too quickly to expose panic defer time.Sleep(time.Second * 10) - } + }) defer procedure.Close() err = procedure.Run(ctx) diff --git a/lightning/pkg/server/lightning.go__failpoint_stash__ b/lightning/pkg/server/lightning.go__failpoint_stash__ deleted file mode 100644 index 0c20abf10c824..0000000000000 --- a/lightning/pkg/server/lightning.go__failpoint_stash__ +++ /dev/null @@ -1,1152 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package server - -import ( - "cmp" - "compress/gzip" - "context" - "crypto/ecdsa" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "database/sql" - "encoding/json" - "encoding/pem" - "fmt" - "io" - "net" - "net/http" - "net/http/pprof" - "os" - "slices" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-sql-driver/mysql" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/import_sstpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/br/pkg/restore/split" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/version/build" - "github.com/pingcap/tidb/lightning/pkg/importer" - "github.com/pingcap/tidb/lightning/pkg/web" - _ "github.com/pingcap/tidb/pkg/expression" // get rid of `import cycle`: just init expression.RewriteAstExpr,and called at package `backend.kv`. - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/lightning/tikv" - _ "github.com/pingcap/tidb/pkg/planner/core" // init expression.EvalSimpleAst related function - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/promutil" - "github.com/pingcap/tidb/pkg/util/redact" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/collectors" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/shurcooL/httpgzip" - pdhttp "github.com/tikv/pd/client/http" - "go.uber.org/atomic" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -// Lightning is the main struct of the lightning package. -type Lightning struct { - globalCfg *config.GlobalConfig - globalTLS *common.TLS - // taskCfgs is the list of task configurations enqueued in the server mode - taskCfgs *config.List - ctx context.Context - shutdown context.CancelFunc // for whole lightning context - server http.Server - serverAddr net.Addr - serverLock sync.Mutex - status importer.LightningStatus - - promFactory promutil.Factory - promRegistry promutil.Registry - metrics *metric.Metrics - - cancelLock sync.Mutex - curTask *config.Config - cancel context.CancelFunc // for per task context, which maybe different from lightning context - - taskCanceled bool -} - -func initEnv(cfg *config.GlobalConfig) error { - if cfg.App.Config.File == "" { - return nil - } - return log.InitLogger(&cfg.App.Config, cfg.TiDB.LogLevel) -} - -// New creates a new Lightning instance. -func New(globalCfg *config.GlobalConfig) *Lightning { - if err := initEnv(globalCfg); err != nil { - fmt.Println("Failed to initialize environment:", err) - os.Exit(1) - } - - tls, err := common.NewTLS( - globalCfg.Security.CAPath, - globalCfg.Security.CertPath, - globalCfg.Security.KeyPath, - globalCfg.App.StatusAddr, - globalCfg.Security.CABytes, - globalCfg.Security.CertBytes, - globalCfg.Security.KeyBytes, - ) - if err != nil { - log.L().Fatal("failed to load TLS certificates", zap.Error(err)) - } - - redact.InitRedact(globalCfg.Security.RedactInfoLog) - - promFactory := promutil.NewDefaultFactory() - promRegistry := promutil.NewDefaultRegistry() - ctx, shutdown := context.WithCancel(context.Background()) - return &Lightning{ - globalCfg: globalCfg, - globalTLS: tls, - ctx: ctx, - shutdown: shutdown, - promFactory: promFactory, - promRegistry: promRegistry, - } -} - -// GoServe starts the HTTP server in a goroutine. The server will be closed -func (l *Lightning) GoServe() error { - handleSigUsr1(func() { - l.serverLock.Lock() - statusAddr := l.globalCfg.App.StatusAddr - shouldStartServer := len(statusAddr) == 0 - if shouldStartServer { - l.globalCfg.App.StatusAddr = ":" - } - l.serverLock.Unlock() - - if shouldStartServer { - // open a random port and start the server if SIGUSR1 is received. - if err := l.goServe(":", os.Stderr); err != nil { - log.L().Warn("failed to start HTTP server", log.ShortError(err)) - } - } else { - // just prints the server address if it is already started. - log.L().Info("already started HTTP server", zap.Stringer("address", l.serverAddr)) - } - }) - - l.serverLock.Lock() - statusAddr := l.globalCfg.App.StatusAddr - l.serverLock.Unlock() - - if len(statusAddr) == 0 { - return nil - } - return l.goServe(statusAddr, io.Discard) -} - -// TODO: maybe handle http request using gin -type loggingResponseWriter struct { - http.ResponseWriter - statusCode int - body string -} - -func newLoggingResponseWriter(w http.ResponseWriter) *loggingResponseWriter { - return &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK} -} - -// WriteHeader implements http.ResponseWriter. -func (lrw *loggingResponseWriter) WriteHeader(code int) { - lrw.statusCode = code - lrw.ResponseWriter.WriteHeader(code) -} - -// Write implements http.ResponseWriter. -func (lrw *loggingResponseWriter) Write(d []byte) (int, error) { - // keep first part of the response for logging, max 1K - if lrw.body == "" && len(d) > 0 { - length := len(d) - if length > 1024 { - length = 1024 - } - lrw.body = string(d[:length]) - } - return lrw.ResponseWriter.Write(d) -} - -func httpHandleWrapper(h http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - logger := log.L().With(zap.String("method", r.Method), zap.Stringer("url", r.URL)). - Begin(zapcore.InfoLevel, "process http request") - - newWriter := newLoggingResponseWriter(w) - h.ServeHTTP(newWriter, r) - - bodyField := zap.Skip() - if newWriter.Header().Get("Content-Encoding") != "gzip" { - bodyField = zap.String("body", newWriter.body) - } - logger.End(zapcore.InfoLevel, nil, zap.Int("status", newWriter.statusCode), bodyField) - } -} - -func (l *Lightning) goServe(statusAddr string, realAddrWriter io.Writer) error { - mux := http.NewServeMux() - mux.Handle("/", http.RedirectHandler("/web/", http.StatusFound)) - - registry := l.promRegistry - registry.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) - registry.MustRegister(collectors.NewGoCollector()) - if gatherer, ok := registry.(prometheus.Gatherer); ok { - handler := promhttp.InstrumentMetricHandler( - registry, promhttp.HandlerFor(gatherer, promhttp.HandlerOpts{}), - ) - mux.Handle("/metrics", handler) - } - - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - - // Enable failpoint http API for testing. - failpoint.Inject("EnableTestAPI", func() { - mux.HandleFunc("/fail/", func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") - new(failpoint.HttpHandler).ServeHTTP(w, r) - }) - }) - - handleTasks := http.StripPrefix("/tasks", http.HandlerFunc(l.handleTask)) - mux.Handle("/tasks", httpHandleWrapper(handleTasks.ServeHTTP)) - mux.Handle("/tasks/", httpHandleWrapper(handleTasks.ServeHTTP)) - mux.HandleFunc("/progress/task", httpHandleWrapper(handleProgressTask)) - mux.HandleFunc("/progress/table", httpHandleWrapper(handleProgressTable)) - mux.HandleFunc("/pause", httpHandleWrapper(handlePause)) - mux.HandleFunc("/resume", httpHandleWrapper(handleResume)) - mux.HandleFunc("/loglevel", httpHandleWrapper(handleLogLevel)) - - mux.Handle("/web/", http.StripPrefix("/web", httpgzip.FileServer(web.Res, httpgzip.FileServerOptions{ - IndexHTML: true, - ServeError: func(w http.ResponseWriter, req *http.Request, err error) { - if os.IsNotExist(err) && !strings.Contains(req.URL.Path, ".") { - http.Redirect(w, req, "/web/", http.StatusFound) - } else { - httpgzip.NonSpecific(w, req, err) - } - }, - }))) - - listener, err := net.Listen("tcp", statusAddr) - if err != nil { - return err - } - l.serverAddr = listener.Addr() - log.L().Info("starting HTTP server", zap.Stringer("address", l.serverAddr)) - fmt.Fprintln(realAddrWriter, "started HTTP server on", l.serverAddr) - l.server.Handler = mux - listener = l.globalTLS.WrapListener(listener) - - go func() { - err := l.server.Serve(listener) - log.L().Info("stopped HTTP server", log.ShortError(err)) - }() - return nil -} - -// RunServer is used by binary lightning to start a HTTP server to receive import tasks. -func (l *Lightning) RunServer() error { - l.serverLock.Lock() - l.taskCfgs = config.NewConfigList() - l.serverLock.Unlock() - log.L().Info( - "Lightning server is running, post to /tasks to start an import task", - zap.Stringer("address", l.serverAddr), - ) - - for { - task, err := l.taskCfgs.Pop(l.ctx) - if err != nil { - return err - } - o := &options{ - promFactory: l.promFactory, - promRegistry: l.promRegistry, - logger: log.L(), - } - err = l.run(context.Background(), task, o) - if err != nil && !common.IsContextCanceledError(err) { - importer.DeliverPauser.Pause() // force pause the progress on error - log.L().Error("tidb lightning encountered error", zap.Error(err)) - } - } -} - -// RunOnceWithOptions is used by binary lightning and host when using lightning as a library. -// - for binary lightning, taskCtx could be context.Background which means taskCtx wouldn't be canceled directly by its -// cancel function, but only by Lightning.Stop or HTTP DELETE using l.cancel. No need to set Options -// - for lightning as a library, taskCtx could be a meaningful context that get canceled outside, and there Options may -// be used: -// - WithGlue: set a caller implemented glue. Otherwise, lightning will use a default glue later. -// - WithDumpFileStorage: caller has opened an external storage for lightning. Otherwise, lightning will open a -// storage by config -// - WithCheckpointStorage: caller has opened an external storage for lightning and want to save checkpoint -// in it. Otherwise, lightning will save checkpoint by the Checkpoint.DSN in config -func (l *Lightning) RunOnceWithOptions(taskCtx context.Context, taskCfg *config.Config, opts ...Option) error { - o := &options{ - promFactory: l.promFactory, - promRegistry: l.promRegistry, - logger: log.L(), - } - for _, opt := range opts { - opt(o) - } - - failpoint.Inject("setExtStorage", func(val failpoint.Value) { - path := val.(string) - b, err := storage.ParseBackend(path, nil) - if err != nil { - panic(err) - } - s, err := storage.New(context.Background(), b, &storage.ExternalStorageOptions{}) - if err != nil { - panic(err) - } - o.dumpFileStorage = s - o.checkpointStorage = s - }) - failpoint.Inject("setCheckpointName", func(val failpoint.Value) { - file := val.(string) - o.checkpointName = file - }) - - if o.dumpFileStorage != nil { - // we don't use it, set a value to pass Adjust - taskCfg.Mydumper.SourceDir = "noop://" - } - - if err := taskCfg.Adjust(taskCtx); err != nil { - return err - } - - taskCfg.TaskID = time.Now().UnixNano() - failpoint.Inject("SetTaskID", func(val failpoint.Value) { - taskCfg.TaskID = int64(val.(int)) - }) - - failpoint.Inject("SetIOTotalBytes", func(_ failpoint.Value) { - o.logger.Info("set io total bytes") - taskCfg.TiDB.IOTotalBytes = atomic.NewUint64(0) - taskCfg.TiDB.UUID = uuid.New().String() - go func() { - for { - time.Sleep(time.Millisecond * 10) - log.L().Info("IOTotalBytes", zap.Uint64("IOTotalBytes", taskCfg.TiDB.IOTotalBytes.Load())) - } - }() - }) - if taskCfg.TiDB.IOTotalBytes != nil { - o.logger.Info("found IO total bytes counter") - mysql.RegisterDialContext(taskCfg.TiDB.UUID, func(ctx context.Context, addr string) (net.Conn, error) { - o.logger.Debug("connection with IO bytes counter") - d := &net.Dialer{} - conn, err := d.DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } - tcpConn := conn.(*net.TCPConn) - // try https://github.com/go-sql-driver/mysql/blob/bcc459a906419e2890a50fc2c99ea6dd927a88f2/connector.go#L56-L64 - err = tcpConn.SetKeepAlive(true) - if err != nil { - o.logger.Warn("set TCP keep alive failed", zap.Error(err)) - } - return util.NewTCPConnWithIOCounter(tcpConn, taskCfg.TiDB.IOTotalBytes), nil - }) - } - - return l.run(taskCtx, taskCfg, o) -} - -var ( - taskRunNotifyKey = "taskRunNotifyKey" - taskCfgRecorderKey = "taskCfgRecorderKey" -) - -func getKeyspaceName(db *sql.DB) (string, error) { - if db == nil { - return "", nil - } - - rows, err := db.Query("show config where Type = 'tidb' and name = 'keyspace-name'") - if err != nil { - return "", err - } - //nolint: errcheck - defer rows.Close() - - var ( - _type string - _instance string - _name string - value string - ) - if rows.Next() { - err = rows.Scan(&_type, &_instance, &_name, &value) - if err != nil { - return "", err - } - } - - return value, rows.Err() -} - -func (l *Lightning) run(taskCtx context.Context, taskCfg *config.Config, o *options) (err error) { - build.LogInfo(build.Lightning) - o.logger.Info("cfg", zap.Stringer("cfg", taskCfg)) - - logutil.LogEnvVariables() - - if split.WaitRegionOnlineAttemptTimes != taskCfg.TikvImporter.RegionCheckBackoffLimit { - // it will cause data race if lightning is used as a library, but this is a - // hidden config so we ignore that case - split.WaitRegionOnlineAttemptTimes = taskCfg.TikvImporter.RegionCheckBackoffLimit - } - - metrics := metric.NewMetrics(o.promFactory) - metrics.RegisterTo(o.promRegistry) - defer func() { - metrics.UnregisterFrom(o.promRegistry) - }() - l.metrics = metrics - - ctx := metric.WithMetric(taskCtx, metrics) - ctx = log.NewContext(ctx, o.logger) - ctx, cancel := context.WithCancel(ctx) - l.cancelLock.Lock() - l.cancel = cancel - l.curTask = taskCfg - l.cancelLock.Unlock() - web.BroadcastStartTask() - - defer func() { - cancel() - l.cancelLock.Lock() - l.cancel = nil - l.cancelLock.Unlock() - web.BroadcastEndTask(err) - }() - - failpoint.Inject("SkipRunTask", func() { - if notifyCh, ok := l.ctx.Value(taskRunNotifyKey).(chan struct{}); ok { - select { - case notifyCh <- struct{}{}: - default: - } - } - if recorder, ok := l.ctx.Value(taskCfgRecorderKey).(chan *config.Config); ok { - select { - case recorder <- taskCfg: - case <-ctx.Done(): - failpoint.Return(ctx.Err()) - } - } - failpoint.Return(nil) - }) - - failpoint.Inject("SetCertExpiredSoon", func(val failpoint.Value) { - rootKeyPath := val.(string) - rootCaPath := taskCfg.Security.CAPath - keyPath := taskCfg.Security.KeyPath - certPath := taskCfg.Security.CertPath - if err := updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath, time.Second*10); err != nil { - panic(err) - } - }) - - failpoint.Inject("PrintStatus", func() { - defer func() { - finished, total := l.Status() - o.logger.Warn("PrintStatus Failpoint", - zap.Int64("finished", finished), - zap.Int64("total", total), - zap.Bool("equal", finished == total)) - }() - }) - - if err := taskCfg.TiDB.Security.BuildTLSConfig(); err != nil { - return common.ErrInvalidTLSConfig.Wrap(err) - } - - s := o.dumpFileStorage - if s == nil { - u, err := storage.ParseBackend(taskCfg.Mydumper.SourceDir, nil) - if err != nil { - return common.NormalizeError(err) - } - s, err = storage.New(ctx, u, &storage.ExternalStorageOptions{}) - if err != nil { - return common.NormalizeError(err) - } - } - - // return expectedErr means at least meet one file - expectedErr := errors.New("Stop Iter") - walkErr := s.WalkDir(ctx, &storage.WalkOption{ListCount: 1}, func(string, int64) error { - // return an error when meet the first regular file to break the walk loop - return expectedErr - }) - if !errors.ErrorEqual(walkErr, expectedErr) { - if walkErr == nil { - return common.ErrEmptySourceDir.GenWithStackByArgs(taskCfg.Mydumper.SourceDir) - } - return common.NormalizeOrWrapErr(common.ErrStorageUnknown, walkErr) - } - - loadTask := o.logger.Begin(zap.InfoLevel, "load data source") - var mdl *mydump.MDLoader - mdl, err = mydump.NewLoaderWithStore(ctx, mydump.NewLoaderCfg(taskCfg), s) - loadTask.End(zap.ErrorLevel, err) - if err != nil { - return errors.Trace(err) - } - err = checkSystemRequirement(taskCfg, mdl.GetDatabases()) - if err != nil { - o.logger.Error("check system requirements failed", zap.Error(err)) - return common.ErrSystemRequirementNotMet.Wrap(err).GenWithStackByArgs() - } - // check table schema conflicts - err = checkSchemaConflict(taskCfg, mdl.GetDatabases()) - if err != nil { - o.logger.Error("checkpoint schema conflicts with data files", zap.Error(err)) - return errors.Trace(err) - } - - dbMetas := mdl.GetDatabases() - web.BroadcastInitProgress(dbMetas) - - // db is only not nil in unit test - db := o.db - if db == nil { - // initiation of default db should be after BuildTLSConfig - db, err = importer.DBFromConfig(ctx, taskCfg.TiDB) - if err != nil { - return common.ErrDBConnect.Wrap(err) - } - } - - var keyspaceName string - if taskCfg.TikvImporter.Backend == config.BackendLocal { - keyspaceName = taskCfg.TikvImporter.KeyspaceName - if keyspaceName == "" { - keyspaceName, err = getKeyspaceName(db) - if err != nil && common.IsAccessDeniedNeedConfigPrivilegeError(err) { - // if the cluster is not multitenant we don't really need to know about the keyspace. - // since the doc does not say we require CONFIG privilege, - // spelling out the Access Denied error just confuses the users. - // hide such allowed errors unless log level is DEBUG. - o.logger.Info("keyspace is unspecified and target user has no config privilege, assuming dedicated cluster") - if o.logger.Level() > zapcore.DebugLevel { - err = nil - } - } - if err != nil { - o.logger.Warn("unable to get keyspace name, lightning will use empty keyspace name", zap.Error(err)) - } - } - o.logger.Info("acquired keyspace name", zap.String("keyspaceName", keyspaceName)) - } - - param := &importer.ControllerParam{ - DBMetas: dbMetas, - Status: &l.status, - DumpFileStorage: s, - OwnExtStorage: o.dumpFileStorage == nil, - DB: db, - CheckpointStorage: o.checkpointStorage, - CheckpointName: o.checkpointName, - DupIndicator: o.dupIndicator, - KeyspaceName: keyspaceName, - } - - var procedure *importer.Controller - procedure, err = importer.NewImportController(ctx, taskCfg, param) - if err != nil { - o.logger.Error("restore failed", log.ShortError(err)) - return errors.Trace(err) - } - - failpoint.Inject("orphanWriterGoRoutine", func() { - // don't exit too quickly to expose panic - defer time.Sleep(time.Second * 10) - }) - defer procedure.Close() - - err = procedure.Run(ctx) - return errors.Trace(err) -} - -// Stop stops the lightning server. -func (l *Lightning) Stop() { - l.cancelLock.Lock() - if l.cancel != nil { - l.taskCanceled = true - l.cancel() - } - l.cancelLock.Unlock() - if err := l.server.Shutdown(l.ctx); err != nil { - log.L().Warn("failed to shutdown HTTP server", log.ShortError(err)) - } - l.shutdown() -} - -// TaskCanceled return whether the current task is canceled. -func (l *Lightning) TaskCanceled() bool { - l.cancelLock.Lock() - defer l.cancelLock.Unlock() - return l.taskCanceled -} - -// Status return the sum size of file which has been imported to TiKV and the total size of source file. -func (l *Lightning) Status() (finished int64, total int64) { - finished = l.status.FinishedFileSize.Load() - total = l.status.TotalFileSize.Load() - return -} - -// Metrics returns the metrics of lightning. -// it's inited during `run`, so might return nil. -func (l *Lightning) Metrics() *metric.Metrics { - return l.metrics -} - -func writeJSONError(w http.ResponseWriter, code int, prefix string, err error) { - type errorResponse struct { - Error string `json:"error"` - } - - w.WriteHeader(code) - - if err != nil { - prefix += ": " + err.Error() - } - _ = json.NewEncoder(w).Encode(errorResponse{Error: prefix}) -} - -func parseTaskID(req *http.Request) (int64, string, error) { - path := strings.TrimPrefix(req.URL.Path, "/") - taskIDString := path - verb := "" - if i := strings.IndexByte(path, '/'); i >= 0 { - taskIDString = path[:i] - verb = path[i+1:] - } - - taskID, err := strconv.ParseInt(taskIDString, 10, 64) - if err != nil { - return 0, "", err - } - - return taskID, verb, nil -} - -func (l *Lightning) handleTask(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch req.Method { - case http.MethodGet: - taskID, _, err := parseTaskID(req) - // golint tells us to refactor this with switch stmt. - // However switch stmt doesn't support init-statements, - // hence if we follow it things might be worse. - // Anyway, this chain of if-else isn't unacceptable. - //nolint:gocritic - if e, ok := err.(*strconv.NumError); ok && e.Num == "" { - l.handleGetTask(w) - } else if err == nil { - l.handleGetOneTask(w, req, taskID) - } else { - writeJSONError(w, http.StatusBadRequest, "invalid task ID", err) - } - case http.MethodPost: - l.handlePostTask(w, req) - case http.MethodDelete: - l.handleDeleteOneTask(w, req) - case http.MethodPatch: - l.handlePatchOneTask(w, req) - default: - w.Header().Set("Allow", http.MethodGet+", "+http.MethodPost+", "+http.MethodDelete+", "+http.MethodPatch) - writeJSONError(w, http.StatusMethodNotAllowed, "only GET, POST, DELETE and PATCH are allowed", nil) - } -} - -func (l *Lightning) handleGetTask(w http.ResponseWriter) { - var response struct { - Current *int64 `json:"current"` - QueuedIDs []int64 `json:"queue"` - } - l.serverLock.Lock() - if l.taskCfgs != nil { - response.QueuedIDs = l.taskCfgs.AllIDs() - } else { - response.QueuedIDs = []int64{} - } - l.serverLock.Unlock() - - l.cancelLock.Lock() - if l.cancel != nil && l.curTask != nil { - response.Current = new(int64) - *response.Current = l.curTask.TaskID - } - l.cancelLock.Unlock() - - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(response) -} - -func (l *Lightning) handleGetOneTask(w http.ResponseWriter, req *http.Request, taskID int64) { - var task *config.Config - - l.cancelLock.Lock() - if l.curTask != nil && l.curTask.TaskID == taskID { - task = l.curTask - } - l.cancelLock.Unlock() - - if task == nil && l.taskCfgs != nil { - task, _ = l.taskCfgs.Get(taskID) - } - - if task == nil { - writeJSONError(w, http.StatusNotFound, "task ID not found", nil) - return - } - - json, err := json.Marshal(task) - if err != nil { - writeJSONError(w, http.StatusInternalServerError, "unable to serialize task", err) - return - } - - writeBytesCompressed(w, req, json) -} - -func (l *Lightning) handlePostTask(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Cache-Control", "no-store") - l.serverLock.Lock() - defer l.serverLock.Unlock() - if l.taskCfgs == nil { - // l.taskCfgs is non-nil only if Lightning is started with RunServer(). - // Without the server mode this pointer is default to be nil. - writeJSONError(w, http.StatusNotImplemented, "server-mode not enabled", nil) - return - } - - type taskResponse struct { - ID int64 `json:"id"` - } - - data, err := io.ReadAll(req.Body) - if err != nil { - writeJSONError(w, http.StatusBadRequest, "cannot read request", err) - return - } - log.L().Info("received task config") - - cfg := config.NewConfig() - if err = cfg.LoadFromGlobal(l.globalCfg); err != nil { - writeJSONError(w, http.StatusInternalServerError, "cannot restore from global config", err) - return - } - if err = cfg.LoadFromTOML(data); err != nil { - writeJSONError(w, http.StatusBadRequest, "cannot parse task (must be TOML)", err) - return - } - if err = cfg.Adjust(l.ctx); err != nil { - writeJSONError(w, http.StatusBadRequest, "invalid task configuration", err) - return - } - - l.taskCfgs.Push(cfg) - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(taskResponse{ID: cfg.TaskID}) -} - -func (l *Lightning) handleDeleteOneTask(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - - taskID, _, err := parseTaskID(req) - if err != nil { - writeJSONError(w, http.StatusBadRequest, "invalid task ID", err) - return - } - - var cancel context.CancelFunc - cancelSuccess := false - - l.cancelLock.Lock() - if l.cancel != nil && l.curTask != nil && l.curTask.TaskID == taskID { - cancel = l.cancel - l.cancel = nil - } - l.cancelLock.Unlock() - - if cancel != nil { - cancel() - cancelSuccess = true - } else if l.taskCfgs != nil { - cancelSuccess = l.taskCfgs.Remove(taskID) - } - - log.L().Info("canceled task", zap.Int64("taskID", taskID), zap.Bool("success", cancelSuccess)) - - if cancelSuccess { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("{}")) - } else { - writeJSONError(w, http.StatusNotFound, "task ID not found", nil) - } -} - -func (l *Lightning) handlePatchOneTask(w http.ResponseWriter, req *http.Request) { - if l.taskCfgs == nil { - writeJSONError(w, http.StatusNotImplemented, "server-mode not enabled", nil) - return - } - - taskID, verb, err := parseTaskID(req) - if err != nil { - writeJSONError(w, http.StatusBadRequest, "invalid task ID", err) - return - } - - moveSuccess := false - switch verb { - case "front": - moveSuccess = l.taskCfgs.MoveToFront(taskID) - case "back": - moveSuccess = l.taskCfgs.MoveToBack(taskID) - default: - writeJSONError(w, http.StatusBadRequest, "unknown patch action", nil) - return - } - - if moveSuccess { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("{}")) - } else { - writeJSONError(w, http.StatusNotFound, "task ID not found", nil) - } -} - -func writeBytesCompressed(w http.ResponseWriter, req *http.Request, b []byte) { - if !strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") { - _, _ = w.Write(b) - return - } - - w.Header().Set("Content-Encoding", "gzip") - w.WriteHeader(http.StatusOK) - gw, _ := gzip.NewWriterLevel(w, gzip.BestSpeed) - _, _ = gw.Write(b) - _ = gw.Close() -} - -func handleProgressTask(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - res, err := web.MarshalTaskProgress() - if err == nil { - writeBytesCompressed(w, req, res) - } else { - w.WriteHeader(http.StatusInternalServerError) - _ = json.NewEncoder(w).Encode(err.Error()) - } -} - -func handleProgressTable(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - tableName := req.URL.Query().Get("t") - res, err := web.MarshalTableCheckpoints(tableName) - if err == nil { - writeBytesCompressed(w, req, res) - } else { - if errors.IsNotFound(err) { - w.WriteHeader(http.StatusNotFound) - } else { - w.WriteHeader(http.StatusInternalServerError) - } - _ = json.NewEncoder(w).Encode(err.Error()) - } -} - -func handlePause(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch req.Method { - case http.MethodGet: - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"paused":%v}`, importer.DeliverPauser.IsPaused()) - - case http.MethodPut: - w.WriteHeader(http.StatusOK) - importer.DeliverPauser.Pause() - log.L().Info("progress paused") - _, _ = w.Write([]byte("{}")) - - default: - w.Header().Set("Allow", http.MethodGet+", "+http.MethodPut) - writeJSONError(w, http.StatusMethodNotAllowed, "only GET and PUT are allowed", nil) - } -} - -func handleResume(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - - switch req.Method { - case http.MethodPut: - w.WriteHeader(http.StatusOK) - importer.DeliverPauser.Resume() - log.L().Info("progress resumed") - _, _ = w.Write([]byte("{}")) - - default: - w.Header().Set("Allow", http.MethodPut) - writeJSONError(w, http.StatusMethodNotAllowed, "only PUT is allowed", nil) - } -} - -func handleLogLevel(w http.ResponseWriter, req *http.Request) { - w.Header().Set("Content-Type", "application/json") - - var logLevel struct { - Level zapcore.Level `json:"level"` - } - - switch req.Method { - case http.MethodGet: - logLevel.Level = log.Level() - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(logLevel) - - case http.MethodPut, http.MethodPost: - if err := json.NewDecoder(req.Body).Decode(&logLevel); err != nil { - writeJSONError(w, http.StatusBadRequest, "invalid log level", err) - return - } - oldLevel := log.SetLevel(zapcore.InfoLevel) - log.L().Info("changed log level. No effects if task has specified its logger", - zap.Stringer("old", oldLevel), - zap.Stringer("new", logLevel.Level)) - log.SetLevel(logLevel.Level) - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("{}")) - - default: - w.Header().Set("Allow", http.MethodGet+", "+http.MethodPut+", "+http.MethodPost) - writeJSONError(w, http.StatusMethodNotAllowed, "only GET, PUT and POST are allowed", nil) - } -} - -func checkSystemRequirement(cfg *config.Config, dbsMeta []*mydump.MDDatabaseMeta) error { - // in local mode, we need to read&write a lot of L0 sst files, so we need to check system max open files limit - if cfg.TikvImporter.Backend == config.BackendLocal { - // estimate max open files = {top N(TableConcurrency) table sizes} / {MemoryTableSize} - tableTotalSizes := make([]int64, 0) - for _, dbs := range dbsMeta { - for _, tb := range dbs.Tables { - tableTotalSizes = append(tableTotalSizes, tb.TotalSize) - } - } - slices.SortFunc(tableTotalSizes, func(i, j int64) int { - return cmp.Compare(j, i) - }) - topNTotalSize := int64(0) - for i := 0; i < len(tableTotalSizes) && i < cfg.App.TableConcurrency; i++ { - topNTotalSize += tableTotalSizes[i] - } - - // region-concurrency: number of LocalWriters writing SST files. - // 2*totalSize/memCacheSize: number of Pebble MemCache files. - maxDBFiles := topNTotalSize / int64(cfg.TikvImporter.LocalWriterMemCacheSize) * 2 - // the pebble db and all import routine need upto maxDBFiles fds for read and write. - maxOpenDBFiles := maxDBFiles * (1 + int64(cfg.TikvImporter.RangeConcurrency)) - estimateMaxFiles := local.RlimT(cfg.App.RegionConcurrency) + local.RlimT(maxOpenDBFiles) - if err := local.VerifyRLimit(estimateMaxFiles); err != nil { - return err - } - } - - return nil -} - -// checkSchemaConflict return error if checkpoint table scheme is conflict with data files -func checkSchemaConflict(cfg *config.Config, dbsMeta []*mydump.MDDatabaseMeta) error { - if cfg.Checkpoint.Enable && cfg.Checkpoint.Driver == config.CheckpointDriverMySQL { - for _, db := range dbsMeta { - if db.Name == cfg.Checkpoint.Schema { - for _, tb := range db.Tables { - if checkpoints.IsCheckpointTable(tb.Name) { - return common.ErrCheckpointSchemaConflict.GenWithStack("checkpoint table `%s`.`%s` conflict with data files. Please change the `checkpoint.schema` config or set `checkpoint.driver` to \"file\" instead", db.Name, tb.Name) - } - } - } - } - } - return nil -} - -// CheckpointRemove removes the checkpoint of the given table. -func CheckpointRemove(ctx context.Context, cfg *config.Config, tableName string) error { - cpdb, err := checkpoints.OpenCheckpointsDB(ctx, cfg) - if err != nil { - return errors.Trace(err) - } - //nolint: errcheck - defer cpdb.Close() - - // try to remove the metadata first. - taskCp, err := cpdb.TaskCheckpoint(ctx) - if err != nil { - return errors.Trace(err) - } - // a empty id means this task is not inited, we needn't further check metas. - if taskCp != nil && taskCp.TaskID != 0 { - // try to clean up table metas if exists - if err = CleanupMetas(ctx, cfg, tableName); err != nil { - return errors.Trace(err) - } - } - - return errors.Trace(cpdb.RemoveCheckpoint(ctx, tableName)) -} - -// CleanupMetas removes the table metas of the given table. -func CleanupMetas(ctx context.Context, cfg *config.Config, tableName string) error { - if tableName == "all" { - tableName = "" - } - // try to clean up table metas if exists - db, err := importer.DBFromConfig(ctx, cfg.TiDB) - if err != nil { - return errors.Trace(err) - } - - tableMetaExist, err := common.TableExists(ctx, db, cfg.App.MetaSchemaName, importer.TableMetaTableName) - if err != nil { - return errors.Trace(err) - } - if tableMetaExist { - metaTableName := common.UniqueTable(cfg.App.MetaSchemaName, importer.TableMetaTableName) - if err = importer.RemoveTableMetaByTableName(ctx, db, metaTableName, tableName); err != nil { - return errors.Trace(err) - } - } - - exist, err := common.TableExists(ctx, db, cfg.App.MetaSchemaName, importer.TaskMetaTableName) - if err != nil || !exist { - return errors.Trace(err) - } - return errors.Trace(importer.MaybeCleanupAllMetas(ctx, log.L(), db, cfg.App.MetaSchemaName, tableMetaExist)) -} - -// SwitchMode switches the mode of the TiKV cluster. -func SwitchMode(ctx context.Context, cli pdhttp.Client, tls *tls.Config, mode string, ranges ...*import_sstpb.Range) error { - var m import_sstpb.SwitchMode - switch mode { - case config.ImportMode: - m = import_sstpb.SwitchMode_Import - case config.NormalMode: - m = import_sstpb.SwitchMode_Normal - default: - return errors.Errorf("invalid mode %s, must use %s or %s", mode, config.ImportMode, config.NormalMode) - } - - return tikv.ForAllStores( - ctx, - cli, - metapb.StoreState_Offline, - func(c context.Context, store *pdhttp.MetaStore) error { - return tikv.SwitchMode(c, tls, store.Address, m, ranges...) - }, - ) -} - -func updateCertExpiry(rootKeyPath, rootCaPath, keyPath, certPath string, expiry time.Duration) error { - rootKey, err := parsePrivateKey(rootKeyPath) - if err != nil { - return err - } - rootCaPem, err := os.ReadFile(rootCaPath) - if err != nil { - return err - } - rootCaDer, _ := pem.Decode(rootCaPem) - rootCa, err := x509.ParseCertificate(rootCaDer.Bytes) - if err != nil { - return err - } - key, err := parsePrivateKey(keyPath) - if err != nil { - return err - } - certPem, err := os.ReadFile(certPath) - if err != nil { - panic(err) - } - certDer, _ := pem.Decode(certPem) - cert, err := x509.ParseCertificate(certDer.Bytes) - if err != nil { - return err - } - cert.NotBefore = time.Now() - cert.NotAfter = time.Now().Add(expiry) - derBytes, err := x509.CreateCertificate(rand.Reader, cert, rootCa, &key.PublicKey, rootKey) - if err != nil { - return err - } - return os.WriteFile(certPath, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}), 0o600) -} - -func parsePrivateKey(keyPath string) (*ecdsa.PrivateKey, error) { - keyPemBlock, err := os.ReadFile(keyPath) - if err != nil { - return nil, err - } - var keyDERBlock *pem.Block - for { - keyDERBlock, keyPemBlock = pem.Decode(keyPemBlock) - if keyDERBlock == nil { - return nil, errors.New("failed to find PEM block with type ending in \"PRIVATE KEY\"") - } - if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { - break - } - } - return x509.ParseECPrivateKey(keyDERBlock.Bytes) -} diff --git a/pkg/autoid_service/autoid.go b/pkg/autoid_service/autoid.go index 528712b638642..dbbe7b8ee3353 100644 --- a/pkg/autoid_service/autoid.go +++ b/pkg/autoid_service/autoid.go @@ -462,11 +462,11 @@ func (s *Service) allocAutoID(ctx context.Context, req *autoid.AutoIDRequest) (* return nil, errors.New("not leader") } - if val, _err_ := failpoint.Eval(_curpkg_("mockErr")); _err_ == nil { + failpoint.Inject("mockErr", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("mock reload failed") + failpoint.Return(nil, errors.New("mock reload failed")) } - } + }) val := s.getAlloc(req.DbID, req.TblID, req.IsUnsigned) val.Lock() diff --git a/pkg/autoid_service/autoid.go__failpoint_stash__ b/pkg/autoid_service/autoid.go__failpoint_stash__ deleted file mode 100644 index dbbe7b8ee3353..0000000000000 --- a/pkg/autoid_service/autoid.go__failpoint_stash__ +++ /dev/null @@ -1,612 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package autoid - -import ( - "context" - "crypto/tls" - "math" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/autoid" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/keyspace" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - autoid1 "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/owner" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/util/etcd" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/mathutil" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" -) - -var ( - errAutoincReadFailed = errors.New("auto increment action failed") -) - -const ( - autoIDLeaderPath = "tidb/autoid/leader" -) - -type autoIDKey struct { - dbID int64 - tblID int64 -} - -type autoIDValue struct { - sync.Mutex - base int64 - end int64 - isUnsigned bool - token chan struct{} -} - -func (alloc *autoIDValue) alloc4Unsigned(ctx context.Context, store kv.Storage, dbID, tblID int64, isUnsigned bool, - n uint64, increment, offset int64) (min int64, max int64, err error) { - // Check offset rebase if necessary. - if uint64(offset-1) > uint64(alloc.base) { - if err := alloc.rebase4Unsigned(ctx, store, dbID, tblID, uint64(offset-1)); err != nil { - return 0, 0, err - } - } - // calcNeededBatchSize calculates the total batch size needed. - n1 := calcNeededBatchSize(alloc.base, int64(n), increment, offset, isUnsigned) - - // The local rest is not enough for alloc. - if uint64(alloc.base)+uint64(n1) > uint64(alloc.end) || alloc.base == 0 { - var newBase, newEnd int64 - nextStep := int64(batch) - fromBase := alloc.base - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { - idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) - var err1 error - newBase, err1 = idAcc.Get() - if err1 != nil { - return err1 - } - // calcNeededBatchSize calculates the total batch size needed on new base. - if alloc.base == 0 || newBase != alloc.end { - alloc.base = newBase - alloc.end = newBase - n1 = calcNeededBatchSize(newBase, int64(n), increment, offset, isUnsigned) - } - - // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. - if nextStep < n1 { - nextStep = n1 - } - tmpStep := int64(mathutil.Min(math.MaxUint64-uint64(newBase), uint64(nextStep))) - // The global rest is not enough for alloc. - if tmpStep < n1 { - return errAutoincReadFailed - } - newEnd, err1 = idAcc.Inc(tmpStep) - return err1 - }) - if err != nil { - return 0, 0, err - } - if uint64(newBase) == math.MaxUint64 { - return 0, 0, errAutoincReadFailed - } - logutil.BgLogger().Info("alloc4Unsigned from", - zap.String("category", "autoid service"), - zap.Int64("dbID", dbID), - zap.Int64("tblID", tblID), - zap.Int64("from base", fromBase), - zap.Int64("from end", alloc.end), - zap.Int64("to base", newBase), - zap.Int64("to end", newEnd)) - alloc.end = newEnd - } - min = alloc.base - // Use uint64 n directly. - alloc.base = int64(uint64(alloc.base) + uint64(n1)) - return min, alloc.base, nil -} - -func (alloc *autoIDValue) alloc4Signed(ctx context.Context, - store kv.Storage, - dbID, tblID int64, - isUnsigned bool, - n uint64, increment, offset int64) (min int64, max int64, err error) { - // Check offset rebase if necessary. - if offset-1 > alloc.base { - if err := alloc.rebase4Signed(ctx, store, dbID, tblID, offset-1); err != nil { - return 0, 0, err - } - } - // calcNeededBatchSize calculates the total batch size needed. - n1 := calcNeededBatchSize(alloc.base, int64(n), increment, offset, isUnsigned) - - // Condition alloc.base+N1 > alloc.end will overflow when alloc.base + N1 > MaxInt64. So need this. - if math.MaxInt64-alloc.base <= n1 { - return 0, 0, errAutoincReadFailed - } - - // The local rest is not enough for allocN. - // If alloc.base is 0, the alloc may not be initialized, force fetch from remote. - if alloc.base+n1 > alloc.end || alloc.base == 0 { - var newBase, newEnd int64 - nextStep := int64(batch) - fromBase := alloc.base - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { - idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) - var err1 error - newBase, err1 = idAcc.Get() - if err1 != nil { - return err1 - } - // calcNeededBatchSize calculates the total batch size needed on global base. - // alloc.base == 0 means uninitialized - // newBase != alloc.end means something abnormal, maybe transaction conflict and retry? - if alloc.base == 0 || newBase != alloc.end { - alloc.base = newBase - alloc.end = newBase - n1 = calcNeededBatchSize(newBase, int64(n), increment, offset, isUnsigned) - } - // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. - if nextStep < n1 { - nextStep = n1 - } - tmpStep := mathutil.Min(math.MaxInt64-newBase, nextStep) - // The global rest is not enough for alloc. - if tmpStep < n1 { - return errAutoincReadFailed - } - newEnd, err1 = idAcc.Inc(tmpStep) - return err1 - }) - if err != nil { - return 0, 0, err - } - if newBase == math.MaxInt64 { - return 0, 0, errAutoincReadFailed - } - logutil.BgLogger().Info("alloc4Signed from", - zap.String("category", "autoid service"), - zap.Int64("dbID", dbID), - zap.Int64("tblID", tblID), - zap.Int64("from base", fromBase), - zap.Int64("from end", alloc.end), - zap.Int64("to base", newBase), - zap.Int64("to end", newEnd)) - alloc.end = newEnd - } - min = alloc.base - alloc.base += n1 - return min, alloc.base, nil -} - -func (alloc *autoIDValue) rebase4Unsigned(ctx context.Context, - store kv.Storage, - dbID, tblID int64, - requiredBase uint64) error { - // Satisfied by alloc.base, nothing to do. - if requiredBase <= uint64(alloc.base) { - return nil - } - // Satisfied by alloc.end, need to update alloc.base. - if requiredBase > uint64(alloc.base) && requiredBase <= uint64(alloc.end) { - alloc.base = int64(requiredBase) - return nil - } - - var newBase, newEnd uint64 - var oldValue int64 - startTime := time.Now() - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { - idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) - currentEnd, err1 := idAcc.Get() - if err1 != nil { - return err1 - } - oldValue = currentEnd - uCurrentEnd := uint64(currentEnd) - newBase = mathutil.Max(uCurrentEnd, requiredBase) - newEnd = mathutil.Min(math.MaxUint64-uint64(batch), newBase) + uint64(batch) - _, err1 = idAcc.Inc(int64(newEnd - uCurrentEnd)) - return err1 - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return err - } - - logutil.BgLogger().Info("rebase4Unsigned from", - zap.String("category", "autoid service"), - zap.Int64("dbID", dbID), - zap.Int64("tblID", tblID), - zap.Int64("from", oldValue), - zap.Uint64("to", newEnd)) - alloc.base, alloc.end = int64(newBase), int64(newEnd) - return nil -} - -func (alloc *autoIDValue) rebase4Signed(ctx context.Context, store kv.Storage, dbID, tblID int64, requiredBase int64) error { - // Satisfied by alloc.base, nothing to do. - if requiredBase <= alloc.base { - return nil - } - // Satisfied by alloc.end, need to update alloc.base. - if requiredBase > alloc.base && requiredBase <= alloc.end { - alloc.base = requiredBase - return nil - } - - var oldValue, newBase, newEnd int64 - startTime := time.Now() - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { - idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) - currentEnd, err1 := idAcc.Get() - if err1 != nil { - return err1 - } - oldValue = currentEnd - newBase = mathutil.Max(currentEnd, requiredBase) - newEnd = mathutil.Min(math.MaxInt64-batch, newBase) + batch - _, err1 = idAcc.Inc(newEnd - currentEnd) - return err1 - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return err - } - - logutil.BgLogger().Info("rebase4Signed from", - zap.Int64("dbID", dbID), - zap.Int64("tblID", tblID), - zap.Int64("from", oldValue), - zap.Int64("to", newEnd), - zap.String("category", "autoid service")) - alloc.base, alloc.end = newBase, newEnd - return nil -} - -// Service implement the grpc AutoIDAlloc service, defined in kvproto/pkg/autoid. -type Service struct { - autoIDLock sync.Mutex - autoIDMap map[autoIDKey]*autoIDValue - - leaderShip owner.Manager - store kv.Storage -} - -// New return a Service instance. -func New(selfAddr string, etcdAddr []string, store kv.Storage, tlsConfig *tls.Config) *Service { - cfg := config.GetGlobalConfig() - etcdLogCfg := zap.NewProductionConfig() - - cli, err := clientv3.New(clientv3.Config{ - LogConfig: &etcdLogCfg, - Endpoints: etcdAddr, - AutoSyncInterval: 30 * time.Second, - DialTimeout: 5 * time.Second, - DialOptions: []grpc.DialOption{ - grpc.WithBackoffMaxDelay(time.Second * 3), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: time.Duration(cfg.TiKVClient.GrpcKeepAliveTime) * time.Second, - Timeout: time.Duration(cfg.TiKVClient.GrpcKeepAliveTimeout) * time.Second, - }), - }, - TLS: tlsConfig, - }) - if store.GetCodec().GetKeyspace() != nil { - etcd.SetEtcdCliByNamespace(cli, keyspace.MakeKeyspaceEtcdNamespaceSlash(store.GetCodec())) - } - if err != nil { - panic(err) - } - return newWithCli(selfAddr, cli, store) -} - -func newWithCli(selfAddr string, cli *clientv3.Client, store kv.Storage) *Service { - l := owner.NewOwnerManager(context.Background(), cli, "autoid", selfAddr, autoIDLeaderPath) - service := &Service{ - autoIDMap: make(map[autoIDKey]*autoIDValue), - leaderShip: l, - store: store, - } - l.SetListener(&ownerListener{ - Service: service, - selfAddr: selfAddr, - }) - // 10 means that autoid service's etcd lease is 10s. - err := l.CampaignOwner(10) - if err != nil { - panic(err) - } - - return service -} - -type mockClient struct { - Service -} - -func (m *mockClient) AllocAutoID(ctx context.Context, in *autoid.AutoIDRequest, _ ...grpc.CallOption) (*autoid.AutoIDResponse, error) { - return m.Service.AllocAutoID(ctx, in) -} - -func (m *mockClient) Rebase(ctx context.Context, in *autoid.RebaseRequest, _ ...grpc.CallOption) (*autoid.RebaseResponse, error) { - return m.Service.Rebase(ctx, in) -} - -var global = make(map[string]*mockClient) - -// MockForTest is used for testing, the UT test and unistore use this. -func MockForTest(store kv.Storage) autoid.AutoIDAllocClient { - uuid := store.UUID() - ret, ok := global[uuid] - if !ok { - ret = &mockClient{ - Service{ - autoIDMap: make(map[autoIDKey]*autoIDValue), - leaderShip: nil, - store: store, - }, - } - global[uuid] = ret - } - return ret -} - -// Close closes the Service and clean up resource. -func (s *Service) Close() { - if s.leaderShip != nil && s.leaderShip.IsOwner() { - s.leaderShip.Cancel() - } -} - -// seekToFirstAutoIDSigned seeks to the next valid signed position. -func seekToFirstAutoIDSigned(base, increment, offset int64) int64 { - nr := (base + increment - offset) / increment - nr = nr*increment + offset - return nr -} - -// seekToFirstAutoIDUnSigned seeks to the next valid unsigned position. -func seekToFirstAutoIDUnSigned(base, increment, offset uint64) uint64 { - nr := (base + increment - offset) / increment - nr = nr*increment + offset - return nr -} - -func calcNeededBatchSize(base, n, increment, offset int64, isUnsigned bool) int64 { - if increment == 1 { - return n - } - if isUnsigned { - // SeekToFirstAutoIDUnSigned seeks to the next unsigned valid position. - nr := seekToFirstAutoIDUnSigned(uint64(base), uint64(increment), uint64(offset)) - // calculate the total batch size needed. - nr += (uint64(n) - 1) * uint64(increment) - return int64(nr - uint64(base)) - } - nr := seekToFirstAutoIDSigned(base, increment, offset) - // calculate the total batch size needed. - nr += (n - 1) * increment - return nr - base -} - -const batch = 4000 - -// AllocAutoID implements gRPC AutoIDAlloc interface. -func (s *Service) AllocAutoID(ctx context.Context, req *autoid.AutoIDRequest) (*autoid.AutoIDResponse, error) { - serviceKeyspaceID := uint32(s.store.GetCodec().GetKeyspaceID()) - if req.KeyspaceID != serviceKeyspaceID { - logutil.BgLogger().Info("Current service is not request keyspace leader.", zap.Uint32("req-keyspace-id", req.KeyspaceID), zap.Uint32("service-keyspace-id", serviceKeyspaceID)) - return nil, errors.Trace(errors.New("not leader")) - } - var res *autoid.AutoIDResponse - for { - var err error - res, err = s.allocAutoID(ctx, req) - if err != nil { - return nil, errors.Trace(err) - } - if res != nil { - break - } - } - return res, nil -} - -func (s *Service) getAlloc(dbID, tblID int64, isUnsigned bool) *autoIDValue { - key := autoIDKey{dbID: dbID, tblID: tblID} - s.autoIDLock.Lock() - defer s.autoIDLock.Unlock() - - val, ok := s.autoIDMap[key] - if !ok { - val = &autoIDValue{ - isUnsigned: isUnsigned, - token: make(chan struct{}, 1), - } - s.autoIDMap[key] = val - } - - return val -} - -func (s *Service) allocAutoID(ctx context.Context, req *autoid.AutoIDRequest) (*autoid.AutoIDResponse, error) { - if s.leaderShip != nil && !s.leaderShip.IsOwner() { - logutil.BgLogger().Info("Alloc AutoID fail, not leader", zap.String("category", "autoid service")) - return nil, errors.New("not leader") - } - - failpoint.Inject("mockErr", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("mock reload failed")) - } - }) - - val := s.getAlloc(req.DbID, req.TblID, req.IsUnsigned) - val.Lock() - defer val.Unlock() - - if req.N == 0 { - if val.base != 0 { - return &autoid.AutoIDResponse{ - Min: val.base, - Max: val.base, - }, nil - } - // This item is not initialized, get the data from remote. - var currentEnd int64 - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, s.store, true, func(_ context.Context, txn kv.Transaction) error { - idAcc := meta.NewMeta(txn).GetAutoIDAccessors(req.DbID, req.TblID).IncrementID(model.TableInfoVersion5) - var err1 error - currentEnd, err1 = idAcc.Get() - if err1 != nil { - return err1 - } - val.base = currentEnd - val.end = currentEnd - return nil - }) - if err != nil { - return &autoid.AutoIDResponse{Errmsg: []byte(err.Error())}, nil - } - return &autoid.AutoIDResponse{ - Min: currentEnd, - Max: currentEnd, - }, nil - } - - var min, max int64 - var err error - if req.IsUnsigned { - min, max, err = val.alloc4Unsigned(ctx, s.store, req.DbID, req.TblID, req.IsUnsigned, req.N, req.Increment, req.Offset) - } else { - min, max, err = val.alloc4Signed(ctx, s.store, req.DbID, req.TblID, req.IsUnsigned, req.N, req.Increment, req.Offset) - } - - if err != nil { - return &autoid.AutoIDResponse{Errmsg: []byte(err.Error())}, nil - } - return &autoid.AutoIDResponse{ - Min: min, - Max: max, - }, nil -} - -func (alloc *autoIDValue) forceRebase(ctx context.Context, store kv.Storage, dbID, tblID, requiredBase int64, isUnsigned bool) error { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - var oldValue int64 - err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { - idAcc := meta.NewMeta(txn).GetAutoIDAccessors(dbID, tblID).IncrementID(model.TableInfoVersion5) - currentEnd, err1 := idAcc.Get() - if err1 != nil { - return err1 - } - oldValue = currentEnd - var step int64 - if !isUnsigned { - step = requiredBase - currentEnd - } else { - uRequiredBase, uCurrentEnd := uint64(requiredBase), uint64(currentEnd) - step = int64(uRequiredBase - uCurrentEnd) - } - _, err1 = idAcc.Inc(step) - return err1 - }) - if err != nil { - return err - } - logutil.BgLogger().Info("forceRebase from", - zap.Int64("dbID", dbID), - zap.Int64("tblID", tblID), - zap.Int64("from", oldValue), - zap.Int64("to", requiredBase), - zap.Bool("isUnsigned", isUnsigned), - zap.String("category", "autoid service")) - alloc.base, alloc.end = requiredBase, requiredBase - return nil -} - -// Rebase implements gRPC AutoIDAlloc interface. -// req.N = 0 is handled specially, it is used to return the current auto ID value. -func (s *Service) Rebase(ctx context.Context, req *autoid.RebaseRequest) (*autoid.RebaseResponse, error) { - if s.leaderShip != nil && !s.leaderShip.IsOwner() { - logutil.BgLogger().Info("Rebase() fail, not leader", zap.String("category", "autoid service")) - return nil, errors.New("not leader") - } - - val := s.getAlloc(req.DbID, req.TblID, req.IsUnsigned) - val.Lock() - defer val.Unlock() - - if req.Force { - err := val.forceRebase(ctx, s.store, req.DbID, req.TblID, req.Base, req.IsUnsigned) - if err != nil { - return &autoid.RebaseResponse{Errmsg: []byte(err.Error())}, nil - } - } - - var err error - if req.IsUnsigned { - err = val.rebase4Unsigned(ctx, s.store, req.DbID, req.TblID, uint64(req.Base)) - } else { - err = val.rebase4Signed(ctx, s.store, req.DbID, req.TblID, req.Base) - } - if err != nil { - return &autoid.RebaseResponse{Errmsg: []byte(err.Error())}, nil - } - return &autoid.RebaseResponse{}, nil -} - -type ownerListener struct { - *Service - selfAddr string -} - -var _ owner.Listener = (*ownerListener)(nil) - -func (l *ownerListener) OnBecomeOwner() { - // Reset the map to avoid a case that a node lose leadership and regain it, then - // improperly use the stale map to serve the autoid requests. - // See https://github.com/pingcap/tidb/issues/52600 - l.autoIDLock.Lock() - clear(l.autoIDMap) - l.autoIDLock.Unlock() - - logutil.BgLogger().Info("leader change of autoid service, this node become owner", - zap.String("addr", l.selfAddr), - zap.String("category", "autoid service")) -} - -func (*ownerListener) OnRetireOwner() { -} - -func init() { - autoid1.MockForTest = MockForTest -} diff --git a/pkg/autoid_service/binding__failpoint_binding__.go b/pkg/autoid_service/binding__failpoint_binding__.go deleted file mode 100644 index 2c1025c7f434f..0000000000000 --- a/pkg/autoid_service/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package autoid - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/bindinfo/binding__failpoint_binding__.go b/pkg/bindinfo/binding__failpoint_binding__.go deleted file mode 100644 index 331f3106c1860..0000000000000 --- a/pkg/bindinfo/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package bindinfo - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/bindinfo/global_handle.go b/pkg/bindinfo/global_handle.go index 8e23689fd5c58..db027b6c5d2f7 100644 --- a/pkg/bindinfo/global_handle.go +++ b/pkg/bindinfo/global_handle.go @@ -716,9 +716,9 @@ func (h *globalBindingHandle) LoadBindingsFromStorage(sctx sessionctx.Context, s } func (h *globalBindingHandle) loadBindingsFromStorageInternal(sqlDigest string) (any, error) { - if _, _err_ := failpoint.Eval(_curpkg_("load_bindings_from_storage_internal_timeout")); _err_ == nil { + failpoint.Inject("load_bindings_from_storage_internal_timeout", func() { time.Sleep(time.Second) - } + }) var bindings Bindings selectStmt := fmt.Sprintf("SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source, sql_digest, plan_digest FROM mysql.bind_info where sql_digest = '%s'", sqlDigest) err := h.callWithSCtx(false, func(sctx sessionctx.Context) error { diff --git a/pkg/bindinfo/global_handle.go__failpoint_stash__ b/pkg/bindinfo/global_handle.go__failpoint_stash__ deleted file mode 100644 index db027b6c5d2f7..0000000000000 --- a/pkg/bindinfo/global_handle.go__failpoint_stash__ +++ /dev/null @@ -1,745 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package bindinfo - -import ( - "context" - "fmt" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/bindinfo/internal/logutil" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - driver "github.com/pingcap/tidb/pkg/types/parser_driver" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/hint" - utilparser "github.com/pingcap/tidb/pkg/util/parser" - "go.uber.org/zap" - "golang.org/x/sync/singleflight" -) - -// GlobalBindingHandle is used to handle all global sql bind operations. -type GlobalBindingHandle interface { - // Methods for create, get, drop global sql bindings. - - // MatchGlobalBinding returns the matched binding for this statement. - MatchGlobalBinding(sctx sessionctx.Context, fuzzyDigest string, tableNames []*ast.TableName) (matchedBinding Binding, isMatched bool) - - // GetAllGlobalBindings returns all bind records in cache. - GetAllGlobalBindings() (bindings Bindings) - - // CreateGlobalBinding creates a Bindings to the storage and the cache. - // It replaces all the exists bindings for the same normalized SQL. - CreateGlobalBinding(sctx sessionctx.Context, binding Binding) (err error) - - // DropGlobalBinding drop Bindings to the storage and Bindings int the cache. - DropGlobalBinding(sqlDigest string) (deletedRows uint64, err error) - - // SetGlobalBindingStatus set a Bindings's status to the storage and bind cache. - SetGlobalBindingStatus(newStatus, sqlDigest string) (ok bool, err error) - - // AddInvalidGlobalBinding adds Bindings which needs to be deleted into invalidBindingCache. - AddInvalidGlobalBinding(invalidBinding Binding) - - // DropInvalidGlobalBinding executes the drop Bindings tasks. - DropInvalidGlobalBinding() - - // Methods for load and clear global sql bindings. - - // Reset is to reset the BindHandle and clean old info. - Reset() - - // LoadFromStorageToCache loads global bindings from storage to the memory cache. - LoadFromStorageToCache(fullLoad bool) (err error) - - // GCGlobalBinding physically removes the deleted bind records in mysql.bind_info. - GCGlobalBinding() (err error) - - // Methods for memory control. - - // Size returns the size of bind info cache. - Size() int - - // SetBindingCacheCapacity reset the capacity for the bindingCache. - SetBindingCacheCapacity(capacity int64) - - // GetMemUsage returns the memory usage for the bind cache. - GetMemUsage() (memUsage int64) - - // GetMemCapacity returns the memory capacity for the bind cache. - GetMemCapacity() (memCapacity int64) - - // Clear resets the bind handle. It is only used for test. - Clear() - - // FlushGlobalBindings flushes the Bindings in temp maps to storage and loads them into cache. - FlushGlobalBindings() error - - // Methods for Auto Capture. - - // CaptureBaselines is used to automatically capture plan baselines. - CaptureBaselines() - - variable.Statistics -} - -// globalBindingHandle is used to handle all global sql bind operations. -type globalBindingHandle struct { - sPool util.SessionPool - - fuzzyBindingCache atomic.Value - - // lastTaskTime records the last update time for the global sql bind cache. - // This value is used to avoid reload duplicated bindings from storage. - lastUpdateTime atomic.Value - - // invalidBindings indicates the invalid bindings found during querying. - // A binding will be deleted from this map, after 2 bind-lease, after it is dropped from the kv. - invalidBindings *invalidBindingCache - - // syncBindingSingleflight is used to synchronize the execution of `LoadFromStorageToCache` method. - syncBindingSingleflight singleflight.Group -} - -// Lease influences the duration of loading bind info and handling invalid bind. -var Lease = 3 * time.Second - -const ( - // OwnerKey is the bindinfo owner path that is saved to etcd. - OwnerKey = "/tidb/bindinfo/owner" - // Prompt is the prompt for bindinfo owner manager. - Prompt = "bindinfo" - // BuiltinPseudoSQL4BindLock is used to simulate LOCK TABLE for mysql.bind_info. - BuiltinPseudoSQL4BindLock = "builtin_pseudo_sql_for_bind_lock" - - // LockBindInfoSQL simulates LOCK TABLE by updating a same row in each pessimistic transaction. - LockBindInfoSQL = `UPDATE mysql.bind_info SET source= 'builtin' WHERE original_sql= 'builtin_pseudo_sql_for_bind_lock'` - - // StmtRemoveDuplicatedPseudoBinding is used to remove duplicated pseudo binding. - // After using BR to sync bind_info between two clusters, the pseudo binding may be duplicated, and - // BR use this statement to remove duplicated rows, and this SQL should only be executed by BR. - StmtRemoveDuplicatedPseudoBinding = `DELETE FROM mysql.bind_info - WHERE original_sql='builtin_pseudo_sql_for_bind_lock' AND - _tidb_rowid NOT IN ( -- keep one arbitrary pseudo binding - SELECT _tidb_rowid FROM mysql.bind_info WHERE original_sql='builtin_pseudo_sql_for_bind_lock' limit 1)` -) - -// NewGlobalBindingHandle creates a new GlobalBindingHandle. -func NewGlobalBindingHandle(sPool util.SessionPool) GlobalBindingHandle { - handle := &globalBindingHandle{sPool: sPool} - handle.Reset() - return handle -} - -func (h *globalBindingHandle) getCache() FuzzyBindingCache { - return h.fuzzyBindingCache.Load().(FuzzyBindingCache) -} - -func (h *globalBindingHandle) setCache(c FuzzyBindingCache) { - // TODO: update the global cache in-place instead of replacing it and remove this function. - h.fuzzyBindingCache.Store(c) -} - -// Reset is to reset the BindHandle and clean old info. -func (h *globalBindingHandle) Reset() { - h.lastUpdateTime.Store(types.ZeroTimestamp) - h.invalidBindings = newInvalidBindingCache() - h.setCache(newFuzzyBindingCache(h.LoadBindingsFromStorage)) - variable.RegisterStatistics(h) -} - -func (h *globalBindingHandle) getLastUpdateTime() types.Time { - return h.lastUpdateTime.Load().(types.Time) -} - -func (h *globalBindingHandle) setLastUpdateTime(t types.Time) { - h.lastUpdateTime.Store(t) -} - -// LoadFromStorageToCache loads bindings from the storage into the cache. -func (h *globalBindingHandle) LoadFromStorageToCache(fullLoad bool) (err error) { - var lastUpdateTime types.Time - var timeCondition string - var newCache FuzzyBindingCache - if fullLoad { - lastUpdateTime = types.ZeroTimestamp - timeCondition = "" - newCache = newFuzzyBindingCache(h.LoadBindingsFromStorage) - } else { - lastUpdateTime = h.getLastUpdateTime() - timeCondition = fmt.Sprintf("WHERE update_time>'%s'", lastUpdateTime.String()) - newCache, err = h.getCache().Copy() - if err != nil { - return err - } - } - - selectStmt := fmt.Sprintf(`SELECT original_sql, bind_sql, default_db, status, create_time, - update_time, charset, collation, source, sql_digest, plan_digest FROM mysql.bind_info - %s ORDER BY update_time, create_time`, timeCondition) - - return h.callWithSCtx(false, func(sctx sessionctx.Context) error { - rows, _, err := execRows(sctx, selectStmt) - if err != nil { - return err - } - - defer func() { - h.setLastUpdateTime(lastUpdateTime) - h.setCache(newCache) - - metrics.BindingCacheMemUsage.Set(float64(h.GetMemUsage())) - metrics.BindingCacheMemLimit.Set(float64(h.GetMemCapacity())) - metrics.BindingCacheNumBindings.Set(float64(h.Size())) - }() - - for _, row := range rows { - // Skip the builtin record which is designed for binding synchronization. - if row.GetString(0) == BuiltinPseudoSQL4BindLock { - continue - } - sqlDigest, binding, err := newBinding(sctx, row) - - // Update lastUpdateTime to the newest one. - // Even if this one is an invalid bind. - if binding.UpdateTime.Compare(lastUpdateTime) > 0 { - lastUpdateTime = binding.UpdateTime - } - - if err != nil { - logutil.BindLogger().Warn("failed to generate bind record from data row", zap.Error(err)) - continue - } - - oldBinding := newCache.GetBinding(sqlDigest) - newBinding := removeDeletedBindings(merge(oldBinding, []Binding{binding})) - if len(newBinding) > 0 { - err = newCache.SetBinding(sqlDigest, newBinding) - if err != nil { - // When the memory capacity of bing_cache is not enough, - // there will be some memory-related errors in multiple places. - // Only needs to be handled once. - logutil.BindLogger().Warn("BindHandle.Update", zap.Error(err)) - } - } else { - newCache.RemoveBinding(sqlDigest) - } - } - return nil - }) -} - -// CreateGlobalBinding creates a Bindings to the storage and the cache. -// It replaces all the exists bindings for the same normalized SQL. -func (h *globalBindingHandle) CreateGlobalBinding(sctx sessionctx.Context, binding Binding) (err error) { - if err := prepareHints(sctx, &binding); err != nil { - return err - } - defer func() { - if err == nil { - err = h.LoadFromStorageToCache(false) - } - }() - - return h.callWithSCtx(true, func(sctx sessionctx.Context) error { - // Lock mysql.bind_info to synchronize with CreateBinding / AddBinding / DropBinding on other tidb instances. - if err = lockBindInfoTable(sctx); err != nil { - return err - } - - now := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) - - updateTs := now.String() - _, err = exec(sctx, `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE original_sql = %? AND update_time < %?`, - deleted, updateTs, binding.OriginalSQL, updateTs) - if err != nil { - return err - } - - binding.CreateTime = now - binding.UpdateTime = now - - // Insert the Bindings to the storage. - _, err = exec(sctx, `INSERT INTO mysql.bind_info VALUES (%?,%?, %?, %?, %?, %?, %?, %?, %?, %?, %?)`, - binding.OriginalSQL, - binding.BindSQL, - strings.ToLower(binding.Db), - binding.Status, - binding.CreateTime.String(), - binding.UpdateTime.String(), - binding.Charset, - binding.Collation, - binding.Source, - binding.SQLDigest, - binding.PlanDigest, - ) - if err != nil { - return err - } - return nil - }) -} - -// dropGlobalBinding drops a Bindings to the storage and Bindings int the cache. -func (h *globalBindingHandle) dropGlobalBinding(sqlDigest string) (deletedRows uint64, err error) { - err = h.callWithSCtx(false, func(sctx sessionctx.Context) error { - // Lock mysql.bind_info to synchronize with CreateBinding / AddBinding / DropBinding on other tidb instances. - if err = lockBindInfoTable(sctx); err != nil { - return err - } - - updateTs := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3).String() - - _, err = exec(sctx, `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE sql_digest = %? AND update_time < %? AND status != %?`, - deleted, updateTs, sqlDigest, updateTs, deleted) - if err != nil { - return err - } - deletedRows = sctx.GetSessionVars().StmtCtx.AffectedRows() - return nil - }) - return -} - -// DropGlobalBinding drop Bindings to the storage and Bindings int the cache. -func (h *globalBindingHandle) DropGlobalBinding(sqlDigest string) (deletedRows uint64, err error) { - if sqlDigest == "" { - return 0, errors.New("sql digest is empty") - } - defer func() { - if err == nil { - err = h.LoadFromStorageToCache(false) - } - }() - return h.dropGlobalBinding(sqlDigest) -} - -// SetGlobalBindingStatus set a Bindings's status to the storage and bind cache. -func (h *globalBindingHandle) SetGlobalBindingStatus(newStatus, sqlDigest string) (ok bool, err error) { - var ( - updateTs types.Time - oldStatus0, oldStatus1 string - ) - if newStatus == Disabled { - // For compatibility reasons, when we need to 'set binding disabled for ', - // we need to consider both the 'enabled' and 'using' status. - oldStatus0 = Using - oldStatus1 = Enabled - } else if newStatus == Enabled { - // In order to unify the code, two identical old statuses are set. - oldStatus0 = Disabled - oldStatus1 = Disabled - } - - defer func() { - if err == nil { - err = h.LoadFromStorageToCache(false) - } - }() - - err = h.callWithSCtx(true, func(sctx sessionctx.Context) error { - // Lock mysql.bind_info to synchronize with SetBindingStatus on other tidb instances. - if err = lockBindInfoTable(sctx); err != nil { - return err - } - - updateTs = types.NewTime(types.FromGoTime(time.Now()), mysql.TypeTimestamp, 3) - updateTsStr := updateTs.String() - - _, err = exec(sctx, `UPDATE mysql.bind_info SET status = %?, update_time = %? WHERE sql_digest = %? AND update_time < %? AND status IN (%?, %?)`, - newStatus, updateTsStr, sqlDigest, updateTsStr, oldStatus0, oldStatus1) - return err - }) - return -} - -// GCGlobalBinding physically removes the deleted bind records in mysql.bind_info. -func (h *globalBindingHandle) GCGlobalBinding() (err error) { - return h.callWithSCtx(true, func(sctx sessionctx.Context) error { - // Lock mysql.bind_info to synchronize with CreateBinding / AddBinding / DropBinding on other tidb instances. - if err = lockBindInfoTable(sctx); err != nil { - return err - } - - // To make sure that all the deleted bind records have been acknowledged to all tidb, - // we only garbage collect those records with update_time before 10 leases. - updateTime := time.Now().Add(-(10 * Lease)) - updateTimeStr := types.NewTime(types.FromGoTime(updateTime), mysql.TypeTimestamp, 3).String() - _, err = exec(sctx, `DELETE FROM mysql.bind_info WHERE status = 'deleted' and update_time < %?`, updateTimeStr) - return err - }) -} - -// lockBindInfoTable simulates `LOCK TABLE mysql.bind_info WRITE` by acquiring a pessimistic lock on a -// special builtin row of mysql.bind_info. Note that this function must be called with h.sctx.Lock() held. -// We can replace this implementation to normal `LOCK TABLE mysql.bind_info WRITE` if that feature is -// generally available later. -// This lock would enforce the CREATE / DROP GLOBAL BINDING statements to be executed sequentially, -// even if they come from different tidb instances. -func lockBindInfoTable(sctx sessionctx.Context) error { - // h.sctx already locked. - _, err := exec(sctx, LockBindInfoSQL) - return err -} - -// invalidBindingCache is used to store invalid bindings temporarily. -type invalidBindingCache struct { - mu sync.RWMutex - m map[string]Binding // key: sqlDigest -} - -func newInvalidBindingCache() *invalidBindingCache { - return &invalidBindingCache{ - m: make(map[string]Binding), - } -} - -func (c *invalidBindingCache) add(binding Binding) { - c.mu.Lock() - defer c.mu.Unlock() - c.m[binding.SQLDigest] = binding -} - -func (c *invalidBindingCache) getAll() Bindings { - c.mu.Lock() - defer c.mu.Unlock() - bindings := make(Bindings, 0, len(c.m)) - for _, binding := range c.m { - bindings = append(bindings, binding) - } - return bindings -} - -func (c *invalidBindingCache) reset() { - c.mu.Lock() - defer c.mu.Unlock() - c.m = make(map[string]Binding) -} - -// DropInvalidGlobalBinding executes the drop Bindings tasks. -func (h *globalBindingHandle) DropInvalidGlobalBinding() { - defer func() { - if err := h.LoadFromStorageToCache(false); err != nil { - logutil.BindLogger().Warn("drop invalid global binding error", zap.Error(err)) - } - }() - - invalidBindings := h.invalidBindings.getAll() - h.invalidBindings.reset() - for _, invalidBinding := range invalidBindings { - if _, err := h.dropGlobalBinding(invalidBinding.SQLDigest); err != nil { - logutil.BindLogger().Debug("flush bind record failed", zap.Error(err)) - } - } -} - -// AddInvalidGlobalBinding adds Bindings which needs to be deleted into invalidBindings. -func (h *globalBindingHandle) AddInvalidGlobalBinding(invalidBinding Binding) { - h.invalidBindings.add(invalidBinding) -} - -// Size returns the size of bind info cache. -func (h *globalBindingHandle) Size() int { - size := len(h.getCache().GetAllBindings()) - return size -} - -// MatchGlobalBinding returns the matched binding for this statement. -func (h *globalBindingHandle) MatchGlobalBinding(sctx sessionctx.Context, fuzzyDigest string, tableNames []*ast.TableName) (matchedBinding Binding, isMatched bool) { - return h.getCache().FuzzyMatchingBinding(sctx, fuzzyDigest, tableNames) -} - -// GetAllGlobalBindings returns all bind records in cache. -func (h *globalBindingHandle) GetAllGlobalBindings() (bindings Bindings) { - return h.getCache().GetAllBindings() -} - -// SetBindingCacheCapacity reset the capacity for the bindingCache. -// It will not affect already cached Bindings. -func (h *globalBindingHandle) SetBindingCacheCapacity(capacity int64) { - h.getCache().SetMemCapacity(capacity) -} - -// GetMemUsage returns the memory usage for the bind cache. -func (h *globalBindingHandle) GetMemUsage() (memUsage int64) { - return h.getCache().GetMemUsage() -} - -// GetMemCapacity returns the memory capacity for the bind cache. -func (h *globalBindingHandle) GetMemCapacity() (memCapacity int64) { - return h.getCache().GetMemCapacity() -} - -// newBinding builds Bindings from a tuple in storage. -func newBinding(sctx sessionctx.Context, row chunk.Row) (string, Binding, error) { - status := row.GetString(3) - // For compatibility, the 'Using' status binding will be converted to the 'Enabled' status binding. - if status == Using { - status = Enabled - } - binding := Binding{ - OriginalSQL: row.GetString(0), - Db: strings.ToLower(row.GetString(2)), - BindSQL: row.GetString(1), - Status: status, - CreateTime: row.GetTime(4), - UpdateTime: row.GetTime(5), - Charset: row.GetString(6), - Collation: row.GetString(7), - Source: row.GetString(8), - SQLDigest: row.GetString(9), - PlanDigest: row.GetString(10), - } - sqlDigest := parser.DigestNormalized(binding.OriginalSQL) - err := prepareHints(sctx, &binding) - sctx.GetSessionVars().CurrentDB = binding.Db - return sqlDigest.String(), binding, err -} - -func getHintsForSQL(sctx sessionctx.Context, sql string) (string, error) { - origVals := sctx.GetSessionVars().UsePlanBaselines - sctx.GetSessionVars().UsePlanBaselines = false - - // Usually passing a sprintf to ExecuteInternal is not recommended, but in this case - // it is safe because ExecuteInternal does not permit MultiStatement execution. Thus, - // the statement won't be able to "break out" from EXPLAIN. - rs, err := exec(sctx, fmt.Sprintf("EXPLAIN FORMAT='hint' %s", sql)) - sctx.GetSessionVars().UsePlanBaselines = origVals - if rs != nil { - defer func() { - // Audit log is collected in Close(), set InRestrictedSQL to avoid 'create sql binding' been recorded as 'explain'. - origin := sctx.GetSessionVars().InRestrictedSQL - sctx.GetSessionVars().InRestrictedSQL = true - terror.Call(rs.Close) - sctx.GetSessionVars().InRestrictedSQL = origin - }() - } - if err != nil { - return "", err - } - chk := rs.NewChunk(nil) - err = rs.Next(context.TODO(), chk) - if err != nil { - return "", err - } - return chk.GetRow(0).GetString(0), nil -} - -// GenerateBindingSQL generates binding sqls from stmt node and plan hints. -func GenerateBindingSQL(stmtNode ast.StmtNode, planHint string, skipCheckIfHasParam bool, defaultDB string) string { - // If would be nil for very simple cases such as point get, we do not need to evolve for them. - if planHint == "" { - return "" - } - if !skipCheckIfHasParam { - paramChecker := ¶mMarkerChecker{} - stmtNode.Accept(paramChecker) - // We need to evolve on current sql, but we cannot restore values for paramMarkers yet, - // so just ignore them now. - if paramChecker.hasParamMarker { - return "" - } - } - // We need to evolve plan based on the current sql, not the original sql which may have different parameters. - // So here we would remove the hint and inject the current best plan hint. - hint.BindHint(stmtNode, &hint.HintsSet{}) - bindSQL := utilparser.RestoreWithDefaultDB(stmtNode, defaultDB, "") - if bindSQL == "" { - return "" - } - switch n := stmtNode.(type) { - case *ast.DeleteStmt: - deleteIdx := strings.Index(bindSQL, "DELETE") - // Remove possible `explain` prefix. - bindSQL = bindSQL[deleteIdx:] - return strings.Replace(bindSQL, "DELETE", fmt.Sprintf("DELETE /*+ %s*/", planHint), 1) - case *ast.UpdateStmt: - updateIdx := strings.Index(bindSQL, "UPDATE") - // Remove possible `explain` prefix. - bindSQL = bindSQL[updateIdx:] - return strings.Replace(bindSQL, "UPDATE", fmt.Sprintf("UPDATE /*+ %s*/", planHint), 1) - case *ast.SelectStmt: - var selectIdx int - if n.With != nil { - var withSb strings.Builder - withIdx := strings.Index(bindSQL, "WITH") - restoreCtx := format.NewRestoreCtx(format.RestoreStringSingleQuotes|format.RestoreSpacesAroundBinaryOperation|format.RestoreStringWithoutCharset|format.RestoreNameBackQuotes, &withSb) - restoreCtx.DefaultDB = defaultDB - if err := n.With.Restore(restoreCtx); err != nil { - logutil.BindLogger().Debug("restore SQL failed", zap.Error(err)) - return "" - } - withEnd := withIdx + len(withSb.String()) - tmp := strings.Replace(bindSQL[withEnd:], "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1) - return strings.Join([]string{bindSQL[withIdx:withEnd], tmp}, "") - } - selectIdx = strings.Index(bindSQL, "SELECT") - // Remove possible `explain` prefix. - bindSQL = bindSQL[selectIdx:] - return strings.Replace(bindSQL, "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1) - case *ast.InsertStmt: - insertIdx := int(0) - if n.IsReplace { - insertIdx = strings.Index(bindSQL, "REPLACE") - } else { - insertIdx = strings.Index(bindSQL, "INSERT") - } - // Remove possible `explain` prefix. - bindSQL = bindSQL[insertIdx:] - return strings.Replace(bindSQL, "SELECT", fmt.Sprintf("SELECT /*+ %s*/", planHint), 1) - } - logutil.BindLogger().Debug("unexpected statement type when generating bind SQL", zap.Any("statement", stmtNode)) - return "" -} - -type paramMarkerChecker struct { - hasParamMarker bool -} - -func (e *paramMarkerChecker) Enter(in ast.Node) (ast.Node, bool) { - if _, ok := in.(*driver.ParamMarkerExpr); ok { - e.hasParamMarker = true - return in, true - } - return in, false -} - -func (*paramMarkerChecker) Leave(in ast.Node) (ast.Node, bool) { - return in, true -} - -// Clear resets the bind handle. It is only used for test. -func (h *globalBindingHandle) Clear() { - h.setCache(newFuzzyBindingCache(h.LoadBindingsFromStorage)) - h.setLastUpdateTime(types.ZeroTimestamp) - h.invalidBindings.reset() -} - -// FlushGlobalBindings flushes the Bindings in temp maps to storage and loads them into cache. -func (h *globalBindingHandle) FlushGlobalBindings() error { - h.DropInvalidGlobalBinding() - return h.LoadFromStorageToCache(false) -} - -func (h *globalBindingHandle) callWithSCtx(wrapTxn bool, f func(sctx sessionctx.Context) error) (err error) { - resource, err := h.sPool.Get() - if err != nil { - return err - } - defer func() { - if err == nil { // only recycle when no error - h.sPool.Put(resource) - } - }() - sctx := resource.(sessionctx.Context) - if wrapTxn { - if _, err = exec(sctx, "BEGIN PESSIMISTIC"); err != nil { - return - } - defer func() { - if err == nil { - _, err = exec(sctx, "COMMIT") - } else { - _, err1 := exec(sctx, "ROLLBACK") - terror.Log(errors.Trace(err1)) - } - }() - } - - err = f(sctx) - return -} - -var ( - lastPlanBindingUpdateTime = "last_plan_binding_update_time" -) - -// GetScope gets the status variables scope. -func (*globalBindingHandle) GetScope(_ string) variable.ScopeFlag { - return variable.ScopeSession -} - -// Stats returns the server statistics. -func (h *globalBindingHandle) Stats(_ *variable.SessionVars) (map[string]any, error) { - m := make(map[string]any) - m[lastPlanBindingUpdateTime] = h.getLastUpdateTime().String() - return m, nil -} - -// LoadBindingsFromStorageToCache loads global bindings from storage to the memory cache. -func (h *globalBindingHandle) LoadBindingsFromStorage(sctx sessionctx.Context, sqlDigest string) (Bindings, error) { - if sqlDigest == "" { - return nil, nil - } - timeout := time.Duration(sctx.GetSessionVars().LoadBindingTimeout) * time.Millisecond - resultChan := h.syncBindingSingleflight.DoChan(sqlDigest, func() (any, error) { - return h.loadBindingsFromStorageInternal(sqlDigest) - }) - select { - case result := <-resultChan: - if result.Err != nil { - return nil, result.Err - } - bindings := result.Val - if bindings == nil { - return nil, nil - } - return bindings.(Bindings), nil - case <-time.After(timeout): - return nil, errors.New("load bindings from storage timeout") - } -} - -func (h *globalBindingHandle) loadBindingsFromStorageInternal(sqlDigest string) (any, error) { - failpoint.Inject("load_bindings_from_storage_internal_timeout", func() { - time.Sleep(time.Second) - }) - var bindings Bindings - selectStmt := fmt.Sprintf("SELECT original_sql, bind_sql, default_db, status, create_time, update_time, charset, collation, source, sql_digest, plan_digest FROM mysql.bind_info where sql_digest = '%s'", sqlDigest) - err := h.callWithSCtx(false, func(sctx sessionctx.Context) error { - rows, _, err := execRows(sctx, selectStmt) - if err != nil { - return err - } - bindings = make([]Binding, 0, len(rows)) - for _, row := range rows { - // Skip the builtin record which is designed for binding synchronization. - if row.GetString(0) == BuiltinPseudoSQL4BindLock { - continue - } - _, binding, err := newBinding(sctx, row) - if err != nil { - logutil.BindLogger().Warn("failed to generate bind record from data row", zap.Error(err)) - continue - } - bindings = append(bindings, binding) - } - return nil - }) - return bindings, err -} diff --git a/pkg/ddl/add_column.go b/pkg/ddl/add_column.go index 3f95a7de11901..54f519b7731af 100644 --- a/pkg/ddl/add_column.go +++ b/pkg/ddl/add_column.go @@ -60,12 +60,12 @@ func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) return ver, nil } - if val, _err_ := failpoint.Eval(_curpkg_("errorBeforeDecodeArgs")); _err_ == nil { + failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { - return ver, errors.New("occur an error before decode args") + failpoint.Return(ver, errors.New("occur an error before decode args")) } - } + }) tblInfo, columnInfo, colFromArgs, pos, ifNotExists, err := checkAddColumn(t, job) if err != nil { @@ -117,7 +117,7 @@ func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) case model.StateWriteReorganization: // reorganization -> public // Adjust table column offset. - failpoint.Call(_curpkg_("onAddColumnStateWriteReorg")) + failpoint.InjectCall("onAddColumnStateWriteReorg") offset, err := LocateOffsetToMove(columnInfo.Offset, pos, tblInfo) if err != nil { return ver, errors.Trace(err) diff --git a/pkg/ddl/add_column.go__failpoint_stash__ b/pkg/ddl/add_column.go__failpoint_stash__ deleted file mode 100644 index 54f519b7731af..0000000000000 --- a/pkg/ddl/add_column.go__failpoint_stash__ +++ /dev/null @@ -1,1288 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "fmt" - "strconv" - "strings" - "time" - "unicode/utf8" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/expression" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - field_types "github.com/pingcap/tidb/pkg/parser/types" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - driver "github.com/pingcap/tidb/pkg/types/parser_driver" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/hack" - "go.uber.org/zap" -) - -func onAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - // Handle the rolling back job. - if job.IsRollingback() { - ver, err = onDropColumn(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - return ver, nil - } - - failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - failpoint.Return(ver, errors.New("occur an error before decode args")) - } - }) - - tblInfo, columnInfo, colFromArgs, pos, ifNotExists, err := checkAddColumn(t, job) - if err != nil { - if ifNotExists && infoschema.ErrColumnExists.Equal(err) { - job.Warning = toTError(err) - job.State = model.JobStateDone - return ver, nil - } - return ver, errors.Trace(err) - } - if columnInfo == nil { - columnInfo = InitAndAddColumnToTable(tblInfo, colFromArgs) - logutil.DDLLogger().Info("run add column job", zap.Stringer("job", job), zap.Reflect("columnInfo", *columnInfo)) - if err = checkAddColumnTooManyColumns(len(tblInfo.Columns)); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - } - - originalState := columnInfo.State - switch columnInfo.State { - case model.StateNone: - // none -> delete only - columnInfo.State = model.StateDeleteOnly - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != columnInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - job.SchemaState = model.StateDeleteOnly - case model.StateDeleteOnly: - // delete only -> write only - columnInfo.State = model.StateWriteOnly - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - // Update the job state when all affairs done. - job.SchemaState = model.StateWriteOnly - case model.StateWriteOnly: - // write only -> reorganization - columnInfo.State = model.StateWriteReorganization - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - // Update the job state when all affairs done. - job.SchemaState = model.StateWriteReorganization - job.MarkNonRevertible() - case model.StateWriteReorganization: - // reorganization -> public - // Adjust table column offset. - failpoint.InjectCall("onAddColumnStateWriteReorg") - offset, err := LocateOffsetToMove(columnInfo.Offset, pos, tblInfo) - if err != nil { - return ver, errors.Trace(err) - } - tblInfo.MoveColumnInfo(columnInfo.Offset, offset) - columnInfo.State = model.StatePublic - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - addColumnEvent := statsutil.NewAddColumnEvent( - job.SchemaID, - tblInfo, - []*model.ColumnInfo{columnInfo}, - ) - asyncNotifyEvent(d, addColumnEvent) - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("column", columnInfo.State) - } - - return ver, errors.Trace(err) -} - -func checkAndCreateNewColumn(ctx sessionctx.Context, ti ast.Ident, schema *model.DBInfo, spec *ast.AlterTableSpec, t table.Table, specNewColumn *ast.ColumnDef) (*table.Column, error) { - err := checkUnsupportedColumnConstraint(specNewColumn, ti) - if err != nil { - return nil, errors.Trace(err) - } - - colName := specNewColumn.Name.Name.O - // Check whether added column has existed. - col := table.FindCol(t.Cols(), colName) - if col != nil { - err = infoschema.ErrColumnExists.GenWithStackByArgs(colName) - if spec.IfNotExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil, nil - } - return nil, err - } - if err = checkColumnAttributes(colName, specNewColumn.Tp); err != nil { - return nil, errors.Trace(err) - } - if utf8.RuneCountInString(colName) > mysql.MaxColumnNameLength { - return nil, dbterror.ErrTooLongIdent.GenWithStackByArgs(colName) - } - - return CreateNewColumn(ctx, schema, spec, t, specNewColumn) -} - -func checkUnsupportedColumnConstraint(col *ast.ColumnDef, ti ast.Ident) error { - for _, constraint := range col.Options { - switch constraint.Tp { - case ast.ColumnOptionAutoIncrement: - return dbterror.ErrUnsupportedAddColumn.GenWithStack("unsupported add column '%s' constraint AUTO_INCREMENT when altering '%s.%s'", col.Name, ti.Schema, ti.Name) - case ast.ColumnOptionPrimaryKey: - return dbterror.ErrUnsupportedAddColumn.GenWithStack("unsupported add column '%s' constraint PRIMARY KEY when altering '%s.%s'", col.Name, ti.Schema, ti.Name) - case ast.ColumnOptionUniqKey: - return dbterror.ErrUnsupportedAddColumn.GenWithStack("unsupported add column '%s' constraint UNIQUE KEY when altering '%s.%s'", col.Name, ti.Schema, ti.Name) - case ast.ColumnOptionAutoRandom: - errMsg := fmt.Sprintf(autoid.AutoRandomAlterAddColumn, col.Name, ti.Schema, ti.Name) - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) - } - } - - return nil -} - -// CreateNewColumn creates a new column according to the column information. -func CreateNewColumn(ctx sessionctx.Context, schema *model.DBInfo, spec *ast.AlterTableSpec, t table.Table, specNewColumn *ast.ColumnDef) (*table.Column, error) { - // If new column is a generated column, do validation. - // NOTE: we do check whether the column refers other generated - // columns occurring later in a table, but we don't handle the col offset. - for _, option := range specNewColumn.Options { - if option.Tp == ast.ColumnOptionGenerated { - if err := checkIllegalFn4Generated(specNewColumn.Name.Name.L, typeColumn, option.Expr); err != nil { - return nil, errors.Trace(err) - } - - if option.Stored { - return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Adding generated stored column through ALTER TABLE") - } - - _, dependColNames, err := findDependedColumnNames(schema.Name, t.Meta().Name, specNewColumn) - if err != nil { - return nil, errors.Trace(err) - } - if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { - if err := checkAutoIncrementRef(specNewColumn.Name.Name.L, dependColNames, t.Meta()); err != nil { - return nil, errors.Trace(err) - } - } - duplicateColNames := make(map[string]struct{}, len(dependColNames)) - for k := range dependColNames { - duplicateColNames[k] = struct{}{} - } - cols := t.Cols() - - if err := checkDependedColExist(dependColNames, cols); err != nil { - return nil, errors.Trace(err) - } - - if err := verifyColumnGenerationSingle(duplicateColNames, cols, spec.Position); err != nil { - return nil, errors.Trace(err) - } - } - // Specially, since sequence has been supported, if a newly added column has a - // sequence nextval function as it's default value option, it won't fill the - // known rows with specific sequence next value under current add column logic. - // More explanation can refer: TestSequenceDefaultLogic's comment in sequence_test.go - if option.Tp == ast.ColumnOptionDefaultValue { - if f, ok := option.Expr.(*ast.FuncCallExpr); ok { - switch f.FnName.L { - case ast.NextVal: - if _, err := getSequenceDefaultValue(option); err != nil { - return nil, errors.Trace(err) - } - return nil, errors.Trace(dbterror.ErrAddColumnWithSequenceAsDefault.GenWithStackByArgs(specNewColumn.Name.Name.O)) - case ast.Rand, ast.UUID, ast.UUIDToBin, ast.Replace, ast.Upper: - return nil, errors.Trace(dbterror.ErrBinlogUnsafeSystemFunction.GenWithStackByArgs()) - } - } - } - } - - tableCharset, tableCollate, err := ResolveCharsetCollation(ctx.GetSessionVars(), - ast.CharsetOpt{Chs: t.Meta().Charset, Col: t.Meta().Collate}, - ast.CharsetOpt{Chs: schema.Charset, Col: schema.Collate}, - ) - if err != nil { - return nil, errors.Trace(err) - } - // Ignore table constraints now, they will be checked later. - // We use length(t.Cols()) as the default offset firstly, we will change the column's offset later. - col, _, err := buildColumnAndConstraint( - ctx, - len(t.Cols()), - specNewColumn, - nil, - tableCharset, - tableCollate, - ) - if err != nil { - return nil, errors.Trace(err) - } - - originDefVal, err := generateOriginDefaultValue(col.ToInfo(), ctx) - if err != nil { - return nil, errors.Trace(err) - } - - err = col.SetOriginDefaultValue(originDefVal) - return col, err -} - -// buildColumnAndConstraint builds table.Column and ast.Constraint from the parameters. -// outPriKeyConstraint is the primary key constraint out of column definition. For example: -// `create table t1 (id int , age int, primary key(id));` -func buildColumnAndConstraint( - ctx sessionctx.Context, - offset int, - colDef *ast.ColumnDef, - outPriKeyConstraint *ast.Constraint, - tblCharset string, - tblCollate string, -) (*table.Column, []*ast.Constraint, error) { - if colName := colDef.Name.Name.L; colName == model.ExtraHandleName.L { - return nil, nil, dbterror.ErrWrongColumnName.GenWithStackByArgs(colName) - } - - // specifiedCollate refers to the last collate specified in colDef.Options. - chs, coll, err := getCharsetAndCollateInColumnDef(ctx.GetSessionVars(), colDef) - if err != nil { - return nil, nil, errors.Trace(err) - } - chs, coll, err = ResolveCharsetCollation(ctx.GetSessionVars(), - ast.CharsetOpt{Chs: chs, Col: coll}, - ast.CharsetOpt{Chs: tblCharset, Col: tblCollate}, - ) - chs, coll = OverwriteCollationWithBinaryFlag(ctx.GetSessionVars(), colDef, chs, coll) - if err != nil { - return nil, nil, errors.Trace(err) - } - - if err := setCharsetCollationFlenDecimal(colDef.Tp, colDef.Name.Name.O, chs, coll, ctx.GetSessionVars()); err != nil { - return nil, nil, errors.Trace(err) - } - decodeEnumSetBinaryLiteralToUTF8(colDef.Tp, chs) - col, cts, err := columnDefToCol(ctx, offset, colDef, outPriKeyConstraint) - if err != nil { - return nil, nil, errors.Trace(err) - } - return col, cts, nil -} - -// getCharsetAndCollateInColumnDef will iterate collate in the options, validate it by checking the charset -// of column definition. If there's no collate in the option, the default collate of column's charset will be used. -func getCharsetAndCollateInColumnDef(sessVars *variable.SessionVars, def *ast.ColumnDef) (chs, coll string, err error) { - chs = def.Tp.GetCharset() - coll = def.Tp.GetCollate() - if chs != "" && coll == "" { - if coll, err = GetDefaultCollation(sessVars, chs); err != nil { - return "", "", errors.Trace(err) - } - } - for _, opt := range def.Options { - if opt.Tp == ast.ColumnOptionCollate { - info, err := collate.GetCollationByName(opt.StrValue) - if err != nil { - return "", "", errors.Trace(err) - } - if chs == "" { - chs = info.CharsetName - } else if chs != info.CharsetName { - return "", "", dbterror.ErrCollationCharsetMismatch.GenWithStackByArgs(info.Name, chs) - } - coll = info.Name - } - } - return -} - -// OverwriteCollationWithBinaryFlag is used to handle the case like -// -// CREATE TABLE t (a VARCHAR(255) BINARY) CHARSET utf8 COLLATE utf8_general_ci; -// -// The 'BINARY' sets the column collation to *_bin according to the table charset. -func OverwriteCollationWithBinaryFlag(sessVars *variable.SessionVars, colDef *ast.ColumnDef, chs, coll string) (newChs string, newColl string) { - ignoreBinFlag := colDef.Tp.GetCharset() != "" && (colDef.Tp.GetCollate() != "" || containsColumnOption(colDef, ast.ColumnOptionCollate)) - if ignoreBinFlag { - return chs, coll - } - needOverwriteBinColl := types.IsString(colDef.Tp.GetType()) && mysql.HasBinaryFlag(colDef.Tp.GetFlag()) - if needOverwriteBinColl { - newColl, err := GetDefaultCollation(sessVars, chs) - if err != nil { - return chs, coll - } - return chs, newColl - } - return chs, coll -} - -func setCharsetCollationFlenDecimal(tp *types.FieldType, colName, colCharset, colCollate string, sessVars *variable.SessionVars) error { - var err error - if typesNeedCharset(tp.GetType()) { - tp.SetCharset(colCharset) - tp.SetCollate(colCollate) - } else { - tp.SetCharset(charset.CharsetBin) - tp.SetCollate(charset.CharsetBin) - } - - // Use default value for flen or decimal when they are unspecified. - defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(tp.GetType()) - if tp.GetDecimal() == types.UnspecifiedLength { - tp.SetDecimal(defaultDecimal) - } - if tp.GetFlen() == types.UnspecifiedLength { - tp.SetFlen(defaultFlen) - if mysql.HasUnsignedFlag(tp.GetFlag()) && tp.GetType() != mysql.TypeLonglong && mysql.IsIntegerType(tp.GetType()) { - // Issue #4684: the flen of unsigned integer(except bigint) is 1 digit shorter than signed integer - // because it has no prefix "+" or "-" character. - tp.SetFlen(tp.GetFlen() - 1) - } - } else { - // Adjust the field type for blob/text types if the flen is set. - if err = adjustBlobTypesFlen(tp, colCharset); err != nil { - return err - } - } - return checkTooBigFieldLengthAndTryAutoConvert(tp, colName, sessVars) -} - -func decodeEnumSetBinaryLiteralToUTF8(tp *types.FieldType, chs string) { - if tp.GetType() != mysql.TypeEnum && tp.GetType() != mysql.TypeSet { - return - } - enc := charset.FindEncoding(chs) - for i, elem := range tp.GetElems() { - if !tp.GetElemIsBinaryLit(i) { - continue - } - s, err := enc.Transform(nil, hack.Slice(elem), charset.OpDecodeReplace) - if err != nil { - logutil.DDLLogger().Warn("decode enum binary literal to utf-8 failed", zap.Error(err)) - } - tp.SetElem(i, string(hack.String(s))) - } - tp.CleanElemIsBinaryLit() -} - -func typesNeedCharset(tp byte) bool { - switch tp { - case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, - mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, - mysql.TypeEnum, mysql.TypeSet: - return true - } - return false -} - -// checkTooBigFieldLengthAndTryAutoConvert will check whether the field length is too big -// in non-strict mode and varchar column. If it is, will try to adjust to blob or text, see issue #30328 -func checkTooBigFieldLengthAndTryAutoConvert(tp *types.FieldType, colName string, sessVars *variable.SessionVars) error { - if sessVars != nil && !sessVars.SQLMode.HasStrictMode() && tp.GetType() == mysql.TypeVarchar { - err := types.IsVarcharTooBigFieldLength(tp.GetFlen(), colName, tp.GetCharset()) - if err != nil && terror.ErrorEqual(types.ErrTooBigFieldLength, err) { - tp.SetType(mysql.TypeBlob) - if err = adjustBlobTypesFlen(tp, tp.GetCharset()); err != nil { - return err - } - if tp.GetCharset() == charset.CharsetBin { - sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARBINARY", "BLOB")) - } else { - sessVars.StmtCtx.AppendWarning(dbterror.ErrAutoConvert.FastGenByArgs(colName, "VARCHAR", "TEXT")) - } - } - } - return nil -} - -// columnDefToCol converts ColumnDef to Col and TableConstraints. -// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); -func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) { - var constraints = make([]*ast.Constraint, 0) - col := table.ToColumn(&model.ColumnInfo{ - Offset: offset, - Name: colDef.Name.Name, - FieldType: *colDef.Tp, - // TODO: remove this version field after there is no old version. - Version: model.CurrLatestColumnInfoVersion, - }) - - if !isExplicitTimeStamp() { - // Check and set TimestampFlag, OnUpdateNowFlag and NotNullFlag. - if col.GetType() == mysql.TypeTimestamp { - col.AddFlag(mysql.TimestampFlag | mysql.OnUpdateNowFlag | mysql.NotNullFlag) - } - } - var err error - setOnUpdateNow := false - hasDefaultValue := false - hasNullFlag := false - if colDef.Options != nil { - length := types.UnspecifiedLength - - keys := []*ast.IndexPartSpecification{ - { - Column: colDef.Name, - Length: length, - }, - } - - var sb strings.Builder - restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName - restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) - - for _, v := range colDef.Options { - switch v.Tp { - case ast.ColumnOptionNotNull: - col.AddFlag(mysql.NotNullFlag) - case ast.ColumnOptionNull: - col.DelFlag(mysql.NotNullFlag) - removeOnUpdateNowFlag(col) - hasNullFlag = true - case ast.ColumnOptionAutoIncrement: - col.AddFlag(mysql.AutoIncrementFlag | mysql.NotNullFlag) - case ast.ColumnOptionPrimaryKey: - // Check PriKeyFlag first to avoid extra duplicate constraints. - if col.GetFlag()&mysql.PriKeyFlag == 0 { - constraint := &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: keys, - Option: &ast.IndexOption{PrimaryKeyTp: v.PrimaryKeyTp}} - constraints = append(constraints, constraint) - col.AddFlag(mysql.PriKeyFlag) - // Add NotNullFlag early so that processColumnFlags() can see it. - col.AddFlag(mysql.NotNullFlag) - } - case ast.ColumnOptionUniqKey: - // Check UniqueFlag first to avoid extra duplicate constraints. - if col.GetFlag()&mysql.UniqueFlag == 0 { - constraint := &ast.Constraint{Tp: ast.ConstraintUniqKey, Keys: keys} - constraints = append(constraints, constraint) - col.AddFlag(mysql.UniqueKeyFlag) - } - case ast.ColumnOptionDefaultValue: - hasDefaultValue, err = SetDefaultValue(ctx, col, v) - if err != nil { - return nil, nil, errors.Trace(err) - } - removeOnUpdateNowFlag(col) - case ast.ColumnOptionOnUpdate: - // TODO: Support other time functions. - if !(col.GetType() == mysql.TypeTimestamp || col.GetType() == mysql.TypeDatetime) { - return nil, nil, dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) - } - if !expression.IsValidCurrentTimestampExpr(v.Expr, colDef.Tp) { - return nil, nil, dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) - } - col.AddFlag(mysql.OnUpdateNowFlag) - setOnUpdateNow = true - case ast.ColumnOptionComment: - err := setColumnComment(ctx, col, v) - if err != nil { - return nil, nil, errors.Trace(err) - } - case ast.ColumnOptionGenerated: - sb.Reset() - err = v.Expr.Restore(restoreCtx) - if err != nil { - return nil, nil, errors.Trace(err) - } - col.GeneratedExprString = sb.String() - col.GeneratedStored = v.Stored - _, dependColNames, err := findDependedColumnNames(model.NewCIStr(""), model.NewCIStr(""), colDef) - if err != nil { - return nil, nil, errors.Trace(err) - } - col.Dependences = dependColNames - case ast.ColumnOptionCollate: - if field_types.HasCharset(colDef.Tp) { - col.FieldType.SetCollate(v.StrValue) - } - case ast.ColumnOptionFulltext: - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) - case ast.ColumnOptionCheck: - if !variable.EnableCheckConstraint.Load() { - ctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) - } else { - // Check the column CHECK constraint dependency lazily, after fill all the name. - // Extract column constraint from column option. - constraint := &ast.Constraint{ - Tp: ast.ConstraintCheck, - Expr: v.Expr, - Enforced: v.Enforced, - Name: v.ConstraintName, - InColumn: true, - InColumnName: colDef.Name.Name.O, - } - constraints = append(constraints, constraint) - } - } - } - } - - if err = processAndCheckDefaultValueAndColumn(ctx, col, outPriKeyConstraint, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { - return nil, nil, errors.Trace(err) - } - return col, constraints, nil -} - -// isExplicitTimeStamp is used to check if explicit_defaults_for_timestamp is on or off. -// Check out this link for more details. -// https://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_explicit_defaults_for_timestamp -func isExplicitTimeStamp() bool { - // TODO: implement the behavior as MySQL when explicit_defaults_for_timestamp = off, then this function could return false. - return true -} - -// SetDefaultValue sets the default value of the column. -func SetDefaultValue(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) (hasDefaultValue bool, err error) { - var value any - var isSeqExpr bool - value, isSeqExpr, err = getDefaultValue( - exprctx.CtxWithHandleTruncateErrLevel(ctx.GetExprCtx(), errctx.LevelError), - col, option, - ) - if err != nil { - return false, errors.Trace(err) - } - if isSeqExpr { - if err := checkSequenceDefaultValue(col); err != nil { - return false, errors.Trace(err) - } - col.DefaultIsExpr = isSeqExpr - } - - // When the default value is expression, we skip check and convert. - if !col.DefaultIsExpr { - if hasDefaultValue, value, err = checkColumnDefaultValue(ctx.GetExprCtx(), col, value); err != nil { - return hasDefaultValue, errors.Trace(err) - } - value, err = convertTimestampDefaultValToUTC(ctx, value, col) - if err != nil { - return hasDefaultValue, errors.Trace(err) - } - } else { - hasDefaultValue = true - } - err = setDefaultValueWithBinaryPadding(col, value) - if err != nil { - return hasDefaultValue, errors.Trace(err) - } - return hasDefaultValue, nil -} - -// getFuncCallDefaultValue gets the default column value of function-call expression. -func getFuncCallDefaultValue(col *table.Column, option *ast.ColumnOption, expr *ast.FuncCallExpr) (any, bool, error) { - switch expr.FnName.L { - case ast.CurrentTimestamp, ast.CurrentDate: // CURRENT_TIMESTAMP() and CURRENT_DATE() - tp, fsp := col.FieldType.GetType(), col.FieldType.GetDecimal() - if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { - defaultFsp := 0 - if len(expr.Args) == 1 { - if val := expr.Args[0].(*driver.ValueExpr); val != nil { - defaultFsp = int(val.GetInt64()) - } - } - if defaultFsp != fsp { - return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) - } - } - return nil, false, nil - case ast.NextVal: - // handle default next value of sequence. (keep the expr string) - str, err := getSequenceDefaultValue(option) - if err != nil { - return nil, false, errors.Trace(err) - } - return str, true, nil - case ast.Rand, ast.UUID, ast.UUIDToBin: // RAND(), UUID() and UUID_TO_BIN() - if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { - return nil, false, errors.Trace(err) - } - str, err := restoreFuncCall(expr) - if err != nil { - return nil, false, errors.Trace(err) - } - col.DefaultIsExpr = true - return str, false, nil - case ast.DateFormat: // DATE_FORMAT() - if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { - return nil, false, errors.Trace(err) - } - // Support DATE_FORMAT(NOW(),'%Y-%m'), DATE_FORMAT(NOW(),'%Y-%m-%d'), - // DATE_FORMAT(NOW(),'%Y-%m-%d %H.%i.%s'), DATE_FORMAT(NOW(),'%Y-%m-%d %H:%i:%s'). - nowFunc, ok := expr.Args[0].(*ast.FuncCallExpr) - if ok && nowFunc.FnName.L == ast.Now { - if err := expression.VerifyArgsWrapper(nowFunc.FnName.L, len(nowFunc.Args)); err != nil { - return nil, false, errors.Trace(err) - } - valExpr, isValue := expr.Args[1].(ast.ValueExpr) - if !isValue || (valExpr.GetString() != "%Y-%m" && valExpr.GetString() != "%Y-%m-%d" && - valExpr.GetString() != "%Y-%m-%d %H.%i.%s" && valExpr.GetString() != "%Y-%m-%d %H:%i:%s") { - return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), valExpr) - } - str, err := restoreFuncCall(expr) - if err != nil { - return nil, false, errors.Trace(err) - } - col.DefaultIsExpr = true - return str, false, nil - } - return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), - fmt.Sprintf("%s with disallowed args", expr.FnName.String())) - case ast.Replace: - if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { - return nil, false, errors.Trace(err) - } - funcCall := expr.Args[0] - // Support REPLACE(CONVERT(UPPER(UUID()) USING UTF8MB4), '-', '')) - if convertFunc, ok := funcCall.(*ast.FuncCallExpr); ok && convertFunc.FnName.L == ast.Convert { - if err := expression.VerifyArgsWrapper(convertFunc.FnName.L, len(convertFunc.Args)); err != nil { - return nil, false, errors.Trace(err) - } - funcCall = convertFunc.Args[0] - } - // Support REPLACE(UPPER(UUID()), '-', ''). - if upperFunc, ok := funcCall.(*ast.FuncCallExpr); ok && upperFunc.FnName.L == ast.Upper { - if err := expression.VerifyArgsWrapper(upperFunc.FnName.L, len(upperFunc.Args)); err != nil { - return nil, false, errors.Trace(err) - } - if uuidFunc, ok := upperFunc.Args[0].(*ast.FuncCallExpr); ok && uuidFunc.FnName.L == ast.UUID { - if err := expression.VerifyArgsWrapper(uuidFunc.FnName.L, len(uuidFunc.Args)); err != nil { - return nil, false, errors.Trace(err) - } - str, err := restoreFuncCall(expr) - if err != nil { - return nil, false, errors.Trace(err) - } - col.DefaultIsExpr = true - return str, false, nil - } - } - return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), - fmt.Sprintf("%s with disallowed args", expr.FnName.String())) - case ast.Upper: - if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { - return nil, false, errors.Trace(err) - } - // Support UPPER(SUBSTRING_INDEX(USER(), '@', 1)). - if substringIndexFunc, ok := expr.Args[0].(*ast.FuncCallExpr); ok && substringIndexFunc.FnName.L == ast.SubstringIndex { - if err := expression.VerifyArgsWrapper(substringIndexFunc.FnName.L, len(substringIndexFunc.Args)); err != nil { - return nil, false, errors.Trace(err) - } - if userFunc, ok := substringIndexFunc.Args[0].(*ast.FuncCallExpr); ok && userFunc.FnName.L == ast.User { - if err := expression.VerifyArgsWrapper(userFunc.FnName.L, len(userFunc.Args)); err != nil { - return nil, false, errors.Trace(err) - } - valExpr, isValue := substringIndexFunc.Args[1].(ast.ValueExpr) - if !isValue || valExpr.GetString() != "@" { - return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), valExpr) - } - str, err := restoreFuncCall(expr) - if err != nil { - return nil, false, errors.Trace(err) - } - col.DefaultIsExpr = true - return str, false, nil - } - } - return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), - fmt.Sprintf("%s with disallowed args", expr.FnName.String())) - case ast.StrToDate: // STR_TO_DATE() - if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { - return nil, false, errors.Trace(err) - } - // Support STR_TO_DATE('1980-01-01', '%Y-%m-%d'). - if _, ok1 := expr.Args[0].(ast.ValueExpr); ok1 { - if _, ok2 := expr.Args[1].(ast.ValueExpr); ok2 { - str, err := restoreFuncCall(expr) - if err != nil { - return nil, false, errors.Trace(err) - } - col.DefaultIsExpr = true - return str, false, nil - } - } - return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), - fmt.Sprintf("%s with disallowed args", expr.FnName.String())) - case ast.JSONObject, ast.JSONArray, ast.JSONQuote: // JSON_OBJECT(), JSON_ARRAY(), JSON_QUOTE() - if err := expression.VerifyArgsWrapper(expr.FnName.L, len(expr.Args)); err != nil { - return nil, false, errors.Trace(err) - } - str, err := restoreFuncCall(expr) - if err != nil { - return nil, false, errors.Trace(err) - } - col.DefaultIsExpr = true - return str, false, nil - - default: - return nil, false, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String(), expr.FnName.String()) - } -} - -// getDefaultValue will get the default value for column. -// 1: get the expr restored string for the column which uses sequence next value as default value. -// 2: get specific default value for the other column. -func getDefaultValue(ctx exprctx.BuildContext, col *table.Column, option *ast.ColumnOption) (any, bool, error) { - // handle default value with function call - tp, fsp := col.FieldType.GetType(), col.FieldType.GetDecimal() - if x, ok := option.Expr.(*ast.FuncCallExpr); ok { - val, isSeqExpr, err := getFuncCallDefaultValue(col, option, x) - if val != nil || isSeqExpr || err != nil { - return val, isSeqExpr, err - } - // If the function call is ast.CurrentTimestamp, it needs to be continuously processed. - } - - if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate { - vd, err := expression.GetTimeValue(ctx, option.Expr, tp, fsp, nil) - value := vd.GetValue() - if err != nil { - return nil, false, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) - } - - // Value is nil means `default null`. - if value == nil { - return nil, false, nil - } - - // If value is types.Time, convert it to string. - if vv, ok := value.(types.Time); ok { - return vv.String(), false, nil - } - - return value, false, nil - } - - // evaluate the non-function-call expr to a certain value. - v, err := expression.EvalSimpleAst(ctx, option.Expr) - if err != nil { - return nil, false, errors.Trace(err) - } - - if v.IsNull() { - return nil, false, nil - } - - if v.Kind() == types.KindBinaryLiteral || v.Kind() == types.KindMysqlBit { - if types.IsTypeBlob(tp) || tp == mysql.TypeJSON { - // BLOB/TEXT/JSON column cannot have a default value. - // Skip the unnecessary decode procedure. - return v.GetString(), false, err - } - if tp == mysql.TypeBit || tp == mysql.TypeString || tp == mysql.TypeVarchar || - tp == mysql.TypeVarString || tp == mysql.TypeEnum || tp == mysql.TypeSet { - // For BinaryLiteral or bit fields, we decode the default value to utf8 string. - str, err := v.GetBinaryStringDecoded(types.StrictFlags, col.GetCharset()) - if err != nil { - // Overwrite the decoding error with invalid default value error. - err = dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) - } - return str, false, err - } - // For other kind of fields (e.g. INT), we supply its integer as string value. - value, err := v.GetBinaryLiteral().ToInt(ctx.GetEvalCtx().TypeCtx()) - if err != nil { - return nil, false, err - } - return strconv.FormatUint(value, 10), false, nil - } - - switch tp { - case mysql.TypeSet: - val, err := getSetDefaultValue(v, col) - return val, false, err - case mysql.TypeEnum: - val, err := getEnumDefaultValue(v, col) - return val, false, err - case mysql.TypeDuration, mysql.TypeDate: - if v, err = v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err != nil { - return "", false, errors.Trace(err) - } - case mysql.TypeBit: - if v.Kind() == types.KindInt64 || v.Kind() == types.KindUint64 { - // For BIT fields, convert int into BinaryLiteral. - return types.NewBinaryLiteralFromUint(v.GetUint64(), -1).ToString(), false, nil - } - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeFloat, mysql.TypeDouble: - // For these types, convert it to standard format firstly. - // like integer fields, convert it into integer string literals. like convert "1.25" into "1" and "2.8" into "3". - // if raise a error, we will use original expression. We will handle it in check phase - if temp, err := v.ConvertTo(ctx.GetEvalCtx().TypeCtx(), &col.FieldType); err == nil { - v = temp - } - } - - val, err := v.ToString() - return val, false, err -} - -func getSequenceDefaultValue(c *ast.ColumnOption) (expr string, err error) { - var sb strings.Builder - restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation - restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) - if err := c.Expr.Restore(restoreCtx); err != nil { - return "", err - } - return sb.String(), nil -} - -func setDefaultValueWithBinaryPadding(col *table.Column, value any) error { - err := col.SetDefaultValue(value) - if err != nil { - return err - } - // https://dev.mysql.com/doc/refman/8.0/en/binary-varbinary.html - // Set the default value for binary type should append the paddings. - if value != nil { - if col.GetType() == mysql.TypeString && types.IsBinaryStr(&col.FieldType) && len(value.(string)) < col.GetFlen() { - padding := make([]byte, col.GetFlen()-len(value.(string))) - col.DefaultValue = string(append([]byte(col.DefaultValue.(string)), padding...)) - } - } - return nil -} - -func setColumnComment(ctx sessionctx.Context, col *table.Column, option *ast.ColumnOption) error { - value, err := expression.EvalSimpleAst(ctx.GetExprCtx(), option.Expr) - if err != nil { - return errors.Trace(err) - } - if col.Comment, err = value.ToString(); err != nil { - return errors.Trace(err) - } - - sessionVars := ctx.GetSessionVars() - col.Comment, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, col.Name.L, &col.Comment, dbterror.ErrTooLongFieldComment) - return errors.Trace(err) -} - -func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Column, - outPriKeyConstraint *ast.Constraint, hasDefaultValue, setOnUpdateNow, hasNullFlag bool) error { - processDefaultValue(col, hasDefaultValue, setOnUpdateNow) - processColumnFlags(col) - - err := checkPriKeyConstraint(col, hasDefaultValue, hasNullFlag, outPriKeyConstraint) - if err != nil { - return errors.Trace(err) - } - if err = checkColumnValueConstraint(col, col.GetCollate()); err != nil { - return errors.Trace(err) - } - if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil { - return errors.Trace(err) - } - if err = checkColumnFieldLength(col); err != nil { - return errors.Trace(err) - } - return nil -} - -func restoreFuncCall(expr *ast.FuncCallExpr) (string, error) { - var sb strings.Builder - restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation - restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) - if err := expr.Restore(restoreCtx); err != nil { - return "", err - } - return sb.String(), nil -} - -// getSetDefaultValue gets the default value for the set type. See https://dev.mysql.com/doc/refman/5.7/en/set.html. -func getSetDefaultValue(v types.Datum, col *table.Column) (string, error) { - if v.Kind() == types.KindInt64 { - setCnt := len(col.GetElems()) - maxLimit := int64(1< maxLimit { - return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) - } - setVal, err := types.ParseSetValue(col.GetElems(), uint64(val)) - if err != nil { - return "", errors.Trace(err) - } - v.SetMysqlSet(setVal, col.GetCollate()) - return v.ToString() - } - - str, err := v.ToString() - if err != nil { - return "", errors.Trace(err) - } - if str == "" { - return str, nil - } - setVal, err := types.ParseSetName(col.GetElems(), str, col.GetCollate()) - if err != nil { - return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) - } - v.SetMysqlSet(setVal, col.GetCollate()) - - return v.ToString() -} - -// getEnumDefaultValue gets the default value for the enum type. See https://dev.mysql.com/doc/refman/5.7/en/enum.html. -func getEnumDefaultValue(v types.Datum, col *table.Column) (string, error) { - if v.Kind() == types.KindInt64 { - val := v.GetInt64() - if val < 1 || val > int64(len(col.GetElems())) { - return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) - } - enumVal, err := types.ParseEnumValue(col.GetElems(), uint64(val)) - if err != nil { - return "", errors.Trace(err) - } - v.SetMysqlEnum(enumVal, col.GetCollate()) - return v.ToString() - } - str, err := v.ToString() - if err != nil { - return "", errors.Trace(err) - } - // Ref: https://dev.mysql.com/doc/refman/8.0/en/enum.html - // Trailing spaces are automatically deleted from ENUM member values in the table definition when a table is created. - str = strings.TrimRight(str, " ") - enumVal, err := types.ParseEnumName(col.GetElems(), str, col.GetCollate()) - if err != nil { - return "", dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(col.Name.O) - } - v.SetMysqlEnum(enumVal, col.GetCollate()) - - return v.ToString() -} - -func removeOnUpdateNowFlag(c *table.Column) { - // For timestamp Col, if it is set null or default value, - // OnUpdateNowFlag should be removed. - if mysql.HasTimestampFlag(c.GetFlag()) { - c.DelFlag(mysql.OnUpdateNowFlag) - } -} - -func processDefaultValue(c *table.Column, hasDefaultValue bool, setOnUpdateNow bool) { - setTimestampDefaultValue(c, hasDefaultValue, setOnUpdateNow) - - setYearDefaultValue(c, hasDefaultValue) - - // Set `NoDefaultValueFlag` if this field doesn't have a default value and - // it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field. - setNoDefaultValueFlag(c, hasDefaultValue) -} - -func setYearDefaultValue(c *table.Column, hasDefaultValue bool) { - if hasDefaultValue { - return - } - - if c.GetType() == mysql.TypeYear && mysql.HasNotNullFlag(c.GetFlag()) { - if err := c.SetDefaultValue("0000"); err != nil { - logutil.DDLLogger().Error("set default value failed", zap.Error(err)) - } - } -} - -func setTimestampDefaultValue(c *table.Column, hasDefaultValue bool, setOnUpdateNow bool) { - if hasDefaultValue { - return - } - - // For timestamp Col, if is not set default value or not set null, use current timestamp. - if mysql.HasTimestampFlag(c.GetFlag()) && mysql.HasNotNullFlag(c.GetFlag()) { - if setOnUpdateNow { - if err := c.SetDefaultValue(types.ZeroDatetimeStr); err != nil { - logutil.DDLLogger().Error("set default value failed", zap.Error(err)) - } - } else { - if err := c.SetDefaultValue(strings.ToUpper(ast.CurrentTimestamp)); err != nil { - logutil.DDLLogger().Error("set default value failed", zap.Error(err)) - } - } - } -} - -func setNoDefaultValueFlag(c *table.Column, hasDefaultValue bool) { - if hasDefaultValue { - return - } - - if !mysql.HasNotNullFlag(c.GetFlag()) { - return - } - - // Check if it is an `AUTO_INCREMENT` field or `TIMESTAMP` field. - if !mysql.HasAutoIncrementFlag(c.GetFlag()) && !mysql.HasTimestampFlag(c.GetFlag()) { - c.AddFlag(mysql.NoDefaultValueFlag) - } -} - -func checkDefaultValue(ctx exprctx.BuildContext, c *table.Column, hasDefaultValue bool) (err error) { - if !hasDefaultValue { - return nil - } - - if c.GetDefaultValue() != nil { - if c.DefaultIsExpr { - if mysql.HasAutoIncrementFlag(c.GetFlag()) { - return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) - } - return nil - } - _, err = table.GetColDefaultValue( - exprctx.CtxWithHandleTruncateErrLevel(ctx, errctx.LevelError), - c.ToInfo(), - ) - if err != nil { - return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) - } - return nil - } - // Primary key default null is invalid. - if mysql.HasPriKeyFlag(c.GetFlag()) { - return dbterror.ErrPrimaryCantHaveNull - } - - // Set not null but default null is invalid. - if mysql.HasNotNullFlag(c.GetFlag()) { - return types.ErrInvalidDefault.GenWithStackByArgs(c.Name) - } - - return nil -} - -func checkColumnFieldLength(col *table.Column) error { - if col.GetType() == mysql.TypeVarchar { - if err := types.IsVarcharTooBigFieldLength(col.GetFlen(), col.Name.O, col.GetCharset()); err != nil { - return errors.Trace(err) - } - } - - return nil -} - -// checkPriKeyConstraint check all parts of a PRIMARY KEY must be NOT NULL -func checkPriKeyConstraint(col *table.Column, hasDefaultValue, hasNullFlag bool, outPriKeyConstraint *ast.Constraint) error { - // Primary key should not be null. - if mysql.HasPriKeyFlag(col.GetFlag()) && hasDefaultValue && col.GetDefaultValue() == nil { - return types.ErrInvalidDefault.GenWithStackByArgs(col.Name) - } - // Set primary key flag for outer primary key constraint. - // Such as: create table t1 (id int , age int, primary key(id)) - if !mysql.HasPriKeyFlag(col.GetFlag()) && outPriKeyConstraint != nil { - for _, key := range outPriKeyConstraint.Keys { - if key.Expr == nil && key.Column.Name.L != col.Name.L { - continue - } - col.AddFlag(mysql.PriKeyFlag) - break - } - } - // Primary key should not be null. - if mysql.HasPriKeyFlag(col.GetFlag()) && hasNullFlag { - return dbterror.ErrPrimaryCantHaveNull - } - return nil -} - -func checkColumnValueConstraint(col *table.Column, collation string) error { - if col.GetType() != mysql.TypeEnum && col.GetType() != mysql.TypeSet { - return nil - } - valueMap := make(map[string]bool, len(col.GetElems())) - ctor := collate.GetCollator(collation) - enumLengthLimit := config.GetGlobalConfig().EnableEnumLengthLimit - desc, err := charset.GetCharsetInfo(col.GetCharset()) - if err != nil { - return errors.Trace(err) - } - for i := range col.GetElems() { - val := string(ctor.Key(col.GetElems()[i])) - // According to MySQL 8.0 Refman: - // The maximum supported length of an individual ENUM element is M <= 255 and (M x w) <= 1020, - // where M is the element literal length and w is the number of bytes required for the maximum-length character in the character set. - // See https://dev.mysql.com/doc/refman/8.0/en/string-type-syntax.html for more details. - if enumLengthLimit && (len(val) > 255 || len(val)*desc.Maxlen > 1020) { - return dbterror.ErrTooLongValueForType.GenWithStackByArgs(col.Name) - } - if _, ok := valueMap[val]; ok { - tpStr := "ENUM" - if col.GetType() == mysql.TypeSet { - tpStr = "SET" - } - return types.ErrDuplicatedValueInType.GenWithStackByArgs(col.Name, col.GetElems()[i], tpStr) - } - valueMap[val] = true - } - return nil -} - -// checkColumnDefaultValue checks the default value of the column. -// In non-strict SQL mode, if the default value of the column is an empty string, the default value can be ignored. -// In strict SQL mode, TEXT/BLOB/JSON can't have not null default values. -// In NO_ZERO_DATE SQL mode, TIMESTAMP/DATE/DATETIME type can't have zero date like '0000-00-00' or '0000-00-00 00:00:00'. -func checkColumnDefaultValue(ctx exprctx.BuildContext, col *table.Column, value any) (bool, any, error) { - hasDefaultValue := true - if value != nil && (col.GetType() == mysql.TypeJSON || - col.GetType() == mysql.TypeTinyBlob || col.GetType() == mysql.TypeMediumBlob || - col.GetType() == mysql.TypeLongBlob || col.GetType() == mysql.TypeBlob) { - // In non-strict SQL mode. - if !ctx.GetEvalCtx().SQLMode().HasStrictMode() && value == "" { - if col.GetType() == mysql.TypeBlob || col.GetType() == mysql.TypeLongBlob { - // The TEXT/BLOB default value can be ignored. - hasDefaultValue = false - } - // In non-strict SQL mode, if the column type is json and the default value is null, it is initialized to an empty array. - if col.GetType() == mysql.TypeJSON { - value = `null` - } - ctx.GetEvalCtx().AppendWarning(dbterror.ErrBlobCantHaveDefault.FastGenByArgs(col.Name.O)) - return hasDefaultValue, value, nil - } - // In strict SQL mode or default value is not an empty string. - return hasDefaultValue, value, dbterror.ErrBlobCantHaveDefault.GenWithStackByArgs(col.Name.O) - } - if value != nil && ctx.GetEvalCtx().SQLMode().HasNoZeroDateMode() && - ctx.GetEvalCtx().SQLMode().HasStrictMode() && types.IsTypeTime(col.GetType()) { - if vv, ok := value.(string); ok { - timeValue, err := expression.GetTimeValue(ctx, vv, col.GetType(), col.GetDecimal(), nil) - if err != nil { - return hasDefaultValue, value, errors.Trace(err) - } - if timeValue.GetMysqlTime().CoreTime() == types.ZeroCoreTime { - return hasDefaultValue, value, types.ErrInvalidDefault.GenWithStackByArgs(col.Name.O) - } - } - } - return hasDefaultValue, value, nil -} - -func checkSequenceDefaultValue(col *table.Column) error { - if mysql.IsIntegerType(col.GetType()) { - return nil - } - return dbterror.ErrColumnTypeUnsupportedNextValue.GenWithStackByArgs(col.ColumnInfo.Name.O) -} - -func convertTimestampDefaultValToUTC(ctx sessionctx.Context, defaultVal any, col *table.Column) (any, error) { - if defaultVal == nil || col.GetType() != mysql.TypeTimestamp { - return defaultVal, nil - } - if vv, ok := defaultVal.(string); ok { - if vv != types.ZeroDatetimeStr && !strings.EqualFold(vv, ast.CurrentTimestamp) { - t, err := types.ParseTime(ctx.GetSessionVars().StmtCtx.TypeCtx(), vv, col.GetType(), col.GetDecimal()) - if err != nil { - return defaultVal, errors.Trace(err) - } - err = t.ConvertTimeZone(ctx.GetSessionVars().Location(), time.UTC) - if err != nil { - return defaultVal, errors.Trace(err) - } - defaultVal = t.String() - } - } - return defaultVal, nil -} - -// processColumnFlags is used by columnDefToCol and processColumnOptions. It is intended to unify behaviors on `create/add` and `modify/change` statements. Check tidb#issue#19342. -func processColumnFlags(col *table.Column) { - if col.FieldType.EvalType().IsStringKind() { - if col.GetCharset() == charset.CharsetBin { - col.AddFlag(mysql.BinaryFlag) - } else { - col.DelFlag(mysql.BinaryFlag) - } - } - if col.GetType() == mysql.TypeBit { - // For BIT field, it's charset is binary but does not have binary flag. - col.DelFlag(mysql.BinaryFlag) - col.AddFlag(mysql.UnsignedFlag) - } - if col.GetType() == mysql.TypeYear { - // For Year field, it's charset is binary but does not have binary flag. - col.DelFlag(mysql.BinaryFlag) - col.AddFlag(mysql.ZerofillFlag) - } - - // If you specify ZEROFILL for a numeric column, MySQL automatically adds the UNSIGNED attribute to the column. - // See https://dev.mysql.com/doc/refman/5.7/en/numeric-type-overview.html for more details. - // But some types like bit and year, won't show its unsigned flag in `show create table`. - if mysql.HasZerofillFlag(col.GetFlag()) { - col.AddFlag(mysql.UnsignedFlag) - } -} - -func adjustBlobTypesFlen(tp *types.FieldType, colCharset string) error { - cs, err := charset.GetCharsetInfo(colCharset) - // when we meet the unsupported charset, we do not adjust. - if err != nil { - return err - } - l := tp.GetFlen() * cs.Maxlen - if tp.GetType() == mysql.TypeBlob { - if l <= tinyBlobMaxLength { - logutil.DDLLogger().Info(fmt.Sprintf("Automatically convert BLOB(%d) to TINYBLOB", tp.GetFlen())) - tp.SetFlen(tinyBlobMaxLength) - tp.SetType(mysql.TypeTinyBlob) - } else if l <= blobMaxLength { - tp.SetFlen(blobMaxLength) - } else if l <= mediumBlobMaxLength { - logutil.DDLLogger().Info(fmt.Sprintf("Automatically convert BLOB(%d) to MEDIUMBLOB", tp.GetFlen())) - tp.SetFlen(mediumBlobMaxLength) - tp.SetType(mysql.TypeMediumBlob) - } else if l <= longBlobMaxLength { - logutil.DDLLogger().Info(fmt.Sprintf("Automatically convert BLOB(%d) to LONGBLOB", tp.GetFlen())) - tp.SetFlen(longBlobMaxLength) - tp.SetType(mysql.TypeLongBlob) - } - } - return nil -} diff --git a/pkg/ddl/backfilling.go b/pkg/ddl/backfilling.go index 1f2d3dc37fd3c..4fbed0422726d 100644 --- a/pkg/ddl/backfilling.go +++ b/pkg/ddl/backfilling.go @@ -397,22 +397,22 @@ func (w *backfillWorker) run(d *ddlCtx, bf backfiller, job *model.Job) { d.setDDLLabelForTopSQL(job.ID, job.Query) logger.Debug("backfill worker got task", zap.Int("workerID", w.GetCtx().id), zap.Stringer("task", task)) - if _, _err_ := failpoint.Eval(_curpkg_("mockBackfillRunErr")); _err_ == nil { + failpoint.Inject("mockBackfillRunErr", func() { if w.GetCtx().id == 0 { result := &backfillResult{taskID: task.id, addedCount: 0, nextKey: nil, err: errors.Errorf("mock backfill error")} w.sendResult(result) - continue + failpoint.Continue() } - } + }) - if _, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForAddIndex")); _err_ == nil { + failpoint.Inject("mockHighLoadForAddIndex", func() { sqlPrefixes := []string{"alter"} topsql.MockHighCPULoad(job.Query, sqlPrefixes, 5) - } + }) - if _, _err_ := failpoint.Eval(_curpkg_("mockBackfillSlow")); _err_ == nil { + failpoint.Inject("mockBackfillSlow", func() { time.Sleep(100 * time.Millisecond) - } + }) // Change the batch size dynamically. w.GetCtx().batchCnt = int(variable.GetDDLReorgBatchSize()) @@ -828,12 +828,12 @@ func (dc *ddlCtx) writePhysicalTableRecord( return errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("MockCaseWhenParseFailure")); _err_ == nil { + failpoint.Inject("MockCaseWhenParseFailure", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { - return errors.New("job.ErrCount:" + strconv.Itoa(int(reorgInfo.Job.ErrorCount)) + ", mock unknown type: ast.whenClause.") + failpoint.Return(errors.New("job.ErrCount:" + strconv.Itoa(int(reorgInfo.Job.ErrorCount)) + ", mock unknown type: ast.whenClause.")) } - } + }) if bfWorkerType == typeAddIndexWorker && reorgInfo.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { return dc.runAddIndexInLocalIngestMode(ctx, sessPool, t, reorgInfo) } @@ -954,13 +954,13 @@ func injectCheckBackfillWorkerNum(curWorkerSize int, isMergeWorker bool) error { if isMergeWorker { return nil } - if val, _err_ := failpoint.Eval(_curpkg_("checkBackfillWorkerNum")); _err_ == nil { + failpoint.Inject("checkBackfillWorkerNum", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { num := int(atomic.LoadInt32(&TestCheckWorkerNumber)) if num != 0 { if num != curWorkerSize { - return errors.Errorf("expected backfill worker num: %v, actual record num: %v", num, curWorkerSize) + failpoint.Return(errors.Errorf("expected backfill worker num: %v, actual record num: %v", num, curWorkerSize)) } var wg sync.WaitGroup wg.Add(1) @@ -968,7 +968,7 @@ func injectCheckBackfillWorkerNum(curWorkerSize int, isMergeWorker bool) error { wg.Wait() } } - } + }) return nil } diff --git a/pkg/ddl/backfilling.go__failpoint_stash__ b/pkg/ddl/backfilling.go__failpoint_stash__ deleted file mode 100644 index 4fbed0422726d..0000000000000 --- a/pkg/ddl/backfilling.go__failpoint_stash__ +++ /dev/null @@ -1,1124 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "context" - "encoding/hex" - "fmt" - "strconv" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - ddlutil "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/disttask/operator" - "github.com/pingcap/tidb/pkg/expression" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/dbterror" - decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" - "github.com/pingcap/tidb/pkg/util/topsql" - "github.com/prometheus/client_golang/prometheus" - "github.com/tikv/client-go/v2/tikv" - kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -type backfillerType byte - -const ( - typeAddIndexWorker backfillerType = iota - typeUpdateColumnWorker - typeCleanUpIndexWorker - typeAddIndexMergeTmpWorker - typeReorgPartitionWorker - - typeCount -) - -// BackupFillerTypeCount represents the count of ddl jobs that need to do backfill. -func BackupFillerTypeCount() int { - return int(typeCount) -} - -func (bT backfillerType) String() string { - switch bT { - case typeAddIndexWorker: - return "add index" - case typeUpdateColumnWorker: - return "update column" - case typeCleanUpIndexWorker: - return "clean up index" - case typeAddIndexMergeTmpWorker: - return "merge temporary index" - case typeReorgPartitionWorker: - return "reorganize partition" - default: - return "unknown" - } -} - -// By now the DDL jobs that need backfilling include: -// 1: add-index -// 2: modify-column-type -// 3: clean-up global index -// 4: reorganize partition -// -// They all have a write reorganization state to back fill data into the rows existed. -// Backfilling is time consuming, to accelerate this process, TiDB has built some sub -// workers to do this in the DDL owner node. -// -// DDL owner thread (also see comments before runReorgJob func) -// ^ -// | (reorgCtx.doneCh) -// | -// worker master -// ^ (waitTaskResults) -// | -// | -// v (sendRangeTask) -// +--------------------+---------+---------+------------------+--------------+ -// | | | | | -// backfillworker1 backfillworker2 backfillworker3 backfillworker4 ... -// -// The worker master is responsible for scaling the backfilling workers according to the -// system variable "tidb_ddl_reorg_worker_cnt". Essentially, reorg job is mainly based -// on the [start, end] range of the table to backfill data. We did not do it all at once, -// there were several ddl rounds. -// -// [start1---end1 start2---end2 start3---end3 start4---end4 ... ... ] -// | | | | | | | | -// +-------+ +-------+ +-------+ +-------+ ... ... -// | | | | -// bfworker1 bfworker2 bfworker3 bfworker4 ... ... -// | | | | | | -// +---------------- (round1)----------------+ +--(round2)--+ -// -// The main range [start, end] will be split into small ranges. -// Each small range corresponds to a region and it will be delivered to a backfillworker. -// Each worker can only be assigned with one range at one round, those remaining ranges -// will be cached until all the backfill workers have had their previous range jobs done. -// -// [ region start --------------------- region end ] -// | -// v -// [ batch ] [ batch ] [ batch ] [ batch ] ... -// | | | | -// v v v v -// (a kv txn) -> -> -> -// -// For a single range, backfill worker doesn't backfill all the data in one kv transaction. -// Instead, it is divided into batches, each time a kv transaction completes the backfilling -// of a partial batch. - -// backfillTaskContext is the context of the batch adding indices or updating column values. -// After finishing the batch adding indices or updating column values, result in backfillTaskContext will be merged into backfillResult. -type backfillTaskContext struct { - nextKey kv.Key - done bool - addedCount int - scanCount int - warnings map[errors.ErrorID]*terror.Error - warningsCount map[errors.ErrorID]int64 - finishTS uint64 -} - -type backfillCtx struct { - id int - *ddlCtx - sessCtx sessionctx.Context - warnings contextutil.WarnHandlerExt - loc *time.Location - exprCtx exprctx.BuildContext - tblCtx table.MutateContext - schemaName string - table table.Table - batchCnt int - jobContext *JobContext - metricCounter prometheus.Counter -} - -func newBackfillCtx(id int, rInfo *reorgInfo, - schemaName string, tbl table.Table, jobCtx *JobContext, label string, isDistributed bool) (*backfillCtx, error) { - sessCtx, err := newSessCtx(rInfo.d.store, rInfo.ReorgMeta) - if err != nil { - return nil, err - } - - if isDistributed { - id = int(backfillContextID.Add(1)) - } - - exprCtx := sessCtx.GetExprCtx() - return &backfillCtx{ - id: id, - ddlCtx: rInfo.d, - sessCtx: sessCtx, - warnings: sessCtx.GetSessionVars().StmtCtx.WarnHandler, - exprCtx: exprCtx, - tblCtx: sessCtx.GetTableCtx(), - loc: exprCtx.GetEvalCtx().Location(), - schemaName: schemaName, - table: tbl, - batchCnt: int(variable.GetDDLReorgBatchSize()), - jobContext: jobCtx, - metricCounter: metrics.BackfillTotalCounter.WithLabelValues( - metrics.GenerateReorgLabel(label, schemaName, tbl.Meta().Name.String())), - }, nil -} - -func updateTxnEntrySizeLimitIfNeeded(txn kv.Transaction) { - if entrySizeLimit := variable.TxnEntrySizeLimit.Load(); entrySizeLimit > 0 { - txn.SetOption(kv.SizeLimits, kv.TxnSizeLimits{ - Entry: entrySizeLimit, - Total: kv.TxnTotalSizeLimit.Load(), - }) - } -} - -type backfiller interface { - BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, err error) - AddMetricInfo(float64) - GetCtx() *backfillCtx - String() string -} - -type backfillResult struct { - taskID int - addedCount int - scanCount int - totalCount int - nextKey kv.Key - err error -} - -type reorgBackfillTask struct { - physicalTable table.PhysicalTable - - // TODO: Remove the following fields after remove the function of run. - id int - startKey kv.Key - endKey kv.Key - jobID int64 - sqlQuery string - priority int -} - -func (r *reorgBackfillTask) getJobID() int64 { - return r.jobID -} - -func (r *reorgBackfillTask) String() string { - pID := r.physicalTable.GetPhysicalID() - start := hex.EncodeToString(r.startKey) - end := hex.EncodeToString(r.endKey) - jobID := r.getJobID() - return fmt.Sprintf("taskID: %d, physicalTableID: %d, range: [%s, %s), jobID: %d", r.id, pID, start, end, jobID) -} - -// mergeBackfillCtxToResult merge partial result in taskCtx into result. -func mergeBackfillCtxToResult(taskCtx *backfillTaskContext, result *backfillResult) { - result.nextKey = taskCtx.nextKey - result.addedCount += taskCtx.addedCount - result.scanCount += taskCtx.scanCount -} - -type backfillWorker struct { - backfiller - taskCh chan *reorgBackfillTask - resultCh chan *backfillResult - ctx context.Context - cancel func() - wg *sync.WaitGroup -} - -func newBackfillWorker(ctx context.Context, bf backfiller) *backfillWorker { - bfCtx, cancel := context.WithCancel(ctx) - return &backfillWorker{ - backfiller: bf, - taskCh: make(chan *reorgBackfillTask, 1), - resultCh: make(chan *backfillResult, 1), - ctx: bfCtx, - cancel: cancel, - } -} - -func (w *backfillWorker) String() string { - return fmt.Sprintf("backfill-worker %d, tp %s", w.GetCtx().id, w.backfiller.String()) -} - -func (w *backfillWorker) Close() { - if w.cancel != nil { - w.cancel() - w.cancel = nil - } -} - -func closeBackfillWorkers(workers []*backfillWorker) { - for _, worker := range workers { - worker.Close() - } -} - -// ResultCounterForTest is used for test. -var ResultCounterForTest *atomic.Int32 - -// handleBackfillTask backfills range [task.startHandle, task.endHandle) handle's index to table. -func (w *backfillWorker) handleBackfillTask(d *ddlCtx, task *reorgBackfillTask, bf backfiller) *backfillResult { - handleRange := *task - result := &backfillResult{ - taskID: task.id, - err: nil, - addedCount: 0, - nextKey: handleRange.startKey, - } - lastLogCount := 0 - lastLogTime := time.Now() - startTime := lastLogTime - jobID := task.getJobID() - rc := d.getReorgCtx(jobID) - - for { - // Give job chance to be canceled or paused, if we not check it here, - // we will never cancel the job once there is panic in bf.BackfillData. - // Because reorgRecordTask may run a long time, - // we should check whether this ddl job is still runnable. - err := d.isReorgRunnable(jobID, false) - if err != nil { - result.err = err - return result - } - - taskCtx, err := bf.BackfillData(handleRange) - if err != nil { - result.err = err - return result - } - - bf.AddMetricInfo(float64(taskCtx.addedCount)) - mergeBackfillCtxToResult(&taskCtx, result) - - // Although `handleRange` is for data in one region, but back fill worker still split it into many - // small reorg batch size slices and reorg them in many different kv txn. - // If a task failed, it may contained some committed small kv txn which has already finished the - // small range reorganization. - // In the next round of reorganization, the target handle range may overlap with last committed - // small ranges. This will cause the `redo` action in reorganization. - // So for added count and warnings collection, it is recommended to collect the statistics in every - // successfully committed small ranges rather than fetching it in the total result. - rc.increaseRowCount(int64(taskCtx.addedCount)) - rc.mergeWarnings(taskCtx.warnings, taskCtx.warningsCount) - - if num := result.scanCount - lastLogCount; num >= 90000 { - lastLogCount = result.scanCount - logutil.DDLLogger().Info("backfill worker back fill index", zap.Stringer("worker", w), - zap.Int("addedCount", result.addedCount), zap.Int("scanCount", result.scanCount), - zap.String("next key", hex.EncodeToString(taskCtx.nextKey)), - zap.Float64("speed(rows/s)", float64(num)/time.Since(lastLogTime).Seconds())) - lastLogTime = time.Now() - } - - handleRange.startKey = taskCtx.nextKey - if taskCtx.done { - break - } - } - logutil.DDLLogger().Info("backfill worker finish task", - zap.Stringer("worker", w), zap.Stringer("task", task), - zap.Int("added count", result.addedCount), - zap.Int("scan count", result.scanCount), - zap.String("next key", hex.EncodeToString(result.nextKey)), - zap.Stringer("take time", time.Since(startTime))) - if ResultCounterForTest != nil && result.err == nil { - ResultCounterForTest.Add(1) - } - return result -} - -func (w *backfillWorker) sendResult(result *backfillResult) { - select { - case <-w.ctx.Done(): - case w.resultCh <- result: - } -} - -func (w *backfillWorker) run(d *ddlCtx, bf backfiller, job *model.Job) { - logger := logutil.DDLLogger().With(zap.Stringer("worker", w), zap.Int64("jobID", job.ID)) - var ( - curTaskID int - task *reorgBackfillTask - ok bool - ) - - defer w.wg.Done() - defer util.Recover(metrics.LabelDDL, "backfillWorker.run", func() { - w.sendResult(&backfillResult{taskID: curTaskID, err: dbterror.ErrReorgPanic}) - }, false) - for { - select { - case <-w.ctx.Done(): - logger.Info("backfill worker exit on context done") - return - case task, ok = <-w.taskCh: - } - if !ok { - logger.Info("backfill worker exit") - return - } - curTaskID = task.id - d.setDDLLabelForTopSQL(job.ID, job.Query) - - logger.Debug("backfill worker got task", zap.Int("workerID", w.GetCtx().id), zap.Stringer("task", task)) - failpoint.Inject("mockBackfillRunErr", func() { - if w.GetCtx().id == 0 { - result := &backfillResult{taskID: task.id, addedCount: 0, nextKey: nil, err: errors.Errorf("mock backfill error")} - w.sendResult(result) - failpoint.Continue() - } - }) - - failpoint.Inject("mockHighLoadForAddIndex", func() { - sqlPrefixes := []string{"alter"} - topsql.MockHighCPULoad(job.Query, sqlPrefixes, 5) - }) - - failpoint.Inject("mockBackfillSlow", func() { - time.Sleep(100 * time.Millisecond) - }) - - // Change the batch size dynamically. - w.GetCtx().batchCnt = int(variable.GetDDLReorgBatchSize()) - result := w.handleBackfillTask(d, task, bf) - w.sendResult(result) - - if result.err != nil { - logger.Info("backfill worker exit on error", - zap.Error(result.err)) - return - } - } -} - -// loadTableRanges load table key ranges from PD between given start key and end key. -// It returns up to `limit` ranges. -func loadTableRanges( - ctx context.Context, - t table.PhysicalTable, - store kv.Storage, - startKey, endKey kv.Key, - limit int, -) ([]kv.KeyRange, error) { - if len(startKey) == 0 && len(endKey) == 0 { - logutil.DDLLogger().Info("load noop table range", - zap.Int64("physicalTableID", t.GetPhysicalID())) - return []kv.KeyRange{}, nil - } - - s, ok := store.(tikv.Storage) - if !ok { - // Only support split ranges in tikv.Storage now. - logutil.DDLLogger().Info("load table ranges failed, unsupported storage", - zap.String("storage", fmt.Sprintf("%T", store)), - zap.Int64("physicalTableID", t.GetPhysicalID())) - return []kv.KeyRange{{StartKey: startKey, EndKey: endKey}}, nil - } - - rc := s.GetRegionCache() - maxSleep := 10000 // ms - bo := tikv.NewBackofferWithVars(ctx, maxSleep, nil) - var ranges []kv.KeyRange - err := util.RunWithRetry(util.DefaultMaxRetries, util.RetryInterval, func() (bool, error) { - logutil.DDLLogger().Info("load table ranges from PD", - zap.Int64("physicalTableID", t.GetPhysicalID()), - zap.String("start key", hex.EncodeToString(startKey)), - zap.String("end key", hex.EncodeToString(endKey))) - rs, err := rc.BatchLoadRegionsWithKeyRange(bo, startKey, endKey, limit) - if err != nil { - return false, errors.Trace(err) - } - ranges = make([]kv.KeyRange, 0, len(rs)) - for _, r := range rs { - ranges = append(ranges, kv.KeyRange{StartKey: r.StartKey(), EndKey: r.EndKey()}) - } - err = validateAndFillRanges(ranges, startKey, endKey) - if err != nil { - return true, errors.Trace(err) - } - return false, nil - }) - if err != nil { - return nil, errors.Trace(err) - } - logutil.DDLLogger().Info("load table ranges from PD done", - zap.Int64("physicalTableID", t.GetPhysicalID()), - zap.String("range start", hex.EncodeToString(ranges[0].StartKey)), - zap.String("range end", hex.EncodeToString(ranges[len(ranges)-1].EndKey)), - zap.Int("range count", len(ranges))) - return ranges, nil -} - -func validateAndFillRanges(ranges []kv.KeyRange, startKey, endKey []byte) error { - if len(ranges) == 0 { - errMsg := fmt.Sprintf("cannot find region in range [%s, %s]", - hex.EncodeToString(startKey), hex.EncodeToString(endKey)) - return dbterror.ErrInvalidSplitRegionRanges.GenWithStackByArgs(errMsg) - } - for i, r := range ranges { - if i == 0 { - s := r.StartKey - if len(s) == 0 || bytes.Compare(s, startKey) < 0 { - ranges[i].StartKey = startKey - } - } - if i == len(ranges)-1 { - e := r.EndKey - if len(e) == 0 || bytes.Compare(e, endKey) > 0 { - ranges[i].EndKey = endKey - } - } - if len(ranges[i].StartKey) == 0 || len(ranges[i].EndKey) == 0 { - return errors.Errorf("get empty start/end key in the middle of ranges") - } - if i > 0 && !bytes.Equal(ranges[i-1].EndKey, ranges[i].StartKey) { - return errors.Errorf("ranges are not continuous") - } - } - return nil -} - -func getBatchTasks( - t table.Table, - reorgInfo *reorgInfo, - kvRanges []kv.KeyRange, - taskIDAlloc *taskIDAllocator, - bfWorkerTp backfillerType, -) []*reorgBackfillTask { - batchTasks := make([]*reorgBackfillTask, 0, len(kvRanges)) - //nolint:forcetypeassert - phyTbl := t.(table.PhysicalTable) - for _, r := range kvRanges { - taskID := taskIDAlloc.alloc() - startKey := r.StartKey - endKey := r.EndKey - endKey = getActualEndKey(t, reorgInfo, bfWorkerTp, startKey, endKey, taskID) - task := &reorgBackfillTask{ - id: taskID, - jobID: reorgInfo.Job.ID, - physicalTable: phyTbl, - priority: reorgInfo.Priority, - startKey: startKey, - endKey: endKey, - } - batchTasks = append(batchTasks, task) - } - return batchTasks -} - -func getActualEndKey( - t table.Table, - reorgInfo *reorgInfo, - bfTp backfillerType, - rangeStart, rangeEnd kv.Key, - taskID int, -) kv.Key { - job := reorgInfo.Job - //nolint:forcetypeassert - phyTbl := t.(table.PhysicalTable) - - if bfTp == typeAddIndexMergeTmpWorker { - // Temp Index data does not grow infinitely, we can return the whole range - // and IndexMergeTmpWorker should still be finished in a bounded time. - return rangeEnd - } - if bfTp == typeAddIndexWorker && job.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { - // Ingest worker uses coprocessor to read table data. It is fast enough, - // we don't need to get the actual end key of this range. - return rangeEnd - } - - // Otherwise to avoid the future data written to key range of [backfillChunkEndKey, rangeEnd) and - // backfill worker can't catch up, we shrink the end key to the actual written key for now. - jobCtx := reorgInfo.NewJobContext() - - actualEndKey, err := GetRangeEndKey(jobCtx, reorgInfo.d.store, job.Priority, t.RecordPrefix(), rangeStart, rangeEnd) - if err != nil { - logutil.DDLLogger().Info("get backfill range task, get reverse key failed", zap.Error(err)) - return rangeEnd - } - logutil.DDLLogger().Info("get backfill range task, change end key", - zap.Int("id", taskID), - zap.Int64("pTbl", phyTbl.GetPhysicalID()), - zap.String("end key", hex.EncodeToString(rangeEnd)), - zap.String("current end key", hex.EncodeToString(actualEndKey))) - return actualEndKey -} - -// sendTasks sends tasks to workers, and returns remaining kvRanges that is not handled. -func sendTasks( - scheduler backfillScheduler, - t table.PhysicalTable, - kvRanges []kv.KeyRange, - reorgInfo *reorgInfo, - taskIDAlloc *taskIDAllocator, - bfWorkerTp backfillerType, -) error { - batchTasks := getBatchTasks(t, reorgInfo, kvRanges, taskIDAlloc, bfWorkerTp) - for _, task := range batchTasks { - if err := scheduler.sendTask(task); err != nil { - return errors.Trace(err) - } - } - return nil -} - -var ( - // TestCheckWorkerNumCh use for test adjust backfill worker. - TestCheckWorkerNumCh = make(chan *sync.WaitGroup) - // TestCheckWorkerNumber use for test adjust backfill worker. - TestCheckWorkerNumber = int32(variable.DefTiDBDDLReorgWorkerCount) - // TestCheckReorgTimeout is used to mock timeout when reorg data. - TestCheckReorgTimeout = int32(0) -) - -func loadDDLReorgVars(ctx context.Context, sessPool *sess.Pool) error { - // Get sessionctx from context resource pool. - sCtx, err := sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer sessPool.Put(sCtx) - return ddlutil.LoadDDLReorgVars(ctx, sCtx) -} - -func makeupDecodeColMap(dbName model.CIStr, t table.Table) (map[int64]decoder.Column, error) { - writableColInfos := make([]*model.ColumnInfo, 0, len(t.WritableCols())) - for _, col := range t.WritableCols() { - writableColInfos = append(writableColInfos, col.ColumnInfo) - } - exprCols, _, err := expression.ColumnInfos2ColumnsAndNames(newReorgExprCtx(), dbName, t.Meta().Name, writableColInfos, t.Meta()) - if err != nil { - return nil, err - } - mockSchema := expression.NewSchema(exprCols...) - - decodeColMap := decoder.BuildFullDecodeColMap(t.WritableCols(), mockSchema) - - return decodeColMap, nil -} - -var backfillTaskChanSize = 128 - -// SetBackfillTaskChanSizeForTest is only used for test. -func SetBackfillTaskChanSizeForTest(n int) { - backfillTaskChanSize = n -} - -func (dc *ddlCtx) runAddIndexInLocalIngestMode( - ctx context.Context, - sessPool *sess.Pool, - t table.PhysicalTable, - reorgInfo *reorgInfo, -) error { - // TODO(tangenta): support adjust worker count dynamically. - if err := dc.isReorgRunnable(reorgInfo.Job.ID, false); err != nil { - return errors.Trace(err) - } - job := reorgInfo.Job - opCtx := NewLocalOperatorCtx(ctx, job.ID) - idxCnt := len(reorgInfo.elements) - indexIDs := make([]int64, 0, idxCnt) - indexInfos := make([]*model.IndexInfo, 0, idxCnt) - uniques := make([]bool, 0, idxCnt) - hasUnique := false - for _, e := range reorgInfo.elements { - indexIDs = append(indexIDs, e.ID) - indexInfo := model.FindIndexInfoByID(t.Meta().Indices, e.ID) - if indexInfo == nil { - logutil.DDLIngestLogger().Warn("index info not found", - zap.Int64("jobID", job.ID), - zap.Int64("tableID", t.Meta().ID), - zap.Int64("indexID", e.ID)) - return errors.Errorf("index info not found: %d", e.ID) - } - indexInfos = append(indexInfos, indexInfo) - uniques = append(uniques, indexInfo.Unique) - hasUnique = hasUnique || indexInfo.Unique - } - - //nolint: forcetypeassert - discovery := dc.store.(tikv.Storage).GetRegionCache().PDClient().GetServiceDiscovery() - bcCtx, err := ingest.LitBackCtxMgr.Register( - ctx, job.ID, hasUnique, dc.etcdCli, discovery, job.ReorgMeta.ResourceGroupName) - if err != nil { - return errors.Trace(err) - } - defer ingest.LitBackCtxMgr.Unregister(job.ID) - sctx, err := sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer sessPool.Put(sctx) - - cpMgr, err := ingest.NewCheckpointManager( - ctx, - sessPool, - reorgInfo.PhysicalTableID, - job.ID, - indexIDs, - ingest.LitBackCtxMgr.EncodeJobSortPath(job.ID), - dc.store.(kv.StorageWithPD).GetPDClient(), - ) - if err != nil { - logutil.DDLIngestLogger().Warn("create checkpoint manager failed", - zap.Int64("jobID", job.ID), - zap.Error(err)) - } else { - defer cpMgr.Close() - bcCtx.AttachCheckpointManager(cpMgr) - } - - reorgCtx := dc.getReorgCtx(reorgInfo.Job.ID) - rowCntListener := &localRowCntListener{ - prevPhysicalRowCnt: reorgCtx.getRowCount(), - reorgCtx: dc.getReorgCtx(reorgInfo.Job.ID), - counter: metrics.BackfillTotalCounter.WithLabelValues( - metrics.GenerateReorgLabel("add_idx_rate", job.SchemaName, job.TableName)), - } - - avgRowSize := estimateTableRowSize(ctx, dc.store, sctx.GetRestrictedSQLExecutor(), t) - concurrency := int(variable.GetDDLReorgWorkerCounter()) - - engines, err := bcCtx.Register(indexIDs, uniques, t) - if err != nil { - logutil.DDLIngestLogger().Error("cannot register new engine", - zap.Int64("jobID", job.ID), - zap.Error(err), - zap.Int64s("index IDs", indexIDs)) - return errors.Trace(err) - } - - pipe, err := NewAddIndexIngestPipeline( - opCtx, - dc.store, - sessPool, - bcCtx, - engines, - job.ID, - t, - indexInfos, - reorgInfo.StartKey, - reorgInfo.EndKey, - job.ReorgMeta, - avgRowSize, - concurrency, - cpMgr, - rowCntListener, - ) - if err != nil { - return err - } - err = executeAndClosePipeline(opCtx, pipe) - if err != nil { - err1 := bcCtx.FinishAndUnregisterEngines(ingest.OptCloseEngines) - if err1 != nil { - logutil.DDLIngestLogger().Error("unregister engine failed", - zap.Int64("jobID", job.ID), - zap.Error(err1), - zap.Int64s("index IDs", indexIDs)) - } - return err - } - if cpMgr != nil { - cpMgr.AdvanceWatermark(true, true) - } - return bcCtx.FinishAndUnregisterEngines(ingest.OptCleanData | ingest.OptCheckDup) -} - -func executeAndClosePipeline(ctx *OperatorCtx, pipe *operator.AsyncPipeline) error { - err := pipe.Execute() - if err != nil { - return err - } - err = pipe.Close() - if opErr := ctx.OperatorErr(); opErr != nil { - return opErr - } - return err -} - -type localRowCntListener struct { - EmptyRowCntListener - reorgCtx *reorgCtx - counter prometheus.Counter - - // prevPhysicalRowCnt records the row count from previous physical tables (partitions). - prevPhysicalRowCnt int64 - // curPhysicalRowCnt records the row count of current physical table. - curPhysicalRowCnt struct { - cnt int64 - mu sync.Mutex - } -} - -func (s *localRowCntListener) Written(rowCnt int) { - s.curPhysicalRowCnt.mu.Lock() - s.curPhysicalRowCnt.cnt += int64(rowCnt) - s.reorgCtx.setRowCount(s.prevPhysicalRowCnt + s.curPhysicalRowCnt.cnt) - s.curPhysicalRowCnt.mu.Unlock() - s.counter.Add(float64(rowCnt)) -} - -func (s *localRowCntListener) SetTotal(total int) { - s.reorgCtx.setRowCount(s.prevPhysicalRowCnt + int64(total)) -} - -// writePhysicalTableRecord handles the "add index" or "modify/change column" reorganization state for a non-partitioned table or a partition. -// For a partitioned table, it should be handled partition by partition. -// -// How to "add index" or "update column value" in reorganization state? -// Concurrently process the @@tidb_ddl_reorg_worker_cnt tasks. Each task deals with a handle range of the index/row record. -// The handle range is split from PD regions now. Each worker deal with a region table key range one time. -// Each handle range by estimation, concurrent processing needs to perform after the handle range has been acquired. -// The operation flow is as follows: -// 1. Open numbers of defaultWorkers goroutines. -// 2. Split table key range from PD regions. -// 3. Send tasks to running workers by workers's task channel. Each task deals with a region key ranges. -// 4. Wait all these running tasks finished, then continue to step 3, until all tasks is done. -// -// The above operations are completed in a transaction. -// Finally, update the concurrent processing of the total number of rows, and store the completed handle value. -func (dc *ddlCtx) writePhysicalTableRecord( - ctx context.Context, - sessPool *sess.Pool, - t table.PhysicalTable, - bfWorkerType backfillerType, - reorgInfo *reorgInfo, -) error { - startKey, endKey := reorgInfo.StartKey, reorgInfo.EndKey - - if err := dc.isReorgRunnable(reorgInfo.Job.ID, false); err != nil { - return errors.Trace(err) - } - - failpoint.Inject("MockCaseWhenParseFailure", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - failpoint.Return(errors.New("job.ErrCount:" + strconv.Itoa(int(reorgInfo.Job.ErrorCount)) + ", mock unknown type: ast.whenClause.")) - } - }) - if bfWorkerType == typeAddIndexWorker && reorgInfo.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { - return dc.runAddIndexInLocalIngestMode(ctx, sessPool, t, reorgInfo) - } - - jc := reorgInfo.NewJobContext() - - eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) - - scheduler, err := newTxnBackfillScheduler(egCtx, reorgInfo, sessPool, bfWorkerType, t, jc) - if err != nil { - return errors.Trace(err) - } - defer scheduler.close(true) - - err = scheduler.setupWorkers() - if err != nil { - return errors.Trace(err) - } - - // process result goroutine - eg.Go(func() error { - totalAddedCount := reorgInfo.Job.GetRowCount() - keeper := newDoneTaskKeeper(startKey) - cnt := 0 - - for { - select { - case <-egCtx.Done(): - return egCtx.Err() - case result, ok := <-scheduler.resultChan(): - if !ok { - logutil.DDLLogger().Info("backfill workers successfully processed", - zap.Stringer("element", reorgInfo.currElement), - zap.Int64("total added count", totalAddedCount), - zap.String("start key", hex.EncodeToString(startKey))) - return nil - } - cnt++ - - if result.err != nil { - logutil.DDLLogger().Warn("backfill worker failed", - zap.Int64("job ID", reorgInfo.ID), - zap.Int64("total added count", totalAddedCount), - zap.String("start key", hex.EncodeToString(startKey)), - zap.String("result next key", hex.EncodeToString(result.nextKey)), - zap.Error(result.err)) - return result.err - } - - if result.totalCount > 0 { - totalAddedCount = int64(result.totalCount) - } else { - totalAddedCount += int64(result.addedCount) - } - dc.getReorgCtx(reorgInfo.Job.ID).setRowCount(totalAddedCount) - - keeper.updateNextKey(result.taskID, result.nextKey) - - if cnt%(scheduler.currentWorkerSize()*4) == 0 { - err2 := reorgInfo.UpdateReorgMeta(keeper.nextKey, sessPool) - if err2 != nil { - logutil.DDLLogger().Warn("update reorg meta failed", - zap.Int64("job ID", reorgInfo.ID), - zap.Error(err2)) - } - // We try to adjust the worker size regularly to reduce - // the overhead of loading the DDL related global variables. - err2 = scheduler.adjustWorkerSize() - if err2 != nil { - logutil.DDLLogger().Warn("cannot adjust backfill worker size", - zap.Int64("job ID", reorgInfo.ID), - zap.Error(err2)) - } - } - } - } - }) - - // generate task goroutine - eg.Go(func() error { - // we will modify the startKey in this goroutine, so copy them to avoid race. - start, end := startKey, endKey - taskIDAlloc := newTaskIDAllocator() - for { - kvRanges, err2 := loadTableRanges(egCtx, t, dc.store, start, end, backfillTaskChanSize) - if err2 != nil { - return errors.Trace(err2) - } - if len(kvRanges) == 0 { - break - } - logutil.DDLLogger().Info("start backfill workers to reorg record", - zap.Stringer("type", bfWorkerType), - zap.Int("workerCnt", scheduler.currentWorkerSize()), - zap.Int("regionCnt", len(kvRanges)), - zap.String("startKey", hex.EncodeToString(start)), - zap.String("endKey", hex.EncodeToString(end))) - - err2 = sendTasks(scheduler, t, kvRanges, reorgInfo, taskIDAlloc, bfWorkerType) - if err2 != nil { - return errors.Trace(err2) - } - - start = kvRanges[len(kvRanges)-1].EndKey - if start.Cmp(end) >= 0 { - break - } - } - - scheduler.close(false) - return nil - }) - - return eg.Wait() -} - -func injectCheckBackfillWorkerNum(curWorkerSize int, isMergeWorker bool) error { - if isMergeWorker { - return nil - } - failpoint.Inject("checkBackfillWorkerNum", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - num := int(atomic.LoadInt32(&TestCheckWorkerNumber)) - if num != 0 { - if num != curWorkerSize { - failpoint.Return(errors.Errorf("expected backfill worker num: %v, actual record num: %v", num, curWorkerSize)) - } - var wg sync.WaitGroup - wg.Add(1) - TestCheckWorkerNumCh <- &wg - wg.Wait() - } - } - }) - return nil -} - -// recordIterFunc is used for low-level record iteration. -type recordIterFunc func(h kv.Handle, rowKey kv.Key, rawRecord []byte) (more bool, err error) - -func iterateSnapshotKeys(ctx *JobContext, store kv.Storage, priority int, keyPrefix kv.Key, version uint64, - startKey kv.Key, endKey kv.Key, fn recordIterFunc) error { - isRecord := tablecodec.IsRecordKey(keyPrefix.Next()) - var firstKey kv.Key - if startKey == nil { - firstKey = keyPrefix - } else { - firstKey = startKey - } - - var upperBound kv.Key - if endKey == nil { - upperBound = keyPrefix.PrefixNext() - } else { - upperBound = endKey.PrefixNext() - } - - ver := kv.Version{Ver: version} - snap := store.GetSnapshot(ver) - snap.SetOption(kv.Priority, priority) - snap.SetOption(kv.RequestSourceInternal, true) - snap.SetOption(kv.RequestSourceType, ctx.ddlJobSourceType()) - snap.SetOption(kv.ExplicitRequestSourceType, kvutil.ExplicitTypeDDL) - if tagger := ctx.getResourceGroupTaggerForTopSQL(); tagger != nil { - snap.SetOption(kv.ResourceGroupTagger, tagger) - } - snap.SetOption(kv.ResourceGroupName, ctx.resourceGroupName) - - it, err := snap.Iter(firstKey, upperBound) - if err != nil { - return errors.Trace(err) - } - defer it.Close() - - for it.Valid() { - if !it.Key().HasPrefix(keyPrefix) { - break - } - - var handle kv.Handle - if isRecord { - handle, err = tablecodec.DecodeRowKey(it.Key()) - if err != nil { - return errors.Trace(err) - } - } - - more, err := fn(handle, it.Key(), it.Value()) - if !more || err != nil { - return errors.Trace(err) - } - - err = kv.NextUntil(it, util.RowKeyPrefixFilter(it.Key())) - if err != nil { - if kv.ErrNotExist.Equal(err) { - break - } - return errors.Trace(err) - } - } - - return nil -} - -// GetRangeEndKey gets the actual end key for the range of [startKey, endKey). -func GetRangeEndKey(ctx *JobContext, store kv.Storage, priority int, keyPrefix kv.Key, startKey, endKey kv.Key) (kv.Key, error) { - snap := store.GetSnapshot(kv.MaxVersion) - snap.SetOption(kv.Priority, priority) - if tagger := ctx.getResourceGroupTaggerForTopSQL(); tagger != nil { - snap.SetOption(kv.ResourceGroupTagger, tagger) - } - snap.SetOption(kv.ResourceGroupName, ctx.resourceGroupName) - snap.SetOption(kv.RequestSourceInternal, true) - snap.SetOption(kv.RequestSourceType, ctx.ddlJobSourceType()) - snap.SetOption(kv.ExplicitRequestSourceType, kvutil.ExplicitTypeDDL) - it, err := snap.IterReverse(endKey, nil) - if err != nil { - return nil, errors.Trace(err) - } - defer it.Close() - - if !it.Valid() || !it.Key().HasPrefix(keyPrefix) { - return startKey.Next(), nil - } - if it.Key().Cmp(startKey) < 0 { - return startKey.Next(), nil - } - - return it.Key().Next(), nil -} - -func mergeWarningsAndWarningsCount(partWarnings, totalWarnings map[errors.ErrorID]*terror.Error, partWarningsCount, totalWarningsCount map[errors.ErrorID]int64) (map[errors.ErrorID]*terror.Error, map[errors.ErrorID]int64) { - for _, warn := range partWarnings { - if _, ok := totalWarningsCount[warn.ID()]; ok { - totalWarningsCount[warn.ID()] += partWarningsCount[warn.ID()] - } else { - totalWarningsCount[warn.ID()] = partWarningsCount[warn.ID()] - totalWarnings[warn.ID()] = warn - } - } - return totalWarnings, totalWarningsCount -} - -func logSlowOperations(elapsed time.Duration, slowMsg string, threshold uint32) { - if threshold == 0 { - threshold = atomic.LoadUint32(&variable.DDLSlowOprThreshold) - } - - if elapsed >= time.Duration(threshold)*time.Millisecond { - logutil.DDLLogger().Info("slow operations", - zap.Duration("takeTimes", elapsed), - zap.String("msg", slowMsg)) - } -} - -// doneTaskKeeper keeps the done tasks and update the latest next key. -type doneTaskKeeper struct { - doneTaskNextKey map[int]kv.Key - current int - nextKey kv.Key -} - -func newDoneTaskKeeper(start kv.Key) *doneTaskKeeper { - return &doneTaskKeeper{ - doneTaskNextKey: make(map[int]kv.Key), - current: 0, - nextKey: start, - } -} - -func (n *doneTaskKeeper) updateNextKey(doneTaskID int, next kv.Key) { - if doneTaskID == n.current { - n.current++ - n.nextKey = next - for { - nKey, ok := n.doneTaskNextKey[n.current] - if !ok { - break - } - delete(n.doneTaskNextKey, n.current) - n.current++ - n.nextKey = nKey - } - return - } - n.doneTaskNextKey[doneTaskID] = next -} diff --git a/pkg/ddl/backfilling_dist_scheduler.go b/pkg/ddl/backfilling_dist_scheduler.go index b74cbfa6f0cae..55abe4225da12 100644 --- a/pkg/ddl/backfilling_dist_scheduler.go +++ b/pkg/ddl/backfilling_dist_scheduler.go @@ -105,15 +105,15 @@ func (sch *BackfillingSchedulerExt) OnNextSubtasksBatch( return generateMergePlan(taskHandle, task, logger) case proto.BackfillStepWriteAndIngest: if sch.GlobalSort { - if _, _err_ := failpoint.Eval(_curpkg_("mockWriteIngest")); _err_ == nil { + failpoint.Inject("mockWriteIngest", func() { m := &BackfillSubTaskMeta{ MetaGroups: []*external.SortedKVMeta{}, } metaBytes, _ := json.Marshal(m) metaArr := make([][]byte, 0, 16) metaArr = append(metaArr, metaBytes) - return metaArr, nil - } + failpoint.Return(metaArr, nil) + }) return generateGlobalSortIngestPlan( ctx, sch.d.store.(kv.StorageWithPD), @@ -148,9 +148,9 @@ func (sch *BackfillingSchedulerExt) GetNextStep(task *proto.TaskBase) proto.Step } func skipMergeSort(stats []external.MultipleFilesStat) bool { - if _, _err_ := failpoint.Eval(_curpkg_("forceMergeSort")); _err_ == nil { - return false - } + failpoint.Inject("forceMergeSort", func() { + failpoint.Return(false) + }) return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold } @@ -330,9 +330,9 @@ func generateNonPartitionPlan( } func calculateRegionBatch(totalRegionCnt int, instanceCnt int, useLocalDisk bool) int { - if val, _err_ := failpoint.Eval(_curpkg_("mockRegionBatch")); _err_ == nil { - return val.(int) - } + failpoint.Inject("mockRegionBatch", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) var regionBatch int avgTasksPerInstance := (totalRegionCnt + instanceCnt - 1) / instanceCnt // ceiling if useLocalDisk { @@ -427,10 +427,10 @@ func splitSubtaskMetaForOneKVMetaGroup( return nil, err } ts := oracle.ComposeTS(p, l) - if val, _err_ := failpoint.Eval(_curpkg_("mockTSForGlobalSort")); _err_ == nil { + failpoint.Inject("mockTSForGlobalSort", func(val failpoint.Value) { i := val.(int) ts = uint64(i) - } + }) splitter, err := getRangeSplitter( ctx, store, cloudStorageURI, int64(kvMeta.TotalKVSize), instanceCnt, kvMeta.MultipleFilesStats, logger) if err != nil { diff --git a/pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ b/pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ deleted file mode 100644 index 55abe4225da12..0000000000000 --- a/pkg/ddl/backfilling_dist_scheduler.go__failpoint_stash__ +++ /dev/null @@ -1,639 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "context" - "encoding/hex" - "encoding/json" - "math" - "sort" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - diststorage "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/external" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/util/backoff" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - "go.uber.org/zap" -) - -// BackfillingSchedulerExt is an extension of litBackfillScheduler, exported for test. -type BackfillingSchedulerExt struct { - d *ddl - GlobalSort bool -} - -// NewBackfillingSchedulerExt creates a new backfillingSchedulerExt, only used for test now. -func NewBackfillingSchedulerExt(d DDL) (scheduler.Extension, error) { - ddl, ok := d.(*ddl) - if !ok { - return nil, errors.New("The getDDL result should be the type of *ddl") - } - return &BackfillingSchedulerExt{ - d: ddl, - }, nil -} - -var _ scheduler.Extension = (*BackfillingSchedulerExt)(nil) - -// OnTick implements scheduler.Extension interface. -func (*BackfillingSchedulerExt) OnTick(_ context.Context, _ *proto.Task) { -} - -// OnNextSubtasksBatch generate batch of next step's plan. -func (sch *BackfillingSchedulerExt) OnNextSubtasksBatch( - ctx context.Context, - taskHandle diststorage.TaskHandle, - task *proto.Task, - execIDs []string, - nextStep proto.Step, -) (subtaskMeta [][]byte, err error) { - logger := logutil.DDLLogger().With( - zap.Stringer("type", task.Type), - zap.Int64("task-id", task.ID), - zap.String("curr-step", proto.Step2Str(task.Type, task.Step)), - zap.String("next-step", proto.Step2Str(task.Type, nextStep)), - ) - var backfillMeta BackfillTaskMeta - if err := json.Unmarshal(task.Meta, &backfillMeta); err != nil { - return nil, err - } - job := &backfillMeta.Job - tblInfo, err := getTblInfo(ctx, sch.d, job) - if err != nil { - return nil, err - } - logger.Info("on next subtasks batch") - - // TODO: use planner. - switch nextStep { - case proto.BackfillStepReadIndex: - if tblInfo.Partition != nil { - return generatePartitionPlan(tblInfo) - } - return generateNonPartitionPlan(ctx, sch.d, tblInfo, job, sch.GlobalSort, len(execIDs)) - case proto.BackfillStepMergeSort: - return generateMergePlan(taskHandle, task, logger) - case proto.BackfillStepWriteAndIngest: - if sch.GlobalSort { - failpoint.Inject("mockWriteIngest", func() { - m := &BackfillSubTaskMeta{ - MetaGroups: []*external.SortedKVMeta{}, - } - metaBytes, _ := json.Marshal(m) - metaArr := make([][]byte, 0, 16) - metaArr = append(metaArr, metaBytes) - failpoint.Return(metaArr, nil) - }) - return generateGlobalSortIngestPlan( - ctx, - sch.d.store.(kv.StorageWithPD), - taskHandle, - task, - backfillMeta.CloudStorageURI, - logger) - } - return nil, nil - default: - return nil, nil - } -} - -// GetNextStep implements scheduler.Extension interface. -func (sch *BackfillingSchedulerExt) GetNextStep(task *proto.TaskBase) proto.Step { - switch task.Step { - case proto.StepInit: - return proto.BackfillStepReadIndex - case proto.BackfillStepReadIndex: - if sch.GlobalSort { - return proto.BackfillStepMergeSort - } - return proto.StepDone - case proto.BackfillStepMergeSort: - return proto.BackfillStepWriteAndIngest - case proto.BackfillStepWriteAndIngest: - return proto.StepDone - default: - return proto.StepDone - } -} - -func skipMergeSort(stats []external.MultipleFilesStat) bool { - failpoint.Inject("forceMergeSort", func() { - failpoint.Return(false) - }) - return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold -} - -// OnDone implements scheduler.Extension interface. -func (*BackfillingSchedulerExt) OnDone(_ context.Context, _ diststorage.TaskHandle, _ *proto.Task) error { - return nil -} - -// GetEligibleInstances implements scheduler.Extension interface. -func (*BackfillingSchedulerExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]string, error) { - return nil, nil -} - -// IsRetryableErr implements scheduler.Extension.IsRetryableErr interface. -func (*BackfillingSchedulerExt) IsRetryableErr(error) bool { - return true -} - -// LitBackfillScheduler wraps BaseScheduler. -type LitBackfillScheduler struct { - *scheduler.BaseScheduler - d *ddl -} - -func newLitBackfillScheduler(ctx context.Context, d *ddl, task *proto.Task, param scheduler.Param) scheduler.Scheduler { - sch := LitBackfillScheduler{ - d: d, - BaseScheduler: scheduler.NewBaseScheduler(ctx, task, param), - } - return &sch -} - -// Init implements BaseScheduler interface. -func (sch *LitBackfillScheduler) Init() (err error) { - taskMeta := &BackfillTaskMeta{} - if err = json.Unmarshal(sch.BaseScheduler.GetTask().Meta, taskMeta); err != nil { - return errors.Annotate(err, "unmarshal task meta failed") - } - sch.BaseScheduler.Extension = &BackfillingSchedulerExt{ - d: sch.d, - GlobalSort: len(taskMeta.CloudStorageURI) > 0} - return sch.BaseScheduler.Init() -} - -// Close implements BaseScheduler interface. -func (sch *LitBackfillScheduler) Close() { - sch.BaseScheduler.Close() -} - -func getTblInfo(ctx context.Context, d *ddl, job *model.Job) (tblInfo *model.TableInfo, err error) { - err = kv.RunInNewTxn(ctx, d.store, true, func(_ context.Context, txn kv.Transaction) error { - tblInfo, err = meta.NewMeta(txn).GetTable(job.SchemaID, job.TableID) - return err - }) - if err != nil { - return nil, err - } - - return tblInfo, nil -} - -func generatePartitionPlan(tblInfo *model.TableInfo) (metas [][]byte, err error) { - defs := tblInfo.Partition.Definitions - physicalIDs := make([]int64, len(defs)) - for i := range defs { - physicalIDs[i] = defs[i].ID - } - - subTaskMetas := make([][]byte, 0, len(physicalIDs)) - for _, physicalID := range physicalIDs { - subTaskMeta := &BackfillSubTaskMeta{ - PhysicalTableID: physicalID, - } - - metaBytes, err := json.Marshal(subTaskMeta) - if err != nil { - return nil, err - } - - subTaskMetas = append(subTaskMetas, metaBytes) - } - return subTaskMetas, nil -} - -const ( - scanRegionBackoffBase = 200 * time.Millisecond - scanRegionBackoffMax = 2 * time.Second -) - -func generateNonPartitionPlan( - ctx context.Context, - d *ddl, - tblInfo *model.TableInfo, - job *model.Job, - useCloud bool, - instanceCnt int, -) (metas [][]byte, err error) { - tbl, err := getTable(d.ddlCtx.getAutoIDRequirement(), job.SchemaID, tblInfo) - if err != nil { - return nil, err - } - ver, err := getValidCurrentVersion(d.store) - if err != nil { - return nil, errors.Trace(err) - } - - startKey, endKey, err := getTableRange(d.jobContext(job.ID, job.ReorgMeta), d.ddlCtx, tbl.(table.PhysicalTable), ver.Ver, job.Priority) - if startKey == nil && endKey == nil { - // Empty table. - return nil, nil - } - if err != nil { - return nil, errors.Trace(err) - } - - subTaskMetas := make([][]byte, 0, 4) - backoffer := backoff.NewExponential(scanRegionBackoffBase, 2, scanRegionBackoffMax) - err = handle.RunWithRetry(ctx, 8, backoffer, logutil.DDLLogger(), func(_ context.Context) (bool, error) { - regionCache := d.store.(helper.Storage).GetRegionCache() - recordRegionMetas, err := regionCache.LoadRegionsInKeyRange(tikv.NewBackofferWithVars(context.Background(), 20000, nil), startKey, endKey) - - if err != nil { - return false, err - } - sort.Slice(recordRegionMetas, func(i, j int) bool { - return bytes.Compare(recordRegionMetas[i].StartKey(), recordRegionMetas[j].StartKey()) < 0 - }) - - // Check if regions are continuous. - shouldRetry := false - cur := recordRegionMetas[0] - for _, m := range recordRegionMetas[1:] { - if !bytes.Equal(cur.EndKey(), m.StartKey()) { - shouldRetry = true - break - } - cur = m - } - - if shouldRetry { - return true, nil - } - - regionBatch := calculateRegionBatch(len(recordRegionMetas), instanceCnt, !useCloud) - - for i := 0; i < len(recordRegionMetas); i += regionBatch { - end := i + regionBatch - if end > len(recordRegionMetas) { - end = len(recordRegionMetas) - } - batch := recordRegionMetas[i:end] - subTaskMeta := &BackfillSubTaskMeta{ - RowStart: batch[0].StartKey(), - RowEnd: batch[len(batch)-1].EndKey(), - } - if i == 0 { - subTaskMeta.RowStart = startKey - } - if end == len(recordRegionMetas) { - subTaskMeta.RowEnd = endKey - } - metaBytes, err := json.Marshal(subTaskMeta) - if err != nil { - return false, err - } - subTaskMetas = append(subTaskMetas, metaBytes) - } - return false, nil - }) - if err != nil { - return nil, errors.Trace(err) - } - if len(subTaskMetas) == 0 { - return nil, errors.Errorf("regions are not continuous") - } - return subTaskMetas, nil -} - -func calculateRegionBatch(totalRegionCnt int, instanceCnt int, useLocalDisk bool) int { - failpoint.Inject("mockRegionBatch", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) - var regionBatch int - avgTasksPerInstance := (totalRegionCnt + instanceCnt - 1) / instanceCnt // ceiling - if useLocalDisk { - regionBatch = avgTasksPerInstance - } else { - // For cloud storage, each subtask should contain no more than 4000 regions. - regionBatch = min(4000, avgTasksPerInstance) - } - regionBatch = max(regionBatch, 1) - return regionBatch -} - -func generateGlobalSortIngestPlan( - ctx context.Context, - store kv.StorageWithPD, - taskHandle diststorage.TaskHandle, - task *proto.Task, - cloudStorageURI string, - logger *zap.Logger, -) ([][]byte, error) { - var ( - kvMetaGroups []*external.SortedKVMeta - eleIDs []int64 - ) - for _, step := range []proto.Step{proto.BackfillStepMergeSort, proto.BackfillStepReadIndex} { - hasSubtasks := false - err := forEachBackfillSubtaskMeta(taskHandle, task.ID, step, func(subtask *BackfillSubTaskMeta) { - hasSubtasks = true - if kvMetaGroups == nil { - kvMetaGroups = make([]*external.SortedKVMeta, len(subtask.MetaGroups)) - eleIDs = subtask.EleIDs - } - for i, cur := range subtask.MetaGroups { - if kvMetaGroups[i] == nil { - kvMetaGroups[i] = &external.SortedKVMeta{} - } - kvMetaGroups[i].Merge(cur) - } - }) - if err != nil { - return nil, err - } - if hasSubtasks { - break - } - // If there is no subtask for merge sort step, - // it means the merge sort step is skipped. - } - - instanceIDs, err := scheduler.GetLiveExecIDs(ctx) - if err != nil { - return nil, err - } - iCnt := int64(len(instanceIDs)) - metaArr := make([][]byte, 0, 16) - for i, g := range kvMetaGroups { - if g == nil { - logger.Error("meet empty kv group when getting subtask summary", - zap.Int64("taskID", task.ID)) - return nil, errors.Errorf("subtask kv group %d is empty", i) - } - eleID := int64(0) - // in case the subtask metadata is written by an old version of TiDB. - if i < len(eleIDs) { - eleID = eleIDs[i] - } - newMeta, err := splitSubtaskMetaForOneKVMetaGroup(ctx, store, g, eleID, cloudStorageURI, iCnt, logger) - if err != nil { - return nil, errors.Trace(err) - } - metaArr = append(metaArr, newMeta...) - } - return metaArr, nil -} - -func splitSubtaskMetaForOneKVMetaGroup( - ctx context.Context, - store kv.StorageWithPD, - kvMeta *external.SortedKVMeta, - eleID int64, - cloudStorageURI string, - instanceCnt int64, - logger *zap.Logger, -) (metaArr [][]byte, err error) { - if len(kvMeta.StartKey) == 0 && len(kvMeta.EndKey) == 0 { - // Skip global sort for empty table. - return nil, nil - } - pdCli := store.GetPDClient() - p, l, err := pdCli.GetTS(ctx) - if err != nil { - return nil, err - } - ts := oracle.ComposeTS(p, l) - failpoint.Inject("mockTSForGlobalSort", func(val failpoint.Value) { - i := val.(int) - ts = uint64(i) - }) - splitter, err := getRangeSplitter( - ctx, store, cloudStorageURI, int64(kvMeta.TotalKVSize), instanceCnt, kvMeta.MultipleFilesStats, logger) - if err != nil { - return nil, err - } - defer func() { - err := splitter.Close() - if err != nil { - logger.Error("failed to close range splitter", zap.Error(err)) - } - }() - - startKey := kvMeta.StartKey - var endKey kv.Key - for { - endKeyOfGroup, dataFiles, statFiles, rangeSplitKeys, err := splitter.SplitOneRangesGroup() - if err != nil { - return nil, err - } - if len(endKeyOfGroup) == 0 { - endKey = kvMeta.EndKey - } else { - endKey = kv.Key(endKeyOfGroup).Clone() - } - logger.Info("split subtask range", - zap.String("startKey", hex.EncodeToString(startKey)), - zap.String("endKey", hex.EncodeToString(endKey))) - - if bytes.Compare(startKey, endKey) >= 0 { - return nil, errors.Errorf("invalid range, startKey: %s, endKey: %s", - hex.EncodeToString(startKey), hex.EncodeToString(endKey)) - } - m := &BackfillSubTaskMeta{ - MetaGroups: []*external.SortedKVMeta{{ - StartKey: startKey, - EndKey: endKey, - TotalKVSize: kvMeta.TotalKVSize / uint64(instanceCnt), - }}, - DataFiles: dataFiles, - StatFiles: statFiles, - RangeSplitKeys: rangeSplitKeys, - TS: ts, - } - if eleID > 0 { - m.EleIDs = []int64{eleID} - } - metaBytes, err := json.Marshal(m) - if err != nil { - return nil, err - } - metaArr = append(metaArr, metaBytes) - if len(endKeyOfGroup) == 0 { - break - } - startKey = endKey - } - return metaArr, nil -} - -func generateMergePlan( - taskHandle diststorage.TaskHandle, - task *proto.Task, - logger *zap.Logger, -) ([][]byte, error) { - // check data files overlaps, - // if data files overlaps too much, we need a merge step. - var ( - multiStatsGroup [][]external.MultipleFilesStat - kvMetaGroups []*external.SortedKVMeta - eleIDs []int64 - ) - err := forEachBackfillSubtaskMeta(taskHandle, task.ID, proto.BackfillStepReadIndex, - func(subtask *BackfillSubTaskMeta) { - if kvMetaGroups == nil { - kvMetaGroups = make([]*external.SortedKVMeta, len(subtask.MetaGroups)) - multiStatsGroup = make([][]external.MultipleFilesStat, len(subtask.MetaGroups)) - eleIDs = subtask.EleIDs - } - for i, g := range subtask.MetaGroups { - if kvMetaGroups[i] == nil { - kvMetaGroups[i] = &external.SortedKVMeta{} - multiStatsGroup[i] = make([]external.MultipleFilesStat, 0, 100) - } - kvMetaGroups[i].Merge(g) - multiStatsGroup[i] = append(multiStatsGroup[i], g.MultipleFilesStats...) - } - }) - if err != nil { - return nil, err - } - - allSkip := true - for _, multiStats := range multiStatsGroup { - if !skipMergeSort(multiStats) { - allSkip = false - break - } - } - if allSkip { - logger.Info("skip merge sort") - return nil, nil - } - - metaArr := make([][]byte, 0, 16) - for i, g := range kvMetaGroups { - dataFiles := make([]string, 0, 1000) - if g == nil { - logger.Error("meet empty kv group when getting subtask summary", - zap.Int64("taskID", task.ID)) - return nil, errors.Errorf("subtask kv group %d is empty", i) - } - for _, m := range g.MultipleFilesStats { - for _, filePair := range m.Filenames { - dataFiles = append(dataFiles, filePair[0]) - } - } - var eleID []int64 - if i < len(eleIDs) { - eleID = []int64{eleIDs[i]} - } - start := 0 - step := external.MergeSortFileCountStep - for start < len(dataFiles) { - end := start + step - if end > len(dataFiles) { - end = len(dataFiles) - } - m := &BackfillSubTaskMeta{ - DataFiles: dataFiles[start:end], - EleIDs: eleID, - } - metaBytes, err := json.Marshal(m) - if err != nil { - return nil, err - } - metaArr = append(metaArr, metaBytes) - - start = end - } - } - return metaArr, nil -} - -func getRangeSplitter( - ctx context.Context, - store kv.StorageWithPD, - cloudStorageURI string, - totalSize int64, - instanceCnt int64, - multiFileStat []external.MultipleFilesStat, - logger *zap.Logger, -) (*external.RangeSplitter, error) { - backend, err := storage.ParseBackend(cloudStorageURI, nil) - if err != nil { - return nil, err - } - extStore, err := storage.NewWithDefaultOpt(ctx, backend) - if err != nil { - return nil, err - } - - rangeGroupSize := totalSize / instanceCnt - rangeGroupKeys := int64(math.MaxInt64) - - var maxSizePerRange = int64(config.SplitRegionSize) - var maxKeysPerRange = int64(config.SplitRegionKeys) - if store != nil { - pdCli := store.GetPDClient() - tls, err := ingest.NewDDLTLS() - if err == nil { - size, keys, err := local.GetRegionSplitSizeKeys(ctx, pdCli, tls) - if err == nil { - maxSizePerRange = max(maxSizePerRange, size) - maxKeysPerRange = max(maxKeysPerRange, keys) - } else { - logger.Warn("fail to get region split keys and size", zap.Error(err)) - } - } else { - logger.Warn("fail to get region split keys and size", zap.Error(err)) - } - } - - return external.NewRangeSplitter(ctx, multiFileStat, extStore, - rangeGroupSize, rangeGroupKeys, maxSizePerRange, maxKeysPerRange) -} - -func forEachBackfillSubtaskMeta( - taskHandle diststorage.TaskHandle, - gTaskID int64, - step proto.Step, - fn func(subtask *BackfillSubTaskMeta), -) error { - subTaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTaskID, step) - if err != nil { - return errors.Trace(err) - } - for _, subTaskMeta := range subTaskMetas { - subtask, err := decodeBackfillSubTaskMeta(subTaskMeta) - if err != nil { - logutil.DDLLogger().Error("unmarshal error", zap.Error(err)) - return errors.Trace(err) - } - fn(subtask) - } - return nil -} diff --git a/pkg/ddl/backfilling_operators.go b/pkg/ddl/backfilling_operators.go index a2d87df69a74f..5ae5554ef290b 100644 --- a/pkg/ddl/backfilling_operators.go +++ b/pkg/ddl/backfilling_operators.go @@ -243,9 +243,9 @@ func NewWriteIndexToExternalStoragePipeline( } memCap := resource.Mem.Capacity() memSizePerIndex := uint64(memCap / int64(writerCnt*2*len(idxInfos))) - if _, _err_ := failpoint.Eval(_curpkg_("mockWriterMemSize")); _err_ == nil { + failpoint.Inject("mockWriterMemSize", func() { memSizePerIndex = 1 * size.GB - } + }) srcOp := NewTableScanTaskSource(ctx, store, tbl, startKey, endKey, nil) scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt, nil) @@ -492,9 +492,9 @@ func (w *tableScanWorker) HandleTask(task TableScanTask, sender func(IndexRecord w.ctx.onError(dbterror.ErrReorgPanic) }, false) - if _, _err_ := failpoint.Eval(_curpkg_("injectPanicForTableScan")); _err_ == nil { + failpoint.Inject("injectPanicForTableScan", func() { panic("mock panic") - } + }) if w.se == nil { sessCtx, err := w.sessPool.Get() if err != nil { @@ -519,10 +519,10 @@ func (w *tableScanWorker) scanRecords(task TableScanTask, sender func(IndexRecor var idxResult IndexRecordChunk err := wrapInBeginRollback(w.se, func(startTS uint64) error { - if _, _err_ := failpoint.Eval(_curpkg_("mockScanRecordError")); _err_ == nil { - return errors.New("mock scan record error") - } - failpoint.Call(_curpkg_("scanRecordExec")) + failpoint.Inject("mockScanRecordError", func() { + failpoint.Return(errors.New("mock scan record error")) + }) + failpoint.InjectCall("scanRecordExec") rs, err := buildTableScan(w.ctx, w.copCtx.GetBase(), startTS, task.Start, task.End) if err != nil { return err @@ -789,9 +789,9 @@ type indexIngestBaseWorker struct { } func (w *indexIngestBaseWorker) HandleTask(rs IndexRecordChunk) (IndexWriteResult, error) { - if _, _err_ := failpoint.Eval(_curpkg_("injectPanicForIndexIngest")); _err_ == nil { + failpoint.Inject("injectPanicForIndexIngest", func() { panic("mock panic") - } + }) result := IndexWriteResult{ ID: rs.ID, @@ -851,10 +851,10 @@ func (w *indexIngestBaseWorker) Close() { // WriteChunk will write index records to lightning engine. func (w *indexIngestBaseWorker) WriteChunk(rs *IndexRecordChunk) (count int, nextKey kv.Key, err error) { - if _, _err_ := failpoint.Eval(_curpkg_("mockWriteLocalError")); _err_ == nil { - return 0, nil, errors.New("mock write local error") - } - failpoint.Call(_curpkg_("writeLocalExec"), rs.Done) + failpoint.Inject("mockWriteLocalError", func(_ failpoint.Value) { + failpoint.Return(0, nil, errors.New("mock write local error")) + }) + failpoint.InjectCall("writeLocalExec", rs.Done) oprStartTime := time.Now() vars := w.se.GetSessionVars() @@ -934,9 +934,9 @@ func (s *indexWriteResultSink) flush() error { if s.backendCtx == nil { return nil } - if _, _err_ := failpoint.Eval(_curpkg_("mockFlushError")); _err_ == nil { - return errors.New("mock flush error") - } + failpoint.Inject("mockFlushError", func(_ failpoint.Value) { + failpoint.Return(errors.New("mock flush error")) + }) flushed, imported, err := s.backendCtx.Flush(ingest.FlushModeForceFlushAndImport) if s.cpMgr != nil { // Try to advance watermark even if there is an error. diff --git a/pkg/ddl/backfilling_operators.go__failpoint_stash__ b/pkg/ddl/backfilling_operators.go__failpoint_stash__ deleted file mode 100644 index 5ae5554ef290b..0000000000000 --- a/pkg/ddl/backfilling_operators.go__failpoint_stash__ +++ /dev/null @@ -1,962 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "encoding/hex" - "fmt" - "path" - "strconv" - "sync/atomic" - "time" - - "github.com/docker/go-units" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/ddl/copr" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/operator" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/external" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/resourcemanager/pool/workerpool" - "github.com/pingcap/tidb/pkg/resourcemanager/util" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/size" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -var ( - _ operator.Operator = (*TableScanTaskSource)(nil) - _ operator.WithSink[TableScanTask] = (*TableScanTaskSource)(nil) - - _ operator.WithSource[TableScanTask] = (*TableScanOperator)(nil) - _ operator.Operator = (*TableScanOperator)(nil) - _ operator.WithSink[IndexRecordChunk] = (*TableScanOperator)(nil) - - _ operator.WithSource[IndexRecordChunk] = (*IndexIngestOperator)(nil) - _ operator.Operator = (*IndexIngestOperator)(nil) - _ operator.WithSink[IndexWriteResult] = (*IndexIngestOperator)(nil) - - _ operator.WithSource[IndexWriteResult] = (*indexWriteResultSink)(nil) - _ operator.Operator = (*indexWriteResultSink)(nil) -) - -type opSessPool interface { - Get() (sessionctx.Context, error) - Put(sessionctx.Context) -} - -// OperatorCtx is the context for AddIndexIngestPipeline. -// This is used to cancel the pipeline and collect errors. -type OperatorCtx struct { - context.Context - cancel context.CancelFunc - err atomic.Pointer[error] -} - -// NewDistTaskOperatorCtx is used for adding index with dist framework. -func NewDistTaskOperatorCtx(ctx context.Context, taskID, subtaskID int64) *OperatorCtx { - opCtx, cancel := context.WithCancel(ctx) - opCtx = logutil.WithFields(opCtx, zap.Int64("task-id", taskID), zap.Int64("subtask-id", subtaskID)) - return &OperatorCtx{ - Context: opCtx, - cancel: cancel, - } -} - -// NewLocalOperatorCtx is used for adding index with local ingest mode. -func NewLocalOperatorCtx(ctx context.Context, jobID int64) *OperatorCtx { - opCtx, cancel := context.WithCancel(ctx) - opCtx = logutil.WithFields(opCtx, zap.Int64("jobID", jobID)) - return &OperatorCtx{ - Context: opCtx, - cancel: cancel, - } -} - -func (ctx *OperatorCtx) onError(err error) { - tracedErr := errors.Trace(err) - ctx.cancel() - ctx.err.CompareAndSwap(nil, &tracedErr) -} - -// Cancel cancels the pipeline. -func (ctx *OperatorCtx) Cancel() { - ctx.cancel() -} - -// OperatorErr returns the error of the operator. -func (ctx *OperatorCtx) OperatorErr() error { - err := ctx.err.Load() - if err == nil { - return nil - } - return *err -} - -var ( - _ RowCountListener = (*EmptyRowCntListener)(nil) - _ RowCountListener = (*distTaskRowCntListener)(nil) - _ RowCountListener = (*localRowCntListener)(nil) -) - -// RowCountListener is invoked when some index records are flushed to disk or imported to TiKV. -type RowCountListener interface { - Written(rowCnt int) - SetTotal(total int) -} - -// EmptyRowCntListener implements a noop RowCountListener. -type EmptyRowCntListener struct{} - -// Written implements RowCountListener. -func (*EmptyRowCntListener) Written(_ int) {} - -// SetTotal implements RowCountListener. -func (*EmptyRowCntListener) SetTotal(_ int) {} - -// NewAddIndexIngestPipeline creates a pipeline for adding index in ingest mode. -func NewAddIndexIngestPipeline( - ctx *OperatorCtx, - store kv.Storage, - sessPool opSessPool, - backendCtx ingest.BackendCtx, - engines []ingest.Engine, - jobID int64, - tbl table.PhysicalTable, - idxInfos []*model.IndexInfo, - startKey, endKey kv.Key, - reorgMeta *model.DDLReorgMeta, - avgRowSize int, - concurrency int, - cpMgr *ingest.CheckpointManager, - rowCntListener RowCountListener, -) (*operator.AsyncPipeline, error) { - indexes := make([]table.Index, 0, len(idxInfos)) - for _, idxInfo := range idxInfos { - index := tables.NewIndex(tbl.GetPhysicalID(), tbl.Meta(), idxInfo) - indexes = append(indexes, index) - } - reqSrc := getDDLRequestSource(model.ActionAddIndex) - copCtx, err := NewReorgCopContext(store, reorgMeta, tbl.Meta(), idxInfos, reqSrc) - if err != nil { - return nil, err - } - poolSize := copReadChunkPoolSize() - srcChkPool := make(chan *chunk.Chunk, poolSize) - for i := 0; i < poolSize; i++ { - srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, copReadBatchSize()) - } - readerCnt, writerCnt := expectedIngestWorkerCnt(concurrency, avgRowSize) - - srcOp := NewTableScanTaskSource(ctx, store, tbl, startKey, endKey, cpMgr) - scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt, cpMgr) - ingestOp := NewIndexIngestOperator(ctx, copCtx, backendCtx, sessPool, - tbl, indexes, engines, srcChkPool, writerCnt, reorgMeta, cpMgr, rowCntListener) - sinkOp := newIndexWriteResultSink(ctx, backendCtx, tbl, indexes, cpMgr, rowCntListener) - - operator.Compose[TableScanTask](srcOp, scanOp) - operator.Compose[IndexRecordChunk](scanOp, ingestOp) - operator.Compose[IndexWriteResult](ingestOp, sinkOp) - - logutil.Logger(ctx).Info("build add index local storage operators", - zap.Int64("jobID", jobID), - zap.Int("avgRowSize", avgRowSize), - zap.Int("reader", readerCnt), - zap.Int("writer", writerCnt)) - - return operator.NewAsyncPipeline( - srcOp, scanOp, ingestOp, sinkOp, - ), nil -} - -// NewWriteIndexToExternalStoragePipeline creates a pipeline for writing index to external storage. -func NewWriteIndexToExternalStoragePipeline( - ctx *OperatorCtx, - store kv.Storage, - extStoreURI string, - sessPool opSessPool, - jobID, subtaskID int64, - tbl table.PhysicalTable, - idxInfos []*model.IndexInfo, - startKey, endKey kv.Key, - onClose external.OnCloseFunc, - reorgMeta *model.DDLReorgMeta, - avgRowSize int, - concurrency int, - resource *proto.StepResource, - rowCntListener RowCountListener, -) (*operator.AsyncPipeline, error) { - indexes := make([]table.Index, 0, len(idxInfos)) - for _, idxInfo := range idxInfos { - index := tables.NewIndex(tbl.GetPhysicalID(), tbl.Meta(), idxInfo) - indexes = append(indexes, index) - } - reqSrc := getDDLRequestSource(model.ActionAddIndex) - copCtx, err := NewReorgCopContext(store, reorgMeta, tbl.Meta(), idxInfos, reqSrc) - if err != nil { - return nil, err - } - poolSize := copReadChunkPoolSize() - srcChkPool := make(chan *chunk.Chunk, poolSize) - for i := 0; i < poolSize; i++ { - srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, copReadBatchSize()) - } - readerCnt, writerCnt := expectedIngestWorkerCnt(concurrency, avgRowSize) - - backend, err := storage.ParseBackend(extStoreURI, nil) - if err != nil { - return nil, err - } - extStore, err := storage.NewWithDefaultOpt(ctx, backend) - if err != nil { - return nil, err - } - memCap := resource.Mem.Capacity() - memSizePerIndex := uint64(memCap / int64(writerCnt*2*len(idxInfos))) - failpoint.Inject("mockWriterMemSize", func() { - memSizePerIndex = 1 * size.GB - }) - - srcOp := NewTableScanTaskSource(ctx, store, tbl, startKey, endKey, nil) - scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt, nil) - writeOp := NewWriteExternalStoreOperator( - ctx, copCtx, sessPool, jobID, subtaskID, tbl, indexes, extStore, srcChkPool, writerCnt, onClose, memSizePerIndex, reorgMeta) - sinkOp := newIndexWriteResultSink(ctx, nil, tbl, indexes, nil, rowCntListener) - - operator.Compose[TableScanTask](srcOp, scanOp) - operator.Compose[IndexRecordChunk](scanOp, writeOp) - operator.Compose[IndexWriteResult](writeOp, sinkOp) - - logutil.Logger(ctx).Info("build add index cloud storage operators", - zap.Int64("jobID", jobID), - zap.String("memCap", units.BytesSize(float64(memCap))), - zap.String("memSizePerIdx", units.BytesSize(float64(memSizePerIndex))), - zap.Int("avgRowSize", avgRowSize), - zap.Int("reader", readerCnt), - zap.Int("writer", writerCnt)) - - return operator.NewAsyncPipeline( - srcOp, scanOp, writeOp, sinkOp, - ), nil -} - -// TableScanTask contains the start key and the end key of a region. -type TableScanTask struct { - ID int - Start kv.Key - End kv.Key -} - -// String implement fmt.Stringer interface. -func (t TableScanTask) String() string { - return fmt.Sprintf("TableScanTask: id=%d, startKey=%s, endKey=%s", - t.ID, hex.EncodeToString(t.Start), hex.EncodeToString(t.End)) -} - -// IndexRecordChunk contains one of the chunk read from corresponding TableScanTask. -type IndexRecordChunk struct { - ID int - Chunk *chunk.Chunk - Err error - Done bool -} - -// TableScanTaskSource produces TableScanTask by splitting table records into ranges. -type TableScanTaskSource struct { - ctx context.Context - - errGroup errgroup.Group - sink operator.DataChannel[TableScanTask] - - tbl table.PhysicalTable - store kv.Storage - startKey kv.Key - endKey kv.Key - - // only used in local ingest - cpMgr *ingest.CheckpointManager -} - -// NewTableScanTaskSource creates a new TableScanTaskSource. -func NewTableScanTaskSource( - ctx context.Context, - store kv.Storage, - physicalTable table.PhysicalTable, - startKey kv.Key, - endKey kv.Key, - cpMgr *ingest.CheckpointManager, -) *TableScanTaskSource { - return &TableScanTaskSource{ - ctx: ctx, - errGroup: errgroup.Group{}, - tbl: physicalTable, - store: store, - startKey: startKey, - endKey: endKey, - cpMgr: cpMgr, - } -} - -// SetSink implements WithSink interface. -func (src *TableScanTaskSource) SetSink(sink operator.DataChannel[TableScanTask]) { - src.sink = sink -} - -// Open implements Operator interface. -func (src *TableScanTaskSource) Open() error { - src.errGroup.Go(src.generateTasks) - return nil -} - -// adjustStartKey adjusts the start key so that we can skip the ranges that have been processed -// according to the information of checkpoint manager. -func (src *TableScanTaskSource) adjustStartKey(start, end kv.Key) (adjusted kv.Key, done bool) { - if src.cpMgr == nil { - return start, false - } - cpKey := src.cpMgr.LastProcessedKey() - if len(cpKey) == 0 { - return start, false - } - if cpKey.Cmp(start) < 0 || cpKey.Cmp(end) > 0 { - logutil.Logger(src.ctx).Error("invalid checkpoint key", - zap.String("last_process_key", hex.EncodeToString(cpKey)), - zap.String("start", hex.EncodeToString(start)), - zap.String("end", hex.EncodeToString(end)), - ) - if intest.InTest { - panic("invalid checkpoint key") - } - return start, false - } - if cpKey.Cmp(end) == 0 { - return cpKey, true - } - return cpKey.Next(), false -} - -func (src *TableScanTaskSource) generateTasks() error { - taskIDAlloc := newTaskIDAllocator() - defer src.sink.Finish() - - startKey, done := src.adjustStartKey(src.startKey, src.endKey) - if done { - // All table data are done. - return nil - } - for { - kvRanges, err := loadTableRanges( - src.ctx, - src.tbl, - src.store, - startKey, - src.endKey, - backfillTaskChanSize, - ) - if err != nil { - return err - } - if len(kvRanges) == 0 { - break - } - - batchTasks := src.getBatchTableScanTask(kvRanges, taskIDAlloc) - for _, task := range batchTasks { - select { - case <-src.ctx.Done(): - return src.ctx.Err() - case src.sink.Channel() <- task: - } - } - startKey = kvRanges[len(kvRanges)-1].EndKey - if startKey.Cmp(src.endKey) >= 0 { - break - } - } - return nil -} - -func (src *TableScanTaskSource) getBatchTableScanTask( - kvRanges []kv.KeyRange, - taskIDAlloc *taskIDAllocator, -) []TableScanTask { - batchTasks := make([]TableScanTask, 0, len(kvRanges)) - prefix := src.tbl.RecordPrefix() - // Build reorg tasks. - for _, keyRange := range kvRanges { - taskID := taskIDAlloc.alloc() - startKey := keyRange.StartKey - if len(startKey) == 0 { - startKey = prefix - } - endKey := keyRange.EndKey - if len(endKey) == 0 { - endKey = prefix.PrefixNext() - } - - task := TableScanTask{ - ID: taskID, - Start: startKey, - End: endKey, - } - batchTasks = append(batchTasks, task) - } - return batchTasks -} - -// Close implements Operator interface. -func (src *TableScanTaskSource) Close() error { - return src.errGroup.Wait() -} - -// String implements fmt.Stringer interface. -func (*TableScanTaskSource) String() string { - return "TableScanTaskSource" -} - -// TableScanOperator scans table records in given key ranges from kv store. -type TableScanOperator struct { - *operator.AsyncOperator[TableScanTask, IndexRecordChunk] -} - -// NewTableScanOperator creates a new TableScanOperator. -func NewTableScanOperator( - ctx *OperatorCtx, - sessPool opSessPool, - copCtx copr.CopContext, - srcChkPool chan *chunk.Chunk, - concurrency int, - cpMgr *ingest.CheckpointManager, -) *TableScanOperator { - pool := workerpool.NewWorkerPool( - "TableScanOperator", - util.DDL, - concurrency, - func() workerpool.Worker[TableScanTask, IndexRecordChunk] { - return &tableScanWorker{ - ctx: ctx, - copCtx: copCtx, - sessPool: sessPool, - se: nil, - srcChkPool: srcChkPool, - cpMgr: cpMgr, - } - }) - return &TableScanOperator{ - AsyncOperator: operator.NewAsyncOperator[TableScanTask, IndexRecordChunk](ctx, pool), - } -} - -type tableScanWorker struct { - ctx *OperatorCtx - copCtx copr.CopContext - sessPool opSessPool - se *session.Session - srcChkPool chan *chunk.Chunk - - cpMgr *ingest.CheckpointManager -} - -func (w *tableScanWorker) HandleTask(task TableScanTask, sender func(IndexRecordChunk)) { - defer tidbutil.Recover(metrics.LblAddIndex, "handleTableScanTaskWithRecover", func() { - w.ctx.onError(dbterror.ErrReorgPanic) - }, false) - - failpoint.Inject("injectPanicForTableScan", func() { - panic("mock panic") - }) - if w.se == nil { - sessCtx, err := w.sessPool.Get() - if err != nil { - logutil.Logger(w.ctx).Error("tableScanWorker get session from pool failed", zap.Error(err)) - w.ctx.onError(err) - return - } - w.se = session.NewSession(sessCtx) - } - w.scanRecords(task, sender) -} - -func (w *tableScanWorker) Close() { - if w.se != nil { - w.sessPool.Put(w.se.Context) - } -} - -func (w *tableScanWorker) scanRecords(task TableScanTask, sender func(IndexRecordChunk)) { - logutil.Logger(w.ctx).Info("start a table scan task", - zap.Int("id", task.ID), zap.Stringer("task", task)) - - var idxResult IndexRecordChunk - err := wrapInBeginRollback(w.se, func(startTS uint64) error { - failpoint.Inject("mockScanRecordError", func() { - failpoint.Return(errors.New("mock scan record error")) - }) - failpoint.InjectCall("scanRecordExec") - rs, err := buildTableScan(w.ctx, w.copCtx.GetBase(), startTS, task.Start, task.End) - if err != nil { - return err - } - if w.cpMgr != nil { - w.cpMgr.Register(task.ID, task.End) - } - var done bool - for !done { - srcChk := w.getChunk() - done, err = fetchTableScanResult(w.ctx, w.copCtx.GetBase(), rs, srcChk) - if err != nil || w.ctx.Err() != nil { - w.recycleChunk(srcChk) - terror.Call(rs.Close) - return err - } - idxResult = IndexRecordChunk{ID: task.ID, Chunk: srcChk, Done: done} - if w.cpMgr != nil { - w.cpMgr.UpdateTotalKeys(task.ID, srcChk.NumRows(), done) - } - sender(idxResult) - } - return rs.Close() - }) - if err != nil { - w.ctx.onError(err) - } -} - -func (w *tableScanWorker) getChunk() *chunk.Chunk { - chk := <-w.srcChkPool - newCap := copReadBatchSize() - if chk.Capacity() != newCap { - chk = chunk.NewChunkWithCapacity(w.copCtx.GetBase().FieldTypes, newCap) - } - chk.Reset() - return chk -} - -func (w *tableScanWorker) recycleChunk(chk *chunk.Chunk) { - w.srcChkPool <- chk -} - -// WriteExternalStoreOperator writes index records to external storage. -type WriteExternalStoreOperator struct { - *operator.AsyncOperator[IndexRecordChunk, IndexWriteResult] -} - -// NewWriteExternalStoreOperator creates a new WriteExternalStoreOperator. -func NewWriteExternalStoreOperator( - ctx *OperatorCtx, - copCtx copr.CopContext, - sessPool opSessPool, - jobID int64, - subtaskID int64, - tbl table.PhysicalTable, - indexes []table.Index, - store storage.ExternalStorage, - srcChunkPool chan *chunk.Chunk, - concurrency int, - onClose external.OnCloseFunc, - memoryQuota uint64, - reorgMeta *model.DDLReorgMeta, -) *WriteExternalStoreOperator { - // due to multi-schema-change, we may merge processing multiple indexes into one - // local backend. - hasUnique := false - for _, index := range indexes { - if index.Meta().Unique { - hasUnique = true - break - } - } - - pool := workerpool.NewWorkerPool( - "WriteExternalStoreOperator", - util.DDL, - concurrency, - func() workerpool.Worker[IndexRecordChunk, IndexWriteResult] { - writers := make([]ingest.Writer, 0, len(indexes)) - for i := range indexes { - builder := external.NewWriterBuilder(). - SetOnCloseFunc(onClose). - SetKeyDuplicationEncoding(hasUnique). - SetMemorySizeLimit(memoryQuota). - SetGroupOffset(i) - writerID := uuid.New().String() - prefix := path.Join(strconv.Itoa(int(jobID)), strconv.Itoa(int(subtaskID))) - writer := builder.Build(store, prefix, writerID) - writers = append(writers, writer) - } - - return &indexIngestExternalWorker{ - indexIngestBaseWorker: indexIngestBaseWorker{ - ctx: ctx, - tbl: tbl, - indexes: indexes, - copCtx: copCtx, - se: nil, - sessPool: sessPool, - writers: writers, - srcChunkPool: srcChunkPool, - reorgMeta: reorgMeta, - }, - } - }) - return &WriteExternalStoreOperator{ - AsyncOperator: operator.NewAsyncOperator[IndexRecordChunk, IndexWriteResult](ctx, pool), - } -} - -// IndexWriteResult contains the result of writing index records to ingest engine. -type IndexWriteResult struct { - ID int - Added int - Total int - Next kv.Key -} - -// IndexIngestOperator writes index records to ingest engine. -type IndexIngestOperator struct { - *operator.AsyncOperator[IndexRecordChunk, IndexWriteResult] -} - -// NewIndexIngestOperator creates a new IndexIngestOperator. -func NewIndexIngestOperator( - ctx *OperatorCtx, - copCtx copr.CopContext, - backendCtx ingest.BackendCtx, - sessPool opSessPool, - tbl table.PhysicalTable, - indexes []table.Index, - engines []ingest.Engine, - srcChunkPool chan *chunk.Chunk, - concurrency int, - reorgMeta *model.DDLReorgMeta, - cpMgr *ingest.CheckpointManager, - rowCntListener RowCountListener, -) *IndexIngestOperator { - writerCfg := getLocalWriterConfig(len(indexes), concurrency) - - var writerIDAlloc atomic.Int32 - pool := workerpool.NewWorkerPool( - "indexIngestOperator", - util.DDL, - concurrency, - func() workerpool.Worker[IndexRecordChunk, IndexWriteResult] { - writers := make([]ingest.Writer, 0, len(indexes)) - for i := range indexes { - writerID := int(writerIDAlloc.Add(1)) - writer, err := engines[i].CreateWriter(writerID, writerCfg) - if err != nil { - logutil.Logger(ctx).Error("create index ingest worker failed", zap.Error(err)) - ctx.onError(err) - return nil - } - writers = append(writers, writer) - } - - indexIDs := make([]int64, len(indexes)) - for i := 0; i < len(indexes); i++ { - indexIDs[i] = indexes[i].Meta().ID - } - return &indexIngestLocalWorker{ - indexIngestBaseWorker: indexIngestBaseWorker{ - ctx: ctx, - tbl: tbl, - indexes: indexes, - copCtx: copCtx, - - se: nil, - sessPool: sessPool, - writers: writers, - srcChunkPool: srcChunkPool, - reorgMeta: reorgMeta, - }, - indexIDs: indexIDs, - backendCtx: backendCtx, - rowCntListener: rowCntListener, - cpMgr: cpMgr, - } - }) - return &IndexIngestOperator{ - AsyncOperator: operator.NewAsyncOperator[IndexRecordChunk, IndexWriteResult](ctx, pool), - } -} - -type indexIngestExternalWorker struct { - indexIngestBaseWorker -} - -func (w *indexIngestExternalWorker) HandleTask(ck IndexRecordChunk, send func(IndexWriteResult)) { - defer tidbutil.Recover(metrics.LblAddIndex, "indexIngestExternalWorkerRecover", func() { - w.ctx.onError(dbterror.ErrReorgPanic) - }, false) - defer func() { - if ck.Chunk != nil { - w.srcChunkPool <- ck.Chunk - } - }() - rs, err := w.indexIngestBaseWorker.HandleTask(ck) - if err != nil { - w.ctx.onError(err) - return - } - send(rs) -} - -type indexIngestLocalWorker struct { - indexIngestBaseWorker - indexIDs []int64 - backendCtx ingest.BackendCtx - rowCntListener RowCountListener - cpMgr *ingest.CheckpointManager -} - -func (w *indexIngestLocalWorker) HandleTask(ck IndexRecordChunk, send func(IndexWriteResult)) { - defer tidbutil.Recover(metrics.LblAddIndex, "indexIngestLocalWorkerRecover", func() { - w.ctx.onError(dbterror.ErrReorgPanic) - }, false) - defer func() { - if ck.Chunk != nil { - w.srcChunkPool <- ck.Chunk - } - }() - rs, err := w.indexIngestBaseWorker.HandleTask(ck) - if err != nil { - w.ctx.onError(err) - return - } - if rs.Added == 0 { - return - } - w.rowCntListener.Written(rs.Added) - flushed, imported, err := w.backendCtx.Flush(ingest.FlushModeAuto) - if err != nil { - w.ctx.onError(err) - return - } - if w.cpMgr != nil { - totalCnt, nextKey := w.cpMgr.Status() - rs.Total = totalCnt - rs.Next = nextKey - w.cpMgr.UpdateWrittenKeys(ck.ID, rs.Added) - w.cpMgr.AdvanceWatermark(flushed, imported) - } - send(rs) -} - -type indexIngestBaseWorker struct { - ctx *OperatorCtx - - tbl table.PhysicalTable - indexes []table.Index - reorgMeta *model.DDLReorgMeta - - copCtx copr.CopContext - sessPool opSessPool - se *session.Session - restore func(sessionctx.Context) - - writers []ingest.Writer - srcChunkPool chan *chunk.Chunk -} - -func (w *indexIngestBaseWorker) HandleTask(rs IndexRecordChunk) (IndexWriteResult, error) { - failpoint.Inject("injectPanicForIndexIngest", func() { - panic("mock panic") - }) - - result := IndexWriteResult{ - ID: rs.ID, - } - w.initSessCtx() - count, nextKey, err := w.WriteChunk(&rs) - if err != nil { - w.ctx.onError(err) - return result, err - } - if count == 0 { - logutil.Logger(w.ctx).Info("finish a index ingest task", zap.Int("id", rs.ID)) - return result, nil - } - result.Added = count - result.Next = nextKey - if ResultCounterForTest != nil { - ResultCounterForTest.Add(1) - } - return result, nil -} - -func (w *indexIngestBaseWorker) initSessCtx() { - if w.se == nil { - sessCtx, err := w.sessPool.Get() - if err != nil { - w.ctx.onError(err) - return - } - w.restore = restoreSessCtx(sessCtx) - if err := initSessCtx(sessCtx, w.reorgMeta); err != nil { - w.ctx.onError(err) - return - } - w.se = session.NewSession(sessCtx) - } -} - -func (w *indexIngestBaseWorker) Close() { - // TODO(lance6716): unify the real write action for engineInfo and external - // writer. - for _, writer := range w.writers { - ew, ok := writer.(*external.Writer) - if !ok { - break - } - err := ew.Close(w.ctx) - if err != nil { - w.ctx.onError(err) - } - } - if w.se != nil { - w.restore(w.se.Context) - w.sessPool.Put(w.se.Context) - } -} - -// WriteChunk will write index records to lightning engine. -func (w *indexIngestBaseWorker) WriteChunk(rs *IndexRecordChunk) (count int, nextKey kv.Key, err error) { - failpoint.Inject("mockWriteLocalError", func(_ failpoint.Value) { - failpoint.Return(0, nil, errors.New("mock write local error")) - }) - failpoint.InjectCall("writeLocalExec", rs.Done) - - oprStartTime := time.Now() - vars := w.se.GetSessionVars() - sc := vars.StmtCtx - cnt, lastHandle, err := writeChunkToLocal(w.ctx, w.writers, w.indexes, w.copCtx, sc.TimeZone(), sc.ErrCtx(), vars.GetWriteStmtBufs(), rs.Chunk) - if err != nil || cnt == 0 { - return 0, nil, err - } - logSlowOperations(time.Since(oprStartTime), "writeChunkToLocal", 3000) - nextKey = tablecodec.EncodeRecordKey(w.tbl.RecordPrefix(), lastHandle) - return cnt, nextKey, nil -} - -type indexWriteResultSink struct { - ctx *OperatorCtx - backendCtx ingest.BackendCtx - tbl table.PhysicalTable - indexes []table.Index - - cpMgr *ingest.CheckpointManager - rowCntListener RowCountListener - - errGroup errgroup.Group - source operator.DataChannel[IndexWriteResult] -} - -func newIndexWriteResultSink( - ctx *OperatorCtx, - backendCtx ingest.BackendCtx, - tbl table.PhysicalTable, - indexes []table.Index, - cpMgr *ingest.CheckpointManager, - rowCntListener RowCountListener, -) *indexWriteResultSink { - return &indexWriteResultSink{ - ctx: ctx, - backendCtx: backendCtx, - tbl: tbl, - indexes: indexes, - errGroup: errgroup.Group{}, - cpMgr: cpMgr, - rowCntListener: rowCntListener, - } -} - -func (s *indexWriteResultSink) SetSource(source operator.DataChannel[IndexWriteResult]) { - s.source = source -} - -func (s *indexWriteResultSink) Open() error { - s.errGroup.Go(s.collectResult) - return nil -} - -func (s *indexWriteResultSink) collectResult() error { - for { - select { - case <-s.ctx.Done(): - return s.ctx.Err() - case _, ok := <-s.source.Channel(): - if !ok { - err := s.flush() - if err != nil { - s.ctx.onError(err) - } - if s.cpMgr != nil { - total, _ := s.cpMgr.Status() - s.rowCntListener.SetTotal(total) - } - return err - } - } - } -} - -func (s *indexWriteResultSink) flush() error { - if s.backendCtx == nil { - return nil - } - failpoint.Inject("mockFlushError", func(_ failpoint.Value) { - failpoint.Return(errors.New("mock flush error")) - }) - flushed, imported, err := s.backendCtx.Flush(ingest.FlushModeForceFlushAndImport) - if s.cpMgr != nil { - // Try to advance watermark even if there is an error. - s.cpMgr.AdvanceWatermark(flushed, imported) - } - if err != nil { - msg := "flush error" - if flushed { - msg = "import error" - } - logutil.Logger(s.ctx).Error(msg, zap.String("category", "ddl"), zap.Error(err)) - return err - } - return nil -} - -func (s *indexWriteResultSink) Close() error { - return s.errGroup.Wait() -} - -func (*indexWriteResultSink) String() string { - return "indexWriteResultSink" -} diff --git a/pkg/ddl/backfilling_read_index.go b/pkg/ddl/backfilling_read_index.go index 41e0d3cc77ab5..0c5ce7a62a538 100644 --- a/pkg/ddl/backfilling_read_index.go +++ b/pkg/ddl/backfilling_read_index.go @@ -148,7 +148,7 @@ func (r *readIndexExecutor) Cleanup(ctx context.Context) error { } func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { - failpoint.Call(_curpkg_("mockDMLExecutionAddIndexSubTaskFinish")) + failpoint.InjectCall("mockDMLExecutionAddIndexSubTaskFinish") if len(r.cloudStorageURI) == 0 { return nil } diff --git a/pkg/ddl/backfilling_read_index.go__failpoint_stash__ b/pkg/ddl/backfilling_read_index.go__failpoint_stash__ deleted file mode 100644 index 0c5ce7a62a538..0000000000000 --- a/pkg/ddl/backfilling_read_index.go__failpoint_stash__ +++ /dev/null @@ -1,315 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "encoding/hex" - "encoding/json" - "sync" - "sync/atomic" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" - "github.com/pingcap/tidb/pkg/disttask/operator" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/external" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/table" - tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/prometheus/client_golang/prometheus" - "go.uber.org/zap" -) - -type readIndexExecutor struct { - execute.StepExecFrameworkInfo - d *ddl - job *model.Job - indexes []*model.IndexInfo - ptbl table.PhysicalTable - jc *JobContext - - avgRowSize int - cloudStorageURI string - - bc ingest.BackendCtx - curRowCount *atomic.Int64 - - subtaskSummary sync.Map // subtaskID => readIndexSummary -} - -type readIndexSummary struct { - metaGroups []*external.SortedKVMeta - mu sync.Mutex -} - -func newReadIndexExecutor( - d *ddl, - job *model.Job, - indexes []*model.IndexInfo, - ptbl table.PhysicalTable, - jc *JobContext, - bcGetter func() (ingest.BackendCtx, error), - cloudStorageURI string, - avgRowSize int, -) (*readIndexExecutor, error) { - bc, err := bcGetter() - if err != nil { - return nil, err - } - return &readIndexExecutor{ - d: d, - job: job, - indexes: indexes, - ptbl: ptbl, - jc: jc, - bc: bc, - cloudStorageURI: cloudStorageURI, - avgRowSize: avgRowSize, - curRowCount: &atomic.Int64{}, - }, nil -} - -func (*readIndexExecutor) Init(_ context.Context) error { - logutil.DDLLogger().Info("read index executor init subtask exec env") - return nil -} - -func (r *readIndexExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error { - logutil.DDLLogger().Info("read index executor run subtask", - zap.Bool("use cloud", len(r.cloudStorageURI) > 0)) - - r.subtaskSummary.Store(subtask.ID, &readIndexSummary{ - metaGroups: make([]*external.SortedKVMeta, len(r.indexes)), - }) - - sm, err := decodeBackfillSubTaskMeta(subtask.Meta) - if err != nil { - return err - } - - opCtx := NewDistTaskOperatorCtx(ctx, subtask.TaskID, subtask.ID) - defer opCtx.Cancel() - r.curRowCount.Store(0) - - if len(r.cloudStorageURI) > 0 { - pipe, err := r.buildExternalStorePipeline(opCtx, subtask.ID, sm, subtask.Concurrency) - if err != nil { - return err - } - return executeAndClosePipeline(opCtx, pipe) - } - - pipe, err := r.buildLocalStorePipeline(opCtx, sm, subtask.Concurrency) - if err != nil { - return err - } - err = executeAndClosePipeline(opCtx, pipe) - if err != nil { - // For dist task local based ingest, checkpoint is unsupported. - // If there is an error we should keep local sort dir clean. - err1 := r.bc.FinishAndUnregisterEngines(ingest.OptCleanData) - if err1 != nil { - logutil.DDLLogger().Warn("read index executor unregister engine failed", zap.Error(err1)) - } - return err - } - return r.bc.FinishAndUnregisterEngines(ingest.OptCleanData | ingest.OptCheckDup) -} - -func (r *readIndexExecutor) RealtimeSummary() *execute.SubtaskSummary { - return &execute.SubtaskSummary{ - RowCount: r.curRowCount.Load(), - } -} - -func (r *readIndexExecutor) Cleanup(ctx context.Context) error { - tidblogutil.Logger(ctx).Info("read index executor cleanup subtask exec env") - // cleanup backend context - ingest.LitBackCtxMgr.Unregister(r.job.ID) - return nil -} - -func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { - failpoint.InjectCall("mockDMLExecutionAddIndexSubTaskFinish") - if len(r.cloudStorageURI) == 0 { - return nil - } - // Rewrite the subtask meta to record statistics. - sm, err := decodeBackfillSubTaskMeta(subtask.Meta) - if err != nil { - return err - } - sum, _ := r.subtaskSummary.LoadAndDelete(subtask.ID) - s := sum.(*readIndexSummary) - sm.MetaGroups = s.metaGroups - sm.EleIDs = make([]int64, 0, len(r.indexes)) - for _, index := range r.indexes { - sm.EleIDs = append(sm.EleIDs, index.ID) - } - - all := external.SortedKVMeta{} - for _, g := range s.metaGroups { - all.Merge(g) - } - tidblogutil.Logger(ctx).Info("get key boundary on subtask finished", - zap.String("start", hex.EncodeToString(all.StartKey)), - zap.String("end", hex.EncodeToString(all.EndKey)), - zap.Int("fileCount", len(all.MultipleFilesStats)), - zap.Uint64("totalKVSize", all.TotalKVSize)) - - meta, err := json.Marshal(sm) - if err != nil { - return err - } - subtask.Meta = meta - return nil -} - -func (r *readIndexExecutor) getTableStartEndKey(sm *BackfillSubTaskMeta) ( - start, end kv.Key, tbl table.PhysicalTable, err error) { - currentVer, err1 := getValidCurrentVersion(r.d.store) - if err1 != nil { - return nil, nil, nil, errors.Trace(err1) - } - if parTbl, ok := r.ptbl.(table.PartitionedTable); ok { - pid := sm.PhysicalTableID - start, end, err = getTableRange(r.jc, r.d.ddlCtx, parTbl.GetPartition(pid), currentVer.Ver, r.job.Priority) - if err != nil { - logutil.DDLLogger().Error("get table range error", - zap.Error(err)) - return nil, nil, nil, err - } - tbl = parTbl.GetPartition(pid) - } else { - start, end = sm.RowStart, sm.RowEnd - tbl = r.ptbl - } - return start, end, tbl, nil -} - -func (r *readIndexExecutor) buildLocalStorePipeline( - opCtx *OperatorCtx, - sm *BackfillSubTaskMeta, - concurrency int, -) (*operator.AsyncPipeline, error) { - start, end, tbl, err := r.getTableStartEndKey(sm) - if err != nil { - return nil, err - } - d := r.d - indexIDs := make([]int64, 0, len(r.indexes)) - uniques := make([]bool, 0, len(r.indexes)) - for _, index := range r.indexes { - indexIDs = append(indexIDs, index.ID) - uniques = append(uniques, index.Unique) - } - engines, err := r.bc.Register(indexIDs, uniques, r.ptbl) - if err != nil { - tidblogutil.Logger(opCtx).Error("cannot register new engine", - zap.Error(err), - zap.Int64("job ID", r.job.ID), - zap.Int64s("index IDs", indexIDs)) - return nil, err - } - rowCntListener := newDistTaskRowCntListener(r.curRowCount, r.job.SchemaName, tbl.Meta().Name.O) - return NewAddIndexIngestPipeline( - opCtx, - d.store, - d.sessPool, - r.bc, - engines, - r.job.ID, - tbl, - r.indexes, - start, - end, - r.job.ReorgMeta, - r.avgRowSize, - concurrency, - nil, - rowCntListener, - ) -} - -func (r *readIndexExecutor) buildExternalStorePipeline( - opCtx *OperatorCtx, - subtaskID int64, - sm *BackfillSubTaskMeta, - concurrency int, -) (*operator.AsyncPipeline, error) { - start, end, tbl, err := r.getTableStartEndKey(sm) - if err != nil { - return nil, err - } - - d := r.d - onClose := func(summary *external.WriterSummary) { - sum, _ := r.subtaskSummary.Load(subtaskID) - s := sum.(*readIndexSummary) - s.mu.Lock() - kvMeta := s.metaGroups[summary.GroupOffset] - if kvMeta == nil { - kvMeta = &external.SortedKVMeta{} - s.metaGroups[summary.GroupOffset] = kvMeta - } - kvMeta.MergeSummary(summary) - s.mu.Unlock() - } - rowCntListener := newDistTaskRowCntListener(r.curRowCount, r.job.SchemaName, tbl.Meta().Name.O) - return NewWriteIndexToExternalStoragePipeline( - opCtx, - d.store, - r.cloudStorageURI, - r.d.sessPool, - r.job.ID, - subtaskID, - tbl, - r.indexes, - start, - end, - onClose, - r.job.ReorgMeta, - r.avgRowSize, - concurrency, - r.GetResource(), - rowCntListener, - ) -} - -type distTaskRowCntListener struct { - EmptyRowCntListener - totalRowCount *atomic.Int64 - counter prometheus.Counter -} - -func newDistTaskRowCntListener(totalRowCnt *atomic.Int64, dbName, tblName string) *distTaskRowCntListener { - counter := metrics.BackfillTotalCounter.WithLabelValues( - metrics.GenerateReorgLabel("add_idx_rate", dbName, tblName)) - return &distTaskRowCntListener{ - totalRowCount: totalRowCnt, - counter: counter, - } -} - -func (d *distTaskRowCntListener) Written(rowCnt int) { - d.totalRowCount.Add(int64(rowCnt)) - d.counter.Add(float64(rowCnt)) -} diff --git a/pkg/ddl/binding__failpoint_binding__.go b/pkg/ddl/binding__failpoint_binding__.go deleted file mode 100644 index d1a12631ba250..0000000000000 --- a/pkg/ddl/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package ddl - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/ddl/cluster.go b/pkg/ddl/cluster.go index 7386b26d98e58..b866023636484 100644 --- a/pkg/ddl/cluster.go +++ b/pkg/ddl/cluster.go @@ -104,10 +104,10 @@ func recoverPDSchedule(ctx context.Context, pdScheduleParam map[string]any) erro func getStoreGlobalMinSafeTS(s kv.Storage) time.Time { minSafeTS := s.GetMinSafeTS(kv.GlobalTxnScope) // Inject mocked SafeTS for test. - if val, _err_ := failpoint.Eval(_curpkg_("injectSafeTS")); _err_ == nil { + failpoint.Inject("injectSafeTS", func(val failpoint.Value) { injectTS := val.(int) minSafeTS = uint64(injectTS) - } + }) return oracle.GetTimeFromTS(minSafeTS) } @@ -129,10 +129,10 @@ func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBack } flashbackGetMinSafeTimeTimeout := time.Minute - if val, _err_ := failpoint.Eval(_curpkg_("changeFlashbackGetMinSafeTimeTimeout")); _err_ == nil { + failpoint.Inject("changeFlashbackGetMinSafeTimeTimeout", func(val failpoint.Value) { t := val.(int) flashbackGetMinSafeTimeTimeout = time.Duration(t) - } + }) start := time.Now() minSafeTime := getStoreGlobalMinSafeTS(sctx.GetStore()) @@ -535,14 +535,14 @@ func SendPrepareFlashbackToVersionRPC( if err != nil { return taskStat, err } - if val, _err_ := failpoint.Eval(_curpkg_("mockPrepareMeetsEpochNotMatch")); _err_ == nil { + failpoint.Inject("mockPrepareMeetsEpochNotMatch", func(val failpoint.Value) { if val.(bool) && bo.ErrorsNum() == 0 { regionErr = &errorpb.Error{ Message: "stale epoch", EpochNotMatch: &errorpb.EpochNotMatch{}, } } - } + }) if regionErr != nil { err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) if err != nil { @@ -702,11 +702,11 @@ func splitRegionsByKeyRanges(ctx context.Context, d *ddlCtx, keyRanges []kv.KeyR // 4. phase 2, send flashback RPC, do flashback jobs. func (w *worker) onFlashbackCluster(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { inFlashbackTest := false - if val, _err_ := failpoint.Eval(_curpkg_("mockFlashbackTest")); _err_ == nil { + failpoint.Inject("mockFlashbackTest", func(val failpoint.Value) { if val.(bool) { inFlashbackTest = true } - } + }) // TODO: Support flashback in unistore. if d.store.Name() != "TiKV" && !inFlashbackTest { job.State = model.JobStateCancelled diff --git a/pkg/ddl/cluster.go__failpoint_stash__ b/pkg/ddl/cluster.go__failpoint_stash__ deleted file mode 100644 index b866023636484..0000000000000 --- a/pkg/ddl/cluster.go__failpoint_stash__ +++ /dev/null @@ -1,902 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "cmp" - "context" - "encoding/hex" - "fmt" - "slices" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/errorpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/filter" - "github.com/pingcap/tidb/pkg/util/gcutil" - tikvstore "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/txnkv/rangetask" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -var pdScheduleKey = []string{ - "merge-schedule-limit", -} - -const ( - flashbackMaxBackoff = 1800000 // 1800s - flashbackTimeout = 3 * time.Minute // 3min -) - -const ( - pdScheduleArgsOffset = 1 + iota - gcEnabledOffset - autoAnalyzeOffset - readOnlyOffset - totalLockedRegionsOffset - startTSOffset - commitTSOffset - ttlJobEnableOffSet - keyRangesOffset -) - -func closePDSchedule(ctx context.Context) error { - closeMap := make(map[string]any) - for _, key := range pdScheduleKey { - closeMap[key] = 0 - } - return infosync.SetPDScheduleConfig(ctx, closeMap) -} - -func savePDSchedule(ctx context.Context, job *model.Job) error { - retValue, err := infosync.GetPDScheduleConfig(ctx) - if err != nil { - return err - } - saveValue := make(map[string]any) - for _, key := range pdScheduleKey { - saveValue[key] = retValue[key] - } - job.Args[pdScheduleArgsOffset] = &saveValue - return nil -} - -func recoverPDSchedule(ctx context.Context, pdScheduleParam map[string]any) error { - if pdScheduleParam == nil { - return nil - } - return infosync.SetPDScheduleConfig(ctx, pdScheduleParam) -} - -func getStoreGlobalMinSafeTS(s kv.Storage) time.Time { - minSafeTS := s.GetMinSafeTS(kv.GlobalTxnScope) - // Inject mocked SafeTS for test. - failpoint.Inject("injectSafeTS", func(val failpoint.Value) { - injectTS := val.(int) - minSafeTS = uint64(injectTS) - }) - return oracle.GetTimeFromTS(minSafeTS) -} - -// ValidateFlashbackTS validates that flashBackTS in range [gcSafePoint, currentTS). -func ValidateFlashbackTS(ctx context.Context, sctx sessionctx.Context, flashBackTS uint64) error { - currentTS, err := sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) - // If we fail to calculate currentTS from local time, fallback to get a timestamp from PD. - if err != nil { - metrics.ValidateReadTSFromPDCount.Inc() - currentVer, err := sctx.GetStore().CurrentVersion(oracle.GlobalTxnScope) - if err != nil { - return errors.Errorf("fail to validate flashback timestamp: %v", err) - } - currentTS = currentVer.Ver - } - oracleFlashbackTS := oracle.GetTimeFromTS(flashBackTS) - if oracleFlashbackTS.After(oracle.GetTimeFromTS(currentTS)) { - return errors.Errorf("cannot set flashback timestamp to future time") - } - - flashbackGetMinSafeTimeTimeout := time.Minute - failpoint.Inject("changeFlashbackGetMinSafeTimeTimeout", func(val failpoint.Value) { - t := val.(int) - flashbackGetMinSafeTimeTimeout = time.Duration(t) - }) - - start := time.Now() - minSafeTime := getStoreGlobalMinSafeTS(sctx.GetStore()) - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for oracleFlashbackTS.After(minSafeTime) { - if time.Since(start) >= flashbackGetMinSafeTimeTimeout { - return errors.Errorf("cannot set flashback timestamp after min-resolved-ts(%s)", minSafeTime) - } - select { - case <-ticker.C: - minSafeTime = getStoreGlobalMinSafeTS(sctx.GetStore()) - case <-ctx.Done(): - return ctx.Err() - } - } - - gcSafePoint, err := gcutil.GetGCSafePoint(sctx) - if err != nil { - return err - } - - return gcutil.ValidateSnapshotWithGCSafePoint(flashBackTS, gcSafePoint) -} - -func getTiDBTTLJobEnable(sess sessionctx.Context) (string, error) { - val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBTTLJobEnable) - if err != nil { - return "", errors.Trace(err) - } - return val, nil -} - -func setTiDBTTLJobEnable(ctx context.Context, sess sessionctx.Context, value string) error { - return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBTTLJobEnable, value) -} - -func setTiDBEnableAutoAnalyze(ctx context.Context, sess sessionctx.Context, value string) error { - return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBEnableAutoAnalyze, value) -} - -func getTiDBEnableAutoAnalyze(sess sessionctx.Context) (string, error) { - val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBEnableAutoAnalyze) - if err != nil { - return "", errors.Trace(err) - } - return val, nil -} - -func setTiDBSuperReadOnly(ctx context.Context, sess sessionctx.Context, value string) error { - return sess.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(ctx, variable.TiDBSuperReadOnly, value) -} - -func getTiDBSuperReadOnly(sess sessionctx.Context) (string, error) { - val, err := sess.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.TiDBSuperReadOnly) - if err != nil { - return "", errors.Trace(err) - } - return val, nil -} - -func isFlashbackSupportedDDLAction(action model.ActionType) bool { - switch action { - case model.ActionSetTiFlashReplica, model.ActionUpdateTiFlashReplicaStatus, model.ActionAlterPlacementPolicy, - model.ActionAlterTablePlacement, model.ActionAlterTablePartitionPlacement, model.ActionCreatePlacementPolicy, - model.ActionDropPlacementPolicy, model.ActionModifySchemaDefaultPlacement, - model.ActionAlterTableAttributes, model.ActionAlterTablePartitionAttributes: - return false - default: - return true - } -} - -func checkSystemSchemaID(t *meta.Meta, schemaID int64, flashbackTSString string) error { - if schemaID <= 0 { - return nil - } - dbInfo, err := t.GetDatabase(schemaID) - if err != nil || dbInfo == nil { - return errors.Trace(err) - } - if filter.IsSystemSchema(dbInfo.Name.L) { - return errors.Errorf("Detected modified system table during [%s, now), can't do flashback", flashbackTSString) - } - return nil -} - -func checkAndSetFlashbackClusterInfo(ctx context.Context, se sessionctx.Context, d *ddlCtx, t *meta.Meta, job *model.Job, flashbackTS uint64) (err error) { - if err = ValidateFlashbackTS(ctx, se, flashbackTS); err != nil { - return err - } - - if err = gcutil.DisableGC(se); err != nil { - return err - } - if err = closePDSchedule(ctx); err != nil { - return err - } - if err = setTiDBEnableAutoAnalyze(ctx, se, variable.Off); err != nil { - return err - } - if err = setTiDBSuperReadOnly(ctx, se, variable.On); err != nil { - return err - } - if err = setTiDBTTLJobEnable(ctx, se, variable.Off); err != nil { - return err - } - - nowSchemaVersion, err := t.GetSchemaVersion() - if err != nil { - return errors.Trace(err) - } - - flashbackSnapshotMeta := meta.NewSnapshotMeta(d.store.GetSnapshot(kv.NewVersion(flashbackTS))) - flashbackSchemaVersion, err := flashbackSnapshotMeta.GetSchemaVersion() - if err != nil { - return errors.Trace(err) - } - - flashbackTSString := oracle.GetTimeFromTS(flashbackTS).Format(types.TimeFSPFormat) - - // Check if there is an upgrade during [flashbackTS, now) - sql := fmt.Sprintf("select VARIABLE_VALUE from mysql.tidb as of timestamp '%s' where VARIABLE_NAME='tidb_server_version'", flashbackTSString) - rows, err := sess.NewSession(se).Execute(ctx, sql, "check_tidb_server_version") - if err != nil || len(rows) == 0 { - return errors.Errorf("Get history `tidb_server_version` failed, can't do flashback") - } - sql = fmt.Sprintf("select 1 from mysql.tidb where VARIABLE_NAME='tidb_server_version' and VARIABLE_VALUE=%s", rows[0].GetString(0)) - rows, err = sess.NewSession(se).Execute(ctx, sql, "check_tidb_server_version") - if err != nil { - return errors.Trace(err) - } - if len(rows) == 0 { - return errors.Errorf("Detected TiDB upgrade during [%s, now), can't do flashback", flashbackTSString) - } - - // Check is there a DDL task at flashbackTS. - sql = fmt.Sprintf("select count(*) from mysql.%s as of timestamp '%s'", JobTable, flashbackTSString) - rows, err = sess.NewSession(se).Execute(ctx, sql, "check_history_job") - if err != nil || len(rows) == 0 { - return errors.Errorf("Get history ddl jobs failed, can't do flashback") - } - if rows[0].GetInt64(0) != 0 { - return errors.Errorf("Detected another DDL job at %s, can't do flashback", flashbackTSString) - } - - // If flashbackSchemaVersion not same as nowSchemaVersion, we should check all schema diffs during [flashbackTs, now). - for i := flashbackSchemaVersion + 1; i <= nowSchemaVersion; i++ { - diff, err := t.GetSchemaDiff(i) - if err != nil { - return errors.Trace(err) - } - if diff == nil { - continue - } - if !isFlashbackSupportedDDLAction(diff.Type) { - return errors.Errorf("Detected unsupported DDL job type(%s) during [%s, now), can't do flashback", diff.Type.String(), flashbackTSString) - } - err = checkSystemSchemaID(flashbackSnapshotMeta, diff.SchemaID, flashbackTSString) - if err != nil { - return errors.Trace(err) - } - } - - jobs, err := GetAllDDLJobs(se) - if err != nil { - return errors.Trace(err) - } - // Other ddl jobs in queue, return error. - if len(jobs) != 1 { - var otherJob *model.Job - for _, j := range jobs { - if j.ID != job.ID { - otherJob = j - break - } - } - return errors.Errorf("have other ddl jobs(jobID: %d) in queue, can't do flashback", otherJob.ID) - } - return nil -} - -func addToSlice(schema string, tableName string, tableID int64, flashbackIDs []int64) []int64 { - if filter.IsSystemSchema(schema) && !strings.HasPrefix(tableName, "stats_") && tableName != "gc_delete_range" { - flashbackIDs = append(flashbackIDs, tableID) - } - return flashbackIDs -} - -// getTableDataKeyRanges get keyRanges by `flashbackIDs`. -// This func will return all flashback table data key ranges. -func getTableDataKeyRanges(nonFlashbackTableIDs []int64) []kv.KeyRange { - var keyRanges []kv.KeyRange - - nonFlashbackTableIDs = append(nonFlashbackTableIDs, -1) - - slices.SortFunc(nonFlashbackTableIDs, func(a, b int64) int { - return cmp.Compare(a, b) - }) - - for i := 1; i < len(nonFlashbackTableIDs); i++ { - keyRanges = append(keyRanges, kv.KeyRange{ - StartKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[i-1] + 1), - EndKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[i]), - }) - } - - // Add all other key ranges. - keyRanges = append(keyRanges, kv.KeyRange{ - StartKey: tablecodec.EncodeTablePrefix(nonFlashbackTableIDs[len(nonFlashbackTableIDs)-1] + 1), - EndKey: tablecodec.EncodeTablePrefix(meta.MaxGlobalID), - }) - - return keyRanges -} - -type keyRangeMayExclude struct { - r kv.KeyRange - exclude bool -} - -// appendContinuousKeyRanges merges not exclude continuous key ranges and appends -// to given []kv.KeyRange, assuming the gap between key ranges has no data. -// -// Precondition: schemaKeyRanges is sorted by start key. schemaKeyRanges are -// non-overlapping. -func appendContinuousKeyRanges(result []kv.KeyRange, schemaKeyRanges []keyRangeMayExclude) []kv.KeyRange { - var ( - continuousStart, continuousEnd kv.Key - ) - - for _, r := range schemaKeyRanges { - if r.exclude { - if continuousStart != nil { - result = append(result, kv.KeyRange{ - StartKey: continuousStart, - EndKey: continuousEnd, - }) - continuousStart = nil - } - continue - } - - if continuousStart == nil { - continuousStart = r.r.StartKey - } - continuousEnd = r.r.EndKey - } - - if continuousStart != nil { - result = append(result, kv.KeyRange{ - StartKey: continuousStart, - EndKey: continuousEnd, - }) - } - return result -} - -// getFlashbackKeyRanges get keyRanges for flashback cluster. -// It contains all non system table key ranges and meta data key ranges. -// The time complexity is O(nlogn). -func getFlashbackKeyRanges(ctx context.Context, sess sessionctx.Context, flashbackTS uint64) ([]kv.KeyRange, error) { - is := sess.GetDomainInfoSchema().(infoschema.InfoSchema) - schemas := is.AllSchemas() - - // The semantic of keyRanges(output). - keyRanges := make([]kv.KeyRange, 0) - - // get snapshot schema IDs. - flashbackSnapshotMeta := meta.NewSnapshotMeta(sess.GetStore().GetSnapshot(kv.NewVersion(flashbackTS))) - snapshotSchemas, err := flashbackSnapshotMeta.ListDatabases() - if err != nil { - return nil, errors.Trace(err) - } - - schemaIDs := make(map[int64]struct{}) - excludeSchemaIDs := make(map[int64]struct{}) - for _, schema := range schemas { - if filter.IsSystemSchema(schema.Name.L) { - excludeSchemaIDs[schema.ID] = struct{}{} - } else { - schemaIDs[schema.ID] = struct{}{} - } - } - for _, schema := range snapshotSchemas { - if filter.IsSystemSchema(schema.Name.L) { - excludeSchemaIDs[schema.ID] = struct{}{} - } else { - schemaIDs[schema.ID] = struct{}{} - } - } - - schemaKeyRanges := make([]keyRangeMayExclude, 0, len(schemaIDs)+len(excludeSchemaIDs)) - for schemaID := range schemaIDs { - metaStartKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID)) - metaEndKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID + 1)) - schemaKeyRanges = append(schemaKeyRanges, keyRangeMayExclude{ - r: kv.KeyRange{ - StartKey: metaStartKey, - EndKey: metaEndKey, - }, - exclude: false, - }) - } - for schemaID := range excludeSchemaIDs { - metaStartKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID)) - metaEndKey := tablecodec.EncodeMetaKeyPrefix(meta.DBkey(schemaID + 1)) - schemaKeyRanges = append(schemaKeyRanges, keyRangeMayExclude{ - r: kv.KeyRange{ - StartKey: metaStartKey, - EndKey: metaEndKey, - }, - exclude: true, - }) - } - - slices.SortFunc(schemaKeyRanges, func(a, b keyRangeMayExclude) int { - return bytes.Compare(a.r.StartKey, b.r.StartKey) - }) - - keyRanges = appendContinuousKeyRanges(keyRanges, schemaKeyRanges) - - startKey := tablecodec.EncodeMetaKeyPrefix([]byte("DBs")) - keyRanges = append(keyRanges, kv.KeyRange{ - StartKey: startKey, - EndKey: startKey.PrefixNext(), - }) - - var nonFlashbackTableIDs []int64 - for _, db := range schemas { - tbls, err2 := is.SchemaTableInfos(ctx, db.Name) - if err2 != nil { - return nil, errors.Trace(err2) - } - for _, table := range tbls { - if !table.IsBaseTable() || table.ID > meta.MaxGlobalID { - continue - } - nonFlashbackTableIDs = addToSlice(db.Name.L, table.Name.L, table.ID, nonFlashbackTableIDs) - if table.Partition != nil { - for _, partition := range table.Partition.Definitions { - nonFlashbackTableIDs = addToSlice(db.Name.L, table.Name.L, partition.ID, nonFlashbackTableIDs) - } - } - } - } - - return append(keyRanges, getTableDataKeyRanges(nonFlashbackTableIDs)...), nil -} - -// SendPrepareFlashbackToVersionRPC prepares regions for flashback, the purpose is to put region into flashback state which region stop write -// Function also be called by BR for volume snapshot backup and restore -func SendPrepareFlashbackToVersionRPC( - ctx context.Context, - s tikv.Storage, - flashbackTS, startTS uint64, - r tikvstore.KeyRange, -) (rangetask.TaskStat, error) { - startKey, rangeEndKey := r.StartKey, r.EndKey - var taskStat rangetask.TaskStat - bo := tikv.NewBackoffer(ctx, flashbackMaxBackoff) - for { - select { - case <-ctx.Done(): - return taskStat, errors.WithStack(ctx.Err()) - default: - } - - if len(rangeEndKey) > 0 && bytes.Compare(startKey, rangeEndKey) >= 0 { - break - } - - loc, err := s.GetRegionCache().LocateKey(bo, startKey) - if err != nil { - return taskStat, err - } - - endKey := loc.EndKey - isLast := len(endKey) == 0 || (len(rangeEndKey) > 0 && bytes.Compare(endKey, rangeEndKey) >= 0) - // If it is the last region. - if isLast { - endKey = rangeEndKey - } - - logutil.DDLLogger().Info("send prepare flashback request", zap.Uint64("region_id", loc.Region.GetID()), - zap.String("start_key", hex.EncodeToString(startKey)), zap.String("end_key", hex.EncodeToString(endKey))) - - req := tikvrpc.NewRequest(tikvrpc.CmdPrepareFlashbackToVersion, &kvrpcpb.PrepareFlashbackToVersionRequest{ - StartKey: startKey, - EndKey: endKey, - StartTs: startTS, - Version: flashbackTS, - }) - - resp, err := s.SendReq(bo, req, loc.Region, flashbackTimeout) - if err != nil { - return taskStat, err - } - regionErr, err := resp.GetRegionError() - if err != nil { - return taskStat, err - } - failpoint.Inject("mockPrepareMeetsEpochNotMatch", func(val failpoint.Value) { - if val.(bool) && bo.ErrorsNum() == 0 { - regionErr = &errorpb.Error{ - Message: "stale epoch", - EpochNotMatch: &errorpb.EpochNotMatch{}, - } - } - }) - if regionErr != nil { - err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) - if err != nil { - return taskStat, err - } - continue - } - if resp.Resp == nil { - logutil.DDLLogger().Warn("prepare flashback miss resp body", zap.Uint64("region_id", loc.Region.GetID())) - err = bo.Backoff(tikv.BoTiKVRPC(), errors.New("prepare flashback rpc miss resp body")) - if err != nil { - return taskStat, err - } - continue - } - prepareFlashbackToVersionResp := resp.Resp.(*kvrpcpb.PrepareFlashbackToVersionResponse) - if err := prepareFlashbackToVersionResp.GetError(); err != "" { - boErr := bo.Backoff(tikv.BoTiKVRPC(), errors.New(err)) - if boErr != nil { - return taskStat, boErr - } - continue - } - taskStat.CompletedRegions++ - if isLast { - break - } - bo = tikv.NewBackoffer(ctx, flashbackMaxBackoff) - startKey = endKey - } - return taskStat, nil -} - -// SendFlashbackToVersionRPC flashback the MVCC key to the version -// Function also be called by BR for volume snapshot backup and restore -func SendFlashbackToVersionRPC( - ctx context.Context, - s tikv.Storage, - version uint64, - startTS, commitTS uint64, - r tikvstore.KeyRange, -) (rangetask.TaskStat, error) { - startKey, rangeEndKey := r.StartKey, r.EndKey - var taskStat rangetask.TaskStat - bo := tikv.NewBackoffer(ctx, flashbackMaxBackoff) - for { - select { - case <-ctx.Done(): - return taskStat, errors.WithStack(ctx.Err()) - default: - } - - if len(rangeEndKey) > 0 && bytes.Compare(startKey, rangeEndKey) >= 0 { - break - } - - loc, err := s.GetRegionCache().LocateKey(bo, startKey) - if err != nil { - return taskStat, err - } - - endKey := loc.EndKey - isLast := len(endKey) == 0 || (len(rangeEndKey) > 0 && bytes.Compare(endKey, rangeEndKey) >= 0) - // If it is the last region. - if isLast { - endKey = rangeEndKey - } - - logutil.DDLLogger().Info("send flashback request", zap.Uint64("region_id", loc.Region.GetID()), - zap.String("start_key", hex.EncodeToString(startKey)), zap.String("end_key", hex.EncodeToString(endKey))) - - req := tikvrpc.NewRequest(tikvrpc.CmdFlashbackToVersion, &kvrpcpb.FlashbackToVersionRequest{ - Version: version, - StartKey: startKey, - EndKey: endKey, - StartTs: startTS, - CommitTs: commitTS, - }) - - resp, err := s.SendReq(bo, req, loc.Region, flashbackTimeout) - if err != nil { - logutil.DDLLogger().Warn("send request meets error", zap.Uint64("region_id", loc.Region.GetID()), zap.Error(err)) - if err.Error() != fmt.Sprintf("region %d is not prepared for the flashback", loc.Region.GetID()) { - return taskStat, err - } - } else { - regionErr, err := resp.GetRegionError() - if err != nil { - return taskStat, err - } - if regionErr != nil { - err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) - if err != nil { - return taskStat, err - } - continue - } - if resp.Resp == nil { - logutil.DDLLogger().Warn("flashback miss resp body", zap.Uint64("region_id", loc.Region.GetID())) - err = bo.Backoff(tikv.BoTiKVRPC(), errors.New("flashback rpc miss resp body")) - if err != nil { - return taskStat, err - } - continue - } - flashbackToVersionResp := resp.Resp.(*kvrpcpb.FlashbackToVersionResponse) - if respErr := flashbackToVersionResp.GetError(); respErr != "" { - boErr := bo.Backoff(tikv.BoTiKVRPC(), errors.New(respErr)) - if boErr != nil { - return taskStat, boErr - } - continue - } - } - taskStat.CompletedRegions++ - if isLast { - break - } - bo = tikv.NewBackoffer(ctx, flashbackMaxBackoff) - startKey = endKey - } - return taskStat, nil -} - -func flashbackToVersion( - ctx context.Context, - d *ddlCtx, - handler rangetask.TaskHandler, - startKey []byte, endKey []byte, -) (err error) { - return rangetask.NewRangeTaskRunner( - "flashback-to-version-runner", - d.store.(tikv.Storage), - int(variable.GetDDLFlashbackConcurrency()), - handler, - ).RunOnRange(ctx, startKey, endKey) -} - -func splitRegionsByKeyRanges(ctx context.Context, d *ddlCtx, keyRanges []kv.KeyRange) { - if s, ok := d.store.(kv.SplittableStore); ok { - for _, keys := range keyRanges { - for { - // tableID is useless when scatter == false - _, err := s.SplitRegions(ctx, [][]byte{keys.StartKey, keys.EndKey}, false, nil) - if err == nil { - break - } - } - } - } -} - -// A Flashback has 4 different stages. -// 1. before lock flashbackClusterJobID, check clusterJobID and lock it. -// 2. before flashback start, check timestamp, disable GC and close PD schedule, get flashback key ranges. -// 3. phase 1, lock flashback key ranges. -// 4. phase 2, send flashback RPC, do flashback jobs. -func (w *worker) onFlashbackCluster(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - inFlashbackTest := false - failpoint.Inject("mockFlashbackTest", func(val failpoint.Value) { - if val.(bool) { - inFlashbackTest = true - } - }) - // TODO: Support flashback in unistore. - if d.store.Name() != "TiKV" && !inFlashbackTest { - job.State = model.JobStateCancelled - return ver, errors.Errorf("Not support flashback cluster in non-TiKV env") - } - - var flashbackTS, lockedRegions, startTS, commitTS uint64 - var pdScheduleValue map[string]any - var autoAnalyzeValue, readOnlyValue, ttlJobEnableValue string - var gcEnabledValue bool - var keyRanges []kv.KeyRange - if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabledValue, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS, &ttlJobEnableValue, &keyRanges); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - var totalRegions, completedRegions atomic.Uint64 - totalRegions.Store(lockedRegions) - - sess, err := w.sessPool.Get() - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - defer w.sessPool.Put(sess) - - switch job.SchemaState { - // Stage 1, check and set FlashbackClusterJobID, and update job args. - case model.StateNone: - if err = savePDSchedule(w.ctx, job); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - gcEnableValue, err := gcutil.CheckGCEnable(sess) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.Args[gcEnabledOffset] = &gcEnableValue - autoAnalyzeValue, err = getTiDBEnableAutoAnalyze(sess) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.Args[autoAnalyzeOffset] = &autoAnalyzeValue - readOnlyValue, err = getTiDBSuperReadOnly(sess) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.Args[readOnlyOffset] = &readOnlyValue - ttlJobEnableValue, err = getTiDBTTLJobEnable(sess) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.Args[ttlJobEnableOffSet] = &ttlJobEnableValue - job.SchemaState = model.StateDeleteOnly - return ver, nil - // Stage 2, check flashbackTS, close GC and PD schedule, get flashback key ranges. - case model.StateDeleteOnly: - if err = checkAndSetFlashbackClusterInfo(w.ctx, sess, d, t, job, flashbackTS); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - // We should get startTS here to avoid lost startTS when TiDB crashed during send prepare flashback RPC. - startTS, err = d.store.GetOracle().GetTimestamp(w.ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.Args[startTSOffset] = startTS - keyRanges, err = getFlashbackKeyRanges(w.ctx, sess, flashbackTS) - if err != nil { - return ver, errors.Trace(err) - } - job.Args[keyRangesOffset] = keyRanges - job.SchemaState = model.StateWriteOnly - return updateSchemaVersion(d, t, job) - // Stage 3, lock related key ranges. - case model.StateWriteOnly: - // TODO: Support flashback in unistore. - if inFlashbackTest { - job.SchemaState = model.StateWriteReorganization - return updateSchemaVersion(d, t, job) - } - // Split region by keyRanges, make sure no unrelated key ranges be locked. - splitRegionsByKeyRanges(w.ctx, d, keyRanges) - totalRegions.Store(0) - for _, r := range keyRanges { - if err = flashbackToVersion(w.ctx, d, - func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { - stats, err := SendPrepareFlashbackToVersionRPC(ctx, d.store.(tikv.Storage), flashbackTS, startTS, r) - totalRegions.Add(uint64(stats.CompletedRegions)) - return stats, err - }, r.StartKey, r.EndKey); err != nil { - logutil.DDLLogger().Warn("Get error when do flashback", zap.Error(err)) - return ver, err - } - } - job.Args[totalLockedRegionsOffset] = totalRegions.Load() - - // We should get commitTS here to avoid lost commitTS when TiDB crashed during send flashback RPC. - commitTS, err = d.store.GetOracle().GetTimestamp(w.ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope}) - if err != nil { - return ver, errors.Trace(err) - } - job.Args[commitTSOffset] = commitTS - job.SchemaState = model.StateWriteReorganization - return ver, nil - // Stage 4, get key ranges and send flashback RPC. - case model.StateWriteReorganization: - // TODO: Support flashback in unistore. - if inFlashbackTest { - asyncNotifyEvent(d, statsutil.NewFlashbackClusterEvent()) - job.State = model.JobStateDone - job.SchemaState = model.StatePublic - return ver, nil - } - - for _, r := range keyRanges { - if err = flashbackToVersion(w.ctx, d, - func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { - // Use same startTS as prepare phase to simulate 1PC txn. - stats, err := SendFlashbackToVersionRPC(ctx, d.store.(tikv.Storage), flashbackTS, startTS, commitTS, r) - completedRegions.Add(uint64(stats.CompletedRegions)) - logutil.DDLLogger().Info("flashback cluster stats", - zap.Uint64("complete regions", completedRegions.Load()), - zap.Uint64("total regions", totalRegions.Load()), - zap.Error(err)) - return stats, err - }, r.StartKey, r.EndKey); err != nil { - logutil.DDLLogger().Warn("Get error when do flashback", zap.Error(err)) - return ver, errors.Trace(err) - } - } - - asyncNotifyEvent(d, statsutil.NewFlashbackClusterEvent()) - job.State = model.JobStateDone - job.SchemaState = model.StatePublic - return updateSchemaVersion(d, t, job) - } - return ver, nil -} - -func finishFlashbackCluster(w *worker, job *model.Job) error { - // Didn't do anything during flashback, return directly - if job.SchemaState == model.StateNone { - return nil - } - - var flashbackTS, lockedRegions, startTS, commitTS uint64 - var pdScheduleValue map[string]any - var autoAnalyzeValue, readOnlyValue, ttlJobEnableValue string - var gcEnabled bool - - if err := job.DecodeArgs(&flashbackTS, &pdScheduleValue, &gcEnabled, &autoAnalyzeValue, &readOnlyValue, &lockedRegions, &startTS, &commitTS, &ttlJobEnableValue); err != nil { - return errors.Trace(err) - } - sess, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(sess) - - err = kv.RunInNewTxn(w.ctx, w.store, true, func(context.Context, kv.Transaction) error { - if err = recoverPDSchedule(w.ctx, pdScheduleValue); err != nil { - return err - } - if gcEnabled { - if err = gcutil.EnableGC(sess); err != nil { - return err - } - } - if err = setTiDBSuperReadOnly(w.ctx, sess, readOnlyValue); err != nil { - return err - } - - if job.IsCancelled() { - // only restore `tidb_ttl_job_enable` when flashback failed - if err = setTiDBTTLJobEnable(w.ctx, sess, ttlJobEnableValue); err != nil { - return err - } - } - - return setTiDBEnableAutoAnalyze(w.ctx, sess, autoAnalyzeValue) - }) - if err != nil { - return err - } - - return nil -} diff --git a/pkg/ddl/column.go b/pkg/ddl/column.go index 4b5e656d0389a..5b889c8f73464 100644 --- a/pkg/ddl/column.go +++ b/pkg/ddl/column.go @@ -172,7 +172,7 @@ func onDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) } case model.StateWriteOnly: // write only -> delete only - failpoint.Call(_curpkg_("onDropColumnStateWriteOnly")) + failpoint.InjectCall("onDropColumnStateWriteOnly") colInfo.State = model.StateDeleteOnly tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) if len(idxInfos) > 0 { @@ -511,7 +511,7 @@ var TestReorgGoroutineRunning = make(chan any) // updateCurrentElement update the current element for reorgInfo. func (w *worker) updateCurrentElement(t table.Table, reorgInfo *reorgInfo) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockInfiniteReorgLogic")); _err_ == nil { + failpoint.Inject("mockInfiniteReorgLogic", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { a := new(any) @@ -520,11 +520,11 @@ func (w *worker) updateCurrentElement(t table.Table, reorgInfo *reorgInfo) error time.Sleep(30 * time.Millisecond) if w.isReorgCancelled(reorgInfo.Job.ID) { // Job is cancelled. So it can't be done. - return dbterror.ErrCancelledDDLJob + failpoint.Return(dbterror.ErrCancelledDDLJob) } } } - } + }) // TODO: Support partition tables. if bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { //nolint:forcetypeassert @@ -631,11 +631,11 @@ func newUpdateColumnWorker(id int, t table.PhysicalTable, decodeColMap map[int64 } } rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) - if _, _err_ := failpoint.Eval(_curpkg_("forceRowLevelChecksumOnUpdateColumnBackfill")); _err_ == nil { + failpoint.Inject("forceRowLevelChecksumOnUpdateColumnBackfill", func() { orig := variable.EnableRowLevelChecksum.Load() defer variable.EnableRowLevelChecksum.Store(orig) variable.EnableRowLevelChecksum.Store(true) - } + }) return &updateColumnWorker{ backfillCtx: bCtx, oldColInfo: oldCol, @@ -757,14 +757,14 @@ func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, ra recordWarning = errors.Cause(w.reformatErrors(warn[0].Err)).(*terror.Error) } - if val, _err_ := failpoint.Eval(_curpkg_("MockReorgTimeoutInOneRegion")); _err_ == nil { + failpoint.Inject("MockReorgTimeoutInOneRegion", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { if handle.IntValue() == 3000 && atomic.CompareAndSwapInt32(&TestCheckReorgTimeout, 0, 1) { - return errors.Trace(dbterror.ErrWaitReorgTimeout) + failpoint.Return(errors.Trace(dbterror.ErrWaitReorgTimeout)) } } - } + }) w.rowMap[w.newColInfo.ID] = newColVal _, err = w.rowDecoder.EvalRemainedExprColumnMap(w.exprCtx, w.rowMap) @@ -1158,12 +1158,12 @@ func modifyColsFromNull2NotNull(w *worker, dbInfo *model.DBInfo, tblInfo *model. defer w.sessPool.Put(sctx) skipCheck := false - if val, _err_ := failpoint.Eval(_curpkg_("skipMockContextDoExec")); _err_ == nil { + failpoint.Inject("skipMockContextDoExec", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { skipCheck = true } - } + }) if !skipCheck { // If there is a null value inserted, it cannot be modified and needs to be rollback. err = checkForNullValue(w.ctx, sctx, isDataTruncated, dbInfo.Name, tblInfo.Name, newCol, cols...) diff --git a/pkg/ddl/column.go__failpoint_stash__ b/pkg/ddl/column.go__failpoint_stash__ deleted file mode 100644 index 5b889c8f73464..0000000000000 --- a/pkg/ddl/column.go__failpoint_stash__ +++ /dev/null @@ -1,1320 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "context" - "encoding/hex" - "fmt" - "math/bits" - "strings" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/dbterror" - decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" - "github.com/pingcap/tidb/pkg/util/rowcodec" - kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -// InitAndAddColumnToTable initializes the ColumnInfo in-place and adds it to the table. -func InitAndAddColumnToTable(tblInfo *model.TableInfo, colInfo *model.ColumnInfo) *model.ColumnInfo { - cols := tblInfo.Columns - colInfo.ID = AllocateColumnID(tblInfo) - colInfo.State = model.StateNone - // To support add column asynchronous, we should mark its offset as the last column. - // So that we can use origin column offset to get value from row. - colInfo.Offset = len(cols) - // Append the column info to the end of the tblInfo.Columns. - // It will reorder to the right offset in "Columns" when it state change to public. - tblInfo.Columns = append(cols, colInfo) - return colInfo -} - -func checkAddColumn(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ColumnInfo, *model.ColumnInfo, - *ast.ColumnPosition, bool /* ifNotExists */, error) { - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, nil, nil, false, errors.Trace(err) - } - col := &model.ColumnInfo{} - pos := &ast.ColumnPosition{} - offset := 0 - ifNotExists := false - err = job.DecodeArgs(col, pos, &offset, &ifNotExists) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, nil, false, errors.Trace(err) - } - - columnInfo := model.FindColumnInfo(tblInfo.Columns, col.Name.L) - if columnInfo != nil { - if columnInfo.State == model.StatePublic { - // We already have a column with the same column name. - job.State = model.JobStateCancelled - return nil, nil, nil, nil, ifNotExists, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name) - } - } - - err = CheckAfterPositionExists(tblInfo, pos) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, nil, false, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name) - } - - return tblInfo, columnInfo, col, pos, false, nil -} - -// CheckAfterPositionExists makes sure the column specified in AFTER clause is exists. -// For example, ALTER TABLE t ADD COLUMN c3 INT AFTER c1. -func CheckAfterPositionExists(tblInfo *model.TableInfo, pos *ast.ColumnPosition) error { - if pos != nil && pos.Tp == ast.ColumnPositionAfter { - c := model.FindColumnInfo(tblInfo.Columns, pos.RelativeColumn.Name.L) - if c == nil { - return infoschema.ErrColumnNotExists.GenWithStackByArgs(pos.RelativeColumn, tblInfo.Name) - } - } - return nil -} - -func setIndicesState(indexInfos []*model.IndexInfo, state model.SchemaState) { - for _, indexInfo := range indexInfos { - indexInfo.State = state - } -} - -func checkDropColumnForStatePublic(colInfo *model.ColumnInfo) (err error) { - // When the dropping column has not-null flag and it hasn't the default value, we can backfill the column value like "add column". - // NOTE: If the state of StateWriteOnly can be rollbacked, we'd better reconsider the original default value. - // And we need consider the column without not-null flag. - if colInfo.GetOriginDefaultValue() == nil && mysql.HasNotNullFlag(colInfo.GetFlag()) { - // If the column is timestamp default current_timestamp, and DDL owner is new version TiDB that set column.Version to 1, - // then old TiDB update record in the column write only stage will uses the wrong default value of the dropping column. - // Because new version of the column default value is UTC time, but old version TiDB will think the default value is the time in system timezone. - // But currently will be ok, because we can't cancel the drop column job when the job is running, - // so the column will be dropped succeed and client will never see the wrong default value of the dropped column. - // More info about this problem, see PR#9115. - originDefVal, err := generateOriginDefaultValue(colInfo, nil) - if err != nil { - return err - } - return colInfo.SetOriginDefaultValue(originDefVal) - } - return nil -} - -func onDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - tblInfo, colInfo, idxInfos, ifExists, err := checkDropColumn(d, t, job) - if err != nil { - if ifExists && dbterror.ErrCantDropFieldOrKey.Equal(err) { - // Convert the "not exists" error to a warning. - job.Warning = toTError(err) - job.State = model.JobStateDone - return ver, nil - } - return ver, errors.Trace(err) - } - if job.MultiSchemaInfo != nil && !job.IsRollingback() && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - job.SchemaState = colInfo.State - // Store the mark and enter the next DDL handling loop. - return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) - } - - originalState := colInfo.State - switch colInfo.State { - case model.StatePublic: - // public -> write only - colInfo.State = model.StateWriteOnly - setIndicesState(idxInfos, model.StateWriteOnly) - tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) - err = checkDropColumnForStatePublic(colInfo) - if err != nil { - return ver, errors.Trace(err) - } - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != colInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - case model.StateWriteOnly: - // write only -> delete only - failpoint.InjectCall("onDropColumnStateWriteOnly") - colInfo.State = model.StateDeleteOnly - tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) - if len(idxInfos) > 0 { - newIndices := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) - for _, idx := range tblInfo.Indices { - if !indexInfoContains(idx.ID, idxInfos) { - newIndices = append(newIndices, idx) - } - } - tblInfo.Indices = newIndices - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != colInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - job.Args = append(job.Args, indexInfosToIDList(idxInfos)) - case model.StateDeleteOnly: - // delete only -> reorganization - colInfo.State = model.StateDeleteReorganization - tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != colInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - case model.StateDeleteReorganization: - // reorganization -> absent - // All reorganization jobs are done, drop this column. - tblInfo.MoveColumnInfo(colInfo.Offset, len(tblInfo.Columns)-1) - tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-1] - colInfo.State = model.StateNone - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != colInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - - // Finish this job. - if job.IsRollingback() { - job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) - } else { - // We should set related index IDs for job - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - job.Args = append(job.Args, getPartitionIDs(tblInfo)) - } - default: - return ver, errors.Trace(dbterror.ErrInvalidDDLJob.GenWithStackByArgs("table", tblInfo.State)) - } - job.SchemaState = colInfo.State - return ver, errors.Trace(err) -} - -func checkDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ColumnInfo, []*model.IndexInfo, bool /* ifExists */, error) { - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, nil, false, errors.Trace(err) - } - - var colName model.CIStr - var ifExists bool - // indexIDs is used to make sure we don't truncate args when decoding the rawArgs. - var indexIDs []int64 - err = job.DecodeArgs(&colName, &ifExists, &indexIDs) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, false, errors.Trace(err) - } - - colInfo := model.FindColumnInfo(tblInfo.Columns, colName.L) - if colInfo == nil || colInfo.Hidden { - job.State = model.JobStateCancelled - return nil, nil, nil, ifExists, dbterror.ErrCantDropFieldOrKey.GenWithStack("column %s doesn't exist", colName) - } - if err = isDroppableColumn(tblInfo, colName); err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, false, errors.Trace(err) - } - if err = checkDropColumnWithForeignKeyConstraintInOwner(d, t, job, tblInfo, colName.L); err != nil { - return nil, nil, nil, false, errors.Trace(err) - } - if err = checkDropColumnWithTTLConfig(tblInfo, colName.L); err != nil { - return nil, nil, nil, false, errors.Trace(err) - } - idxInfos := listIndicesWithColumn(colName.L, tblInfo.Indices) - return tblInfo, colInfo, idxInfos, false, nil -} - -func isDroppableColumn(tblInfo *model.TableInfo, colName model.CIStr) error { - if ok, dep, isHidden := hasDependentByGeneratedColumn(tblInfo, colName); ok { - if isHidden { - return dbterror.ErrDependentByFunctionalIndex.GenWithStackByArgs(dep) - } - return dbterror.ErrDependentByGeneratedColumn.GenWithStackByArgs(dep) - } - - if len(tblInfo.Columns) == 1 { - return dbterror.ErrCantRemoveAllFields.GenWithStack("can't drop only column %s in table %s", - colName, tblInfo.Name) - } - // We only support dropping column with single-value none Primary Key index covered now. - err := isColumnCanDropWithIndex(colName.L, tblInfo.Indices) - if err != nil { - return err - } - err = IsColumnDroppableWithCheckConstraint(colName, tblInfo) - if err != nil { - return err - } - return nil -} - -func onSetDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - newCol := &model.ColumnInfo{} - err := job.DecodeArgs(newCol) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - return updateColumnDefaultValue(d, t, job, newCol, &newCol.Name) -} - -func setIdxIDName(idxInfo *model.IndexInfo, newID int64, newName model.CIStr) { - idxInfo.ID = newID - idxInfo.Name = newName -} - -// SetIdxColNameOffset sets index column name and offset from changing ColumnInfo. -func SetIdxColNameOffset(idxCol *model.IndexColumn, changingCol *model.ColumnInfo) { - idxCol.Name = changingCol.Name - idxCol.Offset = changingCol.Offset - canPrefix := types.IsTypePrefixable(changingCol.GetType()) - if !canPrefix || (changingCol.GetFlen() <= idxCol.Length) { - idxCol.Length = types.UnspecifiedLength - } -} - -func removeChangingColAndIdxs(tblInfo *model.TableInfo, changingColID int64) { - restIdx := tblInfo.Indices[:0] - for _, idx := range tblInfo.Indices { - if !idx.HasColumnInIndexColumns(tblInfo, changingColID) { - restIdx = append(restIdx, idx) - } - } - tblInfo.Indices = restIdx - - restCols := tblInfo.Columns[:0] - for _, c := range tblInfo.Columns { - if c.ID != changingColID { - restCols = append(restCols, c) - } - } - tblInfo.Columns = restCols -} - -func replaceOldColumn(tblInfo *model.TableInfo, oldCol, changingCol *model.ColumnInfo, - newName model.CIStr) *model.ColumnInfo { - tblInfo.MoveColumnInfo(changingCol.Offset, len(tblInfo.Columns)-1) - changingCol = updateChangingCol(changingCol, newName, oldCol.Offset) - tblInfo.Columns[oldCol.Offset] = changingCol - tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-1] - return changingCol -} - -func replaceOldIndexes(tblInfo *model.TableInfo, changingIdxs []*model.IndexInfo) { - // Remove the changing indexes. - for i, idx := range tblInfo.Indices { - for _, cIdx := range changingIdxs { - if cIdx.ID == idx.ID { - tblInfo.Indices[i] = nil - break - } - } - } - tmp := tblInfo.Indices[:0] - for _, idx := range tblInfo.Indices { - if idx != nil { - tmp = append(tmp, idx) - } - } - tblInfo.Indices = tmp - // Replace the old indexes with changing indexes. - for _, cIdx := range changingIdxs { - // The index name should be changed from '_Idx$_name' to 'name'. - idxName := getChangingIndexOriginName(cIdx) - for i, idx := range tblInfo.Indices { - if strings.EqualFold(idxName, idx.Name.O) { - cIdx.Name = model.NewCIStr(idxName) - tblInfo.Indices[i] = cIdx - break - } - } - } -} - -// updateNewIdxColsNameOffset updates the name&offset of the index column. -func updateNewIdxColsNameOffset(changingIdxs []*model.IndexInfo, - oldName model.CIStr, changingCol *model.ColumnInfo) { - for _, idx := range changingIdxs { - for _, col := range idx.Columns { - if col.Name.L == oldName.L { - SetIdxColNameOffset(col, changingCol) - } - } - } -} - -// filterIndexesToRemove filters out the indexes that can be removed. -func filterIndexesToRemove(changingIdxs []*model.IndexInfo, colName model.CIStr, tblInfo *model.TableInfo) []*model.IndexInfo { - indexesToRemove := make([]*model.IndexInfo, 0, len(changingIdxs)) - for _, idx := range changingIdxs { - var hasOtherChangingCol bool - for _, col := range idx.Columns { - if col.Name.L == colName.L { - continue // ignore the current modifying column. - } - if !hasOtherChangingCol { - hasOtherChangingCol = tblInfo.Columns[col.Offset].ChangeStateInfo != nil - } - } - // For the indexes that still contains other changing column, skip removing it now. - // We leave the removal work to the last modify column job. - if !hasOtherChangingCol { - indexesToRemove = append(indexesToRemove, idx) - } - } - return indexesToRemove -} - -func updateChangingCol(col *model.ColumnInfo, newName model.CIStr, newOffset int) *model.ColumnInfo { - col.Name = newName - col.ChangeStateInfo = nil - col.Offset = newOffset - // After changing the column, the column's type is change, so it needs to set OriginDefaultValue back - // so that there is no error in getting the default value from OriginDefaultValue. - // Besides, nil data that was not backfilled in the "add column" is backfilled after the column is changed. - // So it can set OriginDefaultValue to nil. - col.OriginDefaultValue = nil - return col -} - -func buildRelatedIndexInfos(tblInfo *model.TableInfo, colID int64) []*model.IndexInfo { - var indexInfos []*model.IndexInfo - for _, idx := range tblInfo.Indices { - if idx.HasColumnInIndexColumns(tblInfo, colID) { - indexInfos = append(indexInfos, idx) - } - } - return indexInfos -} - -func buildRelatedIndexIDs(tblInfo *model.TableInfo, colID int64) []int64 { - var oldIdxIDs []int64 - for _, idx := range tblInfo.Indices { - if idx.HasColumnInIndexColumns(tblInfo, colID) { - oldIdxIDs = append(oldIdxIDs, idx.ID) - } - } - return oldIdxIDs -} - -// LocateOffsetToMove returns the offset of the column to move. -func LocateOffsetToMove(currentOffset int, pos *ast.ColumnPosition, tblInfo *model.TableInfo) (destOffset int, err error) { - if pos == nil { - return currentOffset, nil - } - // Get column offset. - switch pos.Tp { - case ast.ColumnPositionFirst: - return 0, nil - case ast.ColumnPositionAfter: - c := model.FindColumnInfo(tblInfo.Columns, pos.RelativeColumn.Name.L) - if c == nil || c.State != model.StatePublic { - return 0, infoschema.ErrColumnNotExists.GenWithStackByArgs(pos.RelativeColumn, tblInfo.Name) - } - if currentOffset <= c.Offset { - return c.Offset, nil - } - return c.Offset + 1, nil - case ast.ColumnPositionNone: - return currentOffset, nil - default: - return 0, errors.Errorf("unknown column position type") - } -} - -// BuildElements is exported for testing. -func BuildElements(changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo) []*meta.Element { - elements := make([]*meta.Element, 0, len(changingIdxs)+1) - elements = append(elements, &meta.Element{ID: changingCol.ID, TypeKey: meta.ColumnElementKey}) - for _, idx := range changingIdxs { - elements = append(elements, &meta.Element{ID: idx.ID, TypeKey: meta.IndexElementKey}) - } - return elements -} - -func (w *worker) updatePhysicalTableRow(t table.Table, reorgInfo *reorgInfo) error { - logutil.DDLLogger().Info("start to update table row", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) - if tbl, ok := t.(table.PartitionedTable); ok { - done := false - for !done { - p := tbl.GetPartition(reorgInfo.PhysicalTableID) - if p == nil { - return dbterror.ErrCancelledDDLJob.GenWithStack("Can not find partition id %d for table %d", reorgInfo.PhysicalTableID, t.Meta().ID) - } - workType := typeReorgPartitionWorker - switch reorgInfo.Job.Type { - case model.ActionReorganizePartition, - model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - // Expected - default: - // workType = typeUpdateColumnWorker - // TODO: Support Modify Column on partitioned table - // https://github.com/pingcap/tidb/issues/38297 - return dbterror.ErrCancelledDDLJob.GenWithStack("Modify Column on partitioned table / typeUpdateColumnWorker not yet supported.") - } - err := w.writePhysicalTableRecord(w.ctx, w.sessPool, p, workType, reorgInfo) - if err != nil { - return err - } - done, err = updateReorgInfo(w.sessPool, tbl, reorgInfo) - if err != nil { - return errors.Trace(err) - } - } - return nil - } - if tbl, ok := t.(table.PhysicalTable); ok { - return w.writePhysicalTableRecord(w.ctx, w.sessPool, tbl, typeUpdateColumnWorker, reorgInfo) - } - return dbterror.ErrCancelledDDLJob.GenWithStack("internal error for phys tbl id: %d tbl id: %d", reorgInfo.PhysicalTableID, t.Meta().ID) -} - -// TestReorgGoroutineRunning is only used in test to indicate the reorg goroutine has been started. -var TestReorgGoroutineRunning = make(chan any) - -// updateCurrentElement update the current element for reorgInfo. -func (w *worker) updateCurrentElement(t table.Table, reorgInfo *reorgInfo) error { - failpoint.Inject("mockInfiniteReorgLogic", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - a := new(any) - TestReorgGoroutineRunning <- a - for { - time.Sleep(30 * time.Millisecond) - if w.isReorgCancelled(reorgInfo.Job.ID) { - // Job is cancelled. So it can't be done. - failpoint.Return(dbterror.ErrCancelledDDLJob) - } - } - } - }) - // TODO: Support partition tables. - if bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { - //nolint:forcetypeassert - err := w.updatePhysicalTableRow(t.(table.PhysicalTable), reorgInfo) - if err != nil { - return errors.Trace(err) - } - } - - if _, ok := t.(table.PartitionedTable); ok { - // TODO: remove when modify column of partitioned table is supported - // https://github.com/pingcap/tidb/issues/38297 - return dbterror.ErrCancelledDDLJob.GenWithStack("Modify Column on partitioned table / typeUpdateColumnWorker not yet supported.") - } - // Get the original start handle and end handle. - currentVer, err := getValidCurrentVersion(reorgInfo.d.store) - if err != nil { - return errors.Trace(err) - } - //nolint:forcetypeassert - originalStartHandle, originalEndHandle, err := getTableRange(reorgInfo.NewJobContext(), reorgInfo.d, t.(table.PhysicalTable), currentVer.Ver, reorgInfo.Job.Priority) - if err != nil { - return errors.Trace(err) - } - - startElementOffset := 0 - startElementOffsetToResetHandle := -1 - // This backfill job starts with backfilling index data, whose index ID is currElement.ID. - if bytes.Equal(reorgInfo.currElement.TypeKey, meta.IndexElementKey) { - for i, element := range reorgInfo.elements[1:] { - if reorgInfo.currElement.ID == element.ID { - startElementOffset = i - startElementOffsetToResetHandle = i - break - } - } - } - - for i := startElementOffset; i < len(reorgInfo.elements[1:]); i++ { - // This backfill job has been exited during processing. At that time, the element is reorgInfo.elements[i+1] and handle range is [reorgInfo.StartHandle, reorgInfo.EndHandle]. - // Then the handle range of the rest elements' is [originalStartHandle, originalEndHandle]. - if i == startElementOffsetToResetHandle+1 { - reorgInfo.StartKey, reorgInfo.EndKey = originalStartHandle, originalEndHandle - } - - // Update the element in the reorgInfo for updating the reorg meta below. - reorgInfo.currElement = reorgInfo.elements[i+1] - // Write the reorg info to store so the whole reorganize process can recover from panic. - err := reorgInfo.UpdateReorgMeta(reorgInfo.StartKey, w.sessPool) - logutil.DDLLogger().Info("update column and indexes", - zap.Int64("job ID", reorgInfo.Job.ID), - zap.Stringer("element", reorgInfo.currElement), - zap.String("start key", hex.EncodeToString(reorgInfo.StartKey)), - zap.String("end key", hex.EncodeToString(reorgInfo.EndKey))) - if err != nil { - return errors.Trace(err) - } - err = w.addTableIndex(t, reorgInfo) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -type updateColumnWorker struct { - *backfillCtx - oldColInfo *model.ColumnInfo - newColInfo *model.ColumnInfo - - // The following attributes are used to reduce memory allocation. - rowRecords []*rowRecord - rowDecoder *decoder.RowDecoder - - rowMap map[int64]types.Datum - - checksumNeeded bool -} - -func newUpdateColumnWorker(id int, t table.PhysicalTable, decodeColMap map[int64]decoder.Column, reorgInfo *reorgInfo, jc *JobContext) (*updateColumnWorker, error) { - bCtx, err := newBackfillCtx(id, reorgInfo, reorgInfo.SchemaName, t, jc, "update_col_rate", false) - if err != nil { - return nil, err - } - - sessCtx := bCtx.sessCtx - sessCtx.GetSessionVars().StmtCtx.SetTypeFlags( - sessCtx.GetSessionVars().StmtCtx.TypeFlags(). - WithIgnoreZeroDateErr(!reorgInfo.ReorgMeta.SQLMode.HasStrictMode())) - bCtx.exprCtx = bCtx.sessCtx.GetExprCtx() - bCtx.tblCtx = bCtx.sessCtx.GetTableCtx() - - if !bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { - logutil.DDLLogger().Error("Element type for updateColumnWorker incorrect", zap.String("jobQuery", reorgInfo.Query), - zap.Stringer("reorgInfo", reorgInfo)) - return nil, nil - } - var oldCol, newCol *model.ColumnInfo - for _, col := range t.WritableCols() { - if col.ID == reorgInfo.currElement.ID { - newCol = col.ColumnInfo - oldCol = table.FindCol(t.Cols(), getChangingColumnOriginName(newCol)).ColumnInfo - break - } - } - rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) - failpoint.Inject("forceRowLevelChecksumOnUpdateColumnBackfill", func() { - orig := variable.EnableRowLevelChecksum.Load() - defer variable.EnableRowLevelChecksum.Store(orig) - variable.EnableRowLevelChecksum.Store(true) - }) - return &updateColumnWorker{ - backfillCtx: bCtx, - oldColInfo: oldCol, - newColInfo: newCol, - rowDecoder: rowDecoder, - rowMap: make(map[int64]types.Datum, len(decodeColMap)), - checksumNeeded: variable.EnableRowLevelChecksum.Load(), - }, nil -} - -func (w *updateColumnWorker) AddMetricInfo(cnt float64) { - w.metricCounter.Add(cnt) -} - -func (*updateColumnWorker) String() string { - return typeUpdateColumnWorker.String() -} - -func (w *updateColumnWorker) GetCtx() *backfillCtx { - return w.backfillCtx -} - -type rowRecord struct { - key []byte // It's used to lock a record. Record it to reduce the encoding time. - vals []byte // It's the record. - warning *terror.Error // It's used to record the cast warning of a record. -} - -// getNextHandleKey gets next handle of entry that we are going to process. -func getNextHandleKey(taskRange reorgBackfillTask, - taskDone bool, lastAccessedHandle kv.Key) (nextHandle kv.Key) { - if !taskDone { - // The task is not done. So we need to pick the last processed entry's handle and add one. - return lastAccessedHandle.Next() - } - - return taskRange.endKey.Next() -} - -func (w *updateColumnWorker) fetchRowColVals(txn kv.Transaction, taskRange reorgBackfillTask) ([]*rowRecord, kv.Key, bool, error) { - w.rowRecords = w.rowRecords[:0] - startTime := time.Now() - - // taskDone means that the added handle is out of taskRange.endHandle. - taskDone := false - var lastAccessedHandle kv.Key - oprStartTime := startTime - err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, taskRange.physicalTable.RecordPrefix(), - txn.StartTS(), taskRange.startKey, taskRange.endKey, func(handle kv.Handle, recordKey kv.Key, rawRow []byte) (bool, error) { - oprEndTime := time.Now() - logSlowOperations(oprEndTime.Sub(oprStartTime), "iterateSnapshotKeys in updateColumnWorker fetchRowColVals", 0) - oprStartTime = oprEndTime - - taskDone = recordKey.Cmp(taskRange.endKey) >= 0 - - if taskDone || len(w.rowRecords) >= w.batchCnt { - return false, nil - } - - if err1 := w.getRowRecord(handle, recordKey, rawRow); err1 != nil { - return false, errors.Trace(err1) - } - lastAccessedHandle = recordKey - if recordKey.Cmp(taskRange.endKey) == 0 { - taskDone = true - return false, nil - } - return true, nil - }) - - if len(w.rowRecords) == 0 { - taskDone = true - } - - logutil.DDLLogger().Debug("txn fetches handle info", - zap.Uint64("txnStartTS", txn.StartTS()), - zap.String("taskRange", taskRange.String()), - zap.Duration("takeTime", time.Since(startTime))) - return w.rowRecords, getNextHandleKey(taskRange, taskDone, lastAccessedHandle), taskDone, errors.Trace(err) -} - -func (w *updateColumnWorker) getRowRecord(handle kv.Handle, recordKey []byte, rawRow []byte) error { - sysTZ := w.loc - _, err := w.rowDecoder.DecodeTheExistedColumnMap(w.exprCtx, handle, rawRow, sysTZ, w.rowMap) - if err != nil { - return errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("column", err)) - } - - if _, ok := w.rowMap[w.newColInfo.ID]; ok { - // The column is already added by update or insert statement, skip it. - w.cleanRowMap() - return nil - } - - var recordWarning *terror.Error - // Since every updateColumnWorker handle their own work individually, we can cache warning in statement context when casting datum. - oldWarn := w.warnings.GetWarnings() - if oldWarn == nil { - oldWarn = []contextutil.SQLWarn{} - } else { - oldWarn = oldWarn[:0] - } - w.warnings.SetWarnings(oldWarn) - val := w.rowMap[w.oldColInfo.ID] - col := w.newColInfo - if val.Kind() == types.KindNull && col.FieldType.GetType() == mysql.TypeTimestamp && mysql.HasNotNullFlag(col.GetFlag()) { - if v, err := expression.GetTimeCurrentTimestamp(w.exprCtx.GetEvalCtx(), col.GetType(), col.GetDecimal()); err == nil { - // convert null value to timestamp should be substituted with current timestamp if NOT_NULL flag is set. - w.rowMap[w.oldColInfo.ID] = v - } - } - newColVal, err := table.CastColumnValue(w.exprCtx, w.rowMap[w.oldColInfo.ID], w.newColInfo, false, false) - if err != nil { - return w.reformatErrors(err) - } - warn := w.warnings.GetWarnings() - if len(warn) != 0 { - //nolint:forcetypeassert - recordWarning = errors.Cause(w.reformatErrors(warn[0].Err)).(*terror.Error) - } - - failpoint.Inject("MockReorgTimeoutInOneRegion", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - if handle.IntValue() == 3000 && atomic.CompareAndSwapInt32(&TestCheckReorgTimeout, 0, 1) { - failpoint.Return(errors.Trace(dbterror.ErrWaitReorgTimeout)) - } - } - }) - - w.rowMap[w.newColInfo.ID] = newColVal - _, err = w.rowDecoder.EvalRemainedExprColumnMap(w.exprCtx, w.rowMap) - if err != nil { - return errors.Trace(err) - } - newColumnIDs := make([]int64, 0, len(w.rowMap)) - newRow := make([]types.Datum, 0, len(w.rowMap)) - for colID, val := range w.rowMap { - newColumnIDs = append(newColumnIDs, colID) - newRow = append(newRow, val) - } - rd := w.tblCtx.GetRowEncodingConfig().RowEncoder - ec := w.exprCtx.GetEvalCtx().ErrCtx() - var checksum rowcodec.Checksum - if w.checksumNeeded { - checksum = rowcodec.RawChecksum{Key: recordKey} - } - newRowVal, err := tablecodec.EncodeRow(sysTZ, newRow, newColumnIDs, nil, nil, checksum, rd) - err = ec.HandleError(err) - if err != nil { - return errors.Trace(err) - } - - w.rowRecords = append(w.rowRecords, &rowRecord{key: recordKey, vals: newRowVal, warning: recordWarning}) - w.cleanRowMap() - return nil -} - -// reformatErrors casted error because `convertTo` function couldn't package column name and datum value for some errors. -func (w *updateColumnWorker) reformatErrors(err error) error { - // Since row count is not precious in concurrent reorganization, here we substitute row count with datum value. - if types.ErrTruncated.Equal(err) || types.ErrDataTooLong.Equal(err) { - dStr := datumToStringNoErr(w.rowMap[w.oldColInfo.ID]) - err = types.ErrTruncated.GenWithStack("Data truncated for column '%s', value is '%s'", w.oldColInfo.Name, dStr) - } - - if types.ErrWarnDataOutOfRange.Equal(err) { - dStr := datumToStringNoErr(w.rowMap[w.oldColInfo.ID]) - err = types.ErrWarnDataOutOfRange.GenWithStack("Out of range value for column '%s', the value is '%s'", w.oldColInfo.Name, dStr) - } - return err -} - -func datumToStringNoErr(d types.Datum) string { - if v, err := d.ToString(); err == nil { - return v - } - return fmt.Sprintf("%v", d.GetValue()) -} - -func (w *updateColumnWorker) cleanRowMap() { - for id := range w.rowMap { - delete(w.rowMap, id) - } -} - -// BackfillData will backfill the table record in a transaction. A lock corresponds to a rowKey if the value of rowKey is changed. -func (w *updateColumnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - oprStartTime := time.Now() - ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) - errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { - taskCtx.addedCount = 0 - taskCtx.scanCount = 0 - updateTxnEntrySizeLimitIfNeeded(txn) - - // Because TiCDC do not want this kind of change, - // so we set the lossy DDL reorg txn source to 1 to - // avoid TiCDC to replicate this kind of change. - var txnSource uint64 - if val := txn.GetOption(kv.TxnSource); val != nil { - txnSource, _ = val.(uint64) - } - err := kv.SetLossyDDLReorgSource(&txnSource, kv.LossyDDLColumnReorgSource) - if err != nil { - return errors.Trace(err) - } - txn.SetOption(kv.TxnSource, txnSource) - - txn.SetOption(kv.Priority, handleRange.priority) - if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(handleRange.getJobID()); tagger != nil { - txn.SetOption(kv.ResourceGroupTagger, tagger) - } - txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) - - rowRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) - if err != nil { - return errors.Trace(err) - } - taskCtx.nextKey = nextKey - taskCtx.done = taskDone - - // Optimize for few warnings! - warningsMap := make(map[errors.ErrorID]*terror.Error, 2) - warningsCountMap := make(map[errors.ErrorID]int64, 2) - for _, rowRecord := range rowRecords { - taskCtx.scanCount++ - - err = txn.Set(rowRecord.key, rowRecord.vals) - if err != nil { - return errors.Trace(err) - } - taskCtx.addedCount++ - if rowRecord.warning != nil { - if _, ok := warningsCountMap[rowRecord.warning.ID()]; ok { - warningsCountMap[rowRecord.warning.ID()]++ - } else { - warningsCountMap[rowRecord.warning.ID()] = 1 - warningsMap[rowRecord.warning.ID()] = rowRecord.warning - } - } - } - - // Collect the warnings. - taskCtx.warnings, taskCtx.warningsCount = warningsMap, warningsCountMap - - return nil - }) - logSlowOperations(time.Since(oprStartTime), "BackfillData", 3000) - - return -} - -func updateChangingObjState(changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo, schemaState model.SchemaState) { - changingCol.State = schemaState - for _, idx := range changingIdxs { - idx.State = schemaState - } -} - -func checkAndApplyAutoRandomBits(d *ddlCtx, m *meta.Meta, dbInfo *model.DBInfo, tblInfo *model.TableInfo, - oldCol *model.ColumnInfo, newCol *model.ColumnInfo, newAutoRandBits uint64) error { - if newAutoRandBits == 0 { - return nil - } - idAcc := m.GetAutoIDAccessors(dbInfo.ID, tblInfo.ID) - err := checkNewAutoRandomBits(idAcc, oldCol, newCol, newAutoRandBits, tblInfo.AutoRandomRangeBits, tblInfo.SepAutoInc()) - if err != nil { - return err - } - return applyNewAutoRandomBits(d, m, dbInfo, tblInfo, oldCol, newAutoRandBits) -} - -// checkNewAutoRandomBits checks whether the new auto_random bits number can cause overflow. -func checkNewAutoRandomBits(idAccessors meta.AutoIDAccessors, oldCol *model.ColumnInfo, - newCol *model.ColumnInfo, newShardBits, newRangeBits uint64, sepAutoInc bool) error { - shardFmt := autoid.NewShardIDFormat(&newCol.FieldType, newShardBits, newRangeBits) - - idAcc := idAccessors.RandomID() - convertedFromAutoInc := mysql.HasAutoIncrementFlag(oldCol.GetFlag()) - if convertedFromAutoInc { - if sepAutoInc { - idAcc = idAccessors.IncrementID(model.TableInfoVersion5) - } else { - idAcc = idAccessors.RowID() - } - } - // Generate a new auto ID first to prevent concurrent update in DML. - _, err := idAcc.Inc(1) - if err != nil { - return err - } - currentIncBitsVal, err := idAcc.Get() - if err != nil { - return err - } - // Find the max number of available shard bits by - // counting leading zeros in current inc part of auto_random ID. - usedBits := uint64(64 - bits.LeadingZeros64(uint64(currentIncBitsVal))) - if usedBits > shardFmt.IncrementalBits { - overflowCnt := usedBits - shardFmt.IncrementalBits - errMsg := fmt.Sprintf(autoid.AutoRandomOverflowErrMsg, newShardBits-overflowCnt, newShardBits, oldCol.Name.O) - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) - } - return nil -} - -func (d *ddlCtx) getAutoIDRequirement() autoid.Requirement { - return &asAutoIDRequirement{ - store: d.store, - autoidCli: d.autoidCli, - } -} - -type asAutoIDRequirement struct { - store kv.Storage - autoidCli *autoid.ClientDiscover -} - -var _ autoid.Requirement = &asAutoIDRequirement{} - -func (r *asAutoIDRequirement) Store() kv.Storage { - return r.store -} - -func (r *asAutoIDRequirement) AutoIDClient() *autoid.ClientDiscover { - return r.autoidCli -} - -// applyNewAutoRandomBits set auto_random bits to TableInfo and -// migrate auto_increment ID to auto_random ID if possible. -func applyNewAutoRandomBits(d *ddlCtx, m *meta.Meta, dbInfo *model.DBInfo, - tblInfo *model.TableInfo, oldCol *model.ColumnInfo, newAutoRandBits uint64) error { - tblInfo.AutoRandomBits = newAutoRandBits - needMigrateFromAutoIncToAutoRand := mysql.HasAutoIncrementFlag(oldCol.GetFlag()) - if !needMigrateFromAutoIncToAutoRand { - return nil - } - autoRandAlloc := autoid.NewAllocatorsFromTblInfo(d.getAutoIDRequirement(), dbInfo.ID, tblInfo).Get(autoid.AutoRandomType) - if autoRandAlloc == nil { - errMsg := fmt.Sprintf(autoid.AutoRandomAllocatorNotFound, dbInfo.Name.O, tblInfo.Name.O) - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) - } - idAcc := m.GetAutoIDAccessors(dbInfo.ID, tblInfo.ID).RowID() - nextAutoIncID, err := idAcc.Get() - if err != nil { - return errors.Trace(err) - } - err = autoRandAlloc.Rebase(context.Background(), nextAutoIncID, false) - if err != nil { - return errors.Trace(err) - } - if err := idAcc.Del(); err != nil { - return errors.Trace(err) - } - return nil -} - -// checkForNullValue ensure there are no null values of the column of this table. -// `isDataTruncated` indicates whether the new field and the old field type are the same, in order to be compatible with mysql. -func checkForNullValue(ctx context.Context, sctx sessionctx.Context, isDataTruncated bool, schema, table model.CIStr, newCol *model.ColumnInfo, oldCols ...*model.ColumnInfo) error { - needCheckNullValue := false - for _, oldCol := range oldCols { - if oldCol.GetType() != mysql.TypeTimestamp && newCol.GetType() == mysql.TypeTimestamp { - // special case for convert null value of non-timestamp type to timestamp type, null value will be substituted with current timestamp. - continue - } - needCheckNullValue = true - } - if !needCheckNullValue { - return nil - } - var buf strings.Builder - buf.WriteString("select 1 from %n.%n where ") - paramsList := make([]any, 0, 2+len(oldCols)) - paramsList = append(paramsList, schema.L, table.L) - for i, col := range oldCols { - if i == 0 { - buf.WriteString("%n is null") - paramsList = append(paramsList, col.Name.L) - } else { - buf.WriteString(" or %n is null") - paramsList = append(paramsList, col.Name.L) - } - } - buf.WriteString(" limit 1") - //nolint:forcetypeassert - rows, _, err := sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, buf.String(), paramsList...) - if err != nil { - return errors.Trace(err) - } - rowCount := len(rows) - if rowCount != 0 { - if isDataTruncated { - return dbterror.ErrWarnDataTruncated.GenWithStackByArgs(newCol.Name.L, rowCount) - } - return dbterror.ErrInvalidUseOfNull - } - return nil -} - -func updateColumnDefaultValue(d *ddlCtx, t *meta.Meta, job *model.Job, newCol *model.ColumnInfo, oldColName *model.CIStr) (ver int64, _ error) { - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - - if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - // Store the mark and enter the next DDL handling loop. - return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) - } - - oldCol := model.FindColumnInfo(tblInfo.Columns, oldColName.L) - if oldCol == nil || oldCol.State != model.StatePublic { - job.State = model.JobStateCancelled - return ver, infoschema.ErrColumnNotExists.GenWithStackByArgs(newCol.Name, tblInfo.Name) - } - - if hasDefaultValue, _, err := checkColumnDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol.Clone()), newCol.DefaultValue); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } else if !hasDefaultValue { - job.State = model.JobStateCancelled - return ver, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(newCol.Name) - } - - // The newCol's offset may be the value of the old schema version, so we can't use newCol directly. - oldCol.DefaultValue = newCol.DefaultValue - oldCol.DefaultValueBit = newCol.DefaultValueBit - oldCol.DefaultIsExpr = newCol.DefaultIsExpr - if mysql.HasNoDefaultValueFlag(newCol.GetFlag()) { - oldCol.AddFlag(mysql.NoDefaultValueFlag) - } else { - oldCol.DelFlag(mysql.NoDefaultValueFlag) - err = checkDefaultValue(newReorgExprCtx(), table.ToColumn(oldCol), true) - if err != nil { - job.State = model.JobStateCancelled - return ver, err - } - } - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func isColumnWithIndex(colName string, indices []*model.IndexInfo) bool { - for _, indexInfo := range indices { - for _, col := range indexInfo.Columns { - if col.Name.L == colName { - return true - } - } - } - return false -} - -func isColumnCanDropWithIndex(colName string, indices []*model.IndexInfo) error { - for _, indexInfo := range indices { - if indexInfo.Primary || len(indexInfo.Columns) > 1 { - for _, col := range indexInfo.Columns { - if col.Name.L == colName { - return dbterror.ErrCantDropColWithIndex.GenWithStack("can't drop column %s with composite index covered or Primary Key covered now", colName) - } - } - } - } - return nil -} - -func listIndicesWithColumn(colName string, indices []*model.IndexInfo) []*model.IndexInfo { - ret := make([]*model.IndexInfo, 0) - for _, indexInfo := range indices { - if len(indexInfo.Columns) == 1 && colName == indexInfo.Columns[0].Name.L { - ret = append(ret, indexInfo) - } - } - return ret -} - -// GetColumnForeignKeyInfo returns the wanted foreign key info -func GetColumnForeignKeyInfo(colName string, fkInfos []*model.FKInfo) *model.FKInfo { - for _, fkInfo := range fkInfos { - for _, col := range fkInfo.Cols { - if col.L == colName { - return fkInfo - } - } - } - return nil -} - -// AllocateColumnID allocates next column ID from TableInfo. -func AllocateColumnID(tblInfo *model.TableInfo) int64 { - tblInfo.MaxColumnID++ - return tblInfo.MaxColumnID -} - -func checkAddColumnTooManyColumns(colNum int) error { - if uint32(colNum) > atomic.LoadUint32(&config.GetGlobalConfig().TableColumnCountLimit) { - return dbterror.ErrTooManyFields - } - return nil -} - -// modifyColsFromNull2NotNull modifies the type definitions of 'null' to 'not null'. -// Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. -func modifyColsFromNull2NotNull(w *worker, dbInfo *model.DBInfo, tblInfo *model.TableInfo, cols []*model.ColumnInfo, newCol *model.ColumnInfo, isDataTruncated bool) error { - // Get sessionctx from context resource pool. - var sctx sessionctx.Context - sctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(sctx) - - skipCheck := false - failpoint.Inject("skipMockContextDoExec", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - skipCheck = true - } - }) - if !skipCheck { - // If there is a null value inserted, it cannot be modified and needs to be rollback. - err = checkForNullValue(w.ctx, sctx, isDataTruncated, dbInfo.Name, tblInfo.Name, newCol, cols...) - if err != nil { - return errors.Trace(err) - } - } - - // Prevent this field from inserting null values. - for _, col := range cols { - col.AddFlag(mysql.PreventNullInsertFlag) - } - return nil -} - -func generateOriginDefaultValue(col *model.ColumnInfo, ctx sessionctx.Context) (any, error) { - var err error - odValue := col.GetDefaultValue() - if odValue == nil && mysql.HasNotNullFlag(col.GetFlag()) || - // It's for drop column and modify column. - (col.DefaultIsExpr && odValue != strings.ToUpper(ast.CurrentTimestamp) && ctx == nil) { - switch col.GetType() { - // Just use enum field's first element for OriginDefaultValue. - case mysql.TypeEnum: - defEnum, verr := types.ParseEnumValue(col.GetElems(), 1) - if verr != nil { - return nil, errors.Trace(verr) - } - defVal := types.NewCollateMysqlEnumDatum(defEnum, col.GetCollate()) - return defVal.ToString() - default: - zeroVal := table.GetZeroValue(col) - odValue, err = zeroVal.ToString() - if err != nil { - return nil, errors.Trace(err) - } - } - } - - if odValue == strings.ToUpper(ast.CurrentTimestamp) { - var t time.Time - if ctx == nil { - t = time.Now() - } else { - t, _ = expression.GetStmtTimestamp(ctx.GetExprCtx().GetEvalCtx()) - } - if col.GetType() == mysql.TypeTimestamp { - odValue = types.NewTime(types.FromGoTime(t.UTC()), col.GetType(), col.GetDecimal()).String() - } else if col.GetType() == mysql.TypeDatetime { - odValue = types.NewTime(types.FromGoTime(t), col.GetType(), col.GetDecimal()).String() - } - return odValue, nil - } - - if col.DefaultIsExpr && ctx != nil { - valStr, ok := odValue.(string) - if !ok { - return nil, dbterror.ErrDefValGeneratedNamedFunctionIsNotAllowed.GenWithStackByArgs(col.Name.String()) - } - oldValue := strings.ToLower(valStr) - // It's checked in getFuncCallDefaultValue. - if !strings.Contains(oldValue, fmt.Sprintf("%s(%s(),", ast.DateFormat, ast.Now)) && - !strings.Contains(oldValue, ast.StrToDate) { - return nil, errors.Trace(dbterror.ErrBinlogUnsafeSystemFunction) - } - - defVal, err := table.GetColDefaultValue(ctx.GetExprCtx(), col) - if err != nil { - return nil, errors.Trace(err) - } - odValue, err = defVal.ToString() - if err != nil { - return nil, errors.Trace(err) - } - } - return odValue, nil -} - -func indexInfoContains(idxID int64, idxInfos []*model.IndexInfo) bool { - for _, idxInfo := range idxInfos { - if idxID == idxInfo.ID { - return true - } - } - return false -} - -func indexInfosToIDList(idxInfos []*model.IndexInfo) []int64 { - ids := make([]int64, 0, len(idxInfos)) - for _, idxInfo := range idxInfos { - ids = append(ids, idxInfo.ID) - } - return ids -} - -func genChangingColumnUniqueName(tblInfo *model.TableInfo, oldCol *model.ColumnInfo) string { - suffix := 0 - newColumnNamePrefix := fmt.Sprintf("%s%s", changingColumnPrefix, oldCol.Name.O) - newColumnLowerName := fmt.Sprintf("%s_%d", strings.ToLower(newColumnNamePrefix), suffix) - // Check whether the new column name is used. - columnNameMap := make(map[string]bool, len(tblInfo.Columns)) - for _, col := range tblInfo.Columns { - columnNameMap[col.Name.L] = true - } - for columnNameMap[newColumnLowerName] { - suffix++ - newColumnLowerName = fmt.Sprintf("%s_%d", strings.ToLower(newColumnNamePrefix), suffix) - } - return fmt.Sprintf("%s_%d", newColumnNamePrefix, suffix) -} - -func genChangingIndexUniqueName(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) string { - suffix := 0 - newIndexNamePrefix := fmt.Sprintf("%s%s", changingIndexPrefix, idxInfo.Name.O) - newIndexLowerName := fmt.Sprintf("%s_%d", strings.ToLower(newIndexNamePrefix), suffix) - // Check whether the new index name is used. - indexNameMap := make(map[string]bool, len(tblInfo.Indices)) - for _, idx := range tblInfo.Indices { - indexNameMap[idx.Name.L] = true - } - for indexNameMap[newIndexLowerName] { - suffix++ - newIndexLowerName = fmt.Sprintf("%s_%d", strings.ToLower(newIndexNamePrefix), suffix) - } - return fmt.Sprintf("%s_%d", newIndexNamePrefix, suffix) -} - -func getChangingIndexOriginName(changingIdx *model.IndexInfo) string { - idxName := strings.TrimPrefix(changingIdx.Name.O, changingIndexPrefix) - // Since the unique idxName may contain the suffix number (indexName_num), better trim the suffix. - var pos int - if pos = strings.LastIndex(idxName, "_"); pos == -1 { - return idxName - } - return idxName[:pos] -} - -func getChangingColumnOriginName(changingColumn *model.ColumnInfo) string { - columnName := strings.TrimPrefix(changingColumn.Name.O, changingColumnPrefix) - var pos int - if pos = strings.LastIndex(columnName, "_"); pos == -1 { - return columnName - } - return columnName[:pos] -} - -func getExpressionIndexOriginName(expressionIdx *model.ColumnInfo) string { - columnName := strings.TrimPrefix(expressionIdx.Name.O, expressionIndexPrefix+"_") - var pos int - if pos = strings.LastIndex(columnName, "_"); pos == -1 { - return columnName - } - return columnName[:pos] -} diff --git a/pkg/ddl/constraint.go b/pkg/ddl/constraint.go index 01364e895e641..572c66439889a 100644 --- a/pkg/ddl/constraint.go +++ b/pkg/ddl/constraint.go @@ -37,11 +37,11 @@ func (w *worker) onAddCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) ( return rollingBackAddConstraint(d, t, job) } - if val, _err_ := failpoint.Eval(_curpkg_("errorBeforeDecodeArgs")); _err_ == nil { + failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { if val.(bool) { - return ver, errors.New("occur an error before decode args") + failpoint.Return(ver, errors.New("occur an error before decode args")) } - } + }) dbInfo, tblInfo, constraintInfoInMeta, constraintInfoInJob, err := checkAddCheckConstraint(t, job) if err != nil { @@ -355,11 +355,11 @@ func findDependentColsInExpr(expr ast.ExprNode) map[string]struct{} { func (w *worker) verifyRemainRecordsForCheckConstraint(dbInfo *model.DBInfo, tableInfo *model.TableInfo, constr *model.ConstraintInfo) error { // Inject a fail-point to skip the remaining records check. - if val, _err_ := failpoint.Eval(_curpkg_("mockVerifyRemainDataSuccess")); _err_ == nil { + failpoint.Inject("mockVerifyRemainDataSuccess", func(val failpoint.Value) { if val.(bool) { - return nil + failpoint.Return(nil) } - } + }) // Get sessionctx from ddl context resource pool in ddl worker. var sctx sessionctx.Context sctx, err := w.sessPool.Get() diff --git a/pkg/ddl/constraint.go__failpoint_stash__ b/pkg/ddl/constraint.go__failpoint_stash__ deleted file mode 100644 index 572c66439889a..0000000000000 --- a/pkg/ddl/constraint.go__failpoint_stash__ +++ /dev/null @@ -1,432 +0,0 @@ -// Copyright 2023-2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "fmt" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/dbterror" -) - -func (w *worker) onAddCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - // Handle the rolling back job. - if job.IsRollingback() { - return rollingBackAddConstraint(d, t, job) - } - - failpoint.Inject("errorBeforeDecodeArgs", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(ver, errors.New("occur an error before decode args")) - } - }) - - dbInfo, tblInfo, constraintInfoInMeta, constraintInfoInJob, err := checkAddCheckConstraint(t, job) - if err != nil { - return ver, errors.Trace(err) - } - if constraintInfoInMeta == nil { - // It's first time to run add constraint job, so there is no constraint info in meta. - // Use the raw constraint info from job directly and modify table info here. - constraintInfoInJob.ID = allocateConstraintID(tblInfo) - // Reset constraint name according to real-time constraints name at this point. - constrNames := map[string]bool{} - for _, constr := range tblInfo.Constraints { - constrNames[constr.Name.L] = true - } - setNameForConstraintInfo(tblInfo.Name.L, constrNames, []*model.ConstraintInfo{constraintInfoInJob}) - // Double check the constraint dependency. - existedColsMap := make(map[string]struct{}) - cols := tblInfo.Columns - for _, v := range cols { - if v.State == model.StatePublic { - existedColsMap[v.Name.L] = struct{}{} - } - } - dependedCols := constraintInfoInJob.ConstraintCols - for _, k := range dependedCols { - if _, ok := existedColsMap[k.L]; !ok { - // The table constraint depended on a non-existed column. - return ver, dbterror.ErrTableCheckConstraintReferUnknown.GenWithStackByArgs(constraintInfoInJob.Name, k) - } - } - - tblInfo.Constraints = append(tblInfo.Constraints, constraintInfoInJob) - constraintInfoInMeta = constraintInfoInJob - } - - // If not enforced, add it directly. - if !constraintInfoInMeta.Enforced { - constraintInfoInMeta.State = model.StatePublic - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil - } - - switch constraintInfoInMeta.State { - case model.StateNone: - job.SchemaState = model.StateWriteOnly - constraintInfoInMeta.State = model.StateWriteOnly - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - case model.StateWriteOnly: - job.SchemaState = model.StateWriteReorganization - constraintInfoInMeta.State = model.StateWriteReorganization - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - case model.StateWriteReorganization: - err = w.verifyRemainRecordsForCheckConstraint(dbInfo, tblInfo, constraintInfoInMeta) - if err != nil { - if dbterror.ErrCheckConstraintIsViolated.Equal(err) { - job.State = model.JobStateRollingback - } - return ver, errors.Trace(err) - } - constraintInfoInMeta.State = model.StatePublic - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("constraint", constraintInfoInMeta.State) - } - - return ver, errors.Trace(err) -} - -func checkAddCheckConstraint(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ConstraintInfo, *model.ConstraintInfo, error) { - schemaID := job.SchemaID - dbInfo, err := t.GetDatabase(job.SchemaID) - if err != nil { - return nil, nil, nil, nil, errors.Trace(err) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, nil, nil, errors.Trace(err) - } - constraintInfo1 := &model.ConstraintInfo{} - err = job.DecodeArgs(constraintInfo1) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, nil, errors.Trace(err) - } - // do the double-check with constraint existence. - constraintInfo2 := tblInfo.FindConstraintInfoByName(constraintInfo1.Name.L) - if constraintInfo2 != nil { - if constraintInfo2.State == model.StatePublic { - // We already have a constraint with the same constraint name. - job.State = model.JobStateCancelled - return nil, nil, nil, nil, infoschema.ErrColumnExists.GenWithStackByArgs(constraintInfo1.Name) - } - // if not, that means constraint was in intermediate state. - } - - err = checkConstraintNamesNotExists(t, schemaID, []*model.ConstraintInfo{constraintInfo1}) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, nil, err - } - - return dbInfo, tblInfo, constraintInfo2, constraintInfo1, nil -} - -// onDropCheckConstraint can be called from two case: -// 1: rollback in add constraint.(in rollback function the job.args will be changed) -// 2: user drop constraint ddl. -func onDropCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - tblInfo, constraintInfo, err := checkDropCheckConstraint(t, job) - if err != nil { - return ver, errors.Trace(err) - } - - switch constraintInfo.State { - case model.StatePublic: - job.SchemaState = model.StateWriteOnly - constraintInfo.State = model.StateWriteOnly - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - case model.StateWriteOnly: - // write only state constraint will still take effect to check the newly inserted data. - // So the dependent column shouldn't be dropped even in this intermediate state. - constraintInfo.State = model.StateNone - // remove the constraint from tableInfo. - for i, constr := range tblInfo.Constraints { - if constr.Name.L == constraintInfo.Name.L { - tblInfo.Constraints = append(tblInfo.Constraints[0:i], tblInfo.Constraints[i+1:]...) - } - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - default: - err = dbterror.ErrInvalidDDLJob.GenWithStackByArgs("constraint", tblInfo.State) - } - return ver, errors.Trace(err) -} - -func checkDropCheckConstraint(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.ConstraintInfo, error) { - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, errors.Trace(err) - } - - var constrName model.CIStr - err = job.DecodeArgs(&constrName) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, errors.Trace(err) - } - - // double check with constraint existence. - constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) - if constraintInfo == nil { - job.State = model.JobStateCancelled - return nil, nil, dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) - } - return tblInfo, constraintInfo, nil -} - -func (w *worker) onAlterCheckConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - dbInfo, tblInfo, constraintInfo, enforced, err := checkAlterCheckConstraint(t, job) - if err != nil { - return ver, errors.Trace(err) - } - - if job.IsRollingback() { - return rollingBackAlterConstraint(d, t, job) - } - - // Current State is desired. - if constraintInfo.State == model.StatePublic && constraintInfo.Enforced == enforced { - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return - } - - // enforced will fetch table data and check the constraint. - if enforced { - switch constraintInfo.State { - case model.StatePublic: - job.SchemaState = model.StateWriteReorganization - constraintInfo.State = model.StateWriteReorganization - constraintInfo.Enforced = enforced - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - case model.StateWriteReorganization: - job.SchemaState = model.StateWriteOnly - constraintInfo.State = model.StateWriteOnly - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - case model.StateWriteOnly: - err = w.verifyRemainRecordsForCheckConstraint(dbInfo, tblInfo, constraintInfo) - if err != nil { - if dbterror.ErrCheckConstraintIsViolated.Equal(err) { - job.State = model.JobStateRollingback - } - return ver, errors.Trace(err) - } - constraintInfo.State = model.StatePublic - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - } - } else { - constraintInfo.Enforced = enforced - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - if err != nil { - // update version and tableInfo error will cause retry. - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - } - return ver, err -} - -func checkAlterCheckConstraint(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ConstraintInfo, bool, error) { - schemaID := job.SchemaID - dbInfo, err := t.GetDatabase(job.SchemaID) - if err != nil { - return nil, nil, nil, false, errors.Trace(err) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, nil, false, errors.Trace(err) - } - - var ( - enforced bool - constrName model.CIStr - ) - err = job.DecodeArgs(&constrName, &enforced) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, false, errors.Trace(err) - } - // do the double check with constraint existence. - constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) - if constraintInfo == nil { - job.State = model.JobStateCancelled - return nil, nil, nil, false, dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) - } - return dbInfo, tblInfo, constraintInfo, enforced, nil -} - -func allocateConstraintID(tblInfo *model.TableInfo) int64 { - tblInfo.MaxConstraintID++ - return tblInfo.MaxConstraintID -} - -func buildConstraintInfo(tblInfo *model.TableInfo, dependedCols []model.CIStr, constr *ast.Constraint, state model.SchemaState) (*model.ConstraintInfo, error) { - constraintName := model.NewCIStr(constr.Name) - if err := checkTooLongConstraint(constraintName); err != nil { - return nil, errors.Trace(err) - } - - // Restore check constraint expression to string. - var sb strings.Builder - restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName - restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) - - sb.Reset() - err := constr.Expr.Restore(restoreCtx) - if err != nil { - return nil, errors.Trace(err) - } - - // Create constraint info. - constraintInfo := &model.ConstraintInfo{ - Name: constraintName, - Table: tblInfo.Name, - ConstraintCols: dependedCols, - ExprString: sb.String(), - Enforced: constr.Enforced, - InColumn: constr.InColumn, - State: state, - } - - return constraintInfo, nil -} - -func checkTooLongConstraint(constr model.CIStr) error { - if len(constr.L) > mysql.MaxConstraintIdentifierLen { - return dbterror.ErrTooLongIdent.GenWithStackByArgs(constr) - } - return nil -} - -// findDependentColsInExpr returns a set of string, which indicates -// the names of the columns that are dependent by exprNode. -func findDependentColsInExpr(expr ast.ExprNode) map[string]struct{} { - colNames := FindColumnNamesInExpr(expr) - colsMap := make(map[string]struct{}, len(colNames)) - for _, depCol := range colNames { - colsMap[depCol.Name.L] = struct{}{} - } - return colsMap -} - -func (w *worker) verifyRemainRecordsForCheckConstraint(dbInfo *model.DBInfo, tableInfo *model.TableInfo, constr *model.ConstraintInfo) error { - // Inject a fail-point to skip the remaining records check. - failpoint.Inject("mockVerifyRemainDataSuccess", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil) - } - }) - // Get sessionctx from ddl context resource pool in ddl worker. - var sctx sessionctx.Context - sctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(sctx) - - // If there is any row can't pass the check expression, the add constraint action will error. - // It's no need to construct expression node out and pull the chunk rows through it. Here we - // can let the check expression restored string as the filter in where clause directly. - // Prepare internal SQL to fetch data from physical table under this filter. - sql := fmt.Sprintf("select 1 from `%s`.`%s` where not %s limit 1", dbInfo.Name.L, tableInfo.Name.L, constr.ExprString) - ctx := kv.WithInternalSourceType(w.ctx, kv.InternalTxnDDL) - rows, _, err := sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, sql) - if err != nil { - return errors.Trace(err) - } - rowCount := len(rows) - if rowCount != 0 { - return dbterror.ErrCheckConstraintIsViolated.GenWithStackByArgs(constr.Name.L) - } - return nil -} - -func setNameForConstraintInfo(tableLowerName string, namesMap map[string]bool, infos []*model.ConstraintInfo) { - cnt := 1 - constraintPrefix := tableLowerName + "_chk_" - for _, constrInfo := range infos { - if constrInfo.Name.O == "" { - constrName := fmt.Sprintf("%s%d", constraintPrefix, cnt) - for { - // loop until find constrName that haven't been used. - if !namesMap[constrName] { - namesMap[constrName] = true - break - } - cnt++ - constrName = fmt.Sprintf("%s%d", constraintPrefix, cnt) - } - constrInfo.Name = model.NewCIStr(constrName) - } - } -} - -// IsColumnDroppableWithCheckConstraint check whether the column in check-constraint whose dependent col is more than 1 -func IsColumnDroppableWithCheckConstraint(col model.CIStr, tblInfo *model.TableInfo) error { - for _, cons := range tblInfo.Constraints { - if len(cons.ConstraintCols) > 1 { - for _, colName := range cons.ConstraintCols { - if colName.L == col.L { - return dbterror.ErrCantDropColWithCheckConstraint.GenWithStackByArgs(cons.Name, col) - } - } - } - } - return nil -} - -// IsColumnRenameableWithCheckConstraint check whether the column is referenced in check-constraint -func IsColumnRenameableWithCheckConstraint(col model.CIStr, tblInfo *model.TableInfo) error { - for _, cons := range tblInfo.Constraints { - for _, colName := range cons.ConstraintCols { - if colName.L == col.L { - return dbterror.ErrCantDropColWithCheckConstraint.GenWithStackByArgs(cons.Name, col) - } - } - } - return nil -} diff --git a/pkg/ddl/create_table.go b/pkg/ddl/create_table.go index cc6ac85b7ddd7..68d686ddc35dc 100644 --- a/pkg/ddl/create_table.go +++ b/pkg/ddl/create_table.go @@ -96,11 +96,11 @@ func createTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkCheck bool) (*model. return tbInfo, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("checkOwnerCheckAllVersionsWaitTime")); _err_ == nil { + failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(val failpoint.Value) { if val.(bool) { - return tbInfo, errors.New("mock create table error") + failpoint.Return(tbInfo, errors.New("mock create table error")) } - } + }) // build table & partition bundles if any. if err = checkAllTablePlacementPoliciesExistAndCancelNonExistJob(t, job, tbInfo); err != nil { @@ -149,11 +149,11 @@ func createTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkCheck bool) (*model. } func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockExceedErrorLimit")); _err_ == nil { + failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { if val.(bool) { - return ver, errors.New("mock do job error") + failpoint.Return(ver, errors.New("mock do job error")) } - } + }) // just decode, createTable will use it as Args[0] tbInfo := &model.TableInfo{} @@ -306,7 +306,7 @@ func onCreateView(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) if infoschema.ErrTableNotExists.Equal(err) { err = nil } - failpoint.Call(_curpkg_("onDDLCreateView"), job) + failpoint.InjectCall("onDDLCreateView", job) if err != nil { if infoschema.ErrDatabaseNotExists.Equal(err) { job.State = model.JobStateCancelled diff --git a/pkg/ddl/create_table.go__failpoint_stash__ b/pkg/ddl/create_table.go__failpoint_stash__ deleted file mode 100644 index 68d686ddc35dc..0000000000000 --- a/pkg/ddl/create_table.go__failpoint_stash__ +++ /dev/null @@ -1,1527 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "fmt" - "math" - "strings" - "sync/atomic" - "unicode/utf8" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - field_types "github.com/pingcap/tidb/pkg/parser/types" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/types" - driver "github.com/pingcap/tidb/pkg/types/parser_driver" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/mock" - "github.com/pingcap/tidb/pkg/util/set" - "go.uber.org/zap" -) - -// DANGER: it is an internal function used by onCreateTable and onCreateTables, for reusing code. Be careful. -// 1. it expects the argument of job has been deserialized. -// 2. it won't call updateSchemaVersion, FinishTableJob and asyncNotifyEvent. -func createTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkCheck bool) (*model.TableInfo, error) { - schemaID := job.SchemaID - tbInfo := job.Args[0].(*model.TableInfo) - - tbInfo.State = model.StateNone - err := checkTableNotExists(d, schemaID, tbInfo.Name.L) - if err != nil { - if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { - job.State = model.JobStateCancelled - } - return tbInfo, errors.Trace(err) - } - - err = checkConstraintNamesNotExists(t, schemaID, tbInfo.Constraints) - if err != nil { - if infoschema.ErrCheckConstraintDupName.Equal(err) { - job.State = model.JobStateCancelled - } - return tbInfo, errors.Trace(err) - } - - retryable, err := checkTableForeignKeyValidInOwner(d, t, job, tbInfo, fkCheck) - if err != nil { - if !retryable { - job.State = model.JobStateCancelled - } - return tbInfo, errors.Trace(err) - } - // Allocate foreign key ID. - for _, fkInfo := range tbInfo.ForeignKeys { - fkInfo.ID = allocateFKIndexID(tbInfo) - fkInfo.State = model.StatePublic - } - switch tbInfo.State { - case model.StateNone: - // none -> public - tbInfo.State = model.StatePublic - tbInfo.UpdateTS = t.StartTS - err = createTableOrViewWithCheck(t, job, schemaID, tbInfo) - if err != nil { - return tbInfo, errors.Trace(err) - } - - failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(tbInfo, errors.New("mock create table error")) - } - }) - - // build table & partition bundles if any. - if err = checkAllTablePlacementPoliciesExistAndCancelNonExistJob(t, job, tbInfo); err != nil { - return tbInfo, errors.Trace(err) - } - - if tbInfo.TiFlashReplica != nil { - replicaInfo := tbInfo.TiFlashReplica - if pi := tbInfo.GetPartitionInfo(); pi != nil { - logutil.DDLLogger().Info("Set TiFlash replica pd rule for partitioned table when creating", zap.Int64("tableID", tbInfo.ID)) - if e := infosync.ConfigureTiFlashPDForPartitions(false, &pi.Definitions, replicaInfo.Count, &replicaInfo.LocationLabels, tbInfo.ID); e != nil { - job.State = model.JobStateCancelled - return tbInfo, errors.Trace(e) - } - // Partitions that in adding mid-state. They have high priorities, so we should set accordingly pd rules. - if e := infosync.ConfigureTiFlashPDForPartitions(true, &pi.AddingDefinitions, replicaInfo.Count, &replicaInfo.LocationLabels, tbInfo.ID); e != nil { - job.State = model.JobStateCancelled - return tbInfo, errors.Trace(e) - } - } else { - logutil.DDLLogger().Info("Set TiFlash replica pd rule when creating", zap.Int64("tableID", tbInfo.ID)) - if e := infosync.ConfigureTiFlashPDForTable(tbInfo.ID, replicaInfo.Count, &replicaInfo.LocationLabels); e != nil { - job.State = model.JobStateCancelled - return tbInfo, errors.Trace(e) - } - } - } - - bundles, err := placement.NewFullTableBundles(t, tbInfo) - if err != nil { - job.State = model.JobStateCancelled - return tbInfo, errors.Trace(err) - } - - // Send the placement bundle to PD. - err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) - if err != nil { - job.State = model.JobStateCancelled - return tbInfo, errors.Wrapf(err, "failed to notify PD the placement rules") - } - - return tbInfo, nil - default: - return tbInfo, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tbInfo.State) - } -} - -func onCreateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(ver, errors.New("mock do job error")) - } - }) - - // just decode, createTable will use it as Args[0] - tbInfo := &model.TableInfo{} - fkCheck := false - if err := job.DecodeArgs(tbInfo, &fkCheck); err != nil { - // Invalid arguments, cancel this job. - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if len(tbInfo.ForeignKeys) > 0 { - return createTableWithForeignKeys(d, t, job, tbInfo, fkCheck) - } - - tbInfo, err := createTable(d, t, job, fkCheck) - if err != nil { - return ver, errors.Trace(err) - } - - ver, err = updateSchemaVersion(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) - createTableEvent := statsutil.NewCreateTableEvent( - job.SchemaID, - tbInfo, - ) - asyncNotifyEvent(d, createTableEvent) - return ver, errors.Trace(err) -} - -func createTableWithForeignKeys(d *ddlCtx, t *meta.Meta, job *model.Job, tbInfo *model.TableInfo, fkCheck bool) (ver int64, err error) { - switch tbInfo.State { - case model.StateNone, model.StatePublic: - // create table in non-public or public state. The function `createTable` will always reset - // the `tbInfo.State` with `model.StateNone`, so it's fine to just call the `createTable` with - // public state. - // when `br` restores table, the state of `tbInfo` will be public. - tbInfo, err = createTable(d, t, job, fkCheck) - if err != nil { - return ver, errors.Trace(err) - } - tbInfo.State = model.StateWriteOnly - ver, err = updateVersionAndTableInfo(d, t, job, tbInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.SchemaState = model.StateWriteOnly - case model.StateWriteOnly: - tbInfo.State = model.StatePublic - ver, err = updateVersionAndTableInfo(d, t, job, tbInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) - createTableEvent := statsutil.NewCreateTableEvent( - job.SchemaID, - tbInfo, - ) - asyncNotifyEvent(d, createTableEvent) - return ver, nil - default: - return ver, errors.Trace(dbterror.ErrInvalidDDLJob.GenWithStackByArgs("table", tbInfo.State)) - } - return ver, errors.Trace(err) -} - -func onCreateTables(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { - var ver int64 - - var args []*model.TableInfo - fkCheck := false - err := job.DecodeArgs(&args, &fkCheck) - if err != nil { - // Invalid arguments, cancel this job. - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - // We don't construct jobs for every table, but only tableInfo - // The following loop creates a stub job for every table - // - // it clones a stub job from the ActionCreateTables job - stubJob := job.Clone() - stubJob.Args = make([]any, 1) - for i := range args { - stubJob.TableID = args[i].ID - stubJob.Args[0] = args[i] - if args[i].Sequence != nil { - err := createSequenceWithCheck(t, stubJob, args[i]) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - } else { - tbInfo, err := createTable(d, t, stubJob, fkCheck) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - args[i] = tbInfo - } - } - - ver, err = updateSchemaVersion(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - - job.State = model.JobStateDone - job.SchemaState = model.StatePublic - job.BinlogInfo.SetTableInfos(ver, args) - - for i := range args { - createTableEvent := statsutil.NewCreateTableEvent( - job.SchemaID, - args[i], - ) - asyncNotifyEvent(d, createTableEvent) - } - - return ver, errors.Trace(err) -} - -func createTableOrViewWithCheck(t *meta.Meta, job *model.Job, schemaID int64, tbInfo *model.TableInfo) error { - err := checkTableInfoValid(tbInfo) - if err != nil { - job.State = model.JobStateCancelled - return errors.Trace(err) - } - return t.CreateTableOrView(schemaID, tbInfo) -} - -func onCreateView(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - schemaID := job.SchemaID - tbInfo := &model.TableInfo{} - var orReplace bool - var _placeholder int64 // oldTblInfoID - if err := job.DecodeArgs(tbInfo, &orReplace, &_placeholder); err != nil { - // Invalid arguments, cancel this job. - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - tbInfo.State = model.StateNone - - oldTableID, err := findTableIDByName(d, t, schemaID, tbInfo.Name.L) - if infoschema.ErrTableNotExists.Equal(err) { - err = nil - } - failpoint.InjectCall("onDDLCreateView", job) - if err != nil { - if infoschema.ErrDatabaseNotExists.Equal(err) { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } else if !infoschema.ErrTableExists.Equal(err) { - return ver, errors.Trace(err) - } - if !orReplace { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - } - ver, err = updateSchemaVersion(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - switch tbInfo.State { - case model.StateNone: - // none -> public - tbInfo.State = model.StatePublic - tbInfo.UpdateTS = t.StartTS - if oldTableID > 0 && orReplace { - err = t.DropTableOrView(schemaID, oldTableID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - err = t.GetAutoIDAccessors(schemaID, oldTableID).Del() - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - } - err = createTableOrViewWithCheck(t, job, schemaID, tbInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) - return ver, nil - default: - return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tbInfo.State) - } -} - -func findTableIDByName(d *ddlCtx, t *meta.Meta, schemaID int64, tableName string) (int64, error) { - // Try to use memory schema info to check first. - currVer, err := t.GetSchemaVersion() - if err != nil { - return 0, err - } - is := d.infoCache.GetLatest() - if is != nil && is.SchemaMetaVersion() == currVer { - return findTableIDFromInfoSchema(is, schemaID, tableName) - } - - return findTableIDFromStore(t, schemaID, tableName) -} - -func findTableIDFromInfoSchema(is infoschema.InfoSchema, schemaID int64, tableName string) (int64, error) { - schema, ok := is.SchemaByID(schemaID) - if !ok { - return 0, infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") - } - tbl, err := is.TableByName(context.Background(), schema.Name, model.NewCIStr(tableName)) - if err != nil { - return 0, err - } - return tbl.Meta().ID, nil -} - -func findTableIDFromStore(t *meta.Meta, schemaID int64, tableName string) (int64, error) { - tbls, err := t.ListSimpleTables(schemaID) - if err != nil { - if meta.ErrDBNotExists.Equal(err) { - return 0, infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") - } - return 0, errors.Trace(err) - } - for _, tbl := range tbls { - if tbl.Name.L == tableName { - return tbl.ID, nil - } - } - return 0, infoschema.ErrTableNotExists.FastGenByArgs(tableName) -} - -// BuildTableInfoFromAST builds model.TableInfo from a SQL statement. -// Note: TableID and PartitionID are left as uninitialized value. -func BuildTableInfoFromAST(s *ast.CreateTableStmt) (*model.TableInfo, error) { - return buildTableInfoWithCheck(mock.NewContext(), s, mysql.DefaultCharset, "", nil) -} - -// buildTableInfoWithCheck builds model.TableInfo from a SQL statement. -// Note: TableID and PartitionIDs are left as uninitialized value. -func buildTableInfoWithCheck(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { - tbInfo, err := BuildTableInfoWithStmt(ctx, s, dbCharset, dbCollate, placementPolicyRef) - if err != nil { - return nil, err - } - // Fix issue 17952 which will cause partition range expr can't be parsed as Int. - // checkTableInfoValidWithStmt will do the constant fold the partition expression first, - // then checkTableInfoValidExtra will pass the tableInfo check successfully. - if err = checkTableInfoValidWithStmt(ctx, tbInfo, s); err != nil { - return nil, err - } - if err = checkTableInfoValidExtra(tbInfo); err != nil { - return nil, err - } - return tbInfo, nil -} - -// CheckTableInfoValidWithStmt exposes checkTableInfoValidWithStmt to SchemaTracker. Maybe one day we can delete it. -func CheckTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { - return checkTableInfoValidWithStmt(ctx, tbInfo, s) -} - -func checkTableInfoValidWithStmt(ctx sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) (err error) { - // All of these rely on the AST structure of expressions, which were - // lost in the model (got serialized into strings). - if err := checkGeneratedColumn(ctx, s.Table.Schema, tbInfo.Name, s.Cols); err != nil { - return errors.Trace(err) - } - - // Check if table has a primary key if required. - if !ctx.GetSessionVars().InRestrictedSQL && ctx.GetSessionVars().PrimaryKeyRequired && len(tbInfo.GetPkName().String()) == 0 { - return infoschema.ErrTableWithoutPrimaryKey - } - if tbInfo.Partition != nil { - if err := checkPartitionDefinitionConstraints(ctx, tbInfo); err != nil { - return errors.Trace(err) - } - if s.Partition != nil { - if err := checkPartitionFuncType(ctx, s.Partition.Expr, s.Table.Schema.O, tbInfo); err != nil { - return errors.Trace(err) - } - if err := checkPartitioningKeysConstraints(ctx, s, tbInfo); err != nil { - return errors.Trace(err) - } - } - } - if tbInfo.TTLInfo != nil { - if err := checkTTLInfoValid(ctx, s.Table.Schema, tbInfo); err != nil { - return errors.Trace(err) - } - } - - return nil -} - -func checkGeneratedColumn(ctx sessionctx.Context, schemaName model.CIStr, tableName model.CIStr, colDefs []*ast.ColumnDef) error { - var colName2Generation = make(map[string]columnGenerationInDDL, len(colDefs)) - var exists bool - var autoIncrementColumn string - for i, colDef := range colDefs { - for _, option := range colDef.Options { - if option.Tp == ast.ColumnOptionGenerated { - if err := checkIllegalFn4Generated(colDef.Name.Name.L, typeColumn, option.Expr); err != nil { - return errors.Trace(err) - } - } - } - if containsColumnOption(colDef, ast.ColumnOptionAutoIncrement) { - exists, autoIncrementColumn = true, colDef.Name.Name.L - } - generated, depCols, err := findDependedColumnNames(schemaName, tableName, colDef) - if err != nil { - return errors.Trace(err) - } - if !generated { - colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{ - position: i, - generated: false, - } - } else { - colName2Generation[colDef.Name.Name.L] = columnGenerationInDDL{ - position: i, - generated: true, - dependences: depCols, - } - } - } - - // Check whether the generated column refers to any auto-increment columns - if exists { - if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { - for colName, generated := range colName2Generation { - if _, found := generated.dependences[autoIncrementColumn]; found { - return dbterror.ErrGeneratedColumnRefAutoInc.GenWithStackByArgs(colName) - } - } - } - } - - for _, colDef := range colDefs { - colName := colDef.Name.Name.L - if err := verifyColumnGeneration(colName2Generation, colName); err != nil { - return errors.Trace(err) - } - } - return nil -} - -// checkTableInfoValidExtra is like checkTableInfoValid, but also assumes the -// table info comes from untrusted source and performs further checks such as -// name length and column count. -// (checkTableInfoValid is also used in repairing objects which don't perform -// these checks. Perhaps the two functions should be merged together regardless?) -func checkTableInfoValidExtra(tbInfo *model.TableInfo) error { - if err := checkTooLongTable(tbInfo.Name); err != nil { - return err - } - - if err := checkDuplicateColumn(tbInfo.Columns); err != nil { - return err - } - if err := checkTooLongColumns(tbInfo.Columns); err != nil { - return err - } - if err := checkTooManyColumns(tbInfo.Columns); err != nil { - return errors.Trace(err) - } - if err := checkTooManyIndexes(tbInfo.Indices); err != nil { - return errors.Trace(err) - } - if err := checkColumnsAttributes(tbInfo.Columns); err != nil { - return errors.Trace(err) - } - - // FIXME: perform checkConstraintNames - if err := checkCharsetAndCollation(tbInfo.Charset, tbInfo.Collate); err != nil { - return errors.Trace(err) - } - - oldState := tbInfo.State - tbInfo.State = model.StatePublic - err := checkTableInfoValid(tbInfo) - tbInfo.State = oldState - return err -} - -// checkTableInfoValid uses to check table info valid. This is used to validate table info. -func checkTableInfoValid(tblInfo *model.TableInfo) error { - _, err := tables.TableFromMeta(autoid.NewAllocators(false), tblInfo) - if err != nil { - return err - } - return checkInvisibleIndexOnPK(tblInfo) -} - -func checkDuplicateColumn(cols []*model.ColumnInfo) error { - colNames := set.StringSet{} - for _, col := range cols { - colName := col.Name - if colNames.Exist(colName.L) { - return infoschema.ErrColumnExists.GenWithStackByArgs(colName.O) - } - colNames.Insert(colName.L) - } - return nil -} - -func checkTooLongColumns(cols []*model.ColumnInfo) error { - for _, col := range cols { - if err := checkTooLongColumn(col.Name); err != nil { - return err - } - } - return nil -} - -func checkTooManyColumns(colDefs []*model.ColumnInfo) error { - if uint32(len(colDefs)) > atomic.LoadUint32(&config.GetGlobalConfig().TableColumnCountLimit) { - return dbterror.ErrTooManyFields - } - return nil -} - -func checkTooManyIndexes(idxDefs []*model.IndexInfo) error { - if len(idxDefs) > config.GetGlobalConfig().IndexLimit { - return dbterror.ErrTooManyKeys.GenWithStackByArgs(config.GetGlobalConfig().IndexLimit) - } - return nil -} - -// checkColumnsAttributes checks attributes for multiple columns. -func checkColumnsAttributes(colDefs []*model.ColumnInfo) error { - for _, colDef := range colDefs { - if err := checkColumnAttributes(colDef.Name.O, &colDef.FieldType); err != nil { - return errors.Trace(err) - } - } - return nil -} - -// checkColumnAttributes check attributes for single column. -func checkColumnAttributes(colName string, tp *types.FieldType) error { - switch tp.GetType() { - case mysql.TypeNewDecimal, mysql.TypeDouble, mysql.TypeFloat: - if tp.GetFlen() < tp.GetDecimal() { - return types.ErrMBiggerThanD.GenWithStackByArgs(colName) - } - case mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: - if tp.GetDecimal() != types.UnspecifiedFsp && (tp.GetDecimal() < types.MinFsp || tp.GetDecimal() > types.MaxFsp) { - return types.ErrTooBigPrecision.GenWithStackByArgs(tp.GetDecimal(), colName, types.MaxFsp) - } - } - return nil -} - -// BuildSessionTemporaryTableInfo builds model.TableInfo from a SQL statement. -func BuildSessionTemporaryTableInfo(ctx sessionctx.Context, is infoschema.InfoSchema, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { - ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} - //build tableInfo - var tbInfo *model.TableInfo - var referTbl table.Table - var err error - if s.ReferTable != nil { - referIdent := ast.Ident{Schema: s.ReferTable.Schema, Name: s.ReferTable.Name} - _, ok := is.SchemaByName(referIdent.Schema) - if !ok { - return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) - } - referTbl, err = is.TableByName(context.Background(), referIdent.Schema, referIdent.Name) - if err != nil { - return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) - } - tbInfo, err = BuildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) - } else { - tbInfo, err = buildTableInfoWithCheck(ctx, s, dbCharset, dbCollate, placementPolicyRef) - } - return tbInfo, err -} - -// BuildTableInfoWithStmt builds model.TableInfo from a SQL statement without validity check -func BuildTableInfoWithStmt(ctx sessionctx.Context, s *ast.CreateTableStmt, dbCharset, dbCollate string, placementPolicyRef *model.PolicyRefInfo) (*model.TableInfo, error) { - colDefs := s.Cols - tableCharset, tableCollate, err := GetCharsetAndCollateInTableOption(ctx.GetSessionVars(), 0, s.Options) - if err != nil { - return nil, errors.Trace(err) - } - tableCharset, tableCollate, err = ResolveCharsetCollation(ctx.GetSessionVars(), - ast.CharsetOpt{Chs: tableCharset, Col: tableCollate}, - ast.CharsetOpt{Chs: dbCharset, Col: dbCollate}, - ) - if err != nil { - return nil, errors.Trace(err) - } - - // The column charset haven't been resolved here. - cols, newConstraints, err := buildColumnsAndConstraints(ctx, colDefs, s.Constraints, tableCharset, tableCollate) - if err != nil { - return nil, errors.Trace(err) - } - err = checkConstraintNames(s.Table.Name, newConstraints) - if err != nil { - return nil, errors.Trace(err) - } - - var tbInfo *model.TableInfo - tbInfo, err = BuildTableInfo(ctx, s.Table.Name, cols, newConstraints, tableCharset, tableCollate) - if err != nil { - return nil, errors.Trace(err) - } - if err = setTemporaryType(ctx, tbInfo, s); err != nil { - return nil, errors.Trace(err) - } - - if err = setTableAutoRandomBits(ctx, tbInfo, colDefs); err != nil { - return nil, errors.Trace(err) - } - - if err = handleTableOptions(s.Options, tbInfo); err != nil { - return nil, errors.Trace(err) - } - - sessionVars := ctx.GetSessionVars() - if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, tbInfo.Name.L, &tbInfo.Comment, dbterror.ErrTooLongTableComment); err != nil { - return nil, errors.Trace(err) - } - - if tbInfo.TempTableType == model.TempTableNone && tbInfo.PlacementPolicyRef == nil && placementPolicyRef != nil { - // Set the defaults from Schema. Note: they are mutual exclusive! - tbInfo.PlacementPolicyRef = placementPolicyRef - } - - // After handleTableOptions, so the partitions can get defaults from Table level - err = buildTablePartitionInfo(ctx, s.Partition, tbInfo) - if err != nil { - return nil, errors.Trace(err) - } - - return tbInfo, nil -} - -func setTableAutoRandomBits(ctx sessionctx.Context, tbInfo *model.TableInfo, colDefs []*ast.ColumnDef) error { - for _, col := range colDefs { - if containsColumnOption(col, ast.ColumnOptionAutoRandom) { - if col.Tp.GetType() != mysql.TypeLonglong { - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs( - fmt.Sprintf(autoid.AutoRandomOnNonBigIntColumn, types.TypeStr(col.Tp.GetType()))) - } - switch { - case tbInfo.PKIsHandle: - if tbInfo.GetPkName().L != col.Name.Name.L { - errMsg := fmt.Sprintf(autoid.AutoRandomMustFirstColumnInPK, col.Name.Name.O) - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) - } - case tbInfo.IsCommonHandle: - pk := tables.FindPrimaryIndex(tbInfo) - if pk == nil { - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomNoClusteredPKErrMsg) - } - if col.Name.Name.L != pk.Columns[0].Name.L { - errMsg := fmt.Sprintf(autoid.AutoRandomMustFirstColumnInPK, col.Name.Name.O) - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) - } - default: - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomNoClusteredPKErrMsg) - } - - if containsColumnOption(col, ast.ColumnOptionAutoIncrement) { - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithAutoIncErrMsg) - } - if containsColumnOption(col, ast.ColumnOptionDefaultValue) { - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithDefaultValueErrMsg) - } - - shardBits, rangeBits, err := extractAutoRandomBitsFromColDef(col) - if err != nil { - return errors.Trace(err) - } - tbInfo.AutoRandomBits = shardBits - tbInfo.AutoRandomRangeBits = rangeBits - - shardFmt := autoid.NewShardIDFormat(col.Tp, shardBits, rangeBits) - if shardFmt.IncrementalBits < autoid.AutoRandomIncBitsMin { - return dbterror.ErrInvalidAutoRandom.FastGenByArgs(autoid.AutoRandomIncrementalBitsTooSmall) - } - msg := fmt.Sprintf(autoid.AutoRandomAvailableAllocTimesNote, shardFmt.IncrementalBitsCapacity()) - ctx.GetSessionVars().StmtCtx.AppendNote(errors.NewNoStackError(msg)) - } - } - return nil -} - -func containsColumnOption(colDef *ast.ColumnDef, opTp ast.ColumnOptionType) bool { - for _, option := range colDef.Options { - if option.Tp == opTp { - return true - } - } - return false -} - -func extractAutoRandomBitsFromColDef(colDef *ast.ColumnDef) (shardBits, rangeBits uint64, err error) { - for _, op := range colDef.Options { - if op.Tp == ast.ColumnOptionAutoRandom { - shardBits, err = autoid.AutoRandomShardBitsNormalize(op.AutoRandOpt.ShardBits, colDef.Name.Name.O) - if err != nil { - return 0, 0, err - } - rangeBits, err = autoid.AutoRandomRangeBitsNormalize(op.AutoRandOpt.RangeBits) - if err != nil { - return 0, 0, err - } - return shardBits, rangeBits, nil - } - } - return 0, 0, nil -} - -// handleTableOptions updates tableInfo according to table options. -func handleTableOptions(options []*ast.TableOption, tbInfo *model.TableInfo) error { - var ttlOptionsHandled bool - - for _, op := range options { - switch op.Tp { - case ast.TableOptionAutoIncrement: - tbInfo.AutoIncID = int64(op.UintValue) - case ast.TableOptionAutoIdCache: - if op.UintValue > uint64(math.MaxInt64) { - // TODO: Refine this error. - return errors.New("table option auto_id_cache overflows int64") - } - tbInfo.AutoIdCache = int64(op.UintValue) - case ast.TableOptionAutoRandomBase: - tbInfo.AutoRandID = int64(op.UintValue) - case ast.TableOptionComment: - tbInfo.Comment = op.StrValue - case ast.TableOptionCompression: - tbInfo.Compression = op.StrValue - case ast.TableOptionShardRowID: - if op.UintValue > 0 && tbInfo.HasClusteredIndex() { - return dbterror.ErrUnsupportedShardRowIDBits - } - tbInfo.ShardRowIDBits = op.UintValue - if tbInfo.ShardRowIDBits > shardRowIDBitsMax { - tbInfo.ShardRowIDBits = shardRowIDBitsMax - } - tbInfo.MaxShardRowIDBits = tbInfo.ShardRowIDBits - case ast.TableOptionPreSplitRegion: - if tbInfo.TempTableType != model.TempTableNone { - return errors.Trace(dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("pre split regions")) - } - tbInfo.PreSplitRegions = op.UintValue - case ast.TableOptionCharset, ast.TableOptionCollate: - // We don't handle charset and collate here since they're handled in `GetCharsetAndCollateInTableOption`. - case ast.TableOptionPlacementPolicy: - tbInfo.PlacementPolicyRef = &model.PolicyRefInfo{ - Name: model.NewCIStr(op.StrValue), - } - case ast.TableOptionTTL, ast.TableOptionTTLEnable, ast.TableOptionTTLJobInterval: - if ttlOptionsHandled { - continue - } - - ttlInfo, ttlEnable, ttlJobInterval, err := getTTLInfoInOptions(options) - if err != nil { - return err - } - // It's impossible that `ttlInfo` and `ttlEnable` are all nil, because we have met this option. - // After exclude the situation `ttlInfo == nil && ttlEnable != nil`, we could say `ttlInfo != nil` - if ttlInfo == nil { - if ttlEnable != nil { - return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_ENABLE")) - } - if ttlJobInterval != nil { - return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_JOB_INTERVAL")) - } - } - - tbInfo.TTLInfo = ttlInfo - ttlOptionsHandled = true - } - } - shardingBits := shardingBits(tbInfo) - if tbInfo.PreSplitRegions > shardingBits { - tbInfo.PreSplitRegions = shardingBits - } - return nil -} - -func setTemporaryType(_ sessionctx.Context, tbInfo *model.TableInfo, s *ast.CreateTableStmt) error { - switch s.TemporaryKeyword { - case ast.TemporaryGlobal: - tbInfo.TempTableType = model.TempTableGlobal - // "create global temporary table ... on commit preserve rows" - if !s.OnCommitDelete { - return errors.Trace(dbterror.ErrUnsupportedOnCommitPreserve) - } - case ast.TemporaryLocal: - tbInfo.TempTableType = model.TempTableLocal - default: - tbInfo.TempTableType = model.TempTableNone - } - return nil -} - -func buildColumnsAndConstraints( - ctx sessionctx.Context, - colDefs []*ast.ColumnDef, - constraints []*ast.Constraint, - tblCharset string, - tblCollate string, -) ([]*table.Column, []*ast.Constraint, error) { - // outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); - var outPriKeyConstraint *ast.Constraint - for _, v := range constraints { - if v.Tp == ast.ConstraintPrimaryKey { - outPriKeyConstraint = v - break - } - } - cols := make([]*table.Column, 0, len(colDefs)) - colMap := make(map[string]*table.Column, len(colDefs)) - - for i, colDef := range colDefs { - if field_types.TiDBStrictIntegerDisplayWidth { - switch colDef.Tp.GetType() { - case mysql.TypeTiny: - // No warning for BOOL-like tinyint(1) - if colDef.Tp.GetFlen() != types.UnspecifiedLength && colDef.Tp.GetFlen() != 1 { - ctx.GetSessionVars().StmtCtx.AppendWarning( - dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), - ) - } - case mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - if colDef.Tp.GetFlen() != types.UnspecifiedLength { - ctx.GetSessionVars().StmtCtx.AppendWarning( - dbterror.ErrWarnDeprecatedIntegerDisplayWidth.FastGenByArgs(), - ) - } - } - } - col, cts, err := buildColumnAndConstraint(ctx, i, colDef, outPriKeyConstraint, tblCharset, tblCollate) - if err != nil { - return nil, nil, errors.Trace(err) - } - col.State = model.StatePublic - if mysql.HasZerofillFlag(col.GetFlag()) { - ctx.GetSessionVars().StmtCtx.AppendWarning( - dbterror.ErrWarnDeprecatedZerofill.FastGenByArgs(), - ) - } - constraints = append(constraints, cts...) - cols = append(cols, col) - colMap[colDef.Name.Name.L] = col - } - // Traverse table Constraints and set col.flag. - for _, v := range constraints { - setColumnFlagWithConstraint(colMap, v) - } - return cols, constraints, nil -} - -func setEmptyConstraintName(namesMap map[string]bool, constr *ast.Constraint) { - if constr.Name == "" && len(constr.Keys) > 0 { - var colName string - for _, keyPart := range constr.Keys { - if keyPart.Expr != nil { - colName = "expression_index" - } - } - if colName == "" { - colName = constr.Keys[0].Column.Name.O - } - constrName := colName - i := 2 - if strings.EqualFold(constrName, mysql.PrimaryKeyName) { - constrName = fmt.Sprintf("%s_%d", constrName, 2) - i = 3 - } - for namesMap[constrName] { - // We loop forever until we find constrName that haven't been used. - constrName = fmt.Sprintf("%s_%d", colName, i) - i++ - } - constr.Name = constrName - namesMap[constrName] = true - } -} - -func checkConstraintNames(tableName model.CIStr, constraints []*ast.Constraint) error { - constrNames := map[string]bool{} - fkNames := map[string]bool{} - - // Check not empty constraint name whether is duplicated. - for _, constr := range constraints { - if constr.Tp == ast.ConstraintForeignKey { - err := checkDuplicateConstraint(fkNames, constr.Name, constr.Tp) - if err != nil { - return errors.Trace(err) - } - } else { - err := checkDuplicateConstraint(constrNames, constr.Name, constr.Tp) - if err != nil { - return errors.Trace(err) - } - } - } - - // Set empty constraint names. - checkConstraints := make([]*ast.Constraint, 0, len(constraints)) - for _, constr := range constraints { - if constr.Tp != ast.ConstraintForeignKey { - setEmptyConstraintName(constrNames, constr) - } - if constr.Tp == ast.ConstraintCheck { - checkConstraints = append(checkConstraints, constr) - } - } - // Set check constraint name under its order. - if len(checkConstraints) > 0 { - setEmptyCheckConstraintName(tableName.L, constrNames, checkConstraints) - } - return nil -} - -func checkDuplicateConstraint(namesMap map[string]bool, name string, constraintType ast.ConstraintType) error { - if name == "" { - return nil - } - nameLower := strings.ToLower(name) - if namesMap[nameLower] { - switch constraintType { - case ast.ConstraintForeignKey: - return dbterror.ErrFkDupName.GenWithStackByArgs(name) - case ast.ConstraintCheck: - return dbterror.ErrCheckConstraintDupName.GenWithStackByArgs(name) - default: - return dbterror.ErrDupKeyName.GenWithStackByArgs(name) - } - } - namesMap[nameLower] = true - return nil -} - -func setEmptyCheckConstraintName(tableLowerName string, namesMap map[string]bool, constrs []*ast.Constraint) { - cnt := 1 - constraintPrefix := tableLowerName + "_chk_" - for _, constr := range constrs { - if constr.Name == "" { - constrName := fmt.Sprintf("%s%d", constraintPrefix, cnt) - for { - // loop until find constrName that haven't been used. - if !namesMap[constrName] { - namesMap[constrName] = true - break - } - cnt++ - constrName = fmt.Sprintf("%s%d", constraintPrefix, cnt) - } - constr.Name = constrName - } - } -} - -func setColumnFlagWithConstraint(colMap map[string]*table.Column, v *ast.Constraint) { - switch v.Tp { - case ast.ConstraintPrimaryKey: - for _, key := range v.Keys { - if key.Expr != nil { - continue - } - c, ok := colMap[key.Column.Name.L] - if !ok { - continue - } - c.AddFlag(mysql.PriKeyFlag) - // Primary key can not be NULL. - c.AddFlag(mysql.NotNullFlag) - setNoDefaultValueFlag(c, c.DefaultValue != nil) - } - case ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey: - for i, key := range v.Keys { - if key.Expr != nil { - continue - } - c, ok := colMap[key.Column.Name.L] - if !ok { - continue - } - if i == 0 { - // Only the first column can be set - // if unique index has multi columns, - // the flag should be MultipleKeyFlag. - // See https://dev.mysql.com/doc/refman/5.7/en/show-columns.html - if len(v.Keys) > 1 { - c.AddFlag(mysql.MultipleKeyFlag) - } else { - c.AddFlag(mysql.UniqueKeyFlag) - } - } - } - case ast.ConstraintKey, ast.ConstraintIndex: - for i, key := range v.Keys { - if key.Expr != nil { - continue - } - c, ok := colMap[key.Column.Name.L] - if !ok { - continue - } - if i == 0 { - // Only the first column can be set. - c.AddFlag(mysql.MultipleKeyFlag) - } - } - } -} - -// BuildTableInfoWithLike builds a new table info according to CREATE TABLE ... LIKE statement. -func BuildTableInfoWithLike(ctx sessionctx.Context, ident ast.Ident, referTblInfo *model.TableInfo, s *ast.CreateTableStmt) (*model.TableInfo, error) { - // Check the referred table is a real table object. - if referTblInfo.IsSequence() || referTblInfo.IsView() { - return nil, dbterror.ErrWrongObject.GenWithStackByArgs(ident.Schema, referTblInfo.Name, "BASE TABLE") - } - tblInfo := *referTblInfo - if err := setTemporaryType(ctx, &tblInfo, s); err != nil { - return nil, errors.Trace(err) - } - // Check non-public column and adjust column offset. - newColumns := referTblInfo.Cols() - newIndices := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) - for _, idx := range tblInfo.Indices { - if idx.State == model.StatePublic { - newIndices = append(newIndices, idx) - } - } - tblInfo.Columns = newColumns - tblInfo.Indices = newIndices - tblInfo.Name = ident.Name - tblInfo.AutoIncID = 0 - tblInfo.ForeignKeys = nil - // Ignore TiFlash replicas for temporary tables. - if s.TemporaryKeyword != ast.TemporaryNone { - tblInfo.TiFlashReplica = nil - } else if tblInfo.TiFlashReplica != nil { - replica := *tblInfo.TiFlashReplica - // Keep the tiflash replica setting, remove the replica available status. - replica.AvailablePartitionIDs = nil - replica.Available = false - tblInfo.TiFlashReplica = &replica - } - if referTblInfo.Partition != nil { - pi := *referTblInfo.Partition - pi.Definitions = make([]model.PartitionDefinition, len(referTblInfo.Partition.Definitions)) - copy(pi.Definitions, referTblInfo.Partition.Definitions) - tblInfo.Partition = &pi - } - - if referTblInfo.TTLInfo != nil { - tblInfo.TTLInfo = referTblInfo.TTLInfo.Clone() - } - renameCheckConstraint(&tblInfo) - return &tblInfo, nil -} - -func renameCheckConstraint(tblInfo *model.TableInfo) { - for _, cons := range tblInfo.Constraints { - cons.Name = model.NewCIStr("") - cons.Table = tblInfo.Name - } - setNameForConstraintInfo(tblInfo.Name.L, map[string]bool{}, tblInfo.Constraints) -} - -// BuildTableInfo creates a TableInfo. -func BuildTableInfo( - ctx sessionctx.Context, - tableName model.CIStr, - cols []*table.Column, - constraints []*ast.Constraint, - charset string, - collate string, -) (tbInfo *model.TableInfo, err error) { - tbInfo = &model.TableInfo{ - Name: tableName, - Version: model.CurrLatestTableInfoVersion, - Charset: charset, - Collate: collate, - } - tblColumns := make([]*table.Column, 0, len(cols)) - existedColsMap := make(map[string]struct{}, len(cols)) - for _, v := range cols { - v.ID = AllocateColumnID(tbInfo) - tbInfo.Columns = append(tbInfo.Columns, v.ToInfo()) - tblColumns = append(tblColumns, table.ToColumn(v.ToInfo())) - existedColsMap[v.Name.L] = struct{}{} - } - foreignKeyID := tbInfo.MaxForeignKeyID - for _, constr := range constraints { - // Build hidden columns if necessary. - hiddenCols, err := buildHiddenColumnInfoWithCheck(ctx, constr.Keys, model.NewCIStr(constr.Name), tbInfo, tblColumns) - if err != nil { - return nil, err - } - for _, hiddenCol := range hiddenCols { - hiddenCol.State = model.StatePublic - hiddenCol.ID = AllocateColumnID(tbInfo) - hiddenCol.Offset = len(tbInfo.Columns) - tbInfo.Columns = append(tbInfo.Columns, hiddenCol) - tblColumns = append(tblColumns, table.ToColumn(hiddenCol)) - } - // Check clustered on non-primary key. - if constr.Option != nil && constr.Option.PrimaryKeyTp != model.PrimaryKeyTypeDefault && - constr.Tp != ast.ConstraintPrimaryKey { - return nil, dbterror.ErrUnsupportedClusteredSecondaryKey - } - if constr.Tp == ast.ConstraintForeignKey { - var fkName model.CIStr - foreignKeyID++ - if constr.Name != "" { - fkName = model.NewCIStr(constr.Name) - } else { - fkName = model.NewCIStr(fmt.Sprintf("fk_%d", foreignKeyID)) - } - if model.FindFKInfoByName(tbInfo.ForeignKeys, fkName.L) != nil { - return nil, infoschema.ErrCannotAddForeign - } - fk, err := buildFKInfo(fkName, constr.Keys, constr.Refer, cols) - if err != nil { - return nil, err - } - fk.State = model.StatePublic - - tbInfo.ForeignKeys = append(tbInfo.ForeignKeys, fk) - continue - } - if constr.Tp == ast.ConstraintPrimaryKey { - lastCol, err := CheckPKOnGeneratedColumn(tbInfo, constr.Keys) - if err != nil { - return nil, err - } - isSingleIntPK := isSingleIntPK(constr, lastCol) - if ShouldBuildClusteredIndex(ctx, constr.Option, isSingleIntPK) { - if isSingleIntPK { - tbInfo.PKIsHandle = true - } else { - tbInfo.IsCommonHandle = true - tbInfo.CommonHandleVersion = 1 - } - } - if tbInfo.HasClusteredIndex() { - // Primary key cannot be invisible. - if constr.Option != nil && constr.Option.Visibility == ast.IndexVisibilityInvisible { - return nil, dbterror.ErrPKIndexCantBeInvisible - } - } - if tbInfo.PKIsHandle { - continue - } - } - - if constr.Tp == ast.ConstraintFulltext { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt.FastGenByArgs()) - continue - } - - var ( - indexName = constr.Name - primary, unique bool - ) - - // Check if the index is primary or unique. - switch constr.Tp { - case ast.ConstraintPrimaryKey: - primary = true - unique = true - indexName = mysql.PrimaryKeyName - case ast.ConstraintUniq, ast.ConstraintUniqKey, ast.ConstraintUniqIndex: - unique = true - } - - // check constraint - if constr.Tp == ast.ConstraintCheck { - if !variable.EnableCheckConstraint.Load() { - ctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) - continue - } - // Since column check constraint dependency has been done in columnDefToCol. - // Here do the table check constraint dependency check, table constraint - // can only refer the columns in defined columns of the table. - // Refer: https://dev.mysql.com/doc/refman/8.0/en/create-table-check-constraints.html - if ok, err := table.IsSupportedExpr(constr); !ok { - return nil, err - } - var dependedCols []model.CIStr - dependedColsMap := findDependentColsInExpr(constr.Expr) - if !constr.InColumn { - dependedCols = make([]model.CIStr, 0, len(dependedColsMap)) - for k := range dependedColsMap { - if _, ok := existedColsMap[k]; !ok { - // The table constraint depended on a non-existed column. - return nil, dbterror.ErrTableCheckConstraintReferUnknown.GenWithStackByArgs(constr.Name, k) - } - dependedCols = append(dependedCols, model.NewCIStr(k)) - } - } else { - // Check the column-type constraint dependency. - if len(dependedColsMap) > 1 { - return nil, dbterror.ErrColumnCheckConstraintReferOther.GenWithStackByArgs(constr.Name) - } else if len(dependedColsMap) == 0 { - // If dependedCols is empty, the expression must be true/false. - valExpr, ok := constr.Expr.(*driver.ValueExpr) - if !ok || !mysql.HasIsBooleanFlag(valExpr.GetType().GetFlag()) { - return nil, errors.Trace(errors.New("unsupported expression in check constraint")) - } - } else { - if _, ok := dependedColsMap[constr.InColumnName]; !ok { - return nil, dbterror.ErrColumnCheckConstraintReferOther.GenWithStackByArgs(constr.Name) - } - dependedCols = []model.CIStr{model.NewCIStr(constr.InColumnName)} - } - } - // check auto-increment column - if table.ContainsAutoIncrementCol(dependedCols, tbInfo) { - return nil, dbterror.ErrCheckConstraintRefersAutoIncrementColumn.GenWithStackByArgs(constr.Name) - } - // check foreign key - if err := table.HasForeignKeyRefAction(tbInfo.ForeignKeys, constraints, constr, dependedCols); err != nil { - return nil, err - } - // build constraint meta info. - constraintInfo, err := buildConstraintInfo(tbInfo, dependedCols, constr, model.StatePublic) - if err != nil { - return nil, errors.Trace(err) - } - // check if the expression is bool type - if err := table.IfCheckConstraintExprBoolType(ctx.GetExprCtx().GetEvalCtx(), constraintInfo, tbInfo); err != nil { - return nil, err - } - constraintInfo.ID = allocateConstraintID(tbInfo) - tbInfo.Constraints = append(tbInfo.Constraints, constraintInfo) - continue - } - - // build index info. - idxInfo, err := BuildIndexInfo( - ctx, - tbInfo.Columns, - model.NewCIStr(indexName), - primary, - unique, - false, - constr.Keys, - constr.Option, - model.StatePublic, - ) - if err != nil { - return nil, errors.Trace(err) - } - - if len(hiddenCols) > 0 { - AddIndexColumnFlag(tbInfo, idxInfo) - } - sessionVars := ctx.GetSessionVars() - _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, idxInfo.Name.String(), &idxInfo.Comment, dbterror.ErrTooLongIndexComment) - if err != nil { - return nil, errors.Trace(err) - } - idxInfo.ID = AllocateIndexID(tbInfo) - tbInfo.Indices = append(tbInfo.Indices, idxInfo) - } - - err = addIndexForForeignKey(ctx, tbInfo) - return tbInfo, err -} - -func precheckBuildHiddenColumnInfo( - indexPartSpecifications []*ast.IndexPartSpecification, - indexName model.CIStr, -) error { - for i, idxPart := range indexPartSpecifications { - if idxPart.Expr == nil { - continue - } - name := fmt.Sprintf("%s_%s_%d", expressionIndexPrefix, indexName, i) - if utf8.RuneCountInString(name) > mysql.MaxColumnNameLength { - // TODO: Refine the error message. - return dbterror.ErrTooLongIdent.GenWithStackByArgs("hidden column") - } - // TODO: Refine the error message. - if err := checkIllegalFn4Generated(indexName.L, typeIndex, idxPart.Expr); err != nil { - return errors.Trace(err) - } - } - return nil -} - -func buildHiddenColumnInfoWithCheck(ctx sessionctx.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName model.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { - if err := precheckBuildHiddenColumnInfo(indexPartSpecifications, indexName); err != nil { - return nil, err - } - return BuildHiddenColumnInfo(ctx, indexPartSpecifications, indexName, tblInfo, existCols) -} - -// BuildHiddenColumnInfo builds hidden column info. -func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*ast.IndexPartSpecification, indexName model.CIStr, tblInfo *model.TableInfo, existCols []*table.Column) ([]*model.ColumnInfo, error) { - hiddenCols := make([]*model.ColumnInfo, 0, len(indexPartSpecifications)) - for i, idxPart := range indexPartSpecifications { - if idxPart.Expr == nil { - continue - } - idxPart.Column = &ast.ColumnName{Name: model.NewCIStr(fmt.Sprintf("%s_%s_%d", expressionIndexPrefix, indexName, i))} - // Check whether the hidden columns have existed. - col := table.FindCol(existCols, idxPart.Column.Name.L) - if col != nil { - // TODO: Use expression index related error. - return nil, infoschema.ErrColumnExists.GenWithStackByArgs(col.Name.String()) - } - idxPart.Length = types.UnspecifiedLength - // The index part is an expression, prepare a hidden column for it. - - var sb strings.Builder - restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutTableName - restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) - sb.Reset() - err := idxPart.Expr.Restore(restoreCtx) - if err != nil { - return nil, errors.Trace(err) - } - expr, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), idxPart.Expr, - expression.WithTableInfo(ctx.GetSessionVars().CurrentDB, tblInfo), - expression.WithAllowCastArray(true), - ) - if err != nil { - // TODO: refine the error message. - return nil, err - } - if _, ok := expr.(*expression.Column); ok { - return nil, dbterror.ErrFunctionalIndexOnField - } - - colInfo := &model.ColumnInfo{ - Name: idxPart.Column.Name, - GeneratedExprString: sb.String(), - GeneratedStored: false, - Version: model.CurrLatestColumnInfoVersion, - Dependences: make(map[string]struct{}), - Hidden: true, - FieldType: *expr.GetType(ctx.GetExprCtx().GetEvalCtx()), - } - // Reset some flag, it may be caused by wrong type infer. But it's not easy to fix them all, so reset them here for safety. - colInfo.DelFlag(mysql.PriKeyFlag | mysql.UniqueKeyFlag | mysql.AutoIncrementFlag) - - if colInfo.GetType() == mysql.TypeDatetime || colInfo.GetType() == mysql.TypeDate || colInfo.GetType() == mysql.TypeTimestamp || colInfo.GetType() == mysql.TypeDuration { - if colInfo.FieldType.GetDecimal() == types.UnspecifiedLength { - colInfo.FieldType.SetDecimal(types.MaxFsp) - } - } - // For an array, the collation is set to "binary". The collation has no effect on the array itself (as it's usually - // regarded as a JSON), but will influence how TiKV handles the index value. - if colInfo.FieldType.IsArray() { - colInfo.SetCharset("binary") - colInfo.SetCollate("binary") - } - checkDependencies := make(map[string]struct{}) - for _, colName := range FindColumnNamesInExpr(idxPart.Expr) { - colInfo.Dependences[colName.Name.L] = struct{}{} - checkDependencies[colName.Name.L] = struct{}{} - } - if err = checkDependedColExist(checkDependencies, existCols); err != nil { - return nil, errors.Trace(err) - } - if !ctx.GetSessionVars().EnableAutoIncrementInGenerated { - if err = checkExpressionIndexAutoIncrement(indexName.O, colInfo.Dependences, tblInfo); err != nil { - return nil, errors.Trace(err) - } - } - idxPart.Expr = nil - hiddenCols = append(hiddenCols, colInfo) - } - return hiddenCols, nil -} - -// addIndexForForeignKey uses to auto create an index for the foreign key if the table doesn't have any index cover the -// foreign key columns. -func addIndexForForeignKey(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - if len(tbInfo.ForeignKeys) == 0 { - return nil - } - var handleCol *model.ColumnInfo - if tbInfo.PKIsHandle { - handleCol = tbInfo.GetPkColInfo() - } - for _, fk := range tbInfo.ForeignKeys { - if fk.Version < model.FKVersion1 { - continue - } - if handleCol != nil && len(fk.Cols) == 1 && handleCol.Name.L == fk.Cols[0].L { - continue - } - if model.FindIndexByColumns(tbInfo, tbInfo.Indices, fk.Cols...) != nil { - continue - } - idxName := fk.Name - if tbInfo.FindIndexByName(idxName.L) != nil { - return dbterror.ErrDupKeyName.GenWithStack("duplicate key name %s", fk.Name.O) - } - keys := make([]*ast.IndexPartSpecification, 0, len(fk.Cols)) - for _, col := range fk.Cols { - keys = append(keys, &ast.IndexPartSpecification{ - Column: &ast.ColumnName{Name: col}, - Length: types.UnspecifiedLength, - }) - } - idxInfo, err := BuildIndexInfo(ctx, tbInfo.Columns, idxName, false, false, false, keys, nil, model.StatePublic) - if err != nil { - return errors.Trace(err) - } - idxInfo.ID = AllocateIndexID(tbInfo) - tbInfo.Indices = append(tbInfo.Indices, idxInfo) - } - return nil -} - -func isSingleIntPK(constr *ast.Constraint, lastCol *model.ColumnInfo) bool { - if len(constr.Keys) != 1 { - return false - } - switch lastCol.GetType() { - case mysql.TypeLong, mysql.TypeLonglong, - mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24: - return true - } - return false -} - -// ShouldBuildClusteredIndex is used to determine whether the CREATE TABLE statement should build a clustered index table. -func ShouldBuildClusteredIndex(ctx sessionctx.Context, opt *ast.IndexOption, isSingleIntPK bool) bool { - if opt == nil || opt.PrimaryKeyTp == model.PrimaryKeyTypeDefault { - switch ctx.GetSessionVars().EnableClusteredIndex { - case variable.ClusteredIndexDefModeOn: - return true - case variable.ClusteredIndexDefModeIntOnly: - return !config.GetGlobalConfig().AlterPrimaryKey && isSingleIntPK - default: - return false - } - } - return opt.PrimaryKeyTp == model.PrimaryKeyTypeClustered -} - -// BuildViewInfo builds a ViewInfo structure from an ast.CreateViewStmt. -func BuildViewInfo(s *ast.CreateViewStmt) (*model.ViewInfo, error) { - // Always Use `format.RestoreNameBackQuotes` to restore `SELECT` statement despite the `ANSI_QUOTES` SQL Mode is enabled or not. - restoreFlag := format.RestoreStringSingleQuotes | format.RestoreKeyWordUppercase | format.RestoreNameBackQuotes - var sb strings.Builder - if err := s.Select.Restore(format.NewRestoreCtx(restoreFlag, &sb)); err != nil { - return nil, err - } - - return &model.ViewInfo{Definer: s.Definer, Algorithm: s.Algorithm, - Security: s.Security, SelectStmt: sb.String(), CheckOption: s.CheckOption, Cols: nil}, nil -} diff --git a/pkg/ddl/ddl.go b/pkg/ddl/ddl.go index 12eea8d51efbc..4ec779e8cc0e7 100644 --- a/pkg/ddl/ddl.go +++ b/pkg/ddl/ddl.go @@ -1224,11 +1224,11 @@ func processJobs( ids []int64, byWho model.AdminCommandOperator, ) (jobErrs []error, err error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockFailedCommandOnConcurencyDDL")); _err_ == nil { + failpoint.Inject("mockFailedCommandOnConcurencyDDL", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("mock failed admin command on ddl jobs") + failpoint.Return(nil, errors.New("mock failed admin command on ddl jobs")) } - } + }) if len(ids) == 0 { return nil, nil @@ -1279,11 +1279,11 @@ func processJobs( } } - if val, _err_ := failpoint.Eval(_curpkg_("mockCommitFailedOnDDLCommand")); _err_ == nil { + failpoint.Inject("mockCommitFailedOnDDLCommand", func(val failpoint.Value) { if val.(bool) { - return jobErrs, errors.New("mock commit failed on admin command on ddl jobs") + failpoint.Return(jobErrs, errors.New("mock commit failed on admin command on ddl jobs")) } - } + }) // There may be some conflict during the update, try it again if err = ns.Commit(context.Background()); err != nil { diff --git a/pkg/ddl/ddl.go__failpoint_stash__ b/pkg/ddl/ddl.go__failpoint_stash__ deleted file mode 100644 index 4ec779e8cc0e7..0000000000000 --- a/pkg/ddl/ddl.go__failpoint_stash__ +++ /dev/null @@ -1,1421 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -package ddl - -import ( - "context" - "fmt" - "strconv" - "strings" - "sync" - "time" - - "github.com/google/uuid" - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/ddl/syncer" - "github.com/pingcap/tidb/pkg/ddl/systable" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/owner" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics/handle" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/table" - pumpcli "github.com/pingcap/tidb/pkg/tidb-binlog/pump_client" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/gcutil" - "github.com/pingcap/tidb/pkg/util/generic" - "github.com/tikv/client-go/v2/tikvrpc" - clientv3 "go.etcd.io/etcd/client/v3" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -const ( - // currentVersion is for all new DDL jobs. - currentVersion = 1 - // DDLOwnerKey is the ddl owner path that is saved to etcd, and it's exported for testing. - DDLOwnerKey = "/tidb/ddl/fg/owner" - ddlSchemaVersionKeyLock = "/tidb/ddl/schema_version_lock" - // addingDDLJobPrefix is the path prefix used to record the newly added DDL job, and it's saved to etcd. - addingDDLJobPrefix = "/tidb/ddl/add_ddl_job_" - ddlPrompt = "ddl" - - shardRowIDBitsMax = 15 - - batchAddingJobs = 100 - - reorgWorkerCnt = 10 - generalWorkerCnt = 10 - - // checkFlagIndexInJobArgs is the recoverCheckFlag index used in RecoverTable/RecoverSchema job arg list. - checkFlagIndexInJobArgs = 1 -) - -const ( - // The recoverCheckFlag is used to judge the gc work status when RecoverTable/RecoverSchema. - recoverCheckFlagNone int64 = iota - recoverCheckFlagEnableGC - recoverCheckFlagDisableGC -) - -// OnExist specifies what to do when a new object has a name collision. -type OnExist uint8 - -// CreateTableConfig is the configuration of `CreateTableWithInfo`. -type CreateTableConfig struct { - OnExist OnExist - // IDAllocated indicates whether the job has allocated all IDs for tables affected - // in the job, if true, DDL will not allocate IDs for them again, it's only used - // by BR now. By reusing IDs BR can save a lot of works such as rewriting table - // IDs in backed up KVs. - IDAllocated bool -} - -// CreateTableOption is the option for creating table. -type CreateTableOption func(*CreateTableConfig) - -// GetCreateTableConfig applies the series of config options from default config -// and returns the final config. -func GetCreateTableConfig(cs []CreateTableOption) CreateTableConfig { - cfg := CreateTableConfig{} - for _, c := range cs { - c(&cfg) - } - return cfg -} - -// WithOnExist applies the OnExist option. -func WithOnExist(o OnExist) CreateTableOption { - return func(cfg *CreateTableConfig) { - cfg.OnExist = o - } -} - -// WithIDAllocated applies the IDAllocated option. -// WARNING!!!: if idAllocated == true, DDL will NOT allocate IDs by itself. That -// means if the caller can not promise ID is unique, then we got inconsistency. -// This option is only exposed to be used by BR. -func WithIDAllocated(idAllocated bool) CreateTableOption { - return func(cfg *CreateTableConfig) { - cfg.IDAllocated = idAllocated - } -} - -const ( - // OnExistError throws an error on name collision. - OnExistError OnExist = iota - // OnExistIgnore skips creating the new object. - OnExistIgnore - // OnExistReplace replaces the old object by the new object. This is only - // supported by VIEWs at the moment. For other object types, this is - // equivalent to OnExistError. - OnExistReplace - - jobRecordCapacity = 16 - jobOnceCapacity = 1000 -) - -var ( - // EnableSplitTableRegion is a flag to decide whether to split a new region for - // a newly created table. It takes effect only if the Storage supports split - // region. - EnableSplitTableRegion = uint32(0) -) - -// DDL is responsible for updating schema in data store and maintaining in-memory InfoSchema cache. -type DDL interface { - // Start campaigns the owner and starts workers. - // ctxPool is used for the worker's delRangeManager and creates sessions. - Start(ctxPool *pools.ResourcePool) error - // GetLease returns current schema lease time. - GetLease() time.Duration - // Stats returns the DDL statistics. - Stats(vars *variable.SessionVars) (map[string]any, error) - // GetScope gets the status variables scope. - GetScope(status string) variable.ScopeFlag - // Stop stops DDL worker. - Stop() error - // RegisterStatsHandle registers statistics handle and its corresponding event channel for ddl. - RegisterStatsHandle(*handle.Handle) - // SchemaSyncer gets the schema syncer. - SchemaSyncer() syncer.SchemaSyncer - // StateSyncer gets the cluster state syncer. - StateSyncer() syncer.StateSyncer - // OwnerManager gets the owner manager. - OwnerManager() owner.Manager - // GetID gets the ddl ID. - GetID() string - // GetTableMaxHandle gets the max row ID of a normal table or a partition. - GetTableMaxHandle(ctx *JobContext, startTS uint64, tbl table.PhysicalTable) (kv.Handle, bool, error) - // SetBinlogClient sets the binlog client for DDL worker. It's exported for testing. - SetBinlogClient(*pumpcli.PumpsClient) - // GetMinJobIDRefresher gets the MinJobIDRefresher, this api only works after Start. - GetMinJobIDRefresher() *systable.MinJobIDRefresher -} - -type jobSubmitResult struct { - err error - jobID int64 - // merged indicates whether the job is merged into another job together with - // other jobs. we only merge multiple create table jobs into one job when fast - // create table is enabled. - merged bool -} - -// JobWrapper is used to wrap a job and some other information. -// exported for testing. -type JobWrapper struct { - *model.Job - // IDAllocated see config of same name in CreateTableConfig. - // exported for test. - IDAllocated bool - // job submission is run in async, we use this channel to notify the caller. - // when fast create table enabled, we might combine multiple jobs into one, and - // append the channel to this slice. - ResultCh []chan jobSubmitResult - cacheErr error -} - -// NewJobWrapper creates a new JobWrapper. -// exported for testing. -func NewJobWrapper(job *model.Job, idAllocated bool) *JobWrapper { - return &JobWrapper{ - Job: job, - IDAllocated: idAllocated, - ResultCh: []chan jobSubmitResult{make(chan jobSubmitResult)}, - } -} - -// NotifyResult notifies the job submit result. -func (t *JobWrapper) NotifyResult(err error) { - merged := len(t.ResultCh) > 1 - for _, resultCh := range t.ResultCh { - resultCh <- jobSubmitResult{ - err: err, - jobID: t.ID, - merged: merged, - } - } -} - -// ddl is used to handle the statements that define the structure or schema of the database. -type ddl struct { - m sync.RWMutex - wg tidbutil.WaitGroupWrapper // It's only used to deal with data race in restart_test. - limitJobCh chan *JobWrapper - - *ddlCtx - sessPool *sess.Pool - delRangeMgr delRangeManager - enableTiFlashPoll *atomicutil.Bool - // get notification if any DDL job submitted or finished. - ddlJobNotifyCh chan struct{} - sysTblMgr systable.Manager - minJobIDRefresher *systable.MinJobIDRefresher - - // globalIDLock locks global id to reduce write conflict. - globalIDLock sync.Mutex - executor *executor -} - -// waitSchemaSyncedController is to control whether to waitSchemaSynced or not. -type waitSchemaSyncedController struct { - mu sync.RWMutex - job map[int64]struct{} - - // Use to check if the DDL job is the first run on this owner. - onceMap map[int64]struct{} -} - -func newWaitSchemaSyncedController() *waitSchemaSyncedController { - return &waitSchemaSyncedController{ - job: make(map[int64]struct{}, jobRecordCapacity), - onceMap: make(map[int64]struct{}, jobOnceCapacity), - } -} - -func (w *waitSchemaSyncedController) registerSync(job *model.Job) { - w.mu.Lock() - defer w.mu.Unlock() - w.job[job.ID] = struct{}{} -} - -func (w *waitSchemaSyncedController) isSynced(job *model.Job) bool { - w.mu.RLock() - defer w.mu.RUnlock() - _, ok := w.job[job.ID] - return !ok -} - -func (w *waitSchemaSyncedController) synced(job *model.Job) { - w.mu.Lock() - defer w.mu.Unlock() - delete(w.job, job.ID) -} - -// maybeAlreadyRunOnce returns true means that the job may be the first run on this owner. -// Returns false means that the job must not be the first run on this owner. -func (w *waitSchemaSyncedController) maybeAlreadyRunOnce(id int64) bool { - w.mu.Lock() - defer w.mu.Unlock() - _, ok := w.onceMap[id] - return ok -} - -func (w *waitSchemaSyncedController) setAlreadyRunOnce(id int64) { - w.mu.Lock() - defer w.mu.Unlock() - if len(w.onceMap) > jobOnceCapacity { - // If the map is too large, we reset it. These jobs may need to check schema synced again, but it's ok. - w.onceMap = make(map[int64]struct{}, jobRecordCapacity) - } - w.onceMap[id] = struct{}{} -} - -func (w *waitSchemaSyncedController) clearOnceMap() { - w.mu.Lock() - defer w.mu.Unlock() - w.onceMap = make(map[int64]struct{}, jobOnceCapacity) -} - -// ddlCtx is the context when we use worker to handle DDL jobs. -type ddlCtx struct { - ctx context.Context - cancel context.CancelFunc - uuid string - store kv.Storage - ownerManager owner.Manager - schemaSyncer syncer.SchemaSyncer - stateSyncer syncer.StateSyncer - // ddlJobDoneChMap is used to notify the session that the DDL job is finished. - // jobID -> chan struct{} - ddlJobDoneChMap generic.SyncMap[int64, chan struct{}] - ddlEventCh chan<- *statsutil.DDLEvent - lease time.Duration // lease is schema lease, default 45s, see config.Lease. - binlogCli *pumpcli.PumpsClient // binlogCli is used for Binlog. - infoCache *infoschema.InfoCache - statsHandle *handle.Handle - tableLockCkr util.DeadTableLockChecker - etcdCli *clientv3.Client - autoidCli *autoid.ClientDiscover - schemaLoader SchemaLoader - - *waitSchemaSyncedController - *schemaVersionManager - - // reorgCtx is used for reorganization. - reorgCtx reorgContexts - - jobCtx struct { - sync.RWMutex - // jobCtxMap maps job ID to job's ctx. - jobCtxMap map[int64]*JobContext - } -} - -// SchemaLoader is used to avoid import loop, the only impl is domain currently. -type SchemaLoader interface { - Reload() error -} - -// schemaVersionManager is used to manage the schema version. To prevent the conflicts on this key between different DDL job, -// we use another transaction to update the schema version, so that we need to lock the schema version and unlock it until the job is committed. -// for version2, we use etcd lock to lock the schema version between TiDB nodes now. -type schemaVersionManager struct { - schemaVersionMu sync.Mutex - // lockOwner stores the job ID that is holding the lock. - lockOwner atomicutil.Int64 -} - -func newSchemaVersionManager() *schemaVersionManager { - return &schemaVersionManager{} -} - -func (sv *schemaVersionManager) setSchemaVersion(job *model.Job, store kv.Storage) (schemaVersion int64, err error) { - err = sv.lockSchemaVersion(job.ID) - if err != nil { - return schemaVersion, errors.Trace(err) - } - err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { - var err error - m := meta.NewMeta(txn) - schemaVersion, err = m.GenSchemaVersion() - return err - }) - return schemaVersion, err -} - -// lockSchemaVersion gets the lock to prevent the schema version from being updated. -func (sv *schemaVersionManager) lockSchemaVersion(jobID int64) error { - ownerID := sv.lockOwner.Load() - // There may exist one job update schema version many times in multiple-schema-change, so we do not lock here again - // if they are the same job. - if ownerID != jobID { - sv.schemaVersionMu.Lock() - sv.lockOwner.Store(jobID) - } - return nil -} - -// unlockSchemaVersion releases the lock. -func (sv *schemaVersionManager) unlockSchemaVersion(jobID int64) { - ownerID := sv.lockOwner.Load() - if ownerID == jobID { - sv.lockOwner.Store(0) - sv.schemaVersionMu.Unlock() - } -} - -func (dc *ddlCtx) isOwner() bool { - isOwner := dc.ownerManager.IsOwner() - logutil.DDLLogger().Debug("check whether is the DDL owner", zap.Bool("isOwner", isOwner), zap.String("selfID", dc.uuid)) - if isOwner { - metrics.DDLCounter.WithLabelValues(metrics.DDLOwner + "_" + mysql.TiDBReleaseVersion).Inc() - } - return isOwner -} - -func (dc *ddlCtx) setDDLLabelForTopSQL(jobID int64, jobQuery string) { - dc.jobCtx.Lock() - defer dc.jobCtx.Unlock() - ctx, exists := dc.jobCtx.jobCtxMap[jobID] - if !exists { - ctx = NewJobContext() - dc.jobCtx.jobCtxMap[jobID] = ctx - } - ctx.setDDLLabelForTopSQL(jobQuery) -} - -func (dc *ddlCtx) setDDLSourceForDiagnosis(jobID int64, jobType model.ActionType) { - dc.jobCtx.Lock() - defer dc.jobCtx.Unlock() - ctx, exists := dc.jobCtx.jobCtxMap[jobID] - if !exists { - ctx = NewJobContext() - dc.jobCtx.jobCtxMap[jobID] = ctx - } - ctx.setDDLLabelForDiagnosis(jobType) -} - -func (dc *ddlCtx) getResourceGroupTaggerForTopSQL(jobID int64) tikvrpc.ResourceGroupTagger { - dc.jobCtx.Lock() - defer dc.jobCtx.Unlock() - ctx, exists := dc.jobCtx.jobCtxMap[jobID] - if !exists { - return nil - } - return ctx.getResourceGroupTaggerForTopSQL() -} - -func (dc *ddlCtx) removeJobCtx(job *model.Job) { - dc.jobCtx.Lock() - defer dc.jobCtx.Unlock() - delete(dc.jobCtx.jobCtxMap, job.ID) -} - -func (dc *ddlCtx) jobContext(jobID int64, reorgMeta *model.DDLReorgMeta) *JobContext { - dc.jobCtx.RLock() - defer dc.jobCtx.RUnlock() - var ctx *JobContext - if jobContext, exists := dc.jobCtx.jobCtxMap[jobID]; exists { - ctx = jobContext - } else { - ctx = NewJobContext() - } - if reorgMeta != nil && len(ctx.resourceGroupName) == 0 { - ctx.resourceGroupName = reorgMeta.ResourceGroupName - } - return ctx -} - -type reorgContexts struct { - sync.RWMutex - // reorgCtxMap maps job ID to reorg context. - reorgCtxMap map[int64]*reorgCtx - beOwnerTS int64 -} - -func (r *reorgContexts) getOwnerTS() int64 { - r.RLock() - defer r.RUnlock() - return r.beOwnerTS -} - -func (r *reorgContexts) setOwnerTS(ts int64) { - r.Lock() - r.beOwnerTS = ts - r.Unlock() -} - -func (dc *ddlCtx) getReorgCtx(jobID int64) *reorgCtx { - dc.reorgCtx.RLock() - defer dc.reorgCtx.RUnlock() - return dc.reorgCtx.reorgCtxMap[jobID] -} - -func (dc *ddlCtx) newReorgCtx(jobID int64, rowCount int64) *reorgCtx { - dc.reorgCtx.Lock() - defer dc.reorgCtx.Unlock() - existedRC, ok := dc.reorgCtx.reorgCtxMap[jobID] - if ok { - existedRC.references.Add(1) - return existedRC - } - rc := &reorgCtx{} - rc.doneCh = make(chan reorgFnResult, 1) - // initial reorgCtx - rc.setRowCount(rowCount) - rc.mu.warnings = make(map[errors.ErrorID]*terror.Error) - rc.mu.warningsCount = make(map[errors.ErrorID]int64) - rc.references.Add(1) - dc.reorgCtx.reorgCtxMap[jobID] = rc - return rc -} - -func (dc *ddlCtx) removeReorgCtx(jobID int64) { - dc.reorgCtx.Lock() - defer dc.reorgCtx.Unlock() - ctx, ok := dc.reorgCtx.reorgCtxMap[jobID] - if ok { - ctx.references.Sub(1) - if ctx.references.Load() == 0 { - delete(dc.reorgCtx.reorgCtxMap, jobID) - } - } -} - -func (dc *ddlCtx) notifyReorgWorkerJobStateChange(job *model.Job) { - rc := dc.getReorgCtx(job.ID) - if rc == nil { - logutil.DDLLogger().Warn("cannot find reorgCtx", zap.Int64("Job ID", job.ID)) - return - } - logutil.DDLLogger().Info("notify reorg worker the job's state", - zap.Int64("Job ID", job.ID), zap.Stringer("Job State", job.State), - zap.Stringer("Schema State", job.SchemaState)) - rc.notifyJobState(job.State) -} - -func (dc *ddlCtx) notifyJobDone(jobID int64) { - if ch, ok := dc.ddlJobDoneChMap.Delete(jobID); ok { - // broadcast done event as we might merge multiple jobs into one when fast - // create table is enabled. - close(ch) - } -} - -// EnableTiFlashPoll enables TiFlash poll loop aka PollTiFlashReplicaStatus. -func EnableTiFlashPoll(d any) { - if dd, ok := d.(*ddl); ok { - dd.enableTiFlashPoll.Store(true) - } -} - -// DisableTiFlashPoll disables TiFlash poll loop aka PollTiFlashReplicaStatus. -func DisableTiFlashPoll(d any) { - if dd, ok := d.(*ddl); ok { - dd.enableTiFlashPoll.Store(false) - } -} - -// IsTiFlashPollEnabled reveals enableTiFlashPoll -func (d *ddl) IsTiFlashPollEnabled() bool { - return d.enableTiFlashPoll.Load() -} - -// RegisterStatsHandle registers statistics handle and its corresponding even channel for ddl. -// TODO this is called after ddl started, will cause panic if related DDL are executed -// in between. -func (d *ddl) RegisterStatsHandle(h *handle.Handle) { - d.ddlCtx.statsHandle = h - d.executor.statsHandle = h - d.ddlEventCh = h.DDLEventCh() -} - -// asyncNotifyEvent will notify the ddl event to outside world, say statistic handle. When the channel is full, we may -// give up notify and log it. -func asyncNotifyEvent(d *ddlCtx, e *statsutil.DDLEvent) { - if d.ddlEventCh != nil { - if d.lease == 0 { - // If lease is 0, it's always used in test. - select { - case d.ddlEventCh <- e: - default: - } - return - } - for i := 0; i < 10; i++ { - select { - case d.ddlEventCh <- e: - return - default: - time.Sleep(time.Microsecond * 10) - } - } - logutil.DDLLogger().Warn("fail to notify DDL event", zap.Stringer("event", e)) - } -} - -// NewDDL creates a new DDL. -// TODO remove it, to simplify this PR we use this way. -func NewDDL(ctx context.Context, options ...Option) (DDL, Executor) { - return newDDL(ctx, options...) -} - -func newDDL(ctx context.Context, options ...Option) (*ddl, *executor) { - opt := &Options{} - for _, o := range options { - o(opt) - } - - id := uuid.New().String() - var manager owner.Manager - var schemaSyncer syncer.SchemaSyncer - var stateSyncer syncer.StateSyncer - var deadLockCkr util.DeadTableLockChecker - if etcdCli := opt.EtcdCli; etcdCli == nil { - // The etcdCli is nil if the store is localstore which is only used for testing. - // So we use mockOwnerManager and MockSchemaSyncer. - manager = owner.NewMockManager(ctx, id, opt.Store, DDLOwnerKey) - schemaSyncer = NewMockSchemaSyncer() - stateSyncer = NewMockStateSyncer() - } else { - manager = owner.NewOwnerManager(ctx, etcdCli, ddlPrompt, id, DDLOwnerKey) - schemaSyncer = syncer.NewSchemaSyncer(etcdCli, id) - stateSyncer = syncer.NewStateSyncer(etcdCli, util.ServerGlobalState) - deadLockCkr = util.NewDeadTableLockChecker(etcdCli) - } - - // TODO: make store and infoCache explicit arguments - // these two should be ensured to exist - if opt.Store == nil { - panic("store should not be nil") - } - if opt.InfoCache == nil { - panic("infoCache should not be nil") - } - - ddlCtx := &ddlCtx{ - uuid: id, - store: opt.Store, - lease: opt.Lease, - ddlJobDoneChMap: generic.NewSyncMap[int64, chan struct{}](10), - ownerManager: manager, - schemaSyncer: schemaSyncer, - stateSyncer: stateSyncer, - binlogCli: binloginfo.GetPumpsClient(), - infoCache: opt.InfoCache, - tableLockCkr: deadLockCkr, - etcdCli: opt.EtcdCli, - autoidCli: opt.AutoIDClient, - schemaLoader: opt.SchemaLoader, - waitSchemaSyncedController: newWaitSchemaSyncedController(), - } - ddlCtx.reorgCtx.reorgCtxMap = make(map[int64]*reorgCtx) - ddlCtx.jobCtx.jobCtxMap = make(map[int64]*JobContext) - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnDDL) - ddlCtx.ctx, ddlCtx.cancel = context.WithCancel(ctx) - ddlCtx.schemaVersionManager = newSchemaVersionManager() - - d := &ddl{ - ddlCtx: ddlCtx, - limitJobCh: make(chan *JobWrapper, batchAddingJobs), - enableTiFlashPoll: atomicutil.NewBool(true), - ddlJobNotifyCh: make(chan struct{}, 100), - } - - taskexecutor.RegisterTaskType(proto.Backfill, - func(ctx context.Context, id string, task *proto.Task, taskTable taskexecutor.TaskTable) taskexecutor.TaskExecutor { - return newBackfillDistExecutor(ctx, id, task, taskTable, d) - }, - ) - - scheduler.RegisterSchedulerFactory(proto.Backfill, - func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { - return newLitBackfillScheduler(ctx, d, task, param) - }) - scheduler.RegisterSchedulerCleanUpFactory(proto.Backfill, newBackfillCleanUpS3) - // Register functions for enable/disable ddl when changing system variable `tidb_enable_ddl`. - variable.EnableDDL = d.EnableDDL - variable.DisableDDL = d.DisableDDL - variable.SwitchMDL = d.SwitchMDL - - e := &executor{ - ctx: d.ctx, - uuid: d.uuid, - store: d.store, - etcdCli: d.etcdCli, - autoidCli: d.autoidCli, - infoCache: d.infoCache, - limitJobCh: d.limitJobCh, - schemaLoader: d.schemaLoader, - lease: d.lease, - ownerManager: d.ownerManager, - ddlJobDoneChMap: &d.ddlJobDoneChMap, - ddlJobNotifyCh: d.ddlJobNotifyCh, - globalIDLock: &d.globalIDLock, - } - d.executor = e - - return d, e -} - -// Stop implements DDL.Stop interface. -func (d *ddl) Stop() error { - d.m.Lock() - defer d.m.Unlock() - - d.close() - logutil.DDLLogger().Info("stop DDL", zap.String("ID", d.uuid)) - return nil -} - -func (d *ddl) newDeleteRangeManager(mock bool) delRangeManager { - var delRangeMgr delRangeManager - if !mock { - delRangeMgr = newDelRangeManager(d.store, d.sessPool) - logutil.DDLLogger().Info("start delRangeManager OK", zap.Bool("is a emulator", !d.store.SupportDeleteRange())) - } else { - delRangeMgr = newMockDelRangeManager() - } - - delRangeMgr.start() - return delRangeMgr -} - -// Start implements DDL.Start interface. -func (d *ddl) Start(ctxPool *pools.ResourcePool) error { - logutil.DDLLogger().Info("start DDL", zap.String("ID", d.uuid), zap.Bool("runWorker", config.GetGlobalConfig().Instance.TiDBEnableDDL.Load())) - - d.sessPool = sess.NewSessionPool(ctxPool) - d.executor.sessPool = d.sessPool - d.sysTblMgr = systable.NewManager(d.sessPool) - d.minJobIDRefresher = systable.NewMinJobIDRefresher(d.sysTblMgr) - d.wg.Run(func() { - d.limitDDLJobs() - }) - d.wg.Run(func() { - d.minJobIDRefresher.Start(d.ctx) - }) - - d.delRangeMgr = d.newDeleteRangeManager(ctxPool == nil) - - if err := d.stateSyncer.Init(d.ctx); err != nil { - logutil.DDLLogger().Warn("start DDL init state syncer failed", zap.Error(err)) - return errors.Trace(err) - } - d.ownerManager.SetListener(&ownerListener{ - ddl: d, - }) - - if config.TableLockEnabled() { - d.wg.Add(1) - go d.startCleanDeadTableLock() - } - - // If tidb_enable_ddl is true, we need campaign owner and do DDL jobs. Besides, we also can do backfill jobs. - // Otherwise, we needn't do that. - if config.GetGlobalConfig().Instance.TiDBEnableDDL.Load() { - if err := d.EnableDDL(); err != nil { - return err - } - } - - variable.RegisterStatistics(d) - - metrics.DDLCounter.WithLabelValues(metrics.CreateDDLInstance).Inc() - - // Start some background routine to manage TiFlash replica. - d.wg.Run(d.PollTiFlashRoutine) - - ingestDataDir, err := ingest.GenIngestTempDataDir() - if err != nil { - logutil.DDLIngestLogger().Warn(ingest.LitWarnEnvInitFail, - zap.Error(err)) - } else { - ok := ingest.InitGlobalLightningEnv(ingestDataDir) - if ok { - d.wg.Run(func() { - d.CleanUpTempDirLoop(d.ctx, ingestDataDir) - }) - } - } - - return nil -} - -func (d *ddl) CleanUpTempDirLoop(ctx context.Context, path string) { - ticker := time.NewTicker(1 * time.Minute) - defer ticker.Stop() - for { - select { - case <-ticker.C: - se, err := d.sessPool.Get() - if err != nil { - logutil.DDLLogger().Warn("get session from pool failed", zap.Error(err)) - return - } - ingest.CleanUpTempDir(ctx, se, path) - d.sessPool.Put(se) - case <-d.ctx.Done(): - return - } - } -} - -// EnableDDL enable this node to execute ddl. -// Since ownerManager.CampaignOwner will start a new goroutine to run ownerManager.campaignLoop, -// we should make sure that before invoking EnableDDL(), ddl is DISABLE. -func (d *ddl) EnableDDL() error { - err := d.ownerManager.CampaignOwner() - return errors.Trace(err) -} - -// DisableDDL disable this node to execute ddl. -// We should make sure that before invoking DisableDDL(), ddl is ENABLE. -func (d *ddl) DisableDDL() error { - if d.ownerManager.IsOwner() { - // If there is only one node, we should NOT disable ddl. - serverInfo, err := infosync.GetAllServerInfo(d.ctx) - if err != nil { - logutil.DDLLogger().Error("error when GetAllServerInfo", zap.Error(err)) - return err - } - if len(serverInfo) <= 1 { - return dbterror.ErrDDLSetting.GenWithStackByArgs("disabling", "can not disable ddl owner when it is the only one tidb instance") - } - // FIXME: if possible, when this node is the only node with DDL, ths setting of DisableDDL should fail. - } - - // disable campaign by interrupting campaignLoop - d.ownerManager.CampaignCancel() - return nil -} - -func (d *ddl) close() { - if d.ctx.Err() != nil { - return - } - - startTime := time.Now() - d.cancel() - d.wg.Wait() - d.ownerManager.Cancel() - d.schemaSyncer.Close() - - // d.delRangeMgr using sessions from d.sessPool. - // Put it before d.sessPool.close to reduce the time spent by d.sessPool.close. - if d.delRangeMgr != nil { - d.delRangeMgr.clear() - } - if d.sessPool != nil { - d.sessPool.Close() - } - variable.UnregisterStatistics(d) - - logutil.DDLLogger().Info("DDL closed", zap.String("ID", d.uuid), zap.Duration("take time", time.Since(startTime))) -} - -// GetLease implements DDL.GetLease interface. -func (d *ddl) GetLease() time.Duration { - lease := d.lease - return lease -} - -// SchemaSyncer implements DDL.SchemaSyncer interface. -func (d *ddl) SchemaSyncer() syncer.SchemaSyncer { - return d.schemaSyncer -} - -// StateSyncer implements DDL.StateSyncer interface. -func (d *ddl) StateSyncer() syncer.StateSyncer { - return d.stateSyncer -} - -// OwnerManager implements DDL.OwnerManager interface. -func (d *ddl) OwnerManager() owner.Manager { - return d.ownerManager -} - -// GetID implements DDL.GetID interface. -func (d *ddl) GetID() string { - return d.uuid -} - -// SetBinlogClient implements DDL.SetBinlogClient interface. -func (d *ddl) SetBinlogClient(binlogCli *pumpcli.PumpsClient) { - d.binlogCli = binlogCli -} - -func (d *ddl) GetMinJobIDRefresher() *systable.MinJobIDRefresher { - return d.minJobIDRefresher -} - -func (d *ddl) startCleanDeadTableLock() { - defer func() { - d.wg.Done() - }() - - defer tidbutil.Recover(metrics.LabelDDL, "startCleanDeadTableLock", nil, false) - - ticker := time.NewTicker(time.Second * 10) - defer ticker.Stop() - for { - select { - case <-ticker.C: - if !d.ownerManager.IsOwner() { - continue - } - deadLockTables, err := d.tableLockCkr.GetDeadLockedTables(d.ctx, d.infoCache.GetLatest()) - if err != nil { - logutil.DDLLogger().Info("get dead table lock failed.", zap.Error(err)) - continue - } - for se, tables := range deadLockTables { - err := d.cleanDeadTableLock(tables, se) - if err != nil { - logutil.DDLLogger().Info("clean dead table lock failed.", zap.Error(err)) - } - } - case <-d.ctx.Done(): - return - } - } -} - -// cleanDeadTableLock uses to clean dead table locks. -func (d *ddl) cleanDeadTableLock(unlockTables []model.TableLockTpInfo, se model.SessionInfo) error { - if len(unlockTables) == 0 { - return nil - } - arg := &LockTablesArg{ - UnlockTables: unlockTables, - SessionInfo: se, - } - job := &model.Job{ - SchemaID: unlockTables[0].SchemaID, - TableID: unlockTables[0].TableID, - Type: model.ActionUnlockTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{arg}, - } - - ctx, err := d.sessPool.Get() - if err != nil { - return err - } - defer d.sessPool.Put(ctx) - err = d.executor.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// SwitchMDL enables MDL or disable MDL. -func (d *ddl) SwitchMDL(enable bool) error { - isEnableBefore := variable.EnableMDL.Load() - if isEnableBefore == enable { - return nil - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - // Check if there is any DDL running. - // This check can not cover every corner cases, so users need to guarantee that there is no DDL running by themselves. - sessCtx, err := d.sessPool.Get() - if err != nil { - return err - } - defer d.sessPool.Put(sessCtx) - se := sess.NewSession(sessCtx) - rows, err := se.Execute(ctx, "select 1 from mysql.tidb_ddl_job", "check job") - if err != nil { - return err - } - if len(rows) != 0 { - return errors.New("please wait for all jobs done") - } - - variable.EnableMDL.Store(enable) - err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), d.store, true, func(_ context.Context, txn kv.Transaction) error { - m := meta.NewMeta(txn) - oldEnable, _, err := m.GetMetadataLock() - if err != nil { - return err - } - if oldEnable != enable { - err = m.SetMetadataLock(enable) - } - return err - }) - if err != nil { - logutil.DDLLogger().Warn("switch metadata lock feature", zap.Bool("enable", enable), zap.Error(err)) - return err - } - logutil.DDLLogger().Info("switch metadata lock feature", zap.Bool("enable", enable)) - return nil -} - -// RecoverInfo contains information needed by DDL.RecoverTable. -type RecoverInfo struct { - SchemaID int64 - TableInfo *model.TableInfo - DropJobID int64 - SnapshotTS uint64 - AutoIDs meta.AutoIDGroup - OldSchemaName string - OldTableName string -} - -// RecoverSchemaInfo contains information needed by DDL.RecoverSchema. -type RecoverSchemaInfo struct { - *model.DBInfo - RecoverTabsInfo []*RecoverInfo - // LoadTablesOnExecute is the new logic to avoid a large RecoverTabsInfo can't be - // persisted. If it's true, DDL owner will recover RecoverTabsInfo instead of the - // job submit node. - LoadTablesOnExecute bool - DropJobID int64 - SnapshotTS uint64 - OldSchemaName model.CIStr -} - -// delayForAsyncCommit sleeps `SafeWindow + AllowedClockDrift` before a DDL job finishes. -// It should be called before any DDL that could break data consistency. -// This provides a safe window for async commit and 1PC to commit with an old schema. -func delayForAsyncCommit() { - if variable.EnableMDL.Load() { - // If metadata lock is enabled. The transaction of DDL must begin after prewrite of the async commit transaction, - // then the commit ts of DDL must be greater than the async commit transaction. In this case, the corresponding schema of the async commit transaction - // is correct. But if metadata lock is disabled, we can't ensure that the corresponding schema of the async commit transaction isn't change. - return - } - cfg := config.GetGlobalConfig().TiKVClient.AsyncCommit - duration := cfg.SafeWindow + cfg.AllowedClockDrift - logutil.DDLLogger().Info("sleep before DDL finishes to make async commit and 1PC safe", - zap.Duration("duration", duration)) - time.Sleep(duration) -} - -var ( - // RunInGoTest is used to identify whether ddl in running in the test. - RunInGoTest bool -) - -// GetDropOrTruncateTableInfoFromJobsByStore implements GetDropOrTruncateTableInfoFromJobs -func GetDropOrTruncateTableInfoFromJobsByStore(jobs []*model.Job, gcSafePoint uint64, getTable func(uint64, int64, int64) (*model.TableInfo, error), fn func(*model.Job, *model.TableInfo) (bool, error)) (bool, error) { - for _, job := range jobs { - // Check GC safe point for getting snapshot infoSchema. - err := gcutil.ValidateSnapshotWithGCSafePoint(job.StartTS, gcSafePoint) - if err != nil { - return false, err - } - if job.Type != model.ActionDropTable && job.Type != model.ActionTruncateTable { - continue - } - - tbl, err := getTable(job.StartTS, job.SchemaID, job.TableID) - if err != nil { - if meta.ErrDBNotExists.Equal(err) { - // The dropped/truncated DDL maybe execute failed that caused by the parallel DDL execution, - // then can't find the table from the snapshot info-schema. Should just ignore error here, - // see more in TestParallelDropSchemaAndDropTable. - continue - } - return false, err - } - if tbl == nil { - // The dropped/truncated DDL maybe execute failed that caused by the parallel DDL execution, - // then can't find the table from the snapshot info-schema. Should just ignore error here, - // see more in TestParallelDropSchemaAndDropTable. - continue - } - finish, err := fn(job, tbl) - if err != nil || finish { - return finish, err - } - } - return false, nil -} - -// Info is for DDL information. -type Info struct { - SchemaVer int64 - ReorgHandle kv.Key // It's only used for DDL information. - Jobs []*model.Job // It's the currently running jobs. -} - -// GetDDLInfoWithNewTxn returns DDL information using a new txn. -func GetDDLInfoWithNewTxn(s sessionctx.Context) (*Info, error) { - se := sess.NewSession(s) - err := se.Begin(context.Background()) - if err != nil { - return nil, err - } - info, err := GetDDLInfo(s) - se.Rollback() - return info, err -} - -// GetDDLInfo returns DDL information and only uses for testing. -func GetDDLInfo(s sessionctx.Context) (*Info, error) { - var err error - info := &Info{} - se := sess.NewSession(s) - txn, err := se.Txn() - if err != nil { - return nil, errors.Trace(err) - } - t := meta.NewMeta(txn) - info.Jobs = make([]*model.Job, 0, 2) - var generalJob, reorgJob *model.Job - generalJob, reorgJob, err = get2JobsFromTable(se) - if err != nil { - return nil, errors.Trace(err) - } - - if generalJob != nil { - info.Jobs = append(info.Jobs, generalJob) - } - - if reorgJob != nil { - info.Jobs = append(info.Jobs, reorgJob) - } - - info.SchemaVer, err = t.GetSchemaVersionWithNonEmptyDiff() - if err != nil { - return nil, errors.Trace(err) - } - if reorgJob == nil { - return info, nil - } - - _, info.ReorgHandle, _, _, err = newReorgHandler(se).GetDDLReorgHandle(reorgJob) - if err != nil { - if meta.ErrDDLReorgElementNotExist.Equal(err) { - return info, nil - } - return nil, errors.Trace(err) - } - - return info, nil -} - -func get2JobsFromTable(sess *sess.Session) (*model.Job, *model.Job, error) { - var generalJob, reorgJob *model.Job - jobs, err := getJobsBySQL(sess, JobTable, "not reorg order by job_id limit 1") - if err != nil { - return nil, nil, errors.Trace(err) - } - - if len(jobs) != 0 { - generalJob = jobs[0] - } - jobs, err = getJobsBySQL(sess, JobTable, "reorg order by job_id limit 1") - if err != nil { - return nil, nil, errors.Trace(err) - } - if len(jobs) != 0 { - reorgJob = jobs[0] - } - return generalJob, reorgJob, nil -} - -// cancelRunningJob cancel a DDL job that is in the concurrent state. -func cancelRunningJob(_ *sess.Session, job *model.Job, - byWho model.AdminCommandOperator) (err error) { - // These states can't be cancelled. - if job.IsDone() || job.IsSynced() { - return dbterror.ErrCancelFinishedDDLJob.GenWithStackByArgs(job.ID) - } - - // If the state is rolling back, it means the work is cleaning the data after cancelling the job. - if job.IsCancelled() || job.IsRollingback() || job.IsRollbackDone() { - return nil - } - - if !job.IsRollbackable() { - return dbterror.ErrCannotCancelDDLJob.GenWithStackByArgs(job.ID) - } - job.State = model.JobStateCancelling - job.AdminOperator = byWho - return nil -} - -// pauseRunningJob check and pause the running Job -func pauseRunningJob(_ *sess.Session, job *model.Job, - byWho model.AdminCommandOperator) (err error) { - if job.IsPausing() || job.IsPaused() { - return dbterror.ErrPausedDDLJob.GenWithStackByArgs(job.ID) - } - if !job.IsPausable() { - errMsg := fmt.Sprintf("state [%s] or schema state [%s]", job.State.String(), job.SchemaState.String()) - err = dbterror.ErrCannotPauseDDLJob.GenWithStackByArgs(job.ID, errMsg) - if err != nil { - return err - } - } - - job.State = model.JobStatePausing - job.AdminOperator = byWho - return nil -} - -// resumePausedJob check and resume the Paused Job -func resumePausedJob(_ *sess.Session, job *model.Job, - byWho model.AdminCommandOperator) (err error) { - if !job.IsResumable() { - errMsg := fmt.Sprintf("job has not been paused, job state:%s, schema state:%s", - job.State, job.SchemaState) - return dbterror.ErrCannotResumeDDLJob.GenWithStackByArgs(job.ID, errMsg) - } - // The Paused job should only be resumed by who paused it - if job.AdminOperator != byWho { - errMsg := fmt.Sprintf("job has been paused by [%s], should not resumed by [%s]", - job.AdminOperator.String(), byWho.String()) - return dbterror.ErrCannotResumeDDLJob.GenWithStackByArgs(job.ID, errMsg) - } - - job.State = model.JobStateQueueing - - return nil -} - -// processJobs command on the Job according to the process -func processJobs( - process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error), - sessCtx sessionctx.Context, - ids []int64, - byWho model.AdminCommandOperator, -) (jobErrs []error, err error) { - failpoint.Inject("mockFailedCommandOnConcurencyDDL", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("mock failed admin command on ddl jobs")) - } - }) - - if len(ids) == 0 { - return nil, nil - } - - ns := sess.NewSession(sessCtx) - // We should process (and try) all the jobs in one Transaction. - for tryN := uint(0); tryN < 3; tryN++ { - jobErrs = make([]error, len(ids)) - // Need to figure out which one could not be paused - jobMap := make(map[int64]int, len(ids)) - idsStr := make([]string, 0, len(ids)) - for idx, id := range ids { - jobMap[id] = idx - idsStr = append(idsStr, strconv.FormatInt(id, 10)) - } - - err = ns.Begin(context.Background()) - if err != nil { - return nil, err - } - jobs, err := getJobsBySQL(ns, JobTable, fmt.Sprintf("job_id in (%s) order by job_id", strings.Join(idsStr, ", "))) - if err != nil { - ns.Rollback() - return nil, err - } - - for _, job := range jobs { - i, ok := jobMap[job.ID] - if !ok { - logutil.DDLLogger().Debug("Job ID from meta is not consistent with requested job id,", - zap.Int64("fetched job ID", job.ID)) - jobErrs[i] = dbterror.ErrInvalidDDLJob.GenWithStackByArgs(job.ID) - continue - } - delete(jobMap, job.ID) - - err = process(ns, job, byWho) - if err != nil { - jobErrs[i] = err - continue - } - - err = updateDDLJob2Table(ns, job, false) - if err != nil { - jobErrs[i] = err - continue - } - } - - failpoint.Inject("mockCommitFailedOnDDLCommand", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(jobErrs, errors.New("mock commit failed on admin command on ddl jobs")) - } - }) - - // There may be some conflict during the update, try it again - if err = ns.Commit(context.Background()); err != nil { - continue - } - - for id, idx := range jobMap { - jobErrs[idx] = dbterror.ErrDDLJobNotFound.GenWithStackByArgs(id) - } - - return jobErrs, nil - } - - return jobErrs, err -} - -// CancelJobs cancels the DDL jobs according to user command. -func CancelJobs(se sessionctx.Context, ids []int64) (errs []error, err error) { - return processJobs(cancelRunningJob, se, ids, model.AdminCommandByEndUser) -} - -// PauseJobs pause all the DDL jobs according to user command. -func PauseJobs(se sessionctx.Context, ids []int64) ([]error, error) { - return processJobs(pauseRunningJob, se, ids, model.AdminCommandByEndUser) -} - -// ResumeJobs resume all the DDL jobs according to user command. -func ResumeJobs(se sessionctx.Context, ids []int64) ([]error, error) { - return processJobs(resumePausedJob, se, ids, model.AdminCommandByEndUser) -} - -// CancelJobsBySystem cancels Jobs because of internal reasons. -func CancelJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) { - return processJobs(cancelRunningJob, se, ids, model.AdminCommandBySystem) -} - -// PauseJobsBySystem pauses Jobs because of internal reasons. -func PauseJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) { - return processJobs(pauseRunningJob, se, ids, model.AdminCommandBySystem) -} - -// ResumeJobsBySystem resumes Jobs that are paused by TiDB itself. -func ResumeJobsBySystem(se sessionctx.Context, ids []int64) (errs []error, err error) { - return processJobs(resumePausedJob, se, ids, model.AdminCommandBySystem) -} - -// pprocessAllJobs processes all the jobs in the job table, 100 jobs at a time in case of high memory usage. -func processAllJobs(process func(*sess.Session, *model.Job, model.AdminCommandOperator) (err error), - se sessionctx.Context, byWho model.AdminCommandOperator) (map[int64]error, error) { - var err error - var jobErrs = make(map[int64]error) - - ns := sess.NewSession(se) - err = ns.Begin(context.Background()) - if err != nil { - return nil, err - } - - var jobID int64 - var jobIDMax int64 - var limit = 100 - for { - var jobs []*model.Job - jobs, err = getJobsBySQL(ns, JobTable, - fmt.Sprintf("job_id >= %s order by job_id asc limit %s", - strconv.FormatInt(jobID, 10), - strconv.FormatInt(int64(limit), 10))) - if err != nil { - ns.Rollback() - return nil, err - } - - for _, job := range jobs { - err = process(ns, job, byWho) - if err != nil { - jobErrs[job.ID] = err - continue - } - - err = updateDDLJob2Table(ns, job, false) - if err != nil { - jobErrs[job.ID] = err - continue - } - } - - // Just in case the job ID is not sequential - if len(jobs) > 0 && jobs[len(jobs)-1].ID > jobIDMax { - jobIDMax = jobs[len(jobs)-1].ID - } - - // If rows returned is smaller than $limit, then there is no more records - if len(jobs) < limit { - break - } - - jobID = jobIDMax + 1 - } - - err = ns.Commit(context.Background()) - if err != nil { - return nil, err - } - return jobErrs, nil -} - -// PauseAllJobsBySystem pauses all running Jobs because of internal reasons. -func PauseAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) { - return processAllJobs(pauseRunningJob, se, model.AdminCommandBySystem) -} - -// ResumeAllJobsBySystem resumes all paused Jobs because of internal reasons. -func ResumeAllJobsBySystem(se sessionctx.Context) (map[int64]error, error) { - return processAllJobs(resumePausedJob, se, model.AdminCommandBySystem) -} - -// GetAllDDLJobs get all DDL jobs and sorts jobs by job.ID. -func GetAllDDLJobs(se sessionctx.Context) ([]*model.Job, error) { - return getJobsBySQL(sess.NewSession(se), JobTable, "1 order by job_id") -} - -// IterAllDDLJobs will iterates running DDL jobs first, return directly if `finishFn` return true or error, -// then iterates history DDL jobs until the `finishFn` return true or error. -func IterAllDDLJobs(ctx sessionctx.Context, txn kv.Transaction, finishFn func([]*model.Job) (bool, error)) error { - jobs, err := GetAllDDLJobs(ctx) - if err != nil { - return err - } - - finish, err := finishFn(jobs) - if err != nil || finish { - return err - } - return IterHistoryDDLJobs(txn, finishFn) -} diff --git a/pkg/ddl/ddl_tiflash_api.go b/pkg/ddl/ddl_tiflash_api.go index d232d78758f91..9f4c2512f0019 100644 --- a/pkg/ddl/ddl_tiflash_api.go +++ b/pkg/ddl/ddl_tiflash_api.go @@ -352,9 +352,9 @@ func updateTiFlashStores(pollTiFlashContext *TiFlashManagementContext) error { // PollAvailableTableProgress will poll and check availability of available tables. func PollAvailableTableProgress(schemas infoschema.InfoSchema, _ sessionctx.Context, pollTiFlashContext *TiFlashManagementContext) { pollMaxCount := RefreshProgressMaxTableCount - if val, _err_ := failpoint.Eval(_curpkg_("PollAvailableTableProgressMaxCount")); _err_ == nil { + failpoint.Inject("PollAvailableTableProgressMaxCount", func(val failpoint.Value) { pollMaxCount = uint64(val.(int)) - } + }) for element := pollTiFlashContext.UpdatingProgressTables.Front(); element != nil && pollMaxCount > 0; pollMaxCount-- { availableTableID := element.Value.(AvailableTableID) var table table.Table @@ -431,13 +431,13 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } } - if _, _err_ := failpoint.Eval(_curpkg_("OneTiFlashStoreDown")); _err_ == nil { + failpoint.Inject("OneTiFlashStoreDown", func() { for storeID, store := range pollTiFlashContext.TiFlashStores { store.Store.StateName = "Down" pollTiFlashContext.TiFlashStores[storeID] = store break } - } + }) pollTiFlashContext.PollCounter++ // Start to process every table. @@ -458,7 +458,7 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } } - if val, _err_ := failpoint.Eval(_curpkg_("waitForAddPartition")); _err_ == nil { + failpoint.Inject("waitForAddPartition", func(val failpoint.Value) { for _, phyTable := range tableList { is := d.infoCache.GetLatest() _, ok := is.TableByID(phyTable.ID) @@ -471,7 +471,7 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } } } - } + }) needPushPending := false if pollTiFlashContext.UpdatingProgressTables.Len() == 0 { @@ -482,9 +482,9 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T // For every region in each table, if it has one replica, we reckon it ready. // These request can be batched as an optimization. available := tb.Available - if val, _err_ := failpoint.Eval(_curpkg_("PollTiFlashReplicaStatusReplacePrevAvailableValue")); _err_ == nil { + failpoint.Inject("PollTiFlashReplicaStatusReplacePrevAvailableValue", func(val failpoint.Value) { available = val.(bool) - } + }) // We only check unavailable tables here, so doesn't include blocked add partition case. if !available && !tb.LogicalTableAvailable { enabled, inqueue, _ := pollTiFlashContext.Backoff.Tick(tb.ID) @@ -514,9 +514,9 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T } avail := progress == 1 - if val, _err_ := failpoint.Eval(_curpkg_("PollTiFlashReplicaStatusReplaceCurAvailableValue")); _err_ == nil { + failpoint.Inject("PollTiFlashReplicaStatusReplaceCurAvailableValue", func(val failpoint.Value) { avail = val.(bool) - } + }) if !avail { logutil.DDLLogger().Info("Tiflash replica is not available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) @@ -525,9 +525,9 @@ func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *T logutil.DDLLogger().Info("Tiflash replica is available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) pollTiFlashContext.Backoff.Remove(tb.ID) } - if _, _err_ := failpoint.Eval(_curpkg_("skipUpdateTableReplicaInfoInLoop")); _err_ == nil { - continue - } + failpoint.Inject("skipUpdateTableReplicaInfoInLoop", func() { + failpoint.Continue() + }) // Will call `onUpdateFlashReplicaStatus` to update `TiFlashReplica`. if err := d.executor.UpdateTableReplicaInfo(ctx, tb.ID, avail); err != nil { if infoschema.ErrTableNotExists.Equal(err) && tb.IsPartition { @@ -566,9 +566,9 @@ func (d *ddl) PollTiFlashRoutine() { logutil.DDLLogger().Error("failed to get sessionPool for refreshTiFlashTicker") return } - if _, _err_ := failpoint.Eval(_curpkg_("BeforeRefreshTiFlashTickeLoop")); _err_ == nil { - continue - } + failpoint.Inject("BeforeRefreshTiFlashTickeLoop", func() { + failpoint.Continue() + }) if !hasSetTiFlashGroup && !time.Now().Before(nextSetTiFlashGroupTime) { // We should set tiflash rule group a higher index than other placement groups to forbid override by them. diff --git a/pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ b/pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ deleted file mode 100644 index 9f4c2512f0019..0000000000000 --- a/pkg/ddl/ddl_tiflash_api.go__failpoint_stash__ +++ /dev/null @@ -1,608 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -package ddl - -import ( - "bytes" - "container/list" - "context" - "encoding/json" - "fmt" - "net" - "strconv" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - ddlutil "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/engine" - "github.com/pingcap/tidb/pkg/util/intest" - pd "github.com/tikv/pd/client/http" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -// TiFlashReplicaStatus records status for each TiFlash replica. -type TiFlashReplicaStatus struct { - ID int64 - Count uint64 - LocationLabels []string - Available bool - LogicalTableAvailable bool - HighPriority bool - IsPartition bool -} - -// TiFlashTick is type for backoff threshold. -type TiFlashTick float64 - -// PollTiFlashBackoffElement records backoff for each TiFlash Table. -// `Counter` increases every `Tick`, if it reached `Threshold`, it will be reset to 0 while `Threshold` grows. -// `TotalCounter` records total `Tick`s this element has since created. -type PollTiFlashBackoffElement struct { - Counter int - Threshold TiFlashTick - TotalCounter int -} - -// NewPollTiFlashBackoffElement initialize backoff element for a TiFlash table. -func NewPollTiFlashBackoffElement() *PollTiFlashBackoffElement { - return &PollTiFlashBackoffElement{ - Counter: 0, - Threshold: PollTiFlashBackoffMinTick, - TotalCounter: 0, - } -} - -// PollTiFlashBackoffContext is a collection of all backoff states. -type PollTiFlashBackoffContext struct { - MinThreshold TiFlashTick - MaxThreshold TiFlashTick - // Capacity limits tables a backoff pool can handle, in order to limit handling of big tables. - Capacity int - Rate TiFlashTick - elements map[int64]*PollTiFlashBackoffElement -} - -// NewPollTiFlashBackoffContext creates an instance of PollTiFlashBackoffContext. -func NewPollTiFlashBackoffContext(minThreshold, maxThreshold TiFlashTick, capacity int, rate TiFlashTick) (*PollTiFlashBackoffContext, error) { - if maxThreshold < minThreshold { - return nil, fmt.Errorf("`maxThreshold` should always be larger than `minThreshold`") - } - if minThreshold < 1 { - return nil, fmt.Errorf("`minThreshold` should not be less than 1") - } - if capacity < 0 { - return nil, fmt.Errorf("negative `capacity`") - } - if rate <= 1 { - return nil, fmt.Errorf("`rate` should always be larger than 1") - } - return &PollTiFlashBackoffContext{ - MinThreshold: minThreshold, - MaxThreshold: maxThreshold, - Capacity: capacity, - elements: make(map[int64]*PollTiFlashBackoffElement), - Rate: rate, - }, nil -} - -// TiFlashManagementContext is the context for TiFlash Replica Management -type TiFlashManagementContext struct { - TiFlashStores map[int64]pd.StoreInfo - PollCounter uint64 - Backoff *PollTiFlashBackoffContext - // tables waiting for updating progress after become available. - UpdatingProgressTables *list.List -} - -// AvailableTableID is the table id info of available table for waiting to update TiFlash replica progress. -type AvailableTableID struct { - ID int64 - IsPartition bool -} - -// Tick will first check increase Counter. -// It returns: -// 1. A bool indicates whether threshold is grown during this tick. -// 2. A bool indicates whether this ID exists. -// 3. A int indicates how many ticks ID has counted till now. -func (b *PollTiFlashBackoffContext) Tick(id int64) (grew bool, exist bool, cnt int) { - e, ok := b.Get(id) - if !ok { - return false, false, 0 - } - grew = e.MaybeGrow(b) - e.Counter++ - e.TotalCounter++ - return grew, true, e.TotalCounter -} - -// NeedGrow returns if we need to grow. -// It is exported for testing. -func (e *PollTiFlashBackoffElement) NeedGrow() bool { - return e.Counter >= int(e.Threshold) -} - -func (e *PollTiFlashBackoffElement) doGrow(b *PollTiFlashBackoffContext) { - if e.Threshold < b.MinThreshold { - e.Threshold = b.MinThreshold - } - if e.Threshold*b.Rate > b.MaxThreshold { - e.Threshold = b.MaxThreshold - } else { - e.Threshold *= b.Rate - } - e.Counter = 0 -} - -// MaybeGrow grows threshold and reset counter when needed. -func (e *PollTiFlashBackoffElement) MaybeGrow(b *PollTiFlashBackoffContext) bool { - if !e.NeedGrow() { - return false - } - e.doGrow(b) - return true -} - -// Remove will reset table from backoff. -func (b *PollTiFlashBackoffContext) Remove(id int64) bool { - _, ok := b.elements[id] - delete(b.elements, id) - return ok -} - -// Get returns pointer to inner PollTiFlashBackoffElement. -// Only exported for test. -func (b *PollTiFlashBackoffContext) Get(id int64) (*PollTiFlashBackoffElement, bool) { - res, ok := b.elements[id] - return res, ok -} - -// Put will record table into backoff pool, if there is enough room, or returns false. -func (b *PollTiFlashBackoffContext) Put(id int64) bool { - _, ok := b.elements[id] - if ok { - return true - } else if b.Len() < b.Capacity { - b.elements[id] = NewPollTiFlashBackoffElement() - return true - } - return false -} - -// Len gets size of PollTiFlashBackoffContext. -func (b *PollTiFlashBackoffContext) Len() int { - return len(b.elements) -} - -// NewTiFlashManagementContext creates an instance for TiFlashManagementContext. -func NewTiFlashManagementContext() (*TiFlashManagementContext, error) { - c, err := NewPollTiFlashBackoffContext(PollTiFlashBackoffMinTick, PollTiFlashBackoffMaxTick, PollTiFlashBackoffCapacity, PollTiFlashBackoffRate) - if err != nil { - return nil, err - } - return &TiFlashManagementContext{ - PollCounter: 0, - TiFlashStores: make(map[int64]pd.StoreInfo), - Backoff: c, - UpdatingProgressTables: list.New(), - }, nil -} - -var ( - // PollTiFlashInterval is the interval between every pollTiFlashReplicaStatus call. - PollTiFlashInterval = 2 * time.Second - // PullTiFlashPdTick indicates the number of intervals before we fully sync all TiFlash pd rules and tables. - PullTiFlashPdTick = atomicutil.NewUint64(30 * 5) - // UpdateTiFlashStoreTick indicates the number of intervals before we fully update TiFlash stores. - UpdateTiFlashStoreTick = atomicutil.NewUint64(5) - // PollTiFlashBackoffMaxTick is the max tick before we try to update TiFlash replica availability for one table. - PollTiFlashBackoffMaxTick TiFlashTick = 10 - // PollTiFlashBackoffMinTick is the min tick before we try to update TiFlash replica availability for one table. - PollTiFlashBackoffMinTick TiFlashTick = 1 - // PollTiFlashBackoffCapacity is the cache size of backoff struct. - PollTiFlashBackoffCapacity = 1000 - // PollTiFlashBackoffRate is growth rate of exponential backoff threshold. - PollTiFlashBackoffRate TiFlashTick = 1.5 - // RefreshProgressMaxTableCount is the max count of table to refresh progress after available each poll. - RefreshProgressMaxTableCount uint64 = 1000 -) - -func getTiflashHTTPAddr(host string, statusAddr string) (string, error) { - configURL := fmt.Sprintf("%s://%s/config", - util.InternalHTTPSchema(), - statusAddr, - ) - resp, err := util.InternalHTTPClient().Get(configURL) - if err != nil { - return "", errors.Trace(err) - } - defer func() { - resp.Body.Close() - }() - - buf := new(bytes.Buffer) - _, err = buf.ReadFrom(resp.Body) - if err != nil { - return "", errors.Trace(err) - } - - var j map[string]any - err = json.Unmarshal(buf.Bytes(), &j) - if err != nil { - return "", errors.Trace(err) - } - - engineStore, ok := j["engine-store"].(map[string]any) - if !ok { - return "", errors.New("Error json") - } - port64, ok := engineStore["http_port"].(float64) - if !ok { - return "", errors.New("Error json") - } - - addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port64), 10)) - return addr, nil -} - -// LoadTiFlashReplicaInfo parses model.TableInfo into []TiFlashReplicaStatus. -func LoadTiFlashReplicaInfo(tblInfo *model.TableInfo, tableList *[]TiFlashReplicaStatus) { - if tblInfo.TiFlashReplica == nil { - // reject tables that has no tiflash replica such like `INFORMATION_SCHEMA` - return - } - if pi := tblInfo.GetPartitionInfo(); pi != nil { - for _, p := range pi.Definitions { - logutil.DDLLogger().Debug(fmt.Sprintf("Table %v has partition %v\n", tblInfo.ID, p.ID)) - *tableList = append(*tableList, TiFlashReplicaStatus{p.ID, - tblInfo.TiFlashReplica.Count, tblInfo.TiFlashReplica.LocationLabels, tblInfo.TiFlashReplica.IsPartitionAvailable(p.ID), tblInfo.TiFlashReplica.Available, false, true}) - } - // partitions that in adding mid-state - for _, p := range pi.AddingDefinitions { - logutil.DDLLogger().Debug(fmt.Sprintf("Table %v has partition adding %v\n", tblInfo.ID, p.ID)) - *tableList = append(*tableList, TiFlashReplicaStatus{p.ID, tblInfo.TiFlashReplica.Count, tblInfo.TiFlashReplica.LocationLabels, tblInfo.TiFlashReplica.IsPartitionAvailable(p.ID), tblInfo.TiFlashReplica.Available, true, true}) - } - } else { - logutil.DDLLogger().Debug(fmt.Sprintf("Table %v has no partition\n", tblInfo.ID)) - *tableList = append(*tableList, TiFlashReplicaStatus{tblInfo.ID, tblInfo.TiFlashReplica.Count, tblInfo.TiFlashReplica.LocationLabels, tblInfo.TiFlashReplica.Available, tblInfo.TiFlashReplica.Available, false, false}) - } -} - -// UpdateTiFlashHTTPAddress report TiFlash's StatusAddress's port to Pd's etcd. -func (d *ddl) UpdateTiFlashHTTPAddress(store *pd.StoreInfo) error { - host, _, err := net.SplitHostPort(store.Store.StatusAddress) - if err != nil { - return errors.Trace(err) - } - httpAddr, err := getTiflashHTTPAddr(host, store.Store.StatusAddress) - if err != nil { - return errors.Trace(err) - } - // Report to pd - key := fmt.Sprintf("/tiflash/cluster/http_port/%v", store.Store.Address) - if d.etcdCli == nil { - return errors.New("no etcdCli in ddl") - } - origin := "" - resp, err := d.etcdCli.Get(d.ctx, key) - if err != nil { - return errors.Trace(err) - } - // Try to update. - for _, kv := range resp.Kvs { - if string(kv.Key) == key { - origin = string(kv.Value) - break - } - } - if origin != httpAddr { - logutil.DDLLogger().Warn(fmt.Sprintf("Update status addr of %v from %v to %v", key, origin, httpAddr)) - err := ddlutil.PutKVToEtcd(d.ctx, d.etcdCli, 1, key, httpAddr) - if err != nil { - return errors.Trace(err) - } - } - - return nil -} - -func updateTiFlashStores(pollTiFlashContext *TiFlashManagementContext) error { - // We need the up-to-date information about TiFlash stores. - // Since TiFlash Replica synchronize may happen immediately after new TiFlash stores are added. - tikvStats, err := infosync.GetTiFlashStoresStat(context.Background()) - // If MockTiFlash is not set, will issue a MockTiFlashError here. - if err != nil { - return err - } - pollTiFlashContext.TiFlashStores = make(map[int64]pd.StoreInfo) - for _, store := range tikvStats.Stores { - if engine.IsTiFlashHTTPResp(&store.Store) { - pollTiFlashContext.TiFlashStores[store.Store.ID] = store - } - } - logutil.DDLLogger().Debug("updateTiFlashStores finished", zap.Int("TiFlash store count", len(pollTiFlashContext.TiFlashStores))) - return nil -} - -// PollAvailableTableProgress will poll and check availability of available tables. -func PollAvailableTableProgress(schemas infoschema.InfoSchema, _ sessionctx.Context, pollTiFlashContext *TiFlashManagementContext) { - pollMaxCount := RefreshProgressMaxTableCount - failpoint.Inject("PollAvailableTableProgressMaxCount", func(val failpoint.Value) { - pollMaxCount = uint64(val.(int)) - }) - for element := pollTiFlashContext.UpdatingProgressTables.Front(); element != nil && pollMaxCount > 0; pollMaxCount-- { - availableTableID := element.Value.(AvailableTableID) - var table table.Table - if availableTableID.IsPartition { - table, _, _ = schemas.FindTableByPartitionID(availableTableID.ID) - if table == nil { - logutil.DDLLogger().Info("get table by partition failed, may be dropped or truncated", - zap.Int64("partitionID", availableTableID.ID), - ) - pollTiFlashContext.UpdatingProgressTables.Remove(element) - element = element.Next() - continue - } - } else { - var ok bool - table, ok = schemas.TableByID(availableTableID.ID) - if !ok { - logutil.DDLLogger().Info("get table id failed, may be dropped or truncated", - zap.Int64("tableID", availableTableID.ID), - ) - pollTiFlashContext.UpdatingProgressTables.Remove(element) - element = element.Next() - continue - } - } - tableInfo := table.Meta() - if tableInfo.TiFlashReplica == nil { - logutil.DDLLogger().Info("table has no TiFlash replica", - zap.Int64("tableID or partitionID", availableTableID.ID), - zap.Bool("IsPartition", availableTableID.IsPartition), - ) - pollTiFlashContext.UpdatingProgressTables.Remove(element) - element = element.Next() - continue - } - - progress, err := infosync.CalculateTiFlashProgress(availableTableID.ID, tableInfo.TiFlashReplica.Count, pollTiFlashContext.TiFlashStores) - if err != nil { - if intest.InTest && err.Error() != "EOF" { - // In the test, the server cannot start up because the port is occupied. - // Although the port is random. so we need to quickly return when to - // fail to get tiflash sync. - // https://github.com/pingcap/tidb/issues/39949 - panic(err) - } - pollTiFlashContext.UpdatingProgressTables.Remove(element) - element = element.Next() - continue - } - err = infosync.UpdateTiFlashProgressCache(availableTableID.ID, progress) - if err != nil { - logutil.DDLLogger().Error("update tiflash sync progress cache failed", - zap.Error(err), - zap.Int64("tableID", availableTableID.ID), - zap.Bool("IsPartition", availableTableID.IsPartition), - zap.Float64("progress", progress), - ) - pollTiFlashContext.UpdatingProgressTables.Remove(element) - element = element.Next() - continue - } - next := element.Next() - pollTiFlashContext.UpdatingProgressTables.Remove(element) - element = next - } -} - -func (d *ddl) refreshTiFlashTicker(ctx sessionctx.Context, pollTiFlashContext *TiFlashManagementContext) error { - if pollTiFlashContext.PollCounter%UpdateTiFlashStoreTick.Load() == 0 { - if err := updateTiFlashStores(pollTiFlashContext); err != nil { - // If we failed to get from pd, retry everytime. - pollTiFlashContext.PollCounter = 0 - return err - } - } - - failpoint.Inject("OneTiFlashStoreDown", func() { - for storeID, store := range pollTiFlashContext.TiFlashStores { - store.Store.StateName = "Down" - pollTiFlashContext.TiFlashStores[storeID] = store - break - } - }) - pollTiFlashContext.PollCounter++ - - // Start to process every table. - schema := d.infoCache.GetLatest() - if schema == nil { - return errors.New("Schema is nil") - } - - PollAvailableTableProgress(schema, ctx, pollTiFlashContext) - - var tableList = make([]TiFlashReplicaStatus, 0) - - // Collect TiFlash Replica info, for every table. - ch := schema.ListTablesWithSpecialAttribute(infoschema.TiFlashAttribute) - for _, v := range ch { - for _, tblInfo := range v.TableInfos { - LoadTiFlashReplicaInfo(tblInfo, &tableList) - } - } - - failpoint.Inject("waitForAddPartition", func(val failpoint.Value) { - for _, phyTable := range tableList { - is := d.infoCache.GetLatest() - _, ok := is.TableByID(phyTable.ID) - if !ok { - tb, _, _ := is.FindTableByPartitionID(phyTable.ID) - if tb == nil { - logutil.DDLLogger().Info("waitForAddPartition") - sleepSecond := val.(int) - time.Sleep(time.Duration(sleepSecond) * time.Second) - } - } - } - }) - - needPushPending := false - if pollTiFlashContext.UpdatingProgressTables.Len() == 0 { - needPushPending = true - } - - for _, tb := range tableList { - // For every region in each table, if it has one replica, we reckon it ready. - // These request can be batched as an optimization. - available := tb.Available - failpoint.Inject("PollTiFlashReplicaStatusReplacePrevAvailableValue", func(val failpoint.Value) { - available = val.(bool) - }) - // We only check unavailable tables here, so doesn't include blocked add partition case. - if !available && !tb.LogicalTableAvailable { - enabled, inqueue, _ := pollTiFlashContext.Backoff.Tick(tb.ID) - if inqueue && !enabled { - logutil.DDLLogger().Info("Escape checking available status due to backoff", zap.Int64("tableId", tb.ID)) - continue - } - - progress, err := infosync.CalculateTiFlashProgress(tb.ID, tb.Count, pollTiFlashContext.TiFlashStores) - if err != nil { - logutil.DDLLogger().Error("get tiflash sync progress failed", - zap.Error(err), - zap.Int64("tableID", tb.ID), - ) - continue - } - - err = infosync.UpdateTiFlashProgressCache(tb.ID, progress) - if err != nil { - logutil.DDLLogger().Error("get tiflash sync progress from cache failed", - zap.Error(err), - zap.Int64("tableID", tb.ID), - zap.Bool("IsPartition", tb.IsPartition), - zap.Float64("progress", progress), - ) - continue - } - - avail := progress == 1 - failpoint.Inject("PollTiFlashReplicaStatusReplaceCurAvailableValue", func(val failpoint.Value) { - avail = val.(bool) - }) - - if !avail { - logutil.DDLLogger().Info("Tiflash replica is not available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) - pollTiFlashContext.Backoff.Put(tb.ID) - } else { - logutil.DDLLogger().Info("Tiflash replica is available", zap.Int64("tableID", tb.ID), zap.Float64("progress", progress)) - pollTiFlashContext.Backoff.Remove(tb.ID) - } - failpoint.Inject("skipUpdateTableReplicaInfoInLoop", func() { - failpoint.Continue() - }) - // Will call `onUpdateFlashReplicaStatus` to update `TiFlashReplica`. - if err := d.executor.UpdateTableReplicaInfo(ctx, tb.ID, avail); err != nil { - if infoschema.ErrTableNotExists.Equal(err) && tb.IsPartition { - // May be due to blocking add partition - logutil.DDLLogger().Info("updating TiFlash replica status err, maybe false alarm by blocking add", zap.Error(err), zap.Int64("tableID", tb.ID), zap.Bool("isPartition", tb.IsPartition)) - } else { - logutil.DDLLogger().Error("updating TiFlash replica status err", zap.Error(err), zap.Int64("tableID", tb.ID), zap.Bool("isPartition", tb.IsPartition)) - } - } - } else { - if needPushPending { - pollTiFlashContext.UpdatingProgressTables.PushFront(AvailableTableID{tb.ID, tb.IsPartition}) - } - } - } - - return nil -} - -func (d *ddl) PollTiFlashRoutine() { - pollTiflashContext, err := NewTiFlashManagementContext() - if err != nil { - logutil.DDLLogger().Fatal("TiFlashManagement init failed", zap.Error(err)) - } - - hasSetTiFlashGroup := false - nextSetTiFlashGroupTime := time.Now() - for { - select { - case <-d.ctx.Done(): - return - case <-time.After(PollTiFlashInterval): - } - if d.IsTiFlashPollEnabled() { - if d.sessPool == nil { - logutil.DDLLogger().Error("failed to get sessionPool for refreshTiFlashTicker") - return - } - failpoint.Inject("BeforeRefreshTiFlashTickeLoop", func() { - failpoint.Continue() - }) - - if !hasSetTiFlashGroup && !time.Now().Before(nextSetTiFlashGroupTime) { - // We should set tiflash rule group a higher index than other placement groups to forbid override by them. - // Once `SetTiFlashGroupConfig` succeed, we do not need to invoke it again. If failed, we should retry it util success. - if err = infosync.SetTiFlashGroupConfig(d.ctx); err != nil { - logutil.DDLLogger().Warn("SetTiFlashGroupConfig failed", zap.Error(err)) - nextSetTiFlashGroupTime = time.Now().Add(time.Minute) - } else { - hasSetTiFlashGroup = true - } - } - - sctx, err := d.sessPool.Get() - if err == nil { - if d.ownerManager.IsOwner() { - err := d.refreshTiFlashTicker(sctx, pollTiflashContext) - if err != nil { - switch err.(type) { - case *infosync.MockTiFlashError: - // If we have not set up MockTiFlash instance, for those tests without TiFlash, just suppress. - default: - logutil.DDLLogger().Warn("refreshTiFlashTicker returns error", zap.Error(err)) - } - } - } else { - infosync.CleanTiFlashProgressCache() - } - d.sessPool.Put(sctx) - } else { - if sctx != nil { - d.sessPool.Put(sctx) - } - logutil.DDLLogger().Error("failed to get session for pollTiFlashReplicaStatus", zap.Error(err)) - } - } - } -} diff --git a/pkg/ddl/delete_range.go b/pkg/ddl/delete_range.go index bd5e0920aa000..bafa01c847977 100644 --- a/pkg/ddl/delete_range.go +++ b/pkg/ddl/delete_range.go @@ -372,11 +372,11 @@ func insertJobIntoDeleteRangeTable(ctx context.Context, wrapper DelRangeExecWrap if len(partitionIDs) == 0 { return errors.Trace(doBatchDeleteIndiceRange(ctx, wrapper, job.ID, tableID, allIndexIDs, ea, "drop index: table ID")) } - if val, _err_ := failpoint.Eval(_curpkg_("checkDropGlobalIndex")); _err_ == nil { + failpoint.Inject("checkDropGlobalIndex", func(val failpoint.Value) { if val.(bool) { panic("drop global index must not delete partition index range") } - } + }) for _, pid := range partitionIDs { if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, allIndexIDs, ea, "drop index: partition table ID"); err != nil { return errors.Trace(err) diff --git a/pkg/ddl/delete_range.go__failpoint_stash__ b/pkg/ddl/delete_range.go__failpoint_stash__ deleted file mode 100644 index bafa01c847977..0000000000000 --- a/pkg/ddl/delete_range.go__failpoint_stash__ +++ /dev/null @@ -1,548 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "encoding/hex" - "math" - "strings" - "sync" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/tablecodec" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "go.uber.org/zap" -) - -const ( - insertDeleteRangeSQLPrefix = `INSERT IGNORE INTO mysql.gc_delete_range VALUES ` - insertDeleteRangeSQLValue = `(%?, %?, %?, %?, %?)` - - delBatchSize = 65536 - delBackLog = 128 -) - -// Only used in the BR unit test. Once these const variables modified, please make sure compatible with BR. -const ( - BRInsertDeleteRangeSQLPrefix = insertDeleteRangeSQLPrefix - BRInsertDeleteRangeSQLValue = insertDeleteRangeSQLValue -) - -var ( - // batchInsertDeleteRangeSize is the maximum size for each batch insert statement in the delete-range. - batchInsertDeleteRangeSize = 256 -) - -type delRangeManager interface { - // addDelRangeJob add a DDL job into gc_delete_range table. - addDelRangeJob(ctx context.Context, job *model.Job) error - // removeFromGCDeleteRange removes the deleting table job from gc_delete_range table by jobID and tableID. - // It's use for recover the table that was mistakenly deleted. - removeFromGCDeleteRange(ctx context.Context, jobID int64) error - start() - clear() -} - -type delRange struct { - store kv.Storage - sessPool *sess.Pool - emulatorCh chan struct{} - keys []kv.Key - quitCh chan struct{} - - wait sync.WaitGroup // wait is only used when storeSupport is false. - storeSupport bool -} - -// newDelRangeManager returns a delRangeManager. -func newDelRangeManager(store kv.Storage, sessPool *sess.Pool) delRangeManager { - dr := &delRange{ - store: store, - sessPool: sessPool, - storeSupport: store.SupportDeleteRange(), - quitCh: make(chan struct{}), - } - if !dr.storeSupport { - dr.emulatorCh = make(chan struct{}, delBackLog) - dr.keys = make([]kv.Key, 0, delBatchSize) - } - return dr -} - -// addDelRangeJob implements delRangeManager interface. -func (dr *delRange) addDelRangeJob(ctx context.Context, job *model.Job) error { - sctx, err := dr.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer dr.sessPool.Put(sctx) - - // The same Job ID uses the same element ID allocator - wrapper := newDelRangeExecWrapper(sctx) - err = AddDelRangeJobInternal(ctx, wrapper, job) - if err != nil { - logutil.DDLLogger().Error("add job into delete-range table failed", zap.Int64("jobID", job.ID), zap.String("jobType", job.Type.String()), zap.Error(err)) - return errors.Trace(err) - } - if !dr.storeSupport { - dr.emulatorCh <- struct{}{} - } - logutil.DDLLogger().Info("add job into delete-range table", zap.Int64("jobID", job.ID), zap.String("jobType", job.Type.String())) - return nil -} - -// AddDelRangeJobInternal implements the generation the delete ranges for the provided job and consumes the delete ranges through delRangeExecWrapper. -func AddDelRangeJobInternal(ctx context.Context, wrapper DelRangeExecWrapper, job *model.Job) error { - var err error - var ea elementIDAlloc - if job.MultiSchemaInfo != nil { - err = insertJobIntoDeleteRangeTableMultiSchema(ctx, wrapper, job, &ea) - } else { - err = insertJobIntoDeleteRangeTable(ctx, wrapper, job, &ea) - } - return errors.Trace(err) -} - -func insertJobIntoDeleteRangeTableMultiSchema(ctx context.Context, wrapper DelRangeExecWrapper, job *model.Job, ea *elementIDAlloc) error { - for i, sub := range job.MultiSchemaInfo.SubJobs { - proxyJob := sub.ToProxyJob(job, i) - if JobNeedGC(&proxyJob) { - err := insertJobIntoDeleteRangeTable(ctx, wrapper, &proxyJob, ea) - if err != nil { - return errors.Trace(err) - } - } - } - return nil -} - -// removeFromGCDeleteRange implements delRangeManager interface. -func (dr *delRange) removeFromGCDeleteRange(ctx context.Context, jobID int64) error { - sctx, err := dr.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer dr.sessPool.Put(sctx) - err = util.RemoveMultiFromGCDeleteRange(ctx, sctx, jobID) - return errors.Trace(err) -} - -// start implements delRangeManager interface. -func (dr *delRange) start() { - if !dr.storeSupport { - dr.wait.Add(1) - go dr.startEmulator() - } -} - -// clear implements delRangeManager interface. -func (dr *delRange) clear() { - logutil.DDLLogger().Info("closing delRange") - close(dr.quitCh) - dr.wait.Wait() -} - -// startEmulator is only used for those storage engines which don't support -// delete-range. The emulator fetches records from gc_delete_range table and -// deletes all keys in each DelRangeTask. -func (dr *delRange) startEmulator() { - defer dr.wait.Done() - logutil.DDLLogger().Info("start delRange emulator") - for { - select { - case <-dr.emulatorCh: - case <-dr.quitCh: - return - } - if util.IsEmulatorGCEnable() { - err := dr.doDelRangeWork() - terror.Log(errors.Trace(err)) - } - } -} - -func (dr *delRange) doDelRangeWork() error { - sctx, err := dr.sessPool.Get() - if err != nil { - logutil.DDLLogger().Error("delRange emulator get session failed", zap.Error(err)) - return errors.Trace(err) - } - defer dr.sessPool.Put(sctx) - - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - ranges, err := util.LoadDeleteRanges(ctx, sctx, math.MaxInt64) - if err != nil { - logutil.DDLLogger().Error("delRange emulator load tasks failed", zap.Error(err)) - return errors.Trace(err) - } - - for _, r := range ranges { - if err := dr.doTask(sctx, r); err != nil { - logutil.DDLLogger().Error("delRange emulator do task failed", zap.Error(err)) - return errors.Trace(err) - } - } - return nil -} - -func (dr *delRange) doTask(sctx sessionctx.Context, r util.DelRangeTask) error { - var oldStartKey, newStartKey kv.Key - oldStartKey = r.StartKey - for { - finish := true - dr.keys = dr.keys[:0] - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - err := kv.RunInNewTxn(ctx, dr.store, false, func(_ context.Context, txn kv.Transaction) error { - if topsqlstate.TopSQLEnabled() { - // Only when TiDB run without PD(use unistore as storage for test) will run into here, so just set a mock internal resource tagger. - txn.SetOption(kv.ResourceGroupTagger, util.GetInternalResourceGroupTaggerForTopSQL()) - } - iter, err := txn.Iter(oldStartKey, r.EndKey) - if err != nil { - return errors.Trace(err) - } - defer iter.Close() - - txn.SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - for i := 0; i < delBatchSize; i++ { - if !iter.Valid() { - break - } - finish = false - dr.keys = append(dr.keys, iter.Key().Clone()) - newStartKey = iter.Key().Next() - - if err := iter.Next(); err != nil { - return errors.Trace(err) - } - } - - for _, key := range dr.keys { - err := txn.Delete(key) - if err != nil && !kv.ErrNotExist.Equal(err) { - return errors.Trace(err) - } - } - return nil - }) - if err != nil { - return errors.Trace(err) - } - if finish { - if err := util.CompleteDeleteRange(sctx, r, true); err != nil { - logutil.DDLLogger().Error("delRange emulator complete task failed", zap.Error(err)) - return errors.Trace(err) - } - startKey, endKey := r.Range() - logutil.DDLLogger().Info("delRange emulator complete task", zap.String("category", "ddl"), - zap.Int64("jobID", r.JobID), - zap.Int64("elementID", r.ElementID), - zap.Stringer("startKey", startKey), - zap.Stringer("endKey", endKey)) - break - } - if err := util.UpdateDeleteRange(sctx, r, newStartKey, oldStartKey); err != nil { - logutil.DDLLogger().Error("delRange emulator update task failed", zap.Error(err)) - } - oldStartKey = newStartKey - } - return nil -} - -// insertJobIntoDeleteRangeTable parses the job into delete-range arguments, -// and inserts a new record into gc_delete_range table. The primary key is -// (job ID, element ID), so we ignore key conflict error. -func insertJobIntoDeleteRangeTable(ctx context.Context, wrapper DelRangeExecWrapper, job *model.Job, ea *elementIDAlloc) error { - if err := wrapper.UpdateTSOForJob(); err != nil { - return errors.Trace(err) - } - - ctx = kv.WithInternalSourceType(ctx, getDDLRequestSource(job.Type)) - switch job.Type { - case model.ActionDropSchema: - var tableIDs []int64 - if err := job.DecodeArgs(&tableIDs); err != nil { - return errors.Trace(err) - } - for i := 0; i < len(tableIDs); i += batchInsertDeleteRangeSize { - batchEnd := len(tableIDs) - if batchEnd > i+batchInsertDeleteRangeSize { - batchEnd = i + batchInsertDeleteRangeSize - } - if err := doBatchDeleteTablesRange(ctx, wrapper, job.ID, tableIDs[i:batchEnd], ea, "drop schema: table IDs"); err != nil { - return errors.Trace(err) - } - } - case model.ActionDropTable, model.ActionTruncateTable: - tableID := job.TableID - // The startKey here is for compatibility with previous versions, old version did not endKey so don't have to deal with. - var startKey kv.Key - var physicalTableIDs []int64 - var ruleIDs []string - if err := job.DecodeArgs(&startKey, &physicalTableIDs, &ruleIDs); err != nil { - return errors.Trace(err) - } - if len(physicalTableIDs) > 0 { - if err := doBatchDeleteTablesRange(ctx, wrapper, job.ID, physicalTableIDs, ea, "drop table: partition table IDs"); err != nil { - return errors.Trace(err) - } - // logical table may contain global index regions, so delete the logical table range. - return errors.Trace(doBatchDeleteTablesRange(ctx, wrapper, job.ID, []int64{tableID}, ea, "drop table: table ID")) - } - return errors.Trace(doBatchDeleteTablesRange(ctx, wrapper, job.ID, []int64{tableID}, ea, "drop table: table ID")) - case model.ActionDropTablePartition, model.ActionTruncateTablePartition, - model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - var physicalTableIDs []int64 - if err := job.DecodeArgs(&physicalTableIDs); err != nil { - return errors.Trace(err) - } - return errors.Trace(doBatchDeleteTablesRange(ctx, wrapper, job.ID, physicalTableIDs, ea, "reorganize/drop partition: physical table ID(s)")) - // ActionAddIndex, ActionAddPrimaryKey needs do it, because it needs to be rolled back when it's canceled. - case model.ActionAddIndex, model.ActionAddPrimaryKey: - allIndexIDs := make([]int64, 1) - ifExists := make([]bool, 1) - isGlobal := make([]bool, 0, 1) - var partitionIDs []int64 - if err := job.DecodeArgs(&allIndexIDs[0], &ifExists[0], &partitionIDs); err != nil { - if err = job.DecodeArgs(&allIndexIDs, &ifExists, &partitionIDs, &isGlobal); err != nil { - return errors.Trace(err) - } - } - // Determine the physicalIDs to be added. - physicalIDs := []int64{job.TableID} - if len(partitionIDs) > 0 { - physicalIDs = partitionIDs - } - for i, indexID := range allIndexIDs { - // Determine the index IDs to be added. - tempIdxID := tablecodec.TempIndexPrefix | indexID - var indexIDs []int64 - if job.State == model.JobStateRollbackDone { - indexIDs = []int64{indexID, tempIdxID} - } else { - indexIDs = []int64{tempIdxID} - } - if len(isGlobal) != 0 && isGlobal[i] { - if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, job.TableID, indexIDs, ea, "add index: physical table ID(s)"); err != nil { - return errors.Trace(err) - } - } else { - for _, pid := range physicalIDs { - if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, indexIDs, ea, "add index: physical table ID(s)"); err != nil { - return errors.Trace(err) - } - } - } - } - case model.ActionDropIndex, model.ActionDropPrimaryKey: - tableID := job.TableID - var indexName any - var partitionIDs []int64 - ifExists := make([]bool, 1) - allIndexIDs := make([]int64, 1) - if err := job.DecodeArgs(&indexName, &ifExists[0], &allIndexIDs[0], &partitionIDs); err != nil { - if err = job.DecodeArgs(&indexName, &ifExists, &allIndexIDs, &partitionIDs); err != nil { - return errors.Trace(err) - } - } - // partitionIDs len is 0 if the dropped index is a global index, even if it is a partitioned table. - if len(partitionIDs) == 0 { - return errors.Trace(doBatchDeleteIndiceRange(ctx, wrapper, job.ID, tableID, allIndexIDs, ea, "drop index: table ID")) - } - failpoint.Inject("checkDropGlobalIndex", func(val failpoint.Value) { - if val.(bool) { - panic("drop global index must not delete partition index range") - } - }) - for _, pid := range partitionIDs { - if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, allIndexIDs, ea, "drop index: partition table ID"); err != nil { - return errors.Trace(err) - } - } - case model.ActionDropColumn: - var colName model.CIStr - var ifExists bool - var indexIDs []int64 - var partitionIDs []int64 - if err := job.DecodeArgs(&colName, &ifExists, &indexIDs, &partitionIDs); err != nil { - return errors.Trace(err) - } - if len(indexIDs) > 0 { - if len(partitionIDs) == 0 { - return errors.Trace(doBatchDeleteIndiceRange(ctx, wrapper, job.ID, job.TableID, indexIDs, ea, "drop column: table ID")) - } - for _, pid := range partitionIDs { - if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, indexIDs, ea, "drop column: partition table ID"); err != nil { - return errors.Trace(err) - } - } - } - case model.ActionModifyColumn: - var indexIDs []int64 - var partitionIDs []int64 - if err := job.DecodeArgs(&indexIDs, &partitionIDs); err != nil { - return errors.Trace(err) - } - if len(indexIDs) == 0 { - return nil - } - if len(partitionIDs) == 0 { - return doBatchDeleteIndiceRange(ctx, wrapper, job.ID, job.TableID, indexIDs, ea, "modify column: table ID") - } - for _, pid := range partitionIDs { - if err := doBatchDeleteIndiceRange(ctx, wrapper, job.ID, pid, indexIDs, ea, "modify column: partition table ID"); err != nil { - return errors.Trace(err) - } - } - } - return nil -} - -func doBatchDeleteIndiceRange(ctx context.Context, wrapper DelRangeExecWrapper, jobID, tableID int64, indexIDs []int64, ea *elementIDAlloc, comment string) error { - logutil.DDLLogger().Info("insert into delete-range indices", zap.Int64("jobID", jobID), zap.Int64("tableID", tableID), zap.Int64s("indexIDs", indexIDs), zap.String("comment", comment)) - var buf strings.Builder - buf.WriteString(insertDeleteRangeSQLPrefix) - wrapper.PrepareParamsList(len(indexIDs) * 5) - tableID, ok := wrapper.RewriteTableID(tableID) - if !ok { - return nil - } - for i, indexID := range indexIDs { - startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) - endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) - startKeyEncoded := hex.EncodeToString(startKey) - endKeyEncoded := hex.EncodeToString(endKey) - buf.WriteString(insertDeleteRangeSQLValue) - if i != len(indexIDs)-1 { - buf.WriteString(",") - } - elemID := ea.allocForIndexID(tableID, indexID) - wrapper.AppendParamsList(jobID, elemID, startKeyEncoded, endKeyEncoded) - } - - return errors.Trace(wrapper.ConsumeDeleteRange(ctx, buf.String())) -} - -func doBatchDeleteTablesRange(ctx context.Context, wrapper DelRangeExecWrapper, jobID int64, tableIDs []int64, ea *elementIDAlloc, comment string) error { - logutil.DDLLogger().Info("insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("tableIDs", tableIDs), zap.String("comment", comment)) - var buf strings.Builder - buf.WriteString(insertDeleteRangeSQLPrefix) - wrapper.PrepareParamsList(len(tableIDs) * 5) - for i, tableID := range tableIDs { - tableID, ok := wrapper.RewriteTableID(tableID) - if !ok { - continue - } - startKey := tablecodec.EncodeTablePrefix(tableID) - endKey := tablecodec.EncodeTablePrefix(tableID + 1) - startKeyEncoded := hex.EncodeToString(startKey) - endKeyEncoded := hex.EncodeToString(endKey) - buf.WriteString(insertDeleteRangeSQLValue) - if i != len(tableIDs)-1 { - buf.WriteString(",") - } - elemID := ea.allocForPhysicalID(tableID) - wrapper.AppendParamsList(jobID, elemID, startKeyEncoded, endKeyEncoded) - } - - return errors.Trace(wrapper.ConsumeDeleteRange(ctx, buf.String())) -} - -// DelRangeExecWrapper consumes the delete ranges with the provided table ID(s) and index ID(s). -type DelRangeExecWrapper interface { - // generate a new tso for the next job - UpdateTSOForJob() error - - // initialize the paramsList - PrepareParamsList(sz int) - - // rewrite table id if necessary, used for BR - RewriteTableID(tableID int64) (int64, bool) - - // (job_id, element_id, start_key, end_key, ts) - // ts is generated by delRangeExecWrapper itself - AppendParamsList(jobID, elemID int64, startKey, endKey string) - - // consume the delete range. For TiDB Server, it insert rows into mysql.gc_delete_range. - ConsumeDeleteRange(ctx context.Context, sql string) error -} - -// sessionDelRangeExecWrapper is a lightweight wrapper that implements the DelRangeExecWrapper interface and used for TiDB Server. -// It consumes the delete ranges by directly insert rows into mysql.gc_delete_range. -type sessionDelRangeExecWrapper struct { - sctx sessionctx.Context - ts uint64 - - // temporary values - paramsList []any -} - -func newDelRangeExecWrapper(sctx sessionctx.Context) DelRangeExecWrapper { - return &sessionDelRangeExecWrapper{ - sctx: sctx, - paramsList: nil, - } -} - -func (sdr *sessionDelRangeExecWrapper) UpdateTSOForJob() error { - now, err := getNowTSO(sdr.sctx) - if err != nil { - return errors.Trace(err) - } - sdr.ts = now - return nil -} - -func (sdr *sessionDelRangeExecWrapper) PrepareParamsList(sz int) { - sdr.paramsList = make([]any, 0, sz) -} - -func (*sessionDelRangeExecWrapper) RewriteTableID(tableID int64) (int64, bool) { - return tableID, true -} - -func (sdr *sessionDelRangeExecWrapper) AppendParamsList(jobID, elemID int64, startKey, endKey string) { - sdr.paramsList = append(sdr.paramsList, jobID, elemID, startKey, endKey, sdr.ts) -} - -func (sdr *sessionDelRangeExecWrapper) ConsumeDeleteRange(ctx context.Context, sql string) error { - // set session disk full opt - sdr.sctx.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - _, err := sdr.sctx.GetSQLExecutor().ExecuteInternal(ctx, sql, sdr.paramsList...) - // clear session disk full opt - sdr.sctx.GetSessionVars().ClearDiskFullOpt() - sdr.paramsList = nil - return errors.Trace(err) -} - -// getNowTS gets the current timestamp, in TSO. -func getNowTSO(ctx sessionctx.Context) (uint64, error) { - currVer, err := ctx.GetStore().CurrentVersion(kv.GlobalTxnScope) - if err != nil { - return 0, errors.Trace(err) - } - return currVer.Ver, nil -} diff --git a/pkg/ddl/executor.go b/pkg/ddl/executor.go index 834d1e7048785..225ba84582730 100644 --- a/pkg/ddl/executor.go +++ b/pkg/ddl/executor.go @@ -435,19 +435,19 @@ func isSessionDone(sctx sessionctx.Context) (bool, uint32) { if killed { return true, 1 } - if val, _err_ := failpoint.Eval(_curpkg_("BatchAddTiFlashSendDone")); _err_ == nil { + failpoint.Inject("BatchAddTiFlashSendDone", func(val failpoint.Value) { done = val.(bool) - } + }) return done, 0 } func (e *executor) waitPendingTableThreshold(sctx sessionctx.Context, schemaID int64, tableID int64, originVersion int64, pendingCount uint32, threshold uint32) (bool, int64, uint32, bool) { configRetry := tiflashCheckPendingTablesRetry configWaitTime := tiflashCheckPendingTablesWaitTime - if value, _err_ := failpoint.Eval(_curpkg_("FastFailCheckTiFlashPendingTables")); _err_ == nil { + failpoint.Inject("FastFailCheckTiFlashPendingTables", func(value failpoint.Value) { configRetry = value.(int) configWaitTime = time.Millisecond * 200 - } + }) for retry := 0; retry < configRetry; retry++ { done, killed := isSessionDone(sctx) @@ -1190,12 +1190,12 @@ func (e *executor) BatchCreateTableWithInfo(ctx sessionctx.Context, infos []*model.TableInfo, cs ...CreateTableOption, ) error { - if val, _err_ := failpoint.Eval(_curpkg_("RestoreBatchCreateTableEntryTooLarge")); _err_ == nil { + failpoint.Inject("RestoreBatchCreateTableEntryTooLarge", func(val failpoint.Value) { injectBatchSize := val.(int) if len(infos) > injectBatchSize { - return kv.ErrEntryTooLarge + failpoint.Return(kv.ErrEntryTooLarge) } - } + }) c := GetCreateTableConfig(cs) jobW := NewJobWrapper( @@ -2181,7 +2181,7 @@ func (e *executor) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.Alt if err != nil { return errors.Trace(err) } - failpoint.Call(_curpkg_("afterGetSchemaAndTableByIdent"), ctx) + failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) tbInfo := t.Meta() if err = checkAddColumnTooManyColumns(len(t.Cols()) + 1); err != nil { return errors.Trace(err) @@ -2577,7 +2577,7 @@ func (e *executor) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. err = e.DoDDLJob(ctx, job) - failpoint.Call(_curpkg_("afterReorganizePartition")) + failpoint.InjectCall("afterReorganizePartition") if err == nil { ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of related partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) } @@ -3172,7 +3172,7 @@ func (e *executor) DropColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.Al if err != nil { return errors.Trace(err) } - failpoint.Call(_curpkg_("afterGetSchemaAndTableByIdent"), ctx) + failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) isDropable, err := checkIsDroppableColumn(ctx, e.infoCache.GetLatest(), schema, t, spec) if err != nil { @@ -4766,9 +4766,9 @@ func newReorgMetaFromVariables(job *model.Job, sctx sessionctx.Context) (*model. } reorgMeta.IsDistReorg = false reorgMeta.IsFastReorg = false - if _, _err_ := failpoint.Eval(_curpkg_("reorgMetaRecordFastReorgDisabled")); _err_ == nil { + failpoint.Inject("reorgMetaRecordFastReorgDisabled", func(_ failpoint.Value) { LastReorgMetaFastReorgDisabled = true - } + }) } return reorgMeta, nil } @@ -6296,7 +6296,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err setDDLJobQuery(ctx, job) e.deliverJobTask(jobW) - if val, _err_ := failpoint.Eval(_curpkg_("mockParallelSameDDLJobTwice")); _err_ == nil { + failpoint.Inject("mockParallelSameDDLJobTwice", func(val failpoint.Value) { if val.(bool) { <-jobW.ResultCh[0] // The same job will be put to the DDL queue twice. @@ -6306,7 +6306,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err // The second job result is used for test. jobW = newJobW } - } + }) // worker should restart to continue handling tasks in limitJobCh, and send back through jobW.err result := <-jobW.ResultCh[0] @@ -6317,7 +6317,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err // The transaction of enqueuing job is failed. return errors.Trace(err) } - failpoint.Call(_curpkg_("waitJobSubmitted")) + failpoint.InjectCall("waitJobSubmitted") sessVars := ctx.GetSessionVars() sessVars.StmtCtx.IsDDLJobInQueue = true @@ -6353,7 +6353,7 @@ func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) err i := 0 notifyCh, _ := e.getJobDoneCh(jobID) for { - failpoint.Call(_curpkg_("storeCloseInLoop")) + failpoint.InjectCall("storeCloseInLoop") select { case _, ok := <-notifyCh: if !ok { diff --git a/pkg/ddl/executor.go__failpoint_stash__ b/pkg/ddl/executor.go__failpoint_stash__ deleted file mode 100644 index 225ba84582730..0000000000000 --- a/pkg/ddl/executor.go__failpoint_stash__ +++ /dev/null @@ -1,6540 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -package ddl - -import ( - "bytes" - "context" - "fmt" - "math" - "strings" - "sync" - "sync/atomic" - "time" - "unicode/utf8" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/label" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/ddl/resourcegroup" - sess "github.com/pingcap/tidb/pkg/ddl/session" - ddlutil "github.com/pingcap/tidb/pkg/ddl/util" - rg "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/owner" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/statistics/handle" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/domainutil" - "github.com/pingcap/tidb/pkg/util/generic" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/mathutil" - "github.com/pingcap/tidb/pkg/util/stringutil" - "github.com/tikv/client-go/v2/oracle" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" -) - -const ( - expressionIndexPrefix = "_V$" - changingColumnPrefix = "_Col$_" - changingIndexPrefix = "_Idx$_" - tableNotExist = -1 - tinyBlobMaxLength = 255 - blobMaxLength = 65535 - mediumBlobMaxLength = 16777215 - longBlobMaxLength = 4294967295 - // When setting the placement policy with "PLACEMENT POLICY `default`", - // it means to remove placement policy from the specified object. - defaultPlacementPolicyName = "default" - tiflashCheckPendingTablesWaitTime = 3000 * time.Millisecond - // Once tiflashCheckPendingTablesLimit is reached, we trigger a limiter detection. - tiflashCheckPendingTablesLimit = 100 - tiflashCheckPendingTablesRetry = 7 -) - -var errCheckConstraintIsOff = errors.NewNoStackError(variable.TiDBEnableCheckConstraint + " is off") - -// Executor is the interface for executing DDL statements. -// it's mostly called by SQL executor. -type Executor interface { - CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDatabaseStmt) error - AlterSchema(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) error - DropSchema(ctx sessionctx.Context, stmt *ast.DropDatabaseStmt) error - CreateTable(ctx sessionctx.Context, stmt *ast.CreateTableStmt) error - CreateView(ctx sessionctx.Context, stmt *ast.CreateViewStmt) error - DropTable(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) - RecoverTable(ctx sessionctx.Context, recoverInfo *RecoverInfo) (err error) - RecoverSchema(ctx sessionctx.Context, recoverSchemaInfo *RecoverSchemaInfo) error - DropView(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) - CreateIndex(ctx sessionctx.Context, stmt *ast.CreateIndexStmt) error - DropIndex(ctx sessionctx.Context, stmt *ast.DropIndexStmt) error - AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast.AlterTableStmt) error - TruncateTable(ctx sessionctx.Context, tableIdent ast.Ident) error - RenameTable(ctx sessionctx.Context, stmt *ast.RenameTableStmt) error - LockTables(ctx sessionctx.Context, stmt *ast.LockTablesStmt) error - UnlockTables(ctx sessionctx.Context, lockedTables []model.TableLockTpInfo) error - CleanupTableLock(ctx sessionctx.Context, tables []*ast.TableName) error - UpdateTableReplicaInfo(ctx sessionctx.Context, physicalID int64, available bool) error - RepairTable(ctx sessionctx.Context, createStmt *ast.CreateTableStmt) error - CreateSequence(ctx sessionctx.Context, stmt *ast.CreateSequenceStmt) error - DropSequence(ctx sessionctx.Context, stmt *ast.DropSequenceStmt) (err error) - AlterSequence(ctx sessionctx.Context, stmt *ast.AlterSequenceStmt) error - CreatePlacementPolicy(ctx sessionctx.Context, stmt *ast.CreatePlacementPolicyStmt) error - DropPlacementPolicy(ctx sessionctx.Context, stmt *ast.DropPlacementPolicyStmt) error - AlterPlacementPolicy(ctx sessionctx.Context, stmt *ast.AlterPlacementPolicyStmt) error - AddResourceGroup(ctx sessionctx.Context, stmt *ast.CreateResourceGroupStmt) error - AlterResourceGroup(ctx sessionctx.Context, stmt *ast.AlterResourceGroupStmt) error - DropResourceGroup(ctx sessionctx.Context, stmt *ast.DropResourceGroupStmt) error - FlashbackCluster(ctx sessionctx.Context, flashbackTS uint64) error - - // CreateSchemaWithInfo creates a database (schema) given its database info. - // - // WARNING: the DDL owns the `info` after calling this function, and will modify its fields - // in-place. If you want to keep using `info`, please call Clone() first. - CreateSchemaWithInfo( - ctx sessionctx.Context, - info *model.DBInfo, - onExist OnExist) error - - // CreateTableWithInfo creates a table, view or sequence given its table info. - // - // WARNING: the DDL owns the `info` after calling this function, and will modify its fields - // in-place. If you want to keep using `info`, please call Clone() first. - CreateTableWithInfo( - ctx sessionctx.Context, - schema model.CIStr, - info *model.TableInfo, - involvingRef []model.InvolvingSchemaInfo, - cs ...CreateTableOption) error - - // BatchCreateTableWithInfo is like CreateTableWithInfo, but can handle multiple tables. - BatchCreateTableWithInfo(ctx sessionctx.Context, - schema model.CIStr, - info []*model.TableInfo, - cs ...CreateTableOption) error - - // CreatePlacementPolicyWithInfo creates a placement policy - // - // WARNING: the DDL owns the `policy` after calling this function, and will modify its fields - // in-place. If you want to keep using `policy`, please call Clone() first. - CreatePlacementPolicyWithInfo(ctx sessionctx.Context, policy *model.PolicyInfo, onExist OnExist) error -} - -// ExecutorForTest is the interface for executing DDL statements in tests. -// TODO remove it later -type ExecutorForTest interface { - // DoDDLJob does the DDL job, it's exported for test. - DoDDLJob(ctx sessionctx.Context, job *model.Job) error - // DoDDLJobWrapper similar to DoDDLJob, but with JobWrapper as input. - DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) error -} - -// all fields are shared with ddl now. -type executor struct { - sessPool *sess.Pool - statsHandle *handle.Handle - - ctx context.Context - uuid string - store kv.Storage - etcdCli *clientv3.Client - autoidCli *autoid.ClientDiscover - infoCache *infoschema.InfoCache - limitJobCh chan *JobWrapper - schemaLoader SchemaLoader - lease time.Duration // lease is schema lease, default 45s, see config.Lease. - ownerManager owner.Manager - ddlJobDoneChMap *generic.SyncMap[int64, chan struct{}] - ddlJobNotifyCh chan struct{} - globalIDLock *sync.Mutex -} - -var _ Executor = (*executor)(nil) -var _ ExecutorForTest = (*executor)(nil) - -func (e *executor) CreateSchema(ctx sessionctx.Context, stmt *ast.CreateDatabaseStmt) (err error) { - var placementPolicyRef *model.PolicyRefInfo - sessionVars := ctx.GetSessionVars() - - // If no charset and/or collation is specified use collation_server and character_set_server - charsetOpt := ast.CharsetOpt{} - if sessionVars.GlobalVarsAccessor != nil { - charsetOpt.Col, err = sessionVars.GetSessionOrGlobalSystemVar(context.Background(), variable.CollationServer) - if err != nil { - return err - } - charsetOpt.Chs, err = sessionVars.GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetServer) - if err != nil { - return err - } - } - - explicitCharset := false - explicitCollation := false - for _, val := range stmt.Options { - switch val.Tp { - case ast.DatabaseOptionCharset: - charsetOpt.Chs = val.Value - explicitCharset = true - case ast.DatabaseOptionCollate: - charsetOpt.Col = val.Value - explicitCollation = true - case ast.DatabaseOptionPlacementPolicy: - placementPolicyRef = &model.PolicyRefInfo{ - Name: model.NewCIStr(val.Value), - } - } - } - - if charsetOpt.Col != "" { - coll, err := collate.GetCollationByName(charsetOpt.Col) - if err != nil { - return err - } - - // The collation is not valid for the specified character set. - // Try to remove any of them, but not if they are explicitly defined. - if coll.CharsetName != charsetOpt.Chs { - if explicitCollation && !explicitCharset { - // Use the explicitly set collation, not the implicit charset. - charsetOpt.Chs = "" - } - if !explicitCollation && explicitCharset { - // Use the explicitly set charset, not the (session) collation. - charsetOpt.Col = "" - } - } - } - if !explicitCollation && explicitCharset { - coll, err := getDefaultCollationForUTF8MB4(ctx.GetSessionVars(), charsetOpt.Chs) - if err != nil { - return err - } - if len(coll) != 0 { - charsetOpt.Col = coll - } - } - dbInfo := &model.DBInfo{Name: stmt.Name} - chs, coll, err := ResolveCharsetCollation(ctx.GetSessionVars(), charsetOpt) - if err != nil { - return errors.Trace(err) - } - dbInfo.Charset = chs - dbInfo.Collate = coll - dbInfo.PlacementPolicyRef = placementPolicyRef - - onExist := OnExistError - if stmt.IfNotExists { - onExist = OnExistIgnore - } - return e.CreateSchemaWithInfo(ctx, dbInfo, onExist) -} - -func (e *executor) CreateSchemaWithInfo( - ctx sessionctx.Context, - dbInfo *model.DBInfo, - onExist OnExist, -) error { - is := e.infoCache.GetLatest() - _, ok := is.SchemaByName(dbInfo.Name) - if ok { - // since this error may be seen as error, keep it stack info. - err := infoschema.ErrDatabaseExists.GenWithStackByArgs(dbInfo.Name) - switch onExist { - case OnExistIgnore: - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - case OnExistError, OnExistReplace: - // FIXME: can we implement MariaDB's CREATE OR REPLACE SCHEMA? - return err - } - } - - if err := checkTooLongSchema(dbInfo.Name); err != nil { - return errors.Trace(err) - } - - if err := checkCharsetAndCollation(dbInfo.Charset, dbInfo.Collate); err != nil { - return errors.Trace(err) - } - - if err := handleDatabasePlacement(ctx, dbInfo); err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaName: dbInfo.Name.L, - Type: model.ActionCreateSchema, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{dbInfo}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Database: dbInfo.Name.L, - Table: model.InvolvingAll, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - if ref := dbInfo.PlacementPolicyRef; ref != nil { - job.InvolvingSchemaInfo = append(job.InvolvingSchemaInfo, model.InvolvingSchemaInfo{ - Policy: ref.Name.L, - Mode: model.SharedInvolving, - }) - } - - err := e.DoDDLJob(ctx, job) - - if infoschema.ErrDatabaseExists.Equal(err) && onExist == OnExistIgnore { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - - return errors.Trace(err) -} - -func (e *executor) ModifySchemaCharsetAndCollate(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt, toCharset, toCollate string) (err error) { - if toCollate == "" { - if toCollate, err = GetDefaultCollation(ctx.GetSessionVars(), toCharset); err != nil { - return errors.Trace(err) - } - } - - // Check if need to change charset/collation. - dbName := stmt.Name - is := e.infoCache.GetLatest() - dbInfo, ok := is.SchemaByName(dbName) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName.O) - } - if dbInfo.Charset == toCharset && dbInfo.Collate == toCollate { - return nil - } - // Do the DDL job. - job := &model.Job{ - SchemaID: dbInfo.ID, - SchemaName: dbInfo.Name.L, - Type: model.ActionModifySchemaCharsetAndCollate, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{toCharset, toCollate}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Database: dbInfo.Name.L, - Table: model.InvolvingAll, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) ModifySchemaDefaultPlacement(ctx sessionctx.Context, stmt *ast.AlterDatabaseStmt, placementPolicyRef *model.PolicyRefInfo) (err error) { - dbName := stmt.Name - is := e.infoCache.GetLatest() - dbInfo, ok := is.SchemaByName(dbName) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName.O) - } - - if checkIgnorePlacementDDL(ctx) { - return nil - } - - placementPolicyRef, err = checkAndNormalizePlacementPolicy(ctx, placementPolicyRef) - if err != nil { - return err - } - - // Do the DDL job. - job := &model.Job{ - SchemaID: dbInfo.ID, - SchemaName: dbInfo.Name.L, - Type: model.ActionModifySchemaDefaultPlacement, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{placementPolicyRef}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Database: dbInfo.Name.L, - Table: model.InvolvingAll, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - if placementPolicyRef != nil { - job.InvolvingSchemaInfo = append(job.InvolvingSchemaInfo, model.InvolvingSchemaInfo{ - Policy: placementPolicyRef.Name.L, - Mode: model.SharedInvolving, - }) - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// getPendingTiFlashTableCount counts unavailable TiFlash replica by iterating all tables in infoCache. -func (e *executor) getPendingTiFlashTableCount(originVersion int64, pendingCount uint32) (int64, uint32) { - is := e.infoCache.GetLatest() - // If there are no schema change since last time(can be weird). - if is.SchemaMetaVersion() == originVersion { - return originVersion, pendingCount - } - cnt := uint32(0) - dbs := is.ListTablesWithSpecialAttribute(infoschema.TiFlashAttribute) - for _, db := range dbs { - if util.IsMemOrSysDB(db.DBName) { - continue - } - for _, tbl := range db.TableInfos { - if tbl.TiFlashReplica != nil && !tbl.TiFlashReplica.Available { - cnt++ - } - } - } - return is.SchemaMetaVersion(), cnt -} - -func isSessionDone(sctx sessionctx.Context) (bool, uint32) { - done := false - killed := sctx.GetSessionVars().SQLKiller.HandleSignal() == exeerrors.ErrQueryInterrupted - if killed { - return true, 1 - } - failpoint.Inject("BatchAddTiFlashSendDone", func(val failpoint.Value) { - done = val.(bool) - }) - return done, 0 -} - -func (e *executor) waitPendingTableThreshold(sctx sessionctx.Context, schemaID int64, tableID int64, originVersion int64, pendingCount uint32, threshold uint32) (bool, int64, uint32, bool) { - configRetry := tiflashCheckPendingTablesRetry - configWaitTime := tiflashCheckPendingTablesWaitTime - failpoint.Inject("FastFailCheckTiFlashPendingTables", func(value failpoint.Value) { - configRetry = value.(int) - configWaitTime = time.Millisecond * 200 - }) - - for retry := 0; retry < configRetry; retry++ { - done, killed := isSessionDone(sctx) - if done { - logutil.DDLLogger().Info("abort batch add TiFlash replica", zap.Int64("schemaID", schemaID), zap.Uint32("isKilled", killed)) - return true, originVersion, pendingCount, false - } - originVersion, pendingCount = e.getPendingTiFlashTableCount(originVersion, pendingCount) - delay := time.Duration(0) - if pendingCount < threshold { - // If there are not many unavailable tables, we don't need a force check. - return false, originVersion, pendingCount, false - } - logutil.DDLLogger().Info("too many unavailable tables, wait", - zap.Uint32("threshold", threshold), - zap.Uint32("currentPendingCount", pendingCount), - zap.Int64("schemaID", schemaID), - zap.Int64("tableID", tableID), - zap.Duration("time", configWaitTime)) - delay = configWaitTime - time.Sleep(delay) - } - logutil.DDLLogger().Info("too many unavailable tables, timeout", zap.Int64("schemaID", schemaID), zap.Int64("tableID", tableID)) - // If timeout here, we will trigger a ddl job, to force sync schema. However, it doesn't mean we remove limiter, - // so there is a force check immediately after that. - return false, originVersion, pendingCount, true -} - -func (e *executor) ModifySchemaSetTiFlashReplica(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt, tiflashReplica *ast.TiFlashReplicaSpec) error { - dbName := stmt.Name - is := e.infoCache.GetLatest() - dbInfo, ok := is.SchemaByName(dbName) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName.O) - } - - if util.IsMemOrSysDB(dbInfo.Name.L) { - return errors.Trace(dbterror.ErrUnsupportedTiFlashOperationForSysOrMemTable) - } - - tbls, err := is.SchemaTableInfos(context.Background(), dbInfo.Name) - if err != nil { - return errors.Trace(err) - } - - total := len(tbls) - succ := 0 - skip := 0 - fail := 0 - oneFail := int64(0) - - if total == 0 { - return infoschema.ErrEmptyDatabase.GenWithStack("Empty database '%v'", dbName.O) - } - err = checkTiFlashReplicaCount(sctx, tiflashReplica.Count) - if err != nil { - return errors.Trace(err) - } - - var originVersion int64 - var pendingCount uint32 - forceCheck := false - - logutil.DDLLogger().Info("start batch add TiFlash replicas", zap.Int("total", total), zap.Int64("schemaID", dbInfo.ID)) - threshold := uint32(sctx.GetSessionVars().BatchPendingTiFlashCount) - - for _, tbl := range tbls { - done, killed := isSessionDone(sctx) - if done { - logutil.DDLLogger().Info("abort batch add TiFlash replica", zap.Int64("schemaID", dbInfo.ID), zap.Uint32("isKilled", killed)) - return nil - } - - tbReplicaInfo := tbl.TiFlashReplica - if !shouldModifyTiFlashReplica(tbReplicaInfo, tiflashReplica) { - logutil.DDLLogger().Info("skip repeated processing table", - zap.Int64("tableID", tbl.ID), - zap.Int64("schemaID", dbInfo.ID), - zap.String("tableName", tbl.Name.String()), - zap.String("schemaName", dbInfo.Name.String())) - skip++ - continue - } - - // If table is not supported, add err to warnings. - err = isTableTiFlashSupported(dbName, tbl) - if err != nil { - logutil.DDLLogger().Info("skip processing table", zap.Int64("tableID", tbl.ID), - zap.Int64("schemaID", dbInfo.ID), - zap.String("tableName", tbl.Name.String()), - zap.String("schemaName", dbInfo.Name.String()), - zap.Error(err)) - sctx.GetSessionVars().StmtCtx.AppendNote(err) - skip++ - continue - } - - // Alter `tiflashCheckPendingTablesLimit` tables are handled, we need to check if we have reached threshold. - if (succ+fail)%tiflashCheckPendingTablesLimit == 0 || forceCheck { - // We can execute one probing ddl to the latest schema, if we timeout in `pendingFunc`. - // However, we shall mark `forceCheck` to true, because we may still reach `threshold`. - finished := false - finished, originVersion, pendingCount, forceCheck = e.waitPendingTableThreshold(sctx, dbInfo.ID, tbl.ID, originVersion, pendingCount, threshold) - if finished { - logutil.DDLLogger().Info("abort batch add TiFlash replica", zap.Int64("schemaID", dbInfo.ID)) - return nil - } - } - - job := &model.Job{ - SchemaID: dbInfo.ID, - SchemaName: dbInfo.Name.L, - TableID: tbl.ID, - Type: model.ActionSetTiFlashReplica, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{*tiflashReplica}, - CDCWriteSource: sctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Database: dbInfo.Name.L, - Table: model.InvolvingAll, - }}, - SQLMode: sctx.GetSessionVars().SQLMode, - } - err := e.DoDDLJob(sctx, job) - if err != nil { - oneFail = tbl.ID - fail++ - logutil.DDLLogger().Info("processing schema table error", - zap.Int64("tableID", tbl.ID), - zap.Int64("schemaID", dbInfo.ID), - zap.Stringer("tableName", tbl.Name), - zap.Stringer("schemaName", dbInfo.Name), - zap.Error(err)) - } else { - succ++ - } - } - failStmt := "" - if fail > 0 { - failStmt = fmt.Sprintf("(including table %v)", oneFail) - } - msg := fmt.Sprintf("In total %v tables: %v succeed, %v failed%v, %v skipped", total, succ, fail, failStmt, skip) - sctx.GetSessionVars().StmtCtx.SetMessage(msg) - logutil.DDLLogger().Info("finish batch add TiFlash replica", zap.Int64("schemaID", dbInfo.ID)) - return nil -} - -func (e *executor) AlterTablePlacement(ctx sessionctx.Context, ident ast.Ident, placementPolicyRef *model.PolicyRefInfo) (err error) { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - - tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - - if checkIgnorePlacementDDL(ctx) { - return nil - } - - tblInfo := tb.Meta() - if tblInfo.TempTableType != model.TempTableNone { - return errors.Trace(dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("placement")) - } - - placementPolicyRef, err = checkAndNormalizePlacementPolicy(ctx, placementPolicyRef) - if err != nil { - return err - } - - var involvingSchemaInfo []model.InvolvingSchemaInfo - if placementPolicyRef != nil { - involvingSchemaInfo = []model.InvolvingSchemaInfo{ - { - Database: schema.Name.L, - Table: tblInfo.Name.L, - }, - { - Policy: placementPolicyRef.Name.L, - Mode: model.SharedInvolving, - }, - } - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tblInfo.ID, - SchemaName: schema.Name.L, - TableName: tblInfo.Name.L, - Type: model.ActionAlterTablePlacement, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{placementPolicyRef}, - InvolvingSchemaInfo: involvingSchemaInfo, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func checkMultiSchemaSpecs(_ sessionctx.Context, specs []*ast.DatabaseOption) error { - hasSetTiFlashReplica := false - if len(specs) == 1 { - return nil - } - for _, spec := range specs { - if spec.Tp == ast.DatabaseSetTiFlashReplica { - if hasSetTiFlashReplica { - return dbterror.ErrRunMultiSchemaChanges.FastGenByArgs(model.ActionSetTiFlashReplica.String()) - } - hasSetTiFlashReplica = true - } - } - return nil -} - -func (e *executor) AlterSchema(sctx sessionctx.Context, stmt *ast.AlterDatabaseStmt) (err error) { - // Resolve target charset and collation from options. - var ( - toCharset, toCollate string - isAlterCharsetAndCollate bool - placementPolicyRef *model.PolicyRefInfo - tiflashReplica *ast.TiFlashReplicaSpec - ) - - err = checkMultiSchemaSpecs(sctx, stmt.Options) - if err != nil { - return err - } - - for _, val := range stmt.Options { - switch val.Tp { - case ast.DatabaseOptionCharset: - if toCharset == "" { - toCharset = val.Value - } else if toCharset != val.Value { - return dbterror.ErrConflictingDeclarations.GenWithStackByArgs(toCharset, val.Value) - } - isAlterCharsetAndCollate = true - case ast.DatabaseOptionCollate: - info, errGetCollate := collate.GetCollationByName(val.Value) - if errGetCollate != nil { - return errors.Trace(errGetCollate) - } - if toCharset == "" { - toCharset = info.CharsetName - } else if toCharset != info.CharsetName { - return dbterror.ErrConflictingDeclarations.GenWithStackByArgs(toCharset, info.CharsetName) - } - toCollate = info.Name - isAlterCharsetAndCollate = true - case ast.DatabaseOptionPlacementPolicy: - placementPolicyRef = &model.PolicyRefInfo{Name: model.NewCIStr(val.Value)} - case ast.DatabaseSetTiFlashReplica: - tiflashReplica = val.TiFlashReplica - } - } - - if isAlterCharsetAndCollate { - if err = e.ModifySchemaCharsetAndCollate(sctx, stmt, toCharset, toCollate); err != nil { - return err - } - } - if placementPolicyRef != nil { - if err = e.ModifySchemaDefaultPlacement(sctx, stmt, placementPolicyRef); err != nil { - return err - } - } - if tiflashReplica != nil { - if err = e.ModifySchemaSetTiFlashReplica(sctx, stmt, tiflashReplica); err != nil { - return err - } - } - return nil -} - -func (e *executor) DropSchema(ctx sessionctx.Context, stmt *ast.DropDatabaseStmt) (err error) { - is := e.infoCache.GetLatest() - old, ok := is.SchemaByName(stmt.Name) - if !ok { - if stmt.IfExists { - return nil - } - return infoschema.ErrDatabaseDropExists.GenWithStackByArgs(stmt.Name) - } - fkCheck := ctx.GetSessionVars().ForeignKeyChecks - err = checkDatabaseHasForeignKeyReferred(e.ctx, is, old.Name, fkCheck) - if err != nil { - return err - } - job := &model.Job{ - SchemaID: old.ID, - SchemaName: old.Name.L, - SchemaState: old.State, - Type: model.ActionDropSchema, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{fkCheck}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Database: old.Name.L, - Table: model.InvolvingAll, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - if err != nil { - if infoschema.ErrDatabaseNotExists.Equal(err) { - if stmt.IfExists { - return nil - } - return infoschema.ErrDatabaseDropExists.GenWithStackByArgs(stmt.Name) - } - return errors.Trace(err) - } - if !config.TableLockEnabled() { - return nil - } - // Clear table locks hold by the session. - tbs, err := is.SchemaTableInfos(e.ctx, stmt.Name) - if err != nil { - return errors.Trace(err) - } - - lockTableIDs := make([]int64, 0) - for _, tb := range tbs { - if ok, _ := ctx.CheckTableLocked(tb.ID); ok { - lockTableIDs = append(lockTableIDs, tb.ID) - } - } - ctx.ReleaseTableLockByTableIDs(lockTableIDs) - return nil -} - -func (e *executor) RecoverSchema(ctx sessionctx.Context, recoverSchemaInfo *RecoverSchemaInfo) error { - involvedSchemas := []model.InvolvingSchemaInfo{{ - Database: recoverSchemaInfo.DBInfo.Name.L, - Table: model.InvolvingAll, - }} - if recoverSchemaInfo.OldSchemaName.L != recoverSchemaInfo.DBInfo.Name.L { - involvedSchemas = append(involvedSchemas, model.InvolvingSchemaInfo{ - Database: recoverSchemaInfo.OldSchemaName.L, - Table: model.InvolvingAll, - }) - } - recoverSchemaInfo.State = model.StateNone - job := &model.Job{ - Type: model.ActionRecoverSchema, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{recoverSchemaInfo, recoverCheckFlagNone}, - InvolvingSchemaInfo: involvedSchemas, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err := e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func checkTooLongSchema(schema model.CIStr) error { - if utf8.RuneCountInString(schema.L) > mysql.MaxDatabaseNameLength { - return dbterror.ErrTooLongIdent.GenWithStackByArgs(schema) - } - return nil -} - -func checkTooLongTable(table model.CIStr) error { - if utf8.RuneCountInString(table.L) > mysql.MaxTableNameLength { - return dbterror.ErrTooLongIdent.GenWithStackByArgs(table) - } - return nil -} - -func checkTooLongIndex(index model.CIStr) error { - if utf8.RuneCountInString(index.L) > mysql.MaxIndexIdentifierLen { - return dbterror.ErrTooLongIdent.GenWithStackByArgs(index) - } - return nil -} - -func checkTooLongColumn(col model.CIStr) error { - if utf8.RuneCountInString(col.L) > mysql.MaxColumnNameLength { - return dbterror.ErrTooLongIdent.GenWithStackByArgs(col) - } - return nil -} - -func checkTooLongForeignKey(fk model.CIStr) error { - if utf8.RuneCountInString(fk.L) > mysql.MaxForeignKeyIdentifierLen { - return dbterror.ErrTooLongIdent.GenWithStackByArgs(fk) - } - return nil -} - -func getDefaultCollationForUTF8MB4(sessVars *variable.SessionVars, cs string) (string, error) { - if sessVars == nil || cs != charset.CharsetUTF8MB4 { - return "", nil - } - defaultCollation, err := sessVars.GetSessionOrGlobalSystemVar(context.Background(), variable.DefaultCollationForUTF8MB4) - if err != nil { - return "", err - } - return defaultCollation, nil -} - -// GetDefaultCollation returns the default collation for charset and handle the default collation for UTF8MB4. -func GetDefaultCollation(sessVars *variable.SessionVars, cs string) (string, error) { - coll, err := getDefaultCollationForUTF8MB4(sessVars, cs) - if err != nil { - return "", errors.Trace(err) - } - if coll != "" { - return coll, nil - } - - coll, err = charset.GetDefaultCollation(cs) - if err != nil { - return "", errors.Trace(err) - } - return coll, nil -} - -// ResolveCharsetCollation will resolve the charset and collate by the order of parameters: -// * If any given ast.CharsetOpt is not empty, the resolved charset and collate will be returned. -// * If all ast.CharsetOpts are empty, the default charset and collate will be returned. -func ResolveCharsetCollation(sessVars *variable.SessionVars, charsetOpts ...ast.CharsetOpt) (chs string, coll string, err error) { - for _, v := range charsetOpts { - if v.Col != "" { - collation, err := collate.GetCollationByName(v.Col) - if err != nil { - return "", "", errors.Trace(err) - } - if v.Chs != "" && collation.CharsetName != v.Chs { - return "", "", charset.ErrCollationCharsetMismatch.GenWithStackByArgs(v.Col, v.Chs) - } - return collation.CharsetName, v.Col, nil - } - if v.Chs != "" { - coll, err := GetDefaultCollation(sessVars, v.Chs) - if err != nil { - return "", "", errors.Trace(err) - } - return v.Chs, coll, nil - } - } - chs, coll = charset.GetDefaultCharsetAndCollate() - utf8mb4Coll, err := getDefaultCollationForUTF8MB4(sessVars, chs) - if err != nil { - return "", "", errors.Trace(err) - } - if utf8mb4Coll != "" { - return chs, utf8mb4Coll, nil - } - return chs, coll, nil -} - -// IsAutoRandomColumnID returns true if the given column ID belongs to an auto_random column. -func IsAutoRandomColumnID(tblInfo *model.TableInfo, colID int64) bool { - if !tblInfo.ContainsAutoRandomBits() { - return false - } - if tblInfo.PKIsHandle { - return tblInfo.GetPkColInfo().ID == colID - } else if tblInfo.IsCommonHandle { - pk := tables.FindPrimaryIndex(tblInfo) - if pk == nil { - return false - } - offset := pk.Columns[0].Offset - return tblInfo.Columns[offset].ID == colID - } - return false -} - -// checkInvisibleIndexOnPK check if primary key is invisible index. -// Note: PKIsHandle == true means the table already has a visible primary key, -// we do not need do a check for this case and return directly, -// because whether primary key is invisible has been check when creating table. -func checkInvisibleIndexOnPK(tblInfo *model.TableInfo) error { - if tblInfo.PKIsHandle { - return nil - } - pk := tblInfo.GetPrimaryKey() - if pk != nil && pk.Invisible { - return dbterror.ErrPKIndexCantBeInvisible - } - return nil -} - -func (e *executor) assignPartitionIDs(defs []model.PartitionDefinition) error { - genIDs, err := e.genGlobalIDs(len(defs)) - if err != nil { - return errors.Trace(err) - } - for i := range defs { - defs[i].ID = genIDs[i] - } - return nil -} - -func (e *executor) CreateTable(ctx sessionctx.Context, s *ast.CreateTableStmt) (err error) { - ident := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - - var ( - referTbl table.Table - involvingRef []model.InvolvingSchemaInfo - ) - if s.ReferTable != nil { - referIdent := ast.Ident{Schema: s.ReferTable.Schema, Name: s.ReferTable.Name} - _, ok := is.SchemaByName(referIdent.Schema) - if !ok { - return infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) - } - referTbl, err = is.TableByName(e.ctx, referIdent.Schema, referIdent.Name) - if err != nil { - return infoschema.ErrTableNotExists.GenWithStackByArgs(referIdent.Schema, referIdent.Name) - } - involvingRef = append(involvingRef, model.InvolvingSchemaInfo{ - Database: s.ReferTable.Schema.L, - Table: s.ReferTable.Name.L, - Mode: model.SharedInvolving, - }) - } - - // build tableInfo - var tbInfo *model.TableInfo - if s.ReferTable != nil { - tbInfo, err = BuildTableInfoWithLike(ctx, ident, referTbl.Meta(), s) - } else { - tbInfo, err = BuildTableInfoWithStmt(ctx, s, schema.Charset, schema.Collate, schema.PlacementPolicyRef) - } - if err != nil { - return errors.Trace(err) - } - - if err = checkTableInfoValidWithStmt(ctx, tbInfo, s); err != nil { - return err - } - if err = checkTableForeignKeysValid(ctx, is, schema.Name.L, tbInfo); err != nil { - return err - } - - onExist := OnExistError - if s.IfNotExists { - onExist = OnExistIgnore - } - - return e.CreateTableWithInfo(ctx, schema.Name, tbInfo, involvingRef, WithOnExist(onExist)) -} - -// createTableWithInfoJob returns the table creation job. -// WARNING: it may return a nil job, which means you don't need to submit any DDL job. -func (e *executor) createTableWithInfoJob( - ctx sessionctx.Context, - dbName model.CIStr, - tbInfo *model.TableInfo, - involvingRef []model.InvolvingSchemaInfo, - onExist OnExist, -) (job *model.Job, err error) { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(dbName) - if !ok { - return nil, infoschema.ErrDatabaseNotExists.GenWithStackByArgs(dbName) - } - - if err = handleTablePlacement(ctx, tbInfo); err != nil { - return nil, errors.Trace(err) - } - - var oldViewTblID int64 - if oldTable, err := is.TableByName(e.ctx, schema.Name, tbInfo.Name); err == nil { - err = infoschema.ErrTableExists.GenWithStackByArgs(ast.Ident{Schema: schema.Name, Name: tbInfo.Name}) - switch onExist { - case OnExistIgnore: - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil, nil - case OnExistReplace: - // only CREATE OR REPLACE VIEW is supported at the moment. - if tbInfo.View != nil { - if oldTable.Meta().IsView() { - oldViewTblID = oldTable.Meta().ID - break - } - // The object to replace isn't a view. - return nil, dbterror.ErrWrongObject.GenWithStackByArgs(dbName, tbInfo.Name, "VIEW") - } - return nil, err - default: - return nil, err - } - } - - if err := checkTableInfoValidExtra(tbInfo); err != nil { - return nil, err - } - - var actionType model.ActionType - args := []any{tbInfo} - switch { - case tbInfo.View != nil: - actionType = model.ActionCreateView - args = append(args, onExist == OnExistReplace, oldViewTblID) - case tbInfo.Sequence != nil: - actionType = model.ActionCreateSequence - default: - actionType = model.ActionCreateTable - args = append(args, ctx.GetSessionVars().ForeignKeyChecks) - } - - var involvingSchemas []model.InvolvingSchemaInfo - sharedInvolvingFromTableInfo := getSharedInvolvingSchemaInfo(tbInfo) - - if sum := len(involvingRef) + len(sharedInvolvingFromTableInfo); sum > 0 { - involvingSchemas = make([]model.InvolvingSchemaInfo, 0, sum+1) - involvingSchemas = append(involvingSchemas, model.InvolvingSchemaInfo{ - Database: schema.Name.L, - Table: tbInfo.Name.L, - }) - involvingSchemas = append(involvingSchemas, involvingRef...) - involvingSchemas = append(involvingSchemas, sharedInvolvingFromTableInfo...) - } - - job = &model.Job{ - SchemaID: schema.ID, - SchemaName: schema.Name.L, - TableName: tbInfo.Name.L, - Type: actionType, - BinlogInfo: &model.HistoryInfo{}, - Args: args, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: involvingSchemas, - SQLMode: ctx.GetSessionVars().SQLMode, - } - return job, nil -} - -func getSharedInvolvingSchemaInfo(info *model.TableInfo) []model.InvolvingSchemaInfo { - ret := make([]model.InvolvingSchemaInfo, 0, len(info.ForeignKeys)+1) - for _, fk := range info.ForeignKeys { - ret = append(ret, model.InvolvingSchemaInfo{ - Database: fk.RefSchema.L, - Table: fk.RefTable.L, - Mode: model.SharedInvolving, - }) - } - if ref := info.PlacementPolicyRef; ref != nil { - ret = append(ret, model.InvolvingSchemaInfo{ - Policy: ref.Name.L, - Mode: model.SharedInvolving, - }) - } - return ret -} - -func (e *executor) createTableWithInfoPost( - ctx sessionctx.Context, - tbInfo *model.TableInfo, - schemaID int64, -) error { - var err error - var partitions []model.PartitionDefinition - if pi := tbInfo.GetPartitionInfo(); pi != nil { - partitions = pi.Definitions - } - preSplitAndScatter(ctx, e.store, tbInfo, partitions) - if tbInfo.AutoIncID > 1 { - // Default tableAutoIncID base is 0. - // If the first ID is expected to greater than 1, we need to do rebase. - newEnd := tbInfo.AutoIncID - 1 - var allocType autoid.AllocatorType - if tbInfo.SepAutoInc() { - allocType = autoid.AutoIncrementType - } else { - allocType = autoid.RowIDAllocType - } - if err = e.handleAutoIncID(tbInfo, schemaID, newEnd, allocType); err != nil { - return errors.Trace(err) - } - } - // For issue https://github.com/pingcap/tidb/issues/46093 - if tbInfo.AutoIncIDExtra != 0 { - if err = e.handleAutoIncID(tbInfo, schemaID, tbInfo.AutoIncIDExtra-1, autoid.RowIDAllocType); err != nil { - return errors.Trace(err) - } - } - if tbInfo.AutoRandID > 1 { - // Default tableAutoRandID base is 0. - // If the first ID is expected to greater than 1, we need to do rebase. - newEnd := tbInfo.AutoRandID - 1 - err = e.handleAutoIncID(tbInfo, schemaID, newEnd, autoid.AutoRandomType) - } - return err -} - -func (e *executor) CreateTableWithInfo( - ctx sessionctx.Context, - dbName model.CIStr, - tbInfo *model.TableInfo, - involvingRef []model.InvolvingSchemaInfo, - cs ...CreateTableOption, -) (err error) { - c := GetCreateTableConfig(cs) - - job, err := e.createTableWithInfoJob( - ctx, dbName, tbInfo, involvingRef, c.OnExist, - ) - if err != nil { - return err - } - if job == nil { - return nil - } - - jobW := NewJobWrapper(job, c.IDAllocated) - - err = e.DoDDLJobWrapper(ctx, jobW) - if err != nil { - // table exists, but if_not_exists flags is true, so we ignore this error. - if c.OnExist == OnExistIgnore && infoschema.ErrTableExists.Equal(err) { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - err = nil - } - } else { - err = e.createTableWithInfoPost(ctx, tbInfo, job.SchemaID) - } - - return errors.Trace(err) -} - -func (e *executor) BatchCreateTableWithInfo(ctx sessionctx.Context, - dbName model.CIStr, - infos []*model.TableInfo, - cs ...CreateTableOption, -) error { - failpoint.Inject("RestoreBatchCreateTableEntryTooLarge", func(val failpoint.Value) { - injectBatchSize := val.(int) - if len(infos) > injectBatchSize { - failpoint.Return(kv.ErrEntryTooLarge) - } - }) - c := GetCreateTableConfig(cs) - - jobW := NewJobWrapper( - &model.Job{ - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - }, - c.IDAllocated, - ) - args := make([]*model.TableInfo, 0, len(infos)) - - var err error - - // check if there are any duplicated table names - duplication := make(map[string]struct{}) - // TODO filter those duplicated info out. - for _, info := range infos { - if _, ok := duplication[info.Name.L]; ok { - err = infoschema.ErrTableExists.FastGenByArgs("can not batch create tables with same name") - if c.OnExist == OnExistIgnore && infoschema.ErrTableExists.Equal(err) { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - err = nil - } - } - if err != nil { - return errors.Trace(err) - } - - duplication[info.Name.L] = struct{}{} - } - - for _, info := range infos { - job, err := e.createTableWithInfoJob(ctx, dbName, info, nil, c.OnExist) - if err != nil { - return errors.Trace(err) - } - if job == nil { - continue - } - - // if jobW.Type == model.ActionCreateTables, it is initialized - // if not, initialize jobW by job.XXXX - if jobW.Type != model.ActionCreateTables { - jobW.Type = model.ActionCreateTables - jobW.SchemaID = job.SchemaID - jobW.SchemaName = job.SchemaName - } - - // append table job args - info, ok := job.Args[0].(*model.TableInfo) - if !ok { - return errors.Trace(fmt.Errorf("except table info")) - } - args = append(args, info) - jobW.InvolvingSchemaInfo = append(jobW.InvolvingSchemaInfo, model.InvolvingSchemaInfo{ - Database: dbName.L, - Table: info.Name.L, - }) - if sharedInv := getSharedInvolvingSchemaInfo(info); len(sharedInv) > 0 { - jobW.InvolvingSchemaInfo = append(jobW.InvolvingSchemaInfo, sharedInv...) - } - } - if len(args) == 0 { - return nil - } - jobW.Args = append(jobW.Args, args) - jobW.Args = append(jobW.Args, ctx.GetSessionVars().ForeignKeyChecks) - - err = e.DoDDLJobWrapper(ctx, jobW) - if err != nil { - // table exists, but if_not_exists flags is true, so we ignore this error. - if c.OnExist == OnExistIgnore && infoschema.ErrTableExists.Equal(err) { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - err = nil - } - return errors.Trace(err) - } - - for j := range args { - if err = e.createTableWithInfoPost(ctx, args[j], jobW.SchemaID); err != nil { - return errors.Trace(err) - } - } - - return nil -} - -func (e *executor) CreatePlacementPolicyWithInfo(ctx sessionctx.Context, policy *model.PolicyInfo, onExist OnExist) error { - if checkIgnorePlacementDDL(ctx) { - return nil - } - - policyName := policy.Name - if policyName.L == defaultPlacementPolicyName { - return errors.Trace(infoschema.ErrReservedSyntax.GenWithStackByArgs(policyName)) - } - - // Check policy existence. - _, ok := e.infoCache.GetLatest().PolicyByName(policyName) - if ok { - err := infoschema.ErrPlacementPolicyExists.GenWithStackByArgs(policyName) - switch onExist { - case OnExistIgnore: - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - case OnExistError: - return err - } - } - - if err := checkPolicyValidation(policy.PlacementSettings); err != nil { - return err - } - - policyID, err := e.genPlacementPolicyID() - if err != nil { - return err - } - policy.ID = policyID - - job := &model.Job{ - SchemaName: policy.Name.L, - Type: model.ActionCreatePlacementPolicy, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{policy, onExist == OnExistReplace}, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Policy: policy.Name.L, - }}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// preSplitAndScatter performs pre-split and scatter of the table's regions. -// If `pi` is not nil, will only split region for `pi`, this is used when add partition. -func preSplitAndScatter(ctx sessionctx.Context, store kv.Storage, tbInfo *model.TableInfo, parts []model.PartitionDefinition) { - if tbInfo.TempTableType != model.TempTableNone { - return - } - sp, ok := store.(kv.SplittableStore) - if !ok || atomic.LoadUint32(&EnableSplitTableRegion) == 0 { - return - } - var ( - preSplit func() - scatterRegion bool - ) - val, err := ctx.GetSessionVars().GetGlobalSystemVar(context.Background(), variable.TiDBScatterRegion) - if err != nil { - logutil.DDLLogger().Warn("won't scatter region", zap.Error(err)) - } else { - scatterRegion = variable.TiDBOptOn(val) - } - if len(parts) > 0 { - preSplit = func() { splitPartitionTableRegion(ctx, sp, tbInfo, parts, scatterRegion) } - } else { - preSplit = func() { splitTableRegion(ctx, sp, tbInfo, scatterRegion) } - } - if scatterRegion { - preSplit() - } else { - go preSplit() - } -} - -func (e *executor) FlashbackCluster(ctx sessionctx.Context, flashbackTS uint64) error { - logutil.DDLLogger().Info("get flashback cluster job", zap.Stringer("flashbackTS", oracle.GetTimeFromTS(flashbackTS))) - nowTS, err := ctx.GetStore().GetOracle().GetTimestamp(e.ctx, &oracle.Option{}) - if err != nil { - return errors.Trace(err) - } - gap := time.Until(oracle.GetTimeFromTS(nowTS)).Abs() - if gap > 1*time.Second { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Gap between local time and PD TSO is %s, please check PD/system time", gap)) - } - job := &model.Job{ - Type: model.ActionFlashbackCluster, - BinlogInfo: &model.HistoryInfo{}, - // The value for global variables is meaningless, it will cover during flashback cluster. - Args: []any{ - flashbackTS, - map[string]any{}, - true, /* tidb_gc_enable */ - variable.On, /* tidb_enable_auto_analyze */ - variable.Off, /* tidb_super_read_only */ - 0, /* totalRegions */ - 0, /* startTS */ - 0, /* commitTS */ - variable.On, /* tidb_ttl_job_enable */ - []kv.KeyRange{} /* flashback key_ranges */}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - // FLASHBACK CLUSTER affects all schemas and tables. - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Database: model.InvolvingAll, - Table: model.InvolvingAll, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) RecoverTable(ctx sessionctx.Context, recoverInfo *RecoverInfo) (err error) { - is := e.infoCache.GetLatest() - schemaID, tbInfo := recoverInfo.SchemaID, recoverInfo.TableInfo - // Check schema exist. - schema, ok := is.SchemaByID(schemaID) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", schemaID), - )) - } - // Check not exist table with same name. - if ok := is.TableExists(schema.Name, tbInfo.Name); ok { - return infoschema.ErrTableExists.GenWithStackByArgs(tbInfo.Name) - } - - // for "flashback table xxx to yyy" - // Note: this case only allow change table name, schema remains the same. - var involvedSchemas []model.InvolvingSchemaInfo - if recoverInfo.OldTableName != tbInfo.Name.L { - involvedSchemas = []model.InvolvingSchemaInfo{ - {Database: schema.Name.L, Table: recoverInfo.OldTableName}, - {Database: schema.Name.L, Table: tbInfo.Name.L}, - } - } - - tbInfo.State = model.StateNone - job := &model.Job{ - SchemaID: schemaID, - TableID: tbInfo.ID, - SchemaName: schema.Name.L, - TableName: tbInfo.Name.L, - - Type: model.ActionRecoverTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{recoverInfo, recoverCheckFlagNone}, - InvolvingSchemaInfo: involvedSchemas, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) CreateView(ctx sessionctx.Context, s *ast.CreateViewStmt) (err error) { - viewInfo, err := BuildViewInfo(s) - if err != nil { - return err - } - - cols := make([]*table.Column, len(s.Cols)) - for i, v := range s.Cols { - cols[i] = table.ToColumn(&model.ColumnInfo{ - Name: v, - ID: int64(i), - Offset: i, - State: model.StatePublic, - }) - } - - tblCharset := "" - tblCollate := "" - if v, ok := ctx.GetSessionVars().GetSystemVar(variable.CharacterSetConnection); ok { - tblCharset = v - } - if v, ok := ctx.GetSessionVars().GetSystemVar(variable.CollationConnection); ok { - tblCollate = v - } - - tbInfo, err := BuildTableInfo(ctx, s.ViewName.Name, cols, nil, tblCharset, tblCollate) - if err != nil { - return err - } - tbInfo.View = viewInfo - - onExist := OnExistError - if s.OrReplace { - onExist = OnExistReplace - } - - return e.CreateTableWithInfo(ctx, s.ViewName.Schema, tbInfo, nil, WithOnExist(onExist)) -} - -func checkCharsetAndCollation(cs string, co string) error { - if !charset.ValidCharsetAndCollation(cs, co) { - return dbterror.ErrUnknownCharacterSet.GenWithStackByArgs(cs) - } - if co != "" { - if _, err := collate.GetCollationByName(co); err != nil { - return errors.Trace(err) - } - } - return nil -} - -// handleAutoIncID handles auto_increment option in DDL. It creates a ID counter for the table and initiates the counter to a proper value. -// For example if the option sets auto_increment to 10. The counter will be set to 9. So the next allocated ID will be 10. -func (e *executor) handleAutoIncID(tbInfo *model.TableInfo, schemaID int64, newEnd int64, tp autoid.AllocatorType) error { - allocs := autoid.NewAllocatorsFromTblInfo(e.getAutoIDRequirement(), schemaID, tbInfo) - if alloc := allocs.Get(tp); alloc != nil { - err := alloc.Rebase(context.Background(), newEnd, false) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -// TODO we can unify this part with ddlCtx. -func (e *executor) getAutoIDRequirement() autoid.Requirement { - return &asAutoIDRequirement{ - store: e.store, - autoidCli: e.autoidCli, - } -} - -func shardingBits(tblInfo *model.TableInfo) uint64 { - if tblInfo.ShardRowIDBits > 0 { - return tblInfo.ShardRowIDBits - } - return tblInfo.AutoRandomBits -} - -// isIgnorableSpec checks if the spec type is ignorable. -// Some specs are parsed by ignored. This is for compatibility. -func isIgnorableSpec(tp ast.AlterTableType) bool { - // AlterTableLock/AlterTableAlgorithm are ignored. - return tp == ast.AlterTableLock || tp == ast.AlterTableAlgorithm -} - -// GetCharsetAndCollateInTableOption will iterate the charset and collate in the options, -// and returns the last charset and collate in options. If there is no charset in the options, -// the returns charset will be "", the same as collate. -func GetCharsetAndCollateInTableOption(sessVars *variable.SessionVars, startIdx int, options []*ast.TableOption) (chs, coll string, err error) { - for i := startIdx; i < len(options); i++ { - opt := options[i] - // we set the charset to the last option. example: alter table t charset latin1 charset utf8 collate utf8_bin; - // the charset will be utf8, collate will be utf8_bin - switch opt.Tp { - case ast.TableOptionCharset: - info, err := charset.GetCharsetInfo(opt.StrValue) - if err != nil { - return "", "", err - } - if len(chs) == 0 { - chs = info.Name - } else if chs != info.Name { - return "", "", dbterror.ErrConflictingDeclarations.GenWithStackByArgs(chs, info.Name) - } - if len(coll) == 0 { - defaultColl, err := getDefaultCollationForUTF8MB4(sessVars, chs) - if err != nil { - return "", "", errors.Trace(err) - } - if len(defaultColl) == 0 { - coll = info.DefaultCollation - } else { - coll = defaultColl - } - } - case ast.TableOptionCollate: - info, err := collate.GetCollationByName(opt.StrValue) - if err != nil { - return "", "", err - } - if len(chs) == 0 { - chs = info.CharsetName - } else if chs != info.CharsetName { - return "", "", dbterror.ErrCollationCharsetMismatch.GenWithStackByArgs(info.Name, chs) - } - coll = info.Name - } - } - return -} - -// NeedToOverwriteColCharset return true for altering charset and specified CONVERT TO. -func NeedToOverwriteColCharset(options []*ast.TableOption) bool { - for i := len(options) - 1; i >= 0; i-- { - opt := options[i] - if opt.Tp == ast.TableOptionCharset { - // Only overwrite columns charset if the option contains `CONVERT TO`. - return opt.UintValue == ast.TableOptionCharsetWithConvertTo - } - } - return false -} - -// resolveAlterTableAddColumns splits "add columns" to multiple spec. For example, -// `ALTER TABLE ADD COLUMN (c1 INT, c2 INT)` is split into -// `ALTER TABLE ADD COLUMN c1 INT, ADD COLUMN c2 INT`. -func resolveAlterTableAddColumns(spec *ast.AlterTableSpec) []*ast.AlterTableSpec { - specs := make([]*ast.AlterTableSpec, 0, len(spec.NewColumns)+len(spec.NewConstraints)) - for _, col := range spec.NewColumns { - t := *spec - t.NewColumns = []*ast.ColumnDef{col} - t.NewConstraints = []*ast.Constraint{} - specs = append(specs, &t) - } - // Split the add constraints from AlterTableSpec. - for _, con := range spec.NewConstraints { - t := *spec - t.NewColumns = []*ast.ColumnDef{} - t.NewConstraints = []*ast.Constraint{} - t.Constraint = con - t.Tp = ast.AlterTableAddConstraint - specs = append(specs, &t) - } - return specs -} - -// ResolveAlterTableSpec resolves alter table algorithm and removes ignore table spec in specs. -// returns valid specs, and the occurred error. -func ResolveAlterTableSpec(ctx sessionctx.Context, specs []*ast.AlterTableSpec) ([]*ast.AlterTableSpec, error) { - validSpecs := make([]*ast.AlterTableSpec, 0, len(specs)) - algorithm := ast.AlgorithmTypeDefault - for _, spec := range specs { - if spec.Tp == ast.AlterTableAlgorithm { - // Find the last AlterTableAlgorithm. - algorithm = spec.Algorithm - } - if isIgnorableSpec(spec.Tp) { - continue - } - if spec.Tp == ast.AlterTableAddColumns && (len(spec.NewColumns) > 1 || len(spec.NewConstraints) > 0) { - validSpecs = append(validSpecs, resolveAlterTableAddColumns(spec)...) - } else { - validSpecs = append(validSpecs, spec) - } - // TODO: Only allow REMOVE PARTITIONING as a single ALTER TABLE statement? - } - - // Verify whether the algorithm is supported. - for _, spec := range validSpecs { - resolvedAlgorithm, err := ResolveAlterAlgorithm(spec, algorithm) - if err != nil { - // If TiDB failed to choose a better algorithm, report the error - if resolvedAlgorithm == ast.AlgorithmTypeDefault { - return nil, errors.Trace(err) - } - // For the compatibility, we return warning instead of error when a better algorithm is chosed by TiDB - ctx.GetSessionVars().StmtCtx.AppendError(err) - } - - spec.Algorithm = resolvedAlgorithm - } - - // Only handle valid specs. - return validSpecs, nil -} - -func isMultiSchemaChanges(specs []*ast.AlterTableSpec) bool { - if len(specs) > 1 { - return true - } - if len(specs) == 1 && len(specs[0].NewColumns) > 1 && specs[0].Tp == ast.AlterTableAddColumns { - return true - } - return false -} - -func (e *executor) AlterTable(ctx context.Context, sctx sessionctx.Context, stmt *ast.AlterTableStmt) (err error) { - ident := ast.Ident{Schema: stmt.Table.Schema, Name: stmt.Table.Name} - validSpecs, err := ResolveAlterTableSpec(sctx, stmt.Specs) - if err != nil { - return errors.Trace(err) - } - - is := e.infoCache.GetLatest() - tb, err := is.TableByName(ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(err) - } - if tb.Meta().IsView() || tb.Meta().IsSequence() { - return dbterror.ErrWrongObject.GenWithStackByArgs(ident.Schema, ident.Name, "BASE TABLE") - } - if tb.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - if len(validSpecs) != 1 { - return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Alter Table") - } - if validSpecs[0].Tp != ast.AlterTableCache && validSpecs[0].Tp != ast.AlterTableNoCache { - return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Alter Table") - } - } - if isMultiSchemaChanges(validSpecs) && (sctx.GetSessionVars().EnableRowLevelChecksum || variable.EnableRowLevelChecksum.Load()) { - return dbterror.ErrRunMultiSchemaChanges.GenWithStack("Unsupported multi schema change when row level checksum is enabled") - } - // set name for anonymous foreign key. - maxForeignKeyID := tb.Meta().MaxForeignKeyID - for _, spec := range validSpecs { - if spec.Tp == ast.AlterTableAddConstraint && spec.Constraint.Tp == ast.ConstraintForeignKey && spec.Constraint.Name == "" { - maxForeignKeyID++ - spec.Constraint.Name = fmt.Sprintf("fk_%d", maxForeignKeyID) - } - } - - if len(validSpecs) > 1 { - // after MultiSchemaInfo is set, DoDDLJob will collect all jobs into - // MultiSchemaInfo and skip running them. Then we will run them in - // d.multiSchemaChange all at once. - sctx.GetSessionVars().StmtCtx.MultiSchemaInfo = model.NewMultiSchemaInfo() - } - for _, spec := range validSpecs { - var handledCharsetOrCollate bool - var ttlOptionsHandled bool - switch spec.Tp { - case ast.AlterTableAddColumns: - err = e.AddColumn(sctx, ident, spec) - case ast.AlterTableAddPartitions, ast.AlterTableAddLastPartition: - err = e.AddTablePartitions(sctx, ident, spec) - case ast.AlterTableCoalescePartitions: - err = e.CoalescePartitions(sctx, ident, spec) - case ast.AlterTableReorganizePartition: - err = e.ReorganizePartitions(sctx, ident, spec) - case ast.AlterTableReorganizeFirstPartition: - err = dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("MERGE FIRST PARTITION") - case ast.AlterTableReorganizeLastPartition: - err = dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("SPLIT LAST PARTITION") - case ast.AlterTableCheckPartitions: - err = errors.Trace(dbterror.ErrUnsupportedCheckPartition) - case ast.AlterTableRebuildPartition: - err = errors.Trace(dbterror.ErrUnsupportedRebuildPartition) - case ast.AlterTableOptimizePartition: - err = errors.Trace(dbterror.ErrUnsupportedOptimizePartition) - case ast.AlterTableRemovePartitioning: - err = e.RemovePartitioning(sctx, ident, spec) - case ast.AlterTableRepairPartition: - err = errors.Trace(dbterror.ErrUnsupportedRepairPartition) - case ast.AlterTableDropColumn: - err = e.DropColumn(sctx, ident, spec) - case ast.AlterTableDropIndex: - err = e.dropIndex(sctx, ident, model.NewCIStr(spec.Name), spec.IfExists, false) - case ast.AlterTableDropPrimaryKey: - err = e.dropIndex(sctx, ident, model.NewCIStr(mysql.PrimaryKeyName), spec.IfExists, false) - case ast.AlterTableRenameIndex: - err = e.RenameIndex(sctx, ident, spec) - case ast.AlterTableDropPartition, ast.AlterTableDropFirstPartition: - err = e.DropTablePartition(sctx, ident, spec) - case ast.AlterTableTruncatePartition: - err = e.TruncateTablePartition(sctx, ident, spec) - case ast.AlterTableWriteable: - if !config.TableLockEnabled() { - return nil - } - tName := &ast.TableName{Schema: ident.Schema, Name: ident.Name} - if spec.Writeable { - err = e.CleanupTableLock(sctx, []*ast.TableName{tName}) - } else { - lockStmt := &ast.LockTablesStmt{ - TableLocks: []ast.TableLock{ - { - Table: tName, - Type: model.TableLockReadOnly, - }, - }, - } - err = e.LockTables(sctx, lockStmt) - } - case ast.AlterTableExchangePartition: - err = e.ExchangeTablePartition(sctx, ident, spec) - case ast.AlterTableAddConstraint: - constr := spec.Constraint - switch spec.Constraint.Tp { - case ast.ConstraintKey, ast.ConstraintIndex: - err = e.createIndex(sctx, ident, ast.IndexKeyTypeNone, model.NewCIStr(constr.Name), - spec.Constraint.Keys, constr.Option, constr.IfNotExists) - case ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey: - err = e.createIndex(sctx, ident, ast.IndexKeyTypeUnique, model.NewCIStr(constr.Name), - spec.Constraint.Keys, constr.Option, false) // IfNotExists should be not applied - case ast.ConstraintForeignKey: - // NOTE: we do not handle `symbol` and `index_name` well in the parser and we do not check ForeignKey already exists, - // so we just also ignore the `if not exists` check. - err = e.CreateForeignKey(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, spec.Constraint.Refer) - case ast.ConstraintPrimaryKey: - err = e.CreatePrimaryKey(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, constr.Option) - case ast.ConstraintFulltext: - sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTableCantHandleFt) - case ast.ConstraintCheck: - if !variable.EnableCheckConstraint.Load() { - sctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) - } else { - err = e.CreateCheckConstraint(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint) - } - default: - // Nothing to do now. - } - case ast.AlterTableDropForeignKey: - // NOTE: we do not check `if not exists` and `if exists` for ForeignKey now. - err = e.DropForeignKey(sctx, ident, model.NewCIStr(spec.Name)) - case ast.AlterTableModifyColumn: - err = e.ModifyColumn(ctx, sctx, ident, spec) - case ast.AlterTableChangeColumn: - err = e.ChangeColumn(ctx, sctx, ident, spec) - case ast.AlterTableRenameColumn: - err = e.RenameColumn(sctx, ident, spec) - case ast.AlterTableAlterColumn: - err = e.AlterColumn(sctx, ident, spec) - case ast.AlterTableRenameTable: - newIdent := ast.Ident{Schema: spec.NewTable.Schema, Name: spec.NewTable.Name} - isAlterTable := true - err = e.renameTable(sctx, ident, newIdent, isAlterTable) - case ast.AlterTablePartition: - err = e.AlterTablePartitioning(sctx, ident, spec) - case ast.AlterTableOption: - var placementPolicyRef *model.PolicyRefInfo - for i, opt := range spec.Options { - switch opt.Tp { - case ast.TableOptionShardRowID: - if opt.UintValue > shardRowIDBitsMax { - opt.UintValue = shardRowIDBitsMax - } - err = e.ShardRowID(sctx, ident, opt.UintValue) - case ast.TableOptionAutoIncrement: - err = e.RebaseAutoID(sctx, ident, int64(opt.UintValue), autoid.AutoIncrementType, opt.BoolValue) - case ast.TableOptionAutoIdCache: - if opt.UintValue > uint64(math.MaxInt64) { - // TODO: Refine this error. - return errors.New("table option auto_id_cache overflows int64") - } - err = e.AlterTableAutoIDCache(sctx, ident, int64(opt.UintValue)) - case ast.TableOptionAutoRandomBase: - err = e.RebaseAutoID(sctx, ident, int64(opt.UintValue), autoid.AutoRandomType, opt.BoolValue) - case ast.TableOptionComment: - spec.Comment = opt.StrValue - err = e.AlterTableComment(sctx, ident, spec) - case ast.TableOptionCharset, ast.TableOptionCollate: - // GetCharsetAndCollateInTableOption will get the last charset and collate in the options, - // so it should be handled only once. - if handledCharsetOrCollate { - continue - } - var toCharset, toCollate string - toCharset, toCollate, err = GetCharsetAndCollateInTableOption(sctx.GetSessionVars(), i, spec.Options) - if err != nil { - return err - } - needsOverwriteCols := NeedToOverwriteColCharset(spec.Options) - err = e.AlterTableCharsetAndCollate(sctx, ident, toCharset, toCollate, needsOverwriteCols) - handledCharsetOrCollate = true - case ast.TableOptionPlacementPolicy: - placementPolicyRef = &model.PolicyRefInfo{ - Name: model.NewCIStr(opt.StrValue), - } - case ast.TableOptionEngine: - case ast.TableOptionRowFormat: - case ast.TableOptionTTL, ast.TableOptionTTLEnable, ast.TableOptionTTLJobInterval: - var ttlInfo *model.TTLInfo - var ttlEnable *bool - var ttlJobInterval *string - - if ttlOptionsHandled { - continue - } - ttlInfo, ttlEnable, ttlJobInterval, err = getTTLInfoInOptions(spec.Options) - if err != nil { - return err - } - err = e.AlterTableTTLInfoOrEnable(sctx, ident, ttlInfo, ttlEnable, ttlJobInterval) - - ttlOptionsHandled = true - default: - err = dbterror.ErrUnsupportedAlterTableOption - } - - if err != nil { - return errors.Trace(err) - } - } - - if placementPolicyRef != nil { - err = e.AlterTablePlacement(sctx, ident, placementPolicyRef) - } - case ast.AlterTableSetTiFlashReplica: - err = e.AlterTableSetTiFlashReplica(sctx, ident, spec.TiFlashReplica) - case ast.AlterTableOrderByColumns: - err = e.OrderByColumns(sctx, ident) - case ast.AlterTableIndexInvisible: - err = e.AlterIndexVisibility(sctx, ident, spec.IndexName, spec.Visibility) - case ast.AlterTableAlterCheck: - if !variable.EnableCheckConstraint.Load() { - sctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) - } else { - err = e.AlterCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name), spec.Constraint.Enforced) - } - case ast.AlterTableDropCheck: - if !variable.EnableCheckConstraint.Load() { - sctx.GetSessionVars().StmtCtx.AppendWarning(errCheckConstraintIsOff) - } else { - err = e.DropCheckConstraint(sctx, ident, model.NewCIStr(spec.Constraint.Name)) - } - case ast.AlterTableWithValidation: - sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedAlterTableWithValidation) - case ast.AlterTableWithoutValidation: - sctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedAlterTableWithoutValidation) - case ast.AlterTableAddStatistics: - err = e.AlterTableAddStatistics(sctx, ident, spec.Statistics, spec.IfNotExists) - case ast.AlterTableDropStatistics: - err = e.AlterTableDropStatistics(sctx, ident, spec.Statistics, spec.IfExists) - case ast.AlterTableAttributes: - err = e.AlterTableAttributes(sctx, ident, spec) - case ast.AlterTablePartitionAttributes: - err = e.AlterTablePartitionAttributes(sctx, ident, spec) - case ast.AlterTablePartitionOptions: - err = e.AlterTablePartitionOptions(sctx, ident, spec) - case ast.AlterTableCache: - err = e.AlterTableCache(sctx, ident) - case ast.AlterTableNoCache: - err = e.AlterTableNoCache(sctx, ident) - case ast.AlterTableDisableKeys, ast.AlterTableEnableKeys: - // Nothing to do now, see https://github.com/pingcap/tidb/issues/1051 - // MyISAM specific - case ast.AlterTableRemoveTTL: - // the parser makes sure we have only one `ast.AlterTableRemoveTTL` in an alter statement - err = e.AlterTableRemoveTTL(sctx, ident) - default: - err = errors.Trace(dbterror.ErrUnsupportedAlterTableSpec) - } - - if err != nil { - return errors.Trace(err) - } - } - - if sctx.GetSessionVars().StmtCtx.MultiSchemaInfo != nil { - info := sctx.GetSessionVars().StmtCtx.MultiSchemaInfo - sctx.GetSessionVars().StmtCtx.MultiSchemaInfo = nil - err = e.multiSchemaChange(sctx, ident, info) - if err != nil { - return errors.Trace(err) - } - } - - return nil -} - -func (e *executor) multiSchemaChange(ctx sessionctx.Context, ti ast.Ident, info *model.MultiSchemaInfo) error { - subJobs := info.SubJobs - if len(subJobs) == 0 { - return nil - } - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return errors.Trace(err) - } - - logFn := logutil.DDLLogger().Warn - if intest.InTest { - logFn = logutil.DDLLogger().Fatal - } - - var involvingSchemaInfo []model.InvolvingSchemaInfo - for _, j := range subJobs { - switch j.Type { - case model.ActionAlterTablePlacement: - ref, ok := j.Args[0].(*model.PolicyRefInfo) - if !ok { - logFn("unexpected type of policy reference info", - zap.Any("args[0]", j.Args[0]), - zap.String("type", fmt.Sprintf("%T", j.Args[0]))) - continue - } - if ref == nil { - continue - } - involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ - Policy: ref.Name.L, - Mode: model.SharedInvolving, - }) - case model.ActionAddForeignKey: - ref, ok := j.Args[0].(*model.FKInfo) - if !ok { - logFn("unexpected type of foreign key info", - zap.Any("args[0]", j.Args[0]), - zap.String("type", fmt.Sprintf("%T", j.Args[0]))) - continue - } - involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ - Database: ref.RefSchema.L, - Table: ref.RefTable.L, - Mode: model.SharedInvolving, - }) - case model.ActionAlterTablePartitionPlacement: - if len(j.Args) < 2 { - logFn("unexpected number of arguments for partition placement", - zap.Int("len(args)", len(j.Args)), - zap.Any("args", j.Args)) - continue - } - ref, ok := j.Args[1].(*model.PolicyRefInfo) - if !ok { - logFn("unexpected type of policy reference info", - zap.Any("args[0]", j.Args[0]), - zap.String("type", fmt.Sprintf("%T", j.Args[0]))) - continue - } - if ref == nil { - continue - } - involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ - Policy: ref.Name.L, - Mode: model.SharedInvolving, - }) - } - } - - if len(involvingSchemaInfo) > 0 { - involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ - Database: schema.Name.L, - Table: t.Meta().Name.L, - }) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionMultiSchemaChange, - BinlogInfo: &model.HistoryInfo{}, - Args: nil, - MultiSchemaInfo: info, - ReorgMeta: nil, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: involvingSchemaInfo, - SQLMode: ctx.GetSessionVars().SQLMode, - } - if containsDistTaskSubJob(subJobs) { - job.ReorgMeta, err = newReorgMetaFromVariables(job, ctx) - if err != nil { - return err - } - } else { - job.ReorgMeta = NewDDLReorgMeta(ctx) - } - - err = checkMultiSchemaInfo(info, t) - if err != nil { - return errors.Trace(err) - } - mergeAddIndex(info) - return e.DoDDLJob(ctx, job) -} - -func containsDistTaskSubJob(subJobs []*model.SubJob) bool { - for _, sub := range subJobs { - if sub.Type == model.ActionAddIndex || - sub.Type == model.ActionAddPrimaryKey { - return true - } - } - return false -} - -func (e *executor) RebaseAutoID(ctx sessionctx.Context, ident ast.Ident, newBase int64, tp autoid.AllocatorType, force bool) error { - schema, t, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - tbInfo := t.Meta() - var actionType model.ActionType - switch tp { - case autoid.AutoRandomType: - pkCol := tbInfo.GetPkColInfo() - if tbInfo.AutoRandomBits == 0 || pkCol == nil { - return errors.Trace(dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomRebaseNotApplicable)) - } - shardFmt := autoid.NewShardIDFormat(&pkCol.FieldType, tbInfo.AutoRandomBits, tbInfo.AutoRandomRangeBits) - if shardFmt.IncrementalMask()&newBase != newBase { - errMsg := fmt.Sprintf(autoid.AutoRandomRebaseOverflow, newBase, shardFmt.IncrementalBitsCapacity()) - return errors.Trace(dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg)) - } - actionType = model.ActionRebaseAutoRandomBase - case autoid.RowIDAllocType: - actionType = model.ActionRebaseAutoID - case autoid.AutoIncrementType: - actionType = model.ActionRebaseAutoID - default: - panic(fmt.Sprintf("unimplemented rebase autoid type %s", tp)) - } - - if !force { - newBaseTemp, err := adjustNewBaseToNextGlobalID(ctx.GetTableCtx(), t, tp, newBase) - if err != nil { - return err - } - if newBase != newBaseTemp { - ctx.GetSessionVars().StmtCtx.AppendWarning( - errors.NewNoStackErrorf("Can't reset AUTO_INCREMENT to %d without FORCE option, using %d instead", - newBase, newBaseTemp, - )) - } - newBase = newBaseTemp - } - job := &model.Job{ - SchemaID: schema.ID, - TableID: tbInfo.ID, - SchemaName: schema.Name.L, - TableName: tbInfo.Name.L, - Type: actionType, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{newBase, force}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func adjustNewBaseToNextGlobalID(ctx table.AllocatorContext, t table.Table, tp autoid.AllocatorType, newBase int64) (int64, error) { - alloc := t.Allocators(ctx).Get(tp) - if alloc == nil { - return newBase, nil - } - autoID, err := alloc.NextGlobalAutoID() - if err != nil { - return newBase, errors.Trace(err) - } - // If newBase < autoID, we need to do a rebase before returning. - // Assume there are 2 TiDB servers: TiDB-A with allocator range of 0 ~ 30000; TiDB-B with allocator range of 30001 ~ 60000. - // If the user sends SQL `alter table t1 auto_increment = 100` to TiDB-B, - // and TiDB-B finds 100 < 30001 but returns without any handling, - // then TiDB-A may still allocate 99 for auto_increment column. This doesn't make sense for the user. - return int64(mathutil.Max(uint64(newBase), uint64(autoID))), nil -} - -// ShardRowID shards the implicit row ID by adding shard value to the row ID's first few bits. -func (e *executor) ShardRowID(ctx sessionctx.Context, tableIdent ast.Ident, uVal uint64) error { - schema, t, err := e.getSchemaAndTableByIdent(tableIdent) - if err != nil { - return errors.Trace(err) - } - tbInfo := t.Meta() - if tbInfo.TempTableType != model.TempTableNone { - return dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("shard_row_id_bits") - } - if uVal == tbInfo.ShardRowIDBits { - // Nothing need to do. - return nil - } - if uVal > 0 && tbInfo.HasClusteredIndex() { - return dbterror.ErrUnsupportedShardRowIDBits - } - err = verifyNoOverflowShardBits(e.sessPool, t, uVal) - if err != nil { - return err - } - job := &model.Job{ - Type: model.ActionShardRowID, - SchemaID: schema.ID, - TableID: tbInfo.ID, - SchemaName: schema.Name.L, - TableName: tbInfo.Name.L, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{uVal}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) getSchemaAndTableByIdent(tableIdent ast.Ident) (dbInfo *model.DBInfo, t table.Table, err error) { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(tableIdent.Schema) - if !ok { - return nil, nil, infoschema.ErrDatabaseNotExists.GenWithStackByArgs(tableIdent.Schema) - } - t, err = is.TableByName(e.ctx, tableIdent.Schema, tableIdent.Name) - if err != nil { - return nil, nil, infoschema.ErrTableNotExists.GenWithStackByArgs(tableIdent.Schema, tableIdent.Name) - } - return schema, t, nil -} - -// AddColumn will add a new column to the table. -func (e *executor) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) error { - specNewColumn := spec.NewColumns[0] - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return errors.Trace(err) - } - failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) - tbInfo := t.Meta() - if err = checkAddColumnTooManyColumns(len(t.Cols()) + 1); err != nil { - return errors.Trace(err) - } - col, err := checkAndCreateNewColumn(ctx, ti, schema, spec, t, specNewColumn) - if err != nil { - return errors.Trace(err) - } - // Added column has existed and if_not_exists flag is true. - if col == nil { - return nil - } - err = CheckAfterPositionExists(tbInfo, spec.Position) - if err != nil { - return errors.Trace(err) - } - - txn, err := ctx.Txn(true) - if err != nil { - return errors.Trace(err) - } - bdrRole, err := meta.NewMeta(txn).GetBDRRole() - if err != nil { - return errors.Trace(err) - } - if bdrRole == string(ast.BDRRolePrimary) && deniedByBDRWhenAddColumn(specNewColumn.Options) { - return dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tbInfo.ID, - SchemaName: schema.Name.L, - TableName: tbInfo.Name.L, - Type: model.ActionAddColumn, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{col, spec.Position, 0, spec.IfNotExists}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// AddTablePartitions will add a new partition to the table. -func (e *executor) AddTablePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) - } - t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - - meta := t.Meta() - pi := meta.GetPartitionInfo() - if pi == nil { - return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - if pi.Type == model.PartitionTypeHash || pi.Type == model.PartitionTypeKey { - // Add partition for hash/key is actually a reorganize partition - // operation and not a metadata only change! - switch spec.Tp { - case ast.AlterTableAddLastPartition: - return errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("LAST PARTITION of HASH/KEY partitioned table")) - case ast.AlterTableAddPartitions: - // only thing supported - default: - return errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("ADD PARTITION of HASH/KEY partitioned table")) - } - return e.hashPartitionManagement(ctx, ident, spec, pi) - } - - partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec) - if err != nil { - return errors.Trace(err) - } - if pi.Type == model.PartitionTypeList { - // TODO: make sure that checks in ddl_api and ddl_worker is the same. - err = checkAddListPartitions(meta) - if err != nil { - return errors.Trace(err) - } - } - if err := e.assignPartitionIDs(partInfo.Definitions); err != nil { - return errors.Trace(err) - } - - // partInfo contains only the new added partition, we have to combine it with the - // old partitions to check all partitions is strictly increasing. - clonedMeta := meta.Clone() - tmp := *partInfo - tmp.Definitions = append(pi.Definitions, tmp.Definitions...) - clonedMeta.Partition = &tmp - if err := checkPartitionDefinitionConstraints(ctx, clonedMeta); err != nil { - if dbterror.ErrSameNamePartition.Equal(err) && spec.IfNotExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return errors.Trace(err) - } - - if err = handlePartitionPlacement(ctx, partInfo); err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionAddTablePartition, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{partInfo}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - if spec.Tp == ast.AlterTableAddLastPartition && spec.Partition != nil { - query, ok := ctx.Value(sessionctx.QueryString).(string) - if ok { - sqlMode := ctx.GetSessionVars().SQLMode - var buf bytes.Buffer - AppendPartitionDefs(partInfo, &buf, sqlMode) - - syntacticSugar := spec.Partition.PartitionMethod.OriginalText() - syntacticStart := spec.Partition.PartitionMethod.OriginTextPosition() - newQuery := query[:syntacticStart] + "ADD PARTITION (" + buf.String() + ")" + query[syntacticStart+len(syntacticSugar):] - defer ctx.SetValue(sessionctx.QueryString, query) - ctx.SetValue(sessionctx.QueryString, newQuery) - } - } - err = e.DoDDLJob(ctx, job) - if dbterror.ErrSameNamePartition.Equal(err) && spec.IfNotExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return errors.Trace(err) -} - -// getReorganizedDefinitions return the definitions as they would look like after the REORGANIZE PARTITION is done. -func getReorganizedDefinitions(pi *model.PartitionInfo, firstPartIdx, lastPartIdx int, idMap map[int]struct{}) []model.PartitionDefinition { - tmpDefs := make([]model.PartitionDefinition, 0, len(pi.Definitions)+len(pi.AddingDefinitions)-len(idMap)) - if pi.Type == model.PartitionTypeList { - replaced := false - for i := range pi.Definitions { - if _, ok := idMap[i]; ok { - if !replaced { - tmpDefs = append(tmpDefs, pi.AddingDefinitions...) - replaced = true - } - continue - } - tmpDefs = append(tmpDefs, pi.Definitions[i]) - } - if !replaced { - // For safety, for future non-partitioned table -> partitioned - tmpDefs = append(tmpDefs, pi.AddingDefinitions...) - } - return tmpDefs - } - // Range - tmpDefs = append(tmpDefs, pi.Definitions[:firstPartIdx]...) - tmpDefs = append(tmpDefs, pi.AddingDefinitions...) - if len(pi.Definitions) > (lastPartIdx + 1) { - tmpDefs = append(tmpDefs, pi.Definitions[lastPartIdx+1:]...) - } - return tmpDefs -} - -func getReplacedPartitionIDs(names []string, pi *model.PartitionInfo) (firstPartIdx int, lastPartIdx int, idMap map[int]struct{}, err error) { - idMap = make(map[int]struct{}) - firstPartIdx, lastPartIdx = -1, -1 - for _, name := range names { - nameL := strings.ToLower(name) - partIdx := pi.FindPartitionDefinitionByName(nameL) - if partIdx == -1 { - return 0, 0, nil, errors.Trace(dbterror.ErrWrongPartitionName) - } - if _, ok := idMap[partIdx]; ok { - return 0, 0, nil, errors.Trace(dbterror.ErrSameNamePartition) - } - idMap[partIdx] = struct{}{} - if firstPartIdx == -1 { - firstPartIdx = partIdx - } else { - firstPartIdx = mathutil.Min[int](firstPartIdx, partIdx) - } - if lastPartIdx == -1 { - lastPartIdx = partIdx - } else { - lastPartIdx = mathutil.Max[int](lastPartIdx, partIdx) - } - } - switch pi.Type { - case model.PartitionTypeRange: - if len(idMap) != (lastPartIdx - firstPartIdx + 1) { - return 0, 0, nil, errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "REORGANIZE PARTITION of RANGE; not adjacent partitions")) - } - case model.PartitionTypeHash, model.PartitionTypeKey: - if len(idMap) != len(pi.Definitions) { - return 0, 0, nil, errors.Trace(dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "REORGANIZE PARTITION of HASH/RANGE; must reorganize all partitions")) - } - } - - return firstPartIdx, lastPartIdx, idMap, nil -} - -func getPartitionInfoTypeNone() *model.PartitionInfo { - return &model.PartitionInfo{ - Type: model.PartitionTypeNone, - Enable: true, - Definitions: []model.PartitionDefinition{{ - Name: model.NewCIStr("pFullTable"), - Comment: "Intermediate partition during ALTER TABLE ... PARTITION BY ...", - }}, - Num: 1, - } -} - -// AlterTablePartitioning reorganize one set of partitions to a new set of partitions. -func (e *executor) AlterTablePartitioning(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - schema, t, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.FastGenByArgs(ident.Schema, ident.Name)) - } - - meta := t.Meta().Clone() - piOld := meta.GetPartitionInfo() - var partNames []string - if piOld != nil { - partNames = make([]string, 0, len(piOld.Definitions)) - for i := range piOld.Definitions { - partNames = append(partNames, piOld.Definitions[i].Name.L) - } - } else { - piOld = getPartitionInfoTypeNone() - meta.Partition = piOld - partNames = append(partNames, piOld.Definitions[0].Name.L) - } - newMeta := meta.Clone() - err = buildTablePartitionInfo(ctx, spec.Partition, newMeta) - if err != nil { - return err - } - newPartInfo := newMeta.Partition - - for _, index := range newMeta.Indices { - if index.Unique { - ck, err := checkPartitionKeysConstraint(newMeta.GetPartitionInfo(), index.Columns, newMeta) - if err != nil { - return err - } - if !ck { - indexTp := "" - if !ctx.GetSessionVars().EnableGlobalIndex { - if index.Primary { - indexTp = "PRIMARY KEY" - } else { - indexTp = "UNIQUE INDEX" - } - } else if t.Meta().IsCommonHandle { - indexTp = "CLUSTERED INDEX" - } - if indexTp != "" { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs(indexTp) - } - // Also mark the unique index as global index - index.Global = true - } - } - } - if newMeta.PKIsHandle { - // This case is covers when the Handle is the PK (only ints), since it would not - // have an entry in the tblInfo.Indices - indexCols := []*model.IndexColumn{{ - Name: newMeta.GetPkName(), - Length: types.UnspecifiedLength, - }} - ck, err := checkPartitionKeysConstraint(newMeta.GetPartitionInfo(), indexCols, newMeta) - if err != nil { - return err - } - if !ck { - if !ctx.GetSessionVars().EnableGlobalIndex { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("PRIMARY KEY") - } - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("CLUSTERED INDEX") - } - } - - if err = handlePartitionPlacement(ctx, newPartInfo); err != nil { - return errors.Trace(err) - } - - if err = e.assignPartitionIDs(newPartInfo.Definitions); err != nil { - return errors.Trace(err) - } - // A new table ID would be needed for - // the global index, which cannot be the same as the current table id, - // since this table id will be removed in the final state when removing - // all the data with this table id. - var newID []int64 - newID, err = e.genGlobalIDs(1) - if err != nil { - return errors.Trace(err) - } - newPartInfo.NewTableID = newID[0] - newPartInfo.DDLType = piOld.Type - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionAlterTablePartitioning, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{partNames, newPartInfo}, - ReorgMeta: NewDDLReorgMeta(ctx), - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. - err = e.DoDDLJob(ctx, job) - if err == nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of new partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) - } - return errors.Trace(err) -} - -// ReorganizePartitions reorganize one set of partitions to a new set of partitions. -func (e *executor) ReorganizePartitions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - schema, t, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.FastGenByArgs(ident.Schema, ident.Name)) - } - - meta := t.Meta() - pi := meta.GetPartitionInfo() - if pi == nil { - return dbterror.ErrPartitionMgmtOnNonpartitioned - } - switch pi.Type { - case model.PartitionTypeRange, model.PartitionTypeList: - case model.PartitionTypeHash, model.PartitionTypeKey: - if spec.Tp != ast.AlterTableCoalescePartitions && - spec.Tp != ast.AlterTableAddPartitions { - return errors.Trace(dbterror.ErrUnsupportedReorganizePartition) - } - default: - return errors.Trace(dbterror.ErrUnsupportedReorganizePartition) - } - partNames := make([]string, 0, len(spec.PartitionNames)) - for _, name := range spec.PartitionNames { - partNames = append(partNames, name.L) - } - firstPartIdx, lastPartIdx, idMap, err := getReplacedPartitionIDs(partNames, pi) - if err != nil { - return errors.Trace(err) - } - partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, spec) - if err != nil { - return errors.Trace(err) - } - if err = e.assignPartitionIDs(partInfo.Definitions); err != nil { - return errors.Trace(err) - } - if err = checkReorgPartitionDefs(ctx, model.ActionReorganizePartition, meta, partInfo, firstPartIdx, lastPartIdx, idMap); err != nil { - return errors.Trace(err) - } - if err = handlePartitionPlacement(ctx, partInfo); err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionReorganizePartition, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{partNames, partInfo}, - ReorgMeta: NewDDLReorgMeta(ctx), - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. - err = e.DoDDLJob(ctx, job) - failpoint.InjectCall("afterReorganizePartition") - if err == nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("The statistics of related partitions will be outdated after reorganizing partitions. Please use 'ANALYZE TABLE' statement if you want to update it now")) - } - return errors.Trace(err) -} - -// RemovePartitioning removes partitioning from a table. -func (e *executor) RemovePartitioning(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - schema, t, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.FastGenByArgs(ident.Schema, ident.Name)) - } - - meta := t.Meta().Clone() - pi := meta.GetPartitionInfo() - if pi == nil { - return dbterror.ErrPartitionMgmtOnNonpartitioned - } - // TODO: Optimize for remove partitioning with a single partition - // TODO: Add the support for this in onReorganizePartition - // skip if only one partition - // If there are only one partition, then we can do: - // change the table id to the partition id - // and keep the statistics for the partition id (which should be similar to the global statistics) - // and it let the GC clean up the old table metadata including possible global index. - - newSpec := &ast.AlterTableSpec{} - newSpec.Tp = spec.Tp - defs := make([]*ast.PartitionDefinition, 1) - defs[0] = &ast.PartitionDefinition{} - defs[0].Name = model.NewCIStr("CollapsedPartitions") - newSpec.PartDefinitions = defs - partNames := make([]string, len(pi.Definitions)) - for i := range pi.Definitions { - partNames[i] = pi.Definitions[i].Name.L - } - meta.Partition.Type = model.PartitionTypeNone - partInfo, err := BuildAddedPartitionInfo(ctx.GetExprCtx(), meta, newSpec) - if err != nil { - return errors.Trace(err) - } - if err = e.assignPartitionIDs(partInfo.Definitions); err != nil { - return errors.Trace(err) - } - // TODO: check where the default placement comes from (i.e. table level) - if err = handlePartitionPlacement(ctx, partInfo); err != nil { - return errors.Trace(err) - } - partInfo.NewTableID = partInfo.Definitions[0].ID - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - TableName: meta.Name.L, - Type: model.ActionRemovePartitioning, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{partNames, partInfo}, - ReorgMeta: NewDDLReorgMeta(ctx), - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - // No preSplitAndScatter here, it will be done by the worker in onReorganizePartition instead. - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func checkReorgPartitionDefs(ctx sessionctx.Context, action model.ActionType, tblInfo *model.TableInfo, partInfo *model.PartitionInfo, firstPartIdx, lastPartIdx int, idMap map[int]struct{}) error { - // partInfo contains only the new added partition, we have to combine it with the - // old partitions to check all partitions is strictly increasing. - pi := tblInfo.Partition - clonedMeta := tblInfo.Clone() - switch action { - case model.ActionRemovePartitioning, model.ActionAlterTablePartitioning: - clonedMeta.Partition = partInfo - clonedMeta.ID = partInfo.NewTableID - case model.ActionReorganizePartition: - clonedMeta.Partition.AddingDefinitions = partInfo.Definitions - clonedMeta.Partition.Definitions = getReorganizedDefinitions(clonedMeta.Partition, firstPartIdx, lastPartIdx, idMap) - default: - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("partition type") - } - if err := checkPartitionDefinitionConstraints(ctx, clonedMeta); err != nil { - return errors.Trace(err) - } - if action == model.ActionReorganizePartition { - if pi.Type == model.PartitionTypeRange { - if lastPartIdx == len(pi.Definitions)-1 { - // Last partition dropped, OK to change the end range - // Also includes MAXVALUE - return nil - } - // Check if the replaced end range is the same as before - lastAddingPartition := partInfo.Definitions[len(partInfo.Definitions)-1] - lastOldPartition := pi.Definitions[lastPartIdx] - if len(pi.Columns) > 0 { - newGtOld, err := checkTwoRangeColumns(ctx, &lastAddingPartition, &lastOldPartition, pi, tblInfo) - if err != nil { - return errors.Trace(err) - } - if newGtOld { - return errors.Trace(dbterror.ErrRangeNotIncreasing) - } - oldGtNew, err := checkTwoRangeColumns(ctx, &lastOldPartition, &lastAddingPartition, pi, tblInfo) - if err != nil { - return errors.Trace(err) - } - if oldGtNew { - return errors.Trace(dbterror.ErrRangeNotIncreasing) - } - return nil - } - - isUnsigned := isPartExprUnsigned(ctx.GetExprCtx().GetEvalCtx(), tblInfo) - currentRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), pi.Definitions[lastPartIdx].LessThan[0], isUnsigned) - if err != nil { - return errors.Trace(err) - } - newRangeValue, _, err := getRangeValue(ctx.GetExprCtx(), partInfo.Definitions[len(partInfo.Definitions)-1].LessThan[0], isUnsigned) - if err != nil { - return errors.Trace(err) - } - - if currentRangeValue != newRangeValue { - return errors.Trace(dbterror.ErrRangeNotIncreasing) - } - } - } else { - if len(pi.Definitions) != (lastPartIdx - firstPartIdx + 1) { - // if not ActionReorganizePartition, require all partitions to be changed. - return errors.Trace(dbterror.ErrAlterOperationNotSupported) - } - } - return nil -} - -// CoalescePartitions coalesce partitions can be used with a table that is partitioned by hash or key to reduce the number of partitions by number. -func (e *executor) CoalescePartitions(sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) - } - t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - - pi := t.Meta().GetPartitionInfo() - if pi == nil { - return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - - switch pi.Type { - case model.PartitionTypeHash, model.PartitionTypeKey: - return e.hashPartitionManagement(sctx, ident, spec, pi) - - // Coalesce partition can only be used on hash/key partitions. - default: - return errors.Trace(dbterror.ErrCoalesceOnlyOnHashPartition) - } -} - -func (e *executor) hashPartitionManagement(sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec, pi *model.PartitionInfo) error { - newSpec := *spec - newSpec.PartitionNames = make([]model.CIStr, len(pi.Definitions)) - for i := 0; i < len(pi.Definitions); i++ { - // reorganize ALL partitions into the new number of partitions - newSpec.PartitionNames[i] = pi.Definitions[i].Name - } - for i := 0; i < len(newSpec.PartDefinitions); i++ { - switch newSpec.PartDefinitions[i].Clause.(type) { - case *ast.PartitionDefinitionClauseNone: - // OK, expected - case *ast.PartitionDefinitionClauseIn: - return errors.Trace(ast.ErrPartitionWrongValues.FastGenByArgs("LIST", "IN")) - case *ast.PartitionDefinitionClauseLessThan: - return errors.Trace(ast.ErrPartitionWrongValues.FastGenByArgs("RANGE", "LESS THAN")) - case *ast.PartitionDefinitionClauseHistory: - return errors.Trace(ast.ErrPartitionWrongValues.FastGenByArgs("SYSTEM_TIME", "HISTORY")) - - default: - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "partitioning clause") - } - } - if newSpec.Num < uint64(len(newSpec.PartDefinitions)) { - newSpec.Num = uint64(len(newSpec.PartDefinitions)) - } - if spec.Tp == ast.AlterTableCoalescePartitions { - if newSpec.Num < 1 { - return ast.ErrCoalescePartitionNoPartition - } - if newSpec.Num >= uint64(len(pi.Definitions)) { - return dbterror.ErrDropLastPartition - } - if isNonDefaultPartitionOptionsUsed(pi.Definitions) { - // The partition definitions will be copied in buildHashPartitionDefinitions() - // if there is a non-empty list of definitions - newSpec.PartDefinitions = []*ast.PartitionDefinition{{}} - } - } - - return e.ReorganizePartitions(sctx, ident, &newSpec) -} - -func (e *executor) TruncateTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) - } - t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - meta := t.Meta() - if meta.GetPartitionInfo() == nil { - return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - - getTruncatedParts := func(pi *model.PartitionInfo) (*model.PartitionInfo, error) { - if spec.OnAllPartitions { - return pi.Clone(), nil - } - var defs []model.PartitionDefinition - // MySQL allows duplicate partition names in truncate partition - // so we filter them out through a hash - posMap := make(map[int]bool) - for _, name := range spec.PartitionNames { - pos := pi.FindPartitionDefinitionByName(name.L) - if pos < 0 { - return nil, errors.Trace(table.ErrUnknownPartition.GenWithStackByArgs(name.L, ident.Name.O)) - } - if _, ok := posMap[pos]; !ok { - defs = append(defs, pi.Definitions[pos]) - posMap[pos] = true - } - } - pi = pi.Clone() - pi.Definitions = defs - return pi, nil - } - pi, err := getTruncatedParts(meta.GetPartitionInfo()) - if err != nil { - return err - } - pids := make([]int64, 0, len(pi.Definitions)) - for i := range pi.Definitions { - pids = append(pids, pi.Definitions[i].ID) - } - - genIDs, err := e.genGlobalIDs(len(pids)) - if err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - SchemaState: model.StatePublic, - TableName: t.Meta().Name.L, - Type: model.ActionTruncateTablePartition, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{pids, genIDs}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - if err != nil { - return errors.Trace(err) - } - return nil -} - -func (e *executor) DropTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs(schema)) - } - t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - meta := t.Meta() - if meta.GetPartitionInfo() == nil { - return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - - if spec.Tp == ast.AlterTableDropFirstPartition { - intervalOptions := getPartitionIntervalFromTable(ctx.GetExprCtx(), meta) - if intervalOptions == nil { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "FIRST PARTITION, does not seem like an INTERVAL partitioned table") - } - if len(spec.Partition.Definitions) != 0 { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "FIRST PARTITION, table info already contains partition definitions") - } - spec.Partition.Interval = intervalOptions - err = GeneratePartDefsFromInterval(ctx.GetExprCtx(), spec.Tp, meta, spec.Partition) - if err != nil { - return err - } - pNullOffset := 0 - if intervalOptions.NullPart { - pNullOffset = 1 - } - if len(spec.Partition.Definitions) == 0 || - len(spec.Partition.Definitions) >= len(meta.Partition.Definitions)-pNullOffset { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "FIRST PARTITION, number of partitions does not match") - } - if len(spec.PartitionNames) != 0 || len(spec.Partition.Definitions) <= 1 { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "FIRST PARTITION, given value does not generate a list of partition names to be dropped") - } - for i := range spec.Partition.Definitions { - spec.PartitionNames = append(spec.PartitionNames, meta.Partition.Definitions[i+pNullOffset].Name) - } - // Use the last generated partition as First, i.e. do not drop the last name in the slice - spec.PartitionNames = spec.PartitionNames[:len(spec.PartitionNames)-1] - - query, ok := ctx.Value(sessionctx.QueryString).(string) - if ok { - partNames := make([]string, 0, len(spec.PartitionNames)) - sqlMode := ctx.GetSessionVars().SQLMode - for i := range spec.PartitionNames { - partNames = append(partNames, stringutil.Escape(spec.PartitionNames[i].O, sqlMode)) - } - syntacticSugar := spec.Partition.PartitionMethod.OriginalText() - syntacticStart := spec.Partition.PartitionMethod.OriginTextPosition() - newQuery := query[:syntacticStart] + "DROP PARTITION " + strings.Join(partNames, ", ") + query[syntacticStart+len(syntacticSugar):] - defer ctx.SetValue(sessionctx.QueryString, query) - ctx.SetValue(sessionctx.QueryString, newQuery) - } - } - partNames := make([]string, len(spec.PartitionNames)) - for i, partCIName := range spec.PartitionNames { - partNames[i] = partCIName.L - } - err = CheckDropTablePartition(meta, partNames) - if err != nil { - if dbterror.ErrDropPartitionNonExistent.Equal(err) && spec.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - SchemaState: model.StatePublic, - TableName: meta.Name.L, - Type: model.ActionDropTablePartition, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{partNames}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - if err != nil { - if dbterror.ErrDropPartitionNonExistent.Equal(err) && spec.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return errors.Trace(err) - } - return errors.Trace(err) -} - -func checkFieldTypeCompatible(ft *types.FieldType, other *types.FieldType) bool { - // int(1) could match the type with int(8) - partialEqual := ft.GetType() == other.GetType() && - ft.GetDecimal() == other.GetDecimal() && - ft.GetCharset() == other.GetCharset() && - ft.GetCollate() == other.GetCollate() && - (ft.GetFlen() == other.GetFlen() || ft.StorageLength() != types.VarStorageLen) && - mysql.HasUnsignedFlag(ft.GetFlag()) == mysql.HasUnsignedFlag(other.GetFlag()) && - mysql.HasAutoIncrementFlag(ft.GetFlag()) == mysql.HasAutoIncrementFlag(other.GetFlag()) && - mysql.HasNotNullFlag(ft.GetFlag()) == mysql.HasNotNullFlag(other.GetFlag()) && - mysql.HasZerofillFlag(ft.GetFlag()) == mysql.HasZerofillFlag(other.GetFlag()) && - mysql.HasBinaryFlag(ft.GetFlag()) == mysql.HasBinaryFlag(other.GetFlag()) && - mysql.HasPriKeyFlag(ft.GetFlag()) == mysql.HasPriKeyFlag(other.GetFlag()) - if !partialEqual || len(ft.GetElems()) != len(other.GetElems()) { - return false - } - for i := range ft.GetElems() { - if ft.GetElems()[i] != other.GetElems()[i] { - return false - } - } - return true -} - -func checkTiFlashReplicaCompatible(source *model.TiFlashReplicaInfo, target *model.TiFlashReplicaInfo) bool { - if source == target { - return true - } - if source == nil || target == nil { - return false - } - if source.Count != target.Count || - source.Available != target.Available || len(source.LocationLabels) != len(target.LocationLabels) { - return false - } - for i, lable := range source.LocationLabels { - if target.LocationLabels[i] != lable { - return false - } - } - return true -} - -func checkTableDefCompatible(source *model.TableInfo, target *model.TableInfo) error { - // check temp table - if target.TempTableType != model.TempTableNone { - return errors.Trace(dbterror.ErrPartitionExchangeTempTable.FastGenByArgs(target.Name)) - } - - // check auto_random - if source.AutoRandomBits != target.AutoRandomBits || - source.AutoRandomRangeBits != target.AutoRandomRangeBits || - source.Charset != target.Charset || - source.Collate != target.Collate || - source.ShardRowIDBits != target.ShardRowIDBits || - source.MaxShardRowIDBits != target.MaxShardRowIDBits || - !checkTiFlashReplicaCompatible(source.TiFlashReplica, target.TiFlashReplica) { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - if len(source.Cols()) != len(target.Cols()) { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - // Col compatible check - for i, sourceCol := range source.Cols() { - targetCol := target.Cols()[i] - if sourceCol.IsVirtualGenerated() != targetCol.IsVirtualGenerated() { - return dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Exchanging partitions for non-generated columns") - } - // It should strictyle compare expressions for generated columns - if sourceCol.Name.L != targetCol.Name.L || - sourceCol.Hidden != targetCol.Hidden || - !checkFieldTypeCompatible(&sourceCol.FieldType, &targetCol.FieldType) || - sourceCol.GeneratedExprString != targetCol.GeneratedExprString { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - if sourceCol.State != model.StatePublic || - targetCol.State != model.StatePublic { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - if sourceCol.ID != targetCol.ID { - return dbterror.ErrPartitionExchangeDifferentOption.GenWithStackByArgs(fmt.Sprintf("column: %s", sourceCol.Name)) - } - } - if len(source.Indices) != len(target.Indices) { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - for _, sourceIdx := range source.Indices { - if sourceIdx.Global { - return dbterror.ErrPartitionExchangeDifferentOption.GenWithStackByArgs(fmt.Sprintf("global index: %s", sourceIdx.Name)) - } - var compatIdx *model.IndexInfo - for _, targetIdx := range target.Indices { - if strings.EqualFold(sourceIdx.Name.L, targetIdx.Name.L) { - compatIdx = targetIdx - } - } - // No match index - if compatIdx == nil { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - // Index type is not compatible - if sourceIdx.Tp != compatIdx.Tp || - sourceIdx.Unique != compatIdx.Unique || - sourceIdx.Primary != compatIdx.Primary { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - // The index column - if len(sourceIdx.Columns) != len(compatIdx.Columns) { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - for i, sourceIdxCol := range sourceIdx.Columns { - compatIdxCol := compatIdx.Columns[i] - if sourceIdxCol.Length != compatIdxCol.Length || - sourceIdxCol.Name.L != compatIdxCol.Name.L { - return errors.Trace(dbterror.ErrTablesDifferentMetadata) - } - } - if sourceIdx.ID != compatIdx.ID { - return dbterror.ErrPartitionExchangeDifferentOption.GenWithStackByArgs(fmt.Sprintf("index: %s", sourceIdx.Name)) - } - } - - return nil -} - -func checkExchangePartition(pt *model.TableInfo, nt *model.TableInfo) error { - if nt.IsView() || nt.IsSequence() { - return errors.Trace(dbterror.ErrCheckNoSuchTable) - } - if pt.GetPartitionInfo() == nil { - return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - if nt.GetPartitionInfo() != nil { - return errors.Trace(dbterror.ErrPartitionExchangePartTable.GenWithStackByArgs(nt.Name)) - } - - if len(nt.ForeignKeys) > 0 { - return errors.Trace(dbterror.ErrPartitionExchangeForeignKey.GenWithStackByArgs(nt.Name)) - } - - return nil -} - -func (e *executor) ExchangeTablePartition(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - ptSchema, pt, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - - ptMeta := pt.Meta() - - ntIdent := ast.Ident{Schema: spec.NewTable.Schema, Name: spec.NewTable.Name} - - // We should check local temporary here using session's info schema because the local temporary tables are only stored in session. - ntLocalTempTable, err := sessiontxn.GetTxnManager(ctx).GetTxnInfoSchema().TableByName(context.Background(), ntIdent.Schema, ntIdent.Name) - if err == nil && ntLocalTempTable.Meta().TempTableType == model.TempTableLocal { - return errors.Trace(dbterror.ErrPartitionExchangeTempTable.FastGenByArgs(ntLocalTempTable.Meta().Name)) - } - - ntSchema, nt, err := e.getSchemaAndTableByIdent(ntIdent) - if err != nil { - return errors.Trace(err) - } - - ntMeta := nt.Meta() - - err = checkExchangePartition(ptMeta, ntMeta) - if err != nil { - return errors.Trace(err) - } - - partName := spec.PartitionNames[0].L - - // NOTE: if pt is subPartitioned, it should be checked - - defID, err := tables.FindPartitionByName(ptMeta, partName) - if err != nil { - return errors.Trace(err) - } - - err = checkTableDefCompatible(ptMeta, ntMeta) - if err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: ntSchema.ID, - TableID: ntMeta.ID, - SchemaName: ntSchema.Name.L, - TableName: ntMeta.Name.L, - Type: model.ActionExchangeTablePartition, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{defID, ptSchema.ID, ptMeta.ID, partName, spec.WithValidation}, - CtxVars: []any{[]int64{ntSchema.ID, ptSchema.ID}, []int64{ntMeta.ID, ptMeta.ID}}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{ - {Database: ptSchema.Name.L, Table: ptMeta.Name.L}, - {Database: ntSchema.Name.L, Table: ntMeta.Name.L}, - }, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - if err != nil { - return errors.Trace(err) - } - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("after the exchange, please analyze related table of the exchange to update statistics")) - return nil -} - -// DropColumn will drop a column from the table, now we don't support drop the column with index covered. -func (e *executor) DropColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTableSpec) error { - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return errors.Trace(err) - } - failpoint.InjectCall("afterGetSchemaAndTableByIdent", ctx) - - isDropable, err := checkIsDroppableColumn(ctx, e.infoCache.GetLatest(), schema, t, spec) - if err != nil { - return err - } - if !isDropable { - return nil - } - colName := spec.OldColumnName.Name - err = checkVisibleColumnCnt(t, 0, 1) - if err != nil { - return err - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - SchemaState: model.StatePublic, - TableName: t.Meta().Name.L, - Type: model.ActionDropColumn, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{colName, spec.IfExists}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func checkIsDroppableColumn(ctx sessionctx.Context, is infoschema.InfoSchema, schema *model.DBInfo, t table.Table, spec *ast.AlterTableSpec) (isDrapable bool, err error) { - tblInfo := t.Meta() - // Check whether dropped column has existed. - colName := spec.OldColumnName.Name - col := table.FindCol(t.VisibleCols(), colName.L) - if col == nil { - err = dbterror.ErrCantDropFieldOrKey.GenWithStackByArgs(colName) - if spec.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return false, nil - } - return false, err - } - - if err = isDroppableColumn(tblInfo, colName); err != nil { - return false, errors.Trace(err) - } - if err = checkDropColumnWithPartitionConstraint(t, colName); err != nil { - return false, errors.Trace(err) - } - // Check the column with foreign key. - err = checkDropColumnWithForeignKeyConstraint(is, schema.Name.L, tblInfo, colName.L) - if err != nil { - return false, errors.Trace(err) - } - // Check the column with TTL config - err = checkDropColumnWithTTLConfig(tblInfo, colName.L) - if err != nil { - return false, errors.Trace(err) - } - // We don't support dropping column with PK handle covered now. - if col.IsPKHandleColumn(tblInfo) { - return false, dbterror.ErrUnsupportedPKHandle - } - if mysql.HasAutoIncrementFlag(col.GetFlag()) && !ctx.GetSessionVars().AllowRemoveAutoInc { - return false, dbterror.ErrCantDropColWithAutoInc - } - return true, nil -} - -// checkDropColumnWithPartitionConstraint is used to check the partition constraint of the drop column. -func checkDropColumnWithPartitionConstraint(t table.Table, colName model.CIStr) error { - if t.Meta().Partition == nil { - return nil - } - pt, ok := t.(table.PartitionedTable) - if !ok { - // Should never happen! - return errors.Trace(dbterror.ErrDependentByPartitionFunctional.GenWithStackByArgs(colName.L)) - } - for _, name := range pt.GetPartitionColumnNames() { - if strings.EqualFold(name.L, colName.L) { - return errors.Trace(dbterror.ErrDependentByPartitionFunctional.GenWithStackByArgs(colName.L)) - } - } - return nil -} - -func checkVisibleColumnCnt(t table.Table, addCnt, dropCnt int) error { - tblInfo := t.Meta() - visibleColumCnt := 0 - for _, column := range tblInfo.Columns { - if !column.Hidden { - visibleColumCnt++ - } - } - if visibleColumCnt+addCnt > dropCnt { - return nil - } - if len(tblInfo.Columns)-visibleColumCnt > 0 { - // There are only invisible columns. - return dbterror.ErrTableMustHaveColumns - } - return dbterror.ErrCantRemoveAllFields -} - -// checkModifyCharsetAndCollation returns error when the charset or collation is not modifiable. -// needRewriteCollationData is used when trying to modify the collation of a column, it is true when the column is with -// index because index of a string column is collation-aware. -func checkModifyCharsetAndCollation(toCharset, toCollate, origCharset, origCollate string, needRewriteCollationData bool) error { - if !charset.ValidCharsetAndCollation(toCharset, toCollate) { - return dbterror.ErrUnknownCharacterSet.GenWithStack("Unknown character set: '%s', collation: '%s'", toCharset, toCollate) - } - - if needRewriteCollationData && collate.NewCollationEnabled() && !collate.CompatibleCollate(origCollate, toCollate) { - return dbterror.ErrUnsupportedModifyCollation.GenWithStackByArgs(origCollate, toCollate) - } - - if (origCharset == charset.CharsetUTF8 && toCharset == charset.CharsetUTF8MB4) || - (origCharset == charset.CharsetUTF8 && toCharset == charset.CharsetUTF8) || - (origCharset == charset.CharsetUTF8MB4 && toCharset == charset.CharsetUTF8MB4) || - (origCharset == charset.CharsetLatin1 && toCharset == charset.CharsetUTF8MB4) { - // TiDB only allow utf8/latin1 to be changed to utf8mb4, or changing the collation when the charset is utf8/utf8mb4/latin1. - return nil - } - - if toCharset != origCharset { - msg := fmt.Sprintf("charset from %s to %s", origCharset, toCharset) - return dbterror.ErrUnsupportedModifyCharset.GenWithStackByArgs(msg) - } - if toCollate != origCollate { - msg := fmt.Sprintf("change collate from %s to %s", origCollate, toCollate) - return dbterror.ErrUnsupportedModifyCharset.GenWithStackByArgs(msg) - } - return nil -} - -func (e *executor) getModifiableColumnJob(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, originalColName model.CIStr, - spec *ast.AlterTableSpec) (*model.Job, error) { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return nil, errors.Trace(infoschema.ErrDatabaseNotExists) - } - t, err := is.TableByName(ctx, ident.Schema, ident.Name) - if err != nil { - return nil, errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - - return GetModifiableColumnJob(ctx, sctx, is, ident, originalColName, schema, t, spec) -} - -// ChangeColumn renames an existing column and modifies the column's definition, -// currently we only support limited kind of changes -// that do not need to change or check data on the table. -func (e *executor) ChangeColumn(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - specNewColumn := spec.NewColumns[0] - if len(specNewColumn.Name.Schema.O) != 0 && ident.Schema.L != specNewColumn.Name.Schema.L { - return dbterror.ErrWrongDBName.GenWithStackByArgs(specNewColumn.Name.Schema.O) - } - if len(spec.OldColumnName.Schema.O) != 0 && ident.Schema.L != spec.OldColumnName.Schema.L { - return dbterror.ErrWrongDBName.GenWithStackByArgs(spec.OldColumnName.Schema.O) - } - if len(specNewColumn.Name.Table.O) != 0 && ident.Name.L != specNewColumn.Name.Table.L { - return dbterror.ErrWrongTableName.GenWithStackByArgs(specNewColumn.Name.Table.O) - } - if len(spec.OldColumnName.Table.O) != 0 && ident.Name.L != spec.OldColumnName.Table.L { - return dbterror.ErrWrongTableName.GenWithStackByArgs(spec.OldColumnName.Table.O) - } - - job, err := e.getModifiableColumnJob(ctx, sctx, ident, spec.OldColumnName.Name, spec) - if err != nil { - if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.FastGenByArgs(spec.OldColumnName.Name, ident.Name)) - return nil - } - return errors.Trace(err) - } - - err = e.DoDDLJob(sctx, job) - // column not exists, but if_exists flags is true, so we ignore this error. - if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - sctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return errors.Trace(err) -} - -// RenameColumn renames an existing column. -func (e *executor) RenameColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - oldColName := spec.OldColumnName.Name - newColName := spec.NewColumnName.Name - - schema, tbl, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - - oldCol := table.FindCol(tbl.VisibleCols(), oldColName.L) - if oldCol == nil { - return infoschema.ErrColumnNotExists.GenWithStackByArgs(oldColName, ident.Name) - } - // check if column can rename with check constraint - err = IsColumnRenameableWithCheckConstraint(oldCol.Name, tbl.Meta()) - if err != nil { - return err - } - - if oldColName.L == newColName.L { - return nil - } - if newColName.L == model.ExtraHandleName.L { - return dbterror.ErrWrongColumnName.GenWithStackByArgs(newColName.L) - } - - allCols := tbl.Cols() - colWithNewNameAlreadyExist := table.FindCol(allCols, newColName.L) != nil - if colWithNewNameAlreadyExist { - return infoschema.ErrColumnExists.GenWithStackByArgs(newColName) - } - - // Check generated expression. - err = checkModifyColumnWithGeneratedColumnsConstraint(allCols, oldColName) - if err != nil { - return errors.Trace(err) - } - err = checkDropColumnWithPartitionConstraint(tbl, oldColName) - if err != nil { - return errors.Trace(err) - } - - newCol := oldCol.Clone() - newCol.Name = newColName - job := &model.Job{ - SchemaID: schema.ID, - TableID: tbl.Meta().ID, - SchemaName: schema.Name.L, - TableName: tbl.Meta().Name.L, - Type: model.ActionModifyColumn, - BinlogInfo: &model.HistoryInfo{}, - ReorgMeta: NewDDLReorgMeta(ctx), - Args: []any{&newCol, oldColName, spec.Position, 0, 0}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// ModifyColumn does modification on an existing column, currently we only support limited kind of changes -// that do not need to change or check data on the table. -func (e *executor) ModifyColumn(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - specNewColumn := spec.NewColumns[0] - if len(specNewColumn.Name.Schema.O) != 0 && ident.Schema.L != specNewColumn.Name.Schema.L { - return dbterror.ErrWrongDBName.GenWithStackByArgs(specNewColumn.Name.Schema.O) - } - if len(specNewColumn.Name.Table.O) != 0 && ident.Name.L != specNewColumn.Name.Table.L { - return dbterror.ErrWrongTableName.GenWithStackByArgs(specNewColumn.Name.Table.O) - } - - originalColName := specNewColumn.Name.Name - job, err := e.getModifiableColumnJob(ctx, sctx, ident, originalColName, spec) - if err != nil { - if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.FastGenByArgs(originalColName, ident.Name)) - return nil - } - return errors.Trace(err) - } - - err = e.DoDDLJob(sctx, job) - // column not exists, but if_exists flags is true, so we ignore this error. - if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - sctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return errors.Trace(err) -} - -func (e *executor) AlterColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - specNewColumn := spec.NewColumns[0] - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name) - } - t, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name) - } - - colName := specNewColumn.Name.Name - // Check whether alter column has existed. - oldCol := table.FindCol(t.Cols(), colName.L) - if oldCol == nil { - return dbterror.ErrBadField.GenWithStackByArgs(colName, ident.Name) - } - col := table.ToColumn(oldCol.Clone()) - - // Clean the NoDefaultValueFlag value. - col.DelFlag(mysql.NoDefaultValueFlag) - col.DefaultIsExpr = false - if len(specNewColumn.Options) == 0 { - err = col.SetDefaultValue(nil) - if err != nil { - return errors.Trace(err) - } - col.AddFlag(mysql.NoDefaultValueFlag) - } else { - if IsAutoRandomColumnID(t.Meta(), col.ID) { - return dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithDefaultValueErrMsg) - } - hasDefaultValue, err := SetDefaultValue(ctx, col, specNewColumn.Options[0]) - if err != nil { - return errors.Trace(err) - } - if err = checkDefaultValue(ctx.GetExprCtx(), col, hasDefaultValue); err != nil { - return errors.Trace(err) - } - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionSetDefaultValue, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{col}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// AlterTableComment updates the table comment information. -func (e *executor) AlterTableComment(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - - tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - sessionVars := ctx.GetSessionVars() - if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, ident.Name.L, &spec.Comment, dbterror.ErrTooLongTableComment); err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tb.Meta().ID, - SchemaName: schema.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionModifyTableComment, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{spec.Comment}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// AlterTableAutoIDCache updates the table comment information. -func (e *executor) AlterTableAutoIDCache(ctx sessionctx.Context, ident ast.Ident, newCache int64) error { - schema, tb, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - tbInfo := tb.Meta() - if (newCache == 1 && tbInfo.AutoIdCache != 1) || - (newCache != 1 && tbInfo.AutoIdCache == 1) { - return fmt.Errorf("Can't Alter AUTO_ID_CACHE between 1 and non-1, the underlying implementation is different") - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tb.Meta().ID, - SchemaName: schema.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionModifyTableAutoIdCache, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{newCache}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// AlterTableCharsetAndCollate changes the table charset and collate. -func (e *executor) AlterTableCharsetAndCollate(ctx sessionctx.Context, ident ast.Ident, toCharset, toCollate string, needsOverwriteCols bool) error { - // use the last one. - if toCharset == "" && toCollate == "" { - return dbterror.ErrUnknownCharacterSet.GenWithStackByArgs(toCharset) - } - - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - - tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - - if toCharset == "" { - // charset does not change. - toCharset = tb.Meta().Charset - } - - if toCollate == "" { - // Get the default collation of the charset. - toCollate, err = GetDefaultCollation(ctx.GetSessionVars(), toCharset) - if err != nil { - return errors.Trace(err) - } - } - doNothing, err := checkAlterTableCharset(tb.Meta(), schema, toCharset, toCollate, needsOverwriteCols) - if err != nil { - return err - } - if doNothing { - return nil - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tb.Meta().ID, - SchemaName: schema.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionModifyTableCharsetAndCollate, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{toCharset, toCollate, needsOverwriteCols}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func shouldModifyTiFlashReplica(tbReplicaInfo *model.TiFlashReplicaInfo, replicaInfo *ast.TiFlashReplicaSpec) bool { - if tbReplicaInfo != nil && tbReplicaInfo.Count == replicaInfo.Count && - len(tbReplicaInfo.LocationLabels) == len(replicaInfo.Labels) { - for i, label := range tbReplicaInfo.LocationLabels { - if replicaInfo.Labels[i] != label { - return true - } - } - return false - } - return true -} - -// addHypoTiFlashReplicaIntoCtx adds this hypothetical tiflash replica into this ctx. -func (*executor) setHypoTiFlashReplica(ctx sessionctx.Context, schemaName, tableName model.CIStr, replicaInfo *ast.TiFlashReplicaSpec) error { - sctx := ctx.GetSessionVars() - if sctx.HypoTiFlashReplicas == nil { - sctx.HypoTiFlashReplicas = make(map[string]map[string]struct{}) - } - if sctx.HypoTiFlashReplicas[schemaName.L] == nil { - sctx.HypoTiFlashReplicas[schemaName.L] = make(map[string]struct{}) - } - if replicaInfo.Count > 0 { // add replicas - sctx.HypoTiFlashReplicas[schemaName.L][tableName.L] = struct{}{} - } else { // delete replicas - delete(sctx.HypoTiFlashReplicas[schemaName.L], tableName.L) - } - return nil -} - -// AlterTableSetTiFlashReplica sets the TiFlash replicas info. -func (e *executor) AlterTableSetTiFlashReplica(ctx sessionctx.Context, ident ast.Ident, replicaInfo *ast.TiFlashReplicaSpec) error { - schema, tb, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - - err = isTableTiFlashSupported(schema.Name, tb.Meta()) - if err != nil { - return errors.Trace(err) - } - - tbReplicaInfo := tb.Meta().TiFlashReplica - if !shouldModifyTiFlashReplica(tbReplicaInfo, replicaInfo) { - return nil - } - - if replicaInfo.Hypo { - return e.setHypoTiFlashReplica(ctx, schema.Name, tb.Meta().Name, replicaInfo) - } - - err = checkTiFlashReplicaCount(ctx, replicaInfo.Count) - if err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tb.Meta().ID, - SchemaName: schema.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionSetTiFlashReplica, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{*replicaInfo}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// AlterTableTTLInfoOrEnable submit ddl job to change table info according to the ttlInfo, or ttlEnable -// at least one of the `ttlInfo`, `ttlEnable` or `ttlCronJobSchedule` should be not nil. -// When `ttlInfo` is nil, and `ttlEnable` is not, it will use the original `.TTLInfo` in the table info and modify the -// `.Enable`. If the `.TTLInfo` in the table info is empty, this function will return an error. -// When `ttlInfo` is nil, and `ttlCronJobSchedule` is not, it will use the original `.TTLInfo` in the table info and modify the -// `.JobInterval`. If the `.TTLInfo` in the table info is empty, this function will return an error. -// When `ttlInfo` is not nil, it simply submits the job with the `ttlInfo` and ignore the `ttlEnable`. -func (e *executor) AlterTableTTLInfoOrEnable(ctx sessionctx.Context, ident ast.Ident, ttlInfo *model.TTLInfo, ttlEnable *bool, ttlCronJobSchedule *string) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - - tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - - tblInfo := tb.Meta().Clone() - tableID := tblInfo.ID - tableName := tblInfo.Name.L - - var job *model.Job - if ttlInfo != nil { - tblInfo.TTLInfo = ttlInfo - err = checkTTLInfoValid(ctx, ident.Schema, tblInfo) - if err != nil { - return err - } - } else { - if tblInfo.TTLInfo == nil { - if ttlEnable != nil { - return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_ENABLE")) - } - if ttlCronJobSchedule != nil { - return errors.Trace(dbterror.ErrSetTTLOptionForNonTTLTable.FastGenByArgs("TTL_JOB_INTERVAL")) - } - } - } - - job = &model.Job{ - SchemaID: schema.ID, - TableID: tableID, - SchemaName: schema.Name.L, - TableName: tableName, - Type: model.ActionAlterTTLInfo, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{ttlInfo, ttlEnable, ttlCronJobSchedule}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) AlterTableRemoveTTL(ctx sessionctx.Context, ident ast.Ident) error { - is := e.infoCache.GetLatest() - - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - - tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - - tblInfo := tb.Meta().Clone() - tableID := tblInfo.ID - tableName := tblInfo.Name.L - - if tblInfo.TTLInfo != nil { - job := &model.Job{ - SchemaID: schema.ID, - TableID: tableID, - SchemaName: schema.Name.L, - TableName: tableName, - Type: model.ActionAlterTTLRemove, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) - } - - return nil -} - -func isTableTiFlashSupported(dbName model.CIStr, tbl *model.TableInfo) error { - // Memory tables and system tables are not supported by TiFlash - if util.IsMemOrSysDB(dbName.L) { - return errors.Trace(dbterror.ErrUnsupportedTiFlashOperationForSysOrMemTable) - } else if tbl.TempTableType != model.TempTableNone { - return dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("set on tiflash") - } else if tbl.IsView() || tbl.IsSequence() { - return dbterror.ErrWrongObject.GenWithStackByArgs(dbName, tbl.Name, "BASE TABLE") - } - - // Tables that has charset are not supported by TiFlash - for _, col := range tbl.Cols() { - _, ok := charset.TiFlashSupportedCharsets[col.GetCharset()] - if !ok { - return dbterror.ErrUnsupportedTiFlashOperationForUnsupportedCharsetTable.GenWithStackByArgs(col.GetCharset()) - } - } - - return nil -} - -func checkTiFlashReplicaCount(ctx sessionctx.Context, replicaCount uint64) error { - // Check the tiflash replica count should be less than the total tiflash stores. - tiflashStoreCnt, err := infoschema.GetTiFlashStoreCount(ctx) - if err != nil { - return errors.Trace(err) - } - if replicaCount > tiflashStoreCnt { - return errors.Errorf("the tiflash replica count: %d should be less than the total tiflash server count: %d", replicaCount, tiflashStoreCnt) - } - return nil -} - -// AlterTableAddStatistics registers extended statistics for a table. -func (e *executor) AlterTableAddStatistics(ctx sessionctx.Context, ident ast.Ident, stats *ast.StatisticsSpec, ifNotExists bool) error { - if !ctx.GetSessionVars().EnableExtendedStats { - return errors.New("Extended statistics feature is not generally available now, and tidb_enable_extended_stats is OFF") - } - // Not support Cardinality and Dependency statistics type for now. - if stats.StatsType == ast.StatsTypeCardinality || stats.StatsType == ast.StatsTypeDependency { - return errors.New("Cardinality and Dependency statistics types are not supported now") - } - _, tbl, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return err - } - tblInfo := tbl.Meta() - if tblInfo.GetPartitionInfo() != nil { - return errors.New("Extended statistics on partitioned tables are not supported now") - } - colIDs := make([]int64, 0, 2) - colIDSet := make(map[int64]struct{}, 2) - // Check whether columns exist. - for _, colName := range stats.Columns { - col := table.FindCol(tbl.VisibleCols(), colName.Name.L) - if col == nil { - return infoschema.ErrColumnNotExists.GenWithStackByArgs(colName.Name, ident.Name) - } - if stats.StatsType == ast.StatsTypeCorrelation && tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.GetFlag()) { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("No need to create correlation statistics on the integer primary key column")) - return nil - } - if _, exist := colIDSet[col.ID]; exist { - return errors.Errorf("Cannot create extended statistics on duplicate column names '%s'", colName.Name.L) - } - colIDSet[col.ID] = struct{}{} - colIDs = append(colIDs, col.ID) - } - if len(colIDs) != 2 && (stats.StatsType == ast.StatsTypeCorrelation || stats.StatsType == ast.StatsTypeDependency) { - return errors.New("Only support Correlation and Dependency statistics types on 2 columns") - } - if len(colIDs) < 1 && stats.StatsType == ast.StatsTypeCardinality { - return errors.New("Only support Cardinality statistics type on at least 2 columns") - } - // TODO: check whether covering index exists for cardinality / dependency types. - - // Call utilities of statistics.Handle to modify system tables instead of doing DML directly, - // because locking in Handle can guarantee the correctness of `version` in system tables. - return e.statsHandle.InsertExtendedStats(stats.StatsName, colIDs, int(stats.StatsType), tblInfo.ID, ifNotExists) -} - -// AlterTableDropStatistics logically deletes extended statistics for a table. -func (e *executor) AlterTableDropStatistics(ctx sessionctx.Context, ident ast.Ident, stats *ast.StatisticsSpec, ifExists bool) error { - if !ctx.GetSessionVars().EnableExtendedStats { - return errors.New("Extended statistics feature is not generally available now, and tidb_enable_extended_stats is OFF") - } - _, tbl, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return err - } - tblInfo := tbl.Meta() - // Call utilities of statistics.Handle to modify system tables instead of doing DML directly, - // because locking in Handle can guarantee the correctness of `version` in system tables. - return e.statsHandle.MarkExtendedStatsDeleted(stats.StatsName, tblInfo.ID, ifExists) -} - -// UpdateTableReplicaInfo updates the table flash replica infos. -func (e *executor) UpdateTableReplicaInfo(ctx sessionctx.Context, physicalID int64, available bool) error { - is := e.infoCache.GetLatest() - tb, ok := is.TableByID(physicalID) - if !ok { - tb, _, _ = is.FindTableByPartitionID(physicalID) - if tb == nil { - return infoschema.ErrTableNotExists.GenWithStack("Table which ID = %d does not exist.", physicalID) - } - } - tbInfo := tb.Meta() - if tbInfo.TiFlashReplica == nil || (tbInfo.ID == physicalID && tbInfo.TiFlashReplica.Available == available) || - (tbInfo.ID != physicalID && available == tbInfo.TiFlashReplica.IsPartitionAvailable(physicalID)) { - return nil - } - - db, ok := infoschema.SchemaByTable(is, tbInfo) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStack("Database of table `%s` does not exist.", tb.Meta().Name) - } - - job := &model.Job{ - SchemaID: db.ID, - TableID: tb.Meta().ID, - SchemaName: db.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionUpdateTiFlashReplicaStatus, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{available, physicalID}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err := e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// checkAlterTableCharset uses to check is it possible to change the charset of table. -// This function returns 2 variable: -// doNothing: if doNothing is true, means no need to change any more, because the target charset is same with the charset of table. -// err: if err is not nil, means it is not possible to change table charset to target charset. -func checkAlterTableCharset(tblInfo *model.TableInfo, dbInfo *model.DBInfo, toCharset, toCollate string, needsOverwriteCols bool) (doNothing bool, err error) { - origCharset := tblInfo.Charset - origCollate := tblInfo.Collate - // Old version schema charset maybe modified when load schema if TreatOldVersionUTF8AsUTF8MB4 was enable. - // So even if the origCharset equal toCharset, we still need to do the ddl for old version schema. - if origCharset == toCharset && origCollate == toCollate && tblInfo.Version >= model.TableInfoVersion2 { - // nothing to do. - doNothing = true - for _, col := range tblInfo.Columns { - if col.GetCharset() == charset.CharsetBin { - continue - } - if col.GetCharset() == toCharset && col.GetCollate() == toCollate { - continue - } - doNothing = false - } - if doNothing { - return doNothing, nil - } - } - - // This DDL will update the table charset to default charset. - origCharset, origCollate, err = ResolveCharsetCollation(nil, - ast.CharsetOpt{Chs: origCharset, Col: origCollate}, - ast.CharsetOpt{Chs: dbInfo.Charset, Col: dbInfo.Collate}, - ) - if err != nil { - return doNothing, err - } - - if err = checkModifyCharsetAndCollation(toCharset, toCollate, origCharset, origCollate, false); err != nil { - return doNothing, err - } - if !needsOverwriteCols { - // If we don't change the charset and collation of columns, skip the next checks. - return doNothing, nil - } - - for _, col := range tblInfo.Columns { - if col.GetType() == mysql.TypeVarchar { - if err = types.IsVarcharTooBigFieldLength(col.GetFlen(), col.Name.O, toCharset); err != nil { - return doNothing, err - } - } - if col.GetCharset() == charset.CharsetBin { - continue - } - if len(col.GetCharset()) == 0 { - continue - } - if err = checkModifyCharsetAndCollation(toCharset, toCollate, col.GetCharset(), col.GetCollate(), isColumnWithIndex(col.Name.L, tblInfo.Indices)); err != nil { - if strings.Contains(err.Error(), "Unsupported modifying collation") { - colErrMsg := "Unsupported converting collation of column '%s' from '%s' to '%s' when index is defined on it." - err = dbterror.ErrUnsupportedModifyCollation.GenWithStack(colErrMsg, col.Name.L, col.GetCollate(), toCollate) - } - return doNothing, err - } - } - return doNothing, nil -} - -// RenameIndex renames an index. -// In TiDB, indexes are case-insensitive (so index 'a' and 'A" are considered the same index), -// but index names are case-sensitive (we can rename index 'a' to 'A') -func (e *executor) RenameIndex(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - - tb, err := is.TableByName(e.ctx, ident.Schema, ident.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ident.Schema, ident.Name)) - } - if tb.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Index")) - } - duplicate, err := ValidateRenameIndex(spec.FromKey, spec.ToKey, tb.Meta()) - if duplicate { - return nil - } - if err != nil { - return errors.Trace(err) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tb.Meta().ID, - SchemaName: schema.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionRenameIndex, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{spec.FromKey, spec.ToKey}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// If one drop those tables by mistake, it's difficult to recover. -// In the worst case, the whole TiDB cluster fails to bootstrap, so we prevent user from dropping them. -var systemTables = map[string]struct{}{ - "tidb": {}, - "gc_delete_range": {}, - "gc_delete_range_done": {}, -} - -func isUndroppableTable(schema, table string) bool { - if schema != mysql.SystemDB { - return false - } - if _, ok := systemTables[table]; ok { - return true - } - return false -} - -type objectType int - -const ( - tableObject objectType = iota - viewObject - sequenceObject -) - -// dropTableObject provides common logic to DROP TABLE/VIEW/SEQUENCE. -func (e *executor) dropTableObject( - ctx sessionctx.Context, - objects []*ast.TableName, - ifExists bool, - tableObjectType objectType, -) error { - var ( - notExistTables []string - sessVars = ctx.GetSessionVars() - is = e.infoCache.GetLatest() - dropExistErr *terror.Error - jobType model.ActionType - ) - - var jobArgs []any - switch tableObjectType { - case tableObject: - dropExistErr = infoschema.ErrTableDropExists - jobType = model.ActionDropTable - objectIdents := make([]ast.Ident, len(objects)) - fkCheck := ctx.GetSessionVars().ForeignKeyChecks - jobArgs = []any{objectIdents, fkCheck} - for i, tn := range objects { - objectIdents[i] = ast.Ident{Schema: tn.Schema, Name: tn.Name} - } - for _, tn := range objects { - if referredFK := checkTableHasForeignKeyReferred(is, tn.Schema.L, tn.Name.L, objectIdents, fkCheck); referredFK != nil { - return errors.Trace(dbterror.ErrForeignKeyCannotDropParent.GenWithStackByArgs(tn.Name, referredFK.ChildFKName, referredFK.ChildTable)) - } - } - case viewObject: - dropExistErr = infoschema.ErrTableDropExists - jobType = model.ActionDropView - case sequenceObject: - dropExistErr = infoschema.ErrSequenceDropExists - jobType = model.ActionDropSequence - } - for _, tn := range objects { - fullti := ast.Ident{Schema: tn.Schema, Name: tn.Name} - schema, ok := is.SchemaByName(tn.Schema) - if !ok { - // TODO: we should return special error for table not exist, checking "not exist" is not enough, - // because some other errors may contain this error string too. - notExistTables = append(notExistTables, fullti.String()) - continue - } - tableInfo, err := is.TableByName(e.ctx, tn.Schema, tn.Name) - if err != nil && infoschema.ErrTableNotExists.Equal(err) { - notExistTables = append(notExistTables, fullti.String()) - continue - } else if err != nil { - return err - } - - // prechecks before build DDL job - - // Protect important system table from been dropped by a mistake. - // I can hardly find a case that a user really need to do this. - if isUndroppableTable(tn.Schema.L, tn.Name.L) { - return errors.Errorf("Drop tidb system table '%s.%s' is forbidden", tn.Schema.L, tn.Name.L) - } - switch tableObjectType { - case tableObject: - if !tableInfo.Meta().IsBaseTable() { - notExistTables = append(notExistTables, fullti.String()) - continue - } - - tempTableType := tableInfo.Meta().TempTableType - if config.CheckTableBeforeDrop && tempTableType == model.TempTableNone { - logutil.DDLLogger().Warn("admin check table before drop", - zap.String("database", fullti.Schema.O), - zap.String("table", fullti.Name.O), - ) - exec := ctx.GetRestrictedSQLExecutor() - internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - _, _, err := exec.ExecRestrictedSQL(internalCtx, nil, "admin check table %n.%n", fullti.Schema.O, fullti.Name.O) - if err != nil { - return err - } - } - - if tableInfo.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Drop Table") - } - case viewObject: - if !tableInfo.Meta().IsView() { - return dbterror.ErrWrongObject.GenWithStackByArgs(fullti.Schema, fullti.Name, "VIEW") - } - case sequenceObject: - if !tableInfo.Meta().IsSequence() { - err = dbterror.ErrWrongObject.GenWithStackByArgs(fullti.Schema, fullti.Name, "SEQUENCE") - if ifExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - continue - } - return err - } - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tableInfo.Meta().ID, - SchemaName: schema.Name.L, - SchemaState: schema.State, - TableName: tableInfo.Meta().Name.L, - Type: jobType, - BinlogInfo: &model.HistoryInfo{}, - Args: jobArgs, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { - notExistTables = append(notExistTables, fullti.String()) - continue - } else if err != nil { - return errors.Trace(err) - } - - // unlock table after drop - if tableObjectType != tableObject { - continue - } - if !config.TableLockEnabled() { - continue - } - if ok, _ := ctx.CheckTableLocked(tableInfo.Meta().ID); ok { - ctx.ReleaseTableLockByTableIDs([]int64{tableInfo.Meta().ID}) - } - } - if len(notExistTables) > 0 && !ifExists { - return dropExistErr.FastGenByArgs(strings.Join(notExistTables, ",")) - } - // We need add warning when use if exists. - if len(notExistTables) > 0 && ifExists { - for _, table := range notExistTables { - sessVars.StmtCtx.AppendNote(dropExistErr.FastGenByArgs(table)) - } - } - return nil -} - -// DropTable will proceed even if some table in the list does not exists. -func (e *executor) DropTable(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) { - return e.dropTableObject(ctx, stmt.Tables, stmt.IfExists, tableObject) -} - -// DropView will proceed even if some view in the list does not exists. -func (e *executor) DropView(ctx sessionctx.Context, stmt *ast.DropTableStmt) (err error) { - return e.dropTableObject(ctx, stmt.Tables, stmt.IfExists, viewObject) -} - -func (e *executor) TruncateTable(ctx sessionctx.Context, ti ast.Ident) error { - schema, tb, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return errors.Trace(err) - } - if tb.Meta().IsView() || tb.Meta().IsSequence() { - return infoschema.ErrTableNotExists.GenWithStackByArgs(schema.Name.O, tb.Meta().Name.O) - } - if tb.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Truncate Table") - } - fkCheck := ctx.GetSessionVars().ForeignKeyChecks - referredFK := checkTableHasForeignKeyReferred(e.infoCache.GetLatest(), ti.Schema.L, ti.Name.L, []ast.Ident{{Name: ti.Name, Schema: ti.Schema}}, fkCheck) - if referredFK != nil { - msg := fmt.Sprintf("`%s`.`%s` CONSTRAINT `%s`", referredFK.ChildSchema, referredFK.ChildTable, referredFK.ChildFKName) - return errors.Trace(dbterror.ErrTruncateIllegalForeignKey.GenWithStackByArgs(msg)) - } - - ids := 1 - if tb.Meta().Partition != nil { - ids += len(tb.Meta().Partition.Definitions) - } - genIDs, err := e.genGlobalIDs(ids) - if err != nil { - return errors.Trace(err) - } - newTableID := genIDs[0] - job := &model.Job{ - SchemaID: schema.ID, - TableID: tb.Meta().ID, - SchemaName: schema.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionTruncateTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{newTableID, fkCheck, genIDs[1:]}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - if ok, _ := ctx.CheckTableLocked(tb.Meta().ID); ok && config.TableLockEnabled() { - // AddTableLock here to avoid this ddl job was executed successfully but the session was been kill before return. - // The session will release all table locks it holds, if we don't add the new locking table id here, - // the session may forget to release the new locked table id when this ddl job was executed successfully - // but the session was killed before return. - ctx.AddTableLock([]model.TableLockTpInfo{{SchemaID: schema.ID, TableID: newTableID, Tp: tb.Meta().Lock.Tp}}) - } - err = e.DoDDLJob(ctx, job) - if err != nil { - if config.TableLockEnabled() { - ctx.ReleaseTableLockByTableIDs([]int64{newTableID}) - } - return errors.Trace(err) - } - - if !config.TableLockEnabled() { - return nil - } - if ok, _ := ctx.CheckTableLocked(tb.Meta().ID); ok { - ctx.ReleaseTableLockByTableIDs([]int64{tb.Meta().ID}) - } - return nil -} - -func (e *executor) RenameTable(ctx sessionctx.Context, s *ast.RenameTableStmt) error { - isAlterTable := false - var err error - if len(s.TableToTables) == 1 { - oldIdent := ast.Ident{Schema: s.TableToTables[0].OldTable.Schema, Name: s.TableToTables[0].OldTable.Name} - newIdent := ast.Ident{Schema: s.TableToTables[0].NewTable.Schema, Name: s.TableToTables[0].NewTable.Name} - err = e.renameTable(ctx, oldIdent, newIdent, isAlterTable) - } else { - oldIdents := make([]ast.Ident, 0, len(s.TableToTables)) - newIdents := make([]ast.Ident, 0, len(s.TableToTables)) - for _, tables := range s.TableToTables { - oldIdent := ast.Ident{Schema: tables.OldTable.Schema, Name: tables.OldTable.Name} - newIdent := ast.Ident{Schema: tables.NewTable.Schema, Name: tables.NewTable.Name} - oldIdents = append(oldIdents, oldIdent) - newIdents = append(newIdents, newIdent) - } - err = e.renameTables(ctx, oldIdents, newIdents, isAlterTable) - } - return err -} - -func (e *executor) renameTable(ctx sessionctx.Context, oldIdent, newIdent ast.Ident, isAlterTable bool) error { - is := e.infoCache.GetLatest() - tables := make(map[string]int64) - schemas, tableID, err := ExtractTblInfos(is, oldIdent, newIdent, isAlterTable, tables) - if err != nil { - return err - } - - if schemas == nil { - return nil - } - - if tbl, ok := is.TableByID(tableID); ok { - if tbl.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Table")) - } - } - - job := &model.Job{ - SchemaID: schemas[1].ID, - TableID: tableID, - SchemaName: schemas[1].Name.L, - TableName: oldIdent.Name.L, - Type: model.ActionRenameTable, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{schemas[0].ID, newIdent.Name, schemas[0].Name}, - CtxVars: []any{[]int64{schemas[0].ID, schemas[1].ID}, []int64{tableID}}, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{ - {Database: schemas[0].Name.L, Table: oldIdent.Name.L}, - {Database: schemas[1].Name.L, Table: newIdent.Name.L}, - }, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) renameTables(ctx sessionctx.Context, oldIdents, newIdents []ast.Ident, isAlterTable bool) error { - is := e.infoCache.GetLatest() - oldTableNames := make([]*model.CIStr, 0, len(oldIdents)) - tableNames := make([]*model.CIStr, 0, len(oldIdents)) - oldSchemaIDs := make([]int64, 0, len(oldIdents)) - newSchemaIDs := make([]int64, 0, len(oldIdents)) - tableIDs := make([]int64, 0, len(oldIdents)) - oldSchemaNames := make([]*model.CIStr, 0, len(oldIdents)) - involveSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(oldIdents)*2) - - var schemas []*model.DBInfo - var tableID int64 - var err error - - tables := make(map[string]int64) - for i := 0; i < len(oldIdents); i++ { - schemas, tableID, err = ExtractTblInfos(is, oldIdents[i], newIdents[i], isAlterTable, tables) - if err != nil { - return err - } - - if t, ok := is.TableByID(tableID); ok { - if t.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Tables")) - } - } - - tableIDs = append(tableIDs, tableID) - oldTableNames = append(oldTableNames, &oldIdents[i].Name) - tableNames = append(tableNames, &newIdents[i].Name) - oldSchemaIDs = append(oldSchemaIDs, schemas[0].ID) - newSchemaIDs = append(newSchemaIDs, schemas[1].ID) - oldSchemaNames = append(oldSchemaNames, &schemas[0].Name) - involveSchemaInfo = append(involveSchemaInfo, - model.InvolvingSchemaInfo{ - Database: schemas[0].Name.L, Table: oldIdents[i].Name.L, - }, - model.InvolvingSchemaInfo{ - Database: schemas[1].Name.L, Table: newIdents[i].Name.L, - }, - ) - } - - job := &model.Job{ - SchemaID: schemas[1].ID, - TableID: tableIDs[0], - SchemaName: schemas[1].Name.L, - Type: model.ActionRenameTables, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{oldSchemaIDs, newSchemaIDs, tableNames, tableIDs, oldSchemaNames, oldTableNames}, - CtxVars: []any{append(oldSchemaIDs, newSchemaIDs...), tableIDs}, - InvolvingSchemaInfo: involveSchemaInfo, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// ExtractTblInfos extracts the table information from the infoschema. -func ExtractTblInfos(is infoschema.InfoSchema, oldIdent, newIdent ast.Ident, isAlterTable bool, tables map[string]int64) ([]*model.DBInfo, int64, error) { - oldSchema, ok := is.SchemaByName(oldIdent.Schema) - if !ok { - if isAlterTable { - return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) - } - if tableExists(is, newIdent, tables) { - return nil, 0, infoschema.ErrTableExists.GenWithStackByArgs(newIdent) - } - return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) - } - if !tableExists(is, oldIdent, tables) { - if isAlterTable { - return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) - } - if tableExists(is, newIdent, tables) { - return nil, 0, infoschema.ErrTableExists.GenWithStackByArgs(newIdent) - } - return nil, 0, infoschema.ErrTableNotExists.GenWithStackByArgs(oldIdent.Schema, oldIdent.Name) - } - if isAlterTable && newIdent.Schema.L == oldIdent.Schema.L && newIdent.Name.L == oldIdent.Name.L { - // oldIdent is equal to newIdent, do nothing - return nil, 0, nil - } - //View can be renamed only in the same schema. Compatible with mysql - if infoschema.TableIsView(is, oldIdent.Schema, oldIdent.Name) { - if oldIdent.Schema != newIdent.Schema { - return nil, 0, infoschema.ErrForbidSchemaChange.GenWithStackByArgs(oldIdent.Schema, newIdent.Schema) - } - } - - newSchema, ok := is.SchemaByName(newIdent.Schema) - if !ok { - return nil, 0, dbterror.ErrErrorOnRename.GenWithStackByArgs( - fmt.Sprintf("%s.%s", oldIdent.Schema, oldIdent.Name), - fmt.Sprintf("%s.%s", newIdent.Schema, newIdent.Name), - 168, - fmt.Sprintf("Database `%s` doesn't exist", newIdent.Schema)) - } - if tableExists(is, newIdent, tables) { - return nil, 0, infoschema.ErrTableExists.GenWithStackByArgs(newIdent) - } - if err := checkTooLongTable(newIdent.Name); err != nil { - return nil, 0, errors.Trace(err) - } - oldTableID := getTableID(is, oldIdent, tables) - oldIdentKey := getIdentKey(oldIdent) - tables[oldIdentKey] = tableNotExist - newIdentKey := getIdentKey(newIdent) - tables[newIdentKey] = oldTableID - return []*model.DBInfo{oldSchema, newSchema}, oldTableID, nil -} - -func tableExists(is infoschema.InfoSchema, ident ast.Ident, tables map[string]int64) bool { - identKey := getIdentKey(ident) - tableID, ok := tables[identKey] - if (ok && tableID != tableNotExist) || (!ok && is.TableExists(ident.Schema, ident.Name)) { - return true - } - return false -} - -func getTableID(is infoschema.InfoSchema, ident ast.Ident, tables map[string]int64) int64 { - identKey := getIdentKey(ident) - tableID, ok := tables[identKey] - if !ok { - oldTbl, err := is.TableByName(context.Background(), ident.Schema, ident.Name) - if err != nil { - return tableNotExist - } - tableID = oldTbl.Meta().ID - } - return tableID -} - -func getIdentKey(ident ast.Ident) string { - return fmt.Sprintf("%s.%s", ident.Schema.L, ident.Name.L) -} - -// GetName4AnonymousIndex returns a valid name for anonymous index. -func GetName4AnonymousIndex(t table.Table, colName model.CIStr, idxName model.CIStr) model.CIStr { - // `id` is used to indicated the index name's suffix. - id := 2 - l := len(t.Indices()) - indexName := colName - if idxName.O != "" { - // Use the provided index name, it only happens when the original index name is too long and be truncated. - indexName = idxName - id = 3 - } - if strings.EqualFold(indexName.L, mysql.PrimaryKeyName) { - indexName = model.NewCIStr(fmt.Sprintf("%s_%d", colName.O, id)) - id = 3 - } - for i := 0; i < l; i++ { - if t.Indices()[i].Meta().Name.L == indexName.L { - indexName = model.NewCIStr(fmt.Sprintf("%s_%d", colName.O, id)) - if err := checkTooLongIndex(indexName); err != nil { - indexName = GetName4AnonymousIndex(t, model.NewCIStr(colName.O[:30]), model.NewCIStr(fmt.Sprintf("%s_%d", colName.O[:30], 2))) - } - i = -1 - id++ - } - } - return indexName -} - -func (e *executor) CreatePrimaryKey(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, - indexPartSpecifications []*ast.IndexPartSpecification, indexOption *ast.IndexOption) error { - if indexOption != nil && indexOption.PrimaryKeyTp == model.PrimaryKeyTypeClustered { - return dbterror.ErrUnsupportedModifyPrimaryKey.GenWithStack("Adding clustered primary key is not supported. " + - "Please consider adding NONCLUSTERED primary key instead") - } - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return errors.Trace(err) - } - - if err = checkTooLongIndex(indexName); err != nil { - return dbterror.ErrTooLongIdent.GenWithStackByArgs(mysql.PrimaryKeyName) - } - - indexName = model.NewCIStr(mysql.PrimaryKeyName) - if indexInfo := t.Meta().FindIndexByName(indexName.L); indexInfo != nil || - // If the table's PKIsHandle is true, it also means that this table has a primary key. - t.Meta().PKIsHandle { - return infoschema.ErrMultiplePriKey - } - - // Primary keys cannot include expression index parts. A primary key requires the generated column to be stored, - // but expression index parts are implemented as virtual generated columns, not stored generated columns. - for _, idxPart := range indexPartSpecifications { - if idxPart.Expr != nil { - return dbterror.ErrFunctionalIndexPrimaryKey - } - } - - tblInfo := t.Meta() - // Check before the job is put to the queue. - // This check is redundant, but useful. If DDL check fail before the job is put - // to job queue, the fail path logic is super fast. - // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. - // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. - // For same reason, decide whether index is global here. - indexColumns, _, err := buildIndexColumns(ctx, tblInfo.Columns, indexPartSpecifications) - if err != nil { - return errors.Trace(err) - } - if _, err = CheckPKOnGeneratedColumn(tblInfo, indexPartSpecifications); err != nil { - return err - } - - global := false - if tblInfo.GetPartitionInfo() != nil { - ck, err := checkPartitionKeysConstraint(tblInfo.GetPartitionInfo(), indexColumns, tblInfo) - if err != nil { - return err - } - if !ck { - if !ctx.GetSessionVars().EnableGlobalIndex { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("PRIMARY") - } - // index columns does not contain all partition columns, must set global - global = true - } - } - - // May be truncate comment here, when index comment too long and sql_mode is't strict. - if indexOption != nil { - sessionVars := ctx.GetSessionVars() - if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { - return errors.Trace(err) - } - } - - unique := true - sqlMode := ctx.GetSessionVars().SQLMode - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionAddPrimaryKey, - BinlogInfo: &model.HistoryInfo{}, - ReorgMeta: nil, - Args: []any{unique, indexName, indexPartSpecifications, indexOption, sqlMode, nil, global}, - Priority: ctx.GetSessionVars().DDLReorgPriority, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - reorgMeta, err := newReorgMetaFromVariables(job, ctx) - if err != nil { - return err - } - job.ReorgMeta = reorgMeta - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) CreateIndex(ctx sessionctx.Context, stmt *ast.CreateIndexStmt) error { - ident := ast.Ident{Schema: stmt.Table.Schema, Name: stmt.Table.Name} - return e.createIndex(ctx, ident, stmt.KeyType, model.NewCIStr(stmt.IndexName), - stmt.IndexPartSpecifications, stmt.IndexOption, stmt.IfNotExists) -} - -// addHypoIndexIntoCtx adds this index as a hypo-index into this ctx. -func (*executor) addHypoIndexIntoCtx(ctx sessionctx.Context, schemaName, tableName model.CIStr, indexInfo *model.IndexInfo) error { - sctx := ctx.GetSessionVars() - indexName := indexInfo.Name - - if sctx.HypoIndexes == nil { - sctx.HypoIndexes = make(map[string]map[string]map[string]*model.IndexInfo) - } - if sctx.HypoIndexes[schemaName.L] == nil { - sctx.HypoIndexes[schemaName.L] = make(map[string]map[string]*model.IndexInfo) - } - if sctx.HypoIndexes[schemaName.L][tableName.L] == nil { - sctx.HypoIndexes[schemaName.L][tableName.L] = make(map[string]*model.IndexInfo) - } - if _, exist := sctx.HypoIndexes[schemaName.L][tableName.L][indexName.L]; exist { - return errors.Trace(errors.Errorf("conflict hypo index name %s", indexName.L)) - } - - sctx.HypoIndexes[schemaName.L][tableName.L][indexName.L] = indexInfo - return nil -} - -func (e *executor) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast.IndexKeyType, indexName model.CIStr, - indexPartSpecifications []*ast.IndexPartSpecification, indexOption *ast.IndexOption, ifNotExists bool) error { - // not support Spatial and FullText index - if keyType == ast.IndexKeyTypeFullText || keyType == ast.IndexKeyTypeSpatial { - return dbterror.ErrUnsupportedIndexType.GenWithStack("FULLTEXT and SPATIAL index is not supported") - } - unique := keyType == ast.IndexKeyTypeUnique - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return errors.Trace(err) - } - - if t.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Create Index")) - } - // Deal with anonymous index. - if len(indexName.L) == 0 { - colName := model.NewCIStr("expression_index") - if indexPartSpecifications[0].Column != nil { - colName = indexPartSpecifications[0].Column.Name - } - indexName = GetName4AnonymousIndex(t, colName, model.NewCIStr("")) - } - - if indexInfo := t.Meta().FindIndexByName(indexName.L); indexInfo != nil { - if indexInfo.State != model.StatePublic { - // NOTE: explicit error message. See issue #18363. - err = dbterror.ErrDupKeyName.GenWithStack("Duplicate key name '%s'; "+ - "a background job is trying to add the same index, "+ - "please check by `ADMIN SHOW DDL JOBS`", indexName) - } else { - err = dbterror.ErrDupKeyName.GenWithStackByArgs(indexName) - } - if ifNotExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return err - } - - if err = checkTooLongIndex(indexName); err != nil { - return errors.Trace(err) - } - - tblInfo := t.Meta() - - // Build hidden columns if necessary. - hiddenCols, err := buildHiddenColumnInfoWithCheck(ctx, indexPartSpecifications, indexName, t.Meta(), t.Cols()) - if err != nil { - return err - } - if err = checkAddColumnTooManyColumns(len(t.Cols()) + len(hiddenCols)); err != nil { - return errors.Trace(err) - } - - finalColumns := make([]*model.ColumnInfo, len(tblInfo.Columns), len(tblInfo.Columns)+len(hiddenCols)) - copy(finalColumns, tblInfo.Columns) - finalColumns = append(finalColumns, hiddenCols...) - // Check before the job is put to the queue. - // This check is redundant, but useful. If DDL check fail before the job is put - // to job queue, the fail path logic is super fast. - // After DDL job is put to the queue, and if the check fail, TiDB will run the DDL cancel logic. - // The recover step causes DDL wait a few seconds, makes the unit test painfully slow. - // For same reason, decide whether index is global here. - indexColumns, _, err := buildIndexColumns(ctx, finalColumns, indexPartSpecifications) - if err != nil { - return errors.Trace(err) - } - - global := false - if unique && tblInfo.GetPartitionInfo() != nil { - ck, err := checkPartitionKeysConstraint(tblInfo.GetPartitionInfo(), indexColumns, tblInfo) - if err != nil { - return err - } - if !ck { - if !ctx.GetSessionVars().EnableGlobalIndex { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("UNIQUE INDEX") - } - // index columns does not contain all partition columns, must set global - global = true - } - } - // May be truncate comment here, when index comment too long and sql_mode is't strict. - if indexOption != nil { - sessionVars := ctx.GetSessionVars() - if _, err = validateCommentLength(sessionVars.StmtCtx.ErrCtx(), sessionVars.SQLMode, indexName.String(), &indexOption.Comment, dbterror.ErrTooLongIndexComment); err != nil { - return errors.Trace(err) - } - } - - if indexOption != nil && indexOption.Tp == model.IndexTypeHypo { // for hypo-index - indexInfo, err := BuildIndexInfo(ctx, tblInfo.Columns, indexName, false, unique, global, - indexPartSpecifications, indexOption, model.StatePublic) - if err != nil { - return err - } - return e.addHypoIndexIntoCtx(ctx, ti.Schema, ti.Name, indexInfo) - } - - chs, coll := ctx.GetSessionVars().GetCharsetInfo() - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionAddIndex, - BinlogInfo: &model.HistoryInfo{}, - ReorgMeta: nil, - Args: []any{unique, indexName, indexPartSpecifications, indexOption, hiddenCols, global}, - Priority: ctx.GetSessionVars().DDLReorgPriority, - Charset: chs, - Collate: coll, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - reorgMeta, err := newReorgMetaFromVariables(job, ctx) - if err != nil { - return err - } - job.ReorgMeta = reorgMeta - - err = e.DoDDLJob(ctx, job) - // key exists, but if_not_exists flags is true, so we ignore this error. - if dbterror.ErrDupKeyName.Equal(err) && ifNotExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return errors.Trace(err) -} - -func newReorgMetaFromVariables(job *model.Job, sctx sessionctx.Context) (*model.DDLReorgMeta, error) { - reorgMeta := NewDDLReorgMeta(sctx) - reorgMeta.IsDistReorg = variable.EnableDistTask.Load() - reorgMeta.IsFastReorg = variable.EnableFastReorg.Load() - reorgMeta.TargetScope = variable.ServiceScope.Load() - - if reorgMeta.IsDistReorg && !reorgMeta.IsFastReorg { - return nil, dbterror.ErrUnsupportedDistTask - } - if hasSysDB(job) { - if reorgMeta.IsDistReorg { - logutil.DDLLogger().Info("cannot use distributed task execution on system DB", - zap.Stringer("job", job)) - } - reorgMeta.IsDistReorg = false - reorgMeta.IsFastReorg = false - failpoint.Inject("reorgMetaRecordFastReorgDisabled", func(_ failpoint.Value) { - LastReorgMetaFastReorgDisabled = true - }) - } - return reorgMeta, nil -} - -// LastReorgMetaFastReorgDisabled is used for test. -var LastReorgMetaFastReorgDisabled bool - -func buildFKInfo(fkName model.CIStr, keys []*ast.IndexPartSpecification, refer *ast.ReferenceDef, cols []*table.Column) (*model.FKInfo, error) { - if len(keys) != len(refer.IndexPartSpecifications) { - return nil, infoschema.ErrForeignKeyNotMatch.GenWithStackByArgs(fkName, "Key reference and table reference don't match") - } - if err := checkTooLongForeignKey(fkName); err != nil { - return nil, err - } - if err := checkTooLongSchema(refer.Table.Schema); err != nil { - return nil, err - } - if err := checkTooLongTable(refer.Table.Name); err != nil { - return nil, err - } - - // all base columns of stored generated columns - baseCols := make(map[string]struct{}) - for _, col := range cols { - if col.IsGenerated() && col.GeneratedStored { - for name := range col.Dependences { - baseCols[name] = struct{}{} - } - } - } - - fkInfo := &model.FKInfo{ - Name: fkName, - RefSchema: refer.Table.Schema, - RefTable: refer.Table.Name, - Cols: make([]model.CIStr, len(keys)), - } - if variable.EnableForeignKey.Load() { - fkInfo.Version = model.FKVersion1 - } - - for i, key := range keys { - // Check add foreign key to generated columns - // For more detail, see https://dev.mysql.com/doc/refman/8.0/en/innodb-foreign-key-constraints.html#innodb-foreign-key-generated-columns - for _, col := range cols { - if col.Name.L != key.Column.Name.L { - continue - } - if col.IsGenerated() { - // Check foreign key on virtual generated columns - if !col.GeneratedStored { - return nil, infoschema.ErrForeignKeyCannotUseVirtualColumn.GenWithStackByArgs(fkInfo.Name.O, col.Name.O) - } - - // Check wrong reference options of foreign key on stored generated columns - switch refer.OnUpdate.ReferOpt { - case model.ReferOptionCascade, model.ReferOptionSetNull, model.ReferOptionSetDefault: - //nolint: gosec - return nil, dbterror.ErrWrongFKOptionForGeneratedColumn.GenWithStackByArgs("ON UPDATE " + refer.OnUpdate.ReferOpt.String()) - } - switch refer.OnDelete.ReferOpt { - case model.ReferOptionSetNull, model.ReferOptionSetDefault: - //nolint: gosec - return nil, dbterror.ErrWrongFKOptionForGeneratedColumn.GenWithStackByArgs("ON DELETE " + refer.OnDelete.ReferOpt.String()) - } - continue - } - // Check wrong reference options of foreign key on base columns of stored generated columns - if _, ok := baseCols[col.Name.L]; ok { - switch refer.OnUpdate.ReferOpt { - case model.ReferOptionCascade, model.ReferOptionSetNull, model.ReferOptionSetDefault: - return nil, infoschema.ErrCannotAddForeign - } - switch refer.OnDelete.ReferOpt { - case model.ReferOptionCascade, model.ReferOptionSetNull, model.ReferOptionSetDefault: - return nil, infoschema.ErrCannotAddForeign - } - } - } - col := table.FindCol(cols, key.Column.Name.O) - if col == nil { - return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStackByArgs(key.Column.Name) - } - if mysql.HasNotNullFlag(col.GetFlag()) && (refer.OnDelete.ReferOpt == model.ReferOptionSetNull || refer.OnUpdate.ReferOpt == model.ReferOptionSetNull) { - return nil, infoschema.ErrForeignKeyColumnNotNull.GenWithStackByArgs(col.Name.O, fkName) - } - fkInfo.Cols[i] = key.Column.Name - } - - fkInfo.RefCols = make([]model.CIStr, len(refer.IndexPartSpecifications)) - for i, key := range refer.IndexPartSpecifications { - if err := checkTooLongColumn(key.Column.Name); err != nil { - return nil, err - } - fkInfo.RefCols[i] = key.Column.Name - } - - fkInfo.OnDelete = int(refer.OnDelete.ReferOpt) - fkInfo.OnUpdate = int(refer.OnUpdate.ReferOpt) - - return fkInfo, nil -} - -func (e *executor) CreateForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName model.CIStr, keys []*ast.IndexPartSpecification, refer *ast.ReferenceDef) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ti.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) - } - - t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) - } - if t.Meta().TempTableType != model.TempTableNone { - return infoschema.ErrCannotAddForeign - } - - if fkName.L == "" { - fkName = model.NewCIStr(fmt.Sprintf("fk_%d", t.Meta().MaxForeignKeyID+1)) - } - err = checkFKDupName(t.Meta(), fkName) - if err != nil { - return err - } - fkInfo, err := buildFKInfo(fkName, keys, refer, t.Cols()) - if err != nil { - return errors.Trace(err) - } - fkCheck := ctx.GetSessionVars().ForeignKeyChecks - err = checkAddForeignKeyValid(is, schema.Name.L, t.Meta(), fkInfo, fkCheck) - if err != nil { - return err - } - if model.FindIndexByColumns(t.Meta(), t.Meta().Indices, fkInfo.Cols...) == nil { - // Need to auto create index for fk cols - if ctx.GetSessionVars().StmtCtx.MultiSchemaInfo == nil { - ctx.GetSessionVars().StmtCtx.MultiSchemaInfo = model.NewMultiSchemaInfo() - } - indexPartSpecifications := make([]*ast.IndexPartSpecification, 0, len(fkInfo.Cols)) - for _, col := range fkInfo.Cols { - indexPartSpecifications = append(indexPartSpecifications, &ast.IndexPartSpecification{ - Column: &ast.ColumnName{Name: col}, - Length: types.UnspecifiedLength, // Index prefixes on foreign key columns are not supported. - }) - } - indexOption := &ast.IndexOption{} - err = e.createIndex(ctx, ti, ast.IndexKeyTypeNone, fkInfo.Name, indexPartSpecifications, indexOption, false) - if err != nil { - return err - } - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionAddForeignKey, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{fkInfo, fkCheck}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{ - { - Database: schema.Name.L, - Table: t.Meta().Name.L, - }, - { - Database: fkInfo.RefSchema.L, - Table: fkInfo.RefTable.L, - Mode: model.SharedInvolving, - }, - }, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) DropForeignKey(ctx sessionctx.Context, ti ast.Ident, fkName model.CIStr) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ti.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ti.Schema) - } - - t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - SchemaState: model.StatePublic, - TableName: t.Meta().Name.L, - Type: model.ActionDropForeignKey, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{fkName}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) DropIndex(ctx sessionctx.Context, stmt *ast.DropIndexStmt) error { - ti := ast.Ident{Schema: stmt.Table.Schema, Name: stmt.Table.Name} - err := e.dropIndex(ctx, ti, model.NewCIStr(stmt.IndexName), stmt.IfExists, stmt.IsHypo) - if (infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err)) && stmt.IfExists { - err = nil - } - return err -} - -// dropHypoIndexFromCtx drops this hypo-index from this ctx. -func (*executor) dropHypoIndexFromCtx(ctx sessionctx.Context, schema, table, index model.CIStr, ifExists bool) error { - sctx := ctx.GetSessionVars() - if sctx.HypoIndexes != nil && - sctx.HypoIndexes[schema.L] != nil && - sctx.HypoIndexes[schema.L][table.L] != nil && - sctx.HypoIndexes[schema.L][table.L][index.L] != nil { - delete(sctx.HypoIndexes[schema.L][table.L], index.L) - return nil - } - if !ifExists { - return dbterror.ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", index) - } - return nil -} - -// dropIndex drops the specified index. -// isHypo is used to indicate whether this operation is for a hypo-index. -func (e *executor) dropIndex(ctx sessionctx.Context, ti ast.Ident, indexName model.CIStr, ifExists, isHypo bool) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ti.Schema) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists) - } - t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) - } - if t.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - return errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Drop Index")) - } - - if isHypo { - return e.dropHypoIndexFromCtx(ctx, ti.Schema, ti.Name, indexName, ifExists) - } - - indexInfo := t.Meta().FindIndexByName(indexName.L) - - isPK, err := CheckIsDropPrimaryKey(indexName, indexInfo, t) - if err != nil { - return err - } - - if !ctx.GetSessionVars().InRestrictedSQL && ctx.GetSessionVars().PrimaryKeyRequired && isPK { - return infoschema.ErrTableWithoutPrimaryKey - } - - if indexInfo == nil { - err = dbterror.ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", indexName) - if ifExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return err - } - - err = checkIndexNeededInForeignKey(is, schema.Name.L, t.Meta(), indexInfo) - if err != nil { - return err - } - - jobTp := model.ActionDropIndex - if isPK { - jobTp = model.ActionDropPrimaryKey - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - SchemaState: indexInfo.State, - TableName: t.Meta().Name.L, - Type: jobTp, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{indexName, ifExists}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// CheckIsDropPrimaryKey checks if we will drop PK, there are many PK implementations so we provide a helper function. -func CheckIsDropPrimaryKey(indexName model.CIStr, indexInfo *model.IndexInfo, t table.Table) (bool, error) { - var isPK bool - if indexName.L == strings.ToLower(mysql.PrimaryKeyName) && - // Before we fixed #14243, there might be a general index named `primary` but not a primary key. - (indexInfo == nil || indexInfo.Primary) { - isPK = true - } - if isPK { - // If the table's PKIsHandle is true, we can't find the index from the table. So we check the value of PKIsHandle. - if indexInfo == nil && !t.Meta().PKIsHandle { - return isPK, dbterror.ErrCantDropFieldOrKey.GenWithStackByArgs("PRIMARY") - } - if t.Meta().IsCommonHandle || t.Meta().PKIsHandle { - return isPK, dbterror.ErrUnsupportedModifyPrimaryKey.GenWithStack("Unsupported drop primary key when the table is using clustered index") - } - } - - return isPK, nil -} - -// validateCommentLength checks comment length of table, column, or index -// If comment length is more than the standard length truncate it -// and store the comment length upto the standard comment length size. -func validateCommentLength(ec errctx.Context, sqlMode mysql.SQLMode, name string, comment *string, errTooLongComment *terror.Error) (string, error) { - if comment == nil { - return "", nil - } - - maxLen := MaxCommentLength - // The maximum length of table comment in MySQL 5.7 is 2048 - // Other comment is 1024 - switch errTooLongComment { - case dbterror.ErrTooLongTableComment: - maxLen *= 2 - case dbterror.ErrTooLongFieldComment, dbterror.ErrTooLongIndexComment, dbterror.ErrTooLongTablePartitionComment: - default: - // add more types of terror.Error if need - } - if len(*comment) > maxLen { - err := errTooLongComment.GenWithStackByArgs(name, maxLen) - if sqlMode.HasStrictMode() { - // may be treated like an error. - return "", err - } - ec.AppendWarning(err) - *comment = (*comment)[:maxLen] - } - return *comment, nil -} - -// BuildAddedPartitionInfo build alter table add partition info -func BuildAddedPartitionInfo(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) (*model.PartitionInfo, error) { - numParts := uint64(0) - switch meta.Partition.Type { - case model.PartitionTypeNone: - // OK - case model.PartitionTypeList: - if len(spec.PartDefinitions) == 0 { - return nil, ast.ErrPartitionsMustBeDefined.GenWithStackByArgs(meta.Partition.Type) - } - err := checkListPartitions(spec.PartDefinitions) - if err != nil { - return nil, err - } - - case model.PartitionTypeRange: - if spec.Tp == ast.AlterTableAddLastPartition { - err := buildAddedPartitionDefs(ctx, meta, spec) - if err != nil { - return nil, err - } - spec.PartDefinitions = spec.Partition.Definitions - } else { - if len(spec.PartDefinitions) == 0 { - return nil, ast.ErrPartitionsMustBeDefined.GenWithStackByArgs(meta.Partition.Type) - } - } - case model.PartitionTypeHash, model.PartitionTypeKey: - switch spec.Tp { - case ast.AlterTableRemovePartitioning: - numParts = 1 - default: - return nil, errors.Trace(dbterror.ErrUnsupportedAddPartition) - case ast.AlterTableCoalescePartitions: - if int(spec.Num) >= len(meta.Partition.Definitions) { - return nil, dbterror.ErrDropLastPartition - } - numParts = uint64(len(meta.Partition.Definitions)) - spec.Num - case ast.AlterTableAddPartitions: - if len(spec.PartDefinitions) > 0 { - numParts = uint64(len(meta.Partition.Definitions)) + uint64(len(spec.PartDefinitions)) - } else { - numParts = uint64(len(meta.Partition.Definitions)) + spec.Num - } - } - default: - // we don't support ADD PARTITION for all other partition types yet. - return nil, errors.Trace(dbterror.ErrUnsupportedAddPartition) - } - - part := &model.PartitionInfo{ - Type: meta.Partition.Type, - Expr: meta.Partition.Expr, - Columns: meta.Partition.Columns, - Enable: meta.Partition.Enable, - } - - defs, err := buildPartitionDefinitionsInfo(ctx, spec.PartDefinitions, meta, numParts) - if err != nil { - return nil, err - } - - part.Definitions = defs - part.Num = uint64(len(defs)) - return part, nil -} - -func buildAddedPartitionDefs(ctx expression.BuildContext, meta *model.TableInfo, spec *ast.AlterTableSpec) error { - partInterval := getPartitionIntervalFromTable(ctx, meta) - if partInterval == nil { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs( - "LAST PARTITION, does not seem like an INTERVAL partitioned table") - } - if partInterval.MaxValPart { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("LAST PARTITION when MAXVALUE partition exists") - } - - spec.Partition.Interval = partInterval - - if len(spec.PartDefinitions) > 0 { - return errors.Trace(dbterror.ErrUnsupportedAddPartition) - } - return GeneratePartDefsFromInterval(ctx, spec.Tp, meta, spec.Partition) -} - -// LockTables uses to execute lock tables statement. -func (e *executor) LockTables(ctx sessionctx.Context, stmt *ast.LockTablesStmt) error { - lockTables := make([]model.TableLockTpInfo, 0, len(stmt.TableLocks)) - sessionInfo := model.SessionInfo{ - ServerID: e.uuid, - SessionID: ctx.GetSessionVars().ConnectionID, - } - uniqueTableID := make(map[int64]struct{}) - involveSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(stmt.TableLocks)) - // Check whether the table was already locked by another. - for _, tl := range stmt.TableLocks { - tb := tl.Table - err := throwErrIfInMemOrSysDB(ctx, tb.Schema.L) - if err != nil { - return err - } - schema, t, err := e.getSchemaAndTableByIdent(ast.Ident{Schema: tb.Schema, Name: tb.Name}) - if err != nil { - return errors.Trace(err) - } - if t.Meta().IsView() || t.Meta().IsSequence() { - return table.ErrUnsupportedOp.GenWithStackByArgs() - } - - err = checkTableLocked(t.Meta(), tl.Type, sessionInfo) - if err != nil { - return err - } - if _, ok := uniqueTableID[t.Meta().ID]; ok { - return infoschema.ErrNonuniqTable.GenWithStackByArgs(t.Meta().Name) - } - uniqueTableID[t.Meta().ID] = struct{}{} - lockTables = append(lockTables, model.TableLockTpInfo{SchemaID: schema.ID, TableID: t.Meta().ID, Tp: tl.Type}) - involveSchemaInfo = append(involveSchemaInfo, model.InvolvingSchemaInfo{ - Database: schema.Name.L, - Table: t.Meta().Name.L, - }) - } - - unlockTables := ctx.GetAllTableLocks() - arg := &LockTablesArg{ - LockTables: lockTables, - UnlockTables: unlockTables, - SessionInfo: sessionInfo, - } - job := &model.Job{ - SchemaID: lockTables[0].SchemaID, - TableID: lockTables[0].TableID, - Type: model.ActionLockTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{arg}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: involveSchemaInfo, - SQLMode: ctx.GetSessionVars().SQLMode, - } - // AddTableLock here is avoiding this job was executed successfully but the session was killed before return. - ctx.AddTableLock(lockTables) - err := e.DoDDLJob(ctx, job) - if err == nil { - ctx.ReleaseTableLocks(unlockTables) - ctx.AddTableLock(lockTables) - } - return errors.Trace(err) -} - -// UnlockTables uses to execute unlock tables statement. -func (e *executor) UnlockTables(ctx sessionctx.Context, unlockTables []model.TableLockTpInfo) error { - if len(unlockTables) == 0 { - return nil - } - arg := &LockTablesArg{ - UnlockTables: unlockTables, - SessionInfo: model.SessionInfo{ - ServerID: e.uuid, - SessionID: ctx.GetSessionVars().ConnectionID, - }, - } - - involveSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(unlockTables)) - is := e.infoCache.GetLatest() - for _, t := range unlockTables { - schema, ok := is.SchemaByID(t.SchemaID) - if !ok { - continue - } - tbl, ok := is.TableByID(t.TableID) - if !ok { - continue - } - involveSchemaInfo = append(involveSchemaInfo, model.InvolvingSchemaInfo{ - Database: schema.Name.L, - Table: tbl.Meta().Name.L, - }) - } - job := &model.Job{ - SchemaID: unlockTables[0].SchemaID, - TableID: unlockTables[0].TableID, - Type: model.ActionUnlockTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{arg}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: involveSchemaInfo, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err := e.DoDDLJob(ctx, job) - if err == nil { - ctx.ReleaseAllTableLocks() - } - return errors.Trace(err) -} - -func throwErrIfInMemOrSysDB(ctx sessionctx.Context, dbLowerName string) error { - if util.IsMemOrSysDB(dbLowerName) { - if ctx.GetSessionVars().User != nil { - return infoschema.ErrAccessDenied.GenWithStackByArgs(ctx.GetSessionVars().User.Username, ctx.GetSessionVars().User.Hostname) - } - return infoschema.ErrAccessDenied.GenWithStackByArgs("", "") - } - return nil -} - -func (e *executor) CleanupTableLock(ctx sessionctx.Context, tables []*ast.TableName) error { - uniqueTableID := make(map[int64]struct{}) - cleanupTables := make([]model.TableLockTpInfo, 0, len(tables)) - unlockedTablesNum := 0 - involvingSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(tables)) - // Check whether the table was already locked by another. - for _, tb := range tables { - err := throwErrIfInMemOrSysDB(ctx, tb.Schema.L) - if err != nil { - return err - } - schema, t, err := e.getSchemaAndTableByIdent(ast.Ident{Schema: tb.Schema, Name: tb.Name}) - if err != nil { - return errors.Trace(err) - } - if t.Meta().IsView() || t.Meta().IsSequence() { - return table.ErrUnsupportedOp - } - // Maybe the table t was not locked, but still try to unlock this table. - // If we skip unlock the table here, the job maybe not consistent with the job.Query. - // eg: unlock tables t1,t2; If t2 is not locked and skip here, then the job will only unlock table t1, - // and this behaviour is not consistent with the sql query. - if !t.Meta().IsLocked() { - unlockedTablesNum++ - } - if _, ok := uniqueTableID[t.Meta().ID]; ok { - return infoschema.ErrNonuniqTable.GenWithStackByArgs(t.Meta().Name) - } - uniqueTableID[t.Meta().ID] = struct{}{} - cleanupTables = append(cleanupTables, model.TableLockTpInfo{SchemaID: schema.ID, TableID: t.Meta().ID}) - involvingSchemaInfo = append(involvingSchemaInfo, model.InvolvingSchemaInfo{ - Database: schema.Name.L, - Table: t.Meta().Name.L, - }) - } - // If the num of cleanupTables is 0, or all cleanupTables is unlocked, just return here. - if len(cleanupTables) == 0 || len(cleanupTables) == unlockedTablesNum { - return nil - } - - arg := &LockTablesArg{ - UnlockTables: cleanupTables, - IsCleanup: true, - } - job := &model.Job{ - SchemaID: cleanupTables[0].SchemaID, - TableID: cleanupTables[0].TableID, - Type: model.ActionUnlockTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{arg}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: involvingSchemaInfo, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err := e.DoDDLJob(ctx, job) - if err == nil { - ctx.ReleaseTableLocks(cleanupTables) - } - return errors.Trace(err) -} - -// LockTablesArg is the argument for LockTables, export for test. -type LockTablesArg struct { - LockTables []model.TableLockTpInfo - IndexOfLock int - UnlockTables []model.TableLockTpInfo - IndexOfUnlock int - SessionInfo model.SessionInfo - IsCleanup bool -} - -func (e *executor) RepairTable(ctx sessionctx.Context, createStmt *ast.CreateTableStmt) error { - // Existence of DB and table has been checked in the preprocessor. - oldTableInfo, ok := (ctx.Value(domainutil.RepairedTable)).(*model.TableInfo) - if !ok || oldTableInfo == nil { - return dbterror.ErrRepairTableFail.GenWithStack("Failed to get the repaired table") - } - oldDBInfo, ok := (ctx.Value(domainutil.RepairedDatabase)).(*model.DBInfo) - if !ok || oldDBInfo == nil { - return dbterror.ErrRepairTableFail.GenWithStack("Failed to get the repaired database") - } - // By now only support same DB repair. - if createStmt.Table.Schema.L != oldDBInfo.Name.L { - return dbterror.ErrRepairTableFail.GenWithStack("Repaired table should in same database with the old one") - } - - // It is necessary to specify the table.ID and partition.ID manually. - newTableInfo, err := buildTableInfoWithCheck(ctx, createStmt, oldTableInfo.Charset, oldTableInfo.Collate, oldTableInfo.PlacementPolicyRef) - if err != nil { - return errors.Trace(err) - } - // Override newTableInfo with oldTableInfo's element necessary. - // TODO: There may be more element assignments here, and the new TableInfo should be verified with the actual data. - newTableInfo.ID = oldTableInfo.ID - if err = checkAndOverridePartitionID(newTableInfo, oldTableInfo); err != nil { - return err - } - newTableInfo.AutoIncID = oldTableInfo.AutoIncID - // If any old columnInfo has lost, that means the old column ID lost too, repair failed. - for i, newOne := range newTableInfo.Columns { - old := oldTableInfo.FindPublicColumnByName(newOne.Name.L) - if old == nil { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Column " + newOne.Name.L + " has lost") - } - if newOne.GetType() != old.GetType() { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Column " + newOne.Name.L + " type should be the same") - } - if newOne.GetFlen() != old.GetFlen() { - logutil.DDLLogger().Warn("admin repair table : Column " + newOne.Name.L + " flen is not equal to the old one") - } - newTableInfo.Columns[i].ID = old.ID - } - // If any old indexInfo has lost, that means the index ID lost too, so did the data, repair failed. - for i, newOne := range newTableInfo.Indices { - old := getIndexInfoByNameAndColumn(oldTableInfo, newOne) - if old == nil { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Index " + newOne.Name.L + " has lost") - } - if newOne.Tp != old.Tp { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Index " + newOne.Name.L + " type should be the same") - } - newTableInfo.Indices[i].ID = old.ID - } - - newTableInfo.State = model.StatePublic - err = checkTableInfoValid(newTableInfo) - if err != nil { - return err - } - newTableInfo.State = model.StateNone - - job := &model.Job{ - SchemaID: oldDBInfo.ID, - TableID: newTableInfo.ID, - SchemaName: oldDBInfo.Name.L, - TableName: newTableInfo.Name.L, - Type: model.ActionRepairTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{newTableInfo}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - if err == nil { - // Remove the old TableInfo from repairInfo before domain reload. - domainutil.RepairInfo.RemoveFromRepairInfo(oldDBInfo.Name.L, oldTableInfo.Name.L) - } - return errors.Trace(err) -} - -func (e *executor) OrderByColumns(ctx sessionctx.Context, ident ast.Ident) error { - _, tb, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - if tb.Meta().GetPkColInfo() != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("ORDER BY ignored as there is a user-defined clustered index in the table '%s'", ident.Name)) - } - return nil -} - -func (e *executor) CreateSequence(ctx sessionctx.Context, stmt *ast.CreateSequenceStmt) error { - ident := ast.Ident{Name: stmt.Name.Name, Schema: stmt.Name.Schema} - sequenceInfo, err := buildSequenceInfo(stmt, ident) - if err != nil { - return err - } - // TiDB describe the sequence within a tableInfo, as a same-level object of a table and view. - tbInfo, err := BuildTableInfo(ctx, ident.Name, nil, nil, "", "") - if err != nil { - return err - } - tbInfo.Sequence = sequenceInfo - - onExist := OnExistError - if stmt.IfNotExists { - onExist = OnExistIgnore - } - - return e.CreateTableWithInfo(ctx, ident.Schema, tbInfo, nil, WithOnExist(onExist)) -} - -func (e *executor) AlterSequence(ctx sessionctx.Context, stmt *ast.AlterSequenceStmt) error { - ident := ast.Ident{Name: stmt.Name.Name, Schema: stmt.Name.Schema} - is := e.infoCache.GetLatest() - // Check schema existence. - db, ok := is.SchemaByName(ident.Schema) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(ident.Schema) - } - // Check table existence. - tbl, err := is.TableByName(context.Background(), ident.Schema, ident.Name) - if err != nil { - if stmt.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return err - } - if !tbl.Meta().IsSequence() { - return dbterror.ErrWrongObject.GenWithStackByArgs(ident.Schema, ident.Name, "SEQUENCE") - } - - // Validate the new sequence option value in old sequenceInfo. - oldSequenceInfo := tbl.Meta().Sequence - copySequenceInfo := *oldSequenceInfo - _, _, err = alterSequenceOptions(stmt.SeqOptions, ident, ©SequenceInfo) - if err != nil { - return err - } - - job := &model.Job{ - SchemaID: db.ID, - TableID: tbl.Meta().ID, - SchemaName: db.Name.L, - TableName: tbl.Meta().Name.L, - Type: model.ActionAlterSequence, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{ident, stmt.SeqOptions}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) DropSequence(ctx sessionctx.Context, stmt *ast.DropSequenceStmt) (err error) { - return e.dropTableObject(ctx, stmt.Sequences, stmt.IfExists, sequenceObject) -} - -func (e *executor) AlterIndexVisibility(ctx sessionctx.Context, ident ast.Ident, indexName model.CIStr, visibility ast.IndexVisibility) error { - schema, tb, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return err - } - - invisible := false - if visibility == ast.IndexVisibilityInvisible { - invisible = true - } - - skip, err := validateAlterIndexVisibility(ctx, indexName, invisible, tb.Meta()) - if err != nil { - return errors.Trace(err) - } - if skip { - return nil - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tb.Meta().ID, - SchemaName: schema.Name.L, - TableName: tb.Meta().Name.L, - Type: model.ActionAlterIndexVisibility, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{indexName, invisible}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) AlterTableAttributes(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { - schema, tb, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - meta := tb.Meta() - - rule := label.NewRule() - err = rule.ApplyAttributesSpec(spec.AttributesSpec) - if err != nil { - return dbterror.ErrInvalidAttributesSpec.GenWithStackByArgs(err) - } - ids := getIDs([]*model.TableInfo{meta}) - rule.Reset(schema.Name.L, meta.Name.L, "", ids...) - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - TableName: meta.Name.L, - Type: model.ActionAlterTableAttributes, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{rule}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - if err != nil { - return errors.Trace(err) - } - - return errors.Trace(err) -} - -func (e *executor) AlterTablePartitionAttributes(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { - schema, tb, err := e.getSchemaAndTableByIdent(ident) - if err != nil { - return errors.Trace(err) - } - - meta := tb.Meta() - if meta.Partition == nil { - return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - - partitionID, err := tables.FindPartitionByName(meta, spec.PartitionNames[0].L) - if err != nil { - return errors.Trace(err) - } - - rule := label.NewRule() - err = rule.ApplyAttributesSpec(spec.AttributesSpec) - if err != nil { - return dbterror.ErrInvalidAttributesSpec.GenWithStackByArgs(err) - } - rule.Reset(schema.Name.L, meta.Name.L, spec.PartitionNames[0].L, partitionID) - - job := &model.Job{ - SchemaID: schema.ID, - TableID: meta.ID, - SchemaName: schema.Name.L, - TableName: meta.Name.L, - Type: model.ActionAlterTablePartitionAttributes, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{partitionID, rule}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - if err != nil { - return errors.Trace(err) - } - - return errors.Trace(err) -} - -func (e *executor) AlterTablePartitionOptions(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) (err error) { - var policyRefInfo *model.PolicyRefInfo - if spec.Options != nil { - for _, op := range spec.Options { - switch op.Tp { - case ast.TableOptionPlacementPolicy: - policyRefInfo = &model.PolicyRefInfo{ - Name: model.NewCIStr(op.StrValue), - } - default: - return errors.Trace(errors.New("unknown partition option")) - } - } - } - - if policyRefInfo != nil { - err = e.AlterTablePartitionPlacement(ctx, ident, spec, policyRefInfo) - if err != nil { - return errors.Trace(err) - } - } - - return nil -} - -func (e *executor) AlterTablePartitionPlacement(ctx sessionctx.Context, tableIdent ast.Ident, spec *ast.AlterTableSpec, policyRefInfo *model.PolicyRefInfo) (err error) { - schema, tb, err := e.getSchemaAndTableByIdent(tableIdent) - if err != nil { - return errors.Trace(err) - } - - tblInfo := tb.Meta() - if tblInfo.Partition == nil { - return errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - - partitionID, err := tables.FindPartitionByName(tblInfo, spec.PartitionNames[0].L) - if err != nil { - return errors.Trace(err) - } - - if checkIgnorePlacementDDL(ctx) { - return nil - } - - policyRefInfo, err = checkAndNormalizePlacementPolicy(ctx, policyRefInfo) - if err != nil { - return errors.Trace(err) - } - - var involveSchemaInfo []model.InvolvingSchemaInfo - if policyRefInfo != nil { - involveSchemaInfo = []model.InvolvingSchemaInfo{ - { - Database: schema.Name.L, - Table: tblInfo.Name.L, - }, - { - Policy: policyRefInfo.Name.L, - Mode: model.SharedInvolving, - }, - } - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tblInfo.ID, - SchemaName: schema.Name.L, - TableName: tblInfo.Name.L, - Type: model.ActionAlterTablePartitionPlacement, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{partitionID, policyRefInfo}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - InvolvingSchemaInfo: involveSchemaInfo, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -// AddResourceGroup implements the DDL interface, creates a resource group. -func (e *executor) AddResourceGroup(ctx sessionctx.Context, stmt *ast.CreateResourceGroupStmt) (err error) { - groupName := stmt.ResourceGroupName - groupInfo := &model.ResourceGroupInfo{Name: groupName, ResourceGroupSettings: model.NewResourceGroupSettings()} - groupInfo, err = buildResourceGroup(groupInfo, stmt.ResourceGroupOptionList) - if err != nil { - return err - } - - if _, ok := e.infoCache.GetLatest().ResourceGroupByName(groupName); ok { - if stmt.IfNotExists { - err = infoschema.ErrResourceGroupExists.FastGenByArgs(groupName) - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return infoschema.ErrResourceGroupExists.GenWithStackByArgs(groupName) - } - - if err := checkResourceGroupValidation(groupInfo); err != nil { - return err - } - - logutil.DDLLogger().Debug("create resource group", zap.String("name", groupName.O), zap.Stringer("resource group settings", groupInfo.ResourceGroupSettings)) - groupIDs, err := e.genGlobalIDs(1) - if err != nil { - return err - } - groupInfo.ID = groupIDs[0] - - job := &model.Job{ - SchemaName: groupName.L, - Type: model.ActionCreateResourceGroup, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{groupInfo, false}, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - ResourceGroup: groupInfo.Name.L, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return err -} - -// DropResourceGroup implements the DDL interface. -func (e *executor) DropResourceGroup(ctx sessionctx.Context, stmt *ast.DropResourceGroupStmt) (err error) { - groupName := stmt.ResourceGroupName - if groupName.L == rg.DefaultResourceGroupName { - return resourcegroup.ErrDroppingInternalResourceGroup - } - is := e.infoCache.GetLatest() - // Check group existence. - group, ok := is.ResourceGroupByName(groupName) - if !ok { - err = infoschema.ErrResourceGroupNotExists.GenWithStackByArgs(groupName) - if stmt.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return err - } - - // check to see if some user has dependency on the group - checker := privilege.GetPrivilegeManager(ctx) - if checker == nil { - return errors.New("miss privilege checker") - } - user, matched := checker.MatchUserResourceGroupName(groupName.L) - if matched { - err = errors.Errorf("user [%s] depends on the resource group to drop", user) - return err - } - - job := &model.Job{ - SchemaID: group.ID, - SchemaName: group.Name.L, - Type: model.ActionDropResourceGroup, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{groupName}, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - ResourceGroup: groupName.L, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return err -} - -// AlterResourceGroup implements the DDL interface. -func (e *executor) AlterResourceGroup(ctx sessionctx.Context, stmt *ast.AlterResourceGroupStmt) (err error) { - groupName := stmt.ResourceGroupName - is := e.infoCache.GetLatest() - // Check group existence. - group, ok := is.ResourceGroupByName(groupName) - if !ok { - err := infoschema.ErrResourceGroupNotExists.GenWithStackByArgs(groupName) - if stmt.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return err - } - newGroupInfo, err := buildResourceGroup(group, stmt.ResourceGroupOptionList) - if err != nil { - return errors.Trace(err) - } - - if err := checkResourceGroupValidation(newGroupInfo); err != nil { - return err - } - - logutil.DDLLogger().Debug("alter resource group", zap.String("name", groupName.L), zap.Stringer("new resource group settings", newGroupInfo.ResourceGroupSettings)) - - job := &model.Job{ - SchemaID: newGroupInfo.ID, - SchemaName: newGroupInfo.Name.L, - Type: model.ActionAlterResourceGroup, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{newGroupInfo}, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - ResourceGroup: newGroupInfo.Name.L, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return err -} - -func (e *executor) CreatePlacementPolicy(ctx sessionctx.Context, stmt *ast.CreatePlacementPolicyStmt) (err error) { - if checkIgnorePlacementDDL(ctx) { - return nil - } - - if stmt.OrReplace && stmt.IfNotExists { - return dbterror.ErrWrongUsage.GenWithStackByArgs("OR REPLACE", "IF NOT EXISTS") - } - - policyInfo, err := buildPolicyInfo(stmt.PolicyName, stmt.PlacementOptions) - if err != nil { - return errors.Trace(err) - } - - var onExists OnExist - switch { - case stmt.IfNotExists: - onExists = OnExistIgnore - case stmt.OrReplace: - onExists = OnExistReplace - default: - onExists = OnExistError - } - - return e.CreatePlacementPolicyWithInfo(ctx, policyInfo, onExists) -} - -func (e *executor) DropPlacementPolicy(ctx sessionctx.Context, stmt *ast.DropPlacementPolicyStmt) (err error) { - if checkIgnorePlacementDDL(ctx) { - return nil - } - policyName := stmt.PolicyName - is := e.infoCache.GetLatest() - // Check policy existence. - policy, ok := is.PolicyByName(policyName) - if !ok { - err = infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(policyName) - if stmt.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) - return nil - } - return err - } - - if err = CheckPlacementPolicyNotInUseFromInfoSchema(is, policy); err != nil { - return err - } - - job := &model.Job{ - SchemaID: policy.ID, - SchemaName: policy.Name.L, - Type: model.ActionDropPlacementPolicy, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{policyName}, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Policy: policyName.L, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) AlterPlacementPolicy(ctx sessionctx.Context, stmt *ast.AlterPlacementPolicyStmt) (err error) { - if checkIgnorePlacementDDL(ctx) { - return nil - } - policyName := stmt.PolicyName - is := e.infoCache.GetLatest() - // Check policy existence. - policy, ok := is.PolicyByName(policyName) - if !ok { - return infoschema.ErrPlacementPolicyNotExists.GenWithStackByArgs(policyName) - } - - newPolicyInfo, err := buildPolicyInfo(policy.Name, stmt.PlacementOptions) - if err != nil { - return errors.Trace(err) - } - - err = checkPolicyValidation(newPolicyInfo.PlacementSettings) - if err != nil { - return err - } - - job := &model.Job{ - SchemaID: policy.ID, - SchemaName: policy.Name.L, - Type: model.ActionAlterPlacementPolicy, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{newPolicyInfo}, - InvolvingSchemaInfo: []model.InvolvingSchemaInfo{{ - Policy: newPolicyInfo.Name.L, - }}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) AlterTableCache(sctx sessionctx.Context, ti ast.Ident) (err error) { - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return err - } - // if a table is already in cache state, return directly - if t.Meta().TableCacheStatusType == model.TableCacheStatusEnable { - return nil - } - - // forbidden cache table in system database. - if util.IsMemOrSysDB(schema.Name.L) { - return errors.Trace(dbterror.ErrUnsupportedAlterCacheForSysTable) - } else if t.Meta().TempTableType != model.TempTableNone { - return dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("alter temporary table cache") - } - - if t.Meta().Partition != nil { - return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("partition mode") - } - - succ, err := checkCacheTableSize(e.store, t.Meta().ID) - if err != nil { - return errors.Trace(err) - } - if !succ { - return dbterror.ErrOptOnCacheTable.GenWithStackByArgs("table too large") - } - - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - ddlQuery, _ := sctx.Value(sessionctx.QueryString).(string) - // Initialize the cached table meta lock info in `mysql.table_cache_meta`. - // The operation shouldn't fail in most cases, and if it does, return the error directly. - // This DML and the following DDL is not atomic, that's not a problem. - _, _, err = sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, - "replace into mysql.table_cache_meta values (%?, 'NONE', 0, 0)", t.Meta().ID) - if err != nil { - return errors.Trace(err) - } - - sctx.SetValue(sessionctx.QueryString, ddlQuery) - - job := &model.Job{ - SchemaID: schema.ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - TableID: t.Meta().ID, - Type: model.ActionAlterCacheTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{}, - CDCWriteSource: sctx.GetSessionVars().CDCWriteSource, - SQLMode: sctx.GetSessionVars().SQLMode, - } - - return e.DoDDLJob(sctx, job) -} - -func checkCacheTableSize(store kv.Storage, tableID int64) (bool, error) { - const cacheTableSizeLimit = 64 * (1 << 20) // 64M - succ := true - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnCacheTable) - err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { - txn.SetOption(kv.RequestSourceType, kv.InternalTxnCacheTable) - prefix := tablecodec.GenTablePrefix(tableID) - it, err := txn.Iter(prefix, prefix.PrefixNext()) - if err != nil { - return errors.Trace(err) - } - defer it.Close() - - totalSize := 0 - for it.Valid() && it.Key().HasPrefix(prefix) { - key := it.Key() - value := it.Value() - totalSize += len(key) - totalSize += len(value) - - if totalSize > cacheTableSizeLimit { - succ = false - break - } - - err = it.Next() - if err != nil { - return errors.Trace(err) - } - } - return nil - }) - return succ, err -} - -func (e *executor) AlterTableNoCache(ctx sessionctx.Context, ti ast.Ident) (err error) { - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return err - } - // if a table is not in cache state, return directly - if t.Meta().TableCacheStatusType == model.TableCacheStatusDisable { - return nil - } - - job := &model.Job{ - SchemaID: schema.ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - TableID: t.Meta().ID, - Type: model.ActionAlterNoCacheTable, - BinlogInfo: &model.HistoryInfo{}, - Args: []any{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - return e.DoDDLJob(ctx, job) -} - -func (e *executor) CreateCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr, constr *ast.Constraint) error { - schema, t, err := e.getSchemaAndTableByIdent(ti) - if err != nil { - return errors.Trace(err) - } - if constraintInfo := t.Meta().FindConstraintInfoByName(constrName.L); constraintInfo != nil { - return infoschema.ErrCheckConstraintDupName.GenWithStackByArgs(constrName.L) - } - - // allocate the temporary constraint name for dependency-check-error-output below. - constrNames := map[string]bool{} - for _, constr := range t.Meta().Constraints { - constrNames[constr.Name.L] = true - } - setEmptyCheckConstraintName(t.Meta().Name.L, constrNames, []*ast.Constraint{constr}) - - // existedColsMap can be used to check the existence of depended. - existedColsMap := make(map[string]struct{}) - cols := t.Cols() - for _, v := range cols { - existedColsMap[v.Name.L] = struct{}{} - } - // check expression if supported - if ok, err := table.IsSupportedExpr(constr); !ok { - return err - } - - dependedColsMap := findDependentColsInExpr(constr.Expr) - dependedCols := make([]model.CIStr, 0, len(dependedColsMap)) - for k := range dependedColsMap { - if _, ok := existedColsMap[k]; !ok { - // The table constraint depended on a non-existed column. - return dbterror.ErrBadField.GenWithStackByArgs(k, "check constraint "+constr.Name+" expression") - } - dependedCols = append(dependedCols, model.NewCIStr(k)) - } - - // build constraint meta info. - tblInfo := t.Meta() - - // check auto-increment column - if table.ContainsAutoIncrementCol(dependedCols, tblInfo) { - return dbterror.ErrCheckConstraintRefersAutoIncrementColumn.GenWithStackByArgs(constr.Name) - } - // check foreign key - if err := table.HasForeignKeyRefAction(tblInfo.ForeignKeys, nil, constr, dependedCols); err != nil { - return err - } - constraintInfo, err := buildConstraintInfo(tblInfo, dependedCols, constr, model.StateNone) - if err != nil { - return errors.Trace(err) - } - // check if the expression is bool type - if err := table.IfCheckConstraintExprBoolType(ctx.GetExprCtx().GetEvalCtx(), constraintInfo, tblInfo); err != nil { - return err - } - job := &model.Job{ - SchemaID: schema.ID, - TableID: tblInfo.ID, - SchemaName: schema.Name.L, - TableName: tblInfo.Name.L, - Type: model.ActionAddCheckConstraint, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{constraintInfo}, - Priority: ctx.GetSessionVars().DDLReorgPriority, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) DropCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ti.Schema) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists) - } - t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) - } - tblInfo := t.Meta() - - constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) - if constraintInfo == nil { - return dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tblInfo.ID, - SchemaName: schema.Name.L, - TableName: tblInfo.Name.L, - Type: model.ActionDropCheckConstraint, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{constrName}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) AlterCheckConstraint(ctx sessionctx.Context, ti ast.Ident, constrName model.CIStr, enforced bool) error { - is := e.infoCache.GetLatest() - schema, ok := is.SchemaByName(ti.Schema) - if !ok { - return errors.Trace(infoschema.ErrDatabaseNotExists) - } - t, err := is.TableByName(context.Background(), ti.Schema, ti.Name) - if err != nil { - return errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(ti.Schema, ti.Name)) - } - tblInfo := t.Meta() - - constraintInfo := tblInfo.FindConstraintInfoByName(constrName.L) - if constraintInfo == nil { - return dbterror.ErrConstraintNotFound.GenWithStackByArgs(constrName) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: tblInfo.ID, - SchemaName: schema.Name.L, - TableName: tblInfo.Name.L, - Type: model.ActionAlterCheckConstraint, - BinlogInfo: &model.HistoryInfo{}, - CDCWriteSource: ctx.GetSessionVars().CDCWriteSource, - Args: []any{constrName, enforced}, - SQLMode: ctx.GetSessionVars().SQLMode, - } - - err = e.DoDDLJob(ctx, job) - return errors.Trace(err) -} - -func (e *executor) genGlobalIDs(count int) ([]int64, error) { - var ret []int64 - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - // lock to reduce conflict - e.globalIDLock.Lock() - defer e.globalIDLock.Unlock() - err := kv.RunInNewTxn(ctx, e.store, true, func(_ context.Context, txn kv.Transaction) error { - m := meta.NewMeta(txn) - var err error - ret, err = m.GenGlobalIDs(count) - return err - }) - - return ret, err -} - -func (e *executor) genPlacementPolicyID() (int64, error) { - var ret int64 - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - err := kv.RunInNewTxn(ctx, e.store, true, func(_ context.Context, txn kv.Transaction) error { - m := meta.NewMeta(txn) - var err error - ret, err = m.GenPlacementPolicyID() - return err - }) - - return ret, err -} - -// DoDDLJob will return -// - nil: found in history DDL job and no job error -// - context.Cancel: job has been sent to worker, but not found in history DDL job before cancel -// - other: found in history DDL job and return that job error -func (e *executor) DoDDLJob(ctx sessionctx.Context, job *model.Job) error { - return e.DoDDLJobWrapper(ctx, NewJobWrapper(job, false)) -} - -// DoDDLJobWrapper submit DDL job and wait it finishes. -// When fast create is enabled, we might merge multiple jobs into one, so do not -// depend on job.ID, use JobID from jobSubmitResult. -func (e *executor) DoDDLJobWrapper(ctx sessionctx.Context, jobW *JobWrapper) error { - job := jobW.Job - job.TraceInfo = &model.TraceInfo{ - ConnectionID: ctx.GetSessionVars().ConnectionID, - SessionAlias: ctx.GetSessionVars().SessionAlias, - } - if mci := ctx.GetSessionVars().StmtCtx.MultiSchemaInfo; mci != nil { - // In multiple schema change, we don't run the job. - // Instead, we merge all the jobs into one pending job. - return appendToSubJobs(mci, job) - } - // Get a global job ID and put the DDL job in the queue. - setDDLJobQuery(ctx, job) - e.deliverJobTask(jobW) - - failpoint.Inject("mockParallelSameDDLJobTwice", func(val failpoint.Value) { - if val.(bool) { - <-jobW.ResultCh[0] - // The same job will be put to the DDL queue twice. - job = job.Clone() - newJobW := NewJobWrapper(job, jobW.IDAllocated) - e.deliverJobTask(newJobW) - // The second job result is used for test. - jobW = newJobW - } - }) - - // worker should restart to continue handling tasks in limitJobCh, and send back through jobW.err - result := <-jobW.ResultCh[0] - // job.ID must be allocated after previous channel receive returns nil. - jobID, err := result.jobID, result.err - defer e.delJobDoneCh(jobID) - if err != nil { - // The transaction of enqueuing job is failed. - return errors.Trace(err) - } - failpoint.InjectCall("waitJobSubmitted") - - sessVars := ctx.GetSessionVars() - sessVars.StmtCtx.IsDDLJobInQueue = true - - ddlAction := job.Type - // Notice worker that we push a new job and wait the job done. - e.notifyNewJobSubmitted(e.ddlJobNotifyCh, addingDDLJobNotifyKey, jobID, ddlAction.String()) - if result.merged { - logutil.DDLLogger().Info("DDL job submitted", zap.Int64("job_id", jobID), zap.String("query", job.Query), zap.String("merged", "true")) - } else { - logutil.DDLLogger().Info("DDL job submitted", zap.Stringer("job", job), zap.String("query", job.Query)) - } - - var historyJob *model.Job - - // Attach the context of the jobId to the calling session so that - // KILL can cancel this DDL job. - ctx.GetSessionVars().StmtCtx.DDLJobID = jobID - - // For a job from start to end, the state of it will be none -> delete only -> write only -> reorganization -> public - // For every state changes, we will wait as lease 2 * lease time, so here the ticker check is 10 * lease. - // But we use etcd to speed up, normally it takes less than 0.5s now, so we use 0.5s or 1s or 3s as the max value. - initInterval, _ := getJobCheckInterval(ddlAction, 0) - ticker := time.NewTicker(chooseLeaseTime(10*e.lease, initInterval)) - startTime := time.Now() - metrics.JobsGauge.WithLabelValues(ddlAction.String()).Inc() - defer func() { - ticker.Stop() - metrics.JobsGauge.WithLabelValues(ddlAction.String()).Dec() - metrics.HandleJobHistogram.WithLabelValues(ddlAction.String(), metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - recordLastDDLInfo(ctx, historyJob) - }() - i := 0 - notifyCh, _ := e.getJobDoneCh(jobID) - for { - failpoint.InjectCall("storeCloseInLoop") - select { - case _, ok := <-notifyCh: - if !ok { - // when fast create enabled, jobs might be merged, and we broadcast - // the result by closing the channel, to avoid this loop keeps running - // without sleeping on retryable error, we set it to nil. - notifyCh = nil - } - case <-ticker.C: - i++ - ticker = updateTickerInterval(ticker, 10*e.lease, ddlAction, i) - case <-e.ctx.Done(): - logutil.DDLLogger().Info("DoDDLJob will quit because context done") - return context.Canceled - } - - // If the connection being killed, we need to CANCEL the DDL job. - if sessVars.SQLKiller.HandleSignal() == exeerrors.ErrQueryInterrupted { - if atomic.LoadInt32(&sessVars.ConnectionStatus) == variable.ConnStatusShutdown { - logutil.DDLLogger().Info("DoDDLJob will quit because context done") - return context.Canceled - } - if sessVars.StmtCtx.DDLJobID != 0 { - se, err := e.sessPool.Get() - if err != nil { - logutil.DDLLogger().Error("get session failed, check again", zap.Error(err)) - continue - } - sessVars.StmtCtx.DDLJobID = 0 // Avoid repeat. - errs, err := CancelJobsBySystem(se, []int64{jobID}) - e.sessPool.Put(se) - if len(errs) > 0 { - logutil.DDLLogger().Warn("error canceling DDL job", zap.Error(errs[0])) - } - if err != nil { - logutil.DDLLogger().Warn("Kill command could not cancel DDL job", zap.Error(err)) - continue - } - } - } - - se, err := e.sessPool.Get() - if err != nil { - logutil.DDLLogger().Error("get session failed, check again", zap.Error(err)) - continue - } - historyJob, err = GetHistoryJobByID(se, jobID) - e.sessPool.Put(se) - if err != nil { - logutil.DDLLogger().Error("get history DDL job failed, check again", zap.Error(err)) - continue - } - if historyJob == nil { - logutil.DDLLogger().Debug("DDL job is not in history, maybe not run", zap.Int64("jobID", jobID)) - continue - } - - e.checkHistoryJobInTest(ctx, historyJob) - - // If a job is a history job, the state must be JobStateSynced or JobStateRollbackDone or JobStateCancelled. - if historyJob.IsSynced() { - // Judge whether there are some warnings when executing DDL under the certain SQL mode. - if historyJob.ReorgMeta != nil && len(historyJob.ReorgMeta.Warnings) != 0 { - if len(historyJob.ReorgMeta.Warnings) != len(historyJob.ReorgMeta.WarningsCount) { - logutil.DDLLogger().Info("DDL warnings doesn't match the warnings count", zap.Int64("jobID", jobID)) - } else { - for key, warning := range historyJob.ReorgMeta.Warnings { - keyCount := historyJob.ReorgMeta.WarningsCount[key] - if keyCount == 1 { - ctx.GetSessionVars().StmtCtx.AppendWarning(warning) - } else { - newMsg := fmt.Sprintf("%d warnings with this error code, first warning: "+warning.GetMsg(), keyCount) - newWarning := dbterror.ClassTypes.Synthesize(terror.ErrCode(warning.Code()), newMsg) - ctx.GetSessionVars().StmtCtx.AppendWarning(newWarning) - } - } - } - } - appendMultiChangeWarningsToOwnerCtx(ctx, historyJob) - - logutil.DDLLogger().Info("DDL job is finished", zap.Int64("jobID", jobID)) - return nil - } - - if historyJob.Error != nil { - logutil.DDLLogger().Info("DDL job is failed", zap.Int64("jobID", jobID)) - return errors.Trace(historyJob.Error) - } - panic("When the state is JobStateRollbackDone or JobStateCancelled, historyJob.Error should never be nil") - } -} - -func (e *executor) getJobDoneCh(jobID int64) (chan struct{}, bool) { - return e.ddlJobDoneChMap.Load(jobID) -} - -func (e *executor) delJobDoneCh(jobID int64) { - e.ddlJobDoneChMap.Delete(jobID) -} - -func (e *executor) deliverJobTask(task *JobWrapper) { - // TODO this might block forever, as the consumer part considers context cancel. - e.limitJobCh <- task -} - -func updateTickerInterval(ticker *time.Ticker, lease time.Duration, action model.ActionType, i int) *time.Ticker { - interval, changed := getJobCheckInterval(action, i) - if !changed { - return ticker - } - // For now we should stop old ticker and create a new ticker - ticker.Stop() - return time.NewTicker(chooseLeaseTime(lease, interval)) -} - -func recordLastDDLInfo(ctx sessionctx.Context, job *model.Job) { - if job == nil { - return - } - ctx.GetSessionVars().LastDDLInfo.Query = job.Query - ctx.GetSessionVars().LastDDLInfo.SeqNum = job.SeqNum -} - -func setDDLJobQuery(ctx sessionctx.Context, job *model.Job) { - switch job.Type { - case model.ActionUpdateTiFlashReplicaStatus, model.ActionUnlockTable: - job.Query = "" - default: - job.Query, _ = ctx.Value(sessionctx.QueryString).(string) - } -} - -var ( - fastDDLIntervalPolicy = []time.Duration{ - 500 * time.Millisecond, - } - normalDDLIntervalPolicy = []time.Duration{ - 500 * time.Millisecond, - 500 * time.Millisecond, - 1 * time.Second, - } - slowDDLIntervalPolicy = []time.Duration{ - 500 * time.Millisecond, - 500 * time.Millisecond, - 1 * time.Second, - 1 * time.Second, - 3 * time.Second, - } -) - -func getIntervalFromPolicy(policy []time.Duration, i int) (time.Duration, bool) { - plen := len(policy) - if i < plen { - return policy[i], true - } - return policy[plen-1], false -} - -func getJobCheckInterval(action model.ActionType, i int) (time.Duration, bool) { - switch action { - case model.ActionAddIndex, model.ActionAddPrimaryKey, model.ActionModifyColumn, - model.ActionReorganizePartition, - model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - return getIntervalFromPolicy(slowDDLIntervalPolicy, i) - case model.ActionCreateTable, model.ActionCreateSchema: - return getIntervalFromPolicy(fastDDLIntervalPolicy, i) - default: - return getIntervalFromPolicy(normalDDLIntervalPolicy, i) - } -} - -// NewDDLReorgMeta create a DDL ReorgMeta. -func NewDDLReorgMeta(ctx sessionctx.Context) *model.DDLReorgMeta { - tzName, tzOffset := ddlutil.GetTimeZone(ctx) - return &model.DDLReorgMeta{ - SQLMode: ctx.GetSessionVars().SQLMode, - Warnings: make(map[errors.ErrorID]*terror.Error), - WarningsCount: make(map[errors.ErrorID]int64), - Location: &model.TimeZoneLocation{Name: tzName, Offset: tzOffset}, - ResourceGroupName: ctx.GetSessionVars().StmtCtx.ResourceGroupName, - Version: model.CurrentReorgMetaVersion, - } -} diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index 168022effcbdf..d9b2c97d50903 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -791,7 +791,7 @@ SwitchIndexState: } // Inject the failpoint to prevent the progress of index creation. - if v, _err_ := failpoint.Eval(_curpkg_("create-index-stuck-before-public")); _err_ == nil { + failpoint.Inject("create-index-stuck-before-public", func(v failpoint.Value) { if sigFile, ok := v.(string); ok { for { time.Sleep(1 * time.Second) @@ -799,12 +799,12 @@ SwitchIndexState: if os.IsNotExist(err) { continue } - return ver, errors.Trace(err) + failpoint.Return(ver, errors.Trace(err)) } break } } - } + }) ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StatePublic) if err != nil { @@ -931,11 +931,11 @@ func doReorgWorkForCreateIndex( ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) return false, ver, errors.Trace(err) case model.BackfillStateReadyToMerge: - if _, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateBeforeMerge")); _err_ == nil { + failpoint.Inject("mockDMLExecutionStateBeforeMerge", func(_ failpoint.Value) { if MockDMLExecutionStateBeforeMerge != nil { MockDMLExecutionStateBeforeMerge() } - } + }) logutil.DDLLogger().Info("index backfill state ready to merge", zap.Int64("job ID", job.ID), zap.String("table", tbl.Meta().Name.O), @@ -993,7 +993,7 @@ func runIngestReorgJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, } return false, ver, errors.Trace(err) } - failpoint.Call(_curpkg_("afterRunIngestReorgJob"), job, done) + failpoint.InjectCall("afterRunIngestReorgJob", job, done) return done, ver, nil } @@ -1025,13 +1025,13 @@ func runReorgJobAndHandleErr( elements = append(elements, &meta.Element{ID: indexInfo.ID, TypeKey: meta.IndexElementKey}) } - if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateMerging")); _err_ == nil { + failpoint.Inject("mockDMLExecutionStateMerging", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) && allIndexInfos[0].BackfillState == model.BackfillStateMerging && MockDMLExecutionStateMerging != nil { MockDMLExecutionStateMerging() } - } + }) sctx, err1 := w.sessPool.Get() if err1 != nil { @@ -1080,11 +1080,11 @@ func runReorgJobAndHandleErr( } return false, ver, errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateBeforeImport")); _err_ == nil { + failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { if MockDMLExecutionStateBeforeImport != nil { MockDMLExecutionStateBeforeImport() } - } + }) return true, ver, nil } @@ -1149,12 +1149,12 @@ func onDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { idxIDs = append(idxIDs, indexInfo.ID) } - if val, _err_ := failpoint.Eval(_curpkg_("mockExceedErrorLimit")); _err_ == nil { + failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { panic("panic test in cancelling add index") } - } + }) ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != model.StateNone) if err != nil { @@ -1423,25 +1423,25 @@ var mockNotOwnerErrOnce uint32 // getIndexRecord gets index columns values use w.rowDecoder, and generate indexRecord. func (w *baseIndexWorker) getIndexRecord(idxInfo *model.IndexInfo, handle kv.Handle, recordKey []byte) (*indexRecord, error) { cols := w.table.WritableCols() - if val, _err_ := failpoint.Eval(_curpkg_("MockGetIndexRecordErr")); _err_ == nil { + failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { if valStr, ok := val.(string); ok { switch valStr { case "cantDecodeRecordErr": - return nil, errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("index", - errors.New("mock can't decode record error"))) + failpoint.Return(nil, errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("index", + errors.New("mock can't decode record error")))) case "modifyColumnNotOwnerErr": if idxInfo.Name.O == "_Idx$_idx_0" && handle.IntValue() == 7168 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 0, 1) { - return nil, errors.Trace(dbterror.ErrNotOwner) + failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) } case "addIdxNotOwnerErr": // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. // First step, we need to exit "addPhysicalTableIndex". if idxInfo.Name.O == "idx2" && handle.IntValue() == 6144 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 1, 2) { - return nil, errors.Trace(dbterror.ErrNotOwner) + failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) } } } - } + }) idxVal := make([]types.Datum, len(idxInfo.Columns)) var err error for j, v := range idxInfo.Columns { @@ -1768,16 +1768,16 @@ func writeOneKVToLocal( if err != nil { return errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("mockLocalWriterPanic")); _err_ == nil { + failpoint.Inject("mockLocalWriterPanic", func() { panic("mock panic") - } + }) err = writer.WriteRow(ctx, key, idxVal, handle) if err != nil { return errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("mockLocalWriterError")); _err_ == nil { - return errors.New("mock engine error") - } + failpoint.Inject("mockLocalWriterError", func() { + failpoint.Return(errors.New("mock engine error")) + }) writeBufs.IndexKeyBuf = key writeBufs.RowValBuf = idxVal } @@ -1788,12 +1788,12 @@ func writeOneKVToLocal( // Note that index columns values may change, and an index is not allowed to be added, so the txn will rollback and retry. // BackfillData will add w.batchCnt indices once, default value of w.batchCnt is 128. func (w *addIndexTxnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - if val, _err_ := failpoint.Eval(_curpkg_("errorMockPanic")); _err_ == nil { + failpoint.Inject("errorMockPanic", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { panic("panic test") } - } + }) oprStartTime := time.Now() jobID := handleRange.getJobID() @@ -1858,12 +1858,12 @@ func (w *addIndexTxnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx return nil }) logSlowOperations(time.Since(oprStartTime), "AddIndexBackfillData", 3000) - if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecution")); _err_ == nil { + failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) && MockDMLExecution != nil { MockDMLExecution() } - } + }) return } @@ -1923,7 +1923,7 @@ func (w *worker) addTableIndex(t table.Table, reorgInfo *reorgInfo) error { if err != nil { return errors.Trace(err) } - failpoint.Call(_curpkg_("afterUpdatePartitionReorgInfo"), reorgInfo.Job) + failpoint.InjectCall("afterUpdatePartitionReorgInfo", reorgInfo.Job) // Every time we finish a partition, we update the progress of the job. if rc := w.getReorgCtx(reorgInfo.Job.ID); rc != nil { reorgInfo.Job.SetRowCount(rc.getRowCount()) @@ -2055,7 +2055,7 @@ func (w *worker) executeDistTask(t table.Table, reorgInfo *reorgInfo) error { g.Go(func() error { defer close(done) err := submitAndWaitTask(ctx, taskKey, taskType, concurrency, reorgInfo.ReorgMeta.TargetScope, metaData) - failpoint.Call(_curpkg_("pauseAfterDistTaskFinished")) + failpoint.InjectCall("pauseAfterDistTaskFinished") if err := w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { if dbterror.ErrPausedDDLJob.Equal(err) { logutil.DDLLogger().Warn("job paused by user", zap.Error(err)) @@ -2083,7 +2083,7 @@ func (w *worker) executeDistTask(t table.Table, reorgInfo *reorgInfo) error { logutil.DDLLogger().Error("pause task error", zap.String("task_key", taskKey), zap.Error(err)) continue } - failpoint.Call(_curpkg_("syncDDLTaskPause")) + failpoint.InjectCall("syncDDLTaskPause") } if !dbterror.ErrCancelledDDLJob.Equal(err) { return errors.Trace(err) @@ -2264,7 +2264,7 @@ func getNextPartitionInfo(reorg *reorgInfo, t table.PartitionedTable, currPhysic return 0, nil, nil, nil } - if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateCachedSafePoint")); _err_ == nil { + failpoint.Inject("mockUpdateCachedSafePoint", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { ts := oracle.GoTimeToTS(time.Now()) @@ -2273,7 +2273,7 @@ func getNextPartitionInfo(reorg *reorgInfo, t table.PartitionedTable, currPhysic s.UpdateSPCache(ts, time.Now()) time.Sleep(time.Second * 3) } - } + }) var startKey, endKey kv.Key if reorg.mergingTmpIdx { @@ -2414,12 +2414,12 @@ func newCleanUpIndexWorker(id int, t table.PhysicalTable, decodeColMap map[int64 } func (w *cleanUpIndexWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - if val, _err_ := failpoint.Eval(_curpkg_("errorMockPanic")); _err_ == nil { + failpoint.Inject("errorMockPanic", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { panic("panic test") } - } + }) oprStartTime := time.Now() ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) @@ -2457,12 +2457,12 @@ func (w *cleanUpIndexWorker) BackfillData(handleRange reorgBackfillTask) (taskCt return nil }) logSlowOperations(time.Since(oprStartTime), "cleanUpIndexBackfillDataInTxn", 3000) - if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecution")); _err_ == nil { + failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) && MockDMLExecution != nil { MockDMLExecution() } - } + }) return } diff --git a/pkg/ddl/index.go__failpoint_stash__ b/pkg/ddl/index.go__failpoint_stash__ deleted file mode 100644 index d9b2c97d50903..0000000000000 --- a/pkg/ddl/index.go__failpoint_stash__ +++ /dev/null @@ -1,2616 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "cmp" - "context" - "encoding/hex" - "encoding/json" - "fmt" - "os" - "slices" - "strings" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/copr" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - ddlutil "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - litconfig "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/backoff" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbterror" - tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" - decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" - "github.com/pingcap/tidb/pkg/util/size" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/pingcap/tidb/pkg/util/stringutil" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - kvutil "github.com/tikv/client-go/v2/util" - pd "github.com/tikv/pd/client" - pdHttp "github.com/tikv/pd/client/http" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -const ( - // MaxCommentLength is exported for testing. - MaxCommentLength = 1024 -) - -var ( - // SuppressErrorTooLongKeyKey is used by SchemaTracker to suppress err too long key error - SuppressErrorTooLongKeyKey stringutil.StringerStr = "suppressErrorTooLongKeyKey" -) - -func suppressErrorTooLongKeyKey(sctx sessionctx.Context) bool { - if suppress, ok := sctx.Value(SuppressErrorTooLongKeyKey).(bool); ok && suppress { - return true - } - return false -} - -func buildIndexColumns(ctx sessionctx.Context, columns []*model.ColumnInfo, indexPartSpecifications []*ast.IndexPartSpecification) ([]*model.IndexColumn, bool, error) { - // Build offsets. - idxParts := make([]*model.IndexColumn, 0, len(indexPartSpecifications)) - var col *model.ColumnInfo - var mvIndex bool - maxIndexLength := config.GetGlobalConfig().MaxIndexLength - // The sum of length of all index columns. - sumLength := 0 - for _, ip := range indexPartSpecifications { - col = model.FindColumnInfo(columns, ip.Column.Name.L) - if col == nil { - return nil, false, dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ip.Column.Name) - } - - if err := checkIndexColumn(ctx, col, ip.Length); err != nil { - return nil, false, err - } - if col.FieldType.IsArray() { - if mvIndex { - return nil, false, dbterror.ErrNotSupportedYet.GenWithStackByArgs("more than one multi-valued key part per index") - } - mvIndex = true - } - indexColLen := ip.Length - if indexColLen != types.UnspecifiedLength && - types.IsTypeChar(col.FieldType.GetType()) && - indexColLen == col.FieldType.GetFlen() { - indexColLen = types.UnspecifiedLength - } - indexColumnLength, err := getIndexColumnLength(col, indexColLen) - if err != nil { - return nil, false, err - } - sumLength += indexColumnLength - - // The sum of all lengths must be shorter than the max length for prefix. - if sumLength > maxIndexLength { - // The multiple column index and the unique index in which the length sum exceeds the maximum size - // will return an error instead produce a warning. - if ctx == nil || (ctx.GetSessionVars().SQLMode.HasStrictMode() && !suppressErrorTooLongKeyKey(ctx)) || mysql.HasUniKeyFlag(col.GetFlag()) || len(indexPartSpecifications) > 1 { - return nil, false, dbterror.ErrTooLongKey.GenWithStackByArgs(sumLength, maxIndexLength) - } - // truncate index length and produce warning message in non-restrict sql mode. - colLenPerUint, err := getIndexColumnLength(col, 1) - if err != nil { - return nil, false, err - } - indexColLen = maxIndexLength / colLenPerUint - // produce warning message - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTooLongKey.FastGenByArgs(sumLength, maxIndexLength)) - } - - idxParts = append(idxParts, &model.IndexColumn{ - Name: col.Name, - Offset: col.Offset, - Length: indexColLen, - }) - } - - return idxParts, mvIndex, nil -} - -// CheckPKOnGeneratedColumn checks the specification of PK is valid. -func CheckPKOnGeneratedColumn(tblInfo *model.TableInfo, indexPartSpecifications []*ast.IndexPartSpecification) (*model.ColumnInfo, error) { - var lastCol *model.ColumnInfo - for _, colName := range indexPartSpecifications { - lastCol = tblInfo.FindPublicColumnByName(colName.Column.Name.L) - if lastCol == nil { - return nil, dbterror.ErrKeyColumnDoesNotExits.GenWithStackByArgs(colName.Column.Name) - } - // Virtual columns cannot be used in primary key. - if lastCol.IsGenerated() && !lastCol.GeneratedStored { - if lastCol.Hidden { - return nil, dbterror.ErrFunctionalIndexPrimaryKey - } - return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs("Defining a virtual generated column as primary key") - } - } - - return lastCol, nil -} - -func checkIndexPrefixLength(columns []*model.ColumnInfo, idxColumns []*model.IndexColumn) error { - idxLen, err := indexColumnsLen(columns, idxColumns) - if err != nil { - return err - } - if idxLen > config.GetGlobalConfig().MaxIndexLength { - return dbterror.ErrTooLongKey.GenWithStackByArgs(idxLen, config.GetGlobalConfig().MaxIndexLength) - } - return nil -} - -func indexColumnsLen(cols []*model.ColumnInfo, idxCols []*model.IndexColumn) (colLen int, err error) { - for _, idxCol := range idxCols { - col := model.FindColumnInfo(cols, idxCol.Name.L) - if col == nil { - err = dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", idxCol.Name.L) - return - } - var l int - l, err = getIndexColumnLength(col, idxCol.Length) - if err != nil { - return - } - colLen += l - } - return -} - -func checkIndexColumn(ctx sessionctx.Context, col *model.ColumnInfo, indexColumnLen int) error { - if col.GetFlen() == 0 && (types.IsTypeChar(col.FieldType.GetType()) || types.IsTypeVarchar(col.FieldType.GetType())) { - if col.Hidden { - return errors.Trace(dbterror.ErrWrongKeyColumnFunctionalIndex.GenWithStackByArgs(col.GeneratedExprString)) - } - return errors.Trace(dbterror.ErrWrongKeyColumn.GenWithStackByArgs(col.Name)) - } - - // JSON column cannot index. - if col.FieldType.GetType() == mysql.TypeJSON && !col.FieldType.IsArray() { - if col.Hidden { - return dbterror.ErrFunctionalIndexOnJSONOrGeometryFunction - } - return errors.Trace(dbterror.ErrJSONUsedAsKey.GenWithStackByArgs(col.Name.O)) - } - - // Length must be specified and non-zero for BLOB and TEXT column indexes. - if types.IsTypeBlob(col.FieldType.GetType()) { - if indexColumnLen == types.UnspecifiedLength { - if col.Hidden { - return dbterror.ErrFunctionalIndexOnBlob - } - return errors.Trace(dbterror.ErrBlobKeyWithoutLength.GenWithStackByArgs(col.Name.O)) - } - if indexColumnLen == types.ErrorLength { - return errors.Trace(dbterror.ErrKeyPart0.GenWithStackByArgs(col.Name.O)) - } - } - - // Length can only be specified for specifiable types. - if indexColumnLen != types.UnspecifiedLength && !types.IsTypePrefixable(col.FieldType.GetType()) { - return errors.Trace(dbterror.ErrIncorrectPrefixKey) - } - - // Key length must be shorter or equal to the column length. - if indexColumnLen != types.UnspecifiedLength && - types.IsTypeChar(col.FieldType.GetType()) { - if col.GetFlen() < indexColumnLen { - return errors.Trace(dbterror.ErrIncorrectPrefixKey) - } - // Length must be non-zero for char. - if indexColumnLen == types.ErrorLength { - return errors.Trace(dbterror.ErrKeyPart0.GenWithStackByArgs(col.Name.O)) - } - } - - if types.IsString(col.FieldType.GetType()) { - desc, err := charset.GetCharsetInfo(col.GetCharset()) - if err != nil { - return err - } - indexColumnLen *= desc.Maxlen - } - // Specified length must be shorter than the max length for prefix. - maxIndexLength := config.GetGlobalConfig().MaxIndexLength - if indexColumnLen > maxIndexLength { - if ctx == nil || (ctx.GetSessionVars().SQLMode.HasStrictMode() && !suppressErrorTooLongKeyKey(ctx)) { - // return error in strict sql mode - return dbterror.ErrTooLongKey.GenWithStackByArgs(indexColumnLen, maxIndexLength) - } - } - return nil -} - -// getIndexColumnLength calculate the bytes number required in an index column. -func getIndexColumnLength(col *model.ColumnInfo, colLen int) (int, error) { - length := types.UnspecifiedLength - if colLen != types.UnspecifiedLength { - length = colLen - } else if col.GetFlen() != types.UnspecifiedLength { - length = col.GetFlen() - } - - switch col.GetType() { - case mysql.TypeBit: - return (length + 7) >> 3, nil - case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob: - // Different charsets occupy different numbers of bytes on each character. - desc, err := charset.GetCharsetInfo(col.GetCharset()) - if err != nil { - return 0, dbterror.ErrUnsupportedCharset.GenWithStackByArgs(col.GetCharset(), col.GetCollate()) - } - return desc.Maxlen * length, nil - case mysql.TypeTiny, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeDouble, mysql.TypeShort: - return mysql.DefaultLengthOfMysqlTypes[col.GetType()], nil - case mysql.TypeFloat: - if length <= mysql.MaxFloatPrecisionLength { - return mysql.DefaultLengthOfMysqlTypes[mysql.TypeFloat], nil - } - return mysql.DefaultLengthOfMysqlTypes[mysql.TypeDouble], nil - case mysql.TypeNewDecimal: - return calcBytesLengthForDecimal(length), nil - case mysql.TypeYear, mysql.TypeDate, mysql.TypeDuration, mysql.TypeDatetime, mysql.TypeTimestamp: - return mysql.DefaultLengthOfMysqlTypes[col.GetType()], nil - default: - return length, nil - } -} - -// decimal using a binary format that packs nine decimal (base 10) digits into four bytes. -func calcBytesLengthForDecimal(m int) int { - return (m / 9 * 4) + ((m%9)+1)/2 -} - -// BuildIndexInfo builds a new IndexInfo according to the index information. -func BuildIndexInfo( - ctx sessionctx.Context, - allTableColumns []*model.ColumnInfo, - indexName model.CIStr, - isPrimary bool, - isUnique bool, - isGlobal bool, - indexPartSpecifications []*ast.IndexPartSpecification, - indexOption *ast.IndexOption, - state model.SchemaState, -) (*model.IndexInfo, error) { - if err := checkTooLongIndex(indexName); err != nil { - return nil, errors.Trace(err) - } - - idxColumns, mvIndex, err := buildIndexColumns(ctx, allTableColumns, indexPartSpecifications) - if err != nil { - return nil, errors.Trace(err) - } - - // Create index info. - idxInfo := &model.IndexInfo{ - Name: indexName, - Columns: idxColumns, - State: state, - Primary: isPrimary, - Unique: isUnique, - Global: isGlobal, - MVIndex: mvIndex, - } - - if indexOption != nil { - idxInfo.Comment = indexOption.Comment - if indexOption.Visibility == ast.IndexVisibilityInvisible { - idxInfo.Invisible = true - } - if indexOption.Tp == model.IndexTypeInvalid { - // Use btree as default index type. - idxInfo.Tp = model.IndexTypeBtree - } else { - idxInfo.Tp = indexOption.Tp - } - } else { - // Use btree as default index type. - idxInfo.Tp = model.IndexTypeBtree - } - - return idxInfo, nil -} - -// AddIndexColumnFlag aligns the column flags of columns in TableInfo to IndexInfo. -func AddIndexColumnFlag(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) { - if indexInfo.Primary { - for _, col := range indexInfo.Columns { - tblInfo.Columns[col.Offset].AddFlag(mysql.PriKeyFlag) - } - return - } - - col := indexInfo.Columns[0] - if indexInfo.Unique && len(indexInfo.Columns) == 1 { - tblInfo.Columns[col.Offset].AddFlag(mysql.UniqueKeyFlag) - } else { - tblInfo.Columns[col.Offset].AddFlag(mysql.MultipleKeyFlag) - } -} - -// DropIndexColumnFlag drops the column flag of columns in TableInfo according to the IndexInfo. -func DropIndexColumnFlag(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) { - if indexInfo.Primary { - for _, col := range indexInfo.Columns { - tblInfo.Columns[col.Offset].DelFlag(mysql.PriKeyFlag) - } - } else if indexInfo.Unique && len(indexInfo.Columns) == 1 { - tblInfo.Columns[indexInfo.Columns[0].Offset].DelFlag(mysql.UniqueKeyFlag) - } else { - tblInfo.Columns[indexInfo.Columns[0].Offset].DelFlag(mysql.MultipleKeyFlag) - } - - col := indexInfo.Columns[0] - // other index may still cover this col - for _, index := range tblInfo.Indices { - if index.Name.L == indexInfo.Name.L { - continue - } - - if index.Columns[0].Name.L != col.Name.L { - continue - } - - AddIndexColumnFlag(tblInfo, index) - } -} - -// ValidateRenameIndex checks if index name is ok to be renamed. -func ValidateRenameIndex(from, to model.CIStr, tbl *model.TableInfo) (ignore bool, err error) { - if fromIdx := tbl.FindIndexByName(from.L); fromIdx == nil { - return false, errors.Trace(infoschema.ErrKeyNotExists.GenWithStackByArgs(from.O, tbl.Name)) - } - // Take case-sensitivity into account, if `FromKey` and `ToKey` are the same, nothing need to be changed - if from.O == to.O { - return true, nil - } - // If spec.FromKey.L == spec.ToKey.L, we operate on the same index(case-insensitive) and change its name (case-sensitive) - // e.g: from `inDex` to `IndEX`. Otherwise, we try to rename an index to another different index which already exists, - // that's illegal by rule. - if toIdx := tbl.FindIndexByName(to.L); toIdx != nil && from.L != to.L { - return false, errors.Trace(infoschema.ErrKeyNameDuplicate.GenWithStackByArgs(toIdx.Name.O)) - } - return false, nil -} - -func onRenameIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - tblInfo, from, to, err := checkRenameIndex(t, job) - if err != nil || tblInfo == nil { - return ver, errors.Trace(err) - } - if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { - return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Rename Index")) - } - - if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - // Store the mark and enter the next DDL handling loop. - return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) - } - - renameIndexes(tblInfo, from, to) - renameHiddenColumns(tblInfo, from, to) - - if ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func validateAlterIndexVisibility(ctx sessionctx.Context, indexName model.CIStr, invisible bool, tbl *model.TableInfo) (bool, error) { - var idx *model.IndexInfo - if idx = tbl.FindIndexByName(indexName.L); idx == nil || idx.State != model.StatePublic { - return false, errors.Trace(infoschema.ErrKeyNotExists.GenWithStackByArgs(indexName.O, tbl.Name)) - } - if ctx == nil || ctx.GetSessionVars() == nil || ctx.GetSessionVars().StmtCtx.MultiSchemaInfo == nil { - // Early return. - if idx.Invisible == invisible { - return true, nil - } - } - return false, nil -} - -func onAlterIndexVisibility(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - tblInfo, from, invisible, err := checkAlterIndexVisibility(t, job) - if err != nil || tblInfo == nil { - return ver, errors.Trace(err) - } - - if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - return updateVersionAndTableInfo(d, t, job, tblInfo, false) - } - - setIndexVisibility(tblInfo, from, invisible) - if ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func setIndexVisibility(tblInfo *model.TableInfo, name model.CIStr, invisible bool) { - for _, idx := range tblInfo.Indices { - if idx.Name.L == name.L || (isTempIdxInfo(idx, tblInfo) && getChangingIndexOriginName(idx) == name.O) { - idx.Invisible = invisible - } - } -} - -func getNullColInfos(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) ([]*model.ColumnInfo, error) { - nullCols := make([]*model.ColumnInfo, 0, len(indexInfo.Columns)) - for _, colName := range indexInfo.Columns { - col := model.FindColumnInfo(tblInfo.Columns, colName.Name.L) - if !mysql.HasNotNullFlag(col.GetFlag()) || mysql.HasPreventNullInsertFlag(col.GetFlag()) { - nullCols = append(nullCols, col) - } - } - return nullCols, nil -} - -func checkPrimaryKeyNotNull(d *ddlCtx, w *worker, t *meta.Meta, job *model.Job, - tblInfo *model.TableInfo, indexInfo *model.IndexInfo) (warnings []string, err error) { - if !indexInfo.Primary { - return nil, nil - } - - dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) - if err != nil { - return nil, err - } - nullCols, err := getNullColInfos(tblInfo, indexInfo) - if err != nil { - return nil, err - } - if len(nullCols) == 0 { - return nil, nil - } - - err = modifyColsFromNull2NotNull(w, dbInfo, tblInfo, nullCols, &model.ColumnInfo{Name: model.NewCIStr("")}, false) - if err == nil { - return nil, nil - } - _, err = convertAddIdxJob2RollbackJob(d, t, job, tblInfo, []*model.IndexInfo{indexInfo}, err) - // TODO: Support non-strict mode. - // warnings = append(warnings, ErrWarnDataTruncated.GenWithStackByArgs(oldCol.Name.L, 0).Error()) - return nil, err -} - -// moveAndUpdateHiddenColumnsToPublic updates the hidden columns to public, and -// moves the hidden columns to proper offsets, so that Table.Columns' states meet the assumption of -// [public, public, ..., public, non-public, non-public, ..., non-public]. -func moveAndUpdateHiddenColumnsToPublic(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) { - hiddenColOffset := make(map[int]struct{}, 0) - for _, col := range idxInfo.Columns { - if tblInfo.Columns[col.Offset].Hidden { - hiddenColOffset[col.Offset] = struct{}{} - } - } - if len(hiddenColOffset) == 0 { - return - } - // Find the first non-public column. - firstNonPublicPos := len(tblInfo.Columns) - 1 - for i, c := range tblInfo.Columns { - if c.State != model.StatePublic { - firstNonPublicPos = i - break - } - } - for _, col := range idxInfo.Columns { - tblInfo.Columns[col.Offset].State = model.StatePublic - if _, needMove := hiddenColOffset[col.Offset]; needMove { - tblInfo.MoveColumnInfo(col.Offset, firstNonPublicPos) - } - } -} - -func decodeAddIndexArgs(job *model.Job) ( - uniques []bool, - indexNames []model.CIStr, - indexPartSpecifications [][]*ast.IndexPartSpecification, - indexOptions []*ast.IndexOption, - hiddenCols [][]*model.ColumnInfo, - globals []bool, - err error, -) { - var ( - unique bool - indexName model.CIStr - indexPartSpecification []*ast.IndexPartSpecification - indexOption *ast.IndexOption - hiddenCol []*model.ColumnInfo - global bool - ) - err = job.DecodeArgs(&unique, &indexName, &indexPartSpecification, &indexOption, &hiddenCol, &global) - if err == nil { - return []bool{unique}, - []model.CIStr{indexName}, - [][]*ast.IndexPartSpecification{indexPartSpecification}, - []*ast.IndexOption{indexOption}, - [][]*model.ColumnInfo{hiddenCol}, - []bool{global}, - nil - } - - err = job.DecodeArgs(&uniques, &indexNames, &indexPartSpecifications, &indexOptions, &hiddenCols, &globals) - return -} - -func (w *worker) onCreateIndex(d *ddlCtx, t *meta.Meta, job *model.Job, isPK bool) (ver int64, err error) { - // Handle the rolling back job. - if job.IsRollingback() { - ver, err = onDropIndex(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - return ver, nil - } - - // Handle normal job. - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return ver, errors.Trace(err) - } - if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { - return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Create Index")) - } - - uniques := make([]bool, 1) - global := make([]bool, 1) - indexNames := make([]model.CIStr, 1) - indexPartSpecifications := make([][]*ast.IndexPartSpecification, 1) - indexOption := make([]*ast.IndexOption, 1) - var sqlMode mysql.SQLMode - var warnings []string - hiddenCols := make([][]*model.ColumnInfo, 1) - - if isPK { - // Notice: sqlMode and warnings is used to support non-strict mode. - err = job.DecodeArgs(&uniques[0], &indexNames[0], &indexPartSpecifications[0], &indexOption[0], &sqlMode, &warnings, &global[0]) - } else { - uniques, indexNames, indexPartSpecifications, indexOption, hiddenCols, global, err = decodeAddIndexArgs(job) - } - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - allIndexInfos := make([]*model.IndexInfo, 0, len(indexNames)) - for i, indexName := range indexNames { - indexInfo := tblInfo.FindIndexByName(indexName.L) - if indexInfo != nil && indexInfo.State == model.StatePublic { - job.State = model.JobStateCancelled - err = dbterror.ErrDupKeyName.GenWithStack("index already exist %s", indexName) - if isPK { - err = infoschema.ErrMultiplePriKey - } - return ver, err - } - if indexInfo == nil { - for _, hiddenCol := range hiddenCols[i] { - columnInfo := model.FindColumnInfo(tblInfo.Columns, hiddenCol.Name.L) - if columnInfo != nil && columnInfo.State == model.StatePublic { - // We already have a column with the same column name. - job.State = model.JobStateCancelled - // TODO: refine the error message - return ver, infoschema.ErrColumnExists.GenWithStackByArgs(hiddenCol.Name) - } - } - } - if indexInfo == nil { - if len(hiddenCols) > 0 { - for _, hiddenCol := range hiddenCols[i] { - InitAndAddColumnToTable(tblInfo, hiddenCol) - } - } - if err = checkAddColumnTooManyColumns(len(tblInfo.Columns)); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - indexInfo, err = BuildIndexInfo( - nil, - tblInfo.Columns, - indexName, - isPK, - uniques[i], - global[i], - indexPartSpecifications[i], - indexOption[i], - model.StateNone, - ) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - if isPK { - if _, err = CheckPKOnGeneratedColumn(tblInfo, indexPartSpecifications[i]); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - } - indexInfo.ID = AllocateIndexID(tblInfo) - tblInfo.Indices = append(tblInfo.Indices, indexInfo) - if err = checkTooManyIndexes(tblInfo.Indices); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - // Here we need do this check before set state to `DeleteOnly`, - // because if hidden columns has been set to `DeleteOnly`, - // the `DeleteOnly` columns are missing when we do this check. - if err := checkInvisibleIndexOnPK(tblInfo); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - logutil.DDLLogger().Info("run add index job", zap.Stringer("job", job), zap.Reflect("indexInfo", indexInfo)) - } - allIndexInfos = append(allIndexInfos, indexInfo) - } - - originalState := allIndexInfos[0].State - -SwitchIndexState: - switch allIndexInfos[0].State { - case model.StateNone: - // none -> delete only - var reorgTp model.ReorgType - reorgTp, err = pickBackfillType(job) - if err != nil { - if !errorIsRetryable(err, job) { - job.State = model.JobStateCancelled - } - return ver, err - } - loadCloudStorageURI(w, job) - if reorgTp.NeedMergeProcess() { - for _, indexInfo := range allIndexInfos { - indexInfo.BackfillState = model.BackfillStateRunning - } - } - for _, indexInfo := range allIndexInfos { - indexInfo.State = model.StateDeleteOnly - moveAndUpdateHiddenColumnsToPublic(tblInfo, indexInfo) - } - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != model.StateDeleteOnly) - if err != nil { - return ver, err - } - job.SchemaState = model.StateDeleteOnly - case model.StateDeleteOnly: - // delete only -> write only - for _, indexInfo := range allIndexInfos { - indexInfo.State = model.StateWriteOnly - _, err = checkPrimaryKeyNotNull(d, w, t, job, tblInfo, indexInfo) - if err != nil { - break SwitchIndexState - } - } - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateWriteOnly) - if err != nil { - return ver, err - } - job.SchemaState = model.StateWriteOnly - case model.StateWriteOnly: - // write only -> reorganization - for _, indexInfo := range allIndexInfos { - indexInfo.State = model.StateWriteReorganization - _, err = checkPrimaryKeyNotNull(d, w, t, job, tblInfo, indexInfo) - if err != nil { - break SwitchIndexState - } - } - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateWriteReorganization) - if err != nil { - return ver, err - } - // Initialize SnapshotVer to 0 for later reorganization check. - job.SnapshotVer = 0 - job.SchemaState = model.StateWriteReorganization - case model.StateWriteReorganization: - // reorganization -> public - tbl, err := getTable(d.getAutoIDRequirement(), schemaID, tblInfo) - if err != nil { - return ver, errors.Trace(err) - } - - var done bool - if job.MultiSchemaInfo != nil { - done, ver, err = doReorgWorkForCreateIndexMultiSchema(w, d, t, job, tbl, allIndexInfos) - } else { - done, ver, err = doReorgWorkForCreateIndex(w, d, t, job, tbl, allIndexInfos) - } - if !done { - return ver, err - } - - // Set column index flag. - for _, indexInfo := range allIndexInfos { - AddIndexColumnFlag(tblInfo, indexInfo) - if isPK { - if err = UpdateColsNull2NotNull(tblInfo, indexInfo); err != nil { - return ver, errors.Trace(err) - } - } - indexInfo.State = model.StatePublic - } - - // Inject the failpoint to prevent the progress of index creation. - failpoint.Inject("create-index-stuck-before-public", func(v failpoint.Value) { - if sigFile, ok := v.(string); ok { - for { - time.Sleep(1 * time.Second) - if _, err := os.Stat(sigFile); err != nil { - if os.IsNotExist(err) { - continue - } - failpoint.Return(ver, errors.Trace(err)) - } - break - } - } - }) - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StatePublic) - if err != nil { - return ver, errors.Trace(err) - } - - allIndexIDs := make([]int64, 0, len(allIndexInfos)) - ifExists := make([]bool, 0, len(allIndexInfos)) - isGlobal := make([]bool, 0, len(allIndexInfos)) - for _, indexInfo := range allIndexInfos { - allIndexIDs = append(allIndexIDs, indexInfo.ID) - ifExists = append(ifExists, false) - isGlobal = append(isGlobal, indexInfo.Global) - } - job.Args = []any{allIndexIDs, ifExists, getPartitionIDs(tbl.Meta()), isGlobal} - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - if !job.ReorgMeta.IsDistReorg && job.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { - ingest.LitBackCtxMgr.Unregister(job.ID) - } - logutil.DDLLogger().Info("run add index job done", - zap.String("charset", job.Charset), - zap.String("collation", job.Collate)) - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", tblInfo.State) - } - - return ver, errors.Trace(err) -} - -// pickBackfillType determines which backfill process will be used. The result is -// both stored in job.ReorgMeta.ReorgTp and returned. -func pickBackfillType(job *model.Job) (model.ReorgType, error) { - if job.ReorgMeta.ReorgTp != model.ReorgTypeNone { - // The backfill task has been started. - // Don't change the backfill type. - return job.ReorgMeta.ReorgTp, nil - } - if !job.ReorgMeta.IsFastReorg { - job.ReorgMeta.ReorgTp = model.ReorgTypeTxn - return model.ReorgTypeTxn, nil - } - if ingest.LitInitialized { - if job.ReorgMeta.UseCloudStorage { - job.ReorgMeta.ReorgTp = model.ReorgTypeLitMerge - return model.ReorgTypeLitMerge, nil - } - available, err := ingest.LitBackCtxMgr.CheckMoreTasksAvailable() - if err != nil { - return model.ReorgTypeNone, err - } - if available { - job.ReorgMeta.ReorgTp = model.ReorgTypeLitMerge - return model.ReorgTypeLitMerge, nil - } - } - // The lightning environment is unavailable, but we can still use the txn-merge backfill. - logutil.DDLLogger().Info("fallback to txn-merge backfill process", - zap.Bool("lightning env initialized", ingest.LitInitialized)) - job.ReorgMeta.ReorgTp = model.ReorgTypeTxnMerge - return model.ReorgTypeTxnMerge, nil -} - -func loadCloudStorageURI(w *worker, job *model.Job) { - jc := w.jobContext(job.ID, job.ReorgMeta) - jc.cloudStorageURI = variable.CloudStorageURI.Load() - job.ReorgMeta.UseCloudStorage = len(jc.cloudStorageURI) > 0 -} - -func doReorgWorkForCreateIndexMultiSchema(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, - tbl table.Table, allIndexInfos []*model.IndexInfo) (done bool, ver int64, err error) { - if job.MultiSchemaInfo.Revertible { - done, ver, err = doReorgWorkForCreateIndex(w, d, t, job, tbl, allIndexInfos) - if done { - job.MarkNonRevertible() - if err == nil { - ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) - } - } - // We need another round to wait for all the others sub-jobs to finish. - return false, ver, err - } - return true, ver, err -} - -func doReorgWorkForCreateIndex( - w *worker, - d *ddlCtx, - t *meta.Meta, - job *model.Job, - tbl table.Table, - allIndexInfos []*model.IndexInfo, -) (done bool, ver int64, err error) { - var reorgTp model.ReorgType - reorgTp, err = pickBackfillType(job) - if err != nil { - return false, ver, err - } - if !reorgTp.NeedMergeProcess() { - return runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) - } - switch allIndexInfos[0].BackfillState { - case model.BackfillStateRunning: - logutil.DDLLogger().Info("index backfill state running", - zap.Int64("job ID", job.ID), zap.String("table", tbl.Meta().Name.O), - zap.Bool("ingest mode", reorgTp == model.ReorgTypeLitMerge), - zap.String("index", allIndexInfos[0].Name.O)) - switch reorgTp { - case model.ReorgTypeLitMerge: - if job.ReorgMeta.IsDistReorg { - done, ver, err = runIngestReorgJobDist(w, d, t, job, tbl, allIndexInfos) - } else { - done, ver, err = runIngestReorgJob(w, d, t, job, tbl, allIndexInfos) - } - case model.ReorgTypeTxnMerge: - done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) - } - if err != nil || !done { - return false, ver, errors.Trace(err) - } - for _, indexInfo := range allIndexInfos { - indexInfo.BackfillState = model.BackfillStateReadyToMerge - } - ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) - return false, ver, errors.Trace(err) - case model.BackfillStateReadyToMerge: - failpoint.Inject("mockDMLExecutionStateBeforeMerge", func(_ failpoint.Value) { - if MockDMLExecutionStateBeforeMerge != nil { - MockDMLExecutionStateBeforeMerge() - } - }) - logutil.DDLLogger().Info("index backfill state ready to merge", - zap.Int64("job ID", job.ID), - zap.String("table", tbl.Meta().Name.O), - zap.String("index", allIndexInfos[0].Name.O)) - for _, indexInfo := range allIndexInfos { - indexInfo.BackfillState = model.BackfillStateMerging - } - if reorgTp == model.ReorgTypeLitMerge { - ingest.LitBackCtxMgr.Unregister(job.ID) - } - job.SnapshotVer = 0 // Reset the snapshot version for merge index reorg. - ver, err = updateVersionAndTableInfo(d, t, job, tbl.Meta(), true) - return false, ver, errors.Trace(err) - case model.BackfillStateMerging: - done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, true) - if !done { - return false, ver, err - } - for _, indexInfo := range allIndexInfos { - indexInfo.BackfillState = model.BackfillStateInapplicable // Prevent double-write on this index. - } - return true, ver, err - default: - return false, 0, dbterror.ErrInvalidDDLState.GenWithStackByArgs("backfill", allIndexInfos[0].BackfillState) - } -} - -func runIngestReorgJobDist(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, - tbl table.Table, allIndexInfos []*model.IndexInfo) (done bool, ver int64, err error) { - done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) - if err != nil { - return false, ver, errors.Trace(err) - } - - if !done { - return false, ver, nil - } - - return true, ver, nil -} - -func runIngestReorgJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, - tbl table.Table, allIndexInfos []*model.IndexInfo) (done bool, ver int64, err error) { - done, ver, err = runReorgJobAndHandleErr(w, d, t, job, tbl, allIndexInfos, false) - if err != nil { - if kv.ErrKeyExists.Equal(err) { - logutil.DDLLogger().Warn("import index duplicate key, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) - ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), allIndexInfos, err) - } else if !errorIsRetryable(err, job) { - logutil.DDLLogger().Warn("run reorg job failed, convert job to rollback", - zap.String("job", job.String()), zap.Error(err)) - ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), allIndexInfos, err) - } else { - logutil.DDLLogger().Warn("run add index ingest job error", zap.Error(err)) - } - return false, ver, errors.Trace(err) - } - failpoint.InjectCall("afterRunIngestReorgJob", job, done) - return done, ver, nil -} - -func errorIsRetryable(err error, job *model.Job) bool { - if job.ErrorCount+1 >= variable.GetDDLErrorCountLimit() { - return false - } - originErr := errors.Cause(err) - if tErr, ok := originErr.(*terror.Error); ok { - sqlErr := terror.ToSQLError(tErr) - _, ok := dbterror.ReorgRetryableErrCodes[sqlErr.Code] - return ok - } - // For the unknown errors, we should retry. - return true -} - -func runReorgJobAndHandleErr( - w *worker, - d *ddlCtx, - t *meta.Meta, - job *model.Job, - tbl table.Table, - allIndexInfos []*model.IndexInfo, - mergingTmpIdx bool, -) (done bool, ver int64, err error) { - elements := make([]*meta.Element, 0, len(allIndexInfos)) - for _, indexInfo := range allIndexInfos { - elements = append(elements, &meta.Element{ID: indexInfo.ID, TypeKey: meta.IndexElementKey}) - } - - failpoint.Inject("mockDMLExecutionStateMerging", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) && allIndexInfos[0].BackfillState == model.BackfillStateMerging && - MockDMLExecutionStateMerging != nil { - MockDMLExecutionStateMerging() - } - }) - - sctx, err1 := w.sessPool.Get() - if err1 != nil { - err = err1 - return - } - defer w.sessPool.Put(sctx) - rh := newReorgHandler(sess.NewSession(sctx)) - dbInfo, err := t.GetDatabase(job.SchemaID) - if err != nil { - return false, ver, errors.Trace(err) - } - reorgInfo, err := getReorgInfo(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, tbl, elements, mergingTmpIdx) - if err != nil || reorgInfo == nil || reorgInfo.first { - // If we run reorg firstly, we should update the job snapshot version - // and then run the reorg next time. - return false, ver, errors.Trace(err) - } - err = overwriteReorgInfoFromGlobalCheckpoint(w, rh.s, job, reorgInfo) - if err != nil { - return false, ver, errors.Trace(err) - } - err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (addIndexErr error) { - defer util.Recover(metrics.LabelDDL, "onCreateIndex", - func() { - addIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("add table `%v` index `%v` panic", tbl.Meta().Name, allIndexInfos[0].Name) - }, false) - return w.addTableIndex(tbl, reorgInfo) - }) - if err != nil { - if dbterror.ErrPausedDDLJob.Equal(err) { - return false, ver, nil - } - if dbterror.ErrWaitReorgTimeout.Equal(err) { - // if timeout, we should return, check for the owner and re-wait job done. - return false, ver, nil - } - // TODO(tangenta): get duplicate column and match index. - err = ingest.TryConvertToKeyExistsErr(err, allIndexInfos[0], tbl.Meta()) - if !errorIsRetryable(err, job) { - logutil.DDLLogger().Warn("run add index job failed, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) - ver, err = convertAddIdxJob2RollbackJob(d, t, job, tbl.Meta(), allIndexInfos, err) - if err1 := rh.RemoveDDLReorgHandle(job, reorgInfo.elements); err1 != nil { - logutil.DDLLogger().Warn("run add index job failed, convert job to rollback, RemoveDDLReorgHandle failed", zap.Stringer("job", job), zap.Error(err1)) - } - } - return false, ver, errors.Trace(err) - } - failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { - if MockDMLExecutionStateBeforeImport != nil { - MockDMLExecutionStateBeforeImport() - } - }) - return true, ver, nil -} - -func onDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - tblInfo, allIndexInfos, ifExists, err := checkDropIndex(d, t, job) - if err != nil { - if ifExists && dbterror.ErrCantDropFieldOrKey.Equal(err) { - job.Warning = toTError(err) - job.State = model.JobStateDone - return ver, nil - } - return ver, errors.Trace(err) - } - if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { - return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("Drop Index")) - } - - if job.MultiSchemaInfo != nil && !job.IsRollingback() && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - job.SchemaState = allIndexInfos[0].State - return updateVersionAndTableInfo(d, t, job, tblInfo, false) - } - - originalState := allIndexInfos[0].State - switch allIndexInfos[0].State { - case model.StatePublic: - // public -> write only - for _, indexInfo := range allIndexInfos { - indexInfo.State = model.StateWriteOnly - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateWriteOnly) - if err != nil { - return ver, errors.Trace(err) - } - case model.StateWriteOnly: - // write only -> delete only - for _, indexInfo := range allIndexInfos { - indexInfo.State = model.StateDeleteOnly - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateDeleteOnly) - if err != nil { - return ver, errors.Trace(err) - } - case model.StateDeleteOnly: - // delete only -> reorganization - for _, indexInfo := range allIndexInfos { - indexInfo.State = model.StateDeleteReorganization - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateDeleteReorganization) - if err != nil { - return ver, errors.Trace(err) - } - case model.StateDeleteReorganization: - // reorganization -> absent - idxIDs := make([]int64, 0, len(allIndexInfos)) - for _, indexInfo := range allIndexInfos { - indexInfo.State = model.StateNone - // Set column index flag. - DropIndexColumnFlag(tblInfo, indexInfo) - RemoveDependentHiddenColumns(tblInfo, indexInfo) - removeIndexInfo(tblInfo, indexInfo) - idxIDs = append(idxIDs, indexInfo.ID) - } - - failpoint.Inject("mockExceedErrorLimit", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - panic("panic test in cancelling add index") - } - }) - - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != model.StateNone) - if err != nil { - return ver, errors.Trace(err) - } - - // Finish this job. - if job.IsRollingback() { - job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) - job.Args[0] = idxIDs - } else { - // the partition ids were append by convertAddIdxJob2RollbackJob, it is weird, but for the compatibility, - // we should keep appending the partitions in the convertAddIdxJob2RollbackJob. - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - // Global index key has t{tableID}_ prefix. - // Assign partitionIDs empty to guarantee correct prefix in insertJobIntoDeleteRangeTable. - if allIndexInfos[0].Global { - job.Args = append(job.Args, idxIDs[0], []int64{}) - } else { - job.Args = append(job.Args, idxIDs[0], getPartitionIDs(tblInfo)) - } - } - default: - return ver, errors.Trace(dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", allIndexInfos[0].State)) - } - job.SchemaState = allIndexInfos[0].State - return ver, errors.Trace(err) -} - -// RemoveDependentHiddenColumns removes hidden columns by the indexInfo. -func RemoveDependentHiddenColumns(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) { - hiddenColOffs := make([]int, 0) - for _, indexColumn := range idxInfo.Columns { - col := tblInfo.Columns[indexColumn.Offset] - if col.Hidden { - hiddenColOffs = append(hiddenColOffs, col.Offset) - } - } - // Sort the offset in descending order. - slices.SortFunc(hiddenColOffs, func(a, b int) int { return cmp.Compare(b, a) }) - // Move all the dependent hidden columns to the end. - endOffset := len(tblInfo.Columns) - 1 - for _, offset := range hiddenColOffs { - tblInfo.MoveColumnInfo(offset, endOffset) - } - tblInfo.Columns = tblInfo.Columns[:len(tblInfo.Columns)-len(hiddenColOffs)] -} - -func removeIndexInfo(tblInfo *model.TableInfo, idxInfo *model.IndexInfo) { - indices := tblInfo.Indices - offset := -1 - for i, idx := range indices { - if idxInfo.ID == idx.ID { - offset = i - break - } - } - if offset == -1 { - // The target index has been removed. - return - } - // Remove the target index. - tblInfo.Indices = append(tblInfo.Indices[:offset], tblInfo.Indices[offset+1:]...) -} - -func checkDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (*model.TableInfo, []*model.IndexInfo, bool /* ifExists */, error) { - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, false, errors.Trace(err) - } - - indexNames := make([]model.CIStr, 1) - ifExists := make([]bool, 1) - if err = job.DecodeArgs(&indexNames[0], &ifExists[0]); err != nil { - if err = job.DecodeArgs(&indexNames, &ifExists); err != nil { - job.State = model.JobStateCancelled - return nil, nil, false, errors.Trace(err) - } - } - - indexInfos := make([]*model.IndexInfo, 0, len(indexNames)) - for i, idxName := range indexNames { - indexInfo := tblInfo.FindIndexByName(idxName.L) - if indexInfo == nil { - job.State = model.JobStateCancelled - return nil, nil, ifExists[i], dbterror.ErrCantDropFieldOrKey.GenWithStack("index %s doesn't exist", idxName) - } - - // Check that drop primary index will not cause invisible implicit primary index. - if err := checkInvisibleIndexesOnPK(tblInfo, []*model.IndexInfo{indexInfo}, job); err != nil { - job.State = model.JobStateCancelled - return nil, nil, false, errors.Trace(err) - } - - // Double check for drop index needed in foreign key. - if err := checkIndexNeededInForeignKeyInOwner(d, t, job, job.SchemaName, tblInfo, indexInfo); err != nil { - return nil, nil, false, errors.Trace(err) - } - indexInfos = append(indexInfos, indexInfo) - } - return tblInfo, indexInfos, false, nil -} - -func checkInvisibleIndexesOnPK(tblInfo *model.TableInfo, indexInfos []*model.IndexInfo, job *model.Job) error { - newIndices := make([]*model.IndexInfo, 0, len(tblInfo.Indices)) - for _, oidx := range tblInfo.Indices { - needAppend := true - for _, idx := range indexInfos { - if idx.Name.L == oidx.Name.L { - needAppend = false - break - } - } - if needAppend { - newIndices = append(newIndices, oidx) - } - } - newTbl := tblInfo.Clone() - newTbl.Indices = newIndices - if err := checkInvisibleIndexOnPK(newTbl); err != nil { - job.State = model.JobStateCancelled - return err - } - - return nil -} - -func checkRenameIndex(t *meta.Meta, job *model.Job) (*model.TableInfo, model.CIStr, model.CIStr, error) { - var from, to model.CIStr - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, from, to, errors.Trace(err) - } - - if err := job.DecodeArgs(&from, &to); err != nil { - job.State = model.JobStateCancelled - return nil, from, to, errors.Trace(err) - } - - // Double check. See function `RenameIndex` in executor.go - duplicate, err := ValidateRenameIndex(from, to, tblInfo) - if duplicate { - return nil, from, to, nil - } - if err != nil { - job.State = model.JobStateCancelled - return nil, from, to, errors.Trace(err) - } - return tblInfo, from, to, errors.Trace(err) -} - -func checkAlterIndexVisibility(t *meta.Meta, job *model.Job) (*model.TableInfo, model.CIStr, bool, error) { - var ( - indexName model.CIStr - invisible bool - ) - - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, indexName, invisible, errors.Trace(err) - } - - if err := job.DecodeArgs(&indexName, &invisible); err != nil { - job.State = model.JobStateCancelled - return nil, indexName, invisible, errors.Trace(err) - } - - skip, err := validateAlterIndexVisibility(nil, indexName, invisible, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return nil, indexName, invisible, errors.Trace(err) - } - if skip { - job.State = model.JobStateDone - return nil, indexName, invisible, nil - } - return tblInfo, indexName, invisible, nil -} - -// indexRecord is the record information of an index. -type indexRecord struct { - handle kv.Handle - key []byte // It's used to lock a record. Record it to reduce the encoding time. - vals []types.Datum // It's the index values. - rsData []types.Datum // It's the restored data for handle. - skip bool // skip indicates that the index key is already exists, we should not add it. -} - -type baseIndexWorker struct { - *backfillCtx - indexes []table.Index - - tp backfillerType - // The following attributes are used to reduce memory allocation. - defaultVals []types.Datum - idxRecords []*indexRecord - rowMap map[int64]types.Datum - rowDecoder *decoder.RowDecoder -} - -type addIndexTxnWorker struct { - baseIndexWorker - - // The following attributes are used to reduce memory allocation. - idxKeyBufs [][]byte - batchCheckKeys []kv.Key - batchCheckValues [][]byte - distinctCheckFlags []bool - recordIdx []int -} - -func newAddIndexTxnWorker( - decodeColMap map[int64]decoder.Column, - t table.PhysicalTable, - bfCtx *backfillCtx, - jobID int64, - elements []*meta.Element, - eleTypeKey []byte, -) (*addIndexTxnWorker, error) { - if !bytes.Equal(eleTypeKey, meta.IndexElementKey) { - logutil.DDLLogger().Error("Element type for addIndexTxnWorker incorrect", - zap.Int64("job ID", jobID), zap.ByteString("element type", eleTypeKey), zap.Int64("element ID", elements[0].ID)) - return nil, errors.Errorf("element type is not index, typeKey: %v", eleTypeKey) - } - - allIndexes := make([]table.Index, 0, len(elements)) - for _, elem := range elements { - if !bytes.Equal(elem.TypeKey, meta.IndexElementKey) { - continue - } - indexInfo := model.FindIndexInfoByID(t.Meta().Indices, elem.ID) - index := tables.NewIndex(t.GetPhysicalID(), t.Meta(), indexInfo) - allIndexes = append(allIndexes, index) - } - rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) - - return &addIndexTxnWorker{ - baseIndexWorker: baseIndexWorker{ - backfillCtx: bfCtx, - indexes: allIndexes, - rowDecoder: rowDecoder, - defaultVals: make([]types.Datum, len(t.WritableCols())), - rowMap: make(map[int64]types.Datum, len(decodeColMap)), - }, - }, nil -} - -func (w *baseIndexWorker) AddMetricInfo(cnt float64) { - w.metricCounter.Add(cnt) -} - -func (w *baseIndexWorker) String() string { - return w.tp.String() -} - -func (w *baseIndexWorker) GetCtx() *backfillCtx { - return w.backfillCtx -} - -// mockNotOwnerErrOnce uses to make sure `notOwnerErr` only mock error once. -var mockNotOwnerErrOnce uint32 - -// getIndexRecord gets index columns values use w.rowDecoder, and generate indexRecord. -func (w *baseIndexWorker) getIndexRecord(idxInfo *model.IndexInfo, handle kv.Handle, recordKey []byte) (*indexRecord, error) { - cols := w.table.WritableCols() - failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { - if valStr, ok := val.(string); ok { - switch valStr { - case "cantDecodeRecordErr": - failpoint.Return(nil, errors.Trace(dbterror.ErrCantDecodeRecord.GenWithStackByArgs("index", - errors.New("mock can't decode record error")))) - case "modifyColumnNotOwnerErr": - if idxInfo.Name.O == "_Idx$_idx_0" && handle.IntValue() == 7168 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 0, 1) { - failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) - } - case "addIdxNotOwnerErr": - // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. - // First step, we need to exit "addPhysicalTableIndex". - if idxInfo.Name.O == "idx2" && handle.IntValue() == 6144 && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 1, 2) { - failpoint.Return(nil, errors.Trace(dbterror.ErrNotOwner)) - } - } - } - }) - idxVal := make([]types.Datum, len(idxInfo.Columns)) - var err error - for j, v := range idxInfo.Columns { - col := cols[v.Offset] - idxColumnVal, ok := w.rowMap[col.ID] - if ok { - idxVal[j] = idxColumnVal - continue - } - idxColumnVal, err = tables.GetColDefaultValue(w.exprCtx, col, w.defaultVals) - if err != nil { - return nil, errors.Trace(err) - } - - idxVal[j] = idxColumnVal - } - - rsData := tables.TryGetHandleRestoredDataWrapper(w.table.Meta(), nil, w.rowMap, idxInfo) - idxRecord := &indexRecord{handle: handle, key: recordKey, vals: idxVal, rsData: rsData} - return idxRecord, nil -} - -func (w *baseIndexWorker) cleanRowMap() { - for id := range w.rowMap { - delete(w.rowMap, id) - } -} - -// getNextKey gets next key of entry that we are going to process. -func (w *baseIndexWorker) getNextKey(taskRange reorgBackfillTask, taskDone bool) (nextKey kv.Key) { - if !taskDone { - // The task is not done. So we need to pick the last processed entry's handle and add one. - lastHandle := w.idxRecords[len(w.idxRecords)-1].handle - recordKey := tablecodec.EncodeRecordKey(taskRange.physicalTable.RecordPrefix(), lastHandle) - return recordKey.Next() - } - return taskRange.endKey -} - -func (w *baseIndexWorker) updateRowDecoder(handle kv.Handle, rawRecord []byte) error { - sysZone := w.loc - _, err := w.rowDecoder.DecodeAndEvalRowWithMap(w.exprCtx, handle, rawRecord, sysZone, w.rowMap) - return errors.Trace(err) -} - -// fetchRowColVals fetch w.batchCnt count records that need to reorganize indices, and build the corresponding indexRecord slice. -// fetchRowColVals returns: -// 1. The corresponding indexRecord slice. -// 2. Next handle of entry that we need to process. -// 3. Boolean indicates whether the task is done. -// 4. error occurs in fetchRowColVals. nil if no error occurs. -func (w *baseIndexWorker) fetchRowColVals(txn kv.Transaction, taskRange reorgBackfillTask) ([]*indexRecord, kv.Key, bool, error) { - // TODO: use tableScan to prune columns. - w.idxRecords = w.idxRecords[:0] - startTime := time.Now() - - // taskDone means that the reorged handle is out of taskRange.endHandle. - taskDone := false - oprStartTime := startTime - err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, taskRange.physicalTable.RecordPrefix(), txn.StartTS(), - taskRange.startKey, taskRange.endKey, func(handle kv.Handle, recordKey kv.Key, rawRow []byte) (bool, error) { - oprEndTime := time.Now() - logSlowOperations(oprEndTime.Sub(oprStartTime), "iterateSnapshotKeys in baseIndexWorker fetchRowColVals", 0) - oprStartTime = oprEndTime - - taskDone = recordKey.Cmp(taskRange.endKey) >= 0 - - if taskDone || len(w.idxRecords) >= w.batchCnt { - return false, nil - } - - // Decode one row, generate records of this row. - err := w.updateRowDecoder(handle, rawRow) - if err != nil { - return false, err - } - for _, index := range w.indexes { - idxRecord, err1 := w.getIndexRecord(index.Meta(), handle, recordKey) - if err1 != nil { - return false, errors.Trace(err1) - } - w.idxRecords = append(w.idxRecords, idxRecord) - } - // If there are generated column, rowDecoder will use column value that not in idxInfo.Columns to calculate - // the generated value, so we need to clear up the reusing map. - w.cleanRowMap() - - if recordKey.Cmp(taskRange.endKey) == 0 { - taskDone = true - return false, nil - } - return true, nil - }) - - if len(w.idxRecords) == 0 { - taskDone = true - } - - logutil.DDLLogger().Debug("txn fetches handle info", zap.Stringer("worker", w), zap.Uint64("txnStartTS", txn.StartTS()), - zap.String("taskRange", taskRange.String()), zap.Duration("takeTime", time.Since(startTime))) - return w.idxRecords, w.getNextKey(taskRange, taskDone), taskDone, errors.Trace(err) -} - -func (w *addIndexTxnWorker) initBatchCheckBufs(batchCount int) { - if len(w.idxKeyBufs) < batchCount { - w.idxKeyBufs = make([][]byte, batchCount) - } - - w.batchCheckKeys = w.batchCheckKeys[:0] - w.batchCheckValues = w.batchCheckValues[:0] - w.distinctCheckFlags = w.distinctCheckFlags[:0] - w.recordIdx = w.recordIdx[:0] -} - -func (w *addIndexTxnWorker) checkHandleExists(idxInfo *model.IndexInfo, key kv.Key, value []byte, handle kv.Handle) error { - tblInfo := w.table.Meta() - idxColLen := len(idxInfo.Columns) - h, err := tablecodec.DecodeIndexHandle(key, value, idxColLen) - if err != nil { - return errors.Trace(err) - } - hasBeenBackFilled := h.Equal(handle) - if hasBeenBackFilled { - return nil - } - return ddlutil.GenKeyExistsErr(key, value, idxInfo, tblInfo) -} - -// batchCheckUniqueKey checks the unique keys in the batch. -// Note that `idxRecords` may belong to multiple indexes. -func (w *addIndexTxnWorker) batchCheckUniqueKey(txn kv.Transaction, idxRecords []*indexRecord) error { - w.initBatchCheckBufs(len(idxRecords)) - evalCtx := w.exprCtx.GetEvalCtx() - ec := evalCtx.ErrCtx() - uniqueBatchKeys := make([]kv.Key, 0, len(idxRecords)) - cnt := 0 - for i, record := range idxRecords { - idx := w.indexes[i%len(w.indexes)] - if !idx.Meta().Unique { - // non-unique key need not to check, use `nil` as a placeholder to keep - // `idxRecords[i]` belonging to `indexes[i%len(indexes)]`. - w.batchCheckKeys = append(w.batchCheckKeys, nil) - w.batchCheckValues = append(w.batchCheckValues, nil) - w.distinctCheckFlags = append(w.distinctCheckFlags, false) - w.recordIdx = append(w.recordIdx, 0) - continue - } - // skip by default. - idxRecords[i].skip = true - iter := idx.GenIndexKVIter(ec, w.loc, record.vals, record.handle, idxRecords[i].rsData) - for iter.Valid() { - var buf []byte - if cnt < len(w.idxKeyBufs) { - buf = w.idxKeyBufs[cnt] - } - key, val, distinct, err := iter.Next(buf, nil) - if err != nil { - return errors.Trace(err) - } - if cnt < len(w.idxKeyBufs) { - w.idxKeyBufs[cnt] = key - } else { - w.idxKeyBufs = append(w.idxKeyBufs, key) - } - cnt++ - w.batchCheckKeys = append(w.batchCheckKeys, key) - w.batchCheckValues = append(w.batchCheckValues, val) - w.distinctCheckFlags = append(w.distinctCheckFlags, distinct) - w.recordIdx = append(w.recordIdx, i) - uniqueBatchKeys = append(uniqueBatchKeys, key) - } - } - - if len(uniqueBatchKeys) == 0 { - return nil - } - - batchVals, err := txn.BatchGet(context.Background(), uniqueBatchKeys) - if err != nil { - return errors.Trace(err) - } - - // 1. unique-key/primary-key is duplicate and the handle is equal, skip it. - // 2. unique-key/primary-key is duplicate and the handle is not equal, return duplicate error. - // 3. non-unique-key is duplicate, skip it. - for i, key := range w.batchCheckKeys { - if len(key) == 0 { - continue - } - idx := w.indexes[i%len(w.indexes)] - val, found := batchVals[string(key)] - if found { - if w.distinctCheckFlags[i] { - if err := w.checkHandleExists(idx.Meta(), key, val, idxRecords[w.recordIdx[i]].handle); err != nil { - return errors.Trace(err) - } - } - } else if w.distinctCheckFlags[i] { - // The keys in w.batchCheckKeys also maybe duplicate, - // so we need to backfill the not found key into `batchVals` map. - batchVals[string(key)] = w.batchCheckValues[i] - } - idxRecords[w.recordIdx[i]].skip = found && idxRecords[w.recordIdx[i]].skip - } - return nil -} - -func getLocalWriterConfig(indexCnt, writerCnt int) *backend.LocalWriterConfig { - writerCfg := &backend.LocalWriterConfig{} - // avoid unit test panic - memRoot := ingest.LitMemRoot - if memRoot == nil { - return writerCfg - } - - // leave some room for objects overhead - availMem := memRoot.MaxMemoryQuota() - memRoot.CurrentUsage() - int64(10*size.MB) - memLimitPerWriter := availMem / int64(indexCnt) / int64(writerCnt) - memLimitPerWriter = min(memLimitPerWriter, litconfig.DefaultLocalWriterMemCacheSize) - writerCfg.Local.MemCacheSize = memLimitPerWriter - return writerCfg -} - -func writeChunkToLocal( - ctx context.Context, - writers []ingest.Writer, - indexes []table.Index, - copCtx copr.CopContext, - loc *time.Location, - errCtx errctx.Context, - writeStmtBufs *variable.WriteStmtBufs, - copChunk *chunk.Chunk, -) (int, kv.Handle, error) { - iter := chunk.NewIterator4Chunk(copChunk) - c := copCtx.GetBase() - ectx := c.ExprCtx.GetEvalCtx() - - maxIdxColCnt := maxIndexColumnCount(indexes) - idxDataBuf := make([]types.Datum, maxIdxColCnt) - handleDataBuf := make([]types.Datum, len(c.HandleOutputOffsets)) - var restoreDataBuf []types.Datum - count := 0 - var lastHandle kv.Handle - - unlockFns := make([]func(), 0, len(writers)) - for _, w := range writers { - unlock := w.LockForWrite() - unlockFns = append(unlockFns, unlock) - } - defer func() { - for _, unlock := range unlockFns { - unlock() - } - }() - needRestoreForIndexes := make([]bool, len(indexes)) - restore, pkNeedRestore := false, false - if c.PrimaryKeyInfo != nil && c.TableInfo.IsCommonHandle && c.TableInfo.CommonHandleVersion != 0 { - pkNeedRestore = tables.NeedRestoredData(c.PrimaryKeyInfo.Columns, c.TableInfo.Columns) - } - for i, index := range indexes { - needRestore := pkNeedRestore || tables.NeedRestoredData(index.Meta().Columns, c.TableInfo.Columns) - needRestoreForIndexes[i] = needRestore - restore = restore || needRestore - } - if restore { - restoreDataBuf = make([]types.Datum, len(c.HandleOutputOffsets)) - } - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - handleDataBuf := extractDatumByOffsets(ectx, row, c.HandleOutputOffsets, c.ExprColumnInfos, handleDataBuf) - if restore { - // restoreDataBuf should not truncate index values. - for i, datum := range handleDataBuf { - restoreDataBuf[i] = *datum.Clone() - } - } - h, err := buildHandle(handleDataBuf, c.TableInfo, c.PrimaryKeyInfo, loc, errCtx) - if err != nil { - return 0, nil, errors.Trace(err) - } - for i, index := range indexes { - idxID := index.Meta().ID - idxDataBuf = extractDatumByOffsets(ectx, - row, copCtx.IndexColumnOutputOffsets(idxID), c.ExprColumnInfos, idxDataBuf) - idxData := idxDataBuf[:len(index.Meta().Columns)] - var rsData []types.Datum - if needRestoreForIndexes[i] { - rsData = getRestoreData(c.TableInfo, copCtx.IndexInfo(idxID), c.PrimaryKeyInfo, restoreDataBuf) - } - err = writeOneKVToLocal(ctx, writers[i], index, loc, errCtx, writeStmtBufs, idxData, rsData, h) - if err != nil { - return 0, nil, errors.Trace(err) - } - } - count++ - lastHandle = h - } - return count, lastHandle, nil -} - -func maxIndexColumnCount(indexes []table.Index) int { - maxCnt := 0 - for _, idx := range indexes { - colCnt := len(idx.Meta().Columns) - if colCnt > maxCnt { - maxCnt = colCnt - } - } - return maxCnt -} - -func writeOneKVToLocal( - ctx context.Context, - writer ingest.Writer, - index table.Index, - loc *time.Location, - errCtx errctx.Context, - writeBufs *variable.WriteStmtBufs, - idxDt, rsData []types.Datum, - handle kv.Handle, -) error { - iter := index.GenIndexKVIter(errCtx, loc, idxDt, handle, rsData) - for iter.Valid() { - key, idxVal, _, err := iter.Next(writeBufs.IndexKeyBuf, writeBufs.RowValBuf) - if err != nil { - return errors.Trace(err) - } - failpoint.Inject("mockLocalWriterPanic", func() { - panic("mock panic") - }) - err = writer.WriteRow(ctx, key, idxVal, handle) - if err != nil { - return errors.Trace(err) - } - failpoint.Inject("mockLocalWriterError", func() { - failpoint.Return(errors.New("mock engine error")) - }) - writeBufs.IndexKeyBuf = key - writeBufs.RowValBuf = idxVal - } - return nil -} - -// BackfillData will backfill table index in a transaction. A lock corresponds to a rowKey if the value of rowKey is changed, -// Note that index columns values may change, and an index is not allowed to be added, so the txn will rollback and retry. -// BackfillData will add w.batchCnt indices once, default value of w.batchCnt is 128. -func (w *addIndexTxnWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - failpoint.Inject("errorMockPanic", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - panic("panic test") - } - }) - - oprStartTime := time.Now() - jobID := handleRange.getJobID() - ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) - errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) (err error) { - taskCtx.finishTS = txn.StartTS() - taskCtx.addedCount = 0 - taskCtx.scanCount = 0 - updateTxnEntrySizeLimitIfNeeded(txn) - txn.SetOption(kv.Priority, handleRange.priority) - if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(jobID); tagger != nil { - txn.SetOption(kv.ResourceGroupTagger, tagger) - } - txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) - - idxRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) - if err != nil { - return errors.Trace(err) - } - taskCtx.nextKey = nextKey - taskCtx.done = taskDone - - err = w.batchCheckUniqueKey(txn, idxRecords) - if err != nil { - return errors.Trace(err) - } - - for i, idxRecord := range idxRecords { - taskCtx.scanCount++ - // The index is already exists, we skip it, no needs to backfill it. - // The following update, delete, insert on these rows, TiDB can handle it correctly. - if idxRecord.skip { - continue - } - - // We need to add this lock to make sure pessimistic transaction can realize this operation. - // For the normal pessimistic transaction, it's ok. But if async commit is used, it may lead to inconsistent data and index. - // TODO: For global index, lock the correct key?! Currently it locks the partition (phyTblID) and the handle or actual key? - // but should really lock the table's ID + key col(s) - err := txn.LockKeys(context.Background(), new(kv.LockCtx), idxRecord.key) - if err != nil { - return errors.Trace(err) - } - - handle, err := w.indexes[i%len(w.indexes)].Create( - w.tblCtx, txn, idxRecord.vals, idxRecord.handle, idxRecord.rsData, - table.WithIgnoreAssertion, - table.FromBackfill, - // Constrains is already checked in batchCheckUniqueKey - table.DupKeyCheckSkip, - ) - if err != nil { - if kv.ErrKeyExists.Equal(err) && idxRecord.handle.Equal(handle) { - // Index already exists, skip it. - continue - } - return errors.Trace(err) - } - taskCtx.addedCount++ - } - - return nil - }) - logSlowOperations(time.Since(oprStartTime), "AddIndexBackfillData", 3000) - failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) && MockDMLExecution != nil { - MockDMLExecution() - } - }) - return -} - -// MockDMLExecution is only used for test. -var MockDMLExecution func() - -// MockDMLExecutionMerging is only used for test. -var MockDMLExecutionMerging func() - -// MockDMLExecutionStateMerging is only used for test. -var MockDMLExecutionStateMerging func() - -// MockDMLExecutionStateBeforeImport is only used for test. -var MockDMLExecutionStateBeforeImport func() - -// MockDMLExecutionStateBeforeMerge is only used for test. -var MockDMLExecutionStateBeforeMerge func() - -func (w *worker) addPhysicalTableIndex(t table.PhysicalTable, reorgInfo *reorgInfo) error { - if reorgInfo.mergingTmpIdx { - logutil.DDLLogger().Info("start to merge temp index", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) - return w.writePhysicalTableRecord(w.ctx, w.sessPool, t, typeAddIndexMergeTmpWorker, reorgInfo) - } - logutil.DDLLogger().Info("start to add table index", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) - return w.writePhysicalTableRecord(w.ctx, w.sessPool, t, typeAddIndexWorker, reorgInfo) -} - -// addTableIndex handles the add index reorganization state for a table. -func (w *worker) addTableIndex(t table.Table, reorgInfo *reorgInfo) error { - // TODO: Support typeAddIndexMergeTmpWorker. - if reorgInfo.ReorgMeta.IsDistReorg && !reorgInfo.mergingTmpIdx { - if reorgInfo.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge { - err := w.executeDistTask(t, reorgInfo) - if err != nil { - return err - } - //nolint:forcetypeassert - discovery := w.store.(tikv.Storage).GetRegionCache().PDClient().GetServiceDiscovery() - return checkDuplicateForUniqueIndex(w.ctx, t, reorgInfo, discovery) - } - } - - var err error - if tbl, ok := t.(table.PartitionedTable); ok { - var finish bool - for !finish { - p := tbl.GetPartition(reorgInfo.PhysicalTableID) - if p == nil { - return dbterror.ErrCancelledDDLJob.GenWithStack("Can not find partition id %d for table %d", reorgInfo.PhysicalTableID, t.Meta().ID) - } - err = w.addPhysicalTableIndex(p, reorgInfo) - if err != nil { - break - } - - finish, err = updateReorgInfo(w.sessPool, tbl, reorgInfo) - if err != nil { - return errors.Trace(err) - } - failpoint.InjectCall("afterUpdatePartitionReorgInfo", reorgInfo.Job) - // Every time we finish a partition, we update the progress of the job. - if rc := w.getReorgCtx(reorgInfo.Job.ID); rc != nil { - reorgInfo.Job.SetRowCount(rc.getRowCount()) - } - } - } else { - //nolint:forcetypeassert - phyTbl := t.(table.PhysicalTable) - err = w.addPhysicalTableIndex(phyTbl, reorgInfo) - } - return errors.Trace(err) -} - -func checkDuplicateForUniqueIndex(ctx context.Context, t table.Table, reorgInfo *reorgInfo, discovery pd.ServiceDiscovery) error { - var bc ingest.BackendCtx - var err error - defer func() { - if bc != nil { - ingest.LitBackCtxMgr.Unregister(reorgInfo.ID) - } - }() - - for _, elem := range reorgInfo.elements { - indexInfo := model.FindIndexInfoByID(t.Meta().Indices, elem.ID) - if indexInfo == nil { - return errors.New("unexpected error, can't find index info") - } - if indexInfo.Unique { - ctx := tidblogutil.WithCategory(ctx, "ddl-ingest") - if bc == nil { - bc, err = ingest.LitBackCtxMgr.Register(ctx, reorgInfo.ID, indexInfo.Unique, nil, discovery, reorgInfo.ReorgMeta.ResourceGroupName) - if err != nil { - return err - } - } - err = bc.CollectRemoteDuplicateRows(indexInfo.ID, t) - if err != nil { - return err - } - } - } - return nil -} - -func (w *worker) executeDistTask(t table.Table, reorgInfo *reorgInfo) error { - if reorgInfo.mergingTmpIdx { - return errors.New("do not support merge index") - } - - taskType := proto.Backfill - taskKey := fmt.Sprintf("ddl/%s/%d", taskType, reorgInfo.Job.ID) - g, ctx := errgroup.WithContext(w.ctx) - ctx = kv.WithInternalSourceType(ctx, kv.InternalDistTask) - - done := make(chan struct{}) - - // generate taskKey for multi schema change. - if mInfo := reorgInfo.Job.MultiSchemaInfo; mInfo != nil { - taskKey = fmt.Sprintf("%s/%d", taskKey, mInfo.Seq) - } - - // For resuming add index task. - // Need to fetch task by taskKey in tidb_global_task and tidb_global_task_history tables. - // When pausing the related ddl job, it is possible that the task with taskKey is succeed and in tidb_global_task_history. - // As a result, when resuming the related ddl job, - // it is necessary to check task exits in tidb_global_task and tidb_global_task_history tables. - taskManager, err := storage.GetTaskManager() - if err != nil { - return err - } - task, err := taskManager.GetTaskByKeyWithHistory(w.ctx, taskKey) - if err != nil && err != storage.ErrTaskNotFound { - return err - } - if task != nil { - // It's possible that the task state is succeed but the ddl job is paused. - // When task in succeed state, we can skip the dist task execution/scheduing process. - if task.State == proto.TaskStateSucceed { - logutil.DDLLogger().Info( - "task succeed, start to resume the ddl job", - zap.String("task-key", taskKey)) - return nil - } - g.Go(func() error { - defer close(done) - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logutil.DDLLogger(), - func(context.Context) (bool, error) { - return true, handle.ResumeTask(w.ctx, taskKey) - }, - ) - if err != nil { - return err - } - err = handle.WaitTaskDoneOrPaused(ctx, task.ID) - if err := w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { - if dbterror.ErrPausedDDLJob.Equal(err) { - logutil.DDLLogger().Warn("job paused by user", zap.Error(err)) - return dbterror.ErrPausedDDLJob.GenWithStackByArgs(reorgInfo.Job.ID) - } - } - return err - }) - } else { - job := reorgInfo.Job - workerCntLimit := int(variable.GetDDLReorgWorkerCounter()) - cpuCount, err := handle.GetCPUCountOfNode(ctx) - if err != nil { - return err - } - concurrency := min(workerCntLimit, cpuCount) - logutil.DDLLogger().Info("adjusted add-index task concurrency", - zap.Int("worker-cnt", workerCntLimit), zap.Int("task-concurrency", concurrency), - zap.String("task-key", taskKey)) - rowSize := estimateTableRowSize(w.ctx, w.store, w.sess.GetRestrictedSQLExecutor(), t) - taskMeta := &BackfillTaskMeta{ - Job: *job.Clone(), - EleIDs: extractElemIDs(reorgInfo), - EleTypeKey: reorgInfo.currElement.TypeKey, - CloudStorageURI: w.jobContext(job.ID, job.ReorgMeta).cloudStorageURI, - EstimateRowSize: rowSize, - } - - metaData, err := json.Marshal(taskMeta) - if err != nil { - return err - } - - g.Go(func() error { - defer close(done) - err := submitAndWaitTask(ctx, taskKey, taskType, concurrency, reorgInfo.ReorgMeta.TargetScope, metaData) - failpoint.InjectCall("pauseAfterDistTaskFinished") - if err := w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { - if dbterror.ErrPausedDDLJob.Equal(err) { - logutil.DDLLogger().Warn("job paused by user", zap.Error(err)) - return dbterror.ErrPausedDDLJob.GenWithStackByArgs(reorgInfo.Job.ID) - } - } - return err - }) - } - - g.Go(func() error { - checkFinishTk := time.NewTicker(CheckBackfillJobFinishInterval) - defer checkFinishTk.Stop() - updateRowCntTk := time.NewTicker(UpdateBackfillJobRowCountInterval) - defer updateRowCntTk.Stop() - for { - select { - case <-done: - w.updateJobRowCount(taskKey, reorgInfo.Job.ID) - return nil - case <-checkFinishTk.C: - if err = w.isReorgRunnable(reorgInfo.Job.ID, true); err != nil { - if dbterror.ErrPausedDDLJob.Equal(err) { - if err = handle.PauseTask(w.ctx, taskKey); err != nil { - logutil.DDLLogger().Error("pause task error", zap.String("task_key", taskKey), zap.Error(err)) - continue - } - failpoint.InjectCall("syncDDLTaskPause") - } - if !dbterror.ErrCancelledDDLJob.Equal(err) { - return errors.Trace(err) - } - if err = handle.CancelTask(w.ctx, taskKey); err != nil { - logutil.DDLLogger().Error("cancel task error", zap.String("task_key", taskKey), zap.Error(err)) - // continue to cancel task. - continue - } - } - case <-updateRowCntTk.C: - w.updateJobRowCount(taskKey, reorgInfo.Job.ID) - } - } - }) - err = g.Wait() - return err -} - -// EstimateTableRowSizeForTest is used for test. -var EstimateTableRowSizeForTest = estimateTableRowSize - -// estimateTableRowSize estimates the row size in bytes of a table. -// This function tries to retrieve row size in following orders: -// 1. AVG_ROW_LENGTH column from information_schema.tables. -// 2. region info's approximate key size / key number. -func estimateTableRowSize( - ctx context.Context, - store kv.Storage, - exec sqlexec.RestrictedSQLExecutor, - tbl table.Table, -) (sizeInBytes int) { - defer util.Recover(metrics.LabelDDL, "estimateTableRowSize", nil, false) - var gErr error - defer func() { - tidblogutil.Logger(ctx).Info("estimate row size", - zap.Int64("tableID", tbl.Meta().ID), zap.Int("size", sizeInBytes), zap.Error(gErr)) - }() - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, - "select AVG_ROW_LENGTH from information_schema.tables where TIDB_TABLE_ID = %?", tbl.Meta().ID) - if err != nil { - gErr = err - return 0 - } - if len(rows) == 0 { - gErr = errors.New("no average row data") - return 0 - } - avgRowSize := rows[0].GetInt64(0) - if avgRowSize != 0 { - return int(avgRowSize) - } - regionRowSize, err := estimateRowSizeFromRegion(ctx, store, tbl) - if err != nil { - gErr = err - return 0 - } - return regionRowSize -} - -func estimateRowSizeFromRegion(ctx context.Context, store kv.Storage, tbl table.Table) (int, error) { - hStore, ok := store.(helper.Storage) - if !ok { - return 0, fmt.Errorf("not a helper.Storage") - } - h := &helper.Helper{ - Store: hStore, - RegionCache: hStore.GetRegionCache(), - } - pdCli, err := h.TryGetPDHTTPClient() - if err != nil { - return 0, err - } - pid := tbl.Meta().ID - sk, ek := tablecodec.GetTableHandleKeyRange(pid) - sRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, sk)) - if err != nil { - return 0, err - } - eRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, ek)) - if err != nil { - return 0, err - } - sk, err = hex.DecodeString(sRegion.StartKey) - if err != nil { - return 0, err - } - ek, err = hex.DecodeString(eRegion.EndKey) - if err != nil { - return 0, err - } - // We use the second region to prevent the influence of the front and back tables. - regionLimit := 3 - regionInfos, err := pdCli.GetRegionsByKeyRange(ctx, pdHttp.NewKeyRange(sk, ek), regionLimit) - if err != nil { - return 0, err - } - if len(regionInfos.Regions) != regionLimit { - return 0, fmt.Errorf("less than 3 regions") - } - sample := regionInfos.Regions[1] - if sample.ApproximateKeys == 0 || sample.ApproximateSize == 0 { - return 0, fmt.Errorf("zero approximate size") - } - return int(uint64(sample.ApproximateSize)*size.MB) / int(sample.ApproximateKeys), nil -} - -func (w *worker) updateJobRowCount(taskKey string, jobID int64) { - taskMgr, err := storage.GetTaskManager() - if err != nil { - logutil.DDLLogger().Warn("cannot get task manager", zap.String("task_key", taskKey), zap.Error(err)) - return - } - task, err := taskMgr.GetTaskByKey(w.ctx, taskKey) - if err != nil { - logutil.DDLLogger().Warn("cannot get task", zap.String("task_key", taskKey), zap.Error(err)) - return - } - rowCount, err := taskMgr.GetSubtaskRowCount(w.ctx, task.ID, proto.BackfillStepReadIndex) - if err != nil { - logutil.DDLLogger().Warn("cannot get subtask row count", zap.String("task_key", taskKey), zap.Error(err)) - return - } - w.getReorgCtx(jobID).setRowCount(rowCount) -} - -// submitAndWaitTask submits a task and wait for it to finish. -func submitAndWaitTask(ctx context.Context, taskKey string, taskType proto.TaskType, concurrency int, targetScope string, taskMeta []byte) error { - task, err := handle.SubmitTask(ctx, taskKey, taskType, concurrency, targetScope, taskMeta) - if err != nil { - return err - } - return handle.WaitTaskDoneOrPaused(ctx, task.ID) -} - -func getNextPartitionInfo(reorg *reorgInfo, t table.PartitionedTable, currPhysicalTableID int64) (int64, kv.Key, kv.Key, error) { - pi := t.Meta().GetPartitionInfo() - if pi == nil { - return 0, nil, nil, nil - } - - // This will be used in multiple different scenarios/ALTER TABLE: - // ADD INDEX - no change in partitions, just use pi.Definitions (1) - // REORGANIZE PARTITION - copy data from partitions to be dropped (2) - // REORGANIZE PARTITION - (re)create indexes on partitions to be added (3) - // REORGANIZE PARTITION - Update new Global indexes with data from non-touched partitions (4) - // (i.e. pi.Definitions - pi.DroppingDefinitions) - var pid int64 - var err error - if bytes.Equal(reorg.currElement.TypeKey, meta.IndexElementKey) { - // case 1, 3 or 4 - if len(pi.AddingDefinitions) == 0 { - // case 1 - // Simply AddIndex, without any partitions added or dropped! - pid, err = findNextPartitionID(currPhysicalTableID, pi.Definitions) - } else { - // case 3 (or if not found AddingDefinitions; 4) - // check if recreating Global Index (during Reorg Partition) - pid, err = findNextPartitionID(currPhysicalTableID, pi.AddingDefinitions) - if err != nil { - // case 4 - // Not a partition in the AddingDefinitions, so it must be an existing - // non-touched partition, i.e. recreating Global Index for the non-touched partitions - pid, err = findNextNonTouchedPartitionID(currPhysicalTableID, pi) - } - } - } else { - // case 2 - pid, err = findNextPartitionID(currPhysicalTableID, pi.DroppingDefinitions) - } - if err != nil { - // Fatal error, should not run here. - logutil.DDLLogger().Error("find next partition ID failed", zap.Reflect("table", t), zap.Error(err)) - return 0, nil, nil, errors.Trace(err) - } - if pid == 0 { - // Next partition does not exist, all the job done. - return 0, nil, nil, nil - } - - failpoint.Inject("mockUpdateCachedSafePoint", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - ts := oracle.GoTimeToTS(time.Now()) - //nolint:forcetypeassert - s := reorg.d.store.(tikv.Storage) - s.UpdateSPCache(ts, time.Now()) - time.Sleep(time.Second * 3) - } - }) - - var startKey, endKey kv.Key - if reorg.mergingTmpIdx { - elements := reorg.elements - firstElemTempID := tablecodec.TempIndexPrefix | elements[0].ID - lastElemTempID := tablecodec.TempIndexPrefix | elements[len(elements)-1].ID - startKey = tablecodec.EncodeIndexSeekKey(pid, firstElemTempID, nil) - endKey = tablecodec.EncodeIndexSeekKey(pid, lastElemTempID, []byte{255}) - } else { - currentVer, err := getValidCurrentVersion(reorg.d.store) - if err != nil { - return 0, nil, nil, errors.Trace(err) - } - startKey, endKey, err = getTableRange(reorg.NewJobContext(), reorg.d, t.GetPartition(pid), currentVer.Ver, reorg.Job.Priority) - if err != nil { - return 0, nil, nil, errors.Trace(err) - } - } - return pid, startKey, endKey, nil -} - -// updateReorgInfo will find the next partition according to current reorgInfo. -// If no more partitions, or table t is not a partitioned table, returns true to -// indicate that the reorganize work is finished. -func updateReorgInfo(sessPool *sess.Pool, t table.PartitionedTable, reorg *reorgInfo) (bool, error) { - pid, startKey, endKey, err := getNextPartitionInfo(reorg, t, reorg.PhysicalTableID) - if err != nil { - return false, errors.Trace(err) - } - if pid == 0 { - // Next partition does not exist, all the job done. - return true, nil - } - reorg.PhysicalTableID, reorg.StartKey, reorg.EndKey = pid, startKey, endKey - - // Write the reorg info to store so the whole reorganize process can recover from panic. - err = reorg.UpdateReorgMeta(reorg.StartKey, sessPool) - logutil.DDLLogger().Info("job update reorgInfo", - zap.Int64("jobID", reorg.Job.ID), - zap.Stringer("element", reorg.currElement), - zap.Int64("partitionTableID", pid), - zap.String("startKey", hex.EncodeToString(reorg.StartKey)), - zap.String("endKey", hex.EncodeToString(reorg.EndKey)), zap.Error(err)) - return false, errors.Trace(err) -} - -// findNextPartitionID finds the next partition ID in the PartitionDefinition array. -// Returns 0 if current partition is already the last one. -func findNextPartitionID(currentPartition int64, defs []model.PartitionDefinition) (int64, error) { - for i, def := range defs { - if currentPartition == def.ID { - if i == len(defs)-1 { - return 0, nil - } - return defs[i+1].ID, nil - } - } - return 0, errors.Errorf("partition id not found %d", currentPartition) -} - -func findNextNonTouchedPartitionID(currPartitionID int64, pi *model.PartitionInfo) (int64, error) { - pid, err := findNextPartitionID(currPartitionID, pi.Definitions) - if err != nil { - return 0, err - } - if pid == 0 { - return 0, nil - } - for _, notFoundErr := findNextPartitionID(pid, pi.DroppingDefinitions); notFoundErr == nil; { - // This can be optimized, but it is not frequently called, so keeping as-is - pid, err = findNextPartitionID(pid, pi.Definitions) - if pid == 0 { - break - } - } - return pid, err -} - -// AllocateIndexID allocates an index ID from TableInfo. -func AllocateIndexID(tblInfo *model.TableInfo) int64 { - tblInfo.MaxIndexID++ - return tblInfo.MaxIndexID -} - -func getIndexInfoByNameAndColumn(oldTableInfo *model.TableInfo, newOne *model.IndexInfo) *model.IndexInfo { - for _, oldOne := range oldTableInfo.Indices { - if newOne.Name.L == oldOne.Name.L && indexColumnSliceEqual(newOne.Columns, oldOne.Columns) { - return oldOne - } - } - return nil -} - -func indexColumnSliceEqual(a, b []*model.IndexColumn) bool { - if len(a) != len(b) { - return false - } - if len(a) == 0 { - logutil.DDLLogger().Warn("admin repair table : index's columns length equal to 0") - return true - } - // Accelerate the compare by eliminate index bound check. - b = b[:len(a)] - for i, v := range a { - if v.Name.L != b[i].Name.L { - return false - } - } - return true -} - -type cleanUpIndexWorker struct { - baseIndexWorker -} - -func newCleanUpIndexWorker(id int, t table.PhysicalTable, decodeColMap map[int64]decoder.Column, reorgInfo *reorgInfo, jc *JobContext) (*cleanUpIndexWorker, error) { - bCtx, err := newBackfillCtx(id, reorgInfo, reorgInfo.SchemaName, t, jc, "cleanup_idx_rate", false) - if err != nil { - return nil, err - } - - indexes := make([]table.Index, 0, len(t.Indices())) - rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap) - for _, index := range t.Indices() { - if index.Meta().Global { - indexes = append(indexes, index) - } - } - return &cleanUpIndexWorker{ - baseIndexWorker: baseIndexWorker{ - backfillCtx: bCtx, - indexes: indexes, - rowDecoder: rowDecoder, - defaultVals: make([]types.Datum, len(t.WritableCols())), - rowMap: make(map[int64]types.Datum, len(decodeColMap)), - }, - }, nil -} - -func (w *cleanUpIndexWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - failpoint.Inject("errorMockPanic", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - panic("panic test") - } - }) - - oprStartTime := time.Now() - ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) - errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { - taskCtx.addedCount = 0 - taskCtx.scanCount = 0 - updateTxnEntrySizeLimitIfNeeded(txn) - txn.SetOption(kv.Priority, handleRange.priority) - if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(handleRange.getJobID()); tagger != nil { - txn.SetOption(kv.ResourceGroupTagger, tagger) - } - txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) - - idxRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) - if err != nil { - return errors.Trace(err) - } - taskCtx.nextKey = nextKey - taskCtx.done = taskDone - - txn.SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - - n := len(w.indexes) - for i, idxRecord := range idxRecords { - taskCtx.scanCount++ - // we fetch records row by row, so records will belong to - // index[0], index[1] ... index[n-1], index[0], index[1] ... - // respectively. So indexes[i%n] is the index of idxRecords[i]. - err := w.indexes[i%n].Delete(w.tblCtx, txn, idxRecord.vals, idxRecord.handle) - if err != nil { - return errors.Trace(err) - } - taskCtx.addedCount++ - } - return nil - }) - logSlowOperations(time.Since(oprStartTime), "cleanUpIndexBackfillDataInTxn", 3000) - failpoint.Inject("mockDMLExecution", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) && MockDMLExecution != nil { - MockDMLExecution() - } - }) - - return -} - -// cleanupPhysicalTableIndex handles the drop partition reorganization state for a non-partitioned table or a partition. -func (w *worker) cleanupPhysicalTableIndex(t table.PhysicalTable, reorgInfo *reorgInfo) error { - logutil.DDLLogger().Info("start to clean up index", zap.Stringer("job", reorgInfo.Job), zap.Stringer("reorgInfo", reorgInfo)) - return w.writePhysicalTableRecord(w.ctx, w.sessPool, t, typeCleanUpIndexWorker, reorgInfo) -} - -// cleanupGlobalIndex handles the drop partition reorganization state to clean up index entries of partitions. -func (w *worker) cleanupGlobalIndexes(tbl table.PartitionedTable, partitionIDs []int64, reorgInfo *reorgInfo) error { - var err error - var finish bool - for !finish { - p := tbl.GetPartition(reorgInfo.PhysicalTableID) - if p == nil { - return dbterror.ErrCancelledDDLJob.GenWithStack("Can not find partition id %d for table %d", reorgInfo.PhysicalTableID, tbl.Meta().ID) - } - err = w.cleanupPhysicalTableIndex(p, reorgInfo) - if err != nil { - break - } - finish, err = w.updateReorgInfoForPartitions(tbl, reorgInfo, partitionIDs) - if err != nil { - return errors.Trace(err) - } - } - - return errors.Trace(err) -} - -// updateReorgInfoForPartitions will find the next partition in partitionIDs according to current reorgInfo. -// If no more partitions, or table t is not a partitioned table, returns true to -// indicate that the reorganize work is finished. -func (w *worker) updateReorgInfoForPartitions(t table.PartitionedTable, reorg *reorgInfo, partitionIDs []int64) (bool, error) { - pi := t.Meta().GetPartitionInfo() - if pi == nil { - return true, nil - } - - var pid int64 - for i, pi := range partitionIDs { - if pi == reorg.PhysicalTableID { - if i == len(partitionIDs)-1 { - return true, nil - } - pid = partitionIDs[i+1] - break - } - } - - currentVer, err := getValidCurrentVersion(reorg.d.store) - if err != nil { - return false, errors.Trace(err) - } - start, end, err := getTableRange(reorg.NewJobContext(), reorg.d, t.GetPartition(pid), currentVer.Ver, reorg.Job.Priority) - if err != nil { - return false, errors.Trace(err) - } - reorg.StartKey, reorg.EndKey, reorg.PhysicalTableID = start, end, pid - - // Write the reorg info to store so the whole reorganize process can recover from panic. - err = reorg.UpdateReorgMeta(reorg.StartKey, w.sessPool) - logutil.DDLLogger().Info("job update reorg info", zap.Int64("jobID", reorg.Job.ID), - zap.Stringer("element", reorg.currElement), - zap.Int64("partition table ID", pid), zap.String("start key", hex.EncodeToString(start)), - zap.String("end key", hex.EncodeToString(end)), zap.Error(err)) - return false, errors.Trace(err) -} - -// changingIndex is used to store the index that need to be changed during modifying column. -type changingIndex struct { - IndexInfo *model.IndexInfo - // Column offset in idxInfo.Columns. - Offset int - // When the modifying column is contained in the index, a temp index is created. - // isTemp indicates whether the indexInfo is a temp index created by a previous modify column job. - isTemp bool -} - -// FindRelatedIndexesToChange finds the indexes that covering the given column. -// The normal one will be overridden by the temp one. -func FindRelatedIndexesToChange(tblInfo *model.TableInfo, colName model.CIStr) []changingIndex { - // In multi-schema change jobs that contains several "modify column" sub-jobs, there may be temp indexes for another temp index. - // To prevent reorganizing too many indexes, we should create the temp indexes that are really necessary. - var normalIdxInfos, tempIdxInfos []changingIndex - for _, idxInfo := range tblInfo.Indices { - if pos := findIdxCol(idxInfo, colName); pos != -1 { - isTemp := isTempIdxInfo(idxInfo, tblInfo) - r := changingIndex{IndexInfo: idxInfo, Offset: pos, isTemp: isTemp} - if isTemp { - tempIdxInfos = append(tempIdxInfos, r) - } else { - normalIdxInfos = append(normalIdxInfos, r) - } - } - } - // Overwrite if the index has the corresponding temp index. For example, - // we try to find the indexes that contain the column `b` and there are two indexes, `i(a, b)` and `$i($a, b)`. - // Note that the symbol `$` means temporary. The index `$i($a, b)` is temporarily created by the previous "modify a" statement. - // In this case, we would create a temporary index like $$i($a, $b), so the latter should be chosen. - result := normalIdxInfos - for _, tmpIdx := range tempIdxInfos { - origName := getChangingIndexOriginName(tmpIdx.IndexInfo) - for i, normIdx := range normalIdxInfos { - if normIdx.IndexInfo.Name.O == origName { - result[i] = tmpIdx - } - } - } - return result -} - -func isTempIdxInfo(idxInfo *model.IndexInfo, tblInfo *model.TableInfo) bool { - for _, idxCol := range idxInfo.Columns { - if tblInfo.Columns[idxCol.Offset].ChangeStateInfo != nil { - return true - } - } - return false -} - -func findIdxCol(idxInfo *model.IndexInfo, colName model.CIStr) int { - for offset, idxCol := range idxInfo.Columns { - if idxCol.Name.L == colName.L { - return offset - } - } - return -1 -} - -func renameIndexes(tblInfo *model.TableInfo, from, to model.CIStr) { - for _, idx := range tblInfo.Indices { - if idx.Name.L == from.L { - idx.Name = to - } else if isTempIdxInfo(idx, tblInfo) && getChangingIndexOriginName(idx) == from.O { - idx.Name.L = strings.Replace(idx.Name.L, from.L, to.L, 1) - idx.Name.O = strings.Replace(idx.Name.O, from.O, to.O, 1) - } - } -} - -func renameHiddenColumns(tblInfo *model.TableInfo, from, to model.CIStr) { - for _, col := range tblInfo.Columns { - if col.Hidden && getExpressionIndexOriginName(col) == from.O { - col.Name.L = strings.Replace(col.Name.L, from.L, to.L, 1) - col.Name.O = strings.Replace(col.Name.O, from.O, to.O, 1) - } - } -} diff --git a/pkg/ddl/index_cop.go b/pkg/ddl/index_cop.go index dd05fd7fee847..30d6f70c8b9e0 100644 --- a/pkg/ddl/index_cop.go +++ b/pkg/ddl/index_cop.go @@ -145,11 +145,11 @@ func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) if err != nil { return err } - if val, _err_ := failpoint.Eval(_curpkg_("mockCopSenderPanic")); _err_ == nil { + failpoint.Inject("mockCopSenderPanic", func(val failpoint.Value) { if val.(bool) { panic("mock panic") } - } + }) if p.checkpointMgr != nil { p.checkpointMgr.Register(task.id, task.endKey) } @@ -169,9 +169,9 @@ func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) idxRs := IndexRecordChunk{ID: task.id, Chunk: srcChk, Done: done} rate := float64(srcChk.MemoryUsage()) / 1024.0 / 1024.0 / time.Since(startTime).Seconds() metrics.AddIndexScanRate.WithLabelValues(metrics.LblAddIndex).Observe(rate) - if _, _err_ := failpoint.Eval(_curpkg_("mockCopSenderError")); _err_ == nil { + failpoint.Inject("mockCopSenderError", func() { idxRs.Err = errors.New("mock cop error") - } + }) p.chunkSender.AddTask(idxRs) startTime = time.Now() } diff --git a/pkg/ddl/index_cop.go__failpoint_stash__ b/pkg/ddl/index_cop.go__failpoint_stash__ deleted file mode 100644 index 30d6f70c8b9e0..0000000000000 --- a/pkg/ddl/index_cop.go__failpoint_stash__ +++ /dev/null @@ -1,392 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "encoding/hex" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/copr" - "github.com/pingcap/tidb/pkg/ddl/ingest" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/distsql" - distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/expression" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/timeutil" - "github.com/pingcap/tipb/go-tipb" - kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -// copReadBatchSize is the batch size of coprocessor read. -// It multiplies the tidb_ddl_reorg_batch_size by 10 to avoid -// sending too many cop requests for the same handle range. -func copReadBatchSize() int { - return 10 * int(variable.GetDDLReorgBatchSize()) -} - -// copReadChunkPoolSize is the size of chunk pool, which -// represents the max concurrent ongoing coprocessor requests. -// It multiplies the tidb_ddl_reorg_worker_cnt by 10. -func copReadChunkPoolSize() int { - return 10 * int(variable.GetDDLReorgWorkerCounter()) -} - -// chunkSender is used to receive the result of coprocessor request. -type chunkSender interface { - AddTask(IndexRecordChunk) -} - -type copReqSenderPool struct { - tasksCh chan *reorgBackfillTask - chunkSender chunkSender - checkpointMgr *ingest.CheckpointManager - sessPool *sess.Pool - - ctx context.Context - copCtx copr.CopContext - store kv.Storage - - senders []*copReqSender - wg sync.WaitGroup - closed bool - - srcChkPool chan *chunk.Chunk -} - -type copReqSender struct { - senderPool *copReqSenderPool - - ctx context.Context - cancel context.CancelFunc -} - -func (c *copReqSender) run() { - p := c.senderPool - defer p.wg.Done() - defer util.Recover(metrics.LabelDDL, "copReqSender.run", func() { - p.chunkSender.AddTask(IndexRecordChunk{Err: dbterror.ErrReorgPanic}) - }, false) - sessCtx, err := p.sessPool.Get() - if err != nil { - logutil.Logger(p.ctx).Error("copReqSender get session from pool failed", zap.Error(err)) - p.chunkSender.AddTask(IndexRecordChunk{Err: err}) - return - } - se := sess.NewSession(sessCtx) - defer p.sessPool.Put(sessCtx) - var ( - task *reorgBackfillTask - ok bool - ) - - for { - select { - case <-c.ctx.Done(): - return - case task, ok = <-p.tasksCh: - } - if !ok { - return - } - if p.checkpointMgr != nil && p.checkpointMgr.IsKeyProcessed(task.endKey) { - logutil.Logger(p.ctx).Info("checkpoint detected, skip a cop-request task", - zap.Int("task ID", task.id), - zap.String("task end key", hex.EncodeToString(task.endKey))) - continue - } - err := scanRecords(p, task, se) - if err != nil { - p.chunkSender.AddTask(IndexRecordChunk{ID: task.id, Err: err}) - return - } - } -} - -func scanRecords(p *copReqSenderPool, task *reorgBackfillTask, se *sess.Session) error { - logutil.Logger(p.ctx).Info("start a cop-request task", - zap.Int("id", task.id), zap.Stringer("task", task)) - - return wrapInBeginRollback(se, func(startTS uint64) error { - rs, err := buildTableScan(p.ctx, p.copCtx.GetBase(), startTS, task.startKey, task.endKey) - if err != nil { - return err - } - failpoint.Inject("mockCopSenderPanic", func(val failpoint.Value) { - if val.(bool) { - panic("mock panic") - } - }) - if p.checkpointMgr != nil { - p.checkpointMgr.Register(task.id, task.endKey) - } - var done bool - startTime := time.Now() - for !done { - srcChk := p.getChunk() - done, err = fetchTableScanResult(p.ctx, p.copCtx.GetBase(), rs, srcChk) - if err != nil { - p.recycleChunk(srcChk) - terror.Call(rs.Close) - return err - } - if p.checkpointMgr != nil { - p.checkpointMgr.UpdateTotalKeys(task.id, srcChk.NumRows(), done) - } - idxRs := IndexRecordChunk{ID: task.id, Chunk: srcChk, Done: done} - rate := float64(srcChk.MemoryUsage()) / 1024.0 / 1024.0 / time.Since(startTime).Seconds() - metrics.AddIndexScanRate.WithLabelValues(metrics.LblAddIndex).Observe(rate) - failpoint.Inject("mockCopSenderError", func() { - idxRs.Err = errors.New("mock cop error") - }) - p.chunkSender.AddTask(idxRs) - startTime = time.Now() - } - terror.Call(rs.Close) - return nil - }) -} - -func wrapInBeginRollback(se *sess.Session, f func(startTS uint64) error) error { - err := se.Begin(context.Background()) - if err != nil { - return errors.Trace(err) - } - defer se.Rollback() - var startTS uint64 - sessVars := se.GetSessionVars() - sessVars.TxnCtxMu.Lock() - startTS = sessVars.TxnCtx.StartTS - sessVars.TxnCtxMu.Unlock() - return f(startTS) -} - -func newCopReqSenderPool(ctx context.Context, copCtx copr.CopContext, store kv.Storage, - taskCh chan *reorgBackfillTask, sessPool *sess.Pool, - checkpointMgr *ingest.CheckpointManager) *copReqSenderPool { - poolSize := copReadChunkPoolSize() - srcChkPool := make(chan *chunk.Chunk, poolSize) - for i := 0; i < poolSize; i++ { - srcChkPool <- chunk.NewChunkWithCapacity(copCtx.GetBase().FieldTypes, copReadBatchSize()) - } - return &copReqSenderPool{ - tasksCh: taskCh, - ctx: ctx, - copCtx: copCtx, - store: store, - senders: make([]*copReqSender, 0, variable.GetDDLReorgWorkerCounter()), - wg: sync.WaitGroup{}, - srcChkPool: srcChkPool, - sessPool: sessPool, - checkpointMgr: checkpointMgr, - } -} - -func (c *copReqSenderPool) adjustSize(n int) { - // Add some senders. - for i := len(c.senders); i < n; i++ { - ctx, cancel := context.WithCancel(c.ctx) - c.senders = append(c.senders, &copReqSender{ - senderPool: c, - ctx: ctx, - cancel: cancel, - }) - c.wg.Add(1) - go c.senders[i].run() - } - // Remove some senders. - if n < len(c.senders) { - for i := n; i < len(c.senders); i++ { - c.senders[i].cancel() - } - c.senders = c.senders[:n] - } -} - -func (c *copReqSenderPool) close(force bool) { - if c.closed { - return - } - logutil.Logger(c.ctx).Info("close cop-request sender pool", zap.Bool("force", force)) - if force { - for _, w := range c.senders { - w.cancel() - } - } - // Wait for all cop-req senders to exit. - c.wg.Wait() - c.closed = true -} - -func (c *copReqSenderPool) getChunk() *chunk.Chunk { - chk := <-c.srcChkPool - newCap := copReadBatchSize() - if chk.Capacity() != newCap { - chk = chunk.NewChunkWithCapacity(c.copCtx.GetBase().FieldTypes, newCap) - } - chk.Reset() - return chk -} - -// recycleChunk puts the index record slice and the chunk back to the pool for reuse. -func (c *copReqSenderPool) recycleChunk(chk *chunk.Chunk) { - if chk == nil { - return - } - c.srcChkPool <- chk -} - -func buildTableScan(ctx context.Context, c *copr.CopContextBase, startTS uint64, start, end kv.Key) (distsql.SelectResult, error) { - dagPB, err := buildDAGPB(c.ExprCtx, c.DistSQLCtx, c.PushDownFlags, c.TableInfo, c.ColumnInfos) - if err != nil { - return nil, err - } - - var builder distsql.RequestBuilder - kvReq, err := builder. - SetDAGRequest(dagPB). - SetStartTS(startTS). - SetKeyRanges([]kv.KeyRange{{StartKey: start, EndKey: end}}). - SetKeepOrder(true). - SetFromSessionVars(c.DistSQLCtx). - SetConcurrency(1). - Build() - kvReq.RequestSource.RequestSourceInternal = true - kvReq.RequestSource.RequestSourceType = getDDLRequestSource(model.ActionAddIndex) - kvReq.RequestSource.ExplicitRequestSourceType = kvutil.ExplicitTypeDDL - if err != nil { - return nil, err - } - return distsql.Select(ctx, c.DistSQLCtx, kvReq, c.FieldTypes) -} - -func fetchTableScanResult( - ctx context.Context, - copCtx *copr.CopContextBase, - result distsql.SelectResult, - chk *chunk.Chunk, -) (bool, error) { - err := result.Next(ctx, chk) - if err != nil { - return false, errors.Trace(err) - } - if chk.NumRows() == 0 { - return true, nil - } - err = table.FillVirtualColumnValue( - copCtx.VirtualColumnsFieldTypes, copCtx.VirtualColumnsOutputOffsets, - copCtx.ExprColumnInfos, copCtx.ColumnInfos, copCtx.ExprCtx, chk) - return false, err -} - -func completeErr(err error, idxInfo *model.IndexInfo) error { - if expression.ErrInvalidJSONForFuncIndex.Equal(err) { - err = expression.ErrInvalidJSONForFuncIndex.GenWithStackByArgs(idxInfo.Name.O) - } - return errors.Trace(err) -} - -func getRestoreData(tblInfo *model.TableInfo, targetIdx, pkIdx *model.IndexInfo, handleDts []types.Datum) []types.Datum { - if !collate.NewCollationEnabled() || !tblInfo.IsCommonHandle || tblInfo.CommonHandleVersion == 0 { - return nil - } - if pkIdx == nil { - return nil - } - for i, pkIdxCol := range pkIdx.Columns { - pkCol := tblInfo.Columns[pkIdxCol.Offset] - if !types.NeedRestoredData(&pkCol.FieldType) { - // Since the handle data cannot be null, we can use SetNull to - // indicate that this column does not need to be restored. - handleDts[i].SetNull() - continue - } - tables.TryTruncateRestoredData(&handleDts[i], pkCol, pkIdxCol, targetIdx) - tables.ConvertDatumToTailSpaceCount(&handleDts[i], pkCol) - } - dtToRestored := handleDts[:0] - for _, handleDt := range handleDts { - if !handleDt.IsNull() { - dtToRestored = append(dtToRestored, handleDt) - } - } - return dtToRestored -} - -func buildDAGPB(exprCtx exprctx.BuildContext, distSQLCtx *distsqlctx.DistSQLContext, pushDownFlags uint64, tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.DAGRequest, error) { - dagReq := &tipb.DAGRequest{} - dagReq.TimeZoneName, dagReq.TimeZoneOffset = timeutil.Zone(exprCtx.GetEvalCtx().Location()) - dagReq.Flags = pushDownFlags - for i := range colInfos { - dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) - } - execPB, err := constructTableScanPB(exprCtx, tblInfo, colInfos) - if err != nil { - return nil, err - } - dagReq.Executors = append(dagReq.Executors, execPB) - distsql.SetEncodeType(distSQLCtx, dagReq) - return dagReq, nil -} - -func constructTableScanPB(ctx exprctx.BuildContext, tblInfo *model.TableInfo, colInfos []*model.ColumnInfo) (*tipb.Executor, error) { - tblScan := tables.BuildTableScanFromInfos(tblInfo, colInfos) - tblScan.TableId = tblInfo.ID - err := tables.SetPBColumnsDefaultValue(ctx, tblScan.Columns, colInfos) - return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan}, err -} - -func extractDatumByOffsets(ctx expression.EvalContext, row chunk.Row, offsets []int, expCols []*expression.Column, buf []types.Datum) []types.Datum { - for i, offset := range offsets { - c := expCols[offset] - row.DatumWithBuffer(offset, c.GetType(ctx), &buf[i]) - } - return buf -} - -func buildHandle(pkDts []types.Datum, tblInfo *model.TableInfo, - pkInfo *model.IndexInfo, loc *time.Location, errCtx errctx.Context) (kv.Handle, error) { - if tblInfo.IsCommonHandle { - tablecodec.TruncateIndexValues(tblInfo, pkInfo, pkDts) - handleBytes, err := codec.EncodeKey(loc, nil, pkDts...) - err = errCtx.HandleError(err) - if err != nil { - return nil, err - } - return kv.NewCommonHandle(handleBytes) - } - return kv.IntHandle(pkDts[0].GetInt64()), nil -} diff --git a/pkg/ddl/index_merge_tmp.go b/pkg/ddl/index_merge_tmp.go index e6c3c1c61cd33..c0001250b6c22 100644 --- a/pkg/ddl/index_merge_tmp.go +++ b/pkg/ddl/index_merge_tmp.go @@ -241,12 +241,12 @@ func (w *mergeIndexWorker) BackfillData(taskRange reorgBackfillTask) (taskCtx ba return nil }) - if val, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionMerging")); _err_ == nil { + failpoint.Inject("mockDMLExecutionMerging", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) && MockDMLExecutionMerging != nil { MockDMLExecutionMerging() } - } + }) logSlowOperations(time.Since(oprStartTime), "AddIndexMergeDataInTxn", 3000) return } diff --git a/pkg/ddl/index_merge_tmp.go__failpoint_stash__ b/pkg/ddl/index_merge_tmp.go__failpoint_stash__ deleted file mode 100644 index c0001250b6c22..0000000000000 --- a/pkg/ddl/index_merge_tmp.go__failpoint_stash__ +++ /dev/null @@ -1,400 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/model" - driver "github.com/pingcap/tidb/pkg/store/driver/txn" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - kvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -func (w *mergeIndexWorker) batchCheckTemporaryUniqueKey( - txn kv.Transaction, - idxRecords []*temporaryIndexRecord, -) error { - if !w.currentIndex.Unique { - // non-unique key need no check, just overwrite it, - // because in most case, backfilling indices is not exists. - return nil - } - - batchVals, err := txn.BatchGet(context.Background(), w.originIdxKeys) - if err != nil { - return errors.Trace(err) - } - - for i, key := range w.originIdxKeys { - if val, found := batchVals[string(key)]; found { - // Found a value in the original index key. - err := checkTempIndexKey(txn, idxRecords[i], val, w.table) - if err != nil { - if kv.ErrKeyExists.Equal(err) { - return driver.ExtractKeyExistsErrFromIndex(key, val, w.table.Meta(), w.currentIndex.ID) - } - return errors.Trace(err) - } - } else if idxRecords[i].distinct { - // The keys in w.batchCheckKeys also maybe duplicate, - // so we need to backfill the not found key into `batchVals` map. - batchVals[string(key)] = idxRecords[i].vals - } - } - return nil -} - -func checkTempIndexKey(txn kv.Transaction, tmpRec *temporaryIndexRecord, originIdxVal []byte, tblInfo table.Table) error { - if !tmpRec.delete { - if tmpRec.distinct && !bytes.Equal(originIdxVal, tmpRec.vals) { - return kv.ErrKeyExists - } - // The key has been found in the original index, skip merging it. - tmpRec.skip = true - return nil - } - // Delete operation. - distinct := tablecodec.IndexKVIsUnique(originIdxVal) - if !distinct { - // For non-distinct key, it is consist of a null value and the handle. - // Same as the non-unique indexes, replay the delete operation on non-distinct keys. - return nil - } - // For distinct index key values, prevent deleting an unexpected index KV in original index. - hdInVal, err := tablecodec.DecodeHandleInIndexValue(originIdxVal) - if err != nil { - return errors.Trace(err) - } - if !tmpRec.handle.Equal(hdInVal) { - // The inequality means multiple modifications happened in the same key. - // We use the handle in origin index value to check if the row exists. - rowKey := tablecodec.EncodeRecordKey(tblInfo.RecordPrefix(), hdInVal) - _, err := txn.Get(context.Background(), rowKey) - if err != nil { - if kv.IsErrNotFound(err) { - // The row is deleted, so we can merge the delete operation to the origin index. - tmpRec.skip = false - return nil - } - // Unexpected errors. - return errors.Trace(err) - } - // Don't delete the index key if the row exists. - tmpRec.skip = true - return nil - } - return nil -} - -// temporaryIndexRecord is the record information of an index. -type temporaryIndexRecord struct { - vals []byte - skip bool // skip indicates that the index key is already exists, we should not add it. - delete bool - unique bool - distinct bool - handle kv.Handle -} - -type mergeIndexWorker struct { - *backfillCtx - - indexes []table.Index - - tmpIdxRecords []*temporaryIndexRecord - originIdxKeys []kv.Key - tmpIdxKeys []kv.Key - - needValidateKey bool - currentTempIndexPrefix []byte - currentIndex *model.IndexInfo -} - -func newMergeTempIndexWorker(bfCtx *backfillCtx, t table.PhysicalTable, elements []*meta.Element) *mergeIndexWorker { - allIndexes := make([]table.Index, 0, len(elements)) - for _, elem := range elements { - indexInfo := model.FindIndexInfoByID(t.Meta().Indices, elem.ID) - index := tables.NewIndex(t.GetPhysicalID(), t.Meta(), indexInfo) - allIndexes = append(allIndexes, index) - } - - return &mergeIndexWorker{ - backfillCtx: bfCtx, - indexes: allIndexes, - } -} - -func (w *mergeIndexWorker) validateTaskRange(taskRange *reorgBackfillTask) (skip bool, err error) { - tmpID, err := tablecodec.DecodeIndexID(taskRange.startKey) - if err != nil { - return false, err - } - startIndexID := tmpID & tablecodec.IndexIDMask - tmpID, err = tablecodec.DecodeIndexID(taskRange.endKey) - if err != nil { - return false, err - } - endIndexID := tmpID & tablecodec.IndexIDMask - - w.needValidateKey = startIndexID != endIndexID - containsTargetID := false - for _, idx := range w.indexes { - idxInfo := idx.Meta() - if idxInfo.ID == startIndexID { - containsTargetID = true - w.currentIndex = idxInfo - break - } - if idxInfo.ID == endIndexID { - containsTargetID = true - } - } - return !containsTargetID, nil -} - -// BackfillData merge temp index data in txn. -func (w *mergeIndexWorker) BackfillData(taskRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - skip, err := w.validateTaskRange(&taskRange) - if skip || err != nil { - return taskCtx, err - } - - oprStartTime := time.Now() - ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) - - errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { - taskCtx.addedCount = 0 - taskCtx.scanCount = 0 - updateTxnEntrySizeLimitIfNeeded(txn) - txn.SetOption(kv.Priority, taskRange.priority) - if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(taskRange.getJobID()); tagger != nil { - txn.SetOption(kv.ResourceGroupTagger, tagger) - } - txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) - - tmpIdxRecords, nextKey, taskDone, err := w.fetchTempIndexVals(txn, taskRange) - if err != nil { - return errors.Trace(err) - } - taskCtx.nextKey = nextKey - taskCtx.done = taskDone - - err = w.batchCheckTemporaryUniqueKey(txn, tmpIdxRecords) - if err != nil { - return errors.Trace(err) - } - - for i, idxRecord := range tmpIdxRecords { - taskCtx.scanCount++ - // The index is already exists, we skip it, no needs to backfill it. - // The following update, delete, insert on these rows, TiDB can handle it correctly. - // If all batch are skipped, update first index key to make txn commit to release lock. - if idxRecord.skip { - continue - } - - // Lock the corresponding row keys so that it doesn't modify the index KVs - // that are changing by a pessimistic transaction. - rowKey := tablecodec.EncodeRecordKey(w.table.RecordPrefix(), idxRecord.handle) - err := txn.LockKeys(context.Background(), new(kv.LockCtx), rowKey) - if err != nil { - return errors.Trace(err) - } - - if idxRecord.delete { - if idxRecord.unique { - err = txn.GetMemBuffer().DeleteWithFlags(w.originIdxKeys[i], kv.SetNeedLocked) - } else { - err = txn.GetMemBuffer().Delete(w.originIdxKeys[i]) - } - } else { - err = txn.GetMemBuffer().Set(w.originIdxKeys[i], idxRecord.vals) - } - if err != nil { - return err - } - taskCtx.addedCount++ - } - return nil - }) - - failpoint.Inject("mockDMLExecutionMerging", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) && MockDMLExecutionMerging != nil { - MockDMLExecutionMerging() - } - }) - logSlowOperations(time.Since(oprStartTime), "AddIndexMergeDataInTxn", 3000) - return -} - -func (*mergeIndexWorker) AddMetricInfo(float64) { -} - -func (*mergeIndexWorker) String() string { - return typeAddIndexMergeTmpWorker.String() -} - -func (w *mergeIndexWorker) GetCtx() *backfillCtx { - return w.backfillCtx -} - -func (w *mergeIndexWorker) prefixIsChanged(newKey kv.Key) bool { - return len(w.currentTempIndexPrefix) == 0 || !bytes.HasPrefix(newKey, w.currentTempIndexPrefix) -} - -func (w *mergeIndexWorker) updateCurrentIndexInfo(newIndexKey kv.Key) (skip bool, err error) { - tempIdxID, err := tablecodec.DecodeIndexID(newIndexKey) - if err != nil { - return false, err - } - idxID := tablecodec.IndexIDMask & tempIdxID - var curIdx *model.IndexInfo - for _, idx := range w.indexes { - if idx.Meta().ID == idxID { - curIdx = idx.Meta() - } - } - if curIdx == nil { - // Index IDs are always increasing, but not always continuous: - // if DDL adds another index between these indexes, it is possible that: - // multi-schema add index IDs = [1, 2, 4, 5] - // another index ID = [3] - // If the new index get rollback, temp index 0xFFxxx03 may have dirty records. - // We should skip these dirty records. - return true, nil - } - pfx := tablecodec.CutIndexPrefix(newIndexKey) - - w.currentTempIndexPrefix = kv.Key(pfx).Clone() - w.currentIndex = curIdx - - return false, nil -} - -func (w *mergeIndexWorker) fetchTempIndexVals( - txn kv.Transaction, - taskRange reorgBackfillTask, -) ([]*temporaryIndexRecord, kv.Key, bool, error) { - startTime := time.Now() - w.tmpIdxRecords = w.tmpIdxRecords[:0] - w.tmpIdxKeys = w.tmpIdxKeys[:0] - w.originIdxKeys = w.originIdxKeys[:0] - // taskDone means that the merged handle is out of taskRange.endHandle. - taskDone := false - oprStartTime := startTime - idxPrefix := w.table.IndexPrefix() - var lastKey kv.Key - err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, idxPrefix, txn.StartTS(), - taskRange.startKey, taskRange.endKey, func(_ kv.Handle, indexKey kv.Key, rawValue []byte) (more bool, err error) { - oprEndTime := time.Now() - logSlowOperations(oprEndTime.Sub(oprStartTime), "iterate temporary index in merge process", 0) - oprStartTime = oprEndTime - - taskDone = indexKey.Cmp(taskRange.endKey) >= 0 - - if taskDone || len(w.tmpIdxRecords) >= w.batchCnt { - return false, nil - } - - if w.needValidateKey && w.prefixIsChanged(indexKey) { - skip, err := w.updateCurrentIndexInfo(indexKey) - if err != nil || skip { - return skip, err - } - } - - tempIdxVal, err := tablecodec.DecodeTempIndexValue(rawValue) - if err != nil { - return false, err - } - tempIdxVal, err = decodeTempIndexHandleFromIndexKV(indexKey, tempIdxVal, len(w.currentIndex.Columns)) - if err != nil { - return false, err - } - - tempIdxVal = tempIdxVal.FilterOverwritten() - - // Extract the operations on the original index and replay them later. - for _, elem := range tempIdxVal { - if elem.KeyVer == tables.TempIndexKeyTypeMerge || elem.KeyVer == tables.TempIndexKeyTypeDelete { - // For 'm' version kvs, they are double-written. - // For 'd' version kvs, they are written in the delete-only state and can be dropped safely. - continue - } - - originIdxKey := make([]byte, len(indexKey)) - copy(originIdxKey, indexKey) - tablecodec.TempIndexKey2IndexKey(originIdxKey) - - idxRecord := &temporaryIndexRecord{ - handle: elem.Handle, - delete: elem.Delete, - unique: elem.Distinct, - skip: false, - } - if !elem.Delete { - idxRecord.vals = elem.Value - idxRecord.distinct = tablecodec.IndexKVIsUnique(elem.Value) - } - w.tmpIdxRecords = append(w.tmpIdxRecords, idxRecord) - w.originIdxKeys = append(w.originIdxKeys, originIdxKey) - w.tmpIdxKeys = append(w.tmpIdxKeys, indexKey) - } - - lastKey = indexKey - return true, nil - }) - - if len(w.tmpIdxRecords) == 0 { - taskDone = true - } - var nextKey kv.Key - if taskDone { - nextKey = taskRange.endKey - } else { - nextKey = lastKey - } - - logutil.DDLLogger().Debug("merge temp index txn fetches handle info", zap.Uint64("txnStartTS", txn.StartTS()), - zap.String("taskRange", taskRange.String()), zap.Duration("takeTime", time.Since(startTime))) - return w.tmpIdxRecords, nextKey.Next(), taskDone, errors.Trace(err) -} - -func decodeTempIndexHandleFromIndexKV(indexKey kv.Key, tmpVal tablecodec.TempIndexValue, idxColLen int) (ret tablecodec.TempIndexValue, err error) { - for _, elem := range tmpVal { - if elem.Handle == nil { - // If the handle is not found in the value of the temp index, it means - // 1) This is not a deletion marker, the handle is in the key or the origin value. - // 2) This is a deletion marker, but the handle is in the key of temp index. - elem.Handle, err = tablecodec.DecodeIndexHandle(indexKey, elem.Value, idxColLen) - if err != nil { - return nil, err - } - } - } - return tmpVal, nil -} diff --git a/pkg/ddl/ingest/backend.go b/pkg/ddl/ingest/backend.go index 802a5424ad9d2..6a8a4e0e6666a 100644 --- a/pkg/ddl/ingest/backend.go +++ b/pkg/ddl/ingest/backend.go @@ -215,11 +215,11 @@ func (bc *litBackendCtx) Flush(mode FlushMode) (flushed, imported bool, err erro } }() } - if _, _err_ := failpoint.Eval(_curpkg_("mockDMLExecutionStateBeforeImport")); _err_ == nil { + failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { if MockDMLExecutionStateBeforeImport != nil { MockDMLExecutionStateBeforeImport() } - } + }) for indexID, ei := range bc.engines { if err = bc.unsafeImportAndReset(ei); err != nil { @@ -286,9 +286,9 @@ func (bc *litBackendCtx) unsafeImportAndReset(ei *engineInfo) error { } err := resetFn(bc.ctx, ei.uuid) - if _, _err_ := failpoint.Eval(_curpkg_("mockResetEngineFailed")); _err_ == nil { + failpoint.Inject("mockResetEngineFailed", func() { err = fmt.Errorf("mock reset engine failed") - } + }) if err != nil { logutil.Logger(bc.ctx).Error(LitErrResetEngineFail, zap.Int64("index ID", ei.indexID)) err1 := closedEngine.Cleanup(bc.ctx) @@ -306,10 +306,10 @@ func (bc *litBackendCtx) unsafeImportAndReset(ei *engineInfo) error { var ForceSyncFlagForTest = false func (bc *litBackendCtx) checkFlush(mode FlushMode) (shouldFlush bool, shouldImport bool) { - if _, _err_ := failpoint.Eval(_curpkg_("forceSyncFlagForTest")); _err_ == nil { + failpoint.Inject("forceSyncFlagForTest", func() { // used in a manual test ForceSyncFlagForTest = true - } + }) if mode == FlushModeForceFlushAndImport || ForceSyncFlagForTest { return true, true } @@ -317,11 +317,11 @@ func (bc *litBackendCtx) checkFlush(mode FlushMode) (shouldFlush bool, shouldImp shouldImport = bc.diskRoot.ShouldImport() interval := bc.updateInterval // This failpoint will be manually set through HTTP status port. - if val, _err_ := failpoint.Eval(_curpkg_("mockSyncIntervalMs")); _err_ == nil { + failpoint.Inject("mockSyncIntervalMs", func(val failpoint.Value) { if v, ok := val.(int); ok { interval = time.Duration(v) * time.Millisecond } - } + }) shouldFlush = shouldImport || time.Since(bc.timeOfLastFlush.Load()) >= interval return shouldFlush, shouldImport diff --git a/pkg/ddl/ingest/backend.go__failpoint_stash__ b/pkg/ddl/ingest/backend.go__failpoint_stash__ deleted file mode 100644 index 6a8a4e0e6666a..0000000000000 --- a/pkg/ddl/ingest/backend.go__failpoint_stash__ +++ /dev/null @@ -1,343 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ingest - -import ( - "context" - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - tikv "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/common" - lightning "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/errormanager" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/util/logutil" - clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/client/v3/concurrency" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -// MockDMLExecutionStateBeforeImport is a failpoint to mock the DML execution state before import. -var MockDMLExecutionStateBeforeImport func() - -// BackendCtx is the backend context for one add index reorg task. -type BackendCtx interface { - // Register create a new engineInfo for each index ID and register it to the - // backend context. If the index ID is already registered, it will return the - // associated engines. Only one group of index ID is allowed to register for a - // BackendCtx. - // - // Register is only used in local disk based ingest. - Register(indexIDs []int64, uniques []bool, tbl table.Table) ([]Engine, error) - // FinishAndUnregisterEngines finishes the task and unregisters all engines that - // are Register-ed before. It's safe to call it multiple times. - // - // FinishAndUnregisterEngines is only used in local disk based ingest. - FinishAndUnregisterEngines(opt UnregisterOpt) error - - FlushController - - AttachCheckpointManager(*CheckpointManager) - GetCheckpointManager() *CheckpointManager - - // GetLocalBackend exposes local.Backend. It's only used in global sort based - // ingest. - GetLocalBackend() *local.Backend - // CollectRemoteDuplicateRows collects duplicate entry error for given index as - // the supplement of FlushController.Flush. - // - // CollectRemoteDuplicateRows is only used in global sort based ingest. - CollectRemoteDuplicateRows(indexID int64, tbl table.Table) error -} - -// FlushMode is used to control how to flush. -type FlushMode byte - -const ( - // FlushModeAuto means caller does not enforce any flush, the implementation can - // decide it. - FlushModeAuto FlushMode = iota - // FlushModeForceFlushAndImport means flush and import all data to TiKV. - FlushModeForceFlushAndImport -) - -// litBackendCtx implements BackendCtx. -type litBackendCtx struct { - engines map[int64]*engineInfo - memRoot MemRoot - diskRoot DiskRoot - jobID int64 - tbl table.Table - backend *local.Backend - ctx context.Context - cfg *lightning.Config - sysVars map[string]string - - flushing atomic.Bool - timeOfLastFlush atomicutil.Time - updateInterval time.Duration - checkpointMgr *CheckpointManager - etcdClient *clientv3.Client - - // unregisterMu prevents concurrent calls of `FinishAndUnregisterEngines`. - // For details, see https://github.com/pingcap/tidb/issues/53843. - unregisterMu sync.Mutex -} - -func (bc *litBackendCtx) handleErrorAfterCollectRemoteDuplicateRows( - err error, - indexID int64, - tbl table.Table, - hasDupe bool, -) error { - if err != nil && !common.ErrFoundIndexConflictRecords.Equal(err) { - logutil.Logger(bc.ctx).Error(LitInfoRemoteDupCheck, zap.Error(err), - zap.String("table", tbl.Meta().Name.O), zap.Int64("index ID", indexID)) - return errors.Trace(err) - } else if hasDupe { - logutil.Logger(bc.ctx).Error(LitErrRemoteDupExistErr, - zap.String("table", tbl.Meta().Name.O), zap.Int64("index ID", indexID)) - - if common.ErrFoundIndexConflictRecords.Equal(err) { - tErr, ok := errors.Cause(err).(*terror.Error) - if !ok { - return errors.Trace(tikv.ErrKeyExists) - } - if len(tErr.Args()) != 4 { - return errors.Trace(tikv.ErrKeyExists) - } - //nolint: forcetypeassert - indexName := tErr.Args()[1].(string) - //nolint: forcetypeassert - keyCols := tErr.Args()[2].([]string) - return errors.Trace(tikv.GenKeyExistsErr(keyCols, indexName)) - } - return errors.Trace(tikv.ErrKeyExists) - } - return nil -} - -// CollectRemoteDuplicateRows collects duplicate rows from remote TiKV. -func (bc *litBackendCtx) CollectRemoteDuplicateRows(indexID int64, tbl table.Table) error { - return bc.collectRemoteDuplicateRows(indexID, tbl) -} - -func (bc *litBackendCtx) collectRemoteDuplicateRows(indexID int64, tbl table.Table) error { - errorMgr := errormanager.New(nil, bc.cfg, log.Logger{Logger: logutil.Logger(bc.ctx)}) - dupeController := bc.backend.GetDupeController(bc.cfg.TikvImporter.RangeConcurrency*2, errorMgr) - hasDupe, err := dupeController.CollectRemoteDuplicateRows(bc.ctx, tbl, tbl.Meta().Name.L, &encode.SessionOptions{ - SQLMode: mysql.ModeStrictAllTables, - SysVars: bc.sysVars, - IndexID: indexID, - }, lightning.ErrorOnDup) - return bc.handleErrorAfterCollectRemoteDuplicateRows(err, indexID, tbl, hasDupe) -} - -func acquireLock(ctx context.Context, se *concurrency.Session, key string) (*concurrency.Mutex, error) { - mu := concurrency.NewMutex(se, key) - err := mu.Lock(ctx) - if err != nil { - return nil, err - } - return mu, nil -} - -// Flush implements FlushController. -func (bc *litBackendCtx) Flush(mode FlushMode) (flushed, imported bool, err error) { - shouldFlush, shouldImport := bc.checkFlush(mode) - if !shouldFlush { - return false, false, nil - } - if !bc.flushing.CompareAndSwap(false, true) { - return false, false, nil - } - defer bc.flushing.Store(false) - - for _, ei := range bc.engines { - ei.flushLock.Lock() - //nolint: all_revive,revive - defer ei.flushLock.Unlock() - - if err = ei.Flush(); err != nil { - return false, false, err - } - } - bc.timeOfLastFlush.Store(time.Now()) - - if !shouldImport { - return true, false, nil - } - - // Use distributed lock if run in distributed mode). - if bc.etcdClient != nil { - distLockKey := fmt.Sprintf("/tidb/distributeLock/%d", bc.jobID) - se, _ := concurrency.NewSession(bc.etcdClient) - mu, err := acquireLock(bc.ctx, se, distLockKey) - if err != nil { - return true, false, errors.Trace(err) - } - logutil.Logger(bc.ctx).Info("acquire distributed flush lock success", zap.Int64("jobID", bc.jobID)) - defer func() { - err = mu.Unlock(bc.ctx) - if err != nil { - logutil.Logger(bc.ctx).Warn("release distributed flush lock error", zap.Error(err), zap.Int64("jobID", bc.jobID)) - } else { - logutil.Logger(bc.ctx).Info("release distributed flush lock success", zap.Int64("jobID", bc.jobID)) - } - err = se.Close() - if err != nil { - logutil.Logger(bc.ctx).Warn("close session error", zap.Error(err)) - } - }() - } - failpoint.Inject("mockDMLExecutionStateBeforeImport", func(_ failpoint.Value) { - if MockDMLExecutionStateBeforeImport != nil { - MockDMLExecutionStateBeforeImport() - } - }) - - for indexID, ei := range bc.engines { - if err = bc.unsafeImportAndReset(ei); err != nil { - if common.ErrFoundDuplicateKeys.Equal(err) { - idxInfo := model.FindIndexInfoByID(bc.tbl.Meta().Indices, indexID) - if idxInfo == nil { - logutil.Logger(bc.ctx).Error( - "index not found", - zap.Int64("indexID", indexID)) - err = tikv.ErrKeyExists - } else { - err = TryConvertToKeyExistsErr(err, idxInfo, bc.tbl.Meta()) - } - } - return true, false, err - } - } - - var newTS uint64 - if mgr := bc.GetCheckpointManager(); mgr != nil { - // for local disk case, we need to refresh TS because duplicate detection - // requires each ingest to have a unique TS. - // - // TODO(lance6716): there's still a chance that data is imported but because of - // checkpoint is low-watermark, the data will still be imported again with - // another TS after failover. Need to refine the checkpoint mechanism. - newTS, err = mgr.refreshTSAndUpdateCP() - if err == nil { - for _, ei := range bc.engines { - ei.openedEngine.SetTS(newTS) - } - } - } - - return true, true, err -} - -func (bc *litBackendCtx) unsafeImportAndReset(ei *engineInfo) error { - logger := log.FromContext(bc.ctx).With( - zap.Stringer("engineUUID", ei.uuid), - ) - logger.Info(LitInfoUnsafeImport, - zap.Int64("index ID", ei.indexID), - zap.String("usage info", bc.diskRoot.UsageInfo())) - - closedEngine := backend.NewClosedEngine(bc.backend, logger, ei.uuid, 0) - - regionSplitSize := int64(lightning.SplitRegionSize) * int64(lightning.MaxSplitRegionSizeRatio) - regionSplitKeys := int64(lightning.SplitRegionKeys) - if err := closedEngine.Import(bc.ctx, regionSplitSize, regionSplitKeys); err != nil { - logutil.Logger(bc.ctx).Error(LitErrIngestDataErr, zap.Int64("index ID", ei.indexID), - zap.String("usage info", bc.diskRoot.UsageInfo())) - return err - } - - resetFn := bc.backend.ResetEngineSkipAllocTS - mgr := bc.GetCheckpointManager() - if mgr == nil { - // disttask case, no need to refresh TS. - // - // TODO(lance6716): for disttask local sort case, we need to use a fixed TS. But - // it doesn't have checkpoint, so we need to find a way to save TS. - resetFn = bc.backend.ResetEngine - } - - err := resetFn(bc.ctx, ei.uuid) - failpoint.Inject("mockResetEngineFailed", func() { - err = fmt.Errorf("mock reset engine failed") - }) - if err != nil { - logutil.Logger(bc.ctx).Error(LitErrResetEngineFail, zap.Int64("index ID", ei.indexID)) - err1 := closedEngine.Cleanup(bc.ctx) - if err1 != nil { - logutil.Logger(ei.ctx).Error(LitErrCleanEngineErr, zap.Error(err1), - zap.Int64("job ID", ei.jobID), zap.Int64("index ID", ei.indexID)) - } - ei.openedEngine = nil - return err - } - return nil -} - -// ForceSyncFlagForTest is a flag to force sync only for test. -var ForceSyncFlagForTest = false - -func (bc *litBackendCtx) checkFlush(mode FlushMode) (shouldFlush bool, shouldImport bool) { - failpoint.Inject("forceSyncFlagForTest", func() { - // used in a manual test - ForceSyncFlagForTest = true - }) - if mode == FlushModeForceFlushAndImport || ForceSyncFlagForTest { - return true, true - } - bc.diskRoot.UpdateUsage() - shouldImport = bc.diskRoot.ShouldImport() - interval := bc.updateInterval - // This failpoint will be manually set through HTTP status port. - failpoint.Inject("mockSyncIntervalMs", func(val failpoint.Value) { - if v, ok := val.(int); ok { - interval = time.Duration(v) * time.Millisecond - } - }) - shouldFlush = shouldImport || - time.Since(bc.timeOfLastFlush.Load()) >= interval - return shouldFlush, shouldImport -} - -// AttachCheckpointManager attaches a checkpoint manager to the backend context. -func (bc *litBackendCtx) AttachCheckpointManager(mgr *CheckpointManager) { - bc.checkpointMgr = mgr -} - -// GetCheckpointManager returns the checkpoint manager attached to the backend context. -func (bc *litBackendCtx) GetCheckpointManager() *CheckpointManager { - return bc.checkpointMgr -} - -// GetLocalBackend returns the local backend. -func (bc *litBackendCtx) GetLocalBackend() *local.Backend { - return bc.backend -} diff --git a/pkg/ddl/ingest/backend_mgr.go b/pkg/ddl/ingest/backend_mgr.go index 031a5e0da6886..068047e5a8710 100644 --- a/pkg/ddl/ingest/backend_mgr.go +++ b/pkg/ddl/ingest/backend_mgr.go @@ -136,9 +136,9 @@ func (m *litBackendCtxMgr) Register( logutil.Logger(ctx).Warn(LitWarnConfigError, zap.Int64("job ID", jobID), zap.Error(err)) return nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("beforeCreateLocalBackend")); _err_ == nil { + failpoint.Inject("beforeCreateLocalBackend", func() { ResignOwnerForTest.Store(true) - } + }) // lock backends because createLocalBackend will let lightning create the sort // folder, which may cause cleanupSortPath wrongly delete the sort folder if only // checking the existence of the entry in backends. diff --git a/pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ b/pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ deleted file mode 100644 index 068047e5a8710..0000000000000 --- a/pkg/ddl/ingest/backend_mgr.go__failpoint_stash__ +++ /dev/null @@ -1,285 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ingest - -import ( - "context" - "math" - "os" - "path/filepath" - "strconv" - "sync" - "time" - - "github.com/pingcap/failpoint" - ddllogutil "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/util/logutil" - kvutil "github.com/tikv/client-go/v2/util" - pd "github.com/tikv/pd/client" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -// BackendCtxMgr is used to manage the BackendCtx. -type BackendCtxMgr interface { - // CheckMoreTasksAvailable checks if it can run more ingest backfill tasks. - CheckMoreTasksAvailable() (bool, error) - // Register uses jobID to identify the BackendCtx. If there's already a - // BackendCtx with the same jobID, it will be returned. Otherwise, a new - // BackendCtx will be created and returned. - Register( - ctx context.Context, - jobID int64, - hasUnique bool, - etcdClient *clientv3.Client, - pdSvcDiscovery pd.ServiceDiscovery, - resourceGroupName string, - ) (BackendCtx, error) - Unregister(jobID int64) - // EncodeJobSortPath encodes the job ID to the local disk sort path. - EncodeJobSortPath(jobID int64) string - // Load returns the registered BackendCtx with the given jobID. - Load(jobID int64) (BackendCtx, bool) -} - -// litBackendCtxMgr manages multiple litBackendCtx for each DDL job. Each -// litBackendCtx can use some local disk space and memory resource which are -// controlled by litBackendCtxMgr. -type litBackendCtxMgr struct { - // the lifetime of entries in backends should cover all other resources so it can - // be used as a lightweight indicator when interacts with other resources. - // Currently, the entry must be created not after disk folder is created and - // memory usage is tracked, and vice versa when considering deletion. - backends struct { - mu sync.RWMutex - m map[int64]*litBackendCtx - } - // all disk resources of litBackendCtx should be used under path. Currently the - // hierarchy is ${path}/${jobID} for each litBackendCtx. - path string - memRoot MemRoot - diskRoot DiskRoot -} - -// NewLitBackendCtxMgr creates a new litBackendCtxMgr. -func NewLitBackendCtxMgr(path string, memQuota uint64) BackendCtxMgr { - mgr := &litBackendCtxMgr{ - path: path, - } - mgr.backends.m = make(map[int64]*litBackendCtx, 4) - mgr.memRoot = NewMemRootImpl(int64(memQuota), mgr) - mgr.diskRoot = NewDiskRootImpl(path, mgr) - LitMemRoot = mgr.memRoot - litDiskRoot = mgr.diskRoot - litDiskRoot.UpdateUsage() - err := litDiskRoot.StartupCheck() - if err != nil { - ddllogutil.DDLIngestLogger().Warn("ingest backfill may not be available", zap.Error(err)) - } - return mgr -} - -// CheckMoreTasksAvailable implements BackendCtxMgr.CheckMoreTaskAvailable interface. -func (m *litBackendCtxMgr) CheckMoreTasksAvailable() (bool, error) { - if err := m.diskRoot.PreCheckUsage(); err != nil { - ddllogutil.DDLIngestLogger().Info("ingest backfill is not available", zap.Error(err)) - return false, err - } - return true, nil -} - -// ResignOwnerForTest is only used for test. -var ResignOwnerForTest = atomic.NewBool(false) - -// Register creates a new backend and registers it to the backend context. -func (m *litBackendCtxMgr) Register( - ctx context.Context, - jobID int64, - hasUnique bool, - etcdClient *clientv3.Client, - pdSvcDiscovery pd.ServiceDiscovery, - resourceGroupName string, -) (BackendCtx, error) { - bc, exist := m.Load(jobID) - if exist { - return bc, nil - } - - m.memRoot.RefreshConsumption() - ok := m.memRoot.CheckConsume(structSizeBackendCtx) - if !ok { - return nil, genBackendAllocMemFailedErr(ctx, m.memRoot, jobID) - } - sortPath := m.EncodeJobSortPath(jobID) - err := os.MkdirAll(sortPath, 0700) - if err != nil { - logutil.Logger(ctx).Error(LitErrCreateDirFail, zap.Error(err)) - return nil, err - } - cfg, err := genConfig(ctx, sortPath, m.memRoot, hasUnique, resourceGroupName) - if err != nil { - logutil.Logger(ctx).Warn(LitWarnConfigError, zap.Int64("job ID", jobID), zap.Error(err)) - return nil, err - } - failpoint.Inject("beforeCreateLocalBackend", func() { - ResignOwnerForTest.Store(true) - }) - // lock backends because createLocalBackend will let lightning create the sort - // folder, which may cause cleanupSortPath wrongly delete the sort folder if only - // checking the existence of the entry in backends. - m.backends.mu.Lock() - bd, err := createLocalBackend(ctx, cfg, pdSvcDiscovery) - if err != nil { - m.backends.mu.Unlock() - logutil.Logger(ctx).Error(LitErrCreateBackendFail, zap.Int64("job ID", jobID), zap.Error(err)) - return nil, err - } - - bcCtx := newBackendContext(ctx, jobID, bd, cfg.lightning, defaultImportantVariables, m.memRoot, m.diskRoot, etcdClient) - m.backends.m[jobID] = bcCtx - m.memRoot.Consume(structSizeBackendCtx) - m.backends.mu.Unlock() - - logutil.Logger(ctx).Info(LitInfoCreateBackend, zap.Int64("job ID", jobID), - zap.Int64("current memory usage", m.memRoot.CurrentUsage()), - zap.Int64("max memory quota", m.memRoot.MaxMemoryQuota()), - zap.Bool("has unique index", hasUnique)) - return bcCtx, nil -} - -// EncodeJobSortPath implements BackendCtxMgr. -func (m *litBackendCtxMgr) EncodeJobSortPath(jobID int64) string { - return filepath.Join(m.path, encodeBackendTag(jobID)) -} - -func createLocalBackend( - ctx context.Context, - cfg *litConfig, - pdSvcDiscovery pd.ServiceDiscovery, -) (*local.Backend, error) { - tls, err := cfg.lightning.ToTLS() - if err != nil { - logutil.Logger(ctx).Error(LitErrCreateBackendFail, zap.Error(err)) - return nil, err - } - - ddllogutil.DDLIngestLogger().Info("create local backend for adding index", - zap.String("sortDir", cfg.lightning.TikvImporter.SortedKVDir), - zap.String("keyspaceName", cfg.keyspaceName)) - // We disable the switch TiKV mode feature for now, - // because the impact is not fully tested. - var raftKV2SwitchModeDuration time.Duration - backendConfig := local.NewBackendConfig(cfg.lightning, int(litRLimit), cfg.keyspaceName, cfg.resourceGroup, kvutil.ExplicitTypeDDL, raftKV2SwitchModeDuration) - return local.NewBackend(ctx, tls, backendConfig, pdSvcDiscovery) -} - -const checkpointUpdateInterval = 10 * time.Minute - -func newBackendContext( - ctx context.Context, - jobID int64, - be *local.Backend, - cfg *config.Config, - vars map[string]string, - memRoot MemRoot, - diskRoot DiskRoot, - etcdClient *clientv3.Client, -) *litBackendCtx { - bCtx := &litBackendCtx{ - engines: make(map[int64]*engineInfo, 10), - memRoot: memRoot, - diskRoot: diskRoot, - jobID: jobID, - backend: be, - ctx: ctx, - cfg: cfg, - sysVars: vars, - updateInterval: checkpointUpdateInterval, - etcdClient: etcdClient, - } - bCtx.timeOfLastFlush.Store(time.Now()) - return bCtx -} - -// Unregister removes a backend context from the backend context manager. -func (m *litBackendCtxMgr) Unregister(jobID int64) { - m.backends.mu.RLock() - _, exist := m.backends.m[jobID] - m.backends.mu.RUnlock() - if !exist { - return - } - - m.backends.mu.Lock() - defer m.backends.mu.Unlock() - bc, exist := m.backends.m[jobID] - if !exist { - return - } - _ = bc.FinishAndUnregisterEngines(OptCloseEngines) - bc.backend.Close() - m.memRoot.Release(structSizeBackendCtx) - m.memRoot.ReleaseWithTag(encodeBackendTag(jobID)) - logutil.Logger(bc.ctx).Info(LitInfoCloseBackend, zap.Int64("job ID", jobID), - zap.Int64("current memory usage", m.memRoot.CurrentUsage()), - zap.Int64("max memory quota", m.memRoot.MaxMemoryQuota())) - delete(m.backends.m, jobID) -} - -func (m *litBackendCtxMgr) Load(jobID int64) (BackendCtx, bool) { - m.backends.mu.RLock() - defer m.backends.mu.RUnlock() - ret, ok := m.backends.m[jobID] - return ret, ok -} - -// TotalDiskUsage returns the total disk usage of all backends. -func (m *litBackendCtxMgr) TotalDiskUsage() uint64 { - var totalDiskUsed uint64 - m.backends.mu.RLock() - defer m.backends.mu.RUnlock() - - for _, bc := range m.backends.m { - _, _, bcDiskUsed, _ := local.CheckDiskQuota(bc.backend, math.MaxInt64) - totalDiskUsed += uint64(bcDiskUsed) - } - return totalDiskUsed -} - -// UpdateMemoryUsage collects the memory usages from all the backend and updates it to the memRoot. -func (m *litBackendCtxMgr) UpdateMemoryUsage() { - m.backends.mu.RLock() - defer m.backends.mu.RUnlock() - - for _, bc := range m.backends.m { - curSize := bc.backend.TotalMemoryConsume() - m.memRoot.ReleaseWithTag(encodeBackendTag(bc.jobID)) - m.memRoot.ConsumeWithTag(encodeBackendTag(bc.jobID), curSize) - } -} - -// encodeBackendTag encodes the job ID to backend tag. -// The backend tag is also used as the file name of the local index data files. -func encodeBackendTag(jobID int64) string { - return strconv.FormatInt(jobID, 10) -} - -// decodeBackendTag decodes the backend tag to job ID. -func decodeBackendTag(name string) (int64, error) { - return strconv.ParseInt(name, 10, 64) -} diff --git a/pkg/ddl/ingest/binding__failpoint_binding__.go b/pkg/ddl/ingest/binding__failpoint_binding__.go deleted file mode 100644 index 207ba83505336..0000000000000 --- a/pkg/ddl/ingest/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package ingest - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/ddl/ingest/checkpoint.go b/pkg/ddl/ingest/checkpoint.go index 8560f616abe3d..2d1536905736c 100644 --- a/pkg/ddl/ingest/checkpoint.go +++ b/pkg/ddl/ingest/checkpoint.go @@ -221,14 +221,14 @@ func (s *CheckpointManager) AdvanceWatermark(flushed, imported bool) { return } - if _, _err_ := failpoint.Eval(_curpkg_("resignAfterFlush")); _err_ == nil { + failpoint.Inject("resignAfterFlush", func() { // used in a manual test ResignOwnerForTest.Store(true) // wait until ResignOwnerForTest is processed for ResignOwnerForTest.Load() { time.Sleep(100 * time.Millisecond) } - } + }) s.mu.Lock() defer s.mu.Unlock() @@ -445,10 +445,10 @@ func (s *CheckpointManager) updateCheckpointImpl() error { } func (s *CheckpointManager) updateCheckpointLoop() { - if _, _err_ := failpoint.Eval(_curpkg_("checkpointLoopExit")); _err_ == nil { + failpoint.Inject("checkpointLoopExit", func() { // used in a manual test - return - } + failpoint.Return() + }) ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() for { @@ -477,10 +477,10 @@ func (s *CheckpointManager) updateCheckpointLoop() { } func (s *CheckpointManager) updateCheckpoint() error { - if _, _err_ := failpoint.Eval(_curpkg_("checkpointLoopExit")); _err_ == nil { + failpoint.Inject("checkpointLoopExit", func() { // used in a manual test - return errors.New("failpoint triggered so can't update checkpoint") - } + failpoint.Return(errors.New("failpoint triggered so can't update checkpoint")) + }) finishCh := make(chan struct{}) select { case s.updaterCh <- finishCh: diff --git a/pkg/ddl/ingest/checkpoint.go__failpoint_stash__ b/pkg/ddl/ingest/checkpoint.go__failpoint_stash__ deleted file mode 100644 index 2d1536905736c..0000000000000 --- a/pkg/ddl/ingest/checkpoint.go__failpoint_stash__ +++ /dev/null @@ -1,509 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ingest - -import ( - "context" - "encoding/hex" - "encoding/json" - "fmt" - "net" - "strconv" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/tikv/client-go/v2/oracle" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" -) - -// CheckpointManager is a checkpoint manager implementation that used by -// non-distributed reorganization. It manages the data as two-level checkpoints: -// "flush"ed to local storage and "import"ed to TiKV. The checkpoint is saved in -// a table in the TiDB cluster. -type CheckpointManager struct { - ctx context.Context - cancel context.CancelFunc - sessPool *sess.Pool - jobID int64 - indexIDs []int64 - localStoreDir string - pdCli pd.Client - logger *zap.Logger - physicalID int64 - - // Derived and unchanged after the initialization. - instanceAddr string - localDataIsValid bool - - // Live in memory. - mu sync.Mutex - checkpoints map[int]*taskCheckpoint // task ID -> checkpoint - // we require each task ID to be continuous and start from 1. - minTaskIDFinished int - dirty bool - - // Persisted to the storage. - flushedKeyLowWatermark kv.Key - importedKeyLowWatermark kv.Key - flushedKeyCnt int - importedKeyCnt int - - ts uint64 - - // For persisting the checkpoint periodically. - updaterWg sync.WaitGroup - updaterCh chan chan struct{} -} - -// taskCheckpoint is the checkpoint for a single task. -type taskCheckpoint struct { - totalKeys int - writtenKeys int - checksum int64 - endKey kv.Key - lastBatchRead bool -} - -// FlushController is an interface to control the flush of data so after it -// returns caller can save checkpoint. -type FlushController interface { - // Flush checks if al engines need to be flushed and imported based on given - // FlushMode. It's concurrent safe. - Flush(mode FlushMode) (flushed, imported bool, err error) -} - -// NewCheckpointManager creates a new checkpoint manager. -func NewCheckpointManager( - ctx context.Context, - sessPool *sess.Pool, - physicalID int64, - jobID int64, - indexIDs []int64, - localStoreDir string, - pdCli pd.Client, -) (*CheckpointManager, error) { - instanceAddr := InstanceAddr() - ctx2, cancel := context.WithCancel(ctx) - logger := logutil.DDLIngestLogger().With( - zap.Int64("jobID", jobID), zap.Int64s("indexIDs", indexIDs)) - - cm := &CheckpointManager{ - ctx: ctx2, - cancel: cancel, - sessPool: sessPool, - jobID: jobID, - indexIDs: indexIDs, - localStoreDir: localStoreDir, - pdCli: pdCli, - logger: logger, - checkpoints: make(map[int]*taskCheckpoint, 16), - mu: sync.Mutex{}, - instanceAddr: instanceAddr, - physicalID: physicalID, - updaterWg: sync.WaitGroup{}, - updaterCh: make(chan chan struct{}), - } - err := cm.resumeOrInitCheckpoint() - if err != nil { - return nil, err - } - cm.updaterWg.Add(1) - go func() { - cm.updateCheckpointLoop() - cm.updaterWg.Done() - }() - logger.Info("create checkpoint manager") - return cm, nil -} - -// InstanceAddr returns the string concat with instance address and temp-dir. -func InstanceAddr() string { - cfg := config.GetGlobalConfig() - dsn := net.JoinHostPort(cfg.AdvertiseAddress, strconv.Itoa(int(cfg.Port))) - return fmt.Sprintf("%s:%s", dsn, cfg.TempDir) -} - -// IsKeyProcessed checks if the key is processed. The key may not be imported. -// This is called before the reader reads the data and decides whether to skip -// the current task. -func (s *CheckpointManager) IsKeyProcessed(end kv.Key) bool { - s.mu.Lock() - defer s.mu.Unlock() - if len(s.importedKeyLowWatermark) > 0 && end.Cmp(s.importedKeyLowWatermark) <= 0 { - return true - } - return s.localDataIsValid && len(s.flushedKeyLowWatermark) > 0 && end.Cmp(s.flushedKeyLowWatermark) <= 0 -} - -// LastProcessedKey finds the last processed key in checkpoint. -// If there is no processed key, it returns nil. -func (s *CheckpointManager) LastProcessedKey() kv.Key { - s.mu.Lock() - defer s.mu.Unlock() - - if s.localDataIsValid && len(s.flushedKeyLowWatermark) > 0 { - return s.flushedKeyLowWatermark.Clone() - } - if len(s.importedKeyLowWatermark) > 0 { - return s.importedKeyLowWatermark.Clone() - } - return nil -} - -// Status returns the status of the checkpoint. -func (s *CheckpointManager) Status() (keyCnt int, minKeyImported kv.Key) { - s.mu.Lock() - defer s.mu.Unlock() - total := 0 - for _, cp := range s.checkpoints { - total += cp.writtenKeys - } - // TODO(lance6716): ??? - return s.flushedKeyCnt + total, s.importedKeyLowWatermark -} - -// Register registers a new task. taskID MUST be continuous ascending and start -// from 1. -// -// TODO(lance6716): remove this constraint, use endKey as taskID and use -// ordered map type for checkpoints. -func (s *CheckpointManager) Register(taskID int, end kv.Key) { - s.mu.Lock() - defer s.mu.Unlock() - s.checkpoints[taskID] = &taskCheckpoint{ - endKey: end, - } -} - -// UpdateTotalKeys updates the total keys of the task. -// This is called by the reader after reading the data to update the number of rows contained in the current chunk. -func (s *CheckpointManager) UpdateTotalKeys(taskID int, delta int, last bool) { - s.mu.Lock() - defer s.mu.Unlock() - cp := s.checkpoints[taskID] - cp.totalKeys += delta - cp.lastBatchRead = last -} - -// UpdateWrittenKeys updates the written keys of the task. -// This is called by the writer after writing the local engine to update the current number of rows written. -func (s *CheckpointManager) UpdateWrittenKeys(taskID int, delta int) { - s.mu.Lock() - cp := s.checkpoints[taskID] - cp.writtenKeys += delta - s.mu.Unlock() -} - -// AdvanceWatermark advances the watermark according to flushed or imported status. -func (s *CheckpointManager) AdvanceWatermark(flushed, imported bool) { - if !flushed { - return - } - - failpoint.Inject("resignAfterFlush", func() { - // used in a manual test - ResignOwnerForTest.Store(true) - // wait until ResignOwnerForTest is processed - for ResignOwnerForTest.Load() { - time.Sleep(100 * time.Millisecond) - } - }) - - s.mu.Lock() - defer s.mu.Unlock() - s.afterFlush() - - if imported { - s.afterImport() - } -} - -// afterFlush should be called after all engine is flushed. -func (s *CheckpointManager) afterFlush() { - for { - cp := s.checkpoints[s.minTaskIDFinished+1] - if cp == nil || !cp.lastBatchRead || cp.writtenKeys < cp.totalKeys { - break - } - s.minTaskIDFinished++ - s.flushedKeyLowWatermark = cp.endKey - s.flushedKeyCnt += cp.totalKeys - delete(s.checkpoints, s.minTaskIDFinished) - s.dirty = true - } -} - -func (s *CheckpointManager) afterImport() { - if s.importedKeyLowWatermark.Cmp(s.flushedKeyLowWatermark) > 0 { - s.logger.Warn("lower watermark of flushed key is less than imported key", - zap.String("flushed", hex.EncodeToString(s.flushedKeyLowWatermark)), - zap.String("imported", hex.EncodeToString(s.importedKeyLowWatermark)), - ) - return - } - s.importedKeyLowWatermark = s.flushedKeyLowWatermark - s.importedKeyCnt = s.flushedKeyCnt - s.dirty = true -} - -// Close closes the checkpoint manager. -func (s *CheckpointManager) Close() { - err := s.updateCheckpoint() - if err != nil { - s.logger.Error("update checkpoint failed", zap.Error(err)) - } - - s.cancel() - s.updaterWg.Wait() - s.logger.Info("checkpoint manager closed") -} - -// GetTS returns the TS saved in checkpoint. -func (s *CheckpointManager) GetTS() uint64 { - s.mu.Lock() - defer s.mu.Unlock() - return s.ts -} - -// JobReorgMeta is the metadata for a reorg job. -type JobReorgMeta struct { - Checkpoint *ReorgCheckpoint `json:"reorg_checkpoint"` -} - -// ReorgCheckpoint is the checkpoint for a reorg job. -type ReorgCheckpoint struct { - LocalSyncKey kv.Key `json:"local_sync_key"` - LocalKeyCount int `json:"local_key_count"` - GlobalSyncKey kv.Key `json:"global_sync_key"` - GlobalKeyCount int `json:"global_key_count"` - InstanceAddr string `json:"instance_addr"` - - PhysicalID int64 `json:"physical_id"` - // TS of next engine ingest. - TS uint64 `json:"ts"` - - Version int64 `json:"version"` -} - -// JobCheckpointVersionCurrent is the current version of the checkpoint. -const ( - JobCheckpointVersionCurrent = JobCheckpointVersion1 - JobCheckpointVersion1 = 1 -) - -func (s *CheckpointManager) resumeOrInitCheckpoint() error { - sessCtx, err := s.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer s.sessPool.Put(sessCtx) - ddlSess := sess.NewSession(sessCtx) - err = ddlSess.RunInTxn(func(se *sess.Session) error { - template := "select reorg_meta from mysql.tidb_ddl_reorg where job_id = %d and ele_type = %s;" - sql := fmt.Sprintf(template, s.jobID, util.WrapKey2String(meta.IndexElementKey)) - ctx := kv.WithInternalSourceType(s.ctx, kv.InternalTxnBackfillDDLPrefix+"add_index") - rows, err := se.Execute(ctx, sql, "get_checkpoint") - if err != nil { - return errors.Trace(err) - } - - if len(rows) == 0 || rows[0].IsNull(0) { - return nil - } - rawReorgMeta := rows[0].GetBytes(0) - var reorgMeta JobReorgMeta - err = json.Unmarshal(rawReorgMeta, &reorgMeta) - if err != nil { - return errors.Trace(err) - } - if cp := reorgMeta.Checkpoint; cp != nil { - if cp.PhysicalID != s.physicalID { - s.logger.Info("checkpoint physical table ID mismatch", - zap.Int64("current", s.physicalID), - zap.Int64("get", cp.PhysicalID)) - return nil - } - s.importedKeyLowWatermark = cp.GlobalSyncKey - s.importedKeyCnt = cp.GlobalKeyCount - s.ts = cp.TS - folderNotEmpty := util.FolderNotEmpty(s.localStoreDir) - if folderNotEmpty && - (s.instanceAddr == cp.InstanceAddr || cp.InstanceAddr == "" /* initial state */) { - s.localDataIsValid = true - s.flushedKeyLowWatermark = cp.LocalSyncKey - s.flushedKeyCnt = cp.LocalKeyCount - } - s.logger.Info("resume checkpoint", - zap.String("flushed key low watermark", hex.EncodeToString(s.flushedKeyLowWatermark)), - zap.String("imported key low watermark", hex.EncodeToString(s.importedKeyLowWatermark)), - zap.Int64("physical table ID", cp.PhysicalID), - zap.String("previous instance", cp.InstanceAddr), - zap.String("current instance", s.instanceAddr), - zap.Bool("folder is empty", !folderNotEmpty)) - return nil - } - s.logger.Info("checkpoint not found") - return nil - }) - if err != nil { - return errors.Trace(err) - } - - if s.ts > 0 { - return nil - } - // if TS is not set, we need to allocate a TS and save it to the storage before - // continue. - p, l, err := s.pdCli.GetTS(s.ctx) - if err != nil { - return errors.Trace(err) - } - ts := oracle.ComposeTS(p, l) - s.ts = ts - return s.updateCheckpointImpl() -} - -// updateCheckpointImpl is only used by updateCheckpointLoop goroutine or in -// NewCheckpointManager. In other cases, use updateCheckpoint instead. -func (s *CheckpointManager) updateCheckpointImpl() error { - s.mu.Lock() - flushedKeyLowWatermark := s.flushedKeyLowWatermark - importedKeyLowWatermark := s.importedKeyLowWatermark - flushedKeyCnt := s.flushedKeyCnt - importedKeyCnt := s.importedKeyCnt - physicalID := s.physicalID - ts := s.ts - s.mu.Unlock() - - sessCtx, err := s.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer s.sessPool.Put(sessCtx) - ddlSess := sess.NewSession(sessCtx) - err = ddlSess.RunInTxn(func(se *sess.Session) error { - template := "update mysql.tidb_ddl_reorg set reorg_meta = %s where job_id = %d and ele_type = %s;" - cp := &ReorgCheckpoint{ - LocalSyncKey: flushedKeyLowWatermark, - GlobalSyncKey: importedKeyLowWatermark, - LocalKeyCount: flushedKeyCnt, - GlobalKeyCount: importedKeyCnt, - InstanceAddr: s.instanceAddr, - PhysicalID: physicalID, - TS: ts, - Version: JobCheckpointVersionCurrent, - } - rawReorgMeta, err := json.Marshal(JobReorgMeta{Checkpoint: cp}) - if err != nil { - return errors.Trace(err) - } - sql := fmt.Sprintf(template, util.WrapKey2String(rawReorgMeta), s.jobID, util.WrapKey2String(meta.IndexElementKey)) - ctx := kv.WithInternalSourceType(s.ctx, kv.InternalTxnBackfillDDLPrefix+"add_index") - _, err = se.Execute(ctx, sql, "update_checkpoint") - if err != nil { - return errors.Trace(err) - } - s.mu.Lock() - s.dirty = false - s.mu.Unlock() - return nil - }) - - logFunc := s.logger.Info - if err != nil { - logFunc = s.logger.With(zap.Error(err)).Error - } - logFunc("update checkpoint", - zap.String("local checkpoint", hex.EncodeToString(flushedKeyLowWatermark)), - zap.String("global checkpoint", hex.EncodeToString(importedKeyLowWatermark)), - zap.Int("flushed keys", flushedKeyCnt), - zap.Int("imported keys", importedKeyCnt), - zap.Int64("global physical ID", physicalID), - zap.Uint64("ts", ts)) - return err -} - -func (s *CheckpointManager) updateCheckpointLoop() { - failpoint.Inject("checkpointLoopExit", func() { - // used in a manual test - failpoint.Return() - }) - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - for { - select { - case finishCh := <-s.updaterCh: - err := s.updateCheckpointImpl() - if err != nil { - s.logger.Error("update checkpoint failed", zap.Error(err)) - } - close(finishCh) - case <-ticker.C: - s.mu.Lock() - if !s.dirty { - s.mu.Unlock() - continue - } - s.mu.Unlock() - err := s.updateCheckpointImpl() - if err != nil { - s.logger.Error("periodically update checkpoint failed", zap.Error(err)) - } - case <-s.ctx.Done(): - return - } - } -} - -func (s *CheckpointManager) updateCheckpoint() error { - failpoint.Inject("checkpointLoopExit", func() { - // used in a manual test - failpoint.Return(errors.New("failpoint triggered so can't update checkpoint")) - }) - finishCh := make(chan struct{}) - select { - case s.updaterCh <- finishCh: - case <-s.ctx.Done(): - return s.ctx.Err() - } - // wait updateCheckpointLoop to finish checkpoint update. - select { - case <-finishCh: - case <-s.ctx.Done(): - return s.ctx.Err() - } - return nil -} - -func (s *CheckpointManager) refreshTSAndUpdateCP() (uint64, error) { - p, l, err := s.pdCli.GetTS(s.ctx) - if err != nil { - return 0, errors.Trace(err) - } - newTS := oracle.ComposeTS(p, l) - s.mu.Lock() - s.ts = newTS - s.mu.Unlock() - return newTS, s.updateCheckpoint() -} diff --git a/pkg/ddl/ingest/disk_root.go b/pkg/ddl/ingest/disk_root.go index adfacda863dab..90e3fa4f62922 100644 --- a/pkg/ddl/ingest/disk_root.go +++ b/pkg/ddl/ingest/disk_root.go @@ -116,9 +116,9 @@ func (d *diskRootImpl) usageInfo() string { // PreCheckUsage implements DiskRoot interface. func (d *diskRootImpl) PreCheckUsage() error { - if _, _err_ := failpoint.Eval(_curpkg_("mockIngestCheckEnvFailed")); _err_ == nil { - return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs("mock error") - } + failpoint.Inject("mockIngestCheckEnvFailed", func(_ failpoint.Value) { + failpoint.Return(dbterror.ErrIngestCheckEnvFailed.FastGenByArgs("mock error")) + }) err := os.MkdirAll(d.path, 0700) if err != nil { return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(err.Error()) diff --git a/pkg/ddl/ingest/disk_root.go__failpoint_stash__ b/pkg/ddl/ingest/disk_root.go__failpoint_stash__ deleted file mode 100644 index 90e3fa4f62922..0000000000000 --- a/pkg/ddl/ingest/disk_root.go__failpoint_stash__ +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ingest - -import ( - "fmt" - "os" - "sync" - "sync/atomic" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - lcom "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util/dbterror" - "go.uber.org/zap" -) - -// DiskRoot is used to track the disk usage for the lightning backfill process. -type DiskRoot interface { - UpdateUsage() - ShouldImport() bool - UsageInfo() string - PreCheckUsage() error - StartupCheck() error -} - -const capacityThreshold = 0.9 - -// diskRootImpl implements DiskRoot interface. -type diskRootImpl struct { - path string - capacity uint64 - used uint64 - bcUsed uint64 - bcCtx *litBackendCtxMgr - mu sync.RWMutex - updating atomic.Bool -} - -// NewDiskRootImpl creates a new DiskRoot. -func NewDiskRootImpl(path string, bcCtx *litBackendCtxMgr) DiskRoot { - return &diskRootImpl{ - path: path, - bcCtx: bcCtx, - } -} - -// UpdateUsage implements DiskRoot interface. -func (d *diskRootImpl) UpdateUsage() { - if !d.updating.CompareAndSwap(false, true) { - return - } - bcUsed := d.bcCtx.TotalDiskUsage() - var capacity, used uint64 - sz, err := lcom.GetStorageSize(d.path) - if err != nil { - logutil.DDLIngestLogger().Error(LitErrGetStorageQuota, zap.Error(err)) - } else { - capacity, used = sz.Capacity, sz.Capacity-sz.Available - } - d.updating.Store(false) - d.mu.Lock() - d.bcUsed = bcUsed - d.capacity = capacity - d.used = used - d.mu.Unlock() -} - -// ShouldImport implements DiskRoot interface. -func (d *diskRootImpl) ShouldImport() bool { - d.mu.RLock() - defer d.mu.RUnlock() - if d.bcUsed > variable.DDLDiskQuota.Load() { - logutil.DDLIngestLogger().Info("disk usage is over quota", - zap.Uint64("quota", variable.DDLDiskQuota.Load()), - zap.String("usage", d.usageInfo())) - return true - } - if d.used == 0 && d.capacity == 0 { - return false - } - if float64(d.used) >= float64(d.capacity)*capacityThreshold { - logutil.DDLIngestLogger().Warn("available disk space is less than 10%, "+ - "this may degrade the performance, "+ - "please make sure the disk available space is larger than @@tidb_ddl_disk_quota before adding index", - zap.String("usage", d.usageInfo())) - return true - } - return false -} - -// UsageInfo implements DiskRoot interface. -func (d *diskRootImpl) UsageInfo() string { - d.mu.RLock() - defer d.mu.RUnlock() - return d.usageInfo() -} - -func (d *diskRootImpl) usageInfo() string { - return fmt.Sprintf("disk usage: %d/%d, backend usage: %d", d.used, d.capacity, d.bcUsed) -} - -// PreCheckUsage implements DiskRoot interface. -func (d *diskRootImpl) PreCheckUsage() error { - failpoint.Inject("mockIngestCheckEnvFailed", func(_ failpoint.Value) { - failpoint.Return(dbterror.ErrIngestCheckEnvFailed.FastGenByArgs("mock error")) - }) - err := os.MkdirAll(d.path, 0700) - if err != nil { - return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(err.Error()) - } - sz, err := lcom.GetStorageSize(d.path) - if err != nil { - return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(err.Error()) - } - if RiskOfDiskFull(sz.Available, sz.Capacity) { - logutil.DDLIngestLogger().Warn("available disk space is less than 10%, cannot use ingest mode", - zap.String("sort path", d.path), - zap.String("usage", d.usageInfo())) - msg := fmt.Sprintf("no enough space in %s", d.path) - return dbterror.ErrIngestCheckEnvFailed.FastGenByArgs(msg) - } - return nil -} - -// StartupCheck implements DiskRoot interface. -func (d *diskRootImpl) StartupCheck() error { - sz, err := lcom.GetStorageSize(d.path) - if err != nil { - return errors.Trace(err) - } - quota := variable.DDLDiskQuota.Load() - if sz.Available < quota { - return errors.Errorf("the available disk space(%d) in %s should be greater than @@tidb_ddl_disk_quota(%d)", - sz.Available, d.path, quota) - } - return nil -} - -// RiskOfDiskFull checks if the disk has less than 10% space. -func RiskOfDiskFull(available, capacity uint64) bool { - return float64(available) < (1-capacityThreshold)*float64(capacity) -} diff --git a/pkg/ddl/ingest/env.go b/pkg/ddl/ingest/env.go index 6aaff89dea58d..fda7e76720319 100644 --- a/pkg/ddl/ingest/env.go +++ b/pkg/ddl/ingest/env.go @@ -70,11 +70,11 @@ func InitGlobalLightningEnv(path string) (ok bool) { } else { memTotal = memTotal / 2 } - if val, _err_ := failpoint.Eval(_curpkg_("setMemTotalInMB")); _err_ == nil { + failpoint.Inject("setMemTotalInMB", func(val failpoint.Value) { //nolint: forcetypeassert i := val.(int) memTotal = uint64(i) * size.MB - } + }) LitBackCtxMgr = NewLitBackendCtxMgr(path, memTotal) litRLimit = util.GenRLimit("ddl-ingest") LitInitialized = true diff --git a/pkg/ddl/ingest/env.go__failpoint_stash__ b/pkg/ddl/ingest/env.go__failpoint_stash__ deleted file mode 100644 index fda7e76720319..0000000000000 --- a/pkg/ddl/ingest/env.go__failpoint_stash__ +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ingest - -import ( - "context" - "fmt" - "os" - "path/filepath" - "slices" - "strconv" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/size" - "go.uber.org/zap" - "golang.org/x/exp/maps" -) - -var ( - // LitBackCtxMgr is the entry for the lightning backfill process. - LitBackCtxMgr BackendCtxMgr - // LitMemRoot is used to track the memory usage of the lightning backfill process. - LitMemRoot MemRoot - // litDiskRoot is used to track the disk usage of the lightning backfill process. - litDiskRoot DiskRoot - // litRLimit is the max open file number of the lightning backfill process. - litRLimit uint64 - // LitInitialized is the flag indicates whether the lightning backfill process is initialized. - LitInitialized bool -) - -const defaultMemoryQuota = 2 * size.GB - -// InitGlobalLightningEnv initialize Lightning backfill environment. -func InitGlobalLightningEnv(path string) (ok bool) { - log.SetAppLogger(logutil.DDLIngestLogger()) - globalCfg := config.GetGlobalConfig() - if globalCfg.Store != "tikv" { - logutil.DDLIngestLogger().Warn(LitWarnEnvInitFail, - zap.String("storage limitation", "only support TiKV storage"), - zap.String("current storage", globalCfg.Store), - zap.Bool("lightning is initialized", LitInitialized)) - return false - } - memTotal, err := memory.MemTotal() - if err != nil { - logutil.DDLIngestLogger().Warn("get total memory fail", zap.Error(err)) - memTotal = defaultMemoryQuota - } else { - memTotal = memTotal / 2 - } - failpoint.Inject("setMemTotalInMB", func(val failpoint.Value) { - //nolint: forcetypeassert - i := val.(int) - memTotal = uint64(i) * size.MB - }) - LitBackCtxMgr = NewLitBackendCtxMgr(path, memTotal) - litRLimit = util.GenRLimit("ddl-ingest") - LitInitialized = true - logutil.DDLIngestLogger().Info(LitInfoEnvInitSucc, - zap.Uint64("memory limitation", memTotal), - zap.String("disk usage info", litDiskRoot.UsageInfo()), - zap.Uint64("max open file number", litRLimit), - zap.Bool("lightning is initialized", LitInitialized)) - return true -} - -// GenIngestTempDataDir generates a path for DDL ingest. -// Format: ${temp-dir}/tmp_ddl-{port} -func GenIngestTempDataDir() (string, error) { - tidbCfg := config.GetGlobalConfig() - sortPathSuffix := "/tmp_ddl-" + strconv.Itoa(int(tidbCfg.Port)) - sortPath := filepath.Join(tidbCfg.TempDir, sortPathSuffix) - - if _, err := os.Stat(sortPath); err != nil { - if !os.IsNotExist(err) { - logutil.DDLIngestLogger().Error(LitErrStatDirFail, - zap.String("sort path", sortPath), zap.Error(err)) - return "", err - } - } - err := os.MkdirAll(sortPath, 0o700) - if err != nil { - logutil.DDLIngestLogger().Error(LitErrCreateDirFail, - zap.String("sort path", sortPath), zap.Error(err)) - return "", err - } - logutil.DDLIngestLogger().Info(LitInfoSortDir, zap.String("data path", sortPath)) - return sortPath, nil -} - -// CleanUpTempDir is used to remove the stale index data. -// This function gets running DDL jobs from `mysql.tidb_ddl_job` and -// it only removes the folders that related to finished jobs. -func CleanUpTempDir(ctx context.Context, se sessionctx.Context, path string) { - entries, err := os.ReadDir(path) - if err != nil { - if strings.Contains(err.Error(), "no such file") { - return - } - logutil.DDLIngestLogger().Warn(LitErrCleanSortPath, zap.Error(err)) - return - } - toCheckJobIDs := make(map[int64]struct{}, len(entries)) - for _, entry := range entries { - if !entry.IsDir() { - continue - } - jobID, err := decodeBackendTag(entry.Name()) - if err != nil { - logutil.DDLIngestLogger().Error(LitErrCleanSortPath, zap.Error(err)) - continue - } - toCheckJobIDs[jobID] = struct{}{} - } - - if len(toCheckJobIDs) == 0 { - return - } - - idSlice := maps.Keys(toCheckJobIDs) - slices.Sort(idSlice) - processing, err := filterProcessingJobIDs(ctx, sess.NewSession(se), idSlice) - if err != nil { - logutil.DDLIngestLogger().Error(LitErrCleanSortPath, zap.Error(err)) - return - } - - for _, id := range processing { - delete(toCheckJobIDs, id) - } - - if len(toCheckJobIDs) == 0 { - return - } - - for id := range toCheckJobIDs { - logutil.DDLIngestLogger().Info("remove stale temp index data", - zap.Int64("jobID", id)) - p := filepath.Join(path, encodeBackendTag(id)) - err = os.RemoveAll(p) - if err != nil { - logutil.DDLIngestLogger().Error(LitErrCleanSortPath, zap.Error(err)) - } - } -} - -func filterProcessingJobIDs(ctx context.Context, se *sess.Session, jobIDs []int64) ([]int64, error) { - var sb strings.Builder - for i, id := range jobIDs { - if i != 0 { - sb.WriteString(",") - } - sb.WriteString(strconv.FormatInt(id, 10)) - } - sql := fmt.Sprintf( - "SELECT job_id FROM mysql.tidb_ddl_job WHERE job_id IN (%s)", - sb.String()) - rows, err := se.Execute(ctx, sql, "filter_processing_job_ids") - if err != nil { - return nil, errors.Trace(err) - } - ret := make([]int64, 0, len(rows)) - for _, row := range rows { - ret = append(ret, row.GetInt64(0)) - } - return ret, nil -} diff --git a/pkg/ddl/ingest/mock.go b/pkg/ddl/ingest/mock.go index f7361d490c833..4d9261ecfc672 100644 --- a/pkg/ddl/ingest/mock.go +++ b/pkg/ddl/ingest/mock.go @@ -206,7 +206,7 @@ func (m *MockWriter) WriteRow(_ context.Context, key, idxVal []byte, _ kv.Handle zap.String("key", hex.EncodeToString(key)), zap.String("idxVal", hex.EncodeToString(idxVal))) - failpoint.Call(_curpkg_("onMockWriterWriteRow")) + failpoint.InjectCall("onMockWriterWriteRow") m.mu.Lock() defer m.mu.Unlock() if m.onWrite != nil { diff --git a/pkg/ddl/ingest/mock.go__failpoint_stash__ b/pkg/ddl/ingest/mock.go__failpoint_stash__ deleted file mode 100644 index 4d9261ecfc672..0000000000000 --- a/pkg/ddl/ingest/mock.go__failpoint_stash__ +++ /dev/null @@ -1,236 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ingest - -import ( - "context" - "encoding/hex" - "os" - "path/filepath" - "strconv" - "sync" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - pd "github.com/tikv/pd/client" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" -) - -// MockBackendCtxMgr is a mock backend context manager. -type MockBackendCtxMgr struct { - sessCtxProvider func() sessionctx.Context - runningJobs map[int64]*MockBackendCtx -} - -var _ BackendCtxMgr = (*MockBackendCtxMgr)(nil) - -// NewMockBackendCtxMgr creates a new mock backend context manager. -func NewMockBackendCtxMgr(sessCtxProvider func() sessionctx.Context) *MockBackendCtxMgr { - return &MockBackendCtxMgr{ - sessCtxProvider: sessCtxProvider, - runningJobs: make(map[int64]*MockBackendCtx), - } -} - -// CheckMoreTasksAvailable implements BackendCtxMgr.CheckMoreTaskAvailable interface. -func (m *MockBackendCtxMgr) CheckMoreTasksAvailable() (bool, error) { - return len(m.runningJobs) == 0, nil -} - -// Register implements BackendCtxMgr.Register interface. -func (m *MockBackendCtxMgr) Register(ctx context.Context, jobID int64, unique bool, etcdClient *clientv3.Client, pdSvcDiscovery pd.ServiceDiscovery, resourceGroupName string) (BackendCtx, error) { - logutil.DDLIngestLogger().Info("mock backend mgr register", zap.Int64("jobID", jobID)) - if mockCtx, ok := m.runningJobs[jobID]; ok { - return mockCtx, nil - } - sessCtx := m.sessCtxProvider() - mockCtx := &MockBackendCtx{ - mu: sync.Mutex{}, - sessCtx: sessCtx, - jobID: jobID, - } - m.runningJobs[jobID] = mockCtx - return mockCtx, nil -} - -// Unregister implements BackendCtxMgr.Unregister interface. -func (m *MockBackendCtxMgr) Unregister(jobID int64) { - if mCtx, ok := m.runningJobs[jobID]; ok { - mCtx.sessCtx.StmtCommit(context.Background()) - err := mCtx.sessCtx.CommitTxn(context.Background()) - logutil.DDLIngestLogger().Info("mock backend mgr unregister", zap.Int64("jobID", jobID), zap.Error(err)) - delete(m.runningJobs, jobID) - } -} - -// EncodeJobSortPath implements BackendCtxMgr interface. -func (m *MockBackendCtxMgr) EncodeJobSortPath(int64) string { - return "" -} - -// Load implements BackendCtxMgr.Load interface. -func (m *MockBackendCtxMgr) Load(jobID int64) (BackendCtx, bool) { - logutil.DDLIngestLogger().Info("mock backend mgr load", zap.Int64("jobID", jobID)) - if mockCtx, ok := m.runningJobs[jobID]; ok { - return mockCtx, true - } - return nil, false -} - -// ResetSessCtx is only used for mocking test. -func (m *MockBackendCtxMgr) ResetSessCtx() { - for _, mockCtx := range m.runningJobs { - mockCtx.sessCtx = m.sessCtxProvider() - } -} - -// MockBackendCtx is a mock backend context. -type MockBackendCtx struct { - sessCtx sessionctx.Context - mu sync.Mutex - jobID int64 - checkpointMgr *CheckpointManager -} - -// Register implements BackendCtx.Register interface. -func (m *MockBackendCtx) Register(indexIDs []int64, _ []bool, _ table.Table) ([]Engine, error) { - logutil.DDLIngestLogger().Info("mock backend ctx register", zap.Int64("jobID", m.jobID), zap.Int64s("indexIDs", indexIDs)) - ret := make([]Engine, 0, len(indexIDs)) - for range indexIDs { - ret = append(ret, &MockEngineInfo{sessCtx: m.sessCtx, mu: &m.mu}) - } - return ret, nil -} - -// FinishAndUnregisterEngines implements BackendCtx interface. -func (*MockBackendCtx) FinishAndUnregisterEngines(_ UnregisterOpt) error { - logutil.DDLIngestLogger().Info("mock backend ctx unregister") - return nil -} - -// CollectRemoteDuplicateRows implements BackendCtx.CollectRemoteDuplicateRows interface. -func (*MockBackendCtx) CollectRemoteDuplicateRows(indexID int64, _ table.Table) error { - logutil.DDLIngestLogger().Info("mock backend ctx collect remote duplicate rows", zap.Int64("indexID", indexID)) - return nil -} - -// Flush implements BackendCtx.Flush interface. -func (*MockBackendCtx) Flush(mode FlushMode) (flushed, imported bool, err error) { - return false, false, nil -} - -// AttachCheckpointManager attaches a checkpoint manager to the backend context. -func (m *MockBackendCtx) AttachCheckpointManager(mgr *CheckpointManager) { - m.checkpointMgr = mgr -} - -// GetCheckpointManager returns the checkpoint manager attached to the backend context. -func (m *MockBackendCtx) GetCheckpointManager() *CheckpointManager { - return m.checkpointMgr -} - -// GetLocalBackend returns the local backend. -func (m *MockBackendCtx) GetLocalBackend() *local.Backend { - b := &local.Backend{} - b.LocalStoreDir = filepath.Join(os.TempDir(), "mock_backend", strconv.FormatInt(m.jobID, 10)) - return b -} - -// MockWriteHook the hook for write in mock engine. -type MockWriteHook func(key, val []byte) - -// MockEngineInfo is a mock engine info. -type MockEngineInfo struct { - sessCtx sessionctx.Context - mu *sync.Mutex - - onWrite MockWriteHook -} - -// NewMockEngineInfo creates a new mock engine info. -func NewMockEngineInfo(sessCtx sessionctx.Context) *MockEngineInfo { - return &MockEngineInfo{ - sessCtx: sessCtx, - mu: &sync.Mutex{}, - } -} - -// Flush implements Engine.Flush interface. -func (*MockEngineInfo) Flush() error { - return nil -} - -// Close implements Engine.Close interface. -func (*MockEngineInfo) Close(_ bool) { -} - -// SetHook set the write hook. -func (m *MockEngineInfo) SetHook(onWrite func(key, val []byte)) { - m.onWrite = onWrite -} - -// CreateWriter implements Engine.CreateWriter interface. -func (m *MockEngineInfo) CreateWriter(id int, _ *backend.LocalWriterConfig) (Writer, error) { - logutil.DDLIngestLogger().Info("mock engine info create writer", zap.Int("id", id)) - return &MockWriter{sessCtx: m.sessCtx, mu: m.mu, onWrite: m.onWrite}, nil -} - -// MockWriter is a mock writer. -type MockWriter struct { - sessCtx sessionctx.Context - mu *sync.Mutex - onWrite MockWriteHook -} - -// WriteRow implements Writer.WriteRow interface. -func (m *MockWriter) WriteRow(_ context.Context, key, idxVal []byte, _ kv.Handle) error { - logutil.DDLIngestLogger().Info("mock writer write row", - zap.String("key", hex.EncodeToString(key)), - zap.String("idxVal", hex.EncodeToString(idxVal))) - - failpoint.InjectCall("onMockWriterWriteRow") - m.mu.Lock() - defer m.mu.Unlock() - if m.onWrite != nil { - m.onWrite(key, idxVal) - return nil - } - txn, err := m.sessCtx.Txn(true) - if err != nil { - return err - } - err = txn.Set(key, idxVal) - if err != nil { - return err - } - if MockExecAfterWriteRow != nil { - MockExecAfterWriteRow() - } - return nil -} - -// LockForWrite implements Writer.LockForWrite interface. -func (*MockWriter) LockForWrite() func() { - return func() {} -} - -// MockExecAfterWriteRow is only used for test. -var MockExecAfterWriteRow func() diff --git a/pkg/ddl/job_scheduler.go b/pkg/ddl/job_scheduler.go index 75d1257c70fec..1c3446b36704a 100644 --- a/pkg/ddl/job_scheduler.go +++ b/pkg/ddl/job_scheduler.go @@ -177,7 +177,7 @@ func (s *jobScheduler) close() { if s.generalDDLWorkerPool != nil { s.generalDDLWorkerPool.close() } - failpoint.Call(_curpkg_("afterSchedulerClose")) + failpoint.InjectCall("afterSchedulerClose") } func hasSysDB(job *model.Job) bool { @@ -286,7 +286,7 @@ func (s *jobScheduler) schedule() error { if err := s.schCtx.Err(); err != nil { return err } - if _, _err_ := failpoint.Eval(_curpkg_("ownerResignAfterDispatchLoopCheck")); _err_ == nil { + failpoint.Inject("ownerResignAfterDispatchLoopCheck", func() { if ingest.ResignOwnerForTest.Load() { err2 := s.ownerManager.ResignOwner(context.Background()) if err2 != nil { @@ -294,7 +294,7 @@ func (s *jobScheduler) schedule() error { } ingest.ResignOwnerForTest.Store(false) } - } + }) select { case <-s.ddlJobNotifyCh: case <-ticker.C: @@ -311,7 +311,7 @@ func (s *jobScheduler) schedule() error { if err := s.checkAndUpdateClusterState(false); err != nil { continue } - failpoint.Call(_curpkg_("beforeLoadAndDeliverJobs")) + failpoint.InjectCall("beforeLoadAndDeliverJobs") if err := s.loadAndDeliverJobs(se); err != nil { logutil.SampleLogger().Warn("load and deliver jobs failed", zap.Error(err)) } @@ -451,7 +451,7 @@ func (s *jobScheduler) mustReloadSchemas() { // the worker will run the job until it's finished, paused or another owner takes // over and finished it. func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) { - failpoint.Call(_curpkg_("beforeDeliveryJob"), job) + failpoint.InjectCall("beforeDeliveryJob", job) injectFailPointForGetJob(job) jobID, involvedSchemaInfos := job.ID, job.GetInvolvingSchemaInfo() s.runningJobs.addRunning(jobID, involvedSchemaInfos) @@ -462,7 +462,7 @@ func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) if r != nil { logutil.DDLLogger().Error("panic in deliveryJob", zap.Any("recover", r), zap.Stack("stack")) } - failpoint.Call(_curpkg_("afterDeliveryJob"), job) + failpoint.InjectCall("afterDeliveryJob", job) // Because there is a gap between `allIDs()` and `checkRunnable()`, // we append unfinished job to pending atomically to prevent `getJob()` // chosing another runnable job that involves the same schema object. @@ -483,10 +483,10 @@ func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) // or the job is finished by another owner. // TODO for JobStateRollbackDone we have to query 1 additional time when the // job is already moved to history. - failpoint.Call(_curpkg_("beforeRefreshJob"), job) + failpoint.InjectCall("beforeRefreshJob", job) for { job, err = s.sysTblMgr.GetJobByID(s.schCtx, jobID) - failpoint.Call(_curpkg_("mockGetJobByIDFail"), &err) + failpoint.InjectCall("mockGetJobByIDFail", &err) if err == nil { break } @@ -511,7 +511,7 @@ func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) // transitOneJobStepAndWaitSync runs one step of the DDL job, persist it and // waits for other TiDB node to synchronize. func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) error { - failpoint.Call(_curpkg_("beforeRunOneJobStep")) + failpoint.InjectCall("beforeRunOneJobStep") ownerID := s.ownerManager.ID() // suppose we failed to sync version last time, we need to check and sync it // before run to maintain the 2-version invariant. @@ -545,16 +545,16 @@ func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) tidblogutil.Logger(wk.logCtx).Info("handle ddl job failed", zap.Error(err), zap.Stringer("job", job)) return err } - if val, _err_ := failpoint.Eval(_curpkg_("mockDownBeforeUpdateGlobalVersion")); _err_ == nil { + failpoint.Inject("mockDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { if val.(bool) { if mockDDLErrOnce == 0 { mockDDLErrOnce = schemaVer - return errors.New("mock down before update global version") + failpoint.Return(errors.New("mock down before update global version")) } } - } + }) - failpoint.Call(_curpkg_("beforeWaitSchemaChanged"), job, schemaVer) + failpoint.InjectCall("beforeWaitSchemaChanged", job, schemaVer) // Here means the job enters another state (delete only, write only, public, etc...) or is cancelled. // If the job is done or still running or rolling back, we will wait 2 * lease time or util MDL synced to guarantee other servers to update // the newest schema. @@ -564,7 +564,7 @@ func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) s.cleanMDLInfo(job, ownerID) s.synced(job) - failpoint.Call(_curpkg_("onJobUpdated"), job) + failpoint.InjectCall("onJobUpdated", job) return nil } @@ -626,11 +626,11 @@ const ( ) func insertDDLJobs2Table(ctx context.Context, se *sess.Session, jobWs ...*JobWrapper) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockAddBatchDDLJobsErr")); _err_ == nil { + failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { if val.(bool) { - return errors.Errorf("mockAddBatchDDLJobsErr") + failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) } - } + }) if len(jobWs) == 0 { return nil } diff --git a/pkg/ddl/job_scheduler.go__failpoint_stash__ b/pkg/ddl/job_scheduler.go__failpoint_stash__ deleted file mode 100644 index 1c3446b36704a..0000000000000 --- a/pkg/ddl/job_scheduler.go__failpoint_stash__ +++ /dev/null @@ -1,837 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "context" - "encoding/hex" - "encoding/json" - "fmt" - "runtime" - "slices" - "strconv" - "strings" - "sync/atomic" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/ddl/syncer" - "github.com/pingcap/tidb/pkg/ddl/systable" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/owner" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/intest" - tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/zap" -) - -var ( - // addingDDLJobNotifyKey is the key in etcd to notify DDL scheduler that there - // is a new DDL job. - addingDDLJobNotifyKey = "/tidb/ddl/add_ddl_job_general" - dispatchLoopWaitingDuration = 1 * time.Second - schedulerLoopRetryInterval = time.Second -) - -func init() { - // In test the wait duration can be reduced to make test case run faster - if intest.InTest { - dispatchLoopWaitingDuration = 50 * time.Millisecond - } -} - -type jobType int - -func (t jobType) String() string { - switch t { - case jobTypeGeneral: - return "general" - case jobTypeReorg: - return "reorg" - } - return "unknown job type: " + strconv.Itoa(int(t)) -} - -const ( - jobTypeGeneral jobType = iota - jobTypeReorg -) - -type ownerListener struct { - ddl *ddl - scheduler *jobScheduler -} - -var _ owner.Listener = (*ownerListener)(nil) - -func (l *ownerListener) OnBecomeOwner() { - ctx, cancelFunc := context.WithCancel(l.ddl.ddlCtx.ctx) - sysTblMgr := systable.NewManager(l.ddl.sessPool) - l.scheduler = &jobScheduler{ - schCtx: ctx, - cancel: cancelFunc, - runningJobs: newRunningJobs(), - sysTblMgr: sysTblMgr, - schemaLoader: l.ddl.schemaLoader, - minJobIDRefresher: l.ddl.minJobIDRefresher, - - ddlCtx: l.ddl.ddlCtx, - ddlJobNotifyCh: l.ddl.ddlJobNotifyCh, - sessPool: l.ddl.sessPool, - delRangeMgr: l.ddl.delRangeMgr, - } - l.ddl.reorgCtx.setOwnerTS(time.Now().Unix()) - l.scheduler.start() -} - -func (l *ownerListener) OnRetireOwner() { - if l.scheduler == nil { - return - } - l.scheduler.close() -} - -// jobScheduler is used to schedule the DDL jobs, it's only run on the DDL owner. -type jobScheduler struct { - // *ddlCtx already have context named as "ctx", so we use "schCtx" here to avoid confusion. - schCtx context.Context - cancel context.CancelFunc - wg tidbutil.WaitGroupWrapper - runningJobs *runningJobs - sysTblMgr systable.Manager - schemaLoader SchemaLoader - minJobIDRefresher *systable.MinJobIDRefresher - - // those fields are created or initialized on start - reorgWorkerPool *workerPool - generalDDLWorkerPool *workerPool - seqAllocator atomic.Uint64 - - // those fields are shared with 'ddl' instance - // TODO ddlCtx is too large for here, we should remove dependency on it. - *ddlCtx - ddlJobNotifyCh chan struct{} - sessPool *sess.Pool - delRangeMgr delRangeManager -} - -func (s *jobScheduler) start() { - workerFactory := func(tp workerType) func() (pools.Resource, error) { - return func() (pools.Resource, error) { - wk := newWorker(s.schCtx, tp, s.sessPool, s.delRangeMgr, s.ddlCtx) - sessForJob, err := s.sessPool.Get() - if err != nil { - return nil, err - } - wk.seqAllocator = &s.seqAllocator - sessForJob.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - wk.sess = sess.NewSession(sessForJob) - metrics.DDLCounter.WithLabelValues(fmt.Sprintf("%s_%s", metrics.CreateDDL, wk.String())).Inc() - return wk, nil - } - } - // reorg worker count at least 1 at most 10. - reorgCnt := min(max(runtime.GOMAXPROCS(0)/4, 1), reorgWorkerCnt) - s.reorgWorkerPool = newDDLWorkerPool(pools.NewResourcePool(workerFactory(addIdxWorker), reorgCnt, reorgCnt, 0), jobTypeReorg) - s.generalDDLWorkerPool = newDDLWorkerPool(pools.NewResourcePool(workerFactory(generalWorker), generalWorkerCnt, generalWorkerCnt, 0), jobTypeGeneral) - s.wg.RunWithLog(s.scheduleLoop) - s.wg.RunWithLog(func() { - s.schemaSyncer.SyncJobSchemaVerLoop(s.schCtx) - }) -} - -func (s *jobScheduler) close() { - s.cancel() - s.wg.Wait() - if s.reorgWorkerPool != nil { - s.reorgWorkerPool.close() - } - if s.generalDDLWorkerPool != nil { - s.generalDDLWorkerPool.close() - } - failpoint.InjectCall("afterSchedulerClose") -} - -func hasSysDB(job *model.Job) bool { - for _, info := range job.GetInvolvingSchemaInfo() { - if tidbutil.IsSysDB(info.Database) { - return true - } - } - return false -} - -func (s *jobScheduler) processJobDuringUpgrade(sess *sess.Session, job *model.Job) (isRunnable bool, err error) { - if s.stateSyncer.IsUpgradingState() { - if job.IsPaused() { - return false, nil - } - // We need to turn the 'pausing' job to be 'paused' in ddl worker, - // and stop the reorganization workers - if job.IsPausing() || hasSysDB(job) { - return true, nil - } - var errs []error - // During binary upgrade, pause all running DDL jobs - errs, err = PauseJobsBySystem(sess.Session(), []int64{job.ID}) - if len(errs) > 0 && errs[0] != nil { - err = errs[0] - } - - if err != nil { - isCannotPauseDDLJobErr := dbterror.ErrCannotPauseDDLJob.Equal(err) - logutil.DDLUpgradingLogger().Warn("pause the job failed", zap.Stringer("job", job), - zap.Bool("isRunnable", isCannotPauseDDLJobErr), zap.Error(err)) - if isCannotPauseDDLJobErr { - return true, nil - } - } else { - logutil.DDLUpgradingLogger().Warn("pause the job successfully", zap.Stringer("job", job)) - } - - return false, nil - } - - if job.IsPausedBySystem() { - var errs []error - errs, err = ResumeJobsBySystem(sess.Session(), []int64{job.ID}) - if len(errs) > 0 && errs[0] != nil { - logutil.DDLUpgradingLogger().Warn("normal cluster state, resume the job failed", zap.Stringer("job", job), zap.Error(errs[0])) - return false, errs[0] - } - if err != nil { - logutil.DDLUpgradingLogger().Warn("normal cluster state, resume the job failed", zap.Stringer("job", job), zap.Error(err)) - return false, err - } - logutil.DDLUpgradingLogger().Warn("normal cluster state, resume the job successfully", zap.Stringer("job", job)) - return false, errors.Errorf("system paused job:%d need to be resumed", job.ID) - } - - if job.IsPaused() { - return false, nil - } - - return true, nil -} - -func (s *jobScheduler) scheduleLoop() { - const retryInterval = 3 * time.Second - for { - err := s.schedule() - if err == context.Canceled { - logutil.DDLLogger().Info("scheduleLoop quit due to context canceled") - return - } - logutil.DDLLogger().Warn("scheduleLoop failed, retrying", - zap.Error(err)) - - select { - case <-s.schCtx.Done(): - logutil.DDLLogger().Info("scheduleLoop quit due to context done") - return - case <-time.After(retryInterval): - } - } -} - -func (s *jobScheduler) schedule() error { - sessCtx, err := s.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer s.sessPool.Put(sessCtx) - se := sess.NewSession(sessCtx) - var notifyDDLJobByEtcdCh clientv3.WatchChan - if s.etcdCli != nil { - notifyDDLJobByEtcdCh = s.etcdCli.Watch(s.schCtx, addingDDLJobNotifyKey) - } - if err := s.checkAndUpdateClusterState(true); err != nil { - return errors.Trace(err) - } - ticker := time.NewTicker(dispatchLoopWaitingDuration) - defer ticker.Stop() - // TODO move waitSchemaSyncedController out of ddlCtx. - s.clearOnceMap() - s.mustReloadSchemas() - - for { - if err := s.schCtx.Err(); err != nil { - return err - } - failpoint.Inject("ownerResignAfterDispatchLoopCheck", func() { - if ingest.ResignOwnerForTest.Load() { - err2 := s.ownerManager.ResignOwner(context.Background()) - if err2 != nil { - logutil.DDLLogger().Info("resign meet error", zap.Error(err2)) - } - ingest.ResignOwnerForTest.Store(false) - } - }) - select { - case <-s.ddlJobNotifyCh: - case <-ticker.C: - case _, ok := <-notifyDDLJobByEtcdCh: - if !ok { - logutil.DDLLogger().Warn("start worker watch channel closed", zap.String("watch key", addingDDLJobNotifyKey)) - notifyDDLJobByEtcdCh = s.etcdCli.Watch(s.schCtx, addingDDLJobNotifyKey) - time.Sleep(time.Second) - continue - } - case <-s.schCtx.Done(): - return s.schCtx.Err() - } - if err := s.checkAndUpdateClusterState(false); err != nil { - continue - } - failpoint.InjectCall("beforeLoadAndDeliverJobs") - if err := s.loadAndDeliverJobs(se); err != nil { - logutil.SampleLogger().Warn("load and deliver jobs failed", zap.Error(err)) - } - } -} - -// TODO make it run in a separate routine. -func (s *jobScheduler) checkAndUpdateClusterState(needUpdate bool) error { - select { - case _, ok := <-s.stateSyncer.WatchChan(): - if !ok { - // TODO stateSyncer should only be started when we are the owner, and use - // the context of scheduler, will refactor it later. - s.stateSyncer.Rewatch(s.ddlCtx.ctx) - } - default: - if !needUpdate { - return nil - } - } - - oldState := s.stateSyncer.IsUpgradingState() - stateInfo, err := s.stateSyncer.GetGlobalState(s.schCtx) - if err != nil { - logutil.DDLLogger().Warn("get global state failed", zap.Error(err)) - return errors.Trace(err) - } - logutil.DDLLogger().Info("get global state and global state change", - zap.Bool("oldState", oldState), zap.Bool("currState", s.stateSyncer.IsUpgradingState())) - - ownerOp := owner.OpNone - if stateInfo.State == syncer.StateUpgrading { - ownerOp = owner.OpSyncUpgradingState - } - err = s.ownerManager.SetOwnerOpValue(s.schCtx, ownerOp) - if err != nil { - logutil.DDLLogger().Warn("the owner sets global state to owner operator value failed", zap.Error(err)) - return errors.Trace(err) - } - logutil.DDLLogger().Info("the owner sets owner operator value", zap.Stringer("ownerOp", ownerOp)) - return nil -} - -func (s *jobScheduler) loadAndDeliverJobs(se *sess.Session) error { - if s.generalDDLWorkerPool.available() == 0 && s.reorgWorkerPool.available() == 0 { - return nil - } - - defer s.runningJobs.resetAllPending() - - const getJobSQL = `select reorg, job_meta from mysql.tidb_ddl_job where job_id >= %d %s order by job_id` - var whereClause string - if ids := s.runningJobs.allIDs(); len(ids) > 0 { - whereClause = fmt.Sprintf("and job_id not in (%s)", ids) - } - sql := fmt.Sprintf(getJobSQL, s.minJobIDRefresher.GetCurrMinJobID(), whereClause) - rows, err := se.Execute(context.Background(), sql, "load_ddl_jobs") - if err != nil { - return errors.Trace(err) - } - for _, row := range rows { - reorgJob := row.GetInt64(0) == 1 - targetPool := s.generalDDLWorkerPool - if reorgJob { - targetPool = s.reorgWorkerPool - } - jobBinary := row.GetBytes(1) - - job := model.Job{} - err = job.Decode(jobBinary) - if err != nil { - return errors.Trace(err) - } - - involving := job.GetInvolvingSchemaInfo() - if targetPool.available() == 0 { - s.runningJobs.addPending(involving) - continue - } - - isRunnable, err := s.processJobDuringUpgrade(se, &job) - if err != nil { - return errors.Trace(err) - } - if !isRunnable { - s.runningJobs.addPending(involving) - continue - } - - if !s.runningJobs.checkRunnable(job.ID, involving) { - s.runningJobs.addPending(involving) - continue - } - - wk, err := targetPool.get() - if err != nil { - return errors.Trace(err) - } - intest.Assert(wk != nil, "worker should not be nil") - if wk == nil { - // should not happen, we have checked available() before, and we are - // the only routine consumes worker. - logutil.DDLLogger().Info("no worker available now", zap.Stringer("type", targetPool.tp())) - s.runningJobs.addPending(involving) - continue - } - - s.deliveryJob(wk, targetPool, &job) - - if s.generalDDLWorkerPool.available() == 0 && s.reorgWorkerPool.available() == 0 { - break - } - } - return nil -} - -// mustReloadSchemas is used to reload schema when we become the DDL owner, in case -// the schema version is outdated before we become the owner. -// It will keep reloading schema until either success or context done. -// Domain also have a similar method 'mustReload', but its methods don't accept context. -func (s *jobScheduler) mustReloadSchemas() { - for { - err := s.schemaLoader.Reload() - if err == nil { - return - } - logutil.DDLLogger().Warn("reload schema failed, will retry later", zap.Error(err)) - select { - case <-s.schCtx.Done(): - return - case <-time.After(schedulerLoopRetryInterval): - } - } -} - -// deliveryJob deliver the job to the worker to run it asynchronously. -// the worker will run the job until it's finished, paused or another owner takes -// over and finished it. -func (s *jobScheduler) deliveryJob(wk *worker, pool *workerPool, job *model.Job) { - failpoint.InjectCall("beforeDeliveryJob", job) - injectFailPointForGetJob(job) - jobID, involvedSchemaInfos := job.ID, job.GetInvolvingSchemaInfo() - s.runningJobs.addRunning(jobID, involvedSchemaInfos) - metrics.DDLRunningJobCount.WithLabelValues(pool.tp().String()).Inc() - s.wg.Run(func() { - defer func() { - r := recover() - if r != nil { - logutil.DDLLogger().Error("panic in deliveryJob", zap.Any("recover", r), zap.Stack("stack")) - } - failpoint.InjectCall("afterDeliveryJob", job) - // Because there is a gap between `allIDs()` and `checkRunnable()`, - // we append unfinished job to pending atomically to prevent `getJob()` - // chosing another runnable job that involves the same schema object. - moveRunningJobsToPending := r != nil || (job != nil && !job.IsFinished()) - s.runningJobs.finishOrPendJob(jobID, involvedSchemaInfos, moveRunningJobsToPending) - asyncNotify(s.ddlJobNotifyCh) - metrics.DDLRunningJobCount.WithLabelValues(pool.tp().String()).Dec() - pool.put(wk) - }() - for { - err := s.transitOneJobStepAndWaitSync(wk, job) - if err != nil { - logutil.DDLLogger().Info("run job failed", zap.Error(err), zap.Stringer("job", job)) - } else if job.InFinalState() { - return - } - // we have to refresh the job, to handle cases like job cancel or pause - // or the job is finished by another owner. - // TODO for JobStateRollbackDone we have to query 1 additional time when the - // job is already moved to history. - failpoint.InjectCall("beforeRefreshJob", job) - for { - job, err = s.sysTblMgr.GetJobByID(s.schCtx, jobID) - failpoint.InjectCall("mockGetJobByIDFail", &err) - if err == nil { - break - } - - if err == systable.ErrNotFound { - logutil.DDLLogger().Info("job not found, might already finished", - zap.Int64("job_id", jobID)) - return - } - logutil.DDLLogger().Error("get job failed", zap.Int64("job_id", jobID), zap.Error(err)) - select { - case <-s.schCtx.Done(): - return - case <-time.After(500 * time.Millisecond): - continue - } - } - } - }) -} - -// transitOneJobStepAndWaitSync runs one step of the DDL job, persist it and -// waits for other TiDB node to synchronize. -func (s *jobScheduler) transitOneJobStepAndWaitSync(wk *worker, job *model.Job) error { - failpoint.InjectCall("beforeRunOneJobStep") - ownerID := s.ownerManager.ID() - // suppose we failed to sync version last time, we need to check and sync it - // before run to maintain the 2-version invariant. - if !job.NotStarted() && (!s.isSynced(job) || !s.maybeAlreadyRunOnce(job.ID)) { - if variable.EnableMDL.Load() { - version, err := s.sysTblMgr.GetMDLVer(s.schCtx, job.ID) - if err == nil { - err = waitSchemaSyncedForMDL(wk.ctx, s.ddlCtx, job, version) - if err != nil { - return err - } - s.setAlreadyRunOnce(job.ID) - s.cleanMDLInfo(job, ownerID) - return nil - } else if err != systable.ErrNotFound { - wk.jobLogger(job).Warn("check MDL info failed", zap.Error(err)) - return err - } - } else { - err := waitSchemaSynced(wk.ctx, s.ddlCtx, job) - if err != nil { - time.Sleep(time.Second) - return err - } - s.setAlreadyRunOnce(job.ID) - } - } - - schemaVer, err := wk.transitOneJobStep(s.ddlCtx, job) - if err != nil { - tidblogutil.Logger(wk.logCtx).Info("handle ddl job failed", zap.Error(err), zap.Stringer("job", job)) - return err - } - failpoint.Inject("mockDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { - if val.(bool) { - if mockDDLErrOnce == 0 { - mockDDLErrOnce = schemaVer - failpoint.Return(errors.New("mock down before update global version")) - } - } - }) - - failpoint.InjectCall("beforeWaitSchemaChanged", job, schemaVer) - // Here means the job enters another state (delete only, write only, public, etc...) or is cancelled. - // If the job is done or still running or rolling back, we will wait 2 * lease time or util MDL synced to guarantee other servers to update - // the newest schema. - if err = waitSchemaChanged(wk.ctx, s.ddlCtx, schemaVer, job); err != nil { - return err - } - s.cleanMDLInfo(job, ownerID) - s.synced(job) - - failpoint.InjectCall("onJobUpdated", job) - return nil -} - -// cleanMDLInfo cleans metadata lock info. -func (s *jobScheduler) cleanMDLInfo(job *model.Job, ownerID string) { - if !variable.EnableMDL.Load() { - return - } - var sql string - if tidbutil.IsSysDB(strings.ToLower(job.SchemaName)) { - // DDLs that modify system tables could only happen in upgrade process, - // we should not reference 'owner_id'. Otherwise, there is a circular blocking problem. - sql = fmt.Sprintf("delete from mysql.tidb_mdl_info where job_id = %d", job.ID) - } else { - sql = fmt.Sprintf("delete from mysql.tidb_mdl_info where job_id = %d and owner_id = '%s'", job.ID, ownerID) - } - sctx, _ := s.sessPool.Get() - defer s.sessPool.Put(sctx) - se := sess.NewSession(sctx) - se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - _, err := se.Execute(s.schCtx, sql, "delete-mdl-info") - if err != nil { - logutil.DDLLogger().Warn("unexpected error when clean mdl info", zap.Int64("job ID", job.ID), zap.Error(err)) - return - } - // TODO we need clean it when version of JobStateRollbackDone is synced also. - if job.State == model.JobStateSynced && s.etcdCli != nil { - path := fmt.Sprintf("%s/%d/", util.DDLAllSchemaVersionsByJob, job.ID) - _, err = s.etcdCli.Delete(s.schCtx, path, clientv3.WithPrefix()) - if err != nil { - logutil.DDLLogger().Warn("delete versions failed", zap.Int64("job ID", job.ID), zap.Error(err)) - } - } -} - -func (d *ddl) getTableByTxn(r autoid.Requirement, schemaID, tableID int64) (*model.DBInfo, table.Table, error) { - var tbl table.Table - var dbInfo *model.DBInfo - err := kv.RunInNewTxn(d.ctx, r.Store(), false, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - var err1 error - dbInfo, err1 = t.GetDatabase(schemaID) - if err1 != nil { - return errors.Trace(err1) - } - tblInfo, err1 := getTableInfo(t, tableID, schemaID) - if err1 != nil { - return errors.Trace(err1) - } - tbl, err1 = getTable(r, schemaID, tblInfo) - return errors.Trace(err1) - }) - return dbInfo, tbl, err -} - -const ( - addDDLJobSQL = "insert into mysql.tidb_ddl_job(job_id, reorg, schema_ids, table_ids, job_meta, type, processing) values" - updateDDLJobSQL = "update mysql.tidb_ddl_job set job_meta = %s where job_id = %d" -) - -func insertDDLJobs2Table(ctx context.Context, se *sess.Session, jobWs ...*JobWrapper) error { - failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) - } - }) - if len(jobWs) == 0 { - return nil - } - var sql bytes.Buffer - sql.WriteString(addDDLJobSQL) - for i, jobW := range jobWs { - b, err := jobW.Encode(true) - if err != nil { - return err - } - if i != 0 { - sql.WriteString(",") - } - fmt.Fprintf(&sql, "(%d, %t, %s, %s, %s, %d, %t)", jobW.ID, jobW.MayNeedReorg(), - strconv.Quote(job2SchemaIDs(jobW.Job)), strconv.Quote(job2TableIDs(jobW.Job)), - util.WrapKey2String(b), jobW.Type, !jobW.NotStarted()) - } - se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - _, err := se.Execute(ctx, sql.String(), "insert_job") - logutil.DDLLogger().Debug("add job to mysql.tidb_ddl_job table", zap.String("sql", sql.String())) - return errors.Trace(err) -} - -func job2SchemaIDs(job *model.Job) string { - return job2UniqueIDs(job, true) -} - -func job2TableIDs(job *model.Job) string { - return job2UniqueIDs(job, false) -} - -func job2UniqueIDs(job *model.Job, schema bool) string { - switch job.Type { - case model.ActionExchangeTablePartition, model.ActionRenameTables, model.ActionRenameTable: - var ids []int64 - if schema { - ids = job.CtxVars[0].([]int64) - } else { - ids = job.CtxVars[1].([]int64) - } - set := make(map[int64]struct{}, len(ids)) - for _, id := range ids { - set[id] = struct{}{} - } - - s := make([]string, 0, len(set)) - for id := range set { - s = append(s, strconv.FormatInt(id, 10)) - } - slices.Sort(s) - return strings.Join(s, ",") - case model.ActionTruncateTable: - if schema { - return strconv.FormatInt(job.SchemaID, 10) - } - return strconv.FormatInt(job.TableID, 10) + "," + strconv.FormatInt(job.Args[0].(int64), 10) - } - if schema { - return strconv.FormatInt(job.SchemaID, 10) - } - return strconv.FormatInt(job.TableID, 10) -} - -func updateDDLJob2Table(se *sess.Session, job *model.Job, updateRawArgs bool) error { - b, err := job.Encode(updateRawArgs) - if err != nil { - return err - } - sql := fmt.Sprintf(updateDDLJobSQL, util.WrapKey2String(b), job.ID) - _, err = se.Execute(context.Background(), sql, "update_job") - return errors.Trace(err) -} - -// getDDLReorgHandle gets DDL reorg handle. -func getDDLReorgHandle(se *sess.Session, job *model.Job) (element *meta.Element, - startKey, endKey kv.Key, physicalTableID int64, err error) { - sql := fmt.Sprintf("select ele_id, ele_type, start_key, end_key, physical_id, reorg_meta from mysql.tidb_ddl_reorg where job_id = %d", job.ID) - ctx := kv.WithInternalSourceType(context.Background(), getDDLRequestSource(job.Type)) - rows, err := se.Execute(ctx, sql, "get_handle") - if err != nil { - return nil, nil, nil, 0, err - } - if len(rows) == 0 { - return nil, nil, nil, 0, meta.ErrDDLReorgElementNotExist - } - id := rows[0].GetInt64(0) - tp := rows[0].GetBytes(1) - element = &meta.Element{ - ID: id, - TypeKey: tp, - } - startKey = rows[0].GetBytes(2) - endKey = rows[0].GetBytes(3) - physicalTableID = rows[0].GetInt64(4) - return -} - -func getImportedKeyFromCheckpoint(se *sess.Session, job *model.Job) (imported kv.Key, physicalTableID int64, err error) { - sql := fmt.Sprintf("select reorg_meta from mysql.tidb_ddl_reorg where job_id = %d", job.ID) - ctx := kv.WithInternalSourceType(context.Background(), getDDLRequestSource(job.Type)) - rows, err := se.Execute(ctx, sql, "get_handle") - if err != nil { - return nil, 0, err - } - if len(rows) == 0 { - return nil, 0, meta.ErrDDLReorgElementNotExist - } - if !rows[0].IsNull(0) { - rawReorgMeta := rows[0].GetBytes(0) - var reorgMeta ingest.JobReorgMeta - err = json.Unmarshal(rawReorgMeta, &reorgMeta) - if err != nil { - return nil, 0, errors.Trace(err) - } - if cp := reorgMeta.Checkpoint; cp != nil { - logutil.DDLIngestLogger().Info("resume physical table ID from checkpoint", - zap.Int64("jobID", job.ID), - zap.String("global sync key", hex.EncodeToString(cp.GlobalSyncKey)), - zap.Int64("checkpoint physical ID", cp.PhysicalID)) - return cp.GlobalSyncKey, cp.PhysicalID, nil - } - } - return -} - -// updateDDLReorgHandle update startKey, endKey physicalTableID and element of the handle. -// Caller should wrap this in a separate transaction, to avoid conflicts. -func updateDDLReorgHandle(se *sess.Session, jobID int64, startKey kv.Key, endKey kv.Key, physicalTableID int64, element *meta.Element) error { - sql := fmt.Sprintf("update mysql.tidb_ddl_reorg set ele_id = %d, ele_type = %s, start_key = %s, end_key = %s, physical_id = %d where job_id = %d", - element.ID, util.WrapKey2String(element.TypeKey), util.WrapKey2String(startKey), util.WrapKey2String(endKey), physicalTableID, jobID) - _, err := se.Execute(context.Background(), sql, "update_handle") - return err -} - -// initDDLReorgHandle initializes the handle for ddl reorg. -func initDDLReorgHandle(s *sess.Session, jobID int64, startKey kv.Key, endKey kv.Key, physicalTableID int64, element *meta.Element) error { - rawReorgMeta, err := json.Marshal(ingest.JobReorgMeta{ - Checkpoint: &ingest.ReorgCheckpoint{ - PhysicalID: physicalTableID, - Version: ingest.JobCheckpointVersionCurrent, - }}) - if err != nil { - return errors.Trace(err) - } - del := fmt.Sprintf("delete from mysql.tidb_ddl_reorg where job_id = %d", jobID) - ins := fmt.Sprintf("insert into mysql.tidb_ddl_reorg(job_id, ele_id, ele_type, start_key, end_key, physical_id, reorg_meta) values (%d, %d, %s, %s, %s, %d, %s)", - jobID, element.ID, util.WrapKey2String(element.TypeKey), util.WrapKey2String(startKey), util.WrapKey2String(endKey), physicalTableID, util.WrapKey2String(rawReorgMeta)) - return s.RunInTxn(func(se *sess.Session) error { - _, err := se.Execute(context.Background(), del, "init_handle") - if err != nil { - logutil.DDLLogger().Info("initDDLReorgHandle failed to delete", zap.Int64("jobID", jobID), zap.Error(err)) - } - _, err = se.Execute(context.Background(), ins, "init_handle") - return err - }) -} - -// deleteDDLReorgHandle deletes the handle for ddl reorg. -func removeDDLReorgHandle(se *sess.Session, job *model.Job, elements []*meta.Element) error { - if len(elements) == 0 { - return nil - } - sql := fmt.Sprintf("delete from mysql.tidb_ddl_reorg where job_id = %d", job.ID) - return se.RunInTxn(func(se *sess.Session) error { - _, err := se.Execute(context.Background(), sql, "remove_handle") - return err - }) -} - -// removeReorgElement removes the element from ddl reorg, it is the same with removeDDLReorgHandle, only used in failpoint -func removeReorgElement(se *sess.Session, job *model.Job) error { - sql := fmt.Sprintf("delete from mysql.tidb_ddl_reorg where job_id = %d", job.ID) - return se.RunInTxn(func(se *sess.Session) error { - _, err := se.Execute(context.Background(), sql, "remove_handle") - return err - }) -} - -// cleanDDLReorgHandles removes handles that are no longer needed. -func cleanDDLReorgHandles(se *sess.Session, job *model.Job) error { - sql := "delete from mysql.tidb_ddl_reorg where job_id = " + strconv.FormatInt(job.ID, 10) - return se.RunInTxn(func(se *sess.Session) error { - _, err := se.Execute(context.Background(), sql, "clean_handle") - return err - }) -} - -func getJobsBySQL(se *sess.Session, tbl, condition string) ([]*model.Job, error) { - rows, err := se.Execute(context.Background(), fmt.Sprintf("select job_meta from mysql.%s where %s", tbl, condition), "get_job") - if err != nil { - return nil, errors.Trace(err) - } - jobs := make([]*model.Job, 0, 16) - for _, row := range rows { - jobBinary := row.GetBytes(0) - job := model.Job{} - err := job.Decode(jobBinary) - if err != nil { - return nil, errors.Trace(err) - } - jobs = append(jobs, &job) - } - return jobs, nil -} diff --git a/pkg/ddl/job_submitter.go b/pkg/ddl/job_submitter.go index dd99d93376ca9..da78b30bc5a3e 100644 --- a/pkg/ddl/job_submitter.go +++ b/pkg/ddl/job_submitter.go @@ -55,7 +55,7 @@ func (d *ddl) limitDDLJobs() { // the channel is never closed case jobW := <-ch: jobWs = jobWs[:0] - failpoint.Call(_curpkg_("afterGetJobFromLimitCh"), ch) + failpoint.InjectCall("afterGetJobFromLimitCh", ch) jobLen := len(ch) jobWs = append(jobWs, jobW) for i := 0; i < jobLen; i++ { @@ -369,11 +369,11 @@ func (d *ddl) addBatchDDLJobs2Queue(jobWs []*JobWrapper) error { return errors.Trace(err) } } - if val, _err_ := failpoint.Eval(_curpkg_("mockAddBatchDDLJobsErr")); _err_ == nil { + failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { if val.(bool) { - return errors.Errorf("mockAddBatchDDLJobsErr") + failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) } - } + }) return nil }) } @@ -399,11 +399,11 @@ func (*ddl) checkFlashbackJobInQueue(t *meta.Meta) error { func GenGIDAndInsertJobsWithRetry(ctx context.Context, ddlSe *sess.Session, jobWs []*JobWrapper) error { count := getRequiredGIDCount(jobWs) return genGIDAndCallWithRetry(ctx, ddlSe, count, func(ids []int64) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockGenGlobalIDFail")); _err_ == nil { + failpoint.Inject("mockGenGlobalIDFail", func(val failpoint.Value) { if val.(bool) { - return errors.New("gofail genGlobalIDs error") + failpoint.Return(errors.New("gofail genGlobalIDs error")) } - } + }) assignGIDsForJobs(jobWs, ids) injectModifyJobArgFailPoint(jobWs) return insertDDLJobs2Table(ctx, ddlSe, jobWs...) @@ -578,7 +578,7 @@ func lockGlobalIDKey(ctx context.Context, ddlSe *sess.Session, txn kv.Transactio // TODO this failpoint is only checking how job scheduler handle // corrupted job args, we should test it there by UT, not here. func injectModifyJobArgFailPoint(jobWs []*JobWrapper) { - if val, _err_ := failpoint.Eval(_curpkg_("MockModifyJobArg")); _err_ == nil { + failpoint.Inject("MockModifyJobArg", func(val failpoint.Value) { if val.(bool) { for _, jobW := range jobWs { job := jobW.Job @@ -592,7 +592,7 @@ func injectModifyJobArgFailPoint(jobWs []*JobWrapper) { } } } - } + }) } func setJobStateToQueueing(job *model.Job) { diff --git a/pkg/ddl/job_submitter.go__failpoint_stash__ b/pkg/ddl/job_submitter.go__failpoint_stash__ deleted file mode 100644 index da78b30bc5a3e..0000000000000 --- a/pkg/ddl/job_submitter.go__failpoint_stash__ +++ /dev/null @@ -1,669 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "fmt" - "math" - "strconv" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - ddlutil "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/mathutil" - tikv "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/oracle" - "go.uber.org/zap" -) - -func (d *ddl) limitDDLJobs() { - defer util.Recover(metrics.LabelDDL, "limitDDLJobs", nil, true) - - jobWs := make([]*JobWrapper, 0, batchAddingJobs) - ch := d.limitJobCh - for { - select { - // the channel is never closed - case jobW := <-ch: - jobWs = jobWs[:0] - failpoint.InjectCall("afterGetJobFromLimitCh", ch) - jobLen := len(ch) - jobWs = append(jobWs, jobW) - for i := 0; i < jobLen; i++ { - jobWs = append(jobWs, <-ch) - } - d.addBatchDDLJobs(jobWs) - case <-d.ctx.Done(): - return - } - } -} - -// addBatchDDLJobs gets global job IDs and puts the DDL jobs in the DDL queue. -func (d *ddl) addBatchDDLJobs(jobWs []*JobWrapper) { - startTime := time.Now() - var ( - err error - newWs []*JobWrapper - ) - // DDLForce2Queue is a flag to tell DDL worker to always push the job to the DDL queue. - toTable := !variable.DDLForce2Queue.Load() - fastCreate := variable.EnableFastCreateTable.Load() - if toTable { - if fastCreate { - newWs, err = mergeCreateTableJobs(jobWs) - if err != nil { - logutil.DDLLogger().Warn("failed to merge create table jobs", zap.Error(err)) - } else { - jobWs = newWs - } - } - err = d.addBatchDDLJobs2Table(jobWs) - } else { - err = d.addBatchDDLJobs2Queue(jobWs) - } - var jobs string - for _, jobW := range jobWs { - if err == nil { - err = jobW.cacheErr - } - jobW.NotifyResult(err) - jobs += jobW.Job.String() + "; " - metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerAddDDLJob, jobW.Job.Type.String(), - metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - } - if err != nil { - logutil.DDLLogger().Warn("add DDL jobs failed", zap.String("jobs", jobs), zap.Error(err)) - } else { - logutil.DDLLogger().Info("add DDL jobs", - zap.Int("batch count", len(jobWs)), - zap.String("jobs", jobs), - zap.Bool("table", toTable), - zap.Bool("fast_create", fastCreate)) - } -} - -// mergeCreateTableJobs merges CreateTable jobs to CreateTables. -func mergeCreateTableJobs(jobWs []*JobWrapper) ([]*JobWrapper, error) { - if len(jobWs) <= 1 { - return jobWs, nil - } - resJobWs := make([]*JobWrapper, 0, len(jobWs)) - mergeableJobWs := make(map[string][]*JobWrapper, len(jobWs)) - for _, jobW := range jobWs { - // we don't merge jobs with ID pre-allocated. - if jobW.Type != model.ActionCreateTable || jobW.IDAllocated { - resJobWs = append(resJobWs, jobW) - continue - } - // ActionCreateTables doesn't support foreign key now. - tbInfo, ok := jobW.Args[0].(*model.TableInfo) - if !ok || len(tbInfo.ForeignKeys) > 0 { - resJobWs = append(resJobWs, jobW) - continue - } - // CreateTables only support tables of same schema now. - mergeableJobWs[jobW.Job.SchemaName] = append(mergeableJobWs[jobW.Job.SchemaName], jobW) - } - - for schema, jobs := range mergeableJobWs { - total := len(jobs) - if total <= 1 { - resJobWs = append(resJobWs, jobs...) - continue - } - const maxBatchSize = 8 - batchCount := (total + maxBatchSize - 1) / maxBatchSize - start := 0 - for _, batchSize := range mathutil.Divide2Batches(total, batchCount) { - batch := jobs[start : start+batchSize] - job, err := mergeCreateTableJobsOfSameSchema(batch) - if err != nil { - return nil, err - } - start += batchSize - logutil.DDLLogger().Info("merge create table jobs", zap.String("schema", schema), - zap.Int("total", total), zap.Int("batch_size", batchSize)) - - newJobW := &JobWrapper{ - Job: job, - ResultCh: make([]chan jobSubmitResult, 0, batchSize), - } - // merge the result channels. - for _, j := range batch { - newJobW.ResultCh = append(newJobW.ResultCh, j.ResultCh...) - } - resJobWs = append(resJobWs, newJobW) - } - } - return resJobWs, nil -} - -// buildQueryStringFromJobs takes a slice of Jobs and concatenates their -// queries into a single query string. -// Each query is separated by a semicolon and a space. -// Trailing spaces are removed from each query, and a semicolon is appended -// if it's not already present. -func buildQueryStringFromJobs(jobs []*JobWrapper) string { - var queryBuilder strings.Builder - for i, job := range jobs { - q := strings.TrimSpace(job.Query) - if !strings.HasSuffix(q, ";") { - q += ";" - } - queryBuilder.WriteString(q) - - if i < len(jobs)-1 { - queryBuilder.WriteString(" ") - } - } - return queryBuilder.String() -} - -// mergeCreateTableJobsOfSameSchema combine CreateTableJobs to BatchCreateTableJob. -func mergeCreateTableJobsOfSameSchema(jobWs []*JobWrapper) (*model.Job, error) { - if len(jobWs) == 0 { - return nil, errors.Trace(fmt.Errorf("expect non-empty jobs")) - } - - var combinedJob *model.Job - - args := make([]*model.TableInfo, 0, len(jobWs)) - involvingSchemaInfo := make([]model.InvolvingSchemaInfo, 0, len(jobWs)) - var foreignKeyChecks bool - - // if there is any duplicated table name - duplication := make(map[string]struct{}) - for _, job := range jobWs { - if combinedJob == nil { - combinedJob = job.Clone() - combinedJob.Type = model.ActionCreateTables - combinedJob.Args = combinedJob.Args[:0] - foreignKeyChecks = job.Args[1].(bool) - } - // append table job args - info, ok := job.Args[0].(*model.TableInfo) - if !ok { - return nil, errors.Trace(fmt.Errorf("expect model.TableInfo, but got %T", job.Args[0])) - } - args = append(args, info) - - if _, ok := duplication[info.Name.L]; ok { - // return err even if create table if not exists - return nil, infoschema.ErrTableExists.FastGenByArgs("can not batch create tables with same name") - } - - duplication[info.Name.L] = struct{}{} - - involvingSchemaInfo = append(involvingSchemaInfo, - model.InvolvingSchemaInfo{ - Database: job.SchemaName, - Table: info.Name.L, - }) - } - - combinedJob.Args = append(combinedJob.Args, args) - combinedJob.Args = append(combinedJob.Args, foreignKeyChecks) - combinedJob.InvolvingSchemaInfo = involvingSchemaInfo - combinedJob.Query = buildQueryStringFromJobs(jobWs) - - return combinedJob, nil -} - -// addBatchDDLJobs2Table gets global job IDs and puts the DDL jobs in the DDL job table. -func (d *ddl) addBatchDDLJobs2Table(jobWs []*JobWrapper) error { - var err error - - if len(jobWs) == 0 { - return nil - } - - ctx := kv.WithInternalSourceType(d.ctx, kv.InternalTxnDDL) - se, err := d.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer d.sessPool.Put(se) - found, err := d.sysTblMgr.HasFlashbackClusterJob(ctx, d.minJobIDRefresher.GetCurrMinJobID()) - if err != nil { - return errors.Trace(err) - } - if found { - return errors.Errorf("Can't add ddl job, have flashback cluster job") - } - - var ( - startTS = uint64(0) - bdrRole = string(ast.BDRRoleNone) - ) - - err = kv.RunInNewTxn(ctx, d.store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - - bdrRole, err = t.GetBDRRole() - if err != nil { - return errors.Trace(err) - } - startTS = txn.StartTS() - - if variable.DDLForce2Queue.Load() { - if err := d.checkFlashbackJobInQueue(t); err != nil { - return err - } - } - - return nil - }) - if err != nil { - return errors.Trace(err) - } - - for _, jobW := range jobWs { - job := jobW.Job - job.Version = currentVersion - job.StartTS = startTS - job.BDRRole = bdrRole - - // BDR mode only affects the DDL not from CDC - if job.CDCWriteSource == 0 && bdrRole != string(ast.BDRRoleNone) { - if job.Type == model.ActionMultiSchemaChange && job.MultiSchemaInfo != nil { - for _, subJob := range job.MultiSchemaInfo.SubJobs { - if ast.DeniedByBDR(ast.BDRRole(bdrRole), subJob.Type, job) { - return dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) - } - } - } else if ast.DeniedByBDR(ast.BDRRole(bdrRole), job.Type, job) { - return dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) - } - } - - setJobStateToQueueing(job) - - if d.stateSyncer.IsUpgradingState() && !hasSysDB(job) { - if err = pauseRunningJob(sess.NewSession(se), job, model.AdminCommandBySystem); err != nil { - logutil.DDLUpgradingLogger().Warn("pause user DDL by system failed", zap.Stringer("job", job), zap.Error(err)) - jobW.cacheErr = err - continue - } - logutil.DDLUpgradingLogger().Info("pause user DDL by system successful", zap.Stringer("job", job)) - } - } - - se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - ddlSe := sess.NewSession(se) - if err = GenGIDAndInsertJobsWithRetry(ctx, ddlSe, jobWs); err != nil { - return errors.Trace(err) - } - for _, jobW := range jobWs { - d.initJobDoneCh(jobW.ID) - } - - return nil -} - -func (d *ddl) initJobDoneCh(jobID int64) { - d.ddlJobDoneChMap.Store(jobID, make(chan struct{}, 1)) -} - -func (d *ddl) addBatchDDLJobs2Queue(jobWs []*JobWrapper) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - // lock to reduce conflict - d.globalIDLock.Lock() - defer d.globalIDLock.Unlock() - return kv.RunInNewTxn(ctx, d.store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - - count := getRequiredGIDCount(jobWs) - ids, err := t.GenGlobalIDs(count) - if err != nil { - return errors.Trace(err) - } - assignGIDsForJobs(jobWs, ids) - - if err := d.checkFlashbackJobInQueue(t); err != nil { - return errors.Trace(err) - } - - for _, jobW := range jobWs { - job := jobW.Job - job.Version = currentVersion - job.StartTS = txn.StartTS() - setJobStateToQueueing(job) - if err = buildJobDependence(t, job); err != nil { - return errors.Trace(err) - } - jobListKey := meta.DefaultJobListKey - if job.MayNeedReorg() { - jobListKey = meta.AddIndexJobListKey - } - if err = t.EnQueueDDLJob(job, jobListKey); err != nil { - return errors.Trace(err) - } - } - failpoint.Inject("mockAddBatchDDLJobsErr", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.Errorf("mockAddBatchDDLJobsErr")) - } - }) - return nil - }) -} - -func (*ddl) checkFlashbackJobInQueue(t *meta.Meta) error { - jobs, err := t.GetAllDDLJobsInQueue(meta.DefaultJobListKey) - if err != nil { - return errors.Trace(err) - } - for _, job := range jobs { - if job.Type == model.ActionFlashbackCluster { - return errors.Errorf("Can't add ddl job, have flashback cluster job") - } - } - return nil -} - -// GenGIDAndInsertJobsWithRetry generate job related global ID and inserts DDL jobs to the DDL job -// table with retry. job id allocation and job insertion are in the same transaction, -// as we want to make sure DDL jobs are inserted in id order, then we can query from -// a min job ID when scheduling DDL jobs to mitigate https://github.com/pingcap/tidb/issues/52905. -// so this function has side effect, it will set table/db/job id of 'jobs'. -func GenGIDAndInsertJobsWithRetry(ctx context.Context, ddlSe *sess.Session, jobWs []*JobWrapper) error { - count := getRequiredGIDCount(jobWs) - return genGIDAndCallWithRetry(ctx, ddlSe, count, func(ids []int64) error { - failpoint.Inject("mockGenGlobalIDFail", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("gofail genGlobalIDs error")) - } - }) - assignGIDsForJobs(jobWs, ids) - injectModifyJobArgFailPoint(jobWs) - return insertDDLJobs2Table(ctx, ddlSe, jobWs...) - }) -} - -// getRequiredGIDCount returns the count of required global IDs for the jobs. it's calculated -// as: the count of jobs + the count of IDs for the jobs which do NOT have pre-allocated ID. -func getRequiredGIDCount(jobWs []*JobWrapper) int { - count := len(jobWs) - idCountForTable := func(info *model.TableInfo) int { - c := 1 - if partitionInfo := info.GetPartitionInfo(); partitionInfo != nil { - c += len(partitionInfo.Definitions) - } - return c - } - for _, jobW := range jobWs { - if jobW.IDAllocated { - continue - } - switch jobW.Type { - case model.ActionCreateView, model.ActionCreateSequence, model.ActionCreateTable: - info := jobW.Args[0].(*model.TableInfo) - count += idCountForTable(info) - case model.ActionCreateTables: - infos := jobW.Args[0].([]*model.TableInfo) - for _, info := range infos { - count += idCountForTable(info) - } - case model.ActionCreateSchema: - count++ - } - // TODO support other type of jobs - } - return count -} - -// assignGIDsForJobs should be used with getRequiredGIDCount, and len(ids) must equal -// what getRequiredGIDCount returns. -func assignGIDsForJobs(jobWs []*JobWrapper, ids []int64) { - idx := 0 - - assignIDsForTable := func(info *model.TableInfo) { - info.ID = ids[idx] - idx++ - if partitionInfo := info.GetPartitionInfo(); partitionInfo != nil { - for i := range partitionInfo.Definitions { - partitionInfo.Definitions[i].ID = ids[idx] - idx++ - } - } - } - for _, jobW := range jobWs { - switch jobW.Type { - case model.ActionCreateView, model.ActionCreateSequence, model.ActionCreateTable: - info := jobW.Args[0].(*model.TableInfo) - if !jobW.IDAllocated { - assignIDsForTable(info) - } - jobW.TableID = info.ID - case model.ActionCreateTables: - if !jobW.IDAllocated { - infos := jobW.Args[0].([]*model.TableInfo) - for _, info := range infos { - assignIDsForTable(info) - } - } - case model.ActionCreateSchema: - dbInfo := jobW.Args[0].(*model.DBInfo) - if !jobW.IDAllocated { - dbInfo.ID = ids[idx] - idx++ - } - jobW.SchemaID = dbInfo.ID - } - // TODO support other type of jobs - jobW.ID = ids[idx] - idx++ - } -} - -// genGIDAndCallWithRetry generates global IDs and calls the function with retry. -// generate ID and call function runs in the same transaction. -func genGIDAndCallWithRetry(ctx context.Context, ddlSe *sess.Session, count int, fn func(ids []int64) error) error { - var resErr error - for i := uint(0); i < kv.MaxRetryCnt; i++ { - resErr = func() (err error) { - if err := ddlSe.Begin(ctx); err != nil { - return errors.Trace(err) - } - defer func() { - if err != nil { - ddlSe.Rollback() - } - }() - txn, err := ddlSe.Txn() - if err != nil { - return errors.Trace(err) - } - txn.SetOption(kv.Pessimistic, true) - forUpdateTS, err := lockGlobalIDKey(ctx, ddlSe, txn) - if err != nil { - return errors.Trace(err) - } - txn.GetSnapshot().SetOption(kv.SnapshotTS, forUpdateTS) - - m := meta.NewMeta(txn) - ids, err := m.GenGlobalIDs(count) - if err != nil { - return errors.Trace(err) - } - if err = fn(ids); err != nil { - return errors.Trace(err) - } - return ddlSe.Commit(ctx) - }() - - if resErr != nil && kv.IsTxnRetryableError(resErr) { - logutil.DDLLogger().Warn("insert job meet retryable error", zap.Error(resErr)) - kv.BackOff(i) - continue - } - break - } - return resErr -} - -// lockGlobalIDKey locks the global ID key in the meta store. it keeps trying if -// meet write conflict, we cannot have a fixed retry count for this error, see this -// https://github.com/pingcap/tidb/issues/27197#issuecomment-2216315057. -// this part is same as how we implement pessimistic + repeatable read isolation -// level in SQL executor, see doLockKeys. -// NextGlobalID is a meta key, so we cannot use "select xx for update", if we store -// it into a table row or using advisory lock, we will depends on a system table -// that is created by us, cyclic. although we can create a system table without using -// DDL logic, we will only consider change it when we have data dictionary and keep -// it this way now. -// TODO maybe we can unify the lock mechanism with SQL executor in the future, or -// implement it inside TiKV client-go. -func lockGlobalIDKey(ctx context.Context, ddlSe *sess.Session, txn kv.Transaction) (uint64, error) { - var ( - iteration uint - forUpdateTs = txn.StartTS() - ver kv.Version - err error - ) - waitTime := ddlSe.GetSessionVars().LockWaitTimeout - m := meta.NewMeta(txn) - idKey := m.GlobalIDKey() - for { - lockCtx := tikv.NewLockCtx(forUpdateTs, waitTime, time.Now()) - err = txn.LockKeys(ctx, lockCtx, idKey) - if err == nil || !terror.ErrorEqual(kv.ErrWriteConflict, err) { - break - } - // ErrWriteConflict contains a conflict-commit-ts in most case, but it cannot - // be used as forUpdateTs, see comments inside handleAfterPessimisticLockError - ver, err = ddlSe.GetStore().CurrentVersion(oracle.GlobalTxnScope) - if err != nil { - break - } - forUpdateTs = ver.Ver - - kv.BackOff(iteration) - // avoid it keep growing and overflow. - iteration = min(iteration+1, math.MaxInt) - } - return forUpdateTs, err -} - -// TODO this failpoint is only checking how job scheduler handle -// corrupted job args, we should test it there by UT, not here. -func injectModifyJobArgFailPoint(jobWs []*JobWrapper) { - failpoint.Inject("MockModifyJobArg", func(val failpoint.Value) { - if val.(bool) { - for _, jobW := range jobWs { - job := jobW.Job - // Corrupt the DDL job argument. - if job.Type == model.ActionMultiSchemaChange { - if len(job.MultiSchemaInfo.SubJobs) > 0 && len(job.MultiSchemaInfo.SubJobs[0].Args) > 0 { - job.MultiSchemaInfo.SubJobs[0].Args[0] = 1 - } - } else if len(job.Args) > 0 { - job.Args[0] = 1 - } - } - } - }) -} - -func setJobStateToQueueing(job *model.Job) { - if job.Type == model.ActionMultiSchemaChange && job.MultiSchemaInfo != nil { - for _, sub := range job.MultiSchemaInfo.SubJobs { - sub.State = model.JobStateQueueing - } - } - job.State = model.JobStateQueueing -} - -// buildJobDependence sets the curjob's dependency-ID. -// The dependency-job's ID must less than the current job's ID, and we need the largest one in the list. -func buildJobDependence(t *meta.Meta, curJob *model.Job) error { - // Jobs in the same queue are ordered. If we want to find a job's dependency-job, we need to look for - // it from the other queue. So if the job is "ActionAddIndex" job, we need find its dependency-job from DefaultJobList. - jobListKey := meta.DefaultJobListKey - if !curJob.MayNeedReorg() { - jobListKey = meta.AddIndexJobListKey - } - jobs, err := t.GetAllDDLJobsInQueue(jobListKey) - if err != nil { - return errors.Trace(err) - } - - for _, job := range jobs { - if curJob.ID < job.ID { - continue - } - isDependent, err := curJob.IsDependentOn(job) - if err != nil { - return errors.Trace(err) - } - if isDependent { - logutil.DDLLogger().Info("current DDL job depends on other job", - zap.Stringer("currentJob", curJob), - zap.Stringer("dependentJob", job)) - curJob.DependencyID = job.ID - break - } - } - return nil -} - -func (e *executor) notifyNewJobSubmitted(ch chan struct{}, etcdPath string, jobID int64, jobType string) { - // If the workers don't run, we needn't notify workers. - // TODO: It does not affect informing the backfill worker. - if !config.GetGlobalConfig().Instance.TiDBEnableDDL.Load() { - return - } - if e.ownerManager.IsOwner() { - asyncNotify(ch) - } else { - e.notifyNewJobByEtcd(etcdPath, jobID, jobType) - } -} - -func (e *executor) notifyNewJobByEtcd(etcdPath string, jobID int64, jobType string) { - if e.etcdCli == nil { - return - } - - jobIDStr := strconv.FormatInt(jobID, 10) - timeStart := time.Now() - err := ddlutil.PutKVToEtcd(e.ctx, e.etcdCli, 1, etcdPath, jobIDStr) - if err != nil { - logutil.DDLLogger().Info("notify handling DDL job failed", - zap.String("etcdPath", etcdPath), - zap.Int64("jobID", jobID), - zap.String("type", jobType), - zap.Error(err)) - } - metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerNotifyDDLJob, jobType, metrics.RetLabel(err)).Observe(time.Since(timeStart).Seconds()) -} diff --git a/pkg/ddl/job_worker.go b/pkg/ddl/job_worker.go index ff2e392a2557e..c204711ce5217 100644 --- a/pkg/ddl/job_worker.go +++ b/pkg/ddl/job_worker.go @@ -191,12 +191,12 @@ func injectFailPointForGetJob(job *model.Job) { if job == nil { return } - if val, _err_ := failpoint.Eval(_curpkg_("mockModifyJobSchemaId")); _err_ == nil { + failpoint.Inject("mockModifyJobSchemaId", func(val failpoint.Value) { job.SchemaID = int64(val.(int)) - } - if val, _err_ := failpoint.Eval(_curpkg_("MockModifyJobTableId")); _err_ == nil { + }) + failpoint.Inject("MockModifyJobTableId", func(val failpoint.Value) { job.TableID = int64(val.(int)) - } + }) } // handleUpdateJobError handles the too large DDL job. @@ -224,11 +224,11 @@ func (w *worker) handleUpdateJobError(t *meta.Meta, job *model.Job, err error) e // updateDDLJob updates the DDL job information. func (w *worker) updateDDLJob(job *model.Job, updateRawArgs bool) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockErrEntrySizeTooLarge")); _err_ == nil { + failpoint.Inject("mockErrEntrySizeTooLarge", func(val failpoint.Value) { if val.(bool) { - return kv.ErrEntryTooLarge + failpoint.Return(kv.ErrEntryTooLarge) } - } + }) if !updateRawArgs { w.jobLogger(job).Info("meet something wrong before update DDL job, shouldn't update raw args", @@ -348,7 +348,7 @@ func (w *worker) finishDDLJob(t *meta.Meta, job *model.Job) (err error) { } job.SeqNum = w.seqAllocator.Add(1) w.removeJobCtx(job) - failpoint.Call(_curpkg_("afterFinishDDLJob"), job) + failpoint.InjectCall("afterFinishDDLJob", job) err = AddHistoryDDLJob(w.ctx, w.sess, t, job, updateRawArgs) return errors.Trace(err) } @@ -456,11 +456,11 @@ func (w *worker) prepareTxn(job *model.Job) (kv.Transaction, error) { if err != nil { return nil, err } - if val, _err_ := failpoint.Eval(_curpkg_("mockRunJobTime")); _err_ == nil { + failpoint.Inject("mockRunJobTime", func(val failpoint.Value) { if val.(bool) { time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) // #nosec G404 } - } + }) txn, err := w.sess.Txn() if err != nil { w.sess.Rollback() @@ -503,7 +503,7 @@ func (w *worker) transitOneJobStep(d *ddlCtx, job *model.Job) (int64, error) { job.State = model.JobStateSynced } // Inject the failpoint to prevent the progress of index creation. - if v, _err_ := failpoint.Eval(_curpkg_("create-index-stuck-before-ddlhistory")); _err_ == nil { + failpoint.Inject("create-index-stuck-before-ddlhistory", func(v failpoint.Value) { if sigFile, ok := v.(string); ok && job.Type == model.ActionAddIndex { for { time.Sleep(1 * time.Second) @@ -511,21 +511,21 @@ func (w *worker) transitOneJobStep(d *ddlCtx, job *model.Job) (int64, error) { if os.IsNotExist(err) { continue } - return 0, errors.Trace(err) + failpoint.Return(0, errors.Trace(err)) } break } } - } + }) return 0, w.handleJobDone(d, job, t) } - failpoint.Call(_curpkg_("onJobRunBefore"), job) + failpoint.InjectCall("onJobRunBefore", job) // If running job meets error, we will save this error in job Error and retry // later if the job is not cancelled. schemaVer, updateRawArgs, runJobErr := w.runOneJobStep(d, t, job) - failpoint.Call(_curpkg_("onJobRunAfter"), job) + failpoint.InjectCall("onJobRunAfter", job) if job.IsCancelled() { defer d.unlockSchemaVersion(job.ID) @@ -748,7 +748,7 @@ func (w *worker) runOneJobStep( }, false) // Mock for run ddl job panic. - failpoint.Eval(_curpkg_("mockPanicInRunDDLJob")) + failpoint.Inject("mockPanicInRunDDLJob", func(failpoint.Value) {}) if job.Type != model.ActionMultiSchemaChange { w.jobLogger(job).Info("run DDL job", zap.String("category", "ddl"), zap.String("job", job.String())) diff --git a/pkg/ddl/job_worker.go__failpoint_stash__ b/pkg/ddl/job_worker.go__failpoint_stash__ deleted file mode 100644 index c204711ce5217..0000000000000 --- a/pkg/ddl/job_worker.go__failpoint_stash__ +++ /dev/null @@ -1,1013 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "fmt" - "math/rand" - "os" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - pumpcli "github.com/pingcap/tidb/pkg/tidb-binlog/pump_client" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - tidblogutil "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/resourcegrouptag" - "github.com/pingcap/tidb/pkg/util/topsql" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/tikv/client-go/v2/tikvrpc" - kvutil "github.com/tikv/client-go/v2/util" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -var ( - // ddlWorkerID is used for generating the next DDL worker ID. - ddlWorkerID = atomicutil.NewInt32(0) - // backfillContextID is used for generating the next backfill context ID. - backfillContextID = atomicutil.NewInt32(0) - // WaitTimeWhenErrorOccurred is waiting interval when processing DDL jobs encounter errors. - WaitTimeWhenErrorOccurred = int64(1 * time.Second) - - mockDDLErrOnce = int64(0) - // TestNotifyBeginTxnCh is used for if the txn is beginning in runInTxn. - TestNotifyBeginTxnCh = make(chan struct{}) -) - -// GetWaitTimeWhenErrorOccurred return waiting interval when processing DDL jobs encounter errors. -func GetWaitTimeWhenErrorOccurred() time.Duration { - return time.Duration(atomic.LoadInt64(&WaitTimeWhenErrorOccurred)) -} - -// SetWaitTimeWhenErrorOccurred update waiting interval when processing DDL jobs encounter errors. -func SetWaitTimeWhenErrorOccurred(dur time.Duration) { - atomic.StoreInt64(&WaitTimeWhenErrorOccurred, int64(dur)) -} - -type workerType byte - -const ( - // generalWorker is the worker who handles all DDL statements except “add index”. - generalWorker workerType = 0 - // addIdxWorker is the worker who handles the operation of adding indexes. - addIdxWorker workerType = 1 -) - -// worker is used for handling DDL jobs. -// Now we have two kinds of workers. -type worker struct { - id int32 - tp workerType - addingDDLJobKey string - ddlJobCh chan struct{} - // it's the ctx of 'job scheduler'. - ctx context.Context - wg sync.WaitGroup - - sessPool *sess.Pool // sessPool is used to new sessions to execute SQL in ddl package. - sess *sess.Session // sess is used and only used in running DDL job. - delRangeManager delRangeManager - logCtx context.Context - seqAllocator *atomic.Uint64 - - *ddlCtx -} - -// JobContext is the ddl job execution context. -type JobContext struct { - // below fields are cache for top sql - ddlJobCtx context.Context - cacheSQL string - cacheNormalizedSQL string - cacheDigest *parser.Digest - tp string - - resourceGroupName string - cloudStorageURI string -} - -// NewJobContext returns a new ddl job context. -func NewJobContext() *JobContext { - return &JobContext{ - ddlJobCtx: context.Background(), - cacheSQL: "", - cacheNormalizedSQL: "", - cacheDigest: nil, - tp: "", - } -} - -func newWorker(ctx context.Context, tp workerType, sessPool *sess.Pool, delRangeMgr delRangeManager, dCtx *ddlCtx) *worker { - worker := &worker{ - id: ddlWorkerID.Add(1), - tp: tp, - ddlJobCh: make(chan struct{}, 1), - ctx: ctx, - ddlCtx: dCtx, - sessPool: sessPool, - delRangeManager: delRangeMgr, - } - worker.addingDDLJobKey = addingDDLJobPrefix + worker.typeStr() - worker.logCtx = tidblogutil.WithFields(context.Background(), zap.String("worker", worker.String()), zap.String("category", "ddl")) - return worker -} - -func (w *worker) jobLogger(job *model.Job) *zap.Logger { - logger := tidblogutil.Logger(w.logCtx) - if job != nil { - logger = tidblogutil.LoggerWithTraceInfo( - logger.With(zap.Int64("jobID", job.ID)), - job.TraceInfo, - ) - } - return logger -} - -func (w *worker) typeStr() string { - var str string - switch w.tp { - case generalWorker: - str = "general" - case addIdxWorker: - str = "add index" - default: - str = "unknown" - } - return str -} - -func (w *worker) String() string { - return fmt.Sprintf("worker %d, tp %s", w.id, w.typeStr()) -} - -func (w *worker) Close() { - startTime := time.Now() - if w.sess != nil { - w.sessPool.Put(w.sess.Session()) - } - w.wg.Wait() - tidblogutil.Logger(w.logCtx).Info("DDL worker closed", zap.Duration("take time", time.Since(startTime))) -} - -func asyncNotify(ch chan struct{}) { - select { - case ch <- struct{}{}: - default: - } -} - -func injectFailPointForGetJob(job *model.Job) { - if job == nil { - return - } - failpoint.Inject("mockModifyJobSchemaId", func(val failpoint.Value) { - job.SchemaID = int64(val.(int)) - }) - failpoint.Inject("MockModifyJobTableId", func(val failpoint.Value) { - job.TableID = int64(val.(int)) - }) -} - -// handleUpdateJobError handles the too large DDL job. -func (w *worker) handleUpdateJobError(t *meta.Meta, job *model.Job, err error) error { - if err == nil { - return nil - } - if kv.ErrEntryTooLarge.Equal(err) { - w.jobLogger(job).Warn("update DDL job failed", zap.String("job", job.String()), zap.Error(err)) - w.sess.Rollback() - err1 := w.sess.Begin(w.ctx) - if err1 != nil { - return errors.Trace(err1) - } - // Reduce this txn entry size. - job.BinlogInfo.Clean() - job.Error = toTError(err) - job.ErrorCount++ - job.SchemaState = model.StateNone - job.State = model.JobStateCancelled - err = w.finishDDLJob(t, job) - } - return errors.Trace(err) -} - -// updateDDLJob updates the DDL job information. -func (w *worker) updateDDLJob(job *model.Job, updateRawArgs bool) error { - failpoint.Inject("mockErrEntrySizeTooLarge", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(kv.ErrEntryTooLarge) - } - }) - - if !updateRawArgs { - w.jobLogger(job).Info("meet something wrong before update DDL job, shouldn't update raw args", - zap.String("job", job.String())) - } - return errors.Trace(updateDDLJob2Table(w.sess, job, updateRawArgs)) -} - -// registerMDLInfo registers metadata lock info. -func (w *worker) registerMDLInfo(job *model.Job, ver int64) error { - if !variable.EnableMDL.Load() { - return nil - } - if ver == 0 { - return nil - } - rows, err := w.sess.Execute(w.ctx, fmt.Sprintf("select table_ids from mysql.tidb_ddl_job where job_id = %d", job.ID), "register-mdl-info") - if err != nil { - return err - } - if len(rows) == 0 { - return errors.Errorf("can't find ddl job %d", job.ID) - } - ownerID := w.ownerManager.ID() - ids := rows[0].GetString(0) - var sql string - if tidbutil.IsSysDB(strings.ToLower(job.SchemaName)) { - // DDLs that modify system tables could only happen in upgrade process, - // we should not reference 'owner_id'. Otherwise, there is a circular blocking problem. - sql = fmt.Sprintf("replace into mysql.tidb_mdl_info (job_id, version, table_ids) values (%d, %d, '%s')", job.ID, ver, ids) - } else { - sql = fmt.Sprintf("replace into mysql.tidb_mdl_info (job_id, version, table_ids, owner_id) values (%d, %d, '%s', '%s')", job.ID, ver, ids, ownerID) - } - _, err = w.sess.Execute(w.ctx, sql, "register-mdl-info") - return err -} - -// JobNeedGC is called to determine whether delete-ranges need to be generated for the provided job. -// -// NOTICE: BR also uses jobNeedGC to determine whether delete-ranges need to be generated for the provided job. -// Therefore, please make sure any modification is compatible with BR. -func JobNeedGC(job *model.Job) bool { - if !job.IsCancelled() { - if job.Warning != nil && dbterror.ErrCantDropFieldOrKey.Equal(job.Warning) { - // For the field/key not exists warnings, there is no need to - // delete the ranges. - return false - } - switch job.Type { - case model.ActionDropSchema, model.ActionDropTable, - model.ActionTruncateTable, model.ActionDropIndex, - model.ActionDropPrimaryKey, - model.ActionDropTablePartition, model.ActionTruncateTablePartition, - model.ActionDropColumn, model.ActionModifyColumn, - model.ActionAddIndex, model.ActionAddPrimaryKey, - model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - return true - case model.ActionMultiSchemaChange: - for i, sub := range job.MultiSchemaInfo.SubJobs { - proxyJob := sub.ToProxyJob(job, i) - needGC := JobNeedGC(&proxyJob) - if needGC { - return true - } - } - return false - } - } - return false -} - -// finishDDLJob deletes the finished DDL job in the ddl queue and puts it to history queue. -// If the DDL job need to handle in background, it will prepare a background job. -func (w *worker) finishDDLJob(t *meta.Meta, job *model.Job) (err error) { - startTime := time.Now() - defer func() { - metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerFinishDDLJob, job.Type.String(), metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - }() - - if JobNeedGC(job) { - err = w.delRangeManager.addDelRangeJob(w.ctx, job) - if err != nil { - return errors.Trace(err) - } - } - - switch job.Type { - case model.ActionRecoverTable: - err = finishRecoverTable(w, job) - case model.ActionFlashbackCluster: - err = finishFlashbackCluster(w, job) - case model.ActionRecoverSchema: - err = finishRecoverSchema(w, job) - case model.ActionCreateTables: - if job.IsCancelled() { - // it may be too large that it can not be added to the history queue, so - // delete its arguments - job.Args = nil - } - } - if err != nil { - return errors.Trace(err) - } - err = w.deleteDDLJob(job) - if err != nil { - return errors.Trace(err) - } - - job.BinlogInfo.FinishedTS = t.StartTS - w.jobLogger(job).Info("finish DDL job", zap.String("job", job.String())) - updateRawArgs := true - if job.Type == model.ActionAddPrimaryKey && !job.IsCancelled() { - // ActionAddPrimaryKey needs to check the warnings information in job.Args. - // Notice: warnings is used to support non-strict mode. - updateRawArgs = false - } - job.SeqNum = w.seqAllocator.Add(1) - w.removeJobCtx(job) - failpoint.InjectCall("afterFinishDDLJob", job) - err = AddHistoryDDLJob(w.ctx, w.sess, t, job, updateRawArgs) - return errors.Trace(err) -} - -func (w *worker) deleteDDLJob(job *model.Job) error { - sql := fmt.Sprintf("delete from mysql.tidb_ddl_job where job_id = %d", job.ID) - _, err := w.sess.Execute(context.Background(), sql, "delete_job") - return errors.Trace(err) -} - -func finishRecoverTable(w *worker, job *model.Job) error { - var ( - recoverInfo *RecoverInfo - recoverTableCheckFlag int64 - ) - err := job.DecodeArgs(&recoverInfo, &recoverTableCheckFlag) - if err != nil { - return errors.Trace(err) - } - if recoverTableCheckFlag == recoverCheckFlagEnableGC { - err = enableGC(w) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -func finishRecoverSchema(w *worker, job *model.Job) error { - var ( - recoverSchemaInfo *RecoverSchemaInfo - recoverSchemaCheckFlag int64 - ) - err := job.DecodeArgs(&recoverSchemaInfo, &recoverSchemaCheckFlag) - if err != nil { - return errors.Trace(err) - } - if recoverSchemaCheckFlag == recoverCheckFlagEnableGC { - err = enableGC(w) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -func (w *JobContext) setDDLLabelForTopSQL(jobQuery string) { - if !topsqlstate.TopSQLEnabled() || jobQuery == "" { - return - } - - if jobQuery != w.cacheSQL || w.cacheDigest == nil { - w.cacheNormalizedSQL, w.cacheDigest = parser.NormalizeDigest(jobQuery) - w.cacheSQL = jobQuery - w.ddlJobCtx = topsql.AttachAndRegisterSQLInfo(context.Background(), w.cacheNormalizedSQL, w.cacheDigest, false) - } else { - topsql.AttachAndRegisterSQLInfo(w.ddlJobCtx, w.cacheNormalizedSQL, w.cacheDigest, false) - } -} - -// DDLBackfillers contains the DDL need backfill step. -var DDLBackfillers = map[model.ActionType]string{ - model.ActionAddIndex: "add_index", - model.ActionModifyColumn: "modify_column", - model.ActionDropIndex: "drop_index", - model.ActionReorganizePartition: "reorganize_partition", -} - -func getDDLRequestSource(jobType model.ActionType) string { - if tp, ok := DDLBackfillers[jobType]; ok { - return kv.InternalTxnBackfillDDLPrefix + tp - } - return kv.InternalTxnDDL -} - -func (w *JobContext) setDDLLabelForDiagnosis(jobType model.ActionType) { - if w.tp != "" { - return - } - w.tp = getDDLRequestSource(jobType) - w.ddlJobCtx = kv.WithInternalSourceAndTaskType(w.ddlJobCtx, w.ddlJobSourceType(), kvutil.ExplicitTypeDDL) -} - -func (w *worker) handleJobDone(d *ddlCtx, job *model.Job, t *meta.Meta) error { - if err := w.checkBeforeCommit(); err != nil { - return err - } - err := w.finishDDLJob(t, job) - if err != nil { - w.sess.Rollback() - return err - } - - err = w.sess.Commit(w.ctx) - if err != nil { - return err - } - cleanupDDLReorgHandles(job, w.sess) - d.notifyJobDone(job.ID) - return nil -} - -func (w *worker) prepareTxn(job *model.Job) (kv.Transaction, error) { - err := w.sess.Begin(w.ctx) - if err != nil { - return nil, err - } - failpoint.Inject("mockRunJobTime", func(val failpoint.Value) { - if val.(bool) { - time.Sleep(time.Duration(rand.Intn(500)) * time.Millisecond) // #nosec G404 - } - }) - txn, err := w.sess.Txn() - if err != nil { - w.sess.Rollback() - return txn, err - } - // Only general DDLs are allowed to be executed when TiKV is disk full. - if w.tp == addIdxWorker && job.IsRunning() { - txn.SetDiskFullOpt(kvrpcpb.DiskFullOpt_NotAllowedOnFull) - } - w.setDDLLabelForTopSQL(job.ID, job.Query) - w.setDDLSourceForDiagnosis(job.ID, job.Type) - jobContext := w.jobContext(job.ID, job.ReorgMeta) - if tagger := w.getResourceGroupTaggerForTopSQL(job.ID); tagger != nil { - txn.SetOption(kv.ResourceGroupTagger, tagger) - } - txn.SetOption(kv.ResourceGroupName, jobContext.resourceGroupName) - // set request source type to DDL type - txn.SetOption(kv.RequestSourceType, jobContext.ddlJobSourceType()) - return txn, err -} - -// transitOneJobStep runs one step of the DDL job and persist the new job -// information. -// -// The first return value is the schema version after running the job. If it's -// non-zero, caller should wait for other nodes to catch up. -func (w *worker) transitOneJobStep(d *ddlCtx, job *model.Job) (int64, error) { - var ( - err error - ) - - txn, err := w.prepareTxn(job) - if err != nil { - return 0, err - } - t := meta.NewMeta(txn) - - if job.IsDone() || job.IsRollbackDone() || job.IsCancelled() { - if job.IsDone() { - job.State = model.JobStateSynced - } - // Inject the failpoint to prevent the progress of index creation. - failpoint.Inject("create-index-stuck-before-ddlhistory", func(v failpoint.Value) { - if sigFile, ok := v.(string); ok && job.Type == model.ActionAddIndex { - for { - time.Sleep(1 * time.Second) - if _, err := os.Stat(sigFile); err != nil { - if os.IsNotExist(err) { - continue - } - failpoint.Return(0, errors.Trace(err)) - } - break - } - } - }) - return 0, w.handleJobDone(d, job, t) - } - failpoint.InjectCall("onJobRunBefore", job) - - // If running job meets error, we will save this error in job Error and retry - // later if the job is not cancelled. - schemaVer, updateRawArgs, runJobErr := w.runOneJobStep(d, t, job) - - failpoint.InjectCall("onJobRunAfter", job) - - if job.IsCancelled() { - defer d.unlockSchemaVersion(job.ID) - w.sess.Reset() - return 0, w.handleJobDone(d, job, t) - } - - if err = w.checkBeforeCommit(); err != nil { - d.unlockSchemaVersion(job.ID) - return 0, err - } - - if runJobErr != nil && !job.IsRollingback() && !job.IsRollbackDone() { - // If the running job meets an error - // and the job state is rolling back, it means that we have already handled this error. - // Some DDL jobs (such as adding indexes) may need to update the table info and the schema version, - // then shouldn't discard the KV modification. - // And the job state is rollback done, it means the job was already finished, also shouldn't discard too. - // Otherwise, we should discard the KV modification when running job. - w.sess.Reset() - // If error happens after updateSchemaVersion(), then the schemaVer is updated. - // Result in the retry duration is up to 2 * lease. - schemaVer = 0 - } - - err = w.registerMDLInfo(job, schemaVer) - if err != nil { - w.sess.Rollback() - d.unlockSchemaVersion(job.ID) - return 0, err - } - err = w.updateDDLJob(job, updateRawArgs) - if err = w.handleUpdateJobError(t, job, err); err != nil { - w.sess.Rollback() - d.unlockSchemaVersion(job.ID) - return 0, err - } - writeBinlog(d.binlogCli, txn, job) - // reset the SQL digest to make topsql work right. - w.sess.GetSessionVars().StmtCtx.ResetSQLDigest(job.Query) - err = w.sess.Commit(w.ctx) - d.unlockSchemaVersion(job.ID) - if err != nil { - return 0, err - } - w.registerSync(job) - - // If error is non-retryable, we can ignore the sleep. - if runJobErr != nil && errorIsRetryable(runJobErr, job) { - w.jobLogger(job).Info("run DDL job failed, sleeps a while then retries it.", - zap.Duration("waitTime", GetWaitTimeWhenErrorOccurred()), zap.Error(runJobErr)) - // wait a while to retry again. If we don't wait here, DDL will retry this job immediately, - // which may act like a deadlock. - select { - case <-time.After(GetWaitTimeWhenErrorOccurred()): - case <-w.ctx.Done(): - } - } - - return schemaVer, nil -} - -func (w *worker) checkBeforeCommit() error { - if !w.ddlCtx.isOwner() { - // Since this TiDB instance is not a DDL owner anymore, - // it should not commit any transaction. - w.sess.Rollback() - return dbterror.ErrNotOwner - } - - if err := w.ctx.Err(); err != nil { - // The worker context is canceled, it should not commit any transaction. - return err - } - return nil -} - -func (w *JobContext) getResourceGroupTaggerForTopSQL() tikvrpc.ResourceGroupTagger { - if !topsqlstate.TopSQLEnabled() || w.cacheDigest == nil { - return nil - } - - digest := w.cacheDigest - tagger := func(req *tikvrpc.Request) { - req.ResourceGroupTag = resourcegrouptag.EncodeResourceGroupTag(digest, nil, - resourcegrouptag.GetResourceGroupLabelByKey(resourcegrouptag.GetFirstKeyFromRequest(req))) - } - return tagger -} - -func (w *JobContext) ddlJobSourceType() string { - return w.tp -} - -func skipWriteBinlog(job *model.Job) bool { - switch job.Type { - // ActionUpdateTiFlashReplicaStatus is a TiDB internal DDL, - // it's used to update table's TiFlash replica available status. - case model.ActionUpdateTiFlashReplicaStatus: - return true - // Don't sync 'alter table cache|nocache' to other tools. - // It's internal to the current cluster. - case model.ActionAlterCacheTable, model.ActionAlterNoCacheTable: - return true - } - - return false -} - -func writeBinlog(binlogCli *pumpcli.PumpsClient, txn kv.Transaction, job *model.Job) { - if job.IsDone() || job.IsRollbackDone() || - // When this column is in the "delete only" and "delete reorg" states, the binlog of "drop column" has not been written yet, - // but the column has been removed from the binlog of the write operation. - // So we add this binlog to enable downstream components to handle DML correctly in this schema state. - (job.Type == model.ActionDropColumn && job.SchemaState == model.StateDeleteOnly) { - if skipWriteBinlog(job) { - return - } - binloginfo.SetDDLBinlog(binlogCli, txn, job.ID, int32(job.SchemaState), job.Query) - } -} - -func chooseLeaseTime(t, max time.Duration) time.Duration { - if t == 0 || t > max { - return max - } - return t -} - -// countForPanic records the error count for DDL job. -func (w *worker) countForPanic(job *model.Job) { - // If run DDL job panic, just cancel the DDL jobs. - if job.State == model.JobStateRollingback { - job.State = model.JobStateCancelled - } else { - job.State = model.JobStateCancelling - } - job.ErrorCount++ - - logger := w.jobLogger(job) - // Load global DDL variables. - if err1 := loadDDLVars(w); err1 != nil { - logger.Error("load DDL global variable failed", zap.Error(err1)) - } - errorCount := variable.GetDDLErrorCountLimit() - - if job.ErrorCount > errorCount { - msg := fmt.Sprintf("panic in handling DDL logic and error count beyond the limitation %d, cancelled", errorCount) - logger.Warn(msg) - job.Error = toTError(errors.New(msg)) - job.State = model.JobStateCancelled - } -} - -// countForError records the error count for DDL job. -func (w *worker) countForError(err error, job *model.Job) error { - job.Error = toTError(err) - job.ErrorCount++ - - logger := w.jobLogger(job) - // If job is cancelled, we shouldn't return an error and shouldn't load DDL variables. - if job.State == model.JobStateCancelled { - logger.Info("DDL job is cancelled normally", zap.Error(err)) - return nil - } - logger.Warn("run DDL job error", zap.Error(err)) - - // Load global DDL variables. - if err1 := loadDDLVars(w); err1 != nil { - logger.Error("load DDL global variable failed", zap.Error(err1)) - } - // Check error limit to avoid falling into an infinite loop. - if job.ErrorCount > variable.GetDDLErrorCountLimit() && job.State == model.JobStateRunning && job.IsRollbackable() { - logger.Warn("DDL job error count exceed the limit, cancelling it now", zap.Int64("errorCountLimit", variable.GetDDLErrorCountLimit())) - job.State = model.JobStateCancelling - } - return err -} - -func (w *worker) processJobPausingRequest(d *ddlCtx, job *model.Job) (isRunnable bool, err error) { - if job.IsPaused() { - w.jobLogger(job).Debug("paused DDL job ", zap.String("job", job.String())) - return false, err - } - if job.IsPausing() { - w.jobLogger(job).Debug("pausing DDL job ", zap.String("job", job.String())) - job.State = model.JobStatePaused - return false, pauseReorgWorkers(w, d, job) - } - return true, nil -} - -// runOneJobStep runs a DDL job *step*. It returns the current schema version in -// this transaction, if the given job.Args has changed, and the error. The *step* -// is defined as the following reasons: -// -// - TiDB uses "Asynchronous Schema Change in F1", one job may have multiple -// *steps* each for a schema state change such as 'delete only' -> 'write only'. -// Combined with caller transitOneJobStepAndWaitSync waiting for other nodes to -// catch up with the returned schema version, we can make sure the cluster will -// only have two adjacent schema state for a DDL object. -// -// - Some types of DDL jobs has defined its own *step*s other than F1 paper. -// These *step*s may not be schema state change, and their purposes are various. -// For example, onLockTables updates the lock state of one table every *step*. -// -// - To provide linearizability we have added extra job state change *step*. For -// example, if job becomes JobStateDone in runOneJobStep, we cannot return to -// user that the job is finished because other nodes in cluster may not be -// synchronized. So JobStateSynced *step* is added to make sure there is -// waitSchemaChanged to wait for all nodes to catch up JobStateDone. -func (w *worker) runOneJobStep( - d *ddlCtx, - t *meta.Meta, - job *model.Job, -) (ver int64, updateRawArgs bool, err error) { - defer tidbutil.Recover(metrics.LabelDDLWorker, fmt.Sprintf("%s runOneJobStep", w), - func() { - w.countForPanic(job) - }, false) - - // Mock for run ddl job panic. - failpoint.Inject("mockPanicInRunDDLJob", func(failpoint.Value) {}) - - if job.Type != model.ActionMultiSchemaChange { - w.jobLogger(job).Info("run DDL job", zap.String("category", "ddl"), zap.String("job", job.String())) - } - timeStart := time.Now() - if job.RealStartTS == 0 { - job.RealStartTS = t.StartTS - } - defer func() { - metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerRunDDLJob, job.Type.String(), metrics.RetLabel(err)).Observe(time.Since(timeStart).Seconds()) - }() - - if job.IsCancelling() { - w.jobLogger(job).Debug("cancel DDL job", zap.String("job", job.String())) - ver, err = convertJob2RollbackJob(w, d, t, job) - // if job is converted to rollback job, the job.Args may be changed for the - // rollback logic, so we let caller persist the new arguments. - updateRawArgs = job.IsRollingback() - return - } - - isRunnable, err := w.processJobPausingRequest(d, job) - if !isRunnable { - return ver, false, err - } - - // It would be better to do the positive check, but no idea to list all valid states here now. - if !job.IsRollingback() { - job.State = model.JobStateRunning - } - - prevState := job.State - - // For every type, `schema/table` modification and `job` modification are conducted - // in the one kv transaction. The `schema/table` modification can be always discarded - // by kv reset when meets an unhandled error, but the `job` modification can't. - // So make sure job state and args change is after all other checks or make sure these - // change has no effect when retrying it. - switch job.Type { - case model.ActionCreateSchema: - ver, err = onCreateSchema(d, t, job) - case model.ActionModifySchemaCharsetAndCollate: - ver, err = onModifySchemaCharsetAndCollate(d, t, job) - case model.ActionDropSchema: - ver, err = onDropSchema(d, t, job) - case model.ActionRecoverSchema: - ver, err = w.onRecoverSchema(d, t, job) - case model.ActionModifySchemaDefaultPlacement: - ver, err = onModifySchemaDefaultPlacement(d, t, job) - case model.ActionCreateTable: - ver, err = onCreateTable(d, t, job) - case model.ActionCreateTables: - ver, err = onCreateTables(d, t, job) - case model.ActionRepairTable: - ver, err = onRepairTable(d, t, job) - case model.ActionCreateView: - ver, err = onCreateView(d, t, job) - case model.ActionDropTable, model.ActionDropView, model.ActionDropSequence: - ver, err = onDropTableOrView(d, t, job) - case model.ActionDropTablePartition: - ver, err = w.onDropTablePartition(d, t, job) - case model.ActionTruncateTablePartition: - ver, err = w.onTruncateTablePartition(d, t, job) - case model.ActionExchangeTablePartition: - ver, err = w.onExchangeTablePartition(d, t, job) - case model.ActionAddColumn: - ver, err = onAddColumn(d, t, job) - case model.ActionDropColumn: - ver, err = onDropColumn(d, t, job) - case model.ActionModifyColumn: - ver, err = w.onModifyColumn(d, t, job) - case model.ActionSetDefaultValue: - ver, err = onSetDefaultValue(d, t, job) - case model.ActionAddIndex: - ver, err = w.onCreateIndex(d, t, job, false) - case model.ActionAddPrimaryKey: - ver, err = w.onCreateIndex(d, t, job, true) - case model.ActionDropIndex, model.ActionDropPrimaryKey: - ver, err = onDropIndex(d, t, job) - case model.ActionRenameIndex: - ver, err = onRenameIndex(d, t, job) - case model.ActionAddForeignKey: - ver, err = w.onCreateForeignKey(d, t, job) - case model.ActionDropForeignKey: - ver, err = onDropForeignKey(d, t, job) - case model.ActionTruncateTable: - ver, err = w.onTruncateTable(d, t, job) - case model.ActionRebaseAutoID: - ver, err = onRebaseAutoIncrementIDType(d, t, job) - case model.ActionRebaseAutoRandomBase: - ver, err = onRebaseAutoRandomType(d, t, job) - case model.ActionRenameTable: - ver, err = onRenameTable(d, t, job) - case model.ActionShardRowID: - ver, err = w.onShardRowID(d, t, job) - case model.ActionModifyTableComment: - ver, err = onModifyTableComment(d, t, job) - case model.ActionModifyTableAutoIdCache: - ver, err = onModifyTableAutoIDCache(d, t, job) - case model.ActionAddTablePartition: - ver, err = w.onAddTablePartition(d, t, job) - case model.ActionModifyTableCharsetAndCollate: - ver, err = onModifyTableCharsetAndCollate(d, t, job) - case model.ActionRecoverTable: - ver, err = w.onRecoverTable(d, t, job) - case model.ActionLockTable: - ver, err = onLockTables(d, t, job) - case model.ActionUnlockTable: - ver, err = onUnlockTables(d, t, job) - case model.ActionSetTiFlashReplica: - ver, err = w.onSetTableFlashReplica(d, t, job) - case model.ActionUpdateTiFlashReplicaStatus: - ver, err = onUpdateFlashReplicaStatus(d, t, job) - case model.ActionCreateSequence: - ver, err = onCreateSequence(d, t, job) - case model.ActionAlterIndexVisibility: - ver, err = onAlterIndexVisibility(d, t, job) - case model.ActionAlterSequence: - ver, err = onAlterSequence(d, t, job) - case model.ActionRenameTables: - ver, err = onRenameTables(d, t, job) - case model.ActionAlterTableAttributes: - ver, err = onAlterTableAttributes(d, t, job) - case model.ActionAlterTablePartitionAttributes: - ver, err = onAlterTablePartitionAttributes(d, t, job) - case model.ActionCreatePlacementPolicy: - ver, err = onCreatePlacementPolicy(d, t, job) - case model.ActionDropPlacementPolicy: - ver, err = onDropPlacementPolicy(d, t, job) - case model.ActionAlterPlacementPolicy: - ver, err = onAlterPlacementPolicy(d, t, job) - case model.ActionAlterTablePartitionPlacement: - ver, err = onAlterTablePartitionPlacement(d, t, job) - case model.ActionAlterTablePlacement: - ver, err = onAlterTablePlacement(d, t, job) - case model.ActionCreateResourceGroup: - ver, err = onCreateResourceGroup(w.ctx, d, t, job) - case model.ActionAlterResourceGroup: - ver, err = onAlterResourceGroup(d, t, job) - case model.ActionDropResourceGroup: - ver, err = onDropResourceGroup(d, t, job) - case model.ActionAlterCacheTable: - ver, err = onAlterCacheTable(d, t, job) - case model.ActionAlterNoCacheTable: - ver, err = onAlterNoCacheTable(d, t, job) - case model.ActionFlashbackCluster: - ver, err = w.onFlashbackCluster(d, t, job) - case model.ActionMultiSchemaChange: - ver, err = onMultiSchemaChange(w, d, t, job) - case model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - ver, err = w.onReorganizePartition(d, t, job) - case model.ActionAlterTTLInfo: - ver, err = onTTLInfoChange(d, t, job) - case model.ActionAlterTTLRemove: - ver, err = onTTLInfoRemove(d, t, job) - case model.ActionAddCheckConstraint: - ver, err = w.onAddCheckConstraint(d, t, job) - case model.ActionDropCheckConstraint: - ver, err = onDropCheckConstraint(d, t, job) - case model.ActionAlterCheckConstraint: - ver, err = w.onAlterCheckConstraint(d, t, job) - default: - // Invalid job, cancel it. - job.State = model.JobStateCancelled - err = dbterror.ErrInvalidDDLJob.GenWithStack("invalid ddl job type: %v", job.Type) - } - - // there are too many job types, instead let every job type output its own - // updateRawArgs, we try to use these rules as a generalization: - // - // if job has no error, some arguments may be changed, there's no harm to update - // it. - updateRawArgs = err == nil - // if job changed from running to rolling back, arguments may be changed - if prevState == model.JobStateRunning && job.IsRollingback() { - updateRawArgs = true - } - - // Save errors in job if any, so that others can know errors happened. - if err != nil { - err = w.countForError(err, job) - } - return ver, updateRawArgs, err -} - -func loadDDLVars(w *worker) error { - // Get sessionctx from context resource pool. - var ctx sessionctx.Context - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) - return util.LoadDDLVars(ctx) -} - -func toTError(err error) *terror.Error { - originErr := errors.Cause(err) - tErr, ok := originErr.(*terror.Error) - if ok { - return tErr - } - - // TODO: Add the error code. - return dbterror.ClassDDL.Synthesize(terror.CodeUnknown, err.Error()) -} - -// waitSchemaChanged waits for the completion of updating all servers' schema or MDL synced. In order to make sure that happens, -// we wait at most 2 * lease time(sessionTTL, 90 seconds). -func waitSchemaChanged(ctx context.Context, d *ddlCtx, latestSchemaVersion int64, job *model.Job) error { - if !job.IsRunning() && !job.IsRollingback() && !job.IsDone() && !job.IsRollbackDone() { - return nil - } - - timeStart := time.Now() - var err error - defer func() { - metrics.DDLWorkerHistogram.WithLabelValues(metrics.WorkerWaitSchemaChanged, job.Type.String(), metrics.RetLabel(err)).Observe(time.Since(timeStart).Seconds()) - }() - - if latestSchemaVersion == 0 { - logutil.DDLLogger().Info("schema version doesn't change", zap.Int64("jobID", job.ID)) - return nil - } - - err = d.schemaSyncer.OwnerUpdateGlobalVersion(ctx, latestSchemaVersion) - if err != nil { - logutil.DDLLogger().Info("update latest schema version failed", zap.Int64("ver", latestSchemaVersion), zap.Error(err)) - if variable.EnableMDL.Load() { - return err - } - if terror.ErrorEqual(err, context.DeadlineExceeded) { - // If err is context.DeadlineExceeded, it means waitTime(2 * lease) is elapsed. So all the schemas are synced by ticker. - // There is no need to use etcd to sync. The function returns directly. - return nil - } - } - - return checkAllVersions(ctx, d, job, latestSchemaVersion, timeStart) -} - -// waitSchemaSyncedForMDL likes waitSchemaSynced, but it waits for getting the metadata lock of the latest version of this DDL. -func waitSchemaSyncedForMDL(ctx context.Context, d *ddlCtx, job *model.Job, latestSchemaVersion int64) error { - timeStart := time.Now() - return checkAllVersions(ctx, d, job, latestSchemaVersion, timeStart) -} - -func buildPlacementAffects(oldIDs []int64, newIDs []int64) []*model.AffectedOption { - if len(oldIDs) == 0 { - return nil - } - - affects := make([]*model.AffectedOption, len(oldIDs)) - for i := 0; i < len(oldIDs); i++ { - affects[i] = &model.AffectedOption{ - OldTableID: oldIDs[i], - TableID: newIDs[i], - } - } - return affects -} diff --git a/pkg/ddl/mock.go b/pkg/ddl/mock.go index 330aa62ce6980..7be8f499fa01e 100644 --- a/pkg/ddl/mock.go +++ b/pkg/ddl/mock.go @@ -72,11 +72,11 @@ func (*MockSchemaSyncer) WatchGlobalSchemaVer(context.Context) {} // UpdateSelfVersion implements SchemaSyncer.UpdateSelfVersion interface. func (s *MockSchemaSyncer) UpdateSelfVersion(_ context.Context, jobID int64, version int64) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateMDLToETCDError")); _err_ == nil { + failpoint.Inject("mockUpdateMDLToETCDError", func(val failpoint.Value) { if val.(bool) { - return errors.New("mock update mdl to etcd error") + failpoint.Return(errors.New("mock update mdl to etcd error")) } - } + }) if variable.EnableMDL.Load() { s.mdlSchemaVersions.Store(jobID, version) } else { @@ -115,20 +115,20 @@ func (s *MockSchemaSyncer) OwnerCheckAllVersions(ctx context.Context, jobID int6 ticker := time.NewTicker(mockCheckVersInterval) defer ticker.Stop() - if val, _err_ := failpoint.Eval(_curpkg_("mockOwnerCheckAllVersionSlow")); _err_ == nil { + failpoint.Inject("mockOwnerCheckAllVersionSlow", func(val failpoint.Value) { if v, ok := val.(int); ok && v == int(jobID) { time.Sleep(2 * time.Second) } - } + }) for { select { case <-ctx.Done(): - if v, _err_ := failpoint.Eval(_curpkg_("checkOwnerCheckAllVersionsWaitTime")); _err_ == nil { + failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(v failpoint.Value) { if v.(bool) { panic("shouldn't happen") } - } + }) return errors.Trace(ctx.Err()) case <-ticker.C: if variable.EnableMDL.Load() { @@ -181,12 +181,12 @@ func (s *MockStateSyncer) Init(context.Context) error { // UpdateGlobalState implements StateSyncer.UpdateGlobalState interface. func (s *MockStateSyncer) UpdateGlobalState(_ context.Context, stateInfo *syncer.StateInfo) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockUpgradingState")); _err_ == nil { + failpoint.Inject("mockUpgradingState", func(val failpoint.Value) { if val.(bool) { clusterState.Store(stateInfo) - return nil + failpoint.Return(nil) } - } + }) s.globalVerCh <- clientv3.WatchResponse{} clusterState.Store(stateInfo) return nil diff --git a/pkg/ddl/mock.go__failpoint_stash__ b/pkg/ddl/mock.go__failpoint_stash__ deleted file mode 100644 index 7be8f499fa01e..0000000000000 --- a/pkg/ddl/mock.go__failpoint_stash__ +++ /dev/null @@ -1,260 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/syncer" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - clientv3 "go.etcd.io/etcd/client/v3" - atomicutil "go.uber.org/atomic" -) - -// SetBatchInsertDeleteRangeSize sets the batch insert/delete range size in the test -func SetBatchInsertDeleteRangeSize(i int) { - batchInsertDeleteRangeSize = i -} - -var _ syncer.SchemaSyncer = &MockSchemaSyncer{} - -const mockCheckVersInterval = 2 * time.Millisecond - -// MockSchemaSyncer is a mock schema syncer, it is exported for testing. -type MockSchemaSyncer struct { - selfSchemaVersion int64 - mdlSchemaVersions sync.Map - globalVerCh chan clientv3.WatchResponse - mockSession chan struct{} -} - -// NewMockSchemaSyncer creates a new mock SchemaSyncer. -func NewMockSchemaSyncer() syncer.SchemaSyncer { - return &MockSchemaSyncer{} -} - -// Init implements SchemaSyncer.Init interface. -func (s *MockSchemaSyncer) Init(_ context.Context) error { - s.mdlSchemaVersions = sync.Map{} - s.globalVerCh = make(chan clientv3.WatchResponse, 1) - s.mockSession = make(chan struct{}, 1) - return nil -} - -// GlobalVersionCh implements SchemaSyncer.GlobalVersionCh interface. -func (s *MockSchemaSyncer) GlobalVersionCh() clientv3.WatchChan { - return s.globalVerCh -} - -// WatchGlobalSchemaVer implements SchemaSyncer.WatchGlobalSchemaVer interface. -func (*MockSchemaSyncer) WatchGlobalSchemaVer(context.Context) {} - -// UpdateSelfVersion implements SchemaSyncer.UpdateSelfVersion interface. -func (s *MockSchemaSyncer) UpdateSelfVersion(_ context.Context, jobID int64, version int64) error { - failpoint.Inject("mockUpdateMDLToETCDError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("mock update mdl to etcd error")) - } - }) - if variable.EnableMDL.Load() { - s.mdlSchemaVersions.Store(jobID, version) - } else { - atomic.StoreInt64(&s.selfSchemaVersion, version) - } - return nil -} - -// Done implements SchemaSyncer.Done interface. -func (s *MockSchemaSyncer) Done() <-chan struct{} { - return s.mockSession -} - -// CloseSession mockSession, it is exported for testing. -func (s *MockSchemaSyncer) CloseSession() { - close(s.mockSession) -} - -// Restart implements SchemaSyncer.Restart interface. -func (s *MockSchemaSyncer) Restart(_ context.Context) error { - s.mockSession = make(chan struct{}, 1) - return nil -} - -// OwnerUpdateGlobalVersion implements SchemaSyncer.OwnerUpdateGlobalVersion interface. -func (s *MockSchemaSyncer) OwnerUpdateGlobalVersion(_ context.Context, _ int64) error { - select { - case s.globalVerCh <- clientv3.WatchResponse{}: - default: - } - return nil -} - -// OwnerCheckAllVersions implements SchemaSyncer.OwnerCheckAllVersions interface. -func (s *MockSchemaSyncer) OwnerCheckAllVersions(ctx context.Context, jobID int64, latestVer int64) error { - ticker := time.NewTicker(mockCheckVersInterval) - defer ticker.Stop() - - failpoint.Inject("mockOwnerCheckAllVersionSlow", func(val failpoint.Value) { - if v, ok := val.(int); ok && v == int(jobID) { - time.Sleep(2 * time.Second) - } - }) - - for { - select { - case <-ctx.Done(): - failpoint.Inject("checkOwnerCheckAllVersionsWaitTime", func(v failpoint.Value) { - if v.(bool) { - panic("shouldn't happen") - } - }) - return errors.Trace(ctx.Err()) - case <-ticker.C: - if variable.EnableMDL.Load() { - ver, ok := s.mdlSchemaVersions.Load(jobID) - if ok && ver.(int64) >= latestVer { - return nil - } - } else { - ver := atomic.LoadInt64(&s.selfSchemaVersion) - if ver >= latestVer { - return nil - } - } - } - } -} - -// SyncJobSchemaVerLoop implements SchemaSyncer.SyncJobSchemaVerLoop interface. -func (*MockSchemaSyncer) SyncJobSchemaVerLoop(context.Context) { -} - -// Close implements SchemaSyncer.Close interface. -func (*MockSchemaSyncer) Close() {} - -// NewMockStateSyncer creates a new mock StateSyncer. -func NewMockStateSyncer() syncer.StateSyncer { - return &MockStateSyncer{} -} - -// clusterState mocks cluster state. -// We move it from MockStateSyncer to here. Because we want to make it unaffected by ddl close. -var clusterState *atomicutil.Pointer[syncer.StateInfo] - -// MockStateSyncer is a mock state syncer, it is exported for testing. -type MockStateSyncer struct { - globalVerCh chan clientv3.WatchResponse - mockSession chan struct{} -} - -// Init implements StateSyncer.Init interface. -func (s *MockStateSyncer) Init(context.Context) error { - s.globalVerCh = make(chan clientv3.WatchResponse, 1) - s.mockSession = make(chan struct{}, 1) - state := syncer.NewStateInfo(syncer.StateNormalRunning) - if clusterState == nil { - clusterState = atomicutil.NewPointer(state) - } - return nil -} - -// UpdateGlobalState implements StateSyncer.UpdateGlobalState interface. -func (s *MockStateSyncer) UpdateGlobalState(_ context.Context, stateInfo *syncer.StateInfo) error { - failpoint.Inject("mockUpgradingState", func(val failpoint.Value) { - if val.(bool) { - clusterState.Store(stateInfo) - failpoint.Return(nil) - } - }) - s.globalVerCh <- clientv3.WatchResponse{} - clusterState.Store(stateInfo) - return nil -} - -// GetGlobalState implements StateSyncer.GetGlobalState interface. -func (*MockStateSyncer) GetGlobalState(context.Context) (*syncer.StateInfo, error) { - return clusterState.Load(), nil -} - -// IsUpgradingState implements StateSyncer.IsUpgradingState interface. -func (*MockStateSyncer) IsUpgradingState() bool { - return clusterState.Load().State == syncer.StateUpgrading -} - -// WatchChan implements StateSyncer.WatchChan interface. -func (s *MockStateSyncer) WatchChan() clientv3.WatchChan { - return s.globalVerCh -} - -// Rewatch implements StateSyncer.Rewatch interface. -func (*MockStateSyncer) Rewatch(context.Context) {} - -type mockDelRange struct { -} - -// newMockDelRangeManager creates a mock delRangeManager only used for test. -func newMockDelRangeManager() delRangeManager { - return &mockDelRange{} -} - -// addDelRangeJob implements delRangeManager interface. -func (*mockDelRange) addDelRangeJob(_ context.Context, _ *model.Job) error { - return nil -} - -// removeFromGCDeleteRange implements delRangeManager interface. -func (*mockDelRange) removeFromGCDeleteRange(_ context.Context, _ int64) error { - return nil -} - -// start implements delRangeManager interface. -func (*mockDelRange) start() {} - -// clear implements delRangeManager interface. -func (*mockDelRange) clear() {} - -// MockTableInfo mocks a table info by create table stmt ast and a specified table id. -func MockTableInfo(ctx sessionctx.Context, stmt *ast.CreateTableStmt, tableID int64) (*model.TableInfo, error) { - chs, coll := charset.GetDefaultCharsetAndCollate() - cols, newConstraints, err := buildColumnsAndConstraints(ctx, stmt.Cols, stmt.Constraints, chs, coll) - if err != nil { - return nil, errors.Trace(err) - } - tbl, err := BuildTableInfo(ctx, stmt.Table.Name, cols, newConstraints, "", "") - if err != nil { - return nil, errors.Trace(err) - } - tbl.ID = tableID - - if err = setTableAutoRandomBits(ctx, tbl, stmt.Cols); err != nil { - return nil, errors.Trace(err) - } - - // The specified charset will be handled in handleTableOptions - if err = handleTableOptions(stmt.Options, tbl); err != nil { - return nil, errors.Trace(err) - } - - return tbl, nil -} diff --git a/pkg/ddl/modify_column.go b/pkg/ddl/modify_column.go index 1c95073988983..09a654d133a26 100644 --- a/pkg/ddl/modify_column.go +++ b/pkg/ddl/modify_column.go @@ -84,14 +84,14 @@ func (w *worker) onModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver in } } - if val, _err_ := failpoint.Eval(_curpkg_("uninitializedOffsetAndState")); _err_ == nil { + failpoint.Inject("uninitializedOffsetAndState", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { if modifyInfo.newCol.State != model.StatePublic { - return ver, errors.New("the column state is wrong") + failpoint.Return(ver, errors.New("the column state is wrong")) } } - } + }) err = checkAndApplyAutoRandomBits(d, t, dbInfo, tblInfo, oldCol, modifyInfo.newCol, modifyInfo.updatedAutoRandomBits) if err != nil { @@ -442,12 +442,12 @@ func (w *worker) doModifyColumnTypeWithData( } // none -> delete only updateChangingObjState(changingCol, changingIdxs, model.StateDeleteOnly) - if val, _err_ := failpoint.Eval(_curpkg_("mockInsertValueAfterCheckNull")); _err_ == nil { + failpoint.Inject("mockInsertValueAfterCheckNull", func(val failpoint.Value) { if valStr, ok := val.(string); ok { var sctx sessionctx.Context sctx, err := w.sessPool.Get() if err != nil { - return ver, err + failpoint.Return(ver, err) } defer w.sessPool.Put(sctx) @@ -456,10 +456,10 @@ func (w *worker) doModifyColumnTypeWithData( _, _, err = sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, valStr) if err != nil { job.State = model.JobStateCancelled - return ver, err + failpoint.Return(ver, err) } } - } + }) ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != changingCol.State) if err != nil { return ver, errors.Trace(err) @@ -488,7 +488,7 @@ func (w *worker) doModifyColumnTypeWithData( return ver, errors.Trace(err) } job.SchemaState = model.StateWriteOnly - failpoint.Call(_curpkg_("afterModifyColumnStateDeleteOnly"), job.ID) + failpoint.InjectCall("afterModifyColumnStateDeleteOnly", job.ID) case model.StateWriteOnly: // write only -> reorganization updateChangingObjState(changingCol, changingIdxs, model.StateWriteReorganization) @@ -587,7 +587,7 @@ func doReorgWorkForModifyColumn(w *worker, d *ddlCtx, t *meta.Meta, job *model.J // With a failpoint-enabled version of TiDB, you can trigger this failpoint by the following command: // enable: curl -X PUT -d "pause" "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData". // disable: curl -X DELETE "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData" - failpoint.Eval(_curpkg_("mockDelayInModifyColumnTypeWithData")) + failpoint.Inject("mockDelayInModifyColumnTypeWithData", func() {}) err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (addIndexErr error) { defer util.Recover(metrics.LabelDDL, "onModifyColumn", func() { diff --git a/pkg/ddl/modify_column.go__failpoint_stash__ b/pkg/ddl/modify_column.go__failpoint_stash__ deleted file mode 100644 index 09a654d133a26..0000000000000 --- a/pkg/ddl/modify_column.go__failpoint_stash__ +++ /dev/null @@ -1,1318 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "context" - "fmt" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/expression" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - "go.uber.org/zap" -) - -type modifyingColInfo struct { - newCol *model.ColumnInfo - oldColName *model.CIStr - modifyColumnTp byte - updatedAutoRandomBits uint64 - changingCol *model.ColumnInfo - changingIdxs []*model.IndexInfo - pos *ast.ColumnPosition - removedIdxs []int64 -} - -func (w *worker) onModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - dbInfo, tblInfo, oldCol, modifyInfo, err := getModifyColumnInfo(t, job) - if err != nil { - return ver, err - } - - if job.IsRollingback() { - // For those column-type-change jobs which don't reorg the data. - if !needChangeColumnData(oldCol, modifyInfo.newCol) { - return rollbackModifyColumnJob(d, t, tblInfo, job, modifyInfo.newCol, oldCol, modifyInfo.modifyColumnTp) - } - // For those column-type-change jobs which reorg the data. - return rollbackModifyColumnJobWithData(d, t, tblInfo, job, oldCol, modifyInfo) - } - - // If we want to rename the column name, we need to check whether it already exists. - if modifyInfo.newCol.Name.L != modifyInfo.oldColName.L { - c := model.FindColumnInfo(tblInfo.Columns, modifyInfo.newCol.Name.L) - if c != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(infoschema.ErrColumnExists.GenWithStackByArgs(modifyInfo.newCol.Name)) - } - } - - failpoint.Inject("uninitializedOffsetAndState", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - if modifyInfo.newCol.State != model.StatePublic { - failpoint.Return(ver, errors.New("the column state is wrong")) - } - } - }) - - err = checkAndApplyAutoRandomBits(d, t, dbInfo, tblInfo, oldCol, modifyInfo.newCol, modifyInfo.updatedAutoRandomBits) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if !needChangeColumnData(oldCol, modifyInfo.newCol) { - return w.doModifyColumn(d, t, job, dbInfo, tblInfo, modifyInfo.newCol, oldCol, modifyInfo.pos) - } - - if err = isGeneratedRelatedColumn(tblInfo, modifyInfo.newCol, oldCol); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - if tblInfo.Partition != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("table is partition table")) - } - - changingCol := modifyInfo.changingCol - if changingCol == nil { - newColName := model.NewCIStr(genChangingColumnUniqueName(tblInfo, oldCol)) - if mysql.HasPriKeyFlag(oldCol.GetFlag()) { - job.State = model.JobStateCancelled - msg := "this column has primary key flag" - return ver, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(msg) - } - - changingCol = modifyInfo.newCol.Clone() - changingCol.Name = newColName - changingCol.ChangeStateInfo = &model.ChangeStateInfo{DependencyColumnOffset: oldCol.Offset} - - originDefVal, err := GetOriginDefaultValueForModifyColumn(newReorgExprCtx(), changingCol, oldCol) - if err != nil { - return ver, errors.Trace(err) - } - if err = changingCol.SetOriginDefaultValue(originDefVal); err != nil { - return ver, errors.Trace(err) - } - - InitAndAddColumnToTable(tblInfo, changingCol) - indexesToChange := FindRelatedIndexesToChange(tblInfo, oldCol.Name) - for _, info := range indexesToChange { - newIdxID := AllocateIndexID(tblInfo) - if !info.isTemp { - // We create a temp index for each normal index. - tmpIdx := info.IndexInfo.Clone() - tmpIdxName := genChangingIndexUniqueName(tblInfo, info.IndexInfo) - setIdxIDName(tmpIdx, newIdxID, model.NewCIStr(tmpIdxName)) - SetIdxColNameOffset(tmpIdx.Columns[info.Offset], changingCol) - tblInfo.Indices = append(tblInfo.Indices, tmpIdx) - } else { - // The index is a temp index created by previous modify column job(s). - // We can overwrite it to reduce reorg cost, because it will be dropped eventually. - tmpIdx := info.IndexInfo - oldTempIdxID := tmpIdx.ID - setIdxIDName(tmpIdx, newIdxID, tmpIdx.Name /* unchanged */) - SetIdxColNameOffset(tmpIdx.Columns[info.Offset], changingCol) - modifyInfo.removedIdxs = append(modifyInfo.removedIdxs, oldTempIdxID) - } - } - } else { - changingCol = model.FindColumnInfoByID(tblInfo.Columns, modifyInfo.changingCol.ID) - if changingCol == nil { - logutil.DDLLogger().Error("the changing column has been removed", zap.Error(err)) - job.State = model.JobStateCancelled - return ver, errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) - } - } - - return w.doModifyColumnTypeWithData(d, t, job, dbInfo, tblInfo, changingCol, oldCol, modifyInfo.newCol.Name, modifyInfo.pos, modifyInfo.removedIdxs) -} - -// rollbackModifyColumnJob rollbacks the job when an error occurs. -func rollbackModifyColumnJob(d *ddlCtx, t *meta.Meta, tblInfo *model.TableInfo, job *model.Job, newCol, oldCol *model.ColumnInfo, modifyColumnTp byte) (ver int64, _ error) { - var err error - if oldCol.ID == newCol.ID && modifyColumnTp == mysql.TypeNull { - // field NotNullFlag flag reset. - tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.NotNullFlag) - // field PreventNullInsertFlag flag reset. - tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.PreventNullInsertFlag) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - } - job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) - // For those column-type-change type which doesn't need reorg data, we should also mock the job args for delete range. - job.Args = []any{[]int64{}, []int64{}} - return ver, nil -} - -func getModifyColumnInfo(t *meta.Meta, job *model.Job) (*model.DBInfo, *model.TableInfo, *model.ColumnInfo, *modifyingColInfo, error) { - modifyInfo := &modifyingColInfo{pos: &ast.ColumnPosition{}} - err := job.DecodeArgs(&modifyInfo.newCol, &modifyInfo.oldColName, modifyInfo.pos, &modifyInfo.modifyColumnTp, - &modifyInfo.updatedAutoRandomBits, &modifyInfo.changingCol, &modifyInfo.changingIdxs, &modifyInfo.removedIdxs) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, modifyInfo, errors.Trace(err) - } - - dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) - if err != nil { - return nil, nil, nil, modifyInfo, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return nil, nil, nil, modifyInfo, errors.Trace(err) - } - - oldCol := model.FindColumnInfo(tblInfo.Columns, modifyInfo.oldColName.L) - if oldCol == nil || oldCol.State != model.StatePublic { - job.State = model.JobStateCancelled - return nil, nil, nil, modifyInfo, errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(*(modifyInfo.oldColName), tblInfo.Name)) - } - - return dbInfo, tblInfo, oldCol, modifyInfo, errors.Trace(err) -} - -// GetOriginDefaultValueForModifyColumn gets the original default value for modifying column. -// Since column type change is implemented as adding a new column then substituting the old one. -// Case exists when update-where statement fetch a NULL for not-null column without any default data, -// it will errors. -// So we set original default value here to prevent this error. If the oldCol has the original default value, we use it. -// Otherwise we set the zero value as original default value. -// Besides, in insert & update records, we have already implement using the casted value of relative column to insert -// rather than the original default value. -func GetOriginDefaultValueForModifyColumn(ctx exprctx.BuildContext, changingCol, oldCol *model.ColumnInfo) (any, error) { - var err error - originDefVal := oldCol.GetOriginDefaultValue() - if originDefVal != nil { - odv, err := table.CastColumnValue(ctx, types.NewDatum(originDefVal), changingCol, false, false) - if err != nil { - logutil.DDLLogger().Info("cast origin default value failed", zap.Error(err)) - } - if !odv.IsNull() { - if originDefVal, err = odv.ToString(); err != nil { - originDefVal = nil - logutil.DDLLogger().Info("convert default value to string failed", zap.Error(err)) - } - } - } - if originDefVal == nil { - originDefVal, err = generateOriginDefaultValue(changingCol, nil) - if err != nil { - return nil, errors.Trace(err) - } - } - return originDefVal, nil -} - -// rollbackModifyColumnJobWithData is used to rollback modify-column job which need to reorg the data. -func rollbackModifyColumnJobWithData(d *ddlCtx, t *meta.Meta, tblInfo *model.TableInfo, job *model.Job, oldCol *model.ColumnInfo, modifyInfo *modifyingColInfo) (ver int64, err error) { - // If the not-null change is included, we should clean the flag info in oldCol. - if modifyInfo.modifyColumnTp == mysql.TypeNull { - // Reset NotNullFlag flag. - tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.NotNullFlag) - // Reset PreventNullInsertFlag flag. - tblInfo.Columns[oldCol.Offset].SetFlag(oldCol.GetFlag() &^ mysql.PreventNullInsertFlag) - } - var changingIdxIDs []int64 - if modifyInfo.changingCol != nil { - changingIdxIDs = buildRelatedIndexIDs(tblInfo, modifyInfo.changingCol.ID) - // The job is in the middle state. The appended changingCol and changingIndex should - // be removed from the tableInfo as well. - removeChangingColAndIdxs(tblInfo, modifyInfo.changingCol.ID) - } - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) - // Reconstruct the job args to add the temporary index ids into delete range table. - job.Args = []any{changingIdxIDs, getPartitionIDs(tblInfo)} - return ver, nil -} - -// doModifyColumn updates the column information and reorders all columns. It does not support modifying column data. -func (w *worker) doModifyColumn( - d *ddlCtx, t *meta.Meta, job *model.Job, dbInfo *model.DBInfo, tblInfo *model.TableInfo, - newCol, oldCol *model.ColumnInfo, pos *ast.ColumnPosition) (ver int64, _ error) { - if oldCol.ID != newCol.ID { - job.State = model.JobStateRollingback - return ver, dbterror.ErrColumnInChange.GenWithStackByArgs(oldCol.Name, newCol.ID) - } - // Column from null to not null. - if !mysql.HasNotNullFlag(oldCol.GetFlag()) && mysql.HasNotNullFlag(newCol.GetFlag()) { - noPreventNullFlag := !mysql.HasPreventNullInsertFlag(oldCol.GetFlag()) - - // lease = 0 means it's in an integration test. In this case we don't delay so the test won't run too slowly. - // We need to check after the flag is set - if d.lease > 0 && !noPreventNullFlag { - delayForAsyncCommit() - } - - // Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. - err := modifyColsFromNull2NotNull(w, dbInfo, tblInfo, []*model.ColumnInfo{oldCol}, newCol, oldCol.GetType() != newCol.GetType()) - if err != nil { - if dbterror.ErrWarnDataTruncated.Equal(err) || dbterror.ErrInvalidUseOfNull.Equal(err) { - job.State = model.JobStateRollingback - } - return ver, err - } - // The column should get into prevent null status first. - if noPreventNullFlag { - return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - } - } - - if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - // Store the mark and enter the next DDL handling loop. - return updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, false) - } - - if err := adjustTableInfoAfterModifyColumn(tblInfo, newCol, oldCol, pos); err != nil { - job.State = model.JobStateRollingback - return ver, errors.Trace(err) - } - - childTableInfos, err := adjustForeignKeyChildTableInfoAfterModifyColumn(d, t, job, tblInfo, newCol, oldCol) - if err != nil { - return ver, errors.Trace(err) - } - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true, childTableInfos...) - if err != nil { - // Modified the type definition of 'null' to 'not null' before this, so rollBack the job when an error occurs. - job.State = model.JobStateRollingback - return ver, errors.Trace(err) - } - - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - // For those column-type-change type which doesn't need reorg data, we should also mock the job args for delete range. - job.Args = []any{[]int64{}, []int64{}} - return ver, nil -} - -func adjustTableInfoAfterModifyColumn( - tblInfo *model.TableInfo, newCol, oldCol *model.ColumnInfo, pos *ast.ColumnPosition) error { - // We need the latest column's offset and state. This information can be obtained from the store. - newCol.Offset = oldCol.Offset - newCol.State = oldCol.State - if pos != nil && pos.RelativeColumn != nil && oldCol.Name.L == pos.RelativeColumn.Name.L { - // For cases like `modify column b after b`, it should report this error. - return errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) - } - destOffset, err := LocateOffsetToMove(oldCol.Offset, pos, tblInfo) - if err != nil { - return errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) - } - tblInfo.Columns[oldCol.Offset] = newCol - tblInfo.MoveColumnInfo(oldCol.Offset, destOffset) - updateNewIdxColsNameOffset(tblInfo.Indices, oldCol.Name, newCol) - updateFKInfoWhenModifyColumn(tblInfo, oldCol.Name, newCol.Name) - updateTTLInfoWhenModifyColumn(tblInfo, oldCol.Name, newCol.Name) - return nil -} - -func updateFKInfoWhenModifyColumn(tblInfo *model.TableInfo, oldCol, newCol model.CIStr) { - if oldCol.L == newCol.L { - return - } - for _, fk := range tblInfo.ForeignKeys { - for i := range fk.Cols { - if fk.Cols[i].L == oldCol.L { - fk.Cols[i] = newCol - } - } - } -} - -func updateTTLInfoWhenModifyColumn(tblInfo *model.TableInfo, oldCol, newCol model.CIStr) { - if oldCol.L == newCol.L { - return - } - if tblInfo.TTLInfo != nil { - if tblInfo.TTLInfo.ColumnName.L == oldCol.L { - tblInfo.TTLInfo.ColumnName = newCol - } - } -} - -func adjustForeignKeyChildTableInfoAfterModifyColumn(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, newCol, oldCol *model.ColumnInfo) ([]schemaIDAndTableInfo, error) { - if !variable.EnableForeignKey.Load() || newCol.Name.L == oldCol.Name.L { - return nil, nil - } - is, err := getAndCheckLatestInfoSchema(d, t) - if err != nil { - return nil, err - } - referredFKs := is.GetTableReferredForeignKeys(job.SchemaName, tblInfo.Name.L) - if len(referredFKs) == 0 { - return nil, nil - } - fkh := newForeignKeyHelper() - fkh.addLoadedTable(job.SchemaName, tblInfo.Name.L, job.SchemaID, tblInfo) - for _, referredFK := range referredFKs { - info, err := fkh.getTableFromStorage(is, t, referredFK.ChildSchema, referredFK.ChildTable) - if err != nil { - if infoschema.ErrTableNotExists.Equal(err) || infoschema.ErrDatabaseNotExists.Equal(err) { - continue - } - return nil, err - } - fkInfo := model.FindFKInfoByName(info.tblInfo.ForeignKeys, referredFK.ChildFKName.L) - if fkInfo == nil { - continue - } - for i := range fkInfo.RefCols { - if fkInfo.RefCols[i].L == oldCol.Name.L { - fkInfo.RefCols[i] = newCol.Name - } - } - } - infoList := make([]schemaIDAndTableInfo, 0, len(fkh.loaded)) - for _, info := range fkh.loaded { - if info.tblInfo.ID == tblInfo.ID { - continue - } - infoList = append(infoList, info) - } - return infoList, nil -} - -func (w *worker) doModifyColumnTypeWithData( - d *ddlCtx, t *meta.Meta, job *model.Job, - dbInfo *model.DBInfo, tblInfo *model.TableInfo, changingCol, oldCol *model.ColumnInfo, - colName model.CIStr, pos *ast.ColumnPosition, rmIdxIDs []int64) (ver int64, _ error) { - var err error - originalState := changingCol.State - targetCol := changingCol.Clone() - targetCol.Name = colName - changingIdxs := buildRelatedIndexInfos(tblInfo, changingCol.ID) - switch changingCol.State { - case model.StateNone: - // Column from null to not null. - if !mysql.HasNotNullFlag(oldCol.GetFlag()) && mysql.HasNotNullFlag(changingCol.GetFlag()) { - // Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. - err := modifyColsFromNull2NotNull(w, dbInfo, tblInfo, []*model.ColumnInfo{oldCol}, targetCol, oldCol.GetType() != changingCol.GetType()) - if err != nil { - if dbterror.ErrWarnDataTruncated.Equal(err) || dbterror.ErrInvalidUseOfNull.Equal(err) { - job.State = model.JobStateRollingback - } - return ver, err - } - } - // none -> delete only - updateChangingObjState(changingCol, changingIdxs, model.StateDeleteOnly) - failpoint.Inject("mockInsertValueAfterCheckNull", func(val failpoint.Value) { - if valStr, ok := val.(string); ok { - var sctx sessionctx.Context - sctx, err := w.sessPool.Get() - if err != nil { - failpoint.Return(ver, err) - } - defer w.sessPool.Put(sctx) - - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - //nolint:forcetypeassert - _, _, err = sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, valStr) - if err != nil { - job.State = model.JobStateCancelled - failpoint.Return(ver, err) - } - } - }) - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, originalState != changingCol.State) - if err != nil { - return ver, errors.Trace(err) - } - // Make sure job args change after `updateVersionAndTableInfoWithCheck`, otherwise, the job args will - // be updated in `updateDDLJob` even if it meets an error in `updateVersionAndTableInfoWithCheck`. - job.SchemaState = model.StateDeleteOnly - metrics.GetBackfillProgressByLabel(metrics.LblModifyColumn, job.SchemaName, tblInfo.Name.String()).Set(0) - job.Args = append(job.Args, changingCol, changingIdxs, rmIdxIDs) - case model.StateDeleteOnly: - // Column from null to not null. - if !mysql.HasNotNullFlag(oldCol.GetFlag()) && mysql.HasNotNullFlag(changingCol.GetFlag()) { - // Introduce the `mysql.PreventNullInsertFlag` flag to prevent users from inserting or updating null values. - err := modifyColsFromNull2NotNull(w, dbInfo, tblInfo, []*model.ColumnInfo{oldCol}, targetCol, oldCol.GetType() != changingCol.GetType()) - if err != nil { - if dbterror.ErrWarnDataTruncated.Equal(err) || dbterror.ErrInvalidUseOfNull.Equal(err) { - job.State = model.JobStateRollingback - } - return ver, err - } - } - // delete only -> write only - updateChangingObjState(changingCol, changingIdxs, model.StateWriteOnly) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != changingCol.State) - if err != nil { - return ver, errors.Trace(err) - } - job.SchemaState = model.StateWriteOnly - failpoint.InjectCall("afterModifyColumnStateDeleteOnly", job.ID) - case model.StateWriteOnly: - // write only -> reorganization - updateChangingObjState(changingCol, changingIdxs, model.StateWriteReorganization) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != changingCol.State) - if err != nil { - return ver, errors.Trace(err) - } - // Initialize SnapshotVer to 0 for later reorganization check. - job.SnapshotVer = 0 - job.SchemaState = model.StateWriteReorganization - case model.StateWriteReorganization: - tbl, err := getTable(d.getAutoIDRequirement(), dbInfo.ID, tblInfo) - if err != nil { - return ver, errors.Trace(err) - } - - var done bool - if job.MultiSchemaInfo != nil { - done, ver, err = doReorgWorkForModifyColumnMultiSchema(w, d, t, job, tbl, oldCol, changingCol, changingIdxs) - } else { - done, ver, err = doReorgWorkForModifyColumn(w, d, t, job, tbl, oldCol, changingCol, changingIdxs) - } - if !done { - return ver, err - } - - rmIdxIDs = append(buildRelatedIndexIDs(tblInfo, oldCol.ID), rmIdxIDs...) - - err = adjustTableInfoAfterModifyColumnWithData(tblInfo, pos, oldCol, changingCol, colName, changingIdxs) - if err != nil { - job.State = model.JobStateRollingback - return ver, errors.Trace(err) - } - - updateChangingObjState(changingCol, changingIdxs, model.StatePublic) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != changingCol.State) - if err != nil { - return ver, errors.Trace(err) - } - - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - // Refactor the job args to add the old index ids into delete range table. - job.Args = []any{rmIdxIDs, getPartitionIDs(tblInfo)} - modifyColumnEvent := statsutil.NewModifyColumnEvent( - job.SchemaID, - tblInfo, - []*model.ColumnInfo{changingCol}, - ) - asyncNotifyEvent(d, modifyColumnEvent) - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("column", changingCol.State) - } - - return ver, errors.Trace(err) -} - -func doReorgWorkForModifyColumnMultiSchema(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, tbl table.Table, - oldCol, changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo) (done bool, ver int64, err error) { - if job.MultiSchemaInfo.Revertible { - done, ver, err = doReorgWorkForModifyColumn(w, d, t, job, tbl, oldCol, changingCol, changingIdxs) - if done { - // We need another round to wait for all the others sub-jobs to finish. - job.MarkNonRevertible() - } - // We need another round to run the reorg process. - return false, ver, err - } - // Non-revertible means all the sub jobs finished. - return true, ver, err -} - -func doReorgWorkForModifyColumn(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, tbl table.Table, - oldCol, changingCol *model.ColumnInfo, changingIdxs []*model.IndexInfo) (done bool, ver int64, err error) { - job.ReorgMeta.ReorgTp = model.ReorgTypeTxn - sctx, err1 := w.sessPool.Get() - if err1 != nil { - err = errors.Trace(err1) - return - } - defer w.sessPool.Put(sctx) - rh := newReorgHandler(sess.NewSession(sctx)) - dbInfo, err := t.GetDatabase(job.SchemaID) - if err != nil { - return false, ver, errors.Trace(err) - } - reorgInfo, err := getReorgInfo(d.jobContext(job.ID, job.ReorgMeta), - d, rh, job, dbInfo, tbl, BuildElements(changingCol, changingIdxs), false) - if err != nil || reorgInfo == nil || reorgInfo.first { - // If we run reorg firstly, we should update the job snapshot version - // and then run the reorg next time. - return false, ver, errors.Trace(err) - } - - // Inject a failpoint so that we can pause here and do verification on other components. - // With a failpoint-enabled version of TiDB, you can trigger this failpoint by the following command: - // enable: curl -X PUT -d "pause" "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData". - // disable: curl -X DELETE "http://127.0.0.1:10080/fail/github.com/pingcap/tidb/pkg/ddl/mockDelayInModifyColumnTypeWithData" - failpoint.Inject("mockDelayInModifyColumnTypeWithData", func() {}) - err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (addIndexErr error) { - defer util.Recover(metrics.LabelDDL, "onModifyColumn", - func() { - addIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("modify table `%v` column `%v` panic", tbl.Meta().Name, oldCol.Name) - }, false) - // Use old column name to generate less confusing error messages. - changingColCpy := changingCol.Clone() - changingColCpy.Name = oldCol.Name - return w.updateCurrentElement(tbl, reorgInfo) - }) - if err != nil { - if dbterror.ErrPausedDDLJob.Equal(err) { - return false, ver, nil - } - - if dbterror.ErrWaitReorgTimeout.Equal(err) { - // If timeout, we should return, check for the owner and re-wait job done. - return false, ver, nil - } - if kv.IsTxnRetryableError(err) || dbterror.ErrNotOwner.Equal(err) { - return false, ver, errors.Trace(err) - } - if err1 := rh.RemoveDDLReorgHandle(job, reorgInfo.elements); err1 != nil { - logutil.DDLLogger().Warn("run modify column job failed, RemoveDDLReorgHandle failed, can't convert job to rollback", - zap.String("job", job.String()), zap.Error(err1)) - } - logutil.DDLLogger().Warn("run modify column job failed, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) - job.State = model.JobStateRollingback - return false, ver, errors.Trace(err) - } - return true, ver, nil -} - -func adjustTableInfoAfterModifyColumnWithData(tblInfo *model.TableInfo, pos *ast.ColumnPosition, - oldCol, changingCol *model.ColumnInfo, newName model.CIStr, changingIdxs []*model.IndexInfo) (err error) { - if pos != nil && pos.RelativeColumn != nil && oldCol.Name.L == pos.RelativeColumn.Name.L { - // For cases like `modify column b after b`, it should report this error. - return errors.Trace(infoschema.ErrColumnNotExists.GenWithStackByArgs(oldCol.Name, tblInfo.Name)) - } - internalColName := changingCol.Name - changingCol = replaceOldColumn(tblInfo, oldCol, changingCol, newName) - if len(changingIdxs) > 0 { - updateNewIdxColsNameOffset(changingIdxs, internalColName, changingCol) - indexesToRemove := filterIndexesToRemove(changingIdxs, newName, tblInfo) - replaceOldIndexes(tblInfo, indexesToRemove) - } - if tblInfo.TTLInfo != nil { - updateTTLInfoWhenModifyColumn(tblInfo, oldCol.Name, changingCol.Name) - } - // Move the new column to a correct offset. - destOffset, err := LocateOffsetToMove(changingCol.Offset, pos, tblInfo) - if err != nil { - return errors.Trace(err) - } - tblInfo.MoveColumnInfo(changingCol.Offset, destOffset) - return nil -} - -func checkModifyColumnWithGeneratedColumnsConstraint(allCols []*table.Column, oldColName model.CIStr) error { - for _, col := range allCols { - if col.GeneratedExpr == nil { - continue - } - dependedColNames := FindColumnNamesInExpr(col.GeneratedExpr.Internal()) - for _, name := range dependedColNames { - if name.Name.L == oldColName.L { - if col.Hidden { - return dbterror.ErrDependentByFunctionalIndex.GenWithStackByArgs(oldColName.O) - } - return dbterror.ErrDependentByGeneratedColumn.GenWithStackByArgs(oldColName.O) - } - } - } - return nil -} - -// GetModifiableColumnJob returns a DDL job of model.ActionModifyColumn. -func GetModifiableColumnJob( - ctx context.Context, - sctx sessionctx.Context, - is infoschema.InfoSchema, // WARN: is maybe nil here. - ident ast.Ident, - originalColName model.CIStr, - schema *model.DBInfo, - t table.Table, - spec *ast.AlterTableSpec, -) (*model.Job, error) { - var err error - specNewColumn := spec.NewColumns[0] - - col := table.FindCol(t.Cols(), originalColName.L) - if col == nil { - return nil, infoschema.ErrColumnNotExists.GenWithStackByArgs(originalColName, ident.Name) - } - newColName := specNewColumn.Name.Name - if newColName.L == model.ExtraHandleName.L { - return nil, dbterror.ErrWrongColumnName.GenWithStackByArgs(newColName.L) - } - errG := checkModifyColumnWithGeneratedColumnsConstraint(t.Cols(), originalColName) - - // If we want to rename the column name, we need to check whether it already exists. - if newColName.L != originalColName.L { - c := table.FindCol(t.Cols(), newColName.L) - if c != nil { - return nil, infoschema.ErrColumnExists.GenWithStackByArgs(newColName) - } - - // And also check the generated columns dependency, if some generated columns - // depend on this column, we can't rename the column name. - if errG != nil { - return nil, errors.Trace(errG) - } - } - - // Constraints in the new column means adding new constraints. Errors should thrown, - // which will be done by `processColumnOptions` later. - if specNewColumn.Tp == nil { - // Make sure the column definition is simple field type. - return nil, errors.Trace(dbterror.ErrUnsupportedModifyColumn) - } - - if err = checkColumnAttributes(specNewColumn.Name.OrigColName(), specNewColumn.Tp); err != nil { - return nil, errors.Trace(err) - } - - newCol := table.ToColumn(&model.ColumnInfo{ - ID: col.ID, - // We use this PR(https://github.com/pingcap/tidb/pull/6274) as the dividing line to define whether it is a new version or an old version TiDB. - // The old version TiDB initializes the column's offset and state here. - // The new version TiDB doesn't initialize the column's offset and state, and it will do the initialization in run DDL function. - // When we do the rolling upgrade the following may happen: - // a new version TiDB builds the DDL job that doesn't be set the column's offset and state, - // and the old version TiDB is the DDL owner, it doesn't get offset and state from the store. Then it will encounter errors. - // So here we set offset and state to support the rolling upgrade. - Offset: col.Offset, - State: col.State, - OriginDefaultValue: col.OriginDefaultValue, - OriginDefaultValueBit: col.OriginDefaultValueBit, - FieldType: *specNewColumn.Tp, - Name: newColName, - Version: col.Version, - }) - - if err = ProcessColumnCharsetAndCollation(sctx, col, newCol, t.Meta(), specNewColumn, schema); err != nil { - return nil, err - } - - if err = checkModifyColumnWithForeignKeyConstraint(is, schema.Name.L, t.Meta(), col.ColumnInfo, newCol.ColumnInfo); err != nil { - return nil, errors.Trace(err) - } - - // Copy index related options to the new spec. - indexFlags := col.FieldType.GetFlag() & (mysql.PriKeyFlag | mysql.UniqueKeyFlag | mysql.MultipleKeyFlag) - newCol.FieldType.AddFlag(indexFlags) - if mysql.HasPriKeyFlag(col.FieldType.GetFlag()) { - newCol.FieldType.AddFlag(mysql.NotNullFlag) - // TODO: If user explicitly set NULL, we should throw error ErrPrimaryCantHaveNull. - } - - if err = ProcessModifyColumnOptions(sctx, newCol, specNewColumn.Options); err != nil { - return nil, errors.Trace(err) - } - - if err = checkModifyTypes(&col.FieldType, &newCol.FieldType, isColumnWithIndex(col.Name.L, t.Meta().Indices)); err != nil { - if strings.Contains(err.Error(), "Unsupported modifying collation") { - colErrMsg := "Unsupported modifying collation of column '%s' from '%s' to '%s' when index is defined on it." - err = dbterror.ErrUnsupportedModifyCollation.GenWithStack(colErrMsg, col.Name.L, col.GetCollate(), newCol.GetCollate()) - } - return nil, errors.Trace(err) - } - needChangeColData := needChangeColumnData(col.ColumnInfo, newCol.ColumnInfo) - if needChangeColData { - if err = isGeneratedRelatedColumn(t.Meta(), newCol.ColumnInfo, col.ColumnInfo); err != nil { - return nil, errors.Trace(err) - } - if t.Meta().Partition != nil { - return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("table is partition table") - } - } - - // Check that the column change does not affect the partitioning column - // It must keep the same type, int [unsigned], [var]char, date[time] - if t.Meta().Partition != nil { - pt, ok := t.(table.PartitionedTable) - if !ok { - // Should never happen! - return nil, dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(newCol.Name.O) - } - isPartitioningColumn := false - for _, name := range pt.GetPartitionColumnNames() { - if strings.EqualFold(name.L, col.Name.L) { - isPartitioningColumn = true - break - } - } - if isPartitioningColumn { - // TODO: update the partitioning columns with new names if column is renamed - // Would be an extension from MySQL which does not support it. - if col.Name.L != newCol.Name.L { - return nil, dbterror.ErrDependentByPartitionFunctional.GenWithStackByArgs(col.Name.L) - } - if !isColTypeAllowedAsPartitioningCol(t.Meta().Partition.Type, newCol.FieldType) { - return nil, dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(newCol.Name.O) - } - pi := pt.Meta().GetPartitionInfo() - if len(pi.Columns) == 0 { - // non COLUMNS partitioning, only checks INTs, not their actual range - // There are many edge cases, like when truncating SQL Mode is allowed - // which will change the partitioning expression value resulting in a - // different partition. Better be safe and not allow decreasing of length. - // TODO: Should we allow it in strict mode? Wait for a use case / request. - if newCol.FieldType.GetFlen() < col.FieldType.GetFlen() { - return nil, dbterror.ErrUnsupportedModifyCollation.GenWithStack("Unsupported modify column, decreasing length of int may result in truncation and change of partition") - } - } - // Basically only allow changes of the length/decimals for the column - // Note that enum is not allowed, so elems are not checked - // TODO: support partition by ENUM - if newCol.FieldType.EvalType() != col.FieldType.EvalType() || - newCol.FieldType.GetFlag() != col.FieldType.GetFlag() || - newCol.FieldType.GetCollate() != col.FieldType.GetCollate() || - newCol.FieldType.GetCharset() != col.FieldType.GetCharset() { - return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't change the partitioning column, since it would require reorganize all partitions") - } - // Generate a new PartitionInfo and validate it together with the new column definition - // Checks if all partition definition values are compatible. - // Similar to what buildRangePartitionDefinitions would do in terms of checks. - - tblInfo := pt.Meta() - newTblInfo := *tblInfo - // Replace col with newCol and see if we can generate a new SHOW CREATE TABLE - // and reparse it and build new partition definitions (which will do additional - // checks columns vs partition definition values - newCols := make([]*model.ColumnInfo, 0, len(newTblInfo.Columns)) - for _, c := range newTblInfo.Columns { - if c.ID == col.ID { - newCols = append(newCols, newCol.ColumnInfo) - continue - } - newCols = append(newCols, c) - } - newTblInfo.Columns = newCols - - var buf bytes.Buffer - AppendPartitionInfo(tblInfo.GetPartitionInfo(), &buf, mysql.ModeNone) - // The parser supports ALTER TABLE ... PARTITION BY ... even if the ddl code does not yet :) - // Ignoring warnings - stmt, _, err := parser.New().ParseSQL("ALTER TABLE t " + buf.String()) - if err != nil { - // Should never happen! - return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("cannot parse generated PartitionInfo") - } - at, ok := stmt[0].(*ast.AlterTableStmt) - if !ok || len(at.Specs) != 1 || at.Specs[0].Partition == nil { - return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("cannot parse generated PartitionInfo") - } - pAst := at.Specs[0].Partition - _, err = buildPartitionDefinitionsInfo( - exprctx.CtxWithHandleTruncateErrLevel(sctx.GetExprCtx(), errctx.LevelError), - pAst.Definitions, &newTblInfo, uint64(len(newTblInfo.Partition.Definitions)), - ) - if err != nil { - return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStack("New column does not match partition definitions: %s", err.Error()) - } - } - } - - // We don't support modifying column from not_auto_increment to auto_increment. - if !mysql.HasAutoIncrementFlag(col.GetFlag()) && mysql.HasAutoIncrementFlag(newCol.GetFlag()) { - return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't set auto_increment") - } - // Not support auto id with default value. - if mysql.HasAutoIncrementFlag(newCol.GetFlag()) && newCol.GetDefaultValue() != nil { - return nil, dbterror.ErrInvalidDefaultValue.GenWithStackByArgs(newCol.Name) - } - // Disallow modifying column from auto_increment to not auto_increment if the session variable `AllowRemoveAutoInc` is false. - if !sctx.GetSessionVars().AllowRemoveAutoInc && mysql.HasAutoIncrementFlag(col.GetFlag()) && !mysql.HasAutoIncrementFlag(newCol.GetFlag()) { - return nil, dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't remove auto_increment without @@tidb_allow_remove_auto_inc enabled") - } - - // We support modifying the type definitions of 'null' to 'not null' now. - var modifyColumnTp byte - if !mysql.HasNotNullFlag(col.GetFlag()) && mysql.HasNotNullFlag(newCol.GetFlag()) { - if err = checkForNullValue(ctx, sctx, true, ident.Schema, ident.Name, newCol.ColumnInfo, col.ColumnInfo); err != nil { - return nil, errors.Trace(err) - } - // `modifyColumnTp` indicates that there is a type modification. - modifyColumnTp = mysql.TypeNull - } - - if err = checkColumnWithIndexConstraint(t.Meta(), col.ColumnInfo, newCol.ColumnInfo); err != nil { - return nil, err - } - - // As same with MySQL, we don't support modifying the stored status for generated columns. - if err = checkModifyGeneratedColumn(sctx, schema.Name, t, col, newCol, specNewColumn, spec.Position); err != nil { - return nil, errors.Trace(err) - } - if errG != nil { - // According to issue https://github.com/pingcap/tidb/issues/24321, - // changing the type of a column involving generating a column is prohibited. - return nil, dbterror.ErrUnsupportedOnGeneratedColumn.GenWithStackByArgs(errG.Error()) - } - - if t.Meta().TTLInfo != nil { - // the column referenced by TTL should be a time type - if t.Meta().TTLInfo.ColumnName.L == originalColName.L && !types.IsTypeTime(newCol.ColumnInfo.FieldType.GetType()) { - return nil, errors.Trace(dbterror.ErrUnsupportedColumnInTTLConfig.GenWithStackByArgs(newCol.ColumnInfo.Name.O)) - } - } - - var newAutoRandBits uint64 - if newAutoRandBits, err = checkAutoRandom(t.Meta(), col, specNewColumn); err != nil { - return nil, errors.Trace(err) - } - - txn, err := sctx.Txn(true) - if err != nil { - return nil, errors.Trace(err) - } - bdrRole, err := meta.NewMeta(txn).GetBDRRole() - if err != nil { - return nil, errors.Trace(err) - } - if bdrRole == string(ast.BDRRolePrimary) && - deniedByBDRWhenModifyColumn(newCol.FieldType, col.FieldType, specNewColumn.Options) { - return nil, dbterror.ErrBDRRestrictedDDL.FastGenByArgs(bdrRole) - } - - job := &model.Job{ - SchemaID: schema.ID, - TableID: t.Meta().ID, - SchemaName: schema.Name.L, - TableName: t.Meta().Name.L, - Type: model.ActionModifyColumn, - BinlogInfo: &model.HistoryInfo{}, - ReorgMeta: NewDDLReorgMeta(sctx), - CtxVars: []any{needChangeColData}, - Args: []any{&newCol.ColumnInfo, originalColName, spec.Position, modifyColumnTp, newAutoRandBits}, - CDCWriteSource: sctx.GetSessionVars().CDCWriteSource, - SQLMode: sctx.GetSessionVars().SQLMode, - } - return job, nil -} - -func needChangeColumnData(oldCol, newCol *model.ColumnInfo) bool { - toUnsigned := mysql.HasUnsignedFlag(newCol.GetFlag()) - originUnsigned := mysql.HasUnsignedFlag(oldCol.GetFlag()) - needTruncationOrToggleSign := func() bool { - return (newCol.GetFlen() > 0 && (newCol.GetFlen() < oldCol.GetFlen() || newCol.GetDecimal() < oldCol.GetDecimal())) || - (toUnsigned != originUnsigned) - } - // Ignore the potential max display length represented by integer's flen, use default flen instead. - defaultOldColFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(oldCol.GetType()) - defaultNewColFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(newCol.GetType()) - needTruncationOrToggleSignForInteger := func() bool { - return (defaultNewColFlen > 0 && defaultNewColFlen < defaultOldColFlen) || (toUnsigned != originUnsigned) - } - - // Deal with the same type. - if oldCol.GetType() == newCol.GetType() { - switch oldCol.GetType() { - case mysql.TypeNewDecimal: - // Since type decimal will encode the precision, frac, negative(signed) and wordBuf into storage together, there is no short - // cut to eliminate data reorg change for column type change between decimal. - return oldCol.GetFlen() != newCol.GetFlen() || oldCol.GetDecimal() != newCol.GetDecimal() || toUnsigned != originUnsigned - case mysql.TypeEnum, mysql.TypeSet: - return IsElemsChangedToModifyColumn(oldCol.GetElems(), newCol.GetElems()) - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - return toUnsigned != originUnsigned - case mysql.TypeString: - // Due to the behavior of padding \x00 at binary type, always change column data when binary length changed - if types.IsBinaryStr(&oldCol.FieldType) { - return newCol.GetFlen() != oldCol.GetFlen() - } - } - - return needTruncationOrToggleSign() - } - - if ConvertBetweenCharAndVarchar(oldCol.GetType(), newCol.GetType()) { - return true - } - - // Deal with the different type. - switch oldCol.GetType() { - case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - switch newCol.GetType() { - case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: - return needTruncationOrToggleSign() - } - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - switch newCol.GetType() { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - return needTruncationOrToggleSignForInteger() - } - // conversion between float and double needs reorganization, see issue #31372 - } - - return true -} - -// ConvertBetweenCharAndVarchar check whether column converted between char and varchar -// TODO: it is used for plugins. so change plugin's using and remove it. -func ConvertBetweenCharAndVarchar(oldCol, newCol byte) bool { - return types.ConvertBetweenCharAndVarchar(oldCol, newCol) -} - -// IsElemsChangedToModifyColumn check elems changed -func IsElemsChangedToModifyColumn(oldElems, newElems []string) bool { - if len(newElems) < len(oldElems) { - return true - } - for index, oldElem := range oldElems { - newElem := newElems[index] - if oldElem != newElem { - return true - } - } - return false -} - -// ProcessColumnCharsetAndCollation process column charset and collation -func ProcessColumnCharsetAndCollation(sctx sessionctx.Context, col *table.Column, newCol *table.Column, meta *model.TableInfo, specNewColumn *ast.ColumnDef, schema *model.DBInfo) error { - var chs, coll string - var err error - // TODO: Remove it when all table versions are greater than or equal to TableInfoVersion1. - // If newCol's charset is empty and the table's version less than TableInfoVersion1, - // we will not modify the charset of the column. This behavior is not compatible with MySQL. - if len(newCol.FieldType.GetCharset()) == 0 && meta.Version < model.TableInfoVersion1 { - chs = col.FieldType.GetCharset() - coll = col.FieldType.GetCollate() - } else { - chs, coll, err = getCharsetAndCollateInColumnDef(sctx.GetSessionVars(), specNewColumn) - if err != nil { - return errors.Trace(err) - } - chs, coll, err = ResolveCharsetCollation(sctx.GetSessionVars(), - ast.CharsetOpt{Chs: chs, Col: coll}, - ast.CharsetOpt{Chs: meta.Charset, Col: meta.Collate}, - ast.CharsetOpt{Chs: schema.Charset, Col: schema.Collate}, - ) - chs, coll = OverwriteCollationWithBinaryFlag(sctx.GetSessionVars(), specNewColumn, chs, coll) - if err != nil { - return errors.Trace(err) - } - } - - if err = setCharsetCollationFlenDecimal(&newCol.FieldType, newCol.Name.O, chs, coll, sctx.GetSessionVars()); err != nil { - return errors.Trace(err) - } - decodeEnumSetBinaryLiteralToUTF8(&newCol.FieldType, chs) - return nil -} - -// checkColumnWithIndexConstraint is used to check the related index constraint of the modified column. -// Index has a max-prefix-length constraint. eg: a varchar(100), index idx(a), modifying column a to a varchar(4000) -// will cause index idx to break the max-prefix-length constraint. -func checkColumnWithIndexConstraint(tbInfo *model.TableInfo, originalCol, newCol *model.ColumnInfo) error { - columns := make([]*model.ColumnInfo, 0, len(tbInfo.Columns)) - columns = append(columns, tbInfo.Columns...) - // Replace old column with new column. - for i, col := range columns { - if col.Name.L != originalCol.Name.L { - continue - } - columns[i] = newCol.Clone() - columns[i].Name = originalCol.Name - break - } - - pkIndex := tables.FindPrimaryIndex(tbInfo) - - checkOneIndex := func(indexInfo *model.IndexInfo) (err error) { - var modified bool - for _, col := range indexInfo.Columns { - if col.Name.L == originalCol.Name.L { - modified = true - break - } - } - if !modified { - return - } - err = checkIndexInModifiableColumns(columns, indexInfo.Columns) - if err != nil { - return - } - err = checkIndexPrefixLength(columns, indexInfo.Columns) - return - } - - // Check primary key first. - var err error - - if pkIndex != nil { - err = checkOneIndex(pkIndex) - if err != nil { - return err - } - } - - // Check secondary indexes. - for _, indexInfo := range tbInfo.Indices { - if indexInfo.Primary { - continue - } - // the second param should always be set to true, check index length only if it was modified - // checkOneIndex needs one param only. - err = checkOneIndex(indexInfo) - if err != nil { - return err - } - } - return nil -} - -func checkIndexInModifiableColumns(columns []*model.ColumnInfo, idxColumns []*model.IndexColumn) error { - for _, ic := range idxColumns { - col := model.FindColumnInfo(columns, ic.Name.L) - if col == nil { - return dbterror.ErrKeyColumnDoesNotExits.GenWithStack("column does not exist: %s", ic.Name) - } - - prefixLength := types.UnspecifiedLength - if types.IsTypePrefixable(col.FieldType.GetType()) && col.FieldType.GetFlen() > ic.Length { - // When the index column is changed, prefix length is only valid - // if the type is still prefixable and larger than old prefix length. - prefixLength = ic.Length - } - if err := checkIndexColumn(nil, col, prefixLength); err != nil { - return err - } - } - return nil -} - -// checkModifyTypes checks if the 'origin' type can be modified to 'to' type no matter directly change -// or change by reorg. It returns error if the two types are incompatible and correlated change are not -// supported. However, even the two types can be change, if the "origin" type contains primary key, error will be returned. -func checkModifyTypes(origin *types.FieldType, to *types.FieldType, needRewriteCollationData bool) error { - canReorg, err := types.CheckModifyTypeCompatible(origin, to) - if err != nil { - if !canReorg { - return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(err.Error())) - } - if mysql.HasPriKeyFlag(origin.GetFlag()) { - msg := "this column has primary key flag" - return dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(msg) - } - } - - err = checkModifyCharsetAndCollation(to.GetCharset(), to.GetCollate(), origin.GetCharset(), origin.GetCollate(), needRewriteCollationData) - - if err != nil { - if to.GetCharset() == charset.CharsetGBK || origin.GetCharset() == charset.CharsetGBK { - return errors.Trace(err) - } - // column type change can handle the charset change between these two types in the process of the reorg. - if dbterror.ErrUnsupportedModifyCharset.Equal(err) && canReorg { - return nil - } - } - return errors.Trace(err) -} - -// ProcessModifyColumnOptions process column options. -func ProcessModifyColumnOptions(ctx sessionctx.Context, col *table.Column, options []*ast.ColumnOption) error { - var sb strings.Builder - restoreFlags := format.RestoreStringSingleQuotes | format.RestoreKeyWordLowercase | format.RestoreNameBackQuotes | - format.RestoreSpacesAroundBinaryOperation | format.RestoreWithoutSchemaName | format.RestoreWithoutSchemaName - restoreCtx := format.NewRestoreCtx(restoreFlags, &sb) - - var hasDefaultValue, setOnUpdateNow bool - var err error - var hasNullFlag bool - for _, opt := range options { - switch opt.Tp { - case ast.ColumnOptionDefaultValue: - hasDefaultValue, err = SetDefaultValue(ctx, col, opt) - if err != nil { - return errors.Trace(err) - } - case ast.ColumnOptionComment: - err := setColumnComment(ctx, col, opt) - if err != nil { - return errors.Trace(err) - } - case ast.ColumnOptionNotNull: - col.AddFlag(mysql.NotNullFlag) - case ast.ColumnOptionNull: - hasNullFlag = true - col.DelFlag(mysql.NotNullFlag) - case ast.ColumnOptionAutoIncrement: - col.AddFlag(mysql.AutoIncrementFlag) - case ast.ColumnOptionPrimaryKey: - return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStack("can't change column constraint (PRIMARY KEY)")) - case ast.ColumnOptionUniqKey: - return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStack("can't change column constraint (UNIQUE KEY)")) - case ast.ColumnOptionOnUpdate: - // TODO: Support other time functions. - if !(col.GetType() == mysql.TypeTimestamp || col.GetType() == mysql.TypeDatetime) { - return dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) - } - if !expression.IsValidCurrentTimestampExpr(opt.Expr, &col.FieldType) { - return dbterror.ErrInvalidOnUpdate.GenWithStackByArgs(col.Name) - } - col.AddFlag(mysql.OnUpdateNowFlag) - setOnUpdateNow = true - case ast.ColumnOptionGenerated: - sb.Reset() - err = opt.Expr.Restore(restoreCtx) - if err != nil { - return errors.Trace(err) - } - col.GeneratedExprString = sb.String() - col.GeneratedStored = opt.Stored - col.Dependences = make(map[string]struct{}) - // Only used by checkModifyGeneratedColumn, there is no need to set a ctor for it. - col.GeneratedExpr = table.NewClonableExprNode(nil, opt.Expr) - for _, colName := range FindColumnNamesInExpr(opt.Expr) { - col.Dependences[colName.Name.L] = struct{}{} - } - case ast.ColumnOptionCollate: - col.SetCollate(opt.StrValue) - case ast.ColumnOptionReference: - return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't modify with references")) - case ast.ColumnOptionFulltext: - return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't modify with full text")) - case ast.ColumnOptionCheck: - return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs("can't modify with check")) - // Ignore ColumnOptionAutoRandom. It will be handled later. - case ast.ColumnOptionAutoRandom: - default: - return errors.Trace(dbterror.ErrUnsupportedModifyColumn.GenWithStackByArgs(fmt.Sprintf("unknown column option type: %d", opt.Tp))) - } - } - - if err = processAndCheckDefaultValueAndColumn(ctx, col, nil, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { - return errors.Trace(err) - } - - return nil -} - -func checkAutoRandom(tableInfo *model.TableInfo, originCol *table.Column, specNewColumn *ast.ColumnDef) (uint64, error) { - var oldShardBits, oldRangeBits uint64 - if isClusteredPKColumn(originCol, tableInfo) { - oldShardBits = tableInfo.AutoRandomBits - oldRangeBits = tableInfo.AutoRandomRangeBits - } - newShardBits, newRangeBits, err := extractAutoRandomBitsFromColDef(specNewColumn) - if err != nil { - return 0, errors.Trace(err) - } - switch { - case oldShardBits == newShardBits: - case oldShardBits < newShardBits: - addingAutoRandom := oldShardBits == 0 - if addingAutoRandom { - convFromAutoInc := mysql.HasAutoIncrementFlag(originCol.GetFlag()) && originCol.IsPKHandleColumn(tableInfo) - if !convFromAutoInc { - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomAlterChangeFromAutoInc) - } - } - if autoid.AutoRandomShardBitsMax < newShardBits { - errMsg := fmt.Sprintf(autoid.AutoRandomOverflowErrMsg, - autoid.AutoRandomShardBitsMax, newShardBits, specNewColumn.Name.Name.O) - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(errMsg) - } - // increasing auto_random shard bits is allowed. - case oldShardBits > newShardBits: - if newShardBits == 0 { - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomAlterErrMsg) - } - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomDecreaseBitErrMsg) - } - - modifyingAutoRandCol := oldShardBits > 0 || newShardBits > 0 - if modifyingAutoRandCol { - // Disallow changing the column field type. - if originCol.GetType() != specNewColumn.Tp.GetType() { - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomModifyColTypeErrMsg) - } - if originCol.GetType() != mysql.TypeLonglong { - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(fmt.Sprintf(autoid.AutoRandomOnNonBigIntColumn, types.TypeStr(originCol.GetType()))) - } - // Disallow changing from auto_random to auto_increment column. - if containsColumnOption(specNewColumn, ast.ColumnOptionAutoIncrement) { - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithAutoIncErrMsg) - } - // Disallow specifying a default value on auto_random column. - if containsColumnOption(specNewColumn, ast.ColumnOptionDefaultValue) { - return 0, dbterror.ErrInvalidAutoRandom.GenWithStackByArgs(autoid.AutoRandomIncompatibleWithDefaultValueErrMsg) - } - } - if rangeBitsIsChanged(oldRangeBits, newRangeBits) { - return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(autoid.AutoRandomUnsupportedAlterRangeBits) - } - return newShardBits, nil -} - -func isClusteredPKColumn(col *table.Column, tblInfo *model.TableInfo) bool { - switch { - case tblInfo.PKIsHandle: - return mysql.HasPriKeyFlag(col.GetFlag()) - case tblInfo.IsCommonHandle: - pk := tables.FindPrimaryIndex(tblInfo) - for _, c := range pk.Columns { - if c.Name.L == col.Name.L { - return true - } - } - return false - default: - return false - } -} - -func rangeBitsIsChanged(oldBits, newBits uint64) bool { - if oldBits == 0 { - oldBits = autoid.AutoRandomRangeBitsDefault - } - if newBits == 0 { - newBits = autoid.AutoRandomRangeBitsDefault - } - return oldBits != newBits -} diff --git a/pkg/ddl/partition.go b/pkg/ddl/partition.go index dcc59bf5cd8ae..a0ee830c0e870 100644 --- a/pkg/ddl/partition.go +++ b/pkg/ddl/partition.go @@ -179,10 +179,10 @@ func (w *worker) onAddTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (v job.SchemaState = model.StateReplicaOnly case model.StateReplicaOnly: // replica only -> public - if val, _err_ := failpoint.Eval(_curpkg_("sleepBeforeReplicaOnly")); _err_ == nil { + failpoint.Inject("sleepBeforeReplicaOnly", func(val failpoint.Value) { sleepSecond := val.(int) time.Sleep(time.Duration(sleepSecond) * time.Second) - } + }) // Here need do some tiflash replica complement check. // TODO: If a table is with no TiFlashReplica or it is not available, the replica-only state can be eliminated. if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { @@ -410,16 +410,16 @@ func checkAddPartitionValue(meta *model.TableInfo, part *model.PartitionInfo) er } func checkPartitionReplica(replicaCount uint64, addingDefinitions []model.PartitionDefinition, d *ddlCtx) (needWait bool, err error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockWaitTiFlashReplica")); _err_ == nil { + failpoint.Inject("mockWaitTiFlashReplica", func(val failpoint.Value) { if val.(bool) { - return true, nil + failpoint.Return(true, nil) } - } - if val, _err_ := failpoint.Eval(_curpkg_("mockWaitTiFlashReplicaOK")); _err_ == nil { + }) + failpoint.Inject("mockWaitTiFlashReplicaOK", func(val failpoint.Value) { if val.(bool) { - return false, nil + failpoint.Return(false, nil) } - } + }) ctx := context.Background() pdCli := d.store.(tikv.Storage).GetRegionCache().PDClient() @@ -451,9 +451,9 @@ func checkPartitionReplica(replicaCount uint64, addingDefinitions []model.Partit return needWait, errors.Trace(err) } tiflashPeerAtLeastOne := checkTiFlashPeerStoreAtLeastOne(stores, regionState.Meta.Peers) - if v, _err_ := failpoint.Eval(_curpkg_("ForceTiflashNotAvailable")); _err_ == nil { + failpoint.Inject("ForceTiflashNotAvailable", func(v failpoint.Value) { tiflashPeerAtLeastOne = v.(bool) - } + }) // It's unnecessary to wait all tiflash peer to be replicated. // Here only make sure that tiflash peer count > 0 (at least one). if tiflashPeerAtLeastOne { @@ -2516,9 +2516,9 @@ func clearTruncatePartitionTiflashStatus(tblInfo *model.TableInfo, newPartitions // Clear the tiflash replica available status. if tblInfo.TiFlashReplica != nil { e := infosync.ConfigureTiFlashPDForPartitions(true, &newPartitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID) - if _, _err_ := failpoint.Eval(_curpkg_("FailTiFlashTruncatePartition")); _err_ == nil { + failpoint.Inject("FailTiFlashTruncatePartition", func() { e = errors.New("enforced error") - } + }) if e != nil { logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(e)) return e @@ -2778,11 +2778,11 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo return ver, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("exchangePartitionErr")); _err_ == nil { + failpoint.Inject("exchangePartitionErr", func(val failpoint.Value) { if val.(bool) { - return ver, errors.New("occur an error after updating partition id") + failpoint.Return(ver, errors.New("occur an error after updating partition id")) } - } + }) // Set both tables to the maximum auto IDs between normal table and partitioned table. // TODO: Fix the issue of big transactions during EXCHANGE PARTITION with AutoID. @@ -2801,20 +2801,20 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo return ver, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("exchangePartitionAutoID")); _err_ == nil { + failpoint.Inject("exchangePartitionAutoID", func(val failpoint.Value) { if val.(bool) { seCtx, err := w.sessPool.Get() defer w.sessPool.Put(seCtx) if err != nil { - return ver, err + failpoint.Return(ver, err) } se := sess.NewSession(seCtx) _, err = se.Execute(context.Background(), "insert ignore into test.pt values (40000000)", "exchange_partition_test") if err != nil { - return ver, err + failpoint.Return(ver, err) } } - } + }) // the follow code is a swap function for rules of two partitions // though partitions has exchanged their ID, swap still take effect @@ -3246,11 +3246,11 @@ func (w *worker) onReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) } } firstPartIdx, lastPartIdx, idMap, err2 := getReplacedPartitionIDs(partNames, tblInfo.Partition) - if val, _err_ := failpoint.Eval(_curpkg_("reorgPartWriteReorgReplacedPartIDsFail")); _err_ == nil { + failpoint.Inject("reorgPartWriteReorgReplacedPartIDsFail", func(val failpoint.Value) { if val.(bool) { err2 = errors.New("Injected error by reorgPartWriteReorgReplacedPartIDsFail") } - } + }) if err2 != nil { return ver, err2 } @@ -3354,11 +3354,11 @@ func (w *worker) onReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) } job.CtxVars = []any{physicalTableIDs, newIDs} ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if val, _err_ := failpoint.Eval(_curpkg_("reorgPartWriteReorgSchemaVersionUpdateFail")); _err_ == nil { + failpoint.Inject("reorgPartWriteReorgSchemaVersionUpdateFail", func(val failpoint.Value) { if val.(bool) { err = errors.New("Injected error by reorgPartWriteReorgSchemaVersionUpdateFail") } - } + }) if err != nil { return ver, errors.Trace(err) } @@ -3717,12 +3717,12 @@ func (w *worker) reorgPartitionDataAndIndex(t table.Table, reorgInfo *reorgInfo) } } - if val, _err_ := failpoint.Eval(_curpkg_("reorgPartitionAfterDataCopy")); _err_ == nil { + failpoint.Inject("reorgPartitionAfterDataCopy", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { panic("panic test in reorgPartitionAfterDataCopy") } - } + }) if !bytes.Equal(reorgInfo.currElement.TypeKey, meta.IndexElementKey) { // row data has been copied, now proceed with creating the indexes @@ -4800,10 +4800,10 @@ func checkPartitionByHash(ctx sessionctx.Context, tbInfo *model.TableInfo) error // checkPartitionByRange checks validity of a "BY RANGE" partition. func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - if _, _err_ := failpoint.Eval(_curpkg_("CheckPartitionByRangeErr")); _err_ == nil { + failpoint.Inject("CheckPartitionByRangeErr", func() { ctx.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryMemoryExceeded) panic(ctx.GetSessionVars().SQLKiller.HandleSignal()) - } + }) pi := tbInfo.Partition if len(pi.Columns) == 0 { diff --git a/pkg/ddl/partition.go__failpoint_stash__ b/pkg/ddl/partition.go__failpoint_stash__ deleted file mode 100644 index a0ee830c0e870..0000000000000 --- a/pkg/ddl/partition.go__failpoint_stash__ +++ /dev/null @@ -1,4922 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "bytes" - "context" - "encoding/hex" - "fmt" - "math" - "strconv" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/pkg/ddl/label" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/ddl/placement" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/opcode" - "github.com/pingcap/tidb/pkg/parser/terror" - field_types "github.com/pingcap/tidb/pkg/parser/types" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - driver "github.com/pingcap/tidb/pkg/types/parser_driver" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/mathutil" - decoder "github.com/pingcap/tidb/pkg/util/rowDecoder" - "github.com/pingcap/tidb/pkg/util/slice" - "github.com/pingcap/tidb/pkg/util/sqlkiller" - "github.com/pingcap/tidb/pkg/util/stringutil" - "github.com/tikv/client-go/v2/tikv" - kvutil "github.com/tikv/client-go/v2/util" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" -) - -const ( - partitionMaxValue = "MAXVALUE" -) - -func checkAddPartition(t *meta.Meta, job *model.Job) (*model.TableInfo, *model.PartitionInfo, []model.PartitionDefinition, error) { - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, nil, errors.Trace(err) - } - partInfo := &model.PartitionInfo{} - err = job.DecodeArgs(&partInfo) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, errors.Trace(err) - } - if len(tblInfo.Partition.AddingDefinitions) > 0 { - return tblInfo, partInfo, tblInfo.Partition.AddingDefinitions, nil - } - return tblInfo, partInfo, []model.PartitionDefinition{}, nil -} - -// TODO: Move this into reorganize partition! -func (w *worker) onAddTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - // Handle the rolling back job - if job.IsRollingback() { - ver, err := w.onDropTablePartition(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - return ver, nil - } - - // notice: addingDefinitions is empty when job is in state model.StateNone - tblInfo, partInfo, addingDefinitions, err := checkAddPartition(t, job) - if err != nil { - return ver, err - } - - // In order to skip maintaining the state check in partitionDefinition, TiDB use addingDefinition instead of state field. - // So here using `job.SchemaState` to judge what the stage of this job is. - switch job.SchemaState { - case model.StateNone: - // job.SchemaState == model.StateNone means the job is in the initial state of add partition. - // Here should use partInfo from job directly and do some check action. - err = checkAddPartitionTooManyPartitions(uint64(len(tblInfo.Partition.Definitions) + len(partInfo.Definitions))) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = checkAddPartitionValue(tblInfo, partInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = checkAddPartitionNameUnique(tblInfo, partInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - // move the adding definition into tableInfo. - updateAddingPartitionInfo(partInfo, tblInfo) - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - - // modify placement settings - for _, def := range tblInfo.Partition.AddingDefinitions { - if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, def.PlacementPolicyRef); err != nil { - return ver, errors.Trace(err) - } - } - - if tblInfo.TiFlashReplica != nil { - // Must set placement rule, and make sure it succeeds. - if err := infosync.ConfigureTiFlashPDForPartitions(true, &tblInfo.Partition.AddingDefinitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID); err != nil { - logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(err)) - return ver, errors.Trace(err) - } - } - - bundles, err := alterTablePartitionBundles(t, tblInfo, tblInfo.Partition.AddingDefinitions) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the placement rules") - } - - ids := getIDs([]*model.TableInfo{tblInfo}) - for _, p := range tblInfo.Partition.AddingDefinitions { - ids = append(ids, p.ID) - } - if _, err := alterTableLabelRule(job.SchemaName, tblInfo, ids); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - // none -> replica only - job.SchemaState = model.StateReplicaOnly - case model.StateReplicaOnly: - // replica only -> public - failpoint.Inject("sleepBeforeReplicaOnly", func(val failpoint.Value) { - sleepSecond := val.(int) - time.Sleep(time.Duration(sleepSecond) * time.Second) - }) - // Here need do some tiflash replica complement check. - // TODO: If a table is with no TiFlashReplica or it is not available, the replica-only state can be eliminated. - if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { - // For available state, the new added partition should wait it's replica to - // be finished. Otherwise the query to this partition will be blocked. - needRetry, err := checkPartitionReplica(tblInfo.TiFlashReplica.Count, addingDefinitions, d) - if err != nil { - return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) - } - if needRetry { - // The new added partition hasn't been replicated. - // Do nothing to the job this time, wait next worker round. - time.Sleep(tiflashCheckTiDBHTTPAPIHalfInterval) - // Set the error here which will lead this job exit when it's retry times beyond the limitation. - return ver, errors.Errorf("[ddl] add partition wait for tiflash replica to complete") - } - } - - // When TiFlash Replica is ready, we must move them into `AvailablePartitionIDs`. - if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { - for _, d := range partInfo.Definitions { - tblInfo.TiFlashReplica.AvailablePartitionIDs = append(tblInfo.TiFlashReplica.AvailablePartitionIDs, d.ID) - err = infosync.UpdateTiFlashProgressCache(d.ID, 1) - if err != nil { - // just print log, progress will be updated in `refreshTiFlashTicker` - logutil.DDLLogger().Error("update tiflash sync progress cache failed", - zap.Error(err), - zap.Int64("tableID", tblInfo.ID), - zap.Int64("partitionID", d.ID), - ) - } - } - } - // For normal and replica finished table, move the `addingDefinitions` into `Definitions`. - updatePartitionInfo(tblInfo) - - preSplitAndScatter(w.sess.Context, d.store, tblInfo, addingDefinitions) - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - addPartitionEvent := statsutil.NewAddPartitionEvent( - job.SchemaID, - tblInfo, - partInfo, - ) - asyncNotifyEvent(d, addPartitionEvent) - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) - } - - return ver, errors.Trace(err) -} - -// alterTableLabelRule updates Label Rules if they exists -// returns true if changed. -func alterTableLabelRule(schemaName string, meta *model.TableInfo, ids []int64) (bool, error) { - tableRuleID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, schemaName, meta.Name.L) - oldRule, err := infosync.GetLabelRules(context.TODO(), []string{tableRuleID}) - if err != nil { - return false, errors.Trace(err) - } - if len(oldRule) == 0 { - return false, nil - } - - r, ok := oldRule[tableRuleID] - if ok { - rule := r.Reset(schemaName, meta.Name.L, "", ids...) - err = infosync.PutLabelRule(context.TODO(), rule) - if err != nil { - return false, errors.Wrapf(err, "failed to notify PD label rule") - } - return true, nil - } - return false, nil -} - -func alterTablePartitionBundles(t *meta.Meta, tblInfo *model.TableInfo, addingDefinitions []model.PartitionDefinition) ([]*placement.Bundle, error) { - var bundles []*placement.Bundle - - // tblInfo do not include added partitions, so we should add them first - tblInfo = tblInfo.Clone() - p := *tblInfo.Partition - p.Definitions = append([]model.PartitionDefinition{}, p.Definitions...) - p.Definitions = append(tblInfo.Partition.Definitions, addingDefinitions...) - tblInfo.Partition = &p - - // bundle for table should be recomputed because it includes some default configs for partitions - tblBundle, err := placement.NewTableBundle(t, tblInfo) - if err != nil { - return nil, errors.Trace(err) - } - - if tblBundle != nil { - bundles = append(bundles, tblBundle) - } - - partitionBundles, err := placement.NewPartitionListBundles(t, addingDefinitions) - if err != nil { - return nil, errors.Trace(err) - } - - bundles = append(bundles, partitionBundles...) - return bundles, nil -} - -// When drop/truncate a partition, we should still keep the dropped partition's placement settings to avoid unnecessary region schedules. -// When a partition is not configured with a placement policy directly, its rule is in the table's placement group which will be deleted after -// partition truncated/dropped. So it is necessary to create a standalone placement group with partition id after it. -func droppedPartitionBundles(t *meta.Meta, tblInfo *model.TableInfo, dropPartitions []model.PartitionDefinition) ([]*placement.Bundle, error) { - partitions := make([]model.PartitionDefinition, 0, len(dropPartitions)) - for _, def := range dropPartitions { - def = def.Clone() - if def.PlacementPolicyRef == nil { - def.PlacementPolicyRef = tblInfo.PlacementPolicyRef - } - - if def.PlacementPolicyRef != nil { - partitions = append(partitions, def) - } - } - - return placement.NewPartitionListBundles(t, partitions) -} - -// updatePartitionInfo merge `addingDefinitions` into `Definitions` in the tableInfo. -func updatePartitionInfo(tblInfo *model.TableInfo) { - parInfo := &model.PartitionInfo{} - oldDefs, newDefs := tblInfo.Partition.Definitions, tblInfo.Partition.AddingDefinitions - parInfo.Definitions = make([]model.PartitionDefinition, 0, len(newDefs)+len(oldDefs)) - parInfo.Definitions = append(parInfo.Definitions, oldDefs...) - parInfo.Definitions = append(parInfo.Definitions, newDefs...) - tblInfo.Partition.Definitions = parInfo.Definitions - tblInfo.Partition.AddingDefinitions = nil -} - -// updateAddingPartitionInfo write adding partitions into `addingDefinitions` field in the tableInfo. -func updateAddingPartitionInfo(partitionInfo *model.PartitionInfo, tblInfo *model.TableInfo) { - newDefs := partitionInfo.Definitions - tblInfo.Partition.AddingDefinitions = make([]model.PartitionDefinition, 0, len(newDefs)) - tblInfo.Partition.AddingDefinitions = append(tblInfo.Partition.AddingDefinitions, newDefs...) -} - -// rollbackAddingPartitionInfo remove the `addingDefinitions` in the tableInfo. -func rollbackAddingPartitionInfo(tblInfo *model.TableInfo) ([]int64, []string, []*placement.Bundle) { - physicalTableIDs := make([]int64, 0, len(tblInfo.Partition.AddingDefinitions)) - partNames := make([]string, 0, len(tblInfo.Partition.AddingDefinitions)) - rollbackBundles := make([]*placement.Bundle, 0, len(tblInfo.Partition.AddingDefinitions)) - for _, one := range tblInfo.Partition.AddingDefinitions { - physicalTableIDs = append(physicalTableIDs, one.ID) - partNames = append(partNames, one.Name.L) - if one.PlacementPolicyRef != nil { - rollbackBundles = append(rollbackBundles, placement.NewBundle(one.ID)) - } - } - tblInfo.Partition.AddingDefinitions = nil - return physicalTableIDs, partNames, rollbackBundles -} - -// Check if current table already contains DEFAULT list partition -func checkAddListPartitions(tblInfo *model.TableInfo) error { - for i := range tblInfo.Partition.Definitions { - for j := range tblInfo.Partition.Definitions[i].InValues { - for _, val := range tblInfo.Partition.Definitions[i].InValues[j] { - if val == "DEFAULT" { // should already be normalized - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("ADD List partition, already contains DEFAULT partition. Please use REORGANIZE PARTITION instead") - } - } - } - } - return nil -} - -// checkAddPartitionValue check add Partition Values, -// For Range: values less than value must be strictly increasing for each partition. -// For List: if a Default partition exists, -// -// no ADD partition can be allowed -// (needs reorganize partition instead). -func checkAddPartitionValue(meta *model.TableInfo, part *model.PartitionInfo) error { - switch meta.Partition.Type { - case model.PartitionTypeRange: - if len(meta.Partition.Columns) == 0 { - newDefs, oldDefs := part.Definitions, meta.Partition.Definitions - rangeValue := oldDefs[len(oldDefs)-1].LessThan[0] - if strings.EqualFold(rangeValue, "MAXVALUE") { - return errors.Trace(dbterror.ErrPartitionMaxvalue) - } - - currentRangeValue, err := strconv.Atoi(rangeValue) - if err != nil { - return errors.Trace(err) - } - - for i := 0; i < len(newDefs); i++ { - ifMaxvalue := strings.EqualFold(newDefs[i].LessThan[0], "MAXVALUE") - if ifMaxvalue && i == len(newDefs)-1 { - return nil - } else if ifMaxvalue && i != len(newDefs)-1 { - return errors.Trace(dbterror.ErrPartitionMaxvalue) - } - - nextRangeValue, err := strconv.Atoi(newDefs[i].LessThan[0]) - if err != nil { - return errors.Trace(err) - } - if nextRangeValue <= currentRangeValue { - return errors.Trace(dbterror.ErrRangeNotIncreasing) - } - currentRangeValue = nextRangeValue - } - } - case model.PartitionTypeList: - err := checkAddListPartitions(meta) - if err != nil { - return err - } - } - return nil -} - -func checkPartitionReplica(replicaCount uint64, addingDefinitions []model.PartitionDefinition, d *ddlCtx) (needWait bool, err error) { - failpoint.Inject("mockWaitTiFlashReplica", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(true, nil) - } - }) - failpoint.Inject("mockWaitTiFlashReplicaOK", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(false, nil) - } - }) - - ctx := context.Background() - pdCli := d.store.(tikv.Storage).GetRegionCache().PDClient() - stores, err := pdCli.GetAllStores(ctx) - if err != nil { - return needWait, errors.Trace(err) - } - // Check whether stores have `count` tiflash engines. - tiFlashStoreCount := uint64(0) - for _, store := range stores { - if storeHasEngineTiFlashLabel(store) { - tiFlashStoreCount++ - } - } - if replicaCount > tiFlashStoreCount { - return false, errors.Errorf("[ddl] the tiflash replica count: %d should be less than the total tiflash server count: %d", replicaCount, tiFlashStoreCount) - } - for _, pDef := range addingDefinitions { - startKey, endKey := tablecodec.GetTableHandleKeyRange(pDef.ID) - regions, err := pdCli.BatchScanRegions(ctx, []pd.KeyRange{{StartKey: startKey, EndKey: endKey}}, -1) - if err != nil { - return needWait, errors.Trace(err) - } - // For every region in the partition, if it has some corresponding peers and - // no pending peers, that means the replication has completed. - for _, region := range regions { - regionState, err := pdCli.GetRegionByID(ctx, region.Meta.Id) - if err != nil { - return needWait, errors.Trace(err) - } - tiflashPeerAtLeastOne := checkTiFlashPeerStoreAtLeastOne(stores, regionState.Meta.Peers) - failpoint.Inject("ForceTiflashNotAvailable", func(v failpoint.Value) { - tiflashPeerAtLeastOne = v.(bool) - }) - // It's unnecessary to wait all tiflash peer to be replicated. - // Here only make sure that tiflash peer count > 0 (at least one). - if tiflashPeerAtLeastOne { - continue - } - needWait = true - logutil.DDLLogger().Info("partition replicas check failed in replica-only DDL state", zap.Int64("pID", pDef.ID), zap.Uint64("wait region ID", region.Meta.Id), zap.Bool("tiflash peer at least one", tiflashPeerAtLeastOne), zap.Time("check time", time.Now())) - return needWait, nil - } - } - logutil.DDLLogger().Info("partition replicas check ok in replica-only DDL state") - return needWait, nil -} - -func checkTiFlashPeerStoreAtLeastOne(stores []*metapb.Store, peers []*metapb.Peer) bool { - for _, peer := range peers { - for _, store := range stores { - if peer.StoreId == store.Id && storeHasEngineTiFlashLabel(store) { - return true - } - } - } - return false -} - -func storeHasEngineTiFlashLabel(store *metapb.Store) bool { - for _, label := range store.Labels { - if label.Key == placement.EngineLabelKey && label.Value == placement.EngineLabelTiFlash { - return true - } - } - return false -} - -func checkListPartitions(defs []*ast.PartitionDefinition) error { - for _, def := range defs { - _, ok := def.Clause.(*ast.PartitionDefinitionClauseIn) - if !ok { - switch def.Clause.(type) { - case *ast.PartitionDefinitionClauseLessThan: - return ast.ErrPartitionWrongValues.GenWithStackByArgs("RANGE", "LESS THAN") - case *ast.PartitionDefinitionClauseNone: - return ast.ErrPartitionRequiresValues.GenWithStackByArgs("LIST", "IN") - default: - return dbterror.ErrUnsupportedCreatePartition.GenWithStack("Only VALUES IN () is supported for LIST partitioning") - } - } - } - return nil -} - -// buildTablePartitionInfo builds partition info and checks for some errors. -func buildTablePartitionInfo(ctx sessionctx.Context, s *ast.PartitionOptions, tbInfo *model.TableInfo) error { - if s == nil { - return nil - } - - if strings.EqualFold(ctx.GetSessionVars().EnableTablePartition, "OFF") { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrTablePartitionDisabled) - return nil - } - - var enable bool - switch s.Tp { - case model.PartitionTypeRange: - enable = true - case model.PartitionTypeList: - // Partition by list is enabled only when tidb_enable_list_partition is 'ON'. - enable = ctx.GetSessionVars().EnableListTablePartition - if enable { - err := checkListPartitions(s.Definitions) - if err != nil { - return err - } - } - case model.PartitionTypeHash, model.PartitionTypeKey: - // Partition by hash and key is enabled by default. - if s.Sub != nil { - // Subpartitioning only allowed with Range or List - return ast.ErrSubpartition - } - // Note that linear hash is simply ignored, and creates non-linear hash/key. - if s.Linear { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("LINEAR %s is not supported, using non-linear %s instead", s.Tp.String(), s.Tp.String()))) - } - if s.Tp == model.PartitionTypeHash || len(s.ColumnNames) != 0 { - enable = true - } - if s.Tp == model.PartitionTypeKey && len(s.ColumnNames) == 0 { - enable = true - } - } - - if !enable { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported partition type %v, treat as normal table", s.Tp))) - return nil - } - if s.Sub != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(dbterror.ErrUnsupportedCreatePartition.FastGen(fmt.Sprintf("Unsupported subpartitioning, only using %v partitioning", s.Tp))) - } - - pi := &model.PartitionInfo{ - Type: s.Tp, - Enable: enable, - Num: s.Num, - } - tbInfo.Partition = pi - if s.Expr != nil { - if err := checkPartitionFuncValid(ctx.GetExprCtx(), tbInfo, s.Expr); err != nil { - return errors.Trace(err) - } - buf := new(bytes.Buffer) - restoreFlags := format.DefaultRestoreFlags | format.RestoreBracketAroundBinaryOperation | - format.RestoreWithoutSchemaName | format.RestoreWithoutTableName - restoreCtx := format.NewRestoreCtx(restoreFlags, buf) - if err := s.Expr.Restore(restoreCtx); err != nil { - return err - } - pi.Expr = buf.String() - } else if s.ColumnNames != nil { - pi.Columns = make([]model.CIStr, 0, len(s.ColumnNames)) - for _, cn := range s.ColumnNames { - pi.Columns = append(pi.Columns, cn.Name) - } - if pi.Type == model.PartitionTypeKey && len(s.ColumnNames) == 0 { - if tbInfo.PKIsHandle { - pi.Columns = append(pi.Columns, tbInfo.GetPkName()) - pi.IsEmptyColumns = true - } else if key := tbInfo.GetPrimaryKey(); key != nil { - for _, col := range key.Columns { - pi.Columns = append(pi.Columns, col.Name) - } - pi.IsEmptyColumns = true - } - } - if err := checkColumnsPartitionType(tbInfo); err != nil { - return err - } - } - - exprCtx := ctx.GetExprCtx() - err := generatePartitionDefinitionsFromInterval(exprCtx, s, tbInfo) - if err != nil { - return errors.Trace(err) - } - - defs, err := buildPartitionDefinitionsInfo(exprCtx, s.Definitions, tbInfo, s.Num) - if err != nil { - return errors.Trace(err) - } - - tbInfo.Partition.Definitions = defs - - if s.Interval != nil { - // Syntactic sugar for INTERVAL partitioning - // Generate the resulting CREATE TABLE as the query string - query, ok := ctx.Value(sessionctx.QueryString).(string) - if ok { - sqlMode := ctx.GetSessionVars().SQLMode - var buf bytes.Buffer - AppendPartitionDefs(tbInfo.Partition, &buf, sqlMode) - - syntacticSugar := s.Interval.OriginalText() - syntacticStart := s.Interval.OriginTextPosition() - newQuery := query[:syntacticStart] + "(" + buf.String() + ")" + query[syntacticStart+len(syntacticSugar):] - ctx.SetValue(sessionctx.QueryString, newQuery) - } - } - - partCols, err := getPartitionColSlices(exprCtx, tbInfo, s) - if err != nil { - return errors.Trace(err) - } - - for _, index := range tbInfo.Indices { - if index.Unique && !checkUniqueKeyIncludePartKey(partCols, index.Columns) { - index.Global = ctx.GetSessionVars().EnableGlobalIndex - } - } - return nil -} - -func getPartitionColSlices(sctx expression.BuildContext, tblInfo *model.TableInfo, s *ast.PartitionOptions) (partCols stringSlice, err error) { - if s.Expr != nil { - extractCols := newPartitionExprChecker(sctx, tblInfo) - s.Expr.Accept(extractCols) - partColumns, err := extractCols.columns, extractCols.err - if err != nil { - return nil, err - } - return columnInfoSlice(partColumns), nil - } else if len(s.ColumnNames) > 0 { - return columnNameSlice(s.ColumnNames), nil - } else if len(s.ColumnNames) == 0 { - if tblInfo.PKIsHandle { - return columnInfoSlice([]*model.ColumnInfo{tblInfo.GetPkColInfo()}), nil - } else if key := tblInfo.GetPrimaryKey(); key != nil { - colInfos := make([]*model.ColumnInfo, 0, len(key.Columns)) - for _, col := range key.Columns { - colInfos = append(colInfos, model.FindColumnInfo(tblInfo.Cols(), col.Name.L)) - } - return columnInfoSlice(colInfos), nil - } - } - return nil, errors.Errorf("Table partition metadata not correct, neither partition expression or list of partition columns") -} - -func checkColumnsPartitionType(tbInfo *model.TableInfo) error { - for _, col := range tbInfo.Partition.Columns { - colInfo := tbInfo.FindPublicColumnByName(col.L) - if colInfo == nil { - return errors.Trace(dbterror.ErrFieldNotFoundPart) - } - if !isColTypeAllowedAsPartitioningCol(tbInfo.Partition.Type, colInfo.FieldType) { - return dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(col.O) - } - } - return nil -} - -func isValidKeyPartitionColType(fieldType types.FieldType) bool { - switch fieldType.GetType() { - case mysql.TypeBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeJSON, mysql.TypeGeometry, mysql.TypeTiDBVectorFloat32: - return false - default: - return true - } -} - -func isColTypeAllowedAsPartitioningCol(partType model.PartitionType, fieldType types.FieldType) bool { - // For key partition, the permitted partition field types can be all field types except - // BLOB, JSON, Geometry - if partType == model.PartitionTypeKey { - return isValidKeyPartitionColType(fieldType) - } - // The permitted data types are shown in the following list: - // All integer types - // DATE and DATETIME - // CHAR, VARCHAR, BINARY, and VARBINARY - // See https://dev.mysql.com/doc/mysql-partitioning-excerpt/5.7/en/partitioning-columns.html - // Note that also TIME is allowed in MySQL. Also see https://bugs.mysql.com/bug.php?id=84362 - switch fieldType.GetType() { - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeDuration: - case mysql.TypeVarchar, mysql.TypeString: - default: - return false - } - return true -} - -// getPartitionIntervalFromTable checks if a partitioned table matches a generated INTERVAL partitioned scheme -// will return nil if error occurs, i.e. not an INTERVAL partitioned table -func getPartitionIntervalFromTable(ctx expression.BuildContext, tbInfo *model.TableInfo) *ast.PartitionInterval { - if tbInfo.Partition == nil || - tbInfo.Partition.Type != model.PartitionTypeRange { - return nil - } - if len(tbInfo.Partition.Columns) > 1 { - // Multi-column RANGE COLUMNS is not supported with INTERVAL - return nil - } - if len(tbInfo.Partition.Definitions) < 2 { - // Must have at least two partitions to calculate an INTERVAL - return nil - } - - var ( - interval ast.PartitionInterval - startIdx = 0 - endIdx = len(tbInfo.Partition.Definitions) - 1 - isIntType = true - minVal = "0" - ) - if len(tbInfo.Partition.Columns) > 0 { - partCol := findColumnByName(tbInfo.Partition.Columns[0].L, tbInfo) - if partCol.FieldType.EvalType() == types.ETInt { - min := getLowerBoundInt(partCol) - minVal = strconv.FormatInt(min, 10) - } else if partCol.FieldType.EvalType() == types.ETDatetime { - isIntType = false - minVal = "0000-01-01" - } else { - // Only INT and Datetime columns are supported for INTERVAL partitioning - return nil - } - } else { - if !isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) { - minVal = "-9223372036854775808" - } - } - - // Check if possible null partition - firstPartLessThan := driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[0].LessThan[0]) - if strings.EqualFold(firstPartLessThan, minVal) { - interval.NullPart = true - startIdx++ - firstPartLessThan = driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[startIdx].LessThan[0]) - } - // flag if MAXVALUE partition - lastPartLessThan := driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[endIdx].LessThan[0]) - if strings.EqualFold(lastPartLessThan, partitionMaxValue) { - interval.MaxValPart = true - endIdx-- - lastPartLessThan = driver.UnwrapFromSingleQuotes(tbInfo.Partition.Definitions[endIdx].LessThan[0]) - } - // Guess the interval - if startIdx >= endIdx { - // Must have at least two partitions to calculate an INTERVAL - return nil - } - var firstExpr, lastExpr ast.ExprNode - if isIntType { - exprStr := fmt.Sprintf("((%s) - (%s)) DIV %d", lastPartLessThan, firstPartLessThan, endIdx-startIdx) - expr, err := expression.ParseSimpleExpr(ctx, exprStr) - if err != nil { - return nil - } - val, isNull, err := expr.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) - if isNull || err != nil || val < 1 { - // If NULL, error or interval < 1 then cannot be an INTERVAL partitioned table - return nil - } - interval.IntervalExpr.Expr = ast.NewValueExpr(val, "", "") - interval.IntervalExpr.TimeUnit = ast.TimeUnitInvalid - firstExpr, err = astIntValueExprFromStr(firstPartLessThan, minVal == "0") - if err != nil { - return nil - } - interval.FirstRangeEnd = &firstExpr - lastExpr, err = astIntValueExprFromStr(lastPartLessThan, minVal == "0") - if err != nil { - return nil - } - interval.LastRangeEnd = &lastExpr - } else { // types.ETDatetime - exprStr := fmt.Sprintf("TIMESTAMPDIFF(SECOND, '%s', '%s')", firstPartLessThan, lastPartLessThan) - expr, err := expression.ParseSimpleExpr(ctx, exprStr) - if err != nil { - return nil - } - val, isNull, err := expr.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) - if isNull || err != nil || val < 1 { - // If NULL, error or interval < 1 then cannot be an INTERVAL partitioned table - return nil - } - - // This will not find all matches > 28 days, since INTERVAL 1 MONTH can generate - // 2022-01-31, 2022-02-28, 2022-03-31 etc. so we just assume that if there is a - // diff >= 28 days, we will try with Month and not retry with something else... - i := val / int64(endIdx-startIdx) - if i < (28 * 24 * 60 * 60) { - // Since it is not stored or displayed, non need to try Minute..Week! - interval.IntervalExpr.Expr = ast.NewValueExpr(i, "", "") - interval.IntervalExpr.TimeUnit = ast.TimeUnitSecond - } else { - // Since it is not stored or displayed, non need to try to match Quarter or Year! - if (endIdx - startIdx) <= 3 { - // in case February is in the range - i = i / (28 * 24 * 60 * 60) - } else { - // This should be good for intervals up to 5 years - i = i / (30 * 24 * 60 * 60) - } - interval.IntervalExpr.Expr = ast.NewValueExpr(i, "", "") - interval.IntervalExpr.TimeUnit = ast.TimeUnitMonth - } - - firstExpr = ast.NewValueExpr(firstPartLessThan, "", "") - lastExpr = ast.NewValueExpr(lastPartLessThan, "", "") - interval.FirstRangeEnd = &firstExpr - interval.LastRangeEnd = &lastExpr - } - - partitionMethod := ast.PartitionMethod{ - Tp: model.PartitionTypeRange, - Interval: &interval, - } - partOption := &ast.PartitionOptions{PartitionMethod: partitionMethod} - // Generate the definitions from interval, first and last - err := generatePartitionDefinitionsFromInterval(ctx, partOption, tbInfo) - if err != nil { - return nil - } - - return &interval -} - -// comparePartitionAstAndModel compares a generated *ast.PartitionOptions and a *model.PartitionInfo -func comparePartitionAstAndModel(ctx expression.BuildContext, pAst *ast.PartitionOptions, pModel *model.PartitionInfo, partCol *model.ColumnInfo) error { - a := pAst.Definitions - m := pModel.Definitions - if len(pAst.Definitions) != len(pModel.Definitions) { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: number of partitions generated != partition defined (%d != %d)", len(a), len(m)) - } - - evalCtx := ctx.GetEvalCtx() - evalFn := func(expr ast.ExprNode) (types.Datum, error) { - val, err := expression.EvalSimpleAst(ctx, ast.NewValueExpr(expr, "", "")) - if err != nil || partCol == nil { - return val, err - } - return val.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) - } - for i := range pAst.Definitions { - // Allow options to differ! (like Placement Rules) - // Allow names to differ! - - // Check MAXVALUE - maxVD := false - if strings.EqualFold(m[i].LessThan[0], partitionMaxValue) { - maxVD = true - } - generatedExpr := a[i].Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs[0] - _, maxVG := generatedExpr.(*ast.MaxValueExpr) - if maxVG || maxVD { - if maxVG && maxVD { - continue - } - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("INTERVAL partitioning: MAXVALUE clause defined for partition %s differs between generated and defined", m[i].Name.O)) - } - - lessThan := m[i].LessThan[0] - if len(lessThan) > 1 && lessThan[:1] == "'" && lessThan[len(lessThan)-1:] == "'" { - lessThan = driver.UnwrapFromSingleQuotes(lessThan) - } - lessThanVal, err := evalFn(ast.NewValueExpr(lessThan, "", "")) - if err != nil { - return err - } - generatedExprVal, err := evalFn(generatedExpr) - if err != nil { - return err - } - cmp, err := lessThanVal.Compare(evalCtx.TypeCtx(), &generatedExprVal, collate.GetBinaryCollator()) - if err != nil { - return err - } - if cmp != 0 { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("INTERVAL partitioning: LESS THAN for partition %s differs between generated and defined", m[i].Name.O)) - } - } - return nil -} - -// comparePartitionDefinitions check if generated definitions are the same as the given ones -// Allow names to differ -// returns error in case of error or non-accepted difference -func comparePartitionDefinitions(ctx expression.BuildContext, a, b []*ast.PartitionDefinition) error { - if len(a) != len(b) { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("number of partitions generated != partition defined (%d != %d)", len(a), len(b)) - } - for i := range a { - if len(b[i].Sub) > 0 { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s does have unsupported subpartitions", b[i].Name.O)) - } - // TODO: We could extend the syntax to allow for table options too, like: - // CREATE TABLE t ... INTERVAL ... LAST PARTITION LESS THAN ('2015-01-01') PLACEMENT POLICY = 'cheapStorage' - // ALTER TABLE t LAST PARTITION LESS THAN ('2022-01-01') PLACEMENT POLICY 'defaultStorage' - // ALTER TABLE t LAST PARTITION LESS THAN ('2023-01-01') PLACEMENT POLICY 'fastStorage' - if len(b[i].Options) > 0 { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s does have unsupported options", b[i].Name.O)) - } - lessThan, ok := b[i].Clause.(*ast.PartitionDefinitionClauseLessThan) - if !ok { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s does not have the right type for LESS THAN", b[i].Name.O)) - } - definedExpr := lessThan.Exprs[0] - generatedExpr := a[i].Clause.(*ast.PartitionDefinitionClauseLessThan).Exprs[0] - _, maxVD := definedExpr.(*ast.MaxValueExpr) - _, maxVG := generatedExpr.(*ast.MaxValueExpr) - if maxVG || maxVD { - if maxVG && maxVD { - continue - } - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s differs between generated and defined for MAXVALUE", b[i].Name.O)) - } - cmpExpr := &ast.BinaryOperationExpr{ - Op: opcode.EQ, - L: definedExpr, - R: generatedExpr, - } - cmp, err := expression.EvalSimpleAst(ctx, cmpExpr) - if err != nil { - return err - } - if cmp.GetInt64() != 1 { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(fmt.Sprintf("partition %s differs between generated and defined for expression", b[i].Name.O)) - } - } - return nil -} - -func getLowerBoundInt(partCols ...*model.ColumnInfo) int64 { - ret := int64(0) - for _, col := range partCols { - if mysql.HasUnsignedFlag(col.FieldType.GetFlag()) { - return 0 - } - ret = min(ret, types.IntergerSignedLowerBound(col.GetType())) - } - return ret -} - -// generatePartitionDefinitionsFromInterval generates partition Definitions according to INTERVAL options on partOptions -func generatePartitionDefinitionsFromInterval(ctx expression.BuildContext, partOptions *ast.PartitionOptions, tbInfo *model.TableInfo) error { - if partOptions.Interval == nil { - return nil - } - if tbInfo.Partition.Type != model.PartitionTypeRange { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, only allowed on RANGE partitioning") - } - if len(partOptions.ColumnNames) > 1 || len(tbInfo.Partition.Columns) > 1 { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, does not allow RANGE COLUMNS with more than one column") - } - var partCol *model.ColumnInfo - if len(tbInfo.Partition.Columns) > 0 { - partCol = findColumnByName(tbInfo.Partition.Columns[0].L, tbInfo) - if partCol == nil { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, could not find any RANGE COLUMNS") - } - // Only support Datetime, date and INT column types for RANGE INTERVAL! - switch partCol.FieldType.EvalType() { - case types.ETInt, types.ETDatetime: - default: - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, only supports Date, Datetime and INT types") - } - } - // Allow given partition definitions, but check it later! - definedPartDefs := partOptions.Definitions - partOptions.Definitions = make([]*ast.PartitionDefinition, 0, 1) - if partOptions.Interval.FirstRangeEnd == nil || partOptions.Interval.LastRangeEnd == nil { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, currently requires FIRST and LAST partitions to be defined") - } - switch partOptions.Interval.IntervalExpr.TimeUnit { - case ast.TimeUnitInvalid, ast.TimeUnitYear, ast.TimeUnitQuarter, ast.TimeUnitMonth, ast.TimeUnitWeek, ast.TimeUnitDay, ast.TimeUnitHour, ast.TimeUnitDayMinute, ast.TimeUnitSecond: - default: - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning, only supports YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE and SECOND as time unit") - } - first := ast.PartitionDefinitionClauseLessThan{ - Exprs: []ast.ExprNode{*partOptions.Interval.FirstRangeEnd}, - } - last := ast.PartitionDefinitionClauseLessThan{ - Exprs: []ast.ExprNode{*partOptions.Interval.LastRangeEnd}, - } - if len(tbInfo.Partition.Columns) > 0 { - colTypes := collectColumnsType(tbInfo) - if len(colTypes) != len(tbInfo.Partition.Columns) { - return dbterror.ErrWrongPartitionName.GenWithStack("partition column name cannot be found") - } - if _, err := checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, first.Exprs); err != nil { - return err - } - if _, err := checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, last.Exprs); err != nil { - return err - } - } else { - if err := checkPartitionValuesIsInt(ctx, "FIRST PARTITION", first.Exprs, tbInfo); err != nil { - return err - } - if err := checkPartitionValuesIsInt(ctx, "LAST PARTITION", last.Exprs, tbInfo); err != nil { - return err - } - } - if partOptions.Interval.NullPart { - var partExpr ast.ExprNode - if len(tbInfo.Partition.Columns) == 1 && partOptions.Interval.IntervalExpr.TimeUnit != ast.TimeUnitInvalid { - // Notice compatibility with MySQL, keyword here is 'supported range' but MySQL seems to work from 0000-01-01 too - // https://dev.mysql.com/doc/refman/8.0/en/datetime.html says range 1000-01-01 - 9999-12-31 - // https://docs.pingcap.com/tidb/dev/data-type-date-and-time says The supported range is '0000-01-01' to '9999-12-31' - // set LESS THAN to ZeroTime - partExpr = ast.NewValueExpr("0000-01-01", "", "") - } else { - var min int64 - if partCol != nil { - min = getLowerBoundInt(partCol) - } else { - if !isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) { - min = math.MinInt64 - } - } - partExpr = ast.NewValueExpr(min, "", "") - } - partOptions.Definitions = append(partOptions.Definitions, &ast.PartitionDefinition{ - Name: model.NewCIStr("P_NULL"), - Clause: &ast.PartitionDefinitionClauseLessThan{ - Exprs: []ast.ExprNode{partExpr}, - }, - }) - } - - err := GeneratePartDefsFromInterval(ctx, ast.AlterTablePartition, tbInfo, partOptions) - if err != nil { - return err - } - - if partOptions.Interval.MaxValPart { - partOptions.Definitions = append(partOptions.Definitions, &ast.PartitionDefinition{ - Name: model.NewCIStr("P_MAXVALUE"), - Clause: &ast.PartitionDefinitionClauseLessThan{ - Exprs: []ast.ExprNode{&ast.MaxValueExpr{}}, - }, - }) - } - - if len(definedPartDefs) > 0 { - err := comparePartitionDefinitions(ctx, partOptions.Definitions, definedPartDefs) - if err != nil { - return err - } - // Seems valid, so keep the defined so that the user defined names are kept etc. - partOptions.Definitions = definedPartDefs - } else if len(tbInfo.Partition.Definitions) > 0 { - err := comparePartitionAstAndModel(ctx, partOptions, tbInfo.Partition, partCol) - if err != nil { - return err - } - } - - return nil -} - -func checkAndGetColumnsTypeAndValuesMatch(ctx expression.BuildContext, colTypes []types.FieldType, exprs []ast.ExprNode) ([]types.Datum, error) { - // Validate() has already checked len(colNames) = len(exprs) - // create table ... partition by range columns (cols) - // partition p0 values less than (expr) - // check the type of cols[i] and expr is consistent. - valDatums := make([]types.Datum, 0, len(colTypes)) - for i, colExpr := range exprs { - if _, ok := colExpr.(*ast.MaxValueExpr); ok { - valDatums = append(valDatums, types.NewStringDatum(partitionMaxValue)) - continue - } - if d, ok := colExpr.(*ast.DefaultExpr); ok { - if d.Name != nil { - return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() - } - continue - } - colType := colTypes[i] - val, err := expression.EvalSimpleAst(ctx, colExpr) - if err != nil { - return nil, err - } - // Check val.ConvertTo(colType) doesn't work, so we need this case by case check. - vkind := val.Kind() - switch colType.GetType() { - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeDuration: - switch vkind { - case types.KindString, types.KindBytes, types.KindNull: - default: - return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() - } - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - switch vkind { - case types.KindInt64, types.KindUint64, types.KindNull: - default: - return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() - } - case mysql.TypeFloat, mysql.TypeDouble: - switch vkind { - case types.KindFloat32, types.KindFloat64, types.KindNull: - default: - return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() - } - case mysql.TypeString, mysql.TypeVarString: - switch vkind { - case types.KindString, types.KindBytes, types.KindNull, types.KindBinaryLiteral: - default: - return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() - } - } - evalCtx := ctx.GetEvalCtx() - newVal, err := val.ConvertTo(evalCtx.TypeCtx(), &colType) - if err != nil { - return nil, dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() - } - valDatums = append(valDatums, newVal) - } - return valDatums, nil -} - -func astIntValueExprFromStr(s string, unsigned bool) (ast.ExprNode, error) { - if unsigned { - u, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return nil, err - } - return ast.NewValueExpr(u, "", ""), nil - } - i, err := strconv.ParseInt(s, 10, 64) - if err != nil { - return nil, err - } - return ast.NewValueExpr(i, "", ""), nil -} - -// GeneratePartDefsFromInterval generates range partitions from INTERVAL partitioning. -// Handles -// - CREATE TABLE: all partitions are generated -// - ALTER TABLE FIRST PARTITION (expr): Drops all partitions before the partition matching the expr (i.e. sets that partition as the new first partition) -// i.e. will return the partitions from old FIRST partition to (and including) new FIRST partition -// - ALTER TABLE LAST PARTITION (expr): Creates new partitions from (excluding) old LAST partition to (including) new LAST partition -// -// partition definitions will be set on partitionOptions -func GeneratePartDefsFromInterval(ctx expression.BuildContext, tp ast.AlterTableType, tbInfo *model.TableInfo, partitionOptions *ast.PartitionOptions) error { - if partitionOptions == nil { - return nil - } - var sb strings.Builder - err := partitionOptions.Interval.IntervalExpr.Expr.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - if err != nil { - return err - } - intervalString := driver.UnwrapFromSingleQuotes(sb.String()) - if len(intervalString) < 1 || intervalString[:1] < "1" || intervalString[:1] > "9" { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL, should be a positive number") - } - var currVal types.Datum - var startExpr, lastExpr, currExpr ast.ExprNode - var timeUnit ast.TimeUnitType - var partCol *model.ColumnInfo - if len(tbInfo.Partition.Columns) == 1 { - partCol = findColumnByName(tbInfo.Partition.Columns[0].L, tbInfo) - if partCol == nil { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL COLUMNS partitioning: could not find partitioning column") - } - } - timeUnit = partitionOptions.Interval.IntervalExpr.TimeUnit - switch tp { - case ast.AlterTablePartition: - // CREATE TABLE - startExpr = *partitionOptions.Interval.FirstRangeEnd - lastExpr = *partitionOptions.Interval.LastRangeEnd - case ast.AlterTableDropFirstPartition: - startExpr = *partitionOptions.Interval.FirstRangeEnd - lastExpr = partitionOptions.Expr - case ast.AlterTableAddLastPartition: - startExpr = *partitionOptions.Interval.LastRangeEnd - lastExpr = partitionOptions.Expr - default: - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: Internal error during generating altered INTERVAL partitions, no known alter type") - } - lastVal, err := expression.EvalSimpleAst(ctx, lastExpr) - if err != nil { - return err - } - evalCtx := ctx.GetEvalCtx() - if partCol != nil { - lastVal, err = lastVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) - if err != nil { - return err - } - } - var partDefs []*ast.PartitionDefinition - if len(partitionOptions.Definitions) != 0 { - partDefs = partitionOptions.Definitions - } else { - partDefs = make([]*ast.PartitionDefinition, 0, 1) - } - for i := 0; i < mysql.PartitionCountLimit; i++ { - if i == 0 { - currExpr = startExpr - // TODO: adjust the startExpr and have an offset for interval to handle - // Month/Quarters with start partition on day 28/29/30 - if tp == ast.AlterTableAddLastPartition { - // ALTER TABLE LAST PARTITION ... - // Current LAST PARTITION/start already exists, skip to next partition - continue - } - } else { - currExpr = &ast.BinaryOperationExpr{ - Op: opcode.Mul, - L: ast.NewValueExpr(i, "", ""), - R: partitionOptions.Interval.IntervalExpr.Expr, - } - if timeUnit == ast.TimeUnitInvalid { - currExpr = &ast.BinaryOperationExpr{ - Op: opcode.Plus, - L: startExpr, - R: currExpr, - } - } else { - currExpr = &ast.FuncCallExpr{ - FnName: model.NewCIStr("DATE_ADD"), - Args: []ast.ExprNode{ - startExpr, - currExpr, - &ast.TimeUnitExpr{Unit: timeUnit}, - }, - } - } - } - currVal, err = expression.EvalSimpleAst(ctx, currExpr) - if err != nil { - return err - } - if partCol != nil { - currVal, err = currVal.ConvertTo(evalCtx.TypeCtx(), &partCol.FieldType) - if err != nil { - return err - } - } - cmp, err := currVal.Compare(evalCtx.TypeCtx(), &lastVal, collate.GetBinaryCollator()) - if err != nil { - return err - } - if cmp > 0 { - lastStr, err := lastVal.ToString() - if err != nil { - return err - } - sb.Reset() - err = startExpr.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - if err != nil { - return err - } - startStr := sb.String() - errStr := fmt.Sprintf("INTERVAL: expr (%s) not matching FIRST + n INTERVALs (%s + n * %s", - lastStr, startStr, intervalString) - if timeUnit != ast.TimeUnitInvalid { - errStr = errStr + " " + timeUnit.String() - } - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs(errStr + ")") - } - valStr, err := currVal.ToString() - if err != nil { - return err - } - if len(valStr) == 0 || valStr[0:1] == "'" { - return dbterror.ErrGeneralUnsupportedDDL.GenWithStackByArgs("INTERVAL partitioning: Error when generating partition values") - } - partName := "P_LT_" + valStr - if timeUnit != ast.TimeUnitInvalid { - currExpr = ast.NewValueExpr(valStr, "", "") - } else { - if valStr[:1] == "-" { - currExpr = ast.NewValueExpr(currVal.GetInt64(), "", "") - } else { - currExpr = ast.NewValueExpr(currVal.GetUint64(), "", "") - } - } - partDefs = append(partDefs, &ast.PartitionDefinition{ - Name: model.NewCIStr(partName), - Clause: &ast.PartitionDefinitionClauseLessThan{ - Exprs: []ast.ExprNode{currExpr}, - }, - }) - if cmp == 0 { - // Last partition! - break - } - // The last loop still not reach the max value, return error. - if i == mysql.PartitionCountLimit-1 { - return errors.Trace(dbterror.ErrTooManyPartitions) - } - } - if len(tbInfo.Partition.Definitions)+len(partDefs) > mysql.PartitionCountLimit { - return errors.Trace(dbterror.ErrTooManyPartitions) - } - partitionOptions.Definitions = partDefs - return nil -} - -// buildPartitionDefinitionsInfo build partition definitions info without assign partition id. tbInfo will be constant -func buildPartitionDefinitionsInfo(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) (partitions []model.PartitionDefinition, err error) { - switch tbInfo.Partition.Type { - case model.PartitionTypeNone: - if len(defs) != 1 { - return nil, dbterror.ErrUnsupportedPartitionType - } - partitions = []model.PartitionDefinition{{Name: defs[0].Name}} - if comment, set := defs[0].Comment(); set { - partitions[0].Comment = comment - } - case model.PartitionTypeRange: - partitions, err = buildRangePartitionDefinitions(ctx, defs, tbInfo) - case model.PartitionTypeHash, model.PartitionTypeKey: - partitions, err = buildHashPartitionDefinitions(defs, tbInfo, numParts) - case model.PartitionTypeList: - partitions, err = buildListPartitionDefinitions(ctx, defs, tbInfo) - default: - err = dbterror.ErrUnsupportedPartitionType - } - - if err != nil { - return nil, err - } - - return partitions, nil -} - -func setPartitionPlacementFromOptions(partition *model.PartitionDefinition, options []*ast.TableOption) error { - // the partition inheritance of placement rules don't have to copy the placement elements to themselves. - // For example: - // t placement policy x (p1 placement policy y, p2) - // p2 will share the same rule as table t does, but it won't copy the meta to itself. we will - // append p2 range to the coverage of table t's rules. This mechanism is good for cascading change - // when policy x is altered. - for _, opt := range options { - if opt.Tp == ast.TableOptionPlacementPolicy { - partition.PlacementPolicyRef = &model.PolicyRefInfo{ - Name: model.NewCIStr(opt.StrValue), - } - } - } - - return nil -} - -func isNonDefaultPartitionOptionsUsed(defs []model.PartitionDefinition) bool { - for i := range defs { - orgDef := defs[i] - if orgDef.Name.O != fmt.Sprintf("p%d", i) { - return true - } - if len(orgDef.Comment) > 0 { - return true - } - if orgDef.PlacementPolicyRef != nil { - return true - } - } - return false -} - -func buildHashPartitionDefinitions(defs []*ast.PartitionDefinition, tbInfo *model.TableInfo, numParts uint64) ([]model.PartitionDefinition, error) { - if err := checkAddPartitionTooManyPartitions(tbInfo.Partition.Num); err != nil { - return nil, err - } - - definitions := make([]model.PartitionDefinition, numParts) - oldParts := uint64(len(tbInfo.Partition.Definitions)) - for i := uint64(0); i < numParts; i++ { - if i < oldParts { - // Use the existing definitions - def := tbInfo.Partition.Definitions[i] - definitions[i].Name = def.Name - definitions[i].Comment = def.Comment - definitions[i].PlacementPolicyRef = def.PlacementPolicyRef - } else if i < oldParts+uint64(len(defs)) { - // Use the new defs - def := defs[i-oldParts] - definitions[i].Name = def.Name - definitions[i].Comment, _ = def.Comment() - if err := setPartitionPlacementFromOptions(&definitions[i], def.Options); err != nil { - return nil, err - } - } else { - // Use the default - definitions[i].Name = model.NewCIStr(fmt.Sprintf("p%d", i)) - } - } - return definitions, nil -} - -func buildListPartitionDefinitions(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { - definitions := make([]model.PartitionDefinition, 0, len(defs)) - exprChecker := newPartitionExprChecker(ctx, nil, checkPartitionExprAllowed) - colTypes := collectColumnsType(tbInfo) - if len(colTypes) != len(tbInfo.Partition.Columns) { - return nil, dbterror.ErrWrongPartitionName.GenWithStack("partition column name cannot be found") - } - for _, def := range defs { - if err := def.Clause.Validate(model.PartitionTypeList, len(tbInfo.Partition.Columns)); err != nil { - return nil, err - } - clause := def.Clause.(*ast.PartitionDefinitionClauseIn) - partVals := make([][]types.Datum, 0, len(clause.Values)) - if len(tbInfo.Partition.Columns) > 0 { - for _, vs := range clause.Values { - vals, err := checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, vs) - if err != nil { - return nil, err - } - partVals = append(partVals, vals) - } - } else { - for _, vs := range clause.Values { - if err := checkPartitionValuesIsInt(ctx, def.Name, vs, tbInfo); err != nil { - return nil, err - } - } - } - comment, _ := def.Comment() - err := checkTooLongTable(def.Name) - if err != nil { - return nil, err - } - piDef := model.PartitionDefinition{ - Name: def.Name, - Comment: comment, - } - - if err = setPartitionPlacementFromOptions(&piDef, def.Options); err != nil { - return nil, err - } - - buf := new(bytes.Buffer) - for valIdx, vs := range clause.Values { - inValue := make([]string, 0, len(vs)) - isDefault := false - if len(vs) == 1 { - if _, ok := vs[0].(*ast.DefaultExpr); ok { - isDefault = true - } - } - if len(partVals) > valIdx && !isDefault { - for colIdx := range partVals[valIdx] { - partVal, err := generatePartValuesWithTp(partVals[valIdx][colIdx], colTypes[colIdx]) - if err != nil { - return nil, err - } - inValue = append(inValue, partVal) - } - } else { - for i := range vs { - vs[i].Accept(exprChecker) - if exprChecker.err != nil { - return nil, exprChecker.err - } - buf.Reset() - vs[i].Format(buf) - inValue = append(inValue, buf.String()) - } - } - piDef.InValues = append(piDef.InValues, inValue) - buf.Reset() - } - definitions = append(definitions, piDef) - } - return definitions, nil -} - -func collectColumnsType(tbInfo *model.TableInfo) []types.FieldType { - if len(tbInfo.Partition.Columns) > 0 { - colTypes := make([]types.FieldType, 0, len(tbInfo.Partition.Columns)) - for _, col := range tbInfo.Partition.Columns { - c := findColumnByName(col.L, tbInfo) - if c == nil { - return nil - } - colTypes = append(colTypes, c.FieldType) - } - - return colTypes - } - - return nil -} - -func buildRangePartitionDefinitions(ctx expression.BuildContext, defs []*ast.PartitionDefinition, tbInfo *model.TableInfo) ([]model.PartitionDefinition, error) { - definitions := make([]model.PartitionDefinition, 0, len(defs)) - exprChecker := newPartitionExprChecker(ctx, nil, checkPartitionExprAllowed) - colTypes := collectColumnsType(tbInfo) - if len(colTypes) != len(tbInfo.Partition.Columns) { - return nil, dbterror.ErrWrongPartitionName.GenWithStack("partition column name cannot be found") - } - for _, def := range defs { - if err := def.Clause.Validate(model.PartitionTypeRange, len(tbInfo.Partition.Columns)); err != nil { - return nil, err - } - clause := def.Clause.(*ast.PartitionDefinitionClauseLessThan) - var partValDatums []types.Datum - if len(tbInfo.Partition.Columns) > 0 { - var err error - if partValDatums, err = checkAndGetColumnsTypeAndValuesMatch(ctx, colTypes, clause.Exprs); err != nil { - return nil, err - } - } else { - if err := checkPartitionValuesIsInt(ctx, def.Name, clause.Exprs, tbInfo); err != nil { - return nil, err - } - } - comment, _ := def.Comment() - evalCtx := ctx.GetEvalCtx() - comment, err := validateCommentLength(evalCtx.ErrCtx(), evalCtx.SQLMode(), def.Name.L, &comment, dbterror.ErrTooLongTablePartitionComment) - if err != nil { - return nil, err - } - err = checkTooLongTable(def.Name) - if err != nil { - return nil, err - } - piDef := model.PartitionDefinition{ - Name: def.Name, - Comment: comment, - } - - if err = setPartitionPlacementFromOptions(&piDef, def.Options); err != nil { - return nil, err - } - - buf := new(bytes.Buffer) - // Range columns partitions support multi-column partitions. - for i, expr := range clause.Exprs { - expr.Accept(exprChecker) - if exprChecker.err != nil { - return nil, exprChecker.err - } - // If multi-column use new evaluated+normalized output, instead of just formatted expression - if len(partValDatums) > i { - var partVal string - if partValDatums[i].Kind() == types.KindNull { - return nil, dbterror.ErrNullInValuesLessThan - } - if _, ok := clause.Exprs[i].(*ast.MaxValueExpr); ok { - partVal, err = partValDatums[i].ToString() - if err != nil { - return nil, err - } - } else { - partVal, err = generatePartValuesWithTp(partValDatums[i], colTypes[i]) - if err != nil { - return nil, err - } - } - - piDef.LessThan = append(piDef.LessThan, partVal) - } else { - expr.Format(buf) - piDef.LessThan = append(piDef.LessThan, buf.String()) - buf.Reset() - } - } - definitions = append(definitions, piDef) - } - return definitions, nil -} - -func checkPartitionValuesIsInt(ctx expression.BuildContext, defName any, exprs []ast.ExprNode, tbInfo *model.TableInfo) error { - tp := types.NewFieldType(mysql.TypeLonglong) - if isPartExprUnsigned(ctx.GetEvalCtx(), tbInfo) { - tp.AddFlag(mysql.UnsignedFlag) - } - for _, exp := range exprs { - if _, ok := exp.(*ast.MaxValueExpr); ok { - continue - } - if d, ok := exp.(*ast.DefaultExpr); ok { - if d.Name != nil { - return dbterror.ErrPartitionConstDomain.GenWithStackByArgs() - } - continue - } - val, err := expression.EvalSimpleAst(ctx, exp) - if err != nil { - return err - } - switch val.Kind() { - case types.KindUint64, types.KindNull: - case types.KindInt64: - if mysql.HasUnsignedFlag(tp.GetFlag()) && val.GetInt64() < 0 { - return dbterror.ErrPartitionConstDomain.GenWithStackByArgs() - } - default: - return dbterror.ErrValuesIsNotIntType.GenWithStackByArgs(defName) - } - - evalCtx := ctx.GetEvalCtx() - _, err = val.ConvertTo(evalCtx.TypeCtx(), tp) - if err != nil && !types.ErrOverflow.Equal(err) { - return dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() - } - } - - return nil -} - -func checkPartitionNameUnique(pi *model.PartitionInfo) error { - newPars := pi.Definitions - partNames := make(map[string]struct{}, len(newPars)) - for _, newPar := range newPars { - if _, ok := partNames[newPar.Name.L]; ok { - return dbterror.ErrSameNamePartition.GenWithStackByArgs(newPar.Name) - } - partNames[newPar.Name.L] = struct{}{} - } - return nil -} - -func checkAddPartitionNameUnique(tbInfo *model.TableInfo, pi *model.PartitionInfo) error { - partNames := make(map[string]struct{}) - if tbInfo.Partition != nil { - oldPars := tbInfo.Partition.Definitions - for _, oldPar := range oldPars { - partNames[oldPar.Name.L] = struct{}{} - } - } - newPars := pi.Definitions - for _, newPar := range newPars { - if _, ok := partNames[newPar.Name.L]; ok { - return dbterror.ErrSameNamePartition.GenWithStackByArgs(newPar.Name) - } - partNames[newPar.Name.L] = struct{}{} - } - return nil -} - -func checkReorgPartitionNames(p *model.PartitionInfo, droppedNames []string, pi *model.PartitionInfo) error { - partNames := make(map[string]struct{}) - oldDefs := p.Definitions - for _, oldDef := range oldDefs { - partNames[oldDef.Name.L] = struct{}{} - } - for _, delName := range droppedNames { - droppedName := strings.ToLower(delName) - if _, ok := partNames[droppedName]; !ok { - return dbterror.ErrSameNamePartition.GenWithStackByArgs(delName) - } - delete(partNames, droppedName) - } - newDefs := pi.Definitions - for _, newDef := range newDefs { - if _, ok := partNames[newDef.Name.L]; ok { - return dbterror.ErrSameNamePartition.GenWithStackByArgs(newDef.Name) - } - partNames[newDef.Name.L] = struct{}{} - } - return nil -} - -func checkAndOverridePartitionID(newTableInfo, oldTableInfo *model.TableInfo) error { - // If any old partitionInfo has lost, that means the partition ID lost too, so did the data, repair failed. - if newTableInfo.Partition == nil { - return nil - } - if oldTableInfo.Partition == nil { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Old table doesn't have partitions") - } - if newTableInfo.Partition.Type != oldTableInfo.Partition.Type { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Partition type should be the same") - } - // Check whether partitionType is hash partition. - if newTableInfo.Partition.Type == model.PartitionTypeHash { - if newTableInfo.Partition.Num != oldTableInfo.Partition.Num { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Hash partition num should be the same") - } - } - for i, newOne := range newTableInfo.Partition.Definitions { - found := false - for _, oldOne := range oldTableInfo.Partition.Definitions { - // Fix issue 17952 which wanna substitute partition range expr. - // So eliminate stringSliceEqual(newOne.LessThan, oldOne.LessThan) here. - if newOne.Name.L == oldOne.Name.L { - newTableInfo.Partition.Definitions[i].ID = oldOne.ID - found = true - break - } - } - if !found { - return dbterror.ErrRepairTableFail.GenWithStackByArgs("Partition " + newOne.Name.L + " has lost") - } - } - return nil -} - -// checkPartitionFuncValid checks partition function validly. -func checkPartitionFuncValid(ctx expression.BuildContext, tblInfo *model.TableInfo, expr ast.ExprNode) error { - if expr == nil { - return nil - } - exprChecker := newPartitionExprChecker(ctx, tblInfo, checkPartitionExprArgs, checkPartitionExprAllowed) - expr.Accept(exprChecker) - if exprChecker.err != nil { - return errors.Trace(exprChecker.err) - } - if len(exprChecker.columns) == 0 { - return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) - } - return nil -} - -// checkResultOK derives from https://github.com/mysql/mysql-server/blob/5.7/sql/item_timefunc -// For partition tables, mysql do not support Constant, random or timezone-dependent expressions -// Based on mysql code to check whether field is valid, every time related type has check_valid_arguments_processor function. -func checkResultOK(ok bool) error { - if !ok { - return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) - } - - return nil -} - -// checkPartitionFuncType checks partition function return type. -func checkPartitionFuncType(ctx sessionctx.Context, expr ast.ExprNode, schema string, tblInfo *model.TableInfo) error { - if expr == nil { - return nil - } - - if schema == "" { - schema = ctx.GetSessionVars().CurrentDB - } - - e, err := expression.BuildSimpleExpr(ctx.GetExprCtx(), expr, expression.WithTableInfo(schema, tblInfo)) - if err != nil { - return errors.Trace(err) - } - if e.GetType(ctx.GetExprCtx().GetEvalCtx()).EvalType() == types.ETInt { - return nil - } - - if col, ok := expr.(*ast.ColumnNameExpr); ok { - return errors.Trace(dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(col.Name.Name.L)) - } - - return errors.Trace(dbterror.ErrPartitionFuncNotAllowed.GenWithStackByArgs("PARTITION")) -} - -// checkRangePartitionValue checks whether `less than value` is strictly increasing for each partition. -// Side effect: it may simplify the partition range definition from a constant expression to an integer. -func checkRangePartitionValue(ctx sessionctx.Context, tblInfo *model.TableInfo) error { - pi := tblInfo.Partition - defs := pi.Definitions - if len(defs) == 0 { - return nil - } - - if strings.EqualFold(defs[len(defs)-1].LessThan[0], partitionMaxValue) { - defs = defs[:len(defs)-1] - } - isUnsigned := isPartExprUnsigned(ctx.GetExprCtx().GetEvalCtx(), tblInfo) - var prevRangeValue any - for i := 0; i < len(defs); i++ { - if strings.EqualFold(defs[i].LessThan[0], partitionMaxValue) { - return errors.Trace(dbterror.ErrPartitionMaxvalue) - } - - currentRangeValue, fromExpr, err := getRangeValue(ctx.GetExprCtx(), defs[i].LessThan[0], isUnsigned) - if err != nil { - return errors.Trace(err) - } - if fromExpr { - // Constant fold the expression. - defs[i].LessThan[0] = fmt.Sprintf("%d", currentRangeValue) - } - - if i == 0 { - prevRangeValue = currentRangeValue - continue - } - - if isUnsigned { - if currentRangeValue.(uint64) <= prevRangeValue.(uint64) { - return errors.Trace(dbterror.ErrRangeNotIncreasing) - } - } else { - if currentRangeValue.(int64) <= prevRangeValue.(int64) { - return errors.Trace(dbterror.ErrRangeNotIncreasing) - } - } - prevRangeValue = currentRangeValue - } - return nil -} - -func checkListPartitionValue(ctx expression.BuildContext, tblInfo *model.TableInfo) error { - pi := tblInfo.Partition - if len(pi.Definitions) == 0 { - return ast.ErrPartitionsMustBeDefined.GenWithStackByArgs("LIST") - } - expStr, err := formatListPartitionValue(ctx, tblInfo) - if err != nil { - return errors.Trace(err) - } - - partitionsValuesMap := make(map[string]struct{}) - for _, s := range expStr { - if _, ok := partitionsValuesMap[s]; ok { - return errors.Trace(dbterror.ErrMultipleDefConstInListPart) - } - partitionsValuesMap[s] = struct{}{} - } - - return nil -} - -func formatListPartitionValue(ctx expression.BuildContext, tblInfo *model.TableInfo) ([]string, error) { - defs := tblInfo.Partition.Definitions - pi := tblInfo.Partition - var colTps []*types.FieldType - cols := make([]*model.ColumnInfo, 0, len(pi.Columns)) - if len(pi.Columns) == 0 { - tp := types.NewFieldType(mysql.TypeLonglong) - if isPartExprUnsigned(ctx.GetEvalCtx(), tblInfo) { - tp.AddFlag(mysql.UnsignedFlag) - } - colTps = []*types.FieldType{tp} - } else { - colTps = make([]*types.FieldType, 0, len(pi.Columns)) - for _, colName := range pi.Columns { - colInfo := findColumnByName(colName.L, tblInfo) - if colInfo == nil { - return nil, errors.Trace(dbterror.ErrFieldNotFoundPart) - } - colTps = append(colTps, colInfo.FieldType.Clone()) - cols = append(cols, colInfo) - } - } - - haveDefault := false - exprStrs := make([]string, 0) - inValueStrs := make([]string, 0, mathutil.Max(len(pi.Columns), 1)) - for i := range defs { - inValuesLoop: - for j, vs := range defs[i].InValues { - inValueStrs = inValueStrs[:0] - for k, v := range vs { - // if DEFAULT would be given as string, like "DEFAULT", - // it would be stored as "'DEFAULT'", - if strings.EqualFold(v, "DEFAULT") && k == 0 && len(vs) == 1 { - if haveDefault { - return nil, dbterror.ErrMultipleDefConstInListPart - } - haveDefault = true - continue inValuesLoop - } - if strings.EqualFold(v, "MAXVALUE") { - return nil, errors.Trace(dbterror.ErrMaxvalueInValuesIn) - } - expr, err := expression.ParseSimpleExpr(ctx, v, expression.WithCastExprTo(colTps[k])) - if err != nil { - return nil, errors.Trace(err) - } - eval, err := expr.Eval(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, errors.Trace(err) - } - s, err := eval.ToString() - if err != nil { - return nil, errors.Trace(err) - } - if eval.IsNull() { - s = "NULL" - } else { - if colTps[k].EvalType() == types.ETInt { - defs[i].InValues[j][k] = s - } - if colTps[k].EvalType() == types.ETString { - s = string(hack.String(collate.GetCollator(cols[k].GetCollate()).Key(s))) - s = driver.WrapInSingleQuotes(s) - } - } - inValueStrs = append(inValueStrs, s) - } - exprStrs = append(exprStrs, strings.Join(inValueStrs, ",")) - } - } - return exprStrs, nil -} - -// getRangeValue gets an integer from the range value string. -// The returned boolean value indicates whether the input string is a constant expression. -func getRangeValue(ctx expression.BuildContext, str string, unsigned bool) (any, bool, error) { - // Unsigned bigint was converted to uint64 handle. - if unsigned { - if value, err := strconv.ParseUint(str, 10, 64); err == nil { - return value, false, nil - } - - e, err1 := expression.ParseSimpleExpr(ctx, str) - if err1 != nil { - return 0, false, err1 - } - res, isNull, err2 := e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) - if err2 == nil && !isNull { - return uint64(res), true, nil - } - } else { - if value, err := strconv.ParseInt(str, 10, 64); err == nil { - return value, false, nil - } - // The range value maybe not an integer, it could be a constant expression. - // For example, the following two cases are the same: - // PARTITION p0 VALUES LESS THAN (TO_SECONDS('2004-01-01')) - // PARTITION p0 VALUES LESS THAN (63340531200) - e, err1 := expression.ParseSimpleExpr(ctx, str) - if err1 != nil { - return 0, false, err1 - } - res, isNull, err2 := e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) - if err2 == nil && !isNull { - return res, true, nil - } - } - return 0, false, dbterror.ErrNotAllowedTypeInPartition.GenWithStackByArgs(str) -} - -// CheckDropTablePartition checks if the partition exists and does not allow deleting the last existing partition in the table. -func CheckDropTablePartition(meta *model.TableInfo, partLowerNames []string) error { - pi := meta.Partition - if pi.Type != model.PartitionTypeRange && pi.Type != model.PartitionTypeList { - return dbterror.ErrOnlyOnRangeListPartition.GenWithStackByArgs("DROP") - } - - // To be error compatible with MySQL, we need to do this first! - // see https://github.com/pingcap/tidb/issues/31681#issuecomment-1015536214 - oldDefs := pi.Definitions - if len(oldDefs) <= len(partLowerNames) { - return errors.Trace(dbterror.ErrDropLastPartition) - } - - dupCheck := make(map[string]bool) - for _, pn := range partLowerNames { - found := false - for _, def := range oldDefs { - if def.Name.L == pn { - if _, ok := dupCheck[pn]; ok { - return errors.Trace(dbterror.ErrDropPartitionNonExistent.GenWithStackByArgs("DROP")) - } - dupCheck[pn] = true - found = true - break - } - } - if !found { - return errors.Trace(dbterror.ErrDropPartitionNonExistent.GenWithStackByArgs("DROP")) - } - } - return nil -} - -// updateDroppingPartitionInfo move dropping partitions to DroppingDefinitions, and return partitionIDs -func updateDroppingPartitionInfo(tblInfo *model.TableInfo, partLowerNames []string) []int64 { - oldDefs := tblInfo.Partition.Definitions - newDefs := make([]model.PartitionDefinition, 0, len(oldDefs)-len(partLowerNames)) - droppingDefs := make([]model.PartitionDefinition, 0, len(partLowerNames)) - pids := make([]int64, 0, len(partLowerNames)) - - // consider using a map to probe partLowerNames if too many partLowerNames - for i := range oldDefs { - found := false - for _, partName := range partLowerNames { - if oldDefs[i].Name.L == partName { - found = true - break - } - } - if found { - pids = append(pids, oldDefs[i].ID) - droppingDefs = append(droppingDefs, oldDefs[i]) - } else { - newDefs = append(newDefs, oldDefs[i]) - } - } - - tblInfo.Partition.Definitions = newDefs - tblInfo.Partition.DroppingDefinitions = droppingDefs - return pids -} - -func getPartitionDef(tblInfo *model.TableInfo, partName string) (index int, def *model.PartitionDefinition, _ error) { - defs := tblInfo.Partition.Definitions - for i := 0; i < len(defs); i++ { - if strings.EqualFold(defs[i].Name.L, strings.ToLower(partName)) { - return i, &(defs[i]), nil - } - } - return index, nil, table.ErrUnknownPartition.GenWithStackByArgs(partName, tblInfo.Name.O) -} - -func getPartitionIDsFromDefinitions(defs []model.PartitionDefinition) []int64 { - pids := make([]int64, 0, len(defs)) - for _, def := range defs { - pids = append(pids, def.ID) - } - return pids -} - -func hasGlobalIndex(tblInfo *model.TableInfo) bool { - for _, idxInfo := range tblInfo.Indices { - if idxInfo.Global { - return true - } - } - return false -} - -// getTableInfoWithDroppingPartitions builds oldTableInfo including dropping partitions, only used by onDropTablePartition. -func getTableInfoWithDroppingPartitions(t *model.TableInfo) *model.TableInfo { - p := t.Partition - nt := t.Clone() - np := *p - npd := make([]model.PartitionDefinition, 0, len(p.Definitions)+len(p.DroppingDefinitions)) - npd = append(npd, p.Definitions...) - npd = append(npd, p.DroppingDefinitions...) - np.Definitions = npd - np.DroppingDefinitions = nil - nt.Partition = &np - return nt -} - -// getTableInfoWithOriginalPartitions builds oldTableInfo including truncating partitions, only used by onTruncateTablePartition. -func getTableInfoWithOriginalPartitions(t *model.TableInfo, oldIDs []int64, newIDs []int64) *model.TableInfo { - nt := t.Clone() - np := nt.Partition - - // reconstruct original definitions - for _, oldDef := range np.DroppingDefinitions { - var newID int64 - for i := range newIDs { - if oldDef.ID == oldIDs[i] { - newID = newIDs[i] - break - } - } - for i := range np.Definitions { - newDef := &np.Definitions[i] - if newDef.ID == newID { - newDef.ID = oldDef.ID - break - } - } - } - - np.DroppingDefinitions = nil - np.NewPartitionIDs = nil - return nt -} - -func dropLabelRules(ctx context.Context, schemaName, tableName string, partNames []string) error { - deleteRules := make([]string, 0, len(partNames)) - for _, partName := range partNames { - deleteRules = append(deleteRules, fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, schemaName, tableName, partName)) - } - // delete batch rules - patch := label.NewRulePatch([]*label.Rule{}, deleteRules) - return infosync.UpdateLabelRules(ctx, patch) -} - -// onDropTablePartition deletes old partition meta. -func (w *worker) onDropTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var partNames []string - partInfo := model.PartitionInfo{} - if err := job.DecodeArgs(&partNames, &partInfo); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - if job.Type != model.ActionDropTablePartition { - // If rollback from reorganize partition, remove DroppingDefinitions from tableInfo - tblInfo.Partition.DroppingDefinitions = nil - // If rollback from adding table partition, remove addingDefinitions from tableInfo. - physicalTableIDs, pNames, rollbackBundles := rollbackAddingPartitionInfo(tblInfo) - err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), rollbackBundles) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the placement rules") - } - // TODO: Will this drop LabelRules for existing partitions, if the new partitions have the same name? - err = dropLabelRules(w.ctx, job.SchemaName, tblInfo.Name.L, pNames) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the label rules") - } - - if _, err := alterTableLabelRule(job.SchemaName, tblInfo, getIDs([]*model.TableInfo{tblInfo})); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - // ALTER TABLE ... PARTITION BY - if partInfo.Type != model.PartitionTypeNone { - // Also remove anything with the new table id - physicalTableIDs = append(physicalTableIDs, partInfo.NewTableID) - // Reset if it was normal table before - if tblInfo.Partition.Type == model.PartitionTypeNone || - tblInfo.Partition.DDLType == model.PartitionTypeNone { - tblInfo.Partition = nil - } else { - tblInfo.Partition.ClearReorgIntermediateInfo() - } - } else { - // REMOVE PARTITIONING - tblInfo.Partition.ClearReorgIntermediateInfo() - } - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateRollbackDone, model.StateNone, ver, tblInfo) - job.Args = []any{physicalTableIDs} - return ver, nil - } - - var physicalTableIDs []int64 - // In order to skip maintaining the state check in partitionDefinition, TiDB use droppingDefinition instead of state field. - // So here using `job.SchemaState` to judge what the stage of this job is. - originalState := job.SchemaState - switch job.SchemaState { - case model.StatePublic: - // If an error occurs, it returns that it cannot delete all partitions or that the partition doesn't exist. - err = CheckDropTablePartition(tblInfo, partNames) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - physicalTableIDs = updateDroppingPartitionInfo(tblInfo, partNames) - err = dropLabelRules(w.ctx, job.SchemaName, tblInfo.Name.L, partNames) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the label rules") - } - - if _, err := alterTableLabelRule(job.SchemaName, tblInfo, getIDs([]*model.TableInfo{tblInfo})); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - var bundles []*placement.Bundle - // create placement groups for each dropped partition to keep the data's placement before GC - // These placements groups will be deleted after GC - bundles, err = droppedPartitionBundles(t, tblInfo, tblInfo.Partition.DroppingDefinitions) - if err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - var tableBundle *placement.Bundle - // Recompute table bundle to remove dropped partitions rules from its group - tableBundle, err = placement.NewTableBundle(t, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if tableBundle != nil { - bundles = append(bundles, tableBundle) - } - - if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - job.SchemaState = model.StateDeleteOnly - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != job.SchemaState) - case model.StateDeleteOnly: - // This state is not a real 'DeleteOnly' state, because tidb does not maintaining the state check in partitionDefinition. - // Insert this state to confirm all servers can not see the old partitions when reorg is running, - // so that no new data will be inserted into old partitions when reorganizing. - job.SchemaState = model.StateDeleteReorganization - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != job.SchemaState) - case model.StateDeleteReorganization: - oldTblInfo := getTableInfoWithDroppingPartitions(tblInfo) - physicalTableIDs = getPartitionIDsFromDefinitions(tblInfo.Partition.DroppingDefinitions) - tbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, oldTblInfo) - if err != nil { - return ver, errors.Trace(err) - } - dbInfo, err := t.GetDatabase(job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - // If table has global indexes, we need reorg to clean up them. - if pt, ok := tbl.(table.PartitionedTable); ok && hasGlobalIndex(tblInfo) { - // Build elements for compatible with modify column type. elements will not be used when reorganizing. - elements := make([]*meta.Element, 0, len(tblInfo.Indices)) - for _, idxInfo := range tblInfo.Indices { - if idxInfo.Global { - elements = append(elements, &meta.Element{ID: idxInfo.ID, TypeKey: meta.IndexElementKey}) - } - } - sctx, err1 := w.sessPool.Get() - if err1 != nil { - return ver, err1 - } - defer w.sessPool.Put(sctx) - rh := newReorgHandler(sess.NewSession(sctx)) - reorgInfo, err := getReorgInfoFromPartitions(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, pt, physicalTableIDs, elements) - - if err != nil || reorgInfo.first { - // If we run reorg firstly, we should update the job snapshot version - // and then run the reorg next time. - return ver, errors.Trace(err) - } - err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (dropIndexErr error) { - defer tidbutil.Recover(metrics.LabelDDL, "onDropTablePartition", - func() { - dropIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("drop partition panic") - }, false) - return w.cleanupGlobalIndexes(pt, physicalTableIDs, reorgInfo) - }) - if err != nil { - if dbterror.ErrWaitReorgTimeout.Equal(err) { - // if timeout, we should return, check for the owner and re-wait job done. - return ver, nil - } - if dbterror.ErrPausedDDLJob.Equal(err) { - // if ErrPausedDDLJob, we should return, check for the owner and re-wait job done. - return ver, nil - } - return ver, errors.Trace(err) - } - } - if tblInfo.TiFlashReplica != nil { - removeTiFlashAvailablePartitionIDs(tblInfo, physicalTableIDs) - } - droppedDefs := tblInfo.Partition.DroppingDefinitions - tblInfo.Partition.DroppingDefinitions = nil - // used by ApplyDiff in updateSchemaVersion - job.CtxVars = []any{physicalTableIDs} - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.SchemaState = model.StateNone - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - dropPartitionEvent := statsutil.NewDropPartitionEvent( - job.SchemaID, - tblInfo, - &model.PartitionInfo{Definitions: droppedDefs}, - ) - asyncNotifyEvent(d, dropPartitionEvent) - // A background job will be created to delete old partition data. - job.Args = []any{physicalTableIDs} - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) - } - return ver, errors.Trace(err) -} - -func removeTiFlashAvailablePartitionIDs(tblInfo *model.TableInfo, pids []int64) { - // Remove the partitions - ids := tblInfo.TiFlashReplica.AvailablePartitionIDs - // Rarely called, so OK to take some time, to make it easy - for _, id := range pids { - for i, avail := range ids { - if id == avail { - tmp := ids[:i] - tmp = append(tmp, ids[i+1:]...) - ids = tmp - break - } - } - } - tblInfo.TiFlashReplica.AvailablePartitionIDs = ids -} - -// onTruncateTablePartition truncates old partition meta. -func (w *worker) onTruncateTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { - var ver int64 - var oldIDs, newIDs []int64 - if err := job.DecodeArgs(&oldIDs, &newIDs); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - if len(oldIDs) != len(newIDs) { - job.State = model.JobStateCancelled - return ver, errors.Trace(errors.New("len(oldIDs) must be the same as len(newIDs)")) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - pi := tblInfo.GetPartitionInfo() - if pi == nil { - return ver, errors.Trace(dbterror.ErrPartitionMgmtOnNonpartitioned) - } - - if !hasGlobalIndex(tblInfo) { - oldPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) - newPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) - for k, oldID := range oldIDs { - for i := 0; i < len(pi.Definitions); i++ { - def := &pi.Definitions[i] - if def.ID == oldID { - oldPartitions = append(oldPartitions, def.Clone()) - def.ID = newIDs[k] - // Shallow copy only use the def.ID in event handle. - newPartitions = append(newPartitions, *def) - break - } - } - } - if len(newPartitions) == 0 { - job.State = model.JobStateCancelled - return ver, table.ErrUnknownPartition.GenWithStackByArgs(fmt.Sprintf("pid:%v", oldIDs), tblInfo.Name.O) - } - - if err = clearTruncatePartitionTiflashStatus(tblInfo, newPartitions, oldIDs); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - if err = updateTruncatePartitionLabelRules(job, t, oldPartitions, newPartitions, tblInfo, oldIDs); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - preSplitAndScatter(w.sess.Context, d.store, tblInfo, newPartitions) - - job.CtxVars = []any{oldIDs, newIDs} - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - truncatePartitionEvent := statsutil.NewTruncatePartitionEvent( - job.SchemaID, - tblInfo, - &model.PartitionInfo{Definitions: newPartitions}, - &model.PartitionInfo{Definitions: oldPartitions}, - ) - asyncNotifyEvent(d, truncatePartitionEvent) - // A background job will be created to delete old partition data. - job.Args = []any{oldIDs} - - return ver, err - } - - // When table has global index, public->deleteOnly->deleteReorg->none schema changes should be handled. - switch job.SchemaState { - case model.StatePublic: - // Step1: generate new partition ids - truncatingDefinitions := make([]model.PartitionDefinition, 0, len(oldIDs)) - for i, oldID := range oldIDs { - for j := 0; j < len(pi.Definitions); j++ { - def := &pi.Definitions[j] - if def.ID == oldID { - truncatingDefinitions = append(truncatingDefinitions, def.Clone()) - def.ID = newIDs[i] - break - } - } - } - pi.DroppingDefinitions = truncatingDefinitions - pi.NewPartitionIDs = newIDs[:] - - job.SchemaState = model.StateDeleteOnly - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - case model.StateDeleteOnly: - // This state is not a real 'DeleteOnly' state, because tidb does not maintaining the state check in partitionDefinition. - // Insert this state to confirm all servers can not see the old partitions when reorg is running, - // so that no new data will be inserted into old partitions when reorganizing. - job.SchemaState = model.StateDeleteReorganization - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - case model.StateDeleteReorganization: - // Step2: clear global index rows. - physicalTableIDs := oldIDs - oldTblInfo := getTableInfoWithOriginalPartitions(tblInfo, oldIDs, newIDs) - - tbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, oldTblInfo) - if err != nil { - return ver, errors.Trace(err) - } - dbInfo, err := t.GetDatabase(job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - // If table has global indexes, we need reorg to clean up them. - if pt, ok := tbl.(table.PartitionedTable); ok && hasGlobalIndex(tblInfo) { - // Build elements for compatible with modify column type. elements will not be used when reorganizing. - elements := make([]*meta.Element, 0, len(tblInfo.Indices)) - for _, idxInfo := range tblInfo.Indices { - if idxInfo.Global { - elements = append(elements, &meta.Element{ID: idxInfo.ID, TypeKey: meta.IndexElementKey}) - } - } - sctx, err1 := w.sessPool.Get() - if err1 != nil { - return ver, err1 - } - defer w.sessPool.Put(sctx) - rh := newReorgHandler(sess.NewSession(sctx)) - reorgInfo, err := getReorgInfoFromPartitions(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, pt, physicalTableIDs, elements) - - if err != nil || reorgInfo.first { - // If we run reorg firstly, we should update the job snapshot version - // and then run the reorg next time. - return ver, errors.Trace(err) - } - err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (dropIndexErr error) { - defer tidbutil.Recover(metrics.LabelDDL, "onDropTablePartition", - func() { - dropIndexErr = dbterror.ErrCancelledDDLJob.GenWithStack("drop partition panic") - }, false) - return w.cleanupGlobalIndexes(pt, physicalTableIDs, reorgInfo) - }) - if err != nil { - if dbterror.ErrWaitReorgTimeout.Equal(err) { - // if timeout, we should return, check for the owner and re-wait job done. - return ver, nil - } - return ver, errors.Trace(err) - } - } - - // Step3: generate new partition ids and finish rest works - oldPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) - newPartitions := make([]model.PartitionDefinition, 0, len(oldIDs)) - for _, oldDef := range pi.DroppingDefinitions { - var newID int64 - for i := range oldIDs { - if oldDef.ID == oldIDs[i] { - newID = newIDs[i] - break - } - } - for i := 0; i < len(pi.Definitions); i++ { - def := &pi.Definitions[i] - if newID == def.ID { - oldPartitions = append(oldPartitions, oldDef.Clone()) - newPartitions = append(newPartitions, def.Clone()) - break - } - } - } - if len(newPartitions) == 0 { - job.State = model.JobStateCancelled - return ver, table.ErrUnknownPartition.GenWithStackByArgs(fmt.Sprintf("pid:%v", oldIDs), tblInfo.Name.O) - } - - if err = clearTruncatePartitionTiflashStatus(tblInfo, newPartitions, oldIDs); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - if err = updateTruncatePartitionLabelRules(job, t, oldPartitions, newPartitions, tblInfo, oldIDs); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - // Step4: clear DroppingDefinitions and finish job. - tblInfo.Partition.DroppingDefinitions = nil - tblInfo.Partition.NewPartitionIDs = nil - - preSplitAndScatter(w.sess.Context, d.store, tblInfo, newPartitions) - - // used by ApplyDiff in updateSchemaVersion - job.CtxVars = []any{oldIDs, newIDs} - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - truncatePartitionEvent := statsutil.NewTruncatePartitionEvent( - job.SchemaID, - tblInfo, - &model.PartitionInfo{Definitions: newPartitions}, - &model.PartitionInfo{Definitions: oldPartitions}, - ) - asyncNotifyEvent(d, truncatePartitionEvent) - // A background job will be created to delete old partition data. - job.Args = []any{oldIDs} - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) - } - - return ver, errors.Trace(err) -} - -func clearTruncatePartitionTiflashStatus(tblInfo *model.TableInfo, newPartitions []model.PartitionDefinition, oldIDs []int64) error { - // Clear the tiflash replica available status. - if tblInfo.TiFlashReplica != nil { - e := infosync.ConfigureTiFlashPDForPartitions(true, &newPartitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID) - failpoint.Inject("FailTiFlashTruncatePartition", func() { - e = errors.New("enforced error") - }) - if e != nil { - logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(e)) - return e - } - tblInfo.TiFlashReplica.Available = false - // Set partition replica become unavailable. - removeTiFlashAvailablePartitionIDs(tblInfo, oldIDs) - } - return nil -} - -func updateTruncatePartitionLabelRules(job *model.Job, t *meta.Meta, oldPartitions, newPartitions []model.PartitionDefinition, tblInfo *model.TableInfo, oldIDs []int64) error { - bundles, err := placement.NewPartitionListBundles(t, newPartitions) - if err != nil { - return errors.Trace(err) - } - - tableBundle, err := placement.NewTableBundle(t, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return errors.Trace(err) - } - - if tableBundle != nil { - bundles = append(bundles, tableBundle) - } - - // create placement groups for each dropped partition to keep the data's placement before GC - // These placements groups will be deleted after GC - keepDroppedBundles, err := droppedPartitionBundles(t, tblInfo, oldPartitions) - if err != nil { - job.State = model.JobStateCancelled - return errors.Trace(err) - } - bundles = append(bundles, keepDroppedBundles...) - - err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) - if err != nil { - return errors.Wrapf(err, "failed to notify PD the placement rules") - } - - tableID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, job.SchemaName, tblInfo.Name.L) - oldPartRules := make([]string, 0, len(oldIDs)) - for _, newPartition := range newPartitions { - oldPartRuleID := fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, job.SchemaName, tblInfo.Name.L, newPartition.Name.L) - oldPartRules = append(oldPartRules, oldPartRuleID) - } - - rules, err := infosync.GetLabelRules(context.TODO(), append(oldPartRules, tableID)) - if err != nil { - return errors.Wrapf(err, "failed to get label rules from PD") - } - - newPartIDs := getPartitionIDs(tblInfo) - newRules := make([]*label.Rule, 0, len(oldIDs)+1) - if tr, ok := rules[tableID]; ok { - newRules = append(newRules, tr.Clone().Reset(job.SchemaName, tblInfo.Name.L, "", append(newPartIDs, tblInfo.ID)...)) - } - - for idx, newPartition := range newPartitions { - if pr, ok := rules[oldPartRules[idx]]; ok { - newRules = append(newRules, pr.Clone().Reset(job.SchemaName, tblInfo.Name.L, newPartition.Name.L, newPartition.ID)) - } - } - - patch := label.NewRulePatch(newRules, []string{}) - err = infosync.UpdateLabelRules(context.TODO(), patch) - if err != nil { - return errors.Wrapf(err, "failed to notify PD the label rules") - } - - return nil -} - -// onExchangeTablePartition exchange partition data -func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var ( - // defID only for updateSchemaVersion - defID int64 - ptSchemaID int64 - ptID int64 - partName string - withValidation bool - ) - - if err := job.DecodeArgs(&defID, &ptSchemaID, &ptID, &partName, &withValidation); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - ntDbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - ptDbInfo, err := t.GetDatabase(ptSchemaID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - nt, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - - if job.IsRollingback() { - return rollbackExchangeTablePartition(d, t, job, nt) - } - pt, err := getTableInfo(t, ptID, ptSchemaID) - if err != nil { - if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { - job.State = model.JobStateCancelled - } - return ver, errors.Trace(err) - } - - _, partDef, err := getPartitionDef(pt, partName) - if err != nil { - return ver, errors.Trace(err) - } - if job.SchemaState == model.StateNone { - if pt.State != model.StatePublic { - job.State = model.JobStateCancelled - return ver, dbterror.ErrInvalidDDLState.GenWithStack("table %s is not in public, but %s", pt.Name, pt.State) - } - err = checkExchangePartition(pt, nt) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = checkTableDefCompatible(pt, nt) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = checkExchangePartitionPlacementPolicy(t, nt.PlacementPolicyRef, pt.PlacementPolicyRef, partDef.PlacementPolicyRef) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if defID != partDef.ID { - logutil.DDLLogger().Info("Exchange partition id changed, updating to actual id", - zap.Stringer("job", job), zap.Int64("defID", defID), zap.Int64("partDef.ID", partDef.ID)) - job.Args[0] = partDef.ID - defID = partDef.ID - err = updateDDLJob2Table(w.sess, job, true) - if err != nil { - return ver, errors.Trace(err) - } - } - var ptInfo []schemaIDAndTableInfo - if len(nt.Constraints) > 0 { - pt.ExchangePartitionInfo = &model.ExchangePartitionInfo{ - ExchangePartitionTableID: nt.ID, - ExchangePartitionDefID: defID, - } - ptInfo = append(ptInfo, schemaIDAndTableInfo{ - schemaID: ptSchemaID, - tblInfo: pt, - }) - } - nt.ExchangePartitionInfo = &model.ExchangePartitionInfo{ - ExchangePartitionTableID: ptID, - ExchangePartitionDefID: defID, - } - // We need an interim schema version, - // so there are no non-matching rows inserted - // into the table using the schema version - // before the exchange is made. - job.SchemaState = model.StateWriteOnly - return updateVersionAndTableInfoWithCheck(d, t, job, nt, true, ptInfo...) - } - // From now on, nt (the non-partitioned table) has - // ExchangePartitionInfo set, meaning it is restricted - // to only allow writes that would match the - // partition to be exchange with. - // So we need to rollback that change, instead of just cancelling. - - if d.lease > 0 { - delayForAsyncCommit() - } - - if defID != partDef.ID { - // Should never happen, should have been updated above, in previous state! - logutil.DDLLogger().Error("Exchange partition id changed, updating to actual id", - zap.Stringer("job", job), zap.Int64("defID", defID), zap.Int64("partDef.ID", partDef.ID)) - job.Args[0] = partDef.ID - defID = partDef.ID - err = updateDDLJob2Table(w.sess, job, true) - if err != nil { - return ver, errors.Trace(err) - } - } - - if withValidation { - ntbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, nt) - if err != nil { - return ver, errors.Trace(err) - } - ptbl, err := getTable(d.getAutoIDRequirement(), ptSchemaID, pt) - if err != nil { - return ver, errors.Trace(err) - } - err = checkExchangePartitionRecordValidation(w, ptbl, ntbl, ptDbInfo.Name.L, ntDbInfo.Name.L, partName) - if err != nil { - job.State = model.JobStateRollingback - return ver, errors.Trace(err) - } - } - - // partition table auto IDs. - ptAutoIDs, err := t.GetAutoIDAccessors(ptSchemaID, ptID).Get() - if err != nil { - return ver, errors.Trace(err) - } - // non-partition table auto IDs. - ntAutoIDs, err := t.GetAutoIDAccessors(job.SchemaID, nt.ID).Get() - if err != nil { - return ver, errors.Trace(err) - } - - if pt.TiFlashReplica != nil { - for i, id := range pt.TiFlashReplica.AvailablePartitionIDs { - if id == partDef.ID { - pt.TiFlashReplica.AvailablePartitionIDs[i] = nt.ID - break - } - } - } - - // Recreate non-partition table meta info, - // by first delete it with the old table id - err = t.DropTableOrView(job.SchemaID, nt.ID) - if err != nil { - return ver, errors.Trace(err) - } - - // exchange table meta id - pt.ExchangePartitionInfo = nil - // Used below to update the partitioned table's stats meta. - originalPartitionDef := partDef.Clone() - originalNt := nt.Clone() - partDef.ID, nt.ID = nt.ID, partDef.ID - - err = t.UpdateTable(ptSchemaID, pt) - if err != nil { - return ver, errors.Trace(err) - } - - err = t.CreateTableOrView(job.SchemaID, nt) - if err != nil { - return ver, errors.Trace(err) - } - - failpoint.Inject("exchangePartitionErr", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(ver, errors.New("occur an error after updating partition id")) - } - }) - - // Set both tables to the maximum auto IDs between normal table and partitioned table. - // TODO: Fix the issue of big transactions during EXCHANGE PARTITION with AutoID. - // Similar to https://github.com/pingcap/tidb/issues/46904 - newAutoIDs := meta.AutoIDGroup{ - RowID: mathutil.Max(ptAutoIDs.RowID, ntAutoIDs.RowID), - IncrementID: mathutil.Max(ptAutoIDs.IncrementID, ntAutoIDs.IncrementID), - RandomID: mathutil.Max(ptAutoIDs.RandomID, ntAutoIDs.RandomID), - } - err = t.GetAutoIDAccessors(ptSchemaID, pt.ID).Put(newAutoIDs) - if err != nil { - return ver, errors.Trace(err) - } - err = t.GetAutoIDAccessors(job.SchemaID, nt.ID).Put(newAutoIDs) - if err != nil { - return ver, errors.Trace(err) - } - - failpoint.Inject("exchangePartitionAutoID", func(val failpoint.Value) { - if val.(bool) { - seCtx, err := w.sessPool.Get() - defer w.sessPool.Put(seCtx) - if err != nil { - failpoint.Return(ver, err) - } - se := sess.NewSession(seCtx) - _, err = se.Execute(context.Background(), "insert ignore into test.pt values (40000000)", "exchange_partition_test") - if err != nil { - failpoint.Return(ver, err) - } - } - }) - - // the follow code is a swap function for rules of two partitions - // though partitions has exchanged their ID, swap still take effect - - bundles, err := bundlesForExchangeTablePartition(t, pt, partDef, nt) - if err != nil { - return ver, errors.Trace(err) - } - - if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { - return ver, errors.Wrapf(err, "failed to notify PD the placement rules") - } - - ntrID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, job.SchemaName, nt.Name.L) - ptrID := fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, job.SchemaName, pt.Name.L, partDef.Name.L) - - rules, err := infosync.GetLabelRules(context.TODO(), []string{ntrID, ptrID}) - if err != nil { - return 0, errors.Wrapf(err, "failed to get PD the label rules") - } - - ntr := rules[ntrID] - ptr := rules[ptrID] - - // This must be a bug, nt cannot be partitioned! - partIDs := getPartitionIDs(nt) - - var setRules []*label.Rule - var deleteRules []string - if ntr != nil && ptr != nil { - setRules = append(setRules, ntr.Clone().Reset(job.SchemaName, pt.Name.L, partDef.Name.L, partDef.ID)) - setRules = append(setRules, ptr.Clone().Reset(job.SchemaName, nt.Name.L, "", append(partIDs, nt.ID)...)) - } else if ptr != nil { - setRules = append(setRules, ptr.Clone().Reset(job.SchemaName, nt.Name.L, "", append(partIDs, nt.ID)...)) - // delete ptr - deleteRules = append(deleteRules, ptrID) - } else if ntr != nil { - setRules = append(setRules, ntr.Clone().Reset(job.SchemaName, pt.Name.L, partDef.Name.L, partDef.ID)) - // delete ntr - deleteRules = append(deleteRules, ntrID) - } - - patch := label.NewRulePatch(setRules, deleteRules) - err = infosync.UpdateLabelRules(context.TODO(), patch) - if err != nil { - return ver, errors.Wrapf(err, "failed to notify PD the label rules") - } - - job.SchemaState = model.StatePublic - nt.ExchangePartitionInfo = nil - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, nt, true) - if err != nil { - return ver, errors.Trace(err) - } - - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, pt) - exchangePartitionEvent := statsutil.NewExchangePartitionEvent( - job.SchemaID, - pt, - &model.PartitionInfo{Definitions: []model.PartitionDefinition{originalPartitionDef}}, - originalNt, - ) - asyncNotifyEvent(d, exchangePartitionEvent) - return ver, nil -} - -func getReorgPartitionInfo(t *meta.Meta, job *model.Job) (*model.TableInfo, []string, *model.PartitionInfo, []model.PartitionDefinition, []model.PartitionDefinition, error) { - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return nil, nil, nil, nil, nil, errors.Trace(err) - } - partInfo := &model.PartitionInfo{} - var partNames []string - err = job.DecodeArgs(&partNames, &partInfo) - if err != nil { - job.State = model.JobStateCancelled - return nil, nil, nil, nil, nil, errors.Trace(err) - } - var addingDefs, droppingDefs []model.PartitionDefinition - if tblInfo.Partition != nil { - addingDefs = tblInfo.Partition.AddingDefinitions - droppingDefs = tblInfo.Partition.DroppingDefinitions - tblInfo.Partition.NewTableID = partInfo.NewTableID - tblInfo.Partition.DDLType = partInfo.Type - tblInfo.Partition.DDLExpr = partInfo.Expr - tblInfo.Partition.DDLColumns = partInfo.Columns - } else { - tblInfo.Partition = getPartitionInfoTypeNone() - tblInfo.Partition.NewTableID = partInfo.NewTableID - tblInfo.Partition.Definitions[0].ID = tblInfo.ID - tblInfo.Partition.DDLType = partInfo.Type - tblInfo.Partition.DDLExpr = partInfo.Expr - tblInfo.Partition.DDLColumns = partInfo.Columns - } - if len(addingDefs) == 0 { - addingDefs = []model.PartitionDefinition{} - } - if len(droppingDefs) == 0 { - droppingDefs = []model.PartitionDefinition{} - } - return tblInfo, partNames, partInfo, droppingDefs, addingDefs, nil -} - -// onReorganizePartition reorganized the partitioning of a table including its indexes. -// ALTER TABLE t REORGANIZE PARTITION p0 [, p1...] INTO (PARTITION p0 ...) -// -// Takes one set of partitions and copies the data to a newly defined set of partitions -// -// ALTER TABLE t REMOVE PARTITIONING -// -// Makes a partitioned table non-partitioned, by first collapsing all partitions into a -// single partition and then converts that partition to a non-partitioned table -// -// ALTER TABLE t PARTITION BY ... -// -// Changes the partitioning to the newly defined partitioning type and definitions, -// works for both partitioned and non-partitioned tables. -// If the table is non-partitioned, then it will first convert it to a partitioned -// table with a single partition, i.e. the full table as a single partition. -// -// job.SchemaState goes through the following SchemaState(s): -// StateNone -> StateDeleteOnly -> StateWriteOnly -> StateWriteReorganization -// -> StateDeleteOrganization -> StatePublic -// There are more details embedded in the implementation, but the high level changes are: -// StateNone -> StateDeleteOnly: -// -// Various checks and validations. -// Add possible new unique/global indexes. They share the same state as job.SchemaState -// until end of StateWriteReorganization -> StateDeleteReorganization. -// Set DroppingDefinitions and AddingDefinitions. -// So both the new partitions and new indexes will be included in new delete/update DML. -// -// StateDeleteOnly -> StateWriteOnly: -// -// So both the new partitions and new indexes will be included also in update/insert DML. -// -// StateWriteOnly -> StateWriteReorganization: -// -// To make sure that when we are reorganizing the data, -// both the old and new partitions/indexes will be updated. -// -// StateWriteReorganization -> StateDeleteOrganization: -// -// Here is where all data is reorganized, both partitions and indexes. -// It copies all data from the old set of partitions into the new set of partitions, -// and creates the local indexes on the new set of partitions, -// and if new unique indexes are added, it also updates them with the rest of data from -// the non-touched partitions. -// For indexes that are to be replaced with new ones (old/new global index), -// mark the old indexes as StateDeleteReorganization and new ones as StatePublic -// Finally make the table visible with the new partition definitions. -// I.e. in this state clients will read from the old set of partitions, -// and will read the new set of partitions in StateDeleteReorganization. -// -// StateDeleteOrganization -> StatePublic: -// -// Now all heavy lifting is done, and we just need to finalize and drop things, while still doing -// double writes, since previous state sees the old partitions/indexes. -// Remove the old indexes and old partitions from the TableInfo. -// Add the old indexes and old partitions to the queue for later cleanup (see delete_range.go). -// Queue new partitions for statistics update. -// if ALTER TABLE t PARTITION BY/REMOVE PARTITIONING: -// Recreate the table with the new TableID, by DropTableOrView+CreateTableOrView -// -// StatePublic: -// -// Everything now looks as it should, no memory of old partitions/indexes, -// and no more double writing, since the previous state is only reading the new partitions/indexes. -func (w *worker) onReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - // Handle the rolling back job - if job.IsRollingback() { - ver, err := w.onDropTablePartition(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - return ver, nil - } - - tblInfo, partNames, partInfo, _, addingDefinitions, err := getReorgPartitionInfo(t, job) - if err != nil { - return ver, err - } - - switch job.SchemaState { - case model.StateNone: - // job.SchemaState == model.StateNone means the job is in the initial state of reorg partition. - // Here should use partInfo from job directly and do some check action. - // In case there was a race for queueing different schema changes on the same - // table and the checks was not done on the current schema version. - // The partInfo may have been checked against an older schema version for example. - // If the check is done here, it does not need to be repeated, since no other - // DDL on the same table can be run concurrently. - num := len(partInfo.Definitions) - len(partNames) + len(tblInfo.Partition.Definitions) - err = checkAddPartitionTooManyPartitions(uint64(num)) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = checkReorgPartitionNames(tblInfo.Partition, partNames, partInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - // Re-check that the dropped/added partitions are compatible with current definition - firstPartIdx, lastPartIdx, idMap, err := getReplacedPartitionIDs(partNames, tblInfo.Partition) - if err != nil { - job.State = model.JobStateCancelled - return ver, err - } - sctx := w.sess.Context - if err = checkReorgPartitionDefs(sctx, job.Type, tblInfo, partInfo, firstPartIdx, lastPartIdx, idMap); err != nil { - job.State = model.JobStateCancelled - return ver, err - } - - // move the adding definition into tableInfo. - updateAddingPartitionInfo(partInfo, tblInfo) - orgDefs := tblInfo.Partition.Definitions - _ = updateDroppingPartitionInfo(tblInfo, partNames) - // Reset original partitions, and keep DroppedDefinitions - tblInfo.Partition.Definitions = orgDefs - - // modify placement settings - for _, def := range tblInfo.Partition.AddingDefinitions { - if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, def.PlacementPolicyRef); err != nil { - // job.State = model.JobStateCancelled may be set depending on error in function above. - return ver, errors.Trace(err) - } - } - - // All global indexes must be recreated, we cannot update them in-place, since we must have - // both old and new set of partition ids in the unique index at the same time! - for _, index := range tblInfo.Indices { - if !index.Unique { - // for now, only unique index can be global, non-unique indexes are 'local' - continue - } - inAllPartitionColumns, err := checkPartitionKeysConstraint(partInfo, index.Columns, tblInfo) - if err != nil { - return ver, errors.Trace(err) - } - if index.Global || !inAllPartitionColumns { - // Duplicate the unique indexes with new index ids. - // If previously was Global or will be Global: - // it must be recreated with new index ID - newIndex := index.Clone() - newIndex.State = model.StateDeleteOnly - newIndex.ID = AllocateIndexID(tblInfo) - if inAllPartitionColumns { - newIndex.Global = false - } else { - // If not including all partitioning columns, make it Global - newIndex.Global = true - } - tblInfo.Indices = append(tblInfo.Indices, newIndex) - } - } - // From now on we cannot just cancel the DDL, we must roll back if changesMade! - changesMade := false - if tblInfo.TiFlashReplica != nil { - // Must set placement rule, and make sure it succeeds. - if err := infosync.ConfigureTiFlashPDForPartitions(true, &tblInfo.Partition.AddingDefinitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID); err != nil { - logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(err)) - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - changesMade = true - // In the next step, StateDeleteOnly, wait to verify the TiFlash replicas are OK - } - - bundles, err := alterTablePartitionBundles(t, tblInfo, tblInfo.Partition.AddingDefinitions) - if err != nil { - if !changesMade { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) - } - - if len(bundles) > 0 { - if err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles); err != nil { - if !changesMade { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the placement rules") - } - return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) - } - changesMade = true - } - - ids := getIDs([]*model.TableInfo{tblInfo}) - for _, p := range tblInfo.Partition.AddingDefinitions { - ids = append(ids, p.ID) - } - changed, err := alterTableLabelRule(job.SchemaName, tblInfo, ids) - changesMade = changesMade || changed - if err != nil { - if !changesMade { - job.State = model.JobStateCancelled - return ver, err - } - return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) - } - - // Doing the preSplitAndScatter here, since all checks are completed, - // and we will soon start writing to the new partitions. - if s, ok := d.store.(kv.SplittableStore); ok && s != nil { - // partInfo only contains the AddingPartitions - splitPartitionTableRegion(w.sess.Context, s, tblInfo, partInfo.Definitions, true) - } - - // Assume we cannot have more than MaxUint64 rows, set the progress to 1/10 of that. - metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, job.SchemaName, tblInfo.Name.String()).Set(0.1 / float64(math.MaxUint64)) - job.SchemaState = model.StateDeleteOnly - tblInfo.Partition.DDLState = model.StateDeleteOnly - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - - // Is really both StateDeleteOnly AND StateWriteOnly needed? - // If transaction A in WriteOnly inserts row 1 (into both new and old partition set) - // and then transaction B in DeleteOnly deletes that row (in both new and old) - // does really transaction B need to do the delete in the new partition? - // Yes, otherwise it would still be there when the WriteReorg happens, - // and WriteReorg would only copy existing rows to the new table, so unless it is - // deleted it would result in a ghost row! - // What about update then? - // Updates also need to be handled for new partitions in DeleteOnly, - // since it would not be overwritten during Reorganize phase. - // BUT if the update results in adding in one partition and deleting in another, - // THEN only the delete must happen in the new partition set, not the insert! - case model.StateDeleteOnly: - // This state is to confirm all servers can not see the new partitions when reorg is running, - // so that all deletes will be done in both old and new partitions when in either DeleteOnly - // or WriteOnly state. - // Also using the state for checking that the optional TiFlash replica is available, making it - // in a state without (much) data and easy to retry without side effects. - - // Reason for having it here, is to make it easy for retry, and better to make sure it is in-sync - // as early as possible, to avoid a long wait after the data copying. - if tblInfo.TiFlashReplica != nil && tblInfo.TiFlashReplica.Available { - // For available state, the new added partition should wait its replica to - // be finished, otherwise the query to this partition will be blocked. - count := tblInfo.TiFlashReplica.Count - needRetry, err := checkPartitionReplica(count, addingDefinitions, d) - if err != nil { - // need to rollback, since we tried to register the new - // partitions before! - return convertAddTablePartitionJob2RollbackJob(d, t, job, err, tblInfo) - } - if needRetry { - // The new added partition hasn't been replicated. - // Do nothing to the job this time, wait next worker round. - time.Sleep(tiflashCheckTiDBHTTPAPIHalfInterval) - // Set the error here which will lead this job exit when it's retry times beyond the limitation. - return ver, errors.Errorf("[ddl] add partition wait for tiflash replica to complete") - } - - // When TiFlash Replica is ready, we must move them into `AvailablePartitionIDs`. - // Since onUpdateFlashReplicaStatus cannot see the partitions yet (not public) - for _, d := range addingDefinitions { - tblInfo.TiFlashReplica.AvailablePartitionIDs = append(tblInfo.TiFlashReplica.AvailablePartitionIDs, d.ID) - } - } - - for i := range tblInfo.Indices { - if tblInfo.Indices[i].Unique && tblInfo.Indices[i].State == model.StateDeleteOnly { - tblInfo.Indices[i].State = model.StateWriteOnly - } - } - tblInfo.Partition.DDLState = model.StateWriteOnly - metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, job.SchemaName, tblInfo.Name.String()).Set(0.2 / float64(math.MaxUint64)) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - job.SchemaState = model.StateWriteOnly - case model.StateWriteOnly: - // Insert this state to confirm all servers can see the new partitions when reorg is running, - // so that new data will be updated in both old and new partitions when reorganizing. - job.SnapshotVer = 0 - for i := range tblInfo.Indices { - if tblInfo.Indices[i].Unique && tblInfo.Indices[i].State == model.StateWriteOnly { - tblInfo.Indices[i].State = model.StateWriteReorganization - } - } - tblInfo.Partition.DDLState = model.StateWriteReorganization - metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, job.SchemaName, tblInfo.Name.String()).Set(0.3 / float64(math.MaxUint64)) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - job.SchemaState = model.StateWriteReorganization - case model.StateWriteReorganization: - physicalTableIDs := getPartitionIDsFromDefinitions(tblInfo.Partition.DroppingDefinitions) - tbl, err2 := getTable(d.getAutoIDRequirement(), job.SchemaID, tblInfo) - if err2 != nil { - return ver, errors.Trace(err2) - } - var done bool - done, ver, err = doPartitionReorgWork(w, d, t, job, tbl, physicalTableIDs) - - if !done { - return ver, err - } - - for _, index := range tblInfo.Indices { - if !index.Unique { - continue - } - switch index.State { - case model.StateWriteReorganization: - // Newly created index, replacing old unique/global index - index.State = model.StatePublic - case model.StatePublic: - if index.Global { - // Mark the old global index as non-readable, and to be dropped - index.State = model.StateDeleteReorganization - } else { - inAllPartitionColumns, err := checkPartitionKeysConstraint(partInfo, index.Columns, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - if !inAllPartitionColumns { - // Mark the old unique index as non-readable, and to be dropped, - // since it is replaced by a global index - index.State = model.StateDeleteReorganization - } - } - } - } - firstPartIdx, lastPartIdx, idMap, err2 := getReplacedPartitionIDs(partNames, tblInfo.Partition) - failpoint.Inject("reorgPartWriteReorgReplacedPartIDsFail", func(val failpoint.Value) { - if val.(bool) { - err2 = errors.New("Injected error by reorgPartWriteReorgReplacedPartIDsFail") - } - }) - if err2 != nil { - return ver, err2 - } - newDefs := getReorganizedDefinitions(tblInfo.Partition, firstPartIdx, lastPartIdx, idMap) - - // From now on, use the new partitioning, but keep the Adding and Dropping for double write - tblInfo.Partition.Definitions = newDefs - tblInfo.Partition.Num = uint64(len(newDefs)) - if job.Type == model.ActionAlterTablePartitioning || - job.Type == model.ActionRemovePartitioning { - tblInfo.Partition.Type, tblInfo.Partition.DDLType = tblInfo.Partition.DDLType, tblInfo.Partition.Type - tblInfo.Partition.Expr, tblInfo.Partition.DDLExpr = tblInfo.Partition.DDLExpr, tblInfo.Partition.Expr - tblInfo.Partition.Columns, tblInfo.Partition.DDLColumns = tblInfo.Partition.DDLColumns, tblInfo.Partition.Columns - } - - // Now all the data copying is done, but we cannot simply remove the droppingDefinitions - // since they are a part of the normal Definitions that other nodes with - // the current schema version. So we need to double write for one more schema version - tblInfo.Partition.DDLState = model.StateDeleteReorganization - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - job.SchemaState = model.StateDeleteReorganization - - case model.StateDeleteReorganization: - // Drop the droppingDefinitions and finish the DDL - // This state is needed for the case where client A sees the schema - // with version of StateWriteReorg and would not see updates of - // client B that writes to the new partitions, previously - // addingDefinitions, since it would not double write to - // the droppingDefinitions during this time - // By adding StateDeleteReorg state, client B will write to both - // the new (previously addingDefinitions) AND droppingDefinitions - - // Register the droppingDefinitions ids for rangeDelete - // and the addingDefinitions for handling in the updateSchemaVersion - physicalTableIDs := getPartitionIDsFromDefinitions(tblInfo.Partition.DroppingDefinitions) - newIDs := getPartitionIDsFromDefinitions(partInfo.Definitions) - statisticsPartInfo := &model.PartitionInfo{Definitions: tblInfo.Partition.AddingDefinitions} - droppedPartInfo := &model.PartitionInfo{Definitions: tblInfo.Partition.DroppingDefinitions} - - tblInfo.Partition.DroppingDefinitions = nil - tblInfo.Partition.AddingDefinitions = nil - tblInfo.Partition.DDLState = model.StateNone - - var dropIndices []*model.IndexInfo - for _, indexInfo := range tblInfo.Indices { - if indexInfo.Unique && indexInfo.State == model.StateDeleteReorganization { - // Drop the old unique (possible global) index, see onDropIndex - indexInfo.State = model.StateNone - DropIndexColumnFlag(tblInfo, indexInfo) - RemoveDependentHiddenColumns(tblInfo, indexInfo) - dropIndices = append(dropIndices, indexInfo) - } - } - for _, indexInfo := range dropIndices { - removeIndexInfo(tblInfo, indexInfo) - } - var oldTblID int64 - if job.Type != model.ActionReorganizePartition { - // ALTER TABLE ... PARTITION BY - // REMOVE PARTITIONING - // Storing the old table ID, used for updating statistics. - oldTblID = tblInfo.ID - // TODO: Handle bundles? - // TODO: Add concurrent test! - // TODO: Will this result in big gaps? - // TODO: How to carrie over AUTO_INCREMENT etc.? - // Check if they are carried over in ApplyDiff?!? - autoIDs, err := t.GetAutoIDAccessors(job.SchemaID, tblInfo.ID).Get() - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - err = t.DropTableOrView(job.SchemaID, tblInfo.ID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - tblInfo.ID = partInfo.NewTableID - if partInfo.DDLType != model.PartitionTypeNone { - // if partitioned before, then also add the old table ID, - // otherwise it will be the already included first partition - physicalTableIDs = append(physicalTableIDs, oldTblID) - } - if job.Type == model.ActionRemovePartitioning { - tblInfo.Partition = nil - } else { - // ALTER TABLE ... PARTITION BY - tblInfo.Partition.ClearReorgIntermediateInfo() - } - err = t.GetAutoIDAccessors(job.SchemaID, tblInfo.ID).Put(autoIDs) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - // TODO: Add failpoint here? - err = t.CreateTableOrView(job.SchemaID, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - } - job.CtxVars = []any{physicalTableIDs, newIDs} - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - failpoint.Inject("reorgPartWriteReorgSchemaVersionUpdateFail", func(val failpoint.Value) { - if val.(bool) { - err = errors.New("Injected error by reorgPartWriteReorgSchemaVersionUpdateFail") - } - }) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - // How to handle this? - // Seems to only trigger asynchronous update of statistics. - // Should it actually be synchronous? - // Include the old table ID, if changed, which may contain global statistics, - // so it can be reused for the new (non)partitioned table. - event, err := newStatsDDLEventForJob( - job.SchemaID, - job.Type, oldTblID, tblInfo, statisticsPartInfo, droppedPartInfo, - ) - if err != nil { - return ver, errors.Trace(err) - } - asyncNotifyEvent(d, event) - // A background job will be created to delete old partition data. - job.Args = []any{physicalTableIDs} - - default: - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("partition", job.SchemaState) - } - - return ver, errors.Trace(err) -} - -// newStatsDDLEventForJob creates a statsutil.DDLEvent for a job. -// It is used for reorganize partition, add partitioning and remove partitioning. -func newStatsDDLEventForJob( - schemaID int64, - jobType model.ActionType, - oldTblID int64, - tblInfo *model.TableInfo, - addedPartInfo *model.PartitionInfo, - droppedPartInfo *model.PartitionInfo, -) (*statsutil.DDLEvent, error) { - var event *statsutil.DDLEvent - switch jobType { - case model.ActionReorganizePartition: - event = statsutil.NewReorganizePartitionEvent( - schemaID, - tblInfo, - addedPartInfo, - droppedPartInfo, - ) - case model.ActionAlterTablePartitioning: - event = statsutil.NewAddPartitioningEvent( - schemaID, - oldTblID, - tblInfo, - addedPartInfo, - ) - case model.ActionRemovePartitioning: - event = statsutil.NewRemovePartitioningEvent( - schemaID, - oldTblID, - tblInfo, - droppedPartInfo, - ) - default: - return nil, errors.Errorf("unknown job type: %s", jobType.String()) - } - return event, nil -} - -func doPartitionReorgWork(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, tbl table.Table, physTblIDs []int64) (done bool, ver int64, err error) { - job.ReorgMeta.ReorgTp = model.ReorgTypeTxn - sctx, err1 := w.sessPool.Get() - if err1 != nil { - return done, ver, err1 - } - defer w.sessPool.Put(sctx) - rh := newReorgHandler(sess.NewSession(sctx)) - indices := make([]*model.IndexInfo, 0, len(tbl.Meta().Indices)) - for _, index := range tbl.Meta().Indices { - if index.Global && index.State == model.StatePublic { - // Skip old global indexes, but rebuild all other indexes - continue - } - indices = append(indices, index) - } - elements := BuildElements(tbl.Meta().Columns[0], indices) - partTbl, ok := tbl.(table.PartitionedTable) - if !ok { - return false, ver, dbterror.ErrUnsupportedReorganizePartition.GenWithStackByArgs() - } - dbInfo, err := t.GetDatabase(job.SchemaID) - if err != nil { - return false, ver, errors.Trace(err) - } - reorgInfo, err := getReorgInfoFromPartitions(d.jobContext(job.ID, job.ReorgMeta), d, rh, job, dbInfo, partTbl, physTblIDs, elements) - err = w.runReorgJob(reorgInfo, tbl.Meta(), d.lease, func() (reorgErr error) { - defer tidbutil.Recover(metrics.LabelDDL, "doPartitionReorgWork", - func() { - reorgErr = dbterror.ErrCancelledDDLJob.GenWithStack("reorganize partition for table `%v` panic", tbl.Meta().Name) - }, false) - return w.reorgPartitionDataAndIndex(tbl, reorgInfo) - }) - if err != nil { - if dbterror.ErrPausedDDLJob.Equal(err) { - return false, ver, nil - } - - if dbterror.ErrWaitReorgTimeout.Equal(err) { - // If timeout, we should return, check for the owner and re-wait job done. - return false, ver, nil - } - if kv.IsTxnRetryableError(err) { - return false, ver, errors.Trace(err) - } - if err1 := rh.RemoveDDLReorgHandle(job, reorgInfo.elements); err1 != nil { - logutil.DDLLogger().Warn("reorg partition job failed, RemoveDDLReorgHandle failed, can't convert job to rollback", - zap.Stringer("job", job), zap.Error(err1)) - } - logutil.DDLLogger().Warn("reorg partition job failed, convert job to rollback", zap.Stringer("job", job), zap.Error(err)) - // TODO: rollback new global indexes! TODO: How to handle new index ids? - ver, err = convertAddTablePartitionJob2RollbackJob(d, t, job, err, tbl.Meta()) - return false, ver, errors.Trace(err) - } - return true, ver, err -} - -type reorgPartitionWorker struct { - *backfillCtx - // Static allocated to limit memory allocations - rowRecords []*rowRecord - rowDecoder *decoder.RowDecoder - rowMap map[int64]types.Datum - writeColOffsetMap map[int64]int - maxOffset int - reorgedTbl table.PartitionedTable -} - -func newReorgPartitionWorker(i int, t table.PhysicalTable, decodeColMap map[int64]decoder.Column, reorgInfo *reorgInfo, jc *JobContext) (*reorgPartitionWorker, error) { - bCtx, err := newBackfillCtx(i, reorgInfo, reorgInfo.SchemaName, t, jc, "reorg_partition_rate", false) - if err != nil { - return nil, err - } - reorgedTbl, err := tables.GetReorganizedPartitionedTable(t) - if err != nil { - return nil, errors.Trace(err) - } - pt := t.GetPartitionedTable() - if pt == nil { - return nil, dbterror.ErrUnsupportedReorganizePartition.GenWithStackByArgs() - } - partColIDs := reorgedTbl.GetPartitionColumnIDs() - writeColOffsetMap := make(map[int64]int, len(partColIDs)) - maxOffset := 0 - for _, id := range partColIDs { - var offset int - for _, col := range pt.Cols() { - if col.ID == id { - offset = col.Offset - break - } - } - writeColOffsetMap[id] = offset - maxOffset = mathutil.Max[int](maxOffset, offset) - } - return &reorgPartitionWorker{ - backfillCtx: bCtx, - rowDecoder: decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap), - rowMap: make(map[int64]types.Datum, len(decodeColMap)), - writeColOffsetMap: writeColOffsetMap, - maxOffset: maxOffset, - reorgedTbl: reorgedTbl, - }, nil -} - -func (w *reorgPartitionWorker) BackfillData(handleRange reorgBackfillTask) (taskCtx backfillTaskContext, errInTxn error) { - oprStartTime := time.Now() - ctx := kv.WithInternalSourceAndTaskType(context.Background(), w.jobContext.ddlJobSourceType(), kvutil.ExplicitTypeDDL) - errInTxn = kv.RunInNewTxn(ctx, w.ddlCtx.store, true, func(_ context.Context, txn kv.Transaction) error { - taskCtx.addedCount = 0 - taskCtx.scanCount = 0 - updateTxnEntrySizeLimitIfNeeded(txn) - txn.SetOption(kv.Priority, handleRange.priority) - if tagger := w.GetCtx().getResourceGroupTaggerForTopSQL(handleRange.getJobID()); tagger != nil { - txn.SetOption(kv.ResourceGroupTagger, tagger) - } - txn.SetOption(kv.ResourceGroupName, w.jobContext.resourceGroupName) - - rowRecords, nextKey, taskDone, err := w.fetchRowColVals(txn, handleRange) - if err != nil { - return errors.Trace(err) - } - taskCtx.nextKey = nextKey - taskCtx.done = taskDone - - warningsMap := make(map[errors.ErrorID]*terror.Error) - warningsCountMap := make(map[errors.ErrorID]int64) - for _, prr := range rowRecords { - taskCtx.scanCount++ - - err = txn.Set(prr.key, prr.vals) - if err != nil { - return errors.Trace(err) - } - taskCtx.addedCount++ - if prr.warning != nil { - if _, ok := warningsCountMap[prr.warning.ID()]; ok { - warningsCountMap[prr.warning.ID()]++ - } else { - warningsCountMap[prr.warning.ID()] = 1 - warningsMap[prr.warning.ID()] = prr.warning - } - } - // TODO: Future optimization: also write the indexes here? - // What if the transaction limit is just enough for a single row, without index? - // Hmm, how could that be in the first place? - // For now, implement the batch-txn w.addTableIndex, - // since it already exists and is in use - } - - // Collect the warnings. - taskCtx.warnings, taskCtx.warningsCount = warningsMap, warningsCountMap - - // also add the index entries here? And make sure they are not added somewhere else - - return nil - }) - logSlowOperations(time.Since(oprStartTime), "BackfillData", 3000) - - return -} - -func (w *reorgPartitionWorker) fetchRowColVals(txn kv.Transaction, taskRange reorgBackfillTask) ([]*rowRecord, kv.Key, bool, error) { - w.rowRecords = w.rowRecords[:0] - startTime := time.Now() - - // taskDone means that the added handle is out of taskRange.endHandle. - taskDone := false - sysTZ := w.loc - - tmpRow := make([]types.Datum, w.maxOffset+1) - var lastAccessedHandle kv.Key - oprStartTime := startTime - err := iterateSnapshotKeys(w.jobContext, w.ddlCtx.store, taskRange.priority, w.table.RecordPrefix(), txn.StartTS(), taskRange.startKey, taskRange.endKey, - func(handle kv.Handle, recordKey kv.Key, rawRow []byte) (bool, error) { - oprEndTime := time.Now() - logSlowOperations(oprEndTime.Sub(oprStartTime), "iterateSnapshotKeys in reorgPartitionWorker fetchRowColVals", 0) - oprStartTime = oprEndTime - - taskDone = recordKey.Cmp(taskRange.endKey) >= 0 - - if taskDone || len(w.rowRecords) >= w.batchCnt { - return false, nil - } - - _, err := w.rowDecoder.DecodeTheExistedColumnMap(w.exprCtx, handle, rawRow, sysTZ, w.rowMap) - if err != nil { - return false, errors.Trace(err) - } - - // Set the partitioning columns and calculate which partition to write to - for colID, offset := range w.writeColOffsetMap { - d, ok := w.rowMap[colID] - if !ok { - return false, dbterror.ErrUnsupportedReorganizePartition.GenWithStackByArgs() - } - tmpRow[offset] = d - } - p, err := w.reorgedTbl.GetPartitionByRow(w.exprCtx.GetEvalCtx(), tmpRow) - if err != nil { - return false, errors.Trace(err) - } - var newKey kv.Key - if w.reorgedTbl.Meta().PKIsHandle || w.reorgedTbl.Meta().IsCommonHandle { - pid := p.GetPhysicalID() - newKey = tablecodec.EncodeTablePrefix(pid) - newKey = append(newKey, recordKey[len(newKey):]...) - } else { - // Non-clustered table / not unique _tidb_rowid for the whole table - // Generate new _tidb_rowid if exists. - // Due to EXCHANGE PARTITION, the existing _tidb_rowid may collide between partitions! - if reserved, ok := w.tblCtx.GetReservedRowIDAlloc(); ok && reserved.Exhausted() { - // TODO: Which autoid allocator to use? - ids := uint64(max(1, w.batchCnt-len(w.rowRecords))) - // Keep using the original table's allocator - var baseRowID, maxRowID int64 - baseRowID, maxRowID, err = tables.AllocHandleIDs(w.ctx, w.tblCtx, w.reorgedTbl, ids) - if err != nil { - return false, errors.Trace(err) - } - reserved.Reset(baseRowID, maxRowID) - } - recordID, err := tables.AllocHandle(w.ctx, w.tblCtx, w.reorgedTbl) - if err != nil { - return false, errors.Trace(err) - } - newKey = tablecodec.EncodeRecordKey(p.RecordPrefix(), recordID) - } - w.rowRecords = append(w.rowRecords, &rowRecord{ - key: newKey, vals: rawRow, - }) - - w.cleanRowMap() - lastAccessedHandle = recordKey - if recordKey.Cmp(taskRange.endKey) == 0 { - taskDone = true - return false, nil - } - return true, nil - }) - - if len(w.rowRecords) == 0 { - taskDone = true - } - - logutil.DDLLogger().Debug("txn fetches handle info", - zap.Uint64("txnStartTS", txn.StartTS()), - zap.Stringer("taskRange", &taskRange), - zap.Duration("takeTime", time.Since(startTime))) - return w.rowRecords, getNextHandleKey(taskRange, taskDone, lastAccessedHandle), taskDone, errors.Trace(err) -} - -func (w *reorgPartitionWorker) cleanRowMap() { - for id := range w.rowMap { - delete(w.rowMap, id) - } -} - -func (w *reorgPartitionWorker) AddMetricInfo(cnt float64) { - w.metricCounter.Add(cnt) -} - -func (*reorgPartitionWorker) String() string { - return typeReorgPartitionWorker.String() -} - -func (w *reorgPartitionWorker) GetCtx() *backfillCtx { - return w.backfillCtx -} - -func (w *worker) reorgPartitionDataAndIndex(t table.Table, reorgInfo *reorgInfo) (err error) { - // First copy all table data to the new AddingDefinitions partitions - // from each of the DroppingDefinitions partitions. - // Then create all indexes on the AddingDefinitions partitions, - // both new local and new global indexes - // And last update new global indexes from the non-touched partitions - // Note it is hard to update global indexes in-place due to: - // - Transactions on different TiDB nodes/domains may see different states of the table/partitions - // - We cannot have multiple partition ids for a unique index entry. - - // Copy the data from the DroppingDefinitions to the AddingDefinitions - if bytes.Equal(reorgInfo.currElement.TypeKey, meta.ColumnElementKey) { - err = w.updatePhysicalTableRow(t, reorgInfo) - if err != nil { - return errors.Trace(err) - } - if len(reorgInfo.elements) <= 1 { - // No indexes to (re)create, all done! - return nil - } - } - - failpoint.Inject("reorgPartitionAfterDataCopy", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - panic("panic test in reorgPartitionAfterDataCopy") - } - }) - - if !bytes.Equal(reorgInfo.currElement.TypeKey, meta.IndexElementKey) { - // row data has been copied, now proceed with creating the indexes - // on the new AddingDefinitions partitions - reorgInfo.PhysicalTableID = t.Meta().Partition.AddingDefinitions[0].ID - reorgInfo.currElement = reorgInfo.elements[1] - var physTbl table.PhysicalTable - if tbl, ok := t.(table.PartitionedTable); ok { - physTbl = tbl.GetPartition(reorgInfo.PhysicalTableID) - } else if tbl, ok := t.(table.PhysicalTable); ok { - // This may be used when partitioning a non-partitioned table - physTbl = tbl - } - // Get the original start handle and end handle. - currentVer, err := getValidCurrentVersion(reorgInfo.d.store) - if err != nil { - return errors.Trace(err) - } - startHandle, endHandle, err := getTableRange(reorgInfo.NewJobContext(), reorgInfo.d, physTbl, currentVer.Ver, reorgInfo.Job.Priority) - if err != nil { - return errors.Trace(err) - } - - // Always (re)start with the full PhysicalTable range - reorgInfo.StartKey, reorgInfo.EndKey = startHandle, endHandle - - // Write the reorg info to store so the whole reorganize process can recover from panic. - err = reorgInfo.UpdateReorgMeta(reorgInfo.StartKey, w.sessPool) - logutil.DDLLogger().Info("update column and indexes", - zap.Int64("jobID", reorgInfo.Job.ID), - zap.ByteString("elementType", reorgInfo.currElement.TypeKey), - zap.Int64("elementID", reorgInfo.currElement.ID), - zap.Int64("partitionTableId", physTbl.GetPhysicalID()), - zap.String("startHandle", hex.EncodeToString(reorgInfo.StartKey)), - zap.String("endHandle", hex.EncodeToString(reorgInfo.EndKey))) - if err != nil { - return errors.Trace(err) - } - } - - pi := t.Meta().GetPartitionInfo() - if _, err = findNextPartitionID(reorgInfo.PhysicalTableID, pi.AddingDefinitions); err == nil { - // Now build all the indexes in the new partitions - err = w.addTableIndex(t, reorgInfo) - if err != nil { - return errors.Trace(err) - } - // All indexes are up-to-date for new partitions, - // now we only need to add the existing non-touched partitions - // to the global indexes - reorgInfo.elements = reorgInfo.elements[:0] - for _, indexInfo := range t.Meta().Indices { - if indexInfo.Global && indexInfo.State == model.StateWriteReorganization { - reorgInfo.elements = append(reorgInfo.elements, &meta.Element{ID: indexInfo.ID, TypeKey: meta.IndexElementKey}) - } - } - if len(reorgInfo.elements) == 0 { - // No global indexes - return nil - } - reorgInfo.currElement = reorgInfo.elements[0] - pid := pi.Definitions[0].ID - if _, err = findNextPartitionID(pid, pi.DroppingDefinitions); err == nil { - // Skip all dropped partitions - pid, err = findNextNonTouchedPartitionID(pid, pi) - if err != nil { - return errors.Trace(err) - } - } - if pid == 0 { - // All partitions will be dropped, nothing more to add to global indexes. - return nil - } - reorgInfo.PhysicalTableID = pid - var physTbl table.PhysicalTable - if tbl, ok := t.(table.PartitionedTable); ok { - physTbl = tbl.GetPartition(reorgInfo.PhysicalTableID) - } else if tbl, ok := t.(table.PhysicalTable); ok { - // This may be used when partitioning a non-partitioned table - physTbl = tbl - } - // Get the original start handle and end handle. - currentVer, err := getValidCurrentVersion(reorgInfo.d.store) - if err != nil { - return errors.Trace(err) - } - startHandle, endHandle, err := getTableRange(reorgInfo.NewJobContext(), reorgInfo.d, physTbl, currentVer.Ver, reorgInfo.Job.Priority) - if err != nil { - return errors.Trace(err) - } - - // Always (re)start with the full PhysicalTable range - reorgInfo.StartKey, reorgInfo.EndKey = startHandle, endHandle - - // Write the reorg info to store so the whole reorganize process can recover from panic. - err = reorgInfo.UpdateReorgMeta(reorgInfo.StartKey, w.sessPool) - logutil.DDLLogger().Info("update column and indexes", - zap.Int64("jobID", reorgInfo.Job.ID), - zap.ByteString("elementType", reorgInfo.currElement.TypeKey), - zap.Int64("elementID", reorgInfo.currElement.ID), - zap.Int64("partitionTableId", physTbl.GetPhysicalID()), - zap.String("startHandle", hex.EncodeToString(reorgInfo.StartKey)), - zap.String("endHandle", hex.EncodeToString(reorgInfo.EndKey))) - if err != nil { - return errors.Trace(err) - } - } - return w.addTableIndex(t, reorgInfo) -} - -func bundlesForExchangeTablePartition(t *meta.Meta, pt *model.TableInfo, newPar *model.PartitionDefinition, nt *model.TableInfo) ([]*placement.Bundle, error) { - bundles := make([]*placement.Bundle, 0, 3) - - ptBundle, err := placement.NewTableBundle(t, pt) - if err != nil { - return nil, errors.Trace(err) - } - if ptBundle != nil { - bundles = append(bundles, ptBundle) - } - - parBundle, err := placement.NewPartitionBundle(t, *newPar) - if err != nil { - return nil, errors.Trace(err) - } - if parBundle != nil { - bundles = append(bundles, parBundle) - } - - ntBundle, err := placement.NewTableBundle(t, nt) - if err != nil { - return nil, errors.Trace(err) - } - if ntBundle != nil { - bundles = append(bundles, ntBundle) - } - - if parBundle == nil && ntBundle != nil { - // newPar.ID is the ID of old table to exchange, so ntBundle != nil means it has some old placement settings. - // We should remove it in this situation - bundles = append(bundles, placement.NewBundle(newPar.ID)) - } - - if parBundle != nil && ntBundle == nil { - // nt.ID is the ID of old partition to exchange, so parBundle != nil means it has some old placement settings. - // We should remove it in this situation - bundles = append(bundles, placement.NewBundle(nt.ID)) - } - - return bundles, nil -} - -func checkExchangePartitionRecordValidation(w *worker, ptbl, ntbl table.Table, pschemaName, nschemaName, partitionName string) error { - verifyFunc := func(sql string, params ...any) error { - var ctx sessionctx.Context - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) - - rows, _, err := ctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(w.ctx, nil, sql, params...) - if err != nil { - return errors.Trace(err) - } - rowCount := len(rows) - if rowCount != 0 { - return errors.Trace(dbterror.ErrRowDoesNotMatchPartition) - } - // Check warnings! - // Is it possible to check how many rows where checked as well? - return nil - } - genConstraintCondition := func(constraints []*table.Constraint) string { - var buf strings.Builder - buf.WriteString("not (") - for i, cons := range constraints { - if i != 0 { - buf.WriteString(" and ") - } - buf.WriteString(fmt.Sprintf("(%s)", cons.ExprString)) - } - buf.WriteString(")") - return buf.String() - } - type CheckConstraintTable interface { - WritableConstraint() []*table.Constraint - } - - pt := ptbl.Meta() - index, _, err := getPartitionDef(pt, partitionName) - if err != nil { - return errors.Trace(err) - } - - var buf strings.Builder - buf.WriteString("select 1 from %n.%n where ") - paramList := []any{nschemaName, ntbl.Meta().Name.L} - checkNt := true - - pi := pt.Partition - switch pi.Type { - case model.PartitionTypeHash: - if pi.Num == 1 { - checkNt = false - } else { - buf.WriteString("mod(") - buf.WriteString(pi.Expr) - buf.WriteString(", %?) != %?") - paramList = append(paramList, pi.Num, index) - if index != 0 { - // TODO: if hash result can't be NULL, we can remove the check part. - // For example hash(id), but id is defined not NULL. - buf.WriteString(" or mod(") - buf.WriteString(pi.Expr) - buf.WriteString(", %?) is null") - paramList = append(paramList, pi.Num, index) - } - } - case model.PartitionTypeRange: - // Table has only one partition and has the maximum value - if len(pi.Definitions) == 1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - checkNt = false - } else { - // For range expression and range columns - if len(pi.Columns) == 0 { - conds, params := buildCheckSQLConditionForRangeExprPartition(pi, index) - buf.WriteString(conds) - paramList = append(paramList, params...) - } else { - conds, params := buildCheckSQLConditionForRangeColumnsPartition(pi, index) - buf.WriteString(conds) - paramList = append(paramList, params...) - } - } - case model.PartitionTypeList: - if len(pi.Columns) == 0 { - conds := buildCheckSQLConditionForListPartition(pi, index) - buf.WriteString(conds) - } else { - conds := buildCheckSQLConditionForListColumnsPartition(pi, index) - buf.WriteString(conds) - } - default: - return dbterror.ErrUnsupportedPartitionType.GenWithStackByArgs(pt.Name.O) - } - - if variable.EnableCheckConstraint.Load() { - pcc, ok := ptbl.(CheckConstraintTable) - if !ok { - return errors.Errorf("exchange partition process assert table partition failed") - } - pCons := pcc.WritableConstraint() - if len(pCons) > 0 { - if !checkNt { - checkNt = true - } else { - buf.WriteString(" or ") - } - buf.WriteString(genConstraintCondition(pCons)) - } - } - // Check non-partition table records. - if checkNt { - buf.WriteString(" limit 1") - err = verifyFunc(buf.String(), paramList...) - if err != nil { - return errors.Trace(err) - } - } - - // Check partition table records. - if variable.EnableCheckConstraint.Load() { - ncc, ok := ntbl.(CheckConstraintTable) - if !ok { - return errors.Errorf("exchange partition process assert table partition failed") - } - nCons := ncc.WritableConstraint() - if len(nCons) > 0 { - buf.Reset() - buf.WriteString("select 1 from %n.%n partition(%n) where ") - buf.WriteString(genConstraintCondition(nCons)) - buf.WriteString(" limit 1") - err = verifyFunc(buf.String(), pschemaName, pt.Name.L, partitionName) - if err != nil { - return errors.Trace(err) - } - } - } - return nil -} - -func checkExchangePartitionPlacementPolicy(t *meta.Meta, ntPPRef, ptPPRef, partPPRef *model.PolicyRefInfo) error { - partitionPPRef := partPPRef - if partitionPPRef == nil { - partitionPPRef = ptPPRef - } - - if ntPPRef == nil && partitionPPRef == nil { - return nil - } - if ntPPRef == nil || partitionPPRef == nil { - return dbterror.ErrTablesDifferentMetadata - } - - ptPlacementPolicyInfo, _ := getPolicyInfo(t, partitionPPRef.ID) - ntPlacementPolicyInfo, _ := getPolicyInfo(t, ntPPRef.ID) - if ntPlacementPolicyInfo == nil && ptPlacementPolicyInfo == nil { - return nil - } - if ntPlacementPolicyInfo == nil || ptPlacementPolicyInfo == nil { - return dbterror.ErrTablesDifferentMetadata - } - if ntPlacementPolicyInfo.Name.L != ptPlacementPolicyInfo.Name.L { - return dbterror.ErrTablesDifferentMetadata - } - - return nil -} - -func buildCheckSQLConditionForRangeExprPartition(pi *model.PartitionInfo, index int) (string, []any) { - var buf strings.Builder - paramList := make([]any, 0, 2) - // Since the pi.Expr string may contain the identifier, which couldn't be escaped in our ParseWithParams(...) - // So we write it to the origin sql string here. - if index == 0 { - buf.WriteString(pi.Expr) - buf.WriteString(" >= %?") - paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - buf.WriteString(pi.Expr) - buf.WriteString(" < %? or ") - buf.WriteString(pi.Expr) - buf.WriteString(" is null") - paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) - } else { - buf.WriteString(pi.Expr) - buf.WriteString(" < %? or ") - buf.WriteString(pi.Expr) - buf.WriteString(" >= %? or ") - buf.WriteString(pi.Expr) - buf.WriteString(" is null") - paramList = append(paramList, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - } - return buf.String(), paramList -} - -func buildCheckSQLConditionForRangeColumnsPartition(pi *model.PartitionInfo, index int) (string, []any) { - paramList := make([]any, 0, 2) - colName := pi.Columns[0].L - if index == 0 { - paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - return "%n >= %?", paramList - } else if index == len(pi.Definitions)-1 && strings.EqualFold(pi.Definitions[index].LessThan[0], partitionMaxValue) { - paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0])) - return "%n < %?", paramList - } - paramList = append(paramList, colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index-1].LessThan[0]), colName, driver.UnwrapFromSingleQuotes(pi.Definitions[index].LessThan[0])) - return "%n < %? or %n >= %?", paramList -} - -func buildCheckSQLConditionForListPartition(pi *model.PartitionInfo, index int) string { - var buf strings.Builder - buf.WriteString("not (") - for i, inValue := range pi.Definitions[index].InValues { - if i != 0 { - buf.WriteString(" OR ") - } - // AND has higher priority than OR, so no need for parentheses - for j, val := range inValue { - if j != 0 { - // Should never happen, since there should be no multi-columns, only a single expression :) - buf.WriteString(" AND ") - } - // null-safe compare '<=>' - buf.WriteString(fmt.Sprintf("(%s) <=> %s", pi.Expr, val)) - } - } - buf.WriteString(")") - return buf.String() -} - -func buildCheckSQLConditionForListColumnsPartition(pi *model.PartitionInfo, index int) string { - var buf strings.Builder - // How to find a match? - // (row <=> vals1) OR (row <=> vals2) - // How to find a non-matching row: - // NOT ( (row <=> vals1) OR (row <=> vals2) ... ) - buf.WriteString("not (") - colNames := make([]string, 0, len(pi.Columns)) - for i := range pi.Columns { - // TODO: check if there are no proper quoting function for this? - n := "`" + strings.ReplaceAll(pi.Columns[i].O, "`", "``") + "`" - colNames = append(colNames, n) - } - for i, colValues := range pi.Definitions[index].InValues { - if i != 0 { - buf.WriteString(" OR ") - } - // AND has higher priority than OR, so no need for parentheses - for j, val := range colValues { - if j != 0 { - buf.WriteString(" AND ") - } - // null-safe compare '<=>' - buf.WriteString(fmt.Sprintf("%s <=> %s", colNames[j], val)) - } - } - buf.WriteString(")") - return buf.String() -} - -func checkAddPartitionTooManyPartitions(piDefs uint64) error { - if piDefs > uint64(mysql.PartitionCountLimit) { - return errors.Trace(dbterror.ErrTooManyPartitions) - } - return nil -} - -func checkAddPartitionOnTemporaryMode(tbInfo *model.TableInfo) error { - if tbInfo.Partition != nil && tbInfo.TempTableType != model.TempTableNone { - return dbterror.ErrPartitionNoTemporary - } - return nil -} - -func checkPartitionColumnsUnique(tbInfo *model.TableInfo) error { - if len(tbInfo.Partition.Columns) <= 1 { - return nil - } - var columnsMap = make(map[string]struct{}) - for _, col := range tbInfo.Partition.Columns { - if _, ok := columnsMap[col.L]; ok { - return dbterror.ErrSameNamePartitionField.GenWithStackByArgs(col.L) - } - columnsMap[col.L] = struct{}{} - } - return nil -} - -func checkNoHashPartitions(_ sessionctx.Context, partitionNum uint64) error { - if partitionNum == 0 { - return ast.ErrNoParts.GenWithStackByArgs("partitions") - } - return nil -} - -func getPartitionIDs(table *model.TableInfo) []int64 { - if table.GetPartitionInfo() == nil { - return []int64{} - } - physicalTableIDs := make([]int64, 0, len(table.Partition.Definitions)) - for _, def := range table.Partition.Definitions { - physicalTableIDs = append(physicalTableIDs, def.ID) - } - return physicalTableIDs -} - -func getPartitionRuleIDs(dbName string, table *model.TableInfo) []string { - if table.GetPartitionInfo() == nil { - return []string{} - } - partRuleIDs := make([]string, 0, len(table.Partition.Definitions)) - for _, def := range table.Partition.Definitions { - partRuleIDs = append(partRuleIDs, fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, dbName, table.Name.L, def.Name.L)) - } - return partRuleIDs -} - -// checkPartitioningKeysConstraints checks that the range partitioning key is included in the table constraint. -func checkPartitioningKeysConstraints(sctx sessionctx.Context, s *ast.CreateTableStmt, tblInfo *model.TableInfo) error { - // Returns directly if there are no unique keys in the table. - if len(tblInfo.Indices) == 0 && !tblInfo.PKIsHandle { - return nil - } - - partCols, err := getPartitionColSlices(sctx.GetExprCtx(), tblInfo, s.Partition) - if err != nil { - return errors.Trace(err) - } - - // Checks that the partitioning key is included in the constraint. - // Every unique key on the table must use every column in the table's partitioning expression. - // See https://dev.mysql.com/doc/refman/5.7/en/partitioning-limitations-partitioning-keys-unique-keys.html - for _, index := range tblInfo.Indices { - if index.Unique && !checkUniqueKeyIncludePartKey(partCols, index.Columns) { - if index.Primary { - // global index does not support clustered index - if tblInfo.IsCommonHandle { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("CLUSTERED INDEX") - } - if !sctx.GetSessionVars().EnableGlobalIndex { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("PRIMARY KEY") - } - } - if !sctx.GetSessionVars().EnableGlobalIndex { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("UNIQUE INDEX") - } - } - } - // when PKIsHandle, tblInfo.Indices will not contain the primary key. - if tblInfo.PKIsHandle { - indexCols := []*model.IndexColumn{{ - Name: tblInfo.GetPkName(), - Length: types.UnspecifiedLength, - }} - if !checkUniqueKeyIncludePartKey(partCols, indexCols) { - return dbterror.ErrUniqueKeyNeedAllFieldsInPf.GenWithStackByArgs("CLUSTERED INDEX") - } - } - return nil -} - -func checkPartitionKeysConstraint(pi *model.PartitionInfo, indexColumns []*model.IndexColumn, tblInfo *model.TableInfo) (bool, error) { - var ( - partCols []*model.ColumnInfo - err error - ) - if pi.Type == model.PartitionTypeNone { - return true, nil - } - // The expr will be an empty string if the partition is defined by: - // CREATE TABLE t (...) PARTITION BY RANGE COLUMNS(...) - if partExpr := pi.Expr; partExpr != "" { - // Parse partitioning key, extract the column names in the partitioning key to slice. - partCols, err = extractPartitionColumns(partExpr, tblInfo) - if err != nil { - return false, err - } - } else { - partCols = make([]*model.ColumnInfo, 0, len(pi.Columns)) - for _, col := range pi.Columns { - colInfo := tblInfo.FindPublicColumnByName(col.L) - if colInfo == nil { - return false, infoschema.ErrColumnNotExists.GenWithStackByArgs(col, tblInfo.Name) - } - partCols = append(partCols, colInfo) - } - } - - // In MySQL, every unique key on the table must use every column in the table's partitioning expression.(This - // also includes the table's primary key.) - // In TiDB, global index will be built when this constraint is not satisfied and EnableGlobalIndex is set. - // See https://dev.mysql.com/doc/refman/5.7/en/partitioning-limitations-partitioning-keys-unique-keys.html - return checkUniqueKeyIncludePartKey(columnInfoSlice(partCols), indexColumns), nil -} - -type columnNameExtractor struct { - extractedColumns []*model.ColumnInfo - tblInfo *model.TableInfo - err error -} - -func (*columnNameExtractor) Enter(node ast.Node) (ast.Node, bool) { - return node, false -} - -func (cne *columnNameExtractor) Leave(node ast.Node) (ast.Node, bool) { - if c, ok := node.(*ast.ColumnNameExpr); ok { - info := findColumnByName(c.Name.Name.L, cne.tblInfo) - if info != nil { - cne.extractedColumns = append(cne.extractedColumns, info) - return node, true - } - cne.err = dbterror.ErrBadField.GenWithStackByArgs(c.Name.Name.O, "expression") - return nil, false - } - return node, true -} - -func findColumnByName(colName string, tblInfo *model.TableInfo) *model.ColumnInfo { - if tblInfo == nil { - return nil - } - for _, info := range tblInfo.Columns { - if info.Name.L == colName { - return info - } - } - return nil -} - -func extractPartitionColumns(partExpr string, tblInfo *model.TableInfo) ([]*model.ColumnInfo, error) { - partExpr = "select " + partExpr - stmts, _, err := parser.New().ParseSQL(partExpr) - if err != nil { - return nil, errors.Trace(err) - } - extractor := &columnNameExtractor{ - tblInfo: tblInfo, - extractedColumns: make([]*model.ColumnInfo, 0), - } - stmts[0].Accept(extractor) - if extractor.err != nil { - return nil, errors.Trace(extractor.err) - } - return extractor.extractedColumns, nil -} - -// stringSlice is defined for checkUniqueKeyIncludePartKey. -// if Go supports covariance, the code shouldn't be so complex. -type stringSlice interface { - Len() int - At(i int) string -} - -// checkUniqueKeyIncludePartKey checks that the partitioning key is included in the constraint. -func checkUniqueKeyIncludePartKey(partCols stringSlice, idxCols []*model.IndexColumn) bool { - for i := 0; i < partCols.Len(); i++ { - partCol := partCols.At(i) - _, idxCol := model.FindIndexColumnByName(idxCols, partCol) - if idxCol == nil { - // Partition column is not found in the index columns. - return false - } - if idxCol.Length > 0 { - // The partition column is found in the index columns, but the index column is a prefix index - return false - } - } - return true -} - -// columnInfoSlice implements the stringSlice interface. -type columnInfoSlice []*model.ColumnInfo - -func (cis columnInfoSlice) Len() int { - return len(cis) -} - -func (cis columnInfoSlice) At(i int) string { - return cis[i].Name.L -} - -// columnNameSlice implements the stringSlice interface. -type columnNameSlice []*ast.ColumnName - -func (cns columnNameSlice) Len() int { - return len(cns) -} - -func (cns columnNameSlice) At(i int) string { - return cns[i].Name.L -} - -func isPartExprUnsigned(ectx expression.EvalContext, tbInfo *model.TableInfo) bool { - ctx := tables.NewPartitionExprBuildCtx() - expr, err := expression.ParseSimpleExpr(ctx, tbInfo.Partition.Expr, expression.WithTableInfo("", tbInfo)) - if err != nil { - logutil.DDLLogger().Error("isPartExpr failed parsing expression!", zap.Error(err)) - return false - } - if mysql.HasUnsignedFlag(expr.GetType(ectx).GetFlag()) { - return true - } - return false -} - -// truncateTableByReassignPartitionIDs reassigns new partition ids. -func truncateTableByReassignPartitionIDs(t *meta.Meta, tblInfo *model.TableInfo, pids []int64) (err error) { - if len(pids) < len(tblInfo.Partition.Definitions) { - // To make it compatible with older versions when pids was not given - // and if there has been any add/reorganize partition increasing the number of partitions - morePids, err := t.GenGlobalIDs(len(tblInfo.Partition.Definitions) - len(pids)) - if err != nil { - return errors.Trace(err) - } - pids = append(pids, morePids...) - } - newDefs := make([]model.PartitionDefinition, 0, len(tblInfo.Partition.Definitions)) - for i, def := range tblInfo.Partition.Definitions { - newDef := def - newDef.ID = pids[i] - newDefs = append(newDefs, newDef) - } - tblInfo.Partition.Definitions = newDefs - return nil -} - -type partitionExprProcessor func(expression.BuildContext, *model.TableInfo, ast.ExprNode) error - -type partitionExprChecker struct { - processors []partitionExprProcessor - ctx expression.BuildContext - tbInfo *model.TableInfo - err error - - columns []*model.ColumnInfo -} - -func newPartitionExprChecker(ctx expression.BuildContext, tbInfo *model.TableInfo, processor ...partitionExprProcessor) *partitionExprChecker { - p := &partitionExprChecker{processors: processor, ctx: ctx, tbInfo: tbInfo} - p.processors = append(p.processors, p.extractColumns) - return p -} - -func (p *partitionExprChecker) Enter(n ast.Node) (node ast.Node, skipChildren bool) { - expr, ok := n.(ast.ExprNode) - if !ok { - return n, true - } - for _, processor := range p.processors { - if err := processor(p.ctx, p.tbInfo, expr); err != nil { - p.err = err - return n, true - } - } - - return n, false -} - -func (p *partitionExprChecker) Leave(n ast.Node) (node ast.Node, ok bool) { - return n, p.err == nil -} - -func (p *partitionExprChecker) extractColumns(_ expression.BuildContext, _ *model.TableInfo, expr ast.ExprNode) error { - columnNameExpr, ok := expr.(*ast.ColumnNameExpr) - if !ok { - return nil - } - colInfo := findColumnByName(columnNameExpr.Name.Name.L, p.tbInfo) - if colInfo == nil { - return errors.Trace(dbterror.ErrBadField.GenWithStackByArgs(columnNameExpr.Name.Name.L, "partition function")) - } - - p.columns = append(p.columns, colInfo) - return nil -} - -func checkPartitionExprAllowed(_ expression.BuildContext, tb *model.TableInfo, e ast.ExprNode) error { - switch v := e.(type) { - case *ast.FuncCallExpr: - if _, ok := expression.AllowedPartitionFuncMap[v.FnName.L]; ok { - return nil - } - case *ast.BinaryOperationExpr: - if _, ok := expression.AllowedPartition4BinaryOpMap[v.Op]; ok { - return errors.Trace(checkNoTimestampArgs(tb, v.L, v.R)) - } - case *ast.UnaryOperationExpr: - if _, ok := expression.AllowedPartition4UnaryOpMap[v.Op]; ok { - return errors.Trace(checkNoTimestampArgs(tb, v.V)) - } - case *ast.ColumnNameExpr, *ast.ParenthesesExpr, *driver.ValueExpr, *ast.MaxValueExpr, - *ast.DefaultExpr, *ast.TimeUnitExpr: - return nil - } - return errors.Trace(dbterror.ErrPartitionFunctionIsNotAllowed) -} - -func checkPartitionExprArgs(_ expression.BuildContext, tblInfo *model.TableInfo, e ast.ExprNode) error { - expr, ok := e.(*ast.FuncCallExpr) - if !ok { - return nil - } - argsType, err := collectArgsType(tblInfo, expr.Args...) - if err != nil { - return errors.Trace(err) - } - switch expr.FnName.L { - case ast.ToDays, ast.ToSeconds, ast.DayOfMonth, ast.Month, ast.DayOfYear, ast.Quarter, ast.YearWeek, - ast.Year, ast.Weekday, ast.DayOfWeek, ast.Day: - return errors.Trace(checkResultOK(hasDateArgs(argsType...))) - case ast.Hour, ast.Minute, ast.Second, ast.TimeToSec, ast.MicroSecond: - return errors.Trace(checkResultOK(hasTimeArgs(argsType...))) - case ast.UnixTimestamp: - return errors.Trace(checkResultOK(hasTimestampArgs(argsType...))) - case ast.FromDays: - return errors.Trace(checkResultOK(hasDateArgs(argsType...) || hasTimeArgs(argsType...))) - case ast.Extract: - switch expr.Args[0].(*ast.TimeUnitExpr).Unit { - case ast.TimeUnitYear, ast.TimeUnitYearMonth, ast.TimeUnitQuarter, ast.TimeUnitMonth, ast.TimeUnitDay: - return errors.Trace(checkResultOK(hasDateArgs(argsType...))) - case ast.TimeUnitDayMicrosecond, ast.TimeUnitDayHour, ast.TimeUnitDayMinute, ast.TimeUnitDaySecond: - return errors.Trace(checkResultOK(hasDatetimeArgs(argsType...))) - case ast.TimeUnitHour, ast.TimeUnitHourMinute, ast.TimeUnitHourSecond, ast.TimeUnitMinute, ast.TimeUnitMinuteSecond, - ast.TimeUnitSecond, ast.TimeUnitMicrosecond, ast.TimeUnitHourMicrosecond, ast.TimeUnitMinuteMicrosecond, ast.TimeUnitSecondMicrosecond: - return errors.Trace(checkResultOK(hasTimeArgs(argsType...))) - default: - return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) - } - case ast.DateDiff: - return errors.Trace(checkResultOK(slice.AllOf(argsType, func(i int) bool { - return hasDateArgs(argsType[i]) - }))) - - case ast.Abs, ast.Ceiling, ast.Floor, ast.Mod: - has := hasTimestampArgs(argsType...) - if has { - return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) - } - } - return nil -} - -func collectArgsType(tblInfo *model.TableInfo, exprs ...ast.ExprNode) ([]byte, error) { - ts := make([]byte, 0, len(exprs)) - for _, arg := range exprs { - col, ok := arg.(*ast.ColumnNameExpr) - if !ok { - continue - } - columnInfo := findColumnByName(col.Name.Name.L, tblInfo) - if columnInfo == nil { - return nil, errors.Trace(dbterror.ErrBadField.GenWithStackByArgs(col.Name.Name.L, "partition function")) - } - ts = append(ts, columnInfo.GetType()) - } - - return ts, nil -} - -func hasDateArgs(argsType ...byte) bool { - return slice.AnyOf(argsType, func(i int) bool { - return argsType[i] == mysql.TypeDate || argsType[i] == mysql.TypeDatetime - }) -} - -func hasTimeArgs(argsType ...byte) bool { - return slice.AnyOf(argsType, func(i int) bool { - return argsType[i] == mysql.TypeDuration || argsType[i] == mysql.TypeDatetime - }) -} - -func hasTimestampArgs(argsType ...byte) bool { - return slice.AnyOf(argsType, func(i int) bool { - return argsType[i] == mysql.TypeTimestamp - }) -} - -func hasDatetimeArgs(argsType ...byte) bool { - return slice.AnyOf(argsType, func(i int) bool { - return argsType[i] == mysql.TypeDatetime - }) -} - -func checkNoTimestampArgs(tbInfo *model.TableInfo, exprs ...ast.ExprNode) error { - argsType, err := collectArgsType(tbInfo, exprs...) - if err != nil { - return err - } - if hasTimestampArgs(argsType...) { - return errors.Trace(dbterror.ErrWrongExprInPartitionFunc) - } - return nil -} - -// hexIfNonPrint checks if printable UTF-8 characters from a single quoted string, -// if so, just returns the string -// else returns a hex string of the binary string (i.e. actual encoding, not unicode code points!) -func hexIfNonPrint(s string) string { - isPrint := true - // https://go.dev/blog/strings `for range` of string converts to runes! - for _, runeVal := range s { - if !strconv.IsPrint(runeVal) { - isPrint = false - break - } - } - if isPrint { - return s - } - // To avoid 'simple' MySQL accepted escape characters, to be showed as hex, just escape them - // \0 \b \n \r \t \Z, see https://dev.mysql.com/doc/refman/8.0/en/string-literals.html - isPrint = true - res := "" - for _, runeVal := range s { - switch runeVal { - case 0: // Null - res += `\0` - case 7: // Bell - res += `\b` - case '\t': // 9 - res += `\t` - case '\n': // 10 - res += `\n` - case '\r': // 13 - res += `\r` - case 26: // ctrl-z / Substitute - res += `\Z` - default: - if !strconv.IsPrint(runeVal) { - isPrint = false - break - } - res += string(runeVal) - } - } - if isPrint { - return res - } - // Not possible to create an easy interpreted MySQL string, return as hex string - // Can be converted to string in MySQL like: CAST(UNHEX('') AS CHAR(255)) - return "0x" + hex.EncodeToString([]byte(driver.UnwrapFromSingleQuotes(s))) -} - -func writeColumnListToBuffer(partitionInfo *model.PartitionInfo, sqlMode mysql.SQLMode, buf *bytes.Buffer) { - if partitionInfo.IsEmptyColumns { - return - } - for i, col := range partitionInfo.Columns { - buf.WriteString(stringutil.Escape(col.O, sqlMode)) - if i < len(partitionInfo.Columns)-1 { - buf.WriteString(",") - } - } -} - -// AppendPartitionInfo is used in SHOW CREATE TABLE as well as generation the SQL syntax -// for the PartitionInfo during validation of various DDL commands -func AppendPartitionInfo(partitionInfo *model.PartitionInfo, buf *bytes.Buffer, sqlMode mysql.SQLMode) { - if partitionInfo == nil { - return - } - // Since MySQL 5.1/5.5 is very old and TiDB aims for 5.7/8.0 compatibility, we will not - // include the /*!50100 or /*!50500 comments for TiDB. - // This also solves the issue with comments within comments that would happen for - // PLACEMENT POLICY options. - defaultPartitionDefinitions := true - if partitionInfo.Type == model.PartitionTypeHash || - partitionInfo.Type == model.PartitionTypeKey { - for i, def := range partitionInfo.Definitions { - if def.Name.O != fmt.Sprintf("p%d", i) { - defaultPartitionDefinitions = false - break - } - if len(def.Comment) > 0 || def.PlacementPolicyRef != nil { - defaultPartitionDefinitions = false - break - } - } - - if defaultPartitionDefinitions { - if partitionInfo.Type == model.PartitionTypeHash { - fmt.Fprintf(buf, "\nPARTITION BY HASH (%s) PARTITIONS %d", partitionInfo.Expr, partitionInfo.Num) - } else { - buf.WriteString("\nPARTITION BY KEY (") - writeColumnListToBuffer(partitionInfo, sqlMode, buf) - buf.WriteString(")") - fmt.Fprintf(buf, " PARTITIONS %d", partitionInfo.Num) - } - return - } - } - // this if statement takes care of lists/range/key columns case - if len(partitionInfo.Columns) > 0 { - // partitionInfo.Type == model.PartitionTypeRange || partitionInfo.Type == model.PartitionTypeList - // || partitionInfo.Type == model.PartitionTypeKey - // Notice that MySQL uses two spaces between LIST and COLUMNS... - if partitionInfo.Type == model.PartitionTypeKey { - fmt.Fprintf(buf, "\nPARTITION BY %s (", partitionInfo.Type.String()) - } else { - fmt.Fprintf(buf, "\nPARTITION BY %s COLUMNS(", partitionInfo.Type.String()) - } - writeColumnListToBuffer(partitionInfo, sqlMode, buf) - buf.WriteString(")\n(") - } else { - fmt.Fprintf(buf, "\nPARTITION BY %s (%s)\n(", partitionInfo.Type.String(), partitionInfo.Expr) - } - - AppendPartitionDefs(partitionInfo, buf, sqlMode) - buf.WriteString(")") -} - -// AppendPartitionDefs generates a list of partition definitions needed for SHOW CREATE TABLE (in executor/show.go) -// as well as needed for generating the ADD PARTITION query for INTERVAL partitioning of ALTER TABLE t LAST PARTITION -// and generating the CREATE TABLE query from CREATE TABLE ... INTERVAL -func AppendPartitionDefs(partitionInfo *model.PartitionInfo, buf *bytes.Buffer, sqlMode mysql.SQLMode) { - for i, def := range partitionInfo.Definitions { - if i > 0 { - fmt.Fprintf(buf, ",\n ") - } - fmt.Fprintf(buf, "PARTITION %s", stringutil.Escape(def.Name.O, sqlMode)) - // PartitionTypeHash and PartitionTypeKey do not have any VALUES definition - if partitionInfo.Type == model.PartitionTypeRange { - lessThans := make([]string, len(def.LessThan)) - for idx, v := range def.LessThan { - lessThans[idx] = hexIfNonPrint(v) - } - fmt.Fprintf(buf, " VALUES LESS THAN (%s)", strings.Join(lessThans, ",")) - } else if partitionInfo.Type == model.PartitionTypeList { - if len(def.InValues) == 0 { - fmt.Fprintf(buf, " DEFAULT") - } else if len(def.InValues) == 1 && - len(def.InValues[0]) == 1 && - strings.EqualFold(def.InValues[0][0], "DEFAULT") { - fmt.Fprintf(buf, " DEFAULT") - } else { - values := bytes.NewBuffer(nil) - for j, inValues := range def.InValues { - if j > 0 { - values.WriteString(",") - } - if len(inValues) > 1 { - values.WriteString("(") - tmpVals := make([]string, len(inValues)) - for idx, v := range inValues { - tmpVals[idx] = hexIfNonPrint(v) - } - values.WriteString(strings.Join(tmpVals, ",")) - values.WriteString(")") - } else if len(inValues) == 1 { - values.WriteString(hexIfNonPrint(inValues[0])) - } - } - fmt.Fprintf(buf, " VALUES IN (%s)", values.String()) - } - } - if len(def.Comment) > 0 { - fmt.Fprintf(buf, " COMMENT '%s'", format.OutputFormat(def.Comment)) - } - if def.PlacementPolicyRef != nil { - // add placement ref info here - fmt.Fprintf(buf, " /*T![placement] PLACEMENT POLICY=%s */", stringutil.Escape(def.PlacementPolicyRef.Name.O, sqlMode)) - } - } -} - -func generatePartValuesWithTp(partVal types.Datum, tp types.FieldType) (string, error) { - if partVal.Kind() == types.KindNull { - return "NULL", nil - } - - s, err := partVal.ToString() - if err != nil { - return "", err - } - - switch tp.EvalType() { - case types.ETInt: - return s, nil - case types.ETString: - // The `partVal` can be an invalid utf8 string if it's converted to BINARY, then the content will be lost after - // marshaling and storing in the schema. In this case, we use a hex literal to work around this issue. - if tp.GetCharset() == charset.CharsetBin { - return fmt.Sprintf("_binary 0x%x", s), nil - } - return driver.WrapInSingleQuotes(s), nil - case types.ETDatetime, types.ETDuration: - return driver.WrapInSingleQuotes(s), nil - } - - return "", dbterror.ErrWrongTypeColumnValue.GenWithStackByArgs() -} - -func checkPartitionDefinitionConstraints(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - var err error - if err = checkPartitionNameUnique(tbInfo.Partition); err != nil { - return errors.Trace(err) - } - if err = checkAddPartitionTooManyPartitions(uint64(len(tbInfo.Partition.Definitions))); err != nil { - return err - } - if err = checkAddPartitionOnTemporaryMode(tbInfo); err != nil { - return err - } - if err = checkPartitionColumnsUnique(tbInfo); err != nil { - return err - } - - switch tbInfo.Partition.Type { - case model.PartitionTypeRange: - err = checkPartitionByRange(ctx, tbInfo) - case model.PartitionTypeHash, model.PartitionTypeKey: - err = checkPartitionByHash(ctx, tbInfo) - case model.PartitionTypeList: - err = checkPartitionByList(ctx, tbInfo) - } - return errors.Trace(err) -} - -func checkPartitionByHash(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - return checkNoHashPartitions(ctx, tbInfo.Partition.Num) -} - -// checkPartitionByRange checks validity of a "BY RANGE" partition. -func checkPartitionByRange(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - failpoint.Inject("CheckPartitionByRangeErr", func() { - ctx.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryMemoryExceeded) - panic(ctx.GetSessionVars().SQLKiller.HandleSignal()) - }) - pi := tbInfo.Partition - - if len(pi.Columns) == 0 { - return checkRangePartitionValue(ctx, tbInfo) - } - - return checkRangeColumnsPartitionValue(ctx, tbInfo) -} - -func checkRangeColumnsPartitionValue(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - // Range columns partition key supports multiple data types with integer、datetime、string. - pi := tbInfo.Partition - defs := pi.Definitions - if len(defs) < 1 { - return ast.ErrPartitionsMustBeDefined.GenWithStackByArgs("RANGE") - } - - curr := &defs[0] - if len(curr.LessThan) != len(pi.Columns) { - return errors.Trace(ast.ErrPartitionColumnList) - } - var prev *model.PartitionDefinition - for i := 1; i < len(defs); i++ { - prev, curr = curr, &defs[i] - succ, err := checkTwoRangeColumns(ctx, curr, prev, pi, tbInfo) - if err != nil { - return err - } - if !succ { - return errors.Trace(dbterror.ErrRangeNotIncreasing) - } - } - return nil -} - -func checkTwoRangeColumns(ctx sessionctx.Context, curr, prev *model.PartitionDefinition, pi *model.PartitionInfo, tbInfo *model.TableInfo) (bool, error) { - if len(curr.LessThan) != len(pi.Columns) { - return false, errors.Trace(ast.ErrPartitionColumnList) - } - for i := 0; i < len(pi.Columns); i++ { - // Special handling for MAXVALUE. - if strings.EqualFold(curr.LessThan[i], partitionMaxValue) && !strings.EqualFold(prev.LessThan[i], partitionMaxValue) { - // If current is maxvalue, it certainly >= previous. - return true, nil - } - if strings.EqualFold(prev.LessThan[i], partitionMaxValue) { - // Current is not maxvalue, and previous is maxvalue. - return false, nil - } - - // The tuples of column values used to define the partitions are strictly increasing: - // PARTITION p0 VALUES LESS THAN (5,10,'ggg') - // PARTITION p1 VALUES LESS THAN (10,20,'mmm') - // PARTITION p2 VALUES LESS THAN (15,30,'sss') - colInfo := findColumnByName(pi.Columns[i].L, tbInfo) - cmp, err := parseAndEvalBoolExpr(ctx.GetExprCtx(), curr.LessThan[i], prev.LessThan[i], colInfo, tbInfo) - if err != nil { - return false, err - } - - if cmp > 0 { - return true, nil - } - - if cmp < 0 { - return false, nil - } - } - return false, nil -} - -// equal, return 0 -// greater, return 1 -// less, return -1 -func parseAndEvalBoolExpr(ctx expression.BuildContext, l, r string, colInfo *model.ColumnInfo, tbInfo *model.TableInfo) (int64, error) { - lexpr, err := expression.ParseSimpleExpr(ctx, l, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType)) - if err != nil { - return 0, err - } - rexpr, err := expression.ParseSimpleExpr(ctx, r, expression.WithTableInfo("", tbInfo), expression.WithCastExprTo(&colInfo.FieldType)) - if err != nil { - return 0, err - } - - e, err := expression.NewFunctionBase(ctx, ast.EQ, field_types.NewFieldType(mysql.TypeLonglong), lexpr, rexpr) - if err != nil { - return 0, err - } - e.SetCharsetAndCollation(colInfo.GetCharset(), colInfo.GetCollate()) - res, _, err1 := e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) - if err1 != nil { - return 0, err1 - } - if res == 1 { - return 0, nil - } - - e, err = expression.NewFunctionBase(ctx, ast.GT, field_types.NewFieldType(mysql.TypeLonglong), lexpr, rexpr) - if err != nil { - return 0, err - } - e.SetCharsetAndCollation(colInfo.GetCharset(), colInfo.GetCollate()) - res, _, err1 = e.EvalInt(ctx.GetEvalCtx(), chunk.Row{}) - if err1 != nil { - return 0, err1 - } - if res > 0 { - return 1, nil - } - return -1, nil -} - -// checkPartitionByList checks validity of a "BY LIST" partition. -func checkPartitionByList(ctx sessionctx.Context, tbInfo *model.TableInfo) error { - return checkListPartitionValue(ctx.GetExprCtx(), tbInfo) -} diff --git a/pkg/ddl/placement/binding__failpoint_binding__.go b/pkg/ddl/placement/binding__failpoint_binding__.go deleted file mode 100644 index 72feaf66a5d4f..0000000000000 --- a/pkg/ddl/placement/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package placement - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/ddl/placement/bundle.go b/pkg/ddl/placement/bundle.go index dbd8a70418f46..427bbdf2a7d98 100644 --- a/pkg/ddl/placement/bundle.go +++ b/pkg/ddl/placement/bundle.go @@ -301,11 +301,11 @@ func NewBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err // String implements fmt.Stringer. func (b *Bundle) String() string { t, err := json.Marshal(b) - if val, _err_ := failpoint.Eval(_curpkg_("MockMarshalFailure")); _err_ == nil { + failpoint.Inject("MockMarshalFailure", func(val failpoint.Value) { if _, ok := val.(bool); ok { err = errors.New("test") } - } + }) if err != nil { return "" } diff --git a/pkg/ddl/placement/bundle.go__failpoint_stash__ b/pkg/ddl/placement/bundle.go__failpoint_stash__ deleted file mode 100644 index 427bbdf2a7d98..0000000000000 --- a/pkg/ddl/placement/bundle.go__failpoint_stash__ +++ /dev/null @@ -1,712 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package placement - -import ( - "encoding/hex" - "encoding/json" - "errors" - "fmt" - "math" - "slices" - "sort" - "strconv" - "strings" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/codec" - pd "github.com/tikv/pd/client/http" - "gopkg.in/yaml.v2" -) - -// Bundle is a group of all rules and configurations. It is used to support rule cache. -// Alias `pd.GroupBundle` is to wrap more methods. -type Bundle pd.GroupBundle - -// NewBundle will create a bundle with the provided ID. -// Note that you should never pass negative id. -func NewBundle(id int64) *Bundle { - return &Bundle{ - ID: GroupID(id), - } -} - -// NewBundleFromConstraintsOptions will transform constraints options into the bundle. -func NewBundleFromConstraintsOptions(options *model.PlacementSettings) (*Bundle, error) { - if options == nil { - return nil, fmt.Errorf("%w: options can not be nil", ErrInvalidPlacementOptions) - } - - if len(options.PrimaryRegion) > 0 || len(options.Regions) > 0 || len(options.Schedule) > 0 { - return nil, fmt.Errorf("%w: should be [LEADER/VOTER/LEARNER/FOLLOWER]_CONSTRAINTS=.. [VOTERS/FOLLOWERS/LEARNERS]=.., mixed other sugar options %s", ErrInvalidPlacementOptions, options) - } - - constraints := options.Constraints - leaderConst := options.LeaderConstraints - learnerConstraints := options.LearnerConstraints - followerConstraints := options.FollowerConstraints - explicitFollowerCount := options.Followers - explicitLearnerCount := options.Learners - - rules := []*pd.Rule{} - commonConstraints, err := NewConstraintsFromYaml([]byte(constraints)) - if err != nil { - // If it's not in array format, attempt to parse it as a dictionary for more detailed definitions. - // The dictionary format specifies details for each replica. Constraints are used to define normal - // replicas that should act as voters. - // For example: CONSTRAINTS='{ "+region=us-east-1":2, "+region=us-east-2": 2, "+region=us-west-1": 1}' - normalReplicasRules, err := NewRuleBuilder(). - SetRole(pd.Voter). - SetConstraintStr(constraints). - BuildRulesWithDictConstraintsOnly() - if err != nil { - return nil, err - } - rules = append(rules, normalReplicasRules...) - } - needCreateDefault := len(rules) == 0 - leaderConstraints, err := NewConstraintsFromYaml([]byte(leaderConst)) - if err != nil { - return nil, fmt.Errorf("%w: 'LeaderConstraints' should be [constraint1, ...] or any yaml compatible array representation", err) - } - for _, cnst := range commonConstraints { - if err := AddConstraint(&leaderConstraints, cnst); err != nil { - return nil, fmt.Errorf("%w: LeaderConstraints conflicts with Constraints", err) - } - } - leaderReplicas, followerReplicas := uint64(1), uint64(2) - if explicitFollowerCount > 0 { - followerReplicas = explicitFollowerCount - } - if !needCreateDefault { - if len(leaderConst) == 0 { - leaderReplicas = 0 - } - if len(followerConstraints) == 0 { - if explicitFollowerCount > 0 { - return nil, fmt.Errorf("%w: specify follower count without specify follower constraints when specify other constraints", ErrInvalidPlacementOptions) - } - followerReplicas = 0 - } - } - - // create leader rule. - // if no constraints, we need create default leader rule. - if leaderReplicas > 0 { - leaderRule := NewRule(pd.Leader, leaderReplicas, leaderConstraints) - rules = append(rules, leaderRule) - } - - // create follower rules. - // if no constraints, we need create default follower rules. - if followerReplicas > 0 { - builder := NewRuleBuilder(). - SetRole(pd.Voter). - SetReplicasNum(followerReplicas). - SetSkipCheckReplicasConsistent(needCreateDefault && (explicitFollowerCount == 0)). - SetConstraintStr(followerConstraints) - followerRules, err := builder.BuildRules() - if err != nil { - return nil, fmt.Errorf("%w: invalid FollowerConstraints", err) - } - for _, followerRule := range followerRules { - for _, cnst := range commonConstraints { - if err := AddConstraint(&followerRule.LabelConstraints, cnst); err != nil { - return nil, fmt.Errorf("%w: FollowerConstraints conflicts with Constraints", err) - } - } - } - rules = append(rules, followerRules...) - } - - // create learner rules. - builder := NewRuleBuilder(). - SetRole(pd.Learner). - SetReplicasNum(explicitLearnerCount). - SetConstraintStr(learnerConstraints) - learnerRules, err := builder.BuildRules() - if err != nil { - return nil, fmt.Errorf("%w: invalid LearnerConstraints", err) - } - for _, rule := range learnerRules { - for _, cnst := range commonConstraints { - if err := AddConstraint(&rule.LabelConstraints, cnst); err != nil { - return nil, fmt.Errorf("%w: LearnerConstraints conflicts with Constraints", err) - } - } - } - rules = append(rules, learnerRules...) - labels, err := newLocationLabelsFromSurvivalPreferences(options.SurvivalPreferences) - if err != nil { - return nil, err - } - for _, rule := range rules { - rule.LocationLabels = labels - } - return &Bundle{Rules: rules}, nil -} - -// NewBundleFromSugarOptions will transform syntax sugar options into the bundle. -func NewBundleFromSugarOptions(options *model.PlacementSettings) (*Bundle, error) { - if options == nil { - return nil, fmt.Errorf("%w: options can not be nil", ErrInvalidPlacementOptions) - } - - if len(options.LeaderConstraints) > 0 || len(options.LearnerConstraints) > 0 || len(options.FollowerConstraints) > 0 || len(options.Constraints) > 0 || options.Learners > 0 { - return nil, fmt.Errorf("%w: should be PRIMARY_REGION=.. REGIONS=.. FOLLOWERS=.. SCHEDULE=.., mixed other constraints into options %s", ErrInvalidPlacementOptions, options) - } - - primaryRegion := strings.TrimSpace(options.PrimaryRegion) - - var regions []string - if k := strings.TrimSpace(options.Regions); len(k) > 0 { - regions = strings.Split(k, ",") - for i, r := range regions { - regions[i] = strings.TrimSpace(r) - } - } - - followers := options.Followers - if followers == 0 { - followers = 2 - } - schedule := options.Schedule - - var rules []*pd.Rule - - locationLabels, err := newLocationLabelsFromSurvivalPreferences(options.SurvivalPreferences) - if err != nil { - return nil, err - } - - // in case empty primaryRegion and regions, just return an empty bundle - if primaryRegion == "" && len(regions) == 0 { - rules = append(rules, NewRule(pd.Voter, followers+1, NewConstraintsDirect())) - for _, rule := range rules { - rule.LocationLabels = locationLabels - } - return &Bundle{Rules: rules}, nil - } - - // regions must include the primary - slices.Sort(regions) - primaryIndex := sort.SearchStrings(regions, primaryRegion) - if primaryIndex >= len(regions) || regions[primaryIndex] != primaryRegion { - return nil, fmt.Errorf("%w: primary region must be included in regions", ErrInvalidPlacementOptions) - } - - // primaryCount only makes sense when len(regions) > 0 - // but we will compute it here anyway for reusing code - var primaryCount uint64 - switch strings.ToLower(schedule) { - case "", "even": - primaryCount = uint64(math.Ceil(float64(followers+1) / float64(len(regions)))) - case "majority_in_primary": - // calculate how many replicas need to be in the primary region for quorum - primaryCount = (followers+1)/2 + 1 - default: - return nil, fmt.Errorf("%w: unsupported schedule %s", ErrInvalidPlacementOptions, schedule) - } - - rules = append(rules, NewRule(pd.Leader, 1, NewConstraintsDirect(NewConstraintDirect("region", pd.In, primaryRegion)))) - if primaryCount > 1 { - rules = append(rules, NewRule(pd.Voter, primaryCount-1, NewConstraintsDirect(NewConstraintDirect("region", pd.In, primaryRegion)))) - } - if cnt := followers + 1 - primaryCount; cnt > 0 { - // delete primary from regions - regions = regions[:primaryIndex+copy(regions[primaryIndex:], regions[primaryIndex+1:])] - if len(regions) > 0 { - rules = append(rules, NewRule(pd.Voter, cnt, NewConstraintsDirect(NewConstraintDirect("region", pd.In, regions...)))) - } else { - rules = append(rules, NewRule(pd.Voter, cnt, NewConstraintsDirect())) - } - } - - // set location labels - for _, rule := range rules { - rule.LocationLabels = locationLabels - } - - return &Bundle{Rules: rules}, nil -} - -// Non-Exported functionality function, do not use it directly but NewBundleFromOptions -// here is for only directly used in the test. -func newBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err error) { - if options == nil { - return nil, fmt.Errorf("%w: options can not be nil", ErrInvalidPlacementOptions) - } - - if options.Followers > uint64(8) { - return nil, fmt.Errorf("%w: followers should be less than or equal to 8: %d", ErrInvalidPlacementOptions, options.Followers) - } - - // always prefer the sugar syntax, which gives better schedule results most of the time - isSyntaxSugar := true - if len(options.LeaderConstraints) > 0 || len(options.LearnerConstraints) > 0 || len(options.FollowerConstraints) > 0 || len(options.Constraints) > 0 || options.Learners > 0 { - isSyntaxSugar = false - } - - if isSyntaxSugar { - bundle, err = NewBundleFromSugarOptions(options) - } else { - bundle, err = NewBundleFromConstraintsOptions(options) - } - return bundle, err -} - -// newLocationLabelsFromSurvivalPreferences will parse the survival preferences into location labels. -func newLocationLabelsFromSurvivalPreferences(survivalPreferenceStr string) ([]string, error) { - if len(survivalPreferenceStr) > 0 { - labels := []string{} - err := yaml.UnmarshalStrict([]byte(survivalPreferenceStr), &labels) - if err != nil { - return nil, ErrInvalidSurvivalPreferenceFormat - } - return labels, nil - } - return nil, nil -} - -// NewBundleFromOptions will transform options into the bundle. -func NewBundleFromOptions(options *model.PlacementSettings) (bundle *Bundle, err error) { - bundle, err = newBundleFromOptions(options) - if err != nil { - return nil, err - } - if bundle == nil { - return nil, nil - } - err = bundle.Tidy() - if err != nil { - return nil, err - } - return bundle, err -} - -// String implements fmt.Stringer. -func (b *Bundle) String() string { - t, err := json.Marshal(b) - failpoint.Inject("MockMarshalFailure", func(val failpoint.Value) { - if _, ok := val.(bool); ok { - err = errors.New("test") - } - }) - if err != nil { - return "" - } - return string(t) -} - -// Tidy will post optimize Rules, trying to generate rules that suits PD. -func (b *Bundle) Tidy() error { - tempRules := b.Rules[:0] - id := 0 - for _, rule := range b.Rules { - // useless Rule - if rule.Count <= 0 { - continue - } - // refer to tidb#22065. - // add -engine=tiflash to every rule to avoid schedules to tiflash instances. - // placement rules in SQL is not compatible with `set tiflash replica` yet - err := AddConstraint(&rule.LabelConstraints, pd.LabelConstraint{ - Op: pd.NotIn, - Key: EngineLabelKey, - Values: []string{EngineLabelTiFlash}, - }) - if err != nil { - return err - } - rule.ID = strconv.Itoa(id) - tempRules = append(tempRules, rule) - id++ - } - - groups := make(map[string]*constraintsGroup) - finalRules := tempRules[:0] - for _, rule := range tempRules { - key := ConstraintsFingerPrint(&rule.LabelConstraints) - existing, ok := groups[key] - if !ok { - groups[key] = &constraintsGroup{rules: []*pd.Rule{rule}} - continue - } - existing.rules = append(existing.rules, rule) - } - for _, group := range groups { - group.MergeRulesByRole() - } - if err := transformableLeaderConstraint(groups); err != nil { - return err - } - for _, group := range groups { - finalRules = append(finalRules, group.rules...) - } - // sort by id - sort.SliceStable(finalRules, func(i, j int) bool { - return finalRules[i].ID < finalRules[j].ID - }) - b.Rules = finalRules - return nil -} - -// constraintsGroup is a group of rules with the same constraints. -type constraintsGroup struct { - rules []*pd.Rule - // canBecameLeader means the group has leader/voter role, - // it's valid if it has leader. - canBecameLeader bool - // isLeaderGroup means it has specified leader role in this group. - isLeaderGroup bool -} - -func transformableLeaderConstraint(groups map[string]*constraintsGroup) error { - var leaderGroup *constraintsGroup - canBecameLeaderNum := 0 - for _, group := range groups { - if group.isLeaderGroup { - if leaderGroup != nil { - return ErrInvalidPlacementOptions - } - leaderGroup = group - } - if group.canBecameLeader { - canBecameLeaderNum++ - } - } - // If there is a specified group should have leader, and only this group can be a leader, that means - // the leader's priority is certain, so we can merge the transformable rules into one. - // eg: - // - [ group1 (L F), group2 (F) ], after merging is [group1 (2*V), group2 (F)], we still know the leader prefers group1. - // - [ group1 (L F), group2 (V) ], after merging is [group1 (2*V), group2 (V)], we can't know leader priority after merge. - if leaderGroup != nil && canBecameLeaderNum == 1 { - leaderGroup.MergeTransformableRoles() - } - return nil -} - -// MergeRulesByRole merges the rules with the same role. -func (c *constraintsGroup) MergeRulesByRole() { - // Create a map to store rules by role - rulesByRole := make(map[pd.PeerRoleType][]*pd.Rule) - - // Iterate through each rule - for _, rule := range c.rules { - // Add the rule to the map based on its role - rulesByRole[rule.Role] = append(rulesByRole[rule.Role], rule) - if rule.Role == pd.Leader || rule.Role == pd.Voter { - c.canBecameLeader = true - } - if rule.Role == pd.Leader { - c.isLeaderGroup = true - } - } - - // Clear existing rules - c.rules = nil - - // Iterate through each role and merge the rules - for _, rules := range rulesByRole { - mergedRule := rules[0] - for i, rule := range rules { - if i == 0 { - continue - } - mergedRule.Count += rule.Count - if mergedRule.ID > rule.ID { - mergedRule.ID = rule.ID - } - } - c.rules = append(c.rules, mergedRule) - } -} - -// MergeTransformableRoles merges all the rules to one that can be transformed to other roles. -func (c *constraintsGroup) MergeTransformableRoles() { - if len(c.rules) == 0 || len(c.rules) == 1 { - return - } - var mergedRule *pd.Rule - newRules := make([]*pd.Rule, 0, len(c.rules)) - for _, rule := range c.rules { - // Learner is not transformable, it should be promote by PD. - if rule.Role == pd.Learner { - newRules = append(newRules, rule) - continue - } - if mergedRule == nil { - mergedRule = rule - continue - } - mergedRule.Count += rule.Count - if mergedRule.ID > rule.ID { - mergedRule.ID = rule.ID - } - } - if mergedRule != nil { - mergedRule.Role = pd.Voter - newRules = append(newRules, mergedRule) - } - c.rules = newRules -} - -// GetRangeStartAndEndKeyHex get startKeyHex and endKeyHex of range by rangeBundleID. -func GetRangeStartAndEndKeyHex(rangeBundleID string) (startKey string, endKey string) { - startKey, endKey = "", "" - if rangeBundleID == TiDBBundleRangePrefixForMeta { - startKey = hex.EncodeToString(metaPrefix) - endKey = hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(0))) - } - return startKey, endKey -} - -// RebuildForRange rebuilds the bundle for system range. -func (b *Bundle) RebuildForRange(rangeName string, policyName string) *Bundle { - rule := b.Rules - switch rangeName { - case KeyRangeGlobal: - b.ID = TiDBBundleRangePrefixForGlobal - b.Index = RuleIndexKeyRangeForGlobal - case KeyRangeMeta: - b.ID = TiDBBundleRangePrefixForMeta - b.Index = RuleIndexKeyRangeForMeta - } - - startKey, endKey := GetRangeStartAndEndKeyHex(b.ID) - b.Override = true - newRules := make([]*pd.Rule, 0, len(rule)) - for i, r := range b.Rules { - cp := r.Clone() - cp.ID = fmt.Sprintf("%s_rule_%d", strings.ToLower(policyName), i) - cp.GroupID = b.ID - cp.StartKeyHex = startKey - cp.EndKeyHex = endKey - cp.Index = i - newRules = append(newRules, cp) - } - b.Rules = newRules - return b -} - -// Reset resets the bundle ID and keyrange of all rules. -func (b *Bundle) Reset(ruleIndex int, newIDs []int64) *Bundle { - // eliminate the redundant rules. - var basicRules []*pd.Rule - if len(b.Rules) != 0 { - // Make priority for rules with RuleIndexTable cause of duplication rules existence with RuleIndexPartition. - // If RuleIndexTable doesn't exist, bundle itself is a independent series of rules for a partition. - for _, rule := range b.Rules { - if rule.Index == RuleIndexTable { - basicRules = append(basicRules, rule) - } - } - if len(basicRules) == 0 { - basicRules = b.Rules - } - } - - // extend and reset basic rules for all new ids, the first id should be the group id. - b.ID = GroupID(newIDs[0]) - b.Index = ruleIndex - b.Override = true - newRules := make([]*pd.Rule, 0, len(basicRules)*len(newIDs)) - for i, newID := range newIDs { - // rule.id should be distinguished with each other, otherwise it will be de-duplicated in pd http api. - var ruleID string - if ruleIndex == RuleIndexPartition { - ruleID = "partition_rule_" + strconv.FormatInt(newID, 10) - } else { - if i == 0 { - ruleID = "table_rule_" + strconv.FormatInt(newID, 10) - } else { - ruleID = "partition_rule_" + strconv.FormatInt(newID, 10) - } - } - // Involve all the table level objects. - startKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID))) - endKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(newID+1))) - for j, rule := range basicRules { - clone := rule.Clone() - // for the rules of one element id, distinguishing the rule ids to avoid the PD's overlap. - clone.ID = ruleID + "_" + strconv.FormatInt(int64(j), 10) - clone.GroupID = b.ID - clone.StartKeyHex = startKey - clone.EndKeyHex = endKey - if i == 0 { - clone.Index = RuleIndexTable - } else { - clone.Index = RuleIndexPartition - } - newRules = append(newRules, clone) - } - } - b.Rules = newRules - return b -} - -// Clone is used to duplicate a bundle. -func (b *Bundle) Clone() *Bundle { - newBundle := &Bundle{} - *newBundle = *b - if len(b.Rules) > 0 { - newBundle.Rules = make([]*pd.Rule, 0, len(b.Rules)) - for i := range b.Rules { - newBundle.Rules = append(newBundle.Rules, b.Rules[i].Clone()) - } - } - return newBundle -} - -// IsEmpty is used to check if a bundle is empty. -func (b *Bundle) IsEmpty() bool { - return len(b.Rules) == 0 && b.Index == 0 && !b.Override -} - -// ObjectID extracts the db/table/partition ID from the group ID -func (b *Bundle) ObjectID() (int64, error) { - // If the rule doesn't come from TiDB, skip it. - if !strings.HasPrefix(b.ID, BundleIDPrefix) { - return 0, ErrInvalidBundleIDFormat - } - id, err := strconv.ParseInt(b.ID[len(BundleIDPrefix):], 10, 64) - if err != nil { - return 0, fmt.Errorf("%w: %s", ErrInvalidBundleID, err) - } - if id <= 0 { - return 0, fmt.Errorf("%w: %s doesn't include an id", ErrInvalidBundleID, b.ID) - } - return id, nil -} - -func isValidLeaderRule(rule *pd.Rule, dcLabelKey string) bool { - if rule.Role == pd.Leader && rule.Count == 1 { - for _, con := range rule.LabelConstraints { - if con.Op == pd.In && con.Key == dcLabelKey && len(con.Values) == 1 { - return true - } - } - } - return false -} - -// GetLeaderDC returns the leader's DC by Bundle if found. -func (b *Bundle) GetLeaderDC(dcLabelKey string) (string, bool) { - for _, rule := range b.Rules { - if isValidLeaderRule(rule, dcLabelKey) { - return rule.LabelConstraints[0].Values[0], true - } - } - return "", false -} - -// PolicyGetter is the interface to get the policy -type PolicyGetter interface { - GetPolicy(policyID int64) (*model.PolicyInfo, error) -} - -// NewTableBundle creates a bundle for table key range. -// If table is a partitioned table, it also contains the rules that inherited from table for every partition. -// The bundle does not contain the rules specified independently by each partition -func NewTableBundle(getter PolicyGetter, tbInfo *model.TableInfo) (*Bundle, error) { - bundle, err := newBundleFromPolicy(getter, tbInfo.PlacementPolicyRef) - if err != nil { - return nil, err - } - - if bundle == nil { - return nil, nil - } - ids := []int64{tbInfo.ID} - // build the default partition rules in the table-level bundle. - if tbInfo.Partition != nil { - for _, pDef := range tbInfo.Partition.Definitions { - ids = append(ids, pDef.ID) - } - } - bundle.Reset(RuleIndexTable, ids) - return bundle, nil -} - -// NewPartitionBundle creates a bundle for partition key range. -// It only contains the rules specified independently by the partition. -// That is to say the inherited rules from table is not included. -func NewPartitionBundle(getter PolicyGetter, def model.PartitionDefinition) (*Bundle, error) { - bundle, err := newBundleFromPolicy(getter, def.PlacementPolicyRef) - if err != nil { - return nil, err - } - - if bundle != nil { - bundle.Reset(RuleIndexPartition, []int64{def.ID}) - } - - return bundle, nil -} - -// NewPartitionListBundles creates a bundle list for a partition list -func NewPartitionListBundles(getter PolicyGetter, defs []model.PartitionDefinition) ([]*Bundle, error) { - bundles := make([]*Bundle, 0, len(defs)) - // If the partition has the placement rules on their own, build the partition-level bundles additionally. - for _, def := range defs { - bundle, err := NewPartitionBundle(getter, def) - if err != nil { - return nil, err - } - - if bundle != nil { - bundles = append(bundles, bundle) - } - } - return bundles, nil -} - -// NewFullTableBundles returns a bundle list with both table bundle and partition bundles -func NewFullTableBundles(getter PolicyGetter, tbInfo *model.TableInfo) ([]*Bundle, error) { - var bundles []*Bundle - tableBundle, err := NewTableBundle(getter, tbInfo) - if err != nil { - return nil, err - } - - if tableBundle != nil { - bundles = append(bundles, tableBundle) - } - - if tbInfo.Partition != nil { - partitionBundles, err := NewPartitionListBundles(getter, tbInfo.Partition.Definitions) - if err != nil { - return nil, err - } - bundles = append(bundles, partitionBundles...) - } - - return bundles, nil -} - -func newBundleFromPolicy(getter PolicyGetter, ref *model.PolicyRefInfo) (*Bundle, error) { - if ref != nil { - policy, err := getter.GetPolicy(ref.ID) - if err != nil { - return nil, err - } - - return NewBundleFromOptions(policy.PlacementSettings) - } - - return nil, nil -} diff --git a/pkg/ddl/reorg.go b/pkg/ddl/reorg.go index b00762f6ce8bd..2154974826a5a 100644 --- a/pkg/ddl/reorg.go +++ b/pkg/ddl/reorg.go @@ -733,15 +733,15 @@ func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, if job.SnapshotVer == 0 { // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. // Third step, we need to remove the element information to make sure we can save the reorganized information to storage. - if val, _err_ := failpoint.Eval(_curpkg_("MockGetIndexRecordErr")); _err_ == nil { + failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 3, 4) { if err := rh.RemoveReorgElementFailPoint(job); err != nil { - return nil, errors.Trace(err) + failpoint.Return(nil, errors.Trace(err)) } info.first = true - return &info, nil + failpoint.Return(&info, nil) } - } + }) info.first = true if d.lease > 0 { // Only delay when it's not in test. @@ -776,9 +776,9 @@ func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, zap.String("startKey", hex.EncodeToString(start)), zap.String("endKey", hex.EncodeToString(end))) - if _, _err_ := failpoint.Eval(_curpkg_("errorUpdateReorgHandle")); _err_ == nil { + failpoint.Inject("errorUpdateReorgHandle", func() (*reorgInfo, error) { return &info, errors.New("occur an error when update reorg handle") - } + }) err = rh.InitDDLReorgHandle(job, start, end, pid, elements[0]) if err != nil { return &info, errors.Trace(err) @@ -787,16 +787,16 @@ func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, job.SnapshotVer = ver.Ver element = elements[0] } else { - if val, _err_ := failpoint.Eval(_curpkg_("MockGetIndexRecordErr")); _err_ == nil { + failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. // Second step, we need to remove the element information to make sure we can get the error of "ErrDDLReorgElementNotExist". // However, since "txn.Reset()" will be called later, the reorganized information cannot be saved to storage. if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 2, 3) { if err := rh.RemoveReorgElementFailPoint(job); err != nil { - return nil, errors.Trace(err) + failpoint.Return(nil, errors.Trace(err)) } } - } + }) var err error element, start, end, pid, err = rh.GetDDLReorgHandle(job) diff --git a/pkg/ddl/reorg.go__failpoint_stash__ b/pkg/ddl/reorg.go__failpoint_stash__ deleted file mode 100644 index 2154974826a5a..0000000000000 --- a/pkg/ddl/reorg.go__failpoint_stash__ +++ /dev/null @@ -1,982 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "encoding/hex" - "fmt" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/ddl/logutil" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/distsql" - distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" - "github.com/pingcap/tidb/pkg/errctx" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/expression/contextstatic" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/mock" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tidb/pkg/util/timeutil" - "github.com/pingcap/tipb/go-tipb" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -// reorgCtx is for reorganization. -type reorgCtx struct { - // doneCh is used to notify. - // If the reorganization job is done, we will use this channel to notify outer. - // TODO: Now we use goroutine to simulate reorganization jobs, later we may - // use a persistent job list. - doneCh chan reorgFnResult - // rowCount is used to simulate a job's row count. - rowCount int64 - jobState model.JobState - - mu struct { - sync.Mutex - // warnings are used to store the warnings when doing the reorg job under certain SQL modes. - warnings map[errors.ErrorID]*terror.Error - warningsCount map[errors.ErrorID]int64 - } - - references atomicutil.Int32 -} - -// reorgFnResult records the DDL owner TS before executing reorg function, in order to help -// receiver determine if the result is from reorg function of previous DDL owner in this instance. -type reorgFnResult struct { - ownerTS int64 - err error -} - -func newReorgExprCtx() exprctx.ExprContext { - evalCtx := contextstatic.NewStaticEvalContext( - contextstatic.WithSQLMode(mysql.ModeNone), - contextstatic.WithTypeFlags(types.DefaultStmtFlags), - contextstatic.WithErrLevelMap(stmtctx.DefaultStmtErrLevels), - ) - - planCacheTracker := contextutil.NewPlanCacheTracker(contextutil.IgnoreWarn) - - return contextstatic.NewStaticExprContext( - contextstatic.WithEvalCtx(evalCtx), - contextstatic.WithPlanCacheTracker(&planCacheTracker), - ) -} - -func reorgTypeFlagsWithSQLMode(mode mysql.SQLMode) types.Flags { - return types.StrictFlags. - WithTruncateAsWarning(!mode.HasStrictMode()). - WithIgnoreInvalidDateErr(mode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!mode.HasStrictMode() || mode.HasAllowInvalidDatesMode()). - WithCastTimeToYearThroughConcat(true) -} - -func reorgErrLevelsWithSQLMode(mode mysql.SQLMode) errctx.LevelMap { - return errctx.LevelMap{ - errctx.ErrGroupTruncate: errctx.ResolveErrLevel(false, !mode.HasStrictMode()), - errctx.ErrGroupBadNull: errctx.ResolveErrLevel(false, !mode.HasStrictMode()), - errctx.ErrGroupDividedByZero: errctx.ResolveErrLevel( - !mode.HasErrorForDivisionByZeroMode(), - !mode.HasStrictMode(), - ), - } -} - -func reorgTimeZoneWithTzLoc(tzLoc *model.TimeZoneLocation) (*time.Location, error) { - if tzLoc == nil { - // It is set to SystemLocation to be compatible with nil LocationInfo. - return timeutil.SystemLocation(), nil - } - return tzLoc.GetLocation() -} - -func newReorgSessCtx(store kv.Storage) sessionctx.Context { - c := mock.NewContext() - c.Store = store - c.GetSessionVars().SetStatusFlag(mysql.ServerStatusAutocommit, false) - - tz := *time.UTC - c.GetSessionVars().TimeZone = &tz - c.GetSessionVars().StmtCtx.SetTimeZone(&tz) - return c -} - -const defaultWaitReorgTimeout = 10 * time.Second - -// ReorgWaitTimeout is the timeout that wait ddl in write reorganization stage. -var ReorgWaitTimeout = 5 * time.Second - -func (rc *reorgCtx) notifyJobState(state model.JobState) { - atomic.StoreInt32((*int32)(&rc.jobState), int32(state)) -} - -func (rc *reorgCtx) isReorgCanceled() bool { - s := atomic.LoadInt32((*int32)(&rc.jobState)) - return int32(model.JobStateCancelled) == s || int32(model.JobStateCancelling) == s -} - -func (rc *reorgCtx) isReorgPaused() bool { - s := atomic.LoadInt32((*int32)(&rc.jobState)) - return int32(model.JobStatePaused) == s || int32(model.JobStatePausing) == s -} - -func (rc *reorgCtx) setRowCount(count int64) { - atomic.StoreInt64(&rc.rowCount, count) -} - -func (rc *reorgCtx) mergeWarnings(warnings map[errors.ErrorID]*terror.Error, warningsCount map[errors.ErrorID]int64) { - if len(warnings) == 0 || len(warningsCount) == 0 { - return - } - rc.mu.Lock() - defer rc.mu.Unlock() - rc.mu.warnings, rc.mu.warningsCount = mergeWarningsAndWarningsCount(warnings, rc.mu.warnings, warningsCount, rc.mu.warningsCount) -} - -func (rc *reorgCtx) resetWarnings() { - rc.mu.Lock() - defer rc.mu.Unlock() - rc.mu.warnings = make(map[errors.ErrorID]*terror.Error) - rc.mu.warningsCount = make(map[errors.ErrorID]int64) -} - -func (rc *reorgCtx) increaseRowCount(count int64) { - atomic.AddInt64(&rc.rowCount, count) -} - -func (rc *reorgCtx) getRowCount() int64 { - row := atomic.LoadInt64(&rc.rowCount) - return row -} - -// runReorgJob is used as a portal to do the reorganization work. -// eg: -// 1: add index -// 2: alter column type -// 3: clean global index -// 4: reorganize partitions -/* - ddl goroutine >---------+ - ^ | - | | - | | - | | <---(doneCh)--- f() - HandleDDLQueue(...) | <---(regular timeout) - | | <---(ctx done) - | | - | | - A more ddl round <-----+ -*/ -// How can we cancel reorg job? -// -// The background reorg is continuously running except for several factors, for instances, ddl owner change, -// logic error (kv duplicate when insert index / cast error when alter column), ctx done, and cancel signal. -// -// When `admin cancel ddl jobs xxx` takes effect, we will give this kind of reorg ddl one more round. -// because we should pull the result from doneCh out, otherwise, the reorg worker will hang on `f()` logic, -// which is a kind of goroutine leak. -// -// That's why we couldn't set the job to rollingback state directly in `convertJob2RollbackJob`, which is a -// cancelling portal for admin cancel action. -// -// In other words, the cancelling signal is informed from the bottom up, we set the atomic cancel variable -// in the cancelling portal to notify the lower worker goroutine, and fetch the cancel error from them in -// the additional ddl round. -// -// After that, we can make sure that the worker goroutine is correctly shut down. -func (w *worker) runReorgJob( - reorgInfo *reorgInfo, - tblInfo *model.TableInfo, - lease time.Duration, - reorgFn func() error, -) error { - job := reorgInfo.Job - d := reorgInfo.d - // This is for tests compatible, because most of the early tests try to build the reorg job manually - // without reorg meta info, which will cause nil pointer in here. - if job.ReorgMeta == nil { - job.ReorgMeta = &model.DDLReorgMeta{ - SQLMode: mysql.ModeNone, - Warnings: make(map[errors.ErrorID]*terror.Error), - WarningsCount: make(map[errors.ErrorID]int64), - Location: &model.TimeZoneLocation{Name: time.UTC.String(), Offset: 0}, - Version: model.CurrentReorgMetaVersion, - } - } - - rc := w.getReorgCtx(job.ID) - if rc == nil { - // This job is cancelling, we should return ErrCancelledDDLJob directly. - // - // Q: Is there any possibility that the job is cancelling and has no reorgCtx? - // A: Yes, consider the case that : - // - we cancel the job when backfilling the last batch of data, the cancel txn is commit first, - // - and then the backfill workers send signal to the `doneCh` of the reorgCtx, - // - and then the DDL worker will remove the reorgCtx - // - and update the DDL job to `done` - // - but at the commit time, the DDL txn will raise a "write conflict" error and retry, and it happens. - if job.IsCancelling() { - return dbterror.ErrCancelledDDLJob - } - - beOwnerTS := w.ddlCtx.reorgCtx.getOwnerTS() - rc = w.newReorgCtx(reorgInfo.Job.ID, reorgInfo.Job.GetRowCount()) - w.wg.Add(1) - go func() { - defer w.wg.Done() - err := reorgFn() - rc.doneCh <- reorgFnResult{ownerTS: beOwnerTS, err: err} - }() - } - - waitTimeout := defaultWaitReorgTimeout - // if lease is 0, we are using a local storage, - // and we can wait the reorganization to be done here. - // if lease > 0, we don't need to wait here because - // we should update some job's progress context and try checking again, - // so we use a very little timeout here. - if lease > 0 { - waitTimeout = ReorgWaitTimeout - } - - // wait reorganization job done or timeout - select { - case res := <-rc.doneCh: - err := res.err - curTS := w.ddlCtx.reorgCtx.getOwnerTS() - if res.ownerTS != curTS { - d.removeReorgCtx(job.ID) - logutil.DDLLogger().Warn("owner ts mismatch, return timeout error and retry", - zap.Int64("prevTS", res.ownerTS), - zap.Int64("curTS", curTS)) - return dbterror.ErrWaitReorgTimeout - } - // Since job is cancelled,we don't care about its partial counts. - if rc.isReorgCanceled() || terror.ErrorEqual(err, dbterror.ErrCancelledDDLJob) { - d.removeReorgCtx(job.ID) - return dbterror.ErrCancelledDDLJob - } - rowCount := rc.getRowCount() - job.SetRowCount(rowCount) - if err != nil { - logutil.DDLLogger().Warn("run reorg job done", zap.Int64("handled rows", rowCount), zap.Error(err)) - } else { - logutil.DDLLogger().Info("run reorg job done", zap.Int64("handled rows", rowCount)) - } - - // Update a job's warnings. - w.mergeWarningsIntoJob(job) - - d.removeReorgCtx(job.ID) - - updateBackfillProgress(w, reorgInfo, tblInfo, rowCount) - - // For other errors, even err is not nil here, we still wait the partial counts to be collected. - // since in the next round, the startKey is brand new which is stored by last time. - if err != nil { - return errors.Trace(err) - } - case <-time.After(waitTimeout): - rowCount := rc.getRowCount() - job.SetRowCount(rowCount) - updateBackfillProgress(w, reorgInfo, tblInfo, rowCount) - - // Update a job's warnings. - w.mergeWarningsIntoJob(job) - - rc.resetWarnings() - - logutil.DDLLogger().Info("run reorg job wait timeout", - zap.Duration("wait time", waitTimeout), - zap.Int64("total added row count", rowCount)) - // If timeout, we will return, check the owner and retry to wait job done again. - return dbterror.ErrWaitReorgTimeout - } - return nil -} - -func overwriteReorgInfoFromGlobalCheckpoint(w *worker, sess *sess.Session, job *model.Job, reorgInfo *reorgInfo) error { - if job.ReorgMeta.ReorgTp != model.ReorgTypeLitMerge { - // Only used for the ingest mode job. - return nil - } - if reorgInfo.mergingTmpIdx { - // Merging the temporary index uses txn mode, so we don't need to consider the checkpoint. - return nil - } - if job.ReorgMeta.IsDistReorg { - // The global checkpoint is not used in distributed tasks. - return nil - } - if w.getReorgCtx(job.ID) != nil { - // We only overwrite from checkpoint when the job runs for the first time on this TiDB instance. - return nil - } - start, pid, err := getImportedKeyFromCheckpoint(sess, job) - if err != nil { - return errors.Trace(err) - } - if pid != reorgInfo.PhysicalTableID { - // Current physical ID does not match checkpoint physical ID. - // Don't overwrite reorgInfo.StartKey. - return nil - } - if len(start) > 0 { - reorgInfo.StartKey = start - } - return nil -} - -func extractElemIDs(r *reorgInfo) []int64 { - elemIDs := make([]int64, 0, len(r.elements)) - for _, elem := range r.elements { - elemIDs = append(elemIDs, elem.ID) - } - return elemIDs -} - -func (w *worker) mergeWarningsIntoJob(job *model.Job) { - rc := w.getReorgCtx(job.ID) - rc.mu.Lock() - partWarnings := rc.mu.warnings - partWarningsCount := rc.mu.warningsCount - rc.mu.Unlock() - warnings, warningsCount := job.GetWarnings() - warnings, warningsCount = mergeWarningsAndWarningsCount(partWarnings, warnings, partWarningsCount, warningsCount) - job.SetWarnings(warnings, warningsCount) -} - -func updateBackfillProgress(w *worker, reorgInfo *reorgInfo, tblInfo *model.TableInfo, - addedRowCount int64) { - if tblInfo == nil { - return - } - progress := float64(0) - if addedRowCount != 0 { - totalCount := getTableTotalCount(w, tblInfo) - if totalCount > 0 { - progress = float64(addedRowCount) / float64(totalCount) - } else { - progress = 0 - } - if progress > 1 { - progress = 1 - } - logutil.DDLLogger().Debug("update progress", - zap.Float64("progress", progress), - zap.Int64("addedRowCount", addedRowCount), - zap.Int64("totalCount", totalCount)) - } - switch reorgInfo.Type { - case model.ActionAddIndex, model.ActionAddPrimaryKey: - var label string - if reorgInfo.mergingTmpIdx { - label = metrics.LblAddIndexMerge - } else { - label = metrics.LblAddIndex - } - metrics.GetBackfillProgressByLabel(label, reorgInfo.SchemaName, tblInfo.Name.String()).Set(progress * 100) - case model.ActionModifyColumn: - metrics.GetBackfillProgressByLabel(metrics.LblModifyColumn, reorgInfo.SchemaName, tblInfo.Name.String()).Set(progress * 100) - case model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - metrics.GetBackfillProgressByLabel(metrics.LblReorgPartition, reorgInfo.SchemaName, tblInfo.Name.String()).Set(progress * 100) - } -} - -func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { - var ctx sessionctx.Context - ctx, err := w.sessPool.Get() - if err != nil { - return statistics.PseudoRowCount - } - defer w.sessPool.Put(ctx) - - // `mock.Context` is used in tests, which doesn't support sql exec - if _, ok := ctx.(*mock.Context); ok { - return statistics.PseudoRowCount - } - - executor := ctx.GetRestrictedSQLExecutor() - var rows []chunk.Row - if tblInfo.Partition != nil && len(tblInfo.Partition.DroppingDefinitions) > 0 { - // if Reorganize Partition, only select number of rows from the selected partitions! - defs := tblInfo.Partition.DroppingDefinitions - partIDs := make([]string, 0, len(defs)) - for _, def := range defs { - partIDs = append(partIDs, strconv.FormatInt(def.ID, 10)) - } - sql := "select sum(table_rows) from information_schema.partitions where tidb_partition_id in (%?);" - rows, _, err = executor.ExecRestrictedSQL(w.ctx, nil, sql, strings.Join(partIDs, ",")) - } else { - sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" - rows, _, err = executor.ExecRestrictedSQL(w.ctx, nil, sql, tblInfo.ID) - } - if err != nil { - return statistics.PseudoRowCount - } - if len(rows) != 1 { - return statistics.PseudoRowCount - } - return rows[0].GetInt64(0) -} - -func (dc *ddlCtx) isReorgCancelled(jobID int64) bool { - return dc.getReorgCtx(jobID).isReorgCanceled() -} -func (dc *ddlCtx) isReorgPaused(jobID int64) bool { - return dc.getReorgCtx(jobID).isReorgPaused() -} - -func (dc *ddlCtx) isReorgRunnable(jobID int64, isDistReorg bool) error { - if dc.ctx.Err() != nil { - // Worker is closed. So it can't do the reorganization. - return dbterror.ErrInvalidWorker.GenWithStack("worker is closed") - } - - if dc.isReorgCancelled(jobID) { - // Job is cancelled. So it can't be done. - return dbterror.ErrCancelledDDLJob - } - - if dc.isReorgPaused(jobID) { - logutil.DDLLogger().Warn("job paused by user", zap.String("ID", dc.uuid)) - return dbterror.ErrPausedDDLJob.GenWithStackByArgs(jobID) - } - - // If isDistReorg is true, we needn't check if it is owner. - if isDistReorg { - return nil - } - if !dc.isOwner() { - // If it's not the owner, we will try later, so here just returns an error. - logutil.DDLLogger().Info("DDL is not the DDL owner", zap.String("ID", dc.uuid)) - return errors.Trace(dbterror.ErrNotOwner) - } - return nil -} - -type reorgInfo struct { - *model.Job - - StartKey kv.Key - EndKey kv.Key - d *ddlCtx - first bool - mergingTmpIdx bool - // PhysicalTableID is used for partitioned table. - // DDL reorganize for a partitioned table will handle partitions one by one, - // PhysicalTableID is used to trace the current partition we are handling. - // If the table is not partitioned, PhysicalTableID would be TableID. - PhysicalTableID int64 - dbInfo *model.DBInfo - elements []*meta.Element - currElement *meta.Element -} - -func (r *reorgInfo) NewJobContext() *JobContext { - return r.d.jobContext(r.Job.ID, r.Job.ReorgMeta) -} - -func (r *reorgInfo) String() string { - var isEnabled bool - if ingest.LitInitialized { - _, isEnabled = ingest.LitBackCtxMgr.Load(r.Job.ID) - } - return "CurrElementType:" + string(r.currElement.TypeKey) + "," + - "CurrElementID:" + strconv.FormatInt(r.currElement.ID, 10) + "," + - "StartKey:" + hex.EncodeToString(r.StartKey) + "," + - "EndKey:" + hex.EncodeToString(r.EndKey) + "," + - "First:" + strconv.FormatBool(r.first) + "," + - "PhysicalTableID:" + strconv.FormatInt(r.PhysicalTableID, 10) + "," + - "Ingest mode:" + strconv.FormatBool(isEnabled) -} - -func constructDescTableScanPB(physicalTableID int64, tblInfo *model.TableInfo, handleCols []*model.ColumnInfo) *tipb.Executor { - tblScan := tables.BuildTableScanFromInfos(tblInfo, handleCols) - tblScan.TableId = physicalTableID - tblScan.Desc = true - return &tipb.Executor{Tp: tipb.ExecType_TypeTableScan, TblScan: tblScan} -} - -func constructLimitPB(count uint64) *tipb.Executor { - limitExec := &tipb.Limit{ - Limit: count, - } - return &tipb.Executor{Tp: tipb.ExecType_TypeLimit, Limit: limitExec} -} - -func buildDescTableScanDAG(distSQLCtx *distsqlctx.DistSQLContext, tbl table.PhysicalTable, handleCols []*model.ColumnInfo, limit uint64) (*tipb.DAGRequest, error) { - dagReq := &tipb.DAGRequest{} - _, timeZoneOffset := time.Now().In(time.UTC).Zone() - dagReq.TimeZoneOffset = int64(timeZoneOffset) - for i := range handleCols { - dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) - } - dagReq.Flags |= model.FlagInSelectStmt - - tblScanExec := constructDescTableScanPB(tbl.GetPhysicalID(), tbl.Meta(), handleCols) - dagReq.Executors = append(dagReq.Executors, tblScanExec) - dagReq.Executors = append(dagReq.Executors, constructLimitPB(limit)) - distsql.SetEncodeType(distSQLCtx, dagReq) - return dagReq, nil -} - -func getColumnsTypes(columns []*model.ColumnInfo) []*types.FieldType { - colTypes := make([]*types.FieldType, 0, len(columns)) - for _, col := range columns { - colTypes = append(colTypes, &col.FieldType) - } - return colTypes -} - -// buildDescTableScan builds a desc table scan upon tblInfo. -func (dc *ddlCtx) buildDescTableScan(ctx *JobContext, startTS uint64, tbl table.PhysicalTable, - handleCols []*model.ColumnInfo, limit uint64) (distsql.SelectResult, error) { - distSQLCtx := newDefaultReorgDistSQLCtx(dc.store.GetClient()) - dagPB, err := buildDescTableScanDAG(distSQLCtx, tbl, handleCols, limit) - if err != nil { - return nil, errors.Trace(err) - } - var b distsql.RequestBuilder - var builder *distsql.RequestBuilder - var ranges []*ranger.Range - if tbl.Meta().IsCommonHandle { - ranges = ranger.FullNotNullRange() - } else { - ranges = ranger.FullIntRange(false) - } - builder = b.SetHandleRanges(distSQLCtx, tbl.GetPhysicalID(), tbl.Meta().IsCommonHandle, ranges) - builder.SetDAGRequest(dagPB). - SetStartTS(startTS). - SetKeepOrder(true). - SetConcurrency(1). - SetDesc(true). - SetResourceGroupTagger(ctx.getResourceGroupTaggerForTopSQL()). - SetResourceGroupName(ctx.resourceGroupName) - - builder.Request.NotFillCache = true - builder.Request.Priority = kv.PriorityLow - builder.RequestSource.RequestSourceInternal = true - builder.RequestSource.RequestSourceType = ctx.ddlJobSourceType() - - kvReq, err := builder.Build() - if err != nil { - return nil, errors.Trace(err) - } - - result, err := distsql.Select(ctx.ddlJobCtx, distSQLCtx, kvReq, getColumnsTypes(handleCols)) - if err != nil { - return nil, errors.Trace(err) - } - return result, nil -} - -// GetTableMaxHandle gets the max handle of a PhysicalTable. -func (dc *ddlCtx) GetTableMaxHandle(ctx *JobContext, startTS uint64, tbl table.PhysicalTable) (maxHandle kv.Handle, emptyTable bool, err error) { - var handleCols []*model.ColumnInfo - var pkIdx *model.IndexInfo - tblInfo := tbl.Meta() - switch { - case tblInfo.PKIsHandle: - for _, col := range tbl.Meta().Columns { - if mysql.HasPriKeyFlag(col.GetFlag()) { - handleCols = []*model.ColumnInfo{col} - break - } - } - case tblInfo.IsCommonHandle: - pkIdx = tables.FindPrimaryIndex(tblInfo) - cols := tblInfo.Cols() - for _, idxCol := range pkIdx.Columns { - handleCols = append(handleCols, cols[idxCol.Offset]) - } - default: - handleCols = []*model.ColumnInfo{model.NewExtraHandleColInfo()} - } - - // build a desc scan of tblInfo, which limit is 1, we can use it to retrieve the last handle of the table. - result, err := dc.buildDescTableScan(ctx, startTS, tbl, handleCols, 1) - if err != nil { - return nil, false, errors.Trace(err) - } - defer terror.Call(result.Close) - - chk := chunk.New(getColumnsTypes(handleCols), 1, 1) - err = result.Next(ctx.ddlJobCtx, chk) - if err != nil { - return nil, false, errors.Trace(err) - } - - if chk.NumRows() == 0 { - // empty table - return nil, true, nil - } - row := chk.GetRow(0) - if tblInfo.IsCommonHandle { - maxHandle, err = buildCommonHandleFromChunkRow(time.UTC, tblInfo, pkIdx, handleCols, row) - return maxHandle, false, err - } - return kv.IntHandle(row.GetInt64(0)), false, nil -} - -func buildCommonHandleFromChunkRow(loc *time.Location, tblInfo *model.TableInfo, idxInfo *model.IndexInfo, - cols []*model.ColumnInfo, row chunk.Row) (kv.Handle, error) { - fieldTypes := make([]*types.FieldType, 0, len(cols)) - for _, col := range cols { - fieldTypes = append(fieldTypes, &col.FieldType) - } - datumRow := row.GetDatumRow(fieldTypes) - tablecodec.TruncateIndexValues(tblInfo, idxInfo, datumRow) - - var handleBytes []byte - handleBytes, err := codec.EncodeKey(loc, nil, datumRow...) - if err != nil { - return nil, err - } - return kv.NewCommonHandle(handleBytes) -} - -// getTableRange gets the start and end handle of a table (or partition). -func getTableRange(ctx *JobContext, d *ddlCtx, tbl table.PhysicalTable, snapshotVer uint64, priority int) (startHandleKey, endHandleKey kv.Key, err error) { - // Get the start handle of this partition. - err = iterateSnapshotKeys(ctx, d.store, priority, tbl.RecordPrefix(), snapshotVer, nil, nil, - func(_ kv.Handle, rowKey kv.Key, _ []byte) (bool, error) { - startHandleKey = rowKey - return false, nil - }) - if err != nil { - return startHandleKey, endHandleKey, errors.Trace(err) - } - maxHandle, isEmptyTable, err := d.GetTableMaxHandle(ctx, snapshotVer, tbl) - if err != nil { - return startHandleKey, nil, errors.Trace(err) - } - if maxHandle != nil { - endHandleKey = tablecodec.EncodeRecordKey(tbl.RecordPrefix(), maxHandle).Next() - } - if isEmptyTable || endHandleKey.Cmp(startHandleKey) <= 0 { - logutil.DDLLogger().Info("get noop table range", - zap.String("table", fmt.Sprintf("%v", tbl.Meta())), - zap.Int64("table/partition ID", tbl.GetPhysicalID()), - zap.String("start key", hex.EncodeToString(startHandleKey)), - zap.String("end key", hex.EncodeToString(endHandleKey)), - zap.Bool("is empty table", isEmptyTable)) - if startHandleKey == nil { - endHandleKey = nil - } else { - endHandleKey = startHandleKey.Next() - } - } - return -} - -func getValidCurrentVersion(store kv.Storage) (ver kv.Version, err error) { - ver, err = store.CurrentVersion(kv.GlobalTxnScope) - if err != nil { - return ver, errors.Trace(err) - } else if ver.Ver <= 0 { - return ver, dbterror.ErrInvalidStoreVer.GenWithStack("invalid storage current version %d", ver.Ver) - } - return ver, nil -} - -func getReorgInfo(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, dbInfo *model.DBInfo, - tbl table.Table, elements []*meta.Element, mergingTmpIdx bool) (*reorgInfo, error) { - var ( - element *meta.Element - start kv.Key - end kv.Key - pid int64 - info reorgInfo - ) - - if job.SnapshotVer == 0 { - // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. - // Third step, we need to remove the element information to make sure we can save the reorganized information to storage. - failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { - if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 3, 4) { - if err := rh.RemoveReorgElementFailPoint(job); err != nil { - failpoint.Return(nil, errors.Trace(err)) - } - info.first = true - failpoint.Return(&info, nil) - } - }) - - info.first = true - if d.lease > 0 { // Only delay when it's not in test. - delayForAsyncCommit() - } - ver, err := getValidCurrentVersion(d.store) - if err != nil { - return nil, errors.Trace(err) - } - tblInfo := tbl.Meta() - pid = tblInfo.ID - var tb table.PhysicalTable - if pi := tblInfo.GetPartitionInfo(); pi != nil { - pid = pi.Definitions[0].ID - tb = tbl.(table.PartitionedTable).GetPartition(pid) - } else { - tb = tbl.(table.PhysicalTable) - } - if mergingTmpIdx { - firstElemTempID := tablecodec.TempIndexPrefix | elements[0].ID - lastElemTempID := tablecodec.TempIndexPrefix | elements[len(elements)-1].ID - start = tablecodec.EncodeIndexSeekKey(pid, firstElemTempID, nil) - end = tablecodec.EncodeIndexSeekKey(pid, lastElemTempID, []byte{255}) - } else { - start, end, err = getTableRange(ctx, d, tb, ver.Ver, job.Priority) - if err != nil { - return nil, errors.Trace(err) - } - } - logutil.DDLLogger().Info("job get table range", - zap.Int64("jobID", job.ID), zap.Int64("physicalTableID", pid), - zap.String("startKey", hex.EncodeToString(start)), - zap.String("endKey", hex.EncodeToString(end))) - - failpoint.Inject("errorUpdateReorgHandle", func() (*reorgInfo, error) { - return &info, errors.New("occur an error when update reorg handle") - }) - err = rh.InitDDLReorgHandle(job, start, end, pid, elements[0]) - if err != nil { - return &info, errors.Trace(err) - } - // Update info should after data persistent. - job.SnapshotVer = ver.Ver - element = elements[0] - } else { - failpoint.Inject("MockGetIndexRecordErr", func(val failpoint.Value) { - // For the case of the old TiDB version(do not exist the element information) is upgraded to the new TiDB version. - // Second step, we need to remove the element information to make sure we can get the error of "ErrDDLReorgElementNotExist". - // However, since "txn.Reset()" will be called later, the reorganized information cannot be saved to storage. - if val.(string) == "addIdxNotOwnerErr" && atomic.CompareAndSwapUint32(&mockNotOwnerErrOnce, 2, 3) { - if err := rh.RemoveReorgElementFailPoint(job); err != nil { - failpoint.Return(nil, errors.Trace(err)) - } - } - }) - - var err error - element, start, end, pid, err = rh.GetDDLReorgHandle(job) - if err != nil { - // If the reorg element doesn't exist, this reorg info should be saved by the older TiDB versions. - // It's compatible with the older TiDB versions. - // We'll try to remove it in the next major TiDB version. - if meta.ErrDDLReorgElementNotExist.Equal(err) { - job.SnapshotVer = 0 - logutil.DDLLogger().Warn("get reorg info, the element does not exist", zap.Stringer("job", job)) - if job.IsCancelling() { - return nil, nil - } - } - return &info, errors.Trace(err) - } - } - info.Job = job - info.d = d - info.StartKey = start - info.EndKey = end - info.PhysicalTableID = pid - info.currElement = element - info.elements = elements - info.mergingTmpIdx = mergingTmpIdx - info.dbInfo = dbInfo - - return &info, nil -} - -func getReorgInfoFromPartitions(ctx *JobContext, d *ddlCtx, rh *reorgHandler, job *model.Job, dbInfo *model.DBInfo, tbl table.PartitionedTable, partitionIDs []int64, elements []*meta.Element) (*reorgInfo, error) { - var ( - element *meta.Element - start kv.Key - end kv.Key - pid int64 - info reorgInfo - ) - if job.SnapshotVer == 0 { - info.first = true - if d.lease > 0 { // Only delay when it's not in test. - delayForAsyncCommit() - } - ver, err := getValidCurrentVersion(d.store) - if err != nil { - return nil, errors.Trace(err) - } - pid = partitionIDs[0] - physTbl := tbl.GetPartition(pid) - - start, end, err = getTableRange(ctx, d, physTbl, ver.Ver, job.Priority) - if err != nil { - return nil, errors.Trace(err) - } - logutil.DDLLogger().Info("job get table range", - zap.Int64("job ID", job.ID), zap.Int64("physical table ID", pid), - zap.String("start key", hex.EncodeToString(start)), - zap.String("end key", hex.EncodeToString(end))) - - err = rh.InitDDLReorgHandle(job, start, end, pid, elements[0]) - if err != nil { - return &info, errors.Trace(err) - } - // Update info should after data persistent. - job.SnapshotVer = ver.Ver - element = elements[0] - } else { - var err error - element, start, end, pid, err = rh.GetDDLReorgHandle(job) - if err != nil { - // If the reorg element doesn't exist, this reorg info should be saved by the older TiDB versions. - // It's compatible with the older TiDB versions. - // We'll try to remove it in the next major TiDB version. - if meta.ErrDDLReorgElementNotExist.Equal(err) { - job.SnapshotVer = 0 - logutil.DDLLogger().Warn("get reorg info, the element does not exist", zap.Stringer("job", job)) - } - return &info, errors.Trace(err) - } - } - info.Job = job - info.d = d - info.StartKey = start - info.EndKey = end - info.PhysicalTableID = pid - info.currElement = element - info.elements = elements - info.dbInfo = dbInfo - - return &info, nil -} - -// UpdateReorgMeta creates a new transaction and updates tidb_ddl_reorg table, -// so the reorg can restart in case of issues. -func (r *reorgInfo) UpdateReorgMeta(startKey kv.Key, pool *sess.Pool) (err error) { - if startKey == nil && r.EndKey == nil { - return nil - } - sctx, err := pool.Get() - if err != nil { - return - } - defer pool.Put(sctx) - - se := sess.NewSession(sctx) - err = se.Begin(context.Background()) - if err != nil { - return - } - rh := newReorgHandler(se) - err = updateDDLReorgHandle(rh.s, r.Job.ID, startKey, r.EndKey, r.PhysicalTableID, r.currElement) - err1 := se.Commit(context.Background()) - if err == nil { - err = err1 - } - return errors.Trace(err) -} - -// reorgHandler is used to handle the reorg information duration reorganization DDL job. -type reorgHandler struct { - s *sess.Session -} - -// NewReorgHandlerForTest creates a new reorgHandler, only used in test. -func NewReorgHandlerForTest(se sessionctx.Context) *reorgHandler { - return newReorgHandler(sess.NewSession(se)) -} - -func newReorgHandler(sess *sess.Session) *reorgHandler { - return &reorgHandler{s: sess} -} - -// InitDDLReorgHandle initializes the job reorganization information. -func (r *reorgHandler) InitDDLReorgHandle(job *model.Job, startKey, endKey kv.Key, physicalTableID int64, element *meta.Element) error { - return initDDLReorgHandle(r.s, job.ID, startKey, endKey, physicalTableID, element) -} - -// RemoveReorgElementFailPoint removes the element of the reorganization information. -func (r *reorgHandler) RemoveReorgElementFailPoint(job *model.Job) error { - return removeReorgElement(r.s, job) -} - -// RemoveDDLReorgHandle removes the job reorganization related handles. -func (r *reorgHandler) RemoveDDLReorgHandle(job *model.Job, elements []*meta.Element) error { - return removeDDLReorgHandle(r.s, job, elements) -} - -// cleanupDDLReorgHandles removes the job reorganization related handles. -func cleanupDDLReorgHandles(job *model.Job, s *sess.Session) { - if job != nil && !job.IsFinished() && !job.IsSynced() { - // Job is given, but it is neither finished nor synced; do nothing - return - } - - err := cleanDDLReorgHandles(s, job) - if err != nil { - // ignore error, cleanup is not that critical - logutil.DDLLogger().Warn("Failed removing the DDL reorg entry in tidb_ddl_reorg", zap.Stringer("job", job), zap.Error(err)) - } -} - -// GetDDLReorgHandle gets the latest processed DDL reorganize position. -func (r *reorgHandler) GetDDLReorgHandle(job *model.Job) (element *meta.Element, startKey, endKey kv.Key, physicalTableID int64, err error) { - element, startKey, endKey, physicalTableID, err = getDDLReorgHandle(r.s, job) - if err != nil { - return element, startKey, endKey, physicalTableID, err - } - adjustedEndKey := adjustEndKeyAcrossVersion(job, endKey) - return element, startKey, adjustedEndKey, physicalTableID, nil -} - -// #46306 changes the table range from [start_key, end_key] to [start_key, end_key.next). -// For old version TiDB, the semantic is still [start_key, end_key], we need to adjust it in new version TiDB. -func adjustEndKeyAcrossVersion(job *model.Job, endKey kv.Key) kv.Key { - if job.ReorgMeta != nil && job.ReorgMeta.Version == model.ReorgMetaVersion0 { - logutil.DDLLogger().Info("adjust range end key for old version ReorgMetas", - zap.Int64("jobID", job.ID), - zap.Int64("reorgMetaVersion", job.ReorgMeta.Version), - zap.String("endKey", hex.EncodeToString(endKey))) - return endKey.Next() - } - return endKey -} diff --git a/pkg/ddl/rollingback.go b/pkg/ddl/rollingback.go index 459955cc150d8..b72b2ee76d392 100644 --- a/pkg/ddl/rollingback.go +++ b/pkg/ddl/rollingback.go @@ -52,11 +52,11 @@ func convertAddIdxJob2RollbackJob( allIndexInfos []*model.IndexInfo, err error, ) (int64, error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockConvertAddIdxJob2RollbackJobError")); _err_ == nil { + failpoint.Inject("mockConvertAddIdxJob2RollbackJobError", func(val failpoint.Value) { if val.(bool) { - return 0, errors.New("mock convert add index job to rollback job error") + failpoint.Return(0, errors.New("mock convert add index job to rollback job error")) } - } + }) originalState := allIndexInfos[0].State idxNames := make([]model.CIStr, 0, len(allIndexInfos)) diff --git a/pkg/ddl/rollingback.go__failpoint_stash__ b/pkg/ddl/rollingback.go__failpoint_stash__ deleted file mode 100644 index b72b2ee76d392..0000000000000 --- a/pkg/ddl/rollingback.go__failpoint_stash__ +++ /dev/null @@ -1,629 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "fmt" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/ingest" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util/dbterror" - "go.uber.org/zap" -) - -// UpdateColsNull2NotNull changes the null option of columns of an index. -func UpdateColsNull2NotNull(tblInfo *model.TableInfo, indexInfo *model.IndexInfo) error { - nullCols, err := getNullColInfos(tblInfo, indexInfo) - if err != nil { - return errors.Trace(err) - } - - for _, col := range nullCols { - col.AddFlag(mysql.NotNullFlag) - col.DelFlag(mysql.PreventNullInsertFlag) - } - return nil -} - -func convertAddIdxJob2RollbackJob( - d *ddlCtx, - t *meta.Meta, - job *model.Job, - tblInfo *model.TableInfo, - allIndexInfos []*model.IndexInfo, - err error, -) (int64, error) { - failpoint.Inject("mockConvertAddIdxJob2RollbackJobError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(0, errors.New("mock convert add index job to rollback job error")) - } - }) - - originalState := allIndexInfos[0].State - idxNames := make([]model.CIStr, 0, len(allIndexInfos)) - ifExists := make([]bool, 0, len(allIndexInfos)) - for _, indexInfo := range allIndexInfos { - if indexInfo.Primary { - nullCols, err := getNullColInfos(tblInfo, indexInfo) - if err != nil { - return 0, errors.Trace(err) - } - for _, col := range nullCols { - // Field PreventNullInsertFlag flag reset. - col.DelFlag(mysql.PreventNullInsertFlag) - } - } - // If add index job rollbacks in write reorganization state, its need to delete all keys which has been added. - // Its work is the same as drop index job do. - // The write reorganization state in add index job that likes write only state in drop index job. - // So the next state is delete only state. - indexInfo.State = model.StateDeleteOnly - idxNames = append(idxNames, indexInfo.Name) - ifExists = append(ifExists, false) - } - - // the second and the third args will be used in onDropIndex. - job.Args = []any{idxNames, ifExists, getPartitionIDs(tblInfo)} - job.SchemaState = model.StateDeleteOnly - ver, err1 := updateVersionAndTableInfo(d, t, job, tblInfo, originalState != model.StateDeleteOnly) - if err1 != nil { - return ver, errors.Trace(err1) - } - job.State = model.JobStateRollingback - // TODO(tangenta): get duplicate column and match index. - err = completeErr(err, allIndexInfos[0]) - if ingest.LitBackCtxMgr != nil { - ingest.LitBackCtxMgr.Unregister(job.ID) - } - return ver, errors.Trace(err) -} - -// convertNotReorgAddIdxJob2RollbackJob converts the add index job that are not started workers to rollingbackJob, -// to rollback add index operations. job.SnapshotVer == 0 indicates the workers are not started. -func convertNotReorgAddIdxJob2RollbackJob(d *ddlCtx, t *meta.Meta, job *model.Job, occuredErr error) (ver int64, err error) { - defer func() { - if ingest.LitBackCtxMgr != nil { - ingest.LitBackCtxMgr.Unregister(job.ID) - } - }() - schemaID := job.SchemaID - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return ver, errors.Trace(err) - } - - unique := make([]bool, 1) - indexName := make([]model.CIStr, 1) - indexPartSpecifications := make([][]*ast.IndexPartSpecification, 1) - indexOption := make([]*ast.IndexOption, 1) - - err = job.DecodeArgs(&unique[0], &indexName[0], &indexPartSpecifications[0], &indexOption[0]) - if err != nil { - err = job.DecodeArgs(&unique, &indexName, &indexPartSpecifications, &indexOption) - } - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - var indexesInfo []*model.IndexInfo - for _, idxName := range indexName { - indexInfo := tblInfo.FindIndexByName(idxName.L) - if indexInfo != nil { - indexesInfo = append(indexesInfo, indexInfo) - } - } - if len(indexesInfo) == 0 { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - return convertAddIdxJob2RollbackJob(d, t, job, tblInfo, indexesInfo, occuredErr) -} - -// rollingbackModifyColumn change the modifying-column job into rolling back state. -// Since modifying column job has two types: normal-type and reorg-type, we should handle it respectively. -// normal-type has only two states: None -> Public -// reorg-type has five states: None -> Delete-only -> Write-only -> Write-org -> Public -func rollingbackModifyColumn(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - if needNotifyAndStopReorgWorker(job) { - // column type change workers are started. we have to ask them to exit. - w.jobLogger(job).Info("run the cancelling DDL job", zap.String("job", job.String())) - d.notifyReorgWorkerJobStateChange(job) - // Give the this kind of ddl one more round to run, the dbterror.ErrCancelledDDLJob should be fetched from the bottom up. - return w.onModifyColumn(d, t, job) - } - _, tblInfo, oldCol, jp, err := getModifyColumnInfo(t, job) - if err != nil { - return ver, err - } - if !needChangeColumnData(oldCol, jp.newCol) { - // Normal-type rolling back - if job.SchemaState == model.StateNone { - // When change null to not null, although state is unchanged with none, the oldCol flag's has been changed to preNullInsertFlag. - // To roll back this kind of normal job, it is necessary to mark the state as JobStateRollingback to restore the old col's flag. - if jp.modifyColumnTp == mysql.TypeNull && tblInfo.Columns[oldCol.Offset].GetFlag()|mysql.PreventNullInsertFlag != 0 { - job.State = model.JobStateRollingback - return ver, dbterror.ErrCancelledDDLJob - } - // Normal job with stateNone can be cancelled directly. - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - // StatePublic couldn't be cancelled. - job.State = model.JobStateRunning - return ver, nil - } - // reorg-type rolling back - if jp.changingCol == nil { - // The job hasn't been handled and we cancel it directly. - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - // The job has been in its middle state (but the reorg worker hasn't started) and we roll it back here. - job.State = model.JobStateRollingback - return ver, dbterror.ErrCancelledDDLJob -} - -func rollingbackAddColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, columnInfo, col, _, _, err := checkAddColumn(t, job) - if err != nil { - return ver, errors.Trace(err) - } - if columnInfo == nil { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - - originalState := columnInfo.State - columnInfo.State = model.StateDeleteOnly - job.SchemaState = model.StateDeleteOnly - - job.Args = []any{col.Name} - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != columnInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - - job.State = model.JobStateRollingback - return ver, dbterror.ErrCancelledDDLJob -} - -func rollingbackDropColumn(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - _, colInfo, idxInfos, _, err := checkDropColumn(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - - for _, indexInfo := range idxInfos { - switch indexInfo.State { - case model.StateWriteOnly, model.StateDeleteOnly, model.StateDeleteReorganization, model.StateNone: - // We can not rollback now, so just continue to drop index. - // In function isJobRollbackable will let job rollback when state is StateNone. - // When there is no index related to the drop column job it is OK, but when there has indices, we should - // make sure the job is not rollback. - job.State = model.JobStateRunning - return ver, nil - case model.StatePublic: - default: - return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", indexInfo.State) - } - } - - // StatePublic means when the job is not running yet. - if colInfo.State == model.StatePublic { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - // In the state of drop column `write only -> delete only -> reorganization`, - // We can not rollback now, so just continue to drop column. - job.State = model.JobStateRunning - return ver, nil -} - -func rollingbackDropIndex(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - _, indexInfo, _, err := checkDropIndex(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - - switch indexInfo[0].State { - case model.StateWriteOnly, model.StateDeleteOnly, model.StateDeleteReorganization, model.StateNone: - // We can not rollback now, so just continue to drop index. - // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. - job.State = model.JobStateRunning - return ver, nil - case model.StatePublic: - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - default: - return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("index", indexInfo[0].State) - } -} - -func rollingbackAddIndex(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job, isPK bool) (ver int64, err error) { - if needNotifyAndStopReorgWorker(job) { - // add index workers are started. need to ask them to exit. - w.jobLogger(job).Info("run the cancelling DDL job", zap.String("job", job.String())) - d.notifyReorgWorkerJobStateChange(job) - ver, err = w.onCreateIndex(d, t, job, isPK) - } else { - // add index's reorg workers are not running, remove the indexInfo in tableInfo. - ver, err = convertNotReorgAddIdxJob2RollbackJob(d, t, job, dbterror.ErrCancelledDDLJob) - } - return -} - -func needNotifyAndStopReorgWorker(job *model.Job) bool { - if job.SchemaState == model.StateWriteReorganization && job.SnapshotVer != 0 { - // If the value of SnapshotVer isn't zero, it means the reorg workers have been started. - if job.MultiSchemaInfo != nil { - // However, if the sub-job is non-revertible, it means the reorg process is finished. - // We don't need to start another round to notify reorg workers to exit. - return job.MultiSchemaInfo.Revertible - } - return true - } - return false -} - -// rollbackExchangeTablePartition will clear the non-partitioned -// table's ExchangePartitionInfo state. -func rollbackExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo) (ver int64, err error) { - tblInfo.ExchangePartitionInfo = nil - job.State = model.JobStateRollbackDone - job.SchemaState = model.StatePublic - if len(tblInfo.Constraints) == 0 { - return updateVersionAndTableInfo(d, t, job, tblInfo, true) - } - var ( - defID int64 - ptSchemaID int64 - ptID int64 - partName string - withValidation bool - ) - if err = job.DecodeArgs(&defID, &ptSchemaID, &ptID, &partName, &withValidation); err != nil { - return ver, errors.Trace(err) - } - pt, err := getTableInfo(t, ptID, ptSchemaID) - if err != nil { - return ver, errors.Trace(err) - } - pt.ExchangePartitionInfo = nil - var ptInfo []schemaIDAndTableInfo - ptInfo = append(ptInfo, schemaIDAndTableInfo{ - schemaID: ptSchemaID, - tblInfo: pt, - }) - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true, ptInfo...) - return ver, errors.Trace(err) -} - -func rollingbackExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - if job.SchemaState == model.StateNone { - // Nothing is changed - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - var nt *model.TableInfo - nt, err = GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - ver, err = rollbackExchangeTablePartition(d, t, job, nt) - return ver, errors.Trace(err) -} - -func convertAddTablePartitionJob2RollbackJob(d *ddlCtx, t *meta.Meta, job *model.Job, otherwiseErr error, tblInfo *model.TableInfo) (ver int64, err error) { - addingDefinitions := tblInfo.Partition.AddingDefinitions - partNames := make([]string, 0, len(addingDefinitions)) - for _, pd := range addingDefinitions { - partNames = append(partNames, pd.Name.L) - } - if job.Type == model.ActionReorganizePartition || - job.Type == model.ActionAlterTablePartitioning || - job.Type == model.ActionRemovePartitioning { - partInfo := &model.PartitionInfo{} - var pNames []string - err = job.DecodeArgs(&pNames, &partInfo) - if err != nil { - return ver, err - } - job.Args = []any{partNames, partInfo} - } else { - job.Args = []any{partNames} - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.State = model.JobStateRollingback - return ver, errors.Trace(otherwiseErr) -} - -func rollingbackAddTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, _, addingDefinitions, err := checkAddPartition(t, job) - if err != nil { - return ver, errors.Trace(err) - } - // addingDefinitions' len = 0 means the job hasn't reached the replica-only state. - if len(addingDefinitions) == 0 { - job.State = model.JobStateCancelled - return ver, errors.Trace(dbterror.ErrCancelledDDLJob) - } - // addingDefinitions is also in tblInfo, here pass the tblInfo as parameter directly. - return convertAddTablePartitionJob2RollbackJob(d, t, job, dbterror.ErrCancelledDDLJob, tblInfo) -} - -func rollingbackDropTableOrView(t *meta.Meta, job *model.Job) error { - tblInfo, err := checkTableExistAndCancelNonExistJob(t, job, job.SchemaID) - if err != nil { - return errors.Trace(err) - } - // To simplify the rollback logic, cannot be canceled after job start to run. - // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. - if tblInfo.State == model.StatePublic { - job.State = model.JobStateCancelled - return dbterror.ErrCancelledDDLJob - } - job.State = model.JobStateRunning - return nil -} - -func rollingbackDropTablePartition(t *meta.Meta, job *model.Job) (ver int64, err error) { - _, err = GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - return cancelOnlyNotHandledJob(job, model.StatePublic) -} - -func rollingbackDropSchema(t *meta.Meta, job *model.Job) error { - dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) - if err != nil { - return errors.Trace(err) - } - // To simplify the rollback logic, cannot be canceled after job start to run. - // Normally won't fetch here, because there is check when cancel ddl jobs. see function: isJobRollbackable. - if dbInfo.State == model.StatePublic { - job.State = model.JobStateCancelled - return dbterror.ErrCancelledDDLJob - } - job.State = model.JobStateRunning - return nil -} - -func rollingbackRenameIndex(t *meta.Meta, job *model.Job) (ver int64, err error) { - tblInfo, from, _, err := checkRenameIndex(t, job) - if err != nil { - return ver, errors.Trace(err) - } - // Here rename index is done in a transaction, if the job is not completed, it can be canceled. - idx := tblInfo.FindIndexByName(from.L) - if idx.State == model.StatePublic { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - job.State = model.JobStateRunning - return ver, errors.Trace(err) -} - -func cancelOnlyNotHandledJob(job *model.Job, initialState model.SchemaState) (ver int64, err error) { - // We can only cancel the not handled job. - if job.SchemaState == initialState { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - - job.State = model.JobStateRunning - - return ver, nil -} - -func rollingbackTruncateTable(t *meta.Meta, job *model.Job) (ver int64, err error) { - _, err = GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - return cancelOnlyNotHandledJob(job, model.StateNone) -} - -func rollingbackReorganizePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - if job.SchemaState == model.StateNone { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - - // addingDefinitions is also in tblInfo, here pass the tblInfo as parameter directly. - // TODO: Test this with reorganize partition p1 into (partition p1 ...)! - return convertAddTablePartitionJob2RollbackJob(d, t, job, dbterror.ErrCancelledDDLJob, tblInfo) -} - -func pauseReorgWorkers(w *worker, d *ddlCtx, job *model.Job) (err error) { - if needNotifyAndStopReorgWorker(job) { - w.jobLogger(job).Info("pausing the DDL job", zap.String("job", job.String())) - d.notifyReorgWorkerJobStateChange(job) - } - - return dbterror.ErrPausedDDLJob.GenWithStackByArgs(job.ID) -} - -func convertJob2RollbackJob(w *worker, d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - switch job.Type { - case model.ActionAddColumn: - ver, err = rollingbackAddColumn(d, t, job) - case model.ActionAddIndex: - ver, err = rollingbackAddIndex(w, d, t, job, false) - case model.ActionAddPrimaryKey: - ver, err = rollingbackAddIndex(w, d, t, job, true) - case model.ActionAddTablePartition: - ver, err = rollingbackAddTablePartition(d, t, job) - case model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - ver, err = rollingbackReorganizePartition(d, t, job) - case model.ActionDropColumn: - ver, err = rollingbackDropColumn(d, t, job) - case model.ActionDropIndex, model.ActionDropPrimaryKey: - ver, err = rollingbackDropIndex(d, t, job) - case model.ActionDropTable, model.ActionDropView, model.ActionDropSequence: - err = rollingbackDropTableOrView(t, job) - case model.ActionDropTablePartition: - ver, err = rollingbackDropTablePartition(t, job) - case model.ActionExchangeTablePartition: - ver, err = rollingbackExchangeTablePartition(d, t, job) - case model.ActionDropSchema: - err = rollingbackDropSchema(t, job) - case model.ActionRenameIndex: - ver, err = rollingbackRenameIndex(t, job) - case model.ActionTruncateTable: - ver, err = rollingbackTruncateTable(t, job) - case model.ActionModifyColumn: - ver, err = rollingbackModifyColumn(w, d, t, job) - case model.ActionDropForeignKey, model.ActionTruncateTablePartition: - ver, err = cancelOnlyNotHandledJob(job, model.StatePublic) - case model.ActionRebaseAutoID, model.ActionShardRowID, model.ActionAddForeignKey, - model.ActionRenameTable, model.ActionRenameTables, - model.ActionModifyTableCharsetAndCollate, - model.ActionModifySchemaCharsetAndCollate, model.ActionRepairTable, - model.ActionModifyTableAutoIdCache, model.ActionAlterIndexVisibility, - model.ActionModifySchemaDefaultPlacement, model.ActionRecoverSchema: - ver, err = cancelOnlyNotHandledJob(job, model.StateNone) - case model.ActionMultiSchemaChange: - err = rollingBackMultiSchemaChange(job) - case model.ActionAddCheckConstraint: - ver, err = rollingBackAddConstraint(d, t, job) - case model.ActionDropCheckConstraint: - ver, err = rollingBackDropConstraint(t, job) - case model.ActionAlterCheckConstraint: - ver, err = rollingBackAlterConstraint(d, t, job) - default: - job.State = model.JobStateCancelled - err = dbterror.ErrCancelledDDLJob - } - - logger := w.jobLogger(job) - if err != nil { - if job.Error == nil { - job.Error = toTError(err) - } - job.ErrorCount++ - - if dbterror.ErrCancelledDDLJob.Equal(err) { - // The job is normally cancelled. - if !job.Error.Equal(dbterror.ErrCancelledDDLJob) { - job.Error = terror.GetErrClass(job.Error).Synthesize(terror.ErrCode(job.Error.Code()), - fmt.Sprintf("DDL job rollback, error msg: %s", terror.ToSQLError(job.Error).Message)) - } - } else { - // A job canceling meet other error. - // - // Once `convertJob2RollbackJob` meets an error, the job state can't be set as `JobStateRollingback` since - // job state and args may not be correctly overwritten. The job will be fetched to run with the cancelling - // state again. So we should check the error count here. - if err1 := loadDDLVars(w); err1 != nil { - logger.Error("load DDL global variable failed", zap.Error(err1)) - } - errorCount := variable.GetDDLErrorCountLimit() - if job.ErrorCount > errorCount { - logger.Warn("rollback DDL job error count exceed the limit, cancelled it now", zap.Int64("errorCountLimit", errorCount)) - job.Error = toTError(errors.Errorf("rollback DDL job error count exceed the limit %d, cancelled it now", errorCount)) - job.State = model.JobStateCancelled - } - } - - if !(job.State != model.JobStateRollingback && job.State != model.JobStateCancelled) { - logger.Info("the DDL job is cancelled normally", zap.String("job", job.String()), zap.Error(err)) - // If job is cancelled, we shouldn't return an error. - return ver, nil - } - logger.Error("run DDL job failed", zap.String("job", job.String()), zap.Error(err)) - } - - return -} - -func rollingBackAddConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - _, tblInfo, constrInfoInMeta, _, err := checkAddCheckConstraint(t, job) - if err != nil { - return ver, errors.Trace(err) - } - if constrInfoInMeta == nil { - // Add constraint hasn't stored constraint info into meta, so we can cancel the job - // directly without further rollback action. - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - for i, constr := range tblInfo.Constraints { - if constr.Name.L == constrInfoInMeta.Name.L { - tblInfo.Constraints = append(tblInfo.Constraints[0:i], tblInfo.Constraints[i+1:]...) - break - } - } - if job.IsRollingback() { - job.State = model.JobStateRollbackDone - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - return ver, errors.Trace(err) -} - -func rollingBackDropConstraint(t *meta.Meta, job *model.Job) (ver int64, err error) { - _, constrInfoInMeta, err := checkDropCheckConstraint(t, job) - if err != nil { - return ver, errors.Trace(err) - } - - // StatePublic means when the job is not running yet. - if constrInfoInMeta.State == model.StatePublic { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - // Can not rollback like drop other element, so just continue to drop constraint. - job.State = model.JobStateRunning - return ver, nil -} - -func rollingBackAlterConstraint(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - _, tblInfo, constraintInfo, enforced, err := checkAlterCheckConstraint(t, job) - if err != nil { - return ver, errors.Trace(err) - } - - // StatePublic means when the job is not running yet. - if constraintInfo.State == model.StatePublic { - job.State = model.JobStateCancelled - return ver, dbterror.ErrCancelledDDLJob - } - - // Only alter check constraints ENFORCED can get here. - constraintInfo.Enforced = !enforced - constraintInfo.State = model.StatePublic - if job.IsRollingback() { - job.State = model.JobStateRollbackDone - } - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tblInfo, true) - return ver, errors.Trace(err) -} diff --git a/pkg/ddl/schema_version.go b/pkg/ddl/schema_version.go index 782c6397e124b..96fbdad427d81 100644 --- a/pkg/ddl/schema_version.go +++ b/pkg/ddl/schema_version.go @@ -362,14 +362,14 @@ func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ... } func checkAllVersions(ctx context.Context, d *ddlCtx, job *model.Job, latestSchemaVersion int64, timeStart time.Time) error { - if val, _err_ := failpoint.Eval(_curpkg_("checkDownBeforeUpdateGlobalVersion")); _err_ == nil { + failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { if val.(bool) { if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { panic("check down before update global version failed") } mockDDLErrOnce = -1 } - } + }) // OwnerCheckAllVersions returns only when all TiDB schemas are synced(exclude the isolated TiDB). err := d.schemaSyncer.OwnerCheckAllVersions(ctx, job.ID, latestSchemaVersion) @@ -405,14 +405,14 @@ func waitSchemaSynced(ctx context.Context, d *ddlCtx, job *model.Job) error { return err } - if val, _err_ := failpoint.Eval(_curpkg_("checkDownBeforeUpdateGlobalVersion")); _err_ == nil { + failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { if val.(bool) { if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { panic("check down before update global version failed") } mockDDLErrOnce = -1 } - } + }) return waitSchemaChanged(ctx, d, latestSchemaVersion, job) } diff --git a/pkg/ddl/schema_version.go__failpoint_stash__ b/pkg/ddl/schema_version.go__failpoint_stash__ deleted file mode 100644 index 96fbdad427d81..0000000000000 --- a/pkg/ddl/schema_version.go__failpoint_stash__ +++ /dev/null @@ -1,418 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/util/mathutil" - "go.uber.org/zap" -) - -// SetSchemaDiffForCreateTables set SchemaDiff for ActionCreateTables. -func SetSchemaDiffForCreateTables(diff *model.SchemaDiff, job *model.Job) error { - var tableInfos []*model.TableInfo - err := job.DecodeArgs(&tableInfos) - if err != nil { - return errors.Trace(err) - } - diff.AffectedOpts = make([]*model.AffectedOption, len(tableInfos)) - for i := range tableInfos { - diff.AffectedOpts[i] = &model.AffectedOption{ - SchemaID: job.SchemaID, - OldSchemaID: job.SchemaID, - TableID: tableInfos[i].ID, - OldTableID: tableInfos[i].ID, - } - } - return nil -} - -// SetSchemaDiffForTruncateTable set SchemaDiff for ActionTruncateTable. -func SetSchemaDiffForTruncateTable(diff *model.SchemaDiff, job *model.Job) error { - // Truncate table has two table ID, should be handled differently. - err := job.DecodeArgs(&diff.TableID) - if err != nil { - return errors.Trace(err) - } - diff.OldTableID = job.TableID - - // affects are used to update placement rule cache - if len(job.CtxVars) > 0 { - oldIDs := job.CtxVars[0].([]int64) - newIDs := job.CtxVars[1].([]int64) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } - return nil -} - -// SetSchemaDiffForCreateView set SchemaDiff for ActionCreateView. -func SetSchemaDiffForCreateView(diff *model.SchemaDiff, job *model.Job) error { - tbInfo := &model.TableInfo{} - var orReplace bool - var oldTbInfoID int64 - if err := job.DecodeArgs(tbInfo, &orReplace, &oldTbInfoID); err != nil { - return errors.Trace(err) - } - // When the statement is "create or replace view " and we need to drop the old view, - // it has two table IDs and should be handled differently. - if oldTbInfoID > 0 && orReplace { - diff.OldTableID = oldTbInfoID - } - diff.TableID = tbInfo.ID - return nil -} - -// SetSchemaDiffForRenameTable set SchemaDiff for ActionRenameTable. -func SetSchemaDiffForRenameTable(diff *model.SchemaDiff, job *model.Job) error { - err := job.DecodeArgs(&diff.OldSchemaID) - if err != nil { - return errors.Trace(err) - } - diff.TableID = job.TableID - return nil -} - -// SetSchemaDiffForRenameTables set SchemaDiff for ActionRenameTables. -func SetSchemaDiffForRenameTables(diff *model.SchemaDiff, job *model.Job) error { - var ( - oldSchemaIDs, newSchemaIDs, tableIDs []int64 - tableNames, oldSchemaNames []*model.CIStr - ) - err := job.DecodeArgs(&oldSchemaIDs, &newSchemaIDs, &tableNames, &tableIDs, &oldSchemaNames) - if err != nil { - return errors.Trace(err) - } - affects := make([]*model.AffectedOption, len(newSchemaIDs)-1) - for i, newSchemaID := range newSchemaIDs { - // Do not add the first table to AffectedOpts. Related issue tidb#47064. - if i == 0 { - continue - } - affects[i-1] = &model.AffectedOption{ - SchemaID: newSchemaID, - TableID: tableIDs[i], - OldTableID: tableIDs[i], - OldSchemaID: oldSchemaIDs[i], - } - } - diff.TableID = tableIDs[0] - diff.SchemaID = newSchemaIDs[0] - diff.OldSchemaID = oldSchemaIDs[0] - diff.AffectedOpts = affects - return nil -} - -// SetSchemaDiffForExchangeTablePartition set SchemaDiff for ActionExchangeTablePartition. -func SetSchemaDiffForExchangeTablePartition(diff *model.SchemaDiff, job *model.Job, multiInfos ...schemaIDAndTableInfo) error { - // From start of function: diff.SchemaID = job.SchemaID - // Old is original non partitioned table - diff.OldTableID = job.TableID - diff.OldSchemaID = job.SchemaID - // Update the partitioned table (it is only done in the last state) - var ( - ptSchemaID int64 - ptTableID int64 - ptDefID int64 - partName string // Not used - withValidation bool // Not used - ) - // See ddl.ExchangeTablePartition - err := job.DecodeArgs(&ptDefID, &ptSchemaID, &ptTableID, &partName, &withValidation) - if err != nil { - return errors.Trace(err) - } - // This is needed for not crashing TiFlash! - // TODO: Update TiFlash, to handle StateWriteOnly - diff.AffectedOpts = []*model.AffectedOption{{ - TableID: ptTableID, - }} - if job.SchemaState != model.StatePublic { - // No change, just to refresh the non-partitioned table - // with its new ExchangePartitionInfo. - diff.TableID = job.TableID - // Keep this as Schema ID of non-partitioned table - // to avoid trigger early rename in TiFlash - diff.AffectedOpts[0].SchemaID = job.SchemaID - // Need reload partition table, use diff.AffectedOpts[0].OldSchemaID to mark it. - if len(multiInfos) > 0 { - diff.AffectedOpts[0].OldSchemaID = ptSchemaID - } - } else { - // Swap - diff.TableID = ptDefID - // Also add correct SchemaID in case different schemas - diff.AffectedOpts[0].SchemaID = ptSchemaID - } - return nil -} - -// SetSchemaDiffForTruncateTablePartition set SchemaDiff for ActionTruncateTablePartition. -func SetSchemaDiffForTruncateTablePartition(diff *model.SchemaDiff, job *model.Job) { - diff.TableID = job.TableID - if len(job.CtxVars) > 0 { - oldIDs := job.CtxVars[0].([]int64) - newIDs := job.CtxVars[1].([]int64) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } -} - -// SetSchemaDiffForDropTable set SchemaDiff for ActionDropTablePartition, ActionRecoverTable, ActionDropTable. -func SetSchemaDiffForDropTable(diff *model.SchemaDiff, job *model.Job) { - // affects are used to update placement rule cache - diff.TableID = job.TableID - if len(job.CtxVars) > 0 { - if oldIDs, ok := job.CtxVars[0].([]int64); ok { - diff.AffectedOpts = buildPlacementAffects(oldIDs, oldIDs) - } - } -} - -// SetSchemaDiffForReorganizePartition set SchemaDiff for ActionReorganizePartition. -func SetSchemaDiffForReorganizePartition(diff *model.SchemaDiff, job *model.Job) { - diff.TableID = job.TableID - // TODO: should this be for every state of Reorganize? - if len(job.CtxVars) > 0 { - if droppedIDs, ok := job.CtxVars[0].([]int64); ok { - if addedIDs, ok := job.CtxVars[1].([]int64); ok { - // to use AffectedOpts we need both new and old to have the same length - maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) - // Also initialize them to 0! - oldIDs := make([]int64, maxParts) - copy(oldIDs, droppedIDs) - newIDs := make([]int64, maxParts) - copy(newIDs, addedIDs) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } - } - } -} - -// SetSchemaDiffForPartitionModify set SchemaDiff for ActionRemovePartitioning, ActionAlterTablePartitioning. -func SetSchemaDiffForPartitionModify(diff *model.SchemaDiff, job *model.Job) error { - diff.TableID = job.TableID - diff.OldTableID = job.TableID - if job.SchemaState == model.StateDeleteReorganization { - partInfo := &model.PartitionInfo{} - var partNames []string - err := job.DecodeArgs(&partNames, &partInfo) - if err != nil { - return errors.Trace(err) - } - // Final part, new table id is assigned - diff.TableID = partInfo.NewTableID - if len(job.CtxVars) > 0 { - if droppedIDs, ok := job.CtxVars[0].([]int64); ok { - if addedIDs, ok := job.CtxVars[1].([]int64); ok { - // to use AffectedOpts we need both new and old to have the same length - maxParts := mathutil.Max[int](len(droppedIDs), len(addedIDs)) - // Also initialize them to 0! - oldIDs := make([]int64, maxParts) - copy(oldIDs, droppedIDs) - newIDs := make([]int64, maxParts) - copy(newIDs, addedIDs) - diff.AffectedOpts = buildPlacementAffects(oldIDs, newIDs) - } - } - } - } - return nil -} - -// SetSchemaDiffForCreateTable set SchemaDiff for ActionCreateTable. -func SetSchemaDiffForCreateTable(diff *model.SchemaDiff, job *model.Job) { - diff.TableID = job.TableID - if len(job.Args) > 0 { - tbInfo, _ := job.Args[0].(*model.TableInfo) - // When create table with foreign key, there are two schema status change: - // 1. none -> write-only - // 2. write-only -> public - // In the second status change write-only -> public, infoschema loader should apply drop old table first, then - // apply create new table. So need to set diff.OldTableID here to make sure it. - if tbInfo != nil && tbInfo.State == model.StatePublic && len(tbInfo.ForeignKeys) > 0 { - diff.OldTableID = job.TableID - } - } -} - -// SetSchemaDiffForRecoverSchema set SchemaDiff for ActionRecoverSchema. -func SetSchemaDiffForRecoverSchema(diff *model.SchemaDiff, job *model.Job) error { - var ( - recoverSchemaInfo *RecoverSchemaInfo - recoverSchemaCheckFlag int64 - ) - err := job.DecodeArgs(&recoverSchemaInfo, &recoverSchemaCheckFlag) - if err != nil { - return errors.Trace(err) - } - // Reserved recoverSchemaCheckFlag value for gc work judgment. - job.Args[checkFlagIndexInJobArgs] = recoverSchemaCheckFlag - recoverTabsInfo := recoverSchemaInfo.RecoverTabsInfo - diff.AffectedOpts = make([]*model.AffectedOption, len(recoverTabsInfo)) - for i := range recoverTabsInfo { - diff.AffectedOpts[i] = &model.AffectedOption{ - SchemaID: job.SchemaID, - OldSchemaID: job.SchemaID, - TableID: recoverTabsInfo[i].TableInfo.ID, - OldTableID: recoverTabsInfo[i].TableInfo.ID, - } - } - diff.ReadTableFromMeta = true - return nil -} - -// SetSchemaDiffForFlashbackCluster set SchemaDiff for ActionFlashbackCluster. -func SetSchemaDiffForFlashbackCluster(diff *model.SchemaDiff, job *model.Job) { - diff.TableID = -1 - if job.SchemaState == model.StatePublic { - diff.RegenerateSchemaMap = true - } -} - -// SetSchemaDiffForMultiInfos set SchemaDiff for multiInfos. -func SetSchemaDiffForMultiInfos(diff *model.SchemaDiff, multiInfos ...schemaIDAndTableInfo) { - if len(multiInfos) > 0 { - existsMap := make(map[int64]struct{}) - existsMap[diff.TableID] = struct{}{} - for _, affect := range diff.AffectedOpts { - existsMap[affect.TableID] = struct{}{} - } - for _, info := range multiInfos { - _, exist := existsMap[info.tblInfo.ID] - if exist { - continue - } - existsMap[info.tblInfo.ID] = struct{}{} - diff.AffectedOpts = append(diff.AffectedOpts, &model.AffectedOption{ - SchemaID: info.schemaID, - OldSchemaID: info.schemaID, - TableID: info.tblInfo.ID, - OldTableID: info.tblInfo.ID, - }) - } - } -} - -// updateSchemaVersion increments the schema version by 1 and sets SchemaDiff. -func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ...schemaIDAndTableInfo) (int64, error) { - schemaVersion, err := d.setSchemaVersion(job, d.store) - if err != nil { - return 0, errors.Trace(err) - } - diff := &model.SchemaDiff{ - Version: schemaVersion, - Type: job.Type, - SchemaID: job.SchemaID, - } - switch job.Type { - case model.ActionCreateTables: - err = SetSchemaDiffForCreateTables(diff, job) - case model.ActionTruncateTable: - err = SetSchemaDiffForTruncateTable(diff, job) - case model.ActionCreateView: - err = SetSchemaDiffForCreateView(diff, job) - case model.ActionRenameTable: - err = SetSchemaDiffForRenameTable(diff, job) - case model.ActionRenameTables: - err = SetSchemaDiffForRenameTables(diff, job) - case model.ActionExchangeTablePartition: - err = SetSchemaDiffForExchangeTablePartition(diff, job, multiInfos...) - case model.ActionTruncateTablePartition: - SetSchemaDiffForTruncateTablePartition(diff, job) - case model.ActionDropTablePartition, model.ActionRecoverTable, model.ActionDropTable: - SetSchemaDiffForDropTable(diff, job) - case model.ActionReorganizePartition: - SetSchemaDiffForReorganizePartition(diff, job) - case model.ActionRemovePartitioning, model.ActionAlterTablePartitioning: - err = SetSchemaDiffForPartitionModify(diff, job) - case model.ActionCreateTable: - SetSchemaDiffForCreateTable(diff, job) - case model.ActionRecoverSchema: - err = SetSchemaDiffForRecoverSchema(diff, job) - case model.ActionFlashbackCluster: - SetSchemaDiffForFlashbackCluster(diff, job) - default: - diff.TableID = job.TableID - } - if err != nil { - return 0, err - } - SetSchemaDiffForMultiInfos(diff, multiInfos...) - err = t.SetSchemaDiff(diff) - return schemaVersion, errors.Trace(err) -} - -func checkAllVersions(ctx context.Context, d *ddlCtx, job *model.Job, latestSchemaVersion int64, timeStart time.Time) error { - failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { - if val.(bool) { - if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { - panic("check down before update global version failed") - } - mockDDLErrOnce = -1 - } - }) - - // OwnerCheckAllVersions returns only when all TiDB schemas are synced(exclude the isolated TiDB). - err := d.schemaSyncer.OwnerCheckAllVersions(ctx, job.ID, latestSchemaVersion) - if err != nil { - logutil.DDLLogger().Info("wait latest schema version encounter error", zap.Int64("ver", latestSchemaVersion), - zap.Int64("jobID", job.ID), zap.Duration("take time", time.Since(timeStart)), zap.Error(err)) - return err - } - logutil.DDLLogger().Info("wait latest schema version changed(get the metadata lock if tidb_enable_metadata_lock is true)", - zap.Int64("ver", latestSchemaVersion), - zap.Duration("take time", time.Since(timeStart)), - zap.String("job", job.String())) - return nil -} - -// waitSchemaSynced handles the following situation: -// If the job enters a new state, and the worker crash when it's in the process of -// version sync, then the worker restarts quickly, we may run the job immediately again, -// but schema version might not sync. -// So here we get the latest schema version to make sure all servers' schema version -// update to the latest schema version in a cluster. -func waitSchemaSynced(ctx context.Context, d *ddlCtx, job *model.Job) error { - if !job.IsRunning() && !job.IsRollingback() && !job.IsDone() && !job.IsRollbackDone() { - return nil - } - - ver, _ := d.store.CurrentVersion(kv.GlobalTxnScope) - snapshot := d.store.GetSnapshot(ver) - m := meta.NewSnapshotMeta(snapshot) - latestSchemaVersion, err := m.GetSchemaVersionWithNonEmptyDiff() - if err != nil { - logutil.DDLLogger().Warn("get global version failed", zap.Int64("jobID", job.ID), zap.Error(err)) - return err - } - - failpoint.Inject("checkDownBeforeUpdateGlobalVersion", func(val failpoint.Value) { - if val.(bool) { - if mockDDLErrOnce > 0 && mockDDLErrOnce != latestSchemaVersion { - panic("check down before update global version failed") - } - mockDDLErrOnce = -1 - } - }) - - return waitSchemaChanged(ctx, d, latestSchemaVersion, job) -} diff --git a/pkg/ddl/session/binding__failpoint_binding__.go b/pkg/ddl/session/binding__failpoint_binding__.go deleted file mode 100644 index 9ef59b452261c..0000000000000 --- a/pkg/ddl/session/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package session - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/ddl/session/session.go b/pkg/ddl/session/session.go index 17ebdf16bae53..bc16de5b0c484 100644 --- a/pkg/ddl/session/session.go +++ b/pkg/ddl/session/session.go @@ -109,7 +109,7 @@ func (s *Session) RunInTxn(f func(*Session) error) (err error) { if err != nil { return err } - if val, _err_ := failpoint.Eval(_curpkg_("NotifyBeginTxnCh")); _err_ == nil { + failpoint.Inject("NotifyBeginTxnCh", func(val failpoint.Value) { //nolint:forcetypeassert v := val.(int) if v == 1 { @@ -119,7 +119,7 @@ func (s *Session) RunInTxn(f func(*Session) error) (err error) { <-TestNotifyBeginTxnCh MockDDLOnce = 0 } - } + }) err = f(s) if err != nil { diff --git a/pkg/ddl/session/session.go__failpoint_stash__ b/pkg/ddl/session/session.go__failpoint_stash__ deleted file mode 100644 index bc16de5b0c484..0000000000000 --- a/pkg/ddl/session/session.go__failpoint_stash__ +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/sqlexec" -) - -// Session wraps sessionctx.Context for transaction usage. -type Session struct { - sessionctx.Context -} - -// NewSession creates a new Session. -func NewSession(s sessionctx.Context) *Session { - return &Session{s} -} - -// Begin starts a transaction. -func (s *Session) Begin(ctx context.Context) error { - err := sessiontxn.NewTxn(ctx, s.Context) - if err != nil { - return err - } - s.GetSessionVars().SetInTxn(true) - return nil -} - -// Commit commits the transaction. -func (s *Session) Commit(ctx context.Context) error { - s.StmtCommit(ctx) - return s.CommitTxn(ctx) -} - -// Txn activate and returns the current transaction. -func (s *Session) Txn() (kv.Transaction, error) { - return s.Context.Txn(true) -} - -// Rollback aborts the transaction. -func (s *Session) Rollback() { - s.StmtRollback(context.Background(), false) - s.RollbackTxn(context.Background()) -} - -// Reset resets the session. -func (s *Session) Reset() { - s.StmtRollback(context.Background(), false) -} - -// Execute executes a query. -func (s *Session) Execute(ctx context.Context, query string, label string) ([]chunk.Row, error) { - startTime := time.Now() - var err error - defer func() { - metrics.DDLJobTableDuration.WithLabelValues(label + "-" + metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - }() - - if ctx.Value(kv.RequestSourceKey) == nil { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnDDL) - } - rs, err := s.Context.GetSQLExecutor().ExecuteInternal(ctx, query) - if err != nil { - return nil, errors.Trace(err) - } - - if rs == nil { - return nil, nil - } - var rows []chunk.Row - defer terror.Call(rs.Close) - if rows, err = sqlexec.DrainRecordSet(ctx, rs, 8); err != nil { - return nil, errors.Trace(err) - } - return rows, nil -} - -// Session returns the sessionctx.Context. -func (s *Session) Session() sessionctx.Context { - return s.Context -} - -// RunInTxn runs a function in a transaction. -func (s *Session) RunInTxn(f func(*Session) error) (err error) { - err = s.Begin(context.Background()) - if err != nil { - return err - } - failpoint.Inject("NotifyBeginTxnCh", func(val failpoint.Value) { - //nolint:forcetypeassert - v := val.(int) - if v == 1 { - MockDDLOnce = 1 - TestNotifyBeginTxnCh <- struct{}{} - } else if v == 2 && MockDDLOnce == 1 { - <-TestNotifyBeginTxnCh - MockDDLOnce = 0 - } - }) - - err = f(s) - if err != nil { - s.Rollback() - return - } - return errors.Trace(s.Commit(context.Background())) -} - -var ( - // MockDDLOnce is only used for test. - MockDDLOnce = int64(0) - // TestNotifyBeginTxnCh is used for if the txn is beginning in RunInTxn. - TestNotifyBeginTxnCh = make(chan struct{}) -) diff --git a/pkg/ddl/syncer/binding__failpoint_binding__.go b/pkg/ddl/syncer/binding__failpoint_binding__.go deleted file mode 100644 index 3db8a06874be4..0000000000000 --- a/pkg/ddl/syncer/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package syncer - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/ddl/syncer/syncer.go b/pkg/ddl/syncer/syncer.go index d9e34b6fdf5c4..818d4747eed06 100644 --- a/pkg/ddl/syncer/syncer.go +++ b/pkg/ddl/syncer/syncer.go @@ -273,12 +273,12 @@ func (s *schemaVersionSyncer) storeSession(session *concurrency.Session) { // Done implements SchemaSyncer.Done interface. func (s *schemaVersionSyncer) Done() <-chan struct{} { - if val, _err_ := failpoint.Eval(_curpkg_("ErrorMockSessionDone")); _err_ == nil { + failpoint.Inject("ErrorMockSessionDone", func(val failpoint.Value) { if val.(bool) { err := s.loadSession().Close() logutil.DDLLogger().Error("close session failed", zap.Error(err)) } - } + }) return s.loadSession().Done() } @@ -530,9 +530,9 @@ func (s *schemaVersionSyncer) syncJobSchemaVer(ctx context.Context) { return } } - if _, _err_ := failpoint.Eval(_curpkg_("mockCompaction")); _err_ == nil { + failpoint.Inject("mockCompaction", func() { wresp.CompactRevision = 123 - } + }) if err := wresp.Err(); err != nil { logutil.DDLLogger().Warn("watch job version failed", zap.Error(err)) return diff --git a/pkg/ddl/syncer/syncer.go__failpoint_stash__ b/pkg/ddl/syncer/syncer.go__failpoint_stash__ deleted file mode 100644 index 818d4747eed06..0000000000000 --- a/pkg/ddl/syncer/syncer.go__failpoint_stash__ +++ /dev/null @@ -1,629 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncer - -import ( - "context" - "fmt" - "math" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - tidbutil "github.com/pingcap/tidb/pkg/util" - disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" - "go.etcd.io/etcd/api/v3/mvccpb" - clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/client/v3/concurrency" - "go.uber.org/zap" -) - -const ( - // InitialVersion is the initial schema version for every server. - // It's exported for testing. - InitialVersion = "0" - putKeyNoRetry = 1 - keyOpDefaultRetryCnt = 3 - putKeyRetryUnlimited = math.MaxInt64 - checkVersInterval = 20 * time.Millisecond - ddlPrompt = "ddl-syncer" -) - -var ( - // CheckVersFirstWaitTime is a waitting time before the owner checks all the servers of the schema version, - // and it's an exported variable for testing. - CheckVersFirstWaitTime = 50 * time.Millisecond -) - -// Watcher is responsible for watching the etcd path related operations. -type Watcher interface { - // WatchChan returns the chan for watching etcd path. - WatchChan() clientv3.WatchChan - // Watch watches the etcd path. - Watch(ctx context.Context, etcdCli *clientv3.Client, path string) - // Rewatch rewatches the etcd path. - Rewatch(ctx context.Context, etcdCli *clientv3.Client, path string) -} - -type watcher struct { - sync.RWMutex - wCh clientv3.WatchChan -} - -// WatchChan implements SyncerWatch.WatchChan interface. -func (w *watcher) WatchChan() clientv3.WatchChan { - w.RLock() - defer w.RUnlock() - return w.wCh -} - -// Watch implements SyncerWatch.Watch interface. -func (w *watcher) Watch(ctx context.Context, etcdCli *clientv3.Client, path string) { - w.Lock() - w.wCh = etcdCli.Watch(ctx, path) - w.Unlock() -} - -// Rewatch implements SyncerWatch.Rewatch interface. -func (w *watcher) Rewatch(ctx context.Context, etcdCli *clientv3.Client, path string) { - startTime := time.Now() - // Make sure the wCh doesn't receive the information of 'close' before we finish the rewatch. - w.Lock() - w.wCh = nil - w.Unlock() - - go func() { - defer func() { - metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerRewatch, metrics.RetLabel(nil)).Observe(time.Since(startTime).Seconds()) - }() - wCh := etcdCli.Watch(ctx, path) - - w.Lock() - w.wCh = wCh - w.Unlock() - logutil.DDLLogger().Info("syncer rewatch global info finished") - }() -} - -// SchemaSyncer is used to synchronize schema version between the DDL worker leader and followers through etcd. -type SchemaSyncer interface { - // Init sets the global schema version path to etcd if it isn't exist, - // then watch this path, and initializes the self schema version to etcd. - Init(ctx context.Context) error - // UpdateSelfVersion updates the current version to the self path on etcd. - UpdateSelfVersion(ctx context.Context, jobID int64, version int64) error - // OwnerUpdateGlobalVersion updates the latest version to the global path on etcd until updating is successful or the ctx is done. - OwnerUpdateGlobalVersion(ctx context.Context, version int64) error - // GlobalVersionCh gets the chan for watching global version. - GlobalVersionCh() clientv3.WatchChan - // WatchGlobalSchemaVer watches the global schema version. - WatchGlobalSchemaVer(ctx context.Context) - // Done returns a channel that closes when the syncer is no longer being refreshed. - Done() <-chan struct{} - // Restart restarts the syncer when it's on longer being refreshed. - Restart(ctx context.Context) error - // OwnerCheckAllVersions checks whether all followers' schema version are equal to - // the latest schema version. (exclude the isolated TiDB) - // It returns until all servers' versions are equal to the latest version. - OwnerCheckAllVersions(ctx context.Context, jobID int64, latestVer int64) error - // SyncJobSchemaVerLoop syncs the schema versions on all TiDB nodes for DDL jobs. - SyncJobSchemaVerLoop(ctx context.Context) - // Close ends SchemaSyncer. - Close() -} - -// nodeVersions is used to record the schema versions of all TiDB nodes for a DDL job. -type nodeVersions struct { - sync.Mutex - nodeVersions map[string]int64 - // onceMatchFn is used to check if all the servers report the least version. - // If all the servers report the least version, i.e. return true, it will be - // set to nil. - onceMatchFn func(map[string]int64) bool -} - -func newNodeVersions(initialCap int, fn func(map[string]int64) bool) *nodeVersions { - return &nodeVersions{ - nodeVersions: make(map[string]int64, initialCap), - onceMatchFn: fn, - } -} - -func (v *nodeVersions) add(nodeID string, ver int64) { - v.Lock() - defer v.Unlock() - v.nodeVersions[nodeID] = ver - if v.onceMatchFn != nil { - if ok := v.onceMatchFn(v.nodeVersions); ok { - v.onceMatchFn = nil - } - } -} - -func (v *nodeVersions) del(nodeID string) { - v.Lock() - defer v.Unlock() - delete(v.nodeVersions, nodeID) - // we don't call onceMatchFn here, for only "add" can cause onceMatchFn return - // true currently. -} - -func (v *nodeVersions) len() int { - v.Lock() - defer v.Unlock() - return len(v.nodeVersions) -} - -// matchOrSet onceMatchFn must be nil before calling this method. -func (v *nodeVersions) matchOrSet(fn func(nodeVersions map[string]int64) bool) { - v.Lock() - defer v.Unlock() - if ok := fn(v.nodeVersions); !ok { - v.onceMatchFn = fn - } -} - -func (v *nodeVersions) clearData() { - v.Lock() - defer v.Unlock() - v.nodeVersions = make(map[string]int64, len(v.nodeVersions)) -} - -func (v *nodeVersions) clearMatchFn() { - v.Lock() - defer v.Unlock() - v.onceMatchFn = nil -} - -func (v *nodeVersions) emptyAndNotUsed() bool { - v.Lock() - defer v.Unlock() - return len(v.nodeVersions) == 0 && v.onceMatchFn == nil -} - -// for test -func (v *nodeVersions) getMatchFn() func(map[string]int64) bool { - v.Lock() - defer v.Unlock() - return v.onceMatchFn -} - -type schemaVersionSyncer struct { - selfSchemaVerPath string - etcdCli *clientv3.Client - session unsafe.Pointer - globalVerWatcher watcher - ddlID string - - mu sync.RWMutex - jobNodeVersions map[int64]*nodeVersions - jobNodeVerPrefix string -} - -// NewSchemaSyncer creates a new SchemaSyncer. -func NewSchemaSyncer(etcdCli *clientv3.Client, id string) SchemaSyncer { - return &schemaVersionSyncer{ - etcdCli: etcdCli, - selfSchemaVerPath: fmt.Sprintf("%s/%s", util.DDLAllSchemaVersions, id), - ddlID: id, - - jobNodeVersions: make(map[int64]*nodeVersions), - jobNodeVerPrefix: util.DDLAllSchemaVersionsByJob + "/", - } -} - -// Init implements SchemaSyncer.Init interface. -func (s *schemaVersionSyncer) Init(ctx context.Context) error { - startTime := time.Now() - var err error - defer func() { - metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerInit, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - }() - - _, err = s.etcdCli.Txn(ctx). - If(clientv3.Compare(clientv3.CreateRevision(util.DDLGlobalSchemaVersion), "=", 0)). - Then(clientv3.OpPut(util.DDLGlobalSchemaVersion, InitialVersion)). - Commit() - if err != nil { - return errors.Trace(err) - } - logPrefix := fmt.Sprintf("[%s] %s", ddlPrompt, s.selfSchemaVerPath) - session, err := tidbutil.NewSession(ctx, logPrefix, s.etcdCli, tidbutil.NewSessionDefaultRetryCnt, util.SessionTTL) - if err != nil { - return errors.Trace(err) - } - s.storeSession(session) - - s.globalVerWatcher.Watch(ctx, s.etcdCli, util.DDLGlobalSchemaVersion) - - err = util.PutKVToEtcd(ctx, s.etcdCli, keyOpDefaultRetryCnt, s.selfSchemaVerPath, InitialVersion, - clientv3.WithLease(s.loadSession().Lease())) - return errors.Trace(err) -} - -func (s *schemaVersionSyncer) loadSession() *concurrency.Session { - return (*concurrency.Session)(atomic.LoadPointer(&s.session)) -} - -func (s *schemaVersionSyncer) storeSession(session *concurrency.Session) { - atomic.StorePointer(&s.session, (unsafe.Pointer)(session)) -} - -// Done implements SchemaSyncer.Done interface. -func (s *schemaVersionSyncer) Done() <-chan struct{} { - failpoint.Inject("ErrorMockSessionDone", func(val failpoint.Value) { - if val.(bool) { - err := s.loadSession().Close() - logutil.DDLLogger().Error("close session failed", zap.Error(err)) - } - }) - - return s.loadSession().Done() -} - -// Restart implements SchemaSyncer.Restart interface. -func (s *schemaVersionSyncer) Restart(ctx context.Context) error { - startTime := time.Now() - var err error - defer func() { - metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerRestart, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - }() - - logPrefix := fmt.Sprintf("[%s] %s", ddlPrompt, s.selfSchemaVerPath) - // NewSession's context will affect the exit of the session. - session, err := tidbutil.NewSession(ctx, logPrefix, s.etcdCli, tidbutil.NewSessionRetryUnlimited, util.SessionTTL) - if err != nil { - return errors.Trace(err) - } - s.storeSession(session) - - childCtx, cancel := context.WithTimeout(ctx, util.KeyOpDefaultTimeout) - defer cancel() - err = util.PutKVToEtcd(childCtx, s.etcdCli, putKeyRetryUnlimited, s.selfSchemaVerPath, InitialVersion, - clientv3.WithLease(s.loadSession().Lease())) - - return errors.Trace(err) -} - -// GlobalVersionCh implements SchemaSyncer.GlobalVersionCh interface. -func (s *schemaVersionSyncer) GlobalVersionCh() clientv3.WatchChan { - return s.globalVerWatcher.WatchChan() -} - -// WatchGlobalSchemaVer implements SchemaSyncer.WatchGlobalSchemaVer interface. -func (s *schemaVersionSyncer) WatchGlobalSchemaVer(ctx context.Context) { - s.globalVerWatcher.Rewatch(ctx, s.etcdCli, util.DDLGlobalSchemaVersion) -} - -// UpdateSelfVersion implements SchemaSyncer.UpdateSelfVersion interface. -func (s *schemaVersionSyncer) UpdateSelfVersion(ctx context.Context, jobID int64, version int64) error { - startTime := time.Now() - ver := strconv.FormatInt(version, 10) - var err error - var path string - if variable.EnableMDL.Load() { - path = fmt.Sprintf("%s/%d/%s", util.DDLAllSchemaVersionsByJob, jobID, s.ddlID) - err = util.PutKVToEtcdMono(ctx, s.etcdCli, keyOpDefaultRetryCnt, path, ver) - } else { - path = s.selfSchemaVerPath - err = util.PutKVToEtcd(ctx, s.etcdCli, putKeyNoRetry, path, ver, - clientv3.WithLease(s.loadSession().Lease())) - } - - metrics.UpdateSelfVersionHistogram.WithLabelValues(metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - return errors.Trace(err) -} - -// OwnerUpdateGlobalVersion implements SchemaSyncer.OwnerUpdateGlobalVersion interface. -func (s *schemaVersionSyncer) OwnerUpdateGlobalVersion(ctx context.Context, version int64) error { - startTime := time.Now() - ver := strconv.FormatInt(version, 10) - // TODO: If the version is larger than the original global version, we need set the version. - // Otherwise, we'd better set the original global version. - err := util.PutKVToEtcd(ctx, s.etcdCli, putKeyRetryUnlimited, util.DDLGlobalSchemaVersion, ver) - metrics.OwnerHandleSyncerHistogram.WithLabelValues(metrics.OwnerUpdateGlobalVersion, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - return errors.Trace(err) -} - -// removeSelfVersionPath remove the self path from etcd. -func (s *schemaVersionSyncer) removeSelfVersionPath() error { - startTime := time.Now() - var err error - defer func() { - metrics.DeploySyncerHistogram.WithLabelValues(metrics.SyncerClear, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - }() - - err = util.DeleteKeyFromEtcd(s.selfSchemaVerPath, s.etcdCli, keyOpDefaultRetryCnt, util.KeyOpDefaultTimeout) - return errors.Trace(err) -} - -// OwnerCheckAllVersions implements SchemaSyncer.OwnerCheckAllVersions interface. -func (s *schemaVersionSyncer) OwnerCheckAllVersions(ctx context.Context, jobID int64, latestVer int64) error { - startTime := time.Now() - if !variable.EnableMDL.Load() { - time.Sleep(CheckVersFirstWaitTime) - } - notMatchVerCnt := 0 - intervalCnt := int(time.Second / checkVersInterval) - - var err error - defer func() { - metrics.OwnerHandleSyncerHistogram.WithLabelValues(metrics.OwnerCheckAllVersions, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - }() - - // If MDL is disabled, updatedMap is a cache. We need to ensure all the keys equal to the least version. - // We can skip checking the key if it is checked in the cache(set by the previous loop). - // If MDL is enabled, updatedMap is used to check if all the servers report the least version. - // updatedMap is initialed to record all the server in every loop. We delete a server from the map if it gets the metadata lock(the key version equal the given version. - // updatedMap should be empty if all the servers get the metadata lock. - updatedMap := make(map[string]string) - for { - if err := ctx.Err(); err != nil { - // ctx is canceled or timeout. - return errors.Trace(err) - } - - if variable.EnableMDL.Load() { - serverInfos, err := infosync.GetAllServerInfo(ctx) - if err != nil { - return err - } - updatedMap = make(map[string]string) - instance2id := make(map[string]string) - - for _, info := range serverInfos { - instance := disttaskutil.GenerateExecID(info) - // if some node shutdown abnormally and start, we might see some - // instance with different id, we should use the latest one. - if id, ok := instance2id[instance]; ok { - if info.StartTimestamp > serverInfos[id].StartTimestamp { - // Replace it. - delete(updatedMap, id) - updatedMap[info.ID] = fmt.Sprintf("instance ip %s, port %d, id %s", info.IP, info.Port, info.ID) - instance2id[instance] = info.ID - } - } else { - updatedMap[info.ID] = fmt.Sprintf("instance ip %s, port %d, id %s", info.IP, info.Port, info.ID) - instance2id[instance] = info.ID - } - } - } - - // Check all schema versions. - if variable.EnableMDL.Load() { - notifyCh := make(chan struct{}) - var unmatchedNodeID atomic.Pointer[string] - matchFn := func(nodeVersions map[string]int64) bool { - if len(nodeVersions) < len(updatedMap) { - return false - } - for tidbID := range updatedMap { - if nodeVer, ok := nodeVersions[tidbID]; !ok || nodeVer < latestVer { - id := tidbID - unmatchedNodeID.Store(&id) - return false - } - } - close(notifyCh) - return true - } - item := s.jobSchemaVerMatchOrSet(jobID, matchFn) - select { - case <-notifyCh: - return nil - case <-ctx.Done(): - item.clearMatchFn() - return errors.Trace(ctx.Err()) - case <-time.After(time.Second): - item.clearMatchFn() - if id := unmatchedNodeID.Load(); id != nil { - logutil.DDLLogger().Info("syncer check all versions, someone is not synced", - zap.String("info", *id), - zap.Int64("ddl job id", jobID), - zap.Int64("ver", latestVer)) - } else { - logutil.DDLLogger().Info("syncer check all versions, all nodes are not synced", - zap.Int64("ddl job id", jobID), - zap.Int64("ver", latestVer)) - } - } - } else { - // Get all the schema versions from ETCD. - resp, err := s.etcdCli.Get(ctx, util.DDLAllSchemaVersions, clientv3.WithPrefix()) - if err != nil { - logutil.DDLLogger().Info("syncer check all versions failed, continue checking.", zap.Error(err)) - continue - } - succ := true - for _, kv := range resp.Kvs { - if _, ok := updatedMap[string(kv.Key)]; ok { - continue - } - - succ = isUpdatedLatestVersion(string(kv.Key), string(kv.Value), latestVer, notMatchVerCnt, intervalCnt, true) - if !succ { - break - } - updatedMap[string(kv.Key)] = "" - } - - if succ { - return nil - } - time.Sleep(checkVersInterval) - notMatchVerCnt++ - } - } -} - -// SyncJobSchemaVerLoop implements SchemaSyncer.SyncJobSchemaVerLoop interface. -func (s *schemaVersionSyncer) SyncJobSchemaVerLoop(ctx context.Context) { - for { - s.syncJobSchemaVer(ctx) - logutil.DDLLogger().Info("schema version sync loop interrupted, retrying...") - select { - case <-ctx.Done(): - return - case <-time.After(time.Second): - } - } -} - -func (s *schemaVersionSyncer) syncJobSchemaVer(ctx context.Context) { - resp, err := s.etcdCli.Get(ctx, s.jobNodeVerPrefix, clientv3.WithPrefix()) - if err != nil { - logutil.DDLLogger().Info("get all job versions failed", zap.Error(err)) - return - } - s.mu.Lock() - for jobID, item := range s.jobNodeVersions { - item.clearData() - // we might miss some DELETE events during retry, some items might be emptyAndNotUsed, remove them. - if item.emptyAndNotUsed() { - delete(s.jobNodeVersions, jobID) - } - } - s.mu.Unlock() - for _, oneKV := range resp.Kvs { - s.handleJobSchemaVerKV(oneKV, mvccpb.PUT) - } - - startRev := resp.Header.Revision + 1 - watchCtx, watchCtxCancel := context.WithCancel(ctx) - defer watchCtxCancel() - watchCtx = clientv3.WithRequireLeader(watchCtx) - watchCh := s.etcdCli.Watch(watchCtx, s.jobNodeVerPrefix, clientv3.WithPrefix(), clientv3.WithRev(startRev)) - for { - var ( - wresp clientv3.WatchResponse - ok bool - ) - select { - case <-watchCtx.Done(): - return - case wresp, ok = <-watchCh: - if !ok { - // ctx must be cancelled, else we should have received a response - // with err and caught by below err check. - return - } - } - failpoint.Inject("mockCompaction", func() { - wresp.CompactRevision = 123 - }) - if err := wresp.Err(); err != nil { - logutil.DDLLogger().Warn("watch job version failed", zap.Error(err)) - return - } - for _, ev := range wresp.Events { - s.handleJobSchemaVerKV(ev.Kv, ev.Type) - } - } -} - -func (s *schemaVersionSyncer) handleJobSchemaVerKV(kv *mvccpb.KeyValue, tp mvccpb.Event_EventType) { - jobID, tidbID, schemaVer, valid := decodeJobVersionEvent(kv, tp, s.jobNodeVerPrefix) - if !valid { - logutil.DDLLogger().Error("invalid job version kv", zap.Stringer("kv", kv), zap.Stringer("type", tp)) - return - } - if tp == mvccpb.PUT { - s.mu.Lock() - item, exists := s.jobNodeVersions[jobID] - if !exists { - item = newNodeVersions(1, nil) - s.jobNodeVersions[jobID] = item - } - s.mu.Unlock() - item.add(tidbID, schemaVer) - } else { // DELETE - s.mu.Lock() - if item, exists := s.jobNodeVersions[jobID]; exists { - item.del(tidbID) - if item.len() == 0 { - delete(s.jobNodeVersions, jobID) - } - } - s.mu.Unlock() - } -} - -func (s *schemaVersionSyncer) jobSchemaVerMatchOrSet(jobID int64, matchFn func(map[string]int64) bool) *nodeVersions { - s.mu.Lock() - defer s.mu.Unlock() - - item, exists := s.jobNodeVersions[jobID] - if exists { - item.matchOrSet(matchFn) - } else { - item = newNodeVersions(1, matchFn) - s.jobNodeVersions[jobID] = item - } - return item -} - -func decodeJobVersionEvent(kv *mvccpb.KeyValue, tp mvccpb.Event_EventType, prefix string) (jobID int64, tidbID string, schemaVer int64, valid bool) { - left := strings.TrimPrefix(string(kv.Key), prefix) - parts := strings.Split(left, "/") - if len(parts) != 2 { - return 0, "", 0, false - } - jobID, err := strconv.ParseInt(parts[0], 10, 64) - if err != nil { - return 0, "", 0, false - } - // there is no Value in DELETE event, so we need to check it. - if tp == mvccpb.PUT { - schemaVer, err = strconv.ParseInt(string(kv.Value), 10, 64) - if err != nil { - return 0, "", 0, false - } - } - return jobID, parts[1], schemaVer, true -} - -func isUpdatedLatestVersion(key, val string, latestVer int64, notMatchVerCnt, intervalCnt int, nodeAlive bool) bool { - ver, err := strconv.Atoi(val) - if err != nil { - logutil.DDLLogger().Info("syncer check all versions, convert value to int failed, continue checking.", - zap.String("ddl", key), zap.String("value", val), zap.Error(err)) - return false - } - if int64(ver) < latestVer && nodeAlive { - if notMatchVerCnt%intervalCnt == 0 { - logutil.DDLLogger().Info("syncer check all versions, someone is not synced, continue checking", - zap.String("ddl", key), zap.Int("currentVer", ver), zap.Int64("latestVer", latestVer)) - } - return false - } - return true -} - -func (s *schemaVersionSyncer) Close() { - err := s.removeSelfVersionPath() - if err != nil { - logutil.DDLLogger().Error("remove self version path failed", zap.Error(err)) - } -} diff --git a/pkg/ddl/table.go b/pkg/ddl/table.go index bcc0a0bb5d91c..4d4696b43339c 100644 --- a/pkg/ddl/table.go +++ b/pkg/ddl/table.go @@ -273,14 +273,14 @@ func (w *worker) recoverTable(t *meta.Meta, job *model.Job, recoverInfo *Recover return ver, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("mockRecoverTableCommitErr")); _err_ == nil { + failpoint.Inject("mockRecoverTableCommitErr", func(val failpoint.Value) { if val.(bool) && atomic.CompareAndSwapUint32(&mockRecoverTableCommitErrOnce, 0, 1) { err = failpoint.Enable(`tikvclient/mockCommitErrorOpt`, "return(true)") if err != nil { return } } - } + }) err = updateLabelRules(job, recoverInfo.TableInfo, oldRules, tableRuleID, partRuleIDs, oldRuleIDs, recoverInfo.TableInfo.ID) if err != nil { @@ -292,9 +292,9 @@ func (w *worker) recoverTable(t *meta.Meta, job *model.Job, recoverInfo *Recover } func clearTablePlacementAndBundles(ctx context.Context, tblInfo *model.TableInfo) error { - if _, _err_ := failpoint.Eval(_curpkg_("mockClearTablePlacementAndBundlesErr")); _err_ == nil { - return errors.New("mock error for clearTablePlacementAndBundles") - } + failpoint.Inject("mockClearTablePlacementAndBundlesErr", func() { + failpoint.Return(errors.New("mock error for clearTablePlacementAndBundles")) + }) var bundles []*placement.Bundle if tblInfo.PlacementPolicyRef != nil { tblInfo.PlacementPolicyRef = nil @@ -459,12 +459,12 @@ func (w *worker) onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver i job.State = model.JobStateCancelled return ver, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("truncateTableErr")); _err_ == nil { + failpoint.Inject("truncateTableErr", func(val failpoint.Value) { if val.(bool) { job.State = model.JobStateCancelled - return ver, errors.New("occur an error after dropping table") + failpoint.Return(ver, errors.New("occur an error after dropping table")) } - } + }) // Clear the TiFlash replica progress from ETCD. if tblInfo.TiFlashReplica != nil { @@ -552,11 +552,11 @@ func (w *worker) onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver i return ver, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("mockTruncateTableUpdateVersionError")); _err_ == nil { + failpoint.Inject("mockTruncateTableUpdateVersionError", func(val failpoint.Value) { if val.(bool) { - return ver, errors.New("mock update version error") + failpoint.Return(ver, errors.New("mock update version error")) } - } + }) var partitions []model.PartitionDefinition if pi := tblInfo.GetPartitionInfo(); pi != nil { @@ -829,14 +829,14 @@ func checkAndRenameTables(t *meta.Meta, job *model.Job, tblInfo *model.TableInfo return ver, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("renameTableErr")); _err_ == nil { + failpoint.Inject("renameTableErr", func(val failpoint.Value) { if valStr, ok := val.(string); ok { if tableName.L == valStr { job.State = model.JobStateCancelled - return ver, errors.New("occur an error after renaming table") + failpoint.Return(ver, errors.New("occur an error after renaming table")) } } - } + }) oldTableName := tblInfo.Name tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err := getOldLabelRules(tblInfo, oldSchemaName.L, oldTableName.L) @@ -1291,18 +1291,18 @@ func updateVersionAndTableInfoWithCheck(d *ddlCtx, t *meta.Meta, job *model.Job, // updateVersionAndTableInfo updates the schema version and the table information. func updateVersionAndTableInfo(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, shouldUpdateVer bool, multiInfos ...schemaIDAndTableInfo) ( ver int64, err error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateVersionAndTableInfoErr")); _err_ == nil { + failpoint.Inject("mockUpdateVersionAndTableInfoErr", func(val failpoint.Value) { switch val.(int) { case 1: - return ver, errors.New("mock update version and tableInfo error") + failpoint.Return(ver, errors.New("mock update version and tableInfo error")) case 2: // We change it cancelled directly here, because we want to get the original error with the job id appended. // The job ID will be used to get the job from history queue and we will assert it's args. job.State = model.JobStateCancelled - return ver, errors.New("mock update version and tableInfo error, jobID=" + strconv.Itoa(int(job.ID))) + failpoint.Return(ver, errors.New("mock update version and tableInfo error, jobID="+strconv.Itoa(int(job.ID)))) default: } - } + }) if shouldUpdateVer && (job.MultiSchemaInfo == nil || !job.MultiSchemaInfo.SkipVersion) { ver, err = updateSchemaVersion(d, t, job, multiInfos...) if err != nil { diff --git a/pkg/ddl/table.go__failpoint_stash__ b/pkg/ddl/table.go__failpoint_stash__ deleted file mode 100644 index 4d4696b43339c..0000000000000 --- a/pkg/ddl/table.go__failpoint_stash__ +++ /dev/null @@ -1,1681 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ddl - -import ( - "context" - "encoding/json" - "fmt" - "strconv" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/label" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/ddl/placement" - sess "github.com/pingcap/tidb/pkg/ddl/session" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - field_types "github.com/pingcap/tidb/pkg/parser/types" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - tidb_util "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/gcutil" - "go.uber.org/zap" -) - -const tiflashCheckTiDBHTTPAPIHalfInterval = 2500 * time.Millisecond - -func repairTableOrViewWithCheck(t *meta.Meta, job *model.Job, schemaID int64, tbInfo *model.TableInfo) error { - err := checkTableInfoValid(tbInfo) - if err != nil { - job.State = model.JobStateCancelled - return errors.Trace(err) - } - return t.UpdateTable(schemaID, tbInfo) -} - -func onDropTableOrView(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - tblInfo, err := checkTableExistAndCancelNonExistJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - - originalState := job.SchemaState - switch tblInfo.State { - case model.StatePublic: - // public -> write only - if job.Type == model.ActionDropTable { - err = checkDropTableHasForeignKeyReferredInOwner(d, t, job) - if err != nil { - return ver, err - } - } - tblInfo.State = model.StateWriteOnly - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != tblInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - case model.StateWriteOnly: - // write only -> delete only - tblInfo.State = model.StateDeleteOnly - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != tblInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - case model.StateDeleteOnly: - tblInfo.State = model.StateNone - oldIDs := getPartitionIDs(tblInfo) - ruleIDs := append(getPartitionRuleIDs(job.SchemaName, tblInfo), fmt.Sprintf(label.TableIDFormat, label.IDPrefix, job.SchemaName, tblInfo.Name.L)) - job.CtxVars = []any{oldIDs} - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, originalState != tblInfo.State) - if err != nil { - return ver, errors.Trace(err) - } - if tblInfo.IsSequence() { - if err = t.DropSequence(job.SchemaID, job.TableID); err != nil { - return ver, errors.Trace(err) - } - } else { - if err = t.DropTableOrView(job.SchemaID, job.TableID); err != nil { - return ver, errors.Trace(err) - } - if err = t.GetAutoIDAccessors(job.SchemaID, job.TableID).Del(); err != nil { - return ver, errors.Trace(err) - } - } - if tblInfo.TiFlashReplica != nil { - e := infosync.DeleteTiFlashTableSyncProgress(tblInfo) - if e != nil { - logutil.DDLLogger().Error("DeleteTiFlashTableSyncProgress fails", zap.Error(e)) - } - } - // Placement rules cannot be removed immediately after drop table / truncate table, because the - // tables can be flashed back or recovered, therefore it moved to doGCPlacementRules in gc_worker.go. - - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StateNone, ver, tblInfo) - startKey := tablecodec.EncodeTablePrefix(job.TableID) - job.Args = append(job.Args, startKey, oldIDs, ruleIDs) - if !tblInfo.IsSequence() && !tblInfo.IsView() { - dropTableEvent := statsutil.NewDropTableEvent( - job.SchemaID, - tblInfo, - ) - asyncNotifyEvent(d, dropTableEvent) - } - default: - return ver, errors.Trace(dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tblInfo.State)) - } - job.SchemaState = tblInfo.State - return ver, errors.Trace(err) -} - -func (w *worker) onRecoverTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - var ( - recoverInfo *RecoverInfo - recoverTableCheckFlag int64 - ) - if err = job.DecodeArgs(&recoverInfo, &recoverTableCheckFlag); err != nil { - // Invalid arguments, cancel this job. - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - schemaID := recoverInfo.SchemaID - tblInfo := recoverInfo.TableInfo - if tblInfo.TTLInfo != nil { - // force disable TTL job schedule for recovered table - tblInfo.TTLInfo.Enable = false - } - - // check GC and safe point - gcEnable, err := checkGCEnable(w) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = checkTableNotExists(d, schemaID, tblInfo.Name.L) - if err != nil { - if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { - job.State = model.JobStateCancelled - } - return ver, errors.Trace(err) - } - - err = checkTableIDNotExists(t, schemaID, tblInfo.ID) - if err != nil { - if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { - job.State = model.JobStateCancelled - } - return ver, errors.Trace(err) - } - - // Recover table divide into 2 steps: - // 1. Check GC enable status, to decided whether enable GC after recover table. - // a. Why not disable GC before put the job to DDL job queue? - // Think about concurrency problem. If a recover job-1 is doing and already disabled GC, - // then, another recover table job-2 check GC enable will get disable before into the job queue. - // then, after recover table job-2 finished, the GC will be disabled. - // b. Why split into 2 steps? 1 step also can finish this job: check GC -> disable GC -> recover table -> finish job. - // What if the transaction commit failed? then, the job will retry, but the GC already disabled when first running. - // So, after this job retry succeed, the GC will be disabled. - // 2. Do recover table job. - // a. Check whether GC enabled, if enabled, disable GC first. - // b. Check GC safe point. If drop table time if after safe point time, then can do recover. - // otherwise, can't recover table, because the records of the table may already delete by gc. - // c. Remove GC task of the table from gc_delete_range table. - // d. Create table and rebase table auto ID. - // e. Finish. - switch tblInfo.State { - case model.StateNone: - // none -> write only - // check GC enable and update flag. - if gcEnable { - job.Args[checkFlagIndexInJobArgs] = recoverCheckFlagEnableGC - } else { - job.Args[checkFlagIndexInJobArgs] = recoverCheckFlagDisableGC - } - - job.SchemaState = model.StateWriteOnly - tblInfo.State = model.StateWriteOnly - case model.StateWriteOnly: - // write only -> public - // do recover table. - if gcEnable { - err = disableGC(w) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Errorf("disable gc failed, try again later. err: %v", err) - } - } - // check GC safe point - err = checkSafePoint(w, recoverInfo.SnapshotTS) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - ver, err = w.recoverTable(t, job, recoverInfo) - if err != nil { - return ver, errors.Trace(err) - } - tableInfo := tblInfo.Clone() - tableInfo.State = model.StatePublic - tableInfo.UpdateTS = t.StartTS - ver, err = updateVersionAndTableInfo(d, t, job, tableInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - tblInfo.State = model.StatePublic - tblInfo.UpdateTS = t.StartTS - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - default: - return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tblInfo.State) - } - return ver, nil -} - -func (w *worker) recoverTable(t *meta.Meta, job *model.Job, recoverInfo *RecoverInfo) (ver int64, err error) { - var tids []int64 - if recoverInfo.TableInfo.GetPartitionInfo() != nil { - tids = getPartitionIDs(recoverInfo.TableInfo) - tids = append(tids, recoverInfo.TableInfo.ID) - } else { - tids = []int64{recoverInfo.TableInfo.ID} - } - tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err := getOldLabelRules(recoverInfo.TableInfo, recoverInfo.OldSchemaName, recoverInfo.OldTableName) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to get old label rules from PD") - } - // Remove dropped table DDL job from gc_delete_range table. - err = w.delRangeManager.removeFromGCDeleteRange(w.ctx, recoverInfo.DropJobID) - if err != nil { - return ver, errors.Trace(err) - } - err = clearTablePlacementAndBundles(w.ctx, recoverInfo.TableInfo) - if err != nil { - return ver, errors.Trace(err) - } - - tableInfo := recoverInfo.TableInfo.Clone() - tableInfo.State = model.StatePublic - tableInfo.UpdateTS = t.StartTS - err = t.CreateTableAndSetAutoID(recoverInfo.SchemaID, tableInfo, recoverInfo.AutoIDs) - if err != nil { - return ver, errors.Trace(err) - } - - failpoint.Inject("mockRecoverTableCommitErr", func(val failpoint.Value) { - if val.(bool) && atomic.CompareAndSwapUint32(&mockRecoverTableCommitErrOnce, 0, 1) { - err = failpoint.Enable(`tikvclient/mockCommitErrorOpt`, "return(true)") - if err != nil { - return - } - } - }) - - err = updateLabelRules(job, recoverInfo.TableInfo, oldRules, tableRuleID, partRuleIDs, oldRuleIDs, recoverInfo.TableInfo.ID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to update the label rule to PD") - } - job.CtxVars = []any{tids} - return ver, nil -} - -func clearTablePlacementAndBundles(ctx context.Context, tblInfo *model.TableInfo) error { - failpoint.Inject("mockClearTablePlacementAndBundlesErr", func() { - failpoint.Return(errors.New("mock error for clearTablePlacementAndBundles")) - }) - var bundles []*placement.Bundle - if tblInfo.PlacementPolicyRef != nil { - tblInfo.PlacementPolicyRef = nil - bundles = append(bundles, placement.NewBundle(tblInfo.ID)) - } - - if tblInfo.Partition != nil { - for i := range tblInfo.Partition.Definitions { - par := &tblInfo.Partition.Definitions[i] - if par.PlacementPolicyRef != nil { - par.PlacementPolicyRef = nil - bundles = append(bundles, placement.NewBundle(par.ID)) - } - } - } - - if len(bundles) == 0 { - return nil - } - - return infosync.PutRuleBundlesWithDefaultRetry(ctx, bundles) -} - -// mockRecoverTableCommitErrOnce uses to make sure -// `mockRecoverTableCommitErr` only mock error once. -var mockRecoverTableCommitErrOnce uint32 - -func enableGC(w *worker) error { - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) - - return gcutil.EnableGC(ctx) -} - -func disableGC(w *worker) error { - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) - - return gcutil.DisableGC(ctx) -} - -func checkGCEnable(w *worker) (enable bool, err error) { - ctx, err := w.sessPool.Get() - if err != nil { - return false, errors.Trace(err) - } - defer w.sessPool.Put(ctx) - - return gcutil.CheckGCEnable(ctx) -} - -func checkSafePoint(w *worker, snapshotTS uint64) error { - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) - - return gcutil.ValidateSnapshot(ctx, snapshotTS) -} - -func getTable(r autoid.Requirement, schemaID int64, tblInfo *model.TableInfo) (table.Table, error) { - allocs := autoid.NewAllocatorsFromTblInfo(r, schemaID, tblInfo) - tbl, err := table.TableFromMeta(allocs, tblInfo) - return tbl, errors.Trace(err) -} - -// GetTableInfoAndCancelFaultJob is exported for test. -func GetTableInfoAndCancelFaultJob(t *meta.Meta, job *model.Job, schemaID int64) (*model.TableInfo, error) { - tblInfo, err := checkTableExistAndCancelNonExistJob(t, job, schemaID) - if err != nil { - return nil, errors.Trace(err) - } - - if tblInfo.State != model.StatePublic { - job.State = model.JobStateCancelled - return nil, dbterror.ErrInvalidDDLState.GenWithStack("table %s is not in public, but %s", tblInfo.Name, tblInfo.State) - } - - return tblInfo, nil -} - -func checkTableExistAndCancelNonExistJob(t *meta.Meta, job *model.Job, schemaID int64) (*model.TableInfo, error) { - tblInfo, err := getTableInfo(t, job.TableID, schemaID) - if err == nil { - // Check if table name is renamed. - if job.TableName != "" && tblInfo.Name.L != job.TableName && job.Type != model.ActionRepairTable { - job.State = model.JobStateCancelled - return nil, infoschema.ErrTableNotExists.GenWithStackByArgs(job.SchemaName, job.TableName) - } - return tblInfo, nil - } - if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { - job.State = model.JobStateCancelled - } - return nil, err -} - -func getTableInfo(t *meta.Meta, tableID, schemaID int64) (*model.TableInfo, error) { - // Check this table's database. - tblInfo, err := t.GetTable(schemaID, tableID) - if err != nil { - if meta.ErrDBNotExists.Equal(err) { - return nil, errors.Trace(infoschema.ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", schemaID), - )) - } - return nil, errors.Trace(err) - } - - // Check the table. - if tblInfo == nil { - return nil, errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", schemaID), - fmt.Sprintf("(Table ID %d)", tableID), - )) - } - return tblInfo, nil -} - -// onTruncateTable delete old table meta, and creates a new table identical to old table except for table ID. -// As all the old data is encoded with old table ID, it can not be accessed anymore. -// A background job will be created to delete old data. -func (w *worker) onTruncateTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - schemaID := job.SchemaID - tableID := job.TableID - var newTableID int64 - var fkCheck bool - var newPartitionIDs []int64 - err := job.DecodeArgs(&newTableID, &fkCheck, &newPartitionIDs) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return ver, errors.Trace(err) - } - if tblInfo.IsView() || tblInfo.IsSequence() { - job.State = model.JobStateCancelled - return ver, infoschema.ErrTableNotExists.GenWithStackByArgs(job.SchemaName, tblInfo.Name.O) - } - // Copy the old tableInfo for later usage. - oldTblInfo := tblInfo.Clone() - err = checkTruncateTableHasForeignKeyReferredInOwner(d, t, job, tblInfo, fkCheck) - if err != nil { - return ver, err - } - err = t.DropTableOrView(schemaID, tblInfo.ID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - err = t.GetAutoIDAccessors(schemaID, tblInfo.ID).Del() - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - failpoint.Inject("truncateTableErr", func(val failpoint.Value) { - if val.(bool) { - job.State = model.JobStateCancelled - failpoint.Return(ver, errors.New("occur an error after dropping table")) - } - }) - - // Clear the TiFlash replica progress from ETCD. - if tblInfo.TiFlashReplica != nil { - e := infosync.DeleteTiFlashTableSyncProgress(tblInfo) - if e != nil { - logutil.DDLLogger().Error("DeleteTiFlashTableSyncProgress fails", zap.Error(e)) - } - } - - var oldPartitionIDs []int64 - if tblInfo.GetPartitionInfo() != nil { - oldPartitionIDs = getPartitionIDs(tblInfo) - // We use the new partition ID because all the old data is encoded with the old partition ID, it can not be accessed anymore. - err = truncateTableByReassignPartitionIDs(t, tblInfo, newPartitionIDs) - if err != nil { - return ver, errors.Trace(err) - } - } - - if pi := tblInfo.GetPartitionInfo(); pi != nil { - oldIDs := make([]int64, 0, len(oldPartitionIDs)) - newIDs := make([]int64, 0, len(oldPartitionIDs)) - newDefs := pi.Definitions - for i := range oldPartitionIDs { - newDef := &newDefs[i] - newID := newDef.ID - if newDef.PlacementPolicyRef != nil { - oldIDs = append(oldIDs, oldPartitionIDs[i]) - newIDs = append(newIDs, newID) - } - } - job.CtxVars = []any{oldIDs, newIDs} - } - - tableRuleID, partRuleIDs, _, oldRules, err := getOldLabelRules(tblInfo, job.SchemaName, tblInfo.Name.L) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Wrapf(err, "failed to get old label rules from PD") - } - - err = updateLabelRules(job, tblInfo, oldRules, tableRuleID, partRuleIDs, []string{}, newTableID) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Wrapf(err, "failed to update the label rule to PD") - } - - // Clear the TiFlash replica available status. - if tblInfo.TiFlashReplica != nil { - // Set PD rules for TiFlash - if pi := tblInfo.GetPartitionInfo(); pi != nil { - if e := infosync.ConfigureTiFlashPDForPartitions(true, &pi.Definitions, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels, tblInfo.ID); e != nil { - logutil.DDLLogger().Error("ConfigureTiFlashPDForPartitions fails", zap.Error(err)) - job.State = model.JobStateCancelled - return ver, errors.Trace(e) - } - } else { - if e := infosync.ConfigureTiFlashPDForTable(newTableID, tblInfo.TiFlashReplica.Count, &tblInfo.TiFlashReplica.LocationLabels); e != nil { - logutil.DDLLogger().Error("ConfigureTiFlashPDForTable fails", zap.Error(err)) - job.State = model.JobStateCancelled - return ver, errors.Trace(e) - } - } - tblInfo.TiFlashReplica.AvailablePartitionIDs = nil - tblInfo.TiFlashReplica.Available = false - } - - tblInfo.ID = newTableID - - // build table & partition bundles if any. - bundles, err := placement.NewFullTableBundles(t, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Wrapf(err, "failed to notify PD the placement rules") - } - - err = t.CreateTableOrView(schemaID, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - failpoint.Inject("mockTruncateTableUpdateVersionError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(ver, errors.New("mock update version error")) - } - }) - - var partitions []model.PartitionDefinition - if pi := tblInfo.GetPartitionInfo(); pi != nil { - partitions = tblInfo.GetPartitionInfo().Definitions - } - preSplitAndScatter(w.sess.Context, d.store, tblInfo, partitions) - - ver, err = updateSchemaVersion(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - truncateTableEvent := statsutil.NewTruncateTableEvent( - job.SchemaID, - tblInfo, - oldTblInfo, - ) - asyncNotifyEvent(d, truncateTableEvent) - startKey := tablecodec.EncodeTablePrefix(tableID) - job.Args = []any{startKey, oldPartitionIDs} - return ver, nil -} - -func onRebaseAutoIncrementIDType(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - return onRebaseAutoID(d, d.store, t, job, autoid.AutoIncrementType) -} - -func onRebaseAutoRandomType(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - return onRebaseAutoID(d, d.store, t, job, autoid.AutoRandomType) -} - -func onRebaseAutoID(d *ddlCtx, _ kv.Storage, t *meta.Meta, job *model.Job, tp autoid.AllocatorType) (ver int64, _ error) { - schemaID := job.SchemaID - var ( - newBase int64 - force bool - ) - err := job.DecodeArgs(&newBase, &force) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - return ver, nil - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - tbl, err := getTable(d.getAutoIDRequirement(), schemaID, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if !force { - newBaseTemp, err := adjustNewBaseToNextGlobalID(nil, tbl, tp, newBase) - if err != nil { - return ver, errors.Trace(err) - } - if newBase != newBaseTemp { - job.Warning = toTError(fmt.Errorf("Can't reset AUTO_INCREMENT to %d without FORCE option, using %d instead", - newBase, newBaseTemp, - )) - } - newBase = newBaseTemp - } - - if tp == autoid.AutoIncrementType { - tblInfo.AutoIncID = newBase - } else { - tblInfo.AutoRandID = newBase - } - - if alloc := tbl.Allocators(nil).Get(tp); alloc != nil { - // The next value to allocate is `newBase`. - newEnd := newBase - 1 - if force { - err = alloc.ForceRebase(newEnd) - } else { - err = alloc.Rebase(context.Background(), newEnd, false) - } - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func onModifyTableAutoIDCache(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { - var cache int64 - if err := job.DecodeArgs(&cache); err != nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return 0, errors.Trace(err) - } - - tblInfo.AutoIdCache = cache - ver, err := updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func (w *worker) onShardRowID(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var shardRowIDBits uint64 - err := job.DecodeArgs(&shardRowIDBits) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - if shardRowIDBits < tblInfo.ShardRowIDBits { - tblInfo.ShardRowIDBits = shardRowIDBits - } else { - tbl, err := getTable(d.getAutoIDRequirement(), job.SchemaID, tblInfo) - if err != nil { - return ver, errors.Trace(err) - } - err = verifyNoOverflowShardBits(w.sessPool, tbl, shardRowIDBits) - if err != nil { - job.State = model.JobStateCancelled - return ver, err - } - tblInfo.ShardRowIDBits = shardRowIDBits - // MaxShardRowIDBits use to check the overflow of auto ID. - tblInfo.MaxShardRowIDBits = shardRowIDBits - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func verifyNoOverflowShardBits(s *sess.Pool, tbl table.Table, shardRowIDBits uint64) error { - if shardRowIDBits == 0 { - return nil - } - ctx, err := s.Get() - if err != nil { - return errors.Trace(err) - } - defer s.Put(ctx) - // Check next global max auto ID first. - autoIncID, err := tbl.Allocators(ctx.GetTableCtx()).Get(autoid.RowIDAllocType).NextGlobalAutoID() - if err != nil { - return errors.Trace(err) - } - if tables.OverflowShardBits(autoIncID, shardRowIDBits, autoid.RowIDBitLength, true) { - return autoid.ErrAutoincReadFailed.GenWithStack("shard_row_id_bits %d will cause next global auto ID %v overflow", shardRowIDBits, autoIncID) - } - return nil -} - -func onRenameTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var oldSchemaID int64 - var oldSchemaName model.CIStr - var tableName model.CIStr - if err := job.DecodeArgs(&oldSchemaID, &tableName, &oldSchemaName); err != nil { - // Invalid arguments, cancel this job. - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if job.SchemaState == model.StatePublic { - return finishJobRenameTable(d, t, job) - } - newSchemaID := job.SchemaID - err := checkTableNotExists(d, newSchemaID, tableName.L) - if err != nil { - if infoschema.ErrDatabaseNotExists.Equal(err) || infoschema.ErrTableExists.Equal(err) { - job.State = model.JobStateCancelled - } - return ver, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, oldSchemaID) - if err != nil { - return ver, errors.Trace(err) - } - oldTableName := tblInfo.Name - ver, err = checkAndRenameTables(t, job, tblInfo, oldSchemaID, job.SchemaID, &oldSchemaName, &tableName) - if err != nil { - return ver, errors.Trace(err) - } - fkh := newForeignKeyHelper() - err = adjustForeignKeyChildTableInfoAfterRenameTable(d, t, job, &fkh, tblInfo, oldSchemaName, oldTableName, tableName, newSchemaID) - if err != nil { - return ver, errors.Trace(err) - } - ver, err = updateSchemaVersion(d, t, job, fkh.getLoadedTables()...) - if err != nil { - return ver, errors.Trace(err) - } - job.SchemaState = model.StatePublic - return ver, nil -} - -func onRenameTables(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - oldSchemaIDs := []int64{} - newSchemaIDs := []int64{} - tableNames := []*model.CIStr{} - tableIDs := []int64{} - oldSchemaNames := []*model.CIStr{} - oldTableNames := []*model.CIStr{} - if err := job.DecodeArgs(&oldSchemaIDs, &newSchemaIDs, &tableNames, &tableIDs, &oldSchemaNames, &oldTableNames); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if job.SchemaState == model.StatePublic { - return finishJobRenameTables(d, t, job, tableNames, tableIDs, newSchemaIDs) - } - - var err error - fkh := newForeignKeyHelper() - for i, oldSchemaID := range oldSchemaIDs { - job.TableID = tableIDs[i] - job.TableName = oldTableNames[i].L - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, oldSchemaID) - if err != nil { - return ver, errors.Trace(err) - } - ver, err := checkAndRenameTables(t, job, tblInfo, oldSchemaID, newSchemaIDs[i], oldSchemaNames[i], tableNames[i]) - if err != nil { - return ver, errors.Trace(err) - } - err = adjustForeignKeyChildTableInfoAfterRenameTable(d, t, job, &fkh, tblInfo, *oldSchemaNames[i], *oldTableNames[i], *tableNames[i], newSchemaIDs[i]) - if err != nil { - return ver, errors.Trace(err) - } - } - - ver, err = updateSchemaVersion(d, t, job, fkh.getLoadedTables()...) - if err != nil { - return ver, errors.Trace(err) - } - job.SchemaState = model.StatePublic - return ver, nil -} - -func checkAndRenameTables(t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, oldSchemaID, newSchemaID int64, oldSchemaName, tableName *model.CIStr) (ver int64, _ error) { - err := t.DropTableOrView(oldSchemaID, tblInfo.ID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - failpoint.Inject("renameTableErr", func(val failpoint.Value) { - if valStr, ok := val.(string); ok { - if tableName.L == valStr { - job.State = model.JobStateCancelled - failpoint.Return(ver, errors.New("occur an error after renaming table")) - } - } - }) - - oldTableName := tblInfo.Name - tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err := getOldLabelRules(tblInfo, oldSchemaName.L, oldTableName.L) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to get old label rules from PD") - } - - if tblInfo.AutoIDSchemaID == 0 && newSchemaID != oldSchemaID { - // The auto id is referenced by a schema id + table id - // Table ID is not changed between renames, but schema id can change. - // To allow concurrent use of the auto id during rename, keep the auto id - // by always reference it with the schema id it was originally created in. - tblInfo.AutoIDSchemaID = oldSchemaID - } - if newSchemaID == tblInfo.AutoIDSchemaID { - // Back to the original schema id, no longer needed. - tblInfo.AutoIDSchemaID = 0 - } - - tblInfo.Name = *tableName - err = t.CreateTableOrView(newSchemaID, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - err = updateLabelRules(job, tblInfo, oldRules, tableRuleID, partRuleIDs, oldRuleIDs, tblInfo.ID) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to update the label rule to PD") - } - - return ver, nil -} - -func adjustForeignKeyChildTableInfoAfterRenameTable(d *ddlCtx, t *meta.Meta, job *model.Job, fkh *foreignKeyHelper, tblInfo *model.TableInfo, oldSchemaName, oldTableName, newTableName model.CIStr, newSchemaID int64) error { - if !variable.EnableForeignKey.Load() || newTableName.L == oldTableName.L { - return nil - } - is, err := getAndCheckLatestInfoSchema(d, t) - if err != nil { - return err - } - newDB, ok := is.SchemaByID(newSchemaID) - if !ok { - job.State = model.JobStateCancelled - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs(fmt.Sprintf("schema-ID: %v", newSchemaID)) - } - referredFKs := is.GetTableReferredForeignKeys(oldSchemaName.L, oldTableName.L) - if len(referredFKs) == 0 { - return nil - } - fkh.addLoadedTable(oldSchemaName.L, oldTableName.L, newDB.ID, tblInfo) - for _, referredFK := range referredFKs { - childTableInfo, err := fkh.getTableFromStorage(is, t, referredFK.ChildSchema, referredFK.ChildTable) - if err != nil { - if infoschema.ErrTableNotExists.Equal(err) || infoschema.ErrDatabaseNotExists.Equal(err) { - continue - } - return err - } - childFKInfo := model.FindFKInfoByName(childTableInfo.tblInfo.ForeignKeys, referredFK.ChildFKName.L) - if childFKInfo == nil { - continue - } - childFKInfo.RefSchema = newDB.Name - childFKInfo.RefTable = newTableName - } - for _, info := range fkh.loaded { - err = updateTable(t, info.schemaID, info.tblInfo) - if err != nil { - return err - } - } - return nil -} - -// We split the renaming table job into two steps: -// 1. rename table and update the schema version. -// 2. update the job state to JobStateDone. -// This is the requirement from TiCDC because -// - it uses the job state to check whether the DDL is finished. -// - there is a gap between schema reloading and job state updating: -// when the job state is updated to JobStateDone, before the new schema reloaded, -// there may be DMLs that use the old schema. -// - TiCDC cannot handle the DMLs that use the old schema, because -// the commit TS of the DMLs are greater than the job state updating TS. -func finishJobRenameTable(d *ddlCtx, t *meta.Meta, job *model.Job) (int64, error) { - tblInfo, err := getTableInfo(t, job.TableID, job.SchemaID) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(err) - } - // Before updating the schema version, we need to reset the old schema ID to new schema ID, so that - // the table info can be dropped normally in `ApplyDiff`. This is because renaming table requires two - // schema versions to complete. - oldRawArgs := job.RawArgs - job.Args[0] = job.SchemaID - job.RawArgs, err = json.Marshal(job.Args) - if err != nil { - return 0, errors.Trace(err) - } - ver, err := updateSchemaVersion(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - job.RawArgs = oldRawArgs - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func finishJobRenameTables(d *ddlCtx, t *meta.Meta, job *model.Job, - tableNames []*model.CIStr, tableIDs, newSchemaIDs []int64) (int64, error) { - tblSchemaIDs := make(map[int64]int64, len(tableIDs)) - for i := range tableIDs { - tblSchemaIDs[tableIDs[i]] = newSchemaIDs[i] - } - tblInfos := make([]*model.TableInfo, 0, len(tableNames)) - for i := range tableIDs { - tblID := tableIDs[i] - tblInfo, err := getTableInfo(t, tblID, tblSchemaIDs[tblID]) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(err) - } - tblInfos = append(tblInfos, tblInfo) - } - // Before updating the schema version, we need to reset the old schema ID to new schema ID, so that - // the table info can be dropped normally in `ApplyDiff`. This is because renaming table requires two - // schema versions to complete. - var err error - oldRawArgs := job.RawArgs - job.Args[0] = newSchemaIDs - job.RawArgs, err = json.Marshal(job.Args) - if err != nil { - return 0, errors.Trace(err) - } - ver, err := updateSchemaVersion(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - job.RawArgs = oldRawArgs - job.FinishMultipleTableJob(model.JobStateDone, model.StatePublic, ver, tblInfos) - return ver, nil -} - -func onModifyTableComment(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var comment string - if err := job.DecodeArgs(&comment); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - - if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - return ver, nil - } - - tblInfo.Comment = comment - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func onModifyTableCharsetAndCollate(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var toCharset, toCollate string - var needsOverwriteCols bool - if err := job.DecodeArgs(&toCharset, &toCollate, &needsOverwriteCols); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - dbInfo, err := checkSchemaExistAndCancelNotExistJob(t, job) - if err != nil { - return ver, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - - // double check. - _, err = checkAlterTableCharset(tblInfo, dbInfo, toCharset, toCollate, needsOverwriteCols) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if job.MultiSchemaInfo != nil && job.MultiSchemaInfo.Revertible { - job.MarkNonRevertible() - return ver, nil - } - - tblInfo.Charset = toCharset - tblInfo.Collate = toCollate - - if needsOverwriteCols { - // update column charset. - for _, col := range tblInfo.Columns { - if field_types.HasCharset(&col.FieldType) { - col.SetCharset(toCharset) - col.SetCollate(toCollate) - } else { - col.SetCharset(charset.CharsetBin) - col.SetCollate(charset.CharsetBin) - } - } - } - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func (w *worker) onSetTableFlashReplica(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var replicaInfo ast.TiFlashReplicaSpec - if err := job.DecodeArgs(&replicaInfo); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - - // Ban setting replica count for tables in system database. - if tidb_util.IsMemOrSysDB(job.SchemaName) { - return ver, errors.Trace(dbterror.ErrUnsupportedTiFlashOperationForSysOrMemTable) - } - - err = w.checkTiFlashReplicaCount(replicaInfo.Count) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - // We should check this first, in order to avoid creating redundant DDL jobs. - if pi := tblInfo.GetPartitionInfo(); pi != nil { - logutil.DDLLogger().Info("Set TiFlash replica pd rule for partitioned table", zap.Int64("tableID", tblInfo.ID)) - if e := infosync.ConfigureTiFlashPDForPartitions(false, &pi.Definitions, replicaInfo.Count, &replicaInfo.Labels, tblInfo.ID); e != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(e) - } - // Partitions that in adding mid-state. They have high priorities, so we should set accordingly pd rules. - if e := infosync.ConfigureTiFlashPDForPartitions(true, &pi.AddingDefinitions, replicaInfo.Count, &replicaInfo.Labels, tblInfo.ID); e != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(e) - } - } else { - logutil.DDLLogger().Info("Set TiFlash replica pd rule", zap.Int64("tableID", tblInfo.ID)) - if e := infosync.ConfigureTiFlashPDForTable(tblInfo.ID, replicaInfo.Count, &replicaInfo.Labels); e != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(e) - } - } - - available := false - if tblInfo.TiFlashReplica != nil { - available = tblInfo.TiFlashReplica.Available - } - if replicaInfo.Count > 0 { - tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ - Count: replicaInfo.Count, - LocationLabels: replicaInfo.Labels, - Available: available, - } - } else { - if tblInfo.TiFlashReplica != nil { - err = infosync.DeleteTiFlashTableSyncProgress(tblInfo) - if err != nil { - logutil.DDLLogger().Error("DeleteTiFlashTableSyncProgress fails", zap.Error(err)) - } - } - tblInfo.TiFlashReplica = nil - } - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func (w *worker) checkTiFlashReplicaCount(replicaCount uint64) error { - ctx, err := w.sessPool.Get() - if err != nil { - return errors.Trace(err) - } - defer w.sessPool.Put(ctx) - - return checkTiFlashReplicaCount(ctx, replicaCount) -} - -func onUpdateFlashReplicaStatus(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - var available bool - var physicalID int64 - if err := job.DecodeArgs(&available, &physicalID); err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return ver, errors.Trace(err) - } - if tblInfo.TiFlashReplica == nil || (tblInfo.ID == physicalID && tblInfo.TiFlashReplica.Available == available) || - (tblInfo.ID != physicalID && available == tblInfo.TiFlashReplica.IsPartitionAvailable(physicalID)) { - job.State = model.JobStateCancelled - return ver, errors.Errorf("the replica available status of table %s is already updated", tblInfo.Name.String()) - } - - if tblInfo.ID == physicalID { - tblInfo.TiFlashReplica.Available = available - } else if pi := tblInfo.GetPartitionInfo(); pi != nil { - // Partition replica become available. - if available { - allAvailable := true - for _, p := range pi.Definitions { - if p.ID == physicalID { - tblInfo.TiFlashReplica.AvailablePartitionIDs = append(tblInfo.TiFlashReplica.AvailablePartitionIDs, physicalID) - } - allAvailable = allAvailable && tblInfo.TiFlashReplica.IsPartitionAvailable(p.ID) - } - tblInfo.TiFlashReplica.Available = allAvailable - } else { - // Partition replica become unavailable. - for i, id := range tblInfo.TiFlashReplica.AvailablePartitionIDs { - if id == physicalID { - newIDs := tblInfo.TiFlashReplica.AvailablePartitionIDs[:i] - newIDs = append(newIDs, tblInfo.TiFlashReplica.AvailablePartitionIDs[i+1:]...) - tblInfo.TiFlashReplica.AvailablePartitionIDs = newIDs - tblInfo.TiFlashReplica.Available = false - logutil.DDLLogger().Info("TiFlash replica become unavailable", zap.Int64("tableID", tblInfo.ID), zap.Int64("partitionID", id)) - break - } - } - } - } else { - job.State = model.JobStateCancelled - return ver, errors.Errorf("unknown physical ID %v in table %v", physicalID, tblInfo.Name.O) - } - - if tblInfo.TiFlashReplica.Available { - logutil.DDLLogger().Info("TiFlash replica available", zap.Int64("tableID", tblInfo.ID)) - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -// checking using cached info schema should be enough, as: -// - we will reload schema until success when become the owner -// - existing tables are correctly checked in the first place -// - we calculate job dependencies before running jobs, so there will not be 2 -// jobs creating same table running concurrently. -// -// if there are 2 owners A and B, we have 2 consecutive jobs J1 and J2 which -// are creating the same table T. those 2 jobs might be running concurrently when -// A sees J1 first and B sees J2 first. But for B sees J2 first, J1 must already -// be done and synced, and been deleted from tidb_ddl_job table, as we are querying -// jobs in the order of job id. During syncing J1, B should have synced the schema -// with the latest schema version, so when B runs J2, below check will see the table -// T already exists, and J2 will fail. -func checkTableNotExists(d *ddlCtx, schemaID int64, tableName string) error { - is := d.infoCache.GetLatest() - return checkTableNotExistsFromInfoSchema(is, schemaID, tableName) -} - -func checkConstraintNamesNotExists(t *meta.Meta, schemaID int64, constraints []*model.ConstraintInfo) error { - if len(constraints) == 0 { - return nil - } - tbInfos, err := t.ListTables(schemaID) - if err != nil { - return err - } - - for _, tb := range tbInfos { - for _, constraint := range constraints { - if constraint.State != model.StateWriteOnly { - if constraintInfo := tb.FindConstraintInfoByName(constraint.Name.L); constraintInfo != nil { - return infoschema.ErrCheckConstraintDupName.GenWithStackByArgs(constraint.Name.L) - } - } - } - } - - return nil -} - -func checkTableIDNotExists(t *meta.Meta, schemaID, tableID int64) error { - tbl, err := t.GetTable(schemaID, tableID) - if err != nil { - if meta.ErrDBNotExists.Equal(err) { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") - } - return errors.Trace(err) - } - if tbl != nil { - return infoschema.ErrTableExists.GenWithStackByArgs(tbl.Name) - } - return nil -} - -func checkTableNotExistsFromInfoSchema(is infoschema.InfoSchema, schemaID int64, tableName string) error { - // Check this table's database. - schema, ok := is.SchemaByID(schemaID) - if !ok { - return infoschema.ErrDatabaseNotExists.GenWithStackByArgs("") - } - if is.TableExists(schema.Name, model.NewCIStr(tableName)) { - return infoschema.ErrTableExists.GenWithStackByArgs(tableName) - } - return nil -} - -// updateVersionAndTableInfoWithCheck checks table info validate and updates the schema version and the table information -func updateVersionAndTableInfoWithCheck(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, shouldUpdateVer bool, multiInfos ...schemaIDAndTableInfo) ( - ver int64, err error) { - err = checkTableInfoValid(tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - for _, info := range multiInfos { - err = checkTableInfoValid(info.tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - } - return updateVersionAndTableInfo(d, t, job, tblInfo, shouldUpdateVer, multiInfos...) -} - -// updateVersionAndTableInfo updates the schema version and the table information. -func updateVersionAndTableInfo(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo, shouldUpdateVer bool, multiInfos ...schemaIDAndTableInfo) ( - ver int64, err error) { - failpoint.Inject("mockUpdateVersionAndTableInfoErr", func(val failpoint.Value) { - switch val.(int) { - case 1: - failpoint.Return(ver, errors.New("mock update version and tableInfo error")) - case 2: - // We change it cancelled directly here, because we want to get the original error with the job id appended. - // The job ID will be used to get the job from history queue and we will assert it's args. - job.State = model.JobStateCancelled - failpoint.Return(ver, errors.New("mock update version and tableInfo error, jobID="+strconv.Itoa(int(job.ID)))) - default: - } - }) - if shouldUpdateVer && (job.MultiSchemaInfo == nil || !job.MultiSchemaInfo.SkipVersion) { - ver, err = updateSchemaVersion(d, t, job, multiInfos...) - if err != nil { - return 0, errors.Trace(err) - } - } - - err = updateTable(t, job.SchemaID, tblInfo) - if err != nil { - return 0, errors.Trace(err) - } - for _, info := range multiInfos { - err = updateTable(t, info.schemaID, info.tblInfo) - if err != nil { - return 0, errors.Trace(err) - } - } - return ver, nil -} - -func updateTable(t *meta.Meta, schemaID int64, tblInfo *model.TableInfo) error { - if tblInfo.State == model.StatePublic { - tblInfo.UpdateTS = t.StartTS - } - return t.UpdateTable(schemaID, tblInfo) -} - -type schemaIDAndTableInfo struct { - schemaID int64 - tblInfo *model.TableInfo -} - -func onRepairTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, _ error) { - schemaID := job.SchemaID - tblInfo := &model.TableInfo{} - - if err := job.DecodeArgs(tblInfo); err != nil { - // Invalid arguments, cancel this job. - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - tblInfo.State = model.StateNone - - // Check the old DB and old table exist. - _, err := GetTableInfoAndCancelFaultJob(t, job, schemaID) - if err != nil { - return ver, errors.Trace(err) - } - - // When in repair mode, the repaired table in a server is not access to user, - // the table after repairing will be removed from repair list. Other server left - // behind alive may need to restart to get the latest schema version. - ver, err = updateSchemaVersion(d, t, job) - if err != nil { - return ver, errors.Trace(err) - } - switch tblInfo.State { - case model.StateNone: - // none -> public - tblInfo.State = model.StatePublic - tblInfo.UpdateTS = t.StartTS - err = repairTableOrViewWithCheck(t, job, schemaID, tblInfo) - if err != nil { - return ver, errors.Trace(err) - } - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil - default: - return ver, dbterror.ErrInvalidDDLState.GenWithStackByArgs("table", tblInfo.State) - } -} - -func onAlterTableAttributes(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - rule := label.NewRule() - err = job.DecodeArgs(rule) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return 0, err - } - - if len(rule.Labels) == 0 { - patch := label.NewRulePatch([]*label.Rule{}, []string{rule.ID}) - err = infosync.UpdateLabelRules(context.TODO(), patch) - } else { - err = infosync.PutLabelRule(context.TODO(), rule) - } - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Wrapf(err, "failed to notify PD the label rules") - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - - return ver, nil -} - -func onAlterTablePartitionAttributes(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - var partitionID int64 - rule := label.NewRule() - err = job.DecodeArgs(&partitionID, rule) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(err) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return 0, err - } - - ptInfo := tblInfo.GetPartitionInfo() - if ptInfo.GetNameByID(partitionID) == "" { - job.State = model.JobStateCancelled - return 0, errors.Trace(table.ErrUnknownPartition.GenWithStackByArgs("drop?", tblInfo.Name.O)) - } - - if len(rule.Labels) == 0 { - patch := label.NewRulePatch([]*label.Rule{}, []string{rule.ID}) - err = infosync.UpdateLabelRules(context.TODO(), patch) - } else { - err = infosync.PutLabelRule(context.TODO(), rule) - } - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Wrapf(err, "failed to notify PD the label rules") - } - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - - return ver, nil -} - -func onAlterTablePartitionPlacement(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - var partitionID int64 - policyRefInfo := &model.PolicyRefInfo{} - err = job.DecodeArgs(&partitionID, &policyRefInfo) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(err) - } - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return 0, err - } - - ptInfo := tblInfo.GetPartitionInfo() - var partitionDef *model.PartitionDefinition - definitions := ptInfo.Definitions - oldPartitionEnablesPlacement := false - for i := range definitions { - if partitionID == definitions[i].ID { - def := &definitions[i] - oldPartitionEnablesPlacement = def.PlacementPolicyRef != nil - def.PlacementPolicyRef = policyRefInfo - partitionDef = &definitions[i] - break - } - } - - if partitionDef == nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(table.ErrUnknownPartition.GenWithStackByArgs("drop?", tblInfo.Name.O)) - } - - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - - if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, partitionDef.PlacementPolicyRef); err != nil { - return ver, errors.Trace(err) - } - - bundle, err := placement.NewPartitionBundle(t, *partitionDef) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if bundle == nil && oldPartitionEnablesPlacement { - bundle = placement.NewBundle(partitionDef.ID) - } - - // Send the placement bundle to PD. - if bundle != nil { - err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), []*placement.Bundle{bundle}) - } - - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Wrapf(err, "failed to notify PD the placement rules") - } - - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - return ver, nil -} - -func onAlterTablePlacement(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - policyRefInfo := &model.PolicyRefInfo{} - err = job.DecodeArgs(&policyRefInfo) - if err != nil { - job.State = model.JobStateCancelled - return 0, errors.Trace(err) - } - - tblInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return 0, err - } - - if _, err = checkPlacementPolicyRefValidAndCanNonValidJob(t, job, policyRefInfo); err != nil { - return 0, errors.Trace(err) - } - - oldTableEnablesPlacement := tblInfo.PlacementPolicyRef != nil - tblInfo.PlacementPolicyRef = policyRefInfo - ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true) - if err != nil { - return ver, errors.Trace(err) - } - - bundle, err := placement.NewTableBundle(t, tblInfo) - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - if bundle == nil && oldTableEnablesPlacement { - bundle = placement.NewBundle(tblInfo.ID) - } - - // Send the placement bundle to PD. - if bundle != nil { - err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), []*placement.Bundle{bundle}) - } - - if err != nil { - job.State = model.JobStateCancelled - return ver, errors.Trace(err) - } - - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tblInfo) - - return ver, nil -} - -func getOldLabelRules(tblInfo *model.TableInfo, oldSchemaName, oldTableName string) (string, []string, []string, map[string]*label.Rule, error) { - tableRuleID := fmt.Sprintf(label.TableIDFormat, label.IDPrefix, oldSchemaName, oldTableName) - oldRuleIDs := []string{tableRuleID} - var partRuleIDs []string - if tblInfo.GetPartitionInfo() != nil { - for _, def := range tblInfo.GetPartitionInfo().Definitions { - partRuleIDs = append(partRuleIDs, fmt.Sprintf(label.PartitionIDFormat, label.IDPrefix, oldSchemaName, oldTableName, def.Name.L)) - } - } - - oldRuleIDs = append(oldRuleIDs, partRuleIDs...) - oldRules, err := infosync.GetLabelRules(context.TODO(), oldRuleIDs) - return tableRuleID, partRuleIDs, oldRuleIDs, oldRules, err -} - -func updateLabelRules(job *model.Job, tblInfo *model.TableInfo, oldRules map[string]*label.Rule, tableRuleID string, partRuleIDs, oldRuleIDs []string, tID int64) error { - if oldRules == nil { - return nil - } - var newRules []*label.Rule - if tblInfo.GetPartitionInfo() != nil { - for idx, def := range tblInfo.GetPartitionInfo().Definitions { - if r, ok := oldRules[partRuleIDs[idx]]; ok { - newRules = append(newRules, r.Clone().Reset(job.SchemaName, tblInfo.Name.L, def.Name.L, def.ID)) - } - } - } - ids := []int64{tID} - if r, ok := oldRules[tableRuleID]; ok { - if tblInfo.GetPartitionInfo() != nil { - for _, def := range tblInfo.GetPartitionInfo().Definitions { - ids = append(ids, def.ID) - } - } - newRules = append(newRules, r.Clone().Reset(job.SchemaName, tblInfo.Name.L, "", ids...)) - } - - patch := label.NewRulePatch(newRules, oldRuleIDs) - return infosync.UpdateLabelRules(context.TODO(), patch) -} - -func onAlterCacheTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - tbInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return 0, errors.Trace(err) - } - // If the table is already in the cache state - if tbInfo.TableCacheStatusType == model.TableCacheStatusEnable { - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) - return ver, nil - } - - if tbInfo.TempTableType != model.TempTableNone { - return ver, errors.Trace(dbterror.ErrOptOnTemporaryTable.GenWithStackByArgs("alter temporary table cache")) - } - - if tbInfo.Partition != nil { - return ver, errors.Trace(dbterror.ErrOptOnCacheTable.GenWithStackByArgs("partition mode")) - } - - switch tbInfo.TableCacheStatusType { - case model.TableCacheStatusDisable: - // disable -> switching - tbInfo.TableCacheStatusType = model.TableCacheStatusSwitching - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) - if err != nil { - return ver, err - } - case model.TableCacheStatusSwitching: - // switching -> enable - tbInfo.TableCacheStatusType = model.TableCacheStatusEnable - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) - if err != nil { - return ver, err - } - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) - default: - job.State = model.JobStateCancelled - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("alter table cache", tbInfo.TableCacheStatusType.String()) - } - return ver, err -} - -func onAlterNoCacheTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) { - tbInfo, err := GetTableInfoAndCancelFaultJob(t, job, job.SchemaID) - if err != nil { - return 0, errors.Trace(err) - } - // If the table is not in the cache state - if tbInfo.TableCacheStatusType == model.TableCacheStatusDisable { - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) - return ver, nil - } - - switch tbInfo.TableCacheStatusType { - case model.TableCacheStatusEnable: - // enable -> switching - tbInfo.TableCacheStatusType = model.TableCacheStatusSwitching - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) - if err != nil { - return ver, err - } - case model.TableCacheStatusSwitching: - // switching -> disable - tbInfo.TableCacheStatusType = model.TableCacheStatusDisable - ver, err = updateVersionAndTableInfoWithCheck(d, t, job, tbInfo, true) - if err != nil { - return ver, err - } - // Finish this job. - job.FinishTableJob(model.JobStateDone, model.StatePublic, ver, tbInfo) - default: - job.State = model.JobStateCancelled - err = dbterror.ErrInvalidDDLState.GenWithStackByArgs("alter table no cache", tbInfo.TableCacheStatusType.String()) - } - return ver, err -} diff --git a/pkg/ddl/util/binding__failpoint_binding__.go b/pkg/ddl/util/binding__failpoint_binding__.go deleted file mode 100644 index c7fcdb8c0fcf4..0000000000000 --- a/pkg/ddl/util/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package util - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/ddl/util/util.go b/pkg/ddl/util/util.go index 7342d8b020d51..6f088e0646854 100644 --- a/pkg/ddl/util/util.go +++ b/pkg/ddl/util/util.go @@ -381,10 +381,10 @@ const ( func IsRaftKv2(ctx context.Context, sctx sessionctx.Context) (bool, error) { // Mock store does not support `show config` now, so we use failpoint here // to control whether we are in raft-kv2 - if v, _err_ := failpoint.Eval(_curpkg_("IsRaftKv2")); _err_ == nil { + failpoint.Inject("IsRaftKv2", func(v failpoint.Value) (bool, error) { v2, _ := v.(bool) return v2, nil - } + }) rs, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, getRaftKvVersionSQL) if err != nil { diff --git a/pkg/ddl/util/util.go__failpoint_stash__ b/pkg/ddl/util/util.go__failpoint_stash__ deleted file mode 100644 index 6f088e0646854..0000000000000 --- a/pkg/ddl/util/util.go__failpoint_stash__ +++ /dev/null @@ -1,427 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package util - -import ( - "bytes" - "context" - "encoding/hex" - "fmt" - "os" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/logutil" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/mock" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/tikv/client-go/v2/tikvrpc" - clientv3 "go.etcd.io/etcd/client/v3" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -const ( - deleteRangesTable = `gc_delete_range` - doneDeleteRangesTable = `gc_delete_range_done` - loadDeleteRangeSQL = `SELECT HIGH_PRIORITY job_id, element_id, start_key, end_key FROM mysql.%n WHERE ts < %?` - recordDoneDeletedRangeSQL = `INSERT IGNORE INTO mysql.gc_delete_range_done SELECT * FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` - completeDeleteRangeSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %? AND element_id = %?` - completeDeleteMultiRangesSQL = `DELETE FROM mysql.gc_delete_range WHERE job_id = %?` - updateDeleteRangeSQL = `UPDATE mysql.gc_delete_range SET start_key = %? WHERE job_id = %? AND element_id = %? AND start_key = %?` - deleteDoneRecordSQL = `DELETE FROM mysql.gc_delete_range_done WHERE job_id = %? AND element_id = %?` - loadGlobalVars = `SELECT HIGH_PRIORITY variable_name, variable_value from mysql.global_variables where variable_name in (` // + nameList + ")" - // KeyOpDefaultTimeout is the default timeout for each key operation. - KeyOpDefaultTimeout = 2 * time.Second - // KeyOpRetryInterval is the interval between two key operations. - KeyOpRetryInterval = 30 * time.Millisecond - // DDLAllSchemaVersions is the path on etcd that is used to store all servers current schema versions. - DDLAllSchemaVersions = "/tidb/ddl/all_schema_versions" - // DDLAllSchemaVersionsByJob is the path on etcd that is used to store all servers current schema versions. - // /tidb/ddl/all_schema_by_job_versions// ---> - DDLAllSchemaVersionsByJob = "/tidb/ddl/all_schema_by_job_versions" - // DDLGlobalSchemaVersion is the path on etcd that is used to store the latest schema versions. - DDLGlobalSchemaVersion = "/tidb/ddl/global_schema_version" - // ServerGlobalState is the path on etcd that is used to store the server global state. - ServerGlobalState = "/tidb/server/global_state" - // SessionTTL is the etcd session's TTL in seconds. - SessionTTL = 90 -) - -// DelRangeTask is for run delete-range command in gc_worker. -type DelRangeTask struct { - StartKey kv.Key - EndKey kv.Key - JobID int64 - ElementID int64 -} - -// Range returns the range [start, end) to delete. -func (t DelRangeTask) Range() (kv.Key, kv.Key) { - return t.StartKey, t.EndKey -} - -// LoadDeleteRanges loads delete range tasks from gc_delete_range table. -func LoadDeleteRanges(ctx context.Context, sctx sessionctx.Context, safePoint uint64) (ranges []DelRangeTask, _ error) { - return loadDeleteRangesFromTable(ctx, sctx, deleteRangesTable, safePoint) -} - -// LoadDoneDeleteRanges loads deleted ranges from gc_delete_range_done table. -func LoadDoneDeleteRanges(ctx context.Context, sctx sessionctx.Context, safePoint uint64) (ranges []DelRangeTask, _ error) { - return loadDeleteRangesFromTable(ctx, sctx, doneDeleteRangesTable, safePoint) -} - -func loadDeleteRangesFromTable(ctx context.Context, sctx sessionctx.Context, table string, safePoint uint64) (ranges []DelRangeTask, _ error) { - rs, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, loadDeleteRangeSQL, table, safePoint) - if rs != nil { - defer terror.Call(rs.Close) - } - if err != nil { - return nil, errors.Trace(err) - } - - req := rs.NewChunk(nil) - it := chunk.NewIterator4Chunk(req) - for { - err = rs.Next(context.TODO(), req) - if err != nil { - return nil, errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - - for row := it.Begin(); row != it.End(); row = it.Next() { - startKey, err := hex.DecodeString(row.GetString(2)) - if err != nil { - return nil, errors.Trace(err) - } - endKey, err := hex.DecodeString(row.GetString(3)) - if err != nil { - return nil, errors.Trace(err) - } - ranges = append(ranges, DelRangeTask{ - JobID: row.GetInt64(0), - ElementID: row.GetInt64(1), - StartKey: startKey, - EndKey: endKey, - }) - } - } - return ranges, nil -} - -// CompleteDeleteRange moves a record from gc_delete_range table to gc_delete_range_done table. -func CompleteDeleteRange(sctx sessionctx.Context, dr DelRangeTask, needToRecordDone bool) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - - _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "BEGIN") - if err != nil { - return errors.Trace(err) - } - - if needToRecordDone { - _, err = sctx.GetSQLExecutor().ExecuteInternal(ctx, recordDoneDeletedRangeSQL, dr.JobID, dr.ElementID) - if err != nil { - return errors.Trace(err) - } - } - - err = RemoveFromGCDeleteRange(sctx, dr.JobID, dr.ElementID) - if err != nil { - return errors.Trace(err) - } - _, err = sctx.GetSQLExecutor().ExecuteInternal(ctx, "COMMIT") - return errors.Trace(err) -} - -// RemoveFromGCDeleteRange is exported for ddl pkg to use. -func RemoveFromGCDeleteRange(sctx sessionctx.Context, jobID, elementID int64) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, completeDeleteRangeSQL, jobID, elementID) - return errors.Trace(err) -} - -// RemoveMultiFromGCDeleteRange is exported for ddl pkg to use. -func RemoveMultiFromGCDeleteRange(ctx context.Context, sctx sessionctx.Context, jobID int64) error { - _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, completeDeleteMultiRangesSQL, jobID) - return errors.Trace(err) -} - -// DeleteDoneRecord removes a record from gc_delete_range_done table. -func DeleteDoneRecord(sctx sessionctx.Context, dr DelRangeTask) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, deleteDoneRecordSQL, dr.JobID, dr.ElementID) - return errors.Trace(err) -} - -// UpdateDeleteRange is only for emulator. -func UpdateDeleteRange(sctx sessionctx.Context, dr DelRangeTask, newStartKey, oldStartKey kv.Key) error { - newStartKeyHex := hex.EncodeToString(newStartKey) - oldStartKeyHex := hex.EncodeToString(oldStartKey) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, updateDeleteRangeSQL, newStartKeyHex, dr.JobID, dr.ElementID, oldStartKeyHex) - return errors.Trace(err) -} - -// LoadDDLReorgVars loads ddl reorg variable from mysql.global_variables. -func LoadDDLReorgVars(ctx context.Context, sctx sessionctx.Context) error { - // close issue #21391 - // variable.TiDBRowFormatVersion is used to encode the new row for column type change. - return LoadGlobalVars(ctx, sctx, []string{variable.TiDBDDLReorgWorkerCount, variable.TiDBDDLReorgBatchSize, variable.TiDBRowFormatVersion}) -} - -// LoadDDLVars loads ddl variable from mysql.global_variables. -func LoadDDLVars(ctx sessionctx.Context) error { - return LoadGlobalVars(context.Background(), ctx, []string{variable.TiDBDDLErrorCountLimit}) -} - -// LoadGlobalVars loads global variable from mysql.global_variables. -func LoadGlobalVars(ctx context.Context, sctx sessionctx.Context, varNames []string) error { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnDDL) - // *mock.Context does not support SQL execution. Only do it when sctx is not `mock.Context` - if _, ok := sctx.(*mock.Context); !ok { - e := sctx.GetRestrictedSQLExecutor() - var buf strings.Builder - buf.WriteString(loadGlobalVars) - paramNames := make([]any, 0, len(varNames)) - for i, name := range varNames { - if i > 0 { - buf.WriteString(", ") - } - buf.WriteString("%?") - paramNames = append(paramNames, name) - } - buf.WriteString(")") - rows, _, err := e.ExecRestrictedSQL(ctx, nil, buf.String(), paramNames...) - if err != nil { - return errors.Trace(err) - } - for _, row := range rows { - varName := row.GetString(0) - varValue := row.GetString(1) - if err = sctx.GetSessionVars().SetSystemVarWithoutValidation(varName, varValue); err != nil { - return err - } - } - } - return nil -} - -// GetTimeZone gets the session location's zone name and offset. -func GetTimeZone(sctx sessionctx.Context) (string, int) { - loc := sctx.GetSessionVars().Location() - name := loc.String() - if name != "" { - _, err := time.LoadLocation(name) - if err == nil { - return name, 0 - } - } - _, offset := time.Now().In(loc).Zone() - return "", offset -} - -// enableEmulatorGC means whether to enable emulator GC. The default is enable. -// In some unit tests, we want to stop emulator GC, then wen can set enableEmulatorGC to 0. -var emulatorGCEnable = atomicutil.NewInt32(1) - -// EmulatorGCEnable enables emulator gc. It exports for testing. -func EmulatorGCEnable() { - emulatorGCEnable.Store(1) -} - -// EmulatorGCDisable disables emulator gc. It exports for testing. -func EmulatorGCDisable() { - emulatorGCEnable.Store(0) -} - -// IsEmulatorGCEnable indicates whether emulator GC enabled. It exports for testing. -func IsEmulatorGCEnable() bool { - return emulatorGCEnable.Load() == 1 -} - -var internalResourceGroupTag = []byte{0} - -// GetInternalResourceGroupTaggerForTopSQL only use for testing. -func GetInternalResourceGroupTaggerForTopSQL() tikvrpc.ResourceGroupTagger { - tagger := func(req *tikvrpc.Request) { - req.ResourceGroupTag = internalResourceGroupTag - } - return tagger -} - -// IsInternalResourceGroupTaggerForTopSQL use for testing. -func IsInternalResourceGroupTaggerForTopSQL(tag []byte) bool { - return bytes.Equal(tag, internalResourceGroupTag) -} - -// DeleteKeyFromEtcd deletes key value from etcd. -func DeleteKeyFromEtcd(key string, etcdCli *clientv3.Client, retryCnt int, timeout time.Duration) error { - var err error - ctx := context.Background() - for i := 0; i < retryCnt; i++ { - childCtx, cancel := context.WithTimeout(ctx, timeout) - _, err = etcdCli.Delete(childCtx, key) - cancel() - if err == nil { - return nil - } - logutil.DDLLogger().Warn("etcd-cli delete key failed", zap.String("key", key), zap.Error(err), zap.Int("retryCnt", i)) - } - return errors.Trace(err) -} - -// PutKVToEtcdMono puts key value to etcd monotonously. -// etcdCli is client of etcd. -// retryCnt is retry time when an error occurs. -// opts are configures of etcd Operations. -func PutKVToEtcdMono(ctx context.Context, etcdCli *clientv3.Client, retryCnt int, key, val string, - opts ...clientv3.OpOption) error { - var err error - for i := 0; i < retryCnt; i++ { - if err = ctx.Err(); err != nil { - return errors.Trace(err) - } - - childCtx, cancel := context.WithTimeout(ctx, KeyOpDefaultTimeout) - var resp *clientv3.GetResponse - resp, err = etcdCli.Get(childCtx, key) - if err != nil { - cancel() - logutil.DDLLogger().Warn("etcd-cli put kv failed", zap.String("key", key), zap.String("value", val), zap.Error(err), zap.Int("retryCnt", i)) - time.Sleep(KeyOpRetryInterval) - continue - } - prevRevision := int64(0) - if len(resp.Kvs) > 0 { - prevRevision = resp.Kvs[0].ModRevision - } - - var txnResp *clientv3.TxnResponse - txnResp, err = etcdCli.Txn(childCtx). - If(clientv3.Compare(clientv3.ModRevision(key), "=", prevRevision)). - Then(clientv3.OpPut(key, val, opts...)). - Commit() - - cancel() - - if err == nil && txnResp.Succeeded { - return nil - } - - if err == nil { - err = errors.New("performing compare-and-swap during PutKVToEtcd failed") - } - - logutil.DDLLogger().Warn("etcd-cli put kv failed", zap.String("key", key), zap.String("value", val), zap.Error(err), zap.Int("retryCnt", i)) - time.Sleep(KeyOpRetryInterval) - } - return errors.Trace(err) -} - -// PutKVToEtcd puts key value to etcd. -// etcdCli is client of etcd. -// retryCnt is retry time when an error occurs. -// opts are configures of etcd Operations. -func PutKVToEtcd(ctx context.Context, etcdCli *clientv3.Client, retryCnt int, key, val string, - opts ...clientv3.OpOption) error { - var err error - for i := 0; i < retryCnt; i++ { - if err = ctx.Err(); err != nil { - return errors.Trace(err) - } - - childCtx, cancel := context.WithTimeout(ctx, KeyOpDefaultTimeout) - _, err = etcdCli.Put(childCtx, key, val, opts...) - cancel() - if err == nil { - return nil - } - logutil.DDLLogger().Warn("etcd-cli put kv failed", zap.String("key", key), zap.String("value", val), zap.Error(err), zap.Int("retryCnt", i)) - time.Sleep(KeyOpRetryInterval) - } - return errors.Trace(err) -} - -// WrapKey2String wraps the key to a string. -func WrapKey2String(key []byte) string { - if len(key) == 0 { - return "''" - } - return fmt.Sprintf("0x%x", key) -} - -const ( - getRaftKvVersionSQL = "select `value` from information_schema.cluster_config where type = 'tikv' and `key` = 'storage.engine'" - raftKv2 = "raft-kv2" -) - -// IsRaftKv2 checks whether the raft-kv2 is enabled -func IsRaftKv2(ctx context.Context, sctx sessionctx.Context) (bool, error) { - // Mock store does not support `show config` now, so we use failpoint here - // to control whether we are in raft-kv2 - failpoint.Inject("IsRaftKv2", func(v failpoint.Value) (bool, error) { - v2, _ := v.(bool) - return v2, nil - }) - - rs, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, getRaftKvVersionSQL) - if err != nil { - return false, err - } - if rs == nil { - return false, nil - } - - defer terror.Call(rs.Close) - rows, err := sqlexec.DrainRecordSet(ctx, rs, sctx.GetSessionVars().MaxChunkSize) - if err != nil { - return false, errors.Trace(err) - } - if len(rows) == 0 { - return false, nil - } - - // All nodes should have the same type of engine - raftVersion := rows[0].GetString(0) - return raftVersion == raftKv2, nil -} - -// FolderNotEmpty returns true only when the folder is not empty. -func FolderNotEmpty(path string) bool { - entries, _ := os.ReadDir(path) - return len(entries) > 0 -} - -// GenKeyExistsErr builds a ErrKeyExists error. -func GenKeyExistsErr(key, value []byte, idxInfo *model.IndexInfo, tblInfo *model.TableInfo) error { - indexName := fmt.Sprintf("%s.%s", tblInfo.Name.String(), idxInfo.Name.String()) - valueStr, err := tables.GenIndexValueFromIndex(key, value, tblInfo, idxInfo) - if err != nil { - logutil.DDLLogger().Warn("decode index key value / column value failed", zap.String("index", indexName), - zap.String("key", hex.EncodeToString(key)), zap.String("value", hex.EncodeToString(value)), zap.Error(err)) - return errors.Trace(kv.ErrKeyExists.FastGenByArgs(key, indexName)) - } - return kv.GenKeyExistsErr(valueStr, indexName) -} diff --git a/pkg/distsql/binding__failpoint_binding__.go b/pkg/distsql/binding__failpoint_binding__.go deleted file mode 100644 index b2e21b1442daa..0000000000000 --- a/pkg/distsql/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package distsql - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/distsql/request_builder.go b/pkg/distsql/request_builder.go index 9c9f05296038c..9f0b8b24a2a18 100644 --- a/pkg/distsql/request_builder.go +++ b/pkg/distsql/request_builder.go @@ -65,12 +65,12 @@ func (builder *RequestBuilder) Build() (*kv.Request, error) { }, } } - if val, _err_ := failpoint.Eval(_curpkg_("assertRequestBuilderReplicaOption")); _err_ == nil { + failpoint.Inject("assertRequestBuilderReplicaOption", func(val failpoint.Value) { assertScope := val.(string) if builder.ReplicaRead.IsClosestRead() && assertScope != builder.ReadReplicaScope { panic("request builder get staleness option fail") } - } + }) err := builder.verifyTxnScope() if err != nil { builder.err = err @@ -90,12 +90,12 @@ func (builder *RequestBuilder) Build() (*kv.Request, error) { switch dag.Executors[0].Tp { case tipb.ExecType_TypeTableScan, tipb.ExecType_TypeIndexScan, tipb.ExecType_TypePartitionTableScan: builder.Request.Concurrency = 2 - if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { if val.(bool) { // When the concurrency is too small, test case tests/realtikvtest/sessiontest.TestCoprocessorOOMAction can't trigger OOM condition builder.Request.Concurrency = oldConcurrency } - } + }) } } } diff --git a/pkg/distsql/request_builder.go__failpoint_stash__ b/pkg/distsql/request_builder.go__failpoint_stash__ deleted file mode 100644 index 9f0b8b24a2a18..0000000000000 --- a/pkg/distsql/request_builder.go__failpoint_stash__ +++ /dev/null @@ -1,862 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package distsql - -import ( - "fmt" - "math" - "sort" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/pkg/ddl/placement" - distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" - "github.com/pingcap/tidb/pkg/errctx" - infoschema "github.com/pingcap/tidb/pkg/infoschema/context" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tipb/go-tipb" - "github.com/tikv/client-go/v2/tikvrpc" -) - -// RequestBuilder is used to build a "kv.Request". -// It is called before we issue a kv request by "Select". -type RequestBuilder struct { - kv.Request - is infoschema.MetaOnlyInfoSchema - err error - - // When SetDAGRequest is called, builder will also this field. - dag *tipb.DAGRequest -} - -// Build builds a "kv.Request". -func (builder *RequestBuilder) Build() (*kv.Request, error) { - if builder.ReadReplicaScope == "" { - builder.ReadReplicaScope = kv.GlobalReplicaScope - } - if builder.ReplicaRead.IsClosestRead() && builder.ReadReplicaScope != kv.GlobalReplicaScope { - builder.MatchStoreLabels = []*metapb.StoreLabel{ - { - Key: placement.DCLabelKey, - Value: builder.ReadReplicaScope, - }, - } - } - failpoint.Inject("assertRequestBuilderReplicaOption", func(val failpoint.Value) { - assertScope := val.(string) - if builder.ReplicaRead.IsClosestRead() && assertScope != builder.ReadReplicaScope { - panic("request builder get staleness option fail") - } - }) - err := builder.verifyTxnScope() - if err != nil { - builder.err = err - } - if builder.Request.KeyRanges == nil { - builder.Request.KeyRanges = kv.NewNonPartitionedKeyRanges(nil) - } - - if dag := builder.dag; dag != nil { - if execCnt := len(dag.Executors); execCnt == 1 { - oldConcurrency := builder.Request.Concurrency - // select * from t order by id - if builder.Request.KeepOrder { - // When the DAG is just simple scan and keep order, set concurrency to 2. - // If a lot data are returned to client, mysql protocol is the bottleneck so concurrency 2 is enough. - // If very few data are returned to client, the speed is not optimal but good enough. - switch dag.Executors[0].Tp { - case tipb.ExecType_TypeTableScan, tipb.ExecType_TypeIndexScan, tipb.ExecType_TypePartitionTableScan: - builder.Request.Concurrency = 2 - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { - if val.(bool) { - // When the concurrency is too small, test case tests/realtikvtest/sessiontest.TestCoprocessorOOMAction can't trigger OOM condition - builder.Request.Concurrency = oldConcurrency - } - }) - } - } - } - } - - return &builder.Request, builder.err -} - -// SetMemTracker sets a memTracker for this request. -func (builder *RequestBuilder) SetMemTracker(tracker *memory.Tracker) *RequestBuilder { - builder.Request.MemTracker = tracker - return builder -} - -// SetTableRanges sets "KeyRanges" for "kv.Request" by converting "tableRanges" -// to "KeyRanges" firstly. -// Note this function should be deleted or at least not exported, but currently -// br refers it, so have to keep it. -func (builder *RequestBuilder) SetTableRanges(tid int64, tableRanges []*ranger.Range) *RequestBuilder { - if builder.err == nil { - builder.Request.KeyRanges = kv.NewNonPartitionedKeyRanges(TableRangesToKVRanges(tid, tableRanges)) - } - return builder -} - -// SetIndexRanges sets "KeyRanges" for "kv.Request" by converting index range -// "ranges" to "KeyRanges" firstly. -func (builder *RequestBuilder) SetIndexRanges(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range) *RequestBuilder { - if builder.err == nil { - builder.Request.KeyRanges, builder.err = IndexRangesToKVRanges(dctx, tid, idxID, ranges) - } - return builder -} - -// SetIndexRangesForTables sets "KeyRanges" for "kv.Request" by converting multiple indexes range -// "ranges" to "KeyRanges" firstly. -func (builder *RequestBuilder) SetIndexRangesForTables(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range) *RequestBuilder { - if builder.err == nil { - builder.Request.KeyRanges, builder.err = IndexRangesToKVRangesForTables(dctx, tids, idxID, ranges) - } - return builder -} - -// SetHandleRanges sets "KeyRanges" for "kv.Request" by converting table handle range -// "ranges" to "KeyRanges" firstly. -func (builder *RequestBuilder) SetHandleRanges(dctx *distsqlctx.DistSQLContext, tid int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { - builder = builder.SetHandleRangesForTables(dctx, []int64{tid}, isCommonHandle, ranges) - builder.err = builder.Request.KeyRanges.SetToNonPartitioned() - return builder -} - -// SetHandleRangesForTables sets "KeyRanges" for "kv.Request" by converting table handle range -// "ranges" to "KeyRanges" firstly for multiple tables. -func (builder *RequestBuilder) SetHandleRangesForTables(dctx *distsqlctx.DistSQLContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) *RequestBuilder { - if builder.err == nil { - builder.Request.KeyRanges, builder.err = TableHandleRangesToKVRanges(dctx, tid, isCommonHandle, ranges) - } - return builder -} - -// SetTableHandles sets "KeyRanges" for "kv.Request" by converting table handles -// "handles" to "KeyRanges" firstly. -func (builder *RequestBuilder) SetTableHandles(tid int64, handles []kv.Handle) *RequestBuilder { - keyRanges, hints := TableHandlesToKVRanges(tid, handles) - builder.Request.KeyRanges = kv.NewNonParitionedKeyRangesWithHint(keyRanges, hints) - return builder -} - -// SetPartitionsAndHandles sets "KeyRanges" for "kv.Request" by converting ParitionHandles to KeyRanges. -// handles in slice must be kv.PartitionHandle. -func (builder *RequestBuilder) SetPartitionsAndHandles(handles []kv.Handle) *RequestBuilder { - keyRanges, hints := PartitionHandlesToKVRanges(handles) - builder.Request.KeyRanges = kv.NewNonParitionedKeyRangesWithHint(keyRanges, hints) - return builder -} - -const estimatedRegionRowCount = 100000 - -// SetDAGRequest sets the request type to "ReqTypeDAG" and construct request data. -func (builder *RequestBuilder) SetDAGRequest(dag *tipb.DAGRequest) *RequestBuilder { - if builder.err == nil { - builder.Request.Tp = kv.ReqTypeDAG - builder.Request.Cacheable = true - builder.Request.Data, builder.err = dag.Marshal() - builder.dag = dag - if execCnt := len(dag.Executors); execCnt != 0 && dag.Executors[execCnt-1].GetLimit() != nil { - limit := dag.Executors[execCnt-1].GetLimit() - builder.Request.LimitSize = limit.GetLimit() - // When the DAG is just simple scan and small limit, set concurrency to 1 would be sufficient. - if execCnt == 2 { - if limit.Limit < estimatedRegionRowCount { - if kr := builder.Request.KeyRanges; kr != nil { - builder.Request.Concurrency = kr.PartitionNum() - } else { - builder.Request.Concurrency = 1 - } - } - } - } - } - return builder -} - -// SetAnalyzeRequest sets the request type to "ReqTypeAnalyze" and construct request data. -func (builder *RequestBuilder) SetAnalyzeRequest(ana *tipb.AnalyzeReq, isoLevel kv.IsoLevel) *RequestBuilder { - if builder.err == nil { - builder.Request.Tp = kv.ReqTypeAnalyze - builder.Request.Data, builder.err = ana.Marshal() - builder.Request.NotFillCache = true - builder.Request.IsolationLevel = isoLevel - builder.Request.Priority = kv.PriorityLow - } - - return builder -} - -// SetChecksumRequest sets the request type to "ReqTypeChecksum" and construct request data. -func (builder *RequestBuilder) SetChecksumRequest(checksum *tipb.ChecksumRequest) *RequestBuilder { - if builder.err == nil { - builder.Request.Tp = kv.ReqTypeChecksum - builder.Request.Data, builder.err = checksum.Marshal() - builder.Request.NotFillCache = true - } - - return builder -} - -// SetKeyRanges sets "KeyRanges" for "kv.Request". -func (builder *RequestBuilder) SetKeyRanges(keyRanges []kv.KeyRange) *RequestBuilder { - builder.Request.KeyRanges = kv.NewNonPartitionedKeyRanges(keyRanges) - return builder -} - -// SetKeyRangesWithHints sets "KeyRanges" for "kv.Request" with row count hints. -func (builder *RequestBuilder) SetKeyRangesWithHints(keyRanges []kv.KeyRange, hints []int) *RequestBuilder { - builder.Request.KeyRanges = kv.NewNonParitionedKeyRangesWithHint(keyRanges, hints) - return builder -} - -// SetWrappedKeyRanges sets "KeyRanges" for "kv.Request". -func (builder *RequestBuilder) SetWrappedKeyRanges(keyRanges *kv.KeyRanges) *RequestBuilder { - builder.Request.KeyRanges = keyRanges - return builder -} - -// SetPartitionKeyRanges sets the "KeyRanges" for "kv.Request" on partitioned table cases. -func (builder *RequestBuilder) SetPartitionKeyRanges(keyRanges [][]kv.KeyRange) *RequestBuilder { - builder.Request.KeyRanges = kv.NewPartitionedKeyRanges(keyRanges) - return builder -} - -// SetStartTS sets "StartTS" for "kv.Request". -func (builder *RequestBuilder) SetStartTS(startTS uint64) *RequestBuilder { - builder.Request.StartTs = startTS - return builder -} - -// SetDesc sets "Desc" for "kv.Request". -func (builder *RequestBuilder) SetDesc(desc bool) *RequestBuilder { - builder.Request.Desc = desc - return builder -} - -// SetKeepOrder sets "KeepOrder" for "kv.Request". -func (builder *RequestBuilder) SetKeepOrder(order bool) *RequestBuilder { - builder.Request.KeepOrder = order - return builder -} - -// SetStoreType sets "StoreType" for "kv.Request". -func (builder *RequestBuilder) SetStoreType(storeType kv.StoreType) *RequestBuilder { - builder.Request.StoreType = storeType - return builder -} - -// SetAllowBatchCop sets `BatchCop` property. -func (builder *RequestBuilder) SetAllowBatchCop(batchCop bool) *RequestBuilder { - builder.Request.BatchCop = batchCop - return builder -} - -// SetPartitionIDAndRanges sets `PartitionIDAndRanges` property. -func (builder *RequestBuilder) SetPartitionIDAndRanges(partitionIDAndRanges []kv.PartitionIDAndRanges) *RequestBuilder { - builder.PartitionIDAndRanges = partitionIDAndRanges - return builder -} - -func (builder *RequestBuilder) getIsolationLevel() kv.IsoLevel { - if builder.Tp == kv.ReqTypeAnalyze { - return kv.RC - } - return kv.SI -} - -func (*RequestBuilder) getKVPriority(dctx *distsqlctx.DistSQLContext) int { - switch dctx.Priority { - case mysql.NoPriority, mysql.DelayedPriority: - return kv.PriorityNormal - case mysql.LowPriority: - return kv.PriorityLow - case mysql.HighPriority: - return kv.PriorityHigh - } - return kv.PriorityNormal -} - -// SetFromSessionVars sets the following fields for "kv.Request" from session variables: -// "Concurrency", "IsolationLevel", "NotFillCache", "TaskID", "Priority", "ReplicaRead", -// "ResourceGroupTagger", "ResourceGroupName" -func (builder *RequestBuilder) SetFromSessionVars(dctx *distsqlctx.DistSQLContext) *RequestBuilder { - distsqlConcurrency := dctx.DistSQLConcurrency - if builder.Request.Concurrency == 0 { - // Concurrency unset. - builder.Request.Concurrency = distsqlConcurrency - } else if builder.Request.Concurrency > distsqlConcurrency { - // Concurrency is set in SetDAGRequest, check the upper limit. - builder.Request.Concurrency = distsqlConcurrency - } - replicaReadType := dctx.ReplicaReadType - if dctx.WeakConsistency { - builder.Request.IsolationLevel = kv.RC - } else if dctx.RCCheckTS { - builder.Request.IsolationLevel = kv.RCCheckTS - replicaReadType = kv.ReplicaReadLeader - } else { - builder.Request.IsolationLevel = builder.getIsolationLevel() - } - builder.Request.NotFillCache = dctx.NotFillCache - builder.Request.TaskID = dctx.TaskID - builder.Request.Priority = builder.getKVPriority(dctx) - builder.Request.ReplicaRead = replicaReadType - builder.SetResourceGroupTagger(dctx.ResourceGroupTagger) - { - builder.SetPaging(dctx.EnablePaging) - builder.Request.Paging.MinPagingSize = uint64(dctx.MinPagingSize) - builder.Request.Paging.MaxPagingSize = uint64(dctx.MaxPagingSize) - } - builder.RequestSource.RequestSourceInternal = dctx.InRestrictedSQL - builder.RequestSource.RequestSourceType = dctx.RequestSourceType - builder.RequestSource.ExplicitRequestSourceType = dctx.ExplicitRequestSourceType - builder.StoreBatchSize = dctx.StoreBatchSize - builder.Request.ResourceGroupName = dctx.ResourceGroupName - builder.Request.StoreBusyThreshold = dctx.LoadBasedReplicaReadThreshold - builder.Request.RunawayChecker = dctx.RunawayChecker - builder.Request.TiKVClientReadTimeout = dctx.TiKVClientReadTimeout - return builder -} - -// SetPaging sets "Paging" flag for "kv.Request". -func (builder *RequestBuilder) SetPaging(paging bool) *RequestBuilder { - builder.Request.Paging.Enable = paging - return builder -} - -// SetConcurrency sets "Concurrency" for "kv.Request". -func (builder *RequestBuilder) SetConcurrency(concurrency int) *RequestBuilder { - builder.Request.Concurrency = concurrency - return builder -} - -// SetTiDBServerID sets "TiDBServerID" for "kv.Request" -// -// ServerID is a unique id of TiDB instance among the cluster. -// See https://github.com/pingcap/tidb/blob/master/docs/design/2020-06-01-global-kill.md -func (builder *RequestBuilder) SetTiDBServerID(serverID uint64) *RequestBuilder { - builder.Request.TiDBServerID = serverID - return builder -} - -// SetFromInfoSchema sets the following fields from infoSchema: -// "bundles" -func (builder *RequestBuilder) SetFromInfoSchema(is infoschema.MetaOnlyInfoSchema) *RequestBuilder { - builder.is = is - builder.Request.SchemaVar = is.SchemaMetaVersion() - return builder -} - -// SetResourceGroupTagger sets the request resource group tagger. -func (builder *RequestBuilder) SetResourceGroupTagger(tagger tikvrpc.ResourceGroupTagger) *RequestBuilder { - builder.Request.ResourceGroupTagger = tagger - return builder -} - -// SetResourceGroupName sets the request resource group name. -func (builder *RequestBuilder) SetResourceGroupName(name string) *RequestBuilder { - builder.Request.ResourceGroupName = name - return builder -} - -// SetExplicitRequestSourceType sets the explicit request source type. -func (builder *RequestBuilder) SetExplicitRequestSourceType(sourceType string) *RequestBuilder { - builder.RequestSource.ExplicitRequestSourceType = sourceType - return builder -} - -func (builder *RequestBuilder) verifyTxnScope() error { - txnScope := builder.TxnScope - if txnScope == "" || txnScope == kv.GlobalReplicaScope || builder.is == nil { - return nil - } - visitPhysicalTableID := make(map[int64]struct{}) - tids, err := tablecodec.VerifyTableIDForRanges(builder.Request.KeyRanges) - if err != nil { - return err - } - for _, tid := range tids { - visitPhysicalTableID[tid] = struct{}{} - } - - for phyTableID := range visitPhysicalTableID { - valid := VerifyTxnScope(txnScope, phyTableID, builder.is) - if !valid { - var tblName string - var partName string - tblInfo, _, partInfo := builder.is.FindTableInfoByPartitionID(phyTableID) - if tblInfo != nil && partInfo != nil { - tblName = tblInfo.Name.String() - partName = partInfo.Name.String() - } else { - tblInfo, _ = builder.is.TableInfoByID(phyTableID) - tblName = tblInfo.Name.String() - } - err := fmt.Errorf("table %v can not be read by %v txn_scope", tblName, txnScope) - if len(partName) > 0 { - err = fmt.Errorf("table %v's partition %v can not be read by %v txn_scope", - tblName, partName, txnScope) - } - return err - } - } - return nil -} - -// SetTxnScope sets request TxnScope -func (builder *RequestBuilder) SetTxnScope(scope string) *RequestBuilder { - builder.TxnScope = scope - return builder -} - -// SetReadReplicaScope sets request readReplicaScope -func (builder *RequestBuilder) SetReadReplicaScope(scope string) *RequestBuilder { - builder.ReadReplicaScope = scope - return builder -} - -// SetIsStaleness sets request IsStaleness -func (builder *RequestBuilder) SetIsStaleness(is bool) *RequestBuilder { - builder.IsStaleness = is - return builder -} - -// SetClosestReplicaReadAdjuster sets request CoprRequestAdjuster -func (builder *RequestBuilder) SetClosestReplicaReadAdjuster(chkFn kv.CoprRequestAdjuster) *RequestBuilder { - builder.ClosestReplicaReadAdjuster = chkFn - return builder -} - -// SetConnIDAndConnAlias sets connection id for the builder. -func (builder *RequestBuilder) SetConnIDAndConnAlias(connID uint64, connAlias string) *RequestBuilder { - builder.ConnID = connID - builder.ConnAlias = connAlias - return builder -} - -// TableHandleRangesToKVRanges convert table handle ranges to "KeyRanges" for multiple tables. -func TableHandleRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tid []int64, isCommonHandle bool, ranges []*ranger.Range) (*kv.KeyRanges, error) { - if !isCommonHandle { - return tablesRangesToKVRanges(tid, ranges), nil - } - return CommonHandleRangesToKVRanges(dctx, tid, ranges) -} - -// TableRangesToKVRanges converts table ranges to "KeyRange". -// Note this function should not be exported, but currently -// br refers to it, so have to keep it. -func TableRangesToKVRanges(tid int64, ranges []*ranger.Range) []kv.KeyRange { - if len(ranges) == 0 { - return []kv.KeyRange{} - } - return tablesRangesToKVRanges([]int64{tid}, ranges).FirstPartitionRange() -} - -// tablesRangesToKVRanges converts table ranges to "KeyRange". -func tablesRangesToKVRanges(tids []int64, ranges []*ranger.Range) *kv.KeyRanges { - return tableRangesToKVRangesWithoutSplit(tids, ranges) -} - -func tableRangesToKVRangesWithoutSplit(tids []int64, ranges []*ranger.Range) *kv.KeyRanges { - krs := make([][]kv.KeyRange, len(tids)) - for i := range krs { - krs[i] = make([]kv.KeyRange, 0, len(ranges)) - } - for _, ran := range ranges { - low, high := encodeHandleKey(ran) - for i, tid := range tids { - startKey := tablecodec.EncodeRowKey(tid, low) - endKey := tablecodec.EncodeRowKey(tid, high) - krs[i] = append(krs[i], kv.KeyRange{StartKey: startKey, EndKey: endKey}) - } - } - return kv.NewPartitionedKeyRanges(krs) -} - -func encodeHandleKey(ran *ranger.Range) ([]byte, []byte) { - low := codec.EncodeInt(nil, ran.LowVal[0].GetInt64()) - high := codec.EncodeInt(nil, ran.HighVal[0].GetInt64()) - if ran.LowExclude { - low = kv.Key(low).PrefixNext() - } - if !ran.HighExclude { - high = kv.Key(high).PrefixNext() - } - return low, high -} - -// SplitRangesAcrossInt64Boundary split the ranges into two groups: -// 1. signedRanges is less or equal than MaxInt64 -// 2. unsignedRanges is greater than MaxInt64 -// -// We do this because every key of tikv is encoded as an int64. As a result, MaxUInt64 is smaller than zero when -// interpreted as an int64 variable. -// -// This function does the following: -// 1. split ranges into two groups as described above. -// 2. if there's a range that straddles the int64 boundary, split it into two ranges, which results in one smaller and -// one greater than MaxInt64. -// -// if `KeepOrder` is false, we merge the two groups of ranges into one group, to save a rpc call later -// if `desc` is false, return signed ranges first, vice versa. -func SplitRangesAcrossInt64Boundary(ranges []*ranger.Range, keepOrder bool, desc bool, isCommonHandle bool) ([]*ranger.Range, []*ranger.Range) { - if isCommonHandle || len(ranges) == 0 || ranges[0].LowVal[0].Kind() == types.KindInt64 { - return ranges, nil - } - idx := sort.Search(len(ranges), func(i int) bool { return ranges[i].HighVal[0].GetUint64() > math.MaxInt64 }) - if idx == len(ranges) { - return ranges, nil - } - if ranges[idx].LowVal[0].GetUint64() > math.MaxInt64 { - signedRanges := ranges[0:idx] - unsignedRanges := ranges[idx:] - if !keepOrder { - return append(unsignedRanges, signedRanges...), nil - } - if desc { - return unsignedRanges, signedRanges - } - return signedRanges, unsignedRanges - } - // need to split the range that straddles the int64 boundary - signedRanges := make([]*ranger.Range, 0, idx+1) - unsignedRanges := make([]*ranger.Range, 0, len(ranges)-idx) - signedRanges = append(signedRanges, ranges[0:idx]...) - if !(ranges[idx].LowVal[0].GetUint64() == math.MaxInt64 && ranges[idx].LowExclude) { - signedRanges = append(signedRanges, &ranger.Range{ - LowVal: ranges[idx].LowVal, - LowExclude: ranges[idx].LowExclude, - HighVal: []types.Datum{types.NewUintDatum(math.MaxInt64)}, - Collators: ranges[idx].Collators, - }) - } - if !(ranges[idx].HighVal[0].GetUint64() == math.MaxInt64+1 && ranges[idx].HighExclude) { - unsignedRanges = append(unsignedRanges, &ranger.Range{ - LowVal: []types.Datum{types.NewUintDatum(math.MaxInt64 + 1)}, - HighVal: ranges[idx].HighVal, - HighExclude: ranges[idx].HighExclude, - Collators: ranges[idx].Collators, - }) - } - if idx < len(ranges) { - unsignedRanges = append(unsignedRanges, ranges[idx+1:]...) - } - if !keepOrder { - return append(unsignedRanges, signedRanges...), nil - } - if desc { - return unsignedRanges, signedRanges - } - return signedRanges, unsignedRanges -} - -// TableHandlesToKVRanges converts sorted handle to kv ranges. -// For continuous handles, we should merge them to a single key range. -func TableHandlesToKVRanges(tid int64, handles []kv.Handle) ([]kv.KeyRange, []int) { - krs := make([]kv.KeyRange, 0, len(handles)) - hints := make([]int, 0, len(handles)) - i := 0 - for i < len(handles) { - var isCommonHandle bool - var commonHandle *kv.CommonHandle - if partitionHandle, ok := handles[i].(kv.PartitionHandle); ok { - tid = partitionHandle.PartitionID - commonHandle, isCommonHandle = partitionHandle.Handle.(*kv.CommonHandle) - } else { - commonHandle, isCommonHandle = handles[i].(*kv.CommonHandle) - } - if isCommonHandle { - ran := kv.KeyRange{ - StartKey: tablecodec.EncodeRowKey(tid, commonHandle.Encoded()), - EndKey: tablecodec.EncodeRowKey(tid, kv.Key(commonHandle.Encoded()).Next()), - } - krs = append(krs, ran) - hints = append(hints, 1) - i++ - continue - } - j := i + 1 - for ; j < len(handles) && handles[j-1].IntValue() != math.MaxInt64; j++ { - if p, ok := handles[j].(kv.PartitionHandle); ok && p.PartitionID != tid { - break - } - if handles[j].IntValue() != handles[j-1].IntValue()+1 { - break - } - } - low := codec.EncodeInt(nil, handles[i].IntValue()) - high := codec.EncodeInt(nil, handles[j-1].IntValue()) - high = kv.Key(high).PrefixNext() - startKey := tablecodec.EncodeRowKey(tid, low) - endKey := tablecodec.EncodeRowKey(tid, high) - krs = append(krs, kv.KeyRange{StartKey: startKey, EndKey: endKey}) - hints = append(hints, j-i) - i = j - } - return krs, hints -} - -// PartitionHandlesToKVRanges convert ParitionHandles to kv ranges. -// Handle in slices must be kv.PartitionHandle -func PartitionHandlesToKVRanges(handles []kv.Handle) ([]kv.KeyRange, []int) { - krs := make([]kv.KeyRange, 0, len(handles)) - hints := make([]int, 0, len(handles)) - i := 0 - for i < len(handles) { - ph := handles[i].(kv.PartitionHandle) - h := ph.Handle - pid := ph.PartitionID - if commonHandle, ok := h.(*kv.CommonHandle); ok { - ran := kv.KeyRange{ - StartKey: tablecodec.EncodeRowKey(pid, commonHandle.Encoded()), - EndKey: tablecodec.EncodeRowKey(pid, append(commonHandle.Encoded(), 0)), - } - krs = append(krs, ran) - hints = append(hints, 1) - i++ - continue - } - j := i + 1 - for ; j < len(handles) && handles[j-1].IntValue() != math.MaxInt64; j++ { - if handles[j].IntValue() != handles[j-1].IntValue()+1 { - break - } - if handles[j].(kv.PartitionHandle).PartitionID != pid { - break - } - } - low := codec.EncodeInt(nil, handles[i].IntValue()) - high := codec.EncodeInt(nil, handles[j-1].IntValue()) - high = kv.Key(high).PrefixNext() - startKey := tablecodec.EncodeRowKey(pid, low) - endKey := tablecodec.EncodeRowKey(pid, high) - krs = append(krs, kv.KeyRange{StartKey: startKey, EndKey: endKey}) - hints = append(hints, j-i) - i = j - } - return krs, hints -} - -// IndexRangesToKVRanges converts index ranges to "KeyRange". -func IndexRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { - return IndexRangesToKVRangesWithInterruptSignal(dctx, tid, idxID, ranges, nil, nil) -} - -// IndexRangesToKVRangesWithInterruptSignal converts index ranges to "KeyRange". -// The process can be interrupted by set `interruptSignal` to true. -func IndexRangesToKVRangesWithInterruptSignal(dctx *distsqlctx.DistSQLContext, tid, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { - keyRanges, err := indexRangesToKVRangesForTablesWithInterruptSignal(dctx, []int64{tid}, idxID, ranges, memTracker, interruptSignal) - if err != nil { - return nil, err - } - err = keyRanges.SetToNonPartitioned() - return keyRanges, err -} - -// IndexRangesToKVRangesForTables converts indexes ranges to "KeyRange". -func IndexRangesToKVRangesForTables(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { - return indexRangesToKVRangesForTablesWithInterruptSignal(dctx, tids, idxID, ranges, nil, nil) -} - -// IndexRangesToKVRangesForTablesWithInterruptSignal converts indexes ranges to "KeyRange". -// The process can be interrupted by set `interruptSignal` to true. -func indexRangesToKVRangesForTablesWithInterruptSignal(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { - return indexRangesToKVWithoutSplit(dctx, tids, idxID, ranges, memTracker, interruptSignal) -} - -// CommonHandleRangesToKVRanges converts common handle ranges to "KeyRange". -func CommonHandleRangesToKVRanges(dctx *distsqlctx.DistSQLContext, tids []int64, ranges []*ranger.Range) (*kv.KeyRanges, error) { - rans := make([]*ranger.Range, 0, len(ranges)) - for _, ran := range ranges { - low, high, err := EncodeIndexKey(dctx, ran) - if err != nil { - return nil, err - } - rans = append(rans, &ranger.Range{ - LowVal: []types.Datum{types.NewBytesDatum(low)}, - HighVal: []types.Datum{types.NewBytesDatum(high)}, - LowExclude: false, - HighExclude: true, - Collators: collate.GetBinaryCollatorSlice(1), - }) - } - krs := make([][]kv.KeyRange, len(tids)) - for i := range krs { - krs[i] = make([]kv.KeyRange, 0, len(ranges)) - } - for _, ran := range rans { - low, high := ran.LowVal[0].GetBytes(), ran.HighVal[0].GetBytes() - if ran.LowExclude { - low = kv.Key(low).PrefixNext() - } - ran.LowVal[0].SetBytes(low) - for i, tid := range tids { - startKey := tablecodec.EncodeRowKey(tid, low) - endKey := tablecodec.EncodeRowKey(tid, high) - krs[i] = append(krs[i], kv.KeyRange{StartKey: startKey, EndKey: endKey}) - } - } - return kv.NewPartitionedKeyRanges(krs), nil -} - -// VerifyTxnScope verify whether the txnScope and visited physical table break the leader rule's dcLocation. -func VerifyTxnScope(txnScope string, physicalTableID int64, is infoschema.MetaOnlyInfoSchema) bool { - if txnScope == "" || txnScope == kv.GlobalTxnScope { - return true - } - bundle, ok := is.PlacementBundleByPhysicalTableID(physicalTableID) - if !ok { - return true - } - leaderDC, ok := bundle.GetLeaderDC(placement.DCLabelKey) - if !ok { - return true - } - if leaderDC != txnScope { - return false - } - return true -} - -func indexRangesToKVWithoutSplit(dctx *distsqlctx.DistSQLContext, tids []int64, idxID int64, ranges []*ranger.Range, memTracker *memory.Tracker, interruptSignal *atomic.Value) (*kv.KeyRanges, error) { - krs := make([][]kv.KeyRange, len(tids)) - for i := range krs { - krs[i] = make([]kv.KeyRange, 0, len(ranges)) - } - - const checkSignalStep = 8 - var estimatedMemUsage int64 - // encodeIndexKey and EncodeIndexSeekKey is time-consuming, thus we need to - // check the interrupt signal periodically. - for i, ran := range ranges { - low, high, err := EncodeIndexKey(dctx, ran) - if err != nil { - return nil, err - } - if i == 0 { - estimatedMemUsage += int64(cap(low) + cap(high)) - } - for j, tid := range tids { - startKey := tablecodec.EncodeIndexSeekKey(tid, idxID, low) - endKey := tablecodec.EncodeIndexSeekKey(tid, idxID, high) - if i == 0 { - estimatedMemUsage += int64(cap(startKey)) + int64(cap(endKey)) - } - krs[j] = append(krs[j], kv.KeyRange{StartKey: startKey, EndKey: endKey}) - } - if i%checkSignalStep == 0 { - if i == 0 && memTracker != nil { - estimatedMemUsage *= int64(len(ranges)) - memTracker.Consume(estimatedMemUsage) - } - if interruptSignal != nil && interruptSignal.Load().(bool) { - return kv.NewPartitionedKeyRanges(nil), nil - } - } - } - return kv.NewPartitionedKeyRanges(krs), nil -} - -// EncodeIndexKey gets encoded keys containing low and high -func EncodeIndexKey(dctx *distsqlctx.DistSQLContext, ran *ranger.Range) ([]byte, []byte, error) { - tz := time.UTC - errCtx := errctx.StrictNoWarningContext - if dctx != nil { - tz = dctx.Location - errCtx = dctx.ErrCtx - } - - low, err := codec.EncodeKey(tz, nil, ran.LowVal...) - err = errCtx.HandleError(err) - if err != nil { - return nil, nil, err - } - if ran.LowExclude { - low = kv.Key(low).PrefixNext() - } - high, err := codec.EncodeKey(tz, nil, ran.HighVal...) - err = errCtx.HandleError(err) - if err != nil { - return nil, nil, err - } - - if !ran.HighExclude { - high = kv.Key(high).PrefixNext() - } - return low, high, nil -} - -// BuildTableRanges returns the key ranges encompassing the entire table, -// and its partitions if exists. -func BuildTableRanges(tbl *model.TableInfo) ([]kv.KeyRange, error) { - pis := tbl.GetPartitionInfo() - if pis == nil { - // Short path, no partition. - return appendRanges(tbl, tbl.ID) - } - - ranges := make([]kv.KeyRange, 0, len(pis.Definitions)*(len(tbl.Indices)+1)+1) - for _, def := range pis.Definitions { - rgs, err := appendRanges(tbl, def.ID) - if err != nil { - return nil, errors.Trace(err) - } - ranges = append(ranges, rgs...) - } - return ranges, nil -} - -func appendRanges(tbl *model.TableInfo, tblID int64) ([]kv.KeyRange, error) { - var ranges []*ranger.Range - if tbl.IsCommonHandle { - ranges = ranger.FullNotNullRange() - } else { - ranges = ranger.FullIntRange(false) - } - - retRanges := make([]kv.KeyRange, 0, 1+len(tbl.Indices)) - kvRanges, err := TableHandleRangesToKVRanges(nil, []int64{tblID}, tbl.IsCommonHandle, ranges) - if err != nil { - return nil, errors.Trace(err) - } - retRanges = kvRanges.AppendSelfTo(retRanges) - - for _, index := range tbl.Indices { - if index.State != model.StatePublic { - continue - } - ranges = ranger.FullRange() - idxRanges, err := IndexRangesToKVRanges(nil, tblID, index.ID, ranges) - if err != nil { - return nil, errors.Trace(err) - } - retRanges = idxRanges.AppendSelfTo(retRanges) - } - return retRanges, nil -} diff --git a/pkg/distsql/select_result.go b/pkg/distsql/select_result.go index 1f854a80ea236..5d485d4cf1483 100644 --- a/pkg/distsql/select_result.go +++ b/pkg/distsql/select_result.go @@ -402,11 +402,11 @@ func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { // NextRaw returns the next raw partial result. func (r *selectResult) NextRaw(ctx context.Context) (data []byte, err error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockNextRawError")); _err_ == nil { + failpoint.Inject("mockNextRawError", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("mockNextRawError") + failpoint.Return(nil, errors.New("mockNextRawError")) } - } + }) resultSubset, err := r.resp.Next(ctx) r.partialCount++ diff --git a/pkg/distsql/select_result.go__failpoint_stash__ b/pkg/distsql/select_result.go__failpoint_stash__ deleted file mode 100644 index 5d485d4cf1483..0000000000000 --- a/pkg/distsql/select_result.go__failpoint_stash__ +++ /dev/null @@ -1,815 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package distsql - -import ( - "bytes" - "container/heap" - "context" - "fmt" - "strconv" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - dcontext "github.com/pingcap/tidb/pkg/distsql/context" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/store/copr" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tipb/go-tipb" - tikvmetrics "github.com/tikv/client-go/v2/metrics" - "github.com/tikv/client-go/v2/tikv" - clientutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" - "golang.org/x/exp/maps" -) - -var ( - errQueryInterrupted = dbterror.ClassExecutor.NewStd(errno.ErrQueryInterrupted) -) - -var ( - _ SelectResult = (*selectResult)(nil) - _ SelectResult = (*serialSelectResults)(nil) - _ SelectResult = (*sortedSelectResults)(nil) -) - -// SelectResult is an iterator of coprocessor partial results. -type SelectResult interface { - // NextRaw gets the next raw result. - NextRaw(context.Context) ([]byte, error) - // Next reads the data into chunk. - Next(context.Context, *chunk.Chunk) error - // Close closes the iterator. - Close() error -} - -type chunkRowHeap struct { - *sortedSelectResults -} - -func (h chunkRowHeap) Len() int { - return len(h.rowPtrs) -} - -func (h chunkRowHeap) Less(i, j int) bool { - iPtr := h.rowPtrs[i] - jPtr := h.rowPtrs[j] - return h.lessRow(h.cachedChunks[iPtr.ChkIdx].GetRow(int(iPtr.RowIdx)), - h.cachedChunks[jPtr.ChkIdx].GetRow(int(jPtr.RowIdx))) -} - -func (h chunkRowHeap) Swap(i, j int) { - h.rowPtrs[i], h.rowPtrs[j] = h.rowPtrs[j], h.rowPtrs[i] -} - -func (h *chunkRowHeap) Push(x any) { - h.rowPtrs = append(h.rowPtrs, x.(chunk.RowPtr)) -} - -func (h *chunkRowHeap) Pop() any { - ret := h.rowPtrs[len(h.rowPtrs)-1] - h.rowPtrs = h.rowPtrs[0 : len(h.rowPtrs)-1] - return ret -} - -// NewSortedSelectResults is only for partition table -// If schema == nil, sort by first few columns. -func NewSortedSelectResults(ectx expression.EvalContext, selectResult []SelectResult, schema *expression.Schema, byitems []*util.ByItems, memTracker *memory.Tracker) SelectResult { - s := &sortedSelectResults{ - schema: schema, - selectResult: selectResult, - byItems: byitems, - memTracker: memTracker, - } - s.initCompareFuncs(ectx) - s.buildKeyColumns() - s.heap = &chunkRowHeap{s} - s.cachedChunks = make([]*chunk.Chunk, len(selectResult)) - return s -} - -type sortedSelectResults struct { - schema *expression.Schema - selectResult []SelectResult - compareFuncs []chunk.CompareFunc - byItems []*util.ByItems - keyColumns []int - - cachedChunks []*chunk.Chunk - rowPtrs []chunk.RowPtr - heap *chunkRowHeap - - memTracker *memory.Tracker -} - -func (ssr *sortedSelectResults) updateCachedChunk(ctx context.Context, idx uint32) error { - prevMemUsage := ssr.cachedChunks[idx].MemoryUsage() - if err := ssr.selectResult[idx].Next(ctx, ssr.cachedChunks[idx]); err != nil { - return err - } - ssr.memTracker.Consume(ssr.cachedChunks[idx].MemoryUsage() - prevMemUsage) - if ssr.cachedChunks[idx].NumRows() == 0 { - return nil - } - heap.Push(ssr.heap, chunk.RowPtr{ChkIdx: idx, RowIdx: 0}) - return nil -} - -func (ssr *sortedSelectResults) initCompareFuncs(ectx expression.EvalContext) { - ssr.compareFuncs = make([]chunk.CompareFunc, len(ssr.byItems)) - for i, item := range ssr.byItems { - keyType := item.Expr.GetType(ectx) - ssr.compareFuncs[i] = chunk.GetCompareFunc(keyType) - } -} - -func (ssr *sortedSelectResults) buildKeyColumns() { - ssr.keyColumns = make([]int, 0, len(ssr.byItems)) - for i, by := range ssr.byItems { - col := by.Expr.(*expression.Column) - if ssr.schema == nil { - ssr.keyColumns = append(ssr.keyColumns, i) - } else { - ssr.keyColumns = append(ssr.keyColumns, ssr.schema.ColumnIndex(col)) - } - } -} - -func (ssr *sortedSelectResults) lessRow(rowI, rowJ chunk.Row) bool { - for i, colIdx := range ssr.keyColumns { - cmpFunc := ssr.compareFuncs[i] - cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) - if ssr.byItems[i].Desc { - cmp = -cmp - } - if cmp < 0 { - return true - } else if cmp > 0 { - return false - } - } - return false -} - -func (*sortedSelectResults) NextRaw(context.Context) ([]byte, error) { - panic("Not support NextRaw for sortedSelectResults") -} - -func (ssr *sortedSelectResults) Next(ctx context.Context, c *chunk.Chunk) (err error) { - c.Reset() - for i := range ssr.cachedChunks { - if ssr.cachedChunks[i] == nil { - ssr.cachedChunks[i] = c.CopyConstruct() - ssr.memTracker.Consume(ssr.cachedChunks[i].MemoryUsage()) - } - } - - if ssr.heap.Len() == 0 { - for i := range ssr.cachedChunks { - if err = ssr.updateCachedChunk(ctx, uint32(i)); err != nil { - return err - } - } - } - - for c.NumRows() < c.RequiredRows() { - if ssr.heap.Len() == 0 { - break - } - - idx := heap.Pop(ssr.heap).(chunk.RowPtr) - c.AppendRow(ssr.cachedChunks[idx.ChkIdx].GetRow(int(idx.RowIdx))) - if int(idx.RowIdx) >= ssr.cachedChunks[idx.ChkIdx].NumRows()-1 { - if err = ssr.updateCachedChunk(ctx, idx.ChkIdx); err != nil { - return err - } - } else { - heap.Push(ssr.heap, chunk.RowPtr{ChkIdx: idx.ChkIdx, RowIdx: idx.RowIdx + 1}) - } - } - return nil -} - -func (ssr *sortedSelectResults) Close() (err error) { - for i, sr := range ssr.selectResult { - err = sr.Close() - if err != nil { - return err - } - ssr.memTracker.Consume(-ssr.cachedChunks[i].MemoryUsage()) - ssr.cachedChunks[i] = nil - } - return nil -} - -// NewSerialSelectResults create a SelectResult which will read each SelectResult serially. -func NewSerialSelectResults(selectResults []SelectResult) SelectResult { - return &serialSelectResults{ - selectResults: selectResults, - cur: 0, - } -} - -// serialSelectResults reads each SelectResult serially -type serialSelectResults struct { - selectResults []SelectResult - cur int -} - -func (ssr *serialSelectResults) NextRaw(ctx context.Context) ([]byte, error) { - for ssr.cur < len(ssr.selectResults) { - resultSubset, err := ssr.selectResults[ssr.cur].NextRaw(ctx) - if err != nil { - return nil, err - } - if len(resultSubset) > 0 { - return resultSubset, nil - } - ssr.cur++ // move to the next SelectResult - } - return nil, nil -} - -func (ssr *serialSelectResults) Next(ctx context.Context, chk *chunk.Chunk) error { - for ssr.cur < len(ssr.selectResults) { - if err := ssr.selectResults[ssr.cur].Next(ctx, chk); err != nil { - return err - } - if chk.NumRows() > 0 { - return nil - } - ssr.cur++ // move to the next SelectResult - } - return nil -} - -func (ssr *serialSelectResults) Close() (err error) { - for _, r := range ssr.selectResults { - if rerr := r.Close(); rerr != nil { - err = rerr - } - } - return -} - -type selectResult struct { - label string - resp kv.Response - - rowLen int - fieldTypes []*types.FieldType - ctx *dcontext.DistSQLContext - - selectResp *tipb.SelectResponse - selectRespSize int64 // record the selectResp.Size() when it is initialized. - respChkIdx int - respChunkDecoder *chunk.Decoder - - partialCount int64 // number of partial results. - sqlType string - - // copPlanIDs contains all copTasks' planIDs, - // which help to collect copTasks' runtime stats. - copPlanIDs []int - rootPlanID int - - storeType kv.StoreType - - fetchDuration time.Duration - durationReported bool - memTracker *memory.Tracker - - stats *selectResultRuntimeStats - // distSQLConcurrency and paging are only for collecting information, and they don't affect the process of execution. - distSQLConcurrency int - paging bool -} - -func (r *selectResult) fetchResp(ctx context.Context) error { - for { - r.respChkIdx = 0 - startTime := time.Now() - resultSubset, err := r.resp.Next(ctx) - duration := time.Since(startTime) - r.fetchDuration += duration - if err != nil { - return errors.Trace(err) - } - if r.selectResp != nil { - r.memConsume(-atomic.LoadInt64(&r.selectRespSize)) - } - if resultSubset == nil { - r.selectResp = nil - atomic.StoreInt64(&r.selectRespSize, 0) - if !r.durationReported { - // final round of fetch - // TODO: Add a label to distinguish between success or failure. - // https://github.com/pingcap/tidb/issues/11397 - if r.paging { - metrics.DistSQLQueryHistogram.WithLabelValues(r.label, r.sqlType, "paging").Observe(r.fetchDuration.Seconds()) - } else { - metrics.DistSQLQueryHistogram.WithLabelValues(r.label, r.sqlType, "common").Observe(r.fetchDuration.Seconds()) - } - r.durationReported = true - } - return nil - } - r.selectResp = new(tipb.SelectResponse) - err = r.selectResp.Unmarshal(resultSubset.GetData()) - if err != nil { - return errors.Trace(err) - } - respSize := int64(r.selectResp.Size()) - atomic.StoreInt64(&r.selectRespSize, respSize) - r.memConsume(respSize) - if err := r.selectResp.Error; err != nil { - return dbterror.ClassTiKV.Synthesize(terror.ErrCode(err.Code), err.Msg) - } - if err = r.ctx.SQLKiller.HandleSignal(); err != nil { - return err - } - for _, warning := range r.selectResp.Warnings { - r.ctx.AppendWarning(dbterror.ClassTiKV.Synthesize(terror.ErrCode(warning.Code), warning.Msg)) - } - - r.partialCount++ - - hasStats, ok := resultSubset.(CopRuntimeStats) - if ok { - copStats := hasStats.GetCopRuntimeStats() - if copStats != nil { - if err := r.updateCopRuntimeStats(ctx, copStats, resultSubset.RespTime()); err != nil { - return err - } - copStats.CopTime = duration - r.ctx.ExecDetails.MergeExecDetails(&copStats.ExecDetails, nil) - } - } - if len(r.selectResp.Chunks) != 0 { - break - } - } - return nil -} - -func (r *selectResult) Next(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() - if r.selectResp == nil || r.respChkIdx == len(r.selectResp.Chunks) { - err := r.fetchResp(ctx) - if err != nil { - return err - } - if r.selectResp == nil { - return nil - } - } - // TODO(Shenghui Wu): add metrics - encodeType := r.selectResp.GetEncodeType() - switch encodeType { - case tipb.EncodeType_TypeDefault: - return r.readFromDefault(ctx, chk) - case tipb.EncodeType_TypeChunk: - return r.readFromChunk(ctx, chk) - } - return errors.Errorf("unsupported encode type:%v", encodeType) -} - -// NextRaw returns the next raw partial result. -func (r *selectResult) NextRaw(ctx context.Context) (data []byte, err error) { - failpoint.Inject("mockNextRawError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("mockNextRawError")) - } - }) - - resultSubset, err := r.resp.Next(ctx) - r.partialCount++ - if resultSubset != nil && err == nil { - data = resultSubset.GetData() - } - return data, err -} - -func (r *selectResult) readFromDefault(ctx context.Context, chk *chunk.Chunk) error { - for !chk.IsFull() { - if r.respChkIdx == len(r.selectResp.Chunks) { - err := r.fetchResp(ctx) - if err != nil || r.selectResp == nil { - return err - } - } - err := r.readRowsData(chk) - if err != nil { - return err - } - if len(r.selectResp.Chunks[r.respChkIdx].RowsData) == 0 { - r.respChkIdx++ - } - } - return nil -} - -func (r *selectResult) readFromChunk(ctx context.Context, chk *chunk.Chunk) error { - if r.respChunkDecoder == nil { - r.respChunkDecoder = chunk.NewDecoder( - chunk.NewChunkWithCapacity(r.fieldTypes, 0), - r.fieldTypes, - ) - } - - for !chk.IsFull() { - if r.respChkIdx == len(r.selectResp.Chunks) { - err := r.fetchResp(ctx) - if err != nil || r.selectResp == nil { - return err - } - } - - if r.respChunkDecoder.IsFinished() { - r.respChunkDecoder.Reset(r.selectResp.Chunks[r.respChkIdx].RowsData) - } - // If the next chunk size is greater than required rows * 0.8, reuse the memory of the next chunk and return - // immediately. Otherwise, splice the data to one chunk and wait the next chunk. - if r.respChunkDecoder.RemainedRows() > int(float64(chk.RequiredRows())*0.8) { - if chk.NumRows() > 0 { - return nil - } - r.respChunkDecoder.ReuseIntermChk(chk) - r.respChkIdx++ - return nil - } - r.respChunkDecoder.Decode(chk) - if r.respChunkDecoder.IsFinished() { - r.respChkIdx++ - } - } - return nil -} - -// FillDummySummariesForTiFlashTasks fills dummy execution summaries for mpp tasks which lack summaries -func FillDummySummariesForTiFlashTasks(runtimeStatsColl *execdetails.RuntimeStatsColl, callee string, storeTypeName string, allPlanIDs []int, recordedPlanIDs map[int]int) { - num := uint64(0) - dummySummary := &tipb.ExecutorExecutionSummary{TimeProcessedNs: &num, NumProducedRows: &num, NumIterations: &num, ExecutorId: nil} - for _, planID := range allPlanIDs { - if _, ok := recordedPlanIDs[planID]; !ok { - runtimeStatsColl.RecordOneCopTask(planID, storeTypeName, callee, dummySummary) - } - } -} - -// recordExecutionSummariesForTiFlashTasks records mpp task execution summaries -func recordExecutionSummariesForTiFlashTasks(runtimeStatsColl *execdetails.RuntimeStatsColl, executionSummaries []*tipb.ExecutorExecutionSummary, callee string, storeTypeName string, allPlanIDs []int) { - var recordedPlanIDs = make(map[int]int) - for _, detail := range executionSummaries { - if detail != nil && detail.TimeProcessedNs != nil && - detail.NumProducedRows != nil && detail.NumIterations != nil { - recordedPlanIDs[runtimeStatsColl. - RecordOneCopTask(-1, storeTypeName, callee, detail)] = 0 - } - } - FillDummySummariesForTiFlashTasks(runtimeStatsColl, callee, storeTypeName, allPlanIDs, recordedPlanIDs) -} - -func (r *selectResult) updateCopRuntimeStats(ctx context.Context, copStats *copr.CopRuntimeStats, respTime time.Duration) (err error) { - callee := copStats.CalleeAddress - if r.rootPlanID <= 0 || r.ctx.RuntimeStatsColl == nil || (callee == "" && (copStats.ReqStats == nil || len(copStats.ReqStats.RPCStats) == 0)) { - return - } - - if copStats.ScanDetail != nil { - readKeys := copStats.ScanDetail.ProcessedKeys - readTime := copStats.TimeDetail.KvReadWallTime.Seconds() - readSize := float64(copStats.ScanDetail.ProcessedKeysSize) - tikvmetrics.ObserveReadSLI(uint64(readKeys), readTime, readSize) - } - - if r.stats == nil { - r.stats = &selectResultRuntimeStats{ - backoffSleep: make(map[string]time.Duration), - reqStat: tikv.NewRegionRequestRuntimeStats(), - distSQLConcurrency: r.distSQLConcurrency, - } - if ci, ok := r.resp.(copr.CopInfo); ok { - conc, extraConc := ci.GetConcurrency() - r.stats.distSQLConcurrency = conc - r.stats.extraConcurrency = extraConc - } - } - r.stats.mergeCopRuntimeStats(copStats, respTime) - - if copStats.ScanDetail != nil && len(r.copPlanIDs) > 0 { - r.ctx.RuntimeStatsColl.RecordScanDetail(r.copPlanIDs[len(r.copPlanIDs)-1], r.storeType.Name(), copStats.ScanDetail) - } - if len(r.copPlanIDs) > 0 { - r.ctx.RuntimeStatsColl.RecordTimeDetail(r.copPlanIDs[len(r.copPlanIDs)-1], r.storeType.Name(), &copStats.TimeDetail) - } - - // If hasExecutor is true, it means the summary is returned from TiFlash. - hasExecutor := false - for _, detail := range r.selectResp.GetExecutionSummaries() { - if detail != nil && detail.TimeProcessedNs != nil && - detail.NumProducedRows != nil && detail.NumIterations != nil { - if detail.ExecutorId != nil { - hasExecutor = true - } - break - } - } - - if ruDetailsRaw := ctx.Value(clientutil.RUDetailsCtxKey); ruDetailsRaw != nil && r.storeType == kv.TiFlash { - if err = execdetails.MergeTiFlashRUConsumption(r.selectResp.GetExecutionSummaries(), ruDetailsRaw.(*clientutil.RUDetails)); err != nil { - return err - } - } - if hasExecutor { - recordExecutionSummariesForTiFlashTasks(r.ctx.RuntimeStatsColl, r.selectResp.GetExecutionSummaries(), callee, r.storeType.Name(), r.copPlanIDs) - } else { - // For cop task cases, we still need this protection. - if len(r.selectResp.GetExecutionSummaries()) != len(r.copPlanIDs) { - // for TiFlash streaming call(BatchCop and MPP), it is by design that only the last response will - // carry the execution summaries, so it is ok if some responses have no execution summaries, should - // not trigger an error log in this case. - if !(r.storeType == kv.TiFlash && len(r.selectResp.GetExecutionSummaries()) == 0) { - logutil.Logger(ctx).Error("invalid cop task execution summaries length", - zap.Int("expected", len(r.copPlanIDs)), - zap.Int("received", len(r.selectResp.GetExecutionSummaries()))) - } - return - } - for i, detail := range r.selectResp.GetExecutionSummaries() { - if detail != nil && detail.TimeProcessedNs != nil && - detail.NumProducedRows != nil && detail.NumIterations != nil { - planID := r.copPlanIDs[i] - r.ctx.RuntimeStatsColl. - RecordOneCopTask(planID, r.storeType.Name(), callee, detail) - } - } - } - return -} - -func (r *selectResult) readRowsData(chk *chunk.Chunk) (err error) { - rowsData := r.selectResp.Chunks[r.respChkIdx].RowsData - decoder := codec.NewDecoder(chk, r.ctx.Location) - for !chk.IsFull() && len(rowsData) > 0 { - for i := 0; i < r.rowLen; i++ { - rowsData, err = decoder.DecodeOne(rowsData, i, r.fieldTypes[i]) - if err != nil { - return err - } - } - } - r.selectResp.Chunks[r.respChkIdx].RowsData = rowsData - return nil -} - -func (r *selectResult) memConsume(bytes int64) { - if r.memTracker != nil { - r.memTracker.Consume(bytes) - } -} - -// Close closes selectResult. -func (r *selectResult) Close() error { - metrics.DistSQLPartialCountHistogram.Observe(float64(r.partialCount)) - respSize := atomic.SwapInt64(&r.selectRespSize, 0) - if respSize > 0 { - r.memConsume(-respSize) - } - if r.ctx != nil { - if unconsumed, ok := r.resp.(copr.HasUnconsumedCopRuntimeStats); ok && unconsumed != nil { - unconsumedCopStats := unconsumed.CollectUnconsumedCopRuntimeStats() - for _, copStats := range unconsumedCopStats { - _ = r.updateCopRuntimeStats(context.Background(), copStats, time.Duration(0)) - r.ctx.ExecDetails.MergeExecDetails(&copStats.ExecDetails, nil) - } - } - } - if r.stats != nil && r.ctx != nil { - defer func() { - if ci, ok := r.resp.(copr.CopInfo); ok { - r.stats.buildTaskDuration = ci.GetBuildTaskElapsed() - batched, fallback := ci.GetStoreBatchInfo() - if batched != 0 || fallback != 0 { - r.stats.storeBatchedNum, r.stats.storeBatchedFallbackNum = batched, fallback - } - } - r.ctx.RuntimeStatsColl.RegisterStats(r.rootPlanID, r.stats) - }() - } - return r.resp.Close() -} - -// CopRuntimeStats is an interface uses to check whether the result has cop runtime stats. -type CopRuntimeStats interface { - // GetCopRuntimeStats gets the cop runtime stats information. - GetCopRuntimeStats() *copr.CopRuntimeStats -} - -type selectResultRuntimeStats struct { - copRespTime execdetails.Percentile[execdetails.Duration] - procKeys execdetails.Percentile[execdetails.Int64] - backoffSleep map[string]time.Duration - totalProcessTime time.Duration - totalWaitTime time.Duration - reqStat *tikv.RegionRequestRuntimeStats - distSQLConcurrency int - extraConcurrency int - CoprCacheHitNum int64 - storeBatchedNum uint64 - storeBatchedFallbackNum uint64 - buildTaskDuration time.Duration -} - -func (s *selectResultRuntimeStats) mergeCopRuntimeStats(copStats *copr.CopRuntimeStats, respTime time.Duration) { - s.copRespTime.Add(execdetails.Duration(respTime)) - if copStats.ScanDetail != nil { - s.procKeys.Add(execdetails.Int64(copStats.ScanDetail.ProcessedKeys)) - } else { - s.procKeys.Add(0) - } - maps.Copy(s.backoffSleep, copStats.BackoffSleep) - s.totalProcessTime += copStats.TimeDetail.ProcessTime - s.totalWaitTime += copStats.TimeDetail.WaitTime - s.reqStat.Merge(copStats.ReqStats) - if copStats.CoprCacheHit { - s.CoprCacheHitNum++ - } -} - -func (s *selectResultRuntimeStats) Clone() execdetails.RuntimeStats { - newRs := selectResultRuntimeStats{ - copRespTime: execdetails.Percentile[execdetails.Duration]{}, - procKeys: execdetails.Percentile[execdetails.Int64]{}, - backoffSleep: make(map[string]time.Duration, len(s.backoffSleep)), - reqStat: tikv.NewRegionRequestRuntimeStats(), - distSQLConcurrency: s.distSQLConcurrency, - extraConcurrency: s.extraConcurrency, - CoprCacheHitNum: s.CoprCacheHitNum, - storeBatchedNum: s.storeBatchedNum, - storeBatchedFallbackNum: s.storeBatchedFallbackNum, - buildTaskDuration: s.buildTaskDuration, - } - newRs.copRespTime.MergePercentile(&s.copRespTime) - newRs.procKeys.MergePercentile(&s.procKeys) - for k, v := range s.backoffSleep { - newRs.backoffSleep[k] += v - } - newRs.totalProcessTime += s.totalProcessTime - newRs.totalWaitTime += s.totalWaitTime - newRs.reqStat = s.reqStat.Clone() - return &newRs -} - -func (s *selectResultRuntimeStats) Merge(rs execdetails.RuntimeStats) { - other, ok := rs.(*selectResultRuntimeStats) - if !ok { - return - } - s.copRespTime.MergePercentile(&other.copRespTime) - s.procKeys.MergePercentile(&other.procKeys) - - for k, v := range other.backoffSleep { - s.backoffSleep[k] += v - } - s.totalProcessTime += other.totalProcessTime - s.totalWaitTime += other.totalWaitTime - s.reqStat.Merge(other.reqStat) - s.CoprCacheHitNum += other.CoprCacheHitNum - if other.distSQLConcurrency > s.distSQLConcurrency { - s.distSQLConcurrency = other.distSQLConcurrency - } - if other.extraConcurrency > s.extraConcurrency { - s.extraConcurrency = other.extraConcurrency - } - s.storeBatchedNum += other.storeBatchedNum - s.storeBatchedFallbackNum += other.storeBatchedFallbackNum - s.buildTaskDuration += other.buildTaskDuration -} - -func (s *selectResultRuntimeStats) String() string { - buf := bytes.NewBuffer(nil) - reqStat := s.reqStat - if s.copRespTime.Size() > 0 { - size := s.copRespTime.Size() - if size == 1 { - fmt.Fprintf(buf, "cop_task: {num: 1, max: %v, proc_keys: %v", execdetails.FormatDuration(time.Duration(s.copRespTime.GetPercentile(0))), s.procKeys.GetPercentile(0)) - } else { - vMax, vMin := s.copRespTime.GetMax(), s.copRespTime.GetMin() - vP95 := s.copRespTime.GetPercentile(0.95) - sum := s.copRespTime.Sum() - vAvg := time.Duration(sum / float64(size)) - - keyMax := s.procKeys.GetMax() - keyP95 := s.procKeys.GetPercentile(0.95) - fmt.Fprintf(buf, "cop_task: {num: %v, max: %v, min: %v, avg: %v, p95: %v", size, - execdetails.FormatDuration(time.Duration(vMax.GetFloat64())), execdetails.FormatDuration(time.Duration(vMin.GetFloat64())), - execdetails.FormatDuration(vAvg), execdetails.FormatDuration(time.Duration(vP95))) - if keyMax > 0 { - buf.WriteString(", max_proc_keys: ") - buf.WriteString(strconv.FormatInt(int64(keyMax), 10)) - buf.WriteString(", p95_proc_keys: ") - buf.WriteString(strconv.FormatInt(int64(keyP95), 10)) - } - } - if s.totalProcessTime > 0 { - buf.WriteString(", tot_proc: ") - buf.WriteString(execdetails.FormatDuration(s.totalProcessTime)) - if s.totalWaitTime > 0 { - buf.WriteString(", tot_wait: ") - buf.WriteString(execdetails.FormatDuration(s.totalWaitTime)) - } - } - if config.GetGlobalConfig().TiKVClient.CoprCache.CapacityMB > 0 { - fmt.Fprintf(buf, ", copr_cache_hit_ratio: %v", - strconv.FormatFloat(s.calcCacheHit(), 'f', 2, 64)) - } else { - buf.WriteString(", copr_cache: disabled") - } - if s.buildTaskDuration > 0 { - buf.WriteString(", build_task_duration: ") - buf.WriteString(execdetails.FormatDuration(s.buildTaskDuration)) - } - if s.distSQLConcurrency > 0 { - buf.WriteString(", max_distsql_concurrency: ") - buf.WriteString(strconv.FormatInt(int64(s.distSQLConcurrency), 10)) - } - if s.extraConcurrency > 0 { - buf.WriteString(", max_extra_concurrency: ") - buf.WriteString(strconv.FormatInt(int64(s.extraConcurrency), 10)) - } - if s.storeBatchedNum > 0 { - buf.WriteString(", store_batch_num: ") - buf.WriteString(strconv.FormatInt(int64(s.storeBatchedNum), 10)) - } - if s.storeBatchedFallbackNum > 0 { - buf.WriteString(", store_batch_fallback_num: ") - buf.WriteString(strconv.FormatInt(int64(s.storeBatchedFallbackNum), 10)) - } - buf.WriteString("}") - } - - rpcStatsStr := reqStat.String() - if len(rpcStatsStr) > 0 { - buf.WriteString(", rpc_info:{") - buf.WriteString(rpcStatsStr) - buf.WriteString("}") - } - - if len(s.backoffSleep) > 0 { - buf.WriteString(", backoff{") - idx := 0 - for k, d := range s.backoffSleep { - if idx > 0 { - buf.WriteString(", ") - } - idx++ - fmt.Fprintf(buf, "%s: %s", k, execdetails.FormatDuration(d)) - } - buf.WriteString("}") - } - return buf.String() -} - -// Tp implements the RuntimeStats interface. -func (*selectResultRuntimeStats) Tp() int { - return execdetails.TpSelectResultRuntimeStats -} - -func (s *selectResultRuntimeStats) calcCacheHit() float64 { - hit := s.CoprCacheHitNum - tot := s.copRespTime.Size() - if s.storeBatchedNum > 0 { - tot += int(s.storeBatchedNum) - } - if tot == 0 { - return 0 - } - return float64(hit) / float64(tot) -} diff --git a/pkg/disttask/framework/scheduler/binding__failpoint_binding__.go b/pkg/disttask/framework/scheduler/binding__failpoint_binding__.go deleted file mode 100644 index c960860aa5565..0000000000000 --- a/pkg/disttask/framework/scheduler/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package scheduler - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/disttask/framework/scheduler/nodes.go b/pkg/disttask/framework/scheduler/nodes.go index 99d81656a4804..7a8dc8cb401c4 100644 --- a/pkg/disttask/framework/scheduler/nodes.go +++ b/pkg/disttask/framework/scheduler/nodes.go @@ -150,9 +150,9 @@ func (nm *NodeManager) refreshNodes(ctx context.Context, taskMgr TaskManager, sl slotMgr.updateCapacity(cpuCount) nm.nodes.Store(&newNodes) - if _, _err_ := failpoint.Eval(_curpkg_("syncRefresh")); _err_ == nil { + failpoint.Inject("syncRefresh", func() { TestRefreshedChan <- struct{}{} - } + }) } // GetNodes returns the nodes managed by the framework. diff --git a/pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ b/pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ deleted file mode 100644 index 7a8dc8cb401c4..0000000000000 --- a/pkg/disttask/framework/scheduler/nodes.go__failpoint_stash__ +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package scheduler - -import ( - "context" - "sync/atomic" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - llog "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/util/intest" - "go.uber.org/zap" -) - -var ( - // liveNodesCheckInterval is the tick interval of fetching all server infos from etcs. - nodesCheckInterval = 2 * CheckTaskFinishedInterval -) - -// NodeManager maintains live TiDB nodes in the cluster, and maintains the nodes -// managed by the framework. -type NodeManager struct { - logger *zap.Logger - // prevLiveNodes is used to record the live nodes in last checking. - prevLiveNodes map[string]struct{} - // nodes is the cached nodes managed by the framework. - // see TaskManager.GetNodes for more details. - nodes atomic.Pointer[[]proto.ManagedNode] -} - -func newNodeManager(serverID string) *NodeManager { - logger := log.L() - if intest.InTest { - logger = log.L().With(zap.String("server-id", serverID)) - } - nm := &NodeManager{ - logger: logger, - prevLiveNodes: make(map[string]struct{}), - } - nodes := make([]proto.ManagedNode, 0, 10) - nm.nodes.Store(&nodes) - return nm -} - -func (nm *NodeManager) maintainLiveNodesLoop(ctx context.Context, taskMgr TaskManager) { - ticker := time.NewTicker(nodesCheckInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - nm.maintainLiveNodes(ctx, taskMgr) - } - } -} - -// maintainLiveNodes manages live node info in dist_framework_meta table -// see recoverMetaLoop in task executor for when node is inserted into dist_framework_meta. -func (nm *NodeManager) maintainLiveNodes(ctx context.Context, taskMgr TaskManager) { - // Safe to discard errors since this function can be called at regular intervals. - liveExecIDs, err := GetLiveExecIDs(ctx) - if err != nil { - nm.logger.Warn("generate task executor nodes met error", llog.ShortError(err)) - return - } - nodeChanged := len(liveExecIDs) != len(nm.prevLiveNodes) - currLiveNodes := make(map[string]struct{}, len(liveExecIDs)) - for _, execID := range liveExecIDs { - if _, ok := nm.prevLiveNodes[execID]; !ok { - nodeChanged = true - } - currLiveNodes[execID] = struct{}{} - } - if !nodeChanged { - return - } - - oldNodes, err := taskMgr.GetAllNodes(ctx) - if err != nil { - nm.logger.Warn("get all nodes met error", llog.ShortError(err)) - return - } - - deadNodes := make([]string, 0) - for _, node := range oldNodes { - if _, ok := currLiveNodes[node.ID]; !ok { - deadNodes = append(deadNodes, node.ID) - } - } - if len(deadNodes) == 0 { - nm.prevLiveNodes = currLiveNodes - return - } - nm.logger.Info("delete dead nodes from dist_framework_meta", - zap.Strings("dead-nodes", deadNodes)) - err = taskMgr.DeleteDeadNodes(ctx, deadNodes) - if err != nil { - nm.logger.Warn("delete dead nodes from dist_framework_meta failed", llog.ShortError(err)) - return - } - nm.prevLiveNodes = currLiveNodes -} - -func (nm *NodeManager) refreshNodesLoop(ctx context.Context, taskMgr TaskManager, slotMgr *SlotManager) { - ticker := time.NewTicker(nodesCheckInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - nm.refreshNodes(ctx, taskMgr, slotMgr) - } - } -} - -// TestRefreshedChan is used to sync the test. -var TestRefreshedChan = make(chan struct{}) - -// refreshNodes maintains the nodes managed by the framework. -func (nm *NodeManager) refreshNodes(ctx context.Context, taskMgr TaskManager, slotMgr *SlotManager) { - newNodes, err := taskMgr.GetAllNodes(ctx) - if err != nil { - nm.logger.Warn("get managed nodes met error", llog.ShortError(err)) - return - } - - var cpuCount int - for _, node := range newNodes { - if node.CPUCount > 0 { - cpuCount = node.CPUCount - } - } - slotMgr.updateCapacity(cpuCount) - nm.nodes.Store(&newNodes) - - failpoint.Inject("syncRefresh", func() { - TestRefreshedChan <- struct{}{} - }) -} - -// GetNodes returns the nodes managed by the framework. -// return a copy of the nodes. -func (nm *NodeManager) getNodes() []proto.ManagedNode { - nodes := *nm.nodes.Load() - res := make([]proto.ManagedNode, len(nodes)) - copy(res, nodes) - return res -} - -func filterByScope(nodes []proto.ManagedNode, targetScope string) []string { - var nodeIDs []string - haveBackground := false - for _, node := range nodes { - if node.Role == "background" { - haveBackground = true - } - } - // prefer to use "background" node instead of "" node. - if targetScope == "" && haveBackground { - for _, node := range nodes { - if node.Role == "background" { - nodeIDs = append(nodeIDs, node.ID) - } - } - return nodeIDs - } - - for _, node := range nodes { - if node.Role == targetScope { - nodeIDs = append(nodeIDs, node.ID) - } - } - return nodeIDs -} diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 39f3be3a48134..476339645762d 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -177,19 +177,19 @@ func (s *BaseScheduler) scheduleTask() { } task := *s.GetTask() // TODO: refine failpoints below. - if _, _err_ := failpoint.Eval(_curpkg_("exitScheduler")); _err_ == nil { - return - } - if val, _err_ := failpoint.Eval(_curpkg_("cancelTaskAfterRefreshTask")); _err_ == nil { + failpoint.Inject("exitScheduler", func() { + failpoint.Return() + }) + failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) { if val.(bool) && task.State == proto.TaskStateRunning { err := s.taskMgr.CancelTask(s.ctx, task.ID) if err != nil { s.logger.Error("cancel task failed", zap.Error(err)) } } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("pausePendingTask")); _err_ == nil { + failpoint.Inject("pausePendingTask", func(val failpoint.Value) { if val.(bool) && task.State == proto.TaskStatePending { _, err := s.taskMgr.PauseTask(s.ctx, task.Key) if err != nil { @@ -198,9 +198,9 @@ func (s *BaseScheduler) scheduleTask() { task.State = proto.TaskStatePausing s.task.Store(&task) } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("pauseTaskAfterRefreshTask")); _err_ == nil { + failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) { if val.(bool) && task.State == proto.TaskStateRunning { _, err := s.taskMgr.PauseTask(s.ctx, task.Key) if err != nil { @@ -209,7 +209,7 @@ func (s *BaseScheduler) scheduleTask() { task.State = proto.TaskStatePausing s.task.Store(&task) } - } + }) switch task.State { case proto.TaskStateCancelling: @@ -261,7 +261,7 @@ func (s *BaseScheduler) scheduleTask() { s.logger.Info("schedule task meet err, reschedule it", zap.Error(err)) } - failpoint.Call(_curpkg_("mockOwnerChange")) + failpoint.InjectCall("mockOwnerChange") } } } @@ -302,7 +302,7 @@ func (s *BaseScheduler) onPausing() error { func (s *BaseScheduler) onPaused() error { task := s.GetTask() s.logger.Info("on paused state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - failpoint.Call(_curpkg_("mockDMLExecutionOnPausedState")) + failpoint.InjectCall("mockDMLExecutionOnPausedState") return nil } @@ -483,9 +483,9 @@ func (s *BaseScheduler) scheduleSubTask( size += uint64(len(meta)) } - if _, _err_ := failpoint.Eval(_curpkg_("cancelBeforeUpdateTask")); _err_ == nil { + failpoint.Inject("cancelBeforeUpdateTask", func() { _ = s.taskMgr.CancelTask(s.ctx, task.ID) - } + }) // as other fields and generated key and index KV takes space too, we limit // the size of subtasks to 80% of the transaction limit. @@ -537,9 +537,9 @@ var MockServerInfo atomic.Pointer[[]string] // GetLiveExecIDs returns all live executor node IDs. func GetLiveExecIDs(ctx context.Context) ([]string, error) { - if _, _err_ := failpoint.Eval(_curpkg_("mockTaskExecutorNodes")); _err_ == nil { - return *MockServerInfo.Load(), nil - } + failpoint.Inject("mockTaskExecutorNodes", func() { + failpoint.Return(*MockServerInfo.Load(), nil) + }) serverInfos, err := generateTaskExecutorNodes(ctx) if err != nil { return nil, err diff --git a/pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ b/pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ deleted file mode 100644 index 476339645762d..0000000000000 --- a/pkg/disttask/framework/scheduler/scheduler.go__failpoint_stash__ +++ /dev/null @@ -1,624 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package scheduler - -import ( - "context" - "math/rand" - "strings" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/backoff" - disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -const ( - // for a cancelled task, it's terminal state is reverted or reverted_failed, - // so we use a special error message to indicate that the task is cancelled - // by user. - taskCancelMsg = "cancelled by user" -) - -var ( - // CheckTaskFinishedInterval is the interval for scheduler. - // exported for testing. - CheckTaskFinishedInterval = 500 * time.Millisecond - // RetrySQLTimes is the max retry times when executing SQL. - RetrySQLTimes = 30 - // RetrySQLInterval is the initial interval between two SQL retries. - RetrySQLInterval = 3 * time.Second - // RetrySQLMaxInterval is the max interval between two SQL retries. - RetrySQLMaxInterval = 30 * time.Second -) - -// Scheduler manages the lifetime of a task -// including submitting subtasks and updating the status of a task. -type Scheduler interface { - // Init initializes the scheduler, should be called before ExecuteTask. - // if Init returns error, scheduler manager will fail the task directly, - // so the returned error should be a fatal error. - Init() error - // ScheduleTask schedules the task execution step by step. - ScheduleTask() - // Close closes the scheduler, should be called if Init returns nil. - Close() - // GetTask returns the task that the scheduler is managing. - GetTask() *proto.Task - Extension -} - -// BaseScheduler is the base struct for Scheduler. -// each task type embed this struct and implement the Extension interface. -type BaseScheduler struct { - ctx context.Context - Param - // task might be accessed by multiple goroutines, so don't change its fields - // directly, make a copy, update and store it back to the atomic pointer. - task atomic.Pointer[proto.Task] - logger *zap.Logger - // when RegisterSchedulerFactory, the factory MUST initialize this fields. - Extension - - balanceSubtaskTick int - // rand is for generating random selection of nodes. - rand *rand.Rand -} - -// NewBaseScheduler creates a new BaseScheduler. -func NewBaseScheduler(ctx context.Context, task *proto.Task, param Param) *BaseScheduler { - logger := log.L().With(zap.Int64("task-id", task.ID), zap.Stringer("task-type", task.Type), zap.Bool("allocated-slots", param.allocatedSlots)) - if intest.InTest { - logger = logger.With(zap.String("server-id", param.serverID)) - } - s := &BaseScheduler{ - ctx: ctx, - Param: param, - logger: logger, - rand: rand.New(rand.NewSource(time.Now().UnixNano())), - } - s.task.Store(task) - return s -} - -// Init implements the Scheduler interface. -func (*BaseScheduler) Init() error { - return nil -} - -// ScheduleTask implements the Scheduler interface. -func (s *BaseScheduler) ScheduleTask() { - task := s.GetTask() - s.logger.Info("schedule task", - zap.Stringer("state", task.State), zap.Int("concurrency", task.Concurrency)) - s.scheduleTask() -} - -// Close closes the scheduler. -func (*BaseScheduler) Close() { -} - -// GetTask implements the Scheduler interface. -func (s *BaseScheduler) GetTask() *proto.Task { - return s.task.Load() -} - -// refreshTaskIfNeeded fetch task state from tidb_global_task table. -func (s *BaseScheduler) refreshTaskIfNeeded() error { - task := s.GetTask() - // we only query the base fields of task to reduce memory usage, other fields - // are refreshed when needed. - newTaskBase, err := s.taskMgr.GetTaskBaseByID(s.ctx, task.ID) - if err != nil { - return err - } - // state might be changed by user to pausing/resuming/cancelling, or - // in case of network partition, state/step/meta might be changed by other scheduler, - // in both cases we refresh the whole task object. - if newTaskBase.State != task.State || newTaskBase.Step != task.Step { - s.logger.Info("task state/step changed by user or other scheduler", - zap.Stringer("old-state", task.State), - zap.Stringer("new-state", newTaskBase.State), - zap.String("old-step", proto.Step2Str(task.Type, task.Step)), - zap.String("new-step", proto.Step2Str(task.Type, newTaskBase.Step))) - newTask, err := s.taskMgr.GetTaskByID(s.ctx, task.ID) - if err != nil { - return err - } - s.task.Store(newTask) - } - return nil -} - -// scheduleTask schedule the task execution step by step. -func (s *BaseScheduler) scheduleTask() { - ticker := time.NewTicker(CheckTaskFinishedInterval) - defer ticker.Stop() - for { - select { - case <-s.ctx.Done(): - s.logger.Info("schedule task exits") - return - case <-ticker.C: - err := s.refreshTaskIfNeeded() - if err != nil { - if errors.Cause(err) == storage.ErrTaskNotFound { - // this can happen when task is reverted/succeed, but before - // we reach here, cleanup routine move it to history. - return - } - s.logger.Error("refresh task failed", zap.Error(err)) - continue - } - task := *s.GetTask() - // TODO: refine failpoints below. - failpoint.Inject("exitScheduler", func() { - failpoint.Return() - }) - failpoint.Inject("cancelTaskAfterRefreshTask", func(val failpoint.Value) { - if val.(bool) && task.State == proto.TaskStateRunning { - err := s.taskMgr.CancelTask(s.ctx, task.ID) - if err != nil { - s.logger.Error("cancel task failed", zap.Error(err)) - } - } - }) - - failpoint.Inject("pausePendingTask", func(val failpoint.Value) { - if val.(bool) && task.State == proto.TaskStatePending { - _, err := s.taskMgr.PauseTask(s.ctx, task.Key) - if err != nil { - s.logger.Error("pause task failed", zap.Error(err)) - } - task.State = proto.TaskStatePausing - s.task.Store(&task) - } - }) - - failpoint.Inject("pauseTaskAfterRefreshTask", func(val failpoint.Value) { - if val.(bool) && task.State == proto.TaskStateRunning { - _, err := s.taskMgr.PauseTask(s.ctx, task.Key) - if err != nil { - s.logger.Error("pause task failed", zap.Error(err)) - } - task.State = proto.TaskStatePausing - s.task.Store(&task) - } - }) - - switch task.State { - case proto.TaskStateCancelling: - err = s.onCancelling() - case proto.TaskStatePausing: - err = s.onPausing() - case proto.TaskStatePaused: - err = s.onPaused() - // close the scheduler. - if err == nil { - return - } - case proto.TaskStateResuming: - // Case with 2 nodes. - // Here is the timeline - // 1. task in pausing state. - // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. - // 3. node1's scheduler transfer the node from pausing to paused state. - // 4. resume the task. - // 5. node2 scheduler call refreshTask and get task with resuming state. - if !s.allocatedSlots { - s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State)) - return - } - err = s.onResuming() - case proto.TaskStateReverting: - err = s.onReverting() - case proto.TaskStatePending: - err = s.onPending() - case proto.TaskStateRunning: - // Case with 2 nodes. - // Here is the timeline - // 1. task in pausing state. - // 2. node1 and node2 start schedulers with task in pausing state without allocatedSlots. - // 3. node1's scheduler transfer the node from pausing to paused state. - // 4. resume the task. - // 5. node1 start another scheduler and transfer the node from resuming to running state. - // 6. node2 scheduler call refreshTask and get task with running state. - if !s.allocatedSlots { - s.logger.Info("scheduler exit since not allocated slots", zap.Stringer("state", task.State)) - return - } - err = s.onRunning() - case proto.TaskStateSucceed, proto.TaskStateReverted, proto.TaskStateFailed: - s.onFinished() - return - } - if err != nil { - s.logger.Info("schedule task meet err, reschedule it", zap.Error(err)) - } - - failpoint.InjectCall("mockOwnerChange") - } - } -} - -// handle task in cancelling state, schedule revert subtasks. -func (s *BaseScheduler) onCancelling() error { - task := s.GetTask() - s.logger.Info("on cancelling state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - - return s.revertTask(errors.New(taskCancelMsg)) -} - -// handle task in pausing state, cancel all running subtasks. -func (s *BaseScheduler) onPausing() error { - task := *s.GetTask() - s.logger.Info("on pausing state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) - if err != nil { - s.logger.Warn("check task failed", zap.Error(err)) - return err - } - runningPendingCnt := cntByStates[proto.SubtaskStateRunning] + cntByStates[proto.SubtaskStatePending] - if runningPendingCnt > 0 { - s.logger.Debug("on pausing state, this task keeps current state", zap.Stringer("state", task.State)) - return nil - } - - s.logger.Info("all running subtasks paused, update the task to paused state") - if err = s.taskMgr.PausedTask(s.ctx, task.ID); err != nil { - return err - } - task.State = proto.TaskStatePaused - s.task.Store(&task) - return nil -} - -// handle task in paused state. -func (s *BaseScheduler) onPaused() error { - task := s.GetTask() - s.logger.Info("on paused state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - failpoint.InjectCall("mockDMLExecutionOnPausedState") - return nil -} - -// handle task in resuming state. -func (s *BaseScheduler) onResuming() error { - task := *s.GetTask() - s.logger.Info("on resuming state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) - if err != nil { - s.logger.Warn("check task failed", zap.Error(err)) - return err - } - if cntByStates[proto.SubtaskStatePaused] == 0 { - // Finish the resuming process. - s.logger.Info("all paused tasks converted to pending state, update the task to running state") - if err = s.taskMgr.ResumedTask(s.ctx, task.ID); err != nil { - return err - } - task.State = proto.TaskStateRunning - s.task.Store(&task) - return nil - } - - return s.taskMgr.ResumeSubtasks(s.ctx, task.ID) -} - -// handle task in reverting state, check all revert subtasks finishes. -func (s *BaseScheduler) onReverting() error { - task := *s.GetTask() - s.logger.Debug("on reverting state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) - if err != nil { - s.logger.Warn("check task failed", zap.Error(err)) - return err - } - runnableSubtaskCnt := cntByStates[proto.SubtaskStatePending] + cntByStates[proto.SubtaskStateRunning] - if runnableSubtaskCnt == 0 { - if err = s.OnDone(s.ctx, s, &task); err != nil { - return errors.Trace(err) - } - if err = s.taskMgr.RevertedTask(s.ctx, task.ID); err != nil { - return errors.Trace(err) - } - task.State = proto.TaskStateReverted - s.task.Store(&task) - return nil - } - // Wait all subtasks in this step finishes. - s.OnTick(s.ctx, &task) - s.logger.Debug("on reverting state, this task keeps current state", zap.Stringer("state", task.State)) - return nil -} - -// handle task in pending state, schedule subtasks. -func (s *BaseScheduler) onPending() error { - task := s.GetTask() - s.logger.Debug("on pending state", zap.Stringer("state", task.State), zap.String("step", proto.Step2Str(task.Type, task.Step))) - return s.switch2NextStep() -} - -// handle task in running state, check all running subtasks finishes. -// If subtasks finished, run into the next step. -func (s *BaseScheduler) onRunning() error { - task := s.GetTask() - s.logger.Debug("on running state", - zap.Stringer("state", task.State), - zap.String("step", proto.Step2Str(task.Type, task.Step))) - // check current step finishes. - cntByStates, err := s.taskMgr.GetSubtaskCntGroupByStates(s.ctx, task.ID, task.Step) - if err != nil { - s.logger.Warn("check task failed", zap.Error(err)) - return err - } - if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 { - subTaskErrs, err := s.taskMgr.GetSubtaskErrors(s.ctx, task.ID) - if err != nil { - s.logger.Warn("collect subtask error failed", zap.Error(err)) - return err - } - if len(subTaskErrs) > 0 { - s.logger.Warn("subtasks encounter errors", zap.Errors("subtask-errs", subTaskErrs)) - // we only store the first error as task error. - return s.revertTask(subTaskErrs[0]) - } - } else if s.isStepSucceed(cntByStates) { - return s.switch2NextStep() - } - - // Wait all subtasks in this step finishes. - s.OnTick(s.ctx, task) - s.logger.Debug("on running state, this task keeps current state", zap.Stringer("state", task.State)) - return nil -} - -func (s *BaseScheduler) onFinished() { - task := s.GetTask() - metrics.UpdateMetricsForFinishTask(task) - s.logger.Debug("schedule task, task is finished", zap.Stringer("state", task.State)) -} - -func (s *BaseScheduler) switch2NextStep() error { - task := *s.GetTask() - nextStep := s.GetNextStep(&task.TaskBase) - s.logger.Info("switch to next step", - zap.String("current-step", proto.Step2Str(task.Type, task.Step)), - zap.String("next-step", proto.Step2Str(task.Type, nextStep))) - - if nextStep == proto.StepDone { - if err := s.OnDone(s.ctx, s, &task); err != nil { - return errors.Trace(err) - } - if err := s.taskMgr.SucceedTask(s.ctx, task.ID); err != nil { - return errors.Trace(err) - } - task.Step = nextStep - task.State = proto.TaskStateSucceed - s.task.Store(&task) - return nil - } - - nodes := s.nodeMgr.getNodes() - nodeIDs := filterByScope(nodes, task.TargetScope) - eligibleNodes, err := getEligibleNodes(s.ctx, s, nodeIDs) - if err != nil { - return err - } - - s.logger.Info("eligible instances", zap.Int("num", len(eligibleNodes))) - if len(eligibleNodes) == 0 { - return errors.New("no available TiDB node to dispatch subtasks") - } - - metas, err := s.OnNextSubtasksBatch(s.ctx, s, &task, eligibleNodes, nextStep) - if err != nil { - s.logger.Warn("generate part of subtasks failed", zap.Error(err)) - return s.handlePlanErr(err) - } - - if err = s.scheduleSubTask(&task, nextStep, metas, eligibleNodes); err != nil { - return err - } - task.Step = nextStep - task.State = proto.TaskStateRunning - // and OnNextSubtasksBatch might change meta of task. - s.task.Store(&task) - return nil -} - -func (s *BaseScheduler) scheduleSubTask( - task *proto.Task, - subtaskStep proto.Step, - metas [][]byte, - eligibleNodes []string) error { - s.logger.Info("schedule subtasks", - zap.Stringer("state", task.State), - zap.String("step", proto.Step2Str(task.Type, subtaskStep)), - zap.Int("concurrency", task.Concurrency), - zap.Int("subtasks", len(metas))) - - // the scheduled node of the subtask might not be optimal, as we run all - // scheduler in parallel, and update might be called too many times when - // multiple tasks are switching to next step. - // balancer will assign the subtasks to the right instance according to - // the system load of all nodes. - if err := s.slotMgr.update(s.ctx, s.nodeMgr, s.taskMgr); err != nil { - return err - } - adjustedEligibleNodes := s.slotMgr.adjustEligibleNodes(eligibleNodes, task.Concurrency) - var size uint64 - subTasks := make([]*proto.Subtask, 0, len(metas)) - for i, meta := range metas { - // we assign the subtask to the instance in a round-robin way. - pos := i % len(adjustedEligibleNodes) - instanceID := adjustedEligibleNodes[pos] - s.logger.Debug("create subtasks", zap.String("instanceID", instanceID)) - subTasks = append(subTasks, proto.NewSubtask( - subtaskStep, task.ID, task.Type, instanceID, task.Concurrency, meta, i+1)) - - size += uint64(len(meta)) - } - failpoint.Inject("cancelBeforeUpdateTask", func() { - _ = s.taskMgr.CancelTask(s.ctx, task.ID) - }) - - // as other fields and generated key and index KV takes space too, we limit - // the size of subtasks to 80% of the transaction limit. - limit := max(uint64(float64(kv.TxnTotalSizeLimit.Load())*0.8), 1) - fn := s.taskMgr.SwitchTaskStep - if size >= limit { - // On default, transaction size limit is controlled by tidb_mem_quota_query - // which is 1G on default, so it's unlikely to reach this limit, but in - // case user set txn-total-size-limit explicitly, we insert in batch. - s.logger.Info("subtasks size exceed limit, will insert in batch", - zap.Uint64("size", size), zap.Uint64("limit", limit)) - fn = s.taskMgr.SwitchTaskStepInBatch - } - - backoffer := backoff.NewExponential(RetrySQLInterval, 2, RetrySQLMaxInterval) - return handle.RunWithRetry(s.ctx, RetrySQLTimes, backoffer, s.logger, - func(context.Context) (bool, error) { - err := fn(s.ctx, task, proto.TaskStateRunning, subtaskStep, subTasks) - if errors.Cause(err) == storage.ErrUnstableSubtasks { - return false, err - } - return true, err - }, - ) -} - -func (s *BaseScheduler) handlePlanErr(err error) error { - task := *s.GetTask() - s.logger.Warn("generate plan failed", zap.Error(err), zap.Stringer("state", task.State)) - if s.IsRetryableErr(err) { - return err - } - return s.revertTask(err) -} - -func (s *BaseScheduler) revertTask(taskErr error) error { - task := *s.GetTask() - if err := s.taskMgr.RevertTask(s.ctx, task.ID, task.State, taskErr); err != nil { - return err - } - task.State = proto.TaskStateReverting - task.Error = taskErr - s.task.Store(&task) - return nil -} - -// MockServerInfo exported for scheduler_test.go -var MockServerInfo atomic.Pointer[[]string] - -// GetLiveExecIDs returns all live executor node IDs. -func GetLiveExecIDs(ctx context.Context) ([]string, error) { - failpoint.Inject("mockTaskExecutorNodes", func() { - failpoint.Return(*MockServerInfo.Load(), nil) - }) - serverInfos, err := generateTaskExecutorNodes(ctx) - if err != nil { - return nil, err - } - execIDs := make([]string, 0, len(serverInfos)) - for _, info := range serverInfos { - execIDs = append(execIDs, disttaskutil.GenerateExecID(info)) - } - return execIDs, nil -} - -func generateTaskExecutorNodes(ctx context.Context) (serverNodes []*infosync.ServerInfo, err error) { - var serverInfos map[string]*infosync.ServerInfo - _, etcd := ctx.Value("etcd").(bool) - if intest.InTest && !etcd { - serverInfos = infosync.MockGlobalServerInfoManagerEntry.GetAllServerInfo() - } else { - serverInfos, err = infosync.GetAllServerInfo(ctx) - } - if err != nil { - return nil, err - } - if len(serverInfos) == 0 { - return nil, errors.New("not found instance") - } - - serverNodes = make([]*infosync.ServerInfo, 0, len(serverInfos)) - for _, serverInfo := range serverInfos { - serverNodes = append(serverNodes, serverInfo) - } - return serverNodes, nil -} - -// GetPreviousSubtaskMetas get subtask metas from specific step. -func (s *BaseScheduler) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) { - previousSubtasks, err := s.taskMgr.GetAllSubtasksByStepAndState(s.ctx, taskID, step, proto.SubtaskStateSucceed) - if err != nil { - s.logger.Warn("get previous succeed subtask failed", zap.String("step", proto.Step2Str(s.GetTask().Type, step))) - return nil, err - } - previousSubtaskMetas := make([][]byte, 0, len(previousSubtasks)) - for _, subtask := range previousSubtasks { - previousSubtaskMetas = append(previousSubtaskMetas, subtask.Meta) - } - return previousSubtaskMetas, nil -} - -// WithNewSession executes the function with a new session. -func (s *BaseScheduler) WithNewSession(fn func(se sessionctx.Context) error) error { - return s.taskMgr.WithNewSession(fn) -} - -// WithNewTxn executes the fn in a new transaction. -func (s *BaseScheduler) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { - return s.taskMgr.WithNewTxn(ctx, fn) -} - -func (*BaseScheduler) isStepSucceed(cntByStates map[proto.SubtaskState]int64) bool { - _, ok := cntByStates[proto.SubtaskStateSucceed] - return len(cntByStates) == 0 || (len(cntByStates) == 1 && ok) -} - -// IsCancelledErr checks if the error is a cancelled error. -func IsCancelledErr(err error) bool { - return strings.Contains(err.Error(), taskCancelMsg) -} - -// getEligibleNodes returns the eligible(live) nodes for the task. -// if the task can only be scheduled to some specific nodes, return them directly, -// we don't care liveliness of them. -func getEligibleNodes(ctx context.Context, sch Scheduler, managedNodes []string) ([]string, error) { - serverNodes, err := sch.GetEligibleInstances(ctx, sch.GetTask()) - if err != nil { - return nil, err - } - logutil.BgLogger().Debug("eligible instances", zap.Int("num", len(serverNodes))) - if len(serverNodes) == 0 { - serverNodes = managedNodes - } - - return serverNodes, nil -} diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go b/pkg/disttask/framework/scheduler/scheduler_manager.go index 8e7b4933fbe61..186fc77702a7d 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager.go @@ -301,13 +301,13 @@ func (sm *Manager) failTask(id int64, currState proto.TaskState, err error) { func (sm *Manager) gcSubtaskHistoryTableLoop() { historySubtaskTableGcInterval := defaultHistorySubtaskTableGcInterval - if val, _err_ := failpoint.Eval(_curpkg_("historySubtaskTableGcInterval")); _err_ == nil { + failpoint.Inject("historySubtaskTableGcInterval", func(val failpoint.Value) { if seconds, ok := val.(int); ok { historySubtaskTableGcInterval = time.Second * time.Duration(seconds) } <-WaitTaskFinished - } + }) sm.logger.Info("subtask table gc loop start") ticker := time.NewTicker(historySubtaskTableGcInterval) @@ -413,9 +413,9 @@ func (sm *Manager) doCleanupTask() { sm.logger.Warn("cleanup routine failed", zap.Error(err)) return } - if _, _err_ := failpoint.Eval(_curpkg_("WaitCleanUpFinished")); _err_ == nil { + failpoint.Inject("WaitCleanUpFinished", func() { WaitCleanUpFinished <- struct{}{} - } + }) sm.logger.Info("cleanup routine success") } @@ -442,9 +442,9 @@ func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error { sm.logger.Warn("cleanup routine failed", zap.Error(errors.Trace(firstErr))) } - if _, _err_ := failpoint.Eval(_curpkg_("mockTransferErr")); _err_ == nil { - return errors.New("transfer err") - } + failpoint.Inject("mockTransferErr", func() { + failpoint.Return(errors.New("transfer err")) + }) return sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks) } diff --git a/pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ b/pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ deleted file mode 100644 index 186fc77702a7d..0000000000000 --- a/pkg/disttask/framework/scheduler/scheduler_manager.go__failpoint_stash__ +++ /dev/null @@ -1,485 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package scheduler - -import ( - "context" - "slices" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/metrics" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/syncutil" - "go.uber.org/zap" -) - -var ( - // CheckTaskRunningInterval is the interval for loading tasks. - // It is exported for testing. - CheckTaskRunningInterval = 3 * time.Second - // defaultHistorySubtaskTableGcInterval is the interval of gc history subtask table. - defaultHistorySubtaskTableGcInterval = 24 * time.Hour - // DefaultCleanUpInterval is the interval of cleanup routine. - DefaultCleanUpInterval = 10 * time.Minute - defaultCollectMetricsInterval = 5 * time.Second -) - -// WaitTaskFinished is used to sync the test. -var WaitTaskFinished = make(chan struct{}) - -func (sm *Manager) getSchedulerCount() int { - sm.mu.RLock() - defer sm.mu.RUnlock() - return len(sm.mu.schedulerMap) -} - -func (sm *Manager) addScheduler(taskID int64, scheduler Scheduler) { - sm.mu.Lock() - defer sm.mu.Unlock() - sm.mu.schedulerMap[taskID] = scheduler - sm.mu.schedulers = append(sm.mu.schedulers, scheduler) - slices.SortFunc(sm.mu.schedulers, func(i, j Scheduler) int { - return i.GetTask().CompareTask(j.GetTask()) - }) -} - -func (sm *Manager) hasScheduler(taskID int64) bool { - sm.mu.Lock() - defer sm.mu.Unlock() - _, ok := sm.mu.schedulerMap[taskID] - return ok -} - -func (sm *Manager) delScheduler(taskID int64) { - sm.mu.Lock() - defer sm.mu.Unlock() - delete(sm.mu.schedulerMap, taskID) - for i, scheduler := range sm.mu.schedulers { - if scheduler.GetTask().ID == taskID { - sm.mu.schedulers = append(sm.mu.schedulers[:i], sm.mu.schedulers[i+1:]...) - break - } - } -} - -func (sm *Manager) clearSchedulers() { - sm.mu.Lock() - defer sm.mu.Unlock() - sm.mu.schedulerMap = make(map[int64]Scheduler) - sm.mu.schedulers = sm.mu.schedulers[:0] -} - -// getSchedulers returns a copy of schedulers. -func (sm *Manager) getSchedulers() []Scheduler { - sm.mu.RLock() - defer sm.mu.RUnlock() - res := make([]Scheduler, len(sm.mu.schedulers)) - copy(res, sm.mu.schedulers) - return res -} - -// Manager manage a bunch of schedulers. -// Scheduler schedule and monitor tasks. -// The scheduling task number is limited by size of gPool. -type Manager struct { - ctx context.Context - cancel context.CancelFunc - taskMgr TaskManager - wg tidbutil.WaitGroupWrapper - schedulerWG tidbutil.WaitGroupWrapper - slotMgr *SlotManager - nodeMgr *NodeManager - balancer *balancer - initialized bool - // serverID, it's value is ip:port now. - serverID string - logger *zap.Logger - - finishCh chan struct{} - - mu struct { - syncutil.RWMutex - schedulerMap map[int64]Scheduler - // in task order - schedulers []Scheduler - } -} - -// NewManager creates a scheduler struct. -func NewManager(ctx context.Context, taskMgr TaskManager, serverID string) *Manager { - logger := log.L() - if intest.InTest { - logger = log.L().With(zap.String("server-id", serverID)) - } - subCtx, cancel := context.WithCancel(ctx) - slotMgr := newSlotManager() - nodeMgr := newNodeManager(serverID) - schedulerManager := &Manager{ - ctx: subCtx, - cancel: cancel, - taskMgr: taskMgr, - serverID: serverID, - slotMgr: slotMgr, - nodeMgr: nodeMgr, - balancer: newBalancer(Param{ - taskMgr: taskMgr, - nodeMgr: nodeMgr, - slotMgr: slotMgr, - serverID: serverID, - }), - logger: logger, - finishCh: make(chan struct{}, proto.MaxConcurrentTask), - } - schedulerManager.mu.schedulerMap = make(map[int64]Scheduler) - - return schedulerManager -} - -// Start the schedulerManager, start the scheduleTaskLoop to start multiple schedulers. -func (sm *Manager) Start() { - // init cached managed nodes - sm.nodeMgr.refreshNodes(sm.ctx, sm.taskMgr, sm.slotMgr) - - sm.wg.Run(sm.scheduleTaskLoop) - sm.wg.Run(sm.gcSubtaskHistoryTableLoop) - sm.wg.Run(sm.cleanupTaskLoop) - sm.wg.Run(sm.collectLoop) - sm.wg.Run(func() { - sm.nodeMgr.maintainLiveNodesLoop(sm.ctx, sm.taskMgr) - }) - sm.wg.Run(func() { - sm.nodeMgr.refreshNodesLoop(sm.ctx, sm.taskMgr, sm.slotMgr) - }) - sm.wg.Run(func() { - sm.balancer.balanceLoop(sm.ctx, sm) - }) - sm.initialized = true -} - -// Cancel cancels the scheduler manager. -// used in test to simulate tidb node shutdown. -func (sm *Manager) Cancel() { - sm.cancel() -} - -// Stop the schedulerManager. -func (sm *Manager) Stop() { - sm.cancel() - sm.schedulerWG.Wait() - sm.wg.Wait() - sm.clearSchedulers() - sm.initialized = false - close(sm.finishCh) -} - -// Initialized check the manager initialized. -func (sm *Manager) Initialized() bool { - return sm.initialized -} - -// scheduleTaskLoop schedules the tasks. -func (sm *Manager) scheduleTaskLoop() { - sm.logger.Info("schedule task loop start") - ticker := time.NewTicker(CheckTaskRunningInterval) - defer ticker.Stop() - for { - select { - case <-sm.ctx.Done(): - sm.logger.Info("schedule task loop exits") - return - case <-ticker.C: - case <-handle.TaskChangedCh: - } - - taskCnt := sm.getSchedulerCount() - if taskCnt >= proto.MaxConcurrentTask { - sm.logger.Debug("scheduled tasks reached limit", - zap.Int("current", taskCnt), zap.Int("max", proto.MaxConcurrentTask)) - continue - } - - schedulableTasks, err := sm.getSchedulableTasks() - if err != nil { - continue - } - - err = sm.startSchedulers(schedulableTasks) - if err != nil { - continue - } - } -} - -func (sm *Manager) getSchedulableTasks() ([]*proto.TaskBase, error) { - tasks, err := sm.taskMgr.GetTopUnfinishedTasks(sm.ctx) - if err != nil { - sm.logger.Warn("get unfinished tasks failed", zap.Error(err)) - return nil, err - } - - schedulableTasks := make([]*proto.TaskBase, 0, len(tasks)) - for _, task := range tasks { - if sm.hasScheduler(task.ID) { - continue - } - // we check it before start scheduler, so no need to check it again. - // see startScheduler. - // this should not happen normally, unless user modify system table - // directly. - if getSchedulerFactory(task.Type) == nil { - sm.logger.Warn("unknown task type", zap.Int64("task-id", task.ID), - zap.Stringer("task-type", task.Type)) - sm.failTask(task.ID, task.State, errors.New("unknown task type")) - continue - } - schedulableTasks = append(schedulableTasks, task) - } - return schedulableTasks, nil -} - -func (sm *Manager) startSchedulers(schedulableTasks []*proto.TaskBase) error { - if len(schedulableTasks) == 0 { - return nil - } - if err := sm.slotMgr.update(sm.ctx, sm.nodeMgr, sm.taskMgr); err != nil { - sm.logger.Warn("update used slot failed", zap.Error(err)) - return err - } - for _, task := range schedulableTasks { - taskCnt := sm.getSchedulerCount() - if taskCnt >= proto.MaxConcurrentTask { - break - } - var reservedExecID string - allocateSlots := true - var ok bool - switch task.State { - case proto.TaskStatePending, proto.TaskStateRunning, proto.TaskStateResuming: - reservedExecID, ok = sm.slotMgr.canReserve(task) - if !ok { - // task of lower rank might be able to be scheduled. - continue - } - // reverting/cancelling/pausing - default: - allocateSlots = false - sm.logger.Info("start scheduler without allocating slots", - zap.Int64("task-id", task.ID), zap.Stringer("state", task.State)) - } - - metrics.DistTaskGauge.WithLabelValues(task.Type.String(), metrics.SchedulingStatus).Inc() - metrics.UpdateMetricsForScheduleTask(task.ID, task.Type) - sm.startScheduler(task, allocateSlots, reservedExecID) - } - return nil -} - -func (sm *Manager) failTask(id int64, currState proto.TaskState, err error) { - if err2 := sm.taskMgr.FailTask(sm.ctx, id, currState, err); err2 != nil { - sm.logger.Warn("failed to update task state to failed", - zap.Int64("task-id", id), zap.Error(err2)) - } -} - -func (sm *Manager) gcSubtaskHistoryTableLoop() { - historySubtaskTableGcInterval := defaultHistorySubtaskTableGcInterval - failpoint.Inject("historySubtaskTableGcInterval", func(val failpoint.Value) { - if seconds, ok := val.(int); ok { - historySubtaskTableGcInterval = time.Second * time.Duration(seconds) - } - - <-WaitTaskFinished - }) - - sm.logger.Info("subtask table gc loop start") - ticker := time.NewTicker(historySubtaskTableGcInterval) - defer ticker.Stop() - for { - select { - case <-sm.ctx.Done(): - sm.logger.Info("subtask history table gc loop exits") - return - case <-ticker.C: - err := sm.taskMgr.GCSubtasks(sm.ctx) - if err != nil { - sm.logger.Warn("subtask history table gc failed", zap.Error(err)) - } else { - sm.logger.Info("subtask history table gc success") - } - } - } -} - -func (sm *Manager) startScheduler(basicTask *proto.TaskBase, allocateSlots bool, reservedExecID string) { - task, err := sm.taskMgr.GetTaskByID(sm.ctx, basicTask.ID) - if err != nil { - sm.logger.Error("get task failed", zap.Int64("task-id", basicTask.ID), zap.Error(err)) - return - } - - schedulerFactory := getSchedulerFactory(task.Type) - scheduler := schedulerFactory(sm.ctx, task, Param{ - taskMgr: sm.taskMgr, - nodeMgr: sm.nodeMgr, - slotMgr: sm.slotMgr, - serverID: sm.serverID, - allocatedSlots: allocateSlots, - }) - if err = scheduler.Init(); err != nil { - sm.logger.Error("init scheduler failed", zap.Error(err)) - sm.failTask(task.ID, task.State, err) - return - } - sm.addScheduler(task.ID, scheduler) - if allocateSlots { - sm.slotMgr.reserve(basicTask, reservedExecID) - } - sm.logger.Info("task scheduler started", zap.Int64("task-id", task.ID)) - sm.schedulerWG.RunWithLog(func() { - defer func() { - scheduler.Close() - sm.delScheduler(task.ID) - if allocateSlots { - sm.slotMgr.unReserve(basicTask, reservedExecID) - } - handle.NotifyTaskChange() - sm.logger.Info("task scheduler exit", zap.Int64("task-id", task.ID)) - }() - metrics.UpdateMetricsForRunTask(task) - scheduler.ScheduleTask() - sm.finishCh <- struct{}{} - }) -} - -func (sm *Manager) cleanupTaskLoop() { - sm.logger.Info("cleanup loop start") - ticker := time.NewTicker(DefaultCleanUpInterval) - defer ticker.Stop() - for { - select { - case <-sm.ctx.Done(): - sm.logger.Info("cleanup loop exits") - return - case <-sm.finishCh: - sm.doCleanupTask() - case <-ticker.C: - sm.doCleanupTask() - } - } -} - -// WaitCleanUpFinished is used to sync the test. -var WaitCleanUpFinished = make(chan struct{}, 1) - -// doCleanupTask processes clean up routine defined by each type of tasks and cleanupMeta. -// For example: -// -// tasks with global sort should clean up tmp files stored on S3. -func (sm *Manager) doCleanupTask() { - tasks, err := sm.taskMgr.GetTasksInStates( - sm.ctx, - proto.TaskStateFailed, - proto.TaskStateReverted, - proto.TaskStateSucceed, - ) - if err != nil { - sm.logger.Warn("get task in states failed", zap.Error(err)) - return - } - if len(tasks) == 0 { - return - } - sm.logger.Info("cleanup routine start") - err = sm.cleanupFinishedTasks(tasks) - if err != nil { - sm.logger.Warn("cleanup routine failed", zap.Error(err)) - return - } - failpoint.Inject("WaitCleanUpFinished", func() { - WaitCleanUpFinished <- struct{}{} - }) - sm.logger.Info("cleanup routine success") -} - -func (sm *Manager) cleanupFinishedTasks(tasks []*proto.Task) error { - cleanedTasks := make([]*proto.Task, 0) - var firstErr error - for _, task := range tasks { - sm.logger.Info("cleanup task", zap.Int64("task-id", task.ID)) - cleanupFactory := getSchedulerCleanUpFactory(task.Type) - if cleanupFactory != nil { - cleanup := cleanupFactory() - err := cleanup.CleanUp(sm.ctx, task) - if err != nil { - firstErr = err - break - } - cleanedTasks = append(cleanedTasks, task) - } else { - // if task doesn't register cleanup function, mark it as cleaned. - cleanedTasks = append(cleanedTasks, task) - } - } - if firstErr != nil { - sm.logger.Warn("cleanup routine failed", zap.Error(errors.Trace(firstErr))) - } - - failpoint.Inject("mockTransferErr", func() { - failpoint.Return(errors.New("transfer err")) - }) - - return sm.taskMgr.TransferTasks2History(sm.ctx, cleanedTasks) -} - -func (sm *Manager) collectLoop() { - sm.logger.Info("collect loop start") - ticker := time.NewTicker(defaultCollectMetricsInterval) - defer ticker.Stop() - for { - select { - case <-sm.ctx.Done(): - sm.logger.Info("collect loop exits") - return - case <-ticker.C: - sm.collect() - } - } -} - -func (sm *Manager) collect() { - subtasks, err := sm.taskMgr.GetAllSubtasks(sm.ctx) - if err != nil { - sm.logger.Warn("get all subtasks failed", zap.Error(err)) - return - } - - subtaskCollector.subtaskInfo.Store(&subtasks) -} - -// MockScheduler mock one scheduler for one task, only used for tests. -func (sm *Manager) MockScheduler(task *proto.Task) *BaseScheduler { - return NewBaseScheduler(sm.ctx, task, Param{ - taskMgr: sm.taskMgr, - nodeMgr: sm.nodeMgr, - slotMgr: sm.slotMgr, - serverID: sm.serverID, - }) -} diff --git a/pkg/disttask/framework/storage/binding__failpoint_binding__.go b/pkg/disttask/framework/storage/binding__failpoint_binding__.go deleted file mode 100644 index a1a747a15d57f..0000000000000 --- a/pkg/disttask/framework/storage/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package storage - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/disttask/framework/storage/history.go b/pkg/disttask/framework/storage/history.go index de6da56d7b2c7..9129a002da60e 100644 --- a/pkg/disttask/framework/storage/history.go +++ b/pkg/disttask/framework/storage/history.go @@ -83,11 +83,11 @@ func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*prot // GCSubtasks deletes the history subtask which is older than the given days. func (mgr *TaskManager) GCSubtasks(ctx context.Context) error { subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60 - if val, _err_ := failpoint.Eval(_curpkg_("subtaskHistoryKeepSeconds")); _err_ == nil { + failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) { if val, ok := val.(int); ok { subtaskHistoryKeepSeconds = val } - } + }) _, err := mgr.ExecuteSQLWithNewSession( ctx, fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds), diff --git a/pkg/disttask/framework/storage/history.go__failpoint_stash__ b/pkg/disttask/framework/storage/history.go__failpoint_stash__ deleted file mode 100644 index 9129a002da60e..0000000000000 --- a/pkg/disttask/framework/storage/history.go__failpoint_stash__ +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "fmt" - "strings" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/sqlexec" -) - -// TransferSubtasks2HistoryWithSession transfer the selected subtasks into tidb_background_subtask_history table by taskID. -func (*TaskManager) TransferSubtasks2HistoryWithSession(ctx context.Context, se sessionctx.Context, taskID int64) error { - exec := se.GetSQLExecutor() - _, err := sqlexec.ExecSQL(ctx, exec, `insert into mysql.tidb_background_subtask_history select * from mysql.tidb_background_subtask where task_key = %?`, taskID) - if err != nil { - return err - } - // delete taskID subtask - _, err = sqlexec.ExecSQL(ctx, exec, "delete from mysql.tidb_background_subtask where task_key = %?", taskID) - return err -} - -// TransferTasks2History transfer the selected tasks into tidb_global_task_history table by taskIDs. -func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*proto.Task) error { - if len(tasks) == 0 { - return nil - } - taskIDStrs := make([]string, 0, len(tasks)) - for _, task := range tasks { - taskIDStrs = append(taskIDStrs, fmt.Sprintf("%d", task.ID)) - } - return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - // sensitive data in meta might be redacted, need update first. - exec := se.GetSQLExecutor() - for _, t := range tasks { - _, err := sqlexec.ExecSQL(ctx, exec, ` - update mysql.tidb_global_task - set meta= %?, state_update_time = CURRENT_TIMESTAMP() - where id = %?`, t.Meta, t.ID) - if err != nil { - return err - } - } - _, err := sqlexec.ExecSQL(ctx, exec, ` - insert into mysql.tidb_global_task_history - select * from mysql.tidb_global_task - where id in(`+strings.Join(taskIDStrs, `, `)+`)`) - if err != nil { - return err - } - - _, err = sqlexec.ExecSQL(ctx, exec, ` - delete from mysql.tidb_global_task - where id in(`+strings.Join(taskIDStrs, `, `)+`)`) - - for _, t := range tasks { - err = mgr.TransferSubtasks2HistoryWithSession(ctx, se, t.ID) - if err != nil { - return err - } - } - return err - }) -} - -// GCSubtasks deletes the history subtask which is older than the given days. -func (mgr *TaskManager) GCSubtasks(ctx context.Context) error { - subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60 - failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) { - if val, ok := val.(int); ok { - subtaskHistoryKeepSeconds = val - } - }) - _, err := mgr.ExecuteSQLWithNewSession( - ctx, - fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds), - ) - return err -} diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index c28a42b595cf2..d2c6f77e232c3 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -227,9 +227,7 @@ func (mgr *TaskManager) CreateTaskWithSession( } taskID = int64(rs[0].GetUint64(0)) - if _, _err_ := failpoint.Eval(_curpkg_("testSetLastTaskID")); _err_ == nil { - TestLastTaskID.Store(taskID) - } + failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) }) return taskID, nil } @@ -647,10 +645,10 @@ func (*TaskManager) insertSubtasks(ctx context.Context, se sessionctx.Context, s if len(subtasks) == 0 { return nil } - if _, _err_ := failpoint.Eval(_curpkg_("waitBeforeInsertSubtasks")); _err_ == nil { + failpoint.Inject("waitBeforeInsertSubtasks", func() { <-TestChannel <-TestChannel - } + }) var ( sb strings.Builder markerList = make([]string, 0, len(subtasks)) diff --git a/pkg/disttask/framework/storage/task_table.go__failpoint_stash__ b/pkg/disttask/framework/storage/task_table.go__failpoint_stash__ deleted file mode 100644 index d2c6f77e232c3..0000000000000 --- a/pkg/disttask/framework/storage/task_table.go__failpoint_stash__ +++ /dev/null @@ -1,809 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "context" - "strconv" - "strings" - "sync/atomic" - - "github.com/docker/go-units" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/sqlexec" - clitutil "github.com/tikv/client-go/v2/util" -) - -const ( - defaultSubtaskKeepDays = 14 - - basicTaskColumns = `t.id, t.task_key, t.type, t.state, t.step, t.priority, t.concurrency, t.create_time, t.target_scope` - // TaskColumns is the columns for task. - // TODO: dispatcher_id will update to scheduler_id later - TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error` - // InsertTaskColumns is the columns used in insert task. - InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, target_scope` - basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal, start_time` - // SubtaskColumns is the columns for subtask. - SubtaskColumns = basicSubtaskColumns + `, state_update_time, meta, summary` - // InsertSubtaskColumns is the columns used in insert subtask. - InsertSubtaskColumns = `step, task_key, exec_id, meta, state, type, concurrency, ordinal, create_time, checkpoint, summary` -) - -var ( - maxSubtaskBatchSize = 16 * units.MiB - - // ErrUnstableSubtasks is the error when we detected that the subtasks are - // unstable, i.e. count, order and content of the subtasks are changed on - // different call. - ErrUnstableSubtasks = errors.New("unstable subtasks") - - // ErrTaskNotFound is the error when we can't found task. - // i.e. TransferTasks2History move task from tidb_global_task to tidb_global_task_history. - ErrTaskNotFound = errors.New("task not found") - - // ErrTaskAlreadyExists is the error when we submit a task with the same task key. - // i.e. SubmitTask in handle may submit a task twice. - ErrTaskAlreadyExists = errors.New("task already exists") - - // ErrSubtaskNotFound is the error when can't find subtask by subtask_id and execId, - // i.e. scheduler change the subtask's execId when subtask need to balance to other nodes. - ErrSubtaskNotFound = errors.New("subtask not found") -) - -// TaskExecInfo is the execution information of a task, on some exec node. -type TaskExecInfo struct { - *proto.TaskBase - // SubtaskConcurrency is the concurrency of subtask in current task step. - // TODO: will be used when support subtask have smaller concurrency than task, - // TODO: such as post-process of import-into. - // TODO: we might need create one task executor for each step in this case, to alloc - // TODO: minimal resource - SubtaskConcurrency int -} - -// SessionExecutor defines the interface for executing SQLs in a session. -type SessionExecutor interface { - // WithNewSession executes the function with a new session. - WithNewSession(fn func(se sessionctx.Context) error) error - // WithNewTxn executes the fn in a new transaction. - WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error -} - -// TaskHandle provides the interface for operations needed by Scheduler. -// Then we can use scheduler's function in Scheduler interface. -type TaskHandle interface { - // GetPreviousSubtaskMetas gets previous subtask metas. - GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) - SessionExecutor -} - -// TaskManager is the manager of task and subtask. -type TaskManager struct { - sePool util.SessionPool -} - -var _ SessionExecutor = &TaskManager{} - -var taskManagerInstance atomic.Pointer[TaskManager] - -var ( - // TestLastTaskID is used for test to set the last task ID. - TestLastTaskID atomic.Int64 -) - -// NewTaskManager creates a new task manager. -func NewTaskManager(sePool util.SessionPool) *TaskManager { - return &TaskManager{ - sePool: sePool, - } -} - -// GetTaskManager gets the task manager. -func GetTaskManager() (*TaskManager, error) { - v := taskManagerInstance.Load() - if v == nil { - return nil, errors.New("task manager is not initialized") - } - return v, nil -} - -// SetTaskManager sets the task manager. -func SetTaskManager(is *TaskManager) { - taskManagerInstance.Store(is) -} - -// WithNewSession executes the function with a new session. -func (mgr *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error { - se, err := mgr.sePool.Get() - if err != nil { - return err - } - defer mgr.sePool.Put(se) - return fn(se.(sessionctx.Context)) -} - -// WithNewTxn executes the fn in a new transaction. -func (mgr *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error { - ctx = clitutil.WithInternalSourceType(ctx, kv.InternalDistTask) - return mgr.WithNewSession(func(se sessionctx.Context) (err error) { - _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "begin") - if err != nil { - return err - } - - success := false - defer func() { - sql := "rollback" - if success { - sql = "commit" - } - _, commitErr := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql) - if err == nil && commitErr != nil { - err = commitErr - } - }() - - if err = fn(se); err != nil { - return err - } - - success = true - return nil - }) -} - -// ExecuteSQLWithNewSession executes one SQL with new session. -func (mgr *TaskManager) ExecuteSQLWithNewSession(ctx context.Context, sql string, args ...any) (rs []chunk.Row, err error) { - err = mgr.WithNewSession(func(se sessionctx.Context) error { - rs, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql, args...) - return err - }) - - if err != nil { - return nil, err - } - - return -} - -// CreateTask adds a new task to task table. -func (mgr *TaskManager) CreateTask(ctx context.Context, key string, tp proto.TaskType, concurrency int, targetScope string, meta []byte) (taskID int64, err error) { - err = mgr.WithNewSession(func(se sessionctx.Context) error { - var err2 error - taskID, err2 = mgr.CreateTaskWithSession(ctx, se, key, tp, concurrency, targetScope, meta) - return err2 - }) - return -} - -// CreateTaskWithSession adds a new task to task table with session. -func (mgr *TaskManager) CreateTaskWithSession( - ctx context.Context, - se sessionctx.Context, - key string, - tp proto.TaskType, - concurrency int, - targetScope string, - meta []byte, -) (taskID int64, err error) { - cpuCount, err := mgr.getCPUCountOfNode(ctx, se) - if err != nil { - return 0, err - } - if concurrency > cpuCount { - return 0, errors.Errorf("task concurrency(%d) larger than cpu count(%d) of managed node", concurrency, cpuCount) - } - _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` - insert into mysql.tidb_global_task(`+InsertTaskColumns+`) - values (%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), %?)`, - key, tp, proto.TaskStatePending, proto.NormalPriority, concurrency, proto.StepInit, meta, targetScope) - if err != nil { - return 0, err - } - - rs, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), "select @@last_insert_id") - if err != nil { - return 0, err - } - - taskID = int64(rs[0].GetUint64(0)) - failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) }) - - return taskID, nil -} - -// GetTopUnfinishedTasks implements the scheduler.TaskManager interface. -func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) ([]*proto.TaskBase, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, - `select `+basicTaskColumns+` from mysql.tidb_global_task t - where state in (%?, %?, %?, %?, %?, %?) - order by priority asc, create_time asc, id asc - limit %?`, - proto.TaskStatePending, - proto.TaskStateRunning, - proto.TaskStateReverting, - proto.TaskStateCancelling, - proto.TaskStatePausing, - proto.TaskStateResuming, - proto.MaxConcurrentTask*2, - ) - if err != nil { - return nil, err - } - - tasks := make([]*proto.TaskBase, 0, len(rs)) - for _, r := range rs { - tasks = append(tasks, row2TaskBasic(r)) - } - return tasks, nil -} - -// GetTaskExecInfoByExecID implements the scheduler.TaskManager interface. -func (mgr *TaskManager) GetTaskExecInfoByExecID(ctx context.Context, execID string) ([]*TaskExecInfo, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, - `select `+basicTaskColumns+`, max(st.concurrency) - from mysql.tidb_global_task t join mysql.tidb_background_subtask st - on t.id = st.task_key and t.step = st.step - where t.state in (%?, %?, %?) and st.state in (%?, %?) and st.exec_id = %? - group by t.id - order by priority asc, create_time asc, id asc`, - proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStatePausing, - proto.SubtaskStatePending, proto.SubtaskStateRunning, execID) - if err != nil { - return nil, err - } - - res := make([]*TaskExecInfo, 0, len(rs)) - for _, r := range rs { - res = append(res, &TaskExecInfo{ - TaskBase: row2TaskBasic(r), - SubtaskConcurrency: int(r.GetInt64(9)), - }) - } - return res, nil -} - -// GetTasksInStates gets the tasks in the states(order by priority asc, create_time acs, id asc). -func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...any) (task []*proto.Task, err error) { - if len(states) == 0 { - return task, nil - } - - rs, err := mgr.ExecuteSQLWithNewSession(ctx, - "select "+TaskColumns+" from mysql.tidb_global_task t "+ - "where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)"+ - " order by priority asc, create_time asc, id asc", states...) - if err != nil { - return task, err - } - - for _, r := range rs { - task = append(task, Row2Task(r)) - } - return task, nil -} - -// GetTaskByID gets the task by the task ID. -func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where id = %?", taskID) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, ErrTaskNotFound - } - - return Row2Task(rs[0]), nil -} - -// GetTaskBaseByID implements the TaskManager.GetTaskBaseByID interface. -func (mgr *TaskManager) GetTaskBaseByID(ctx context.Context, taskID int64) (task *proto.TaskBase, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+basicTaskColumns+" from mysql.tidb_global_task t where id = %?", taskID) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, ErrTaskNotFound - } - - return row2TaskBasic(rs[0]), nil -} - -// GetTaskByIDWithHistory gets the task by the task ID from both tidb_global_task and tidb_global_task_history. -func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where id = %? "+ - "union select "+TaskColumns+" from mysql.tidb_global_task_history t where id = %?", taskID, taskID) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, ErrTaskNotFound - } - - return Row2Task(rs[0]), nil -} - -// GetTaskBaseByIDWithHistory gets the task by the task ID from both tidb_global_task and tidb_global_task_history. -func (mgr *TaskManager) GetTaskBaseByIDWithHistory(ctx context.Context, taskID int64) (task *proto.TaskBase, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+basicTaskColumns+" from mysql.tidb_global_task t where id = %? "+ - "union select "+basicTaskColumns+" from mysql.tidb_global_task_history t where id = %?", taskID, taskID) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, ErrTaskNotFound - } - - return row2TaskBasic(rs[0]), nil -} - -// GetTaskByKey gets the task by the task key. -func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where task_key = %?", key) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, ErrTaskNotFound - } - - return Row2Task(rs[0]), nil -} - -// GetTaskByKeyWithHistory gets the task from history table by the task key. -func (mgr *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task t where task_key = %?"+ - "union select "+TaskColumns+" from mysql.tidb_global_task_history t where task_key = %?", key, key) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, ErrTaskNotFound - } - - return Row2Task(rs[0]), nil -} - -// GetTaskBaseByKeyWithHistory gets the task base from history table by the task key. -func (mgr *TaskManager) GetTaskBaseByKeyWithHistory(ctx context.Context, key string) (task *proto.TaskBase, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+basicTaskColumns+" from mysql.tidb_global_task t where task_key = %?"+ - "union select "+basicTaskColumns+" from mysql.tidb_global_task_history t where task_key = %?", key, key) - if err != nil { - return task, err - } - if len(rs) == 0 { - return nil, ErrTaskNotFound - } - - return row2TaskBasic(rs[0]), nil -} - -// GetSubtasksByExecIDAndStepAndStates gets all subtasks by given states on one node. -func (mgr *TaskManager) GetSubtasksByExecIDAndStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) { - args := []any{execID, taskID, step} - for _, state := range states { - args = append(args, state) - } - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask - where exec_id = %? and task_key = %? and step = %? - and state in (`+strings.Repeat("%?,", len(states)-1)+"%?)", args...) - if err != nil { - return nil, err - } - - subtasks := make([]*proto.Subtask, len(rs)) - for i, row := range rs { - subtasks[i] = Row2SubTask(row) - } - return subtasks, nil -} - -// GetFirstSubtaskInStates gets the first subtask by given states. -func (mgr *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (*proto.Subtask, error) { - args := []any{tidbID, taskID, step} - for _, state := range states { - args = append(args, state) - } - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask - where exec_id = %? and task_key = %? and step = %? - and state in (`+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...) - if err != nil { - return nil, err - } - - if len(rs) == 0 { - return nil, nil - } - return Row2SubTask(rs[0]), nil -} - -// GetActiveSubtasks implements TaskManager.GetActiveSubtasks. -func (mgr *TaskManager) GetActiveSubtasks(ctx context.Context, taskID int64) ([]*proto.SubtaskBase, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, ` - select `+basicSubtaskColumns+` from mysql.tidb_background_subtask - where task_key = %? and state in (%?, %?)`, - taskID, proto.SubtaskStatePending, proto.SubtaskStateRunning) - if err != nil { - return nil, err - } - subtasks := make([]*proto.SubtaskBase, 0, len(rs)) - for _, r := range rs { - subtasks = append(subtasks, row2BasicSubTask(r)) - } - return subtasks, nil -} - -// GetAllSubtasksByStepAndState gets the subtask by step and state. -func (mgr *TaskManager) GetAllSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask - where task_key = %? and state = %? and step = %?`, - taskID, state, step) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - subtasks := make([]*proto.Subtask, 0, len(rs)) - for _, r := range rs { - subtasks = append(subtasks, Row2SubTask(r)) - } - return subtasks, nil -} - -// GetSubtaskRowCount gets the subtask row count. -func (mgr *TaskManager) GetSubtaskRowCount(ctx context.Context, taskID int64, step proto.Step) (int64, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select - cast(sum(json_extract(summary, '$.row_count')) as signed) as row_count - from mysql.tidb_background_subtask where task_key = %? and step = %?`, - taskID, step) - if err != nil { - return 0, err - } - if len(rs) == 0 { - return 0, nil - } - return rs[0].GetInt64(0), nil -} - -// UpdateSubtaskRowCount updates the subtask row count. -func (mgr *TaskManager) UpdateSubtaskRowCount(ctx context.Context, subtaskID int64, rowCount int64) error { - _, err := mgr.ExecuteSQLWithNewSession(ctx, - `update mysql.tidb_background_subtask - set summary = json_set(summary, '$.row_count', %?) where id = %?`, - rowCount, subtaskID) - return err -} - -// GetSubtaskCntGroupByStates gets the subtask count by states. -func (mgr *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, ` - select state, count(*) - from mysql.tidb_background_subtask - where task_key = %? and step = %? - group by state`, - taskID, step) - if err != nil { - return nil, err - } - - res := make(map[proto.SubtaskState]int64, len(rs)) - for _, r := range rs { - state := proto.SubtaskState(r.GetString(0)) - res[state] = r.GetInt64(1) - } - - return res, nil -} - -// GetSubtaskErrors gets subtasks' errors. -func (mgr *TaskManager) GetSubtaskErrors(ctx context.Context, taskID int64) ([]error, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, - `select error from mysql.tidb_background_subtask - where task_key = %? AND state in (%?, %?)`, taskID, proto.SubtaskStateFailed, proto.SubtaskStateCanceled) - if err != nil { - return nil, err - } - subTaskErrors := make([]error, 0, len(rs)) - for _, row := range rs { - if row.IsNull(0) { - subTaskErrors = append(subTaskErrors, nil) - continue - } - errBytes := row.GetBytes(0) - if len(errBytes) == 0 { - subTaskErrors = append(subTaskErrors, nil) - continue - } - stdErr := errors.Normalize("") - err := stdErr.UnmarshalJSON(errBytes) - if err != nil { - return nil, err - } - subTaskErrors = append(subTaskErrors, stdErr) - } - - return subTaskErrors, nil -} - -// HasSubtasksInStates checks if there are subtasks in the states. -func (mgr *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (bool, error) { - args := []any{tidbID, taskID, step} - for _, state := range states { - args = append(args, state) - } - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select 1 from mysql.tidb_background_subtask - where exec_id = %? and task_key = %? and step = %? - and state in (`+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...) - if err != nil { - return false, err - } - - return len(rs) > 0, nil -} - -// UpdateSubtasksExecIDs update subtasks' execID. -func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.SubtaskBase) error { - // skip the update process. - if len(subtasks) == 0 { - return nil - } - err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - for _, subtask := range subtasks { - _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` - update mysql.tidb_background_subtask - set exec_id = %? - where id = %? and state = %?`, - subtask.ExecID, subtask.ID, subtask.State) - if err != nil { - return err - } - } - return nil - }) - return err -} - -// SwitchTaskStep implements the scheduler.TaskManager interface. -func (mgr *TaskManager) SwitchTaskStep( - ctx context.Context, - task *proto.Task, - nextState proto.TaskState, - nextStep proto.Step, - subtasks []*proto.Subtask, -) error { - return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - vars := se.GetSessionVars() - if vars.MemQuotaQuery < variable.DefTiDBMemQuotaQuery { - bak := vars.MemQuotaQuery - if err := vars.SetSystemVar(variable.TiDBMemQuotaQuery, - strconv.Itoa(variable.DefTiDBMemQuotaQuery)); err != nil { - return err - } - defer func() { - _ = vars.SetSystemVar(variable.TiDBMemQuotaQuery, strconv.Itoa(int(bak))) - }() - } - err := mgr.updateTaskStateStep(ctx, se, task, nextState, nextStep) - if err != nil { - return err - } - if vars.StmtCtx.AffectedRows() == 0 { - // on network partition or owner change, there might be multiple - // schedulers for the same task, if other scheduler has switched - // the task to next step, skip the update process. - // Or when there is no such task. - return nil - } - return mgr.insertSubtasks(ctx, se, subtasks) - }) -} - -func (*TaskManager) updateTaskStateStep(ctx context.Context, se sessionctx.Context, - task *proto.Task, nextState proto.TaskState, nextStep proto.Step) error { - var extraUpdateStr string - if task.State == proto.TaskStatePending { - extraUpdateStr = `start_time = CURRENT_TIMESTAMP(),` - } - // TODO: during generating subtask, task meta might change, maybe move meta - // update to another place. - _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` - update mysql.tidb_global_task - set state = %?, - step = %?, `+extraUpdateStr+` - state_update_time = CURRENT_TIMESTAMP(), - meta = %? - where id = %? and state = %? and step = %?`, - nextState, nextStep, task.Meta, task.ID, task.State, task.Step) - return err -} - -// TestChannel is used for test. -var TestChannel = make(chan struct{}) - -func (*TaskManager) insertSubtasks(ctx context.Context, se sessionctx.Context, subtasks []*proto.Subtask) error { - if len(subtasks) == 0 { - return nil - } - failpoint.Inject("waitBeforeInsertSubtasks", func() { - <-TestChannel - <-TestChannel - }) - var ( - sb strings.Builder - markerList = make([]string, 0, len(subtasks)) - args = make([]any, 0, len(subtasks)*7) - ) - sb.WriteString(`insert into mysql.tidb_background_subtask(` + InsertSubtaskColumns + `) values `) - for _, subtask := range subtasks { - markerList = append(markerList, "(%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')") - args = append(args, subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta, - proto.SubtaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) - } - sb.WriteString(strings.Join(markerList, ",")) - _, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sb.String(), args...) - return err -} - -// SwitchTaskStepInBatch implements the scheduler.TaskManager interface. -func (mgr *TaskManager) SwitchTaskStepInBatch( - ctx context.Context, - task *proto.Task, - nextState proto.TaskState, - nextStep proto.Step, - subtasks []*proto.Subtask, -) error { - return mgr.WithNewSession(func(se sessionctx.Context) error { - // some subtasks may be inserted by other schedulers, we can skip them. - rs, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` - select count(1) from mysql.tidb_background_subtask - where task_key = %? and step = %?`, task.ID, nextStep) - if err != nil { - return err - } - existingTaskCnt := int(rs[0].GetInt64(0)) - if existingTaskCnt > len(subtasks) { - return errors.Annotatef(ErrUnstableSubtasks, "expected %d, got %d", - len(subtasks), existingTaskCnt) - } - subtaskBatches := mgr.splitSubtasks(subtasks[existingTaskCnt:]) - for _, batch := range subtaskBatches { - if err = mgr.insertSubtasks(ctx, se, batch); err != nil { - return err - } - } - return mgr.updateTaskStateStep(ctx, se, task, nextState, nextStep) - }) -} - -func (*TaskManager) splitSubtasks(subtasks []*proto.Subtask) [][]*proto.Subtask { - var ( - res = make([][]*proto.Subtask, 0, 10) - currBatch = make([]*proto.Subtask, 0, 10) - size int - ) - maxSize := int(min(kv.TxnTotalSizeLimit.Load(), uint64(maxSubtaskBatchSize))) - for _, s := range subtasks { - if size+len(s.Meta) > maxSize { - res = append(res, currBatch) - currBatch = nil - size = 0 - } - currBatch = append(currBatch, s) - size += len(s.Meta) - } - if len(currBatch) > 0 { - res = append(res, currBatch) - } - return res -} - -func serializeErr(err error) []byte { - if err == nil { - return nil - } - originErr := errors.Cause(err) - tErr, ok := originErr.(*errors.Error) - if !ok { - tErr = errors.Normalize(originErr.Error()) - } - errBytes, err := tErr.MarshalJSON() - if err != nil { - return nil - } - return errBytes -} - -// GetSubtasksWithHistory gets the subtasks from tidb_global_task and tidb_global_task_history. -func (mgr *TaskManager) GetSubtasksWithHistory(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error) { - var ( - rs []chunk.Row - err error - ) - err = mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - rs, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), - `select `+SubtaskColumns+` from mysql.tidb_background_subtask where task_key = %? and step = %?`, - taskID, step, - ) - if err != nil { - return err - } - - // To avoid the situation that the subtasks has been `TransferTasks2History` - // when the user show import jobs, we need to check the history table. - rsFromHistory, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), - `select `+SubtaskColumns+` from mysql.tidb_background_subtask_history where task_key = %? and step = %?`, - taskID, step, - ) - if err != nil { - return err - } - - rs = append(rs, rsFromHistory...) - return nil - }) - - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - subtasks := make([]*proto.Subtask, 0, len(rs)) - for _, r := range rs { - subtasks = append(subtasks, Row2SubTask(r)) - } - return subtasks, nil -} - -// GetAllSubtasks gets all subtasks with basic columns. -func (mgr *TaskManager) GetAllSubtasks(ctx context.Context) ([]*proto.SubtaskBase, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+basicSubtaskColumns+` from mysql.tidb_background_subtask`) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - subtasks := make([]*proto.SubtaskBase, 0, len(rs)) - for _, r := range rs { - subtasks = append(subtasks, row2BasicSubTask(r)) - } - return subtasks, nil -} - -// AdjustTaskOverflowConcurrency change the task concurrency to a max value supported by current cluster. -// This is a workaround for an upgrade bug: in v7.5.x, the task concurrency is hard-coded to 16, resulting in -// a stuck issue if the new version TiDB has less than 16 CPU count. -// We don't adjust the concurrency in subtask table because this field does not exist in v7.5.0. -// For details, see https://github.com/pingcap/tidb/issues/50894. -// For the following versions, there is a check when submitting a new task. This function should be a no-op. -func (mgr *TaskManager) AdjustTaskOverflowConcurrency(ctx context.Context, se sessionctx.Context) error { - cpuCount, err := mgr.getCPUCountOfNode(ctx, se) - if err != nil { - return err - } - sql := "update mysql.tidb_global_task set concurrency = %? where concurrency > %?;" - _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql, cpuCount, cpuCount) - return err -} diff --git a/pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go b/pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go deleted file mode 100644 index 42a8c5efc3dee..0000000000000 --- a/pkg/disttask/framework/taskexecutor/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package taskexecutor - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go index 911a1f9f23c98..2551173673f38 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -306,9 +306,9 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error) } execute.SetFrameworkInfo(stepExecutor, resource) - if _, _err_ := failpoint.Eval(_curpkg_("mockExecSubtaskInitEnvErr")); _err_ == nil { - return errors.New("mockExecSubtaskInitEnvErr") - } + failpoint.Inject("mockExecSubtaskInitEnvErr", func() { + failpoint.Return(errors.New("mockExecSubtaskInitEnvErr")) + }) if err := stepExecutor.Init(runStepCtx); err != nil { e.onError(err) return e.getError() @@ -367,9 +367,9 @@ func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error) } } - if _, _err_ := failpoint.Eval(_curpkg_("cancelBeforeRunSubtask")); _err_ == nil { + failpoint.Inject("cancelBeforeRunSubtask", func() { runStepCancel(nil) - } + }) e.runSubtask(runStepCtx, stepExecutor, subtask) } @@ -402,17 +402,17 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute. }() return stepExecutor.RunSubtask(ctx, subtask) }() - if val, _err_ := failpoint.Eval(_curpkg_("MockRunSubtaskCancel")); _err_ == nil { + failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) { if val.(bool) { err = ErrCancelSubtask } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("MockRunSubtaskContextCanceled")); _err_ == nil { + failpoint.Inject("MockRunSubtaskContextCanceled", func(val failpoint.Value) { if val.(bool) { err = context.Canceled } - } + }) if err != nil { e.onError(err) @@ -423,18 +423,18 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute. return } - if _, _err_ := failpoint.Eval(_curpkg_("mockTiDBShutdown")); _err_ == nil { + failpoint.Inject("mockTiDBShutdown", func() { if MockTiDBDown(e.id, e.GetTaskBase()) { - return + failpoint.Return() } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("MockExecutorRunErr")); _err_ == nil { + failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) { if val.(bool) { e.onError(errors.New("MockExecutorRunErr")) } - } - if val, _err_ := failpoint.Eval(_curpkg_("MockExecutorRunCancel")); _err_ == nil { + }) + failpoint.Inject("MockExecutorRunCancel", func(val failpoint.Value) { if taskID, ok := val.(int); ok { mgr, err := storage.GetTaskManager() if err != nil { @@ -446,7 +446,7 @@ func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute. } } } - } + }) e.onSubtaskFinished(ctx, stepExecutor, subtask) } @@ -456,11 +456,11 @@ func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execu e.onError(err) } } - if val, _err_ := failpoint.Eval(_curpkg_("MockSubtaskFinishedCancel")); _err_ == nil { + failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) { if val.(bool) { e.onError(ErrCancelSubtask) } - } + }) finished := e.markSubTaskCanceledOrFailed(ctx, subtask) if finished { @@ -474,7 +474,7 @@ func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execu return } - failpoint.Call(_curpkg_("syncAfterSubtaskFinish")) + failpoint.InjectCall("syncAfterSubtaskFinish") } // GetTaskBase implements TaskExecutor.GetTaskBase. diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ b/pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ deleted file mode 100644 index 2551173673f38..0000000000000 --- a/pkg/disttask/framework/taskexecutor/task_executor.go__failpoint_stash__ +++ /dev/null @@ -1,683 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package taskexecutor - -import ( - "context" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" - "github.com/pingcap/tidb/pkg/lightning/common" - llog "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/backoff" - "github.com/pingcap/tidb/pkg/util/gctuner" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" -) - -var ( - // checkBalanceSubtaskInterval is the default check interval for checking - // subtasks balance to/away from this node. - checkBalanceSubtaskInterval = 2 * time.Second - - // updateSubtaskSummaryInterval is the interval for updating the subtask summary to - // subtask table. - updateSubtaskSummaryInterval = 3 * time.Second -) - -var ( - // ErrCancelSubtask is the cancel cause when cancelling subtasks. - ErrCancelSubtask = errors.New("cancel subtasks") - // ErrFinishSubtask is the cancel cause when TaskExecutor successfully processed subtasks. - ErrFinishSubtask = errors.New("finish subtasks") - // ErrNonIdempotentSubtask means the subtask is left in running state and is not idempotent, - // so cannot be run again. - ErrNonIdempotentSubtask = errors.New("subtask in running state and is not idempotent") - - // MockTiDBDown is used to mock TiDB node down, return true if it's chosen. - MockTiDBDown func(execID string, task *proto.TaskBase) bool -) - -// BaseTaskExecutor is the base implementation of TaskExecutor. -type BaseTaskExecutor struct { - // id, it's the same as server id now, i.e. host:port. - id string - // we only store task base here to reduce overhead of refreshing it. - // task meta is loaded when we do execute subtasks, see GetStepExecutor. - taskBase atomic.Pointer[proto.TaskBase] - taskTable TaskTable - logger *zap.Logger - ctx context.Context - cancel context.CancelFunc - Extension - - currSubtaskID atomic.Int64 - - mu struct { - sync.RWMutex - err error - // handled indicates whether the error has been updated to one of the subtask. - handled bool - // runtimeCancel is used to cancel the Run/Rollback when error occurs. - runtimeCancel context.CancelCauseFunc - } -} - -// NewBaseTaskExecutor creates a new BaseTaskExecutor. -// see TaskExecutor.Init for why we want to use task-base to create TaskExecutor. -// TODO: we can refactor this part to pass task base only, but currently ADD-INDEX -// depends on it to init, so we keep it for now. -func NewBaseTaskExecutor(ctx context.Context, id string, task *proto.Task, taskTable TaskTable) *BaseTaskExecutor { - logger := log.L().With(zap.Int64("task-id", task.ID), zap.String("task-type", string(task.Type))) - if intest.InTest { - logger = logger.With(zap.String("server-id", id)) - } - subCtx, cancelFunc := context.WithCancel(ctx) - taskExecutorImpl := &BaseTaskExecutor{ - id: id, - taskTable: taskTable, - ctx: subCtx, - cancel: cancelFunc, - logger: logger, - } - taskExecutorImpl.taskBase.Store(&task.TaskBase) - return taskExecutorImpl -} - -// checkBalanceSubtask check whether the subtasks are balanced to or away from this node. -// - If other subtask of `running` state is scheduled to this node, try changed to -// `pending` state, to make sure subtasks can be balanced later when node scale out. -// - If current running subtask are scheduled away from this node, i.e. this node -// is taken as down, cancel running. -func (e *BaseTaskExecutor) checkBalanceSubtask(ctx context.Context) { - ticker := time.NewTicker(checkBalanceSubtaskInterval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - } - - task := e.taskBase.Load() - subtasks, err := e.taskTable.GetSubtasksByExecIDAndStepAndStates(ctx, e.id, task.ID, task.Step, - proto.SubtaskStateRunning) - if err != nil { - e.logger.Error("get subtasks failed", zap.Error(err)) - continue - } - if len(subtasks) == 0 { - e.logger.Info("subtask is scheduled away, cancel running") - // cancels runStep, but leave the subtask state unchanged. - e.cancelRunStepWith(nil) - return - } - - extraRunningSubtasks := make([]*proto.SubtaskBase, 0, len(subtasks)) - for _, st := range subtasks { - if st.ID == e.currSubtaskID.Load() { - continue - } - if !e.IsIdempotent(st) { - e.updateSubtaskStateAndErrorImpl(ctx, st.ExecID, st.ID, proto.SubtaskStateFailed, ErrNonIdempotentSubtask) - return - } - extraRunningSubtasks = append(extraRunningSubtasks, &st.SubtaskBase) - } - if len(extraRunningSubtasks) > 0 { - if err = e.taskTable.RunningSubtasksBack2Pending(ctx, extraRunningSubtasks); err != nil { - e.logger.Error("update running subtasks back to pending failed", zap.Error(err)) - } else { - e.logger.Info("update extra running subtasks back to pending", - zap.Stringers("subtasks", extraRunningSubtasks)) - } - } - } -} - -func (e *BaseTaskExecutor) updateSubtaskSummaryLoop( - checkCtx, runStepCtx context.Context, stepExec execute.StepExecutor) { - taskMgr := e.taskTable.(*storage.TaskManager) - ticker := time.NewTicker(updateSubtaskSummaryInterval) - defer ticker.Stop() - curSubtaskID := e.currSubtaskID.Load() - update := func() { - summary := stepExec.RealtimeSummary() - err := taskMgr.UpdateSubtaskRowCount(runStepCtx, curSubtaskID, summary.RowCount) - if err != nil { - e.logger.Info("update subtask row count failed", zap.Error(err)) - } - } - for { - select { - case <-checkCtx.Done(): - update() - return - case <-ticker.C: - } - update() - } -} - -// Init implements the TaskExecutor interface. -func (*BaseTaskExecutor) Init(_ context.Context) error { - return nil -} - -// Ctx returns the context of the task executor. -// TODO: remove it when add-index.taskexecutor.Init don't depends on it. -func (e *BaseTaskExecutor) Ctx() context.Context { - return e.ctx -} - -// Run implements the TaskExecutor interface. -func (e *BaseTaskExecutor) Run(resource *proto.StepResource) { - var err error - // task executor occupies resources, if there's no subtask to run for 10s, - // we release the resources so that other tasks can use them. - // 300ms + 600ms + 1.2s + 2s * 4 = 10.1s - backoffer := backoff.NewExponential(SubtaskCheckInterval, 2, MaxSubtaskCheckInterval) - checkInterval, noSubtaskCheckCnt := SubtaskCheckInterval, 0 - for { - select { - case <-e.ctx.Done(): - return - case <-time.After(checkInterval): - } - if err = e.refreshTask(); err != nil { - if errors.Cause(err) == storage.ErrTaskNotFound { - return - } - e.logger.Error("refresh task failed", zap.Error(err)) - continue - } - task := e.taskBase.Load() - if task.State != proto.TaskStateRunning { - return - } - if exist, err := e.taskTable.HasSubtasksInStates(e.ctx, e.id, task.ID, task.Step, - unfinishedSubtaskStates...); err != nil { - e.logger.Error("check whether there are subtasks to run failed", zap.Error(err)) - continue - } else if !exist { - if noSubtaskCheckCnt >= maxChecksWhenNoSubtask { - e.logger.Info("no subtask to run for a while, exit") - break - } - checkInterval = backoffer.Backoff(noSubtaskCheckCnt) - noSubtaskCheckCnt++ - continue - } - // reset it when we get a subtask - checkInterval, noSubtaskCheckCnt = SubtaskCheckInterval, 0 - err = e.RunStep(resource) - if err != nil { - e.logger.Error("failed to handle task", zap.Error(err)) - } - } -} - -// RunStep start to fetch and run all subtasks for the step of task on the node. -// return if there's no subtask to run. -func (e *BaseTaskExecutor) RunStep(resource *proto.StepResource) (err error) { - defer func() { - if r := recover(); r != nil { - e.logger.Error("BaseTaskExecutor panicked", zap.Any("recover", r), zap.Stack("stack")) - err4Panic := errors.Errorf("%v", r) - err1 := e.updateSubtask(err4Panic) - if err == nil { - err = err1 - } - } - }() - err = e.runStep(resource) - if e.mu.handled { - return err - } - if err == nil { - // may have error in - // 1. defer function in run(ctx, task) - // 2. cancel ctx - // TODO: refine onError/getError - if e.getError() != nil { - err = e.getError() - } else if e.ctx.Err() != nil { - err = e.ctx.Err() - } else { - return nil - } - } - - return e.updateSubtask(err) -} - -func (e *BaseTaskExecutor) runStep(resource *proto.StepResource) (resErr error) { - runStepCtx, runStepCancel := context.WithCancelCause(e.ctx) - e.registerRunStepCancelFunc(runStepCancel) - defer func() { - runStepCancel(ErrFinishSubtask) - e.unregisterRunStepCancelFunc() - }() - e.resetError() - taskBase := e.taskBase.Load() - task, err := e.taskTable.GetTaskByID(e.ctx, taskBase.ID) - if err != nil { - e.onError(err) - return e.getError() - } - stepLogger := llog.BeginTask(e.logger.With( - zap.String("step", proto.Step2Str(task.Type, task.Step)), - zap.Float64("mem-limit-percent", gctuner.GlobalMemoryLimitTuner.GetPercentage()), - zap.String("server-mem-limit", memory.ServerMemoryLimitOriginText.Load()), - zap.Stringer("resource", resource), - ), "execute task step") - // log as info level, subtask might be cancelled, let caller check it. - defer func() { - stepLogger.End(zap.InfoLevel, resErr) - }() - - stepExecutor, err := e.GetStepExecutor(task) - if err != nil { - e.onError(err) - return e.getError() - } - execute.SetFrameworkInfo(stepExecutor, resource) - - failpoint.Inject("mockExecSubtaskInitEnvErr", func() { - failpoint.Return(errors.New("mockExecSubtaskInitEnvErr")) - }) - if err := stepExecutor.Init(runStepCtx); err != nil { - e.onError(err) - return e.getError() - } - - defer func() { - err := stepExecutor.Cleanup(runStepCtx) - if err != nil { - e.logger.Error("cleanup subtask exec env failed", zap.Error(err)) - e.onError(err) - } - }() - - for { - // check if any error occurs. - if err := e.getError(); err != nil { - break - } - if runStepCtx.Err() != nil { - break - } - - subtask, err := e.taskTable.GetFirstSubtaskInStates(runStepCtx, e.id, task.ID, task.Step, - proto.SubtaskStatePending, proto.SubtaskStateRunning) - if err != nil { - e.logger.Warn("GetFirstSubtaskInStates meets error", zap.Error(err)) - continue - } - if subtask == nil { - break - } - - if subtask.State == proto.SubtaskStateRunning { - if !e.IsIdempotent(subtask) { - e.logger.Info("subtask in running state and is not idempotent, fail it", - zap.Int64("subtask-id", subtask.ID)) - e.onError(ErrNonIdempotentSubtask) - e.updateSubtaskStateAndErrorImpl(runStepCtx, subtask.ExecID, subtask.ID, proto.SubtaskStateFailed, ErrNonIdempotentSubtask) - e.markErrorHandled() - break - } - e.logger.Info("subtask in running state and is idempotent", - zap.Int64("subtask-id", subtask.ID)) - } else { - // subtask.State == proto.SubtaskStatePending - err := e.startSubtask(runStepCtx, subtask.ID) - if err != nil { - e.logger.Warn("startSubtask meets error", zap.Error(err)) - // should ignore ErrSubtaskNotFound - // since it only means that the subtask not owned by current task executor. - if err == storage.ErrSubtaskNotFound { - continue - } - e.onError(err) - continue - } - } - - failpoint.Inject("cancelBeforeRunSubtask", func() { - runStepCancel(nil) - }) - - e.runSubtask(runStepCtx, stepExecutor, subtask) - } - return e.getError() -} - -func (e *BaseTaskExecutor) hasRealtimeSummary(stepExecutor execute.StepExecutor) bool { - _, ok := e.taskTable.(*storage.TaskManager) - return ok && stepExecutor.RealtimeSummary() != nil -} - -func (e *BaseTaskExecutor) runSubtask(ctx context.Context, stepExecutor execute.StepExecutor, subtask *proto.Subtask) { - err := func() error { - e.currSubtaskID.Store(subtask.ID) - - var wg util.WaitGroupWrapper - checkCtx, checkCancel := context.WithCancel(ctx) - wg.RunWithLog(func() { - e.checkBalanceSubtask(checkCtx) - }) - - if e.hasRealtimeSummary(stepExecutor) { - wg.RunWithLog(func() { - e.updateSubtaskSummaryLoop(checkCtx, ctx, stepExecutor) - }) - } - defer func() { - checkCancel() - wg.Wait() - }() - return stepExecutor.RunSubtask(ctx, subtask) - }() - failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) { - if val.(bool) { - err = ErrCancelSubtask - } - }) - - failpoint.Inject("MockRunSubtaskContextCanceled", func(val failpoint.Value) { - if val.(bool) { - err = context.Canceled - } - }) - - if err != nil { - e.onError(err) - } - - finished := e.markSubTaskCanceledOrFailed(ctx, subtask) - if finished { - return - } - - failpoint.Inject("mockTiDBShutdown", func() { - if MockTiDBDown(e.id, e.GetTaskBase()) { - failpoint.Return() - } - }) - - failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) { - if val.(bool) { - e.onError(errors.New("MockExecutorRunErr")) - } - }) - failpoint.Inject("MockExecutorRunCancel", func(val failpoint.Value) { - if taskID, ok := val.(int); ok { - mgr, err := storage.GetTaskManager() - if err != nil { - e.logger.Error("get task manager failed", zap.Error(err)) - } else { - err = mgr.CancelTask(ctx, int64(taskID)) - if err != nil { - e.logger.Error("cancel task failed", zap.Error(err)) - } - } - } - }) - e.onSubtaskFinished(ctx, stepExecutor, subtask) -} - -func (e *BaseTaskExecutor) onSubtaskFinished(ctx context.Context, executor execute.StepExecutor, subtask *proto.Subtask) { - if err := e.getError(); err == nil { - if err = executor.OnFinished(ctx, subtask); err != nil { - e.onError(err) - } - } - failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) { - if val.(bool) { - e.onError(ErrCancelSubtask) - } - }) - - finished := e.markSubTaskCanceledOrFailed(ctx, subtask) - if finished { - return - } - - e.finishSubtask(ctx, subtask) - - finished = e.markSubTaskCanceledOrFailed(ctx, subtask) - if finished { - return - } - - failpoint.InjectCall("syncAfterSubtaskFinish") -} - -// GetTaskBase implements TaskExecutor.GetTaskBase. -func (e *BaseTaskExecutor) GetTaskBase() *proto.TaskBase { - return e.taskBase.Load() -} - -// CancelRunningSubtask implements TaskExecutor.CancelRunningSubtask. -func (e *BaseTaskExecutor) CancelRunningSubtask() { - e.cancelRunStepWith(ErrCancelSubtask) -} - -// Cancel implements TaskExecutor.Cancel. -func (e *BaseTaskExecutor) Cancel() { - e.cancel() -} - -// Close closes the TaskExecutor when all the subtasks are complete. -func (e *BaseTaskExecutor) Close() { - e.Cancel() -} - -// refreshTask fetch task state from tidb_global_task table. -func (e *BaseTaskExecutor) refreshTask() error { - task := e.GetTaskBase() - newTaskBase, err := e.taskTable.GetTaskBaseByID(e.ctx, task.ID) - if err != nil { - return err - } - e.taskBase.Store(newTaskBase) - return nil -} - -func (e *BaseTaskExecutor) registerRunStepCancelFunc(cancel context.CancelCauseFunc) { - e.mu.Lock() - defer e.mu.Unlock() - e.mu.runtimeCancel = cancel -} - -func (e *BaseTaskExecutor) unregisterRunStepCancelFunc() { - e.mu.Lock() - defer e.mu.Unlock() - e.mu.runtimeCancel = nil -} - -func (e *BaseTaskExecutor) cancelRunStepWith(cause error) { - e.mu.Lock() - defer e.mu.Unlock() - if e.mu.runtimeCancel != nil { - e.mu.runtimeCancel(cause) - } -} - -func (e *BaseTaskExecutor) onError(err error) { - if err == nil { - return - } - err = errors.Trace(err) - e.logger.Error("onError", zap.Error(err), zap.Stack("stack")) - e.mu.Lock() - defer e.mu.Unlock() - - if e.mu.err == nil { - e.mu.err = err - e.logger.Error("taskExecutor met first error", zap.Error(err)) - } - - if e.mu.runtimeCancel != nil { - e.mu.runtimeCancel(err) - } -} - -func (e *BaseTaskExecutor) markErrorHandled() { - e.mu.Lock() - defer e.mu.Unlock() - e.mu.handled = true -} - -func (e *BaseTaskExecutor) getError() error { - e.mu.RLock() - defer e.mu.RUnlock() - return e.mu.err -} - -func (e *BaseTaskExecutor) resetError() { - e.mu.Lock() - defer e.mu.Unlock() - e.mu.err = nil - e.mu.handled = false -} - -func (e *BaseTaskExecutor) updateSubtaskStateAndErrorImpl(ctx context.Context, execID string, subtaskID int64, state proto.SubtaskState, subTaskErr error) { - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, e.logger, - func(ctx context.Context) (bool, error) { - return true, e.taskTable.UpdateSubtaskStateAndError(ctx, execID, subtaskID, state, subTaskErr) - }, - ) - if err != nil { - e.onError(err) - } -} - -// startSubtask try to change the state of the subtask to running. -// If the subtask is not owned by the task executor, -// the update will fail and task executor should not run the subtask. -func (e *BaseTaskExecutor) startSubtask(ctx context.Context, subtaskID int64) error { - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, e.logger, - func(ctx context.Context) (bool, error) { - err := e.taskTable.StartSubtask(ctx, subtaskID, e.id) - if err == storage.ErrSubtaskNotFound { - // No need to retry. - return false, err - } - return true, err - }, - ) -} - -func (e *BaseTaskExecutor) finishSubtask(ctx context.Context, subtask *proto.Subtask) { - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, e.logger, - func(ctx context.Context) (bool, error) { - return true, e.taskTable.FinishSubtask(ctx, subtask.ExecID, subtask.ID, subtask.Meta) - }, - ) - if err != nil { - e.onError(err) - } -} - -// markSubTaskCanceledOrFailed check the error type and decide the subtasks' state. -// 1. Only cancel subtasks when meet ErrCancelSubtask. -// 2. Only fail subtasks when meet non retryable error. -// 3. When meet other errors, don't change subtasks' state. -func (e *BaseTaskExecutor) markSubTaskCanceledOrFailed(ctx context.Context, subtask *proto.Subtask) bool { - if err := e.getError(); err != nil { - err := errors.Cause(err) - if ctx.Err() != nil && context.Cause(ctx) == ErrCancelSubtask { - e.logger.Warn("subtask canceled", zap.Error(err)) - e.updateSubtaskStateAndErrorImpl(e.ctx, subtask.ExecID, subtask.ID, proto.SubtaskStateCanceled, nil) - } else if e.IsRetryableError(err) { - e.logger.Warn("meet retryable error", zap.Error(err)) - } else if common.IsContextCanceledError(err) { - e.logger.Info("meet context canceled for gracefully shutdown", zap.Error(err)) - } else { - e.logger.Warn("subtask failed", zap.Error(err)) - e.updateSubtaskStateAndErrorImpl(e.ctx, subtask.ExecID, subtask.ID, proto.SubtaskStateFailed, err) - } - e.markErrorHandled() - return true - } - return false -} - -func (e *BaseTaskExecutor) failSubtaskWithRetry(ctx context.Context, taskID int64, err error) error { - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err1 := handle.RunWithRetry(e.ctx, scheduler.RetrySQLTimes, backoffer, e.logger, - func(_ context.Context) (bool, error) { - return true, e.taskTable.FailSubtask(ctx, e.id, taskID, err) - }, - ) - if err1 == nil { - e.logger.Info("failed one subtask succeed", zap.NamedError("subtask-err", err)) - } - return err1 -} - -func (e *BaseTaskExecutor) cancelSubtaskWithRetry(ctx context.Context, taskID int64, err error) error { - e.logger.Warn("subtask canceled", zap.NamedError("subtask-cancel", err)) - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err1 := handle.RunWithRetry(e.ctx, scheduler.RetrySQLTimes, backoffer, e.logger, - func(_ context.Context) (bool, error) { - return true, e.taskTable.CancelSubtask(ctx, e.id, taskID) - }, - ) - if err1 == nil { - e.logger.Info("canceled one subtask succeed", zap.NamedError("subtask-cancel", err)) - } - return err1 -} - -// updateSubtask check the error type and decide the subtasks' state. -// 1. Only cancel subtasks when meet ErrCancelSubtask. -// 2. Only fail subtasks when meet non retryable error. -// 3. When meet other errors, don't change subtasks' state. -// Handled errors should not happen during subtasks execution. -// Only handle errors before subtasks execution and after subtasks execution. -func (e *BaseTaskExecutor) updateSubtask(err error) error { - task := e.taskBase.Load() - err = errors.Cause(err) - // TODO this branch is unreachable now, remove it when we refactor error handling. - if e.ctx.Err() != nil && context.Cause(e.ctx) == ErrCancelSubtask { - return e.cancelSubtaskWithRetry(e.ctx, task.ID, ErrCancelSubtask) - } else if e.IsRetryableError(err) { - e.logger.Warn("meet retryable error", zap.Error(err)) - } else if common.IsContextCanceledError(err) { - e.logger.Info("meet context canceled for gracefully shutdown", zap.Error(err)) - } else { - return e.failSubtaskWithRetry(e.ctx, task.ID, err) - } - return nil -} diff --git a/pkg/disttask/importinto/binding__failpoint_binding__.go b/pkg/disttask/importinto/binding__failpoint_binding__.go deleted file mode 100644 index f923a525842a5..0000000000000 --- a/pkg/disttask/importinto/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package importinto - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/disttask/importinto/planner.go b/pkg/disttask/importinto/planner.go index cc4bec874a880..8fda8686facbb 100644 --- a/pkg/disttask/importinto/planner.go +++ b/pkg/disttask/importinto/planner.go @@ -292,12 +292,12 @@ func generateImportSpecs(pCtx planner.PlanCtx, p *LogicalPlan) ([]planner.Pipeli } func skipMergeSort(kvGroup string, stats []external.MultipleFilesStat) bool { - if val, _err_ := failpoint.Eval(_curpkg_("forceMergeSort")); _err_ == nil { + failpoint.Inject("forceMergeSort", func(val failpoint.Value) { in := val.(string) if in == kvGroup || in == "*" { - return false + failpoint.Return(false) } - } + }) return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold } @@ -349,8 +349,8 @@ func generateWriteIngestSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planne if err != nil { return nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("mockWriteIngestSpecs")); _err_ == nil { - return []planner.PipelineSpec{ + failpoint.Inject("mockWriteIngestSpecs", func() { + failpoint.Return([]planner.PipelineSpec{ &WriteIngestSpec{ WriteIngestStepMeta: &WriteIngestStepMeta{ KVGroup: dataKVGroup, @@ -361,8 +361,8 @@ func generateWriteIngestSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planne KVGroup: "1", }, }, - }, nil - } + }, nil) + }) pTS, lTS, err := planCtx.Store.GetPDClient().GetTS(ctx) if err != nil { diff --git a/pkg/disttask/importinto/planner.go__failpoint_stash__ b/pkg/disttask/importinto/planner.go__failpoint_stash__ deleted file mode 100644 index 8fda8686facbb..0000000000000 --- a/pkg/disttask/importinto/planner.go__failpoint_stash__ +++ /dev/null @@ -1,528 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto - -import ( - "context" - "encoding/hex" - "encoding/json" - "math" - "strconv" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/disttask/framework/planner" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/executor/importer" - tidbkv "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/external" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - verify "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/tikv/client-go/v2/oracle" - "go.uber.org/zap" -) - -var ( - _ planner.LogicalPlan = &LogicalPlan{} - _ planner.PipelineSpec = &ImportSpec{} - _ planner.PipelineSpec = &PostProcessSpec{} -) - -// LogicalPlan represents a logical plan for import into. -type LogicalPlan struct { - JobID int64 - Plan importer.Plan - Stmt string - EligibleInstances []*infosync.ServerInfo - ChunkMap map[int32][]Chunk -} - -// ToTaskMeta converts the logical plan to task meta. -func (p *LogicalPlan) ToTaskMeta() ([]byte, error) { - taskMeta := TaskMeta{ - JobID: p.JobID, - Plan: p.Plan, - Stmt: p.Stmt, - EligibleInstances: p.EligibleInstances, - ChunkMap: p.ChunkMap, - } - return json.Marshal(taskMeta) -} - -// FromTaskMeta converts the task meta to logical plan. -func (p *LogicalPlan) FromTaskMeta(bs []byte) error { - var taskMeta TaskMeta - if err := json.Unmarshal(bs, &taskMeta); err != nil { - return errors.Trace(err) - } - p.JobID = taskMeta.JobID - p.Plan = taskMeta.Plan - p.Stmt = taskMeta.Stmt - p.EligibleInstances = taskMeta.EligibleInstances - p.ChunkMap = taskMeta.ChunkMap - return nil -} - -// ToPhysicalPlan converts the logical plan to physical plan. -func (p *LogicalPlan) ToPhysicalPlan(planCtx planner.PlanCtx) (*planner.PhysicalPlan, error) { - physicalPlan := &planner.PhysicalPlan{} - inputLinks := make([]planner.LinkSpec, 0) - addSpecs := func(specs []planner.PipelineSpec) { - for i, spec := range specs { - physicalPlan.AddProcessor(planner.ProcessorSpec{ - ID: i, - Pipeline: spec, - Output: planner.OutputSpec{ - Links: []planner.LinkSpec{ - { - ProcessorID: len(specs), - }, - }, - }, - Step: planCtx.NextTaskStep, - }) - inputLinks = append(inputLinks, planner.LinkSpec{ - ProcessorID: i, - }) - } - } - // physical plan only needs to be generated once. - // However, our current implementation requires generating it for each step. - // we only generate needed plans for the next step. - switch planCtx.NextTaskStep { - case proto.ImportStepImport, proto.ImportStepEncodeAndSort: - specs, err := generateImportSpecs(planCtx, p) - if err != nil { - return nil, err - } - - addSpecs(specs) - case proto.ImportStepMergeSort: - specs, err := generateMergeSortSpecs(planCtx, p) - if err != nil { - return nil, err - } - - addSpecs(specs) - case proto.ImportStepWriteAndIngest: - specs, err := generateWriteIngestSpecs(planCtx, p) - if err != nil { - return nil, err - } - - addSpecs(specs) - case proto.ImportStepPostProcess: - physicalPlan.AddProcessor(planner.ProcessorSpec{ - ID: len(inputLinks), - Input: planner.InputSpec{ - ColumnTypes: []byte{ - // Checksum_crc64_xor, Total_kvs, Total_bytes, ReadRowCnt, LoadedRowCnt, ColSizeMap - mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeJSON, - }, - Links: inputLinks, - }, - Pipeline: &PostProcessSpec{ - Schema: p.Plan.DBName, - Table: p.Plan.TableInfo.Name.L, - }, - Step: planCtx.NextTaskStep, - }) - } - - return physicalPlan, nil -} - -// ImportSpec is the specification of an import pipeline. -type ImportSpec struct { - ID int32 - Plan importer.Plan - Chunks []Chunk -} - -// ToSubtaskMeta converts the import spec to subtask meta. -func (s *ImportSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) { - importStepMeta := ImportStepMeta{ - ID: s.ID, - Chunks: s.Chunks, - } - return json.Marshal(importStepMeta) -} - -// WriteIngestSpec is the specification of a write-ingest pipeline. -type WriteIngestSpec struct { - *WriteIngestStepMeta -} - -// ToSubtaskMeta converts the write-ingest spec to subtask meta. -func (s *WriteIngestSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) { - return json.Marshal(s.WriteIngestStepMeta) -} - -// MergeSortSpec is the specification of a merge-sort pipeline. -type MergeSortSpec struct { - *MergeSortStepMeta -} - -// ToSubtaskMeta converts the merge-sort spec to subtask meta. -func (s *MergeSortSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) { - return json.Marshal(s.MergeSortStepMeta) -} - -// PostProcessSpec is the specification of a post process pipeline. -type PostProcessSpec struct { - // for checksum request - Schema string - Table string -} - -// ToSubtaskMeta converts the post process spec to subtask meta. -func (*PostProcessSpec) ToSubtaskMeta(planCtx planner.PlanCtx) ([]byte, error) { - encodeStep := getStepOfEncode(planCtx.GlobalSort) - subtaskMetas := make([]*ImportStepMeta, 0, len(planCtx.PreviousSubtaskMetas)) - for _, bs := range planCtx.PreviousSubtaskMetas[encodeStep] { - var subtaskMeta ImportStepMeta - if err := json.Unmarshal(bs, &subtaskMeta); err != nil { - return nil, errors.Trace(err) - } - subtaskMetas = append(subtaskMetas, &subtaskMeta) - } - localChecksum := verify.NewKVGroupChecksumForAdd() - maxIDs := make(map[autoid.AllocatorType]int64, 3) - for _, subtaskMeta := range subtaskMetas { - for id, c := range subtaskMeta.Checksum { - localChecksum.AddRawGroup(id, c.Size, c.KVs, c.Sum) - } - - for key, val := range subtaskMeta.MaxIDs { - if maxIDs[key] < val { - maxIDs[key] = val - } - } - } - c := localChecksum.GetInnerChecksums() - postProcessStepMeta := &PostProcessStepMeta{ - Checksum: make(map[int64]Checksum, len(c)), - MaxIDs: maxIDs, - } - for id, cksum := range c { - postProcessStepMeta.Checksum[id] = Checksum{ - Size: cksum.SumSize(), - KVs: cksum.SumKVS(), - Sum: cksum.Sum(), - } - } - return json.Marshal(postProcessStepMeta) -} - -func buildControllerForPlan(p *LogicalPlan) (*importer.LoadDataController, error) { - return buildController(&p.Plan, p.Stmt) -} - -func buildController(plan *importer.Plan, stmt string) (*importer.LoadDataController, error) { - idAlloc := kv.NewPanickingAllocators(plan.TableInfo.SepAutoInc(), 0) - tbl, err := tables.TableFromMeta(idAlloc, plan.TableInfo) - if err != nil { - return nil, err - } - - astArgs, err := importer.ASTArgsFromStmt(stmt) - if err != nil { - return nil, err - } - controller, err := importer.NewLoadDataController(plan, tbl, astArgs) - if err != nil { - return nil, err - } - return controller, nil -} - -func generateImportSpecs(pCtx planner.PlanCtx, p *LogicalPlan) ([]planner.PipelineSpec, error) { - var chunkMap map[int32][]Chunk - if len(p.ChunkMap) > 0 { - chunkMap = p.ChunkMap - } else { - controller, err2 := buildControllerForPlan(p) - if err2 != nil { - return nil, err2 - } - if err2 = controller.InitDataFiles(pCtx.Ctx); err2 != nil { - return nil, err2 - } - - controller.SetExecuteNodeCnt(pCtx.ExecuteNodesCnt) - engineCheckpoints, err2 := controller.PopulateChunks(pCtx.Ctx) - if err2 != nil { - return nil, err2 - } - chunkMap = toChunkMap(engineCheckpoints) - } - importSpecs := make([]planner.PipelineSpec, 0, len(chunkMap)) - for id := range chunkMap { - if id == common.IndexEngineID { - continue - } - importSpec := &ImportSpec{ - ID: id, - Plan: p.Plan, - Chunks: chunkMap[id], - } - importSpecs = append(importSpecs, importSpec) - } - return importSpecs, nil -} - -func skipMergeSort(kvGroup string, stats []external.MultipleFilesStat) bool { - failpoint.Inject("forceMergeSort", func(val failpoint.Value) { - in := val.(string) - if in == kvGroup || in == "*" { - failpoint.Return(false) - } - }) - return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold -} - -func generateMergeSortSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planner.PipelineSpec, error) { - step := external.MergeSortFileCountStep - result := make([]planner.PipelineSpec, 0, 16) - kvMetas, err := getSortedKVMetasOfEncodeStep(planCtx.PreviousSubtaskMetas[proto.ImportStepEncodeAndSort]) - if err != nil { - return nil, err - } - for kvGroup, kvMeta := range kvMetas { - if !p.Plan.ForceMergeStep && skipMergeSort(kvGroup, kvMeta.MultipleFilesStats) { - logutil.Logger(planCtx.Ctx).Info("skip merge sort for kv group", - zap.Int64("task-id", planCtx.TaskID), - zap.String("kv-group", kvGroup)) - continue - } - dataFiles := kvMeta.GetDataFiles() - length := len(dataFiles) - for start := 0; start < length; start += step { - end := start + step - if end > length { - end = length - } - result = append(result, &MergeSortSpec{ - MergeSortStepMeta: &MergeSortStepMeta{ - KVGroup: kvGroup, - DataFiles: dataFiles[start:end], - }, - }) - } - } - return result, nil -} - -func generateWriteIngestSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planner.PipelineSpec, error) { - ctx := planCtx.Ctx - controller, err2 := buildControllerForPlan(p) - if err2 != nil { - return nil, err2 - } - if err2 = controller.InitDataStore(ctx); err2 != nil { - return nil, err2 - } - // kvMetas contains data kv meta and all index kv metas. - // each kvMeta will be split into multiple range group individually, - // i.e. data and index kv will NOT be in the same subtask. - kvMetas, err := getSortedKVMetasForIngest(planCtx, p) - if err != nil { - return nil, err - } - failpoint.Inject("mockWriteIngestSpecs", func() { - failpoint.Return([]planner.PipelineSpec{ - &WriteIngestSpec{ - WriteIngestStepMeta: &WriteIngestStepMeta{ - KVGroup: dataKVGroup, - }, - }, - &WriteIngestSpec{ - WriteIngestStepMeta: &WriteIngestStepMeta{ - KVGroup: "1", - }, - }, - }, nil) - }) - - pTS, lTS, err := planCtx.Store.GetPDClient().GetTS(ctx) - if err != nil { - return nil, err - } - ts := oracle.ComposeTS(pTS, lTS) - - specs := make([]planner.PipelineSpec, 0, 16) - for kvGroup, kvMeta := range kvMetas { - splitter, err1 := getRangeSplitter(ctx, controller.GlobalSortStore, kvMeta) - if err1 != nil { - return nil, err1 - } - - err1 = func() error { - defer func() { - err2 := splitter.Close() - if err2 != nil { - logutil.Logger(ctx).Warn("close range splitter failed", zap.Error(err2)) - } - }() - startKey := tidbkv.Key(kvMeta.StartKey) - var endKey tidbkv.Key - for { - endKeyOfGroup, dataFiles, statFiles, rangeSplitKeys, err2 := splitter.SplitOneRangesGroup() - if err2 != nil { - return err2 - } - if len(endKeyOfGroup) == 0 { - endKey = kvMeta.EndKey - } else { - endKey = tidbkv.Key(endKeyOfGroup).Clone() - } - logutil.Logger(ctx).Info("kv range as subtask", - zap.String("startKey", hex.EncodeToString(startKey)), - zap.String("endKey", hex.EncodeToString(endKey)), - zap.Int("dataFiles", len(dataFiles))) - if startKey.Cmp(endKey) >= 0 { - return errors.Errorf("invalid kv range, startKey: %s, endKey: %s", - hex.EncodeToString(startKey), hex.EncodeToString(endKey)) - } - // each subtask will write and ingest one range group - m := &WriteIngestStepMeta{ - KVGroup: kvGroup, - SortedKVMeta: external.SortedKVMeta{ - StartKey: startKey, - EndKey: endKey, - // this is actually an estimate, we don't know the exact size of the data - TotalKVSize: uint64(config.DefaultBatchSize), - }, - DataFiles: dataFiles, - StatFiles: statFiles, - RangeSplitKeys: rangeSplitKeys, - RangeSplitSize: splitter.GetRangeSplitSize(), - TS: ts, - } - specs = append(specs, &WriteIngestSpec{m}) - - startKey = endKey - if len(endKeyOfGroup) == 0 { - break - } - } - return nil - }() - if err1 != nil { - return nil, err1 - } - } - return specs, nil -} - -func getSortedKVMetasOfEncodeStep(subTaskMetas [][]byte) (map[string]*external.SortedKVMeta, error) { - dataKVMeta := &external.SortedKVMeta{} - indexKVMetas := make(map[int64]*external.SortedKVMeta) - for _, subTaskMeta := range subTaskMetas { - var stepMeta ImportStepMeta - err := json.Unmarshal(subTaskMeta, &stepMeta) - if err != nil { - return nil, errors.Trace(err) - } - dataKVMeta.Merge(stepMeta.SortedDataMeta) - for indexID, sortedIndexMeta := range stepMeta.SortedIndexMetas { - if item, ok := indexKVMetas[indexID]; !ok { - indexKVMetas[indexID] = sortedIndexMeta - } else { - item.Merge(sortedIndexMeta) - } - } - } - res := make(map[string]*external.SortedKVMeta, 1+len(indexKVMetas)) - res[dataKVGroup] = dataKVMeta - for indexID, item := range indexKVMetas { - res[strconv.Itoa(int(indexID))] = item - } - return res, nil -} - -func getSortedKVMetasOfMergeStep(subTaskMetas [][]byte) (map[string]*external.SortedKVMeta, error) { - result := make(map[string]*external.SortedKVMeta, len(subTaskMetas)) - for _, subTaskMeta := range subTaskMetas { - var stepMeta MergeSortStepMeta - err := json.Unmarshal(subTaskMeta, &stepMeta) - if err != nil { - return nil, errors.Trace(err) - } - meta, ok := result[stepMeta.KVGroup] - if !ok { - result[stepMeta.KVGroup] = &stepMeta.SortedKVMeta - continue - } - meta.Merge(&stepMeta.SortedKVMeta) - } - return result, nil -} - -func getSortedKVMetasForIngest(planCtx planner.PlanCtx, p *LogicalPlan) (map[string]*external.SortedKVMeta, error) { - kvMetasOfMergeSort, err := getSortedKVMetasOfMergeStep(planCtx.PreviousSubtaskMetas[proto.ImportStepMergeSort]) - if err != nil { - return nil, err - } - kvMetasOfEncodeStep, err := getSortedKVMetasOfEncodeStep(planCtx.PreviousSubtaskMetas[proto.ImportStepEncodeAndSort]) - if err != nil { - return nil, err - } - for kvGroup, kvMeta := range kvMetasOfEncodeStep { - // only part of kv files are merge sorted. we need to merge kv metas that - // are not merged into the kvMetasOfMergeSort. - if !p.Plan.ForceMergeStep && skipMergeSort(kvGroup, kvMeta.MultipleFilesStats) { - if _, ok := kvMetasOfMergeSort[kvGroup]; ok { - // this should not happen, because we only generate merge sort - // subtasks for those kv groups with MaxOverlappingTotal > MergeSortOverlapThreshold - logutil.Logger(planCtx.Ctx).Error("kv group of encode step conflict with merge sort step") - return nil, errors.New("kv group of encode step conflict with merge sort step") - } - kvMetasOfMergeSort[kvGroup] = kvMeta - } - } - return kvMetasOfMergeSort, nil -} - -func getRangeSplitter(ctx context.Context, store storage.ExternalStorage, kvMeta *external.SortedKVMeta) ( - *external.RangeSplitter, error) { - regionSplitSize, regionSplitKeys, err := importer.GetRegionSplitSizeKeys(ctx) - if err != nil { - logutil.Logger(ctx).Warn("fail to get region split size and keys", zap.Error(err)) - } - regionSplitSize = max(regionSplitSize, int64(config.SplitRegionSize)) - regionSplitKeys = max(regionSplitKeys, int64(config.SplitRegionKeys)) - logutil.Logger(ctx).Info("split kv range with split size and keys", - zap.Int64("region-split-size", regionSplitSize), - zap.Int64("region-split-keys", regionSplitKeys)) - - return external.NewRangeSplitter( - ctx, - kvMeta.MultipleFilesStats, - store, - int64(config.DefaultBatchSize), - int64(math.MaxInt64), - regionSplitSize, - regionSplitKeys, - ) -} diff --git a/pkg/disttask/importinto/scheduler.go b/pkg/disttask/importinto/scheduler.go index 52a98941d99df..1dd1c0021ae7f 100644 --- a/pkg/disttask/importinto/scheduler.go +++ b/pkg/disttask/importinto/scheduler.go @@ -249,9 +249,9 @@ func (sch *ImportSchedulerExt) OnNextSubtasksBatch( } previousSubtaskMetas[proto.ImportStepEncodeAndSort] = sortAndEncodeMeta case proto.ImportStepWriteAndIngest: - if _, _err_ := failpoint.Eval(_curpkg_("failWhenDispatchWriteIngestSubtask")); _err_ == nil { - return nil, errors.New("injected error") - } + failpoint.Inject("failWhenDispatchWriteIngestSubtask", func() { + failpoint.Return(nil, errors.New("injected error")) + }) // merge sort might be skipped for some kv groups, so we need to get all // subtask metas of ImportStepEncodeAndSort step too. encodeAndSortMetas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepEncodeAndSort) @@ -269,15 +269,15 @@ func (sch *ImportSchedulerExt) OnNextSubtasksBatch( } case proto.ImportStepPostProcess: sch.switchTiKV2NormalMode(ctx, task, logger) - if _, _err_ := failpoint.Eval(_curpkg_("clearLastSwitchTime")); _err_ == nil { + failpoint.Inject("clearLastSwitchTime", func() { sch.lastSwitchTime.Store(time.Time{}) - } + }) if err = job2Step(ctx, logger, taskMeta, importer.JobStepValidating); err != nil { return nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("failWhenDispatchPostProcessSubtask")); _err_ == nil { - return nil, errors.New("injected error after ImportStepImport") - } + failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { + failpoint.Return(nil, errors.New("injected error after ImportStepImport")) + }) // we need get metas where checksum is stored. if err := updateResult(taskHandle, task, taskMeta, sch.GlobalSort); err != nil { return nil, err @@ -614,7 +614,7 @@ func getLoadedRowCountOnGlobalSort(handle storage.TaskHandle, task *proto.Task) } func startJob(ctx context.Context, logger *zap.Logger, taskHandle storage.TaskHandle, taskMeta *TaskMeta, jobStep string) error { - failpoint.Call(_curpkg_("syncBeforeJobStarted"), taskMeta.JobID) + failpoint.InjectCall("syncBeforeJobStarted", taskMeta.JobID) // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes // we consider all errors as retryable errors, except context done. // the errors include errors happened when communicate with PD and TiKV. @@ -628,7 +628,7 @@ func startJob(ctx context.Context, logger *zap.Logger, taskHandle storage.TaskHa }) }, ) - failpoint.Call(_curpkg_("syncAfterJobStarted")) + failpoint.InjectCall("syncAfterJobStarted") return err } diff --git a/pkg/disttask/importinto/scheduler.go__failpoint_stash__ b/pkg/disttask/importinto/scheduler.go__failpoint_stash__ deleted file mode 100644 index 1dd1c0021ae7f..0000000000000 --- a/pkg/disttask/importinto/scheduler.go__failpoint_stash__ +++ /dev/null @@ -1,719 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto - -import ( - "context" - "encoding/json" - "strconv" - "strings" - "sync" - "time" - - dmysql "github.com/go-sql-driver/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/utils" - tidb "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/planner" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/executor/importer" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/backoff" - disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" - "github.com/pingcap/tidb/pkg/util/etcd" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -const ( - registerTaskTTL = 10 * time.Minute - refreshTaskTTLInterval = 3 * time.Minute - registerTimeout = 5 * time.Second -) - -// NewTaskRegisterWithTTL is the ctor for TaskRegister. -// It is exported for testing. -var NewTaskRegisterWithTTL = utils.NewTaskRegisterWithTTL - -type taskInfo struct { - taskID int64 - - // operation on taskInfo is run inside detect-task goroutine, so no need to synchronize. - lastRegisterTime time.Time - - // initialized lazily in register() - etcdClient *etcd.Client - taskRegister utils.TaskRegister -} - -func (t *taskInfo) register(ctx context.Context) { - if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { - return - } - - if time.Since(t.lastRegisterTime) < refreshTaskTTLInterval { - return - } - logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) - if t.taskRegister == nil { - client, err := importer.GetEtcdClient() - if err != nil { - logger.Warn("get etcd client failed", zap.Error(err)) - return - } - t.etcdClient = client - t.taskRegister = NewTaskRegisterWithTTL(client.GetClient(), registerTaskTTL, - utils.RegisterImportInto, strconv.FormatInt(t.taskID, 10)) - } - timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) - defer cancel() - if err := t.taskRegister.RegisterTaskOnce(timeoutCtx); err != nil { - logger.Warn("register task failed", zap.Error(err)) - } else { - logger.Info("register task to pd or refresh lease success") - } - // we set it even if register failed, TTL is 10min, refresh interval is 3min, - // we can try 2 times before the lease is expired. - t.lastRegisterTime = time.Now() -} - -func (t *taskInfo) close(ctx context.Context) { - logger := logutil.BgLogger().With(zap.Int64("task-id", t.taskID)) - if t.taskRegister != nil { - timeoutCtx, cancel := context.WithTimeout(ctx, registerTimeout) - defer cancel() - if err := t.taskRegister.Close(timeoutCtx); err != nil { - logger.Warn("unregister task failed", zap.Error(err)) - } else { - logger.Info("unregister task success") - } - t.taskRegister = nil - } - if t.etcdClient != nil { - if err := t.etcdClient.Close(); err != nil { - logger.Warn("close etcd client failed", zap.Error(err)) - } - t.etcdClient = nil - } -} - -// ImportSchedulerExt is an extension of ImportScheduler, exported for test. -type ImportSchedulerExt struct { - GlobalSort bool - mu sync.RWMutex - // NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one - // task can be running at a time. but we might support task queuing in the future, leave it for now. - // the last time we switch TiKV into IMPORT mode, this is a global operation, do it for one task makes - // no difference to do it for all tasks. So we do not need to record the switch time for each task. - lastSwitchTime atomic.Time - // taskInfoMap is a map from taskID to taskInfo - taskInfoMap sync.Map - - // currTaskID is the taskID of the current running task. - // It may be changed when we switch to a new task or switch to a new owner. - currTaskID atomic.Int64 - disableTiKVImportMode atomic.Bool - - storeWithPD kv.StorageWithPD -} - -var _ scheduler.Extension = (*ImportSchedulerExt)(nil) - -// OnTick implements scheduler.Extension interface. -func (sch *ImportSchedulerExt) OnTick(ctx context.Context, task *proto.Task) { - // only switch TiKV mode or register task when task is running - if task.State != proto.TaskStateRunning { - return - } - sch.switchTiKVMode(ctx, task) - sch.registerTask(ctx, task) -} - -func (*ImportSchedulerExt) isImporting2TiKV(task *proto.Task) bool { - return task.Step == proto.ImportStepImport || task.Step == proto.ImportStepWriteAndIngest -} - -func (sch *ImportSchedulerExt) switchTiKVMode(ctx context.Context, task *proto.Task) { - sch.updateCurrentTask(task) - // only import step need to switch to IMPORT mode, - // If TiKV is in IMPORT mode during checksum, coprocessor will time out. - if sch.disableTiKVImportMode.Load() || !sch.isImporting2TiKV(task) { - return - } - - if time.Since(sch.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { - return - } - - sch.mu.Lock() - defer sch.mu.Unlock() - if time.Since(sch.lastSwitchTime.Load()) < config.DefaultSwitchTiKVModeInterval { - return - } - - logger := logutil.BgLogger().With(zap.Int64("task-id", task.ID)) - // TODO: use the TLS object from TiDB server - tidbCfg := tidb.GetGlobalConfig() - tls, err := util.NewTLSConfig( - util.WithCAPath(tidbCfg.Security.ClusterSSLCA), - util.WithCertAndKeyPath(tidbCfg.Security.ClusterSSLCert, tidbCfg.Security.ClusterSSLKey), - ) - if err != nil { - logger.Warn("get tikv mode switcher failed", zap.Error(err)) - return - } - pdHTTPCli := sch.storeWithPD.GetPDHTTPClient() - switcher := importer.NewTiKVModeSwitcher(tls, pdHTTPCli, logger) - - switcher.ToImportMode(ctx) - sch.lastSwitchTime.Store(time.Now()) -} - -func (sch *ImportSchedulerExt) registerTask(ctx context.Context, task *proto.Task) { - val, _ := sch.taskInfoMap.LoadOrStore(task.ID, &taskInfo{taskID: task.ID}) - info := val.(*taskInfo) - info.register(ctx) -} - -func (sch *ImportSchedulerExt) unregisterTask(ctx context.Context, task *proto.Task) { - if val, loaded := sch.taskInfoMap.LoadAndDelete(task.ID); loaded { - info := val.(*taskInfo) - info.close(ctx) - } -} - -// OnNextSubtasksBatch generate batch of next stage's plan. -func (sch *ImportSchedulerExt) OnNextSubtasksBatch( - ctx context.Context, - taskHandle storage.TaskHandle, - task *proto.Task, - execIDs []string, - nextStep proto.Step, -) ( - resSubtaskMeta [][]byte, err error) { - logger := logutil.BgLogger().With( - zap.Stringer("type", task.Type), - zap.Int64("task-id", task.ID), - zap.String("curr-step", proto.Step2Str(task.Type, task.Step)), - zap.String("next-step", proto.Step2Str(task.Type, nextStep)), - ) - taskMeta := &TaskMeta{} - err = json.Unmarshal(task.Meta, taskMeta) - if err != nil { - return nil, errors.Trace(err) - } - logger.Info("on next subtasks batch") - - previousSubtaskMetas := make(map[proto.Step][][]byte, 1) - switch nextStep { - case proto.ImportStepImport, proto.ImportStepEncodeAndSort: - if metrics, ok := metric.GetCommonMetric(ctx); ok { - metrics.BytesCounter.WithLabelValues(metric.StateTotalRestore).Add(float64(taskMeta.Plan.TotalFileSize)) - } - jobStep := importer.JobStepImporting - if sch.GlobalSort { - jobStep = importer.JobStepGlobalSorting - } - if err = startJob(ctx, logger, taskHandle, taskMeta, jobStep); err != nil { - return nil, err - } - case proto.ImportStepMergeSort: - sortAndEncodeMeta, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepEncodeAndSort) - if err != nil { - return nil, err - } - previousSubtaskMetas[proto.ImportStepEncodeAndSort] = sortAndEncodeMeta - case proto.ImportStepWriteAndIngest: - failpoint.Inject("failWhenDispatchWriteIngestSubtask", func() { - failpoint.Return(nil, errors.New("injected error")) - }) - // merge sort might be skipped for some kv groups, so we need to get all - // subtask metas of ImportStepEncodeAndSort step too. - encodeAndSortMetas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepEncodeAndSort) - if err != nil { - return nil, err - } - mergeSortMetas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepMergeSort) - if err != nil { - return nil, err - } - previousSubtaskMetas[proto.ImportStepEncodeAndSort] = encodeAndSortMetas - previousSubtaskMetas[proto.ImportStepMergeSort] = mergeSortMetas - if err = job2Step(ctx, logger, taskMeta, importer.JobStepImporting); err != nil { - return nil, err - } - case proto.ImportStepPostProcess: - sch.switchTiKV2NormalMode(ctx, task, logger) - failpoint.Inject("clearLastSwitchTime", func() { - sch.lastSwitchTime.Store(time.Time{}) - }) - if err = job2Step(ctx, logger, taskMeta, importer.JobStepValidating); err != nil { - return nil, err - } - failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { - failpoint.Return(nil, errors.New("injected error after ImportStepImport")) - }) - // we need get metas where checksum is stored. - if err := updateResult(taskHandle, task, taskMeta, sch.GlobalSort); err != nil { - return nil, err - } - step := getStepOfEncode(sch.GlobalSort) - metas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, step) - if err != nil { - return nil, err - } - previousSubtaskMetas[step] = metas - logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result)) - case proto.StepDone: - return nil, nil - default: - return nil, errors.Errorf("unknown step %d", task.Step) - } - - planCtx := planner.PlanCtx{ - Ctx: ctx, - TaskID: task.ID, - PreviousSubtaskMetas: previousSubtaskMetas, - GlobalSort: sch.GlobalSort, - NextTaskStep: nextStep, - ExecuteNodesCnt: len(execIDs), - Store: sch.storeWithPD, - } - logicalPlan := &LogicalPlan{} - if err := logicalPlan.FromTaskMeta(task.Meta); err != nil { - return nil, err - } - physicalPlan, err := logicalPlan.ToPhysicalPlan(planCtx) - if err != nil { - return nil, err - } - metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, nextStep) - if err != nil { - return nil, err - } - logger.Info("generate subtasks", zap.Int("subtask-count", len(metaBytes))) - return metaBytes, nil -} - -// OnDone implements scheduler.Extension interface. -func (sch *ImportSchedulerExt) OnDone(ctx context.Context, handle storage.TaskHandle, task *proto.Task) error { - logger := logutil.BgLogger().With( - zap.Stringer("type", task.Type), - zap.Int64("task-id", task.ID), - zap.String("step", proto.Step2Str(task.Type, task.Step)), - ) - logger.Info("task done", zap.Stringer("state", task.State), zap.Error(task.Error)) - taskMeta := &TaskMeta{} - err := json.Unmarshal(task.Meta, taskMeta) - if err != nil { - return errors.Trace(err) - } - if task.Error == nil { - return sch.finishJob(ctx, logger, handle, task, taskMeta) - } - if scheduler.IsCancelledErr(task.Error) { - return sch.cancelJob(ctx, handle, task, taskMeta, logger) - } - return sch.failJob(ctx, handle, task, taskMeta, logger, task.Error.Error()) -} - -// GetEligibleInstances implements scheduler.Extension interface. -func (*ImportSchedulerExt) GetEligibleInstances(_ context.Context, task *proto.Task) ([]string, error) { - taskMeta := &TaskMeta{} - err := json.Unmarshal(task.Meta, taskMeta) - if err != nil { - return nil, errors.Trace(err) - } - res := make([]string, 0, len(taskMeta.EligibleInstances)) - for _, instance := range taskMeta.EligibleInstances { - res = append(res, disttaskutil.GenerateExecID(instance)) - } - return res, nil -} - -// IsRetryableErr implements scheduler.Extension interface. -func (*ImportSchedulerExt) IsRetryableErr(error) bool { - // TODO: check whether the error is retryable. - return false -} - -// GetNextStep implements scheduler.Extension interface. -func (sch *ImportSchedulerExt) GetNextStep(task *proto.TaskBase) proto.Step { - switch task.Step { - case proto.StepInit: - if sch.GlobalSort { - return proto.ImportStepEncodeAndSort - } - return proto.ImportStepImport - case proto.ImportStepEncodeAndSort: - return proto.ImportStepMergeSort - case proto.ImportStepMergeSort: - return proto.ImportStepWriteAndIngest - case proto.ImportStepImport, proto.ImportStepWriteAndIngest: - return proto.ImportStepPostProcess - default: - // current step must be ImportStepPostProcess - return proto.StepDone - } -} - -func (sch *ImportSchedulerExt) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { - sch.updateCurrentTask(task) - if sch.disableTiKVImportMode.Load() { - return - } - - sch.mu.Lock() - defer sch.mu.Unlock() - - // TODO: use the TLS object from TiDB server - tidbCfg := tidb.GetGlobalConfig() - tls, err := util.NewTLSConfig( - util.WithCAPath(tidbCfg.Security.ClusterSSLCA), - util.WithCertAndKeyPath(tidbCfg.Security.ClusterSSLCert, tidbCfg.Security.ClusterSSLKey), - ) - if err != nil { - logger.Warn("get tikv mode switcher failed", zap.Error(err)) - return - } - pdHTTPCli := sch.storeWithPD.GetPDHTTPClient() - switcher := importer.NewTiKVModeSwitcher(tls, pdHTTPCli, logger) - - switcher.ToNormalMode(ctx) - - // clear it, so next task can switch TiKV mode again. - sch.lastSwitchTime.Store(time.Time{}) -} - -func (sch *ImportSchedulerExt) updateCurrentTask(task *proto.Task) { - if sch.currTaskID.Swap(task.ID) != task.ID { - taskMeta := &TaskMeta{} - if err := json.Unmarshal(task.Meta, taskMeta); err == nil { - // for raftkv2, switch mode in local backend - sch.disableTiKVImportMode.Store(taskMeta.Plan.DisableTiKVImportMode || taskMeta.Plan.IsRaftKV2) - } - } -} - -type importScheduler struct { - *scheduler.BaseScheduler - storeWithPD kv.StorageWithPD -} - -// NewImportScheduler creates a new import scheduler. -func NewImportScheduler( - ctx context.Context, - task *proto.Task, - param scheduler.Param, - storeWithPD kv.StorageWithPD, -) scheduler.Scheduler { - metrics := metricsManager.getOrCreateMetrics(task.ID) - subCtx := metric.WithCommonMetric(ctx, metrics) - sch := importScheduler{ - BaseScheduler: scheduler.NewBaseScheduler(subCtx, task, param), - storeWithPD: storeWithPD, - } - return &sch -} - -func (sch *importScheduler) Init() (err error) { - defer func() { - if err != nil { - // if init failed, close is not called, so we need to unregister here. - metricsManager.unregister(sch.GetTask().ID) - } - }() - taskMeta := &TaskMeta{} - if err = json.Unmarshal(sch.BaseScheduler.GetTask().Meta, taskMeta); err != nil { - return errors.Annotate(err, "unmarshal task meta failed") - } - - sch.BaseScheduler.Extension = &ImportSchedulerExt{ - GlobalSort: taskMeta.Plan.CloudStorageURI != "", - storeWithPD: sch.storeWithPD, - } - return sch.BaseScheduler.Init() -} - -func (sch *importScheduler) Close() { - metricsManager.unregister(sch.GetTask().ID) - sch.BaseScheduler.Close() -} - -// nolint:deadcode -func dropTableIndexes(ctx context.Context, handle storage.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error { - tblInfo := taskMeta.Plan.TableInfo - - remainIndexes, dropIndexes := common.GetDropIndexInfos(tblInfo) - for _, idxInfo := range dropIndexes { - sqlStr := common.BuildDropIndexSQL(taskMeta.Plan.DBName, tblInfo.Name.L, idxInfo) - if err := executeSQL(ctx, handle, logger, sqlStr); err != nil { - if merr, ok := errors.Cause(err).(*dmysql.MySQLError); ok { - switch merr.Number { - case errno.ErrCantDropFieldOrKey, errno.ErrDropIndexNeededInForeignKey: - remainIndexes = append(remainIndexes, idxInfo) - logger.Warn("can't drop index, skip", zap.String("index", idxInfo.Name.O), zap.Error(err)) - continue - } - } - return err - } - } - if len(remainIndexes) < len(tblInfo.Indices) { - taskMeta.Plan.TableInfo = taskMeta.Plan.TableInfo.Clone() - taskMeta.Plan.TableInfo.Indices = remainIndexes - } - return nil -} - -// nolint:deadcode -func createTableIndexes(ctx context.Context, executor storage.SessionExecutor, taskMeta *TaskMeta, logger *zap.Logger) error { - tableName := common.UniqueTable(taskMeta.Plan.DBName, taskMeta.Plan.TableInfo.Name.L) - singleSQL, multiSQLs := common.BuildAddIndexSQL(tableName, taskMeta.Plan.TableInfo, taskMeta.Plan.DesiredTableInfo) - logger.Info("build add index sql", zap.String("singleSQL", singleSQL), zap.Strings("multiSQLs", multiSQLs)) - if len(multiSQLs) == 0 { - return nil - } - - err := executeSQL(ctx, executor, logger, singleSQL) - if err == nil { - return nil - } - if !common.IsDupKeyError(err) { - // TODO: refine err msg and error code according to spec. - return errors.Errorf("Failed to create index: %v, please execute the SQL manually, sql: %s", err, singleSQL) - } - if len(multiSQLs) == 1 { - return nil - } - logger.Warn("cannot add all indexes in one statement, try to add them one by one", zap.Strings("sqls", multiSQLs), zap.Error(err)) - - for i, ddl := range multiSQLs { - err := executeSQL(ctx, executor, logger, ddl) - if err != nil && !common.IsDupKeyError(err) { - // TODO: refine err msg and error code according to spec. - return errors.Errorf("Failed to create index: %v, please execute the SQLs manually, sqls: %s", err, strings.Join(multiSQLs[i:], ";")) - } - } - return nil -} - -// TODO: return the result of sql. -func executeSQL(ctx context.Context, executor storage.SessionExecutor, logger *zap.Logger, sql string, args ...any) (err error) { - logger.Info("execute sql", zap.String("sql", sql), zap.Any("args", args)) - return executor.WithNewSession(func(se sessionctx.Context) error { - _, err := se.GetSQLExecutor().ExecuteInternal(ctx, sql, args...) - return err - }) -} - -func updateMeta(task *proto.Task, taskMeta *TaskMeta) error { - bs, err := json.Marshal(taskMeta) - if err != nil { - return errors.Trace(err) - } - task.Meta = bs - - return nil -} - -// todo: converting back and forth, we should unify struct and remove this function later. -func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[int32][]Chunk { - chunkMap := make(map[int32][]Chunk, len(engineCheckpoints)) - for id, ecp := range engineCheckpoints { - chunkMap[id] = make([]Chunk, 0, len(ecp.Chunks)) - for _, chunkCheckpoint := range ecp.Chunks { - chunkMap[id] = append(chunkMap[id], toChunk(*chunkCheckpoint)) - } - } - return chunkMap -} - -func getStepOfEncode(globalSort bool) proto.Step { - if globalSort { - return proto.ImportStepEncodeAndSort - } - return proto.ImportStepImport -} - -// we will update taskMeta in place and make task.Meta point to the new taskMeta. -func updateResult(handle storage.TaskHandle, task *proto.Task, taskMeta *TaskMeta, globalSort bool) error { - stepOfEncode := getStepOfEncode(globalSort) - metas, err := handle.GetPreviousSubtaskMetas(task.ID, stepOfEncode) - if err != nil { - return err - } - - subtaskMetas := make([]*ImportStepMeta, 0, len(metas)) - for _, bs := range metas { - var subtaskMeta ImportStepMeta - if err := json.Unmarshal(bs, &subtaskMeta); err != nil { - return errors.Trace(err) - } - subtaskMetas = append(subtaskMetas, &subtaskMeta) - } - columnSizeMap := make(map[int64]int64) - for _, subtaskMeta := range subtaskMetas { - taskMeta.Result.LoadedRowCnt += subtaskMeta.Result.LoadedRowCnt - for key, val := range subtaskMeta.Result.ColSizeMap { - columnSizeMap[key] += val - } - } - taskMeta.Result.ColSizeMap = columnSizeMap - - if globalSort { - taskMeta.Result.LoadedRowCnt, err = getLoadedRowCountOnGlobalSort(handle, task) - if err != nil { - return err - } - } - - return updateMeta(task, taskMeta) -} - -func getLoadedRowCountOnGlobalSort(handle storage.TaskHandle, task *proto.Task) (uint64, error) { - metas, err := handle.GetPreviousSubtaskMetas(task.ID, proto.ImportStepWriteAndIngest) - if err != nil { - return 0, err - } - - var loadedRowCount uint64 - for _, bs := range metas { - var subtaskMeta WriteIngestStepMeta - if err = json.Unmarshal(bs, &subtaskMeta); err != nil { - return 0, errors.Trace(err) - } - loadedRowCount += subtaskMeta.Result.LoadedRowCnt - } - return loadedRowCount, nil -} - -func startJob(ctx context.Context, logger *zap.Logger, taskHandle storage.TaskHandle, taskMeta *TaskMeta, jobStep string) error { - failpoint.InjectCall("syncBeforeJobStarted", taskMeta.JobID) - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - // we consider all errors as retryable errors, except context done. - // the errors include errors happened when communicate with PD and TiKV. - // we didn't consider system corrupt cases like system table dropped/altered. - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - err := handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { - exec := se.GetSQLExecutor() - return importer.StartJob(ctx, exec, taskMeta.JobID, jobStep) - }) - }, - ) - failpoint.InjectCall("syncAfterJobStarted") - return err -} - -func job2Step(ctx context.Context, logger *zap.Logger, taskMeta *TaskMeta, step string) error { - taskManager, err := storage.GetTaskManager() - if err != nil { - return err - } - // todo: use scheduler.TaskHandle - // we might call this in taskExecutor later, there's no scheduler.Extension, so we use taskManager here. - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - return true, taskManager.WithNewSession(func(se sessionctx.Context) error { - exec := se.GetSQLExecutor() - return importer.Job2Step(ctx, exec, taskMeta.JobID, step) - }) - }, - ) -} - -func (sch *ImportSchedulerExt) finishJob(ctx context.Context, logger *zap.Logger, - taskHandle storage.TaskHandle, task *proto.Task, taskMeta *TaskMeta) error { - // we have already switch import-mode when switch to post-process step. - sch.unregisterTask(ctx, task) - summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { - if err := importer.FlushTableStats(ctx, se, taskMeta.Plan.TableInfo.ID, &importer.JobImportResult{ - Affected: taskMeta.Result.LoadedRowCnt, - ColSizeMap: taskMeta.Result.ColSizeMap, - }); err != nil { - logger.Warn("flush table stats failed", zap.Error(err)) - } - exec := se.GetSQLExecutor() - return importer.FinishJob(ctx, exec, taskMeta.JobID, summary) - }) - }, - ) -} - -func (sch *ImportSchedulerExt) failJob(ctx context.Context, taskHandle storage.TaskHandle, task *proto.Task, - taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error { - sch.switchTiKV2NormalMode(ctx, task, logger) - sch.unregisterTask(ctx, task) - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { - exec := se.GetSQLExecutor() - return importer.FailJob(ctx, exec, taskMeta.JobID, errorMsg) - }) - }, - ) -} - -func (sch *ImportSchedulerExt) cancelJob(ctx context.Context, taskHandle storage.TaskHandle, task *proto.Task, - meta *TaskMeta, logger *zap.Logger) error { - sch.switchTiKV2NormalMode(ctx, task, logger) - sch.unregisterTask(ctx, task) - // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes - backoffer := backoff.NewExponential(scheduler.RetrySQLInterval, 2, scheduler.RetrySQLMaxInterval) - return handle.RunWithRetry(ctx, scheduler.RetrySQLTimes, backoffer, logger, - func(ctx context.Context) (bool, error) { - return true, taskHandle.WithNewSession(func(se sessionctx.Context) error { - exec := se.GetSQLExecutor() - return importer.CancelJob(ctx, exec, meta.JobID) - }) - }, - ) -} - -func redactSensitiveInfo(task *proto.Task, taskMeta *TaskMeta) { - taskMeta.Stmt = "" - taskMeta.Plan.Path = ast.RedactURL(taskMeta.Plan.Path) - if taskMeta.Plan.CloudStorageURI != "" { - taskMeta.Plan.CloudStorageURI = ast.RedactURL(taskMeta.Plan.CloudStorageURI) - } - if err := updateMeta(task, taskMeta); err != nil { - // marshal failed, should not happen - logutil.BgLogger().Warn("failed to update task meta", zap.Error(err)) - } -} diff --git a/pkg/disttask/importinto/subtask_executor.go b/pkg/disttask/importinto/subtask_executor.go index 92300944ef8ac..2ca7fc241cc0b 100644 --- a/pkg/disttask/importinto/subtask_executor.go +++ b/pkg/disttask/importinto/subtask_executor.go @@ -55,13 +55,13 @@ func newImportMinimalTaskExecutor0(t *importStepMinimalTask) MiniTaskExecutor { func (e *importMinimalTaskExecutor) Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error { logger := logutil.BgLogger().With(zap.Stringer("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID)) logger.Info("execute chunk") - if _, _err_ := failpoint.Eval(_curpkg_("waitBeforeSortChunk")); _err_ == nil { + failpoint.Inject("waitBeforeSortChunk", func() { time.Sleep(3 * time.Second) - } - if _, _err_ := failpoint.Eval(_curpkg_("errorWhenSortChunk")); _err_ == nil { - return errors.New("occur an error when sort chunk") - } - failpoint.Call(_curpkg_("syncBeforeSortChunk")) + }) + failpoint.Inject("errorWhenSortChunk", func() { + failpoint.Return(errors.New("occur an error when sort chunk")) + }) + failpoint.InjectCall("syncBeforeSortChunk") chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk) sharedVars := e.mTtask.SharedVars checksum := verify.NewKVGroupChecksumWithKeyspace(sharedVars.TableImporter.GetKeySpace()) @@ -101,7 +101,7 @@ func (e *importMinimalTaskExecutor) Run(ctx context.Context, dataWriter, indexWr // postProcess does the post-processing for the task. func postProcess(ctx context.Context, store kv.Storage, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) (err error) { - failpoint.Call(_curpkg_("syncBeforePostProcess"), taskMeta.JobID) + failpoint.InjectCall("syncBeforePostProcess", taskMeta.JobID) callLog := log.BeginTask(logger, "post process") defer func() { diff --git a/pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ b/pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ deleted file mode 100644 index 2ca7fc241cc0b..0000000000000 --- a/pkg/disttask/importinto/subtask_executor.go__failpoint_stash__ +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto - -import ( - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/executor/importer" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/log" - verify "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -// MiniTaskExecutor is the interface for a minimal task executor. -// exported for testing. -type MiniTaskExecutor interface { - Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error -} - -// importMinimalTaskExecutor is a minimal task executor for IMPORT INTO. -type importMinimalTaskExecutor struct { - mTtask *importStepMinimalTask -} - -var newImportMinimalTaskExecutor = newImportMinimalTaskExecutor0 - -func newImportMinimalTaskExecutor0(t *importStepMinimalTask) MiniTaskExecutor { - return &importMinimalTaskExecutor{ - mTtask: t, - } -} - -func (e *importMinimalTaskExecutor) Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error { - logger := logutil.BgLogger().With(zap.Stringer("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID)) - logger.Info("execute chunk") - failpoint.Inject("waitBeforeSortChunk", func() { - time.Sleep(3 * time.Second) - }) - failpoint.Inject("errorWhenSortChunk", func() { - failpoint.Return(errors.New("occur an error when sort chunk")) - }) - failpoint.InjectCall("syncBeforeSortChunk") - chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk) - sharedVars := e.mTtask.SharedVars - checksum := verify.NewKVGroupChecksumWithKeyspace(sharedVars.TableImporter.GetKeySpace()) - if sharedVars.TableImporter.IsLocalSort() { - if err := importer.ProcessChunk( - ctx, - &chunkCheckpoint, - sharedVars.TableImporter, - sharedVars.DataEngine, - sharedVars.IndexEngine, - sharedVars.Progress, - logger, - checksum, - ); err != nil { - return err - } - } else { - if err := importer.ProcessChunkWithWriter( - ctx, - &chunkCheckpoint, - sharedVars.TableImporter, - dataWriter, - indexWriter, - sharedVars.Progress, - logger, - checksum, - ); err != nil { - return err - } - } - - sharedVars.mu.Lock() - defer sharedVars.mu.Unlock() - sharedVars.Checksum.Add(checksum) - return nil -} - -// postProcess does the post-processing for the task. -func postProcess(ctx context.Context, store kv.Storage, taskMeta *TaskMeta, subtaskMeta *PostProcessStepMeta, logger *zap.Logger) (err error) { - failpoint.InjectCall("syncBeforePostProcess", taskMeta.JobID) - - callLog := log.BeginTask(logger, "post process") - defer func() { - callLog.End(zap.ErrorLevel, err) - }() - - if err = importer.RebaseAllocatorBases(ctx, store, subtaskMeta.MaxIDs, &taskMeta.Plan, logger); err != nil { - return err - } - - // TODO: create table indexes depends on the option. - // create table indexes even if the post process is failed. - // defer func() { - // err2 := createTableIndexes(ctx, globalTaskManager, taskMeta, logger) - // err = multierr.Append(err, err2) - // }() - - localChecksum := verify.NewKVGroupChecksumForAdd() - for id, cksum := range subtaskMeta.Checksum { - callLog.Info( - "kv group checksum", - zap.Int64("groupId", id), - zap.Uint64("size", cksum.Size), - zap.Uint64("kvs", cksum.KVs), - zap.Uint64("checksum", cksum.Sum), - ) - localChecksum.AddRawGroup(id, cksum.Size, cksum.KVs, cksum.Sum) - } - - taskManager, err := storage.GetTaskManager() - ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - if err != nil { - return err - } - return taskManager.WithNewSession(func(se sessionctx.Context) error { - return importer.VerifyChecksum(ctx, &taskMeta.Plan, localChecksum.MergedChecksum(), se, logger) - }) -} diff --git a/pkg/disttask/importinto/task_executor.go b/pkg/disttask/importinto/task_executor.go index bb0ea5da6b3c8..78f41ff03c0f7 100644 --- a/pkg/disttask/importinto/task_executor.go +++ b/pkg/disttask/importinto/task_executor.go @@ -492,9 +492,9 @@ func (p *postProcessStepExecutor) RunSubtask(ctx context.Context, subtask *proto if err = json.Unmarshal(subtask.Meta, &stepMeta); err != nil { return errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("waitBeforePostProcess")); _err_ == nil { + failpoint.Inject("waitBeforePostProcess", func() { time.Sleep(5 * time.Second) - } + }) return postProcess(ctx, p.store, p.taskMeta, &stepMeta, logger) } diff --git a/pkg/disttask/importinto/task_executor.go__failpoint_stash__ b/pkg/disttask/importinto/task_executor.go__failpoint_stash__ deleted file mode 100644 index 78f41ff03c0f7..0000000000000 --- a/pkg/disttask/importinto/task_executor.go__failpoint_stash__ +++ /dev/null @@ -1,577 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importinto - -import ( - "context" - "encoding/json" - "strconv" - "sync" - "time" - - "github.com/docker/go-units" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - brlogutil "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" - "github.com/pingcap/tidb/pkg/disttask/operator" - "github.com/pingcap/tidb/pkg/executor/importer" - tidbkv "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/external" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -// importStepExecutor is a executor for import step. -// StepExecutor is equivalent to a Lightning instance. -type importStepExecutor struct { - execute.StepExecFrameworkInfo - - taskID int64 - taskMeta *TaskMeta - tableImporter *importer.TableImporter - store tidbkv.Storage - sharedVars sync.Map - logger *zap.Logger - - dataKVMemSizePerCon uint64 - perIndexKVMemSizePerCon uint64 - - importCtx context.Context - importCancel context.CancelFunc - wg sync.WaitGroup -} - -func getTableImporter( - ctx context.Context, - taskID int64, - taskMeta *TaskMeta, - store tidbkv.Storage, -) (*importer.TableImporter, error) { - idAlloc := kv.NewPanickingAllocators(taskMeta.Plan.TableInfo.SepAutoInc(), 0) - tbl, err := tables.TableFromMeta(idAlloc, taskMeta.Plan.TableInfo) - if err != nil { - return nil, err - } - astArgs, err := importer.ASTArgsFromStmt(taskMeta.Stmt) - if err != nil { - return nil, err - } - controller, err := importer.NewLoadDataController(&taskMeta.Plan, tbl, astArgs) - if err != nil { - return nil, err - } - if err = controller.InitDataStore(ctx); err != nil { - return nil, err - } - - return importer.NewTableImporter(ctx, controller, strconv.FormatInt(taskID, 10), store) -} - -func (s *importStepExecutor) Init(ctx context.Context) error { - s.logger.Info("init subtask env") - tableImporter, err := getTableImporter(ctx, s.taskID, s.taskMeta, s.store) - if err != nil { - return err - } - s.tableImporter = tableImporter - - // we need this sub context since Cleanup which wait on this routine is called - // before parent context is canceled in normal flow. - s.importCtx, s.importCancel = context.WithCancel(ctx) - // only need to check disk quota when we are using local sort. - if s.tableImporter.IsLocalSort() { - s.wg.Add(1) - go func() { - defer s.wg.Done() - s.tableImporter.CheckDiskQuota(s.importCtx) - }() - } - s.dataKVMemSizePerCon, s.perIndexKVMemSizePerCon = getWriterMemorySizeLimit(s.GetResource(), s.tableImporter.Plan) - s.logger.Info("KV writer memory size limit per concurrency", - zap.String("data", units.BytesSize(float64(s.dataKVMemSizePerCon))), - zap.String("per-index", units.BytesSize(float64(s.perIndexKVMemSizePerCon)))) - return nil -} - -func (s *importStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { - logger := s.logger.With(zap.Int64("subtask-id", subtask.ID)) - task := log.BeginTask(logger, "run subtask") - defer func() { - task.End(zapcore.ErrorLevel, err) - }() - bs := subtask.Meta - var subtaskMeta ImportStepMeta - err = json.Unmarshal(bs, &subtaskMeta) - if err != nil { - return errors.Trace(err) - } - - var dataEngine, indexEngine *backend.OpenedEngine - if s.tableImporter.IsLocalSort() { - dataEngine, err = s.tableImporter.OpenDataEngine(ctx, subtaskMeta.ID) - if err != nil { - return err - } - // Unlike in Lightning, we start an index engine for each subtask, - // whereas previously there was only a single index engine globally. - // This is because the executor currently does not have a post-processing mechanism. - // If we import the index in `cleanupSubtaskEnv`, the scheduler will not wait for the import to complete. - // Multiple index engines may suffer performance degradation due to range overlap. - // These issues will be alleviated after we integrate s3 sorter. - // engineID = -1, -2, -3, ... - indexEngine, err = s.tableImporter.OpenIndexEngine(ctx, common.IndexEngineID-subtaskMeta.ID) - if err != nil { - return err - } - } - sharedVars := &SharedVars{ - TableImporter: s.tableImporter, - DataEngine: dataEngine, - IndexEngine: indexEngine, - Progress: importer.NewProgress(), - Checksum: verification.NewKVGroupChecksumWithKeyspace(s.tableImporter.GetKeySpace()), - SortedDataMeta: &external.SortedKVMeta{}, - SortedIndexMetas: make(map[int64]*external.SortedKVMeta), - } - s.sharedVars.Store(subtaskMeta.ID, sharedVars) - - source := operator.NewSimpleDataChannel(make(chan *importStepMinimalTask)) - op := newEncodeAndSortOperator(ctx, s, sharedVars, subtask.ID, subtask.Concurrency) - op.SetSource(source) - pipeline := operator.NewAsyncPipeline(op) - if err = pipeline.Execute(); err != nil { - return err - } - -outer: - for _, chunk := range subtaskMeta.Chunks { - // TODO: current workpool impl doesn't drain the input channel, it will - // just return on context cancel(error happened), so we add this select. - select { - case source.Channel() <- &importStepMinimalTask{ - Plan: s.taskMeta.Plan, - Chunk: chunk, - SharedVars: sharedVars, - }: - case <-op.Done(): - break outer - } - } - source.Finish() - - return pipeline.Close() -} - -func (*importStepExecutor) RealtimeSummary() *execute.SubtaskSummary { - return nil -} - -func (s *importStepExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { - var subtaskMeta ImportStepMeta - if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil { - return errors.Trace(err) - } - s.logger.Info("on subtask finished", zap.Int32("engine-id", subtaskMeta.ID)) - - val, ok := s.sharedVars.Load(subtaskMeta.ID) - if !ok { - return errors.Errorf("sharedVars %d not found", subtaskMeta.ID) - } - sharedVars, ok := val.(*SharedVars) - if !ok { - return errors.Errorf("sharedVars %d not found", subtaskMeta.ID) - } - - var dataKVCount uint64 - if s.tableImporter.IsLocalSort() { - // TODO: we should close and cleanup engine in all case, since there's no checkpoint. - s.logger.Info("import data engine", zap.Int32("engine-id", subtaskMeta.ID)) - closedDataEngine, err := sharedVars.DataEngine.Close(ctx) - if err != nil { - return err - } - dataKVCount2, err := s.tableImporter.ImportAndCleanup(ctx, closedDataEngine) - if err != nil { - return err - } - dataKVCount = uint64(dataKVCount2) - - s.logger.Info("import index engine", zap.Int32("engine-id", subtaskMeta.ID)) - if closedEngine, err := sharedVars.IndexEngine.Close(ctx); err != nil { - return err - } else if _, err := s.tableImporter.ImportAndCleanup(ctx, closedEngine); err != nil { - return err - } - } - // there's no imported dataKVCount on this stage when using global sort. - - sharedVars.mu.Lock() - defer sharedVars.mu.Unlock() - subtaskMeta.Checksum = map[int64]Checksum{} - for id, c := range sharedVars.Checksum.GetInnerChecksums() { - subtaskMeta.Checksum[id] = Checksum{ - Sum: c.Sum(), - KVs: c.SumKVS(), - Size: c.SumSize(), - } - } - subtaskMeta.Result = Result{ - LoadedRowCnt: dataKVCount, - ColSizeMap: sharedVars.Progress.GetColSize(), - } - allocators := sharedVars.TableImporter.Allocators() - subtaskMeta.MaxIDs = map[autoid.AllocatorType]int64{ - autoid.RowIDAllocType: allocators.Get(autoid.RowIDAllocType).Base(), - autoid.AutoIncrementType: allocators.Get(autoid.AutoIncrementType).Base(), - autoid.AutoRandomType: allocators.Get(autoid.AutoRandomType).Base(), - } - subtaskMeta.SortedDataMeta = sharedVars.SortedDataMeta - subtaskMeta.SortedIndexMetas = sharedVars.SortedIndexMetas - s.sharedVars.Delete(subtaskMeta.ID) - newMeta, err := json.Marshal(subtaskMeta) - if err != nil { - return errors.Trace(err) - } - subtask.Meta = newMeta - return nil -} - -func (s *importStepExecutor) Cleanup(_ context.Context) (err error) { - s.logger.Info("cleanup subtask env") - s.importCancel() - s.wg.Wait() - return s.tableImporter.Close() -} - -type mergeSortStepExecutor struct { - taskexecutor.EmptyStepExecutor - taskID int64 - taskMeta *TaskMeta - logger *zap.Logger - controller *importer.LoadDataController - // subtask of a task is run in serial now, so we don't need lock here. - // change to SyncMap when we support parallel subtask in the future. - subtaskSortedKVMeta *external.SortedKVMeta - // part-size for uploading merged files, it's calculated by: - // max(max-merged-files * max-file-size / max-part-num(10000), min-part-size) - dataKVPartSize int64 - indexKVPartSize int64 -} - -var _ execute.StepExecutor = &mergeSortStepExecutor{} - -func (m *mergeSortStepExecutor) Init(ctx context.Context) error { - controller, err := buildController(&m.taskMeta.Plan, m.taskMeta.Stmt) - if err != nil { - return err - } - if err = controller.InitDataStore(ctx); err != nil { - return err - } - m.controller = controller - dataKVMemSizePerCon, perIndexKVMemSizePerCon := getWriterMemorySizeLimit(m.GetResource(), &m.taskMeta.Plan) - m.dataKVPartSize = max(external.MinUploadPartSize, int64(dataKVMemSizePerCon*uint64(external.MaxMergingFilesPerThread)/10000)) - m.indexKVPartSize = max(external.MinUploadPartSize, int64(perIndexKVMemSizePerCon*uint64(external.MaxMergingFilesPerThread)/10000)) - - m.logger.Info("merge sort partSize", - zap.String("data-kv", units.BytesSize(float64(m.dataKVPartSize))), - zap.String("index-kv", units.BytesSize(float64(m.indexKVPartSize))), - ) - return nil -} - -func (m *mergeSortStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { - sm := &MergeSortStepMeta{} - err = json.Unmarshal(subtask.Meta, sm) - if err != nil { - return errors.Trace(err) - } - logger := m.logger.With(zap.Int64("subtask-id", subtask.ID), zap.String("kv-group", sm.KVGroup)) - task := log.BeginTask(logger, "run subtask") - defer func() { - task.End(zapcore.ErrorLevel, err) - }() - - var mu sync.Mutex - m.subtaskSortedKVMeta = &external.SortedKVMeta{} - onClose := func(summary *external.WriterSummary) { - mu.Lock() - defer mu.Unlock() - m.subtaskSortedKVMeta.MergeSummary(summary) - } - - prefix := subtaskPrefix(m.taskID, subtask.ID) - - partSize := m.dataKVPartSize - if sm.KVGroup != dataKVGroup { - partSize = m.indexKVPartSize - } - err = external.MergeOverlappingFiles( - logutil.WithFields(ctx, zap.String("kv-group", sm.KVGroup), zap.Int64("subtask-id", subtask.ID)), - sm.DataFiles, - m.controller.GlobalSortStore, - partSize, - prefix, - getKVGroupBlockSize(sm.KVGroup), - onClose, - subtask.Concurrency, - false) - logger.Info( - "merge sort finished", - zap.Uint64("total-kv-size", m.subtaskSortedKVMeta.TotalKVSize), - zap.Uint64("total-kv-count", m.subtaskSortedKVMeta.TotalKVCnt), - brlogutil.Key("start-key", m.subtaskSortedKVMeta.StartKey), - brlogutil.Key("end-key", m.subtaskSortedKVMeta.EndKey), - ) - return err -} - -func (m *mergeSortStepExecutor) OnFinished(_ context.Context, subtask *proto.Subtask) error { - var subtaskMeta MergeSortStepMeta - if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil { - return errors.Trace(err) - } - subtaskMeta.SortedKVMeta = *m.subtaskSortedKVMeta - m.subtaskSortedKVMeta = nil - newMeta, err := json.Marshal(subtaskMeta) - if err != nil { - return errors.Trace(err) - } - subtask.Meta = newMeta - return nil -} - -type writeAndIngestStepExecutor struct { - execute.StepExecFrameworkInfo - - taskID int64 - taskMeta *TaskMeta - logger *zap.Logger - tableImporter *importer.TableImporter - store tidbkv.Storage -} - -var _ execute.StepExecutor = &writeAndIngestStepExecutor{} - -func (e *writeAndIngestStepExecutor) Init(ctx context.Context) error { - tableImporter, err := getTableImporter(ctx, e.taskID, e.taskMeta, e.store) - if err != nil { - return err - } - e.tableImporter = tableImporter - return nil -} - -func (e *writeAndIngestStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { - sm := &WriteIngestStepMeta{} - err = json.Unmarshal(subtask.Meta, sm) - if err != nil { - return errors.Trace(err) - } - - logger := e.logger.With(zap.Int64("subtask-id", subtask.ID), - zap.String("kv-group", sm.KVGroup)) - task := log.BeginTask(logger, "run subtask") - defer func() { - task.End(zapcore.ErrorLevel, err) - }() - - _, engineUUID := backend.MakeUUID("", subtask.ID) - localBackend := e.tableImporter.Backend() - localBackend.WorkerConcurrency = subtask.Concurrency * 2 - err = localBackend.CloseEngine(ctx, &backend.EngineConfig{ - External: &backend.ExternalEngineConfig{ - StorageURI: e.taskMeta.Plan.CloudStorageURI, - DataFiles: sm.DataFiles, - StatFiles: sm.StatFiles, - StartKey: sm.StartKey, - EndKey: sm.EndKey, - SplitKeys: sm.RangeSplitKeys, - RegionSplitSize: sm.RangeSplitSize, - TotalFileSize: int64(sm.TotalKVSize), - TotalKVCount: 0, - CheckHotspot: false, - }, - TS: sm.TS, - }, engineUUID) - if err != nil { - return err - } - return localBackend.ImportEngine(ctx, engineUUID, int64(config.SplitRegionSize), int64(config.SplitRegionKeys)) -} - -func (*writeAndIngestStepExecutor) RealtimeSummary() *execute.SubtaskSummary { - return nil -} - -func (e *writeAndIngestStepExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error { - var subtaskMeta WriteIngestStepMeta - if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil { - return errors.Trace(err) - } - if subtaskMeta.KVGroup != dataKVGroup { - return nil - } - - // only data kv group has loaded row count - _, engineUUID := backend.MakeUUID("", subtask.ID) - localBackend := e.tableImporter.Backend() - _, kvCount := localBackend.GetExternalEngineKVStatistics(engineUUID) - subtaskMeta.Result.LoadedRowCnt = uint64(kvCount) - err := localBackend.CleanupEngine(ctx, engineUUID) - if err != nil { - e.logger.Warn("failed to cleanup engine", zap.Error(err)) - } - - newMeta, err := json.Marshal(subtaskMeta) - if err != nil { - return errors.Trace(err) - } - subtask.Meta = newMeta - return nil -} - -func (e *writeAndIngestStepExecutor) Cleanup(_ context.Context) (err error) { - e.logger.Info("cleanup subtask env") - return e.tableImporter.Close() -} - -type postProcessStepExecutor struct { - taskexecutor.EmptyStepExecutor - taskID int64 - store tidbkv.Storage - taskMeta *TaskMeta - logger *zap.Logger -} - -var _ execute.StepExecutor = &postProcessStepExecutor{} - -// NewPostProcessStepExecutor creates a new post process step executor. -// exported for testing. -func NewPostProcessStepExecutor(taskID int64, store tidbkv.Storage, taskMeta *TaskMeta, logger *zap.Logger) execute.StepExecutor { - return &postProcessStepExecutor{ - taskID: taskID, - store: store, - taskMeta: taskMeta, - logger: logger, - } -} - -func (p *postProcessStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) { - logger := p.logger.With(zap.Int64("subtask-id", subtask.ID)) - task := log.BeginTask(logger, "run subtask") - defer func() { - task.End(zapcore.ErrorLevel, err) - }() - stepMeta := PostProcessStepMeta{} - if err = json.Unmarshal(subtask.Meta, &stepMeta); err != nil { - return errors.Trace(err) - } - failpoint.Inject("waitBeforePostProcess", func() { - time.Sleep(5 * time.Second) - }) - return postProcess(ctx, p.store, p.taskMeta, &stepMeta, logger) -} - -type importExecutor struct { - *taskexecutor.BaseTaskExecutor - store tidbkv.Storage -} - -// NewImportExecutor creates a new import task executor. -func NewImportExecutor( - ctx context.Context, - id string, - task *proto.Task, - taskTable taskexecutor.TaskTable, - store tidbkv.Storage, -) taskexecutor.TaskExecutor { - metrics := metricsManager.getOrCreateMetrics(task.ID) - subCtx := metric.WithCommonMetric(ctx, metrics) - s := &importExecutor{ - BaseTaskExecutor: taskexecutor.NewBaseTaskExecutor(subCtx, id, task, taskTable), - store: store, - } - s.BaseTaskExecutor.Extension = s - return s -} - -func (*importExecutor) IsIdempotent(*proto.Subtask) bool { - // import don't have conflict detection and resolution now, so it's ok - // to import data twice. - return true -} - -func (*importExecutor) IsRetryableError(err error) bool { - return common.IsRetryableError(err) -} - -func (e *importExecutor) GetStepExecutor(task *proto.Task) (execute.StepExecutor, error) { - taskMeta := TaskMeta{} - if err := json.Unmarshal(task.Meta, &taskMeta); err != nil { - return nil, errors.Trace(err) - } - logger := logutil.BgLogger().With( - zap.Stringer("type", proto.ImportInto), - zap.Int64("task-id", task.ID), - zap.String("step", proto.Step2Str(task.Type, task.Step)), - ) - - switch task.Step { - case proto.ImportStepImport, proto.ImportStepEncodeAndSort: - return &importStepExecutor{ - taskID: task.ID, - taskMeta: &taskMeta, - logger: logger, - store: e.store, - }, nil - case proto.ImportStepMergeSort: - return &mergeSortStepExecutor{ - taskID: task.ID, - taskMeta: &taskMeta, - logger: logger, - }, nil - case proto.ImportStepWriteAndIngest: - return &writeAndIngestStepExecutor{ - taskID: task.ID, - taskMeta: &taskMeta, - logger: logger, - store: e.store, - }, nil - case proto.ImportStepPostProcess: - return NewPostProcessStepExecutor(task.ID, e.store, &taskMeta, logger), nil - default: - return nil, errors.Errorf("unknown step %d for import task %d", task.Step, task.ID) - } -} - -func (e *importExecutor) Close() { - task := e.GetTaskBase() - metricsManager.unregister(task.ID) - e.BaseTaskExecutor.Close() -} diff --git a/pkg/domain/binding__failpoint_binding__.go b/pkg/domain/binding__failpoint_binding__.go deleted file mode 100644 index 17ae9cb59498c..0000000000000 --- a/pkg/domain/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package domain - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/domain/domain.go b/pkg/domain/domain.go index e1889a312a056..8fe52df356fd4 100644 --- a/pkg/domain/domain.go +++ b/pkg/domain/domain.go @@ -566,22 +566,22 @@ func (do *Domain) tryLoadSchemaDiffs(builder *infoschema.Builder, m *meta.Meta, diffs = append(diffs, diff) } - if val, _err_ := failpoint.Eval(_curpkg_("MockTryLoadDiffError")); _err_ == nil { + failpoint.Inject("MockTryLoadDiffError", func(val failpoint.Value) { switch val.(string) { case "exchangepartition": if diffs[0].Type == model.ActionExchangeTablePartition { - return nil, nil, nil, errors.New("mock error") + failpoint.Return(nil, nil, nil, errors.New("mock error")) } case "renametable": if diffs[0].Type == model.ActionRenameTable { - return nil, nil, nil, errors.New("mock error") + failpoint.Return(nil, nil, nil, errors.New("mock error")) } case "dropdatabase": if diffs[0].Type == model.ActionDropSchema { - return nil, nil, nil, errors.New("mock error") + failpoint.Return(nil, nil, nil, errors.New("mock error")) } } - } + }) err := builder.InitWithOldInfoSchema(do.infoCache.GetLatest()) if err != nil { @@ -720,11 +720,11 @@ func getFlashbackStartTSFromErrorMsg(err error) uint64 { // Reload reloads InfoSchema. // It's public in order to do the test. func (do *Domain) Reload() error { - if val, _err_ := failpoint.Eval(_curpkg_("ErrorMockReloadFailed")); _err_ == nil { + failpoint.Inject("ErrorMockReloadFailed", func(val failpoint.Value) { if val.(bool) { - return errors.New("mock reload failed") + failpoint.Return(errors.New("mock reload failed")) } - } + }) // Lock here for only once at the same time. do.m.Lock() @@ -1048,9 +1048,9 @@ func (do *Domain) loadSchemaInLoop(ctx context.Context, lease time.Duration) { for { select { case <-ticker.C: - if _, _err_ := failpoint.Eval(_curpkg_("disableOnTickReload")); _err_ == nil { - continue - } + failpoint.Inject("disableOnTickReload", func() { + failpoint.Continue() + }) err := do.Reload() if err != nil { logutil.BgLogger().Error("reload schema in loop failed", zap.Error(err)) @@ -1346,12 +1346,12 @@ func (do *Domain) Init( ddl.WithSchemaLoader(do), ) - if val, _err_ := failpoint.Eval(_curpkg_("MockReplaceDDL")); _err_ == nil { + failpoint.Inject("MockReplaceDDL", func(val failpoint.Value) { if val.(bool) { do.ddl = d do.ddlExecutor = eBak } - } + }) if ddlInjector != nil { checker := ddlInjector(do.ddl, do.ddlExecutor, do.infoCache) checker.CreateTestDB(nil) @@ -1645,11 +1645,11 @@ func (do *Domain) checkReplicaRead(ctx context.Context, pdClient pd.Client) erro // InitDistTaskLoop initializes the distributed task framework. func (do *Domain) InitDistTaskLoop() error { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalDistTask) - if val, _err_ := failpoint.Eval(_curpkg_("MockDisableDistTask")); _err_ == nil { + failpoint.Inject("MockDisableDistTask", func(val failpoint.Value) { if val.(bool) { - return nil + failpoint.Return(nil) } - } + }) taskManager := storage.NewTaskManager(do.sysSessionPool) var serverID string @@ -1857,7 +1857,7 @@ func (do *Domain) LoadSysVarCacheLoop(ctx sessionctx.Context) error { case <-time.After(duration): } - if val, _err_ := failpoint.Eval(_curpkg_("skipLoadSysVarCacheLoop")); _err_ == nil { + failpoint.Inject("skipLoadSysVarCacheLoop", func(val failpoint.Value) { // In some pkg integration test, there are many testSuite, and each testSuite has separate storage and // `LoadSysVarCacheLoop` background goroutine. Then each testSuite `RebuildSysVarCache` from it's // own storage. @@ -1865,9 +1865,9 @@ func (do *Domain) LoadSysVarCacheLoop(ctx sessionctx.Context) error { // That's the problem, each testSuit use different storage to update some same local variables. // So just skip `RebuildSysVarCache` in some integration testing. if val.(bool) { - continue + failpoint.Continue() } - } + }) if !ok { logutil.BgLogger().Error("LoadSysVarCacheLoop loop watch channel closed") diff --git a/pkg/domain/domain.go__failpoint_stash__ b/pkg/domain/domain.go__failpoint_stash__ deleted file mode 100644 index 8fe52df356fd4..0000000000000 --- a/pkg/domain/domain.go__failpoint_stash__ +++ /dev/null @@ -1,3239 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package domain - -import ( - "context" - "fmt" - "math" - "math/rand" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/kvproto/pkg/pdpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/streamhelper" - "github.com/pingcap/tidb/br/pkg/streamhelper/daemon" - "github.com/pingcap/tidb/pkg/bindinfo" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/ddl/schematracker" - "github.com/pingcap/tidb/pkg/ddl/systable" - ddlutil "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" - "github.com/pingcap/tidb/pkg/domain/globalconfigsync" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/infoschema" - infoschema_metrics "github.com/pingcap/tidb/pkg/infoschema/metrics" - "github.com/pingcap/tidb/pkg/infoschema/perfschema" - "github.com/pingcap/tidb/pkg/keyspace" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/owner" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - metrics2 "github.com/pingcap/tidb/pkg/planner/core/metrics" - "github.com/pingcap/tidb/pkg/privilege/privileges" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" - "github.com/pingcap/tidb/pkg/sessionctx/sysproctrack" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics/handle" - statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/ttl/ttlworker" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - disttaskutil "github.com/pingcap/tidb/pkg/util/disttask" - "github.com/pingcap/tidb/pkg/util/domainutil" - "github.com/pingcap/tidb/pkg/util/engine" - "github.com/pingcap/tidb/pkg/util/etcd" - "github.com/pingcap/tidb/pkg/util/expensivequery" - "github.com/pingcap/tidb/pkg/util/gctuner" - "github.com/pingcap/tidb/pkg/util/globalconn" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/memoryusagealarm" - "github.com/pingcap/tidb/pkg/util/replayer" - "github.com/pingcap/tidb/pkg/util/servermemorylimit" - "github.com/pingcap/tidb/pkg/util/sqlkiller" - "github.com/pingcap/tidb/pkg/util/syncutil" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/txnkv/transaction" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - rmclient "github.com/tikv/pd/client/resource_group/controller" - clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/client/v3/concurrency" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/backoff" - "google.golang.org/grpc/keepalive" -) - -var ( - mdlCheckLookDuration = 50 * time.Millisecond - - // LoadSchemaDiffVersionGapThreshold is the threshold for version gap to reload domain by loading schema diffs - LoadSchemaDiffVersionGapThreshold int64 = 10000 - - // NewInstancePlanCache creates a new instance level plan cache, this function is designed to avoid cycle-import. - NewInstancePlanCache func(softMemLimit, hardMemLimit int64) sessionctx.InstancePlanCache -) - -const ( - indexUsageGCDuration = 30 * time.Minute -) - -func init() { - if intest.InTest { - // In test we can set duration lower to make test faster. - mdlCheckLookDuration = 2 * time.Millisecond - } -} - -// NewMockDomain is only used for test -func NewMockDomain() *Domain { - do := &Domain{} - do.infoCache = infoschema.NewCache(do, 1) - do.infoCache.Insert(infoschema.MockInfoSchema(nil), 0) - return do -} - -// Domain represents a storage space. Different domains can use the same database name. -// Multiple domains can be used in parallel without synchronization. -type Domain struct { - store kv.Storage - infoCache *infoschema.InfoCache - privHandle *privileges.Handle - bindHandle atomic.Value - statsHandle atomic.Pointer[handle.Handle] - statsLease time.Duration - ddl ddl.DDL - ddlExecutor ddl.Executor - info *infosync.InfoSyncer - globalCfgSyncer *globalconfigsync.GlobalConfigSyncer - m syncutil.Mutex - SchemaValidator SchemaValidator - sysSessionPool util.SessionPool - exit chan struct{} - // `etcdClient` must be used when keyspace is not set, or when the logic to each etcd path needs to be separated by keyspace. - etcdClient *clientv3.Client - // autoidClient is used when there are tables with AUTO_ID_CACHE=1, it is the client to the autoid service. - autoidClient *autoid.ClientDiscover - // `unprefixedEtcdCli` will never set the etcd namespace prefix by keyspace. - // It is only used in storeMinStartTS and RemoveMinStartTS now. - // It must be used when the etcd path isn't needed to separate by keyspace. - // See keyspace RFC: https://github.com/pingcap/tidb/pull/39685 - unprefixedEtcdCli *clientv3.Client - sysVarCache sysVarCache // replaces GlobalVariableCache - slowQuery *topNSlowQueries - expensiveQueryHandle *expensivequery.Handle - memoryUsageAlarmHandle *memoryusagealarm.Handle - serverMemoryLimitHandle *servermemorylimit.Handle - // TODO: use Run for each process in future pr - wg *util.WaitGroupEnhancedWrapper - statsUpdating atomicutil.Int32 - // this is the parent context of DDL, and also used by other loops such as closestReplicaReadCheckLoop. - // there are other top level contexts in the domain, such as the ones used in - // InitDistTaskLoop and loadStatsWorker, domain only stores the cancelFns of them. - // TODO unify top level context. - ctx context.Context - cancelFns struct { - mu sync.Mutex - fns []context.CancelFunc - } - dumpFileGcChecker *dumpFileGcChecker - planReplayerHandle *planReplayerHandle - extractTaskHandle *ExtractHandle - expiredTimeStamp4PC struct { - // let `expiredTimeStamp4PC` use its own lock to avoid any block across domain.Reload() - // and compiler.Compile(), see issue https://github.com/pingcap/tidb/issues/45400 - sync.RWMutex - expiredTimeStamp types.Time - } - - logBackupAdvancer *daemon.OwnerDaemon - historicalStatsWorker *HistoricalStatsWorker - ttlJobManager atomic.Pointer[ttlworker.JobManager] - runawayManager *resourcegroup.RunawayManager - runawaySyncer *runawaySyncer - resourceGroupsController *rmclient.ResourceGroupsController - - serverID uint64 - serverIDSession *concurrency.Session - isLostConnectionToPD atomicutil.Int32 // !0: true, 0: false. - connIDAllocator globalconn.Allocator - - onClose func() - sysExecutorFactory func(*Domain) (pools.Resource, error) - - sysProcesses SysProcesses - - mdlCheckTableInfo *mdlCheckTableInfo - - analyzeMu struct { - sync.Mutex - sctxs map[sessionctx.Context]bool - } - - mdlCheckCh chan struct{} - stopAutoAnalyze atomicutil.Bool - minJobIDRefresher *systable.MinJobIDRefresher - - instancePlanCache sessionctx.InstancePlanCache // the instance level plan cache - - // deferFn is used to release infoschema object lazily during v1 and v2 switch - deferFn -} - -type deferFn struct { - sync.Mutex - data []deferFnRecord -} - -type deferFnRecord struct { - fn func() - fire time.Time -} - -func (df *deferFn) add(fn func(), fire time.Time) { - df.Lock() - defer df.Unlock() - df.data = append(df.data, deferFnRecord{fn: fn, fire: fire}) -} - -func (df *deferFn) check() { - now := time.Now() - df.Lock() - defer df.Unlock() - - // iterate the slice, call the defer function and remove it. - rm := 0 - for i := 0; i < len(df.data); i++ { - record := &df.data[i] - if now.After(record.fire) { - record.fn() - rm++ - } else { - df.data[i-rm] = df.data[i] - } - } - df.data = df.data[:len(df.data)-rm] -} - -type mdlCheckTableInfo struct { - mu sync.Mutex - newestVer int64 - jobsVerMap map[int64]int64 - jobsIDsMap map[int64]string -} - -// InfoCache export for test. -func (do *Domain) InfoCache() *infoschema.InfoCache { - return do.infoCache -} - -// EtcdClient export for test. -func (do *Domain) EtcdClient() *clientv3.Client { - return do.etcdClient -} - -// loadInfoSchema loads infoschema at startTS. -// It returns: -// 1. the needed infoschema -// 2. cache hit indicator -// 3. currentSchemaVersion(before loading) -// 4. the changed table IDs if it is not full load -// 5. an error if any -func (do *Domain) loadInfoSchema(startTS uint64, isSnapshot bool) (infoschema.InfoSchema, bool, int64, *transaction.RelatedSchemaChange, error) { - beginTime := time.Now() - defer func() { - infoschema_metrics.LoadSchemaDurationTotal.Observe(time.Since(beginTime).Seconds()) - }() - snapshot := do.store.GetSnapshot(kv.NewVersion(startTS)) - // Using the KV timeout read feature to address the issue of potential DDL lease expiration when - // the meta region leader is slow. - snapshot.SetOption(kv.TiKVClientReadTimeout, uint64(3000)) // 3000ms. - m := meta.NewSnapshotMeta(snapshot) - neededSchemaVersion, err := m.GetSchemaVersionWithNonEmptyDiff() - if err != nil { - return nil, false, 0, nil, err - } - // fetch the commit timestamp of the schema diff - schemaTs, err := do.getTimestampForSchemaVersionWithNonEmptyDiff(m, neededSchemaVersion, startTS) - if err != nil { - logutil.BgLogger().Warn("failed to get schema version", zap.Error(err), zap.Int64("version", neededSchemaVersion)) - schemaTs = 0 - } - - if is := do.infoCache.GetByVersion(neededSchemaVersion); is != nil { - isV2, raw := infoschema.IsV2(is) - if isV2 { - // Copy the infoschema V2 instance and update its ts. - // For example, the DDL run 30 minutes ago, GC happened 10 minutes ago. If we use - // that infoschema it would get error "GC life time is shorter than transaction - // duration" when visiting TiKV. - // So we keep updating the ts of the infoschema v2. - is = raw.CloneAndUpdateTS(startTS) - } - - // try to insert here as well to correct the schemaTs if previous is wrong - // the insert method check if schemaTs is zero - do.infoCache.Insert(is, schemaTs) - - return is, true, 0, nil, nil - } - - var oldIsV2 bool - enableV2 := variable.SchemaCacheSize.Load() > 0 - currentSchemaVersion := int64(0) - if oldInfoSchema := do.infoCache.GetLatest(); oldInfoSchema != nil { - currentSchemaVersion = oldInfoSchema.SchemaMetaVersion() - oldIsV2, _ = infoschema.IsV2(oldInfoSchema) - } - useV2, isV1V2Switch := shouldUseV2(enableV2, oldIsV2, isSnapshot) - builder := infoschema.NewBuilder(do, do.sysFacHack, do.infoCache.Data, useV2) - - // TODO: tryLoadSchemaDiffs has potential risks of failure. And it becomes worse in history reading cases. - // It is only kept because there is no alternative diff/partial loading solution. - // And it is only used to diff upgrading the current latest infoschema, if: - // 1. Not first time bootstrap loading, which needs a full load. - // 2. It is newer than the current one, so it will be "the current one" after this function call. - // 3. There are less 100 diffs. - // 4. No regenerated schema diff. - startTime := time.Now() - if !isV1V2Switch && currentSchemaVersion != 0 && neededSchemaVersion > currentSchemaVersion && neededSchemaVersion-currentSchemaVersion < LoadSchemaDiffVersionGapThreshold { - is, relatedChanges, diffTypes, err := do.tryLoadSchemaDiffs(builder, m, currentSchemaVersion, neededSchemaVersion, startTS) - if err == nil { - infoschema_metrics.LoadSchemaDurationLoadDiff.Observe(time.Since(startTime).Seconds()) - isV2, _ := infoschema.IsV2(is) - do.infoCache.Insert(is, schemaTs) - logutil.BgLogger().Info("diff load InfoSchema success", - zap.Bool("isV2", isV2), - zap.Int64("currentSchemaVersion", currentSchemaVersion), - zap.Int64("neededSchemaVersion", neededSchemaVersion), - zap.Duration("elapsed time", time.Since(startTime)), - zap.Int64("gotSchemaVersion", is.SchemaMetaVersion()), - zap.Int64s("phyTblIDs", relatedChanges.PhyTblIDS), - zap.Uint64s("actionTypes", relatedChanges.ActionTypes), - zap.Strings("diffTypes", diffTypes)) - return is, false, currentSchemaVersion, relatedChanges, nil - } - // We can fall back to full load, don't need to return the error. - logutil.BgLogger().Error("failed to load schema diff", zap.Error(err)) - } - // full load. - schemas, err := do.fetchAllSchemasWithTables(m) - if err != nil { - return nil, false, currentSchemaVersion, nil, err - } - - policies, err := do.fetchPolicies(m) - if err != nil { - return nil, false, currentSchemaVersion, nil, err - } - - resourceGroups, err := do.fetchResourceGroups(m) - if err != nil { - return nil, false, currentSchemaVersion, nil, err - } - infoschema_metrics.LoadSchemaDurationLoadAll.Observe(time.Since(startTime).Seconds()) - - err = builder.InitWithDBInfos(schemas, policies, resourceGroups, neededSchemaVersion) - if err != nil { - return nil, false, currentSchemaVersion, nil, err - } - is := builder.Build(startTS) - isV2, _ := infoschema.IsV2(is) - logutil.BgLogger().Info("full load InfoSchema success", - zap.Bool("isV2", isV2), - zap.Int64("currentSchemaVersion", currentSchemaVersion), - zap.Int64("neededSchemaVersion", neededSchemaVersion), - zap.Duration("elapsed time", time.Since(startTime))) - - if isV1V2Switch && schemaTs > 0 { - // Reset the whole info cache to avoid co-existing of both v1 and v2, causing the memory usage doubled. - fn := do.infoCache.Upsert(is, schemaTs) - do.deferFn.add(fn, time.Now().Add(10*time.Minute)) - } else { - do.infoCache.Insert(is, schemaTs) - } - return is, false, currentSchemaVersion, nil, nil -} - -// Returns the timestamp of a schema version, which is the commit timestamp of the schema diff -func (do *Domain) getTimestampForSchemaVersionWithNonEmptyDiff(m *meta.Meta, version int64, startTS uint64) (uint64, error) { - tikvStore, ok := do.Store().(helper.Storage) - if ok { - newHelper := helper.NewHelper(tikvStore) - mvccResp, err := newHelper.GetMvccByEncodedKeyWithTS(m.EncodeSchemaDiffKey(version), startTS) - if err != nil { - return 0, err - } - if mvccResp == nil || mvccResp.Info == nil || len(mvccResp.Info.Writes) == 0 { - return 0, errors.Errorf("There is no Write MVCC info for the schema version") - } - return mvccResp.Info.Writes[0].CommitTs, nil - } - return 0, errors.Errorf("cannot get store from domain") -} - -func (do *Domain) sysFacHack() (pools.Resource, error) { - // TODO: Here we create new sessions with sysFac in DDL, - // which will use `do` as Domain instead of call `domap.Get`. - // That's because `domap.Get` requires a lock, but before - // we initialize Domain finish, we can't require that again. - // After we remove the lazy logic of creating Domain, we - // can simplify code here. - return do.sysExecutorFactory(do) -} - -func (*Domain) fetchPolicies(m *meta.Meta) ([]*model.PolicyInfo, error) { - allPolicies, err := m.ListPolicies() - if err != nil { - return nil, err - } - return allPolicies, nil -} - -func (*Domain) fetchResourceGroups(m *meta.Meta) ([]*model.ResourceGroupInfo, error) { - allResourceGroups, err := m.ListResourceGroups() - if err != nil { - return nil, err - } - return allResourceGroups, nil -} - -func (do *Domain) fetchAllSchemasWithTables(m *meta.Meta) ([]*model.DBInfo, error) { - allSchemas, err := m.ListDatabases() - if err != nil { - return nil, err - } - splittedSchemas := do.splitForConcurrentFetch(allSchemas) - doneCh := make(chan error, len(splittedSchemas)) - for _, schemas := range splittedSchemas { - go do.fetchSchemasWithTables(schemas, m, doneCh) - } - for range splittedSchemas { - err = <-doneCh - if err != nil { - return nil, err - } - } - return allSchemas, nil -} - -// fetchSchemaConcurrency controls the goroutines to load schemas, but more goroutines -// increase the memory usage when calling json.Unmarshal(), which would cause OOM, -// so we decrease the concurrency. -const fetchSchemaConcurrency = 1 - -func (*Domain) splitForConcurrentFetch(schemas []*model.DBInfo) [][]*model.DBInfo { - groupSize := (len(schemas) + fetchSchemaConcurrency - 1) / fetchSchemaConcurrency - if variable.SchemaCacheSize.Load() > 0 && len(schemas) > 1000 { - // TODO: Temporary solution to speed up when too many databases, will refactor it later. - groupSize = 8 - } - splitted := make([][]*model.DBInfo, 0, fetchSchemaConcurrency) - schemaCnt := len(schemas) - for i := 0; i < schemaCnt; i += groupSize { - end := i + groupSize - if end > schemaCnt { - end = schemaCnt - } - splitted = append(splitted, schemas[i:end]) - } - return splitted -} - -func (*Domain) fetchSchemasWithTables(schemas []*model.DBInfo, m *meta.Meta, done chan error) { - for _, di := range schemas { - if di.State != model.StatePublic { - // schema is not public, can't be used outside. - continue - } - var tables []*model.TableInfo - var err error - if variable.SchemaCacheSize.Load() > 0 && !infoschema.IsSpecialDB(di.Name.L) { - name2ID, specialTableInfos, err := meta.GetAllNameToIDAndTheMustLoadedTableInfo(m, di.ID) - if err != nil { - done <- err - return - } - di.TableName2ID = name2ID - tables = specialTableInfos - } else { - tables, err = m.ListTables(di.ID) - if err != nil { - done <- err - return - } - } - // If TreatOldVersionUTF8AsUTF8MB4 was enable, need to convert the old version schema UTF8 charset to UTF8MB4. - if config.GetGlobalConfig().TreatOldVersionUTF8AsUTF8MB4 { - for _, tbInfo := range tables { - infoschema.ConvertOldVersionUTF8ToUTF8MB4IfNeed(tbInfo) - } - } - diTables := make([]*model.TableInfo, 0, len(tables)) - for _, tbl := range tables { - if tbl.State != model.StatePublic { - // schema is not public, can't be used outside. - continue - } - infoschema.ConvertCharsetCollateToLowerCaseIfNeed(tbl) - // Check whether the table is in repair mode. - if domainutil.RepairInfo.InRepairMode() && domainutil.RepairInfo.CheckAndFetchRepairedTable(di, tbl) { - if tbl.State != model.StatePublic { - // Do not load it because we are reparing the table and the table info could be `bad` - // before repair is done. - continue - } - // If the state is public, it means that the DDL job is done, but the table - // haven't been deleted from the repair table list. - // Since the repairment is done and table is visible, we should load it. - } - diTables = append(diTables, tbl) - } - di.Deprecated.Tables = diTables - } - done <- nil -} - -// shouldUseV2 decides whether to use infoschema v2. -// When loading snapshot, infoschema should keep the same as before to avoid v1/v2 switch. -// Otherwise, it is decided by enabledV2. -func shouldUseV2(enableV2 bool, oldIsV2 bool, isSnapshot bool) (useV2 bool, isV1V2Switch bool) { - if isSnapshot { - return oldIsV2, false - } - return enableV2, enableV2 != oldIsV2 -} - -// tryLoadSchemaDiffs tries to only load latest schema changes. -// Return true if the schema is loaded successfully. -// Return false if the schema can not be loaded by schema diff, then we need to do full load. -// The second returned value is the delta updated table and partition IDs. -func (do *Domain) tryLoadSchemaDiffs(builder *infoschema.Builder, m *meta.Meta, usedVersion, newVersion int64, startTS uint64) (infoschema.InfoSchema, *transaction.RelatedSchemaChange, []string, error) { - var diffs []*model.SchemaDiff - for usedVersion < newVersion { - usedVersion++ - diff, err := m.GetSchemaDiff(usedVersion) - if err != nil { - return nil, nil, nil, err - } - if diff == nil { - // Empty diff means the txn of generating schema version is committed, but the txn of `runDDLJob` is not or fail. - // It is safe to skip the empty diff because the infoschema is new enough and consistent. - logutil.BgLogger().Info("diff load InfoSchema get empty schema diff", zap.Int64("version", usedVersion)) - do.infoCache.InsertEmptySchemaVersion(usedVersion) - continue - } - diffs = append(diffs, diff) - } - - failpoint.Inject("MockTryLoadDiffError", func(val failpoint.Value) { - switch val.(string) { - case "exchangepartition": - if diffs[0].Type == model.ActionExchangeTablePartition { - failpoint.Return(nil, nil, nil, errors.New("mock error")) - } - case "renametable": - if diffs[0].Type == model.ActionRenameTable { - failpoint.Return(nil, nil, nil, errors.New("mock error")) - } - case "dropdatabase": - if diffs[0].Type == model.ActionDropSchema { - failpoint.Return(nil, nil, nil, errors.New("mock error")) - } - } - }) - - err := builder.InitWithOldInfoSchema(do.infoCache.GetLatest()) - if err != nil { - return nil, nil, nil, errors.Trace(err) - } - - builder.WithStore(do.store).SetDeltaUpdateBundles() - phyTblIDs := make([]int64, 0, len(diffs)) - actions := make([]uint64, 0, len(diffs)) - diffTypes := make([]string, 0, len(diffs)) - for _, diff := range diffs { - if diff.RegenerateSchemaMap { - return nil, nil, nil, errors.Errorf("Meets a schema diff with RegenerateSchemaMap flag") - } - ids, err := builder.ApplyDiff(m, diff) - if err != nil { - return nil, nil, nil, err - } - if canSkipSchemaCheckerDDL(diff.Type) { - continue - } - diffTypes = append(diffTypes, diff.Type.String()) - phyTblIDs = append(phyTblIDs, ids...) - for i := 0; i < len(ids); i++ { - actions = append(actions, uint64(diff.Type)) - } - } - - is := builder.Build(startTS) - relatedChange := transaction.RelatedSchemaChange{} - relatedChange.PhyTblIDS = phyTblIDs - relatedChange.ActionTypes = actions - return is, &relatedChange, diffTypes, nil -} - -func canSkipSchemaCheckerDDL(tp model.ActionType) bool { - switch tp { - case model.ActionUpdateTiFlashReplicaStatus, model.ActionSetTiFlashReplica: - return true - } - return false -} - -// InfoSchema gets the latest information schema from domain. -func (do *Domain) InfoSchema() infoschema.InfoSchema { - return do.infoCache.GetLatest() -} - -// GetSnapshotInfoSchema gets a snapshot information schema. -func (do *Domain) GetSnapshotInfoSchema(snapshotTS uint64) (infoschema.InfoSchema, error) { - // if the snapshotTS is new enough, we can get infoschema directly through snapshotTS. - if is := do.infoCache.GetBySnapshotTS(snapshotTS); is != nil { - return is, nil - } - is, _, _, _, err := do.loadInfoSchema(snapshotTS, true) - infoschema_metrics.LoadSchemaCounterSnapshot.Inc() - return is, err -} - -// GetSnapshotMeta gets a new snapshot meta at startTS. -func (do *Domain) GetSnapshotMeta(startTS uint64) *meta.Meta { - snapshot := do.store.GetSnapshot(kv.NewVersion(startTS)) - return meta.NewSnapshotMeta(snapshot) -} - -// ExpiredTimeStamp4PC gets expiredTimeStamp4PC from domain. -func (do *Domain) ExpiredTimeStamp4PC() types.Time { - do.expiredTimeStamp4PC.RLock() - defer do.expiredTimeStamp4PC.RUnlock() - - return do.expiredTimeStamp4PC.expiredTimeStamp -} - -// SetExpiredTimeStamp4PC sets the expiredTimeStamp4PC from domain. -func (do *Domain) SetExpiredTimeStamp4PC(time types.Time) { - do.expiredTimeStamp4PC.Lock() - defer do.expiredTimeStamp4PC.Unlock() - - do.expiredTimeStamp4PC.expiredTimeStamp = time -} - -// DDL gets DDL from domain. -func (do *Domain) DDL() ddl.DDL { - return do.ddl -} - -// DDLExecutor gets the ddl executor from domain. -func (do *Domain) DDLExecutor() ddl.Executor { - return do.ddlExecutor -} - -// SetDDL sets DDL to domain, it's only used in tests. -func (do *Domain) SetDDL(d ddl.DDL, executor ddl.Executor) { - do.ddl = d - do.ddlExecutor = executor -} - -// InfoSyncer gets infoSyncer from domain. -func (do *Domain) InfoSyncer() *infosync.InfoSyncer { - return do.info -} - -// NotifyGlobalConfigChange notify global config syncer to store the global config into PD. -func (do *Domain) NotifyGlobalConfigChange(name, value string) { - do.globalCfgSyncer.Notify(pd.GlobalConfigItem{Name: name, Value: value, EventType: pdpb.EventType_PUT}) -} - -// GetGlobalConfigSyncer exports for testing. -func (do *Domain) GetGlobalConfigSyncer() *globalconfigsync.GlobalConfigSyncer { - return do.globalCfgSyncer -} - -// Store gets KV store from domain. -func (do *Domain) Store() kv.Storage { - return do.store -} - -// GetScope gets the status variables scope. -func (*Domain) GetScope(string) variable.ScopeFlag { - // Now domain status variables scope are all default scope. - return variable.DefaultStatusVarScopeFlag -} - -func getFlashbackStartTSFromErrorMsg(err error) uint64 { - slices := strings.Split(err.Error(), "is in flashback progress, FlashbackStartTS is ") - if len(slices) != 2 { - return 0 - } - version, err := strconv.ParseUint(slices[1], 10, 0) - if err != nil { - return 0 - } - return version -} - -// Reload reloads InfoSchema. -// It's public in order to do the test. -func (do *Domain) Reload() error { - failpoint.Inject("ErrorMockReloadFailed", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("mock reload failed")) - } - }) - - // Lock here for only once at the same time. - do.m.Lock() - defer do.m.Unlock() - - startTime := time.Now() - ver, err := do.store.CurrentVersion(kv.GlobalTxnScope) - if err != nil { - return err - } - - version := ver.Ver - is, hitCache, oldSchemaVersion, changes, err := do.loadInfoSchema(version, false) - if err != nil { - if version = getFlashbackStartTSFromErrorMsg(err); version != 0 { - // use the latest available version to create domain - version-- - is, hitCache, oldSchemaVersion, changes, err = do.loadInfoSchema(version, false) - } - } - if err != nil { - metrics.LoadSchemaCounter.WithLabelValues("failed").Inc() - return err - } - metrics.LoadSchemaCounter.WithLabelValues("succ").Inc() - - // only update if it is not from cache - if !hitCache { - // loaded newer schema - if oldSchemaVersion < is.SchemaMetaVersion() { - // Update self schema version to etcd. - err = do.ddl.SchemaSyncer().UpdateSelfVersion(context.Background(), 0, is.SchemaMetaVersion()) - if err != nil { - logutil.BgLogger().Info("update self version failed", - zap.Int64("oldSchemaVersion", oldSchemaVersion), - zap.Int64("neededSchemaVersion", is.SchemaMetaVersion()), zap.Error(err)) - } - } - - // it is full load - if changes == nil { - logutil.BgLogger().Info("full load and reset schema validator") - do.SchemaValidator.Reset() - } - } - - // lease renew, so it must be executed despite it is cache or not - do.SchemaValidator.Update(version, oldSchemaVersion, is.SchemaMetaVersion(), changes) - lease := do.DDL().GetLease() - sub := time.Since(startTime) - // Reload interval is lease / 2, if load schema time elapses more than this interval, - // some query maybe responded by ErrInfoSchemaExpired error. - if sub > (lease/2) && lease > 0 { - logutil.BgLogger().Warn("loading schema takes a long time", zap.Duration("take time", sub)) - } - - return nil -} - -// LogSlowQuery keeps topN recent slow queries in domain. -func (do *Domain) LogSlowQuery(query *SlowQueryInfo) { - do.slowQuery.mu.RLock() - defer do.slowQuery.mu.RUnlock() - if do.slowQuery.mu.closed { - return - } - - select { - case do.slowQuery.ch <- query: - default: - } -} - -// ShowSlowQuery returns the slow queries. -func (do *Domain) ShowSlowQuery(showSlow *ast.ShowSlow) []*SlowQueryInfo { - msg := &showSlowMessage{ - request: showSlow, - } - msg.Add(1) - do.slowQuery.msgCh <- msg - msg.Wait() - return msg.result -} - -func (do *Domain) topNSlowQueryLoop() { - defer util.Recover(metrics.LabelDomain, "topNSlowQueryLoop", nil, false) - ticker := time.NewTicker(time.Minute * 10) - defer func() { - ticker.Stop() - logutil.BgLogger().Info("topNSlowQueryLoop exited.") - }() - for { - select { - case now := <-ticker.C: - do.slowQuery.RemoveExpired(now) - case info, ok := <-do.slowQuery.ch: - if !ok { - return - } - do.slowQuery.Append(info) - case msg := <-do.slowQuery.msgCh: - req := msg.request - switch req.Tp { - case ast.ShowSlowTop: - msg.result = do.slowQuery.QueryTop(int(req.Count), req.Kind) - case ast.ShowSlowRecent: - msg.result = do.slowQuery.QueryRecent(int(req.Count)) - default: - msg.result = do.slowQuery.QueryAll() - } - msg.Done() - } - } -} - -func (do *Domain) infoSyncerKeeper() { - defer func() { - logutil.BgLogger().Info("infoSyncerKeeper exited.") - }() - - defer util.Recover(metrics.LabelDomain, "infoSyncerKeeper", nil, false) - - ticker := time.NewTicker(infosync.ReportInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - do.info.ReportMinStartTS(do.Store()) - case <-do.info.Done(): - logutil.BgLogger().Info("server info syncer need to restart") - if err := do.info.Restart(context.Background()); err != nil { - logutil.BgLogger().Error("server info syncer restart failed", zap.Error(err)) - } else { - logutil.BgLogger().Info("server info syncer restarted") - } - case <-do.exit: - return - } - } -} - -func (do *Domain) globalConfigSyncerKeeper() { - defer func() { - logutil.BgLogger().Info("globalConfigSyncerKeeper exited.") - }() - - defer util.Recover(metrics.LabelDomain, "globalConfigSyncerKeeper", nil, false) - - for { - select { - case entry := <-do.globalCfgSyncer.NotifyCh: - err := do.globalCfgSyncer.StoreGlobalConfig(context.Background(), entry) - if err != nil { - logutil.BgLogger().Error("global config syncer store failed", zap.Error(err)) - } - // TODO(crazycs520): Add owner to maintain global config is consistency with global variable. - case <-do.exit: - return - } - } -} - -func (do *Domain) topologySyncerKeeper() { - defer util.Recover(metrics.LabelDomain, "topologySyncerKeeper", nil, false) - ticker := time.NewTicker(infosync.TopologyTimeToRefresh) - defer func() { - ticker.Stop() - logutil.BgLogger().Info("topologySyncerKeeper exited.") - }() - - for { - select { - case <-ticker.C: - err := do.info.StoreTopologyInfo(context.Background()) - if err != nil { - logutil.BgLogger().Error("refresh topology in loop failed", zap.Error(err)) - } - case <-do.info.TopologyDone(): - logutil.BgLogger().Info("server topology syncer need to restart") - if err := do.info.RestartTopology(context.Background()); err != nil { - logutil.BgLogger().Error("server topology syncer restart failed", zap.Error(err)) - } else { - logutil.BgLogger().Info("server topology syncer restarted") - } - case <-do.exit: - return - } - } -} - -func (do *Domain) refreshMDLCheckTableInfo() { - se, err := do.sysSessionPool.Get() - - if err != nil { - logutil.BgLogger().Warn("get system session failed", zap.Error(err)) - return - } - // Make sure the session is new. - sctx := se.(sessionctx.Context) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) - if _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "rollback"); err != nil { - se.Close() - return - } - defer do.sysSessionPool.Put(se) - exec := sctx.GetRestrictedSQLExecutor() - domainSchemaVer := do.InfoSchema().SchemaMetaVersion() - // the job must stay inside tidb_ddl_job if we need to wait schema version for it. - sql := fmt.Sprintf(`select job_id, version, table_ids from mysql.tidb_mdl_info - where job_id >= %d and version <= %d`, do.minJobIDRefresher.GetCurrMinJobID(), domainSchemaVer) - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) - if err != nil { - logutil.BgLogger().Warn("get mdl info from tidb_mdl_info failed", zap.Error(err)) - return - } - do.mdlCheckTableInfo.mu.Lock() - defer do.mdlCheckTableInfo.mu.Unlock() - - do.mdlCheckTableInfo.newestVer = domainSchemaVer - do.mdlCheckTableInfo.jobsVerMap = make(map[int64]int64, len(rows)) - do.mdlCheckTableInfo.jobsIDsMap = make(map[int64]string, len(rows)) - for i := 0; i < len(rows); i++ { - do.mdlCheckTableInfo.jobsVerMap[rows[i].GetInt64(0)] = rows[i].GetInt64(1) - do.mdlCheckTableInfo.jobsIDsMap[rows[i].GetInt64(0)] = rows[i].GetString(2) - } -} - -func (do *Domain) mdlCheckLoop() { - ticker := time.Tick(mdlCheckLookDuration) - var saveMaxSchemaVersion int64 - jobNeedToSync := false - jobCache := make(map[int64]int64, 1000) - - for { - // Wait for channels - select { - case <-do.mdlCheckCh: - case <-ticker: - case <-do.exit: - return - } - - if !variable.EnableMDL.Load() { - continue - } - - do.mdlCheckTableInfo.mu.Lock() - maxVer := do.mdlCheckTableInfo.newestVer - if maxVer > saveMaxSchemaVersion { - saveMaxSchemaVersion = maxVer - } else if !jobNeedToSync { - // Schema doesn't change, and no job to check in the last run. - do.mdlCheckTableInfo.mu.Unlock() - continue - } - - jobNeedToCheckCnt := len(do.mdlCheckTableInfo.jobsVerMap) - if jobNeedToCheckCnt == 0 { - jobNeedToSync = false - do.mdlCheckTableInfo.mu.Unlock() - continue - } - - jobsVerMap := make(map[int64]int64, len(do.mdlCheckTableInfo.jobsVerMap)) - jobsIDsMap := make(map[int64]string, len(do.mdlCheckTableInfo.jobsIDsMap)) - for k, v := range do.mdlCheckTableInfo.jobsVerMap { - jobsVerMap[k] = v - } - for k, v := range do.mdlCheckTableInfo.jobsIDsMap { - jobsIDsMap[k] = v - } - do.mdlCheckTableInfo.mu.Unlock() - - jobNeedToSync = true - - sm := do.InfoSyncer().GetSessionManager() - if sm == nil { - logutil.BgLogger().Info("session manager is nil") - } else { - sm.CheckOldRunningTxn(jobsVerMap, jobsIDsMap) - } - - if len(jobsVerMap) == jobNeedToCheckCnt { - jobNeedToSync = false - } - - // Try to gc jobCache. - if len(jobCache) > 1000 { - jobCache = make(map[int64]int64, 1000) - } - - for jobID, ver := range jobsVerMap { - if cver, ok := jobCache[jobID]; ok && cver >= ver { - // Already update, skip it. - continue - } - logutil.BgLogger().Info("mdl gets lock, update self version to owner", zap.Int64("jobID", jobID), zap.Int64("version", ver)) - err := do.ddl.SchemaSyncer().UpdateSelfVersion(context.Background(), jobID, ver) - if err != nil { - jobNeedToSync = true - logutil.BgLogger().Warn("mdl gets lock, update self version to owner failed", - zap.Int64("jobID", jobID), zap.Int64("version", ver), zap.Error(err)) - } else { - jobCache[jobID] = ver - } - } - } -} - -func (do *Domain) loadSchemaInLoop(ctx context.Context, lease time.Duration) { - defer util.Recover(metrics.LabelDomain, "loadSchemaInLoop", nil, true) - // Lease renewal can run at any frequency. - // Use lease/2 here as recommend by paper. - ticker := time.NewTicker(lease / 2) - defer func() { - ticker.Stop() - logutil.BgLogger().Info("loadSchemaInLoop exited.") - }() - syncer := do.ddl.SchemaSyncer() - - for { - select { - case <-ticker.C: - failpoint.Inject("disableOnTickReload", func() { - failpoint.Continue() - }) - err := do.Reload() - if err != nil { - logutil.BgLogger().Error("reload schema in loop failed", zap.Error(err)) - } - do.deferFn.check() - case _, ok := <-syncer.GlobalVersionCh(): - err := do.Reload() - if err != nil { - logutil.BgLogger().Error("reload schema in loop failed", zap.Error(err)) - } - if !ok { - logutil.BgLogger().Warn("reload schema in loop, schema syncer need rewatch") - // Make sure the rewatch doesn't affect load schema, so we watch the global schema version asynchronously. - syncer.WatchGlobalSchemaVer(context.Background()) - } - case <-syncer.Done(): - // The schema syncer stops, we need stop the schema validator to synchronize the schema version. - logutil.BgLogger().Info("reload schema in loop, schema syncer need restart") - // The etcd is responsible for schema synchronization, we should ensure there is at most two different schema version - // in the TiDB cluster, to make the data/schema be consistent. If we lost connection/session to etcd, the cluster - // will treats this TiDB as a down instance, and etcd will remove the key of `/tidb/ddl/all_schema_versions/tidb-id`. - // Say the schema version now is 1, the owner is changing the schema version to 2, it will not wait for this down TiDB syncing the schema, - // then continue to change the TiDB schema to version 3. Unfortunately, this down TiDB schema version will still be version 1. - // And version 1 is not consistent to version 3. So we need to stop the schema validator to prohibit the DML executing. - do.SchemaValidator.Stop() - err := do.mustRestartSyncer(ctx) - if err != nil { - logutil.BgLogger().Error("reload schema in loop, schema syncer restart failed", zap.Error(err)) - break - } - // The schema maybe changed, must reload schema then the schema validator can restart. - exitLoop := do.mustReload() - // domain is closed. - if exitLoop { - logutil.BgLogger().Error("domain is closed, exit loadSchemaInLoop") - return - } - do.SchemaValidator.Restart() - logutil.BgLogger().Info("schema syncer restarted") - case <-do.exit: - return - } - do.refreshMDLCheckTableInfo() - select { - case do.mdlCheckCh <- struct{}{}: - default: - } - } -} - -// mustRestartSyncer tries to restart the SchemaSyncer. -// It returns until it's successful or the domain is stopped. -func (do *Domain) mustRestartSyncer(ctx context.Context) error { - syncer := do.ddl.SchemaSyncer() - - for { - err := syncer.Restart(ctx) - if err == nil { - return nil - } - // If the domain has stopped, we return an error immediately. - if do.isClose() { - return err - } - logutil.BgLogger().Error("restart the schema syncer failed", zap.Error(err)) - time.Sleep(time.Second) - } -} - -// mustReload tries to Reload the schema, it returns until it's successful or the domain is closed. -// it returns false when it is successful, returns true when the domain is closed. -func (do *Domain) mustReload() (exitLoop bool) { - for { - err := do.Reload() - if err == nil { - logutil.BgLogger().Info("mustReload succeed") - return false - } - - // If the domain is closed, we returns immediately. - logutil.BgLogger().Info("reload the schema failed", zap.Error(err)) - if do.isClose() { - return true - } - time.Sleep(200 * time.Millisecond) - } -} - -func (do *Domain) isClose() bool { - select { - case <-do.exit: - logutil.BgLogger().Info("domain is closed") - return true - default: - } - return false -} - -// Close closes the Domain and release its resource. -func (do *Domain) Close() { - if do == nil { - return - } - startTime := time.Now() - if do.ddl != nil { - terror.Log(do.ddl.Stop()) - } - if do.info != nil { - do.info.RemoveServerInfo() - do.info.RemoveMinStartTS() - } - ttlJobManager := do.ttlJobManager.Load() - if ttlJobManager != nil { - logutil.BgLogger().Info("stopping ttlJobManager") - ttlJobManager.Stop() - err := ttlJobManager.WaitStopped(context.Background(), func() time.Duration { - if intest.InTest { - return 10 * time.Second - } - return 30 * time.Second - }()) - if err != nil { - logutil.BgLogger().Warn("fail to wait until the ttl job manager stop", zap.Error(err)) - } else { - logutil.BgLogger().Info("ttlJobManager exited.") - } - } - do.releaseServerID(context.Background()) - close(do.exit) - if do.etcdClient != nil { - terror.Log(errors.Trace(do.etcdClient.Close())) - } - - do.runawayManager.Stop() - - if do.unprefixedEtcdCli != nil { - terror.Log(errors.Trace(do.unprefixedEtcdCli.Close())) - } - - do.slowQuery.Close() - do.cancelFns.mu.Lock() - for _, f := range do.cancelFns.fns { - f() - } - do.cancelFns.mu.Unlock() - do.wg.Wait() - do.sysSessionPool.Close() - variable.UnregisterStatistics(do.BindHandle()) - if do.onClose != nil { - do.onClose() - } - gctuner.WaitMemoryLimitTunerExitInTest() - close(do.mdlCheckCh) - - // close MockGlobalServerInfoManagerEntry in order to refresh mock server info. - if intest.InTest { - infosync.MockGlobalServerInfoManagerEntry.Close() - } - if handle := do.statsHandle.Load(); handle != nil { - handle.Close() - } - - logutil.BgLogger().Info("domain closed", zap.Duration("take time", time.Since(startTime))) -} - -const resourceIdleTimeout = 3 * time.Minute // resources in the ResourcePool will be recycled after idleTimeout - -// NewDomain creates a new domain. Should not create multiple domains for the same store. -func NewDomain(store kv.Storage, ddlLease time.Duration, statsLease time.Duration, dumpFileGcLease time.Duration, factory pools.Factory) *Domain { - capacity := 200 // capacity of the sysSessionPool size - do := &Domain{ - store: store, - exit: make(chan struct{}), - sysSessionPool: util.NewSessionPool( - capacity, factory, - func(r pools.Resource) { - _, ok := r.(sessionctx.Context) - intest.Assert(ok) - infosync.StoreInternalSession(r) - }, - func(r pools.Resource) { - _, ok := r.(sessionctx.Context) - intest.Assert(ok) - infosync.DeleteInternalSession(r) - }, - ), - statsLease: statsLease, - slowQuery: newTopNSlowQueries(config.GetGlobalConfig().InMemSlowQueryTopNNum, time.Hour*24*7, config.GetGlobalConfig().InMemSlowQueryRecentNum), - dumpFileGcChecker: &dumpFileGcChecker{gcLease: dumpFileGcLease, paths: []string{replayer.GetPlanReplayerDirName(), GetOptimizerTraceDirName(), GetExtractTaskDirName()}}, - mdlCheckTableInfo: &mdlCheckTableInfo{ - mu: sync.Mutex{}, - jobsVerMap: make(map[int64]int64), - jobsIDsMap: make(map[int64]string), - }, - mdlCheckCh: make(chan struct{}), - } - - do.infoCache = infoschema.NewCache(do, int(variable.SchemaVersionCacheLimit.Load())) - do.stopAutoAnalyze.Store(false) - do.wg = util.NewWaitGroupEnhancedWrapper("domain", do.exit, config.GetGlobalConfig().TiDBEnableExitCheck) - do.SchemaValidator = NewSchemaValidator(ddlLease, do) - do.expensiveQueryHandle = expensivequery.NewExpensiveQueryHandle(do.exit) - do.memoryUsageAlarmHandle = memoryusagealarm.NewMemoryUsageAlarmHandle(do.exit) - do.serverMemoryLimitHandle = servermemorylimit.NewServerMemoryLimitHandle(do.exit) - do.sysProcesses = SysProcesses{mu: &sync.RWMutex{}, procMap: make(map[uint64]sysproctrack.TrackProc)} - do.initDomainSysVars() - do.expiredTimeStamp4PC.expiredTimeStamp = types.NewTime(types.ZeroCoreTime, mysql.TypeTimestamp, types.DefaultFsp) - return do -} - -const serverIDForStandalone = 1 // serverID for standalone deployment. - -func newEtcdCli(addrs []string, ebd kv.EtcdBackend) (*clientv3.Client, error) { - cfg := config.GetGlobalConfig() - etcdLogCfg := zap.NewProductionConfig() - etcdLogCfg.Level = zap.NewAtomicLevelAt(zap.ErrorLevel) - backoffCfg := backoff.DefaultConfig - backoffCfg.MaxDelay = 3 * time.Second - cli, err := clientv3.New(clientv3.Config{ - LogConfig: &etcdLogCfg, - Endpoints: addrs, - AutoSyncInterval: 30 * time.Second, - DialTimeout: 5 * time.Second, - DialOptions: []grpc.DialOption{ - grpc.WithConnectParams(grpc.ConnectParams{ - Backoff: backoffCfg, - }), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: time.Duration(cfg.TiKVClient.GrpcKeepAliveTime) * time.Second, - Timeout: time.Duration(cfg.TiKVClient.GrpcKeepAliveTimeout) * time.Second, - }), - }, - TLS: ebd.TLSConfig(), - }) - return cli, err -} - -// Init initializes a domain. after return, session can be used to do DMLs but not -// DDLs which can be used after domain Start. -func (do *Domain) Init( - ddlLease time.Duration, - sysExecutorFactory func(*Domain) (pools.Resource, error), - ddlInjector func(ddl.DDL, ddl.Executor, *infoschema.InfoCache) *schematracker.Checker, -) error { - // TODO there are many place set ddlLease to 0, remove them completely, we want - // UT and even local uni-store to run similar code path as normal. - if ddlLease == 0 { - ddlLease = time.Second - } - - do.sysExecutorFactory = sysExecutorFactory - perfschema.Init() - if ebd, ok := do.store.(kv.EtcdBackend); ok { - var addrs []string - var err error - if addrs, err = ebd.EtcdAddrs(); err != nil { - return err - } - if addrs != nil { - cli, err := newEtcdCli(addrs, ebd) - if err != nil { - return errors.Trace(err) - } - - etcd.SetEtcdCliByNamespace(cli, keyspace.MakeKeyspaceEtcdNamespace(do.store.GetCodec())) - - do.etcdClient = cli - - do.autoidClient = autoid.NewClientDiscover(cli) - - unprefixedEtcdCli, err := newEtcdCli(addrs, ebd) - if err != nil { - return errors.Trace(err) - } - do.unprefixedEtcdCli = unprefixedEtcdCli - } - } - - ctx, cancelFunc := context.WithCancel(context.Background()) - do.ctx = ctx - do.cancelFns.mu.Lock() - do.cancelFns.fns = append(do.cancelFns.fns, cancelFunc) - do.cancelFns.mu.Unlock() - d := do.ddl - eBak := do.ddlExecutor - do.ddl, do.ddlExecutor = ddl.NewDDL( - ctx, - ddl.WithEtcdClient(do.etcdClient), - ddl.WithStore(do.store), - ddl.WithAutoIDClient(do.autoidClient), - ddl.WithInfoCache(do.infoCache), - ddl.WithLease(ddlLease), - ddl.WithSchemaLoader(do), - ) - - failpoint.Inject("MockReplaceDDL", func(val failpoint.Value) { - if val.(bool) { - do.ddl = d - do.ddlExecutor = eBak - } - }) - if ddlInjector != nil { - checker := ddlInjector(do.ddl, do.ddlExecutor, do.infoCache) - checker.CreateTestDB(nil) - do.ddl = checker - do.ddlExecutor = checker - } - - // step 1: prepare the info/schema syncer which domain reload needed. - pdCli, pdHTTPCli := do.GetPDClient(), do.GetPDHTTPClient() - skipRegisterToDashboard := config.GetGlobalConfig().SkipRegisterToDashboard - var err error - do.info, err = infosync.GlobalInfoSyncerInit(ctx, do.ddl.GetID(), do.ServerID, - do.etcdClient, do.unprefixedEtcdCli, pdCli, pdHTTPCli, - do.Store().GetCodec(), skipRegisterToDashboard) - if err != nil { - return err - } - do.globalCfgSyncer = globalconfigsync.NewGlobalConfigSyncer(pdCli) - err = do.ddl.SchemaSyncer().Init(ctx) - if err != nil { - return err - } - - // step 2: initialize the global kill, which depends on `globalInfoSyncer`.` - if config.GetGlobalConfig().EnableGlobalKill { - do.connIDAllocator = globalconn.NewGlobalAllocator(do.ServerID, config.GetGlobalConfig().Enable32BitsConnectionID) - - if do.etcdClient != nil { - err := do.acquireServerID(ctx) - if err != nil { - logutil.BgLogger().Error("acquire serverID failed", zap.Error(err)) - do.isLostConnectionToPD.Store(1) // will retry in `do.serverIDKeeper` - } else { - if err := do.info.StoreServerInfo(context.Background()); err != nil { - return errors.Trace(err) - } - do.isLostConnectionToPD.Store(0) - } - } else { - // set serverID for standalone deployment to enable 'KILL'. - atomic.StoreUint64(&do.serverID, serverIDForStandalone) - } - } else { - do.connIDAllocator = globalconn.NewSimpleAllocator() - } - - // should put `initResourceGroupsController` after fetching server ID - err = do.initResourceGroupsController(ctx, pdCli, do.ServerID()) - if err != nil { - return err - } - - startReloadTime := time.Now() - // step 3: domain reload the infoSchema. - err = do.Reload() - if err != nil { - return err - } - - sub := time.Since(startReloadTime) - // The reload(in step 2) operation takes more than ddlLease and a new reload operation was not performed, - // the next query will respond by ErrInfoSchemaExpired error. So we do a new reload to update schemaValidator.latestSchemaExpire. - if sub > (ddlLease / 2) { - logutil.BgLogger().Warn("loading schema and starting ddl take a long time, we do a new reload", zap.Duration("take time", sub)) - err = do.Reload() - if err != nil { - return err - } - } - return nil -} - -// Start starts the domain. After start, DDLs can be executed using session, see -// Init also. -func (do *Domain) Start() error { - gCfg := config.GetGlobalConfig() - if gCfg.EnableGlobalKill && do.etcdClient != nil { - do.wg.Add(1) - go do.serverIDKeeper() - } - - // TODO: Here we create new sessions with sysFac in DDL, - // which will use `do` as Domain instead of call `domap.Get`. - // That's because `domap.Get` requires a lock, but before - // we initialize Domain finish, we can't require that again. - // After we remove the lazy logic of creating Domain, we - // can simplify code here. - sysFac := func() (pools.Resource, error) { - return do.sysExecutorFactory(do) - } - sysCtxPool := pools.NewResourcePool(sysFac, 512, 512, resourceIdleTimeout) - - // start the ddl after the domain reload, avoiding some internal sql running before infoSchema construction. - err := do.ddl.Start(sysCtxPool) - if err != nil { - return err - } - do.minJobIDRefresher = do.ddl.GetMinJobIDRefresher() - - // Local store needs to get the change information for every DDL state in each session. - do.wg.Run(func() { - do.loadSchemaInLoop(do.ctx, do.ddl.GetLease()) - }, "loadSchemaInLoop") - do.wg.Run(do.mdlCheckLoop, "mdlCheckLoop") - do.wg.Run(do.topNSlowQueryLoop, "topNSlowQueryLoop") - do.wg.Run(do.infoSyncerKeeper, "infoSyncerKeeper") - do.wg.Run(do.globalConfigSyncerKeeper, "globalConfigSyncerKeeper") - do.wg.Run(do.runawayStartLoop, "runawayStartLoop") - do.wg.Run(do.requestUnitsWriterLoop, "requestUnitsWriterLoop") - skipRegisterToDashboard := gCfg.SkipRegisterToDashboard - if !skipRegisterToDashboard { - do.wg.Run(do.topologySyncerKeeper, "topologySyncerKeeper") - } - pdCli := do.GetPDClient() - if pdCli != nil { - do.wg.Run(func() { - do.closestReplicaReadCheckLoop(do.ctx, pdCli) - }, "closestReplicaReadCheckLoop") - } - - err = do.initLogBackup(do.ctx, pdCli) - if err != nil { - return err - } - - return nil -} - -// InitInfo4Test init infosync for distributed execution test. -func (do *Domain) InitInfo4Test() { - infosync.MockGlobalServerInfoManagerEntry.Add(do.ddl.GetID(), do.ServerID) -} - -// SetOnClose used to set do.onClose func. -func (do *Domain) SetOnClose(onClose func()) { - do.onClose = onClose -} - -func (do *Domain) initLogBackup(ctx context.Context, pdClient pd.Client) error { - cfg := config.GetGlobalConfig() - if pdClient == nil || do.etcdClient == nil { - log.Warn("pd / etcd client not provided, won't begin Advancer.") - return nil - } - tikvStore, ok := do.Store().(tikv.Storage) - if !ok { - log.Warn("non tikv store, stop begin Advancer.") - return nil - } - env, err := streamhelper.TiDBEnv(tikvStore, pdClient, do.etcdClient, cfg) - if err != nil { - return err - } - adv := streamhelper.NewCheckpointAdvancer(env) - do.logBackupAdvancer = daemon.New(adv, streamhelper.OwnerManagerForLogBackup(ctx, do.etcdClient), adv.Config().TickDuration) - loop, err := do.logBackupAdvancer.Begin(ctx) - if err != nil { - return err - } - do.wg.Run(loop, "logBackupAdvancer") - return nil -} - -// when tidb_replica_read = 'closest-adaptive', check tidb and tikv's zone label matches. -// if not match, disable replica_read to avoid uneven read traffic distribution. -func (do *Domain) closestReplicaReadCheckLoop(ctx context.Context, pdClient pd.Client) { - defer util.Recover(metrics.LabelDomain, "closestReplicaReadCheckLoop", nil, false) - - // trigger check once instantly. - if err := do.checkReplicaRead(ctx, pdClient); err != nil { - logutil.BgLogger().Warn("refresh replicaRead flag failed", zap.Error(err)) - } - - ticker := time.NewTicker(time.Minute) - defer func() { - ticker.Stop() - logutil.BgLogger().Info("closestReplicaReadCheckLoop exited.") - }() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - if err := do.checkReplicaRead(ctx, pdClient); err != nil { - logutil.BgLogger().Warn("refresh replicaRead flag failed", zap.Error(err)) - } - } - } -} - -// Periodically check and update the replica-read status when `tidb_replica_read` is set to "closest-adaptive" -// We disable "closest-adaptive" in following conditions to ensure the read traffic is evenly distributed across -// all AZs: -// - There are no TiKV servers in the AZ of this tidb instance -// - The AZ if this tidb contains more tidb than other AZ and this tidb's id is the bigger one. -func (do *Domain) checkReplicaRead(ctx context.Context, pdClient pd.Client) error { - do.sysVarCache.RLock() - replicaRead := do.sysVarCache.global[variable.TiDBReplicaRead] - do.sysVarCache.RUnlock() - - if !strings.EqualFold(replicaRead, "closest-adaptive") { - logutil.BgLogger().Debug("closest replica read is not enabled, skip check!", zap.String("mode", replicaRead)) - return nil - } - - serverInfo, err := infosync.GetServerInfo() - if err != nil { - return err - } - zone := "" - for k, v := range serverInfo.Labels { - if k == placement.DCLabelKey && v != "" { - zone = v - break - } - } - if zone == "" { - logutil.BgLogger().Debug("server contains no 'zone' label, disable closest replica read", zap.Any("labels", serverInfo.Labels)) - variable.SetEnableAdaptiveReplicaRead(false) - return nil - } - - stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) - if err != nil { - return err - } - - storeZones := make(map[string]int) - for _, s := range stores { - // skip tumbstone stores or tiflash - if s.NodeState == metapb.NodeState_Removing || s.NodeState == metapb.NodeState_Removed || engine.IsTiFlash(s) { - continue - } - for _, label := range s.Labels { - if label.Key == placement.DCLabelKey && label.Value != "" { - storeZones[label.Value] = 0 - break - } - } - } - - // no stores in this AZ - if _, ok := storeZones[zone]; !ok { - variable.SetEnableAdaptiveReplicaRead(false) - return nil - } - - servers, err := infosync.GetAllServerInfo(ctx) - if err != nil { - return err - } - svrIDsInThisZone := make([]string, 0) - for _, s := range servers { - if v, ok := s.Labels[placement.DCLabelKey]; ok && v != "" { - if _, ok := storeZones[v]; ok { - storeZones[v]++ - if v == zone { - svrIDsInThisZone = append(svrIDsInThisZone, s.ID) - } - } - } - } - enabledCount := math.MaxInt - for _, count := range storeZones { - if count < enabledCount { - enabledCount = count - } - } - // sort tidb in the same AZ by ID and disable the tidb with bigger ID - // because ID is unchangeable, so this is a simple and stable algorithm to select - // some instances across all tidb servers. - if enabledCount < len(svrIDsInThisZone) { - sort.Slice(svrIDsInThisZone, func(i, j int) bool { - return strings.Compare(svrIDsInThisZone[i], svrIDsInThisZone[j]) < 0 - }) - } - enabled := true - for _, s := range svrIDsInThisZone[enabledCount:] { - if s == serverInfo.ID { - enabled = false - break - } - } - - if variable.SetEnableAdaptiveReplicaRead(enabled) { - logutil.BgLogger().Info("tidb server adaptive closest replica read is changed", zap.Bool("enable", enabled)) - } - return nil -} - -// InitDistTaskLoop initializes the distributed task framework. -func (do *Domain) InitDistTaskLoop() error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalDistTask) - failpoint.Inject("MockDisableDistTask", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil) - } - }) - - taskManager := storage.NewTaskManager(do.sysSessionPool) - var serverID string - if intest.InTest { - do.InitInfo4Test() - serverID = disttaskutil.GenerateSubtaskExecID4Test(do.ddl.GetID()) - } else { - serverID = disttaskutil.GenerateSubtaskExecID(ctx, do.ddl.GetID()) - } - - if serverID == "" { - errMsg := fmt.Sprintf("TiDB node ID( = %s ) not found in available TiDB nodes list", do.ddl.GetID()) - return errors.New(errMsg) - } - managerCtx, cancel := context.WithCancel(ctx) - do.cancelFns.mu.Lock() - do.cancelFns.fns = append(do.cancelFns.fns, cancel) - do.cancelFns.mu.Unlock() - executorManager, err := taskexecutor.NewManager(managerCtx, serverID, taskManager) - if err != nil { - return err - } - - storage.SetTaskManager(taskManager) - if err = executorManager.InitMeta(); err != nil { - // executor manager loop will try to recover meta repeatedly, so we can - // just log the error here. - logutil.BgLogger().Warn("init task executor manager meta failed", zap.Error(err)) - } - do.wg.Run(func() { - defer func() { - storage.SetTaskManager(nil) - }() - do.distTaskFrameworkLoop(ctx, taskManager, executorManager, serverID) - }, "distTaskFrameworkLoop") - return nil -} - -func (do *Domain) distTaskFrameworkLoop(ctx context.Context, taskManager *storage.TaskManager, executorManager *taskexecutor.Manager, serverID string) { - err := executorManager.Start() - if err != nil { - logutil.BgLogger().Error("dist task executor manager start failed", zap.Error(err)) - return - } - logutil.BgLogger().Info("dist task executor manager started") - defer func() { - logutil.BgLogger().Info("stopping dist task executor manager") - executorManager.Stop() - logutil.BgLogger().Info("dist task executor manager stopped") - }() - - var schedulerManager *scheduler.Manager - startSchedulerMgrIfNeeded := func() { - if schedulerManager != nil && schedulerManager.Initialized() { - return - } - schedulerManager = scheduler.NewManager(ctx, taskManager, serverID) - schedulerManager.Start() - } - stopSchedulerMgrIfNeeded := func() { - if schedulerManager != nil && schedulerManager.Initialized() { - logutil.BgLogger().Info("stopping dist task scheduler manager because the current node is not DDL owner anymore", zap.String("id", do.ddl.GetID())) - schedulerManager.Stop() - logutil.BgLogger().Info("dist task scheduler manager stopped", zap.String("id", do.ddl.GetID())) - } - } - - ticker := time.NewTicker(time.Second) - for { - select { - case <-do.exit: - stopSchedulerMgrIfNeeded() - return - case <-ticker.C: - if do.ddl.OwnerManager().IsOwner() { - startSchedulerMgrIfNeeded() - } else { - stopSchedulerMgrIfNeeded() - } - } - } -} - -// SysSessionPool returns the system session pool. -func (do *Domain) SysSessionPool() util.SessionPool { - return do.sysSessionPool -} - -// SysProcTracker returns the system processes tracker. -func (do *Domain) SysProcTracker() sysproctrack.Tracker { - return &do.sysProcesses -} - -// GetEtcdClient returns the etcd client. -func (do *Domain) GetEtcdClient() *clientv3.Client { - return do.etcdClient -} - -// AutoIDClient returns the autoid client. -func (do *Domain) AutoIDClient() *autoid.ClientDiscover { - return do.autoidClient -} - -// GetPDClient returns the PD client. -func (do *Domain) GetPDClient() pd.Client { - if store, ok := do.store.(kv.StorageWithPD); ok { - return store.GetPDClient() - } - return nil -} - -// GetPDHTTPClient returns the PD HTTP client. -func (do *Domain) GetPDHTTPClient() pdhttp.Client { - if store, ok := do.store.(kv.StorageWithPD); ok { - return store.GetPDHTTPClient() - } - return nil -} - -// LoadPrivilegeLoop create a goroutine loads privilege tables in a loop, it -// should be called only once in BootstrapSession. -func (do *Domain) LoadPrivilegeLoop(sctx sessionctx.Context) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) - sctx.GetSessionVars().InRestrictedSQL = true - _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "set @@autocommit = 1") - if err != nil { - return err - } - do.privHandle = privileges.NewHandle() - err = do.privHandle.Update(sctx) - if err != nil { - return err - } - - var watchCh clientv3.WatchChan - duration := 5 * time.Minute - if do.etcdClient != nil { - watchCh = do.etcdClient.Watch(context.Background(), privilegeKey) - duration = 10 * time.Minute - } - - do.wg.Run(func() { - defer func() { - logutil.BgLogger().Info("loadPrivilegeInLoop exited.") - }() - defer util.Recover(metrics.LabelDomain, "loadPrivilegeInLoop", nil, false) - - var count int - for { - ok := true - select { - case <-do.exit: - return - case _, ok = <-watchCh: - case <-time.After(duration): - } - if !ok { - logutil.BgLogger().Error("load privilege loop watch channel closed") - watchCh = do.etcdClient.Watch(context.Background(), privilegeKey) - count++ - if count > 10 { - time.Sleep(time.Duration(count) * time.Second) - } - continue - } - - count = 0 - err := do.privHandle.Update(sctx) - metrics.LoadPrivilegeCounter.WithLabelValues(metrics.RetLabel(err)).Inc() - if err != nil { - logutil.BgLogger().Error("load privilege failed", zap.Error(err)) - } - } - }, "loadPrivilegeInLoop") - return nil -} - -// LoadSysVarCacheLoop create a goroutine loads sysvar cache in a loop, -// it should be called only once in BootstrapSession. -func (do *Domain) LoadSysVarCacheLoop(ctx sessionctx.Context) error { - ctx.GetSessionVars().InRestrictedSQL = true - err := do.rebuildSysVarCache(ctx) - if err != nil { - return err - } - var watchCh clientv3.WatchChan - duration := 30 * time.Second - if do.etcdClient != nil { - watchCh = do.etcdClient.Watch(context.Background(), sysVarCacheKey) - } - - do.wg.Run(func() { - defer func() { - logutil.BgLogger().Info("LoadSysVarCacheLoop exited.") - }() - defer util.Recover(metrics.LabelDomain, "LoadSysVarCacheLoop", nil, false) - - var count int - for { - ok := true - select { - case <-do.exit: - return - case _, ok = <-watchCh: - case <-time.After(duration): - } - - failpoint.Inject("skipLoadSysVarCacheLoop", func(val failpoint.Value) { - // In some pkg integration test, there are many testSuite, and each testSuite has separate storage and - // `LoadSysVarCacheLoop` background goroutine. Then each testSuite `RebuildSysVarCache` from it's - // own storage. - // Each testSuit will also call `checkEnableServerGlobalVar` to update some local variables. - // That's the problem, each testSuit use different storage to update some same local variables. - // So just skip `RebuildSysVarCache` in some integration testing. - if val.(bool) { - failpoint.Continue() - } - }) - - if !ok { - logutil.BgLogger().Error("LoadSysVarCacheLoop loop watch channel closed") - watchCh = do.etcdClient.Watch(context.Background(), sysVarCacheKey) - count++ - if count > 10 { - time.Sleep(time.Duration(count) * time.Second) - } - continue - } - count = 0 - logutil.BgLogger().Debug("Rebuilding sysvar cache from etcd watch event.") - err := do.rebuildSysVarCache(ctx) - metrics.LoadSysVarCacheCounter.WithLabelValues(metrics.RetLabel(err)).Inc() - if err != nil { - logutil.BgLogger().Error("LoadSysVarCacheLoop failed", zap.Error(err)) - } - } - }, "LoadSysVarCacheLoop") - return nil -} - -// WatchTiFlashComputeNodeChange create a routine to watch if the topology of tiflash_compute node is changed. -// TODO: tiflashComputeNodeKey is not put to etcd yet(finish this when AutoScaler is done) -// -// store cache will only be invalidated every n seconds. -func (do *Domain) WatchTiFlashComputeNodeChange() error { - var watchCh clientv3.WatchChan - if do.etcdClient != nil { - watchCh = do.etcdClient.Watch(context.Background(), tiflashComputeNodeKey) - } - duration := 10 * time.Second - do.wg.Run(func() { - defer func() { - logutil.BgLogger().Info("WatchTiFlashComputeNodeChange exit") - }() - defer util.Recover(metrics.LabelDomain, "WatchTiFlashComputeNodeChange", nil, false) - - var count int - var logCount int - for { - ok := true - var watched bool - select { - case <-do.exit: - return - case _, ok = <-watchCh: - watched = true - case <-time.After(duration): - } - if !ok { - logutil.BgLogger().Error("WatchTiFlashComputeNodeChange watch channel closed") - watchCh = do.etcdClient.Watch(context.Background(), tiflashComputeNodeKey) - count++ - if count > 10 { - time.Sleep(time.Duration(count) * time.Second) - } - continue - } - count = 0 - switch s := do.store.(type) { - case tikv.Storage: - logCount++ - s.GetRegionCache().InvalidateTiFlashComputeStores() - if logCount == 6 { - // Print log every 6*duration seconds. - logutil.BgLogger().Debug("tiflash_compute store cache invalied, will update next query", zap.Bool("watched", watched)) - logCount = 0 - } - default: - logutil.BgLogger().Debug("No need to watch tiflash_compute store cache for non-tikv store") - return - } - } - }, "WatchTiFlashComputeNodeChange") - return nil -} - -// PrivilegeHandle returns the MySQLPrivilege. -func (do *Domain) PrivilegeHandle() *privileges.Handle { - return do.privHandle -} - -// BindHandle returns domain's bindHandle. -func (do *Domain) BindHandle() bindinfo.GlobalBindingHandle { - v := do.bindHandle.Load() - if v == nil { - return nil - } - return v.(bindinfo.GlobalBindingHandle) -} - -// LoadBindInfoLoop create a goroutine loads BindInfo in a loop, it should -// be called only once in BootstrapSession. -func (do *Domain) LoadBindInfoLoop(ctxForHandle sessionctx.Context, ctxForEvolve sessionctx.Context) error { - ctxForHandle.GetSessionVars().InRestrictedSQL = true - ctxForEvolve.GetSessionVars().InRestrictedSQL = true - if !do.bindHandle.CompareAndSwap(nil, bindinfo.NewGlobalBindingHandle(do.sysSessionPool)) { - do.BindHandle().Reset() - } - - err := do.BindHandle().LoadFromStorageToCache(true) - if err != nil || bindinfo.Lease == 0 { - return err - } - - owner := do.newOwnerManager(bindinfo.Prompt, bindinfo.OwnerKey) - do.globalBindHandleWorkerLoop(owner) - return nil -} - -func (do *Domain) globalBindHandleWorkerLoop(owner owner.Manager) { - do.wg.Run(func() { - defer func() { - logutil.BgLogger().Info("globalBindHandleWorkerLoop exited.") - }() - defer util.Recover(metrics.LabelDomain, "globalBindHandleWorkerLoop", nil, false) - - bindWorkerTicker := time.NewTicker(bindinfo.Lease) - gcBindTicker := time.NewTicker(100 * bindinfo.Lease) - defer func() { - bindWorkerTicker.Stop() - gcBindTicker.Stop() - }() - for { - select { - case <-do.exit: - owner.Cancel() - return - case <-bindWorkerTicker.C: - bindHandle := do.BindHandle() - err := bindHandle.LoadFromStorageToCache(false) - if err != nil { - logutil.BgLogger().Error("update bindinfo failed", zap.Error(err)) - } - bindHandle.DropInvalidGlobalBinding() - // Get Global - optVal, err := do.GetGlobalVar(variable.TiDBCapturePlanBaseline) - if err == nil && variable.TiDBOptOn(optVal) { - bindHandle.CaptureBaselines() - } - case <-gcBindTicker.C: - if !owner.IsOwner() { - continue - } - err := do.BindHandle().GCGlobalBinding() - if err != nil { - logutil.BgLogger().Error("GC bind record failed", zap.Error(err)) - } - } - } - }, "globalBindHandleWorkerLoop") -} - -// SetupPlanReplayerHandle setup plan replayer handle -func (do *Domain) SetupPlanReplayerHandle(collectorSctx sessionctx.Context, workersSctxs []sessionctx.Context) { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - do.planReplayerHandle = &planReplayerHandle{} - do.planReplayerHandle.planReplayerTaskCollectorHandle = &planReplayerTaskCollectorHandle{ - ctx: ctx, - sctx: collectorSctx, - } - taskCH := make(chan *PlanReplayerDumpTask, 16) - taskStatus := &planReplayerDumpTaskStatus{} - taskStatus.finishedTaskMu.finishedTask = map[replayer.PlanReplayerTaskKey]struct{}{} - taskStatus.runningTaskMu.runningTasks = map[replayer.PlanReplayerTaskKey]struct{}{} - - do.planReplayerHandle.planReplayerTaskDumpHandle = &planReplayerTaskDumpHandle{ - taskCH: taskCH, - status: taskStatus, - } - do.planReplayerHandle.planReplayerTaskDumpHandle.workers = make([]*planReplayerTaskDumpWorker, 0) - for i := 0; i < len(workersSctxs); i++ { - worker := &planReplayerTaskDumpWorker{ - ctx: ctx, - sctx: workersSctxs[i], - taskCH: taskCH, - status: taskStatus, - } - do.planReplayerHandle.planReplayerTaskDumpHandle.workers = append(do.planReplayerHandle.planReplayerTaskDumpHandle.workers, worker) - } -} - -// RunawayManager returns the runaway manager. -func (do *Domain) RunawayManager() *resourcegroup.RunawayManager { - return do.runawayManager -} - -// ResourceGroupsController returns the resource groups controller. -func (do *Domain) ResourceGroupsController() *rmclient.ResourceGroupsController { - return do.resourceGroupsController -} - -// SetResourceGroupsController is only used in test. -func (do *Domain) SetResourceGroupsController(controller *rmclient.ResourceGroupsController) { - do.resourceGroupsController = controller -} - -// SetupHistoricalStatsWorker setups worker -func (do *Domain) SetupHistoricalStatsWorker(ctx sessionctx.Context) { - do.historicalStatsWorker = &HistoricalStatsWorker{ - tblCH: make(chan int64, 16), - sctx: ctx, - } -} - -// SetupDumpFileGCChecker setup sctx -func (do *Domain) SetupDumpFileGCChecker(ctx sessionctx.Context) { - do.dumpFileGcChecker.setupSctx(ctx) - do.dumpFileGcChecker.planReplayerTaskStatus = do.planReplayerHandle.status -} - -// SetupExtractHandle setups extract handler -func (do *Domain) SetupExtractHandle(sctxs []sessionctx.Context) { - do.extractTaskHandle = newExtractHandler(do.ctx, sctxs) -} - -var planReplayerHandleLease atomic.Uint64 - -func init() { - planReplayerHandleLease.Store(uint64(10 * time.Second)) - enableDumpHistoricalStats.Store(true) -} - -// DisablePlanReplayerBackgroundJob4Test disable plan replayer handle for test -func DisablePlanReplayerBackgroundJob4Test() { - planReplayerHandleLease.Store(0) -} - -// DisableDumpHistoricalStats4Test disable historical dump worker for test -func DisableDumpHistoricalStats4Test() { - enableDumpHistoricalStats.Store(false) -} - -// StartPlanReplayerHandle start plan replayer handle job -func (do *Domain) StartPlanReplayerHandle() { - lease := planReplayerHandleLease.Load() - if lease < 1 { - return - } - do.wg.Run(func() { - logutil.BgLogger().Info("PlanReplayerTaskCollectHandle started") - tikcer := time.NewTicker(time.Duration(lease)) - defer func() { - tikcer.Stop() - logutil.BgLogger().Info("PlanReplayerTaskCollectHandle exited.") - }() - defer util.Recover(metrics.LabelDomain, "PlanReplayerTaskCollectHandle", nil, false) - - for { - select { - case <-do.exit: - return - case <-tikcer.C: - err := do.planReplayerHandle.CollectPlanReplayerTask() - if err != nil { - logutil.BgLogger().Warn("plan replayer handle collect tasks failed", zap.Error(err)) - } - } - } - }, "PlanReplayerTaskCollectHandle") - - do.wg.Run(func() { - logutil.BgLogger().Info("PlanReplayerTaskDumpHandle started") - defer func() { - logutil.BgLogger().Info("PlanReplayerTaskDumpHandle exited.") - }() - defer util.Recover(metrics.LabelDomain, "PlanReplayerTaskDumpHandle", nil, false) - - for _, worker := range do.planReplayerHandle.planReplayerTaskDumpHandle.workers { - go worker.run() - } - <-do.exit - do.planReplayerHandle.planReplayerTaskDumpHandle.Close() - }, "PlanReplayerTaskDumpHandle") -} - -// GetPlanReplayerHandle returns plan replayer handle -func (do *Domain) GetPlanReplayerHandle() *planReplayerHandle { - return do.planReplayerHandle -} - -// GetExtractHandle returns extract handle -func (do *Domain) GetExtractHandle() *ExtractHandle { - return do.extractTaskHandle -} - -// GetDumpFileGCChecker returns dump file GC checker for plan replayer and plan trace -func (do *Domain) GetDumpFileGCChecker() *dumpFileGcChecker { - return do.dumpFileGcChecker -} - -// DumpFileGcCheckerLoop creates a goroutine that handles `exit` and `gc`. -func (do *Domain) DumpFileGcCheckerLoop() { - do.wg.Run(func() { - logutil.BgLogger().Info("dumpFileGcChecker started") - gcTicker := time.NewTicker(do.dumpFileGcChecker.gcLease) - defer func() { - gcTicker.Stop() - logutil.BgLogger().Info("dumpFileGcChecker exited.") - }() - defer util.Recover(metrics.LabelDomain, "dumpFileGcCheckerLoop", nil, false) - - for { - select { - case <-do.exit: - return - case <-gcTicker.C: - do.dumpFileGcChecker.GCDumpFiles(time.Hour, time.Hour*24*7) - } - } - }, "dumpFileGcChecker") -} - -// GetHistoricalStatsWorker gets historical workers -func (do *Domain) GetHistoricalStatsWorker() *HistoricalStatsWorker { - return do.historicalStatsWorker -} - -// EnableDumpHistoricalStats used to control whether enable dump stats for unit test -var enableDumpHistoricalStats atomic.Bool - -// StartHistoricalStatsWorker start historical workers running -func (do *Domain) StartHistoricalStatsWorker() { - if !enableDumpHistoricalStats.Load() { - return - } - do.wg.Run(func() { - logutil.BgLogger().Info("HistoricalStatsWorker started") - defer func() { - logutil.BgLogger().Info("HistoricalStatsWorker exited.") - }() - defer util.Recover(metrics.LabelDomain, "HistoricalStatsWorkerLoop", nil, false) - - for { - select { - case <-do.exit: - close(do.historicalStatsWorker.tblCH) - return - case tblID, ok := <-do.historicalStatsWorker.tblCH: - if !ok { - return - } - err := do.historicalStatsWorker.DumpHistoricalStats(tblID, do.StatsHandle()) - if err != nil { - logutil.BgLogger().Warn("dump historical stats failed", zap.Error(err), zap.Int64("tableID", tblID)) - } - } - } - }, "HistoricalStatsWorker") -} - -// StatsHandle returns the statistic handle. -func (do *Domain) StatsHandle() *handle.Handle { - return do.statsHandle.Load() -} - -// CreateStatsHandle is used only for test. -func (do *Domain) CreateStatsHandle(ctx, initStatsCtx sessionctx.Context) error { - h, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.NextConnID, do.ReleaseConnID) - if err != nil { - return err - } - h.StartWorker() - do.statsHandle.Store(h) - return nil -} - -// StatsUpdating checks if the stats worker is updating. -func (do *Domain) StatsUpdating() bool { - return do.statsUpdating.Load() > 0 -} - -// SetStatsUpdating sets the value of stats updating. -func (do *Domain) SetStatsUpdating(val bool) { - if val { - do.statsUpdating.Store(1) - } else { - do.statsUpdating.Store(0) - } -} - -// ReleaseAnalyzeExec returned extra exec for Analyze -func (do *Domain) ReleaseAnalyzeExec(sctxs []sessionctx.Context) { - do.analyzeMu.Lock() - defer do.analyzeMu.Unlock() - for _, ctx := range sctxs { - do.analyzeMu.sctxs[ctx] = false - } -} - -// FetchAnalyzeExec get needed exec for analyze -func (do *Domain) FetchAnalyzeExec(need int) []sessionctx.Context { - if need < 1 { - return nil - } - count := 0 - r := make([]sessionctx.Context, 0) - do.analyzeMu.Lock() - defer do.analyzeMu.Unlock() - for sctx, used := range do.analyzeMu.sctxs { - if used { - continue - } - r = append(r, sctx) - do.analyzeMu.sctxs[sctx] = true - count++ - if count >= need { - break - } - } - return r -} - -// SetupAnalyzeExec setups exec for Analyze Executor -func (do *Domain) SetupAnalyzeExec(ctxs []sessionctx.Context) { - do.analyzeMu.sctxs = make(map[sessionctx.Context]bool) - for _, ctx := range ctxs { - do.analyzeMu.sctxs[ctx] = false - } -} - -// LoadAndUpdateStatsLoop loads and updates stats info. -func (do *Domain) LoadAndUpdateStatsLoop(ctxs []sessionctx.Context, initStatsCtx sessionctx.Context) error { - if err := do.UpdateTableStatsLoop(ctxs[0], initStatsCtx); err != nil { - return err - } - do.StartLoadStatsSubWorkers(ctxs[1:]) - return nil -} - -// UpdateTableStatsLoop creates a goroutine loads stats info and updates stats info in a loop. -// It will also start a goroutine to analyze tables automatically. -// It should be called only once in BootstrapSession. -func (do *Domain) UpdateTableStatsLoop(ctx, initStatsCtx sessionctx.Context) error { - ctx.GetSessionVars().InRestrictedSQL = true - statsHandle, err := handle.NewHandle(ctx, initStatsCtx, do.statsLease, do.sysSessionPool, &do.sysProcesses, do.NextConnID, do.ReleaseConnID) - if err != nil { - return err - } - statsHandle.StartWorker() - do.statsHandle.Store(statsHandle) - do.ddl.RegisterStatsHandle(statsHandle) - // Negative stats lease indicates that it is in test or in br binary mode, it does not need update. - if do.statsLease >= 0 { - do.wg.Run(do.loadStatsWorker, "loadStatsWorker") - } - owner := do.newOwnerManager(handle.StatsPrompt, handle.StatsOwnerKey) - do.wg.Run(func() { - do.indexUsageWorker() - }, "indexUsageWorker") - if do.statsLease <= 0 { - // For statsLease > 0, `updateStatsWorker` handles the quit of stats owner. - do.wg.Run(func() { quitStatsOwner(do, owner) }, "quitStatsOwner") - return nil - } - do.SetStatsUpdating(true) - // The stats updated worker doesn't require the stats initialization to be completed. - // This is because the updated worker's primary responsibilities are to update the change delta and handle DDL operations. - // These tasks do not interfere with or depend on the initialization process. - do.wg.Run(func() { do.updateStatsWorker(ctx, owner) }, "updateStatsWorker") - do.wg.Run(func() { - do.handleDDLEvent() - }, "handleDDLEvent") - // Wait for the stats worker to finish the initialization. - // Otherwise, we may start the auto analyze worker before the stats cache is initialized. - do.wg.Run( - func() { - select { - case <-do.StatsHandle().InitStatsDone: - case <-do.exit: // It may happen that before initStatsDone, tidb receive Ctrl+C - return - } - do.autoAnalyzeWorker(owner) - }, - "autoAnalyzeWorker", - ) - do.wg.Run( - func() { - select { - case <-do.StatsHandle().InitStatsDone: - case <-do.exit: // It may happen that before initStatsDone, tidb receive Ctrl+C - return - } - do.analyzeJobsCleanupWorker(owner) - }, - "analyzeJobsCleanupWorker", - ) - do.wg.Run( - func() { - // The initStatsCtx is used to store the internal session for initializing stats, - // so we need the gc min start ts calculation to track it as an internal session. - // Since the session manager may not be ready at this moment, `infosync.StoreInternalSession` can fail. - // we need to retry until the session manager is ready or the init stats completes. - for !infosync.StoreInternalSession(initStatsCtx) { - waitRetry := time.After(time.Second) - select { - case <-do.StatsHandle().InitStatsDone: - return - case <-waitRetry: - } - } - select { - case <-do.StatsHandle().InitStatsDone: - case <-do.exit: // It may happen that before initStatsDone, tidb receive Ctrl+C - return - } - infosync.DeleteInternalSession(initStatsCtx) - }, - "RemoveInitStatsFromInternalSessions", - ) - return nil -} - -func quitStatsOwner(do *Domain, mgr owner.Manager) { - <-do.exit - mgr.Cancel() -} - -// StartLoadStatsSubWorkers starts sub workers with new sessions to load stats concurrently. -func (do *Domain) StartLoadStatsSubWorkers(ctxList []sessionctx.Context) { - statsHandle := do.StatsHandle() - for _, ctx := range ctxList { - do.wg.Add(1) - go statsHandle.SubLoadWorker(ctx, do.exit, do.wg) - } - logutil.BgLogger().Info("start load stats sub workers", zap.Int("worker count", len(ctxList))) -} - -func (do *Domain) newOwnerManager(prompt, ownerKey string) owner.Manager { - id := do.ddl.OwnerManager().ID() - var statsOwner owner.Manager - if do.etcdClient == nil { - statsOwner = owner.NewMockManager(context.Background(), id, do.store, ownerKey) - } else { - statsOwner = owner.NewOwnerManager(context.Background(), do.etcdClient, prompt, id, ownerKey) - } - // TODO: Need to do something when err is not nil. - err := statsOwner.CampaignOwner() - if err != nil { - logutil.BgLogger().Warn("campaign owner failed", zap.Error(err)) - } - return statsOwner -} - -func (do *Domain) initStats(ctx context.Context) { - statsHandle := do.StatsHandle() - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("panic when initiating stats", zap.Any("r", r), - zap.Stack("stack")) - } - close(statsHandle.InitStatsDone) - }() - t := time.Now() - liteInitStats := config.GetGlobalConfig().Performance.LiteInitStats - var err error - if liteInitStats { - err = statsHandle.InitStatsLite(ctx, do.InfoSchema()) - } else { - err = statsHandle.InitStats(ctx, do.InfoSchema()) - } - if err != nil { - logutil.BgLogger().Error("init stats info failed", zap.Bool("lite", liteInitStats), zap.Duration("take time", time.Since(t)), zap.Error(err)) - } else { - logutil.BgLogger().Info("init stats info time", zap.Bool("lite", liteInitStats), zap.Duration("take time", time.Since(t))) - } -} - -func (do *Domain) loadStatsWorker() { - defer util.Recover(metrics.LabelDomain, "loadStatsWorker", nil, false) - lease := do.statsLease - if lease == 0 { - lease = 3 * time.Second - } - loadTicker := time.NewTicker(lease) - updStatsHealthyTicker := time.NewTicker(20 * lease) - defer func() { - loadTicker.Stop() - updStatsHealthyTicker.Stop() - logutil.BgLogger().Info("loadStatsWorker exited.") - }() - - ctx, cancelFunc := context.WithCancel(context.Background()) - do.cancelFns.mu.Lock() - do.cancelFns.fns = append(do.cancelFns.fns, cancelFunc) - do.cancelFns.mu.Unlock() - - do.initStats(ctx) - statsHandle := do.StatsHandle() - var err error - for { - select { - case <-loadTicker.C: - err = statsHandle.Update(ctx, do.InfoSchema()) - if err != nil { - logutil.BgLogger().Debug("update stats info failed", zap.Error(err)) - } - err = statsHandle.LoadNeededHistograms() - if err != nil { - logutil.BgLogger().Debug("load histograms failed", zap.Error(err)) - } - case <-updStatsHealthyTicker.C: - statsHandle.UpdateStatsHealthyMetrics() - case <-do.exit: - return - } - } -} - -func (do *Domain) indexUsageWorker() { - defer util.Recover(metrics.LabelDomain, "indexUsageWorker", nil, false) - gcStatsTicker := time.NewTicker(indexUsageGCDuration) - handle := do.StatsHandle() - defer func() { - logutil.BgLogger().Info("indexUsageWorker exited.") - }() - for { - select { - case <-do.exit: - return - case <-gcStatsTicker.C: - if err := handle.GCIndexUsage(); err != nil { - statslogutil.StatsLogger().Error("gc index usage failed", zap.Error(err)) - } - } - } -} - -func (*Domain) updateStatsWorkerExitPreprocessing(statsHandle *handle.Handle, owner owner.Manager) { - ch := make(chan struct{}, 1) - timeout, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - go func() { - logutil.BgLogger().Info("updateStatsWorker is going to exit, start to flush stats") - statsHandle.FlushStats() - logutil.BgLogger().Info("updateStatsWorker ready to release owner") - owner.Cancel() - ch <- struct{}{} - }() - select { - case <-ch: - logutil.BgLogger().Info("updateStatsWorker exit preprocessing finished") - return - case <-timeout.Done(): - logutil.BgLogger().Warn("updateStatsWorker exit preprocessing timeout, force exiting") - return - } -} - -func (do *Domain) handleDDLEvent() { - logutil.BgLogger().Info("handleDDLEvent started.") - defer util.Recover(metrics.LabelDomain, "handleDDLEvent", nil, false) - statsHandle := do.StatsHandle() - for { - select { - case <-do.exit: - return - // This channel is sent only by ddl owner. - case t := <-statsHandle.DDLEventCh(): - err := statsHandle.HandleDDLEvent(t) - if err != nil { - logutil.BgLogger().Error("handle ddl event failed", zap.String("event", t.String()), zap.Error(err)) - } - } - } -} - -func (do *Domain) updateStatsWorker(_ sessionctx.Context, owner owner.Manager) { - defer util.Recover(metrics.LabelDomain, "updateStatsWorker", nil, false) - logutil.BgLogger().Info("updateStatsWorker started.") - lease := do.statsLease - // We need to have different nodes trigger tasks at different times to avoid the herd effect. - randDuration := time.Duration(rand.Int63n(int64(time.Minute))) - deltaUpdateTicker := time.NewTicker(20*lease + randDuration) - gcStatsTicker := time.NewTicker(100 * lease) - dumpColStatsUsageTicker := time.NewTicker(100 * lease) - readMemTicker := time.NewTicker(memory.ReadMemInterval) - statsHandle := do.StatsHandle() - defer func() { - dumpColStatsUsageTicker.Stop() - gcStatsTicker.Stop() - deltaUpdateTicker.Stop() - readMemTicker.Stop() - do.SetStatsUpdating(false) - logutil.BgLogger().Info("updateStatsWorker exited.") - }() - defer util.Recover(metrics.LabelDomain, "updateStatsWorker", nil, false) - - for { - select { - case <-do.exit: - do.updateStatsWorkerExitPreprocessing(statsHandle, owner) - return - case <-deltaUpdateTicker.C: - err := statsHandle.DumpStatsDeltaToKV(false) - if err != nil { - logutil.BgLogger().Debug("dump stats delta failed", zap.Error(err)) - } - case <-gcStatsTicker.C: - if !owner.IsOwner() { - continue - } - err := statsHandle.GCStats(do.InfoSchema(), do.DDL().GetLease()) - if err != nil { - logutil.BgLogger().Debug("GC stats failed", zap.Error(err)) - } - case <-dumpColStatsUsageTicker.C: - err := statsHandle.DumpColStatsUsageToKV() - if err != nil { - logutil.BgLogger().Debug("dump column stats usage failed", zap.Error(err)) - } - - case <-readMemTicker.C: - memory.ForceReadMemStats() - } - } -} - -func (do *Domain) autoAnalyzeWorker(owner owner.Manager) { - defer util.Recover(metrics.LabelDomain, "autoAnalyzeWorker", nil, false) - statsHandle := do.StatsHandle() - analyzeTicker := time.NewTicker(do.statsLease) - defer func() { - analyzeTicker.Stop() - logutil.BgLogger().Info("autoAnalyzeWorker exited.") - }() - for { - select { - case <-analyzeTicker.C: - if variable.RunAutoAnalyze.Load() && !do.stopAutoAnalyze.Load() && owner.IsOwner() { - statsHandle.HandleAutoAnalyze() - } - case <-do.exit: - return - } - } -} - -// analyzeJobsCleanupWorker is a background worker that periodically performs two main tasks: -// -// 1. Garbage Collection: It removes outdated analyze jobs from the statistics handle. -// This operation is performed every hour and only if the current instance is the owner. -// Analyze jobs older than 7 days are considered outdated and are removed. -// -// 2. Cleanup: It cleans up corrupted analyze jobs. -// A corrupted analyze job is one that is in a 'pending' or 'running' state, -// but is associated with a TiDB instance that is either not currently running or has been restarted. -// Also, if the analyze job is killed by the user, it is considered corrupted. -// This operation is performed every 100 stats leases. -// It first retrieves the list of current analyze processes, then removes any analyze job -// that is not associated with a current process. Additionally, if the current instance is the owner, -// it also cleans up corrupted analyze jobs on dead instances. -func (do *Domain) analyzeJobsCleanupWorker(owner owner.Manager) { - defer util.Recover(metrics.LabelDomain, "analyzeJobsCleanupWorker", nil, false) - // For GC. - const gcInterval = time.Hour - const daysToKeep = 7 - gcTicker := time.NewTicker(gcInterval) - // For clean up. - // Default stats lease is 3 * time.Second. - // So cleanupInterval is 100 * 3 * time.Second = 5 * time.Minute. - var cleanupInterval = do.statsLease * 100 - cleanupTicker := time.NewTicker(cleanupInterval) - defer func() { - gcTicker.Stop() - cleanupTicker.Stop() - logutil.BgLogger().Info("analyzeJobsCleanupWorker exited.") - }() - statsHandle := do.StatsHandle() - for { - select { - case <-gcTicker.C: - // Only the owner should perform this operation. - if owner.IsOwner() { - updateTime := time.Now().AddDate(0, 0, -daysToKeep) - err := statsHandle.DeleteAnalyzeJobs(updateTime) - if err != nil { - logutil.BgLogger().Warn("gc analyze history failed", zap.Error(err)) - } - } - case <-cleanupTicker.C: - sm := do.InfoSyncer().GetSessionManager() - if sm == nil { - continue - } - analyzeProcessIDs := make(map[uint64]struct{}, 8) - for _, process := range sm.ShowProcessList() { - if isAnalyzeTableSQL(process.Info) { - analyzeProcessIDs[process.ID] = struct{}{} - } - } - - err := statsHandle.CleanupCorruptedAnalyzeJobsOnCurrentInstance(analyzeProcessIDs) - if err != nil { - logutil.BgLogger().Warn("cleanup analyze jobs on current instance failed", zap.Error(err)) - } - - if owner.IsOwner() { - err = statsHandle.CleanupCorruptedAnalyzeJobsOnDeadInstances() - if err != nil { - logutil.BgLogger().Warn("cleanup analyze jobs on dead instances failed", zap.Error(err)) - } - } - case <-do.exit: - return - } - } -} - -func isAnalyzeTableSQL(sql string) bool { - // Get rid of the comments. - normalizedSQL := parser.Normalize(sql, "ON") - return strings.HasPrefix(normalizedSQL, "analyze table") -} - -// ExpensiveQueryHandle returns the expensive query handle. -func (do *Domain) ExpensiveQueryHandle() *expensivequery.Handle { - return do.expensiveQueryHandle -} - -// MemoryUsageAlarmHandle returns the memory usage alarm handle. -func (do *Domain) MemoryUsageAlarmHandle() *memoryusagealarm.Handle { - return do.memoryUsageAlarmHandle -} - -// ServerMemoryLimitHandle returns the expensive query handle. -func (do *Domain) ServerMemoryLimitHandle() *servermemorylimit.Handle { - return do.serverMemoryLimitHandle -} - -const ( - privilegeKey = "/tidb/privilege" - sysVarCacheKey = "/tidb/sysvars" - tiflashComputeNodeKey = "/tiflash/new_tiflash_compute_nodes" -) - -// NotifyUpdatePrivilege updates privilege key in etcd, TiDB client that watches -// the key will get notification. -func (do *Domain) NotifyUpdatePrivilege() error { - // No matter skip-grant-table is configured or not, sending an etcd message is required. - // Because we need to tell other TiDB instances to update privilege data, say, we're changing the - // password using a special TiDB instance and want the new password to take effect. - if do.etcdClient != nil { - row := do.etcdClient.KV - _, err := row.Put(context.Background(), privilegeKey, "") - if err != nil { - logutil.BgLogger().Warn("notify update privilege failed", zap.Error(err)) - } - } - - // If skip-grant-table is configured, do not flush privileges. - // Because LoadPrivilegeLoop does not run and the privilege Handle is nil, - // the call to do.PrivilegeHandle().Update would panic. - if config.GetGlobalConfig().Security.SkipGrantTable { - return nil - } - - // update locally - ctx, err := do.sysSessionPool.Get() - if err != nil { - return err - } - defer do.sysSessionPool.Put(ctx) - return do.PrivilegeHandle().Update(ctx.(sessionctx.Context)) -} - -// NotifyUpdateSysVarCache updates the sysvar cache key in etcd, which other TiDB -// clients are subscribed to for updates. For the caller, the cache is also built -// synchronously so that the effect is immediate. -func (do *Domain) NotifyUpdateSysVarCache(updateLocal bool) { - if do.etcdClient != nil { - row := do.etcdClient.KV - _, err := row.Put(context.Background(), sysVarCacheKey, "") - if err != nil { - logutil.BgLogger().Warn("notify update sysvar cache failed", zap.Error(err)) - } - } - // update locally - if updateLocal { - if err := do.rebuildSysVarCache(nil); err != nil { - logutil.BgLogger().Error("rebuilding sysvar cache failed", zap.Error(err)) - } - } -} - -// LoadSigningCertLoop loads the signing cert periodically to make sure it's fresh new. -func (do *Domain) LoadSigningCertLoop(signingCert, signingKey string) { - sessionstates.SetCertPath(signingCert) - sessionstates.SetKeyPath(signingKey) - - do.wg.Run(func() { - defer func() { - logutil.BgLogger().Debug("loadSigningCertLoop exited.") - }() - defer util.Recover(metrics.LabelDomain, "LoadSigningCertLoop", nil, false) - - for { - select { - case <-time.After(sessionstates.LoadCertInterval): - sessionstates.ReloadSigningCert() - case <-do.exit: - return - } - } - }, "loadSigningCertLoop") -} - -// ServerID gets serverID. -func (do *Domain) ServerID() uint64 { - return atomic.LoadUint64(&do.serverID) -} - -// IsLostConnectionToPD indicates lost connection to PD or not. -func (do *Domain) IsLostConnectionToPD() bool { - return do.isLostConnectionToPD.Load() != 0 -} - -// NextConnID return next connection ID. -func (do *Domain) NextConnID() uint64 { - return do.connIDAllocator.NextID() -} - -// ReleaseConnID releases connection ID. -func (do *Domain) ReleaseConnID(connID uint64) { - do.connIDAllocator.Release(connID) -} - -const ( - serverIDEtcdPath = "/tidb/server_id" - refreshServerIDRetryCnt = 3 - acquireServerIDRetryInterval = 300 * time.Millisecond - acquireServerIDTimeout = 10 * time.Second - retrieveServerIDSessionTimeout = 10 * time.Second - - acquire32BitsServerIDRetryCnt = 3 -) - -var ( - // serverIDTTL should be LONG ENOUGH to avoid barbarically killing an on-going long-run SQL. - serverIDTTL = 12 * time.Hour - // serverIDTimeToKeepAlive is the interval that we keep serverID TTL alive periodically. - serverIDTimeToKeepAlive = 5 * time.Minute - // serverIDTimeToCheckPDConnectionRestored is the interval that we check connection to PD restored (after broken) periodically. - serverIDTimeToCheckPDConnectionRestored = 10 * time.Second - // lostConnectionToPDTimeout is the duration that when TiDB cannot connect to PD excceeds this limit, - // we realize the connection to PD is lost utterly, and server ID acquired before should be released. - // Must be SHORTER than `serverIDTTL`. - lostConnectionToPDTimeout = 6 * time.Hour -) - -var ( - ldflagIsGlobalKillTest = "0" // 1:Yes, otherwise:No. - ldflagServerIDTTL = "10" // in seconds. - ldflagServerIDTimeToKeepAlive = "1" // in seconds. - ldflagServerIDTimeToCheckPDConnectionRestored = "1" // in seconds. - ldflagLostConnectionToPDTimeout = "5" // in seconds. -) - -func initByLDFlagsForGlobalKill() { - if ldflagIsGlobalKillTest == "1" { - var ( - i int - err error - ) - - if i, err = strconv.Atoi(ldflagServerIDTTL); err != nil { - panic("invalid ldflagServerIDTTL") - } - serverIDTTL = time.Duration(i) * time.Second - - if i, err = strconv.Atoi(ldflagServerIDTimeToKeepAlive); err != nil { - panic("invalid ldflagServerIDTimeToKeepAlive") - } - serverIDTimeToKeepAlive = time.Duration(i) * time.Second - - if i, err = strconv.Atoi(ldflagServerIDTimeToCheckPDConnectionRestored); err != nil { - panic("invalid ldflagServerIDTimeToCheckPDConnectionRestored") - } - serverIDTimeToCheckPDConnectionRestored = time.Duration(i) * time.Second - - if i, err = strconv.Atoi(ldflagLostConnectionToPDTimeout); err != nil { - panic("invalid ldflagLostConnectionToPDTimeout") - } - lostConnectionToPDTimeout = time.Duration(i) * time.Second - - logutil.BgLogger().Info("global_kill_test is enabled", zap.Duration("serverIDTTL", serverIDTTL), - zap.Duration("serverIDTimeToKeepAlive", serverIDTimeToKeepAlive), - zap.Duration("serverIDTimeToCheckPDConnectionRestored", serverIDTimeToCheckPDConnectionRestored), - zap.Duration("lostConnectionToPDTimeout", lostConnectionToPDTimeout)) - } -} - -func (do *Domain) retrieveServerIDSession(ctx context.Context) (*concurrency.Session, error) { - if do.serverIDSession != nil { - return do.serverIDSession, nil - } - - // `etcdClient.Grant` needs a shortterm timeout, to avoid blocking if connection to PD lost, - // while `etcdClient.KeepAlive` should be longterm. - // So we separately invoke `etcdClient.Grant` and `concurrency.NewSession` with leaseID. - childCtx, cancel := context.WithTimeout(ctx, retrieveServerIDSessionTimeout) - resp, err := do.etcdClient.Grant(childCtx, int64(serverIDTTL.Seconds())) - cancel() - if err != nil { - logutil.BgLogger().Error("retrieveServerIDSession.Grant fail", zap.Error(err)) - return nil, err - } - leaseID := resp.ID - - session, err := concurrency.NewSession(do.etcdClient, - concurrency.WithLease(leaseID), concurrency.WithContext(context.Background())) - if err != nil { - logutil.BgLogger().Error("retrieveServerIDSession.NewSession fail", zap.Error(err)) - return nil, err - } - do.serverIDSession = session - return session, nil -} - -func (do *Domain) acquireServerID(ctx context.Context) error { - atomic.StoreUint64(&do.serverID, 0) - - session, err := do.retrieveServerIDSession(ctx) - if err != nil { - return err - } - - conflictCnt := 0 - for { - var proposeServerID uint64 - if config.GetGlobalConfig().Enable32BitsConnectionID { - proposeServerID, err = do.proposeServerID(ctx, conflictCnt) - if err != nil { - return errors.Trace(err) - } - } else { - // get a random serverID: [1, MaxServerID64] - proposeServerID = uint64(rand.Int63n(int64(globalconn.MaxServerID64)) + 1) // #nosec G404 - } - - key := fmt.Sprintf("%s/%v", serverIDEtcdPath, proposeServerID) - cmp := clientv3.Compare(clientv3.CreateRevision(key), "=", 0) - value := "0" - - childCtx, cancel := context.WithTimeout(ctx, acquireServerIDTimeout) - txn := do.etcdClient.Txn(childCtx) - t := txn.If(cmp) - resp, err := t.Then(clientv3.OpPut(key, value, clientv3.WithLease(session.Lease()))).Commit() - cancel() - if err != nil { - return err - } - if !resp.Succeeded { - logutil.BgLogger().Info("propose serverID exists, try again", zap.Uint64("proposeServerID", proposeServerID)) - time.Sleep(acquireServerIDRetryInterval) - conflictCnt++ - continue - } - - atomic.StoreUint64(&do.serverID, proposeServerID) - logutil.BgLogger().Info("acquireServerID", zap.Uint64("serverID", do.ServerID()), - zap.String("lease id", strconv.FormatInt(int64(session.Lease()), 16))) - return nil - } -} - -func (do *Domain) releaseServerID(context.Context) { - serverID := do.ServerID() - if serverID == 0 { - return - } - atomic.StoreUint64(&do.serverID, 0) - - if do.etcdClient == nil { - return - } - key := fmt.Sprintf("%s/%v", serverIDEtcdPath, serverID) - err := ddlutil.DeleteKeyFromEtcd(key, do.etcdClient, refreshServerIDRetryCnt, acquireServerIDTimeout) - if err != nil { - logutil.BgLogger().Error("releaseServerID fail", zap.Uint64("serverID", serverID), zap.Error(err)) - } else { - logutil.BgLogger().Info("releaseServerID succeed", zap.Uint64("serverID", serverID)) - } -} - -// propose server ID by random. -func (*Domain) proposeServerID(ctx context.Context, conflictCnt int) (uint64, error) { - // get a random server ID in range [min, max] - randomServerID := func(min uint64, max uint64) uint64 { - return uint64(rand.Int63n(int64(max-min+1)) + int64(min)) // #nosec G404 - } - - if conflictCnt < acquire32BitsServerIDRetryCnt { - // get existing server IDs. - allServerInfo, err := infosync.GetAllServerInfo(ctx) - if err != nil { - return 0, errors.Trace(err) - } - // `allServerInfo` contains current TiDB. - if float32(len(allServerInfo)) <= 0.9*float32(globalconn.MaxServerID32) { - serverIDs := make(map[uint64]struct{}, len(allServerInfo)) - for _, info := range allServerInfo { - serverID := info.ServerIDGetter() - if serverID <= globalconn.MaxServerID32 { - serverIDs[serverID] = struct{}{} - } - } - - for retry := 0; retry < 15; retry++ { - randServerID := randomServerID(1, globalconn.MaxServerID32) - if _, ok := serverIDs[randServerID]; !ok { - return randServerID, nil - } - } - } - logutil.BgLogger().Info("upgrade to 64 bits server ID due to used up", zap.Int("len(allServerInfo)", len(allServerInfo))) - } else { - logutil.BgLogger().Info("upgrade to 64 bits server ID due to conflict", zap.Int("conflictCnt", conflictCnt)) - } - - // upgrade to 64 bits. - return randomServerID(globalconn.MaxServerID32+1, globalconn.MaxServerID64), nil -} - -func (do *Domain) refreshServerIDTTL(ctx context.Context) error { - session, err := do.retrieveServerIDSession(ctx) - if err != nil { - return err - } - - key := fmt.Sprintf("%s/%v", serverIDEtcdPath, do.ServerID()) - value := "0" - err = ddlutil.PutKVToEtcd(ctx, do.etcdClient, refreshServerIDRetryCnt, key, value, clientv3.WithLease(session.Lease())) - if err != nil { - logutil.BgLogger().Error("refreshServerIDTTL fail", zap.Uint64("serverID", do.ServerID()), zap.Error(err)) - } else { - logutil.BgLogger().Info("refreshServerIDTTL succeed", zap.Uint64("serverID", do.ServerID()), - zap.String("lease id", strconv.FormatInt(int64(session.Lease()), 16))) - } - return err -} - -func (do *Domain) serverIDKeeper() { - defer func() { - do.wg.Done() - logutil.BgLogger().Info("serverIDKeeper exited.") - }() - defer util.Recover(metrics.LabelDomain, "serverIDKeeper", func() { - logutil.BgLogger().Info("recover serverIDKeeper.") - // should be called before `do.wg.Done()`, to ensure that Domain.Close() waits for the new `serverIDKeeper()` routine. - do.wg.Add(1) - go do.serverIDKeeper() - }, false) - - tickerKeepAlive := time.NewTicker(serverIDTimeToKeepAlive) - tickerCheckRestored := time.NewTicker(serverIDTimeToCheckPDConnectionRestored) - defer func() { - tickerKeepAlive.Stop() - tickerCheckRestored.Stop() - }() - - blocker := make(chan struct{}) // just used for blocking the sessionDone() when session is nil. - sessionDone := func() <-chan struct{} { - if do.serverIDSession == nil { - return blocker - } - return do.serverIDSession.Done() - } - - var lastSucceedTimestamp time.Time - - onConnectionToPDRestored := func() { - logutil.BgLogger().Info("restored connection to PD") - do.isLostConnectionToPD.Store(0) - lastSucceedTimestamp = time.Now() - - if err := do.info.StoreServerInfo(context.Background()); err != nil { - logutil.BgLogger().Error("StoreServerInfo failed", zap.Error(err)) - } - } - - onConnectionToPDLost := func() { - logutil.BgLogger().Warn("lost connection to PD") - do.isLostConnectionToPD.Store(1) - - // Kill all connections when lost connection to PD, - // to avoid the possibility that another TiDB instance acquires the same serverID and generates a same connection ID, - // which will lead to a wrong connection killed. - do.InfoSyncer().GetSessionManager().KillAllConnections() - } - - for { - select { - case <-tickerKeepAlive.C: - if !do.IsLostConnectionToPD() { - if err := do.refreshServerIDTTL(context.Background()); err == nil { - lastSucceedTimestamp = time.Now() - } else { - if lostConnectionToPDTimeout > 0 && time.Since(lastSucceedTimestamp) > lostConnectionToPDTimeout { - onConnectionToPDLost() - } - } - } - case <-tickerCheckRestored.C: - if do.IsLostConnectionToPD() { - if err := do.acquireServerID(context.Background()); err == nil { - onConnectionToPDRestored() - } - } - case <-sessionDone(): - // inform that TTL of `serverID` is expired. See https://godoc.org/github.com/coreos/etcd/clientv3/concurrency#Session.Done - // Should be in `IsLostConnectionToPD` state, as `lostConnectionToPDTimeout` is shorter than `serverIDTTL`. - // So just set `do.serverIDSession = nil` to restart `serverID` session in `retrieveServerIDSession()`. - logutil.BgLogger().Info("serverIDSession need restart") - do.serverIDSession = nil - case <-do.exit: - return - } - } -} - -// StartTTLJobManager creates and starts the ttl job manager -func (do *Domain) StartTTLJobManager() { - ttlJobManager := ttlworker.NewJobManager(do.ddl.GetID(), do.sysSessionPool, do.store, do.etcdClient, do.ddl.OwnerManager().IsOwner) - do.ttlJobManager.Store(ttlJobManager) - ttlJobManager.Start() -} - -// TTLJobManager returns the ttl job manager on this domain -func (do *Domain) TTLJobManager() *ttlworker.JobManager { - return do.ttlJobManager.Load() -} - -// StopAutoAnalyze stops (*Domain).autoAnalyzeWorker to launch new auto analyze jobs. -func (do *Domain) StopAutoAnalyze() { - do.stopAutoAnalyze.Store(true) -} - -// InitInstancePlanCache initializes the instance level plan cache for this Domain. -func (do *Domain) InitInstancePlanCache() { - softLimit := variable.InstancePlanCacheTargetMemSize.Load() - hardLimit := variable.InstancePlanCacheMaxMemSize.Load() - do.instancePlanCache = NewInstancePlanCache(softLimit, hardLimit) - // use a separate goroutine to avoid the eviction blocking other operations. - do.wg.Run(do.planCacheEvictTrigger, "planCacheEvictTrigger") - do.wg.Run(do.planCacheMetricsAndVars, "planCacheMetricsAndVars") -} - -// GetInstancePlanCache returns the instance level plan cache in this Domain. -func (do *Domain) GetInstancePlanCache() sessionctx.InstancePlanCache { - return do.instancePlanCache -} - -// planCacheMetricsAndVars updates metrics and variables for Instance Plan Cache periodically. -func (do *Domain) planCacheMetricsAndVars() { - defer util.Recover(metrics.LabelDomain, "planCacheMetricsAndVars", nil, false) - ticker := time.NewTicker(time.Second * 15) // 15s by default - defer func() { - ticker.Stop() - logutil.BgLogger().Info("planCacheMetricsAndVars exited.") - }() - - for { - select { - case <-ticker.C: - // update limits - softLimit := variable.InstancePlanCacheTargetMemSize.Load() - hardLimit := variable.InstancePlanCacheMaxMemSize.Load() - curSoft, curHard := do.instancePlanCache.GetLimits() - if curSoft != softLimit || curHard != hardLimit { - do.instancePlanCache.SetLimits(softLimit, hardLimit) - } - - // update the metrics - size := do.instancePlanCache.Size() - memUsage := do.instancePlanCache.MemUsage() - metrics2.GetPlanCacheInstanceNumCounter(true).Set(float64(size)) - metrics2.GetPlanCacheInstanceMemoryUsage(true).Set(float64(memUsage)) - case <-do.exit: - return - } - } -} - -// planCacheEvictTrigger triggers the plan cache eviction periodically. -func (do *Domain) planCacheEvictTrigger() { - defer util.Recover(metrics.LabelDomain, "planCacheEvictTrigger", nil, false) - ticker := time.NewTicker(time.Second * 15) // 15s by default - defer func() { - ticker.Stop() - logutil.BgLogger().Info("planCacheEvictTrigger exited.") - }() - - for { - select { - case <-ticker.C: - // trigger the eviction - do.instancePlanCache.Evict() - case <-do.exit: - return - } - } -} - -func init() { - initByLDFlagsForGlobalKill() -} - -var ( - // ErrInfoSchemaExpired returns the error that information schema is out of date. - ErrInfoSchemaExpired = dbterror.ClassDomain.NewStd(errno.ErrInfoSchemaExpired) - // ErrInfoSchemaChanged returns the error that information schema is changed. - ErrInfoSchemaChanged = dbterror.ClassDomain.NewStdErr(errno.ErrInfoSchemaChanged, - mysql.Message(errno.MySQLErrName[errno.ErrInfoSchemaChanged].Raw+". "+kv.TxnRetryableMark, nil)) -) - -// SysProcesses holds the sys processes infos -type SysProcesses struct { - mu *sync.RWMutex - procMap map[uint64]sysproctrack.TrackProc -} - -// Track tracks the sys process into procMap -func (s *SysProcesses) Track(id uint64, proc sysproctrack.TrackProc) error { - s.mu.Lock() - defer s.mu.Unlock() - if oldProc, ok := s.procMap[id]; ok && oldProc != proc { - return errors.Errorf("The ID is in use: %v", id) - } - s.procMap[id] = proc - proc.GetSessionVars().ConnectionID = id - proc.GetSessionVars().SQLKiller.Reset() - return nil -} - -// UnTrack removes the sys process from procMap -func (s *SysProcesses) UnTrack(id uint64) { - s.mu.Lock() - defer s.mu.Unlock() - if proc, ok := s.procMap[id]; ok { - delete(s.procMap, id) - proc.GetSessionVars().ConnectionID = 0 - proc.GetSessionVars().SQLKiller.Reset() - } -} - -// GetSysProcessList gets list of system ProcessInfo -func (s *SysProcesses) GetSysProcessList() map[uint64]*util.ProcessInfo { - s.mu.RLock() - defer s.mu.RUnlock() - rs := make(map[uint64]*util.ProcessInfo) - for connID, proc := range s.procMap { - // if session is still tracked in this map, it's not returned to sysSessionPool yet - if pi := proc.ShowProcess(); pi != nil && pi.ID == connID { - rs[connID] = pi - } - } - return rs -} - -// KillSysProcess kills sys process with specified ID -func (s *SysProcesses) KillSysProcess(id uint64) { - s.mu.Lock() - defer s.mu.Unlock() - if proc, ok := s.procMap[id]; ok { - proc.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted) - } -} diff --git a/pkg/domain/historical_stats.go b/pkg/domain/historical_stats.go index 4f8cb16ce8ff9..9b4dd016d2711 100644 --- a/pkg/domain/historical_stats.go +++ b/pkg/domain/historical_stats.go @@ -35,11 +35,11 @@ type HistoricalStatsWorker struct { // SendTblToDumpHistoricalStats send tableID to worker to dump historical stats func (w *HistoricalStatsWorker) SendTblToDumpHistoricalStats(tableID int64) { send := enableDumpHistoricalStats.Load() - if val, _err_ := failpoint.Eval(_curpkg_("sendHistoricalStats")); _err_ == nil { + failpoint.Inject("sendHistoricalStats", func(val failpoint.Value) { if val.(bool) { send = true } - } + }) if !send { return } diff --git a/pkg/domain/historical_stats.go__failpoint_stash__ b/pkg/domain/historical_stats.go__failpoint_stash__ deleted file mode 100644 index 9b4dd016d2711..0000000000000 --- a/pkg/domain/historical_stats.go__failpoint_stash__ +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package domain - -import ( - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - domain_metrics "github.com/pingcap/tidb/pkg/domain/metrics" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/statistics/handle" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -// HistoricalStatsWorker indicates for dump historical stats -type HistoricalStatsWorker struct { - tblCH chan int64 - sctx sessionctx.Context -} - -// SendTblToDumpHistoricalStats send tableID to worker to dump historical stats -func (w *HistoricalStatsWorker) SendTblToDumpHistoricalStats(tableID int64) { - send := enableDumpHistoricalStats.Load() - failpoint.Inject("sendHistoricalStats", func(val failpoint.Value) { - if val.(bool) { - send = true - } - }) - if !send { - return - } - select { - case w.tblCH <- tableID: - return - default: - logutil.BgLogger().Warn("discard dump historical stats task", zap.Int64("table-id", tableID)) - } -} - -// DumpHistoricalStats dump stats by given tableID -func (w *HistoricalStatsWorker) DumpHistoricalStats(tableID int64, statsHandle *handle.Handle) error { - historicalStatsEnabled, err := statsHandle.CheckHistoricalStatsEnable() - if err != nil { - return errors.Errorf("check tidb_enable_historical_stats failed: %v", err) - } - if !historicalStatsEnabled { - return nil - } - sctx := w.sctx - is := GetDomain(sctx).InfoSchema() - isPartition := false - var tblInfo *model.TableInfo - tbl, existed := is.TableByID(tableID) - if !existed { - tbl, db, p := is.FindTableByPartitionID(tableID) - if !(tbl != nil && db != nil && p != nil) { - return errors.Errorf("cannot get table by id %d", tableID) - } - isPartition = true - tblInfo = tbl.Meta() - } else { - tblInfo = tbl.Meta() - } - dbInfo, existed := infoschema.SchemaByTable(is, tblInfo) - if !existed { - return errors.Errorf("cannot get DBInfo by TableID %d", tableID) - } - if _, err := statsHandle.RecordHistoricalStatsToStorage(dbInfo.Name.O, tblInfo, tableID, isPartition); err != nil { - domain_metrics.GenerateHistoricalStatsFailedCounter.Inc() - return errors.Errorf("record table %s.%s's historical stats failed, err:%v", dbInfo.Name.O, tblInfo.Name.O, err) - } - domain_metrics.GenerateHistoricalStatsSuccessCounter.Inc() - return nil -} - -// GetOneHistoricalStatsTable gets one tableID from channel, only used for test -func (w *HistoricalStatsWorker) GetOneHistoricalStatsTable() int64 { - select { - case tblID := <-w.tblCH: - return tblID - default: - return -1 - } -} diff --git a/pkg/domain/infosync/binding__failpoint_binding__.go b/pkg/domain/infosync/binding__failpoint_binding__.go deleted file mode 100644 index c4cc7873735bf..0000000000000 --- a/pkg/domain/infosync/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package infosync - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/domain/infosync/info.go b/pkg/domain/infosync/info.go index 9845ee180a7e1..1e5e0cb51e6e6 100644 --- a/pkg/domain/infosync/info.go +++ b/pkg/domain/infosync/info.go @@ -276,7 +276,7 @@ func (is *InfoSyncer) initResourceManagerClient(pdCli pd.Client) { if pdCli == nil { cli = NewMockResourceManagerClient() } - if val, _err_ := failpoint.Eval(_curpkg_("managerAlreadyCreateSomeGroups")); _err_ == nil { + failpoint.Inject("managerAlreadyCreateSomeGroups", func(val failpoint.Value) { if val.(bool) { _, err := cli.AddResourceGroup(context.TODO(), &rmpb.ResourceGroup{ @@ -305,7 +305,7 @@ func (is *InfoSyncer) initResourceManagerClient(pdCli pd.Client) { log.Warn("fail to create default group", zap.Error(err)) } } - } + }) is.resourceManagerClient = cli } @@ -355,11 +355,11 @@ func SetMockTiFlash(tiflash *MockTiFlash) { // GetServerInfo gets self server static information. func GetServerInfo() (*ServerInfo, error) { - if v, _err_ := failpoint.Eval(_curpkg_("mockGetServerInfo")); _err_ == nil { + failpoint.Inject("mockGetServerInfo", func(v failpoint.Value) { var res ServerInfo err := json.Unmarshal([]byte(v.(string)), &res) - return &res, err - } + failpoint.Return(&res, err) + }) is, err := getGlobalInfoSyncer() if err != nil { return nil, err @@ -394,11 +394,11 @@ func (is *InfoSyncer) getServerInfoByID(ctx context.Context, id string) (*Server // GetAllServerInfo gets all servers static information from etcd. func GetAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockGetAllServerInfo")); _err_ == nil { + failpoint.Inject("mockGetAllServerInfo", func(val failpoint.Value) { res := make(map[string]*ServerInfo) err := json.Unmarshal([]byte(val.(string)), &res) - return res, err - } + failpoint.Return(res, err) + }) is, err := getGlobalInfoSyncer() if err != nil { return nil, err @@ -538,15 +538,15 @@ func GetRuleBundle(ctx context.Context, name string) (*placement.Bundle, error) // PutRuleBundles is used to post specific rule bundles to PD. func PutRuleBundles(ctx context.Context, bundles []*placement.Bundle) error { - if isServiceError, _err_ := failpoint.Eval(_curpkg_("putRuleBundlesError")); _err_ == nil { + failpoint.Inject("putRuleBundlesError", func(isServiceError failpoint.Value) { var err error if isServiceError.(bool) { err = ErrHTTPServiceError.FastGen("mock service error") } else { err = errors.New("mock other error") } - return err - } + failpoint.Return(err) + }) is, err := getGlobalInfoSyncer() if err != nil { @@ -1034,14 +1034,14 @@ func getServerInfo(id string, serverIDGetter func() uint64) *ServerInfo { metrics.ServerInfo.WithLabelValues(mysql.TiDBReleaseVersion, info.GitHash).Set(float64(info.StartTimestamp)) - if val, _err_ := failpoint.Eval(_curpkg_("mockServerInfo")); _err_ == nil { + failpoint.Inject("mockServerInfo", func(val failpoint.Value) { if val.(bool) { info.StartTimestamp = 1282967700 info.Labels = map[string]string{ "foo": "bar", } } - } + }) return info } @@ -1337,11 +1337,11 @@ type TiProxyServerInfo struct { // GetTiProxyServerInfo gets all TiProxy servers information from etcd. func GetTiProxyServerInfo(ctx context.Context) (map[string]*TiProxyServerInfo, error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockGetTiProxyServerInfo")); _err_ == nil { + failpoint.Inject("mockGetTiProxyServerInfo", func(val failpoint.Value) { res := make(map[string]*TiProxyServerInfo) err := json.Unmarshal([]byte(val.(string)), &res) - return res, err - } + failpoint.Return(res, err) + }) is, err := getGlobalInfoSyncer() if err != nil { return nil, err diff --git a/pkg/domain/infosync/info.go__failpoint_stash__ b/pkg/domain/infosync/info.go__failpoint_stash__ deleted file mode 100644 index 1e5e0cb51e6e6..0000000000000 --- a/pkg/domain/infosync/info.go__failpoint_stash__ +++ /dev/null @@ -1,1457 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package infosync - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "os" - "path" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - rmpb "github.com/pingcap/kvproto/pkg/resource_manager" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/label" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/session/cursor" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - util2 "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/engine" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/versioninfo" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/client/v3/concurrency" - "go.uber.org/zap" -) - -const ( - // ServerInformationPath store server information such as IP, port and so on. - ServerInformationPath = "/tidb/server/info" - // ServerMinStartTSPath store the server min start timestamp. - ServerMinStartTSPath = "/tidb/server/minstartts" - // TiFlashTableSyncProgressPath store the tiflash table replica sync progress. - TiFlashTableSyncProgressPath = "/tiflash/table/sync" - // keyOpDefaultRetryCnt is the default retry count for etcd store. - keyOpDefaultRetryCnt = 5 - // keyOpDefaultTimeout is the default time out for etcd store. - keyOpDefaultTimeout = 1 * time.Second - // ReportInterval is interval of infoSyncerKeeper reporting min startTS. - ReportInterval = 30 * time.Second - // TopologyInformationPath means etcd path for storing topology info. - TopologyInformationPath = "/topology/tidb" - // TopologySessionTTL is ttl for topology, ant it's the ETCD session's TTL in seconds. - TopologySessionTTL = 45 - // TopologyTimeToRefresh means time to refresh etcd. - TopologyTimeToRefresh = 30 * time.Second - // TopologyPrometheus means address of prometheus. - TopologyPrometheus = "/topology/prometheus" - // TopologyTiProxy means address of TiProxy. - TopologyTiProxy = "/topology/tiproxy" - // infoSuffix is the suffix of TiDB/TiProxy topology info. - infoSuffix = "/info" - // TopologyTiCDC means address of TiCDC. - TopologyTiCDC = "/topology/ticdc" - // TablePrometheusCacheExpiry is the expiry time for prometheus address cache. - TablePrometheusCacheExpiry = 10 * time.Second - // RequestRetryInterval is the sleep time before next retry for http request - RequestRetryInterval = 200 * time.Millisecond - // SyncBundlesMaxRetry is the max retry times for sync placement bundles - SyncBundlesMaxRetry = 3 -) - -// ErrPrometheusAddrIsNotSet is the error that Prometheus address is not set in PD and etcd -var ErrPrometheusAddrIsNotSet = dbterror.ClassDomain.NewStd(errno.ErrPrometheusAddrIsNotSet) - -// InfoSyncer stores server info to etcd when the tidb-server starts and delete when tidb-server shuts down. -type InfoSyncer struct { - // `etcdClient` must be used when keyspace is not set, or when the logic to each etcd path needs to be separated by keyspace. - etcdCli *clientv3.Client - // `unprefixedEtcdCli` will never set the etcd namespace prefix by keyspace. - // It is only used in storeMinStartTS and RemoveMinStartTS now. - // It must be used when the etcd path isn't needed to separate by keyspace. - // See keyspace RFC: https://github.com/pingcap/tidb/pull/39685 - unprefixedEtcdCli *clientv3.Client - pdHTTPCli pdhttp.Client - info *ServerInfo - serverInfoPath string - minStartTS uint64 - minStartTSPath string - managerMu struct { - mu sync.RWMutex - util2.SessionManager - } - session *concurrency.Session - topologySession *concurrency.Session - prometheusAddr string - modifyTime time.Time - labelRuleManager LabelRuleManager - placementManager PlacementManager - scheduleManager ScheduleManager - tiflashReplicaManager TiFlashReplicaManager - resourceManagerClient pd.ResourceManagerClient -} - -// ServerInfo is server static information. -// It will not be updated when tidb-server running. So please only put static information in ServerInfo struct. -type ServerInfo struct { - ServerVersionInfo - ID string `json:"ddl_id"` - IP string `json:"ip"` - Port uint `json:"listening_port"` - StatusPort uint `json:"status_port"` - Lease string `json:"lease"` - BinlogStatus string `json:"binlog_status"` - StartTimestamp int64 `json:"start_timestamp"` - Labels map[string]string `json:"labels"` - // ServerID is a function, to always retrieve latest serverID from `Domain`, - // which will be changed on occasions such as connection to PD is restored after broken. - ServerIDGetter func() uint64 `json:"-"` - - // JSONServerID is `serverID` for json marshal/unmarshal ONLY. - JSONServerID uint64 `json:"server_id"` -} - -// Marshal `ServerInfo` into bytes. -func (info *ServerInfo) Marshal() ([]byte, error) { - info.JSONServerID = info.ServerIDGetter() - infoBuf, err := json.Marshal(info) - if err != nil { - return nil, errors.Trace(err) - } - return infoBuf, nil -} - -// Unmarshal `ServerInfo` from bytes. -func (info *ServerInfo) Unmarshal(v []byte) error { - if err := json.Unmarshal(v, info); err != nil { - return err - } - info.ServerIDGetter = func() uint64 { - return info.JSONServerID - } - return nil -} - -// ServerVersionInfo is the server version and git_hash. -type ServerVersionInfo struct { - Version string `json:"version"` - GitHash string `json:"git_hash"` -} - -// globalInfoSyncer stores the global infoSyncer. -// Use a global variable for simply the code, use the domain.infoSyncer will have circle import problem in some pkg. -// Use atomic.Pointer to avoid data race in the test. -var globalInfoSyncer atomic.Pointer[InfoSyncer] - -func getGlobalInfoSyncer() (*InfoSyncer, error) { - v := globalInfoSyncer.Load() - if v == nil { - return nil, errors.New("infoSyncer is not initialized") - } - return v, nil -} - -func setGlobalInfoSyncer(is *InfoSyncer) { - globalInfoSyncer.Store(is) -} - -// GlobalInfoSyncerInit return a new InfoSyncer. It is exported for testing. -func GlobalInfoSyncerInit( - ctx context.Context, - id string, - serverIDGetter func() uint64, - etcdCli, unprefixedEtcdCli *clientv3.Client, - pdCli pd.Client, pdHTTPCli pdhttp.Client, - codec tikv.Codec, - skipRegisterToDashBoard bool, -) (*InfoSyncer, error) { - if pdHTTPCli != nil { - pdHTTPCli = pdHTTPCli. - WithCallerID("tidb-info-syncer"). - WithRespHandler(pdResponseHandler) - } - is := &InfoSyncer{ - etcdCli: etcdCli, - unprefixedEtcdCli: unprefixedEtcdCli, - pdHTTPCli: pdHTTPCli, - info: getServerInfo(id, serverIDGetter), - serverInfoPath: fmt.Sprintf("%s/%s", ServerInformationPath, id), - minStartTSPath: fmt.Sprintf("%s/%s", ServerMinStartTSPath, id), - } - err := is.init(ctx, skipRegisterToDashBoard) - if err != nil { - return nil, err - } - is.initLabelRuleManager() - is.initPlacementManager() - is.initScheduleManager() - is.initTiFlashReplicaManager(codec) - is.initResourceManagerClient(pdCli) - setGlobalInfoSyncer(is) - return is, nil -} - -// Init creates a new etcd session and stores server info to etcd. -func (is *InfoSyncer) init(ctx context.Context, skipRegisterToDashboard bool) error { - err := is.newSessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) - if err != nil { - return err - } - if skipRegisterToDashboard { - return nil - } - return is.newTopologySessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) -} - -// SetSessionManager set the session manager for InfoSyncer. -func (is *InfoSyncer) SetSessionManager(manager util2.SessionManager) { - is.managerMu.mu.Lock() - defer is.managerMu.mu.Unlock() - is.managerMu.SessionManager = manager -} - -// GetSessionManager get the session manager. -func (is *InfoSyncer) GetSessionManager() util2.SessionManager { - is.managerMu.mu.RLock() - defer is.managerMu.mu.RUnlock() - return is.managerMu.SessionManager -} - -func (is *InfoSyncer) initLabelRuleManager() { - if is.pdHTTPCli == nil { - is.labelRuleManager = &mockLabelManager{labelRules: map[string][]byte{}} - return - } - is.labelRuleManager = &PDLabelManager{is.pdHTTPCli} -} - -func (is *InfoSyncer) initPlacementManager() { - if is.pdHTTPCli == nil { - is.placementManager = &mockPlacementManager{} - return - } - is.placementManager = &PDPlacementManager{is.pdHTTPCli} -} - -func (is *InfoSyncer) initResourceManagerClient(pdCli pd.Client) { - var cli pd.ResourceManagerClient = pdCli - if pdCli == nil { - cli = NewMockResourceManagerClient() - } - failpoint.Inject("managerAlreadyCreateSomeGroups", func(val failpoint.Value) { - if val.(bool) { - _, err := cli.AddResourceGroup(context.TODO(), - &rmpb.ResourceGroup{ - Name: resourcegroup.DefaultResourceGroupName, - Mode: rmpb.GroupMode_RUMode, - RUSettings: &rmpb.GroupRequestUnitSettings{ - RU: &rmpb.TokenBucket{ - Settings: &rmpb.TokenLimitSettings{FillRate: 1000000, BurstLimit: -1}, - }, - }, - }) - if err != nil { - log.Warn("fail to create default group", zap.Error(err)) - } - _, err = cli.AddResourceGroup(context.TODO(), - &rmpb.ResourceGroup{ - Name: "oltp", - Mode: rmpb.GroupMode_RUMode, - RUSettings: &rmpb.GroupRequestUnitSettings{ - RU: &rmpb.TokenBucket{ - Settings: &rmpb.TokenLimitSettings{FillRate: 1000000, BurstLimit: -1}, - }, - }, - }) - if err != nil { - log.Warn("fail to create default group", zap.Error(err)) - } - } - }) - is.resourceManagerClient = cli -} - -func (is *InfoSyncer) initTiFlashReplicaManager(codec tikv.Codec) { - if is.pdHTTPCli == nil { - is.tiflashReplicaManager = &mockTiFlashReplicaManagerCtx{tiflashProgressCache: make(map[int64]float64)} - return - } - logutil.BgLogger().Warn("init TiFlashReplicaManager") - is.tiflashReplicaManager = &TiFlashReplicaManagerCtx{pdHTTPCli: is.pdHTTPCli, tiflashProgressCache: make(map[int64]float64), codec: codec} -} - -func (is *InfoSyncer) initScheduleManager() { - if is.pdHTTPCli == nil { - is.scheduleManager = &mockScheduleManager{} - return - } - is.scheduleManager = &PDScheduleManager{is.pdHTTPCli} -} - -// GetMockTiFlash can only be used in tests to get MockTiFlash -func GetMockTiFlash() *MockTiFlash { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil - } - - m, ok := is.tiflashReplicaManager.(*mockTiFlashReplicaManagerCtx) - if ok { - return m.tiflash - } - return nil -} - -// SetMockTiFlash can only be used in tests to set MockTiFlash -func SetMockTiFlash(tiflash *MockTiFlash) { - is, err := getGlobalInfoSyncer() - if err != nil { - return - } - - m, ok := is.tiflashReplicaManager.(*mockTiFlashReplicaManagerCtx) - if ok { - m.SetMockTiFlash(tiflash) - } -} - -// GetServerInfo gets self server static information. -func GetServerInfo() (*ServerInfo, error) { - failpoint.Inject("mockGetServerInfo", func(v failpoint.Value) { - var res ServerInfo - err := json.Unmarshal([]byte(v.(string)), &res) - failpoint.Return(&res, err) - }) - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - return is.info, nil -} - -// GetServerInfoByID gets specified server static information from etcd. -func GetServerInfoByID(ctx context.Context, id string) (*ServerInfo, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - return is.getServerInfoByID(ctx, id) -} - -func (is *InfoSyncer) getServerInfoByID(ctx context.Context, id string) (*ServerInfo, error) { - if is.etcdCli == nil || id == is.info.ID { - return is.info, nil - } - key := fmt.Sprintf("%s/%s", ServerInformationPath, id) - infoMap, err := getInfo(ctx, is.etcdCli, key, keyOpDefaultRetryCnt, keyOpDefaultTimeout) - if err != nil { - return nil, err - } - info, ok := infoMap[id] - if !ok { - return nil, errors.Errorf("[info-syncer] get %s failed", key) - } - return info, nil -} - -// GetAllServerInfo gets all servers static information from etcd. -func GetAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { - failpoint.Inject("mockGetAllServerInfo", func(val failpoint.Value) { - res := make(map[string]*ServerInfo) - err := json.Unmarshal([]byte(val.(string)), &res) - failpoint.Return(res, err) - }) - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - return is.getAllServerInfo(ctx) -} - -// UpdateServerLabel updates the server label for global info syncer. -func UpdateServerLabel(ctx context.Context, labels map[string]string) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - // when etcdCli is nil, the server infos are generated from the latest config, no need to update. - if is.etcdCli == nil { - return nil - } - selfInfo, err := is.getServerInfoByID(ctx, is.info.ID) - if err != nil { - return err - } - changed := false - for k, v := range labels { - if selfInfo.Labels[k] != v { - changed = true - selfInfo.Labels[k] = v - } - } - if !changed { - return nil - } - infoBuf, err := selfInfo.Marshal() - if err != nil { - return errors.Trace(err) - } - str := string(hack.String(infoBuf)) - err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) - return err -} - -// DeleteTiFlashTableSyncProgress is used to delete the tiflash table replica sync progress. -func DeleteTiFlashTableSyncProgress(tableInfo *model.TableInfo) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - if pi := tableInfo.GetPartitionInfo(); pi != nil { - for _, p := range pi.Definitions { - is.tiflashReplicaManager.DeleteTiFlashProgressFromCache(p.ID) - } - } else { - is.tiflashReplicaManager.DeleteTiFlashProgressFromCache(tableInfo.ID) - } - return nil -} - -// MustGetTiFlashProgress gets tiflash replica progress from tiflashProgressCache, if cache not exist, it calculates progress from PD and TiFlash and inserts progress into cache. -func MustGetTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores *map[int64]pdhttp.StoreInfo) (float64, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return 0, err - } - progressCache, isExist := is.tiflashReplicaManager.GetTiFlashProgressFromCache(tableID) - if isExist { - return progressCache, nil - } - if *tiFlashStores == nil { - // We need the up-to-date information about TiFlash stores. - // Since TiFlash Replica synchronize may happen immediately after new TiFlash stores are added. - tikvStats, err := is.tiflashReplicaManager.GetStoresStat(context.Background()) - // If MockTiFlash is not set, will issue a MockTiFlashError here. - if err != nil { - return 0, err - } - stores := make(map[int64]pdhttp.StoreInfo) - for _, store := range tikvStats.Stores { - if engine.IsTiFlashHTTPResp(&store.Store) { - stores[store.Store.ID] = store - } - } - *tiFlashStores = stores - logutil.BgLogger().Debug("updateTiFlashStores finished", zap.Int("TiFlash store count", len(*tiFlashStores))) - } - progress, err := is.tiflashReplicaManager.CalculateTiFlashProgress(tableID, replicaCount, *tiFlashStores) - if err != nil { - return 0, err - } - is.tiflashReplicaManager.UpdateTiFlashProgressCache(tableID, progress) - return progress, nil -} - -// pdResponseHandler will be injected into the PD HTTP client to handle the response, -// this is to maintain consistency with the original logic without the PD HTTP client. -func pdResponseHandler(resp *http.Response, res any) error { - defer func() { terror.Log(resp.Body.Close()) }() - bodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - if resp.StatusCode == http.StatusOK { - if res != nil && bodyBytes != nil { - return json.Unmarshal(bodyBytes, res) - } - return nil - } - logutil.BgLogger().Warn("response not 200", - zap.String("method", resp.Request.Method), - zap.String("host", resp.Request.URL.Host), - zap.String("url", resp.Request.URL.RequestURI()), - zap.Int("http status", resp.StatusCode), - ) - if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusPreconditionFailed { - return ErrHTTPServiceError.FastGen("%s", bodyBytes) - } - return nil -} - -// GetAllRuleBundles is used to get all rule bundles from PD It is used to load full rules from PD while fullload infoschema. -func GetAllRuleBundles(ctx context.Context) ([]*placement.Bundle, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - - return is.placementManager.GetAllRuleBundles(ctx) -} - -// GetRuleBundle is used to get one specific rule bundle from PD. -func GetRuleBundle(ctx context.Context, name string) (*placement.Bundle, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - - return is.placementManager.GetRuleBundle(ctx, name) -} - -// PutRuleBundles is used to post specific rule bundles to PD. -func PutRuleBundles(ctx context.Context, bundles []*placement.Bundle) error { - failpoint.Inject("putRuleBundlesError", func(isServiceError failpoint.Value) { - var err error - if isServiceError.(bool) { - err = ErrHTTPServiceError.FastGen("mock service error") - } else { - err = errors.New("mock other error") - } - failpoint.Return(err) - }) - - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - - return is.placementManager.PutRuleBundles(ctx, bundles) -} - -// PutRuleBundlesWithRetry will retry for specified times when PutRuleBundles failed -func PutRuleBundlesWithRetry(ctx context.Context, bundles []*placement.Bundle, maxRetry int, interval time.Duration) (err error) { - if maxRetry < 0 { - maxRetry = 0 - } - - for i := 0; i <= maxRetry; i++ { - if err = PutRuleBundles(ctx, bundles); err == nil || ErrHTTPServiceError.Equal(err) { - return err - } - - if i != maxRetry { - logutil.BgLogger().Warn("Error occurs when PutRuleBundles, retry", zap.Error(err)) - time.Sleep(interval) - } - } - - return -} - -// GetResourceGroup is used to get one specific resource group from resource manager. -func GetResourceGroup(ctx context.Context, name string) (*rmpb.ResourceGroup, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - - return is.resourceManagerClient.GetResourceGroup(ctx, name) -} - -// ListResourceGroups is used to get all resource groups from resource manager. -func ListResourceGroups(ctx context.Context) ([]*rmpb.ResourceGroup, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - - return is.resourceManagerClient.ListResourceGroups(ctx) -} - -// AddResourceGroup is used to create one specific resource group to resource manager. -func AddResourceGroup(ctx context.Context, group *rmpb.ResourceGroup) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - _, err = is.resourceManagerClient.AddResourceGroup(ctx, group) - return err -} - -// ModifyResourceGroup is used to modify one specific resource group to resource manager. -func ModifyResourceGroup(ctx context.Context, group *rmpb.ResourceGroup) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - _, err = is.resourceManagerClient.ModifyResourceGroup(ctx, group) - return err -} - -// DeleteResourceGroup is used to delete one specific resource group from resource manager. -func DeleteResourceGroup(ctx context.Context, name string) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - _, err = is.resourceManagerClient.DeleteResourceGroup(ctx, name) - return err -} - -// PutRuleBundlesWithDefaultRetry will retry for default times -func PutRuleBundlesWithDefaultRetry(ctx context.Context, bundles []*placement.Bundle) (err error) { - return PutRuleBundlesWithRetry(ctx, bundles, SyncBundlesMaxRetry, RequestRetryInterval) -} - -func (is *InfoSyncer) getAllServerInfo(ctx context.Context) (map[string]*ServerInfo, error) { - allInfo := make(map[string]*ServerInfo) - if is.etcdCli == nil { - allInfo[is.info.ID] = getServerInfo(is.info.ID, is.info.ServerIDGetter) - return allInfo, nil - } - allInfo, err := getInfo(ctx, is.etcdCli, ServerInformationPath, keyOpDefaultRetryCnt, keyOpDefaultTimeout, clientv3.WithPrefix()) - if err != nil { - return nil, err - } - return allInfo, nil -} - -// StoreServerInfo stores self server static information to etcd. -func (is *InfoSyncer) StoreServerInfo(ctx context.Context) error { - if is.etcdCli == nil { - return nil - } - infoBuf, err := is.info.Marshal() - if err != nil { - return errors.Trace(err) - } - str := string(hack.String(infoBuf)) - err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, is.serverInfoPath, str, clientv3.WithLease(is.session.Lease())) - return err -} - -// RemoveServerInfo remove self server static information from etcd. -func (is *InfoSyncer) RemoveServerInfo() { - if is.etcdCli == nil { - return - } - err := util.DeleteKeyFromEtcd(is.serverInfoPath, is.etcdCli, keyOpDefaultRetryCnt, keyOpDefaultTimeout) - if err != nil { - logutil.BgLogger().Error("remove server info failed", zap.Error(err)) - } -} - -// TopologyInfo is the topology info -type TopologyInfo struct { - ServerVersionInfo - IP string `json:"ip"` - StatusPort uint `json:"status_port"` - DeployPath string `json:"deploy_path"` - StartTimestamp int64 `json:"start_timestamp"` - Labels map[string]string `json:"labels"` -} - -func (is *InfoSyncer) getTopologyInfo() TopologyInfo { - s, err := os.Executable() - if err != nil { - s = "" - } - dir := path.Dir(s) - return TopologyInfo{ - ServerVersionInfo: ServerVersionInfo{ - Version: mysql.TiDBReleaseVersion, - GitHash: is.info.ServerVersionInfo.GitHash, - }, - IP: is.info.IP, - StatusPort: is.info.StatusPort, - DeployPath: dir, - StartTimestamp: is.info.StartTimestamp, - Labels: is.info.Labels, - } -} - -// StoreTopologyInfo stores the topology of tidb to etcd. -func (is *InfoSyncer) StoreTopologyInfo(ctx context.Context) error { - if is.etcdCli == nil { - return nil - } - topologyInfo := is.getTopologyInfo() - infoBuf, err := json.Marshal(topologyInfo) - if err != nil { - return errors.Trace(err) - } - str := string(hack.String(infoBuf)) - key := fmt.Sprintf("%s/%s/info", TopologyInformationPath, net.JoinHostPort(is.info.IP, strconv.Itoa(int(is.info.Port)))) - // Note: no lease is required here. - err = util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, key, str) - if err != nil { - return err - } - // Initialize ttl. - return is.updateTopologyAliveness(ctx) -} - -// GetMinStartTS get min start timestamp. -// Export for testing. -func (is *InfoSyncer) GetMinStartTS() uint64 { - return is.minStartTS -} - -// storeMinStartTS stores self server min start timestamp to etcd. -func (is *InfoSyncer) storeMinStartTS(ctx context.Context) error { - if is.unprefixedEtcdCli == nil { - return nil - } - return util.PutKVToEtcd(ctx, is.unprefixedEtcdCli, keyOpDefaultRetryCnt, is.minStartTSPath, - strconv.FormatUint(is.minStartTS, 10), - clientv3.WithLease(is.session.Lease())) -} - -// RemoveMinStartTS removes self server min start timestamp from etcd. -func (is *InfoSyncer) RemoveMinStartTS() { - if is.unprefixedEtcdCli == nil { - return - } - err := util.DeleteKeyFromEtcd(is.minStartTSPath, is.unprefixedEtcdCli, keyOpDefaultRetryCnt, keyOpDefaultTimeout) - if err != nil { - logutil.BgLogger().Error("remove minStartTS failed", zap.Error(err)) - } -} - -// ReportMinStartTS reports self server min start timestamp to ETCD. -func (is *InfoSyncer) ReportMinStartTS(store kv.Storage) { - sm := is.GetSessionManager() - if sm == nil { - return - } - pl := sm.ShowProcessList() - innerSessionStartTSList := sm.GetInternalSessionStartTSList() - - // Calculate the lower limit of the start timestamp to avoid extremely old transaction delaying GC. - currentVer, err := store.CurrentVersion(kv.GlobalTxnScope) - if err != nil { - logutil.BgLogger().Error("update minStartTS failed", zap.Error(err)) - return - } - now := oracle.GetTimeFromTS(currentVer.Ver) - // GCMaxWaitTime is in seconds, GCMaxWaitTime * 1000 converts it to milliseconds. - startTSLowerLimit := oracle.GoTimeToLowerLimitStartTS(now, variable.GCMaxWaitTime.Load()*1000) - minStartTS := oracle.GoTimeToTS(now) - logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("initial minStartTS", minStartTS), - zap.Uint64("StartTSLowerLimit", startTSLowerLimit)) - for _, info := range pl { - if info.CurTxnStartTS > startTSLowerLimit && info.CurTxnStartTS < minStartTS { - minStartTS = info.CurTxnStartTS - } - - if info.CursorTracker != nil { - info.CursorTracker.RangeCursor(func(c cursor.Handle) bool { - startTS := c.GetState().StartTS - if startTS > startTSLowerLimit && startTS < minStartTS { - minStartTS = startTS - } - return true - }) - } - } - - for _, innerTS := range innerSessionStartTSList { - logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("Internal Session Transaction StartTS", innerTS)) - kv.PrintLongTimeInternalTxn(now, innerTS, false) - if innerTS > startTSLowerLimit && innerTS < minStartTS { - minStartTS = innerTS - } - } - - is.minStartTS = kv.GetMinInnerTxnStartTS(now, startTSLowerLimit, minStartTS) - - err = is.storeMinStartTS(context.Background()) - if err != nil { - logutil.BgLogger().Error("update minStartTS failed", zap.Error(err)) - } - logutil.BgLogger().Debug("ReportMinStartTS", zap.Uint64("final minStartTS", is.minStartTS)) -} - -// Done returns a channel that closes when the info syncer is no longer being refreshed. -func (is *InfoSyncer) Done() <-chan struct{} { - if is.etcdCli == nil { - return make(chan struct{}, 1) - } - return is.session.Done() -} - -// TopologyDone returns a channel that closes when the topology syncer is no longer being refreshed. -func (is *InfoSyncer) TopologyDone() <-chan struct{} { - if is.etcdCli == nil { - return make(chan struct{}, 1) - } - return is.topologySession.Done() -} - -// Restart restart the info syncer with new session leaseID and store server info to etcd again. -func (is *InfoSyncer) Restart(ctx context.Context) error { - return is.newSessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) -} - -// RestartTopology restart the topology syncer with new session leaseID and store server info to etcd again. -func (is *InfoSyncer) RestartTopology(ctx context.Context) error { - return is.newTopologySessionAndStoreServerInfo(ctx, util2.NewSessionDefaultRetryCnt) -} - -// GetAllTiDBTopology gets all tidb topology -func (is *InfoSyncer) GetAllTiDBTopology(ctx context.Context) ([]*TopologyInfo, error) { - topos := make([]*TopologyInfo, 0) - response, err := is.etcdCli.Get(ctx, TopologyInformationPath, clientv3.WithPrefix()) - if err != nil { - return nil, err - } - for _, kv := range response.Kvs { - if !strings.HasSuffix(string(kv.Key), "/info") { - continue - } - var topo *TopologyInfo - err = json.Unmarshal(kv.Value, &topo) - if err != nil { - return nil, err - } - topos = append(topos, topo) - } - return topos, nil -} - -// newSessionAndStoreServerInfo creates a new etcd session and stores server info to etcd. -func (is *InfoSyncer) newSessionAndStoreServerInfo(ctx context.Context, retryCnt int) error { - if is.etcdCli == nil { - return nil - } - logPrefix := fmt.Sprintf("[Info-syncer] %s", is.serverInfoPath) - session, err := util2.NewSession(ctx, logPrefix, is.etcdCli, retryCnt, util.SessionTTL) - if err != nil { - return err - } - is.session = session - binloginfo.RegisterStatusListener(func(status binloginfo.BinlogStatus) error { - is.info.BinlogStatus = status.String() - err := is.StoreServerInfo(ctx) - return errors.Trace(err) - }) - return is.StoreServerInfo(ctx) -} - -// newTopologySessionAndStoreServerInfo creates a new etcd session and stores server info to etcd. -func (is *InfoSyncer) newTopologySessionAndStoreServerInfo(ctx context.Context, retryCnt int) error { - if is.etcdCli == nil { - return nil - } - logPrefix := fmt.Sprintf("[topology-syncer] %s/%s", TopologyInformationPath, net.JoinHostPort(is.info.IP, strconv.Itoa(int(is.info.Port)))) - session, err := util2.NewSession(ctx, logPrefix, is.etcdCli, retryCnt, TopologySessionTTL) - if err != nil { - return err - } - - is.topologySession = session - return is.StoreTopologyInfo(ctx) -} - -// refreshTopology refreshes etcd topology with ttl stored in "/topology/tidb/ip:port/ttl". -func (is *InfoSyncer) updateTopologyAliveness(ctx context.Context) error { - if is.etcdCli == nil { - return nil - } - key := fmt.Sprintf("%s/%s/ttl", TopologyInformationPath, net.JoinHostPort(is.info.IP, strconv.Itoa(int(is.info.Port)))) - return util.PutKVToEtcd(ctx, is.etcdCli, keyOpDefaultRetryCnt, key, - fmt.Sprintf("%v", time.Now().UnixNano()), - clientv3.WithLease(is.topologySession.Lease())) -} - -// GetPrometheusAddr gets prometheus Address -func GetPrometheusAddr() (string, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return "", err - } - - // if the cache of prometheusAddr is over 10s, update the prometheusAddr - if time.Since(is.modifyTime) < TablePrometheusCacheExpiry { - return is.prometheusAddr, nil - } - return is.getPrometheusAddr() -} - -type prometheus struct { - IP string `json:"ip"` - BinaryPath string `json:"binary_path"` - Port int `json:"port"` -} - -type metricStorage struct { - PDServer struct { - MetricStorage string `json:"metric-storage"` - } `json:"pd-server"` -} - -func (is *InfoSyncer) getPrometheusAddr() (string, error) { - // Get PD servers info. - clientAvailable := is.etcdCli != nil - var pdAddrs []string - if clientAvailable { - pdAddrs = is.etcdCli.Endpoints() - } - if !clientAvailable || len(pdAddrs) == 0 { - return "", errors.Errorf("pd unavailable") - } - // Get prometheus address from pdhttp. - url := util2.ComposeURL(pdAddrs[0], pdhttp.Config) - resp, err := util2.InternalHTTPClient().Get(url) - if err != nil { - return "", err - } - defer resp.Body.Close() - var metricStorage metricStorage - dec := json.NewDecoder(resp.Body) - err = dec.Decode(&metricStorage) - if err != nil { - return "", err - } - res := metricStorage.PDServer.MetricStorage - - // Get prometheus address from etcdApi. - if res == "" { - values, err := is.getPrometheusAddrFromEtcd(TopologyPrometheus) - if err != nil { - return "", errors.Trace(err) - } - if values == "" { - return "", ErrPrometheusAddrIsNotSet - } - var prometheus prometheus - err = json.Unmarshal([]byte(values), &prometheus) - if err != nil { - return "", errors.Trace(err) - } - res = fmt.Sprintf("http://%s", net.JoinHostPort(prometheus.IP, strconv.Itoa(prometheus.Port))) - } - is.prometheusAddr = res - is.modifyTime = time.Now() - setGlobalInfoSyncer(is) - return res, nil -} - -func (is *InfoSyncer) getPrometheusAddrFromEtcd(k string) (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), keyOpDefaultTimeout) - resp, err := is.etcdCli.Get(ctx, k) - cancel() - if err != nil { - return "", errors.Trace(err) - } - if len(resp.Kvs) > 0 { - return string(resp.Kvs[0].Value), nil - } - return "", nil -} - -// getInfo gets server information from etcd according to the key and opts. -func getInfo(ctx context.Context, etcdCli *clientv3.Client, key string, retryCnt int, timeout time.Duration, opts ...clientv3.OpOption) (map[string]*ServerInfo, error) { - var err error - var resp *clientv3.GetResponse - allInfo := make(map[string]*ServerInfo) - for i := 0; i < retryCnt; i++ { - select { - case <-ctx.Done(): - err = errors.Trace(ctx.Err()) - return nil, err - default: - } - childCtx, cancel := context.WithTimeout(ctx, timeout) - resp, err = etcdCli.Get(childCtx, key, opts...) - cancel() - if err != nil { - logutil.BgLogger().Info("get key failed", zap.String("key", key), zap.Error(err)) - time.Sleep(200 * time.Millisecond) - continue - } - for _, kv := range resp.Kvs { - info := &ServerInfo{ - BinlogStatus: binloginfo.BinlogStatusUnknown.String(), - } - err = info.Unmarshal(kv.Value) - if err != nil { - logutil.BgLogger().Info("get key failed", zap.String("key", string(kv.Key)), zap.ByteString("value", kv.Value), - zap.Error(err)) - return nil, errors.Trace(err) - } - allInfo[info.ID] = info - } - return allInfo, nil - } - return nil, errors.Trace(err) -} - -// getServerInfo gets self tidb server information. -func getServerInfo(id string, serverIDGetter func() uint64) *ServerInfo { - cfg := config.GetGlobalConfig() - info := &ServerInfo{ - ID: id, - IP: cfg.AdvertiseAddress, - Port: cfg.Port, - StatusPort: cfg.Status.StatusPort, - Lease: cfg.Lease, - BinlogStatus: binloginfo.GetStatus().String(), - StartTimestamp: time.Now().Unix(), - Labels: cfg.Labels, - ServerIDGetter: serverIDGetter, - } - info.Version = mysql.ServerVersion - info.GitHash = versioninfo.TiDBGitHash - - metrics.ServerInfo.WithLabelValues(mysql.TiDBReleaseVersion, info.GitHash).Set(float64(info.StartTimestamp)) - - failpoint.Inject("mockServerInfo", func(val failpoint.Value) { - if val.(bool) { - info.StartTimestamp = 1282967700 - info.Labels = map[string]string{ - "foo": "bar", - } - } - }) - - return info -} - -// PutLabelRule synchronizes the label rule to PD. -func PutLabelRule(ctx context.Context, rule *label.Rule) error { - if rule == nil { - return nil - } - - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - if is.labelRuleManager == nil { - return nil - } - return is.labelRuleManager.PutLabelRule(ctx, rule) -} - -// UpdateLabelRules synchronizes the label rule to PD. -func UpdateLabelRules(ctx context.Context, patch *pdhttp.LabelRulePatch) error { - if patch == nil || (len(patch.DeleteRules) == 0 && len(patch.SetRules) == 0) { - return nil - } - - is, err := getGlobalInfoSyncer() - if err != nil { - return err - } - if is.labelRuleManager == nil { - return nil - } - return is.labelRuleManager.UpdateLabelRules(ctx, patch) -} - -// GetAllLabelRules gets all label rules from PD. -func GetAllLabelRules(ctx context.Context) ([]*label.Rule, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - if is.labelRuleManager == nil { - return nil, nil - } - return is.labelRuleManager.GetAllLabelRules(ctx) -} - -// GetLabelRules gets the label rules according to the given IDs from PD. -func GetLabelRules(ctx context.Context, ruleIDs []string) (map[string]*label.Rule, error) { - if len(ruleIDs) == 0 { - return nil, nil - } - - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - if is.labelRuleManager == nil { - return nil, nil - } - return is.labelRuleManager.GetLabelRules(ctx, ruleIDs) -} - -// CalculateTiFlashProgress calculates TiFlash replica progress -func CalculateTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores map[int64]pdhttp.StoreInfo) (float64, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return 0, errors.Trace(err) - } - return is.tiflashReplicaManager.CalculateTiFlashProgress(tableID, replicaCount, tiFlashStores) -} - -// UpdateTiFlashProgressCache updates tiflashProgressCache -func UpdateTiFlashProgressCache(tableID int64, progress float64) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - is.tiflashReplicaManager.UpdateTiFlashProgressCache(tableID, progress) - return nil -} - -// GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache -func GetTiFlashProgressFromCache(tableID int64) (float64, bool) { - is, err := getGlobalInfoSyncer() - if err != nil { - logutil.BgLogger().Error("GetTiFlashProgressFromCache get info sync failed", zap.Int64("tableID", tableID), zap.Error(err)) - return 0, false - } - return is.tiflashReplicaManager.GetTiFlashProgressFromCache(tableID) -} - -// CleanTiFlashProgressCache clean progress cache -func CleanTiFlashProgressCache() { - is, err := getGlobalInfoSyncer() - if err != nil { - return - } - is.tiflashReplicaManager.CleanTiFlashProgressCache() -} - -// SetTiFlashGroupConfig is a helper function to set tiflash rule group config -func SetTiFlashGroupConfig(ctx context.Context) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - logutil.BgLogger().Info("SetTiFlashGroupConfig") - return is.tiflashReplicaManager.SetTiFlashGroupConfig(ctx) -} - -// SetTiFlashPlacementRule is a helper function to set placement rule. -// It is discouraged to use SetTiFlashPlacementRule directly, -// use `ConfigureTiFlashPDForTable`/`ConfigureTiFlashPDForPartitions` instead. -func SetTiFlashPlacementRule(ctx context.Context, rule pdhttp.Rule) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - logutil.BgLogger().Info("SetTiFlashPlacementRule", zap.String("ruleID", rule.ID)) - return is.tiflashReplicaManager.SetPlacementRule(ctx, &rule) -} - -// DeleteTiFlashPlacementRules is a helper function to delete TiFlash placement rules of given physical table IDs. -func DeleteTiFlashPlacementRules(ctx context.Context, physicalTableIDs []int64) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - logutil.BgLogger().Info("DeleteTiFlashPlacementRules", zap.Int64s("physicalTableIDs", physicalTableIDs)) - rules := make([]*pdhttp.Rule, 0, len(physicalTableIDs)) - for _, id := range physicalTableIDs { - // make a rule with count 0 to delete the rule - rule := MakeNewRule(id, 0, nil) - rules = append(rules, &rule) - } - return is.tiflashReplicaManager.SetPlacementRuleBatch(ctx, rules) -} - -// GetTiFlashGroupRules to get all placement rule in a certain group. -func GetTiFlashGroupRules(ctx context.Context, group string) ([]*pdhttp.Rule, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, errors.Trace(err) - } - return is.tiflashReplicaManager.GetGroupRules(ctx, group) -} - -// GetTiFlashRegionCountFromPD is a helper function calling `/stats/region`. -func GetTiFlashRegionCountFromPD(ctx context.Context, tableID int64, regionCount *int) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - return is.tiflashReplicaManager.GetRegionCountFromPD(ctx, tableID, regionCount) -} - -// GetTiFlashStoresStat gets the TiKV store information by accessing PD's api. -func GetTiFlashStoresStat(ctx context.Context) (*pdhttp.StoresInfo, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, errors.Trace(err) - } - return is.tiflashReplicaManager.GetStoresStat(ctx) -} - -// CloseTiFlashManager closes TiFlash manager. -func CloseTiFlashManager(ctx context.Context) { - is, err := getGlobalInfoSyncer() - if err != nil { - return - } - is.tiflashReplicaManager.Close(ctx) -} - -// ConfigureTiFlashPDForTable configures pd rule for unpartitioned tables. -func ConfigureTiFlashPDForTable(id int64, count uint64, locationLabels *[]string) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - ctx := context.Background() - logutil.BgLogger().Info("ConfigureTiFlashPDForTable", zap.Int64("tableID", id), zap.Uint64("count", count)) - ruleNew := MakeNewRule(id, count, *locationLabels) - if e := is.tiflashReplicaManager.SetPlacementRule(ctx, &ruleNew); e != nil { - return errors.Trace(e) - } - return nil -} - -// ConfigureTiFlashPDForPartitions configures pd rule for all partition in partitioned tables. -func ConfigureTiFlashPDForPartitions(accel bool, definitions *[]model.PartitionDefinition, count uint64, locationLabels *[]string, tableID int64) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - ctx := context.Background() - rules := make([]*pdhttp.Rule, 0, len(*definitions)) - pids := make([]int64, 0, len(*definitions)) - for _, p := range *definitions { - logutil.BgLogger().Info("ConfigureTiFlashPDForPartitions", zap.Int64("tableID", tableID), zap.Int64("partID", p.ID), zap.Bool("accel", accel), zap.Uint64("count", count)) - ruleNew := MakeNewRule(p.ID, count, *locationLabels) - rules = append(rules, &ruleNew) - pids = append(pids, p.ID) - } - if e := is.tiflashReplicaManager.SetPlacementRuleBatch(ctx, rules); e != nil { - return errors.Trace(e) - } - if accel { - if e := is.tiflashReplicaManager.PostAccelerateScheduleBatch(ctx, pids); e != nil { - return errors.Trace(e) - } - } - return nil -} - -// StoreInternalSession is the entry function for store an internal session to SessionManager. -// return whether the session is stored successfully. -func StoreInternalSession(se any) bool { - is, err := getGlobalInfoSyncer() - if err != nil { - return false - } - sm := is.GetSessionManager() - if sm == nil { - return false - } - sm.StoreInternalSession(se) - return true -} - -// DeleteInternalSession is the entry function for delete an internal session from SessionManager. -func DeleteInternalSession(se any) { - is, err := getGlobalInfoSyncer() - if err != nil { - return - } - sm := is.GetSessionManager() - if sm == nil { - return - } - sm.DeleteInternalSession(se) -} - -// SetEtcdClient is only used for test. -func SetEtcdClient(etcdCli *clientv3.Client) { - is, err := getGlobalInfoSyncer() - - if err != nil { - return - } - is.etcdCli = etcdCli -} - -// GetEtcdClient is only used for test. -func GetEtcdClient() *clientv3.Client { - is, err := getGlobalInfoSyncer() - - if err != nil { - return nil - } - return is.etcdCli -} - -// GetPDScheduleConfig gets the schedule information from pd -func GetPDScheduleConfig(ctx context.Context) (map[string]any, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, errors.Trace(err) - } - return is.scheduleManager.GetScheduleConfig(ctx) -} - -// SetPDScheduleConfig sets the schedule information for pd -func SetPDScheduleConfig(ctx context.Context, config map[string]any) error { - is, err := getGlobalInfoSyncer() - if err != nil { - return errors.Trace(err) - } - return is.scheduleManager.SetScheduleConfig(ctx, config) -} - -// TiProxyServerInfo is the server info for TiProxy. -type TiProxyServerInfo struct { - Version string `json:"version"` - GitHash string `json:"git_hash"` - IP string `json:"ip"` - Port string `json:"port"` - StatusPort string `json:"status_port"` - StartTimestamp int64 `json:"start_timestamp"` -} - -// GetTiProxyServerInfo gets all TiProxy servers information from etcd. -func GetTiProxyServerInfo(ctx context.Context) (map[string]*TiProxyServerInfo, error) { - failpoint.Inject("mockGetTiProxyServerInfo", func(val failpoint.Value) { - res := make(map[string]*TiProxyServerInfo) - err := json.Unmarshal([]byte(val.(string)), &res) - failpoint.Return(res, err) - }) - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - return is.getTiProxyServerInfo(ctx) -} - -func (is *InfoSyncer) getTiProxyServerInfo(ctx context.Context) (map[string]*TiProxyServerInfo, error) { - // In test. - if is.etcdCli == nil { - return nil, nil - } - - var err error - var resp *clientv3.GetResponse - allInfo := make(map[string]*TiProxyServerInfo) - for i := 0; i < keyOpDefaultRetryCnt; i++ { - if ctx.Err() != nil { - return nil, errors.Trace(ctx.Err()) - } - childCtx, cancel := context.WithTimeout(ctx, keyOpDefaultTimeout) - resp, err = is.etcdCli.Get(childCtx, TopologyTiProxy, clientv3.WithPrefix()) - cancel() - if err != nil { - logutil.BgLogger().Info("get key failed", zap.String("key", TopologyTiProxy), zap.Error(err)) - time.Sleep(200 * time.Millisecond) - continue - } - for _, kv := range resp.Kvs { - key := string(kv.Key) - if !strings.HasSuffix(key, infoSuffix) { - continue - } - addr := key[len(TopologyTiProxy)+1 : len(key)-len(infoSuffix)] - var info TiProxyServerInfo - err = json.Unmarshal(kv.Value, &info) - if err != nil { - logutil.BgLogger().Info("unmarshal key failed", zap.String("key", key), zap.ByteString("value", kv.Value), - zap.Error(err)) - return nil, errors.Trace(err) - } - allInfo[addr] = &info - } - return allInfo, nil - } - return nil, errors.Trace(err) -} - -// TiCDCInfo is the server info for TiCDC. -type TiCDCInfo struct { - ID string `json:"id"` - Address string `json:"address"` - Version string `json:"version"` - GitHash string `json:"git-hash"` - DeployPath string `json:"deploy-path"` - StartTimestamp int64 `json:"start-timestamp"` - ClusterID string `json:"cluster-id"` -} - -// GetTiCDCServerInfo gets all TiCDC servers information from etcd. -func GetTiCDCServerInfo(ctx context.Context) ([]*TiCDCInfo, error) { - is, err := getGlobalInfoSyncer() - if err != nil { - return nil, err - } - return is.getTiCDCServerInfo(ctx) -} - -func (is *InfoSyncer) getTiCDCServerInfo(ctx context.Context) ([]*TiCDCInfo, error) { - // In test. - if is.etcdCli == nil { - return nil, nil - } - - var err error - var resp *clientv3.GetResponse - allInfo := make([]*TiCDCInfo, 0) - for i := 0; i < keyOpDefaultRetryCnt; i++ { - if ctx.Err() != nil { - return nil, errors.Trace(ctx.Err()) - } - childCtx, cancel := context.WithTimeout(ctx, keyOpDefaultTimeout) - resp, err = is.etcdCli.Get(childCtx, TopologyTiCDC, clientv3.WithPrefix()) - cancel() - if err != nil { - logutil.BgLogger().Info("get key failed", zap.String("key", TopologyTiCDC), zap.Error(err)) - time.Sleep(200 * time.Millisecond) - continue - } - for _, kv := range resp.Kvs { - key := string(kv.Key) - keyParts := strings.Split(key, "/") - if len(keyParts) < 3 { - logutil.BgLogger().Info("invalid ticdc key", zap.String("key", key)) - continue - } - clusterID := keyParts[1] - - var info TiCDCInfo - err := json.Unmarshal(kv.Value, &info) - if err != nil { - logutil.BgLogger().Info("unmarshal key failed", zap.String("key", key), zap.ByteString("value", kv.Value), - zap.Error(err)) - return nil, errors.Trace(err) - } - info.Version = strings.TrimPrefix(info.Version, "v") - info.ClusterID = clusterID - allInfo = append(allInfo, &info) - } - return allInfo, nil - } - return nil, errors.Trace(err) -} diff --git a/pkg/domain/infosync/tiflash_manager.go b/pkg/domain/infosync/tiflash_manager.go index 040a7611e50ad..bd84d5fb4c043 100644 --- a/pkg/domain/infosync/tiflash_manager.go +++ b/pkg/domain/infosync/tiflash_manager.go @@ -95,11 +95,11 @@ func getTiFlashPeerWithoutLagCount(tiFlashStores map[int64]pd.StoreInfo, keyspac for _, store := range tiFlashStores { regionReplica := make(map[int64]int) err := helper.CollectTiFlashStatus(store.Store.StatusAddress, keyspaceID, tableID, ®ionReplica) - if _, _err_ := failpoint.Eval(_curpkg_("OneTiFlashStoreDown")); _err_ == nil { + failpoint.Inject("OneTiFlashStoreDown", func() { if store.Store.StateName == "Down" { err = errors.New("mock TiFlasah down") } - } + }) if err != nil { logutil.BgLogger().Error("Fail to get peer status from TiFlash.", zap.Int64("tableID", tableID)) diff --git a/pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ b/pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ deleted file mode 100644 index bd84d5fb4c043..0000000000000 --- a/pkg/domain/infosync/tiflash_manager.go__failpoint_stash__ +++ /dev/null @@ -1,893 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package infosync - -import ( - "bytes" - "context" - "encoding/hex" - "fmt" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "sync" - "time" - - "github.com/gorilla/mux" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/syncutil" - "github.com/tikv/client-go/v2/tikv" - pd "github.com/tikv/pd/client/http" - "go.uber.org/zap" -) - -var ( - _ TiFlashReplicaManager = &TiFlashReplicaManagerCtx{} - _ TiFlashReplicaManager = &mockTiFlashReplicaManagerCtx{} -) - -// TiFlashReplicaManager manages placement settings and replica progress for TiFlash. -type TiFlashReplicaManager interface { - // SetTiFlashGroupConfig sets the group index of the tiflash placement rule - SetTiFlashGroupConfig(ctx context.Context) error - // SetPlacementRule is a helper function to set placement rule. - SetPlacementRule(ctx context.Context, rule *pd.Rule) error - // SetPlacementRuleBatch is a helper function to set a batch of placement rules. - SetPlacementRuleBatch(ctx context.Context, rules []*pd.Rule) error - // DeletePlacementRule is to delete placement rule for certain group. - DeletePlacementRule(ctx context.Context, group string, ruleID string) error - // GetGroupRules to get all placement rule in a certain group. - GetGroupRules(ctx context.Context, group string) ([]*pd.Rule, error) - // PostAccelerateScheduleBatch sends `regions/accelerate-schedule/batch` request. - PostAccelerateScheduleBatch(ctx context.Context, tableIDs []int64) error - // GetRegionCountFromPD is a helper function calling `/stats/region`. - GetRegionCountFromPD(ctx context.Context, tableID int64, regionCount *int) error - // GetStoresStat gets the TiKV store information by accessing PD's api. - GetStoresStat(ctx context.Context) (*pd.StoresInfo, error) - // CalculateTiFlashProgress calculates TiFlash replica progress - CalculateTiFlashProgress(tableID int64, replicaCount uint64, TiFlashStores map[int64]pd.StoreInfo) (float64, error) - // UpdateTiFlashProgressCache updates tiflashProgressCache - UpdateTiFlashProgressCache(tableID int64, progress float64) - // GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache - GetTiFlashProgressFromCache(tableID int64) (float64, bool) - // DeleteTiFlashProgressFromCache delete tiflash replica progress from tiflashProgressCache - DeleteTiFlashProgressFromCache(tableID int64) - // CleanTiFlashProgressCache clean progress cache - CleanTiFlashProgressCache() - // Close is to close TiFlashReplicaManager - Close(ctx context.Context) -} - -// TiFlashReplicaManagerCtx manages placement with pd and replica progress for TiFlash. -type TiFlashReplicaManagerCtx struct { - pdHTTPCli pd.Client - sync.RWMutex // protect tiflashProgressCache - tiflashProgressCache map[int64]float64 - codec tikv.Codec -} - -// Close is called to close TiFlashReplicaManagerCtx. -func (*TiFlashReplicaManagerCtx) Close(context.Context) {} - -func getTiFlashPeerWithoutLagCount(tiFlashStores map[int64]pd.StoreInfo, keyspaceID tikv.KeyspaceID, tableID int64) (int, error) { - // storeIDs -> regionID, PD will not create two peer on the same store - var flashPeerCount int - for _, store := range tiFlashStores { - regionReplica := make(map[int64]int) - err := helper.CollectTiFlashStatus(store.Store.StatusAddress, keyspaceID, tableID, ®ionReplica) - failpoint.Inject("OneTiFlashStoreDown", func() { - if store.Store.StateName == "Down" { - err = errors.New("mock TiFlasah down") - } - }) - if err != nil { - logutil.BgLogger().Error("Fail to get peer status from TiFlash.", - zap.Int64("tableID", tableID)) - // Just skip down or offline or tomestone stores, because PD will migrate regions from these stores. - if store.Store.StateName == "Up" || store.Store.StateName == "Disconnected" { - return 0, err - } - continue - } - flashPeerCount += len(regionReplica) - } - return flashPeerCount, nil -} - -// calculateTiFlashProgress calculates progress based on the region status from PD and TiFlash. -func calculateTiFlashProgress(keyspaceID tikv.KeyspaceID, tableID int64, replicaCount uint64, tiFlashStores map[int64]pd.StoreInfo) (float64, error) { - var regionCount int - if err := GetTiFlashRegionCountFromPD(context.Background(), tableID, ®ionCount); err != nil { - logutil.BgLogger().Error("Fail to get regionCount from PD.", - zap.Int64("tableID", tableID)) - return 0, errors.Trace(err) - } - - if regionCount == 0 { - logutil.BgLogger().Warn("region count getting from PD is 0.", - zap.Int64("tableID", tableID)) - return 0, fmt.Errorf("region count getting from PD is 0") - } - - tiflashPeerCount, err := getTiFlashPeerWithoutLagCount(tiFlashStores, keyspaceID, tableID) - if err != nil { - logutil.BgLogger().Error("Fail to get peer count from TiFlash.", - zap.Int64("tableID", tableID)) - return 0, errors.Trace(err) - } - progress := float64(tiflashPeerCount) / float64(regionCount*int(replicaCount)) - if progress > 1 { // when pd do balance - logutil.BgLogger().Debug("TiFlash peer count > pd peer count, maybe doing balance.", - zap.Int64("tableID", tableID), zap.Int("tiflashPeerCount", tiflashPeerCount), zap.Int("regionCount", regionCount), zap.Uint64("replicaCount", replicaCount)) - progress = 1 - } - if progress < 1 { - logutil.BgLogger().Debug("TiFlash replica progress < 1.", - zap.Int64("tableID", tableID), zap.Int("tiflashPeerCount", tiflashPeerCount), zap.Int("regionCount", regionCount), zap.Uint64("replicaCount", replicaCount)) - } - return progress, nil -} - -func encodeRule(c tikv.Codec, rule *pd.Rule) { - rule.StartKey, rule.EndKey = c.EncodeRange(rule.StartKey, rule.EndKey) - rule.ID = encodeRuleID(c, rule.ID) -} - -// encodeRule encodes the rule ID by the following way: -// 1. if the codec is in API V1 then the rule ID is not encoded, should be like "table--r". -// 2. if the codec is in API V2 then the rule ID is encoded, -// should be like "keyspace--table--r". -func encodeRuleID(c tikv.Codec, ruleID string) string { - if c.GetAPIVersion() == kvrpcpb.APIVersion_V2 { - return fmt.Sprintf("keyspace-%v-%s", c.GetKeyspaceID(), ruleID) - } - return ruleID -} - -// CalculateTiFlashProgress calculates TiFlash replica progress. -func (m *TiFlashReplicaManagerCtx) CalculateTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores map[int64]pd.StoreInfo) (float64, error) { - return calculateTiFlashProgress(m.codec.GetKeyspaceID(), tableID, replicaCount, tiFlashStores) -} - -// UpdateTiFlashProgressCache updates tiflashProgressCache -func (m *TiFlashReplicaManagerCtx) UpdateTiFlashProgressCache(tableID int64, progress float64) { - m.Lock() - defer m.Unlock() - m.tiflashProgressCache[tableID] = progress -} - -// GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache -func (m *TiFlashReplicaManagerCtx) GetTiFlashProgressFromCache(tableID int64) (float64, bool) { - m.RLock() - defer m.RUnlock() - progress, ok := m.tiflashProgressCache[tableID] - return progress, ok -} - -// DeleteTiFlashProgressFromCache delete tiflash replica progress from tiflashProgressCache -func (m *TiFlashReplicaManagerCtx) DeleteTiFlashProgressFromCache(tableID int64) { - m.Lock() - defer m.Unlock() - delete(m.tiflashProgressCache, tableID) -} - -// CleanTiFlashProgressCache clean progress cache -func (m *TiFlashReplicaManagerCtx) CleanTiFlashProgressCache() { - m.Lock() - defer m.Unlock() - m.tiflashProgressCache = make(map[int64]float64) -} - -// SetTiFlashGroupConfig sets the tiflash's rule group config -func (m *TiFlashReplicaManagerCtx) SetTiFlashGroupConfig(ctx context.Context) error { - groupConfig, err := m.pdHTTPCli.GetPlacementRuleGroupByID(ctx, placement.TiFlashRuleGroupID) - if err != nil { - return errors.Trace(err) - } - if groupConfig != nil && groupConfig.Index == placement.RuleIndexTiFlash && !groupConfig.Override { - return nil - } - groupConfig = &pd.RuleGroup{ - ID: placement.TiFlashRuleGroupID, - Index: placement.RuleIndexTiFlash, - Override: false, - } - return m.pdHTTPCli.SetPlacementRuleGroup(ctx, groupConfig) -} - -// SetPlacementRule is a helper function to set placement rule. -func (m *TiFlashReplicaManagerCtx) SetPlacementRule(ctx context.Context, rule *pd.Rule) error { - encodeRule(m.codec, rule) - return m.doSetPlacementRule(ctx, rule) -} - -func (m *TiFlashReplicaManagerCtx) doSetPlacementRule(ctx context.Context, rule *pd.Rule) error { - if err := m.SetTiFlashGroupConfig(ctx); err != nil { - return err - } - if rule.Count == 0 { - return m.pdHTTPCli.DeletePlacementRule(ctx, rule.GroupID, rule.ID) - } - return m.pdHTTPCli.SetPlacementRule(ctx, rule) -} - -// SetPlacementRuleBatch is a helper function to set a batch of placement rules. -func (m *TiFlashReplicaManagerCtx) SetPlacementRuleBatch(ctx context.Context, rules []*pd.Rule) error { - r := make([]*pd.Rule, 0, len(rules)) - for _, rule := range rules { - encodeRule(m.codec, rule) - r = append(r, rule) - } - return m.doSetPlacementRuleBatch(ctx, r) -} - -func (m *TiFlashReplicaManagerCtx) doSetPlacementRuleBatch(ctx context.Context, rules []*pd.Rule) error { - if err := m.SetTiFlashGroupConfig(ctx); err != nil { - return err - } - ruleOps := make([]*pd.RuleOp, 0, len(rules)) - for i, r := range rules { - if r.Count == 0 { - ruleOps = append(ruleOps, &pd.RuleOp{ - Rule: rules[i], - Action: pd.RuleOpDel, - }) - } else { - ruleOps = append(ruleOps, &pd.RuleOp{ - Rule: rules[i], - Action: pd.RuleOpAdd, - }) - } - } - return m.pdHTTPCli.SetPlacementRuleInBatch(ctx, ruleOps) -} - -// DeletePlacementRule is to delete placement rule for certain group. -func (m *TiFlashReplicaManagerCtx) DeletePlacementRule(ctx context.Context, group string, ruleID string) error { - ruleID = encodeRuleID(m.codec, ruleID) - return m.pdHTTPCli.DeletePlacementRule(ctx, group, ruleID) -} - -// GetGroupRules to get all placement rule in a certain group. -func (m *TiFlashReplicaManagerCtx) GetGroupRules(ctx context.Context, group string) ([]*pd.Rule, error) { - return m.pdHTTPCli.GetPlacementRulesByGroup(ctx, group) -} - -// PostAccelerateScheduleBatch sends `regions/batch-accelerate-schedule` request. -func (m *TiFlashReplicaManagerCtx) PostAccelerateScheduleBatch(ctx context.Context, tableIDs []int64) error { - if len(tableIDs) == 0 { - return nil - } - input := make([]*pd.KeyRange, 0, len(tableIDs)) - for _, tableID := range tableIDs { - startKey := tablecodec.GenTableRecordPrefix(tableID) - endKey := tablecodec.EncodeTablePrefix(tableID + 1) - startKey, endKey = m.codec.EncodeRegionRange(startKey, endKey) - input = append(input, pd.NewKeyRange(startKey, endKey)) - } - return m.pdHTTPCli.AccelerateScheduleInBatch(ctx, input) -} - -// GetRegionCountFromPD is a helper function calling `/stats/region`. -func (m *TiFlashReplicaManagerCtx) GetRegionCountFromPD(ctx context.Context, tableID int64, regionCount *int) error { - startKey := tablecodec.GenTableRecordPrefix(tableID) - endKey := tablecodec.EncodeTablePrefix(tableID + 1) - startKey, endKey = m.codec.EncodeRegionRange(startKey, endKey) - stats, err := m.pdHTTPCli.GetRegionStatusByKeyRange(ctx, pd.NewKeyRange(startKey, endKey), true) - if err != nil { - return err - } - *regionCount = stats.Count - return nil -} - -// GetStoresStat gets the TiKV store information by accessing PD's api. -func (m *TiFlashReplicaManagerCtx) GetStoresStat(ctx context.Context) (*pd.StoresInfo, error) { - return m.pdHTTPCli.GetStores(ctx) -} - -type mockTiFlashReplicaManagerCtx struct { - sync.RWMutex - // Set to nil if there is no need to set up a mock TiFlash server. - // Otherwise use NewMockTiFlash to create one. - tiflash *MockTiFlash - tiflashProgressCache map[int64]float64 -} - -func makeBaseRule() pd.Rule { - return pd.Rule{ - GroupID: placement.TiFlashRuleGroupID, - ID: "", - Index: placement.RuleIndexTiFlash, - Override: false, - Role: pd.Learner, - Count: 2, - LabelConstraints: []pd.LabelConstraint{ - { - Key: "engine", - Op: pd.In, - Values: []string{"tiflash"}, - }, - }, - } -} - -// MakeNewRule creates a pd rule for TiFlash. -func MakeNewRule(id int64, count uint64, locationLabels []string) pd.Rule { - ruleID := MakeRuleID(id) - startKey := tablecodec.GenTableRecordPrefix(id) - endKey := tablecodec.EncodeTablePrefix(id + 1) - - ruleNew := makeBaseRule() - ruleNew.ID = ruleID - ruleNew.StartKey = startKey - ruleNew.EndKey = endKey - ruleNew.Count = int(count) - ruleNew.LocationLabels = locationLabels - - return ruleNew -} - -// MakeRuleID creates a rule ID for TiFlash with given TableID. -// This interface is exported for the module who wants to manipulate the TiFlash rule. -// The rule ID is in the format of "table--r". -// NOTE: PLEASE DO NOT write the rule ID manually, use this interface instead. -func MakeRuleID(id int64) string { - return fmt.Sprintf("table-%v-r", id) -} - -type mockTiFlashTableInfo struct { - Regions []int - Accel bool -} - -func (m *mockTiFlashTableInfo) String() string { - regionStr := "" - for _, s := range m.Regions { - regionStr = regionStr + strconv.Itoa(s) + "\n" - } - if regionStr == "" { - regionStr = "\n" - } - return fmt.Sprintf("%v\n%v", len(m.Regions), regionStr) -} - -// MockTiFlash mocks a TiFlash, with necessary Pd support. -type MockTiFlash struct { - syncutil.Mutex - groupIndex int - StatusAddr string - StatusServer *httptest.Server - SyncStatus map[int]mockTiFlashTableInfo - StoreInfo map[uint64]pd.MetaStore - GlobalTiFlashPlacementRules map[string]*pd.Rule - PdEnabled bool - TiflashDelay time.Duration - StartTime time.Time - NotAvailable bool - NetworkError bool -} - -func (tiflash *MockTiFlash) setUpMockTiFlashHTTPServer() { - tiflash.Lock() - defer tiflash.Unlock() - // mock TiFlash http server - router := mux.NewRouter() - server := httptest.NewServer(router) - // mock store stats stat - statusAddr := strings.TrimPrefix(server.URL, "http://") - statusAddrVec := strings.Split(statusAddr, ":") - statusPort, _ := strconv.Atoi(statusAddrVec[1]) - router.HandleFunc("/tiflash/sync-status/keyspace/{keyspaceid:\\d+}/table/{tableid:\\d+}", func(w http.ResponseWriter, req *http.Request) { - tiflash.Lock() - defer tiflash.Unlock() - if tiflash.NetworkError { - w.WriteHeader(http.StatusNotFound) - return - } - params := mux.Vars(req) - tableID, err := strconv.Atoi(params["tableid"]) - if err != nil { - w.WriteHeader(http.StatusNotFound) - return - } - table, ok := tiflash.SyncStatus[tableID] - if tiflash.NotAvailable { - // No region is available, so the table is not available. - table.Regions = []int{} - } - if !ok { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("0\n\n")) - return - } - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(table.String())) - }) - router.HandleFunc("/config", func(w http.ResponseWriter, _ *http.Request) { - tiflash.Lock() - defer tiflash.Unlock() - s := fmt.Sprintf("{\n \"engine-store\": {\n \"http_port\": %v\n }\n}", statusPort) - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte(s)) - }) - tiflash.StatusServer = server - tiflash.StatusAddr = statusAddr -} - -// NewMockTiFlash creates a MockTiFlash with a mocked TiFlash server. -func NewMockTiFlash() *MockTiFlash { - tiflash := &MockTiFlash{ - StatusAddr: "", - StatusServer: nil, - SyncStatus: make(map[int]mockTiFlashTableInfo), - StoreInfo: make(map[uint64]pd.MetaStore), - GlobalTiFlashPlacementRules: make(map[string]*pd.Rule), - PdEnabled: true, - TiflashDelay: 0, - StartTime: time.Now(), - NotAvailable: false, - } - tiflash.setUpMockTiFlashHTTPServer() - return tiflash -} - -// HandleSetPlacementRule is mock function for SetTiFlashPlacementRule. -func (tiflash *MockTiFlash) HandleSetPlacementRule(rule *pd.Rule) error { - tiflash.Lock() - defer tiflash.Unlock() - tiflash.groupIndex = placement.RuleIndexTiFlash - if !tiflash.PdEnabled { - logutil.BgLogger().Info("pd server is manually disabled, just quit") - return nil - } - - if rule.Count == 0 { - delete(tiflash.GlobalTiFlashPlacementRules, rule.ID) - } else { - tiflash.GlobalTiFlashPlacementRules[rule.ID] = rule - } - // Pd shall schedule TiFlash, we can mock here - tid := 0 - _, err := fmt.Sscanf(rule.ID, "table-%d-r", &tid) - if err != nil { - return errors.New("Can't parse rule") - } - // Set up mock TiFlash replica - f := func() { - if z, ok := tiflash.SyncStatus[tid]; ok { - z.Regions = []int{1} - tiflash.SyncStatus[tid] = z - } else { - tiflash.SyncStatus[tid] = mockTiFlashTableInfo{ - Regions: []int{1}, - Accel: false, - } - } - } - if tiflash.TiflashDelay > 0 { - go func() { - time.Sleep(tiflash.TiflashDelay) - logutil.BgLogger().Warn("TiFlash replica is available after delay", zap.Duration("duration", tiflash.TiflashDelay)) - f() - }() - } else { - f() - } - return nil -} - -// HandleSetPlacementRuleBatch is mock function for batch SetTiFlashPlacementRule. -func (tiflash *MockTiFlash) HandleSetPlacementRuleBatch(rules []*pd.Rule) error { - for _, r := range rules { - if err := tiflash.HandleSetPlacementRule(r); err != nil { - return err - } - } - return nil -} - -// ResetSyncStatus is mock function for reset sync status. -func (tiflash *MockTiFlash) ResetSyncStatus(tableID int, canAvailable bool) { - tiflash.Lock() - defer tiflash.Unlock() - if canAvailable { - if z, ok := tiflash.SyncStatus[tableID]; ok { - z.Regions = []int{1} - tiflash.SyncStatus[tableID] = z - } else { - tiflash.SyncStatus[tableID] = mockTiFlashTableInfo{ - Regions: []int{1}, - Accel: false, - } - } - } else { - delete(tiflash.SyncStatus, tableID) - } -} - -// HandleDeletePlacementRule is mock function for DeleteTiFlashPlacementRule. -func (tiflash *MockTiFlash) HandleDeletePlacementRule(_ string, ruleID string) { - tiflash.Lock() - defer tiflash.Unlock() - delete(tiflash.GlobalTiFlashPlacementRules, ruleID) -} - -// HandleGetGroupRules is mock function for GetTiFlashGroupRules. -func (tiflash *MockTiFlash) HandleGetGroupRules(_ string) ([]*pd.Rule, error) { - tiflash.Lock() - defer tiflash.Unlock() - var result = make([]*pd.Rule, 0) - for _, item := range tiflash.GlobalTiFlashPlacementRules { - result = append(result, item) - } - return result, nil -} - -// HandlePostAccelerateSchedule is mock function for PostAccelerateSchedule -func (tiflash *MockTiFlash) HandlePostAccelerateSchedule(endKey string) error { - tiflash.Lock() - defer tiflash.Unlock() - tableID := helper.GetTiFlashTableIDFromEndKey(endKey) - - table, ok := tiflash.SyncStatus[int(tableID)] - if ok { - table.Accel = true - tiflash.SyncStatus[int(tableID)] = table - } else { - tiflash.SyncStatus[int(tableID)] = mockTiFlashTableInfo{ - Regions: []int{}, - Accel: true, - } - } - return nil -} - -// HandleGetPDRegionRecordStats is mock function for GetRegionCountFromPD. -// It currently always returns 1 Region for convenience. -func (*MockTiFlash) HandleGetPDRegionRecordStats(int64) pd.RegionStats { - return pd.RegionStats{ - Count: 1, - } -} - -// AddStore is mock function for adding store info into MockTiFlash. -func (tiflash *MockTiFlash) AddStore(storeID uint64, address string) { - tiflash.StoreInfo[storeID] = pd.MetaStore{ - ID: int64(storeID), - Address: address, - State: 0, - StateName: "Up", - Version: "4.0.0-alpha", - StatusAddress: tiflash.StatusAddr, - GitHash: "mock-tikv-githash", - StartTimestamp: tiflash.StartTime.Unix(), - Labels: []pd.StoreLabel{{ - Key: "engine", - Value: "tiflash", - }}, - } -} - -// HandleGetStoresStat is mock function for GetStoresStat. -// It returns address of our mocked TiFlash server. -func (tiflash *MockTiFlash) HandleGetStoresStat() *pd.StoresInfo { - tiflash.Lock() - defer tiflash.Unlock() - if len(tiflash.StoreInfo) == 0 { - // default Store - return &pd.StoresInfo{ - Count: 1, - Stores: []pd.StoreInfo{ - { - Store: pd.MetaStore{ - ID: 1, - Address: "127.0.0.1:3930", - State: 0, - StateName: "Up", - Version: "4.0.0-alpha", - StatusAddress: tiflash.StatusAddr, - GitHash: "mock-tikv-githash", - StartTimestamp: tiflash.StartTime.Unix(), - Labels: []pd.StoreLabel{{ - Key: "engine", - Value: "tiflash", - }}, - }, - }, - }, - } - } - stores := make([]pd.StoreInfo, 0, len(tiflash.StoreInfo)) - for _, storeInfo := range tiflash.StoreInfo { - stores = append(stores, pd.StoreInfo{Store: storeInfo, Status: pd.StoreStatus{}}) - } - return &pd.StoresInfo{ - Count: len(tiflash.StoreInfo), - Stores: stores, - } -} - -// SetRuleGroupIndex sets the group index of tiflash -func (tiflash *MockTiFlash) SetRuleGroupIndex(groupIndex int) { - tiflash.Lock() - defer tiflash.Unlock() - tiflash.groupIndex = groupIndex -} - -// GetRuleGroupIndex gets the group index of tiflash -func (tiflash *MockTiFlash) GetRuleGroupIndex() int { - tiflash.Lock() - defer tiflash.Unlock() - return tiflash.groupIndex -} - -// Compare supposed rule, and we actually get from TableInfo -func isRuleMatch(rule pd.Rule, startKey []byte, endKey []byte, count int, labels []string) bool { - // Compute startKey - if !(bytes.Equal(rule.StartKey, startKey) && bytes.Equal(rule.EndKey, endKey)) { - return false - } - ok := false - for _, c := range rule.LabelConstraints { - if c.Key == "engine" && len(c.Values) == 1 && c.Values[0] == "tiflash" && c.Op == pd.In { - ok = true - break - } - } - if !ok { - return false - } - - if len(rule.LocationLabels) != len(labels) { - return false - } - for i, lb := range labels { - if lb != rule.LocationLabels[i] { - return false - } - } - if rule.Count != count { - return false - } - if rule.Role != pd.Learner { - return false - } - return true -} - -// CheckPlacementRule find if a given rule precisely matches already set rules. -func (tiflash *MockTiFlash) CheckPlacementRule(rule pd.Rule) bool { - tiflash.Lock() - defer tiflash.Unlock() - for _, r := range tiflash.GlobalTiFlashPlacementRules { - if isRuleMatch(rule, r.StartKey, r.EndKey, r.Count, r.LocationLabels) { - return true - } - } - return false -} - -// GetPlacementRule find a rule by name. -func (tiflash *MockTiFlash) GetPlacementRule(ruleName string) (*pd.Rule, bool) { - tiflash.Lock() - defer tiflash.Unlock() - if r, ok := tiflash.GlobalTiFlashPlacementRules[ruleName]; ok { - p := r - return p, ok - } - return nil, false -} - -// CleanPlacementRules cleans all placement rules. -func (tiflash *MockTiFlash) CleanPlacementRules() { - tiflash.Lock() - defer tiflash.Unlock() - tiflash.GlobalTiFlashPlacementRules = make(map[string]*pd.Rule) -} - -// PlacementRulesLen gets length of all currently set placement rules. -func (tiflash *MockTiFlash) PlacementRulesLen() int { - tiflash.Lock() - defer tiflash.Unlock() - return len(tiflash.GlobalTiFlashPlacementRules) -} - -// GetTableSyncStatus returns table sync status by given tableID. -func (tiflash *MockTiFlash) GetTableSyncStatus(tableID int) (*mockTiFlashTableInfo, bool) { - tiflash.Lock() - defer tiflash.Unlock() - if r, ok := tiflash.SyncStatus[tableID]; ok { - p := r - return &p, ok - } - return nil, false -} - -// PdSwitch controls if pd is enabled. -func (tiflash *MockTiFlash) PdSwitch(enabled bool) { - tiflash.Lock() - defer tiflash.Unlock() - tiflash.PdEnabled = enabled -} - -// SetNetworkError sets network error state. -func (tiflash *MockTiFlash) SetNetworkError(e bool) { - tiflash.Lock() - defer tiflash.Unlock() - tiflash.NetworkError = e -} - -// CalculateTiFlashProgress return truncated string to avoid float64 comparison. -func (*mockTiFlashReplicaManagerCtx) CalculateTiFlashProgress(tableID int64, replicaCount uint64, tiFlashStores map[int64]pd.StoreInfo) (float64, error) { - return calculateTiFlashProgress(tikv.NullspaceID, tableID, replicaCount, tiFlashStores) -} - -// UpdateTiFlashProgressCache updates tiflashProgressCache -func (m *mockTiFlashReplicaManagerCtx) UpdateTiFlashProgressCache(tableID int64, progress float64) { - m.Lock() - defer m.Unlock() - m.tiflashProgressCache[tableID] = progress -} - -// GetTiFlashProgressFromCache gets tiflash replica progress from tiflashProgressCache -func (m *mockTiFlashReplicaManagerCtx) GetTiFlashProgressFromCache(tableID int64) (float64, bool) { - m.RLock() - defer m.RUnlock() - progress, ok := m.tiflashProgressCache[tableID] - return progress, ok -} - -// DeleteTiFlashProgressFromCache delete tiflash replica progress from tiflashProgressCache -func (m *mockTiFlashReplicaManagerCtx) DeleteTiFlashProgressFromCache(tableID int64) { - m.Lock() - defer m.Unlock() - delete(m.tiflashProgressCache, tableID) -} - -// CleanTiFlashProgressCache clean progress cache -func (m *mockTiFlashReplicaManagerCtx) CleanTiFlashProgressCache() { - m.Lock() - defer m.Unlock() - m.tiflashProgressCache = make(map[int64]float64) -} - -// SetMockTiFlash is set a mock TiFlash server. -func (m *mockTiFlashReplicaManagerCtx) SetMockTiFlash(tiflash *MockTiFlash) { - m.Lock() - defer m.Unlock() - m.tiflash = tiflash -} - -// SetTiFlashGroupConfig sets the tiflash's rule group config -func (m *mockTiFlashReplicaManagerCtx) SetTiFlashGroupConfig(_ context.Context) error { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return nil - } - m.tiflash.SetRuleGroupIndex(placement.RuleIndexTiFlash) - return nil -} - -// SetPlacementRule is a helper function to set placement rule. -func (m *mockTiFlashReplicaManagerCtx) SetPlacementRule(_ context.Context, rule *pd.Rule) error { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return nil - } - return m.tiflash.HandleSetPlacementRule(rule) -} - -// SetPlacementRuleBatch is a helper function to set a batch of placement rules. -func (m *mockTiFlashReplicaManagerCtx) SetPlacementRuleBatch(_ context.Context, rules []*pd.Rule) error { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return nil - } - return m.tiflash.HandleSetPlacementRuleBatch(rules) -} - -// DeletePlacementRule is to delete placement rule for certain group. -func (m *mockTiFlashReplicaManagerCtx) DeletePlacementRule(_ context.Context, group string, ruleID string) error { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return nil - } - logutil.BgLogger().Info("Remove TiFlash rule", zap.String("ruleID", ruleID)) - m.tiflash.HandleDeletePlacementRule(group, ruleID) - return nil -} - -// GetGroupRules to get all placement rule in a certain group. -func (m *mockTiFlashReplicaManagerCtx) GetGroupRules(_ context.Context, group string) ([]*pd.Rule, error) { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return []*pd.Rule{}, nil - } - return m.tiflash.HandleGetGroupRules(group) -} - -// PostAccelerateScheduleBatch sends `regions/batch-accelerate-schedule` request. -func (m *mockTiFlashReplicaManagerCtx) PostAccelerateScheduleBatch(_ context.Context, tableIDs []int64) error { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return nil - } - for _, tableID := range tableIDs { - endKey := tablecodec.EncodeTablePrefix(tableID + 1) - endKey = codec.EncodeBytes([]byte{}, endKey) - if err := m.tiflash.HandlePostAccelerateSchedule(hex.EncodeToString(endKey)); err != nil { - return err - } - } - return nil -} - -// GetRegionCountFromPD is a helper function calling `/stats/region`. -func (m *mockTiFlashReplicaManagerCtx) GetRegionCountFromPD(_ context.Context, tableID int64, regionCount *int) error { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return nil - } - stats := m.tiflash.HandleGetPDRegionRecordStats(tableID) - *regionCount = stats.Count - return nil -} - -// GetStoresStat gets the TiKV store information by accessing PD's api. -func (m *mockTiFlashReplicaManagerCtx) GetStoresStat(_ context.Context) (*pd.StoresInfo, error) { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return nil, &MockTiFlashError{"MockTiFlash is not accessible"} - } - return m.tiflash.HandleGetStoresStat(), nil -} - -// Close is called to close mockTiFlashReplicaManager. -func (m *mockTiFlashReplicaManagerCtx) Close(_ context.Context) { - m.Lock() - defer m.Unlock() - if m.tiflash == nil { - return - } - if m.tiflash.StatusServer != nil { - m.tiflash.StatusServer.Close() - } -} - -// MockTiFlashError represents MockTiFlash error -type MockTiFlashError struct { - Message string -} - -func (me *MockTiFlashError) Error() string { - return me.Message -} diff --git a/pkg/domain/plan_replayer_dump.go b/pkg/domain/plan_replayer_dump.go index 55a777ab7d556..1a5b95e4eb0c2 100644 --- a/pkg/domain/plan_replayer_dump.go +++ b/pkg/domain/plan_replayer_dump.go @@ -308,11 +308,11 @@ func DumpPlanReplayerInfo(ctx context.Context, sctx sessionctx.Context, errMsgs = append(errMsgs, fallbackMsg) } } else { - if val, _err_ := failpoint.Eval(_curpkg_("shouldDumpStats")); _err_ == nil { + failpoint.Inject("shouldDumpStats", func(val failpoint.Value) { if val.(bool) { panic("shouldDumpStats") } - } + }) } } else { // Dump stats diff --git a/pkg/domain/plan_replayer_dump.go__failpoint_stash__ b/pkg/domain/plan_replayer_dump.go__failpoint_stash__ deleted file mode 100644 index 1a5b95e4eb0c2..0000000000000 --- a/pkg/domain/plan_replayer_dump.go__failpoint_stash__ +++ /dev/null @@ -1,923 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package domain - -import ( - "archive/zip" - "context" - "encoding/json" - "fmt" - "io" - "strconv" - "strings" - - "github.com/BurntSushi/toml" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/bindinfo" - "github.com/pingcap/tidb/pkg/config" - domain_metrics "github.com/pingcap/tidb/pkg/domain/metrics" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/printer" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "go.uber.org/zap" -) - -const ( - // PlanReplayerSQLMetaFile indicates sql meta path for plan replayer - PlanReplayerSQLMetaFile = "sql_meta.toml" - // PlanReplayerConfigFile indicates config file path for plan replayer - PlanReplayerConfigFile = "config.toml" - // PlanReplayerMetaFile meta file path for plan replayer - PlanReplayerMetaFile = "meta.txt" - // PlanReplayerVariablesFile indicates for session variables file path for plan replayer - PlanReplayerVariablesFile = "variables.toml" - // PlanReplayerTiFlashReplicasFile indicates for table tiflash replica file path for plan replayer - PlanReplayerTiFlashReplicasFile = "table_tiflash_replica.txt" - // PlanReplayerSessionBindingFile indicates session binding file path for plan replayer - PlanReplayerSessionBindingFile = "session_bindings.sql" - // PlanReplayerGlobalBindingFile indicates global binding file path for plan replayer - PlanReplayerGlobalBindingFile = "global_bindings.sql" - // PlanReplayerSchemaMetaFile indicates the schema meta - PlanReplayerSchemaMetaFile = "schema_meta.txt" - // PlanReplayerErrorMessageFile is the file name for error messages - PlanReplayerErrorMessageFile = "errors.txt" -) - -const ( - // PlanReplayerSQLMetaStartTS indicates the startTS in plan replayer sql meta - PlanReplayerSQLMetaStartTS = "startTS" - // PlanReplayerTaskMetaIsCapture indicates whether this task is capture task - PlanReplayerTaskMetaIsCapture = "isCapture" - // PlanReplayerTaskMetaIsContinues indicates whether this task is continues task - PlanReplayerTaskMetaIsContinues = "isContinues" - // PlanReplayerTaskMetaSQLDigest indicates the sql digest of this task - PlanReplayerTaskMetaSQLDigest = "sqlDigest" - // PlanReplayerTaskMetaPlanDigest indicates the plan digest of this task - PlanReplayerTaskMetaPlanDigest = "planDigest" - // PlanReplayerTaskEnableHistoricalStats indicates whether the task is using historical stats - PlanReplayerTaskEnableHistoricalStats = "enableHistoricalStats" - // PlanReplayerHistoricalStatsTS indicates the expected TS of the historical stats if it's specified by the user. - PlanReplayerHistoricalStatsTS = "historicalStatsTS" -) - -type tableNamePair struct { - DBName string - TableName string - IsView bool -} - -type tableNameExtractor struct { - ctx context.Context - executor sqlexec.RestrictedSQLExecutor - is infoschema.InfoSchema - curDB model.CIStr - names map[tableNamePair]struct{} - cteNames map[string]struct{} - err error -} - -func (tne *tableNameExtractor) getTablesAndViews() map[tableNamePair]struct{} { - r := make(map[tableNamePair]struct{}) - for tablePair := range tne.names { - if tablePair.IsView { - r[tablePair] = struct{}{} - continue - } - // remove cte in table names - _, ok := tne.cteNames[tablePair.TableName] - if !ok { - r[tablePair] = struct{}{} - } - } - return r -} - -func (*tableNameExtractor) Enter(in ast.Node) (ast.Node, bool) { - if _, ok := in.(*ast.TableName); ok { - return in, true - } - return in, false -} - -func (tne *tableNameExtractor) Leave(in ast.Node) (ast.Node, bool) { - if tne.err != nil { - return in, true - } - if t, ok := in.(*ast.TableName); ok { - isView, err := tne.handleIsView(t) - if err != nil { - tne.err = err - return in, true - } - if tne.is.TableExists(t.Schema, t.Name) { - tp := tableNamePair{DBName: t.Schema.L, TableName: t.Name.L, IsView: isView} - if tp.DBName == "" { - tp.DBName = tne.curDB.L - } - tne.names[tp] = struct{}{} - } - } else if s, ok := in.(*ast.SelectStmt); ok { - if s.With != nil && len(s.With.CTEs) > 0 { - for _, cte := range s.With.CTEs { - tne.cteNames[cte.Name.L] = struct{}{} - } - } - } - return in, true -} - -func (tne *tableNameExtractor) handleIsView(t *ast.TableName) (bool, error) { - schema := t.Schema - if schema.L == "" { - schema = tne.curDB - } - table := t.Name - isView := infoschema.TableIsView(tne.is, schema, table) - if !isView { - return false, nil - } - viewTbl, err := tne.is.TableByName(context.Background(), schema, table) - if err != nil { - return false, err - } - sql := viewTbl.Meta().View.SelectStmt - node, err := tne.executor.ParseWithParams(tne.ctx, sql) - if err != nil { - return false, err - } - node.Accept(tne) - return true, nil -} - -// DumpPlanReplayerInfo will dump the information about sqls. -// The files will be organized into the following format: -/* - |-sql_meta.toml - |-meta.txt - |-schema - | |-schema_meta.txt - | |-db1.table1.schema.txt - | |-db2.table2.schema.txt - | |-.... - |-view - | |-db1.view1.view.txt - | |-db2.view2.view.txt - | |-.... - |-stats - | |-stats1.json - | |-stats2.json - | |-.... - |-statsMem - | |-stats1.txt - | |-stats2.txt - | |-.... - |-config.toml - |-table_tiflash_replica.txt - |-variables.toml - |-bindings.sql - |-sql - | |-sql1.sql - | |-sql2.sql - | |-.... - |-explain.txt -*/ -func DumpPlanReplayerInfo(ctx context.Context, sctx sessionctx.Context, - task *PlanReplayerDumpTask) (err error) { - zf := task.Zf - fileName := task.FileName - sessionVars := task.SessionVars - execStmts := task.ExecStmts - zw := zip.NewWriter(zf) - var records []PlanReplayerStatusRecord - var errMsgs []string - sqls := make([]string, 0) - for _, execStmt := range task.ExecStmts { - sqls = append(sqls, execStmt.Text()) - } - if task.IsCapture { - logutil.BgLogger().Info("start to dump plan replayer result", zap.String("category", "plan-replayer-dump"), - zap.String("sql-digest", task.SQLDigest), - zap.String("plan-digest", task.PlanDigest), - zap.Strings("sql", sqls), - zap.Bool("isContinues", task.IsContinuesCapture)) - } else { - logutil.BgLogger().Info("start to dump plan replayer result", zap.String("category", "plan-replayer-dump"), - zap.Strings("sqls", sqls)) - } - defer func() { - errMsg := "" - if err != nil { - if task.IsCapture { - logutil.BgLogger().Info("dump file failed", zap.String("category", "plan-replayer-dump"), - zap.String("sql-digest", task.SQLDigest), - zap.String("plan-digest", task.PlanDigest), - zap.Strings("sql", sqls), - zap.Bool("isContinues", task.IsContinuesCapture)) - } else { - logutil.BgLogger().Info("start to dump plan replayer result", zap.String("category", "plan-replayer-dump"), - zap.Strings("sqls", sqls)) - } - errMsg = err.Error() - domain_metrics.PlanReplayerDumpTaskFailed.Inc() - } else { - domain_metrics.PlanReplayerDumpTaskSuccess.Inc() - } - err1 := zw.Close() - if err1 != nil { - logutil.BgLogger().Error("Closing zip writer failed", zap.String("category", "plan-replayer-dump"), zap.Error(err), zap.String("filename", fileName)) - errMsg = errMsg + "," + err1.Error() - } - err2 := zf.Close() - if err2 != nil { - logutil.BgLogger().Error("Closing zip file failed", zap.String("category", "plan-replayer-dump"), zap.Error(err), zap.String("filename", fileName)) - errMsg = errMsg + "," + err2.Error() - } - if len(errMsg) > 0 { - for i, record := range records { - record.FailedReason = errMsg - records[i] = record - } - } - insertPlanReplayerStatus(ctx, sctx, records) - }() - // Dump SQLMeta - if err = dumpSQLMeta(zw, task); err != nil { - return err - } - - // Dump config - if err = dumpConfig(zw); err != nil { - return err - } - - // Dump meta - if err = dumpMeta(zw); err != nil { - return err - } - // Retrieve current DB - dbName := model.NewCIStr(sessionVars.CurrentDB) - do := GetDomain(sctx) - - // Retrieve all tables - pairs, err := extractTableNames(ctx, sctx, execStmts, dbName) - if err != nil { - return errors.AddStack(fmt.Errorf("plan replayer: invalid SQL text, err: %v", err)) - } - - // Dump Schema and View - if err = dumpSchemas(sctx, zw, pairs); err != nil { - return err - } - - // Dump tables tiflash replicas - if err = dumpTiFlashReplica(sctx, zw, pairs); err != nil { - return err - } - - // For continuous capture task, we dump stats in storage only if EnableHistoricalStatsForCapture is disabled. - // For manual plan replayer dump command or capture, we directly dump stats in storage - if task.IsCapture && task.IsContinuesCapture { - if !variable.EnableHistoricalStatsForCapture.Load() { - // Dump stats - fallbackMsg, err := dumpStats(zw, pairs, do, 0) - if err != nil { - return err - } - if len(fallbackMsg) > 0 { - errMsgs = append(errMsgs, fallbackMsg) - } - } else { - failpoint.Inject("shouldDumpStats", func(val failpoint.Value) { - if val.(bool) { - panic("shouldDumpStats") - } - }) - } - } else { - // Dump stats - fallbackMsg, err := dumpStats(zw, pairs, do, task.HistoricalStatsTS) - if err != nil { - return err - } - if len(fallbackMsg) > 0 { - errMsgs = append(errMsgs, fallbackMsg) - } - } - - if err = dumpStatsMemStatus(zw, pairs, do); err != nil { - return err - } - - // Dump variables - if err = dumpVariables(sctx, sessionVars, zw); err != nil { - return err - } - - // Dump sql - if err = dumpSQLs(execStmts, zw); err != nil { - return err - } - - // Dump session bindings - if len(task.SessionBindings) > 0 { - if err = dumpSessionBindRecords(task.SessionBindings, zw); err != nil { - return err - } - } else { - if err = dumpSessionBindings(sctx, zw); err != nil { - return err - } - } - - // Dump global bindings - if err = dumpGlobalBindings(sctx, zw); err != nil { - return err - } - - if len(task.EncodedPlan) > 0 { - records = generateRecords(task) - if err = dumpEncodedPlan(sctx, zw, task.EncodedPlan); err != nil { - return err - } - } else { - // Dump explain - if err = dumpPlanReplayerExplain(sctx, zw, task, &records); err != nil { - return err - } - } - - if task.DebugTrace != nil { - if err = dumpDebugTrace(zw, task.DebugTrace); err != nil { - return err - } - } - - if len(errMsgs) > 0 { - if err = dumpErrorMsgs(zw, errMsgs); err != nil { - return err - } - } - return nil -} - -func generateRecords(task *PlanReplayerDumpTask) []PlanReplayerStatusRecord { - records := make([]PlanReplayerStatusRecord, 0) - if len(task.ExecStmts) > 0 { - for _, execStmt := range task.ExecStmts { - records = append(records, PlanReplayerStatusRecord{ - SQLDigest: task.SQLDigest, - PlanDigest: task.PlanDigest, - OriginSQL: execStmt.Text(), - Token: task.FileName, - }) - } - } - return records -} - -func dumpSQLMeta(zw *zip.Writer, task *PlanReplayerDumpTask) error { - cf, err := zw.Create(PlanReplayerSQLMetaFile) - if err != nil { - return errors.AddStack(err) - } - varMap := make(map[string]string) - varMap[PlanReplayerSQLMetaStartTS] = strconv.FormatUint(task.StartTS, 10) - varMap[PlanReplayerTaskMetaIsCapture] = strconv.FormatBool(task.IsCapture) - varMap[PlanReplayerTaskMetaIsContinues] = strconv.FormatBool(task.IsContinuesCapture) - varMap[PlanReplayerTaskMetaSQLDigest] = task.SQLDigest - varMap[PlanReplayerTaskMetaPlanDigest] = task.PlanDigest - varMap[PlanReplayerTaskEnableHistoricalStats] = strconv.FormatBool(variable.EnableHistoricalStatsForCapture.Load()) - if task.HistoricalStatsTS > 0 { - varMap[PlanReplayerHistoricalStatsTS] = strconv.FormatUint(task.HistoricalStatsTS, 10) - } - if err := toml.NewEncoder(cf).Encode(varMap); err != nil { - return errors.AddStack(err) - } - return nil -} - -func dumpConfig(zw *zip.Writer) error { - cf, err := zw.Create(PlanReplayerConfigFile) - if err != nil { - return errors.AddStack(err) - } - if err := toml.NewEncoder(cf).Encode(config.GetGlobalConfig()); err != nil { - return errors.AddStack(err) - } - return nil -} - -func dumpMeta(zw *zip.Writer) error { - mt, err := zw.Create(PlanReplayerMetaFile) - if err != nil { - return errors.AddStack(err) - } - _, err = mt.Write([]byte(printer.GetTiDBInfo())) - if err != nil { - return errors.AddStack(err) - } - return nil -} - -func dumpTiFlashReplica(ctx sessionctx.Context, zw *zip.Writer, pairs map[tableNamePair]struct{}) error { - bf, err := zw.Create(PlanReplayerTiFlashReplicasFile) - if err != nil { - return errors.AddStack(err) - } - is := GetDomain(ctx).InfoSchema() - for pair := range pairs { - dbName := model.NewCIStr(pair.DBName) - tableName := model.NewCIStr(pair.TableName) - t, err := is.TableByName(context.Background(), dbName, tableName) - if err != nil { - logutil.BgLogger().Warn("failed to find table info", zap.Error(err), - zap.String("dbName", dbName.L), zap.String("tableName", tableName.L)) - continue - } - if t.Meta().TiFlashReplica != nil && t.Meta().TiFlashReplica.Count > 0 { - row := []string{ - pair.DBName, pair.TableName, strconv.FormatUint(t.Meta().TiFlashReplica.Count, 10), - } - fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) - } - } - return nil -} - -func dumpSchemas(ctx sessionctx.Context, zw *zip.Writer, pairs map[tableNamePair]struct{}) error { - tables := make(map[tableNamePair]struct{}) - for pair := range pairs { - err := getShowCreateTable(pair, zw, ctx) - if err != nil { - return err - } - if !pair.IsView { - tables[pair] = struct{}{} - } - } - return dumpSchemaMeta(zw, tables) -} - -func dumpSchemaMeta(zw *zip.Writer, tables map[tableNamePair]struct{}) error { - zf, err := zw.Create(fmt.Sprintf("schema/%v", PlanReplayerSchemaMetaFile)) - if err != nil { - return err - } - for table := range tables { - _, err := fmt.Fprintf(zf, "%s.%s;", table.DBName, table.TableName) - if err != nil { - return err - } - } - return nil -} - -func dumpStatsMemStatus(zw *zip.Writer, pairs map[tableNamePair]struct{}, do *Domain) error { - statsHandle := do.StatsHandle() - is := do.InfoSchema() - for pair := range pairs { - if pair.IsView { - continue - } - tbl, err := is.TableByName(context.Background(), model.NewCIStr(pair.DBName), model.NewCIStr(pair.TableName)) - if err != nil { - return err - } - tblStats := statsHandle.GetTableStats(tbl.Meta()) - if tblStats == nil { - continue - } - statsMemFw, err := zw.Create(fmt.Sprintf("statsMem/%v.%v.txt", pair.DBName, pair.TableName)) - if err != nil { - return errors.AddStack(err) - } - fmt.Fprintf(statsMemFw, "[INDEX]\n") - tblStats.ForEachIndexImmutable(func(_ int64, idx *statistics.Index) bool { - fmt.Fprintf(statsMemFw, "%s\n", fmt.Sprintf("%s=%s", idx.Info.Name.String(), idx.StatusToString())) - return false - }) - fmt.Fprintf(statsMemFw, "[COLUMN]\n") - tblStats.ForEachColumnImmutable(func(_ int64, c *statistics.Column) bool { - fmt.Fprintf(statsMemFw, "%s\n", fmt.Sprintf("%s=%s", c.Info.Name.String(), c.StatusToString())) - return false - }) - } - return nil -} - -func dumpStats(zw *zip.Writer, pairs map[tableNamePair]struct{}, do *Domain, historyStatsTS uint64) (string, error) { - allFallBackTbls := make([]string, 0) - for pair := range pairs { - if pair.IsView { - continue - } - jsonTbl, fallBackTbls, err := getStatsForTable(do, pair, historyStatsTS) - if err != nil { - return "", err - } - statsFw, err := zw.Create(fmt.Sprintf("stats/%v.%v.json", pair.DBName, pair.TableName)) - if err != nil { - return "", errors.AddStack(err) - } - data, err := json.Marshal(jsonTbl) - if err != nil { - return "", errors.AddStack(err) - } - _, err = statsFw.Write(data) - if err != nil { - return "", errors.AddStack(err) - } - allFallBackTbls = append(allFallBackTbls, fallBackTbls...) - } - var msg string - if len(allFallBackTbls) > 0 { - msg = "Historical stats for " + strings.Join(allFallBackTbls, ", ") + " are unavailable, fallback to latest stats" - } - return msg, nil -} - -func dumpSQLs(execStmts []ast.StmtNode, zw *zip.Writer) error { - for i, stmtExec := range execStmts { - zf, err := zw.Create(fmt.Sprintf("sql/sql%v.sql", i)) - if err != nil { - return err - } - _, err = zf.Write([]byte(stmtExec.Text())) - if err != nil { - return err - } - } - return nil -} - -func dumpVariables(sctx sessionctx.Context, sessionVars *variable.SessionVars, zw *zip.Writer) error { - varMap := make(map[string]string) - for _, v := range variable.GetSysVars() { - if v.IsNoop && !variable.EnableNoopVariables.Load() { - continue - } - if infoschema.SysVarHiddenForSem(sctx, v.Name) { - continue - } - value, err := sessionVars.GetSessionOrGlobalSystemVar(context.Background(), v.Name) - if err != nil { - return errors.Trace(err) - } - varMap[v.Name] = value - } - vf, err := zw.Create(PlanReplayerVariablesFile) - if err != nil { - return errors.AddStack(err) - } - if err := toml.NewEncoder(vf).Encode(varMap); err != nil { - return errors.AddStack(err) - } - return nil -} - -func dumpSessionBindRecords(records []bindinfo.Bindings, zw *zip.Writer) error { - sRows := make([][]string, 0) - for _, bindData := range records { - for _, hint := range bindData { - sRows = append(sRows, []string{ - hint.OriginalSQL, - hint.BindSQL, - hint.Db, - hint.Status, - hint.CreateTime.String(), - hint.UpdateTime.String(), - hint.Charset, - hint.Collation, - hint.Source, - }) - } - } - bf, err := zw.Create(PlanReplayerSessionBindingFile) - if err != nil { - return errors.AddStack(err) - } - for _, row := range sRows { - fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) - } - return nil -} - -func dumpSessionBindings(ctx sessionctx.Context, zw *zip.Writer) error { - recordSets, err := ctx.GetSQLExecutor().Execute(context.Background(), "show bindings") - if err != nil { - return err - } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], true) - if err != nil { - return err - } - bf, err := zw.Create(PlanReplayerSessionBindingFile) - if err != nil { - return errors.AddStack(err) - } - for _, row := range sRows { - fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) - } - if len(recordSets) > 0 { - if err := recordSets[0].Close(); err != nil { - return err - } - } - return nil -} - -func dumpGlobalBindings(ctx sessionctx.Context, zw *zip.Writer) error { - recordSets, err := ctx.GetSQLExecutor().Execute(context.Background(), "show global bindings") - if err != nil { - return err - } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) - if err != nil { - return err - } - bf, err := zw.Create(PlanReplayerGlobalBindingFile) - if err != nil { - return errors.AddStack(err) - } - for _, row := range sRows { - fmt.Fprintf(bf, "%s\n", strings.Join(row, "\t")) - } - if len(recordSets) > 0 { - if err := recordSets[0].Close(); err != nil { - return err - } - } - return nil -} - -func dumpEncodedPlan(ctx sessionctx.Context, zw *zip.Writer, encodedPlan string) error { - var recordSets []sqlexec.RecordSet - var err error - recordSets, err = ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("select tidb_decode_plan('%s')", encodedPlan)) - if err != nil { - return err - } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) - if err != nil { - return err - } - fw, err := zw.Create("explain/sql.txt") - if err != nil { - return errors.AddStack(err) - } - for _, row := range sRows { - fmt.Fprintf(fw, "%s\n", strings.Join(row, "\t")) - } - if len(recordSets) > 0 { - if err := recordSets[0].Close(); err != nil { - return err - } - } - return nil -} - -func dumpExplain(ctx sessionctx.Context, zw *zip.Writer, isAnalyze bool, sqls []string, emptyAsNil bool) (debugTraces []any, err error) { - fw, err := zw.Create("explain.txt") - if err != nil { - return nil, errors.AddStack(err) - } - ctx.GetSessionVars().InPlanReplayer = true - defer func() { - ctx.GetSessionVars().InPlanReplayer = false - }() - for i, sql := range sqls { - var recordSets []sqlexec.RecordSet - if isAnalyze { - // Explain analyze - recordSets, err = ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("explain analyze %s", sql)) - if err != nil { - return nil, err - } - } else { - // Explain - recordSets, err = ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("explain %s", sql)) - if err != nil { - return nil, err - } - } - debugTrace := ctx.GetSessionVars().StmtCtx.OptimizerDebugTrace - debugTraces = append(debugTraces, debugTrace) - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], emptyAsNil) - if err != nil { - return nil, err - } - for _, row := range sRows { - fmt.Fprintf(fw, "%s\n", strings.Join(row, "\t")) - } - if len(recordSets) > 0 { - if err := recordSets[0].Close(); err != nil { - return nil, err - } - } - if i < len(sqls)-1 { - fmt.Fprintf(fw, "<--------->\n") - } - } - return -} - -func dumpPlanReplayerExplain(ctx sessionctx.Context, zw *zip.Writer, task *PlanReplayerDumpTask, records *[]PlanReplayerStatusRecord) error { - sqls := make([]string, 0) - for _, execStmt := range task.ExecStmts { - sql := execStmt.Text() - sqls = append(sqls, sql) - *records = append(*records, PlanReplayerStatusRecord{ - OriginSQL: sql, - Token: task.FileName, - }) - } - debugTraces, err := dumpExplain(ctx, zw, task.Analyze, sqls, false) - task.DebugTrace = debugTraces - return err -} - -// extractTableNames extracts table names from the given stmts. -func extractTableNames(ctx context.Context, sctx sessionctx.Context, - execStmts []ast.StmtNode, curDB model.CIStr) (map[tableNamePair]struct{}, error) { - tableExtractor := &tableNameExtractor{ - ctx: ctx, - executor: sctx.GetRestrictedSQLExecutor(), - is: GetDomain(sctx).InfoSchema(), - curDB: curDB, - names: make(map[tableNamePair]struct{}), - cteNames: make(map[string]struct{}), - } - for _, execStmt := range execStmts { - execStmt.Accept(tableExtractor) - } - if tableExtractor.err != nil { - return nil, tableExtractor.err - } - return tableExtractor.getTablesAndViews(), nil -} - -func getStatsForTable(do *Domain, pair tableNamePair, historyStatsTS uint64) (*util.JSONTable, []string, error) { - is := do.InfoSchema() - h := do.StatsHandle() - tbl, err := is.TableByName(context.Background(), model.NewCIStr(pair.DBName), model.NewCIStr(pair.TableName)) - if err != nil { - return nil, nil, err - } - if historyStatsTS > 0 { - return h.DumpHistoricalStatsBySnapshot(pair.DBName, tbl.Meta(), historyStatsTS) - } - jt, err := h.DumpStatsToJSON(pair.DBName, tbl.Meta(), nil, true) - return jt, nil, err -} - -func getShowCreateTable(pair tableNamePair, zw *zip.Writer, ctx sessionctx.Context) error { - recordSets, err := ctx.GetSQLExecutor().Execute(context.Background(), fmt.Sprintf("show create table `%v`.`%v`", pair.DBName, pair.TableName)) - if err != nil { - return err - } - sRows, err := resultSetToStringSlice(context.Background(), recordSets[0], false) - if err != nil { - return err - } - var fw io.Writer - if pair.IsView { - fw, err = zw.Create(fmt.Sprintf("view/%v.%v.view.txt", pair.DBName, pair.TableName)) - if err != nil { - return errors.AddStack(err) - } - if len(sRows) == 0 || len(sRows[0]) != 4 { - return fmt.Errorf("plan replayer: get create view %v.%v failed", pair.DBName, pair.TableName) - } - } else { - fw, err = zw.Create(fmt.Sprintf("schema/%v.%v.schema.txt", pair.DBName, pair.TableName)) - if err != nil { - return errors.AddStack(err) - } - if len(sRows) == 0 || len(sRows[0]) != 2 { - return fmt.Errorf("plan replayer: get create table %v.%v failed", pair.DBName, pair.TableName) - } - } - fmt.Fprintf(fw, "create database if not exists `%v`; use `%v`;", pair.DBName, pair.DBName) - fmt.Fprintf(fw, "%s", sRows[0][1]) - if len(recordSets) > 0 { - if err := recordSets[0].Close(); err != nil { - return err - } - } - return nil -} - -func resultSetToStringSlice(ctx context.Context, rs sqlexec.RecordSet, emptyAsNil bool) ([][]string, error) { - rows, err := getRows(ctx, rs) - if err != nil { - return nil, err - } - err = rs.Close() - if err != nil { - return nil, err - } - sRows := make([][]string, len(rows)) - for i, row := range rows { - iRow := make([]string, row.Len()) - for j := 0; j < row.Len(); j++ { - if row.IsNull(j) { - iRow[j] = "" - } else { - d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType) - iRow[j], err = d.ToString() - if err != nil { - return nil, err - } - if len(iRow[j]) < 1 && emptyAsNil { - iRow[j] = "" - } - } - } - sRows[i] = iRow - } - return sRows, nil -} - -func getRows(ctx context.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { - if rs == nil { - return nil, nil - } - var rows []chunk.Row - req := rs.NewChunk(nil) - // Must reuse `req` for imitating server.(*clientConn).writeChunks - for { - err := rs.Next(ctx, req) - if err != nil { - return nil, err - } - if req.NumRows() == 0 { - break - } - - iter := chunk.NewIterator4Chunk(req.CopyConstruct()) - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - rows = append(rows, row) - } - } - return rows, nil -} - -func dumpDebugTrace(zw *zip.Writer, debugTraces []any) error { - for i, trace := range debugTraces { - fw, err := zw.Create(fmt.Sprintf("debug_trace/debug_trace%d.json", i)) - if err != nil { - return errors.AddStack(err) - } - err = dumpOneDebugTrace(fw, trace) - if err != nil { - return errors.AddStack(err) - } - } - return nil -} - -func dumpOneDebugTrace(w io.Writer, debugTrace any) error { - jsonEncoder := json.NewEncoder(w) - // If we do not set this to false, ">", "<", "&"... will be escaped to "\u003c","\u003e", "\u0026"... - jsonEncoder.SetEscapeHTML(false) - return jsonEncoder.Encode(debugTrace) -} - -func dumpErrorMsgs(zw *zip.Writer, msgs []string) error { - mt, err := zw.Create(PlanReplayerErrorMessageFile) - if err != nil { - return errors.AddStack(err) - } - for _, msg := range msgs { - _, err = mt.Write([]byte(msg)) - if err != nil { - return errors.AddStack(err) - } - _, err = mt.Write([]byte{'\n'}) - if err != nil { - return errors.AddStack(err) - } - } - return nil -} diff --git a/pkg/domain/runaway.go b/pkg/domain/runaway.go index f6dfa79d96049..ae5d50724a047 100644 --- a/pkg/domain/runaway.go +++ b/pkg/domain/runaway.go @@ -64,9 +64,9 @@ func (do *Domain) deleteExpiredRows(tableName, colName string, expiredDuration t if !do.DDL().OwnerManager().IsOwner() { return } - if _, _err_ := failpoint.Eval(_curpkg_("FastRunawayGC")); _err_ == nil { + failpoint.Inject("FastRunawayGC", func() { expiredDuration = time.Second * 1 - } + }) expiredTime := time.Now().Add(-expiredDuration) tbCIStr := model.NewCIStr(tableName) tbl, err := do.InfoSchema().TableByName(context.Background(), systemSchemaCIStr, tbCIStr) @@ -244,12 +244,12 @@ func (do *Domain) runawayRecordFlushLoop() { // we can guarantee a watch record can be seen by the user within 1s. runawayRecordFlushTimer := time.NewTimer(runawayRecordFlushInterval) runawayRecordGCTicker := time.NewTicker(runawayRecordGCInterval) - if _, _err_ := failpoint.Eval(_curpkg_("FastRunawayGC")); _err_ == nil { + failpoint.Inject("FastRunawayGC", func() { runawayRecordFlushTimer.Stop() runawayRecordGCTicker.Stop() runawayRecordFlushTimer = time.NewTimer(time.Millisecond * 50) runawayRecordGCTicker = time.NewTicker(time.Millisecond * 200) - } + }) fired := false recordCh := do.runawayManager.RunawayRecordChan() @@ -278,9 +278,9 @@ func (do *Domain) runawayRecordFlushLoop() { fired = true case r := <-recordCh: records = append(records, r) - if _, _err_ := failpoint.Eval(_curpkg_("FastRunawayGC")); _err_ == nil { + failpoint.Inject("FastRunawayGC", func() { flushRunawayRecords() - } + }) if len(records) >= flushThreshold { flushRunawayRecords() } else if fired { diff --git a/pkg/domain/runaway.go__failpoint_stash__ b/pkg/domain/runaway.go__failpoint_stash__ deleted file mode 100644 index ae5d50724a047..0000000000000 --- a/pkg/domain/runaway.go__failpoint_stash__ +++ /dev/null @@ -1,659 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package domain - -import ( - "context" - "net" - "strconv" - "strings" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - rmpb "github.com/pingcap/kvproto/pkg/resource_manager" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/ttl/cache" - "github.com/pingcap/tidb/pkg/ttl/sqlbuilder" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/tikv/client-go/v2/tikv" - pd "github.com/tikv/pd/client" - rmclient "github.com/tikv/pd/client/resource_group/controller" - "go.uber.org/zap" -) - -const ( - runawayRecordFlushInterval = time.Second - runawayRecordGCInterval = time.Hour * 24 - runawayRecordExpiredDuration = time.Hour * 24 * 7 - runawayWatchSyncInterval = time.Second - - runawayRecordGCBatchSize = 100 - runawayRecordGCSelectBatchSize = runawayRecordGCBatchSize * 5 - - maxIDRetries = 3 - runawayLoopLogErrorIntervalCount = 1800 -) - -var systemSchemaCIStr = model.NewCIStr("mysql") - -func (do *Domain) deleteExpiredRows(tableName, colName string, expiredDuration time.Duration) { - if !do.DDL().OwnerManager().IsOwner() { - return - } - failpoint.Inject("FastRunawayGC", func() { - expiredDuration = time.Second * 1 - }) - expiredTime := time.Now().Add(-expiredDuration) - tbCIStr := model.NewCIStr(tableName) - tbl, err := do.InfoSchema().TableByName(context.Background(), systemSchemaCIStr, tbCIStr) - if err != nil { - logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) - return - } - tbInfo := tbl.Meta() - col := tbInfo.FindPublicColumnByName(colName) - if col == nil { - logutil.BgLogger().Error("time column is not public in table", zap.String("table", tableName), zap.String("column", colName)) - return - } - tb, err := cache.NewBasePhysicalTable(systemSchemaCIStr, tbInfo, model.NewCIStr(""), col) - if err != nil { - logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) - return - } - generator, err := sqlbuilder.NewScanQueryGenerator(tb, expiredTime, nil, nil) - if err != nil { - logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) - return - } - var leftRows [][]types.Datum - for { - sql := "" - if sql, err = generator.NextSQL(leftRows, runawayRecordGCSelectBatchSize); err != nil { - logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) - return - } - // to remove - if len(sql) == 0 { - return - } - - rows, sqlErr := execRestrictedSQL(do.sysSessionPool, sql, nil) - if sqlErr != nil { - logutil.BgLogger().Error("delete system table failed", zap.String("table", tableName), zap.Error(err)) - return - } - leftRows = make([][]types.Datum, len(rows)) - for i, row := range rows { - leftRows[i] = row.GetDatumRow(tb.KeyColumnTypes) - } - - for len(leftRows) > 0 { - var delBatch [][]types.Datum - if len(leftRows) < runawayRecordGCBatchSize { - delBatch = leftRows - leftRows = nil - } else { - delBatch = leftRows[0:runawayRecordGCBatchSize] - leftRows = leftRows[runawayRecordGCBatchSize:] - } - sql, err := sqlbuilder.BuildDeleteSQL(tb, delBatch, expiredTime) - if err != nil { - logutil.BgLogger().Error( - "build delete SQL failed when deleting system table", - zap.Error(err), - zap.String("table", tb.Schema.O+"."+tb.Name.O), - ) - return - } - - _, err = execRestrictedSQL(do.sysSessionPool, sql, nil) - if err != nil { - logutil.BgLogger().Error( - "delete SQL failed when deleting system table", zap.Error(err), zap.String("SQL", sql), - ) - } - } - } -} - -func (do *Domain) runawayStartLoop() { - defer util.Recover(metrics.LabelDomain, "runawayStartLoop", nil, false) - runawayWatchSyncTicker := time.NewTicker(runawayWatchSyncInterval) - count := 0 - var err error - logutil.BgLogger().Info("try to start runaway manager loop") - for { - select { - case <-do.exit: - return - case <-runawayWatchSyncTicker.C: - // Due to the watch and watch done tables is created later than runaway queries table - err = do.updateNewAndDoneWatch() - if err == nil { - logutil.BgLogger().Info("preparations for the runaway manager are finished and start runaway manager loop") - do.wg.Run(do.runawayRecordFlushLoop, "runawayRecordFlushLoop") - do.wg.Run(do.runawayWatchSyncLoop, "runawayWatchSyncLoop") - do.runawayManager.MarkSyncerInitialized() - return - } - } - if count %= runawayLoopLogErrorIntervalCount; count == 0 { - logutil.BgLogger().Warn( - "failed to start runaway manager loop, please check whether the bootstrap or update is finished", - zap.Error(err)) - } - count++ - } -} - -func (do *Domain) updateNewAndDoneWatch() error { - do.runawaySyncer.mu.Lock() - defer do.runawaySyncer.mu.Unlock() - records, err := do.runawaySyncer.getNewWatchRecords() - if err != nil { - return err - } - for _, r := range records { - do.runawayManager.AddWatch(r) - } - doneRecords, err := do.runawaySyncer.getNewWatchDoneRecords() - if err != nil { - return err - } - for _, r := range doneRecords { - do.runawayManager.RemoveWatch(r) - } - return nil -} - -func (do *Domain) runawayWatchSyncLoop() { - defer util.Recover(metrics.LabelDomain, "runawayWatchSyncLoop", nil, false) - runawayWatchSyncTicker := time.NewTicker(runawayWatchSyncInterval) - count := 0 - for { - select { - case <-do.exit: - return - case <-runawayWatchSyncTicker.C: - err := do.updateNewAndDoneWatch() - if err != nil { - if count %= runawayLoopLogErrorIntervalCount; count == 0 { - logutil.BgLogger().Warn("get runaway watch record failed", zap.Error(err)) - } - count++ - } - } - } -} - -// GetRunawayWatchList is used to get all items from runaway watch list. -func (do *Domain) GetRunawayWatchList() []*resourcegroup.QuarantineRecord { - return do.runawayManager.GetWatchList() -} - -// TryToUpdateRunawayWatch is used to update watch list including -// creation and deletion by manual trigger. -func (do *Domain) TryToUpdateRunawayWatch() error { - return do.updateNewAndDoneWatch() -} - -// RemoveRunawayWatch is used to remove runaway watch item manually. -func (do *Domain) RemoveRunawayWatch(recordID int64) error { - do.runawaySyncer.mu.Lock() - defer do.runawaySyncer.mu.Unlock() - records, err := do.runawaySyncer.getWatchRecordByID(recordID) - if err != nil { - return err - } - if len(records) != 1 { - return errors.Errorf("no runaway watch with the specific ID") - } - err = do.handleRunawayWatchDone(records[0]) - return err -} - -func (do *Domain) runawayRecordFlushLoop() { - defer util.Recover(metrics.LabelDomain, "runawayRecordFlushLoop", nil, false) - - // this times is used to batch flushing records, with 1s duration, - // we can guarantee a watch record can be seen by the user within 1s. - runawayRecordFlushTimer := time.NewTimer(runawayRecordFlushInterval) - runawayRecordGCTicker := time.NewTicker(runawayRecordGCInterval) - failpoint.Inject("FastRunawayGC", func() { - runawayRecordFlushTimer.Stop() - runawayRecordGCTicker.Stop() - runawayRecordFlushTimer = time.NewTimer(time.Millisecond * 50) - runawayRecordGCTicker = time.NewTicker(time.Millisecond * 200) - }) - - fired := false - recordCh := do.runawayManager.RunawayRecordChan() - quarantineRecordCh := do.runawayManager.QuarantineRecordChan() - staleQuarantineRecordCh := do.runawayManager.StaleQuarantineRecordChan() - flushThreshold := do.runawayManager.FlushThreshold() - records := make([]*resourcegroup.RunawayRecord, 0, flushThreshold) - - flushRunawayRecords := func() { - if len(records) == 0 { - return - } - sql, params := resourcegroup.GenRunawayQueriesStmt(records) - if _, err := execRestrictedSQL(do.sysSessionPool, sql, params); err != nil { - logutil.BgLogger().Error("flush runaway records failed", zap.Error(err), zap.Int("count", len(records))) - } - records = records[:0] - } - - for { - select { - case <-do.exit: - return - case <-runawayRecordFlushTimer.C: - flushRunawayRecords() - fired = true - case r := <-recordCh: - records = append(records, r) - failpoint.Inject("FastRunawayGC", func() { - flushRunawayRecords() - }) - if len(records) >= flushThreshold { - flushRunawayRecords() - } else if fired { - fired = false - // meet a new record, reset the timer. - runawayRecordFlushTimer.Reset(runawayRecordFlushInterval) - } - case <-runawayRecordGCTicker.C: - go do.deleteExpiredRows("tidb_runaway_queries", "time", runawayRecordExpiredDuration) - case r := <-quarantineRecordCh: - go func() { - _, err := do.AddRunawayWatch(r) - if err != nil { - logutil.BgLogger().Error("add runaway watch", zap.Error(err)) - } - }() - case r := <-staleQuarantineRecordCh: - go func() { - for i := 0; i < 3; i++ { - err := do.handleRemoveStaleRunawayWatch(r) - if err == nil { - break - } - logutil.BgLogger().Error("remove stale runaway watch", zap.Error(err)) - time.Sleep(time.Second) - } - }() - } - } -} - -// AddRunawayWatch is used to add runaway watch item manually. -func (do *Domain) AddRunawayWatch(record *resourcegroup.QuarantineRecord) (uint64, error) { - se, err := do.sysSessionPool.Get() - defer func() { - do.sysSessionPool.Put(se) - }() - if err != nil { - return 0, errors.Annotate(err, "get session failed") - } - exec := se.(sessionctx.Context).GetSQLExecutor() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) - _, err = exec.ExecuteInternal(ctx, "BEGIN") - if err != nil { - return 0, errors.Trace(err) - } - defer func() { - if err != nil { - _, err1 := exec.ExecuteInternal(ctx, "ROLLBACK") - terror.Log(err1) - return - } - _, err = exec.ExecuteInternal(ctx, "COMMIT") - if err != nil { - return - } - }() - sql, params := record.GenInsertionStmt() - _, err = exec.ExecuteInternal(ctx, sql, params...) - if err != nil { - return 0, err - } - for retry := 0; retry < maxIDRetries; retry++ { - if retry > 0 { - select { - case <-do.exit: - return 0, err - case <-time.After(time.Millisecond * time.Duration(retry*100)): - logutil.BgLogger().Warn("failed to get last insert id when adding runaway watch", zap.Error(err)) - } - } - var rs sqlexec.RecordSet - rs, err = exec.ExecuteInternal(ctx, `SELECT LAST_INSERT_ID();`) - if err != nil { - continue - } - var rows []chunk.Row - rows, err = sqlexec.DrainRecordSet(ctx, rs, 1) - //nolint: errcheck - rs.Close() - if err != nil { - continue - } - if len(rows) != 1 { - err = errors.Errorf("unexpected result length: %d", len(rows)) - continue - } - return rows[0].GetUint64(0), nil - } - return 0, errors.Errorf("An error: %v occurred while getting the ID of the newly added watch record. Try querying information_schema.runaway_watches later", err) -} - -func (do *Domain) handleRunawayWatchDone(record *resourcegroup.QuarantineRecord) error { - se, err := do.sysSessionPool.Get() - defer func() { - do.sysSessionPool.Put(se) - }() - if err != nil { - return errors.Annotate(err, "get session failed") - } - sctx, _ := se.(sessionctx.Context) - exec := sctx.GetSQLExecutor() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) - _, err = exec.ExecuteInternal(ctx, "BEGIN") - if err != nil { - return errors.Trace(err) - } - defer func() { - if err != nil { - _, err1 := exec.ExecuteInternal(ctx, "ROLLBACK") - terror.Log(err1) - return - } - _, err = exec.ExecuteInternal(ctx, "COMMIT") - if err != nil { - return - } - }() - sql, params := record.GenInsertionDoneStmt() - _, err = exec.ExecuteInternal(ctx, sql, params...) - if err != nil { - return err - } - sql, params = record.GenDeletionStmt() - _, err = exec.ExecuteInternal(ctx, sql, params...) - return err -} - -func (do *Domain) handleRemoveStaleRunawayWatch(record *resourcegroup.QuarantineRecord) error { - se, err := do.sysSessionPool.Get() - defer func() { - do.sysSessionPool.Put(se) - }() - if err != nil { - return errors.Annotate(err, "get session failed") - } - sctx, _ := se.(sessionctx.Context) - exec := sctx.GetSQLExecutor() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) - _, err = exec.ExecuteInternal(ctx, "BEGIN") - if err != nil { - return errors.Trace(err) - } - defer func() { - if err != nil { - _, err1 := exec.ExecuteInternal(ctx, "ROLLBACK") - terror.Log(err1) - return - } - _, err = exec.ExecuteInternal(ctx, "COMMIT") - if err != nil { - return - } - }() - sql, params := record.GenDeletionStmt() - _, err = exec.ExecuteInternal(ctx, sql, params...) - return err -} - -func execRestrictedSQL(sessPool util.SessionPool, sql string, params []any) ([]chunk.Row, error) { - se, err := sessPool.Get() - defer func() { - sessPool.Put(se) - }() - if err != nil { - return nil, errors.Annotate(err, "get session failed") - } - sctx := se.(sessionctx.Context) - exec := sctx.GetRestrictedSQLExecutor() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) - r, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, - sql, params..., - ) - return r, err -} - -func (do *Domain) initResourceGroupsController(ctx context.Context, pdClient pd.Client, uniqueID uint64) error { - if pdClient == nil { - logutil.BgLogger().Warn("cannot setup up resource controller, not using tikv storage") - // return nil as unistore doesn't support it - return nil - } - - control, err := rmclient.NewResourceGroupController(ctx, uniqueID, pdClient, nil, rmclient.WithMaxWaitDuration(resourcegroup.MaxWaitDuration)) - if err != nil { - return err - } - control.Start(ctx) - serverInfo, err := infosync.GetServerInfo() - if err != nil { - return err - } - serverAddr := net.JoinHostPort(serverInfo.IP, strconv.Itoa(int(serverInfo.Port))) - do.runawayManager = resourcegroup.NewRunawayManager(control, serverAddr) - do.runawaySyncer = newRunawaySyncer(do.sysSessionPool) - do.resourceGroupsController = control - tikv.SetResourceControlInterceptor(control) - return nil -} - -type runawaySyncer struct { - newWatchReader *SystemTableReader - deletionWatchReader *SystemTableReader - sysSessionPool util.SessionPool - mu sync.Mutex -} - -func newRunawaySyncer(sysSessionPool util.SessionPool) *runawaySyncer { - return &runawaySyncer{ - sysSessionPool: sysSessionPool, - newWatchReader: &SystemTableReader{ - resourcegroup.RunawayWatchTableName, - "start_time", - resourcegroup.NullTime}, - deletionWatchReader: &SystemTableReader{resourcegroup.RunawayWatchDoneTableName, - "done_time", - resourcegroup.NullTime}, - } -} - -func (s *runawaySyncer) getWatchRecordByID(id int64) ([]*resourcegroup.QuarantineRecord, error) { - return s.getWatchRecord(s.newWatchReader, s.newWatchReader.genSelectByIDStmt(id), false) -} - -func (s *runawaySyncer) getNewWatchRecords() ([]*resourcegroup.QuarantineRecord, error) { - return s.getWatchRecord(s.newWatchReader, s.newWatchReader.genSelectStmt, true) -} - -func (s *runawaySyncer) getNewWatchDoneRecords() ([]*resourcegroup.QuarantineRecord, error) { - return s.getWatchDoneRecord(s.deletionWatchReader, s.deletionWatchReader.genSelectStmt, true) -} - -func (s *runawaySyncer) getWatchRecord(reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { - se, err := s.sysSessionPool.Get() - defer func() { - s.sysSessionPool.Put(se) - }() - if err != nil { - return nil, errors.Annotate(err, "get session failed") - } - sctx := se.(sessionctx.Context) - exec := sctx.GetRestrictedSQLExecutor() - return getRunawayWatchRecord(exec, reader, sqlGenFn, push) -} - -func (s *runawaySyncer) getWatchDoneRecord(reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { - se, err := s.sysSessionPool.Get() - defer func() { - s.sysSessionPool.Put(se) - }() - if err != nil { - return nil, errors.Annotate(err, "get session failed") - } - sctx := se.(sessionctx.Context) - exec := sctx.GetRestrictedSQLExecutor() - return getRunawayWatchDoneRecord(exec, reader, sqlGenFn, push) -} - -func getRunawayWatchRecord(exec sqlexec.RestrictedSQLExecutor, reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { - rs, err := reader.Read(exec, sqlGenFn) - if err != nil { - return nil, err - } - ret := make([]*resourcegroup.QuarantineRecord, 0, len(rs)) - now := time.Now().UTC() - for _, r := range rs { - startTime, err := r.GetTime(2).GoTime(time.UTC) - if err != nil { - continue - } - var endTime time.Time - if !r.IsNull(3) { - endTime, err = r.GetTime(3).GoTime(time.UTC) - if err != nil { - continue - } - } - qr := &resourcegroup.QuarantineRecord{ - ID: r.GetInt64(0), - ResourceGroupName: r.GetString(1), - StartTime: startTime, - EndTime: endTime, - Watch: rmpb.RunawayWatchType(r.GetInt64(4)), - WatchText: r.GetString(5), - Source: r.GetString(6), - Action: rmpb.RunawayAction(r.GetInt64(7)), - } - // If a TiDB write record slow, it will occur that the record which has earlier start time is inserted later than others. - // So we start the scan a little earlier. - if push { - reader.CheckPoint = now.Add(-3 * runawayWatchSyncInterval) - } - ret = append(ret, qr) - } - return ret, nil -} - -func getRunawayWatchDoneRecord(exec sqlexec.RestrictedSQLExecutor, reader *SystemTableReader, sqlGenFn func() (string, []any), push bool) ([]*resourcegroup.QuarantineRecord, error) { - rs, err := reader.Read(exec, sqlGenFn) - if err != nil { - return nil, err - } - length := len(rs) - ret := make([]*resourcegroup.QuarantineRecord, 0, length) - now := time.Now().UTC() - for _, r := range rs { - startTime, err := r.GetTime(3).GoTime(time.UTC) - if err != nil { - continue - } - var endTime time.Time - if !r.IsNull(4) { - endTime, err = r.GetTime(4).GoTime(time.UTC) - if err != nil { - continue - } - } - qr := &resourcegroup.QuarantineRecord{ - ID: r.GetInt64(1), - ResourceGroupName: r.GetString(2), - StartTime: startTime, - EndTime: endTime, - Watch: rmpb.RunawayWatchType(r.GetInt64(5)), - WatchText: r.GetString(6), - Source: r.GetString(7), - Action: rmpb.RunawayAction(r.GetInt64(8)), - } - // Ditto as getRunawayWatchRecord. - if push { - reader.CheckPoint = now.Add(-3 * runawayWatchSyncInterval) - } - ret = append(ret, qr) - } - return ret, nil -} - -// SystemTableReader is used to read table `runaway_watch` and `runaway_watch_done`. -type SystemTableReader struct { - TableName string - KeyCol string - CheckPoint time.Time -} - -func (r *SystemTableReader) genSelectByIDStmt(id int64) func() (string, []any) { - return func() (string, []any) { - var builder strings.Builder - params := make([]any, 0, 1) - builder.WriteString("select * from ") - builder.WriteString(r.TableName) - builder.WriteString(" where id = %?") - params = append(params, id) - return builder.String(), params - } -} - -func (r *SystemTableReader) genSelectStmt() (string, []any) { - var builder strings.Builder - params := make([]any, 0, 1) - builder.WriteString("select * from ") - builder.WriteString(r.TableName) - builder.WriteString(" where ") - builder.WriteString(r.KeyCol) - builder.WriteString(" > %? order by ") - builder.WriteString(r.KeyCol) - params = append(params, r.CheckPoint) - return builder.String(), params -} - -func (*SystemTableReader) Read(exec sqlexec.RestrictedSQLExecutor, genFn func() (string, []any)) ([]chunk.Row, error) { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) - sql, params := genFn() - rows, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, - sql, params..., - ) - return rows, err -} diff --git a/pkg/executor/adapter.go b/pkg/executor/adapter.go index 2c1c5cd305918..5ea9fdcafb178 100644 --- a/pkg/executor/adapter.go +++ b/pkg/executor/adapter.go @@ -299,12 +299,12 @@ func (a *ExecStmt) PointGet(ctx context.Context) (*recordSet, error) { r.Span.LogKV("sql", a.OriginText()) } - if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInShortPointGetPlan")); _err_ == nil { + failpoint.Inject("assertTxnManagerInShortPointGetPlan", func() { sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInShortPointGetPlan", true) // stale read should not reach here staleread.AssertStmtStaleness(a.Ctx, false) sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, a.InfoSchema) - } + }) ctx = a.observeStmtBeginForTopSQL(ctx) startTs, err := sessiontxn.GetTxnManager(a.Ctx).GetStmtReadTS() @@ -399,7 +399,7 @@ func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { return 0, err } - if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInRebuildPlan")); _err_ == nil { + failpoint.Inject("assertTxnManagerInRebuildPlan", func() { if is, ok := a.Ctx.Value(sessiontxn.AssertTxnInfoSchemaAfterRetryKey).(infoschema.InfoSchema); ok { a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaKey, is) a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaAfterRetryKey, nil) @@ -410,7 +410,7 @@ func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { if ret.IsStaleness { sessiontxn.AssertTxnManagerReadTS(a.Ctx, ret.LastSnapshotTS) } - } + }) a.InfoSchema = sessiontxn.GetTxnManager(a.Ctx).GetTxnInfoSchema() replicaReadScope := sessiontxn.GetTxnManager(a.Ctx).GetReadReplicaScope() @@ -488,7 +488,7 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.GetTextToLog(false)), zap.Stack("stack")) }() - if val, _err_ := failpoint.Eval(_curpkg_("assertStaleTSO")); _err_ == nil { + failpoint.Inject("assertStaleTSO", func(val failpoint.Value) { if n, ok := val.(int); ok && staleread.IsStmtStaleness(a.Ctx) { txnManager := sessiontxn.GetTxnManager(a.Ctx) ts, err := txnManager.GetStmtReadTS() @@ -500,7 +500,7 @@ func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { panic(fmt.Sprintf("different tso %d != %d", n, startTS)) } } - } + }) sctx := a.Ctx ctx = util.SetSessionID(ctx, sctx.GetSessionVars().ConnectionID) if _, ok := a.Plan.(*plannercore.Analyze); ok && sctx.GetSessionVars().InRestrictedSQL { @@ -729,12 +729,12 @@ func (a *ExecStmt) handleForeignKeyCascade(ctx context.Context, fkc *FKCascadeEx return err } err = exec.Next(ctx, e, exec.NewFirstChunk(e)) - if val, _err_ := failpoint.Eval(_curpkg_("handleForeignKeyCascadeError")); _err_ == nil { + failpoint.Inject("handleForeignKeyCascadeError", func(val failpoint.Value) { // Next can recover panic and convert it to error. So we inject error directly here. if val.(bool) && err == nil { err = errors.New("handleForeignKeyCascadeError") } - } + }) closeErr := exec.Close(e) if err == nil { err = closeErr @@ -943,7 +943,7 @@ func (a *ExecStmt) handlePessimisticSelectForUpdate(ctx context.Context, e exec. return rs, nil } - failpoint.Eval(_curpkg_("pessimisticSelectForUpdateRetry")) + failpoint.Inject("pessimisticSelectForUpdateRetry", nil) } } @@ -1050,7 +1050,7 @@ func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e exec.Executor) (e for { if !isFirstAttempt { - failpoint.Eval(_curpkg_("pessimisticDMLRetry")) + failpoint.Inject("pessimisticDMLRetry", nil) } startTime := time.Now() @@ -1122,13 +1122,13 @@ func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error if lockErr == nil { return nil, nil } - if _, _err_ := failpoint.Eval(_curpkg_("assertPessimisticLockErr")); _err_ == nil { + failpoint.Inject("assertPessimisticLockErr", func() { if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errWriteConflict") } else if terror.ErrorEqual(kv.ErrKeyExists, lockErr) { sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errDuplicateKey") } - } + }) defer func() { if _, ok := errors.Cause(err).(*tikverr.ErrDeadlock); ok { @@ -1177,9 +1177,9 @@ func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error a.Ctx.GetSessionVars().StmtCtx.ResetForRetry() a.Ctx.GetSessionVars().RetryInfo.ResetOffset() - if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerAfterPessimisticLockErrorRetry")); _err_ == nil { + failpoint.Inject("assertTxnManagerAfterPessimisticLockErrorRetry", func() { sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterPessimisticLockErrorRetry", true) - } + }) if err = a.openExecutor(ctx, e); err != nil { return nil, err @@ -1213,10 +1213,10 @@ func (a *ExecStmt) buildExecutor() (exec.Executor, error) { return nil, errors.Trace(b.err) } - if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerAfterBuildExecutor")); _err_ == nil { + failpoint.Inject("assertTxnManagerAfterBuildExecutor", func() { sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterBuildExecutor", true) sessiontxn.AssertTxnManagerInfoSchema(b.ctx, b.is) - } + }) // ExecuteExec is not a real Executor, we only use it to build another Executor from a prepared statement. if executorExec, ok := e.(*ExecuteExec); ok { @@ -1474,9 +1474,9 @@ func (a *ExecStmt) recordLastQueryInfo(err error) { ruDetail := ruDetailRaw.(*util.RUDetails) lastRUConsumption = ruDetail.RRU() + ruDetail.WRU() } - if _, _err_ := failpoint.Eval(_curpkg_("mockRUConsumption")); _err_ == nil { + failpoint.Inject("mockRUConsumption", func(_ failpoint.Value) { lastRUConsumption = float64(len(sessVars.StmtCtx.OriginalSQL)) - } + }) // Keep the previous queryInfo for `show session_states` because the statement needs to encode it. sessVars.LastQueryInfo = sessionstates.QueryInfo{ TxnScope: sessVars.CheckAndGetTxnScope(), @@ -1668,13 +1668,13 @@ func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { WRU: ruDetails.WRU(), WaitRUDuration: ruDetails.RUWaitDuration(), } - if val, _err_ := failpoint.Eval(_curpkg_("assertSyncStatsFailed")); _err_ == nil { + failpoint.Inject("assertSyncStatsFailed", func(val failpoint.Value) { if val.(bool) { if !slowItems.IsSyncStatsFailed { panic("isSyncStatsFailed should be true") } } - } + }) if a.retryCount > 0 { slowItems.ExecRetryTime = costTime - sessVars.DurationParse - sessVars.DurationCompile - time.Since(a.retryStartTime) } diff --git a/pkg/executor/adapter.go__failpoint_stash__ b/pkg/executor/adapter.go__failpoint_stash__ deleted file mode 100644 index 5ea9fdcafb178..0000000000000 --- a/pkg/executor/adapter.go__failpoint_stash__ +++ /dev/null @@ -1,2240 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "context" - "fmt" - "math" - "runtime/trace" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/bindinfo" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - executor_metrics "github.com/pingcap/tidb/pkg/executor/metrics" - "github.com/pingcap/tidb/pkg/executor/staticrecordset" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/keyspace" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/planner" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - "github.com/pingcap/tidb/pkg/plugin" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/sessiontxn/staleread" - "github.com/pingcap/tidb/pkg/types" - util2 "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/breakpoint" - "github.com/pingcap/tidb/pkg/util/chunk" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/plancodec" - "github.com/pingcap/tidb/pkg/util/redact" - "github.com/pingcap/tidb/pkg/util/replayer" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/pingcap/tidb/pkg/util/stmtsummary" - stmtsummaryv2 "github.com/pingcap/tidb/pkg/util/stmtsummary/v2" - "github.com/pingcap/tidb/pkg/util/stringutil" - "github.com/pingcap/tidb/pkg/util/topsql" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/prometheus/client_golang/prometheus" - tikverr "github.com/tikv/client-go/v2/error" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -// processinfoSetter is the interface use to set current running process info. -type processinfoSetter interface { - SetProcessInfo(string, time.Time, byte, uint64) - UpdateProcessInfo() -} - -// recordSet wraps an executor, implements sqlexec.RecordSet interface -type recordSet struct { - fields []*ast.ResultField - executor exec.Executor - // The `Fields` method may be called after `Close`, and the executor is cleared in the `Close` function. - // Therefore, we need to store the schema in `recordSet` to avoid a null pointer exception when calling `executor.Schema()`. - schema *expression.Schema - stmt *ExecStmt - lastErrs []error - txnStartTS uint64 - once sync.Once -} - -func (a *recordSet) Fields() []*ast.ResultField { - if len(a.fields) == 0 { - a.fields = colNames2ResultFields(a.schema, a.stmt.OutputNames, a.stmt.Ctx.GetSessionVars().CurrentDB) - } - return a.fields -} - -func colNames2ResultFields(schema *expression.Schema, names []*types.FieldName, defaultDB string) []*ast.ResultField { - rfs := make([]*ast.ResultField, 0, schema.Len()) - defaultDBCIStr := model.NewCIStr(defaultDB) - for i := 0; i < schema.Len(); i++ { - dbName := names[i].DBName - if dbName.L == "" && names[i].TblName.L != "" { - dbName = defaultDBCIStr - } - origColName := names[i].OrigColName - emptyOrgName := false - if origColName.L == "" { - origColName = names[i].ColName - emptyOrgName = true - } - rf := &ast.ResultField{ - Column: &model.ColumnInfo{Name: origColName, FieldType: *schema.Columns[i].RetType}, - ColumnAsName: names[i].ColName, - EmptyOrgName: emptyOrgName, - Table: &model.TableInfo{Name: names[i].OrigTblName}, - TableAsName: names[i].TblName, - DBName: dbName, - } - // This is for compatibility. - // See issue https://github.com/pingcap/tidb/issues/10513 . - if len(rf.ColumnAsName.O) > mysql.MaxAliasIdentifierLen { - rf.ColumnAsName.O = rf.ColumnAsName.O[:mysql.MaxAliasIdentifierLen] - } - // Usually the length of O equals the length of L. - // Add this len judgement to avoid panic. - if len(rf.ColumnAsName.L) > mysql.MaxAliasIdentifierLen { - rf.ColumnAsName.L = rf.ColumnAsName.L[:mysql.MaxAliasIdentifierLen] - } - rfs = append(rfs, rf) - } - return rfs -} - -// Next use uses recordSet's executor to get next available chunk for later usage. -// If chunk does not contain any rows, then we update last query found rows in session variable as current found rows. -// The reason we need update is that chunk with 0 rows indicating we already finished current query, we need prepare for -// next query. -// If stmt is not nil and chunk with some rows inside, we simply update last query found rows by the number of row in chunk. -func (a *recordSet) Next(ctx context.Context, req *chunk.Chunk) (err error) { - defer func() { - r := recover() - if r == nil { - return - } - err = util2.GetRecoverError(r) - logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.stmt.GetTextToLog(false)), zap.Stack("stack")) - }() - if a.stmt != nil { - if err := a.stmt.Ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { - return err - } - } - - err = a.stmt.next(ctx, a.executor, req) - if err != nil { - a.lastErrs = append(a.lastErrs, err) - return err - } - numRows := req.NumRows() - if numRows == 0 { - if a.stmt != nil { - a.stmt.Ctx.GetSessionVars().LastFoundRows = a.stmt.Ctx.GetSessionVars().StmtCtx.FoundRows() - } - return nil - } - if a.stmt != nil { - a.stmt.Ctx.GetSessionVars().StmtCtx.AddFoundRows(uint64(numRows)) - } - return nil -} - -// NewChunk create a chunk base on top-level executor's exec.NewFirstChunk(). -func (a *recordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { - if alloc == nil { - return exec.NewFirstChunk(a.executor) - } - - return alloc.Alloc(a.executor.RetFieldTypes(), a.executor.InitCap(), a.executor.MaxChunkSize()) -} - -func (a *recordSet) Finish() error { - var err error - a.once.Do(func() { - err = exec.Close(a.executor) - cteErr := resetCTEStorageMap(a.stmt.Ctx) - if cteErr != nil { - logutil.BgLogger().Error("got error when reset cte storage, should check if the spill disk file deleted or not", zap.Error(cteErr)) - } - if err == nil { - err = cteErr - } - a.executor = nil - if a.stmt != nil { - status := a.stmt.Ctx.GetSessionVars().SQLKiller.GetKillSignal() - inWriteResultSet := a.stmt.Ctx.GetSessionVars().SQLKiller.InWriteResultSet.Load() - if status > 0 && inWriteResultSet { - logutil.BgLogger().Warn("kill, this SQL might be stuck in the network stack while writing packets to the client.", zap.Uint64("connection ID", a.stmt.Ctx.GetSessionVars().ConnectionID)) - } - } - }) - if err != nil { - a.lastErrs = append(a.lastErrs, err) - } - return err -} - -func (a *recordSet) Close() error { - err := a.Finish() - if err != nil { - logutil.BgLogger().Error("close recordSet error", zap.Error(err)) - } - a.stmt.CloseRecordSet(a.txnStartTS, errors.Join(a.lastErrs...)) - return err -} - -// OnFetchReturned implements commandLifeCycle#OnFetchReturned -func (a *recordSet) OnFetchReturned() { - a.stmt.LogSlowQuery(a.txnStartTS, len(a.lastErrs) == 0, true) -} - -// TryDetach creates a new `RecordSet` which doesn't depend on the current session context. -func (a *recordSet) TryDetach() (sqlexec.RecordSet, bool, error) { - e, ok := Detach(a.executor) - if !ok { - return nil, false, nil - } - return staticrecordset.New(a.Fields(), e, a.stmt.GetTextToLog(false)), true, nil -} - -// GetExecutor4Test exports the internal executor for test purpose. -func (a *recordSet) GetExecutor4Test() any { - return a.executor -} - -// ExecStmt implements the sqlexec.Statement interface, it builds a planner.Plan to an sqlexec.Statement. -type ExecStmt struct { - // GoCtx stores parent go context.Context for a stmt. - GoCtx context.Context - // InfoSchema stores a reference to the schema information. - InfoSchema infoschema.InfoSchema - // Plan stores a reference to the final physical plan. - Plan base.Plan - // Text represents the origin query text. - Text string - - StmtNode ast.StmtNode - - Ctx sessionctx.Context - - // LowerPriority represents whether to lower the execution priority of a query. - LowerPriority bool - isPreparedStmt bool - isSelectForUpdate bool - retryCount uint - retryStartTime time.Time - - // Phase durations are splited into two parts: 1. trying to lock keys (but - // failed); 2. the final iteration of the retry loop. Here we use - // [2]time.Duration to record such info for each phase. The first duration - // is increased only within the current iteration. When we meet a - // pessimistic lock error and decide to retry, we add the first duration to - // the second and reset the first to 0 by calling `resetPhaseDurations`. - phaseBuildDurations [2]time.Duration - phaseOpenDurations [2]time.Duration - phaseNextDurations [2]time.Duration - phaseLockDurations [2]time.Duration - - // OutputNames will be set if using cached plan - OutputNames []*types.FieldName - PsStmt *plannercore.PlanCacheStmt -} - -// GetStmtNode returns the stmtNode inside Statement -func (a *ExecStmt) GetStmtNode() ast.StmtNode { - return a.StmtNode -} - -// PointGet short path for point exec directly from plan, keep only necessary steps -func (a *ExecStmt) PointGet(ctx context.Context) (*recordSet, error) { - r, ctx := tracing.StartRegionEx(ctx, "ExecStmt.PointGet") - defer r.End() - if r.Span != nil { - r.Span.LogKV("sql", a.OriginText()) - } - - failpoint.Inject("assertTxnManagerInShortPointGetPlan", func() { - sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInShortPointGetPlan", true) - // stale read should not reach here - staleread.AssertStmtStaleness(a.Ctx, false) - sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, a.InfoSchema) - }) - - ctx = a.observeStmtBeginForTopSQL(ctx) - startTs, err := sessiontxn.GetTxnManager(a.Ctx).GetStmtReadTS() - if err != nil { - return nil, err - } - a.Ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityHigh - - var executor exec.Executor - useMaxTS := startTs == math.MaxUint64 - - // try to reuse point get executor - // We should only use the cached the executor when the startTS is MaxUint64 - if a.PsStmt.PointGet.Executor != nil && useMaxTS { - exec, ok := a.PsStmt.PointGet.Executor.(*PointGetExecutor) - if !ok { - logutil.Logger(ctx).Error("invalid executor type, not PointGetExecutor for point get path") - a.PsStmt.PointGet.Executor = nil - } else { - // CachedPlan type is already checked in last step - pointGetPlan := a.Plan.(*plannercore.PointGetPlan) - exec.Init(pointGetPlan) - a.PsStmt.PointGet.Executor = exec - executor = exec - } - } - - if executor == nil { - b := newExecutorBuilder(a.Ctx, a.InfoSchema) - executor = b.build(a.Plan) - if b.err != nil { - return nil, b.err - } - pointExecutor, ok := executor.(*PointGetExecutor) - - // Don't cache the executor for non point-get (table dual) or partitioned tables - if ok && useMaxTS && pointExecutor.partitionDefIdx == nil { - a.PsStmt.PointGet.Executor = pointExecutor - } - } - - if err = exec.Open(ctx, executor); err != nil { - terror.Log(exec.Close(executor)) - return nil, err - } - - sctx := a.Ctx - cmd32 := atomic.LoadUint32(&sctx.GetSessionVars().CommandValue) - cmd := byte(cmd32) - var pi processinfoSetter - if raw, ok := sctx.(processinfoSetter); ok { - pi = raw - sql := a.OriginText() - maxExecutionTime := sctx.GetSessionVars().GetMaxExecutionTime() - // Update processinfo, ShowProcess() will use it. - pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime) - if sctx.GetSessionVars().StmtCtx.StmtType == "" { - sctx.GetSessionVars().StmtCtx.StmtType = ast.GetStmtLabel(a.StmtNode) - } - } - - return &recordSet{ - executor: executor, - schema: executor.Schema(), - stmt: a, - txnStartTS: startTs, - }, nil -} - -// OriginText returns original statement as a string. -func (a *ExecStmt) OriginText() string { - return a.Text -} - -// IsPrepared returns true if stmt is a prepare statement. -func (a *ExecStmt) IsPrepared() bool { - return a.isPreparedStmt -} - -// IsReadOnly returns true if a statement is read only. -// If current StmtNode is an ExecuteStmt, we can get its prepared stmt, -// then using ast.IsReadOnly function to determine a statement is read only or not. -func (a *ExecStmt) IsReadOnly(vars *variable.SessionVars) bool { - return planner.IsReadOnly(a.StmtNode, vars) -} - -// RebuildPlan rebuilds current execute statement plan. -// It returns the current information schema version that 'a' is using. -func (a *ExecStmt) RebuildPlan(ctx context.Context) (int64, error) { - ret := &plannercore.PreprocessorReturn{} - if err := plannercore.Preprocess(ctx, a.Ctx, a.StmtNode, plannercore.InTxnRetry, plannercore.InitTxnContextProvider, plannercore.WithPreprocessorReturn(ret)); err != nil { - return 0, err - } - - failpoint.Inject("assertTxnManagerInRebuildPlan", func() { - if is, ok := a.Ctx.Value(sessiontxn.AssertTxnInfoSchemaAfterRetryKey).(infoschema.InfoSchema); ok { - a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaKey, is) - a.Ctx.SetValue(sessiontxn.AssertTxnInfoSchemaAfterRetryKey, nil) - } - sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerInRebuildPlan", true) - sessiontxn.AssertTxnManagerInfoSchema(a.Ctx, ret.InfoSchema) - staleread.AssertStmtStaleness(a.Ctx, ret.IsStaleness) - if ret.IsStaleness { - sessiontxn.AssertTxnManagerReadTS(a.Ctx, ret.LastSnapshotTS) - } - }) - - a.InfoSchema = sessiontxn.GetTxnManager(a.Ctx).GetTxnInfoSchema() - replicaReadScope := sessiontxn.GetTxnManager(a.Ctx).GetReadReplicaScope() - if a.Ctx.GetSessionVars().GetReplicaRead().IsClosestRead() && replicaReadScope == kv.GlobalReplicaScope { - logutil.BgLogger().Warn(fmt.Sprintf("tidb can't read closest replicas due to it haven't %s label", placement.DCLabelKey)) - } - p, names, err := planner.Optimize(ctx, a.Ctx, a.StmtNode, a.InfoSchema) - if err != nil { - return 0, err - } - a.OutputNames = names - a.Plan = p - a.Ctx.GetSessionVars().StmtCtx.SetPlan(p) - return a.InfoSchema.SchemaMetaVersion(), nil -} - -// IsFastPlan exports for testing. -func IsFastPlan(p base.Plan) bool { - if proj, ok := p.(*plannercore.PhysicalProjection); ok { - p = proj.Children()[0] - } - switch p.(type) { - case *plannercore.PointGetPlan: - return true - case *plannercore.PhysicalTableDual: - // Plan of following SQL is PhysicalTableDual: - // select 1; - // select @@autocommit; - return true - case *plannercore.Set: - // Plan of following SQL is Set: - // set @a=1; - // set @@autocommit=1; - return true - } - return false -} - -// Exec builds an Executor from a plan. If the Executor doesn't return result, -// like the INSERT, UPDATE statements, it executes in this function. If the Executor returns -// result, execution is done after this function returns, in the returned sqlexec.RecordSet Next method. -func (a *ExecStmt) Exec(ctx context.Context) (_ sqlexec.RecordSet, err error) { - defer func() { - r := recover() - if r == nil { - if a.retryCount > 0 { - metrics.StatementPessimisticRetryCount.Observe(float64(a.retryCount)) - } - lockKeysCnt := a.Ctx.GetSessionVars().StmtCtx.LockKeysCount - if lockKeysCnt > 0 { - metrics.StatementLockKeysCount.Observe(float64(lockKeysCnt)) - } - - execDetails := a.Ctx.GetSessionVars().StmtCtx.GetExecDetails() - if err == nil && execDetails.LockKeysDetail != nil && - (execDetails.LockKeysDetail.AggressiveLockNewCount > 0 || execDetails.LockKeysDetail.AggressiveLockDerivedCount > 0) { - a.Ctx.GetSessionVars().TxnCtx.FairLockingUsed = true - // If this statement is finished when some of the keys are locked with conflict in the last retry, or - // some of the keys are derived from the previous retry, we consider the optimization of fair locking - // takes effect on this statement. - if execDetails.LockKeysDetail.LockedWithConflictCount > 0 || execDetails.LockKeysDetail.AggressiveLockDerivedCount > 0 { - a.Ctx.GetSessionVars().TxnCtx.FairLockingEffective = true - } - } - return - } - recoverdErr, ok := r.(error) - if !ok || !(exeerrors.ErrMemoryExceedForQuery.Equal(recoverdErr) || - exeerrors.ErrMemoryExceedForInstance.Equal(recoverdErr) || - exeerrors.ErrQueryInterrupted.Equal(recoverdErr) || - exeerrors.ErrMaxExecTimeExceeded.Equal(recoverdErr)) { - panic(r) - } - err = recoverdErr - logutil.Logger(ctx).Error("execute sql panic", zap.String("sql", a.GetTextToLog(false)), zap.Stack("stack")) - }() - - failpoint.Inject("assertStaleTSO", func(val failpoint.Value) { - if n, ok := val.(int); ok && staleread.IsStmtStaleness(a.Ctx) { - txnManager := sessiontxn.GetTxnManager(a.Ctx) - ts, err := txnManager.GetStmtReadTS() - if err != nil { - panic(err) - } - startTS := oracle.ExtractPhysical(ts) / 1000 - if n != int(startTS) { - panic(fmt.Sprintf("different tso %d != %d", n, startTS)) - } - } - }) - sctx := a.Ctx - ctx = util.SetSessionID(ctx, sctx.GetSessionVars().ConnectionID) - if _, ok := a.Plan.(*plannercore.Analyze); ok && sctx.GetSessionVars().InRestrictedSQL { - oriStats, ok := sctx.GetSessionVars().GetSystemVar(variable.TiDBBuildStatsConcurrency) - if !ok { - oriStats = strconv.Itoa(variable.DefBuildStatsConcurrency) - } - oriScan := sctx.GetSessionVars().AnalyzeDistSQLScanConcurrency() - oriIso, ok := sctx.GetSessionVars().GetSystemVar(variable.TxnIsolation) - if !ok { - oriIso = "REPEATABLE-READ" - } - autoConcurrency, err1 := sctx.GetSessionVars().GetSessionOrGlobalSystemVar(ctx, variable.TiDBAutoBuildStatsConcurrency) - terror.Log(err1) - if err1 == nil { - terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TiDBBuildStatsConcurrency, autoConcurrency)) - } - sVal, err2 := sctx.GetSessionVars().GetSessionOrGlobalSystemVar(ctx, variable.TiDBSysProcScanConcurrency) - terror.Log(err2) - if err2 == nil { - concurrency, err3 := strconv.ParseInt(sVal, 10, 64) - terror.Log(err3) - if err3 == nil { - sctx.GetSessionVars().SetAnalyzeDistSQLScanConcurrency(int(concurrency)) - } - } - terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TxnIsolation, ast.ReadCommitted)) - defer func() { - terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TiDBBuildStatsConcurrency, oriStats)) - sctx.GetSessionVars().SetAnalyzeDistSQLScanConcurrency(oriScan) - terror.Log(sctx.GetSessionVars().SetSystemVar(variable.TxnIsolation, oriIso)) - }() - } - - if sctx.GetSessionVars().StmtCtx.HasMemQuotaHint { - sctx.GetSessionVars().MemTracker.SetBytesLimit(sctx.GetSessionVars().StmtCtx.MemQuotaQuery) - } - - // must set plan according to the `Execute` plan before getting planDigest - a.inheritContextFromExecuteStmt() - if rm := domain.GetDomain(sctx).RunawayManager(); variable.EnableResourceControl.Load() && rm != nil { - sessionVars := sctx.GetSessionVars() - stmtCtx := sessionVars.StmtCtx - _, planDigest := GetPlanDigest(stmtCtx) - _, sqlDigest := stmtCtx.SQLDigest() - stmtCtx.RunawayChecker = rm.DeriveChecker(stmtCtx.ResourceGroupName, stmtCtx.OriginalSQL, sqlDigest.String(), planDigest.String(), sessionVars.StartTime) - if err := stmtCtx.RunawayChecker.BeforeExecutor(); err != nil { - return nil, err - } - } - ctx = a.observeStmtBeginForTopSQL(ctx) - - e, err := a.buildExecutor() - if err != nil { - return nil, err - } - - cmd32 := atomic.LoadUint32(&sctx.GetSessionVars().CommandValue) - cmd := byte(cmd32) - var pi processinfoSetter - if raw, ok := sctx.(processinfoSetter); ok { - pi = raw - sql := a.getSQLForProcessInfo() - maxExecutionTime := sctx.GetSessionVars().GetMaxExecutionTime() - // Update processinfo, ShowProcess() will use it. - if a.Ctx.GetSessionVars().StmtCtx.StmtType == "" { - a.Ctx.GetSessionVars().StmtCtx.StmtType = ast.GetStmtLabel(a.StmtNode) - } - // Since maxExecutionTime is used only for query statement, here we limit it affect scope. - if !a.IsReadOnly(a.Ctx.GetSessionVars()) { - maxExecutionTime = 0 - } - pi.SetProcessInfo(sql, time.Now(), cmd, maxExecutionTime) - } - - breakpoint.Inject(a.Ctx, sessiontxn.BreakPointBeforeExecutorFirstRun) - if err = a.openExecutor(ctx, e); err != nil { - terror.Log(exec.Close(e)) - return nil, err - } - - isPessimistic := sctx.GetSessionVars().TxnCtx.IsPessimistic - - if a.isSelectForUpdate { - if sctx.GetSessionVars().UseLowResolutionTSO() { - return nil, errors.New("can not execute select for update statement when 'tidb_low_resolution_tso' is set") - } - // Special handle for "select for update statement" in pessimistic transaction. - if isPessimistic { - return a.handlePessimisticSelectForUpdate(ctx, e) - } - } - - a.prepareFKCascadeContext(e) - if handled, result, err := a.handleNoDelay(ctx, e, isPessimistic); handled || err != nil { - return result, err - } - - var txnStartTS uint64 - txn, err := sctx.Txn(false) - if err != nil { - return nil, err - } - if txn.Valid() { - txnStartTS = txn.StartTS() - } - - return &recordSet{ - executor: e, - schema: e.Schema(), - stmt: a, - txnStartTS: txnStartTS, - }, nil -} - -func (a *ExecStmt) inheritContextFromExecuteStmt() { - if executePlan, ok := a.Plan.(*plannercore.Execute); ok { - a.Ctx.SetValue(sessionctx.QueryString, executePlan.Stmt.Text()) - a.OutputNames = executePlan.OutputNames() - a.isPreparedStmt = true - a.Plan = executePlan.Plan - a.Ctx.GetSessionVars().StmtCtx.SetPlan(executePlan.Plan) - } -} - -func (a *ExecStmt) getSQLForProcessInfo() string { - sql := a.OriginText() - if simple, ok := a.Plan.(*plannercore.Simple); ok && simple.Statement != nil { - if ss, ok := simple.Statement.(ast.SensitiveStmtNode); ok { - // Use SecureText to avoid leak password information. - sql = ss.SecureText() - } - } else if sn, ok2 := a.StmtNode.(ast.SensitiveStmtNode); ok2 { - // such as import into statement - sql = sn.SecureText() - } - return sql -} - -func (a *ExecStmt) handleStmtForeignKeyTrigger(ctx context.Context, e exec.Executor) error { - stmtCtx := a.Ctx.GetSessionVars().StmtCtx - if stmtCtx.ForeignKeyTriggerCtx.HasFKCascades { - // If the ExecStmt has foreign key cascade to be executed, we need call `StmtCommit` to commit the ExecStmt itself - // change first. - // Since `UnionScanExec` use `SnapshotIter` and `SnapshotGetter` to read txn mem-buffer, if we don't do `StmtCommit`, - // then the fk cascade executor can't read the mem-buffer changed by the ExecStmt. - a.Ctx.StmtCommit(ctx) - } - err := a.handleForeignKeyTrigger(ctx, e, 1) - if err != nil { - err1 := a.handleFKTriggerError(stmtCtx) - if err1 != nil { - return errors.Errorf("handle foreign key trigger error failed, err: %v, original_err: %v", err1, err) - } - return err - } - if stmtCtx.ForeignKeyTriggerCtx.SavepointName != "" { - a.Ctx.GetSessionVars().TxnCtx.ReleaseSavepoint(stmtCtx.ForeignKeyTriggerCtx.SavepointName) - } - return nil -} - -var maxForeignKeyCascadeDepth = 15 - -func (a *ExecStmt) handleForeignKeyTrigger(ctx context.Context, e exec.Executor, depth int) error { - exec, ok := e.(WithForeignKeyTrigger) - if !ok { - return nil - } - fkChecks := exec.GetFKChecks() - for _, fkCheck := range fkChecks { - err := fkCheck.doCheck(ctx) - if err != nil { - return err - } - } - fkCascades := exec.GetFKCascades() - for _, fkCascade := range fkCascades { - err := a.handleForeignKeyCascade(ctx, fkCascade, depth) - if err != nil { - return err - } - } - return nil -} - -// handleForeignKeyCascade uses to execute foreign key cascade behaviour, the progress is: -// 1. Build delete/update executor for foreign key on delete/update behaviour. -// a. Construct delete/update AST. We used to try generated SQL string first and then parse the SQL to get AST, -// but we need convert Datum to string, there may be some risks here, since assert_eq(datum_a, parse(datum_a.toString())) may be broken. -// so we chose to construct AST directly. -// b. Build plan by the delete/update AST. -// c. Build executor by the delete/update plan. -// 2. Execute the delete/update executor. -// 3. Close the executor. -// 4. `StmtCommit` to commit the kv change to transaction mem-buffer. -// 5. If the foreign key cascade behaviour has more fk value need to be cascaded, go to step 1. -func (a *ExecStmt) handleForeignKeyCascade(ctx context.Context, fkc *FKCascadeExec, depth int) error { - if a.Ctx.GetSessionVars().StmtCtx.RuntimeStatsColl != nil { - fkc.stats = &FKCascadeRuntimeStats{} - defer a.Ctx.GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(fkc.plan.ID(), fkc.stats) - } - if len(fkc.fkValues) == 0 && len(fkc.fkUpdatedValuesMap) == 0 { - return nil - } - if depth > maxForeignKeyCascadeDepth { - return exeerrors.ErrForeignKeyCascadeDepthExceeded.GenWithStackByArgs(maxForeignKeyCascadeDepth) - } - a.Ctx.GetSessionVars().StmtCtx.InHandleForeignKeyTrigger = true - defer func() { - a.Ctx.GetSessionVars().StmtCtx.InHandleForeignKeyTrigger = false - }() - if fkc.stats != nil { - start := time.Now() - defer func() { - fkc.stats.Total += time.Since(start) - }() - } - for { - e, err := fkc.buildExecutor(ctx) - if err != nil || e == nil { - return err - } - if err := exec.Open(ctx, e); err != nil { - terror.Log(exec.Close(e)) - return err - } - err = exec.Next(ctx, e, exec.NewFirstChunk(e)) - failpoint.Inject("handleForeignKeyCascadeError", func(val failpoint.Value) { - // Next can recover panic and convert it to error. So we inject error directly here. - if val.(bool) && err == nil { - err = errors.New("handleForeignKeyCascadeError") - } - }) - closeErr := exec.Close(e) - if err == nil { - err = closeErr - } - if err != nil { - return err - } - // Call `StmtCommit` uses to flush the fk cascade executor change into txn mem-buffer, - // then the later fk cascade executors can see the mem-buffer changes. - a.Ctx.StmtCommit(ctx) - err = a.handleForeignKeyTrigger(ctx, e, depth+1) - if err != nil { - return err - } - } -} - -// prepareFKCascadeContext records a transaction savepoint for foreign key cascade when this ExecStmt has foreign key -// cascade behaviour and this ExecStmt is in transaction. -func (a *ExecStmt) prepareFKCascadeContext(e exec.Executor) { - exec, ok := e.(WithForeignKeyTrigger) - if !ok || !exec.HasFKCascades() { - return - } - sessVar := a.Ctx.GetSessionVars() - sessVar.StmtCtx.ForeignKeyTriggerCtx.HasFKCascades = true - if !sessVar.InTxn() { - return - } - txn, err := a.Ctx.Txn(false) - if err != nil || !txn.Valid() { - return - } - // Record a txn savepoint if ExecStmt in transaction, the savepoint is use to do rollback when handle foreign key - // cascade failed. - savepointName := "fk_sp_" + strconv.FormatUint(txn.StartTS(), 10) - memDBCheckpoint := txn.GetMemDBCheckpoint() - sessVar.TxnCtx.AddSavepoint(savepointName, memDBCheckpoint) - sessVar.StmtCtx.ForeignKeyTriggerCtx.SavepointName = savepointName -} - -func (a *ExecStmt) handleFKTriggerError(sc *stmtctx.StatementContext) error { - if sc.ForeignKeyTriggerCtx.SavepointName == "" { - return nil - } - txn, err := a.Ctx.Txn(false) - if err != nil || !txn.Valid() { - return err - } - savepointRecord := a.Ctx.GetSessionVars().TxnCtx.RollbackToSavepoint(sc.ForeignKeyTriggerCtx.SavepointName) - if savepointRecord == nil { - // Normally should never run into here, but just in case, rollback the transaction. - err = txn.Rollback() - if err != nil { - return err - } - return errors.Errorf("foreign key cascade savepoint '%s' not found, transaction is rollback, should never happen", sc.ForeignKeyTriggerCtx.SavepointName) - } - txn.RollbackMemDBToCheckpoint(savepointRecord.MemDBCheckpoint) - a.Ctx.GetSessionVars().TxnCtx.ReleaseSavepoint(sc.ForeignKeyTriggerCtx.SavepointName) - return nil -} - -func (a *ExecStmt) handleNoDelay(ctx context.Context, e exec.Executor, isPessimistic bool) (handled bool, rs sqlexec.RecordSet, err error) { - sc := a.Ctx.GetSessionVars().StmtCtx - defer func() { - // If the stmt have no rs like `insert`, The session tracker detachment will be directly - // done in the `defer` function. If the rs is not nil, the detachment will be done in - // `rs.Close` in `handleStmt` - if handled && sc != nil && rs == nil { - sc.DetachMemDiskTracker() - cteErr := resetCTEStorageMap(a.Ctx) - if err == nil { - // Only overwrite err when it's nil. - err = cteErr - } - } - }() - - toCheck := e - isExplainAnalyze := false - if explain, ok := e.(*ExplainExec); ok { - if analyze := explain.getAnalyzeExecToExecutedNoDelay(); analyze != nil { - toCheck = analyze - isExplainAnalyze = true - a.Ctx.GetSessionVars().StmtCtx.IsExplainAnalyzeDML = isExplainAnalyze - } - } - - // If the executor doesn't return any result to the client, we execute it without delay. - if toCheck.Schema().Len() == 0 { - handled = !isExplainAnalyze - if isPessimistic { - err := a.handlePessimisticDML(ctx, toCheck) - return handled, nil, err - } - r, err := a.handleNoDelayExecutor(ctx, toCheck) - return handled, r, err - } else if proj, ok := toCheck.(*ProjectionExec); ok && proj.calculateNoDelay { - // Currently this is only for the "DO" statement. Take "DO 1, @a=2;" as an example: - // the Projection has two expressions and two columns in the schema, but we should - // not return the result of the two expressions. - r, err := a.handleNoDelayExecutor(ctx, e) - return true, r, err - } - - return false, nil, nil -} - -func isNoResultPlan(p base.Plan) bool { - if p.Schema().Len() == 0 { - return true - } - - // Currently this is only for the "DO" statement. Take "DO 1, @a=2;" as an example: - // the Projection has two expressions and two columns in the schema, but we should - // not return the result of the two expressions. - switch raw := p.(type) { - case *logicalop.LogicalProjection: - if raw.CalculateNoDelay { - return true - } - case *plannercore.PhysicalProjection: - if raw.CalculateNoDelay { - return true - } - } - return false -} - -type chunkRowRecordSet struct { - rows []chunk.Row - idx int - fields []*ast.ResultField - e exec.Executor - execStmt *ExecStmt -} - -func (c *chunkRowRecordSet) Fields() []*ast.ResultField { - if c.fields == nil { - c.fields = colNames2ResultFields(c.e.Schema(), c.execStmt.OutputNames, c.execStmt.Ctx.GetSessionVars().CurrentDB) - } - return c.fields -} - -func (c *chunkRowRecordSet) Next(_ context.Context, chk *chunk.Chunk) error { - chk.Reset() - if !chk.IsFull() && c.idx < len(c.rows) { - numToAppend := min(len(c.rows)-c.idx, chk.RequiredRows()-chk.NumRows()) - chk.AppendRows(c.rows[c.idx : c.idx+numToAppend]) - c.idx += numToAppend - } - return nil -} - -func (c *chunkRowRecordSet) NewChunk(alloc chunk.Allocator) *chunk.Chunk { - if alloc == nil { - return exec.NewFirstChunk(c.e) - } - - return alloc.Alloc(c.e.RetFieldTypes(), c.e.InitCap(), c.e.MaxChunkSize()) -} - -func (c *chunkRowRecordSet) Close() error { - c.execStmt.CloseRecordSet(c.execStmt.Ctx.GetSessionVars().TxnCtx.StartTS, nil) - return nil -} - -func (a *ExecStmt) handlePessimisticSelectForUpdate(ctx context.Context, e exec.Executor) (_ sqlexec.RecordSet, retErr error) { - if snapshotTS := a.Ctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { - terror.Log(exec.Close(e)) - return nil, errors.New("can not execute write statement when 'tidb_snapshot' is set") - } - - txnManager := sessiontxn.GetTxnManager(a.Ctx) - err := txnManager.OnPessimisticStmtStart(ctx) - if err != nil { - return nil, err - } - defer func() { - isSuccessful := retErr == nil - err1 := txnManager.OnPessimisticStmtEnd(ctx, isSuccessful) - if retErr == nil && err1 != nil { - retErr = err1 - } - }() - - isFirstAttempt := true - - for { - startTime := time.Now() - rs, err := a.runPessimisticSelectForUpdate(ctx, e) - - if isFirstAttempt { - executor_metrics.SelectForUpdateFirstAttemptDuration.Observe(time.Since(startTime).Seconds()) - isFirstAttempt = false - } else { - executor_metrics.SelectForUpdateRetryDuration.Observe(time.Since(startTime).Seconds()) - } - - e, err = a.handlePessimisticLockError(ctx, err) - if err != nil { - return nil, err - } - if e == nil { - return rs, nil - } - - failpoint.Inject("pessimisticSelectForUpdateRetry", nil) - } -} - -func (a *ExecStmt) runPessimisticSelectForUpdate(ctx context.Context, e exec.Executor) (sqlexec.RecordSet, error) { - defer func() { - terror.Log(exec.Close(e)) - }() - var rows []chunk.Row - var err error - req := exec.TryNewCacheChunk(e) - for { - err = a.next(ctx, e, req) - if err != nil { - // Handle 'write conflict' error. - break - } - if req.NumRows() == 0 { - return &chunkRowRecordSet{rows: rows, e: e, execStmt: a}, nil - } - iter := chunk.NewIterator4Chunk(req) - for r := iter.Begin(); r != iter.End(); r = iter.Next() { - rows = append(rows, r) - } - req = chunk.Renew(req, a.Ctx.GetSessionVars().MaxChunkSize) - } - return nil, err -} - -func (a *ExecStmt) handleNoDelayExecutor(ctx context.Context, e exec.Executor) (sqlexec.RecordSet, error) { - sctx := a.Ctx - r, ctx := tracing.StartRegionEx(ctx, "executor.handleNoDelayExecutor") - defer r.End() - - var err error - defer func() { - terror.Log(exec.Close(e)) - a.logAudit() - }() - - // Check if "tidb_snapshot" is set for the write executors. - // In history read mode, we can not do write operations. - switch e.(type) { - case *DeleteExec, *InsertExec, *UpdateExec, *ReplaceExec, *LoadDataExec, *DDLExec, *ImportIntoExec: - snapshotTS := sctx.GetSessionVars().SnapshotTS - if snapshotTS != 0 { - return nil, errors.New("can not execute write statement when 'tidb_snapshot' is set") - } - if sctx.GetSessionVars().UseLowResolutionTSO() { - return nil, errors.New("can not execute write statement when 'tidb_low_resolution_tso' is set") - } - } - - err = a.next(ctx, e, exec.TryNewCacheChunk(e)) - if err != nil { - return nil, err - } - err = a.handleStmtForeignKeyTrigger(ctx, e) - return nil, err -} - -func (a *ExecStmt) handlePessimisticDML(ctx context.Context, e exec.Executor) (err error) { - sctx := a.Ctx - // Do not activate the transaction here. - // When autocommit = 0 and transaction in pessimistic mode, - // statements like set xxx = xxx; should not active the transaction. - txn, err := sctx.Txn(false) - if err != nil { - return err - } - txnCtx := sctx.GetSessionVars().TxnCtx - defer func() { - if err != nil && !sctx.GetSessionVars().ConstraintCheckInPlacePessimistic && sctx.GetSessionVars().InTxn() { - // If it's not a retryable error, rollback current transaction instead of rolling back current statement like - // in normal transactions, because we cannot locate and rollback the statement that leads to the lock error. - // This is too strict, but since the feature is not for everyone, it's the easiest way to guarantee safety. - stmtText := parser.Normalize(a.OriginText(), sctx.GetSessionVars().EnableRedactLog) - logutil.Logger(ctx).Info("Transaction abort for the safety of lazy uniqueness check. "+ - "Note this may not be a uniqueness violation.", - zap.Error(err), - zap.String("statement", stmtText), - zap.Uint64("conn", sctx.GetSessionVars().ConnectionID), - zap.Uint64("txnStartTS", txnCtx.StartTS), - zap.Uint64("forUpdateTS", txnCtx.GetForUpdateTS()), - ) - sctx.GetSessionVars().SetInTxn(false) - err = exeerrors.ErrLazyUniquenessCheckFailure.GenWithStackByArgs(err.Error()) - } - }() - - txnManager := sessiontxn.GetTxnManager(a.Ctx) - err = txnManager.OnPessimisticStmtStart(ctx) - if err != nil { - return err - } - defer func() { - isSuccessful := err == nil - err1 := txnManager.OnPessimisticStmtEnd(ctx, isSuccessful) - if err == nil && err1 != nil { - err = err1 - } - }() - - isFirstAttempt := true - - for { - if !isFirstAttempt { - failpoint.Inject("pessimisticDMLRetry", nil) - } - - startTime := time.Now() - _, err = a.handleNoDelayExecutor(ctx, e) - if !txn.Valid() { - return err - } - - if isFirstAttempt { - executor_metrics.DmlFirstAttemptDuration.Observe(time.Since(startTime).Seconds()) - isFirstAttempt = false - } else { - executor_metrics.DmlRetryDuration.Observe(time.Since(startTime).Seconds()) - } - - if err != nil { - // It is possible the DML has point get plan that locks the key. - e, err = a.handlePessimisticLockError(ctx, err) - if err != nil { - if exeerrors.ErrDeadlock.Equal(err) { - metrics.StatementDeadlockDetectDuration.Observe(time.Since(startTime).Seconds()) - } - return err - } - continue - } - keys, err1 := txn.(pessimisticTxn).KeysNeedToLock() - if err1 != nil { - return err1 - } - keys = txnCtx.CollectUnchangedKeysForLock(keys) - if len(keys) == 0 { - return nil - } - keys = filterTemporaryTableKeys(sctx.GetSessionVars(), keys) - seVars := sctx.GetSessionVars() - keys = filterLockTableKeys(seVars.StmtCtx, keys) - lockCtx, err := newLockCtx(sctx, seVars.LockWaitTimeout, len(keys)) - if err != nil { - return err - } - var lockKeyStats *util.LockKeysDetails - ctx = context.WithValue(ctx, util.LockKeysDetailCtxKey, &lockKeyStats) - startLocking := time.Now() - err = txn.LockKeys(ctx, lockCtx, keys...) - a.phaseLockDurations[0] += time.Since(startLocking) - if e.RuntimeStats() != nil { - e.RuntimeStats().Record(time.Since(startLocking), 0) - } - if lockKeyStats != nil { - seVars.StmtCtx.MergeLockKeysExecDetails(lockKeyStats) - } - if err == nil { - return nil - } - e, err = a.handlePessimisticLockError(ctx, err) - if err != nil { - // todo: Report deadlock - if exeerrors.ErrDeadlock.Equal(err) { - metrics.StatementDeadlockDetectDuration.Observe(time.Since(startLocking).Seconds()) - } - return err - } - } -} - -// handlePessimisticLockError updates TS and rebuild executor if the err is write conflict. -func (a *ExecStmt) handlePessimisticLockError(ctx context.Context, lockErr error) (_ exec.Executor, err error) { - if lockErr == nil { - return nil, nil - } - failpoint.Inject("assertPessimisticLockErr", func() { - if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { - sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errWriteConflict") - } else if terror.ErrorEqual(kv.ErrKeyExists, lockErr) { - sessiontxn.AddAssertEntranceForLockError(a.Ctx, "errDuplicateKey") - } - }) - - defer func() { - if _, ok := errors.Cause(err).(*tikverr.ErrDeadlock); ok { - err = exeerrors.ErrDeadlock - } - }() - - txnManager := sessiontxn.GetTxnManager(a.Ctx) - action, err := txnManager.OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterPessimisticLock, lockErr) - if err != nil { - return nil, err - } - - if action != sessiontxn.StmtActionRetryReady { - return nil, lockErr - } - - if a.retryCount >= config.GetGlobalConfig().PessimisticTxn.MaxRetryCount { - return nil, errors.New("pessimistic lock retry limit reached") - } - a.retryCount++ - a.retryStartTime = time.Now() - - err = txnManager.OnStmtRetry(ctx) - if err != nil { - return nil, err - } - - // Without this line of code, the result will still be correct. But it can ensure that the update time of for update read - // is determined which is beneficial for testing. - if _, err = txnManager.GetStmtForUpdateTS(); err != nil { - return nil, err - } - - breakpoint.Inject(a.Ctx, sessiontxn.BreakPointOnStmtRetryAfterLockError) - - a.resetPhaseDurations() - - a.inheritContextFromExecuteStmt() - e, err := a.buildExecutor() - if err != nil { - return nil, err - } - // Rollback the statement change before retry it. - a.Ctx.StmtRollback(ctx, true) - a.Ctx.GetSessionVars().StmtCtx.ResetForRetry() - a.Ctx.GetSessionVars().RetryInfo.ResetOffset() - - failpoint.Inject("assertTxnManagerAfterPessimisticLockErrorRetry", func() { - sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterPessimisticLockErrorRetry", true) - }) - - if err = a.openExecutor(ctx, e); err != nil { - return nil, err - } - return e, nil -} - -type pessimisticTxn interface { - kv.Transaction - // KeysNeedToLock returns the keys need to be locked. - KeysNeedToLock() ([]kv.Key, error) -} - -// buildExecutor build an executor from plan, prepared statement may need additional procedure. -func (a *ExecStmt) buildExecutor() (exec.Executor, error) { - defer func(start time.Time) { a.phaseBuildDurations[0] += time.Since(start) }(time.Now()) - ctx := a.Ctx - stmtCtx := ctx.GetSessionVars().StmtCtx - if _, ok := a.Plan.(*plannercore.Execute); !ok { - if stmtCtx.Priority == mysql.NoPriority && a.LowerPriority { - stmtCtx.Priority = kv.PriorityLow - } - } - if _, ok := a.Plan.(*plannercore.Analyze); ok && ctx.GetSessionVars().InRestrictedSQL { - ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityLow - } - - b := newExecutorBuilder(ctx, a.InfoSchema) - e := b.build(a.Plan) - if b.err != nil { - return nil, errors.Trace(b.err) - } - - failpoint.Inject("assertTxnManagerAfterBuildExecutor", func() { - sessiontxn.RecordAssert(a.Ctx, "assertTxnManagerAfterBuildExecutor", true) - sessiontxn.AssertTxnManagerInfoSchema(b.ctx, b.is) - }) - - // ExecuteExec is not a real Executor, we only use it to build another Executor from a prepared statement. - if executorExec, ok := e.(*ExecuteExec); ok { - err := executorExec.Build(b) - if err != nil { - return nil, err - } - if executorExec.lowerPriority { - ctx.GetSessionVars().StmtCtx.Priority = kv.PriorityLow - } - e = executorExec.stmtExec - } - a.isSelectForUpdate = b.hasLock && (!stmtCtx.InDeleteStmt && !stmtCtx.InUpdateStmt && !stmtCtx.InInsertStmt) - return e, nil -} - -func (a *ExecStmt) openExecutor(ctx context.Context, e exec.Executor) (err error) { - defer func() { - if r := recover(); r != nil { - err = util2.GetRecoverError(r) - } - }() - start := time.Now() - err = exec.Open(ctx, e) - a.phaseOpenDurations[0] += time.Since(start) - return err -} - -func (a *ExecStmt) next(ctx context.Context, e exec.Executor, req *chunk.Chunk) error { - start := time.Now() - err := exec.Next(ctx, e, req) - a.phaseNextDurations[0] += time.Since(start) - return err -} - -func (a *ExecStmt) resetPhaseDurations() { - a.phaseBuildDurations[1] += a.phaseBuildDurations[0] - a.phaseBuildDurations[0] = 0 - a.phaseOpenDurations[1] += a.phaseOpenDurations[0] - a.phaseOpenDurations[0] = 0 - a.phaseNextDurations[1] += a.phaseNextDurations[0] - a.phaseNextDurations[0] = 0 - a.phaseLockDurations[1] += a.phaseLockDurations[0] - a.phaseLockDurations[0] = 0 -} - -// QueryReplacer replaces new line and tab for grep result including query string. -var QueryReplacer = strings.NewReplacer("\r", " ", "\n", " ", "\t", " ") - -func (a *ExecStmt) logAudit() { - sessVars := a.Ctx.GetSessionVars() - if sessVars.InRestrictedSQL { - return - } - - err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { - audit := plugin.DeclareAuditManifest(p.Manifest) - if audit.OnGeneralEvent != nil { - cmd := mysql.Command2Str[byte(atomic.LoadUint32(&a.Ctx.GetSessionVars().CommandValue))] - ctx := context.WithValue(context.Background(), plugin.ExecStartTimeCtxKey, a.Ctx.GetSessionVars().StartTime) - audit.OnGeneralEvent(ctx, sessVars, plugin.Completed, cmd) - } - return nil - }) - if err != nil { - log.Error("log audit log failure", zap.Error(err)) - } -} - -// FormatSQL is used to format the original SQL, e.g. truncating long SQL, appending prepared arguments. -func FormatSQL(sql string) stringutil.StringerFunc { - return func() string { return formatSQL(sql) } -} - -func formatSQL(sql string) string { - length := len(sql) - maxQueryLen := variable.QueryLogMaxLen.Load() - if maxQueryLen <= 0 { - return QueryReplacer.Replace(sql) // no limit - } - if int32(length) > maxQueryLen { - var result strings.Builder - result.Grow(int(maxQueryLen)) - result.WriteString(sql[:maxQueryLen]) - fmt.Fprintf(&result, "(len:%d)", length) - return QueryReplacer.Replace(result.String()) - } - return QueryReplacer.Replace(sql) -} - -func getPhaseDurationObserver(phase string, internal bool) prometheus.Observer { - if internal { - if ob, found := executor_metrics.PhaseDurationObserverMapInternal[phase]; found { - return ob - } - return executor_metrics.ExecUnknownInternal - } - if ob, found := executor_metrics.PhaseDurationObserverMap[phase]; found { - return ob - } - return executor_metrics.ExecUnknown -} - -func (a *ExecStmt) observePhaseDurations(internal bool, commitDetails *util.CommitDetails) { - for _, it := range []struct { - duration time.Duration - phase string - }{ - {a.phaseBuildDurations[0], executor_metrics.PhaseBuildFinal}, - {a.phaseBuildDurations[1], executor_metrics.PhaseBuildLocking}, - {a.phaseOpenDurations[0], executor_metrics.PhaseOpenFinal}, - {a.phaseOpenDurations[1], executor_metrics.PhaseOpenLocking}, - {a.phaseNextDurations[0], executor_metrics.PhaseNextFinal}, - {a.phaseNextDurations[1], executor_metrics.PhaseNextLocking}, - {a.phaseLockDurations[0], executor_metrics.PhaseLockFinal}, - {a.phaseLockDurations[1], executor_metrics.PhaseLockLocking}, - } { - if it.duration > 0 { - getPhaseDurationObserver(it.phase, internal).Observe(it.duration.Seconds()) - } - } - if commitDetails != nil { - for _, it := range []struct { - duration time.Duration - phase string - }{ - {commitDetails.PrewriteTime, executor_metrics.PhaseCommitPrewrite}, - {commitDetails.CommitTime, executor_metrics.PhaseCommitCommit}, - {commitDetails.GetCommitTsTime, executor_metrics.PhaseCommitWaitCommitTS}, - {commitDetails.GetLatestTsTime, executor_metrics.PhaseCommitWaitLatestTS}, - {commitDetails.LocalLatchTime, executor_metrics.PhaseCommitWaitLatch}, - {commitDetails.WaitPrewriteBinlogTime, executor_metrics.PhaseCommitWaitBinlog}, - } { - if it.duration > 0 { - getPhaseDurationObserver(it.phase, internal).Observe(it.duration.Seconds()) - } - } - } - if stmtDetailsRaw := a.GoCtx.Value(execdetails.StmtExecDetailKey); stmtDetailsRaw != nil { - d := stmtDetailsRaw.(*execdetails.StmtExecDetails).WriteSQLRespDuration - if d > 0 { - getPhaseDurationObserver(executor_metrics.PhaseWriteResponse, internal).Observe(d.Seconds()) - } - } -} - -// FinishExecuteStmt is used to record some information after `ExecStmt` execution finished: -// 1. record slow log if needed. -// 2. record summary statement. -// 3. record execute duration metric. -// 4. update the `PrevStmt` in session variable. -// 5. reset `DurationParse` in session variable. -func (a *ExecStmt) FinishExecuteStmt(txnTS uint64, err error, hasMoreResults bool) { - a.checkPlanReplayerCapture(txnTS) - - sessVars := a.Ctx.GetSessionVars() - execDetail := sessVars.StmtCtx.GetExecDetails() - // Attach commit/lockKeys runtime stats to executor runtime stats. - if (execDetail.CommitDetail != nil || execDetail.LockKeysDetail != nil) && sessVars.StmtCtx.RuntimeStatsColl != nil { - statsWithCommit := &execdetails.RuntimeStatsWithCommit{ - Commit: execDetail.CommitDetail, - LockKeys: execDetail.LockKeysDetail, - } - sessVars.StmtCtx.RuntimeStatsColl.RegisterStats(a.Plan.ID(), statsWithCommit) - } - // Record related SLI metrics. - if execDetail.CommitDetail != nil && execDetail.CommitDetail.WriteSize > 0 { - a.Ctx.GetTxnWriteThroughputSLI().AddTxnWriteSize(execDetail.CommitDetail.WriteSize, execDetail.CommitDetail.WriteKeys) - } - if execDetail.ScanDetail != nil && sessVars.StmtCtx.AffectedRows() > 0 { - processedKeys := atomic.LoadInt64(&execDetail.ScanDetail.ProcessedKeys) - if processedKeys > 0 { - // Only record the read keys in write statement which affect row more than 0. - a.Ctx.GetTxnWriteThroughputSLI().AddReadKeys(processedKeys) - } - } - succ := err == nil - if a.Plan != nil { - // If this statement has a Plan, the StmtCtx.plan should have been set when it comes here, - // but we set it again in case we missed some code paths. - sessVars.StmtCtx.SetPlan(a.Plan) - } - // `LowSlowQuery` and `SummaryStmt` must be called before recording `PrevStmt`. - a.LogSlowQuery(txnTS, succ, hasMoreResults) - a.SummaryStmt(succ) - a.observeStmtFinishedForTopSQL() - if sessVars.StmtCtx.IsTiFlash.Load() { - if succ { - executor_metrics.TotalTiFlashQuerySuccCounter.Inc() - } else { - metrics.TiFlashQueryTotalCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err), metrics.LblError).Inc() - } - } - a.updatePrevStmt() - a.recordLastQueryInfo(err) - a.observePhaseDurations(sessVars.InRestrictedSQL, execDetail.CommitDetail) - executeDuration := sessVars.GetExecuteDuration() - if sessVars.InRestrictedSQL { - executor_metrics.SessionExecuteRunDurationInternal.Observe(executeDuration.Seconds()) - } else { - executor_metrics.SessionExecuteRunDurationGeneral.Observe(executeDuration.Seconds()) - } - // Reset DurationParse due to the next statement may not need to be parsed (not a text protocol query). - sessVars.DurationParse = 0 - // Clean the stale read flag when statement execution finish - sessVars.StmtCtx.IsStaleness = false - // Clean the MPP query info - sessVars.StmtCtx.MPPQueryInfo.QueryID.Store(0) - sessVars.StmtCtx.MPPQueryInfo.QueryTS.Store(0) - sessVars.StmtCtx.MPPQueryInfo.AllocatedMPPTaskID.Store(0) - sessVars.StmtCtx.MPPQueryInfo.AllocatedMPPGatherID.Store(0) - - if sessVars.StmtCtx.ReadFromTableCache { - metrics.ReadFromTableCacheCounter.Inc() - } - - // Update fair locking related counters by stmt - if execDetail.LockKeysDetail != nil { - if execDetail.LockKeysDetail.AggressiveLockNewCount > 0 || execDetail.LockKeysDetail.AggressiveLockDerivedCount > 0 { - executor_metrics.FairLockingStmtUsedCount.Inc() - // If this statement is finished when some of the keys are locked with conflict in the last retry, or - // some of the keys are derived from the previous retry, we consider the optimization of fair locking - // takes effect on this statement. - if execDetail.LockKeysDetail.LockedWithConflictCount > 0 || execDetail.LockKeysDetail.AggressiveLockDerivedCount > 0 { - executor_metrics.FairLockingStmtEffectiveCount.Inc() - } - } - } - // If the transaction is committed, update fair locking related counters by txn - if execDetail.CommitDetail != nil { - if sessVars.TxnCtx.FairLockingUsed { - executor_metrics.FairLockingTxnUsedCount.Inc() - } - if sessVars.TxnCtx.FairLockingEffective { - executor_metrics.FairLockingTxnEffectiveCount.Inc() - } - } - - a.Ctx.ReportUsageStats() -} - -func (a *ExecStmt) recordLastQueryInfo(err error) { - sessVars := a.Ctx.GetSessionVars() - // Record diagnostic information for DML statements - recordLastQuery := false - switch typ := a.StmtNode.(type) { - case *ast.ShowStmt: - recordLastQuery = typ.Tp != ast.ShowSessionStates - case *ast.ExecuteStmt, ast.DMLNode: - recordLastQuery = true - } - if recordLastQuery { - var lastRUConsumption float64 - if ruDetailRaw := a.GoCtx.Value(util.RUDetailsCtxKey); ruDetailRaw != nil { - ruDetail := ruDetailRaw.(*util.RUDetails) - lastRUConsumption = ruDetail.RRU() + ruDetail.WRU() - } - failpoint.Inject("mockRUConsumption", func(_ failpoint.Value) { - lastRUConsumption = float64(len(sessVars.StmtCtx.OriginalSQL)) - }) - // Keep the previous queryInfo for `show session_states` because the statement needs to encode it. - sessVars.LastQueryInfo = sessionstates.QueryInfo{ - TxnScope: sessVars.CheckAndGetTxnScope(), - StartTS: sessVars.TxnCtx.StartTS, - ForUpdateTS: sessVars.TxnCtx.GetForUpdateTS(), - RUConsumption: lastRUConsumption, - } - if err != nil { - sessVars.LastQueryInfo.ErrMsg = err.Error() - } - } -} - -func (a *ExecStmt) checkPlanReplayerCapture(txnTS uint64) { - if kv.GetInternalSourceType(a.GoCtx) == kv.InternalTxnStats { - return - } - se := a.Ctx - if !se.GetSessionVars().InRestrictedSQL && se.GetSessionVars().IsPlanReplayerCaptureEnabled() { - stmtNode := a.GetStmtNode() - if se.GetSessionVars().EnablePlanReplayedContinuesCapture { - if checkPlanReplayerContinuesCaptureValidStmt(stmtNode) { - checkPlanReplayerContinuesCapture(se, stmtNode, txnTS) - } - } else { - checkPlanReplayerCaptureTask(se, stmtNode, txnTS) - } - } -} - -// CloseRecordSet will finish the execution of current statement and do some record work -func (a *ExecStmt) CloseRecordSet(txnStartTS uint64, lastErr error) { - a.FinishExecuteStmt(txnStartTS, lastErr, false) - a.logAudit() - a.Ctx.GetSessionVars().StmtCtx.DetachMemDiskTracker() -} - -// Clean CTE storage shared by different CTEFullScan executor within a SQL stmt. -// Will return err in two situations: -// 1. Got err when remove disk spill file. -// 2. Some logical error like ref count of CTEStorage is less than 0. -func resetCTEStorageMap(se sessionctx.Context) error { - tmp := se.GetSessionVars().StmtCtx.CTEStorageMap - if tmp == nil { - // Close() is already called, so no need to reset. Such as TraceExec. - return nil - } - storageMap, ok := tmp.(map[int]*CTEStorages) - if !ok { - return errors.New("type assertion for CTEStorageMap failed") - } - for _, v := range storageMap { - v.ResTbl.Lock() - err1 := v.ResTbl.DerefAndClose() - // Make sure we do not hold the lock for longer than necessary. - v.ResTbl.Unlock() - // No need to lock IterInTbl. - err2 := v.IterInTbl.DerefAndClose() - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - } - se.GetSessionVars().StmtCtx.CTEStorageMap = nil - return nil -} - -// LogSlowQuery is used to print the slow query in the log files. -func (a *ExecStmt) LogSlowQuery(txnTS uint64, succ bool, hasMoreResults bool) { - sessVars := a.Ctx.GetSessionVars() - stmtCtx := sessVars.StmtCtx - level := log.GetLevel() - cfg := config.GetGlobalConfig() - costTime := sessVars.GetTotalCostDuration() - threshold := time.Duration(atomic.LoadUint64(&cfg.Instance.SlowThreshold)) * time.Millisecond - enable := cfg.Instance.EnableSlowLog.Load() - // if the level is Debug, or trace is enabled, print slow logs anyway - force := level <= zapcore.DebugLevel || trace.IsEnabled() - if (!enable || costTime < threshold) && !force { - return - } - sql := FormatSQL(a.GetTextToLog(true)) - _, digest := stmtCtx.SQLDigest() - - var indexNames string - if len(stmtCtx.IndexNames) > 0 { - // remove duplicate index. - idxMap := make(map[string]struct{}) - buf := bytes.NewBuffer(make([]byte, 0, 4)) - buf.WriteByte('[') - for _, idx := range stmtCtx.IndexNames { - _, ok := idxMap[idx] - if ok { - continue - } - idxMap[idx] = struct{}{} - if buf.Len() > 1 { - buf.WriteByte(',') - } - buf.WriteString(idx) - } - buf.WriteByte(']') - indexNames = buf.String() - } - var stmtDetail execdetails.StmtExecDetails - stmtDetailRaw := a.GoCtx.Value(execdetails.StmtExecDetailKey) - if stmtDetailRaw != nil { - stmtDetail = *(stmtDetailRaw.(*execdetails.StmtExecDetails)) - } - var tikvExecDetail util.ExecDetails - tikvExecDetailRaw := a.GoCtx.Value(util.ExecDetailsKey) - if tikvExecDetailRaw != nil { - tikvExecDetail = *(tikvExecDetailRaw.(*util.ExecDetails)) - } - ruDetails := util.NewRUDetails() - if ruDetailsVal := a.GoCtx.Value(util.RUDetailsCtxKey); ruDetailsVal != nil { - ruDetails = ruDetailsVal.(*util.RUDetails) - } - - execDetail := stmtCtx.GetExecDetails() - copTaskInfo := stmtCtx.CopTasksDetails() - memMax := sessVars.MemTracker.MaxConsumed() - diskMax := sessVars.DiskTracker.MaxConsumed() - _, planDigest := GetPlanDigest(stmtCtx) - - binaryPlan := "" - if variable.GenerateBinaryPlan.Load() { - binaryPlan = getBinaryPlan(a.Ctx) - if len(binaryPlan) > 0 { - binaryPlan = variable.SlowLogBinaryPlanPrefix + binaryPlan + variable.SlowLogPlanSuffix - } - } - - resultRows := GetResultRowsCount(stmtCtx, a.Plan) - - var ( - keyspaceName string - keyspaceID uint32 - ) - keyspaceName = keyspace.GetKeyspaceNameBySettings() - if !keyspace.IsKeyspaceNameEmpty(keyspaceName) { - keyspaceID = uint32(a.Ctx.GetStore().GetCodec().GetKeyspaceID()) - } - if txnTS == 0 { - // TODO: txnTS maybe ambiguous, consider logging stale-read-ts with a new field in the slow log. - txnTS = sessVars.TxnCtx.StaleReadTs - } - - slowItems := &variable.SlowQueryLogItems{ - TxnTS: txnTS, - KeyspaceName: keyspaceName, - KeyspaceID: keyspaceID, - SQL: sql.String(), - Digest: digest.String(), - TimeTotal: costTime, - TimeParse: sessVars.DurationParse, - TimeCompile: sessVars.DurationCompile, - TimeOptimize: sessVars.DurationOptimization, - TimeWaitTS: sessVars.DurationWaitTS, - IndexNames: indexNames, - CopTasks: &copTaskInfo, - ExecDetail: execDetail, - MemMax: memMax, - DiskMax: diskMax, - Succ: succ, - Plan: getPlanTree(stmtCtx), - PlanDigest: planDigest.String(), - BinaryPlan: binaryPlan, - Prepared: a.isPreparedStmt, - HasMoreResults: hasMoreResults, - PlanFromCache: sessVars.FoundInPlanCache, - PlanFromBinding: sessVars.FoundInBinding, - RewriteInfo: sessVars.RewritePhaseInfo, - KVTotal: time.Duration(atomic.LoadInt64(&tikvExecDetail.WaitKVRespDuration)), - PDTotal: time.Duration(atomic.LoadInt64(&tikvExecDetail.WaitPDRespDuration)), - BackoffTotal: time.Duration(atomic.LoadInt64(&tikvExecDetail.BackoffDuration)), - WriteSQLRespTotal: stmtDetail.WriteSQLRespDuration, - ResultRows: resultRows, - ExecRetryCount: a.retryCount, - IsExplicitTxn: sessVars.TxnCtx.IsExplicit, - IsWriteCacheTable: stmtCtx.WaitLockLeaseTime > 0, - UsedStats: stmtCtx.GetUsedStatsInfo(false), - IsSyncStatsFailed: stmtCtx.IsSyncStatsFailed, - Warnings: collectWarningsForSlowLog(stmtCtx), - ResourceGroupName: sessVars.StmtCtx.ResourceGroupName, - RRU: ruDetails.RRU(), - WRU: ruDetails.WRU(), - WaitRUDuration: ruDetails.RUWaitDuration(), - } - failpoint.Inject("assertSyncStatsFailed", func(val failpoint.Value) { - if val.(bool) { - if !slowItems.IsSyncStatsFailed { - panic("isSyncStatsFailed should be true") - } - } - }) - if a.retryCount > 0 { - slowItems.ExecRetryTime = costTime - sessVars.DurationParse - sessVars.DurationCompile - time.Since(a.retryStartTime) - } - if _, ok := a.StmtNode.(*ast.CommitStmt); ok && sessVars.PrevStmt != nil { - slowItems.PrevStmt = sessVars.PrevStmt.String() - } - slowLog := sessVars.SlowLogFormat(slowItems) - if trace.IsEnabled() { - trace.Log(a.GoCtx, "details", slowLog) - } - logutil.SlowQueryLogger.Warn(slowLog) - if costTime >= threshold { - if sessVars.InRestrictedSQL { - executor_metrics.TotalQueryProcHistogramInternal.Observe(costTime.Seconds()) - executor_metrics.TotalCopProcHistogramInternal.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) - executor_metrics.TotalCopWaitHistogramInternal.Observe(execDetail.TimeDetail.WaitTime.Seconds()) - } else { - executor_metrics.TotalQueryProcHistogramGeneral.Observe(costTime.Seconds()) - executor_metrics.TotalCopProcHistogramGeneral.Observe(execDetail.TimeDetail.ProcessTime.Seconds()) - executor_metrics.TotalCopWaitHistogramGeneral.Observe(execDetail.TimeDetail.WaitTime.Seconds()) - if execDetail.ScanDetail != nil && execDetail.ScanDetail.ProcessedKeys != 0 { - executor_metrics.CopMVCCRatioHistogramGeneral.Observe(float64(execDetail.ScanDetail.TotalKeys) / float64(execDetail.ScanDetail.ProcessedKeys)) - } - } - var userString string - if sessVars.User != nil { - userString = sessVars.User.String() - } - var tableIDs string - if len(stmtCtx.TableIDs) > 0 { - tableIDs = strings.ReplaceAll(fmt.Sprintf("%v", stmtCtx.TableIDs), " ", ",") - } - domain.GetDomain(a.Ctx).LogSlowQuery(&domain.SlowQueryInfo{ - SQL: sql.String(), - Digest: digest.String(), - Start: sessVars.StartTime, - Duration: costTime, - Detail: stmtCtx.GetExecDetails(), - Succ: succ, - ConnID: sessVars.ConnectionID, - SessAlias: sessVars.SessionAlias, - TxnTS: txnTS, - User: userString, - DB: sessVars.CurrentDB, - TableIDs: tableIDs, - IndexNames: indexNames, - Internal: sessVars.InRestrictedSQL, - }) - } -} - -func extractMsgFromSQLWarn(sqlWarn *contextutil.SQLWarn) string { - // Currently, this function is only used in collectWarningsForSlowLog. - // collectWarningsForSlowLog can make sure SQLWarn is not nil so no need to add a nil check here. - warn := errors.Cause(sqlWarn.Err) - if x, ok := warn.(*terror.Error); ok && x != nil { - sqlErr := terror.ToSQLError(x) - return sqlErr.Message - } - return warn.Error() -} - -func collectWarningsForSlowLog(stmtCtx *stmtctx.StatementContext) []variable.JSONSQLWarnForSlowLog { - warnings := stmtCtx.GetWarnings() - extraWarnings := stmtCtx.GetExtraWarnings() - res := make([]variable.JSONSQLWarnForSlowLog, len(warnings)+len(extraWarnings)) - for i := range warnings { - res[i].Level = warnings[i].Level - res[i].Message = extractMsgFromSQLWarn(&warnings[i]) - } - for i := range extraWarnings { - res[len(warnings)+i].Level = extraWarnings[i].Level - res[len(warnings)+i].Message = extractMsgFromSQLWarn(&extraWarnings[i]) - res[len(warnings)+i].IsExtra = true - } - return res -} - -// GetResultRowsCount gets the count of the statement result rows. -func GetResultRowsCount(stmtCtx *stmtctx.StatementContext, p base.Plan) int64 { - runtimeStatsColl := stmtCtx.RuntimeStatsColl - if runtimeStatsColl == nil { - return 0 - } - rootPlanID := p.ID() - if !runtimeStatsColl.ExistsRootStats(rootPlanID) { - return 0 - } - rootStats := runtimeStatsColl.GetRootStats(rootPlanID) - return rootStats.GetActRows() -} - -// getFlatPlan generates a FlatPhysicalPlan from the plan stored in stmtCtx.plan, -// then stores it in stmtCtx.flatPlan. -func getFlatPlan(stmtCtx *stmtctx.StatementContext) *plannercore.FlatPhysicalPlan { - pp := stmtCtx.GetPlan() - if pp == nil { - return nil - } - if flat := stmtCtx.GetFlatPlan(); flat != nil { - f := flat.(*plannercore.FlatPhysicalPlan) - return f - } - p := pp.(base.Plan) - flat := plannercore.FlattenPhysicalPlan(p, false) - if flat != nil { - stmtCtx.SetFlatPlan(flat) - return flat - } - return nil -} - -func getBinaryPlan(sCtx sessionctx.Context) string { - stmtCtx := sCtx.GetSessionVars().StmtCtx - binaryPlan := stmtCtx.GetBinaryPlan() - if len(binaryPlan) > 0 { - return binaryPlan - } - flat := getFlatPlan(stmtCtx) - binaryPlan = plannercore.BinaryPlanStrFromFlatPlan(sCtx.GetPlanCtx(), flat) - stmtCtx.SetBinaryPlan(binaryPlan) - return binaryPlan -} - -// getPlanTree will try to get the select plan tree if the plan is select or the select plan of delete/update/insert statement. -func getPlanTree(stmtCtx *stmtctx.StatementContext) string { - cfg := config.GetGlobalConfig() - if atomic.LoadUint32(&cfg.Instance.RecordPlanInSlowLog) == 0 { - return "" - } - planTree, _ := getEncodedPlan(stmtCtx, false) - if len(planTree) == 0 { - return planTree - } - return variable.SlowLogPlanPrefix + planTree + variable.SlowLogPlanSuffix -} - -// GetPlanDigest will try to get the select plan tree if the plan is select or the select plan of delete/update/insert statement. -func GetPlanDigest(stmtCtx *stmtctx.StatementContext) (string, *parser.Digest) { - normalized, planDigest := stmtCtx.GetPlanDigest() - if len(normalized) > 0 && planDigest != nil { - return normalized, planDigest - } - flat := getFlatPlan(stmtCtx) - normalized, planDigest = plannercore.NormalizeFlatPlan(flat) - stmtCtx.SetPlanDigest(normalized, planDigest) - return normalized, planDigest -} - -// GetEncodedPlan returned same as getEncodedPlan -func GetEncodedPlan(stmtCtx *stmtctx.StatementContext, genHint bool) (encodedPlan, hintStr string) { - return getEncodedPlan(stmtCtx, genHint) -} - -// getEncodedPlan gets the encoded plan, and generates the hint string if indicated. -func getEncodedPlan(stmtCtx *stmtctx.StatementContext, genHint bool) (encodedPlan, hintStr string) { - var hintSet bool - encodedPlan = stmtCtx.GetEncodedPlan() - hintStr, hintSet = stmtCtx.GetPlanHint() - if len(encodedPlan) > 0 && (!genHint || hintSet) { - return - } - flat := getFlatPlan(stmtCtx) - if len(encodedPlan) == 0 { - encodedPlan = plannercore.EncodeFlatPlan(flat) - stmtCtx.SetEncodedPlan(encodedPlan) - } - if genHint { - hints := plannercore.GenHintsFromFlatPlan(flat) - for _, tableHint := range stmtCtx.OriginalTableHints { - // some hints like 'memory_quota' cannot be extracted from the PhysicalPlan directly, - // so we have to iterate all hints from the customer and keep some other necessary hints. - switch tableHint.HintName.L { - case hint.HintMemoryQuota, hint.HintUseToja, hint.HintNoIndexMerge, - hint.HintMaxExecutionTime, hint.HintIgnoreIndex, hint.HintReadFromStorage, - hint.HintMerge, hint.HintSemiJoinRewrite, hint.HintNoDecorrelate: - hints = append(hints, tableHint) - } - } - - hintStr = hint.RestoreOptimizerHints(hints) - stmtCtx.SetPlanHint(hintStr) - } - return -} - -// SummaryStmt collects statements for information_schema.statements_summary -func (a *ExecStmt) SummaryStmt(succ bool) { - sessVars := a.Ctx.GetSessionVars() - var userString string - if sessVars.User != nil { - userString = sessVars.User.Username - } - - // Internal SQLs must also be recorded to keep the consistency of `PrevStmt` and `PrevStmtDigest`. - if !stmtsummaryv2.Enabled() || ((sessVars.InRestrictedSQL || len(userString) == 0) && !stmtsummaryv2.EnabledInternal()) { - sessVars.SetPrevStmtDigest("") - return - } - // Ignore `PREPARE` statements, but record `EXECUTE` statements. - if _, ok := a.StmtNode.(*ast.PrepareStmt); ok { - return - } - stmtCtx := sessVars.StmtCtx - // Make sure StmtType is filled even if succ is false. - if stmtCtx.StmtType == "" { - stmtCtx.StmtType = ast.GetStmtLabel(a.StmtNode) - } - normalizedSQL, digest := stmtCtx.SQLDigest() - costTime := sessVars.GetTotalCostDuration() - charset, collation := sessVars.GetCharsetInfo() - - var prevSQL, prevSQLDigest string - if _, ok := a.StmtNode.(*ast.CommitStmt); ok { - // If prevSQLDigest is not recorded, it means this `commit` is the first SQL once stmt summary is enabled, - // so it's OK just to ignore it. - if prevSQLDigest = sessVars.GetPrevStmtDigest(); len(prevSQLDigest) == 0 { - return - } - prevSQL = sessVars.PrevStmt.String() - } - sessVars.SetPrevStmtDigest(digest.String()) - - // No need to encode every time, so encode lazily. - planGenerator := func() (p string, h string, e any) { - defer func() { - e = recover() - if e != nil { - logutil.BgLogger().Warn("fail to generate plan info", - zap.Stack("backtrace"), - zap.Any("error", e)) - } - }() - p, h = getEncodedPlan(stmtCtx, !sessVars.InRestrictedSQL) - return - } - var binPlanGen func() string - if variable.GenerateBinaryPlan.Load() { - binPlanGen = func() string { - binPlan := getBinaryPlan(a.Ctx) - return binPlan - } - } - // Generating plan digest is slow, only generate it once if it's 'Point_Get'. - // If it's a point get, different SQLs leads to different plans, so SQL digest - // is enough to distinguish different plans in this case. - var planDigest string - var planDigestGen func() string - if a.Plan.TP() == plancodec.TypePointGet { - planDigestGen = func() string { - _, planDigest := GetPlanDigest(stmtCtx) - return planDigest.String() - } - } else { - _, tmp := GetPlanDigest(stmtCtx) - planDigest = tmp.String() - } - - execDetail := stmtCtx.GetExecDetails() - copTaskInfo := stmtCtx.CopTasksDetails() - memMax := sessVars.MemTracker.MaxConsumed() - diskMax := sessVars.DiskTracker.MaxConsumed() - sql := a.getLazyStmtText() - var stmtDetail execdetails.StmtExecDetails - stmtDetailRaw := a.GoCtx.Value(execdetails.StmtExecDetailKey) - if stmtDetailRaw != nil { - stmtDetail = *(stmtDetailRaw.(*execdetails.StmtExecDetails)) - } - var tikvExecDetail util.ExecDetails - tikvExecDetailRaw := a.GoCtx.Value(util.ExecDetailsKey) - if tikvExecDetailRaw != nil { - tikvExecDetail = *(tikvExecDetailRaw.(*util.ExecDetails)) - } - var ruDetail *util.RUDetails - if ruDetailRaw := a.GoCtx.Value(util.RUDetailsCtxKey); ruDetailRaw != nil { - ruDetail = ruDetailRaw.(*util.RUDetails) - } - - if stmtCtx.WaitLockLeaseTime > 0 { - if execDetail.BackoffSleep == nil { - execDetail.BackoffSleep = make(map[string]time.Duration) - } - execDetail.BackoffSleep["waitLockLeaseForCacheTable"] = stmtCtx.WaitLockLeaseTime - execDetail.BackoffTime += stmtCtx.WaitLockLeaseTime - execDetail.TimeDetail.WaitTime += stmtCtx.WaitLockLeaseTime - } - - resultRows := GetResultRowsCount(stmtCtx, a.Plan) - - var ( - keyspaceName string - keyspaceID uint32 - ) - keyspaceName = keyspace.GetKeyspaceNameBySettings() - if !keyspace.IsKeyspaceNameEmpty(keyspaceName) { - keyspaceID = uint32(a.Ctx.GetStore().GetCodec().GetKeyspaceID()) - } - - stmtExecInfo := &stmtsummary.StmtExecInfo{ - SchemaName: strings.ToLower(sessVars.CurrentDB), - OriginalSQL: &sql, - Charset: charset, - Collation: collation, - NormalizedSQL: normalizedSQL, - Digest: digest.String(), - PrevSQL: prevSQL, - PrevSQLDigest: prevSQLDigest, - PlanGenerator: planGenerator, - BinaryPlanGenerator: binPlanGen, - PlanDigest: planDigest, - PlanDigestGen: planDigestGen, - User: userString, - TotalLatency: costTime, - ParseLatency: sessVars.DurationParse, - CompileLatency: sessVars.DurationCompile, - StmtCtx: stmtCtx, - CopTasks: &copTaskInfo, - ExecDetail: &execDetail, - MemMax: memMax, - DiskMax: diskMax, - StartTime: sessVars.StartTime, - IsInternal: sessVars.InRestrictedSQL, - Succeed: succ, - PlanInCache: sessVars.FoundInPlanCache, - PlanInBinding: sessVars.FoundInBinding, - ExecRetryCount: a.retryCount, - StmtExecDetails: stmtDetail, - ResultRows: resultRows, - TiKVExecDetails: tikvExecDetail, - Prepared: a.isPreparedStmt, - KeyspaceName: keyspaceName, - KeyspaceID: keyspaceID, - RUDetail: ruDetail, - ResourceGroupName: sessVars.StmtCtx.ResourceGroupName, - - PlanCacheUnqualified: sessVars.StmtCtx.PlanCacheUnqualified(), - } - if a.retryCount > 0 { - stmtExecInfo.ExecRetryTime = costTime - sessVars.DurationParse - sessVars.DurationCompile - time.Since(a.retryStartTime) - } - stmtsummaryv2.Add(stmtExecInfo) -} - -// GetTextToLog return the query text to log. -func (a *ExecStmt) GetTextToLog(keepHint bool) string { - var sql string - sessVars := a.Ctx.GetSessionVars() - rmode := sessVars.EnableRedactLog - if rmode == errors.RedactLogEnable { - if keepHint { - sql = parser.NormalizeKeepHint(sessVars.StmtCtx.OriginalSQL) - } else { - sql, _ = sessVars.StmtCtx.SQLDigest() - } - } else if sensitiveStmt, ok := a.StmtNode.(ast.SensitiveStmtNode); ok { - sql = sensitiveStmt.SecureText() - } else { - sql = redact.String(rmode, sessVars.StmtCtx.OriginalSQL+sessVars.PlanCacheParams.String()) - } - return sql -} - -// getLazyText is equivalent to `a.GetTextToLog(false)`. Note that the s.Params is a shallow copy of -// `sessVars.PlanCacheParams`, so you can only use the lazy text within the current stmt context. -func (a *ExecStmt) getLazyStmtText() (s variable.LazyStmtText) { - sessVars := a.Ctx.GetSessionVars() - rmode := sessVars.EnableRedactLog - if rmode == errors.RedactLogEnable { - sql, _ := sessVars.StmtCtx.SQLDigest() - s.SetText(sql) - } else if sensitiveStmt, ok := a.StmtNode.(ast.SensitiveStmtNode); ok { - sql := sensitiveStmt.SecureText() - s.SetText(sql) - } else { - s.Redact = rmode - s.SQL = sessVars.StmtCtx.OriginalSQL - s.Params = *sessVars.PlanCacheParams - } - return -} - -// updatePrevStmt is equivalent to `sessVars.PrevStmt = FormatSQL(a.GetTextToLog(false))` -func (a *ExecStmt) updatePrevStmt() { - sessVars := a.Ctx.GetSessionVars() - if sessVars.PrevStmt == nil { - sessVars.PrevStmt = &variable.LazyStmtText{Format: formatSQL} - } - rmode := sessVars.EnableRedactLog - if rmode == errors.RedactLogEnable { - sql, _ := sessVars.StmtCtx.SQLDigest() - sessVars.PrevStmt.SetText(sql) - } else if sensitiveStmt, ok := a.StmtNode.(ast.SensitiveStmtNode); ok { - sql := sensitiveStmt.SecureText() - sessVars.PrevStmt.SetText(sql) - } else { - sessVars.PrevStmt.Update(rmode, sessVars.StmtCtx.OriginalSQL, sessVars.PlanCacheParams) - } -} - -func (a *ExecStmt) observeStmtBeginForTopSQL(ctx context.Context) context.Context { - vars := a.Ctx.GetSessionVars() - sc := vars.StmtCtx - normalizedSQL, sqlDigest := sc.SQLDigest() - normalizedPlan, planDigest := GetPlanDigest(sc) - var sqlDigestByte, planDigestByte []byte - if sqlDigest != nil { - sqlDigestByte = sqlDigest.Bytes() - } - if planDigest != nil { - planDigestByte = planDigest.Bytes() - } - stats := a.Ctx.GetStmtStats() - if !topsqlstate.TopSQLEnabled() { - // To reduce the performance impact on fast plan. - // Drop them does not cause notable accuracy issue in TopSQL. - if IsFastPlan(a.Plan) { - return ctx - } - // Always attach the SQL and plan info uses to catch the running SQL when Top SQL is enabled in execution. - if stats != nil { - stats.OnExecutionBegin(sqlDigestByte, planDigestByte) - } - return topsql.AttachSQLAndPlanInfo(ctx, sqlDigest, planDigest) - } - - if stats != nil { - stats.OnExecutionBegin(sqlDigestByte, planDigestByte) - // This is a special logic prepared for TiKV's SQLExecCount. - sc.KvExecCounter = stats.CreateKvExecCounter(sqlDigestByte, planDigestByte) - } - - isSQLRegistered := sc.IsSQLRegistered.Load() - if !isSQLRegistered { - topsql.RegisterSQL(normalizedSQL, sqlDigest, vars.InRestrictedSQL) - } - sc.IsSQLAndPlanRegistered.Store(true) - if len(normalizedPlan) == 0 { - return ctx - } - topsql.RegisterPlan(normalizedPlan, planDigest) - return topsql.AttachSQLAndPlanInfo(ctx, sqlDigest, planDigest) -} - -func (a *ExecStmt) observeStmtFinishedForTopSQL() { - vars := a.Ctx.GetSessionVars() - if vars == nil { - return - } - if stats := a.Ctx.GetStmtStats(); stats != nil && topsqlstate.TopSQLEnabled() { - sqlDigest, planDigest := a.getSQLPlanDigest() - execDuration := vars.GetTotalCostDuration() - stats.OnExecutionFinished(sqlDigest, planDigest, execDuration) - } -} - -func (a *ExecStmt) getSQLPlanDigest() ([]byte, []byte) { - var sqlDigest, planDigest []byte - vars := a.Ctx.GetSessionVars() - if _, d := vars.StmtCtx.SQLDigest(); d != nil { - sqlDigest = d.Bytes() - } - if _, d := vars.StmtCtx.GetPlanDigest(); d != nil { - planDigest = d.Bytes() - } - return sqlDigest, planDigest -} - -// only allow select/delete/update/insert/execute stmt captured by continues capture -func checkPlanReplayerContinuesCaptureValidStmt(stmtNode ast.StmtNode) bool { - switch stmtNode.(type) { - case *ast.SelectStmt, *ast.DeleteStmt, *ast.UpdateStmt, *ast.InsertStmt, *ast.ExecuteStmt: - return true - default: - return false - } -} - -func checkPlanReplayerCaptureTask(sctx sessionctx.Context, stmtNode ast.StmtNode, startTS uint64) { - dom := domain.GetDomain(sctx) - if dom == nil { - return - } - handle := dom.GetPlanReplayerHandle() - if handle == nil { - return - } - tasks := handle.GetTasks() - if len(tasks) == 0 { - return - } - _, sqlDigest := sctx.GetSessionVars().StmtCtx.SQLDigest() - _, planDigest := sctx.GetSessionVars().StmtCtx.GetPlanDigest() - if sqlDigest == nil || planDigest == nil { - return - } - key := replayer.PlanReplayerTaskKey{ - SQLDigest: sqlDigest.String(), - PlanDigest: planDigest.String(), - } - for _, task := range tasks { - if task.SQLDigest == sqlDigest.String() { - if task.PlanDigest == "*" || task.PlanDigest == planDigest.String() { - sendPlanReplayerDumpTask(key, sctx, stmtNode, startTS, false) - return - } - } - } -} - -func checkPlanReplayerContinuesCapture(sctx sessionctx.Context, stmtNode ast.StmtNode, startTS uint64) { - dom := domain.GetDomain(sctx) - if dom == nil { - return - } - handle := dom.GetPlanReplayerHandle() - if handle == nil { - return - } - _, sqlDigest := sctx.GetSessionVars().StmtCtx.SQLDigest() - _, planDigest := sctx.GetSessionVars().StmtCtx.GetPlanDigest() - key := replayer.PlanReplayerTaskKey{ - SQLDigest: sqlDigest.String(), - PlanDigest: planDigest.String(), - } - existed := sctx.GetSessionVars().CheckPlanReplayerFinishedTaskKey(key) - if existed { - return - } - sendPlanReplayerDumpTask(key, sctx, stmtNode, startTS, true) - sctx.GetSessionVars().AddPlanReplayerFinishedTaskKey(key) -} - -func sendPlanReplayerDumpTask(key replayer.PlanReplayerTaskKey, sctx sessionctx.Context, stmtNode ast.StmtNode, - startTS uint64, isContinuesCapture bool) { - stmtCtx := sctx.GetSessionVars().StmtCtx - handle := sctx.Value(bindinfo.SessionBindInfoKeyType).(bindinfo.SessionBindingHandle) - bindings := handle.GetAllSessionBindings() - dumpTask := &domain.PlanReplayerDumpTask{ - PlanReplayerTaskKey: key, - StartTS: startTS, - TblStats: stmtCtx.TableStats, - SessionBindings: []bindinfo.Bindings{bindings}, - SessionVars: sctx.GetSessionVars(), - ExecStmts: []ast.StmtNode{stmtNode}, - DebugTrace: []any{stmtCtx.OptimizerDebugTrace}, - Analyze: false, - IsCapture: true, - IsContinuesCapture: isContinuesCapture, - } - dumpTask.EncodedPlan, _ = GetEncodedPlan(stmtCtx, false) - if execStmtAst, ok := stmtNode.(*ast.ExecuteStmt); ok { - planCacheStmt, err := plannercore.GetPreparedStmt(execStmtAst, sctx.GetSessionVars()) - if err != nil { - logutil.BgLogger().Warn("fail to find prepared ast for dumping plan replayer", zap.String("category", "plan-replayer-capture"), - zap.String("sqlDigest", key.SQLDigest), - zap.String("planDigest", key.PlanDigest), - zap.Error(err)) - } else { - dumpTask.ExecStmts = []ast.StmtNode{planCacheStmt.PreparedAst.Stmt} - } - } - domain.GetDomain(sctx).GetPlanReplayerHandle().SendTask(dumpTask) -} diff --git a/pkg/executor/aggregate/agg_hash_executor.go b/pkg/executor/aggregate/agg_hash_executor.go index 80c78de579c3d..9b05153015803 100644 --- a/pkg/executor/aggregate/agg_hash_executor.go +++ b/pkg/executor/aggregate/agg_hash_executor.go @@ -215,23 +215,23 @@ func (e *HashAggExec) Close() error { } err := e.BaseExecutor.Close() - if val, _err_ := failpoint.Eval(_curpkg_("injectHashAggClosePanic")); _err_ == nil { + failpoint.Inject("injectHashAggClosePanic", func(val failpoint.Value) { if enabled := val.(bool); enabled { if e.Ctx().GetSessionVars().ConnectionID != 0 { panic(errors.New("test")) } } - } + }) return err } // Open implements the Executor Open interface. func (e *HashAggExec) Open(ctx context.Context) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockHashAggExecBaseExecutorOpenReturnedError")); _err_ == nil { + failpoint.Inject("mockHashAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { if val, _ := val.(bool); val { - return errors.New("mock HashAggExec.baseExecutor.Open returned error") + failpoint.Return(errors.New("mock HashAggExec.baseExecutor.Open returned error")) } - } + }) if err := e.BaseExecutor.Open(ctx); err != nil { return err @@ -264,7 +264,7 @@ func (e *HashAggExec) initForUnparallelExec() { e.groupSet, setSize = set.NewStringSetWithMemoryUsage() e.partialResultMap = make(aggfuncs.AggPartialResultMapper) e.bInMap = 0 - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1< | | | | ...... | | partialInputChs - +--------------+ +-+ +-+ +-+ -*/ -type HashAggExec struct { - exec.BaseExecutor - - Sc *stmtctx.StatementContext - PartialAggFuncs []aggfuncs.AggFunc - FinalAggFuncs []aggfuncs.AggFunc - partialResultMap aggfuncs.AggPartialResultMapper - bInMap int64 // indicate there are 2^bInMap buckets in partialResultMap - groupSet set.StringSetWithMemoryUsage - groupKeys []string - cursor4GroupKey int - GroupByItems []expression.Expression - groupKeyBuffer [][]byte - - finishCh chan struct{} - finalOutputCh chan *AfFinalResult - partialOutputChs []chan *aggfuncs.AggPartialResultMapper - inputCh chan *HashAggInput - partialInputChs []chan *chunk.Chunk - partialWorkers []HashAggPartialWorker - finalWorkers []HashAggFinalWorker - DefaultVal *chunk.Chunk - childResult *chunk.Chunk - - // IsChildReturnEmpty indicates whether the child executor only returns an empty input. - IsChildReturnEmpty bool - // After we support parallel execution for aggregation functions with distinct, - // we can remove this attribute. - IsUnparallelExec bool - parallelExecValid bool - prepared atomic.Bool - executed atomic.Bool - - memTracker *memory.Tracker // track memory usage. - diskTracker *disk.Tracker - - stats *HashAggRuntimeStats - - // dataInDisk is the chunks to store row values for spilled data. - // The HashAggExec may be set to `spill mode` multiple times, and all spilled data will be appended to DataInDiskByRows. - dataInDisk *chunk.DataInDiskByChunks - // numOfSpilledChks indicates the number of all the spilled chunks. - numOfSpilledChks int - // offsetOfSpilledChks indicates the offset of the chunk be read from the disk. - // In each round of processing, we need to re-fetch all the chunks spilled in the last one. - offsetOfSpilledChks int - // inSpillMode indicates whether HashAgg is in `spill mode`. - // When HashAgg is in `spill mode`, the size of `partialResultMap` is no longer growing and all the data fetched - // from the child executor is spilled to the disk. - inSpillMode uint32 - // tmpChkForSpill is the temp chunk for spilling. - tmpChkForSpill *chunk.Chunk - // The `inflightChunkSync` calls `Add(1)` when the data fetcher goroutine inserts a chunk into the channel, - // and `Done()` when any partial worker retrieves a chunk from the channel and updates it in the `partialResultMap`. - // In scenarios where it is necessary to wait for all partial workers to finish processing the inflight chunk, - // `inflightChunkSync` can be used for synchronization. - inflightChunkSync *sync.WaitGroup - // spillAction save the Action for spilling. - spillAction *AggSpillDiskAction - // parallelAggSpillAction save the Action for spilling of parallel aggregation. - parallelAggSpillAction *ParallelAggSpillDiskAction - // spillHelper helps to carry out the spill action - spillHelper *parallelHashAggSpillHelper - // isChildDrained indicates whether the all data from child has been taken out. - isChildDrained bool -} - -// Close implements the Executor Close interface. -func (e *HashAggExec) Close() error { - if e.stats != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - - if e.IsUnparallelExec { - e.childResult = nil - e.groupSet, _ = set.NewStringSetWithMemoryUsage() - e.partialResultMap = nil - if e.memTracker != nil { - e.memTracker.ReplaceBytesUsed(0) - } - if e.dataInDisk != nil { - e.dataInDisk.Close() - } - if e.spillAction != nil { - e.spillAction.SetFinished() - } - e.spillAction, e.tmpChkForSpill = nil, nil - err := e.BaseExecutor.Close() - if err != nil { - return err - } - return nil - } - if e.parallelExecValid { - // `Close` may be called after `Open` without calling `Next` in test. - if e.prepared.CompareAndSwap(false, true) { - close(e.inputCh) - for _, ch := range e.partialOutputChs { - close(ch) - } - for _, ch := range e.partialInputChs { - close(ch) - } - close(e.finalOutputCh) - } - close(e.finishCh) - for _, ch := range e.partialOutputChs { - channel.Clear(ch) - } - for _, ch := range e.partialInputChs { - channel.Clear(ch) - } - channel.Clear(e.finalOutputCh) - e.executed.Store(false) - if e.memTracker != nil { - e.memTracker.ReplaceBytesUsed(0) - } - e.parallelExecValid = false - if e.parallelAggSpillAction != nil { - e.parallelAggSpillAction.SetFinished() - e.parallelAggSpillAction = nil - e.spillHelper.close() - } - } - - err := e.BaseExecutor.Close() - failpoint.Inject("injectHashAggClosePanic", func(val failpoint.Value) { - if enabled := val.(bool); enabled { - if e.Ctx().GetSessionVars().ConnectionID != 0 { - panic(errors.New("test")) - } - } - }) - return err -} - -// Open implements the Executor Open interface. -func (e *HashAggExec) Open(ctx context.Context) error { - failpoint.Inject("mockHashAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { - if val, _ := val.(bool); val { - failpoint.Return(errors.New("mock HashAggExec.baseExecutor.Open returned error")) - } - }) - - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - return e.OpenSelf() -} - -// OpenSelf just opens the hash aggregation executor. -func (e *HashAggExec) OpenSelf() error { - e.prepared.Store(false) - - if e.memTracker != nil { - e.memTracker.Reset() - } else { - e.memTracker = memory.NewTracker(e.ID(), -1) - } - if e.Ctx().GetSessionVars().TrackAggregateMemoryUsage { - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - } - - if e.IsUnparallelExec { - e.initForUnparallelExec() - return nil - } - return e.initForParallelExec(e.Ctx()) -} - -func (e *HashAggExec) initForUnparallelExec() { - var setSize int64 - e.groupSet, setSize = set.NewStringSetWithMemoryUsage() - e.partialResultMap = make(aggfuncs.AggPartialResultMapper) - e.bInMap = 0 - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(hack.DefBucketMemoryUsageForMapStrToSlice*(1< 0 { - e.IsChildReturnEmpty = false - return nil - } - } -} - -// unparallelExec executes hash aggregation algorithm in single thread. -func (e *HashAggExec) unparallelExec(ctx context.Context, chk *chunk.Chunk) error { - chk.Reset() - for { - exprCtx := e.Ctx().GetExprCtx() - if e.prepared.Load() { - // Since we return e.MaxChunkSize() rows every time, so we should not traverse - // `groupSet` because of its randomness. - for ; e.cursor4GroupKey < len(e.groupKeys); e.cursor4GroupKey++ { - partialResults := e.getPartialResults(e.groupKeys[e.cursor4GroupKey]) - if len(e.PartialAggFuncs) == 0 { - chk.SetNumVirtualRows(chk.NumRows() + 1) - } - for i, af := range e.PartialAggFuncs { - if err := af.AppendFinalResult2Chunk(exprCtx.GetEvalCtx(), partialResults[i], chk); err != nil { - return err - } - } - if chk.IsFull() { - e.cursor4GroupKey++ - return nil - } - } - e.resetSpillMode() - } - if e.executed.Load() { - return nil - } - if err := e.execute(ctx); err != nil { - return err - } - if (len(e.groupSet.StringSet) == 0) && len(e.GroupByItems) == 0 { - // If no groupby and no data, we should add an empty group. - // For example: - // "select count(c) from t;" should return one row [0] - // "select count(c) from t group by c1;" should return empty result set. - e.memTracker.Consume(e.groupSet.Insert("")) - e.groupKeys = append(e.groupKeys, "") - } - e.prepared.Store(true) - } -} - -func (e *HashAggExec) resetSpillMode() { - e.cursor4GroupKey, e.groupKeys = 0, e.groupKeys[:0] - var setSize int64 - e.groupSet, setSize = set.NewStringSetWithMemoryUsage() - e.partialResultMap = make(aggfuncs.AggPartialResultMapper) - e.bInMap = 0 - e.prepared.Store(false) - e.executed.Store(e.numOfSpilledChks == e.dataInDisk.NumChunks()) // No data is spilling again, all data have been processed. - e.numOfSpilledChks = e.dataInDisk.NumChunks() - e.memTracker.ReplaceBytesUsed(setSize) - atomic.StoreUint32(&e.inSpillMode, 0) -} - -// execute fetches Chunks from src and update each aggregate function for each row in Chunk. -func (e *HashAggExec) execute(ctx context.Context) (err error) { - defer func() { - if e.tmpChkForSpill.NumRows() > 0 && err == nil { - err = e.dataInDisk.Add(e.tmpChkForSpill) - e.tmpChkForSpill.Reset() - } - }() - exprCtx := e.Ctx().GetExprCtx() - for { - mSize := e.childResult.MemoryUsage() - if err := e.getNextChunk(ctx); err != nil { - return err - } - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) - if err != nil { - return err - } - - failpoint.Inject("unparallelHashAggError", func(val failpoint.Value) { - if val, _ := val.(bool); val { - failpoint.Return(errors.New("HashAggExec.unparallelExec error")) - } - }) - - // no more data. - if e.childResult.NumRows() == 0 { - return nil - } - e.groupKeyBuffer, err = GetGroupKey(e.Ctx(), e.childResult, e.groupKeyBuffer, e.GroupByItems) - if err != nil { - return err - } - - allMemDelta := int64(0) - sel := make([]int, 0, e.childResult.NumRows()) - var tmpBuf [1]chunk.Row - for j := 0; j < e.childResult.NumRows(); j++ { - groupKey := string(e.groupKeyBuffer[j]) // do memory copy here, because e.groupKeyBuffer may be reused. - if !e.groupSet.Exist(groupKey) { - if atomic.LoadUint32(&e.inSpillMode) == 1 && e.groupSet.Count() > 0 { - sel = append(sel, j) - continue - } - allMemDelta += e.groupSet.Insert(groupKey) - e.groupKeys = append(e.groupKeys, groupKey) - } - partialResults := e.getPartialResults(groupKey) - for i, af := range e.PartialAggFuncs { - tmpBuf[0] = e.childResult.GetRow(j) - memDelta, err := af.UpdatePartialResult(exprCtx.GetEvalCtx(), tmpBuf[:], partialResults[i]) - if err != nil { - return err - } - allMemDelta += memDelta - } - } - - // spill unprocessed data when exceeded. - if len(sel) > 0 { - e.childResult.SetSel(sel) - err = e.spillUnprocessedData(len(sel) == cap(sel)) - if err != nil { - return err - } - } - - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(allMemDelta) - } -} - -func (e *HashAggExec) spillUnprocessedData(isFullChk bool) (err error) { - if isFullChk { - return e.dataInDisk.Add(e.childResult) - } - for i := 0; i < e.childResult.NumRows(); i++ { - e.tmpChkForSpill.AppendRow(e.childResult.GetRow(i)) - if e.tmpChkForSpill.IsFull() { - err = e.dataInDisk.Add(e.tmpChkForSpill) - if err != nil { - return err - } - e.tmpChkForSpill.Reset() - } - } - return nil -} - -func (e *HashAggExec) getNextChunk(ctx context.Context) (err error) { - e.childResult.Reset() - if !e.isChildDrained { - if err := exec.Next(ctx, e.Children(0), e.childResult); err != nil { - return err - } - if e.childResult.NumRows() != 0 { - return nil - } - e.isChildDrained = true - } - if e.offsetOfSpilledChks < e.numOfSpilledChks { - e.childResult, err = e.dataInDisk.GetChunk(e.offsetOfSpilledChks) - if err != nil { - return err - } - e.offsetOfSpilledChks++ - } - return nil -} - -func (e *HashAggExec) getPartialResults(groupKey string) []aggfuncs.PartialResult { - partialResults, ok := e.partialResultMap[groupKey] - allMemDelta := int64(0) - if !ok { - partialResults = make([]aggfuncs.PartialResult, 0, len(e.PartialAggFuncs)) - for _, af := range e.PartialAggFuncs { - partialResult, memDelta := af.AllocPartialResult() - partialResults = append(partialResults, partialResult) - allMemDelta += memDelta - } - // Map will expand when count > bucketNum * loadFactor. The memory usage will doubled. - if len(e.partialResultMap)+1 > (1< 0 { - return true - } - } - return false -} diff --git a/pkg/executor/aggregate/agg_hash_final_worker.go b/pkg/executor/aggregate/agg_hash_final_worker.go index 6cb21c8045069..d2cc2cb047f10 100644 --- a/pkg/executor/aggregate/agg_hash_final_worker.go +++ b/pkg/executor/aggregate/agg_hash_final_worker.go @@ -128,7 +128,7 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) error { return nil } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if err := w.mergeInputIntoResultMap(sctx, input); err != nil { return err @@ -168,7 +168,7 @@ func (w *HashAggFinalWorker) sendFinalResult(sctx sessionctx.Context) { return } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) execStart := time.Now() updateExecTime(w.stats, execStart) @@ -253,7 +253,7 @@ func (w *HashAggFinalWorker) cleanup(start time.Time, waitGroup *sync.WaitGroup) } func intestBeforeFinalWorkerStart() { - if val, _err_ := failpoint.Eval(_curpkg_("enableAggSpillIntest")); _err_ == nil { + failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { if val.(bool) { num := rand.Intn(50) if num < 3 { @@ -262,11 +262,11 @@ func intestBeforeFinalWorkerStart() { time.Sleep(1 * time.Millisecond) } } - } + }) } func (w *HashAggFinalWorker) intestDuringFinalWorkerRun(err *error) { - if val, _err_ := failpoint.Eval(_curpkg_("enableAggSpillIntest")); _err_ == nil { + failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { if val.(bool) { num := rand.Intn(10000) if num < 5 { @@ -279,5 +279,5 @@ func (w *HashAggFinalWorker) intestDuringFinalWorkerRun(err *error) { *err = errors.New("Random fail is triggered in final worker") } } - } + }) } diff --git a/pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ b/pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ deleted file mode 100644 index d2cc2cb047f10..0000000000000 --- a/pkg/executor/aggregate/agg_hash_final_worker.go__failpoint_stash__ +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package aggregate - -import ( - "math/rand" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/aggfuncs" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -// AfFinalResult indicates aggregation functions final result. -type AfFinalResult struct { - chk *chunk.Chunk - err error - giveBackCh chan *chunk.Chunk -} - -// HashAggFinalWorker indicates the final workers of parallel hash agg execution, -// the number of the worker can be set by `tidb_hashagg_final_concurrency`. -type HashAggFinalWorker struct { - baseHashAggWorker - - mutableRow chunk.MutRow - partialResultMap aggfuncs.AggPartialResultMapper - BInMap int - inputCh chan *aggfuncs.AggPartialResultMapper - outputCh chan *AfFinalResult - finalResultHolderCh chan *chunk.Chunk - - spillHelper *parallelHashAggSpillHelper - - restoredAggResultMapperMem int64 -} - -func (w *HashAggFinalWorker) getInputFromDisk(sctx sessionctx.Context) (ret aggfuncs.AggPartialResultMapper, restoredMem int64, err error) { - ret, restoredMem, err = w.spillHelper.restoreOnePartition(sctx) - w.intestDuringFinalWorkerRun(&err) - return ret, restoredMem, err -} - -func (w *HashAggFinalWorker) getPartialInput() (input *aggfuncs.AggPartialResultMapper, ok bool) { - waitStart := time.Now() - defer updateWaitTime(w.stats, waitStart) - select { - case <-w.finishCh: - return nil, false - case input, ok = <-w.inputCh: - if !ok { - return nil, false - } - } - return -} - -func (w *HashAggFinalWorker) initBInMap() { - w.BInMap = 0 - mapLen := len(w.partialResultMap) - for mapLen > (1< (1<= 2 && num < 4 { - time.Sleep(1 * time.Millisecond) - } - } - }) -} - -func (w *HashAggPartialWorker) finalizeWorkerProcess(needShuffle bool, finalConcurrency int, hasError bool) { - // Consume all chunks to avoid hang of fetcher - for range w.inputCh { - w.inflightChunkSync.Done() - } - - if w.checkFinishChClosed() { - return - } - - if hasError { - return - } - - if needShuffle && w.spillHelper.isSpilledChunksIOEmpty() { - w.shuffleIntermData(finalConcurrency) - } -} - -func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitGroup, finalConcurrency int) { - start := time.Now() - hasError := false - needShuffle := false - - defer func() { - if r := recover(); r != nil { - recoveryHashAgg(w.globalOutputCh, r) - } - - w.finalizeWorkerProcess(needShuffle, finalConcurrency, hasError) - - w.memTracker.Consume(-w.chk.MemoryUsage()) - updateWorkerTime(w.stats, start) - - // We must ensure that there is no panic before `waitGroup.Done()` or there will be hang - waitGroup.Done() - - tryRecycleBuffer(&w.partialResultsBuffer, &w.groupKeyBuf) - }() - - intestBeforePartialWorkerRun() - - for w.fetchChunkAndProcess(ctx, &hasError, &needShuffle) { - } -} - -// If the group key has appeared before, reuse the partial result. -// If the group key has not appeared before, create empty partial results. -func (w *HashAggPartialWorker) getPartialResultsOfEachRow(groupKey [][]byte, finalConcurrency int) [][]aggfuncs.PartialResult { - mapper := w.partialResultsMap - numRows := len(groupKey) - allMemDelta := int64(0) - w.partialResultsBuffer = w.partialResultsBuffer[0:0] - - for i := 0; i < numRows; i++ { - finalWorkerIdx := int(murmur3.Sum32(groupKey[i])) % finalConcurrency - tmp, ok := mapper[finalWorkerIdx][string(hack.String(groupKey[i]))] - - // This group by key has appeared before, reuse the partial result. - if ok { - w.partialResultsBuffer = append(w.partialResultsBuffer, tmp) - continue - } - - // It's the first time that this group by key appeared, create it - w.partialResultsBuffer = append(w.partialResultsBuffer, make([]aggfuncs.PartialResult, w.partialResultNumInRow)) - lastIdx := len(w.partialResultsBuffer) - 1 - for j, af := range w.aggFuncs { - partialResult, memDelta := af.AllocPartialResult() - w.partialResultsBuffer[lastIdx][j] = partialResult - allMemDelta += memDelta // the memory usage of PartialResult - } - allMemDelta += int64(w.partialResultNumInRow * 8) - - // Map will expand when count > bucketNum * loadFactor. The memory usage will double. - if len(mapper[finalWorkerIdx])+1 > (1< 0 { - err := w.spilledChunksIO[i].Add(w.tmpChksForSpill[i]) - if err != nil { - return err - } - w.tmpChksForSpill[i].Reset() - } - } - return nil -} - -func (w *HashAggPartialWorker) processError(err error) { - w.globalOutputCh <- &AfFinalResult{err: err} - w.spillHelper.setError() -} diff --git a/pkg/executor/aggregate/agg_stream_executor.go b/pkg/executor/aggregate/agg_stream_executor.go index 9416d8cd60b38..6e08503325731 100644 --- a/pkg/executor/aggregate/agg_stream_executor.go +++ b/pkg/executor/aggregate/agg_stream_executor.go @@ -53,11 +53,11 @@ type StreamAggExec struct { // Open implements the Executor Open interface. func (e *StreamAggExec) Open(ctx context.Context) error { - if val, _err_ := failpoint.Eval(_curpkg_("mockStreamAggExecBaseExecutorOpenReturnedError")); _err_ == nil { + failpoint.Inject("mockStreamAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { if val, _ := val.(bool); val { - return errors.New("mock StreamAggExec.baseExecutor.Open returned error") + failpoint.Return(errors.New("mock StreamAggExec.baseExecutor.Open returned error")) } - } + }) if err := e.BaseExecutor.Open(ctx); err != nil { return err @@ -91,7 +91,7 @@ func (e *StreamAggExec) OpenSelf() error { if e.Ctx().GetSessionVars().TrackAggregateMemoryUsage { e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) e.memTracker.Consume(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) return nil } @@ -179,7 +179,7 @@ func (e *StreamAggExec) consumeGroupRows() error { } allMemDelta += memDelta } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) e.memTracker.Consume(allMemDelta) e.groupRows = e.groupRows[:0] return nil @@ -194,7 +194,7 @@ func (e *StreamAggExec) consumeCurGroupRowsAndFetchChild(ctx context.Context, ch mSize := e.childResult.MemoryUsage() err = exec.Next(ctx, e.Children(0), e.childResult) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err @@ -227,7 +227,7 @@ func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error { } aggFunc.ResetPartialResult(e.partialResults[i]) } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) // All partial results have been reset, so reset the memory usage. e.memTracker.ReplaceBytesUsed(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) if len(e.AggFuncs) == 0 { diff --git a/pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ b/pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ deleted file mode 100644 index 6e08503325731..0000000000000 --- a/pkg/executor/aggregate/agg_stream_executor.go__failpoint_stash__ +++ /dev/null @@ -1,237 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package aggregate - -import ( - "context" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/aggfuncs" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/memory" -) - -// StreamAggExec deals with all the aggregate functions. -// It assumes all the input data is sorted by group by key. -// When Next() is called, it will return a result for the same group. -type StreamAggExec struct { - exec.BaseExecutor - - executed bool - // IsChildReturnEmpty indicates whether the child executor only returns an empty input. - IsChildReturnEmpty bool - DefaultVal *chunk.Chunk - GroupChecker *vecgroupchecker.VecGroupChecker - inputIter *chunk.Iterator4Chunk - inputRow chunk.Row - AggFuncs []aggfuncs.AggFunc - partialResults []aggfuncs.PartialResult - groupRows []chunk.Row - childResult *chunk.Chunk - - memTracker *memory.Tracker // track memory usage. - // memUsageOfInitialPartialResult indicates the memory usage of all partial results after initialization. - // All partial results will be reset after processing one group data, and the memory usage should also be reset. - // We can't get memory delta from ResetPartialResult, so record the memory usage here. - memUsageOfInitialPartialResult int64 -} - -// Open implements the Executor Open interface. -func (e *StreamAggExec) Open(ctx context.Context) error { - failpoint.Inject("mockStreamAggExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { - if val, _ := val.(bool); val { - failpoint.Return(errors.New("mock StreamAggExec.baseExecutor.Open returned error")) - } - }) - - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - // If panic in Open, the children executor should be closed because they are open. - defer closeBaseExecutor(&e.BaseExecutor) - return e.OpenSelf() -} - -// OpenSelf just opens the StreamAggExec. -func (e *StreamAggExec) OpenSelf() error { - e.childResult = exec.TryNewCacheChunk(e.Children(0)) - e.executed = false - e.IsChildReturnEmpty = true - e.inputIter = chunk.NewIterator4Chunk(e.childResult) - e.inputRow = e.inputIter.End() - - e.partialResults = make([]aggfuncs.PartialResult, 0, len(e.AggFuncs)) - for _, aggFunc := range e.AggFuncs { - partialResult, memDelta := aggFunc.AllocPartialResult() - e.partialResults = append(e.partialResults, partialResult) - e.memUsageOfInitialPartialResult += memDelta - } - - if e.memTracker != nil { - e.memTracker.Reset() - } else { - // bytesLimit <= 0 means no limit, for now we just track the memory footprint - e.memTracker = memory.NewTracker(e.ID(), -1) - } - if e.Ctx().GetSessionVars().TrackAggregateMemoryUsage { - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - } - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) - return nil -} - -// Close implements the Executor Close interface. -func (e *StreamAggExec) Close() error { - if e.childResult != nil { - e.memTracker.Consume(-e.childResult.MemoryUsage() - e.memUsageOfInitialPartialResult) - e.childResult = nil - } - e.GroupChecker.Reset() - return e.BaseExecutor.Close() -} - -// Next implements the Executor Next interface. -func (e *StreamAggExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - req.Reset() - for !e.executed && !req.IsFull() { - err = e.consumeOneGroup(ctx, req) - if err != nil { - e.executed = true - return err - } - } - return nil -} - -func (e *StreamAggExec) consumeOneGroup(ctx context.Context, chk *chunk.Chunk) (err error) { - if e.GroupChecker.IsExhausted() { - if err = e.consumeCurGroupRowsAndFetchChild(ctx, chk); err != nil { - return err - } - if e.executed { - return nil - } - _, err := e.GroupChecker.SplitIntoGroups(e.childResult) - if err != nil { - return err - } - } - begin, end := e.GroupChecker.GetNextGroup() - for i := begin; i < end; i++ { - e.groupRows = append(e.groupRows, e.childResult.GetRow(i)) - } - - for meetLastGroup := end == e.childResult.NumRows(); meetLastGroup; { - meetLastGroup = false - if err = e.consumeCurGroupRowsAndFetchChild(ctx, chk); err != nil || e.executed { - return err - } - - isFirstGroupSameAsPrev, err := e.GroupChecker.SplitIntoGroups(e.childResult) - if err != nil { - return err - } - - if isFirstGroupSameAsPrev { - begin, end = e.GroupChecker.GetNextGroup() - for i := begin; i < end; i++ { - e.groupRows = append(e.groupRows, e.childResult.GetRow(i)) - } - meetLastGroup = end == e.childResult.NumRows() - } - } - - err = e.consumeGroupRows() - if err != nil { - return err - } - - return e.appendResult2Chunk(chk) -} - -func (e *StreamAggExec) consumeGroupRows() error { - if len(e.groupRows) == 0 { - return nil - } - - allMemDelta := int64(0) - exprCtx := e.Ctx().GetExprCtx() - for i, aggFunc := range e.AggFuncs { - memDelta, err := aggFunc.UpdatePartialResult(exprCtx.GetEvalCtx(), e.groupRows, e.partialResults[i]) - if err != nil { - return err - } - allMemDelta += memDelta - } - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(allMemDelta) - e.groupRows = e.groupRows[:0] - return nil -} - -func (e *StreamAggExec) consumeCurGroupRowsAndFetchChild(ctx context.Context, chk *chunk.Chunk) (err error) { - // Before fetching a new batch of input, we should consume the last group. - err = e.consumeGroupRows() - if err != nil { - return err - } - - mSize := e.childResult.MemoryUsage() - err = exec.Next(ctx, e.Children(0), e.childResult) - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) - if err != nil { - return err - } - - // No more data. - if e.childResult.NumRows() == 0 { - if !e.IsChildReturnEmpty { - err = e.appendResult2Chunk(chk) - } else if e.DefaultVal != nil { - chk.Append(e.DefaultVal, 0, 1) - } - e.executed = true - return err - } - // Reach here, "e.childrenResults[0].NumRows() > 0" is guaranteed. - e.IsChildReturnEmpty = false - e.inputRow = e.inputIter.Begin() - return nil -} - -// appendResult2Chunk appends result of all the aggregation functions to the -// result chunk, and reset the evaluation context for each aggregation. -func (e *StreamAggExec) appendResult2Chunk(chk *chunk.Chunk) error { - exprCtx := e.Ctx().GetExprCtx() - for i, aggFunc := range e.AggFuncs { - err := aggFunc.AppendFinalResult2Chunk(exprCtx.GetEvalCtx(), e.partialResults[i], chk) - if err != nil { - return err - } - aggFunc.ResetPartialResult(e.partialResults[i]) - } - failpoint.Inject("ConsumeRandomPanic", nil) - // All partial results have been reset, so reset the memory usage. - e.memTracker.ReplaceBytesUsed(e.childResult.MemoryUsage() + e.memUsageOfInitialPartialResult) - if len(e.AggFuncs) == 0 { - chk.SetNumVirtualRows(chk.NumRows() + 1) - } - return nil -} diff --git a/pkg/executor/aggregate/agg_util.go b/pkg/executor/aggregate/agg_util.go index 9b1568b88bf2a..81b06147ee1d0 100644 --- a/pkg/executor/aggregate/agg_util.go +++ b/pkg/executor/aggregate/agg_util.go @@ -281,14 +281,14 @@ func (e *HashAggExec) ActionSpill() memory.ActionOnExceed { func failpointError() error { var err error - if val, _err_ := failpoint.Eval(_curpkg_("enableAggSpillIntest")); _err_ == nil { + failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { if val.(bool) { num := rand.Intn(1000) if num < 3 { err = errors.Errorf("Random fail is triggered in ParallelAggSpillDiskAction") } } - } + }) return err } diff --git a/pkg/executor/aggregate/agg_util.go__failpoint_stash__ b/pkg/executor/aggregate/agg_util.go__failpoint_stash__ deleted file mode 100644 index 81b06147ee1d0..0000000000000 --- a/pkg/executor/aggregate/agg_util.go__failpoint_stash__ +++ /dev/null @@ -1,312 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package aggregate - -import ( - "bytes" - "cmp" - "fmt" - "math/rand" - "slices" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/aggfuncs" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" -) - -const defaultPartialResultsBufferCap = 2048 -const defaultGroupKeyCap = 8 - -var partialResultsBufferPool = sync.Pool{ - New: func() any { - s := make([][]aggfuncs.PartialResult, 0, defaultPartialResultsBufferCap) - return &s - }, -} - -var groupKeyPool = sync.Pool{ - New: func() any { - s := make([][]byte, 0, defaultGroupKeyCap) - return &s - }, -} - -func getBuffer() (*[][]aggfuncs.PartialResult, *[][]byte) { - partialResultsBuffer := partialResultsBufferPool.Get().(*[][]aggfuncs.PartialResult) - *partialResultsBuffer = (*partialResultsBuffer)[:0] - groupKey := groupKeyPool.Get().(*[][]byte) - *groupKey = (*groupKey)[:0] - return partialResultsBuffer, groupKey -} - -// tryRecycleBuffer recycles small buffers only. This approach reduces the CPU pressure -// from memory allocation during high concurrency aggregation computations (like DDL's scheduled tasks), -// and also prevents the pool from holding too much memory and causing memory pressure. -func tryRecycleBuffer(buf *[][]aggfuncs.PartialResult, groupKey *[][]byte) { - if cap(*buf) <= defaultPartialResultsBufferCap { - partialResultsBufferPool.Put(buf) - } - if cap(*groupKey) <= defaultGroupKeyCap { - groupKeyPool.Put(groupKey) - } -} - -func closeBaseExecutor(b *exec.BaseExecutor) { - if r := recover(); r != nil { - // Release the resource, but throw the panic again and let the top level handle it. - terror.Log(b.Close()) - logutil.BgLogger().Warn("panic in Open(), close base executor and throw exception again") - panic(r) - } -} - -func recoveryHashAgg(output chan *AfFinalResult, r any) { - err := util.GetRecoverError(r) - output <- &AfFinalResult{err: err} - logutil.BgLogger().Error("parallel hash aggregation panicked", zap.Error(err), zap.Stack("stack")) -} - -func getGroupKeyMemUsage(groupKey [][]byte) int64 { - mem := int64(0) - for _, key := range groupKey { - mem += int64(cap(key)) - } - mem += aggfuncs.DefSliceSize * int64(cap(groupKey)) - return mem -} - -// GetGroupKey evaluates the group items and args of aggregate functions. -func GetGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte, groupByItems []expression.Expression) ([][]byte, error) { - numRows := input.NumRows() - avlGroupKeyLen := min(len(groupKey), numRows) - for i := 0; i < avlGroupKeyLen; i++ { - groupKey[i] = groupKey[i][:0] - } - for i := avlGroupKeyLen; i < numRows; i++ { - groupKey = append(groupKey, make([]byte, 0, 10*len(groupByItems))) - } - - errCtx := ctx.GetSessionVars().StmtCtx.ErrCtx() - exprCtx := ctx.GetExprCtx() - for _, item := range groupByItems { - tp := item.GetType(ctx.GetExprCtx().GetEvalCtx()) - - buf, err := expression.GetColumn(tp.EvalType(), numRows) - if err != nil { - return nil, err - } - - // In strict sql mode like ‘STRICT_TRANS_TABLES’,can not insert an invalid enum value like 0. - // While in sql mode like '', can insert an invalid enum value like 0, - // then the enum value 0 will have the enum name '', which maybe conflict with user defined enum ''. - // Ref to issue #26885. - // This check is used to handle invalid enum name same with user defined enum name. - // Use enum value as groupKey instead of enum name. - if item.GetType(ctx.GetExprCtx().GetEvalCtx()).GetType() == mysql.TypeEnum { - newTp := *tp - newTp.AddFlag(mysql.EnumSetAsIntFlag) - tp = &newTp - } - - if err := expression.EvalExpr(exprCtx.GetEvalCtx(), ctx.GetSessionVars().EnableVectorizedExpression, item, tp.EvalType(), input, buf); err != nil { - expression.PutColumn(buf) - return nil, err - } - // This check is used to avoid error during the execution of `EncodeDecimal`. - if item.GetType(ctx.GetExprCtx().GetEvalCtx()).GetType() == mysql.TypeNewDecimal { - newTp := *tp - newTp.SetFlen(0) - tp = &newTp - } - - groupKey, err = codec.HashGroupKey(ctx.GetSessionVars().StmtCtx.TimeZone(), input.NumRows(), buf, groupKey, tp) - err = errCtx.HandleError(err) - if err != nil { - expression.PutColumn(buf) - return nil, err - } - expression.PutColumn(buf) - } - return groupKey[:numRows], nil -} - -// HashAggRuntimeStats record the HashAggExec runtime stat -type HashAggRuntimeStats struct { - PartialConcurrency int - PartialWallTime int64 - FinalConcurrency int - FinalWallTime int64 - PartialStats []*AggWorkerStat - FinalStats []*AggWorkerStat -} - -func (*HashAggRuntimeStats) workerString(buf *bytes.Buffer, prefix string, concurrency int, wallTime int64, workerStats []*AggWorkerStat) { - var totalTime, totalWait, totalExec, totalTaskNum int64 - for _, w := range workerStats { - totalTime += w.WorkerTime - totalWait += w.WaitTime - totalExec += w.ExecTime - totalTaskNum += w.TaskNum - } - buf.WriteString(prefix) - fmt.Fprintf(buf, "_worker:{wall_time:%s, concurrency:%d, task_num:%d, tot_wait:%s, tot_exec:%s, tot_time:%s", - time.Duration(wallTime), concurrency, totalTaskNum, time.Duration(totalWait), time.Duration(totalExec), time.Duration(totalTime)) - n := len(workerStats) - if n > 0 { - slices.SortFunc(workerStats, func(i, j *AggWorkerStat) int { return cmp.Compare(i.WorkerTime, j.WorkerTime) }) - fmt.Fprintf(buf, ", max:%v, p95:%v", - time.Duration(workerStats[n-1].WorkerTime), time.Duration(workerStats[n*19/20].WorkerTime)) - } - buf.WriteString("}") -} - -// String implements the RuntimeStats interface. -func (e *HashAggRuntimeStats) String() string { - buf := bytes.NewBuffer(make([]byte, 0, 64)) - e.workerString(buf, "partial", e.PartialConcurrency, atomic.LoadInt64(&e.PartialWallTime), e.PartialStats) - buf.WriteString(", ") - e.workerString(buf, "final", e.FinalConcurrency, atomic.LoadInt64(&e.FinalWallTime), e.FinalStats) - return buf.String() -} - -// Clone implements the RuntimeStats interface. -func (e *HashAggRuntimeStats) Clone() execdetails.RuntimeStats { - newRs := &HashAggRuntimeStats{ - PartialConcurrency: e.PartialConcurrency, - PartialWallTime: atomic.LoadInt64(&e.PartialWallTime), - FinalConcurrency: e.FinalConcurrency, - FinalWallTime: atomic.LoadInt64(&e.FinalWallTime), - PartialStats: make([]*AggWorkerStat, 0, e.PartialConcurrency), - FinalStats: make([]*AggWorkerStat, 0, e.FinalConcurrency), - } - for _, s := range e.PartialStats { - newRs.PartialStats = append(newRs.PartialStats, s.Clone()) - } - for _, s := range e.FinalStats { - newRs.FinalStats = append(newRs.FinalStats, s.Clone()) - } - return newRs -} - -// Merge implements the RuntimeStats interface. -func (e *HashAggRuntimeStats) Merge(other execdetails.RuntimeStats) { - tmp, ok := other.(*HashAggRuntimeStats) - if !ok { - return - } - atomic.AddInt64(&e.PartialWallTime, atomic.LoadInt64(&tmp.PartialWallTime)) - atomic.AddInt64(&e.FinalWallTime, atomic.LoadInt64(&tmp.FinalWallTime)) - e.PartialStats = append(e.PartialStats, tmp.PartialStats...) - e.FinalStats = append(e.FinalStats, tmp.FinalStats...) -} - -// Tp implements the RuntimeStats interface. -func (*HashAggRuntimeStats) Tp() int { - return execdetails.TpHashAggRuntimeStat -} - -// AggWorkerInfo contains the agg worker information. -type AggWorkerInfo struct { - Concurrency int - WallTime int64 -} - -// AggWorkerStat record the AggWorker runtime stat -type AggWorkerStat struct { - TaskNum int64 - WaitTime int64 - ExecTime int64 - WorkerTime int64 -} - -// Clone implements the RuntimeStats interface. -func (w *AggWorkerStat) Clone() *AggWorkerStat { - return &AggWorkerStat{ - TaskNum: w.TaskNum, - WaitTime: w.WaitTime, - ExecTime: w.ExecTime, - WorkerTime: w.WorkerTime, - } -} - -func (e *HashAggExec) actionSpillForUnparallel() memory.ActionOnExceed { - e.spillAction = &AggSpillDiskAction{ - e: e, - } - return e.spillAction -} - -func (e *HashAggExec) actionSpillForParallel() memory.ActionOnExceed { - e.parallelAggSpillAction = &ParallelAggSpillDiskAction{ - e: e, - spillHelper: e.spillHelper, - } - return e.parallelAggSpillAction -} - -// ActionSpill returns an action for spilling intermediate data for hashAgg. -func (e *HashAggExec) ActionSpill() memory.ActionOnExceed { - if e.IsUnparallelExec { - return e.actionSpillForUnparallel() - } - return e.actionSpillForParallel() -} - -func failpointError() error { - var err error - failpoint.Inject("enableAggSpillIntest", func(val failpoint.Value) { - if val.(bool) { - num := rand.Intn(1000) - if num < 3 { - err = errors.Errorf("Random fail is triggered in ParallelAggSpillDiskAction") - } - } - }) - return err -} - -func updateWaitTime(stats *AggWorkerStat, startTime time.Time) { - if stats != nil { - stats.WaitTime += int64(time.Since(startTime)) - } -} - -func updateWorkerTime(stats *AggWorkerStat, startTime time.Time) { - if stats != nil { - stats.WorkerTime += int64(time.Since(startTime)) - } -} - -func updateExecTime(stats *AggWorkerStat, startTime time.Time) { - if stats != nil { - stats.ExecTime += int64(time.Since(startTime)) - stats.TaskNum++ - } -} diff --git a/pkg/executor/aggregate/binding__failpoint_binding__.go b/pkg/executor/aggregate/binding__failpoint_binding__.go deleted file mode 100644 index f2796e9412af6..0000000000000 --- a/pkg/executor/aggregate/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package aggregate - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/analyze.go b/pkg/executor/analyze.go index ed6c7c31ca7b1..9ab4963870dec 100644 --- a/pkg/executor/analyze.go +++ b/pkg/executor/analyze.go @@ -124,12 +124,12 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) error { prepareV2AnalyzeJobInfo(task.colExec) AddNewAnalyzeJob(e.Ctx(), task.job) } - if _, _err_ := failpoint.Eval(_curpkg_("mockKillPendingAnalyzeJob")); _err_ == nil { + failpoint.Inject("mockKillPendingAnalyzeJob", func() { dom := domain.GetDomain(e.Ctx()) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - } + }) TASKLOOP: for _, task := range tasks { select { @@ -154,12 +154,12 @@ TASKLOOP: return err } - if _, _err_ := failpoint.Eval(_curpkg_("mockKillFinishedAnalyzeJob")); _err_ == nil { + failpoint.Inject("mockKillFinishedAnalyzeJob", func() { dom := domain.GetDomain(e.Ctx()) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - } + }) // If we enabled dynamic prune mode, then we need to generate global stats here for partition tables. if needGlobalStats { err = e.handleGlobalStats(statsHandle, globalStatsMap) @@ -415,7 +415,7 @@ func (e *AnalyzeExec) handleResultsError( } } logutil.BgLogger().Info("use single session to save analyze results") - failpoint.Eval(_curpkg_("handleResultsErrorSingleThreadPanic")) + failpoint.Inject("handleResultsErrorSingleThreadPanic", nil) subSctxs := []sessionctx.Context{e.Ctx()} return e.handleResultsErrorWithConcurrency(internalCtx, concurrency, needGlobalStats, subSctxs, globalStatsMap, resultsCh) } @@ -510,7 +510,7 @@ func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultsCh chan<- if !ok { break } - failpoint.Eval(_curpkg_("handleAnalyzeWorkerPanic")) + failpoint.Inject("handleAnalyzeWorkerPanic", nil) statsHandle.StartAnalyzeJob(task.job) switch task.taskType { case colTask: diff --git a/pkg/executor/analyze.go__failpoint_stash__ b/pkg/executor/analyze.go__failpoint_stash__ deleted file mode 100644 index 9ab4963870dec..0000000000000 --- a/pkg/executor/analyze.go__failpoint_stash__ +++ /dev/null @@ -1,619 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - stderrors "errors" - "fmt" - "math" - "net" - "strconv" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle" - statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" - handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sqlescape" - "github.com/pingcap/tipb/go-tipb" - "github.com/tiancaiamao/gp" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -var _ exec.Executor = &AnalyzeExec{} - -// AnalyzeExec represents Analyze executor. -type AnalyzeExec struct { - exec.BaseExecutor - tasks []*analyzeTask - wg *util.WaitGroupPool - opts map[ast.AnalyzeOptionType]uint64 - OptionsMap map[int64]core.V2AnalyzeOptions - gp *gp.Pool - // errExitCh is used to notice the worker that the whole analyze task is finished when to meet error. - errExitCh chan struct{} -} - -var ( - // RandSeed is the seed for randing package. - // It's public for test. - RandSeed = int64(1) - - // MaxRegionSampleSize is the max sample size for one region when analyze v1 collects samples from table. - // It's public for test. - MaxRegionSampleSize = int64(1000) -) - -type taskType int - -const ( - colTask taskType = iota - idxTask -) - -// Next implements the Executor Next interface. -// It will collect all the sample task and run them concurrently. -func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) error { - statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() - infoSchema := sessiontxn.GetTxnManager(e.Ctx()).GetTxnInfoSchema() - sessionVars := e.Ctx().GetSessionVars() - - // Filter the locked tables. - tasks, needAnalyzeTableCnt, skippedTables, err := filterAndCollectTasks(e.tasks, statsHandle, infoSchema) - if err != nil { - return err - } - warnLockedTableMsg(sessionVars, needAnalyzeTableCnt, skippedTables) - - if len(tasks) == 0 { - return nil - } - - // Get the min number of goroutines for parallel execution. - concurrency, err := getBuildStatsConcurrency(e.Ctx()) - if err != nil { - return err - } - concurrency = min(len(tasks), concurrency) - - // Start workers with channel to collect results. - taskCh := make(chan *analyzeTask, concurrency) - resultsCh := make(chan *statistics.AnalyzeResults, 1) - for i := 0; i < concurrency; i++ { - e.wg.Run(func() { e.analyzeWorker(taskCh, resultsCh) }) - } - pruneMode := variable.PartitionPruneMode(sessionVars.PartitionPruneMode.Load()) - // needGlobalStats used to indicate whether we should merge the partition-level stats to global-level stats. - needGlobalStats := pruneMode == variable.Dynamic - globalStatsMap := make(map[globalStatsKey]statstypes.GlobalStatsInfo) - g, gctx := errgroup.WithContext(ctx) - g.Go(func() error { - return e.handleResultsError(ctx, concurrency, needGlobalStats, globalStatsMap, resultsCh, len(tasks)) - }) - for _, task := range tasks { - prepareV2AnalyzeJobInfo(task.colExec) - AddNewAnalyzeJob(e.Ctx(), task.job) - } - failpoint.Inject("mockKillPendingAnalyzeJob", func() { - dom := domain.GetDomain(e.Ctx()) - for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { - dom.SysProcTracker().KillSysProcess(id) - } - }) -TASKLOOP: - for _, task := range tasks { - select { - case taskCh <- task: - case <-e.errExitCh: - break TASKLOOP - case <-gctx.Done(): - break TASKLOOP - } - } - close(taskCh) - defer func() { - for _, task := range tasks { - if task.colExec != nil && task.colExec.memTracker != nil { - task.colExec.memTracker.Detach() - } - } - }() - - err = e.waitFinish(ctx, g, resultsCh) - if err != nil { - return err - } - - failpoint.Inject("mockKillFinishedAnalyzeJob", func() { - dom := domain.GetDomain(e.Ctx()) - for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { - dom.SysProcTracker().KillSysProcess(id) - } - }) - // If we enabled dynamic prune mode, then we need to generate global stats here for partition tables. - if needGlobalStats { - err = e.handleGlobalStats(statsHandle, globalStatsMap) - if err != nil { - return err - } - } - - // Update analyze options to mysql.analyze_options for auto analyze. - err = e.saveV2AnalyzeOpts() - if err != nil { - sessionVars.StmtCtx.AppendWarning(err) - } - return statsHandle.Update(ctx, infoSchema) -} - -func (e *AnalyzeExec) waitFinish(ctx context.Context, g *errgroup.Group, resultsCh chan *statistics.AnalyzeResults) error { - checkwg, _ := errgroup.WithContext(ctx) - checkwg.Go(func() error { - // It is to wait for the completion of the result handler. if the result handler meets error, we should cancel - // the analyze process by closing the errExitCh. - err := g.Wait() - if err != nil { - close(e.errExitCh) - return err - } - return nil - }) - checkwg.Go(func() error { - // Wait all workers done and close the results channel. - e.wg.Wait() - close(resultsCh) - return nil - }) - return checkwg.Wait() -} - -// filterAndCollectTasks filters the tasks that are not locked and collects the table IDs. -func filterAndCollectTasks(tasks []*analyzeTask, statsHandle *handle.Handle, is infoschema.InfoSchema) ([]*analyzeTask, uint, []string, error) { - var ( - filteredTasks []*analyzeTask - skippedTables []string - needAnalyzeTableCnt uint - // tidMap is used to deduplicate table IDs. - // In stats v1, analyze for each index is a single task, and they have the same table id. - tidAndPidsMap = make(map[int64]struct{}, len(tasks)) - ) - - lockedTableAndPartitionIDs, err := getLockedTableAndPartitionIDs(statsHandle, tasks) - if err != nil { - return nil, 0, nil, err - } - - for _, task := range tasks { - // Check if the table or partition is locked. - tableID := getTableIDFromTask(task) - _, isLocked := lockedTableAndPartitionIDs[tableID.TableID] - // If the whole table is not locked, we should check whether the partition is locked. - if !isLocked && tableID.IsPartitionTable() { - _, isLocked = lockedTableAndPartitionIDs[tableID.PartitionID] - } - - // Only analyze the table that is not locked. - if !isLocked { - filteredTasks = append(filteredTasks, task) - } - - // Get the physical table ID. - physicalTableID := tableID.TableID - if tableID.IsPartitionTable() { - physicalTableID = tableID.PartitionID - } - if _, ok := tidAndPidsMap[physicalTableID]; !ok { - if isLocked { - if tableID.IsPartitionTable() { - tbl, _, def := is.FindTableByPartitionID(tableID.PartitionID) - if def == nil { - logutil.BgLogger().Warn("Unknown partition ID in analyze task", zap.Int64("pid", tableID.PartitionID)) - } else { - schema, _ := infoschema.SchemaByTable(is, tbl.Meta()) - skippedTables = append(skippedTables, fmt.Sprintf("%s.%s partition (%s)", schema.Name, tbl.Meta().Name.O, def.Name.O)) - } - } else { - tbl, ok := is.TableByID(physicalTableID) - if !ok { - logutil.BgLogger().Warn("Unknown table ID in analyze task", zap.Int64("tid", physicalTableID)) - } else { - schema, _ := infoschema.SchemaByTable(is, tbl.Meta()) - skippedTables = append(skippedTables, fmt.Sprintf("%s.%s", schema.Name, tbl.Meta().Name.O)) - } - } - } else { - needAnalyzeTableCnt++ - } - tidAndPidsMap[physicalTableID] = struct{}{} - } - } - - return filteredTasks, needAnalyzeTableCnt, skippedTables, nil -} - -// getLockedTableAndPartitionIDs queries the locked tables and partitions. -func getLockedTableAndPartitionIDs(statsHandle *handle.Handle, tasks []*analyzeTask) (map[int64]struct{}, error) { - tidAndPids := make([]int64, 0, len(tasks)) - // Check the locked tables in one transaction. - // We need to check all tables and its partitions. - // Because if the whole table is locked, we should skip all partitions. - for _, task := range tasks { - tableID := getTableIDFromTask(task) - tidAndPids = append(tidAndPids, tableID.TableID) - if tableID.IsPartitionTable() { - tidAndPids = append(tidAndPids, tableID.PartitionID) - } - } - return statsHandle.GetLockedTables(tidAndPids...) -} - -// warnLockedTableMsg warns the locked table IDs. -func warnLockedTableMsg(sessionVars *variable.SessionVars, needAnalyzeTableCnt uint, skippedTables []string) { - if len(skippedTables) > 0 { - tables := strings.Join(skippedTables, ", ") - var msg string - if len(skippedTables) > 1 { - msg = "skip analyze locked tables: %s" - if needAnalyzeTableCnt > 0 { - msg = "skip analyze locked tables: %s, other tables will be analyzed" - } - } else { - msg = "skip analyze locked table: %s" - } - sessionVars.StmtCtx.AppendWarning(errors.NewNoStackErrorf(msg, tables)) - } -} - -func getTableIDFromTask(task *analyzeTask) statistics.AnalyzeTableID { - switch task.taskType { - case colTask: - return task.colExec.tableID - case idxTask: - return task.idxExec.tableID - } - - panic("unreachable") -} - -func (e *AnalyzeExec) saveV2AnalyzeOpts() error { - if !variable.PersistAnalyzeOptions.Load() || len(e.OptionsMap) == 0 { - return nil - } - // only to save table options if dynamic prune mode - dynamicPrune := variable.PartitionPruneMode(e.Ctx().GetSessionVars().PartitionPruneMode.Load()) == variable.Dynamic - toSaveMap := make(map[int64]core.V2AnalyzeOptions) - for id, opts := range e.OptionsMap { - if !opts.IsPartition || !dynamicPrune { - toSaveMap[id] = opts - } - } - sql := new(strings.Builder) - sqlescape.MustFormatSQL(sql, "REPLACE INTO mysql.analyze_options (table_id,sample_num,sample_rate,buckets,topn,column_choice,column_ids) VALUES ") - idx := 0 - for _, opts := range toSaveMap { - sampleNum := opts.RawOpts[ast.AnalyzeOptNumSamples] - sampleRate := float64(0) - if val, ok := opts.RawOpts[ast.AnalyzeOptSampleRate]; ok { - sampleRate = math.Float64frombits(val) - } - buckets := opts.RawOpts[ast.AnalyzeOptNumBuckets] - topn := int64(-1) - if val, ok := opts.RawOpts[ast.AnalyzeOptNumTopN]; ok { - topn = int64(val) - } - colChoice := opts.ColChoice.String() - colIDs := make([]string, 0, len(opts.ColumnList)) - for _, colInfo := range opts.ColumnList { - colIDs = append(colIDs, strconv.FormatInt(colInfo.ID, 10)) - } - colIDStrs := strings.Join(colIDs, ",") - sqlescape.MustFormatSQL(sql, "(%?,%?,%?,%?,%?,%?,%?)", opts.PhyTableID, sampleNum, sampleRate, buckets, topn, colChoice, colIDStrs) - if idx < len(toSaveMap)-1 { - sqlescape.MustFormatSQL(sql, ",") - } - idx++ - } - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL) - exec := e.Ctx().GetRestrictedSQLExecutor() - _, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - return err - } - return nil -} - -func recordHistoricalStats(sctx sessionctx.Context, tableID int64) error { - statsHandle := domain.GetDomain(sctx).StatsHandle() - historicalStatsEnabled, err := statsHandle.CheckHistoricalStatsEnable() - if err != nil { - return errors.Errorf("check tidb_enable_historical_stats failed: %v", err) - } - if !historicalStatsEnabled { - return nil - } - historicalStatsWorker := domain.GetDomain(sctx).GetHistoricalStatsWorker() - historicalStatsWorker.SendTblToDumpHistoricalStats(tableID) - return nil -} - -// handleResultsError will handle the error fetch from resultsCh and record it in log -func (e *AnalyzeExec) handleResultsError( - ctx context.Context, - concurrency int, - needGlobalStats bool, - globalStatsMap globalStatsMap, - resultsCh <-chan *statistics.AnalyzeResults, - taskNum int, -) (err error) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("analyze save stats panic", zap.Any("recover", r), zap.Stack("stack")) - if err != nil { - err = stderrors.Join(err, getAnalyzePanicErr(r)) - } else { - err = getAnalyzePanicErr(r) - } - } - }() - partitionStatsConcurrency := e.Ctx().GetSessionVars().AnalyzePartitionConcurrency - // the concurrency of handleResultsError cannot be more than partitionStatsConcurrency - partitionStatsConcurrency = min(taskNum, partitionStatsConcurrency) - // If partitionStatsConcurrency > 1, we will try to demand extra session from Domain to save Analyze results in concurrency. - // If there is no extra session we can use, we will save analyze results in single-thread. - dom := domain.GetDomain(e.Ctx()) - internalCtx := kv.WithInternalSourceType(ctx, kv.InternalTxnStats) - if partitionStatsConcurrency > 1 { - // FIXME: Since we don't use it either to save analysis results or to store job history, it has no effect. Please remove this :( - subSctxs := dom.FetchAnalyzeExec(partitionStatsConcurrency) - warningMessage := "Insufficient sessions to save analyze results. Consider increasing the 'analyze-partition-concurrency-quota' configuration to improve analyze performance. " + - "This value should typically be greater than or equal to the 'tidb_analyze_partition_concurrency' variable." - if len(subSctxs) < partitionStatsConcurrency { - e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError(warningMessage)) - logutil.BgLogger().Warn( - warningMessage, - zap.Int("sessionCount", len(subSctxs)), - zap.Int("needSessionCount", partitionStatsConcurrency), - ) - } - if len(subSctxs) > 0 { - sessionCount := len(subSctxs) - logutil.BgLogger().Info("use multiple sessions to save analyze results", zap.Int("sessionCount", sessionCount)) - defer func() { - dom.ReleaseAnalyzeExec(subSctxs) - }() - return e.handleResultsErrorWithConcurrency(internalCtx, concurrency, needGlobalStats, subSctxs, globalStatsMap, resultsCh) - } - } - logutil.BgLogger().Info("use single session to save analyze results") - failpoint.Inject("handleResultsErrorSingleThreadPanic", nil) - subSctxs := []sessionctx.Context{e.Ctx()} - return e.handleResultsErrorWithConcurrency(internalCtx, concurrency, needGlobalStats, subSctxs, globalStatsMap, resultsCh) -} - -func (e *AnalyzeExec) handleResultsErrorWithConcurrency( - ctx context.Context, - statsConcurrency int, - needGlobalStats bool, - subSctxs []sessionctx.Context, - globalStatsMap globalStatsMap, - resultsCh <-chan *statistics.AnalyzeResults, -) error { - partitionStatsConcurrency := len(subSctxs) - statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() - wg := util.NewWaitGroupPool(e.gp) - saveResultsCh := make(chan *statistics.AnalyzeResults, partitionStatsConcurrency) - errCh := make(chan error, partitionStatsConcurrency) - for i := 0; i < partitionStatsConcurrency; i++ { - worker := newAnalyzeSaveStatsWorker(saveResultsCh, subSctxs[i], errCh, &e.Ctx().GetSessionVars().SQLKiller) - ctx1 := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - wg.Run(func() { - worker.run(ctx1, statsHandle, e.Ctx().GetSessionVars().EnableAnalyzeSnapshot) - }) - } - tableIDs := map[int64]struct{}{} - panicCnt := 0 - var err error - for panicCnt < statsConcurrency { - if err := e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { - close(saveResultsCh) - return err - } - results, ok := <-resultsCh - if !ok { - break - } - if results.Err != nil { - err = results.Err - if isAnalyzeWorkerPanic(err) { - panicCnt++ - } else { - logutil.Logger(ctx).Error("analyze failed", zap.Error(err)) - } - finishJobWithLog(statsHandle, results.Job, err) - continue - } - handleGlobalStats(needGlobalStats, globalStatsMap, results) - tableIDs[results.TableID.GetStatisticsID()] = struct{}{} - saveResultsCh <- results - } - close(saveResultsCh) - wg.Wait() - close(errCh) - if len(errCh) > 0 { - errMsg := make([]string, 0) - for err1 := range errCh { - errMsg = append(errMsg, err1.Error()) - } - err = errors.New(strings.Join(errMsg, ",")) - } - for tableID := range tableIDs { - // Dump stats to historical storage. - if err := recordHistoricalStats(e.Ctx(), tableID); err != nil { - logutil.BgLogger().Error("record historical stats failed", zap.Error(err)) - } - } - return err -} - -func (e *AnalyzeExec) analyzeWorker(taskCh <-chan *analyzeTask, resultsCh chan<- *statistics.AnalyzeResults) { - var task *analyzeTask - statsHandle := domain.GetDomain(e.Ctx()).StatsHandle() - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) - metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() - // If errExitCh is closed, it means the whole analyze task is aborted. So we do not need to send the result to resultsCh. - err := getAnalyzePanicErr(r) - select { - case resultsCh <- &statistics.AnalyzeResults{ - Err: err, - Job: task.job, - }: - case <-e.errExitCh: - logutil.BgLogger().Error("analyze worker exits because the whole analyze task is aborted", zap.Error(err)) - } - } - }() - for { - var ok bool - task, ok = <-taskCh - if !ok { - break - } - failpoint.Inject("handleAnalyzeWorkerPanic", nil) - statsHandle.StartAnalyzeJob(task.job) - switch task.taskType { - case colTask: - select { - case <-e.errExitCh: - return - case resultsCh <- analyzeColumnsPushDownEntry(e.gp, task.colExec): - } - case idxTask: - select { - case <-e.errExitCh: - return - case resultsCh <- analyzeIndexPushdown(task.idxExec): - } - } - } -} - -type analyzeTask struct { - taskType taskType - idxExec *AnalyzeIndexExec - colExec *AnalyzeColumnsExec - job *statistics.AnalyzeJob -} - -type baseAnalyzeExec struct { - ctx sessionctx.Context - tableID statistics.AnalyzeTableID - concurrency int - analyzePB *tipb.AnalyzeReq - opts map[ast.AnalyzeOptionType]uint64 - job *statistics.AnalyzeJob - snapshot uint64 -} - -// AddNewAnalyzeJob records the new analyze job. -func AddNewAnalyzeJob(ctx sessionctx.Context, job *statistics.AnalyzeJob) { - if job == nil { - return - } - var instance string - serverInfo, err := infosync.GetServerInfo() - if err != nil { - logutil.BgLogger().Error("failed to get server info", zap.Error(err)) - instance = "unknown" - } else { - instance = net.JoinHostPort(serverInfo.IP, strconv.Itoa(int(serverInfo.Port))) - } - statsHandle := domain.GetDomain(ctx).StatsHandle() - err = statsHandle.InsertAnalyzeJob(job, instance, ctx.GetSessionVars().ConnectionID) - if err != nil { - logutil.BgLogger().Error("failed to insert analyze job", zap.Error(err)) - } -} - -func finishJobWithLog(statsHandle *handle.Handle, job *statistics.AnalyzeJob, analyzeErr error) { - statsHandle.FinishAnalyzeJob(job, analyzeErr, statistics.TableAnalysisJob) - if job != nil { - var state string - if analyzeErr != nil { - state = statistics.AnalyzeFailed - logutil.BgLogger().Warn(fmt.Sprintf("analyze table `%s`.`%s` has %s", job.DBName, job.TableName, state), - zap.String("partition", job.PartitionName), - zap.String("job info", job.JobInfo), - zap.Time("start time", job.StartTime), - zap.Time("end time", job.EndTime), - zap.String("cost", job.EndTime.Sub(job.StartTime).String()), - zap.String("sample rate reason", job.SampleRateReason), - zap.Error(analyzeErr)) - } else { - state = statistics.AnalyzeFinished - logutil.BgLogger().Info(fmt.Sprintf("analyze table `%s`.`%s` has %s", job.DBName, job.TableName, state), - zap.String("partition", job.PartitionName), - zap.String("job info", job.JobInfo), - zap.Time("start time", job.StartTime), - zap.Time("end time", job.EndTime), - zap.String("cost", job.EndTime.Sub(job.StartTime).String()), - zap.String("sample rate reason", job.SampleRateReason)) - } - } -} - -func handleGlobalStats(needGlobalStats bool, globalStatsMap globalStatsMap, results *statistics.AnalyzeResults) { - if results.TableID.IsPartitionTable() && needGlobalStats { - for _, result := range results.Ars { - if result.IsIndex == 0 { - // If it does not belong to the statistics of index, we need to set it to -1 to distinguish. - globalStatsID := globalStatsKey{tableID: results.TableID.TableID, indexID: int64(-1)} - histIDs := make([]int64, 0, len(result.Hist)) - for _, hg := range result.Hist { - // It's normal virtual column, skip. - if hg == nil { - continue - } - histIDs = append(histIDs, hg.ID) - } - globalStatsMap[globalStatsID] = statstypes.GlobalStatsInfo{IsIndex: result.IsIndex, HistIDs: histIDs, StatsVersion: results.StatsVer} - } else { - for _, hg := range result.Hist { - globalStatsID := globalStatsKey{tableID: results.TableID.TableID, indexID: hg.ID} - globalStatsMap[globalStatsID] = statstypes.GlobalStatsInfo{IsIndex: result.IsIndex, HistIDs: []int64{hg.ID}, StatsVersion: results.StatsVer} - } - } - } - } -} diff --git a/pkg/executor/analyze_col.go b/pkg/executor/analyze_col.go index cab8a5101e3d9..78a5190b8d21b 100644 --- a/pkg/executor/analyze_col.go +++ b/pkg/executor/analyze_col.go @@ -176,18 +176,18 @@ func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats boo } statsHandle := domain.GetDomain(e.ctx).StatsHandle() for { - if _, _err_ := failpoint.Eval(_curpkg_("mockKillRunningV1AnalyzeJob")); _err_ == nil { + failpoint.Inject("mockKillRunningV1AnalyzeJob", func() { dom := domain.GetDomain(e.ctx) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - } + }) if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { return nil, nil, nil, nil, nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("mockSlowAnalyzeV1")); _err_ == nil { + failpoint.Inject("mockSlowAnalyzeV1", func() { time.Sleep(1000 * time.Second) - } + }) data, err1 := e.resultHandler.nextRaw(context.TODO()) if err1 != nil { return nil, nil, nil, nil, nil, err1 diff --git a/pkg/executor/analyze_col.go__failpoint_stash__ b/pkg/executor/analyze_col.go__failpoint_stash__ deleted file mode 100644 index 78a5190b8d21b..0000000000000 --- a/pkg/executor/analyze_col.go__failpoint_stash__ +++ /dev/null @@ -1,494 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "fmt" - "math" - "strings" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/core" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/statistics" - handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tipb/go-tipb" - "github.com/tiancaiamao/gp" -) - -// AnalyzeColumnsExec represents Analyze columns push down executor. -type AnalyzeColumnsExec struct { - baseAnalyzeExec - - tableInfo *model.TableInfo - colsInfo []*model.ColumnInfo - handleCols plannerutil.HandleCols - commonHandle *model.IndexInfo - resultHandler *tableResultHandler - indexes []*model.IndexInfo - core.AnalyzeInfo - - samplingBuilderWg *notifyErrorWaitGroupWrapper - samplingMergeWg *util.WaitGroupWrapper - - schemaForVirtualColEval *expression.Schema - baseCount int64 - baseModifyCnt int64 - - memTracker *memory.Tracker -} - -func analyzeColumnsPushDownEntry(gp *gp.Pool, e *AnalyzeColumnsExec) *statistics.AnalyzeResults { - if e.AnalyzeInfo.StatsVersion >= statistics.Version2 { - return e.toV2().analyzeColumnsPushDownV2(gp) - } - return e.toV1().analyzeColumnsPushDownV1() -} - -func (e *AnalyzeColumnsExec) toV1() *AnalyzeColumnsExecV1 { - return &AnalyzeColumnsExecV1{ - AnalyzeColumnsExec: e, - } -} - -func (e *AnalyzeColumnsExec) toV2() *AnalyzeColumnsExecV2 { - return &AnalyzeColumnsExecV2{ - AnalyzeColumnsExec: e, - } -} - -func (e *AnalyzeColumnsExec) open(ranges []*ranger.Range) error { - e.memTracker = memory.NewTracker(int(e.ctx.GetSessionVars().PlanID.Load()), -1) - e.memTracker.AttachTo(e.ctx.GetSessionVars().StmtCtx.MemTracker) - e.resultHandler = &tableResultHandler{} - firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(ranges, true, false, !hasPkHist(e.handleCols)) - firstResult, err := e.buildResp(firstPartRanges) - if err != nil { - return err - } - if len(secondPartRanges) == 0 { - e.resultHandler.open(nil, firstResult) - return nil - } - var secondResult distsql.SelectResult - secondResult, err = e.buildResp(secondPartRanges) - if err != nil { - return err - } - e.resultHandler.open(firstResult, secondResult) - - return nil -} - -func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectResult, error) { - var builder distsql.RequestBuilder - reqBuilder := builder.SetHandleRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.TableID.GetStatisticsID()}, e.handleCols != nil && !e.handleCols.IsInt(), ranges) - builder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) - startTS := uint64(math.MaxUint64) - isoLevel := kv.RC - if e.ctx.GetSessionVars().EnableAnalyzeSnapshot { - startTS = e.snapshot - isoLevel = kv.SI - } - // Always set KeepOrder of the request to be true, in order to compute - // correct `correlation` of columns. - kvReq, err := reqBuilder. - SetAnalyzeRequest(e.analyzePB, isoLevel). - SetStartTS(startTS). - SetKeepOrder(true). - SetConcurrency(e.concurrency). - SetMemTracker(e.memTracker). - SetResourceGroupName(e.ctx.GetSessionVars().StmtCtx.ResourceGroupName). - SetExplicitRequestSourceType(e.ctx.GetSessionVars().ExplicitRequestSourceType). - Build() - if err != nil { - return nil, err - } - ctx := context.TODO() - result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) - if err != nil { - return nil, err - } - return result, nil -} - -func (e *AnalyzeColumnsExec) buildStats(ranges []*ranger.Range, needExtStats bool) (hists []*statistics.Histogram, cms []*statistics.CMSketch, topNs []*statistics.TopN, fms []*statistics.FMSketch, extStats *statistics.ExtendedStatsColl, err error) { - if err = e.open(ranges); err != nil { - return nil, nil, nil, nil, nil, err - } - defer func() { - if err1 := e.resultHandler.Close(); err1 != nil { - hists = nil - cms = nil - extStats = nil - err = err1 - } - }() - var handleHist *statistics.Histogram - var handleCms *statistics.CMSketch - var handleFms *statistics.FMSketch - var handleTopn *statistics.TopN - statsVer := statistics.Version1 - if e.analyzePB.Tp == tipb.AnalyzeType_TypeMixed { - handleHist = &statistics.Histogram{} - handleCms = statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])) - handleTopn = statistics.NewTopN(int(e.opts[ast.AnalyzeOptNumTopN])) - handleFms = statistics.NewFMSketch(statistics.MaxSketchSize) - if e.analyzePB.IdxReq.Version != nil { - statsVer = int(*e.analyzePB.IdxReq.Version) - } - } - pkHist := &statistics.Histogram{} - collectors := make([]*statistics.SampleCollector, len(e.colsInfo)) - for i := range collectors { - collectors[i] = &statistics.SampleCollector{ - IsMerger: true, - FMSketch: statistics.NewFMSketch(statistics.MaxSketchSize), - MaxSampleSize: int64(e.opts[ast.AnalyzeOptNumSamples]), - CMSketch: statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])), - } - } - statsHandle := domain.GetDomain(e.ctx).StatsHandle() - for { - failpoint.Inject("mockKillRunningV1AnalyzeJob", func() { - dom := domain.GetDomain(e.ctx) - for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { - dom.SysProcTracker().KillSysProcess(id) - } - }) - if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { - return nil, nil, nil, nil, nil, err - } - failpoint.Inject("mockSlowAnalyzeV1", func() { - time.Sleep(1000 * time.Second) - }) - data, err1 := e.resultHandler.nextRaw(context.TODO()) - if err1 != nil { - return nil, nil, nil, nil, nil, err1 - } - if data == nil { - break - } - var colResp *tipb.AnalyzeColumnsResp - if e.analyzePB.Tp == tipb.AnalyzeType_TypeMixed { - resp := &tipb.AnalyzeMixedResp{} - err = resp.Unmarshal(data) - if err != nil { - return nil, nil, nil, nil, nil, err - } - colResp = resp.ColumnsResp - handleHist, handleCms, handleFms, handleTopn, err = updateIndexResult(e.ctx, resp.IndexResp, nil, handleHist, - handleCms, handleFms, handleTopn, e.commonHandle, int(e.opts[ast.AnalyzeOptNumBuckets]), - int(e.opts[ast.AnalyzeOptNumTopN]), statsVer) - - if err != nil { - return nil, nil, nil, nil, nil, err - } - } else { - colResp = &tipb.AnalyzeColumnsResp{} - err = colResp.Unmarshal(data) - } - sc := e.ctx.GetSessionVars().StmtCtx - rowCount := int64(0) - if hasPkHist(e.handleCols) { - respHist := statistics.HistogramFromProto(colResp.PkHist) - rowCount = int64(respHist.TotalRowCount()) - pkHist, err = statistics.MergeHistograms(sc, pkHist, respHist, int(e.opts[ast.AnalyzeOptNumBuckets]), statistics.Version1) - if err != nil { - return nil, nil, nil, nil, nil, err - } - } - for i, rc := range colResp.Collectors { - respSample := statistics.SampleCollectorFromProto(rc) - rowCount = respSample.Count + respSample.NullCount - collectors[i].MergeSampleCollector(sc, respSample) - } - statsHandle.UpdateAnalyzeJobProgress(e.job, rowCount) - } - timeZone := e.ctx.GetSessionVars().Location() - if hasPkHist(e.handleCols) { - pkInfo := e.handleCols.GetCol(0) - pkHist.ID = pkInfo.ID - err = pkHist.DecodeTo(pkInfo.RetType, timeZone) - if err != nil { - return nil, nil, nil, nil, nil, err - } - hists = append(hists, pkHist) - cms = append(cms, nil) - topNs = append(topNs, nil) - fms = append(fms, nil) - } - for i, col := range e.colsInfo { - if e.StatsVersion < 2 { - // In analyze version 2, we don't collect TopN this way. We will collect TopN from samples in `BuildColumnHistAndTopN()` below. - err := collectors[i].ExtractTopN(uint32(e.opts[ast.AnalyzeOptNumTopN]), e.ctx.GetSessionVars().StmtCtx, &col.FieldType, timeZone) - if err != nil { - return nil, nil, nil, nil, nil, err - } - topNs = append(topNs, collectors[i].TopN) - } - for j, s := range collectors[i].Samples { - s.Ordinal = j - s.Value, err = tablecodec.DecodeColumnValue(s.Value.GetBytes(), &col.FieldType, timeZone) - if err != nil { - return nil, nil, nil, nil, nil, err - } - // When collation is enabled, we store the Key representation of the sampling data. So we set it to kind `Bytes` here - // to avoid to convert it to its Key representation once more. - if s.Value.Kind() == types.KindString { - s.Value.SetBytes(s.Value.GetBytes()) - } - } - var hg *statistics.Histogram - var err error - var topn *statistics.TopN - if e.StatsVersion < 2 { - hg, err = statistics.BuildColumn(e.ctx, int64(e.opts[ast.AnalyzeOptNumBuckets]), col.ID, collectors[i], &col.FieldType) - } else { - hg, topn, err = statistics.BuildHistAndTopN(e.ctx, int(e.opts[ast.AnalyzeOptNumBuckets]), int(e.opts[ast.AnalyzeOptNumTopN]), col.ID, collectors[i], &col.FieldType, true, nil, true) - topNs = append(topNs, topn) - } - if err != nil { - return nil, nil, nil, nil, nil, err - } - hists = append(hists, hg) - collectors[i].CMSketch.CalcDefaultValForAnalyze(uint64(hg.NDV)) - cms = append(cms, collectors[i].CMSketch) - fms = append(fms, collectors[i].FMSketch) - } - if needExtStats { - extStats, err = statistics.BuildExtendedStats(e.ctx, e.TableID.GetStatisticsID(), e.colsInfo, collectors) - if err != nil { - return nil, nil, nil, nil, nil, err - } - } - if handleHist != nil { - handleHist.ID = e.commonHandle.ID - if handleTopn != nil && handleTopn.TotalCount() > 0 { - handleHist.RemoveVals(handleTopn.TopN) - } - if handleCms != nil { - handleCms.CalcDefaultValForAnalyze(uint64(handleHist.NDV)) - } - hists = append([]*statistics.Histogram{handleHist}, hists...) - cms = append([]*statistics.CMSketch{handleCms}, cms...) - fms = append([]*statistics.FMSketch{handleFms}, fms...) - topNs = append([]*statistics.TopN{handleTopn}, topNs...) - } - return hists, cms, topNs, fms, extStats, nil -} - -// AnalyzeColumnsExecV1 is used to maintain v1 analyze process -type AnalyzeColumnsExecV1 struct { - *AnalyzeColumnsExec -} - -func (e *AnalyzeColumnsExecV1) analyzeColumnsPushDownV1() *statistics.AnalyzeResults { - var ranges []*ranger.Range - if hc := e.handleCols; hc != nil { - if hc.IsInt() { - ranges = ranger.FullIntRange(mysql.HasUnsignedFlag(hc.GetCol(0).RetType.GetFlag())) - } else { - ranges = ranger.FullNotNullRange() - } - } else { - ranges = ranger.FullIntRange(false) - } - collExtStats := e.ctx.GetSessionVars().EnableExtendedStats - hists, cms, topNs, fms, extStats, err := e.buildStats(ranges, collExtStats) - if err != nil { - return &statistics.AnalyzeResults{Err: err, Job: e.job} - } - - if hasPkHist(e.handleCols) { - pkResult := &statistics.AnalyzeResult{ - Hist: hists[:1], - Cms: cms[:1], - TopNs: topNs[:1], - Fms: fms[:1], - } - restResult := &statistics.AnalyzeResult{ - Hist: hists[1:], - Cms: cms[1:], - TopNs: topNs[1:], - Fms: fms[1:], - } - return &statistics.AnalyzeResults{ - TableID: e.tableID, - Ars: []*statistics.AnalyzeResult{pkResult, restResult}, - ExtStats: extStats, - Job: e.job, - StatsVer: e.StatsVersion, - Count: int64(pkResult.Hist[0].TotalRowCount()), - Snapshot: e.snapshot, - } - } - var ars []*statistics.AnalyzeResult - if e.analyzePB.Tp == tipb.AnalyzeType_TypeMixed { - ars = append(ars, &statistics.AnalyzeResult{ - Hist: []*statistics.Histogram{hists[0]}, - Cms: []*statistics.CMSketch{cms[0]}, - TopNs: []*statistics.TopN{topNs[0]}, - Fms: []*statistics.FMSketch{nil}, - IsIndex: 1, - }) - hists = hists[1:] - cms = cms[1:] - topNs = topNs[1:] - } - colResult := &statistics.AnalyzeResult{ - Hist: hists, - Cms: cms, - TopNs: topNs, - Fms: fms, - } - ars = append(ars, colResult) - cnt := int64(hists[0].TotalRowCount()) - if e.StatsVersion >= statistics.Version2 { - cnt += int64(topNs[0].TotalCount()) - } - return &statistics.AnalyzeResults{ - TableID: e.tableID, - Ars: ars, - Job: e.job, - StatsVer: e.StatsVersion, - ExtStats: extStats, - Count: cnt, - Snapshot: e.snapshot, - } -} - -func hasPkHist(handleCols plannerutil.HandleCols) bool { - return handleCols != nil && handleCols.IsInt() -} - -// prepareColumns prepares the columns for the analyze job. -func prepareColumns(e *AnalyzeColumnsExec, b *strings.Builder) { - cols := e.colsInfo - // Ignore the _row_id column. - if len(cols) > 0 && cols[len(cols)-1].ID == model.ExtraHandleID { - cols = cols[:len(cols)-1] - } - // If there are no columns, skip the process. - if len(cols) == 0 { - return - } - if len(cols) < len(e.tableInfo.Columns) { - if len(cols) > 1 { - b.WriteString(" columns ") - } else { - b.WriteString(" column ") - } - for i, col := range cols { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(col.Name.O) - } - } else { - b.WriteString(" all columns") - } -} - -// prepareIndexes prepares the indexes for the analyze job. -func prepareIndexes(e *AnalyzeColumnsExec, b *strings.Builder) { - indexes := e.indexes - - // If there are no indexes, skip the process. - if len(indexes) == 0 { - return - } - if len(indexes) < len(e.tableInfo.Indices) { - if len(indexes) > 1 { - b.WriteString(" indexes ") - } else { - b.WriteString(" index ") - } - for i, index := range indexes { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(index.Name.O) - } - } else { - b.WriteString(" all indexes") - } -} - -// prepareV2AnalyzeJobInfo prepares the job info for the analyze job. -func prepareV2AnalyzeJobInfo(e *AnalyzeColumnsExec) { - // For v1, we analyze all columns in a single job, so we don't need to set the job info. - if e == nil || e.StatsVersion != statistics.Version2 { - return - } - - opts := e.opts - if e.V2Options != nil { - opts = e.V2Options.FilledOpts - } - sampleRate := *e.analyzePB.ColReq.SampleRate - var b strings.Builder - // If it is an internal SQL, it means it is triggered by the system itself(auto-analyze). - if e.ctx.GetSessionVars().InRestrictedSQL { - b.WriteString("auto ") - } - b.WriteString("analyze table") - - prepareIndexes(e, &b) - if len(e.indexes) > 0 && len(e.colsInfo) > 0 { - b.WriteString(",") - } - prepareColumns(e, &b) - - var needComma bool - b.WriteString(" with ") - printOption := func(optType ast.AnalyzeOptionType) { - if val, ok := opts[optType]; ok { - if needComma { - b.WriteString(", ") - } else { - needComma = true - } - b.WriteString(fmt.Sprintf("%v %s", val, strings.ToLower(ast.AnalyzeOptionString[optType]))) - } - } - printOption(ast.AnalyzeOptNumBuckets) - printOption(ast.AnalyzeOptNumTopN) - if opts[ast.AnalyzeOptNumSamples] != 0 { - printOption(ast.AnalyzeOptNumSamples) - } else { - if needComma { - b.WriteString(", ") - } else { - needComma = true - } - b.WriteString(fmt.Sprintf("%v samplerate", sampleRate)) - } - e.job.JobInfo = b.String() -} diff --git a/pkg/executor/analyze_col_v2.go b/pkg/executor/analyze_col_v2.go index 4e41f0b1bbabb..6eefe269fc865 100644 --- a/pkg/executor/analyze_col_v2.go +++ b/pkg/executor/analyze_col_v2.go @@ -616,16 +616,16 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu close(resultCh) } }() - if _, _err_ := failpoint.Eval(_curpkg_("mockAnalyzeSamplingMergeWorkerPanic")); _err_ == nil { + failpoint.Inject("mockAnalyzeSamplingMergeWorkerPanic", func() { panic("failpoint triggered") - } - if val, _err_ := failpoint.Eval(_curpkg_("mockAnalyzeMergeWorkerSlowConsume")); _err_ == nil { + }) + failpoint.Inject("mockAnalyzeMergeWorkerSlowConsume", func(val failpoint.Value) { times := val.(int) for i := 0; i < times; i++ { e.memTracker.Consume(5 << 20) time.Sleep(100 * time.Millisecond) } - } + }) retCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) for i := 0; i < l; i++ { retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) @@ -682,9 +682,9 @@ func (e *AnalyzeColumnsExecV2) subBuildWorker(resultCh chan error, taskCh chan * resultCh <- getAnalyzePanicErr(r) } }() - if _, _err_ := failpoint.Eval(_curpkg_("mockAnalyzeSamplingBuildWorkerPanic")); _err_ == nil { + failpoint.Inject("mockAnalyzeSamplingBuildWorkerPanic", func() { panic("failpoint triggered") - } + }) colLen := len(e.colsInfo) bufferedMemSize := int64(0) @@ -856,18 +856,18 @@ func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, me // After all tasks are sent, close the mergeTaskCh to notify the mergeWorker that all tasks have been sent. defer close(mergeTaskCh) for { - if _, _err_ := failpoint.Eval(_curpkg_("mockKillRunningV2AnalyzeJob")); _err_ == nil { + failpoint.Inject("mockKillRunningV2AnalyzeJob", func() { dom := domain.GetDomain(ctx) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - } + }) if err := ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { return err } - if _, _err_ := failpoint.Eval(_curpkg_("mockSlowAnalyzeV2")); _err_ == nil { + failpoint.Inject("mockSlowAnalyzeV2", func() { time.Sleep(1000 * time.Second) - } + }) data, err := handler.nextRaw(context.TODO()) if err != nil { diff --git a/pkg/executor/analyze_col_v2.go__failpoint_stash__ b/pkg/executor/analyze_col_v2.go__failpoint_stash__ deleted file mode 100644 index 6eefe269fc865..0000000000000 --- a/pkg/executor/analyze_col_v2.go__failpoint_stash__ +++ /dev/null @@ -1,885 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - stderrors "errors" - "slices" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/statistics" - handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tidb/pkg/util/timeutil" - "github.com/pingcap/tipb/go-tipb" - "github.com/tiancaiamao/gp" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -// AnalyzeColumnsExecV2 is used to maintain v2 analyze process -type AnalyzeColumnsExecV2 struct { - *AnalyzeColumnsExec -} - -func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2(gp *gp.Pool) *statistics.AnalyzeResults { - var ranges []*ranger.Range - if hc := e.handleCols; hc != nil { - if hc.IsInt() { - ranges = ranger.FullIntRange(mysql.HasUnsignedFlag(hc.GetCol(0).RetType.GetFlag())) - } else { - ranges = ranger.FullNotNullRange() - } - } else { - ranges = ranger.FullIntRange(false) - } - - collExtStats := e.ctx.GetSessionVars().EnableExtendedStats - // specialIndexes holds indexes that include virtual or prefix columns. For these indexes, - // only the number of distinct values (NDV) is computed using TiKV. Other statistic - // are derived from sample data processed within TiDB. - // The reason is that we want to keep the same row sampling for all columns. - specialIndexes := make([]*model.IndexInfo, 0, len(e.indexes)) - specialIndexesOffsets := make([]int, 0, len(e.indexes)) - for i, idx := range e.indexes { - isSpecial := false - for _, col := range idx.Columns { - colInfo := e.colsInfo[col.Offset] - isVirtualCol := colInfo.IsGenerated() && !colInfo.GeneratedStored - isPrefixCol := col.Length != types.UnspecifiedLength - if isVirtualCol || isPrefixCol { - isSpecial = true - break - } - } - if isSpecial { - specialIndexesOffsets = append(specialIndexesOffsets, i) - specialIndexes = append(specialIndexes, idx) - } - } - samplingStatsConcurrency, err := getBuildSamplingStatsConcurrency(e.ctx) - if err != nil { - e.memTracker.Release(e.memTracker.BytesConsumed()) - return &statistics.AnalyzeResults{Err: err, Job: e.job} - } - statsConcurrncy, err := getBuildStatsConcurrency(e.ctx) - if err != nil { - e.memTracker.Release(e.memTracker.BytesConsumed()) - return &statistics.AnalyzeResults{Err: err, Job: e.job} - } - idxNDVPushDownCh := make(chan analyzeIndexNDVTotalResult, 1) - // subIndexWorkerWg is better to be initialized in handleNDVForSpecialIndexes, however if we do so, golang would - // report unexpected/unreasonable data race error on subIndexWorkerWg when running TestAnalyzeVirtualCol test - // case with `-race` flag now. - wg := util.NewWaitGroupPool(gp) - wg.Run(func() { - e.handleNDVForSpecialIndexes(specialIndexes, idxNDVPushDownCh, statsConcurrncy) - }) - defer wg.Wait() - - count, hists, topNs, fmSketches, extStats, err := e.buildSamplingStats(gp, ranges, collExtStats, specialIndexesOffsets, idxNDVPushDownCh, samplingStatsConcurrency) - if err != nil { - e.memTracker.Release(e.memTracker.BytesConsumed()) - return &statistics.AnalyzeResults{Err: err, Job: e.job} - } - cLen := len(e.analyzePB.ColReq.ColumnsInfo) - colGroupResult := &statistics.AnalyzeResult{ - Hist: hists[cLen:], - TopNs: topNs[cLen:], - Fms: fmSketches[cLen:], - IsIndex: 1, - } - // Discard stats of _tidb_rowid. - // Because the process of analyzing will keep the order of results be the same as the colsInfo in the analyze task, - // and in `buildAnalyzeFullSamplingTask` we always place the _tidb_rowid at the last of colsInfo, so if there are - // stats for _tidb_rowid, it must be at the end of the column stats. - // Virtual column has no histogram yet. So we check nil here. - if hists[cLen-1] != nil && hists[cLen-1].ID == -1 { - cLen-- - } - colResult := &statistics.AnalyzeResult{ - Hist: hists[:cLen], - TopNs: topNs[:cLen], - Fms: fmSketches[:cLen], - } - - return &statistics.AnalyzeResults{ - TableID: e.tableID, - Ars: []*statistics.AnalyzeResult{colResult, colGroupResult}, - Job: e.job, - StatsVer: e.StatsVersion, - Count: count, - Snapshot: e.snapshot, - ExtStats: extStats, - BaseCount: e.baseCount, - BaseModifyCnt: e.baseModifyCnt, - } -} - -// decodeSampleDataWithVirtualColumn constructs the virtual column by evaluating from the decoded normal columns. -func (e *AnalyzeColumnsExecV2) decodeSampleDataWithVirtualColumn( - collector statistics.RowSampleCollector, - fieldTps []*types.FieldType, - virtualColIdx []int, - schema *expression.Schema, -) error { - totFts := make([]*types.FieldType, 0, e.schemaForVirtualColEval.Len()) - for _, col := range e.schemaForVirtualColEval.Columns { - totFts = append(totFts, col.RetType) - } - chk := chunk.NewChunkWithCapacity(totFts, len(collector.Base().Samples)) - decoder := codec.NewDecoder(chk, e.ctx.GetSessionVars().Location()) - for _, sample := range collector.Base().Samples { - for i, columns := range sample.Columns { - if schema.Columns[i].VirtualExpr != nil { - continue - } - _, err := decoder.DecodeOne(columns.GetBytes(), i, e.schemaForVirtualColEval.Columns[i].RetType) - if err != nil { - return err - } - } - } - err := table.FillVirtualColumnValue(fieldTps, virtualColIdx, schema.Columns, e.colsInfo, e.ctx.GetExprCtx(), chk) - if err != nil { - return err - } - iter := chunk.NewIterator4Chunk(chk) - for row, i := iter.Begin(), 0; row != iter.End(); row, i = iter.Next(), i+1 { - datums := row.GetDatumRow(totFts) - collector.Base().Samples[i].Columns = datums - } - return nil -} - -func printAnalyzeMergeCollectorLog(oldRootCount, newRootCount, subCount, tableID, partitionID int64, isPartition bool, info string, index int) { - if index < 0 { - logutil.BgLogger().Debug(info, - zap.Int64("tableID", tableID), - zap.Int64("partitionID", partitionID), - zap.Bool("isPartitionTable", isPartition), - zap.Int64("oldRootCount", oldRootCount), - zap.Int64("newRootCount", newRootCount), - zap.Int64("subCount", subCount)) - } else { - logutil.BgLogger().Debug(info, - zap.Int64("tableID", tableID), - zap.Int64("partitionID", partitionID), - zap.Bool("isPartitionTable", isPartition), - zap.Int64("oldRootCount", oldRootCount), - zap.Int64("newRootCount", newRootCount), - zap.Int64("subCount", subCount), - zap.Int("subCollectorIndex", index)) - } -} - -func (e *AnalyzeColumnsExecV2) buildSamplingStats( - gp *gp.Pool, - ranges []*ranger.Range, - needExtStats bool, - indexesWithVirtualColOffsets []int, - idxNDVPushDownCh chan analyzeIndexNDVTotalResult, - samplingStatsConcurrency int, -) ( - count int64, - hists []*statistics.Histogram, - topns []*statistics.TopN, - fmSketches []*statistics.FMSketch, - extStats *statistics.ExtendedStatsColl, - err error, -) { - // Open memory tracker and resultHandler. - if err = e.open(ranges); err != nil { - return 0, nil, nil, nil, nil, err - } - defer func() { - if err1 := e.resultHandler.Close(); err1 != nil { - err = err1 - } - }() - - l := len(e.analyzePB.ColReq.ColumnsInfo) + len(e.analyzePB.ColReq.ColumnGroups) - rootRowCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) - for i := 0; i < l; i++ { - rootRowCollector.Base().FMSketches = append(rootRowCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) - } - - sc := e.ctx.GetSessionVars().StmtCtx - - // Start workers to merge the result from collectors. - mergeResultCh := make(chan *samplingMergeResult, 1) - mergeTaskCh := make(chan []byte, 1) - var taskEg errgroup.Group - // Start read data from resultHandler and send them to mergeTaskCh. - taskEg.Go(func() (err error) { - defer func() { - if r := recover(); r != nil { - err = getAnalyzePanicErr(r) - } - }() - return readDataAndSendTask(e.ctx, e.resultHandler, mergeTaskCh, e.memTracker) - }) - e.samplingMergeWg = &util.WaitGroupWrapper{} - e.samplingMergeWg.Add(samplingStatsConcurrency) - for i := 0; i < samplingStatsConcurrency; i++ { - id := i - gp.Go(func() { - e.subMergeWorker(mergeResultCh, mergeTaskCh, l, id) - }) - } - // Merge the result from collectors. - mergeWorkerPanicCnt := 0 - mergeEg, mergeCtx := errgroup.WithContext(context.Background()) - mergeEg.Go(func() (err error) { - defer func() { - if r := recover(); r != nil { - err = getAnalyzePanicErr(r) - } - }() - for mergeWorkerPanicCnt < samplingStatsConcurrency { - mergeResult, ok := <-mergeResultCh - if !ok { - break - } - if mergeResult.err != nil { - err = mergeResult.err - if isAnalyzeWorkerPanic(mergeResult.err) { - mergeWorkerPanicCnt++ - } - continue - } - oldRootCollectorSize := rootRowCollector.Base().MemSize - oldRootCollectorCount := rootRowCollector.Base().Count - // Merge the result from sub-collectors. - rootRowCollector.MergeCollector(mergeResult.collector) - newRootCollectorCount := rootRowCollector.Base().Count - printAnalyzeMergeCollectorLog(oldRootCollectorCount, newRootCollectorCount, - mergeResult.collector.Base().Count, e.tableID.TableID, e.tableID.PartitionID, e.tableID.IsPartitionTable(), - "merge subMergeWorker in AnalyzeColumnsExecV2", -1) - e.memTracker.Consume(rootRowCollector.Base().MemSize - oldRootCollectorSize - mergeResult.collector.Base().MemSize) - mergeResult.collector.DestroyAndPutToPool() - } - return err - }) - err = taskEg.Wait() - if err != nil { - mergeCtx.Done() - if err1 := mergeEg.Wait(); err1 != nil { - err = stderrors.Join(err, err1) - } - return 0, nil, nil, nil, nil, getAnalyzePanicErr(err) - } - err = mergeEg.Wait() - defer e.memTracker.Release(rootRowCollector.Base().MemSize) - if err != nil { - return 0, nil, nil, nil, nil, err - } - - // Decode the data from sample collectors. - virtualColIdx := buildVirtualColumnIndex(e.schemaForVirtualColEval, e.colsInfo) - // Filling virtual columns is necessary here because these samples are used to build statistics for indexes that constructed by virtual columns. - if len(virtualColIdx) > 0 { - fieldTps := make([]*types.FieldType, 0, len(virtualColIdx)) - for _, colOffset := range virtualColIdx { - fieldTps = append(fieldTps, e.schemaForVirtualColEval.Columns[colOffset].RetType) - } - err = e.decodeSampleDataWithVirtualColumn(rootRowCollector, fieldTps, virtualColIdx, e.schemaForVirtualColEval) - if err != nil { - return 0, nil, nil, nil, nil, err - } - } else { - // If there's no virtual column, normal decode way is enough. - for _, sample := range rootRowCollector.Base().Samples { - for i := range sample.Columns { - sample.Columns[i], err = tablecodec.DecodeColumnValue(sample.Columns[i].GetBytes(), &e.colsInfo[i].FieldType, sc.TimeZone()) - if err != nil { - return 0, nil, nil, nil, nil, err - } - } - } - } - - // Calculate handle from the row data for each row. It will be used to sort the samples. - for _, sample := range rootRowCollector.Base().Samples { - sample.Handle, err = e.handleCols.BuildHandleByDatums(sample.Columns) - if err != nil { - return 0, nil, nil, nil, nil, err - } - } - colLen := len(e.colsInfo) - // The order of the samples are broken when merging samples from sub-collectors. - // So now we need to sort the samples according to the handle in order to calculate correlation. - slices.SortFunc(rootRowCollector.Base().Samples, func(i, j *statistics.ReservoirRowSampleItem) int { - return i.Handle.Compare(j.Handle) - }) - - totalLen := len(e.colsInfo) + len(e.indexes) - hists = make([]*statistics.Histogram, totalLen) - topns = make([]*statistics.TopN, totalLen) - fmSketches = make([]*statistics.FMSketch, 0, totalLen) - buildResultChan := make(chan error, totalLen) - buildTaskChan := make(chan *samplingBuildTask, totalLen) - if totalLen < samplingStatsConcurrency { - samplingStatsConcurrency = totalLen - } - e.samplingBuilderWg = newNotifyErrorWaitGroupWrapper(gp, buildResultChan) - sampleCollectors := make([]*statistics.SampleCollector, len(e.colsInfo)) - exitCh := make(chan struct{}) - e.samplingBuilderWg.Add(samplingStatsConcurrency) - - // Start workers to build stats. - for i := 0; i < samplingStatsConcurrency; i++ { - e.samplingBuilderWg.Run(func() { - e.subBuildWorker(buildResultChan, buildTaskChan, hists, topns, sampleCollectors, exitCh) - }) - } - // Generate tasks for building stats. - for i, col := range e.colsInfo { - buildTaskChan <- &samplingBuildTask{ - id: col.ID, - rootRowCollector: rootRowCollector, - tp: &col.FieldType, - isColumn: true, - slicePos: i, - } - fmSketches = append(fmSketches, rootRowCollector.Base().FMSketches[i]) - } - - indexPushedDownResult := <-idxNDVPushDownCh - if indexPushedDownResult.err != nil { - close(exitCh) - e.samplingBuilderWg.Wait() - return 0, nil, nil, nil, nil, indexPushedDownResult.err - } - for _, offset := range indexesWithVirtualColOffsets { - ret := indexPushedDownResult.results[e.indexes[offset].ID] - rootRowCollector.Base().NullCount[colLen+offset] = ret.Count - rootRowCollector.Base().FMSketches[colLen+offset] = ret.Ars[0].Fms[0] - } - - // Generate tasks for building stats for indexes. - for i, idx := range e.indexes { - buildTaskChan <- &samplingBuildTask{ - id: idx.ID, - rootRowCollector: rootRowCollector, - tp: types.NewFieldType(mysql.TypeBlob), - isColumn: false, - slicePos: colLen + i, - } - fmSketches = append(fmSketches, rootRowCollector.Base().FMSketches[colLen+i]) - } - close(buildTaskChan) - - panicCnt := 0 - for panicCnt < samplingStatsConcurrency { - err1, ok := <-buildResultChan - if !ok { - break - } - if err1 != nil { - err = err1 - if isAnalyzeWorkerPanic(err1) { - panicCnt++ - } - continue - } - } - defer func() { - totalSampleCollectorSize := int64(0) - for _, sampleCollector := range sampleCollectors { - if sampleCollector != nil { - totalSampleCollectorSize += sampleCollector.MemSize - } - } - e.memTracker.Release(totalSampleCollectorSize) - }() - if err != nil { - return 0, nil, nil, nil, nil, err - } - - count = rootRowCollector.Base().Count - if needExtStats { - extStats, err = statistics.BuildExtendedStats(e.ctx, e.TableID.GetStatisticsID(), e.colsInfo, sampleCollectors) - if err != nil { - return 0, nil, nil, nil, nil, err - } - } - - return -} - -// handleNDVForSpecialIndexes deals with the logic to analyze the index containing the virtual column when the mode is full sampling. -func (e *AnalyzeColumnsExecV2) handleNDVForSpecialIndexes(indexInfos []*model.IndexInfo, totalResultCh chan analyzeIndexNDVTotalResult, statsConcurrncy int) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("analyze ndv for special index panicked", zap.Any("recover", r), zap.Stack("stack")) - metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() - totalResultCh <- analyzeIndexNDVTotalResult{ - err: getAnalyzePanicErr(r), - } - } - }() - tasks := e.buildSubIndexJobForSpecialIndex(indexInfos) - taskCh := make(chan *analyzeTask, len(tasks)) - for _, task := range tasks { - AddNewAnalyzeJob(e.ctx, task.job) - } - resultsCh := make(chan *statistics.AnalyzeResults, len(tasks)) - if len(tasks) < statsConcurrncy { - statsConcurrncy = len(tasks) - } - var subIndexWorkerWg = NewAnalyzeResultsNotifyWaitGroupWrapper(resultsCh) - subIndexWorkerWg.Add(statsConcurrncy) - for i := 0; i < statsConcurrncy; i++ { - subIndexWorkerWg.Run(func() { e.subIndexWorkerForNDV(taskCh, resultsCh) }) - } - for _, task := range tasks { - taskCh <- task - } - close(taskCh) - panicCnt := 0 - totalResult := analyzeIndexNDVTotalResult{ - results: make(map[int64]*statistics.AnalyzeResults, len(indexInfos)), - } - var err error - statsHandle := domain.GetDomain(e.ctx).StatsHandle() - for panicCnt < statsConcurrncy { - results, ok := <-resultsCh - if !ok { - break - } - if results.Err != nil { - err = results.Err - statsHandle.FinishAnalyzeJob(results.Job, err, statistics.TableAnalysisJob) - if isAnalyzeWorkerPanic(err) { - panicCnt++ - } - continue - } - statsHandle.FinishAnalyzeJob(results.Job, nil, statistics.TableAnalysisJob) - totalResult.results[results.Ars[0].Hist[0].ID] = results - } - if err != nil { - totalResult.err = err - } - totalResultCh <- totalResult -} - -// subIndexWorker receive the task for each index and return the result for them. -func (e *AnalyzeColumnsExecV2) subIndexWorkerForNDV(taskCh chan *analyzeTask, resultsCh chan *statistics.AnalyzeResults) { - var task *analyzeTask - statsHandle := domain.GetDomain(e.ctx).StatsHandle() - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) - metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() - resultsCh <- &statistics.AnalyzeResults{ - Err: getAnalyzePanicErr(r), - Job: task.job, - } - } - }() - for { - var ok bool - task, ok = <-taskCh - if !ok { - break - } - statsHandle.StartAnalyzeJob(task.job) - if task.taskType != idxTask { - resultsCh <- &statistics.AnalyzeResults{ - Err: errors.Errorf("incorrect analyze type"), - Job: task.job, - } - continue - } - task.idxExec.job = task.job - resultsCh <- analyzeIndexNDVPushDown(task.idxExec) - } -} - -// buildSubIndexJobForSpecialIndex builds sub index pushed down task to calculate the NDV information for indexes containing virtual column. -// This is because we cannot push the calculation of the virtual column down to the tikv side. -func (e *AnalyzeColumnsExecV2) buildSubIndexJobForSpecialIndex(indexInfos []*model.IndexInfo) []*analyzeTask { - _, offset := timeutil.Zone(e.ctx.GetSessionVars().Location()) - tasks := make([]*analyzeTask, 0, len(indexInfos)) - sc := e.ctx.GetSessionVars().StmtCtx - concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), e.ctx) - for _, indexInfo := range indexInfos { - base := baseAnalyzeExec{ - ctx: e.ctx, - tableID: e.TableID, - concurrency: concurrency, - analyzePB: &tipb.AnalyzeReq{ - Tp: tipb.AnalyzeType_TypeIndex, - Flags: sc.PushDownFlags(), - TimeZoneOffset: offset, - }, - snapshot: e.snapshot, - } - idxExec := &AnalyzeIndexExec{ - baseAnalyzeExec: base, - isCommonHandle: e.tableInfo.IsCommonHandle, - idxInfo: indexInfo, - } - idxExec.opts = make(map[ast.AnalyzeOptionType]uint64, len(ast.AnalyzeOptionString)) - idxExec.opts[ast.AnalyzeOptNumTopN] = 0 - idxExec.opts[ast.AnalyzeOptCMSketchDepth] = 0 - idxExec.opts[ast.AnalyzeOptCMSketchWidth] = 0 - idxExec.opts[ast.AnalyzeOptNumSamples] = 0 - idxExec.opts[ast.AnalyzeOptNumBuckets] = 1 - statsVersion := new(int32) - *statsVersion = statistics.Version1 - // No Top-N - topnSize := int32(0) - idxExec.analyzePB.IdxReq = &tipb.AnalyzeIndexReq{ - // One bucket to store the null for null histogram. - BucketSize: 1, - NumColumns: int32(len(indexInfo.Columns)), - TopNSize: &topnSize, - Version: statsVersion, - SketchSize: statistics.MaxSketchSize, - } - if idxExec.isCommonHandle && indexInfo.Primary { - idxExec.analyzePB.Tp = tipb.AnalyzeType_TypeCommonHandle - } - // No CM-Sketch. - depth := int32(0) - width := int32(0) - idxExec.analyzePB.IdxReq.CmsketchDepth = &depth - idxExec.analyzePB.IdxReq.CmsketchWidth = &width - autoAnalyze := "" - if e.ctx.GetSessionVars().InRestrictedSQL { - autoAnalyze = "auto " - } - job := &statistics.AnalyzeJob{DBName: e.job.DBName, TableName: e.job.TableName, PartitionName: e.job.PartitionName, JobInfo: autoAnalyze + "analyze ndv for index " + indexInfo.Name.O} - idxExec.job = job - tasks = append(tasks, &analyzeTask{ - taskType: idxTask, - idxExec: idxExec, - job: job, - }) - } - return tasks -} - -func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResult, taskCh <-chan []byte, l int, index int) { - // Only close the resultCh in the first worker. - closeTheResultCh := index == 0 - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) - metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() - resultCh <- &samplingMergeResult{err: getAnalyzePanicErr(r)} - } - // Consume the remaining things. - for { - _, ok := <-taskCh - if !ok { - break - } - } - e.samplingMergeWg.Done() - if closeTheResultCh { - e.samplingMergeWg.Wait() - close(resultCh) - } - }() - failpoint.Inject("mockAnalyzeSamplingMergeWorkerPanic", func() { - panic("failpoint triggered") - }) - failpoint.Inject("mockAnalyzeMergeWorkerSlowConsume", func(val failpoint.Value) { - times := val.(int) - for i := 0; i < times; i++ { - e.memTracker.Consume(5 << 20) - time.Sleep(100 * time.Millisecond) - } - }) - retCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) - for i := 0; i < l; i++ { - retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(statistics.MaxSketchSize)) - } - statsHandle := domain.GetDomain(e.ctx).StatsHandle() - for { - data, ok := <-taskCh - if !ok { - break - } - - // Unmarshal the data. - dataSize := int64(cap(data)) - colResp := &tipb.AnalyzeColumnsResp{} - err := colResp.Unmarshal(data) - if err != nil { - resultCh <- &samplingMergeResult{err: err} - return - } - // Consume the memory of the data. - colRespSize := int64(colResp.Size()) - e.memTracker.Consume(colRespSize) - - // Update processed rows. - subCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) - subCollector.Base().FromProto(colResp.RowCollector, e.memTracker) - statsHandle.UpdateAnalyzeJobProgress(e.job, subCollector.Base().Count) - - // Print collect log. - oldRetCollectorSize := retCollector.Base().MemSize - oldRetCollectorCount := retCollector.Base().Count - retCollector.MergeCollector(subCollector) - newRetCollectorCount := retCollector.Base().Count - printAnalyzeMergeCollectorLog(oldRetCollectorCount, newRetCollectorCount, subCollector.Base().Count, - e.tableID.TableID, e.tableID.PartitionID, e.TableID.IsPartitionTable(), - "merge subCollector in concurrency in AnalyzeColumnsExecV2", index) - - // Consume the memory of the result. - newRetCollectorSize := retCollector.Base().MemSize - subCollectorSize := subCollector.Base().MemSize - e.memTracker.Consume(newRetCollectorSize - oldRetCollectorSize - subCollectorSize) - e.memTracker.Release(dataSize + colRespSize) - subCollector.DestroyAndPutToPool() - } - - resultCh <- &samplingMergeResult{collector: retCollector} -} - -func (e *AnalyzeColumnsExecV2) subBuildWorker(resultCh chan error, taskCh chan *samplingBuildTask, hists []*statistics.Histogram, topns []*statistics.TopN, collectors []*statistics.SampleCollector, exitCh chan struct{}) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack")) - metrics.PanicCounter.WithLabelValues(metrics.LabelAnalyze).Inc() - resultCh <- getAnalyzePanicErr(r) - } - }() - failpoint.Inject("mockAnalyzeSamplingBuildWorkerPanic", func() { - panic("failpoint triggered") - }) - - colLen := len(e.colsInfo) - bufferedMemSize := int64(0) - bufferedReleaseSize := int64(0) - defer e.memTracker.Consume(bufferedMemSize) - defer e.memTracker.Release(bufferedReleaseSize) - -workLoop: - for { - select { - case task, ok := <-taskCh: - if !ok { - break workLoop - } - var collector *statistics.SampleCollector - if task.isColumn { - if e.colsInfo[task.slicePos].IsGenerated() && !e.colsInfo[task.slicePos].GeneratedStored { - hists[task.slicePos] = nil - topns[task.slicePos] = nil - continue - } - sampleNum := task.rootRowCollector.Base().Samples.Len() - sampleItems := make([]*statistics.SampleItem, 0, sampleNum) - // consume mandatory memory at the beginning, including empty SampleItems of all sample rows, if exceeds, fast fail - collectorMemSize := int64(sampleNum) * (8 + statistics.EmptySampleItemSize) - e.memTracker.Consume(collectorMemSize) - var collator collate.Collator - ft := e.colsInfo[task.slicePos].FieldType - // When it's new collation data, we need to use its collate key instead of original value because only - // the collate key can ensure the correct ordering. - // This is also corresponding to similar operation in (*statistics.Column).GetColumnRowCount(). - if ft.EvalType() == types.ETString && ft.GetType() != mysql.TypeEnum && ft.GetType() != mysql.TypeSet { - collator = collate.GetCollator(ft.GetCollate()) - } - for j, row := range task.rootRowCollector.Base().Samples { - if row.Columns[task.slicePos].IsNull() { - continue - } - val := row.Columns[task.slicePos] - // If this value is very big, we think that it is not a value that can occur many times. So we don't record it. - if len(val.GetBytes()) > statistics.MaxSampleValueLength { - continue - } - if collator != nil { - val.SetBytes(collator.Key(val.GetString())) - deltaSize := int64(cap(val.GetBytes())) - collectorMemSize += deltaSize - e.memTracker.BufferedConsume(&bufferedMemSize, deltaSize) - } - sampleItems = append(sampleItems, &statistics.SampleItem{ - Value: val, - Ordinal: j, - }) - // tmp memory usage - deltaSize := val.MemUsage() + 4 // content of SampleItem is copied - e.memTracker.BufferedConsume(&bufferedMemSize, deltaSize) - e.memTracker.BufferedRelease(&bufferedReleaseSize, deltaSize) - } - collector = &statistics.SampleCollector{ - Samples: sampleItems, - NullCount: task.rootRowCollector.Base().NullCount[task.slicePos], - Count: task.rootRowCollector.Base().Count - task.rootRowCollector.Base().NullCount[task.slicePos], - FMSketch: task.rootRowCollector.Base().FMSketches[task.slicePos], - TotalSize: task.rootRowCollector.Base().TotalSizes[task.slicePos], - MemSize: collectorMemSize, - } - } else { - var tmpDatum types.Datum - var err error - idx := e.indexes[task.slicePos-colLen] - sampleNum := task.rootRowCollector.Base().Samples.Len() - sampleItems := make([]*statistics.SampleItem, 0, sampleNum) - // consume mandatory memory at the beginning, including all SampleItems, if exceeds, fast fail - // 8 is size of reference, 8 is the size of "b := make([]byte, 0, 8)" - collectorMemSize := int64(sampleNum) * (8 + statistics.EmptySampleItemSize + 8) - e.memTracker.Consume(collectorMemSize) - errCtx := e.ctx.GetSessionVars().StmtCtx.ErrCtx() - indexSampleCollectLoop: - for _, row := range task.rootRowCollector.Base().Samples { - if len(idx.Columns) == 1 && row.Columns[idx.Columns[0].Offset].IsNull() { - continue - } - b := make([]byte, 0, 8) - for _, col := range idx.Columns { - // If the index value contains one value which is too long, we think that it's a value that doesn't occur many times. - if len(row.Columns[col.Offset].GetBytes()) > statistics.MaxSampleValueLength { - continue indexSampleCollectLoop - } - if col.Length != types.UnspecifiedLength { - row.Columns[col.Offset].Copy(&tmpDatum) - ranger.CutDatumByPrefixLen(&tmpDatum, col.Length, &e.colsInfo[col.Offset].FieldType) - b, err = codec.EncodeKey(e.ctx.GetSessionVars().StmtCtx.TimeZone(), b, tmpDatum) - err = errCtx.HandleError(err) - if err != nil { - resultCh <- err - continue workLoop - } - continue - } - b, err = codec.EncodeKey(e.ctx.GetSessionVars().StmtCtx.TimeZone(), b, row.Columns[col.Offset]) - err = errCtx.HandleError(err) - if err != nil { - resultCh <- err - continue workLoop - } - } - sampleItems = append(sampleItems, &statistics.SampleItem{ - Value: types.NewBytesDatum(b), - }) - // tmp memory usage - deltaSize := sampleItems[len(sampleItems)-1].Value.MemUsage() - e.memTracker.BufferedConsume(&bufferedMemSize, deltaSize) - e.memTracker.BufferedRelease(&bufferedReleaseSize, deltaSize) - } - collector = &statistics.SampleCollector{ - Samples: sampleItems, - NullCount: task.rootRowCollector.Base().NullCount[task.slicePos], - Count: task.rootRowCollector.Base().Count - task.rootRowCollector.Base().NullCount[task.slicePos], - FMSketch: task.rootRowCollector.Base().FMSketches[task.slicePos], - TotalSize: task.rootRowCollector.Base().TotalSizes[task.slicePos], - MemSize: collectorMemSize, - } - } - if task.isColumn { - collectors[task.slicePos] = collector - } - releaseCollectorMemory := func() { - if !task.isColumn { - e.memTracker.Release(collector.MemSize) - } - } - hist, topn, err := statistics.BuildHistAndTopN(e.ctx, int(e.opts[ast.AnalyzeOptNumBuckets]), int(e.opts[ast.AnalyzeOptNumTopN]), task.id, collector, task.tp, task.isColumn, e.memTracker, e.ctx.GetSessionVars().EnableExtendedStats) - if err != nil { - resultCh <- err - releaseCollectorMemory() - continue - } - finalMemSize := hist.MemoryUsage() + topn.MemoryUsage() - e.memTracker.Consume(finalMemSize) - hists[task.slicePos] = hist - topns[task.slicePos] = topn - resultCh <- nil - releaseCollectorMemory() - case <-exitCh: - return - } - } -} - -type analyzeIndexNDVTotalResult struct { - results map[int64]*statistics.AnalyzeResults - err error -} - -type samplingMergeResult struct { - collector statistics.RowSampleCollector - err error -} - -type samplingBuildTask struct { - id int64 - rootRowCollector statistics.RowSampleCollector - tp *types.FieldType - isColumn bool - slicePos int -} - -func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, mergeTaskCh chan []byte, memTracker *memory.Tracker) error { - // After all tasks are sent, close the mergeTaskCh to notify the mergeWorker that all tasks have been sent. - defer close(mergeTaskCh) - for { - failpoint.Inject("mockKillRunningV2AnalyzeJob", func() { - dom := domain.GetDomain(ctx) - for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { - dom.SysProcTracker().KillSysProcess(id) - } - }) - if err := ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { - return err - } - failpoint.Inject("mockSlowAnalyzeV2", func() { - time.Sleep(1000 * time.Second) - }) - - data, err := handler.nextRaw(context.TODO()) - if err != nil { - return errors.Trace(err) - } - if data == nil { - break - } - - memTracker.Consume(int64(cap(data))) - mergeTaskCh <- data - } - - return nil -} diff --git a/pkg/executor/analyze_idx.go b/pkg/executor/analyze_idx.go index 65b1850491e6d..1c26be7782d4f 100644 --- a/pkg/executor/analyze_idx.go +++ b/pkg/executor/analyze_idx.go @@ -180,11 +180,11 @@ func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRang } func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, needCMS bool) (*statistics.Histogram, *statistics.CMSketch, *statistics.FMSketch, *statistics.TopN, error) { - if val, _err_ := failpoint.Eval(_curpkg_("buildStatsFromResult")); _err_ == nil { + failpoint.Inject("buildStatsFromResult", func(val failpoint.Value) { if val.(bool) { - return nil, nil, nil, nil, errors.New("mock buildStatsFromResult error") + failpoint.Return(nil, nil, nil, nil, errors.New("mock buildStatsFromResult error")) } - } + }) hist := &statistics.Histogram{} var cms *statistics.CMSketch var topn *statistics.TopN @@ -198,18 +198,18 @@ func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, nee statsVer = int(*e.analyzePB.IdxReq.Version) } for { - if _, _err_ := failpoint.Eval(_curpkg_("mockKillRunningAnalyzeIndexJob")); _err_ == nil { + failpoint.Inject("mockKillRunningAnalyzeIndexJob", func() { dom := domain.GetDomain(e.ctx) for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { dom.SysProcTracker().KillSysProcess(id) } - } + }) if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { return nil, nil, nil, nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("mockSlowAnalyzeIndex")); _err_ == nil { + failpoint.Inject("mockSlowAnalyzeIndex", func() { time.Sleep(1000 * time.Second) - } + }) data, err := result.NextRaw(context.TODO()) if err != nil { return nil, nil, nil, nil, err diff --git a/pkg/executor/analyze_idx.go__failpoint_stash__ b/pkg/executor/analyze_idx.go__failpoint_stash__ deleted file mode 100644 index 1c26be7782d4f..0000000000000 --- a/pkg/executor/analyze_idx.go__failpoint_stash__ +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "math" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/statistics" - handleutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -// AnalyzeIndexExec represents analyze index push down executor. -type AnalyzeIndexExec struct { - baseAnalyzeExec - - idxInfo *model.IndexInfo - isCommonHandle bool - result distsql.SelectResult - countNullRes distsql.SelectResult -} - -func analyzeIndexPushdown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { - ranges := ranger.FullRange() - // For single-column index, we do not load null rows from TiKV, so the built histogram would not include - // null values, and its `NullCount` would be set by result of another distsql call to get null rows. - // For multi-column index, we cannot define null for the rows, so we still use full range, and the rows - // containing null fields would exist in built histograms. Note that, the `NullCount` of histograms for - // multi-column index is always 0 then. - if len(idxExec.idxInfo.Columns) == 1 { - ranges = ranger.FullNotNullRange() - } - hist, cms, fms, topN, err := idxExec.buildStats(ranges, true) - if err != nil { - return &statistics.AnalyzeResults{Err: err, Job: idxExec.job} - } - var statsVer = statistics.Version1 - if idxExec.analyzePB.IdxReq.Version != nil { - statsVer = int(*idxExec.analyzePB.IdxReq.Version) - } - idxResult := &statistics.AnalyzeResult{ - Hist: []*statistics.Histogram{hist}, - TopNs: []*statistics.TopN{topN}, - Fms: []*statistics.FMSketch{fms}, - IsIndex: 1, - } - if statsVer != statistics.Version2 { - idxResult.Cms = []*statistics.CMSketch{cms} - } - cnt := hist.NullCount - if hist.Len() > 0 { - cnt += hist.Buckets[hist.Len()-1].Count - } - if topN.TotalCount() > 0 { - cnt += int64(topN.TotalCount()) - } - result := &statistics.AnalyzeResults{ - TableID: idxExec.tableID, - Ars: []*statistics.AnalyzeResult{idxResult}, - Job: idxExec.job, - StatsVer: statsVer, - Count: cnt, - Snapshot: idxExec.snapshot, - } - if idxExec.idxInfo.MVIndex { - result.ForMVIndex = true - } - return result -} - -func (e *AnalyzeIndexExec) buildStats(ranges []*ranger.Range, considerNull bool) (hist *statistics.Histogram, cms *statistics.CMSketch, fms *statistics.FMSketch, topN *statistics.TopN, err error) { - if err = e.open(ranges, considerNull); err != nil { - return nil, nil, nil, nil, err - } - defer func() { - err1 := closeAll(e.result, e.countNullRes) - if err == nil { - err = err1 - } - }() - hist, cms, fms, topN, err = e.buildStatsFromResult(e.result, true) - if err != nil { - return nil, nil, nil, nil, err - } - if e.countNullRes != nil { - nullHist, _, _, _, err := e.buildStatsFromResult(e.countNullRes, false) - if err != nil { - return nil, nil, nil, nil, err - } - if l := nullHist.Len(); l > 0 { - hist.NullCount = nullHist.Buckets[l-1].Count - } - } - hist.ID = e.idxInfo.ID - return hist, cms, fms, topN, nil -} - -func (e *AnalyzeIndexExec) open(ranges []*ranger.Range, considerNull bool) error { - err := e.fetchAnalyzeResult(ranges, false) - if err != nil { - return err - } - if considerNull && len(e.idxInfo.Columns) == 1 { - ranges = ranger.NullRange() - err = e.fetchAnalyzeResult(ranges, true) - if err != nil { - return err - } - } - return nil -} - -// fetchAnalyzeResult builds and dispatches the `kv.Request` from given ranges, and stores the `SelectResult` -// in corresponding fields based on the input `isNullRange` argument, which indicates if the range is the -// special null range for single-column index to get the null count. -func (e *AnalyzeIndexExec) fetchAnalyzeResult(ranges []*ranger.Range, isNullRange bool) error { - var builder distsql.RequestBuilder - var kvReqBuilder *distsql.RequestBuilder - if e.isCommonHandle && e.idxInfo.Primary { - kvReqBuilder = builder.SetHandleRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.tableID.GetStatisticsID()}, true, ranges) - } else { - kvReqBuilder = builder.SetIndexRangesForTables(e.ctx.GetDistSQLCtx(), []int64{e.tableID.GetStatisticsID()}, e.idxInfo.ID, ranges) - } - kvReqBuilder.SetResourceGroupTagger(e.ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) - startTS := uint64(math.MaxUint64) - isoLevel := kv.RC - if e.ctx.GetSessionVars().EnableAnalyzeSnapshot { - startTS = e.snapshot - isoLevel = kv.SI - } - kvReq, err := kvReqBuilder. - SetAnalyzeRequest(e.analyzePB, isoLevel). - SetStartTS(startTS). - SetKeepOrder(true). - SetConcurrency(e.concurrency). - SetResourceGroupName(e.ctx.GetSessionVars().StmtCtx.ResourceGroupName). - SetExplicitRequestSourceType(e.ctx.GetSessionVars().ExplicitRequestSourceType). - Build() - if err != nil { - return err - } - ctx := context.TODO() - result, err := distsql.Analyze(ctx, e.ctx.GetClient(), kvReq, e.ctx.GetSessionVars().KVVars, e.ctx.GetSessionVars().InRestrictedSQL, e.ctx.GetDistSQLCtx()) - if err != nil { - return err - } - if isNullRange { - e.countNullRes = result - } else { - e.result = result - } - return nil -} - -func (e *AnalyzeIndexExec) buildStatsFromResult(result distsql.SelectResult, needCMS bool) (*statistics.Histogram, *statistics.CMSketch, *statistics.FMSketch, *statistics.TopN, error) { - failpoint.Inject("buildStatsFromResult", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, nil, nil, nil, errors.New("mock buildStatsFromResult error")) - } - }) - hist := &statistics.Histogram{} - var cms *statistics.CMSketch - var topn *statistics.TopN - if needCMS { - cms = statistics.NewCMSketch(int32(e.opts[ast.AnalyzeOptCMSketchDepth]), int32(e.opts[ast.AnalyzeOptCMSketchWidth])) - topn = statistics.NewTopN(int(e.opts[ast.AnalyzeOptNumTopN])) - } - fms := statistics.NewFMSketch(statistics.MaxSketchSize) - statsVer := statistics.Version1 - if e.analyzePB.IdxReq.Version != nil { - statsVer = int(*e.analyzePB.IdxReq.Version) - } - for { - failpoint.Inject("mockKillRunningAnalyzeIndexJob", func() { - dom := domain.GetDomain(e.ctx) - for _, id := range handleutil.GlobalAutoAnalyzeProcessList.All() { - dom.SysProcTracker().KillSysProcess(id) - } - }) - if err := e.ctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { - return nil, nil, nil, nil, err - } - failpoint.Inject("mockSlowAnalyzeIndex", func() { - time.Sleep(1000 * time.Second) - }) - data, err := result.NextRaw(context.TODO()) - if err != nil { - return nil, nil, nil, nil, err - } - if data == nil { - break - } - resp := &tipb.AnalyzeIndexResp{} - err = resp.Unmarshal(data) - if err != nil { - return nil, nil, nil, nil, err - } - hist, cms, fms, topn, err = updateIndexResult(e.ctx, resp, e.job, hist, cms, fms, topn, - e.idxInfo, int(e.opts[ast.AnalyzeOptNumBuckets]), int(e.opts[ast.AnalyzeOptNumTopN]), statsVer) - if err != nil { - return nil, nil, nil, nil, err - } - } - if needCMS && topn.TotalCount() > 0 { - hist.RemoveVals(topn.TopN) - } - if statsVer == statistics.Version2 { - hist.StandardizeForV2AnalyzeIndex() - } - if needCMS && cms != nil { - cms.CalcDefaultValForAnalyze(uint64(hist.NDV)) - } - return hist, cms, fms, topn, nil -} - -func (e *AnalyzeIndexExec) buildSimpleStats(ranges []*ranger.Range, considerNull bool) (fms *statistics.FMSketch, nullHist *statistics.Histogram, err error) { - if err = e.open(ranges, considerNull); err != nil { - return nil, nil, err - } - defer func() { - err1 := closeAll(e.result, e.countNullRes) - if err == nil { - err = err1 - } - }() - _, _, fms, _, err = e.buildStatsFromResult(e.result, false) - if e.countNullRes != nil { - nullHist, _, _, _, err := e.buildStatsFromResult(e.countNullRes, false) - if err != nil { - return nil, nil, err - } - if l := nullHist.Len(); l > 0 { - return fms, nullHist, nil - } - } - return fms, nil, nil -} - -func analyzeIndexNDVPushDown(idxExec *AnalyzeIndexExec) *statistics.AnalyzeResults { - ranges := ranger.FullRange() - // For single-column index, we do not load null rows from TiKV, so the built histogram would not include - // null values, and its `NullCount` would be set by result of another distsql call to get null rows. - // For multi-column index, we cannot define null for the rows, so we still use full range, and the rows - // containing null fields would exist in built histograms. Note that, the `NullCount` of histograms for - // multi-column index is always 0 then. - if len(idxExec.idxInfo.Columns) == 1 { - ranges = ranger.FullNotNullRange() - } - fms, nullHist, err := idxExec.buildSimpleStats(ranges, len(idxExec.idxInfo.Columns) == 1) - if err != nil { - return &statistics.AnalyzeResults{Err: err, Job: idxExec.job} - } - result := &statistics.AnalyzeResult{ - Fms: []*statistics.FMSketch{fms}, - // We use histogram to get the Index's ID. - Hist: []*statistics.Histogram{statistics.NewHistogram(idxExec.idxInfo.ID, 0, 0, statistics.Version1, types.NewFieldType(mysql.TypeBlob), 0, 0)}, - IsIndex: 1, - } - r := &statistics.AnalyzeResults{ - TableID: idxExec.tableID, - Ars: []*statistics.AnalyzeResult{result}, - Job: idxExec.job, - // TODO: avoid reusing Version1. - StatsVer: statistics.Version1, - } - if nullHist != nil && nullHist.Len() > 0 { - r.Count = nullHist.Buckets[nullHist.Len()-1].Count - } - return r -} - -func updateIndexResult( - ctx sessionctx.Context, - resp *tipb.AnalyzeIndexResp, - job *statistics.AnalyzeJob, - hist *statistics.Histogram, - cms *statistics.CMSketch, - fms *statistics.FMSketch, - topn *statistics.TopN, - idxInfo *model.IndexInfo, - numBuckets int, - numTopN int, - statsVer int, -) ( - *statistics.Histogram, - *statistics.CMSketch, - *statistics.FMSketch, - *statistics.TopN, - error, -) { - var err error - needCMS := cms != nil - respHist := statistics.HistogramFromProto(resp.Hist) - if job != nil { - statsHandle := domain.GetDomain(ctx).StatsHandle() - statsHandle.UpdateAnalyzeJobProgress(job, int64(respHist.TotalRowCount())) - } - hist, err = statistics.MergeHistograms(ctx.GetSessionVars().StmtCtx, hist, respHist, numBuckets, statsVer) - if err != nil { - return nil, nil, nil, nil, err - } - if needCMS { - if resp.Cms == nil { - logutil.Logger(context.TODO()).Warn("nil CMS in response", zap.String("table", idxInfo.Table.O), zap.String("index", idxInfo.Name.O)) - } else { - cm, tmpTopN := statistics.CMSketchAndTopNFromProto(resp.Cms) - if err := cms.MergeCMSketch(cm); err != nil { - return nil, nil, nil, nil, err - } - statistics.MergeTopNAndUpdateCMSketch(topn, tmpTopN, cms, uint32(numTopN)) - } - } - if fms != nil && resp.Collector != nil && resp.Collector.FmSketch != nil { - fms.MergeFMSketch(statistics.FMSketchFromProto(resp.Collector.FmSketch)) - } - return hist, cms, fms, topn, nil -} diff --git a/pkg/executor/batch_point_get.go b/pkg/executor/batch_point_get.go index f7d8a425590ac..88a2a442f158d 100644 --- a/pkg/executor/batch_point_get.go +++ b/pkg/executor/batch_point_get.go @@ -327,14 +327,14 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 // 3. Then point get retrieve data from backend after step 2 finished // 4. Check the result - if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("batchPointGetRepeatableReadTest-step1")); _err_ == nil { + failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step1", func() { if ch, ok := ctx.Value("batchPointGetRepeatableReadTest").(chan struct{}); ok { // Make `UPDATE` continue close(ch) } // Wait `UPDATE` finished - failpoint.EvalContext(ctx, _curpkg_("batchPointGetRepeatableReadTest-step2")) - } + failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step2", nil) + }) } else if e.keepOrder { less := func(i, j kv.Handle) int { if e.desc { diff --git a/pkg/executor/batch_point_get.go__failpoint_stash__ b/pkg/executor/batch_point_get.go__failpoint_stash__ deleted file mode 100644 index 88a2a442f158d..0000000000000 --- a/pkg/executor/batch_point_get.go__failpoint_stash__ +++ /dev/null @@ -1,528 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "fmt" - "slices" - "sync/atomic" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - driver "github.com/pingcap/tidb/pkg/store/driver/txn" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil/consistency" - "github.com/pingcap/tidb/pkg/util/rowcodec" - "github.com/tikv/client-go/v2/tikvrpc" -) - -// BatchPointGetExec executes a bunch of point select queries. -type BatchPointGetExec struct { - exec.BaseExecutor - indexUsageReporter *exec.IndexUsageReporter - - tblInfo *model.TableInfo - idxInfo *model.IndexInfo - handles []kv.Handle - // table/partition IDs for handle or index read - // (can be secondary unique key, - // and need lookup through handle) - planPhysIDs []int64 - // If != 0 then it is a single partition under Static Prune mode. - singlePartID int64 - partitionNames []model.CIStr - idxVals [][]types.Datum - txn kv.Transaction - lock bool - waitTime int64 - inited uint32 - values [][]byte - index int - rowDecoder *rowcodec.ChunkDecoder - keepOrder bool - desc bool - batchGetter kv.BatchGetter - - columns []*model.ColumnInfo - // virtualColumnIndex records all the indices of virtual columns and sort them in definition - // to make sure we can compute the virtual column in right order. - virtualColumnIndex []int - - // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. - virtualColumnRetFieldTypes []*types.FieldType - - snapshot kv.Snapshot - stats *runtimeStatsWithSnapshot -} - -// buildVirtualColumnInfo saves virtual column indices and sort them in definition order -func (e *BatchPointGetExec) buildVirtualColumnInfo() { - e.virtualColumnIndex = buildVirtualColumnIndex(e.Schema(), e.columns) - if len(e.virtualColumnIndex) > 0 { - e.virtualColumnRetFieldTypes = make([]*types.FieldType, len(e.virtualColumnIndex)) - for i, idx := range e.virtualColumnIndex { - e.virtualColumnRetFieldTypes[i] = e.Schema().Columns[idx].RetType - } - } -} - -// Open implements the Executor interface. -func (e *BatchPointGetExec) Open(context.Context) error { - sessVars := e.Ctx().GetSessionVars() - txnCtx := sessVars.TxnCtx - txn, err := e.Ctx().Txn(false) - if err != nil { - return err - } - e.txn = txn - - setOptionForTopSQL(e.Ctx().GetSessionVars().StmtCtx, e.snapshot) - var batchGetter kv.BatchGetter = e.snapshot - if txn.Valid() { - lock := e.tblInfo.Lock - if e.lock { - batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), &PessimisticLockCacheGetter{txnCtx: txnCtx}, e.snapshot) - } else if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) && e.Ctx().GetSessionVars().EnablePointGetCache { - batchGetter = newCacheBatchGetter(e.Ctx(), e.tblInfo.ID, e.snapshot) - } else { - batchGetter = driver.NewBufferBatchGetter(txn.GetMemBuffer(), nil, e.snapshot) - } - } - e.batchGetter = batchGetter - return nil -} - -// CacheTable always use memBuffer in session as snapshot. -// cacheTableSnapshot inherits kv.Snapshot and override the BatchGet methods and Get methods. -type cacheTableSnapshot struct { - kv.Snapshot - memBuffer kv.MemBuffer -} - -func (s cacheTableSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { - values := make(map[string][]byte) - if s.memBuffer == nil { - return values, nil - } - - for _, key := range keys { - val, err := s.memBuffer.Get(ctx, key) - if kv.ErrNotExist.Equal(err) { - continue - } - - if err != nil { - return nil, err - } - - if len(val) == 0 { - continue - } - - values[string(key)] = val - } - - return values, nil -} - -func (s cacheTableSnapshot) Get(ctx context.Context, key kv.Key) ([]byte, error) { - return s.memBuffer.Get(ctx, key) -} - -// MockNewCacheTableSnapShot only serves for test. -func MockNewCacheTableSnapShot(snapshot kv.Snapshot, memBuffer kv.MemBuffer) *cacheTableSnapshot { - return &cacheTableSnapshot{snapshot, memBuffer} -} - -// Close implements the Executor interface. -func (e *BatchPointGetExec) Close() error { - if e.RuntimeStats() != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - if e.RuntimeStats() != nil && e.snapshot != nil { - e.snapshot.SetOption(kv.CollectRuntimeStats, nil) - } - if e.indexUsageReporter != nil && e.idxInfo != nil { - kvReqTotal := e.stats.GetCmdRPCCount(tikvrpc.CmdBatchGet) - // We cannot distinguish how many rows are coming from each partition. Here, we calculate all index usages - // percentage according to the row counts for the whole table. - e.indexUsageReporter.ReportPointGetIndexUsage(e.tblInfo.ID, e.tblInfo.ID, e.idxInfo.ID, e.ID(), kvReqTotal) - } - e.inited = 0 - e.index = 0 - return nil -} - -// Next implements the Executor interface. -func (e *BatchPointGetExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if atomic.CompareAndSwapUint32(&e.inited, 0, 1) { - if err := e.initialize(ctx); err != nil { - return err - } - if e.lock { - e.UpdateDeltaForTableID(e.tblInfo.ID) - } - } - - if e.index >= len(e.values) { - return nil - } - - schema := e.Schema() - sctx := e.BaseExecutor.Ctx() - start := e.index - for !req.IsFull() && e.index < len(e.values) { - handle, val := e.handles[e.index], e.values[e.index] - err := DecodeRowValToChunk(sctx, schema, e.tblInfo, handle, val, req, e.rowDecoder) - if err != nil { - return err - } - e.index++ - } - - err := fillRowChecksum(sctx, start, e.index, schema, e.tblInfo, e.values, e.handles, req, nil) - if err != nil { - return err - } - err = table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, schema.Columns, e.columns, sctx.GetExprCtx(), req) - if err != nil { - return err - } - return nil -} - -func (e *BatchPointGetExec) initialize(ctx context.Context) error { - var handleVals map[string][]byte - var indexKeys []kv.Key - var err error - batchGetter := e.batchGetter - rc := e.Ctx().GetSessionVars().IsPessimisticReadConsistency() - if e.idxInfo != nil && !isCommonHandleRead(e.tblInfo, e.idxInfo) { - // `SELECT a, b FROM t WHERE (a, b) IN ((1, 2), (1, 2), (2, 1), (1, 2))` should not return duplicated rows - dedup := make(map[hack.MutableString]struct{}) - toFetchIndexKeys := make([]kv.Key, 0, len(e.idxVals)) - for i, idxVals := range e.idxVals { - physID := e.tblInfo.ID - if e.singlePartID != 0 { - physID = e.singlePartID - } else if len(e.planPhysIDs) > i { - physID = e.planPhysIDs[i] - } - idxKey, err1 := plannercore.EncodeUniqueIndexKey(e.Ctx(), e.tblInfo, e.idxInfo, idxVals, physID) - if err1 != nil && !kv.ErrNotExist.Equal(err1) { - return err1 - } - if idxKey == nil { - continue - } - s := hack.String(idxKey) - if _, found := dedup[s]; found { - continue - } - dedup[s] = struct{}{} - toFetchIndexKeys = append(toFetchIndexKeys, idxKey) - } - if e.keepOrder { - // TODO: if multiple partitions, then the IDs needs to be - // in the same order as the index keys - // and should skip table id part when comparing - intest.Assert(e.singlePartID != 0 || len(e.planPhysIDs) <= 1 || e.idxInfo.Global) - slices.SortFunc(toFetchIndexKeys, func(i, j kv.Key) int { - if e.desc { - return j.Cmp(i) - } - return i.Cmp(j) - }) - } - - // lock all keys in repeatable read isolation. - // for read consistency, only lock exist keys, - // indexKeys will be generated after getting handles. - if !rc { - indexKeys = toFetchIndexKeys - } else { - indexKeys = make([]kv.Key, 0, len(toFetchIndexKeys)) - } - - // SELECT * FROM t WHERE x IN (null), in this case there is no key. - if len(toFetchIndexKeys) == 0 { - return nil - } - - // Fetch all handles. - handleVals, err = batchGetter.BatchGet(ctx, toFetchIndexKeys) - if err != nil { - return err - } - - e.handles = make([]kv.Handle, 0, len(toFetchIndexKeys)) - if e.tblInfo.Partition != nil { - e.planPhysIDs = e.planPhysIDs[:0] - } - for _, key := range toFetchIndexKeys { - handleVal := handleVals[string(key)] - if len(handleVal) == 0 { - continue - } - handle, err1 := tablecodec.DecodeHandleInIndexValue(handleVal) - if err1 != nil { - return err1 - } - if e.tblInfo.Partition != nil { - var pid int64 - if e.idxInfo.Global { - _, pid, err = codec.DecodeInt(tablecodec.SplitIndexValue(handleVal).PartitionID) - if err != nil { - return err - } - if e.singlePartID != 0 && e.singlePartID != pid { - continue - } - if !matchPartitionNames(pid, e.partitionNames, e.tblInfo.GetPartitionInfo()) { - continue - } - e.planPhysIDs = append(e.planPhysIDs, pid) - } else { - pid = tablecodec.DecodeTableID(key) - e.planPhysIDs = append(e.planPhysIDs, pid) - } - if e.lock { - e.UpdateDeltaForTableID(pid) - } - } - e.handles = append(e.handles, handle) - if rc { - indexKeys = append(indexKeys, key) - } - } - - // The injection is used to simulate following scenario: - // 1. Session A create a point get query but pause before second time `GET` kv from backend - // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 - // 3. Then point get retrieve data from backend after step 2 finished - // 4. Check the result - failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step1", func() { - if ch, ok := ctx.Value("batchPointGetRepeatableReadTest").(chan struct{}); ok { - // Make `UPDATE` continue - close(ch) - } - // Wait `UPDATE` finished - failpoint.InjectContext(ctx, "batchPointGetRepeatableReadTest-step2", nil) - }) - } else if e.keepOrder { - less := func(i, j kv.Handle) int { - if e.desc { - return j.Compare(i) - } - return i.Compare(j) - } - if e.tblInfo.PKIsHandle && mysql.HasUnsignedFlag(e.tblInfo.GetPkColInfo().GetFlag()) { - uintComparator := func(i, h kv.Handle) int { - if !i.IsInt() || !h.IsInt() { - panic(fmt.Sprintf("both handles need be IntHandle, but got %T and %T ", i, h)) - } - ihVal := uint64(i.IntValue()) - hVal := uint64(h.IntValue()) - if ihVal > hVal { - return 1 - } - if ihVal < hVal { - return -1 - } - return 0 - } - less = func(i, j kv.Handle) int { - if e.desc { - return uintComparator(j, i) - } - return uintComparator(i, j) - } - } - slices.SortFunc(e.handles, less) - // TODO: if partitioned table, sorting the handles would also - // need to have the physIDs rearranged in the same order! - intest.Assert(e.singlePartID != 0 || len(e.planPhysIDs) <= 1) - } - - keys := make([]kv.Key, 0, len(e.handles)) - newHandles := make([]kv.Handle, 0, len(e.handles)) - for i, handle := range e.handles { - tID := e.tblInfo.ID - if e.singlePartID != 0 { - tID = e.singlePartID - } else if len(e.planPhysIDs) > 0 { - // Direct handle read - tID = e.planPhysIDs[i] - } - if tID <= 0 { - // not matching any partition - continue - } - key := tablecodec.EncodeRowKeyWithHandle(tID, handle) - keys = append(keys, key) - newHandles = append(newHandles, handle) - } - e.handles = newHandles - - var values map[string][]byte - // Lock keys (include exists and non-exists keys) before fetch all values for Repeatable Read Isolation. - if e.lock && !rc { - lockKeys := make([]kv.Key, len(keys)+len(indexKeys)) - copy(lockKeys, keys) - copy(lockKeys[len(keys):], indexKeys) - err = LockKeys(ctx, e.Ctx(), e.waitTime, lockKeys...) - if err != nil { - return err - } - } - // Fetch all values. - values, err = batchGetter.BatchGet(ctx, keys) - if err != nil { - return err - } - handles := make([]kv.Handle, 0, len(values)) - var existKeys []kv.Key - if e.lock && rc { - existKeys = make([]kv.Key, 0, 2*len(values)) - } - e.values = make([][]byte, 0, len(values)) - for i, key := range keys { - val := values[string(key)] - if len(val) == 0 { - if e.idxInfo != nil && (!e.tblInfo.IsCommonHandle || !e.idxInfo.Primary) && - !e.Ctx().GetSessionVars().StmtCtx.WeakConsistency { - return (&consistency.Reporter{ - HandleEncode: func(_ kv.Handle) kv.Key { - return key - }, - IndexEncode: func(_ *consistency.RecordData) kv.Key { - return indexKeys[i] - }, - Tbl: e.tblInfo, - Idx: e.idxInfo, - EnableRedactLog: e.Ctx().GetSessionVars().EnableRedactLog, - Storage: e.Ctx().GetStore(), - }).ReportLookupInconsistent(ctx, - 1, 0, - e.handles[i:i+1], - e.handles, - []consistency.RecordData{{}}, - ) - } - continue - } - e.values = append(e.values, val) - handles = append(handles, e.handles[i]) - if e.lock && rc { - existKeys = append(existKeys, key) - // when e.handles is set in builder directly, index should be primary key and the plan is CommonHandleRead - // with clustered index enabled, indexKeys is empty in this situation - // lock primary key for clustered index table is redundant - if len(indexKeys) != 0 { - existKeys = append(existKeys, indexKeys[i]) - } - } - } - // Lock exists keys only for Read Committed Isolation. - if e.lock && rc { - err = LockKeys(ctx, e.Ctx(), e.waitTime, existKeys...) - if err != nil { - return err - } - } - e.handles = handles - return nil -} - -// LockKeys locks the keys for pessimistic transaction. -func LockKeys(ctx context.Context, sctx sessionctx.Context, lockWaitTime int64, keys ...kv.Key) error { - txnCtx := sctx.GetSessionVars().TxnCtx - lctx, err := newLockCtx(sctx, lockWaitTime, len(keys)) - if err != nil { - return err - } - if txnCtx.IsPessimistic { - lctx.InitReturnValues(len(keys)) - } - err = doLockKeys(ctx, sctx, lctx, keys...) - if err != nil { - return err - } - if txnCtx.IsPessimistic { - // When doLockKeys returns without error, no other goroutines access the map, - // it's safe to read it without mutex. - for _, key := range keys { - if v, ok := lctx.GetValueNotLocked(key); ok { - txnCtx.SetPessimisticLockCache(key, v) - } - } - } - return nil -} - -// PessimisticLockCacheGetter implements the kv.Getter interface. -// It is used as a middle cache to construct the BufferedBatchGetter. -type PessimisticLockCacheGetter struct { - txnCtx *variable.TransactionContext -} - -// Get implements the kv.Getter interface. -func (getter *PessimisticLockCacheGetter) Get(_ context.Context, key kv.Key) ([]byte, error) { - val, ok := getter.txnCtx.GetKeyInPessimisticLockCache(key) - if ok { - return val, nil - } - return nil, kv.ErrNotExist -} - -type cacheBatchGetter struct { - ctx sessionctx.Context - tid int64 - snapshot kv.Snapshot -} - -func (b *cacheBatchGetter) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { - cacheDB := b.ctx.GetStore().GetMemCache() - vals := make(map[string][]byte) - for _, key := range keys { - val, err := cacheDB.UnionGet(ctx, b.tid, b.snapshot, key) - if err != nil { - if !kv.ErrNotExist.Equal(err) { - return nil, err - } - continue - } - vals[string(key)] = val - } - return vals, nil -} - -func newCacheBatchGetter(ctx sessionctx.Context, tid int64, snapshot kv.Snapshot) *cacheBatchGetter { - return &cacheBatchGetter{ctx, tid, snapshot} -} diff --git a/pkg/executor/binding__failpoint_binding__.go b/pkg/executor/binding__failpoint_binding__.go deleted file mode 100644 index 4ed4af3ddbf4a..0000000000000 --- a/pkg/executor/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package executor - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/brie.go b/pkg/executor/brie.go index 478ba583f4dd2..130f627bc6609 100644 --- a/pkg/executor/brie.go +++ b/pkg/executor/brie.go @@ -313,9 +313,9 @@ func (b *executorBuilder) buildBRIE(s *ast.BRIEStmt, schema *expression.Schema) } store := tidbCfg.Store - if v, _err_ := failpoint.Eval(_curpkg_("modifyStore")); _err_ == nil { + failpoint.Inject("modifyStore", func(v failpoint.Value) { store = v.(string) - } + }) if store != "tikv" { b.err = errors.Errorf("%s requires tikv store, not %s", s.Kind, store) return nil @@ -581,13 +581,13 @@ func (e *BRIEExec) Next(ctx context.Context, req *chunk.Chunk) error { e.info.queueTime = types.CurrentTime(mysql.TypeDatetime) taskCtx, taskID := bq.registerTask(ctx, e.info) defer bq.cancelTask(taskID) - if _, _err_ := failpoint.Eval(_curpkg_("block-on-brie")); _err_ == nil { + failpoint.Inject("block-on-brie", func() { log.Warn("You shall not pass, nya. :3") <-taskCtx.Done() if taskCtx.Err() != nil { - return taskCtx.Err() + failpoint.Return(taskCtx.Err()) } - } + }) // manually monitor the Killed status... go func() { ticker := time.NewTicker(3 * time.Second) diff --git a/pkg/executor/brie.go__failpoint_stash__ b/pkg/executor/brie.go__failpoint_stash__ deleted file mode 100644 index 130f627bc6609..0000000000000 --- a/pkg/executor/brie.go__failpoint_stash__ +++ /dev/null @@ -1,826 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "context" - "fmt" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - backuppb "github.com/pingcap/kvproto/pkg/brpb" - "github.com/pingcap/kvproto/pkg/encryptionpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/br/pkg/glue" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/br/pkg/task" - "github.com/pingcap/tidb/br/pkg/task/show" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/printer" - "github.com/pingcap/tidb/pkg/util/sem" - "github.com/pingcap/tidb/pkg/util/syncutil" - filter "github.com/pingcap/tidb/pkg/util/table-filter" - "github.com/tikv/client-go/v2/oracle" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" -) - -const clearInterval = 10 * time.Minute - -var outdatedDuration = types.Duration{ - Duration: 30 * time.Minute, - Fsp: types.DefaultFsp, -} - -// brieTaskProgress tracks a task's current progress. -type brieTaskProgress struct { - // current progress of the task. - // this field is atomically updated outside of the lock below. - current int64 - - // lock is the mutex protected the two fields below. - lock syncutil.Mutex - // cmd is the name of the step the BRIE task is currently performing. - cmd string - // total is the total progress of the task. - // the percentage of completeness is `(100%) * current / total`. - total int64 -} - -// Inc implements glue.Progress -func (p *brieTaskProgress) Inc() { - atomic.AddInt64(&p.current, 1) -} - -// IncBy implements glue.Progress -func (p *brieTaskProgress) IncBy(cnt int64) { - atomic.AddInt64(&p.current, cnt) -} - -// GetCurrent implements glue.Progress -func (p *brieTaskProgress) GetCurrent() int64 { - return atomic.LoadInt64(&p.current) -} - -// Close implements glue.Progress -func (p *brieTaskProgress) Close() { - p.lock.Lock() - current := atomic.LoadInt64(&p.current) - if current < p.total { - p.cmd = fmt.Sprintf("%s Canceled", p.cmd) - } - atomic.StoreInt64(&p.current, p.total) - p.lock.Unlock() -} - -type brieTaskInfo struct { - id uint64 - query string - queueTime types.Time - execTime types.Time - finishTime types.Time - kind ast.BRIEKind - storage string - connID uint64 - backupTS uint64 - restoreTS uint64 - archiveSize uint64 - message string -} - -type brieQueueItem struct { - info *brieTaskInfo - progress *brieTaskProgress - cancel func() -} - -type brieQueue struct { - nextID uint64 - tasks sync.Map - - lastClearTime time.Time - - workerCh chan struct{} -} - -// globalBRIEQueue is the BRIE execution queue. Only one BRIE task can be executed each time. -// TODO: perhaps copy the DDL Job queue so only one task can be executed in the whole cluster. -var globalBRIEQueue = &brieQueue{ - workerCh: make(chan struct{}, 1), -} - -// ResetGlobalBRIEQueueForTest resets the ID allocation for the global BRIE queue. -// In some of our test cases, we rely on the ID is allocated from 1. -// When batch executing test cases, the assumation may be broken and make the cases fail. -func ResetGlobalBRIEQueueForTest() { - globalBRIEQueue = &brieQueue{ - workerCh: make(chan struct{}, 1), - } -} - -// registerTask registers a BRIE task in the queue. -func (bq *brieQueue) registerTask( - ctx context.Context, - info *brieTaskInfo, -) (context.Context, uint64) { - taskCtx, taskCancel := context.WithCancel(ctx) - item := &brieQueueItem{ - info: info, - cancel: taskCancel, - progress: &brieTaskProgress{ - cmd: "Wait", - total: 1, - }, - } - - taskID := atomic.AddUint64(&bq.nextID, 1) - bq.tasks.Store(taskID, item) - info.id = taskID - - return taskCtx, taskID -} - -// query task queries a task from the queue. -func (bq *brieQueue) queryTask(taskID uint64) (*brieTaskInfo, bool) { - if item, ok := bq.tasks.Load(taskID); ok { - return item.(*brieQueueItem).info, true - } - return nil, false -} - -// acquireTask prepares to execute a BRIE task. Only one BRIE task can be -// executed at a time, and this function blocks until the task is ready. -// -// Returns an object to track the task's progress. -func (bq *brieQueue) acquireTask(taskCtx context.Context, taskID uint64) (*brieTaskProgress, error) { - // wait until we are at the front of the queue. - select { - case bq.workerCh <- struct{}{}: - if item, ok := bq.tasks.Load(taskID); ok { - return item.(*brieQueueItem).progress, nil - } - // cannot find task, perhaps it has been canceled. allow the next task to run. - bq.releaseTask() - return nil, errors.Errorf("backup/restore task %d is canceled", taskID) - case <-taskCtx.Done(): - return nil, taskCtx.Err() - } -} - -func (bq *brieQueue) releaseTask() { - <-bq.workerCh -} - -func (bq *brieQueue) cancelTask(taskID uint64) bool { - item, ok := bq.tasks.Load(taskID) - if !ok { - return false - } - i := item.(*brieQueueItem) - i.cancel() - i.progress.Close() - log.Info("BRIE job canceled.", zap.Uint64("ID", i.info.id)) - return true -} - -func (bq *brieQueue) clearTask(sc *stmtctx.StatementContext) { - if time.Since(bq.lastClearTime) < clearInterval { - return - } - - bq.lastClearTime = time.Now() - currTime := types.CurrentTime(mysql.TypeDatetime) - - bq.tasks.Range(func(key, value any) bool { - item := value.(*brieQueueItem) - if d := currTime.Sub(sc.TypeCtx(), &item.info.finishTime); d.Compare(outdatedDuration) > 0 { - bq.tasks.Delete(key) - } - return true - }) -} - -func (b *executorBuilder) parseTSString(ts string) (uint64, error) { - sc := stmtctx.NewStmtCtxWithTimeZone(b.ctx.GetSessionVars().Location()) - t, err := types.ParseTime(sc.TypeCtx(), ts, mysql.TypeTimestamp, types.MaxFsp) - if err != nil { - return 0, err - } - t1, err := t.GoTime(sc.TimeZone()) - if err != nil { - return 0, err - } - return oracle.GoTimeToTS(t1), nil -} - -func (b *executorBuilder) buildBRIE(s *ast.BRIEStmt, schema *expression.Schema) exec.Executor { - if s.Kind == ast.BRIEKindShowBackupMeta { - return execOnce(&showMetaExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), - showConfig: buildShowMetadataConfigFrom(s), - }) - } - - if s.Kind == ast.BRIEKindShowQuery { - return execOnce(&showQueryExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), - targetID: uint64(s.JobID), - }) - } - - if s.Kind == ast.BRIEKindCancelJob { - return &cancelJobExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), - targetID: uint64(s.JobID), - } - } - - e := &BRIEExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, schema, 0), - info: &brieTaskInfo{ - kind: s.Kind, - }, - } - - tidbCfg := config.GetGlobalConfig() - tlsCfg := task.TLSConfig{ - CA: tidbCfg.Security.ClusterSSLCA, - Cert: tidbCfg.Security.ClusterSSLCert, - Key: tidbCfg.Security.ClusterSSLKey, - } - pds := strings.Split(tidbCfg.Path, ",") - cfg := task.DefaultConfig() - cfg.PD = pds - cfg.TLS = tlsCfg - - storageURL, err := storage.ParseRawURL(s.Storage) - if err != nil { - b.err = errors.Annotate(err, "invalid destination URL") - return nil - } - - switch storageURL.Scheme { - case "s3": - storage.ExtractQueryParameters(storageURL, &cfg.S3) - case "gs", "gcs": - storage.ExtractQueryParameters(storageURL, &cfg.GCS) - case "hdfs": - if sem.IsEnabled() { - // Storage is not permitted to be hdfs when SEM is enabled. - b.err = plannererrors.ErrNotSupportedWithSem.GenWithStackByArgs("hdfs storage") - return nil - } - case "local", "file", "": - if sem.IsEnabled() { - // Storage is not permitted to be local when SEM is enabled. - b.err = plannererrors.ErrNotSupportedWithSem.GenWithStackByArgs("local storage") - return nil - } - default: - } - - store := tidbCfg.Store - failpoint.Inject("modifyStore", func(v failpoint.Value) { - store = v.(string) - }) - if store != "tikv" { - b.err = errors.Errorf("%s requires tikv store, not %s", s.Kind, store) - return nil - } - - cfg.Storage = storageURL.String() - e.info.storage = cfg.Storage - - for _, opt := range s.Options { - switch opt.Tp { - case ast.BRIEOptionRateLimit: - cfg.RateLimit = opt.UintValue - case ast.BRIEOptionConcurrency: - cfg.Concurrency = uint32(opt.UintValue) - case ast.BRIEOptionChecksum: - cfg.Checksum = opt.UintValue != 0 - case ast.BRIEOptionSendCreds: - cfg.SendCreds = opt.UintValue != 0 - case ast.BRIEOptionChecksumConcurrency: - cfg.ChecksumConcurrency = uint(opt.UintValue) - case ast.BRIEOptionEncryptionKeyFile: - cfg.CipherInfo.CipherKey, err = task.GetCipherKeyContent("", opt.StrValue) - if err != nil { - b.err = err - return nil - } - case ast.BRIEOptionEncryptionMethod: - switch opt.StrValue { - case "aes128-ctr": - cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_AES128_CTR - case "aes192-ctr": - cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_AES192_CTR - case "aes256-ctr": - cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_AES256_CTR - case "plaintext": - cfg.CipherInfo.CipherType = encryptionpb.EncryptionMethod_PLAINTEXT - default: - b.err = errors.Errorf("unsupported encryption method: %s", opt.StrValue) - return nil - } - } - } - - switch { - case len(s.Tables) != 0: - tables := make([]filter.Table, 0, len(s.Tables)) - for _, tbl := range s.Tables { - tables = append(tables, filter.Table{Name: tbl.Name.O, Schema: tbl.Schema.O}) - } - cfg.TableFilter = filter.NewTablesFilter(tables...) - case len(s.Schemas) != 0: - cfg.TableFilter = filter.NewSchemasFilter(s.Schemas...) - default: - cfg.TableFilter = filter.All() - } - - // table options are stored in original case, but comparison - // is expected to be performed insensitive. - cfg.TableFilter = filter.CaseInsensitive(cfg.TableFilter) - - // We cannot directly use the query string, or the secret may be print. - // NOTE: the ownership of `s.Storage` is taken here. - s.Storage = e.info.storage - e.info.query = restoreQuery(s) - - switch s.Kind { - case ast.BRIEKindBackup: - bcfg := task.DefaultBackupConfig() - bcfg.Config = cfg - e.backupCfg = &bcfg - - for _, opt := range s.Options { - switch opt.Tp { - case ast.BRIEOptionLastBackupTS: - tso, err := b.parseTSString(opt.StrValue) - if err != nil { - b.err = err - return nil - } - e.backupCfg.LastBackupTS = tso - case ast.BRIEOptionLastBackupTSO: - e.backupCfg.LastBackupTS = opt.UintValue - case ast.BRIEOptionBackupTimeAgo: - e.backupCfg.TimeAgo = time.Duration(opt.UintValue) - case ast.BRIEOptionBackupTSO: - e.backupCfg.BackupTS = opt.UintValue - case ast.BRIEOptionBackupTS: - tso, err := b.parseTSString(opt.StrValue) - if err != nil { - b.err = err - return nil - } - e.backupCfg.BackupTS = tso - case ast.BRIEOptionCompression: - switch opt.StrValue { - case "zstd": - e.backupCfg.CompressionConfig.CompressionType = backuppb.CompressionType_ZSTD - case "snappy": - e.backupCfg.CompressionConfig.CompressionType = backuppb.CompressionType_SNAPPY - case "lz4": - e.backupCfg.CompressionConfig.CompressionType = backuppb.CompressionType_LZ4 - default: - b.err = errors.Errorf("unsupported compression type: %s", opt.StrValue) - return nil - } - case ast.BRIEOptionCompressionLevel: - e.backupCfg.CompressionConfig.CompressionLevel = int32(opt.UintValue) - case ast.BRIEOptionIgnoreStats: - e.backupCfg.IgnoreStats = opt.UintValue != 0 - } - } - - case ast.BRIEKindRestore: - rcfg := task.DefaultRestoreConfig() - rcfg.Config = cfg - e.restoreCfg = &rcfg - for _, opt := range s.Options { - switch opt.Tp { - case ast.BRIEOptionOnline: - e.restoreCfg.Online = opt.UintValue != 0 - case ast.BRIEOptionWaitTiflashReady: - e.restoreCfg.WaitTiflashReady = opt.UintValue != 0 - case ast.BRIEOptionWithSysTable: - e.restoreCfg.WithSysTable = opt.UintValue != 0 - case ast.BRIEOptionLoadStats: - e.restoreCfg.LoadStats = opt.UintValue != 0 - } - } - - default: - b.err = errors.Errorf("unsupported BRIE statement kind: %s", s.Kind) - return nil - } - - return e -} - -// oneshotExecutor wraps a executor, making its `Next` would only be called once. -type oneshotExecutor struct { - exec.Executor - finished bool -} - -func (o *oneshotExecutor) Next(ctx context.Context, req *chunk.Chunk) error { - if o.finished { - req.Reset() - return nil - } - - if err := o.Executor.Next(ctx, req); err != nil { - return err - } - o.finished = true - return nil -} - -func execOnce(ex exec.Executor) exec.Executor { - return &oneshotExecutor{Executor: ex} -} - -type showQueryExec struct { - exec.BaseExecutor - - targetID uint64 -} - -func (s *showQueryExec) Next(_ context.Context, req *chunk.Chunk) error { - req.Reset() - - tsk, ok := globalBRIEQueue.queryTask(s.targetID) - if !ok { - return nil - } - - req.AppendString(0, tsk.query) - return nil -} - -type cancelJobExec struct { - exec.BaseExecutor - - targetID uint64 -} - -func (s cancelJobExec) Next(_ context.Context, req *chunk.Chunk) error { - req.Reset() - if !globalBRIEQueue.cancelTask(s.targetID) { - s.Ctx().GetSessionVars().StmtCtx.AppendWarning(exeerrors.ErrLoadDataJobNotFound.FastGenByArgs(s.targetID)) - } - return nil -} - -type showMetaExec struct { - exec.BaseExecutor - - showConfig show.Config -} - -// BRIEExec represents an executor for BRIE statements (BACKUP, RESTORE, etc) -type BRIEExec struct { - exec.BaseExecutor - - backupCfg *task.BackupConfig - restoreCfg *task.RestoreConfig - showConfig *show.Config - info *brieTaskInfo -} - -func buildShowMetadataConfigFrom(s *ast.BRIEStmt) show.Config { - if s.Kind != ast.BRIEKindShowBackupMeta { - panic(fmt.Sprintf("precondition failed: `fillByShowMetadata` should always called by a ast.BRIEKindShowBackupMeta, but it is %s.", s.Kind)) - } - - store := s.Storage - cfg := show.Config{ - Storage: store, - Cipher: backuppb.CipherInfo{ - CipherType: encryptionpb.EncryptionMethod_PLAINTEXT, - }, - } - return cfg -} - -func (e *showMetaExec) Next(ctx context.Context, req *chunk.Chunk) error { - exe, err := show.CreateExec(ctx, e.showConfig) - if err != nil { - return errors.Annotate(err, "failed to create show exec") - } - res, err := exe.Read(ctx) - if err != nil { - return errors.Annotate(err, "failed to read metadata from backupmeta") - } - - startTime := oracle.GetTimeFromTS(uint64(res.StartVersion)) - endTime := oracle.GetTimeFromTS(uint64(res.EndVersion)) - - for _, table := range res.Tables { - req.AppendString(0, table.DBName) - req.AppendString(1, table.TableName) - req.AppendInt64(2, int64(table.KVCount)) - req.AppendInt64(3, int64(table.KVSize)) - if res.StartVersion > 0 { - req.AppendTime(4, types.NewTime(types.FromGoTime(startTime.In(e.Ctx().GetSessionVars().Location())), mysql.TypeDatetime, 0)) - } else { - req.AppendNull(4) - } - req.AppendTime(5, types.NewTime(types.FromGoTime(endTime.In(e.Ctx().GetSessionVars().Location())), mysql.TypeDatetime, 0)) - } - return nil -} - -// Next implements the Executor Next interface. -func (e *BRIEExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.info == nil { - return nil - } - - bq := globalBRIEQueue - bq.clearTask(e.Ctx().GetSessionVars().StmtCtx) - - e.info.connID = e.Ctx().GetSessionVars().ConnectionID - e.info.queueTime = types.CurrentTime(mysql.TypeDatetime) - taskCtx, taskID := bq.registerTask(ctx, e.info) - defer bq.cancelTask(taskID) - failpoint.Inject("block-on-brie", func() { - log.Warn("You shall not pass, nya. :3") - <-taskCtx.Done() - if taskCtx.Err() != nil { - failpoint.Return(taskCtx.Err()) - } - }) - // manually monitor the Killed status... - go func() { - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - if e.Ctx().GetSessionVars().SQLKiller.HandleSignal() == exeerrors.ErrQueryInterrupted { - bq.cancelTask(taskID) - return - } - case <-taskCtx.Done(): - return - } - } - }() - - progress, err := bq.acquireTask(taskCtx, taskID) - if err != nil { - return err - } - defer bq.releaseTask() - - e.info.execTime = types.CurrentTime(mysql.TypeDatetime) - glue := &tidbGlue{se: e.Ctx(), progress: progress, info: e.info} - - switch e.info.kind { - case ast.BRIEKindBackup: - err = handleBRIEError(task.RunBackup(taskCtx, glue, "Backup", e.backupCfg), exeerrors.ErrBRIEBackupFailed) - case ast.BRIEKindRestore: - err = handleBRIEError(task.RunRestore(taskCtx, glue, "Restore", e.restoreCfg), exeerrors.ErrBRIERestoreFailed) - default: - err = errors.Errorf("unsupported BRIE statement kind: %s", e.info.kind) - } - e.info.finishTime = types.CurrentTime(mysql.TypeDatetime) - if err != nil { - e.info.message = err.Error() - return err - } - e.info.message = "" - - req.AppendString(0, e.info.storage) - req.AppendUint64(1, e.info.archiveSize) - switch e.info.kind { - case ast.BRIEKindBackup: - req.AppendUint64(2, e.info.backupTS) - req.AppendTime(3, e.info.queueTime) - req.AppendTime(4, e.info.execTime) - case ast.BRIEKindRestore: - req.AppendUint64(2, e.info.backupTS) - req.AppendUint64(3, e.info.restoreTS) - req.AppendTime(4, e.info.queueTime) - req.AppendTime(5, e.info.execTime) - } - e.info = nil - return nil -} - -func handleBRIEError(err error, terror *terror.Error) error { - if err == nil { - return nil - } - return terror.GenWithStackByArgs(err) -} - -func (e *ShowExec) fetchShowBRIE(kind ast.BRIEKind) error { - globalBRIEQueue.tasks.Range(func(_, value any) bool { - item := value.(*brieQueueItem) - if item.info.kind == kind { - item.progress.lock.Lock() - defer item.progress.lock.Unlock() - current := atomic.LoadInt64(&item.progress.current) - e.result.AppendUint64(0, item.info.id) - e.result.AppendString(1, item.info.storage) - e.result.AppendString(2, item.progress.cmd) - e.result.AppendFloat64(3, 100.0*float64(current)/float64(item.progress.total)) - e.result.AppendTime(4, item.info.queueTime) - e.result.AppendTime(5, item.info.execTime) - e.result.AppendTime(6, item.info.finishTime) - e.result.AppendUint64(7, item.info.connID) - if len(item.info.message) > 0 { - e.result.AppendString(8, item.info.message) - } else { - e.result.AppendNull(8) - } - } - return true - }) - globalBRIEQueue.clearTask(e.Ctx().GetSessionVars().StmtCtx) - return nil -} - -type tidbGlue struct { - // the session context of the brie task - se sessionctx.Context - progress *brieTaskProgress - info *brieTaskInfo -} - -// GetDomain implements glue.Glue -func (gs *tidbGlue) GetDomain(_ kv.Storage) (*domain.Domain, error) { - return domain.GetDomain(gs.se), nil -} - -// CreateSession implements glue.Glue -func (gs *tidbGlue) CreateSession(_ kv.Storage) (glue.Session, error) { - newSCtx, err := CreateSession(gs.se) - if err != nil { - return nil, err - } - return &tidbGlueSession{se: newSCtx}, nil -} - -// Open implements glue.Glue -func (gs *tidbGlue) Open(string, pd.SecurityOption) (kv.Storage, error) { - return gs.se.GetStore(), nil -} - -// OwnsStorage implements glue.Glue -func (*tidbGlue) OwnsStorage() bool { - return false -} - -// StartProgress implements glue.Glue -func (gs *tidbGlue) StartProgress(_ context.Context, cmdName string, total int64, _ bool) glue.Progress { - gs.progress.lock.Lock() - gs.progress.cmd = cmdName - gs.progress.total = total - atomic.StoreInt64(&gs.progress.current, 0) - gs.progress.lock.Unlock() - return gs.progress -} - -// Record implements glue.Glue -func (gs *tidbGlue) Record(name string, value uint64) { - switch name { - case "BackupTS": - gs.info.backupTS = value - case "RestoreTS": - gs.info.restoreTS = value - case "Size": - gs.info.archiveSize = value - } -} - -func (*tidbGlue) GetVersion() string { - return "TiDB\n" + printer.GetTiDBInfo() -} - -// UseOneShotSession implements glue.Glue -func (gs *tidbGlue) UseOneShotSession(_ kv.Storage, _ bool, fn func(se glue.Session) error) error { - // In SQL backup, we don't need to close domain, - // but need to create an new session. - newSCtx, err := CreateSession(gs.se) - if err != nil { - return err - } - glueSession := &tidbGlueSession{se: newSCtx} - defer func() { - CloseSession(newSCtx) - log.Info("one shot session from brie closed") - }() - return fn(glueSession) -} - -type tidbGlueSession struct { - // the session context of the brie task's subtask, such as `CREATE TABLE`. - se sessionctx.Context -} - -// Execute implements glue.Session -// These queries execute without privilege checking, since the calling statements -// such as BACKUP and RESTORE have already been privilege checked. -// NOTE: Maybe drain the restult too? See `gluetidb.tidbSession.ExecuteInternal` for more details. -func (gs *tidbGlueSession) Execute(ctx context.Context, sql string) error { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) - _, _, err := gs.se.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, sql) - return err -} - -func (gs *tidbGlueSession) ExecuteInternal(ctx context.Context, sql string, args ...any) error { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnBR) - exec := gs.se.GetSQLExecutor() - _, err := exec.ExecuteInternal(ctx, sql, args...) - return err -} - -// CreateDatabase implements glue.Session -func (gs *tidbGlueSession) CreateDatabase(_ context.Context, schema *model.DBInfo) error { - return BRIECreateDatabase(gs.se, schema, "") -} - -// CreateTable implements glue.Session -func (gs *tidbGlueSession) CreateTable(_ context.Context, dbName model.CIStr, table *model.TableInfo, cs ...ddl.CreateTableOption) error { - return BRIECreateTable(gs.se, dbName, table, "", cs...) -} - -// CreateTables implements glue.BatchCreateTableSession. -func (gs *tidbGlueSession) CreateTables(_ context.Context, - tables map[string][]*model.TableInfo, cs ...ddl.CreateTableOption) error { - return BRIECreateTables(gs.se, tables, "", cs...) -} - -// CreatePlacementPolicy implements glue.Session -func (gs *tidbGlueSession) CreatePlacementPolicy(_ context.Context, policy *model.PolicyInfo) error { - originQueryString := gs.se.Value(sessionctx.QueryString) - defer gs.se.SetValue(sessionctx.QueryString, originQueryString) - gs.se.SetValue(sessionctx.QueryString, ConstructResultOfShowCreatePlacementPolicy(policy)) - d := domain.GetDomain(gs.se).DDLExecutor() - // the default behaviour is ignoring duplicated policy during restore. - return d.CreatePlacementPolicyWithInfo(gs.se, policy, ddl.OnExistIgnore) -} - -// Close implements glue.Session -func (gs *tidbGlueSession) Close() { - CloseSession(gs.se) -} - -// GetGlobalVariables implements glue.Session. -func (gs *tidbGlueSession) GetGlobalVariable(name string) (string, error) { - return gs.se.GetSessionVars().GlobalVarsAccessor.GetTiDBTableValue(name) -} - -// GetSessionCtx implements glue.Glue -func (gs *tidbGlueSession) GetSessionCtx() sessionctx.Context { - return gs.se -} - -func restoreQuery(stmt *ast.BRIEStmt) string { - out := bytes.NewBuffer(nil) - rc := format.NewRestoreCtx(format.RestoreNameBackQuotes|format.RestoreStringSingleQuotes, out) - if err := stmt.Restore(rc); err != nil { - return "N/A" - } - return out.String() -} diff --git a/pkg/executor/builder.go b/pkg/executor/builder.go index 95fb0b69094a4..6bcee9f5f8be3 100644 --- a/pkg/executor/builder.go +++ b/pkg/executor/builder.go @@ -846,7 +846,7 @@ func (b *executorBuilder) buildExecute(v *plannercore.Execute) exec.Executor { outputNames: v.OutputNames(), } - if val, _err_ := failpoint.Eval(_curpkg_("assertExecutePrepareStatementStalenessOption")); _err_ == nil { + failpoint.Inject("assertExecutePrepareStatementStalenessOption", func(val failpoint.Value) { vs := strings.Split(val.(string), "_") assertTS, assertReadReplicaScope := vs[0], vs[1] staleread.AssertStmtStaleness(b.ctx, true) @@ -859,7 +859,7 @@ func (b *executorBuilder) buildExecute(v *plannercore.Execute) exec.Executor { assertReadReplicaScope != b.readReplicaScope { panic("execute prepare statement have wrong staleness option") } - } + }) return e } @@ -2729,9 +2729,9 @@ func (b *executorBuilder) buildAnalyzeIndexPushdown(task plannercore.AnalyzeInde b.err = err return nil } - if val, _err_ := failpoint.Eval(_curpkg_("injectAnalyzeSnapshot")); _err_ == nil { + failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { startTS = uint64(val.(int)) - } + }) concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) base := baseAnalyzeExec{ ctx: b.ctx, @@ -2802,21 +2802,21 @@ func (b *executorBuilder) buildAnalyzeSamplingPushdown( b.err = err return nil } - if val, _err_ := failpoint.Eval(_curpkg_("injectAnalyzeSnapshot")); _err_ == nil { + failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { startTS = uint64(val.(int)) - } + }) statsHandle := domain.GetDomain(b.ctx).StatsHandle() count, modifyCount, err := statsHandle.StatsMetaCountAndModifyCount(task.TableID.GetStatisticsID()) if err != nil { b.err = err return nil } - if val, _err_ := failpoint.Eval(_curpkg_("injectBaseCount")); _err_ == nil { + failpoint.Inject("injectBaseCount", func(val failpoint.Value) { count = int64(val.(int)) - } - if val, _err_ := failpoint.Eval(_curpkg_("injectBaseModifyCount")); _err_ == nil { + }) + failpoint.Inject("injectBaseModifyCount", func(val failpoint.Value) { modifyCount = int64(val.(int)) - } + }) sampleRate := new(float64) var sampleRateReason string if opts[ast.AnalyzeOptNumSamples] == 0 { @@ -2980,9 +2980,9 @@ func (b *executorBuilder) buildAnalyzeColumnsPushdown( b.err = err return nil } - if val, _err_ := failpoint.Eval(_curpkg_("injectAnalyzeSnapshot")); _err_ == nil { + failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { startTS = uint64(val.(int)) - } + }) concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) base := baseAnalyzeExec{ ctx: b.ctx, @@ -3588,16 +3588,16 @@ func (b *executorBuilder) buildMPPGather(v *plannercore.PhysicalTableReader) exe // buildTableReader builds a table reader executor. It first build a no range table reader, // and then update it ranges from table scan plan. func (b *executorBuilder) buildTableReader(v *plannercore.PhysicalTableReader) exec.Executor { - if val, _err_ := failpoint.Eval(_curpkg_("checkUseMPP")); _err_ == nil { + failpoint.Inject("checkUseMPP", func(val failpoint.Value) { if !b.ctx.GetSessionVars().InRestrictedSQL && val.(bool) != useMPPExecution(b.ctx, v) { if val.(bool) { b.err = errors.New("expect mpp but not used") } else { b.err = errors.New("don't expect mpp but we used it") } - return nil + failpoint.Return(nil) } - } + }) // https://github.com/pingcap/tidb/issues/50358 if len(v.Schema().Columns) == 0 && len(v.GetTablePlan().Schema().Columns) > 0 { v.SetSchema(v.GetTablePlan().Schema()) @@ -4790,11 +4790,11 @@ func (builder *dataReaderBuilder) buildProjectionForIndexJoin( if int64(v.StatsCount()) < int64(builder.ctx.GetSessionVars().MaxChunkSize) { e.numWorkers = 0 } - if val, _err_ := failpoint.Eval(_curpkg_("buildProjectionForIndexJoinPanic")); _err_ == nil { + failpoint.Inject("buildProjectionForIndexJoinPanic", func(val failpoint.Value) { if v, ok := val.(bool); ok && v { panic("buildProjectionForIndexJoinPanic") } - } + }) err = e.open(ctx) if err != nil { return nil, err @@ -4892,9 +4892,9 @@ func buildKvRangesForIndexJoin(dctx *distsqlctx.DistSQLContext, pctx *rangerctx. } } if len(kvRanges) != 0 && memTracker != nil { - if _, _err_ := failpoint.Eval(_curpkg_("testIssue49033")); _err_ == nil { + failpoint.Inject("testIssue49033", func() { panic("testIssue49033") - } + }) memTracker.Consume(int64(2 * cap(kvRanges[0].StartKey) * len(kvRanges))) } if len(tmpDatumRanges) != 0 && memTracker != nil { @@ -5251,12 +5251,12 @@ func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan sctx.IndexNames = append(sctx.IndexNames, plan.TblInfo.Name.O+":"+plan.IndexInfo.Name.O) } - if val, _err_ := failpoint.Eval(_curpkg_("assertBatchPointReplicaOption")); _err_ == nil { + failpoint.Inject("assertBatchPointReplicaOption", func(val failpoint.Value) { assertScope := val.(string) if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != b.readReplicaScope { panic("batch point get replica option fail") } - } + }) snapshotTS, err := b.getSnapshotTS() if err != nil { diff --git a/pkg/executor/builder.go__failpoint_stash__ b/pkg/executor/builder.go__failpoint_stash__ deleted file mode 100644 index 6bcee9f5f8be3..0000000000000 --- a/pkg/executor/builder.go__failpoint_stash__ +++ /dev/null @@ -1,5659 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "cmp" - "context" - "fmt" - "math" - "slices" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/diagnosticspb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/distsql" - distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/executor/aggfuncs" - "github.com/pingcap/tidb/pkg/executor/aggregate" - "github.com/pingcap/tidb/pkg/executor/internal/builder" - "github.com/pingcap/tidb/pkg/executor/internal/calibrateresource" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/internal/pdhelper" - "github.com/pingcap/tidb/pkg/executor/internal/querywatch" - "github.com/pingcap/tidb/pkg/executor/internal/testutil" - "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" - "github.com/pingcap/tidb/pkg/executor/join" - "github.com/pingcap/tidb/pkg/executor/lockstats" - executor_metrics "github.com/pingcap/tidb/pkg/executor/metrics" - "github.com/pingcap/tidb/pkg/executor/sortexec" - "github.com/pingcap/tidb/pkg/executor/unionexec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/expression/aggregation" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/coreusage" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/sessiontxn/staleread" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/table/temptable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/cteutil" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" - rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" - "github.com/pingcap/tidb/pkg/util/rowcodec" - "github.com/pingcap/tidb/pkg/util/tiflash" - "github.com/pingcap/tidb/pkg/util/timeutil" - "github.com/pingcap/tipb/go-tipb" - clientkv "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/txnkv" - "github.com/tikv/client-go/v2/txnkv/txnsnapshot" -) - -// executorBuilder builds an Executor from a Plan. -// The InfoSchema must not change during execution. -type executorBuilder struct { - ctx sessionctx.Context - is infoschema.InfoSchema - err error // err is set when there is error happened during Executor building process. - hasLock bool - // isStaleness means whether this statement use stale read. - isStaleness bool - txnScope string - readReplicaScope string - inUpdateStmt bool - inDeleteStmt bool - inInsertStmt bool - inSelectLockStmt bool - - // forDataReaderBuilder indicates whether the builder is used by a dataReaderBuilder. - // When forDataReader is true, the builder should use the dataReaderTS as the executor read ts. This is because - // dataReaderBuilder can be used in concurrent goroutines, so we must ensure that getting the ts should be thread safe and - // can return a correct value even if the session context has already been destroyed - forDataReaderBuilder bool - dataReaderTS uint64 - - // Used when building MPPGather. - encounterUnionScan bool -} - -// CTEStorages stores resTbl and iterInTbl for CTEExec. -// There will be a map[CTEStorageID]*CTEStorages in StmtCtx, -// which will store all CTEStorages to make all shared CTEs use same the CTEStorages. -type CTEStorages struct { - ResTbl cteutil.Storage - IterInTbl cteutil.Storage - Producer *cteProducer -} - -func newExecutorBuilder(ctx sessionctx.Context, is infoschema.InfoSchema) *executorBuilder { - txnManager := sessiontxn.GetTxnManager(ctx) - return &executorBuilder{ - ctx: ctx, - is: is, - isStaleness: staleread.IsStmtStaleness(ctx), - txnScope: txnManager.GetTxnScope(), - readReplicaScope: txnManager.GetReadReplicaScope(), - } -} - -// MockExecutorBuilder is a wrapper for executorBuilder. -// ONLY used in test. -type MockExecutorBuilder struct { - *executorBuilder -} - -// NewMockExecutorBuilderForTest is ONLY used in test. -func NewMockExecutorBuilderForTest(ctx sessionctx.Context, is infoschema.InfoSchema) *MockExecutorBuilder { - return &MockExecutorBuilder{ - executorBuilder: newExecutorBuilder(ctx, is)} -} - -// Build builds an executor tree according to `p`. -func (b *MockExecutorBuilder) Build(p base.Plan) exec.Executor { - return b.build(p) -} - -func (b *executorBuilder) build(p base.Plan) exec.Executor { - switch v := p.(type) { - case nil: - return nil - case *plannercore.Change: - return b.buildChange(v) - case *plannercore.CheckTable: - return b.buildCheckTable(v) - case *plannercore.RecoverIndex: - return b.buildRecoverIndex(v) - case *plannercore.CleanupIndex: - return b.buildCleanupIndex(v) - case *plannercore.CheckIndexRange: - return b.buildCheckIndexRange(v) - case *plannercore.ChecksumTable: - return b.buildChecksumTable(v) - case *plannercore.ReloadExprPushdownBlacklist: - return b.buildReloadExprPushdownBlacklist(v) - case *plannercore.ReloadOptRuleBlacklist: - return b.buildReloadOptRuleBlacklist(v) - case *plannercore.AdminPlugins: - return b.buildAdminPlugins(v) - case *plannercore.DDL: - return b.buildDDL(v) - case *plannercore.Deallocate: - return b.buildDeallocate(v) - case *plannercore.Delete: - return b.buildDelete(v) - case *plannercore.Execute: - return b.buildExecute(v) - case *plannercore.Trace: - return b.buildTrace(v) - case *plannercore.Explain: - return b.buildExplain(v) - case *plannercore.PointGetPlan: - return b.buildPointGet(v) - case *plannercore.BatchPointGetPlan: - return b.buildBatchPointGet(v) - case *plannercore.Insert: - return b.buildInsert(v) - case *plannercore.ImportInto: - return b.buildImportInto(v) - case *plannercore.LoadData: - return b.buildLoadData(v) - case *plannercore.LoadStats: - return b.buildLoadStats(v) - case *plannercore.LockStats: - return b.buildLockStats(v) - case *plannercore.UnlockStats: - return b.buildUnlockStats(v) - case *plannercore.IndexAdvise: - return b.buildIndexAdvise(v) - case *plannercore.PlanReplayer: - return b.buildPlanReplayer(v) - case *plannercore.PhysicalLimit: - return b.buildLimit(v) - case *plannercore.Prepare: - return b.buildPrepare(v) - case *plannercore.PhysicalLock: - return b.buildSelectLock(v) - case *plannercore.CancelDDLJobs: - return b.buildCancelDDLJobs(v) - case *plannercore.PauseDDLJobs: - return b.buildPauseDDLJobs(v) - case *plannercore.ResumeDDLJobs: - return b.buildResumeDDLJobs(v) - case *plannercore.ShowNextRowID: - return b.buildShowNextRowID(v) - case *plannercore.ShowDDL: - return b.buildShowDDL(v) - case *plannercore.PhysicalShowDDLJobs: - return b.buildShowDDLJobs(v) - case *plannercore.ShowDDLJobQueries: - return b.buildShowDDLJobQueries(v) - case *plannercore.ShowDDLJobQueriesWithRange: - return b.buildShowDDLJobQueriesWithRange(v) - case *plannercore.ShowSlow: - return b.buildShowSlow(v) - case *plannercore.PhysicalShow: - return b.buildShow(v) - case *plannercore.Simple: - return b.buildSimple(v) - case *plannercore.PhysicalSimpleWrapper: - return b.buildSimple(&v.Inner) - case *plannercore.Set: - return b.buildSet(v) - case *plannercore.SetConfig: - return b.buildSetConfig(v) - case *plannercore.PhysicalSort: - return b.buildSort(v) - case *plannercore.PhysicalTopN: - return b.buildTopN(v) - case *plannercore.PhysicalUnionAll: - return b.buildUnionAll(v) - case *plannercore.Update: - return b.buildUpdate(v) - case *plannercore.PhysicalUnionScan: - return b.buildUnionScanExec(v) - case *plannercore.PhysicalHashJoin: - return b.buildHashJoin(v) - case *plannercore.PhysicalMergeJoin: - return b.buildMergeJoin(v) - case *plannercore.PhysicalIndexJoin: - return b.buildIndexLookUpJoin(v) - case *plannercore.PhysicalIndexMergeJoin: - return b.buildIndexLookUpMergeJoin(v) - case *plannercore.PhysicalIndexHashJoin: - return b.buildIndexNestedLoopHashJoin(v) - case *plannercore.PhysicalSelection: - return b.buildSelection(v) - case *plannercore.PhysicalHashAgg: - return b.buildHashAgg(v) - case *plannercore.PhysicalStreamAgg: - return b.buildStreamAgg(v) - case *plannercore.PhysicalProjection: - return b.buildProjection(v) - case *plannercore.PhysicalMemTable: - return b.buildMemTable(v) - case *plannercore.PhysicalTableDual: - return b.buildTableDual(v) - case *plannercore.PhysicalApply: - return b.buildApply(v) - case *plannercore.PhysicalMaxOneRow: - return b.buildMaxOneRow(v) - case *plannercore.Analyze: - return b.buildAnalyze(v) - case *plannercore.PhysicalTableReader: - return b.buildTableReader(v) - case *plannercore.PhysicalTableSample: - return b.buildTableSample(v) - case *plannercore.PhysicalIndexReader: - return b.buildIndexReader(v) - case *plannercore.PhysicalIndexLookUpReader: - return b.buildIndexLookUpReader(v) - case *plannercore.PhysicalWindow: - return b.buildWindow(v) - case *plannercore.PhysicalShuffle: - return b.buildShuffle(v) - case *plannercore.PhysicalShuffleReceiverStub: - return b.buildShuffleReceiverStub(v) - case *plannercore.SQLBindPlan: - return b.buildSQLBindExec(v) - case *plannercore.SplitRegion: - return b.buildSplitRegion(v) - case *plannercore.PhysicalIndexMergeReader: - return b.buildIndexMergeReader(v) - case *plannercore.SelectInto: - return b.buildSelectInto(v) - case *plannercore.PhysicalCTE: - return b.buildCTE(v) - case *plannercore.PhysicalCTETable: - return b.buildCTETableReader(v) - case *plannercore.CompactTable: - return b.buildCompactTable(v) - case *plannercore.AdminShowBDRRole: - return b.buildAdminShowBDRRole(v) - case *plannercore.PhysicalExpand: - return b.buildExpand(v) - default: - if mp, ok := p.(testutil.MockPhysicalPlan); ok { - return mp.GetExecutor() - } - - b.err = exeerrors.ErrUnknownPlan.GenWithStack("Unknown Plan %T", p) - return nil - } -} - -func (b *executorBuilder) buildCancelDDLJobs(v *plannercore.CancelDDLJobs) exec.Executor { - e := &CancelDDLJobsExec{ - CommandDDLJobsExec: &CommandDDLJobsExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - jobIDs: v.JobIDs, - execute: ddl.CancelJobs, - }, - } - return e -} - -func (b *executorBuilder) buildPauseDDLJobs(v *plannercore.PauseDDLJobs) exec.Executor { - e := &PauseDDLJobsExec{ - CommandDDLJobsExec: &CommandDDLJobsExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - jobIDs: v.JobIDs, - execute: ddl.PauseJobs, - }, - } - return e -} - -func (b *executorBuilder) buildResumeDDLJobs(v *plannercore.ResumeDDLJobs) exec.Executor { - e := &ResumeDDLJobsExec{ - CommandDDLJobsExec: &CommandDDLJobsExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - jobIDs: v.JobIDs, - execute: ddl.ResumeJobs, - }, - } - return e -} - -func (b *executorBuilder) buildChange(v *plannercore.Change) exec.Executor { - return &ChangeExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - ChangeStmt: v.ChangeStmt, - } -} - -func (b *executorBuilder) buildShowNextRowID(v *plannercore.ShowNextRowID) exec.Executor { - e := &ShowNextRowIDExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - tblName: v.TableName, - } - return e -} - -func (b *executorBuilder) buildShowDDL(v *plannercore.ShowDDL) exec.Executor { - // We get Info here because for Executors that returns result set, - // next will be called after transaction has been committed. - // We need the transaction to get Info. - e := &ShowDDLExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - } - - var err error - ownerManager := domain.GetDomain(e.Ctx()).DDL().OwnerManager() - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - e.ddlOwnerID, err = ownerManager.GetOwnerID(ctx) - cancel() - if err != nil { - b.err = err - return nil - } - - session, err := e.GetSysSession() - if err != nil { - b.err = err - return nil - } - ddlInfo, err := ddl.GetDDLInfoWithNewTxn(session) - e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), session) - if err != nil { - b.err = err - return nil - } - e.ddlInfo = ddlInfo - e.selfID = ownerManager.ID() - return e -} - -func (b *executorBuilder) buildShowDDLJobs(v *plannercore.PhysicalShowDDLJobs) exec.Executor { - loc := b.ctx.GetSessionVars().Location() - ddlJobRetriever := DDLJobRetriever{TZLoc: loc} - e := &ShowDDLJobsExec{ - jobNumber: int(v.JobNumber), - is: b.is, - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - DDLJobRetriever: ddlJobRetriever, - } - return e -} - -func (b *executorBuilder) buildShowDDLJobQueries(v *plannercore.ShowDDLJobQueries) exec.Executor { - e := &ShowDDLJobQueriesExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - jobIDs: v.JobIDs, - } - return e -} - -func (b *executorBuilder) buildShowDDLJobQueriesWithRange(v *plannercore.ShowDDLJobQueriesWithRange) exec.Executor { - e := &ShowDDLJobQueriesWithRangeExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - offset: v.Offset, - limit: v.Limit, - } - return e -} - -func (b *executorBuilder) buildShowSlow(v *plannercore.ShowSlow) exec.Executor { - e := &ShowSlowExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - ShowSlow: v.ShowSlow, - } - return e -} - -// buildIndexLookUpChecker builds check information to IndexLookUpReader. -func buildIndexLookUpChecker(b *executorBuilder, p *plannercore.PhysicalIndexLookUpReader, - e *IndexLookUpExecutor) { - is := p.IndexPlans[0].(*plannercore.PhysicalIndexScan) - fullColLen := len(is.Index.Columns) + len(p.CommonHandleCols) - if !e.isCommonHandle() { - fullColLen++ - } - if e.index.Global { - fullColLen++ - } - e.dagPB.OutputOffsets = make([]uint32, fullColLen) - for i := 0; i < fullColLen; i++ { - e.dagPB.OutputOffsets[i] = uint32(i) - } - - ts := p.TablePlans[0].(*plannercore.PhysicalTableScan) - e.handleIdx = ts.HandleIdx - - e.ranges = ranger.FullRange() - - tps := make([]*types.FieldType, 0, fullColLen) - for _, col := range is.Columns { - // tps is used to decode the index, we should use the element type of the array if any. - tps = append(tps, col.FieldType.ArrayType()) - } - - if !e.isCommonHandle() { - tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) - } - if e.index.Global { - tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) - } - - e.checkIndexValue = &checkIndexValue{idxColTps: tps} - - colNames := make([]string, 0, len(is.IdxCols)) - for i := range is.IdxCols { - colNames = append(colNames, is.Columns[i].Name.L) - } - if cols, missingColOffset := table.FindColumns(e.table.Cols(), colNames, true); missingColOffset >= 0 { - b.err = plannererrors.ErrUnknownColumn.GenWithStack("Unknown column %s", is.Columns[missingColOffset].Name.O) - } else { - e.idxTblCols = cols - } -} - -func (b *executorBuilder) buildCheckTable(v *plannercore.CheckTable) exec.Executor { - noMVIndexOrPrefixIndex := true - for _, idx := range v.IndexInfos { - if idx.MVIndex { - noMVIndexOrPrefixIndex = false - break - } - for _, col := range idx.Columns { - if col.Length != types.UnspecifiedLength { - noMVIndexOrPrefixIndex = false - break - } - } - if !noMVIndexOrPrefixIndex { - break - } - } - if b.ctx.GetSessionVars().FastCheckTable && noMVIndexOrPrefixIndex { - e := &FastCheckTableExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - dbName: v.DBName, - table: v.Table, - indexInfos: v.IndexInfos, - is: b.is, - err: &atomic.Pointer[error]{}, - } - return e - } - - readerExecs := make([]*IndexLookUpExecutor, 0, len(v.IndexLookUpReaders)) - for _, readerPlan := range v.IndexLookUpReaders { - readerExec, err := buildNoRangeIndexLookUpReader(b, readerPlan) - if err != nil { - b.err = errors.Trace(err) - return nil - } - buildIndexLookUpChecker(b, readerPlan, readerExec) - - readerExecs = append(readerExecs, readerExec) - } - - e := &CheckTableExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - dbName: v.DBName, - table: v.Table, - indexInfos: v.IndexInfos, - is: b.is, - srcs: readerExecs, - exitCh: make(chan struct{}), - retCh: make(chan error, len(readerExecs)), - checkIndex: v.CheckIndex, - } - return e -} - -func buildIdxColsConcatHandleCols(tblInfo *model.TableInfo, indexInfo *model.IndexInfo, hasGenedCol bool) []*model.ColumnInfo { - var pkCols []*model.IndexColumn - if tblInfo.IsCommonHandle { - pkIdx := tables.FindPrimaryIndex(tblInfo) - pkCols = pkIdx.Columns - } - - columns := make([]*model.ColumnInfo, 0, len(indexInfo.Columns)+len(pkCols)) - if hasGenedCol { - columns = tblInfo.Columns - } else { - for _, idxCol := range indexInfo.Columns { - if tblInfo.PKIsHandle && tblInfo.GetPkColInfo().Offset == idxCol.Offset { - continue - } - columns = append(columns, tblInfo.Columns[idxCol.Offset]) - } - } - - if tblInfo.IsCommonHandle { - for _, c := range pkCols { - if model.FindColumnInfo(columns, c.Name.L) == nil { - columns = append(columns, tblInfo.Columns[c.Offset]) - } - } - return columns - } - if tblInfo.PKIsHandle { - columns = append(columns, tblInfo.Columns[tblInfo.GetPkColInfo().Offset]) - return columns - } - handleOffset := len(columns) - handleColsInfo := &model.ColumnInfo{ - ID: model.ExtraHandleID, - Name: model.ExtraHandleName, - Offset: handleOffset, - } - handleColsInfo.FieldType = *types.NewFieldType(mysql.TypeLonglong) - columns = append(columns, handleColsInfo) - return columns -} - -func (b *executorBuilder) buildRecoverIndex(v *plannercore.RecoverIndex) exec.Executor { - tblInfo := v.Table.TableInfo - t, err := b.is.TableByName(context.Background(), v.Table.Schema, tblInfo.Name) - if err != nil { - b.err = err - return nil - } - idxName := strings.ToLower(v.IndexName) - index := tables.GetWritableIndexByName(idxName, t) - if index == nil { - b.err = errors.Errorf("secondary index `%v` is not found in table `%v`", v.IndexName, v.Table.Name.O) - return nil - } - var hasGenedCol bool - for _, iCol := range index.Meta().Columns { - if tblInfo.Columns[iCol.Offset].IsGenerated() { - hasGenedCol = true - } - } - cols := buildIdxColsConcatHandleCols(tblInfo, index.Meta(), hasGenedCol) - e := &RecoverIndexExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - columns: cols, - containsGenedCol: hasGenedCol, - index: index, - table: t, - physicalID: t.Meta().ID, - } - sessCtx := e.Ctx().GetSessionVars().StmtCtx - e.handleCols = buildHandleColsForExec(sessCtx, tblInfo, e.columns) - return e -} - -func buildHandleColsForExec(sctx *stmtctx.StatementContext, tblInfo *model.TableInfo, - allColInfo []*model.ColumnInfo) plannerutil.HandleCols { - if !tblInfo.IsCommonHandle { - extraColPos := len(allColInfo) - 1 - intCol := &expression.Column{ - Index: extraColPos, - RetType: types.NewFieldType(mysql.TypeLonglong), - } - return plannerutil.NewIntHandleCols(intCol) - } - tblCols := make([]*expression.Column, len(tblInfo.Columns)) - for i := 0; i < len(tblInfo.Columns); i++ { - c := tblInfo.Columns[i] - tblCols[i] = &expression.Column{ - RetType: &c.FieldType, - ID: c.ID, - } - } - pkIdx := tables.FindPrimaryIndex(tblInfo) - for _, c := range pkIdx.Columns { - for j, colInfo := range allColInfo { - if colInfo.Name.L == c.Name.L { - tblCols[c.Offset].Index = j - } - } - } - return plannerutil.NewCommonHandleCols(sctx, tblInfo, pkIdx, tblCols) -} - -func (b *executorBuilder) buildCleanupIndex(v *plannercore.CleanupIndex) exec.Executor { - tblInfo := v.Table.TableInfo - t, err := b.is.TableByName(context.Background(), v.Table.Schema, tblInfo.Name) - if err != nil { - b.err = err - return nil - } - idxName := strings.ToLower(v.IndexName) - var index table.Index - for _, idx := range t.Indices() { - if idx.Meta().State != model.StatePublic { - continue - } - if idxName == idx.Meta().Name.L { - index = idx - break - } - } - - if index == nil { - b.err = errors.Errorf("secondary index `%v` is not found in table `%v`", v.IndexName, v.Table.Name.O) - return nil - } - e := &CleanupIndexExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - columns: buildIdxColsConcatHandleCols(tblInfo, index.Meta(), false), - index: index, - table: t, - physicalID: t.Meta().ID, - batchSize: 20000, - } - sessCtx := e.Ctx().GetSessionVars().StmtCtx - e.handleCols = buildHandleColsForExec(sessCtx, tblInfo, e.columns) - if e.index.Meta().Global { - e.columns = append(e.columns, model.NewExtraPhysTblIDColInfo()) - } - return e -} - -func (b *executorBuilder) buildCheckIndexRange(v *plannercore.CheckIndexRange) exec.Executor { - tb, err := b.is.TableByName(context.Background(), v.Table.Schema, v.Table.Name) - if err != nil { - b.err = err - return nil - } - e := &CheckIndexRangeExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - handleRanges: v.HandleRanges, - table: tb.Meta(), - is: b.is, - } - idxName := strings.ToLower(v.IndexName) - for _, idx := range tb.Indices() { - if idx.Meta().Name.L == idxName { - e.index = idx.Meta() - e.startKey = make([]types.Datum, len(e.index.Columns)) - break - } - } - return e -} - -func (b *executorBuilder) buildChecksumTable(v *plannercore.ChecksumTable) exec.Executor { - e := &ChecksumTableExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - tables: make(map[int64]*checksumContext), - done: false, - } - startTs, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - for _, t := range v.Tables { - e.tables[t.TableInfo.ID] = newChecksumContext(t.DBInfo, t.TableInfo, startTs) - } - return e -} - -func (b *executorBuilder) buildReloadExprPushdownBlacklist(_ *plannercore.ReloadExprPushdownBlacklist) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, nil, 0) - return &ReloadExprPushdownBlacklistExec{base} -} - -func (b *executorBuilder) buildReloadOptRuleBlacklist(_ *plannercore.ReloadOptRuleBlacklist) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, nil, 0) - return &ReloadOptRuleBlacklistExec{BaseExecutor: base} -} - -func (b *executorBuilder) buildAdminPlugins(v *plannercore.AdminPlugins) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, nil, 0) - return &AdminPluginsExec{BaseExecutor: base, Action: v.Action, Plugins: v.Plugins} -} - -func (b *executorBuilder) buildDeallocate(v *plannercore.Deallocate) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, nil, v.ID()) - base.SetInitCap(chunk.ZeroCapacity) - e := &DeallocateExec{ - BaseExecutor: base, - Name: v.Name, - } - return e -} - -func (b *executorBuilder) buildSelectLock(v *plannercore.PhysicalLock) exec.Executor { - if !b.inSelectLockStmt { - b.inSelectLockStmt = true - defer func() { b.inSelectLockStmt = false }() - } - if b.err = b.updateForUpdateTS(); b.err != nil { - return nil - } - - src := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - if !b.ctx.GetSessionVars().PessimisticLockEligible() { - // Locking of rows for update using SELECT FOR UPDATE only applies when autocommit - // is disabled (either by beginning transaction with START TRANSACTION or by setting - // autocommit to 0. If autocommit is enabled, the rows matching the specification are not locked. - // See https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html - return src - } - // If the `PhysicalLock` is not ignored by the above logic, set the `hasLock` flag. - b.hasLock = true - e := &SelectLockExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), src), - Lock: v.Lock, - tblID2Handle: v.TblID2Handle, - tblID2PhysTblIDCol: v.TblID2PhysTblIDCol, - } - - // filter out temporary tables because they do not store any record in tikv and should not write any lock - is := e.Ctx().GetInfoSchema().(infoschema.InfoSchema) - for tblID := range e.tblID2Handle { - tblInfo, ok := is.TableByID(tblID) - if !ok { - b.err = errors.Errorf("Can not get table %d", tblID) - } - - if tblInfo.Meta().TempTableType != model.TempTableNone { - delete(e.tblID2Handle, tblID) - } - } - - return e -} - -func (b *executorBuilder) buildLimit(v *plannercore.PhysicalLimit) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - n := int(min(v.Count, uint64(b.ctx.GetSessionVars().MaxChunkSize))) - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec) - base.SetInitCap(n) - e := &LimitExec{ - BaseExecutor: base, - begin: v.Offset, - end: v.Offset + v.Count, - } - - childUsedSchemaLen := v.Children()[0].Schema().Len() - childUsedSchema := markChildrenUsedCols(v.Schema().Columns, v.Children()[0].Schema())[0] - e.columnIdxsUsedByChild = make([]int, 0, len(childUsedSchema)) - e.columnIdxsUsedByChild = append(e.columnIdxsUsedByChild, childUsedSchema...) - if len(e.columnIdxsUsedByChild) == childUsedSchemaLen { - e.columnIdxsUsedByChild = nil // indicates that all columns are used. LimitExec will improve performance for this condition. - } - return e -} - -func (b *executorBuilder) buildPrepare(v *plannercore.Prepare) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - base.SetInitCap(chunk.ZeroCapacity) - return &PrepareExec{ - BaseExecutor: base, - name: v.Name, - sqlText: v.SQLText, - } -} - -func (b *executorBuilder) buildExecute(v *plannercore.Execute) exec.Executor { - e := &ExecuteExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - is: b.is, - name: v.Name, - usingVars: v.Params, - stmt: v.Stmt, - plan: v.Plan, - outputNames: v.OutputNames(), - } - - failpoint.Inject("assertExecutePrepareStatementStalenessOption", func(val failpoint.Value) { - vs := strings.Split(val.(string), "_") - assertTS, assertReadReplicaScope := vs[0], vs[1] - staleread.AssertStmtStaleness(b.ctx, true) - ts, err := sessiontxn.GetTxnManager(b.ctx).GetStmtReadTS() - if err != nil { - panic(e) - } - - if strconv.FormatUint(ts, 10) != assertTS || - assertReadReplicaScope != b.readReplicaScope { - panic("execute prepare statement have wrong staleness option") - } - }) - - return e -} - -func (b *executorBuilder) buildShow(v *plannercore.PhysicalShow) exec.Executor { - e := &ShowExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - Tp: v.Tp, - CountWarningsOrErrors: v.CountWarningsOrErrors, - DBName: model.NewCIStr(v.DBName), - Table: v.Table, - Partition: v.Partition, - Column: v.Column, - IndexName: v.IndexName, - ResourceGroupName: model.NewCIStr(v.ResourceGroupName), - Flag: v.Flag, - Roles: v.Roles, - User: v.User, - is: b.is, - Full: v.Full, - IfNotExists: v.IfNotExists, - GlobalScope: v.GlobalScope, - Extended: v.Extended, - Extractor: v.Extractor, - ImportJobID: v.ImportJobID, - } - if e.Tp == ast.ShowMasterStatus || e.Tp == ast.ShowBinlogStatus { - // show master status need start ts. - if _, err := e.Ctx().Txn(true); err != nil { - b.err = err - } - } - return e -} - -func (b *executorBuilder) buildSimple(v *plannercore.Simple) exec.Executor { - switch s := v.Statement.(type) { - case *ast.GrantStmt: - return b.buildGrant(s) - case *ast.RevokeStmt: - return b.buildRevoke(s) - case *ast.BRIEStmt: - return b.buildBRIE(s, v.Schema()) - case *ast.CalibrateResourceStmt: - return &calibrateresource.Executor{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), 0), - WorkloadType: s.Tp, - OptionList: s.DynamicCalibrateResourceOptionList, - } - case *ast.AddQueryWatchStmt: - return &querywatch.AddExecutor{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), 0), - QueryWatchOptionList: s.QueryWatchOptionList, - } - case *ast.ImportIntoActionStmt: - return &ImportIntoActionExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, 0), - tp: s.Tp, - jobID: s.JobID, - } - } - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - base.SetInitCap(chunk.ZeroCapacity) - e := &SimpleExec{ - BaseExecutor: base, - Statement: v.Statement, - IsFromRemote: v.IsFromRemote, - is: b.is, - staleTxnStartTS: v.StaleTxnStartTS, - } - return e -} - -func (b *executorBuilder) buildSet(v *plannercore.Set) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - base.SetInitCap(chunk.ZeroCapacity) - e := &SetExecutor{ - BaseExecutor: base, - vars: v.VarAssigns, - } - return e -} - -func (b *executorBuilder) buildSetConfig(v *plannercore.SetConfig) exec.Executor { - return &SetConfigExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - p: v, - } -} - -func (b *executorBuilder) buildInsert(v *plannercore.Insert) exec.Executor { - b.inInsertStmt = true - if b.err = b.updateForUpdateTS(); b.err != nil { - return nil - } - - selectExec := b.build(v.SelectPlan) - if b.err != nil { - return nil - } - var baseExec exec.BaseExecutor - if selectExec != nil { - baseExec = exec.NewBaseExecutor(b.ctx, nil, v.ID(), selectExec) - } else { - baseExec = exec.NewBaseExecutor(b.ctx, nil, v.ID()) - } - baseExec.SetInitCap(chunk.ZeroCapacity) - - ivs := &InsertValues{ - BaseExecutor: baseExec, - Table: v.Table, - Columns: v.Columns, - Lists: v.Lists, - GenExprs: v.GenCols.Exprs, - allAssignmentsAreConstant: v.AllAssignmentsAreConstant, - hasRefCols: v.NeedFillDefaultValue, - SelectExec: selectExec, - rowLen: v.RowLen, - } - err := ivs.initInsertColumns() - if err != nil { - b.err = err - return nil - } - ivs.fkChecks, b.err = buildFKCheckExecs(b.ctx, ivs.Table, v.FKChecks) - if b.err != nil { - return nil - } - ivs.fkCascades, b.err = b.buildFKCascadeExecs(ivs.Table, v.FKCascades) - if b.err != nil { - return nil - } - - if v.IsReplace { - return b.buildReplace(ivs) - } - insert := &InsertExec{ - InsertValues: ivs, - OnDuplicate: append(v.OnDuplicate, v.GenCols.OnDuplicates...), - } - return insert -} - -func (b *executorBuilder) buildImportInto(v *plannercore.ImportInto) exec.Executor { - // see planBuilder.buildImportInto for detail why we use the latest schema here. - latestIS := b.ctx.GetDomainInfoSchema().(infoschema.InfoSchema) - tbl, ok := latestIS.TableByID(v.Table.TableInfo.ID) - if !ok { - b.err = errors.Errorf("Can not get table %d", v.Table.TableInfo.ID) - return nil - } - if !tbl.Meta().IsBaseTable() { - b.err = plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(tbl.Meta().Name.O, "IMPORT") - return nil - } - - var ( - selectExec exec.Executor - base exec.BaseExecutor - ) - if v.SelectPlan != nil { - selectExec = b.build(v.SelectPlan) - if b.err != nil { - return nil - } - base = exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), selectExec) - } else { - base = exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - } - executor, err := newImportIntoExec(base, selectExec, b.ctx, v, tbl) - if err != nil { - b.err = err - return nil - } - - return executor -} - -func (b *executorBuilder) buildLoadData(v *plannercore.LoadData) exec.Executor { - tbl, ok := b.is.TableByID(v.Table.TableInfo.ID) - if !ok { - b.err = errors.Errorf("Can not get table %d", v.Table.TableInfo.ID) - return nil - } - if !tbl.Meta().IsBaseTable() { - b.err = plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(tbl.Meta().Name.O, "LOAD") - return nil - } - - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - worker, err := NewLoadDataWorker(b.ctx, v, tbl) - if err != nil { - b.err = err - return nil - } - - return &LoadDataExec{ - BaseExecutor: base, - loadDataWorker: worker, - FileLocRef: v.FileLocRef, - } -} - -func (b *executorBuilder) buildLoadStats(v *plannercore.LoadStats) exec.Executor { - e := &LoadStatsExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), - info: &LoadStatsInfo{v.Path, b.ctx}, - } - return e -} - -func (b *executorBuilder) buildLockStats(v *plannercore.LockStats) exec.Executor { - e := &lockstats.LockExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), - Tables: v.Tables, - } - return e -} - -func (b *executorBuilder) buildUnlockStats(v *plannercore.UnlockStats) exec.Executor { - e := &lockstats.UnlockExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), - Tables: v.Tables, - } - return e -} - -func (b *executorBuilder) buildIndexAdvise(v *plannercore.IndexAdvise) exec.Executor { - e := &IndexAdviseExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), - IsLocal: v.IsLocal, - indexAdviseInfo: &IndexAdviseInfo{ - Path: v.Path, - MaxMinutes: v.MaxMinutes, - MaxIndexNum: v.MaxIndexNum, - LineFieldsInfo: v.LineFieldsInfo, - Ctx: b.ctx, - }, - } - return e -} - -func (b *executorBuilder) buildPlanReplayer(v *plannercore.PlanReplayer) exec.Executor { - if v.Load { - e := &PlanReplayerLoadExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), - info: &PlanReplayerLoadInfo{Path: v.File, Ctx: b.ctx}, - } - return e - } - if v.Capture { - e := &PlanReplayerExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), - CaptureInfo: &PlanReplayerCaptureInfo{ - SQLDigest: v.SQLDigest, - PlanDigest: v.PlanDigest, - }, - } - return e - } - if v.Remove { - e := &PlanReplayerExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, v.ID()), - CaptureInfo: &PlanReplayerCaptureInfo{ - SQLDigest: v.SQLDigest, - PlanDigest: v.PlanDigest, - Remove: true, - }, - } - return e - } - - e := &PlanReplayerExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - DumpInfo: &PlanReplayerDumpInfo{ - Analyze: v.Analyze, - Path: v.File, - ctx: b.ctx, - HistoricalStatsTS: v.HistoricalStatsTS, - }, - } - if v.ExecStmt != nil { - e.DumpInfo.ExecStmts = []ast.StmtNode{v.ExecStmt} - } else { - e.BaseExecutor = exec.NewBaseExecutor(b.ctx, nil, v.ID()) - } - return e -} - -func (*executorBuilder) buildReplace(vals *InsertValues) exec.Executor { - replaceExec := &ReplaceExec{ - InsertValues: vals, - } - return replaceExec -} - -func (b *executorBuilder) buildGrant(grant *ast.GrantStmt) exec.Executor { - e := &GrantExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, 0), - Privs: grant.Privs, - ObjectType: grant.ObjectType, - Level: grant.Level, - Users: grant.Users, - WithGrant: grant.WithGrant, - AuthTokenOrTLSOptions: grant.AuthTokenOrTLSOptions, - is: b.is, - } - return e -} - -func (b *executorBuilder) buildRevoke(revoke *ast.RevokeStmt) exec.Executor { - e := &RevokeExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, nil, 0), - ctx: b.ctx, - Privs: revoke.Privs, - ObjectType: revoke.ObjectType, - Level: revoke.Level, - Users: revoke.Users, - is: b.is, - } - return e -} - -func (b *executorBuilder) buildDDL(v *plannercore.DDL) exec.Executor { - e := &DDLExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - ddlExecutor: domain.GetDomain(b.ctx).DDLExecutor(), - stmt: v.Statement, - is: b.is, - tempTableDDL: temptable.GetTemporaryTableDDL(b.ctx), - } - return e -} - -// buildTrace builds a TraceExec for future executing. This method will be called -// at build(). -func (b *executorBuilder) buildTrace(v *plannercore.Trace) exec.Executor { - t := &TraceExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - stmtNode: v.StmtNode, - builder: b, - format: v.Format, - - optimizerTrace: v.OptimizerTrace, - optimizerTraceTarget: v.OptimizerTraceTarget, - } - if t.format == plannercore.TraceFormatLog && !t.optimizerTrace { - return &sortexec.SortExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), t), - ByItems: []*plannerutil.ByItems{ - {Expr: &expression.Column{ - Index: 0, - RetType: types.NewFieldType(mysql.TypeTimestamp), - }}, - }, - ExecSchema: v.Schema(), - } - } - return t -} - -// buildExplain builds a explain executor. `e.rows` collects final result to `ExplainExec`. -func (b *executorBuilder) buildExplain(v *plannercore.Explain) exec.Executor { - explainExec := &ExplainExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - explain: v, - } - if v.Analyze { - if b.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl == nil { - b.ctx.GetSessionVars().StmtCtx.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(nil) - } - } - // Needs to build the target plan, even if not executing it - // to get partition pruning. - explainExec.analyzeExec = b.build(v.TargetPlan) - return explainExec -} - -func (b *executorBuilder) buildSelectInto(v *plannercore.SelectInto) exec.Executor { - child := b.build(v.TargetPlan) - if b.err != nil { - return nil - } - return &SelectIntoExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), child), - intoOpt: v.IntoOpt, - LineFieldsInfo: v.LineFieldsInfo, - } -} - -func (b *executorBuilder) buildUnionScanExec(v *plannercore.PhysicalUnionScan) exec.Executor { - oriEncounterUnionScan := b.encounterUnionScan - b.encounterUnionScan = true - defer func() { - b.encounterUnionScan = oriEncounterUnionScan - }() - reader := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - - return b.buildUnionScanFromReader(reader, v) -} - -// buildUnionScanFromReader builds union scan executor from child executor. -// Note that this function may be called by inner workers of index lookup join concurrently. -// Be careful to avoid data race. -func (b *executorBuilder) buildUnionScanFromReader(reader exec.Executor, v *plannercore.PhysicalUnionScan) exec.Executor { - // If reader is union, it means a partition table and we should transfer as above. - if x, ok := reader.(*unionexec.UnionExec); ok { - for i, child := range x.AllChildren() { - x.SetChildren(i, b.buildUnionScanFromReader(child, v)) - if b.err != nil { - return nil - } - } - return x - } - us := &UnionScanExec{BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), reader)} - // Get the handle column index of the below Plan. - us.handleCols = v.HandleCols - us.mutableRow = chunk.MutRowFromTypes(exec.RetTypes(us)) - - // If the push-downed condition contains virtual column, we may build a selection upon reader - originReader := reader - if sel, ok := reader.(*SelectionExec); ok { - reader = sel.Children(0) - } - - us.collators = make([]collate.Collator, 0, len(us.columns)) - for _, tp := range exec.RetTypes(us) { - us.collators = append(us.collators, collate.GetCollator(tp.GetCollate())) - } - - startTS, err := b.getSnapshotTS() - sessionVars := b.ctx.GetSessionVars() - if err != nil { - b.err = err - return nil - } - - switch x := reader.(type) { - case *MPPGather: - us.desc = false - us.keepOrder = false - us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) - us.columns = x.columns - us.table = x.table - us.virtualColumnIndex = x.virtualColumnIndex - us.handleCachedTable(b, x, sessionVars, startTS) - case *TableReaderExecutor: - us.desc = x.desc - us.keepOrder = x.keepOrder - us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) - us.columns = x.columns - us.table = x.table - us.virtualColumnIndex = x.virtualColumnIndex - us.handleCachedTable(b, x, sessionVars, startTS) - case *IndexReaderExecutor: - us.desc = x.desc - us.keepOrder = x.keepOrder - for _, ic := range x.index.Columns { - for i, col := range x.columns { - if col.Name.L == ic.Name.L { - us.usedIndex = append(us.usedIndex, i) - break - } - } - } - us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) - us.columns = x.columns - us.partitionIDMap = x.partitionIDMap - us.table = x.table - us.handleCachedTable(b, x, sessionVars, startTS) - case *IndexLookUpExecutor: - us.desc = x.desc - us.keepOrder = x.keepOrder - for _, ic := range x.index.Columns { - for i, col := range x.columns { - if col.Name.L == ic.Name.L { - us.usedIndex = append(us.usedIndex, i) - break - } - } - } - us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) - us.columns = x.columns - us.table = x.table - us.partitionIDMap = x.partitionIDMap - us.virtualColumnIndex = buildVirtualColumnIndex(us.Schema(), us.columns) - us.handleCachedTable(b, x, sessionVars, startTS) - case *IndexMergeReaderExecutor: - if len(x.byItems) != 0 { - us.keepOrder = x.keepOrder - us.desc = x.byItems[0].Desc - for _, item := range x.byItems { - c, ok := item.Expr.(*expression.Column) - if !ok { - b.err = errors.Errorf("Not support non-column in orderBy pushed down") - return nil - } - for i, col := range x.columns { - if col.ID == c.ID { - us.usedIndex = append(us.usedIndex, i) - break - } - } - } - } - us.partitionIDMap = x.partitionIDMap - us.conditions, us.conditionsWithVirCol = plannercore.SplitSelCondsWithVirtualColumn(v.Conditions) - us.columns = x.columns - us.table = x.table - us.virtualColumnIndex = buildVirtualColumnIndex(us.Schema(), us.columns) - case *PointGetExecutor, *BatchPointGetExec, // PointGet and BatchPoint can handle virtual columns and dirty txn data themselves. - *TableDualExec, // If TableDual, the result must be empty, so we can skip UnionScan and use TableDual directly here. - *TableSampleExecutor: // TableSample only supports sampling from disk, don't need to consider in-memory txn data for simplicity. - return originReader - default: - // TODO: consider more operators like Projection. - b.err = errors.NewNoStackErrorf("unexpected operator %T under UnionScan", reader) - return nil - } - return us -} - -type bypassDataSourceExecutor interface { - dataSourceExecutor - setDummy() -} - -func (us *UnionScanExec) handleCachedTable(b *executorBuilder, x bypassDataSourceExecutor, vars *variable.SessionVars, startTS uint64) { - tbl := x.Table() - if tbl.Meta().TableCacheStatusType == model.TableCacheStatusEnable { - cachedTable := tbl.(table.CachedTable) - // Determine whether the cache can be used. - leaseDuration := time.Duration(variable.TableCacheLease.Load()) * time.Second - cacheData, loading := cachedTable.TryReadFromCache(startTS, leaseDuration) - if cacheData != nil { - vars.StmtCtx.ReadFromTableCache = true - x.setDummy() - us.cacheTable = cacheData - } else if loading { - return - } else { - if !b.inUpdateStmt && !b.inDeleteStmt && !b.inInsertStmt && !vars.StmtCtx.InExplainStmt { - store := b.ctx.GetStore() - cachedTable.UpdateLockForRead(context.Background(), store, startTS, leaseDuration) - } - } - } -} - -// buildMergeJoin builds MergeJoinExec executor. -func (b *executorBuilder) buildMergeJoin(v *plannercore.PhysicalMergeJoin) exec.Executor { - leftExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - - rightExec := b.build(v.Children()[1]) - if b.err != nil { - return nil - } - - defaultValues := v.DefaultValues - if defaultValues == nil { - if v.JoinType == plannercore.RightOuterJoin { - defaultValues = make([]types.Datum, leftExec.Schema().Len()) - } else { - defaultValues = make([]types.Datum, rightExec.Schema().Len()) - } - } - - colsFromChildren := v.Schema().Columns - if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { - colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] - } - - e := &join.MergeJoinExec{ - StmtCtx: b.ctx.GetSessionVars().StmtCtx, - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), - CompareFuncs: v.CompareFuncs, - Joiner: join.NewJoiner( - b.ctx, - v.JoinType, - v.JoinType == plannercore.RightOuterJoin, - defaultValues, - v.OtherConditions, - exec.RetTypes(leftExec), - exec.RetTypes(rightExec), - markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()), - false, - ), - IsOuterJoin: v.JoinType.IsOuterJoin(), - Desc: v.Desc, - } - - leftTable := &join.MergeJoinTable{ - ChildIndex: 0, - JoinKeys: v.LeftJoinKeys, - Filters: v.LeftConditions, - } - rightTable := &join.MergeJoinTable{ - ChildIndex: 1, - JoinKeys: v.RightJoinKeys, - Filters: v.RightConditions, - } - - if v.JoinType == plannercore.RightOuterJoin { - e.InnerTable = leftTable - e.OuterTable = rightTable - } else { - e.InnerTable = rightTable - e.OuterTable = leftTable - } - e.InnerTable.IsInner = true - - // optimizer should guarantee that filters on inner table are pushed down - // to tikv or extracted to a Selection. - if len(e.InnerTable.Filters) != 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "merge join's inner filter should be empty.") - return nil - } - - executor_metrics.ExecutorCounterMergeJoinExec.Inc() - return e -} - -func collectColumnIndexFromExpr(expr expression.Expression, leftColumnSize int, leftColumnIndex []int, rightColumnIndex []int) ([]int, []int) { - switch x := expr.(type) { - case *expression.Column: - colIndex := x.Index - if colIndex >= leftColumnSize { - rightColumnIndex = append(rightColumnIndex, colIndex-leftColumnSize) - } else { - leftColumnIndex = append(leftColumnIndex, colIndex) - } - return leftColumnIndex, rightColumnIndex - case *expression.Constant: - return leftColumnIndex, rightColumnIndex - case *expression.ScalarFunction: - for _, arg := range x.GetArgs() { - leftColumnIndex, rightColumnIndex = collectColumnIndexFromExpr(arg, leftColumnSize, leftColumnIndex, rightColumnIndex) - } - return leftColumnIndex, rightColumnIndex - default: - panic("unsupported expression") - } -} - -func extractUsedColumnsInJoinOtherCondition(expr expression.CNFExprs, leftColumnSize int) ([]int, []int) { - leftColumnIndex := make([]int, 0, 1) - rightColumnIndex := make([]int, 0, 1) - for _, subExpr := range expr { - leftColumnIndex, rightColumnIndex = collectColumnIndexFromExpr(subExpr, leftColumnSize, leftColumnIndex, rightColumnIndex) - } - return leftColumnIndex, rightColumnIndex -} - -func (b *executorBuilder) buildHashJoinV2(v *plannercore.PhysicalHashJoin) exec.Executor { - leftExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - - rightExec := b.build(v.Children()[1]) - if b.err != nil { - return nil - } - - e := &join.HashJoinV2Exec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), - ProbeSideTupleFetcher: &join.ProbeSideTupleFetcherV2{}, - ProbeWorkers: make([]*join.ProbeWorkerV2, v.Concurrency), - BuildWorkers: make([]*join.BuildWorkerV2, v.Concurrency), - HashJoinCtxV2: &join.HashJoinCtxV2{ - OtherCondition: v.OtherConditions, - }, - } - e.HashJoinCtxV2.SessCtx = b.ctx - e.HashJoinCtxV2.JoinType = v.JoinType - e.HashJoinCtxV2.Concurrency = v.Concurrency - e.HashJoinCtxV2.SetupPartitionInfo() - e.ChunkAllocPool = e.AllocPool - e.HashJoinCtxV2.RightAsBuildSide = true - if v.InnerChildIdx == 1 && v.UseOuterToBuild { - e.HashJoinCtxV2.RightAsBuildSide = false - } else if v.InnerChildIdx == 0 && !v.UseOuterToBuild { - e.HashJoinCtxV2.RightAsBuildSide = false - } - - lhsTypes, rhsTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) - joinedTypes := make([]*types.FieldType, 0, len(lhsTypes)+len(rhsTypes)) - joinedTypes = append(joinedTypes, lhsTypes...) - joinedTypes = append(joinedTypes, rhsTypes...) - - if v.InnerChildIdx == 1 { - if len(v.RightConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } else { - if len(v.LeftConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } - - var probeKeys, buildKeys []*expression.Column - var buildSideExec exec.Executor - if v.UseOuterToBuild { - if v.InnerChildIdx == 1 { - buildSideExec, buildKeys = leftExec, v.LeftJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = rightExec, v.RightJoinKeys - e.HashJoinCtxV2.BuildFilter = v.LeftConditions - } else { - buildSideExec, buildKeys = rightExec, v.RightJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = leftExec, v.LeftJoinKeys - e.HashJoinCtxV2.BuildFilter = v.RightConditions - } - } else { - if v.InnerChildIdx == 0 { - buildSideExec, buildKeys = leftExec, v.LeftJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = rightExec, v.RightJoinKeys - e.HashJoinCtxV2.ProbeFilter = v.RightConditions - } else { - buildSideExec, buildKeys = rightExec, v.RightJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys = leftExec, v.LeftJoinKeys - e.HashJoinCtxV2.ProbeFilter = v.LeftConditions - } - } - probeKeyColIdx := make([]int, len(probeKeys)) - buildKeyColIdx := make([]int, len(buildKeys)) - for i := range buildKeys { - buildKeyColIdx[i] = buildKeys[i].Index - } - for i := range probeKeys { - probeKeyColIdx[i] = probeKeys[i].Index - } - - colsFromChildren := v.Schema().Columns - if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { - // the matched column is added inside join - colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] - } - childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) - if childrenUsedSchema == nil { - b.err = errors.New("children used should never be nil") - return nil - } - e.LUsed = make([]int, 0, len(childrenUsedSchema[0])) - e.LUsed = append(e.LUsed, childrenUsedSchema[0]...) - e.RUsed = make([]int, 0, len(childrenUsedSchema[1])) - e.RUsed = append(e.RUsed, childrenUsedSchema[1]...) - if v.OtherConditions != nil { - leftColumnSize := v.Children()[0].Schema().Len() - e.LUsedInOtherCondition, e.RUsedInOtherCondition = extractUsedColumnsInJoinOtherCondition(v.OtherConditions, leftColumnSize) - } - // todo add partition hash join exec - executor_metrics.ExecutorCountHashJoinExec.Inc() - - leftExecTypes, rightExecTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) - leftTypes, rightTypes := make([]*types.FieldType, 0, len(v.LeftJoinKeys)+len(v.LeftNAJoinKeys)), make([]*types.FieldType, 0, len(v.RightJoinKeys)+len(v.RightNAJoinKeys)) - for i, col := range v.LeftJoinKeys { - leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) - leftTypes[i].SetFlag(col.RetType.GetFlag()) - } - offset := len(v.LeftJoinKeys) - for i, col := range v.LeftNAJoinKeys { - leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) - leftTypes[i+offset].SetFlag(col.RetType.GetFlag()) - } - for i, col := range v.RightJoinKeys { - rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) - rightTypes[i].SetFlag(col.RetType.GetFlag()) - } - offset = len(v.RightJoinKeys) - for i, col := range v.RightNAJoinKeys { - rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) - rightTypes[i+offset].SetFlag(col.RetType.GetFlag()) - } - - // consider collations - for i := range v.EqualConditions { - chs, coll := v.EqualConditions[i].CharsetAndCollation() - leftTypes[i].SetCharset(chs) - leftTypes[i].SetCollate(coll) - rightTypes[i].SetCharset(chs) - rightTypes[i].SetCollate(coll) - } - offset = len(v.EqualConditions) - for i := range v.NAEqualConditions { - chs, coll := v.NAEqualConditions[i].CharsetAndCollation() - leftTypes[i+offset].SetCharset(chs) - leftTypes[i+offset].SetCollate(coll) - rightTypes[i+offset].SetCharset(chs) - rightTypes[i+offset].SetCollate(coll) - } - if e.RightAsBuildSide { - e.BuildKeyTypes, e.ProbeKeyTypes = rightTypes, leftTypes - } else { - e.BuildKeyTypes, e.ProbeKeyTypes = leftTypes, rightTypes - } - for i := uint(0); i < e.Concurrency; i++ { - e.ProbeWorkers[i] = &join.ProbeWorkerV2{ - HashJoinCtx: e.HashJoinCtxV2, - JoinProbe: join.NewJoinProbe(e.HashJoinCtxV2, i, v.JoinType, probeKeyColIdx, joinedTypes, e.ProbeKeyTypes, e.RightAsBuildSide), - } - e.ProbeWorkers[i].WorkerID = i - - e.BuildWorkers[i] = join.NewJoinBuildWorkerV2(e.HashJoinCtxV2, i, buildSideExec, buildKeyColIdx, exec.RetTypes(buildSideExec)) - } - return e -} - -func (b *executorBuilder) buildHashJoin(v *plannercore.PhysicalHashJoin) exec.Executor { - if join.IsHashJoinV2Enabled() && v.CanUseHashJoinV2() { - return b.buildHashJoinV2(v) - } - leftExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - - rightExec := b.build(v.Children()[1]) - if b.err != nil { - return nil - } - - e := &join.HashJoinV1Exec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), leftExec, rightExec), - ProbeSideTupleFetcher: &join.ProbeSideTupleFetcherV1{}, - ProbeWorkers: make([]*join.ProbeWorkerV1, v.Concurrency), - BuildWorker: &join.BuildWorkerV1{}, - HashJoinCtxV1: &join.HashJoinCtxV1{ - IsOuterJoin: v.JoinType.IsOuterJoin(), - UseOuterToBuild: v.UseOuterToBuild, - }, - } - e.HashJoinCtxV1.SessCtx = b.ctx - e.HashJoinCtxV1.JoinType = v.JoinType - e.HashJoinCtxV1.Concurrency = v.Concurrency - e.HashJoinCtxV1.ChunkAllocPool = e.AllocPool - defaultValues := v.DefaultValues - lhsTypes, rhsTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) - if v.InnerChildIdx == 1 { - if len(v.RightConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } else { - if len(v.LeftConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } - - leftIsBuildSide := true - - e.IsNullEQ = v.IsNullEQ - var probeKeys, probeNAKeys, buildKeys, buildNAKeys []*expression.Column - var buildSideExec exec.Executor - if v.UseOuterToBuild { - // update the buildSideEstCount due to changing the build side - if v.InnerChildIdx == 1 { - buildSideExec, buildKeys, buildNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys - e.OuterFilter = v.LeftConditions - } else { - buildSideExec, buildKeys, buildNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys - e.OuterFilter = v.RightConditions - leftIsBuildSide = false - } - if defaultValues == nil { - defaultValues = make([]types.Datum, e.ProbeSideTupleFetcher.ProbeSideExec.Schema().Len()) - } - } else { - if v.InnerChildIdx == 0 { - buildSideExec, buildKeys, buildNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys - e.OuterFilter = v.RightConditions - } else { - buildSideExec, buildKeys, buildNAKeys = rightExec, v.RightJoinKeys, v.RightNAJoinKeys - e.ProbeSideTupleFetcher.ProbeSideExec, probeKeys, probeNAKeys = leftExec, v.LeftJoinKeys, v.LeftNAJoinKeys - e.OuterFilter = v.LeftConditions - leftIsBuildSide = false - } - if defaultValues == nil { - defaultValues = make([]types.Datum, buildSideExec.Schema().Len()) - } - } - probeKeyColIdx := make([]int, len(probeKeys)) - probeNAKeColIdx := make([]int, len(probeNAKeys)) - buildKeyColIdx := make([]int, len(buildKeys)) - buildNAKeyColIdx := make([]int, len(buildNAKeys)) - for i := range buildKeys { - buildKeyColIdx[i] = buildKeys[i].Index - } - for i := range buildNAKeys { - buildNAKeyColIdx[i] = buildNAKeys[i].Index - } - for i := range probeKeys { - probeKeyColIdx[i] = probeKeys[i].Index - } - for i := range probeNAKeys { - probeNAKeColIdx[i] = probeNAKeys[i].Index - } - isNAJoin := len(v.LeftNAJoinKeys) > 0 - colsFromChildren := v.Schema().Columns - if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { - colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] - } - childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) - for i := uint(0); i < e.Concurrency; i++ { - e.ProbeWorkers[i] = &join.ProbeWorkerV1{ - HashJoinCtx: e.HashJoinCtxV1, - Joiner: join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, lhsTypes, rhsTypes, childrenUsedSchema, isNAJoin), - ProbeKeyColIdx: probeKeyColIdx, - ProbeNAKeyColIdx: probeNAKeColIdx, - } - e.ProbeWorkers[i].WorkerID = i - } - e.BuildWorker.BuildKeyColIdx, e.BuildWorker.BuildNAKeyColIdx, e.BuildWorker.BuildSideExec, e.BuildWorker.HashJoinCtx = buildKeyColIdx, buildNAKeyColIdx, buildSideExec, e.HashJoinCtxV1 - e.HashJoinCtxV1.IsNullAware = isNAJoin - executor_metrics.ExecutorCountHashJoinExec.Inc() - - // We should use JoinKey to construct the type information using by hashing, instead of using the child's schema directly. - // When a hybrid type column is hashed multiple times, we need to distinguish what field types are used. - // For example, the condition `enum = int and enum = string`, we should use ETInt to hash the first column, - // and use ETString to hash the second column, although they may be the same column. - leftExecTypes, rightExecTypes := exec.RetTypes(leftExec), exec.RetTypes(rightExec) - leftTypes, rightTypes := make([]*types.FieldType, 0, len(v.LeftJoinKeys)+len(v.LeftNAJoinKeys)), make([]*types.FieldType, 0, len(v.RightJoinKeys)+len(v.RightNAJoinKeys)) - // set left types and right types for joiner. - for i, col := range v.LeftJoinKeys { - leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) - leftTypes[i].SetFlag(col.RetType.GetFlag()) - } - offset := len(v.LeftJoinKeys) - for i, col := range v.LeftNAJoinKeys { - leftTypes = append(leftTypes, leftExecTypes[col.Index].Clone()) - leftTypes[i+offset].SetFlag(col.RetType.GetFlag()) - } - for i, col := range v.RightJoinKeys { - rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) - rightTypes[i].SetFlag(col.RetType.GetFlag()) - } - offset = len(v.RightJoinKeys) - for i, col := range v.RightNAJoinKeys { - rightTypes = append(rightTypes, rightExecTypes[col.Index].Clone()) - rightTypes[i+offset].SetFlag(col.RetType.GetFlag()) - } - - // consider collations - for i := range v.EqualConditions { - chs, coll := v.EqualConditions[i].CharsetAndCollation() - leftTypes[i].SetCharset(chs) - leftTypes[i].SetCollate(coll) - rightTypes[i].SetCharset(chs) - rightTypes[i].SetCollate(coll) - } - offset = len(v.EqualConditions) - for i := range v.NAEqualConditions { - chs, coll := v.NAEqualConditions[i].CharsetAndCollation() - leftTypes[i+offset].SetCharset(chs) - leftTypes[i+offset].SetCollate(coll) - rightTypes[i+offset].SetCharset(chs) - rightTypes[i+offset].SetCollate(coll) - } - if leftIsBuildSide { - e.BuildTypes, e.ProbeTypes = leftTypes, rightTypes - } else { - e.BuildTypes, e.ProbeTypes = rightTypes, leftTypes - } - return e -} - -func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) exec.Executor { - src := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - return b.buildHashAggFromChildExec(src, v) -} - -func (b *executorBuilder) buildHashAggFromChildExec(childExec exec.Executor, v *plannercore.PhysicalHashAgg) *aggregate.HashAggExec { - sessionVars := b.ctx.GetSessionVars() - e := &aggregate.HashAggExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), - Sc: sessionVars.StmtCtx, - PartialAggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), - GroupByItems: v.GroupByItems, - } - // We take `create table t(a int, b int);` as example. - // - // 1. If all the aggregation functions are FIRST_ROW, we do not need to set the defaultVal for them: - // e.g. - // mysql> select distinct a, b from t; - // 0 rows in set (0.00 sec) - // - // 2. If there exists group by items, we do not need to set the defaultVal for them either: - // e.g. - // mysql> select avg(a) from t group by b; - // Empty set (0.00 sec) - // - // mysql> select avg(a) from t group by a; - // +--------+ - // | avg(a) | - // +--------+ - // | NULL | - // +--------+ - // 1 row in set (0.00 sec) - if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { - e.DefaultVal = nil - } else { - if v.IsFinalAgg() { - e.DefaultVal = e.AllocPool.Alloc(exec.RetTypes(e), 1, 1) - } - } - for _, aggDesc := range v.AggFuncs { - if aggDesc.HasDistinct || len(aggDesc.OrderByItems) > 0 { - e.IsUnparallelExec = true - } - } - // When we set both tidb_hashagg_final_concurrency and tidb_hashagg_partial_concurrency to 1, - // we do not need to parallelly execute hash agg, - // and this action can be a workaround when meeting some unexpected situation using parallelExec. - if finalCon, partialCon := sessionVars.HashAggFinalConcurrency(), sessionVars.HashAggPartialConcurrency(); finalCon <= 0 || partialCon <= 0 || finalCon == 1 && partialCon == 1 { - e.IsUnparallelExec = true - } - partialOrdinal := 0 - exprCtx := b.ctx.GetExprCtx() - for i, aggDesc := range v.AggFuncs { - if e.IsUnparallelExec { - e.PartialAggFuncs = append(e.PartialAggFuncs, aggfuncs.Build(exprCtx, aggDesc, i)) - } else { - ordinal := []int{partialOrdinal} - partialOrdinal++ - if aggDesc.Name == ast.AggFuncAvg { - ordinal = append(ordinal, partialOrdinal+1) - partialOrdinal++ - } - partialAggDesc, finalDesc := aggDesc.Split(ordinal) - partialAggFunc := aggfuncs.Build(exprCtx, partialAggDesc, i) - finalAggFunc := aggfuncs.Build(exprCtx, finalDesc, i) - e.PartialAggFuncs = append(e.PartialAggFuncs, partialAggFunc) - e.FinalAggFuncs = append(e.FinalAggFuncs, finalAggFunc) - if partialAggDesc.Name == ast.AggFuncGroupConcat { - // For group_concat, finalAggFunc and partialAggFunc need shared `truncate` flag to do duplicate. - finalAggFunc.(interface{ SetTruncated(t *int32) }).SetTruncated( - partialAggFunc.(interface{ GetTruncated() *int32 }).GetTruncated(), - ) - } - } - if e.DefaultVal != nil { - value := aggDesc.GetDefaultValue() - e.DefaultVal.AppendDatum(i, &value) - } - } - - executor_metrics.ExecutorCounterHashAggExec.Inc() - return e -} - -func (b *executorBuilder) buildStreamAgg(v *plannercore.PhysicalStreamAgg) exec.Executor { - src := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - return b.buildStreamAggFromChildExec(src, v) -} - -func (b *executorBuilder) buildStreamAggFromChildExec(childExec exec.Executor, v *plannercore.PhysicalStreamAgg) *aggregate.StreamAggExec { - exprCtx := b.ctx.GetExprCtx() - e := &aggregate.StreamAggExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), - GroupChecker: vecgroupchecker.NewVecGroupChecker(exprCtx.GetEvalCtx(), b.ctx.GetSessionVars().EnableVectorizedExpression, v.GroupByItems), - AggFuncs: make([]aggfuncs.AggFunc, 0, len(v.AggFuncs)), - } - - if len(v.GroupByItems) != 0 || aggregation.IsAllFirstRow(v.AggFuncs) { - e.DefaultVal = nil - } else { - // Only do this for final agg, see issue #35295, #30923 - if v.IsFinalAgg() { - e.DefaultVal = e.AllocPool.Alloc(exec.RetTypes(e), 1, 1) - } - } - for i, aggDesc := range v.AggFuncs { - aggFunc := aggfuncs.Build(exprCtx, aggDesc, i) - e.AggFuncs = append(e.AggFuncs, aggFunc) - if e.DefaultVal != nil { - value := aggDesc.GetDefaultValue() - e.DefaultVal.AppendDatum(i, &value) - } - } - - executor_metrics.ExecutorStreamAggExec.Inc() - return e -} - -func (b *executorBuilder) buildSelection(v *plannercore.PhysicalSelection) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - e := &SelectionExec{ - selectionExecutorContext: newSelectionExecutorContext(b.ctx), - BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), - filters: v.Conditions, - } - return e -} - -func (b *executorBuilder) buildExpand(v *plannercore.PhysicalExpand) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - levelES := make([]*expression.EvaluatorSuite, 0, len(v.LevelExprs)) - for _, exprs := range v.LevelExprs { - // column evaluator can always refer others inside expand. - // grouping column's nullability change should be seen as a new column projecting. - // since input inside expand logic should be targeted and reused for N times. - // column evaluator's swapping columns logic will pollute the input data. - levelE := expression.NewEvaluatorSuite(exprs, true) - levelES = append(levelES, levelE) - } - e := &ExpandExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), - numWorkers: int64(b.ctx.GetSessionVars().ProjectionConcurrency()), - levelEvaluatorSuits: levelES, - } - - // If the calculation row count for this Projection operator is smaller - // than a Chunk size, we turn back to the un-parallel Projection - // implementation to reduce the goroutine overhead. - if int64(v.StatsCount()) < int64(b.ctx.GetSessionVars().MaxChunkSize) { - e.numWorkers = 0 - } - - // Use un-parallel projection for query that write on memdb to avoid data race. - // See also https://github.com/pingcap/tidb/issues/26832 - if b.inUpdateStmt || b.inDeleteStmt || b.inInsertStmt || b.hasLock { - e.numWorkers = 0 - } - return e -} - -func (b *executorBuilder) buildProjection(v *plannercore.PhysicalProjection) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - e := &ProjectionExec{ - projectionExecutorContext: newProjectionExecutorContext(b.ctx), - BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), - numWorkers: int64(b.ctx.GetSessionVars().ProjectionConcurrency()), - evaluatorSuit: expression.NewEvaluatorSuite(v.Exprs, v.AvoidColumnEvaluator), - calculateNoDelay: v.CalculateNoDelay, - } - - // If the calculation row count for this Projection operator is smaller - // than a Chunk size, we turn back to the un-parallel Projection - // implementation to reduce the goroutine overhead. - if int64(v.StatsCount()) < int64(b.ctx.GetSessionVars().MaxChunkSize) { - e.numWorkers = 0 - } - - // Use un-parallel projection for query that write on memdb to avoid data race. - // See also https://github.com/pingcap/tidb/issues/26832 - if b.inUpdateStmt || b.inDeleteStmt || b.inInsertStmt || b.hasLock { - e.numWorkers = 0 - } - return e -} - -func (b *executorBuilder) buildTableDual(v *plannercore.PhysicalTableDual) exec.Executor { - if v.RowCount != 0 && v.RowCount != 1 { - b.err = errors.Errorf("buildTableDual failed, invalid row count for dual table: %v", v.RowCount) - return nil - } - base := exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()) - base.SetInitCap(v.RowCount) - e := &TableDualExec{ - BaseExecutorV2: base, - numDualRows: v.RowCount, - } - return e -} - -// `getSnapshotTS` returns for-update-ts if in insert/update/delete/lock statement otherwise the isolation read ts -// Please notice that in RC isolation, the above two ts are the same -func (b *executorBuilder) getSnapshotTS() (ts uint64, err error) { - if b.forDataReaderBuilder { - return b.dataReaderTS, nil - } - - txnManager := sessiontxn.GetTxnManager(b.ctx) - if b.inInsertStmt || b.inUpdateStmt || b.inDeleteStmt || b.inSelectLockStmt { - return txnManager.GetStmtForUpdateTS() - } - return txnManager.GetStmtReadTS() -} - -// getSnapshot get the appropriate snapshot from txnManager and set -// the relevant snapshot options before return. -func (b *executorBuilder) getSnapshot() (kv.Snapshot, error) { - var snapshot kv.Snapshot - var err error - - txnManager := sessiontxn.GetTxnManager(b.ctx) - if b.inInsertStmt || b.inUpdateStmt || b.inDeleteStmt || b.inSelectLockStmt { - snapshot, err = txnManager.GetSnapshotWithStmtForUpdateTS() - } else { - snapshot, err = txnManager.GetSnapshotWithStmtReadTS() - } - if err != nil { - return nil, err - } - - sessVars := b.ctx.GetSessionVars() - replicaReadType := sessVars.GetReplicaRead() - snapshot.SetOption(kv.ReadReplicaScope, b.readReplicaScope) - snapshot.SetOption(kv.TaskID, sessVars.StmtCtx.TaskID) - snapshot.SetOption(kv.TiKVClientReadTimeout, sessVars.GetTiKVClientReadTimeout()) - snapshot.SetOption(kv.ResourceGroupName, sessVars.StmtCtx.ResourceGroupName) - snapshot.SetOption(kv.ExplicitRequestSourceType, sessVars.ExplicitRequestSourceType) - - if replicaReadType.IsClosestRead() && b.readReplicaScope != kv.GlobalTxnScope { - snapshot.SetOption(kv.MatchStoreLabels, []*metapb.StoreLabel{ - { - Key: placement.DCLabelKey, - Value: b.readReplicaScope, - }, - }) - } - - return snapshot, nil -} - -func (b *executorBuilder) buildMemTable(v *plannercore.PhysicalMemTable) exec.Executor { - switch v.DBName.L { - case util.MetricSchemaName.L: - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &MetricRetriever{ - table: v.Table, - extractor: v.Extractor.(*plannercore.MetricTableExtractor), - }, - } - case util.InformationSchemaName.L: - switch v.Table.Name.L { - case strings.ToLower(infoschema.TableClusterConfig): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &clusterConfigRetriever{ - extractor: v.Extractor.(*plannercore.ClusterTableExtractor), - }, - } - case strings.ToLower(infoschema.TableClusterLoad): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &clusterServerInfoRetriever{ - extractor: v.Extractor.(*plannercore.ClusterTableExtractor), - serverInfoType: diagnosticspb.ServerInfoType_LoadInfo, - }, - } - case strings.ToLower(infoschema.TableClusterHardware): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &clusterServerInfoRetriever{ - extractor: v.Extractor.(*plannercore.ClusterTableExtractor), - serverInfoType: diagnosticspb.ServerInfoType_HardwareInfo, - }, - } - case strings.ToLower(infoschema.TableClusterSystemInfo): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &clusterServerInfoRetriever{ - extractor: v.Extractor.(*plannercore.ClusterTableExtractor), - serverInfoType: diagnosticspb.ServerInfoType_SystemInfo, - }, - } - case strings.ToLower(infoschema.TableClusterLog): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &clusterLogRetriever{ - extractor: v.Extractor.(*plannercore.ClusterLogTableExtractor), - }, - } - case strings.ToLower(infoschema.TableTiDBHotRegionsHistory): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &hotRegionsHistoryRetriver{ - extractor: v.Extractor.(*plannercore.HotRegionsHistoryTableExtractor), - }, - } - case strings.ToLower(infoschema.TableInspectionResult): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &inspectionResultRetriever{ - extractor: v.Extractor.(*plannercore.InspectionResultTableExtractor), - timeRange: v.QueryTimeRange, - }, - } - case strings.ToLower(infoschema.TableInspectionSummary): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &inspectionSummaryRetriever{ - table: v.Table, - extractor: v.Extractor.(*plannercore.InspectionSummaryTableExtractor), - timeRange: v.QueryTimeRange, - }, - } - case strings.ToLower(infoschema.TableInspectionRules): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &inspectionRuleRetriever{ - extractor: v.Extractor.(*plannercore.InspectionRuleTableExtractor), - }, - } - case strings.ToLower(infoschema.TableMetricSummary): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &MetricsSummaryRetriever{ - table: v.Table, - extractor: v.Extractor.(*plannercore.MetricSummaryTableExtractor), - timeRange: v.QueryTimeRange, - }, - } - case strings.ToLower(infoschema.TableMetricSummaryByLabel): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &MetricsSummaryByLabelRetriever{ - table: v.Table, - extractor: v.Extractor.(*plannercore.MetricSummaryTableExtractor), - timeRange: v.QueryTimeRange, - }, - } - case strings.ToLower(infoschema.TableTiKVRegionPeers): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &tikvRegionPeersRetriever{ - extractor: v.Extractor.(*plannercore.TikvRegionPeersExtractor), - }, - } - case strings.ToLower(infoschema.TableSchemata), - strings.ToLower(infoschema.TableStatistics), - strings.ToLower(infoschema.TableTiDBIndexes), - strings.ToLower(infoschema.TableViews), - strings.ToLower(infoschema.TableTables), - strings.ToLower(infoschema.TableReferConst), - strings.ToLower(infoschema.TableSequences), - strings.ToLower(infoschema.TablePartitions), - strings.ToLower(infoschema.TableEngines), - strings.ToLower(infoschema.TableCollations), - strings.ToLower(infoschema.TableAnalyzeStatus), - strings.ToLower(infoschema.TableClusterInfo), - strings.ToLower(infoschema.TableProfiling), - strings.ToLower(infoschema.TableCharacterSets), - strings.ToLower(infoschema.TableKeyColumn), - strings.ToLower(infoschema.TableUserPrivileges), - strings.ToLower(infoschema.TableMetricTables), - strings.ToLower(infoschema.TableCollationCharacterSetApplicability), - strings.ToLower(infoschema.TableProcesslist), - strings.ToLower(infoschema.ClusterTableProcesslist), - strings.ToLower(infoschema.TableTiKVRegionStatus), - strings.ToLower(infoschema.TableTiDBHotRegions), - strings.ToLower(infoschema.TableSessionVar), - strings.ToLower(infoschema.TableConstraints), - strings.ToLower(infoschema.TableTiFlashReplica), - strings.ToLower(infoschema.TableTiDBServersInfo), - strings.ToLower(infoschema.TableTiKVStoreStatus), - strings.ToLower(infoschema.TableClientErrorsSummaryGlobal), - strings.ToLower(infoschema.TableClientErrorsSummaryByUser), - strings.ToLower(infoschema.TableClientErrorsSummaryByHost), - strings.ToLower(infoschema.TableAttributes), - strings.ToLower(infoschema.TablePlacementPolicies), - strings.ToLower(infoschema.TableTrxSummary), - strings.ToLower(infoschema.TableVariablesInfo), - strings.ToLower(infoschema.TableUserAttributes), - strings.ToLower(infoschema.ClusterTableTrxSummary), - strings.ToLower(infoschema.TableMemoryUsage), - strings.ToLower(infoschema.TableMemoryUsageOpsHistory), - strings.ToLower(infoschema.ClusterTableMemoryUsage), - strings.ToLower(infoschema.ClusterTableMemoryUsageOpsHistory), - strings.ToLower(infoschema.TableResourceGroups), - strings.ToLower(infoschema.TableRunawayWatches), - strings.ToLower(infoschema.TableCheckConstraints), - strings.ToLower(infoschema.TableTiDBCheckConstraints), - strings.ToLower(infoschema.TableKeywords), - strings.ToLower(infoschema.TableTiDBIndexUsage), - strings.ToLower(infoschema.ClusterTableTiDBIndexUsage): - memTracker := memory.NewTracker(v.ID(), -1) - memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker) - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &memtableRetriever{ - table: v.Table, - columns: v.Columns, - extractor: v.Extractor, - memTracker: memTracker, - }, - } - case strings.ToLower(infoschema.TableTiDBTrx), - strings.ToLower(infoschema.ClusterTableTiDBTrx): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &tidbTrxTableRetriever{ - table: v.Table, - columns: v.Columns, - }, - } - case strings.ToLower(infoschema.TableDataLockWaits): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &dataLockWaitsTableRetriever{ - table: v.Table, - columns: v.Columns, - }, - } - case strings.ToLower(infoschema.TableDeadlocks), - strings.ToLower(infoschema.ClusterTableDeadlocks): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &deadlocksTableRetriever{ - table: v.Table, - columns: v.Columns, - }, - } - case strings.ToLower(infoschema.TableStatementsSummary), - strings.ToLower(infoschema.TableStatementsSummaryHistory), - strings.ToLower(infoschema.TableStatementsSummaryEvicted), - strings.ToLower(infoschema.ClusterTableStatementsSummary), - strings.ToLower(infoschema.ClusterTableStatementsSummaryHistory), - strings.ToLower(infoschema.ClusterTableStatementsSummaryEvicted): - var extractor *plannercore.StatementsSummaryExtractor - if v.Extractor != nil { - extractor = v.Extractor.(*plannercore.StatementsSummaryExtractor) - } - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: buildStmtSummaryRetriever(v.Table, v.Columns, extractor), - } - case strings.ToLower(infoschema.TableColumns): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &hugeMemTableRetriever{ - table: v.Table, - columns: v.Columns, - extractor: v.Extractor.(*plannercore.ColumnsTableExtractor), - viewSchemaMap: make(map[int64]*expression.Schema), - viewOutputNamesMap: make(map[int64]types.NameSlice), - }, - } - case strings.ToLower(infoschema.TableSlowQuery), strings.ToLower(infoschema.ClusterTableSlowLog): - memTracker := memory.NewTracker(v.ID(), -1) - memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker) - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &slowQueryRetriever{ - table: v.Table, - outputCols: v.Columns, - extractor: v.Extractor.(*plannercore.SlowQueryExtractor), - memTracker: memTracker, - }, - } - case strings.ToLower(infoschema.TableStorageStats): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &tableStorageStatsRetriever{ - table: v.Table, - outputCols: v.Columns, - extractor: v.Extractor.(*plannercore.TableStorageStatsExtractor), - }, - } - case strings.ToLower(infoschema.TableDDLJobs): - loc := b.ctx.GetSessionVars().Location() - ddlJobRetriever := DDLJobRetriever{TZLoc: loc} - return &DDLJobsReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - is: b.is, - DDLJobRetriever: ddlJobRetriever, - } - case strings.ToLower(infoschema.TableTiFlashTables), - strings.ToLower(infoschema.TableTiFlashSegments): - return &MemTableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.Table, - retriever: &TiFlashSystemTableRetriever{ - table: v.Table, - outputCols: v.Columns, - extractor: v.Extractor.(*plannercore.TiFlashSystemTableExtractor), - }, - } - } - } - tb, _ := b.is.TableByID(v.Table.ID) - return &TableScanExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - t: tb, - columns: v.Columns, - } -} - -func (b *executorBuilder) buildSort(v *plannercore.PhysicalSort) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - sortExec := sortexec.SortExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), - ByItems: v.ByItems, - ExecSchema: v.Schema(), - } - executor_metrics.ExecutorCounterSortExec.Inc() - return &sortExec -} - -func (b *executorBuilder) buildTopN(v *plannercore.PhysicalTopN) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - sortExec := sortexec.SortExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec), - ByItems: v.ByItems, - ExecSchema: v.Schema(), - } - executor_metrics.ExecutorCounterTopNExec.Inc() - return &sortexec.TopNExec{ - SortExec: sortExec, - Limit: &plannercore.PhysicalLimit{Count: v.Count, Offset: v.Offset}, - Concurrency: b.ctx.GetSessionVars().Concurrency.ExecutorConcurrency, - } -} - -func (b *executorBuilder) buildApply(v *plannercore.PhysicalApply) exec.Executor { - var ( - innerPlan base.PhysicalPlan - outerPlan base.PhysicalPlan - ) - if v.InnerChildIdx == 0 { - innerPlan = v.Children()[0] - outerPlan = v.Children()[1] - } else { - innerPlan = v.Children()[1] - outerPlan = v.Children()[0] - } - v.OuterSchema = coreusage.ExtractCorColumnsBySchema4PhysicalPlan(innerPlan, outerPlan.Schema()) - leftChild := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - rightChild := b.build(v.Children()[1]) - if b.err != nil { - return nil - } - // test is in the explain/naaj.test#part5. - // although we prepared the NAEqualConditions, but for Apply mode, we still need move it to other conditions like eq condition did here. - otherConditions := append(expression.ScalarFuncs2Exprs(v.EqualConditions), expression.ScalarFuncs2Exprs(v.NAEqualConditions)...) - otherConditions = append(otherConditions, v.OtherConditions...) - defaultValues := v.DefaultValues - if defaultValues == nil { - defaultValues = make([]types.Datum, v.Children()[v.InnerChildIdx].Schema().Len()) - } - outerExec, innerExec := leftChild, rightChild - outerFilter, innerFilter := v.LeftConditions, v.RightConditions - if v.InnerChildIdx == 0 { - outerExec, innerExec = rightChild, leftChild - outerFilter, innerFilter = v.RightConditions, v.LeftConditions - } - tupleJoiner := join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, - defaultValues, otherConditions, exec.RetTypes(leftChild), exec.RetTypes(rightChild), nil, false) - serialExec := &join.NestedLoopApplyExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec, innerExec), - InnerExec: innerExec, - OuterExec: outerExec, - OuterFilter: outerFilter, - InnerFilter: innerFilter, - Outer: v.JoinType != plannercore.InnerJoin, - Joiner: tupleJoiner, - OuterSchema: v.OuterSchema, - Sctx: b.ctx, - CanUseCache: v.CanUseCache, - } - executor_metrics.ExecutorCounterNestedLoopApplyExec.Inc() - - // try parallel mode - if v.Concurrency > 1 { - innerExecs := make([]exec.Executor, 0, v.Concurrency) - innerFilters := make([]expression.CNFExprs, 0, v.Concurrency) - corCols := make([][]*expression.CorrelatedColumn, 0, v.Concurrency) - joiners := make([]join.Joiner, 0, v.Concurrency) - for i := 0; i < v.Concurrency; i++ { - clonedInnerPlan, err := plannercore.SafeClone(v.SCtx(), innerPlan) - if err != nil { - b.err = nil - return serialExec - } - corCol := coreusage.ExtractCorColumnsBySchema4PhysicalPlan(clonedInnerPlan, outerPlan.Schema()) - clonedInnerExec := b.build(clonedInnerPlan) - if b.err != nil { - b.err = nil - return serialExec - } - innerExecs = append(innerExecs, clonedInnerExec) - corCols = append(corCols, corCol) - innerFilters = append(innerFilters, innerFilter.Clone()) - joiners = append(joiners, join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, - defaultValues, otherConditions, exec.RetTypes(leftChild), exec.RetTypes(rightChild), nil, false)) - } - - allExecs := append([]exec.Executor{outerExec}, innerExecs...) - - return &ParallelNestedLoopApplyExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), allExecs...), - innerExecs: innerExecs, - outerExec: outerExec, - outerFilter: outerFilter, - innerFilter: innerFilters, - outer: v.JoinType != plannercore.InnerJoin, - joiners: joiners, - corCols: corCols, - concurrency: v.Concurrency, - useCache: v.CanUseCache, - } - } - return serialExec -} - -func (b *executorBuilder) buildMaxOneRow(v *plannercore.PhysicalMaxOneRow) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec) - base.SetInitCap(2) - base.SetMaxChunkSize(2) - e := &MaxOneRowExec{BaseExecutor: base} - return e -} - -func (b *executorBuilder) buildUnionAll(v *plannercore.PhysicalUnionAll) exec.Executor { - childExecs := make([]exec.Executor, len(v.Children())) - for i, child := range v.Children() { - childExecs[i] = b.build(child) - if b.err != nil { - return nil - } - } - e := &unionexec.UnionExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExecs...), - Concurrency: b.ctx.GetSessionVars().UnionConcurrency(), - } - return e -} - -func buildHandleColsForSplit(sc *stmtctx.StatementContext, tbInfo *model.TableInfo) plannerutil.HandleCols { - if tbInfo.IsCommonHandle { - primaryIdx := tables.FindPrimaryIndex(tbInfo) - tableCols := make([]*expression.Column, len(tbInfo.Columns)) - for i, col := range tbInfo.Columns { - tableCols[i] = &expression.Column{ - ID: col.ID, - RetType: &col.FieldType, - } - } - for i, pkCol := range primaryIdx.Columns { - tableCols[pkCol.Offset].Index = i - } - return plannerutil.NewCommonHandleCols(sc, tbInfo, primaryIdx, tableCols) - } - intCol := &expression.Column{ - RetType: types.NewFieldType(mysql.TypeLonglong), - } - return plannerutil.NewIntHandleCols(intCol) -} - -func (b *executorBuilder) buildSplitRegion(v *plannercore.SplitRegion) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - base.SetInitCap(1) - base.SetMaxChunkSize(1) - if v.IndexInfo != nil { - return &SplitIndexRegionExec{ - BaseExecutor: base, - tableInfo: v.TableInfo, - partitionNames: v.PartitionNames, - indexInfo: v.IndexInfo, - lower: v.Lower, - upper: v.Upper, - num: v.Num, - valueLists: v.ValueLists, - } - } - handleCols := buildHandleColsForSplit(b.ctx.GetSessionVars().StmtCtx, v.TableInfo) - if len(v.ValueLists) > 0 { - return &SplitTableRegionExec{ - BaseExecutor: base, - tableInfo: v.TableInfo, - partitionNames: v.PartitionNames, - handleCols: handleCols, - valueLists: v.ValueLists, - } - } - return &SplitTableRegionExec{ - BaseExecutor: base, - tableInfo: v.TableInfo, - partitionNames: v.PartitionNames, - handleCols: handleCols, - lower: v.Lower, - upper: v.Upper, - num: v.Num, - } -} - -func (b *executorBuilder) buildUpdate(v *plannercore.Update) exec.Executor { - b.inUpdateStmt = true - tblID2table := make(map[int64]table.Table, len(v.TblColPosInfos)) - multiUpdateOnSameTable := make(map[int64]bool) - for _, info := range v.TblColPosInfos { - tbl, _ := b.is.TableByID(info.TblID) - if _, ok := tblID2table[info.TblID]; ok { - multiUpdateOnSameTable[info.TblID] = true - } - tblID2table[info.TblID] = tbl - if len(v.PartitionedTable) > 0 { - // The v.PartitionedTable collects the partitioned table. - // Replace the original table with the partitioned table to support partition selection. - // e.g. update t partition (p0, p1), the new values are not belong to the given set p0, p1 - // Using the table in v.PartitionedTable returns a proper error, while using the original table can't. - for _, p := range v.PartitionedTable { - if info.TblID == p.Meta().ID { - tblID2table[info.TblID] = p - } - } - } - } - if b.err = b.updateForUpdateTS(); b.err != nil { - return nil - } - - selExec := b.build(v.SelectPlan) - if b.err != nil { - return nil - } - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), selExec) - base.SetInitCap(chunk.ZeroCapacity) - var assignFlag []int - assignFlag, b.err = getAssignFlag(b.ctx, v, selExec.Schema().Len()) - if b.err != nil { - return nil - } - // should use the new tblID2table, since the update's schema may have been changed in Execstmt. - b.err = plannercore.CheckUpdateList(assignFlag, v, tblID2table) - if b.err != nil { - return nil - } - updateExec := &UpdateExec{ - BaseExecutor: base, - OrderedList: v.OrderedList, - allAssignmentsAreConstant: v.AllAssignmentsAreConstant, - virtualAssignmentsOffset: v.VirtualAssignmentsOffset, - multiUpdateOnSameTable: multiUpdateOnSameTable, - tblID2table: tblID2table, - tblColPosInfos: v.TblColPosInfos, - assignFlag: assignFlag, - } - updateExec.fkChecks, b.err = buildTblID2FKCheckExecs(b.ctx, tblID2table, v.FKChecks) - if b.err != nil { - return nil - } - updateExec.fkCascades, b.err = b.buildTblID2FKCascadeExecs(tblID2table, v.FKCascades) - if b.err != nil { - return nil - } - return updateExec -} - -func getAssignFlag(ctx sessionctx.Context, v *plannercore.Update, schemaLen int) ([]int, error) { - assignFlag := make([]int, schemaLen) - for i := range assignFlag { - assignFlag[i] = -1 - } - for _, assign := range v.OrderedList { - if !ctx.GetSessionVars().AllowWriteRowID && assign.Col.ID == model.ExtraHandleID { - return nil, errors.Errorf("insert, update and replace statements for _tidb_rowid are not supported") - } - tblIdx, found := v.TblColPosInfos.FindTblIdx(assign.Col.Index) - if found { - colIdx := assign.Col.Index - assignFlag[colIdx] = tblIdx - } - } - return assignFlag, nil -} - -func (b *executorBuilder) buildDelete(v *plannercore.Delete) exec.Executor { - b.inDeleteStmt = true - tblID2table := make(map[int64]table.Table, len(v.TblColPosInfos)) - for _, info := range v.TblColPosInfos { - tblID2table[info.TblID], _ = b.is.TableByID(info.TblID) - } - - if b.err = b.updateForUpdateTS(); b.err != nil { - return nil - } - - selExec := b.build(v.SelectPlan) - if b.err != nil { - return nil - } - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), selExec) - base.SetInitCap(chunk.ZeroCapacity) - deleteExec := &DeleteExec{ - BaseExecutor: base, - tblID2Table: tblID2table, - IsMultiTable: v.IsMultiTable, - tblColPosInfos: v.TblColPosInfos, - } - deleteExec.fkChecks, b.err = buildTblID2FKCheckExecs(b.ctx, tblID2table, v.FKChecks) - if b.err != nil { - return nil - } - deleteExec.fkCascades, b.err = b.buildTblID2FKCascadeExecs(tblID2table, v.FKCascades) - if b.err != nil { - return nil - } - return deleteExec -} - -func (b *executorBuilder) updateForUpdateTS() error { - // GetStmtForUpdateTS will auto update the for update ts if it is necessary - _, err := sessiontxn.GetTxnManager(b.ctx).GetStmtForUpdateTS() - return err -} - -func (b *executorBuilder) buildAnalyzeIndexPushdown(task plannercore.AnalyzeIndexTask, opts map[ast.AnalyzeOptionType]uint64, autoAnalyze string) *analyzeTask { - job := &statistics.AnalyzeJob{DBName: task.DBName, TableName: task.TableName, PartitionName: task.PartitionName, JobInfo: autoAnalyze + "analyze index " + task.IndexInfo.Name.O} - _, offset := timeutil.Zone(b.ctx.GetSessionVars().Location()) - sc := b.ctx.GetSessionVars().StmtCtx - startTS, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { - startTS = uint64(val.(int)) - }) - concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) - base := baseAnalyzeExec{ - ctx: b.ctx, - tableID: task.TableID, - concurrency: concurrency, - analyzePB: &tipb.AnalyzeReq{ - Tp: tipb.AnalyzeType_TypeIndex, - Flags: sc.PushDownFlags(), - TimeZoneOffset: offset, - }, - opts: opts, - job: job, - snapshot: startTS, - } - e := &AnalyzeIndexExec{ - baseAnalyzeExec: base, - isCommonHandle: task.TblInfo.IsCommonHandle, - idxInfo: task.IndexInfo, - } - topNSize := new(int32) - *topNSize = int32(opts[ast.AnalyzeOptNumTopN]) - statsVersion := new(int32) - *statsVersion = int32(task.StatsVersion) - e.analyzePB.IdxReq = &tipb.AnalyzeIndexReq{ - BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), - NumColumns: int32(len(task.IndexInfo.Columns)), - TopNSize: topNSize, - Version: statsVersion, - SketchSize: statistics.MaxSketchSize, - } - if e.isCommonHandle && e.idxInfo.Primary { - e.analyzePB.Tp = tipb.AnalyzeType_TypeCommonHandle - } - depth := int32(opts[ast.AnalyzeOptCMSketchDepth]) - width := int32(opts[ast.AnalyzeOptCMSketchWidth]) - e.analyzePB.IdxReq.CmsketchDepth = &depth - e.analyzePB.IdxReq.CmsketchWidth = &width - return &analyzeTask{taskType: idxTask, idxExec: e, job: job} -} - -func (b *executorBuilder) buildAnalyzeSamplingPushdown( - task plannercore.AnalyzeColumnsTask, - opts map[ast.AnalyzeOptionType]uint64, - schemaForVirtualColEval *expression.Schema, -) *analyzeTask { - if task.V2Options != nil { - opts = task.V2Options.FilledOpts - } - availableIdx := make([]*model.IndexInfo, 0, len(task.Indexes)) - colGroups := make([]*tipb.AnalyzeColumnGroup, 0, len(task.Indexes)) - if len(task.Indexes) > 0 { - for _, idx := range task.Indexes { - availableIdx = append(availableIdx, idx) - colGroup := &tipb.AnalyzeColumnGroup{ - ColumnOffsets: make([]int64, 0, len(idx.Columns)), - } - for _, col := range idx.Columns { - colGroup.ColumnOffsets = append(colGroup.ColumnOffsets, int64(col.Offset)) - } - colGroups = append(colGroups, colGroup) - } - } - - _, offset := timeutil.Zone(b.ctx.GetSessionVars().Location()) - sc := b.ctx.GetSessionVars().StmtCtx - startTS, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { - startTS = uint64(val.(int)) - }) - statsHandle := domain.GetDomain(b.ctx).StatsHandle() - count, modifyCount, err := statsHandle.StatsMetaCountAndModifyCount(task.TableID.GetStatisticsID()) - if err != nil { - b.err = err - return nil - } - failpoint.Inject("injectBaseCount", func(val failpoint.Value) { - count = int64(val.(int)) - }) - failpoint.Inject("injectBaseModifyCount", func(val failpoint.Value) { - modifyCount = int64(val.(int)) - }) - sampleRate := new(float64) - var sampleRateReason string - if opts[ast.AnalyzeOptNumSamples] == 0 { - *sampleRate = math.Float64frombits(opts[ast.AnalyzeOptSampleRate]) - if *sampleRate < 0 { - *sampleRate, sampleRateReason = b.getAdjustedSampleRate(task) - if task.PartitionName != "" { - sc.AppendNote(errors.NewNoStackErrorf( - `Analyze use auto adjusted sample rate %f for table %s.%s's partition %s, reason to use this rate is "%s"`, - *sampleRate, - task.DBName, - task.TableName, - task.PartitionName, - sampleRateReason, - )) - } else { - sc.AppendNote(errors.NewNoStackErrorf( - `Analyze use auto adjusted sample rate %f for table %s.%s, reason to use this rate is "%s"`, - *sampleRate, - task.DBName, - task.TableName, - sampleRateReason, - )) - } - } - } - job := &statistics.AnalyzeJob{ - DBName: task.DBName, - TableName: task.TableName, - PartitionName: task.PartitionName, - SampleRateReason: sampleRateReason, - } - concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) - base := baseAnalyzeExec{ - ctx: b.ctx, - tableID: task.TableID, - concurrency: concurrency, - analyzePB: &tipb.AnalyzeReq{ - Tp: tipb.AnalyzeType_TypeFullSampling, - Flags: sc.PushDownFlags(), - TimeZoneOffset: offset, - }, - opts: opts, - job: job, - snapshot: startTS, - } - e := &AnalyzeColumnsExec{ - baseAnalyzeExec: base, - tableInfo: task.TblInfo, - colsInfo: task.ColsInfo, - handleCols: task.HandleCols, - indexes: availableIdx, - AnalyzeInfo: task.AnalyzeInfo, - schemaForVirtualColEval: schemaForVirtualColEval, - baseCount: count, - baseModifyCnt: modifyCount, - } - e.analyzePB.ColReq = &tipb.AnalyzeColumnsReq{ - BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), - SampleSize: int64(opts[ast.AnalyzeOptNumSamples]), - SampleRate: sampleRate, - SketchSize: statistics.MaxSketchSize, - ColumnsInfo: util.ColumnsToProto(task.ColsInfo, task.TblInfo.PKIsHandle, false), - ColumnGroups: colGroups, - } - if task.TblInfo != nil { - e.analyzePB.ColReq.PrimaryColumnIds = tables.TryGetCommonPkColumnIds(task.TblInfo) - if task.TblInfo.IsCommonHandle { - e.analyzePB.ColReq.PrimaryPrefixColumnIds = tables.PrimaryPrefixColumnIDs(task.TblInfo) - } - } - b.err = tables.SetPBColumnsDefaultValue(b.ctx.GetExprCtx(), e.analyzePB.ColReq.ColumnsInfo, task.ColsInfo) - return &analyzeTask{taskType: colTask, colExec: e, job: job} -} - -// getAdjustedSampleRate calculate the sample rate by the table size. If we cannot get the table size. We use the 0.001 as the default sample rate. -// From the paper "Random sampling for histogram construction: how much is enough?"'s Corollary 1 to Theorem 5, -// for a table size n, histogram size k, maximum relative error in bin size f, and error probability gamma, -// the minimum random sample size is -// -// r = 4 * k * ln(2*n/gamma) / f^2 -// -// If we take f = 0.5, gamma = 0.01, n =1e6, we would got r = 305.82* k. -// Since the there's log function over the table size n, the r grows slowly when the n increases. -// If we take n = 1e12, a 300*k sample still gives <= 0.66 bin size error with probability 0.99. -// So if we don't consider the top-n values, we can keep the sample size at 300*256. -// But we may take some top-n before building the histogram, so we increase the sample a little. -func (b *executorBuilder) getAdjustedSampleRate(task plannercore.AnalyzeColumnsTask) (sampleRate float64, reason string) { - statsHandle := domain.GetDomain(b.ctx).StatsHandle() - defaultRate := 0.001 - if statsHandle == nil { - return defaultRate, fmt.Sprintf("statsHandler is nil, use the default-rate=%v", defaultRate) - } - var statsTbl *statistics.Table - tid := task.TableID.GetStatisticsID() - if tid == task.TblInfo.ID { - statsTbl = statsHandle.GetTableStats(task.TblInfo) - } else { - statsTbl = statsHandle.GetPartitionStats(task.TblInfo, tid) - } - approxiCount, hasPD := b.getApproximateTableCountFromStorage(tid, task) - // If there's no stats meta and no pd, return the default rate. - if statsTbl == nil && !hasPD { - return defaultRate, fmt.Sprintf("TiDB cannot get the row count of the table, use the default-rate=%v", defaultRate) - } - // If the count in stats_meta is still 0 and there's no information from pd side, we scan all rows. - if statsTbl.RealtimeCount == 0 && !hasPD { - return 1, "TiDB assumes that the table is empty and cannot get row count from PD, use sample-rate=1" - } - // we have issue https://github.com/pingcap/tidb/issues/29216. - // To do a workaround for this issue, we check the approxiCount from the pd side to do a comparison. - // If the count from the stats_meta is extremely smaller than the approximate count from the pd, - // we think that we meet this issue and use the approximate count to calculate the sample rate. - if float64(statsTbl.RealtimeCount*5) < approxiCount { - // Confirmed by TiKV side, the experience error rate of the approximate count is about 20%. - // So we increase the number to 150000 to reduce this error rate. - sampleRate = math.Min(1, 150000/approxiCount) - return sampleRate, fmt.Sprintf("Row count in stats_meta is much smaller compared with the row count got by PD, use min(1, 15000/%v) as the sample-rate=%v", approxiCount, sampleRate) - } - // If we don't go into the above if branch and we still detect the count is zero. Return 1 to prevent the dividing zero. - if statsTbl.RealtimeCount == 0 { - return 1, "TiDB assumes that the table is empty, use sample-rate=1" - } - // We are expected to scan about 100000 rows or so. - // Since there's tiny error rate around the count from the stats meta, we use 110000 to get a little big result - sampleRate = math.Min(1, config.DefRowsForSampleRate/float64(statsTbl.RealtimeCount)) - return sampleRate, fmt.Sprintf("use min(1, %v/%v) as the sample-rate=%v", config.DefRowsForSampleRate, statsTbl.RealtimeCount, sampleRate) -} - -func (b *executorBuilder) getApproximateTableCountFromStorage(tid int64, task plannercore.AnalyzeColumnsTask) (float64, bool) { - return pdhelper.GlobalPDHelper.GetApproximateTableCountFromStorage(context.Background(), b.ctx, tid, task.DBName, task.TableName, task.PartitionName) -} - -func (b *executorBuilder) buildAnalyzeColumnsPushdown( - task plannercore.AnalyzeColumnsTask, - opts map[ast.AnalyzeOptionType]uint64, - autoAnalyze string, - schemaForVirtualColEval *expression.Schema, -) *analyzeTask { - if task.StatsVersion == statistics.Version2 { - return b.buildAnalyzeSamplingPushdown(task, opts, schemaForVirtualColEval) - } - job := &statistics.AnalyzeJob{DBName: task.DBName, TableName: task.TableName, PartitionName: task.PartitionName, JobInfo: autoAnalyze + "analyze columns"} - cols := task.ColsInfo - if hasPkHist(task.HandleCols) { - colInfo := task.TblInfo.Columns[task.HandleCols.GetCol(0).Index] - cols = append([]*model.ColumnInfo{colInfo}, cols...) - } else if task.HandleCols != nil && !task.HandleCols.IsInt() { - cols = make([]*model.ColumnInfo, 0, len(task.ColsInfo)+task.HandleCols.NumCols()) - for i := 0; i < task.HandleCols.NumCols(); i++ { - cols = append(cols, task.TblInfo.Columns[task.HandleCols.GetCol(i).Index]) - } - cols = append(cols, task.ColsInfo...) - task.ColsInfo = cols - } - - _, offset := timeutil.Zone(b.ctx.GetSessionVars().Location()) - sc := b.ctx.GetSessionVars().StmtCtx - startTS, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - failpoint.Inject("injectAnalyzeSnapshot", func(val failpoint.Value) { - startTS = uint64(val.(int)) - }) - concurrency := adaptiveAnlayzeDistSQLConcurrency(context.Background(), b.ctx) - base := baseAnalyzeExec{ - ctx: b.ctx, - tableID: task.TableID, - concurrency: concurrency, - analyzePB: &tipb.AnalyzeReq{ - Tp: tipb.AnalyzeType_TypeColumn, - Flags: sc.PushDownFlags(), - TimeZoneOffset: offset, - }, - opts: opts, - job: job, - snapshot: startTS, - } - e := &AnalyzeColumnsExec{ - baseAnalyzeExec: base, - colsInfo: task.ColsInfo, - handleCols: task.HandleCols, - AnalyzeInfo: task.AnalyzeInfo, - } - depth := int32(opts[ast.AnalyzeOptCMSketchDepth]) - width := int32(opts[ast.AnalyzeOptCMSketchWidth]) - e.analyzePB.ColReq = &tipb.AnalyzeColumnsReq{ - BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), - SampleSize: MaxRegionSampleSize, - SketchSize: statistics.MaxSketchSize, - ColumnsInfo: util.ColumnsToProto(cols, task.HandleCols != nil && task.HandleCols.IsInt(), false), - CmsketchDepth: &depth, - CmsketchWidth: &width, - } - if task.TblInfo != nil { - e.analyzePB.ColReq.PrimaryColumnIds = tables.TryGetCommonPkColumnIds(task.TblInfo) - if task.TblInfo.IsCommonHandle { - e.analyzePB.ColReq.PrimaryPrefixColumnIds = tables.PrimaryPrefixColumnIDs(task.TblInfo) - } - } - if task.CommonHandleInfo != nil { - topNSize := new(int32) - *topNSize = int32(opts[ast.AnalyzeOptNumTopN]) - statsVersion := new(int32) - *statsVersion = int32(task.StatsVersion) - e.analyzePB.IdxReq = &tipb.AnalyzeIndexReq{ - BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), - NumColumns: int32(len(task.CommonHandleInfo.Columns)), - TopNSize: topNSize, - Version: statsVersion, - } - depth := int32(opts[ast.AnalyzeOptCMSketchDepth]) - width := int32(opts[ast.AnalyzeOptCMSketchWidth]) - e.analyzePB.IdxReq.CmsketchDepth = &depth - e.analyzePB.IdxReq.CmsketchWidth = &width - e.analyzePB.IdxReq.SketchSize = statistics.MaxSketchSize - e.analyzePB.ColReq.PrimaryColumnIds = tables.TryGetCommonPkColumnIds(task.TblInfo) - e.analyzePB.Tp = tipb.AnalyzeType_TypeMixed - e.commonHandle = task.CommonHandleInfo - } - b.err = tables.SetPBColumnsDefaultValue(b.ctx.GetExprCtx(), e.analyzePB.ColReq.ColumnsInfo, cols) - return &analyzeTask{taskType: colTask, colExec: e, job: job} -} - -func (b *executorBuilder) buildAnalyze(v *plannercore.Analyze) exec.Executor { - gp := domain.GetDomain(b.ctx).StatsHandle().GPool() - e := &AnalyzeExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - tasks: make([]*analyzeTask, 0, len(v.ColTasks)+len(v.IdxTasks)), - opts: v.Opts, - OptionsMap: v.OptionsMap, - wg: util.NewWaitGroupPool(gp), - gp: gp, - errExitCh: make(chan struct{}), - } - autoAnalyze := "" - if b.ctx.GetSessionVars().InRestrictedSQL { - autoAnalyze = "auto " - } - exprCtx := b.ctx.GetExprCtx() - for _, task := range v.ColTasks { - columns, _, err := expression.ColumnInfos2ColumnsAndNames( - exprCtx, - model.NewCIStr(task.AnalyzeInfo.DBName), - task.TblInfo.Name, - task.ColsInfo, - task.TblInfo, - ) - if err != nil { - b.err = err - return nil - } - schema := expression.NewSchema(columns...) - e.tasks = append(e.tasks, b.buildAnalyzeColumnsPushdown(task, v.Opts, autoAnalyze, schema)) - // Other functions may set b.err, so we need to check it here. - if b.err != nil { - return nil - } - } - for _, task := range v.IdxTasks { - e.tasks = append(e.tasks, b.buildAnalyzeIndexPushdown(task, v.Opts, autoAnalyze)) - if b.err != nil { - return nil - } - } - return e -} - -// markChildrenUsedCols compares each child with the output schema, and mark -// each column of the child is used by output or not. -func markChildrenUsedCols(outputCols []*expression.Column, childSchemas ...*expression.Schema) (childrenUsed [][]int) { - childrenUsed = make([][]int, 0, len(childSchemas)) - markedOffsets := make(map[int]int) - // keep the original maybe reversed order. - for originalIdx, col := range outputCols { - markedOffsets[col.Index] = originalIdx - } - prefixLen := 0 - type intPair struct { - first int - second int - } - // for example here. - // left child schema: [col11] - // right child schema: [col21, col22] - // output schema is [col11, col22, col21], if not records the original derived order after physical resolve index. - // the lused will be [0], the rused will be [0,1], while the actual order is dismissed, [1,0] is correct for rused. - for _, childSchema := range childSchemas { - usedIdxPair := make([]intPair, 0, len(childSchema.Columns)) - for i := range childSchema.Columns { - if originalIdx, ok := markedOffsets[prefixLen+i]; ok { - usedIdxPair = append(usedIdxPair, intPair{first: originalIdx, second: i}) - } - } - // sort the used idxes according their original indexes derived after resolveIndex. - slices.SortFunc(usedIdxPair, func(a, b intPair) int { - return cmp.Compare(a.first, b.first) - }) - usedIdx := make([]int, 0, len(childSchema.Columns)) - for _, one := range usedIdxPair { - usedIdx = append(usedIdx, one.second) - } - childrenUsed = append(childrenUsed, usedIdx) - prefixLen += childSchema.Len() - } - return -} - -func (*executorBuilder) corColInDistPlan(plans []base.PhysicalPlan) bool { - for _, p := range plans { - switch x := p.(type) { - case *plannercore.PhysicalSelection: - for _, cond := range x.Conditions { - if len(expression.ExtractCorColumns(cond)) > 0 { - return true - } - } - case *plannercore.PhysicalProjection: - for _, expr := range x.Exprs { - if len(expression.ExtractCorColumns(expr)) > 0 { - return true - } - } - case *plannercore.PhysicalTopN: - for _, byItem := range x.ByItems { - if len(expression.ExtractCorColumns(byItem.Expr)) > 0 { - return true - } - } - case *plannercore.PhysicalTableScan: - for _, cond := range x.LateMaterializationFilterCondition { - if len(expression.ExtractCorColumns(cond)) > 0 { - return true - } - } - } - } - return false -} - -// corColInAccess checks whether there's correlated column in access conditions. -func (*executorBuilder) corColInAccess(p base.PhysicalPlan) bool { - var access []expression.Expression - switch x := p.(type) { - case *plannercore.PhysicalTableScan: - access = x.AccessCondition - case *plannercore.PhysicalIndexScan: - access = x.AccessCondition - } - for _, cond := range access { - if len(expression.ExtractCorColumns(cond)) > 0 { - return true - } - } - return false -} - -func (b *executorBuilder) newDataReaderBuilder(p base.PhysicalPlan) (*dataReaderBuilder, error) { - ts, err := b.getSnapshotTS() - if err != nil { - return nil, err - } - - builderForDataReader := *b - builderForDataReader.forDataReaderBuilder = true - builderForDataReader.dataReaderTS = ts - - return &dataReaderBuilder{ - plan: p, - executorBuilder: &builderForDataReader, - }, nil -} - -func (b *executorBuilder) buildIndexLookUpJoin(v *plannercore.PhysicalIndexJoin) exec.Executor { - outerExec := b.build(v.Children()[1-v.InnerChildIdx]) - if b.err != nil { - return nil - } - outerTypes := exec.RetTypes(outerExec) - innerPlan := v.Children()[v.InnerChildIdx] - innerTypes := make([]*types.FieldType, innerPlan.Schema().Len()) - for i, col := range innerPlan.Schema().Columns { - innerTypes[i] = col.RetType.Clone() - // The `innerTypes` would be called for `Datum.ConvertTo` when converting the columns from outer table - // to build hash map or construct lookup keys. So we need to modify its flen otherwise there would be - // truncate error. See issue https://github.com/pingcap/tidb/issues/21232 for example. - if innerTypes[i].EvalType() == types.ETString { - innerTypes[i].SetFlen(types.UnspecifiedLength) - } - } - - // Use the probe table's collation. - for i, col := range v.OuterHashKeys { - outerTypes[col.Index] = outerTypes[col.Index].Clone() - outerTypes[col.Index].SetCollate(innerTypes[v.InnerHashKeys[i].Index].GetCollate()) - outerTypes[col.Index].SetFlag(col.RetType.GetFlag()) - } - - // We should use JoinKey to construct the type information using by hashing, instead of using the child's schema directly. - // When a hybrid type column is hashed multiple times, we need to distinguish what field types are used. - // For example, the condition `enum = int and enum = string`, we should use ETInt to hash the first column, - // and use ETString to hash the second column, although they may be the same column. - innerHashTypes := make([]*types.FieldType, len(v.InnerHashKeys)) - outerHashTypes := make([]*types.FieldType, len(v.OuterHashKeys)) - for i, col := range v.InnerHashKeys { - innerHashTypes[i] = innerTypes[col.Index].Clone() - innerHashTypes[i].SetFlag(col.RetType.GetFlag()) - } - for i, col := range v.OuterHashKeys { - outerHashTypes[i] = outerTypes[col.Index].Clone() - outerHashTypes[i].SetFlag(col.RetType.GetFlag()) - } - - var ( - outerFilter []expression.Expression - leftTypes, rightTypes []*types.FieldType - ) - - if v.InnerChildIdx == 0 { - leftTypes, rightTypes = innerTypes, outerTypes - outerFilter = v.RightConditions - if len(v.LeftConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } else { - leftTypes, rightTypes = outerTypes, innerTypes - outerFilter = v.LeftConditions - if len(v.RightConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } - defaultValues := v.DefaultValues - if defaultValues == nil { - defaultValues = make([]types.Datum, len(innerTypes)) - } - hasPrefixCol := false - for _, l := range v.IdxColLens { - if l != types.UnspecifiedLength { - hasPrefixCol = true - break - } - } - - readerBuilder, err := b.newDataReaderBuilder(innerPlan) - if err != nil { - b.err = err - return nil - } - - e := &join.IndexLookUpJoin{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec), - OuterCtx: join.OuterCtx{ - RowTypes: outerTypes, - HashTypes: outerHashTypes, - Filter: outerFilter, - }, - InnerCtx: join.InnerCtx{ - ReaderBuilder: readerBuilder, - RowTypes: innerTypes, - HashTypes: innerHashTypes, - ColLens: v.IdxColLens, - HasPrefixCol: hasPrefixCol, - }, - WorkerWg: new(sync.WaitGroup), - IsOuterJoin: v.JoinType.IsOuterJoin(), - IndexRanges: v.Ranges, - KeyOff2IdxOff: v.KeyOff2IdxOff, - LastColHelper: v.CompareFilters, - Finished: &atomic.Value{}, - } - colsFromChildren := v.Schema().Columns - if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { - colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] - } - childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) - e.Joiner = join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema, false) - outerKeyCols := make([]int, len(v.OuterJoinKeys)) - for i := 0; i < len(v.OuterJoinKeys); i++ { - outerKeyCols[i] = v.OuterJoinKeys[i].Index - } - innerKeyCols := make([]int, len(v.InnerJoinKeys)) - innerKeyColIDs := make([]int64, len(v.InnerJoinKeys)) - keyCollators := make([]collate.Collator, 0, len(v.InnerJoinKeys)) - for i := 0; i < len(v.InnerJoinKeys); i++ { - innerKeyCols[i] = v.InnerJoinKeys[i].Index - innerKeyColIDs[i] = v.InnerJoinKeys[i].ID - keyCollators = append(keyCollators, collate.GetCollator(v.InnerJoinKeys[i].RetType.GetCollate())) - } - e.OuterCtx.KeyCols = outerKeyCols - e.InnerCtx.KeyCols = innerKeyCols - e.InnerCtx.KeyColIDs = innerKeyColIDs - e.InnerCtx.KeyCollators = keyCollators - - outerHashCols, innerHashCols := make([]int, len(v.OuterHashKeys)), make([]int, len(v.InnerHashKeys)) - hashCollators := make([]collate.Collator, 0, len(v.InnerHashKeys)) - for i := 0; i < len(v.OuterHashKeys); i++ { - outerHashCols[i] = v.OuterHashKeys[i].Index - } - for i := 0; i < len(v.InnerHashKeys); i++ { - innerHashCols[i] = v.InnerHashKeys[i].Index - hashCollators = append(hashCollators, collate.GetCollator(v.InnerHashKeys[i].RetType.GetCollate())) - } - e.OuterCtx.HashCols = outerHashCols - e.InnerCtx.HashCols = innerHashCols - e.InnerCtx.HashCollators = hashCollators - - e.JoinResult = exec.TryNewCacheChunk(e) - executor_metrics.ExecutorCounterIndexLookUpJoin.Inc() - return e -} - -func (b *executorBuilder) buildIndexLookUpMergeJoin(v *plannercore.PhysicalIndexMergeJoin) exec.Executor { - outerExec := b.build(v.Children()[1-v.InnerChildIdx]) - if b.err != nil { - return nil - } - outerTypes := exec.RetTypes(outerExec) - innerPlan := v.Children()[v.InnerChildIdx] - innerTypes := make([]*types.FieldType, innerPlan.Schema().Len()) - for i, col := range innerPlan.Schema().Columns { - innerTypes[i] = col.RetType.Clone() - // The `innerTypes` would be called for `Datum.ConvertTo` when converting the columns from outer table - // to build hash map or construct lookup keys. So we need to modify its flen otherwise there would be - // truncate error. See issue https://github.com/pingcap/tidb/issues/21232 for example. - if innerTypes[i].EvalType() == types.ETString { - innerTypes[i].SetFlen(types.UnspecifiedLength) - } - } - var ( - outerFilter []expression.Expression - leftTypes, rightTypes []*types.FieldType - ) - if v.InnerChildIdx == 0 { - leftTypes, rightTypes = innerTypes, outerTypes - outerFilter = v.RightConditions - if len(v.LeftConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } else { - leftTypes, rightTypes = outerTypes, innerTypes - outerFilter = v.LeftConditions - if len(v.RightConditions) > 0 { - b.err = errors.Annotate(exeerrors.ErrBuildExecutor, "join's inner condition should be empty") - return nil - } - } - defaultValues := v.DefaultValues - if defaultValues == nil { - defaultValues = make([]types.Datum, len(innerTypes)) - } - outerKeyCols := make([]int, len(v.OuterJoinKeys)) - for i := 0; i < len(v.OuterJoinKeys); i++ { - outerKeyCols[i] = v.OuterJoinKeys[i].Index - } - innerKeyCols := make([]int, len(v.InnerJoinKeys)) - keyCollators := make([]collate.Collator, 0, len(v.InnerJoinKeys)) - for i := 0; i < len(v.InnerJoinKeys); i++ { - innerKeyCols[i] = v.InnerJoinKeys[i].Index - keyCollators = append(keyCollators, collate.GetCollator(v.InnerJoinKeys[i].RetType.GetCollate())) - } - executor_metrics.ExecutorCounterIndexLookUpJoin.Inc() - - readerBuilder, err := b.newDataReaderBuilder(innerPlan) - if err != nil { - b.err = err - return nil - } - - e := &join.IndexLookUpMergeJoin{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), outerExec), - OuterMergeCtx: join.OuterMergeCtx{ - RowTypes: outerTypes, - Filter: outerFilter, - JoinKeys: v.OuterJoinKeys, - KeyCols: outerKeyCols, - NeedOuterSort: v.NeedOuterSort, - CompareFuncs: v.OuterCompareFuncs, - }, - InnerMergeCtx: join.InnerMergeCtx{ - ReaderBuilder: readerBuilder, - RowTypes: innerTypes, - JoinKeys: v.InnerJoinKeys, - KeyCols: innerKeyCols, - KeyCollators: keyCollators, - CompareFuncs: v.CompareFuncs, - ColLens: v.IdxColLens, - Desc: v.Desc, - KeyOff2KeyOffOrderByIdx: v.KeyOff2KeyOffOrderByIdx, - }, - WorkerWg: new(sync.WaitGroup), - IsOuterJoin: v.JoinType.IsOuterJoin(), - IndexRanges: v.Ranges, - KeyOff2IdxOff: v.KeyOff2IdxOff, - LastColHelper: v.CompareFilters, - } - colsFromChildren := v.Schema().Columns - if v.JoinType == plannercore.LeftOuterSemiJoin || v.JoinType == plannercore.AntiLeftOuterSemiJoin { - colsFromChildren = colsFromChildren[:len(colsFromChildren)-1] - } - childrenUsedSchema := markChildrenUsedCols(colsFromChildren, v.Children()[0].Schema(), v.Children()[1].Schema()) - joiners := make([]join.Joiner, e.Ctx().GetSessionVars().IndexLookupJoinConcurrency()) - for i := 0; i < len(joiners); i++ { - joiners[i] = join.NewJoiner(b.ctx, v.JoinType, v.InnerChildIdx == 0, defaultValues, v.OtherConditions, leftTypes, rightTypes, childrenUsedSchema, false) - } - e.Joiners = joiners - return e -} - -func (b *executorBuilder) buildIndexNestedLoopHashJoin(v *plannercore.PhysicalIndexHashJoin) exec.Executor { - joinExec := b.buildIndexLookUpJoin(&(v.PhysicalIndexJoin)) - if b.err != nil { - return nil - } - e := joinExec.(*join.IndexLookUpJoin) - idxHash := &join.IndexNestedLoopHashJoin{ - IndexLookUpJoin: *e, - KeepOuterOrder: v.KeepOuterOrder, - } - concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() - idxHash.Joiners = make([]join.Joiner, concurrency) - for i := 0; i < concurrency; i++ { - idxHash.Joiners[i] = e.Joiner.Clone() - } - return idxHash -} - -func buildNoRangeTableReader(b *executorBuilder, v *plannercore.PhysicalTableReader) (*TableReaderExecutor, error) { - tablePlans := v.TablePlans - if v.StoreType == kv.TiFlash { - tablePlans = []base.PhysicalPlan{v.GetTablePlan()} - } - dagReq, err := builder.ConstructDAGReq(b.ctx, tablePlans, v.StoreType) - if err != nil { - return nil, err - } - ts, err := v.GetTableScan() - if err != nil { - return nil, err - } - if err = b.validCanReadTemporaryOrCacheTable(ts.Table); err != nil { - return nil, err - } - - tbl, _ := b.is.TableByID(ts.Table.ID) - isPartition, physicalTableID := ts.IsPartition() - if isPartition { - pt := tbl.(table.PartitionedTable) - tbl = pt.GetPartition(physicalTableID) - } - startTS, err := b.getSnapshotTS() - if err != nil { - return nil, err - } - paging := b.ctx.GetSessionVars().EnablePaging - - e := &TableReaderExecutor{ - BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()), - tableReaderExecutorContext: newTableReaderExecutorContext(b.ctx), - dagPB: dagReq, - startTS: startTS, - txnScope: b.txnScope, - readReplicaScope: b.readReplicaScope, - isStaleness: b.isStaleness, - netDataSize: v.GetNetDataSize(), - table: tbl, - keepOrder: ts.KeepOrder, - desc: ts.Desc, - byItems: ts.ByItems, - columns: ts.Columns, - paging: paging, - corColInFilter: b.corColInDistPlan(v.TablePlans), - corColInAccess: b.corColInAccess(v.TablePlans[0]), - plans: v.TablePlans, - tablePlan: v.GetTablePlan(), - storeType: v.StoreType, - batchCop: v.ReadReqType == plannercore.BatchCop, - } - e.buildVirtualColumnInfo() - - if v.StoreType == kv.TiDB && b.ctx.GetSessionVars().User != nil { - // User info is used to do privilege check. It is only used in TiDB cluster memory table. - e.dagPB.User = &tipb.UserIdentity{ - UserName: b.ctx.GetSessionVars().User.Username, - UserHost: b.ctx.GetSessionVars().User.Hostname, - } - } - - for i := range v.Schema().Columns { - dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) - } - - return e, nil -} - -func (b *executorBuilder) buildMPPGather(v *plannercore.PhysicalTableReader) exec.Executor { - startTs, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - - gather := &MPPGather{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - is: b.is, - originalPlan: v.GetTablePlan(), - startTS: startTs, - mppQueryID: kv.MPPQueryID{QueryTs: getMPPQueryTS(b.ctx), LocalQueryID: getMPPQueryID(b.ctx), ServerID: domain.GetDomain(b.ctx).ServerID()}, - memTracker: memory.NewTracker(v.ID(), -1), - - columns: []*model.ColumnInfo{}, - virtualColumnIndex: []int{}, - virtualColumnRetFieldTypes: []*types.FieldType{}, - } - - gather.memTracker.AttachTo(b.ctx.GetSessionVars().StmtCtx.MemTracker) - - var hasVirtualCol bool - for _, col := range v.Schema().Columns { - if col.VirtualExpr != nil { - hasVirtualCol = true - break - } - } - - var isSingleDataSource bool - tableScans := v.GetTableScans() - if len(tableScans) == 1 { - isSingleDataSource = true - } - - // 1. hasVirtualCol: when got virtual column in TableScan, will generate plan like the following, - // and there will be no other operators in the MPP fragment. - // MPPGather - // ExchangeSender - // PhysicalTableScan - // 2. UnionScan: there won't be any operators like Join between UnionScan and TableScan. - // and UnionScan cannot push down to tiflash. - if !isSingleDataSource { - if hasVirtualCol || b.encounterUnionScan { - b.err = errors.Errorf("should only have one TableScan in MPP fragment(hasVirtualCol: %v, encounterUnionScan: %v)", hasVirtualCol, b.encounterUnionScan) - return nil - } - return gather - } - - // Setup MPPGather.table if isSingleDataSource. - // Virtual Column or UnionScan need to use it. - ts := tableScans[0] - gather.columns = ts.Columns - if hasVirtualCol { - gather.virtualColumnIndex, gather.virtualColumnRetFieldTypes = buildVirtualColumnInfo(gather.Schema(), gather.columns) - } - tbl, _ := b.is.TableByID(ts.Table.ID) - isPartition, physicalTableID := ts.IsPartition() - if isPartition { - // Only for static pruning partition table. - pt := tbl.(table.PartitionedTable) - tbl = pt.GetPartition(physicalTableID) - } - gather.table = tbl - return gather -} - -// buildTableReader builds a table reader executor. It first build a no range table reader, -// and then update it ranges from table scan plan. -func (b *executorBuilder) buildTableReader(v *plannercore.PhysicalTableReader) exec.Executor { - failpoint.Inject("checkUseMPP", func(val failpoint.Value) { - if !b.ctx.GetSessionVars().InRestrictedSQL && val.(bool) != useMPPExecution(b.ctx, v) { - if val.(bool) { - b.err = errors.New("expect mpp but not used") - } else { - b.err = errors.New("don't expect mpp but we used it") - } - failpoint.Return(nil) - } - }) - // https://github.com/pingcap/tidb/issues/50358 - if len(v.Schema().Columns) == 0 && len(v.GetTablePlan().Schema().Columns) > 0 { - v.SetSchema(v.GetTablePlan().Schema()) - } - useMPP := useMPPExecution(b.ctx, v) - useTiFlashBatchCop := v.ReadReqType == plannercore.BatchCop - useTiFlash := useMPP || useTiFlashBatchCop - if useTiFlash { - if _, isTiDBZoneLabelSet := config.GetGlobalConfig().Labels[placement.DCLabelKey]; b.ctx.GetSessionVars().TiFlashReplicaRead != tiflash.AllReplicas && !isTiDBZoneLabelSet { - b.ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("the variable tiflash_replica_read is ignored, because the entry TiDB[%s] does not set the zone attribute and tiflash_replica_read is '%s'", config.GetGlobalConfig().AdvertiseAddress, tiflash.GetTiFlashReplicaRead(b.ctx.GetSessionVars().TiFlashReplicaRead))) - } - } - if useMPP { - return b.buildMPPGather(v) - } - ts, err := v.GetTableScan() - if err != nil { - b.err = err - return nil - } - ret, err := buildNoRangeTableReader(b, v) - if err != nil { - b.err = err - return nil - } - if err = b.validCanReadTemporaryOrCacheTable(ts.Table); err != nil { - b.err = err - return nil - } - - if ret.table.Meta().TempTableType != model.TempTableNone { - ret.dummy = true - } - - ret.ranges = ts.Ranges - sctx := b.ctx.GetSessionVars().StmtCtx - sctx.TableIDs = append(sctx.TableIDs, ts.Table.ID) - - if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - return ret - } - // When isPartition is set, it means the union rewriting is done, so a partition reader is preferred. - if ok, _ := ts.IsPartition(); ok { - return ret - } - - pi := ts.Table.GetPartitionInfo() - if pi == nil { - return ret - } - - tmp, _ := b.is.TableByID(ts.Table.ID) - tbl := tmp.(table.PartitionedTable) - partitions, err := partitionPruning(b.ctx, tbl, v.PlanPartInfo) - if err != nil { - b.err = err - return nil - } - if v.StoreType == kv.TiFlash { - sctx.IsTiFlash.Store(true) - } - - if len(partitions) == 0 { - return &TableDualExec{BaseExecutorV2: ret.BaseExecutorV2} - } - - // Sort the partition is necessary to make the final multiple partition key ranges ordered. - slices.SortFunc(partitions, func(i, j table.PhysicalTable) int { - return cmp.Compare(i.GetPhysicalID(), j.GetPhysicalID()) - }) - ret.kvRangeBuilder = kvRangeBuilderFromRangeAndPartition{ - partitions: partitions, - } - - return ret -} - -func buildIndexRangeForEachPartition(rctx *rangerctx.RangerContext, usedPartitions []table.PhysicalTable, contentPos []int64, - lookUpContent []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager) (map[int64][]*ranger.Range, error) { - contentBucket := make(map[int64][]*join.IndexJoinLookUpContent) - for _, p := range usedPartitions { - contentBucket[p.GetPhysicalID()] = make([]*join.IndexJoinLookUpContent, 0, 8) - } - for i, pos := range contentPos { - if _, ok := contentBucket[pos]; ok { - contentBucket[pos] = append(contentBucket[pos], lookUpContent[i]) - } - } - nextRange := make(map[int64][]*ranger.Range) - for _, p := range usedPartitions { - ranges, err := buildRangesForIndexJoin(rctx, contentBucket[p.GetPhysicalID()], indexRanges, keyOff2IdxOff, cwc) - if err != nil { - return nil, err - } - nextRange[p.GetPhysicalID()] = ranges - } - return nextRange, nil -} - -func getPartitionKeyColOffsets(keyColIDs []int64, pt table.PartitionedTable) []int { - keyColOffsets := make([]int, len(keyColIDs)) - for i, colID := range keyColIDs { - offset := -1 - for j, col := range pt.Cols() { - if colID == col.ID { - offset = j - break - } - } - if offset == -1 { - return nil - } - keyColOffsets[i] = offset - } - - t, ok := pt.(interface { - PartitionExpr() *tables.PartitionExpr - }) - if !ok { - return nil - } - pe := t.PartitionExpr() - if pe == nil { - return nil - } - - offsetMap := make(map[int]struct{}) - for _, offset := range keyColOffsets { - offsetMap[offset] = struct{}{} - } - for _, offset := range pe.ColumnOffset { - if _, ok := offsetMap[offset]; !ok { - return nil - } - } - return keyColOffsets -} - -func (builder *dataReaderBuilder) prunePartitionForInnerExecutor(tbl table.Table, physPlanPartInfo *plannercore.PhysPlanPartInfo, - lookUpContent []*join.IndexJoinLookUpContent) (usedPartition []table.PhysicalTable, canPrune bool, contentPos []int64, err error) { - partitionTbl := tbl.(table.PartitionedTable) - - // In index join, this is called by multiple goroutines simultaneously, but partitionPruning is not thread-safe. - // Use once.Do to avoid DATA RACE here. - // TODO: condition based pruning can be do in advance. - condPruneResult, err := builder.partitionPruning(partitionTbl, physPlanPartInfo) - if err != nil { - return nil, false, nil, err - } - - // recalculate key column offsets - if len(lookUpContent) == 0 { - return nil, false, nil, nil - } - if lookUpContent[0].KeyColIDs == nil { - return nil, false, nil, plannererrors.ErrInternal.GenWithStack("cannot get column IDs when dynamic pruning") - } - keyColOffsets := getPartitionKeyColOffsets(lookUpContent[0].KeyColIDs, partitionTbl) - if len(keyColOffsets) == 0 { - return condPruneResult, false, nil, nil - } - - locateKey := make([]types.Datum, len(partitionTbl.Cols())) - partitions := make(map[int64]table.PhysicalTable) - contentPos = make([]int64, len(lookUpContent)) - exprCtx := builder.ctx.GetExprCtx() - for idx, content := range lookUpContent { - for i, data := range content.Keys { - locateKey[keyColOffsets[i]] = data - } - p, err := partitionTbl.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey) - if table.ErrNoPartitionForGivenValue.Equal(err) { - continue - } - if err != nil { - return nil, false, nil, err - } - if _, ok := partitions[p.GetPhysicalID()]; !ok { - partitions[p.GetPhysicalID()] = p - } - contentPos[idx] = p.GetPhysicalID() - } - - usedPartition = make([]table.PhysicalTable, 0, len(partitions)) - for _, p := range condPruneResult { - if _, ok := partitions[p.GetPhysicalID()]; ok { - usedPartition = append(usedPartition, p) - } - } - - // To make the final key ranges involving multiple partitions ordered. - slices.SortFunc(usedPartition, func(i, j table.PhysicalTable) int { - return cmp.Compare(i.GetPhysicalID(), j.GetPhysicalID()) - }) - return usedPartition, true, contentPos, nil -} - -func buildNoRangeIndexReader(b *executorBuilder, v *plannercore.PhysicalIndexReader) (*IndexReaderExecutor, error) { - dagReq, err := builder.ConstructDAGReq(b.ctx, v.IndexPlans, kv.TiKV) - if err != nil { - return nil, err - } - is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) - tbl, _ := b.is.TableByID(is.Table.ID) - isPartition, physicalTableID := is.IsPartition() - if isPartition { - pt := tbl.(table.PartitionedTable) - tbl = pt.GetPartition(physicalTableID) - } else { - physicalTableID = is.Table.ID - } - startTS, err := b.getSnapshotTS() - if err != nil { - return nil, err - } - paging := b.ctx.GetSessionVars().EnablePaging - - e := &IndexReaderExecutor{ - indexReaderExecutorContext: newIndexReaderExecutorContext(b.ctx), - BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()), - indexUsageReporter: b.buildIndexUsageReporter(v), - dagPB: dagReq, - startTS: startTS, - txnScope: b.txnScope, - readReplicaScope: b.readReplicaScope, - isStaleness: b.isStaleness, - netDataSize: v.GetNetDataSize(), - physicalTableID: physicalTableID, - table: tbl, - index: is.Index, - keepOrder: is.KeepOrder, - desc: is.Desc, - columns: is.Columns, - byItems: is.ByItems, - paging: paging, - corColInFilter: b.corColInDistPlan(v.IndexPlans), - corColInAccess: b.corColInAccess(v.IndexPlans[0]), - idxCols: is.IdxCols, - colLens: is.IdxColLens, - plans: v.IndexPlans, - outputColumns: v.OutputColumns, - } - - for _, col := range v.OutputColumns { - dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(col.Index)) - } - - return e, nil -} - -func (b *executorBuilder) buildIndexReader(v *plannercore.PhysicalIndexReader) exec.Executor { - is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) - if err := b.validCanReadTemporaryOrCacheTable(is.Table); err != nil { - b.err = err - return nil - } - - ret, err := buildNoRangeIndexReader(b, v) - if err != nil { - b.err = err - return nil - } - - if ret.table.Meta().TempTableType != model.TempTableNone { - ret.dummy = true - } - - ret.ranges = is.Ranges - sctx := b.ctx.GetSessionVars().StmtCtx - sctx.IndexNames = append(sctx.IndexNames, is.Table.Name.O+":"+is.Index.Name.O) - - if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - return ret - } - // When isPartition is set, it means the union rewriting is done, so a partition reader is preferred. - if ok, _ := is.IsPartition(); ok { - return ret - } - - pi := is.Table.GetPartitionInfo() - if pi == nil { - return ret - } - - if is.Index.Global { - ret.partitionIDMap, err = getPartitionIDsAfterPruning(b.ctx, ret.table.(table.PartitionedTable), v.PlanPartInfo) - if err != nil { - b.err = err - return nil - } - return ret - } - - tmp, _ := b.is.TableByID(is.Table.ID) - tbl := tmp.(table.PartitionedTable) - partitions, err := partitionPruning(b.ctx, tbl, v.PlanPartInfo) - if err != nil { - b.err = err - return nil - } - ret.partitions = partitions - return ret -} - -func buildTableReq(b *executorBuilder, schemaLen int, plans []base.PhysicalPlan) (dagReq *tipb.DAGRequest, val table.Table, err error) { - tableReq, err := builder.ConstructDAGReq(b.ctx, plans, kv.TiKV) - if err != nil { - return nil, nil, err - } - for i := 0; i < schemaLen; i++ { - tableReq.OutputOffsets = append(tableReq.OutputOffsets, uint32(i)) - } - ts := plans[0].(*plannercore.PhysicalTableScan) - tbl, _ := b.is.TableByID(ts.Table.ID) - isPartition, physicalTableID := ts.IsPartition() - if isPartition { - pt := tbl.(table.PartitionedTable) - tbl = pt.GetPartition(physicalTableID) - } - return tableReq, tbl, err -} - -// buildIndexReq is designed to create a DAG for index request. -// If len(ByItems) != 0 means index request should return related columns -// to sort result rows in TiDB side for partition tables. -func buildIndexReq(ctx sessionctx.Context, columns []*model.IndexColumn, handleLen int, plans []base.PhysicalPlan) (dagReq *tipb.DAGRequest, err error) { - indexReq, err := builder.ConstructDAGReq(ctx, plans, kv.TiKV) - if err != nil { - return nil, err - } - - indexReq.OutputOffsets = []uint32{} - idxScan := plans[0].(*plannercore.PhysicalIndexScan) - if len(idxScan.ByItems) != 0 { - schema := idxScan.Schema() - for _, item := range idxScan.ByItems { - c, ok := item.Expr.(*expression.Column) - if !ok { - return nil, errors.Errorf("Not support non-column in orderBy pushed down") - } - find := false - for i, schemaColumn := range schema.Columns { - if schemaColumn.ID == c.ID { - indexReq.OutputOffsets = append(indexReq.OutputOffsets, uint32(i)) - find = true - break - } - } - if !find { - return nil, errors.Errorf("Not found order by related columns in indexScan.schema") - } - } - } - - for i := 0; i < handleLen; i++ { - indexReq.OutputOffsets = append(indexReq.OutputOffsets, uint32(len(columns)+i)) - } - - if idxScan.NeedExtraOutputCol() { - // need add one more column for pid or physical table id - indexReq.OutputOffsets = append(indexReq.OutputOffsets, uint32(len(columns)+handleLen)) - } - return indexReq, err -} - -func buildNoRangeIndexLookUpReader(b *executorBuilder, v *plannercore.PhysicalIndexLookUpReader) (*IndexLookUpExecutor, error) { - is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) - var handleLen int - if len(v.CommonHandleCols) != 0 { - handleLen = len(v.CommonHandleCols) - } else { - handleLen = 1 - } - indexReq, err := buildIndexReq(b.ctx, is.Index.Columns, handleLen, v.IndexPlans) - if err != nil { - return nil, err - } - indexPaging := false - if v.Paging { - indexPaging = true - } - tableReq, tbl, err := buildTableReq(b, v.Schema().Len(), v.TablePlans) - if err != nil { - return nil, err - } - ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) - startTS, err := b.getSnapshotTS() - if err != nil { - return nil, err - } - - readerBuilder, err := b.newDataReaderBuilder(nil) - if err != nil { - return nil, err - } - - e := &IndexLookUpExecutor{ - indexLookUpExecutorContext: newIndexLookUpExecutorContext(b.ctx), - BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), v.Schema(), v.ID()), - indexUsageReporter: b.buildIndexUsageReporter(v), - dagPB: indexReq, - startTS: startTS, - table: tbl, - index: is.Index, - keepOrder: is.KeepOrder, - byItems: is.ByItems, - desc: is.Desc, - tableRequest: tableReq, - columns: ts.Columns, - indexPaging: indexPaging, - dataReaderBuilder: readerBuilder, - corColInIdxSide: b.corColInDistPlan(v.IndexPlans), - corColInTblSide: b.corColInDistPlan(v.TablePlans), - corColInAccess: b.corColInAccess(v.IndexPlans[0]), - idxCols: is.IdxCols, - colLens: is.IdxColLens, - idxPlans: v.IndexPlans, - tblPlans: v.TablePlans, - PushedLimit: v.PushedLimit, - idxNetDataSize: v.GetAvgTableRowSize(), - avgRowSize: v.GetAvgTableRowSize(), - } - - if v.ExtraHandleCol != nil { - e.handleIdx = append(e.handleIdx, v.ExtraHandleCol.Index) - e.handleCols = []*expression.Column{v.ExtraHandleCol} - } else { - for _, handleCol := range v.CommonHandleCols { - e.handleIdx = append(e.handleIdx, handleCol.Index) - } - e.handleCols = v.CommonHandleCols - e.primaryKeyIndex = tables.FindPrimaryIndex(tbl.Meta()) - } - return e, nil -} - -func (b *executorBuilder) buildIndexLookUpReader(v *plannercore.PhysicalIndexLookUpReader) exec.Executor { - is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) - if err := b.validCanReadTemporaryOrCacheTable(is.Table); err != nil { - b.err = err - return nil - } - - ret, err := buildNoRangeIndexLookUpReader(b, v) - if err != nil { - b.err = err - return nil - } - - if ret.table.Meta().TempTableType != model.TempTableNone { - ret.dummy = true - } - - ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) - - ret.ranges = is.Ranges - executor_metrics.ExecutorCounterIndexLookUpExecutor.Inc() - - sctx := b.ctx.GetSessionVars().StmtCtx - sctx.IndexNames = append(sctx.IndexNames, is.Table.Name.O+":"+is.Index.Name.O) - sctx.TableIDs = append(sctx.TableIDs, ts.Table.ID) - - if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - return ret - } - - if pi := is.Table.GetPartitionInfo(); pi == nil { - return ret - } - - if is.Index.Global { - ret.partitionIDMap, err = getPartitionIDsAfterPruning(b.ctx, ret.table.(table.PartitionedTable), v.PlanPartInfo) - if err != nil { - b.err = err - return nil - } - - return ret - } - if ok, _ := is.IsPartition(); ok { - // Already pruned when translated to logical union. - return ret - } - - tmp, _ := b.is.TableByID(is.Table.ID) - tbl := tmp.(table.PartitionedTable) - partitions, err := partitionPruning(b.ctx, tbl, v.PlanPartInfo) - if err != nil { - b.err = err - return nil - } - ret.partitionTableMode = true - ret.prunedPartitions = partitions - return ret -} - -func buildNoRangeIndexMergeReader(b *executorBuilder, v *plannercore.PhysicalIndexMergeReader) (*IndexMergeReaderExecutor, error) { - partialPlanCount := len(v.PartialPlans) - partialReqs := make([]*tipb.DAGRequest, 0, partialPlanCount) - partialDataSizes := make([]float64, 0, partialPlanCount) - indexes := make([]*model.IndexInfo, 0, partialPlanCount) - descs := make([]bool, 0, partialPlanCount) - ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) - isCorColInPartialFilters := make([]bool, 0, partialPlanCount) - isCorColInPartialAccess := make([]bool, 0, partialPlanCount) - hasGlobalIndex := false - for i := 0; i < partialPlanCount; i++ { - var tempReq *tipb.DAGRequest - var err error - - if is, ok := v.PartialPlans[i][0].(*plannercore.PhysicalIndexScan); ok { - tempReq, err = buildIndexReq(b.ctx, is.Index.Columns, ts.HandleCols.NumCols(), v.PartialPlans[i]) - descs = append(descs, is.Desc) - indexes = append(indexes, is.Index) - if is.Index.Global { - hasGlobalIndex = true - } - } else { - ts := v.PartialPlans[i][0].(*plannercore.PhysicalTableScan) - tempReq, _, err = buildTableReq(b, len(ts.Columns), v.PartialPlans[i]) - descs = append(descs, ts.Desc) - indexes = append(indexes, nil) - } - if err != nil { - return nil, err - } - collect := false - tempReq.CollectRangeCounts = &collect - partialReqs = append(partialReqs, tempReq) - isCorColInPartialFilters = append(isCorColInPartialFilters, b.corColInDistPlan(v.PartialPlans[i])) - isCorColInPartialAccess = append(isCorColInPartialAccess, b.corColInAccess(v.PartialPlans[i][0])) - partialDataSizes = append(partialDataSizes, v.GetPartialReaderNetDataSize(v.PartialPlans[i][0])) - } - tableReq, tblInfo, err := buildTableReq(b, v.Schema().Len(), v.TablePlans) - isCorColInTableFilter := b.corColInDistPlan(v.TablePlans) - if err != nil { - return nil, err - } - startTS, err := b.getSnapshotTS() - if err != nil { - return nil, err - } - - readerBuilder, err := b.newDataReaderBuilder(nil) - if err != nil { - return nil, err - } - - e := &IndexMergeReaderExecutor{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - indexUsageReporter: b.buildIndexUsageReporter(v), - dagPBs: partialReqs, - startTS: startTS, - table: tblInfo, - indexes: indexes, - descs: descs, - tableRequest: tableReq, - columns: ts.Columns, - partialPlans: v.PartialPlans, - tblPlans: v.TablePlans, - partialNetDataSizes: partialDataSizes, - dataAvgRowSize: v.GetAvgTableRowSize(), - dataReaderBuilder: readerBuilder, - handleCols: v.HandleCols, - isCorColInPartialFilters: isCorColInPartialFilters, - isCorColInTableFilter: isCorColInTableFilter, - isCorColInPartialAccess: isCorColInPartialAccess, - isIntersection: v.IsIntersectionType, - byItems: v.ByItems, - pushedLimit: v.PushedLimit, - keepOrder: v.KeepOrder, - hasGlobalIndex: hasGlobalIndex, - } - collectTable := false - e.tableRequest.CollectRangeCounts = &collectTable - return e, nil -} - -type tableStatsPreloader interface { - LoadTableStats(sessionctx.Context) -} - -func (b *executorBuilder) buildIndexUsageReporter(plan tableStatsPreloader) (indexUsageReporter *exec.IndexUsageReporter) { - sc := b.ctx.GetSessionVars().StmtCtx - if b.ctx.GetSessionVars().StmtCtx.IndexUsageCollector != nil && - sc.RuntimeStatsColl != nil { - // Preload the table stats. If the statement is a point-get or execute, the planner may not have loaded the - // stats. - plan.LoadTableStats(b.ctx) - - statsMap := sc.GetUsedStatsInfo(false) - indexUsageReporter = exec.NewIndexUsageReporter( - sc.IndexUsageCollector, - sc.RuntimeStatsColl, statsMap) - } - - return indexUsageReporter -} - -func (b *executorBuilder) buildIndexMergeReader(v *plannercore.PhysicalIndexMergeReader) exec.Executor { - ts := v.TablePlans[0].(*plannercore.PhysicalTableScan) - if err := b.validCanReadTemporaryOrCacheTable(ts.Table); err != nil { - b.err = err - return nil - } - - ret, err := buildNoRangeIndexMergeReader(b, v) - if err != nil { - b.err = err - return nil - } - ret.ranges = make([][]*ranger.Range, 0, len(v.PartialPlans)) - sctx := b.ctx.GetSessionVars().StmtCtx - hasGlobalIndex := false - for i := 0; i < len(v.PartialPlans); i++ { - if is, ok := v.PartialPlans[i][0].(*plannercore.PhysicalIndexScan); ok { - ret.ranges = append(ret.ranges, is.Ranges) - sctx.IndexNames = append(sctx.IndexNames, is.Table.Name.O+":"+is.Index.Name.O) - if is.Index.Global { - hasGlobalIndex = true - } - } else { - ret.ranges = append(ret.ranges, v.PartialPlans[i][0].(*plannercore.PhysicalTableScan).Ranges) - if ret.table.Meta().IsCommonHandle { - tblInfo := ret.table.Meta() - sctx.IndexNames = append(sctx.IndexNames, tblInfo.Name.O+":"+tables.FindPrimaryIndex(tblInfo).Name.O) - } - } - } - sctx.TableIDs = append(sctx.TableIDs, ts.Table.ID) - executor_metrics.ExecutorCounterIndexMergeReaderExecutor.Inc() - - if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - return ret - } - - if pi := ts.Table.GetPartitionInfo(); pi == nil { - return ret - } - - tmp, _ := b.is.TableByID(ts.Table.ID) - partitions, err := partitionPruning(b.ctx, tmp.(table.PartitionedTable), v.PlanPartInfo) - if err != nil { - b.err = err - return nil - } - ret.partitionTableMode, ret.prunedPartitions = true, partitions - if hasGlobalIndex { - ret.partitionIDMap = make(map[int64]struct{}) - for _, p := range partitions { - ret.partitionIDMap[p.GetPhysicalID()] = struct{}{} - } - } - return ret -} - -// dataReaderBuilder build an executor. -// The executor can be used to read data in the ranges which are constructed by datums. -// Differences from executorBuilder: -// 1. dataReaderBuilder calculate data range from argument, rather than plan. -// 2. the result executor is already opened. -type dataReaderBuilder struct { - plan base.Plan - *executorBuilder - - selectResultHook // for testing - once struct { - sync.Once - condPruneResult []table.PhysicalTable - err error - } -} - -type mockPhysicalIndexReader struct { - base.PhysicalPlan - - e exec.Executor -} - -// MemoryUsage of mockPhysicalIndexReader is only for testing -func (*mockPhysicalIndexReader) MemoryUsage() (sum int64) { - return -} - -func (builder *dataReaderBuilder) BuildExecutorForIndexJoin(ctx context.Context, lookUpContents []*join.IndexJoinLookUpContent, - indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { - return builder.buildExecutorForIndexJoinInternal(ctx, builder.plan, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) -} - -func (builder *dataReaderBuilder) buildExecutorForIndexJoinInternal(ctx context.Context, plan base.Plan, lookUpContents []*join.IndexJoinLookUpContent, - indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { - switch v := plan.(type) { - case *plannercore.PhysicalTableReader: - return builder.buildTableReaderForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - case *plannercore.PhysicalIndexReader: - return builder.buildIndexReaderForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) - case *plannercore.PhysicalIndexLookUpReader: - return builder.buildIndexLookUpReaderForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) - case *plannercore.PhysicalUnionScan: - return builder.buildUnionScanForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - case *plannercore.PhysicalProjection: - return builder.buildProjectionForIndexJoin(ctx, v, lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - // Need to support physical selection because after PR 16389, TiDB will push down all the expr supported by TiKV or TiFlash - // in predicate push down stage, so if there is an expr which only supported by TiFlash, a physical selection will be added after index read - case *plannercore.PhysicalSelection: - childExec, err := builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - if err != nil { - return nil, err - } - exec := &SelectionExec{ - selectionExecutorContext: newSelectionExecutorContext(builder.ctx), - BaseExecutorV2: exec.NewBaseExecutorV2(builder.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), - filters: v.Conditions, - } - err = exec.open(ctx) - return exec, err - case *plannercore.PhysicalHashAgg: - childExec, err := builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - if err != nil { - return nil, err - } - exec := builder.buildHashAggFromChildExec(childExec, v) - err = exec.OpenSelf() - return exec, err - case *plannercore.PhysicalStreamAgg: - childExec, err := builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - if err != nil { - return nil, err - } - exec := builder.buildStreamAggFromChildExec(childExec, v) - err = exec.OpenSelf() - return exec, err - case *mockPhysicalIndexReader: - return v.e, nil - } - return nil, errors.New("Wrong plan type for dataReaderBuilder") -} - -func (builder *dataReaderBuilder) buildUnionScanForIndexJoin(ctx context.Context, v *plannercore.PhysicalUnionScan, - values []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, - cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { - childBuilder, err := builder.newDataReaderBuilder(v.Children()[0]) - if err != nil { - return nil, err - } - - reader, err := childBuilder.BuildExecutorForIndexJoin(ctx, values, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - if err != nil { - return nil, err - } - - ret := builder.buildUnionScanFromReader(reader, v) - if builder.err != nil { - return nil, builder.err - } - if us, ok := ret.(*UnionScanExec); ok { - err = us.open(ctx) - } - return ret, err -} - -func (builder *dataReaderBuilder) buildTableReaderForIndexJoin(ctx context.Context, v *plannercore.PhysicalTableReader, - lookUpContents []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, - cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { - e, err := buildNoRangeTableReader(builder.executorBuilder, v) - if !canReorderHandles { - // `canReorderHandles` is set to false only in IndexMergeJoin. IndexMergeJoin will trigger a dead loop problem - // when enabling paging(tidb/issues/35831). But IndexMergeJoin is not visible to the user and is deprecated - // for now. Thus, we disable paging here. - e.paging = false - } - if err != nil { - return nil, err - } - tbInfo := e.table.Meta() - if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - if v.IsCommonHandle { - kvRanges, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, getPhysicalTableID(e.table), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) - if err != nil { - return nil, err - } - return builder.buildTableReaderFromKvRanges(ctx, e, kvRanges) - } - handles, _ := dedupHandles(lookUpContents) - return builder.buildTableReaderFromHandles(ctx, e, handles, canReorderHandles) - } - tbl, _ := builder.is.TableByID(tbInfo.ID) - pt := tbl.(table.PartitionedTable) - usedPartitionList, err := builder.partitionPruning(pt, v.PlanPartInfo) - if err != nil { - return nil, err - } - usedPartitions := make(map[int64]table.PhysicalTable, len(usedPartitionList)) - for _, p := range usedPartitionList { - usedPartitions[p.GetPhysicalID()] = p - } - var kvRanges []kv.KeyRange - var keyColOffsets []int - if len(lookUpContents) > 0 { - keyColOffsets = getPartitionKeyColOffsets(lookUpContents[0].KeyColIDs, pt) - } - if v.IsCommonHandle { - if len(keyColOffsets) > 0 { - locateKey := make([]types.Datum, len(pt.Cols())) - kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) - // lookUpContentsByPID groups lookUpContents by pid(partition) so that kv ranges for same partition can be merged. - lookUpContentsByPID := make(map[int64][]*join.IndexJoinLookUpContent) - exprCtx := e.ectx - for _, content := range lookUpContents { - for i, data := range content.Keys { - locateKey[keyColOffsets[i]] = data - } - p, err := pt.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey) - if table.ErrNoPartitionForGivenValue.Equal(err) { - continue - } - if err != nil { - return nil, err - } - pid := p.GetPhysicalID() - if _, ok := usedPartitions[pid]; !ok { - continue - } - lookUpContentsByPID[pid] = append(lookUpContentsByPID[pid], content) - } - for pid, contents := range lookUpContentsByPID { - // buildKvRanges for each partition. - tmp, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, pid, -1, contents, indexRanges, keyOff2IdxOff, cwc, nil, interruptSignal) - if err != nil { - return nil, err - } - kvRanges = append(kvRanges, tmp...) - } - } else { - kvRanges = make([]kv.KeyRange, 0, len(usedPartitions)*len(lookUpContents)) - for _, p := range usedPartitionList { - tmp, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, p.GetPhysicalID(), -1, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) - if err != nil { - return nil, err - } - kvRanges = append(tmp, kvRanges...) - } - } - // The key ranges should be ordered. - slices.SortFunc(kvRanges, func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - return builder.buildTableReaderFromKvRanges(ctx, e, kvRanges) - } - - handles, lookUpContents := dedupHandles(lookUpContents) - - if len(keyColOffsets) > 0 { - locateKey := make([]types.Datum, len(pt.Cols())) - kvRanges = make([]kv.KeyRange, 0, len(lookUpContents)) - exprCtx := e.ectx - for _, content := range lookUpContents { - for i, data := range content.Keys { - locateKey[keyColOffsets[i]] = data - } - p, err := pt.GetPartitionByRow(exprCtx.GetEvalCtx(), locateKey) - if table.ErrNoPartitionForGivenValue.Equal(err) { - continue - } - if err != nil { - return nil, err - } - pid := p.GetPhysicalID() - if _, ok := usedPartitions[pid]; !ok { - continue - } - handle := kv.IntHandle(content.Keys[0].GetInt64()) - ranges, _ := distsql.TableHandlesToKVRanges(pid, []kv.Handle{handle}) - kvRanges = append(kvRanges, ranges...) - } - } else { - for _, p := range usedPartitionList { - ranges, _ := distsql.TableHandlesToKVRanges(p.GetPhysicalID(), handles) - kvRanges = append(kvRanges, ranges...) - } - } - - // The key ranges should be ordered. - slices.SortFunc(kvRanges, func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - return builder.buildTableReaderFromKvRanges(ctx, e, kvRanges) -} - -func dedupHandles(lookUpContents []*join.IndexJoinLookUpContent) ([]kv.Handle, []*join.IndexJoinLookUpContent) { - handles := make([]kv.Handle, 0, len(lookUpContents)) - validLookUpContents := make([]*join.IndexJoinLookUpContent, 0, len(lookUpContents)) - for _, content := range lookUpContents { - isValidHandle := true - handle := kv.IntHandle(content.Keys[0].GetInt64()) - for _, key := range content.Keys { - if handle.IntValue() != key.GetInt64() { - isValidHandle = false - break - } - } - if isValidHandle { - handles = append(handles, handle) - validLookUpContents = append(validLookUpContents, content) - } - } - return handles, validLookUpContents -} - -type kvRangeBuilderFromRangeAndPartition struct { - partitions []table.PhysicalTable -} - -func (h kvRangeBuilderFromRangeAndPartition) buildKeyRangeSeparately(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([]int64, [][]kv.KeyRange, error) { - ret := make([][]kv.KeyRange, len(h.partitions)) - pids := make([]int64, 0, len(h.partitions)) - for i, p := range h.partitions { - pid := p.GetPhysicalID() - pids = append(pids, pid) - meta := p.Meta() - if len(ranges) == 0 { - continue - } - kvRange, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) - if err != nil { - return nil, nil, err - } - ret[i] = kvRange.AppendSelfTo(ret[i]) - } - return pids, ret, nil -} - -func (h kvRangeBuilderFromRangeAndPartition) buildKeyRange(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([][]kv.KeyRange, error) { - ret := make([][]kv.KeyRange, len(h.partitions)) - if len(ranges) == 0 { - return ret, nil - } - for i, p := range h.partitions { - pid := p.GetPhysicalID() - meta := p.Meta() - kvRange, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{pid}, meta != nil && meta.IsCommonHandle, ranges) - if err != nil { - return nil, err - } - ret[i] = kvRange.AppendSelfTo(ret[i]) - } - return ret, nil -} - -// newClosestReadAdjuster let the request be sent to closest replica(within the same zone) -// if response size exceeds certain threshold. -func newClosestReadAdjuster(dctx *distsqlctx.DistSQLContext, req *kv.Request, netDataSize float64) kv.CoprRequestAdjuster { - if req.ReplicaRead != kv.ReplicaReadClosestAdaptive { - return nil - } - return func(req *kv.Request, copTaskCount int) bool { - // copTaskCount is the number of coprocessor requests - if int64(netDataSize/float64(copTaskCount)) >= dctx.ReplicaClosestReadThreshold { - req.MatchStoreLabels = append(req.MatchStoreLabels, &metapb.StoreLabel{ - Key: placement.DCLabelKey, - Value: config.GetTxnScopeFromConfig(), - }) - return true - } - // reset to read from leader when the data size is small. - req.ReplicaRead = kv.ReplicaReadLeader - return false - } -} - -func (builder *dataReaderBuilder) buildTableReaderBase(ctx context.Context, e *TableReaderExecutor, reqBuilderWithRange distsql.RequestBuilder) (*TableReaderExecutor, error) { - startTS, err := builder.getSnapshotTS() - if err != nil { - return nil, err - } - kvReq, err := reqBuilderWithRange. - SetDAGRequest(e.dagPB). - SetStartTS(startTS). - SetDesc(e.desc). - SetKeepOrder(e.keepOrder). - SetTxnScope(e.txnScope). - SetReadReplicaScope(e.readReplicaScope). - SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.dctx). - SetFromInfoSchema(e.GetInfoSchema()). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilderWithRange.Request, e.netDataSize)). - SetPaging(e.paging). - SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). - Build() - if err != nil { - return nil, err - } - e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) - e.resultHandler = &tableResultHandler{} - result, err := builder.SelectResult(ctx, builder.ctx.GetDistSQLCtx(), kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) - if err != nil { - return nil, err - } - e.resultHandler.open(nil, result) - return e, nil -} - -func (builder *dataReaderBuilder) buildTableReaderFromHandles(ctx context.Context, e *TableReaderExecutor, handles []kv.Handle, canReorderHandles bool) (*TableReaderExecutor, error) { - if canReorderHandles { - slices.SortFunc(handles, func(i, j kv.Handle) int { - return i.Compare(j) - }) - } - var b distsql.RequestBuilder - if len(handles) > 0 { - if _, ok := handles[0].(kv.PartitionHandle); ok { - b.SetPartitionsAndHandles(handles) - } else { - b.SetTableHandles(getPhysicalTableID(e.table), handles) - } - } else { - b.SetKeyRanges(nil) - } - return builder.buildTableReaderBase(ctx, e, b) -} - -func (builder *dataReaderBuilder) buildTableReaderFromKvRanges(ctx context.Context, e *TableReaderExecutor, ranges []kv.KeyRange) (exec.Executor, error) { - var b distsql.RequestBuilder - b.SetKeyRanges(ranges) - return builder.buildTableReaderBase(ctx, e, b) -} - -func (builder *dataReaderBuilder) buildIndexReaderForIndexJoin(ctx context.Context, v *plannercore.PhysicalIndexReader, - lookUpContents []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memoryTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { - e, err := buildNoRangeIndexReader(builder.executorBuilder, v) - if err != nil { - return nil, err - } - tbInfo := e.table.Meta() - if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - kvRanges, err := buildKvRangesForIndexJoin(e.dctx, e.rctx, e.physicalTableID, e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memoryTracker, interruptSignal) - if err != nil { - return nil, err - } - err = e.open(ctx, kvRanges) - return e, err - } - - is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) - if is.Index.Global { - e.partitionIDMap, err = getPartitionIDsAfterPruning(builder.ctx, e.table.(table.PartitionedTable), v.PlanPartInfo) - if err != nil { - return nil, err - } - if e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc); err != nil { - return nil, err - } - if err := exec.Open(ctx, e); err != nil { - return nil, err - } - return e, nil - } - - tbl, _ := builder.executorBuilder.is.TableByID(tbInfo.ID) - usedPartition, canPrune, contentPos, err := builder.prunePartitionForInnerExecutor(tbl, v.PlanPartInfo, lookUpContents) - if err != nil { - return nil, err - } - if len(usedPartition) != 0 { - if canPrune { - rangeMap, err := buildIndexRangeForEachPartition(e.rctx, usedPartition, contentPos, lookUpContents, indexRanges, keyOff2IdxOff, cwc) - if err != nil { - return nil, err - } - e.partitions = usedPartition - e.ranges = indexRanges - e.partRangeMap = rangeMap - } else { - e.partitions = usedPartition - if e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc); err != nil { - return nil, err - } - } - if err := exec.Open(ctx, e); err != nil { - return nil, err - } - return e, nil - } - ret := &TableDualExec{BaseExecutorV2: e.BaseExecutorV2} - err = exec.Open(ctx, ret) - return ret, err -} - -func (builder *dataReaderBuilder) buildIndexLookUpReaderForIndexJoin(ctx context.Context, v *plannercore.PhysicalIndexLookUpReader, - lookUpContents []*join.IndexJoinLookUpContent, indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) { - e, err := buildNoRangeIndexLookUpReader(builder.executorBuilder, v) - if err != nil { - return nil, err - } - - tbInfo := e.table.Meta() - if tbInfo.GetPartitionInfo() == nil || !builder.ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - e.kvRanges, err = buildKvRangesForIndexJoin(e.dctx, e.rctx, getPhysicalTableID(e.table), e.index.ID, lookUpContents, indexRanges, keyOff2IdxOff, cwc, memTracker, interruptSignal) - if err != nil { - return nil, err - } - err = e.open(ctx) - return e, err - } - - is := v.IndexPlans[0].(*plannercore.PhysicalIndexScan) - if is.Index.Global { - e.partitionIDMap, err = getPartitionIDsAfterPruning(builder.ctx, e.table.(table.PartitionedTable), v.PlanPartInfo) - if err != nil { - return nil, err - } - e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc) - if err != nil { - return nil, err - } - if err := exec.Open(ctx, e); err != nil { - return nil, err - } - return e, err - } - - tbl, _ := builder.executorBuilder.is.TableByID(tbInfo.ID) - usedPartition, canPrune, contentPos, err := builder.prunePartitionForInnerExecutor(tbl, v.PlanPartInfo, lookUpContents) - if err != nil { - return nil, err - } - if len(usedPartition) != 0 { - if canPrune { - rangeMap, err := buildIndexRangeForEachPartition(e.rctx, usedPartition, contentPos, lookUpContents, indexRanges, keyOff2IdxOff, cwc) - if err != nil { - return nil, err - } - e.prunedPartitions = usedPartition - e.ranges = indexRanges - e.partitionRangeMap = rangeMap - } else { - e.prunedPartitions = usedPartition - e.ranges, err = buildRangesForIndexJoin(e.rctx, lookUpContents, indexRanges, keyOff2IdxOff, cwc) - if err != nil { - return nil, err - } - } - e.partitionTableMode = true - if err := exec.Open(ctx, e); err != nil { - return nil, err - } - return e, err - } - ret := &TableDualExec{BaseExecutorV2: e.BaseExecutorV2} - err = exec.Open(ctx, ret) - return ret, err -} - -func (builder *dataReaderBuilder) buildProjectionForIndexJoin( - ctx context.Context, - v *plannercore.PhysicalProjection, - lookUpContents []*join.IndexJoinLookUpContent, - indexRanges []*ranger.Range, - keyOff2IdxOff []int, - cwc *plannercore.ColWithCmpFuncManager, - canReorderHandles bool, - memTracker *memory.Tracker, - interruptSignal *atomic.Value, -) (executor exec.Executor, err error) { - var childExec exec.Executor - childExec, err = builder.buildExecutorForIndexJoinInternal(ctx, v.Children()[0], lookUpContents, indexRanges, keyOff2IdxOff, cwc, canReorderHandles, memTracker, interruptSignal) - if err != nil { - return nil, err - } - defer func() { - if r := recover(); r != nil { - err = util.GetRecoverError(r) - } - if err != nil { - terror.Log(exec.Close(childExec)) - } - }() - - e := &ProjectionExec{ - projectionExecutorContext: newProjectionExecutorContext(builder.ctx), - BaseExecutorV2: exec.NewBaseExecutorV2(builder.ctx.GetSessionVars(), v.Schema(), v.ID(), childExec), - numWorkers: int64(builder.ctx.GetSessionVars().ProjectionConcurrency()), - evaluatorSuit: expression.NewEvaluatorSuite(v.Exprs, v.AvoidColumnEvaluator), - calculateNoDelay: v.CalculateNoDelay, - } - - // If the calculation row count for this Projection operator is smaller - // than a Chunk size, we turn back to the un-parallel Projection - // implementation to reduce the goroutine overhead. - if int64(v.StatsCount()) < int64(builder.ctx.GetSessionVars().MaxChunkSize) { - e.numWorkers = 0 - } - failpoint.Inject("buildProjectionForIndexJoinPanic", func(val failpoint.Value) { - if v, ok := val.(bool); ok && v { - panic("buildProjectionForIndexJoinPanic") - } - }) - err = e.open(ctx) - if err != nil { - return nil, err - } - return e, nil -} - -// buildRangesForIndexJoin builds kv ranges for index join when the inner plan is index scan plan. -func buildRangesForIndexJoin(rctx *rangerctx.RangerContext, lookUpContents []*join.IndexJoinLookUpContent, - ranges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager) ([]*ranger.Range, error) { - retRanges := make([]*ranger.Range, 0, len(ranges)*len(lookUpContents)) - lastPos := len(ranges[0].LowVal) - 1 - tmpDatumRanges := make([]*ranger.Range, 0, len(lookUpContents)) - for _, content := range lookUpContents { - for _, ran := range ranges { - for keyOff, idxOff := range keyOff2IdxOff { - ran.LowVal[idxOff] = content.Keys[keyOff] - ran.HighVal[idxOff] = content.Keys[keyOff] - } - } - if cwc == nil { - // A deep copy is need here because the old []*range.Range is overwritten - for _, ran := range ranges { - retRanges = append(retRanges, ran.Clone()) - } - continue - } - nextColRanges, err := cwc.BuildRangesByRow(rctx, content.Row) - if err != nil { - return nil, err - } - for _, nextColRan := range nextColRanges { - for _, ran := range ranges { - ran.LowVal[lastPos] = nextColRan.LowVal[0] - ran.HighVal[lastPos] = nextColRan.HighVal[0] - ran.LowExclude = nextColRan.LowExclude - ran.HighExclude = nextColRan.HighExclude - ran.Collators = nextColRan.Collators - tmpDatumRanges = append(tmpDatumRanges, ran.Clone()) - } - } - } - - if cwc == nil { - return retRanges, nil - } - - return ranger.UnionRanges(rctx, tmpDatumRanges, true) -} - -// buildKvRangesForIndexJoin builds kv ranges for index join when the inner plan is index scan plan. -func buildKvRangesForIndexJoin(dctx *distsqlctx.DistSQLContext, pctx *rangerctx.RangerContext, tableID, indexID int64, lookUpContents []*join.IndexJoinLookUpContent, - ranges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, memTracker *memory.Tracker, interruptSignal *atomic.Value) (_ []kv.KeyRange, err error) { - kvRanges := make([]kv.KeyRange, 0, len(ranges)*len(lookUpContents)) - if len(ranges) == 0 { - return []kv.KeyRange{}, nil - } - lastPos := len(ranges[0].LowVal) - 1 - tmpDatumRanges := make([]*ranger.Range, 0, len(lookUpContents)) - for _, content := range lookUpContents { - for _, ran := range ranges { - for keyOff, idxOff := range keyOff2IdxOff { - ran.LowVal[idxOff] = content.Keys[keyOff] - ran.HighVal[idxOff] = content.Keys[keyOff] - } - } - if cwc == nil { - // Index id is -1 means it's a common handle. - var tmpKvRanges *kv.KeyRanges - var err error - if indexID == -1 { - tmpKvRanges, err = distsql.CommonHandleRangesToKVRanges(dctx, []int64{tableID}, ranges) - } else { - tmpKvRanges, err = distsql.IndexRangesToKVRangesWithInterruptSignal(dctx, tableID, indexID, ranges, memTracker, interruptSignal) - } - if err != nil { - return nil, err - } - kvRanges = tmpKvRanges.AppendSelfTo(kvRanges) - continue - } - nextColRanges, err := cwc.BuildRangesByRow(pctx, content.Row) - if err != nil { - return nil, err - } - for _, nextColRan := range nextColRanges { - for _, ran := range ranges { - ran.LowVal[lastPos] = nextColRan.LowVal[0] - ran.HighVal[lastPos] = nextColRan.HighVal[0] - ran.LowExclude = nextColRan.LowExclude - ran.HighExclude = nextColRan.HighExclude - ran.Collators = nextColRan.Collators - tmpDatumRanges = append(tmpDatumRanges, ran.Clone()) - } - } - } - if len(kvRanges) != 0 && memTracker != nil { - failpoint.Inject("testIssue49033", func() { - panic("testIssue49033") - }) - memTracker.Consume(int64(2 * cap(kvRanges[0].StartKey) * len(kvRanges))) - } - if len(tmpDatumRanges) != 0 && memTracker != nil { - memTracker.Consume(2 * types.EstimatedMemUsage(tmpDatumRanges[0].LowVal, len(tmpDatumRanges))) - } - if cwc == nil { - slices.SortFunc(kvRanges, func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - return kvRanges, nil - } - - tmpDatumRanges, err = ranger.UnionRanges(pctx, tmpDatumRanges, true) - if err != nil { - return nil, err - } - // Index id is -1 means it's a common handle. - if indexID == -1 { - tmpKeyRanges, err := distsql.CommonHandleRangesToKVRanges(dctx, []int64{tableID}, tmpDatumRanges) - return tmpKeyRanges.FirstPartitionRange(), err - } - tmpKeyRanges, err := distsql.IndexRangesToKVRangesWithInterruptSignal(dctx, tableID, indexID, tmpDatumRanges, memTracker, interruptSignal) - return tmpKeyRanges.FirstPartitionRange(), err -} - -func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Executor { - childExec := b.build(v.Children()[0]) - if b.err != nil { - return nil - } - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID(), childExec) - groupByItems := make([]expression.Expression, 0, len(v.PartitionBy)) - for _, item := range v.PartitionBy { - groupByItems = append(groupByItems, item.Col) - } - orderByCols := make([]*expression.Column, 0, len(v.OrderBy)) - for _, item := range v.OrderBy { - orderByCols = append(orderByCols, item.Col) - } - windowFuncs := make([]aggfuncs.AggFunc, 0, len(v.WindowFuncDescs)) - partialResults := make([]aggfuncs.PartialResult, 0, len(v.WindowFuncDescs)) - resultColIdx := v.Schema().Len() - len(v.WindowFuncDescs) - exprCtx := b.ctx.GetExprCtx() - for _, desc := range v.WindowFuncDescs { - aggDesc, err := aggregation.NewAggFuncDescForWindowFunc(exprCtx, desc, false) - if err != nil { - b.err = err - return nil - } - agg := aggfuncs.BuildWindowFunctions(exprCtx, aggDesc, resultColIdx, orderByCols) - windowFuncs = append(windowFuncs, agg) - partialResult, _ := agg.AllocPartialResult() - partialResults = append(partialResults, partialResult) - resultColIdx++ - } - - var err error - if b.ctx.GetSessionVars().EnablePipelinedWindowExec { - exec := &PipelinedWindowExec{ - BaseExecutor: base, - groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx.GetExprCtx().GetEvalCtx(), b.ctx.GetSessionVars().EnableVectorizedExpression, groupByItems), - numWindowFuncs: len(v.WindowFuncDescs), - } - - exec.windowFuncs = windowFuncs - exec.partialResults = partialResults - if v.Frame == nil { - exec.start = &logicalop.FrameBound{ - Type: ast.Preceding, - UnBounded: true, - } - exec.end = &logicalop.FrameBound{ - Type: ast.Following, - UnBounded: true, - } - } else { - exec.start = v.Frame.Start - exec.end = v.Frame.End - if v.Frame.Type == ast.Ranges { - cmpResult := int64(-1) - if len(v.OrderBy) > 0 && v.OrderBy[0].Desc { - cmpResult = 1 - } - exec.orderByCols = orderByCols - exec.expectedCmpResult = cmpResult - exec.isRangeFrame = true - err = exec.start.UpdateCompareCols(b.ctx, exec.orderByCols) - if err != nil { - return nil - } - err = exec.end.UpdateCompareCols(b.ctx, exec.orderByCols) - if err != nil { - return nil - } - } - } - return exec - } - var processor windowProcessor - if v.Frame == nil { - processor = &aggWindowProcessor{ - windowFuncs: windowFuncs, - partialResults: partialResults, - } - } else if v.Frame.Type == ast.Rows { - processor = &rowFrameWindowProcessor{ - windowFuncs: windowFuncs, - partialResults: partialResults, - start: v.Frame.Start, - end: v.Frame.End, - } - } else { - cmpResult := int64(-1) - if len(v.OrderBy) > 0 && v.OrderBy[0].Desc { - cmpResult = 1 - } - tmpProcessor := &rangeFrameWindowProcessor{ - windowFuncs: windowFuncs, - partialResults: partialResults, - start: v.Frame.Start, - end: v.Frame.End, - orderByCols: orderByCols, - expectedCmpResult: cmpResult, - } - - err = tmpProcessor.start.UpdateCompareCols(b.ctx, orderByCols) - if err != nil { - return nil - } - err = tmpProcessor.end.UpdateCompareCols(b.ctx, orderByCols) - if err != nil { - return nil - } - - processor = tmpProcessor - } - return &WindowExec{BaseExecutor: base, - processor: processor, - groupChecker: vecgroupchecker.NewVecGroupChecker(b.ctx.GetExprCtx().GetEvalCtx(), b.ctx.GetSessionVars().EnableVectorizedExpression, groupByItems), - numWindowFuncs: len(v.WindowFuncDescs), - } -} - -func (b *executorBuilder) buildShuffle(v *plannercore.PhysicalShuffle) *ShuffleExec { - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - shuffle := &ShuffleExec{ - BaseExecutor: base, - concurrency: v.Concurrency, - } - - // 1. initialize the splitters - splitters := make([]partitionSplitter, len(v.ByItemArrays)) - switch v.SplitterType { - case plannercore.PartitionHashSplitterType: - for i, byItems := range v.ByItemArrays { - splitters[i] = buildPartitionHashSplitter(shuffle.concurrency, byItems) - } - case plannercore.PartitionRangeSplitterType: - for i, byItems := range v.ByItemArrays { - splitters[i] = buildPartitionRangeSplitter(b.ctx, shuffle.concurrency, byItems) - } - default: - panic("Not implemented. Should not reach here.") - } - shuffle.splitters = splitters - - // 2. initialize the data sources (build the data sources from physical plan to executors) - shuffle.dataSources = make([]exec.Executor, len(v.DataSources)) - for i, dataSource := range v.DataSources { - shuffle.dataSources[i] = b.build(dataSource) - if b.err != nil { - return nil - } - } - - // 3. initialize the workers - head := v.Children()[0] - // A `PhysicalShuffleReceiverStub` for every worker have the same `DataSource` but different `Receiver`. - // We preallocate `PhysicalShuffleReceiverStub`s here and reuse them below. - stubs := make([]*plannercore.PhysicalShuffleReceiverStub, 0, len(v.DataSources)) - for _, dataSource := range v.DataSources { - stub := plannercore.PhysicalShuffleReceiverStub{ - DataSource: dataSource, - }.Init(b.ctx.GetPlanCtx(), dataSource.StatsInfo(), dataSource.QueryBlockOffset(), nil) - stub.SetSchema(dataSource.Schema()) - stubs = append(stubs, stub) - } - shuffle.workers = make([]*shuffleWorker, shuffle.concurrency) - for i := range shuffle.workers { - receivers := make([]*shuffleReceiver, len(v.DataSources)) - for j, dataSource := range v.DataSources { - receivers[j] = &shuffleReceiver{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, dataSource.Schema(), stubs[j].ID()), - } - } - - w := &shuffleWorker{ - receivers: receivers, - } - - for j := range v.DataSources { - stub := stubs[j] - stub.Receiver = (unsafe.Pointer)(receivers[j]) - v.Tails[j].SetChildren(stub) - } - - w.childExec = b.build(head) - if b.err != nil { - return nil - } - - shuffle.workers[i] = w - } - - return shuffle -} - -func (*executorBuilder) buildShuffleReceiverStub(v *plannercore.PhysicalShuffleReceiverStub) *shuffleReceiver { - return (*shuffleReceiver)(v.Receiver) -} - -func (b *executorBuilder) buildSQLBindExec(v *plannercore.SQLBindPlan) exec.Executor { - base := exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()) - base.SetInitCap(chunk.ZeroCapacity) - - e := &SQLBindExec{ - BaseExecutor: base, - sqlBindOp: v.SQLBindOp, - normdOrigSQL: v.NormdOrigSQL, - bindSQL: v.BindSQL, - charset: v.Charset, - collation: v.Collation, - db: v.Db, - isGlobal: v.IsGlobal, - bindAst: v.BindStmt, - newStatus: v.NewStatus, - source: v.Source, - sqlDigest: v.SQLDigest, - planDigest: v.PlanDigest, - } - return e -} - -// NewRowDecoder creates a chunk decoder for new row format row value decode. -func NewRowDecoder(ctx sessionctx.Context, schema *expression.Schema, tbl *model.TableInfo) *rowcodec.ChunkDecoder { - getColInfoByID := func(tbl *model.TableInfo, colID int64) *model.ColumnInfo { - for _, col := range tbl.Columns { - if col.ID == colID { - return col - } - } - return nil - } - var pkCols []int64 - reqCols := make([]rowcodec.ColInfo, len(schema.Columns)) - for i := range schema.Columns { - idx, col := i, schema.Columns[i] - isPK := (tbl.PKIsHandle && mysql.HasPriKeyFlag(col.RetType.GetFlag())) || col.ID == model.ExtraHandleID - if isPK { - pkCols = append(pkCols, col.ID) - } - isGeneratedCol := false - if col.VirtualExpr != nil { - isGeneratedCol = true - } - reqCols[idx] = rowcodec.ColInfo{ - ID: col.ID, - VirtualGenCol: isGeneratedCol, - Ft: col.RetType, - } - } - if len(pkCols) == 0 { - pkCols = tables.TryGetCommonPkColumnIds(tbl) - if len(pkCols) == 0 { - pkCols = []int64{-1} - } - } - defVal := func(i int, chk *chunk.Chunk) error { - if reqCols[i].ID < 0 { - // model.ExtraHandleID, ExtraPhysTblID... etc - // Don't set the default value for that column. - chk.AppendNull(i) - return nil - } - - ci := getColInfoByID(tbl, reqCols[i].ID) - d, err := table.GetColOriginDefaultValue(ctx.GetExprCtx(), ci) - if err != nil { - return err - } - chk.AppendDatum(i, &d) - return nil - } - return rowcodec.NewChunkDecoder(reqCols, pkCols, defVal, ctx.GetSessionVars().Location()) -} - -func (b *executorBuilder) buildBatchPointGet(plan *plannercore.BatchPointGetPlan) exec.Executor { - var err error - if err = b.validCanReadTemporaryOrCacheTable(plan.TblInfo); err != nil { - b.err = err - return nil - } - - if plan.Lock && !b.inSelectLockStmt { - b.inSelectLockStmt = true - defer func() { - b.inSelectLockStmt = false - }() - } - handles, isTableDual := plan.PrunePartitionsAndValues(b.ctx) - if isTableDual { - // No matching partitions - return &TableDualExec{ - BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), plan.Schema(), plan.ID()), - numDualRows: 0, - } - } - - decoder := NewRowDecoder(b.ctx, plan.Schema(), plan.TblInfo) - e := &BatchPointGetExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, plan.Schema(), plan.ID()), - indexUsageReporter: b.buildIndexUsageReporter(plan), - tblInfo: plan.TblInfo, - idxInfo: plan.IndexInfo, - rowDecoder: decoder, - keepOrder: plan.KeepOrder, - desc: plan.Desc, - lock: plan.Lock, - waitTime: plan.LockWaitTime, - columns: plan.Columns, - handles: handles, - idxVals: plan.IndexValues, - partitionNames: plan.PartitionNames, - } - - e.snapshot, err = b.getSnapshot() - if err != nil { - b.err = err - return nil - } - if e.Ctx().GetSessionVars().IsReplicaReadClosestAdaptive() { - e.snapshot.SetOption(kv.ReplicaReadAdjuster, newReplicaReadAdjuster(e.Ctx(), plan.GetAvgRowSize())) - } - if e.RuntimeStats() != nil { - snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} - e.stats = &runtimeStatsWithSnapshot{ - SnapshotRuntimeStats: snapshotStats, - } - e.snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) - } - - if plan.IndexInfo != nil { - sctx := b.ctx.GetSessionVars().StmtCtx - sctx.IndexNames = append(sctx.IndexNames, plan.TblInfo.Name.O+":"+plan.IndexInfo.Name.O) - } - - failpoint.Inject("assertBatchPointReplicaOption", func(val failpoint.Value) { - assertScope := val.(string) - if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != b.readReplicaScope { - panic("batch point get replica option fail") - } - }) - - snapshotTS, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - if plan.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { - if cacheTable := b.getCacheTable(plan.TblInfo, snapshotTS); cacheTable != nil { - e.snapshot = cacheTableSnapshot{e.snapshot, cacheTable} - } - } - - if plan.TblInfo.TempTableType != model.TempTableNone { - // Temporary table should not do any lock operations - e.lock = false - e.waitTime = 0 - } - - if e.lock { - b.hasLock = true - } - if pi := plan.TblInfo.GetPartitionInfo(); pi != nil && len(plan.PartitionIdxs) > 0 { - defs := plan.TblInfo.GetPartitionInfo().Definitions - if plan.SinglePartition { - e.singlePartID = defs[plan.PartitionIdxs[0]].ID - } else { - e.planPhysIDs = make([]int64, len(plan.PartitionIdxs)) - for i, idx := range plan.PartitionIdxs { - e.planPhysIDs[i] = defs[idx].ID - } - } - } - - capacity := len(e.handles) - if capacity == 0 { - capacity = len(e.idxVals) - } - e.SetInitCap(capacity) - e.SetMaxChunkSize(capacity) - e.buildVirtualColumnInfo() - return e -} - -func newReplicaReadAdjuster(ctx sessionctx.Context, avgRowSize float64) txnkv.ReplicaReadAdjuster { - return func(count int) (tikv.StoreSelectorOption, clientkv.ReplicaReadType) { - if int64(avgRowSize*float64(count)) >= ctx.GetSessionVars().ReplicaClosestReadThreshold { - return tikv.WithMatchLabels([]*metapb.StoreLabel{ - { - Key: placement.DCLabelKey, - Value: config.GetTxnScopeFromConfig(), - }, - }), clientkv.ReplicaReadMixed - } - // fallback to read from leader if the request is small - return nil, clientkv.ReplicaReadLeader - } -} - -func isCommonHandleRead(tbl *model.TableInfo, idx *model.IndexInfo) bool { - return tbl.IsCommonHandle && idx.Primary -} - -func getPhysicalTableID(t table.Table) int64 { - if p, ok := t.(table.PhysicalTable); ok { - return p.GetPhysicalID() - } - return t.Meta().ID -} - -func (builder *dataReaderBuilder) partitionPruning(tbl table.PartitionedTable, planPartInfo *plannercore.PhysPlanPartInfo) ([]table.PhysicalTable, error) { - builder.once.Do(func() { - condPruneResult, err := partitionPruning(builder.executorBuilder.ctx, tbl, planPartInfo) - builder.once.condPruneResult = condPruneResult - builder.once.err = err - }) - return builder.once.condPruneResult, builder.once.err -} - -func partitionPruning(ctx sessionctx.Context, tbl table.PartitionedTable, planPartInfo *plannercore.PhysPlanPartInfo) ([]table.PhysicalTable, error) { - var pruningConds []expression.Expression - var partitionNames []model.CIStr - var columns []*expression.Column - var columnNames types.NameSlice - if planPartInfo != nil { - pruningConds = planPartInfo.PruningConds - partitionNames = planPartInfo.PartitionNames - columns = planPartInfo.Columns - columnNames = planPartInfo.ColumnNames - } - idxArr, err := plannercore.PartitionPruning(ctx.GetPlanCtx(), tbl, pruningConds, partitionNames, columns, columnNames) - if err != nil { - return nil, err - } - - pi := tbl.Meta().GetPartitionInfo() - var ret []table.PhysicalTable - if fullRangePartition(idxArr) { - ret = make([]table.PhysicalTable, 0, len(pi.Definitions)) - for _, def := range pi.Definitions { - p := tbl.GetPartition(def.ID) - ret = append(ret, p) - } - } else { - ret = make([]table.PhysicalTable, 0, len(idxArr)) - for _, idx := range idxArr { - pid := pi.Definitions[idx].ID - p := tbl.GetPartition(pid) - ret = append(ret, p) - } - } - return ret, nil -} - -func getPartitionIDsAfterPruning(ctx sessionctx.Context, tbl table.PartitionedTable, physPlanPartInfo *plannercore.PhysPlanPartInfo) (map[int64]struct{}, error) { - if physPlanPartInfo == nil { - return nil, errors.New("physPlanPartInfo in getPartitionIDsAfterPruning must not be nil") - } - idxArr, err := plannercore.PartitionPruning(ctx.GetPlanCtx(), tbl, physPlanPartInfo.PruningConds, physPlanPartInfo.PartitionNames, physPlanPartInfo.Columns, physPlanPartInfo.ColumnNames) - if err != nil { - return nil, err - } - - var ret map[int64]struct{} - - pi := tbl.Meta().GetPartitionInfo() - if fullRangePartition(idxArr) { - ret = make(map[int64]struct{}, len(pi.Definitions)) - for _, def := range pi.Definitions { - ret[def.ID] = struct{}{} - } - } else { - ret = make(map[int64]struct{}, len(idxArr)) - for _, idx := range idxArr { - pid := pi.Definitions[idx].ID - ret[pid] = struct{}{} - } - } - return ret, nil -} - -func fullRangePartition(idxArr []int) bool { - return len(idxArr) == 1 && idxArr[0] == plannercore.FullRange -} - -type emptySampler struct{} - -func (*emptySampler) writeChunk(_ *chunk.Chunk) error { - return nil -} - -func (*emptySampler) finished() bool { - return true -} - -func (b *executorBuilder) buildTableSample(v *plannercore.PhysicalTableSample) *TableSampleExecutor { - startTS, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - e := &TableSampleExecutor{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - table: v.TableInfo, - startTS: startTS, - } - - tblInfo := v.TableInfo.Meta() - if tblInfo.TempTableType != model.TempTableNone { - if tblInfo.TempTableType != model.TempTableGlobal { - b.err = errors.New("TABLESAMPLE clause can not be applied to local temporary tables") - return nil - } - e.sampler = &emptySampler{} - } else if v.TableSampleInfo.AstNode.SampleMethod == ast.SampleMethodTypeTiDBRegion { - e.sampler = newTableRegionSampler( - b.ctx, v.TableInfo, startTS, v.PhysicalTableID, v.TableSampleInfo.Partitions, v.Schema(), - v.TableSampleInfo.FullSchema, e.RetFieldTypes(), v.Desc) - } - - return e -} - -func (b *executorBuilder) buildCTE(v *plannercore.PhysicalCTE) exec.Executor { - storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages) - if !ok { - b.err = errors.New("type assertion for CTEStorageMap failed") - return nil - } - - chkSize := b.ctx.GetSessionVars().MaxChunkSize - // iterOutTbl will be constructed in CTEExec.Open(). - var resTbl cteutil.Storage - var iterInTbl cteutil.Storage - var producer *cteProducer - storages, ok := storageMap[v.CTE.IDForStorage] - if ok { - // Storage already setup. - resTbl = storages.ResTbl - iterInTbl = storages.IterInTbl - producer = storages.Producer - } else { - if v.SeedPlan == nil { - b.err = errors.New("cte.seedPlan cannot be nil") - return nil - } - // Build seed part. - corCols := plannercore.ExtractOuterApplyCorrelatedCols(v.SeedPlan) - seedExec := b.build(v.SeedPlan) - if b.err != nil { - return nil - } - - // Setup storages. - tps := seedExec.RetFieldTypes() - resTbl = cteutil.NewStorageRowContainer(tps, chkSize) - if err := resTbl.OpenAndRef(); err != nil { - b.err = err - return nil - } - iterInTbl = cteutil.NewStorageRowContainer(tps, chkSize) - if err := iterInTbl.OpenAndRef(); err != nil { - b.err = err - return nil - } - storageMap[v.CTE.IDForStorage] = &CTEStorages{ResTbl: resTbl, IterInTbl: iterInTbl} - - // Build recursive part. - var recursiveExec exec.Executor - if v.RecurPlan != nil { - recursiveExec = b.build(v.RecurPlan) - if b.err != nil { - return nil - } - corCols = append(corCols, plannercore.ExtractOuterApplyCorrelatedCols(v.RecurPlan)...) - } - - var sel []int - if v.CTE.IsDistinct { - sel = make([]int, chkSize) - for i := 0; i < chkSize; i++ { - sel[i] = i - } - } - - var corColHashCodes [][]byte - for _, corCol := range corCols { - corColHashCodes = append(corColHashCodes, getCorColHashCode(corCol)) - } - - producer = &cteProducer{ - ctx: b.ctx, - seedExec: seedExec, - recursiveExec: recursiveExec, - resTbl: resTbl, - iterInTbl: iterInTbl, - isDistinct: v.CTE.IsDistinct, - sel: sel, - hasLimit: v.CTE.HasLimit, - limitBeg: v.CTE.LimitBeg, - limitEnd: v.CTE.LimitEnd, - corCols: corCols, - corColHashCodes: corColHashCodes, - } - storageMap[v.CTE.IDForStorage].Producer = producer - } - - return &CTEExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - producer: producer, - } -} - -func (b *executorBuilder) buildCTETableReader(v *plannercore.PhysicalCTETable) exec.Executor { - storageMap, ok := b.ctx.GetSessionVars().StmtCtx.CTEStorageMap.(map[int]*CTEStorages) - if !ok { - b.err = errors.New("type assertion for CTEStorageMap failed") - return nil - } - storages, ok := storageMap[v.IDForStorage] - if !ok { - b.err = errors.Errorf("iterInTbl should already be set up by CTEExec(id: %d)", v.IDForStorage) - return nil - } - return &CTETableReaderExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - iterInTbl: storages.IterInTbl, - chkIdx: 0, - } -} -func (b *executorBuilder) validCanReadTemporaryOrCacheTable(tbl *model.TableInfo) error { - err := b.validCanReadTemporaryTable(tbl) - if err != nil { - return err - } - return b.validCanReadCacheTable(tbl) -} - -func (b *executorBuilder) validCanReadCacheTable(tbl *model.TableInfo) error { - if tbl.TableCacheStatusType == model.TableCacheStatusDisable { - return nil - } - - sessionVars := b.ctx.GetSessionVars() - - // Temporary table can't switch into cache table. so the following code will not cause confusion - if sessionVars.TxnCtx.IsStaleness || b.isStaleness { - return errors.Trace(errors.New("can not stale read cache table")) - } - - return nil -} - -func (b *executorBuilder) validCanReadTemporaryTable(tbl *model.TableInfo) error { - if tbl.TempTableType == model.TempTableNone { - return nil - } - - // Some tools like dumpling use history read to dump all table's records and will be fail if we return an error. - // So we do not check SnapshotTS here - - sessionVars := b.ctx.GetSessionVars() - - if tbl.TempTableType == model.TempTableLocal && sessionVars.SnapshotTS != 0 { - return errors.New("can not read local temporary table when 'tidb_snapshot' is set") - } - - if sessionVars.TxnCtx.IsStaleness || b.isStaleness { - return errors.New("can not stale read temporary table") - } - - return nil -} - -func (b *executorBuilder) getCacheTable(tblInfo *model.TableInfo, startTS uint64) kv.MemBuffer { - tbl, ok := b.is.TableByID(tblInfo.ID) - if !ok { - b.err = errors.Trace(infoschema.ErrTableNotExists.GenWithStackByArgs(b.ctx.GetSessionVars().CurrentDB, tblInfo.Name)) - return nil - } - sessVars := b.ctx.GetSessionVars() - leaseDuration := time.Duration(variable.TableCacheLease.Load()) * time.Second - cacheData, loading := tbl.(table.CachedTable).TryReadFromCache(startTS, leaseDuration) - if cacheData != nil { - sessVars.StmtCtx.ReadFromTableCache = true - return cacheData - } else if loading { - return nil - } - if !b.ctx.GetSessionVars().StmtCtx.InExplainStmt && !b.inDeleteStmt && !b.inUpdateStmt { - tbl.(table.CachedTable).UpdateLockForRead(context.Background(), b.ctx.GetStore(), startTS, leaseDuration) - } - return nil -} - -func (b *executorBuilder) buildCompactTable(v *plannercore.CompactTable) exec.Executor { - if v.ReplicaKind != ast.CompactReplicaKindTiFlash && v.ReplicaKind != ast.CompactReplicaKindAll { - b.err = errors.Errorf("compact %v replica is not supported", strings.ToLower(string(v.ReplicaKind))) - return nil - } - - store := b.ctx.GetStore() - tikvStore, ok := store.(tikv.Storage) - if !ok { - b.err = errors.New("compact tiflash replica can only run with tikv compatible storage") - return nil - } - - var partitionIDs []int64 - if v.PartitionNames != nil { - if v.TableInfo.Partition == nil { - b.err = errors.Errorf("table:%s is not a partition table, but user specify partition name list:%+v", v.TableInfo.Name.O, v.PartitionNames) - return nil - } - // use map to avoid FindPartitionDefinitionByName - partitionMap := map[string]int64{} - for _, partition := range v.TableInfo.Partition.Definitions { - partitionMap[partition.Name.L] = partition.ID - } - - for _, partitionName := range v.PartitionNames { - partitionID, ok := partitionMap[partitionName.L] - if !ok { - b.err = table.ErrUnknownPartition.GenWithStackByArgs(partitionName.O, v.TableInfo.Name.O) - return nil - } - partitionIDs = append(partitionIDs, partitionID) - } - } - - return &CompactTableTiFlashExec{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID()), - tableInfo: v.TableInfo, - partitionIDs: partitionIDs, - tikvStore: tikvStore, - } -} - -func (b *executorBuilder) buildAdminShowBDRRole(v *plannercore.AdminShowBDRRole) exec.Executor { - return &AdminShowBDRRoleExec{BaseExecutor: exec.NewBaseExecutor(b.ctx, v.Schema(), v.ID())} -} diff --git a/pkg/executor/checksum.go b/pkg/executor/checksum.go index 2fde5a48869fb..c4d349359a74d 100644 --- a/pkg/executor/checksum.go +++ b/pkg/executor/checksum.go @@ -159,7 +159,7 @@ func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.C if err1 := res.Close(); err1 != nil { err = err1 } - failpoint.Eval(_curpkg_("afterHandleChecksumRequest")) + failpoint.Inject("afterHandleChecksumRequest", nil) }() resp = &tipb.ChecksumResponse{} diff --git a/pkg/executor/checksum.go__failpoint_stash__ b/pkg/executor/checksum.go__failpoint_stash__ deleted file mode 100644 index c4d349359a74d..0000000000000 --- a/pkg/executor/checksum.go__failpoint_stash__ +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "strconv" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -var _ exec.Executor = &ChecksumTableExec{} - -// ChecksumTableExec represents ChecksumTable executor. -type ChecksumTableExec struct { - exec.BaseExecutor - - tables map[int64]*checksumContext // tableID -> checksumContext - done bool -} - -// Open implements the Executor Open interface. -func (e *ChecksumTableExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - - concurrency, err := getChecksumTableConcurrency(e.Ctx()) - if err != nil { - return err - } - - tasks, err := e.buildTasks() - if err != nil { - return err - } - - taskCh := make(chan *checksumTask, len(tasks)) - resultCh := make(chan *checksumResult, len(tasks)) - for i := 0; i < concurrency; i++ { - go e.checksumWorker(taskCh, resultCh) - } - - for _, task := range tasks { - taskCh <- task - } - close(taskCh) - - for i := 0; i < len(tasks); i++ { - result := <-resultCh - if result.err != nil { - err = result.err - logutil.Logger(ctx).Error("checksum failed", zap.Error(err)) - continue - } - logutil.Logger(ctx).Info( - "got one checksum result", - zap.Int64("tableID", result.tableID), - zap.Int64("physicalTableID", result.physicalTableID), - zap.Int64("indexID", result.indexID), - zap.Uint64("checksum", result.response.Checksum), - zap.Uint64("totalKvs", result.response.TotalKvs), - zap.Uint64("totalBytes", result.response.TotalBytes), - ) - e.handleResult(result) - } - if err != nil { - return err - } - - return nil -} - -// Next implements the Executor Next interface. -func (e *ChecksumTableExec) Next(_ context.Context, req *chunk.Chunk) error { - req.Reset() - if e.done { - return nil - } - for _, t := range e.tables { - req.AppendString(0, t.dbInfo.Name.O) - req.AppendString(1, t.tableInfo.Name.O) - req.AppendUint64(2, t.response.Checksum) - req.AppendUint64(3, t.response.TotalKvs) - req.AppendUint64(4, t.response.TotalBytes) - } - e.done = true - return nil -} - -func (e *ChecksumTableExec) buildTasks() ([]*checksumTask, error) { - allTasks := make([][]*checksumTask, 0, len(e.tables)) - taskCnt := 0 - for _, t := range e.tables { - tasks, err := t.buildTasks(e.Ctx()) - if err != nil { - return nil, err - } - allTasks = append(allTasks, tasks) - taskCnt += len(tasks) - } - ret := make([]*checksumTask, 0, taskCnt) - for _, task := range allTasks { - ret = append(ret, task...) - } - return ret, nil -} - -func (e *ChecksumTableExec) handleResult(result *checksumResult) { - table := e.tables[result.tableID] - table.handleResponse(result.response) -} - -func (e *ChecksumTableExec) checksumWorker(taskCh <-chan *checksumTask, resultCh chan<- *checksumResult) { - for task := range taskCh { - result := &checksumResult{ - tableID: task.tableID, - physicalTableID: task.physicalTableID, - indexID: task.indexID, - } - result.response, result.err = e.handleChecksumRequest(task.request) - resultCh <- result - } -} - -func (e *ChecksumTableExec) handleChecksumRequest(req *kv.Request) (resp *tipb.ChecksumResponse, err error) { - if err = e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { - return nil, err - } - ctx := distsql.WithSQLKvExecCounterInterceptor(context.TODO(), e.Ctx().GetSessionVars().StmtCtx.KvExecCounter) - res, err := distsql.Checksum(ctx, e.Ctx().GetClient(), req, e.Ctx().GetSessionVars().KVVars) - if err != nil { - return nil, err - } - defer func() { - if err1 := res.Close(); err1 != nil { - err = err1 - } - failpoint.Inject("afterHandleChecksumRequest", nil) - }() - - resp = &tipb.ChecksumResponse{} - - for { - data, err := res.NextRaw(ctx) - if err != nil { - return nil, err - } - if data == nil { - break - } - checksum := &tipb.ChecksumResponse{} - if err = checksum.Unmarshal(data); err != nil { - return nil, err - } - updateChecksumResponse(resp, checksum) - if err = e.Ctx().GetSessionVars().SQLKiller.HandleSignal(); err != nil { - return nil, err - } - } - - return resp, nil -} - -type checksumTask struct { - tableID int64 - physicalTableID int64 - indexID int64 - request *kv.Request -} - -type checksumResult struct { - err error - tableID int64 - physicalTableID int64 - indexID int64 - response *tipb.ChecksumResponse -} - -type checksumContext struct { - dbInfo *model.DBInfo - tableInfo *model.TableInfo - startTs uint64 - response *tipb.ChecksumResponse -} - -func newChecksumContext(db *model.DBInfo, table *model.TableInfo, startTs uint64) *checksumContext { - return &checksumContext{ - dbInfo: db, - tableInfo: table, - startTs: startTs, - response: &tipb.ChecksumResponse{}, - } -} - -func (c *checksumContext) buildTasks(ctx sessionctx.Context) ([]*checksumTask, error) { - var partDefs []model.PartitionDefinition - if part := c.tableInfo.Partition; part != nil { - partDefs = part.Definitions - } - - reqs := make([]*checksumTask, 0, (len(c.tableInfo.Indices)+1)*(len(partDefs)+1)) - if err := c.appendRequest4PhysicalTable(ctx, c.tableInfo.ID, c.tableInfo.ID, &reqs); err != nil { - return nil, err - } - - for _, partDef := range partDefs { - if err := c.appendRequest4PhysicalTable(ctx, c.tableInfo.ID, partDef.ID, &reqs); err != nil { - return nil, err - } - } - - return reqs, nil -} - -func (c *checksumContext) appendRequest4PhysicalTable( - ctx sessionctx.Context, - tableID int64, - physicalTableID int64, - reqs *[]*checksumTask, -) error { - req, err := c.buildTableRequest(ctx, physicalTableID) - if err != nil { - return err - } - - *reqs = append(*reqs, &checksumTask{ - tableID: tableID, - physicalTableID: physicalTableID, - indexID: -1, - request: req, - }) - for _, indexInfo := range c.tableInfo.Indices { - if indexInfo.State != model.StatePublic { - continue - } - req, err = c.buildIndexRequest(ctx, physicalTableID, indexInfo) - if err != nil { - return err - } - *reqs = append(*reqs, &checksumTask{ - tableID: tableID, - physicalTableID: physicalTableID, - indexID: indexInfo.ID, - request: req, - }) - } - - return nil -} - -func (c *checksumContext) buildTableRequest(ctx sessionctx.Context, physicalTableID int64) (*kv.Request, error) { - checksum := &tipb.ChecksumRequest{ - ScanOn: tipb.ChecksumScanOn_Table, - Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, - } - - var ranges []*ranger.Range - if c.tableInfo.IsCommonHandle { - ranges = ranger.FullNotNullRange() - } else { - ranges = ranger.FullIntRange(false) - } - - var builder distsql.RequestBuilder - builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) - return builder.SetHandleRanges(ctx.GetDistSQLCtx(), physicalTableID, c.tableInfo.IsCommonHandle, ranges). - SetChecksumRequest(checksum). - SetStartTS(c.startTs). - SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency()). - SetResourceGroupName(ctx.GetSessionVars().StmtCtx.ResourceGroupName). - SetExplicitRequestSourceType(ctx.GetSessionVars().ExplicitRequestSourceType). - Build() -} - -func (c *checksumContext) buildIndexRequest(ctx sessionctx.Context, physicalTableID int64, indexInfo *model.IndexInfo) (*kv.Request, error) { - checksum := &tipb.ChecksumRequest{ - ScanOn: tipb.ChecksumScanOn_Index, - Algorithm: tipb.ChecksumAlgorithm_Crc64_Xor, - } - - ranges := ranger.FullRange() - - var builder distsql.RequestBuilder - builder.SetResourceGroupTagger(ctx.GetSessionVars().StmtCtx.GetResourceGroupTagger()) - return builder.SetIndexRanges(ctx.GetDistSQLCtx(), physicalTableID, indexInfo.ID, ranges). - SetChecksumRequest(checksum). - SetStartTS(c.startTs). - SetConcurrency(ctx.GetSessionVars().DistSQLScanConcurrency()). - SetResourceGroupName(ctx.GetSessionVars().StmtCtx.ResourceGroupName). - SetExplicitRequestSourceType(ctx.GetSessionVars().ExplicitRequestSourceType). - Build() -} - -func (c *checksumContext) handleResponse(update *tipb.ChecksumResponse) { - updateChecksumResponse(c.response, update) -} - -func getChecksumTableConcurrency(ctx sessionctx.Context) (int, error) { - sessionVars := ctx.GetSessionVars() - concurrency, err := sessionVars.GetSessionOrGlobalSystemVar(context.Background(), variable.TiDBChecksumTableConcurrency) - if err != nil { - return 0, err - } - c, err := strconv.ParseInt(concurrency, 10, 64) - return int(c), err -} - -func updateChecksumResponse(resp, update *tipb.ChecksumResponse) { - resp.Checksum ^= update.Checksum - resp.TotalKvs += update.TotalKvs - resp.TotalBytes += update.TotalBytes -} diff --git a/pkg/executor/compiler.go b/pkg/executor/compiler.go index ea0a6cf1d4b7d..8771753e78626 100644 --- a/pkg/executor/compiler.go +++ b/pkg/executor/compiler.go @@ -75,14 +75,14 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (_ *ExecS return nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInCompile")); _err_ == nil { + failpoint.Inject("assertTxnManagerInCompile", func() { sessiontxn.RecordAssert(c.Ctx, "assertTxnManagerInCompile", true) sessiontxn.AssertTxnManagerInfoSchema(c.Ctx, ret.InfoSchema) if ret.LastSnapshotTS != 0 { staleread.AssertStmtStaleness(c.Ctx, true) sessiontxn.AssertTxnManagerReadTS(c.Ctx, ret.LastSnapshotTS) } - } + }) is := sessiontxn.GetTxnManager(c.Ctx).GetTxnInfoSchema() sessVars := c.Ctx.GetSessionVars() @@ -101,9 +101,9 @@ func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (_ *ExecS return nil, err } - if val, _err_ := failpoint.Eval(_curpkg_("assertStmtCtxIsStaleness")); _err_ == nil { + failpoint.Inject("assertStmtCtxIsStaleness", func(val failpoint.Value) { staleread.AssertStmtStaleness(c.Ctx, val.(bool)) - } + }) if preparedObj != nil { CountStmtNode(preparedObj.PreparedAst.Stmt, sessVars.InRestrictedSQL, stmtCtx.ResourceGroupName) diff --git a/pkg/executor/compiler.go__failpoint_stash__ b/pkg/executor/compiler.go__failpoint_stash__ deleted file mode 100644 index 8771753e78626..0000000000000 --- a/pkg/executor/compiler.go__failpoint_stash__ +++ /dev/null @@ -1,568 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/sessiontxn/staleread" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/tracing" - "go.uber.org/zap" -) - -// Compiler compiles an ast.StmtNode to a physical plan. -type Compiler struct { - Ctx sessionctx.Context -} - -// Compile compiles an ast.StmtNode to a physical plan. -func (c *Compiler) Compile(ctx context.Context, stmtNode ast.StmtNode) (_ *ExecStmt, err error) { - r, ctx := tracing.StartRegionEx(ctx, "executor.Compile") - defer r.End() - - defer func() { - r := recover() - if r == nil { - return - } - recoveredErr, ok := r.(error) - if !ok || !(exeerrors.ErrMemoryExceedForQuery.Equal(recoveredErr) || - exeerrors.ErrMemoryExceedForInstance.Equal(recoveredErr) || - exeerrors.ErrQueryInterrupted.Equal(recoveredErr) || - exeerrors.ErrMaxExecTimeExceeded.Equal(recoveredErr)) { - panic(r) - } - err = recoveredErr - logutil.Logger(ctx).Error("compile SQL panic", zap.String("SQL", stmtNode.Text()), zap.Stack("stack"), zap.Any("recover", r)) - }() - - c.Ctx.GetSessionVars().StmtCtx.IsReadOnly = plannercore.IsReadOnly(stmtNode, c.Ctx.GetSessionVars()) - - // Do preprocess and validate. - ret := &plannercore.PreprocessorReturn{} - err = plannercore.Preprocess( - ctx, - c.Ctx, - stmtNode, - plannercore.WithPreprocessorReturn(ret), - plannercore.InitTxnContextProvider, - ) - if err != nil { - return nil, err - } - - failpoint.Inject("assertTxnManagerInCompile", func() { - sessiontxn.RecordAssert(c.Ctx, "assertTxnManagerInCompile", true) - sessiontxn.AssertTxnManagerInfoSchema(c.Ctx, ret.InfoSchema) - if ret.LastSnapshotTS != 0 { - staleread.AssertStmtStaleness(c.Ctx, true) - sessiontxn.AssertTxnManagerReadTS(c.Ctx, ret.LastSnapshotTS) - } - }) - - is := sessiontxn.GetTxnManager(c.Ctx).GetTxnInfoSchema() - sessVars := c.Ctx.GetSessionVars() - stmtCtx := sessVars.StmtCtx - // handle the execute statement - var preparedObj *plannercore.PlanCacheStmt - - if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { - if preparedObj, err = plannercore.GetPreparedStmt(execStmt, sessVars); err != nil { - return nil, err - } - } - // Build the final physical plan. - finalPlan, names, err := planner.Optimize(ctx, c.Ctx, stmtNode, is) - if err != nil { - return nil, err - } - - failpoint.Inject("assertStmtCtxIsStaleness", func(val failpoint.Value) { - staleread.AssertStmtStaleness(c.Ctx, val.(bool)) - }) - - if preparedObj != nil { - CountStmtNode(preparedObj.PreparedAst.Stmt, sessVars.InRestrictedSQL, stmtCtx.ResourceGroupName) - } else { - CountStmtNode(stmtNode, sessVars.InRestrictedSQL, stmtCtx.ResourceGroupName) - } - var lowerPriority bool - if c.Ctx.GetSessionVars().StmtCtx.Priority == mysql.NoPriority { - lowerPriority = needLowerPriority(finalPlan) - } - stmtCtx.SetPlan(finalPlan) - stmt := &ExecStmt{ - GoCtx: ctx, - InfoSchema: is, - Plan: finalPlan, - LowerPriority: lowerPriority, - Text: stmtNode.Text(), - StmtNode: stmtNode, - Ctx: c.Ctx, - OutputNames: names, - } - // Use cached plan if possible. - if preparedObj != nil && plannercore.IsSafeToReusePointGetExecutor(c.Ctx, is, preparedObj) { - if exec, isExec := finalPlan.(*plannercore.Execute); isExec { - if pointPlan, isPointPlan := exec.Plan.(*plannercore.PointGetPlan); isPointPlan { - stmt.PsStmt, stmt.Plan = preparedObj, pointPlan // notify to re-use the cached plan - } - } - } - - // Perform optimization and initialization related to the transaction level. - if err = sessiontxn.AdviseOptimizeWithPlanAndThenWarmUp(c.Ctx, stmt.Plan); err != nil { - return nil, err - } - - return stmt, nil -} - -// needLowerPriority checks whether it's needed to lower the execution priority -// of a query. -// If the estimated output row count of any operator in the physical plan tree -// is greater than the specific threshold, we'll set it to lowPriority when -// sending it to the coprocessor. -func needLowerPriority(p base.Plan) bool { - switch x := p.(type) { - case base.PhysicalPlan: - return isPhysicalPlanNeedLowerPriority(x) - case *plannercore.Execute: - return needLowerPriority(x.Plan) - case *plannercore.Insert: - if x.SelectPlan != nil { - return isPhysicalPlanNeedLowerPriority(x.SelectPlan) - } - case *plannercore.Delete: - if x.SelectPlan != nil { - return isPhysicalPlanNeedLowerPriority(x.SelectPlan) - } - case *plannercore.Update: - if x.SelectPlan != nil { - return isPhysicalPlanNeedLowerPriority(x.SelectPlan) - } - } - return false -} - -func isPhysicalPlanNeedLowerPriority(p base.PhysicalPlan) bool { - expensiveThreshold := int64(config.GetGlobalConfig().Log.ExpensiveThreshold) - if int64(p.StatsCount()) > expensiveThreshold { - return true - } - - for _, child := range p.Children() { - if isPhysicalPlanNeedLowerPriority(child) { - return true - } - } - - return false -} - -// CountStmtNode records the number of statements with the same type. -func CountStmtNode(stmtNode ast.StmtNode, inRestrictedSQL bool, resourceGroup string) { - if inRestrictedSQL { - return - } - - typeLabel := ast.GetStmtLabel(stmtNode) - - if config.GetGlobalConfig().Status.RecordQPSbyDB || config.GetGlobalConfig().Status.RecordDBLabel { - dbLabels := getStmtDbLabel(stmtNode) - switch { - case config.GetGlobalConfig().Status.RecordQPSbyDB: - for dbLabel := range dbLabels { - metrics.DbStmtNodeCounter.WithLabelValues(dbLabel, typeLabel).Inc() - } - case config.GetGlobalConfig().Status.RecordDBLabel: - for dbLabel := range dbLabels { - metrics.StmtNodeCounter.WithLabelValues(typeLabel, dbLabel, resourceGroup).Inc() - } - } - } else { - metrics.StmtNodeCounter.WithLabelValues(typeLabel, "", resourceGroup).Inc() - } -} - -func getStmtDbLabel(stmtNode ast.StmtNode) map[string]struct{} { - dbLabelSet := make(map[string]struct{}) - - switch x := stmtNode.(type) { - case *ast.AlterTableStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.CreateIndexStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.CreateTableStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.InsertStmt: - var dbLabels []string - if x.Table != nil { - dbLabels = getDbFromResultNode(x.Table.TableRefs) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - dbLabels = getDbFromResultNode(x.Select) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - case *ast.DropIndexStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.TruncateTableStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.RepairTableStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.FlashBackTableStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.RecoverTableStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.CreateViewStmt: - if x.ViewName != nil { - dbLabel := x.ViewName.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.RenameTableStmt: - tables := x.TableToTables - for _, table := range tables { - if table.OldTable != nil { - dbLabel := table.OldTable.Schema.O - if _, ok := dbLabelSet[dbLabel]; !ok { - dbLabelSet[dbLabel] = struct{}{} - } - } - } - case *ast.DropTableStmt: - tables := x.Tables - for _, table := range tables { - dbLabel := table.Schema.O - if _, ok := dbLabelSet[dbLabel]; !ok { - dbLabelSet[dbLabel] = struct{}{} - } - } - case *ast.SelectStmt: - dbLabels := getDbFromResultNode(x) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - case *ast.SetOprStmt: - dbLabels := getDbFromResultNode(x) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - case *ast.UpdateStmt: - if x.TableRefs != nil { - dbLabels := getDbFromResultNode(x.TableRefs.TableRefs) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - case *ast.DeleteStmt: - if x.TableRefs != nil { - dbLabels := getDbFromResultNode(x.TableRefs.TableRefs) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - case *ast.CallStmt: - if x.Procedure != nil { - dbLabel := x.Procedure.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.ShowStmt: - dbLabelSet[x.DBName] = struct{}{} - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.LoadDataStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.ImportIntoStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.SplitRegionStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.NonTransactionalDMLStmt: - if x.ShardColumn != nil { - dbLabel := x.ShardColumn.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.AnalyzeTableStmt: - tables := x.TableNames - for _, table := range tables { - dbLabel := table.Schema.O - if _, ok := dbLabelSet[dbLabel]; !ok { - dbLabelSet[dbLabel] = struct{}{} - } - } - case *ast.DropStatsStmt: - tables := x.Tables - for _, table := range tables { - dbLabel := table.Schema.O - if _, ok := dbLabelSet[dbLabel]; !ok { - dbLabelSet[dbLabel] = struct{}{} - } - } - case *ast.AdminStmt: - tables := x.Tables - for _, table := range tables { - dbLabel := table.Schema.O - if _, ok := dbLabelSet[dbLabel]; !ok { - dbLabelSet[dbLabel] = struct{}{} - } - } - case *ast.UseStmt: - if _, ok := dbLabelSet[x.DBName]; !ok { - dbLabelSet[x.DBName] = struct{}{} - } - case *ast.FlushStmt: - tables := x.Tables - for _, table := range tables { - dbLabel := table.Schema.O - if _, ok := dbLabelSet[dbLabel]; !ok { - dbLabelSet[dbLabel] = struct{}{} - } - } - case *ast.CompactTableStmt: - if x.Table != nil { - dbLabel := x.Table.Schema.O - dbLabelSet[dbLabel] = struct{}{} - } - case *ast.CreateBindingStmt: - var resNode ast.ResultSetNode - var tableRef *ast.TableRefsClause - if x.OriginNode != nil { - switch n := x.OriginNode.(type) { - case *ast.SelectStmt: - tableRef = n.From - case *ast.DeleteStmt: - tableRef = n.TableRefs - case *ast.UpdateStmt: - tableRef = n.TableRefs - case *ast.InsertStmt: - tableRef = n.Table - } - if tableRef != nil { - resNode = tableRef.TableRefs - } else { - resNode = nil - } - dbLabels := getDbFromResultNode(resNode) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - if len(dbLabelSet) == 0 && x.HintedNode != nil { - switch n := x.HintedNode.(type) { - case *ast.SelectStmt: - tableRef = n.From - case *ast.DeleteStmt: - tableRef = n.TableRefs - case *ast.UpdateStmt: - tableRef = n.TableRefs - case *ast.InsertStmt: - tableRef = n.Table - } - if tableRef != nil { - resNode = tableRef.TableRefs - } else { - resNode = nil - } - dbLabels := getDbFromResultNode(resNode) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - case *ast.DropBindingStmt: - var resNode ast.ResultSetNode - var tableRef *ast.TableRefsClause - if x.OriginNode != nil { - switch n := x.OriginNode.(type) { - case *ast.SelectStmt: - tableRef = n.From - case *ast.DeleteStmt: - tableRef = n.TableRefs - case *ast.UpdateStmt: - tableRef = n.TableRefs - case *ast.InsertStmt: - tableRef = n.Table - } - if tableRef != nil { - resNode = tableRef.TableRefs - } else { - resNode = nil - } - dbLabels := getDbFromResultNode(resNode) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - if len(dbLabelSet) == 0 && x.HintedNode != nil { - switch n := x.HintedNode.(type) { - case *ast.SelectStmt: - tableRef = n.From - case *ast.DeleteStmt: - tableRef = n.TableRefs - case *ast.UpdateStmt: - tableRef = n.TableRefs - case *ast.InsertStmt: - tableRef = n.Table - } - if tableRef != nil { - resNode = tableRef.TableRefs - } else { - resNode = nil - } - dbLabels := getDbFromResultNode(resNode) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - case *ast.SetBindingStmt: - var resNode ast.ResultSetNode - var tableRef *ast.TableRefsClause - if x.OriginNode != nil { - switch n := x.OriginNode.(type) { - case *ast.SelectStmt: - tableRef = n.From - case *ast.DeleteStmt: - tableRef = n.TableRefs - case *ast.UpdateStmt: - tableRef = n.TableRefs - case *ast.InsertStmt: - tableRef = n.Table - } - if tableRef != nil { - resNode = tableRef.TableRefs - } else { - resNode = nil - } - dbLabels := getDbFromResultNode(resNode) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - - if len(dbLabelSet) == 0 && x.HintedNode != nil { - switch n := x.HintedNode.(type) { - case *ast.SelectStmt: - tableRef = n.From - case *ast.DeleteStmt: - tableRef = n.TableRefs - case *ast.UpdateStmt: - tableRef = n.TableRefs - case *ast.InsertStmt: - tableRef = n.Table - } - if tableRef != nil { - resNode = tableRef.TableRefs - } else { - resNode = nil - } - dbLabels := getDbFromResultNode(resNode) - for _, db := range dbLabels { - dbLabelSet[db] = struct{}{} - } - } - } - - // add "" db label - if len(dbLabelSet) == 0 { - dbLabelSet[""] = struct{}{} - } - - return dbLabelSet -} - -func getDbFromResultNode(resultNode ast.ResultSetNode) []string { // may have duplicate db name - var dbLabels []string - - if resultNode == nil { - return dbLabels - } - - switch x := resultNode.(type) { - case *ast.TableSource: - return getDbFromResultNode(x.Source) - case *ast.SelectStmt: - if x.From != nil { - return getDbFromResultNode(x.From.TableRefs) - } - case *ast.TableName: - if x.DBInfo != nil { - dbLabels = append(dbLabels, x.DBInfo.Name.O) - } - case *ast.Join: - if x.Left != nil { - dbs := getDbFromResultNode(x.Left) - if dbs != nil { - dbLabels = append(dbLabels, dbs...) - } - } - - if x.Right != nil { - dbs := getDbFromResultNode(x.Right) - if dbs != nil { - dbLabels = append(dbLabels, dbs...) - } - } - } - - return dbLabels -} diff --git a/pkg/executor/cte.go b/pkg/executor/cte.go index bd92fa3796176..4f57b89d3bd92 100644 --- a/pkg/executor/cte.go +++ b/pkg/executor/cte.go @@ -134,13 +134,13 @@ func (e *CTEExec) Close() (firstErr error) { e.producer.resTbl.Lock() defer e.producer.resTbl.Unlock() if !e.producer.closed { - if v, _err_ := failpoint.Eval(_curpkg_("mock_cte_exec_panic_avoid_deadlock")); _err_ == nil { + failpoint.Inject("mock_cte_exec_panic_avoid_deadlock", func(v failpoint.Value) { ok := v.(bool) if ok { // mock an oom panic, returning ErrMemoryExceedForQuery for error identification in recovery work. panic(exeerrors.ErrMemoryExceedForQuery) } - } + }) // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. // It means you can still read resTbl after call closeProducer(). // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). @@ -350,7 +350,7 @@ func (p *cteProducer) produce(ctx context.Context) (err error) { iterOutAction = setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) } - if val, _err_ := failpoint.Eval(_curpkg_("testCTEStorageSpill")); _err_ == nil { + failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { if val.(bool) && variable.EnableTmpStorageOnOOM.Load() { defer resAction.WaitForTest() defer iterInAction.WaitForTest() @@ -358,7 +358,7 @@ func (p *cteProducer) produce(ctx context.Context) (err error) { defer iterOutAction.WaitForTest() } } - } + }) if err = p.computeSeedPart(ctx); err != nil { p.resTbl.SetError(err) @@ -378,7 +378,7 @@ func (p *cteProducer) computeSeedPart(ctx context.Context) (err error) { err = util.GetRecoverError(r) } }() - failpoint.Eval(_curpkg_("testCTESeedPanic")) + failpoint.Inject("testCTESeedPanic", nil) p.curIter = 0 p.iterInTbl.SetIter(p.curIter) chks := make([]*chunk.Chunk, 0, 10) @@ -417,7 +417,7 @@ func (p *cteProducer) computeRecursivePart(ctx context.Context) (err error) { err = util.GetRecoverError(r) } }() - failpoint.Eval(_curpkg_("testCTERecursivePanic")) + failpoint.Inject("testCTERecursivePanic", nil) if p.recursiveExec == nil || p.iterInTbl.NumChunks() == 0 { return } @@ -442,14 +442,14 @@ func (p *cteProducer) computeRecursivePart(ctx context.Context) (err error) { p.logTbls(ctx, err, iterNum, zapcore.DebugLevel) } iterNum++ - if maxIter, _err_ := failpoint.Eval(_curpkg_("assertIterTableSpillToDisk")); _err_ == nil { + failpoint.Inject("assertIterTableSpillToDisk", func(maxIter failpoint.Value) { if iterNum > 0 && iterNum < uint64(maxIter.(int)) && err == nil { if p.iterInTbl.GetDiskBytes() == 0 && p.iterOutTbl.GetDiskBytes() == 0 && p.resTbl.GetDiskBytes() == 0 { p.logTbls(ctx, err, iterNum, zapcore.InfoLevel) panic("assert row container spill disk failed") } } - } + }) if err = p.setupTblsForNewIteration(); err != nil { return @@ -582,11 +582,11 @@ func setupCTEStorageTracker(tbl cteutil.Storage, ctx sessionctx.Context, parentM if variable.EnableTmpStorageOnOOM.Load() { actionSpill = tbl.ActionSpill() - if val, _err_ := failpoint.Eval(_curpkg_("testCTEStorageSpill")); _err_ == nil { + failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { if val.(bool) { actionSpill = tbl.(*cteutil.StorageRC).ActionSpillForTest() } - } + }) ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } return actionSpill diff --git a/pkg/executor/cte.go__failpoint_stash__ b/pkg/executor/cte.go__failpoint_stash__ deleted file mode 100644 index 4f57b89d3bd92..0000000000000 --- a/pkg/executor/cte.go__failpoint_stash__ +++ /dev/null @@ -1,770 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "context" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/join" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/cteutil" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -var _ exec.Executor = &CTEExec{} - -// CTEExec implements CTE. -// Following diagram describes how CTEExec works. -// -// `iterInTbl` is shared by `CTEExec` and `CTETableReaderExec`. -// `CTETableReaderExec` reads data from `iterInTbl`, -// and its output will be stored `iterOutTbl` by `CTEExec`. -// -// When an iteration ends, `CTEExec` will move all data from `iterOutTbl` into `iterInTbl`, -// which will be the input for new iteration. -// At the end of each iteration, data in `iterOutTbl` will also be added into `resTbl`. -// `resTbl` stores data of all iteration. -/* - +----------+ - write |iterOutTbl| - CTEExec ------------------->| | - | +----+-----+ - ------------- | write - | | v - other op other op +----------+ - (seed) (recursive) | resTbl | - ^ | | - | +----------+ - CTETableReaderExec - ^ - | read +----------+ - +---------------+iterInTbl | - | | - +----------+ -*/ -type CTEExec struct { - exec.BaseExecutor - - chkIdx int - producer *cteProducer - - // limit in recursive CTE. - cursor uint64 - meetFirstBatch bool -} - -// Open implements the Executor interface. -func (e *CTEExec) Open(ctx context.Context) (err error) { - e.reset() - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - - e.producer.resTbl.Lock() - defer e.producer.resTbl.Unlock() - - if e.producer.checkAndUpdateCorColHashCode() { - e.producer.reset() - if err = e.producer.reopenTbls(); err != nil { - return err - } - } - if e.producer.openErr != nil { - return e.producer.openErr - } - if !e.producer.opened { - if err = e.producer.openProducer(ctx, e); err != nil { - return err - } - } - return nil -} - -// Next implements the Executor interface. -func (e *CTEExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - e.producer.resTbl.Lock() - defer e.producer.resTbl.Unlock() - if !e.producer.resTbl.Done() { - if err = e.producer.produce(ctx); err != nil { - return err - } - } - return e.producer.getChunk(e, req) -} - -func setFirstErr(firstErr error, newErr error, msg string) error { - if newErr != nil { - logutil.BgLogger().Error("cte got error", zap.Any("err", newErr), zap.Any("extra msg", msg)) - if firstErr == nil { - firstErr = newErr - } - } - return firstErr -} - -// Close implements the Executor interface. -func (e *CTEExec) Close() (firstErr error) { - func() { - e.producer.resTbl.Lock() - defer e.producer.resTbl.Unlock() - if !e.producer.closed { - failpoint.Inject("mock_cte_exec_panic_avoid_deadlock", func(v failpoint.Value) { - ok := v.(bool) - if ok { - // mock an oom panic, returning ErrMemoryExceedForQuery for error identification in recovery work. - panic(exeerrors.ErrMemoryExceedForQuery) - } - }) - // closeProducer() only close seedExec and recursiveExec, will not touch resTbl. - // It means you can still read resTbl after call closeProducer(). - // You can even call all three functions(openProducer/produce/closeProducer) in CTEExec.Next(). - // Separating these three function calls is only to follow the abstraction of the volcano model. - err := e.producer.closeProducer() - firstErr = setFirstErr(firstErr, err, "close cte producer error") - } - }() - err := e.BaseExecutor.Close() - firstErr = setFirstErr(firstErr, err, "close cte children error") - return -} - -func (e *CTEExec) reset() { - e.chkIdx = 0 - e.cursor = 0 - e.meetFirstBatch = false -} - -type cteProducer struct { - // opened should be false when not open or open fail(a.k.a. openErr != nil) - opened bool - produced bool - closed bool - - // cteProducer is shared by multiple operators, so if the first operator tries to open - // and got error, the second should return open error directly instead of open again. - // Otherwise there may be resource leak because Close() only clean resource for the last Open(). - openErr error - - ctx sessionctx.Context - - seedExec exec.Executor - recursiveExec exec.Executor - - // `resTbl` and `iterInTbl` are shared by all CTEExec which reference to same the CTE. - // `iterInTbl` is also shared by CTETableReaderExec. - resTbl cteutil.Storage - iterInTbl cteutil.Storage - iterOutTbl cteutil.Storage - - hashTbl join.BaseHashTable - - // UNION ALL or UNION DISTINCT. - isDistinct bool - curIter int - hCtx *join.HashContext - sel []int - - // Limit related info. - hasLimit bool - limitBeg uint64 - limitEnd uint64 - - memTracker *memory.Tracker - diskTracker *disk.Tracker - - // Correlated Column. - corCols []*expression.CorrelatedColumn - corColHashCodes [][]byte -} - -func (p *cteProducer) openProducer(ctx context.Context, cteExec *CTEExec) (err error) { - defer func() { - p.openErr = err - if err == nil { - p.opened = true - } else { - p.opened = false - } - }() - if p.seedExec == nil { - return errors.New("seedExec for CTEExec is nil") - } - if err = exec.Open(ctx, p.seedExec); err != nil { - return err - } - - p.resetTracker() - p.memTracker = memory.NewTracker(cteExec.ID(), -1) - p.diskTracker = disk.NewTracker(cteExec.ID(), -1) - p.memTracker.AttachTo(p.ctx.GetSessionVars().StmtCtx.MemTracker) - p.diskTracker.AttachTo(p.ctx.GetSessionVars().StmtCtx.DiskTracker) - - if p.recursiveExec != nil { - if err = exec.Open(ctx, p.recursiveExec); err != nil { - return err - } - // For non-recursive CTE, the result will be put into resTbl directly. - // So no need to build iterOutTbl. - // Construct iterOutTbl in Open() instead of buildCTE(), because its destruct is in Close(). - recursiveTypes := p.recursiveExec.RetFieldTypes() - p.iterOutTbl = cteutil.NewStorageRowContainer(recursiveTypes, cteExec.MaxChunkSize()) - if err = p.iterOutTbl.OpenAndRef(); err != nil { - return err - } - } - - if p.isDistinct { - p.hashTbl = join.NewConcurrentMapHashTable() - p.hCtx = &join.HashContext{ - AllTypes: cteExec.RetFieldTypes(), - } - // We use all columns to compute hash. - p.hCtx.KeyColIdx = make([]int, len(p.hCtx.AllTypes)) - for i := range p.hCtx.KeyColIdx { - p.hCtx.KeyColIdx[i] = i - } - } - return nil -} - -func (p *cteProducer) closeProducer() (firstErr error) { - err := exec.Close(p.seedExec) - firstErr = setFirstErr(firstErr, err, "close seedExec err") - - if p.recursiveExec != nil { - err = exec.Close(p.recursiveExec) - firstErr = setFirstErr(firstErr, err, "close recursiveExec err") - - // `iterInTbl` and `resTbl` are shared by multiple operators, - // so will be closed when the SQL finishes. - if p.iterOutTbl != nil { - err = p.iterOutTbl.DerefAndClose() - firstErr = setFirstErr(firstErr, err, "deref iterOutTbl err") - } - } - // Reset to nil instead of calling Detach(), - // because ExplainExec still needs tracker to get mem usage info. - p.memTracker = nil - p.diskTracker = nil - p.closed = true - return -} - -func (p *cteProducer) getChunk(cteExec *CTEExec, req *chunk.Chunk) (err error) { - req.Reset() - if p.hasLimit { - return p.nextChunkLimit(cteExec, req) - } - if cteExec.chkIdx < p.resTbl.NumChunks() { - res, err := p.resTbl.GetChunk(cteExec.chkIdx) - if err != nil { - return err - } - // Need to copy chunk to make sure upper operator will not change chunk in resTbl. - // Also we ignore copying rows not selected, because some operators like Projection - // doesn't support swap column if chunk.sel is no nil. - req.SwapColumns(res.CopyConstructSel()) - cteExec.chkIdx++ - } - return nil -} - -func (p *cteProducer) nextChunkLimit(cteExec *CTEExec, req *chunk.Chunk) error { - if !cteExec.meetFirstBatch { - for cteExec.chkIdx < p.resTbl.NumChunks() { - res, err := p.resTbl.GetChunk(cteExec.chkIdx) - if err != nil { - return err - } - cteExec.chkIdx++ - numRows := uint64(res.NumRows()) - if newCursor := cteExec.cursor + numRows; newCursor >= p.limitBeg { - cteExec.meetFirstBatch = true - begInChk, endInChk := p.limitBeg-cteExec.cursor, numRows - if newCursor > p.limitEnd { - endInChk = p.limitEnd - cteExec.cursor - } - cteExec.cursor += endInChk - if begInChk == endInChk { - break - } - tmpChk := res.CopyConstructSel() - req.Append(tmpChk, int(begInChk), int(endInChk)) - return nil - } - cteExec.cursor += numRows - } - } - if cteExec.chkIdx < p.resTbl.NumChunks() && cteExec.cursor < p.limitEnd { - res, err := p.resTbl.GetChunk(cteExec.chkIdx) - if err != nil { - return err - } - cteExec.chkIdx++ - numRows := uint64(res.NumRows()) - if cteExec.cursor+numRows > p.limitEnd { - numRows = p.limitEnd - cteExec.cursor - req.Append(res.CopyConstructSel(), 0, int(numRows)) - } else { - req.SwapColumns(res.CopyConstructSel()) - } - cteExec.cursor += numRows - } - return nil -} - -func (p *cteProducer) produce(ctx context.Context) (err error) { - if p.resTbl.Error() != nil { - return p.resTbl.Error() - } - resAction := setupCTEStorageTracker(p.resTbl, p.ctx, p.memTracker, p.diskTracker) - iterInAction := setupCTEStorageTracker(p.iterInTbl, p.ctx, p.memTracker, p.diskTracker) - var iterOutAction *chunk.SpillDiskAction - if p.iterOutTbl != nil { - iterOutAction = setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) - } - - failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { - if val.(bool) && variable.EnableTmpStorageOnOOM.Load() { - defer resAction.WaitForTest() - defer iterInAction.WaitForTest() - if iterOutAction != nil { - defer iterOutAction.WaitForTest() - } - } - }) - - if err = p.computeSeedPart(ctx); err != nil { - p.resTbl.SetError(err) - return err - } - if err = p.computeRecursivePart(ctx); err != nil { - p.resTbl.SetError(err) - return err - } - p.resTbl.SetDone() - return nil -} - -func (p *cteProducer) computeSeedPart(ctx context.Context) (err error) { - defer func() { - if r := recover(); r != nil && err == nil { - err = util.GetRecoverError(r) - } - }() - failpoint.Inject("testCTESeedPanic", nil) - p.curIter = 0 - p.iterInTbl.SetIter(p.curIter) - chks := make([]*chunk.Chunk, 0, 10) - for { - if p.limitDone(p.iterInTbl) { - break - } - chk := exec.TryNewCacheChunk(p.seedExec) - if err = exec.Next(ctx, p.seedExec, chk); err != nil { - return - } - if chk.NumRows() == 0 { - break - } - if chk, err = p.tryDedupAndAdd(chk, p.iterInTbl, p.hashTbl); err != nil { - return - } - chks = append(chks, chk) - } - // Initial resTbl is empty, so no need to deduplicate chk using resTbl. - // Just adding is ok. - for _, chk := range chks { - if err = p.resTbl.Add(chk); err != nil { - return - } - } - p.curIter++ - p.iterInTbl.SetIter(p.curIter) - - return -} - -func (p *cteProducer) computeRecursivePart(ctx context.Context) (err error) { - defer func() { - if r := recover(); r != nil && err == nil { - err = util.GetRecoverError(r) - } - }() - failpoint.Inject("testCTERecursivePanic", nil) - if p.recursiveExec == nil || p.iterInTbl.NumChunks() == 0 { - return - } - - if p.curIter > p.ctx.GetSessionVars().CTEMaxRecursionDepth { - return exeerrors.ErrCTEMaxRecursionDepth.GenWithStackByArgs(p.curIter) - } - - if p.limitDone(p.resTbl) { - return - } - - var iterNum uint64 - for { - chk := exec.TryNewCacheChunk(p.recursiveExec) - if err = exec.Next(ctx, p.recursiveExec, chk); err != nil { - return - } - if chk.NumRows() == 0 { - if iterNum%1000 == 0 { - // To avoid too many logs. - p.logTbls(ctx, err, iterNum, zapcore.DebugLevel) - } - iterNum++ - failpoint.Inject("assertIterTableSpillToDisk", func(maxIter failpoint.Value) { - if iterNum > 0 && iterNum < uint64(maxIter.(int)) && err == nil { - if p.iterInTbl.GetDiskBytes() == 0 && p.iterOutTbl.GetDiskBytes() == 0 && p.resTbl.GetDiskBytes() == 0 { - p.logTbls(ctx, err, iterNum, zapcore.InfoLevel) - panic("assert row container spill disk failed") - } - } - }) - - if err = p.setupTblsForNewIteration(); err != nil { - return - } - if p.limitDone(p.resTbl) { - break - } - if p.iterInTbl.NumChunks() == 0 { - break - } - // Next iteration begins. Need use iterOutTbl as input of next iteration. - p.curIter++ - p.iterInTbl.SetIter(p.curIter) - if p.curIter > p.ctx.GetSessionVars().CTEMaxRecursionDepth { - return exeerrors.ErrCTEMaxRecursionDepth.GenWithStackByArgs(p.curIter) - } - // Make sure iterInTbl is setup before Close/Open, - // because some executors will read iterInTbl in Open() (like IndexLookupJoin). - if err = exec.Close(p.recursiveExec); err != nil { - return - } - if err = exec.Open(ctx, p.recursiveExec); err != nil { - return - } - } else { - if err = p.iterOutTbl.Add(chk); err != nil { - return - } - } - } - return -} - -func (p *cteProducer) setupTblsForNewIteration() (err error) { - num := p.iterOutTbl.NumChunks() - chks := make([]*chunk.Chunk, 0, num) - // Setup resTbl's data. - for i := 0; i < num; i++ { - chk, err := p.iterOutTbl.GetChunk(i) - if err != nil { - return err - } - // Data should be copied in UNION DISTINCT. - // Because deduplicate() will change data in iterOutTbl, - // which will cause panic when spilling data into disk concurrently. - if p.isDistinct { - chk = chk.CopyConstruct() - } - chk, err = p.tryDedupAndAdd(chk, p.resTbl, p.hashTbl) - if err != nil { - return err - } - chks = append(chks, chk) - } - - // Setup new iteration data in iterInTbl. - if err = p.iterInTbl.Reopen(); err != nil { - return err - } - setupCTEStorageTracker(p.iterInTbl, p.ctx, p.memTracker, p.diskTracker) - - if p.isDistinct { - // Already deduplicated by resTbl, adding directly is ok. - for _, chk := range chks { - if err = p.iterInTbl.Add(chk); err != nil { - return err - } - } - } else { - if err = p.iterInTbl.SwapData(p.iterOutTbl); err != nil { - return err - } - } - - // Clear data in iterOutTbl. - if err = p.iterOutTbl.Reopen(); err != nil { - return err - } - setupCTEStorageTracker(p.iterOutTbl, p.ctx, p.memTracker, p.diskTracker) - return nil -} - -func (p *cteProducer) reset() { - p.curIter = 0 - p.hashTbl = nil - - p.opened = false - p.openErr = nil - p.produced = false - p.closed = false -} - -func (p *cteProducer) resetTracker() { - if p.memTracker != nil { - p.memTracker.Reset() - p.memTracker = nil - } - if p.diskTracker != nil { - p.diskTracker.Reset() - p.diskTracker = nil - } -} - -func (p *cteProducer) reopenTbls() (err error) { - if p.isDistinct { - p.hashTbl = join.NewConcurrentMapHashTable() - } - // Normally we need to setup tracker after calling Reopen(), - // But reopen resTbl means we need to call produce() again, it will setup tracker. - if err := p.resTbl.Reopen(); err != nil { - return err - } - return p.iterInTbl.Reopen() -} - -// Check if tbl meets the requirement of limit. -func (p *cteProducer) limitDone(tbl cteutil.Storage) bool { - return p.hasLimit && uint64(tbl.NumRows()) >= p.limitEnd -} - -func setupCTEStorageTracker(tbl cteutil.Storage, ctx sessionctx.Context, parentMemTracker *memory.Tracker, - parentDiskTracker *disk.Tracker) (actionSpill *chunk.SpillDiskAction) { - memTracker := tbl.GetMemTracker() - memTracker.SetLabel(memory.LabelForCTEStorage) - memTracker.AttachTo(parentMemTracker) - - diskTracker := tbl.GetDiskTracker() - diskTracker.SetLabel(memory.LabelForCTEStorage) - diskTracker.AttachTo(parentDiskTracker) - - if variable.EnableTmpStorageOnOOM.Load() { - actionSpill = tbl.ActionSpill() - failpoint.Inject("testCTEStorageSpill", func(val failpoint.Value) { - if val.(bool) { - actionSpill = tbl.(*cteutil.StorageRC).ActionSpillForTest() - } - }) - ctx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) - } - return actionSpill -} - -func (p *cteProducer) tryDedupAndAdd(chk *chunk.Chunk, - storage cteutil.Storage, - hashTbl join.BaseHashTable) (res *chunk.Chunk, err error) { - if p.isDistinct { - if chk, err = p.deduplicate(chk, storage, hashTbl); err != nil { - return nil, err - } - } - return chk, storage.Add(chk) -} - -// Compute hash values in chk and put it in hCtx.hashVals. -// Use the returned sel to choose the computed hash values. -func (p *cteProducer) computeChunkHash(chk *chunk.Chunk) (sel []int, err error) { - numRows := chk.NumRows() - p.hCtx.InitHash(numRows) - // Continue to reset to make sure all hasher is new. - for i := numRows; i < len(p.hCtx.HashVals); i++ { - p.hCtx.HashVals[i].Reset() - } - sel = chk.Sel() - var hashBitMap []bool - if sel != nil { - hashBitMap = make([]bool, chk.Capacity()) - for _, val := range sel { - hashBitMap[val] = true - } - } else { - // Length of p.sel is init as MaxChunkSize, but the row num of chunk may still exceeds MaxChunkSize. - // So needs to handle here to make sure len(p.sel) == chk.NumRows(). - if len(p.sel) < numRows { - tmpSel := make([]int, numRows-len(p.sel)) - for i := 0; i < len(tmpSel); i++ { - tmpSel[i] = i + len(p.sel) - } - p.sel = append(p.sel, tmpSel...) - } - - // All rows is selected, sel will be [0....numRows). - // e.sel is setup when building executor. - sel = p.sel - } - - for i := 0; i < chk.NumCols(); i++ { - if err = codec.HashChunkSelected(p.ctx.GetSessionVars().StmtCtx.TypeCtx(), p.hCtx.HashVals, - chk, p.hCtx.AllTypes[i], i, p.hCtx.Buf, p.hCtx.HasNull, - hashBitMap, false); err != nil { - return nil, err - } - } - return sel, nil -} - -// Use hashTbl to deduplicate rows, and unique rows will be added to hashTbl. -// Duplicated rows are only marked to be removed by sel in Chunk, instead of really deleted. -func (p *cteProducer) deduplicate(chk *chunk.Chunk, - storage cteutil.Storage, - hashTbl join.BaseHashTable) (chkNoDup *chunk.Chunk, err error) { - numRows := chk.NumRows() - if numRows == 0 { - return chk, nil - } - - // 1. Compute hash values for chunk. - chkHashTbl := join.NewConcurrentMapHashTable() - selOri, err := p.computeChunkHash(chk) - if err != nil { - return nil, err - } - - // 2. Filter rows duplicated in input chunk. - // This sel is for filtering rows duplicated in cur chk. - selChk := make([]int, 0, numRows) - for i := 0; i < numRows; i++ { - key := p.hCtx.HashVals[selOri[i]].Sum64() - row := chk.GetRow(i) - - hasDup, err := p.checkHasDup(key, row, chk, storage, chkHashTbl) - if err != nil { - return nil, err - } - if hasDup { - continue - } - - selChk = append(selChk, selOri[i]) - - rowPtr := chunk.RowPtr{ChkIdx: uint32(0), RowIdx: uint32(i)} - chkHashTbl.Put(key, rowPtr) - } - chk.SetSel(selChk) - chkIdx := storage.NumChunks() - - // 3. Filter rows duplicated in RowContainer. - // This sel is for filtering rows duplicated in cteutil.Storage. - selStorage := make([]int, 0, len(selChk)) - for i := 0; i < len(selChk); i++ { - key := p.hCtx.HashVals[selChk[i]].Sum64() - row := chk.GetRow(i) - - hasDup, err := p.checkHasDup(key, row, nil, storage, hashTbl) - if err != nil { - return nil, err - } - if hasDup { - continue - } - - rowIdx := len(selStorage) - selStorage = append(selStorage, selChk[i]) - - rowPtr := chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)} - hashTbl.Put(key, rowPtr) - } - - chk.SetSel(selStorage) - return chk, nil -} - -// Use the row's probe key to check if it already exists in chk or storage. -// We also need to compare the row's real encoding value to avoid hash collision. -func (p *cteProducer) checkHasDup(probeKey uint64, - row chunk.Row, - curChk *chunk.Chunk, - storage cteutil.Storage, - hashTbl join.BaseHashTable) (hasDup bool, err error) { - entry := hashTbl.Get(probeKey) - - for ; entry != nil; entry = entry.Next { - ptr := entry.Ptr - var matchedRow chunk.Row - if curChk != nil { - matchedRow = curChk.GetRow(int(ptr.RowIdx)) - } else { - matchedRow, err = storage.GetRow(ptr) - } - if err != nil { - return false, err - } - isEqual, err := codec.EqualChunkRow(p.ctx.GetSessionVars().StmtCtx.TypeCtx(), - row, p.hCtx.AllTypes, p.hCtx.KeyColIdx, - matchedRow, p.hCtx.AllTypes, p.hCtx.KeyColIdx) - if err != nil { - return false, err - } - if isEqual { - return true, nil - } - } - return false, nil -} - -func getCorColHashCode(corCol *expression.CorrelatedColumn) (res []byte) { - return codec.HashCode(res, *corCol.Data) -} - -// Return true if cor col has changed. -func (p *cteProducer) checkAndUpdateCorColHashCode() bool { - var changed bool - for i, corCol := range p.corCols { - newHashCode := getCorColHashCode(corCol) - if !bytes.Equal(newHashCode, p.corColHashCodes[i]) { - changed = true - p.corColHashCodes[i] = newHashCode - } - } - return changed -} - -func (p *cteProducer) logTbls(ctx context.Context, err error, iterNum uint64, lvl zapcore.Level) { - logutil.Logger(ctx).Log(lvl, "cte iteration info", - zap.Any("iterInTbl mem usage", p.iterInTbl.GetMemBytes()), zap.Any("iterInTbl disk usage", p.iterInTbl.GetDiskBytes()), - zap.Any("iterOutTbl mem usage", p.iterOutTbl.GetMemBytes()), zap.Any("iterOutTbl disk usage", p.iterOutTbl.GetDiskBytes()), - zap.Any("resTbl mem usage", p.resTbl.GetMemBytes()), zap.Any("resTbl disk usage", p.resTbl.GetDiskBytes()), - zap.Any("resTbl rows", p.resTbl.NumRows()), zap.Any("iteration num", iterNum), zap.Error(err)) -} diff --git a/pkg/executor/executor.go b/pkg/executor/executor.go index 27731493fb0e5..d6e6252d51834 100644 --- a/pkg/executor/executor.go +++ b/pkg/executor/executor.go @@ -1561,11 +1561,11 @@ func (e *SelectionExec) Open(ctx context.Context) error { if err := e.BaseExecutorV2.Open(ctx); err != nil { return err } - if val, _err_ := failpoint.Eval(_curpkg_("mockSelectionExecBaseExecutorOpenReturnedError")); _err_ == nil { + failpoint.Inject("mockSelectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { if val.(bool) { - return errors.New("mock SelectionExec.baseExecutor.Open returned error") + failpoint.Return(errors.New("mock SelectionExec.baseExecutor.Open returned error")) } - } + }) return e.open(ctx) } diff --git a/pkg/executor/executor.go__failpoint_stash__ b/pkg/executor/executor.go__failpoint_stash__ deleted file mode 100644 index d6e6252d51834..0000000000000 --- a/pkg/executor/executor.go__failpoint_stash__ +++ /dev/null @@ -1,2673 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "cmp" - "context" - stderrors "errors" - "fmt" - "math" - "runtime/pprof" - "slices" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/opentracing/opentracing-go" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/ddl/schematracker" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/executor/aggregate" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/internal/pdhelper" - "github.com/pingcap/tidb/pkg/executor/sortexec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/auth" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - planctx "github.com/pingcap/tidb/pkg/planner/context" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/resourcemanager/pool/workerpool" - poolutil "github.com/pingcap/tidb/pkg/resourcemanager/util" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/admin" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/deadlockhistory" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/logutil/consistency" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/resourcegrouptag" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/pingcap/tidb/pkg/util/syncutil" - "github.com/pingcap/tidb/pkg/util/topsql" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/pingcap/tidb/pkg/util/tracing" - tikverr "github.com/tikv/client-go/v2/error" - tikvstore "github.com/tikv/client-go/v2/kv" - tikvutil "github.com/tikv/client-go/v2/util" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -var ( - _ exec.Executor = &CheckTableExec{} - _ exec.Executor = &aggregate.HashAggExec{} - _ exec.Executor = &IndexLookUpExecutor{} - _ exec.Executor = &IndexReaderExecutor{} - _ exec.Executor = &LimitExec{} - _ exec.Executor = &MaxOneRowExec{} - _ exec.Executor = &ProjectionExec{} - _ exec.Executor = &SelectionExec{} - _ exec.Executor = &SelectLockExec{} - _ exec.Executor = &ShowNextRowIDExec{} - _ exec.Executor = &ShowDDLExec{} - _ exec.Executor = &ShowDDLJobsExec{} - _ exec.Executor = &ShowDDLJobQueriesExec{} - _ exec.Executor = &sortexec.SortExec{} - _ exec.Executor = &aggregate.StreamAggExec{} - _ exec.Executor = &TableDualExec{} - _ exec.Executor = &TableReaderExecutor{} - _ exec.Executor = &TableScanExec{} - _ exec.Executor = &sortexec.TopNExec{} - _ exec.Executor = &FastCheckTableExec{} - _ exec.Executor = &AdminShowBDRRoleExec{} - - // GlobalMemoryUsageTracker is the ancestor of all the Executors' memory tracker and GlobalMemory Tracker - GlobalMemoryUsageTracker *memory.Tracker - // GlobalDiskUsageTracker is the ancestor of all the Executors' disk tracker - GlobalDiskUsageTracker *disk.Tracker - // GlobalAnalyzeMemoryTracker is the ancestor of all the Analyze jobs' memory tracker and child of global Tracker - GlobalAnalyzeMemoryTracker *memory.Tracker -) - -var ( - _ dataSourceExecutor = &TableReaderExecutor{} - _ dataSourceExecutor = &IndexReaderExecutor{} - _ dataSourceExecutor = &IndexLookUpExecutor{} - _ dataSourceExecutor = &IndexMergeReaderExecutor{} - - // CheckTableFastBucketSize is the bucket size of fast check table. - CheckTableFastBucketSize = atomic.Int64{} -) - -// dataSourceExecutor is a table DataSource converted Executor. -// Currently, there are TableReader/IndexReader/IndexLookUp/IndexMergeReader. -// Note, partition reader is special and the caller should handle it carefully. -type dataSourceExecutor interface { - exec.Executor - Table() table.Table -} - -const ( - // globalPanicStorageExceed represents the panic message when out of storage quota. - globalPanicStorageExceed string = "Out Of Quota For Local Temporary Space!" - // globalPanicMemoryExceed represents the panic message when out of memory limit. - globalPanicMemoryExceed string = "Out Of Global Memory Limit!" - // globalPanicAnalyzeMemoryExceed represents the panic message when out of analyze memory limit. - globalPanicAnalyzeMemoryExceed string = "Out Of Global Analyze Memory Limit!" -) - -// globalPanicOnExceed panics when GlobalDisTracker storage usage exceeds storage quota. -type globalPanicOnExceed struct { - memory.BaseOOMAction - mutex syncutil.Mutex // For synchronization. -} - -func init() { - action := &globalPanicOnExceed{} - GlobalMemoryUsageTracker = memory.NewGlobalTracker(memory.LabelForGlobalMemory, -1) - GlobalMemoryUsageTracker.SetActionOnExceed(action) - GlobalDiskUsageTracker = disk.NewGlobalTrcaker(memory.LabelForGlobalStorage, -1) - GlobalDiskUsageTracker.SetActionOnExceed(action) - GlobalAnalyzeMemoryTracker = memory.NewTracker(memory.LabelForGlobalAnalyzeMemory, -1) - GlobalAnalyzeMemoryTracker.SetActionOnExceed(action) - // register quota funcs - variable.SetMemQuotaAnalyze = GlobalAnalyzeMemoryTracker.SetBytesLimit - variable.GetMemQuotaAnalyze = GlobalAnalyzeMemoryTracker.GetBytesLimit - // TODO: do not attach now to avoid impact to global, will attach later when analyze memory track is stable - //GlobalAnalyzeMemoryTracker.AttachToGlobalTracker(GlobalMemoryUsageTracker) - - schematracker.ConstructResultOfShowCreateDatabase = ConstructResultOfShowCreateDatabase - schematracker.ConstructResultOfShowCreateTable = ConstructResultOfShowCreateTable - - // CheckTableFastBucketSize is used to set the fast analyze bucket size for check table. - CheckTableFastBucketSize.Store(1024) -} - -// Start the backend components -func Start() { - pdhelper.GlobalPDHelper.Start() -} - -// Stop the backend components -func Stop() { - pdhelper.GlobalPDHelper.Stop() -} - -// Action panics when storage usage exceeds storage quota. -func (a *globalPanicOnExceed) Action(t *memory.Tracker) { - a.mutex.Lock() - defer a.mutex.Unlock() - msg := "" - switch t.Label() { - case memory.LabelForGlobalStorage: - msg = globalPanicStorageExceed - case memory.LabelForGlobalMemory: - msg = globalPanicMemoryExceed - case memory.LabelForGlobalAnalyzeMemory: - msg = globalPanicAnalyzeMemoryExceed - default: - msg = "Out of Unknown Resource Quota!" - } - // TODO(hawkingrei): should return error instead. - panic(msg) -} - -// GetPriority get the priority of the Action -func (*globalPanicOnExceed) GetPriority() int64 { - return memory.DefPanicPriority -} - -// CommandDDLJobsExec is the general struct for Cancel/Pause/Resume commands on -// DDL jobs. These command currently by admin have the very similar struct and -// operations, it should be a better idea to have them in the same struct. -type CommandDDLJobsExec struct { - exec.BaseExecutor - - cursor int - jobIDs []int64 - errs []error - - execute func(se sessionctx.Context, ids []int64) (errs []error, err error) -} - -// Open implements the Executor for all Cancel/Pause/Resume command on DDL jobs -// just with different processes. And, it should not be called directly by the -// Executor. -func (e *CommandDDLJobsExec) Open(context.Context) error { - // We want to use a global transaction to execute the admin command, so we don't use e.Ctx() here. - newSess, err := e.GetSysSession() - if err != nil { - return err - } - e.errs, err = e.execute(newSess, e.jobIDs) - e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), newSess) - return err -} - -// Next implements the Executor Next interface for Cancel/Pause/Resume -func (e *CommandDDLJobsExec) Next(_ context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - if e.cursor >= len(e.jobIDs) { - return nil - } - numCurBatch := min(req.Capacity(), len(e.jobIDs)-e.cursor) - for i := e.cursor; i < e.cursor+numCurBatch; i++ { - req.AppendString(0, strconv.FormatInt(e.jobIDs[i], 10)) - if e.errs != nil && e.errs[i] != nil { - req.AppendString(1, fmt.Sprintf("error: %v", e.errs[i])) - } else { - req.AppendString(1, "successful") - } - } - e.cursor += numCurBatch - return nil -} - -// CancelDDLJobsExec represents a cancel DDL jobs executor. -type CancelDDLJobsExec struct { - *CommandDDLJobsExec -} - -// PauseDDLJobsExec indicates an Executor for Pause a DDL Job. -type PauseDDLJobsExec struct { - *CommandDDLJobsExec -} - -// ResumeDDLJobsExec indicates an Executor for Resume a DDL Job. -type ResumeDDLJobsExec struct { - *CommandDDLJobsExec -} - -// ShowNextRowIDExec represents a show the next row ID executor. -type ShowNextRowIDExec struct { - exec.BaseExecutor - tblName *ast.TableName - done bool -} - -// Next implements the Executor Next interface. -func (e *ShowNextRowIDExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.done { - return nil - } - is := domain.GetDomain(e.Ctx()).InfoSchema() - tbl, err := is.TableByName(ctx, e.tblName.Schema, e.tblName.Name) - if err != nil { - return err - } - tblMeta := tbl.Meta() - - allocators := tbl.Allocators(e.Ctx().GetTableCtx()) - for _, alloc := range allocators.Allocs { - nextGlobalID, err := alloc.NextGlobalAutoID() - if err != nil { - return err - } - - var colName, idType string - switch alloc.GetType() { - case autoid.RowIDAllocType: - idType = "_TIDB_ROWID" - if tblMeta.PKIsHandle { - if col := tblMeta.GetAutoIncrementColInfo(); col != nil { - colName = col.Name.O - } - } else { - colName = model.ExtraHandleName.O - } - case autoid.AutoIncrementType: - idType = "AUTO_INCREMENT" - if tblMeta.PKIsHandle { - if col := tblMeta.GetAutoIncrementColInfo(); col != nil { - colName = col.Name.O - } - } else { - colName = model.ExtraHandleName.O - } - case autoid.AutoRandomType: - idType = "AUTO_RANDOM" - colName = tblMeta.GetPkName().O - case autoid.SequenceType: - idType = "SEQUENCE" - colName = "" - default: - return autoid.ErrInvalidAllocatorType.GenWithStackByArgs() - } - - req.AppendString(0, e.tblName.Schema.O) - req.AppendString(1, e.tblName.Name.O) - req.AppendString(2, colName) - req.AppendInt64(3, nextGlobalID) - req.AppendString(4, idType) - } - - e.done = true - return nil -} - -// ShowDDLExec represents a show DDL executor. -type ShowDDLExec struct { - exec.BaseExecutor - - ddlOwnerID string - selfID string - ddlInfo *ddl.Info - done bool -} - -// Next implements the Executor Next interface. -func (e *ShowDDLExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.done { - return nil - } - - ddlJobs := "" - query := "" - l := len(e.ddlInfo.Jobs) - for i, job := range e.ddlInfo.Jobs { - ddlJobs += job.String() - query += job.Query - if i != l-1 { - ddlJobs += "\n" - query += "\n" - } - } - - serverInfo, err := infosync.GetServerInfoByID(ctx, e.ddlOwnerID) - if err != nil { - return err - } - - serverAddress := serverInfo.IP + ":" + - strconv.FormatUint(uint64(serverInfo.Port), 10) - - req.AppendInt64(0, e.ddlInfo.SchemaVer) - req.AppendString(1, e.ddlOwnerID) - req.AppendString(2, serverAddress) - req.AppendString(3, ddlJobs) - req.AppendString(4, e.selfID) - req.AppendString(5, query) - - e.done = true - return nil -} - -// ShowDDLJobsExec represent a show DDL jobs executor. -type ShowDDLJobsExec struct { - exec.BaseExecutor - DDLJobRetriever - - jobNumber int - is infoschema.InfoSchema - sess sessionctx.Context -} - -// DDLJobRetriever retrieve the DDLJobs. -// nolint:structcheck -type DDLJobRetriever struct { - runningJobs []*model.Job - historyJobIter meta.LastJobIterator - cursor int - is infoschema.InfoSchema - activeRoles []*auth.RoleIdentity - cacheJobs []*model.Job - TZLoc *time.Location -} - -func (e *DDLJobRetriever) initial(txn kv.Transaction, sess sessionctx.Context) error { - m := meta.NewMeta(txn) - jobs, err := ddl.GetAllDDLJobs(sess) - if err != nil { - return err - } - e.historyJobIter, err = ddl.GetLastHistoryDDLJobsIterator(m) - if err != nil { - return err - } - e.runningJobs = jobs - e.cursor = 0 - return nil -} - -func (e *DDLJobRetriever) appendJobToChunk(req *chunk.Chunk, job *model.Job, checker privilege.Manager) { - schemaName := job.SchemaName - tableName := "" - finishTS := uint64(0) - if job.BinlogInfo != nil { - finishTS = job.BinlogInfo.FinishedTS - if job.BinlogInfo.TableInfo != nil { - tableName = job.BinlogInfo.TableInfo.Name.L - } - if job.BinlogInfo.MultipleTableInfos != nil { - tablenames := new(strings.Builder) - for i, affect := range job.BinlogInfo.MultipleTableInfos { - if i > 0 { - fmt.Fprintf(tablenames, ",") - } - fmt.Fprintf(tablenames, "%s", affect.Name.L) - } - tableName = tablenames.String() - } - if len(schemaName) == 0 && job.BinlogInfo.DBInfo != nil { - schemaName = job.BinlogInfo.DBInfo.Name.L - } - } - if len(tableName) == 0 { - tableName = job.TableName - } - // For compatibility, the old version of DDL Job wasn't store the schema name and table name. - if len(schemaName) == 0 { - schemaName = getSchemaName(e.is, job.SchemaID) - } - if len(tableName) == 0 { - tableName = getTableName(e.is, job.TableID) - } - - createTime := ts2Time(job.StartTS, e.TZLoc) - startTime := ts2Time(job.RealStartTS, e.TZLoc) - finishTime := ts2Time(finishTS, e.TZLoc) - - // Check the privilege. - if checker != nil && !checker.RequestVerification(e.activeRoles, strings.ToLower(schemaName), strings.ToLower(tableName), "", mysql.AllPrivMask) { - return - } - - req.AppendInt64(0, job.ID) - req.AppendString(1, schemaName) - req.AppendString(2, tableName) - req.AppendString(3, job.Type.String()+showAddIdxReorgTp(job)) - req.AppendString(4, job.SchemaState.String()) - req.AppendInt64(5, job.SchemaID) - req.AppendInt64(6, job.TableID) - req.AppendInt64(7, job.RowCount) - req.AppendTime(8, createTime) - if job.RealStartTS > 0 { - req.AppendTime(9, startTime) - } else { - req.AppendNull(9) - } - if finishTS > 0 { - req.AppendTime(10, finishTime) - } else { - req.AppendNull(10) - } - req.AppendString(11, job.State.String()) - if job.Type == model.ActionMultiSchemaChange { - isDistTask := job.ReorgMeta != nil && job.ReorgMeta.IsDistReorg - for _, subJob := range job.MultiSchemaInfo.SubJobs { - req.AppendInt64(0, job.ID) - req.AppendString(1, schemaName) - req.AppendString(2, tableName) - req.AppendString(3, subJob.Type.String()+" /* subjob */"+showAddIdxReorgTpInSubJob(subJob, isDistTask)) - req.AppendString(4, subJob.SchemaState.String()) - req.AppendInt64(5, job.SchemaID) - req.AppendInt64(6, job.TableID) - req.AppendInt64(7, subJob.RowCount) - req.AppendTime(8, createTime) - if subJob.RealStartTS > 0 { - realStartTS := ts2Time(subJob.RealStartTS, e.TZLoc) - req.AppendTime(9, realStartTS) - } else { - req.AppendNull(9) - } - if finishTS > 0 { - req.AppendTime(10, finishTime) - } else { - req.AppendNull(10) - } - req.AppendString(11, subJob.State.String()) - } - } -} - -func showAddIdxReorgTp(job *model.Job) string { - if job.Type == model.ActionAddIndex || job.Type == model.ActionAddPrimaryKey { - if job.ReorgMeta != nil { - sb := strings.Builder{} - tp := job.ReorgMeta.ReorgTp.String() - if len(tp) > 0 { - sb.WriteString(" /* ") - sb.WriteString(tp) - if job.ReorgMeta.ReorgTp == model.ReorgTypeLitMerge && - job.ReorgMeta.IsDistReorg && - job.ReorgMeta.UseCloudStorage { - sb.WriteString(" cloud") - } - sb.WriteString(" */") - } - return sb.String() - } - } - return "" -} - -func showAddIdxReorgTpInSubJob(subJob *model.SubJob, useDistTask bool) string { - if subJob.Type == model.ActionAddIndex || subJob.Type == model.ActionAddPrimaryKey { - sb := strings.Builder{} - tp := subJob.ReorgTp.String() - if len(tp) > 0 { - sb.WriteString(" /* ") - sb.WriteString(tp) - if subJob.ReorgTp == model.ReorgTypeLitMerge && useDistTask && subJob.UseCloud { - sb.WriteString(" cloud") - } - sb.WriteString(" */") - } - return sb.String() - } - return "" -} - -func ts2Time(timestamp uint64, loc *time.Location) types.Time { - duration := time.Duration(math.Pow10(9-types.DefaultFsp)) * time.Nanosecond - t := model.TSConvert2Time(timestamp) - t.Truncate(duration) - return types.NewTime(types.FromGoTime(t.In(loc)), mysql.TypeDatetime, types.MaxFsp) -} - -// ShowDDLJobQueriesExec represents a show DDL job queries executor. -// The jobs id that is given by 'admin show ddl job queries' statement, -// only be searched in the latest 10 history jobs. -type ShowDDLJobQueriesExec struct { - exec.BaseExecutor - - cursor int - jobs []*model.Job - jobIDs []int64 -} - -// Open implements the Executor Open interface. -func (e *ShowDDLJobQueriesExec) Open(ctx context.Context) error { - var err error - var jobs []*model.Job - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - session, err := e.GetSysSession() - if err != nil { - return err - } - err = sessiontxn.NewTxn(context.Background(), session) - if err != nil { - return err - } - defer func() { - // ReleaseSysSession will rollbacks txn automatically. - e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), session) - }() - txn, err := session.Txn(true) - if err != nil { - return err - } - session.GetSessionVars().SetInTxn(true) - - m := meta.NewMeta(txn) - jobs, err = ddl.GetAllDDLJobs(session) - if err != nil { - return err - } - - historyJobs, err := ddl.GetLastNHistoryDDLJobs(m, ddl.DefNumHistoryJobs) - if err != nil { - return err - } - - appendedJobID := make(map[int64]struct{}) - // deduplicate job results - // for situations when this operation happens at the same time with new DDLs being executed - for _, job := range jobs { - if _, ok := appendedJobID[job.ID]; !ok { - appendedJobID[job.ID] = struct{}{} - e.jobs = append(e.jobs, job) - } - } - for _, historyJob := range historyJobs { - if _, ok := appendedJobID[historyJob.ID]; !ok { - appendedJobID[historyJob.ID] = struct{}{} - e.jobs = append(e.jobs, historyJob) - } - } - - return nil -} - -// Next implements the Executor Next interface. -func (e *ShowDDLJobQueriesExec) Next(_ context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - if e.cursor >= len(e.jobs) { - return nil - } - if len(e.jobIDs) >= len(e.jobs) { - return nil - } - numCurBatch := min(req.Capacity(), len(e.jobs)-e.cursor) - for _, id := range e.jobIDs { - for i := e.cursor; i < e.cursor+numCurBatch; i++ { - if id == e.jobs[i].ID { - req.AppendString(0, e.jobs[i].Query) - } - } - } - e.cursor += numCurBatch - return nil -} - -// ShowDDLJobQueriesWithRangeExec represents a show DDL job queries with range executor. -// The jobs id that is given by 'admin show ddl job queries' statement, -// can be searched within a specified range in history jobs using offset and limit. -type ShowDDLJobQueriesWithRangeExec struct { - exec.BaseExecutor - - cursor int - jobs []*model.Job - offset uint64 - limit uint64 -} - -// Open implements the Executor Open interface. -func (e *ShowDDLJobQueriesWithRangeExec) Open(ctx context.Context) error { - var err error - var jobs []*model.Job - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - session, err := e.GetSysSession() - if err != nil { - return err - } - err = sessiontxn.NewTxn(context.Background(), session) - if err != nil { - return err - } - defer func() { - // ReleaseSysSession will rollbacks txn automatically. - e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), session) - }() - txn, err := session.Txn(true) - if err != nil { - return err - } - session.GetSessionVars().SetInTxn(true) - - m := meta.NewMeta(txn) - jobs, err = ddl.GetAllDDLJobs(session) - if err != nil { - return err - } - - historyJobs, err := ddl.GetLastNHistoryDDLJobs(m, int(e.offset+e.limit)) - if err != nil { - return err - } - - appendedJobID := make(map[int64]struct{}) - // deduplicate job results - // for situations when this operation happens at the same time with new DDLs being executed - for _, job := range jobs { - if _, ok := appendedJobID[job.ID]; !ok { - appendedJobID[job.ID] = struct{}{} - e.jobs = append(e.jobs, job) - } - } - for _, historyJob := range historyJobs { - if _, ok := appendedJobID[historyJob.ID]; !ok { - appendedJobID[historyJob.ID] = struct{}{} - e.jobs = append(e.jobs, historyJob) - } - } - - if e.cursor < int(e.offset) { - e.cursor = int(e.offset) - } - - return nil -} - -// Next implements the Executor Next interface. -func (e *ShowDDLJobQueriesWithRangeExec) Next(_ context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - if e.cursor >= len(e.jobs) { - return nil - } - if int(e.offset) > len(e.jobs) { - return nil - } - numCurBatch := min(req.Capacity(), len(e.jobs)-e.cursor) - for i := e.cursor; i < e.cursor+numCurBatch; i++ { - // i is make true to be >= int(e.offset) - if i >= int(e.offset+e.limit) { - break - } - req.AppendString(0, strconv.FormatInt(e.jobs[i].ID, 10)) - req.AppendString(1, e.jobs[i].Query) - } - e.cursor += numCurBatch - return nil -} - -// Open implements the Executor Open interface. -func (e *ShowDDLJobsExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - e.DDLJobRetriever.is = e.is - if e.jobNumber == 0 { - e.jobNumber = ddl.DefNumHistoryJobs - } - sess, err := e.GetSysSession() - if err != nil { - return err - } - e.sess = sess - err = sessiontxn.NewTxn(context.Background(), sess) - if err != nil { - return err - } - txn, err := sess.Txn(true) - if err != nil { - return err - } - sess.GetSessionVars().SetInTxn(true) - err = e.DDLJobRetriever.initial(txn, sess) - return err -} - -// Next implements the Executor Next interface. -func (e *ShowDDLJobsExec) Next(_ context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - if (e.cursor - len(e.runningJobs)) >= e.jobNumber { - return nil - } - count := 0 - - // Append running ddl jobs. - if e.cursor < len(e.runningJobs) { - numCurBatch := min(req.Capacity(), len(e.runningJobs)-e.cursor) - for i := e.cursor; i < e.cursor+numCurBatch; i++ { - e.appendJobToChunk(req, e.runningJobs[i], nil) - } - e.cursor += numCurBatch - count += numCurBatch - } - - // Append history ddl jobs. - var err error - if count < req.Capacity() { - num := req.Capacity() - count - remainNum := e.jobNumber - (e.cursor - len(e.runningJobs)) - num = min(num, remainNum) - e.cacheJobs, err = e.historyJobIter.GetLastJobs(num, e.cacheJobs) - if err != nil { - return err - } - for _, job := range e.cacheJobs { - e.appendJobToChunk(req, job, nil) - } - e.cursor += len(e.cacheJobs) - } - return nil -} - -// Close implements the Executor Close interface. -func (e *ShowDDLJobsExec) Close() error { - e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), e.sess) - return e.BaseExecutor.Close() -} - -func getSchemaName(is infoschema.InfoSchema, id int64) string { - var schemaName string - dbInfo, ok := is.SchemaByID(id) - if ok { - schemaName = dbInfo.Name.O - return schemaName - } - - return schemaName -} - -func getTableName(is infoschema.InfoSchema, id int64) string { - var tableName string - table, ok := is.TableByID(id) - if ok { - tableName = table.Meta().Name.O - return tableName - } - - return tableName -} - -// CheckTableExec represents a check table executor. -// It is built from the "admin check table" statement, and it checks if the -// index matches the records in the table. -type CheckTableExec struct { - exec.BaseExecutor - - dbName string - table table.Table - indexInfos []*model.IndexInfo - srcs []*IndexLookUpExecutor - done bool - is infoschema.InfoSchema - exitCh chan struct{} - retCh chan error - checkIndex bool -} - -// Open implements the Executor Open interface. -func (e *CheckTableExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - for _, src := range e.srcs { - if err := exec.Open(ctx, src); err != nil { - return errors.Trace(err) - } - } - e.done = false - return nil -} - -// Close implements the Executor Close interface. -func (e *CheckTableExec) Close() error { - var firstErr error - close(e.exitCh) - for _, src := range e.srcs { - if err := exec.Close(src); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr -} - -func (e *CheckTableExec) checkTableIndexHandle(ctx context.Context, idxInfo *model.IndexInfo) error { - // For partition table, there will be multi same index indexLookUpReaders on different partitions. - for _, src := range e.srcs { - if src.index.Name.L == idxInfo.Name.L { - err := e.checkIndexHandle(ctx, src) - if err != nil { - return err - } - } - } - return nil -} - -func (e *CheckTableExec) checkIndexHandle(ctx context.Context, src *IndexLookUpExecutor) error { - cols := src.Schema().Columns - retFieldTypes := make([]*types.FieldType, len(cols)) - for i := range cols { - retFieldTypes[i] = cols[i].RetType - } - chk := chunk.New(retFieldTypes, e.InitCap(), e.MaxChunkSize()) - - var err error - for { - err = exec.Next(ctx, src, chk) - if err != nil { - e.retCh <- errors.Trace(err) - break - } - if chk.NumRows() == 0 { - break - } - } - return errors.Trace(err) -} - -func (e *CheckTableExec) handlePanic(r any) { - if r != nil { - e.retCh <- errors.Errorf("%v", r) - } -} - -// Next implements the Executor Next interface. -func (e *CheckTableExec) Next(ctx context.Context, _ *chunk.Chunk) error { - if e.done || len(e.srcs) == 0 { - return nil - } - defer func() { e.done = true }() - - idxNames := make([]string, 0, len(e.indexInfos)) - for _, idx := range e.indexInfos { - if idx.MVIndex { - continue - } - idxNames = append(idxNames, idx.Name.O) - } - greater, idxOffset, err := admin.CheckIndicesCount(e.Ctx(), e.dbName, e.table.Meta().Name.O, idxNames) - if err != nil { - // For admin check index statement, for speed up and compatibility, doesn't do below checks. - if e.checkIndex { - return errors.Trace(err) - } - if greater == admin.IdxCntGreater { - err = e.checkTableIndexHandle(ctx, e.indexInfos[idxOffset]) - } else if greater == admin.TblCntGreater { - err = e.checkTableRecord(ctx, idxOffset) - } - return errors.Trace(err) - } - - // The number of table rows is equal to the number of index rows. - // TODO: Make the value of concurrency adjustable. And we can consider the number of records. - if len(e.srcs) == 1 { - err = e.checkIndexHandle(ctx, e.srcs[0]) - if err == nil && e.srcs[0].index.MVIndex { - err = e.checkTableRecord(ctx, 0) - } - if err != nil { - return err - } - } - taskCh := make(chan *IndexLookUpExecutor, len(e.srcs)) - failure := atomicutil.NewBool(false) - concurrency := min(3, len(e.srcs)) - var wg util.WaitGroupWrapper - for _, src := range e.srcs { - taskCh <- src - } - for i := 0; i < concurrency; i++ { - wg.Run(func() { - util.WithRecovery(func() { - for { - if fail := failure.Load(); fail { - return - } - select { - case src := <-taskCh: - err1 := e.checkIndexHandle(ctx, src) - if err1 == nil && src.index.MVIndex { - for offset, idx := range e.indexInfos { - if idx.ID == src.index.ID { - err1 = e.checkTableRecord(ctx, offset) - break - } - } - } - if err1 != nil { - failure.Store(true) - logutil.Logger(ctx).Info("check index handle failed", zap.Error(err1)) - return - } - case <-e.exitCh: - return - default: - return - } - } - }, e.handlePanic) - }) - } - wg.Wait() - select { - case err := <-e.retCh: - return errors.Trace(err) - default: - return nil - } -} - -func (e *CheckTableExec) checkTableRecord(ctx context.Context, idxOffset int) error { - idxInfo := e.indexInfos[idxOffset] - txn, err := e.Ctx().Txn(true) - if err != nil { - return err - } - if e.table.Meta().GetPartitionInfo() == nil { - idx := tables.NewIndex(e.table.Meta().ID, e.table.Meta(), idxInfo) - return admin.CheckRecordAndIndex(ctx, e.Ctx(), txn, e.table, idx) - } - - info := e.table.Meta().GetPartitionInfo() - for _, def := range info.Definitions { - pid := def.ID - partition := e.table.(table.PartitionedTable).GetPartition(pid) - idx := tables.NewIndex(def.ID, e.table.Meta(), idxInfo) - if err := admin.CheckRecordAndIndex(ctx, e.Ctx(), txn, partition, idx); err != nil { - return errors.Trace(err) - } - } - return nil -} - -// ShowSlowExec represents the executor of showing the slow queries. -// It is build from the "admin show slow" statement: -// -// admin show slow top [internal | all] N -// admin show slow recent N -type ShowSlowExec struct { - exec.BaseExecutor - - ShowSlow *ast.ShowSlow - result []*domain.SlowQueryInfo - cursor int -} - -// Open implements the Executor Open interface. -func (e *ShowSlowExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - - dom := domain.GetDomain(e.Ctx()) - e.result = dom.ShowSlowQuery(e.ShowSlow) - return nil -} - -// Next implements the Executor Next interface. -func (e *ShowSlowExec) Next(_ context.Context, req *chunk.Chunk) error { - req.Reset() - if e.cursor >= len(e.result) { - return nil - } - - for e.cursor < len(e.result) && req.NumRows() < e.MaxChunkSize() { - slow := e.result[e.cursor] - req.AppendString(0, slow.SQL) - req.AppendTime(1, types.NewTime(types.FromGoTime(slow.Start), mysql.TypeTimestamp, types.MaxFsp)) - req.AppendDuration(2, types.Duration{Duration: slow.Duration, Fsp: types.MaxFsp}) - req.AppendString(3, slow.Detail.String()) - if slow.Succ { - req.AppendInt64(4, 1) - } else { - req.AppendInt64(4, 0) - } - req.AppendUint64(5, slow.ConnID) - req.AppendUint64(6, slow.TxnTS) - req.AppendString(7, slow.User) - req.AppendString(8, slow.DB) - req.AppendString(9, slow.TableIDs) - req.AppendString(10, slow.IndexNames) - if slow.Internal { - req.AppendInt64(11, 1) - } else { - req.AppendInt64(11, 0) - } - req.AppendString(12, slow.Digest) - req.AppendString(13, slow.SessAlias) - e.cursor++ - } - return nil -} - -// SelectLockExec represents a select lock executor. -// It is built from the "SELECT .. FOR UPDATE" or the "SELECT .. LOCK IN SHARE MODE" statement. -// For "SELECT .. FOR UPDATE" statement, it locks every row key from source Executor. -// After the execution, the keys are buffered in transaction, and will be sent to KV -// when doing commit. If there is any key already locked by another transaction, -// the transaction will rollback and retry. -type SelectLockExec struct { - exec.BaseExecutor - - Lock *ast.SelectLockInfo - keys []kv.Key - - // The children may be a join of multiple tables, so we need a map. - tblID2Handle map[int64][]plannerutil.HandleCols - - // When SelectLock work on a partition table, we need the partition ID - // (Physical Table ID) instead of the 'logical' table ID to calculate - // the lock KV. In that case, the Physical Table ID is extracted - // from the row key in the store and as an extra column in the chunk row. - - // tblID2PhyTblIDCol is used for partitioned tables. - // The child executor need to return an extra column containing - // the Physical Table ID (i.e. from which partition the row came from) - // Used during building - tblID2PhysTblIDCol map[int64]*expression.Column - - // Used during execution - // Map from logic tableID to column index where the physical table id is stored - // For dynamic prune mode, model.ExtraPhysTblID columns are requested from - // storage and used for physical table id - // For static prune mode, model.ExtraPhysTblID is still sent to storage/Protobuf - // but could be filled in by the partitions TableReaderExecutor - // due to issues with chunk handling between the TableReaderExecutor and the - // SelectReader result. - tblID2PhysTblIDColIdx map[int64]int -} - -// Open implements the Executor Open interface. -func (e *SelectLockExec) Open(ctx context.Context) error { - if len(e.tblID2PhysTblIDCol) > 0 { - e.tblID2PhysTblIDColIdx = make(map[int64]int) - cols := e.Schema().Columns - for i := len(cols) - 1; i >= 0; i-- { - if cols[i].ID == model.ExtraPhysTblID { - for tblID, col := range e.tblID2PhysTblIDCol { - if cols[i].UniqueID == col.UniqueID { - e.tblID2PhysTblIDColIdx[tblID] = i - break - } - } - } - } - } - return e.BaseExecutor.Open(ctx) -} - -// Next implements the Executor Next interface. -func (e *SelectLockExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - err := exec.Next(ctx, e.Children(0), req) - if err != nil { - return err - } - // If there's no handle or it's not a `SELECT FOR UPDATE` statement. - if len(e.tblID2Handle) == 0 || (!logicalop.IsSelectForUpdateLockType(e.Lock.LockType)) { - return nil - } - - if req.NumRows() > 0 { - iter := chunk.NewIterator4Chunk(req) - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - for tblID, cols := range e.tblID2Handle { - for _, col := range cols { - handle, err := col.BuildHandle(row) - if err != nil { - return err - } - physTblID := tblID - if physTblColIdx, ok := e.tblID2PhysTblIDColIdx[tblID]; ok { - physTblID = row.GetInt64(physTblColIdx) - if physTblID == 0 { - // select * from t1 left join t2 on t1.c = t2.c for update - // The join right side might be added NULL in left join - // In that case, physTblID is 0, so skip adding the lock. - // - // Note, we can't distinguish whether it's the left join case, - // or a bug that TiKV return without correct physical ID column. - continue - } - } - e.keys = append(e.keys, tablecodec.EncodeRowKeyWithHandle(physTblID, handle)) - } - } - } - return nil - } - lockWaitTime := e.Ctx().GetSessionVars().LockWaitTimeout - if e.Lock.LockType == ast.SelectLockForUpdateNoWait { - lockWaitTime = tikvstore.LockNoWait - } else if e.Lock.LockType == ast.SelectLockForUpdateWaitN { - lockWaitTime = int64(e.Lock.WaitSec) * 1000 - } - - for id := range e.tblID2Handle { - e.UpdateDeltaForTableID(id) - } - lockCtx, err := newLockCtx(e.Ctx(), lockWaitTime, len(e.keys)) - if err != nil { - return err - } - return doLockKeys(ctx, e.Ctx(), lockCtx, e.keys...) -} - -func newLockCtx(sctx sessionctx.Context, lockWaitTime int64, numKeys int) (*tikvstore.LockCtx, error) { - seVars := sctx.GetSessionVars() - forUpdateTS, err := sessiontxn.GetTxnManager(sctx).GetStmtForUpdateTS() - if err != nil { - return nil, err - } - lockCtx := tikvstore.NewLockCtx(forUpdateTS, lockWaitTime, seVars.StmtCtx.GetLockWaitStartTime()) - lockCtx.Killed = &seVars.SQLKiller.Signal - lockCtx.PessimisticLockWaited = &seVars.StmtCtx.PessimisticLockWaited - lockCtx.LockKeysDuration = &seVars.StmtCtx.LockKeysDuration - lockCtx.LockKeysCount = &seVars.StmtCtx.LockKeysCount - lockCtx.LockExpired = &seVars.TxnCtx.LockExpire - lockCtx.ResourceGroupTagger = func(req *kvrpcpb.PessimisticLockRequest) []byte { - if req == nil { - return nil - } - if len(req.Mutations) == 0 { - return nil - } - if mutation := req.Mutations[0]; mutation != nil { - label := resourcegrouptag.GetResourceGroupLabelByKey(mutation.Key) - normalized, digest := seVars.StmtCtx.SQLDigest() - if len(normalized) == 0 { - return nil - } - _, planDigest := seVars.StmtCtx.GetPlanDigest() - return resourcegrouptag.EncodeResourceGroupTag(digest, planDigest, label) - } - return nil - } - lockCtx.OnDeadlock = func(deadlock *tikverr.ErrDeadlock) { - cfg := config.GetGlobalConfig() - if deadlock.IsRetryable && !cfg.PessimisticTxn.DeadlockHistoryCollectRetryable { - return - } - rec := deadlockhistory.ErrDeadlockToDeadlockRecord(deadlock) - deadlockhistory.GlobalDeadlockHistory.Push(rec) - } - if lockCtx.ForUpdateTS > 0 && seVars.AssertionLevel != variable.AssertionLevelOff { - lockCtx.InitCheckExistence(numKeys) - } - return lockCtx, nil -} - -// doLockKeys is the main entry for pessimistic lock keys -// waitTime means the lock operation will wait in milliseconds if target key is already -// locked by others. used for (select for update nowait) situation -func doLockKeys(ctx context.Context, se sessionctx.Context, lockCtx *tikvstore.LockCtx, keys ...kv.Key) error { - sessVars := se.GetSessionVars() - sctx := sessVars.StmtCtx - if !sctx.InUpdateStmt && !sctx.InDeleteStmt { - atomic.StoreUint32(&se.GetSessionVars().TxnCtx.ForUpdate, 1) - } - // Lock keys only once when finished fetching all results. - txn, err := se.Txn(true) - if err != nil { - return err - } - - // Skip the temporary table keys. - keys = filterTemporaryTableKeys(sessVars, keys) - - keys = filterLockTableKeys(sessVars.StmtCtx, keys) - var lockKeyStats *tikvutil.LockKeysDetails - ctx = context.WithValue(ctx, tikvutil.LockKeysDetailCtxKey, &lockKeyStats) - err = txn.LockKeys(tikvutil.SetSessionID(ctx, se.GetSessionVars().ConnectionID), lockCtx, keys...) - if lockKeyStats != nil { - sctx.MergeLockKeysExecDetails(lockKeyStats) - } - return err -} - -func filterTemporaryTableKeys(vars *variable.SessionVars, keys []kv.Key) []kv.Key { - txnCtx := vars.TxnCtx - if txnCtx == nil || txnCtx.TemporaryTables == nil { - return keys - } - - newKeys := keys[:0:len(keys)] - for _, key := range keys { - tblID := tablecodec.DecodeTableID(key) - if _, ok := txnCtx.TemporaryTables[tblID]; !ok { - newKeys = append(newKeys, key) - } - } - return newKeys -} - -func filterLockTableKeys(stmtCtx *stmtctx.StatementContext, keys []kv.Key) []kv.Key { - if len(stmtCtx.LockTableIDs) == 0 { - return keys - } - newKeys := keys[:0:len(keys)] - for _, key := range keys { - tblID := tablecodec.DecodeTableID(key) - if _, ok := stmtCtx.LockTableIDs[tblID]; ok { - newKeys = append(newKeys, key) - } - } - return newKeys -} - -// LimitExec represents limit executor -// It ignores 'Offset' rows from src, then returns 'Count' rows at maximum. -type LimitExec struct { - exec.BaseExecutor - - begin uint64 - end uint64 - cursor uint64 - - // meetFirstBatch represents whether we have met the first valid Chunk from child. - meetFirstBatch bool - - childResult *chunk.Chunk - - // columnIdxsUsedByChild keep column indexes of child executor used for inline projection - columnIdxsUsedByChild []int - - // Log the close time when opentracing is enabled. - span opentracing.Span -} - -// Next implements the Executor Next interface. -func (e *LimitExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.cursor >= e.end { - return nil - } - for !e.meetFirstBatch { - // transfer req's requiredRows to childResult and then adjust it in childResult - e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.MaxChunkSize()) - err := exec.Next(ctx, e.Children(0), e.adjustRequiredRows(e.childResult)) - if err != nil { - return err - } - batchSize := uint64(e.childResult.NumRows()) - // no more data. - if batchSize == 0 { - return nil - } - if newCursor := e.cursor + batchSize; newCursor >= e.begin { - e.meetFirstBatch = true - begin, end := e.begin-e.cursor, batchSize - if newCursor > e.end { - end = e.end - e.cursor - } - e.cursor += end - if begin == end { - break - } - if e.columnIdxsUsedByChild != nil { - req.Append(e.childResult.Prune(e.columnIdxsUsedByChild), int(begin), int(end)) - } else { - req.Append(e.childResult, int(begin), int(end)) - } - return nil - } - e.cursor += batchSize - } - e.childResult.Reset() - e.childResult = e.childResult.SetRequiredRows(req.RequiredRows(), e.MaxChunkSize()) - e.adjustRequiredRows(e.childResult) - err := exec.Next(ctx, e.Children(0), e.childResult) - if err != nil { - return err - } - batchSize := uint64(e.childResult.NumRows()) - // no more data. - if batchSize == 0 { - return nil - } - if e.cursor+batchSize > e.end { - e.childResult.TruncateTo(int(e.end - e.cursor)) - batchSize = e.end - e.cursor - } - e.cursor += batchSize - - if e.columnIdxsUsedByChild != nil { - for i, childIdx := range e.columnIdxsUsedByChild { - if err = req.SwapColumn(i, e.childResult, childIdx); err != nil { - return err - } - } - } else { - req.SwapColumns(e.childResult) - } - return nil -} - -// Open implements the Executor Open interface. -func (e *LimitExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - e.childResult = exec.TryNewCacheChunk(e.Children(0)) - e.cursor = 0 - e.meetFirstBatch = e.begin == 0 - if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { - e.span = span - } - return nil -} - -// Close implements the Executor Close interface. -func (e *LimitExec) Close() error { - start := time.Now() - - e.childResult = nil - err := e.BaseExecutor.Close() - - elapsed := time.Since(start) - if elapsed > time.Millisecond { - logutil.BgLogger().Info("limit executor close takes a long time", - zap.Duration("elapsed", elapsed)) - if e.span != nil { - span1 := e.span.Tracer().StartSpan("limitExec.Close", opentracing.ChildOf(e.span.Context()), opentracing.StartTime(start)) - defer span1.Finish() - } - } - return err -} - -func (e *LimitExec) adjustRequiredRows(chk *chunk.Chunk) *chunk.Chunk { - // the limit of maximum number of rows the LimitExec should read - limitTotal := int(e.end - e.cursor) - - var limitRequired int - if e.cursor < e.begin { - // if cursor is less than begin, it have to read (begin-cursor) rows to ignore - // and then read chk.RequiredRows() rows to return, - // so the limit is (begin-cursor)+chk.RequiredRows(). - limitRequired = int(e.begin) - int(e.cursor) + chk.RequiredRows() - } else { - // if cursor is equal or larger than begin, just read chk.RequiredRows() rows to return. - limitRequired = chk.RequiredRows() - } - - return chk.SetRequiredRows(min(limitTotal, limitRequired), e.MaxChunkSize()) -} - -func init() { - // While doing optimization in the plan package, we need to execute uncorrelated subquery, - // but the plan package cannot import the executor package because of the dependency cycle. - // So we assign a function implemented in the executor package to the plan package to avoid the dependency cycle. - plannercore.EvalSubqueryFirstRow = func(ctx context.Context, p base.PhysicalPlan, is infoschema.InfoSchema, pctx planctx.PlanContext) ([]types.Datum, error) { - if fixcontrol.GetBoolWithDefault(pctx.GetSessionVars().OptimizerFixControl, fixcontrol.Fix43817, false) { - return nil, errors.NewNoStackError("evaluate non-correlated sub-queries during optimization phase is not allowed by fix-control 43817") - } - - defer func(begin time.Time) { - s := pctx.GetSessionVars() - s.StmtCtx.SetSkipPlanCache("query has uncorrelated sub-queries is un-cacheable") - s.RewritePhaseInfo.PreprocessSubQueries++ - s.RewritePhaseInfo.DurationPreprocessSubQuery += time.Since(begin) - }(time.Now()) - - r, ctx := tracing.StartRegionEx(ctx, "executor.EvalSubQuery") - defer r.End() - - sctx, err := plannercore.AsSctx(pctx) - intest.AssertNoError(err) - if err != nil { - return nil, err - } - - e := newExecutorBuilder(sctx, is) - executor := e.build(p) - if e.err != nil { - return nil, e.err - } - err = exec.Open(ctx, executor) - defer func() { terror.Log(exec.Close(executor)) }() - if err != nil { - return nil, err - } - if pi, ok := sctx.(processinfoSetter); ok { - // Before executing the sub-query, we need update the processinfo to make the progress bar more accurate. - // because the sub-query may take a long time. - pi.UpdateProcessInfo() - } - chk := exec.TryNewCacheChunk(executor) - err = exec.Next(ctx, executor, chk) - if err != nil { - return nil, err - } - if chk.NumRows() == 0 { - return nil, nil - } - row := chk.GetRow(0).GetDatumRow(exec.RetTypes(executor)) - return row, err - } -} - -// TableDualExec represents a dual table executor. -type TableDualExec struct { - exec.BaseExecutorV2 - - // numDualRows can only be 0 or 1. - numDualRows int - numReturned int -} - -// Open implements the Executor Open interface. -func (e *TableDualExec) Open(context.Context) error { - e.numReturned = 0 - return nil -} - -// Next implements the Executor Next interface. -func (e *TableDualExec) Next(_ context.Context, req *chunk.Chunk) error { - req.Reset() - if e.numReturned >= e.numDualRows { - return nil - } - if e.Schema().Len() == 0 { - req.SetNumVirtualRows(1) - } else { - for i := range e.Schema().Columns { - req.AppendNull(i) - } - } - e.numReturned = e.numDualRows - return nil -} - -type selectionExecutorContext struct { - stmtMemTracker *memory.Tracker - evalCtx expression.EvalContext - enableVectorizedExpression bool -} - -func newSelectionExecutorContext(sctx sessionctx.Context) selectionExecutorContext { - return selectionExecutorContext{ - stmtMemTracker: sctx.GetSessionVars().StmtCtx.MemTracker, - evalCtx: sctx.GetExprCtx().GetEvalCtx(), - enableVectorizedExpression: sctx.GetSessionVars().EnableVectorizedExpression, - } -} - -// SelectionExec represents a filter executor. -type SelectionExec struct { - selectionExecutorContext - exec.BaseExecutorV2 - - batched bool - filters []expression.Expression - selected []bool - inputIter *chunk.Iterator4Chunk - inputRow chunk.Row - childResult *chunk.Chunk - - memTracker *memory.Tracker -} - -// Open implements the Executor Open interface. -func (e *SelectionExec) Open(ctx context.Context) error { - if err := e.BaseExecutorV2.Open(ctx); err != nil { - return err - } - failpoint.Inject("mockSelectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("mock SelectionExec.baseExecutor.Open returned error")) - } - }) - return e.open(ctx) -} - -func (e *SelectionExec) open(context.Context) error { - if e.memTracker != nil { - e.memTracker.Reset() - } else { - e.memTracker = memory.NewTracker(e.ID(), -1) - } - e.memTracker.AttachTo(e.stmtMemTracker) - e.childResult = exec.TryNewCacheChunk(e.Children(0)) - e.memTracker.Consume(e.childResult.MemoryUsage()) - e.batched = expression.Vectorizable(e.filters) - if e.batched { - e.selected = make([]bool, 0, chunk.InitialCapacity) - } - e.inputIter = chunk.NewIterator4Chunk(e.childResult) - e.inputRow = e.inputIter.End() - return nil -} - -// Close implements plannercore.Plan Close interface. -func (e *SelectionExec) Close() error { - if e.childResult != nil { - e.memTracker.Consume(-e.childResult.MemoryUsage()) - e.childResult = nil - } - e.selected = nil - return e.BaseExecutorV2.Close() -} - -// Next implements the Executor Next interface. -func (e *SelectionExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - - if !e.batched { - return e.unBatchedNext(ctx, req) - } - - for { - for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - if req.IsFull() { - return nil - } - - if !e.selected[e.inputRow.Idx()] { - continue - } - - req.AppendRow(e.inputRow) - } - mSize := e.childResult.MemoryUsage() - err := exec.Next(ctx, e.Children(0), e.childResult) - e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) - if err != nil { - return err - } - // no more data. - if e.childResult.NumRows() == 0 { - return nil - } - e.selected, err = expression.VectorizedFilter(e.evalCtx, e.enableVectorizedExpression, e.filters, e.inputIter, e.selected) - if err != nil { - return err - } - e.inputRow = e.inputIter.Begin() - } -} - -// unBatchedNext filters input rows one by one and returns once an input row is selected. -// For sql with "SETVAR" in filter and "GETVAR" in projection, for example: "SELECT @a FROM t WHERE (@a := 2) > 0", -// we have to set batch size to 1 to do the evaluation of filter and projection. -func (e *SelectionExec) unBatchedNext(ctx context.Context, chk *chunk.Chunk) error { - evalCtx := e.evalCtx - for { - for ; e.inputRow != e.inputIter.End(); e.inputRow = e.inputIter.Next() { - selected, _, err := expression.EvalBool(evalCtx, e.filters, e.inputRow) - if err != nil { - return err - } - if selected { - chk.AppendRow(e.inputRow) - e.inputRow = e.inputIter.Next() - return nil - } - } - mSize := e.childResult.MemoryUsage() - err := exec.Next(ctx, e.Children(0), e.childResult) - e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) - if err != nil { - return err - } - e.inputRow = e.inputIter.Begin() - // no more data. - if e.childResult.NumRows() == 0 { - return nil - } - } -} - -// TableScanExec is a table scan executor without result fields. -type TableScanExec struct { - exec.BaseExecutor - - t table.Table - columns []*model.ColumnInfo - virtualTableChunkList *chunk.List - virtualTableChunkIdx int -} - -// Next implements the Executor Next interface. -func (e *TableScanExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - return e.nextChunk4InfoSchema(ctx, req) -} - -func (e *TableScanExec) nextChunk4InfoSchema(ctx context.Context, chk *chunk.Chunk) error { - chk.GrowAndReset(e.MaxChunkSize()) - if e.virtualTableChunkList == nil { - e.virtualTableChunkList = chunk.NewList(exec.RetTypes(e), e.InitCap(), e.MaxChunkSize()) - columns := make([]*table.Column, e.Schema().Len()) - for i, colInfo := range e.columns { - columns[i] = table.ToColumn(colInfo) - } - mutableRow := chunk.MutRowFromTypes(exec.RetTypes(e)) - type tableIter interface { - IterRecords(ctx context.Context, sctx sessionctx.Context, cols []*table.Column, fn table.RecordIterFunc) error - } - err := (e.t.(tableIter)).IterRecords(ctx, e.Ctx(), columns, func(_ kv.Handle, rec []types.Datum, _ []*table.Column) (bool, error) { - mutableRow.SetDatums(rec...) - e.virtualTableChunkList.AppendRow(mutableRow.ToRow()) - return true, nil - }) - if err != nil { - return err - } - } - // no more data. - if e.virtualTableChunkIdx >= e.virtualTableChunkList.NumChunks() { - return nil - } - virtualTableChunk := e.virtualTableChunkList.GetChunk(e.virtualTableChunkIdx) - e.virtualTableChunkIdx++ - chk.SwapColumns(virtualTableChunk) - return nil -} - -// Open implements the Executor Open interface. -func (e *TableScanExec) Open(context.Context) error { - e.virtualTableChunkList = nil - return nil -} - -// MaxOneRowExec checks if the number of rows that a query returns is at maximum one. -// It's built from subquery expression. -type MaxOneRowExec struct { - exec.BaseExecutor - - evaluated bool -} - -// Open implements the Executor Open interface. -func (e *MaxOneRowExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - e.evaluated = false - return nil -} - -// Next implements the Executor Next interface. -func (e *MaxOneRowExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.evaluated { - return nil - } - e.evaluated = true - err := exec.Next(ctx, e.Children(0), req) - if err != nil { - return err - } - - if num := req.NumRows(); num == 0 { - for i := range e.Schema().Columns { - req.AppendNull(i) - } - return nil - } else if num != 1 { - return exeerrors.ErrSubqueryMoreThan1Row - } - - childChunk := exec.TryNewCacheChunk(e.Children(0)) - err = exec.Next(ctx, e.Children(0), childChunk) - if err != nil { - return err - } - if childChunk.NumRows() != 0 { - return exeerrors.ErrSubqueryMoreThan1Row - } - - return nil -} - -// ResetContextOfStmt resets the StmtContext and session variables. -// Before every execution, we must clear statement context. -func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Warn("ResetContextOfStmt panicked", zap.Stack("stack"), zap.Any("recover", r), zap.Error(err)) - if err != nil { - err = stderrors.Join(err, util.GetRecoverError(r)) - } else { - err = util.GetRecoverError(r) - } - } - }() - vars := ctx.GetSessionVars() - for name, val := range vars.StmtCtx.SetVarHintRestore { - err := vars.SetSystemVar(name, val) - if err != nil { - logutil.BgLogger().Warn("Failed to restore the variable after SET_VAR hint", zap.String("variable name", name), zap.String("expected value", val)) - } - } - vars.StmtCtx.SetVarHintRestore = nil - var sc *stmtctx.StatementContext - if vars.TxnCtx.CouldRetry || vars.HasStatusFlag(mysql.ServerStatusCursorExists) { - // Must construct new statement context object, the retry history need context for every statement. - // TODO: Maybe one day we can get rid of transaction retry, then this logic can be deleted. - sc = stmtctx.NewStmtCtx() - } else { - sc = vars.InitStatementContext() - } - sc.SetTimeZone(vars.Location()) - sc.TaskID = stmtctx.AllocateTaskID() - if sc.CTEStorageMap == nil { - sc.CTEStorageMap = map[int]*CTEStorages{} - } else { - clear(sc.CTEStorageMap.(map[int]*CTEStorages)) - } - if sc.LockTableIDs == nil { - sc.LockTableIDs = make(map[int64]struct{}) - } else { - clear(sc.LockTableIDs) - } - if sc.TableStats == nil { - sc.TableStats = make(map[int64]any) - } else { - clear(sc.TableStats) - } - if sc.MDLRelatedTableIDs == nil { - sc.MDLRelatedTableIDs = make(map[int64]struct{}) - } else { - clear(sc.MDLRelatedTableIDs) - } - if sc.TblInfo2UnionScan == nil { - sc.TblInfo2UnionScan = make(map[*model.TableInfo]bool) - } else { - clear(sc.TblInfo2UnionScan) - } - sc.IsStaleness = false - sc.EnableOptimizeTrace = false - sc.OptimizeTracer = nil - sc.OptimizerCETrace = nil - sc.IsSyncStatsFailed = false - sc.IsExplainAnalyzeDML = false - sc.ResourceGroupName = vars.ResourceGroupName - // Firstly we assume that UseDynamicPruneMode can be enabled according session variable, then we will check other conditions - // in PlanBuilder.buildDataSource - if ctx.GetSessionVars().IsDynamicPartitionPruneEnabled() { - sc.UseDynamicPruneMode = true - } else { - sc.UseDynamicPruneMode = false - } - - sc.StatsLoad.Timeout = 0 - sc.StatsLoad.NeededItems = nil - sc.StatsLoad.ResultCh = nil - - sc.SysdateIsNow = ctx.GetSessionVars().SysdateIsNow - - vars.MemTracker.Detach() - vars.MemTracker.UnbindActions() - vars.MemTracker.SetBytesLimit(vars.MemQuotaQuery) - vars.MemTracker.ResetMaxConsumed() - vars.DiskTracker.Detach() - vars.DiskTracker.ResetMaxConsumed() - vars.MemTracker.SessionID.Store(vars.ConnectionID) - vars.MemTracker.Killer = &vars.SQLKiller - vars.DiskTracker.Killer = &vars.SQLKiller - vars.SQLKiller.Reset() - vars.SQLKiller.ConnID = vars.ConnectionID - - isAnalyze := false - if execStmt, ok := s.(*ast.ExecuteStmt); ok { - prepareStmt, err := plannercore.GetPreparedStmt(execStmt, vars) - if err != nil { - return err - } - _, isAnalyze = prepareStmt.PreparedAst.Stmt.(*ast.AnalyzeTableStmt) - } else if _, ok := s.(*ast.AnalyzeTableStmt); ok { - isAnalyze = true - } - if isAnalyze { - sc.InitMemTracker(memory.LabelForAnalyzeMemory, -1) - vars.MemTracker.SetBytesLimit(-1) - vars.MemTracker.AttachTo(GlobalAnalyzeMemoryTracker) - } else { - sc.InitMemTracker(memory.LabelForSQLText, -1) - } - logOnQueryExceedMemQuota := domain.GetDomain(ctx).ExpensiveQueryHandle().LogOnQueryExceedMemQuota - switch variable.OOMAction.Load() { - case variable.OOMActionCancel: - action := &memory.PanicOnExceed{ConnID: vars.ConnectionID, Killer: vars.MemTracker.Killer} - action.SetLogHook(logOnQueryExceedMemQuota) - vars.MemTracker.SetActionOnExceed(action) - case variable.OOMActionLog: - fallthrough - default: - action := &memory.LogOnExceed{ConnID: vars.ConnectionID} - action.SetLogHook(logOnQueryExceedMemQuota) - vars.MemTracker.SetActionOnExceed(action) - } - sc.MemTracker.SessionID.Store(vars.ConnectionID) - sc.MemTracker.AttachTo(vars.MemTracker) - sc.InitDiskTracker(memory.LabelForSQLText, -1) - globalConfig := config.GetGlobalConfig() - if variable.EnableTmpStorageOnOOM.Load() && sc.DiskTracker != nil { - sc.DiskTracker.AttachTo(vars.DiskTracker) - if GlobalDiskUsageTracker != nil { - vars.DiskTracker.AttachTo(GlobalDiskUsageTracker) - } - } - if execStmt, ok := s.(*ast.ExecuteStmt); ok { - prepareStmt, err := plannercore.GetPreparedStmt(execStmt, vars) - if err != nil { - return err - } - s = prepareStmt.PreparedAst.Stmt - sc.InitSQLDigest(prepareStmt.NormalizedSQL, prepareStmt.SQLDigest) - // For `execute stmt` SQL, should reset the SQL digest with the prepare SQL digest. - goCtx := context.Background() - if variable.EnablePProfSQLCPU.Load() && len(prepareStmt.NormalizedSQL) > 0 { - goCtx = pprof.WithLabels(goCtx, pprof.Labels("sql", FormatSQL(prepareStmt.NormalizedSQL).String())) - pprof.SetGoroutineLabels(goCtx) - } - if topsqlstate.TopSQLEnabled() && prepareStmt.SQLDigest != nil { - sc.IsSQLRegistered.Store(true) - topsql.AttachAndRegisterSQLInfo(goCtx, prepareStmt.NormalizedSQL, prepareStmt.SQLDigest, vars.InRestrictedSQL) - } - if s, ok := prepareStmt.PreparedAst.Stmt.(*ast.SelectStmt); ok { - if s.LockInfo == nil { - sc.WeakConsistency = isWeakConsistencyRead(ctx, execStmt) - } - } - } - // execute missed stmtID uses empty sql - sc.OriginalSQL = s.Text() - if explainStmt, ok := s.(*ast.ExplainStmt); ok { - sc.InExplainStmt = true - sc.ExplainFormat = explainStmt.Format - sc.InExplainAnalyzeStmt = explainStmt.Analyze - sc.IgnoreExplainIDSuffix = strings.ToLower(explainStmt.Format) == types.ExplainFormatBrief - sc.InVerboseExplain = strings.ToLower(explainStmt.Format) == types.ExplainFormatVerbose - s = explainStmt.Stmt - } else { - sc.ExplainFormat = "" - } - if explainForStmt, ok := s.(*ast.ExplainForStmt); ok { - sc.InExplainStmt = true - sc.InExplainAnalyzeStmt = true - sc.InVerboseExplain = strings.ToLower(explainForStmt.Format) == types.ExplainFormatVerbose - } - - // TODO: Many same bool variables here. - // We should set only two variables ( - // IgnoreErr and StrictSQLMode) to avoid setting the same bool variables and - // pushing them down to TiKV as flags. - - sc.InRestrictedSQL = vars.InRestrictedSQL - strictSQLMode := vars.SQLMode.HasStrictMode() - - errLevels := sc.ErrLevels() - errLevels[errctx.ErrGroupDividedByZero] = errctx.LevelWarn - switch stmt := s.(type) { - // `ResetUpdateStmtCtx` and `ResetDeleteStmtCtx` may modify the flags, so we'll need to store them. - case *ast.UpdateStmt: - ResetUpdateStmtCtx(sc, stmt, vars) - errLevels = sc.ErrLevels() - case *ast.DeleteStmt: - ResetDeleteStmtCtx(sc, stmt, vars) - errLevels = sc.ErrLevels() - case *ast.InsertStmt: - sc.InInsertStmt = true - // For insert statement (not for update statement), disabling the StrictSQLMode - // should make TruncateAsWarning and DividedByZeroAsWarning, - // but should not make DupKeyAsWarning. - if stmt.IgnoreErr { - errLevels[errctx.ErrGroupDupKey] = errctx.LevelWarn - errLevels[errctx.ErrGroupAutoIncReadFailed] = errctx.LevelWarn - errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn - } - errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) - errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( - !vars.SQLMode.HasErrorForDivisionByZeroMode(), - !strictSQLMode || stmt.IgnoreErr, - ) - sc.Priority = stmt.Priority - sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || - !vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode || stmt.IgnoreErr || - vars.SQLMode.HasAllowInvalidDatesMode())) - case *ast.CreateTableStmt, *ast.AlterTableStmt: - sc.InCreateOrAlterStmt = true - sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!strictSQLMode). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !strictSQLMode || - vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroDateErr(!vars.SQLMode.HasNoZeroDateMode() || !strictSQLMode)) - - case *ast.LoadDataStmt: - sc.InLoadDataStmt = true - // return warning instead of error when load data meet no partition for value - errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.LevelWarn - case *ast.SelectStmt: - sc.InSelectStmt = true - - // Return warning for truncate error in selection. - sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(true). - WithIgnoreZeroInDate(true). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) - if opts := stmt.SelectStmtOpts; opts != nil { - sc.Priority = opts.Priority - sc.NotFillCache = !opts.SQLCache - } - sc.WeakConsistency = isWeakConsistencyRead(ctx, stmt) - case *ast.SetOprStmt: - sc.InSelectStmt = true - sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(true). - WithIgnoreZeroInDate(true). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) - case *ast.ShowStmt: - sc.SetTypeFlags(sc.TypeFlags(). - WithIgnoreTruncateErr(true). - WithIgnoreZeroInDate(true). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) - if stmt.Tp == ast.ShowWarnings || stmt.Tp == ast.ShowErrors || stmt.Tp == ast.ShowSessionStates { - sc.InShowWarning = true - sc.SetWarnings(vars.StmtCtx.GetWarnings()) - } - case *ast.SplitRegionStmt: - sc.SetTypeFlags(sc.TypeFlags(). - WithIgnoreTruncateErr(false). - WithIgnoreZeroInDate(true). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) - case *ast.SetSessionStatesStmt: - sc.InSetSessionStatesStmt = true - sc.SetTypeFlags(sc.TypeFlags(). - WithIgnoreTruncateErr(true). - WithIgnoreZeroInDate(true). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) - default: - sc.SetTypeFlags(sc.TypeFlags(). - WithIgnoreTruncateErr(true). - WithIgnoreZeroInDate(true). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode())) - } - - if errLevels != sc.ErrLevels() { - sc.SetErrLevels(errLevels) - } - - sc.SetTypeFlags(sc.TypeFlags(). - WithSkipUTF8Check(vars.SkipUTF8Check). - WithSkipSACIICheck(vars.SkipASCIICheck). - WithSkipUTF8MB4Check(!globalConfig.Instance.CheckMb4ValueInUTF8.Load()). - // WithAllowNegativeToUnsigned with false value indicates values less than 0 should be clipped to 0 for unsigned integer types. - // This is the case for `insert`, `update`, `alter table`, `create table` and `load data infile` statements, when not in strict SQL mode. - // see https://dev.mysql.com/doc/refman/5.7/en/out-of-range-and-overflow.html - WithAllowNegativeToUnsigned(!sc.InInsertStmt && !sc.InLoadDataStmt && !sc.InUpdateStmt && !sc.InCreateOrAlterStmt), - ) - - vars.PlanCacheParams.Reset() - if priority := mysql.PriorityEnum(atomic.LoadInt32(&variable.ForcePriority)); priority != mysql.NoPriority { - sc.Priority = priority - } - if vars.StmtCtx.LastInsertID > 0 { - sc.PrevLastInsertID = vars.StmtCtx.LastInsertID - } else { - sc.PrevLastInsertID = vars.StmtCtx.PrevLastInsertID - } - sc.PrevAffectedRows = 0 - if vars.StmtCtx.InUpdateStmt || vars.StmtCtx.InDeleteStmt || vars.StmtCtx.InInsertStmt || vars.StmtCtx.InSetSessionStatesStmt { - sc.PrevAffectedRows = int64(vars.StmtCtx.AffectedRows()) - } else if vars.StmtCtx.InSelectStmt { - sc.PrevAffectedRows = -1 - } - if globalConfig.Instance.EnableCollectExecutionInfo.Load() { - // In ExplainFor case, RuntimeStatsColl should not be reset for reuse, - // because ExplainFor need to display the last statement information. - reuseObj := vars.StmtCtx.RuntimeStatsColl - if _, ok := s.(*ast.ExplainForStmt); ok { - reuseObj = nil - } - sc.RuntimeStatsColl = execdetails.NewRuntimeStatsColl(reuseObj) - - // also enable index usage collector - if sc.IndexUsageCollector == nil { - sc.IndexUsageCollector = ctx.NewStmtIndexUsageCollector() - } else { - sc.IndexUsageCollector.Reset() - } - } else { - // turn off the index usage collector - sc.IndexUsageCollector = nil - } - - sc.SetForcePlanCache(fixcontrol.GetBoolWithDefault(vars.OptimizerFixControl, fixcontrol.Fix49736, false)) - sc.SetAlwaysWarnSkipCache(sc.InExplainStmt && sc.ExplainFormat == "plan_cache") - errCount, warnCount := vars.StmtCtx.NumErrorWarnings() - vars.SysErrorCount = errCount - vars.SysWarningCount = warnCount - vars.ExchangeChunkStatus() - vars.StmtCtx = sc - vars.PrevFoundInPlanCache = vars.FoundInPlanCache - vars.FoundInPlanCache = false - vars.PrevFoundInBinding = vars.FoundInBinding - vars.FoundInBinding = false - vars.DurationWaitTS = 0 - vars.CurrInsertBatchExtraCols = nil - vars.CurrInsertValues = chunk.Row{} - - return -} - -// ResetUpdateStmtCtx resets statement context for UpdateStmt. -func ResetUpdateStmtCtx(sc *stmtctx.StatementContext, stmt *ast.UpdateStmt, vars *variable.SessionVars) { - strictSQLMode := vars.SQLMode.HasStrictMode() - sc.InUpdateStmt = true - errLevels := sc.ErrLevels() - errLevels[errctx.ErrGroupDupKey] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) - errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) - errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( - !vars.SQLMode.HasErrorForDivisionByZeroMode(), - !strictSQLMode || stmt.IgnoreErr, - ) - errLevels[errctx.ErrGroupNoMatchedPartition] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) - sc.SetErrLevels(errLevels) - sc.Priority = stmt.Priority - sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || - !strictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) -} - -// ResetDeleteStmtCtx resets statement context for DeleteStmt. -func ResetDeleteStmtCtx(sc *stmtctx.StatementContext, stmt *ast.DeleteStmt, vars *variable.SessionVars) { - strictSQLMode := vars.SQLMode.HasStrictMode() - sc.InDeleteStmt = true - errLevels := sc.ErrLevels() - errLevels[errctx.ErrGroupDupKey] = errctx.ResolveErrLevel(false, stmt.IgnoreErr) - errLevels[errctx.ErrGroupBadNull] = errctx.ResolveErrLevel(false, !strictSQLMode || stmt.IgnoreErr) - errLevels[errctx.ErrGroupDividedByZero] = errctx.ResolveErrLevel( - !vars.SQLMode.HasErrorForDivisionByZeroMode(), - !strictSQLMode || stmt.IgnoreErr, - ) - sc.SetErrLevels(errLevels) - sc.Priority = stmt.Priority - sc.SetTypeFlags(sc.TypeFlags(). - WithTruncateAsWarning(!strictSQLMode || stmt.IgnoreErr). - WithIgnoreInvalidDateErr(vars.SQLMode.HasAllowInvalidDatesMode()). - WithIgnoreZeroInDate(!vars.SQLMode.HasNoZeroInDateMode() || !vars.SQLMode.HasNoZeroDateMode() || - !strictSQLMode || stmt.IgnoreErr || vars.SQLMode.HasAllowInvalidDatesMode())) -} - -func setOptionForTopSQL(sc *stmtctx.StatementContext, snapshot kv.Snapshot) { - if snapshot == nil { - return - } - // pipelined dml may already flush in background, don't touch it to avoid race. - if txn, ok := snapshot.(kv.Transaction); ok && txn.IsPipelined() { - return - } - snapshot.SetOption(kv.ResourceGroupTagger, sc.GetResourceGroupTagger()) - if sc.KvExecCounter != nil { - snapshot.SetOption(kv.RPCInterceptor, sc.KvExecCounter.RPCInterceptor()) - } -} - -func isWeakConsistencyRead(ctx sessionctx.Context, node ast.Node) bool { - sessionVars := ctx.GetSessionVars() - return sessionVars.ConnectionID > 0 && sessionVars.ReadConsistency.IsWeak() && - plannercore.IsAutoCommitTxn(sessionVars) && plannercore.IsReadOnly(node, sessionVars) -} - -// FastCheckTableExec represents a check table executor. -// It is built from the "admin check table" statement, and it checks if the -// index matches the records in the table. -// It uses a new algorithms to check table data, which is faster than the old one(CheckTableExec). -type FastCheckTableExec struct { - exec.BaseExecutor - - dbName string - table table.Table - indexInfos []*model.IndexInfo - done bool - is infoschema.InfoSchema - err *atomic.Pointer[error] - wg sync.WaitGroup - contextCtx context.Context -} - -// Open implements the Executor Open interface. -func (e *FastCheckTableExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - - e.done = false - e.contextCtx = ctx - return nil -} - -type checkIndexTask struct { - indexOffset int -} - -type checkIndexWorker struct { - sctx sessionctx.Context - dbName string - table table.Table - indexInfos []*model.IndexInfo - e *FastCheckTableExec -} - -type groupByChecksum struct { - bucket uint64 - checksum uint64 - count int64 -} - -func getCheckSum(ctx context.Context, se sessionctx.Context, sql string) ([]groupByChecksum, error) { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnAdmin) - rs, err := se.GetSQLExecutor().ExecuteInternal(ctx, sql) - if err != nil { - return nil, err - } - defer func(rs sqlexec.RecordSet) { - err := rs.Close() - if err != nil { - logutil.BgLogger().Error("close record set failed", zap.Error(err)) - } - }(rs) - rows, err := sqlexec.DrainRecordSet(ctx, rs, 256) - if err != nil { - return nil, err - } - checksums := make([]groupByChecksum, 0, len(rows)) - for _, row := range rows { - checksums = append(checksums, groupByChecksum{bucket: row.GetUint64(1), checksum: row.GetUint64(0), count: row.GetInt64(2)}) - } - return checksums, nil -} - -func (w *checkIndexWorker) initSessCtx(se sessionctx.Context) (restore func()) { - sessVars := se.GetSessionVars() - originOptUseInvisibleIdx := sessVars.OptimizerUseInvisibleIndexes - originMemQuotaQuery := sessVars.MemQuotaQuery - - sessVars.OptimizerUseInvisibleIndexes = true - sessVars.MemQuotaQuery = w.sctx.GetSessionVars().MemQuotaQuery - return func() { - sessVars.OptimizerUseInvisibleIndexes = originOptUseInvisibleIdx - sessVars.MemQuotaQuery = originMemQuotaQuery - } -} - -// HandleTask implements the Worker interface. -func (w *checkIndexWorker) HandleTask(task checkIndexTask, _ func(workerpool.None)) { - defer w.e.wg.Done() - idxInfo := w.indexInfos[task.indexOffset] - bucketSize := int(CheckTableFastBucketSize.Load()) - - ctx := kv.WithInternalSourceType(w.e.contextCtx, kv.InternalTxnAdmin) - - trySaveErr := func(err error) { - w.e.err.CompareAndSwap(nil, &err) - } - - se, err := w.e.BaseExecutor.GetSysSession() - if err != nil { - trySaveErr(err) - return - } - restoreCtx := w.initSessCtx(se) - defer func() { - restoreCtx() - w.e.BaseExecutor.ReleaseSysSession(ctx, se) - }() - - var pkCols []string - var pkTypes []*types.FieldType - switch { - case w.e.table.Meta().IsCommonHandle: - pkColsInfo := w.e.table.Meta().GetPrimaryKey().Columns - for _, colInfo := range pkColsInfo { - colStr := colInfo.Name.O - pkCols = append(pkCols, colStr) - pkTypes = append(pkTypes, &w.e.table.Meta().Columns[colInfo.Offset].FieldType) - } - case w.e.table.Meta().PKIsHandle: - pkCols = append(pkCols, w.e.table.Meta().GetPkName().O) - default: // support decoding _tidb_rowid. - pkCols = append(pkCols, model.ExtraHandleName.O) - } - - // CheckSum of (handle + index columns). - var md5HandleAndIndexCol strings.Builder - md5HandleAndIndexCol.WriteString("crc32(md5(concat_ws(0x2, ") - for _, col := range pkCols { - md5HandleAndIndexCol.WriteString(ColumnName(col)) - md5HandleAndIndexCol.WriteString(", ") - } - for offset, col := range idxInfo.Columns { - tblCol := w.table.Meta().Columns[col.Offset] - if tblCol.IsGenerated() && !tblCol.GeneratedStored { - md5HandleAndIndexCol.WriteString(tblCol.GeneratedExprString) - } else { - md5HandleAndIndexCol.WriteString(ColumnName(col.Name.O)) - } - if offset != len(idxInfo.Columns)-1 { - md5HandleAndIndexCol.WriteString(", ") - } - } - md5HandleAndIndexCol.WriteString(")))") - - // Used to group by and order. - var md5Handle strings.Builder - md5Handle.WriteString("crc32(md5(concat_ws(0x2, ") - for i, col := range pkCols { - md5Handle.WriteString(ColumnName(col)) - if i != len(pkCols)-1 { - md5Handle.WriteString(", ") - } - } - md5Handle.WriteString(")))") - - handleColumnField := strings.Join(pkCols, ", ") - var indexColumnField strings.Builder - for offset, col := range idxInfo.Columns { - indexColumnField.WriteString(ColumnName(col.Name.O)) - if offset != len(idxInfo.Columns)-1 { - indexColumnField.WriteString(", ") - } - } - - tableRowCntToCheck := int64(0) - - offset := 0 - mod := 1 - meetError := false - - lookupCheckThreshold := int64(100) - checkOnce := false - - if w.e.Ctx().GetSessionVars().SnapshotTS != 0 { - se.GetSessionVars().SnapshotTS = w.e.Ctx().GetSessionVars().SnapshotTS - defer func() { - se.GetSessionVars().SnapshotTS = 0 - }() - } - _, err = se.GetSQLExecutor().ExecuteInternal(ctx, "begin") - if err != nil { - trySaveErr(err) - return - } - - times := 0 - const maxTimes = 10 - for tableRowCntToCheck > lookupCheckThreshold || !checkOnce { - times++ - if times == maxTimes { - logutil.BgLogger().Warn("compare checksum by group reaches time limit", zap.Int("times", times)) - break - } - whereKey := fmt.Sprintf("((cast(%s as signed) - %d) %% %d)", md5Handle.String(), offset, mod) - groupByKey := fmt.Sprintf("((cast(%s as signed) - %d) div %d %% %d)", md5Handle.String(), offset, mod, bucketSize) - if !checkOnce { - whereKey = "0" - } - checkOnce = true - - tblQuery := fmt.Sprintf("select /*+ read_from_storage(tikv[%s]) */ bit_xor(%s), %s, count(*) from %s use index() where %s = 0 group by %s", TableName(w.e.dbName, w.e.table.Meta().Name.String()), md5HandleAndIndexCol.String(), groupByKey, TableName(w.e.dbName, w.e.table.Meta().Name.String()), whereKey, groupByKey) - idxQuery := fmt.Sprintf("select bit_xor(%s), %s, count(*) from %s use index(`%s`) where %s = 0 group by %s", md5HandleAndIndexCol.String(), groupByKey, TableName(w.e.dbName, w.e.table.Meta().Name.String()), idxInfo.Name, whereKey, groupByKey) - - logutil.BgLogger().Info("fast check table by group", zap.String("table name", w.table.Meta().Name.String()), zap.String("index name", idxInfo.Name.String()), zap.Int("times", times), zap.Int("current offset", offset), zap.Int("current mod", mod), zap.String("table sql", tblQuery), zap.String("index sql", idxQuery)) - - // compute table side checksum. - tableChecksum, err := getCheckSum(w.e.contextCtx, se, tblQuery) - if err != nil { - trySaveErr(err) - return - } - slices.SortFunc(tableChecksum, func(i, j groupByChecksum) int { - return cmp.Compare(i.bucket, j.bucket) - }) - - // compute index side checksum. - indexChecksum, err := getCheckSum(w.e.contextCtx, se, idxQuery) - if err != nil { - trySaveErr(err) - return - } - slices.SortFunc(indexChecksum, func(i, j groupByChecksum) int { - return cmp.Compare(i.bucket, j.bucket) - }) - - currentOffset := 0 - - // Every checksum in table side should be the same as the index side. - i := 0 - for i < len(tableChecksum) && i < len(indexChecksum) { - if tableChecksum[i].bucket != indexChecksum[i].bucket || tableChecksum[i].checksum != indexChecksum[i].checksum { - if tableChecksum[i].bucket <= indexChecksum[i].bucket { - currentOffset = int(tableChecksum[i].bucket) - tableRowCntToCheck = tableChecksum[i].count - } else { - currentOffset = int(indexChecksum[i].bucket) - tableRowCntToCheck = indexChecksum[i].count - } - meetError = true - break - } - i++ - } - - if !meetError && i < len(indexChecksum) && i == len(tableChecksum) { - // Table side has fewer buckets. - currentOffset = int(indexChecksum[i].bucket) - tableRowCntToCheck = indexChecksum[i].count - meetError = true - } else if !meetError && i < len(tableChecksum) && i == len(indexChecksum) { - // Index side has fewer buckets. - currentOffset = int(tableChecksum[i].bucket) - tableRowCntToCheck = tableChecksum[i].count - meetError = true - } - - if !meetError { - if times != 1 { - logutil.BgLogger().Error("unexpected result, no error detected in this round, but an error is detected in the previous round", zap.Int("times", times), zap.Int("offset", offset), zap.Int("mod", mod)) - } - break - } - - offset += currentOffset * mod - mod *= bucketSize - } - - queryToRow := func(se sessionctx.Context, sql string) ([]chunk.Row, error) { - rs, err := se.GetSQLExecutor().ExecuteInternal(ctx, sql) - if err != nil { - return nil, err - } - row, err := sqlexec.DrainRecordSet(ctx, rs, 4096) - if err != nil { - return nil, err - } - err = rs.Close() - if err != nil { - logutil.BgLogger().Warn("close result set failed", zap.Error(err)) - } - return row, nil - } - - if meetError { - groupByKey := fmt.Sprintf("((cast(%s as signed) - %d) %% %d)", md5Handle.String(), offset, mod) - indexSQL := fmt.Sprintf("select %s, %s, %s from %s use index(`%s`) where %s = 0 order by %s", handleColumnField, indexColumnField.String(), md5HandleAndIndexCol.String(), TableName(w.e.dbName, w.e.table.Meta().Name.String()), idxInfo.Name, groupByKey, handleColumnField) - tableSQL := fmt.Sprintf("select /*+ read_from_storage(tikv[%s]) */ %s, %s, %s from %s use index() where %s = 0 order by %s", TableName(w.e.dbName, w.e.table.Meta().Name.String()), handleColumnField, indexColumnField.String(), md5HandleAndIndexCol.String(), TableName(w.e.dbName, w.e.table.Meta().Name.String()), groupByKey, handleColumnField) - - idxRow, err := queryToRow(se, indexSQL) - if err != nil { - trySaveErr(err) - return - } - tblRow, err := queryToRow(se, tableSQL) - if err != nil { - trySaveErr(err) - return - } - - errCtx := w.sctx.GetSessionVars().StmtCtx.ErrCtx() - getHandleFromRow := func(row chunk.Row) (kv.Handle, error) { - handleDatum := make([]types.Datum, 0) - for i, t := range pkTypes { - handleDatum = append(handleDatum, row.GetDatum(i, t)) - } - if w.table.Meta().IsCommonHandle { - handleBytes, err := codec.EncodeKey(w.sctx.GetSessionVars().StmtCtx.TimeZone(), nil, handleDatum...) - err = errCtx.HandleError(err) - if err != nil { - return nil, err - } - return kv.NewCommonHandle(handleBytes) - } - return kv.IntHandle(row.GetInt64(0)), nil - } - getValueFromRow := func(row chunk.Row) ([]types.Datum, error) { - valueDatum := make([]types.Datum, 0) - for i, t := range idxInfo.Columns { - valueDatum = append(valueDatum, row.GetDatum(i+len(pkCols), &w.table.Meta().Columns[t.Offset].FieldType)) - } - return valueDatum, nil - } - - ir := func() *consistency.Reporter { - return &consistency.Reporter{ - HandleEncode: func(handle kv.Handle) kv.Key { - return tablecodec.EncodeRecordKey(w.table.RecordPrefix(), handle) - }, - IndexEncode: func(idxRow *consistency.RecordData) kv.Key { - var idx table.Index - for _, v := range w.table.Indices() { - if strings.EqualFold(v.Meta().Name.String(), idxInfo.Name.O) { - idx = v - break - } - } - if idx == nil { - return nil - } - sc := w.sctx.GetSessionVars().StmtCtx - k, _, err := idx.GenIndexKey(sc.ErrCtx(), sc.TimeZone(), idxRow.Values[:len(idx.Meta().Columns)], idxRow.Handle, nil) - if err != nil { - return nil - } - return k - }, - Tbl: w.table.Meta(), - Idx: idxInfo, - EnableRedactLog: w.sctx.GetSessionVars().EnableRedactLog, - Storage: w.sctx.GetStore(), - } - } - - getCheckSum := func(row chunk.Row) uint64 { - return row.GetUint64(len(pkCols) + len(idxInfo.Columns)) - } - - var handle kv.Handle - var tableRecord *consistency.RecordData - var lastTableRecord *consistency.RecordData - var indexRecord *consistency.RecordData - i := 0 - for i < len(tblRow) || i < len(idxRow) { - if i == len(tblRow) { - // No more rows in table side. - tableRecord = nil - } else { - handle, err = getHandleFromRow(tblRow[i]) - if err != nil { - trySaveErr(err) - return - } - value, err := getValueFromRow(tblRow[i]) - if err != nil { - trySaveErr(err) - return - } - tableRecord = &consistency.RecordData{Handle: handle, Values: value} - } - if i == len(idxRow) { - // No more rows in index side. - indexRecord = nil - } else { - indexHandle, err := getHandleFromRow(idxRow[i]) - if err != nil { - trySaveErr(err) - return - } - indexValue, err := getValueFromRow(idxRow[i]) - if err != nil { - trySaveErr(err) - return - } - indexRecord = &consistency.RecordData{Handle: indexHandle, Values: indexValue} - } - - if tableRecord == nil { - if lastTableRecord != nil && lastTableRecord.Handle.Equal(indexRecord.Handle) { - tableRecord = lastTableRecord - } - err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, indexRecord.Handle, indexRecord, tableRecord) - } else if indexRecord == nil { - err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, tableRecord.Handle, indexRecord, tableRecord) - } else if tableRecord.Handle.Equal(indexRecord.Handle) && getCheckSum(tblRow[i]) != getCheckSum(idxRow[i]) { - err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, tableRecord.Handle, indexRecord, tableRecord) - } else if !tableRecord.Handle.Equal(indexRecord.Handle) { - if tableRecord.Handle.Compare(indexRecord.Handle) < 0 { - err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, tableRecord.Handle, nil, tableRecord) - } else { - if lastTableRecord != nil && lastTableRecord.Handle.Equal(indexRecord.Handle) { - err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, indexRecord.Handle, indexRecord, lastTableRecord) - } else { - err = ir().ReportAdminCheckInconsistent(w.e.contextCtx, indexRecord.Handle, indexRecord, nil) - } - } - } - if err != nil { - trySaveErr(err) - return - } - i++ - if tableRecord != nil { - lastTableRecord = &consistency.RecordData{Handle: tableRecord.Handle, Values: tableRecord.Values} - } else { - lastTableRecord = nil - } - } - } -} - -// Close implements the Worker interface. -func (*checkIndexWorker) Close() {} - -func (e *FastCheckTableExec) createWorker() workerpool.Worker[checkIndexTask, workerpool.None] { - return &checkIndexWorker{sctx: e.Ctx(), dbName: e.dbName, table: e.table, indexInfos: e.indexInfos, e: e} -} - -// Next implements the Executor Next interface. -func (e *FastCheckTableExec) Next(ctx context.Context, _ *chunk.Chunk) error { - if e.done || len(e.indexInfos) == 0 { - return nil - } - defer func() { e.done = true }() - - // Here we need check all indexes, includes invisible index - e.Ctx().GetSessionVars().OptimizerUseInvisibleIndexes = true - defer func() { - e.Ctx().GetSessionVars().OptimizerUseInvisibleIndexes = false - }() - - workerPool := workerpool.NewWorkerPool[checkIndexTask]("checkIndex", - poolutil.CheckTable, 3, e.createWorker) - workerPool.Start(ctx) - - e.wg.Add(len(e.indexInfos)) - for i := range e.indexInfos { - workerPool.AddTask(checkIndexTask{indexOffset: i}) - } - - e.wg.Wait() - workerPool.ReleaseAndWait() - - p := e.err.Load() - if p == nil { - return nil - } - return *p -} - -// TableName returns `schema`.`table` -func TableName(schema, table string) string { - return fmt.Sprintf("`%s`.`%s`", escapeName(schema), escapeName(table)) -} - -// ColumnName returns `column` -func ColumnName(column string) string { - return fmt.Sprintf("`%s`", escapeName(column)) -} - -func escapeName(name string) string { - return strings.ReplaceAll(name, "`", "``") -} - -// AdminShowBDRRoleExec represents a show BDR role executor. -type AdminShowBDRRoleExec struct { - exec.BaseExecutor - - done bool -} - -// Next implements the Executor Next interface. -func (e *AdminShowBDRRoleExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.done { - return nil - } - - return kv.RunInNewTxn(kv.WithInternalSourceType(ctx, kv.InternalTxnAdmin), e.Ctx().GetStore(), true, func(_ context.Context, txn kv.Transaction) error { - role, err := meta.NewMeta(txn).GetBDRRole() - if err != nil { - return err - } - - req.AppendString(0, role) - e.done = true - return nil - }) -} diff --git a/pkg/executor/import_into.go b/pkg/executor/import_into.go index 6e2a979c6f2fe..7d9a4f92efd95 100644 --- a/pkg/executor/import_into.go +++ b/pkg/executor/import_into.go @@ -119,7 +119,7 @@ func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error) return err } - failpoint.Call(_curpkg_("cancellableCtx"), &ctx) + failpoint.InjectCall("cancellableCtx", &ctx) jobID, task, err := e.submitTask(ctx) if err != nil { diff --git a/pkg/executor/import_into.go__failpoint_stash__ b/pkg/executor/import_into.go__failpoint_stash__ deleted file mode 100644 index 7d9a4f92efd95..0000000000000 --- a/pkg/executor/import_into.go__failpoint_stash__ +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "fmt" - - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/disttask/framework/handle" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - fstorage "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/disttask/importinto" - "github.com/pingcap/tidb/pkg/executor/importer" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/mysql" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -const unknownImportedRowCount = -1 - -// ImportIntoExec represents a IMPORT INTO executor. -type ImportIntoExec struct { - exec.BaseExecutor - selectExec exec.Executor - userSctx sessionctx.Context - controller *importer.LoadDataController - stmt string - - plan *plannercore.ImportInto - tbl table.Table - dataFilled bool -} - -var ( - _ exec.Executor = (*ImportIntoExec)(nil) -) - -func newImportIntoExec(b exec.BaseExecutor, selectExec exec.Executor, userSctx sessionctx.Context, - plan *plannercore.ImportInto, tbl table.Table) (*ImportIntoExec, error) { - return &ImportIntoExec{ - BaseExecutor: b, - selectExec: selectExec, - userSctx: userSctx, - stmt: plan.Stmt, - plan: plan, - tbl: tbl, - }, nil -} - -// Next implements the Executor Next interface. -func (e *ImportIntoExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - req.GrowAndReset(e.MaxChunkSize()) - ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) - if e.dataFilled { - // need to return an empty req to indicate all results have been written - return nil - } - importPlan, err := importer.NewImportPlan(ctx, e.userSctx, e.plan, e.tbl) - if err != nil { - return err - } - astArgs := importer.ASTArgsFromImportPlan(e.plan) - controller, err := importer.NewLoadDataController(importPlan, e.tbl, astArgs) - if err != nil { - return err - } - e.controller = controller - - if e.selectExec != nil { - // `import from select` doesn't return rows, so no need to set dataFilled. - return e.importFromSelect(ctx) - } - - if err2 := e.controller.InitDataFiles(ctx); err2 != nil { - return err2 - } - - // must use a new session to pre-check, else the stmt in show processlist will be changed. - newSCtx, err2 := CreateSession(e.userSctx) - if err2 != nil { - return err2 - } - defer CloseSession(newSCtx) - sqlExec := newSCtx.GetSQLExecutor() - if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil { - return err2 - } - - if err := e.controller.InitTiKVConfigs(ctx, newSCtx); err != nil { - return err - } - - failpoint.InjectCall("cancellableCtx", &ctx) - - jobID, task, err := e.submitTask(ctx) - if err != nil { - return err - } - - if !e.controller.Detached { - if err = e.waitTask(ctx, jobID, task); err != nil { - return err - } - } - return e.fillJobInfo(ctx, jobID, req) -} - -func (e *ImportIntoExec) fillJobInfo(ctx context.Context, jobID int64, req *chunk.Chunk) error { - e.dataFilled = true - // we use taskManager to get job, user might not have the privilege to system tables. - taskManager, err := fstorage.GetTaskManager() - ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - if err != nil { - return err - } - var info *importer.JobInfo - if err = taskManager.WithNewSession(func(se sessionctx.Context) error { - sqlExec := se.GetSQLExecutor() - var err2 error - info, err2 = importer.GetJob(ctx, sqlExec, jobID, e.Ctx().GetSessionVars().User.String(), false) - return err2 - }); err != nil { - return err - } - fillOneImportJobInfo(info, req, unknownImportedRowCount) - return nil -} - -func (e *ImportIntoExec) submitTask(ctx context.Context) (int64, *proto.TaskBase, error) { - importFromServer, err := storage.IsLocalPath(e.controller.Path) - if err != nil { - // since we have checked this during creating controller, this should not happen. - return 0, nil, exeerrors.ErrLoadDataInvalidURI.FastGenByArgs(plannercore.ImportIntoDataSource, err.Error()) - } - logutil.Logger(ctx).Info("get job importer", zap.Stringer("param", e.controller.Parameters), - zap.Bool("dist-task-enabled", variable.EnableDistTask.Load())) - if importFromServer { - ecp, err2 := e.controller.PopulateChunks(ctx) - if err2 != nil { - return 0, nil, err2 - } - return importinto.SubmitStandaloneTask(ctx, e.controller.Plan, e.stmt, ecp) - } - // if tidb_enable_dist_task=true, we import distributively, otherwise we import on current node. - if variable.EnableDistTask.Load() { - return importinto.SubmitTask(ctx, e.controller.Plan, e.stmt) - } - return importinto.SubmitStandaloneTask(ctx, e.controller.Plan, e.stmt, nil) -} - -// waitTask waits for the task to finish. -// NOTE: WaitTaskDoneOrPaused also return error when task fails. -func (*ImportIntoExec) waitTask(ctx context.Context, jobID int64, task *proto.TaskBase) error { - err := handle.WaitTaskDoneOrPaused(ctx, task.ID) - // when user KILL the connection, the ctx will be canceled, we need to cancel the import job. - if errors.Cause(err) == context.Canceled { - taskManager, err2 := fstorage.GetTaskManager() - if err2 != nil { - return err2 - } - // use background, since ctx is canceled already. - return cancelAndWaitImportJob(context.Background(), taskManager, jobID) - } - return err -} - -func (e *ImportIntoExec) importFromSelect(ctx context.Context) error { - e.dataFilled = true - // must use a new session as: - // - pre-check will execute other sql, the stmt in show processlist will be changed. - // - userSctx might be in stale read, we cannot do write. - newSCtx, err2 := CreateSession(e.userSctx) - if err2 != nil { - return err2 - } - defer CloseSession(newSCtx) - - sqlExec := newSCtx.GetSQLExecutor() - if err2 = e.controller.CheckRequirements(ctx, sqlExec); err2 != nil { - return err2 - } - if err := e.controller.InitTiKVConfigs(ctx, newSCtx); err != nil { - return err - } - - importID := uuid.New().String() - logutil.Logger(ctx).Info("importing data from select statement", - zap.String("import-id", importID), zap.Int("concurrency", e.controller.ThreadCnt), - zap.String("target-table", e.controller.FullTableName()), - zap.Int64("target-table-id", e.controller.TableInfo.ID)) - ti, err2 := importer.NewTableImporter(ctx, e.controller, importID, e.Ctx().GetStore()) - if err2 != nil { - return err2 - } - defer func() { - if err := ti.Close(); err != nil { - logutil.Logger(ctx).Error("close importer failed", zap.Error(err)) - } - }() - selectedRowCh := make(chan importer.QueryRow) - ti.SetSelectedRowCh(selectedRowCh) - - var importResult *importer.JobImportResult - eg, egCtx := errgroup.WithContext(ctx) - eg.Go(func() error { - var err error - importResult, err = ti.ImportSelectedRows(egCtx, newSCtx) - return err - }) - eg.Go(func() error { - defer close(selectedRowCh) - fields := exec.RetTypes(e.selectExec) - var idAllocator int64 - for { - // rows will be consumed concurrently, we cannot use chunk pool in session ctx. - chk := exec.NewFirstChunk(e.selectExec) - iter := chunk.NewIterator4Chunk(chk) - err := exec.Next(egCtx, e.selectExec, chk) - if err != nil { - return err - } - if chk.NumRows() == 0 { - break - } - for innerChunkRow := iter.Begin(); innerChunkRow != iter.End(); innerChunkRow = iter.Next() { - idAllocator++ - select { - case selectedRowCh <- importer.QueryRow{ - ID: idAllocator, - Data: innerChunkRow.GetDatumRow(fields), - }: - case <-egCtx.Done(): - return egCtx.Err() - } - } - } - return nil - }) - if err := eg.Wait(); err != nil { - return err - } - - if err2 = importer.FlushTableStats(ctx, newSCtx, e.controller.TableInfo.ID, importResult); err2 != nil { - logutil.Logger(ctx).Error("flush stats failed", zap.Error(err2)) - } - - stmtCtx := e.userSctx.GetSessionVars().StmtCtx - stmtCtx.SetAffectedRows(importResult.Affected) - // TODO: change it after spec is ready. - stmtCtx.SetMessage(fmt.Sprintf("Records: %d, ID: %s", importResult.Affected, importID)) - return nil -} - -// ImportIntoActionExec represents a import into action executor. -type ImportIntoActionExec struct { - exec.BaseExecutor - tp ast.ImportIntoActionTp - jobID int64 -} - -var ( - _ exec.Executor = (*ImportIntoActionExec)(nil) -) - -// Next implements the Executor Next interface. -func (e *ImportIntoActionExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { - ctx = kv.WithInternalSourceType(ctx, kv.InternalImportInto) - - var hasSuperPriv bool - if pm := privilege.GetPrivilegeManager(e.Ctx()); pm != nil { - hasSuperPriv = pm.RequestVerification(e.Ctx().GetSessionVars().ActiveRoles, "", "", "", mysql.SuperPriv) - } - // we use sessionCtx from GetTaskManager, user ctx might not have enough privileges. - taskManager, err := fstorage.GetTaskManager() - ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - if err != nil { - return err - } - if err = e.checkPrivilegeAndStatus(ctx, taskManager, hasSuperPriv); err != nil { - return err - } - - task := log.BeginTask(logutil.Logger(ctx).With(zap.Int64("jobID", e.jobID), - zap.Any("action", e.tp)), "import into action") - defer func() { - task.End(zap.ErrorLevel, err) - }() - return cancelAndWaitImportJob(ctx, taskManager, e.jobID) -} - -func (e *ImportIntoActionExec) checkPrivilegeAndStatus(ctx context.Context, manager *fstorage.TaskManager, hasSuperPriv bool) error { - var info *importer.JobInfo - if err := manager.WithNewSession(func(se sessionctx.Context) error { - exec := se.GetSQLExecutor() - var err2 error - info, err2 = importer.GetJob(ctx, exec, e.jobID, e.Ctx().GetSessionVars().User.String(), hasSuperPriv) - return err2 - }); err != nil { - return err - } - if !info.CanCancel() { - return exeerrors.ErrLoadDataInvalidOperation.FastGenByArgs("CANCEL") - } - return nil -} - -func cancelAndWaitImportJob(ctx context.Context, manager *fstorage.TaskManager, jobID int64) error { - if err := manager.WithNewTxn(ctx, func(se sessionctx.Context) error { - ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask) - return manager.CancelTaskByKeySession(ctx, se, importinto.TaskKey(jobID)) - }); err != nil { - return err - } - return handle.WaitTaskDoneByKey(ctx, importinto.TaskKey(jobID)) -} diff --git a/pkg/executor/importer/binding__failpoint_binding__.go b/pkg/executor/importer/binding__failpoint_binding__.go deleted file mode 100644 index 62be625525776..0000000000000 --- a/pkg/executor/importer/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package importer - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/importer/job.go b/pkg/executor/importer/job.go index b79ca87fb337f..dc0dec7b20d90 100644 --- a/pkg/executor/importer/job.go +++ b/pkg/executor/importer/job.go @@ -221,9 +221,9 @@ func CreateJob( return 0, errors.Errorf("unexpected result length: %d", len(rows)) } - if _, _err_ := failpoint.Eval(_curpkg_("setLastImportJobID")); _err_ == nil { + failpoint.Inject("setLastImportJobID", func() { TestLastImportJobID.Store(rows[0].GetInt64(0)) - } + }) return rows[0].GetInt64(0), nil } diff --git a/pkg/executor/importer/job.go__failpoint_stash__ b/pkg/executor/importer/job.go__failpoint_stash__ deleted file mode 100644 index dc0dec7b20d90..0000000000000 --- a/pkg/executor/importer/job.go__failpoint_stash__ +++ /dev/null @@ -1,370 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "context" - "encoding/json" - "fmt" - "sync/atomic" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/tikv/client-go/v2/util" -) - -// vars used for test. -var ( - // TestLastImportJobID last created job id, used in unit test. - TestLastImportJobID atomic.Int64 -) - -// constants for job status and step. -const ( - // JobStatus - // ┌───────┐ ┌───────┐ ┌────────┐ - // │pending├────►│running├───►│finished│ - // └────┬──┘ └────┬──┘ └────────┘ - // │ │ ┌──────┐ - // │ ├──────►│failed│ - // │ │ └──────┘ - // │ │ ┌─────────┐ - // └─────────────┴──────►│cancelled│ - // └─────────┘ - jobStatusPending = "pending" - // JobStatusRunning exported since it's used in show import jobs - JobStatusRunning = "running" - jogStatusCancelled = "cancelled" - jobStatusFailed = "failed" - jobStatusFinished = "finished" - - // when the job is finished, step will be set to none. - jobStepNone = "" - // JobStepGlobalSorting is the first step when using global sort, - // step goes from none -> global-sorting -> importing -> validating -> none. - JobStepGlobalSorting = "global-sorting" - // JobStepImporting is the first step when using local sort, - // step goes from none -> importing -> validating -> none. - // when used in global sort, it means importing the sorted data. - // when used in local sort, it means encode&sort data and then importing the data. - JobStepImporting = "importing" - JobStepValidating = "validating" - - baseQuerySQL = `SELECT - id, create_time, start_time, end_time, - table_schema, table_name, table_id, created_by, parameters, source_file_size, - status, step, summary, error_message - FROM mysql.tidb_import_jobs` -) - -// ImportParameters is the parameters for import into statement. -// it's a minimal meta info to store in tidb_import_jobs for diagnose. -// for detailed info, see tidb_global_tasks. -type ImportParameters struct { - ColumnsAndVars string `json:"columns-and-vars,omitempty"` - SetClause string `json:"set-clause,omitempty"` - // for s3 URL, AK/SK is redacted for security - FileLocation string `json:"file-location"` - Format string `json:"format"` - // only include what user specified, not include default value. - Options map[string]any `json:"options,omitempty"` -} - -var _ fmt.Stringer = &ImportParameters{} - -// String implements fmt.Stringer interface. -func (ip *ImportParameters) String() string { - b, _ := json.Marshal(ip) - return string(b) -} - -// JobSummary is the summary info of import into job. -type JobSummary struct { - // ImportedRows is the number of rows imported into TiKV. - ImportedRows uint64 `json:"imported-rows,omitempty"` -} - -// JobInfo is the information of import into job. -type JobInfo struct { - ID int64 - CreateTime types.Time - StartTime types.Time - EndTime types.Time - TableSchema string - TableName string - TableID int64 - CreatedBy string - Parameters ImportParameters - SourceFileSize int64 - Status string - // in SHOW IMPORT JOB, we name it as phase. - // here, we use the same name as in distributed framework. - Step string - // the summary info of the job, it's updated only when the job is finished. - // for running job, we should query the progress from the distributed framework. - Summary *JobSummary - ErrorMessage string -} - -// CanCancel returns whether the job can be cancelled. -func (j *JobInfo) CanCancel() bool { - return j.Status == jobStatusPending || j.Status == JobStatusRunning -} - -// GetJob returns the job with the given id if the user has privilege. -// hasSuperPriv: whether the user has super privilege. -// If the user has super privilege, the user can show or operate all jobs, -// else the user can only show or operate his own jobs. -func GetJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, user string, hasSuperPriv bool) (*JobInfo, error) { - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - - sql := baseQuerySQL + ` WHERE id = %?` - rs, err := conn.ExecuteInternal(ctx, sql, jobID) - if err != nil { - return nil, err - } - defer terror.Call(rs.Close) - rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) - if err != nil { - return nil, err - } - if len(rows) != 1 { - return nil, exeerrors.ErrLoadDataJobNotFound.GenWithStackByArgs(jobID) - } - - info, err := convert2JobInfo(rows[0]) - if err != nil { - return nil, err - } - if !hasSuperPriv && info.CreatedBy != user { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("SUPER") - } - return info, nil -} - -// GetActiveJobCnt returns the count of active import jobs. -// Active import jobs include pending and running jobs. -func GetActiveJobCnt(ctx context.Context, conn sqlexec.SQLExecutor, tableSchema, tableName string) (int64, error) { - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - - sql := `select count(1) from mysql.tidb_import_jobs - where status in (%?, %?) - and table_schema = %? and table_name = %?; - ` - rs, err := conn.ExecuteInternal(ctx, sql, jobStatusPending, JobStatusRunning, - tableSchema, tableName) - if err != nil { - return 0, err - } - defer terror.Call(rs.Close) - rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) - if err != nil { - return 0, err - } - cnt := rows[0].GetInt64(0) - return cnt, nil -} - -// CreateJob creates import into job by insert a record to system table. -// The AUTO_INCREMENT value will be returned as jobID. -func CreateJob( - ctx context.Context, - conn sqlexec.SQLExecutor, - db, table string, - tableID int64, - user string, - parameters *ImportParameters, - sourceFileSize int64, -) (int64, error) { - bytes, err := json.Marshal(parameters) - if err != nil { - return 0, err - } - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - _, err = conn.ExecuteInternal(ctx, `INSERT INTO mysql.tidb_import_jobs - (table_schema, table_name, table_id, created_by, parameters, source_file_size, status, step) - VALUES (%?, %?, %?, %?, %?, %?, %?, %?);`, - db, table, tableID, user, bytes, sourceFileSize, jobStatusPending, jobStepNone) - if err != nil { - return 0, err - } - rs, err := conn.ExecuteInternal(ctx, `SELECT LAST_INSERT_ID();`) - if err != nil { - return 0, err - } - defer terror.Call(rs.Close) - - rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) - if err != nil { - return 0, err - } - if len(rows) != 1 { - return 0, errors.Errorf("unexpected result length: %d", len(rows)) - } - - failpoint.Inject("setLastImportJobID", func() { - TestLastImportJobID.Store(rows[0].GetInt64(0)) - }) - return rows[0].GetInt64(0), nil -} - -// StartJob tries to start a pending job with jobID, change its status/step to running/input step. -// It will not return error when there's no matched job or the job has already started. -func StartJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, step string) error { - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - _, err := conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs - SET update_time = CURRENT_TIMESTAMP(6), start_time = CURRENT_TIMESTAMP(6), status = %?, step = %? - WHERE id = %? AND status = %?;`, - JobStatusRunning, step, jobID, jobStatusPending) - - return err -} - -// Job2Step tries to change the step of a running job with jobID. -// It will not return error when there's no matched job. -func Job2Step(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, step string) error { - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - _, err := conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs - SET update_time = CURRENT_TIMESTAMP(6), step = %? - WHERE id = %? AND status = %?;`, - step, jobID, JobStatusRunning) - - return err -} - -// FinishJob tries to finish a running job with jobID, change its status to finished, clear its step. -// It will not return error when there's no matched job. -func FinishJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, summary *JobSummary) error { - bytes, err := json.Marshal(summary) - if err != nil { - return err - } - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - _, err = conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs - SET update_time = CURRENT_TIMESTAMP(6), end_time = CURRENT_TIMESTAMP(6), status = %?, step = %?, summary = %? - WHERE id = %? AND status = %?;`, - jobStatusFinished, jobStepNone, bytes, jobID, JobStatusRunning) - return err -} - -// FailJob fails import into job. A job can only be failed once. -// It will not return error when there's no matched job. -func FailJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, errorMsg string) error { - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - _, err := conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs - SET update_time = CURRENT_TIMESTAMP(6), end_time = CURRENT_TIMESTAMP(6), status = %?, error_message = %? - WHERE id = %? AND status = %?;`, - jobStatusFailed, errorMsg, jobID, JobStatusRunning) - return err -} - -func convert2JobInfo(row chunk.Row) (*JobInfo, error) { - // start_time, end_time, summary, error_message can be NULL, need to use row.IsNull() to check. - startTime, endTime := types.ZeroTime, types.ZeroTime - if !row.IsNull(2) { - startTime = row.GetTime(2) - } - if !row.IsNull(3) { - endTime = row.GetTime(3) - } - - parameters := ImportParameters{} - parametersStr := row.GetString(8) - if err := json.Unmarshal([]byte(parametersStr), ¶meters); err != nil { - return nil, errors.Trace(err) - } - - var summary *JobSummary - var summaryStr string - if !row.IsNull(12) { - summaryStr = row.GetString(12) - } - if len(summaryStr) > 0 { - summary = &JobSummary{} - if err := json.Unmarshal([]byte(summaryStr), summary); err != nil { - return nil, errors.Trace(err) - } - } - - var errMsg string - if !row.IsNull(13) { - errMsg = row.GetString(13) - } - return &JobInfo{ - ID: row.GetInt64(0), - CreateTime: row.GetTime(1), - StartTime: startTime, - EndTime: endTime, - TableSchema: row.GetString(4), - TableName: row.GetString(5), - TableID: row.GetInt64(6), - CreatedBy: row.GetString(7), - Parameters: parameters, - SourceFileSize: row.GetInt64(9), - Status: row.GetString(10), - Step: row.GetString(11), - Summary: summary, - ErrorMessage: errMsg, - }, nil -} - -// GetAllViewableJobs gets all viewable jobs. -func GetAllViewableJobs(ctx context.Context, conn sqlexec.SQLExecutor, user string, hasSuperPriv bool) ([]*JobInfo, error) { - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - sql := baseQuerySQL - args := []any{} - if !hasSuperPriv { - sql += " WHERE created_by = %?" - args = append(args, user) - } - rs, err := conn.ExecuteInternal(ctx, sql, args...) - if err != nil { - return nil, err - } - defer terror.Call(rs.Close) - rows, err := sqlexec.DrainRecordSet(ctx, rs, 1) - if err != nil { - return nil, err - } - ret := make([]*JobInfo, 0, len(rows)) - for _, row := range rows { - jobInfo, err2 := convert2JobInfo(row) - if err2 != nil { - return nil, err2 - } - ret = append(ret, jobInfo) - } - - return ret, nil -} - -// CancelJob cancels import into job. Only a running/paused job can be canceled. -// check privileges using get before calling this method. -func CancelJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64) (err error) { - ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto) - sql := `UPDATE mysql.tidb_import_jobs - SET update_time = CURRENT_TIMESTAMP(6), status = %?, error_message = 'cancelled by user' - WHERE id = %? AND status IN (%?, %?);` - args := []any{jogStatusCancelled, jobID, jobStatusPending, JobStatusRunning} - _, err = conn.ExecuteInternal(ctx, sql, args...) - return err -} diff --git a/pkg/executor/importer/table_import.go b/pkg/executor/importer/table_import.go index 649d0eff2c512..1e5d1ad2b14e6 100644 --- a/pkg/executor/importer/table_import.go +++ b/pkg/executor/importer/table_import.go @@ -684,9 +684,9 @@ func (ti *TableImporter) ImportSelectedRows(ctx context.Context, se sessionctx.C if err != nil { return nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("mockImportFromSelectErr")); _err_ == nil { - return nil, errors.New("mock import from select error") - } + failpoint.Inject("mockImportFromSelectErr", func() { + failpoint.Return(nil, errors.New("mock import from select error")) + }) if err = closedDataEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys); err != nil { if common.ErrFoundDuplicateKeys.Equal(err) { err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) @@ -834,9 +834,9 @@ func VerifyChecksum(ctx context.Context, plan *Plan, localChecksum verify.KVChec } logger.Info("local checksum", zap.Object("checksum", &localChecksum)) - if _, _err_ := failpoint.Eval(_curpkg_("waitCtxDone")); _err_ == nil { + failpoint.Inject("waitCtxDone", func() { <-ctx.Done() - } + }) remoteChecksum, err := checksumTable(ctx, se, plan, logger) if err != nil { @@ -911,9 +911,9 @@ func checksumTable(ctx context.Context, se sessionctx.Context, plan *Plan, logge return errors.New("empty checksum result") } - if _, _err_ := failpoint.Eval(_curpkg_("errWhenChecksum")); _err_ == nil { - return errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline") - } + failpoint.Inject("errWhenChecksum", func() { + failpoint.Return(errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline")) + }) // ADMIN CHECKSUM TABLE . example. // mysql> admin checksum table test.t; diff --git a/pkg/executor/importer/table_import.go__failpoint_stash__ b/pkg/executor/importer/table_import.go__failpoint_stash__ deleted file mode 100644 index 1e5d1ad2b14e6..0000000000000 --- a/pkg/executor/importer/table_import.go__failpoint_stash__ +++ /dev/null @@ -1,983 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importer - -import ( - "context" - "io" - "math" - "net" - "os" - "path/filepath" - "strconv" - "strings" - "sync" - "time" - "unicode/utf8" - - "github.com/docker/go-units" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - tidb "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/keyspace" - tidbkv "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/backend/local" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/mydump" - verify "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/meta/autoid" - tidbmetrics "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/etcd" - "github.com/pingcap/tidb/pkg/util/mathutil" - "github.com/pingcap/tidb/pkg/util/promutil" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/pingcap/tidb/pkg/util/sqlkiller" - "github.com/pingcap/tidb/pkg/util/syncutil" - "github.com/prometheus/client_golang/prometheus" - "github.com/tikv/client-go/v2/util" - clientv3 "go.etcd.io/etcd/client/v3" - "go.uber.org/multierr" - "go.uber.org/zap" -) - -// NewTiKVModeSwitcher make it a var, so we can mock it in tests. -var NewTiKVModeSwitcher = local.NewTiKVModeSwitcher - -var ( - // CheckDiskQuotaInterval is the default time interval to check disk quota. - // TODO: make it dynamically adjusting according to the speed of import and the disk size. - CheckDiskQuotaInterval = 10 * time.Second - - // defaultMaxEngineSize is the default max engine size in bytes. - // we make it 5 times larger than lightning default engine size to reduce range overlap, especially for index, - // since we have an index engine per distributed subtask. - // for 1TiB data, we can divide it into 2 engines that runs on 2 TiDB. it can have a good balance between - // range overlap and sort speed in one of our test of: - // - 10 columns, PK + 6 secondary index 2 of which is mv index - // - 1.05 KiB per row, 527 MiB per file, 1024000000 rows, 1 TiB total - // - // it might not be the optimal value for other cases. - defaultMaxEngineSize = int64(5 * config.DefaultBatchSize) -) - -// prepareSortDir creates a new directory for import, remove previous sort directory if exists. -func prepareSortDir(e *LoadDataController, id string, tidbCfg *tidb.Config) (string, error) { - importDir := GetImportRootDir(tidbCfg) - sortDir := filepath.Join(importDir, id) - - if info, err := os.Stat(importDir); err != nil || !info.IsDir() { - if err != nil && !os.IsNotExist(err) { - e.logger.Error("stat import dir failed", zap.String("import_dir", importDir), zap.Error(err)) - return "", errors.Trace(err) - } - if info != nil && !info.IsDir() { - e.logger.Warn("import dir is not a dir, remove it", zap.String("import_dir", importDir)) - if err := os.RemoveAll(importDir); err != nil { - return "", errors.Trace(err) - } - } - e.logger.Info("import dir not exists, create it", zap.String("import_dir", importDir)) - if err := os.MkdirAll(importDir, 0o700); err != nil { - e.logger.Error("failed to make dir", zap.String("import_dir", importDir), zap.Error(err)) - return "", errors.Trace(err) - } - } - - // todo: remove this after we support checkpoint - if _, err := os.Stat(sortDir); err != nil { - if !os.IsNotExist(err) { - e.logger.Error("stat sort dir failed", zap.String("sort_dir", sortDir), zap.Error(err)) - return "", errors.Trace(err) - } - } else { - e.logger.Warn("sort dir already exists, remove it", zap.String("sort_dir", sortDir)) - if err := os.RemoveAll(sortDir); err != nil { - return "", errors.Trace(err) - } - } - return sortDir, nil -} - -// GetRegionSplitSizeKeys gets the region split size and keys from PD. -func GetRegionSplitSizeKeys(ctx context.Context) (regionSplitSize int64, regionSplitKeys int64, err error) { - tidbCfg := tidb.GetGlobalConfig() - tls, err := common.NewTLS( - tidbCfg.Security.ClusterSSLCA, - tidbCfg.Security.ClusterSSLCert, - tidbCfg.Security.ClusterSSLKey, - "", - nil, nil, nil, - ) - if err != nil { - return 0, 0, err - } - tlsOpt := tls.ToPDSecurityOption() - addrs := strings.Split(tidbCfg.Path, ",") - pdCli, err := NewClientWithContext(ctx, addrs, tlsOpt) - if err != nil { - return 0, 0, errors.Trace(err) - } - defer pdCli.Close() - return local.GetRegionSplitSizeKeys(ctx, pdCli, tls) -} - -// NewTableImporter creates a new table importer. -func NewTableImporter( - ctx context.Context, - e *LoadDataController, - id string, - kvStore tidbkv.Storage, -) (ti *TableImporter, err error) { - idAlloc := kv.NewPanickingAllocators(e.Table.Meta().SepAutoInc(), 0) - tbl, err := tables.TableFromMeta(idAlloc, e.Table.Meta()) - if err != nil { - return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", e.Table.Meta().Name) - } - - tidbCfg := tidb.GetGlobalConfig() - // todo: we only need to prepare this once on each node(we might call it 3 times in distribution framework) - dir, err := prepareSortDir(e, id, tidbCfg) - if err != nil { - return nil, err - } - - hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) - tls, err := common.NewTLS( - tidbCfg.Security.ClusterSSLCA, - tidbCfg.Security.ClusterSSLCert, - tidbCfg.Security.ClusterSSLKey, - hostPort, - nil, nil, nil, - ) - if err != nil { - return nil, err - } - - backendConfig := e.getLocalBackendCfg(tidbCfg.Path, dir) - d := kvStore.(tidbkv.StorageWithPD).GetPDClient().GetServiceDiscovery() - localBackend, err := local.NewBackend(ctx, tls, backendConfig, d) - if err != nil { - return nil, err - } - - return &TableImporter{ - LoadDataController: e, - id: id, - backend: localBackend, - tableInfo: &checkpoints.TidbTableInfo{ - ID: e.Table.Meta().ID, - Name: e.Table.Meta().Name.O, - Core: e.Table.Meta(), - }, - encTable: tbl, - dbID: e.DBID, - keyspace: kvStore.GetCodec().GetKeyspace(), - logger: e.logger.With(zap.String("import-id", id)), - // this is the value we use for 50TiB data parallel import. - // this might not be the optimal value. - // todo: use different default for single-node import and distributed import. - regionSplitSize: 2 * int64(config.SplitRegionSize), - regionSplitKeys: 2 * int64(config.SplitRegionKeys), - diskQuota: adjustDiskQuota(int64(e.DiskQuota), dir, e.logger), - diskQuotaLock: new(syncutil.RWMutex), - }, nil -} - -// TableImporter is a table importer. -type TableImporter struct { - *LoadDataController - // id is the unique id for this importer. - // it's the task id if we are running in distributed framework, else it's an - // uuid. we use this id to create a unique directory for this importer. - id string - backend *local.Backend - tableInfo *checkpoints.TidbTableInfo - // this table has a separate id allocator used to record the max row id allocated. - encTable table.Table - dbID int64 - - keyspace []byte - logger *zap.Logger - regionSplitSize int64 - regionSplitKeys int64 - diskQuota int64 - diskQuotaLock *syncutil.RWMutex - - rowCh chan QueryRow -} - -// NewTableImporterForTest creates a new table importer for test. -func NewTableImporterForTest(ctx context.Context, e *LoadDataController, id string, helper local.StoreHelper) (*TableImporter, error) { - idAlloc := kv.NewPanickingAllocators(e.Table.Meta().SepAutoInc(), 0) - tbl, err := tables.TableFromMeta(idAlloc, e.Table.Meta()) - if err != nil { - return nil, errors.Annotatef(err, "failed to tables.TableFromMeta %s", e.Table.Meta().Name) - } - - tidbCfg := tidb.GetGlobalConfig() - dir, err := prepareSortDir(e, id, tidbCfg) - if err != nil { - return nil, err - } - - backendConfig := e.getLocalBackendCfg(tidbCfg.Path, dir) - localBackend, err := local.NewBackendForTest(ctx, backendConfig, helper) - if err != nil { - return nil, err - } - - return &TableImporter{ - LoadDataController: e, - id: id, - backend: localBackend, - tableInfo: &checkpoints.TidbTableInfo{ - ID: e.Table.Meta().ID, - Name: e.Table.Meta().Name.O, - Core: e.Table.Meta(), - }, - encTable: tbl, - dbID: e.DBID, - logger: e.logger.With(zap.String("import-id", id)), - diskQuotaLock: new(syncutil.RWMutex), - }, nil -} - -// GetKeySpace gets the keyspace of the kv store. -func (ti *TableImporter) GetKeySpace() []byte { - return ti.keyspace -} - -func (ti *TableImporter) getParser(ctx context.Context, chunk *checkpoints.ChunkCheckpoint) (mydump.Parser, error) { - info := LoadDataReaderInfo{ - Opener: func(ctx context.Context) (io.ReadSeekCloser, error) { - reader, err := mydump.OpenReader(ctx, &chunk.FileMeta, ti.dataStore, storage.DecompressConfig{ - ZStdDecodeConcurrency: 1, - }) - if err != nil { - return nil, errors.Trace(err) - } - return reader, nil - }, - Remote: &chunk.FileMeta, - } - parser, err := ti.LoadDataController.GetParser(ctx, info) - if err != nil { - return nil, err - } - if chunk.Chunk.Offset == 0 { - // if data file is split, only the first chunk need to do skip. - // see check in initOptions. - if err = ti.LoadDataController.HandleSkipNRows(parser); err != nil { - return nil, err - } - parser.SetRowID(chunk.Chunk.PrevRowIDMax) - } else { - // if we reached here, the file must be an uncompressed CSV file. - if err = parser.SetPos(chunk.Chunk.Offset, chunk.Chunk.PrevRowIDMax); err != nil { - return nil, err - } - } - return parser, nil -} - -func (ti *TableImporter) getKVEncoder(chunk *checkpoints.ChunkCheckpoint) (KVEncoder, error) { - cfg := &encode.EncodingConfig{ - SessionOptions: encode.SessionOptions{ - SQLMode: ti.SQLMode, - Timestamp: chunk.Timestamp, - SysVars: ti.ImportantSysVars, - AutoRandomSeed: chunk.Chunk.PrevRowIDMax, - }, - Path: chunk.FileMeta.Path, - Table: ti.encTable, - Logger: log.Logger{Logger: ti.logger.With(zap.String("path", chunk.FileMeta.Path))}, - } - return NewTableKVEncoder(cfg, ti) -} - -func (e *LoadDataController) calculateSubtaskCnt() int { - // we want to split data files into subtask of size close to MaxEngineSize to reduce range overlap, - // and evenly distribute them to subtasks. - // we calculate subtask count first by round(TotalFileSize / maxEngineSize) - - // AllocateEngineIDs is using ceil() to calculate subtask count, engine size might be too small in some case, - // such as 501G data, maxEngineSize will be about 250G, so we don't relay on it. - // see https://github.com/pingcap/tidb/blob/b4183e1dc9bb01fb81d3aa79ca4b5b74387c6c2a/br/pkg/lightning/mydump/region.go#L109 - // - // for default e.MaxEngineSize = 500GiB, we have: - // data size range(G) cnt adjusted-engine-size range(G) - // [0, 750) 1 [0, 750) - // [750, 1250) 2 [375, 625) - // [1250, 1750) 3 [416, 583) - // [1750, 2250) 4 [437, 562) - var ( - subtaskCount float64 - maxEngineSize = int64(e.MaxEngineSize) - ) - if e.TotalFileSize <= maxEngineSize { - subtaskCount = 1 - } else { - subtaskCount = math.Round(float64(e.TotalFileSize) / float64(e.MaxEngineSize)) - } - - // for global sort task, since there is no overlap, - // we make sure subtask count is a multiple of execute nodes count - if e.IsGlobalSort() && e.ExecuteNodesCnt > 0 { - subtaskCount = math.Ceil(subtaskCount/float64(e.ExecuteNodesCnt)) * float64(e.ExecuteNodesCnt) - } - return int(subtaskCount) -} - -func (e *LoadDataController) getAdjustedMaxEngineSize() int64 { - subtaskCount := e.calculateSubtaskCnt() - // we adjust MaxEngineSize to make sure each subtask has a similar amount of data to import. - return int64(math.Ceil(float64(e.TotalFileSize) / float64(subtaskCount))) -} - -// SetExecuteNodeCnt sets the execute node count. -func (e *LoadDataController) SetExecuteNodeCnt(cnt int) { - e.ExecuteNodesCnt = cnt -} - -// PopulateChunks populates chunks from table regions. -// in dist framework, this should be done in the tidb node which is responsible for splitting job into subtasks -// then table-importer handles data belongs to the subtask. -func (e *LoadDataController) PopulateChunks(ctx context.Context) (ecp map[int32]*checkpoints.EngineCheckpoint, err error) { - task := log.BeginTask(e.logger, "populate chunks") - defer func() { - task.End(zap.ErrorLevel, err) - }() - - tableMeta := &mydump.MDTableMeta{ - DB: e.DBName, - Name: e.Table.Meta().Name.O, - DataFiles: e.toMyDumpFiles(), - } - adjustedMaxEngineSize := e.getAdjustedMaxEngineSize() - e.logger.Info("adjust max engine size", zap.Int64("before", int64(e.MaxEngineSize)), - zap.Int64("after", adjustedMaxEngineSize)) - dataDivideCfg := &mydump.DataDivideConfig{ - ColumnCnt: len(e.Table.Meta().Columns), - EngineDataSize: adjustedMaxEngineSize, - MaxChunkSize: int64(config.MaxRegionSize), - Concurrency: e.ThreadCnt, - IOWorkers: nil, - Store: e.dataStore, - TableMeta: tableMeta, - - StrictFormat: e.SplitFile, - DataCharacterSet: *e.Charset, - DataInvalidCharReplace: string(utf8.RuneError), - ReadBlockSize: LoadDataReadBlockSize, - CSV: *e.GenerateCSVConfig(), - } - makeEngineCtx := log.NewContext(ctx, log.Logger{Logger: e.logger}) - tableRegions, err2 := mydump.MakeTableRegions(makeEngineCtx, dataDivideCfg) - - if err2 != nil { - e.logger.Error("populate chunks failed", zap.Error(err2)) - return nil, err2 - } - - var maxRowID int64 - timestamp := time.Now().Unix() - tableCp := &checkpoints.TableCheckpoint{ - Engines: map[int32]*checkpoints.EngineCheckpoint{}, - } - for _, region := range tableRegions { - engine, found := tableCp.Engines[region.EngineID] - if !found { - engine = &checkpoints.EngineCheckpoint{ - Status: checkpoints.CheckpointStatusLoaded, - } - tableCp.Engines[region.EngineID] = engine - } - ccp := &checkpoints.ChunkCheckpoint{ - Key: checkpoints.ChunkCheckpointKey{ - Path: region.FileMeta.Path, - Offset: region.Chunk.Offset, - }, - FileMeta: region.FileMeta, - ColumnPermutation: nil, - Chunk: region.Chunk, - Timestamp: timestamp, - } - engine.Chunks = append(engine.Chunks, ccp) - if region.Chunk.RowIDMax > maxRowID { - maxRowID = region.Chunk.RowIDMax - } - } - - // Add index engine checkpoint - tableCp.Engines[common.IndexEngineID] = &checkpoints.EngineCheckpoint{Status: checkpoints.CheckpointStatusLoaded} - return tableCp.Engines, nil -} - -// a simplified version of EstimateCompactionThreshold -func (ti *TableImporter) getTotalRawFileSize(indexCnt int64) int64 { - var totalSize int64 - for _, file := range ti.dataFiles { - size := file.RealSize - if file.Type == mydump.SourceTypeParquet { - // parquet file is compressed, thus estimates with a factor of 2 - size *= 2 - } - totalSize += size - } - return totalSize * indexCnt -} - -// OpenIndexEngine opens an index engine. -func (ti *TableImporter) OpenIndexEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { - idxEngineCfg := &backend.EngineConfig{ - TableInfo: ti.tableInfo, - } - idxCnt := len(ti.tableInfo.Core.Indices) - if !common.TableHasAutoRowID(ti.tableInfo.Core) { - idxCnt-- - } - // todo: getTotalRawFileSize returns size of all data files, but in distributed framework, - // we create one index engine for each engine, should reflect this in the future. - threshold := local.EstimateCompactionThreshold2(ti.getTotalRawFileSize(int64(idxCnt))) - idxEngineCfg.Local = backend.LocalEngineConfig{ - Compact: threshold > 0, - CompactConcurrency: 4, - CompactThreshold: threshold, - BlockSize: 16 * 1024, - } - fullTableName := ti.FullTableName() - // todo: cleanup all engine data on any error since we don't support checkpoint for now - // some return path, didn't make sure all data engine and index engine are cleaned up. - // maybe we can add this in upper level to clean the whole local-sort directory - mgr := backend.MakeEngineManager(ti.backend) - return mgr.OpenEngine(ctx, idxEngineCfg, fullTableName, engineID) -} - -// OpenDataEngine opens a data engine. -func (ti *TableImporter) OpenDataEngine(ctx context.Context, engineID int32) (*backend.OpenedEngine, error) { - dataEngineCfg := &backend.EngineConfig{ - TableInfo: ti.tableInfo, - } - // todo: support checking IsRowOrdered later. - // also see test result here: https://github.com/pingcap/tidb/pull/47147 - //if ti.tableMeta.IsRowOrdered { - // dataEngineCfg.Local.Compact = true - // dataEngineCfg.Local.CompactConcurrency = 4 - // dataEngineCfg.Local.CompactThreshold = local.CompactionUpperThreshold - //} - mgr := backend.MakeEngineManager(ti.backend) - return mgr.OpenEngine(ctx, dataEngineCfg, ti.FullTableName(), engineID) -} - -// ImportAndCleanup imports the engine and cleanup the engine data. -func (ti *TableImporter) ImportAndCleanup(ctx context.Context, closedEngine *backend.ClosedEngine) (int64, error) { - var kvCount int64 - importErr := closedEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys) - if common.ErrFoundDuplicateKeys.Equal(importErr) { - importErr = local.ConvertToErrFoundConflictRecords(importErr, ti.encTable) - } - if closedEngine.GetID() != common.IndexEngineID { - // todo: change to a finer-grain progress later. - // each row is encoded into 1 data key - kvCount = ti.backend.GetImportedKVCount(closedEngine.GetUUID()) - } - cleanupErr := closedEngine.Cleanup(ctx) - return kvCount, multierr.Combine(importErr, cleanupErr) -} - -// Backend returns the backend of the importer. -func (ti *TableImporter) Backend() *local.Backend { - return ti.backend -} - -// Close implements the io.Closer interface. -func (ti *TableImporter) Close() error { - ti.backend.Close() - return nil -} - -// Allocators returns allocators used to record max used ID, i.e. PanickingAllocators. -func (ti *TableImporter) Allocators() autoid.Allocators { - return ti.encTable.Allocators(nil) -} - -// CheckDiskQuota checks disk quota. -func (ti *TableImporter) CheckDiskQuota(ctx context.Context) { - var locker sync.Locker - lockDiskQuota := func() { - if locker == nil { - ti.diskQuotaLock.Lock() - locker = ti.diskQuotaLock - } - } - unlockDiskQuota := func() { - if locker != nil { - locker.Unlock() - locker = nil - } - } - - defer unlockDiskQuota() - ti.logger.Info("start checking disk quota", zap.String("disk-quota", units.BytesSize(float64(ti.diskQuota)))) - for { - select { - case <-ctx.Done(): - return - case <-time.After(CheckDiskQuotaInterval): - } - - largeEngines, inProgressLargeEngines, totalDiskSize, totalMemSize := local.CheckDiskQuota(ti.backend, ti.diskQuota) - if len(largeEngines) == 0 && inProgressLargeEngines == 0 { - unlockDiskQuota() - continue - } - - ti.logger.Warn("disk quota exceeded", - zap.Int64("diskSize", totalDiskSize), - zap.Int64("memSize", totalMemSize), - zap.Int64("quota", ti.diskQuota), - zap.Int("largeEnginesCount", len(largeEngines)), - zap.Int("inProgressLargeEnginesCount", inProgressLargeEngines)) - - lockDiskQuota() - - if len(largeEngines) == 0 { - ti.logger.Warn("all large engines are already importing, keep blocking all writes") - continue - } - - if err := ti.backend.FlushAllEngines(ctx); err != nil { - ti.logger.Error("flush engine for disk quota failed, check again later", log.ShortError(err)) - unlockDiskQuota() - continue - } - - // at this point, all engines are synchronized on disk. - // we then import every large engines one by one and complete. - // if any engine failed to import, we just try again next time, since the data are still intact. - var importErr error - for _, engine := range largeEngines { - // Use a larger split region size to avoid split the same region by many times. - if err := ti.backend.UnsafeImportAndReset( - ctx, - engine, - int64(config.SplitRegionSize)*int64(config.MaxSplitRegionSizeRatio), - int64(config.SplitRegionKeys)*int64(config.MaxSplitRegionSizeRatio), - ); err != nil { - if common.ErrFoundDuplicateKeys.Equal(err) { - err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) - } - importErr = multierr.Append(importErr, err) - } - } - if importErr != nil { - // discuss: should we return the error and cancel the import? - ti.logger.Error("import large engines failed, check again later", log.ShortError(importErr)) - } - unlockDiskQuota() - } -} - -// SetSelectedRowCh sets the channel to receive selected rows. -func (ti *TableImporter) SetSelectedRowCh(ch chan QueryRow) { - ti.rowCh = ch -} - -func (ti *TableImporter) closeAndCleanupEngine(engine *backend.OpenedEngine) { - // outer context might be done, so we create a new context here. - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - closedEngine, err := engine.Close(ctx) - if err != nil { - ti.logger.Error("close engine failed", zap.Error(err)) - return - } - if err = closedEngine.Cleanup(ctx); err != nil { - ti.logger.Error("cleanup engine failed", zap.Error(err)) - } -} - -// ImportSelectedRows imports selected rows. -func (ti *TableImporter) ImportSelectedRows(ctx context.Context, se sessionctx.Context) (*JobImportResult, error) { - var ( - err error - dataEngine, indexEngine *backend.OpenedEngine - ) - metrics := tidbmetrics.GetRegisteredImportMetrics(promutil.NewDefaultFactory(), - prometheus.Labels{ - proto.TaskIDLabelName: ti.id, - }) - ctx = metric.WithCommonMetric(ctx, metrics) - defer func() { - tidbmetrics.UnregisterImportMetrics(metrics) - if dataEngine != nil { - ti.closeAndCleanupEngine(dataEngine) - } - if indexEngine != nil { - ti.closeAndCleanupEngine(indexEngine) - } - }() - - dataEngine, err = ti.OpenDataEngine(ctx, 1) - if err != nil { - return nil, err - } - indexEngine, err = ti.OpenIndexEngine(ctx, common.IndexEngineID) - if err != nil { - return nil, err - } - - var ( - mu sync.Mutex - checksum = verify.NewKVGroupChecksumWithKeyspace(ti.keyspace) - colSizeMap = make(map[int64]int64) - ) - eg, egCtx := tidbutil.NewErrorGroupWithRecoverWithCtx(ctx) - for i := 0; i < ti.ThreadCnt; i++ { - eg.Go(func() error { - chunkCheckpoint := checkpoints.ChunkCheckpoint{} - chunkChecksum := verify.NewKVGroupChecksumWithKeyspace(ti.keyspace) - progress := NewProgress() - defer func() { - mu.Lock() - defer mu.Unlock() - checksum.Add(chunkChecksum) - for k, v := range progress.GetColSize() { - colSizeMap[k] += v - } - }() - return ProcessChunk(egCtx, &chunkCheckpoint, ti, dataEngine, indexEngine, progress, ti.logger, chunkChecksum) - }) - } - if err = eg.Wait(); err != nil { - return nil, err - } - - closedDataEngine, err := dataEngine.Close(ctx) - if err != nil { - return nil, err - } - failpoint.Inject("mockImportFromSelectErr", func() { - failpoint.Return(nil, errors.New("mock import from select error")) - }) - if err = closedDataEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys); err != nil { - if common.ErrFoundDuplicateKeys.Equal(err) { - err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) - } - return nil, err - } - dataKVCount := ti.backend.GetImportedKVCount(closedDataEngine.GetUUID()) - - closedIndexEngine, err := indexEngine.Close(ctx) - if err != nil { - return nil, err - } - if err = closedIndexEngine.Import(ctx, ti.regionSplitSize, ti.regionSplitKeys); err != nil { - if common.ErrFoundDuplicateKeys.Equal(err) { - err = local.ConvertToErrFoundConflictRecords(err, ti.encTable) - } - return nil, err - } - - allocators := ti.Allocators() - maxIDs := map[autoid.AllocatorType]int64{ - autoid.RowIDAllocType: allocators.Get(autoid.RowIDAllocType).Base(), - autoid.AutoIncrementType: allocators.Get(autoid.AutoIncrementType).Base(), - autoid.AutoRandomType: allocators.Get(autoid.AutoRandomType).Base(), - } - if err = PostProcess(ctx, se, maxIDs, ti.Plan, checksum, ti.logger); err != nil { - return nil, err - } - - return &JobImportResult{ - Affected: uint64(dataKVCount), - ColSizeMap: colSizeMap, - }, nil -} - -func adjustDiskQuota(diskQuota int64, sortDir string, logger *zap.Logger) int64 { - sz, err := common.GetStorageSize(sortDir) - if err != nil { - logger.Warn("failed to get storage size", zap.Error(err)) - if diskQuota != 0 { - return diskQuota - } - logger.Info("use default quota instead", zap.Int64("quota", int64(DefaultDiskQuota))) - return int64(DefaultDiskQuota) - } - - maxDiskQuota := int64(float64(sz.Capacity) * 0.8) - switch { - case diskQuota == 0: - logger.Info("use 0.8 of the storage size as default disk quota", - zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) - return maxDiskQuota - case diskQuota > maxDiskQuota: - logger.Warn("disk quota is larger than 0.8 of the storage size, use 0.8 of the storage size instead", - zap.String("quota", units.HumanSize(float64(maxDiskQuota)))) - return maxDiskQuota - default: - return diskQuota - } -} - -// PostProcess does the post-processing for the task. -// exported for testing. -func PostProcess( - ctx context.Context, - se sessionctx.Context, - maxIDs map[autoid.AllocatorType]int64, - plan *Plan, - localChecksum *verify.KVGroupChecksum, - logger *zap.Logger, -) (err error) { - callLog := log.BeginTask(logger.With(zap.Object("checksum", localChecksum)), "post process") - defer func() { - callLog.End(zap.ErrorLevel, err) - }() - - if err = RebaseAllocatorBases(ctx, se.GetStore(), maxIDs, plan, logger); err != nil { - return err - } - - return VerifyChecksum(ctx, plan, localChecksum.MergedChecksum(), se, logger) -} - -type autoIDRequirement struct { - store tidbkv.Storage - autoidCli *autoid.ClientDiscover -} - -func (r *autoIDRequirement) Store() tidbkv.Storage { - return r.store -} - -func (r *autoIDRequirement) AutoIDClient() *autoid.ClientDiscover { - return r.autoidCli -} - -// RebaseAllocatorBases rebase the allocator bases. -func RebaseAllocatorBases(ctx context.Context, kvStore tidbkv.Storage, maxIDs map[autoid.AllocatorType]int64, plan *Plan, logger *zap.Logger) (err error) { - callLog := log.BeginTask(logger, "rebase allocators") - defer func() { - callLog.End(zap.ErrorLevel, err) - }() - - if !common.TableHasAutoID(plan.DesiredTableInfo) { - return nil - } - - tidbCfg := tidb.GetGlobalConfig() - hostPort := net.JoinHostPort("127.0.0.1", strconv.Itoa(int(tidbCfg.Status.StatusPort))) - tls, err2 := common.NewTLS( - tidbCfg.Security.ClusterSSLCA, - tidbCfg.Security.ClusterSSLCert, - tidbCfg.Security.ClusterSSLKey, - hostPort, - nil, nil, nil, - ) - if err2 != nil { - return err2 - } - - addrs := strings.Split(tidbCfg.Path, ",") - etcdCli, err := clientv3.New(clientv3.Config{ - Endpoints: addrs, - AutoSyncInterval: 30 * time.Second, - TLS: tls.TLSConfig(), - }) - if err != nil { - return errors.Trace(err) - } - etcd.SetEtcdCliByNamespace(etcdCli, keyspace.MakeKeyspaceEtcdNamespace(kvStore.GetCodec())) - autoidCli := autoid.NewClientDiscover(etcdCli) - r := autoIDRequirement{store: kvStore, autoidCli: autoidCli} - err = common.RebaseTableAllocators(ctx, maxIDs, &r, plan.DBID, plan.DesiredTableInfo) - if err1 := etcdCli.Close(); err1 != nil { - logger.Info("close etcd client error", zap.Error(err1)) - } - autoidCli.ResetConn(nil) - return errors.Trace(err) -} - -// VerifyChecksum verify the checksum of the table. -func VerifyChecksum(ctx context.Context, plan *Plan, localChecksum verify.KVChecksum, se sessionctx.Context, logger *zap.Logger) error { - if plan.Checksum == config.OpLevelOff { - return nil - } - logger.Info("local checksum", zap.Object("checksum", &localChecksum)) - - failpoint.Inject("waitCtxDone", func() { - <-ctx.Done() - }) - - remoteChecksum, err := checksumTable(ctx, se, plan, logger) - if err != nil { - if plan.Checksum != config.OpLevelOptional { - return err - } - logger.Warn("checksumTable failed, will skip this error and go on", zap.Error(err)) - } - if remoteChecksum != nil { - if !remoteChecksum.IsEqual(&localChecksum) { - err2 := common.ErrChecksumMismatch.GenWithStackByArgs( - remoteChecksum.Checksum, localChecksum.Sum(), - remoteChecksum.TotalKVs, localChecksum.SumKVS(), - remoteChecksum.TotalBytes, localChecksum.SumSize(), - ) - if plan.Checksum == config.OpLevelOptional { - logger.Warn("verify checksum failed, but checksum is optional, will skip it", zap.Error(err2)) - err2 = nil - } - return err2 - } - logger.Info("checksum pass", zap.Object("local", &localChecksum)) - } - return nil -} - -func checksumTable(ctx context.Context, se sessionctx.Context, plan *Plan, logger *zap.Logger) (*local.RemoteChecksum, error) { - var ( - tableName = common.UniqueTable(plan.DBName, plan.TableInfo.Name.L) - sql = "ADMIN CHECKSUM TABLE " + tableName - maxErrorRetryCount = 3 - distSQLScanConcurrencyFactor = 1 - remoteChecksum *local.RemoteChecksum - txnErr error - doneCh = make(chan struct{}) - ) - checkCtx, cancel := context.WithCancel(ctx) - defer func() { - cancel() - <-doneCh - }() - - go func() { - <-checkCtx.Done() - se.GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted) - close(doneCh) - }() - - distSQLScanConcurrencyBak := se.GetSessionVars().DistSQLScanConcurrency() - defer func() { - se.GetSessionVars().SetDistSQLScanConcurrency(distSQLScanConcurrencyBak) - }() - ctx = util.WithInternalSourceType(checkCtx, tidbkv.InternalImportInto) - for i := 0; i < maxErrorRetryCount; i++ { - txnErr = func() error { - // increase backoff weight - if err := setBackoffWeight(se, plan, logger); err != nil { - logger.Warn("set tidb_backoff_weight failed", zap.Error(err)) - } - - newConcurrency := mathutil.Max(plan.DistSQLScanConcurrency/distSQLScanConcurrencyFactor, local.MinDistSQLScanConcurrency) - logger.Info("checksum with adjusted distsql scan concurrency", zap.Int("concurrency", newConcurrency)) - se.GetSessionVars().SetDistSQLScanConcurrency(newConcurrency) - - // TODO: add resource group name - - rs, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), sql) - if err != nil { - return err - } - if len(rs) < 1 { - return errors.New("empty checksum result") - } - - failpoint.Inject("errWhenChecksum", func() { - failpoint.Return(errors.New("occur an error when checksum, coprocessor task terminated due to exceeding the deadline")) - }) - - // ADMIN CHECKSUM TABLE .
example. - // mysql> admin checksum table test.t; - // +---------+------------+---------------------+-----------+-------------+ - // | Db_name | Table_name | Checksum_crc64_xor | Total_kvs | Total_bytes | - // +---------+------------+---------------------+-----------+-------------+ - // | test | t | 8520875019404689597 | 7296873 | 357601387 | - // +---------+------------+------------- - remoteChecksum = &local.RemoteChecksum{ - Schema: rs[0].GetString(0), - Table: rs[0].GetString(1), - Checksum: rs[0].GetUint64(2), - TotalKVs: rs[0].GetUint64(3), - TotalBytes: rs[0].GetUint64(4), - } - return nil - }() - if !common.IsRetryableError(txnErr) { - break - } - distSQLScanConcurrencyFactor *= 2 - logger.Warn("retry checksum table", zap.Int("retry count", i+1), zap.Error(txnErr)) - } - return remoteChecksum, txnErr -} - -func setBackoffWeight(se sessionctx.Context, plan *Plan, logger *zap.Logger) error { - backoffWeight := local.DefaultBackoffWeight - if val, ok := plan.ImportantSysVars[variable.TiDBBackOffWeight]; ok { - if weight, err := strconv.Atoi(val); err == nil && weight > backoffWeight { - backoffWeight = weight - } - } - logger.Info("set backoff weight", zap.Int("weight", backoffWeight)) - return se.GetSessionVars().SetSystemVar(variable.TiDBBackOffWeight, strconv.Itoa(backoffWeight)) -} - -// GetImportRootDir returns the root directory for import. -// The directory structure is like: -// -// -> /path/to/tidb-tmpdir -// -> import-4000 -// -> 1 -// -> some-uuid -// -// exported for testing. -func GetImportRootDir(tidbCfg *tidb.Config) string { - sortPathSuffix := "import-" + strconv.Itoa(int(tidbCfg.Port)) - return filepath.Join(tidbCfg.TempDir, sortPathSuffix) -} - -// FlushTableStats flushes the stats of the table. -// stats will be flushed in domain.updateStatsWorker, default interval is [1, 2) minutes, -// see DumpStatsDeltaToKV for more details. then the background analyzer will analyze -// the table. -// the stats stay in memory until the next flush, so it might be lost if the tidb-server restarts. -func FlushTableStats(ctx context.Context, se sessionctx.Context, tableID int64, result *JobImportResult) error { - if err := sessiontxn.NewTxn(ctx, se); err != nil { - return err - } - sessionVars := se.GetSessionVars() - sessionVars.TxnCtxMu.Lock() - defer sessionVars.TxnCtxMu.Unlock() - sessionVars.TxnCtx.UpdateDeltaForTable(tableID, int64(result.Affected), int64(result.Affected), result.ColSizeMap) - se.StmtCommit(ctx) - return se.CommitTxn(ctx) -} diff --git a/pkg/executor/index_merge_reader.go b/pkg/executor/index_merge_reader.go index 2034dac4b2eea..84261164a5925 100644 --- a/pkg/executor/index_merge_reader.go +++ b/pkg/executor/index_merge_reader.go @@ -326,12 +326,12 @@ func (e *IndexMergeReaderExecutor) startIndexMergeProcessWorker(ctx context.Cont } func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { - if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeResultChCloseEarly")); _err_ == nil { + failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { // Wait for processWorker to close resultCh. time.Sleep(time.Second * 2) // Should use fetchCh instead of resultCh to send error. syncErr(ctx, e.finished, fetchCh, errors.New("testIndexMergeResultChCloseEarly")) - } + }) if e.RuntimeStats() != nil { collExec := true e.dagPBs[workID].CollectExecutionSummaries = &collExec @@ -343,9 +343,9 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, } else { keyRanges = [][]kv.KeyRange{e.keyRanges[workID]} } - if _, _err_ := failpoint.Eval(_curpkg_("startPartialIndexWorkerErr")); _err_ == nil { + failpoint.Inject("startPartialIndexWorkerErr", func() error { return errors.New("inject an error before start partialIndexWorker") - } + }) // for union case, the push-downLimit can be utilized to limit index fetched handles. // for intersection case, the push-downLimit can only be conducted after all index path/table finished. @@ -359,7 +359,7 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, defer e.idxWorkerWg.Done() util.WithRecovery( func() { - failpoint.Eval(_curpkg_("testIndexMergePanicPartialIndexWorker")) + failpoint.Inject("testIndexMergePanicPartialIndexWorker", nil) is := e.partialPlans[workID][0].(*plannercore.PhysicalIndexScan) worker := &partialIndexWorker{ stats: e.stats, @@ -443,7 +443,7 @@ func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, return } results = append(results, result) - failpoint.Eval(_curpkg_("testIndexMergePartialIndexWorkerCoprLeak")) + failpoint.Inject("testIndexMergePartialIndexWorkerCoprLeak", nil) } if len(results) > 1 && len(e.byItems) != 0 { // e.Schema() not the output schema for partialIndexReader, and we put byItems related column at first in `buildIndexReq`, so use nil here. @@ -486,7 +486,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, defer e.idxWorkerWg.Done() util.WithRecovery( func() { - failpoint.Eval(_curpkg_("testIndexMergePanicPartialTableWorker")) + failpoint.Inject("testIndexMergePanicPartialTableWorker", nil) var err error partialTableReader := &TableReaderExecutor{ BaseExecutorV2: exec.NewBaseExecutorV2(e.Ctx().GetSessionVars(), ts.Schema(), e.getPartitalPlanID(workID)), @@ -558,7 +558,7 @@ func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, syncErr(ctx, e.finished, fetchCh, err) break } - failpoint.Eval(_curpkg_("testIndexMergePartialTableWorkerCoprLeak")) + failpoint.Inject("testIndexMergePartialTableWorkerCoprLeak", nil) tableReaderClosed = false worker.batchSize = e.MaxChunkSize() if worker.batchSize > worker.maxBatchSize { @@ -705,9 +705,9 @@ func (w *partialTableWorker) extractTaskHandles(ctx context.Context, chk *chunk. w.tableReader.RuntimeStats().Record(time.Since(start), chk.NumRows()) } if chk.NumRows() == 0 { - if v, _err_ := failpoint.Eval(_curpkg_("testIndexMergeErrorPartialTableWorker")); _err_ == nil { - return handles, nil, errors.New(v.(string)) - } + failpoint.Inject("testIndexMergeErrorPartialTableWorker", func(v failpoint.Value) { + failpoint.Return(handles, nil, errors.New(v.(string))) + }) return handles, retChk, nil } memDelta := chk.MemoryUsage() @@ -851,12 +851,12 @@ func (e *IndexMergeReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) e } func (e *IndexMergeReaderExecutor) getResultTask(ctx context.Context) (*indexMergeTableTask, error) { - if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeMainReturnEarly")); _err_ == nil { + failpoint.Inject("testIndexMergeMainReturnEarly", func(_ failpoint.Value) { // To make sure processWorker make resultCh to be full. // When main goroutine close finished, processWorker may be stuck when writing resultCh. time.Sleep(time.Second * 20) - return nil, errors.New("failpoint testIndexMergeMainReturnEarly") - } + failpoint.Return(nil, errors.New("failpoint testIndexMergeMainReturnEarly")) + }) if e.resultCurr != nil && e.resultCurr.cursor < len(e.resultCurr.rows) { return e.resultCurr, nil } @@ -1154,9 +1154,9 @@ func (w *indexMergeProcessWorker) fetchLoopUnionWithOrderBy(ctx context.Context, case <-finished: return case resultCh <- task: - if _, _err_ := failpoint.Eval(_curpkg_("testCancelContext")); _err_ == nil { + failpoint.Inject("testCancelContext", func() { IndexMergeCancelFuncForTest() - } + }) select { case <-ctx.Done(): return @@ -1190,14 +1190,14 @@ func pushedLimitCountingDown(pushedLimit *plannercore.PushedDownLimit, handles [ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <-chan *indexMergeTableTask, workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { - if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeResultChCloseEarly")); _err_ == nil { - return - } + failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { + failpoint.Return() + }) memTracker := memory.NewTracker(w.indexMerge.ID(), -1) memTracker.AttachTo(w.indexMerge.memTracker) defer memTracker.Detach() defer close(workCh) - failpoint.Eval(_curpkg_("testIndexMergePanicProcessWorkerUnion")) + failpoint.Inject("testIndexMergePanicProcessWorkerUnion", nil) var pushedLimit *plannercore.PushedDownLimit if w.indexMerge.pushedLimit != nil { @@ -1267,23 +1267,23 @@ func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <- if w.stats != nil { w.stats.IndexMergeProcess += time.Since(start) } - if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeProcessWorkerUnionHang")); _err_ == nil { + failpoint.Inject("testIndexMergeProcessWorkerUnionHang", func(_ failpoint.Value) { for i := 0; i < cap(resultCh); i++ { select { case resultCh <- &indexMergeTableTask{}: default: } } - } + }) select { case <-ctx.Done(): return case <-finished: return case resultCh <- task: - if _, _err_ := failpoint.Eval(_curpkg_("testCancelContext")); _err_ == nil { + failpoint.Inject("testCancelContext", func() { IndexMergeCancelFuncForTest() - } + }) select { case <-ctx.Done(): return @@ -1377,7 +1377,7 @@ func (w *intersectionProcessWorker) consumeMemDelta() { // doIntersectionPerPartition fetch all the task from workerChannel, and after that, then do the intersection pruning, which // will cause wasting a lot of time waiting for all the fetch task done. func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Context, workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished, limitDone <-chan struct{}) { - failpoint.Eval(_curpkg_("testIndexMergePanicPartitionTableIntersectionWorker")) + failpoint.Inject("testIndexMergePanicPartitionTableIntersectionWorker", nil) defer w.memTracker.Detach() for task := range w.workerCh { @@ -1419,7 +1419,7 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte if w.rowDelta >= int64(w.batchSize) { w.consumeMemDelta() } - failpoint.Eval(_curpkg_("testIndexMergeIntersectionWorkerPanic")) + failpoint.Inject("testIndexMergeIntersectionWorkerPanic", nil) } if w.rowDelta > 0 { w.consumeMemDelta() @@ -1460,7 +1460,7 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte zap.Int("parTblIdx", parTblIdx), zap.Int("task.handles", len(task.handles))) } } - if _, _err_ := failpoint.Eval(_curpkg_("testIndexMergeProcessWorkerIntersectionHang")); _err_ == nil { + failpoint.Inject("testIndexMergeProcessWorkerIntersectionHang", func(_ failpoint.Value) { if resultCh != nil { for i := 0; i < cap(resultCh); i++ { select { @@ -1469,7 +1469,7 @@ func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Conte } } } - } + }) for _, task := range tasks { select { case <-ctx.Done(): @@ -1534,7 +1534,7 @@ func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fet }() } - failpoint.Eval(_curpkg_("testIndexMergePanicProcessWorkerIntersection")) + failpoint.Inject("testIndexMergePanicProcessWorkerIntersection", nil) // One goroutine may handle one or multiple partitions. // Max number of partition number is 8192, we use ExecutorConcurrency to avoid too many goroutines. @@ -1548,12 +1548,12 @@ func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fet partCnt = len(w.indexMerge.prunedPartitions) } workerCnt := min(partCnt, maxWorkerCnt) - if val, _err_ := failpoint.Eval(_curpkg_("testIndexMergeIntersectionConcurrency")); _err_ == nil { + failpoint.Inject("testIndexMergeIntersectionConcurrency", func(val failpoint.Value) { con := val.(int) if con != workerCnt { panic(fmt.Sprintf("unexpected workerCnt, expect %d, got %d", con, workerCnt)) } - } + }) partitionIDMap := make(map[int64]int) if w.indexMerge.hasGlobalIndex { @@ -1803,9 +1803,9 @@ func (w *partialIndexWorker) extractTaskHandles(ctx context.Context, chk *chunk. w.sc.GetSessionVars().StmtCtx.RuntimeStatsColl.GetBasicRuntimeStats(w.idxID).Record(time.Since(start), chk.NumRows()) } if chk.NumRows() == 0 { - if v, _err_ := failpoint.Eval(_curpkg_("testIndexMergeErrorPartialIndexWorker")); _err_ == nil { - return handles, nil, errors.New(v.(string)) - } + failpoint.Inject("testIndexMergeErrorPartialIndexWorker", func(v failpoint.Value) { + failpoint.Return(handles, nil, errors.New(v.(string))) + }) return handles, retChk, nil } memDelta := chk.MemoryUsage() @@ -1891,7 +1891,7 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** } // Make sure panic failpoint is after fetch task from workCh. // Otherwise, cannot send error to task.doneCh. - failpoint.Eval(_curpkg_("testIndexMergePanicTableScanWorker")) + failpoint.Inject("testIndexMergePanicTableScanWorker", nil) execStart := time.Now() err := w.executeTask(ctx, *task) if w.stats != nil { @@ -1899,7 +1899,7 @@ func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task ** atomic.AddInt64(&w.stats.FetchRow, int64(time.Since(execStart))) atomic.AddInt64(&w.stats.TableTaskNum, 1) } - failpoint.Eval(_curpkg_("testIndexMergePickAndExecTaskPanic")) + failpoint.Inject("testIndexMergePickAndExecTaskPanic", nil) select { case <-ctx.Done(): return diff --git a/pkg/executor/index_merge_reader.go__failpoint_stash__ b/pkg/executor/index_merge_reader.go__failpoint_stash__ deleted file mode 100644 index 84261164a5925..0000000000000 --- a/pkg/executor/index_merge_reader.go__failpoint_stash__ +++ /dev/null @@ -1,2056 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "cmp" - "container/heap" - "context" - "fmt" - "runtime/trace" - "slices" - "sort" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/executor/internal/builder" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -var ( - _ exec.Executor = &IndexMergeReaderExecutor{} - - // IndexMergeCancelFuncForTest is used just for test - IndexMergeCancelFuncForTest func() -) - -const ( - partialIndexWorkerType = "IndexMergePartialIndexWorker" - partialTableWorkerType = "IndexMergePartialTableWorker" - processWorkerType = "IndexMergeProcessWorker" - partTblIntersectionWorkerType = "IndexMergePartTblIntersectionWorker" - tableScanWorkerType = "IndexMergeTableScanWorker" -) - -// IndexMergeReaderExecutor accesses a table with multiple index/table scan. -// There are three types of workers: -// 1. partialTableWorker/partialIndexWorker, which are used to fetch the handles -// 2. indexMergeProcessWorker, which is used to do the `Union` operation. -// 3. indexMergeTableScanWorker, which is used to get the table tuples with the given handles. -// -// The execution flow is really like IndexLookUpReader. However, it uses multiple index scans -// or table scans to get the handles: -// 1. use the partialTableWorkers and partialIndexWorkers to fetch the handles (a batch per time) -// and send them to the indexMergeProcessWorker. -// 2. indexMergeProcessWorker do the `Union` operation for a batch of handles it have got. -// For every handle in the batch: -// 1. check whether it has been accessed. -// 2. if not, record it and send it to the indexMergeTableScanWorker. -// 3. if accessed, just ignore it. -type IndexMergeReaderExecutor struct { - exec.BaseExecutor - indexUsageReporter *exec.IndexUsageReporter - - table table.Table - indexes []*model.IndexInfo - descs []bool - ranges [][]*ranger.Range - dagPBs []*tipb.DAGRequest - startTS uint64 - tableRequest *tipb.DAGRequest - - keepOrder bool - pushedLimit *plannercore.PushedDownLimit - byItems []*plannerutil.ByItems - - // columns are only required by union scan. - columns []*model.ColumnInfo - // partitionIDMap are only required by union scan with global index. - partitionIDMap map[int64]struct{} - *dataReaderBuilder - - // fields about accessing partition tables - partitionTableMode bool // if this IndexMerge is accessing a partition table - prunedPartitions []table.PhysicalTable // pruned partition tables need to access - partitionKeyRanges [][][]kv.KeyRange // [partialIndex][partitionIdx][ranges] - - // All fields above are immutable. - - tblWorkerWg sync.WaitGroup - idxWorkerWg sync.WaitGroup - processWorkerWg sync.WaitGroup - finished chan struct{} - - workerStarted bool - keyRanges [][]kv.KeyRange - - resultCh chan *indexMergeTableTask - resultCurr *indexMergeTableTask - - // memTracker is used to track the memory usage of this executor. - memTracker *memory.Tracker - - partialPlans [][]base.PhysicalPlan - tblPlans []base.PhysicalPlan - partialNetDataSizes []float64 - dataAvgRowSize float64 - - handleCols plannerutil.HandleCols - stats *IndexMergeRuntimeStat - - // Indicates whether there is correlated column in filter or table/index range. - // We need to refresh dagPBs before send DAGReq to storage. - isCorColInPartialFilters []bool - isCorColInTableFilter bool - isCorColInPartialAccess []bool - - // Whether it's intersection or union. - isIntersection bool - - hasGlobalIndex bool -} - -type indexMergeTableTask struct { - lookupTableTask - - // parTblIdx are only used in indexMergeProcessWorker.fetchLoopIntersection. - parTblIdx int - - // partialPlanID are only used for indexMergeProcessWorker.fetchLoopUnionWithOrderBy. - partialPlanID int -} - -// Table implements the dataSourceExecutor interface. -func (e *IndexMergeReaderExecutor) Table() table.Table { - return e.table -} - -// Open implements the Executor Open interface -func (e *IndexMergeReaderExecutor) Open(_ context.Context) (err error) { - e.keyRanges = make([][]kv.KeyRange, 0, len(e.partialPlans)) - e.initRuntimeStats() - if e.isCorColInTableFilter { - e.tableRequest.Executors, err = builder.ConstructListBasedDistExec(e.Ctx().GetBuildPBCtx(), e.tblPlans) - if err != nil { - return err - } - } - if err = e.rebuildRangeForCorCol(); err != nil { - return err - } - - if !e.partitionTableMode { - if e.keyRanges, err = e.buildKeyRangesForTable(e.table); err != nil { - return err - } - } else { - e.partitionKeyRanges = make([][][]kv.KeyRange, len(e.indexes)) - tmpPartitionKeyRanges := make([][][]kv.KeyRange, len(e.prunedPartitions)) - for i, p := range e.prunedPartitions { - if tmpPartitionKeyRanges[i], err = e.buildKeyRangesForTable(p); err != nil { - return err - } - } - for i, idx := range e.indexes { - if idx != nil && idx.Global { - keyRange, _ := distsql.IndexRangesToKVRanges(e.ctx.GetDistSQLCtx(), e.table.Meta().ID, idx.ID, e.ranges[i]) - e.partitionKeyRanges[i] = [][]kv.KeyRange{keyRange.FirstPartitionRange()} - } else { - for _, pKeyRanges := range tmpPartitionKeyRanges { - e.partitionKeyRanges[i] = append(e.partitionKeyRanges[i], pKeyRanges[i]) - } - } - } - } - e.finished = make(chan struct{}) - e.resultCh = make(chan *indexMergeTableTask, atomic.LoadInt32(&LookupTableTaskChannelSize)) - if e.memTracker != nil { - e.memTracker.Reset() - } else { - e.memTracker = memory.NewTracker(e.ID(), -1) - } - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - return nil -} - -func (e *IndexMergeReaderExecutor) rebuildRangeForCorCol() (err error) { - len1 := len(e.partialPlans) - len2 := len(e.isCorColInPartialAccess) - if len1 != len2 { - return errors.Errorf("unexpect length for partialPlans(%d) and isCorColInPartialAccess(%d)", len1, len2) - } - for i, plan := range e.partialPlans { - if e.isCorColInPartialAccess[i] { - switch x := plan[0].(type) { - case *plannercore.PhysicalIndexScan: - e.ranges[i], err = rebuildIndexRanges(e.Ctx().GetExprCtx(), e.Ctx().GetRangerCtx(), x, x.IdxCols, x.IdxColLens) - case *plannercore.PhysicalTableScan: - e.ranges[i], err = x.ResolveCorrelatedColumns() - default: - err = errors.Errorf("unsupported plan type %T", plan[0]) - } - if err != nil { - return err - } - } - } - return nil -} - -func (e *IndexMergeReaderExecutor) buildKeyRangesForTable(tbl table.Table) (ranges [][]kv.KeyRange, err error) { - dctx := e.Ctx().GetDistSQLCtx() - for i, plan := range e.partialPlans { - _, ok := plan[0].(*plannercore.PhysicalIndexScan) - if !ok { - firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(e.ranges[i], false, e.descs[i], tbl.Meta().IsCommonHandle) - firstKeyRanges, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, firstPartRanges) - if err != nil { - return nil, err - } - secondKeyRanges, err := distsql.TableHandleRangesToKVRanges(dctx, []int64{getPhysicalTableID(tbl)}, tbl.Meta().IsCommonHandle, secondPartRanges) - if err != nil { - return nil, err - } - keyRanges := append(firstKeyRanges.FirstPartitionRange(), secondKeyRanges.FirstPartitionRange()...) - ranges = append(ranges, keyRanges) - continue - } - keyRange, err := distsql.IndexRangesToKVRanges(dctx, getPhysicalTableID(tbl), e.indexes[i].ID, e.ranges[i]) - if err != nil { - return nil, err - } - ranges = append(ranges, keyRange.FirstPartitionRange()) - } - return ranges, nil -} - -func (e *IndexMergeReaderExecutor) startWorkers(ctx context.Context) error { - exitCh := make(chan struct{}) - workCh := make(chan *indexMergeTableTask, 1) - fetchCh := make(chan *indexMergeTableTask, len(e.keyRanges)) - - e.startIndexMergeProcessWorker(ctx, workCh, fetchCh) - - var err error - for i := 0; i < len(e.partialPlans); i++ { - e.idxWorkerWg.Add(1) - if e.indexes[i] != nil { - err = e.startPartialIndexWorker(ctx, exitCh, fetchCh, i) - } else { - err = e.startPartialTableWorker(ctx, exitCh, fetchCh, i) - } - if err != nil { - e.idxWorkerWg.Done() - break - } - } - go e.waitPartialWorkersAndCloseFetchChan(fetchCh) - if err != nil { - close(exitCh) - return err - } - e.startIndexMergeTableScanWorker(ctx, workCh) - e.workerStarted = true - return nil -} - -func (e *IndexMergeReaderExecutor) waitPartialWorkersAndCloseFetchChan(fetchCh chan *indexMergeTableTask) { - e.idxWorkerWg.Wait() - close(fetchCh) -} - -func (e *IndexMergeReaderExecutor) startIndexMergeProcessWorker(ctx context.Context, workCh chan<- *indexMergeTableTask, fetch <-chan *indexMergeTableTask) { - idxMergeProcessWorker := &indexMergeProcessWorker{ - indexMerge: e, - stats: e.stats, - } - e.processWorkerWg.Add(1) - go func() { - defer trace.StartRegion(ctx, "IndexMergeProcessWorker").End() - util.WithRecovery( - func() { - if e.isIntersection { - if e.keepOrder { - // todo: implementing fetchLoopIntersectionWithOrderBy if necessary. - panic("Not support intersection with keepOrder = true") - } - idxMergeProcessWorker.fetchLoopIntersection(ctx, fetch, workCh, e.resultCh, e.finished) - } else if len(e.byItems) != 0 { - idxMergeProcessWorker.fetchLoopUnionWithOrderBy(ctx, fetch, workCh, e.resultCh, e.finished) - } else { - idxMergeProcessWorker.fetchLoopUnion(ctx, fetch, workCh, e.resultCh, e.finished) - } - }, - handleWorkerPanic(ctx, e.finished, nil, e.resultCh, nil, processWorkerType), - ) - e.processWorkerWg.Done() - }() -} - -func (e *IndexMergeReaderExecutor) startPartialIndexWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { - failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { - // Wait for processWorker to close resultCh. - time.Sleep(time.Second * 2) - // Should use fetchCh instead of resultCh to send error. - syncErr(ctx, e.finished, fetchCh, errors.New("testIndexMergeResultChCloseEarly")) - }) - if e.RuntimeStats() != nil { - collExec := true - e.dagPBs[workID].CollectExecutionSummaries = &collExec - } - - var keyRanges [][]kv.KeyRange - if e.partitionTableMode { - keyRanges = e.partitionKeyRanges[workID] - } else { - keyRanges = [][]kv.KeyRange{e.keyRanges[workID]} - } - failpoint.Inject("startPartialIndexWorkerErr", func() error { - return errors.New("inject an error before start partialIndexWorker") - }) - - // for union case, the push-downLimit can be utilized to limit index fetched handles. - // for intersection case, the push-downLimit can only be conducted after all index path/table finished. - pushedIndexLimit := e.pushedLimit - if e.isIntersection { - pushedIndexLimit = nil - } - - go func() { - defer trace.StartRegion(ctx, "IndexMergePartialIndexWorker").End() - defer e.idxWorkerWg.Done() - util.WithRecovery( - func() { - failpoint.Inject("testIndexMergePanicPartialIndexWorker", nil) - is := e.partialPlans[workID][0].(*plannercore.PhysicalIndexScan) - worker := &partialIndexWorker{ - stats: e.stats, - idxID: e.getPartitalPlanID(workID), - sc: e.Ctx(), - dagPB: e.dagPBs[workID], - plan: e.partialPlans[workID], - batchSize: e.MaxChunkSize(), - maxBatchSize: e.Ctx().GetSessionVars().IndexLookupSize, - maxChunkSize: e.MaxChunkSize(), - memTracker: e.memTracker, - partitionTableMode: e.partitionTableMode, - prunedPartitions: e.prunedPartitions, - byItems: is.ByItems, - pushedLimit: pushedIndexLimit, - } - if e.isCorColInPartialFilters[workID] { - // We got correlated column, so need to refresh Selection operator. - var err error - if e.dagPBs[workID].Executors, err = builder.ConstructListBasedDistExec(e.Ctx().GetBuildPBCtx(), e.partialPlans[workID]); err != nil { - syncErr(ctx, e.finished, fetchCh, err) - return - } - } - var builder distsql.RequestBuilder - builder.SetDAGRequest(e.dagPBs[workID]). - SetStartTS(e.startTS). - SetDesc(e.descs[workID]). - SetKeepOrder(e.keepOrder). - SetTxnScope(e.txnScope). - SetReadReplicaScope(e.readReplicaScope). - SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.Ctx().GetDistSQLCtx()). - SetMemTracker(e.memTracker). - SetFromInfoSchema(e.Ctx().GetInfoSchema()). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.Ctx().GetDistSQLCtx(), &builder.Request, e.partialNetDataSizes[workID])). - SetConnIDAndConnAlias(e.Ctx().GetSessionVars().ConnectionID, e.Ctx().GetSessionVars().SessionAlias) - - worker.batchSize = CalculateBatchSize(int(is.StatsCount()), e.MaxChunkSize(), worker.maxBatchSize) - if builder.Request.Paging.Enable && builder.Request.Paging.MinPagingSize < uint64(worker.batchSize) { - // when paging enabled and Paging.MinPagingSize less than initBatchSize, change Paging.MinPagingSize to - // initial batchSize to avoid redundant paging RPC, see more detail in https://github.com/pingcap/tidb/issues/54066 - builder.Request.Paging.MinPagingSize = uint64(worker.batchSize) - if builder.Request.Paging.MaxPagingSize < uint64(worker.batchSize) { - builder.Request.Paging.MaxPagingSize = uint64(worker.batchSize) - } - } - tps := worker.getRetTpsForIndexScan(e.handleCols) - results := make([]distsql.SelectResult, 0, len(keyRanges)) - defer func() { - // To make sure SelectResult.Close() is called even got panic in fetchHandles(). - for _, result := range results { - if err := result.Close(); err != nil { - logutil.Logger(ctx).Error("close Select result failed", zap.Error(err)) - } - } - }() - for _, keyRange := range keyRanges { - // check if this executor is closed - select { - case <-ctx.Done(): - return - case <-e.finished: - return - default: - } - - // init kvReq and worker for this partition - // The key ranges should be ordered. - slices.SortFunc(keyRange, func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - kvReq, err := builder.SetKeyRanges(keyRange).Build() - if err != nil { - syncErr(ctx, e.finished, fetchCh, err) - return - } - result, err := distsql.SelectWithRuntimeStats(ctx, e.Ctx().GetDistSQLCtx(), kvReq, tps, getPhysicalPlanIDs(e.partialPlans[workID]), e.getPartitalPlanID(workID)) - if err != nil { - syncErr(ctx, e.finished, fetchCh, err) - return - } - results = append(results, result) - failpoint.Inject("testIndexMergePartialIndexWorkerCoprLeak", nil) - } - if len(results) > 1 && len(e.byItems) != 0 { - // e.Schema() not the output schema for partialIndexReader, and we put byItems related column at first in `buildIndexReq`, so use nil here. - ssr := distsql.NewSortedSelectResults(e.Ctx().GetExprCtx().GetEvalCtx(), results, nil, e.byItems, e.memTracker) - results = []distsql.SelectResult{ssr} - } - ctx1, cancel := context.WithCancel(ctx) - // this error is reported in fetchHandles(), so ignore it here. - _, _ = worker.fetchHandles(ctx1, results, exitCh, fetchCh, e.finished, e.handleCols, workID) - cancel() - }, - handleWorkerPanic(ctx, e.finished, nil, fetchCh, nil, partialIndexWorkerType), - ) - }() - - return nil -} - -func (e *IndexMergeReaderExecutor) startPartialTableWorker(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, workID int) error { - ts := e.partialPlans[workID][0].(*plannercore.PhysicalTableScan) - - tbls := make([]table.Table, 0, 1) - if e.partitionTableMode && len(e.byItems) == 0 { - for _, p := range e.prunedPartitions { - tbls = append(tbls, p) - } - } else { - tbls = append(tbls, e.table) - } - - // for union case, the push-downLimit can be utilized to limit index fetched handles. - // for intersection case, the push-downLimit can only be conducted after all index/table path finished. - pushedTableLimit := e.pushedLimit - if e.isIntersection { - pushedTableLimit = nil - } - - go func() { - defer trace.StartRegion(ctx, "IndexMergePartialTableWorker").End() - defer e.idxWorkerWg.Done() - util.WithRecovery( - func() { - failpoint.Inject("testIndexMergePanicPartialTableWorker", nil) - var err error - partialTableReader := &TableReaderExecutor{ - BaseExecutorV2: exec.NewBaseExecutorV2(e.Ctx().GetSessionVars(), ts.Schema(), e.getPartitalPlanID(workID)), - tableReaderExecutorContext: newTableReaderExecutorContext(e.Ctx()), - dagPB: e.dagPBs[workID], - startTS: e.startTS, - txnScope: e.txnScope, - readReplicaScope: e.readReplicaScope, - isStaleness: e.isStaleness, - plans: e.partialPlans[workID], - ranges: e.ranges[workID], - netDataSize: e.partialNetDataSizes[workID], - keepOrder: ts.KeepOrder, - byItems: ts.ByItems, - } - - worker := &partialTableWorker{ - stats: e.stats, - sc: e.Ctx(), - batchSize: e.MaxChunkSize(), - maxBatchSize: e.Ctx().GetSessionVars().IndexLookupSize, - maxChunkSize: e.MaxChunkSize(), - tableReader: partialTableReader, - memTracker: e.memTracker, - partitionTableMode: e.partitionTableMode, - prunedPartitions: e.prunedPartitions, - byItems: ts.ByItems, - pushedLimit: pushedTableLimit, - } - - if len(e.prunedPartitions) != 0 && len(e.byItems) != 0 { - slices.SortFunc(worker.prunedPartitions, func(i, j table.PhysicalTable) int { - return cmp.Compare(i.GetPhysicalID(), j.GetPhysicalID()) - }) - partialTableReader.kvRangeBuilder = kvRangeBuilderFromRangeAndPartition{ - partitions: worker.prunedPartitions, - } - } - - if e.isCorColInPartialFilters[workID] { - if e.dagPBs[workID].Executors, err = builder.ConstructListBasedDistExec(e.Ctx().GetBuildPBCtx(), e.partialPlans[workID]); err != nil { - syncErr(ctx, e.finished, fetchCh, err) - return - } - partialTableReader.dagPB = e.dagPBs[workID] - } - - var tableReaderClosed bool - defer func() { - // To make sure SelectResult.Close() is called even got panic in fetchHandles(). - if !tableReaderClosed { - terror.Log(exec.Close(worker.tableReader)) - } - }() - for parTblIdx, tbl := range tbls { - // check if this executor is closed - select { - case <-ctx.Done(): - return - case <-e.finished: - return - default: - } - - // init partialTableReader and partialTableWorker again for the next table - partialTableReader.table = tbl - if err = exec.Open(ctx, partialTableReader); err != nil { - logutil.Logger(ctx).Error("open Select result failed:", zap.Error(err)) - syncErr(ctx, e.finished, fetchCh, err) - break - } - failpoint.Inject("testIndexMergePartialTableWorkerCoprLeak", nil) - tableReaderClosed = false - worker.batchSize = e.MaxChunkSize() - if worker.batchSize > worker.maxBatchSize { - worker.batchSize = worker.maxBatchSize - } - - // fetch all handles from this table - ctx1, cancel := context.WithCancel(ctx) - _, fetchErr := worker.fetchHandles(ctx1, exitCh, fetchCh, e.finished, e.handleCols, parTblIdx, workID) - // release related resources - cancel() - tableReaderClosed = true - if err = exec.Close(worker.tableReader); err != nil { - logutil.Logger(ctx).Error("close Select result failed:", zap.Error(err)) - } - // this error is reported in fetchHandles(), so ignore it here. - if fetchErr != nil { - break - } - } - }, - handleWorkerPanic(ctx, e.finished, nil, fetchCh, nil, partialTableWorkerType), - ) - }() - return nil -} - -func (e *IndexMergeReaderExecutor) initRuntimeStats() { - if e.RuntimeStats() != nil { - e.stats = &IndexMergeRuntimeStat{ - Concurrency: e.Ctx().GetSessionVars().IndexLookupConcurrency(), - } - } -} - -func (e *IndexMergeReaderExecutor) getPartitalPlanID(workID int) int { - if len(e.partialPlans[workID]) > 0 { - return e.partialPlans[workID][len(e.partialPlans[workID])-1].ID() - } - return 0 -} - -func (e *IndexMergeReaderExecutor) getTablePlanRootID() int { - if len(e.tblPlans) > 0 { - return e.tblPlans[len(e.tblPlans)-1].ID() - } - return e.ID() -} - -type partialTableWorker struct { - stats *IndexMergeRuntimeStat - sc sessionctx.Context - batchSize int - maxBatchSize int - maxChunkSize int - tableReader exec.Executor - memTracker *memory.Tracker - partitionTableMode bool - prunedPartitions []table.PhysicalTable - byItems []*plannerutil.ByItems - scannedKeys uint64 - pushedLimit *plannercore.PushedDownLimit -} - -// needPartitionHandle indicates whether we need create a partitionHandle or not. -// If the schema from planner part contains ExtraPhysTblID, -// we need create a partitionHandle, otherwise create a normal handle. -// In TableRowIDScan, the partitionHandle will be used to create key ranges. -func (w *partialTableWorker) needPartitionHandle() (bool, error) { - cols := w.tableReader.(*TableReaderExecutor).plans[0].Schema().Columns - outputOffsets := w.tableReader.(*TableReaderExecutor).dagPB.OutputOffsets - col := cols[outputOffsets[len(outputOffsets)-1]] - - needPartitionHandle := w.partitionTableMode && len(w.byItems) > 0 - hasExtraCol := col.ID == model.ExtraPhysTblID - - // There will be two needPartitionHandle != hasExtraCol situations. - // Only `needPartitionHandle` == true and `hasExtraCol` == false are not allowed. - // `ExtraPhysTblID` will be used in `SelectLock` when `needPartitionHandle` == false and `hasExtraCol` == true. - if needPartitionHandle && !hasExtraCol { - return needPartitionHandle, errors.Errorf("Internal error, needPartitionHandle != ret") - } - return needPartitionHandle, nil -} - -func (w *partialTableWorker) fetchHandles(ctx context.Context, exitCh <-chan struct{}, fetchCh chan<- *indexMergeTableTask, - finished <-chan struct{}, handleCols plannerutil.HandleCols, parTblIdx int, partialPlanIndex int) (count int64, err error) { - chk := w.tableReader.NewChunkWithCapacity(w.getRetTpsForTableScan(), w.maxChunkSize, w.maxBatchSize) - for { - start := time.Now() - handles, retChunk, err := w.extractTaskHandles(ctx, chk, handleCols) - if err != nil { - syncErr(ctx, finished, fetchCh, err) - return count, err - } - if len(handles) == 0 { - return count, nil - } - count += int64(len(handles)) - task := w.buildTableTask(handles, retChunk, parTblIdx, partialPlanIndex) - if w.stats != nil { - atomic.AddInt64(&w.stats.FetchIdxTime, int64(time.Since(start))) - } - select { - case <-ctx.Done(): - return count, ctx.Err() - case <-exitCh: - return count, nil - case <-finished: - return count, nil - case fetchCh <- task: - } - } -} - -func (w *partialTableWorker) getRetTpsForTableScan() []*types.FieldType { - return exec.RetTypes(w.tableReader) -} - -func (w *partialTableWorker) extractTaskHandles(ctx context.Context, chk *chunk.Chunk, handleCols plannerutil.HandleCols) ( - handles []kv.Handle, retChk *chunk.Chunk, err error) { - handles = make([]kv.Handle, 0, w.batchSize) - if len(w.byItems) != 0 { - retChk = chunk.NewChunkWithCapacity(w.getRetTpsForTableScan(), w.batchSize) - } - var memUsage int64 - var chunkRowOffset int - defer w.memTracker.Consume(-memUsage) - for len(handles) < w.batchSize { - requiredRows := w.batchSize - len(handles) - if w.pushedLimit != nil { - if w.pushedLimit.Offset+w.pushedLimit.Count <= w.scannedKeys { - return handles, retChk, nil - } - requiredRows = min(int(w.pushedLimit.Offset+w.pushedLimit.Count-w.scannedKeys), requiredRows) - } - chk.SetRequiredRows(requiredRows, w.maxChunkSize) - start := time.Now() - err = errors.Trace(w.tableReader.Next(ctx, chk)) - if err != nil { - return nil, nil, err - } - if w.tableReader != nil && w.tableReader.RuntimeStats() != nil { - w.tableReader.RuntimeStats().Record(time.Since(start), chk.NumRows()) - } - if chk.NumRows() == 0 { - failpoint.Inject("testIndexMergeErrorPartialTableWorker", func(v failpoint.Value) { - failpoint.Return(handles, nil, errors.New(v.(string))) - }) - return handles, retChk, nil - } - memDelta := chk.MemoryUsage() - memUsage += memDelta - w.memTracker.Consume(memDelta) - for chunkRowOffset = 0; chunkRowOffset < chk.NumRows(); chunkRowOffset++ { - if w.pushedLimit != nil { - w.scannedKeys++ - if w.scannedKeys > (w.pushedLimit.Offset + w.pushedLimit.Count) { - // Skip the handles after Offset+Count. - break - } - } - var handle kv.Handle - ok, err1 := w.needPartitionHandle() - if err1 != nil { - return nil, nil, err1 - } - if ok { - handle, err = handleCols.BuildPartitionHandleFromIndexRow(chk.GetRow(chunkRowOffset)) - } else { - handle, err = handleCols.BuildHandleFromIndexRow(chk.GetRow(chunkRowOffset)) - } - if err != nil { - return nil, nil, err - } - handles = append(handles, handle) - } - // used for order by - if len(w.byItems) != 0 { - retChk.Append(chk, 0, chunkRowOffset) - } - } - w.batchSize *= 2 - if w.batchSize > w.maxBatchSize { - w.batchSize = w.maxBatchSize - } - return handles, retChk, nil -} - -func (w *partialTableWorker) buildTableTask(handles []kv.Handle, retChk *chunk.Chunk, parTblIdx int, partialPlanID int) *indexMergeTableTask { - task := &indexMergeTableTask{ - lookupTableTask: lookupTableTask{ - handles: handles, - idxRows: retChk, - }, - parTblIdx: parTblIdx, - partialPlanID: partialPlanID, - } - - if w.prunedPartitions != nil { - task.partitionTable = w.prunedPartitions[parTblIdx] - } - - task.doneCh = make(chan error, 1) - return task -} - -func (e *IndexMergeReaderExecutor) startIndexMergeTableScanWorker(ctx context.Context, workCh <-chan *indexMergeTableTask) { - lookupConcurrencyLimit := e.Ctx().GetSessionVars().IndexLookupConcurrency() - e.tblWorkerWg.Add(lookupConcurrencyLimit) - for i := 0; i < lookupConcurrencyLimit; i++ { - worker := &indexMergeTableScanWorker{ - stats: e.stats, - workCh: workCh, - finished: e.finished, - indexMergeExec: e, - tblPlans: e.tblPlans, - memTracker: e.memTracker, - } - ctx1, cancel := context.WithCancel(ctx) - go func() { - defer trace.StartRegion(ctx, "IndexMergeTableScanWorker").End() - var task *indexMergeTableTask - util.WithRecovery( - // Note we use the address of `task` as the argument of both `pickAndExecTask` and `handleTableScanWorkerPanic` - // because `task` is expected to be assigned in `pickAndExecTask`, and this assignment should also be visible - // in `handleTableScanWorkerPanic` since it will get `doneCh` from `task`. Golang always pass argument by value, - // so if we don't use the address of `task` as the argument, the assignment to `task` in `pickAndExecTask` is - // not visible in `handleTableScanWorkerPanic` - func() { worker.pickAndExecTask(ctx1, &task) }, - worker.handleTableScanWorkerPanic(ctx1, e.finished, &task, tableScanWorkerType), - ) - cancel() - e.tblWorkerWg.Done() - }() - } -} - -func (e *IndexMergeReaderExecutor) buildFinalTableReader(ctx context.Context, tbl table.Table, handles []kv.Handle) (_ exec.Executor, err error) { - tableReaderExec := &TableReaderExecutor{ - BaseExecutorV2: exec.NewBaseExecutorV2(e.Ctx().GetSessionVars(), e.Schema(), e.getTablePlanRootID()), - tableReaderExecutorContext: newTableReaderExecutorContext(e.Ctx()), - table: tbl, - dagPB: e.tableRequest, - startTS: e.startTS, - txnScope: e.txnScope, - readReplicaScope: e.readReplicaScope, - isStaleness: e.isStaleness, - columns: e.columns, - plans: e.tblPlans, - netDataSize: e.dataAvgRowSize * float64(len(handles)), - } - tableReaderExec.buildVirtualColumnInfo() - // Reorder handles because SplitKeyRangesByLocationsWith/WithoutBuckets() requires startKey of kvRanges is ordered. - // Also it's good for performance. - tableReader, err := e.dataReaderBuilder.buildTableReaderFromHandles(ctx, tableReaderExec, handles, true) - if err != nil { - logutil.Logger(ctx).Error("build table reader from handles failed", zap.Error(err)) - return nil, err - } - return tableReader, nil -} - -// Next implements Executor Next interface. -func (e *IndexMergeReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error { - if !e.workerStarted { - if err := e.startWorkers(ctx); err != nil { - return err - } - } - - req.Reset() - for { - resultTask, err := e.getResultTask(ctx) - if err != nil { - return errors.Trace(err) - } - if resultTask == nil { - return nil - } - if resultTask.cursor < len(resultTask.rows) { - numToAppend := min(len(resultTask.rows)-resultTask.cursor, e.MaxChunkSize()-req.NumRows()) - req.AppendRows(resultTask.rows[resultTask.cursor : resultTask.cursor+numToAppend]) - resultTask.cursor += numToAppend - if req.NumRows() >= e.MaxChunkSize() { - return nil - } - } - } -} - -func (e *IndexMergeReaderExecutor) getResultTask(ctx context.Context) (*indexMergeTableTask, error) { - failpoint.Inject("testIndexMergeMainReturnEarly", func(_ failpoint.Value) { - // To make sure processWorker make resultCh to be full. - // When main goroutine close finished, processWorker may be stuck when writing resultCh. - time.Sleep(time.Second * 20) - failpoint.Return(nil, errors.New("failpoint testIndexMergeMainReturnEarly")) - }) - if e.resultCurr != nil && e.resultCurr.cursor < len(e.resultCurr.rows) { - return e.resultCurr, nil - } - task, ok := <-e.resultCh - if !ok { - return nil, nil - } - - select { - case <-ctx.Done(): - return nil, errors.Trace(ctx.Err()) - case err := <-task.doneCh: - if err != nil { - return nil, errors.Trace(err) - } - } - - // Release the memory usage of last task before we handle a new task. - if e.resultCurr != nil { - e.resultCurr.memTracker.Consume(-e.resultCurr.memUsage) - } - e.resultCurr = task - return e.resultCurr, nil -} - -func handleWorkerPanic(ctx context.Context, finished, limitDone <-chan struct{}, ch chan<- *indexMergeTableTask, extraNotifyCh chan bool, worker string) func(r any) { - return func(r any) { - if worker == processWorkerType { - // There is only one processWorker, so it's safe to close here. - // No need to worry about "close on closed channel" error. - defer close(ch) - } - if r == nil { - logutil.BgLogger().Debug("worker finish without panic", zap.Any("worker", worker)) - return - } - - if extraNotifyCh != nil { - extraNotifyCh <- true - } - - err4Panic := util.GetRecoverError(r) - logutil.Logger(ctx).Error(err4Panic.Error()) - doneCh := make(chan error, 1) - doneCh <- err4Panic - task := &indexMergeTableTask{ - lookupTableTask: lookupTableTask{ - doneCh: doneCh, - }, - } - select { - case <-ctx.Done(): - return - case <-finished: - return - case <-limitDone: - // once the process worker recovered from panic, once finding the limitDone signal, actually we can return. - return - case ch <- task: - return - } - } -} - -// Close implements Exec Close interface. -func (e *IndexMergeReaderExecutor) Close() error { - if e.stats != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - if e.indexUsageReporter != nil { - for _, p := range e.partialPlans { - is, ok := p[0].(*plannercore.PhysicalIndexScan) - if !ok { - continue - } - - e.indexUsageReporter.ReportCopIndexUsageForTable(e.table, is.Index.ID, is.ID()) - } - } - if e.finished == nil { - return nil - } - close(e.finished) - e.tblWorkerWg.Wait() - e.idxWorkerWg.Wait() - e.processWorkerWg.Wait() - e.finished = nil - e.workerStarted = false - return nil -} - -type indexMergeProcessWorker struct { - indexMerge *IndexMergeReaderExecutor - stats *IndexMergeRuntimeStat -} - -type rowIdx struct { - partialID int - taskID int - rowID int -} - -type handleHeap struct { - // requiredCnt == 0 means need all handles - requiredCnt uint64 - tracker *memory.Tracker - taskMap map[int][]*indexMergeTableTask - - idx []rowIdx - compareFunc []chunk.CompareFunc - byItems []*plannerutil.ByItems -} - -func (h handleHeap) Len() int { - return len(h.idx) -} - -func (h handleHeap) Less(i, j int) bool { - rowI := h.taskMap[h.idx[i].partialID][h.idx[i].taskID].idxRows.GetRow(h.idx[i].rowID) - rowJ := h.taskMap[h.idx[j].partialID][h.idx[j].taskID].idxRows.GetRow(h.idx[j].rowID) - - for i, compFunc := range h.compareFunc { - cmp := compFunc(rowI, i, rowJ, i) - if !h.byItems[i].Desc { - cmp = -cmp - } - if cmp < 0 { - return true - } else if cmp > 0 { - return false - } - } - return false -} - -func (h handleHeap) Swap(i, j int) { - h.idx[i], h.idx[j] = h.idx[j], h.idx[i] -} - -func (h *handleHeap) Push(x any) { - idx := x.(rowIdx) - h.idx = append(h.idx, idx) - if h.tracker != nil { - h.tracker.Consume(int64(unsafe.Sizeof(h.idx))) - } -} - -func (h *handleHeap) Pop() any { - idxRet := h.idx[len(h.idx)-1] - h.idx = h.idx[:len(h.idx)-1] - if h.tracker != nil { - h.tracker.Consume(-int64(unsafe.Sizeof(h.idx))) - } - return idxRet -} - -func (w *indexMergeProcessWorker) NewHandleHeap(taskMap map[int][]*indexMergeTableTask, memTracker *memory.Tracker) *handleHeap { - compareFuncs := make([]chunk.CompareFunc, 0, len(w.indexMerge.byItems)) - for _, item := range w.indexMerge.byItems { - keyType := item.Expr.GetType(w.indexMerge.Ctx().GetExprCtx().GetEvalCtx()) - compareFuncs = append(compareFuncs, chunk.GetCompareFunc(keyType)) - } - - requiredCnt := uint64(0) - if w.indexMerge.pushedLimit != nil { - // Pre-allocate up to 1024 to avoid oom - requiredCnt = min(1024, w.indexMerge.pushedLimit.Count+w.indexMerge.pushedLimit.Offset) - } - return &handleHeap{ - requiredCnt: requiredCnt, - tracker: memTracker, - taskMap: taskMap, - idx: make([]rowIdx, 0, requiredCnt), - compareFunc: compareFuncs, - byItems: w.indexMerge.byItems, - } -} - -// pruneTableWorkerTaskIdxRows prune idxRows and only keep columns that will be used in byItems. -// e.g. the common handle is (`b`, `c`) and order by with column `c`, we should make column `c` at the first. -func (w *indexMergeProcessWorker) pruneTableWorkerTaskIdxRows(task *indexMergeTableTask) { - if task.idxRows == nil { - return - } - // IndexScan no need to prune retChk, Columns required by byItems are always first. - if plan, ok := w.indexMerge.partialPlans[task.partialPlanID][0].(*plannercore.PhysicalTableScan); ok { - prune := make([]int, 0, len(w.indexMerge.byItems)) - for _, item := range plan.ByItems { - c, _ := item.Expr.(*expression.Column) - idx := plan.Schema().ColumnIndex(c) - // couldn't equals to -1 here, if idx == -1, just let it panic - prune = append(prune, idx) - } - task.idxRows = task.idxRows.Prune(prune) - } -} - -func (w *indexMergeProcessWorker) fetchLoopUnionWithOrderBy(ctx context.Context, fetchCh <-chan *indexMergeTableTask, - workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { - memTracker := memory.NewTracker(w.indexMerge.ID(), -1) - memTracker.AttachTo(w.indexMerge.memTracker) - defer memTracker.Detach() - defer close(workCh) - - if w.stats != nil { - start := time.Now() - defer func() { - w.stats.IndexMergeProcess += time.Since(start) - }() - } - - distinctHandles := kv.NewHandleMap() - taskMap := make(map[int][]*indexMergeTableTask) - uselessMap := make(map[int]struct{}) - taskHeap := w.NewHandleHeap(taskMap, memTracker) - - for task := range fetchCh { - select { - case err := <-task.doneCh: - // If got error from partialIndexWorker/partialTableWorker, stop processing. - if err != nil { - syncErr(ctx, finished, resultCh, err) - return - } - default: - } - if _, ok := uselessMap[task.partialPlanID]; ok { - continue - } - if _, ok := taskMap[task.partialPlanID]; !ok { - taskMap[task.partialPlanID] = make([]*indexMergeTableTask, 0, 1) - } - w.pruneTableWorkerTaskIdxRows(task) - taskMap[task.partialPlanID] = append(taskMap[task.partialPlanID], task) - for i, h := range task.handles { - if _, ok := distinctHandles.Get(h); !ok { - distinctHandles.Set(h, true) - heap.Push(taskHeap, rowIdx{task.partialPlanID, len(taskMap[task.partialPlanID]) - 1, i}) - if int(taskHeap.requiredCnt) != 0 && taskHeap.Len() > int(taskHeap.requiredCnt) { - top := heap.Pop(taskHeap).(rowIdx) - if top.partialID == task.partialPlanID && top.taskID == len(taskMap[task.partialPlanID])-1 && top.rowID == i { - uselessMap[task.partialPlanID] = struct{}{} - task.handles = task.handles[:i] - break - } - } - } - memTracker.Consume(int64(h.MemUsage())) - } - memTracker.Consume(task.idxRows.MemoryUsage()) - if len(uselessMap) == len(w.indexMerge.partialPlans) { - // consume reset tasks - go func() { - channel.Clear(fetchCh) - }() - break - } - } - - needCount := taskHeap.Len() - if w.indexMerge.pushedLimit != nil { - needCount = max(0, taskHeap.Len()-int(w.indexMerge.pushedLimit.Offset)) - } - if needCount == 0 { - return - } - fhs := make([]kv.Handle, needCount) - for i := needCount - 1; i >= 0; i-- { - idx := heap.Pop(taskHeap).(rowIdx) - fhs[i] = taskMap[idx.partialID][idx.taskID].handles[idx.rowID] - } - - batchSize := w.indexMerge.Ctx().GetSessionVars().IndexLookupSize - tasks := make([]*indexMergeTableTask, 0, len(fhs)/batchSize+1) - for len(fhs) > 0 { - l := min(len(fhs), batchSize) - // Save the index order. - indexOrder := kv.NewHandleMap() - for i, h := range fhs[:l] { - indexOrder.Set(h, i) - } - tasks = append(tasks, &indexMergeTableTask{ - lookupTableTask: lookupTableTask{ - handles: fhs[:l], - indexOrder: indexOrder, - doneCh: make(chan error, 1), - }, - }) - fhs = fhs[l:] - } - for _, task := range tasks { - select { - case <-ctx.Done(): - return - case <-finished: - return - case resultCh <- task: - failpoint.Inject("testCancelContext", func() { - IndexMergeCancelFuncForTest() - }) - select { - case <-ctx.Done(): - return - case <-finished: - return - case workCh <- task: - continue - } - } - } -} - -func pushedLimitCountingDown(pushedLimit *plannercore.PushedDownLimit, handles []kv.Handle) (next bool, res []kv.Handle) { - fhsLen := uint64(len(handles)) - // The number of handles is less than the offset, discard all handles. - if fhsLen <= pushedLimit.Offset { - pushedLimit.Offset -= fhsLen - return true, nil - } - handles = handles[pushedLimit.Offset:] - pushedLimit.Offset = 0 - - fhsLen = uint64(len(handles)) - // The number of handles is greater than the limit, only keep limit count. - if fhsLen > pushedLimit.Count { - handles = handles[:pushedLimit.Count] - } - pushedLimit.Count -= min(pushedLimit.Count, fhsLen) - return false, handles -} - -func (w *indexMergeProcessWorker) fetchLoopUnion(ctx context.Context, fetchCh <-chan *indexMergeTableTask, - workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { - failpoint.Inject("testIndexMergeResultChCloseEarly", func(_ failpoint.Value) { - failpoint.Return() - }) - memTracker := memory.NewTracker(w.indexMerge.ID(), -1) - memTracker.AttachTo(w.indexMerge.memTracker) - defer memTracker.Detach() - defer close(workCh) - failpoint.Inject("testIndexMergePanicProcessWorkerUnion", nil) - - var pushedLimit *plannercore.PushedDownLimit - if w.indexMerge.pushedLimit != nil { - pushedLimit = w.indexMerge.pushedLimit.Clone() - } - hMap := kv.NewHandleMap() - for { - var ok bool - var task *indexMergeTableTask - if pushedLimit != nil && pushedLimit.Count == 0 { - return - } - select { - case <-ctx.Done(): - return - case <-finished: - return - case task, ok = <-fetchCh: - if !ok { - return - } - } - - select { - case err := <-task.doneCh: - // If got error from partialIndexWorker/partialTableWorker, stop processing. - if err != nil { - syncErr(ctx, finished, resultCh, err) - return - } - default: - } - start := time.Now() - handles := task.handles - fhs := make([]kv.Handle, 0, 8) - - memTracker.Consume(int64(cap(task.handles) * 8)) - for _, h := range handles { - if w.indexMerge.partitionTableMode { - if _, ok := h.(kv.PartitionHandle); !ok { - h = kv.NewPartitionHandle(task.partitionTable.GetPhysicalID(), h) - } - } - if _, ok := hMap.Get(h); !ok { - fhs = append(fhs, h) - hMap.Set(h, true) - } - } - if len(fhs) == 0 { - continue - } - if pushedLimit != nil { - next, res := pushedLimitCountingDown(pushedLimit, fhs) - if next { - continue - } - fhs = res - } - task = &indexMergeTableTask{ - lookupTableTask: lookupTableTask{ - handles: fhs, - doneCh: make(chan error, 1), - - partitionTable: task.partitionTable, - }, - } - if w.stats != nil { - w.stats.IndexMergeProcess += time.Since(start) - } - failpoint.Inject("testIndexMergeProcessWorkerUnionHang", func(_ failpoint.Value) { - for i := 0; i < cap(resultCh); i++ { - select { - case resultCh <- &indexMergeTableTask{}: - default: - } - } - }) - select { - case <-ctx.Done(): - return - case <-finished: - return - case resultCh <- task: - failpoint.Inject("testCancelContext", func() { - IndexMergeCancelFuncForTest() - }) - select { - case <-ctx.Done(): - return - case <-finished: - return - case workCh <- task: - } - } - } -} - -// intersectionCollectWorker is used to dispatch index-merge-table-task to original workCh and resultCh. -// a kind of interceptor to control the pushed down limit restriction. (should be no performance impact) -type intersectionCollectWorker struct { - pushedLimit *plannercore.PushedDownLimit - collectCh chan *indexMergeTableTask - limitDone chan struct{} -} - -func (w *intersectionCollectWorker) doIntersectionLimitAndDispatch(ctx context.Context, workCh chan<- *indexMergeTableTask, - resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { - var ( - ok bool - task *indexMergeTableTask - ) - for { - select { - case <-ctx.Done(): - return - case <-finished: - return - case task, ok = <-w.collectCh: - if !ok { - return - } - // receive a new intersection task here, adding limit restriction logic - if w.pushedLimit != nil { - if w.pushedLimit.Count == 0 { - // close limitDone channel to notify intersectionProcessWorkers * N to exit. - close(w.limitDone) - return - } - next, handles := pushedLimitCountingDown(w.pushedLimit, task.handles) - if next { - continue - } - task.handles = handles - } - // dispatch the new task to workCh and resultCh. - select { - case <-ctx.Done(): - return - case <-finished: - return - case workCh <- task: - select { - case <-ctx.Done(): - return - case <-finished: - return - case resultCh <- task: - } - } - } - } -} - -type intersectionProcessWorker struct { - // key: parTblIdx, val: HandleMap - // Value of MemAwareHandleMap is *int to avoid extra Get(). - handleMapsPerWorker map[int]*kv.MemAwareHandleMap[*int] - workerID int - workerCh chan *indexMergeTableTask - indexMerge *IndexMergeReaderExecutor - memTracker *memory.Tracker - batchSize int - - // When rowDelta == memConsumeBatchSize, Consume(memUsage) - rowDelta int64 - mapUsageDelta int64 - - partitionIDMap map[int64]int -} - -func (w *intersectionProcessWorker) consumeMemDelta() { - w.memTracker.Consume(w.mapUsageDelta + w.rowDelta*int64(unsafe.Sizeof(int(0)))) - w.mapUsageDelta = 0 - w.rowDelta = 0 -} - -// doIntersectionPerPartition fetch all the task from workerChannel, and after that, then do the intersection pruning, which -// will cause wasting a lot of time waiting for all the fetch task done. -func (w *intersectionProcessWorker) doIntersectionPerPartition(ctx context.Context, workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished, limitDone <-chan struct{}) { - failpoint.Inject("testIndexMergePanicPartitionTableIntersectionWorker", nil) - defer w.memTracker.Detach() - - for task := range w.workerCh { - var ok bool - var hMap *kv.MemAwareHandleMap[*int] - if hMap, ok = w.handleMapsPerWorker[task.parTblIdx]; !ok { - hMap = kv.NewMemAwareHandleMap[*int]() - w.handleMapsPerWorker[task.parTblIdx] = hMap - } - var mapDelta, rowDelta int64 - for _, h := range task.handles { - if w.indexMerge.hasGlobalIndex { - if ph, ok := h.(kv.PartitionHandle); ok { - if v, exists := w.partitionIDMap[ph.PartitionID]; exists { - if hMap, ok = w.handleMapsPerWorker[v]; !ok { - hMap = kv.NewMemAwareHandleMap[*int]() - w.handleMapsPerWorker[v] = hMap - } - } - } else { - h = kv.NewPartitionHandle(task.partitionTable.GetPhysicalID(), h) - } - } - // Use *int to avoid Get() again. - if cntPtr, ok := hMap.Get(h); ok { - (*cntPtr)++ - } else { - cnt := 1 - mapDelta += hMap.Set(h, &cnt) + int64(h.ExtraMemSize()) - rowDelta++ - } - } - - logutil.BgLogger().Debug("intersectionProcessWorker handle tasks", zap.Int("workerID", w.workerID), - zap.Int("task.handles", len(task.handles)), zap.Int64("rowDelta", rowDelta)) - - w.mapUsageDelta += mapDelta - w.rowDelta += rowDelta - if w.rowDelta >= int64(w.batchSize) { - w.consumeMemDelta() - } - failpoint.Inject("testIndexMergeIntersectionWorkerPanic", nil) - } - if w.rowDelta > 0 { - w.consumeMemDelta() - } - - // We assume the result of intersection is small, so no need to track memory. - intersectedMap := make(map[int][]kv.Handle, len(w.handleMapsPerWorker)) - for parTblIdx, hMap := range w.handleMapsPerWorker { - hMap.Range(func(h kv.Handle, val *int) bool { - if *(val) == len(w.indexMerge.partialPlans) { - // Means all partial paths have this handle. - intersectedMap[parTblIdx] = append(intersectedMap[parTblIdx], h) - } - return true - }) - } - - tasks := make([]*indexMergeTableTask, 0, len(w.handleMapsPerWorker)) - for parTblIdx, intersected := range intersectedMap { - // Split intersected[parTblIdx] to avoid task is too large. - for len(intersected) > 0 { - length := w.batchSize - if length > len(intersected) { - length = len(intersected) - } - task := &indexMergeTableTask{ - lookupTableTask: lookupTableTask{ - handles: intersected[:length], - doneCh: make(chan error, 1), - }, - } - intersected = intersected[length:] - if w.indexMerge.partitionTableMode { - task.partitionTable = w.indexMerge.prunedPartitions[parTblIdx] - } - tasks = append(tasks, task) - logutil.BgLogger().Debug("intersectionProcessWorker build tasks", - zap.Int("parTblIdx", parTblIdx), zap.Int("task.handles", len(task.handles))) - } - } - failpoint.Inject("testIndexMergeProcessWorkerIntersectionHang", func(_ failpoint.Value) { - if resultCh != nil { - for i := 0; i < cap(resultCh); i++ { - select { - case resultCh <- &indexMergeTableTask{}: - default: - } - } - } - }) - for _, task := range tasks { - select { - case <-ctx.Done(): - return - case <-finished: - return - case <-limitDone: - // limitDone has signal means the collectWorker has collected enough results, shutdown process workers quickly here. - return - case workCh <- task: - // resultCh != nil means there is no collectWorker, and we should send task to resultCh too by ourselves here. - if resultCh != nil { - select { - case <-ctx.Done(): - return - case <-finished: - return - case resultCh <- task: - } - } - } - } -} - -// for every index merge process worker, it should be feed on a sortedSelectResult for every partial index plan (constructed on all -// table partition ranges results on that index plan path). Since every partial index path is a sorted select result, we can utilize -// K-way merge to accelerate the intersection process. -// -// partialIndexPlan-1 ---> SSR ---> + -// partialIndexPlan-2 ---> SSR ---> + ---> SSR K-way Merge ---> output IndexMergeTableTask -// partialIndexPlan-3 ---> SSR ---> + -// ... + -// partialIndexPlan-N ---> SSR ---> + -// -// K-way merge detail: for every partial index plan, output one row as current its representative row. Then, comparing the N representative -// rows together: -// -// Loop start: -// -// case 1: they are all the same, intersection succeed. --- Record current handle down (already in index order). -// case 2: distinguish among them, for the minimum value/values corresponded index plan/plans. --- Discard current representative row, fetch next. -// -// goto Loop start: -// -// encapsulate all the recorded handles (already in index order) as index merge table tasks, sending them out. -func (*indexMergeProcessWorker) fetchLoopIntersectionWithOrderBy(_ context.Context, _ <-chan *indexMergeTableTask, - _ chan<- *indexMergeTableTask, _ chan<- *indexMergeTableTask, _ <-chan struct{}) { - // todo: pushed sort property with partial index plan and limit. -} - -// For each partition(dynamic mode), a map is used to do intersection. Key of the map is handle, and value is the number of times it occurs. -// If the value of handle equals the number of partial paths, it should be sent to final_table_scan_worker. -// To avoid too many goroutines, each intersectionProcessWorker can handle multiple partitions. -func (w *indexMergeProcessWorker) fetchLoopIntersection(ctx context.Context, fetchCh <-chan *indexMergeTableTask, - workCh chan<- *indexMergeTableTask, resultCh chan<- *indexMergeTableTask, finished <-chan struct{}) { - defer close(workCh) - - if w.stats != nil { - start := time.Now() - defer func() { - w.stats.IndexMergeProcess += time.Since(start) - }() - } - - failpoint.Inject("testIndexMergePanicProcessWorkerIntersection", nil) - - // One goroutine may handle one or multiple partitions. - // Max number of partition number is 8192, we use ExecutorConcurrency to avoid too many goroutines. - maxWorkerCnt := w.indexMerge.Ctx().GetSessionVars().IndexMergeIntersectionConcurrency() - maxChannelSize := atomic.LoadInt32(&LookupTableTaskChannelSize) - batchSize := w.indexMerge.Ctx().GetSessionVars().IndexLookupSize - - partCnt := 1 - // To avoid multi-threaded access the handle map, we only use one worker for indexMerge with global index. - if w.indexMerge.partitionTableMode && !w.indexMerge.hasGlobalIndex { - partCnt = len(w.indexMerge.prunedPartitions) - } - workerCnt := min(partCnt, maxWorkerCnt) - failpoint.Inject("testIndexMergeIntersectionConcurrency", func(val failpoint.Value) { - con := val.(int) - if con != workerCnt { - panic(fmt.Sprintf("unexpected workerCnt, expect %d, got %d", con, workerCnt)) - } - }) - - partitionIDMap := make(map[int64]int) - if w.indexMerge.hasGlobalIndex { - for i, p := range w.indexMerge.prunedPartitions { - partitionIDMap[p.GetPhysicalID()] = i - } - } - - workers := make([]*intersectionProcessWorker, 0, workerCnt) - var collectWorker *intersectionCollectWorker - wg := util.WaitGroupWrapper{} - wg2 := util.WaitGroupWrapper{} - errCh := make(chan bool, workerCnt) - var limitDone chan struct{} - if w.indexMerge.pushedLimit != nil { - // no memory cost for this code logic. - collectWorker = &intersectionCollectWorker{ - // same size of workCh/resultCh - collectCh: make(chan *indexMergeTableTask, atomic.LoadInt32(&LookupTableTaskChannelSize)), - pushedLimit: w.indexMerge.pushedLimit.Clone(), - limitDone: make(chan struct{}), - } - limitDone = collectWorker.limitDone - wg2.RunWithRecover(func() { - defer trace.StartRegion(ctx, "IndexMergeIntersectionProcessWorker").End() - collectWorker.doIntersectionLimitAndDispatch(ctx, workCh, resultCh, finished) - }, handleWorkerPanic(ctx, finished, nil, resultCh, errCh, partTblIntersectionWorkerType)) - } - for i := 0; i < workerCnt; i++ { - tracker := memory.NewTracker(w.indexMerge.ID(), -1) - tracker.AttachTo(w.indexMerge.memTracker) - worker := &intersectionProcessWorker{ - workerID: i, - handleMapsPerWorker: make(map[int]*kv.MemAwareHandleMap[*int]), - workerCh: make(chan *indexMergeTableTask, maxChannelSize), - indexMerge: w.indexMerge, - memTracker: tracker, - batchSize: batchSize, - partitionIDMap: partitionIDMap, - } - wg.RunWithRecover(func() { - defer trace.StartRegion(ctx, "IndexMergeIntersectionProcessWorker").End() - if collectWorker != nil { - // workflow: - // intersectionProcessWorker-1 --+ (limit restriction logic) - // intersectionProcessWorker-2 --+--------- collectCh--> intersectionCollectWorker +--> workCh --> table worker - // ... --+ <--- limitDone to shut inputs ------+ +-> resultCh --> upper parent - // intersectionProcessWorker-N --+ - worker.doIntersectionPerPartition(ctx, collectWorker.collectCh, nil, finished, collectWorker.limitDone) - } else { - // workflow: - // intersectionProcessWorker-1 --------------------------+--> workCh --> table worker - // intersectionProcessWorker-2 ---(same as above) +--> resultCh --> upper parent - // ... ---(same as above) - // intersectionProcessWorker-N ---(same as above) - worker.doIntersectionPerPartition(ctx, workCh, resultCh, finished, nil) - } - }, handleWorkerPanic(ctx, finished, limitDone, resultCh, errCh, partTblIntersectionWorkerType)) - workers = append(workers, worker) - } - defer func() { - for _, processWorker := range workers { - close(processWorker.workerCh) - } - wg.Wait() - // only after all the possible writer closed, can we shut down the collectCh. - if collectWorker != nil { - // you don't need to clear the channel before closing it, so discard all the remain tasks. - close(collectWorker.collectCh) - } - wg2.Wait() - }() - for { - var ok bool - var task *indexMergeTableTask - select { - case <-ctx.Done(): - return - case <-finished: - return - case task, ok = <-fetchCh: - if !ok { - return - } - } - - select { - case err := <-task.doneCh: - // If got error from partialIndexWorker/partialTableWorker, stop processing. - if err != nil { - syncErr(ctx, finished, resultCh, err) - return - } - default: - } - - select { - case <-ctx.Done(): - return - case <-finished: - return - case workers[task.parTblIdx%workerCnt].workerCh <- task: - case <-errCh: - // If got error from intersectionProcessWorker, stop processing. - return - } - } -} - -type partialIndexWorker struct { - stats *IndexMergeRuntimeStat - sc sessionctx.Context - idxID int - batchSize int - maxBatchSize int - maxChunkSize int - memTracker *memory.Tracker - partitionTableMode bool - prunedPartitions []table.PhysicalTable - byItems []*plannerutil.ByItems - scannedKeys uint64 - pushedLimit *plannercore.PushedDownLimit - dagPB *tipb.DAGRequest - plan []base.PhysicalPlan -} - -func syncErr(ctx context.Context, finished <-chan struct{}, errCh chan<- *indexMergeTableTask, err error) { - logutil.BgLogger().Error("IndexMergeReaderExecutor.syncErr", zap.Error(err)) - doneCh := make(chan error, 1) - doneCh <- err - task := &indexMergeTableTask{ - lookupTableTask: lookupTableTask{ - doneCh: doneCh, - }, - } - - // ctx.Done and finished is to avoid write channel is stuck. - select { - case <-ctx.Done(): - return - case <-finished: - return - case errCh <- task: - return - } -} - -// needPartitionHandle indicates whether we need create a partitionHandle or not. -// If the schema from planner part contains ExtraPhysTblID, -// we need create a partitionHandle, otherwise create a normal handle. -// In TableRowIDScan, the partitionHandle will be used to create key ranges. -func (w *partialIndexWorker) needPartitionHandle() (bool, error) { - cols := w.plan[0].Schema().Columns - outputOffsets := w.dagPB.OutputOffsets - col := cols[outputOffsets[len(outputOffsets)-1]] - - is := w.plan[0].(*plannercore.PhysicalIndexScan) - needPartitionHandle := w.partitionTableMode && len(w.byItems) > 0 || is.Index.Global - hasExtraCol := col.ID == model.ExtraPhysTblID - - // There will be two needPartitionHandle != hasExtraCol situations. - // Only `needPartitionHandle` == true and `hasExtraCol` == false are not allowed. - // `ExtraPhysTblID` will be used in `SelectLock` when `needPartitionHandle` == false and `hasExtraCol` == true. - if needPartitionHandle && !hasExtraCol { - return needPartitionHandle, errors.Errorf("Internal error, needPartitionHandle != ret") - } - return needPartitionHandle, nil -} - -func (w *partialIndexWorker) fetchHandles( - ctx context.Context, - results []distsql.SelectResult, - exitCh <-chan struct{}, - fetchCh chan<- *indexMergeTableTask, - finished <-chan struct{}, - handleCols plannerutil.HandleCols, - partialPlanIndex int) (count int64, err error) { - tps := w.getRetTpsForIndexScan(handleCols) - chk := chunk.NewChunkWithCapacity(tps, w.maxChunkSize) - for i := 0; i < len(results); { - start := time.Now() - handles, retChunk, err := w.extractTaskHandles(ctx, chk, results[i], handleCols) - if err != nil { - syncErr(ctx, finished, fetchCh, err) - return count, err - } - if len(handles) == 0 { - i++ - continue - } - count += int64(len(handles)) - task := w.buildTableTask(handles, retChunk, i, partialPlanIndex) - if w.stats != nil { - atomic.AddInt64(&w.stats.FetchIdxTime, int64(time.Since(start))) - } - select { - case <-ctx.Done(): - return count, ctx.Err() - case <-exitCh: - return count, nil - case <-finished: - return count, nil - case fetchCh <- task: - } - } - return count, nil -} - -func (w *partialIndexWorker) getRetTpsForIndexScan(handleCols plannerutil.HandleCols) []*types.FieldType { - var tps []*types.FieldType - if len(w.byItems) != 0 { - for _, item := range w.byItems { - tps = append(tps, item.Expr.GetType(w.sc.GetExprCtx().GetEvalCtx())) - } - } - tps = append(tps, handleCols.GetFieldsTypes()...) - if ok, _ := w.needPartitionHandle(); ok { - tps = append(tps, types.NewFieldType(mysql.TypeLonglong)) - } - return tps -} - -func (w *partialIndexWorker) extractTaskHandles(ctx context.Context, chk *chunk.Chunk, idxResult distsql.SelectResult, handleCols plannerutil.HandleCols) ( - handles []kv.Handle, retChk *chunk.Chunk, err error) { - handles = make([]kv.Handle, 0, w.batchSize) - if len(w.byItems) != 0 { - retChk = chunk.NewChunkWithCapacity(w.getRetTpsForIndexScan(handleCols), w.batchSize) - } - var memUsage int64 - var chunkRowOffset int - defer w.memTracker.Consume(-memUsage) - for len(handles) < w.batchSize { - requiredRows := w.batchSize - len(handles) - if w.pushedLimit != nil { - if w.pushedLimit.Offset+w.pushedLimit.Count <= w.scannedKeys { - return handles, retChk, nil - } - requiredRows = min(int(w.pushedLimit.Offset+w.pushedLimit.Count-w.scannedKeys), requiredRows) - } - chk.SetRequiredRows(requiredRows, w.maxChunkSize) - start := time.Now() - err = errors.Trace(idxResult.Next(ctx, chk)) - if err != nil { - return nil, nil, err - } - if w.stats != nil && w.idxID != 0 { - w.sc.GetSessionVars().StmtCtx.RuntimeStatsColl.GetBasicRuntimeStats(w.idxID).Record(time.Since(start), chk.NumRows()) - } - if chk.NumRows() == 0 { - failpoint.Inject("testIndexMergeErrorPartialIndexWorker", func(v failpoint.Value) { - failpoint.Return(handles, nil, errors.New(v.(string))) - }) - return handles, retChk, nil - } - memDelta := chk.MemoryUsage() - memUsage += memDelta - w.memTracker.Consume(memDelta) - for chunkRowOffset = 0; chunkRowOffset < chk.NumRows(); chunkRowOffset++ { - if w.pushedLimit != nil { - w.scannedKeys++ - if w.scannedKeys > (w.pushedLimit.Offset + w.pushedLimit.Count) { - // Skip the handles after Offset+Count. - break - } - } - var handle kv.Handle - ok, err1 := w.needPartitionHandle() - if err1 != nil { - return nil, nil, err1 - } - if ok { - handle, err = handleCols.BuildPartitionHandleFromIndexRow(chk.GetRow(chunkRowOffset)) - } else { - handle, err = handleCols.BuildHandleFromIndexRow(chk.GetRow(chunkRowOffset)) - } - if err != nil { - return nil, nil, err - } - handles = append(handles, handle) - } - // used for order by - if len(w.byItems) != 0 { - retChk.Append(chk, 0, chunkRowOffset) - } - } - w.batchSize *= 2 - if w.batchSize > w.maxBatchSize { - w.batchSize = w.maxBatchSize - } - return handles, retChk, nil -} - -func (w *partialIndexWorker) buildTableTask(handles []kv.Handle, retChk *chunk.Chunk, parTblIdx int, partialPlanID int) *indexMergeTableTask { - task := &indexMergeTableTask{ - lookupTableTask: lookupTableTask{ - handles: handles, - idxRows: retChk, - }, - parTblIdx: parTblIdx, - partialPlanID: partialPlanID, - } - - if w.prunedPartitions != nil { - task.partitionTable = w.prunedPartitions[parTblIdx] - } - - task.doneCh = make(chan error, 1) - return task -} - -type indexMergeTableScanWorker struct { - stats *IndexMergeRuntimeStat - workCh <-chan *indexMergeTableTask - finished <-chan struct{} - indexMergeExec *IndexMergeReaderExecutor - tblPlans []base.PhysicalPlan - - // memTracker is used to track the memory usage of this executor. - memTracker *memory.Tracker -} - -func (w *indexMergeTableScanWorker) pickAndExecTask(ctx context.Context, task **indexMergeTableTask) { - var ok bool - for { - waitStart := time.Now() - select { - case <-ctx.Done(): - return - case <-w.finished: - return - case *task, ok = <-w.workCh: - if !ok { - return - } - } - // Make sure panic failpoint is after fetch task from workCh. - // Otherwise, cannot send error to task.doneCh. - failpoint.Inject("testIndexMergePanicTableScanWorker", nil) - execStart := time.Now() - err := w.executeTask(ctx, *task) - if w.stats != nil { - atomic.AddInt64(&w.stats.WaitTime, int64(execStart.Sub(waitStart))) - atomic.AddInt64(&w.stats.FetchRow, int64(time.Since(execStart))) - atomic.AddInt64(&w.stats.TableTaskNum, 1) - } - failpoint.Inject("testIndexMergePickAndExecTaskPanic", nil) - select { - case <-ctx.Done(): - return - case <-w.finished: - return - case (*task).doneCh <- err: - } - } -} - -func (*indexMergeTableScanWorker) handleTableScanWorkerPanic(ctx context.Context, finished <-chan struct{}, task **indexMergeTableTask, worker string) func(r any) { - return func(r any) { - if r == nil { - logutil.BgLogger().Debug("worker finish without panic", zap.Any("worker", worker)) - return - } - - err4Panic := errors.Errorf("%s: %v", worker, r) - logutil.Logger(ctx).Error(err4Panic.Error()) - if *task != nil { - select { - case <-ctx.Done(): - return - case <-finished: - return - case (*task).doneCh <- err4Panic: - return - } - } - } -} - -func (w *indexMergeTableScanWorker) executeTask(ctx context.Context, task *indexMergeTableTask) error { - tbl := w.indexMergeExec.table - if w.indexMergeExec.partitionTableMode && task.partitionTable != nil { - tbl = task.partitionTable - } - tableReader, err := w.indexMergeExec.buildFinalTableReader(ctx, tbl, task.handles) - if err != nil { - logutil.Logger(ctx).Error("build table reader failed", zap.Error(err)) - return err - } - defer func() { terror.Log(exec.Close(tableReader)) }() - task.memTracker = w.memTracker - memUsage := int64(cap(task.handles) * 8) - task.memUsage = memUsage - task.memTracker.Consume(memUsage) - handleCnt := len(task.handles) - task.rows = make([]chunk.Row, 0, handleCnt) - for { - chk := exec.TryNewCacheChunk(tableReader) - err = exec.Next(ctx, tableReader, chk) - if err != nil { - logutil.Logger(ctx).Error("table reader fetch next chunk failed", zap.Error(err)) - return err - } - if chk.NumRows() == 0 { - break - } - memUsage = chk.MemoryUsage() - task.memUsage += memUsage - task.memTracker.Consume(memUsage) - iter := chunk.NewIterator4Chunk(chk) - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - task.rows = append(task.rows, row) - } - } - - if w.indexMergeExec.keepOrder { - // Because len(outputOffsets) == tableScan.Schema().Len(), - // so we could use row.GetInt64(idx) to get partition ID here. - // TODO: We could add plannercore.PartitionHandleCols to unify them. - physicalTableIDIdx := -1 - for i, c := range w.indexMergeExec.Schema().Columns { - if c.ID == model.ExtraPhysTblID { - physicalTableIDIdx = i - break - } - } - task.rowIdx = make([]int, 0, len(task.rows)) - for _, row := range task.rows { - handle, err := w.indexMergeExec.handleCols.BuildHandle(row) - if err != nil { - return err - } - if w.indexMergeExec.partitionTableMode && physicalTableIDIdx != -1 { - handle = kv.NewPartitionHandle(row.GetInt64(physicalTableIDIdx), handle) - } - rowIdx, _ := task.indexOrder.Get(handle) - task.rowIdx = append(task.rowIdx, rowIdx.(int)) - } - sort.Sort(task) - } - - memUsage = int64(cap(task.rows)) * int64(unsafe.Sizeof(chunk.Row{})) - task.memUsage += memUsage - task.memTracker.Consume(memUsage) - if handleCnt != len(task.rows) && len(w.tblPlans) == 1 { - return errors.Errorf("handle count %d isn't equal to value count %d", handleCnt, len(task.rows)) - } - return nil -} - -// IndexMergeRuntimeStat record the indexMerge runtime stat -type IndexMergeRuntimeStat struct { - IndexMergeProcess time.Duration - FetchIdxTime int64 - WaitTime int64 - FetchRow int64 - TableTaskNum int64 - Concurrency int -} - -func (e *IndexMergeRuntimeStat) String() string { - var buf bytes.Buffer - if e.FetchIdxTime != 0 { - buf.WriteString(fmt.Sprintf("index_task:{fetch_handle:%s", time.Duration(e.FetchIdxTime))) - if e.IndexMergeProcess != 0 { - buf.WriteString(fmt.Sprintf(", merge:%s", e.IndexMergeProcess)) - } - buf.WriteByte('}') - } - if e.FetchRow != 0 { - if buf.Len() > 0 { - buf.WriteByte(',') - } - fmt.Fprintf(&buf, " table_task:{num:%d, concurrency:%d, fetch_row:%s, wait_time:%s}", e.TableTaskNum, e.Concurrency, time.Duration(e.FetchRow), time.Duration(e.WaitTime)) - } - return buf.String() -} - -// Clone implements the RuntimeStats interface. -func (e *IndexMergeRuntimeStat) Clone() execdetails.RuntimeStats { - newRs := *e - return &newRs -} - -// Merge implements the RuntimeStats interface. -func (e *IndexMergeRuntimeStat) Merge(other execdetails.RuntimeStats) { - tmp, ok := other.(*IndexMergeRuntimeStat) - if !ok { - return - } - e.IndexMergeProcess += tmp.IndexMergeProcess - e.FetchIdxTime += tmp.FetchIdxTime - e.FetchRow += tmp.FetchRow - e.WaitTime += e.WaitTime - e.TableTaskNum += tmp.TableTaskNum -} - -// Tp implements the RuntimeStats interface. -func (*IndexMergeRuntimeStat) Tp() int { - return execdetails.TpIndexMergeRunTimeStats -} diff --git a/pkg/executor/infoschema_reader.go b/pkg/executor/infoschema_reader.go index 6cdcbf9c2dd44..6b8455e281f39 100644 --- a/pkg/executor/infoschema_reader.go +++ b/pkg/executor/infoschema_reader.go @@ -3478,7 +3478,7 @@ func (e *memtableRetriever) setDataForAttributes(ctx context.Context, sctx sessi checker := privilege.GetPrivilegeManager(sctx) rules, err := infosync.GetAllLabelRules(context.TODO()) skipValidateTable := false - if _, _err_ := failpoint.Eval(_curpkg_("mockOutputOfAttributes")); _err_ == nil { + failpoint.Inject("mockOutputOfAttributes", func() { convert := func(i any) []any { return []any{i} } @@ -3513,7 +3513,7 @@ func (e *memtableRetriever) setDataForAttributes(ctx context.Context, sctx sessi } err = nil skipValidateTable = true - } + }) if err != nil { return errors.Wrap(err, "get the label rules failed") diff --git a/pkg/executor/infoschema_reader.go__failpoint_stash__ b/pkg/executor/infoschema_reader.go__failpoint_stash__ deleted file mode 100644 index 6b8455e281f39..0000000000000 --- a/pkg/executor/infoschema_reader.go__failpoint_stash__ +++ /dev/null @@ -1,3878 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "context" - "encoding/hex" - "encoding/json" - "fmt" - "math" - "slices" - "strconv" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/deadlock" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - rmpb "github.com/pingcap/kvproto/pkg/resource_manager" - "github.com/pingcap/tidb/pkg/ddl/label" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/internal/pdhelper" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/privilege/privileges" - "github.com/pingcap/tidb/pkg/session/txninfo" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/cache" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/deadlockhistory" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/keydecoder" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/resourcegrouptag" - "github.com/pingcap/tidb/pkg/util/sem" - "github.com/pingcap/tidb/pkg/util/servermemorylimit" - "github.com/pingcap/tidb/pkg/util/set" - "github.com/pingcap/tidb/pkg/util/stringutil" - "github.com/pingcap/tidb/pkg/util/syncutil" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/txnkv/txnlock" - pd "github.com/tikv/pd/client/http" - "go.uber.org/zap" -) - -type memtableRetriever struct { - dummyCloser - table *model.TableInfo - columns []*model.ColumnInfo - rows [][]types.Datum - rowIdx int - retrieved bool - initialized bool - extractor base.MemTablePredicateExtractor - is infoschema.InfoSchema - memTracker *memory.Tracker -} - -// retrieve implements the infoschemaRetriever interface -func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.table.Name.O == infoschema.TableClusterInfo && !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - if e.retrieved { - return nil, nil - } - - // Cache the ret full rows in schemataRetriever - if !e.initialized { - is := sctx.GetInfoSchema().(infoschema.InfoSchema) - e.is = is - - var getAllSchemas = func() []model.CIStr { - dbs := is.AllSchemaNames() - slices.SortFunc(dbs, func(a, b model.CIStr) int { - return strings.Compare(a.L, b.L) - }) - return dbs - } - - var err error - switch e.table.Name.O { - case infoschema.TableSchemata: - err = e.setDataFromSchemata(sctx) - case infoschema.TableStatistics: - err = e.setDataForStatistics(ctx, sctx) - case infoschema.TableTables: - err = e.setDataFromTables(ctx, sctx) - case infoschema.TableReferConst: - dbs := getAllSchemas() - err = e.setDataFromReferConst(ctx, sctx, dbs) - case infoschema.TableSequences: - dbs := getAllSchemas() - err = e.setDataFromSequences(ctx, sctx, dbs) - case infoschema.TablePartitions: - err = e.setDataFromPartitions(ctx, sctx) - case infoschema.TableClusterInfo: - err = e.dataForTiDBClusterInfo(sctx) - case infoschema.TableAnalyzeStatus: - err = e.setDataForAnalyzeStatus(ctx, sctx) - case infoschema.TableTiDBIndexes: - dbs := getAllSchemas() - err = e.setDataFromIndexes(ctx, sctx, dbs) - case infoschema.TableViews: - dbs := getAllSchemas() - err = e.setDataFromViews(ctx, sctx, dbs) - case infoschema.TableEngines: - e.setDataFromEngines() - case infoschema.TableCharacterSets: - e.setDataFromCharacterSets() - case infoschema.TableCollations: - e.setDataFromCollations() - case infoschema.TableKeyColumn: - dbs := getAllSchemas() - err = e.setDataFromKeyColumnUsage(ctx, sctx, dbs) - case infoschema.TableMetricTables: - e.setDataForMetricTables() - case infoschema.TableProfiling: - e.setDataForPseudoProfiling(sctx) - case infoschema.TableCollationCharacterSetApplicability: - e.dataForCollationCharacterSetApplicability() - case infoschema.TableProcesslist: - e.setDataForProcessList(sctx) - case infoschema.ClusterTableProcesslist: - err = e.setDataForClusterProcessList(sctx) - case infoschema.TableUserPrivileges: - e.setDataFromUserPrivileges(sctx) - case infoschema.TableTiKVRegionStatus: - err = e.setDataForTiKVRegionStatus(ctx, sctx) - case infoschema.TableTiDBHotRegions: - err = e.setDataForTiDBHotRegions(ctx, sctx) - case infoschema.TableConstraints: - dbs := getAllSchemas() - err = e.setDataFromTableConstraints(ctx, sctx, dbs) - case infoschema.TableSessionVar: - e.rows, err = infoschema.GetDataFromSessionVariables(ctx, sctx) - case infoschema.TableTiDBServersInfo: - err = e.setDataForServersInfo(sctx) - case infoschema.TableTiFlashReplica: - dbs := getAllSchemas() - err = e.dataForTableTiFlashReplica(ctx, sctx, dbs) - case infoschema.TableTiKVStoreStatus: - err = e.dataForTiKVStoreStatus(ctx, sctx) - case infoschema.TableClientErrorsSummaryGlobal, - infoschema.TableClientErrorsSummaryByUser, - infoschema.TableClientErrorsSummaryByHost: - err = e.setDataForClientErrorsSummary(sctx, e.table.Name.O) - case infoschema.TableAttributes: - err = e.setDataForAttributes(ctx, sctx, is) - case infoschema.TablePlacementPolicies: - err = e.setDataFromPlacementPolicies(sctx) - case infoschema.TableTrxSummary: - err = e.setDataForTrxSummary(sctx) - case infoschema.ClusterTableTrxSummary: - err = e.setDataForClusterTrxSummary(sctx) - case infoschema.TableVariablesInfo: - err = e.setDataForVariablesInfo(sctx) - case infoschema.TableUserAttributes: - err = e.setDataForUserAttributes(ctx, sctx) - case infoschema.TableMemoryUsage: - err = e.setDataForMemoryUsage() - case infoschema.ClusterTableMemoryUsage: - err = e.setDataForClusterMemoryUsage(sctx) - case infoschema.TableMemoryUsageOpsHistory: - err = e.setDataForMemoryUsageOpsHistory() - case infoschema.ClusterTableMemoryUsageOpsHistory: - err = e.setDataForClusterMemoryUsageOpsHistory(sctx) - case infoschema.TableResourceGroups: - err = e.setDataFromResourceGroups() - case infoschema.TableRunawayWatches: - err = e.setDataFromRunawayWatches(sctx) - case infoschema.TableCheckConstraints: - dbs := getAllSchemas() - err = e.setDataFromCheckConstraints(ctx, sctx, dbs) - case infoschema.TableTiDBCheckConstraints: - dbs := getAllSchemas() - err = e.setDataFromTiDBCheckConstraints(ctx, sctx, dbs) - case infoschema.TableKeywords: - err = e.setDataFromKeywords() - case infoschema.TableTiDBIndexUsage: - dbs := getAllSchemas() - err = e.setDataFromIndexUsage(ctx, sctx, dbs) - case infoschema.ClusterTableTiDBIndexUsage: - dbs := getAllSchemas() - err = e.setDataForClusterIndexUsage(ctx, sctx, dbs) - } - if err != nil { - return nil, err - } - e.initialized = true - if e.memTracker != nil { - e.memTracker.Consume(calculateDatumsSize(e.rows)) - } - } - - // Adjust the amount of each return - maxCount := 1024 - retCount := maxCount - if e.rowIdx+maxCount > len(e.rows) { - retCount = len(e.rows) - e.rowIdx - e.retrieved = true - } - ret := make([][]types.Datum, retCount) - for i := e.rowIdx; i < e.rowIdx+retCount; i++ { - ret[i-e.rowIdx] = e.rows[i] - } - e.rowIdx += retCount - return adjustColumns(ret, e.columns, e.table), nil -} - -func getAutoIncrementID( - is infoschema.InfoSchema, - sctx sessionctx.Context, - tblInfo *model.TableInfo, -) int64 { - tbl, ok := is.TableByID(tblInfo.ID) - if !ok { - return 0 - } - return tbl.Allocators(sctx.GetTableCtx()).Get(autoid.AutoIncrementType).Base() + 1 -} - -func hasPriv(ctx sessionctx.Context, priv mysql.PrivilegeType) bool { - pm := privilege.GetPrivilegeManager(ctx) - if pm == nil { - // internal session created with createSession doesn't has the PrivilegeManager. For most experienced cases before, - // we use it like this: - // ``` - // checker := privilege.GetPrivilegeManager(ctx) - // if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", mysql.AllPrivMask) { - // continue - // } - // do something. - // ``` - // So once the privilege manager is nil, it's a signature of internal sql, so just passing the checker through. - return true - } - return pm.RequestVerification(ctx.GetSessionVars().ActiveRoles, "", "", "", priv) -} - -func (e *memtableRetriever) setDataForVariablesInfo(ctx sessionctx.Context) error { - sysVars := variable.GetSysVars() - rows := make([][]types.Datum, 0, len(sysVars)) - for _, sv := range sysVars { - if infoschema.SysVarHiddenForSem(ctx, sv.Name) { - continue - } - currentVal, err := ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), sv.Name) - if err != nil { - currentVal = "" - } - isNoop := "NO" - if sv.IsNoop { - isNoop = "YES" - } - defVal := sv.Value - if sv.HasGlobalScope() { - defVal = variable.GlobalSystemVariableInitialValue(sv.Name, defVal) - } - row := types.MakeDatums( - sv.Name, // VARIABLE_NAME - sv.Scope.String(), // VARIABLE_SCOPE - defVal, // DEFAULT_VALUE - currentVal, // CURRENT_VALUE - sv.MinValue, // MIN_VALUE - sv.MaxValue, // MAX_VALUE - nil, // POSSIBLE_VALUES - isNoop, // IS_NOOP - ) - // min and max value is only supported for numeric types - if !(sv.Type == variable.TypeUnsigned || sv.Type == variable.TypeInt || sv.Type == variable.TypeFloat) { - row[4].SetNull() - row[5].SetNull() - } - if sv.Type == variable.TypeEnum { - possibleValues := strings.Join(sv.PossibleValues, ",") - row[6].SetString(possibleValues, mysql.DefaultCollationName) - } - rows = append(rows, row) - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForUserAttributes(ctx context.Context, sctx sessionctx.Context) error { - exec := sctx.GetRestrictedSQLExecutor() - chunkRows, _, err := exec.ExecRestrictedSQL(ctx, nil, `SELECT user, host, JSON_UNQUOTE(JSON_EXTRACT(user_attributes, '$.metadata')) FROM mysql.user`) - if err != nil { - return err - } - if len(chunkRows) == 0 { - return nil - } - rows := make([][]types.Datum, 0, len(chunkRows)) - for _, chunkRow := range chunkRows { - if chunkRow.Len() != 3 { - continue - } - user := chunkRow.GetString(0) - host := chunkRow.GetString(1) - // Compatible with results in MySQL - var attribute any - if attribute = chunkRow.GetString(2); attribute == "" { - attribute = nil - } - row := types.MakeDatums(user, host, attribute) - rows = append(rows, row) - } - - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromSchemata(ctx sessionctx.Context) error { - checker := privilege.GetPrivilegeManager(ctx) - ex, ok := e.extractor.(*plannercore.InfoSchemaSchemataExtractor) - if !ok { - return errors.Errorf("wrong extractor type: %T, expected InfoSchemaSchemataExtractor", e.extractor) - } - if ex.SkipRequest { - return nil - } - schemas := ex.ListSchemas(e.is) - rows := make([][]types.Datum, 0, len(schemas)) - - for _, schemaName := range schemas { - schema, _ := e.is.SchemaByName(schemaName) - charset := mysql.DefaultCharset - collation := mysql.DefaultCollationName - - if len(schema.Charset) > 0 { - charset = schema.Charset // Overwrite default - } - - if len(schema.Collate) > 0 { - collation = schema.Collate // Overwrite default - } - var policyName any - if schema.PlacementPolicyRef != nil { - policyName = schema.PlacementPolicyRef.Name.O - } - - if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, schema.Name.L, "", "", mysql.AllPrivMask) { - continue - } - record := types.MakeDatums( - infoschema.CatalogVal, // CATALOG_NAME - schema.Name.O, // SCHEMA_NAME - charset, // DEFAULT_CHARACTER_SET_NAME - collation, // DEFAULT_COLLATION_NAME - nil, // SQL_PATH - policyName, // TIDB_PLACEMENT_POLICY_NAME - ) - rows = append(rows, record) - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForStatistics(ctx context.Context, sctx sessionctx.Context) error { - checker := privilege.GetPrivilegeManager(sctx) - ex, ok := e.extractor.(*plannercore.InfoSchemaStatisticsExtractor) - if !ok { - return errors.Errorf("wrong extractor type: %T, expected InfoSchemaStatisticsExtractor", e.extractor) - } - if ex.SkipRequest { - return nil - } - schemas, tables, err := ex.ListSchemasAndTables(ctx, e.is) - if err != nil { - return errors.Trace(err) - } - for i, table := range tables { - schema := schemas[i] - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { - continue - } - e.setDataForStatisticsInTable(schema, table, ex) - } - return nil -} - -func (e *memtableRetriever) setDataForStatisticsInTable(schema model.CIStr, table *model.TableInfo, extractor *plannercore.InfoSchemaStatisticsExtractor) { - var rows [][]types.Datum - if table.PKIsHandle { - if !extractor.Filter("index_name", "primary") { - for _, col := range table.Columns { - if mysql.HasPriKeyFlag(col.GetFlag()) { - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - "0", // NON_UNIQUE - schema.O, // INDEX_SCHEMA - "PRIMARY", // INDEX_NAME - 1, // SEQ_IN_INDEX - col.Name.O, // COLUMN_NAME - "A", // COLLATION - 0, // CARDINALITY - nil, // SUB_PART - nil, // PACKED - "", // NULLABLE - "BTREE", // INDEX_TYPE - "", // COMMENT - "", // INDEX_COMMENT - "YES", // IS_VISIBLE - nil, // Expression - ) - rows = append(rows, record) - } - } - } - } - nameToCol := make(map[string]*model.ColumnInfo, len(table.Columns)) - for _, c := range table.Columns { - nameToCol[c.Name.L] = c - } - for _, index := range table.Indices { - if extractor.Filter("index_name", index.Name.L) { - continue - } - nonUnique := "1" - if index.Unique { - nonUnique = "0" - } - for i, key := range index.Columns { - col := nameToCol[key.Name.L] - nullable := "YES" - if mysql.HasNotNullFlag(col.GetFlag()) { - nullable = "" - } - - visible := "YES" - if index.Invisible { - visible = "NO" - } - - colName := col.Name.O - var expression any - expression = nil - tblCol := table.Columns[col.Offset] - if tblCol.Hidden { - colName = "NULL" - expression = tblCol.GeneratedExprString - } - - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - nonUnique, // NON_UNIQUE - schema.O, // INDEX_SCHEMA - index.Name.O, // INDEX_NAME - i+1, // SEQ_IN_INDEX - colName, // COLUMN_NAME - "A", // COLLATION - 0, // CARDINALITY - nil, // SUB_PART - nil, // PACKED - nullable, // NULLABLE - "BTREE", // INDEX_TYPE - "", // COMMENT - index.Comment, // INDEX_COMMENT - visible, // IS_VISIBLE - expression, // Expression - ) - rows = append(rows, record) - } - } - e.rows = append(e.rows, rows...) -} - -func (e *memtableRetriever) setDataFromReferConst(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - checker := privilege.GetPrivilegeManager(sctx) - var rows [][]types.Datum - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - for _, schema := range schemas { - if ok && extractor.Filter("constraint_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, table := range tables { - if ok && extractor.Filter("table_name", table.Name.L) { - continue - } - if !table.IsBaseTable() { - continue - } - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { - continue - } - for _, fk := range table.ForeignKeys { - if ok && extractor.Filter("constraint_name", fk.Name.L) { - continue - } - updateRule, deleteRule := "NO ACTION", "NO ACTION" - if model.ReferOptionType(fk.OnUpdate) != 0 { - updateRule = model.ReferOptionType(fk.OnUpdate).String() - } - if model.ReferOptionType(fk.OnDelete) != 0 { - deleteRule = model.ReferOptionType(fk.OnDelete).String() - } - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - fk.Name.O, // CONSTRAINT_NAME - infoschema.CatalogVal, // UNIQUE_CONSTRAINT_CATALOG - schema.O, // UNIQUE_CONSTRAINT_SCHEMA - "PRIMARY", // UNIQUE_CONSTRAINT_NAME - "NONE", // MATCH_OPTION - updateRule, // UPDATE_RULE - deleteRule, // DELETE_RULE - table.Name.O, // TABLE_NAME - fk.RefTable.O, // REFERENCED_TABLE_NAME - ) - rows = append(rows, record) - } - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) updateStatsCacheIfNeed() bool { - for _, col := range e.columns { - // only the following columns need stats cache. - if col.Name.O == "AVG_ROW_LENGTH" || col.Name.O == "DATA_LENGTH" || col.Name.O == "INDEX_LENGTH" || col.Name.O == "TABLE_ROWS" { - return true - } - } - return false -} - -func (e *memtableRetriever) setDataFromOneTable( - sctx sessionctx.Context, - loc *time.Location, - checker privilege.Manager, - schema model.CIStr, - table *model.TableInfo, - rows [][]types.Datum, - useStatsCache bool, -) ([][]types.Datum, error) { - collation := table.Collate - if collation == "" { - collation = mysql.DefaultCollationName - } - createTime := types.NewTime(types.FromGoTime(table.GetUpdateTime().In(loc)), mysql.TypeDatetime, types.DefaultFsp) - - createOptions := "" - - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { - return rows, nil - } - pkType := "NONCLUSTERED" - if !table.IsView() { - if table.GetPartitionInfo() != nil { - createOptions = "partitioned" - } else if table.TableCacheStatusType == model.TableCacheStatusEnable { - createOptions = "cached=on" - } - var autoIncID any - hasAutoIncID, _ := infoschema.HasAutoIncrementColumn(table) - if hasAutoIncID { - autoIncID = getAutoIncrementID(e.is, sctx, table) - } - tableType := "BASE TABLE" - if util.IsSystemView(schema.L) { - tableType = "SYSTEM VIEW" - } - if table.IsSequence() { - tableType = "SEQUENCE" - } - if table.HasClusteredIndex() { - pkType = "CLUSTERED" - } - shardingInfo := infoschema.GetShardingInfo(schema, table) - var policyName any - if table.PlacementPolicyRef != nil { - policyName = table.PlacementPolicyRef.Name.O - } - - var rowCount, avgRowLength, dataLength, indexLength uint64 - if useStatsCache { - if table.GetPartitionInfo() == nil { - err := cache.TableRowStatsCache.UpdateByID(sctx, table.ID) - if err != nil { - return rows, err - } - } else { - // needs to update all partitions for partition table. - for _, pi := range table.GetPartitionInfo().Definitions { - err := cache.TableRowStatsCache.UpdateByID(sctx, pi.ID) - if err != nil { - return rows, err - } - } - } - rowCount, avgRowLength, dataLength, indexLength = cache.TableRowStatsCache.EstimateDataLength(table) - } - - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - tableType, // TABLE_TYPE - "InnoDB", // ENGINE - uint64(10), // VERSION - "Compact", // ROW_FORMAT - rowCount, // TABLE_ROWS - avgRowLength, // AVG_ROW_LENGTH - dataLength, // DATA_LENGTH - uint64(0), // MAX_DATA_LENGTH - indexLength, // INDEX_LENGTH - uint64(0), // DATA_FREE - autoIncID, // AUTO_INCREMENT - createTime, // CREATE_TIME - nil, // UPDATE_TIME - nil, // CHECK_TIME - collation, // TABLE_COLLATION - nil, // CHECKSUM - createOptions, // CREATE_OPTIONS - table.Comment, // TABLE_COMMENT - table.ID, // TIDB_TABLE_ID - shardingInfo, // TIDB_ROW_ID_SHARDING_INFO - pkType, // TIDB_PK_TYPE - policyName, // TIDB_PLACEMENT_POLICY_NAME - ) - rows = append(rows, record) - } else { - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - "VIEW", // TABLE_TYPE - nil, // ENGINE - nil, // VERSION - nil, // ROW_FORMAT - nil, // TABLE_ROWS - nil, // AVG_ROW_LENGTH - nil, // DATA_LENGTH - nil, // MAX_DATA_LENGTH - nil, // INDEX_LENGTH - nil, // DATA_FREE - nil, // AUTO_INCREMENT - createTime, // CREATE_TIME - nil, // UPDATE_TIME - nil, // CHECK_TIME - nil, // TABLE_COLLATION - nil, // CHECKSUM - nil, // CREATE_OPTIONS - "VIEW", // TABLE_COMMENT - table.ID, // TIDB_TABLE_ID - nil, // TIDB_ROW_ID_SHARDING_INFO - pkType, // TIDB_PK_TYPE - nil, // TIDB_PLACEMENT_POLICY_NAME - ) - rows = append(rows, record) - } - return rows, nil -} - -func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context) error { - useStatsCache := e.updateStatsCacheIfNeed() - checker := privilege.GetPrivilegeManager(sctx) - - var rows [][]types.Datum - loc := sctx.GetSessionVars().TimeZone - if loc == nil { - loc = time.Local - } - ex, ok := e.extractor.(*plannercore.InfoSchemaTablesExtractor) - if !ok { - return errors.Errorf("wrong extractor type: %T, expected InfoSchemaTablesExtractor", e.extractor) - } - if ex.SkipRequest { - return nil - } - - schemas, tables, err := ex.ListSchemasAndTables(ctx, e.is) - if err != nil { - return errors.Trace(err) - } - for i, table := range tables { - rows, err = e.setDataFromOneTable(sctx, loc, checker, schemas[i], table, rows, useStatsCache) - if err != nil { - return errors.Trace(err) - } - } - e.rows = rows - return nil -} - -// Data for inforation_schema.CHECK_CONSTRAINTS -// This is standards (ISO/IEC 9075-11) compliant and is compatible with the implementation in MySQL as well. -func (e *memtableRetriever) setDataFromCheckConstraints(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - var rows [][]types.Datum - checker := privilege.GetPrivilegeManager(sctx) - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - for _, schema := range schemas { - if ok && extractor.Filter("constraint_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, table := range tables { - if len(table.Constraints) > 0 { - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.SelectPriv) { - continue - } - for _, constraint := range table.Constraints { - if constraint.State != model.StatePublic { - continue - } - if ok && extractor.Filter("constraint_name", constraint.Name.L) { - continue - } - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - constraint.Name.O, // CONSTRAINT_NAME - fmt.Sprintf("(%s)", constraint.ExprString), // CHECK_CLAUSE - ) - rows = append(rows, record) - } - } - } - } - e.rows = rows - return nil -} - -// Data for inforation_schema.TIDB_CHECK_CONSTRAINTS -// This has non-standard TiDB specific extensions. -func (e *memtableRetriever) setDataFromTiDBCheckConstraints(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - var rows [][]types.Datum - checker := privilege.GetPrivilegeManager(sctx) - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - for _, schema := range schemas { - if ok && extractor.Filter("constraint_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, table := range tables { - if len(table.Constraints) > 0 { - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.SelectPriv) { - continue - } - for _, constraint := range table.Constraints { - if constraint.State != model.StatePublic { - continue - } - if ok && extractor.Filter("constraint_name", constraint.Name.L) { - continue - } - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - constraint.Name.O, // CONSTRAINT_NAME - fmt.Sprintf("(%s)", constraint.ExprString), // CHECK_CLAUSE - table.Name.O, // TABLE_NAME - table.ID, // TABLE_ID - ) - rows = append(rows, record) - } - } - } - } - e.rows = rows - return nil -} - -type hugeMemTableRetriever struct { - dummyCloser - extractor *plannercore.ColumnsTableExtractor - table *model.TableInfo - columns []*model.ColumnInfo - retrieved bool - initialized bool - rows [][]types.Datum - dbs []*model.DBInfo - curTables []*model.TableInfo - dbsIdx int - tblIdx int - viewMu syncutil.RWMutex - viewSchemaMap map[int64]*expression.Schema // table id to view schema - viewOutputNamesMap map[int64]types.NameSlice // table id to view output names - batch int - is infoschema.InfoSchema -} - -// retrieve implements the infoschemaRetriever interface -func (e *hugeMemTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.retrieved { - return nil, nil - } - - if !e.initialized { - e.is = sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() - dbs := e.is.AllSchemas() - slices.SortFunc(dbs, model.LessDBInfo) - e.dbs = dbs - e.initialized = true - e.rows = make([][]types.Datum, 0, 1024) - e.batch = 1024 - } - - var err error - if e.table.Name.O == infoschema.TableColumns { - err = e.setDataForColumns(ctx, sctx, e.extractor) - } - if err != nil { - return nil, err - } - e.retrieved = len(e.rows) == 0 - - return adjustColumns(e.rows, e.columns, e.table), nil -} - -func (e *hugeMemTableRetriever) setDataForColumns(ctx context.Context, sctx sessionctx.Context, extractor *plannercore.ColumnsTableExtractor) error { - checker := privilege.GetPrivilegeManager(sctx) - e.rows = e.rows[:0] - for ; e.dbsIdx < len(e.dbs); e.dbsIdx++ { - schema := e.dbs[e.dbsIdx] - var table *model.TableInfo - if len(e.curTables) == 0 { - tables, err := e.is.SchemaTableInfos(ctx, schema.Name) - if err != nil { - return errors.Trace(err) - } - e.curTables = tables - } - for e.tblIdx < len(e.curTables) { - table = e.curTables[e.tblIdx] - e.tblIdx++ - if e.setDataForColumnsWithOneTable(ctx, sctx, extractor, schema, table, checker) { - return nil - } - } - e.tblIdx = 0 - e.curTables = e.curTables[:0] - } - return nil -} - -func (e *hugeMemTableRetriever) setDataForColumnsWithOneTable( - ctx context.Context, - sctx sessionctx.Context, - extractor *plannercore.ColumnsTableExtractor, - schema *model.DBInfo, - table *model.TableInfo, - checker privilege.Manager) bool { - hasPrivs := false - var priv mysql.PrivilegeType - if checker != nil { - for _, p := range mysql.AllColumnPrivs { - if checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", p) { - hasPrivs = true - priv |= p - } - } - if !hasPrivs { - return false - } - } - - e.dataForColumnsInTable(ctx, sctx, schema, table, priv, extractor) - return len(e.rows) >= e.batch -} - -func (e *hugeMemTableRetriever) dataForColumnsInTable(ctx context.Context, sctx sessionctx.Context, schema *model.DBInfo, tbl *model.TableInfo, priv mysql.PrivilegeType, extractor *plannercore.ColumnsTableExtractor) { - if tbl.IsView() { - e.viewMu.Lock() - _, ok := e.viewSchemaMap[tbl.ID] - if !ok { - var viewLogicalPlan base.Plan - internalCtx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnOthers) - // Build plan is not thread safe, there will be concurrency on sessionctx. - if err := runWithSystemSession(internalCtx, sctx, func(s sessionctx.Context) error { - is := sessiontxn.GetTxnManager(s).GetTxnInfoSchema() - planBuilder, _ := plannercore.NewPlanBuilder().Init(s.GetPlanCtx(), is, hint.NewQBHintHandler(nil)) - var err error - viewLogicalPlan, err = planBuilder.BuildDataSourceFromView(ctx, schema.Name, tbl, nil, nil) - return errors.Trace(err) - }); err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(err) - e.viewMu.Unlock() - return - } - e.viewSchemaMap[tbl.ID] = viewLogicalPlan.Schema() - e.viewOutputNamesMap[tbl.ID] = viewLogicalPlan.OutputNames() - } - e.viewMu.Unlock() - } - - var tableSchemaRegexp, tableNameRegexp, columnsRegexp []collate.WildcardPattern - var tableSchemaFilterEnable, - tableNameFilterEnable, columnsFilterEnable bool - if !extractor.SkipRequest { - tableSchemaFilterEnable = extractor.TableSchema.Count() > 0 - tableNameFilterEnable = extractor.TableName.Count() > 0 - columnsFilterEnable = extractor.ColumnName.Count() > 0 - if len(extractor.TableSchemaPatterns) > 0 { - tableSchemaRegexp = make([]collate.WildcardPattern, len(extractor.TableSchemaPatterns)) - for i, pattern := range extractor.TableSchemaPatterns { - tableSchemaRegexp[i] = collate.GetCollatorByID(collate.CollationName2ID(mysql.UTF8MB4DefaultCollation)).Pattern() - tableSchemaRegexp[i].Compile(pattern, byte('\\')) - } - } - if len(extractor.TableNamePatterns) > 0 { - tableNameRegexp = make([]collate.WildcardPattern, len(extractor.TableNamePatterns)) - for i, pattern := range extractor.TableNamePatterns { - tableNameRegexp[i] = collate.GetCollatorByID(collate.CollationName2ID(mysql.UTF8MB4DefaultCollation)).Pattern() - tableNameRegexp[i].Compile(pattern, byte('\\')) - } - } - if len(extractor.ColumnNamePatterns) > 0 { - columnsRegexp = make([]collate.WildcardPattern, len(extractor.ColumnNamePatterns)) - for i, pattern := range extractor.ColumnNamePatterns { - columnsRegexp[i] = collate.GetCollatorByID(collate.CollationName2ID(mysql.UTF8MB4DefaultCollation)).Pattern() - columnsRegexp[i].Compile(pattern, byte('\\')) - } - } - } - i := 0 -ForColumnsTag: - for _, col := range tbl.Columns { - if col.Hidden { - continue - } - i++ - ft := &(col.FieldType) - if tbl.IsView() { - e.viewMu.RLock() - if e.viewSchemaMap[tbl.ID] != nil { - // If this is a view, replace the column with the view column. - idx := expression.FindFieldNameIdxByColName(e.viewOutputNamesMap[tbl.ID], col.Name.L) - if idx >= 0 { - col1 := e.viewSchemaMap[tbl.ID].Columns[idx] - ft = col1.GetType(sctx.GetExprCtx().GetEvalCtx()) - } - } - e.viewMu.RUnlock() - } - if !extractor.SkipRequest { - if tableSchemaFilterEnable && !extractor.TableSchema.Exist(schema.Name.L) { - continue - } - if tableNameFilterEnable && !extractor.TableName.Exist(tbl.Name.L) { - continue - } - if columnsFilterEnable && !extractor.ColumnName.Exist(col.Name.L) { - continue - } - for _, re := range tableSchemaRegexp { - if !re.DoMatch(schema.Name.L) { - continue ForColumnsTag - } - } - for _, re := range tableNameRegexp { - if !re.DoMatch(tbl.Name.L) { - continue ForColumnsTag - } - } - for _, re := range columnsRegexp { - if !re.DoMatch(col.Name.L) { - continue ForColumnsTag - } - } - } - - var charMaxLen, charOctLen, numericPrecision, numericScale, datetimePrecision any - colLen, decimal := ft.GetFlen(), ft.GetDecimal() - defaultFlen, defaultDecimal := mysql.GetDefaultFieldLengthAndDecimal(ft.GetType()) - if decimal == types.UnspecifiedLength { - decimal = defaultDecimal - } - if colLen == types.UnspecifiedLength { - colLen = defaultFlen - } - if ft.GetType() == mysql.TypeSet { - // Example: In MySQL set('a','bc','def','ghij') has length 13, because - // len('a')+len('bc')+len('def')+len('ghij')+len(ThreeComma)=13 - // Reference link: https://bugs.mysql.com/bug.php?id=22613 - colLen = 0 - for _, ele := range ft.GetElems() { - colLen += len(ele) - } - if len(ft.GetElems()) != 0 { - colLen += (len(ft.GetElems()) - 1) - } - charMaxLen = colLen - charOctLen = calcCharOctLength(colLen, ft.GetCharset()) - } else if ft.GetType() == mysql.TypeEnum { - // Example: In MySQL enum('a', 'ab', 'cdef') has length 4, because - // the longest string in the enum is 'cdef' - // Reference link: https://bugs.mysql.com/bug.php?id=22613 - colLen = 0 - for _, ele := range ft.GetElems() { - if len(ele) > colLen { - colLen = len(ele) - } - } - charMaxLen = colLen - charOctLen = calcCharOctLength(colLen, ft.GetCharset()) - } else if types.IsString(ft.GetType()) { - charMaxLen = colLen - charOctLen = calcCharOctLength(colLen, ft.GetCharset()) - } else if types.IsTypeFractionable(ft.GetType()) { - datetimePrecision = decimal - } else if types.IsTypeNumeric(ft.GetType()) { - numericPrecision = colLen - if ft.GetType() != mysql.TypeFloat && ft.GetType() != mysql.TypeDouble { - numericScale = decimal - } else if decimal != -1 { - numericScale = decimal - } - } else if ft.GetType() == mysql.TypeNull { - charMaxLen, charOctLen = 0, 0 - } - columnType := ft.InfoSchemaStr() - columnDesc := table.NewColDesc(table.ToColumn(col)) - var columnDefault any - if columnDesc.DefaultValue != nil { - columnDefault = fmt.Sprintf("%v", columnDesc.DefaultValue) - switch col.GetDefaultValue() { - case "CURRENT_TIMESTAMP": - default: - if ft.GetType() == mysql.TypeTimestamp && columnDefault != types.ZeroDatetimeStr { - timeValue, err := table.GetColDefaultValue(sctx.GetExprCtx(), col) - if err == nil { - columnDefault = timeValue.GetMysqlTime().String() - } - } - if ft.GetType() == mysql.TypeBit && !col.DefaultIsExpr { - defaultValBinaryLiteral := types.BinaryLiteral(columnDefault.(string)) - columnDefault = defaultValBinaryLiteral.ToBitLiteralString(true) - } - } - } - colType := ft.GetType() - if colType == mysql.TypeVarString { - colType = mysql.TypeVarchar - } - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.Name.O, // TABLE_SCHEMA - tbl.Name.O, // TABLE_NAME - col.Name.O, // COLUMN_NAME - i, // ORDINAL_POSITION - columnDefault, // COLUMN_DEFAULT - columnDesc.Null, // IS_NULLABLE - types.TypeToStr(colType, ft.GetCharset()), // DATA_TYPE - charMaxLen, // CHARACTER_MAXIMUM_LENGTH - charOctLen, // CHARACTER_OCTET_LENGTH - numericPrecision, // NUMERIC_PRECISION - numericScale, // NUMERIC_SCALE - datetimePrecision, // DATETIME_PRECISION - columnDesc.Charset, // CHARACTER_SET_NAME - columnDesc.Collation, // COLLATION_NAME - columnType, // COLUMN_TYPE - columnDesc.Key, // COLUMN_KEY - columnDesc.Extra, // EXTRA - strings.ToLower(privileges.PrivToString(priv, mysql.AllColumnPrivs, mysql.Priv2Str)), // PRIVILEGES - columnDesc.Comment, // COLUMN_COMMENT - col.GeneratedExprString, // GENERATION_EXPRESSION - ) - e.rows = append(e.rows, record) - } -} - -func calcCharOctLength(lenInChar int, cs string) int { - lenInBytes := lenInChar - if desc, err := charset.GetCharsetInfo(cs); err == nil { - lenInBytes = desc.Maxlen * lenInChar - } - return lenInBytes -} - -func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context) error { - useStatsCache := e.updateStatsCacheIfNeed() - checker := privilege.GetPrivilegeManager(sctx) - var rows [][]types.Datum - createTimeTp := mysql.TypeDatetime - - ex, ok := e.extractor.(*plannercore.InfoSchemaPartitionsExtractor) - if !ok { - return errors.Errorf("wrong extractor type: %T, expected InfoSchemaPartitionsExtractor", e.extractor) - } - if ex.SkipRequest { - return nil - } - schemas, tables, err := ex.ListSchemasAndTables(ctx, e.is) - if err != nil { - return errors.Trace(err) - } - for i, table := range tables { - schema := schemas[i] - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.SelectPriv) { - continue - } - createTime := types.NewTime(types.FromGoTime(table.GetUpdateTime()), createTimeTp, types.DefaultFsp) - - var rowCount, dataLength, indexLength uint64 - if useStatsCache { - if table.GetPartitionInfo() == nil { - err := cache.TableRowStatsCache.UpdateByID(sctx, table.ID) - if err != nil { - return err - } - } else { - // needs to update needed partitions for partition table. - for _, pi := range table.GetPartitionInfo().Definitions { - if ex.Filter("partition_name", pi.Name.L) { - continue - } - err := cache.TableRowStatsCache.UpdateByID(sctx, pi.ID) - if err != nil { - return err - } - } - } - } - if table.GetPartitionInfo() == nil { - rowCount = cache.TableRowStatsCache.GetTableRows(table.ID) - dataLength, indexLength = cache.TableRowStatsCache.GetDataAndIndexLength(table, table.ID, rowCount) - avgRowLength := uint64(0) - if rowCount != 0 { - avgRowLength = dataLength / rowCount - } - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - nil, // PARTITION_NAME - nil, // SUBPARTITION_NAME - nil, // PARTITION_ORDINAL_POSITION - nil, // SUBPARTITION_ORDINAL_POSITION - nil, // PARTITION_METHOD - nil, // SUBPARTITION_METHOD - nil, // PARTITION_EXPRESSION - nil, // SUBPARTITION_EXPRESSION - nil, // PARTITION_DESCRIPTION - rowCount, // TABLE_ROWS - avgRowLength, // AVG_ROW_LENGTH - dataLength, // DATA_LENGTH - nil, // MAX_DATA_LENGTH - indexLength, // INDEX_LENGTH - nil, // DATA_FREE - createTime, // CREATE_TIME - nil, // UPDATE_TIME - nil, // CHECK_TIME - nil, // CHECKSUM - nil, // PARTITION_COMMENT - nil, // NODEGROUP - nil, // TABLESPACE_NAME - nil, // TIDB_PARTITION_ID - nil, // TIDB_PLACEMENT_POLICY_NAME - ) - rows = append(rows, record) - } else { - for i, pi := range table.GetPartitionInfo().Definitions { - if ex.Filter("partition_name", pi.Name.L) { - continue - } - rowCount = cache.TableRowStatsCache.GetTableRows(pi.ID) - dataLength, indexLength = cache.TableRowStatsCache.GetDataAndIndexLength(table, pi.ID, rowCount) - avgRowLength := uint64(0) - if rowCount != 0 { - avgRowLength = dataLength / rowCount - } - - var partitionDesc string - if table.Partition.Type == model.PartitionTypeRange { - partitionDesc = strings.Join(pi.LessThan, ",") - } else if table.Partition.Type == model.PartitionTypeList { - if len(pi.InValues) > 0 { - buf := bytes.NewBuffer(nil) - for i, vs := range pi.InValues { - if i > 0 { - buf.WriteString(",") - } - if len(vs) != 1 { - buf.WriteString("(") - } - buf.WriteString(strings.Join(vs, ",")) - if len(vs) != 1 { - buf.WriteString(")") - } - } - partitionDesc = buf.String() - } - } - - partitionMethod := table.Partition.Type.String() - partitionExpr := table.Partition.Expr - if len(table.Partition.Columns) > 0 { - switch table.Partition.Type { - case model.PartitionTypeRange: - partitionMethod = "RANGE COLUMNS" - case model.PartitionTypeList: - partitionMethod = "LIST COLUMNS" - case model.PartitionTypeKey: - partitionMethod = "KEY" - default: - return errors.Errorf("Inconsistent partition type, have type %v, but with COLUMNS > 0 (%d)", table.Partition.Type, len(table.Partition.Columns)) - } - buf := bytes.NewBuffer(nil) - for i, col := range table.Partition.Columns { - if i > 0 { - buf.WriteString(",") - } - buf.WriteString("`") - buf.WriteString(col.String()) - buf.WriteString("`") - } - partitionExpr = buf.String() - } - - var policyName any - if pi.PlacementPolicyRef != nil { - policyName = pi.PlacementPolicyRef.Name.O - } - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - pi.Name.O, // PARTITION_NAME - nil, // SUBPARTITION_NAME - i+1, // PARTITION_ORDINAL_POSITION - nil, // SUBPARTITION_ORDINAL_POSITION - partitionMethod, // PARTITION_METHOD - nil, // SUBPARTITION_METHOD - partitionExpr, // PARTITION_EXPRESSION - nil, // SUBPARTITION_EXPRESSION - partitionDesc, // PARTITION_DESCRIPTION - rowCount, // TABLE_ROWS - avgRowLength, // AVG_ROW_LENGTH - dataLength, // DATA_LENGTH - uint64(0), // MAX_DATA_LENGTH - indexLength, // INDEX_LENGTH - uint64(0), // DATA_FREE - createTime, // CREATE_TIME - nil, // UPDATE_TIME - nil, // CHECK_TIME - nil, // CHECKSUM - pi.Comment, // PARTITION_COMMENT - nil, // NODEGROUP - nil, // TABLESPACE_NAME - pi.ID, // TIDB_PARTITION_ID - policyName, // TIDB_PLACEMENT_POLICY_NAME - ) - rows = append(rows, record) - } - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromIndexes(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - checker := privilege.GetPrivilegeManager(sctx) - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - var rows [][]types.Datum - for _, schema := range schemas { - if ok && extractor.Filter("table_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, tb := range tables { - if ok && extractor.Filter("table_name", tb.Name.L) { - continue - } - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, tb.Name.L, "", mysql.AllPrivMask) { - continue - } - - if tb.PKIsHandle { - var pkCol *model.ColumnInfo - for _, col := range tb.Cols() { - if mysql.HasPriKeyFlag(col.GetFlag()) { - pkCol = col - break - } - } - record := types.MakeDatums( - schema.O, // TABLE_SCHEMA - tb.Name.O, // TABLE_NAME - 0, // NON_UNIQUE - "PRIMARY", // KEY_NAME - 1, // SEQ_IN_INDEX - pkCol.Name.O, // COLUMN_NAME - nil, // SUB_PART - "", // INDEX_COMMENT - nil, // Expression - 0, // INDEX_ID - "YES", // IS_VISIBLE - "YES", // CLUSTERED - 0, // IS_GLOBAL - ) - rows = append(rows, record) - } - for _, idxInfo := range tb.Indices { - if idxInfo.State != model.StatePublic { - continue - } - isClustered := "NO" - if tb.IsCommonHandle && idxInfo.Primary { - isClustered = "YES" - } - for i, col := range idxInfo.Columns { - nonUniq := 1 - if idxInfo.Unique { - nonUniq = 0 - } - var subPart any - if col.Length != types.UnspecifiedLength { - subPart = col.Length - } - colName := col.Name.O - var expression any - expression = nil - tblCol := tb.Columns[col.Offset] - if tblCol.Hidden { - colName = "NULL" - expression = tblCol.GeneratedExprString - } - visible := "YES" - if idxInfo.Invisible { - visible = "NO" - } - record := types.MakeDatums( - schema.O, // TABLE_SCHEMA - tb.Name.O, // TABLE_NAME - nonUniq, // NON_UNIQUE - idxInfo.Name.O, // KEY_NAME - i+1, // SEQ_IN_INDEX - colName, // COLUMN_NAME - subPart, // SUB_PART - idxInfo.Comment, // INDEX_COMMENT - expression, // Expression - idxInfo.ID, // INDEX_ID - visible, // IS_VISIBLE - isClustered, // CLUSTERED - idxInfo.Global, // IS_GLOBAL - ) - rows = append(rows, record) - } - } - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromViews(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - checker := privilege.GetPrivilegeManager(sctx) - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - var rows [][]types.Datum - for _, schema := range schemas { - if ok && extractor.Filter("table_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, table := range tables { - if ok && extractor.Filter("table_name", table.Name.L) { - continue - } - if !table.IsView() { - continue - } - collation := table.Collate - charset := table.Charset - if collation == "" { - collation = mysql.DefaultCollationName - } - if charset == "" { - charset = mysql.DefaultCharset - } - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { - continue - } - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - table.View.SelectStmt, // VIEW_DEFINITION - table.View.CheckOption.String(), // CHECK_OPTION - "NO", // IS_UPDATABLE - table.View.Definer.String(), // DEFINER - table.View.Security.String(), // SECURITY_TYPE - charset, // CHARACTER_SET_CLIENT - collation, // COLLATION_CONNECTION - ) - rows = append(rows, record) - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) dataForTiKVStoreStatus(ctx context.Context, sctx sessionctx.Context) (err error) { - tikvStore, ok := sctx.GetStore().(helper.Storage) - if !ok { - return errors.New("Information about TiKV store status can be gotten only when the storage is TiKV") - } - tikvHelper := &helper.Helper{ - Store: tikvStore, - RegionCache: tikvStore.GetRegionCache(), - } - pdCli, err := tikvHelper.TryGetPDHTTPClient() - if err != nil { - return err - } - storesStat, err := pdCli.GetStores(ctx) - if err != nil { - return err - } - for _, storeStat := range storesStat.Stores { - row := make([]types.Datum, len(infoschema.TableTiKVStoreStatusCols)) - row[0].SetInt64(storeStat.Store.ID) - row[1].SetString(storeStat.Store.Address, mysql.DefaultCollationName) - row[2].SetInt64(storeStat.Store.State) - row[3].SetString(storeStat.Store.StateName, mysql.DefaultCollationName) - data, err := json.Marshal(storeStat.Store.Labels) - if err != nil { - return err - } - bj := types.BinaryJSON{} - if err = bj.UnmarshalJSON(data); err != nil { - return err - } - row[4].SetMysqlJSON(bj) - row[5].SetString(storeStat.Store.Version, mysql.DefaultCollationName) - row[6].SetString(storeStat.Status.Capacity, mysql.DefaultCollationName) - row[7].SetString(storeStat.Status.Available, mysql.DefaultCollationName) - row[8].SetInt64(storeStat.Status.LeaderCount) - row[9].SetFloat64(storeStat.Status.LeaderWeight) - row[10].SetFloat64(storeStat.Status.LeaderScore) - row[11].SetInt64(storeStat.Status.LeaderSize) - row[12].SetInt64(storeStat.Status.RegionCount) - row[13].SetFloat64(storeStat.Status.RegionWeight) - row[14].SetFloat64(storeStat.Status.RegionScore) - row[15].SetInt64(storeStat.Status.RegionSize) - startTs := types.NewTime(types.FromGoTime(storeStat.Status.StartTS), mysql.TypeDatetime, types.DefaultFsp) - row[16].SetMysqlTime(startTs) - lastHeartbeatTs := types.NewTime(types.FromGoTime(storeStat.Status.LastHeartbeatTS), mysql.TypeDatetime, types.DefaultFsp) - row[17].SetMysqlTime(lastHeartbeatTs) - row[18].SetString(storeStat.Status.Uptime, mysql.DefaultCollationName) - if sem.IsEnabled() { - // Patch out IP addresses etc if the user does not have the RESTRICTED_TABLES_ADMIN privilege - checker := privilege.GetPrivilegeManager(sctx) - if checker == nil || !checker.RequestDynamicVerification(sctx.GetSessionVars().ActiveRoles, "RESTRICTED_TABLES_ADMIN", false) { - row[1].SetString(strconv.FormatInt(storeStat.Store.ID, 10), mysql.DefaultCollationName) - row[1].SetNull() - row[6].SetNull() - row[7].SetNull() - row[16].SetNull() - row[18].SetNull() - } - } - e.rows = append(e.rows, row) - } - return nil -} - -// DDLJobsReaderExec executes DDLJobs information retrieving. -type DDLJobsReaderExec struct { - exec.BaseExecutor - DDLJobRetriever - - cacheJobs []*model.Job - is infoschema.InfoSchema - sess sessionctx.Context -} - -// Open implements the Executor Next interface. -func (e *DDLJobsReaderExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - e.DDLJobRetriever.is = e.is - e.activeRoles = e.Ctx().GetSessionVars().ActiveRoles - sess, err := e.GetSysSession() - if err != nil { - return err - } - e.sess = sess - err = sessiontxn.NewTxn(context.Background(), sess) - if err != nil { - return err - } - txn, err := sess.Txn(true) - if err != nil { - return err - } - sess.GetSessionVars().SetInTxn(true) - err = e.DDLJobRetriever.initial(txn, sess) - if err != nil { - return err - } - return nil -} - -// Next implements the Executor Next interface. -func (e *DDLJobsReaderExec) Next(_ context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - checker := privilege.GetPrivilegeManager(e.Ctx()) - count := 0 - - // Append running DDL jobs. - if e.cursor < len(e.runningJobs) { - num := min(req.Capacity(), len(e.runningJobs)-e.cursor) - for i := e.cursor; i < e.cursor+num; i++ { - e.appendJobToChunk(req, e.runningJobs[i], checker) - req.AppendString(12, e.runningJobs[i].Query) - if e.runningJobs[i].MultiSchemaInfo != nil { - for range e.runningJobs[i].MultiSchemaInfo.SubJobs { - req.AppendString(12, e.runningJobs[i].Query) - } - } - } - e.cursor += num - count += num - } - var err error - - // Append history DDL jobs. - if count < req.Capacity() { - e.cacheJobs, err = e.historyJobIter.GetLastJobs(req.Capacity()-count, e.cacheJobs) - if err != nil { - return err - } - for _, job := range e.cacheJobs { - e.appendJobToChunk(req, job, checker) - req.AppendString(12, job.Query) - if job.MultiSchemaInfo != nil { - for range job.MultiSchemaInfo.SubJobs { - req.AppendString(12, job.Query) - } - } - } - e.cursor += len(e.cacheJobs) - } - return nil -} - -// Close implements the Executor Close interface. -func (e *DDLJobsReaderExec) Close() error { - e.ReleaseSysSession(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), e.sess) - return e.BaseExecutor.Close() -} - -func (e *memtableRetriever) setDataFromEngines() { - var rows [][]types.Datum - rows = append(rows, - types.MakeDatums( - "InnoDB", // Engine - "DEFAULT", // Support - "Supports transactions, row-level locking, and foreign keys", // Comment - "YES", // Transactions - "YES", // XA - "YES", // Savepoints - ), - ) - e.rows = rows -} - -func (e *memtableRetriever) setDataFromCharacterSets() { - charsets := charset.GetSupportedCharsets() - var rows = make([][]types.Datum, 0, len(charsets)) - for _, charset := range charsets { - rows = append(rows, - types.MakeDatums(charset.Name, charset.DefaultCollation, charset.Desc, charset.Maxlen), - ) - } - e.rows = rows -} - -func (e *memtableRetriever) setDataFromCollations() { - collations := collate.GetSupportedCollations() - var rows = make([][]types.Datum, 0, len(collations)) - for _, collation := range collations { - isDefault := "" - if collation.IsDefault { - isDefault = "Yes" - } - rows = append(rows, - types.MakeDatums(collation.Name, collation.CharsetName, collation.ID, - isDefault, "Yes", collation.Sortlen, collation.PadAttribute), - ) - } - e.rows = rows -} - -func (e *memtableRetriever) dataForCollationCharacterSetApplicability() { - collations := collate.GetSupportedCollations() - var rows = make([][]types.Datum, 0, len(collations)) - for _, collation := range collations { - rows = append(rows, - types.MakeDatums(collation.Name, collation.CharsetName), - ) - } - e.rows = rows -} - -func (e *memtableRetriever) dataForTiDBClusterInfo(ctx sessionctx.Context) error { - servers, err := infoschema.GetClusterServerInfo(ctx) - if err != nil { - e.rows = nil - return err - } - rows := make([][]types.Datum, 0, len(servers)) - for _, server := range servers { - upTimeStr := "" - startTimeNative := types.NewTime(types.FromGoTime(time.Now()), mysql.TypeDatetime, 0) - if server.StartTimestamp > 0 { - startTime := time.Unix(server.StartTimestamp, 0) - startTimeNative = types.NewTime(types.FromGoTime(startTime), mysql.TypeDatetime, 0) - upTimeStr = time.Since(startTime).String() - } - serverType := server.ServerType - if server.ServerType == kv.TiFlash.Name() && server.EngineRole == placement.EngineRoleLabelWrite { - serverType = infoschema.TiFlashWrite - } - row := types.MakeDatums( - serverType, - server.Address, - server.StatusAddr, - server.Version, - server.GitHash, - startTimeNative, - upTimeStr, - server.ServerID, - ) - if sem.IsEnabled() { - checker := privilege.GetPrivilegeManager(ctx) - if checker == nil || !checker.RequestDynamicVerification(ctx.GetSessionVars().ActiveRoles, "RESTRICTED_TABLES_ADMIN", false) { - row[1].SetString(strconv.FormatUint(server.ServerID, 10), mysql.DefaultCollationName) - row[2].SetNull() - row[5].SetNull() - row[6].SetNull() - } - } - rows = append(rows, row) - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromKeyColumnUsage(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - checker := privilege.GetPrivilegeManager(sctx) - rows := make([][]types.Datum, 0, len(schemas)) // The capacity is not accurate, but it is not a big problem. - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - for _, schema := range schemas { - // `constraint_schema` and `table_schema` are always the same in MySQL. - if ok && extractor.Filter("constraint_schema", schema.L) { - continue - } - if ok && extractor.Filter("table_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, table := range tables { - if ok && extractor.Filter("table_name", table.Name.L) { - continue - } - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { - continue - } - rs := keyColumnUsageInTable(schema, table, extractor) - rows = append(rows, rs...) - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForClusterProcessList(ctx sessionctx.Context) error { - e.setDataForProcessList(ctx) - rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) - if err != nil { - return err - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForProcessList(ctx sessionctx.Context) { - sm := ctx.GetSessionManager() - if sm == nil { - return - } - - loginUser := ctx.GetSessionVars().User - hasProcessPriv := hasPriv(ctx, mysql.ProcessPriv) - pl := sm.ShowProcessList() - - records := make([][]types.Datum, 0, len(pl)) - for _, pi := range pl { - // If you have the PROCESS privilege, you can see all threads. - // Otherwise, you can see only your own threads. - if !hasProcessPriv && loginUser != nil && pi.User != loginUser.Username { - continue - } - - rows := pi.ToRow(ctx.GetSessionVars().StmtCtx.TimeZone()) - record := types.MakeDatums(rows...) - records = append(records, record) - } - e.rows = records -} - -func (e *memtableRetriever) setDataFromUserPrivileges(ctx sessionctx.Context) { - pm := privilege.GetPrivilegeManager(ctx) - // The results depend on the user querying the information. - e.rows = pm.UserPrivilegesTable(ctx.GetSessionVars().ActiveRoles, ctx.GetSessionVars().User.Username, ctx.GetSessionVars().User.Hostname) -} - -func (e *memtableRetriever) setDataForMetricTables() { - tables := make([]string, 0, len(infoschema.MetricTableMap)) - for name := range infoschema.MetricTableMap { - tables = append(tables, name) - } - slices.Sort(tables) - rows := make([][]types.Datum, 0, len(tables)) - for _, name := range tables { - schema := infoschema.MetricTableMap[name] - record := types.MakeDatums( - name, // METRICS_NAME - schema.PromQL, // PROMQL - strings.Join(schema.Labels, ","), // LABELS - schema.Quantile, // QUANTILE - schema.Comment, // COMMENT - ) - rows = append(rows, record) - } - e.rows = rows -} - -func keyColumnUsageInTable(schema model.CIStr, table *model.TableInfo, extractor *plannercore.InfoSchemaBaseExtractor) [][]types.Datum { - var rows [][]types.Datum - if table.PKIsHandle { - for _, col := range table.Columns { - if mysql.HasPriKeyFlag(col.GetFlag()) { - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - infoschema.PrimaryConstraint, // CONSTRAINT_NAME - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - col.Name.O, // COLUMN_NAME - 1, // ORDINAL_POSITION - 1, // POSITION_IN_UNIQUE_CONSTRAINT - nil, // REFERENCED_TABLE_SCHEMA - nil, // REFERENCED_TABLE_NAME - nil, // REFERENCED_COLUMN_NAME - ) - rows = append(rows, record) - break - } - } - } - nameToCol := make(map[string]*model.ColumnInfo, len(table.Columns)) - for _, c := range table.Columns { - nameToCol[c.Name.L] = c - } - for _, index := range table.Indices { - var idxName string - if index.Primary { - idxName = infoschema.PrimaryConstraint - } else if index.Unique { - idxName = index.Name.O - } else { - // Only handle unique/primary key - continue - } - - if extractor != nil && extractor.Filter("constraint_name", idxName) { - continue - } - - for i, key := range index.Columns { - col := nameToCol[key.Name.L] - if col.Hidden { - continue - } - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - idxName, // CONSTRAINT_NAME - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - col.Name.O, // COLUMN_NAME - i+1, // ORDINAL_POSITION, - nil, // POSITION_IN_UNIQUE_CONSTRAINT - nil, // REFERENCED_TABLE_SCHEMA - nil, // REFERENCED_TABLE_NAME - nil, // REFERENCED_COLUMN_NAME - ) - rows = append(rows, record) - } - } - for _, fk := range table.ForeignKeys { - for i, key := range fk.Cols { - fkRefCol := "" - if len(fk.RefCols) > i { - fkRefCol = fk.RefCols[i].O - } - col := nameToCol[key.L] - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - fk.Name.O, // CONSTRAINT_NAME - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // TABLE_SCHEMA - table.Name.O, // TABLE_NAME - col.Name.O, // COLUMN_NAME - i+1, // ORDINAL_POSITION, - 1, // POSITION_IN_UNIQUE_CONSTRAINT - fk.RefSchema.O, // REFERENCED_TABLE_SCHEMA - fk.RefTable.O, // REFERENCED_TABLE_NAME - fkRefCol, // REFERENCED_COLUMN_NAME - ) - rows = append(rows, record) - } - } - return rows -} - -func (e *memtableRetriever) setDataForTiKVRegionStatus(ctx context.Context, sctx sessionctx.Context) (err error) { - checker := privilege.GetPrivilegeManager(sctx) - var extractorTableIDs []int64 - tikvStore, ok := sctx.GetStore().(helper.Storage) - if !ok { - return errors.New("Information about TiKV region status can be gotten only when the storage is TiKV") - } - tikvHelper := &helper.Helper{ - Store: tikvStore, - RegionCache: tikvStore.GetRegionCache(), - } - requestByTableRange := false - var allRegionsInfo *pd.RegionsInfo - is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) - if e.extractor != nil { - extractor, ok := e.extractor.(*plannercore.TiKVRegionStatusExtractor) - if ok && len(extractor.GetTablesID()) > 0 { - extractorTableIDs = extractor.GetTablesID() - for _, tableID := range extractorTableIDs { - regionsInfo, err := e.getRegionsInfoForTable(ctx, tikvHelper, is, tableID) - if err != nil { - if errors.ErrorEqual(err, infoschema.ErrTableExists) { - continue - } - return err - } - allRegionsInfo = allRegionsInfo.Merge(regionsInfo) - } - requestByTableRange = true - } - } - if !requestByTableRange { - pdCli, err := tikvHelper.TryGetPDHTTPClient() - if err != nil { - return err - } - allRegionsInfo, err = pdCli.GetRegions(ctx) - if err != nil { - return err - } - } - tableInfos := tikvHelper.GetRegionsTableInfo(allRegionsInfo, is, nil) - for i := range allRegionsInfo.Regions { - regionTableList := tableInfos[allRegionsInfo.Regions[i].ID] - if len(regionTableList) == 0 { - e.setNewTiKVRegionStatusCol(&allRegionsInfo.Regions[i], nil) - } - for j, regionTable := range regionTableList { - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, regionTable.DB.Name.L, regionTable.Table.Name.L, "", mysql.AllPrivMask) { - continue - } - if len(extractorTableIDs) == 0 { - e.setNewTiKVRegionStatusCol(&allRegionsInfo.Regions[i], ®ionTable) - } - if slices.Contains(extractorTableIDs, regionTableList[j].Table.ID) { - e.setNewTiKVRegionStatusCol(&allRegionsInfo.Regions[i], ®ionTable) - } - } - } - return nil -} - -func (e *memtableRetriever) getRegionsInfoForTable(ctx context.Context, h *helper.Helper, is infoschema.InfoSchema, tableID int64) (*pd.RegionsInfo, error) { - tbl, _ := is.TableByID(tableID) - if tbl == nil { - return nil, infoschema.ErrTableExists.GenWithStackByArgs(tableID) - } - - pt := tbl.Meta().GetPartitionInfo() - if pt == nil { - regionsInfo, err := e.getRegionsInfoForSingleTable(ctx, h, tableID) - if err != nil { - return nil, err - } - return regionsInfo, nil - } - - var allRegionsInfo *pd.RegionsInfo - for _, def := range pt.Definitions { - regionsInfo, err := e.getRegionsInfoForSingleTable(ctx, h, def.ID) - if err != nil { - return nil, err - } - allRegionsInfo = allRegionsInfo.Merge(regionsInfo) - } - return allRegionsInfo, nil -} - -func (*memtableRetriever) getRegionsInfoForSingleTable(ctx context.Context, helper *helper.Helper, tableID int64) (*pd.RegionsInfo, error) { - pdCli, err := helper.TryGetPDHTTPClient() - if err != nil { - return nil, err - } - sk, ek := tablecodec.GetTableHandleKeyRange(tableID) - sRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, sk)) - if err != nil { - return nil, err - } - eRegion, err := pdCli.GetRegionByKey(ctx, codec.EncodeBytes(nil, ek)) - if err != nil { - return nil, err - } - sk, err = hex.DecodeString(sRegion.StartKey) - if err != nil { - return nil, err - } - ek, err = hex.DecodeString(eRegion.EndKey) - if err != nil { - return nil, err - } - return pdCli.GetRegionsByKeyRange(ctx, pd.NewKeyRange(sk, ek), -1) -} - -func (e *memtableRetriever) setNewTiKVRegionStatusCol(region *pd.RegionInfo, table *helper.TableInfo) { - row := make([]types.Datum, len(infoschema.TableTiKVRegionStatusCols)) - row[0].SetInt64(region.ID) - row[1].SetString(region.StartKey, mysql.DefaultCollationName) - row[2].SetString(region.EndKey, mysql.DefaultCollationName) - if table != nil { - row[3].SetInt64(table.Table.ID) - row[4].SetString(table.DB.Name.O, mysql.DefaultCollationName) - row[5].SetString(table.Table.Name.O, mysql.DefaultCollationName) - if table.IsIndex { - row[6].SetInt64(1) - row[7].SetInt64(table.Index.ID) - row[8].SetString(table.Index.Name.O, mysql.DefaultCollationName) - } else { - row[6].SetInt64(0) - } - if table.IsPartition { - row[9].SetInt64(1) - row[10].SetInt64(table.Partition.ID) - row[11].SetString(table.Partition.Name.O, mysql.DefaultCollationName) - } else { - row[9].SetInt64(0) - } - } else { - row[6].SetInt64(0) - row[9].SetInt64(0) - } - row[12].SetInt64(region.Epoch.ConfVer) - row[13].SetInt64(region.Epoch.Version) - row[14].SetUint64(region.WrittenBytes) - row[15].SetUint64(region.ReadBytes) - row[16].SetInt64(region.ApproximateSize) - row[17].SetInt64(region.ApproximateKeys) - if region.ReplicationStatus != nil { - row[18].SetString(region.ReplicationStatus.State, mysql.DefaultCollationName) - row[19].SetInt64(region.ReplicationStatus.StateID) - } - e.rows = append(e.rows, row) -} - -const ( - normalPeer = "NORMAL" - pendingPeer = "PENDING" - downPeer = "DOWN" -) - -func (e *memtableRetriever) setDataForTiDBHotRegions(ctx context.Context, sctx sessionctx.Context) error { - tikvStore, ok := sctx.GetStore().(helper.Storage) - if !ok { - return errors.New("Information about hot region can be gotten only when the storage is TiKV") - } - tikvHelper := &helper.Helper{ - Store: tikvStore, - RegionCache: tikvStore.GetRegionCache(), - } - is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() - metrics, err := tikvHelper.ScrapeHotInfo(ctx, helper.HotRead, is, tikvHelper.FilterMemDBs) - if err != nil { - return err - } - e.setDataForHotRegionByMetrics(metrics, "read") - metrics, err = tikvHelper.ScrapeHotInfo(ctx, helper.HotWrite, is, nil) - if err != nil { - return err - } - e.setDataForHotRegionByMetrics(metrics, "write") - return nil -} - -func (e *memtableRetriever) setDataForHotRegionByMetrics(metrics []helper.HotTableIndex, tp string) { - rows := make([][]types.Datum, 0, len(metrics)) - for _, tblIndex := range metrics { - row := make([]types.Datum, len(infoschema.TableTiDBHotRegionsCols)) - if tblIndex.IndexName != "" { - row[1].SetInt64(tblIndex.IndexID) - row[4].SetString(tblIndex.IndexName, mysql.DefaultCollationName) - } else { - row[1].SetNull() - row[4].SetNull() - } - row[0].SetInt64(tblIndex.TableID) - row[2].SetString(tblIndex.DbName, mysql.DefaultCollationName) - row[3].SetString(tblIndex.TableName, mysql.DefaultCollationName) - row[5].SetUint64(tblIndex.RegionID) - row[6].SetString(tp, mysql.DefaultCollationName) - if tblIndex.RegionMetric == nil { - row[7].SetNull() - row[8].SetNull() - } else { - row[7].SetInt64(int64(tblIndex.RegionMetric.MaxHotDegree)) - row[8].SetInt64(int64(tblIndex.RegionMetric.Count)) - } - row[9].SetUint64(tblIndex.RegionMetric.FlowBytes) - rows = append(rows, row) - } - e.rows = append(e.rows, rows...) -} - -// setDataFromTableConstraints constructs data for table information_schema.constraints.See https://dev.mysql.com/doc/refman/5.7/en/table-constraints-table.html -func (e *memtableRetriever) setDataFromTableConstraints(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - checker := privilege.GetPrivilegeManager(sctx) - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - var rows [][]types.Datum - for _, schema := range schemas { - if ok && extractor.Filter("constraint_schema", schema.L) { - continue - } - if ok && extractor.Filter("table_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, tbl := range tables { - if ok && extractor.Filter("table_name", tbl.Name.L) { - continue - } - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, tbl.Name.L, "", mysql.AllPrivMask) { - continue - } - - if tbl.PKIsHandle { - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - mysql.PrimaryKeyName, // CONSTRAINT_NAME - schema.O, // TABLE_SCHEMA - tbl.Name.O, // TABLE_NAME - infoschema.PrimaryKeyType, // CONSTRAINT_TYPE - ) - rows = append(rows, record) - } - - for _, idx := range tbl.Indices { - var cname, ctype string - if idx.Primary { - cname = mysql.PrimaryKeyName - ctype = infoschema.PrimaryKeyType - } else if idx.Unique { - cname = idx.Name.O - ctype = infoschema.UniqueKeyType - } else { - // The index has no constriant. - continue - } - if ok && extractor.Filter("constraint_name", cname) { - continue - } - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - cname, // CONSTRAINT_NAME - schema.O, // TABLE_SCHEMA - tbl.Name.O, // TABLE_NAME - ctype, // CONSTRAINT_TYPE - ) - rows = append(rows, record) - } - // TiDB includes foreign key information for compatibility but foreign keys are not yet enforced. - for _, fk := range tbl.ForeignKeys { - record := types.MakeDatums( - infoschema.CatalogVal, // CONSTRAINT_CATALOG - schema.O, // CONSTRAINT_SCHEMA - fk.Name.O, // CONSTRAINT_NAME - schema.O, // TABLE_SCHEMA - tbl.Name.O, // TABLE_NAME - infoschema.ForeignKeyType, // CONSTRAINT_TYPE - ) - rows = append(rows, record) - } - } - } - e.rows = rows - return nil -} - -// tableStorageStatsRetriever is used to read slow log data. -type tableStorageStatsRetriever struct { - dummyCloser - table *model.TableInfo - outputCols []*model.ColumnInfo - retrieved bool - initialized bool - extractor *plannercore.TableStorageStatsExtractor - initialTables []*initialTable - curTable int - helper *helper.Helper - stats *pd.RegionStats -} - -func (e *tableStorageStatsRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.retrieved { - return nil, nil - } - if !e.initialized { - err := e.initialize(ctx, sctx) - if err != nil { - return nil, err - } - } - if len(e.initialTables) == 0 || e.curTable >= len(e.initialTables) { - e.retrieved = true - return nil, nil - } - - rows, err := e.setDataForTableStorageStats(ctx) - if err != nil { - return nil, err - } - if len(e.outputCols) == len(e.table.Columns) { - return rows, nil - } - retRows := make([][]types.Datum, len(rows)) - for i, fullRow := range rows { - row := make([]types.Datum, len(e.outputCols)) - for j, col := range e.outputCols { - row[j] = fullRow[col.Offset] - } - retRows[i] = row - } - return retRows, nil -} - -type initialTable struct { - db string - *model.TableInfo -} - -func (e *tableStorageStatsRetriever) initialize(ctx context.Context, sctx sessionctx.Context) error { - is := sctx.GetInfoSchema().(infoschema.InfoSchema) - var databases []string - schemas := e.extractor.TableSchema - tables := e.extractor.TableName - - // If not specify the table_schema, return an error to avoid traverse all schemas and their tables. - if len(schemas) == 0 { - return errors.Errorf("Please add where clause to filter the column TABLE_SCHEMA. " + - "For example, where TABLE_SCHEMA = 'xxx' or where TABLE_SCHEMA in ('xxx', 'yyy')") - } - - // Filter the sys or memory schema. - for schema := range schemas { - if !util.IsMemDB(schema) { - databases = append(databases, schema) - } - } - - // Privilege checker. - checker := func(db, table string) bool { - if pm := privilege.GetPrivilegeManager(sctx); pm != nil { - return pm.RequestVerification(sctx.GetSessionVars().ActiveRoles, db, table, "", mysql.AllPrivMask) - } - return true - } - - // Extract the tables to the initialTable. - for _, DB := range databases { - // The user didn't specified the table, extract all tables of this db to initialTable. - if len(tables) == 0 { - tbs, err := is.SchemaTableInfos(ctx, model.NewCIStr(DB)) - if err != nil { - return errors.Trace(err) - } - for _, tb := range tbs { - // For every db.table, check it's privileges. - if checker(DB, tb.Name.L) { - e.initialTables = append(e.initialTables, &initialTable{DB, tb}) - } - } - } else { - // The user specified the table, extract the specified tables of this db to initialTable. - for tb := range tables { - if tb, err := is.TableByName(context.Background(), model.NewCIStr(DB), model.NewCIStr(tb)); err == nil { - // For every db.table, check it's privileges. - if checker(DB, tb.Meta().Name.L) { - e.initialTables = append(e.initialTables, &initialTable{DB, tb.Meta()}) - } - } - } - } - } - - // Cache the helper and return an error if PD unavailable. - tikvStore, ok := sctx.GetStore().(helper.Storage) - if !ok { - return errors.Errorf("Information about TiKV region status can be gotten only when the storage is TiKV") - } - e.helper = helper.NewHelper(tikvStore) - _, err := e.helper.GetPDAddr() - if err != nil { - return err - } - e.initialized = true - return nil -} - -func (e *tableStorageStatsRetriever) setDataForTableStorageStats(ctx context.Context) ([][]types.Datum, error) { - rows := make([][]types.Datum, 0, 1024) - count := 0 - for e.curTable < len(e.initialTables) && count < 1024 { - tbl := e.initialTables[e.curTable] - tblIDs := make([]int64, 0, 1) - tblIDs = append(tblIDs, tbl.ID) - if partInfo := tbl.GetPartitionInfo(); partInfo != nil { - for _, partDef := range partInfo.Definitions { - tblIDs = append(tblIDs, partDef.ID) - } - } - var err error - for _, tableID := range tblIDs { - e.stats, err = e.helper.GetPDRegionStats(ctx, tableID, false) - if err != nil { - return nil, err - } - peerCount := 0 - for _, cnt := range e.stats.StorePeerCount { - peerCount += cnt - } - - record := types.MakeDatums( - tbl.db, // TABLE_SCHEMA - tbl.Name.O, // TABLE_NAME - tableID, // TABLE_ID - peerCount, // TABLE_PEER_COUNT - e.stats.Count, // TABLE_REGION_COUNT - e.stats.EmptyCount, // TABLE_EMPTY_REGION_COUNT - e.stats.StorageSize, // TABLE_SIZE - e.stats.StorageKeys, // TABLE_KEYS - ) - rows = append(rows, record) - } - count++ - e.curTable++ - } - return rows, nil -} - -// dataForAnalyzeStatusHelper is a helper function which can be used in show_stats.go -func dataForAnalyzeStatusHelper(ctx context.Context, sctx sessionctx.Context) (rows [][]types.Datum, err error) { - const maxAnalyzeJobs = 30 - const sql = "SELECT table_schema, table_name, partition_name, job_info, processed_rows, CONVERT_TZ(start_time, @@TIME_ZONE, '+00:00'), CONVERT_TZ(end_time, @@TIME_ZONE, '+00:00'), state, fail_reason, instance, process_id FROM mysql.analyze_jobs ORDER BY update_time DESC LIMIT %?" - exec := sctx.GetRestrictedSQLExecutor() - kctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - chunkRows, _, err := exec.ExecRestrictedSQL(kctx, nil, sql, maxAnalyzeJobs) - if err != nil { - return nil, err - } - checker := privilege.GetPrivilegeManager(sctx) - - for _, chunkRow := range chunkRows { - dbName := chunkRow.GetString(0) - tableName := chunkRow.GetString(1) - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, dbName, tableName, "", mysql.AllPrivMask) { - continue - } - partitionName := chunkRow.GetString(2) - jobInfo := chunkRow.GetString(3) - processedRows := chunkRow.GetInt64(4) - var startTime, endTime any - if !chunkRow.IsNull(5) { - t, err := chunkRow.GetTime(5).GoTime(time.UTC) - if err != nil { - return nil, err - } - startTime = types.NewTime(types.FromGoTime(t.In(sctx.GetSessionVars().TimeZone)), mysql.TypeDatetime, 0) - } - if !chunkRow.IsNull(6) { - t, err := chunkRow.GetTime(6).GoTime(time.UTC) - if err != nil { - return nil, err - } - endTime = types.NewTime(types.FromGoTime(t.In(sctx.GetSessionVars().TimeZone)), mysql.TypeDatetime, 0) - } - - state := chunkRow.GetEnum(7).String() - var failReason any - if !chunkRow.IsNull(8) { - failReason = chunkRow.GetString(8) - } - instance := chunkRow.GetString(9) - var procID any - if !chunkRow.IsNull(10) { - procID = chunkRow.GetUint64(10) - } - - var remainDurationStr, progressDouble, estimatedRowCntStr any - if state == statistics.AnalyzeRunning && !strings.HasPrefix(jobInfo, "merge global stats") { - startTime, ok := startTime.(types.Time) - if !ok { - return nil, errors.New("invalid start time") - } - remainingDuration, progress, estimatedRowCnt, remainDurationErr := - getRemainDurationForAnalyzeStatusHelper(ctx, sctx, &startTime, - dbName, tableName, partitionName, processedRows) - if remainDurationErr != nil { - logutil.BgLogger().Warn("get remaining duration failed", zap.Error(remainDurationErr)) - } - if remainingDuration != nil { - remainDurationStr = execdetails.FormatDuration(*remainingDuration) - } - progressDouble = progress - estimatedRowCntStr = int64(estimatedRowCnt) - } - row := types.MakeDatums( - dbName, // TABLE_SCHEMA - tableName, // TABLE_NAME - partitionName, // PARTITION_NAME - jobInfo, // JOB_INFO - processedRows, // ROW_COUNT - startTime, // START_TIME - endTime, // END_TIME - state, // STATE - failReason, // FAIL_REASON - instance, // INSTANCE - procID, // PROCESS_ID - remainDurationStr, // REMAINING_SECONDS - progressDouble, // PROGRESS - estimatedRowCntStr, // ESTIMATED_TOTAL_ROWS - ) - rows = append(rows, row) - } - return -} - -func getRemainDurationForAnalyzeStatusHelper( - ctx context.Context, - sctx sessionctx.Context, startTime *types.Time, - dbName, tableName, partitionName string, processedRows int64) (*time.Duration, float64, float64, error) { - var remainingDuration = time.Duration(0) - var percentage = 0.0 - var totalCnt = float64(0) - if startTime != nil { - start, err := startTime.GoTime(time.UTC) - if err != nil { - return nil, percentage, totalCnt, err - } - duration := time.Now().UTC().Sub(start) - if intest.InTest { - if val := ctx.Value(AnalyzeProgressTest); val != nil { - remainingDuration, percentage = calRemainInfoForAnalyzeStatus(ctx, int64(totalCnt), processedRows, duration) - return &remainingDuration, percentage, totalCnt, nil - } - } - var tid int64 - is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() - tb, err := is.TableByName(ctx, model.NewCIStr(dbName), model.NewCIStr(tableName)) - if err != nil { - return nil, percentage, totalCnt, err - } - statsHandle := domain.GetDomain(sctx).StatsHandle() - if statsHandle != nil { - var statsTbl *statistics.Table - meta := tb.Meta() - if partitionName != "" { - pt := meta.GetPartitionInfo() - tid = pt.GetPartitionIDByName(partitionName) - statsTbl = statsHandle.GetPartitionStats(meta, tid) - } else { - statsTbl = statsHandle.GetTableStats(meta) - tid = meta.ID - } - if statsTbl != nil && statsTbl.RealtimeCount != 0 { - totalCnt = float64(statsTbl.RealtimeCount) - } - } - if (tid > 0 && totalCnt == 0) || float64(processedRows) > totalCnt { - totalCnt, _ = pdhelper.GlobalPDHelper.GetApproximateTableCountFromStorage(ctx, sctx, tid, dbName, tableName, partitionName) - } - remainingDuration, percentage = calRemainInfoForAnalyzeStatus(ctx, int64(totalCnt), processedRows, duration) - } - return &remainingDuration, percentage, totalCnt, nil -} - -func calRemainInfoForAnalyzeStatus(ctx context.Context, totalCnt int64, processedRows int64, duration time.Duration) (time.Duration, float64) { - if intest.InTest { - if val := ctx.Value(AnalyzeProgressTest); val != nil { - totalCnt = 100 // But in final result, it is still 0. - processedRows = 10 - duration = 1 * time.Minute - } - } - if totalCnt == 0 { - return 0, 100.0 - } - remainLine := totalCnt - processedRows - if processedRows == 0 { - processedRows = 1 - } - if duration == 0 { - duration = 1 * time.Second - } - i := float64(remainLine) * duration.Seconds() / float64(processedRows) - persentage := float64(processedRows) / float64(totalCnt) - return time.Duration(i) * time.Second, persentage -} - -// setDataForAnalyzeStatus gets all the analyze jobs. -func (e *memtableRetriever) setDataForAnalyzeStatus(ctx context.Context, sctx sessionctx.Context) (err error) { - e.rows, err = dataForAnalyzeStatusHelper(ctx, sctx) - return -} - -// setDataForPseudoProfiling returns pseudo data for table profiling when system variable `profiling` is set to `ON`. -func (e *memtableRetriever) setDataForPseudoProfiling(sctx sessionctx.Context) { - if v, ok := sctx.GetSessionVars().GetSystemVar("profiling"); ok && variable.TiDBOptOn(v) { - row := types.MakeDatums( - 0, // QUERY_ID - 0, // SEQ - "", // STATE - types.NewDecFromInt(0), // DURATION - types.NewDecFromInt(0), // CPU_USER - types.NewDecFromInt(0), // CPU_SYSTEM - 0, // CONTEXT_VOLUNTARY - 0, // CONTEXT_INVOLUNTARY - 0, // BLOCK_OPS_IN - 0, // BLOCK_OPS_OUT - 0, // MESSAGES_SENT - 0, // MESSAGES_RECEIVED - 0, // PAGE_FAULTS_MAJOR - 0, // PAGE_FAULTS_MINOR - 0, // SWAPS - "", // SOURCE_FUNCTION - "", // SOURCE_FILE - 0, // SOURCE_LINE - ) - e.rows = append(e.rows, row) - } -} - -func (e *memtableRetriever) setDataForServersInfo(ctx sessionctx.Context) error { - serversInfo, err := infosync.GetAllServerInfo(context.Background()) - if err != nil { - return err - } - rows := make([][]types.Datum, 0, len(serversInfo)) - for _, info := range serversInfo { - row := types.MakeDatums( - info.ID, // DDL_ID - info.IP, // IP - int(info.Port), // PORT - int(info.StatusPort), // STATUS_PORT - info.Lease, // LEASE - info.Version, // VERSION - info.GitHash, // GIT_HASH - info.BinlogStatus, // BINLOG_STATUS - stringutil.BuildStringFromLabels(info.Labels), // LABELS - ) - if sem.IsEnabled() { - checker := privilege.GetPrivilegeManager(ctx) - if checker == nil || !checker.RequestDynamicVerification(ctx.GetSessionVars().ActiveRoles, "RESTRICTED_TABLES_ADMIN", false) { - row[1].SetNull() // clear IP - } - } - rows = append(rows, row) - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromSequences(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - checker := privilege.GetPrivilegeManager(sctx) - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - var rows [][]types.Datum - for _, schema := range schemas { - if ok && extractor.Filter("sequence_schema", schema.L) { - continue - } - tables, err := e.is.SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, table := range tables { - if ok && extractor.Filter("sequence_name", table.Name.L) { - continue - } - if !table.IsSequence() { - continue - } - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.L, table.Name.L, "", mysql.AllPrivMask) { - continue - } - record := types.MakeDatums( - infoschema.CatalogVal, // TABLE_CATALOG - schema.O, // SEQUENCE_SCHEMA - table.Name.O, // SEQUENCE_NAME - table.Sequence.Cache, // Cache - table.Sequence.CacheValue, // CACHE_VALUE - table.Sequence.Cycle, // CYCLE - table.Sequence.Increment, // INCREMENT - table.Sequence.MaxValue, // MAXVALUE - table.Sequence.MinValue, // MINVALUE - table.Sequence.Start, // START - table.Sequence.Comment, // COMMENT - ) - rows = append(rows, record) - } - } - e.rows = rows - return nil -} - -// dataForTableTiFlashReplica constructs data for table tiflash replica info. -func (e *memtableRetriever) dataForTableTiFlashReplica(_ context.Context, sctx sessionctx.Context, _ []model.CIStr) error { - var ( - checker = privilege.GetPrivilegeManager(sctx) - rows [][]types.Datum - tiFlashStores map[int64]pd.StoreInfo - ) - rs := e.is.ListTablesWithSpecialAttribute(infoschema.TiFlashAttribute) - for _, schema := range rs { - for _, tbl := range schema.TableInfos { - if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.DBName, tbl.Name.L, "", mysql.AllPrivMask) { - continue - } - var progress float64 - if pi := tbl.GetPartitionInfo(); pi != nil && len(pi.Definitions) > 0 { - for _, p := range pi.Definitions { - progressOfPartition, err := infosync.MustGetTiFlashProgress(p.ID, tbl.TiFlashReplica.Count, &tiFlashStores) - if err != nil { - logutil.BgLogger().Error("dataForTableTiFlashReplica error", zap.Int64("tableID", tbl.ID), zap.Int64("partitionID", p.ID), zap.Error(err)) - } - progress += progressOfPartition - } - progress = progress / float64(len(pi.Definitions)) - } else { - var err error - progress, err = infosync.MustGetTiFlashProgress(tbl.ID, tbl.TiFlashReplica.Count, &tiFlashStores) - if err != nil { - logutil.BgLogger().Error("dataForTableTiFlashReplica error", zap.Int64("tableID", tbl.ID), zap.Error(err)) - } - } - progressString := types.TruncateFloatToString(progress, 2) - progress, _ = strconv.ParseFloat(progressString, 64) - record := types.MakeDatums( - schema.DBName, // TABLE_SCHEMA - tbl.Name.O, // TABLE_NAME - tbl.ID, // TABLE_ID - int64(tbl.TiFlashReplica.Count), // REPLICA_COUNT - strings.Join(tbl.TiFlashReplica.LocationLabels, ","), // LOCATION_LABELS - tbl.TiFlashReplica.Available, // AVAILABLE - progress, // PROGRESS - ) - rows = append(rows, record) - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForClientErrorsSummary(ctx sessionctx.Context, tableName string) error { - // Seeing client errors should require the PROCESS privilege, with the exception of errors for your own user. - // This is similar to information_schema.processlist, which is the closest comparison. - hasProcessPriv := hasPriv(ctx, mysql.ProcessPriv) - loginUser := ctx.GetSessionVars().User - - var rows [][]types.Datum - switch tableName { - case infoschema.TableClientErrorsSummaryGlobal: - if !hasProcessPriv { - return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - for code, summary := range errno.GlobalStats() { - firstSeen := types.NewTime(types.FromGoTime(summary.FirstSeen), mysql.TypeTimestamp, types.DefaultFsp) - lastSeen := types.NewTime(types.FromGoTime(summary.LastSeen), mysql.TypeTimestamp, types.DefaultFsp) - row := types.MakeDatums( - int(code), // ERROR_NUMBER - errno.MySQLErrName[code].Raw, // ERROR_MESSAGE - summary.ErrorCount, // ERROR_COUNT - summary.WarningCount, // WARNING_COUNT - firstSeen, // FIRST_SEEN - lastSeen, // LAST_SEEN - ) - rows = append(rows, row) - } - case infoschema.TableClientErrorsSummaryByUser: - for user, agg := range errno.UserStats() { - for code, summary := range agg { - // Allow anyone to see their own errors. - if !hasProcessPriv && loginUser != nil && loginUser.Username != user { - continue - } - firstSeen := types.NewTime(types.FromGoTime(summary.FirstSeen), mysql.TypeTimestamp, types.DefaultFsp) - lastSeen := types.NewTime(types.FromGoTime(summary.LastSeen), mysql.TypeTimestamp, types.DefaultFsp) - row := types.MakeDatums( - user, // USER - int(code), // ERROR_NUMBER - errno.MySQLErrName[code].Raw, // ERROR_MESSAGE - summary.ErrorCount, // ERROR_COUNT - summary.WarningCount, // WARNING_COUNT - firstSeen, // FIRST_SEEN - lastSeen, // LAST_SEEN - ) - rows = append(rows, row) - } - } - case infoschema.TableClientErrorsSummaryByHost: - if !hasProcessPriv { - return plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - for host, agg := range errno.HostStats() { - for code, summary := range agg { - firstSeen := types.NewTime(types.FromGoTime(summary.FirstSeen), mysql.TypeTimestamp, types.DefaultFsp) - lastSeen := types.NewTime(types.FromGoTime(summary.LastSeen), mysql.TypeTimestamp, types.DefaultFsp) - row := types.MakeDatums( - host, // HOST - int(code), // ERROR_NUMBER - errno.MySQLErrName[code].Raw, // ERROR_MESSAGE - summary.ErrorCount, // ERROR_COUNT - summary.WarningCount, // WARNING_COUNT - firstSeen, // FIRST_SEEN - lastSeen, // LAST_SEEN - ) - rows = append(rows, row) - } - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForTrxSummary(ctx sessionctx.Context) error { - hasProcessPriv := hasPriv(ctx, mysql.ProcessPriv) - if !hasProcessPriv { - return nil - } - rows := txninfo.Recorder.DumpTrxSummary() - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForClusterTrxSummary(ctx sessionctx.Context) error { - err := e.setDataForTrxSummary(ctx) - if err != nil { - return err - } - rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) - if err != nil { - return err - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForMemoryUsage() error { - r := memory.ReadMemStats() - currentOps, sessionKillLastDatum := types.NewDatum(nil), types.NewDatum(nil) - if memory.TriggerMemoryLimitGC.Load() || servermemorylimit.IsKilling.Load() { - currentOps.SetString("shrink", mysql.DefaultCollationName) - } - sessionKillLast := servermemorylimit.SessionKillLast.Load() - if !sessionKillLast.IsZero() { - sessionKillLastDatum.SetMysqlTime(types.NewTime(types.FromGoTime(sessionKillLast), mysql.TypeDatetime, 0)) - } - gcLast := types.NewTime(types.FromGoTime(memory.MemoryLimitGCLast.Load()), mysql.TypeDatetime, 0) - - row := []types.Datum{ - types.NewIntDatum(int64(memory.GetMemTotalIgnoreErr())), // MEMORY_TOTAL - types.NewIntDatum(int64(memory.ServerMemoryLimit.Load())), // MEMORY_LIMIT - types.NewIntDatum(int64(r.HeapInuse)), // MEMORY_CURRENT - types.NewIntDatum(int64(servermemorylimit.MemoryMaxUsed.Load())), // MEMORY_MAX_USED - currentOps, // CURRENT_OPS - sessionKillLastDatum, // SESSION_KILL_LAST - types.NewIntDatum(servermemorylimit.SessionKillTotal.Load()), // SESSION_KILL_TOTAL - types.NewTimeDatum(gcLast), // GC_LAST - types.NewIntDatum(memory.MemoryLimitGCTotal.Load()), // GC_TOTAL - types.NewDatum(GlobalDiskUsageTracker.BytesConsumed()), // DISK_USAGE - types.NewDatum(memory.QueryForceDisk.Load()), // QUERY_FORCE_DISK - } - e.rows = append(e.rows, row) - return nil -} - -func (e *memtableRetriever) setDataForClusterMemoryUsage(ctx sessionctx.Context) error { - err := e.setDataForMemoryUsage() - if err != nil { - return err - } - rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) - if err != nil { - return err - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForMemoryUsageOpsHistory() error { - e.rows = servermemorylimit.GlobalMemoryOpsHistoryManager.GetRows() - return nil -} - -func (e *memtableRetriever) setDataForClusterMemoryUsageOpsHistory(ctx sessionctx.Context) error { - err := e.setDataForMemoryUsageOpsHistory() - if err != nil { - return err - } - rows, err := infoschema.AppendHostInfoToRows(ctx, e.rows) - if err != nil { - return err - } - e.rows = rows - return nil -} - -// tidbTrxTableRetriever is the memtable retriever for the TIDB_TRX and CLUSTER_TIDB_TRX table. -type tidbTrxTableRetriever struct { - dummyCloser - batchRetrieverHelper - table *model.TableInfo - columns []*model.ColumnInfo - txnInfo []*txninfo.TxnInfo - initialized bool -} - -func (e *tidbTrxTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.retrieved { - return nil, nil - } - - if !e.initialized { - e.initialized = true - - sm := sctx.GetSessionManager() - if sm == nil { - e.retrieved = true - return nil, nil - } - - loginUser := sctx.GetSessionVars().User - hasProcessPriv := hasPriv(sctx, mysql.ProcessPriv) - infoList := sm.ShowTxnList() - e.txnInfo = make([]*txninfo.TxnInfo, 0, len(infoList)) - for _, info := range infoList { - // If you have the PROCESS privilege, you can see all running transactions. - // Otherwise, you can see only your own transactions. - if !hasProcessPriv && loginUser != nil && info.Username != loginUser.Username { - continue - } - e.txnInfo = append(e.txnInfo, info) - } - - e.batchRetrieverHelper.totalRows = len(e.txnInfo) - e.batchRetrieverHelper.batchSize = 1024 - } - - sqlExec := sctx.GetRestrictedSQLExecutor() - - var err error - // The current TiDB node's address is needed by the CLUSTER_TIDB_TRX table. - var instanceAddr string - if e.table.Name.O == infoschema.ClusterTableTiDBTrx { - instanceAddr, err = infoschema.GetInstanceAddr(sctx) - if err != nil { - return nil, err - } - } - - var res [][]types.Datum - err = e.nextBatch(func(start, end int) error { - // Before getting rows, collect the SQL digests that needs to be retrieved first. - var sqlRetriever *expression.SQLDigestTextRetriever - for _, c := range e.columns { - if c.Name.O == txninfo.CurrentSQLDigestTextStr { - if sqlRetriever == nil { - sqlRetriever = expression.NewSQLDigestTextRetriever() - } - - for i := start; i < end; i++ { - sqlRetriever.SQLDigestsMap[e.txnInfo[i].CurrentSQLDigest] = "" - } - } - } - // Retrieve the SQL texts if necessary. - if sqlRetriever != nil { - err1 := sqlRetriever.RetrieveLocal(ctx, sqlExec) - if err1 != nil { - return errors.Trace(err1) - } - } - - res = make([][]types.Datum, 0, end-start) - - // Calculate rows. - for i := start; i < end; i++ { - row := make([]types.Datum, 0, len(e.columns)) - for _, c := range e.columns { - if c.Name.O == util.ClusterTableInstanceColumnName { - row = append(row, types.NewDatum(instanceAddr)) - } else if c.Name.O == txninfo.CurrentSQLDigestTextStr { - if text, ok := sqlRetriever.SQLDigestsMap[e.txnInfo[i].CurrentSQLDigest]; ok && len(text) != 0 { - row = append(row, types.NewDatum(text)) - } else { - row = append(row, types.NewDatum(nil)) - } - } else { - switch c.Name.O { - case txninfo.MemBufferBytesStr: - memDBFootprint := sctx.GetSessionVars().MemDBFootprint - var bytesConsumed int64 - if memDBFootprint != nil { - bytesConsumed = memDBFootprint.BytesConsumed() - } - row = append(row, types.NewDatum(bytesConsumed)) - default: - row = append(row, e.txnInfo[i].ToDatum(c.Name.O)) - } - } - } - res = append(res, row) - } - - return nil - }) - - if err != nil { - return nil, err - } - - return res, nil -} - -// dataLockWaitsTableRetriever is the memtable retriever for the DATA_LOCK_WAITS table. -type dataLockWaitsTableRetriever struct { - dummyCloser - batchRetrieverHelper - table *model.TableInfo - columns []*model.ColumnInfo - lockWaits []*deadlock.WaitForEntry - resolvingLocks []txnlock.ResolvingLock - initialized bool -} - -func (r *dataLockWaitsTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if r.retrieved { - return nil, nil - } - - if !r.initialized { - if !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - - r.initialized = true - var err error - r.lockWaits, err = sctx.GetStore().GetLockWaits() - tikvStore, _ := sctx.GetStore().(helper.Storage) - r.resolvingLocks = tikvStore.GetLockResolver().Resolving() - if err != nil { - r.retrieved = true - return nil, err - } - - r.batchRetrieverHelper.totalRows = len(r.lockWaits) + len(r.resolvingLocks) - r.batchRetrieverHelper.batchSize = 1024 - } - - var res [][]types.Datum - - err := r.nextBatch(func(start, end int) error { - // Before getting rows, collect the SQL digests that needs to be retrieved first. - var needDigest bool - var needSQLText bool - for _, c := range r.columns { - if c.Name.O == infoschema.DataLockWaitsColumnSQLDigestText { - needSQLText = true - } else if c.Name.O == infoschema.DataLockWaitsColumnSQLDigest { - needDigest = true - } - } - - var digests []string - if needDigest || needSQLText { - digests = make([]string, end-start) - for i, lockWait := range r.lockWaits { - digest, err := resourcegrouptag.DecodeResourceGroupTag(lockWait.ResourceGroupTag) - if err != nil { - // Ignore the error if failed to decode the digest from resource_group_tag. We still want to show - // as much information as possible even we can't retrieve some of them. - logutil.Logger(ctx).Warn("failed to decode resource group tag", zap.Error(err)) - } else { - digests[i] = hex.EncodeToString(digest) - } - } - // todo: support resourcegrouptag for resolvingLocks - } - - // Fetch the SQL Texts of the digests above if necessary. - var sqlRetriever *expression.SQLDigestTextRetriever - if needSQLText { - sqlRetriever = expression.NewSQLDigestTextRetriever() - for _, digest := range digests { - if len(digest) > 0 { - sqlRetriever.SQLDigestsMap[digest] = "" - } - } - - err := sqlRetriever.RetrieveGlobal(ctx, sctx.GetRestrictedSQLExecutor()) - if err != nil { - return errors.Trace(err) - } - } - - // Calculate rows. - res = make([][]types.Datum, 0, end-start) - // data_lock_waits contains both lockWaits (pessimistic lock waiting) - // and resolving (optimistic lock "waiting") info - // first we'll return the lockWaits, and then resolving, so we need to - // do some index calculation here - lockWaitsStart := min(start, len(r.lockWaits)) - resolvingStart := start - lockWaitsStart - lockWaitsEnd := min(end, len(r.lockWaits)) - resolvingEnd := end - lockWaitsEnd - for rowIdx, lockWait := range r.lockWaits[lockWaitsStart:lockWaitsEnd] { - row := make([]types.Datum, 0, len(r.columns)) - - for _, col := range r.columns { - switch col.Name.O { - case infoschema.DataLockWaitsColumnKey: - row = append(row, types.NewDatum(strings.ToUpper(hex.EncodeToString(lockWait.Key)))) - case infoschema.DataLockWaitsColumnKeyInfo: - infoSchema := sctx.GetInfoSchema().(infoschema.InfoSchema) - var decodedKeyStr any - decodedKey, err := keydecoder.DecodeKey(lockWait.Key, infoSchema) - if err == nil { - decodedKeyBytes, err := json.Marshal(decodedKey) - if err != nil { - logutil.BgLogger().Warn("marshal decoded key info to JSON failed", zap.Error(err)) - } else { - decodedKeyStr = string(decodedKeyBytes) - } - } else { - logutil.Logger(ctx).Warn("decode key failed", zap.Error(err)) - } - row = append(row, types.NewDatum(decodedKeyStr)) - case infoschema.DataLockWaitsColumnTrxID: - row = append(row, types.NewDatum(lockWait.Txn)) - case infoschema.DataLockWaitsColumnCurrentHoldingTrxID: - row = append(row, types.NewDatum(lockWait.WaitForTxn)) - case infoschema.DataLockWaitsColumnSQLDigest: - digest := digests[rowIdx] - if len(digest) == 0 { - row = append(row, types.NewDatum(nil)) - } else { - row = append(row, types.NewDatum(digest)) - } - case infoschema.DataLockWaitsColumnSQLDigestText: - text := sqlRetriever.SQLDigestsMap[digests[rowIdx]] - if len(text) > 0 { - row = append(row, types.NewDatum(text)) - } else { - row = append(row, types.NewDatum(nil)) - } - default: - row = append(row, types.NewDatum(nil)) - } - } - - res = append(res, row) - } - for _, resolving := range r.resolvingLocks[resolvingStart:resolvingEnd] { - row := make([]types.Datum, 0, len(r.columns)) - - for _, col := range r.columns { - switch col.Name.O { - case infoschema.DataLockWaitsColumnKey: - row = append(row, types.NewDatum(strings.ToUpper(hex.EncodeToString(resolving.Key)))) - case infoschema.DataLockWaitsColumnKeyInfo: - infoSchema := domain.GetDomain(sctx).InfoSchema() - var decodedKeyStr any - decodedKey, err := keydecoder.DecodeKey(resolving.Key, infoSchema) - if err == nil { - decodedKeyBytes, err := json.Marshal(decodedKey) - if err != nil { - logutil.Logger(ctx).Warn("marshal decoded key info to JSON failed", zap.Error(err)) - } else { - decodedKeyStr = string(decodedKeyBytes) - } - } else { - logutil.Logger(ctx).Warn("decode key failed", zap.Error(err)) - } - row = append(row, types.NewDatum(decodedKeyStr)) - case infoschema.DataLockWaitsColumnTrxID: - row = append(row, types.NewDatum(resolving.TxnID)) - case infoschema.DataLockWaitsColumnCurrentHoldingTrxID: - row = append(row, types.NewDatum(resolving.LockTxnID)) - case infoschema.DataLockWaitsColumnSQLDigest: - // todo: support resourcegrouptag for resolvingLocks - row = append(row, types.NewDatum(nil)) - case infoschema.DataLockWaitsColumnSQLDigestText: - // todo: support resourcegrouptag for resolvingLocks - row = append(row, types.NewDatum(nil)) - default: - row = append(row, types.NewDatum(nil)) - } - } - - res = append(res, row) - } - return nil - }) - - if err != nil { - return nil, err - } - - return res, nil -} - -// deadlocksTableRetriever is the memtable retriever for the DEADLOCKS and CLUSTER_DEADLOCKS table. -type deadlocksTableRetriever struct { - dummyCloser - batchRetrieverHelper - - currentIdx int - currentWaitChainIdx int - - table *model.TableInfo - columns []*model.ColumnInfo - deadlocks []*deadlockhistory.DeadlockRecord - initialized bool -} - -// nextIndexPair advances a index pair (where `idx` is the index of the DeadlockRecord, and `waitChainIdx` is the index -// of the wait chain item in the `idx`-th DeadlockRecord. This function helps iterate over each wait chain item -// in all DeadlockRecords. -func (r *deadlocksTableRetriever) nextIndexPair(idx, waitChainIdx int) (a, b int) { - waitChainIdx++ - if waitChainIdx >= len(r.deadlocks[idx].WaitChain) { - waitChainIdx = 0 - idx++ - for idx < len(r.deadlocks) && len(r.deadlocks[idx].WaitChain) == 0 { - idx++ - } - } - return idx, waitChainIdx -} - -func (r *deadlocksTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if r.retrieved { - return nil, nil - } - - if !r.initialized { - if !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - - r.initialized = true - r.deadlocks = deadlockhistory.GlobalDeadlockHistory.GetAll() - - r.batchRetrieverHelper.totalRows = 0 - for _, d := range r.deadlocks { - r.batchRetrieverHelper.totalRows += len(d.WaitChain) - } - r.batchRetrieverHelper.batchSize = 1024 - } - - // The current TiDB node's address is needed by the CLUSTER_DEADLOCKS table. - var err error - var instanceAddr string - if r.table.Name.O == infoschema.ClusterTableDeadlocks { - instanceAddr, err = infoschema.GetInstanceAddr(sctx) - if err != nil { - return nil, err - } - } - - infoSchema := sctx.GetInfoSchema().(infoschema.InfoSchema) - - var res [][]types.Datum - - err = r.nextBatch(func(start, end int) error { - // Before getting rows, collect the SQL digests that needs to be retrieved first. - var sqlRetriever *expression.SQLDigestTextRetriever - for _, c := range r.columns { - if c.Name.O == deadlockhistory.ColCurrentSQLDigestTextStr { - if sqlRetriever == nil { - sqlRetriever = expression.NewSQLDigestTextRetriever() - } - - idx, waitChainIdx := r.currentIdx, r.currentWaitChainIdx - for i := start; i < end; i++ { - if idx >= len(r.deadlocks) { - return errors.New("reading information_schema.(cluster_)deadlocks table meets corrupted index") - } - - sqlRetriever.SQLDigestsMap[r.deadlocks[idx].WaitChain[waitChainIdx].SQLDigest] = "" - // Step to the next entry - idx, waitChainIdx = r.nextIndexPair(idx, waitChainIdx) - } - } - } - // Retrieve the SQL texts if necessary. - if sqlRetriever != nil { - err1 := sqlRetriever.RetrieveGlobal(ctx, sctx.GetRestrictedSQLExecutor()) - if err1 != nil { - return errors.Trace(err1) - } - } - - res = make([][]types.Datum, 0, end-start) - - for i := start; i < end; i++ { - if r.currentIdx >= len(r.deadlocks) { - return errors.New("reading information_schema.(cluster_)deadlocks table meets corrupted index") - } - - row := make([]types.Datum, 0, len(r.columns)) - deadlock := r.deadlocks[r.currentIdx] - waitChainItem := deadlock.WaitChain[r.currentWaitChainIdx] - - for _, c := range r.columns { - if c.Name.O == util.ClusterTableInstanceColumnName { - row = append(row, types.NewDatum(instanceAddr)) - } else if c.Name.O == deadlockhistory.ColCurrentSQLDigestTextStr { - if text, ok := sqlRetriever.SQLDigestsMap[waitChainItem.SQLDigest]; ok && len(text) > 0 { - row = append(row, types.NewDatum(text)) - } else { - row = append(row, types.NewDatum(nil)) - } - } else if c.Name.O == deadlockhistory.ColKeyInfoStr { - value := types.NewDatum(nil) - if len(waitChainItem.Key) > 0 { - decodedKey, err := keydecoder.DecodeKey(waitChainItem.Key, infoSchema) - if err == nil { - decodedKeyJSON, err := json.Marshal(decodedKey) - if err != nil { - logutil.BgLogger().Warn("marshal decoded key info to JSON failed", zap.Error(err)) - } else { - value = types.NewDatum(string(decodedKeyJSON)) - } - } else { - logutil.Logger(ctx).Warn("decode key failed", zap.Error(err)) - } - } - row = append(row, value) - } else { - row = append(row, deadlock.ToDatum(r.currentWaitChainIdx, c.Name.O)) - } - } - - res = append(res, row) - // Step to the next entry - r.currentIdx, r.currentWaitChainIdx = r.nextIndexPair(r.currentIdx, r.currentWaitChainIdx) - } - - return nil - }) - - if err != nil { - return nil, err - } - - return res, nil -} - -func adjustColumns(input [][]types.Datum, outColumns []*model.ColumnInfo, table *model.TableInfo) [][]types.Datum { - if len(outColumns) == len(table.Columns) { - return input - } - rows := make([][]types.Datum, len(input)) - for i, fullRow := range input { - row := make([]types.Datum, len(outColumns)) - for j, col := range outColumns { - row[j] = fullRow[col.Offset] - } - rows[i] = row - } - return rows -} - -// TiFlashSystemTableRetriever is used to read system table from tiflash. -type TiFlashSystemTableRetriever struct { - dummyCloser - table *model.TableInfo - outputCols []*model.ColumnInfo - instanceCount int - instanceIdx int - instanceIDs []string - rowIdx int - retrieved bool - initialized bool - extractor *plannercore.TiFlashSystemTableExtractor -} - -func (e *TiFlashSystemTableRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.extractor.SkipRequest || e.retrieved { - return nil, nil - } - if !e.initialized { - err := e.initialize(sctx, e.extractor.TiFlashInstances) - if err != nil { - return nil, err - } - } - if e.instanceCount == 0 || e.instanceIdx >= e.instanceCount { - e.retrieved = true - return nil, nil - } - - for { - rows, err := e.dataForTiFlashSystemTables(ctx, sctx, e.extractor.TiDBDatabases, e.extractor.TiDBTables) - if err != nil { - return nil, err - } - if len(rows) > 0 || e.instanceIdx >= e.instanceCount { - return rows, nil - } - } -} - -func (e *TiFlashSystemTableRetriever) initialize(sctx sessionctx.Context, tiflashInstances set.StringSet) error { - storeInfo, err := infoschema.GetStoreServerInfo(sctx.GetStore()) - if err != nil { - return err - } - - for _, info := range storeInfo { - if info.ServerType != kv.TiFlash.Name() { - continue - } - info.ResolveLoopBackAddr() - if len(tiflashInstances) > 0 && !tiflashInstances.Exist(info.Address) { - continue - } - hostAndStatusPort := strings.Split(info.StatusAddr, ":") - if len(hostAndStatusPort) != 2 { - return errors.Errorf("node status addr: %s format illegal", info.StatusAddr) - } - e.instanceIDs = append(e.instanceIDs, info.Address) - e.instanceCount++ - } - e.initialized = true - return nil -} - -type tiFlashSQLExecuteResponseMetaColumn struct { - Name string `json:"name"` - Type string `json:"type"` -} - -type tiFlashSQLExecuteResponse struct { - Meta []tiFlashSQLExecuteResponseMetaColumn `json:"meta"` - Data [][]any `json:"data"` -} - -func (e *TiFlashSystemTableRetriever) dataForTiFlashSystemTables(ctx context.Context, sctx sessionctx.Context, tidbDatabases string, tidbTables string) ([][]types.Datum, error) { - maxCount := 1024 - targetTable := strings.ToLower(strings.Replace(e.table.Name.O, "TIFLASH", "DT", 1)) - var filters []string - if len(tidbDatabases) > 0 { - filters = append(filters, fmt.Sprintf("tidb_database IN (%s)", strings.ReplaceAll(tidbDatabases, "\"", "'"))) - } - if len(tidbTables) > 0 { - filters = append(filters, fmt.Sprintf("tidb_table IN (%s)", strings.ReplaceAll(tidbTables, "\"", "'"))) - } - sql := fmt.Sprintf("SELECT * FROM system.%s", targetTable) - if len(filters) > 0 { - sql = fmt.Sprintf("%s WHERE %s", sql, strings.Join(filters, " AND ")) - } - sql = fmt.Sprintf("%s LIMIT %d, %d", sql, e.rowIdx, maxCount) - request := tikvrpc.Request{ - Type: tikvrpc.CmdGetTiFlashSystemTable, - StoreTp: tikvrpc.TiFlash, - Req: &kvrpcpb.TiFlashSystemTableRequest{ - Sql: sql, - }, - } - - store := sctx.GetStore() - tikvStore, ok := store.(tikv.Storage) - if !ok { - return nil, errors.New("Get tiflash system tables can only run with tikv compatible storage") - } - // send request to tiflash, timeout is 1s - instanceID := e.instanceIDs[e.instanceIdx] - resp, err := tikvStore.GetTiKVClient().SendRequest(ctx, instanceID, &request, time.Second) - if err != nil { - return nil, errors.Trace(err) - } - var result tiFlashSQLExecuteResponse - tiflashResp, ok := resp.Resp.(*kvrpcpb.TiFlashSystemTableResponse) - if !ok { - return nil, errors.Errorf("Unexpected response type: %T", resp.Resp) - } - err = json.Unmarshal(tiflashResp.Data, &result) - if err != nil { - return nil, errors.Wrapf(err, "Failed to decode JSON from TiFlash") - } - - // Map result columns back to our columns. It is possible that some columns cannot be - // recognized and some other columns are missing. This may happen during upgrading. - outputColIndexMap := map[string]int{} // Map from TiDB Column name to Output Column Index - for idx, c := range e.outputCols { - outputColIndexMap[c.Name.L] = idx - } - tiflashColIndexMap := map[int]int{} // Map from TiFlash Column index to Output Column Index - for tiFlashColIdx, col := range result.Meta { - if outputIdx, ok := outputColIndexMap[strings.ToLower(col.Name)]; ok { - tiflashColIndexMap[tiFlashColIdx] = outputIdx - } - } - outputRows := make([][]types.Datum, 0, len(result.Data)) - for _, rowFields := range result.Data { - if len(rowFields) == 0 { - continue - } - outputRow := make([]types.Datum, len(e.outputCols)) - for tiFlashColIdx, fieldValue := range rowFields { - outputIdx, ok := tiflashColIndexMap[tiFlashColIdx] - if !ok { - // Discard this field, we don't know which output column is the destination - continue - } - if fieldValue == nil { - continue - } - valStr := fmt.Sprint(fieldValue) - column := e.outputCols[outputIdx] - if column.GetType() == mysql.TypeVarchar { - outputRow[outputIdx].SetString(valStr, mysql.DefaultCollationName) - } else if column.GetType() == mysql.TypeLonglong { - value, err := strconv.ParseInt(valStr, 10, 64) - if err != nil { - return nil, errors.Trace(err) - } - outputRow[outputIdx].SetInt64(value) - } else if column.GetType() == mysql.TypeDouble { - value, err := strconv.ParseFloat(valStr, 64) - if err != nil { - return nil, errors.Trace(err) - } - outputRow[outputIdx].SetFloat64(value) - } else { - return nil, errors.Errorf("Meet column of unknown type %v", column) - } - } - outputRow[len(e.outputCols)-1].SetString(instanceID, mysql.DefaultCollationName) - outputRows = append(outputRows, outputRow) - } - e.rowIdx += len(outputRows) - if len(outputRows) < maxCount { - e.instanceIdx++ - e.rowIdx = 0 - } - return outputRows, nil -} - -func (e *memtableRetriever) setDataForAttributes(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema) error { - checker := privilege.GetPrivilegeManager(sctx) - rules, err := infosync.GetAllLabelRules(context.TODO()) - skipValidateTable := false - failpoint.Inject("mockOutputOfAttributes", func() { - convert := func(i any) []any { - return []any{i} - } - rules = []*label.Rule{ - { - ID: "schema/test/test_label", - Labels: []pd.RegionLabel{{Key: "merge_option", Value: "allow"}, {Key: "db", Value: "test"}, {Key: "table", Value: "test_label"}}, - RuleType: "key-range", - Data: convert(map[string]any{ - "start_key": "7480000000000000ff395f720000000000fa", - "end_key": "7480000000000000ff3a5f720000000000fa", - }), - }, - { - ID: "invalidIDtest", - Labels: []pd.RegionLabel{{Key: "merge_option", Value: "allow"}, {Key: "db", Value: "test"}, {Key: "table", Value: "test_label"}}, - RuleType: "key-range", - Data: convert(map[string]any{ - "start_key": "7480000000000000ff395f720000000000fa", - "end_key": "7480000000000000ff3a5f720000000000fa", - }), - }, - { - ID: "schema/test/test_label", - Labels: []pd.RegionLabel{{Key: "merge_option", Value: "allow"}, {Key: "db", Value: "test"}, {Key: "table", Value: "test_label"}}, - RuleType: "key-range", - Data: convert(map[string]any{ - "start_key": "aaaaa", - "end_key": "bbbbb", - }), - }, - } - err = nil - skipValidateTable = true - }) - - if err != nil { - return errors.Wrap(err, "get the label rules failed") - } - - rows := make([][]types.Datum, 0, len(rules)) - for _, rule := range rules { - skip := true - dbName, tableName, partitionName, err := checkRule(rule) - if err != nil { - logutil.BgLogger().Warn("check table-rule failed", zap.String("ID", rule.ID), zap.Error(err)) - continue - } - tableID, err := decodeTableIDFromRule(rule) - if err != nil { - logutil.BgLogger().Warn("decode table ID from rule failed", zap.String("ID", rule.ID), zap.Error(err)) - continue - } - - if !skipValidateTable && tableOrPartitionNotExist(ctx, dbName, tableName, partitionName, is, tableID) { - continue - } - - if tableName != "" && dbName != "" && (checker == nil || checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, dbName, tableName, "", mysql.SelectPriv)) { - skip = false - } - if skip { - continue - } - - labels := label.RestoreRegionLabels(&rule.Labels) - var ranges []string - for _, data := range rule.Data.([]any) { - if kv, ok := data.(map[string]any); ok { - startKey := kv["start_key"] - endKey := kv["end_key"] - ranges = append(ranges, fmt.Sprintf("[%s, %s]", startKey, endKey)) - } - } - kr := strings.Join(ranges, ", ") - - row := types.MakeDatums( - rule.ID, - rule.RuleType, - labels, - kr, - ) - rows = append(rows, row) - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromPlacementPolicies(sctx sessionctx.Context) error { - is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() - placementPolicies := is.AllPlacementPolicies() - rows := make([][]types.Datum, 0, len(placementPolicies)) - // Get global PLACEMENT POLICIES - // Currently no privileges needed for seeing global PLACEMENT POLICIES! - for _, policy := range placementPolicies { - // Currently we skip converting syntactic sugar. We might revisit this decision still in the future - // I.e.: if PrimaryRegion or Regions are set, - // also convert them to LeaderConstraints and FollowerConstraints - // for better user experience searching for particular constraints - - // Followers == 0 means not set, so the default value 2 will be used - followerCnt := policy.PlacementSettings.Followers - if followerCnt == 0 { - followerCnt = 2 - } - - row := types.MakeDatums( - policy.ID, - infoschema.CatalogVal, // CATALOG - policy.Name.O, // Policy Name - policy.PlacementSettings.PrimaryRegion, - policy.PlacementSettings.Regions, - policy.PlacementSettings.Constraints, - policy.PlacementSettings.LeaderConstraints, - policy.PlacementSettings.FollowerConstraints, - policy.PlacementSettings.LearnerConstraints, - policy.PlacementSettings.Schedule, - followerCnt, - policy.PlacementSettings.Learners, - ) - rows = append(rows, row) - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromRunawayWatches(sctx sessionctx.Context) error { - do := domain.GetDomain(sctx) - err := do.TryToUpdateRunawayWatch() - if err != nil { - logutil.BgLogger().Warn("read runaway watch list", zap.Error(err)) - } - watches := do.GetRunawayWatchList() - rows := make([][]types.Datum, 0, len(watches)) - for _, watch := range watches { - action := watch.Action - row := types.MakeDatums( - watch.ID, - watch.ResourceGroupName, - watch.StartTime.UTC().Format(time.DateTime), - watch.EndTime.UTC().Format(time.DateTime), - rmpb.RunawayWatchType_name[int32(watch.Watch)], - watch.WatchText, - watch.Source, - rmpb.RunawayAction_name[int32(action)], - ) - if watch.EndTime.Equal(resourcegroup.NullTime) { - row[3].SetString("UNLIMITED", mysql.DefaultCollationName) - } - rows = append(rows, row) - } - e.rows = rows - return nil -} - -// used in resource_groups -const ( - burstableStr = "YES" - burstdisableStr = "NO" - unlimitedFillRate = "UNLIMITED" -) - -func (e *memtableRetriever) setDataFromResourceGroups() error { - resourceGroups, err := infosync.ListResourceGroups(context.TODO()) - if err != nil { - return errors.Errorf("failed to access resource group manager, error message is %s", err.Error()) - } - rows := make([][]types.Datum, 0, len(resourceGroups)) - for _, group := range resourceGroups { - //mode := "" - burstable := burstdisableStr - priority := model.PriorityValueToName(uint64(group.Priority)) - fillrate := unlimitedFillRate - // RU_PER_SEC = unlimited like the default group settings. - isDefaultInReservedSetting := group.RUSettings.RU.Settings.FillRate == math.MaxInt32 - if !isDefaultInReservedSetting { - fillrate = strconv.FormatUint(group.RUSettings.RU.Settings.FillRate, 10) - } - // convert runaway settings - limitBuilder := new(strings.Builder) - if setting := group.RunawaySettings; setting != nil { - if setting.Rule == nil { - return errors.Errorf("unexpected runaway config in resource group") - } - dur := time.Duration(setting.Rule.ExecElapsedTimeMs) * time.Millisecond - fmt.Fprintf(limitBuilder, "EXEC_ELAPSED='%s'", dur.String()) - fmt.Fprintf(limitBuilder, ", ACTION=%s", model.RunawayActionType(setting.Action).String()) - if setting.Watch != nil { - if setting.Watch.LastingDurationMs > 0 { - dur := time.Duration(setting.Watch.LastingDurationMs) * time.Millisecond - fmt.Fprintf(limitBuilder, ", WATCH=%s DURATION='%s'", model.RunawayWatchType(setting.Watch.Type).String(), dur.String()) - } else { - fmt.Fprintf(limitBuilder, ", WATCH=%s DURATION=UNLIMITED", model.RunawayWatchType(setting.Watch.Type).String()) - } - } - } - queryLimit := limitBuilder.String() - - // convert background settings - bgBuilder := new(strings.Builder) - if setting := group.BackgroundSettings; setting != nil { - fmt.Fprintf(bgBuilder, "TASK_TYPES='%s'", strings.Join(setting.JobTypes, ",")) - } - background := bgBuilder.String() - - switch group.Mode { - case rmpb.GroupMode_RUMode: - if group.RUSettings.RU.Settings.BurstLimit < 0 { - burstable = burstableStr - } - row := types.MakeDatums( - group.Name, - fillrate, - priority, - burstable, - queryLimit, - background, - ) - if len(queryLimit) == 0 { - row[4].SetNull() - } - if len(background) == 0 { - row[5].SetNull() - } - rows = append(rows, row) - default: - //mode = "UNKNOWN_MODE" - row := types.MakeDatums( - group.Name, - nil, - nil, - nil, - nil, - nil, - ) - rows = append(rows, row) - } - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromKeywords() error { - rows := make([][]types.Datum, 0, len(parser.Keywords)) - for _, kw := range parser.Keywords { - row := types.MakeDatums(kw.Word, kw.Reserved) - rows = append(rows, row) - } - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataFromIndexUsage(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - dom := domain.GetDomain(sctx) - rows := make([][]types.Datum, 0, 100) - checker := privilege.GetPrivilegeManager(sctx) - extractor, ok := e.extractor.(*plannercore.InfoSchemaBaseExtractor) - if ok && extractor.SkipRequest { - return nil - } - - for _, schema := range schemas { - if ok && extractor.Filter("table_schema", schema.L) { - continue - } - tables, err := dom.InfoSchema().SchemaTableInfos(ctx, schema) - if err != nil { - return errors.Trace(err) - } - for _, tbl := range tables { - if ok && extractor.Filter("table_name", tbl.Name.L) { - continue - } - allowed := checker == nil || checker.RequestVerification( - sctx.GetSessionVars().ActiveRoles, - schema.L, tbl.Name.L, "", mysql.AllPrivMask) - if !allowed { - continue - } - - for _, idx := range tbl.Indices { - if ok && extractor.Filter("index_name", idx.Name.L) { - continue - } - row := make([]types.Datum, 0, 14) - usage := dom.StatsHandle().GetIndexUsage(tbl.ID, idx.ID) - row = append(row, types.NewStringDatum(schema.O)) - row = append(row, types.NewStringDatum(tbl.Name.O)) - row = append(row, types.NewStringDatum(idx.Name.O)) - row = append(row, types.NewIntDatum(int64(usage.QueryTotal))) - row = append(row, types.NewIntDatum(int64(usage.KvReqTotal))) - row = append(row, types.NewIntDatum(int64(usage.RowAccessTotal))) - for _, percentage := range usage.PercentageAccess { - row = append(row, types.NewIntDatum(int64(percentage))) - } - lastUsedAt := types.Datum{} - lastUsedAt.SetNull() - if !usage.LastUsedAt.IsZero() { - t := types.NewTime(types.FromGoTime(usage.LastUsedAt), mysql.TypeTimestamp, 0) - lastUsedAt = types.NewTimeDatum(t) - } - row = append(row, lastUsedAt) - rows = append(rows, row) - } - } - } - - e.rows = rows - return nil -} - -func (e *memtableRetriever) setDataForClusterIndexUsage(ctx context.Context, sctx sessionctx.Context, schemas []model.CIStr) error { - err := e.setDataFromIndexUsage(ctx, sctx, schemas) - if err != nil { - return errors.Trace(err) - } - rows, err := infoschema.AppendHostInfoToRows(sctx, e.rows) - if err != nil { - return err - } - e.rows = rows - return nil -} - -func checkRule(rule *label.Rule) (dbName, tableName string, partitionName string, err error) { - s := strings.Split(rule.ID, "/") - if len(s) < 3 { - err = errors.Errorf("invalid label rule ID: %v", rule.ID) - return - } - if rule.RuleType == "" { - err = errors.New("empty label rule type") - return - } - if rule.Labels == nil || len(rule.Labels) == 0 { - err = errors.New("the label rule has no label") - return - } - if rule.Data == nil { - err = errors.New("the label rule has no data") - return - } - dbName = s[1] - tableName = s[2] - if len(s) > 3 { - partitionName = s[3] - } - return -} - -func decodeTableIDFromRule(rule *label.Rule) (tableID int64, err error) { - datas := rule.Data.([]any) - if len(datas) == 0 { - err = fmt.Errorf("there is no data in rule %s", rule.ID) - return - } - data := datas[0] - dataMap, ok := data.(map[string]any) - if !ok { - err = fmt.Errorf("get the label rules %s failed", rule.ID) - return - } - key, err := hex.DecodeString(fmt.Sprintf("%s", dataMap["start_key"])) - if err != nil { - err = fmt.Errorf("decode key from start_key %s in rule %s failed", dataMap["start_key"], rule.ID) - return - } - _, bs, err := codec.DecodeBytes(key, nil) - if err == nil { - key = bs - } - tableID = tablecodec.DecodeTableID(key) - if tableID == 0 { - err = fmt.Errorf("decode tableID from key %s in rule %s failed", key, rule.ID) - return - } - return -} - -func tableOrPartitionNotExist(ctx context.Context, dbName string, tableName string, partitionName string, is infoschema.InfoSchema, tableID int64) (tableNotExist bool) { - if len(partitionName) == 0 { - curTable, _ := is.TableByName(ctx, model.NewCIStr(dbName), model.NewCIStr(tableName)) - if curTable == nil { - return true - } - curTableID := curTable.Meta().ID - if curTableID != tableID { - return true - } - } else { - _, _, partInfo := is.FindTableByPartitionID(tableID) - if partInfo == nil { - return true - } - } - return false -} diff --git a/pkg/executor/inspection_result.go b/pkg/executor/inspection_result.go index 8d8251a77b1e2..8596f562ca720 100644 --- a/pkg/executor/inspection_result.go +++ b/pkg/executor/inspection_result.go @@ -128,7 +128,7 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct sctx.GetSessionVars().InspectionTableCache = map[string]variable.TableSnapshot{} defer func() { sctx.GetSessionVars().InspectionTableCache = nil }() - if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockMergeMockInspectionTables")); _err_ == nil { + failpoint.InjectContext(ctx, "mockMergeMockInspectionTables", func() { // Merge mock snapshots injected from failpoint for test purpose mockTables, ok := ctx.Value("__mockInspectionTables").(map[string]variable.TableSnapshot) if ok { @@ -136,7 +136,7 @@ func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionct sctx.GetSessionVars().InspectionTableCache[strings.ToLower(name)] = snap } } - } + }) if e.instanceToStatusAddress == nil { // Get cluster info. diff --git a/pkg/executor/inspection_result.go__failpoint_stash__ b/pkg/executor/inspection_result.go__failpoint_stash__ deleted file mode 100644 index 8596f562ca720..0000000000000 --- a/pkg/executor/inspection_result.go__failpoint_stash__ +++ /dev/null @@ -1,1248 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "cmp" - "context" - "fmt" - "math" - "slices" - "strconv" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/set" - "github.com/pingcap/tidb/pkg/util/size" -) - -type ( - // inspectionResult represents a abnormal diagnosis result - inspectionResult struct { - tp string - instance string - statusAddress string - // represents the diagnostics item, e.g: `ddl.lease` `raftstore.cpuusage` - item string - // diagnosis result value base on current cluster status - actual string - expected string - severity string - detail string - // degree only used for sort. - degree float64 - } - - inspectionName string - - inspectionFilter struct { - set set.StringSet - timeRange plannerutil.QueryTimeRange - } - - inspectionRule interface { - name() string - inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult - } -) - -func (n inspectionName) name() string { - return string(n) -} - -func (f inspectionFilter) enable(name string) bool { - return len(f.set) == 0 || f.set.Exist(name) -} - -type ( - // configInspection is used to check whether a same configuration item has a - // different value between different instance in the cluster - configInspection struct{ inspectionName } - - // versionInspection is used to check whether the same component has different - // version in the cluster - versionInspection struct{ inspectionName } - - // nodeLoadInspection is used to check the node load of memory/disk/cpu - // have reached a high-level threshold - nodeLoadInspection struct{ inspectionName } - - // criticalErrorInspection is used to check are there some critical errors - // occurred in the past - criticalErrorInspection struct{ inspectionName } - - // thresholdCheckInspection is used to check some threshold value, like CPU usage, leader count change. - thresholdCheckInspection struct{ inspectionName } -) - -var inspectionRules = []inspectionRule{ - &configInspection{inspectionName: "config"}, - &versionInspection{inspectionName: "version"}, - &nodeLoadInspection{inspectionName: "node-load"}, - &criticalErrorInspection{inspectionName: "critical-error"}, - &thresholdCheckInspection{inspectionName: "threshold-check"}, -} - -type inspectionResultRetriever struct { - dummyCloser - retrieved bool - extractor *plannercore.InspectionResultTableExtractor - timeRange plannerutil.QueryTimeRange - instanceToStatusAddress map[string]string - statusToInstanceAddress map[string]string -} - -func (e *inspectionResultRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.retrieved || e.extractor.SkipInspection { - return nil, nil - } - e.retrieved = true - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - // Some data of cluster-level memory tables will be retrieved many times in different inspection rules, - // and the cost of retrieving some data is expensive. We use the `TableSnapshot` to cache those data - // and obtain them lazily, and provide a consistent view of inspection tables for each inspection rules. - // All cached snapshots should be released at the end of retrieving. - sctx.GetSessionVars().InspectionTableCache = map[string]variable.TableSnapshot{} - defer func() { sctx.GetSessionVars().InspectionTableCache = nil }() - - failpoint.InjectContext(ctx, "mockMergeMockInspectionTables", func() { - // Merge mock snapshots injected from failpoint for test purpose - mockTables, ok := ctx.Value("__mockInspectionTables").(map[string]variable.TableSnapshot) - if ok { - for name, snap := range mockTables { - sctx.GetSessionVars().InspectionTableCache[strings.ToLower(name)] = snap - } - } - }) - - if e.instanceToStatusAddress == nil { - // Get cluster info. - e.instanceToStatusAddress = make(map[string]string) - e.statusToInstanceAddress = make(map[string]string) - var rows []chunk.Row - exec := sctx.GetRestrictedSQLExecutor() - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select instance,status_address from information_schema.cluster_info;") - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("get cluster info failed: %v", err)) - } - for _, row := range rows { - if row.Len() < 2 { - continue - } - e.instanceToStatusAddress[row.GetString(0)] = row.GetString(1) - e.statusToInstanceAddress[row.GetString(1)] = row.GetString(0) - } - } - - rules := inspectionFilter{set: e.extractor.Rules} - items := inspectionFilter{set: e.extractor.Items, timeRange: e.timeRange} - var finalRows [][]types.Datum - for _, r := range inspectionRules { - name := r.name() - if !rules.enable(name) { - continue - } - results := r.inspect(ctx, sctx, items) - if len(results) == 0 { - continue - } - // make result stable - slices.SortFunc(results, func(i, j inspectionResult) int { - if c := cmp.Compare(i.degree, j.degree); c != 0 { - return -c - } - // lhs and rhs - if c := cmp.Compare(i.item, j.item); c != 0 { - return c - } - if c := cmp.Compare(i.actual, j.actual); c != 0 { - return c - } - // lhs and rhs - if c := cmp.Compare(i.tp, j.tp); c != 0 { - return c - } - return cmp.Compare(i.instance, j.instance) - }) - for _, result := range results { - if len(result.instance) == 0 { - result.instance = e.statusToInstanceAddress[result.statusAddress] - } - if len(result.statusAddress) == 0 { - result.statusAddress = e.instanceToStatusAddress[result.instance] - } - finalRows = append(finalRows, types.MakeDatums( - name, - result.item, - result.tp, - result.instance, - result.statusAddress, - result.actual, - result.expected, - result.severity, - result.detail, - )) - } - } - return finalRows, nil -} - -func (c configInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - var results []inspectionResult - results = append(results, c.inspectDiffConfig(ctx, sctx, filter)...) - results = append(results, c.inspectCheckConfig(ctx, sctx, filter)...) - return results -} - -func (configInspection) inspectDiffConfig(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - // check the configuration consistent - ignoreConfigKey := []string{ - // TiDB - "port", - "status.status-port", - "host", - "path", - "advertise-address", - "status.status-port", - "log.file.filename", - "log.slow-query-file", - "tmp-storage-path", - - // PD - "advertise-client-urls", - "advertise-peer-urls", - "client-urls", - "data-dir", - "log-file", - "log.file.filename", - "metric.job", - "name", - "peer-urls", - - // TiKV - "server.addr", - "server.advertise-addr", - "server.advertise-status-addr", - "server.status-addr", - "log-file", - "raftstore.raftdb-path", - "storage.data-dir", - "storage.block-cache.capacity", - } - exec := sctx.GetRestrictedSQLExecutor() - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select type, `key`, count(distinct value) as c from information_schema.cluster_config where `key` not in (%?) group by type, `key` having c > 1", ignoreConfigKey) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) - } - - generateDetail := func(tp, item string) string { - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select value, instance from information_schema.cluster_config where type=%? and `key`=%?;", tp, item) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration consistency failed: %v", err)) - return fmt.Sprintf("the cluster has different config value of %[2]s, execute the sql to see more detail: select * from information_schema.cluster_config where type='%[1]s' and `key`='%[2]s'", - tp, item) - } - m := make(map[string][]string) - for _, row := range rows { - value := row.GetString(0) - instance := row.GetString(1) - m[value] = append(m[value], instance) - } - groups := make([]string, 0, len(m)) - for k, v := range m { - slices.Sort(v) - groups = append(groups, fmt.Sprintf("%s config value is %s", strings.Join(v, ","), k)) - } - slices.Sort(groups) - return strings.Join(groups, "\n") - } - - var results []inspectionResult - for _, row := range rows { - if filter.enable(row.GetString(1)) { - detail := generateDetail(row.GetString(0), row.GetString(1)) - results = append(results, inspectionResult{ - tp: row.GetString(0), - instance: "", - item: row.GetString(1), // key - actual: "inconsistent", - expected: "consistent", - severity: "warning", - detail: detail, - }) - } - } - return results -} - -func (c configInspection) inspectCheckConfig(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - // check the configuration in reason. - cases := []struct { - table string - tp string - key string - expect string - cond string - detail string - }{ - { - table: "cluster_config", - key: "log.slow-threshold", - expect: "> 0", - cond: "type = 'tidb' and `key` = 'log.slow-threshold' and value = '0'", - detail: "slow-threshold = 0 will record every query to slow log, it may affect performance", - }, - { - - table: "cluster_config", - key: "raftstore.sync-log", - expect: "true", - cond: "type = 'tikv' and `key` = 'raftstore.sync-log' and value = 'false'", - detail: "sync-log should be true to avoid recover region when the machine breaks down", - }, - { - table: "cluster_systeminfo", - key: "transparent_hugepage_enabled", - expect: "always madvise [never]", - cond: "system_name = 'kernel' and name = 'transparent_hugepage_enabled' and value not like '%[never]%'", - detail: "Transparent HugePages can cause memory allocation delays during runtime, TiDB recommends that you disable Transparent HugePages on all TiDB servers", - }, - } - - var results []inspectionResult - var rows []chunk.Row - sql := new(strings.Builder) - exec := sctx.GetRestrictedSQLExecutor() - for _, cas := range cases { - if !filter.enable(cas.key) { - continue - } - sql.Reset() - fmt.Fprintf(sql, "select type,instance,value from information_schema.%s where %s", cas.table, cas.cond) - stmt, err := exec.ParseWithParams(ctx, sql.String()) - if err == nil { - rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) - } - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) - } - - for _, row := range rows { - results = append(results, inspectionResult{ - tp: row.GetString(0), - instance: row.GetString(1), - item: cas.key, - actual: row.GetString(2), - expected: cas.expect, - severity: "warning", - detail: cas.detail, - }) - } - } - results = append(results, c.checkTiKVBlockCacheSizeConfig(ctx, sctx, filter)...) - return results -} - -func (c configInspection) checkTiKVBlockCacheSizeConfig(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - item := "storage.block-cache.capacity" - if !filter.enable(item) { - return nil - } - exec := sctx.GetRestrictedSQLExecutor() - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select instance,value from information_schema.cluster_config where type='tikv' and `key` = 'storage.block-cache.capacity'") - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) - } - extractIP := func(addr string) string { - if idx := strings.Index(addr, ":"); idx > -1 { - return addr[0:idx] - } - return addr - } - - ipToBlockSize := make(map[string]uint64) - ipToCount := make(map[string]int) - for _, row := range rows { - ip := extractIP(row.GetString(0)) - size, err := c.convertReadableSizeToByteSize(row.GetString(1)) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check TiKV block-cache configuration in reason failed: %v", err)) - return nil - } - ipToBlockSize[ip] += size - ipToCount[ip]++ - } - - rows, _, err = exec.ExecRestrictedSQL(ctx, nil, "select instance, value from metrics_schema.node_total_memory where time=now()") - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check configuration in reason failed: %v", err)) - } - ipToMemorySize := make(map[string]float64) - for _, row := range rows { - ip := extractIP(row.GetString(0)) - size := row.GetFloat64(1) - ipToMemorySize[ip] += size - } - - var results []inspectionResult - for ip, blockSize := range ipToBlockSize { - if memorySize, ok := ipToMemorySize[ip]; ok { - if float64(blockSize) > memorySize*0.45 { - detail := fmt.Sprintf("There are %v TiKV server in %v node, the total 'storage.block-cache.capacity' of TiKV is more than (0.45 * total node memory)", - ipToCount[ip], ip) - results = append(results, inspectionResult{ - tp: "tikv", - instance: ip, - item: item, - actual: fmt.Sprintf("%v", blockSize), - expected: fmt.Sprintf("< %.0f", memorySize*0.45), - severity: "warning", - detail: detail, - }) - } - } - } - return results -} - -func (configInspection) convertReadableSizeToByteSize(sizeStr string) (uint64, error) { - rate := uint64(1) - if strings.HasSuffix(sizeStr, "KiB") { - rate = size.KB - } else if strings.HasSuffix(sizeStr, "MiB") { - rate = size.MB - } else if strings.HasSuffix(sizeStr, "GiB") { - rate = size.GB - } else if strings.HasSuffix(sizeStr, "TiB") { - rate = size.TB - } else if strings.HasSuffix(sizeStr, "PiB") { - rate = size.PB - } - if rate != 1 && len(sizeStr) > 3 { - sizeStr = sizeStr[:len(sizeStr)-3] - } - size, err := strconv.Atoi(sizeStr) - if err != nil { - return 0, errors.Trace(err) - } - return uint64(size) * rate, nil -} - -func (versionInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - exec := sctx.GetRestrictedSQLExecutor() - // check the configuration consistent - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, "select type, count(distinct git_hash) as c from information_schema.cluster_info group by type having c > 1;") - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("check version consistency failed: %v", err)) - } - - const name = "git_hash" - var results []inspectionResult - for _, row := range rows { - if filter.enable(name) { - results = append(results, inspectionResult{ - tp: row.GetString(0), - instance: "", - item: name, - actual: "inconsistent", - expected: "consistent", - severity: "critical", - detail: fmt.Sprintf("the cluster has %[1]v different %[2]s versions, execute the sql to see more detail: select * from information_schema.cluster_info where type='%[2]s'", row.GetUint64(1), row.GetString(0)), - }) - } - } - return results -} - -func (nodeLoadInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - var rules = []ruleChecker{ - inspectCPULoad{item: "load1", tbl: "node_load1"}, - inspectCPULoad{item: "load5", tbl: "node_load5"}, - inspectCPULoad{item: "load15", tbl: "node_load15"}, - inspectVirtualMemUsage{}, - inspectSwapMemoryUsed{}, - inspectDiskUsage{}, - } - return checkRules(ctx, sctx, filter, rules) -} - -type inspectVirtualMemUsage struct{} - -func (inspectVirtualMemUsage) genSQL(timeRange plannerutil.QueryTimeRange) string { - sql := fmt.Sprintf("select instance, max(value) as max_usage from metrics_schema.node_memory_usage %s group by instance having max_usage >= 70", timeRange.Condition()) - return sql -} - -func (i inspectVirtualMemUsage) genResult(_ string, row chunk.Row) inspectionResult { - return inspectionResult{ - tp: "node", - instance: row.GetString(0), - item: i.getItem(), - actual: fmt.Sprintf("%.1f%%", row.GetFloat64(1)), - expected: "< 70%", - severity: "warning", - detail: "the memory-usage is too high", - } -} - -func (inspectVirtualMemUsage) getItem() string { - return "virtual-memory-usage" -} - -type inspectSwapMemoryUsed struct{} - -func (inspectSwapMemoryUsed) genSQL(timeRange plannerutil.QueryTimeRange) string { - sql := fmt.Sprintf("select instance, max(value) as max_used from metrics_schema.node_memory_swap_used %s group by instance having max_used > 0", timeRange.Condition()) - return sql -} - -func (i inspectSwapMemoryUsed) genResult(_ string, row chunk.Row) inspectionResult { - return inspectionResult{ - tp: "node", - instance: row.GetString(0), - item: i.getItem(), - actual: fmt.Sprintf("%.1f", row.GetFloat64(1)), - expected: "0", - severity: "warning", - } -} - -func (inspectSwapMemoryUsed) getItem() string { - return "swap-memory-used" -} - -type inspectDiskUsage struct{} - -func (inspectDiskUsage) genSQL(timeRange plannerutil.QueryTimeRange) string { - sql := fmt.Sprintf("select instance, device, max(value) as max_usage from metrics_schema.node_disk_usage %v and device like '/%%' group by instance, device having max_usage >= 70", timeRange.Condition()) - return sql -} - -func (i inspectDiskUsage) genResult(_ string, row chunk.Row) inspectionResult { - return inspectionResult{ - tp: "node", - instance: row.GetString(0), - item: i.getItem(), - actual: fmt.Sprintf("%.1f%%", row.GetFloat64(2)), - expected: "< 70%", - severity: "warning", - detail: "the disk-usage of " + row.GetString(1) + " is too high", - } -} - -func (inspectDiskUsage) getItem() string { - return "disk-usage" -} - -type inspectCPULoad struct { - item string - tbl string -} - -func (i inspectCPULoad) genSQL(timeRange plannerutil.QueryTimeRange) string { - sql := fmt.Sprintf(`select t1.instance, t1.max_load , 0.7*t2.cpu_count from - (select instance,max(value) as max_load from metrics_schema.%[1]s %[2]s group by instance) as t1 join - (select instance,max(value) as cpu_count from metrics_schema.node_virtual_cpus %[2]s group by instance) as t2 - on t1.instance=t2.instance where t1.max_load>(0.7*t2.cpu_count);`, i.tbl, timeRange.Condition()) - return sql -} - -func (i inspectCPULoad) genResult(_ string, row chunk.Row) inspectionResult { - return inspectionResult{ - tp: "node", - instance: row.GetString(0), - item: "cpu-" + i.item, - actual: fmt.Sprintf("%.1f", row.GetFloat64(1)), - expected: fmt.Sprintf("< %.1f", row.GetFloat64(2)), - severity: "warning", - detail: i.getItem() + " should less than (cpu_logical_cores * 0.7)", - } -} - -func (i inspectCPULoad) getItem() string { - return "cpu-" + i.item -} - -func (c criticalErrorInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - results := c.inspectError(ctx, sctx, filter) - results = append(results, c.inspectForServerDown(ctx, sctx, filter)...) - return results -} -func (criticalErrorInspection) inspectError(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - var rules = []struct { - tp string - item string - tbl string - }{ - {tp: "tikv", item: "critical-error", tbl: "tikv_critical_error_total_count"}, - {tp: "tidb", item: "panic-count", tbl: "tidb_panic_count_total_count"}, - {tp: "tidb", item: "binlog-error", tbl: "tidb_binlog_error_total_count"}, - {tp: "tikv", item: "scheduler-is-busy", tbl: "tikv_scheduler_is_busy_total_count"}, - {tp: "tikv", item: "coprocessor-is-busy", tbl: "tikv_coprocessor_is_busy_total_count"}, - {tp: "tikv", item: "channel-is-full", tbl: "tikv_channel_full_total_count"}, - {tp: "tikv", item: "tikv_engine_write_stall", tbl: "tikv_engine_write_stall"}, - } - - condition := filter.timeRange.Condition() - var results []inspectionResult - exec := sctx.GetRestrictedSQLExecutor() - sql := new(strings.Builder) - for _, rule := range rules { - if filter.enable(rule.item) { - def, found := infoschema.MetricTableMap[rule.tbl] - if !found { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", rule.tbl)) - continue - } - sql.Reset() - fmt.Fprintf(sql, "select `%[1]s`,sum(value) as total from `%[2]s`.`%[3]s` %[4]s group by `%[1]s` having total>=1.0", - strings.Join(def.Labels, "`,`"), util.MetricSchemaName.L, rule.tbl, condition) - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - continue - } - for _, row := range rows { - var actual, detail string - var degree float64 - if rest := def.Labels[1:]; len(rest) > 0 { - values := make([]string, 0, len(rest)) - // `i+1` and `1+len(rest)` means skip the first field `instance` - for i := range rest { - values = append(values, row.GetString(i+1)) - } - // TODO: find a better way to construct the `actual` field - actual = fmt.Sprintf("%.2f(%s)", row.GetFloat64(1+len(rest)), strings.Join(values, ", ")) - degree = row.GetFloat64(1 + len(rest)) - } else { - actual = fmt.Sprintf("%.2f", row.GetFloat64(1)) - degree = row.GetFloat64(1) - } - detail = fmt.Sprintf("the total number of errors about '%s' is too many", rule.item) - result := inspectionResult{ - tp: rule.tp, - // NOTE: all tables which can be inspected here whose first label must be `instance` - statusAddress: row.GetString(0), - item: rule.item, - actual: actual, - expected: "0", - severity: "critical", - detail: detail, - degree: degree, - } - results = append(results, result) - } - } - } - return results -} - -func (criticalErrorInspection) inspectForServerDown(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - item := "server-down" - if !filter.enable(item) { - return nil - } - condition := filter.timeRange.Condition() - exec := sctx.GetRestrictedSQLExecutor() - sql := new(strings.Builder) - fmt.Fprintf(sql, `select t1.job,t1.instance, t2.min_time from - (select instance,job from metrics_schema.up %[1]s group by instance,job having max(value)-min(value)>0) as t1 join - (select instance,min(time) as min_time from metrics_schema.up %[1]s and value=0 group by instance,job) as t2 on t1.instance=t2.instance order by job`, condition) - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - } - results := make([]inspectionResult, 0, len(rows)) - for _, row := range rows { - if row.Len() < 3 { - continue - } - detail := fmt.Sprintf("%s %s disconnect with prometheus around time '%s'", row.GetString(0), row.GetString(1), row.GetTime(2)) - result := inspectionResult{ - tp: row.GetString(0), - statusAddress: row.GetString(1), - item: item, - actual: "", - expected: "", - severity: "critical", - detail: detail, - degree: 10000 + float64(len(results)), - } - results = append(results, result) - } - // Check from log. - sql.Reset() - fmt.Fprintf(sql, "select type,instance,time from information_schema.cluster_log %s and level = 'info' and message like '%%Welcome to'", condition) - rows, _, err = exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - } - for _, row := range rows { - if row.Len() < 3 { - continue - } - detail := fmt.Sprintf("%s %s restarted at time '%s'", row.GetString(0), row.GetString(1), row.GetString(2)) - result := inspectionResult{ - tp: row.GetString(0), - instance: row.GetString(1), - item: item, - actual: "", - expected: "", - severity: "critical", - detail: detail, - degree: 10000 + float64(len(results)), - } - results = append(results, result) - } - return results -} - -func (c thresholdCheckInspection) inspect(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - inspects := []func(context.Context, sessionctx.Context, inspectionFilter) []inspectionResult{ - c.inspectThreshold1, - c.inspectThreshold2, - c.inspectThreshold3, - c.inspectForLeaderDrop, - } - //nolint: prealloc - var results []inspectionResult - for _, inspect := range inspects { - re := inspect(ctx, sctx, filter) - results = append(results, re...) - } - return results -} - -func (thresholdCheckInspection) inspectThreshold1(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - var rules = []struct { - item string - component string - configKey string - threshold float64 - }{ - { - item: "coprocessor-normal-cpu", - component: "cop_normal%", - configKey: "readpool.coprocessor.normal-concurrency", - threshold: 0.9}, - { - item: "coprocessor-high-cpu", - component: "cop_high%", - configKey: "readpool.coprocessor.high-concurrency", - threshold: 0.9, - }, - { - item: "coprocessor-low-cpu", - component: "cop_low%", - configKey: "readpool.coprocessor.low-concurrency", - threshold: 0.9, - }, - { - item: "grpc-cpu", - component: "grpc%", - configKey: "server.grpc-concurrency", - threshold: 0.9, - }, - { - item: "raftstore-cpu", - component: "raftstore_%", - configKey: "raftstore.store-pool-size", - threshold: 0.8, - }, - { - item: "apply-cpu", - component: "apply_%", - configKey: "raftstore.apply-pool-size", - threshold: 0.8, - }, - { - item: "storage-readpool-normal-cpu", - component: "store_read_norm%", - configKey: "readpool.storage.normal-concurrency", - threshold: 0.9, - }, - { - item: "storage-readpool-high-cpu", - component: "store_read_high%", - configKey: "readpool.storage.high-concurrency", - threshold: 0.9, - }, - { - item: "storage-readpool-low-cpu", - component: "store_read_low%", - configKey: "readpool.storage.low-concurrency", - threshold: 0.9, - }, - { - item: "scheduler-worker-cpu", - component: "sched_%", - configKey: "storage.scheduler-worker-pool-size", - threshold: 0.85, - }, - { - item: "split-check-cpu", - component: "split_check", - threshold: 0.9, - }, - } - - condition := filter.timeRange.Condition() - var results []inspectionResult - exec := sctx.GetRestrictedSQLExecutor() - sql := new(strings.Builder) - for _, rule := range rules { - if !filter.enable(rule.item) { - continue - } - - sql.Reset() - if len(rule.configKey) > 0 { - fmt.Fprintf(sql, `select t1.status_address, t1.cpu, (t2.value * %[2]f) as threshold, t2.value from - (select status_address, max(sum_value) as cpu from (select instance as status_address, sum(value) as sum_value from metrics_schema.tikv_thread_cpu %[4]s and name like '%[1]s' group by instance, time) as tmp group by tmp.status_address) as t1 join - (select instance, value from information_schema.cluster_config where type='tikv' and %[5]s = '%[3]s') as t2 join - (select instance,status_address from information_schema.cluster_info where type='tikv') as t3 - on t1.status_address=t3.status_address and t2.instance=t3.instance where t1.cpu > (t2.value * %[2]f)`, rule.component, rule.threshold, rule.configKey, condition, "`key`") - } else { - fmt.Fprintf(sql, `select t1.instance, t1.cpu, %[2]f from - (select instance, max(value) as cpu from metrics_schema.tikv_thread_cpu %[3]s and name like '%[1]s' group by instance) as t1 - where t1.cpu > %[2]f;`, rule.component, rule.threshold, condition) - } - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - continue - } - for _, row := range rows { - actual := fmt.Sprintf("%.2f", row.GetFloat64(1)) - degree := math.Abs(row.GetFloat64(1)-row.GetFloat64(2)) / math.Max(row.GetFloat64(1), row.GetFloat64(2)) - expected := "" - if len(rule.configKey) > 0 { - expected = fmt.Sprintf("< %.2f, config: %v=%v", row.GetFloat64(2), rule.configKey, row.GetString(3)) - } else { - expected = fmt.Sprintf("< %.2f", row.GetFloat64(2)) - } - detail := fmt.Sprintf("the '%s' max cpu-usage of %s tikv is too high", rule.item, row.GetString(0)) - result := inspectionResult{ - tp: "tikv", - statusAddress: row.GetString(0), - item: rule.item, - actual: actual, - expected: expected, - severity: "warning", - detail: detail, - degree: degree, - } - results = append(results, result) - } - } - return results -} - -func (thresholdCheckInspection) inspectThreshold2(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - var rules = []struct { - tp string - item string - tbl string - condition string - threshold float64 - factor float64 - isMin bool - detail string - }{ - { - tp: "tidb", - item: "tso-duration", - tbl: "pd_tso_wait_duration", - condition: "quantile=0.999", - threshold: 0.05, - }, - { - tp: "tidb", - item: "get-token-duration", - tbl: "tidb_get_token_duration", - condition: "quantile=0.999", - threshold: 0.001, - factor: 10e5, // the unit is microsecond - }, - { - tp: "tidb", - item: "load-schema-duration", - tbl: "tidb_load_schema_duration", - condition: "quantile=0.99", - threshold: 1, - }, - { - tp: "tikv", - item: "scheduler-cmd-duration", - tbl: "tikv_scheduler_command_duration", - condition: "quantile=0.99", - threshold: 0.1, - }, - { - tp: "tikv", - item: "handle-snapshot-duration", - tbl: "tikv_handle_snapshot_duration", - threshold: 30, - }, - { - tp: "tikv", - item: "storage-write-duration", - tbl: "tikv_storage_async_request_duration", - condition: "type='write'", - threshold: 0.1, - }, - { - tp: "tikv", - item: "storage-snapshot-duration", - tbl: "tikv_storage_async_request_duration", - condition: "type='snapshot'", - threshold: 0.05, - }, - { - tp: "tikv", - item: "rocksdb-write-duration", - tbl: "tikv_engine_write_duration", - condition: "type='write_max'", - threshold: 0.1, - factor: 10e5, // the unit is microsecond - }, - { - tp: "tikv", - item: "rocksdb-get-duration", - tbl: "tikv_engine_max_get_duration", - condition: "type='get_max'", - threshold: 0.05, - factor: 10e5, - }, - { - tp: "tikv", - item: "rocksdb-seek-duration", - tbl: "tikv_engine_max_seek_duration", - condition: "type='seek_max'", - threshold: 0.05, - factor: 10e5, // the unit is microsecond - }, - { - tp: "tikv", - item: "scheduler-pending-cmd-count", - tbl: "tikv_scheduler_pending_commands", - threshold: 1000, - detail: " %s tikv scheduler has too many pending commands", - }, - { - tp: "tikv", - item: "index-block-cache-hit", - tbl: "tikv_block_index_cache_hit", - condition: "value > 0", - threshold: 0.95, - isMin: true, - }, - { - tp: "tikv", - item: "filter-block-cache-hit", - tbl: "tikv_block_filter_cache_hit", - condition: "value > 0", - threshold: 0.95, - isMin: true, - }, - { - tp: "tikv", - item: "data-block-cache-hit", - tbl: "tikv_block_data_cache_hit", - condition: "value > 0", - threshold: 0.80, - isMin: true, - }, - } - - condition := filter.timeRange.Condition() - var results []inspectionResult - sql := new(strings.Builder) - exec := sctx.GetRestrictedSQLExecutor() - for _, rule := range rules { - if !filter.enable(rule.item) { - continue - } - cond := condition - if len(rule.condition) > 0 { - cond = fmt.Sprintf("%s and %s", cond, rule.condition) - } - if rule.factor == 0 { - rule.factor = 1 - } - sql.Reset() - if rule.isMin { - fmt.Fprintf(sql, "select instance, min(value)/%.0f as min_value from metrics_schema.%s %s group by instance having min_value < %f;", rule.factor, rule.tbl, cond, rule.threshold) - } else { - fmt.Fprintf(sql, "select instance, max(value)/%.0f as max_value from metrics_schema.%s %s group by instance having max_value > %f;", rule.factor, rule.tbl, cond, rule.threshold) - } - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - continue - } - for _, row := range rows { - actual := fmt.Sprintf("%.3f", row.GetFloat64(1)) - degree := math.Abs(row.GetFloat64(1)-rule.threshold) / math.Max(row.GetFloat64(1), rule.threshold) - expected := "" - if rule.isMin { - expected = fmt.Sprintf("> %.3f", rule.threshold) - } else { - expected = fmt.Sprintf("< %.3f", rule.threshold) - } - detail := rule.detail - if len(detail) == 0 { - if strings.HasSuffix(rule.item, "duration") { - detail = fmt.Sprintf("max duration of %s %s %s is too slow", row.GetString(0), rule.tp, rule.item) - } else if strings.HasSuffix(rule.item, "hit") { - detail = fmt.Sprintf("min %s rate of %s %s is too low", rule.item, row.GetString(0), rule.tp) - } - } else { - detail = fmt.Sprintf(detail, row.GetString(0)) - } - result := inspectionResult{ - tp: rule.tp, - statusAddress: row.GetString(0), - item: rule.item, - actual: actual, - expected: expected, - severity: "warning", - detail: detail, - degree: degree, - } - results = append(results, result) - } - } - return results -} - -type ruleChecker interface { - genSQL(timeRange plannerutil.QueryTimeRange) string - genResult(sql string, row chunk.Row) inspectionResult - getItem() string -} - -type compareStoreStatus struct { - item string - tp string - threshold float64 -} - -func (c compareStoreStatus) genSQL(timeRange plannerutil.QueryTimeRange) string { - condition := fmt.Sprintf(`where t1.time>='%[1]s' and t1.time<='%[2]s' and - t2.time>='%[1]s' and t2.time<='%[2]s'`, timeRange.From.Format(plannerutil.MetricTableTimeFormat), - timeRange.To.Format(plannerutil.MetricTableTimeFormat)) - return fmt.Sprintf(` - SELECT t1.address, - max(t1.value), - t2.address, - min(t2.value), - max((t1.value-t2.value)/t1.value) AS ratio - FROM metrics_schema.pd_scheduler_store_status t1 - JOIN metrics_schema.pd_scheduler_store_status t2 %s - AND t1.type='%s' - AND t1.time = t2.time - AND t1.type=t2.type - AND t1.address != t2.address - AND (t1.value-t2.value)/t1.value>%v - AND t1.value > 0 - GROUP BY t1.address,t2.address - ORDER BY ratio desc`, condition, c.tp, c.threshold) -} - -func (c compareStoreStatus) genResult(_ string, row chunk.Row) inspectionResult { - addr1 := row.GetString(0) - value1 := row.GetFloat64(1) - addr2 := row.GetString(2) - value2 := row.GetFloat64(3) - ratio := row.GetFloat64(4) - detail := fmt.Sprintf("%v max %s is %.2f, much more than %v min %s %.2f", addr1, c.tp, value1, addr2, c.tp, value2) - return inspectionResult{ - tp: "tikv", - instance: addr2, - item: c.item, - actual: fmt.Sprintf("%.2f%%", ratio*100), - expected: fmt.Sprintf("< %.2f%%", c.threshold*100), - severity: "warning", - detail: detail, - degree: ratio, - } -} - -func (c compareStoreStatus) getItem() string { - return c.item -} - -type checkRegionHealth struct{} - -func (checkRegionHealth) genSQL(timeRange plannerutil.QueryTimeRange) string { - condition := timeRange.Condition() - return fmt.Sprintf(`select instance, sum(value) as sum_value from metrics_schema.pd_region_health %s and - type in ('extra-peer-region-count','learner-peer-region-count','pending-peer-region-count') having sum_value>100`, condition) -} - -func (c checkRegionHealth) genResult(_ string, row chunk.Row) inspectionResult { - detail := fmt.Sprintf("the count of extra-perr and learner-peer and pending-peer are %v, it means the scheduling is too frequent or too slow", row.GetFloat64(1)) - actual := fmt.Sprintf("%.2f", row.GetFloat64(1)) - degree := math.Abs(row.GetFloat64(1)-100) / math.Max(row.GetFloat64(1), 100) - return inspectionResult{ - tp: "pd", - instance: row.GetString(0), - item: c.getItem(), - actual: actual, - expected: "< 100", - severity: "warning", - detail: detail, - degree: degree, - } -} - -func (checkRegionHealth) getItem() string { - return "region-health" -} - -type checkStoreRegionTooMuch struct{} - -func (checkStoreRegionTooMuch) genSQL(timeRange plannerutil.QueryTimeRange) string { - condition := timeRange.Condition() - return fmt.Sprintf(`select address, max(value) from metrics_schema.pd_scheduler_store_status %s and type='region_count' and value > 20000 group by address`, condition) -} - -func (c checkStoreRegionTooMuch) genResult(_ string, row chunk.Row) inspectionResult { - actual := fmt.Sprintf("%.2f", row.GetFloat64(1)) - degree := math.Abs(row.GetFloat64(1)-20000) / math.Max(row.GetFloat64(1), 20000) - return inspectionResult{ - tp: "tikv", - instance: row.GetString(0), - item: c.getItem(), - actual: actual, - expected: "<= 20000", - severity: "warning", - detail: fmt.Sprintf("%s tikv has too many regions", row.GetString(0)), - degree: degree, - } -} - -func (checkStoreRegionTooMuch) getItem() string { - return "region-count" -} - -func (thresholdCheckInspection) inspectThreshold3(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - var rules = []ruleChecker{ - compareStoreStatus{ - item: "leader-score-balance", - tp: "leader_score", - threshold: 0.05, - }, - compareStoreStatus{ - item: "region-score-balance", - tp: "region_score", - threshold: 0.05, - }, - compareStoreStatus{ - item: "store-available-balance", - tp: "store_available", - threshold: 0.2, - }, - checkRegionHealth{}, - checkStoreRegionTooMuch{}, - } - return checkRules(ctx, sctx, filter, rules) -} - -func checkRules(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter, rules []ruleChecker) []inspectionResult { - var results []inspectionResult - exec := sctx.GetRestrictedSQLExecutor() - for _, rule := range rules { - if !filter.enable(rule.getItem()) { - continue - } - sql := rule.genSQL(filter.timeRange) - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - continue - } - for _, row := range rows { - results = append(results, rule.genResult(sql, row)) - } - } - return results -} - -func (thresholdCheckInspection) inspectForLeaderDrop(ctx context.Context, sctx sessionctx.Context, filter inspectionFilter) []inspectionResult { - condition := filter.timeRange.Condition() - threshold := 50.0 - sql := new(strings.Builder) - fmt.Fprintf(sql, `select address,min(value) as mi,max(value) as mx from metrics_schema.pd_scheduler_store_status %s and type='leader_count' group by address having mx-mi>%v`, condition, threshold) - exec := sctx.GetRestrictedSQLExecutor() - - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - return nil - } - var results []inspectionResult - for _, row := range rows { - address := row.GetString(0) - sql.Reset() - fmt.Fprintf(sql, `select time, value from metrics_schema.pd_scheduler_store_status %s and type='leader_count' and address = '%s' order by time`, condition, address) - var subRows []chunk.Row - subRows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("execute '%s' failed: %v", sql, err)) - continue - } - - lastValue := float64(0) - for i, subRows := range subRows { - v := subRows.GetFloat64(1) - if i == 0 { - lastValue = v - continue - } - if lastValue-v > threshold { - level := "warning" - if v == 0 { - level = "critical" - } - results = append(results, inspectionResult{ - tp: "tikv", - instance: address, - item: "leader-drop", - actual: fmt.Sprintf("%.0f", lastValue-v), - expected: fmt.Sprintf("<= %.0f", threshold), - severity: level, - detail: fmt.Sprintf("%s tikv has too many leader-drop around time %s, leader count from %.0f drop to %.0f", address, subRows.GetTime(0), lastValue, v), - degree: lastValue - v, - }) - break - } - lastValue = v - } - } - return results -} diff --git a/pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go b/pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go deleted file mode 100644 index ee9bcd36346bf..0000000000000 --- a/pkg/executor/internal/calibrateresource/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package calibrateresource - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/internal/calibrateresource/calibrate_resource.go b/pkg/executor/internal/calibrateresource/calibrate_resource.go index 75378d7acf087..74cb61fec5846 100644 --- a/pkg/executor/internal/calibrateresource/calibrate_resource.go +++ b/pkg/executor/internal/calibrateresource/calibrate_resource.go @@ -317,7 +317,7 @@ func (e *Executor) getTiDBQuota( return 0, err } - if _, _err_ := failpoint.Eval(_curpkg_("mockMetricsDataFilter")); _err_ == nil { + failpoint.Inject("mockMetricsDataFilter", func() { ret := make([]*timePointValue, 0) for _, point := range tikvCPUs.vals { if point.tp.After(endTs) || point.tp.Before(startTs) { @@ -342,7 +342,7 @@ func (e *Executor) getTiDBQuota( ret = append(ret, point) } rus.vals = ret - } + }) quotas := make([]float64, 0) lowCount := 0 for { @@ -505,11 +505,11 @@ func staticCalibrateTpch10(req *chunk.Chunk, clusterInfo []infoschema.ServerInfo func getTiDBTotalCPUQuota(clusterInfo []infoschema.ServerInfo) (float64, error) { cpuQuota := float64(runtime.GOMAXPROCS(0)) - if val, _err_ := failpoint.Eval(_curpkg_("mockGOMAXPROCS")); _err_ == nil { + failpoint.Inject("mockGOMAXPROCS", func(val failpoint.Value) { if val != nil { cpuQuota = float64(val.(int)) } - } + }) instanceNum := count(clusterInfo, serverTypeTiDB) return cpuQuota * float64(instanceNum), nil } @@ -662,7 +662,7 @@ func fetchStoreMetrics(serversInfo []infoschema.ServerInfo, serverType string, o return err } var resp *http.Response - if val, _err_ := failpoint.Eval(_curpkg_("mockMetricsResponse")); _err_ == nil { + failpoint.Inject("mockMetricsResponse", func(val failpoint.Value) { if val != nil { data, _ := base64.StdEncoding.DecodeString(val.(string)) resp = &http.Response{ @@ -672,7 +672,7 @@ func fetchStoreMetrics(serversInfo []infoschema.ServerInfo, serverType string, o }, } } - } + }) if resp == nil { var err1 error // ignore false positive go line, can't use defer here because it's in a loop. diff --git a/pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ b/pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ deleted file mode 100644 index 74cb61fec5846..0000000000000 --- a/pkg/executor/internal/calibrateresource/calibrate_resource.go__failpoint_stash__ +++ /dev/null @@ -1,704 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package calibrateresource - -import ( - "bufio" - "context" - "encoding/base64" - "fmt" - "io" - "math" - "net/http" - "runtime" - "sort" - "strconv" - "strings" - "time" - - "github.com/docker/go-units" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/duration" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn/staleread" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/tikv/client-go/v2/oracle" - resourceControlClient "github.com/tikv/pd/client/resource_group/controller" -) - -var ( - // workloadBaseRUCostMap contains the base resource cost rate per 1 kv cpu within 1 second, - // the data is calculated from benchmark result, these data might not be very accurate, - // but is enough here because the maximum RU capacity is depended on both the cluster and - // the workload. - workloadBaseRUCostMap = map[ast.CalibrateResourceType]*baseResourceCost{ - ast.TPCC: { - tidbToKVCPURatio: 0.6, - kvCPU: 0.15, - readBytes: units.MiB / 2, - writeBytes: units.MiB, - readReqCount: 300, - writeReqCount: 1750, - }, - ast.OLTPREADWRITE: { - tidbToKVCPURatio: 1.25, - kvCPU: 0.35, - readBytes: units.MiB * 4.25, - writeBytes: units.MiB / 3, - readReqCount: 1600, - writeReqCount: 1400, - }, - ast.OLTPREADONLY: { - tidbToKVCPURatio: 2, - kvCPU: 0.52, - readBytes: units.MiB * 28, - writeBytes: 0, - readReqCount: 4500, - writeReqCount: 0, - }, - ast.OLTPWRITEONLY: { - tidbToKVCPURatio: 1, - kvCPU: 0, - readBytes: 0, - writeBytes: units.MiB, - readReqCount: 0, - writeReqCount: 3550, - }, - } -) - -const ( - // serverTypeTiDB is tidb's instance type name - serverTypeTiDB = "tidb" - // serverTypeTiKV is tikv's instance type name - serverTypeTiKV = "tikv" - // serverTypeTiFlash is tiflash's instance type name - serverTypeTiFlash = "tiflash" -) - -// the resource cost rate of a specified workload per 1 tikv cpu. -type baseResourceCost struct { - // represents the average ratio of TiDB CPU time to TiKV CPU time, this is used to calculate whether tikv cpu - // or tidb cpu is the performance bottle neck. - tidbToKVCPURatio float64 - // the kv CPU time for calculate RU, it's smaller than the actual cpu usage. The unit is seconds. - kvCPU float64 - // the read bytes rate per 1 tikv cpu. - readBytes uint64 - // the write bytes rate per 1 tikv cpu. - writeBytes uint64 - // the average tikv read request count per 1 tikv cpu. - readReqCount uint64 - // the average tikv write request count per 1 tikv cpu. - writeReqCount uint64 -} - -const ( - // valuableUsageThreshold is the threshold used to determine whether the CPU is high enough. - // The sampling point is available when the CPU utilization of tikv or tidb is higher than the valuableUsageThreshold. - valuableUsageThreshold = 0.2 - // lowUsageThreshold is the threshold used to determine whether the CPU is too low. - // When the CPU utilization of tikv or tidb is lower than lowUsageThreshold, but neither is higher than valuableUsageThreshold, the sampling point is unavailable - lowUsageThreshold = 0.1 - // For quotas computed at each point in time, the maximum and minimum portions are discarded, and discardRate is the percentage discarded - discardRate = 0.1 - - // duration Indicates the supported calibration duration - maxDuration = time.Hour * 24 - minDuration = time.Minute -) - -// Executor is used as executor of calibrate resource. -type Executor struct { - OptionList []*ast.DynamicCalibrateResourceOption - exec.BaseExecutor - WorkloadType ast.CalibrateResourceType - done bool -} - -func (e *Executor) parseTsExpr(ctx context.Context, tsExpr ast.ExprNode) (time.Time, error) { - ts, err := staleread.CalculateAsOfTsExpr(ctx, e.Ctx().GetPlanCtx(), tsExpr) - if err != nil { - return time.Time{}, err - } - return oracle.GetTimeFromTS(ts), nil -} - -func (e *Executor) parseCalibrateDuration(ctx context.Context) (startTime time.Time, endTime time.Time, err error) { - var dur time.Duration - // startTimeExpr and endTimeExpr are used to calc endTime by FuncCallExpr when duration begin with `interval`. - var startTimeExpr ast.ExprNode - var endTimeExpr ast.ExprNode - for _, op := range e.OptionList { - switch op.Tp { - case ast.CalibrateStartTime: - startTimeExpr = op.Ts - startTime, err = e.parseTsExpr(ctx, startTimeExpr) - if err != nil { - return - } - case ast.CalibrateEndTime: - endTimeExpr = op.Ts - endTime, err = e.parseTsExpr(ctx, op.Ts) - if err != nil { - return - } - } - } - for _, op := range e.OptionList { - if op.Tp != ast.CalibrateDuration { - continue - } - // string duration - if len(op.StrValue) > 0 { - dur, err = duration.ParseDuration(op.StrValue) - if err != nil { - return - } - // If startTime is not set, startTime will be now() - duration. - if startTime.IsZero() { - toTime := endTime - if toTime.IsZero() { - toTime = time.Now() - } - startTime = toTime.Add(-dur) - } - // If endTime is set, duration will be ignored. - if endTime.IsZero() { - endTime = startTime.Add(dur) - } - continue - } - // interval duration - // If startTime is not set, startTime will be now() - duration. - if startTimeExpr == nil { - toTimeExpr := endTimeExpr - if endTime.IsZero() { - toTimeExpr = &ast.FuncCallExpr{FnName: model.NewCIStr("CURRENT_TIMESTAMP")} - } - startTimeExpr = &ast.FuncCallExpr{ - FnName: model.NewCIStr("DATE_SUB"), - Args: []ast.ExprNode{ - toTimeExpr, - op.Ts, - &ast.TimeUnitExpr{Unit: op.Unit}}, - } - startTime, err = e.parseTsExpr(ctx, startTimeExpr) - if err != nil { - return - } - } - // If endTime is set, duration will be ignored. - if endTime.IsZero() { - endTime, err = e.parseTsExpr(ctx, &ast.FuncCallExpr{ - FnName: model.NewCIStr("DATE_ADD"), - Args: []ast.ExprNode{startTimeExpr, - op.Ts, - &ast.TimeUnitExpr{Unit: op.Unit}}, - }) - if err != nil { - return - } - } - } - - if startTime.IsZero() { - err = errors.Errorf("start time should not be 0") - return - } - if endTime.IsZero() { - endTime = time.Now() - } - // check the duration - dur = endTime.Sub(startTime) - // add the buffer duration - if dur > maxDuration+time.Minute { - err = errors.Errorf("the duration of calibration is too long, which could lead to inaccurate output. Please make the duration between %s and %s", minDuration.String(), maxDuration.String()) - return - } - // We only need to consider the case where the duration is slightly enlarged. - if dur < minDuration { - err = errors.Errorf("the duration of calibration is too short, which could lead to inaccurate output. Please make the duration between %s and %s", minDuration.String(), maxDuration.String()) - } - return -} - -// Next implements the interface of Executor. -func (e *Executor) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.done { - return nil - } - e.done = true - if !variable.EnableResourceControl.Load() { - return infoschema.ErrResourceGroupSupportDisabled - } - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) - if len(e.OptionList) > 0 { - return e.dynamicCalibrate(ctx, req) - } - return e.staticCalibrate(req) -} - -var ( - errLowUsage = errors.Errorf("The workload in selected time window is too low, with which TiDB is unable to reach a capacity estimation; please select another time window with higher workload, or calibrate resource by hardware instead") - errNoCPUQuotaMetrics = errors.Normalize("There is no CPU quota metrics, %v") -) - -func (e *Executor) dynamicCalibrate(ctx context.Context, req *chunk.Chunk) error { - exec := e.Ctx().GetRestrictedSQLExecutor() - startTs, endTs, err := e.parseCalibrateDuration(ctx) - if err != nil { - return err - } - clusterInfo, err := infoschema.GetClusterServerInfo(e.Ctx()) - if err != nil { - return err - } - tidbQuota, err1 := e.getTiDBQuota(ctx, exec, clusterInfo, startTs, endTs) - tiflashQuota, err2 := e.getTiFlashQuota(ctx, exec, clusterInfo, startTs, endTs) - if err1 != nil && err2 != nil { - return err1 - } - - req.AppendUint64(0, uint64(tidbQuota+tiflashQuota)) - return nil -} - -func (e *Executor) getTiDBQuota( - ctx context.Context, - exec sqlexec.RestrictedSQLExecutor, - serverInfos []infoschema.ServerInfo, - startTs, endTs time.Time, -) (float64, error) { - startTime := startTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) - endTime := endTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) - - totalKVCPUQuota, err := getTiKVTotalCPUQuota(serverInfos) - if err != nil { - return 0, errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) - } - totalTiDBCPU, err := getTiDBTotalCPUQuota(serverInfos) - if err != nil { - return 0, errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) - } - rus, err := getRUPerSec(ctx, e.Ctx(), exec, startTime, endTime) - if err != nil { - return 0, err - } - tikvCPUs, err := getComponentCPUUsagePerSec(ctx, e.Ctx(), exec, "tikv", startTime, endTime) - if err != nil { - return 0, err - } - tidbCPUs, err := getComponentCPUUsagePerSec(ctx, e.Ctx(), exec, "tidb", startTime, endTime) - if err != nil { - return 0, err - } - - failpoint.Inject("mockMetricsDataFilter", func() { - ret := make([]*timePointValue, 0) - for _, point := range tikvCPUs.vals { - if point.tp.After(endTs) || point.tp.Before(startTs) { - continue - } - ret = append(ret, point) - } - tikvCPUs.vals = ret - ret = make([]*timePointValue, 0) - for _, point := range tidbCPUs.vals { - if point.tp.After(endTs) || point.tp.Before(startTs) { - continue - } - ret = append(ret, point) - } - tidbCPUs.vals = ret - ret = make([]*timePointValue, 0) - for _, point := range rus.vals { - if point.tp.After(endTs) || point.tp.Before(startTs) { - continue - } - ret = append(ret, point) - } - rus.vals = ret - }) - quotas := make([]float64, 0) - lowCount := 0 - for { - if rus.isEnd() || tikvCPUs.isEnd() || tidbCPUs.isEnd() { - break - } - // make time point match - maxTime := rus.getTime() - if tikvCPUs.getTime().After(maxTime) { - maxTime = tikvCPUs.getTime() - } - if tidbCPUs.getTime().After(maxTime) { - maxTime = tidbCPUs.getTime() - } - if !rus.advance(maxTime) || !tikvCPUs.advance(maxTime) || !tidbCPUs.advance(maxTime) { - continue - } - tikvQuota, tidbQuota := tikvCPUs.getValue()/totalKVCPUQuota, tidbCPUs.getValue()/totalTiDBCPU - // If one of the two cpu usage is greater than the `valuableUsageThreshold`, we can accept it. - // And if both are greater than the `lowUsageThreshold`, we can also accept it. - if tikvQuota > valuableUsageThreshold || tidbQuota > valuableUsageThreshold { - quotas = append(quotas, rus.getValue()/max(tikvQuota, tidbQuota)) - } else if tikvQuota < lowUsageThreshold || tidbQuota < lowUsageThreshold { - lowCount++ - } else { - quotas = append(quotas, rus.getValue()/max(tikvQuota, tidbQuota)) - } - rus.next() - tidbCPUs.next() - tikvCPUs.next() - } - quota, err := setupQuotas(quotas) - if err != nil { - return 0, err - } - return quota, nil -} - -func setupQuotas(quotas []float64) (float64, error) { - if len(quotas) < 2 { - return 0, errLowUsage - } - sort.Slice(quotas, func(i, j int) bool { - return quotas[i] > quotas[j] - }) - lowerBound := int(math.Round(float64(len(quotas)) * discardRate)) - upperBound := len(quotas) - lowerBound - sum := 0. - for i := lowerBound; i < upperBound; i++ { - sum += quotas[i] - } - return sum / float64(upperBound-lowerBound), nil -} - -func (e *Executor) getTiFlashQuota( - ctx context.Context, - exec sqlexec.RestrictedSQLExecutor, - serverInfos []infoschema.ServerInfo, - startTs, endTs time.Time, -) (float64, error) { - startTime := startTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) - endTime := endTs.In(e.Ctx().GetSessionVars().Location()).Format(time.DateTime) - - quotas := make([]float64, 0) - totalTiFlashLogicalCores, err := getTiFlashLogicalCores(serverInfos) - if err != nil { - return 0, errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) - } - tiflashCPUs, err := getTiFlashCPUUsagePerSec(ctx, e.Ctx(), exec, startTime, endTime) - if err != nil { - return 0, err - } - tiflashRUs, err := getTiFlashRUPerSec(ctx, e.Ctx(), exec, startTime, endTime) - if err != nil { - return 0, err - } - for { - if tiflashRUs.isEnd() || tiflashCPUs.isEnd() { - break - } - // make time point match - maxTime := tiflashRUs.getTime() - if tiflashCPUs.getTime().After(maxTime) { - maxTime = tiflashCPUs.getTime() - } - if !tiflashRUs.advance(maxTime) || !tiflashCPUs.advance(maxTime) { - continue - } - tiflashQuota := tiflashCPUs.getValue() / totalTiFlashLogicalCores - if tiflashQuota > lowUsageThreshold { - quotas = append(quotas, tiflashRUs.getValue()/tiflashQuota) - } - tiflashRUs.next() - tiflashCPUs.next() - } - return setupQuotas(quotas) -} - -func (e *Executor) staticCalibrate(req *chunk.Chunk) error { - resourceGroupCtl := domain.GetDomain(e.Ctx()).ResourceGroupsController() - // first fetch the ru settings config. - if resourceGroupCtl == nil { - return errors.New("resource group controller is not initialized") - } - clusterInfo, err := infoschema.GetClusterServerInfo(e.Ctx()) - if err != nil { - return err - } - ruCfg := resourceGroupCtl.GetConfig() - if e.WorkloadType == ast.TPCH10 { - return staticCalibrateTpch10(req, clusterInfo, ruCfg) - } - - totalKVCPUQuota, err := getTiKVTotalCPUQuota(clusterInfo) - if err != nil { - return errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) - } - totalTiDBCPUQuota, err := getTiDBTotalCPUQuota(clusterInfo) - if err != nil { - return errNoCPUQuotaMetrics.FastGenByArgs(err.Error()) - } - - // The default workload to calculate the RU capacity. - if e.WorkloadType == ast.WorkloadNone { - e.WorkloadType = ast.TPCC - } - baseCost, ok := workloadBaseRUCostMap[e.WorkloadType] - if !ok { - return errors.Errorf("unknown workload '%T'", e.WorkloadType) - } - - if totalTiDBCPUQuota/baseCost.tidbToKVCPURatio < totalKVCPUQuota { - totalKVCPUQuota = totalTiDBCPUQuota / baseCost.tidbToKVCPURatio - } - ruPerKVCPU := float64(ruCfg.ReadBaseCost)*float64(baseCost.readReqCount) + - float64(ruCfg.CPUMsCost)*baseCost.kvCPU*1000 + // convert to ms - float64(ruCfg.ReadBytesCost)*float64(baseCost.readBytes) + - float64(ruCfg.WriteBaseCost)*float64(baseCost.writeReqCount) + - float64(ruCfg.WriteBytesCost)*float64(baseCost.writeBytes) - quota := totalKVCPUQuota * ruPerKVCPU - req.AppendUint64(0, uint64(quota)) - return nil -} - -func staticCalibrateTpch10(req *chunk.Chunk, clusterInfo []infoschema.ServerInfo, ruCfg *resourceControlClient.RUConfig) error { - // TPCH10 only considers the resource usage of the TiFlash including cpu and read bytes. Others are ignored. - // cpu usage: 105494.666484 / 20 / 20 = 263.74 - // read bytes: 401799161689.0 / 20 / 20 = 1004497904.22 - const cpuTimePerCPUPerSec float64 = 263.74 - const readBytesPerCPUPerSec float64 = 1004497904.22 - ruPerCPU := float64(ruCfg.CPUMsCost)*cpuTimePerCPUPerSec + float64(ruCfg.ReadBytesCost)*readBytesPerCPUPerSec - totalTiFlashLogicalCores, err := getTiFlashLogicalCores(clusterInfo) - if err != nil { - return err - } - quota := totalTiFlashLogicalCores * ruPerCPU - req.AppendUint64(0, uint64(quota)) - return nil -} - -func getTiDBTotalCPUQuota(clusterInfo []infoschema.ServerInfo) (float64, error) { - cpuQuota := float64(runtime.GOMAXPROCS(0)) - failpoint.Inject("mockGOMAXPROCS", func(val failpoint.Value) { - if val != nil { - cpuQuota = float64(val.(int)) - } - }) - instanceNum := count(clusterInfo, serverTypeTiDB) - return cpuQuota * float64(instanceNum), nil -} - -func getTiKVTotalCPUQuota(clusterInfo []infoschema.ServerInfo) (float64, error) { - instanceNum := count(clusterInfo, serverTypeTiKV) - if instanceNum == 0 { - return 0.0, errors.New("no server with type 'tikv' is found") - } - cpuQuota, err := fetchServerCPUQuota(clusterInfo, serverTypeTiKV, "tikv_server_cpu_cores_quota") - if err != nil { - return 0.0, err - } - return cpuQuota * float64(instanceNum), nil -} - -func getTiFlashLogicalCores(clusterInfo []infoschema.ServerInfo) (float64, error) { - instanceNum := count(clusterInfo, serverTypeTiFlash) - if instanceNum == 0 { - return 0.0, nil - } - cpuQuota, err := fetchServerCPUQuota(clusterInfo, serverTypeTiFlash, "tiflash_proxy_tikv_server_cpu_cores_quota") - if err != nil { - return 0.0, err - } - return cpuQuota * float64(instanceNum), nil -} - -func getTiFlashRUPerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, startTime, endTime string) (*timeSeriesValues, error) { - query := fmt.Sprintf("SELECT time, value FROM METRICS_SCHEMA.tiflash_resource_manager_resource_unit where time >= '%s' and time <= '%s' ORDER BY time asc", startTime, endTime) - return getValuesFromMetrics(ctx, sctx, exec, query) -} - -func getTiFlashCPUUsagePerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, startTime, endTime string) (*timeSeriesValues, error) { - query := fmt.Sprintf("SELECT time, sum(value) FROM METRICS_SCHEMA.tiflash_process_cpu_usage where time >= '%s' and time <= '%s' and job = 'tiflash' GROUP BY time ORDER BY time asc", startTime, endTime) - return getValuesFromMetrics(ctx, sctx, exec, query) -} - -type timePointValue struct { - tp time.Time - val float64 -} - -type timeSeriesValues struct { - vals []*timePointValue - idx int -} - -func (t *timeSeriesValues) isEnd() bool { - return t.idx >= len(t.vals) -} - -func (t *timeSeriesValues) next() { - t.idx++ -} - -func (t *timeSeriesValues) getTime() time.Time { - return t.vals[t.idx].tp -} - -func (t *timeSeriesValues) getValue() float64 { - return t.vals[t.idx].val -} - -func (t *timeSeriesValues) advance(target time.Time) bool { - for ; t.idx < len(t.vals); t.idx++ { - // `target` is maximal time in other timeSeriesValues, - // so we should find the time which offset is less than 10s. - if t.vals[t.idx].tp.Add(time.Second * 10).After(target) { - return t.vals[t.idx].tp.Add(-time.Second * 10).Before(target) - } - } - return false -} - -func getRUPerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, startTime, endTime string) (*timeSeriesValues, error) { - query := fmt.Sprintf("SELECT time, value FROM METRICS_SCHEMA.resource_manager_resource_unit where time >= '%s' and time <= '%s' ORDER BY time asc", startTime, endTime) - return getValuesFromMetrics(ctx, sctx, exec, query) -} - -func getComponentCPUUsagePerSec(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, component, startTime, endTime string) (*timeSeriesValues, error) { - query := fmt.Sprintf("SELECT time, sum(value) FROM METRICS_SCHEMA.process_cpu_usage where time >= '%s' and time <= '%s' and job like '%%%s' GROUP BY time ORDER BY time asc", startTime, endTime, component) - return getValuesFromMetrics(ctx, sctx, exec, query) -} - -func getValuesFromMetrics(ctx context.Context, sctx sessionctx.Context, exec sqlexec.RestrictedSQLExecutor, query string) (*timeSeriesValues, error) { - rows, _, err := exec.ExecRestrictedSQL(ctx, []sqlexec.OptionFuncAlias{sqlexec.ExecOptionUseCurSession}, query) - if err != nil { - return nil, errors.Trace(err) - } - ret := make([]*timePointValue, 0, len(rows)) - for _, row := range rows { - if tp, err := row.GetTime(0).AdjustedGoTime(sctx.GetSessionVars().Location()); err == nil { - ret = append(ret, &timePointValue{ - tp: tp, - val: row.GetFloat64(1), - }) - } - } - return &timeSeriesValues{idx: 0, vals: ret}, nil -} - -func count(clusterInfo []infoschema.ServerInfo, ty string) int { - num := 0 - for _, e := range clusterInfo { - if e.ServerType == ty { - num++ - } - } - return num -} - -func fetchServerCPUQuota(serverInfos []infoschema.ServerInfo, serverType string, metricName string) (float64, error) { - var cpuQuota float64 - err := fetchStoreMetrics(serverInfos, serverType, func(addr string, resp *http.Response) error { - if resp.StatusCode != http.StatusOK { - return errors.Errorf("request %s failed: %s", addr, resp.Status) - } - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - if !strings.HasPrefix(line, metricName) { - continue - } - // the metrics format is like following: - // tikv_server_cpu_cores_quota 8 - quota, err := strconv.ParseFloat(line[len(metricName)+1:], 64) - if err == nil { - cpuQuota = quota - } - return errors.Trace(err) - } - return errors.Errorf("metrics '%s' not found from server '%s'", metricName, addr) - }) - return cpuQuota, err -} - -func fetchStoreMetrics(serversInfo []infoschema.ServerInfo, serverType string, onResp func(string, *http.Response) error) error { - var firstErr error - for _, srv := range serversInfo { - if srv.ServerType != serverType { - continue - } - if len(srv.StatusAddr) == 0 { - continue - } - url := fmt.Sprintf("%s://%s/metrics", util.InternalHTTPSchema(), srv.StatusAddr) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return err - } - var resp *http.Response - failpoint.Inject("mockMetricsResponse", func(val failpoint.Value) { - if val != nil { - data, _ := base64.StdEncoding.DecodeString(val.(string)) - resp = &http.Response{ - StatusCode: http.StatusOK, - Body: noopCloserWrapper{ - Reader: strings.NewReader(string(data)), - }, - } - } - }) - if resp == nil { - var err1 error - // ignore false positive go line, can't use defer here because it's in a loop. - //nolint:bodyclose - resp, err1 = util.InternalHTTPClient().Do(req) - if err1 != nil { - if firstErr == nil { - firstErr = err1 - } - continue - } - } - err = onResp(srv.Address, resp) - resp.Body.Close() - return err - } - if firstErr == nil { - firstErr = errors.Errorf("no server with type '%s' is found", serverType) - } - return firstErr -} - -type noopCloserWrapper struct { - io.Reader -} - -func (noopCloserWrapper) Close() error { - return nil -} diff --git a/pkg/executor/internal/exec/binding__failpoint_binding__.go b/pkg/executor/internal/exec/binding__failpoint_binding__.go deleted file mode 100644 index 7d0ade518b1e1..0000000000000 --- a/pkg/executor/internal/exec/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package exec - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/internal/exec/executor.go b/pkg/executor/internal/exec/executor.go index 74f6260ac652c..6feb131b67306 100644 --- a/pkg/executor/internal/exec/executor.go +++ b/pkg/executor/internal/exec/executor.go @@ -97,9 +97,9 @@ func newExecutorChunkAllocator(vars *variable.SessionVars, retFieldTypes []*type // InitCap returns the initial capacity for chunk func (e *executorChunkAllocator) InitCap() int { - if val, _err_ := failpoint.Eval(_curpkg_("initCap")); _err_ == nil { - return val.(int) - } + failpoint.Inject("initCap", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) return e.initCap } @@ -110,9 +110,9 @@ func (e *executorChunkAllocator) SetInitCap(c int) { // MaxChunkSize returns the max chunk size. func (e *executorChunkAllocator) MaxChunkSize() int { - if val, _err_ := failpoint.Eval(_curpkg_("maxChunkSize")); _err_ == nil { - return val.(int) - } + failpoint.Inject("maxChunkSize", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) return e.maxChunkSize } diff --git a/pkg/executor/internal/exec/executor.go__failpoint_stash__ b/pkg/executor/internal/exec/executor.go__failpoint_stash__ deleted file mode 100644 index 6feb131b67306..0000000000000 --- a/pkg/executor/internal/exec/executor.go__failpoint_stash__ +++ /dev/null @@ -1,468 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package exec - -import ( - "context" - "reflect" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/topsql" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/pingcap/tidb/pkg/util/tracing" - "go.uber.org/atomic" -) - -// Executor is the physical implementation of an algebra operator. -// -// In TiDB, all algebra operators are implemented as iterators, i.e., they -// support a simple Open-Next-Close protocol. See this paper for more details: -// -// "Volcano-An Extensible and Parallel Query Evaluation System" -// -// Different from Volcano's execution model, a "Next" function call in TiDB will -// return a batch of rows, other than a single row in Volcano. -// NOTE: Executors must call "chk.Reset()" before appending their results to it. -type Executor interface { - NewChunk() *chunk.Chunk - NewChunkWithCapacity(fields []*types.FieldType, capacity int, maxCachesize int) *chunk.Chunk - - RuntimeStats() *execdetails.BasicRuntimeStats - - HandleSQLKillerSignal() error - RegisterSQLAndPlanInExecForTopSQL() - - AllChildren() []Executor - SetAllChildren([]Executor) - Open(context.Context) error - Next(ctx context.Context, req *chunk.Chunk) error - - // `Close()` may be called at any time after `Open()` and it may be called with `Next()` at the same time - Close() error - Schema() *expression.Schema - RetFieldTypes() []*types.FieldType - InitCap() int - MaxChunkSize() int - - // Detach detaches the current executor from the session context without considering its children. - // - // It has to make sure, no matter whether it returns true or false, both the original executor and the returning executor - // should be able to be used correctly. - Detach() (Executor, bool) -} - -var _ Executor = &BaseExecutor{} - -// executorChunkAllocator is a helper to implement `Chunk` related methods in `Executor` interface -type executorChunkAllocator struct { - AllocPool chunk.Allocator - retFieldTypes []*types.FieldType - initCap int - maxChunkSize int -} - -// newExecutorChunkAllocator creates a new `executorChunkAllocator` -func newExecutorChunkAllocator(vars *variable.SessionVars, retFieldTypes []*types.FieldType) executorChunkAllocator { - return executorChunkAllocator{ - AllocPool: vars.GetChunkAllocator(), - initCap: vars.InitChunkSize, - maxChunkSize: vars.MaxChunkSize, - retFieldTypes: retFieldTypes, - } -} - -// InitCap returns the initial capacity for chunk -func (e *executorChunkAllocator) InitCap() int { - failpoint.Inject("initCap", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) - return e.initCap -} - -// SetInitCap sets the initial capacity for chunk -func (e *executorChunkAllocator) SetInitCap(c int) { - e.initCap = c -} - -// MaxChunkSize returns the max chunk size. -func (e *executorChunkAllocator) MaxChunkSize() int { - failpoint.Inject("maxChunkSize", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) - return e.maxChunkSize -} - -// SetMaxChunkSize sets the max chunk size. -func (e *executorChunkAllocator) SetMaxChunkSize(size int) { - e.maxChunkSize = size -} - -// NewChunk creates a new chunk according to the executor configuration -func (e *executorChunkAllocator) NewChunk() *chunk.Chunk { - return e.NewChunkWithCapacity(e.retFieldTypes, e.InitCap(), e.MaxChunkSize()) -} - -// NewChunkWithCapacity allows the caller to allocate the chunk with any types, capacity and max size in the pool -func (e *executorChunkAllocator) NewChunkWithCapacity(fields []*types.FieldType, capacity int, maxCachesize int) *chunk.Chunk { - return e.AllocPool.Alloc(fields, capacity, maxCachesize) -} - -// executorMeta is a helper to store metadata for an execturo and implement the getter -type executorMeta struct { - schema *expression.Schema - children []Executor - retFieldTypes []*types.FieldType - id int -} - -// newExecutorMeta creates a new `executorMeta` -func newExecutorMeta(schema *expression.Schema, id int, children ...Executor) executorMeta { - e := executorMeta{ - id: id, - schema: schema, - children: children, - } - if schema != nil { - cols := schema.Columns - e.retFieldTypes = make([]*types.FieldType, len(cols)) - for i := range cols { - e.retFieldTypes[i] = cols[i].RetType - } - } - return e -} - -// NewChunkWithCapacity allows the caller to allocate the chunk with any types, capacity and max size in the pool -func (e *executorMeta) RetFieldTypes() []*types.FieldType { - return e.retFieldTypes -} - -// ID returns the id of an executor. -func (e *executorMeta) ID() int { - return e.id -} - -// AllChildren returns all children. -func (e *executorMeta) AllChildren() []Executor { - return e.children -} - -// SetAllChildren sets the children for an executor. -func (e *executorMeta) SetAllChildren(children []Executor) { - e.children = children -} - -// ChildrenLen returns the length of children. -func (e *executorMeta) ChildrenLen() int { - return len(e.children) -} - -// EmptyChildren judges whether the children is empty. -func (e *executorMeta) EmptyChildren() bool { - return len(e.children) == 0 -} - -// SetChildren sets a child for an executor. -func (e *executorMeta) SetChildren(idx int, ex Executor) { - e.children[idx] = ex -} - -// Children returns the children for an executor. -func (e *executorMeta) Children(idx int) Executor { - return e.children[idx] -} - -// Schema returns the current BaseExecutor's schema. If it is nil, then create and return a new one. -func (e *executorMeta) Schema() *expression.Schema { - if e.schema == nil { - return expression.NewSchema() - } - return e.schema -} - -// GetSchema gets the schema. -func (e *executorMeta) GetSchema() *expression.Schema { - return e.schema -} - -// executorStats is a helper to implement the stats related methods for `Executor` -type executorStats struct { - runtimeStats *execdetails.BasicRuntimeStats - isSQLAndPlanRegistered *atomic.Bool - sqlDigest *parser.Digest - planDigest *parser.Digest - normalizedSQL string - normalizedPlan string - inRestrictedSQL bool -} - -// newExecutorStats creates a new `executorStats` -func newExecutorStats(stmtCtx *stmtctx.StatementContext, id int) executorStats { - normalizedSQL, sqlDigest := stmtCtx.SQLDigest() - normalizedPlan, planDigest := stmtCtx.GetPlanDigest() - e := executorStats{ - isSQLAndPlanRegistered: &stmtCtx.IsSQLAndPlanRegistered, - normalizedSQL: normalizedSQL, - sqlDigest: sqlDigest, - normalizedPlan: normalizedPlan, - planDigest: planDigest, - inRestrictedSQL: stmtCtx.InRestrictedSQL, - } - - if stmtCtx.RuntimeStatsColl != nil { - if id > 0 { - e.runtimeStats = stmtCtx.RuntimeStatsColl.GetBasicRuntimeStats(id) - } - } - - return e -} - -// RuntimeStats returns the runtime stats of an executor. -func (e *executorStats) RuntimeStats() *execdetails.BasicRuntimeStats { - return e.runtimeStats -} - -// RegisterSQLAndPlanInExecForTopSQL registers the current SQL and Plan on top sql -func (e *executorStats) RegisterSQLAndPlanInExecForTopSQL() { - if topsqlstate.TopSQLEnabled() && e.isSQLAndPlanRegistered.CompareAndSwap(false, true) { - topsql.RegisterSQL(e.normalizedSQL, e.sqlDigest, e.inRestrictedSQL) - if len(e.normalizedPlan) > 0 { - topsql.RegisterPlan(e.normalizedPlan, e.planDigest) - } - } -} - -type signalHandler interface { - HandleSignal() error -} - -// executorKillerHandler is a helper to implement the killer related methods for `Executor`. -type executorKillerHandler struct { - handler signalHandler -} - -func (e *executorKillerHandler) HandleSQLKillerSignal() error { - return e.handler.HandleSignal() -} - -func newExecutorKillerHandler(handler signalHandler) executorKillerHandler { - return executorKillerHandler{handler} -} - -// BaseExecutorV2 is a simplified version of `BaseExecutor`, which doesn't contain a full session context -type BaseExecutorV2 struct { - executorMeta - executorKillerHandler - executorStats - executorChunkAllocator -} - -// NewBaseExecutorV2 creates a new BaseExecutorV2 instance. -func NewBaseExecutorV2(vars *variable.SessionVars, schema *expression.Schema, id int, children ...Executor) BaseExecutorV2 { - executorMeta := newExecutorMeta(schema, id, children...) - e := BaseExecutorV2{ - executorMeta: executorMeta, - executorStats: newExecutorStats(vars.StmtCtx, id), - executorChunkAllocator: newExecutorChunkAllocator(vars, executorMeta.RetFieldTypes()), - executorKillerHandler: newExecutorKillerHandler(&vars.SQLKiller), - } - return e -} - -// Open initializes children recursively and "childrenResults" according to children's schemas. -func (e *BaseExecutorV2) Open(ctx context.Context) error { - for _, child := range e.children { - err := Open(ctx, child) - if err != nil { - return err - } - } - return nil -} - -// Close closes all executors and release all resources. -func (e *BaseExecutorV2) Close() error { - var firstErr error - for _, src := range e.children { - if err := Close(src); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr -} - -// Next fills multiple rows into a chunk. -func (*BaseExecutorV2) Next(_ context.Context, _ *chunk.Chunk) error { - return nil -} - -// Detach detaches the current executor from the session context. -func (*BaseExecutorV2) Detach() (Executor, bool) { - return nil, false -} - -// BuildNewBaseExecutorV2 builds a new `BaseExecutorV2` based on the configuration of the current base executor. -// It's used to build a new sub-executor from an existing executor. For example, the `IndexLookUpExecutor` will use -// this function to build `TableReaderExecutor` -func (e *BaseExecutorV2) BuildNewBaseExecutorV2(stmtRuntimeStatsColl *execdetails.RuntimeStatsColl, schema *expression.Schema, id int, children ...Executor) BaseExecutorV2 { - newExecutorMeta := newExecutorMeta(schema, id, children...) - - newExecutorStats := e.executorStats - if stmtRuntimeStatsColl != nil { - if id > 0 { - newExecutorStats.runtimeStats = stmtRuntimeStatsColl.GetBasicRuntimeStats(id) - } - } - - newChunkAllocator := e.executorChunkAllocator - newChunkAllocator.retFieldTypes = newExecutorMeta.RetFieldTypes() - newE := BaseExecutorV2{ - executorMeta: newExecutorMeta, - executorStats: newExecutorStats, - executorChunkAllocator: newChunkAllocator, - executorKillerHandler: e.executorKillerHandler, - } - return newE -} - -// BaseExecutor holds common information for executors. -type BaseExecutor struct { - ctx sessionctx.Context - - BaseExecutorV2 -} - -// NewBaseExecutor creates a new BaseExecutor instance. -func NewBaseExecutor(ctx sessionctx.Context, schema *expression.Schema, id int, children ...Executor) BaseExecutor { - return BaseExecutor{ - ctx: ctx, - BaseExecutorV2: NewBaseExecutorV2(ctx.GetSessionVars(), schema, id, children...), - } -} - -// Ctx return ```sessionctx.Context``` of Executor -func (e *BaseExecutor) Ctx() sessionctx.Context { - return e.ctx -} - -// UpdateDeltaForTableID updates the delta info for the table with tableID. -func (e *BaseExecutor) UpdateDeltaForTableID(id int64) { - txnCtx := e.ctx.GetSessionVars().TxnCtx - txnCtx.UpdateDeltaForTable(id, 0, 0, nil) -} - -// GetSysSession gets a system session context from executor. -func (e *BaseExecutor) GetSysSession() (sessionctx.Context, error) { - dom := domain.GetDomain(e.Ctx()) - sysSessionPool := dom.SysSessionPool() - ctx, err := sysSessionPool.Get() - if err != nil { - return nil, err - } - restrictedCtx := ctx.(sessionctx.Context) - restrictedCtx.GetSessionVars().InRestrictedSQL = true - return restrictedCtx, nil -} - -// ReleaseSysSession releases a system session context to executor. -func (e *BaseExecutor) ReleaseSysSession(ctx context.Context, sctx sessionctx.Context) { - if sctx == nil { - return - } - dom := domain.GetDomain(e.Ctx()) - sysSessionPool := dom.SysSessionPool() - if _, err := sctx.GetSQLExecutor().ExecuteInternal(ctx, "rollback"); err != nil { - sctx.(pools.Resource).Close() - return - } - sysSessionPool.Put(sctx.(pools.Resource)) -} - -// TryNewCacheChunk tries to get a cached chunk -func TryNewCacheChunk(e Executor) *chunk.Chunk { - return e.NewChunk() -} - -// RetTypes returns all output column types. -func RetTypes(e Executor) []*types.FieldType { - return e.RetFieldTypes() -} - -// NewFirstChunk creates a new chunk to buffer current executor's result. -func NewFirstChunk(e Executor) *chunk.Chunk { - return chunk.New(e.RetFieldTypes(), e.InitCap(), e.MaxChunkSize()) -} - -// Open is a wrapper function on e.Open(), it handles some common codes. -func Open(ctx context.Context, e Executor) (err error) { - defer func() { - if r := recover(); r != nil { - err = util.GetRecoverError(r) - } - }() - return e.Open(ctx) -} - -// Next is a wrapper function on e.Next(), it handles some common codes. -func Next(ctx context.Context, e Executor, req *chunk.Chunk) (err error) { - defer func() { - if r := recover(); r != nil { - err = util.GetRecoverError(r) - } - }() - if e.RuntimeStats() != nil { - start := time.Now() - defer func() { e.RuntimeStats().Record(time.Since(start), req.NumRows()) }() - } - - if err := e.HandleSQLKillerSignal(); err != nil { - return err - } - - r, ctx := tracing.StartRegionEx(ctx, reflect.TypeOf(e).String()+".Next") - defer r.End() - - e.RegisterSQLAndPlanInExecForTopSQL() - err = e.Next(ctx, req) - - if err != nil { - return err - } - // recheck whether the session/query is killed during the Next() - return e.HandleSQLKillerSignal() -} - -// Close is a wrapper function on e.Close(), it handles some common codes. -func Close(e Executor) (err error) { - defer func() { - if r := recover(); r != nil { - err = util.GetRecoverError(r) - } - }() - return e.Close() -} diff --git a/pkg/executor/internal/mpp/binding__failpoint_binding__.go b/pkg/executor/internal/mpp/binding__failpoint_binding__.go deleted file mode 100644 index bb1223ce4ae04..0000000000000 --- a/pkg/executor/internal/mpp/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package mpp - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/internal/mpp/executor_with_retry.go b/pkg/executor/internal/mpp/executor_with_retry.go index 4d3d6d6b6578a..fd2bd527dc48c 100644 --- a/pkg/executor/internal/mpp/executor_with_retry.go +++ b/pkg/executor/internal/mpp/executor_with_retry.go @@ -84,11 +84,11 @@ func NewExecutorWithRetry(ctx context.Context, sctx sessionctx.Context, parentTr // 3. For cached table, will not dispatch tasks to TiFlash, so no need to recovery. enableMPPRecovery := disaggTiFlashWithAutoScaler && !allowTiFlashFallback - if _, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_mock_enable")); _err_ == nil { + failpoint.Inject("mpp_recovery_test_mock_enable", func() { if !allowTiFlashFallback { enableMPPRecovery = true } - } + }) recoveryHandler := NewRecoveryHandler(disaggTiFlashWithAutoScaler, uint64(holdCap), enableMPPRecovery, parentTracker) @@ -176,12 +176,12 @@ func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { resp, mppErr := r.coord.Next(ctx) // Mock recovery n times. - if forceErrCnt, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_max_err_times")); _err_ == nil { + failpoint.Inject("mpp_recovery_test_max_err_times", func(forceErrCnt failpoint.Value) { forceErrCntInt := forceErrCnt.(int) if r.mppErrRecovery.RecoveryCnt() < uint32(forceErrCntInt) { mppErr = errors.New("mock mpp error") } - } + }) if mppErr != nil { recoveryErr := r.mppErrRecovery.Recovery(&RecoveryInfo{ @@ -190,14 +190,14 @@ func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { }) // Mock recovery succeed, ignore no recovery handler err. - if _, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_ignore_recovery_err")); _err_ == nil { + failpoint.Inject("mpp_recovery_test_ignore_recovery_err", func() { if recoveryErr == nil { panic("mocked mpp err should got recovery err") } if strings.Contains(mppErr.Error(), "mock mpp error") && strings.Contains(recoveryErr.Error(), "no handler to recovery") { recoveryErr = nil } - } + }) if recoveryErr != nil { logutil.BgLogger().Error("recovery mpp error failed", zap.Any("mppErr", mppErr), @@ -224,14 +224,14 @@ func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { r.mppErrRecovery.HoldResult(resp.(*mppResponse)) } - if num, _err_ := failpoint.Eval(_curpkg_("mpp_recovery_test_hold_size")); _err_ == nil { + failpoint.Inject("mpp_recovery_test_hold_size", func(num failpoint.Value) { // Note: this failpoint only execute once. curRows := r.mppErrRecovery.NumHoldResp() numInt := num.(int) if curRows != numInt { panic(fmt.Sprintf("unexpected holding rows, cur: %d", curRows)) } - } + }) return nil } diff --git a/pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ b/pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ deleted file mode 100644 index fd2bd527dc48c..0000000000000 --- a/pkg/executor/internal/mpp/executor_with_retry.go__failpoint_stash__ +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mpp - -import ( - "context" - "fmt" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/executor/mppcoordmanager" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - plannercore "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" -) - -// ExecutorWithRetry receive mppResponse from localMppCoordinator, -// and tries to recovery mpp err if necessary. -// The abstraction layer of reading mpp resp: -// 1. MPPGather: As part of the TiDB Volcano model executor, it is equivalent to a TableReader. -// 2. selectResult: Decode select result(mppResponse) into chunk. Also record runtime info. -// 3. ExecutorWithRetry: Recovery mpp err if possible and retry MPP Task. -// 4. localMppCoordinator: Generate MPP fragment and dispatch MPPTask. -// And receive MPP status for better err msg and correct stats for Limit. -// 5. mppIterator: Send or receive MPP RPC. -type ExecutorWithRetry struct { - coord kv.MppCoordinator - sctx sessionctx.Context - is infoschema.InfoSchema - plan plannercore.PhysicalPlan - ctx context.Context - memTracker *memory.Tracker - // mppErrRecovery is designed for the recovery of MPP errors. - // Basic idea: - // 1. It attempts to hold the results of MPP. During the holding process, if an error occurs, it starts error recovery. - // If the recovery is successful, it discards held results and reconstructs the respIter, then re-executes the MPP task. - // If the recovery fails, an error is reported directly. - // 2. If the held MPP results exceed the capacity, will starts returning results to caller. - // Once the results start being returned, error recovery cannot be performed anymore. - mppErrRecovery *RecoveryHandler - planIDs []int - // Expose to let MPPGather access. - KVRanges []kv.KeyRange - queryID kv.MPPQueryID - startTS uint64 - gatherID uint64 - nodeCnt int -} - -var _ kv.Response = &ExecutorWithRetry{} - -// NewExecutorWithRetry create ExecutorWithRetry. -func NewExecutorWithRetry(ctx context.Context, sctx sessionctx.Context, parentTracker *memory.Tracker, planIDs []int, - plan plannercore.PhysicalPlan, startTS uint64, queryID kv.MPPQueryID, - is infoschema.InfoSchema) (*ExecutorWithRetry, error) { - // TODO: After add row info in tipb.DataPacket, we can use row count as capacity. - // For now, use the number of tipb.DataPacket as capacity. - const holdCap = 2 - - disaggTiFlashWithAutoScaler := config.GetGlobalConfig().DisaggregatedTiFlash && config.GetGlobalConfig().UseAutoScaler - _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] - - // 1. For now, mpp err recovery only support MemLimit, which is only useful when AutoScaler is used. - // 2. When enable fallback to tikv, the returned mpp err will be ErrTiFlashServerTimeout, - // which we cannot handle for now. Also there is no need to recovery because tikv will retry the query. - // 3. For cached table, will not dispatch tasks to TiFlash, so no need to recovery. - enableMPPRecovery := disaggTiFlashWithAutoScaler && !allowTiFlashFallback - - failpoint.Inject("mpp_recovery_test_mock_enable", func() { - if !allowTiFlashFallback { - enableMPPRecovery = true - } - }) - - recoveryHandler := NewRecoveryHandler(disaggTiFlashWithAutoScaler, - uint64(holdCap), enableMPPRecovery, parentTracker) - memTracker := memory.NewTracker(parentTracker.Label(), 0) - memTracker.AttachTo(parentTracker) - retryer := &ExecutorWithRetry{ - ctx: ctx, - sctx: sctx, - memTracker: memTracker, - planIDs: planIDs, - is: is, - plan: plan, - startTS: startTS, - queryID: queryID, - mppErrRecovery: recoveryHandler, - } - - var err error - retryer.KVRanges, err = retryer.setupMPPCoordinator(ctx, false) - return retryer, err -} - -// Next implements kv.Response interface. -func (r *ExecutorWithRetry) Next(ctx context.Context) (resp kv.ResultSubset, err error) { - if err = r.nextWithRecovery(ctx); err != nil { - return nil, err - } - - if r.mppErrRecovery.NumHoldResp() != 0 { - if resp, err = r.mppErrRecovery.PopFrontResp(); err != nil { - return nil, err - } - } else if resp, err = r.coord.Next(ctx); err != nil { - return nil, err - } - return resp, nil -} - -// Close implements kv.Response interface. -func (r *ExecutorWithRetry) Close() error { - r.mppErrRecovery.ResetHolder() - r.memTracker.Detach() - // Need to close coordinator before unregister to avoid coord.Close() takes too long. - err := r.coord.Close() - mppcoordmanager.InstanceMPPCoordinatorManager.Unregister(r.getCoordUniqueID()) - return err -} - -func (r *ExecutorWithRetry) setupMPPCoordinator(ctx context.Context, recoverying bool) ([]kv.KeyRange, error) { - if recoverying { - // Sanity check. - if r.coord == nil { - return nil, errors.New("mpp coordinator should not be nil when recoverying") - } - // Only report runtime stats when there is no error. - r.coord.(*localMppCoordinator).closeWithoutReport() - mppcoordmanager.InstanceMPPCoordinatorManager.Unregister(r.getCoordUniqueID()) - } - - // Make sure gatherID is updated before build coord. - r.gatherID = allocMPPGatherID(r.sctx) - - r.coord = r.buildCoordinator() - if err := mppcoordmanager.InstanceMPPCoordinatorManager.Register(r.getCoordUniqueID(), r.coord); err != nil { - return nil, err - } - - _, kvRanges, err := r.coord.Execute(ctx) - if err != nil { - return nil, err - } - - if r.nodeCnt = r.coord.GetNodeCnt(); r.nodeCnt <= 0 { - return nil, errors.Errorf("tiflash node count should be greater than zero: %v", r.nodeCnt) - } - return kvRanges, err -} - -func (r *ExecutorWithRetry) nextWithRecovery(ctx context.Context) error { - if !r.mppErrRecovery.Enabled() { - return nil - } - - for r.mppErrRecovery.CanHoldResult() { - resp, mppErr := r.coord.Next(ctx) - - // Mock recovery n times. - failpoint.Inject("mpp_recovery_test_max_err_times", func(forceErrCnt failpoint.Value) { - forceErrCntInt := forceErrCnt.(int) - if r.mppErrRecovery.RecoveryCnt() < uint32(forceErrCntInt) { - mppErr = errors.New("mock mpp error") - } - }) - - if mppErr != nil { - recoveryErr := r.mppErrRecovery.Recovery(&RecoveryInfo{ - MPPErr: mppErr, - NodeCnt: r.nodeCnt, - }) - - // Mock recovery succeed, ignore no recovery handler err. - failpoint.Inject("mpp_recovery_test_ignore_recovery_err", func() { - if recoveryErr == nil { - panic("mocked mpp err should got recovery err") - } - if strings.Contains(mppErr.Error(), "mock mpp error") && strings.Contains(recoveryErr.Error(), "no handler to recovery") { - recoveryErr = nil - } - }) - - if recoveryErr != nil { - logutil.BgLogger().Error("recovery mpp error failed", zap.Any("mppErr", mppErr), - zap.Any("recoveryErr", recoveryErr)) - return mppErr - } - - logutil.BgLogger().Info("recovery mpp error succeed, begin next retry", - zap.Any("mppErr", mppErr), zap.Any("recoveryCnt", r.mppErrRecovery.RecoveryCnt())) - - if _, err := r.setupMPPCoordinator(r.ctx, true); err != nil { - logutil.BgLogger().Error("setup resp iter when recovery mpp err failed", zap.Any("err", err)) - return mppErr - } - r.mppErrRecovery.ResetHolder() - - continue - } - - if resp == nil { - break - } - - r.mppErrRecovery.HoldResult(resp.(*mppResponse)) - } - - failpoint.Inject("mpp_recovery_test_hold_size", func(num failpoint.Value) { - // Note: this failpoint only execute once. - curRows := r.mppErrRecovery.NumHoldResp() - numInt := num.(int) - if curRows != numInt { - panic(fmt.Sprintf("unexpected holding rows, cur: %d", curRows)) - } - }) - return nil -} - -func allocMPPGatherID(ctx sessionctx.Context) uint64 { - mppQueryInfo := &ctx.GetSessionVars().StmtCtx.MPPQueryInfo - return mppQueryInfo.AllocatedMPPGatherID.Add(1) -} - -func (r *ExecutorWithRetry) buildCoordinator() kv.MppCoordinator { - _, serverAddr := mppcoordmanager.InstanceMPPCoordinatorManager.GetServerAddr() - return NewLocalMPPCoordinator(r.ctx, r.sctx, r.is, r.plan, r.planIDs, r.startTS, r.queryID, - r.gatherID, serverAddr, r.memTracker) -} - -func (r *ExecutorWithRetry) getCoordUniqueID() mppcoordmanager.CoordinatorUniqueID { - return mppcoordmanager.CoordinatorUniqueID{ - MPPQueryID: r.queryID, - GatherID: r.gatherID, - } -} diff --git a/pkg/executor/internal/mpp/local_mpp_coordinator.go b/pkg/executor/internal/mpp/local_mpp_coordinator.go index 5299f27ef9a3f..15ef85eaa2da0 100644 --- a/pkg/executor/internal/mpp/local_mpp_coordinator.go +++ b/pkg/executor/internal/mpp/local_mpp_coordinator.go @@ -363,11 +363,11 @@ func (c *localMppCoordinator) dispatchAll(ctx context.Context) { c.mu.Unlock() c.wg.Add(1) boMaxSleep := copr.CopNextMaxBackoff - if value, _err_ := failpoint.Eval(_curpkg_("ReduceCopNextMaxBackoff")); _err_ == nil { + failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { if value.(bool) { boMaxSleep = 2 } - } + }) bo := backoff.NewBackoffer(ctx, boMaxSleep) go func(mppTask *kv.MPPDispatchRequest) { defer func() { @@ -395,11 +395,11 @@ func (c *localMppCoordinator) sendToRespCh(resp *mppResponse) (exit bool) { }() if c.memTracker != nil { respSize := resp.MemSize() - if val, _err_ := failpoint.Eval(_curpkg_("testMPPOOMPanic")); _err_ == nil { + failpoint.Inject("testMPPOOMPanic", func(val failpoint.Value) { if val.(bool) && respSize != 0 { respSize = 1 << 30 } - } + }) c.memTracker.Consume(respSize) defer c.memTracker.Consume(-respSize) } @@ -451,14 +451,14 @@ func (c *localMppCoordinator) handleDispatchReq(ctx context.Context, bo *backoff c.sendError(errors.New(rpcResp.Error.Msg)) return } - if val, _err_ := failpoint.Eval(_curpkg_("mppNonRootTaskError")); _err_ == nil { + failpoint.Inject("mppNonRootTaskError", func(val failpoint.Value) { if val.(bool) && !req.IsRoot { time.Sleep(1 * time.Second) atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) c.sendError(derr.ErrTiFlashServerTimeout) return } - } + }) if !req.IsRoot { return } @@ -747,11 +747,11 @@ func (c *localMppCoordinator) Execute(ctx context.Context) (kv.Response, []kv.Ke return nil, nil, errors.Trace(err) } } - if val, _err_ := failpoint.Eval(_curpkg_("checkTotalMPPTasks")); _err_ == nil { + failpoint.Inject("checkTotalMPPTasks", func(val failpoint.Value) { if val.(int) != len(c.mppReqs) { - return nil, nil, errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(c.mppReqs)) + failpoint.Return(nil, nil, errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(c.mppReqs))) } - } + }) ctx = distsql.WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx.KvExecCounter) _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] diff --git a/pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ b/pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ deleted file mode 100644 index 15ef85eaa2da0..0000000000000 --- a/pkg/executor/internal/mpp/local_mpp_coordinator.go__failpoint_stash__ +++ /dev/null @@ -1,772 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mpp - -import ( - "context" - "fmt" - "io" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/mpp" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/executor/internal/builder" - "github.com/pingcap/tidb/pkg/executor/internal/util" - "github.com/pingcap/tidb/pkg/executor/metrics" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/store/copr" - "github.com/pingcap/tidb/pkg/store/driver/backoff" - derr "github.com/pingcap/tidb/pkg/store/driver/error" - util2 "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tipb/go-tipb" - clientutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -const ( - receiveReportTimeout = 3 * time.Second -) - -// mppResponse wraps mpp data packet. -type mppResponse struct { - err error - pbResp *mpp.MPPDataPacket - detail *copr.CopRuntimeStats - respTime time.Duration - respSize int64 -} - -// GetData implements the kv.ResultSubset GetData interface. -func (m *mppResponse) GetData() []byte { - return m.pbResp.Data -} - -// GetStartKey implements the kv.ResultSubset GetStartKey interface. -func (*mppResponse) GetStartKey() kv.Key { - return nil -} - -// GetCopRuntimeStats is unavailable currently. -func (m *mppResponse) GetCopRuntimeStats() *copr.CopRuntimeStats { - return m.detail -} - -// MemSize returns how many bytes of memory this response use -func (m *mppResponse) MemSize() int64 { - if m.respSize != 0 { - return m.respSize - } - if m.detail != nil { - m.respSize += int64(int(unsafe.Sizeof(execdetails.ExecDetails{}))) - } - if m.pbResp != nil { - m.respSize += int64(m.pbResp.Size()) - } - return m.respSize -} - -func (m *mppResponse) RespTime() time.Duration { - return m.respTime -} - -type mppRequestReport struct { - mppReq *kv.MPPDispatchRequest - errMsg string - executionSummaries []*tipb.ExecutorExecutionSummary - receivedReport bool // if received ReportStatus from mpp task -} - -// localMppCoordinator stands for constructing and dispatching mpp tasks in local tidb server, since these work might be done remotely too -type localMppCoordinator struct { - ctx context.Context - sessionCtx sessionctx.Context - is infoschema.InfoSchema - originalPlan base.PhysicalPlan - reqMap map[int64]*mppRequestReport - - cancelFunc context.CancelFunc - - wgDoneChan chan struct{} - - memTracker *memory.Tracker - - reportStatusCh chan struct{} // used to notify inside coordinator that all reports has been received - - vars *kv.Variables - - respChan chan *mppResponse - - finishCh chan struct{} - - coordinatorAddr string // empty if coordinator service not available - firstErrMsg string - - mppReqs []*kv.MPPDispatchRequest - - planIDs []int - mppQueryID kv.MPPQueryID - - wg sync.WaitGroup - gatherID uint64 - reportedReqCount int - startTS uint64 - mu sync.Mutex - - closed uint32 - - dispatchFailed uint32 - allReportsHandled uint32 - - needTriggerFallback bool - enableCollectExecutionInfo bool - reportExecutionInfo bool // if each mpp task needs to report execution info directly to coordinator through ReportMPPTaskStatus - - // Record node cnt that involved in the mpp computation. - nodeCnt int -} - -// NewLocalMPPCoordinator creates a new localMppCoordinator instance -func NewLocalMPPCoordinator(ctx context.Context, sctx sessionctx.Context, is infoschema.InfoSchema, plan base.PhysicalPlan, planIDs []int, startTS uint64, mppQueryID kv.MPPQueryID, gatherID uint64, coordinatorAddr string, memTracker *memory.Tracker) *localMppCoordinator { - if sctx.GetSessionVars().ChooseMppVersion() < kv.MppVersionV2 { - coordinatorAddr = "" - } - coord := &localMppCoordinator{ - ctx: ctx, - sessionCtx: sctx, - is: is, - originalPlan: plan, - planIDs: planIDs, - startTS: startTS, - mppQueryID: mppQueryID, - gatherID: gatherID, - coordinatorAddr: coordinatorAddr, - memTracker: memTracker, - finishCh: make(chan struct{}), - wgDoneChan: make(chan struct{}), - respChan: make(chan *mppResponse), - reportStatusCh: make(chan struct{}), - vars: sctx.GetSessionVars().KVVars, - reqMap: make(map[int64]*mppRequestReport), - } - - if len(coordinatorAddr) > 0 && needReportExecutionSummary(coord.originalPlan) { - coord.reportExecutionInfo = true - } - return coord -} - -func (c *localMppCoordinator) appendMPPDispatchReq(pf *plannercore.Fragment) error { - dagReq, err := builder.ConstructDAGReq(c.sessionCtx, []base.PhysicalPlan{pf.ExchangeSender}, kv.TiFlash) - if err != nil { - return errors.Trace(err) - } - for i := range pf.ExchangeSender.Schema().Columns { - dagReq.OutputOffsets = append(dagReq.OutputOffsets, uint32(i)) - } - if !pf.IsRoot { - dagReq.EncodeType = tipb.EncodeType_TypeCHBlock - } else { - dagReq.EncodeType = tipb.EncodeType_TypeChunk - } - for _, mppTask := range pf.ExchangeSender.Tasks { - if mppTask.PartitionTableIDs != nil { - err = util.UpdateExecutorTableID(context.Background(), dagReq.RootExecutor, true, mppTask.PartitionTableIDs) - } else if !mppTask.TiFlashStaticPrune { - // If isDisaggregatedTiFlashStaticPrune is true, it means this TableScan is under PartitionUnoin, - // tableID in TableScan is already the physical table id of this partition, no need to update again. - err = util.UpdateExecutorTableID(context.Background(), dagReq.RootExecutor, true, []int64{mppTask.TableID}) - } - if err != nil { - return errors.Trace(err) - } - err = c.fixTaskForCTEStorageAndReader(dagReq.RootExecutor, mppTask.Meta) - if err != nil { - return err - } - pbData, err := dagReq.Marshal() - if err != nil { - return errors.Trace(err) - } - - rgName := c.sessionCtx.GetSessionVars().StmtCtx.ResourceGroupName - if !variable.EnableResourceControl.Load() { - rgName = "" - } - logutil.BgLogger().Info("Dispatch mpp task", zap.Uint64("timestamp", mppTask.StartTs), - zap.Int64("ID", mppTask.ID), zap.Uint64("QueryTs", mppTask.MppQueryID.QueryTs), zap.Uint64("LocalQueryId", mppTask.MppQueryID.LocalQueryID), - zap.Uint64("ServerID", mppTask.MppQueryID.ServerID), zap.String("address", mppTask.Meta.GetAddress()), - zap.String("plan", plannercore.ToString(pf.ExchangeSender)), - zap.Int64("mpp-version", mppTask.MppVersion.ToInt64()), - zap.String("exchange-compression-mode", pf.ExchangeSender.CompressionMode.Name()), - zap.Uint64("GatherID", c.gatherID), - zap.String("resource_group", rgName), - ) - req := &kv.MPPDispatchRequest{ - Data: pbData, - Meta: mppTask.Meta, - ID: mppTask.ID, - IsRoot: pf.IsRoot, - Timeout: 10, - SchemaVar: c.is.SchemaMetaVersion(), - StartTs: c.startTS, - MppQueryID: mppTask.MppQueryID, - GatherID: c.gatherID, - MppVersion: mppTask.MppVersion, - CoordinatorAddress: c.coordinatorAddr, - ReportExecutionSummary: c.reportExecutionInfo, - State: kv.MppTaskReady, - ResourceGroupName: rgName, - ConnectionID: c.sessionCtx.GetSessionVars().ConnectionID, - ConnectionAlias: c.sessionCtx.ShowProcess().SessionAlias, - } - c.reqMap[req.ID] = &mppRequestReport{mppReq: req, receivedReport: false, errMsg: "", executionSummaries: nil} - c.mppReqs = append(c.mppReqs, req) - } - return nil -} - -// fixTaskForCTEStorageAndReader fixes the upstream/downstream tasks for the producers and consumers. -// After we split the fragments. A CTE producer in the fragment will holds all the task address of the consumers. -// For example, the producer has two task on node_1 and node_2. As we know that each consumer also has two task on the same nodes(node_1 and node_2) -// We need to prune address of node_2 for producer's task on node_1 since we just want the producer task on the node_1 only send to the consumer tasks on the node_1. -// And the same for the task on the node_2. -// And the same for the consumer task. We need to prune the unnecessary task address of its producer tasks(i.e. the downstream tasks). -func (c *localMppCoordinator) fixTaskForCTEStorageAndReader(exec *tipb.Executor, meta kv.MPPTaskMeta) error { - children := make([]*tipb.Executor, 0, 2) - switch exec.Tp { - case tipb.ExecType_TypeTableScan, tipb.ExecType_TypePartitionTableScan, tipb.ExecType_TypeIndexScan: - case tipb.ExecType_TypeSelection: - children = append(children, exec.Selection.Child) - case tipb.ExecType_TypeAggregation, tipb.ExecType_TypeStreamAgg: - children = append(children, exec.Aggregation.Child) - case tipb.ExecType_TypeTopN: - children = append(children, exec.TopN.Child) - case tipb.ExecType_TypeLimit: - children = append(children, exec.Limit.Child) - case tipb.ExecType_TypeExchangeSender: - children = append(children, exec.ExchangeSender.Child) - if len(exec.ExchangeSender.UpstreamCteTaskMeta) == 0 { - break - } - actualUpStreamTasks := make([][]byte, 0, len(exec.ExchangeSender.UpstreamCteTaskMeta)) - actualTIDs := make([]int64, 0, len(exec.ExchangeSender.UpstreamCteTaskMeta)) - for _, tasksFromOneConsumer := range exec.ExchangeSender.UpstreamCteTaskMeta { - for _, taskBytes := range tasksFromOneConsumer.EncodedTasks { - taskMeta := &mpp.TaskMeta{} - err := taskMeta.Unmarshal(taskBytes) - if err != nil { - return err - } - if taskMeta.Address != meta.GetAddress() { - continue - } - actualUpStreamTasks = append(actualUpStreamTasks, taskBytes) - actualTIDs = append(actualTIDs, taskMeta.TaskId) - } - } - logutil.BgLogger().Warn("refine tunnel for cte producer task", zap.String("the final tunnel", fmt.Sprintf("up stream consumer tasks: %v", actualTIDs))) - exec.ExchangeSender.EncodedTaskMeta = actualUpStreamTasks - case tipb.ExecType_TypeExchangeReceiver: - if len(exec.ExchangeReceiver.OriginalCtePrdocuerTaskMeta) == 0 { - break - } - exec.ExchangeReceiver.EncodedTaskMeta = [][]byte{} - actualTIDs := make([]int64, 0, 4) - for _, taskBytes := range exec.ExchangeReceiver.OriginalCtePrdocuerTaskMeta { - taskMeta := &mpp.TaskMeta{} - err := taskMeta.Unmarshal(taskBytes) - if err != nil { - return err - } - if taskMeta.Address != meta.GetAddress() { - continue - } - exec.ExchangeReceiver.EncodedTaskMeta = append(exec.ExchangeReceiver.EncodedTaskMeta, taskBytes) - actualTIDs = append(actualTIDs, taskMeta.TaskId) - } - logutil.BgLogger().Warn("refine tunnel for cte consumer task", zap.String("the final tunnel", fmt.Sprintf("down stream producer task: %v", actualTIDs))) - case tipb.ExecType_TypeJoin: - children = append(children, exec.Join.Children...) - case tipb.ExecType_TypeProjection: - children = append(children, exec.Projection.Child) - case tipb.ExecType_TypeWindow: - children = append(children, exec.Window.Child) - case tipb.ExecType_TypeSort: - children = append(children, exec.Sort.Child) - case tipb.ExecType_TypeExpand: - children = append(children, exec.Expand.Child) - case tipb.ExecType_TypeExpand2: - children = append(children, exec.Expand2.Child) - default: - return errors.Errorf("unknown new tipb protocol %d", exec.Tp) - } - for _, child := range children { - err := c.fixTaskForCTEStorageAndReader(child, meta) - if err != nil { - return err - } - } - return nil -} - -// DFS to check if plan needs report execution summary through ReportMPPTaskStatus mpp service -// Currently, return true if plan contains limit operator -func needReportExecutionSummary(plan base.PhysicalPlan) bool { - switch x := plan.(type) { - case *plannercore.PhysicalLimit: - return true - default: - for _, child := range x.Children() { - if needReportExecutionSummary(child) { - return true - } - } - } - return false -} - -func (c *localMppCoordinator) dispatchAll(ctx context.Context) { - for _, task := range c.mppReqs { - if atomic.LoadUint32(&c.closed) == 1 { - break - } - c.mu.Lock() - if task.State == kv.MppTaskReady { - task.State = kv.MppTaskRunning - } - c.mu.Unlock() - c.wg.Add(1) - boMaxSleep := copr.CopNextMaxBackoff - failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { - if value.(bool) { - boMaxSleep = 2 - } - }) - bo := backoff.NewBackoffer(ctx, boMaxSleep) - go func(mppTask *kv.MPPDispatchRequest) { - defer func() { - c.wg.Done() - }() - c.handleDispatchReq(ctx, bo, mppTask) - }(task) - } - c.wg.Wait() - close(c.wgDoneChan) - close(c.respChan) -} - -func (c *localMppCoordinator) sendError(err error) { - c.sendToRespCh(&mppResponse{err: err}) - c.cancelMppTasks() -} - -func (c *localMppCoordinator) sendToRespCh(resp *mppResponse) (exit bool) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("localMppCoordinator panic", zap.Stack("stack"), zap.Any("recover", r)) - c.sendError(util2.GetRecoverError(r)) - } - }() - if c.memTracker != nil { - respSize := resp.MemSize() - failpoint.Inject("testMPPOOMPanic", func(val failpoint.Value) { - if val.(bool) && respSize != 0 { - respSize = 1 << 30 - } - }) - c.memTracker.Consume(respSize) - defer c.memTracker.Consume(-respSize) - } - select { - case c.respChan <- resp: - case <-c.finishCh: - exit = true - } - return -} - -// TODO:: Consider that which way is better: -// - dispatch all Tasks at once, and connect Tasks at second. -// - dispatch Tasks and establish connection at the same time. -func (c *localMppCoordinator) handleDispatchReq(ctx context.Context, bo *backoff.Backoffer, req *kv.MPPDispatchRequest) { - var rpcResp *mpp.DispatchTaskResponse - var err error - var retry bool - for { - rpcResp, retry, err = c.sessionCtx.GetMPPClient().DispatchMPPTask( - kv.DispatchMPPTaskParam{ - Ctx: ctx, - Req: req, - EnableCollectExecutionInfo: c.enableCollectExecutionInfo, - Bo: bo.TiKVBackoffer(), - }) - if retry { - // TODO: If we want to retry, we might need to redo the plan fragment cutting and task scheduling. https://github.com/pingcap/tidb/issues/31015 - logutil.BgLogger().Warn("mpp dispatch meet error and retrying", zap.Error(err), zap.Uint64("timestamp", c.startTS), zap.Int64("task", req.ID), zap.Int64("mpp-version", req.MppVersion.ToInt64())) - continue - } - break - } - - if err != nil { - logutil.BgLogger().Error("mpp dispatch meet error", zap.String("error", err.Error()), zap.Uint64("timestamp", req.StartTs), zap.Int64("task", req.ID), zap.Int64("mpp-version", req.MppVersion.ToInt64())) - atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) - // if NeedTriggerFallback is true, we return timeout to trigger tikv's fallback - if c.needTriggerFallback { - err = derr.ErrTiFlashServerTimeout - } - c.sendError(err) - return - } - - if rpcResp.Error != nil { - logutil.BgLogger().Error("mpp dispatch response meet error", zap.String("error", rpcResp.Error.Msg), zap.Uint64("timestamp", req.StartTs), zap.Int64("task", req.ID), zap.Int64("task-mpp-version", req.MppVersion.ToInt64()), zap.Int64("error-mpp-version", rpcResp.Error.GetMppVersion())) - atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) - c.sendError(errors.New(rpcResp.Error.Msg)) - return - } - failpoint.Inject("mppNonRootTaskError", func(val failpoint.Value) { - if val.(bool) && !req.IsRoot { - time.Sleep(1 * time.Second) - atomic.CompareAndSwapUint32(&c.dispatchFailed, 0, 1) - c.sendError(derr.ErrTiFlashServerTimeout) - return - } - }) - if !req.IsRoot { - return - } - // only root task should establish a stream conn with tiFlash to receive result. - taskMeta := &mpp.TaskMeta{StartTs: req.StartTs, GatherId: c.gatherID, QueryTs: req.MppQueryID.QueryTs, LocalQueryId: req.MppQueryID.LocalQueryID, TaskId: req.ID, ServerId: req.MppQueryID.ServerID, - Address: req.Meta.GetAddress(), - MppVersion: req.MppVersion.ToInt64(), - ResourceGroupName: req.ResourceGroupName, - } - c.receiveResults(req, taskMeta, bo) -} - -// NOTE: We do not retry here, because retry is helpless when errors result from TiFlash or Network. If errors occur, the execution on TiFlash will finally stop after some minutes. -// This function is exclusively called, and only the first call succeeds sending Tasks and setting all Tasks as cancelled, while others will not work. -func (c *localMppCoordinator) cancelMppTasks() { - if len(c.mppReqs) == 0 { - return - } - usedStoreAddrs := make(map[string]bool) - c.mu.Lock() - // 1. One address will receive one cancel request, since cancel request cancels all mpp tasks within the same mpp gather - // 2. Cancel process will set all mpp task requests' states, thus if one request's state is Cancelled already, just return - if c.mppReqs[0].State == kv.MppTaskCancelled { - c.mu.Unlock() - return - } - for _, task := range c.mppReqs { - // get the store address of running tasks, - if task.State == kv.MppTaskRunning && !usedStoreAddrs[task.Meta.GetAddress()] { - usedStoreAddrs[task.Meta.GetAddress()] = true - } - task.State = kv.MppTaskCancelled - } - c.mu.Unlock() - c.sessionCtx.GetMPPClient().CancelMPPTasks(kv.CancelMPPTasksParam{StoreAddr: usedStoreAddrs, Reqs: c.mppReqs}) -} - -func (c *localMppCoordinator) receiveResults(req *kv.MPPDispatchRequest, taskMeta *mpp.TaskMeta, bo *backoff.Backoffer) { - stream, err := c.sessionCtx.GetMPPClient().EstablishMPPConns(kv.EstablishMPPConnsParam{Ctx: bo.GetCtx(), Req: req, TaskMeta: taskMeta}) - if err != nil { - // if NeedTriggerFallback is true, we return timeout to trigger tikv's fallback - if c.needTriggerFallback { - c.sendError(derr.ErrTiFlashServerTimeout) - } else { - c.sendError(err) - } - return - } - - defer stream.Close() - resp := stream.MPPDataPacket - if resp == nil { - return - } - - for { - err := c.handleMPPStreamResponse(bo, resp, req) - if err != nil { - c.sendError(err) - return - } - - resp, err = stream.Recv() - if err != nil { - if errors.Cause(err) == io.EOF { - return - } - - logutil.BgLogger().Info("mpp stream recv got error", zap.Error(err), zap.Uint64("timestamp", taskMeta.StartTs), - zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) - - // if NeedTriggerFallback is true, we return timeout to trigger tikv's fallback - if c.needTriggerFallback { - c.sendError(derr.ErrTiFlashServerTimeout) - } else { - c.sendError(err) - } - return - } - } -} - -// ReportStatus implements MppCoordinator interface -func (c *localMppCoordinator) ReportStatus(info kv.ReportStatusRequest) error { - taskID := info.Request.Meta.TaskId - var errMsg string - if info.Request.Error != nil { - errMsg = info.Request.Error.Msg - } - executionInfo := new(tipb.TiFlashExecutionInfo) - err := executionInfo.Unmarshal(info.Request.GetData()) - if err != nil { - // since it is very corner case to reach here, and it won't cause forever hang due to not close reportStatusCh - return err - } - - c.mu.Lock() - defer c.mu.Unlock() - req, exists := c.reqMap[taskID] - if !exists { - return errors.Errorf("ReportMPPTaskStatus task not exists taskID: %d", taskID) - } - if req.receivedReport { - return errors.Errorf("ReportMPPTaskStatus task already received taskID: %d", taskID) - } - - req.receivedReport = true - if len(errMsg) > 0 { - req.errMsg = errMsg - if len(c.firstErrMsg) == 0 { - c.firstErrMsg = errMsg - } - } - - c.reportedReqCount++ - req.executionSummaries = executionInfo.GetExecutionSummaries() - if c.reportedReqCount == len(c.mppReqs) { - close(c.reportStatusCh) - } - return nil -} - -func (c *localMppCoordinator) handleAllReports() error { - if c.reportExecutionInfo && atomic.LoadUint32(&c.dispatchFailed) == 0 && atomic.CompareAndSwapUint32(&c.allReportsHandled, 0, 1) { - startTime := time.Now() - select { - case <-c.reportStatusCh: - metrics.MppCoordinatorLatencyRcvReport.Observe(float64(time.Since(startTime).Milliseconds())) - var recordedPlanIDs = make(map[int]int) - for _, report := range c.reqMap { - for _, detail := range report.executionSummaries { - if detail != nil && detail.TimeProcessedNs != nil && - detail.NumProducedRows != nil && detail.NumIterations != nil { - recordedPlanIDs[c.sessionCtx.GetSessionVars().StmtCtx.RuntimeStatsColl. - RecordOneCopTask(-1, kv.TiFlash.Name(), report.mppReq.Meta.GetAddress(), detail)] = 0 - } - } - if ruDetailsRaw := c.ctx.Value(clientutil.RUDetailsCtxKey); ruDetailsRaw != nil { - if err := execdetails.MergeTiFlashRUConsumption(report.executionSummaries, ruDetailsRaw.(*clientutil.RUDetails)); err != nil { - return err - } - } - } - distsql.FillDummySummariesForTiFlashTasks(c.sessionCtx.GetSessionVars().StmtCtx.RuntimeStatsColl, "", kv.TiFlash.Name(), c.planIDs, recordedPlanIDs) - case <-time.After(receiveReportTimeout): - metrics.MppCoordinatorStatsReportNotReceived.Inc() - logutil.BgLogger().Warn(fmt.Sprintf("Mpp coordinator not received all reports within %d seconds", int(receiveReportTimeout.Seconds())), - zap.Uint64("txnStartTS", c.startTS), - zap.Uint64("gatherID", c.gatherID), - zap.Int("expectCount", len(c.mppReqs)), - zap.Int("actualCount", c.reportedReqCount)) - } - } - return nil -} - -// IsClosed implements MppCoordinator interface -func (c *localMppCoordinator) IsClosed() bool { - return atomic.LoadUint32(&c.closed) == 1 -} - -// Close implements MppCoordinator interface -// TODO: Test the case that user cancels the query. -func (c *localMppCoordinator) Close() error { - c.closeWithoutReport() - return c.handleAllReports() -} - -func (c *localMppCoordinator) closeWithoutReport() { - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - close(c.finishCh) - } - c.cancelFunc() - <-c.wgDoneChan -} - -func (c *localMppCoordinator) handleMPPStreamResponse(bo *backoff.Backoffer, response *mpp.MPPDataPacket, req *kv.MPPDispatchRequest) (err error) { - if response.Error != nil { - c.mu.Lock() - firstErrMsg := c.firstErrMsg - c.mu.Unlock() - // firstErrMsg is only used when already received error response from root tasks, avoid confusing error messages - if len(firstErrMsg) > 0 { - err = errors.Errorf("other error for mpp stream: %s", firstErrMsg) - } else { - err = errors.Errorf("other error for mpp stream: %s", response.Error.Msg) - } - logutil.BgLogger().Warn("other error", - zap.Uint64("txnStartTS", req.StartTs), - zap.String("storeAddr", req.Meta.GetAddress()), - zap.Int64("mpp-version", req.MppVersion.ToInt64()), - zap.Int64("task-id", req.ID), - zap.Error(err)) - return err - } - - resp := &mppResponse{ - pbResp: response, - detail: new(copr.CopRuntimeStats), - } - - backoffTimes := bo.GetBackoffTimes() - resp.detail.BackoffTime = time.Duration(bo.GetTotalSleep()) * time.Millisecond - resp.detail.BackoffSleep = make(map[string]time.Duration, len(backoffTimes)) - resp.detail.BackoffTimes = make(map[string]int, len(backoffTimes)) - for backoff := range backoffTimes { - resp.detail.BackoffTimes[backoff] = backoffTimes[backoff] - resp.detail.BackoffSleep[backoff] = time.Duration(bo.GetBackoffSleepMS()[backoff]) * time.Millisecond - } - resp.detail.CalleeAddress = req.Meta.GetAddress() - c.sendToRespCh(resp) - return -} - -func (c *localMppCoordinator) nextImpl(ctx context.Context) (resp *mppResponse, ok bool, exit bool, err error) { - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - for { - select { - case resp, ok = <-c.respChan: - return - case <-ticker.C: - if c.vars != nil && c.vars.Killed != nil { - killed := atomic.LoadUint32(c.vars.Killed) - if killed != 0 { - logutil.Logger(ctx).Info( - "a killed signal is received", - zap.Uint32("signal", killed), - ) - err = derr.ErrQueryInterrupted - exit = true - return - } - } - case <-c.finishCh: - exit = true - return - case <-ctx.Done(): - if atomic.CompareAndSwapUint32(&c.closed, 0, 1) { - close(c.finishCh) - } - exit = true - return - } - } -} - -// Next implements MppCoordinator interface -func (c *localMppCoordinator) Next(ctx context.Context) (kv.ResultSubset, error) { - resp, ok, closed, err := c.nextImpl(ctx) - if err != nil { - return nil, errors.Trace(err) - } - if !ok || closed { - return nil, nil - } - - if resp.err != nil { - return nil, errors.Trace(resp.err) - } - - err = c.sessionCtx.GetMPPClient().CheckVisibility(c.startTS) - if err != nil { - return nil, errors.Trace(derr.ErrQueryInterrupted) - } - return resp, nil -} - -// Execute implements MppCoordinator interface -func (c *localMppCoordinator) Execute(ctx context.Context) (kv.Response, []kv.KeyRange, error) { - // TODO: Move the construct tasks logic to planner, so we can see the explain results. - sender := c.originalPlan.(*plannercore.PhysicalExchangeSender) - sctx := c.sessionCtx - frags, kvRanges, nodeInfo, err := plannercore.GenerateRootMPPTasks(sctx, c.startTS, c.gatherID, c.mppQueryID, sender, c.is) - if err != nil { - return nil, nil, errors.Trace(err) - } - if nodeInfo == nil { - return nil, nil, errors.New("node info should not be nil") - } - c.nodeCnt = len(nodeInfo) - - for _, frag := range frags { - err = c.appendMPPDispatchReq(frag) - if err != nil { - return nil, nil, errors.Trace(err) - } - } - failpoint.Inject("checkTotalMPPTasks", func(val failpoint.Value) { - if val.(int) != len(c.mppReqs) { - failpoint.Return(nil, nil, errors.Errorf("The number of tasks is not right, expect %d tasks but actually there are %d tasks", val.(int), len(c.mppReqs))) - } - }) - - ctx = distsql.WithSQLKvExecCounterInterceptor(ctx, sctx.GetSessionVars().StmtCtx.KvExecCounter) - _, allowTiFlashFallback := sctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] - ctx = distsql.SetTiFlashConfVarsInContext(ctx, sctx.GetDistSQLCtx()) - c.needTriggerFallback = allowTiFlashFallback - c.enableCollectExecutionInfo = config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load() - - var ctxChild context.Context - ctxChild, c.cancelFunc = context.WithCancel(ctx) - go c.dispatchAll(ctxChild) - - return c, kvRanges, nil -} - -// GetNodeCnt returns the node count that involved in the mpp computation. -func (c *localMppCoordinator) GetNodeCnt() int { - return c.nodeCnt -} diff --git a/pkg/executor/internal/pdhelper/binding__failpoint_binding__.go b/pkg/executor/internal/pdhelper/binding__failpoint_binding__.go deleted file mode 100644 index d67d83e08b252..0000000000000 --- a/pkg/executor/internal/pdhelper/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package pdhelper - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/internal/pdhelper/pd.go b/pkg/executor/internal/pdhelper/pd.go index 7a4f8c099ccbf..4893791d14e3f 100644 --- a/pkg/executor/internal/pdhelper/pd.go +++ b/pkg/executor/internal/pdhelper/pd.go @@ -93,13 +93,13 @@ func getApproximateTableCountFromStorage( return 0, false } regionStats, err := helper.NewHelper(tikvStore).GetPDRegionStats(ctx, tid, true) - if _, _err_ := failpoint.Eval(_curpkg_("calcSampleRateByStorageCount")); _err_ == nil { + failpoint.Inject("calcSampleRateByStorageCount", func() { // Force the TiDB thinking that there's PD and the count of region is small. err = nil regionStats.Count = 1 // Set a very large approximate count. regionStats.StorageKeys = 1000000 - } + }) if err != nil { return 0, false } diff --git a/pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ b/pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ deleted file mode 100644 index 4893791d14e3f..0000000000000 --- a/pkg/executor/internal/pdhelper/pd.go__failpoint_stash__ +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pdhelper - -import ( - "context" - "strconv" - "strings" - "sync" - "time" - - "github.com/jellydator/ttlcache/v3" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/sqlescape" -) - -// GlobalPDHelper is the global variable for PDHelper. -var GlobalPDHelper = defaultPDHelper() -var globalPDHelperOnce sync.Once - -// PDHelper is used to get some information from PD. -type PDHelper struct { - cacheForApproximateTableCountFromStorage *ttlcache.Cache[string, float64] - - getApproximateTableCountFromStorageFunc func(ctx context.Context, sctx sessionctx.Context, tid int64, dbName, tableName, partitionName string) (float64, bool) - wg util.WaitGroupWrapper -} - -func defaultPDHelper() *PDHelper { - cache := ttlcache.New[string, float64]( - ttlcache.WithTTL[string, float64](30*time.Second), - ttlcache.WithCapacity[string, float64](1024*1024), - ) - return &PDHelper{ - cacheForApproximateTableCountFromStorage: cache, - getApproximateTableCountFromStorageFunc: getApproximateTableCountFromStorage, - } -} - -// Start is used to start the background task of PDHelper. Currently, the background task is used to clean up TTL cache. -func (p *PDHelper) Start() { - globalPDHelperOnce.Do(func() { - p.wg.Run(p.cacheForApproximateTableCountFromStorage.Start) - }) -} - -// Stop stops the background task of PDHelper. -func (p *PDHelper) Stop() { - p.cacheForApproximateTableCountFromStorage.Stop() - p.wg.Wait() -} - -func approximateTableCountKey(tid int64, dbName, tableName, partitionName string) string { - return strings.Join([]string{strconv.FormatInt(tid, 10), dbName, tableName, partitionName}, "_") -} - -// GetApproximateTableCountFromStorage gets the approximate count of the table. -func (p *PDHelper) GetApproximateTableCountFromStorage( - ctx context.Context, sctx sessionctx.Context, - tid int64, dbName, tableName, partitionName string, -) (float64, bool) { - key := approximateTableCountKey(tid, dbName, tableName, partitionName) - if item := p.cacheForApproximateTableCountFromStorage.Get(key); item != nil { - return item.Value(), true - } - result, hasPD := p.getApproximateTableCountFromStorageFunc(ctx, sctx, tid, dbName, tableName, partitionName) - p.cacheForApproximateTableCountFromStorage.Set(key, result, ttlcache.DefaultTTL) - return result, hasPD -} - -func getApproximateTableCountFromStorage( - ctx context.Context, sctx sessionctx.Context, - tid int64, dbName, tableName, partitionName string, -) (float64, bool) { - tikvStore, ok := sctx.GetStore().(helper.Storage) - if !ok { - return 0, false - } - regionStats, err := helper.NewHelper(tikvStore).GetPDRegionStats(ctx, tid, true) - failpoint.Inject("calcSampleRateByStorageCount", func() { - // Force the TiDB thinking that there's PD and the count of region is small. - err = nil - regionStats.Count = 1 - // Set a very large approximate count. - regionStats.StorageKeys = 1000000 - }) - if err != nil { - return 0, false - } - // If this table is not small, we directly use the count from PD, - // since for a small table, it's possible that it's data is in the same region with part of another large table. - // Thus, we use the number of the regions of the table's table KV to decide whether the table is small. - if regionStats.Count > 2 { - return float64(regionStats.StorageKeys), true - } - // Otherwise, we use count(*) to calc it's size, since it's very small, the table data can be filled in no more than 2 regions. - sql := new(strings.Builder) - sqlescape.MustFormatSQL(sql, "select count(*) from %n.%n", dbName, tableName) - if partitionName != "" { - sqlescape.MustFormatSQL(sql, " partition(%n)", partitionName) - } - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) - rows, _, err := sctx.GetRestrictedSQLExecutor().ExecRestrictedSQL(ctx, nil, sql.String()) - if err != nil { - return 0, false - } - // If the record set is nil, there's something wrong with the execution. The COUNT(*) would always return one row. - if len(rows) == 0 || rows[0].Len() == 0 { - return 0, false - } - return float64(rows[0].GetInt64(0)), true -} diff --git a/pkg/executor/join/base_join_probe.go b/pkg/executor/join/base_join_probe.go index 58face22ee483..64a815bafdc51 100644 --- a/pkg/executor/join/base_join_probe.go +++ b/pkg/executor/join/base_join_probe.go @@ -262,11 +262,11 @@ func (j *baseJoinProbe) finishLookupCurrentProbeRow() { func checkSQLKiller(killer *sqlkiller.SQLKiller, fpName string) error { err := killer.HandleSignal() - if val, _err_ := failpoint.Eval(_curpkg_(fpName)); _err_ == nil { + failpoint.Inject(fpName, func(val failpoint.Value) { if val.(bool) { err = exeerrors.ErrQueryInterrupted } - } + }) return err } diff --git a/pkg/executor/join/base_join_probe.go__failpoint_stash__ b/pkg/executor/join/base_join_probe.go__failpoint_stash__ deleted file mode 100644 index 64a815bafdc51..0000000000000 --- a/pkg/executor/join/base_join_probe.go__failpoint_stash__ +++ /dev/null @@ -1,593 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "bytes" - "hash/fnv" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/sqlkiller" -) - -type keyMode int - -const ( - // OneInt64 mean the key contains only one Int64 - OneInt64 keyMode = iota - // FixedSerializedKey mean the key has fixed length - FixedSerializedKey - // VariableSerializedKey mean the key has variable length - VariableSerializedKey -) - -const batchBuildRowSize = 32 - -func (hCtx *HashJoinCtxV2) hasOtherCondition() bool { - return hCtx.OtherCondition != nil -} - -// ProbeV2 is the interface used to do probe in hash join v2 -type ProbeV2 interface { - // SetChunkForProbe will do some pre-work when start probing a chunk - SetChunkForProbe(chunk *chunk.Chunk) error - // Probe is to probe current chunk, the result chunk is set in result.chk, and Probe need to make sure result.chk.NumRows() <= result.chk.RequiredRows() - Probe(joinResult *hashjoinWorkerResult, sqlKiller *sqlkiller.SQLKiller) (ok bool, result *hashjoinWorkerResult) - // IsCurrentChunkProbeDone returns true if current probe chunk is all probed - IsCurrentChunkProbeDone() bool - // ScanRowTable is called after all the probe chunks are probed. It is used in some special joins, like left outer join with left side to build, after all - // the probe side chunks are handled, it needs to scan the row table to return the un-matched rows - ScanRowTable(joinResult *hashjoinWorkerResult, sqlKiller *sqlkiller.SQLKiller) (result *hashjoinWorkerResult) - // IsScanRowTableDone returns true after scan row table is done - IsScanRowTableDone() bool - // NeedScanRowTable returns true if current join need to scan row table after all the probe side chunks are handled - NeedScanRowTable() bool - // InitForScanRowTable do some pre-work before ScanRowTable, it must be called before ScanRowTable - InitForScanRowTable() - // Return probe collsion - GetProbeCollision() uint64 - // Reset probe collsion - ResetProbeCollision() -} - -type offsetAndLength struct { - offset int - length int -} - -type matchedRowInfo struct { - // probeRowIndex mean the probe side index of current matched row - probeRowIndex int - // buildRowStart mean the build row start of the current matched row - buildRowStart uintptr - // buildRowOffset mean the current offset of current BuildRow, used to construct column data from BuildRow - buildRowOffset int -} - -func createMatchRowInfo(probeRowIndex int, buildRowStart unsafe.Pointer) *matchedRowInfo { - ret := &matchedRowInfo{probeRowIndex: probeRowIndex} - *(*unsafe.Pointer)(unsafe.Pointer(&ret.buildRowStart)) = buildRowStart - return ret -} - -type posAndHashValue struct { - hashValue uint64 - pos int -} - -type baseJoinProbe struct { - ctx *HashJoinCtxV2 - workID uint - - currentChunk *chunk.Chunk - // if currentChunk.Sel() == nil, then construct a fake selRows - selRows []int - usedRows []int - // matchedRowsHeaders, serializedKeys is indexed by logical row index - matchedRowsHeaders []uintptr // the start address of each matched rows - serializedKeys [][]byte // used for save serialized keys - // filterVector and nullKeyVector is indexed by physical row index because the return vector of VectorizedFilter is based on physical row index - filterVector []bool // if there is filter before probe, filterVector saves the filter result - nullKeyVector []bool // nullKeyVector[i] = true if any of the key is null - hashValues [][]posAndHashValue // the start address of each matched rows - currentProbeRow int - matchedRowsForCurrentProbeRow int - chunkRows int - cachedBuildRows []*matchedRowInfo - - keyIndex []int - keyTypes []*types.FieldType - hasNullableKey bool - maxChunkSize int - rightAsBuildSide bool - // lUsed/rUsed show which columns are used by father for left child and right child. - // NOTE: - // 1. lUsed/rUsed should never be nil. - // 2. no columns are used if lUsed/rUsed is not nil but the size of lUsed/rUsed is 0. - lUsed, rUsed []int - lUsedInOtherCondition, rUsedInOtherCondition []int - // used when construct column from probe side - offsetAndLengthArray []offsetAndLength - // these 3 variables are used for join that has other condition, should be inited when the join has other condition - tmpChk *chunk.Chunk - rowIndexInfos []*matchedRowInfo - selected []bool - - probeCollision uint64 -} - -func (j *baseJoinProbe) GetProbeCollision() uint64 { - return j.probeCollision -} - -func (j *baseJoinProbe) ResetProbeCollision() { - j.probeCollision = 0 -} - -func (j *baseJoinProbe) IsCurrentChunkProbeDone() bool { - return j.currentChunk == nil || j.currentProbeRow >= j.chunkRows -} - -func (j *baseJoinProbe) finishCurrentLookupLoop(joinedChk *chunk.Chunk) { - if len(j.cachedBuildRows) > 0 { - j.batchConstructBuildRows(joinedChk, 0, j.ctx.hasOtherCondition()) - } - j.finishLookupCurrentProbeRow() - j.appendProbeRowToChunk(joinedChk, j.currentChunk) -} - -func (j *baseJoinProbe) SetChunkForProbe(chk *chunk.Chunk) (err error) { - if j.currentChunk != nil { - if j.currentProbeRow < j.chunkRows { - return errors.New("Previous chunk is not probed yet") - } - } - j.currentChunk = chk - logicalRows := chk.NumRows() - // if chk.sel != nil, then physicalRows is different from logicalRows - physicalRows := chk.Column(0).Rows() - j.usedRows = chk.Sel() - if j.usedRows == nil { - if cap(j.selRows) >= logicalRows { - j.selRows = j.selRows[:logicalRows] - } else { - j.selRows = make([]int, 0, logicalRows) - for i := 0; i < logicalRows; i++ { - j.selRows = append(j.selRows, i) - } - } - j.usedRows = j.selRows - } - j.chunkRows = logicalRows - if cap(j.matchedRowsHeaders) >= logicalRows { - j.matchedRowsHeaders = j.matchedRowsHeaders[:logicalRows] - } else { - j.matchedRowsHeaders = make([]uintptr, logicalRows) - } - for i := 0; i < int(j.ctx.partitionNumber); i++ { - j.hashValues[i] = j.hashValues[i][:0] - } - if j.ctx.ProbeFilter != nil { - if cap(j.filterVector) >= physicalRows { - j.filterVector = j.filterVector[:physicalRows] - } else { - j.filterVector = make([]bool, physicalRows) - } - } - if j.hasNullableKey { - if cap(j.nullKeyVector) >= physicalRows { - j.nullKeyVector = j.nullKeyVector[:physicalRows] - } else { - j.nullKeyVector = make([]bool, physicalRows) - } - for i := 0; i < physicalRows; i++ { - j.nullKeyVector[i] = false - } - } - if cap(j.serializedKeys) >= logicalRows { - j.serializedKeys = j.serializedKeys[:logicalRows] - } else { - j.serializedKeys = make([][]byte, logicalRows) - } - for i := 0; i < logicalRows; i++ { - j.serializedKeys[i] = j.serializedKeys[i][:0] - } - if j.ctx.ProbeFilter != nil { - j.filterVector, err = expression.VectorizedFilter(j.ctx.SessCtx.GetExprCtx().GetEvalCtx(), j.ctx.SessCtx.GetSessionVars().EnableVectorizedExpression, j.ctx.ProbeFilter, chunk.NewIterator4Chunk(j.currentChunk), j.filterVector) - if err != nil { - return err - } - } - - // generate serialized key - for i, index := range j.keyIndex { - err = codec.SerializeKeys(j.ctx.SessCtx.GetSessionVars().StmtCtx.TypeCtx(), j.currentChunk, j.keyTypes[i], index, j.usedRows, j.filterVector, j.nullKeyVector, j.ctx.hashTableMeta.serializeModes[i], j.serializedKeys) - if err != nil { - return err - } - } - // generate hash value - hash := fnv.New64() - for logicalRowIndex, physicalRowIndex := range j.usedRows { - if (j.filterVector != nil && !j.filterVector[physicalRowIndex]) || (j.nullKeyVector != nil && j.nullKeyVector[physicalRowIndex]) { - // explicit set the matchedRowsHeaders[logicalRowIndex] to nil to indicate there is no matched rows - j.matchedRowsHeaders[logicalRowIndex] = 0 - continue - } - hash.Reset() - // As the golang doc described, `Hash.Write` never returns an error. - // See https://golang.org/pkg/hash/#Hash - _, _ = hash.Write(j.serializedKeys[logicalRowIndex]) - hashValue := hash.Sum64() - partIndex := hashValue >> j.ctx.partitionMaskOffset - j.hashValues[partIndex] = append(j.hashValues[partIndex], posAndHashValue{hashValue: hashValue, pos: logicalRowIndex}) - } - j.currentProbeRow = 0 - for i := 0; i < int(j.ctx.partitionNumber); i++ { - for index := range j.hashValues[i] { - j.matchedRowsHeaders[j.hashValues[i][index].pos] = j.ctx.hashTableContext.hashTable.tables[i].lookup(j.hashValues[i][index].hashValue) - } - } - return -} - -func (j *baseJoinProbe) finishLookupCurrentProbeRow() { - if j.matchedRowsForCurrentProbeRow > 0 { - j.offsetAndLengthArray = append(j.offsetAndLengthArray, offsetAndLength{offset: j.usedRows[j.currentProbeRow], length: j.matchedRowsForCurrentProbeRow}) - } - j.matchedRowsForCurrentProbeRow = 0 -} - -func checkSQLKiller(killer *sqlkiller.SQLKiller, fpName string) error { - err := killer.HandleSignal() - failpoint.Inject(fpName, func(val failpoint.Value) { - if val.(bool) { - err = exeerrors.ErrQueryInterrupted - } - }) - return err -} - -func (j *baseJoinProbe) appendBuildRowToCachedBuildRowsAndConstructBuildRowsIfNeeded(buildRow *matchedRowInfo, chk *chunk.Chunk, currentColumnIndexInRow int, forOtherCondition bool) { - j.cachedBuildRows = append(j.cachedBuildRows, buildRow) - if len(j.cachedBuildRows) >= batchBuildRowSize { - j.batchConstructBuildRows(chk, currentColumnIndexInRow, forOtherCondition) - } -} - -func (j *baseJoinProbe) batchConstructBuildRows(chk *chunk.Chunk, currentColumnIndexInRow int, forOtherCondition bool) { - j.appendBuildRowToChunk(chk, currentColumnIndexInRow, forOtherCondition) - if forOtherCondition { - j.rowIndexInfos = append(j.rowIndexInfos, j.cachedBuildRows...) - } - j.cachedBuildRows = j.cachedBuildRows[:0] -} - -func (j *baseJoinProbe) prepareForProbe(chk *chunk.Chunk) (joinedChk *chunk.Chunk, remainCap int, err error) { - j.offsetAndLengthArray = j.offsetAndLengthArray[:0] - j.cachedBuildRows = j.cachedBuildRows[:0] - j.matchedRowsForCurrentProbeRow = 0 - joinedChk = chk - if j.ctx.OtherCondition != nil { - j.tmpChk.Reset() - j.rowIndexInfos = j.rowIndexInfos[:0] - j.selected = j.selected[:0] - joinedChk = j.tmpChk - } - return joinedChk, chk.RequiredRows() - chk.NumRows(), nil -} - -func (j *baseJoinProbe) appendBuildRowToChunk(chk *chunk.Chunk, currentColumnIndexInRow int, forOtherCondition bool) { - if j.rightAsBuildSide { - if forOtherCondition { - j.appendBuildRowToChunkInternal(chk, j.rUsedInOtherCondition, true, j.currentChunk.NumCols(), currentColumnIndexInRow) - } else { - j.appendBuildRowToChunkInternal(chk, j.rUsed, false, len(j.lUsed), currentColumnIndexInRow) - } - } else { - if forOtherCondition { - j.appendBuildRowToChunkInternal(chk, j.lUsedInOtherCondition, true, 0, currentColumnIndexInRow) - } else { - j.appendBuildRowToChunkInternal(chk, j.lUsed, false, 0, currentColumnIndexInRow) - } - } -} - -func (j *baseJoinProbe) appendBuildRowToChunkInternal(chk *chunk.Chunk, usedCols []int, forOtherCondition bool, colOffset int, currentColumnInRow int) { - chkRows := chk.NumRows() - needUpdateVirtualRow := currentColumnInRow == 0 - if len(usedCols) == 0 || len(j.cachedBuildRows) == 0 { - if needUpdateVirtualRow { - chk.SetNumVirtualRows(chkRows + len(j.cachedBuildRows)) - } - return - } - for i := 0; i < len(j.cachedBuildRows); i++ { - if j.cachedBuildRows[i].buildRowOffset == 0 { - j.ctx.hashTableMeta.advanceToRowData(j.cachedBuildRows[i]) - } - } - colIndexMap := make(map[int]int) - for index, value := range usedCols { - if forOtherCondition { - colIndexMap[value] = value + colOffset - } else { - colIndexMap[value] = index + colOffset - } - } - meta := j.ctx.hashTableMeta - columnsToAppend := len(meta.rowColumnsOrder) - if forOtherCondition { - columnsToAppend = meta.columnCountNeededForOtherCondition - if j.ctx.RightAsBuildSide { - for _, value := range j.rUsed { - colIndexMap[value] = value + colOffset - } - } else { - for _, value := range j.lUsed { - colIndexMap[value] = value + colOffset - } - } - } - for columnIndex := currentColumnInRow; columnIndex < len(meta.rowColumnsOrder) && columnIndex < columnsToAppend; columnIndex++ { - indexInDstChk, ok := colIndexMap[meta.rowColumnsOrder[columnIndex]] - var currentColumn *chunk.Column - if ok { - currentColumn = chk.Column(indexInDstChk) - for index := range j.cachedBuildRows { - currentColumn.AppendNullBitmap(!meta.isColumnNull(*(*unsafe.Pointer)(unsafe.Pointer(&j.cachedBuildRows[index].buildRowStart)), columnIndex)) - j.cachedBuildRows[index].buildRowOffset = chunk.AppendCellFromRawData(currentColumn, *(*unsafe.Pointer)(unsafe.Pointer(&j.cachedBuildRows[index].buildRowStart)), j.cachedBuildRows[index].buildRowOffset) - } - } else { - // not used so don't need to insert into chk, but still need to advance rowData - if meta.columnsSize[columnIndex] < 0 { - for index := range j.cachedBuildRows { - size := *(*uint64)(unsafe.Add(*(*unsafe.Pointer)(unsafe.Pointer(&j.cachedBuildRows[index].buildRowStart)), j.cachedBuildRows[index].buildRowOffset)) - j.cachedBuildRows[index].buildRowOffset += sizeOfLengthField + int(size) - } - } else { - for index := range j.cachedBuildRows { - j.cachedBuildRows[index].buildRowOffset += meta.columnsSize[columnIndex] - } - } - } - } - if needUpdateVirtualRow { - chk.SetNumVirtualRows(chkRows + len(j.cachedBuildRows)) - } -} - -func (j *baseJoinProbe) appendProbeRowToChunk(chk *chunk.Chunk, probeChk *chunk.Chunk) { - if j.rightAsBuildSide { - if j.ctx.hasOtherCondition() { - j.appendProbeRowToChunkInternal(chk, probeChk, j.lUsedInOtherCondition, 0, true) - } else { - j.appendProbeRowToChunkInternal(chk, probeChk, j.lUsed, 0, false) - } - } else { - if j.ctx.hasOtherCondition() { - j.appendProbeRowToChunkInternal(chk, probeChk, j.rUsedInOtherCondition, j.ctx.hashTableMeta.totalColumnNumber, true) - } else { - j.appendProbeRowToChunkInternal(chk, probeChk, j.rUsed, len(j.lUsed), false) - } - } -} - -func (j *baseJoinProbe) appendProbeRowToChunkInternal(chk *chunk.Chunk, probeChk *chunk.Chunk, used []int, collOffset int, forOtherCondition bool) { - if len(used) == 0 || len(j.offsetAndLengthArray) == 0 { - return - } - if forOtherCondition { - usedColumnMap := make(map[int]struct{}) - for _, colIndex := range used { - if _, ok := usedColumnMap[colIndex]; !ok { - srcCol := probeChk.Column(colIndex) - dstCol := chk.Column(colIndex + collOffset) - for _, offsetAndLength := range j.offsetAndLengthArray { - dstCol.AppendCellNTimes(srcCol, offsetAndLength.offset, offsetAndLength.length) - } - usedColumnMap[colIndex] = struct{}{} - } - } - } else { - for index, colIndex := range used { - srcCol := probeChk.Column(colIndex) - dstCol := chk.Column(index + collOffset) - for _, offsetAndLength := range j.offsetAndLengthArray { - dstCol.AppendCellNTimes(srcCol, offsetAndLength.offset, offsetAndLength.length) - } - } - } -} - -func (j *baseJoinProbe) buildResultAfterOtherCondition(chk *chunk.Chunk, joinedChk *chunk.Chunk) (err error) { - // construct the return chunk based on joinedChk and selected, there are 3 kinds of columns - // 1. columns already in joinedChk - // 2. columns from build side, but not in joinedChk - // 3. columns from probe side, but not in joinedChk - rowCount := chk.NumRows() - probeUsedColumns, probeColOffset, probeColOffsetInJoinedChk := j.lUsed, 0, 0 - if !j.rightAsBuildSide { - probeUsedColumns, probeColOffset, probeColOffsetInJoinedChk = j.rUsed, len(j.lUsed), j.ctx.hashTableMeta.totalColumnNumber - } - - for index, colIndex := range probeUsedColumns { - dstCol := chk.Column(index + probeColOffset) - if joinedChk.Column(colIndex+probeColOffsetInJoinedChk).Rows() > 0 { - // probe column that is already in joinedChk - srcCol := joinedChk.Column(colIndex + probeColOffsetInJoinedChk) - chunk.CopySelectedRows(dstCol, srcCol, j.selected) - } else { - // probe column that is not in joinedChk - srcCol := j.currentChunk.Column(colIndex) - chunk.CopySelectedRowsWithRowIDFunc(dstCol, srcCol, j.selected, 0, len(j.selected), func(i int) int { - return j.usedRows[j.rowIndexInfos[i].probeRowIndex] - }) - } - } - buildUsedColumns, buildColOffset, buildColOffsetInJoinedChk := j.lUsed, 0, 0 - if j.rightAsBuildSide { - buildUsedColumns, buildColOffset, buildColOffsetInJoinedChk = j.rUsed, len(j.lUsed), j.currentChunk.NumCols() - } - hasRemainCols := false - for index, colIndex := range buildUsedColumns { - dstCol := chk.Column(index + buildColOffset) - srcCol := joinedChk.Column(colIndex + buildColOffsetInJoinedChk) - if srcCol.Rows() > 0 { - // build column that is already in joinedChk - chunk.CopySelectedRows(dstCol, srcCol, j.selected) - } else { - hasRemainCols = true - } - } - if hasRemainCols { - j.cachedBuildRows = j.cachedBuildRows[:0] - // build column that is not in joinedChk - for index, result := range j.selected { - if result { - j.appendBuildRowToCachedBuildRowsAndConstructBuildRowsIfNeeded(j.rowIndexInfos[index], chk, j.ctx.hashTableMeta.columnCountNeededForOtherCondition, false) - } - } - if len(j.cachedBuildRows) > 0 { - j.batchConstructBuildRows(chk, j.ctx.hashTableMeta.columnCountNeededForOtherCondition, false) - } - } - rowsAdded := 0 - for _, result := range j.selected { - if result { - rowsAdded++ - } - } - chk.SetNumVirtualRows(rowCount + rowsAdded) - return -} - -func isKeyMatched(keyMode keyMode, serializedKey []byte, rowStart unsafe.Pointer, meta *TableMeta) bool { - switch keyMode { - case OneInt64: - return *(*int64)(unsafe.Pointer(&serializedKey[0])) == *(*int64)(unsafe.Add(rowStart, meta.nullMapLength+sizeOfNextPtr)) - case FixedSerializedKey: - return bytes.Equal(serializedKey, hack.GetBytesFromPtr(unsafe.Add(rowStart, meta.nullMapLength+sizeOfNextPtr), meta.joinKeysLength)) - case VariableSerializedKey: - return bytes.Equal(serializedKey, hack.GetBytesFromPtr(unsafe.Add(rowStart, meta.nullMapLength+sizeOfNextPtr+sizeOfLengthField), int(meta.getSerializedKeyLength(rowStart)))) - default: - panic("unknown key match type") - } -} - -// NewJoinProbe create a join probe used for hash join v2 -func NewJoinProbe(ctx *HashJoinCtxV2, workID uint, joinType core.JoinType, keyIndex []int, joinedColumnTypes, probeKeyTypes []*types.FieldType, rightAsBuildSide bool) ProbeV2 { - base := baseJoinProbe{ - ctx: ctx, - workID: workID, - keyIndex: keyIndex, - keyTypes: probeKeyTypes, - maxChunkSize: ctx.SessCtx.GetSessionVars().MaxChunkSize, - lUsed: ctx.LUsed, - rUsed: ctx.RUsed, - lUsedInOtherCondition: ctx.LUsedInOtherCondition, - rUsedInOtherCondition: ctx.RUsedInOtherCondition, - rightAsBuildSide: rightAsBuildSide, - } - for i := range keyIndex { - if !mysql.HasNotNullFlag(base.keyTypes[i].GetFlag()) { - base.hasNullableKey = true - } - } - base.cachedBuildRows = make([]*matchedRowInfo, 0, batchBuildRowSize) - base.matchedRowsHeaders = make([]uintptr, 0, chunk.InitialCapacity) - base.selRows = make([]int, 0, chunk.InitialCapacity) - for i := 0; i < chunk.InitialCapacity; i++ { - base.selRows = append(base.selRows, i) - } - base.hashValues = make([][]posAndHashValue, ctx.partitionNumber) - for i := 0; i < int(ctx.partitionNumber); i++ { - base.hashValues[i] = make([]posAndHashValue, 0, chunk.InitialCapacity) - } - base.serializedKeys = make([][]byte, 0, chunk.InitialCapacity) - if base.ctx.ProbeFilter != nil { - base.filterVector = make([]bool, 0, chunk.InitialCapacity) - } - if base.hasNullableKey { - base.nullKeyVector = make([]bool, 0, chunk.InitialCapacity) - } - if base.ctx.OtherCondition != nil { - base.tmpChk = chunk.NewChunkWithCapacity(joinedColumnTypes, chunk.InitialCapacity) - base.tmpChk.SetInCompleteChunk(true) - base.selected = make([]bool, 0, chunk.InitialCapacity) - base.rowIndexInfos = make([]*matchedRowInfo, 0, chunk.InitialCapacity) - } - switch joinType { - case core.InnerJoin: - return &innerJoinProbe{base} - case core.LeftOuterJoin: - return newOuterJoinProbe(base, !rightAsBuildSide, rightAsBuildSide) - case core.RightOuterJoin: - return newOuterJoinProbe(base, rightAsBuildSide, rightAsBuildSide) - default: - panic("unsupported join type") - } -} - -type mockJoinProbe struct { - baseJoinProbe -} - -func (*mockJoinProbe) SetChunkForProbe(*chunk.Chunk) error { - return errors.New("not supported") -} - -func (*mockJoinProbe) Probe(*hashjoinWorkerResult, *sqlkiller.SQLKiller) (ok bool, result *hashjoinWorkerResult) { - panic("not supported") -} - -func (*mockJoinProbe) ScanRowTable(*hashjoinWorkerResult, *sqlkiller.SQLKiller) (result *hashjoinWorkerResult) { - panic("not supported") -} - -func (*mockJoinProbe) IsScanRowTableDone() bool { - panic("not supported") -} - -func (*mockJoinProbe) NeedScanRowTable() bool { - panic("not supported") -} - -func (*mockJoinProbe) InitForScanRowTable() { - panic("not supported") -} - -// used for test -func newMockJoinProbe(ctx *HashJoinCtxV2) *mockJoinProbe { - base := baseJoinProbe{ - ctx: ctx, - lUsed: ctx.LUsed, - rUsed: ctx.RUsed, - lUsedInOtherCondition: ctx.LUsedInOtherCondition, - rUsedInOtherCondition: ctx.RUsedInOtherCondition, - rightAsBuildSide: false, - } - return &mockJoinProbe{base} -} diff --git a/pkg/executor/join/binding__failpoint_binding__.go b/pkg/executor/join/binding__failpoint_binding__.go deleted file mode 100644 index 6560aaf35d7c5..0000000000000 --- a/pkg/executor/join/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package join - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/join/hash_join_base.go b/pkg/executor/join/hash_join_base.go index ecf5586e34114..7947b96653e6c 100644 --- a/pkg/executor/join/hash_join_base.go +++ b/pkg/executor/join/hash_join_base.go @@ -168,7 +168,7 @@ func (fetcher *probeSideTupleFetcherBase) fetchProbeSideChunks(ctx context.Conte } probeSideResult := probeSideResource.chk err := exec.Next(ctx, fetcher.ProbeSideExec, probeSideResult) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if err != nil { hashJoinCtx.joinResultCh <- &hashjoinWorkerResult{ err: err, @@ -176,11 +176,11 @@ func (fetcher *probeSideTupleFetcherBase) fetchProbeSideChunks(ctx context.Conte return } if !hasWaitedForBuild { - if val, _err_ := failpoint.Eval(_curpkg_("issue30289")); _err_ == nil { + failpoint.Inject("issue30289", func(val failpoint.Value) { if val.(bool) { probeSideResult.Reset() } - } + }) skipProbe := wait4BuildSide(isBuildEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone, hashJoinCtx) if skipProbe { // there is no need to probe, so just return @@ -223,14 +223,14 @@ type buildWorkerBase struct { func (w *buildWorkerBase) fetchBuildSideRows(ctx context.Context, hashJoinCtx *hashJoinCtxBase, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { defer close(chkCh) var err error - if val, _err_ := failpoint.Eval(_curpkg_("issue30289")); _err_ == nil { + failpoint.Inject("issue30289", func(val failpoint.Value) { if val.(bool) { err = errors.Errorf("issue30289 build return error") errCh <- errors.Trace(err) return } - } - if val, _err_ := failpoint.Eval(_curpkg_("issue42662_1")); _err_ == nil { + }) + failpoint.Inject("issue42662_1", func(val failpoint.Value) { if val.(bool) { if hashJoinCtx.SessCtx.GetSessionVars().ConnectionID != 0 { // consume 170MB memory, this sql should be tracked into MemoryTop1Tracker @@ -238,30 +238,30 @@ func (w *buildWorkerBase) fetchBuildSideRows(ctx context.Context, hashJoinCtx *h } return } - } + }) sessVars := hashJoinCtx.SessCtx.GetSessionVars() - if val, _err_ := failpoint.Eval(_curpkg_("issue51998")); _err_ == nil { + failpoint.Inject("issue51998", func(val failpoint.Value) { if val.(bool) { time.Sleep(2 * time.Second) } - } + }) for { if hashJoinCtx.finished.Load() { return } chk := hashJoinCtx.ChunkAllocPool.Alloc(w.BuildSideExec.RetFieldTypes(), sessVars.MaxChunkSize, sessVars.MaxChunkSize) err = exec.Next(ctx, w.BuildSideExec, chk) - if val, _err_ := failpoint.Eval(_curpkg_("issue51998")); _err_ == nil { + failpoint.Inject("issue51998", func(val failpoint.Value) { if val.(bool) { err = errors.Errorf("issue51998 build return error") } - } + }) if err != nil { errCh <- errors.Trace(err) return } - failpoint.Eval(_curpkg_("errorFetchBuildSideRowsMockOOMPanic")) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("errorFetchBuildSideRowsMockOOMPanic", nil) + failpoint.Inject("ConsumeRandomPanic", nil) if chk.NumRows() == 0 { return } diff --git a/pkg/executor/join/hash_join_base.go__failpoint_stash__ b/pkg/executor/join/hash_join_base.go__failpoint_stash__ deleted file mode 100644 index 7947b96653e6c..0000000000000 --- a/pkg/executor/join/hash_join_base.go__failpoint_stash__ +++ /dev/null @@ -1,379 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "bytes" - "context" - "fmt" - "strconv" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/memory" -) - -// hashjoinWorkerResult stores the result of join workers, -// `src` is for Chunk reuse: the main goroutine will get the join result chunk `chk`, -// and push `chk` into `src` after processing, join worker goroutines get the empty chunk from `src` -// and push new data into this chunk. -type hashjoinWorkerResult struct { - chk *chunk.Chunk - err error - src chan<- *chunk.Chunk -} - -type hashJoinCtxBase struct { - SessCtx sessionctx.Context - ChunkAllocPool chunk.Allocator - // Concurrency is the number of partition, build and join workers. - Concurrency uint - joinResultCh chan *hashjoinWorkerResult - // closeCh add a lock for closing executor. - closeCh chan struct{} - finished atomic.Bool - IsNullEQ []bool - buildFinished chan error - JoinType plannercore.JoinType - IsNullAware bool - memTracker *memory.Tracker // track memory usage. - diskTracker *disk.Tracker // track disk usage. -} - -type probeSideTupleFetcherBase struct { - ProbeSideExec exec.Executor - probeChkResourceCh chan *probeChkResource - probeResultChs []chan *chunk.Chunk - requiredRows int64 - joinResultChannel chan *hashjoinWorkerResult -} - -func (fetcher *probeSideTupleFetcherBase) initializeForProbeBase(concurrency uint, joinResultChannel chan *hashjoinWorkerResult) { - // fetcher.probeResultChs is for transmitting the chunks which store the data of - // ProbeSideExec, it'll be written by probe side worker goroutine, and read by join - // workers. - fetcher.probeResultChs = make([]chan *chunk.Chunk, concurrency) - for i := uint(0); i < concurrency; i++ { - fetcher.probeResultChs[i] = make(chan *chunk.Chunk, 1) - } - // fetcher.probeChkResourceCh is for transmitting the used ProbeSideExec chunks from - // join workers to ProbeSideExec worker. - fetcher.probeChkResourceCh = make(chan *probeChkResource, concurrency) - for i := uint(0); i < concurrency; i++ { - fetcher.probeChkResourceCh <- &probeChkResource{ - chk: exec.NewFirstChunk(fetcher.ProbeSideExec), - dest: fetcher.probeResultChs[i], - } - } - fetcher.joinResultChannel = joinResultChannel -} - -func (fetcher *probeSideTupleFetcherBase) handleProbeSideFetcherPanic(r any) { - for i := range fetcher.probeResultChs { - close(fetcher.probeResultChs[i]) - } - if r != nil { - fetcher.joinResultChannel <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} - } -} - -type isBuildSideEmpty func() bool - -func wait4BuildSide(isBuildEmpty isBuildSideEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone bool, hashJoinCtx *hashJoinCtxBase) (skipProbe bool) { - var err error - skipProbe = false - buildFinishes := false - select { - case <-hashJoinCtx.closeCh: - // current executor is closed, no need to probe - skipProbe = true - case err = <-hashJoinCtx.buildFinished: - if err != nil { - // build meet error, no need to probe - skipProbe = true - } else { - buildFinishes = true - } - } - // only check build empty if build finishes - if buildFinishes && isBuildEmpty() && canSkipIfBuildEmpty { - // if build side is empty, can skip probe if canSkipIfBuildEmpty is true(e.g. inner join) - skipProbe = true - } - if err != nil { - // if err is not nil, send out the error - hashJoinCtx.joinResultCh <- &hashjoinWorkerResult{ - err: err, - } - } else if skipProbe { - // if skipProbe is true and there is no need to scan hash table after probe, just the whole hash join is finished - if !needScanAfterProbeDone { - hashJoinCtx.finished.Store(true) - } - } - return skipProbe -} - -func (fetcher *probeSideTupleFetcherBase) getProbeSideResource(shouldLimitProbeFetchSize bool, maxChunkSize int, hashJoinCtx *hashJoinCtxBase) *probeChkResource { - if hashJoinCtx.finished.Load() { - return nil - } - - var probeSideResource *probeChkResource - var ok bool - select { - case <-hashJoinCtx.closeCh: - return nil - case probeSideResource, ok = <-fetcher.probeChkResourceCh: - if !ok { - return nil - } - } - if shouldLimitProbeFetchSize { - required := int(atomic.LoadInt64(&fetcher.requiredRows)) - probeSideResource.chk.SetRequiredRows(required, maxChunkSize) - } - return probeSideResource -} - -// fetchProbeSideChunks get chunks from fetches chunks from the big table in a background goroutine -// and sends the chunks to multiple channels which will be read by multiple join workers. -func (fetcher *probeSideTupleFetcherBase) fetchProbeSideChunks(ctx context.Context, maxChunkSize int, isBuildEmpty isBuildSideEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone, shouldLimitProbeFetchSize bool, hashJoinCtx *hashJoinCtxBase) { - hasWaitedForBuild := false - for { - probeSideResource := fetcher.getProbeSideResource(shouldLimitProbeFetchSize, maxChunkSize, hashJoinCtx) - if probeSideResource == nil { - return - } - probeSideResult := probeSideResource.chk - err := exec.Next(ctx, fetcher.ProbeSideExec, probeSideResult) - failpoint.Inject("ConsumeRandomPanic", nil) - if err != nil { - hashJoinCtx.joinResultCh <- &hashjoinWorkerResult{ - err: err, - } - return - } - if !hasWaitedForBuild { - failpoint.Inject("issue30289", func(val failpoint.Value) { - if val.(bool) { - probeSideResult.Reset() - } - }) - skipProbe := wait4BuildSide(isBuildEmpty, canSkipIfBuildEmpty, needScanAfterProbeDone, hashJoinCtx) - if skipProbe { - // there is no need to probe, so just return - return - } - hasWaitedForBuild = true - } - - if probeSideResult.NumRows() == 0 { - return - } - - probeSideResource.dest <- probeSideResult - } -} - -type probeWorkerBase struct { - WorkerID uint - probeChkResourceCh chan *probeChkResource - joinChkResourceCh chan *chunk.Chunk - probeResultCh chan *chunk.Chunk -} - -func (worker *probeWorkerBase) initializeForProbe(probeChkResourceCh chan *probeChkResource, probeResultCh chan *chunk.Chunk, joinExec exec.Executor) { - // worker.joinChkResourceCh is for transmitting the reused join result chunks - // from the main thread to probe worker goroutines. - worker.joinChkResourceCh = make(chan *chunk.Chunk, 1) - worker.joinChkResourceCh <- exec.NewFirstChunk(joinExec) - worker.probeChkResourceCh = probeChkResourceCh - worker.probeResultCh = probeResultCh -} - -type buildWorkerBase struct { - BuildSideExec exec.Executor - BuildKeyColIdx []int -} - -// fetchBuildSideRows fetches all rows from build side executor, and append them -// to e.buildSideResult. -func (w *buildWorkerBase) fetchBuildSideRows(ctx context.Context, hashJoinCtx *hashJoinCtxBase, chkCh chan<- *chunk.Chunk, errCh chan<- error, doneCh <-chan struct{}) { - defer close(chkCh) - var err error - failpoint.Inject("issue30289", func(val failpoint.Value) { - if val.(bool) { - err = errors.Errorf("issue30289 build return error") - errCh <- errors.Trace(err) - return - } - }) - failpoint.Inject("issue42662_1", func(val failpoint.Value) { - if val.(bool) { - if hashJoinCtx.SessCtx.GetSessionVars().ConnectionID != 0 { - // consume 170MB memory, this sql should be tracked into MemoryTop1Tracker - hashJoinCtx.memTracker.Consume(170 * 1024 * 1024) - } - return - } - }) - sessVars := hashJoinCtx.SessCtx.GetSessionVars() - failpoint.Inject("issue51998", func(val failpoint.Value) { - if val.(bool) { - time.Sleep(2 * time.Second) - } - }) - for { - if hashJoinCtx.finished.Load() { - return - } - chk := hashJoinCtx.ChunkAllocPool.Alloc(w.BuildSideExec.RetFieldTypes(), sessVars.MaxChunkSize, sessVars.MaxChunkSize) - err = exec.Next(ctx, w.BuildSideExec, chk) - failpoint.Inject("issue51998", func(val failpoint.Value) { - if val.(bool) { - err = errors.Errorf("issue51998 build return error") - } - }) - if err != nil { - errCh <- errors.Trace(err) - return - } - failpoint.Inject("errorFetchBuildSideRowsMockOOMPanic", nil) - failpoint.Inject("ConsumeRandomPanic", nil) - if chk.NumRows() == 0 { - return - } - select { - case <-doneCh: - return - case <-hashJoinCtx.closeCh: - return - case chkCh <- chk: - } - } -} - -// probeChkResource stores the result of the join probe side fetch worker, -// `dest` is for Chunk reuse: after join workers process the probe side chunk which is read from `dest`, -// they'll store the used chunk as `chk`, and then the probe side fetch worker will put new data into `chk` and write `chk` into dest. -type probeChkResource struct { - chk *chunk.Chunk - dest chan<- *chunk.Chunk -} - -type hashJoinRuntimeStats struct { - fetchAndBuildHashTable time.Duration - hashStat hashStatistic - fetchAndProbe int64 - probe int64 - concurrent int - maxFetchAndProbe int64 -} - -func (e *hashJoinRuntimeStats) setMaxFetchAndProbeTime(t int64) { - for { - value := atomic.LoadInt64(&e.maxFetchAndProbe) - if t <= value { - return - } - if atomic.CompareAndSwapInt64(&e.maxFetchAndProbe, value, t) { - return - } - } -} - -// Tp implements the RuntimeStats interface. -func (*hashJoinRuntimeStats) Tp() int { - return execdetails.TpHashJoinRuntimeStats -} - -func (e *hashJoinRuntimeStats) String() string { - buf := bytes.NewBuffer(make([]byte, 0, 128)) - if e.fetchAndBuildHashTable > 0 { - buf.WriteString("build_hash_table:{total:") - buf.WriteString(execdetails.FormatDuration(e.fetchAndBuildHashTable)) - buf.WriteString(", fetch:") - buf.WriteString(execdetails.FormatDuration(e.fetchAndBuildHashTable - e.hashStat.buildTableElapse)) - buf.WriteString(", build:") - buf.WriteString(execdetails.FormatDuration(e.hashStat.buildTableElapse)) - buf.WriteString("}") - } - if e.probe > 0 { - buf.WriteString(", probe:{concurrency:") - buf.WriteString(strconv.Itoa(e.concurrent)) - buf.WriteString(", total:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe))) - buf.WriteString(", max:") - buf.WriteString(execdetails.FormatDuration(time.Duration(atomic.LoadInt64(&e.maxFetchAndProbe)))) - buf.WriteString(", probe:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.probe))) - // fetch time is the time wait fetch result from its child executor, - // wait time is the time wait its parent executor to fetch the joined result - buf.WriteString(", fetch and wait:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe - e.probe))) - if e.hashStat.probeCollision > 0 { - buf.WriteString(", probe_collision:") - buf.WriteString(strconv.FormatInt(e.hashStat.probeCollision, 10)) - } - buf.WriteString("}") - } - return buf.String() -} - -func (e *hashJoinRuntimeStats) Clone() execdetails.RuntimeStats { - return &hashJoinRuntimeStats{ - fetchAndBuildHashTable: e.fetchAndBuildHashTable, - hashStat: e.hashStat, - fetchAndProbe: e.fetchAndProbe, - probe: e.probe, - concurrent: e.concurrent, - maxFetchAndProbe: e.maxFetchAndProbe, - } -} - -func (e *hashJoinRuntimeStats) Merge(rs execdetails.RuntimeStats) { - tmp, ok := rs.(*hashJoinRuntimeStats) - if !ok { - return - } - e.fetchAndBuildHashTable += tmp.fetchAndBuildHashTable - e.hashStat.buildTableElapse += tmp.hashStat.buildTableElapse - e.hashStat.probeCollision += tmp.hashStat.probeCollision - e.fetchAndProbe += tmp.fetchAndProbe - e.probe += tmp.probe - if e.maxFetchAndProbe < tmp.maxFetchAndProbe { - e.maxFetchAndProbe = tmp.maxFetchAndProbe - } -} - -type hashStatistic struct { - // NOTE: probeCollision may be accessed from multiple goroutines concurrently. - probeCollision int64 - buildTableElapse time.Duration -} - -func (s *hashStatistic) String() string { - return fmt.Sprintf("probe_collision:%v, build:%v", s.probeCollision, execdetails.FormatDuration(s.buildTableElapse)) -} diff --git a/pkg/executor/join/hash_join_v1.go b/pkg/executor/join/hash_join_v1.go index d414a594f8116..649eac1eb466b 100644 --- a/pkg/executor/join/hash_join_v1.go +++ b/pkg/executor/join/hash_join_v1.go @@ -330,7 +330,7 @@ func (w *ProbeWorkerV1) runJoinWorker() { return case probeSideResult, ok = <-w.probeResultCh: } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if !ok { break } @@ -859,11 +859,11 @@ func (w *ProbeWorkerV1) join2Chunk(probeSideChk *chunk.Chunk, hCtx *HashContext, for i := range selected { err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() - if val, _err_ := failpoint.Eval(_curpkg_("killedInJoin2Chunk")); _err_ == nil { + failpoint.Inject("killedInJoin2Chunk", func(val failpoint.Value) { if val.(bool) { err = exeerrors.ErrQueryInterrupted } - } + }) if err != nil { joinResult.err = err return false, waitTime, joinResult @@ -938,11 +938,11 @@ func (w *ProbeWorkerV1) join2ChunkForOuterHashJoin(probeSideChk *chunk.Chunk, hC } for i := 0; i < probeSideChk.NumRows(); i++ { err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() - if val, _err_ := failpoint.Eval(_curpkg_("killedInJoin2ChunkForOuterHashJoin")); _err_ == nil { + failpoint.Inject("killedInJoin2ChunkForOuterHashJoin", func(val failpoint.Value) { if val.(bool) { err = exeerrors.ErrQueryInterrupted } - } + }) if err != nil { joinResult.err = err return false, waitTime, joinResult @@ -1073,12 +1073,12 @@ func (w *BuildWorkerV1) BuildHashTableForList(buildSideResultCh <-chan *chunk.Ch rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult) if variable.EnableTmpStorageOnOOM.Load() { actionSpill := rowContainer.ActionSpill() - if val, _err_ := failpoint.Eval(_curpkg_("testRowContainerSpill")); _err_ == nil { + failpoint.Inject("testRowContainerSpill", func(val failpoint.Value) { if val.(bool) { actionSpill = rowContainer.rowContainer.ActionSpillForTest() defer actionSpill.(*chunk.SpillDiskAction).WaitForTest() } - } + }) w.HashJoinCtx.SessCtx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } for chk := range buildSideResultCh { @@ -1101,7 +1101,7 @@ func (w *BuildWorkerV1) BuildHashTableForList(buildSideResultCh <-chan *chunk.Ch err = rowContainer.PutChunkSelected(chk, selected, w.HashJoinCtx.IsNullEQ) } } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if err != nil { return err } diff --git a/pkg/executor/join/hash_join_v1.go__failpoint_stash__ b/pkg/executor/join/hash_join_v1.go__failpoint_stash__ deleted file mode 100644 index 649eac1eb466b..0000000000000 --- a/pkg/executor/join/hash_join_v1.go__failpoint_stash__ +++ /dev/null @@ -1,1434 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "bytes" - "context" - "fmt" - "runtime/trace" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/aggregate" - "github.com/pingcap/tidb/pkg/executor/internal/applycache" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/unionexec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/bitmap" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/memory" -) - -var ( - _ exec.Executor = &HashJoinV1Exec{} - _ exec.Executor = &NestedLoopApplyExec{} -) - -// HashJoinCtxV1 is the context used in hash join -type HashJoinCtxV1 struct { - hashJoinCtxBase - UseOuterToBuild bool - IsOuterJoin bool - RowContainer *hashRowContainer - outerMatchedStatus []*bitmap.ConcurrentBitmap - ProbeTypes []*types.FieldType - BuildTypes []*types.FieldType - OuterFilter expression.CNFExprs - stats *hashJoinRuntimeStats -} - -// ProbeSideTupleFetcherV1 reads tuples from ProbeSideExec and send them to ProbeWorkers. -type ProbeSideTupleFetcherV1 struct { - probeSideTupleFetcherBase - *HashJoinCtxV1 -} - -// ProbeWorkerV1 is the probe side worker in hash join -type ProbeWorkerV1 struct { - probeWorkerBase - HashJoinCtx *HashJoinCtxV1 - ProbeKeyColIdx []int - ProbeNAKeyColIdx []int - // We pre-alloc and reuse the Rows and RowPtrs for each probe goroutine, to avoid allocation frequently - buildSideRows []chunk.Row - buildSideRowPtrs []chunk.RowPtr - - // We build individual joiner for each join worker when use chunk-based - // execution, to avoid the concurrency of joiner.chk and joiner.selected. - Joiner Joiner - rowIters *chunk.Iterator4Slice - rowContainerForProbe *hashRowContainer - // for every naaj probe worker, pre-allocate the int slice for store the join column index to check. - needCheckBuildColPos []int - needCheckProbeColPos []int - needCheckBuildTypes []*types.FieldType - needCheckProbeTypes []*types.FieldType -} - -// BuildWorkerV1 is the build side worker in hash join -type BuildWorkerV1 struct { - buildWorkerBase - HashJoinCtx *HashJoinCtxV1 - BuildNAKeyColIdx []int -} - -// HashJoinV1Exec implements the hash join algorithm. -type HashJoinV1Exec struct { - exec.BaseExecutor - *HashJoinCtxV1 - - ProbeSideTupleFetcher *ProbeSideTupleFetcherV1 - ProbeWorkers []*ProbeWorkerV1 - BuildWorker *BuildWorkerV1 - - workerWg util.WaitGroupWrapper - waiterWg util.WaitGroupWrapper - - Prepared bool -} - -// Close implements the Executor Close interface. -func (e *HashJoinV1Exec) Close() error { - if e.closeCh != nil { - close(e.closeCh) - } - e.finished.Store(true) - if e.Prepared { - if e.buildFinished != nil { - channel.Clear(e.buildFinished) - } - if e.joinResultCh != nil { - channel.Clear(e.joinResultCh) - } - if e.ProbeSideTupleFetcher.probeChkResourceCh != nil { - close(e.ProbeSideTupleFetcher.probeChkResourceCh) - channel.Clear(e.ProbeSideTupleFetcher.probeChkResourceCh) - } - for i := range e.ProbeSideTupleFetcher.probeResultChs { - channel.Clear(e.ProbeSideTupleFetcher.probeResultChs[i]) - } - for i := range e.ProbeWorkers { - close(e.ProbeWorkers[i].joinChkResourceCh) - channel.Clear(e.ProbeWorkers[i].joinChkResourceCh) - } - e.ProbeSideTupleFetcher.probeChkResourceCh = nil - terror.Call(e.RowContainer.Close) - e.HashJoinCtxV1.SessCtx.GetSessionVars().MemTracker.UnbindActionFromHardLimit(e.RowContainer.ActionSpill()) - e.waiterWg.Wait() - } - e.outerMatchedStatus = e.outerMatchedStatus[:0] - for _, w := range e.ProbeWorkers { - w.buildSideRows = nil - w.buildSideRowPtrs = nil - w.needCheckBuildColPos = nil - w.needCheckProbeColPos = nil - w.needCheckBuildTypes = nil - w.needCheckProbeTypes = nil - w.joinChkResourceCh = nil - } - - if e.stats != nil && e.RowContainer != nil { - e.stats.hashStat = *e.RowContainer.stat - } - if e.stats != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - err := e.BaseExecutor.Close() - return err -} - -// Open implements the Executor Open interface. -func (e *HashJoinV1Exec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - e.closeCh = nil - e.Prepared = false - return err - } - e.Prepared = false - if e.HashJoinCtxV1.memTracker != nil { - e.HashJoinCtxV1.memTracker.Reset() - } else { - e.HashJoinCtxV1.memTracker = memory.NewTracker(e.ID(), -1) - } - e.HashJoinCtxV1.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - - if e.HashJoinCtxV1.diskTracker != nil { - e.HashJoinCtxV1.diskTracker.Reset() - } else { - e.HashJoinCtxV1.diskTracker = disk.NewTracker(e.ID(), -1) - } - e.HashJoinCtxV1.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) - - e.workerWg = util.WaitGroupWrapper{} - e.waiterWg = util.WaitGroupWrapper{} - e.closeCh = make(chan struct{}) - e.finished.Store(false) - - if e.RuntimeStats() != nil { - e.stats = &hashJoinRuntimeStats{ - concurrent: int(e.Concurrency), - } - } - return nil -} - -func (e *HashJoinV1Exec) initializeForProbe() { - e.ProbeSideTupleFetcher.HashJoinCtxV1 = e.HashJoinCtxV1 - // e.joinResultCh is for transmitting the join result chunks to the main - // thread. - e.joinResultCh = make(chan *hashjoinWorkerResult, e.Concurrency+1) - e.ProbeSideTupleFetcher.initializeForProbeBase(e.Concurrency, e.joinResultCh) - - for i := uint(0); i < e.Concurrency; i++ { - e.ProbeWorkers[i].initializeForProbe(e.ProbeSideTupleFetcher.probeChkResourceCh, e.ProbeSideTupleFetcher.probeResultChs[i], e) - } -} - -func (e *HashJoinV1Exec) fetchAndProbeHashTable(ctx context.Context) { - e.initializeForProbe() - e.workerWg.RunWithRecover(func() { - defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() - e.ProbeSideTupleFetcher.fetchProbeSideChunks(ctx, e.MaxChunkSize(), func() bool { - return e.ProbeSideTupleFetcher.RowContainer.Len() == uint64(0) - }, e.ProbeSideTupleFetcher.JoinType == plannercore.InnerJoin || e.ProbeSideTupleFetcher.JoinType == plannercore.SemiJoin, - false, e.ProbeSideTupleFetcher.IsOuterJoin, &e.ProbeSideTupleFetcher.hashJoinCtxBase) - }, e.ProbeSideTupleFetcher.handleProbeSideFetcherPanic) - - for i := uint(0); i < e.Concurrency; i++ { - workerID := i - e.workerWg.RunWithRecover(func() { - defer trace.StartRegion(ctx, "HashJoinWorker").End() - e.ProbeWorkers[workerID].runJoinWorker() - }, e.ProbeWorkers[workerID].handleProbeWorkerPanic) - } - e.waiterWg.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) -} - -func (w *ProbeWorkerV1) handleProbeWorkerPanic(r any) { - if r != nil { - w.HashJoinCtx.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} - } -} - -func (e *HashJoinV1Exec) handleJoinWorkerPanic(r any) { - if r != nil { - e.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} - } -} - -// Concurrently handling unmatched rows from the hash table -func (w *ProbeWorkerV1) handleUnmatchedRowsFromHashTable() { - ok, joinResult := w.getNewJoinResult() - if !ok { - return - } - numChks := w.rowContainerForProbe.NumChunks() - for i := int(w.WorkerID); i < numChks; i += int(w.HashJoinCtx.Concurrency) { - chk, err := w.rowContainerForProbe.GetChunk(i) - if err != nil { - // Catching the error and send it - joinResult.err = err - w.HashJoinCtx.joinResultCh <- joinResult - return - } - for j := 0; j < chk.NumRows(); j++ { - if !w.HashJoinCtx.outerMatchedStatus[i].UnsafeIsSet(j) { // process unmatched Outer rows - w.Joiner.OnMissMatch(false, chk.GetRow(j), joinResult.chk) - } - if joinResult.chk.IsFull() { - w.HashJoinCtx.joinResultCh <- joinResult - ok, joinResult = w.getNewJoinResult() - if !ok { - return - } - } - } - } - - if joinResult == nil { - return - } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { - w.HashJoinCtx.joinResultCh <- joinResult - } -} - -func (e *HashJoinV1Exec) waitJoinWorkersAndCloseResultChan() { - e.workerWg.Wait() - if e.UseOuterToBuild { - // Concurrently handling unmatched rows from the hash table at the tail - for i := uint(0); i < e.Concurrency; i++ { - var workerID = i - e.workerWg.RunWithRecover(func() { e.ProbeWorkers[workerID].handleUnmatchedRowsFromHashTable() }, e.handleJoinWorkerPanic) - } - e.workerWg.Wait() - } - close(e.joinResultCh) -} - -func (w *ProbeWorkerV1) runJoinWorker() { - probeTime := int64(0) - if w.HashJoinCtx.stats != nil { - start := time.Now() - defer func() { - t := time.Since(start) - atomic.AddInt64(&w.HashJoinCtx.stats.probe, probeTime) - atomic.AddInt64(&w.HashJoinCtx.stats.fetchAndProbe, int64(t)) - w.HashJoinCtx.stats.setMaxFetchAndProbeTime(int64(t)) - }() - } - - var ( - probeSideResult *chunk.Chunk - selected = make([]bool, 0, chunk.InitialCapacity) - ) - ok, joinResult := w.getNewJoinResult() - if !ok { - return - } - - // Read and filter probeSideResult, and join the probeSideResult with the build side rows. - emptyProbeSideResult := &probeChkResource{ - dest: w.probeResultCh, - } - hCtx := &HashContext{ - AllTypes: w.HashJoinCtx.ProbeTypes, - KeyColIdx: w.ProbeKeyColIdx, - NaKeyColIdx: w.ProbeNAKeyColIdx, - } - for ok := true; ok; { - if w.HashJoinCtx.finished.Load() { - break - } - select { - case <-w.HashJoinCtx.closeCh: - return - case probeSideResult, ok = <-w.probeResultCh: - } - failpoint.Inject("ConsumeRandomPanic", nil) - if !ok { - break - } - start := time.Now() - // waitTime is the time cost on w.sendingResult(), it should not be added to probe time, because if - // parent executor does not call `e.Next()`, `sendingResult()` will hang, and this hang has nothing to do - // with the probe - waitTime := int64(0) - if w.HashJoinCtx.UseOuterToBuild { - ok, waitTime, joinResult = w.join2ChunkForOuterHashJoin(probeSideResult, hCtx, joinResult) - } else { - ok, waitTime, joinResult = w.join2Chunk(probeSideResult, hCtx, joinResult, selected) - } - probeTime += int64(time.Since(start)) - waitTime - if !ok { - break - } - probeSideResult.Reset() - emptyProbeSideResult.chk = probeSideResult - w.probeChkResourceCh <- emptyProbeSideResult - } - // note joinResult.chk may be nil when getNewJoinResult fails in loops - if joinResult == nil { - return - } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { - w.HashJoinCtx.joinResultCh <- joinResult - } else if joinResult.chk != nil && joinResult.chk.NumRows() == 0 { - w.joinChkResourceCh <- joinResult.chk - } -} - -func (w *ProbeWorkerV1) joinMatchedProbeSideRow2ChunkForOuterHashJoin(probeKey uint64, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { - var err error - waitTime := int64(0) - oneWaitTime := int64(0) - w.buildSideRows, w.buildSideRowPtrs, err = w.rowContainerForProbe.GetMatchedRowsAndPtrs(probeKey, probeSideRow, hCtx, w.buildSideRows, w.buildSideRowPtrs, true) - buildSideRows, rowsPtrs := w.buildSideRows, w.buildSideRowPtrs - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) == 0 { - return true, waitTime, joinResult - } - - iter := w.rowIters - iter.Reset(buildSideRows) - var outerMatchStatus []outerRowStatusFlag - rowIdx, ok := 0, false - for iter.Begin(); iter.Current() != iter.End(); { - outerMatchStatus, err = w.Joiner.TryToMatchOuters(iter, probeSideRow, joinResult.chk, outerMatchStatus) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - for i := range outerMatchStatus { - if outerMatchStatus[i] == outerRowMatched { - w.HashJoinCtx.outerMatchedStatus[rowsPtrs[rowIdx+i].ChkIdx].Set(int(rowsPtrs[rowIdx+i].RowIdx)) - } - } - rowIdx += len(outerMatchStatus) - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - return true, waitTime, joinResult -} - -// joinNAALOSJMatchProbeSideRow2Chunk implement the matching logic for NA-AntiLeftOuterSemiJoin -func (w *ProbeWorkerV1) joinNAALOSJMatchProbeSideRow2Chunk(probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { - var ( - err error - ok bool - ) - waitTime := int64(0) - oneWaitTime := int64(0) - if probeKeyNullBits == nil { - // step1: match the same key bucket first. - // because AntiLeftOuterSemiJoin cares about the scalar value. If we both have a match from null - // bucket and same key bucket, we should return the result as from same-key bucket - // rather than from null bucket. - w.buildSideRows, err = w.rowContainerForProbe.GetMatchedRows(probeKey, probeSideRow, hCtx, w.buildSideRows) - buildSideRows := w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) != 0 { - iter1 := w.rowIters - iter1.Reset(buildSideRows) - for iter1.Begin(); iter1.Current() != iter1.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk, LeftNotNullRightNotNull) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid same-key bucket row from right side. - // as said in the comment, once we meet a same key (NOT IN semantic) in CNF, we can determine the result as . - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - } - // step2: match the null bucket secondly. - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) - buildSideRows = w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) == 0 { - // when reach here, it means we couldn't find a valid same key match from same-key bucket yet - // and the null bucket is empty. so the result should be . - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult - } - iter2 := w.rowIters - iter2.Reset(buildSideRows) - for iter2.Begin(); iter2.Current() != iter2.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk, LeftNotNullRightHasNull) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid null bucket row from right side. - // as said in the comment, once we meet a null in CNF, we can determine the result as . - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - // step3: if we couldn't return it quickly in null bucket and same key bucket, here means two cases: - // case1: x NOT IN (empty set): if other key bucket don't have the valid rows yet. - // case2: x NOT IN (l,m,n...): if other key bucket do have the valid rows. - // both cases mean the result should be - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult - } - // when left side has null values, all we want is to find a valid build side rows (past other condition) - // so we can return it as soon as possible. here means two cases: - // case1: NOT IN (empty set): ----------------------> result is . - // case2: NOT IN (at least a valid inner row) ------------------> result is . - // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) - buildSideRows := w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) != 0 { - iter1 := w.rowIters - iter1.Reset(buildSideRows) - for iter1.Begin(); iter1.Current() != iter1.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk, LeftHasNullRightHasNull) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid null bucket row from right side. (not empty) - // as said in the comment, once we found at least a valid row, we can determine the result as . - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - } - // Step2: match all hash table bucket build rows (use probeKeyNullBits to filter if any). - w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) - buildSideRows = w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) == 0 { - // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, - // which means x NOT IN (empty set) or x NOT IN (l,m,n), the result should be - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult - } - iter2 := w.rowIters - iter2.Reset(buildSideRows) - for iter2.Begin(); iter2.Current() != iter2.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk, LeftHasNullRightNotNull) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid same key bucket row from right side. (not empty) - // as said in the comment, once we found at least a valid row, we can determine the result as . - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - // step3: if we couldn't return it quickly in null bucket and all hash bucket, here means only one cases: - // case1: NOT IN (empty set): - // empty set comes from no rows from all bucket can pass other condition. the result should be - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult -} - -// joinNAASJMatchProbeSideRow2Chunk implement the matching logic for NA-AntiSemiJoin -func (w *ProbeWorkerV1) joinNAASJMatchProbeSideRow2Chunk(probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { - var ( - err error - ok bool - ) - waitTime := int64(0) - oneWaitTime := int64(0) - if probeKeyNullBits == nil { - // step1: match null bucket first. - // need fetch the "valid" rows every time. (nullBits map check is necessary) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) - buildSideRows := w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) != 0 { - iter1 := w.rowIters - iter1.Reset(buildSideRows) - for iter1.Begin(); iter1.Current() != iter1.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid null bucket row from right side. - // as said in the comment, once we meet a rhs null in CNF, we can determine the reject of lhs row. - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - } - // step2: then same key bucket. - w.buildSideRows, err = w.rowContainerForProbe.GetMatchedRows(probeKey, probeSideRow, hCtx, w.buildSideRows) - buildSideRows = w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) == 0 { - // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, - // which means x NOT IN (empty set), accept the rhs Row. - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult - } - iter2 := w.rowIters - iter2.Reset(buildSideRows) - for iter2.Begin(); iter2.Current() != iter2.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid same key bucket row from right side. - // as said in the comment, once we meet a false in CNF, we can determine the reject of lhs Row. - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - // step3: if we couldn't return it quickly in null bucket and same key bucket, here means two cases: - // case1: x NOT IN (empty set): if other key bucket don't have the valid rows yet. - // case2: x NOT IN (l,m,n...): if other key bucket do have the valid rows. - // both cases should accept the rhs row. - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult - } - // when left side has null values, all we want is to find a valid build side rows (passed from other condition) - // so we can return it as soon as possible. here means two cases: - // case1: NOT IN (empty set): ----------------------> accept rhs row. - // case2: NOT IN (at least a valid inner row) ------------------> unknown result, refuse rhs row. - // Step1: match null bucket (assumption that null bucket is quite smaller than all hash table bucket rows) - w.buildSideRows, err = w.rowContainerForProbe.GetNullBucketRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) - buildSideRows := w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) != 0 { - iter1 := w.rowIters - iter1.Reset(buildSideRows) - for iter1.Begin(); iter1.Current() != iter1.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter1, joinResult.chk) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid null bucket row from right side. (not empty) - // as said in the comment, once we found at least a valid row, we can determine the reject of lhs row. - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - } - // Step2: match all hash table bucket build rows. - w.buildSideRows, err = w.rowContainerForProbe.GetAllMatchedRows(hCtx, probeSideRow, probeKeyNullBits, w.buildSideRows, w.needCheckBuildColPos, w.needCheckProbeColPos, w.needCheckBuildTypes, w.needCheckProbeTypes) - buildSideRows = w.buildSideRows - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) == 0 { - // when reach here, it means we couldn't return it quickly in null bucket, and same-bucket is empty, - // which means NOT IN (empty set) or NOT IN (no valid rows) accept the rhs row. - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult - } - iter2 := w.rowIters - iter2.Reset(buildSideRows) - for iter2.Begin(); iter2.Current() != iter2.End(); { - matched, _, err := w.Joiner.TryToMatchInners(probeSideRow, iter2, joinResult.chk) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // here matched means: there is a valid key row from right side. (not empty) - // as said in the comment, once we found at least a valid row, we can determine the reject of lhs row. - if matched { - return true, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - // step3: if we couldn't return it quickly in null bucket and all hash bucket, here means only one cases: - // case1: NOT IN (empty set): - // empty set comes from no rows from all bucket can pass other condition. we should accept the rhs row. - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult -} - -// joinNAAJMatchProbeSideRow2Chunk implement the matching priority logic for NA-AntiSemiJoin and NA-AntiLeftOuterSemiJoin -// there are some bucket-matching priority difference between them. -// -// Since NA-AntiSemiJoin don't need to append the scalar value with the left side row, there is a quick matching path. -// 1: lhs row has null: -// lhs row has null can't determine its result in advance, we should judge whether the right valid set is empty -// or not. For semantic like x NOT IN(y set), If y set is empty, the scalar result is 1; Otherwise, the result -// is 0. Since NA-AntiSemiJoin don't care about the scalar value, we just try to find a valid row from right side, -// once we found it then just return the left side row instantly. (same as NA-AntiLeftOuterSemiJoin) -// -// 2: lhs row without null: -// same-key bucket and null-bucket which should be the first to match? For semantic like x NOT IN(y set), once y -// set has a same key x, the scalar value is 0; else if y set has a null key, then the scalar value is null. Both -// of them lead the refuse of the lhs row without any difference. Since NA-AntiSemiJoin don't care about the scalar -// value, we can just match the null bucket first and refuse the lhs row as quickly as possible, because a null of -// yi in the CNF (x NA-EQ yi) can always determine a negative value (refuse lhs row) in advance here. -// -// For NA-AntiLeftOuterSemiJoin, we couldn't match null-bucket first, because once y set has a same key x and null -// key, we should return the result as left side row appended with a scalar value 0 which is from same key matching failure. -func (w *ProbeWorkerV1) joinNAAJMatchProbeSideRow2Chunk(probeKey uint64, probeKeyNullBits *bitmap.ConcurrentBitmap, probeSideRow chunk.Row, hCtx *HashContext, joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { - naAntiSemiJoin := w.HashJoinCtx.JoinType == plannercore.AntiSemiJoin && w.HashJoinCtx.IsNullAware - naAntiLeftOuterSemiJoin := w.HashJoinCtx.JoinType == plannercore.AntiLeftOuterSemiJoin && w.HashJoinCtx.IsNullAware - if naAntiSemiJoin { - return w.joinNAASJMatchProbeSideRow2Chunk(probeKey, probeKeyNullBits, probeSideRow, hCtx, joinResult) - } - if naAntiLeftOuterSemiJoin { - return w.joinNAALOSJMatchProbeSideRow2Chunk(probeKey, probeKeyNullBits, probeSideRow, hCtx, joinResult) - } - // shouldn't be here, not a valid NAAJ. - return false, 0, joinResult -} - -func (w *ProbeWorkerV1) joinMatchedProbeSideRow2Chunk(probeKey uint64, probeSideRow chunk.Row, hCtx *HashContext, - joinResult *hashjoinWorkerResult) (bool, int64, *hashjoinWorkerResult) { - var err error - waitTime := int64(0) - oneWaitTime := int64(0) - var buildSideRows []chunk.Row - if w.Joiner.isSemiJoinWithoutCondition() { - var rowPtr *chunk.Row - rowPtr, err = w.rowContainerForProbe.GetOneMatchedRow(probeKey, probeSideRow, hCtx) - if rowPtr != nil { - buildSideRows = append(buildSideRows, *rowPtr) - } - } else { - w.buildSideRows, err = w.rowContainerForProbe.GetMatchedRows(probeKey, probeSideRow, hCtx, w.buildSideRows) - buildSideRows = w.buildSideRows - } - - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if len(buildSideRows) == 0 { - w.Joiner.OnMissMatch(false, probeSideRow, joinResult.chk) - return true, waitTime, joinResult - } - iter := w.rowIters - iter.Reset(buildSideRows) - hasMatch, hasNull, ok := false, false, false - for iter.Begin(); iter.Current() != iter.End(); { - matched, isNull, err := w.Joiner.TryToMatchInners(probeSideRow, iter, joinResult.chk) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - hasMatch = hasMatch || matched - hasNull = hasNull || isNull - - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - if !hasMatch { - w.Joiner.OnMissMatch(hasNull, probeSideRow, joinResult.chk) - } - return true, waitTime, joinResult -} - -func (w *ProbeWorkerV1) getNewJoinResult() (bool, *hashjoinWorkerResult) { - joinResult := &hashjoinWorkerResult{ - src: w.joinChkResourceCh, - } - ok := true - select { - case <-w.HashJoinCtx.closeCh: - ok = false - case joinResult.chk, ok = <-w.joinChkResourceCh: - } - return ok, joinResult -} - -func (w *ProbeWorkerV1) join2Chunk(probeSideChk *chunk.Chunk, hCtx *HashContext, joinResult *hashjoinWorkerResult, - selected []bool) (ok bool, waitTime int64, _ *hashjoinWorkerResult) { - var err error - waitTime = 0 - oneWaitTime := int64(0) - selected, err = expression.VectorizedFilter(w.HashJoinCtx.SessCtx.GetExprCtx().GetEvalCtx(), w.HashJoinCtx.SessCtx.GetSessionVars().EnableVectorizedExpression, w.HashJoinCtx.OuterFilter, chunk.NewIterator4Chunk(probeSideChk), selected) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - - numRows := probeSideChk.NumRows() - hCtx.InitHash(numRows) - // By now, path 1 and 2 won't be conducted at the same time. - // 1: write the row data of join key to hashVals. (normal EQ key should ignore the null values.) null-EQ for Except statement is an exception. - for keyIdx, i := range hCtx.KeyColIdx { - ignoreNull := len(w.HashJoinCtx.IsNullEQ) > keyIdx && w.HashJoinCtx.IsNullEQ[keyIdx] - err = codec.HashChunkSelected(w.rowContainerForProbe.sc.TypeCtx(), hCtx.HashVals, probeSideChk, hCtx.AllTypes[keyIdx], i, hCtx.Buf, hCtx.HasNull, selected, ignoreNull) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - } - // 2: write the how data of NA join key to hashVals. (NA EQ key should collect all how including null value, store null value in a special position) - isNAAJ := len(hCtx.NaKeyColIdx) > 0 - for keyIdx, i := range hCtx.NaKeyColIdx { - // NAAJ won't ignore any null values, but collect them up to probe. - err = codec.HashChunkSelected(w.rowContainerForProbe.sc.TypeCtx(), hCtx.HashVals, probeSideChk, hCtx.AllTypes[keyIdx], i, hCtx.Buf, hCtx.HasNull, selected, false) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - // after fetch one NA column, collect the null value to null bitmap for every how. (use hasNull flag to accelerate) - // eg: if a NA Join cols is (a, b, c), for every build row here we maintained a 3-bit map to mark which column is null for them. - for rowIdx := 0; rowIdx < numRows; rowIdx++ { - if hCtx.HasNull[rowIdx] { - hCtx.naColNullBitMap[rowIdx].UnsafeSet(keyIdx) - // clean and try fetch Next NA join col. - hCtx.HasNull[rowIdx] = false - hCtx.naHasNull[rowIdx] = true - } - } - } - - for i := range selected { - err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() - failpoint.Inject("killedInJoin2Chunk", func(val failpoint.Value) { - if val.(bool) { - err = exeerrors.ErrQueryInterrupted - } - }) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - if isNAAJ { - if !selected[i] { - // since this is the case of using inner to build, so for an outer row unselected, we should fill the result when it's outer join. - w.Joiner.OnMissMatch(false, probeSideChk.GetRow(i), joinResult.chk) - } - if hCtx.naHasNull[i] { - // here means the probe join connecting column has null value in it and this is special for matching all the hash buckets - // for it. (probeKey is not necessary here) - probeRow := probeSideChk.GetRow(i) - ok, oneWaitTime, joinResult = w.joinNAAJMatchProbeSideRow2Chunk(0, hCtx.naColNullBitMap[i].Clone(), probeRow, hCtx, joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } else { - // here means the probe join connecting column without null values, where we should match same key bucket and null bucket for it at its order. - // step1: process same key matched probe side rows - probeKey, probeRow := hCtx.HashVals[i].Sum64(), probeSideChk.GetRow(i) - ok, oneWaitTime, joinResult = w.joinNAAJMatchProbeSideRow2Chunk(probeKey, nil, probeRow, hCtx, joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } else { - // since this is the case of using inner to build, so for an outer row unselected, we should fill the result when it's outer join. - if !selected[i] || hCtx.HasNull[i] { // process unmatched probe side rows - w.Joiner.OnMissMatch(false, probeSideChk.GetRow(i), joinResult.chk) - } else { // process matched probe side rows - probeKey, probeRow := hCtx.HashVals[i].Sum64(), probeSideChk.GetRow(i) - ok, oneWaitTime, joinResult = w.joinMatchedProbeSideRow2Chunk(probeKey, probeRow, hCtx, joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - return true, waitTime, joinResult -} - -func (w *ProbeWorkerV1) sendingResult(joinResult *hashjoinWorkerResult) (ok bool, cost int64, newJoinResult *hashjoinWorkerResult) { - start := time.Now() - w.HashJoinCtx.joinResultCh <- joinResult - ok, newJoinResult = w.getNewJoinResult() - cost = int64(time.Since(start)) - return ok, cost, newJoinResult -} - -// join2ChunkForOuterHashJoin joins chunks when using the outer to build a hash table (refer to outer hash join) -func (w *ProbeWorkerV1) join2ChunkForOuterHashJoin(probeSideChk *chunk.Chunk, hCtx *HashContext, joinResult *hashjoinWorkerResult) (ok bool, waitTime int64, _ *hashjoinWorkerResult) { - waitTime = 0 - oneWaitTime := int64(0) - hCtx.InitHash(probeSideChk.NumRows()) - for keyIdx, i := range hCtx.KeyColIdx { - err := codec.HashChunkColumns(w.rowContainerForProbe.sc.TypeCtx(), hCtx.HashVals, probeSideChk, hCtx.AllTypes[keyIdx], i, hCtx.Buf, hCtx.HasNull) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - } - for i := 0; i < probeSideChk.NumRows(); i++ { - err := w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller.HandleSignal() - failpoint.Inject("killedInJoin2ChunkForOuterHashJoin", func(val failpoint.Value) { - if val.(bool) { - err = exeerrors.ErrQueryInterrupted - } - }) - if err != nil { - joinResult.err = err - return false, waitTime, joinResult - } - probeKey, probeRow := hCtx.HashVals[i].Sum64(), probeSideChk.GetRow(i) - ok, oneWaitTime, joinResult = w.joinMatchedProbeSideRow2ChunkForOuterHashJoin(probeKey, probeRow, hCtx, joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - if joinResult.chk.IsFull() { - ok, oneWaitTime, joinResult = w.sendingResult(joinResult) - waitTime += oneWaitTime - if !ok { - return false, waitTime, joinResult - } - } - } - return true, waitTime, joinResult -} - -// Next implements the Executor Next interface. -// hash join constructs the result following these steps: -// step 1. fetch data from build side child and build a hash table; -// step 2. fetch data from probe child in a background goroutine and probe the hash table in multiple join workers. -func (e *HashJoinV1Exec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - if !e.Prepared { - e.buildFinished = make(chan error, 1) - hCtx := &HashContext{ - AllTypes: e.BuildTypes, - KeyColIdx: e.BuildWorker.BuildKeyColIdx, - NaKeyColIdx: e.BuildWorker.BuildNAKeyColIdx, - } - e.RowContainer = newHashRowContainer(e.Ctx(), hCtx, exec.RetTypes(e.BuildWorker.BuildSideExec)) - // we shallow copies RowContainer for each probe worker to avoid lock contention - for i := uint(0); i < e.Concurrency; i++ { - if i == 0 { - e.ProbeWorkers[i].rowContainerForProbe = e.RowContainer - } else { - e.ProbeWorkers[i].rowContainerForProbe = e.RowContainer.ShallowCopy() - } - } - for i := uint(0); i < e.Concurrency; i++ { - e.ProbeWorkers[i].rowIters = chunk.NewIterator4Slice([]chunk.Row{}) - } - e.workerWg.RunWithRecover(func() { - defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End() - e.fetchAndBuildHashTable(ctx) - }, e.handleFetchAndBuildHashTablePanic) - e.fetchAndProbeHashTable(ctx) - e.Prepared = true - } - if e.IsOuterJoin { - atomic.StoreInt64(&e.ProbeSideTupleFetcher.requiredRows, int64(req.RequiredRows())) - } - req.Reset() - - result, ok := <-e.joinResultCh - if !ok { - return nil - } - if result.err != nil { - e.finished.Store(true) - return result.err - } - req.SwapColumns(result.chk) - result.src <- result.chk - return nil -} - -func (e *HashJoinV1Exec) handleFetchAndBuildHashTablePanic(r any) { - if r != nil { - e.buildFinished <- util.GetRecoverError(r) - } - close(e.buildFinished) -} - -func (e *HashJoinV1Exec) fetchAndBuildHashTable(ctx context.Context) { - if e.stats != nil { - start := time.Now() - defer func() { - e.stats.fetchAndBuildHashTable = time.Since(start) - }() - } - // buildSideResultCh transfers build side chunk from build side fetch to build hash table. - buildSideResultCh := make(chan *chunk.Chunk, 1) - doneCh := make(chan struct{}) - fetchBuildSideRowsOk := make(chan error, 1) - e.workerWg.RunWithRecover( - func() { - defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End() - e.BuildWorker.fetchBuildSideRows(ctx, &e.BuildWorker.HashJoinCtx.hashJoinCtxBase, buildSideResultCh, fetchBuildSideRowsOk, doneCh) - }, - func(r any) { - if r != nil { - fetchBuildSideRowsOk <- util.GetRecoverError(r) - } - close(fetchBuildSideRowsOk) - }, - ) - - // TODO: Parallel build hash table. Currently not support because `unsafeHashTable` is not thread-safe. - err := e.BuildWorker.BuildHashTableForList(buildSideResultCh) - if err != nil { - e.buildFinished <- errors.Trace(err) - close(doneCh) - } - // Wait fetchBuildSideRows be Finished. - // 1. if BuildHashTableForList fails - // 2. if probeSideResult.NumRows() == 0, fetchProbeSideChunks will not wait for the build side. - channel.Clear(buildSideResultCh) - // Check whether err is nil to avoid sending redundant error into buildFinished. - if err == nil { - if err = <-fetchBuildSideRowsOk; err != nil { - e.buildFinished <- err - } - } -} - -// BuildHashTableForList builds hash table from `list`. -func (w *BuildWorkerV1) BuildHashTableForList(buildSideResultCh <-chan *chunk.Chunk) error { - var err error - var selected []bool - rowContainer := w.HashJoinCtx.RowContainer - rowContainer.GetMemTracker().AttachTo(w.HashJoinCtx.memTracker) - rowContainer.GetMemTracker().SetLabel(memory.LabelForBuildSideResult) - rowContainer.GetDiskTracker().AttachTo(w.HashJoinCtx.diskTracker) - rowContainer.GetDiskTracker().SetLabel(memory.LabelForBuildSideResult) - if variable.EnableTmpStorageOnOOM.Load() { - actionSpill := rowContainer.ActionSpill() - failpoint.Inject("testRowContainerSpill", func(val failpoint.Value) { - if val.(bool) { - actionSpill = rowContainer.rowContainer.ActionSpillForTest() - defer actionSpill.(*chunk.SpillDiskAction).WaitForTest() - } - }) - w.HashJoinCtx.SessCtx.GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) - } - for chk := range buildSideResultCh { - if w.HashJoinCtx.finished.Load() { - return nil - } - if !w.HashJoinCtx.UseOuterToBuild { - err = rowContainer.PutChunk(chk, w.HashJoinCtx.IsNullEQ) - } else { - var bitMap = bitmap.NewConcurrentBitmap(chk.NumRows()) - w.HashJoinCtx.outerMatchedStatus = append(w.HashJoinCtx.outerMatchedStatus, bitMap) - w.HashJoinCtx.memTracker.Consume(bitMap.BytesConsumed()) - if len(w.HashJoinCtx.OuterFilter) == 0 { - err = w.HashJoinCtx.RowContainer.PutChunk(chk, w.HashJoinCtx.IsNullEQ) - } else { - selected, err = expression.VectorizedFilter(w.HashJoinCtx.SessCtx.GetExprCtx().GetEvalCtx(), w.HashJoinCtx.SessCtx.GetSessionVars().EnableVectorizedExpression, w.HashJoinCtx.OuterFilter, chunk.NewIterator4Chunk(chk), selected) - if err != nil { - return err - } - err = rowContainer.PutChunkSelected(chk, selected, w.HashJoinCtx.IsNullEQ) - } - } - failpoint.Inject("ConsumeRandomPanic", nil) - if err != nil { - return err - } - } - return nil -} - -// NestedLoopApplyExec is the executor for apply. -type NestedLoopApplyExec struct { - exec.BaseExecutor - - Sctx sessionctx.Context - innerRows []chunk.Row - cursor int - InnerExec exec.Executor - OuterExec exec.Executor - InnerFilter expression.CNFExprs - OuterFilter expression.CNFExprs - - Joiner Joiner - - cache *applycache.ApplyCache - CanUseCache bool - cacheHitCounter int - cacheAccessCounter int - - OuterSchema []*expression.CorrelatedColumn - - OuterChunk *chunk.Chunk - outerChunkCursor int - outerSelected []bool - InnerList *chunk.List - InnerChunk *chunk.Chunk - innerSelected []bool - innerIter chunk.Iterator - outerRow *chunk.Row - hasMatch bool - hasNull bool - - Outer bool - - memTracker *memory.Tracker // track memory usage. -} - -// Close implements the Executor interface. -func (e *NestedLoopApplyExec) Close() error { - e.innerRows = nil - e.memTracker = nil - if e.RuntimeStats() != nil { - runtimeStats := NewJoinRuntimeStats() - if e.CanUseCache { - var hitRatio float64 - if e.cacheAccessCounter > 0 { - hitRatio = float64(e.cacheHitCounter) / float64(e.cacheAccessCounter) - } - runtimeStats.SetCacheInfo(true, hitRatio) - } else { - runtimeStats.SetCacheInfo(false, 0) - } - runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("concurrency", 0)) - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) - } - return exec.Close(e.OuterExec) -} - -// Open implements the Executor interface. -func (e *NestedLoopApplyExec) Open(ctx context.Context) error { - err := exec.Open(ctx, e.OuterExec) - if err != nil { - return err - } - e.cursor = 0 - e.innerRows = e.innerRows[:0] - e.OuterChunk = exec.TryNewCacheChunk(e.OuterExec) - e.InnerChunk = exec.TryNewCacheChunk(e.InnerExec) - e.InnerList = chunk.NewList(exec.RetTypes(e.InnerExec), e.InitCap(), e.MaxChunkSize()) - - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - - e.InnerList.GetMemTracker().SetLabel(memory.LabelForInnerList) - e.InnerList.GetMemTracker().AttachTo(e.memTracker) - - if e.CanUseCache { - e.cache, err = applycache.NewApplyCache(e.Sctx) - if err != nil { - return err - } - e.cacheHitCounter = 0 - e.cacheAccessCounter = 0 - e.cache.GetMemTracker().AttachTo(e.memTracker) - } - return nil -} - -// aggExecutorTreeInputEmpty checks whether the executor tree returns empty if without aggregate operators. -// Note that, the prerequisite is that this executor tree has been executed already and it returns one Row. -func aggExecutorTreeInputEmpty(e exec.Executor) bool { - children := e.AllChildren() - if len(children) == 0 { - return false - } - if len(children) > 1 { - _, ok := e.(*unionexec.UnionExec) - if !ok { - // It is a Join executor. - return false - } - for _, child := range children { - if !aggExecutorTreeInputEmpty(child) { - return false - } - } - return true - } - // Single child executors. - if aggExecutorTreeInputEmpty(children[0]) { - return true - } - if hashAgg, ok := e.(*aggregate.HashAggExec); ok { - return hashAgg.IsChildReturnEmpty - } - if streamAgg, ok := e.(*aggregate.StreamAggExec); ok { - return streamAgg.IsChildReturnEmpty - } - return false -} - -func (e *NestedLoopApplyExec) fetchSelectedOuterRow(ctx context.Context, chk *chunk.Chunk) (*chunk.Row, error) { - outerIter := chunk.NewIterator4Chunk(e.OuterChunk) - for { - if e.outerChunkCursor >= e.OuterChunk.NumRows() { - err := exec.Next(ctx, e.OuterExec, e.OuterChunk) - if err != nil { - return nil, err - } - if e.OuterChunk.NumRows() == 0 { - return nil, nil - } - e.outerSelected, err = expression.VectorizedFilter(e.Sctx.GetExprCtx().GetEvalCtx(), e.Sctx.GetSessionVars().EnableVectorizedExpression, e.OuterFilter, outerIter, e.outerSelected) - if err != nil { - return nil, err - } - // For cases like `select count(1), (select count(1) from s where s.a > t.a) as sub from t where t.a = 1`, - // if outer child has no row satisfying `t.a = 1`, `sub` should be `null` instead of `0` theoretically; however, the - // outer `count(1)` produces one row <0, null> over the empty input, we should specially mark this outer row - // as not selected, to trigger the mismatch join procedure. - if e.outerChunkCursor == 0 && e.OuterChunk.NumRows() == 1 && e.outerSelected[0] && aggExecutorTreeInputEmpty(e.OuterExec) { - e.outerSelected[0] = false - } - e.outerChunkCursor = 0 - } - outerRow := e.OuterChunk.GetRow(e.outerChunkCursor) - selected := e.outerSelected[e.outerChunkCursor] - e.outerChunkCursor++ - if selected { - return &outerRow, nil - } else if e.Outer { - e.Joiner.OnMissMatch(false, outerRow, chk) - if chk.IsFull() { - return nil, nil - } - } - } -} - -// fetchAllInners reads all data from the inner table and stores them in a List. -func (e *NestedLoopApplyExec) fetchAllInners(ctx context.Context) error { - err := exec.Open(ctx, e.InnerExec) - defer func() { terror.Log(exec.Close(e.InnerExec)) }() - if err != nil { - return err - } - - if e.CanUseCache { - // create a new one since it may be in the cache - e.InnerList = chunk.NewListWithMemTracker(exec.RetTypes(e.InnerExec), e.InitCap(), e.MaxChunkSize(), e.InnerList.GetMemTracker()) - } else { - e.InnerList.Reset() - } - innerIter := chunk.NewIterator4Chunk(e.InnerChunk) - for { - err := exec.Next(ctx, e.InnerExec, e.InnerChunk) - if err != nil { - return err - } - if e.InnerChunk.NumRows() == 0 { - return nil - } - - e.innerSelected, err = expression.VectorizedFilter(e.Sctx.GetExprCtx().GetEvalCtx(), e.Sctx.GetSessionVars().EnableVectorizedExpression, e.InnerFilter, innerIter, e.innerSelected) - if err != nil { - return err - } - for row := innerIter.Begin(); row != innerIter.End(); row = innerIter.Next() { - if e.innerSelected[row.Idx()] { - e.InnerList.AppendRow(row) - } - } - } -} - -// Next implements the Executor interface. -func (e *NestedLoopApplyExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - req.Reset() - for { - if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { - if e.outerRow != nil && !e.hasMatch { - e.Joiner.OnMissMatch(e.hasNull, *e.outerRow, req) - } - e.outerRow, err = e.fetchSelectedOuterRow(ctx, req) - if e.outerRow == nil || err != nil { - return err - } - e.hasMatch = false - e.hasNull = false - - if e.CanUseCache { - var key []byte - for _, col := range e.OuterSchema { - *col.Data = e.outerRow.GetDatum(col.Index, col.RetType) - key, err = codec.EncodeKey(e.Ctx().GetSessionVars().StmtCtx.TimeZone(), key, *col.Data) - err = e.Ctx().GetSessionVars().StmtCtx.HandleError(err) - if err != nil { - return err - } - } - e.cacheAccessCounter++ - value, err := e.cache.Get(key) - if err != nil { - return err - } - if value != nil { - e.InnerList = value - e.cacheHitCounter++ - } else { - err = e.fetchAllInners(ctx) - if err != nil { - return err - } - if _, err := e.cache.Set(key, e.InnerList); err != nil { - return err - } - } - } else { - for _, col := range e.OuterSchema { - *col.Data = e.outerRow.GetDatum(col.Index, col.RetType) - } - err = e.fetchAllInners(ctx) - if err != nil { - return err - } - } - e.innerIter = chunk.NewIterator4List(e.InnerList) - e.innerIter.Begin() - } - - matched, isNull, err := e.Joiner.TryToMatchInners(*e.outerRow, e.innerIter, req) - e.hasMatch = e.hasMatch || matched - e.hasNull = e.hasNull || isNull - - if err != nil || req.IsFull() { - return err - } - } -} - -// cacheInfo is used to save the concurrency information of the executor operator -type cacheInfo struct { - hitRatio float64 - useCache bool -} - -type joinRuntimeStats struct { - *execdetails.RuntimeStatsWithConcurrencyInfo - - applyCache bool - cache cacheInfo - hasHashStat bool - hashStat hashStatistic -} - -// NewJoinRuntimeStats returns a new joinRuntimeStats -func NewJoinRuntimeStats() *joinRuntimeStats { - stats := &joinRuntimeStats{ - RuntimeStatsWithConcurrencyInfo: &execdetails.RuntimeStatsWithConcurrencyInfo{}, - } - return stats -} - -// SetCacheInfo sets the cache information. Only used for apply executor. -func (e *joinRuntimeStats) SetCacheInfo(useCache bool, hitRatio float64) { - e.Lock() - e.applyCache = true - e.cache.useCache = useCache - e.cache.hitRatio = hitRatio - e.Unlock() -} - -func (e *joinRuntimeStats) String() string { - buf := bytes.NewBuffer(make([]byte, 0, 16)) - buf.WriteString(e.RuntimeStatsWithConcurrencyInfo.String()) - if e.applyCache { - if e.cache.useCache { - fmt.Fprintf(buf, ", cache:ON, cacheHitRatio:%.3f%%", e.cache.hitRatio*100) - } else { - buf.WriteString(", cache:OFF") - } - } - if e.hasHashStat { - buf.WriteString(", " + e.hashStat.String()) - } - return buf.String() -} - -// Tp implements the RuntimeStats interface. -func (*joinRuntimeStats) Tp() int { - return execdetails.TpJoinRuntimeStats -} - -func (e *joinRuntimeStats) Clone() execdetails.RuntimeStats { - newJRS := &joinRuntimeStats{ - RuntimeStatsWithConcurrencyInfo: e.RuntimeStatsWithConcurrencyInfo, - applyCache: e.applyCache, - cache: e.cache, - hasHashStat: e.hasHashStat, - hashStat: e.hashStat, - } - return newJRS -} diff --git a/pkg/executor/join/hash_join_v2.go b/pkg/executor/join/hash_join_v2.go index d5bcd9ce69ba1..9fbb587af4a3e 100644 --- a/pkg/executor/join/hash_join_v2.go +++ b/pkg/executor/join/hash_join_v2.go @@ -95,7 +95,7 @@ func (htc *hashTableContext) getCurrentRowSegment(workerID, partitionID int, tab func (htc *hashTableContext) finalizeCurrentSeg(workerID, partitionID int, builder *rowTableBuilder) { seg := htc.getCurrentRowSegment(workerID, partitionID, nil, false, 0) builder.rowNumberInCurrentRowTableSeg[partitionID] = 0 - failpoint.Eval(_curpkg_("finalizeCurrentSegPanic")) + failpoint.Inject("finalizeCurrentSegPanic", nil) seg.finalized = true htc.memoryTracker.Consume(seg.totalUsedBytes()) } @@ -360,7 +360,7 @@ func (w *BuildWorkerV2) splitPartitionAndAppendToRowTable(typeCtx types.Context, for chk := range srcChkCh { start := time.Now() err = builder.processOneChunk(chk, typeCtx, w.HashJoinCtx, int(w.WorkerID)) - failpoint.Eval(_curpkg_("splitPartitionPanic")) + failpoint.Inject("splitPartitionPanic", nil) cost += int64(time.Since(start)) if err != nil { return err @@ -495,7 +495,7 @@ func (w *ProbeWorkerV2) processOneProbeChunk(probeChunk *chunk.Chunk, joinResult if !ok || joinResult.err != nil { return ok, waitTime, joinResult } - failpoint.Eval(_curpkg_("processOneProbeChunkPanic")) + failpoint.Inject("processOneProbeChunkPanic", nil) if joinResult.chk.IsFull() { waitStart := time.Now() w.HashJoinCtx.joinResultCh <- joinResult @@ -542,7 +542,7 @@ func (w *ProbeWorkerV2) runJoinWorker() { return case probeSideResult, ok = <-w.probeResultCh: } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if !ok { break } @@ -648,7 +648,7 @@ func (e *HashJoinV2Exec) createTasks(buildTaskCh chan<- *buildTask, totalSegment createBuildTask := func(partIdx int, segStartIdx int, segEndIdx int) *buildTask { return &buildTask{partitionIdx: partIdx, segStartIdx: segStartIdx, segEndIdx: segEndIdx} } - failpoint.Eval(_curpkg_("createTasksPanic")) + failpoint.Inject("createTasksPanic", nil) if isBalanced { for partIdx, subTable := range subTables { @@ -831,7 +831,7 @@ func (w *BuildWorkerV2) buildHashTable(taskCh chan *buildTask) error { start := time.Now() partIdx, segStartIdx, segEndIdx := task.partitionIdx, task.segStartIdx, task.segEndIdx w.HashJoinCtx.hashTableContext.hashTable.tables[partIdx].build(segStartIdx, segEndIdx) - failpoint.Eval(_curpkg_("buildHashTablePanic")) + failpoint.Inject("buildHashTablePanic", nil) cost += int64(time.Since(start)) } return nil diff --git a/pkg/executor/join/hash_join_v2.go__failpoint_stash__ b/pkg/executor/join/hash_join_v2.go__failpoint_stash__ deleted file mode 100644 index 9fbb587af4a3e..0000000000000 --- a/pkg/executor/join/hash_join_v2.go__failpoint_stash__ +++ /dev/null @@ -1,943 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "bytes" - "context" - "math" - "runtime/trace" - "strconv" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/mysql" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/memory" -) - -var ( - _ exec.Executor = &HashJoinV2Exec{} - // enableHashJoinV2 is a variable used only in test - enableHashJoinV2 = atomic.Bool{} -) - -// IsHashJoinV2Enabled return true if hash join v2 is enabled -func IsHashJoinV2Enabled() bool { - // sizeOfUintptr should always equal to sizeOfUnsafePointer, because according to golang's doc, - // a Pointer can be converted to an uintptr. Add this check here in case in the future go runtime - // change this - return !heapObjectsCanMove() && enableHashJoinV2.Load() && sizeOfUintptr >= sizeOfUnsafePointer -} - -// SetEnableHashJoinV2 enable/disable hash join v2 -func SetEnableHashJoinV2(enable bool) { - enableHashJoinV2.Store(enable) -} - -type hashTableContext struct { - // rowTables is used during split partition stage, each buildWorker has - // its own rowTable - rowTables [][]*rowTable - hashTable *hashTableV2 - memoryTracker *memory.Tracker -} - -func (htc *hashTableContext) reset() { - htc.rowTables = nil - htc.hashTable = nil - htc.memoryTracker.Detach() -} - -func (htc *hashTableContext) getCurrentRowSegment(workerID, partitionID int, tableMeta *TableMeta, allowCreate bool, firstSegSizeHint uint) *rowTableSegment { - if htc.rowTables[workerID][partitionID] == nil { - htc.rowTables[workerID][partitionID] = newRowTable(tableMeta) - } - segNum := len(htc.rowTables[workerID][partitionID].segments) - if segNum == 0 || htc.rowTables[workerID][partitionID].segments[segNum-1].finalized { - if !allowCreate { - panic("logical error, should not reach here") - } - // do not pre-allocate too many memory for the first seg because for query that only has a few rows, it may waste memory and may hurt the performance in high concurrency scenarios - rowSizeHint := maxRowTableSegmentSize - if segNum == 0 { - rowSizeHint = int(firstSegSizeHint) - } - seg := newRowTableSegment(uint(rowSizeHint)) - htc.rowTables[workerID][partitionID].segments = append(htc.rowTables[workerID][partitionID].segments, seg) - segNum++ - } - return htc.rowTables[workerID][partitionID].segments[segNum-1] -} - -func (htc *hashTableContext) finalizeCurrentSeg(workerID, partitionID int, builder *rowTableBuilder) { - seg := htc.getCurrentRowSegment(workerID, partitionID, nil, false, 0) - builder.rowNumberInCurrentRowTableSeg[partitionID] = 0 - failpoint.Inject("finalizeCurrentSegPanic", nil) - seg.finalized = true - htc.memoryTracker.Consume(seg.totalUsedBytes()) -} - -func (htc *hashTableContext) mergeRowTablesToHashTable(tableMeta *TableMeta, partitionNumber uint) int { - rowTables := make([]*rowTable, partitionNumber) - for i := 0; i < int(partitionNumber); i++ { - rowTables[i] = newRowTable(tableMeta) - } - totalSegmentCnt := 0 - for _, rowTablesPerWorker := range htc.rowTables { - for partIdx, rt := range rowTablesPerWorker { - if rt == nil { - continue - } - rowTables[partIdx].merge(rt) - totalSegmentCnt += len(rt.segments) - } - } - for i := 0; i < int(partitionNumber); i++ { - htc.hashTable.tables[i] = newSubTable(rowTables[i]) - } - htc.rowTables = nil - return totalSegmentCnt -} - -// HashJoinCtxV2 is the hash join ctx used in hash join v2 -type HashJoinCtxV2 struct { - hashJoinCtxBase - partitionNumber uint - partitionMaskOffset int - ProbeKeyTypes []*types.FieldType - BuildKeyTypes []*types.FieldType - stats *hashJoinRuntimeStatsV2 - - RightAsBuildSide bool - BuildFilter expression.CNFExprs - ProbeFilter expression.CNFExprs - OtherCondition expression.CNFExprs - hashTableContext *hashTableContext - hashTableMeta *TableMeta - needScanRowTableAfterProbeDone bool - - LUsed, RUsed []int - LUsedInOtherCondition, RUsedInOtherCondition []int -} - -// partitionNumber is always power of 2 -func genHashJoinPartitionNumber(partitionHint uint) uint { - prevRet := uint(16) - currentRet := uint(8) - for currentRet != 0 { - if currentRet < partitionHint { - return prevRet - } - prevRet = currentRet - currentRet = currentRet >> 1 - } - return 1 -} - -func getPartitionMaskOffset(partitionNumber uint) int { - getMSBPos := func(num uint64) int { - ret := 0 - for num&1 != 1 { - num = num >> 1 - ret++ - } - if num != 1 { - // partitionNumber is always pow of 2 - panic("should not reach here") - } - return ret - } - msbPos := getMSBPos(uint64(partitionNumber)) - // top MSB bits in hash value will be used to partition data - return 64 - msbPos -} - -// SetupPartitionInfo set up partitionNumber and partitionMaskOffset based on concurrency -func (hCtx *HashJoinCtxV2) SetupPartitionInfo() { - hCtx.partitionNumber = genHashJoinPartitionNumber(hCtx.Concurrency) - hCtx.partitionMaskOffset = getPartitionMaskOffset(hCtx.partitionNumber) -} - -// initHashTableContext create hashTableContext for current HashJoinCtxV2 -func (hCtx *HashJoinCtxV2) initHashTableContext() { - hCtx.hashTableContext = &hashTableContext{} - hCtx.hashTableContext.rowTables = make([][]*rowTable, hCtx.Concurrency) - for index := range hCtx.hashTableContext.rowTables { - hCtx.hashTableContext.rowTables[index] = make([]*rowTable, hCtx.partitionNumber) - } - hCtx.hashTableContext.hashTable = &hashTableV2{ - tables: make([]*subTable, hCtx.partitionNumber), - partitionNumber: uint64(hCtx.partitionNumber), - } - hCtx.hashTableContext.memoryTracker = memory.NewTracker(memory.LabelForHashTableInHashJoinV2, -1) -} - -// ProbeSideTupleFetcherV2 reads tuples from ProbeSideExec and send them to ProbeWorkers. -type ProbeSideTupleFetcherV2 struct { - probeSideTupleFetcherBase - *HashJoinCtxV2 - canSkipProbeIfHashTableIsEmpty bool -} - -// ProbeWorkerV2 is the probe worker used in hash join v2 -type ProbeWorkerV2 struct { - probeWorkerBase - HashJoinCtx *HashJoinCtxV2 - // We build individual joinProbe for each join worker when use chunk-based - // execution, to avoid the concurrency of joiner.chk and joiner.selected. - JoinProbe ProbeV2 -} - -// BuildWorkerV2 is the build worker used in hash join v2 -type BuildWorkerV2 struct { - buildWorkerBase - HashJoinCtx *HashJoinCtxV2 - BuildTypes []*types.FieldType - HasNullableKey bool - WorkerID uint -} - -// NewJoinBuildWorkerV2 create a BuildWorkerV2 -func NewJoinBuildWorkerV2(ctx *HashJoinCtxV2, workID uint, buildSideExec exec.Executor, buildKeyColIdx []int, buildTypes []*types.FieldType) *BuildWorkerV2 { - hasNullableKey := false - for _, idx := range buildKeyColIdx { - if !mysql.HasNotNullFlag(buildTypes[idx].GetFlag()) { - hasNullableKey = true - break - } - } - worker := &BuildWorkerV2{ - HashJoinCtx: ctx, - BuildTypes: buildTypes, - WorkerID: workID, - HasNullableKey: hasNullableKey, - } - worker.BuildSideExec = buildSideExec - worker.BuildKeyColIdx = buildKeyColIdx - return worker -} - -// HashJoinV2Exec implements the hash join algorithm. -type HashJoinV2Exec struct { - exec.BaseExecutor - *HashJoinCtxV2 - - ProbeSideTupleFetcher *ProbeSideTupleFetcherV2 - ProbeWorkers []*ProbeWorkerV2 - BuildWorkers []*BuildWorkerV2 - - workerWg util.WaitGroupWrapper - waiterWg util.WaitGroupWrapper - - prepared bool -} - -// Close implements the Executor Close interface. -func (e *HashJoinV2Exec) Close() error { - if e.closeCh != nil { - close(e.closeCh) - } - e.finished.Store(true) - if e.prepared { - if e.buildFinished != nil { - channel.Clear(e.buildFinished) - } - if e.joinResultCh != nil { - channel.Clear(e.joinResultCh) - } - if e.ProbeSideTupleFetcher.probeChkResourceCh != nil { - close(e.ProbeSideTupleFetcher.probeChkResourceCh) - channel.Clear(e.ProbeSideTupleFetcher.probeChkResourceCh) - } - for i := range e.ProbeSideTupleFetcher.probeResultChs { - channel.Clear(e.ProbeSideTupleFetcher.probeResultChs[i]) - } - for i := range e.ProbeWorkers { - close(e.ProbeWorkers[i].joinChkResourceCh) - channel.Clear(e.ProbeWorkers[i].joinChkResourceCh) - } - e.ProbeSideTupleFetcher.probeChkResourceCh = nil - e.waiterWg.Wait() - e.hashTableContext.reset() - } - for _, w := range e.ProbeWorkers { - w.joinChkResourceCh = nil - } - - if e.stats != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - err := e.BaseExecutor.Close() - return err -} - -// Open implements the Executor Open interface. -func (e *HashJoinV2Exec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - e.closeCh = nil - e.prepared = false - return err - } - e.prepared = false - needScanRowTableAfterProbeDone := e.ProbeWorkers[0].JoinProbe.NeedScanRowTable() - e.HashJoinCtxV2.needScanRowTableAfterProbeDone = needScanRowTableAfterProbeDone - if e.RightAsBuildSide { - e.hashTableMeta = newTableMeta(e.BuildWorkers[0].BuildKeyColIdx, e.BuildWorkers[0].BuildTypes, - e.BuildKeyTypes, e.ProbeKeyTypes, e.RUsedInOtherCondition, e.RUsed, needScanRowTableAfterProbeDone) - } else { - e.hashTableMeta = newTableMeta(e.BuildWorkers[0].BuildKeyColIdx, e.BuildWorkers[0].BuildTypes, - e.BuildKeyTypes, e.ProbeKeyTypes, e.LUsedInOtherCondition, e.LUsed, needScanRowTableAfterProbeDone) - } - e.HashJoinCtxV2.ChunkAllocPool = e.AllocPool - if e.memTracker != nil { - e.memTracker.Reset() - } else { - e.memTracker = memory.NewTracker(e.ID(), -1) - } - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - - e.diskTracker = disk.NewTracker(e.ID(), -1) - e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) - - e.workerWg = util.WaitGroupWrapper{} - e.waiterWg = util.WaitGroupWrapper{} - e.closeCh = make(chan struct{}) - e.finished.Store(false) - - if e.RuntimeStats() != nil { - e.stats = &hashJoinRuntimeStatsV2{} - e.stats.concurrent = int(e.Concurrency) - } - return nil -} - -func (fetcher *ProbeSideTupleFetcherV2) shouldLimitProbeFetchSize() bool { - if fetcher.JoinType == plannercore.LeftOuterJoin && fetcher.RightAsBuildSide { - return true - } - if fetcher.JoinType == plannercore.RightOuterJoin && !fetcher.RightAsBuildSide { - return true - } - return false -} - -func (w *BuildWorkerV2) splitPartitionAndAppendToRowTable(typeCtx types.Context, srcChkCh chan *chunk.Chunk) (err error) { - cost := int64(0) - defer func() { - if w.HashJoinCtx.stats != nil { - atomic.AddInt64(&w.HashJoinCtx.stats.partitionData, cost) - setMaxValue(&w.HashJoinCtx.stats.maxPartitionData, cost) - } - }() - partitionNumber := w.HashJoinCtx.partitionNumber - hashJoinCtx := w.HashJoinCtx - - builder := createRowTableBuilder(w.BuildKeyColIdx, hashJoinCtx.BuildKeyTypes, partitionNumber, w.HasNullableKey, hashJoinCtx.BuildFilter != nil, hashJoinCtx.needScanRowTableAfterProbeDone) - - for chk := range srcChkCh { - start := time.Now() - err = builder.processOneChunk(chk, typeCtx, w.HashJoinCtx, int(w.WorkerID)) - failpoint.Inject("splitPartitionPanic", nil) - cost += int64(time.Since(start)) - if err != nil { - return err - } - } - start := time.Now() - builder.appendRemainingRowLocations(int(w.WorkerID), w.HashJoinCtx.hashTableContext) - cost += int64(time.Since(start)) - return nil -} - -func (e *HashJoinV2Exec) canSkipProbeIfHashTableIsEmpty() bool { - switch e.JoinType { - case plannercore.InnerJoin: - return true - case plannercore.LeftOuterJoin: - return !e.RightAsBuildSide - case plannercore.RightOuterJoin: - return e.RightAsBuildSide - case plannercore.SemiJoin: - return e.RightAsBuildSide - default: - return false - } -} - -func (e *HashJoinV2Exec) initializeForProbe() { - e.ProbeSideTupleFetcher.HashJoinCtxV2 = e.HashJoinCtxV2 - // e.joinResultCh is for transmitting the join result chunks to the main - // thread. - e.joinResultCh = make(chan *hashjoinWorkerResult, e.Concurrency+1) - e.ProbeSideTupleFetcher.initializeForProbeBase(e.Concurrency, e.joinResultCh) - e.ProbeSideTupleFetcher.canSkipProbeIfHashTableIsEmpty = e.canSkipProbeIfHashTableIsEmpty() - - for i := uint(0); i < e.Concurrency; i++ { - e.ProbeWorkers[i].initializeForProbe(e.ProbeSideTupleFetcher.probeChkResourceCh, e.ProbeSideTupleFetcher.probeResultChs[i], e) - e.ProbeWorkers[i].JoinProbe.ResetProbeCollision() - } -} - -func (e *HashJoinV2Exec) fetchAndProbeHashTable(ctx context.Context) { - e.initializeForProbe() - fetchProbeSideChunksFunc := func() { - defer trace.StartRegion(ctx, "HashJoinProbeSideFetcher").End() - e.ProbeSideTupleFetcher.fetchProbeSideChunks( - ctx, - e.MaxChunkSize(), - func() bool { return e.ProbeSideTupleFetcher.hashTableContext.hashTable.isHashTableEmpty() }, - e.ProbeSideTupleFetcher.canSkipProbeIfHashTableIsEmpty, - e.ProbeSideTupleFetcher.needScanRowTableAfterProbeDone, - e.ProbeSideTupleFetcher.shouldLimitProbeFetchSize(), - &e.ProbeSideTupleFetcher.hashJoinCtxBase) - } - e.workerWg.RunWithRecover(fetchProbeSideChunksFunc, e.ProbeSideTupleFetcher.handleProbeSideFetcherPanic) - - for i := uint(0); i < e.Concurrency; i++ { - workerID := i - e.workerWg.RunWithRecover(func() { - defer trace.StartRegion(ctx, "HashJoinWorker").End() - e.ProbeWorkers[workerID].runJoinWorker() - }, e.ProbeWorkers[workerID].handleProbeWorkerPanic) - } - e.waiterWg.RunWithRecover(e.waitJoinWorkersAndCloseResultChan, nil) -} - -func (w *ProbeWorkerV2) handleProbeWorkerPanic(r any) { - if r != nil { - w.HashJoinCtx.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} - } -} - -func (e *HashJoinV2Exec) handleJoinWorkerPanic(r any) { - if r != nil { - e.joinResultCh <- &hashjoinWorkerResult{err: util.GetRecoverError(r)} - } -} - -func (e *HashJoinV2Exec) waitJoinWorkersAndCloseResultChan() { - e.workerWg.Wait() - if e.stats != nil { - for _, prober := range e.ProbeWorkers { - e.stats.hashStat.probeCollision += int64(prober.JoinProbe.GetProbeCollision()) - } - } - if e.ProbeWorkers[0] != nil && e.ProbeWorkers[0].JoinProbe.NeedScanRowTable() { - for i := uint(0); i < e.Concurrency; i++ { - var workerID = i - e.workerWg.RunWithRecover(func() { - e.ProbeWorkers[workerID].scanRowTableAfterProbeDone() - }, e.handleJoinWorkerPanic) - } - e.workerWg.Wait() - } - close(e.joinResultCh) -} - -func (w *ProbeWorkerV2) scanRowTableAfterProbeDone() { - w.JoinProbe.InitForScanRowTable() - ok, joinResult := w.getNewJoinResult() - if !ok { - return - } - for !w.JoinProbe.IsScanRowTableDone() { - joinResult = w.JoinProbe.ScanRowTable(joinResult, &w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller) - if joinResult.err != nil { - w.HashJoinCtx.joinResultCh <- joinResult - return - } - if joinResult.chk.IsFull() { - w.HashJoinCtx.joinResultCh <- joinResult - ok, joinResult = w.getNewJoinResult() - if !ok { - return - } - } - } - if joinResult == nil { - return - } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { - w.HashJoinCtx.joinResultCh <- joinResult - } -} - -func (w *ProbeWorkerV2) processOneProbeChunk(probeChunk *chunk.Chunk, joinResult *hashjoinWorkerResult) (ok bool, waitTime int64, _ *hashjoinWorkerResult) { - waitTime = 0 - joinResult.err = w.JoinProbe.SetChunkForProbe(probeChunk) - if joinResult.err != nil { - return false, waitTime, joinResult - } - for !w.JoinProbe.IsCurrentChunkProbeDone() { - ok, joinResult = w.JoinProbe.Probe(joinResult, &w.HashJoinCtx.SessCtx.GetSessionVars().SQLKiller) - if !ok || joinResult.err != nil { - return ok, waitTime, joinResult - } - failpoint.Inject("processOneProbeChunkPanic", nil) - if joinResult.chk.IsFull() { - waitStart := time.Now() - w.HashJoinCtx.joinResultCh <- joinResult - ok, joinResult = w.getNewJoinResult() - waitTime += int64(time.Since(waitStart)) - if !ok { - return false, waitTime, joinResult - } - } - } - return true, waitTime, joinResult -} - -func (w *ProbeWorkerV2) runJoinWorker() { - probeTime := int64(0) - if w.HashJoinCtx.stats != nil { - start := time.Now() - defer func() { - t := time.Since(start) - atomic.AddInt64(&w.HashJoinCtx.stats.probe, probeTime) - atomic.AddInt64(&w.HashJoinCtx.stats.fetchAndProbe, int64(t)) - setMaxValue(&w.HashJoinCtx.stats.maxFetchAndProbe, int64(t)) - }() - } - - var ( - probeSideResult *chunk.Chunk - ) - ok, joinResult := w.getNewJoinResult() - if !ok { - return - } - - // Read and filter probeSideResult, and join the probeSideResult with the build side rows. - emptyProbeSideResult := &probeChkResource{ - dest: w.probeResultCh, - } - for ok := true; ok; { - if w.HashJoinCtx.finished.Load() { - break - } - select { - case <-w.HashJoinCtx.closeCh: - return - case probeSideResult, ok = <-w.probeResultCh: - } - failpoint.Inject("ConsumeRandomPanic", nil) - if !ok { - break - } - - start := time.Now() - waitTime := int64(0) - ok, waitTime, joinResult = w.processOneProbeChunk(probeSideResult, joinResult) - probeTime += int64(time.Since(start)) - waitTime - if !ok { - break - } - probeSideResult.Reset() - emptyProbeSideResult.chk = probeSideResult - w.probeChkResourceCh <- emptyProbeSideResult - } - // note joinResult.chk may be nil when getNewJoinResult fails in loops - if joinResult == nil { - return - } else if joinResult.err != nil || (joinResult.chk != nil && joinResult.chk.NumRows() > 0) { - w.HashJoinCtx.joinResultCh <- joinResult - } else if joinResult.chk != nil && joinResult.chk.NumRows() == 0 { - w.joinChkResourceCh <- joinResult.chk - } -} - -func (w *ProbeWorkerV2) getNewJoinResult() (bool, *hashjoinWorkerResult) { - joinResult := &hashjoinWorkerResult{ - src: w.joinChkResourceCh, - } - ok := true - select { - case <-w.HashJoinCtx.closeCh: - ok = false - case joinResult.chk, ok = <-w.joinChkResourceCh: - } - return ok, joinResult -} - -// Next implements the Executor Next interface. -// hash join constructs the result following these steps: -// step 1. fetch data from build side child and build a hash table; -// step 2. fetch data from probe child in a background goroutine and probe the hash table in multiple join workers. -func (e *HashJoinV2Exec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - if !e.prepared { - e.initHashTableContext() - e.hashTableContext.memoryTracker.AttachTo(e.memTracker) - e.buildFinished = make(chan error, 1) - e.workerWg.RunWithRecover(func() { - defer trace.StartRegion(ctx, "HashJoinHashTableBuilder").End() - e.fetchAndBuildHashTable(ctx) - }, e.handleFetchAndBuildHashTablePanic) - e.fetchAndProbeHashTable(ctx) - e.prepared = true - } - if e.ProbeSideTupleFetcher.shouldLimitProbeFetchSize() { - atomic.StoreInt64(&e.ProbeSideTupleFetcher.requiredRows, int64(req.RequiredRows())) - } - req.Reset() - - result, ok := <-e.joinResultCh - if !ok { - return nil - } - if result.err != nil { - e.finished.Store(true) - return result.err - } - req.SwapColumns(result.chk) - result.src <- result.chk - return nil -} - -func (e *HashJoinV2Exec) handleFetchAndBuildHashTablePanic(r any) { - if r != nil { - e.buildFinished <- util.GetRecoverError(r) - } - close(e.buildFinished) -} - -// checkBalance checks whether the segment count of each partition is balanced. -func (e *HashJoinV2Exec) checkBalance(totalSegmentCnt int) bool { - isBalanced := e.Concurrency == e.partitionNumber - if !isBalanced { - return false - } - avgSegCnt := totalSegmentCnt / int(e.partitionNumber) - balanceThreshold := int(float64(avgSegCnt) * 0.8) - subTables := e.HashJoinCtxV2.hashTableContext.hashTable.tables - - for _, subTable := range subTables { - if math.Abs(float64(len(subTable.rowData.segments)-avgSegCnt)) > float64(balanceThreshold) { - isBalanced = false - break - } - } - return isBalanced -} - -func (e *HashJoinV2Exec) createTasks(buildTaskCh chan<- *buildTask, totalSegmentCnt int, doneCh chan struct{}) { - isBalanced := e.checkBalance(totalSegmentCnt) - segStep := max(1, totalSegmentCnt/int(e.Concurrency)) - subTables := e.HashJoinCtxV2.hashTableContext.hashTable.tables - createBuildTask := func(partIdx int, segStartIdx int, segEndIdx int) *buildTask { - return &buildTask{partitionIdx: partIdx, segStartIdx: segStartIdx, segEndIdx: segEndIdx} - } - failpoint.Inject("createTasksPanic", nil) - - if isBalanced { - for partIdx, subTable := range subTables { - segmentsLen := len(subTable.rowData.segments) - select { - case <-doneCh: - return - case buildTaskCh <- createBuildTask(partIdx, 0, segmentsLen): - } - } - return - } - - partitionStartIndex := make([]int, len(subTables)) - partitionSegmentLength := make([]int, len(subTables)) - for i := 0; i < len(subTables); i++ { - partitionStartIndex[i] = 0 - partitionSegmentLength[i] = len(subTables[i].rowData.segments) - } - - for { - hasNewTask := false - for partIdx := range subTables { - // create table by round-robin all the partitions so the build thread is likely to build different partition at the same time - if partitionStartIndex[partIdx] < partitionSegmentLength[partIdx] { - startIndex := partitionStartIndex[partIdx] - endIndex := min(startIndex+segStep, partitionSegmentLength[partIdx]) - select { - case <-doneCh: - return - case buildTaskCh <- createBuildTask(partIdx, startIndex, endIndex): - } - partitionStartIndex[partIdx] = endIndex - hasNewTask = true - } - } - if !hasNewTask { - break - } - } -} - -func (e *HashJoinV2Exec) fetchAndBuildHashTable(ctx context.Context) { - if e.stats != nil { - start := time.Now() - defer func() { - e.stats.fetchAndBuildHashTable = time.Since(start) - }() - } - - waitJobDone := func(wg *sync.WaitGroup, errCh chan error) bool { - wg.Wait() - close(errCh) - if err := <-errCh; err != nil { - e.buildFinished <- err - return false - } - return true - } - - wg := new(sync.WaitGroup) - errCh := make(chan error, 1+e.Concurrency) - // doneCh is used by the consumer(splitAndAppendToRowTable) to info the producer(fetchBuildSideRows) that the consumer meet error and stop consume data - doneCh := make(chan struct{}, e.Concurrency) - srcChkCh := e.fetchBuildSideRows(ctx, wg, errCh, doneCh) - e.splitAndAppendToRowTable(srcChkCh, wg, errCh, doneCh) - success := waitJobDone(wg, errCh) - if !success { - return - } - - totalSegmentCnt := e.hashTableContext.mergeRowTablesToHashTable(e.hashTableMeta, e.partitionNumber) - - wg = new(sync.WaitGroup) - errCh = make(chan error, 1+e.Concurrency) - // doneCh is used by the consumer(buildHashTable) to info the producer(createBuildTasks) that the consumer meet error and stop consume data - doneCh = make(chan struct{}, e.Concurrency) - buildTaskCh := e.createBuildTasks(totalSegmentCnt, wg, errCh, doneCh) - e.buildHashTable(buildTaskCh, wg, errCh, doneCh) - waitJobDone(wg, errCh) -} - -func (e *HashJoinV2Exec) fetchBuildSideRows(ctx context.Context, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) chan *chunk.Chunk { - srcChkCh := make(chan *chunk.Chunk, 1) - wg.Add(1) - e.workerWg.RunWithRecover( - func() { - defer trace.StartRegion(ctx, "HashJoinBuildSideFetcher").End() - fetcher := e.BuildWorkers[0] - fetcher.fetchBuildSideRows(ctx, &fetcher.HashJoinCtx.hashJoinCtxBase, srcChkCh, errCh, doneCh) - }, - func(r any) { - if r != nil { - errCh <- util.GetRecoverError(r) - } - wg.Done() - }, - ) - return srcChkCh -} - -func (e *HashJoinV2Exec) splitAndAppendToRowTable(srcChkCh chan *chunk.Chunk, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) { - for i := uint(0); i < e.Concurrency; i++ { - wg.Add(1) - workIndex := i - e.workerWg.RunWithRecover( - func() { - err := e.BuildWorkers[workIndex].splitPartitionAndAppendToRowTable(e.SessCtx.GetSessionVars().StmtCtx.TypeCtx(), srcChkCh) - if err != nil { - errCh <- err - doneCh <- struct{}{} - } - }, - func(r any) { - if r != nil { - errCh <- util.GetRecoverError(r) - doneCh <- struct{}{} - } - wg.Done() - }, - ) - } -} - -func (e *HashJoinV2Exec) createBuildTasks(totalSegmentCnt int, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) chan *buildTask { - buildTaskCh := make(chan *buildTask, e.Concurrency) - wg.Add(1) - e.workerWg.RunWithRecover( - func() { e.createTasks(buildTaskCh, totalSegmentCnt, doneCh) }, - func(r any) { - if r != nil { - errCh <- util.GetRecoverError(r) - } - close(buildTaskCh) - wg.Done() - }, - ) - return buildTaskCh -} - -func (e *HashJoinV2Exec) buildHashTable(buildTaskCh chan *buildTask, wg *sync.WaitGroup, errCh chan error, doneCh chan struct{}) { - for i := uint(0); i < e.Concurrency; i++ { - wg.Add(1) - workID := i - e.workerWg.RunWithRecover( - func() { - err := e.BuildWorkers[workID].buildHashTable(buildTaskCh) - if err != nil { - errCh <- err - doneCh <- struct{}{} - } - }, - func(r any) { - if r != nil { - errCh <- util.GetRecoverError(r) - doneCh <- struct{}{} - } - wg.Done() - }, - ) - } -} - -type buildTask struct { - partitionIdx int - segStartIdx int - segEndIdx int -} - -// buildHashTableForList builds hash table from `list`. -func (w *BuildWorkerV2) buildHashTable(taskCh chan *buildTask) error { - cost := int64(0) - defer func() { - if w.HashJoinCtx.stats != nil { - atomic.AddInt64(&w.HashJoinCtx.stats.buildHashTable, cost) - setMaxValue(&w.HashJoinCtx.stats.maxBuildHashTable, cost) - } - }() - for task := range taskCh { - start := time.Now() - partIdx, segStartIdx, segEndIdx := task.partitionIdx, task.segStartIdx, task.segEndIdx - w.HashJoinCtx.hashTableContext.hashTable.tables[partIdx].build(segStartIdx, segEndIdx) - failpoint.Inject("buildHashTablePanic", nil) - cost += int64(time.Since(start)) - } - return nil -} - -type hashJoinRuntimeStatsV2 struct { - hashJoinRuntimeStats - partitionData int64 - maxPartitionData int64 - buildHashTable int64 - maxBuildHashTable int64 -} - -func setMaxValue(addr *int64, currentValue int64) { - for { - value := atomic.LoadInt64(addr) - if currentValue <= value { - return - } - if atomic.CompareAndSwapInt64(addr, value, currentValue) { - return - } - } -} - -// Tp implements the RuntimeStats interface. -func (*hashJoinRuntimeStatsV2) Tp() int { - return execdetails.TpHashJoinRuntimeStats -} - -func (e *hashJoinRuntimeStatsV2) String() string { - buf := bytes.NewBuffer(make([]byte, 0, 128)) - if e.fetchAndBuildHashTable > 0 { - buf.WriteString("build_hash_table:{concurrency:") - buf.WriteString(strconv.Itoa(e.concurrent)) - buf.WriteString(", total:") - buf.WriteString(execdetails.FormatDuration(e.fetchAndBuildHashTable)) - buf.WriteString(", fetch:") - buf.WriteString(execdetails.FormatDuration(time.Duration(int64(e.fetchAndBuildHashTable) - e.maxBuildHashTable - e.maxPartitionData))) - buf.WriteString(", partition:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.partitionData))) - buf.WriteString(", max partition:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.maxPartitionData))) - buf.WriteString(", build:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.buildHashTable))) - buf.WriteString(", max build:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.maxBuildHashTable))) - buf.WriteString("}") - } - if e.probe > 0 { - buf.WriteString(", probe:{concurrency:") - buf.WriteString(strconv.Itoa(e.concurrent)) - buf.WriteString(", total:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe))) - buf.WriteString(", max:") - buf.WriteString(execdetails.FormatDuration(time.Duration(atomic.LoadInt64(&e.maxFetchAndProbe)))) - buf.WriteString(", probe:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.probe))) - buf.WriteString(", fetch and wait:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.fetchAndProbe - e.probe))) - if e.hashStat.probeCollision > 0 { - buf.WriteString(", probe_collision:") - buf.WriteString(strconv.FormatInt(e.hashStat.probeCollision, 10)) - } - buf.WriteString("}") - } - return buf.String() -} - -func (e *hashJoinRuntimeStatsV2) Clone() execdetails.RuntimeStats { - stats := hashJoinRuntimeStats{ - fetchAndBuildHashTable: e.fetchAndBuildHashTable, - hashStat: e.hashStat, - fetchAndProbe: e.fetchAndProbe, - probe: e.probe, - concurrent: e.concurrent, - maxFetchAndProbe: e.maxFetchAndProbe, - } - return &hashJoinRuntimeStatsV2{ - hashJoinRuntimeStats: stats, - partitionData: e.partitionData, - maxPartitionData: e.maxPartitionData, - buildHashTable: e.buildHashTable, - maxBuildHashTable: e.maxBuildHashTable, - } -} - -func (e *hashJoinRuntimeStatsV2) Merge(rs execdetails.RuntimeStats) { - tmp, ok := rs.(*hashJoinRuntimeStatsV2) - if !ok { - return - } - e.fetchAndBuildHashTable += tmp.fetchAndBuildHashTable - e.buildHashTable += tmp.buildHashTable - if e.maxBuildHashTable < tmp.maxBuildHashTable { - e.maxBuildHashTable = tmp.maxBuildHashTable - } - e.partitionData += tmp.partitionData - if e.maxPartitionData < tmp.maxPartitionData { - e.maxPartitionData = tmp.maxPartitionData - } - e.hashStat.buildTableElapse += tmp.hashStat.buildTableElapse - e.hashStat.probeCollision += tmp.hashStat.probeCollision - e.fetchAndProbe += tmp.fetchAndProbe - e.probe += tmp.probe - if e.maxFetchAndProbe < tmp.maxFetchAndProbe { - e.maxFetchAndProbe = tmp.maxFetchAndProbe - } -} diff --git a/pkg/executor/join/index_lookup_hash_join.go b/pkg/executor/join/index_lookup_hash_join.go index 7d0c7f591a237..fd368ef77911d 100644 --- a/pkg/executor/join/index_lookup_hash_join.go +++ b/pkg/executor/join/index_lookup_hash_join.go @@ -338,11 +338,11 @@ func (ow *indexHashJoinOuterWorker) run(ctx context.Context) { defer trace.StartRegion(ctx, "IndexHashJoinOuterWorker").End() defer close(ow.innerCh) for { - failpoint.Eval(_curpkg_("TestIssue30211")) + failpoint.Inject("TestIssue30211", nil) task, err := ow.buildTask(ctx) - if _, _err_ := failpoint.Eval(_curpkg_("testIndexHashJoinOuterWorkerErr")); _err_ == nil { + failpoint.Inject("testIndexHashJoinOuterWorkerErr", func() { err = errors.New("mockIndexHashJoinOuterWorkerErr") - } + }) if err != nil { task = &indexHashJoinTask{err: err} if ow.keepOuterOrder { @@ -362,9 +362,9 @@ func (ow *indexHashJoinOuterWorker) run(ctx context.Context) { return } if ow.keepOuterOrder { - if _, _err_ := failpoint.Eval(_curpkg_("testIssue20779")); _err_ == nil { + failpoint.Inject("testIssue20779", func() { panic("testIssue20779") - } + }) if finished := ow.pushToChan(ctx, task, ow.taskCh); finished { return } @@ -531,9 +531,9 @@ func (iw *indexHashJoinInnerWorker) run(ctx context.Context, cancelFunc context. } } } - if _, _err_ := failpoint.Eval(_curpkg_("testIndexHashJoinInnerWorkerErr")); _err_ == nil { + failpoint.Inject("testIndexHashJoinInnerWorkerErr", func() { joinResult.err = errors.New("mockIndexHashJoinInnerWorkerErr") - } + }) // When task.KeepOuterOrder is TRUE (resultCh != iw.resultCh): // - the last joinResult will be handled when the task has been processed, // thus we DO NOT need to check it here again. @@ -572,8 +572,8 @@ func (iw *indexHashJoinInnerWorker) getNewJoinResult(ctx context.Context) (*inde } func (iw *indexHashJoinInnerWorker) buildHashTableForOuterResult(task *indexHashJoinTask, h hash.Hash64) { - failpoint.Eval(_curpkg_("IndexHashJoinBuildHashTablePanic")) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("IndexHashJoinBuildHashTablePanic", nil) + failpoint.Inject("ConsumeRandomPanic", nil) if iw.stats != nil { start := time.Now() defer func() { @@ -602,9 +602,9 @@ func (iw *indexHashJoinInnerWorker) buildHashTableForOuterResult(task *indexHash } h.Reset() err := codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), h, row, iw.outerCtx.HashTypes, hashColIdx, buf) - if _, _err_ := failpoint.Eval(_curpkg_("testIndexHashJoinBuildErr")); _err_ == nil { + failpoint.Inject("testIndexHashJoinBuildErr", func() { err = errors.New("mockIndexHashJoinBuildErr") - } + }) if err != nil { // This panic will be recovered by the invoker. panic(err.Error()) @@ -680,10 +680,10 @@ func (iw *indexHashJoinInnerWorker) handleTask(ctx context.Context, task *indexH iw.wg.Wait() // check error after wg.Wait to make sure error message can be sent to // resultCh even if panic happen in buildHashTableForOuterResult. - if _, _err_ := failpoint.Eval(_curpkg_("IndexHashJoinFetchInnerResultsErr")); _err_ == nil { + failpoint.Inject("IndexHashJoinFetchInnerResultsErr", func() { err = errors.New("IndexHashJoinFetchInnerResultsErr") - } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + }) + failpoint.Inject("ConsumeRandomPanic", nil) if err != nil { return err } @@ -789,7 +789,7 @@ func (iw *indexHashJoinInnerWorker) joinMatchedInnerRow2Chunk(ctx context.Contex joinResult.err = ctx.Err() return false, joinResult } - failpoint.Call(_curpkg_("joinMatchedInnerRow2Chunk")) + failpoint.InjectCall("joinMatchedInnerRow2Chunk") joinResult, ok = iw.getNewJoinResult(ctx) if !ok { return false, joinResult @@ -837,9 +837,9 @@ func (iw *indexHashJoinInnerWorker) doJoinInOrder(ctx context.Context, task *ind row := chk.GetRow(j) ptr := chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)} err = iw.collectMatchedInnerPtrs4OuterRows(row, ptr, task, h, iw.joinKeyBuf) - if _, _err_ := failpoint.Eval(_curpkg_("TestIssue31129")); _err_ == nil { + failpoint.Inject("TestIssue31129", func() { err = errors.New("TestIssue31129") - } + }) if err != nil { return err } diff --git a/pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ b/pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ deleted file mode 100644 index fd368ef77911d..0000000000000 --- a/pkg/executor/join/index_lookup_hash_join.go__failpoint_stash__ +++ /dev/null @@ -1,884 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "context" - "fmt" - "hash" - "hash/fnv" - "runtime/trace" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" -) - -// numResChkHold indicates the number of resource chunks that an inner worker -// holds at the same time. -// It's used in 2 cases individually: -// 1. IndexMergeJoin -// 2. IndexNestedLoopHashJoin: -// It's used when IndexNestedLoopHashJoin.KeepOuterOrder is true. -// Otherwise, there will be at most `concurrency` resource chunks throughout -// the execution of IndexNestedLoopHashJoin. -const numResChkHold = 4 - -// IndexNestedLoopHashJoin employs one outer worker and N inner workers to -// execute concurrently. The output order is not promised. -// -// The execution flow is very similar to IndexLookUpReader: -// 1. The outer worker reads N outer rows, builds a task and sends it to the -// inner worker channel. -// 2. The inner worker receives the tasks and does 3 things for every task: -// 1. builds hash table from the outer rows -// 2. builds key ranges from outer rows and fetches inner rows -// 3. probes the hash table and sends the join result to the main thread channel. -// Note: step 1 and step 2 runs concurrently. -// -// 3. The main thread receives the join results. -type IndexNestedLoopHashJoin struct { - IndexLookUpJoin - resultCh chan *indexHashJoinResult - joinChkResourceCh []chan *chunk.Chunk - // We build individual joiner for each inner worker when using chunk-based - // execution, to avoid the concurrency of joiner.chk and joiner.selected. - Joiners []Joiner - KeepOuterOrder bool - curTask *indexHashJoinTask - // taskCh is only used when `KeepOuterOrder` is true. - taskCh chan *indexHashJoinTask - - stats *indexLookUpJoinRuntimeStats - prepared bool - // panicErr records the error generated by panic recover. This is introduced to - // return the actual error message instead of `context cancelled` to the client. - panicErr error - ctxWithCancel context.Context -} - -type indexHashJoinOuterWorker struct { - outerWorker - innerCh chan *indexHashJoinTask - keepOuterOrder bool - // taskCh is only used when the outer order needs to be promised. - taskCh chan *indexHashJoinTask -} - -type indexHashJoinInnerWorker struct { - innerWorker - joiner Joiner - joinChkResourceCh chan *chunk.Chunk - // resultCh is valid only when indexNestedLoopHashJoin do not need to keep - // order. Otherwise, it will be nil. - resultCh chan *indexHashJoinResult - taskCh <-chan *indexHashJoinTask - wg *sync.WaitGroup - joinKeyBuf []byte - outerRowStatus []outerRowStatusFlag - rowIter *chunk.Iterator4Slice -} - -type indexHashJoinResult struct { - chk *chunk.Chunk - err error - src chan<- *chunk.Chunk -} - -type indexHashJoinTask struct { - *lookUpJoinTask - outerRowStatus [][]outerRowStatusFlag - lookupMap BaseHashTable - err error - keepOuterOrder bool - // resultCh is only used when the outer order needs to be promised. - resultCh chan *indexHashJoinResult - // matchedInnerRowPtrs is only valid when the outer order needs to be - // promised. Otherwise, it will be nil. - // len(matchedInnerRowPtrs) equals to - // lookUpJoinTask.outerResult.NumChunks(), and the elements of every - // matchedInnerRowPtrs[chkIdx][rowIdx] indicates the matched inner row ptrs - // of the corresponding outer row. - matchedInnerRowPtrs [][][]chunk.RowPtr -} - -// Open implements the IndexNestedLoopHashJoin Executor interface. -func (e *IndexNestedLoopHashJoin) Open(ctx context.Context) error { - err := exec.Open(ctx, e.Children(0)) - if err != nil { - return err - } - if e.memTracker != nil { - e.memTracker.Reset() - } else { - e.memTracker = memory.NewTracker(e.ID(), -1) - } - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - e.cancelFunc = nil - e.innerPtrBytes = make([][]byte, 0, 8) - if e.RuntimeStats() != nil { - e.stats = &indexLookUpJoinRuntimeStats{} - } - e.Finished.Store(false) - return nil -} - -func (e *IndexNestedLoopHashJoin) startWorkers(ctx context.Context, initBatchSize int) { - concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() - if e.stats != nil { - e.stats.concurrency = concurrency - } - workerCtx, cancelFunc := context.WithCancel(ctx) - e.ctxWithCancel, e.cancelFunc = workerCtx, cancelFunc - innerCh := make(chan *indexHashJoinTask, concurrency) - if e.KeepOuterOrder { - e.taskCh = make(chan *indexHashJoinTask, concurrency) - // When `KeepOuterOrder` is true, each task holds their own `resultCh` - // individually, thus we do not need a global resultCh. - e.resultCh = nil - } else { - e.resultCh = make(chan *indexHashJoinResult, concurrency) - } - e.joinChkResourceCh = make([]chan *chunk.Chunk, concurrency) - e.WorkerWg.Add(1) - ow := e.newOuterWorker(innerCh, initBatchSize) - go util.WithRecovery(func() { ow.run(e.ctxWithCancel) }, e.finishJoinWorkers) - - for i := 0; i < concurrency; i++ { - if !e.KeepOuterOrder { - e.joinChkResourceCh[i] = make(chan *chunk.Chunk, 1) - e.joinChkResourceCh[i] <- exec.NewFirstChunk(e) - } else { - e.joinChkResourceCh[i] = make(chan *chunk.Chunk, numResChkHold) - for j := 0; j < numResChkHold; j++ { - e.joinChkResourceCh[i] <- exec.NewFirstChunk(e) - } - } - } - - e.WorkerWg.Add(concurrency) - for i := 0; i < concurrency; i++ { - workerID := i - go util.WithRecovery(func() { e.newInnerWorker(innerCh, workerID).run(e.ctxWithCancel, cancelFunc) }, e.finishJoinWorkers) - } - go e.wait4JoinWorkers() -} - -func (e *IndexNestedLoopHashJoin) finishJoinWorkers(r any) { - if r != nil { - e.IndexLookUpJoin.Finished.Store(true) - err := fmt.Errorf("%v", r) - if recoverdErr, ok := r.(error); ok { - err = recoverdErr - } - if !e.KeepOuterOrder { - e.resultCh <- &indexHashJoinResult{err: err} - } else { - task := &indexHashJoinTask{err: err} - e.taskCh <- task - } - e.panicErr = err - if e.cancelFunc != nil { - e.cancelFunc() - } - } - e.WorkerWg.Done() -} - -func (e *IndexNestedLoopHashJoin) wait4JoinWorkers() { - e.WorkerWg.Wait() - if e.resultCh != nil { - close(e.resultCh) - } - if e.taskCh != nil { - close(e.taskCh) - } -} - -// Next implements the IndexNestedLoopHashJoin Executor interface. -func (e *IndexNestedLoopHashJoin) Next(ctx context.Context, req *chunk.Chunk) error { - if !e.prepared { - e.startWorkers(ctx, req.RequiredRows()) - e.prepared = true - } - req.Reset() - if e.KeepOuterOrder { - return e.runInOrder(e.ctxWithCancel, req) - } - return e.runUnordered(e.ctxWithCancel, req) -} - -func (e *IndexNestedLoopHashJoin) runInOrder(ctx context.Context, req *chunk.Chunk) error { - for { - if e.isDryUpTasks(ctx) { - return e.panicErr - } - if e.curTask.err != nil { - return e.curTask.err - } - result, err := e.getResultFromChannel(ctx, e.curTask.resultCh) - if err != nil { - return err - } - if result == nil { - e.curTask = nil - continue - } - return e.handleResult(req, result) - } -} - -func (e *IndexNestedLoopHashJoin) runUnordered(ctx context.Context, req *chunk.Chunk) error { - result, err := e.getResultFromChannel(ctx, e.resultCh) - if err != nil { - return err - } - return e.handleResult(req, result) -} - -// isDryUpTasks indicates whether all the tasks have been processed. -func (e *IndexNestedLoopHashJoin) isDryUpTasks(ctx context.Context) bool { - if e.curTask != nil { - return false - } - var ok bool - select { - case e.curTask, ok = <-e.taskCh: - if !ok { - return true - } - case <-ctx.Done(): - return true - } - return false -} - -func (e *IndexNestedLoopHashJoin) getResultFromChannel(ctx context.Context, resultCh <-chan *indexHashJoinResult) (*indexHashJoinResult, error) { - var ( - result *indexHashJoinResult - ok bool - ) - select { - case result, ok = <-resultCh: - if !ok { - return nil, nil - } - if result.err != nil { - return nil, result.err - } - case <-ctx.Done(): - err := e.panicErr - if err == nil { - err = ctx.Err() - } - return nil, err - } - return result, nil -} - -func (*IndexNestedLoopHashJoin) handleResult(req *chunk.Chunk, result *indexHashJoinResult) error { - if result == nil { - return nil - } - req.SwapColumns(result.chk) - result.src <- result.chk - return nil -} - -// Close implements the IndexNestedLoopHashJoin Executor interface. -func (e *IndexNestedLoopHashJoin) Close() error { - if e.stats != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - if e.cancelFunc != nil { - e.cancelFunc() - } - if e.resultCh != nil { - channel.Clear(e.resultCh) - e.resultCh = nil - } - if e.taskCh != nil { - channel.Clear(e.taskCh) - e.taskCh = nil - } - for i := range e.joinChkResourceCh { - close(e.joinChkResourceCh[i]) - } - e.joinChkResourceCh = nil - e.Finished.Store(false) - e.prepared = false - e.ctxWithCancel = nil - return e.BaseExecutor.Close() -} - -func (ow *indexHashJoinOuterWorker) run(ctx context.Context) { - defer trace.StartRegion(ctx, "IndexHashJoinOuterWorker").End() - defer close(ow.innerCh) - for { - failpoint.Inject("TestIssue30211", nil) - task, err := ow.buildTask(ctx) - failpoint.Inject("testIndexHashJoinOuterWorkerErr", func() { - err = errors.New("mockIndexHashJoinOuterWorkerErr") - }) - if err != nil { - task = &indexHashJoinTask{err: err} - if ow.keepOuterOrder { - // The outerBuilder and innerFetcher run concurrently, we may - // get 2 errors at simultaneously. Thus the capacity of task.resultCh - // needs to be initialized to 2 to avoid waiting. - task.keepOuterOrder, task.resultCh = true, make(chan *indexHashJoinResult, 2) - ow.pushToChan(ctx, task, ow.taskCh) - } - ow.pushToChan(ctx, task, ow.innerCh) - return - } - if task == nil { - return - } - if finished := ow.pushToChan(ctx, task, ow.innerCh); finished { - return - } - if ow.keepOuterOrder { - failpoint.Inject("testIssue20779", func() { - panic("testIssue20779") - }) - if finished := ow.pushToChan(ctx, task, ow.taskCh); finished { - return - } - } - } -} - -func (ow *indexHashJoinOuterWorker) buildTask(ctx context.Context) (*indexHashJoinTask, error) { - task, err := ow.outerWorker.buildTask(ctx) - if task == nil || err != nil { - return nil, err - } - var ( - resultCh chan *indexHashJoinResult - matchedInnerRowPtrs [][][]chunk.RowPtr - ) - if ow.keepOuterOrder { - resultCh = make(chan *indexHashJoinResult, numResChkHold) - matchedInnerRowPtrs = make([][][]chunk.RowPtr, task.outerResult.NumChunks()) - for i := range matchedInnerRowPtrs { - matchedInnerRowPtrs[i] = make([][]chunk.RowPtr, task.outerResult.GetChunk(i).NumRows()) - } - } - numChks := task.outerResult.NumChunks() - outerRowStatus := make([][]outerRowStatusFlag, numChks) - for i := 0; i < numChks; i++ { - outerRowStatus[i] = make([]outerRowStatusFlag, task.outerResult.GetChunk(i).NumRows()) - } - return &indexHashJoinTask{ - lookUpJoinTask: task, - outerRowStatus: outerRowStatus, - keepOuterOrder: ow.keepOuterOrder, - resultCh: resultCh, - matchedInnerRowPtrs: matchedInnerRowPtrs, - }, nil -} - -func (*indexHashJoinOuterWorker) pushToChan(ctx context.Context, task *indexHashJoinTask, dst chan<- *indexHashJoinTask) bool { - select { - case <-ctx.Done(): - return true - case dst <- task: - } - return false -} - -func (e *IndexNestedLoopHashJoin) newOuterWorker(innerCh chan *indexHashJoinTask, initBatchSize int) *indexHashJoinOuterWorker { - maxBatchSize := e.Ctx().GetSessionVars().IndexJoinBatchSize - batchSize := min(initBatchSize, maxBatchSize) - ow := &indexHashJoinOuterWorker{ - outerWorker: outerWorker{ - OuterCtx: e.OuterCtx, - ctx: e.Ctx(), - executor: e.Children(0), - batchSize: batchSize, - maxBatchSize: maxBatchSize, - parentMemTracker: e.memTracker, - lookup: &e.IndexLookUpJoin, - }, - innerCh: innerCh, - keepOuterOrder: e.KeepOuterOrder, - taskCh: e.taskCh, - } - return ow -} - -func (e *IndexNestedLoopHashJoin) newInnerWorker(taskCh chan *indexHashJoinTask, workerID int) *indexHashJoinInnerWorker { - // Since multiple inner workers run concurrently, we should copy join's IndexRanges for every worker to avoid data race. - copiedRanges := make([]*ranger.Range, 0, len(e.IndexRanges.Range())) - for _, ran := range e.IndexRanges.Range() { - copiedRanges = append(copiedRanges, ran.Clone()) - } - var innerStats *innerWorkerRuntimeStats - if e.stats != nil { - innerStats = &e.stats.innerWorker - } - iw := &indexHashJoinInnerWorker{ - innerWorker: innerWorker{ - InnerCtx: e.InnerCtx, - outerCtx: e.OuterCtx, - ctx: e.Ctx(), - executorChk: e.AllocPool.Alloc(e.InnerCtx.RowTypes, e.MaxChunkSize(), e.MaxChunkSize()), - indexRanges: copiedRanges, - keyOff2IdxOff: e.KeyOff2IdxOff, - stats: innerStats, - lookup: &e.IndexLookUpJoin, - memTracker: memory.NewTracker(memory.LabelForIndexJoinInnerWorker, -1), - }, - taskCh: taskCh, - joiner: e.Joiners[workerID], - joinChkResourceCh: e.joinChkResourceCh[workerID], - resultCh: e.resultCh, - joinKeyBuf: make([]byte, 1), - outerRowStatus: make([]outerRowStatusFlag, 0, e.MaxChunkSize()), - rowIter: chunk.NewIterator4Slice([]chunk.Row{}), - } - iw.memTracker.AttachTo(e.memTracker) - if len(copiedRanges) != 0 { - // We should not consume this memory usage in `iw.memTracker`. The - // memory usage of inner worker will be reset the end of iw.handleTask. - // While the life cycle of this memory consumption exists throughout the - // whole active period of inner worker. - e.Ctx().GetSessionVars().StmtCtx.MemTracker.Consume(2 * types.EstimatedMemUsage(copiedRanges[0].LowVal, len(copiedRanges))) - } - if e.LastColHelper != nil { - // nextCwf.TmpConstant needs to be reset for every individual - // inner worker to avoid data race when the inner workers is running - // concurrently. - nextCwf := *e.LastColHelper - nextCwf.TmpConstant = make([]*expression.Constant, len(e.LastColHelper.TmpConstant)) - for i := range e.LastColHelper.TmpConstant { - nextCwf.TmpConstant[i] = &expression.Constant{RetType: nextCwf.TargetCol.RetType} - } - iw.nextColCompareFilters = &nextCwf - } - return iw -} - -func (iw *indexHashJoinInnerWorker) run(ctx context.Context, cancelFunc context.CancelFunc) { - defer trace.StartRegion(ctx, "IndexHashJoinInnerWorker").End() - var task *indexHashJoinTask - joinResult, ok := iw.getNewJoinResult(ctx) - if !ok { - cancelFunc() - return - } - h, resultCh := fnv.New64(), iw.resultCh - for { - // The previous task has been processed, so release the occupied memory - if task != nil { - task.memTracker.Detach() - } - select { - case <-ctx.Done(): - return - case task, ok = <-iw.taskCh: - } - if !ok { - break - } - // We need to init resultCh before the err is returned. - if task.keepOuterOrder { - resultCh = task.resultCh - } - if task.err != nil { - joinResult.err = task.err - break - } - err := iw.handleTask(ctx, task, joinResult, h, resultCh) - if err != nil && !task.keepOuterOrder { - // Only need check non-keep-outer-order case because the - // `joinResult` had been sent to the `resultCh` when err != nil. - joinResult.err = err - break - } - if task.keepOuterOrder { - // We need to get a new result holder here because the old - // `joinResult` hash been sent to the `resultCh` or to the - // `joinChkResourceCh`. - joinResult, ok = iw.getNewJoinResult(ctx) - if !ok { - cancelFunc() - return - } - } - } - failpoint.Inject("testIndexHashJoinInnerWorkerErr", func() { - joinResult.err = errors.New("mockIndexHashJoinInnerWorkerErr") - }) - // When task.KeepOuterOrder is TRUE (resultCh != iw.resultCh): - // - the last joinResult will be handled when the task has been processed, - // thus we DO NOT need to check it here again. - // - we DO NOT check the error here neither, because: - // - if the error is from task.err, the main thread will check the error of each task - // - if the error is from handleTask, the error will be handled in handleTask - // We should not check `task != nil && !task.KeepOuterOrder` here since it's - // possible that `join.chk.NumRows > 0` is true even if task == nil. - if resultCh == iw.resultCh { - if joinResult.err != nil { - resultCh <- joinResult - return - } - if joinResult.chk != nil && joinResult.chk.NumRows() > 0 { - select { - case resultCh <- joinResult: - case <-ctx.Done(): - return - } - } - } -} - -func (iw *indexHashJoinInnerWorker) getNewJoinResult(ctx context.Context) (*indexHashJoinResult, bool) { - joinResult := &indexHashJoinResult{ - src: iw.joinChkResourceCh, - } - ok := true - select { - case joinResult.chk, ok = <-iw.joinChkResourceCh: - case <-ctx.Done(): - joinResult.err = ctx.Err() - return joinResult, false - } - return joinResult, ok -} - -func (iw *indexHashJoinInnerWorker) buildHashTableForOuterResult(task *indexHashJoinTask, h hash.Hash64) { - failpoint.Inject("IndexHashJoinBuildHashTablePanic", nil) - failpoint.Inject("ConsumeRandomPanic", nil) - if iw.stats != nil { - start := time.Now() - defer func() { - atomic.AddInt64(&iw.stats.build, int64(time.Since(start))) - }() - } - buf, numChks := make([]byte, 1), task.outerResult.NumChunks() - task.lookupMap = newUnsafeHashTable(task.outerResult.Len()) - for chkIdx := 0; chkIdx < numChks; chkIdx++ { - chk := task.outerResult.GetChunk(chkIdx) - numRows := chk.NumRows() - if iw.lookup.Finished.Load().(bool) { - return - } - OUTER: - for rowIdx := 0; rowIdx < numRows; rowIdx++ { - if task.outerMatch != nil && !task.outerMatch[chkIdx][rowIdx] { - continue - } - row := chk.GetRow(rowIdx) - hashColIdx := iw.outerCtx.HashCols - for _, i := range hashColIdx { - if row.IsNull(i) { - continue OUTER - } - } - h.Reset() - err := codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), h, row, iw.outerCtx.HashTypes, hashColIdx, buf) - failpoint.Inject("testIndexHashJoinBuildErr", func() { - err = errors.New("mockIndexHashJoinBuildErr") - }) - if err != nil { - // This panic will be recovered by the invoker. - panic(err.Error()) - } - rowPtr := chunk.RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)} - task.lookupMap.Put(h.Sum64(), rowPtr) - } - } -} - -func (iw *indexHashJoinInnerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTask) error { - lookUpContents, err := iw.constructLookupContent(task) - if err != nil { - return err - } - return iw.innerWorker.fetchInnerResults(ctx, task, lookUpContents) -} - -func (iw *indexHashJoinInnerWorker) handleHashJoinInnerWorkerPanic(resultCh chan *indexHashJoinResult, err error) { - defer func() { - iw.wg.Done() - iw.lookup.WorkerWg.Done() - }() - if err != nil { - resultCh <- &indexHashJoinResult{err: err} - } -} - -func (iw *indexHashJoinInnerWorker) handleTask(ctx context.Context, task *indexHashJoinTask, joinResult *indexHashJoinResult, h hash.Hash64, resultCh chan *indexHashJoinResult) (err error) { - defer func() { - iw.memTracker.Consume(-iw.memTracker.BytesConsumed()) - if task.keepOuterOrder { - if err != nil { - joinResult.err = err - select { - case <-ctx.Done(): - case resultCh <- joinResult: - } - } - close(resultCh) - } - }() - var joinStartTime time.Time - if iw.stats != nil { - start := time.Now() - defer func() { - endTime := time.Now() - atomic.AddInt64(&iw.stats.totalTime, int64(endTime.Sub(start))) - if !joinStartTime.IsZero() { - // FetchInnerResults maybe return err and return, so joinStartTime is not initialized. - atomic.AddInt64(&iw.stats.join, int64(endTime.Sub(joinStartTime))) - } - }() - } - - iw.wg = &sync.WaitGroup{} - iw.wg.Add(1) - iw.lookup.WorkerWg.Add(1) - // TODO(XuHuaiyu): we may always use the smaller side to build the hashtable. - go util.WithRecovery( - func() { - iw.buildHashTableForOuterResult(task, h) - }, - func(r any) { - var err error - if r != nil { - err = errors.Errorf("%v", r) - } - iw.handleHashJoinInnerWorkerPanic(resultCh, err) - }, - ) - err = iw.fetchInnerResults(ctx, task.lookUpJoinTask) - iw.wg.Wait() - // check error after wg.Wait to make sure error message can be sent to - // resultCh even if panic happen in buildHashTableForOuterResult. - failpoint.Inject("IndexHashJoinFetchInnerResultsErr", func() { - err = errors.New("IndexHashJoinFetchInnerResultsErr") - }) - failpoint.Inject("ConsumeRandomPanic", nil) - if err != nil { - return err - } - - joinStartTime = time.Now() - if !task.keepOuterOrder { - return iw.doJoinUnordered(ctx, task, joinResult, h, resultCh) - } - return iw.doJoinInOrder(ctx, task, joinResult, h, resultCh) -} - -func (iw *indexHashJoinInnerWorker) doJoinUnordered(ctx context.Context, task *indexHashJoinTask, joinResult *indexHashJoinResult, h hash.Hash64, resultCh chan *indexHashJoinResult) error { - var ok bool - iter := chunk.NewIterator4List(task.innerResult) - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - ok, joinResult = iw.joinMatchedInnerRow2Chunk(ctx, row, task, joinResult, h, iw.joinKeyBuf) - if !ok { - return joinResult.err - } - } - for chkIdx, outerRowStatus := range task.outerRowStatus { - chk := task.outerResult.GetChunk(chkIdx) - for rowIdx, val := range outerRowStatus { - if val == outerRowMatched { - continue - } - iw.joiner.OnMissMatch(val == outerRowHasNull, chk.GetRow(rowIdx), joinResult.chk) - if joinResult.chk.IsFull() { - select { - case resultCh <- joinResult: - case <-ctx.Done(): - return ctx.Err() - } - joinResult, ok = iw.getNewJoinResult(ctx) - if !ok { - return errors.New("indexHashJoinInnerWorker.doJoinUnordered failed") - } - } - } - } - return nil -} - -func (iw *indexHashJoinInnerWorker) getMatchedOuterRows(innerRow chunk.Row, task *indexHashJoinTask, h hash.Hash64, buf []byte) (matchedRows []chunk.Row, matchedRowPtr []chunk.RowPtr, err error) { - h.Reset() - err = codec.HashChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), h, innerRow, iw.HashTypes, iw.HashCols, buf) - if err != nil { - return nil, nil, err - } - matchedOuterEntry := task.lookupMap.Get(h.Sum64()) - if matchedOuterEntry == nil { - return nil, nil, nil - } - joinType := JoinerType(iw.joiner) - isSemiJoin := joinType.IsSemiJoin() - for ; matchedOuterEntry != nil; matchedOuterEntry = matchedOuterEntry.Next { - ptr := matchedOuterEntry.Ptr - outerRow := task.outerResult.GetRow(ptr) - ok, err := codec.EqualChunkRow(iw.ctx.GetSessionVars().StmtCtx.TypeCtx(), innerRow, iw.HashTypes, iw.HashCols, outerRow, iw.outerCtx.HashTypes, iw.outerCtx.HashCols) - if err != nil { - return nil, nil, err - } - if !ok || (task.outerRowStatus[ptr.ChkIdx][ptr.RowIdx] == outerRowMatched && isSemiJoin) { - continue - } - matchedRows = append(matchedRows, outerRow) - matchedRowPtr = append(matchedRowPtr, chunk.RowPtr{ChkIdx: ptr.ChkIdx, RowIdx: ptr.RowIdx}) - } - return matchedRows, matchedRowPtr, nil -} - -func (iw *indexHashJoinInnerWorker) joinMatchedInnerRow2Chunk(ctx context.Context, innerRow chunk.Row, task *indexHashJoinTask, - joinResult *indexHashJoinResult, h hash.Hash64, buf []byte) (bool, *indexHashJoinResult) { - matchedOuterRows, matchedOuterRowPtr, err := iw.getMatchedOuterRows(innerRow, task, h, buf) - if err != nil { - joinResult.err = err - return false, joinResult - } - if len(matchedOuterRows) == 0 { - return true, joinResult - } - var ok bool - cursor := 0 - iw.rowIter.Reset(matchedOuterRows) - iter := iw.rowIter - for iw.rowIter.Begin(); iter.Current() != iter.End(); { - iw.outerRowStatus, err = iw.joiner.TryToMatchOuters(iter, innerRow, joinResult.chk, iw.outerRowStatus) - if err != nil { - joinResult.err = err - return false, joinResult - } - for _, status := range iw.outerRowStatus { - chkIdx, rowIdx := matchedOuterRowPtr[cursor].ChkIdx, matchedOuterRowPtr[cursor].RowIdx - if status == outerRowMatched || task.outerRowStatus[chkIdx][rowIdx] == outerRowUnmatched { - task.outerRowStatus[chkIdx][rowIdx] = status - } - cursor++ - } - if joinResult.chk.IsFull() { - select { - case iw.resultCh <- joinResult: - case <-ctx.Done(): - joinResult.err = ctx.Err() - return false, joinResult - } - failpoint.InjectCall("joinMatchedInnerRow2Chunk") - joinResult, ok = iw.getNewJoinResult(ctx) - if !ok { - return false, joinResult - } - } - } - return true, joinResult -} - -func (iw *indexHashJoinInnerWorker) collectMatchedInnerPtrs4OuterRows(innerRow chunk.Row, innerRowPtr chunk.RowPtr, - task *indexHashJoinTask, h hash.Hash64, buf []byte) error { - _, matchedOuterRowIdx, err := iw.getMatchedOuterRows(innerRow, task, h, buf) - if err != nil { - return err - } - for _, outerRowPtr := range matchedOuterRowIdx { - chkIdx, rowIdx := outerRowPtr.ChkIdx, outerRowPtr.RowIdx - task.matchedInnerRowPtrs[chkIdx][rowIdx] = append(task.matchedInnerRowPtrs[chkIdx][rowIdx], innerRowPtr) - } - return nil -} - -// doJoinInOrder follows the following steps: -// 1. collect all the matched inner row ptrs for every outer row -// 2. do the join work -// 2.1 collect all the matched inner rows using the collected ptrs for every outer row -// 2.2 call TryToMatchInners for every outer row -// 2.3 call OnMissMatch when no inner rows are matched -func (iw *indexHashJoinInnerWorker) doJoinInOrder(ctx context.Context, task *indexHashJoinTask, joinResult *indexHashJoinResult, h hash.Hash64, resultCh chan *indexHashJoinResult) (err error) { - defer func() { - if err == nil && joinResult.chk != nil { - if joinResult.chk.NumRows() > 0 { - select { - case resultCh <- joinResult: - case <-ctx.Done(): - return - } - } else { - joinResult.src <- joinResult.chk - } - } - }() - for i, numChunks := 0, task.innerResult.NumChunks(); i < numChunks; i++ { - for j, chk := 0, task.innerResult.GetChunk(i); j < chk.NumRows(); j++ { - row := chk.GetRow(j) - ptr := chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)} - err = iw.collectMatchedInnerPtrs4OuterRows(row, ptr, task, h, iw.joinKeyBuf) - failpoint.Inject("TestIssue31129", func() { - err = errors.New("TestIssue31129") - }) - if err != nil { - return err - } - } - } - // TODO: matchedInnerRowPtrs and matchedInnerRows can be moved to inner worker. - matchedInnerRows := make([]chunk.Row, 0, len(task.matchedInnerRowPtrs)) - var hasMatched, hasNull, ok bool - for chkIdx, innerRowPtrs4Chk := range task.matchedInnerRowPtrs { - for outerRowIdx, innerRowPtrs := range innerRowPtrs4Chk { - matchedInnerRows, hasMatched, hasNull = matchedInnerRows[:0], false, false - outerRow := task.outerResult.GetChunk(chkIdx).GetRow(outerRowIdx) - for _, ptr := range innerRowPtrs { - matchedInnerRows = append(matchedInnerRows, task.innerResult.GetRow(ptr)) - } - iw.rowIter.Reset(matchedInnerRows) - iter := iw.rowIter - for iter.Begin(); iter.Current() != iter.End(); { - matched, isNull, err := iw.joiner.TryToMatchInners(outerRow, iter, joinResult.chk) - if err != nil { - return err - } - hasMatched, hasNull = matched || hasMatched, isNull || hasNull - if joinResult.chk.IsFull() { - select { - case resultCh <- joinResult: - case <-ctx.Done(): - return ctx.Err() - } - joinResult, ok = iw.getNewJoinResult(ctx) - if !ok { - return errors.New("indexHashJoinInnerWorker.doJoinInOrder failed") - } - } - } - if !hasMatched { - iw.joiner.OnMissMatch(hasNull, outerRow, joinResult.chk) - } - } - } - return nil -} diff --git a/pkg/executor/join/index_lookup_join.go b/pkg/executor/join/index_lookup_join.go index 06b8fadfabfeb..e289ae7ab44b3 100644 --- a/pkg/executor/join/index_lookup_join.go +++ b/pkg/executor/join/index_lookup_join.go @@ -243,11 +243,11 @@ func (e *IndexLookUpJoin) newInnerWorker(taskCh chan *lookUpJoinTask) *innerWork lookup: e, memTracker: memory.NewTracker(memory.LabelForIndexJoinInnerWorker, -1), } - if val, _err_ := failpoint.Eval(_curpkg_("inlNewInnerPanic")); _err_ == nil { + failpoint.Inject("inlNewInnerPanic", func(val failpoint.Value) { if val.(bool) { panic("test inlNewInnerPanic") } - } + }) iw.memTracker.AttachTo(e.memTracker) if len(copiedRanges) != 0 { // We should not consume this memory usage in `iw.memTracker`. The @@ -389,8 +389,8 @@ func (ow *outerWorker) run(ctx context.Context, wg *sync.WaitGroup) { wg.Done() }() for { - failpoint.Eval(_curpkg_("TestIssue30211")) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("TestIssue30211", nil) + failpoint.Inject("ConsumeRandomPanic", nil) task, err := ow.buildTask(ctx) if err != nil { task.doneCh <- err @@ -436,7 +436,7 @@ func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { task.memTracker = memory.NewTracker(-1, -1) task.outerResult.GetMemTracker().AttachTo(task.memTracker) task.memTracker.AttachTo(ow.parentMemTracker) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) ow.increaseBatchSize() requiredRows := ow.batchSize @@ -586,7 +586,7 @@ func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*IndexJoi } return nil, err } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if rowIdx == 0 { iw.memTracker.Consume(types.EstimatedMemUsage(dLookUpKey, numRows)) } @@ -733,7 +733,7 @@ func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTa default: } err := exec.Next(ctx, innerExec, iw.executorChk) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if err != nil { return err } diff --git a/pkg/executor/join/index_lookup_join.go__failpoint_stash__ b/pkg/executor/join/index_lookup_join.go__failpoint_stash__ deleted file mode 100644 index e289ae7ab44b3..0000000000000 --- a/pkg/executor/join/index_lookup_join.go__failpoint_stash__ +++ /dev/null @@ -1,882 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "bytes" - "context" - "runtime/trace" - "slices" - "strconv" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/mvmap" - "github.com/pingcap/tidb/pkg/util/ranger" - "go.uber.org/zap" -) - -var _ exec.Executor = &IndexLookUpJoin{} - -// IndexLookUpJoin employs one outer worker and N innerWorkers to execute concurrently. -// It preserves the order of the outer table and support batch lookup. -// -// The execution flow is very similar to IndexLookUpReader: -// 1. outerWorker read N outer rows, build a task and send it to result channel and inner worker channel. -// 2. The innerWorker receives the task, builds key ranges from outer rows and fetch inner rows, builds inner row hash map. -// 3. main thread receives the task, waits for inner worker finish handling the task. -// 4. main thread join each outer row by look up the inner rows hash map in the task. -type IndexLookUpJoin struct { - exec.BaseExecutor - - resultCh <-chan *lookUpJoinTask - cancelFunc context.CancelFunc - WorkerWg *sync.WaitGroup - - OuterCtx OuterCtx - InnerCtx InnerCtx - - task *lookUpJoinTask - JoinResult *chunk.Chunk - innerIter *chunk.Iterator4Slice - - Joiner Joiner - IsOuterJoin bool - - requiredRows int64 - - IndexRanges ranger.MutableRanges - KeyOff2IdxOff []int - innerPtrBytes [][]byte - - // LastColHelper store the information for last col if there's complicated filter like col > x_col and col < x_col + 100. - LastColHelper *plannercore.ColWithCmpFuncManager - - memTracker *memory.Tracker // track memory usage. - - stats *indexLookUpJoinRuntimeStats - Finished *atomic.Value - prepared bool -} - -// OuterCtx is the outer ctx used in index lookup join -type OuterCtx struct { - RowTypes []*types.FieldType - KeyCols []int - HashTypes []*types.FieldType - HashCols []int - Filter expression.CNFExprs -} - -// IndexJoinExecutorBuilder is the interface used by index lookup join to build the executor, this interface -// is added to avoid cycle import -type IndexJoinExecutorBuilder interface { - BuildExecutorForIndexJoin(ctx context.Context, lookUpContents []*IndexJoinLookUpContent, - indexRanges []*ranger.Range, keyOff2IdxOff []int, cwc *plannercore.ColWithCmpFuncManager, canReorderHandles bool, memTracker *memory.Tracker, interruptSignal *atomic.Value) (exec.Executor, error) -} - -// InnerCtx is the inner side ctx used in index lookup join -type InnerCtx struct { - ReaderBuilder IndexJoinExecutorBuilder - RowTypes []*types.FieldType - KeyCols []int - KeyColIDs []int64 // the original ID in its table, used by dynamic partition pruning - KeyCollators []collate.Collator - HashTypes []*types.FieldType - HashCols []int - HashCollators []collate.Collator - ColLens []int - HasPrefixCol bool -} - -type lookUpJoinTask struct { - outerResult *chunk.List - outerMatch [][]bool - - innerResult *chunk.List - encodedLookUpKeys []*chunk.Chunk - lookupMap *mvmap.MVMap - matchedInners []chunk.Row - - doneCh chan error - cursor chunk.RowPtr - hasMatch bool - hasNull bool - - memTracker *memory.Tracker // track memory usage. -} - -type outerWorker struct { - OuterCtx - - lookup *IndexLookUpJoin - - ctx sessionctx.Context - executor exec.Executor - - maxBatchSize int - batchSize int - - resultCh chan<- *lookUpJoinTask - innerCh chan<- *lookUpJoinTask - - parentMemTracker *memory.Tracker -} - -type innerWorker struct { - InnerCtx - - taskCh <-chan *lookUpJoinTask - outerCtx OuterCtx - ctx sessionctx.Context - executorChk *chunk.Chunk - lookup *IndexLookUpJoin - - indexRanges []*ranger.Range - nextColCompareFilters *plannercore.ColWithCmpFuncManager - keyOff2IdxOff []int - stats *innerWorkerRuntimeStats - memTracker *memory.Tracker -} - -// Open implements the Executor interface. -func (e *IndexLookUpJoin) Open(ctx context.Context) error { - err := exec.Open(ctx, e.Children(0)) - if err != nil { - return err - } - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - e.innerPtrBytes = make([][]byte, 0, 8) - e.Finished.Store(false) - if e.RuntimeStats() != nil { - e.stats = &indexLookUpJoinRuntimeStats{} - } - e.cancelFunc = nil - return nil -} - -func (e *IndexLookUpJoin) startWorkers(ctx context.Context) { - concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() - if e.stats != nil { - e.stats.concurrency = concurrency - } - resultCh := make(chan *lookUpJoinTask, concurrency) - e.resultCh = resultCh - workerCtx, cancelFunc := context.WithCancel(ctx) - e.cancelFunc = cancelFunc - innerCh := make(chan *lookUpJoinTask, concurrency) - e.WorkerWg.Add(1) - go e.newOuterWorker(resultCh, innerCh).run(workerCtx, e.WorkerWg) - for i := 0; i < concurrency; i++ { - innerWorker := e.newInnerWorker(innerCh) - e.WorkerWg.Add(1) - go innerWorker.run(workerCtx, e.WorkerWg) - } -} - -func (e *IndexLookUpJoin) newOuterWorker(resultCh, innerCh chan *lookUpJoinTask) *outerWorker { - ow := &outerWorker{ - OuterCtx: e.OuterCtx, - ctx: e.Ctx(), - executor: e.Children(0), - resultCh: resultCh, - innerCh: innerCh, - batchSize: 32, - maxBatchSize: e.Ctx().GetSessionVars().IndexJoinBatchSize, - parentMemTracker: e.memTracker, - lookup: e, - } - return ow -} - -func (e *IndexLookUpJoin) newInnerWorker(taskCh chan *lookUpJoinTask) *innerWorker { - // Since multiple inner workers run concurrently, we should copy join's IndexRanges for every worker to avoid data race. - copiedRanges := make([]*ranger.Range, 0, len(e.IndexRanges.Range())) - for _, ran := range e.IndexRanges.Range() { - copiedRanges = append(copiedRanges, ran.Clone()) - } - - var innerStats *innerWorkerRuntimeStats - if e.stats != nil { - innerStats = &e.stats.innerWorker - } - iw := &innerWorker{ - InnerCtx: e.InnerCtx, - outerCtx: e.OuterCtx, - taskCh: taskCh, - ctx: e.Ctx(), - executorChk: e.AllocPool.Alloc(e.InnerCtx.RowTypes, e.MaxChunkSize(), e.MaxChunkSize()), - indexRanges: copiedRanges, - keyOff2IdxOff: e.KeyOff2IdxOff, - stats: innerStats, - lookup: e, - memTracker: memory.NewTracker(memory.LabelForIndexJoinInnerWorker, -1), - } - failpoint.Inject("inlNewInnerPanic", func(val failpoint.Value) { - if val.(bool) { - panic("test inlNewInnerPanic") - } - }) - iw.memTracker.AttachTo(e.memTracker) - if len(copiedRanges) != 0 { - // We should not consume this memory usage in `iw.memTracker`. The - // memory usage of inner worker will be reset the end of iw.handleTask. - // While the life cycle of this memory consumption exists throughout the - // whole active period of inner worker. - e.Ctx().GetSessionVars().StmtCtx.MemTracker.Consume(2 * types.EstimatedMemUsage(copiedRanges[0].LowVal, len(copiedRanges))) - } - if e.LastColHelper != nil { - // nextCwf.TmpConstant needs to be reset for every individual - // inner worker to avoid data race when the inner workers is running - // concurrently. - nextCwf := *e.LastColHelper - nextCwf.TmpConstant = make([]*expression.Constant, len(e.LastColHelper.TmpConstant)) - for i := range e.LastColHelper.TmpConstant { - nextCwf.TmpConstant[i] = &expression.Constant{RetType: nextCwf.TargetCol.RetType} - } - iw.nextColCompareFilters = &nextCwf - } - return iw -} - -// Next implements the Executor interface. -func (e *IndexLookUpJoin) Next(ctx context.Context, req *chunk.Chunk) error { - if !e.prepared { - e.startWorkers(ctx) - e.prepared = true - } - if e.IsOuterJoin { - atomic.StoreInt64(&e.requiredRows, int64(req.RequiredRows())) - } - req.Reset() - e.JoinResult.Reset() - for { - task, err := e.getFinishedTask(ctx) - if err != nil { - return err - } - if task == nil { - return nil - } - startTime := time.Now() - if e.innerIter == nil || e.innerIter.Current() == e.innerIter.End() { - e.lookUpMatchedInners(task, task.cursor) - if e.innerIter == nil { - e.innerIter = chunk.NewIterator4Slice(task.matchedInners) - } - e.innerIter.Reset(task.matchedInners) - e.innerIter.Begin() - } - - outerRow := task.outerResult.GetRow(task.cursor) - if e.innerIter.Current() != e.innerIter.End() { - matched, isNull, err := e.Joiner.TryToMatchInners(outerRow, e.innerIter, req) - if err != nil { - return err - } - task.hasMatch = task.hasMatch || matched - task.hasNull = task.hasNull || isNull - } - if e.innerIter.Current() == e.innerIter.End() { - if !task.hasMatch { - e.Joiner.OnMissMatch(task.hasNull, outerRow, req) - } - task.cursor.RowIdx++ - if int(task.cursor.RowIdx) == task.outerResult.GetChunk(int(task.cursor.ChkIdx)).NumRows() { - task.cursor.ChkIdx++ - task.cursor.RowIdx = 0 - } - task.hasMatch = false - task.hasNull = false - } - if e.stats != nil { - atomic.AddInt64(&e.stats.probe, int64(time.Since(startTime))) - } - if req.IsFull() { - return nil - } - } -} - -func (e *IndexLookUpJoin) getFinishedTask(ctx context.Context) (*lookUpJoinTask, error) { - task := e.task - if task != nil && int(task.cursor.ChkIdx) < task.outerResult.NumChunks() { - return task, nil - } - - // The previous task has been processed, so release the occupied memory - if task != nil { - task.memTracker.Detach() - } - select { - case task = <-e.resultCh: - case <-ctx.Done(): - return nil, ctx.Err() - } - if task == nil { - return nil, nil - } - - select { - case err := <-task.doneCh: - if err != nil { - return nil, err - } - case <-ctx.Done(): - return nil, ctx.Err() - } - - e.task = task - return task, nil -} - -func (e *IndexLookUpJoin) lookUpMatchedInners(task *lookUpJoinTask, rowPtr chunk.RowPtr) { - outerKey := task.encodedLookUpKeys[rowPtr.ChkIdx].GetRow(int(rowPtr.RowIdx)).GetBytes(0) - e.innerPtrBytes = task.lookupMap.Get(outerKey, e.innerPtrBytes[:0]) - task.matchedInners = task.matchedInners[:0] - - for _, b := range e.innerPtrBytes { - ptr := *(*chunk.RowPtr)(unsafe.Pointer(&b[0])) - matchedInner := task.innerResult.GetRow(ptr) - task.matchedInners = append(task.matchedInners, matchedInner) - } -} - -func (ow *outerWorker) run(ctx context.Context, wg *sync.WaitGroup) { - defer trace.StartRegion(ctx, "IndexLookupJoinOuterWorker").End() - defer func() { - if r := recover(); r != nil { - ow.lookup.Finished.Store(true) - logutil.Logger(ctx).Error("outerWorker panicked", zap.Any("recover", r), zap.Stack("stack")) - task := &lookUpJoinTask{doneCh: make(chan error, 1)} - err := util.GetRecoverError(r) - task.doneCh <- err - ow.pushToChan(ctx, task, ow.resultCh) - } - close(ow.resultCh) - close(ow.innerCh) - wg.Done() - }() - for { - failpoint.Inject("TestIssue30211", nil) - failpoint.Inject("ConsumeRandomPanic", nil) - task, err := ow.buildTask(ctx) - if err != nil { - task.doneCh <- err - ow.pushToChan(ctx, task, ow.resultCh) - return - } - if task == nil { - return - } - - if finished := ow.pushToChan(ctx, task, ow.innerCh); finished { - return - } - - if finished := ow.pushToChan(ctx, task, ow.resultCh); finished { - return - } - } -} - -func (*outerWorker) pushToChan(ctx context.Context, task *lookUpJoinTask, dst chan<- *lookUpJoinTask) bool { - select { - case <-ctx.Done(): - return true - case dst <- task: - } - return false -} - -// newList creates a new List to buffer current executor's result. -func newList(e exec.Executor) *chunk.List { - return chunk.NewList(e.RetFieldTypes(), e.InitCap(), e.MaxChunkSize()) -} - -// buildTask builds a lookUpJoinTask and read Outer rows. -// When err is not nil, task must not be nil to send the error to the main thread via task. -func (ow *outerWorker) buildTask(ctx context.Context) (*lookUpJoinTask, error) { - task := &lookUpJoinTask{ - doneCh: make(chan error, 1), - outerResult: newList(ow.executor), - lookupMap: mvmap.NewMVMap(), - } - task.memTracker = memory.NewTracker(-1, -1) - task.outerResult.GetMemTracker().AttachTo(task.memTracker) - task.memTracker.AttachTo(ow.parentMemTracker) - failpoint.Inject("ConsumeRandomPanic", nil) - - ow.increaseBatchSize() - requiredRows := ow.batchSize - if ow.lookup.IsOuterJoin { - // If it is outerJoin, push the requiredRows down. - // Note: buildTask is triggered when `Open` is called, but - // ow.lookup.requiredRows is set when `Next` is called. Thus we check - // whether it's 0 here. - if parentRequired := int(atomic.LoadInt64(&ow.lookup.requiredRows)); parentRequired != 0 { - requiredRows = parentRequired - } - } - maxChunkSize := ow.ctx.GetSessionVars().MaxChunkSize - for requiredRows > task.outerResult.Len() { - chk := ow.executor.NewChunkWithCapacity(ow.OuterCtx.RowTypes, maxChunkSize, maxChunkSize) - chk = chk.SetRequiredRows(requiredRows, maxChunkSize) - err := exec.Next(ctx, ow.executor, chk) - if err != nil { - return task, err - } - if chk.NumRows() == 0 { - break - } - - task.outerResult.Add(chk) - } - if task.outerResult.Len() == 0 { - return nil, nil - } - numChks := task.outerResult.NumChunks() - if ow.Filter != nil { - task.outerMatch = make([][]bool, task.outerResult.NumChunks()) - var err error - exprCtx := ow.ctx.GetExprCtx() - for i := 0; i < numChks; i++ { - chk := task.outerResult.GetChunk(i) - outerMatch := make([]bool, 0, chk.NumRows()) - task.memTracker.Consume(int64(cap(outerMatch))) - task.outerMatch[i], err = expression.VectorizedFilter(exprCtx.GetEvalCtx(), ow.ctx.GetSessionVars().EnableVectorizedExpression, ow.Filter, chunk.NewIterator4Chunk(chk), outerMatch) - if err != nil { - return task, err - } - } - } - task.encodedLookUpKeys = make([]*chunk.Chunk, task.outerResult.NumChunks()) - for i := range task.encodedLookUpKeys { - task.encodedLookUpKeys[i] = ow.executor.NewChunkWithCapacity( - []*types.FieldType{types.NewFieldType(mysql.TypeBlob)}, - task.outerResult.GetChunk(i).NumRows(), - task.outerResult.GetChunk(i).NumRows(), - ) - } - return task, nil -} - -func (ow *outerWorker) increaseBatchSize() { - if ow.batchSize < ow.maxBatchSize { - ow.batchSize *= 2 - } - if ow.batchSize > ow.maxBatchSize { - ow.batchSize = ow.maxBatchSize - } -} - -func (iw *innerWorker) run(ctx context.Context, wg *sync.WaitGroup) { - defer trace.StartRegion(ctx, "IndexLookupJoinInnerWorker").End() - var task *lookUpJoinTask - defer func() { - if r := recover(); r != nil { - iw.lookup.Finished.Store(true) - logutil.Logger(ctx).Error("innerWorker panicked", zap.Any("recover", r), zap.Stack("stack")) - err := util.GetRecoverError(r) - // "task != nil" is guaranteed when panic happened. - task.doneCh <- err - } - wg.Done() - }() - - for ok := true; ok; { - select { - case task, ok = <-iw.taskCh: - if !ok { - return - } - case <-ctx.Done(): - return - } - - err := iw.handleTask(ctx, task) - task.doneCh <- err - } -} - -// IndexJoinLookUpContent is the content used in index lookup join -type IndexJoinLookUpContent struct { - Keys []types.Datum - Row chunk.Row - keyCols []int - KeyColIDs []int64 // the original ID in its table, used by dynamic partition pruning -} - -func (iw *innerWorker) handleTask(ctx context.Context, task *lookUpJoinTask) error { - if iw.stats != nil { - start := time.Now() - defer func() { - atomic.AddInt64(&iw.stats.totalTime, int64(time.Since(start))) - }() - } - defer func() { - iw.memTracker.Consume(-iw.memTracker.BytesConsumed()) - }() - lookUpContents, err := iw.constructLookupContent(task) - if err != nil { - return err - } - err = iw.fetchInnerResults(ctx, task, lookUpContents) - if err != nil { - return err - } - err = iw.buildLookUpMap(task) - if err != nil { - return err - } - return nil -} - -func (iw *innerWorker) constructLookupContent(task *lookUpJoinTask) ([]*IndexJoinLookUpContent, error) { - if iw.stats != nil { - start := time.Now() - defer func() { - atomic.AddInt64(&iw.stats.task, 1) - atomic.AddInt64(&iw.stats.construct, int64(time.Since(start))) - }() - } - lookUpContents := make([]*IndexJoinLookUpContent, 0, task.outerResult.Len()) - keyBuf := make([]byte, 0, 64) - for chkIdx := 0; chkIdx < task.outerResult.NumChunks(); chkIdx++ { - chk := task.outerResult.GetChunk(chkIdx) - numRows := chk.NumRows() - for rowIdx := 0; rowIdx < numRows; rowIdx++ { - dLookUpKey, dHashKey, err := iw.constructDatumLookupKey(task, chkIdx, rowIdx) - if err != nil { - if terror.ErrorEqual(err, types.ErrWrongValue) { - // We ignore rows with invalid datetime. - task.encodedLookUpKeys[chkIdx].AppendNull(0) - continue - } - return nil, err - } - failpoint.Inject("ConsumeRandomPanic", nil) - if rowIdx == 0 { - iw.memTracker.Consume(types.EstimatedMemUsage(dLookUpKey, numRows)) - } - if dHashKey == nil { - // Append null to make lookUpKeys the same length as Outer Result. - task.encodedLookUpKeys[chkIdx].AppendNull(0) - continue - } - keyBuf = keyBuf[:0] - keyBuf, err = codec.EncodeKey(iw.ctx.GetSessionVars().StmtCtx.TimeZone(), keyBuf, dHashKey...) - err = iw.ctx.GetSessionVars().StmtCtx.HandleError(err) - if err != nil { - if terror.ErrorEqual(err, types.ErrWrongValue) { - // we ignore rows with invalid datetime - task.encodedLookUpKeys[chkIdx].AppendNull(0) - continue - } - return nil, err - } - // Store the encoded lookup key in chunk, so we can use it to lookup the matched inners directly. - task.encodedLookUpKeys[chkIdx].AppendBytes(0, keyBuf) - if iw.HasPrefixCol { - for i, outerOffset := range iw.keyOff2IdxOff { - // If it's a prefix column. Try to fix it. - joinKeyColPrefixLen := iw.ColLens[outerOffset] - if joinKeyColPrefixLen != types.UnspecifiedLength { - ranger.CutDatumByPrefixLen(&dLookUpKey[i], joinKeyColPrefixLen, iw.RowTypes[iw.KeyCols[i]]) - } - } - // dLookUpKey is sorted and deduplicated at sortAndDedupLookUpContents. - // So we don't need to do it here. - } - lookUpContents = append(lookUpContents, &IndexJoinLookUpContent{Keys: dLookUpKey, Row: chk.GetRow(rowIdx), keyCols: iw.KeyCols, KeyColIDs: iw.KeyColIDs}) - } - } - - for i := range task.encodedLookUpKeys { - task.memTracker.Consume(task.encodedLookUpKeys[i].MemoryUsage()) - } - lookUpContents = iw.sortAndDedupLookUpContents(lookUpContents) - return lookUpContents, nil -} - -func (iw *innerWorker) constructDatumLookupKey(task *lookUpJoinTask, chkIdx, rowIdx int) ([]types.Datum, []types.Datum, error) { - if task.outerMatch != nil && !task.outerMatch[chkIdx][rowIdx] { - return nil, nil, nil - } - outerRow := task.outerResult.GetChunk(chkIdx).GetRow(rowIdx) - sc := iw.ctx.GetSessionVars().StmtCtx - keyLen := len(iw.KeyCols) - dLookupKey := make([]types.Datum, 0, keyLen) - dHashKey := make([]types.Datum, 0, len(iw.HashCols)) - for i, hashCol := range iw.outerCtx.HashCols { - outerValue := outerRow.GetDatum(hashCol, iw.outerCtx.RowTypes[hashCol]) - // Join-on-condition can be promised to be equal-condition in - // IndexNestedLoopJoin, thus the Filter will always be false if - // outerValue is null, and we don't need to lookup it. - if outerValue.IsNull() { - return nil, nil, nil - } - innerColType := iw.RowTypes[iw.HashCols[i]] - innerValue, err := outerValue.ConvertTo(sc.TypeCtx(), innerColType) - if err != nil && !(terror.ErrorEqual(err, types.ErrTruncated) && (innerColType.GetType() == mysql.TypeSet || innerColType.GetType() == mysql.TypeEnum)) { - // If the converted outerValue overflows or invalid to innerValue, we don't need to lookup it. - if terror.ErrorEqual(err, types.ErrOverflow) || terror.ErrorEqual(err, types.ErrWarnDataOutOfRange) { - return nil, nil, nil - } - return nil, nil, err - } - cmp, err := outerValue.Compare(sc.TypeCtx(), &innerValue, iw.HashCollators[i]) - if err != nil { - return nil, nil, err - } - if cmp != 0 { - // If the converted outerValue is not equal to the origin outerValue, we don't need to lookup it. - return nil, nil, nil - } - if i < keyLen { - dLookupKey = append(dLookupKey, innerValue) - } - dHashKey = append(dHashKey, innerValue) - } - return dLookupKey, dHashKey, nil -} - -func (iw *innerWorker) sortAndDedupLookUpContents(lookUpContents []*IndexJoinLookUpContent) []*IndexJoinLookUpContent { - if len(lookUpContents) < 2 { - return lookUpContents - } - sc := iw.ctx.GetSessionVars().StmtCtx - slices.SortFunc(lookUpContents, func(i, j *IndexJoinLookUpContent) int { - cmp := compareRow(sc, i.Keys, j.Keys, iw.KeyCollators) - if cmp != 0 || iw.nextColCompareFilters == nil { - return cmp - } - return iw.nextColCompareFilters.CompareRow(i.Row, j.Row) - }) - deDupedLookupKeys := lookUpContents[:1] - for i := 1; i < len(lookUpContents); i++ { - cmp := compareRow(sc, lookUpContents[i].Keys, lookUpContents[i-1].Keys, iw.KeyCollators) - if cmp != 0 || (iw.nextColCompareFilters != nil && iw.nextColCompareFilters.CompareRow(lookUpContents[i].Row, lookUpContents[i-1].Row) != 0) { - deDupedLookupKeys = append(deDupedLookupKeys, lookUpContents[i]) - } - } - return deDupedLookupKeys -} - -func compareRow(sc *stmtctx.StatementContext, left, right []types.Datum, ctors []collate.Collator) int { - for idx := 0; idx < len(left); idx++ { - cmp, err := left[idx].Compare(sc.TypeCtx(), &right[idx], ctors[idx]) - // We only compare rows with the same type, no error to return. - terror.Log(err) - if cmp > 0 { - return 1 - } else if cmp < 0 { - return -1 - } - } - return 0 -} - -func (iw *innerWorker) fetchInnerResults(ctx context.Context, task *lookUpJoinTask, lookUpContent []*IndexJoinLookUpContent) error { - if iw.stats != nil { - start := time.Now() - defer func() { - atomic.AddInt64(&iw.stats.fetch, int64(time.Since(start))) - }() - } - innerExec, err := iw.ReaderBuilder.BuildExecutorForIndexJoin(ctx, lookUpContent, iw.indexRanges, iw.keyOff2IdxOff, iw.nextColCompareFilters, true, iw.memTracker, iw.lookup.Finished) - if innerExec != nil { - defer func() { terror.Log(exec.Close(innerExec)) }() - } - if err != nil { - return err - } - - innerResult := chunk.NewList(exec.RetTypes(innerExec), iw.ctx.GetSessionVars().MaxChunkSize, iw.ctx.GetSessionVars().MaxChunkSize) - innerResult.GetMemTracker().SetLabel(memory.LabelForBuildSideResult) - innerResult.GetMemTracker().AttachTo(task.memTracker) - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - err := exec.Next(ctx, innerExec, iw.executorChk) - failpoint.Inject("ConsumeRandomPanic", nil) - if err != nil { - return err - } - if iw.executorChk.NumRows() == 0 { - break - } - innerResult.Add(iw.executorChk) - iw.executorChk = exec.TryNewCacheChunk(innerExec) - } - task.innerResult = innerResult - return nil -} - -func (iw *innerWorker) buildLookUpMap(task *lookUpJoinTask) error { - if iw.stats != nil { - start := time.Now() - defer func() { - atomic.AddInt64(&iw.stats.build, int64(time.Since(start))) - }() - } - keyBuf := make([]byte, 0, 64) - valBuf := make([]byte, 8) - for i := 0; i < task.innerResult.NumChunks(); i++ { - chk := task.innerResult.GetChunk(i) - for j := 0; j < chk.NumRows(); j++ { - innerRow := chk.GetRow(j) - if iw.hasNullInJoinKey(innerRow) { - continue - } - - keyBuf = keyBuf[:0] - for _, keyCol := range iw.HashCols { - d := innerRow.GetDatum(keyCol, iw.RowTypes[keyCol]) - var err error - keyBuf, err = codec.EncodeKey(iw.ctx.GetSessionVars().StmtCtx.TimeZone(), keyBuf, d) - err = iw.ctx.GetSessionVars().StmtCtx.HandleError(err) - if err != nil { - return err - } - } - rowPtr := chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)} - *(*chunk.RowPtr)(unsafe.Pointer(&valBuf[0])) = rowPtr - task.lookupMap.Put(keyBuf, valBuf) - } - } - return nil -} - -func (iw *innerWorker) hasNullInJoinKey(row chunk.Row) bool { - for _, ordinal := range iw.HashCols { - if row.IsNull(ordinal) { - return true - } - } - return false -} - -// Close implements the Executor interface. -func (e *IndexLookUpJoin) Close() error { - if e.stats != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - if e.cancelFunc != nil { - e.cancelFunc() - } - e.WorkerWg.Wait() - e.memTracker = nil - e.task = nil - e.Finished.Store(false) - e.prepared = false - return e.BaseExecutor.Close() -} - -type indexLookUpJoinRuntimeStats struct { - concurrency int - probe int64 - innerWorker innerWorkerRuntimeStats -} - -type innerWorkerRuntimeStats struct { - totalTime int64 - task int64 - construct int64 - fetch int64 - build int64 - join int64 -} - -func (e *indexLookUpJoinRuntimeStats) String() string { - buf := bytes.NewBuffer(make([]byte, 0, 16)) - if e.innerWorker.totalTime > 0 { - buf.WriteString("inner:{total:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.totalTime))) - buf.WriteString(", concurrency:") - if e.concurrency > 0 { - buf.WriteString(strconv.Itoa(e.concurrency)) - } else { - buf.WriteString("OFF") - } - buf.WriteString(", task:") - buf.WriteString(strconv.FormatInt(e.innerWorker.task, 10)) - buf.WriteString(", construct:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.construct))) - buf.WriteString(", fetch:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.fetch))) - buf.WriteString(", build:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.build))) - if e.innerWorker.join > 0 { - buf.WriteString(", join:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.innerWorker.join))) - } - buf.WriteString("}") - } - if e.probe > 0 { - buf.WriteString(", probe:") - buf.WriteString(execdetails.FormatDuration(time.Duration(e.probe))) - } - return buf.String() -} - -func (e *indexLookUpJoinRuntimeStats) Clone() execdetails.RuntimeStats { - return &indexLookUpJoinRuntimeStats{ - concurrency: e.concurrency, - probe: e.probe, - innerWorker: e.innerWorker, - } -} - -func (e *indexLookUpJoinRuntimeStats) Merge(rs execdetails.RuntimeStats) { - tmp, ok := rs.(*indexLookUpJoinRuntimeStats) - if !ok { - return - } - e.probe += tmp.probe - e.innerWorker.totalTime += tmp.innerWorker.totalTime - e.innerWorker.task += tmp.innerWorker.task - e.innerWorker.construct += tmp.innerWorker.construct - e.innerWorker.fetch += tmp.innerWorker.fetch - e.innerWorker.build += tmp.innerWorker.build - e.innerWorker.join += tmp.innerWorker.join -} - -// Tp implements the RuntimeStats interface. -func (*indexLookUpJoinRuntimeStats) Tp() int { - return execdetails.TpIndexLookUpJoinRuntimeStats -} diff --git a/pkg/executor/join/index_lookup_merge_join.go b/pkg/executor/join/index_lookup_merge_join.go index 0a8afa17f8234..c181eb04548d7 100644 --- a/pkg/executor/join/index_lookup_merge_join.go +++ b/pkg/executor/join/index_lookup_merge_join.go @@ -211,9 +211,9 @@ func (e *IndexLookUpMergeJoin) newOuterWorker(resultCh, innerCh chan *lookUpMerg parentMemTracker: e.memTracker, nextColCompareFilters: e.LastColHelper, } - if _, _err_ := failpoint.Eval(_curpkg_("testIssue18068")); _err_ == nil { + failpoint.Inject("testIssue18068", func() { omw.batchSize = 1 - } + }) return omw } @@ -316,7 +316,7 @@ func (omw *outerMergeWorker) run(ctx context.Context, wg *sync.WaitGroup, cancel omw.pushToChan(ctx, task, omw.resultCh) return } - failpoint.Eval(_curpkg_("mockIndexMergeJoinOOMPanic")) + failpoint.Inject("mockIndexMergeJoinOOMPanic", nil) if task == nil { return } diff --git a/pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ b/pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ deleted file mode 100644 index c181eb04548d7..0000000000000 --- a/pkg/executor/join/index_lookup_merge_join.go__failpoint_stash__ +++ /dev/null @@ -1,743 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "context" - "runtime/trace" - "slices" - "sync" - "sync/atomic" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" - "go.uber.org/zap" -) - -// IndexLookUpMergeJoin realizes IndexLookUpJoin by merge join -// It preserves the order of the outer table and support batch lookup. -// -// The execution flow is very similar to IndexLookUpReader: -// 1. outerWorker read N outer rows, build a task and send it to result channel and inner worker channel. -// 2. The innerWorker receives the task, builds key ranges from outer rows and fetch inner rows, then do merge join. -// 3. main thread receives the task and fetch results from the channel in task one by one. -// 4. If channel has been closed, main thread receives the next task. -type IndexLookUpMergeJoin struct { - exec.BaseExecutor - - resultCh <-chan *lookUpMergeJoinTask - cancelFunc context.CancelFunc - WorkerWg *sync.WaitGroup - - OuterMergeCtx OuterMergeCtx - InnerMergeCtx InnerMergeCtx - - Joiners []Joiner - joinChkResourceCh []chan *chunk.Chunk - IsOuterJoin bool - - requiredRows int64 - - task *lookUpMergeJoinTask - - IndexRanges ranger.MutableRanges - KeyOff2IdxOff []int - - // LastColHelper store the information for last col if there's complicated filter like col > x_col and col < x_col + 100. - LastColHelper *plannercore.ColWithCmpFuncManager - - memTracker *memory.Tracker // track memory usage - prepared bool -} - -// OuterMergeCtx is the outer side ctx of merge join -type OuterMergeCtx struct { - RowTypes []*types.FieldType - JoinKeys []*expression.Column - KeyCols []int - Filter expression.CNFExprs - NeedOuterSort bool - CompareFuncs []expression.CompareFunc -} - -// InnerMergeCtx is the inner side ctx of merge join -type InnerMergeCtx struct { - ReaderBuilder IndexJoinExecutorBuilder - RowTypes []*types.FieldType - JoinKeys []*expression.Column - KeyCols []int - KeyCollators []collate.Collator - CompareFuncs []expression.CompareFunc - ColLens []int - Desc bool - KeyOff2KeyOffOrderByIdx []int -} - -type lookUpMergeJoinTask struct { - outerResult *chunk.List - outerMatch [][]bool - outerOrderIdx []chunk.RowPtr - - innerResult *chunk.Chunk - innerIter chunk.Iterator - - sameKeyInnerRows []chunk.Row - sameKeyIter chunk.Iterator - - doneErr error - results chan *indexMergeJoinResult - - memTracker *memory.Tracker -} - -type outerMergeWorker struct { - OuterMergeCtx - - lookup *IndexLookUpMergeJoin - - ctx sessionctx.Context - executor exec.Executor - - maxBatchSize int - batchSize int - - nextColCompareFilters *plannercore.ColWithCmpFuncManager - - resultCh chan<- *lookUpMergeJoinTask - innerCh chan<- *lookUpMergeJoinTask - - parentMemTracker *memory.Tracker -} - -type innerMergeWorker struct { - InnerMergeCtx - - taskCh <-chan *lookUpMergeJoinTask - joinChkResourceCh chan *chunk.Chunk - outerMergeCtx OuterMergeCtx - ctx sessionctx.Context - innerExec exec.Executor - joiner Joiner - retFieldTypes []*types.FieldType - - maxChunkSize int - indexRanges []*ranger.Range - nextColCompareFilters *plannercore.ColWithCmpFuncManager - keyOff2IdxOff []int -} - -type indexMergeJoinResult struct { - chk *chunk.Chunk - src chan<- *chunk.Chunk -} - -// Open implements the Executor interface -func (e *IndexLookUpMergeJoin) Open(ctx context.Context) error { - err := exec.Open(ctx, e.Children(0)) - if err != nil { - return err - } - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - return nil -} - -func (e *IndexLookUpMergeJoin) startWorkers(ctx context.Context) { - // TODO: consider another session currency variable for index merge join. - // Because its parallelization is not complete. - concurrency := e.Ctx().GetSessionVars().IndexLookupJoinConcurrency() - if e.RuntimeStats() != nil { - runtimeStats := &execdetails.RuntimeStatsWithConcurrencyInfo{} - runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", concurrency)) - e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) - } - - resultCh := make(chan *lookUpMergeJoinTask, concurrency) - e.resultCh = resultCh - e.joinChkResourceCh = make([]chan *chunk.Chunk, concurrency) - for i := 0; i < concurrency; i++ { - e.joinChkResourceCh[i] = make(chan *chunk.Chunk, numResChkHold) - for j := 0; j < numResChkHold; j++ { - e.joinChkResourceCh[i] <- chunk.NewChunkWithCapacity(e.RetFieldTypes(), e.MaxChunkSize()) - } - } - workerCtx, cancelFunc := context.WithCancel(ctx) - e.cancelFunc = cancelFunc - innerCh := make(chan *lookUpMergeJoinTask, concurrency) - e.WorkerWg.Add(1) - go e.newOuterWorker(resultCh, innerCh).run(workerCtx, e.WorkerWg, e.cancelFunc) - e.WorkerWg.Add(concurrency) - for i := 0; i < concurrency; i++ { - go e.newInnerMergeWorker(innerCh, i).run(workerCtx, e.WorkerWg, e.cancelFunc) - } -} - -func (e *IndexLookUpMergeJoin) newOuterWorker(resultCh, innerCh chan *lookUpMergeJoinTask) *outerMergeWorker { - omw := &outerMergeWorker{ - OuterMergeCtx: e.OuterMergeCtx, - ctx: e.Ctx(), - lookup: e, - executor: e.Children(0), - resultCh: resultCh, - innerCh: innerCh, - batchSize: 32, - maxBatchSize: e.Ctx().GetSessionVars().IndexJoinBatchSize, - parentMemTracker: e.memTracker, - nextColCompareFilters: e.LastColHelper, - } - failpoint.Inject("testIssue18068", func() { - omw.batchSize = 1 - }) - return omw -} - -func (e *IndexLookUpMergeJoin) newInnerMergeWorker(taskCh chan *lookUpMergeJoinTask, workID int) *innerMergeWorker { - // Since multiple inner workers run concurrently, we should copy join's IndexRanges for every worker to avoid data race. - copiedRanges := make([]*ranger.Range, 0, len(e.IndexRanges.Range())) - for _, ran := range e.IndexRanges.Range() { - copiedRanges = append(copiedRanges, ran.Clone()) - } - imw := &innerMergeWorker{ - InnerMergeCtx: e.InnerMergeCtx, - outerMergeCtx: e.OuterMergeCtx, - taskCh: taskCh, - ctx: e.Ctx(), - indexRanges: copiedRanges, - keyOff2IdxOff: e.KeyOff2IdxOff, - joiner: e.Joiners[workID], - joinChkResourceCh: e.joinChkResourceCh[workID], - retFieldTypes: e.RetFieldTypes(), - maxChunkSize: e.MaxChunkSize(), - } - if e.LastColHelper != nil { - // nextCwf.TmpConstant needs to be reset for every individual - // inner worker to avoid data race when the inner workers is running - // concurrently. - nextCwf := *e.LastColHelper - nextCwf.TmpConstant = make([]*expression.Constant, len(e.LastColHelper.TmpConstant)) - for i := range e.LastColHelper.TmpConstant { - nextCwf.TmpConstant[i] = &expression.Constant{RetType: nextCwf.TargetCol.RetType} - } - imw.nextColCompareFilters = &nextCwf - } - return imw -} - -// Next implements the Executor interface -func (e *IndexLookUpMergeJoin) Next(ctx context.Context, req *chunk.Chunk) error { - if !e.prepared { - e.startWorkers(ctx) - e.prepared = true - } - if e.IsOuterJoin { - atomic.StoreInt64(&e.requiredRows, int64(req.RequiredRows())) - } - req.Reset() - if e.task == nil { - e.loadFinishedTask(ctx) - } - for e.task != nil { - select { - case result, ok := <-e.task.results: - if !ok { - if e.task.doneErr != nil { - return e.task.doneErr - } - e.loadFinishedTask(ctx) - continue - } - req.SwapColumns(result.chk) - result.src <- result.chk - return nil - case <-ctx.Done(): - return ctx.Err() - } - } - - return nil -} - -// TODO: reuse the Finished task memory to build tasks. -func (e *IndexLookUpMergeJoin) loadFinishedTask(ctx context.Context) { - select { - case e.task = <-e.resultCh: - case <-ctx.Done(): - e.task = nil - } -} - -func (omw *outerMergeWorker) run(ctx context.Context, wg *sync.WaitGroup, cancelFunc context.CancelFunc) { - defer trace.StartRegion(ctx, "IndexLookupMergeJoinOuterWorker").End() - defer func() { - if r := recover(); r != nil { - task := &lookUpMergeJoinTask{ - doneErr: util.GetRecoverError(r), - results: make(chan *indexMergeJoinResult, numResChkHold), - } - close(task.results) - omw.resultCh <- task - cancelFunc() - } - close(omw.resultCh) - close(omw.innerCh) - wg.Done() - }() - for { - task, err := omw.buildTask(ctx) - if err != nil { - task.doneErr = err - close(task.results) - omw.pushToChan(ctx, task, omw.resultCh) - return - } - failpoint.Inject("mockIndexMergeJoinOOMPanic", nil) - if task == nil { - return - } - - if finished := omw.pushToChan(ctx, task, omw.innerCh); finished { - return - } - - if finished := omw.pushToChan(ctx, task, omw.resultCh); finished { - return - } - } -} - -func (*outerMergeWorker) pushToChan(ctx context.Context, task *lookUpMergeJoinTask, dst chan<- *lookUpMergeJoinTask) (finished bool) { - select { - case <-ctx.Done(): - return true - case dst <- task: - } - return false -} - -// buildTask builds a lookUpMergeJoinTask and read Outer rows. -// When err is not nil, task must not be nil to send the error to the main thread via task -func (omw *outerMergeWorker) buildTask(ctx context.Context) (*lookUpMergeJoinTask, error) { - task := &lookUpMergeJoinTask{ - results: make(chan *indexMergeJoinResult, numResChkHold), - outerResult: chunk.NewList(omw.RowTypes, omw.executor.InitCap(), omw.executor.MaxChunkSize()), - } - task.memTracker = memory.NewTracker(memory.LabelForSimpleTask, -1) - task.memTracker.AttachTo(omw.parentMemTracker) - - omw.increaseBatchSize() - requiredRows := omw.batchSize - if omw.lookup.IsOuterJoin { - requiredRows = int(atomic.LoadInt64(&omw.lookup.requiredRows)) - } - if requiredRows <= 0 || requiredRows > omw.maxBatchSize { - requiredRows = omw.maxBatchSize - } - for requiredRows > 0 { - execChk := exec.TryNewCacheChunk(omw.executor) - err := exec.Next(ctx, omw.executor, execChk) - if err != nil { - return task, err - } - if execChk.NumRows() == 0 { - break - } - - task.outerResult.Add(execChk) - requiredRows -= execChk.NumRows() - task.memTracker.Consume(execChk.MemoryUsage()) - } - - if task.outerResult.Len() == 0 { - return nil, nil - } - - return task, nil -} - -func (omw *outerMergeWorker) increaseBatchSize() { - if omw.batchSize < omw.maxBatchSize { - omw.batchSize *= 2 - } - if omw.batchSize > omw.maxBatchSize { - omw.batchSize = omw.maxBatchSize - } -} - -func (imw *innerMergeWorker) run(ctx context.Context, wg *sync.WaitGroup, cancelFunc context.CancelFunc) { - defer trace.StartRegion(ctx, "IndexLookupMergeJoinInnerWorker").End() - var task *lookUpMergeJoinTask - defer func() { - wg.Done() - if r := recover(); r != nil { - if task != nil { - task.doneErr = util.GetRecoverError(r) - close(task.results) - } - logutil.Logger(ctx).Error("innerMergeWorker panicked", zap.Any("recover", r), zap.Stack("stack")) - cancelFunc() - } - }() - - for ok := true; ok; { - select { - case task, ok = <-imw.taskCh: - if !ok { - return - } - case <-ctx.Done(): - return - } - - err := imw.handleTask(ctx, task) - task.doneErr = err - close(task.results) - } -} - -func (imw *innerMergeWorker) handleTask(ctx context.Context, task *lookUpMergeJoinTask) (err error) { - numOuterChks := task.outerResult.NumChunks() - if imw.outerMergeCtx.Filter != nil { - task.outerMatch = make([][]bool, numOuterChks) - exprCtx := imw.ctx.GetExprCtx() - for i := 0; i < numOuterChks; i++ { - chk := task.outerResult.GetChunk(i) - task.outerMatch[i] = make([]bool, chk.NumRows()) - task.outerMatch[i], err = expression.VectorizedFilter(exprCtx.GetEvalCtx(), imw.ctx.GetSessionVars().EnableVectorizedExpression, imw.outerMergeCtx.Filter, chunk.NewIterator4Chunk(chk), task.outerMatch[i]) - if err != nil { - return err - } - } - } - task.memTracker.Consume(int64(cap(task.outerMatch))) - task.outerOrderIdx = make([]chunk.RowPtr, 0, task.outerResult.Len()) - for i := 0; i < numOuterChks; i++ { - numRow := task.outerResult.GetChunk(i).NumRows() - for j := 0; j < numRow; j++ { - task.outerOrderIdx = append(task.outerOrderIdx, chunk.RowPtr{ChkIdx: uint32(i), RowIdx: uint32(j)}) - } - } - task.memTracker.Consume(int64(cap(task.outerOrderIdx))) - // NeedOuterSort means the outer side property items can't guarantee the order of join keys. - // Because the necessary condition of merge join is both outer and inner keep order of join keys. - // In this case, we need sort the outer side. - if imw.outerMergeCtx.NeedOuterSort { - exprCtx := imw.ctx.GetExprCtx() - slices.SortFunc(task.outerOrderIdx, func(idxI, idxJ chunk.RowPtr) int { - rowI, rowJ := task.outerResult.GetRow(idxI), task.outerResult.GetRow(idxJ) - var c int64 - var err error - for _, keyOff := range imw.KeyOff2KeyOffOrderByIdx { - joinKey := imw.outerMergeCtx.JoinKeys[keyOff] - c, _, err = imw.outerMergeCtx.CompareFuncs[keyOff](exprCtx.GetEvalCtx(), joinKey, joinKey, rowI, rowJ) - terror.Log(err) - if c != 0 { - break - } - } - if c != 0 || imw.nextColCompareFilters == nil { - if imw.Desc { - return int(-c) - } - return int(c) - } - c = int64(imw.nextColCompareFilters.CompareRow(rowI, rowJ)) - if imw.Desc { - return int(-c) - } - return int(c) - }) - } - dLookUpKeys, err := imw.constructDatumLookupKeys(task) - if err != nil { - return err - } - dLookUpKeys = imw.dedupDatumLookUpKeys(dLookUpKeys) - // If the order requires descending, the deDupedLookUpContents is keep descending order before. - // So at the end, we should generate the ascending deDupedLookUpContents to build the correct range for inner read. - if imw.Desc { - lenKeys := len(dLookUpKeys) - for i := 0; i < lenKeys/2; i++ { - dLookUpKeys[i], dLookUpKeys[lenKeys-i-1] = dLookUpKeys[lenKeys-i-1], dLookUpKeys[i] - } - } - imw.innerExec, err = imw.ReaderBuilder.BuildExecutorForIndexJoin(ctx, dLookUpKeys, imw.indexRanges, imw.keyOff2IdxOff, imw.nextColCompareFilters, false, nil, nil) - if imw.innerExec != nil { - defer func() { terror.Log(exec.Close(imw.innerExec)) }() - } - if err != nil { - return err - } - _, err = imw.fetchNextInnerResult(ctx, task) - if err != nil { - return err - } - err = imw.doMergeJoin(ctx, task) - return err -} - -func (imw *innerMergeWorker) fetchNewChunkWhenFull(ctx context.Context, task *lookUpMergeJoinTask, chk **chunk.Chunk) (continueJoin bool) { - if !(*chk).IsFull() { - return true - } - select { - case task.results <- &indexMergeJoinResult{*chk, imw.joinChkResourceCh}: - case <-ctx.Done(): - return false - } - var ok bool - select { - case *chk, ok = <-imw.joinChkResourceCh: - if !ok { - return false - } - case <-ctx.Done(): - return false - } - (*chk).Reset() - return true -} - -func (imw *innerMergeWorker) doMergeJoin(ctx context.Context, task *lookUpMergeJoinTask) (err error) { - var chk *chunk.Chunk - select { - case chk = <-imw.joinChkResourceCh: - case <-ctx.Done(): - return - } - defer func() { - if chk == nil { - return - } - if chk.NumRows() > 0 { - select { - case task.results <- &indexMergeJoinResult{chk, imw.joinChkResourceCh}: - case <-ctx.Done(): - return - } - } else { - imw.joinChkResourceCh <- chk - } - }() - - initCmpResult := 1 - if imw.InnerMergeCtx.Desc { - initCmpResult = -1 - } - noneInnerRowsRemain := task.innerResult.NumRows() == 0 - - for _, outerIdx := range task.outerOrderIdx { - outerRow := task.outerResult.GetRow(outerIdx) - hasMatch, hasNull, cmpResult := false, false, initCmpResult - if task.outerMatch != nil && !task.outerMatch[outerIdx.ChkIdx][outerIdx.RowIdx] { - goto missMatch - } - // If it has iterated out all inner rows and the inner rows with same key is empty, - // that means the Outer Row needn't match any inner rows. - if noneInnerRowsRemain && len(task.sameKeyInnerRows) == 0 { - goto missMatch - } - if len(task.sameKeyInnerRows) > 0 { - cmpResult, err = imw.compare(outerRow, task.sameKeyIter.Begin()) - if err != nil { - return err - } - } - if (cmpResult > 0 && !imw.InnerMergeCtx.Desc) || (cmpResult < 0 && imw.InnerMergeCtx.Desc) { - if noneInnerRowsRemain { - task.sameKeyInnerRows = task.sameKeyInnerRows[:0] - goto missMatch - } - noneInnerRowsRemain, err = imw.fetchInnerRowsWithSameKey(ctx, task, outerRow) - if err != nil { - return err - } - } - - for task.sameKeyIter.Current() != task.sameKeyIter.End() { - matched, isNull, err := imw.joiner.TryToMatchInners(outerRow, task.sameKeyIter, chk) - if err != nil { - return err - } - hasMatch = hasMatch || matched - hasNull = hasNull || isNull - if !imw.fetchNewChunkWhenFull(ctx, task, &chk) { - return nil - } - } - - missMatch: - if !hasMatch { - imw.joiner.OnMissMatch(hasNull, outerRow, chk) - if !imw.fetchNewChunkWhenFull(ctx, task, &chk) { - return nil - } - } - } - - return nil -} - -// fetchInnerRowsWithSameKey collects the inner rows having the same key with one outer row. -func (imw *innerMergeWorker) fetchInnerRowsWithSameKey(ctx context.Context, task *lookUpMergeJoinTask, key chunk.Row) (noneInnerRows bool, err error) { - task.sameKeyInnerRows = task.sameKeyInnerRows[:0] - curRow := task.innerIter.Current() - var cmpRes int - for cmpRes, err = imw.compare(key, curRow); ((cmpRes >= 0 && !imw.Desc) || (cmpRes <= 0 && imw.Desc)) && err == nil; cmpRes, err = imw.compare(key, curRow) { - if cmpRes == 0 { - task.sameKeyInnerRows = append(task.sameKeyInnerRows, curRow) - } - curRow = task.innerIter.Next() - if curRow == task.innerIter.End() { - curRow, err = imw.fetchNextInnerResult(ctx, task) - if err != nil || task.innerResult.NumRows() == 0 { - break - } - } - } - task.sameKeyIter = chunk.NewIterator4Slice(task.sameKeyInnerRows) - task.sameKeyIter.Begin() - noneInnerRows = task.innerResult.NumRows() == 0 - return -} - -func (imw *innerMergeWorker) compare(outerRow, innerRow chunk.Row) (int, error) { - exprCtx := imw.ctx.GetExprCtx() - for _, keyOff := range imw.InnerMergeCtx.KeyOff2KeyOffOrderByIdx { - cmp, _, err := imw.InnerMergeCtx.CompareFuncs[keyOff](exprCtx.GetEvalCtx(), imw.outerMergeCtx.JoinKeys[keyOff], imw.InnerMergeCtx.JoinKeys[keyOff], outerRow, innerRow) - if err != nil || cmp != 0 { - return int(cmp), err - } - } - return 0, nil -} - -func (imw *innerMergeWorker) constructDatumLookupKeys(task *lookUpMergeJoinTask) ([]*IndexJoinLookUpContent, error) { - numRows := len(task.outerOrderIdx) - dLookUpKeys := make([]*IndexJoinLookUpContent, 0, numRows) - for i := 0; i < numRows; i++ { - dLookUpKey, err := imw.constructDatumLookupKey(task, task.outerOrderIdx[i]) - if err != nil { - return nil, err - } - if dLookUpKey == nil { - continue - } - dLookUpKeys = append(dLookUpKeys, dLookUpKey) - } - - return dLookUpKeys, nil -} - -func (imw *innerMergeWorker) constructDatumLookupKey(task *lookUpMergeJoinTask, idx chunk.RowPtr) (*IndexJoinLookUpContent, error) { - if task.outerMatch != nil && !task.outerMatch[idx.ChkIdx][idx.RowIdx] { - return nil, nil - } - outerRow := task.outerResult.GetRow(idx) - sc := imw.ctx.GetSessionVars().StmtCtx - keyLen := len(imw.KeyCols) - dLookupKey := make([]types.Datum, 0, keyLen) - for i, keyCol := range imw.outerMergeCtx.KeyCols { - outerValue := outerRow.GetDatum(keyCol, imw.outerMergeCtx.RowTypes[keyCol]) - // Join-on-condition can be promised to be equal-condition in - // IndexNestedLoopJoin, thus the Filter will always be false if - // outerValue is null, and we don't need to lookup it. - if outerValue.IsNull() { - return nil, nil - } - innerColType := imw.RowTypes[imw.KeyCols[i]] - innerValue, err := outerValue.ConvertTo(sc.TypeCtx(), innerColType) - if err != nil { - // If the converted outerValue overflows, we don't need to lookup it. - if terror.ErrorEqual(err, types.ErrOverflow) || terror.ErrorEqual(err, types.ErrWarnDataOutOfRange) { - return nil, nil - } - if terror.ErrorEqual(err, types.ErrTruncated) && (innerColType.GetType() == mysql.TypeSet || innerColType.GetType() == mysql.TypeEnum) { - return nil, nil - } - return nil, err - } - cmp, err := outerValue.Compare(sc.TypeCtx(), &innerValue, imw.KeyCollators[i]) - if err != nil { - return nil, err - } - if cmp != 0 { - // If the converted outerValue is not equal to the origin outerValue, we don't need to lookup it. - return nil, nil - } - dLookupKey = append(dLookupKey, innerValue) - } - return &IndexJoinLookUpContent{Keys: dLookupKey, Row: task.outerResult.GetRow(idx)}, nil -} - -func (imw *innerMergeWorker) dedupDatumLookUpKeys(lookUpContents []*IndexJoinLookUpContent) []*IndexJoinLookUpContent { - if len(lookUpContents) < 2 { - return lookUpContents - } - sc := imw.ctx.GetSessionVars().StmtCtx - deDupedLookUpContents := lookUpContents[:1] - for i := 1; i < len(lookUpContents); i++ { - cmp := compareRow(sc, lookUpContents[i].Keys, lookUpContents[i-1].Keys, imw.KeyCollators) - if cmp != 0 || (imw.nextColCompareFilters != nil && imw.nextColCompareFilters.CompareRow(lookUpContents[i].Row, lookUpContents[i-1].Row) != 0) { - deDupedLookUpContents = append(deDupedLookUpContents, lookUpContents[i]) - } - } - return deDupedLookUpContents -} - -// fetchNextInnerResult collects a chunk of inner results from inner child executor. -func (imw *innerMergeWorker) fetchNextInnerResult(ctx context.Context, task *lookUpMergeJoinTask) (beginRow chunk.Row, err error) { - task.innerResult = imw.innerExec.NewChunkWithCapacity(imw.innerExec.RetFieldTypes(), imw.innerExec.MaxChunkSize(), imw.innerExec.MaxChunkSize()) - err = exec.Next(ctx, imw.innerExec, task.innerResult) - task.innerIter = chunk.NewIterator4Chunk(task.innerResult) - beginRow = task.innerIter.Begin() - return -} - -// Close implements the Executor interface. -func (e *IndexLookUpMergeJoin) Close() error { - if e.RuntimeStats() != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.RuntimeStats()) - } - if e.cancelFunc != nil { - e.cancelFunc() - e.cancelFunc = nil - } - if e.resultCh != nil { - channel.Clear(e.resultCh) - e.resultCh = nil - } - e.joinChkResourceCh = nil - // joinChkResourceCh is to recycle result chunks, used by inner worker. - // resultCh is the main thread get the results, used by main thread and inner worker. - // cancelFunc control the outer worker and outer worker close the task channel. - e.WorkerWg.Wait() - e.memTracker = nil - e.prepared = false - return e.BaseExecutor.Close() -} diff --git a/pkg/executor/join/merge_join.go b/pkg/executor/join/merge_join.go index 0be8ff35fae1e..7dc647ce6c592 100644 --- a/pkg/executor/join/merge_join.go +++ b/pkg/executor/join/merge_join.go @@ -103,11 +103,11 @@ func (t *MergeJoinTable) init(executor *MergeJoinExec) { t.rowContainer.GetDiskTracker().SetLabel(memory.LabelForInnerTable) if variable.EnableTmpStorageOnOOM.Load() { actionSpill := t.rowContainer.ActionSpill() - if val, _err_ := failpoint.Eval(_curpkg_("testMergeJoinRowContainerSpill")); _err_ == nil { + failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { if val.(bool) { actionSpill = t.rowContainer.ActionSpillForTest() } - } + }) executor.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) } t.memTracker = memory.NewTracker(memory.LabelForInnerTable, -1) @@ -128,12 +128,12 @@ func (t *MergeJoinTable) finish() error { t.memTracker.Consume(-t.childChunk.MemoryUsage()) if t.IsInner { - if val, _err_ := failpoint.Eval(_curpkg_("testMergeJoinRowContainerSpill")); _err_ == nil { + failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { if val.(bool) { actionSpill := t.rowContainer.ActionSpill() actionSpill.WaitForTest() } - } + }) if err := t.rowContainer.Close(); err != nil { return err } @@ -330,7 +330,7 @@ func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) innerIter := e.InnerTable.groupRowsIter outerIter := e.OuterTable.groupRowsIter for !req.IsFull() { - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) if innerIter.Current() == innerIter.End() { if err := e.InnerTable.fetchNextInnerGroup(ctx, e); err != nil { return err diff --git a/pkg/executor/join/merge_join.go__failpoint_stash__ b/pkg/executor/join/merge_join.go__failpoint_stash__ deleted file mode 100644 index 7dc647ce6c592..0000000000000 --- a/pkg/executor/join/merge_join.go__failpoint_stash__ +++ /dev/null @@ -1,420 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package join - -import ( - "context" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/memory" -) - -var ( - _ exec.Executor = &MergeJoinExec{} -) - -// MergeJoinExec implements the merge join algorithm. -// This operator assumes that two iterators of both sides -// will provide required order on join condition: -// 1. For equal-join, one of the join key from each side -// matches the order given. -// 2. For other cases its preferred not to use SMJ and operator -// will throw error. -type MergeJoinExec struct { - exec.BaseExecutor - - StmtCtx *stmtctx.StatementContext - CompareFuncs []expression.CompareFunc - Joiner Joiner - IsOuterJoin bool - Desc bool - - InnerTable *MergeJoinTable - OuterTable *MergeJoinTable - - hasMatch bool - hasNull bool - - memTracker *memory.Tracker - diskTracker *disk.Tracker -} - -// MergeJoinTable is used for merge join -type MergeJoinTable struct { - inited bool - IsInner bool - ChildIndex int - JoinKeys []*expression.Column - Filters []expression.Expression - - executed bool - childChunk *chunk.Chunk - childChunkIter *chunk.Iterator4Chunk - groupChecker *vecgroupchecker.VecGroupChecker - groupRowsSelected []int - groupRowsIter chunk.Iterator - - // for inner table, an unbroken group may refer many chunks - rowContainer *chunk.RowContainer - - // for outer table, save result of filters - filtersSelected []bool - - memTracker *memory.Tracker -} - -func (t *MergeJoinTable) init(executor *MergeJoinExec) { - child := executor.Children(t.ChildIndex) - t.childChunk = exec.TryNewCacheChunk(child) - t.childChunkIter = chunk.NewIterator4Chunk(t.childChunk) - - items := make([]expression.Expression, 0, len(t.JoinKeys)) - for _, col := range t.JoinKeys { - items = append(items, col) - } - vecEnabled := executor.Ctx().GetSessionVars().EnableVectorizedExpression - t.groupChecker = vecgroupchecker.NewVecGroupChecker(executor.Ctx().GetExprCtx().GetEvalCtx(), vecEnabled, items) - t.groupRowsIter = chunk.NewIterator4Chunk(t.childChunk) - - if t.IsInner { - t.rowContainer = chunk.NewRowContainer(child.RetFieldTypes(), t.childChunk.Capacity()) - t.rowContainer.GetMemTracker().AttachTo(executor.memTracker) - t.rowContainer.GetMemTracker().SetLabel(memory.LabelForInnerTable) - t.rowContainer.GetDiskTracker().AttachTo(executor.diskTracker) - t.rowContainer.GetDiskTracker().SetLabel(memory.LabelForInnerTable) - if variable.EnableTmpStorageOnOOM.Load() { - actionSpill := t.rowContainer.ActionSpill() - failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { - if val.(bool) { - actionSpill = t.rowContainer.ActionSpillForTest() - } - }) - executor.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(actionSpill) - } - t.memTracker = memory.NewTracker(memory.LabelForInnerTable, -1) - } else { - t.filtersSelected = make([]bool, 0, executor.MaxChunkSize()) - t.memTracker = memory.NewTracker(memory.LabelForOuterTable, -1) - } - - t.memTracker.AttachTo(executor.memTracker) - t.inited = true - t.memTracker.Consume(t.childChunk.MemoryUsage()) -} - -func (t *MergeJoinTable) finish() error { - if !t.inited { - return nil - } - t.memTracker.Consume(-t.childChunk.MemoryUsage()) - - if t.IsInner { - failpoint.Inject("testMergeJoinRowContainerSpill", func(val failpoint.Value) { - if val.(bool) { - actionSpill := t.rowContainer.ActionSpill() - actionSpill.WaitForTest() - } - }) - if err := t.rowContainer.Close(); err != nil { - return err - } - } - - t.executed = false - t.childChunk = nil - t.childChunkIter = nil - t.groupChecker = nil - t.groupRowsSelected = nil - t.groupRowsIter = nil - t.rowContainer = nil - t.filtersSelected = nil - t.memTracker = nil - return nil -} - -func (t *MergeJoinTable) selectNextGroup() { - t.groupRowsSelected = t.groupRowsSelected[:0] - begin, end := t.groupChecker.GetNextGroup() - if t.IsInner && t.hasNullInJoinKey(t.childChunk.GetRow(begin)) { - return - } - - for i := begin; i < end; i++ { - t.groupRowsSelected = append(t.groupRowsSelected, i) - } - t.childChunk.SetSel(t.groupRowsSelected) -} - -func (t *MergeJoinTable) fetchNextChunk(ctx context.Context, executor *MergeJoinExec) error { - oldMemUsage := t.childChunk.MemoryUsage() - err := exec.Next(ctx, executor.Children(t.ChildIndex), t.childChunk) - t.memTracker.Consume(t.childChunk.MemoryUsage() - oldMemUsage) - if err != nil { - return err - } - t.executed = t.childChunk.NumRows() == 0 - return nil -} - -func (t *MergeJoinTable) fetchNextInnerGroup(ctx context.Context, exec *MergeJoinExec) error { - t.childChunk.SetSel(nil) - if err := t.rowContainer.Reset(); err != nil { - return err - } - -fetchNext: - if t.executed && t.groupChecker.IsExhausted() { - // Ensure iter at the end, since sel of childChunk has been cleared. - t.groupRowsIter.ReachEnd() - return nil - } - - isEmpty := true - // For inner table, rows have null in join keys should be skip by selectNextGroup. - for isEmpty && !t.groupChecker.IsExhausted() { - t.selectNextGroup() - isEmpty = len(t.groupRowsSelected) == 0 - } - - // For inner table, all the rows have the same join keys should be put into one group. - for !t.executed && t.groupChecker.IsExhausted() { - if !isEmpty { - // Group is not empty, hand over the management of childChunk to t.RowContainer. - if err := t.rowContainer.Add(t.childChunk); err != nil { - return err - } - t.memTracker.Consume(-t.childChunk.MemoryUsage()) - t.groupRowsSelected = nil - - t.childChunk = t.rowContainer.AllocChunk() - t.childChunkIter = chunk.NewIterator4Chunk(t.childChunk) - t.memTracker.Consume(t.childChunk.MemoryUsage()) - } - - if err := t.fetchNextChunk(ctx, exec); err != nil { - return err - } - if t.executed { - break - } - - isFirstGroupSameAsPrev, err := t.groupChecker.SplitIntoGroups(t.childChunk) - if err != nil { - return err - } - if isFirstGroupSameAsPrev && !isEmpty { - t.selectNextGroup() - } - } - if isEmpty { - goto fetchNext - } - - // iterate all data in t.RowContainer and t.childChunk - var iter chunk.Iterator - if t.rowContainer.NumChunks() != 0 { - iter = chunk.NewIterator4RowContainer(t.rowContainer) - } - if len(t.groupRowsSelected) != 0 { - if iter != nil { - iter = chunk.NewMultiIterator(iter, t.childChunkIter) - } else { - iter = t.childChunkIter - } - } - t.groupRowsIter = iter - t.groupRowsIter.Begin() - return nil -} - -func (t *MergeJoinTable) fetchNextOuterGroup(ctx context.Context, exec *MergeJoinExec, requiredRows int) error { - if t.executed && t.groupChecker.IsExhausted() { - return nil - } - - if !t.executed && t.groupChecker.IsExhausted() { - // It's hard to calculate selectivity if there is any filter or it's inner join, - // so we just push the requiredRows down when it's outer join and has no filter. - if exec.IsOuterJoin && len(t.Filters) == 0 { - t.childChunk.SetRequiredRows(requiredRows, exec.MaxChunkSize()) - } - err := t.fetchNextChunk(ctx, exec) - if err != nil || t.executed { - return err - } - - t.childChunkIter.Begin() - t.filtersSelected, err = expression.VectorizedFilter(exec.Ctx().GetExprCtx().GetEvalCtx(), exec.Ctx().GetSessionVars().EnableVectorizedExpression, t.Filters, t.childChunkIter, t.filtersSelected) - if err != nil { - return err - } - - _, err = t.groupChecker.SplitIntoGroups(t.childChunk) - if err != nil { - return err - } - } - - t.selectNextGroup() - t.groupRowsIter.Begin() - return nil -} - -func (t *MergeJoinTable) hasNullInJoinKey(row chunk.Row) bool { - for _, col := range t.JoinKeys { - ordinal := col.Index - if row.IsNull(ordinal) { - return true - } - } - return false -} - -// Close implements the Executor Close interface. -func (e *MergeJoinExec) Close() error { - if err := e.InnerTable.finish(); err != nil { - return err - } - if err := e.OuterTable.finish(); err != nil { - return err - } - - e.hasMatch = false - e.hasNull = false - e.memTracker = nil - e.diskTracker = nil - return e.BaseExecutor.Close() -} - -// Open implements the Executor Open interface. -func (e *MergeJoinExec) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - e.diskTracker = disk.NewTracker(e.ID(), -1) - e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) - - e.InnerTable.init(e) - e.OuterTable.init(e) - return nil -} - -// Next implements the Executor Next interface. -// Note the inner group collects all identical keys in a group across multiple chunks, but the outer group just covers -// the identical keys within a chunk, so identical keys may cover more than one chunk. -func (e *MergeJoinExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - req.Reset() - - innerIter := e.InnerTable.groupRowsIter - outerIter := e.OuterTable.groupRowsIter - for !req.IsFull() { - failpoint.Inject("ConsumeRandomPanic", nil) - if innerIter.Current() == innerIter.End() { - if err := e.InnerTable.fetchNextInnerGroup(ctx, e); err != nil { - return err - } - innerIter = e.InnerTable.groupRowsIter - } - if outerIter.Current() == outerIter.End() { - if err := e.OuterTable.fetchNextOuterGroup(ctx, e, req.RequiredRows()-req.NumRows()); err != nil { - return err - } - outerIter = e.OuterTable.groupRowsIter - if e.OuterTable.executed { - return nil - } - } - - cmpResult := -1 - if e.Desc { - cmpResult = 1 - } - if innerIter.Current() != innerIter.End() { - cmpResult, err = e.compare(outerIter.Current(), innerIter.Current()) - if err != nil { - return err - } - } - // the inner group falls behind - if (cmpResult > 0 && !e.Desc) || (cmpResult < 0 && e.Desc) { - innerIter.ReachEnd() - continue - } - // the Outer group falls behind - if (cmpResult < 0 && !e.Desc) || (cmpResult > 0 && e.Desc) { - for row := outerIter.Current(); row != outerIter.End() && !req.IsFull(); row = outerIter.Next() { - e.Joiner.OnMissMatch(false, row, req) - } - continue - } - - for row := outerIter.Current(); row != outerIter.End() && !req.IsFull(); row = outerIter.Next() { - if !e.OuterTable.filtersSelected[row.Idx()] { - e.Joiner.OnMissMatch(false, row, req) - continue - } - // compare each Outer item with each inner item - // the inner maybe not exhausted at one time - for innerIter.Current() != innerIter.End() { - matched, isNull, err := e.Joiner.TryToMatchInners(row, innerIter, req) - if err != nil { - return err - } - e.hasMatch = e.hasMatch || matched - e.hasNull = e.hasNull || isNull - if req.IsFull() { - if innerIter.Current() == innerIter.End() { - break - } - return nil - } - } - - if !e.hasMatch { - e.Joiner.OnMissMatch(e.hasNull, row, req) - } - e.hasMatch = false - e.hasNull = false - innerIter.Begin() - } - } - return nil -} - -func (e *MergeJoinExec) compare(outerRow, innerRow chunk.Row) (int, error) { - outerJoinKeys := e.OuterTable.JoinKeys - innerJoinKeys := e.InnerTable.JoinKeys - for i := range outerJoinKeys { - cmp, _, err := e.CompareFuncs[i](e.Ctx().GetExprCtx().GetEvalCtx(), outerJoinKeys[i], innerJoinKeys[i], outerRow, innerRow) - if err != nil { - return 0, err - } - - if cmp != 0 { - return int(cmp), nil - } - } - return 0, nil -} diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index 91c4fc40ee0a5..e6471546d36a5 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -242,7 +242,7 @@ func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDa }) // commitWork goroutines. group.Go(func() error { - failpoint.Eval(_curpkg_("BeforeCommitWork")) + failpoint.Inject("BeforeCommitWork", nil) return committer.commitWork(groupCtx, commitTaskCh) }) @@ -620,9 +620,9 @@ func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error logutil.Logger(ctx).Error("commit error CheckAndInsert", zap.Error(err)) return err } - if _, _err_ := failpoint.Eval(_curpkg_("commitOneTaskErr")); _err_ == nil { - return errors.New("mock commit one task error") - } + failpoint.Inject("commitOneTaskErr", func() { + failpoint.Return(errors.New("mock commit one task error")) + }) return nil } diff --git a/pkg/executor/load_data.go__failpoint_stash__ b/pkg/executor/load_data.go__failpoint_stash__ deleted file mode 100644 index e6471546d36a5..0000000000000 --- a/pkg/executor/load_data.go__failpoint_stash__ +++ /dev/null @@ -1,780 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "fmt" - "io" - "math" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/executor/importer" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sqlkiller" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -// LoadDataVarKey is a variable key for load data. -const LoadDataVarKey loadDataVarKeyType = 0 - -// LoadDataReaderBuilderKey stores the reader channel that reads from the connection. -const LoadDataReaderBuilderKey loadDataVarKeyType = 1 - -var ( - taskQueueSize = 16 // the maximum number of pending tasks to commit in queue -) - -// LoadDataReaderBuilder is a function type that builds a reader from a file path. -type LoadDataReaderBuilder func(filepath string) ( - r io.ReadCloser, err error, -) - -// LoadDataExec represents a load data executor. -type LoadDataExec struct { - exec.BaseExecutor - - FileLocRef ast.FileLocRefTp - loadDataWorker *LoadDataWorker - - // fields for loading local file - infileReader io.ReadCloser -} - -// Open implements the Executor interface. -func (e *LoadDataExec) Open(_ context.Context) error { - if rb, ok := e.Ctx().Value(LoadDataReaderBuilderKey).(LoadDataReaderBuilder); ok { - var err error - e.infileReader, err = rb(e.loadDataWorker.GetInfilePath()) - if err != nil { - return err - } - } - return nil -} - -// Close implements the Executor interface. -func (e *LoadDataExec) Close() error { - return e.closeLocalReader(nil) -} - -func (e *LoadDataExec) closeLocalReader(originalErr error) error { - err := originalErr - if e.infileReader != nil { - if err2 := e.infileReader.Close(); err2 != nil { - logutil.BgLogger().Error( - "close local reader failed", zap.Error(err2), - zap.NamedError("original error", originalErr), - ) - if err == nil { - err = err2 - } - } - e.infileReader = nil - } - return err -} - -// Next implements the Executor Next interface. -func (e *LoadDataExec) Next(ctx context.Context, _ *chunk.Chunk) (err error) { - switch e.FileLocRef { - case ast.FileLocServerOrRemote: - return e.loadDataWorker.loadRemote(ctx) - case ast.FileLocClient: - // This is for legacy test only - // TODO: adjust tests to remove LoadDataVarKey - sctx := e.loadDataWorker.UserSctx - sctx.SetValue(LoadDataVarKey, e.loadDataWorker) - - err = e.loadDataWorker.LoadLocal(ctx, e.infileReader) - if err != nil { - logutil.Logger(ctx).Error("load local data failed", zap.Error(err)) - err = e.closeLocalReader(err) - return err - } - } - return nil -} - -type planInfo struct { - ID int - Columns []*ast.ColumnName - GenColExprs []expression.Expression -} - -// LoadDataWorker does a LOAD DATA job. -type LoadDataWorker struct { - UserSctx sessionctx.Context - - controller *importer.LoadDataController - planInfo planInfo - - table table.Table -} - -func setNonRestrictiveFlags(stmtCtx *stmtctx.StatementContext) { - // TODO: DupKeyAsWarning represents too many "ignore error" paths, the - // meaning of this flag is not clear. I can only reuse it here. - levels := stmtCtx.ErrLevels() - levels[errctx.ErrGroupDupKey] = errctx.LevelWarn - levels[errctx.ErrGroupBadNull] = errctx.LevelWarn - stmtCtx.SetErrLevels(levels) - stmtCtx.SetTypeFlags(stmtCtx.TypeFlags().WithTruncateAsWarning(true)) -} - -// NewLoadDataWorker creates a new LoadDataWorker that is ready to work. -func NewLoadDataWorker( - userSctx sessionctx.Context, - plan *plannercore.LoadData, - tbl table.Table, -) (w *LoadDataWorker, err error) { - importPlan, err := importer.NewPlanFromLoadDataPlan(userSctx, plan) - if err != nil { - return nil, err - } - astArgs := importer.ASTArgsFromPlan(plan) - controller, err := importer.NewLoadDataController(importPlan, tbl, astArgs) - if err != nil { - return nil, err - } - - if !controller.Restrictive { - setNonRestrictiveFlags(userSctx.GetSessionVars().StmtCtx) - } - - loadDataWorker := &LoadDataWorker{ - UserSctx: userSctx, - table: tbl, - controller: controller, - planInfo: planInfo{ - ID: plan.ID(), - Columns: plan.Columns, - GenColExprs: plan.GenCols.Exprs, - }, - } - return loadDataWorker, nil -} - -func (e *LoadDataWorker) loadRemote(ctx context.Context) error { - if err2 := e.controller.InitDataFiles(ctx); err2 != nil { - return err2 - } - return e.load(ctx, e.controller.GetLoadDataReaderInfos()) -} - -// LoadLocal reads from client connection and do load data job. -func (e *LoadDataWorker) LoadLocal(ctx context.Context, r io.ReadCloser) error { - if r == nil { - return errors.New("load local data, reader is nil") - } - - compressTp := mydump.ParseCompressionOnFileExtension(e.GetInfilePath()) - compressTp2, err := mydump.ToStorageCompressType(compressTp) - if err != nil { - return err - } - readers := []importer.LoadDataReaderInfo{{ - Opener: func(_ context.Context) (io.ReadSeekCloser, error) { - addedSeekReader := NewSimpleSeekerOnReadCloser(r) - return storage.InterceptDecompressReader(addedSeekReader, compressTp2, storage.DecompressConfig{ - ZStdDecodeConcurrency: 1, - }) - }}} - return e.load(ctx, readers) -} - -func (e *LoadDataWorker) load(ctx context.Context, readerInfos []importer.LoadDataReaderInfo) error { - group, groupCtx := errgroup.WithContext(ctx) - - encoder, committer, err := initEncodeCommitWorkers(e) - if err != nil { - return err - } - - // main goroutine -> readerInfoCh -> processOneStream goroutines - readerInfoCh := make(chan importer.LoadDataReaderInfo, 1) - // processOneStream goroutines -> commitTaskCh -> commitWork goroutines - commitTaskCh := make(chan commitTask, taskQueueSize) - // commitWork goroutines -> done -> UpdateJobProgress goroutine - - // processOneStream goroutines. - group.Go(func() error { - err2 := encoder.processStream(groupCtx, readerInfoCh, commitTaskCh) - if err2 == nil { - close(commitTaskCh) - } - return err2 - }) - // commitWork goroutines. - group.Go(func() error { - failpoint.Inject("BeforeCommitWork", nil) - return committer.commitWork(groupCtx, commitTaskCh) - }) - -sendReaderInfoLoop: - for _, info := range readerInfos { - select { - case <-groupCtx.Done(): - break sendReaderInfoLoop - case readerInfoCh <- info: - } - } - close(readerInfoCh) - err = group.Wait() - e.setResult(encoder.exprWarnings) - return err -} - -func (e *LoadDataWorker) setResult(colAssignExprWarnings []contextutil.SQLWarn) { - stmtCtx := e.UserSctx.GetSessionVars().StmtCtx - numWarnings := uint64(stmtCtx.WarningCount()) - numRecords := stmtCtx.RecordRows() - numDeletes := stmtCtx.DeletedRows() - numSkipped := stmtCtx.RecordRows() - stmtCtx.CopiedRows() - - // col assign expr warnings is generated during init, it's static - // we need to generate it for each row processed. - numWarnings += numRecords * uint64(len(colAssignExprWarnings)) - - if numWarnings > math.MaxUint16 { - numWarnings = math.MaxUint16 - } - - msg := fmt.Sprintf(mysql.MySQLErrName[mysql.ErrLoadInfo].Raw, numRecords, numDeletes, numSkipped, numWarnings) - warns := make([]contextutil.SQLWarn, numWarnings) - n := copy(warns, stmtCtx.GetWarnings()) - for i := 0; i < int(numRecords) && n < len(warns); i++ { - n += copy(warns[n:], colAssignExprWarnings) - } - - stmtCtx.SetMessage(msg) - stmtCtx.SetWarnings(warns) -} - -func initEncodeCommitWorkers(e *LoadDataWorker) (*encodeWorker, *commitWorker, error) { - insertValues, err2 := createInsertValues(e) - if err2 != nil { - return nil, nil, err2 - } - colAssignExprs, exprWarnings, err2 := e.controller.CreateColAssignExprs(insertValues.Ctx()) - if err2 != nil { - return nil, nil, err2 - } - enc := &encodeWorker{ - InsertValues: insertValues, - controller: e.controller, - colAssignExprs: colAssignExprs, - exprWarnings: exprWarnings, - killer: &e.UserSctx.GetSessionVars().SQLKiller, - } - enc.resetBatch() - com := &commitWorker{ - InsertValues: insertValues, - controller: e.controller, - } - return enc, com, nil -} - -// createInsertValues creates InsertValues from userSctx. -func createInsertValues(e *LoadDataWorker) (insertVal *InsertValues, err error) { - insertColumns := e.controller.InsertColumns - hasExtraHandle := false - for _, col := range insertColumns { - if col.Name.L == model.ExtraHandleName.L { - if !e.UserSctx.GetSessionVars().AllowWriteRowID { - return nil, errors.Errorf("load data statement for _tidb_rowid are not supported") - } - hasExtraHandle = true - break - } - } - ret := &InsertValues{ - BaseExecutor: exec.NewBaseExecutor(e.UserSctx, nil, e.planInfo.ID), - Table: e.table, - Columns: e.planInfo.Columns, - GenExprs: e.planInfo.GenColExprs, - maxRowsInBatch: 1000, - insertColumns: insertColumns, - rowLen: len(insertColumns), - hasExtraHandle: hasExtraHandle, - } - if len(insertColumns) > 0 { - ret.initEvalBuffer() - } - ret.collectRuntimeStatsEnabled() - return ret, nil -} - -// encodeWorker is a sub-worker of LoadDataWorker that dedicated to encode data. -type encodeWorker struct { - *InsertValues - controller *importer.LoadDataController - colAssignExprs []expression.Expression - // sessionCtx generate warnings when rewrite AST node into expression. - // we should generate such warnings for each row encoded. - exprWarnings []contextutil.SQLWarn - killer *sqlkiller.SQLKiller - rows [][]types.Datum -} - -// commitTask is used for passing data from processStream goroutine to commitWork goroutine. -type commitTask struct { - cnt uint64 - rows [][]types.Datum -} - -// processStream always tries to build a parser from channel and process it. When -// it returns nil, it means all data is read. -func (w *encodeWorker) processStream( - ctx context.Context, - inCh <-chan importer.LoadDataReaderInfo, - outCh chan<- commitTask, -) error { - for { - select { - case <-ctx.Done(): - return ctx.Err() - case readerInfo, ok := <-inCh: - if !ok { - return nil - } - dataParser, err := w.controller.GetParser(ctx, readerInfo) - if err != nil { - return err - } - if err = w.controller.HandleSkipNRows(dataParser); err != nil { - return err - } - err = w.processOneStream(ctx, dataParser, outCh) - terror.Log(dataParser.Close()) - if err != nil { - return err - } - } - } -} - -// processOneStream process input stream from parser. When returns nil, it means -// all data is read. -func (w *encodeWorker) processOneStream( - ctx context.Context, - parser mydump.Parser, - outCh chan<- commitTask, -) (err error) { - defer func() { - r := recover() - if r != nil { - logutil.Logger(ctx).Error("process routine panicked", - zap.Any("r", r), - zap.Stack("stack")) - err = util.GetRecoverError(r) - } - }() - - checkKilled := time.NewTicker(30 * time.Second) - defer checkKilled.Stop() - - for { - // prepare batch and enqueue task - if err = w.readOneBatchRows(ctx, parser); err != nil { - return - } - if w.curBatchCnt == 0 { - return - } - - TrySendTask: - select { - case <-ctx.Done(): - return ctx.Err() - case <-checkKilled.C: - if err := w.killer.HandleSignal(); err != nil { - logutil.Logger(ctx).Info("load data query interrupted quit data processing") - return err - } - goto TrySendTask - case outCh <- commitTask{ - cnt: w.curBatchCnt, - rows: w.rows, - }: - } - // reset rows buffer, will reallocate buffer but NOT reuse - w.resetBatch() - } -} - -func (w *encodeWorker) resetBatch() { - w.rows = make([][]types.Datum, 0, w.maxRowsInBatch) - w.curBatchCnt = 0 -} - -// readOneBatchRows reads rows from parser. When parser's reader meet EOF, it -// will return nil. For other errors it will return directly. When the rows -// batch is full it will also return nil. -// The result rows are saved in w.rows and update some members, caller can check -// if curBatchCnt == 0 to know if reached EOF. -func (w *encodeWorker) readOneBatchRows(ctx context.Context, parser mydump.Parser) error { - for { - if err := parser.ReadRow(); err != nil { - if errors.Cause(err) == io.EOF { - return nil - } - return exeerrors.ErrLoadDataCantRead.GenWithStackByArgs( - err.Error(), - "Only the following formats delimited text file (csv, tsv), parquet, sql are supported. Please provide the valid source file(s)", - ) - } - // rowCount will be used in fillRow(), last insert ID will be assigned according to the rowCount = 1. - // So should add first here. - w.rowCount++ - r, err := w.parserData2TableData(ctx, parser.LastRow().Row) - if err != nil { - return err - } - parser.RecycleRow(parser.LastRow()) - w.rows = append(w.rows, r) - w.curBatchCnt++ - if w.maxRowsInBatch != 0 && w.rowCount%w.maxRowsInBatch == 0 { - logutil.Logger(ctx).Info("batch limit hit when inserting rows", zap.Int("maxBatchRows", w.MaxChunkSize()), - zap.Uint64("totalRows", w.rowCount)) - return nil - } - } -} - -// parserData2TableData encodes the data of parser output. -func (w *encodeWorker) parserData2TableData( - ctx context.Context, - parserData []types.Datum, -) ([]types.Datum, error) { - var errColNumMismatch error - switch { - case len(parserData) < w.controller.GetFieldCount(): - errColNumMismatch = exeerrors.ErrWarnTooFewRecords.GenWithStackByArgs(w.rowCount) - case len(parserData) > w.controller.GetFieldCount(): - errColNumMismatch = exeerrors.ErrWarnTooManyRecords.GenWithStackByArgs(w.rowCount) - } - - if errColNumMismatch != nil { - if w.controller.Restrictive { - return nil, errColNumMismatch - } - w.handleWarning(errColNumMismatch) - } - - row := make([]types.Datum, 0, len(w.insertColumns)) - sessionVars := w.Ctx().GetSessionVars() - setVar := func(name string, col *types.Datum) { - // User variable names are not case-sensitive - // https://dev.mysql.com/doc/refman/8.0/en/user-variables.html - name = strings.ToLower(name) - if col == nil || col.IsNull() { - sessionVars.UnsetUserVar(name) - } else { - sessionVars.SetUserVarVal(name, *col) - } - } - - fieldMappings := w.controller.FieldMappings - for i := 0; i < len(fieldMappings); i++ { - if i >= len(parserData) { - if fieldMappings[i].Column == nil { - setVar(fieldMappings[i].UserVar.Name, nil) - continue - } - - // If some columns is missing and their type is time and has not null flag, they should be set as current time. - if types.IsTypeTime(fieldMappings[i].Column.GetType()) && mysql.HasNotNullFlag(fieldMappings[i].Column.GetFlag()) { - row = append(row, types.NewTimeDatum(types.CurrentTime(fieldMappings[i].Column.GetType()))) - continue - } - - row = append(row, types.NewDatum(nil)) - continue - } - - if fieldMappings[i].Column == nil { - setVar(fieldMappings[i].UserVar.Name, &parserData[i]) - continue - } - - // Don't set the value for generated columns. - if fieldMappings[i].Column.IsGenerated() { - row = append(row, types.NewDatum(nil)) - continue - } - - row = append(row, parserData[i]) - } - for i := 0; i < len(w.colAssignExprs); i++ { - // eval expression of `SET` clause - d, err := w.colAssignExprs[i].Eval(w.Ctx().GetExprCtx().GetEvalCtx(), chunk.Row{}) - if err != nil { - if w.controller.Restrictive { - return nil, err - } - w.handleWarning(err) - } - row = append(row, d) - } - - // a new row buffer will be allocated in getRow - newRow, err := w.getRow(ctx, row) - if err != nil { - if w.controller.Restrictive { - return nil, err - } - w.handleWarning(err) - logutil.Logger(ctx).Error("failed to get row", zap.Error(err)) - // TODO: should not return nil! caller will panic when lookup index - return nil, nil - } - - return newRow, nil -} - -// commitWorker is a sub-worker of LoadDataWorker that dedicated to commit data. -type commitWorker struct { - *InsertValues - controller *importer.LoadDataController -} - -// commitWork commit batch sequentially. When returns nil, it means the job is -// finished. -func (w *commitWorker) commitWork(ctx context.Context, inCh <-chan commitTask) (err error) { - defer func() { - r := recover() - if r != nil { - logutil.Logger(ctx).Error("commitWork panicked", - zap.Any("r", r), - zap.Stack("stack")) - err = util.GetRecoverError(r) - } - }() - - var ( - taskCnt uint64 - ) - for { - select { - case <-ctx.Done(): - return ctx.Err() - case task, ok := <-inCh: - if !ok { - return nil - } - start := time.Now() - if err = w.commitOneTask(ctx, task); err != nil { - return err - } - taskCnt++ - logutil.Logger(ctx).Info("commit one task success", - zap.Duration("commit time usage", time.Since(start)), - zap.Uint64("keys processed", task.cnt), - zap.Uint64("taskCnt processed", taskCnt), - ) - } - } -} - -// commitOneTask insert Data from LoadDataWorker.rows, then commit the modification -// like a statement. -func (w *commitWorker) commitOneTask(ctx context.Context, task commitTask) error { - err := w.checkAndInsertOneBatch(ctx, task.rows, task.cnt) - if err != nil { - logutil.Logger(ctx).Error("commit error CheckAndInsert", zap.Error(err)) - return err - } - failpoint.Inject("commitOneTaskErr", func() { - failpoint.Return(errors.New("mock commit one task error")) - }) - return nil -} - -func (w *commitWorker) checkAndInsertOneBatch(ctx context.Context, rows [][]types.Datum, cnt uint64) error { - if w.stats != nil && w.stats.BasicRuntimeStats != nil { - // Since this method will not call by executor Next, - // so we need record the basic executor runtime stats by ourselves. - start := time.Now() - defer func() { - w.stats.BasicRuntimeStats.Record(time.Since(start), 0) - }() - } - var err error - if cnt == 0 { - return err - } - w.Ctx().GetSessionVars().StmtCtx.AddRecordRows(cnt) - - switch w.controller.OnDuplicate { - case ast.OnDuplicateKeyHandlingReplace: - return w.batchCheckAndInsert(ctx, rows[0:cnt], w.addRecordLD, true) - case ast.OnDuplicateKeyHandlingIgnore: - return w.batchCheckAndInsert(ctx, rows[0:cnt], w.addRecordLD, false) - case ast.OnDuplicateKeyHandlingError: - for i, row := range rows[0:cnt] { - sizeHintStep := int(w.Ctx().GetSessionVars().ShardAllocateStep) - if sizeHintStep > 0 && i%sizeHintStep == 0 { - sizeHint := sizeHintStep - remain := len(rows[0:cnt]) - i - if sizeHint > remain { - sizeHint = remain - } - err = w.addRecordWithAutoIDHint(ctx, row, sizeHint, table.DupKeyCheckDefault) - } else { - err = w.addRecord(ctx, row, table.DupKeyCheckDefault) - } - if err != nil { - return err - } - w.Ctx().GetSessionVars().StmtCtx.AddCopiedRows(1) - } - return nil - default: - return errors.Errorf("unknown on duplicate key handling: %v", w.controller.OnDuplicate) - } -} - -func (w *commitWorker) addRecordLD(ctx context.Context, row []types.Datum, dupKeyCheck table.DupKeyCheckMode) error { - if row == nil { - return nil - } - return w.addRecord(ctx, row, dupKeyCheck) -} - -// GetInfilePath get infile path. -func (e *LoadDataWorker) GetInfilePath() string { - return e.controller.Path -} - -// GetController get load data controller. -// used in unit test. -func (e *LoadDataWorker) GetController() *importer.LoadDataController { - return e.controller -} - -// TestLoadLocal is a helper function for unit test. -func (e *LoadDataWorker) TestLoadLocal(parser mydump.Parser) error { - if err := ResetContextOfStmt(e.UserSctx, &ast.LoadDataStmt{}); err != nil { - return err - } - setNonRestrictiveFlags(e.UserSctx.GetSessionVars().StmtCtx) - encoder, committer, err := initEncodeCommitWorkers(e) - if err != nil { - return err - } - - ctx := context.Background() - err = sessiontxn.NewTxn(ctx, e.UserSctx) - if err != nil { - return err - } - - for i := uint64(0); i < e.controller.IgnoreLines; i++ { - //nolint: errcheck - _ = parser.ReadRow() - } - - err = encoder.readOneBatchRows(ctx, parser) - if err != nil { - return err - } - - err = committer.checkAndInsertOneBatch( - ctx, - encoder.rows, - encoder.curBatchCnt) - if err != nil { - return err - } - encoder.resetBatch() - committer.Ctx().StmtCommit(ctx) - err = committer.Ctx().CommitTxn(ctx) - if err != nil { - return err - } - e.setResult(encoder.exprWarnings) - return nil -} - -var _ io.ReadSeekCloser = (*SimpleSeekerOnReadCloser)(nil) - -// SimpleSeekerOnReadCloser provides Seek(0, SeekCurrent) on ReadCloser. -type SimpleSeekerOnReadCloser struct { - r io.ReadCloser - pos int -} - -// NewSimpleSeekerOnReadCloser creates a SimpleSeekerOnReadCloser. -func NewSimpleSeekerOnReadCloser(r io.ReadCloser) *SimpleSeekerOnReadCloser { - return &SimpleSeekerOnReadCloser{r: r} -} - -// Read implements io.Reader. -func (s *SimpleSeekerOnReadCloser) Read(p []byte) (n int, err error) { - n, err = s.r.Read(p) - s.pos += n - return -} - -// Seek implements io.Seeker. -func (s *SimpleSeekerOnReadCloser) Seek(offset int64, whence int) (int64, error) { - // only support get reader's current offset - if offset == 0 && whence == io.SeekCurrent { - return int64(s.pos), nil - } - return 0, errors.Errorf("unsupported seek on SimpleSeekerOnReadCloser, offset: %d whence: %d", offset, whence) -} - -// Close implements io.Closer. -func (s *SimpleSeekerOnReadCloser) Close() error { - return s.r.Close() -} - -// GetFileSize implements storage.ExternalFileReader. -func (*SimpleSeekerOnReadCloser) GetFileSize() (int64, error) { - return 0, errors.Errorf("unsupported GetFileSize on SimpleSeekerOnReadCloser") -} - -// loadDataVarKeyType is a dummy type to avoid naming collision in context. -type loadDataVarKeyType int - -// String defines a Stringer function for debugging and pretty printing. -func (loadDataVarKeyType) String() string { - return "load_data_var" -} diff --git a/pkg/executor/memtable_reader.go b/pkg/executor/memtable_reader.go index 286ac5c9bb6f9..316439a638b99 100644 --- a/pkg/executor/memtable_reader.go +++ b/pkg/executor/memtable_reader.go @@ -170,12 +170,12 @@ func fetchClusterConfig(sctx sessionctx.Context, nodeTypes, nodeAddrs set.String return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("CONFIG") } serversInfo, err := infoschema.GetClusterServerInfo(sctx) - if val, _err_ := failpoint.Eval(_curpkg_("mockClusterConfigServerInfo")); _err_ == nil { + failpoint.Inject("mockClusterConfigServerInfo", func(val failpoint.Value) { if s := val.(string); len(s) > 0 { // erase the error serversInfo, err = parseFailpointServerInfo(s), nil } - } + }) if err != nil { return nil, err } @@ -394,13 +394,13 @@ func (e *clusterLogRetriever) initialize(ctx context.Context, sctx sessionctx.Co return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") } serversInfo, err := infoschema.GetClusterServerInfo(sctx) - if val, _err_ := failpoint.Eval(_curpkg_("mockClusterLogServerInfo")); _err_ == nil { + failpoint.Inject("mockClusterLogServerInfo", func(val failpoint.Value) { // erase the error err = nil if s := val.(string); len(s) > 0 { serversInfo = parseFailpointServerInfo(s) } - } + }) if err != nil { return nil, err } diff --git a/pkg/executor/memtable_reader.go__failpoint_stash__ b/pkg/executor/memtable_reader.go__failpoint_stash__ deleted file mode 100644 index 316439a638b99..0000000000000 --- a/pkg/executor/memtable_reader.go__failpoint_stash__ +++ /dev/null @@ -1,1009 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "cmp" - "container/heap" - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "slices" - "strings" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/diagnosticspb" - "github.com/pingcap/sysutil" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/set" - pd "github.com/tikv/pd/client/http" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" -) - -const clusterLogBatchSize = 256 -const hotRegionsHistoryBatchSize = 256 - -type dummyCloser struct{} - -func (dummyCloser) close() error { return nil } - -func (dummyCloser) getRuntimeStats() execdetails.RuntimeStats { return nil } - -type memTableRetriever interface { - retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) - close() error - getRuntimeStats() execdetails.RuntimeStats -} - -// MemTableReaderExec executes memTable information retrieving from the MemTable components -type MemTableReaderExec struct { - exec.BaseExecutor - table *model.TableInfo - retriever memTableRetriever - // cacheRetrieved is used to indicate whether has the parent executor retrieved - // from inspection cache in inspection mode. - cacheRetrieved bool -} - -func (*MemTableReaderExec) isInspectionCacheableTable(tblName string) bool { - switch tblName { - case strings.ToLower(infoschema.TableClusterConfig), - strings.ToLower(infoschema.TableClusterInfo), - strings.ToLower(infoschema.TableClusterSystemInfo), - strings.ToLower(infoschema.TableClusterLoad), - strings.ToLower(infoschema.TableClusterHardware): - return true - default: - return false - } -} - -// Next implements the Executor Next interface. -func (e *MemTableReaderExec) Next(ctx context.Context, req *chunk.Chunk) error { - var ( - rows [][]types.Datum - err error - ) - - // The `InspectionTableCache` will be assigned in the begin of retrieving` and be - // cleaned at the end of retrieving, so nil represents currently in non-inspection mode. - if cache, tbl := e.Ctx().GetSessionVars().InspectionTableCache, e.table.Name.L; cache != nil && - e.isInspectionCacheableTable(tbl) { - // TODO: cached rows will be returned fully, we should refactor this part. - if !e.cacheRetrieved { - // Obtain data from cache first. - cached, found := cache[tbl] - if !found { - rows, err := e.retriever.retrieve(ctx, e.Ctx()) - cached = variable.TableSnapshot{Rows: rows, Err: err} - cache[tbl] = cached - } - e.cacheRetrieved = true - rows, err = cached.Rows, cached.Err - } - } else { - rows, err = e.retriever.retrieve(ctx, e.Ctx()) - } - if err != nil { - return err - } - - if len(rows) == 0 { - req.Reset() - return nil - } - - req.GrowAndReset(len(rows)) - mutableRow := chunk.MutRowFromTypes(exec.RetTypes(e)) - for _, row := range rows { - mutableRow.SetDatums(row...) - req.AppendRow(mutableRow.ToRow()) - } - return nil -} - -// Close implements the Executor Close interface. -func (e *MemTableReaderExec) Close() error { - if stats := e.retriever.getRuntimeStats(); stats != nil && e.RuntimeStats() != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), stats) - } - return e.retriever.close() -} - -type clusterConfigRetriever struct { - dummyCloser - retrieved bool - extractor *plannercore.ClusterTableExtractor -} - -// retrieve implements the memTableRetriever interface -func (e *clusterConfigRetriever) retrieve(_ context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.extractor.SkipRequest || e.retrieved { - return nil, nil - } - e.retrieved = true - return fetchClusterConfig(sctx, e.extractor.NodeTypes, e.extractor.Instances) -} - -func fetchClusterConfig(sctx sessionctx.Context, nodeTypes, nodeAddrs set.StringSet) ([][]types.Datum, error) { - type result struct { - idx int - rows [][]types.Datum - err error - } - if !hasPriv(sctx, mysql.ConfigPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("CONFIG") - } - serversInfo, err := infoschema.GetClusterServerInfo(sctx) - failpoint.Inject("mockClusterConfigServerInfo", func(val failpoint.Value) { - if s := val.(string); len(s) > 0 { - // erase the error - serversInfo, err = parseFailpointServerInfo(s), nil - } - }) - if err != nil { - return nil, err - } - serversInfo = infoschema.FilterClusterServerInfo(serversInfo, nodeTypes, nodeAddrs) - //nolint: prealloc - var finalRows [][]types.Datum - wg := sync.WaitGroup{} - ch := make(chan result, len(serversInfo)) - for i, srv := range serversInfo { - typ := srv.ServerType - address := srv.Address - statusAddr := srv.StatusAddr - if len(statusAddr) == 0 { - sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("%s node %s does not contain status address", typ, address)) - continue - } - wg.Add(1) - go func(index int) { - util.WithRecovery(func() { - defer wg.Done() - var url string - switch typ { - case "pd": - url = fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), statusAddr, pd.Config) - case "tikv", "tidb", "tiflash": - url = fmt.Sprintf("%s://%s/config", util.InternalHTTPSchema(), statusAddr) - case "tiproxy": - url = fmt.Sprintf("%s://%s/api/admin/config?format=json", util.InternalHTTPSchema(), statusAddr) - case "ticdc": - url = fmt.Sprintf("%s://%s/config", util.InternalHTTPSchema(), statusAddr) - case "tso": - url = fmt.Sprintf("%s://%s/tso/api/v1/config", util.InternalHTTPSchema(), statusAddr) - case "scheduling": - url = fmt.Sprintf("%s://%s/scheduling/api/v1/config", util.InternalHTTPSchema(), statusAddr) - default: - ch <- result{err: errors.Errorf("currently we do not support get config from node type: %s(%s)", typ, address)} - return - } - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - ch <- result{err: errors.Trace(err)} - return - } - req.Header.Add("PD-Allow-follower-handle", "true") - resp, err := util.InternalHTTPClient().Do(req) - if err != nil { - ch <- result{err: errors.Trace(err)} - return - } - defer func() { - terror.Log(resp.Body.Close()) - }() - if resp.StatusCode != http.StatusOK { - ch <- result{err: errors.Errorf("request %s failed: %s", url, resp.Status)} - return - } - var nested map[string]any - if err = json.NewDecoder(resp.Body).Decode(&nested); err != nil { - ch <- result{err: errors.Trace(err)} - return - } - data := config.FlattenConfigItems(nested) - type item struct { - key string - val string - } - var items []item - for key, val := range data { - if config.ContainHiddenConfig(key) { - continue - } - var str string - switch val := val.(type) { - case string: // remove quotes - str = val - default: - tmp, err := json.Marshal(val) - if err != nil { - ch <- result{err: errors.Trace(err)} - return - } - str = string(tmp) - } - items = append(items, item{key: key, val: str}) - } - slices.SortFunc(items, func(i, j item) int { return cmp.Compare(i.key, j.key) }) - var rows [][]types.Datum - for _, item := range items { - rows = append(rows, types.MakeDatums( - typ, - address, - item.key, - item.val, - )) - } - ch <- result{idx: index, rows: rows} - }, nil) - }(i) - } - - wg.Wait() - close(ch) - - // Keep the original order to make the result more stable - var results []result //nolint: prealloc - for result := range ch { - if result.err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) - continue - } - results = append(results, result) - } - slices.SortFunc(results, func(i, j result) int { return cmp.Compare(i.idx, j.idx) }) - for _, result := range results { - finalRows = append(finalRows, result.rows...) - } - return finalRows, nil -} - -type clusterServerInfoRetriever struct { - dummyCloser - extractor *plannercore.ClusterTableExtractor - serverInfoType diagnosticspb.ServerInfoType - retrieved bool -} - -// retrieve implements the memTableRetriever interface -func (e *clusterServerInfoRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - switch e.serverInfoType { - case diagnosticspb.ServerInfoType_LoadInfo, - diagnosticspb.ServerInfoType_SystemInfo: - if !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - case diagnosticspb.ServerInfoType_HardwareInfo: - if !hasPriv(sctx, mysql.ConfigPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("CONFIG") - } - } - if e.extractor.SkipRequest || e.retrieved { - return nil, nil - } - e.retrieved = true - serversInfo, err := infoschema.GetClusterServerInfo(sctx) - if err != nil { - return nil, err - } - serversInfo = infoschema.FilterClusterServerInfo(serversInfo, e.extractor.NodeTypes, e.extractor.Instances) - return infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, e.serverInfoType, true) -} - -func parseFailpointServerInfo(s string) []infoschema.ServerInfo { - servers := strings.Split(s, ";") - serversInfo := make([]infoschema.ServerInfo, 0, len(servers)) - for _, server := range servers { - parts := strings.Split(server, ",") - serversInfo = append(serversInfo, infoschema.ServerInfo{ - StatusAddr: parts[2], - Address: parts[1], - ServerType: parts[0], - }) - } - return serversInfo -} - -type clusterLogRetriever struct { - isDrained bool - retrieving bool - heap *logResponseHeap - extractor *plannercore.ClusterLogTableExtractor - cancel context.CancelFunc -} - -type logStreamResult struct { - // Read the next stream result while current messages is drained - next chan logStreamResult - - addr string - typ string - messages []*diagnosticspb.LogMessage - err error -} - -type logResponseHeap []logStreamResult - -func (h logResponseHeap) Len() int { - return len(h) -} - -func (h logResponseHeap) Less(i, j int) bool { - if lhs, rhs := h[i].messages[0].Time, h[j].messages[0].Time; lhs != rhs { - return lhs < rhs - } - return h[i].typ < h[j].typ -} - -func (h logResponseHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -func (h *logResponseHeap) Push(x any) { - *h = append(*h, x.(logStreamResult)) -} - -func (h *logResponseHeap) Pop() any { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x -} - -func (e *clusterLogRetriever) initialize(ctx context.Context, sctx sessionctx.Context) ([]chan logStreamResult, error) { - if !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - serversInfo, err := infoschema.GetClusterServerInfo(sctx) - failpoint.Inject("mockClusterLogServerInfo", func(val failpoint.Value) { - // erase the error - err = nil - if s := val.(string); len(s) > 0 { - serversInfo = parseFailpointServerInfo(s) - } - }) - if err != nil { - return nil, err - } - - instances := e.extractor.Instances - nodeTypes := e.extractor.NodeTypes - serversInfo = infoschema.FilterClusterServerInfo(serversInfo, nodeTypes, instances) - - var levels = make([]diagnosticspb.LogLevel, 0, len(e.extractor.LogLevels)) - for l := range e.extractor.LogLevels { - levels = append(levels, sysutil.ParseLogLevel(l)) - } - - // To avoid search log interface overload, the user should specify the time range, and at least one pattern - // in normally SQL. - if e.extractor.StartTime == 0 { - return nil, errors.New("denied to scan logs, please specified the start time, such as `time > '2020-01-01 00:00:00'`") - } - if e.extractor.EndTime == 0 { - return nil, errors.New("denied to scan logs, please specified the end time, such as `time < '2020-01-01 00:00:00'`") - } - patterns := e.extractor.Patterns - if len(patterns) == 0 && len(levels) == 0 && len(instances) == 0 && len(nodeTypes) == 0 { - return nil, errors.New("denied to scan full logs (use `SELECT * FROM cluster_log WHERE message LIKE '%'` explicitly if intentionally)") - } - - req := &diagnosticspb.SearchLogRequest{ - StartTime: e.extractor.StartTime, - EndTime: e.extractor.EndTime, - Levels: levels, - Patterns: patterns, - } - - return e.startRetrieving(ctx, sctx, serversInfo, req) -} - -func (e *clusterLogRetriever) startRetrieving( - ctx context.Context, - sctx sessionctx.Context, - serversInfo []infoschema.ServerInfo, - req *diagnosticspb.SearchLogRequest) ([]chan logStreamResult, error) { - // gRPC options - opt := grpc.WithTransportCredentials(insecure.NewCredentials()) - security := config.GetGlobalConfig().Security - if len(security.ClusterSSLCA) != 0 { - clusterSecurity := security.ClusterSecurity() - tlsConfig, err := clusterSecurity.ToTLSConfig() - if err != nil { - return nil, errors.Trace(err) - } - opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) - } - - // The retrieve progress may be abort - ctx, e.cancel = context.WithCancel(ctx) - - var results []chan logStreamResult //nolint: prealloc - for _, srv := range serversInfo { - typ := srv.ServerType - address := srv.Address - statusAddr := srv.StatusAddr - if len(statusAddr) == 0 { - sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("%s node %s does not contain status address", typ, address)) - continue - } - ch := make(chan logStreamResult) - results = append(results, ch) - - go func(ch chan logStreamResult, serverType, address, statusAddr string) { - util.WithRecovery(func() { - defer close(ch) - - // TiDB and TiProxy provide diagnostics service via status address - remote := address - if serverType == "tidb" || serverType == "tiproxy" { - remote = statusAddr - } - conn, err := grpc.Dial(remote, opt) - if err != nil { - ch <- logStreamResult{addr: address, typ: serverType, err: err} - return - } - defer terror.Call(conn.Close) - - cli := diagnosticspb.NewDiagnosticsClient(conn) - stream, err := cli.SearchLog(ctx, req) - if err != nil { - ch <- logStreamResult{addr: address, typ: serverType, err: err} - return - } - - for { - res, err := stream.Recv() - if err != nil && err == io.EOF { - return - } - if err != nil { - select { - case ch <- logStreamResult{addr: address, typ: serverType, err: err}: - case <-ctx.Done(): - } - return - } - - result := logStreamResult{next: ch, addr: address, typ: serverType, messages: res.Messages} - select { - case ch <- result: - case <-ctx.Done(): - return - } - } - }, nil) - }(ch, typ, address, statusAddr) - } - - return results, nil -} - -func (e *clusterLogRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.extractor.SkipRequest || e.isDrained { - return nil, nil - } - - if !e.retrieving { - e.retrieving = true - results, err := e.initialize(ctx, sctx) - if err != nil { - e.isDrained = true - return nil, err - } - - // initialize the heap - e.heap = &logResponseHeap{} - for _, ch := range results { - result := <-ch - if result.err != nil || len(result.messages) == 0 { - if result.err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) - } - continue - } - *e.heap = append(*e.heap, result) - } - heap.Init(e.heap) - } - - // Merge the results - var finalRows [][]types.Datum - for e.heap.Len() > 0 && len(finalRows) < clusterLogBatchSize { - minTimeItem := heap.Pop(e.heap).(logStreamResult) - headMessage := minTimeItem.messages[0] - loggingTime := time.UnixMilli(headMessage.Time) - finalRows = append(finalRows, types.MakeDatums( - loggingTime.Format("2006/01/02 15:04:05.000"), - minTimeItem.typ, - minTimeItem.addr, - strings.ToUpper(headMessage.Level.String()), - headMessage.Message, - )) - minTimeItem.messages = minTimeItem.messages[1:] - // Current streaming result is drained, read the next to supply. - if len(minTimeItem.messages) == 0 { - result := <-minTimeItem.next - if result.err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) - continue - } - if len(result.messages) > 0 { - heap.Push(e.heap, result) - } - } else { - heap.Push(e.heap, minTimeItem) - } - } - - // All streams are drained - e.isDrained = e.heap.Len() == 0 - - return finalRows, nil -} - -func (e *clusterLogRetriever) close() error { - if e.cancel != nil { - e.cancel() - } - return nil -} - -func (*clusterLogRetriever) getRuntimeStats() execdetails.RuntimeStats { - return nil -} - -type hotRegionsResult struct { - addr string - messages *HistoryHotRegions - err error -} - -type hotRegionsResponseHeap []hotRegionsResult - -func (h hotRegionsResponseHeap) Len() int { - return len(h) -} - -func (h hotRegionsResponseHeap) Less(i, j int) bool { - lhs, rhs := h[i].messages.HistoryHotRegion[0], h[j].messages.HistoryHotRegion[0] - if lhs.UpdateTime != rhs.UpdateTime { - return lhs.UpdateTime < rhs.UpdateTime - } - return lhs.HotDegree < rhs.HotDegree -} - -func (h hotRegionsResponseHeap) Swap(i, j int) { - h[i], h[j] = h[j], h[i] -} - -func (h *hotRegionsResponseHeap) Push(x any) { - *h = append(*h, x.(hotRegionsResult)) -} - -func (h *hotRegionsResponseHeap) Pop() any { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x -} - -type hotRegionsHistoryRetriver struct { - dummyCloser - isDrained bool - retrieving bool - heap *hotRegionsResponseHeap - extractor *plannercore.HotRegionsHistoryTableExtractor -} - -// HistoryHotRegionsRequest wrap conditions push down to PD. -type HistoryHotRegionsRequest struct { - StartTime int64 `json:"start_time,omitempty"` - EndTime int64 `json:"end_time,omitempty"` - RegionIDs []uint64 `json:"region_ids,omitempty"` - StoreIDs []uint64 `json:"store_ids,omitempty"` - PeerIDs []uint64 `json:"peer_ids,omitempty"` - IsLearners []bool `json:"is_learners,omitempty"` - IsLeaders []bool `json:"is_leaders,omitempty"` - HotRegionTypes []string `json:"hot_region_type,omitempty"` -} - -// HistoryHotRegions records filtered hot regions stored in each PD. -// it's the response of PD. -type HistoryHotRegions struct { - HistoryHotRegion []*HistoryHotRegion `json:"history_hot_region"` -} - -// HistoryHotRegion records each hot region's statistics. -// it's the response of PD. -type HistoryHotRegion struct { - UpdateTime int64 `json:"update_time"` - RegionID uint64 `json:"region_id"` - StoreID uint64 `json:"store_id"` - PeerID uint64 `json:"peer_id"` - IsLearner bool `json:"is_learner"` - IsLeader bool `json:"is_leader"` - HotRegionType string `json:"hot_region_type"` - HotDegree int64 `json:"hot_degree"` - FlowBytes float64 `json:"flow_bytes"` - KeyRate float64 `json:"key_rate"` - QueryRate float64 `json:"query_rate"` - StartKey string `json:"start_key"` - EndKey string `json:"end_key"` -} - -func (e *hotRegionsHistoryRetriver) initialize(_ context.Context, sctx sessionctx.Context) ([]chan hotRegionsResult, error) { - if !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - pdServers, err := infoschema.GetPDServerInfo(sctx) - if err != nil { - return nil, err - } - - // To avoid search hot regions interface overload, the user should specify the time range in normally SQL. - if e.extractor.StartTime == 0 { - return nil, errors.New("denied to scan hot regions, please specified the start time, such as `update_time > '2020-01-01 00:00:00'`") - } - if e.extractor.EndTime == 0 { - return nil, errors.New("denied to scan hot regions, please specified the end time, such as `update_time < '2020-01-01 00:00:00'`") - } - - historyHotRegionsRequest := &HistoryHotRegionsRequest{ - StartTime: e.extractor.StartTime, - EndTime: e.extractor.EndTime, - RegionIDs: e.extractor.RegionIDs, - StoreIDs: e.extractor.StoreIDs, - PeerIDs: e.extractor.PeerIDs, - IsLearners: e.extractor.IsLearners, - IsLeaders: e.extractor.IsLeaders, - } - - return e.startRetrieving(pdServers, historyHotRegionsRequest) -} - -func (e *hotRegionsHistoryRetriver) startRetrieving( - pdServers []infoschema.ServerInfo, - req *HistoryHotRegionsRequest, -) ([]chan hotRegionsResult, error) { - var results []chan hotRegionsResult - for _, srv := range pdServers { - for typ := range e.extractor.HotRegionTypes { - req.HotRegionTypes = []string{typ} - jsonBody, err := json.Marshal(req) - if err != nil { - return nil, err - } - body := bytes.NewBuffer(jsonBody) - ch := make(chan hotRegionsResult) - results = append(results, ch) - go func(ch chan hotRegionsResult, address string, body *bytes.Buffer) { - util.WithRecovery(func() { - defer close(ch) - url := fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), address, pd.HotHistory) - req, err := http.NewRequest(http.MethodGet, url, body) - if err != nil { - ch <- hotRegionsResult{err: errors.Trace(err)} - return - } - req.Header.Add("PD-Allow-follower-handle", "true") - resp, err := util.InternalHTTPClient().Do(req) - if err != nil { - ch <- hotRegionsResult{err: errors.Trace(err)} - return - } - defer func() { - terror.Log(resp.Body.Close()) - }() - if resp.StatusCode != http.StatusOK { - ch <- hotRegionsResult{err: errors.Errorf("request %s failed: %s", url, resp.Status)} - return - } - var historyHotRegions HistoryHotRegions - if err = json.NewDecoder(resp.Body).Decode(&historyHotRegions); err != nil { - ch <- hotRegionsResult{err: errors.Trace(err)} - return - } - ch <- hotRegionsResult{addr: address, messages: &historyHotRegions} - }, nil) - }(ch, srv.StatusAddr, body) - } - } - return results, nil -} - -func (e *hotRegionsHistoryRetriver) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.extractor.SkipRequest || e.isDrained { - return nil, nil - } - - if !e.retrieving { - e.retrieving = true - results, err := e.initialize(ctx, sctx) - if err != nil { - e.isDrained = true - return nil, err - } - // Initialize the heap - e.heap = &hotRegionsResponseHeap{} - for _, ch := range results { - result := <-ch - if result.err != nil || len(result.messages.HistoryHotRegion) == 0 { - if result.err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(result.err) - } - continue - } - *e.heap = append(*e.heap, result) - } - heap.Init(e.heap) - } - // Merge the results - var finalRows [][]types.Datum - tikvStore, ok := sctx.GetStore().(helper.Storage) - if !ok { - return nil, errors.New("Information about hot region can be gotten only when the storage is TiKV") - } - tikvHelper := &helper.Helper{ - Store: tikvStore, - RegionCache: tikvStore.GetRegionCache(), - } - tz := sctx.GetSessionVars().Location() - is := sessiontxn.GetTxnManager(sctx).GetTxnInfoSchema() - tables := tikvHelper.GetTablesInfoWithKeyRange(is, tikvHelper.FilterMemDBs) - for e.heap.Len() > 0 && len(finalRows) < hotRegionsHistoryBatchSize { - minTimeItem := heap.Pop(e.heap).(hotRegionsResult) - rows, err := e.getHotRegionRowWithSchemaInfo(minTimeItem.messages.HistoryHotRegion[0], tikvHelper, tables, tz) - if err != nil { - return nil, err - } - if rows != nil { - finalRows = append(finalRows, rows...) - } - minTimeItem.messages.HistoryHotRegion = minTimeItem.messages.HistoryHotRegion[1:] - // Fetch next message item - if len(minTimeItem.messages.HistoryHotRegion) != 0 { - heap.Push(e.heap, minTimeItem) - } - } - // All streams are drained - e.isDrained = e.heap.Len() == 0 - return finalRows, nil -} - -func (*hotRegionsHistoryRetriver) getHotRegionRowWithSchemaInfo( - hisHotRegion *HistoryHotRegion, - tikvHelper *helper.Helper, - tables []helper.TableInfoWithKeyRange, - tz *time.Location, -) ([][]types.Datum, error) { - regionsInfo := []*pd.RegionInfo{ - { - ID: int64(hisHotRegion.RegionID), - StartKey: hisHotRegion.StartKey, - EndKey: hisHotRegion.EndKey, - }} - regionsTableInfos := tikvHelper.ParseRegionsTableInfos(regionsInfo, tables) - - var rows [][]types.Datum - // Ignore row without corresponding schema. - if tableInfos, ok := regionsTableInfos[int64(hisHotRegion.RegionID)]; ok { - for _, tableInfo := range tableInfos { - updateTimestamp := time.UnixMilli(hisHotRegion.UpdateTime) - if updateTimestamp.Location() != tz { - updateTimestamp.In(tz) - } - updateTime := types.NewTime(types.FromGoTime(updateTimestamp), mysql.TypeTimestamp, types.MinFsp) - row := make([]types.Datum, len(infoschema.GetTableTiDBHotRegionsHistoryCols())) - row[0].SetMysqlTime(updateTime) - row[1].SetString(strings.ToUpper(tableInfo.DB.Name.O), mysql.DefaultCollationName) - row[2].SetString(strings.ToUpper(tableInfo.Table.Name.O), mysql.DefaultCollationName) - row[3].SetInt64(tableInfo.Table.ID) - if tableInfo.IsIndex { - row[4].SetString(strings.ToUpper(tableInfo.Index.Name.O), mysql.DefaultCollationName) - row[5].SetInt64(tableInfo.Index.ID) - } else { - row[4].SetNull() - row[5].SetNull() - } - row[6].SetInt64(int64(hisHotRegion.RegionID)) - row[7].SetInt64(int64(hisHotRegion.StoreID)) - row[8].SetInt64(int64(hisHotRegion.PeerID)) - if hisHotRegion.IsLearner { - row[9].SetInt64(1) - } else { - row[9].SetInt64(0) - } - if hisHotRegion.IsLeader { - row[10].SetInt64(1) - } else { - row[10].SetInt64(0) - } - row[11].SetString(strings.ToUpper(hisHotRegion.HotRegionType), mysql.DefaultCollationName) - row[12].SetInt64(hisHotRegion.HotDegree) - row[13].SetFloat64(hisHotRegion.FlowBytes) - row[14].SetFloat64(hisHotRegion.KeyRate) - row[15].SetFloat64(hisHotRegion.QueryRate) - rows = append(rows, row) - } - } - - return rows, nil -} - -type tikvRegionPeersRetriever struct { - dummyCloser - extractor *plannercore.TikvRegionPeersExtractor - retrieved bool -} - -func (e *tikvRegionPeersRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.extractor.SkipRequest || e.retrieved { - return nil, nil - } - e.retrieved = true - tikvStore, ok := sctx.GetStore().(helper.Storage) - if !ok { - return nil, errors.New("Information about hot region can be gotten only when the storage is TiKV") - } - tikvHelper := &helper.Helper{ - Store: tikvStore, - RegionCache: tikvStore.GetRegionCache(), - } - pdCli, err := tikvHelper.TryGetPDHTTPClient() - if err != nil { - return nil, err - } - - var regionsInfo, regionsInfoByStoreID []pd.RegionInfo - regionMap := make(map[int64]*pd.RegionInfo) - storeMap := make(map[int64]struct{}) - - if len(e.extractor.StoreIDs) == 0 && len(e.extractor.RegionIDs) == 0 { - regionsInfo, err := pdCli.GetRegions(ctx) - if err != nil { - return nil, err - } - return e.packTiKVRegionPeersRows(regionsInfo.Regions, storeMap) - } - - for _, storeID := range e.extractor.StoreIDs { - // if a region_id located in 1, 4, 7 store we will get all of them when request any store_id, - // storeMap is used to filter peers on unexpected stores. - storeMap[int64(storeID)] = struct{}{} - storeRegionsInfo, err := pdCli.GetRegionsByStoreID(ctx, storeID) - if err != nil { - return nil, err - } - for i, regionInfo := range storeRegionsInfo.Regions { - // regionMap is used to remove dup regions and record the region in regionsInfoByStoreID. - if _, ok := regionMap[regionInfo.ID]; !ok { - regionsInfoByStoreID = append(regionsInfoByStoreID, regionInfo) - regionMap[regionInfo.ID] = &storeRegionsInfo.Regions[i] - } - } - } - - if len(e.extractor.RegionIDs) == 0 { - return e.packTiKVRegionPeersRows(regionsInfoByStoreID, storeMap) - } - - for _, regionID := range e.extractor.RegionIDs { - regionInfoByStoreID, ok := regionMap[int64(regionID)] - if !ok { - // if there is storeIDs, target region_id is fetched by storeIDs, - // otherwise we need to fetch it from PD. - if len(e.extractor.StoreIDs) == 0 { - regionInfo, err := pdCli.GetRegionByID(ctx, regionID) - if err != nil { - return nil, err - } - regionsInfo = append(regionsInfo, *regionInfo) - } - } else { - regionsInfo = append(regionsInfo, *regionInfoByStoreID) - } - } - - return e.packTiKVRegionPeersRows(regionsInfo, storeMap) -} - -func (e *tikvRegionPeersRetriever) isUnexpectedStoreID(storeID int64, storeMap map[int64]struct{}) bool { - if len(e.extractor.StoreIDs) == 0 { - return false - } - if _, ok := storeMap[storeID]; ok { - return false - } - return true -} - -func (e *tikvRegionPeersRetriever) packTiKVRegionPeersRows( - regionsInfo []pd.RegionInfo, storeMap map[int64]struct{}) ([][]types.Datum, error) { - //nolint: prealloc - var rows [][]types.Datum - for _, region := range regionsInfo { - records := make([][]types.Datum, 0, len(region.Peers)) - pendingPeerIDSet := set.NewInt64Set() - for _, peer := range region.PendingPeers { - pendingPeerIDSet.Insert(peer.ID) - } - downPeerMap := make(map[int64]int64, len(region.DownPeers)) - for _, peerStat := range region.DownPeers { - downPeerMap[peerStat.Peer.ID] = peerStat.DownSec - } - for _, peer := range region.Peers { - // isUnexpectedStoreID return true if we should filter this peer. - if e.isUnexpectedStoreID(peer.StoreID, storeMap) { - continue - } - - row := make([]types.Datum, len(infoschema.GetTableTiKVRegionPeersCols())) - row[0].SetInt64(region.ID) - row[1].SetInt64(peer.ID) - row[2].SetInt64(peer.StoreID) - if peer.IsLearner { - row[3].SetInt64(1) - } else { - row[3].SetInt64(0) - } - if peer.ID == region.Leader.ID { - row[4].SetInt64(1) - } else { - row[4].SetInt64(0) - } - if downSec, ok := downPeerMap[peer.ID]; ok { - row[5].SetString(downPeer, mysql.DefaultCollationName) - row[6].SetInt64(downSec) - } else if pendingPeerIDSet.Exist(peer.ID) { - row[5].SetString(pendingPeer, mysql.DefaultCollationName) - } else { - row[5].SetString(normalPeer, mysql.DefaultCollationName) - } - records = append(records, row) - } - rows = append(rows, records...) - } - return rows, nil -} diff --git a/pkg/executor/metrics_reader.go b/pkg/executor/metrics_reader.go index 3bfaa6907d568..31a0073148584 100644 --- a/pkg/executor/metrics_reader.go +++ b/pkg/executor/metrics_reader.go @@ -57,12 +57,12 @@ func (e *MetricRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) } e.retrieved = true - if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockMetricsTableData")); _err_ == nil { + failpoint.InjectContext(ctx, "mockMetricsTableData", func() { m, ok := ctx.Value("__mockMetricsTableData").(map[string][][]types.Datum) if ok && m[e.table.Name.L] != nil { - return m[e.table.Name.L], nil + failpoint.Return(m[e.table.Name.L], nil) } - } + }) tblDef, err := infoschema.GetMetricTableDef(e.table.Name.L) if err != nil { @@ -94,9 +94,9 @@ func (e *MetricRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) type MockMetricsPromDataKey struct{} func (e *MetricRetriever) queryMetric(ctx context.Context, sctx sessionctx.Context, queryRange promv1.Range, quantile float64) (result pmodel.Value, err error) { - if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockMetricsPromData")); _err_ == nil { - return ctx.Value(MockMetricsPromDataKey{}).(pmodel.Matrix), nil - } + failpoint.InjectContext(ctx, "mockMetricsPromData", func() { + failpoint.Return(ctx.Value(MockMetricsPromDataKey{}).(pmodel.Matrix), nil) + }) // Add retry to avoid network error. var prometheusAddr string diff --git a/pkg/executor/metrics_reader.go__failpoint_stash__ b/pkg/executor/metrics_reader.go__failpoint_stash__ deleted file mode 100644 index 31a0073148584..0000000000000 --- a/pkg/executor/metrics_reader.go__failpoint_stash__ +++ /dev/null @@ -1,365 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "fmt" - "math" - "slices" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/prometheus/client_golang/api" - promv1 "github.com/prometheus/client_golang/api/prometheus/v1" - pmodel "github.com/prometheus/common/model" -) - -const promReadTimeout = time.Second * 10 - -// MetricRetriever uses to read metric data. -type MetricRetriever struct { - dummyCloser - table *model.TableInfo - tblDef *infoschema.MetricTableDef - extractor *plannercore.MetricTableExtractor - retrieved bool -} - -func (e *MetricRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if e.retrieved || e.extractor.SkipRequest { - return nil, nil - } - e.retrieved = true - - failpoint.InjectContext(ctx, "mockMetricsTableData", func() { - m, ok := ctx.Value("__mockMetricsTableData").(map[string][][]types.Datum) - if ok && m[e.table.Name.L] != nil { - failpoint.Return(m[e.table.Name.L], nil) - } - }) - - tblDef, err := infoschema.GetMetricTableDef(e.table.Name.L) - if err != nil { - return nil, err - } - e.tblDef = tblDef - queryRange := e.getQueryRange(sctx) - totalRows := make([][]types.Datum, 0) - quantiles := e.extractor.Quantiles - if len(quantiles) == 0 { - quantiles = []float64{tblDef.Quantile} - } - for _, quantile := range quantiles { - var queryValue pmodel.Value - queryValue, err = e.queryMetric(ctx, sctx, queryRange, quantile) - if err != nil { - if err1, ok := err.(*promv1.Error); ok { - return nil, errors.Errorf("query metric error, msg: %v, detail: %v", err1.Msg, err1.Detail) - } - return nil, errors.Errorf("query metric error: %v", err.Error()) - } - partRows := e.genRows(queryValue, quantile) - totalRows = append(totalRows, partRows...) - } - return totalRows, nil -} - -// MockMetricsPromDataKey is for test -type MockMetricsPromDataKey struct{} - -func (e *MetricRetriever) queryMetric(ctx context.Context, sctx sessionctx.Context, queryRange promv1.Range, quantile float64) (result pmodel.Value, err error) { - failpoint.InjectContext(ctx, "mockMetricsPromData", func() { - failpoint.Return(ctx.Value(MockMetricsPromDataKey{}).(pmodel.Matrix), nil) - }) - - // Add retry to avoid network error. - var prometheusAddr string - for i := 0; i < 5; i++ { - //TODO: the prometheus will be Integrated into the PD, then we need to query the prometheus in PD directly, which need change the quire API - prometheusAddr, err = infosync.GetPrometheusAddr() - if err == nil || err == infosync.ErrPrometheusAddrIsNotSet { - break - } - time.Sleep(100 * time.Millisecond) - } - if err != nil { - return nil, err - } - promClient, err := api.NewClient(api.Config{ - Address: prometheusAddr, - }) - if err != nil { - return nil, err - } - promQLAPI := promv1.NewAPI(promClient) - ctx, cancel := context.WithTimeout(ctx, promReadTimeout) - defer cancel() - promQL := e.tblDef.GenPromQL(sctx.GetSessionVars().MetricSchemaRangeDuration, e.extractor.LabelConditions, quantile) - - // Add retry to avoid network error. - for i := 0; i < 5; i++ { - result, _, err = promQLAPI.QueryRange(ctx, promQL, queryRange) - if err == nil { - break - } - time.Sleep(100 * time.Millisecond) - } - return result, err -} - -type promQLQueryRange = promv1.Range - -func (e *MetricRetriever) getQueryRange(sctx sessionctx.Context) promQLQueryRange { - startTime, endTime := e.extractor.StartTime, e.extractor.EndTime - step := time.Second * time.Duration(sctx.GetSessionVars().MetricSchemaStep) - return promQLQueryRange{Start: startTime, End: endTime, Step: step} -} - -func (e *MetricRetriever) genRows(value pmodel.Value, quantile float64) [][]types.Datum { - var rows [][]types.Datum - if value.Type() == pmodel.ValMatrix { - matrix := value.(pmodel.Matrix) - for _, m := range matrix { - for _, v := range m.Values { - record := e.genRecord(m.Metric, v, quantile) - rows = append(rows, record) - } - } - } - return rows -} - -func (e *MetricRetriever) genRecord(metric pmodel.Metric, pair pmodel.SamplePair, quantile float64) []types.Datum { - record := make([]types.Datum, 0, 2+len(e.tblDef.Labels)+1) - // Record order should keep same with genColumnInfos. - record = append(record, types.NewTimeDatum(types.NewTime( - types.FromGoTime(time.UnixMilli(int64(pair.Timestamp))), - mysql.TypeDatetime, - types.MaxFsp, - ))) - for _, label := range e.tblDef.Labels { - v := "" - if metric != nil { - v = string(metric[pmodel.LabelName(label)]) - } - if len(v) == 0 { - v = infoschema.GenLabelConditionValues(e.extractor.LabelConditions[strings.ToLower(label)]) - } - record = append(record, types.NewStringDatum(v)) - } - if e.tblDef.Quantile > 0 { - record = append(record, types.NewFloat64Datum(quantile)) - } - if math.IsNaN(float64(pair.Value)) { - record = append(record, types.NewDatum(nil)) - } else { - record = append(record, types.NewFloat64Datum(float64(pair.Value))) - } - return record -} - -// MetricsSummaryRetriever uses to read metric data. -type MetricsSummaryRetriever struct { - dummyCloser - table *model.TableInfo - extractor *plannercore.MetricSummaryTableExtractor - timeRange plannerutil.QueryTimeRange - retrieved bool -} - -func (e *MetricsSummaryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - if e.retrieved || e.extractor.SkipRequest { - return nil, nil - } - e.retrieved = true - totalRows := make([][]types.Datum, 0, len(infoschema.MetricTableMap)) - tables := make([]string, 0, len(infoschema.MetricTableMap)) - for name := range infoschema.MetricTableMap { - tables = append(tables, name) - } - slices.Sort(tables) - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) - filter := inspectionFilter{set: e.extractor.MetricsNames} - condition := e.timeRange.Condition() - for _, name := range tables { - if !filter.enable(name) { - continue - } - def, found := infoschema.MetricTableMap[name] - if !found { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", name)) - continue - } - var sql string - if def.Quantile > 0 { - var qs []string - if len(e.extractor.Quantiles) > 0 { - for _, q := range e.extractor.Quantiles { - qs = append(qs, fmt.Sprintf("%f", q)) - } - } else { - qs = []string{"0.99"} - } - sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value),quantile from `%[2]s`.`%[1]s` %[3]s and quantile in (%[4]s) group by quantile order by quantile", - name, util.MetricSchemaName.L, condition, strings.Join(qs, ",")) - } else { - sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%[2]s`.`%[1]s` %[3]s", - name, util.MetricSchemaName.L, condition) - } - - exec := sctx.GetRestrictedSQLExecutor() - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) - if err != nil { - return nil, errors.Errorf("execute '%s' failed: %v", sql, err) - } - for _, row := range rows { - var quantile any - if def.Quantile > 0 { - quantile = row.GetFloat64(row.Len() - 1) - } - totalRows = append(totalRows, types.MakeDatums( - name, - quantile, - row.GetFloat64(0), - row.GetFloat64(1), - row.GetFloat64(2), - row.GetFloat64(3), - def.Comment, - )) - } - } - return totalRows, nil -} - -// MetricsSummaryByLabelRetriever uses to read metric detail data. -type MetricsSummaryByLabelRetriever struct { - dummyCloser - table *model.TableInfo - extractor *plannercore.MetricSummaryTableExtractor - timeRange plannerutil.QueryTimeRange - retrieved bool -} - -func (e *MetricsSummaryByLabelRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if !hasPriv(sctx, mysql.ProcessPriv) { - return nil, plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("PROCESS") - } - if e.retrieved || e.extractor.SkipRequest { - return nil, nil - } - e.retrieved = true - totalRows := make([][]types.Datum, 0, len(infoschema.MetricTableMap)) - tables := make([]string, 0, len(infoschema.MetricTableMap)) - for name := range infoschema.MetricTableMap { - tables = append(tables, name) - } - slices.Sort(tables) - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) - filter := inspectionFilter{set: e.extractor.MetricsNames} - condition := e.timeRange.Condition() - for _, name := range tables { - if !filter.enable(name) { - continue - } - def, found := infoschema.MetricTableMap[name] - if !found { - sctx.GetSessionVars().StmtCtx.AppendWarning(fmt.Errorf("metrics table: %s not found", name)) - continue - } - cols := def.Labels - cond := condition - if def.Quantile > 0 { - cols = append(cols, "quantile") - if len(e.extractor.Quantiles) > 0 { - qs := make([]string, len(e.extractor.Quantiles)) - for i, q := range e.extractor.Quantiles { - qs[i] = fmt.Sprintf("%f", q) - } - cond += " and quantile in (" + strings.Join(qs, ",") + ")" - } else { - cond += " and quantile=0.99" - } - } - var sql string - if len(cols) > 0 { - sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value),`%s` from `%s`.`%s` %s group by `%[1]s` order by `%[1]s`", - strings.Join(cols, "`,`"), util.MetricSchemaName.L, name, cond) - } else { - sql = fmt.Sprintf("select sum(value),avg(value),min(value),max(value) from `%s`.`%s` %s", - util.MetricSchemaName.L, name, cond) - } - exec := sctx.GetRestrictedSQLExecutor() - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql) - if err != nil { - return nil, errors.Errorf("execute '%s' failed: %v", sql, err) - } - nonInstanceLabelIndex := 0 - if len(def.Labels) > 0 && def.Labels[0] == "instance" { - nonInstanceLabelIndex = 1 - } - // skip sum/avg/min/max - const skipCols = 4 - for _, row := range rows { - instance := "" - if nonInstanceLabelIndex > 0 { - instance = row.GetString(skipCols) // sum/avg/min/max - } - var labels []string - for i, label := range def.Labels[nonInstanceLabelIndex:] { - // skip min/max/avg/instance - val := row.GetString(skipCols + nonInstanceLabelIndex + i) - if label == "store" || label == "store_id" { - val = fmt.Sprintf("store_id:%s", val) - } - labels = append(labels, val) - } - var quantile any - if def.Quantile > 0 { - quantile = row.GetFloat64(row.Len() - 1) // quantile will be the last column - } - totalRows = append(totalRows, types.MakeDatums( - instance, - name, - strings.Join(labels, ", "), - quantile, - row.GetFloat64(0), // sum - row.GetFloat64(1), // avg - row.GetFloat64(2), // min - row.GetFloat64(3), // max - def.Comment, - )) - } - } - return totalRows, nil -} diff --git a/pkg/executor/parallel_apply.go b/pkg/executor/parallel_apply.go index f855b991e72a5..03820652df933 100644 --- a/pkg/executor/parallel_apply.go +++ b/pkg/executor/parallel_apply.go @@ -208,7 +208,7 @@ func (e *ParallelNestedLoopApplyExec) outerWorker(ctx context.Context) { var selected []bool var err error for { - failpoint.Eval(_curpkg_("parallelApplyOuterWorkerPanic")) + failpoint.Inject("parallelApplyOuterWorkerPanic", nil) chk := exec.TryNewCacheChunk(e.outerExec) if err := exec.Next(ctx, e.outerExec, chk); err != nil { e.putResult(nil, err) @@ -246,7 +246,7 @@ func (e *ParallelNestedLoopApplyExec) innerWorker(ctx context.Context, id int) { case <-e.exit: return } - failpoint.Eval(_curpkg_("parallelApplyInnerWorkerPanic")) + failpoint.Inject("parallelApplyInnerWorkerPanic", nil) err := e.fillInnerChunk(ctx, id, chk) if err == nil && chk.NumRows() == 0 { // no more data, this goroutine can exit return @@ -292,7 +292,7 @@ func (e *ParallelNestedLoopApplyExec) fetchAllInners(ctx context.Context, id int } if e.useCache { // look up the cache atomic.AddInt64(&e.cacheAccessCounter, 1) - failpoint.Eval(_curpkg_("parallelApplyGetCachePanic")) + failpoint.Inject("parallelApplyGetCachePanic", nil) value, err := e.cache.Get(key) if err != nil { return err @@ -339,7 +339,7 @@ func (e *ParallelNestedLoopApplyExec) fetchAllInners(ctx context.Context, id int } if e.useCache { // update the cache - failpoint.Eval(_curpkg_("parallelApplySetCachePanic")) + failpoint.Inject("parallelApplySetCachePanic", nil) if _, err := e.cache.Set(key, e.innerList[id]); err != nil { return err } diff --git a/pkg/executor/parallel_apply.go__failpoint_stash__ b/pkg/executor/parallel_apply.go__failpoint_stash__ deleted file mode 100644 index 03820652df933..0000000000000 --- a/pkg/executor/parallel_apply.go__failpoint_stash__ +++ /dev/null @@ -1,405 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "runtime/trace" - "sync" - "sync/atomic" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/applycache" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/join" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" -) - -type result struct { - chk *chunk.Chunk - err error -} - -type outerRow struct { - row *chunk.Row - selected bool // if this row is selected by the outer side -} - -// ParallelNestedLoopApplyExec is the executor for apply. -type ParallelNestedLoopApplyExec struct { - exec.BaseExecutor - - // outer-side fields - outerExec exec.Executor - outerFilter expression.CNFExprs - outerList *chunk.List - outer bool - - // inner-side fields - // use slices since the inner side is paralleled - corCols [][]*expression.CorrelatedColumn - innerFilter []expression.CNFExprs - innerExecs []exec.Executor - innerList []*chunk.List - innerChunk []*chunk.Chunk - innerSelected [][]bool - innerIter []chunk.Iterator - outerRow []*chunk.Row - hasMatch []bool - hasNull []bool - joiners []join.Joiner - - // fields about concurrency control - concurrency int - started uint32 - drained uint32 // drained == true indicates there is no more data - freeChkCh chan *chunk.Chunk - resultChkCh chan result - outerRowCh chan outerRow - exit chan struct{} - workerWg sync.WaitGroup - notifyWg sync.WaitGroup - - // fields about cache - cache *applycache.ApplyCache - useCache bool - cacheHitCounter int64 - cacheAccessCounter int64 - - memTracker *memory.Tracker // track memory usage. -} - -// Open implements the Executor interface. -func (e *ParallelNestedLoopApplyExec) Open(ctx context.Context) error { - err := exec.Open(ctx, e.outerExec) - if err != nil { - return err - } - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - - e.outerList = chunk.NewList(exec.RetTypes(e.outerExec), e.InitCap(), e.MaxChunkSize()) - e.outerList.GetMemTracker().SetLabel(memory.LabelForOuterList) - e.outerList.GetMemTracker().AttachTo(e.memTracker) - - e.innerList = make([]*chunk.List, e.concurrency) - e.innerChunk = make([]*chunk.Chunk, e.concurrency) - e.innerSelected = make([][]bool, e.concurrency) - e.innerIter = make([]chunk.Iterator, e.concurrency) - e.outerRow = make([]*chunk.Row, e.concurrency) - e.hasMatch = make([]bool, e.concurrency) - e.hasNull = make([]bool, e.concurrency) - for i := 0; i < e.concurrency; i++ { - e.innerChunk[i] = exec.TryNewCacheChunk(e.innerExecs[i]) - e.innerList[i] = chunk.NewList(exec.RetTypes(e.innerExecs[i]), e.InitCap(), e.MaxChunkSize()) - e.innerList[i].GetMemTracker().SetLabel(memory.LabelForInnerList) - e.innerList[i].GetMemTracker().AttachTo(e.memTracker) - } - - e.freeChkCh = make(chan *chunk.Chunk, e.concurrency) - e.resultChkCh = make(chan result, e.concurrency+1) // innerWorkers + outerWorker - e.outerRowCh = make(chan outerRow) - e.exit = make(chan struct{}) - for i := 0; i < e.concurrency; i++ { - e.freeChkCh <- exec.NewFirstChunk(e) - } - - if e.useCache { - if e.cache, err = applycache.NewApplyCache(e.Ctx()); err != nil { - return err - } - e.cache.GetMemTracker().AttachTo(e.memTracker) - } - return nil -} - -// Next implements the Executor interface. -func (e *ParallelNestedLoopApplyExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { - if atomic.LoadUint32(&e.drained) == 1 { - req.Reset() - return nil - } - - if atomic.CompareAndSwapUint32(&e.started, 0, 1) { - e.workerWg.Add(1) - go e.outerWorker(ctx) - for i := 0; i < e.concurrency; i++ { - e.workerWg.Add(1) - workID := i - go e.innerWorker(ctx, workID) - } - e.notifyWg.Add(1) - go e.notifyWorker(ctx) - } - result := <-e.resultChkCh - if result.err != nil { - return result.err - } - if result.chk == nil { // no more data - req.Reset() - atomic.StoreUint32(&e.drained, 1) - return nil - } - req.SwapColumns(result.chk) - e.freeChkCh <- result.chk - return nil -} - -// Close implements the Executor interface. -func (e *ParallelNestedLoopApplyExec) Close() error { - e.memTracker = nil - if atomic.LoadUint32(&e.started) == 1 { - close(e.exit) - e.notifyWg.Wait() - e.started = 0 - } - // Wait all workers to finish before Close() is called. - // Otherwise we may got data race. - err := exec.Close(e.outerExec) - - if e.RuntimeStats() != nil { - runtimeStats := join.NewJoinRuntimeStats() - if e.useCache { - var hitRatio float64 - if e.cacheAccessCounter > 0 { - hitRatio = float64(e.cacheHitCounter) / float64(e.cacheAccessCounter) - } - runtimeStats.SetCacheInfo(true, hitRatio) - } else { - runtimeStats.SetCacheInfo(false, 0) - } - runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", e.concurrency)) - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) - } - return err -} - -// notifyWorker waits for all inner/outer-workers finishing and then put an empty -// chunk into the resultCh to notify the upper executor there is no more data. -func (e *ParallelNestedLoopApplyExec) notifyWorker(ctx context.Context) { - defer e.handleWorkerPanic(ctx, &e.notifyWg) - e.workerWg.Wait() - e.putResult(nil, nil) -} - -func (e *ParallelNestedLoopApplyExec) outerWorker(ctx context.Context) { - defer trace.StartRegion(ctx, "ParallelApplyOuterWorker").End() - defer e.handleWorkerPanic(ctx, &e.workerWg) - var selected []bool - var err error - for { - failpoint.Inject("parallelApplyOuterWorkerPanic", nil) - chk := exec.TryNewCacheChunk(e.outerExec) - if err := exec.Next(ctx, e.outerExec, chk); err != nil { - e.putResult(nil, err) - return - } - if chk.NumRows() == 0 { - close(e.outerRowCh) - return - } - e.outerList.Add(chk) - outerIter := chunk.NewIterator4Chunk(chk) - selected, err = expression.VectorizedFilter(e.Ctx().GetExprCtx().GetEvalCtx(), e.Ctx().GetSessionVars().EnableVectorizedExpression, e.outerFilter, outerIter, selected) - if err != nil { - e.putResult(nil, err) - return - } - for i := 0; i < chk.NumRows(); i++ { - row := chk.GetRow(i) - select { - case e.outerRowCh <- outerRow{&row, selected[i]}: - case <-e.exit: - return - } - } - } -} - -func (e *ParallelNestedLoopApplyExec) innerWorker(ctx context.Context, id int) { - defer trace.StartRegion(ctx, "ParallelApplyInnerWorker").End() - defer e.handleWorkerPanic(ctx, &e.workerWg) - for { - var chk *chunk.Chunk - select { - case chk = <-e.freeChkCh: - case <-e.exit: - return - } - failpoint.Inject("parallelApplyInnerWorkerPanic", nil) - err := e.fillInnerChunk(ctx, id, chk) - if err == nil && chk.NumRows() == 0 { // no more data, this goroutine can exit - return - } - if e.putResult(chk, err) { - return - } - } -} - -func (e *ParallelNestedLoopApplyExec) putResult(chk *chunk.Chunk, err error) (exit bool) { - select { - case e.resultChkCh <- result{chk, err}: - return false - case <-e.exit: - return true - } -} - -func (e *ParallelNestedLoopApplyExec) handleWorkerPanic(ctx context.Context, wg *sync.WaitGroup) { - if r := recover(); r != nil { - err := util.GetRecoverError(r) - logutil.Logger(ctx).Error("parallel nested loop join worker panicked", zap.Error(err), zap.Stack("stack")) - e.resultChkCh <- result{nil, err} - } - if wg != nil { - wg.Done() - } -} - -// fetchAllInners reads all data from the inner table and stores them in a List. -func (e *ParallelNestedLoopApplyExec) fetchAllInners(ctx context.Context, id int) (err error) { - var key []byte - for _, col := range e.corCols[id] { - *col.Data = e.outerRow[id].GetDatum(col.Index, col.RetType) - if e.useCache { - key, err = codec.EncodeKey(e.Ctx().GetSessionVars().StmtCtx.TimeZone(), key, *col.Data) - err = e.Ctx().GetSessionVars().StmtCtx.HandleError(err) - if err != nil { - return err - } - } - } - if e.useCache { // look up the cache - atomic.AddInt64(&e.cacheAccessCounter, 1) - failpoint.Inject("parallelApplyGetCachePanic", nil) - value, err := e.cache.Get(key) - if err != nil { - return err - } - if value != nil { - e.innerList[id] = value - atomic.AddInt64(&e.cacheHitCounter, 1) - return nil - } - } - - err = exec.Open(ctx, e.innerExecs[id]) - defer func() { terror.Log(exec.Close(e.innerExecs[id])) }() - if err != nil { - return err - } - - if e.useCache { - // create a new one in this case since it may be in the cache - e.innerList[id] = chunk.NewList(exec.RetTypes(e.innerExecs[id]), e.InitCap(), e.MaxChunkSize()) - } else { - e.innerList[id].Reset() - } - - innerIter := chunk.NewIterator4Chunk(e.innerChunk[id]) - for { - err := exec.Next(ctx, e.innerExecs[id], e.innerChunk[id]) - if err != nil { - return err - } - if e.innerChunk[id].NumRows() == 0 { - break - } - - e.innerSelected[id], err = expression.VectorizedFilter(e.Ctx().GetExprCtx().GetEvalCtx(), e.Ctx().GetSessionVars().EnableVectorizedExpression, e.innerFilter[id], innerIter, e.innerSelected[id]) - if err != nil { - return err - } - for row := innerIter.Begin(); row != innerIter.End(); row = innerIter.Next() { - if e.innerSelected[id][row.Idx()] { - e.innerList[id].AppendRow(row) - } - } - } - - if e.useCache { // update the cache - failpoint.Inject("parallelApplySetCachePanic", nil) - if _, err := e.cache.Set(key, e.innerList[id]); err != nil { - return err - } - } - return nil -} - -func (e *ParallelNestedLoopApplyExec) fetchNextOuterRow(id int, req *chunk.Chunk) (row *chunk.Row, exit bool) { - for { - select { - case outerRow, ok := <-e.outerRowCh: - if !ok { // no more data - return nil, false - } - if !outerRow.selected { - if e.outer { - e.joiners[id].OnMissMatch(false, *outerRow.row, req) - if req.IsFull() { - return nil, false - } - } - continue // try the next outer row - } - return outerRow.row, false - case <-e.exit: - return nil, true - } - } -} - -func (e *ParallelNestedLoopApplyExec) fillInnerChunk(ctx context.Context, id int, req *chunk.Chunk) (err error) { - req.Reset() - for { - if e.innerIter[id] == nil || e.innerIter[id].Current() == e.innerIter[id].End() { - if e.outerRow[id] != nil && !e.hasMatch[id] { - e.joiners[id].OnMissMatch(e.hasNull[id], *e.outerRow[id], req) - } - var exit bool - e.outerRow[id], exit = e.fetchNextOuterRow(id, req) - if exit || req.IsFull() || e.outerRow[id] == nil { - return nil - } - - e.hasMatch[id] = false - e.hasNull[id] = false - - err = e.fetchAllInners(ctx, id) - if err != nil { - return err - } - e.innerIter[id] = chunk.NewIterator4List(e.innerList[id]) - e.innerIter[id].Begin() - } - - matched, isNull, err := e.joiners[id].TryToMatchInners(*e.outerRow[id], e.innerIter[id], req) - e.hasMatch[id] = e.hasMatch[id] || matched - e.hasNull[id] = e.hasNull[id] || isNull - - if err != nil || req.IsFull() { - return err - } - } -} diff --git a/pkg/executor/point_get.go b/pkg/executor/point_get.go index c2af1a0332290..ee3b7047aa89e 100644 --- a/pkg/executor/point_get.go +++ b/pkg/executor/point_get.go @@ -103,12 +103,12 @@ func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) exec.Execut sctx.IndexNames = append(sctx.IndexNames, p.TblInfo.Name.O+":"+p.IndexInfo.Name.O) } - if val, _err_ := failpoint.Eval(_curpkg_("assertPointReplicaOption")); _err_ == nil { + failpoint.Inject("assertPointReplicaOption", func(val failpoint.Value) { assertScope := val.(string) if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != e.readReplicaScope { panic("point get replica option fail") } - } + }) snapshotTS, err := b.getSnapshotTS() if err != nil { @@ -340,14 +340,14 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 // 3. Then point get retrieve data from backend after step 2 finished // 4. Check the result - if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("pointGetRepeatableReadTest-step1")); _err_ == nil { + failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step1", func() { if ch, ok := ctx.Value("pointGetRepeatableReadTest").(chan struct{}); ok { // Make `UPDATE` continue close(ch) } // Wait `UPDATE` finished - failpoint.EvalContext(ctx, _curpkg_("pointGetRepeatableReadTest-step2")) - } + failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step2", nil) + }) if e.idxInfo.Global { _, pid, err := codec.DecodeInt(tablecodec.SplitIndexValue(e.handleVal).PartitionID) if err != nil { diff --git a/pkg/executor/point_get.go__failpoint_stash__ b/pkg/executor/point_get.go__failpoint_stash__ deleted file mode 100644 index ee3b7047aa89e..0000000000000 --- a/pkg/executor/point_get.go__failpoint_stash__ +++ /dev/null @@ -1,824 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "fmt" - "sort" - "strconv" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/distsql" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil/consistency" - "github.com/pingcap/tidb/pkg/util/rowcodec" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/txnkv/txnsnapshot" -) - -func (b *executorBuilder) buildPointGet(p *plannercore.PointGetPlan) exec.Executor { - var err error - if err = b.validCanReadTemporaryOrCacheTable(p.TblInfo); err != nil { - b.err = err - return nil - } - - if p.PrunePartitions(b.ctx) { - // no matching partitions - return &TableDualExec{ - BaseExecutorV2: exec.NewBaseExecutorV2(b.ctx.GetSessionVars(), p.Schema(), p.ID()), - numDualRows: 0, - numReturned: 0, - } - } - - if p.Lock && !b.inSelectLockStmt { - b.inSelectLockStmt = true - defer func() { - b.inSelectLockStmt = false - }() - } - - e := &PointGetExecutor{ - BaseExecutor: exec.NewBaseExecutor(b.ctx, p.Schema(), p.ID()), - indexUsageReporter: b.buildIndexUsageReporter(p), - txnScope: b.txnScope, - readReplicaScope: b.readReplicaScope, - isStaleness: b.isStaleness, - partitionNames: p.PartitionNames, - } - - e.SetInitCap(1) - e.SetMaxChunkSize(1) - e.Init(p) - - e.snapshot, err = b.getSnapshot() - if err != nil { - b.err = err - return nil - } - if b.ctx.GetSessionVars().IsReplicaReadClosestAdaptive() { - e.snapshot.SetOption(kv.ReplicaReadAdjuster, newReplicaReadAdjuster(e.Ctx(), p.GetAvgRowSize())) - } - if e.RuntimeStats() != nil { - snapshotStats := &txnsnapshot.SnapshotRuntimeStats{} - e.stats = &runtimeStatsWithSnapshot{ - SnapshotRuntimeStats: snapshotStats, - } - e.snapshot.SetOption(kv.CollectRuntimeStats, snapshotStats) - } - - if p.IndexInfo != nil { - sctx := b.ctx.GetSessionVars().StmtCtx - sctx.IndexNames = append(sctx.IndexNames, p.TblInfo.Name.O+":"+p.IndexInfo.Name.O) - } - - failpoint.Inject("assertPointReplicaOption", func(val failpoint.Value) { - assertScope := val.(string) - if e.Ctx().GetSessionVars().GetReplicaRead().IsClosestRead() && assertScope != e.readReplicaScope { - panic("point get replica option fail") - } - }) - - snapshotTS, err := b.getSnapshotTS() - if err != nil { - b.err = err - return nil - } - if p.TblInfo.TableCacheStatusType == model.TableCacheStatusEnable { - if cacheTable := b.getCacheTable(p.TblInfo, snapshotTS); cacheTable != nil { - e.snapshot = cacheTableSnapshot{e.snapshot, cacheTable} - } - } - - if e.lock { - b.hasLock = true - } - - return e -} - -// PointGetExecutor executes point select query. -type PointGetExecutor struct { - exec.BaseExecutor - indexUsageReporter *exec.IndexUsageReporter - - tblInfo *model.TableInfo - handle kv.Handle - idxInfo *model.IndexInfo - partitionDefIdx *int - partitionNames []model.CIStr - idxKey kv.Key - handleVal []byte - idxVals []types.Datum - txnScope string - readReplicaScope string - isStaleness bool - txn kv.Transaction - snapshot kv.Snapshot - done bool - lock bool - lockWaitTime int64 - rowDecoder *rowcodec.ChunkDecoder - - columns []*model.ColumnInfo - // virtualColumnIndex records all the indices of virtual columns and sort them in definition - // to make sure we can compute the virtual column in right order. - virtualColumnIndex []int - - // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. - virtualColumnRetFieldTypes []*types.FieldType - - stats *runtimeStatsWithSnapshot -} - -// GetPhysID returns the physical id used, either the table's id or a partition's ID -func GetPhysID(tblInfo *model.TableInfo, idx *int) int64 { - if idx != nil { - if *idx < 0 { - intest.Assert(false) - } else { - if pi := tblInfo.GetPartitionInfo(); pi != nil { - return pi.Definitions[*idx].ID - } - } - } - return tblInfo.ID -} - -func matchPartitionNames(pid int64, partitionNames []model.CIStr, pi *model.PartitionInfo) bool { - if len(partitionNames) == 0 { - return true - } - defs := pi.Definitions - for i := range defs { - // TODO: create a map from id to partition definition index - if defs[i].ID == pid { - for _, name := range partitionNames { - if defs[i].Name.L == name.L { - return true - } - } - // Only one partition can match pid - return false - } - } - return false -} - -// Init set fields needed for PointGetExecutor reuse, this does NOT change baseExecutor field -func (e *PointGetExecutor) Init(p *plannercore.PointGetPlan) { - decoder := NewRowDecoder(e.Ctx(), p.Schema(), p.TblInfo) - e.tblInfo = p.TblInfo - e.handle = p.Handle - e.idxInfo = p.IndexInfo - e.idxVals = p.IndexValues - e.done = false - if e.tblInfo.TempTableType == model.TempTableNone { - e.lock = p.Lock - e.lockWaitTime = p.LockWaitTime - } else { - // Temporary table should not do any lock operations - e.lock = false - e.lockWaitTime = 0 - } - e.rowDecoder = decoder - e.partitionDefIdx = p.PartitionIdx - e.columns = p.Columns - e.buildVirtualColumnInfo() -} - -// buildVirtualColumnInfo saves virtual column indices and sort them in definition order -func (e *PointGetExecutor) buildVirtualColumnInfo() { - e.virtualColumnIndex = buildVirtualColumnIndex(e.Schema(), e.columns) - if len(e.virtualColumnIndex) > 0 { - e.virtualColumnRetFieldTypes = make([]*types.FieldType, len(e.virtualColumnIndex)) - for i, idx := range e.virtualColumnIndex { - e.virtualColumnRetFieldTypes[i] = e.Schema().Columns[idx].RetType - } - } -} - -// Open implements the Executor interface. -func (e *PointGetExecutor) Open(context.Context) error { - var err error - e.txn, err = e.Ctx().Txn(false) - if err != nil { - return err - } - if err := e.verifyTxnScope(); err != nil { - return err - } - setOptionForTopSQL(e.Ctx().GetSessionVars().StmtCtx, e.snapshot) - return nil -} - -// Close implements the Executor interface. -func (e *PointGetExecutor) Close() error { - if e.stats != nil { - defer e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), e.stats) - } - if e.RuntimeStats() != nil && e.snapshot != nil { - e.snapshot.SetOption(kv.CollectRuntimeStats, nil) - } - if e.indexUsageReporter != nil && e.idxInfo != nil { - tableID := e.tblInfo.ID - physicalTableID := GetPhysID(e.tblInfo, e.partitionDefIdx) - kvReqTotal := e.stats.SnapshotRuntimeStats.GetCmdRPCCount(tikvrpc.CmdGet) - e.indexUsageReporter.ReportPointGetIndexUsage(tableID, physicalTableID, e.idxInfo.ID, e.ID(), kvReqTotal) - } - e.done = false - return nil -} - -// Next implements the Executor interface. -func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.done { - return nil - } - e.done = true - - var err error - tblID := GetPhysID(e.tblInfo, e.partitionDefIdx) - if e.lock { - e.UpdateDeltaForTableID(tblID) - } - if e.idxInfo != nil { - if isCommonHandleRead(e.tblInfo, e.idxInfo) { - handleBytes, err := plannercore.EncodeUniqueIndexValuesForKey(e.Ctx(), e.tblInfo, e.idxInfo, e.idxVals) - if err != nil { - if kv.ErrNotExist.Equal(err) { - return nil - } - return err - } - e.handle, err = kv.NewCommonHandle(handleBytes) - if err != nil { - return err - } - } else { - e.idxKey, err = plannercore.EncodeUniqueIndexKey(e.Ctx(), e.tblInfo, e.idxInfo, e.idxVals, tblID) - if err != nil && !kv.ErrNotExist.Equal(err) { - return err - } - - // lockNonExistIdxKey indicates the key will be locked regardless of its existence. - lockNonExistIdxKey := !e.Ctx().GetSessionVars().IsPessimisticReadConsistency() - // Non-exist keys are also locked if the isolation level is not read consistency, - // lock it before read here, then it's able to read from pessimistic lock cache. - if lockNonExistIdxKey { - err = e.lockKeyIfNeeded(ctx, e.idxKey) - if err != nil { - return err - } - e.handleVal, err = e.get(ctx, e.idxKey) - if err != nil { - if !kv.ErrNotExist.Equal(err) { - return err - } - } - } else { - if e.lock { - e.handleVal, err = e.lockKeyIfExists(ctx, e.idxKey) - if err != nil { - return err - } - } else { - e.handleVal, err = e.get(ctx, e.idxKey) - if err != nil { - if !kv.ErrNotExist.Equal(err) { - return err - } - } - } - } - - if len(e.handleVal) == 0 { - return nil - } - - var iv kv.Handle - iv, err = tablecodec.DecodeHandleInIndexValue(e.handleVal) - if err != nil { - return err - } - e.handle = iv - - // The injection is used to simulate following scenario: - // 1. Session A create a point get query but pause before second time `GET` kv from backend - // 2. Session B create an UPDATE query to update the record that will be obtained in step 1 - // 3. Then point get retrieve data from backend after step 2 finished - // 4. Check the result - failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step1", func() { - if ch, ok := ctx.Value("pointGetRepeatableReadTest").(chan struct{}); ok { - // Make `UPDATE` continue - close(ch) - } - // Wait `UPDATE` finished - failpoint.InjectContext(ctx, "pointGetRepeatableReadTest-step2", nil) - }) - if e.idxInfo.Global { - _, pid, err := codec.DecodeInt(tablecodec.SplitIndexValue(e.handleVal).PartitionID) - if err != nil { - return err - } - tblID = pid - if !matchPartitionNames(tblID, e.partitionNames, e.tblInfo.GetPartitionInfo()) { - return nil - } - } - } - } - - key := tablecodec.EncodeRowKeyWithHandle(tblID, e.handle) - val, err := e.getAndLock(ctx, key) - if err != nil { - return err - } - if len(val) == 0 { - if e.idxInfo != nil && !isCommonHandleRead(e.tblInfo, e.idxInfo) && - !e.Ctx().GetSessionVars().StmtCtx.WeakConsistency { - return (&consistency.Reporter{ - HandleEncode: func(kv.Handle) kv.Key { - return key - }, - IndexEncode: func(*consistency.RecordData) kv.Key { - return e.idxKey - }, - Tbl: e.tblInfo, - Idx: e.idxInfo, - EnableRedactLog: e.Ctx().GetSessionVars().EnableRedactLog, - Storage: e.Ctx().GetStore(), - }).ReportLookupInconsistent(ctx, - 1, 0, - []kv.Handle{e.handle}, - []kv.Handle{e.handle}, - []consistency.RecordData{{}}, - ) - } - return nil - } - - sctx := e.BaseExecutor.Ctx() - schema := e.Schema() - err = DecodeRowValToChunk(sctx, schema, e.tblInfo, e.handle, val, req, e.rowDecoder) - if err != nil { - return err - } - - err = fillRowChecksum(sctx, 0, 1, schema, e.tblInfo, [][]byte{val}, []kv.Handle{e.handle}, req, nil) - if err != nil { - return err - } - - err = table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, - schema.Columns, e.columns, sctx.GetExprCtx(), req) - if err != nil { - return err - } - return nil -} - -func shouldFillRowChecksum(schema *expression.Schema) (int, bool) { - for idx, col := range schema.Columns { - if col.ID == model.ExtraRowChecksumID { - return idx, true - } - } - return 0, false -} - -func fillRowChecksum( - sctx sessionctx.Context, - start, end int, - schema *expression.Schema, tblInfo *model.TableInfo, - values [][]byte, handles []kv.Handle, - req *chunk.Chunk, buf []byte, -) error { - checksumColumnIndex, ok := shouldFillRowChecksum(schema) - if !ok { - return nil - } - - var handleColIDs []int64 - if tblInfo.PKIsHandle { - colInfo := tblInfo.GetPkColInfo() - handleColIDs = []int64{colInfo.ID} - } else if tblInfo.IsCommonHandle { - pkIdx := tables.FindPrimaryIndex(tblInfo) - for _, col := range pkIdx.Columns { - colInfo := tblInfo.Columns[col.Offset] - handleColIDs = append(handleColIDs, colInfo.ID) - } - } - - columnFt := make(map[int64]*types.FieldType) - for idx := range tblInfo.Columns { - col := tblInfo.Columns[idx] - columnFt[col.ID] = &col.FieldType - } - tz := sctx.GetSessionVars().TimeZone - ft := []*types.FieldType{schema.Columns[checksumColumnIndex].GetType(sctx.GetExprCtx().GetEvalCtx())} - checksumCols := chunk.NewChunkWithCapacity(ft, req.Capacity()) - for i := start; i < end; i++ { - handle, val := handles[i], values[i] - if !rowcodec.IsNewFormat(val) { - checksumCols.AppendNull(0) - continue - } - datums, err := tablecodec.DecodeRowWithMapNew(val, columnFt, tz, nil) - if err != nil { - return err - } - datums, err = tablecodec.DecodeHandleToDatumMap(handle, handleColIDs, columnFt, tz, datums) - if err != nil { - return err - } - for _, col := range tblInfo.Columns { - // cannot found from the datums, which means the data is not stored, this - // may happen after `add column` executed, filling with the default value. - _, ok := datums[col.ID] - if !ok { - colInfo := getColInfoByID(tblInfo, col.ID) - d, err := table.GetColOriginDefaultValue(sctx.GetExprCtx(), colInfo) - if err != nil { - return err - } - datums[col.ID] = d - } - } - - colData := make([]rowcodec.ColData, len(tblInfo.Columns)) - for idx, col := range tblInfo.Columns { - d := datums[col.ID] - data := rowcodec.ColData{ - ColumnInfo: col, - Datum: &d, - } - colData[idx] = data - } - row := rowcodec.RowData{ - Cols: colData, - Data: buf, - } - if !sort.IsSorted(row) { - sort.Sort(row) - } - checksum, err := row.Checksum(tz) - if err != nil { - return err - } - checksumCols.AppendString(0, strconv.FormatUint(uint64(checksum), 10)) - } - req.SetCol(checksumColumnIndex, checksumCols.Column(0)) - return nil -} - -func (e *PointGetExecutor) getAndLock(ctx context.Context, key kv.Key) (val []byte, err error) { - if e.Ctx().GetSessionVars().IsPessimisticReadConsistency() { - // Only Lock the existing keys in RC isolation. - if e.lock { - val, err = e.lockKeyIfExists(ctx, key) - if err != nil { - return nil, err - } - } else { - val, err = e.get(ctx, key) - if err != nil { - if !kv.ErrNotExist.Equal(err) { - return nil, err - } - return nil, nil - } - } - return val, nil - } - // Lock the key before get in RR isolation, then get will get the value from the cache. - err = e.lockKeyIfNeeded(ctx, key) - if err != nil { - return nil, err - } - val, err = e.get(ctx, key) - if err != nil { - if !kv.ErrNotExist.Equal(err) { - return nil, err - } - return nil, nil - } - return val, nil -} - -func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) error { - _, err := e.lockKeyBase(ctx, key, false) - return err -} - -// lockKeyIfExists locks the key if needed, but won't lock the key if it doesn't exis. -// Returns the value of the key if the key exist. -func (e *PointGetExecutor) lockKeyIfExists(ctx context.Context, key []byte) ([]byte, error) { - return e.lockKeyBase(ctx, key, true) -} - -func (e *PointGetExecutor) lockKeyBase(ctx context.Context, - key []byte, - lockOnlyIfExists bool) ([]byte, error) { - if len(key) == 0 { - return nil, nil - } - - if e.lock { - seVars := e.Ctx().GetSessionVars() - lockCtx, err := newLockCtx(e.Ctx(), e.lockWaitTime, 1) - if err != nil { - return nil, err - } - lockCtx.LockOnlyIfExists = lockOnlyIfExists - lockCtx.InitReturnValues(1) - err = doLockKeys(ctx, e.Ctx(), lockCtx, key) - if err != nil { - return nil, err - } - lockCtx.IterateValuesNotLocked(func(k, v []byte) { - seVars.TxnCtx.SetPessimisticLockCache(k, v) - }) - if len(e.handleVal) > 0 { - seVars.TxnCtx.SetPessimisticLockCache(e.idxKey, e.handleVal) - } - if lockOnlyIfExists { - return e.getValueFromLockCtx(ctx, lockCtx, key) - } - } - - return nil, nil -} - -func (e *PointGetExecutor) getValueFromLockCtx(ctx context.Context, - lockCtx *kv.LockCtx, - key []byte) ([]byte, error) { - if val, ok := lockCtx.Values[string(key)]; ok { - if val.Exists { - return val.Value, nil - } else if val.AlreadyLocked { - val, err := e.get(ctx, key) - if err != nil { - if !kv.ErrNotExist.Equal(err) { - return nil, err - } - return nil, nil - } - return val, nil - } - } - - return nil, nil -} - -// get will first try to get from txn buffer, then check the pessimistic lock cache, -// then the store. Kv.ErrNotExist will be returned if key is not found -func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error) { - if len(key) == 0 { - return nil, kv.ErrNotExist - } - - var ( - val []byte - err error - ) - - if e.txn.Valid() && !e.txn.IsReadOnly() { - // We cannot use txn.Get directly here because the snapshot in txn and the snapshot of e.snapshot may be - // different for pessimistic transaction. - val, err = e.txn.GetMemBuffer().Get(ctx, key) - if err == nil { - return val, err - } - if !kv.IsErrNotFound(err) { - return nil, err - } - // key does not exist in mem buffer, check the lock cache - if e.lock { - var ok bool - val, ok = e.Ctx().GetSessionVars().TxnCtx.GetKeyInPessimisticLockCache(key) - if ok { - return val, nil - } - } - // fallthrough to snapshot get. - } - - lock := e.tblInfo.Lock - if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) { - if e.Ctx().GetSessionVars().EnablePointGetCache { - cacheDB := e.Ctx().GetStore().GetMemCache() - val, err = cacheDB.UnionGet(ctx, e.tblInfo.ID, e.snapshot, key) - if err != nil { - return nil, err - } - return val, nil - } - } - // if not read lock or table was unlock then snapshot get - return e.snapshot.Get(ctx, key) -} - -func (e *PointGetExecutor) verifyTxnScope() error { - if e.txnScope == "" || e.txnScope == kv.GlobalTxnScope { - return nil - } - - var partName string - is := e.Ctx().GetInfoSchema().(infoschema.InfoSchema) - tblInfo, _ := is.TableByID((e.tblInfo.ID)) - tblName := tblInfo.Meta().Name.String() - tblID := GetPhysID(tblInfo.Meta(), e.partitionDefIdx) - if tblID != tblInfo.Meta().ID { - partName = tblInfo.Meta().GetPartitionInfo().Definitions[*e.partitionDefIdx].Name.String() - } - valid := distsql.VerifyTxnScope(e.txnScope, tblID, is) - if valid { - return nil - } - if len(partName) > 0 { - return dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs( - fmt.Sprintf("table %v's partition %v can not be read by %v txn_scope", tblName, partName, e.txnScope)) - } - return dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs( - fmt.Sprintf("table %v can not be read by %v txn_scope", tblName, e.txnScope)) -} - -// DecodeRowValToChunk decodes row value into chunk checking row format used. -func DecodeRowValToChunk(sctx sessionctx.Context, schema *expression.Schema, tblInfo *model.TableInfo, - handle kv.Handle, rowVal []byte, chk *chunk.Chunk, rd *rowcodec.ChunkDecoder) error { - if rowcodec.IsNewFormat(rowVal) { - return rd.DecodeToChunk(rowVal, handle, chk) - } - return decodeOldRowValToChunk(sctx, schema, tblInfo, handle, rowVal, chk) -} - -func decodeOldRowValToChunk(sctx sessionctx.Context, schema *expression.Schema, tblInfo *model.TableInfo, handle kv.Handle, - rowVal []byte, chk *chunk.Chunk) error { - pkCols := tables.TryGetCommonPkColumnIds(tblInfo) - prefixColIDs := tables.PrimaryPrefixColumnIDs(tblInfo) - colID2CutPos := make(map[int64]int, schema.Len()) - for _, col := range schema.Columns { - if _, ok := colID2CutPos[col.ID]; !ok { - colID2CutPos[col.ID] = len(colID2CutPos) - } - } - cutVals, err := tablecodec.CutRowNew(rowVal, colID2CutPos) - if err != nil { - return err - } - if cutVals == nil { - cutVals = make([][]byte, len(colID2CutPos)) - } - decoder := codec.NewDecoder(chk, sctx.GetSessionVars().Location()) - for i, col := range schema.Columns { - // fill the virtual column value after row calculation - if col.VirtualExpr != nil { - chk.AppendNull(i) - continue - } - ok, err := tryDecodeFromHandle(tblInfo, i, col, handle, chk, decoder, pkCols, prefixColIDs) - if err != nil { - return err - } - if ok { - continue - } - cutPos := colID2CutPos[col.ID] - if len(cutVals[cutPos]) == 0 { - colInfo := getColInfoByID(tblInfo, col.ID) - d, err1 := table.GetColOriginDefaultValue(sctx.GetExprCtx(), colInfo) - if err1 != nil { - return err1 - } - chk.AppendDatum(i, &d) - continue - } - _, err = decoder.DecodeOne(cutVals[cutPos], i, col.RetType) - if err != nil { - return err - } - } - return nil -} - -func tryDecodeFromHandle(tblInfo *model.TableInfo, schemaColIdx int, col *expression.Column, handle kv.Handle, chk *chunk.Chunk, - decoder *codec.Decoder, pkCols []int64, prefixColIDs []int64) (bool, error) { - if tblInfo.PKIsHandle && mysql.HasPriKeyFlag(col.RetType.GetFlag()) { - chk.AppendInt64(schemaColIdx, handle.IntValue()) - return true, nil - } - if col.ID == model.ExtraHandleID { - chk.AppendInt64(schemaColIdx, handle.IntValue()) - return true, nil - } - if types.NeedRestoredData(col.RetType) { - return false, nil - } - // Try to decode common handle. - if mysql.HasPriKeyFlag(col.RetType.GetFlag()) { - for i, hid := range pkCols { - if col.ID == hid && notPKPrefixCol(hid, prefixColIDs) { - _, err := decoder.DecodeOne(handle.EncodedCol(i), schemaColIdx, col.RetType) - if err != nil { - return false, errors.Trace(err) - } - return true, nil - } - } - } - return false, nil -} - -func notPKPrefixCol(colID int64, prefixColIDs []int64) bool { - for _, pCol := range prefixColIDs { - if pCol == colID { - return false - } - } - return true -} - -func getColInfoByID(tbl *model.TableInfo, colID int64) *model.ColumnInfo { - for _, col := range tbl.Columns { - if col.ID == colID { - return col - } - } - return nil -} - -type runtimeStatsWithSnapshot struct { - *txnsnapshot.SnapshotRuntimeStats -} - -func (e *runtimeStatsWithSnapshot) String() string { - if e.SnapshotRuntimeStats != nil { - return e.SnapshotRuntimeStats.String() - } - return "" -} - -// Clone implements the RuntimeStats interface. -func (e *runtimeStatsWithSnapshot) Clone() execdetails.RuntimeStats { - newRs := &runtimeStatsWithSnapshot{} - if e.SnapshotRuntimeStats != nil { - snapshotStats := e.SnapshotRuntimeStats.Clone() - newRs.SnapshotRuntimeStats = snapshotStats - } - return newRs -} - -// Merge implements the RuntimeStats interface. -func (e *runtimeStatsWithSnapshot) Merge(other execdetails.RuntimeStats) { - tmp, ok := other.(*runtimeStatsWithSnapshot) - if !ok { - return - } - if tmp.SnapshotRuntimeStats != nil { - if e.SnapshotRuntimeStats == nil { - snapshotStats := tmp.SnapshotRuntimeStats.Clone() - e.SnapshotRuntimeStats = snapshotStats - return - } - e.SnapshotRuntimeStats.Merge(tmp.SnapshotRuntimeStats) - } -} - -// Tp implements the RuntimeStats interface. -func (*runtimeStatsWithSnapshot) Tp() int { - return execdetails.TpRuntimeStatsWithSnapshot -} diff --git a/pkg/executor/projection.go b/pkg/executor/projection.go index 632e8bf586a5e..ce1113dde4776 100644 --- a/pkg/executor/projection.go +++ b/pkg/executor/projection.go @@ -106,11 +106,11 @@ func (e *ProjectionExec) Open(ctx context.Context) error { if err := e.BaseExecutorV2.Open(ctx); err != nil { return err } - if val, _err_ := failpoint.Eval(_curpkg_("mockProjectionExecBaseExecutorOpenReturnedError")); _err_ == nil { + failpoint.Inject("mockProjectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { if val.(bool) { - return errors.New("mock ProjectionExec.baseExecutor.Open returned error") + failpoint.Return(errors.New("mock ProjectionExec.baseExecutor.Open returned error")) } - } + }) return e.open(ctx) } @@ -216,7 +216,7 @@ func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk e.childResult.SetRequiredRows(chk.RequiredRows(), e.MaxChunkSize()) mSize := e.childResult.MemoryUsage() err := exec.Next(ctx, e.Children(0), e.childResult) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) if err != nil { return err @@ -246,7 +246,7 @@ func (e *ProjectionExec) parallelExecute(ctx context.Context, chk *chunk.Chunk) } mSize := output.chk.MemoryUsage() chk.SwapColumns(output.chk) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) e.memTracker.Consume(output.chk.MemoryUsage() - mSize) e.fetcher.outputCh <- output return nil @@ -280,7 +280,7 @@ func (e *ProjectionExec) prepare(ctx context.Context) { }) inputChk := exec.NewFirstChunk(e.Children(0)) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) e.memTracker.Consume(inputChk.MemoryUsage()) e.fetcher.inputCh <- &projectionInput{ chk: inputChk, @@ -408,7 +408,7 @@ func (f *projectionInputFetcher) run(ctx context.Context) { input.chk.SetRequiredRows(int(requiredRows), f.proj.MaxChunkSize()) mSize := input.chk.MemoryUsage() err := exec.Next(ctx, f.child, input.chk) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) f.proj.memTracker.Consume(input.chk.MemoryUsage() - mSize) if err != nil || input.chk.NumRows() == 0 { output.done <- err @@ -469,7 +469,7 @@ func (w *projectionWorker) run(ctx context.Context) { mSize := output.chk.MemoryUsage() + input.chk.MemoryUsage() err := w.evaluatorSuit.Run(w.ctx.evalCtx, w.ctx.enableVectorizedExpression, input.chk, output.chk) - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + failpoint.Inject("ConsumeRandomPanic", nil) w.proj.memTracker.Consume(output.chk.MemoryUsage() + input.chk.MemoryUsage() - mSize) output.done <- err diff --git a/pkg/executor/projection.go__failpoint_stash__ b/pkg/executor/projection.go__failpoint_stash__ deleted file mode 100644 index ce1113dde4776..0000000000000 --- a/pkg/executor/projection.go__failpoint_stash__ +++ /dev/null @@ -1,501 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "fmt" - "runtime/trace" - "sync" - "sync/atomic" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" -) - -// This file contains the implementation of the physical Projection Operator: -// https://en.wikipedia.org/wiki/Projection_(relational_algebra) -// -// NOTE: -// 1. The number of "projectionWorker" is controlled by the global session -// variable "tidb_projection_concurrency". -// 2. Unparallel version is used when one of the following situations occurs: -// a. "tidb_projection_concurrency" is set to 0. -// b. The estimated input size is smaller than "tidb_max_chunk_size". -// c. This projection can not be executed vectorially. - -type projectionInput struct { - chk *chunk.Chunk - targetWorker *projectionWorker -} - -type projectionOutput struct { - chk *chunk.Chunk - done chan error -} - -// projectionExecutorContext is the execution context for the `ProjectionExec` -type projectionExecutorContext struct { - stmtMemTracker *memory.Tracker - stmtRuntimeStatsColl *execdetails.RuntimeStatsColl - evalCtx expression.EvalContext - enableVectorizedExpression bool -} - -func newProjectionExecutorContext(sctx sessionctx.Context) projectionExecutorContext { - return projectionExecutorContext{ - stmtMemTracker: sctx.GetSessionVars().StmtCtx.MemTracker, - stmtRuntimeStatsColl: sctx.GetSessionVars().StmtCtx.RuntimeStatsColl, - evalCtx: sctx.GetExprCtx().GetEvalCtx(), - enableVectorizedExpression: sctx.GetSessionVars().EnableVectorizedExpression, - } -} - -// ProjectionExec implements the physical Projection Operator: -// https://en.wikipedia.org/wiki/Projection_(relational_algebra) -type ProjectionExec struct { - projectionExecutorContext - exec.BaseExecutorV2 - - evaluatorSuit *expression.EvaluatorSuite - - finishCh chan struct{} - outputCh chan *projectionOutput - fetcher projectionInputFetcher - numWorkers int64 - workers []*projectionWorker - childResult *chunk.Chunk - - // parentReqRows indicates how many rows the parent executor is - // requiring. It is set when parallelExecute() is called and used by the - // concurrent projectionInputFetcher. - // - // NOTE: It should be protected by atomic operations. - parentReqRows int64 - - memTracker *memory.Tracker - wg *sync.WaitGroup - - calculateNoDelay bool - prepared bool -} - -// Open implements the Executor Open interface. -func (e *ProjectionExec) Open(ctx context.Context) error { - if err := e.BaseExecutorV2.Open(ctx); err != nil { - return err - } - failpoint.Inject("mockProjectionExecBaseExecutorOpenReturnedError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("mock ProjectionExec.baseExecutor.Open returned error")) - } - }) - return e.open(ctx) -} - -func (e *ProjectionExec) open(_ context.Context) error { - e.prepared = false - e.parentReqRows = int64(e.MaxChunkSize()) - - if e.memTracker != nil { - e.memTracker.Reset() - } else { - e.memTracker = memory.NewTracker(e.ID(), -1) - } - e.memTracker.AttachTo(e.stmtMemTracker) - - // For now a Projection can not be executed vectorially only because it - // contains "SetVar" or "GetVar" functions, in this scenario this - // Projection can not be executed parallelly. - if e.numWorkers > 0 && !e.evaluatorSuit.Vectorizable() { - e.numWorkers = 0 - } - - if e.isUnparallelExec() { - e.childResult = exec.TryNewCacheChunk(e.Children(0)) - e.memTracker.Consume(e.childResult.MemoryUsage()) - } - - e.wg = &sync.WaitGroup{} - - return nil -} - -// Next implements the Executor Next interface. -// -// Here we explain the execution flow of the parallel projection implementation. -// There are 3 main components: -// 1. "projectionInputFetcher": Fetch input "Chunk" from child. -// 2. "projectionWorker": Do the projection work. -// 3. "ProjectionExec.Next": Return result to parent. -// -// 1. "projectionInputFetcher" gets its input and output resources from its -// "inputCh" and "outputCh" channel, once the input and output resources are -// obtained, it fetches child's result into "input.chk" and: -// a. Dispatches this input to the worker specified in "input.targetWorker" -// b. Dispatches this output to the main thread: "ProjectionExec.Next" -// c. Dispatches this output to the worker specified in "input.targetWorker" -// It is finished and exited once: -// a. There is no more input from child. -// b. "ProjectionExec" close the "globalFinishCh" -// -// 2. "projectionWorker" gets its input and output resources from its -// "inputCh" and "outputCh" channel, once the input and output resources are -// abtained, it calculates the projection result use "input.chk" as the input -// and "output.chk" as the output, once the calculation is done, it: -// a. Sends "nil" or error to "output.done" to mark this input is finished. -// b. Returns the "input" resource to "projectionInputFetcher.inputCh" -// They are finished and exited once: -// a. "ProjectionExec" closes the "globalFinishCh" -// -// 3. "ProjectionExec.Next" gets its output resources from its "outputCh" channel. -// After receiving an output from "outputCh", it should wait to receive a "nil" -// or error from "output.done" channel. Once a "nil" or error is received: -// a. Returns this output to its parent -// b. Returns the "output" resource to "projectionInputFetcher.outputCh" -/* - +-----------+----------------------+--------------------------+ - | | | | - | +--------+---------+ +--------+---------+ +--------+---------+ - | | projectionWorker | + projectionWorker | ... + projectionWorker | - | +------------------+ +------------------+ +------------------+ - | ^ ^ ^ ^ ^ ^ - | | | | | | | - | inputCh outputCh inputCh outputCh inputCh outputCh - | ^ ^ ^ ^ ^ ^ - | | | | | | | - | | | - | | +----------------->outputCh - | | | | - | | | v - | +-------+-------+--------+ +---------------------+ - | | projectionInputFetcher | | ProjectionExec.Next | - | +------------------------+ +---------+-----------+ - | ^ ^ | - | | | | - | inputCh outputCh | - | ^ ^ | - | | | | - +------------------------------+ +----------------------+ -*/ -func (e *ProjectionExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - if e.isUnparallelExec() { - return e.unParallelExecute(ctx, req) - } - return e.parallelExecute(ctx, req) -} - -func (e *ProjectionExec) isUnparallelExec() bool { - return e.numWorkers <= 0 -} - -func (e *ProjectionExec) unParallelExecute(ctx context.Context, chk *chunk.Chunk) error { - // transmit the requiredRows - e.childResult.SetRequiredRows(chk.RequiredRows(), e.MaxChunkSize()) - mSize := e.childResult.MemoryUsage() - err := exec.Next(ctx, e.Children(0), e.childResult) - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(e.childResult.MemoryUsage() - mSize) - if err != nil { - return err - } - if e.childResult.NumRows() == 0 { - return nil - } - err = e.evaluatorSuit.Run(e.evalCtx, e.enableVectorizedExpression, e.childResult, chk) - return err -} - -func (e *ProjectionExec) parallelExecute(ctx context.Context, chk *chunk.Chunk) error { - atomic.StoreInt64(&e.parentReqRows, int64(chk.RequiredRows())) - if !e.prepared { - e.prepare(ctx) - e.prepared = true - } - - output, ok := <-e.outputCh - if !ok { - return nil - } - - err := <-output.done - if err != nil { - return err - } - mSize := output.chk.MemoryUsage() - chk.SwapColumns(output.chk) - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(output.chk.MemoryUsage() - mSize) - e.fetcher.outputCh <- output - return nil -} - -func (e *ProjectionExec) prepare(ctx context.Context) { - e.finishCh = make(chan struct{}) - e.outputCh = make(chan *projectionOutput, e.numWorkers) - - // Initialize projectionInputFetcher. - e.fetcher = projectionInputFetcher{ - proj: e, - child: e.Children(0), - globalFinishCh: e.finishCh, - globalOutputCh: e.outputCh, - inputCh: make(chan *projectionInput, e.numWorkers), - outputCh: make(chan *projectionOutput, e.numWorkers), - } - - // Initialize projectionWorker. - e.workers = make([]*projectionWorker, 0, e.numWorkers) - for i := int64(0); i < e.numWorkers; i++ { - e.workers = append(e.workers, &projectionWorker{ - proj: e, - ctx: e.projectionExecutorContext, - evaluatorSuit: e.evaluatorSuit, - globalFinishCh: e.finishCh, - inputGiveBackCh: e.fetcher.inputCh, - inputCh: make(chan *projectionInput, 1), - outputCh: make(chan *projectionOutput, 1), - }) - - inputChk := exec.NewFirstChunk(e.Children(0)) - failpoint.Inject("ConsumeRandomPanic", nil) - e.memTracker.Consume(inputChk.MemoryUsage()) - e.fetcher.inputCh <- &projectionInput{ - chk: inputChk, - targetWorker: e.workers[i], - } - - outputChk := exec.NewFirstChunk(e) - e.memTracker.Consume(outputChk.MemoryUsage()) - e.fetcher.outputCh <- &projectionOutput{ - chk: outputChk, - done: make(chan error, 1), - } - } - - e.wg.Add(1) - go e.fetcher.run(ctx) - - for i := range e.workers { - e.wg.Add(1) - go e.workers[i].run(ctx) - } -} - -func (e *ProjectionExec) drainInputCh(ch chan *projectionInput) { - close(ch) - for item := range ch { - if item.chk != nil { - e.memTracker.Consume(-item.chk.MemoryUsage()) - } - } -} - -func (e *ProjectionExec) drainOutputCh(ch chan *projectionOutput) { - close(ch) - for item := range ch { - if item.chk != nil { - e.memTracker.Consume(-item.chk.MemoryUsage()) - } - } -} - -// Close implements the Executor Close interface. -func (e *ProjectionExec) Close() error { - // if e.BaseExecutor.Open returns error, e.childResult will be nil, see https://github.com/pingcap/tidb/issues/24210 - // for more information - if e.isUnparallelExec() && e.childResult != nil { - e.memTracker.Consume(-e.childResult.MemoryUsage()) - e.childResult = nil - } - if e.prepared { - close(e.finishCh) - e.wg.Wait() // Wait for fetcher and workers to finish and exit. - - // clear fetcher - e.drainInputCh(e.fetcher.inputCh) - e.drainOutputCh(e.fetcher.outputCh) - - // clear workers - for _, w := range e.workers { - e.drainInputCh(w.inputCh) - e.drainOutputCh(w.outputCh) - } - } - if e.BaseExecutorV2.RuntimeStats() != nil { - runtimeStats := &execdetails.RuntimeStatsWithConcurrencyInfo{} - if e.isUnparallelExec() { - runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", 0)) - } else { - runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("Concurrency", int(e.numWorkers))) - } - e.stmtRuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) - } - return e.BaseExecutorV2.Close() -} - -type projectionInputFetcher struct { - proj *ProjectionExec - child exec.Executor - globalFinishCh <-chan struct{} - globalOutputCh chan<- *projectionOutput - - inputCh chan *projectionInput - outputCh chan *projectionOutput -} - -// run gets projectionInputFetcher's input and output resources from its -// "inputCh" and "outputCh" channel, once the input and output resources are -// abtained, it fetches child's result into "input.chk" and: -// -// a. Dispatches this input to the worker specified in "input.targetWorker" -// b. Dispatches this output to the main thread: "ProjectionExec.Next" -// c. Dispatches this output to the worker specified in "input.targetWorker" -// -// It is finished and exited once: -// -// a. There is no more input from child. -// b. "ProjectionExec" close the "globalFinishCh" -func (f *projectionInputFetcher) run(ctx context.Context) { - defer trace.StartRegion(ctx, "ProjectionFetcher").End() - var output *projectionOutput - defer func() { - if r := recover(); r != nil { - recoveryProjection(output, r) - } - close(f.globalOutputCh) - f.proj.wg.Done() - }() - - for { - input, isNil := readProjection[*projectionInput](f.inputCh, f.globalFinishCh) - if isNil { - return - } - targetWorker := input.targetWorker - - output, isNil = readProjection[*projectionOutput](f.outputCh, f.globalFinishCh) - if isNil { - f.proj.memTracker.Consume(-input.chk.MemoryUsage()) - return - } - - f.globalOutputCh <- output - - requiredRows := atomic.LoadInt64(&f.proj.parentReqRows) - input.chk.SetRequiredRows(int(requiredRows), f.proj.MaxChunkSize()) - mSize := input.chk.MemoryUsage() - err := exec.Next(ctx, f.child, input.chk) - failpoint.Inject("ConsumeRandomPanic", nil) - f.proj.memTracker.Consume(input.chk.MemoryUsage() - mSize) - if err != nil || input.chk.NumRows() == 0 { - output.done <- err - f.proj.memTracker.Consume(-input.chk.MemoryUsage()) - return - } - - targetWorker.inputCh <- input - targetWorker.outputCh <- output - } -} - -type projectionWorker struct { - proj *ProjectionExec - ctx projectionExecutorContext - evaluatorSuit *expression.EvaluatorSuite - globalFinishCh <-chan struct{} - inputGiveBackCh chan<- *projectionInput - - // channel "input" and "output" is : - // a. initialized by "ProjectionExec.prepare" - // b. written by "projectionInputFetcher.run" - // c. read by "projectionWorker.run" - inputCh chan *projectionInput - outputCh chan *projectionOutput -} - -// run gets projectionWorker's input and output resources from its -// "inputCh" and "outputCh" channel, once the input and output resources are -// abtained, it calculate the projection result use "input.chk" as the input -// and "output.chk" as the output, once the calculation is done, it: -// -// a. Sends "nil" or error to "output.done" to mark this input is finished. -// b. Returns the "input" resource to "projectionInputFetcher.inputCh". -// -// It is finished and exited once: -// -// a. "ProjectionExec" closes the "globalFinishCh". -func (w *projectionWorker) run(ctx context.Context) { - defer trace.StartRegion(ctx, "ProjectionWorker").End() - var output *projectionOutput - defer func() { - if r := recover(); r != nil { - recoveryProjection(output, r) - } - w.proj.wg.Done() - }() - for { - input, isNil := readProjection[*projectionInput](w.inputCh, w.globalFinishCh) - if isNil { - return - } - - output, isNil = readProjection[*projectionOutput](w.outputCh, w.globalFinishCh) - if isNil { - return - } - - mSize := output.chk.MemoryUsage() + input.chk.MemoryUsage() - err := w.evaluatorSuit.Run(w.ctx.evalCtx, w.ctx.enableVectorizedExpression, input.chk, output.chk) - failpoint.Inject("ConsumeRandomPanic", nil) - w.proj.memTracker.Consume(output.chk.MemoryUsage() + input.chk.MemoryUsage() - mSize) - output.done <- err - - if err != nil { - return - } - - w.inputGiveBackCh <- input - } -} - -func recoveryProjection(output *projectionOutput, r any) { - if output != nil { - output.done <- util.GetRecoverError(r) - } - logutil.BgLogger().Error("projection executor panicked", zap.String("error", fmt.Sprintf("%v", r)), zap.Stack("stack")) -} - -func readProjection[T any](ch <-chan T, finishCh <-chan struct{}) (t T, isNil bool) { - select { - case <-finishCh: - return t, true - case t, ok := <-ch: - if !ok { - return t, true - } - return t, false - } -} diff --git a/pkg/executor/shuffle.go b/pkg/executor/shuffle.go index fde7eb34e5d28..9c0ee0f8050b2 100644 --- a/pkg/executor/shuffle.go +++ b/pkg/executor/shuffle.go @@ -231,11 +231,11 @@ func (e *ShuffleExec) Next(ctx context.Context, req *chunk.Chunk) error { e.prepared = true } - if val, _err_ := failpoint.Eval(_curpkg_("shuffleError")); _err_ == nil { + failpoint.Inject("shuffleError", func(val failpoint.Value) { if val.(bool) { - return errors.New("ShuffleExec.Next error") + failpoint.Return(errors.New("ShuffleExec.Next error")) } - } + }) if e.executed { return nil @@ -279,12 +279,12 @@ func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int waitGroup.Done() }() - if val, _err_ := failpoint.Eval(_curpkg_("shuffleExecFetchDataAndSplit")); _err_ == nil { + failpoint.Inject("shuffleExecFetchDataAndSplit", func(val failpoint.Value) { if val.(bool) { time.Sleep(100 * time.Millisecond) panic("shuffleExecFetchDataAndSplitPanic") } - } + }) for { err = exec.Next(ctx, e.dataSources[dataSourceIndex], chk) @@ -400,7 +400,7 @@ func (e *shuffleWorker) run(ctx context.Context, waitGroup *sync.WaitGroup) { waitGroup.Done() }() - failpoint.Eval(_curpkg_("shuffleWorkerRun")) + failpoint.Inject("shuffleWorkerRun", nil) for { select { case <-e.finishCh: diff --git a/pkg/executor/shuffle.go__failpoint_stash__ b/pkg/executor/shuffle.go__failpoint_stash__ deleted file mode 100644 index 9c0ee0f8050b2..0000000000000 --- a/pkg/executor/shuffle.go__failpoint_stash__ +++ /dev/null @@ -1,492 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "context" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/aggregate" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/executor/internal/vecgroupchecker" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/twmb/murmur3" - "go.uber.org/zap" -) - -// ShuffleExec is the executor to run other executors in a parallel manner. -// -// 1. It fetches chunks from M `DataSources` (value of M depends on the actual executor, e.g. M = 1 for WindowExec, M = 2 for MergeJoinExec). -// -// 2. It splits tuples from each `DataSource` into N partitions (Only "split by hash" is implemented so far). -// -// 3. It invokes N workers in parallel, each one has M `receiver` to receive partitions from `DataSources` -// -// 4. It assigns partitions received as input to each worker and executes child executors. -// -// 5. It collects outputs from each worker, then sends outputs to its parent. -// -// +-------------+ -// +-------| Main Thread | -// | +------+------+ -// | ^ -// | | -// | + -// v +++ -// outputHolderCh | | outputCh (1 x Concurrency) -// v +++ -// | ^ -// | | -// | +-------+-------+ -// v | | -// +--------------+ +--------------+ -// +----- | worker | ....... | worker | worker (N Concurrency): child executor, eg. WindowExec (+SortExec) -// | +------------+-+ +-+------------+ -// | ^ ^ -// | | | -// | +-+ +-+ ...... +-+ -// | | | | | | | -// | ... ... ... inputCh (Concurrency x 1) -// v | | | | | | -// inputHolderCh +++ +++ +++ -// v ^ ^ ^ -// | | | | -// | +------o----+ | -// | | +-----------------+-----+ -// | | | -// | +---+------------+------------+----+-----------+ -// | | Partition Splitter | -// | +--------------+-+------------+-+--------------+ -// | ^ -// | | -// | +---------------v-----------------+ -// +----------> | fetch data from DataSource | -// +---------------------------------+ -type ShuffleExec struct { - exec.BaseExecutor - concurrency int - workers []*shuffleWorker - - prepared bool - executed bool - - // each dataSource has a corresponding spliter - splitters []partitionSplitter - dataSources []exec.Executor - - finishCh chan struct{} - outputCh chan *shuffleOutput -} - -type shuffleOutput struct { - chk *chunk.Chunk - err error - giveBackCh chan *chunk.Chunk -} - -// Open implements the Executor Open interface. -func (e *ShuffleExec) Open(ctx context.Context) error { - for _, s := range e.dataSources { - if err := exec.Open(ctx, s); err != nil { - return err - } - } - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - - e.prepared = false - e.finishCh = make(chan struct{}, 1) - e.outputCh = make(chan *shuffleOutput, e.concurrency+len(e.dataSources)) - - for _, w := range e.workers { - w.finishCh = e.finishCh - - for _, r := range w.receivers { - r.inputCh = make(chan *chunk.Chunk, 1) - r.inputHolderCh = make(chan *chunk.Chunk, 1) - } - - w.outputCh = e.outputCh - w.outputHolderCh = make(chan *chunk.Chunk, 1) - - if err := exec.Open(ctx, w.childExec); err != nil { - return err - } - - for i, r := range w.receivers { - r.inputHolderCh <- exec.NewFirstChunk(e.dataSources[i]) - } - w.outputHolderCh <- exec.NewFirstChunk(e) - } - - return nil -} - -// Close implements the Executor Close interface. -func (e *ShuffleExec) Close() error { - var firstErr error - if !e.prepared { - for _, w := range e.workers { - for _, r := range w.receivers { - if r.inputHolderCh != nil { - close(r.inputHolderCh) - } - if r.inputCh != nil { - close(r.inputCh) - } - } - if w.outputHolderCh != nil { - close(w.outputHolderCh) - } - } - if e.outputCh != nil { - close(e.outputCh) - } - } - if e.finishCh != nil { - close(e.finishCh) - } - for _, w := range e.workers { - for _, r := range w.receivers { - if r.inputCh != nil { - channel.Clear(r.inputCh) - } - } - // close child executor of each worker - if err := exec.Close(w.childExec); err != nil && firstErr == nil { - firstErr = err - } - } - if e.outputCh != nil { - channel.Clear(e.outputCh) - } - e.executed = false - - if e.RuntimeStats() != nil { - runtimeStats := &execdetails.RuntimeStatsWithConcurrencyInfo{} - runtimeStats.SetConcurrencyInfo(execdetails.NewConcurrencyInfo("ShuffleConcurrency", e.concurrency)) - e.Ctx().GetSessionVars().StmtCtx.RuntimeStatsColl.RegisterStats(e.ID(), runtimeStats) - } - - // close dataSources - for _, dataSource := range e.dataSources { - if err := exec.Close(dataSource); err != nil && firstErr == nil { - firstErr = err - } - } - // close baseExecutor - if err := e.BaseExecutor.Close(); err != nil && firstErr == nil { - firstErr = err - } - return errors.Trace(firstErr) -} - -func (e *ShuffleExec) prepare4ParallelExec(ctx context.Context) { - waitGroup := &sync.WaitGroup{} - waitGroup.Add(len(e.workers) + len(e.dataSources)) - // create a goroutine for each dataSource to fetch and split data - for i := range e.dataSources { - go e.fetchDataAndSplit(ctx, i, waitGroup) - } - - for _, w := range e.workers { - go w.run(ctx, waitGroup) - } - - go e.waitWorkerAndCloseOutput(waitGroup) -} - -func (e *ShuffleExec) waitWorkerAndCloseOutput(waitGroup *sync.WaitGroup) { - waitGroup.Wait() - close(e.outputCh) -} - -// Next implements the Executor Next interface. -func (e *ShuffleExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if !e.prepared { - e.prepare4ParallelExec(ctx) - e.prepared = true - } - - failpoint.Inject("shuffleError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(errors.New("ShuffleExec.Next error")) - } - }) - - if e.executed { - return nil - } - - result, ok := <-e.outputCh - if !ok { - e.executed = true - return nil - } - if result.err != nil { - return result.err - } - req.SwapColumns(result.chk) // `shuffleWorker` will not send an empty `result.chk` to `e.outputCh`. - result.giveBackCh <- result.chk - - return nil -} - -func recoveryShuffleExec(output chan *shuffleOutput, r any) { - err := util.GetRecoverError(r) - output <- &shuffleOutput{err: util.GetRecoverError(r)} - logutil.BgLogger().Error("shuffle panicked", zap.Error(err), zap.Stack("stack")) -} - -func (e *ShuffleExec) fetchDataAndSplit(ctx context.Context, dataSourceIndex int, waitGroup *sync.WaitGroup) { - var ( - err error - workerIndices []int - ) - results := make([]*chunk.Chunk, len(e.workers)) - chk := exec.TryNewCacheChunk(e.dataSources[dataSourceIndex]) - - defer func() { - if r := recover(); r != nil { - recoveryShuffleExec(e.outputCh, r) - } - for _, w := range e.workers { - close(w.receivers[dataSourceIndex].inputCh) - } - waitGroup.Done() - }() - - failpoint.Inject("shuffleExecFetchDataAndSplit", func(val failpoint.Value) { - if val.(bool) { - time.Sleep(100 * time.Millisecond) - panic("shuffleExecFetchDataAndSplitPanic") - } - }) - - for { - err = exec.Next(ctx, e.dataSources[dataSourceIndex], chk) - if err != nil { - e.outputCh <- &shuffleOutput{err: err} - return - } - if chk.NumRows() == 0 { - break - } - - workerIndices, err = e.splitters[dataSourceIndex].split(e.Ctx(), chk, workerIndices) - if err != nil { - e.outputCh <- &shuffleOutput{err: err} - return - } - numRows := chk.NumRows() - for i := 0; i < numRows; i++ { - workerIdx := workerIndices[i] - w := e.workers[workerIdx] - - if results[workerIdx] == nil { - select { - case <-e.finishCh: - return - case results[workerIdx] = <-w.receivers[dataSourceIndex].inputHolderCh: - //nolint: revive - break - } - } - results[workerIdx].AppendRow(chk.GetRow(i)) - if results[workerIdx].IsFull() { - w.receivers[dataSourceIndex].inputCh <- results[workerIdx] - results[workerIdx] = nil - } - } - } - for i, w := range e.workers { - if results[i] != nil { - w.receivers[dataSourceIndex].inputCh <- results[i] - results[i] = nil - } - } -} - -var _ exec.Executor = &shuffleReceiver{} - -// shuffleReceiver receives chunk from dataSource through inputCh -type shuffleReceiver struct { - exec.BaseExecutor - - finishCh <-chan struct{} - executed bool - - inputCh chan *chunk.Chunk - inputHolderCh chan *chunk.Chunk -} - -// Open implements the Executor Open interface. -func (e *shuffleReceiver) Open(ctx context.Context) error { - if err := e.BaseExecutor.Open(ctx); err != nil { - return err - } - e.executed = false - return nil -} - -// Close implements the Executor Close interface. -func (e *shuffleReceiver) Close() error { - return errors.Trace(e.BaseExecutor.Close()) -} - -// Next implements the Executor Next interface. -// It is called by `Tail` executor within "shuffle", to fetch data from `DataSource` by `inputCh`. -func (e *shuffleReceiver) Next(_ context.Context, req *chunk.Chunk) error { - req.Reset() - if e.executed { - return nil - } - select { - case <-e.finishCh: - e.executed = true - return nil - case result, ok := <-e.inputCh: - if !ok || result.NumRows() == 0 { - e.executed = true - return nil - } - req.SwapColumns(result) - e.inputHolderCh <- result - return nil - } -} - -// shuffleWorker is the multi-thread worker executing child executors within "partition". -type shuffleWorker struct { - childExec exec.Executor - - finishCh <-chan struct{} - - // each receiver corresponse to a dataSource - receivers []*shuffleReceiver - - outputCh chan *shuffleOutput - outputHolderCh chan *chunk.Chunk -} - -func (e *shuffleWorker) run(ctx context.Context, waitGroup *sync.WaitGroup) { - defer func() { - if r := recover(); r != nil { - recoveryShuffleExec(e.outputCh, r) - } - waitGroup.Done() - }() - - failpoint.Inject("shuffleWorkerRun", nil) - for { - select { - case <-e.finishCh: - return - case chk := <-e.outputHolderCh: - if err := exec.Next(ctx, e.childExec, chk); err != nil { - e.outputCh <- &shuffleOutput{err: err} - return - } - - // Should not send an empty `chk` to `e.outputCh`. - if chk.NumRows() == 0 { - return - } - e.outputCh <- &shuffleOutput{chk: chk, giveBackCh: e.outputHolderCh} - } - } -} - -var _ partitionSplitter = &partitionHashSplitter{} -var _ partitionSplitter = &partitionRangeSplitter{} - -type partitionSplitter interface { - split(ctx sessionctx.Context, input *chunk.Chunk, workerIndices []int) ([]int, error) -} - -type partitionHashSplitter struct { - byItems []expression.Expression - numWorkers int - hashKeys [][]byte -} - -func (s *partitionHashSplitter) split(ctx sessionctx.Context, input *chunk.Chunk, workerIndices []int) ([]int, error) { - var err error - s.hashKeys, err = aggregate.GetGroupKey(ctx, input, s.hashKeys, s.byItems) - if err != nil { - return workerIndices, err - } - workerIndices = workerIndices[:0] - numRows := input.NumRows() - for i := 0; i < numRows; i++ { - workerIndices = append(workerIndices, int(murmur3.Sum32(s.hashKeys[i]))%s.numWorkers) - } - return workerIndices, nil -} - -func buildPartitionHashSplitter(concurrency int, byItems []expression.Expression) *partitionHashSplitter { - return &partitionHashSplitter{ - byItems: byItems, - numWorkers: concurrency, - } -} - -type partitionRangeSplitter struct { - byItems []expression.Expression - numWorkers int - groupChecker *vecgroupchecker.VecGroupChecker - idx int -} - -func buildPartitionRangeSplitter(ctx sessionctx.Context, concurrency int, byItems []expression.Expression) *partitionRangeSplitter { - return &partitionRangeSplitter{ - byItems: byItems, - numWorkers: concurrency, - groupChecker: vecgroupchecker.NewVecGroupChecker(ctx.GetExprCtx().GetEvalCtx(), ctx.GetSessionVars().EnableVectorizedExpression, byItems), - idx: 0, - } -} - -// This method is supposed to be used for shuffle with sorted `dataSource` -// the caller of this method should guarantee that `input` is grouped, -// which means that rows with the same byItems should be continuous, the order does not matter. -func (s *partitionRangeSplitter) split(_ sessionctx.Context, input *chunk.Chunk, workerIndices []int) ([]int, error) { - _, err := s.groupChecker.SplitIntoGroups(input) - if err != nil { - return workerIndices, err - } - - workerIndices = workerIndices[:0] - for !s.groupChecker.IsExhausted() { - begin, end := s.groupChecker.GetNextGroup() - for i := begin; i < end; i++ { - workerIndices = append(workerIndices, s.idx) - } - s.idx = (s.idx + 1) % s.numWorkers - } - - return workerIndices, nil -} diff --git a/pkg/executor/slow_query.go b/pkg/executor/slow_query.go index 87da1791d0b19..203d72d8f0e61 100644 --- a/pkg/executor/slow_query.go +++ b/pkg/executor/slow_query.go @@ -468,13 +468,13 @@ func (e *slowQueryRetriever) parseSlowLog(ctx context.Context, sctx sessionctx.C if e.stats != nil { e.stats.readFile += time.Since(startTime) } - if val, _err_ := failpoint.Eval(_curpkg_("mockReadSlowLogSlow")); _err_ == nil { + failpoint.Inject("mockReadSlowLogSlow", func(val failpoint.Value) { if val.(bool) { signals := ctx.Value(signalsKey{}).([]chan int) signals[0] <- 1 <-signals[1] } - } + }) for i := range logs { log := logs[i] t := slowLogTask{} @@ -636,11 +636,11 @@ func (e *slowQueryRetriever) parseLog(ctx context.Context, sctx sessionctx.Conte } }() e.memConsume(logSize) - if val, _err_ := failpoint.Eval(_curpkg_("errorMockParseSlowLogPanic")); _err_ == nil { + failpoint.Inject("errorMockParseSlowLogPanic", func(val failpoint.Value) { if val.(bool) { panic("panic test") } - } + }) var row []types.Datum user := "" tz := sctx.GetSessionVars().Location() diff --git a/pkg/executor/slow_query.go__failpoint_stash__ b/pkg/executor/slow_query.go__failpoint_stash__ deleted file mode 100644 index 203d72d8f0e61..0000000000000 --- a/pkg/executor/slow_query.go__failpoint_stash__ +++ /dev/null @@ -1,1259 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bufio" - "compress/gzip" - "context" - "fmt" - "io" - "os" - "path/filepath" - "runtime" - "slices" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/auth" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/plancodec" - "go.uber.org/zap" -) - -type signalsKey struct{} - -// ParseSlowLogBatchSize is the batch size of slow-log lines for a worker to parse, exported for testing. -var ParseSlowLogBatchSize = 64 - -// slowQueryRetriever is used to read slow log data. -type slowQueryRetriever struct { - table *model.TableInfo - outputCols []*model.ColumnInfo - initialized bool - extractor *plannercore.SlowQueryExtractor - files []logFile - fileIdx int - fileLine int - checker *slowLogChecker - columnValueFactoryMap map[string]slowQueryColumnValueFactory - instanceFactory func([]types.Datum) - - taskList chan slowLogTask - stats *slowQueryRuntimeStats - memTracker *memory.Tracker - lastFetchSize int64 - cancel context.CancelFunc - wg sync.WaitGroup -} - -func (e *slowQueryRetriever) retrieve(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - if !e.initialized { - err := e.initialize(ctx, sctx) - if err != nil { - return nil, err - } - ctx, e.cancel = context.WithCancel(ctx) - e.initializeAsyncParsing(ctx, sctx) - } - return e.dataForSlowLog(ctx) -} - -func (e *slowQueryRetriever) initialize(ctx context.Context, sctx sessionctx.Context) error { - var err error - var hasProcessPriv bool - if pm := privilege.GetPrivilegeManager(sctx); pm != nil { - hasProcessPriv = pm.RequestVerification(sctx.GetSessionVars().ActiveRoles, "", "", "", mysql.ProcessPriv) - } - // initialize column value factories. - e.columnValueFactoryMap = make(map[string]slowQueryColumnValueFactory, len(e.outputCols)) - for idx, col := range e.outputCols { - if col.Name.O == util.ClusterTableInstanceColumnName { - e.instanceFactory, err = getInstanceColumnValueFactory(sctx, idx) - if err != nil { - return err - } - continue - } - factory, err := getColumnValueFactoryByName(col.Name.O, idx) - if err != nil { - return err - } - if factory == nil { - panic(fmt.Sprintf("should never happen, should register new column %v into getColumnValueFactoryByName function", col.Name.O)) - } - e.columnValueFactoryMap[col.Name.O] = factory - } - // initialize checker. - e.checker = &slowLogChecker{ - hasProcessPriv: hasProcessPriv, - user: sctx.GetSessionVars().User, - } - e.stats = &slowQueryRuntimeStats{} - if e.extractor != nil { - e.checker.enableTimeCheck = e.extractor.Enable - for _, tr := range e.extractor.TimeRanges { - startTime := types.NewTime(types.FromGoTime(tr.StartTime.In(sctx.GetSessionVars().Location())), mysql.TypeDatetime, types.MaxFsp) - endTime := types.NewTime(types.FromGoTime(tr.EndTime.In(sctx.GetSessionVars().Location())), mysql.TypeDatetime, types.MaxFsp) - timeRange := &timeRange{ - startTime: startTime, - endTime: endTime, - } - e.checker.timeRanges = append(e.checker.timeRanges, timeRange) - } - } else { - e.extractor = &plannercore.SlowQueryExtractor{} - } - e.initialized = true - e.files, err = e.getAllFiles(ctx, sctx, sctx.GetSessionVars().SlowQueryFile) - if e.extractor.Desc { - slices.Reverse(e.files) - } - return err -} - -func (e *slowQueryRetriever) close() error { - for _, f := range e.files { - err := f.file.Close() - if err != nil { - logutil.BgLogger().Error("close slow log file failed.", zap.Error(err)) - } - } - if e.cancel != nil { - e.cancel() - } - e.wg.Wait() - return nil -} - -type parsedSlowLog struct { - rows [][]types.Datum - err error -} - -func (e *slowQueryRetriever) getNextFile() *logFile { - if e.fileIdx >= len(e.files) { - return nil - } - ret := &e.files[e.fileIdx] - file := e.files[e.fileIdx].file - e.fileIdx++ - if e.stats != nil { - stat, err := file.Stat() - if err == nil { - // ignore the err will be ok. - e.stats.readFileSize += stat.Size() - e.stats.readFileNum++ - } - } - return ret -} - -func (e *slowQueryRetriever) getPreviousReader() (*bufio.Reader, error) { - fileIdx := e.fileIdx - // fileIdx refer to the next file which should be read - // so we need to set fileIdx to fileIdx - 2 to get the previous file. - fileIdx = fileIdx - 2 - if fileIdx < 0 { - return nil, nil - } - file := e.files[fileIdx] - _, err := file.file.Seek(0, io.SeekStart) - if err != nil { - return nil, err - } - var reader *bufio.Reader - if !file.compressed { - reader = bufio.NewReader(file.file) - } else { - gr, err := gzip.NewReader(file.file) - if err != nil { - return nil, err - } - reader = bufio.NewReader(gr) - } - return reader, nil -} - -func (e *slowQueryRetriever) getNextReader() (*bufio.Reader, error) { - file := e.getNextFile() - if file == nil { - return nil, nil - } - var reader *bufio.Reader - if !file.compressed { - reader = bufio.NewReader(file.file) - } else { - gr, err := gzip.NewReader(file.file) - if err != nil { - return nil, err - } - reader = bufio.NewReader(gr) - } - return reader, nil -} - -func (e *slowQueryRetriever) parseDataForSlowLog(ctx context.Context, sctx sessionctx.Context) { - defer e.wg.Done() - reader, _ := e.getNextReader() - if reader == nil { - close(e.taskList) - return - } - e.parseSlowLog(ctx, sctx, reader, ParseSlowLogBatchSize) -} - -func (e *slowQueryRetriever) dataForSlowLog(ctx context.Context) ([][]types.Datum, error) { - var ( - task slowLogTask - ok bool - ) - e.memConsume(-e.lastFetchSize) - e.lastFetchSize = 0 - for { - select { - case task, ok = <-e.taskList: - case <-ctx.Done(): - return nil, ctx.Err() - } - if !ok { - return nil, nil - } - result := <-task.resultCh - rows, err := result.rows, result.err - if err != nil { - return nil, err - } - if len(rows) == 0 { - continue - } - if e.instanceFactory != nil { - for i := range rows { - e.instanceFactory(rows[i]) - } - } - e.lastFetchSize = calculateDatumsSize(rows) - return rows, nil - } -} - -type slowLogChecker struct { - // Below fields is used to check privilege. - hasProcessPriv bool - user *auth.UserIdentity - // Below fields is used to check slow log time valid. - enableTimeCheck bool - timeRanges []*timeRange -} - -type timeRange struct { - startTime types.Time - endTime types.Time -} - -func (sc *slowLogChecker) hasPrivilege(userName string) bool { - return sc.hasProcessPriv || sc.user == nil || userName == sc.user.Username -} - -func (sc *slowLogChecker) isTimeValid(t types.Time) bool { - for _, tr := range sc.timeRanges { - if sc.enableTimeCheck && (t.Compare(tr.startTime) >= 0 && t.Compare(tr.endTime) <= 0) { - return true - } - } - return !sc.enableTimeCheck -} - -func getOneLine(reader *bufio.Reader) ([]byte, error) { - return util.ReadLine(reader, int(variable.MaxOfMaxAllowedPacket)) -} - -type offset struct { - offset int - length int -} - -type slowLogTask struct { - resultCh chan parsedSlowLog -} - -type slowLogBlock []string - -func (e *slowQueryRetriever) getBatchLog(ctx context.Context, reader *bufio.Reader, offset *offset, num int) ([][]string, error) { - var line string - log := make([]string, 0, num) - var err error - for i := 0; i < num; i++ { - for { - if isCtxDone(ctx) { - return nil, ctx.Err() - } - e.fileLine++ - lineByte, err := getOneLine(reader) - if err != nil { - if err == io.EOF { - e.fileLine = 0 - newReader, err := e.getNextReader() - if newReader == nil || err != nil { - return [][]string{log}, err - } - offset.length = len(log) - reader.Reset(newReader) - continue - } - return [][]string{log}, err - } - line = string(hack.String(lineByte)) - log = append(log, line) - if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) { - if strings.HasPrefix(line, "use") || strings.HasPrefix(line, variable.SlowLogRowPrefixStr) { - continue - } - break - } - } - } - return [][]string{log}, err -} - -func (e *slowQueryRetriever) getBatchLogForReversedScan(ctx context.Context, reader *bufio.Reader, offset *offset, num int) ([][]string, error) { - // reader maybe change when read previous file. - inputReader := reader - defer func() { - newReader, _ := e.getNextReader() - if newReader != nil { - inputReader.Reset(newReader) - } - }() - var line string - var logs []slowLogBlock - var log []string - var err error - hasStartFlag := false - scanPreviousFile := false - for { - if isCtxDone(ctx) { - return nil, ctx.Err() - } - e.fileLine++ - lineByte, err := getOneLine(reader) - if err != nil { - if err == io.EOF { - if len(log) == 0 { - decomposedSlowLogTasks := decomposeToSlowLogTasks(logs, num) - offset.length = len(decomposedSlowLogTasks) - return decomposedSlowLogTasks, nil - } - e.fileLine = 0 - reader, err = e.getPreviousReader() - if reader == nil || err != nil { - return decomposeToSlowLogTasks(logs, num), nil - } - scanPreviousFile = true - continue - } - return nil, err - } - line = string(hack.String(lineByte)) - if !hasStartFlag && strings.HasPrefix(line, variable.SlowLogStartPrefixStr) { - hasStartFlag = true - } - if hasStartFlag { - log = append(log, line) - if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) { - if strings.HasPrefix(line, "use") || strings.HasPrefix(line, variable.SlowLogRowPrefixStr) { - continue - } - logs = append(logs, log) - if scanPreviousFile { - break - } - log = make([]string, 0, 8) - hasStartFlag = false - } - } - } - return decomposeToSlowLogTasks(logs, num), err -} - -func decomposeToSlowLogTasks(logs []slowLogBlock, num int) [][]string { - if len(logs) == 0 { - return nil - } - - //In reversed scan, We should reverse the blocks. - last := len(logs) - 1 - for i := 0; i < len(logs)/2; i++ { - logs[i], logs[last-i] = logs[last-i], logs[i] - } - - decomposedSlowLogTasks := make([][]string, 0) - log := make([]string, 0, num*len(logs[0])) - for i := range logs { - log = append(log, logs[i]...) - if i > 0 && i%num == 0 { - decomposedSlowLogTasks = append(decomposedSlowLogTasks, log) - log = make([]string, 0, len(log)) - } - } - if len(log) > 0 { - decomposedSlowLogTasks = append(decomposedSlowLogTasks, log) - } - return decomposedSlowLogTasks -} - -func (e *slowQueryRetriever) parseSlowLog(ctx context.Context, sctx sessionctx.Context, reader *bufio.Reader, logNum int) { - defer close(e.taskList) - offset := offset{offset: 0, length: 0} - // To limit the num of go routine - concurrent := sctx.GetSessionVars().Concurrency.DistSQLScanConcurrency() - ch := make(chan int, concurrent) - if e.stats != nil { - e.stats.concurrent = concurrent - } - defer close(ch) - for { - startTime := time.Now() - var logs [][]string - var err error - if !e.extractor.Desc { - logs, err = e.getBatchLog(ctx, reader, &offset, logNum) - } else { - logs, err = e.getBatchLogForReversedScan(ctx, reader, &offset, logNum) - } - if err != nil { - t := slowLogTask{} - t.resultCh = make(chan parsedSlowLog, 1) - select { - case <-ctx.Done(): - return - case e.taskList <- t: - } - e.sendParsedSlowLogCh(t, parsedSlowLog{nil, err}) - } - if len(logs) == 0 || len(logs[0]) == 0 { - break - } - if e.stats != nil { - e.stats.readFile += time.Since(startTime) - } - failpoint.Inject("mockReadSlowLogSlow", func(val failpoint.Value) { - if val.(bool) { - signals := ctx.Value(signalsKey{}).([]chan int) - signals[0] <- 1 - <-signals[1] - } - }) - for i := range logs { - log := logs[i] - t := slowLogTask{} - t.resultCh = make(chan parsedSlowLog, 1) - start := offset - ch <- 1 - select { - case <-ctx.Done(): - return - case e.taskList <- t: - } - e.wg.Add(1) - go func() { - defer e.wg.Done() - result, err := e.parseLog(ctx, sctx, log, start) - e.sendParsedSlowLogCh(t, parsedSlowLog{result, err}) - <-ch - }() - offset.offset = e.fileLine - offset.length = 0 - select { - case <-ctx.Done(): - return - default: - } - } - } -} - -func (*slowQueryRetriever) sendParsedSlowLogCh(t slowLogTask, re parsedSlowLog) { - select { - case t.resultCh <- re: - default: - return - } -} - -func getLineIndex(offset offset, index int) int { - var fileLine int - if offset.length <= index { - fileLine = index - offset.length + 1 - } else { - fileLine = offset.offset + index + 1 - } - return fileLine -} - -// findMatchedRightBracket returns the rightBracket index which matchs line[leftBracketIdx] -// leftBracketIdx should be valid string index for line -// Returns -1 if invalid inputs are given -func findMatchedRightBracket(line string, leftBracketIdx int) int { - leftBracket := line[leftBracketIdx] - rightBracket := byte('}') - if leftBracket == '[' { - rightBracket = ']' - } else if leftBracket != '{' { - return -1 - } - lineLength := len(line) - current := leftBracketIdx - leftBracketCnt := 0 - for current < lineLength { - b := line[current] - if b == leftBracket { - leftBracketCnt++ - current++ - } else if b == rightBracket { - leftBracketCnt-- - if leftBracketCnt > 0 { - current++ - } else if leftBracketCnt == 0 { - if current+1 < lineLength && line[current+1] != ' ' { - return -1 - } - return current - } else { - return -1 - } - } else { - current++ - } - } - return -1 -} - -func isLetterOrNumeric(b byte) bool { - return ('A' <= b && b <= 'Z') || ('a' <= b && b <= 'z') || ('0' <= b && b <= '9') -} - -// splitByColon split a line like "field: value field: value..." -// Note: -// 1. field string's first character can only be ASCII letters or digits, and can't contain ':' -// 2. value string may be surrounded by brackets, allowed brackets includes "[]" and "{}", like {key: value,{key: value}} -// "[]" can only be nested inside "[]"; "{}" can only be nested inside "{}" -// 3. value string can't contain ' ' character unless it is inside brackets -func splitByColon(line string) (fields []string, values []string) { - fields = make([]string, 0, 1) - values = make([]string, 0, 1) - - lineLength := len(line) - parseKey := true - start := 0 - errMsg := "" - for current := 0; current < lineLength; { - if parseKey { - // Find key start - for current < lineLength && !isLetterOrNumeric(line[current]) { - current++ - } - start = current - if current >= lineLength { - break - } - for current < lineLength && line[current] != ':' { - current++ - } - fields = append(fields, line[start:current]) - parseKey = false - current += 2 // bypass ": " - } else { - start = current - if current < lineLength && (line[current] == '{' || line[current] == '[') { - rBraceIdx := findMatchedRightBracket(line, current) - if rBraceIdx == -1 { - errMsg = "Braces matched error" - break - } - current = rBraceIdx + 1 - } else { - for current < lineLength && line[current] != ' ' { - current++ - } - } - values = append(values, line[start:min(current, len(line))]) - parseKey = true - } - } - if len(errMsg) > 0 { - logutil.BgLogger().Warn("slow query parse slow log error", zap.String("Error", errMsg), zap.String("Log", line)) - return nil, nil - } - return fields, values -} - -func (e *slowQueryRetriever) parseLog(ctx context.Context, sctx sessionctx.Context, log []string, offset offset) (data [][]types.Datum, err error) { - start := time.Now() - logSize := calculateLogSize(log) - defer e.memConsume(-logSize) - defer func() { - if r := recover(); r != nil { - err = util.GetRecoverError(r) - buf := make([]byte, 4096) - stackSize := runtime.Stack(buf, false) - buf = buf[:stackSize] - logutil.BgLogger().Warn("slow query parse slow log panic", zap.Error(err), zap.String("stack", string(buf))) - } - if e.stats != nil { - atomic.AddInt64(&e.stats.parseLog, int64(time.Since(start))) - } - }() - e.memConsume(logSize) - failpoint.Inject("errorMockParseSlowLogPanic", func(val failpoint.Value) { - if val.(bool) { - panic("panic test") - } - }) - var row []types.Datum - user := "" - tz := sctx.GetSessionVars().Location() - startFlag := false - for index, line := range log { - if isCtxDone(ctx) { - return nil, ctx.Err() - } - fileLine := getLineIndex(offset, index) - if !startFlag && strings.HasPrefix(line, variable.SlowLogStartPrefixStr) { - row = make([]types.Datum, len(e.outputCols)) - user = "" - valid := e.setColumnValue(sctx, row, tz, variable.SlowLogTimeStr, line[len(variable.SlowLogStartPrefixStr):], e.checker, fileLine) - if valid { - startFlag = true - } - continue - } - if startFlag { - if strings.HasPrefix(line, variable.SlowLogRowPrefixStr) { - line = line[len(variable.SlowLogRowPrefixStr):] - valid := true - if strings.HasPrefix(line, variable.SlowLogPrevStmtPrefix) { - valid = e.setColumnValue(sctx, row, tz, variable.SlowLogPrevStmt, line[len(variable.SlowLogPrevStmtPrefix):], e.checker, fileLine) - } else if strings.HasPrefix(line, variable.SlowLogUserAndHostStr+variable.SlowLogSpaceMarkStr) { - value := line[len(variable.SlowLogUserAndHostStr+variable.SlowLogSpaceMarkStr):] - fields := strings.SplitN(value, "@", 2) - if len(fields) < 2 { - continue - } - user = parseUserOrHostValue(fields[0]) - if e.checker != nil && !e.checker.hasPrivilege(user) { - startFlag = false - continue - } - valid = e.setColumnValue(sctx, row, tz, variable.SlowLogUserStr, user, e.checker, fileLine) - if !valid { - startFlag = false - continue - } - host := parseUserOrHostValue(fields[1]) - valid = e.setColumnValue(sctx, row, tz, variable.SlowLogHostStr, host, e.checker, fileLine) - } else if strings.HasPrefix(line, variable.SlowLogCopBackoffPrefix) { - valid = e.setColumnValue(sctx, row, tz, variable.SlowLogBackoffDetail, line, e.checker, fileLine) - } else if strings.HasPrefix(line, variable.SlowLogWarnings) { - line = line[len(variable.SlowLogWarnings+variable.SlowLogSpaceMarkStr):] - valid = e.setColumnValue(sctx, row, tz, variable.SlowLogWarnings, line, e.checker, fileLine) - } else { - fields, values := splitByColon(line) - for i := 0; i < len(fields); i++ { - valid := e.setColumnValue(sctx, row, tz, fields[i], values[i], e.checker, fileLine) - if !valid { - startFlag = false - break - } - } - } - if !valid { - startFlag = false - } - } else if strings.HasSuffix(line, variable.SlowLogSQLSuffixStr) { - if strings.HasPrefix(line, "use") { - // `use DB` statements in the slow log is used to keep it be compatible with MySQL, - // since we already get the current DB from the `# DB` field, we can ignore it here, - // please see https://github.com/pingcap/tidb/issues/17846 for more details. - continue - } - if e.checker != nil && !e.checker.hasPrivilege(user) { - startFlag = false - continue - } - // Get the sql string, and mark the start flag to false. - _ = e.setColumnValue(sctx, row, tz, variable.SlowLogQuerySQLStr, string(hack.Slice(line)), e.checker, fileLine) - e.setDefaultValue(row) - e.memConsume(types.EstimatedMemUsage(row, 1)) - data = append(data, row) - startFlag = false - } else { - startFlag = false - } - } - } - return data, nil -} - -func (e *slowQueryRetriever) setColumnValue(sctx sessionctx.Context, row []types.Datum, tz *time.Location, field, value string, checker *slowLogChecker, lineNum int) bool { - factory := e.columnValueFactoryMap[field] - if factory == nil { - // Fix issue 34320, when slow log time is not in the output columns, the time filter condition is mistakenly discard. - if field == variable.SlowLogTimeStr && checker != nil { - t, err := ParseTime(value) - if err != nil { - err = fmt.Errorf("Parse slow log at line %v, failed field is %v, failed value is %v, error is %v", lineNum, field, value, err) - sctx.GetSessionVars().StmtCtx.AppendWarning(err) - return false - } - timeValue := types.NewTime(types.FromGoTime(t), mysql.TypeTimestamp, types.MaxFsp) - return checker.isTimeValid(timeValue) - } - return true - } - valid, err := factory(row, value, tz, checker) - if err != nil { - err = fmt.Errorf("Parse slow log at line %v, failed field is %v, failed value is %v, error is %v", lineNum, field, value, err) - sctx.GetSessionVars().StmtCtx.AppendWarning(err) - return true - } - return valid -} - -func (e *slowQueryRetriever) setDefaultValue(row []types.Datum) { - for i := range row { - if !row[i].IsNull() { - continue - } - row[i] = table.GetZeroValue(e.outputCols[i]) - } -} - -type slowQueryColumnValueFactory func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) - -func parseUserOrHostValue(value string) string { - // the new User&Host format: root[root] @ localhost [127.0.0.1] - tmp := strings.Split(value, "[") - return strings.TrimSpace(tmp[0]) -} - -func getColumnValueFactoryByName(colName string, columnIdx int) (slowQueryColumnValueFactory, error) { - switch colName { - case variable.SlowLogTimeStr: - return func(row []types.Datum, value string, tz *time.Location, checker *slowLogChecker) (bool, error) { - t, err := ParseTime(value) - if err != nil { - return false, err - } - timeValue := types.NewTime(types.FromGoTime(t.In(tz)), mysql.TypeTimestamp, types.MaxFsp) - if checker != nil { - valid := checker.isTimeValid(timeValue) - if !valid { - return valid, nil - } - } - row[columnIdx] = types.NewTimeDatum(timeValue) - return true, nil - }, nil - case variable.SlowLogBackoffDetail: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (bool, error) { - backoffDetail := row[columnIdx].GetString() - if len(backoffDetail) > 0 { - backoffDetail += " " - } - backoffDetail += value - row[columnIdx] = types.NewStringDatum(backoffDetail) - return true, nil - }, nil - case variable.SlowLogPlan: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (bool, error) { - plan := parsePlan(value) - row[columnIdx] = types.NewStringDatum(plan) - return true, nil - }, nil - case variable.SlowLogBinaryPlan: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (bool, error) { - if strings.HasPrefix(value, variable.SlowLogBinaryPlanPrefix) { - value = value[len(variable.SlowLogBinaryPlanPrefix) : len(value)-len(variable.SlowLogPlanSuffix)] - } - row[columnIdx] = types.NewStringDatum(value) - return true, nil - }, nil - case variable.SlowLogConnIDStr, variable.SlowLogExecRetryCount, variable.SlowLogPreprocSubQueriesStr, - execdetails.WriteKeysStr, execdetails.WriteSizeStr, execdetails.PrewriteRegionStr, execdetails.TxnRetryStr, - execdetails.RequestCountStr, execdetails.TotalKeysStr, execdetails.ProcessKeysStr, - execdetails.RocksdbDeleteSkippedCountStr, execdetails.RocksdbKeySkippedCountStr, - execdetails.RocksdbBlockCacheHitCountStr, execdetails.RocksdbBlockReadCountStr, - variable.SlowLogTxnStartTSStr, execdetails.RocksdbBlockReadByteStr: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { - v, err := strconv.ParseUint(value, 10, 64) - if err != nil { - return false, err - } - row[columnIdx] = types.NewUintDatum(v) - return true, nil - }, nil - case variable.SlowLogExecRetryTime, variable.SlowLogQueryTimeStr, variable.SlowLogParseTimeStr, - variable.SlowLogCompileTimeStr, variable.SlowLogRewriteTimeStr, variable.SlowLogPreProcSubQueryTimeStr, - variable.SlowLogOptimizeTimeStr, variable.SlowLogWaitTSTimeStr, execdetails.PreWriteTimeStr, - execdetails.WaitPrewriteBinlogTimeStr, execdetails.CommitTimeStr, execdetails.GetCommitTSTimeStr, - execdetails.CommitBackoffTimeStr, execdetails.ResolveLockTimeStr, execdetails.LocalLatchWaitTimeStr, - execdetails.CopTimeStr, execdetails.ProcessTimeStr, execdetails.WaitTimeStr, execdetails.BackoffTimeStr, - execdetails.LockKeysTimeStr, variable.SlowLogCopProcAvg, variable.SlowLogCopProcP90, variable.SlowLogCopProcMax, - variable.SlowLogCopWaitAvg, variable.SlowLogCopWaitP90, variable.SlowLogCopWaitMax, variable.SlowLogKVTotal, - variable.SlowLogPDTotal, variable.SlowLogBackoffTotal, variable.SlowLogWriteSQLRespTotal, variable.SlowLogRRU, - variable.SlowLogWRU, variable.SlowLogWaitRUDuration: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { - v, err := strconv.ParseFloat(value, 64) - if err != nil { - return false, err - } - row[columnIdx] = types.NewFloat64Datum(v) - return true, nil - }, nil - case variable.SlowLogUserStr, variable.SlowLogHostStr, execdetails.BackoffTypesStr, variable.SlowLogDBStr, variable.SlowLogIndexNamesStr, variable.SlowLogDigestStr, - variable.SlowLogStatsInfoStr, variable.SlowLogCopProcAddr, variable.SlowLogCopWaitAddr, variable.SlowLogPlanDigest, - variable.SlowLogPrevStmt, variable.SlowLogQuerySQLStr, variable.SlowLogWarnings, variable.SlowLogSessAliasStr, - variable.SlowLogResourceGroup: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { - row[columnIdx] = types.NewStringDatum(value) - return true, nil - }, nil - case variable.SlowLogMemMax, variable.SlowLogDiskMax, variable.SlowLogResultRows: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { - v, err := strconv.ParseInt(value, 10, 64) - if err != nil { - return false, err - } - row[columnIdx] = types.NewIntDatum(v) - return true, nil - }, nil - case variable.SlowLogPrepared, variable.SlowLogSucc, variable.SlowLogPlanFromCache, variable.SlowLogPlanFromBinding, - variable.SlowLogIsInternalStr, variable.SlowLogIsExplicitTxn, variable.SlowLogIsWriteCacheTable, variable.SlowLogHasMoreResults: - return func(row []types.Datum, value string, _ *time.Location, _ *slowLogChecker) (valid bool, err error) { - v, err := strconv.ParseBool(value) - if err != nil { - return false, err - } - row[columnIdx] = types.NewDatum(v) - return true, nil - }, nil - } - return nil, nil -} - -func getInstanceColumnValueFactory(sctx sessionctx.Context, columnIdx int) (func(row []types.Datum), error) { - instanceAddr, err := infoschema.GetInstanceAddr(sctx) - if err != nil { - return nil, err - } - return func(row []types.Datum) { - row[columnIdx] = types.NewStringDatum(instanceAddr) - }, nil -} - -func parsePlan(planString string) string { - if len(planString) <= len(variable.SlowLogPlanPrefix)+len(variable.SlowLogPlanSuffix) { - return planString - } - planString = planString[len(variable.SlowLogPlanPrefix) : len(planString)-len(variable.SlowLogPlanSuffix)] - decodePlanString, err := plancodec.DecodePlan(planString) - if err == nil { - planString = decodePlanString - } else { - logutil.BgLogger().Error("decode plan in slow log failed", zap.String("plan", planString), zap.Error(err)) - } - return planString -} - -// ParseTime exports for testing. -func ParseTime(s string) (time.Time, error) { - t, err := time.Parse(logutil.SlowLogTimeFormat, s) - if err != nil { - // This is for compatibility. - t, err = time.Parse(logutil.OldSlowLogTimeFormat, s) - if err != nil { - err = errors.Errorf("string \"%v\" doesn't has a prefix that matches format \"%v\", err: %v", s, logutil.SlowLogTimeFormat, err) - } - } - return t, err -} - -type logFile struct { - file *os.File // The opened file handle - start time.Time // The start time of the log file - compressed bool // The file is compressed or not -} - -// getAllFiles is used to get all slow-log needed to parse, it is exported for test. -func (e *slowQueryRetriever) getAllFiles(ctx context.Context, sctx sessionctx.Context, logFilePath string) ([]logFile, error) { - totalFileNum := 0 - if e.stats != nil { - startTime := time.Now() - defer func() { - e.stats.initialize = time.Since(startTime) - e.stats.totalFileNum = totalFileNum - }() - } - if e.extractor == nil || !e.extractor.Enable { - totalFileNum = 1 - //nolint: gosec - file, err := os.Open(logFilePath) - if err != nil { - if os.IsNotExist(err) { - return nil, nil - } - return nil, err - } - return []logFile{{file: file}}, nil - } - var logFiles []logFile - logDir := filepath.Dir(logFilePath) - ext := filepath.Ext(logFilePath) - prefix := logFilePath[:len(logFilePath)-len(ext)] - handleErr := func(err error) error { - // Ignore the error and append warning for usability. - if err != io.EOF { - sctx.GetSessionVars().StmtCtx.AppendWarning(err) - } - return nil - } - files, err := os.ReadDir(logDir) - if err != nil { - return nil, err - } - walkFn := func(path string, info os.DirEntry) error { - if info.IsDir() { - return nil - } - // All rotated log files have the same prefix with the original file. - if !strings.HasPrefix(path, prefix) { - return nil - } - compressed := strings.HasSuffix(path, ".gz") - if isCtxDone(ctx) { - return ctx.Err() - } - totalFileNum++ - file, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm) - if err != nil { - return handleErr(err) - } - skip := false - defer func() { - if !skip { - terror.Log(file.Close()) - } - }() - // Get the file start time. - fileStartTime, err := e.getFileStartTime(ctx, file, compressed) - if err != nil { - return handleErr(err) - } - start := types.NewTime(types.FromGoTime(fileStartTime), mysql.TypeDatetime, types.MaxFsp) - notInAllTimeRanges := true - for _, tr := range e.checker.timeRanges { - if start.Compare(tr.endTime) <= 0 { - notInAllTimeRanges = false - break - } - } - if notInAllTimeRanges { - return nil - } - - // If we want to get the end time from a compressed file, - // we need uncompress the whole file which is very slow and consume a lot of memory. - if !compressed { - // Get the file end time. - fileEndTime, err := e.getFileEndTime(ctx, file) - if err != nil { - return handleErr(err) - } - end := types.NewTime(types.FromGoTime(fileEndTime), mysql.TypeDatetime, types.MaxFsp) - inTimeRanges := false - for _, tr := range e.checker.timeRanges { - if !(start.Compare(tr.endTime) > 0 || end.Compare(tr.startTime) < 0) { - inTimeRanges = true - break - } - } - if !inTimeRanges { - return nil - } - } - _, err = file.Seek(0, io.SeekStart) - if err != nil { - return handleErr(err) - } - logFiles = append(logFiles, logFile{ - file: file, - start: fileStartTime, - compressed: compressed, - }) - skip = true - return nil - } - for _, file := range files { - err := walkFn(filepath.Join(logDir, file.Name()), file) - if err != nil { - return nil, err - } - } - // Sort by start time - slices.SortFunc(logFiles, func(i, j logFile) int { - return i.start.Compare(j.start) - }) - // Assume no time range overlap in log files and remove unnecessary log files for compressed files. - var ret []logFile - for i, file := range logFiles { - if i == len(logFiles)-1 || !file.compressed { - ret = append(ret, file) - continue - } - start := types.NewTime(types.FromGoTime(logFiles[i].start), mysql.TypeDatetime, types.MaxFsp) - // use next file.start as endTime - end := types.NewTime(types.FromGoTime(logFiles[i+1].start), mysql.TypeDatetime, types.MaxFsp) - inTimeRanges := false - for _, tr := range e.checker.timeRanges { - if !(start.Compare(tr.endTime) > 0 || end.Compare(tr.startTime) < 0) { - inTimeRanges = true - break - } - } - if inTimeRanges { - ret = append(ret, file) - } - } - return ret, err -} - -func (*slowQueryRetriever) getFileStartTime(ctx context.Context, file *os.File, compressed bool) (time.Time, error) { - var t time.Time - _, err := file.Seek(0, io.SeekStart) - if err != nil { - return t, err - } - var reader *bufio.Reader - if !compressed { - reader = bufio.NewReader(file) - } else { - gr, err := gzip.NewReader(file) - if err != nil { - return t, err - } - reader = bufio.NewReader(gr) - } - maxNum := 128 - for { - lineByte, err := getOneLine(reader) - if err != nil { - return t, err - } - line := string(lineByte) - if strings.HasPrefix(line, variable.SlowLogStartPrefixStr) { - return ParseTime(line[len(variable.SlowLogStartPrefixStr):]) - } - maxNum-- - if maxNum <= 0 { - break - } - if isCtxDone(ctx) { - return t, ctx.Err() - } - } - return t, errors.Errorf("malform slow query file %v", file.Name()) -} - -func (e *slowQueryRetriever) getRuntimeStats() execdetails.RuntimeStats { - return e.stats -} - -type slowQueryRuntimeStats struct { - totalFileNum int - readFileNum int - readFile time.Duration - initialize time.Duration - readFileSize int64 - parseLog int64 - concurrent int -} - -// String implements the RuntimeStats interface. -func (s *slowQueryRuntimeStats) String() string { - return fmt.Sprintf("initialize: %s, read_file: %s, parse_log: {time:%s, concurrency:%v}, total_file: %v, read_file: %v, read_size: %s", - execdetails.FormatDuration(s.initialize), execdetails.FormatDuration(s.readFile), - execdetails.FormatDuration(time.Duration(s.parseLog)), s.concurrent, - s.totalFileNum, s.readFileNum, memory.FormatBytes(s.readFileSize)) -} - -// Merge implements the RuntimeStats interface. -func (s *slowQueryRuntimeStats) Merge(rs execdetails.RuntimeStats) { - tmp, ok := rs.(*slowQueryRuntimeStats) - if !ok { - return - } - s.totalFileNum += tmp.totalFileNum - s.readFileNum += tmp.readFileNum - s.readFile += tmp.readFile - s.initialize += tmp.initialize - s.readFileSize += tmp.readFileSize - s.parseLog += tmp.parseLog -} - -// Clone implements the RuntimeStats interface. -func (s *slowQueryRuntimeStats) Clone() execdetails.RuntimeStats { - newRs := *s - return &newRs -} - -// Tp implements the RuntimeStats interface. -func (*slowQueryRuntimeStats) Tp() int { - return execdetails.TpSlowQueryRuntimeStat -} - -func (*slowQueryRetriever) getFileEndTime(ctx context.Context, file *os.File) (time.Time, error) { - var t time.Time - var tried int - stat, err := file.Stat() - if err != nil { - return t, err - } - endCursor := stat.Size() - maxLineNum := 128 - for { - lines, readBytes, err := readLastLines(ctx, file, endCursor) - if err != nil { - return t, err - } - // read out the file - if readBytes == 0 { - break - } - endCursor -= int64(readBytes) - for i := len(lines) - 1; i >= 0; i-- { - if strings.HasPrefix(lines[i], variable.SlowLogStartPrefixStr) { - return ParseTime(lines[i][len(variable.SlowLogStartPrefixStr):]) - } - } - tried += len(lines) - if tried >= maxLineNum { - break - } - if isCtxDone(ctx) { - return t, ctx.Err() - } - } - return t, errors.Errorf("invalid slow query file %v", file.Name()) -} - -const maxReadCacheSize = 1024 * 1024 * 64 - -// Read lines from the end of a file -// endCursor initial value should be the filesize -func readLastLines(ctx context.Context, file *os.File, endCursor int64) ([]string, int, error) { - var lines []byte - var firstNonNewlinePos int - var cursor = endCursor - var size int64 = 2048 - for { - // stop if we are at the beginning - // check it in the start to avoid read beyond the size - if cursor <= 0 { - break - } - if size < maxReadCacheSize { - size = size * 2 - } - if cursor < size { - size = cursor - } - cursor -= size - - _, err := file.Seek(cursor, io.SeekStart) - if err != nil { - return nil, 0, err - } - chars := make([]byte, size) - _, err = file.Read(chars) - if err != nil { - return nil, 0, err - } - lines = append(chars, lines...) // nozero - - // find first '\n' or '\r' - for i := 0; i < len(chars)-1; i++ { - if (chars[i] == '\n' || chars[i] == '\r') && chars[i+1] != '\n' && chars[i+1] != '\r' { - firstNonNewlinePos = i + 1 - break - } - } - if firstNonNewlinePos > 0 { - break - } - if isCtxDone(ctx) { - return nil, 0, ctx.Err() - } - } - finalStr := string(lines[firstNonNewlinePos:]) - return strings.Split(strings.ReplaceAll(finalStr, "\r\n", "\n"), "\n"), len(finalStr), nil -} - -func (e *slowQueryRetriever) initializeAsyncParsing(ctx context.Context, sctx sessionctx.Context) { - e.taskList = make(chan slowLogTask, 1) - e.wg.Add(1) - go e.parseDataForSlowLog(ctx, sctx) -} - -func calculateLogSize(log []string) int64 { - size := 0 - for _, line := range log { - size += len(line) - } - return int64(size) -} - -func calculateDatumsSize(rows [][]types.Datum) int64 { - size := int64(0) - for _, row := range rows { - size += types.EstimatedMemUsage(row, 1) - } - return size -} - -func (e *slowQueryRetriever) memConsume(bytes int64) { - if e.memTracker != nil { - e.memTracker.Consume(bytes) - } -} diff --git a/pkg/executor/sortexec/binding__failpoint_binding__.go b/pkg/executor/sortexec/binding__failpoint_binding__.go deleted file mode 100644 index 1fe00bd1a09a7..0000000000000 --- a/pkg/executor/sortexec/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package sortexec - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/sortexec/parallel_sort_worker.go b/pkg/executor/sortexec/parallel_sort_worker.go index 967f4754f7575..1d7a4d7aaa7b6 100644 --- a/pkg/executor/sortexec/parallel_sort_worker.go +++ b/pkg/executor/sortexec/parallel_sort_worker.go @@ -87,7 +87,7 @@ func (p *parallelSortWorker) reset() { func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor int32) { injectParallelSortRandomFail(triggerFactor) - if val, _err_ := failpoint.Eval(_curpkg_("SlowSomeWorkers")); _err_ == nil { + failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { if val.(bool) { if p.workerIDForTest%2 == 0 { randNum := rand.Int31n(10000) @@ -96,7 +96,7 @@ func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor } } } - } + }) } func (p *parallelSortWorker) multiWayMergeLocalSortedRows() ([]chunk.Row, error) { @@ -208,11 +208,11 @@ func (p *parallelSortWorker) keyColumnsLess(i, j chunk.Row) int { p.timesOfRowCompare = 0 } - if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { if val.(bool) { p.timesOfRowCompare += 1024 } - } + }) p.timesOfRowCompare++ return p.lessRowFunc(i, j) diff --git a/pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ b/pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ deleted file mode 100644 index 1d7a4d7aaa7b6..0000000000000 --- a/pkg/executor/sortexec/parallel_sort_worker.go__failpoint_stash__ +++ /dev/null @@ -1,229 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sortexec - -import ( - "math/rand" - "slices" - "sync" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/memory" -) - -// SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered. -const SignalCheckpointForSort uint = 20000 - -type parallelSortWorker struct { - workerIDForTest int - - chunkChannel chan *chunkWithMemoryUsage - fetcherAndWorkerSyncer *sync.WaitGroup - errOutputChan chan rowWithError - finishCh chan struct{} - - lessRowFunc func(chunk.Row, chunk.Row) int - timesOfRowCompare uint - - memTracker *memory.Tracker - totalMemoryUsage int64 - - spillHelper *parallelSortSpillHelper - - localSortedRows []*chunk.Iterator4Slice - sortedRowsIter *chunk.Iterator4Slice - maxSortedRowsLimit int - chunkIters []*chunk.Iterator4Chunk - rowNumInChunkIters int - merger *multiWayMerger -} - -func newParallelSortWorker( - workerIDForTest int, - lessRowFunc func(chunk.Row, chunk.Row) int, - chunkChannel chan *chunkWithMemoryUsage, - fetcherAndWorkerSyncer *sync.WaitGroup, - errOutputChan chan rowWithError, - finishCh chan struct{}, - memTracker *memory.Tracker, - sortedRowsIter *chunk.Iterator4Slice, - maxChunkSize int, - spillHelper *parallelSortSpillHelper) *parallelSortWorker { - return ¶llelSortWorker{ - workerIDForTest: workerIDForTest, - lessRowFunc: lessRowFunc, - chunkChannel: chunkChannel, - fetcherAndWorkerSyncer: fetcherAndWorkerSyncer, - errOutputChan: errOutputChan, - finishCh: finishCh, - timesOfRowCompare: 0, - memTracker: memTracker, - sortedRowsIter: sortedRowsIter, - maxSortedRowsLimit: maxChunkSize * 30, - spillHelper: spillHelper, - } -} - -func (p *parallelSortWorker) reset() { - p.localSortedRows = nil - p.sortedRowsIter = nil - p.merger = nil - p.memTracker.ReplaceBytesUsed(0) -} - -func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor int32) { - injectParallelSortRandomFail(triggerFactor) - failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { - if val.(bool) { - if p.workerIDForTest%2 == 0 { - randNum := rand.Int31n(10000) - if randNum < 10 { - time.Sleep(1 * time.Millisecond) - } - } - } - }) -} - -func (p *parallelSortWorker) multiWayMergeLocalSortedRows() ([]chunk.Row, error) { - totalRowNum := 0 - for _, rows := range p.localSortedRows { - totalRowNum += rows.Len() - } - resultSortedRows := make([]chunk.Row, 0, totalRowNum) - source := &memorySource{sortedRowsIters: p.localSortedRows} - p.merger = newMultiWayMerger(source, p.lessRowFunc) - err := p.merger.init() - if err != nil { - return nil, err - } - - for { - // It's impossible to return error here as rows are in memory - row, _ := p.merger.next() - if row.IsEmpty() { - break - } - resultSortedRows = append(resultSortedRows, row) - } - p.localSortedRows = nil - return resultSortedRows, nil -} - -func (p *parallelSortWorker) convertChunksToRows() []chunk.Row { - rows := make([]chunk.Row, 0, p.rowNumInChunkIters) - for _, iter := range p.chunkIters { - row := iter.Begin() - for !row.IsEmpty() { - rows = append(rows, row) - row = iter.Next() - } - } - p.chunkIters = p.chunkIters[:0] - p.rowNumInChunkIters = 0 - return rows -} - -func (p *parallelSortWorker) sortBatchRows() { - rows := p.convertChunksToRows() - slices.SortFunc(rows, p.keyColumnsLess) - p.localSortedRows = append(p.localSortedRows, chunk.NewIterator4Slice(rows)) -} - -func (p *parallelSortWorker) sortLocalRows() ([]chunk.Row, error) { - // Handle Remaining batchRows whose row number is not over the `maxSortedRowsLimit` - if p.rowNumInChunkIters > 0 { - p.sortBatchRows() - } - - return p.multiWayMergeLocalSortedRows() -} - -func (p *parallelSortWorker) saveChunk(chk *chunk.Chunk) { - chkIter := chunk.NewIterator4Chunk(chk) - p.chunkIters = append(p.chunkIters, chkIter) - p.rowNumInChunkIters += chkIter.Len() -} - -// Fetching a bunch of chunks from chunkChannel and sort them. -// After receiving all chunks, we will get several sorted rows slices and we use k-way merge to sort them. -func (p *parallelSortWorker) fetchChunksAndSort() { - for p.fetchChunksAndSortImpl() { - } -} - -func (p *parallelSortWorker) fetchChunksAndSortImpl() bool { - var ( - chk *chunkWithMemoryUsage - ok bool - ) - select { - case <-p.finishCh: - return false - case chk, ok = <-p.chunkChannel: - // Memory usage of the chunk has been consumed at the chunk fetcher - if !ok { - p.injectFailPointForParallelSortWorker(100) - // Put local sorted rows into this iter who will be read by sort executor - sortedRows, err := p.sortLocalRows() - if err != nil { - p.errOutputChan <- rowWithError{err: err} - return false - } - p.sortedRowsIter.Reset(sortedRows) - return false - } - defer p.fetcherAndWorkerSyncer.Done() - p.totalMemoryUsage += chk.MemoryUsage - } - - p.saveChunk(chk.Chk) - - if p.rowNumInChunkIters >= p.maxSortedRowsLimit { - p.sortBatchRows() - } - - p.injectFailPointForParallelSortWorker(3) - return true -} - -func (p *parallelSortWorker) keyColumnsLess(i, j chunk.Row) int { - if p.timesOfRowCompare >= SignalCheckpointForSort { - // Trigger Consume for checking the NeedKill signal - p.memTracker.Consume(1) - p.timesOfRowCompare = 0 - } - - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { - if val.(bool) { - p.timesOfRowCompare += 1024 - } - }) - p.timesOfRowCompare++ - - return p.lessRowFunc(i, j) -} - -func (p *parallelSortWorker) run() { - defer func() { - if r := recover(); r != nil { - processPanicAndLog(p.errOutputChan, r) - } - }() - - p.fetchChunksAndSort() -} diff --git a/pkg/executor/sortexec/sort.go b/pkg/executor/sortexec/sort.go index 5d1f79a554040..9f581cb699a9a 100644 --- a/pkg/executor/sortexec/sort.go +++ b/pkg/executor/sortexec/sort.go @@ -615,28 +615,28 @@ func (e *SortExec) fetchChunksUnparallel(ctx context.Context) error { return err } - if val, _err_ := failpoint.Eval(_curpkg_("unholdSyncLock")); _err_ == nil { + failpoint.Inject("unholdSyncLock", func(val failpoint.Value) { if val.(bool) { // Ensure that spill can get `syncLock`. time.Sleep(1 * time.Millisecond) } - } + }) } - if val, _err_ := failpoint.Eval(_curpkg_("waitForSpill")); _err_ == nil { + failpoint.Inject("waitForSpill", func(val failpoint.Value) { if val.(bool) { // Ensure that spill is triggered before returning data. time.Sleep(50 * time.Millisecond) } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { if val.(bool) { if e.Ctx().GetSessionVars().ConnectionID == 123456 { e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) } } - } + }) err = e.handleCurrentPartitionBeforeExit() if err != nil { @@ -700,13 +700,13 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) { e.Parallel.resultChannel <- rowWithError{err: err} } - if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { if val.(bool) { if e.Ctx().GetSessionVars().ConnectionID == 123456 { e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) } } - } + }) // We must place it after the spill as workers will process its received // chunks after channel is closed and this will cause data race. diff --git a/pkg/executor/sortexec/sort.go__failpoint_stash__ b/pkg/executor/sortexec/sort.go__failpoint_stash__ deleted file mode 100644 index 9f581cb699a9a..0000000000000 --- a/pkg/executor/sortexec/sort.go__failpoint_stash__ +++ /dev/null @@ -1,845 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sortexec - -import ( - "context" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/expression" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/sqlkiller" -) - -// SortExec represents sorting executor. -type SortExec struct { - exec.BaseExecutor - - ByItems []*plannerutil.ByItems - fetched *atomic.Bool - ExecSchema *expression.Schema - - // keyColumns is the column index of the by items. - keyColumns []int - // keyCmpFuncs is used to compare each ByItem. - keyCmpFuncs []chunk.CompareFunc - - curPartition *sortPartition - - // We can't spill if size of data is lower than the limit - spillLimit int64 - - memTracker *memory.Tracker - diskTracker *disk.Tracker - - // TODO delete this variable in the future and remove the unparallel sort - IsUnparallel bool - - finishCh chan struct{} - - // multiWayMerge uses multi-way merge for spill disk. - // The multi-way merge algorithm can refer to https://en.wikipedia.org/wiki/K-way_merge_algorithm - multiWayMerge *multiWayMerger - - Unparallel struct { - Idx int - - // sortPartitions is the chunks to store row values for partitions. Every partition is a sorted list. - sortPartitions []*sortPartition - - spillAction *sortPartitionSpillDiskAction - } - - Parallel struct { - chunkChannel chan *chunkWithMemoryUsage - // It's useful when spill is triggered and the fetcher could know when workers finish their works. - fetcherAndWorkerSyncer *sync.WaitGroup - workers []*parallelSortWorker - - // Each worker will put their results into the given iter - sortedRowsIters []*chunk.Iterator4Slice - merger *multiWayMerger - - resultChannel chan rowWithError - - // Ensure that workers and fetcher have exited - closeSync chan struct{} - - spillHelper *parallelSortSpillHelper - spillAction *parallelSortSpillAction - } - - enableTmpStorageOnOOM bool -} - -// Close implements the Executor Close interface. -func (e *SortExec) Close() error { - // TopN not initializes `e.finishCh` but it will call the Close function - if e.finishCh != nil { - close(e.finishCh) - } - if e.Unparallel.spillAction != nil { - e.Unparallel.spillAction.SetFinished() - } - - if e.IsUnparallel { - for _, partition := range e.Unparallel.sortPartitions { - partition.close() - } - } else if e.finishCh != nil { - if e.fetched.CompareAndSwap(false, true) { - close(e.Parallel.resultChannel) - close(e.Parallel.chunkChannel) - } else { - for range e.Parallel.chunkChannel { - e.Parallel.fetcherAndWorkerSyncer.Done() - } - <-e.Parallel.closeSync - } - - // Ensure that `generateResult()` has exited, - // or data race may happen as `generateResult()` - // will use `e.Parallel.workers` and `e.Parallel.merger`. - channel.Clear(e.Parallel.resultChannel) - for i := range e.Parallel.workers { - if e.Parallel.workers[i] != nil { - e.Parallel.workers[i].reset() - } - } - e.Parallel.merger = nil - if e.Parallel.spillAction != nil { - e.Parallel.spillAction.SetFinished() - } - e.Parallel.spillHelper.close() - } - - if e.memTracker != nil { - e.memTracker.ReplaceBytesUsed(0) - } - - return exec.Close(e.Children(0)) -} - -// Open implements the Executor Open interface. -func (e *SortExec) Open(ctx context.Context) error { - e.fetched = &atomic.Bool{} - e.fetched.Store(false) - e.enableTmpStorageOnOOM = variable.EnableTmpStorageOnOOM.Load() - e.finishCh = make(chan struct{}, 1) - - // To avoid duplicated initialization for TopNExec. - if e.memTracker == nil { - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - e.spillLimit = e.Ctx().GetSessionVars().MemTracker.GetBytesLimit() / 10 - e.diskTracker = disk.NewTracker(e.ID(), -1) - e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) - } - - e.IsUnparallel = false - if e.IsUnparallel { - e.Unparallel.Idx = 0 - e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0] - } else { - e.Parallel.workers = make([]*parallelSortWorker, e.Ctx().GetSessionVars().ExecutorConcurrency) - e.Parallel.chunkChannel = make(chan *chunkWithMemoryUsage, e.Ctx().GetSessionVars().ExecutorConcurrency) - e.Parallel.fetcherAndWorkerSyncer = &sync.WaitGroup{} - e.Parallel.sortedRowsIters = make([]*chunk.Iterator4Slice, len(e.Parallel.workers)) - e.Parallel.resultChannel = make(chan rowWithError, 10) - e.Parallel.closeSync = make(chan struct{}) - e.Parallel.merger = newMultiWayMerger(&memorySource{sortedRowsIters: e.Parallel.sortedRowsIters}, e.lessRow) - e.Parallel.spillHelper = newParallelSortSpillHelper(e, exec.RetTypes(e), e.finishCh, e.lessRow, e.Parallel.resultChannel) - e.Parallel.spillAction = newParallelSortSpillDiskAction(e.Parallel.spillHelper) - for i := range e.Parallel.sortedRowsIters { - e.Parallel.sortedRowsIters[i] = chunk.NewIterator4Slice(nil) - } - if e.enableTmpStorageOnOOM { - e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.Parallel.spillAction) - } - } - - return exec.Open(ctx, e.Children(0)) -} - -// InitUnparallelModeForTest is for unit test -func (e *SortExec) InitUnparallelModeForTest() { - e.Unparallel.Idx = 0 - e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0] -} - -// Next implements the Executor Next interface. -// Sort constructs the result following these step in unparallel mode: -// 1. Read as mush as rows into memory. -// 2. If memory quota is triggered, sort these rows in memory and put them into disk as partition 1, then reset -// the memory quota trigger and return to step 1 -// 3. If memory quota is not triggered and child is consumed, sort these rows in memory as partition N. -// 4. Merge sort if the count of partitions is larger than 1. If there is only one partition in step 4, it works -// just like in-memory sort before. -// -// Here we explain the execution flow of the parallel sort implementation. -// There are 3 main components: -// 1. Chunks Fetcher: Fetcher is responsible for fetching chunks from child and send them to channel. -// 2. Parallel Sort Worker: Worker receives chunks from channel it will sort these chunks after the -// number of rows in these chunks exceeds limit, we call them as sorted rows after chunks are sorted. -// Then each worker will have several sorted rows, we use multi-way merge to sort them and each worker -// will have only one sorted rows in the end. -// 3. Result Generator: Generator gets n sorted rows from n workers, it will use multi-way merge to sort -// these rows, once it gets the next row, it will send it into `resultChannel` and the goroutine who -// calls `Next()` will fetch result from `resultChannel`. -/* - ┌─────────┐ - │ Child │ - └────▲────┘ - │ - Fetch - │ - ┌───────┴───────┐ - │ Chunk Fetcher │ - └───────┬───────┘ - │ - Push - │ - ▼ - ┌────────────────►Channel◄───────────────────┐ - │ ▲ │ - │ │ │ - Fetch Fetch Fetch - │ │ │ - ┌────┴───┐ ┌───┴────┐ ┌───┴────┐ - │ Worker │ │ Worker │ ...... │ Worker │ - └────┬───┘ └───┬────┘ └───┬────┘ - │ │ │ - │ │ │ - Sort Sort Sort - │ │ │ - │ │ │ - ┌──────┴──────┐ ┌──────┴──────┐ ┌──────┴──────┐ - │ Sorted Rows │ │ Sorted Rows │ ...... │ Sorted Rows │ - └──────▲──────┘ └──────▲──────┘ └──────▲──────┘ - │ │ │ - Pull Pull Pull - │ │ │ - └────────────────────┼───────────────────────┘ - │ - Multi-way Merge - │ - ┌──────┴──────┐ - │ Generator │ - └──────┬──────┘ - │ - Push - │ - ▼ - resultChannel -*/ -func (e *SortExec) Next(ctx context.Context, req *chunk.Chunk) error { - if e.fetched.CompareAndSwap(false, true) { - err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) - if err != nil { - return err - } - - e.buildKeyColumns() - err = e.fetchChunks(ctx) - if err != nil { - return err - } - } - - req.Reset() - if e.IsUnparallel { - return e.appendResultToChunkInUnparallelMode(req) - } - return e.appendResultToChunkInParallelMode(req) -} - -func (e *SortExec) appendResultToChunkInParallelMode(req *chunk.Chunk) error { - for !req.IsFull() { - row, ok := <-e.Parallel.resultChannel - if row.err != nil { - return row.err - } - if !ok { - return nil - } - req.AppendRow(row.row) - } - return nil -} - -func (e *SortExec) appendResultToChunkInUnparallelMode(req *chunk.Chunk) error { - sortPartitionListLen := len(e.Unparallel.sortPartitions) - if sortPartitionListLen == 0 { - return nil - } - - if sortPartitionListLen == 1 { - if err := e.onePartitionSorting(req); err != nil { - return err - } - } else { - if err := e.externalSorting(req); err != nil { - return err - } - } - return nil -} - -func (e *SortExec) generateResultWithMultiWayMerge() error { - multiWayMerge := newMultiWayMerger(&diskSource{sortedRowsInDisk: e.Parallel.spillHelper.sortedRowsInDisk}, e.lessRow) - - err := multiWayMerge.init() - if err != nil { - return err - } - - for { - row, err := multiWayMerge.next() - if err != nil { - return err - } - - if row.IsEmpty() { - return nil - } - - select { - case <-e.finishCh: - return nil - case e.Parallel.resultChannel <- rowWithError{row: row}: - } - injectParallelSortRandomFail(1) - } -} - -// We call this function when sorted rows are in disk -func (e *SortExec) generateResultFromDisk() error { - inDiskNum := len(e.Parallel.spillHelper.sortedRowsInDisk) - if inDiskNum == 0 { - return nil - } - - // Spill is triggered only once - if inDiskNum == 1 { - inDisk := e.Parallel.spillHelper.sortedRowsInDisk[0] - chunkNum := inDisk.NumChunks() - for i := 0; i < chunkNum; i++ { - chk, err := inDisk.GetChunk(i) - if err != nil { - return err - } - - injectParallelSortRandomFail(1) - - rowNum := chk.NumRows() - for j := 0; j < rowNum; j++ { - select { - case <-e.finishCh: - return nil - case e.Parallel.resultChannel <- rowWithError{row: chk.GetRow(j)}: - } - } - } - return nil - } - return e.generateResultWithMultiWayMerge() -} - -// We call this function to generate result when sorted rows are in memory -// Return true when spill is triggered -func (e *SortExec) generateResultFromMemory() (bool, error) { - if e.Parallel.merger == nil { - // Sort has been closed - return false, nil - } - err := e.Parallel.merger.init() - if err != nil { - return false, err - } - - maxChunkSize := e.MaxChunkSize() - resBuf := make([]rowWithError, 0, 3) - idx := int64(0) - var row chunk.Row - for { - resBuf = resBuf[:0] - for i := 0; i < maxChunkSize; i++ { - // It's impossible to return error here as rows are in memory - row, _ = e.Parallel.merger.next() - if row.IsEmpty() { - break - } - resBuf = append(resBuf, rowWithError{row: row, err: nil}) - } - - if len(resBuf) == 0 { - return false, nil - } - - for _, row := range resBuf { - select { - case <-e.finishCh: - return false, nil - case e.Parallel.resultChannel <- row: - } - } - - injectParallelSortRandomFail(3) - - if idx%1000 == 0 && e.Parallel.spillHelper.isSpillNeeded() { - return true, nil - } - } -} - -func (e *SortExec) generateResult(waitGroups ...*util.WaitGroupWrapper) { - for _, waitGroup := range waitGroups { - waitGroup.Wait() - } - close(e.Parallel.closeSync) - - defer func() { - if r := recover(); r != nil { - processPanicAndLog(e.Parallel.resultChannel, r) - } - - for i := range e.Parallel.sortedRowsIters { - e.Parallel.sortedRowsIters[i].Reset(nil) - } - e.Parallel.merger = nil - close(e.Parallel.resultChannel) - }() - - if !e.Parallel.spillHelper.isSpillTriggered() { - spillTriggered, err := e.generateResultFromMemory() - if err != nil { - e.Parallel.resultChannel <- rowWithError{err: err} - return - } - - if !spillTriggered { - return - } - - err = e.spillSortedRowsInMemory() - if err != nil { - e.Parallel.resultChannel <- rowWithError{err: err} - return - } - } - - err := e.generateResultFromDisk() - if err != nil { - e.Parallel.resultChannel <- rowWithError{err: err} - } -} - -// Spill rows that are in memory -func (e *SortExec) spillSortedRowsInMemory() error { - return e.Parallel.spillHelper.spillImpl(e.Parallel.merger) -} - -func (e *SortExec) onePartitionSorting(req *chunk.Chunk) (err error) { - err = e.Unparallel.sortPartitions[0].checkError() - if err != nil { - return err - } - - for !req.IsFull() { - row, err := e.Unparallel.sortPartitions[0].getNextSortedRow() - if err != nil { - return err - } - - if row.IsEmpty() { - return nil - } - - req.AppendRow(row) - } - return nil -} - -func (e *SortExec) externalSorting(req *chunk.Chunk) (err error) { - // We only need to check error for the last partition as previous partitions - // have been checked when we call `switchToNewSortPartition` function. - err = e.Unparallel.sortPartitions[len(e.Unparallel.sortPartitions)-1].checkError() - if err != nil { - return err - } - - if e.multiWayMerge == nil { - e.multiWayMerge = newMultiWayMerger(&sortPartitionSource{sortPartitions: e.Unparallel.sortPartitions}, e.lessRow) - err := e.multiWayMerge.init() - if err != nil { - return err - } - } - - for !req.IsFull() { - row, err := e.multiWayMerge.next() - if err != nil { - return err - } - if row.IsEmpty() { - return nil - } - req.AppendRow(row) - } - return nil -} - -func (e *SortExec) fetchChunks(ctx context.Context) error { - if e.IsUnparallel { - return e.fetchChunksUnparallel(ctx) - } - return e.fetchChunksParallel(ctx) -} - -func (e *SortExec) switchToNewSortPartition(fields []*types.FieldType, byItemsDesc []bool, appendPartition bool) error { - if appendPartition { - // Put the full partition into list - e.Unparallel.sortPartitions = append(e.Unparallel.sortPartitions, e.curPartition) - } - - if e.curPartition != nil { - err := e.curPartition.checkError() - if err != nil { - return err - } - } - - e.curPartition = newSortPartition(fields, byItemsDesc, e.keyColumns, e.keyCmpFuncs, e.spillLimit) - e.curPartition.getMemTracker().AttachTo(e.memTracker) - e.curPartition.getMemTracker().SetLabel(memory.LabelForRowChunks) - e.Unparallel.spillAction = e.curPartition.actionSpill() - if e.enableTmpStorageOnOOM { - e.curPartition.getDiskTracker().AttachTo(e.diskTracker) - e.curPartition.getDiskTracker().SetLabel(memory.LabelForRowChunks) - e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.Unparallel.spillAction) - } - return nil -} - -func (e *SortExec) checkError() error { - for _, partition := range e.Unparallel.sortPartitions { - err := partition.checkError() - if err != nil { - return err - } - } - return nil -} - -func (e *SortExec) storeChunk(chk *chunk.Chunk, fields []*types.FieldType, byItemsDesc []bool) error { - err := e.curPartition.checkError() - if err != nil { - return err - } - - if !e.curPartition.add(chk) { - err := e.switchToNewSortPartition(fields, byItemsDesc, true) - if err != nil { - return err - } - - if !e.curPartition.add(chk) { - return errFailToAddChunk - } - } - return nil -} - -func (e *SortExec) handleCurrentPartitionBeforeExit() error { - err := e.checkError() - if err != nil { - return err - } - - err = e.curPartition.sort() - if err != nil { - return err - } - - return nil -} - -func (e *SortExec) fetchChunksUnparallel(ctx context.Context) error { - fields := exec.RetTypes(e) - byItemsDesc := make([]bool, len(e.ByItems)) - for i, byItem := range e.ByItems { - byItemsDesc[i] = byItem.Desc - } - - err := e.switchToNewSortPartition(fields, byItemsDesc, false) - if err != nil { - return err - } - - for { - chk := exec.TryNewCacheChunk(e.Children(0)) - err := exec.Next(ctx, e.Children(0), chk) - if err != nil { - return err - } - if chk.NumRows() == 0 { - break - } - - err = e.storeChunk(chk, fields, byItemsDesc) - if err != nil { - return err - } - - failpoint.Inject("unholdSyncLock", func(val failpoint.Value) { - if val.(bool) { - // Ensure that spill can get `syncLock`. - time.Sleep(1 * time.Millisecond) - } - }) - } - - failpoint.Inject("waitForSpill", func(val failpoint.Value) { - if val.(bool) { - // Ensure that spill is triggered before returning data. - time.Sleep(50 * time.Millisecond) - } - }) - - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { - if val.(bool) { - if e.Ctx().GetSessionVars().ConnectionID == 123456 { - e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) - } - } - }) - - err = e.handleCurrentPartitionBeforeExit() - if err != nil { - return err - } - - e.Unparallel.sortPartitions = append(e.Unparallel.sortPartitions, e.curPartition) - e.curPartition = nil - return nil -} - -func (e *SortExec) fetchChunksParallel(ctx context.Context) error { - // Wait for the finish of all workers - workersWaiter := util.WaitGroupWrapper{} - // Wait for the finish of chunk fetcher - fetcherWaiter := util.WaitGroupWrapper{} - - for i := range e.Parallel.workers { - e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper) - worker := e.Parallel.workers[i] - workersWaiter.Run(func() { - worker.run() - }) - } - - // Fetch chunks from child and put chunks into chunkChannel - fetcherWaiter.Run(func() { - e.fetchChunksFromChild(ctx) - }) - - go e.generateResult(&workersWaiter, &fetcherWaiter) - return nil -} - -func (e *SortExec) spillRemainingRowsWhenNeeded() error { - if e.Parallel.spillHelper.isSpillTriggered() { - return e.Parallel.spillHelper.spill() - } - return nil -} - -func (e *SortExec) checkSpillAndExecute() error { - if e.Parallel.spillHelper.isSpillNeeded() { - // Wait for the stop of all workers - e.Parallel.fetcherAndWorkerSyncer.Wait() - return e.Parallel.spillHelper.spill() - } - return nil -} - -// Fetch chunks from child and put chunks into chunkChannel -func (e *SortExec) fetchChunksFromChild(ctx context.Context) { - defer func() { - if r := recover(); r != nil { - processPanicAndLog(e.Parallel.resultChannel, r) - } - - e.Parallel.fetcherAndWorkerSyncer.Wait() - err := e.spillRemainingRowsWhenNeeded() - if err != nil { - e.Parallel.resultChannel <- rowWithError{err: err} - } - - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { - if val.(bool) { - if e.Ctx().GetSessionVars().ConnectionID == 123456 { - e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) - } - } - }) - - // We must place it after the spill as workers will process its received - // chunks after channel is closed and this will cause data race. - close(e.Parallel.chunkChannel) - }() - - for { - chk := exec.TryNewCacheChunk(e.Children(0)) - err := exec.Next(ctx, e.Children(0), chk) - if err != nil { - e.Parallel.resultChannel <- rowWithError{err: err} - return - } - - rowCount := chk.NumRows() - if rowCount == 0 { - break - } - - chkWithMemoryUsage := &chunkWithMemoryUsage{ - Chk: chk, - MemoryUsage: chk.MemoryUsage() + chunk.RowSize*int64(rowCount), - } - - e.memTracker.Consume(chkWithMemoryUsage.MemoryUsage) - - e.Parallel.fetcherAndWorkerSyncer.Add(1) - - select { - case <-e.finishCh: - e.Parallel.fetcherAndWorkerSyncer.Done() - return - case e.Parallel.chunkChannel <- chkWithMemoryUsage: - } - - err = e.checkSpillAndExecute() - if err != nil { - e.Parallel.resultChannel <- rowWithError{err: err} - return - } - injectParallelSortRandomFail(3) - } -} - -func (e *SortExec) initCompareFuncs(ctx expression.EvalContext) error { - e.keyCmpFuncs = make([]chunk.CompareFunc, len(e.ByItems)) - for i := range e.ByItems { - keyType := e.ByItems[i].Expr.GetType(ctx) - e.keyCmpFuncs[i] = chunk.GetCompareFunc(keyType) - if e.keyCmpFuncs[i] == nil { - return errors.Errorf("Sort executor not supports type %s", types.TypeStr(keyType.GetType())) - } - } - return nil -} - -func (e *SortExec) buildKeyColumns() { - e.keyColumns = make([]int, 0, len(e.ByItems)) - for _, by := range e.ByItems { - col := by.Expr.(*expression.Column) - e.keyColumns = append(e.keyColumns, col.Index) - } -} - -func (e *SortExec) lessRow(rowI, rowJ chunk.Row) int { - for i, colIdx := range e.keyColumns { - cmpFunc := e.keyCmpFuncs[i] - cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) - if e.ByItems[i].Desc { - cmp = -cmp - } - if cmp != 0 { - return cmp - } - } - return 0 -} - -func (e *SortExec) compareRow(rowI, rowJ chunk.Row) int { - for i, colIdx := range e.keyColumns { - cmpFunc := e.keyCmpFuncs[i] - cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) - if e.ByItems[i].Desc { - cmp = -cmp - } - if cmp != 0 { - return cmp - } - } - return 0 -} - -// IsSpillTriggeredInParallelSortForTest tells if spill is triggered in parallel sort. -func (e *SortExec) IsSpillTriggeredInParallelSortForTest() bool { - return e.Parallel.spillHelper.isSpillTriggered() -} - -// GetSpilledRowNumInParallelSortForTest tells if spill is triggered in parallel sort. -func (e *SortExec) GetSpilledRowNumInParallelSortForTest() int64 { - totalSpilledRows := int64(0) - for _, disk := range e.Parallel.spillHelper.sortedRowsInDisk { - totalSpilledRows += disk.NumRows() - } - return totalSpilledRows -} - -// IsSpillTriggeredInOnePartitionForTest tells if spill is triggered in a specific partition, it's only used in test. -func (e *SortExec) IsSpillTriggeredInOnePartitionForTest(idx int) bool { - return e.Unparallel.sortPartitions[idx].isSpillTriggered() -} - -// GetRowNumInOnePartitionDiskForTest returns number of rows a partition holds in disk, it's only used in test. -func (e *SortExec) GetRowNumInOnePartitionDiskForTest(idx int) int64 { - return e.Unparallel.sortPartitions[idx].numRowInDiskForTest() -} - -// GetRowNumInOnePartitionMemoryForTest returns number of rows a partition holds in memory, it's only used in test. -func (e *SortExec) GetRowNumInOnePartitionMemoryForTest(idx int) int64 { - return e.Unparallel.sortPartitions[idx].numRowInMemoryForTest() -} - -// GetSortPartitionListLenForTest returns the number of partitions, it's only used in test. -func (e *SortExec) GetSortPartitionListLenForTest() int { - return len(e.Unparallel.sortPartitions) -} - -// GetSortMetaForTest returns some sort meta, it's only used in test. -func (e *SortExec) GetSortMetaForTest() (keyColumns []int, keyCmpFuncs []chunk.CompareFunc, byItemsDesc []bool) { - keyColumns = e.keyColumns - keyCmpFuncs = e.keyCmpFuncs - byItemsDesc = make([]bool, len(e.ByItems)) - for i, byItem := range e.ByItems { - byItemsDesc[i] = byItem.Desc - } - return -} diff --git a/pkg/executor/sortexec/sort_partition.go b/pkg/executor/sortexec/sort_partition.go index 82e1e3ee07822..7c798f6385e95 100644 --- a/pkg/executor/sortexec/sort_partition.go +++ b/pkg/executor/sortexec/sort_partition.go @@ -141,11 +141,11 @@ func (s *sortPartition) sortNoLock() (ret error) { return } - if val, _err_ := failpoint.Eval(_curpkg_("errorDuringSortRowContainer")); _err_ == nil { + failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { if val.(bool) { panic("sort meet error") } - } + }) sort.Slice(s.savedRows, s.keyColumnsLess) s.isSorted = true @@ -297,11 +297,11 @@ func (s *sortPartition) keyColumnsLess(i, j int) bool { s.timesOfRowCompare = 0 } - if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { if val.(bool) { s.timesOfRowCompare += 1024 } - } + }) s.timesOfRowCompare++ return s.lessRow(s.savedRows[i], s.savedRows[j]) diff --git a/pkg/executor/sortexec/sort_partition.go__failpoint_stash__ b/pkg/executor/sortexec/sort_partition.go__failpoint_stash__ deleted file mode 100644 index 7c798f6385e95..0000000000000 --- a/pkg/executor/sortexec/sort_partition.go__failpoint_stash__ +++ /dev/null @@ -1,367 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sortexec - -import ( - "sort" - "sync" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/memory" -) - -type sortPartition struct { - // cond is only used for protecting spillStatus - cond *sync.Cond - spillStatus int - - // syncLock is used to protect variables except `spillStatus` - syncLock sync.Mutex - - // Data are stored in savedRows - savedRows []chunk.Row - sliceIter *chunk.Iterator4Slice - isSorted bool - - // cursor iterates the spilled chunks. - cursor *dataCursor - inDisk *chunk.DataInDiskByChunks - - spillError error - closed bool - - fieldTypes []*types.FieldType - - memTracker *memory.Tracker - diskTracker *disk.Tracker - spillAction *sortPartitionSpillDiskAction - - // We can't spill if size of data is lower than the limit - spillLimit int64 - - byItemsDesc []bool - // keyColumns is the column index of the by items. - keyColumns []int - // keyCmpFuncs is used to compare each ByItem. - keyCmpFuncs []chunk.CompareFunc - - // Sort is a time-consuming operation, we need to set a checkpoint to detect - // the outside signal periodically. - timesOfRowCompare uint -} - -// Creates a new SortPartition in memory. -func newSortPartition(fieldTypes []*types.FieldType, byItemsDesc []bool, - keyColumns []int, keyCmpFuncs []chunk.CompareFunc, spillLimit int64) *sortPartition { - lock := new(sync.Mutex) - retVal := &sortPartition{ - cond: sync.NewCond(lock), - spillError: nil, - spillStatus: notSpilled, - fieldTypes: fieldTypes, - savedRows: make([]chunk.Row, 0), - isSorted: false, - inDisk: nil, // It's initialized only when spill is triggered - memTracker: memory.NewTracker(memory.LabelForSortPartition, -1), - diskTracker: disk.NewTracker(memory.LabelForSortPartition, -1), - spillAction: nil, // It's set in `actionSpill` function - spillLimit: spillLimit, - byItemsDesc: byItemsDesc, - keyColumns: keyColumns, - keyCmpFuncs: keyCmpFuncs, - cursor: NewDataCursor(), - closed: false, - } - - return retVal -} - -func (s *sortPartition) close() { - s.syncLock.Lock() - defer s.syncLock.Unlock() - s.closed = true - if s.inDisk != nil { - s.inDisk.Close() - } - s.getMemTracker().ReplaceBytesUsed(0) -} - -// Return false if the spill is triggered in this partition. -func (s *sortPartition) add(chk *chunk.Chunk) bool { - rowNum := chk.NumRows() - consumedBytesNum := chunk.RowSize*int64(rowNum) + chk.MemoryUsage() - - s.syncLock.Lock() - defer s.syncLock.Unlock() - if s.isSpillTriggered() { - return false - } - - // Convert chunk to rows - for i := 0; i < rowNum; i++ { - s.savedRows = append(s.savedRows, chk.GetRow(i)) - } - - s.getMemTracker().Consume(consumedBytesNum) - return true -} - -func (s *sortPartition) sort() error { - s.syncLock.Lock() - defer s.syncLock.Unlock() - return s.sortNoLock() -} - -func (s *sortPartition) sortNoLock() (ret error) { - ret = nil - defer func() { - if r := recover(); r != nil { - ret = util.GetRecoverError(r) - } - }() - - if s.isSorted { - return - } - - failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { - if val.(bool) { - panic("sort meet error") - } - }) - - sort.Slice(s.savedRows, s.keyColumnsLess) - s.isSorted = true - s.sliceIter = chunk.NewIterator4Slice(s.savedRows) - return -} - -func (s *sortPartition) spillToDiskImpl() (err error) { - defer func() { - if r := recover(); r != nil { - err = util.GetRecoverError(r) - } - }() - - if s.closed { - return nil - } - - s.inDisk = chunk.NewDataInDiskByChunks(s.fieldTypes) - s.inDisk.GetDiskTracker().AttachTo(s.diskTracker) - tmpChk := chunk.NewChunkWithCapacity(s.fieldTypes, spillChunkSize) - - rowNum := len(s.savedRows) - if rowNum == 0 { - return errSpillEmptyChunk - } - - for row := s.sliceIter.Next(); !row.IsEmpty(); row = s.sliceIter.Next() { - tmpChk.AppendRow(row) - if tmpChk.IsFull() { - err := s.inDisk.Add(tmpChk) - if err != nil { - return err - } - tmpChk.Reset() - s.getMemTracker().HandleKillSignal() - } - } - - // Spill the remaining data in tmpChk. - // Do not spill when tmpChk is empty as `Add` function requires a non-empty chunk - if tmpChk.NumRows() > 0 { - err := s.inDisk.Add(tmpChk) - if err != nil { - return err - } - } - - // Release memory as all data have been spilled to disk - s.savedRows = nil - s.sliceIter = nil - s.getMemTracker().ReplaceBytesUsed(0) - return nil -} - -// We can only call this function under the protection of `syncLock`. -func (s *sortPartition) spillToDisk() error { - s.syncLock.Lock() - defer s.syncLock.Unlock() - if s.isSpillTriggered() { - return nil - } - - err := s.sortNoLock() - if err != nil { - return err - } - - s.setIsSpilling() - defer s.cond.Broadcast() - defer s.setSpillTriggered() - - err = s.spillToDiskImpl() - return err -} - -func (s *sortPartition) getNextSortedRow() (chunk.Row, error) { - s.syncLock.Lock() - defer s.syncLock.Unlock() - if s.isSpillTriggered() { - row := s.cursor.next() - if row.IsEmpty() { - success, err := reloadCursor(s.cursor, s.inDisk) - if err != nil { - return chunk.Row{}, err - } - if !success { - // All data has been consumed - return chunk.Row{}, nil - } - - row = s.cursor.begin() - if row.IsEmpty() { - return chunk.Row{}, errors.New("Get an empty row") - } - } - return row, nil - } - - row := s.sliceIter.Next() - return row, nil -} - -func (s *sortPartition) actionSpill() *sortPartitionSpillDiskAction { - if s.spillAction == nil { - s.spillAction = &sortPartitionSpillDiskAction{ - partition: s, - } - } - return s.spillAction -} - -func (s *sortPartition) getMemTracker() *memory.Tracker { - return s.memTracker -} - -func (s *sortPartition) getDiskTracker() *disk.Tracker { - return s.diskTracker -} - -func (s *sortPartition) hasEnoughDataToSpill() bool { - // Guarantee that each partition size is not too small, to avoid opening too many files. - return s.getMemTracker().BytesConsumed() > s.spillLimit -} - -func (s *sortPartition) lessRow(rowI, rowJ chunk.Row) bool { - for i, colIdx := range s.keyColumns { - cmpFunc := s.keyCmpFuncs[i] - if cmpFunc != nil { - cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) - if s.byItemsDesc[i] { - cmp = -cmp - } - if cmp < 0 { - return true - } else if cmp > 0 { - return false - } - } - } - return false -} - -// keyColumnsLess is the less function for key columns. -func (s *sortPartition) keyColumnsLess(i, j int) bool { - if s.timesOfRowCompare >= signalCheckpointForSort { - // Trigger Consume for checking the NeedKill signal - s.memTracker.HandleKillSignal() - s.timesOfRowCompare = 0 - } - - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { - if val.(bool) { - s.timesOfRowCompare += 1024 - } - }) - - s.timesOfRowCompare++ - return s.lessRow(s.savedRows[i], s.savedRows[j]) -} - -func (s *sortPartition) isSpillTriggered() bool { - s.cond.L.Lock() - defer s.cond.L.Unlock() - return s.spillStatus == spillTriggered -} - -func (s *sortPartition) isSpillTriggeredNoLock() bool { - return s.spillStatus == spillTriggered -} - -func (s *sortPartition) setSpillTriggered() { - s.cond.L.Lock() - defer s.cond.L.Unlock() - s.spillStatus = spillTriggered -} - -func (s *sortPartition) setIsSpilling() { - s.cond.L.Lock() - defer s.cond.L.Unlock() - s.spillStatus = inSpilling -} - -func (s *sortPartition) getIsSpillingNoLock() bool { - return s.spillStatus == inSpilling -} - -func (s *sortPartition) setError(err error) { - s.syncLock.Lock() - defer s.syncLock.Unlock() - s.spillError = err -} - -func (s *sortPartition) checkError() error { - s.syncLock.Lock() - defer s.syncLock.Unlock() - return s.spillError -} - -func (s *sortPartition) numRowInDiskForTest() int64 { - if s.inDisk != nil { - return s.inDisk.NumRows() - } - return 0 -} - -func (s *sortPartition) numRowInMemoryForTest() int64 { - if s.sliceIter != nil { - if s.sliceIter.Len() != len(s.savedRows) { - panic("length of sliceIter should be equal to savedRows") - } - } - return int64(len(s.savedRows)) -} - -// SetSmallSpillChunkSizeForTest set spill chunk size for test. -func SetSmallSpillChunkSizeForTest() { - spillChunkSize = 16 -} diff --git a/pkg/executor/sortexec/sort_util.go b/pkg/executor/sortexec/sort_util.go index 8a4aaee944499..59ef17f90da2c 100644 --- a/pkg/executor/sortexec/sort_util.go +++ b/pkg/executor/sortexec/sort_util.go @@ -66,14 +66,14 @@ type rowWithError struct { } func injectParallelSortRandomFail(triggerFactor int32) { - if val, _err_ := failpoint.Eval(_curpkg_("ParallelSortRandomFail")); _err_ == nil { + failpoint.Inject("ParallelSortRandomFail", func(val failpoint.Value) { if val.(bool) { randNum := rand.Int31n(10000) if randNum < triggerFactor { panic("panic is triggered by random fail") } } - } + }) } // It's used only when spill is triggered diff --git a/pkg/executor/sortexec/sort_util.go__failpoint_stash__ b/pkg/executor/sortexec/sort_util.go__failpoint_stash__ deleted file mode 100644 index 59ef17f90da2c..0000000000000 --- a/pkg/executor/sortexec/sort_util.go__failpoint_stash__ +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sortexec - -import ( - "math/rand" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -var errSpillEmptyChunk = errors.New("can not spill empty chunk to disk") -var errFailToAddChunk = errors.New("fail to add chunk") - -// It should be const, but we need to modify it for test. -var spillChunkSize = 1024 - -// signalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered. -const signalCheckpointForSort uint = 10240 - -const ( - notSpilled = iota - needSpill - inSpilling - spillTriggered -) - -type rowWithPartition struct { - row chunk.Row - partitionID int -} - -func processPanicAndLog(errOutputChan chan<- rowWithError, r any) { - err := util.GetRecoverError(r) - errOutputChan <- rowWithError{err: err} - logutil.BgLogger().Error("executor panicked", zap.Error(err), zap.Stack("stack")) -} - -// chunkWithMemoryUsage contains chunk and memory usage. -// However, some of memory usage may also come from other place, -// not only the chunk's memory usage. -type chunkWithMemoryUsage struct { - Chk *chunk.Chunk - MemoryUsage int64 -} - -type rowWithError struct { - row chunk.Row - err error -} - -func injectParallelSortRandomFail(triggerFactor int32) { - failpoint.Inject("ParallelSortRandomFail", func(val failpoint.Value) { - if val.(bool) { - randNum := rand.Int31n(10000) - if randNum < triggerFactor { - panic("panic is triggered by random fail") - } - } - }) -} - -// It's used only when spill is triggered -type dataCursor struct { - chkID int - chunkIter *chunk.Iterator4Chunk -} - -// NewDataCursor creates a new dataCursor -func NewDataCursor() *dataCursor { - return &dataCursor{ - chkID: -1, - chunkIter: chunk.NewIterator4Chunk(nil), - } -} - -func (d *dataCursor) getChkID() int { - return d.chkID -} - -func (d *dataCursor) begin() chunk.Row { - return d.chunkIter.Begin() -} - -func (d *dataCursor) next() chunk.Row { - return d.chunkIter.Next() -} - -func (d *dataCursor) setChunk(chk *chunk.Chunk, chkID int) { - d.chkID = chkID - d.chunkIter.ResetChunk(chk) -} - -func reloadCursor(cursor *dataCursor, inDisk *chunk.DataInDiskByChunks) (bool, error) { - spilledChkNum := inDisk.NumChunks() - restoredChkID := cursor.getChkID() + 1 - if restoredChkID >= spilledChkNum { - // All data has been consumed - return false, nil - } - - chk, err := inDisk.GetChunk(restoredChkID) - if err != nil { - return false, err - } - cursor.setChunk(chk, restoredChkID) - return true, nil -} diff --git a/pkg/executor/sortexec/topn.go b/pkg/executor/sortexec/topn.go index a3e48c340e847..5d49cb80703d4 100644 --- a/pkg/executor/sortexec/topn.go +++ b/pkg/executor/sortexec/topn.go @@ -613,14 +613,14 @@ func (e *TopNExec) GetInMemoryThenSpillFlagForTest() bool { } func injectTopNRandomFail(triggerFactor int32) { - if val, _err_ := failpoint.Eval(_curpkg_("TopNRandomFail")); _err_ == nil { + failpoint.Inject("TopNRandomFail", func(val failpoint.Value) { if val.(bool) { randNum := rand.Int31n(10000) if randNum < triggerFactor { panic("panic is triggered by random fail") } } - } + }) } // InitTopNExecForTest initializes TopN executors, only for test. diff --git a/pkg/executor/sortexec/topn.go__failpoint_stash__ b/pkg/executor/sortexec/topn.go__failpoint_stash__ deleted file mode 100644 index 5d49cb80703d4..0000000000000 --- a/pkg/executor/sortexec/topn.go__failpoint_stash__ +++ /dev/null @@ -1,647 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sortexec - -import ( - "container/heap" - "context" - "math/rand" - "slices" - "sync" - "sync/atomic" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/memory" -) - -// TopNExec implements a Top-N algorithm and it is built from a SELECT statement with ORDER BY and LIMIT. -// Instead of sorting all the rows fetched from the table, it keeps the Top-N elements only in a heap to reduce memory usage. -type TopNExec struct { - SortExec - Limit *plannercore.PhysicalLimit - - // It's useful when spill is triggered and the fetcher could know when workers finish their works. - fetcherAndWorkerSyncer *sync.WaitGroup - resultChannel chan rowWithError - chunkChannel chan *chunk.Chunk - - finishCh chan struct{} - - chkHeap *topNChunkHeap - - spillHelper *topNSpillHelper - spillAction *topNSpillAction - - // Normally, heap will be stored in memory after it has been built. - // However, other executors may trigger topn spill after the heap is built - // and inMemoryThenSpillFlag will be set to true at this time. - inMemoryThenSpillFlag bool - - // Topn executor has two stage: - // 1. Building heap, in this stage all received rows will be inserted into heap. - // 2. Updating heap, in this stage only rows that is smaller than the heap top could be inserted and we will drop the heap top. - // - // This variable is only used for test. - isSpillTriggeredInStage1ForTest bool - isSpillTriggeredInStage2ForTest bool - - Concurrency int -} - -// Open implements the Executor Open interface. -func (e *TopNExec) Open(ctx context.Context) error { - e.memTracker = memory.NewTracker(e.ID(), -1) - e.memTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.MemTracker) - - e.fetched = &atomic.Bool{} - e.fetched.Store(false) - e.chkHeap = &topNChunkHeap{memTracker: e.memTracker} - e.chkHeap.idx = 0 - - e.finishCh = make(chan struct{}, 1) - e.resultChannel = make(chan rowWithError, e.MaxChunkSize()) - e.chunkChannel = make(chan *chunk.Chunk, e.Concurrency) - e.inMemoryThenSpillFlag = false - e.isSpillTriggeredInStage1ForTest = false - e.isSpillTriggeredInStage2ForTest = false - - if variable.EnableTmpStorageOnOOM.Load() { - e.diskTracker = disk.NewTracker(e.ID(), -1) - diskTracker := e.Ctx().GetSessionVars().StmtCtx.DiskTracker - if diskTracker != nil { - e.diskTracker.AttachTo(diskTracker) - } - e.fetcherAndWorkerSyncer = &sync.WaitGroup{} - - workers := make([]*topNWorker, e.Concurrency) - for i := range workers { - chkHeap := &topNChunkHeap{} - // Offset of heap in worker should be 0, as we need to spill all data - chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, 0, e.greaterRow, e.RetFieldTypes()) - workers[i] = newTopNWorker(i, e.chunkChannel, e.fetcherAndWorkerSyncer, e.resultChannel, e.finishCh, e, chkHeap, e.memTracker) - } - - e.spillHelper = newTopNSpillerHelper( - e, - e.finishCh, - e.resultChannel, - e.memTracker, - e.diskTracker, - exec.RetTypes(e), - workers, - e.Concurrency, - ) - e.spillAction = &topNSpillAction{spillHelper: e.spillHelper} - e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.spillAction) - } else { - e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0) - } - - return exec.Open(ctx, e.Children(0)) -} - -// Close implements the Executor Close interface. -func (e *TopNExec) Close() error { - // `e.finishCh == nil` means that `Open` is not called. - if e.finishCh == nil { - return exec.Close(e.Children(0)) - } - - close(e.finishCh) - if e.fetched.CompareAndSwap(false, true) { - close(e.resultChannel) - return exec.Close(e.Children(0)) - } - - // Wait for the finish of all tasks - channel.Clear(e.resultChannel) - - e.chkHeap = nil - e.spillAction = nil - - if e.spillHelper != nil { - e.spillHelper.close() - e.spillHelper = nil - } - - if e.memTracker != nil { - e.memTracker.ReplaceBytesUsed(0) - } - - return exec.Close(e.Children(0)) -} - -func (e *TopNExec) greaterRow(rowI, rowJ chunk.Row) bool { - for i, colIdx := range e.keyColumns { - cmpFunc := e.keyCmpFuncs[i] - cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) - if e.ByItems[i].Desc { - cmp = -cmp - } - if cmp > 0 { - return true - } else if cmp < 0 { - return false - } - } - return false -} - -// Next implements the Executor Next interface. -// -// The following picture shows the procedure of topn when spill is triggered. -/* -Spill Stage: - ┌─────────┐ - │ Child │ - └────▲────┘ - │ - Fetch - │ - ┌───────┴───────┐ - │ Chunk Fetcher │ - └───────┬───────┘ - │ - │ - ▼ - Check Spill──────►Spill Triggered─────────►Spill - │ │ - ▼ │ - Spill Not Triggered │ - │ │ - ▼ │ - Push Chunk◄─────────────────────────────────┘ - │ - ▼ - ┌────────────────►Channel◄───────────────────┐ - │ ▲ │ - │ │ │ - Fetch Fetch Fetch - │ │ │ - ┌────┴───┐ ┌───┴────┐ ┌───┴────┐ - │ Worker │ │ Worker │ ...... │ Worker │ - └────┬───┘ └───┬────┘ └───┬────┘ - │ │ │ - │ │ │ - │ ▼ │ - └───────────► Multi-way Merge◄───────────────┘ - │ - │ - ▼ - Output - -Restore Stage: - ┌────────┐ ┌────────┐ ┌────────┐ - │ Heap │ │ Heap │ ...... │ Heap │ - └────┬───┘ └───┬────┘ └───┬────┘ - │ │ │ - │ │ │ - │ ▼ │ - └───────────► Multi-way Merge◄───────────────┘ - │ - │ - ▼ - Output - -*/ -func (e *TopNExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.Reset() - if e.fetched.CompareAndSwap(false, true) { - err := e.fetchChunks(ctx) - if err != nil { - return err - } - } - - if !req.IsFull() { - numToAppend := req.RequiredRows() - req.NumRows() - for i := 0; i < numToAppend; i++ { - row, ok := <-e.resultChannel - if !ok || row.err != nil { - return row.err - } - req.AppendRow(row.row) - } - } - return nil -} - -func (e *TopNExec) fetchChunks(ctx context.Context) error { - defer func() { - if r := recover(); r != nil { - processPanicAndLog(e.resultChannel, r) - close(e.resultChannel) - } - }() - - err := e.loadChunksUntilTotalLimit(ctx) - if err != nil { - close(e.resultChannel) - return err - } - go e.executeTopN(ctx) - return nil -} - -func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { - err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) - if err != nil { - return err - } - - e.buildKeyColumns() - e.chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, int(e.Limit.Offset), e.greaterRow, e.RetFieldTypes()) - for uint64(e.chkHeap.rowChunks.Len()) < e.chkHeap.totalLimit { - srcChk := exec.TryNewCacheChunk(e.Children(0)) - // adjust required rows by total limit - srcChk.SetRequiredRows(int(e.chkHeap.totalLimit-uint64(e.chkHeap.rowChunks.Len())), e.MaxChunkSize()) - err := exec.Next(ctx, e.Children(0), srcChk) - if err != nil { - return err - } - if srcChk.NumRows() == 0 { - break - } - e.chkHeap.rowChunks.Add(srcChk) - if e.spillHelper.isSpillNeeded() { - e.isSpillTriggeredInStage1ForTest = true - break - } - - injectTopNRandomFail(1) - } - - e.chkHeap.initPtrs() - return nil -} - -const topNCompactionFactor = 4 - -func (e *TopNExec) executeTopNWhenNoSpillTriggered(ctx context.Context) error { - if e.spillHelper.isSpillNeeded() { - e.isSpillTriggeredInStage2ForTest = true - return nil - } - - childRowChk := exec.TryNewCacheChunk(e.Children(0)) - for { - if e.spillHelper.isSpillNeeded() { - e.isSpillTriggeredInStage2ForTest = true - return nil - } - - err := exec.Next(ctx, e.Children(0), childRowChk) - if err != nil { - return err - } - - if childRowChk.NumRows() == 0 { - break - } - - e.chkHeap.processChk(childRowChk) - - if e.chkHeap.rowChunks.Len() > len(e.chkHeap.rowPtrs)*topNCompactionFactor { - err = e.chkHeap.doCompaction(e) - if err != nil { - return err - } - } - injectTopNRandomFail(10) - } - - slices.SortFunc(e.chkHeap.rowPtrs, e.chkHeap.keyColumnsCompare) - return nil -} - -func (e *TopNExec) spillRemainingRowsWhenNeeded() error { - if e.spillHelper.isSpillTriggered() { - return e.spillHelper.spill() - } - return nil -} - -func (e *TopNExec) checkSpillAndExecute() error { - if e.spillHelper.isSpillNeeded() { - // Wait for the stop of all workers - e.fetcherAndWorkerSyncer.Wait() - return e.spillHelper.spill() - } - return nil -} - -func (e *TopNExec) fetchChunksFromChild(ctx context.Context) { - defer func() { - if r := recover(); r != nil { - processPanicAndLog(e.resultChannel, r) - } - - e.fetcherAndWorkerSyncer.Wait() - err := e.spillRemainingRowsWhenNeeded() - if err != nil { - e.resultChannel <- rowWithError{err: err} - } - - close(e.chunkChannel) - }() - - for { - chk := exec.TryNewCacheChunk(e.Children(0)) - err := exec.Next(ctx, e.Children(0), chk) - if err != nil { - e.resultChannel <- rowWithError{err: err} - return - } - - rowCount := chk.NumRows() - if rowCount == 0 { - break - } - - e.fetcherAndWorkerSyncer.Add(1) - select { - case <-e.finishCh: - e.fetcherAndWorkerSyncer.Done() - return - case e.chunkChannel <- chk: - } - - injectTopNRandomFail(10) - - err = e.checkSpillAndExecute() - if err != nil { - e.resultChannel <- rowWithError{err: err} - return - } - } -} - -// Spill the heap which is in TopN executor -func (e *TopNExec) spillTopNExecHeap() error { - e.spillHelper.setInSpilling() - defer e.spillHelper.cond.Broadcast() - defer e.spillHelper.setNotSpilled() - - err := e.spillHelper.spillHeap(e.chkHeap) - if err != nil { - return err - } - return nil -} - -func (e *TopNExec) executeTopNWhenSpillTriggered(ctx context.Context) error { - // idx need to be set to 0 as we need to spill all data - e.chkHeap.idx = 0 - err := e.spillTopNExecHeap() - if err != nil { - return err - } - - // Wait for the finish of chunk fetcher - fetcherWaiter := util.WaitGroupWrapper{} - // Wait for the finish of all workers - workersWaiter := util.WaitGroupWrapper{} - - for i := range e.spillHelper.workers { - worker := e.spillHelper.workers[i] - worker.initWorker() - workersWaiter.Run(func() { - worker.run() - }) - } - - // Fetch chunks from child and put chunks into chunkChannel - fetcherWaiter.Run(func() { - e.fetchChunksFromChild(ctx) - }) - - fetcherWaiter.Wait() - workersWaiter.Wait() - return nil -} - -func (e *TopNExec) executeTopN(ctx context.Context) { - defer func() { - if r := recover(); r != nil { - processPanicAndLog(e.resultChannel, r) - } - - close(e.resultChannel) - }() - - heap.Init(e.chkHeap) - for uint64(len(e.chkHeap.rowPtrs)) > e.chkHeap.totalLimit { - // The number of rows we loaded may exceeds total limit, remove greatest rows by Pop. - heap.Pop(e.chkHeap) - } - - if err := e.executeTopNWhenNoSpillTriggered(ctx); err != nil { - e.resultChannel <- rowWithError{err: err} - return - } - - if e.spillHelper.isSpillNeeded() { - if err := e.executeTopNWhenSpillTriggered(ctx); err != nil { - e.resultChannel <- rowWithError{err: err} - return - } - } - - e.generateTopNResults() -} - -// Return true when spill is triggered -func (e *TopNExec) generateTopNResultsWhenNoSpillTriggered() bool { - rowPtrNum := len(e.chkHeap.rowPtrs) - for ; e.chkHeap.idx < rowPtrNum; e.chkHeap.idx++ { - if e.chkHeap.idx%10 == 0 && e.spillHelper.isSpillNeeded() { - return true - } - e.resultChannel <- rowWithError{row: e.chkHeap.rowChunks.GetRow(e.chkHeap.rowPtrs[e.chkHeap.idx])} - } - return false -} - -func (e *TopNExec) generateResultWithMultiWayMerge(offset int64, limit int64) error { - multiWayMerge := newMultiWayMerger(&diskSource{sortedRowsInDisk: e.spillHelper.sortedRowsInDisk}, e.lessRow) - - err := multiWayMerge.init() - if err != nil { - return err - } - - outputRowNum := int64(0) - for { - if outputRowNum >= limit { - return nil - } - - row, err := multiWayMerge.next() - if err != nil { - return err - } - - if row.IsEmpty() { - return nil - } - - if outputRowNum >= offset { - select { - case <-e.finishCh: - return nil - case e.resultChannel <- rowWithError{row: row}: - } - } - outputRowNum++ - injectParallelSortRandomFail(1) - } -} - -// GenerateTopNResultsWhenSpillOnlyOnce generates results with this function when we trigger spill only once. -// It's a public function as we need to test it in ut. -func (e *TopNExec) GenerateTopNResultsWhenSpillOnlyOnce() error { - inDisk := e.spillHelper.sortedRowsInDisk[0] - chunkNum := inDisk.NumChunks() - skippedRowNum := uint64(0) - offset := e.Limit.Offset - for i := 0; i < chunkNum; i++ { - chk, err := inDisk.GetChunk(i) - if err != nil { - return err - } - - injectTopNRandomFail(10) - - rowNum := chk.NumRows() - j := 0 - if !e.inMemoryThenSpillFlag { - // When e.inMemoryThenSpillFlag == false, we need to manually set j - // because rows that should be ignored before offset have also been - // spilled to disk. - if skippedRowNum < offset { - rowNumNeedSkip := offset - skippedRowNum - if rowNum <= int(rowNumNeedSkip) { - // All rows in this chunk should be skipped - skippedRowNum += uint64(rowNum) - continue - } - j += int(rowNumNeedSkip) - skippedRowNum += rowNumNeedSkip - } - } - - for ; j < rowNum; j++ { - select { - case <-e.finishCh: - return nil - case e.resultChannel <- rowWithError{row: chk.GetRow(j)}: - } - } - } - return nil -} - -func (e *TopNExec) generateTopNResultsWhenSpillTriggered() error { - inDiskNum := len(e.spillHelper.sortedRowsInDisk) - if inDiskNum == 0 { - panic("inDiskNum can't be 0 when we generate result with spill triggered") - } - - if inDiskNum == 1 { - return e.GenerateTopNResultsWhenSpillOnlyOnce() - } - return e.generateResultWithMultiWayMerge(int64(e.Limit.Offset), int64(e.Limit.Offset+e.Limit.Count)) -} - -func (e *TopNExec) generateTopNResults() { - if !e.spillHelper.isSpillTriggered() { - if !e.generateTopNResultsWhenNoSpillTriggered() { - return - } - - err := e.spillTopNExecHeap() - if err != nil { - e.resultChannel <- rowWithError{err: err} - } - - e.inMemoryThenSpillFlag = true - } - - err := e.generateTopNResultsWhenSpillTriggered() - if err != nil { - e.resultChannel <- rowWithError{err: err} - } -} - -// IsSpillTriggeredForTest shows if spill is triggered, used for test. -func (e *TopNExec) IsSpillTriggeredForTest() bool { - return e.spillHelper.isSpillTriggered() -} - -// GetIsSpillTriggeredInStage1ForTest shows if spill is triggered in stage 1, only used for test. -func (e *TopNExec) GetIsSpillTriggeredInStage1ForTest() bool { - return e.isSpillTriggeredInStage1ForTest -} - -// GetIsSpillTriggeredInStage2ForTest shows if spill is triggered in stage 2, only used for test. -func (e *TopNExec) GetIsSpillTriggeredInStage2ForTest() bool { - return e.isSpillTriggeredInStage2ForTest -} - -// GetInMemoryThenSpillFlagForTest shows if results are in memory before they are spilled, only used for test -func (e *TopNExec) GetInMemoryThenSpillFlagForTest() bool { - return e.inMemoryThenSpillFlag -} - -func injectTopNRandomFail(triggerFactor int32) { - failpoint.Inject("TopNRandomFail", func(val failpoint.Value) { - if val.(bool) { - randNum := rand.Int31n(10000) - if randNum < triggerFactor { - panic("panic is triggered by random fail") - } - } - }) -} - -// InitTopNExecForTest initializes TopN executors, only for test. -func InitTopNExecForTest(topnExec *TopNExec, offset uint64, sortedRowsInDisk *chunk.DataInDiskByChunks) { - topnExec.inMemoryThenSpillFlag = false - topnExec.finishCh = make(chan struct{}, 1) - topnExec.resultChannel = make(chan rowWithError, 10000) - topnExec.Limit.Offset = offset - topnExec.spillHelper = &topNSpillHelper{} - topnExec.spillHelper.sortedRowsInDisk = []*chunk.DataInDiskByChunks{sortedRowsInDisk} -} - -// GetResultForTest gets result, only for test. -func GetResultForTest(topnExec *TopNExec) []int64 { - close(topnExec.resultChannel) - result := make([]int64, 0, 100) - for { - row, ok := <-topnExec.resultChannel - if !ok { - return result - } - result = append(result, row.row.GetInt64(0)) - } -} diff --git a/pkg/executor/sortexec/topn_worker.go b/pkg/executor/sortexec/topn_worker.go index 45b2b75843ff8..527dcd42977e9 100644 --- a/pkg/executor/sortexec/topn_worker.go +++ b/pkg/executor/sortexec/topn_worker.go @@ -117,7 +117,7 @@ func (t *topNWorker) run() { func (t *topNWorker) injectFailPointForTopNWorker(triggerFactor int32) { injectTopNRandomFail(triggerFactor) - if val, _err_ := failpoint.Eval(_curpkg_("SlowSomeWorkers")); _err_ == nil { + failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { if val.(bool) { if t.workerIDForTest%2 == 0 { randNum := rand.Int31n(10000) @@ -126,5 +126,5 @@ func (t *topNWorker) injectFailPointForTopNWorker(triggerFactor int32) { } } } - } + }) } diff --git a/pkg/executor/sortexec/topn_worker.go__failpoint_stash__ b/pkg/executor/sortexec/topn_worker.go__failpoint_stash__ deleted file mode 100644 index 527dcd42977e9..0000000000000 --- a/pkg/executor/sortexec/topn_worker.go__failpoint_stash__ +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sortexec - -import ( - "container/heap" - "math/rand" - "sync" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/memory" -) - -// topNWorker is used only when topn spill is triggered -type topNWorker struct { - workerIDForTest int - - chunkChannel <-chan *chunk.Chunk - fetcherAndWorkerSyncer *sync.WaitGroup - errOutputChan chan<- rowWithError - finishChan <-chan struct{} - - topn *TopNExec - chkHeap *topNChunkHeap - memTracker *memory.Tracker -} - -func newTopNWorker( - idForTest int, - chunkChannel <-chan *chunk.Chunk, - fetcherAndWorkerSyncer *sync.WaitGroup, - errOutputChan chan<- rowWithError, - finishChan <-chan struct{}, - topn *TopNExec, - chkHeap *topNChunkHeap, - memTracker *memory.Tracker) *topNWorker { - return &topNWorker{ - workerIDForTest: idForTest, - chunkChannel: chunkChannel, - fetcherAndWorkerSyncer: fetcherAndWorkerSyncer, - errOutputChan: errOutputChan, - finishChan: finishChan, - chkHeap: chkHeap, - topn: topn, - memTracker: memTracker, - } -} - -func (t *topNWorker) initWorker() { - // Offset of heap in worker should be 0, as we need to spill all data - t.chkHeap.init(t.topn, t.memTracker, t.topn.Limit.Offset+t.topn.Limit.Count, 0, t.topn.greaterRow, t.topn.RetFieldTypes()) -} - -func (t *topNWorker) fetchChunksAndProcess() { - for t.fetchChunksAndProcessImpl() { - } -} - -func (t *topNWorker) fetchChunksAndProcessImpl() bool { - select { - case <-t.finishChan: - return false - case chk, ok := <-t.chunkChannel: - if !ok { - return false - } - defer func() { - t.fetcherAndWorkerSyncer.Done() - }() - - t.injectFailPointForTopNWorker(3) - - if uint64(t.chkHeap.rowChunks.Len()) < t.chkHeap.totalLimit { - if !t.chkHeap.isInitialized { - t.chkHeap.init(t.topn, t.memTracker, t.topn.Limit.Offset+t.topn.Limit.Count, 0, t.topn.greaterRow, t.topn.RetFieldTypes()) - } - t.chkHeap.rowChunks.Add(chk) - } else { - if !t.chkHeap.isRowPtrsInit { - t.chkHeap.initPtrs() - heap.Init(t.chkHeap) - } - t.chkHeap.processChk(chk) - } - } - return true -} - -func (t *topNWorker) run() { - defer func() { - if r := recover(); r != nil { - processPanicAndLog(t.errOutputChan, r) - } - - // Consume all chunks to avoid hang of fetcher - for range t.chunkChannel { - t.fetcherAndWorkerSyncer.Done() - } - }() - - t.fetchChunksAndProcess() -} - -func (t *topNWorker) injectFailPointForTopNWorker(triggerFactor int32) { - injectTopNRandomFail(triggerFactor) - failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { - if val.(bool) { - if t.workerIDForTest%2 == 0 { - randNum := rand.Int31n(10000) - if randNum < 10 { - time.Sleep(1 * time.Millisecond) - } - } - } - }) -} diff --git a/pkg/executor/table_reader.go b/pkg/executor/table_reader.go index f59384d029d81..60c83af717526 100644 --- a/pkg/executor/table_reader.go +++ b/pkg/executor/table_reader.go @@ -220,10 +220,10 @@ func (e *TableReaderExecutor) memUsage() int64 { func (e *TableReaderExecutor) Open(ctx context.Context) error { r, ctx := tracing.StartRegionEx(ctx, "TableReaderExecutor.Open") defer r.End() - if v, _err_ := failpoint.Eval(_curpkg_("mockSleepInTableReaderNext")); _err_ == nil { + failpoint.Inject("mockSleepInTableReaderNext", func(v failpoint.Value) { ms := v.(int) time.Sleep(time.Millisecond * time.Duration(ms)) - } + }) if e.memTracker != nil { e.memTracker.Reset() diff --git a/pkg/executor/table_reader.go__failpoint_stash__ b/pkg/executor/table_reader.go__failpoint_stash__ deleted file mode 100644 index 60c83af717526..0000000000000 --- a/pkg/executor/table_reader.go__failpoint_stash__ +++ /dev/null @@ -1,632 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package executor - -import ( - "bytes" - "cmp" - "context" - "slices" - "time" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/distsql" - distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/executor/internal/builder" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - internalutil "github.com/pingcap/tidb/pkg/executor/internal/util" - "github.com/pingcap/tidb/pkg/expression" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/infoschema" - isctx "github.com/pingcap/tidb/pkg/infoschema/context" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - planctx "github.com/pingcap/tidb/pkg/planner/context" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/ranger" - rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" - "github.com/pingcap/tidb/pkg/util/size" - "github.com/pingcap/tidb/pkg/util/stringutil" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/pingcap/tipb/go-tipb" -) - -// make sure `TableReaderExecutor` implements `Executor`. -var _ exec.Executor = &TableReaderExecutor{} - -// selectResultHook is used to hack distsql.SelectWithRuntimeStats safely for testing. -type selectResultHook struct { - selectResultFunc func(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, - fieldTypes []*types.FieldType, copPlanIDs []int) (distsql.SelectResult, error) -} - -func (sr selectResultHook) SelectResult(ctx context.Context, dctx *distsqlctx.DistSQLContext, kvReq *kv.Request, - fieldTypes []*types.FieldType, copPlanIDs []int, rootPlanID int) (distsql.SelectResult, error) { - if sr.selectResultFunc == nil { - return distsql.SelectWithRuntimeStats(ctx, dctx, kvReq, fieldTypes, copPlanIDs, rootPlanID) - } - return sr.selectResultFunc(ctx, dctx, kvReq, fieldTypes, copPlanIDs) -} - -type kvRangeBuilder interface { - buildKeyRange(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([][]kv.KeyRange, error) - buildKeyRangeSeparately(dctx *distsqlctx.DistSQLContext, ranges []*ranger.Range) ([]int64, [][]kv.KeyRange, error) -} - -// tableReaderExecutorContext is the execution context for the `TableReaderExecutor` -type tableReaderExecutorContext struct { - dctx *distsqlctx.DistSQLContext - rctx *rangerctx.RangerContext - buildPBCtx *planctx.BuildPBContext - ectx exprctx.BuildContext - - stmtMemTracker *memory.Tracker - - infoSchema isctx.MetaOnlyInfoSchema - getDDLOwner func(context.Context) (*infosync.ServerInfo, error) -} - -func (treCtx *tableReaderExecutorContext) GetInfoSchema() isctx.MetaOnlyInfoSchema { - return treCtx.infoSchema -} - -func (treCtx *tableReaderExecutorContext) GetDDLOwner(ctx context.Context) (*infosync.ServerInfo, error) { - if treCtx.getDDLOwner != nil { - return treCtx.getDDLOwner(ctx) - } - - return nil, errors.New("GetDDLOwner in a context without DDL") -} - -func newTableReaderExecutorContext(sctx sessionctx.Context) tableReaderExecutorContext { - // Explicitly get `ownerManager` out of the closure to show that the `tableReaderExecutorContext` itself doesn't - // depend on `sctx` directly. - // The context of some tests don't have `DDL`, so make it optional - var getDDLOwner func(ctx context.Context) (*infosync.ServerInfo, error) - ddl := domain.GetDomain(sctx).DDL() - if ddl != nil { - ownerManager := ddl.OwnerManager() - getDDLOwner = func(ctx context.Context) (*infosync.ServerInfo, error) { - ddlOwnerID, err := ownerManager.GetOwnerID(ctx) - if err != nil { - return nil, err - } - return infosync.GetServerInfoByID(ctx, ddlOwnerID) - } - } - - pctx := sctx.GetPlanCtx() - return tableReaderExecutorContext{ - dctx: sctx.GetDistSQLCtx(), - rctx: pctx.GetRangerCtx(), - buildPBCtx: pctx.GetBuildPBCtx(), - ectx: sctx.GetExprCtx(), - stmtMemTracker: sctx.GetSessionVars().StmtCtx.MemTracker, - infoSchema: pctx.GetInfoSchema(), - getDDLOwner: getDDLOwner, - } -} - -// TableReaderExecutor sends DAG request and reads table data from kv layer. -type TableReaderExecutor struct { - tableReaderExecutorContext - exec.BaseExecutorV2 - - table table.Table - - // The source of key ranges varies from case to case. - // It may be calculated from PhysicalPlan by executorBuilder, or calculated from argument by dataBuilder; - // It may be calculated from ranger.Ranger, or calculated from handles. - // The table ID may also change because of the partition table, and causes the key range to change. - // So instead of keeping a `range` struct field, it's better to define a interface. - kvRangeBuilder - // TODO: remove this field, use the kvRangeBuilder interface. - ranges []*ranger.Range - - // kvRanges are only use for union scan. - kvRanges []kv.KeyRange - dagPB *tipb.DAGRequest - startTS uint64 - txnScope string - readReplicaScope string - isStaleness bool - // FIXME: in some cases the data size can be more accurate after get the handles count, - // but we keep things simple as it needn't to be that accurate for now. - netDataSize float64 - // columns are only required by union scan and virtual column. - columns []*model.ColumnInfo - - // resultHandler handles the order of the result. Since (MAXInt64, MAXUint64] stores before [0, MaxInt64] physically - // for unsigned int. - resultHandler *tableResultHandler - plans []base.PhysicalPlan - tablePlan base.PhysicalPlan - - memTracker *memory.Tracker - selectResultHook // for testing - - keepOrder bool - desc bool - // byItems only for partition table with orderBy + pushedLimit - byItems []*util.ByItems - paging bool - storeType kv.StoreType - // corColInFilter tells whether there's correlated column in filter (both conditions in PhysicalSelection and LateMaterializationFilterCondition in PhysicalTableScan) - // If true, we will need to revise the dagPB (fill correlated column value in filter) each time call Open(). - corColInFilter bool - // corColInAccess tells whether there's correlated column in access conditions. - corColInAccess bool - // virtualColumnIndex records all the indices of virtual columns and sort them in definition - // to make sure we can compute the virtual column in right order. - virtualColumnIndex []int - // virtualColumnRetFieldTypes records the RetFieldTypes of virtual columns. - virtualColumnRetFieldTypes []*types.FieldType - // batchCop indicates whether use super batch coprocessor request, only works for TiFlash engine. - batchCop bool - - // If dummy flag is set, this is not a real TableReader, it just provides the KV ranges for UnionScan. - // Used by the temporary table, cached table. - dummy bool -} - -// Table implements the dataSourceExecutor interface. -func (e *TableReaderExecutor) Table() table.Table { - return e.table -} - -func (e *TableReaderExecutor) setDummy() { - e.dummy = true -} - -func (e *TableReaderExecutor) memUsage() int64 { - const sizeofTableReaderExecutor = int64(unsafe.Sizeof(*(*TableReaderExecutor)(nil))) - - res := sizeofTableReaderExecutor - res += size.SizeOfPointer * int64(cap(e.ranges)) - for _, v := range e.ranges { - res += v.MemUsage() - } - res += kv.KeyRangeSliceMemUsage(e.kvRanges) - res += int64(e.dagPB.Size()) - // TODO: add more statistics - return res -} - -// Open initializes necessary variables for using this executor. -func (e *TableReaderExecutor) Open(ctx context.Context) error { - r, ctx := tracing.StartRegionEx(ctx, "TableReaderExecutor.Open") - defer r.End() - failpoint.Inject("mockSleepInTableReaderNext", func(v failpoint.Value) { - ms := v.(int) - time.Sleep(time.Millisecond * time.Duration(ms)) - }) - - if e.memTracker != nil { - e.memTracker.Reset() - } else { - e.memTracker = memory.NewTracker(e.ID(), -1) - } - e.memTracker.AttachTo(e.stmtMemTracker) - - var err error - if e.corColInFilter { - // If there's correlated column in filter, need to rewrite dagPB - if e.storeType == kv.TiFlash { - execs, err := builder.ConstructTreeBasedDistExec(e.buildPBCtx, e.tablePlan) - if err != nil { - return err - } - e.dagPB.RootExecutor = execs[0] - } else { - e.dagPB.Executors, err = builder.ConstructListBasedDistExec(e.buildPBCtx, e.plans) - if err != nil { - return err - } - } - } - if e.dctx.RuntimeStatsColl != nil { - collExec := true - e.dagPB.CollectExecutionSummaries = &collExec - } - if e.corColInAccess { - ts := e.plans[0].(*plannercore.PhysicalTableScan) - e.ranges, err = ts.ResolveCorrelatedColumns() - if err != nil { - return err - } - } - - e.resultHandler = &tableResultHandler{} - - firstPartRanges, secondPartRanges := distsql.SplitRangesAcrossInt64Boundary(e.ranges, e.keepOrder, e.desc, e.table.Meta() != nil && e.table.Meta().IsCommonHandle) - - // Treat temporary table as dummy table, avoid sending distsql request to TiKV. - // Calculate the kv ranges here, UnionScan rely on this kv ranges. - // cached table and temporary table are similar - if e.dummy { - if e.desc && len(secondPartRanges) != 0 { - // TiKV support reverse scan and the `resultHandler` process the range order. - // While in UnionScan, it doesn't use reverse scan and reverse the final result rows manually. - // So things are differ, we need to reverse the kv range here. - // TODO: If we refactor UnionScan to use reverse scan, update the code here. - // [9734095886065816708 9734095886065816709] | [1 3] [65535 9734095886065816707] => before the following change - // [1 3] [65535 9734095886065816707] | [9734095886065816708 9734095886065816709] => ranges part reverse here - // [1 3 65535 9734095886065816707 9734095886065816708 9734095886065816709] => scan (normal order) in UnionScan - // [9734095886065816709 9734095886065816708 9734095886065816707 65535 3 1] => rows reverse in UnionScan - firstPartRanges, secondPartRanges = secondPartRanges, firstPartRanges - } - kvReq, err := e.buildKVReq(ctx, firstPartRanges) - if err != nil { - return err - } - e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) - if len(secondPartRanges) != 0 { - kvReq, err = e.buildKVReq(ctx, secondPartRanges) - if err != nil { - return err - } - e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) - } - return nil - } - - firstResult, err := e.buildResp(ctx, firstPartRanges) - if err != nil { - return err - } - if len(secondPartRanges) == 0 { - e.resultHandler.open(nil, firstResult) - return nil - } - var secondResult distsql.SelectResult - secondResult, err = e.buildResp(ctx, secondPartRanges) - if err != nil { - return err - } - e.resultHandler.open(firstResult, secondResult) - return nil -} - -// Next fills data into the chunk passed by its caller. -// The task was actually done by tableReaderHandler. -func (e *TableReaderExecutor) Next(ctx context.Context, req *chunk.Chunk) error { - if e.dummy { - // Treat temporary table as dummy table, avoid sending distsql request to TiKV. - req.Reset() - return nil - } - - logutil.Eventf(ctx, "table scan table: %s, range: %v", stringutil.MemoizeStr(func() string { - var tableName string - if meta := e.table.Meta(); meta != nil { - tableName = meta.Name.L - } - return tableName - }), e.ranges) - if err := e.resultHandler.nextChunk(ctx, req); err != nil { - return err - } - - err := table.FillVirtualColumnValue(e.virtualColumnRetFieldTypes, e.virtualColumnIndex, e.Schema().Columns, e.columns, e.ectx, req) - if err != nil { - return err - } - - return nil -} - -// Close implements the Executor Close interface. -func (e *TableReaderExecutor) Close() error { - var err error - if e.resultHandler != nil { - err = e.resultHandler.Close() - } - e.kvRanges = e.kvRanges[:0] - if e.dummy { - return nil - } - return err -} - -// buildResp first builds request and sends it to tikv using distsql.Select. It uses SelectResult returned by the callee -// to fetch all results. -func (e *TableReaderExecutor) buildResp(ctx context.Context, ranges []*ranger.Range) (distsql.SelectResult, error) { - if e.storeType == kv.TiFlash && e.kvRangeBuilder != nil { - if !e.batchCop { - // TiFlash cannot support to access multiple tables/partitions within one KVReq, so we have to build KVReq for each partition separately. - kvReqs, err := e.buildKVReqSeparately(ctx, ranges) - if err != nil { - return nil, err - } - var results []distsql.SelectResult - for _, kvReq := range kvReqs { - result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) - if err != nil { - return nil, err - } - results = append(results, result) - } - return distsql.NewSerialSelectResults(results), nil - } - // Use PartitionTable Scan - kvReq, err := e.buildKVReqForPartitionTableScan(ctx, ranges) - if err != nil { - return nil, err - } - result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) - if err != nil { - return nil, err - } - return result, nil - } - - // use sortedSelectResults here when pushDown limit for partition table. - if e.kvRangeBuilder != nil && e.byItems != nil { - kvReqs, err := e.buildKVReqSeparately(ctx, ranges) - if err != nil { - return nil, err - } - var results []distsql.SelectResult - for _, kvReq := range kvReqs { - result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) - if err != nil { - return nil, err - } - results = append(results, result) - } - if len(results) == 1 { - return results[0], nil - } - return distsql.NewSortedSelectResults(e.ectx.GetEvalCtx(), results, e.Schema(), e.byItems, e.memTracker), nil - } - - kvReq, err := e.buildKVReq(ctx, ranges) - if err != nil { - return nil, err - } - kvReq.KeyRanges.SortByFunc(func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - e.kvRanges = kvReq.KeyRanges.AppendSelfTo(e.kvRanges) - - result, err := e.SelectResult(ctx, e.dctx, kvReq, exec.RetTypes(e), getPhysicalPlanIDs(e.plans), e.ID()) - if err != nil { - return nil, err - } - return result, nil -} - -func (e *TableReaderExecutor) buildKVReqSeparately(ctx context.Context, ranges []*ranger.Range) ([]*kv.Request, error) { - pids, kvRanges, err := e.kvRangeBuilder.buildKeyRangeSeparately(e.dctx, ranges) - if err != nil { - return nil, err - } - kvReqs := make([]*kv.Request, 0, len(kvRanges)) - for i, kvRange := range kvRanges { - e.kvRanges = append(e.kvRanges, kvRange...) - if err := internalutil.UpdateExecutorTableID(ctx, e.dagPB.RootExecutor, true, []int64{pids[i]}); err != nil { - return nil, err - } - var builder distsql.RequestBuilder - reqBuilder := builder.SetKeyRanges(kvRange) - kvReq, err := reqBuilder. - SetDAGRequest(e.dagPB). - SetStartTS(e.startTS). - SetDesc(e.desc). - SetKeepOrder(e.keepOrder). - SetTxnScope(e.txnScope). - SetReadReplicaScope(e.readReplicaScope). - SetFromSessionVars(e.dctx). - SetFromInfoSchema(e.GetInfoSchema()). - SetMemTracker(e.memTracker). - SetStoreType(e.storeType). - SetPaging(e.paging). - SetAllowBatchCop(e.batchCop). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). - SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). - Build() - if err != nil { - return nil, err - } - kvReqs = append(kvReqs, kvReq) - } - return kvReqs, nil -} - -func (e *TableReaderExecutor) buildKVReqForPartitionTableScan(ctx context.Context, ranges []*ranger.Range) (*kv.Request, error) { - pids, kvRanges, err := e.kvRangeBuilder.buildKeyRangeSeparately(e.dctx, ranges) - if err != nil { - return nil, err - } - partitionIDAndRanges := make([]kv.PartitionIDAndRanges, 0, len(pids)) - for i, kvRange := range kvRanges { - e.kvRanges = append(e.kvRanges, kvRange...) - partitionIDAndRanges = append(partitionIDAndRanges, kv.PartitionIDAndRanges{ - ID: pids[i], - KeyRanges: kvRange, - }) - } - if err := internalutil.UpdateExecutorTableID(ctx, e.dagPB.RootExecutor, true, pids); err != nil { - return nil, err - } - var builder distsql.RequestBuilder - reqBuilder := builder.SetPartitionIDAndRanges(partitionIDAndRanges) - kvReq, err := reqBuilder. - SetDAGRequest(e.dagPB). - SetStartTS(e.startTS). - SetDesc(e.desc). - SetKeepOrder(e.keepOrder). - SetTxnScope(e.txnScope). - SetReadReplicaScope(e.readReplicaScope). - SetFromSessionVars(e.dctx). - SetFromInfoSchema(e.GetInfoSchema()). - SetMemTracker(e.memTracker). - SetStoreType(e.storeType). - SetPaging(e.paging). - SetAllowBatchCop(e.batchCop). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). - SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias). - Build() - if err != nil { - return nil, err - } - return kvReq, nil -} - -func (e *TableReaderExecutor) buildKVReq(ctx context.Context, ranges []*ranger.Range) (*kv.Request, error) { - var builder distsql.RequestBuilder - var reqBuilder *distsql.RequestBuilder - if e.kvRangeBuilder != nil { - kvRange, err := e.kvRangeBuilder.buildKeyRange(e.dctx, ranges) - if err != nil { - return nil, err - } - reqBuilder = builder.SetPartitionKeyRanges(kvRange) - } else { - reqBuilder = builder.SetHandleRanges(e.dctx, getPhysicalTableID(e.table), e.table.Meta() != nil && e.table.Meta().IsCommonHandle, ranges) - } - if e.table != nil && e.table.Type().IsClusterTable() { - copDestination := infoschema.GetClusterTableCopDestination(e.table.Meta().Name.L) - if copDestination == infoschema.DDLOwner { - serverInfo, err := e.GetDDLOwner(ctx) - if err != nil { - return nil, err - } - reqBuilder.SetTiDBServerID(serverInfo.ServerIDGetter()) - } - } - reqBuilder. - SetDAGRequest(e.dagPB). - SetStartTS(e.startTS). - SetDesc(e.desc). - SetKeepOrder(e.keepOrder). - SetTxnScope(e.txnScope). - SetReadReplicaScope(e.readReplicaScope). - SetIsStaleness(e.isStaleness). - SetFromSessionVars(e.dctx). - SetFromInfoSchema(e.GetInfoSchema()). - SetMemTracker(e.memTracker). - SetStoreType(e.storeType). - SetAllowBatchCop(e.batchCop). - SetClosestReplicaReadAdjuster(newClosestReadAdjuster(e.dctx, &reqBuilder.Request, e.netDataSize)). - SetPaging(e.paging). - SetConnIDAndConnAlias(e.dctx.ConnectionID, e.dctx.SessionAlias) - return reqBuilder.Build() -} - -func buildVirtualColumnIndex(schema *expression.Schema, columns []*model.ColumnInfo) []int { - virtualColumnIndex := make([]int, 0, len(columns)) - for i, col := range schema.Columns { - if col.VirtualExpr != nil { - virtualColumnIndex = append(virtualColumnIndex, i) - } - } - slices.SortFunc(virtualColumnIndex, func(i, j int) int { - return cmp.Compare(plannercore.FindColumnInfoByID(columns, schema.Columns[i].ID).Offset, - plannercore.FindColumnInfoByID(columns, schema.Columns[j].ID).Offset) - }) - return virtualColumnIndex -} - -// buildVirtualColumnInfo saves virtual column indices and sort them in definition order -func (e *TableReaderExecutor) buildVirtualColumnInfo() { - e.virtualColumnIndex, e.virtualColumnRetFieldTypes = buildVirtualColumnInfo(e.Schema(), e.columns) -} - -// buildVirtualColumnInfo saves virtual column indices and sort them in definition order -func buildVirtualColumnInfo(schema *expression.Schema, columns []*model.ColumnInfo) (colIndexs []int, retTypes []*types.FieldType) { - colIndexs = buildVirtualColumnIndex(schema, columns) - if len(colIndexs) > 0 { - retTypes = make([]*types.FieldType, len(colIndexs)) - for i, idx := range colIndexs { - retTypes[i] = schema.Columns[idx].RetType - } - } - return colIndexs, retTypes -} - -type tableResultHandler struct { - // If the pk is unsigned and we have KeepOrder=true and want ascending order, - // `optionalResult` will handles the request whose range is in signed int range, and - // `result` will handle the request whose range is exceed signed int range. - // If we want descending order, `optionalResult` will handles the request whose range is exceed signed, and - // the `result` will handle the request whose range is in signed. - // Otherwise, we just set `optionalFinished` true and the `result` handles the whole ranges. - optionalResult distsql.SelectResult - result distsql.SelectResult - - optionalFinished bool -} - -func (tr *tableResultHandler) open(optionalResult, result distsql.SelectResult) { - if optionalResult == nil { - tr.optionalFinished = true - tr.result = result - return - } - tr.optionalResult = optionalResult - tr.result = result - tr.optionalFinished = false -} - -func (tr *tableResultHandler) nextChunk(ctx context.Context, chk *chunk.Chunk) error { - if !tr.optionalFinished { - err := tr.optionalResult.Next(ctx, chk) - if err != nil { - return err - } - if chk.NumRows() > 0 { - return nil - } - tr.optionalFinished = true - } - return tr.result.Next(ctx, chk) -} - -func (tr *tableResultHandler) nextRaw(ctx context.Context) (data []byte, err error) { - if !tr.optionalFinished { - data, err = tr.optionalResult.NextRaw(ctx) - if err != nil { - return nil, err - } - if data != nil { - return data, nil - } - tr.optionalFinished = true - } - data, err = tr.result.NextRaw(ctx) - if err != nil { - return nil, err - } - return data, nil -} - -func (tr *tableResultHandler) Close() error { - err := closeAll(tr.optionalResult, tr.result) - tr.optionalResult, tr.result = nil, nil - return err -} diff --git a/pkg/executor/unionexec/binding__failpoint_binding__.go b/pkg/executor/unionexec/binding__failpoint_binding__.go deleted file mode 100644 index 80646b919b715..0000000000000 --- a/pkg/executor/unionexec/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package unionexec - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/executor/unionexec/union.go b/pkg/executor/unionexec/union.go index d952c4a65418a..0b6326dc0261a 100644 --- a/pkg/executor/unionexec/union.go +++ b/pkg/executor/unionexec/union.go @@ -149,9 +149,9 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.stopFetchData.Store(true) e.resultPool <- result } - if _, _err_ := failpoint.Eval(_curpkg_("issue21441")); _err_ == nil { + failpoint.Inject("issue21441", func() { atomic.AddInt32(&e.childInFlightForTest, 1) - } + }) for { if e.stopFetchData.Load().(bool) { return @@ -166,20 +166,20 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.resourcePools[workerID] <- result.chk break } - if _, _err_ := failpoint.Eval(_curpkg_("issue21441")); _err_ == nil { + failpoint.Inject("issue21441", func() { if int(atomic.LoadInt32(&e.childInFlightForTest)) > e.Concurrency { panic("the count of child in flight is larger than e.concurrency unexpectedly") } - } + }) e.resultPool <- result if result.err != nil { e.stopFetchData.Store(true) return } } - if _, _err_ := failpoint.Eval(_curpkg_("issue21441")); _err_ == nil { + failpoint.Inject("issue21441", func() { atomic.AddInt32(&e.childInFlightForTest, -1) - } + }) } } diff --git a/pkg/executor/unionexec/union.go__failpoint_stash__ b/pkg/executor/unionexec/union.go__failpoint_stash__ deleted file mode 100644 index 0b6326dc0261a..0000000000000 --- a/pkg/executor/unionexec/union.go__failpoint_stash__ +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package unionexec - -import ( - "context" - "sync" - "sync/atomic" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/executor/internal/exec" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/channel" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/syncutil" - "go.uber.org/zap" -) - -var ( - _ exec.Executor = &UnionExec{} -) - -// UnionExec pulls all it's children's result and returns to its parent directly. -// A "resultPuller" is started for every child to pull result from that child and push it to the "resultPool", the used -// "Chunk" is obtained from the corresponding "resourcePool". All resultPullers are running concurrently. -// -// +----------------+ -// +---> resourcePool 1 ---> | resultPuller 1 |-----+ -// | +----------------+ | -// | | -// | +----------------+ v -// +---> resourcePool 2 ---> | resultPuller 2 |-----> resultPool ---+ -// | +----------------+ ^ | -// | ...... | | -// | +----------------+ | | -// +---> resourcePool n ---> | resultPuller n |-----+ | -// | +----------------+ | -// | | -// | +-------------+ | -// |--------------------------| main thread | <---------------------+ -// +-------------+ -type UnionExec struct { - exec.BaseExecutor - Concurrency int - childIDChan chan int - - stopFetchData atomic.Value - - finished chan struct{} - resourcePools []chan *chunk.Chunk - resultPool chan *unionWorkerResult - - results []*chunk.Chunk - wg sync.WaitGroup - initialized bool - mu struct { - *syncutil.Mutex - maxOpenedChildID int - } - - childInFlightForTest int32 -} - -// unionWorkerResult stores the result for a union worker. -// A "resultPuller" is started for every child to pull result from that child, unionWorkerResult is used to store that pulled result. -// "src" is used for Chunk reuse: after pulling result from "resultPool", main-thread must push a valid unused Chunk to "src" to -// enable the corresponding "resultPuller" continue to work. -type unionWorkerResult struct { - chk *chunk.Chunk - err error - src chan<- *chunk.Chunk -} - -func (e *UnionExec) waitAllFinished() { - e.wg.Wait() - close(e.resultPool) -} - -// Open implements the Executor Open interface. -func (e *UnionExec) Open(context.Context) error { - e.stopFetchData.Store(false) - e.initialized = false - e.finished = make(chan struct{}) - e.mu.Mutex = &syncutil.Mutex{} - e.mu.maxOpenedChildID = -1 - return nil -} - -func (e *UnionExec) initialize(ctx context.Context) { - if e.Concurrency > e.ChildrenLen() { - e.Concurrency = e.ChildrenLen() - } - for i := 0; i < e.Concurrency; i++ { - e.results = append(e.results, exec.NewFirstChunk(e.Children(0))) - } - e.resultPool = make(chan *unionWorkerResult, e.Concurrency) - e.resourcePools = make([]chan *chunk.Chunk, e.Concurrency) - e.childIDChan = make(chan int, e.ChildrenLen()) - for i := 0; i < e.Concurrency; i++ { - e.resourcePools[i] = make(chan *chunk.Chunk, 1) - e.resourcePools[i] <- e.results[i] - e.wg.Add(1) - go e.resultPuller(ctx, i) - } - for i := 0; i < e.ChildrenLen(); i++ { - e.childIDChan <- i - } - close(e.childIDChan) - go e.waitAllFinished() -} - -func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { - result := &unionWorkerResult{ - err: nil, - chk: nil, - src: e.resourcePools[workerID], - } - defer func() { - if r := recover(); r != nil { - logutil.Logger(ctx).Error("resultPuller panicked", zap.Any("recover", r), zap.Stack("stack")) - result.err = util.GetRecoverError(r) - e.resultPool <- result - e.stopFetchData.Store(true) - } - e.wg.Done() - }() - for childID := range e.childIDChan { - e.mu.Lock() - if childID > e.mu.maxOpenedChildID { - e.mu.maxOpenedChildID = childID - } - e.mu.Unlock() - if err := exec.Open(ctx, e.Children(childID)); err != nil { - result.err = err - e.stopFetchData.Store(true) - e.resultPool <- result - } - failpoint.Inject("issue21441", func() { - atomic.AddInt32(&e.childInFlightForTest, 1) - }) - for { - if e.stopFetchData.Load().(bool) { - return - } - select { - case <-e.finished: - return - case result.chk = <-e.resourcePools[workerID]: - } - result.err = exec.Next(ctx, e.Children(childID), result.chk) - if result.err == nil && result.chk.NumRows() == 0 { - e.resourcePools[workerID] <- result.chk - break - } - failpoint.Inject("issue21441", func() { - if int(atomic.LoadInt32(&e.childInFlightForTest)) > e.Concurrency { - panic("the count of child in flight is larger than e.concurrency unexpectedly") - } - }) - e.resultPool <- result - if result.err != nil { - e.stopFetchData.Store(true) - return - } - } - failpoint.Inject("issue21441", func() { - atomic.AddInt32(&e.childInFlightForTest, -1) - }) - } -} - -// Next implements the Executor Next interface. -func (e *UnionExec) Next(ctx context.Context, req *chunk.Chunk) error { - req.GrowAndReset(e.MaxChunkSize()) - if !e.initialized { - e.initialize(ctx) - e.initialized = true - } - result, ok := <-e.resultPool - if !ok { - return nil - } - if result.err != nil { - return errors.Trace(result.err) - } - - if result.chk.NumCols() != req.NumCols() { - return errors.Errorf("Internal error: UnionExec chunk column count mismatch, req: %d, result: %d", - req.NumCols(), result.chk.NumCols()) - } - req.SwapColumns(result.chk) - result.src <- result.chk - return nil -} - -// Close implements the Executor Close interface. -func (e *UnionExec) Close() error { - if e.finished != nil { - close(e.finished) - } - e.results = nil - if e.resultPool != nil { - channel.Clear(e.resultPool) - } - e.resourcePools = nil - if e.childIDChan != nil { - channel.Clear(e.childIDChan) - } - // We do not need to acquire the e.mu.Lock since all the resultPuller can be - // promised to exit when reaching here (e.childIDChan been closed). - var firstErr error - for i := 0; i <= e.mu.maxOpenedChildID; i++ { - if err := exec.Close(e.Children(i)); err != nil && firstErr == nil { - firstErr = err - } - } - return firstErr -} diff --git a/pkg/expression/aggregation/binding__failpoint_binding__.go b/pkg/expression/aggregation/binding__failpoint_binding__.go deleted file mode 100644 index df733f383834f..0000000000000 --- a/pkg/expression/aggregation/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package aggregation - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/expression/aggregation/explain.go b/pkg/expression/aggregation/explain.go index d63fd9d4e8d7d..d89fc08d88dc6 100644 --- a/pkg/expression/aggregation/explain.go +++ b/pkg/expression/aggregation/explain.go @@ -27,11 +27,11 @@ import ( func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string { var buffer bytes.Buffer showMode := false - if v, _err_ := failpoint.Eval(_curpkg_("show-agg-mode")); _err_ == nil { + failpoint.Inject("show-agg-mode", func(v failpoint.Value) { if v.(bool) { showMode = true } - } + }) if showMode { fmt.Fprintf(&buffer, "%s(%s,", agg.Name, agg.Mode.ToString()) } else { diff --git a/pkg/expression/aggregation/explain.go__failpoint_stash__ b/pkg/expression/aggregation/explain.go__failpoint_stash__ deleted file mode 100644 index d89fc08d88dc6..0000000000000 --- a/pkg/expression/aggregation/explain.go__failpoint_stash__ +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package aggregation - -import ( - "bytes" - "fmt" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/ast" -) - -// ExplainAggFunc generates explain information for a aggregation function. -func ExplainAggFunc(ctx expression.EvalContext, agg *AggFuncDesc, normalized bool) string { - var buffer bytes.Buffer - showMode := false - failpoint.Inject("show-agg-mode", func(v failpoint.Value) { - if v.(bool) { - showMode = true - } - }) - if showMode { - fmt.Fprintf(&buffer, "%s(%s,", agg.Name, agg.Mode.ToString()) - } else { - fmt.Fprintf(&buffer, "%s(", agg.Name) - } - - if agg.HasDistinct { - buffer.WriteString("distinct ") - } - for i, arg := range agg.Args { - if agg.Name == ast.AggFuncGroupConcat && i == len(agg.Args)-1 { - if len(agg.OrderByItems) > 0 { - buffer.WriteString(" order by ") - for i, item := range agg.OrderByItems { - if item.Desc { - if normalized { - fmt.Fprintf(&buffer, "%s desc", item.Expr.ExplainNormalizedInfo()) - } else { - fmt.Fprintf(&buffer, "%s desc", item.Expr.ExplainInfo(ctx)) - } - } else { - if normalized { - fmt.Fprintf(&buffer, "%s", item.Expr.ExplainNormalizedInfo()) - } else { - fmt.Fprintf(&buffer, "%s", item.Expr.ExplainInfo(ctx)) - } - } - - if i+1 < len(agg.OrderByItems) { - buffer.WriteString(", ") - } - } - } - buffer.WriteString(" separator ") - } else if i != 0 { - buffer.WriteString(", ") - } - if normalized { - buffer.WriteString(arg.ExplainNormalizedInfo()) - } else { - buffer.WriteString(arg.ExplainInfo(ctx)) - } - } - buffer.WriteString(")") - return buffer.String() -} diff --git a/pkg/expression/binding__failpoint_binding__.go b/pkg/expression/binding__failpoint_binding__.go deleted file mode 100644 index 7464c1d08001f..0000000000000 --- a/pkg/expression/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package expression - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/expression/builtin_json.go b/pkg/expression/builtin_json.go index 6f62e0a8981ea..a6acf9844cfc5 100644 --- a/pkg/expression/builtin_json.go +++ b/pkg/expression/builtin_json.go @@ -1881,9 +1881,9 @@ func (b *builtinJSONSchemaValidSig) evalInt(ctx EvalContext, row chunk.Row) (res if b.args[0].ConstLevel() >= ConstOnlyInContext { schema, err = b.schemaCache.getOrInitCache(ctx, func() (jsonschema.Schema, error) { - if _, _err_ := failpoint.Eval(_curpkg_("jsonSchemaValidDisableCacheRefresh")); _err_ == nil { - return jsonschema.Schema{}, errors.New("Cache refresh disabled by failpoint") - } + failpoint.Inject("jsonSchemaValidDisableCacheRefresh", func() { + failpoint.Return(jsonschema.Schema{}, errors.New("Cache refresh disabled by failpoint")) + }) dataBin, err := schemaData.MarshalJSON() if err != nil { return jsonschema.Schema{}, err diff --git a/pkg/expression/builtin_json.go__failpoint_stash__ b/pkg/expression/builtin_json.go__failpoint_stash__ deleted file mode 100644 index a6acf9844cfc5..0000000000000 --- a/pkg/expression/builtin_json.go__failpoint_stash__ +++ /dev/null @@ -1,1940 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package expression - -import ( - "bytes" - "context" - goJSON "encoding/json" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tipb/go-tipb" - "github.com/qri-io/jsonschema" -) - -var ( - _ functionClass = &jsonTypeFunctionClass{} - _ functionClass = &jsonExtractFunctionClass{} - _ functionClass = &jsonUnquoteFunctionClass{} - _ functionClass = &jsonQuoteFunctionClass{} - _ functionClass = &jsonSetFunctionClass{} - _ functionClass = &jsonInsertFunctionClass{} - _ functionClass = &jsonReplaceFunctionClass{} - _ functionClass = &jsonRemoveFunctionClass{} - _ functionClass = &jsonMergeFunctionClass{} - _ functionClass = &jsonObjectFunctionClass{} - _ functionClass = &jsonArrayFunctionClass{} - _ functionClass = &jsonMemberOfFunctionClass{} - _ functionClass = &jsonContainsFunctionClass{} - _ functionClass = &jsonOverlapsFunctionClass{} - _ functionClass = &jsonContainsPathFunctionClass{} - _ functionClass = &jsonValidFunctionClass{} - _ functionClass = &jsonArrayAppendFunctionClass{} - _ functionClass = &jsonArrayInsertFunctionClass{} - _ functionClass = &jsonMergePatchFunctionClass{} - _ functionClass = &jsonMergePreserveFunctionClass{} - _ functionClass = &jsonPrettyFunctionClass{} - _ functionClass = &jsonQuoteFunctionClass{} - _ functionClass = &jsonSchemaValidFunctionClass{} - _ functionClass = &jsonSearchFunctionClass{} - _ functionClass = &jsonStorageSizeFunctionClass{} - _ functionClass = &jsonDepthFunctionClass{} - _ functionClass = &jsonKeysFunctionClass{} - _ functionClass = &jsonLengthFunctionClass{} - - _ builtinFunc = &builtinJSONTypeSig{} - _ builtinFunc = &builtinJSONQuoteSig{} - _ builtinFunc = &builtinJSONUnquoteSig{} - _ builtinFunc = &builtinJSONArraySig{} - _ builtinFunc = &builtinJSONArrayAppendSig{} - _ builtinFunc = &builtinJSONArrayInsertSig{} - _ builtinFunc = &builtinJSONObjectSig{} - _ builtinFunc = &builtinJSONExtractSig{} - _ builtinFunc = &builtinJSONSetSig{} - _ builtinFunc = &builtinJSONInsertSig{} - _ builtinFunc = &builtinJSONReplaceSig{} - _ builtinFunc = &builtinJSONRemoveSig{} - _ builtinFunc = &builtinJSONMergeSig{} - _ builtinFunc = &builtinJSONMemberOfSig{} - _ builtinFunc = &builtinJSONContainsSig{} - _ builtinFunc = &builtinJSONOverlapsSig{} - _ builtinFunc = &builtinJSONStorageSizeSig{} - _ builtinFunc = &builtinJSONDepthSig{} - _ builtinFunc = &builtinJSONSchemaValidSig{} - _ builtinFunc = &builtinJSONSearchSig{} - _ builtinFunc = &builtinJSONKeysSig{} - _ builtinFunc = &builtinJSONKeys2ArgsSig{} - _ builtinFunc = &builtinJSONLengthSig{} - _ builtinFunc = &builtinJSONValidJSONSig{} - _ builtinFunc = &builtinJSONValidStringSig{} - _ builtinFunc = &builtinJSONValidOthersSig{} -) - -type jsonTypeFunctionClass struct { - baseFunctionClass -} - -type builtinJSONTypeSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONTypeSig) Clone() builtinFunc { - newSig := &builtinJSONTypeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonTypeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETJson) - if err != nil { - return nil, err - } - charset, collate := ctx.GetCharsetInfo() - bf.tp.SetCharset(charset) - bf.tp.SetCollate(collate) - bf.tp.SetFlen(51) // flen of JSON_TYPE is length of UNSIGNED INTEGER. - bf.tp.AddFlag(mysql.BinaryFlag) - sig := &builtinJSONTypeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonTypeSig) - return sig, nil -} - -func (b *builtinJSONTypeSig) evalString(ctx EvalContext, row chunk.Row) (val string, isNull bool, err error) { - var j types.BinaryJSON - j, isNull, err = b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - return j.Type(), false, nil -} - -type jsonExtractFunctionClass struct { - baseFunctionClass -} - -type builtinJSONExtractSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONExtractSig) Clone() builtinFunc { - newSig := &builtinJSONExtractSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonExtractFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_extract") - } - return nil -} - -func (c *jsonExtractFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for range args[1:] { - argTps = append(argTps, types.ETString) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONExtractSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonExtractSig) - return sig, nil -} - -func (b *builtinJSONExtractSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return - } - pathExprs := make([]types.JSONPathExpression, 0, len(b.args)-1) - for _, arg := range b.args[1:] { - var s string - s, isNull, err = arg.EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - pathExpr, err := types.ParseJSONPathExpr(s) - if err != nil { - return res, true, err - } - pathExprs = append(pathExprs, pathExpr) - } - var found bool - if res, found = res.Extract(pathExprs); !found { - return res, true, nil - } - return res, false, nil -} - -type jsonUnquoteFunctionClass struct { - baseFunctionClass -} - -type builtinJSONUnquoteSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONUnquoteSig) Clone() builtinFunc { - newSig := &builtinJSONUnquoteSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonUnquoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrIncorrectType.GenWithStackByArgs("1", "json_unquote") - } - return nil -} - -func (c *jsonUnquoteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString) - if err != nil { - return nil, err - } - bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen()) - bf.tp.AddFlag(mysql.BinaryFlag) - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[0]) - sig := &builtinJSONUnquoteSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonUnquoteSig) - return sig, nil -} - -func (b *builtinJSONUnquoteSig) evalString(ctx EvalContext, row chunk.Row) (str string, isNull bool, err error) { - str, isNull, err = b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - if len(str) >= 2 && str[0] == '"' && str[len(str)-1] == '"' && !goJSON.Valid([]byte(str)) { - return "", false, types.ErrInvalidJSONText.GenWithStackByArgs("The document root must not be followed by other values.") - } - str, err = types.UnquoteString(str) - if err != nil { - return "", false, err - } - return str, false, nil -} - -type jsonSetFunctionClass struct { - baseFunctionClass -} - -type builtinJSONSetSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONSetSig) Clone() builtinFunc { - newSig := &builtinJSONSetSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonSetFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - if len(args)&1 != 1 { - return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) - } - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for i := 1; i < len(args)-1; i += 2 { - argTps = append(argTps, types.ETString, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - for i := 2; i < len(args); i += 2 { - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) - } - sig := &builtinJSONSetSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonSetSig) - return sig, nil -} - -func (b *builtinJSONSetSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = jsonModify(ctx, b.args, row, types.JSONModifySet) - return res, isNull, err -} - -type jsonInsertFunctionClass struct { - baseFunctionClass -} - -type builtinJSONInsertSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONInsertSig) Clone() builtinFunc { - newSig := &builtinJSONInsertSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonInsertFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - if len(args)&1 != 1 { - return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) - } - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for i := 1; i < len(args)-1; i += 2 { - argTps = append(argTps, types.ETString, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - for i := 2; i < len(args); i += 2 { - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) - } - sig := &builtinJSONInsertSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonInsertSig) - return sig, nil -} - -func (b *builtinJSONInsertSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = jsonModify(ctx, b.args, row, types.JSONModifyInsert) - return res, isNull, err -} - -type jsonReplaceFunctionClass struct { - baseFunctionClass -} - -type builtinJSONReplaceSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONReplaceSig) Clone() builtinFunc { - newSig := &builtinJSONReplaceSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonReplaceFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - if len(args)&1 != 1 { - return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) - } - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for i := 1; i < len(args)-1; i += 2 { - argTps = append(argTps, types.ETString, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - for i := 2; i < len(args); i += 2 { - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) - } - sig := &builtinJSONReplaceSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonReplaceSig) - return sig, nil -} - -func (b *builtinJSONReplaceSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = jsonModify(ctx, b.args, row, types.JSONModifyReplace) - return res, isNull, err -} - -type jsonRemoveFunctionClass struct { - baseFunctionClass -} - -type builtinJSONRemoveSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONRemoveSig) Clone() builtinFunc { - newSig := &builtinJSONRemoveSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonRemoveFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for range args[1:] { - argTps = append(argTps, types.ETString) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONRemoveSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonRemoveSig) - return sig, nil -} - -func (b *builtinJSONRemoveSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - pathExprs := make([]types.JSONPathExpression, 0, len(b.args)-1) - for _, arg := range b.args[1:] { - var s string - s, isNull, err = arg.EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - var pathExpr types.JSONPathExpression - pathExpr, err = types.ParseJSONPathExpr(s) - if err != nil { - return res, true, err - } - pathExprs = append(pathExprs, pathExpr) - } - res, err = res.Remove(pathExprs) - if err != nil { - return res, true, err - } - return res, false, nil -} - -type jsonMergeFunctionClass struct { - baseFunctionClass -} - -func (c *jsonMergeFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - for i, arg := range args { - if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge") - } - } - return nil -} - -type builtinJSONMergeSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONMergeSig) Clone() builtinFunc { - newSig := &builtinJSONMergeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonMergeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, len(args)) - for range args { - argTps = append(argTps, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONMergeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonMergeSig) - return sig, nil -} - -func (b *builtinJSONMergeSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - values := make([]types.BinaryJSON, 0, len(b.args)) - for _, arg := range b.args { - var value types.BinaryJSON - value, isNull, err = arg.EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - values = append(values, value) - } - res = types.MergeBinaryJSON(values) - // function "JSON_MERGE" is deprecated since MySQL 5.7.22. Synonym for function "JSON_MERGE_PRESERVE". - // See https://dev.mysql.com/doc/refman/5.7/en/json-modification-functions.html#function_json-merge - if b.pbCode == tipb.ScalarFuncSig_JsonMergeSig { - tc := typeCtx(ctx) - tc.AppendWarning(errDeprecatedSyntaxNoReplacement.FastGenByArgs("JSON_MERGE", "")) - } - return res, false, nil -} - -type jsonObjectFunctionClass struct { - baseFunctionClass -} - -type builtinJSONObjectSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONObjectSig) Clone() builtinFunc { - newSig := &builtinJSONObjectSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonObjectFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - if len(args)&1 != 0 { - return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) - } - argTps := make([]types.EvalType, 0, len(args)) - for i := 0; i < len(args)-1; i += 2 { - if args[i].GetType(ctx.GetEvalCtx()).EvalType() == types.ETString && args[i].GetType(ctx.GetEvalCtx()).GetCharset() == charset.CharsetBin { - return nil, types.ErrInvalidJSONCharset.GenWithStackByArgs(args[i].GetType(ctx.GetEvalCtx()).GetCharset()) - } - argTps = append(argTps, types.ETString, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - for i := 1; i < len(args); i += 2 { - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) - } - sig := &builtinJSONObjectSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonObjectSig) - return sig, nil -} - -func (b *builtinJSONObjectSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - if len(b.args)&1 == 1 { - err = ErrIncorrectParameterCount.GenWithStackByArgs(ast.JSONObject) - return res, true, err - } - jsons := make(map[string]any, len(b.args)>>1) - var key string - var value types.BinaryJSON - for i, arg := range b.args { - if i&1 == 0 { - key, isNull, err = arg.EvalString(ctx, row) - if err != nil { - return res, true, err - } - if isNull { - return res, true, types.ErrJSONDocumentNULLKey - } - } else { - value, isNull, err = arg.EvalJSON(ctx, row) - if err != nil { - return res, true, err - } - if isNull { - value = types.CreateBinaryJSON(nil) - } - jsons[key] = value - } - } - bj, err := types.CreateBinaryJSONWithCheck(jsons) - if err != nil { - return res, true, err - } - return bj, false, nil -} - -type jsonArrayFunctionClass struct { - baseFunctionClass -} - -type builtinJSONArraySig struct { - baseBuiltinFunc -} - -func (b *builtinJSONArraySig) Clone() builtinFunc { - newSig := &builtinJSONArraySig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonArrayFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, len(args)) - for range args { - argTps = append(argTps, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - for i := range args { - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) - } - sig := &builtinJSONArraySig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonArraySig) - return sig, nil -} - -func (b *builtinJSONArraySig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - jsons := make([]any, 0, len(b.args)) - for _, arg := range b.args { - j, isNull, err := arg.EvalJSON(ctx, row) - if err != nil { - return res, true, err - } - if isNull { - j = types.CreateBinaryJSON(nil) - } - jsons = append(jsons, j) - } - bj, err := types.CreateBinaryJSONWithCheck(jsons) - if err != nil { - return res, true, err - } - return bj, false, nil -} - -type jsonContainsPathFunctionClass struct { - baseFunctionClass -} - -type builtinJSONContainsPathSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONContainsPathSig) Clone() builtinFunc { - newSig := &builtinJSONContainsPathSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonContainsPathFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains_path") - } - return nil -} - -func (c *jsonContainsPathFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := []types.EvalType{types.ETJson, types.ETString} - for i := 3; i <= len(args); i++ { - argTps = append(argTps, types.ETString) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONContainsPathSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonContainsPathSig) - return sig, nil -} - -func (b *builtinJSONContainsPathSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - obj, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - containType, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - containType = strings.ToLower(containType) - if containType != types.JSONContainsPathAll && containType != types.JSONContainsPathOne { - return res, true, types.ErrJSONBadOneOrAllArg.GenWithStackByArgs("json_contains_path") - } - var pathExpr types.JSONPathExpression - contains := int64(1) - for i := 2; i < len(b.args); i++ { - path, isNull, err := b.args[i].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - if pathExpr, err = types.ParseJSONPathExpr(path); err != nil { - return res, true, err - } - _, exists := obj.Extract([]types.JSONPathExpression{pathExpr}) - switch { - case exists && containType == types.JSONContainsPathOne: - return 1, false, nil - case !exists && containType == types.JSONContainsPathOne: - contains = 0 - case !exists && containType == types.JSONContainsPathAll: - return 0, false, nil - } - } - return contains, false, nil -} - -func jsonModify(ctx EvalContext, args []Expression, row chunk.Row, mt types.JSONModifyType) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - pathExprs := make([]types.JSONPathExpression, 0, (len(args)-1)/2+1) - for i := 1; i < len(args); i += 2 { - // TODO: We can cache pathExprs if args are constants. - var s string - s, isNull, err = args[i].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - var pathExpr types.JSONPathExpression - pathExpr, err = types.ParseJSONPathExpr(s) - if err != nil { - return res, true, err - } - pathExprs = append(pathExprs, pathExpr) - } - values := make([]types.BinaryJSON, 0, (len(args)-1)/2+1) - for i := 2; i < len(args); i += 2 { - var value types.BinaryJSON - value, isNull, err = args[i].EvalJSON(ctx, row) - if err != nil { - return res, true, err - } - if isNull { - value = types.CreateBinaryJSON(nil) - } - values = append(values, value) - } - res, err = res.Modify(pathExprs, values, mt) - if err != nil { - return res, true, err - } - return res, false, nil -} - -type jsonMemberOfFunctionClass struct { - baseFunctionClass -} - -type builtinJSONMemberOfSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONMemberOfSig) Clone() builtinFunc { - newSig := &builtinJSONMemberOfSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonMemberOfFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "member of") - } - return nil -} - -func (c *jsonMemberOfFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := []types.EvalType{types.ETJson, types.ETJson} - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) - if err != nil { - return nil, err - } - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[0]) - sig := &builtinJSONMemberOfSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonMemberOfSig) - return sig, nil -} - -func (b *builtinJSONMemberOfSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - target, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - obj, isNull, err := b.args[1].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - if obj.TypeCode != types.JSONTypeCodeArray { - return boolToInt64(types.CompareBinaryJSON(obj, target) == 0), false, nil - } - - elemCount := obj.GetElemCount() - for i := 0; i < elemCount; i++ { - if types.CompareBinaryJSON(obj.ArrayGetElem(i), target) == 0 { - return 1, false, nil - } - } - - return 0, false, nil -} - -type jsonContainsFunctionClass struct { - baseFunctionClass -} - -type builtinJSONContainsSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONContainsSig) Clone() builtinFunc { - newSig := &builtinJSONContainsSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonContainsFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_contains") - } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_contains") - } - return nil -} - -func (c *jsonContainsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - - argTps := []types.EvalType{types.ETJson, types.ETJson} - if len(args) == 3 { - argTps = append(argTps, types.ETString) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONContainsSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonContainsSig) - return sig, nil -} - -func (b *builtinJSONContainsSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - obj, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - target, isNull, err := b.args[1].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - var pathExpr types.JSONPathExpression - if len(b.args) == 3 { - path, isNull, err := b.args[2].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - pathExpr, err = types.ParseJSONPathExpr(path) - if err != nil { - return res, true, err - } - if pathExpr.CouldMatchMultipleValues() { - return res, true, types.ErrInvalidJSONPathMultipleSelection - } - var exists bool - obj, exists = obj.Extract([]types.JSONPathExpression{pathExpr}) - if !exists { - return res, true, nil - } - } - - if types.ContainsBinaryJSON(obj, target) { - return 1, false, nil - } - return 0, false, nil -} - -type jsonOverlapsFunctionClass struct { - baseFunctionClass -} - -type builtinJSONOverlapsSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONOverlapsSig) Clone() builtinFunc { - newSig := &builtinJSONOverlapsSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonOverlapsFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_overlaps") - } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETJson && evalType != types.ETString { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_overlaps") - } - return nil -} - -func (c *jsonOverlapsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - - argTps := []types.EvalType{types.ETJson, types.ETJson} - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONOverlapsSig{bf} - return sig, nil -} - -func (b *builtinJSONOverlapsSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - obj, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - target, isNull, err := b.args[1].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - if types.OverlapsBinaryJSON(obj, target) { - return 1, false, nil - } - return 0, false, nil -} - -type jsonValidFunctionClass struct { - baseFunctionClass -} - -func (c *jsonValidFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - var sig builtinFunc - argType := args[0].GetType(ctx.GetEvalCtx()).EvalType() - switch argType { - case types.ETJson: - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) - if err != nil { - return nil, err - } - sig = &builtinJSONValidJSONSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonValidJsonSig) - case types.ETString: - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString) - if err != nil { - return nil, err - } - sig = &builtinJSONValidStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonValidStringSig) - default: - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argType) - if err != nil { - return nil, err - } - sig = &builtinJSONValidOthersSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonValidOthersSig) - } - return sig, nil -} - -type builtinJSONValidJSONSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONValidJSONSig) Clone() builtinFunc { - newSig := &builtinJSONValidJSONSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinJSONValidJSONSig. -// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid -func (b *builtinJSONValidJSONSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { - _, isNull, err = b.args[0].EvalJSON(ctx, row) - return 1, isNull, err -} - -type builtinJSONValidStringSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONValidStringSig) Clone() builtinFunc { - newSig := &builtinJSONValidStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinJSONValidStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid -func (b *builtinJSONValidStringSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - val, isNull, err := b.args[0].EvalString(ctx, row) - if err != nil || isNull { - return 0, isNull, err - } - - data := hack.Slice(val) - if goJSON.Valid(data) { - res = 1 - } - return res, false, nil -} - -type builtinJSONValidOthersSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONValidOthersSig) Clone() builtinFunc { - newSig := &builtinJSONValidOthersSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinJSONValidOthersSig. -// See https://dev.mysql.com/doc/refman/5.7/en/json-attribute-functions.html#function_json-valid -func (b *builtinJSONValidOthersSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { - return 0, false, nil -} - -type jsonArrayAppendFunctionClass struct { - baseFunctionClass -} - -type builtinJSONArrayAppendSig struct { - baseBuiltinFunc -} - -func (c *jsonArrayAppendFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if len(args) < 3 || (len(args)&1 != 1) { - return ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) - } - return nil -} - -func (c *jsonArrayAppendFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for i := 1; i < len(args)-1; i += 2 { - argTps = append(argTps, types.ETString, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - for i := 2; i < len(args); i += 2 { - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) - } - sig := &builtinJSONArrayAppendSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonArrayAppendSig) - return sig, nil -} - -func (b *builtinJSONArrayAppendSig) Clone() builtinFunc { - newSig := &builtinJSONArrayAppendSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinJSONArrayAppendSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = b.args[0].EvalJSON(ctx, row) - if err != nil || isNull { - return res, true, err - } - - for i := 1; i < len(b.args)-1; i += 2 { - // If JSON path is NULL, MySQL breaks and returns NULL. - s, sNull, err := b.args[i].EvalString(ctx, row) - if sNull || err != nil { - return res, true, err - } - value, vNull, err := b.args[i+1].EvalJSON(ctx, row) - if err != nil { - return res, true, err - } - if vNull { - value = types.CreateBinaryJSON(nil) - } - res, isNull, err = b.appendJSONArray(res, s, value) - if isNull || err != nil { - return res, isNull, err - } - } - return res, false, nil -} - -func (b *builtinJSONArrayAppendSig) appendJSONArray(res types.BinaryJSON, p string, v types.BinaryJSON) (types.BinaryJSON, bool, error) { - // We should do the following checks to get correct values in res.Extract - pathExpr, err := types.ParseJSONPathExpr(p) - if err != nil { - return res, true, err - } - if pathExpr.CouldMatchMultipleValues() { - return res, true, types.ErrInvalidJSONPathMultipleSelection - } - - obj, exists := res.Extract([]types.JSONPathExpression{pathExpr}) - if !exists { - // If path not exists, just do nothing and no errors. - return res, false, nil - } - - if obj.TypeCode != types.JSONTypeCodeArray { - // res.Extract will return a json object instead of an array if there is an object at path pathExpr. - // JSON_ARRAY_APPEND({"a": "b"}, "$", {"b": "c"}) => [{"a": "b"}, {"b", "c"}] - // We should wrap them to a single array first. - obj, err = types.CreateBinaryJSONWithCheck([]any{obj}) - if err != nil { - return res, true, err - } - } - - obj = types.MergeBinaryJSON([]types.BinaryJSON{obj, v}) - res, err = res.Modify([]types.JSONPathExpression{pathExpr}, []types.BinaryJSON{obj}, types.JSONModifySet) - return res, false, err -} - -type jsonArrayInsertFunctionClass struct { - baseFunctionClass -} - -type builtinJSONArrayInsertSig struct { - baseBuiltinFunc -} - -func (c *jsonArrayInsertFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - if len(args)&1 != 1 { - return nil, ErrIncorrectParameterCount.GenWithStackByArgs(c.funcName) - } - - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for i := 1; i < len(args)-1; i += 2 { - argTps = append(argTps, types.ETString, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - for i := 2; i < len(args); i += 2 { - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[i]) - } - sig := &builtinJSONArrayInsertSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonArrayInsertSig) - return sig, nil -} - -func (b *builtinJSONArrayInsertSig) Clone() builtinFunc { - newSig := &builtinJSONArrayInsertSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinJSONArrayInsertSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = b.args[0].EvalJSON(ctx, row) - if err != nil || isNull { - return res, true, err - } - - for i := 1; i < len(b.args)-1; i += 2 { - // If JSON path is NULL, MySQL breaks and returns NULL. - s, isNull, err := b.args[i].EvalString(ctx, row) - if err != nil || isNull { - return res, true, err - } - - pathExpr, err := types.ParseJSONPathExpr(s) - if err != nil { - return res, true, err - } - if pathExpr.CouldMatchMultipleValues() { - return res, true, types.ErrInvalidJSONPathMultipleSelection - } - - value, isnull, err := b.args[i+1].EvalJSON(ctx, row) - if err != nil { - return res, true, err - } - - if isnull { - value = types.CreateBinaryJSON(nil) - } - - res, err = res.ArrayInsert(pathExpr, value) - if err != nil { - return res, true, err - } - } - return res, false, nil -} - -type jsonMergePatchFunctionClass struct { - baseFunctionClass -} - -func (c *jsonMergePatchFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - for i, arg := range args { - if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_patch") - } - } - return nil -} - -func (c *jsonMergePatchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, len(args)) - for range args { - argTps = append(argTps, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONMergePatchSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonMergePatchSig) - return sig, nil -} - -type builtinJSONMergePatchSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONMergePatchSig) Clone() builtinFunc { - newSig := &builtinJSONMergePatchSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinJSONMergePatchSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - values := make([]*types.BinaryJSON, 0, len(b.args)) - for _, arg := range b.args { - var value types.BinaryJSON - value, isNull, err = arg.EvalJSON(ctx, row) - if err != nil { - return - } - if isNull { - values = append(values, nil) - } else { - values = append(values, &value) - } - } - tmpRes, err := types.MergePatchBinaryJSON(values) - if err != nil { - return - } - if tmpRes != nil { - res = *tmpRes - } else { - isNull = true - } - return res, isNull, nil -} - -type jsonMergePreserveFunctionClass struct { - baseFunctionClass -} - -func (c *jsonMergePreserveFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - for i, arg := range args { - if evalType := arg.GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(i+1, "json_merge_preserve") - } - } - return nil -} - -func (c *jsonMergePreserveFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, len(args)) - for range args { - argTps = append(argTps, types.ETJson) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONMergeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonMergePreserveSig) - return sig, nil -} - -type jsonPrettyFunctionClass struct { - baseFunctionClass -} - -type builtinJSONSPrettySig struct { - baseBuiltinFunc -} - -func (b *builtinJSONSPrettySig) Clone() builtinFunc { - newSig := &builtinJSONSPrettySig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonPrettyFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETJson) - if err != nil { - return nil, err - } - bf.tp.AddFlag(mysql.BinaryFlag) - bf.tp.SetFlen(mysql.MaxBlobWidth * 4) - sig := &builtinJSONSPrettySig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonPrettySig) - return sig, nil -} - -func (b *builtinJSONSPrettySig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) { - obj, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - buf, err := obj.MarshalJSON() - if err != nil { - return res, isNull, err - } - var resBuf bytes.Buffer - if err = goJSON.Indent(&resBuf, buf, "", " "); err != nil { - return res, isNull, err - } - return resBuf.String(), false, nil -} - -type jsonQuoteFunctionClass struct { - baseFunctionClass -} - -type builtinJSONQuoteSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONQuoteSig) Clone() builtinFunc { - newSig := &builtinJSONQuoteSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonQuoteFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString { - return ErrIncorrectType.GenWithStackByArgs("1", "json_quote") - } - return nil -} - -func (c *jsonQuoteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString) - if err != nil { - return nil, err - } - DisableParseJSONFlag4Expr(ctx.GetEvalCtx(), args[0]) - bf.tp.AddFlag(mysql.BinaryFlag) - bf.tp.SetFlen(args[0].GetType(ctx.GetEvalCtx()).GetFlen()*6 + 2) - sig := &builtinJSONQuoteSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonQuoteSig) - return sig, nil -} - -func (b *builtinJSONQuoteSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - str, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - buffer := &bytes.Buffer{} - encoder := goJSON.NewEncoder(buffer) - encoder.SetEscapeHTML(false) - err = encoder.Encode(str) - if err != nil { - return "", isNull, err - } - return string(bytes.TrimSuffix(buffer.Bytes(), []byte("\n"))), false, nil -} - -type jsonSearchFunctionClass struct { - baseFunctionClass -} - -type builtinJSONSearchSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONSearchSig) Clone() builtinFunc { - newSig := &builtinJSONSearchSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonSearchFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_search") - } - return nil -} - -func (c *jsonSearchFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - // json_doc, one_or_all, search_str[, escape_char[, path] ...]) - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - for range args[1:] { - argTps = append(argTps, types.ETString) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONSearchSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonSearchSig) - return sig, nil -} - -func (b *builtinJSONSearchSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - // json_doc - var obj types.BinaryJSON - obj, isNull, err = b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - // one_or_all - var containType string - containType, isNull, err = b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - containType = strings.ToLower(containType) - if containType != types.JSONContainsPathAll && containType != types.JSONContainsPathOne { - return res, true, errors.AddStack(types.ErrInvalidJSONContainsPathType) - } - - // search_str & escape_char - var searchStr string - searchStr, isNull, err = b.args[2].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - escape := byte('\\') - if len(b.args) >= 4 { - var escapeStr string - escapeStr, isNull, err = b.args[3].EvalString(ctx, row) - if err != nil { - return res, isNull, err - } - if isNull || len(escapeStr) == 0 { - escape = byte('\\') - } else if len(escapeStr) == 1 { - escape = escapeStr[0] - } else { - return res, true, errIncorrectArgs.GenWithStackByArgs("ESCAPE") - } - } - if len(b.args) >= 5 { // path... - pathExprs := make([]types.JSONPathExpression, 0, len(b.args)-4) - for i := 4; i < len(b.args); i++ { - var s string - s, isNull, err = b.args[i].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - var pathExpr types.JSONPathExpression - pathExpr, err = types.ParseJSONPathExpr(s) - if err != nil { - return res, true, err - } - pathExprs = append(pathExprs, pathExpr) - } - return obj.Search(containType, searchStr, escape, pathExprs) - } - return obj.Search(containType, searchStr, escape, nil) -} - -type jsonStorageFreeFunctionClass struct { - baseFunctionClass -} - -type builtinJSONStorageFreeSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONStorageFreeSig) Clone() builtinFunc { - newSig := &builtinJSONStorageFreeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonStorageFreeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) - if err != nil { - return nil, err - } - sig := &builtinJSONStorageFreeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonStorageFreeSig) - return sig, nil -} - -func (b *builtinJSONStorageFreeSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - _, isNull, err = b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - return 0, false, nil -} - -type jsonStorageSizeFunctionClass struct { - baseFunctionClass -} - -type builtinJSONStorageSizeSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONStorageSizeSig) Clone() builtinFunc { - newSig := &builtinJSONStorageSizeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonStorageSizeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) - if err != nil { - return nil, err - } - sig := &builtinJSONStorageSizeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonStorageSizeSig) - return sig, nil -} - -func (b *builtinJSONStorageSizeSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - obj, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - // returns the length of obj value plus 1 (the TypeCode) - return int64(len(obj.Value)) + 1, false, nil -} - -type jsonDepthFunctionClass struct { - baseFunctionClass -} - -type builtinJSONDepthSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONDepthSig) Clone() builtinFunc { - newSig := &builtinJSONDepthSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonDepthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson) - if err != nil { - return nil, err - } - sig := &builtinJSONDepthSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonDepthSig) - return sig, nil -} - -func (b *builtinJSONDepthSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - // as TiDB doesn't support partial update json value, so only check the - // json format and whether it's NULL. For NULL return NULL, for invalid json, return - // an error, otherwise return 0 - - obj, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - return int64(obj.GetElemDepth()), false, nil -} - -type jsonKeysFunctionClass struct { - baseFunctionClass -} - -func (c *jsonKeysFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_keys") - } - return nil -} - -func (c *jsonKeysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - argTps := []types.EvalType{types.ETJson} - if len(args) == 2 { - argTps = append(argTps, types.ETString) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETJson, argTps...) - if err != nil { - return nil, err - } - var sig builtinFunc - switch len(args) { - case 1: - sig = &builtinJSONKeysSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonKeysSig) - case 2: - sig = &builtinJSONKeys2ArgsSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonKeys2ArgsSig) - } - return sig, nil -} - -type builtinJSONKeysSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONKeysSig) Clone() builtinFunc { - newSig := &builtinJSONKeysSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinJSONKeysSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - if res.TypeCode != types.JSONTypeCodeObject { - return res, true, nil - } - return res.GetKeys(), false, nil -} - -type builtinJSONKeys2ArgsSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONKeys2ArgsSig) Clone() builtinFunc { - newSig := &builtinJSONKeys2ArgsSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinJSONKeys2ArgsSig) evalJSON(ctx EvalContext, row chunk.Row) (res types.BinaryJSON, isNull bool, err error) { - res, isNull, err = b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - path, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - pathExpr, err := types.ParseJSONPathExpr(path) - if err != nil { - return res, true, err - } - if pathExpr.CouldMatchMultipleValues() { - return res, true, types.ErrInvalidJSONPathMultipleSelection - } - - res, exists := res.Extract([]types.JSONPathExpression{pathExpr}) - if !exists { - return res, true, nil - } - if res.TypeCode != types.JSONTypeCodeObject { - return res, true, nil - } - - return res.GetKeys(), false, nil -} - -type jsonLengthFunctionClass struct { - baseFunctionClass -} - -type builtinJSONLengthSig struct { - baseBuiltinFunc -} - -func (b *builtinJSONLengthSig) Clone() builtinFunc { - newSig := &builtinJSONLengthSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (c *jsonLengthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - argTps := make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETJson) - if len(args) == 2 { - argTps = append(argTps, types.ETString) - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) - if err != nil { - return nil, err - } - sig := &builtinJSONLengthSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_JsonLengthSig) - return sig, nil -} - -func (b *builtinJSONLengthSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - obj, isNull, err := b.args[0].EvalJSON(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - if len(b.args) == 2 { - path, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - pathExpr, err := types.ParseJSONPathExpr(path) - if err != nil { - return res, true, err - } - if pathExpr.CouldMatchMultipleValues() { - return res, true, types.ErrInvalidJSONPathMultipleSelection - } - - var exists bool - obj, exists = obj.Extract([]types.JSONPathExpression{pathExpr}) - if !exists { - return res, true, nil - } - } - - if obj.TypeCode != types.JSONTypeCodeObject && obj.TypeCode != types.JSONTypeCodeArray { - return 1, false, nil - } - return int64(obj.GetElemCount()), false, nil -} - -type jsonSchemaValidFunctionClass struct { - baseFunctionClass -} - -func (c *jsonSchemaValidFunctionClass) verifyArgs(ctx EvalContext, args []Expression) error { - if err := c.baseFunctionClass.verifyArgs(args); err != nil { - return err - } - if evalType := args[0].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(1, "json_schema_valid") - } - if evalType := args[1].GetType(ctx).EvalType(); evalType != types.ETString && evalType != types.ETJson { - return ErrInvalidTypeForJSON.GenWithStackByArgs(2, "json_schema_valid") - } - if c, ok := args[0].(*Constant); ok { - // If args[0] is NULL, then don't check the length of *both* arguments. - // JSON_SCHEMA_VALID(NULL,NULL) -> NULL - // JSON_SCHEMA_VALID(NULL,'') -> NULL - // JSON_SCHEMA_VALID('',NULL) -> ErrInvalidJSONTextInParam - if !c.Value.IsNull() { - if len(c.Value.GetBytes()) == 0 { - return types.ErrInvalidJSONTextInParam.GenWithStackByArgs( - 1, "json_schema_valid", "The document is empty.", 0) - } - if c1, ok := args[1].(*Constant); ok { - if !c1.Value.IsNull() && len(c1.Value.GetBytes()) == 0 { - return types.ErrInvalidJSONTextInParam.GenWithStackByArgs( - 2, "json_schema_valid", "The document is empty.", 0) - } - } - } - } - return nil -} - -func (c *jsonSchemaValidFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(ctx.GetEvalCtx(), args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETJson, types.ETJson) - if err != nil { - return nil, err - } - - sig := &builtinJSONSchemaValidSig{baseBuiltinFunc: bf} - return sig, nil -} - -type builtinJSONSchemaValidSig struct { - baseBuiltinFunc - - schemaCache builtinFuncCache[jsonschema.Schema] -} - -func (b *builtinJSONSchemaValidSig) Clone() builtinFunc { - newSig := &builtinJSONSchemaValidSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinJSONSchemaValidSig) evalInt(ctx EvalContext, row chunk.Row) (res int64, isNull bool, err error) { - var schema jsonschema.Schema - - // First argument is the schema - schemaData, schemaIsNull, err := b.args[0].EvalJSON(ctx, row) - if err != nil { - return res, false, err - } - if schemaIsNull { - return res, true, err - } - - if b.args[0].ConstLevel() >= ConstOnlyInContext { - schema, err = b.schemaCache.getOrInitCache(ctx, func() (jsonschema.Schema, error) { - failpoint.Inject("jsonSchemaValidDisableCacheRefresh", func() { - failpoint.Return(jsonschema.Schema{}, errors.New("Cache refresh disabled by failpoint")) - }) - dataBin, err := schemaData.MarshalJSON() - if err != nil { - return jsonschema.Schema{}, err - } - if err := goJSON.Unmarshal(dataBin, &schema); err != nil { - if _, ok := err.(*goJSON.UnmarshalTypeError); ok { - return jsonschema.Schema{}, - types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", "object") - } - return jsonschema.Schema{}, - types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", err) - } - return schema, nil - }) - if err != nil { - return res, false, err - } - } else { - dataBin, err := schemaData.MarshalJSON() - if err != nil { - return res, false, err - } - if err := goJSON.Unmarshal(dataBin, &schema); err != nil { - if _, ok := err.(*goJSON.UnmarshalTypeError); ok { - return res, false, - types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", "object") - } - return res, false, - types.ErrInvalidJSONType.GenWithStackByArgs(1, "json_schema_valid", err) - } - } - - // Second argument is the JSON document - docData, docIsNull, err := b.args[1].EvalJSON(ctx, row) - if err != nil { - return res, false, err - } - if docIsNull { - return res, true, err - } - docDataBin, err := docData.MarshalJSON() - if err != nil { - return res, false, err - } - errs, err := schema.ValidateBytes(context.Background(), docDataBin) - if err != nil { - return res, false, err - } - if len(errs) > 0 { - return res, false, nil - } - res = 1 - return res, false, nil -} diff --git a/pkg/expression/builtin_time.go b/pkg/expression/builtin_time.go index 8af6b1796dfce..e2c58fd32699c 100644 --- a/pkg/expression/builtin_time.go +++ b/pkg/expression/builtin_time.go @@ -2507,9 +2507,9 @@ func evalNowWithFsp(ctx EvalContext, fsp int) (types.Time, bool, error) { return types.ZeroTime, true, err } - if val, _err_ := failpoint.Eval(_curpkg_("injectNow")); _err_ == nil { + failpoint.Inject("injectNow", func(val failpoint.Value) { nowTs = time.Unix(int64(val.(int)), 0) - } + }) // In MySQL's implementation, now() will truncate the result instead of rounding it. // Results below are from MySQL 5.7, which can prove it. @@ -6723,10 +6723,10 @@ func GetStmtMinSafeTime(sc *stmtctx.StatementContext, store kv.Storage, timeZone minSafeTS = store.GetMinSafeTS(txnScope) } // Inject mocked SafeTS for test. - if val, _err_ := failpoint.Eval(_curpkg_("injectSafeTS")); _err_ == nil { + failpoint.Inject("injectSafeTS", func(val failpoint.Value) { injectTS := val.(int) minSafeTS = uint64(injectTS) - } + }) // Try to get from the stmt cache to make sure this function is deterministic. minSafeTS = sc.GetOrStoreStmtCache(stmtctx.StmtSafeTSCacheKey, minSafeTS).(uint64) return oracle.GetTimeFromTS(minSafeTS).In(timeZone) diff --git a/pkg/expression/builtin_time.go__failpoint_stash__ b/pkg/expression/builtin_time.go__failpoint_stash__ deleted file mode 100644 index e2c58fd32699c..0000000000000 --- a/pkg/expression/builtin_time.go__failpoint_stash__ +++ /dev/null @@ -1,6832 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -package expression - -import ( - "context" - "fmt" - "math" - "regexp" - "strconv" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/expression/contextopt" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/mathutil" - "github.com/pingcap/tidb/pkg/util/parser" - "github.com/pingcap/tipb/go-tipb" - "github.com/tikv/client-go/v2/oracle" - "go.uber.org/zap" -) - -const ( // GET_FORMAT first argument. - dateFormat = "DATE" - datetimeFormat = "DATETIME" - timestampFormat = "TIMESTAMP" - timeFormat = "TIME" -) - -const ( // GET_FORMAT location. - usaLocation = "USA" - jisLocation = "JIS" - isoLocation = "ISO" - eurLocation = "EUR" - internalLocation = "INTERNAL" -) - -var ( - // durationPattern checks whether a string matches the format of duration. - durationPattern = regexp.MustCompile(`^\s*[-]?(((\d{1,2}\s+)?0*\d{0,3}(:0*\d{1,2}){0,2})|(\d{1,7}))?(\.\d*)?\s*$`) - - // timestampPattern checks whether a string matches the format of timestamp. - timestampPattern = regexp.MustCompile(`^\s*0*\d{1,4}([^\d]0*\d{1,2}){2}\s+(0*\d{0,2}([^\d]0*\d{1,2}){2})?(\.\d*)?\s*$`) - - // datePattern determine whether to match the format of date. - datePattern = regexp.MustCompile(`^\s*((0*\d{1,4}([^\d]0*\d{1,2}){2})|(\d{2,4}(\d{2}){2}))\s*$`) -) - -var ( - _ functionClass = &dateFunctionClass{} - _ functionClass = &dateLiteralFunctionClass{} - _ functionClass = &dateDiffFunctionClass{} - _ functionClass = &timeDiffFunctionClass{} - _ functionClass = &dateFormatFunctionClass{} - _ functionClass = &hourFunctionClass{} - _ functionClass = &minuteFunctionClass{} - _ functionClass = &secondFunctionClass{} - _ functionClass = µSecondFunctionClass{} - _ functionClass = &monthFunctionClass{} - _ functionClass = &monthNameFunctionClass{} - _ functionClass = &nowFunctionClass{} - _ functionClass = &dayNameFunctionClass{} - _ functionClass = &dayOfMonthFunctionClass{} - _ functionClass = &dayOfWeekFunctionClass{} - _ functionClass = &dayOfYearFunctionClass{} - _ functionClass = &weekFunctionClass{} - _ functionClass = &weekDayFunctionClass{} - _ functionClass = &weekOfYearFunctionClass{} - _ functionClass = &yearFunctionClass{} - _ functionClass = &yearWeekFunctionClass{} - _ functionClass = &fromUnixTimeFunctionClass{} - _ functionClass = &getFormatFunctionClass{} - _ functionClass = &strToDateFunctionClass{} - _ functionClass = &sysDateFunctionClass{} - _ functionClass = ¤tDateFunctionClass{} - _ functionClass = ¤tTimeFunctionClass{} - _ functionClass = &timeFunctionClass{} - _ functionClass = &timeLiteralFunctionClass{} - _ functionClass = &utcDateFunctionClass{} - _ functionClass = &utcTimestampFunctionClass{} - _ functionClass = &extractFunctionClass{} - _ functionClass = &unixTimestampFunctionClass{} - _ functionClass = &addTimeFunctionClass{} - _ functionClass = &convertTzFunctionClass{} - _ functionClass = &makeDateFunctionClass{} - _ functionClass = &makeTimeFunctionClass{} - _ functionClass = &periodAddFunctionClass{} - _ functionClass = &periodDiffFunctionClass{} - _ functionClass = &quarterFunctionClass{} - _ functionClass = &secToTimeFunctionClass{} - _ functionClass = &subTimeFunctionClass{} - _ functionClass = &timeFormatFunctionClass{} - _ functionClass = &timeToSecFunctionClass{} - _ functionClass = ×tampAddFunctionClass{} - _ functionClass = &toDaysFunctionClass{} - _ functionClass = &toSecondsFunctionClass{} - _ functionClass = &utcTimeFunctionClass{} - _ functionClass = ×tampFunctionClass{} - _ functionClass = ×tampLiteralFunctionClass{} - _ functionClass = &lastDayFunctionClass{} - _ functionClass = &addSubDateFunctionClass{} -) - -var ( - _ builtinFunc = &builtinDateSig{} - _ builtinFunc = &builtinDateLiteralSig{} - _ builtinFunc = &builtinDateDiffSig{} - _ builtinFunc = &builtinNullTimeDiffSig{} - _ builtinFunc = &builtinTimeStringTimeDiffSig{} - _ builtinFunc = &builtinDurationStringTimeDiffSig{} - _ builtinFunc = &builtinDurationDurationTimeDiffSig{} - _ builtinFunc = &builtinStringTimeTimeDiffSig{} - _ builtinFunc = &builtinStringDurationTimeDiffSig{} - _ builtinFunc = &builtinStringStringTimeDiffSig{} - _ builtinFunc = &builtinTimeTimeTimeDiffSig{} - _ builtinFunc = &builtinDateFormatSig{} - _ builtinFunc = &builtinHourSig{} - _ builtinFunc = &builtinMinuteSig{} - _ builtinFunc = &builtinSecondSig{} - _ builtinFunc = &builtinMicroSecondSig{} - _ builtinFunc = &builtinMonthSig{} - _ builtinFunc = &builtinMonthNameSig{} - _ builtinFunc = &builtinNowWithArgSig{} - _ builtinFunc = &builtinNowWithoutArgSig{} - _ builtinFunc = &builtinDayNameSig{} - _ builtinFunc = &builtinDayOfMonthSig{} - _ builtinFunc = &builtinDayOfWeekSig{} - _ builtinFunc = &builtinDayOfYearSig{} - _ builtinFunc = &builtinWeekWithModeSig{} - _ builtinFunc = &builtinWeekWithoutModeSig{} - _ builtinFunc = &builtinWeekDaySig{} - _ builtinFunc = &builtinWeekOfYearSig{} - _ builtinFunc = &builtinYearSig{} - _ builtinFunc = &builtinYearWeekWithModeSig{} - _ builtinFunc = &builtinYearWeekWithoutModeSig{} - _ builtinFunc = &builtinGetFormatSig{} - _ builtinFunc = &builtinSysDateWithFspSig{} - _ builtinFunc = &builtinSysDateWithoutFspSig{} - _ builtinFunc = &builtinCurrentDateSig{} - _ builtinFunc = &builtinCurrentTime0ArgSig{} - _ builtinFunc = &builtinCurrentTime1ArgSig{} - _ builtinFunc = &builtinTimeSig{} - _ builtinFunc = &builtinTimeLiteralSig{} - _ builtinFunc = &builtinUTCDateSig{} - _ builtinFunc = &builtinUTCTimestampWithArgSig{} - _ builtinFunc = &builtinUTCTimestampWithoutArgSig{} - _ builtinFunc = &builtinAddDatetimeAndDurationSig{} - _ builtinFunc = &builtinAddDatetimeAndStringSig{} - _ builtinFunc = &builtinAddTimeDateTimeNullSig{} - _ builtinFunc = &builtinAddStringAndDurationSig{} - _ builtinFunc = &builtinAddStringAndStringSig{} - _ builtinFunc = &builtinAddTimeStringNullSig{} - _ builtinFunc = &builtinAddDurationAndDurationSig{} - _ builtinFunc = &builtinAddDurationAndStringSig{} - _ builtinFunc = &builtinAddTimeDurationNullSig{} - _ builtinFunc = &builtinAddDateAndDurationSig{} - _ builtinFunc = &builtinAddDateAndStringSig{} - _ builtinFunc = &builtinSubDatetimeAndDurationSig{} - _ builtinFunc = &builtinSubDatetimeAndStringSig{} - _ builtinFunc = &builtinSubTimeDateTimeNullSig{} - _ builtinFunc = &builtinSubStringAndDurationSig{} - _ builtinFunc = &builtinSubStringAndStringSig{} - _ builtinFunc = &builtinSubTimeStringNullSig{} - _ builtinFunc = &builtinSubDurationAndDurationSig{} - _ builtinFunc = &builtinSubDurationAndStringSig{} - _ builtinFunc = &builtinSubTimeDurationNullSig{} - _ builtinFunc = &builtinSubDateAndDurationSig{} - _ builtinFunc = &builtinSubDateAndStringSig{} - _ builtinFunc = &builtinUnixTimestampCurrentSig{} - _ builtinFunc = &builtinUnixTimestampIntSig{} - _ builtinFunc = &builtinUnixTimestampDecSig{} - _ builtinFunc = &builtinConvertTzSig{} - _ builtinFunc = &builtinMakeDateSig{} - _ builtinFunc = &builtinMakeTimeSig{} - _ builtinFunc = &builtinPeriodAddSig{} - _ builtinFunc = &builtinPeriodDiffSig{} - _ builtinFunc = &builtinQuarterSig{} - _ builtinFunc = &builtinSecToTimeSig{} - _ builtinFunc = &builtinTimeToSecSig{} - _ builtinFunc = &builtinTimestampAddSig{} - _ builtinFunc = &builtinToDaysSig{} - _ builtinFunc = &builtinToSecondsSig{} - _ builtinFunc = &builtinUTCTimeWithArgSig{} - _ builtinFunc = &builtinUTCTimeWithoutArgSig{} - _ builtinFunc = &builtinTimestamp1ArgSig{} - _ builtinFunc = &builtinTimestamp2ArgsSig{} - _ builtinFunc = &builtinTimestampLiteralSig{} - _ builtinFunc = &builtinLastDaySig{} - _ builtinFunc = &builtinStrToDateDateSig{} - _ builtinFunc = &builtinStrToDateDatetimeSig{} - _ builtinFunc = &builtinStrToDateDurationSig{} - _ builtinFunc = &builtinFromUnixTime1ArgSig{} - _ builtinFunc = &builtinFromUnixTime2ArgSig{} - _ builtinFunc = &builtinExtractDatetimeFromStringSig{} - _ builtinFunc = &builtinExtractDatetimeSig{} - _ builtinFunc = &builtinExtractDurationSig{} - _ builtinFunc = &builtinAddSubDateAsStringSig{} - _ builtinFunc = &builtinAddSubDateDatetimeAnySig{} - _ builtinFunc = &builtinAddSubDateDurationAnySig{} -) - -func convertTimeToMysqlTime(t time.Time, fsp int, roundMode types.RoundMode) (types.Time, error) { - var tr time.Time - var err error - if roundMode == types.ModeTruncate { - tr, err = types.TruncateFrac(t, fsp) - } else { - tr, err = types.RoundFrac(t, fsp) - } - if err != nil { - return types.ZeroTime, err - } - - return types.NewTime(types.FromGoTime(tr), mysql.TypeDatetime, fsp), nil -} - -type dateFunctionClass struct { - baseFunctionClass -} - -func (c *dateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig := &builtinDateSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Date) - return sig, nil -} - -type builtinDateSig struct { - baseBuiltinFunc -} - -func (b *builtinDateSig) Clone() builtinFunc { - newSig := &builtinDateSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals DATE(expr). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date -func (b *builtinDateSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - expr, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - - if expr.IsZero() && sqlMode(ctx).HasNoZeroDateMode() { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, expr.String())) - } - - expr.SetCoreTime(types.FromDate(expr.Year(), expr.Month(), expr.Day(), 0, 0, 0, 0)) - expr.SetType(mysql.TypeDate) - return expr, false, nil -} - -type dateLiteralFunctionClass struct { - baseFunctionClass -} - -func (c *dateLiteralFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - con, ok := args[0].(*Constant) - if !ok { - panic("Unexpected parameter for date literal") - } - dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, err - } - str := dt.GetString() - if !datePattern.MatchString(str) { - return nil, types.ErrWrongValue.GenWithStackByArgs(types.DateStr, str) - } - tm, err := types.ParseDate(ctx.GetEvalCtx().TypeCtx(), str) - if err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, []Expression{}, types.ETDatetime) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig := &builtinDateLiteralSig{bf, tm} - return sig, nil -} - -type builtinDateLiteralSig struct { - baseBuiltinFunc - literal types.Time -} - -func (b *builtinDateLiteralSig) Clone() builtinFunc { - newSig := &builtinDateLiteralSig{literal: b.literal} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals DATE 'stringLit'. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-literals.html -func (b *builtinDateLiteralSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - mode := sqlMode(ctx) - if mode.HasNoZeroDateMode() && b.literal.IsZero() { - return b.literal, true, types.ErrWrongValue.GenWithStackByArgs(types.DateStr, b.literal.String()) - } - if mode.HasNoZeroInDateMode() && (b.literal.InvalidZero() && !b.literal.IsZero()) { - return b.literal, true, types.ErrWrongValue.GenWithStackByArgs(types.DateStr, b.literal.String()) - } - return b.literal, false, nil -} - -type dateDiffFunctionClass struct { - baseFunctionClass -} - -func (c *dateDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime, types.ETDatetime) - if err != nil { - return nil, err - } - sig := &builtinDateDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DateDiff) - return sig, nil -} - -type builtinDateDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinDateDiffSig) Clone() builtinFunc { - newSig := &builtinDateDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinDateDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_datediff -func (b *builtinDateDiffSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - lhs, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - rhs, isNull, err := b.args[1].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - if invalidLHS, invalidRHS := lhs.InvalidZero(), rhs.InvalidZero(); invalidLHS || invalidRHS { - if invalidLHS { - err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, lhs.String())) - } - if invalidRHS { - err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rhs.String())) - } - return 0, true, err - } - return int64(types.DateDiff(lhs.CoreTime(), rhs.CoreTime())), false, nil -} - -type timeDiffFunctionClass struct { - baseFunctionClass -} - -func (c *timeDiffFunctionClass) getArgEvalTp(fieldTp *types.FieldType) types.EvalType { - argTp := types.ETString - switch tp := fieldTp.EvalType(); tp { - case types.ETDuration, types.ETDatetime, types.ETTimestamp: - argTp = tp - } - return argTp -} - -func (c *timeDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - arg0FieldTp, arg1FieldTp := args[0].GetType(ctx.GetEvalCtx()), args[1].GetType(ctx.GetEvalCtx()) - arg0Tp, arg1Tp := c.getArgEvalTp(arg0FieldTp), c.getArgEvalTp(arg1FieldTp) - arg0Dec, err := getExpressionFsp(ctx, args[0]) - if err != nil { - return nil, err - } - arg1Dec, err := getExpressionFsp(ctx, args[1]) - if err != nil { - return nil, err - } - fsp := mathutil.Max(arg0Dec, arg1Dec) - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, arg0Tp, arg1Tp) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(fsp) - - var sig builtinFunc - // arg0 and arg1 must be the same time type(compatible), or timediff will return NULL. - switch arg0Tp { - case types.ETDuration: - switch arg1Tp { - case types.ETDuration: - sig = &builtinDurationDurationTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DurationDurationTimeDiff) - case types.ETDatetime, types.ETTimestamp: - sig = &builtinNullTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_NullTimeDiff) - default: - sig = &builtinDurationStringTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DurationStringTimeDiff) - } - case types.ETDatetime, types.ETTimestamp: - switch arg1Tp { - case types.ETDuration: - sig = &builtinNullTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_NullTimeDiff) - case types.ETDatetime, types.ETTimestamp: - sig = &builtinTimeTimeTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_TimeTimeTimeDiff) - default: - sig = &builtinTimeStringTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_TimeStringTimeDiff) - } - default: - switch arg1Tp { - case types.ETDuration: - sig = &builtinStringDurationTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_StringDurationTimeDiff) - case types.ETDatetime, types.ETTimestamp: - sig = &builtinStringTimeTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_StringTimeTimeDiff) - default: - sig = &builtinStringStringTimeDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_StringStringTimeDiff) - } - } - return sig, nil -} - -type builtinDurationDurationTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinDurationDurationTimeDiffSig) Clone() builtinFunc { - newSig := &builtinDurationDurationTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinDurationDurationTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinDurationDurationTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - lhs, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - rhs, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - d, isNull, err = calculateDurationTimeDiff(ctx, lhs, rhs) - return d, isNull, err -} - -type builtinTimeTimeTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinTimeTimeTimeDiffSig) Clone() builtinFunc { - newSig := &builtinTimeTimeTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinTimeTimeTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinTimeTimeTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - lhs, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - rhs, isNull, err := b.args[1].EvalTime(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - tc := typeCtx(ctx) - d, isNull, err = calculateTimeDiff(tc, lhs, rhs) - return d, isNull, err -} - -type builtinDurationStringTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinDurationStringTimeDiffSig) Clone() builtinFunc { - newSig := &builtinDurationStringTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinDurationStringTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinDurationStringTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - lhs, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - rhsStr, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - tc := typeCtx(ctx) - rhs, _, isDuration, err := convertStringToDuration(tc, rhsStr, b.tp.GetDecimal()) - if err != nil || !isDuration { - return d, true, err - } - - d, isNull, err = calculateDurationTimeDiff(ctx, lhs, rhs) - return d, isNull, err -} - -type builtinStringDurationTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinStringDurationTimeDiffSig) Clone() builtinFunc { - newSig := &builtinStringDurationTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinStringDurationTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinStringDurationTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - lhsStr, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - rhs, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - tc := typeCtx(ctx) - lhs, _, isDuration, err := convertStringToDuration(tc, lhsStr, b.tp.GetDecimal()) - if err != nil || !isDuration { - return d, true, err - } - - d, isNull, err = calculateDurationTimeDiff(ctx, lhs, rhs) - return d, isNull, err -} - -// calculateTimeDiff calculates interval difference of two types.Time. -func calculateTimeDiff(tc types.Context, lhs, rhs types.Time) (d types.Duration, isNull bool, err error) { - d = lhs.Sub(tc, &rhs) - d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) - if types.ErrTruncatedWrongVal.Equal(err) { - err = tc.HandleTruncate(err) - } - return d, err != nil, err -} - -// calculateDurationTimeDiff calculates interval difference of two types.Duration. -func calculateDurationTimeDiff(ctx EvalContext, lhs, rhs types.Duration) (d types.Duration, isNull bool, err error) { - d, err = lhs.Sub(rhs) - if err != nil { - return d, true, err - } - - d.Duration, err = types.TruncateOverflowMySQLTime(d.Duration) - if types.ErrTruncatedWrongVal.Equal(err) { - tc := typeCtx(ctx) - err = tc.HandleTruncate(err) - } - return d, err != nil, err -} - -type builtinTimeStringTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinTimeStringTimeDiffSig) Clone() builtinFunc { - newSig := &builtinTimeStringTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinTimeStringTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinTimeStringTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - lhs, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - rhsStr, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - tc := typeCtx(ctx) - _, rhs, isDuration, err := convertStringToDuration(tc, rhsStr, b.tp.GetDecimal()) - if err != nil || isDuration { - return d, true, err - } - - d, isNull, err = calculateTimeDiff(tc, lhs, rhs) - return d, isNull, err -} - -type builtinStringTimeTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinStringTimeTimeDiffSig) Clone() builtinFunc { - newSig := &builtinStringTimeTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinStringTimeTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinStringTimeTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - lhsStr, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - rhs, isNull, err := b.args[1].EvalTime(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - tc := typeCtx(ctx) - _, lhs, isDuration, err := convertStringToDuration(tc, lhsStr, b.tp.GetDecimal()) - if err != nil || isDuration { - return d, true, err - } - - d, isNull, err = calculateTimeDiff(tc, lhs, rhs) - return d, isNull, err -} - -type builtinStringStringTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinStringStringTimeDiffSig) Clone() builtinFunc { - newSig := &builtinStringStringTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinStringStringTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinStringStringTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - lhs, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - rhs, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return d, isNull, err - } - - tc := typeCtx(ctx) - fsp := b.tp.GetDecimal() - lhsDur, lhsTime, lhsIsDuration, err := convertStringToDuration(tc, lhs, fsp) - if err != nil { - return d, true, err - } - - rhsDur, rhsTime, rhsIsDuration, err := convertStringToDuration(tc, rhs, fsp) - if err != nil { - return d, true, err - } - - if lhsIsDuration != rhsIsDuration { - return d, true, nil - } - - if lhsIsDuration { - d, isNull, err = calculateDurationTimeDiff(ctx, lhsDur, rhsDur) - } else { - d, isNull, err = calculateTimeDiff(tc, lhsTime, rhsTime) - } - - return d, isNull, err -} - -type builtinNullTimeDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinNullTimeDiffSig) Clone() builtinFunc { - newSig := &builtinNullTimeDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinNullTimeDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timediff -func (b *builtinNullTimeDiffSig) evalDuration(ctx EvalContext, row chunk.Row) (d types.Duration, isNull bool, err error) { - return d, true, nil -} - -// convertStringToDuration converts string to duration, it return types.Time because in some case -// it will converts string to datetime. -func convertStringToDuration(tc types.Context, str string, fsp int) (d types.Duration, t types.Time, - isDuration bool, err error) { - if n := strings.IndexByte(str, '.'); n >= 0 { - lenStrFsp := len(str[n+1:]) - if lenStrFsp <= types.MaxFsp { - fsp = mathutil.Max(lenStrFsp, fsp) - } - } - return types.StrToDuration(tc, str, fsp) -} - -type dateFormatFunctionClass struct { - baseFunctionClass -} - -func (c *dateFormatFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDatetime, types.ETString) - if err != nil { - return nil, err - } - // worst case: formatMask=%r%r%r...%r, each %r takes 11 characters - bf.tp.SetFlen((args[1].GetType(ctx.GetEvalCtx()).GetFlen() + 1) / 2 * 11) - sig := &builtinDateFormatSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DateFormatSig) - return sig, nil -} - -type builtinDateFormatSig struct { - baseBuiltinFunc -} - -func (b *builtinDateFormatSig) Clone() builtinFunc { - newSig := &builtinDateFormatSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinDateFormatSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-format -func (b *builtinDateFormatSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - t, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return "", isNull, handleInvalidTimeError(ctx, err) - } - formatMask, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - // MySQL compatibility, #11203 - // If format mask is 0 then return 0 without warnings - if formatMask == "0" { - return "0", false, nil - } - - if t.InvalidZero() { - // MySQL compatibility, #11203 - // 0 | 0.0 should be converted to null without warnings - n, err := t.ToNumber().ToInt() - isOriginalIntOrDecimalZero := err == nil && n == 0 - // Args like "0000-00-00", "0000-00-00 00:00:00" set Fsp to 6 - isOriginalStringZero := t.Fsp() > 0 - if isOriginalIntOrDecimalZero && !isOriginalStringZero { - return "", true, nil - } - return "", true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) - } - - res, err := t.DateFormat(formatMask) - return res, isNull, err -} - -type fromDaysFunctionClass struct { - baseFunctionClass -} - -func (c *fromDaysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETInt) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig := &builtinFromDaysSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_FromDays) - return sig, nil -} - -type builtinFromDaysSig struct { - baseBuiltinFunc -} - -func (b *builtinFromDaysSig) Clone() builtinFunc { - newSig := &builtinFromDaysSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals FROM_DAYS(N). -// See https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_from-days -func (b *builtinFromDaysSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - n, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - ret := types.TimeFromDays(n) - // the maximum date value is 9999-12-31 in mysql 5.8. - if ret.Year() > 9999 { - return types.ZeroTime, true, nil - } - return ret, false, nil -} - -type hourFunctionClass struct { - baseFunctionClass -} - -func (c *hourFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) - if err != nil { - return nil, err - } - bf.tp.SetFlen(3) - bf.tp.SetDecimal(0) - sig := &builtinHourSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Hour) - return sig, nil -} - -type builtinHourSig struct { - baseBuiltinFunc -} - -func (b *builtinHourSig) Clone() builtinFunc { - newSig := &builtinHourSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals HOUR(time). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_hour -func (b *builtinHourSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - dur, isNull, err := b.args[0].EvalDuration(ctx, row) - // ignore error and return NULL - if isNull || err != nil { - return 0, true, nil - } - return int64(dur.Hour()), false, nil -} - -type minuteFunctionClass struct { - baseFunctionClass -} - -func (c *minuteFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) - if err != nil { - return nil, err - } - bf.tp.SetFlen(2) - bf.tp.SetDecimal(0) - sig := &builtinMinuteSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Minute) - return sig, nil -} - -type builtinMinuteSig struct { - baseBuiltinFunc -} - -func (b *builtinMinuteSig) Clone() builtinFunc { - newSig := &builtinMinuteSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals MINUTE(time). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_minute -func (b *builtinMinuteSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - dur, isNull, err := b.args[0].EvalDuration(ctx, row) - // ignore error and return NULL - if isNull || err != nil { - return 0, true, nil - } - return int64(dur.Minute()), false, nil -} - -type secondFunctionClass struct { - baseFunctionClass -} - -func (c *secondFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) - if err != nil { - return nil, err - } - bf.tp.SetFlen(2) - bf.tp.SetDecimal(0) - sig := &builtinSecondSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Second) - return sig, nil -} - -type builtinSecondSig struct { - baseBuiltinFunc -} - -func (b *builtinSecondSig) Clone() builtinFunc { - newSig := &builtinSecondSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals SECOND(time). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_second -func (b *builtinSecondSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - dur, isNull, err := b.args[0].EvalDuration(ctx, row) - // ignore error and return NULL - if isNull || err != nil { - return 0, true, nil - } - return int64(dur.Second()), false, nil -} - -type microSecondFunctionClass struct { - baseFunctionClass -} - -func (c *microSecondFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) - if err != nil { - return nil, err - } - bf.tp.SetFlen(6) - bf.tp.SetDecimal(0) - sig := &builtinMicroSecondSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_MicroSecond) - return sig, nil -} - -type builtinMicroSecondSig struct { - baseBuiltinFunc -} - -func (b *builtinMicroSecondSig) Clone() builtinFunc { - newSig := &builtinMicroSecondSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals MICROSECOND(expr). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_microsecond -func (b *builtinMicroSecondSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - dur, isNull, err := b.args[0].EvalDuration(ctx, row) - // ignore error and return NULL - if isNull || err != nil { - return 0, true, nil - } - return int64(dur.MicroSecond()), false, nil -} - -type monthFunctionClass struct { - baseFunctionClass -} - -func (c *monthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(2) - bf.tp.SetDecimal(0) - sig := &builtinMonthSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Month) - return sig, nil -} - -type builtinMonthSig struct { - baseBuiltinFunc -} - -func (b *builtinMonthSig) Clone() builtinFunc { - newSig := &builtinMonthSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals MONTH(date). -// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_month -func (b *builtinMonthSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - - return int64(date.Month()), false, nil -} - -// monthNameFunctionClass see https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_monthname -type monthNameFunctionClass struct { - baseFunctionClass -} - -func (c *monthNameFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDatetime) - if err != nil { - return nil, err - } - charset, collate := ctx.GetCharsetInfo() - bf.tp.SetCharset(charset) - bf.tp.SetCollate(collate) - bf.tp.SetFlen(10) - sig := &builtinMonthNameSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_MonthName) - return sig, nil -} - -type builtinMonthNameSig struct { - baseBuiltinFunc -} - -func (b *builtinMonthNameSig) Clone() builtinFunc { - newSig := &builtinMonthNameSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinMonthNameSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return "", true, handleInvalidTimeError(ctx, err) - } - mon := arg.Month() - if (arg.IsZero() && sqlMode(ctx).HasNoZeroDateMode()) || mon < 0 || mon > len(types.MonthNames) { - return "", true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } else if mon == 0 || arg.IsZero() { - return "", true, nil - } - return types.MonthNames[mon-1], false, nil -} - -type dayNameFunctionClass struct { - baseFunctionClass -} - -func (c *dayNameFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDatetime) - if err != nil { - return nil, err - } - charset, collate := ctx.GetCharsetInfo() - bf.tp.SetCharset(charset) - bf.tp.SetCollate(collate) - bf.tp.SetFlen(10) - sig := &builtinDayNameSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DayName) - return sig, nil -} - -type builtinDayNameSig struct { - baseBuiltinFunc -} - -func (b *builtinDayNameSig) Clone() builtinFunc { - newSig := &builtinDayNameSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinDayNameSig) evalIndex(ctx EvalContext, row chunk.Row) (int64, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - if arg.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - // Monday is 0, ... Sunday = 6 in MySQL - // but in go, Sunday is 0, ... Saturday is 6 - // w will do a conversion. - res := (int64(arg.Weekday()) + 6) % 7 - return res, false, nil -} - -// evalString evals a builtinDayNameSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayname -func (b *builtinDayNameSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - idx, isNull, err := b.evalIndex(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - return types.WeekdayNames[idx], false, nil -} - -func (b *builtinDayNameSig) evalReal(ctx EvalContext, row chunk.Row) (float64, bool, error) { - idx, isNull, err := b.evalIndex(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - return float64(idx), false, nil -} - -func (b *builtinDayNameSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - idx, isNull, err := b.evalIndex(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - return idx, false, nil -} - -type dayOfMonthFunctionClass struct { - baseFunctionClass -} - -func (c *dayOfMonthFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(2) - sig := &builtinDayOfMonthSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DayOfMonth) - return sig, nil -} - -type builtinDayOfMonthSig struct { - baseBuiltinFunc -} - -func (b *builtinDayOfMonthSig) Clone() builtinFunc { - newSig := &builtinDayOfMonthSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinDayOfMonthSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofmonth -func (b *builtinDayOfMonthSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - return int64(arg.Day()), false, nil -} - -type dayOfWeekFunctionClass struct { - baseFunctionClass -} - -func (c *dayOfWeekFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(1) - sig := &builtinDayOfWeekSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DayOfWeek) - return sig, nil -} - -type builtinDayOfWeekSig struct { - baseBuiltinFunc -} - -func (b *builtinDayOfWeekSig) Clone() builtinFunc { - newSig := &builtinDayOfWeekSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinDayOfWeekSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofweek -func (b *builtinDayOfWeekSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - if arg.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - // 1 is Sunday, 2 is Monday, .... 7 is Saturday - return int64(arg.Weekday() + 1), false, nil -} - -type dayOfYearFunctionClass struct { - baseFunctionClass -} - -func (c *dayOfYearFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(3) - sig := &builtinDayOfYearSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_DayOfYear) - return sig, nil -} - -type builtinDayOfYearSig struct { - baseBuiltinFunc -} - -func (b *builtinDayOfYearSig) Clone() builtinFunc { - newSig := &builtinDayOfYearSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinDayOfYearSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayofyear -func (b *builtinDayOfYearSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, isNull, handleInvalidTimeError(ctx, err) - } - if arg.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - - return int64(arg.YearDay()), false, nil -} - -type weekFunctionClass struct { - baseFunctionClass -} - -func (c *weekFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - argTps := []types.EvalType{types.ETDatetime} - if len(args) == 2 { - argTps = append(argTps, types.ETInt) - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) - if err != nil { - return nil, err - } - bf.tp.SetFlen(2) - bf.tp.SetDecimal(0) - - var sig builtinFunc - if len(args) == 2 { - sig = &builtinWeekWithModeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_WeekWithMode) - } else { - sig = &builtinWeekWithoutModeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_WeekWithoutMode) - } - return sig, nil -} - -type builtinWeekWithModeSig struct { - baseBuiltinFunc -} - -func (b *builtinWeekWithModeSig) Clone() builtinFunc { - newSig := &builtinWeekWithModeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals WEEK(date, mode). -// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_week -func (b *builtinWeekWithModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - - if date.IsZero() || date.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) - } - - mode, isNull, err := b.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - - week := date.Week(int(mode)) - return int64(week), false, nil -} - -type builtinWeekWithoutModeSig struct { - baseBuiltinFunc -} - -func (b *builtinWeekWithoutModeSig) Clone() builtinFunc { - newSig := &builtinWeekWithoutModeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals WEEK(date). -// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_week -func (b *builtinWeekWithoutModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - - if date.IsZero() || date.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) - } - - mode := 0 - if modeStr := ctx.GetDefaultWeekFormatMode(); modeStr != "" { - mode, err = strconv.Atoi(modeStr) - if err != nil { - return 0, true, handleInvalidTimeError(ctx, types.ErrInvalidWeekModeFormat.GenWithStackByArgs(modeStr)) - } - } - - week := date.Week(mode) - return int64(week), false, nil -} - -type weekDayFunctionClass struct { - baseFunctionClass -} - -func (c *weekDayFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(1) - - sig := &builtinWeekDaySig{bf} - sig.setPbCode(tipb.ScalarFuncSig_WeekDay) - return sig, nil -} - -type builtinWeekDaySig struct { - baseBuiltinFunc -} - -func (b *builtinWeekDaySig) Clone() builtinFunc { - newSig := &builtinWeekDaySig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals WEEKDAY(date). -func (b *builtinWeekDaySig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - - if date.IsZero() || date.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) - } - - return int64(date.Weekday()+6) % 7, false, nil -} - -type weekOfYearFunctionClass struct { - baseFunctionClass -} - -func (c *weekOfYearFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(2) - bf.tp.SetDecimal(0) - sig := &builtinWeekOfYearSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_WeekOfYear) - return sig, nil -} - -type builtinWeekOfYearSig struct { - baseBuiltinFunc -} - -func (b *builtinWeekOfYearSig) Clone() builtinFunc { - newSig := &builtinWeekOfYearSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals WEEKOFYEAR(date). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_weekofyear -func (b *builtinWeekOfYearSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - - if date.IsZero() || date.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) - } - - week := date.Week(3) - return int64(week), false, nil -} - -type yearFunctionClass struct { - baseFunctionClass -} - -func (c *yearFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(4) - bf.tp.SetDecimal(0) - sig := &builtinYearSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Year) - return sig, nil -} - -type builtinYearSig struct { - baseBuiltinFunc -} - -func (b *builtinYearSig) Clone() builtinFunc { - newSig := &builtinYearSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals YEAR(date). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_year -func (b *builtinYearSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - return int64(date.Year()), false, nil -} - -type yearWeekFunctionClass struct { - baseFunctionClass -} - -func (c *yearWeekFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - argTps := []types.EvalType{types.ETDatetime} - if len(args) == 2 { - argTps = append(argTps, types.ETInt) - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, argTps...) - if err != nil { - return nil, err - } - - bf.tp.SetFlen(6) - bf.tp.SetDecimal(0) - - var sig builtinFunc - if len(args) == 2 { - sig = &builtinYearWeekWithModeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_YearWeekWithMode) - } else { - sig = &builtinYearWeekWithoutModeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_YearWeekWithoutMode) - } - return sig, nil -} - -type builtinYearWeekWithModeSig struct { - baseBuiltinFunc -} - -func (b *builtinYearWeekWithModeSig) Clone() builtinFunc { - newSig := &builtinYearWeekWithModeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals YEARWEEK(date,mode). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_yearweek -func (b *builtinYearWeekWithModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, isNull, handleInvalidTimeError(ctx, err) - } - if date.IsZero() || date.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) - } - - mode, isNull, err := b.args[1].EvalInt(ctx, row) - if err != nil { - return 0, true, err - } - if isNull { - mode = 0 - } - - year, week := date.YearWeek(int(mode)) - result := int64(week + year*100) - if result < 0 { - return int64(math.MaxUint32), false, nil - } - return result, false, nil -} - -type builtinYearWeekWithoutModeSig struct { - baseBuiltinFunc -} - -func (b *builtinYearWeekWithoutModeSig) Clone() builtinFunc { - newSig := &builtinYearWeekWithoutModeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals YEARWEEK(date). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_yearweek -func (b *builtinYearWeekWithoutModeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - - if date.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) - } - - year, week := date.YearWeek(0) - result := int64(week + year*100) - if result < 0 { - return int64(math.MaxUint32), false, nil - } - return result, false, nil -} - -type fromUnixTimeFunctionClass struct { - baseFunctionClass -} - -func (c *fromUnixTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { - if err = c.verifyArgs(args); err != nil { - return nil, err - } - - retTp, argTps := types.ETDatetime, make([]types.EvalType, 0, len(args)) - argTps = append(argTps, types.ETDecimal) - if len(args) == 2 { - retTp = types.ETString - argTps = append(argTps, types.ETString) - } - - arg0Tp := args[0].GetType(ctx.GetEvalCtx()) - isArg0Str := arg0Tp.EvalType() == types.ETString - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, retTp, argTps...) - if err != nil { - return nil, err - } - - if fieldString(arg0Tp.GetType()) { - //Improve string cast Unix Time precision - x, ok := (bf.getArgs()[0]).(*ScalarFunction) - if ok { - //used to adjust FromUnixTime precision #Fixbug35184 - if x.FuncName.L == ast.Cast { - if x.RetType.GetDecimal() == 0 && (x.RetType.GetType() == mysql.TypeNewDecimal) { - x.RetType.SetDecimal(6) - fieldLen := mathutil.Min(x.RetType.GetFlen()+6, mysql.MaxDecimalWidth) - x.RetType.SetFlen(fieldLen) - } - } - } - } - - if len(args) > 1 { - sig = &builtinFromUnixTime2ArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_FromUnixTime2Arg) - return sig, nil - } - - // Calculate the time fsp. - fsp := types.MaxFsp - if !isArg0Str { - if arg0Tp.GetDecimal() != types.UnspecifiedLength { - fsp = mathutil.Min(bf.tp.GetDecimal(), arg0Tp.GetDecimal()) - } - } - bf.setDecimalAndFlenForDatetime(fsp) - - sig = &builtinFromUnixTime1ArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_FromUnixTime1Arg) - return sig, nil -} - -func evalFromUnixTime(ctx EvalContext, fsp int, unixTimeStamp *types.MyDecimal) (res types.Time, isNull bool, err error) { - // 0 <= unixTimeStamp <= 32536771199.999999 - if unixTimeStamp.IsNegative() { - return res, true, nil - } - integralPart, err := unixTimeStamp.ToInt() - if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) && !terror.ErrorEqual(err, types.ErrOverflow) { - return res, true, err - } - // The max integralPart should not be larger than 32536771199. - // Refer to https://dev.mysql.com/doc/relnotes/mysql/8.0/en/news-8-0-28.html - if integralPart > 32536771199 { - return res, true, nil - } - // Split the integral part and fractional part of a decimal timestamp. - // e.g. for timestamp 12345.678, - // first get the integral part 12345, - // then (12345.678 - 12345) * (10^9) to get the decimal part and convert it to nanosecond precision. - integerDecimalTp := new(types.MyDecimal).FromInt(integralPart) - fracDecimalTp := new(types.MyDecimal) - err = types.DecimalSub(unixTimeStamp, integerDecimalTp, fracDecimalTp) - if err != nil { - return res, true, err - } - nano := new(types.MyDecimal).FromInt(int64(time.Second)) - x := new(types.MyDecimal) - err = types.DecimalMul(fracDecimalTp, nano, x) - if err != nil { - return res, true, err - } - fractionalPart, err := x.ToInt() // here fractionalPart is result multiplying the original fractional part by 10^9. - if err != nil && !terror.ErrorEqual(err, types.ErrTruncated) { - return res, true, err - } - if fsp < 0 { - fsp = types.MaxFsp - } - - tc := typeCtx(ctx) - tmp := time.Unix(integralPart, fractionalPart).In(tc.Location()) - t, err := convertTimeToMysqlTime(tmp, fsp, types.ModeHalfUp) - if err != nil { - return res, true, err - } - return t, false, nil -} - -// fieldString returns true if precision cannot be determined -func fieldString(fieldType byte) bool { - switch fieldType { - case mysql.TypeString, mysql.TypeVarchar, mysql.TypeTinyBlob, - mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: - return true - default: - return false - } -} - -type builtinFromUnixTime1ArgSig struct { - baseBuiltinFunc -} - -func (b *builtinFromUnixTime1ArgSig) Clone() builtinFunc { - newSig := &builtinFromUnixTime1ArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinFromUnixTime1ArgSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_from-unixtime -func (b *builtinFromUnixTime1ArgSig) evalTime(ctx EvalContext, row chunk.Row) (res types.Time, isNull bool, err error) { - unixTimeStamp, isNull, err := b.args[0].EvalDecimal(ctx, row) - if err != nil || isNull { - return res, isNull, err - } - return evalFromUnixTime(ctx, b.tp.GetDecimal(), unixTimeStamp) -} - -type builtinFromUnixTime2ArgSig struct { - baseBuiltinFunc -} - -func (b *builtinFromUnixTime2ArgSig) Clone() builtinFunc { - newSig := &builtinFromUnixTime2ArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinFromUnixTime2ArgSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_from-unixtime -func (b *builtinFromUnixTime2ArgSig) evalString(ctx EvalContext, row chunk.Row) (res string, isNull bool, err error) { - format, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", true, err - } - unixTimeStamp, isNull, err := b.args[0].EvalDecimal(ctx, row) - if err != nil || isNull { - return "", isNull, err - } - t, isNull, err := evalFromUnixTime(ctx, b.tp.GetDecimal(), unixTimeStamp) - if isNull || err != nil { - return "", isNull, err - } - res, err = t.DateFormat(format) - return res, err != nil, err -} - -type getFormatFunctionClass struct { - baseFunctionClass -} - -func (c *getFormatFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString, types.ETString) - if err != nil { - return nil, err - } - bf.tp.SetFlen(17) - sig := &builtinGetFormatSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_GetFormat) - return sig, nil -} - -type builtinGetFormatSig struct { - baseBuiltinFunc -} - -func (b *builtinGetFormatSig) Clone() builtinFunc { - newSig := &builtinGetFormatSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinGetFormatSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_get-format -func (b *builtinGetFormatSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - t, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - l, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - - res := b.getFormat(t, l) - return res, false, nil -} - -type strToDateFunctionClass struct { - baseFunctionClass -} - -func (c *strToDateFunctionClass) getRetTp(ctx BuildContext, arg Expression) (tp byte, fsp int) { - tp = mysql.TypeDatetime - if _, ok := arg.(*Constant); !ok { - return tp, types.MaxFsp - } - strArg := WrapWithCastAsString(ctx, arg) - format, isNull, err := strArg.EvalString(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil || isNull { - return - } - - isDuration, isDate := types.GetFormatType(format) - if isDuration && !isDate { - tp = mysql.TypeDuration - } else if !isDuration && isDate { - tp = mysql.TypeDate - } - if strings.Contains(format, "%f") { - fsp = types.MaxFsp - } - return -} - -// getFunction see https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_str-to-date -func (c *strToDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - retTp, fsp := c.getRetTp(ctx, args[1]) - switch retTp { - case mysql.TypeDate: - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETString, types.ETString) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig = &builtinStrToDateDateSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_StrToDateDate) - case mysql.TypeDatetime: - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETString, types.ETString) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(fsp) - sig = &builtinStrToDateDatetimeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_StrToDateDatetime) - case mysql.TypeDuration: - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETString, types.ETString) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(fsp) - sig = &builtinStrToDateDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_StrToDateDuration) - } - return sig, nil -} - -type builtinStrToDateDateSig struct { - baseBuiltinFunc -} - -func (b *builtinStrToDateDateSig) Clone() builtinFunc { - newSig := &builtinStrToDateDateSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinStrToDateDateSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - date, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - format, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - var t types.Time - tc := typeCtx(ctx) - succ := t.StrToDate(tc, date, format) - if !succ { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) - } - if sqlMode(ctx).HasNoZeroDateMode() && (t.Year() == 0 || t.Month() == 0 || t.Day() == 0) { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValueForType.GenWithStackByArgs(types.DateTimeStr, date, ast.StrToDate)) - } - t.SetType(mysql.TypeDate) - t.SetFsp(types.MinFsp) - return t, false, nil -} - -type builtinStrToDateDatetimeSig struct { - baseBuiltinFunc -} - -func (b *builtinStrToDateDatetimeSig) Clone() builtinFunc { - newSig := &builtinStrToDateDatetimeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinStrToDateDatetimeSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - date, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - format, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - var t types.Time - tc := typeCtx(ctx) - succ := t.StrToDate(tc, date, format) - if !succ { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) - } - if sqlMode(ctx).HasNoZeroDateMode() && (t.Year() == 0 || t.Month() == 0 || t.Day() == 0) { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) - } - t.SetType(mysql.TypeDatetime) - t.SetFsp(b.tp.GetDecimal()) - return t, false, nil -} - -type builtinStrToDateDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinStrToDateDurationSig) Clone() builtinFunc { - newSig := &builtinStrToDateDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration -// TODO: If the NO_ZERO_DATE or NO_ZERO_IN_DATE SQL mode is enabled, zero dates or part of dates are disallowed. -// In that case, STR_TO_DATE() returns NULL and generates a warning. -func (b *builtinStrToDateDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - date, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return types.Duration{}, isNull, err - } - format, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.Duration{}, isNull, err - } - var t types.Time - tc := typeCtx(ctx) - succ := t.StrToDate(tc, date, format) - if !succ { - return types.Duration{}, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, t.String())) - } - t.SetFsp(b.tp.GetDecimal()) - dur, err := t.ConvertToDuration() - return dur, err != nil, err -} - -type sysDateFunctionClass struct { - baseFunctionClass -} - -func (c *sysDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - fsp, err := getFspByIntArg(ctx, args) - if err != nil { - return nil, err - } - var argTps = make([]types.EvalType, 0) - if len(args) == 1 { - argTps = append(argTps, types.ETInt) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, argTps...) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(fsp) - // Illegal parameters have been filtered out in the parser, so the result is always not null. - bf.tp.SetFlag(bf.tp.GetFlag() | mysql.NotNullFlag) - - var sig builtinFunc - if len(args) == 1 { - sig = &builtinSysDateWithFspSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SysDateWithFsp) - } else { - sig = &builtinSysDateWithoutFspSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SysDateWithoutFsp) - } - return sig, nil -} - -type builtinSysDateWithFspSig struct { - baseBuiltinFunc -} - -func (b *builtinSysDateWithFspSig) Clone() builtinFunc { - newSig := &builtinSysDateWithFspSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals SYSDATE(fsp). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_sysdate -func (b *builtinSysDateWithFspSig) evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { - fsp, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - - loc := location(ctx) - now := time.Now().In(loc) - result, err := convertTimeToMysqlTime(now, int(fsp), types.ModeHalfUp) - if err != nil { - return types.ZeroTime, true, err - } - return result, false, nil -} - -type builtinSysDateWithoutFspSig struct { - baseBuiltinFunc -} - -func (b *builtinSysDateWithoutFspSig) Clone() builtinFunc { - newSig := &builtinSysDateWithoutFspSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals SYSDATE(). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_sysdate -func (b *builtinSysDateWithoutFspSig) evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { - tz := location(ctx) - now := time.Now().In(tz) - result, err := convertTimeToMysqlTime(now, 0, types.ModeHalfUp) - if err != nil { - return types.ZeroTime, true, err - } - return result, false, nil -} - -type currentDateFunctionClass struct { - baseFunctionClass -} - -func (c *currentDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig := &builtinCurrentDateSig{bf} - return sig, nil -} - -type builtinCurrentDateSig struct { - baseBuiltinFunc -} - -func (b *builtinCurrentDateSig) Clone() builtinFunc { - newSig := &builtinCurrentDateSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals CURDATE(). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_curdate -func (b *builtinCurrentDateSig) evalTime(ctx EvalContext, row chunk.Row) (val types.Time, isNull bool, err error) { - tz := location(ctx) - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.ZeroTime, true, err - } - year, month, day := nowTs.In(tz).Date() - result := types.NewTime(types.FromDate(year, int(month), day, 0, 0, 0, 0), mysql.TypeDate, 0) - return result, false, nil -} - -type currentTimeFunctionClass struct { - baseFunctionClass -} - -func (c *currentTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { - if err = c.verifyArgs(args); err != nil { - return nil, err - } - - fsp, err := getFspByIntArg(ctx, args) - if err != nil { - return nil, err - } - var argTps = make([]types.EvalType, 0) - if len(args) == 1 { - argTps = append(argTps, types.ETInt) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, argTps...) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(fsp) - // 1. no sign. - // 2. hour is in the 2-digit range. - bf.tp.SetFlen(bf.tp.GetFlen() - 2) - if len(args) == 0 { - sig = &builtinCurrentTime0ArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_CurrentTime0Arg) - return sig, nil - } - sig = &builtinCurrentTime1ArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_CurrentTime1Arg) - return sig, nil -} - -type builtinCurrentTime0ArgSig struct { - baseBuiltinFunc -} - -func (b *builtinCurrentTime0ArgSig) Clone() builtinFunc { - newSig := &builtinCurrentTime0ArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinCurrentTime0ArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - tz := location(ctx) - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.Duration{}, true, err - } - dur := nowTs.In(tz).Format(types.TimeFormat) - res, _, err := types.ParseDuration(typeCtx(ctx), dur, types.MinFsp) - if err != nil { - return types.Duration{}, true, err - } - return res, false, nil -} - -type builtinCurrentTime1ArgSig struct { - baseBuiltinFunc -} - -func (b *builtinCurrentTime1ArgSig) Clone() builtinFunc { - newSig := &builtinCurrentTime1ArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinCurrentTime1ArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - fsp, _, err := b.args[0].EvalInt(ctx, row) - if err != nil { - return types.Duration{}, true, err - } - tz := location(ctx) - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.Duration{}, true, err - } - dur := nowTs.In(tz).Format(types.TimeFSPFormat) - tc := typeCtx(ctx) - res, _, err := types.ParseDuration(tc, dur, int(fsp)) - if err != nil { - return types.Duration{}, true, err - } - return res, false, nil -} - -type timeFunctionClass struct { - baseFunctionClass -} - -func (c *timeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - err := c.verifyArgs(args) - if err != nil { - return nil, err - } - fsp, err := getExpressionFsp(ctx, args[0]) - if err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETString) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(fsp) - sig := &builtinTimeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Time) - return sig, nil -} - -type builtinTimeSig struct { - baseBuiltinFunc -} - -func (b *builtinTimeSig) Clone() builtinFunc { - newSig := &builtinTimeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinTimeSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time. -func (b *builtinTimeSig) evalDuration(ctx EvalContext, row chunk.Row) (res types.Duration, isNull bool, err error) { - expr, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return res, isNull, err - } - - fsp := 0 - if idx := strings.Index(expr, "."); idx != -1 { - fsp = len(expr) - idx - 1 - } - - var tmpFsp int - if tmpFsp, err = types.CheckFsp(fsp); err != nil { - return res, isNull, err - } - fsp = tmpFsp - - tc := typeCtx(ctx) - res, _, err = types.ParseDuration(tc, expr, fsp) - if types.ErrTruncatedWrongVal.Equal(err) { - err = tc.HandleTruncate(err) - } - return res, isNull, err -} - -type timeLiteralFunctionClass struct { - baseFunctionClass -} - -func (c *timeLiteralFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - con, ok := args[0].(*Constant) - if !ok { - panic("Unexpected parameter for time literal") - } - dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, err - } - str := dt.GetString() - if !isDuration(str) { - return nil, types.ErrWrongValue.GenWithStackByArgs(types.TimeStr, str) - } - duration, _, err := types.ParseDuration(ctx.GetEvalCtx().TypeCtx(), str, types.GetFsp(str)) - if err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, []Expression{}, types.ETDuration) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(duration.Fsp) - sig := &builtinTimeLiteralSig{bf, duration} - return sig, nil -} - -type builtinTimeLiteralSig struct { - baseBuiltinFunc - duration types.Duration -} - -func (b *builtinTimeLiteralSig) Clone() builtinFunc { - newSig := &builtinTimeLiteralSig{duration: b.duration} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals TIME 'stringLit'. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-literals.html -func (b *builtinTimeLiteralSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - return b.duration, false, nil -} - -type utcDateFunctionClass struct { - baseFunctionClass -} - -func (c *utcDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig := &builtinUTCDateSig{bf} - return sig, nil -} - -type builtinUTCDateSig struct { - baseBuiltinFunc -} - -func (b *builtinUTCDateSig) Clone() builtinFunc { - newSig := &builtinUTCDateSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals UTC_DATE, UTC_DATE(). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-date -func (b *builtinUTCDateSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.ZeroTime, true, err - } - year, month, day := nowTs.UTC().Date() - result := types.NewTime(types.FromGoTime(time.Date(year, month, day, 0, 0, 0, 0, time.UTC)), mysql.TypeDate, types.UnspecifiedFsp) - return result, false, nil -} - -type utcTimestampFunctionClass struct { - baseFunctionClass -} - -func (c *utcTimestampFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, 1) - if len(args) == 1 { - argTps = append(argTps, types.ETInt) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, argTps...) - if err != nil { - return nil, err - } - - fsp, err := getFspByIntArg(ctx, args) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(fsp) - var sig builtinFunc - if len(args) == 1 { - sig = &builtinUTCTimestampWithArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UTCTimestampWithArg) - } else { - sig = &builtinUTCTimestampWithoutArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UTCTimestampWithoutArg) - } - return sig, nil -} - -func evalUTCTimestampWithFsp(ctx EvalContext, fsp int) (types.Time, bool, error) { - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.ZeroTime, true, err - } - result, err := convertTimeToMysqlTime(nowTs.UTC(), fsp, types.ModeHalfUp) - if err != nil { - return types.ZeroTime, true, err - } - return result, false, nil -} - -type builtinUTCTimestampWithArgSig struct { - baseBuiltinFunc -} - -func (b *builtinUTCTimestampWithArgSig) Clone() builtinFunc { - newSig := &builtinUTCTimestampWithArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals UTC_TIMESTAMP(fsp). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-timestamp -func (b *builtinUTCTimestampWithArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - num, isNull, err := b.args[0].EvalInt(ctx, row) - if err != nil { - return types.ZeroTime, true, err - } - - if !isNull && num > int64(types.MaxFsp) { - return types.ZeroTime, true, errors.Errorf("Too-big precision %v specified for 'utc_timestamp'. Maximum is %v", num, types.MaxFsp) - } - if !isNull && num < int64(types.MinFsp) { - return types.ZeroTime, true, errors.Errorf("Invalid negative %d specified, must in [0, 6]", num) - } - - result, isNull, err := evalUTCTimestampWithFsp(ctx, int(num)) - return result, isNull, err -} - -type builtinUTCTimestampWithoutArgSig struct { - baseBuiltinFunc -} - -func (b *builtinUTCTimestampWithoutArgSig) Clone() builtinFunc { - newSig := &builtinUTCTimestampWithoutArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals UTC_TIMESTAMP(). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-timestamp -func (b *builtinUTCTimestampWithoutArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - result, isNull, err := evalUTCTimestampWithFsp(ctx, 0) - return result, isNull, err -} - -type nowFunctionClass struct { - baseFunctionClass -} - -func (c *nowFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, 1) - if len(args) == 1 { - argTps = append(argTps, types.ETInt) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, argTps...) - if err != nil { - return nil, err - } - - fsp, err := getFspByIntArg(ctx, args) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(fsp) - - var sig builtinFunc - if len(args) == 1 { - sig = &builtinNowWithArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_NowWithArg) - } else { - sig = &builtinNowWithoutArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_NowWithoutArg) - } - return sig, nil -} - -// GetStmtTimestamp directly calls getTimeZone with timezone -func GetStmtTimestamp(ctx EvalContext) (time.Time, error) { - tz := getTimeZone(ctx) - tVal, err := getStmtTimestamp(ctx) - if err != nil { - return tVal, err - } - return tVal.In(tz), nil -} - -func evalNowWithFsp(ctx EvalContext, fsp int) (types.Time, bool, error) { - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.ZeroTime, true, err - } - - failpoint.Inject("injectNow", func(val failpoint.Value) { - nowTs = time.Unix(int64(val.(int)), 0) - }) - - // In MySQL's implementation, now() will truncate the result instead of rounding it. - // Results below are from MySQL 5.7, which can prove it. - // mysql> select now(6), now(3), now(); - // +----------------------------+-------------------------+---------------------+ - // | now(6) | now(3) | now() | - // +----------------------------+-------------------------+---------------------+ - // | 2019-03-25 15:57:56.612966 | 2019-03-25 15:57:56.612 | 2019-03-25 15:57:56 | - // +----------------------------+-------------------------+---------------------+ - result, err := convertTimeToMysqlTime(nowTs, fsp, types.ModeTruncate) - if err != nil { - return types.ZeroTime, true, err - } - - err = result.ConvertTimeZone(time.Local, location(ctx)) - if err != nil { - return types.ZeroTime, true, err - } - - return result, false, nil -} - -type builtinNowWithArgSig struct { - baseBuiltinFunc -} - -func (b *builtinNowWithArgSig) Clone() builtinFunc { - newSig := &builtinNowWithArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals NOW(fsp) -// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_now -func (b *builtinNowWithArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - fsp, isNull, err := b.args[0].EvalInt(ctx, row) - - if err != nil { - return types.ZeroTime, true, err - } - - if isNull { - fsp = 0 - } else if fsp > int64(types.MaxFsp) { - return types.ZeroTime, true, errors.Errorf("Too-big precision %v specified for 'now'. Maximum is %v", fsp, types.MaxFsp) - } else if fsp < int64(types.MinFsp) { - return types.ZeroTime, true, errors.Errorf("Invalid negative %d specified, must in [0, 6]", fsp) - } - - result, isNull, err := evalNowWithFsp(ctx, int(fsp)) - return result, isNull, err -} - -type builtinNowWithoutArgSig struct { - baseBuiltinFunc -} - -func (b *builtinNowWithoutArgSig) Clone() builtinFunc { - newSig := &builtinNowWithoutArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals NOW() -// see: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_now -func (b *builtinNowWithoutArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - result, isNull, err := evalNowWithFsp(ctx, 0) - return result, isNull, err -} - -type extractFunctionClass struct { - baseFunctionClass -} - -func (c *extractFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { - if err = c.verifyArgs(args); err != nil { - return nil, err - } - - args[0] = WrapWithCastAsString(ctx, args[0]) - unit, _, err := args[0].EvalString(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, err - } - isClockUnit := types.IsClockUnit(unit) - isDateUnit := types.IsDateUnit(unit) - var bf baseBuiltinFunc - if isClockUnit && isDateUnit { - // For unit DAY_MICROSECOND/DAY_SECOND/DAY_MINUTE/DAY_HOUR, the interpretation of the second argument depends on its evaluation type: - // 1. Datetime/timestamp are interpreted as datetime. For example: - // extract(day_second from datetime('2001-01-01 02:03:04')) = 120304 - // Note that MySQL 5.5+ has a bug of no day portion in the result (20304) for this case, see https://bugs.mysql.com/bug.php?id=73240. - // 2. Time is interpreted as is. For example: - // extract(day_second from time('02:03:04')) = 20304 - // Note that time shouldn't be implicitly cast to datetime, or else the date portion will be padded with the current date and this will adjust time portion accordingly. - // 3. Otherwise, string/int/float are interpreted as arbitrarily either datetime or time, depending on which fits. For example: - // extract(day_second from '2001-01-01 02:03:04') = 1020304 // datetime - // extract(day_second from 20010101020304) = 1020304 // datetime - // extract(day_second from '01 02:03:04') = 260304 // time - if args[1].GetType(ctx.GetEvalCtx()).EvalType() == types.ETDatetime || args[1].GetType(ctx.GetEvalCtx()).EvalType() == types.ETTimestamp { - bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDatetime) - if err != nil { - return nil, err - } - sig = &builtinExtractDatetimeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_ExtractDatetime) - } else if args[1].GetType(ctx.GetEvalCtx()).EvalType() == types.ETDuration { - bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDuration) - if err != nil { - return nil, err - } - sig = &builtinExtractDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_ExtractDuration) - } else { - bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETString) - if err != nil { - return nil, err - } - bf.args[1].GetType(ctx.GetEvalCtx()).SetDecimal(int(types.MaxFsp)) - sig = &builtinExtractDatetimeFromStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_ExtractDatetimeFromString) - } - } else if isClockUnit { - // Clock units interpret the second argument as time. - bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDuration) - if err != nil { - return nil, err - } - sig = &builtinExtractDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_ExtractDuration) - } else { - // Date units interpret the second argument as datetime. - bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDatetime) - if err != nil { - return nil, err - } - sig = &builtinExtractDatetimeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_ExtractDatetime) - } - return sig, nil -} - -type builtinExtractDatetimeFromStringSig struct { - baseBuiltinFunc -} - -func (b *builtinExtractDatetimeFromStringSig) Clone() builtinFunc { - newSig := &builtinExtractDatetimeFromStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinExtractDatetimeFromStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract -func (b *builtinExtractDatetimeFromStringSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - unit, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - dtStr, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - tc := typeCtx(ctx) - if types.IsClockUnit(unit) && types.IsDateUnit(unit) { - dur, _, err := types.ParseDuration(tc, dtStr, types.GetFsp(dtStr)) - if err != nil { - return 0, true, err - } - res, err := types.ExtractDurationNum(&dur, unit) - if err != nil { - return 0, true, err - } - dt, err := types.ParseDatetime(tc, dtStr) - if err != nil { - return res, false, nil - } - if dt.Hour() == dur.Hour() && dt.Minute() == dur.Minute() && dt.Second() == dur.Second() && dt.Year() > 0 { - res, err = types.ExtractDatetimeNum(&dt, unit) - } - return res, err != nil, err - } - - panic("Unexpected unit for extract") -} - -type builtinExtractDatetimeSig struct { - baseBuiltinFunc -} - -func (b *builtinExtractDatetimeSig) Clone() builtinFunc { - newSig := &builtinExtractDatetimeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinExtractDatetimeSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract -func (b *builtinExtractDatetimeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - unit, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - dt, isNull, err := b.args[1].EvalTime(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - res, err := types.ExtractDatetimeNum(&dt, unit) - return res, err != nil, err -} - -type builtinExtractDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinExtractDurationSig) Clone() builtinFunc { - newSig := &builtinExtractDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinExtractDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract -func (b *builtinExtractDurationSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - unit, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - dur, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - res, err := types.ExtractDurationNum(&dur, unit) - return res, err != nil, err -} - -// baseDateArithmetical is the base class for all "builtinAddDateXXXSig" and "builtinSubDateXXXSig", -// which provides parameter getter and date arithmetical calculate functions. -type baseDateArithmetical struct { - // intervalRegexp is "*Regexp" used to extract string interval for "DAY" unit. - intervalRegexp *regexp.Regexp -} - -func newDateArithmeticalUtil() baseDateArithmetical { - return baseDateArithmetical{ - intervalRegexp: regexp.MustCompile(`^[+-]?[\d]+`), - } -} - -func (du *baseDateArithmetical) getDateFromString(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - dateStr, isNull, err := args[0].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - dateTp := mysql.TypeDate - if !types.IsDateFormat(dateStr) || types.IsClockUnit(unit) { - dateTp = mysql.TypeDatetime - } - - tc := typeCtx(ctx) - date, err := types.ParseTime(tc, dateStr, dateTp, types.MaxFsp) - if err != nil { - err = handleInvalidTimeError(ctx, err) - if err != nil { - return types.ZeroTime, true, err - } - return date, true, handleInvalidTimeError(ctx, err) - } else if sqlMode(ctx).HasNoZeroDateMode() && (date.Year() == 0 || date.Month() == 0 || date.Day() == 0) { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, dateStr)) - } - return date, false, handleInvalidTimeError(ctx, err) -} - -func (du *baseDateArithmetical) getDateFromInt(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - dateInt, isNull, err := args[0].EvalInt(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - tc := typeCtx(ctx) - date, err := types.ParseTimeFromInt64(tc, dateInt) - if err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - - // The actual date.Type() might be date or datetime. - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if types.IsClockUnit(unit) { - date.SetType(mysql.TypeDatetime) - } - return date, false, nil -} - -func (du *baseDateArithmetical) getDateFromReal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - dateReal, isNull, err := args[0].EvalReal(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - tc := typeCtx(ctx) - date, err := types.ParseTimeFromFloat64(tc, dateReal) - if err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - - // The actual date.Type() might be date or datetime. - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if types.IsClockUnit(unit) { - date.SetType(mysql.TypeDatetime) - } - return date, false, nil -} - -func (du *baseDateArithmetical) getDateFromDecimal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - dateDec, isNull, err := args[0].EvalDecimal(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - tc := typeCtx(ctx) - date, err := types.ParseTimeFromDecimal(tc, dateDec) - if err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - - // The actual date.Type() might be date or datetime. - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if types.IsClockUnit(unit) { - date.SetType(mysql.TypeDatetime) - } - return date, false, nil -} - -func (du *baseDateArithmetical) getDateFromDatetime(ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - date, isNull, err := args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - // The actual date.Type() might be date, datetime or timestamp. - // Datetime is treated as is. - // Timestamp is treated as datetime, as MySQL manual says: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-add - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if types.IsClockUnit(unit) || date.Type() == mysql.TypeTimestamp { - date.SetType(mysql.TypeDatetime) - } - return date, false, nil -} - -func (du *baseDateArithmetical) getIntervalFromString(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - interval, isNull, err := args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", true, err - } - - ec := errCtx(ctx) - interval, err = du.intervalReformatString(ec, interval, unit) - return interval, false, err -} - -func (du *baseDateArithmetical) intervalReformatString(ec errctx.Context, str string, unit string) (interval string, err error) { - switch strings.ToUpper(unit) { - case "MICROSECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": - str = strings.TrimSpace(str) - // a single unit value has to be specially handled. - interval = du.intervalRegexp.FindString(str) - if interval == "" { - interval = "0" - } - - if interval != str { - err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str)) - } - case "SECOND": - // The unit SECOND is specially handled, for example: - // date + INTERVAL "1e2" SECOND = date + INTERVAL 100 second - // date + INTERVAL "1.6" SECOND = date + INTERVAL 1.6 second - // But: - // date + INTERVAL "1e2" MINUTE = date + INTERVAL 1 MINUTE - // date + INTERVAL "1.6" MINUTE = date + INTERVAL 1 MINUTE - var dec types.MyDecimal - if err = dec.FromString([]byte(str)); err != nil { - truncatedErr := types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", str) - err = ec.HandleErrorWithAlias(err, truncatedErr, truncatedErr) - } - interval = string(dec.ToString()) - default: - interval = str - } - return interval, err -} - -func (du *baseDateArithmetical) intervalDecimalToString(ec errctx.Context, dec *types.MyDecimal) (string, error) { - var rounded types.MyDecimal - err := dec.Round(&rounded, 0, types.ModeHalfUp) - if err != nil { - return "", err - } - - intVal, err := rounded.ToInt() - if err != nil { - if err = ec.HandleError(types.ErrTruncatedWrongVal.GenWithStackByArgs("DECIMAL", dec.String())); err != nil { - return "", err - } - } - - return strconv.FormatInt(intVal, 10), nil -} - -func (du *baseDateArithmetical) getIntervalFromDecimal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - decimal, isNull, err := args[1].EvalDecimal(ctx, row) - if isNull || err != nil { - return "", true, err - } - interval := decimal.String() - - switch strings.ToUpper(unit) { - case "HOUR_MINUTE", "MINUTE_SECOND", "YEAR_MONTH", "DAY_HOUR", "DAY_MINUTE", - "DAY_SECOND", "DAY_MICROSECOND", "HOUR_MICROSECOND", "HOUR_SECOND", "MINUTE_MICROSECOND", "SECOND_MICROSECOND": - neg := false - if interval != "" && interval[0] == '-' { - neg = true - interval = interval[1:] - } - switch strings.ToUpper(unit) { - case "HOUR_MINUTE", "MINUTE_SECOND": - interval = strings.ReplaceAll(interval, ".", ":") - case "YEAR_MONTH": - interval = strings.ReplaceAll(interval, ".", "-") - case "DAY_HOUR": - interval = strings.ReplaceAll(interval, ".", " ") - case "DAY_MINUTE": - interval = "0 " + strings.ReplaceAll(interval, ".", ":") - case "DAY_SECOND": - interval = "0 00:" + strings.ReplaceAll(interval, ".", ":") - case "DAY_MICROSECOND": - interval = "0 00:00:" + interval - case "HOUR_MICROSECOND": - interval = "00:00:" + interval - case "HOUR_SECOND": - interval = "00:" + strings.ReplaceAll(interval, ".", ":") - case "MINUTE_MICROSECOND": - interval = "00:" + interval - case "SECOND_MICROSECOND": - /* keep interval as original decimal */ - } - if neg { - interval = "-" + interval - } - case "SECOND": - // interval is already like the %f format. - default: - // YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND - ec := errCtx(ctx) - interval, err = du.intervalDecimalToString(ec, decimal) - if err != nil { - return "", true, err - } - } - - return interval, false, nil -} - -func (du *baseDateArithmetical) getIntervalFromInt(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - interval, isNull, err := args[1].EvalInt(ctx, row) - if isNull || err != nil { - return "", true, err - } - - if mysql.HasUnsignedFlag(args[1].GetType(ctx).GetFlag()) { - return strconv.FormatUint(uint64(interval), 10), false, nil - } - - return strconv.FormatInt(interval, 10), false, nil -} - -func (du *baseDateArithmetical) getIntervalFromReal(ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - interval, isNull, err := args[1].EvalReal(ctx, row) - if isNull || err != nil { - return "", true, err - } - return strconv.FormatFloat(interval, 'f', args[1].GetType(ctx).GetDecimal(), 64), false, nil -} - -func (du *baseDateArithmetical) add(ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) { - year, month, day, nano, _, err := types.ParseDurationValue(unit, interval) - if err := handleInvalidTimeError(ctx, err); err != nil { - return types.ZeroTime, true, err - } - return du.addDate(ctx, date, year, month, day, nano, resultFsp) -} - -func (du *baseDateArithmetical) addDate(ctx EvalContext, date types.Time, year, month, day, nano int64, resultFsp int) (types.Time, bool, error) { - goTime, err := date.GoTime(time.UTC) - if err := handleInvalidTimeError(ctx, err); err != nil { - return types.ZeroTime, true, err - } - - goTime = goTime.Add(time.Duration(nano)) - goTime, err = types.AddDate(year, month, day, goTime) - if err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) - } - - // Adjust fsp as required by outer - always respect type inference. - date.SetFsp(resultFsp) - - // fix https://github.com/pingcap/tidb/issues/11329 - if goTime.Year() == 0 { - hour, minute, second := goTime.Clock() - date.SetCoreTime(types.FromDate(0, 0, 0, hour, minute, second, goTime.Nanosecond()/1000)) - return date, false, nil - } - - date.SetCoreTime(types.FromGoTime(goTime)) - tc := typeCtx(ctx) - overflow, err := types.DateTimeIsOverflow(tc, date) - if err := handleInvalidTimeError(ctx, err); err != nil { - return types.ZeroTime, true, err - } - if overflow { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) - } - return date, false, nil -} - -type funcDurationOp func(d, interval types.Duration) (types.Duration, error) - -func (du *baseDateArithmetical) opDuration(ctx EvalContext, op funcDurationOp, d types.Duration, interval string, unit string, resultFsp int) (types.Duration, bool, error) { - dur, err := types.ExtractDurationValue(unit, interval) - if err != nil { - return types.ZeroDuration, true, handleInvalidTimeError(ctx, err) - } - retDur, err := op(d, dur) - if err != nil { - return types.ZeroDuration, true, err - } - // Adjust fsp as required by outer - always respect type inference. - retDur.Fsp = resultFsp - return retDur, false, nil -} - -func (du *baseDateArithmetical) addDuration(ctx EvalContext, d types.Duration, interval string, unit string, resultFsp int) (types.Duration, bool, error) { - add := func(d, interval types.Duration) (types.Duration, error) { - return d.Add(interval) - } - return du.opDuration(ctx, add, d, interval, unit, resultFsp) -} - -func (du *baseDateArithmetical) subDuration(ctx EvalContext, d types.Duration, interval string, unit string, resultFsp int) (types.Duration, bool, error) { - sub := func(d, interval types.Duration) (types.Duration, error) { - return d.Sub(interval) - } - return du.opDuration(ctx, sub, d, interval, unit, resultFsp) -} - -func (du *baseDateArithmetical) sub(ctx EvalContext, date types.Time, interval string, unit string, resultFsp int) (types.Time, bool, error) { - year, month, day, nano, _, err := types.ParseDurationValue(unit, interval) - if err := handleInvalidTimeError(ctx, err); err != nil { - return types.ZeroTime, true, err - } - return du.addDate(ctx, date, -year, -month, -day, -nano, resultFsp) -} - -func (du *baseDateArithmetical) vecGetDateFromInt(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[0].VecEvalInt(ctx, input, buf); err != nil { - return err - } - - result.ResizeTime(n, false) - result.MergeNulls(buf) - dates := result.Times() - i64s := buf.Int64s() - tc := typeCtx(ctx) - isClockUnit := types.IsClockUnit(unit) - for i := 0; i < n; i++ { - if result.IsNull(i) { - continue - } - - date, err := types.ParseTimeFromInt64(tc, i64s[i]) - if err != nil { - err = handleInvalidTimeError(ctx, err) - if err != nil { - return err - } - result.SetNull(i, true) - continue - } - - // The actual date.Type() might be date or datetime. - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if isClockUnit { - date.SetType(mysql.TypeDatetime) - } - dates[i] = date - } - return nil -} - -func (du *baseDateArithmetical) vecGetDateFromReal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[0].VecEvalReal(ctx, input, buf); err != nil { - return err - } - - result.ResizeTime(n, false) - result.MergeNulls(buf) - dates := result.Times() - f64s := buf.Float64s() - tc := typeCtx(ctx) - isClockUnit := types.IsClockUnit(unit) - for i := 0; i < n; i++ { - if result.IsNull(i) { - continue - } - - date, err := types.ParseTimeFromFloat64(tc, f64s[i]) - if err != nil { - err = handleInvalidTimeError(ctx, err) - if err != nil { - return err - } - result.SetNull(i, true) - continue - } - - // The actual date.Type() might be date or datetime. - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if isClockUnit { - date.SetType(mysql.TypeDatetime) - } - dates[i] = date - } - return nil -} - -func (du *baseDateArithmetical) vecGetDateFromDecimal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[0].VecEvalDecimal(ctx, input, buf); err != nil { - return err - } - - result.ResizeTime(n, false) - result.MergeNulls(buf) - dates := result.Times() - tc := typeCtx(ctx) - isClockUnit := types.IsClockUnit(unit) - for i := 0; i < n; i++ { - if result.IsNull(i) { - continue - } - - dec := buf.GetDecimal(i) - date, err := types.ParseTimeFromDecimal(tc, dec) - if err != nil { - err = handleInvalidTimeError(ctx, err) - if err != nil { - return err - } - result.SetNull(i, true) - continue - } - - // The actual date.Type() might be date or datetime. - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if isClockUnit { - date.SetType(mysql.TypeDatetime) - } - dates[i] = date - } - return nil -} - -func (du *baseDateArithmetical) vecGetDateFromString(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[0].VecEvalString(ctx, input, buf); err != nil { - return err - } - - result.ResizeTime(n, false) - result.MergeNulls(buf) - dates := result.Times() - tc := typeCtx(ctx) - isClockUnit := types.IsClockUnit(unit) - for i := 0; i < n; i++ { - if result.IsNull(i) { - continue - } - - dateStr := buf.GetString(i) - dateTp := mysql.TypeDate - if !types.IsDateFormat(dateStr) || isClockUnit { - dateTp = mysql.TypeDatetime - } - - date, err := types.ParseTime(tc, dateStr, dateTp, types.MaxFsp) - if err != nil { - err = handleInvalidTimeError(ctx, err) - if err != nil { - return err - } - result.SetNull(i, true) - } else if sqlMode(ctx).HasNoZeroDateMode() && (date.Year() == 0 || date.Month() == 0 || date.Day() == 0) { - err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, dateStr)) - if err != nil { - return err - } - result.SetNull(i, true) - } else { - dates[i] = date - } - } - return nil -} - -func (du *baseDateArithmetical) vecGetDateFromDatetime(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - result.ResizeTime(n, false) - if err := b.args[0].VecEvalTime(ctx, input, result); err != nil { - return err - } - - dates := result.Times() - isClockUnit := types.IsClockUnit(unit) - for i := 0; i < n; i++ { - if result.IsNull(i) { - continue - } - - // The actual date[i].Type() might be date, datetime or timestamp. - // Datetime is treated as is. - // Timestamp is treated as datetime, as MySQL manual says: https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_date-add - // When the unit contains clock, the date part is treated as datetime even though it might be actually a date. - if isClockUnit || dates[i].Type() == mysql.TypeTimestamp { - dates[i].SetType(mysql.TypeDatetime) - } - } - return nil -} - -func (du *baseDateArithmetical) vecGetIntervalFromString(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[1].VecEvalString(ctx, input, buf); err != nil { - return err - } - - ec := errCtx(ctx) - result.ReserveString(n) - for i := 0; i < n; i++ { - if buf.IsNull(i) { - result.AppendNull() - continue - } - - interval, err := du.intervalReformatString(ec, buf.GetString(i), unit) - if err != nil { - return err - } - result.AppendString(interval) - } - return nil -} - -func (du *baseDateArithmetical) vecGetIntervalFromDecimal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[1].VecEvalDecimal(ctx, input, buf); err != nil { - return err - } - - isCompoundUnit := false - amendInterval := func(val string, row *chunk.Row) (string, bool, error) { - return val, false, nil - } - switch unitUpper := strings.ToUpper(unit); unitUpper { - case "HOUR_MINUTE", "MINUTE_SECOND", "YEAR_MONTH", "DAY_HOUR", "DAY_MINUTE", - "DAY_SECOND", "DAY_MICROSECOND", "HOUR_MICROSECOND", "HOUR_SECOND", "MINUTE_MICROSECOND", "SECOND_MICROSECOND": - isCompoundUnit = true - switch strings.ToUpper(unit) { - case "HOUR_MINUTE", "MINUTE_SECOND": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return strings.ReplaceAll(val, ".", ":"), false, nil - } - case "YEAR_MONTH": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return strings.ReplaceAll(val, ".", "-"), false, nil - } - case "DAY_HOUR": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return strings.ReplaceAll(val, ".", " "), false, nil - } - case "DAY_MINUTE": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return "0 " + strings.ReplaceAll(val, ".", ":"), false, nil - } - case "DAY_SECOND": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return "0 00:" + strings.ReplaceAll(val, ".", ":"), false, nil - } - case "DAY_MICROSECOND": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return "0 00:00:" + val, false, nil - } - case "HOUR_MICROSECOND": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return "00:00:" + val, false, nil - } - case "HOUR_SECOND": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return "00:" + strings.ReplaceAll(val, ".", ":"), false, nil - } - case "MINUTE_MICROSECOND": - amendInterval = func(val string, _ *chunk.Row) (string, bool, error) { - return "00:" + val, false, nil - } - case "SECOND_MICROSECOND": - /* keep interval as original decimal */ - } - case "SECOND": - /* keep interval as original decimal */ - default: - // YEAR, QUARTER, MONTH, WEEK, DAY, HOUR, MINUTE, MICROSECOND - amendInterval = func(_ string, row *chunk.Row) (string, bool, error) { - dec, isNull, err := b.args[1].EvalDecimal(ctx, *row) - if isNull || err != nil { - return "", true, err - } - - str, err := du.intervalDecimalToString(errCtx(ctx), dec) - if err != nil { - return "", true, err - } - - return str, false, nil - } - } - - result.ReserveString(n) - decs := buf.Decimals() - for i := 0; i < n; i++ { - if buf.IsNull(i) { - result.AppendNull() - continue - } - - interval := decs[i].String() - row := input.GetRow(i) - isNeg := false - if isCompoundUnit && interval != "" && interval[0] == '-' { - isNeg = true - interval = interval[1:] - } - interval, isNull, err := amendInterval(interval, &row) - if err != nil { - return err - } - if isNull { - result.AppendNull() - continue - } - if isCompoundUnit && isNeg { - interval = "-" + interval - } - result.AppendString(interval) - } - return nil -} - -func (du *baseDateArithmetical) vecGetIntervalFromInt(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[1].VecEvalInt(ctx, input, buf); err != nil { - return err - } - - result.ReserveString(n) - i64s := buf.Int64s() - unsigned := mysql.HasUnsignedFlag(b.args[1].GetType(ctx).GetFlag()) - for i := 0; i < n; i++ { - if buf.IsNull(i) { - result.AppendNull() - } else if unsigned { - result.AppendString(strconv.FormatUint(uint64(i64s[i]), 10)) - } else { - result.AppendString(strconv.FormatInt(i64s[i], 10)) - } - } - return nil -} - -func (du *baseDateArithmetical) vecGetIntervalFromReal(b *baseBuiltinFunc, ctx EvalContext, input *chunk.Chunk, unit string, result *chunk.Column) error { - n := input.NumRows() - buf, err := b.bufAllocator.get() - if err != nil { - return err - } - defer b.bufAllocator.put(buf) - if err := b.args[1].VecEvalReal(ctx, input, buf); err != nil { - return err - } - - result.ReserveString(n) - f64s := buf.Float64s() - prec := b.args[1].GetType(ctx).GetDecimal() - for i := 0; i < n; i++ { - if buf.IsNull(i) { - result.AppendNull() - } else { - result.AppendString(strconv.FormatFloat(f64s[i], 'f', prec, 64)) - } - } - return nil -} - -type funcTimeOpForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) - -func addTime(da *baseDateArithmetical, ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) { - return da.add(ctx, date, interval, unit, resultFsp) -} - -func subTime(da *baseDateArithmetical, ctx EvalContext, date types.Time, interval, unit string, resultFsp int) (types.Time, bool, error) { - return da.sub(ctx, date, interval, unit, resultFsp) -} - -type funcDurationOpForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, d types.Duration, interval, unit string, resultFsp int) (types.Duration, bool, error) - -func addDuration(da *baseDateArithmetical, ctx EvalContext, d types.Duration, interval, unit string, resultFsp int) (types.Duration, bool, error) { - return da.addDuration(ctx, d, interval, unit, resultFsp) -} - -func subDuration(da *baseDateArithmetical, ctx EvalContext, d types.Duration, interval, unit string, resultFsp int) (types.Duration, bool, error) { - return da.subDuration(ctx, d, interval, unit, resultFsp) -} - -type funcSetPbCodeOp func(b builtinFunc, add, sub tipb.ScalarFuncSig) - -func setAdd(b builtinFunc, add, sub tipb.ScalarFuncSig) { - b.setPbCode(add) -} - -func setSub(b builtinFunc, add, sub tipb.ScalarFuncSig) { - b.setPbCode(sub) -} - -type funcGetDateForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) - -func getDateFromString(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - return da.getDateFromString(ctx, args, row, unit) -} - -func getDateFromInt(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - return da.getDateFromInt(ctx, args, row, unit) -} - -func getDateFromReal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - return da.getDateFromReal(ctx, args, row, unit) -} - -func getDateFromDecimal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (types.Time, bool, error) { - return da.getDateFromDecimal(ctx, args, row, unit) -} - -type funcVecGetDateForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error - -func vecGetDateFromString(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetDateFromString(b, ctx, input, unit, result) -} - -func vecGetDateFromInt(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetDateFromInt(b, ctx, input, unit, result) -} - -func vecGetDateFromReal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetDateFromReal(b, ctx, input, unit, result) -} - -func vecGetDateFromDecimal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetDateFromDecimal(b, ctx, input, unit, result) -} - -type funcGetIntervalForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) - -func getIntervalFromString(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - return da.getIntervalFromString(ctx, args, row, unit) -} - -func getIntervalFromInt(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - return da.getIntervalFromInt(ctx, args, row, unit) -} - -func getIntervalFromReal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - return da.getIntervalFromReal(ctx, args, row, unit) -} - -func getIntervalFromDecimal(da *baseDateArithmetical, ctx EvalContext, args []Expression, row chunk.Row, unit string) (string, bool, error) { - return da.getIntervalFromDecimal(ctx, args, row, unit) -} - -type funcVecGetIntervalForDateAddSub func(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error - -func vecGetIntervalFromString(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetIntervalFromString(b, ctx, input, unit, result) -} - -func vecGetIntervalFromInt(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetIntervalFromInt(b, ctx, input, unit, result) -} - -func vecGetIntervalFromReal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetIntervalFromReal(b, ctx, input, unit, result) -} - -func vecGetIntervalFromDecimal(da *baseDateArithmetical, ctx EvalContext, b *baseBuiltinFunc, input *chunk.Chunk, unit string, result *chunk.Column) error { - return da.vecGetIntervalFromDecimal(b, ctx, input, unit, result) -} - -type addSubDateFunctionClass struct { - baseFunctionClass - timeOp funcTimeOpForDateAddSub - durationOp funcDurationOpForDateAddSub - setPbCodeOp funcSetPbCodeOp -} - -func (c *addSubDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { - if err = c.verifyArgs(args); err != nil { - return nil, err - } - - dateEvalTp := args[0].GetType(ctx.GetEvalCtx()).EvalType() - // Some special evaluation type treatment. - // Note that it could be more elegant if we always evaluate datetime for int, real, decimal and string, by leveraging existing implicit casts. - // However, MySQL has a weird behavior for date_add(string, ...), whose result depends on the content of the first argument. - // E.g., date_add('2000-01-02 00:00:00', interval 1 day) evaluates to '2021-01-03 00:00:00' (which is normal), - // whereas date_add('2000-01-02', interval 1 day) evaluates to '2000-01-03' instead of '2021-01-03 00:00:00'. - // This requires a customized parsing of the content of the first argument, by recognizing if it is a pure date format or contains HMS part. - // So implicit casts are not viable here. - if dateEvalTp == types.ETTimestamp { - dateEvalTp = types.ETDatetime - } else if dateEvalTp == types.ETJson { - dateEvalTp = types.ETString - } - - intervalEvalTp := args[1].GetType(ctx.GetEvalCtx()).EvalType() - if intervalEvalTp == types.ETJson { - intervalEvalTp = types.ETString - } else if intervalEvalTp != types.ETString && intervalEvalTp != types.ETDecimal && intervalEvalTp != types.ETReal { - intervalEvalTp = types.ETInt - } - - unit, _, err := args[2].EvalString(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, err - } - - resultTp := mysql.TypeVarString - resultEvalTp := types.ETString - if args[0].GetType(ctx.GetEvalCtx()).GetType() == mysql.TypeDate { - if !types.IsClockUnit(unit) { - // First arg is date and unit contains no HMS, return date. - resultTp = mysql.TypeDate - resultEvalTp = types.ETDatetime - } else { - // First arg is date and unit contains HMS, return datetime. - resultTp = mysql.TypeDatetime - resultEvalTp = types.ETDatetime - } - } else if dateEvalTp == types.ETDuration { - if types.IsDateUnit(unit) && unit != "DAY_MICROSECOND" { - // First arg is time and unit contains YMD (except DAY_MICROSECOND), return datetime. - resultTp = mysql.TypeDatetime - resultEvalTp = types.ETDatetime - } else { - // First arg is time and unit contains no YMD or is DAY_MICROSECOND, return time. - resultTp = mysql.TypeDuration - resultEvalTp = types.ETDuration - } - } else if dateEvalTp == types.ETDatetime { - // First arg is datetime or timestamp, return datetime. - resultTp = mysql.TypeDatetime - resultEvalTp = types.ETDatetime - } - - argTps := []types.EvalType{dateEvalTp, intervalEvalTp, types.ETString} - var bf baseBuiltinFunc - bf, err = newBaseBuiltinFuncWithTp(ctx, c.funcName, args, resultEvalTp, argTps...) - if err != nil { - return nil, err - } - bf.tp.SetType(resultTp) - - var resultFsp int - if types.IsMicrosecondUnit(unit) { - resultFsp = types.MaxFsp - } else { - intervalFsp := types.MinFsp - if unit == "SECOND" { - if intervalEvalTp == types.ETString || intervalEvalTp == types.ETReal { - intervalFsp = types.MaxFsp - } else { - intervalFsp = mathutil.Min(types.MaxFsp, args[1].GetType(ctx.GetEvalCtx()).GetDecimal()) - } - } - resultFsp = mathutil.Min(types.MaxFsp, mathutil.Max(args[0].GetType(ctx.GetEvalCtx()).GetDecimal(), intervalFsp)) - } - switch resultTp { - case mysql.TypeDate: - bf.setDecimalAndFlenForDate() - case mysql.TypeDuration: - bf.setDecimalAndFlenForTime(resultFsp) - case mysql.TypeDatetime: - bf.setDecimalAndFlenForDatetime(resultFsp) - case mysql.TypeVarString: - bf.tp.SetFlen(mysql.MaxDatetimeFullWidth) - bf.tp.SetDecimal(types.MinFsp) - } - - switch { - case dateEvalTp == types.ETString && intervalEvalTp == types.ETString: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromString, - vecGetDate: vecGetDateFromString, - getInterval: getIntervalFromString, - vecGetInterval: vecGetIntervalFromString, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringString, tipb.ScalarFuncSig_SubDateStringString) - case dateEvalTp == types.ETString && intervalEvalTp == types.ETInt: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromString, - vecGetDate: vecGetDateFromString, - getInterval: getIntervalFromInt, - vecGetInterval: vecGetIntervalFromInt, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringInt, tipb.ScalarFuncSig_SubDateStringInt) - case dateEvalTp == types.ETString && intervalEvalTp == types.ETReal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromString, - vecGetDate: vecGetDateFromString, - getInterval: getIntervalFromReal, - vecGetInterval: vecGetIntervalFromReal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringReal, tipb.ScalarFuncSig_SubDateStringReal) - case dateEvalTp == types.ETString && intervalEvalTp == types.ETDecimal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromString, - vecGetDate: vecGetDateFromString, - getInterval: getIntervalFromDecimal, - vecGetInterval: vecGetIntervalFromDecimal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateStringDecimal, tipb.ScalarFuncSig_SubDateStringDecimal) - case dateEvalTp == types.ETInt && intervalEvalTp == types.ETString: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromInt, - vecGetDate: vecGetDateFromInt, - getInterval: getIntervalFromString, - vecGetInterval: vecGetIntervalFromString, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntString, tipb.ScalarFuncSig_SubDateIntString) - case dateEvalTp == types.ETInt && intervalEvalTp == types.ETInt: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromInt, - vecGetDate: vecGetDateFromInt, - getInterval: getIntervalFromInt, - vecGetInterval: vecGetIntervalFromInt, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntInt, tipb.ScalarFuncSig_SubDateIntInt) - case dateEvalTp == types.ETInt && intervalEvalTp == types.ETReal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromInt, - vecGetDate: vecGetDateFromInt, - getInterval: getIntervalFromReal, - vecGetInterval: vecGetIntervalFromReal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntReal, tipb.ScalarFuncSig_SubDateIntReal) - case dateEvalTp == types.ETInt && intervalEvalTp == types.ETDecimal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromInt, - vecGetDate: vecGetDateFromInt, - getInterval: getIntervalFromDecimal, - vecGetInterval: vecGetIntervalFromDecimal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateIntDecimal, tipb.ScalarFuncSig_SubDateIntDecimal) - case dateEvalTp == types.ETReal && intervalEvalTp == types.ETString: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromReal, - vecGetDate: vecGetDateFromReal, - getInterval: getIntervalFromString, - vecGetInterval: vecGetIntervalFromString, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealString, tipb.ScalarFuncSig_SubDateRealString) - case dateEvalTp == types.ETReal && intervalEvalTp == types.ETInt: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromReal, - vecGetDate: vecGetDateFromReal, - getInterval: getIntervalFromInt, - vecGetInterval: vecGetIntervalFromInt, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealInt, tipb.ScalarFuncSig_SubDateRealInt) - case dateEvalTp == types.ETReal && intervalEvalTp == types.ETReal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromReal, - vecGetDate: vecGetDateFromReal, - getInterval: getIntervalFromReal, - vecGetInterval: vecGetIntervalFromReal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealReal, tipb.ScalarFuncSig_SubDateRealReal) - case dateEvalTp == types.ETReal && intervalEvalTp == types.ETDecimal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromReal, - vecGetDate: vecGetDateFromReal, - getInterval: getIntervalFromDecimal, - vecGetInterval: vecGetIntervalFromDecimal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateRealDecimal, tipb.ScalarFuncSig_SubDateRealDecimal) - case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETString: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromDecimal, - vecGetDate: vecGetDateFromDecimal, - getInterval: getIntervalFromString, - vecGetInterval: vecGetIntervalFromString, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalString, tipb.ScalarFuncSig_SubDateDecimalString) - case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETInt: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromDecimal, - vecGetDate: vecGetDateFromDecimal, - getInterval: getIntervalFromInt, - vecGetInterval: vecGetIntervalFromInt, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalInt, tipb.ScalarFuncSig_SubDateDecimalInt) - case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETReal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromDecimal, - vecGetDate: vecGetDateFromDecimal, - getInterval: getIntervalFromReal, - vecGetInterval: vecGetIntervalFromReal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalReal, tipb.ScalarFuncSig_SubDateDecimalReal) - case dateEvalTp == types.ETDecimal && intervalEvalTp == types.ETDecimal: - sig = &builtinAddSubDateAsStringSig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getDate: getDateFromDecimal, - vecGetDate: vecGetDateFromDecimal, - getInterval: getIntervalFromDecimal, - vecGetInterval: vecGetIntervalFromDecimal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDecimalDecimal, tipb.ScalarFuncSig_SubDateDecimalDecimal) - case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETString: - sig = &builtinAddSubDateDatetimeAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromString, - vecGetInterval: vecGetIntervalFromString, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeString, tipb.ScalarFuncSig_SubDateDatetimeString) - case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETInt: - sig = &builtinAddSubDateDatetimeAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromInt, - vecGetInterval: vecGetIntervalFromInt, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeInt, tipb.ScalarFuncSig_SubDateDatetimeInt) - case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETReal: - sig = &builtinAddSubDateDatetimeAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromReal, - vecGetInterval: vecGetIntervalFromReal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeReal, tipb.ScalarFuncSig_SubDateDatetimeReal) - case dateEvalTp == types.ETDatetime && intervalEvalTp == types.ETDecimal: - sig = &builtinAddSubDateDatetimeAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromDecimal, - vecGetInterval: vecGetIntervalFromDecimal, - timeOp: c.timeOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDatetimeDecimal, tipb.ScalarFuncSig_SubDateDatetimeDecimal) - case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETString: - sig = &builtinAddSubDateDurationAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromString, - vecGetInterval: vecGetIntervalFromString, - timeOp: c.timeOp, - durationOp: c.durationOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationString, tipb.ScalarFuncSig_SubDateDurationString) - case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETInt: - sig = &builtinAddSubDateDurationAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromInt, - vecGetInterval: vecGetIntervalFromInt, - timeOp: c.timeOp, - durationOp: c.durationOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationInt, tipb.ScalarFuncSig_SubDateDurationInt) - case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETReal: - sig = &builtinAddSubDateDurationAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromReal, - vecGetInterval: vecGetIntervalFromReal, - timeOp: c.timeOp, - durationOp: c.durationOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationReal, tipb.ScalarFuncSig_SubDateDurationReal) - case dateEvalTp == types.ETDuration && intervalEvalTp == types.ETDecimal: - sig = &builtinAddSubDateDurationAnySig{ - baseBuiltinFunc: bf, - baseDateArithmetical: newDateArithmeticalUtil(), - getInterval: getIntervalFromDecimal, - vecGetInterval: vecGetIntervalFromDecimal, - timeOp: c.timeOp, - durationOp: c.durationOp, - } - c.setPbCodeOp(sig, tipb.ScalarFuncSig_AddDateDurationDecimal, tipb.ScalarFuncSig_SubDateDurationDecimal) - } - return sig, nil -} - -type builtinAddSubDateAsStringSig struct { - baseBuiltinFunc - baseDateArithmetical - getDate funcGetDateForDateAddSub - vecGetDate funcVecGetDateForDateAddSub - getInterval funcGetIntervalForDateAddSub - vecGetInterval funcVecGetIntervalForDateAddSub - timeOp funcTimeOpForDateAddSub -} - -func (b *builtinAddSubDateAsStringSig) Clone() builtinFunc { - newSig := &builtinAddSubDateAsStringSig{ - baseDateArithmetical: b.baseDateArithmetical, - getDate: b.getDate, - vecGetDate: b.vecGetDate, - getInterval: b.getInterval, - vecGetInterval: b.vecGetInterval, - timeOp: b.timeOp, - } - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinAddSubDateAsStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - unit, isNull, err := b.args[2].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime.String(), true, err - } - - date, isNull, err := b.getDate(&b.baseDateArithmetical, ctx, b.args, row, unit) - if isNull || err != nil { - return types.ZeroTime.String(), true, err - } - if date.InvalidZero() { - return types.ZeroTime.String(), true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, date.String())) - } - - interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) - if isNull || err != nil { - return types.ZeroTime.String(), true, err - } - - result, isNull, err := b.timeOp(&b.baseDateArithmetical, ctx, date, interval, unit, b.tp.GetDecimal()) - if result.Microsecond() == 0 { - result.SetFsp(types.MinFsp) - } else { - result.SetFsp(types.MaxFsp) - } - - return result.String(), isNull, err -} - -type builtinAddSubDateDatetimeAnySig struct { - baseBuiltinFunc - baseDateArithmetical - getInterval funcGetIntervalForDateAddSub - vecGetInterval funcVecGetIntervalForDateAddSub - timeOp funcTimeOpForDateAddSub -} - -func (b *builtinAddSubDateDatetimeAnySig) Clone() builtinFunc { - newSig := &builtinAddSubDateDatetimeAnySig{ - baseDateArithmetical: b.baseDateArithmetical, - getInterval: b.getInterval, - vecGetInterval: b.vecGetInterval, - timeOp: b.timeOp, - } - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinAddSubDateDatetimeAnySig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - unit, isNull, err := b.args[2].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - date, isNull, err := b.getDateFromDatetime(ctx, b.args, row, unit) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - result, isNull, err := b.timeOp(&b.baseDateArithmetical, ctx, date, interval, unit, b.tp.GetDecimal()) - return result, isNull || err != nil, err -} - -type builtinAddSubDateDurationAnySig struct { - baseBuiltinFunc - baseDateArithmetical - getInterval funcGetIntervalForDateAddSub - vecGetInterval funcVecGetIntervalForDateAddSub - timeOp funcTimeOpForDateAddSub - durationOp funcDurationOpForDateAddSub -} - -func (b *builtinAddSubDateDurationAnySig) Clone() builtinFunc { - newSig := &builtinAddSubDateDurationAnySig{ - baseDateArithmetical: b.baseDateArithmetical, - getInterval: b.getInterval, - vecGetInterval: b.vecGetInterval, - timeOp: b.timeOp, - durationOp: b.durationOp, - } - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinAddSubDateDurationAnySig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - unit, isNull, err := b.args[2].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - d, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) - if isNull || err != nil { - return types.ZeroTime, true, err - } - - tc := typeCtx(ctx) - t, err := d.ConvertToTime(tc, mysql.TypeDatetime) - if err != nil { - return types.ZeroTime, true, err - } - result, isNull, err := b.timeOp(&b.baseDateArithmetical, ctx, t, interval, unit, b.tp.GetDecimal()) - return result, isNull || err != nil, err -} - -func (b *builtinAddSubDateDurationAnySig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - unit, isNull, err := b.args[2].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, true, err - } - - dur, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, true, err - } - - interval, isNull, err := b.getInterval(&b.baseDateArithmetical, ctx, b.args, row, unit) - if isNull || err != nil { - return types.ZeroDuration, true, err - } - - result, isNull, err := b.durationOp(&b.baseDateArithmetical, ctx, dur, interval, unit, b.tp.GetDecimal()) - return result, isNull || err != nil, err -} - -type timestampDiffFunctionClass struct { - baseFunctionClass -} - -func (c *timestampDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETString, types.ETDatetime, types.ETDatetime) - if err != nil { - return nil, err - } - sig := &builtinTimestampDiffSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_TimestampDiff) - return sig, nil -} - -type builtinTimestampDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinTimestampDiffSig) Clone() builtinFunc { - newSig := &builtinTimestampDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinTimestampDiffSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timestampdiff -func (b *builtinTimestampDiffSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - unit, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - lhs, isNull, err := b.args[1].EvalTime(ctx, row) - if isNull || err != nil { - return 0, isNull, handleInvalidTimeError(ctx, err) - } - rhs, isNull, err := b.args[2].EvalTime(ctx, row) - if isNull || err != nil { - return 0, isNull, handleInvalidTimeError(ctx, err) - } - if invalidLHS, invalidRHS := lhs.InvalidZero(), rhs.InvalidZero(); invalidLHS || invalidRHS { - if invalidLHS { - err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, lhs.String())) - } - if invalidRHS { - err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rhs.String())) - } - return 0, true, err - } - return types.TimestampDiff(unit, lhs, rhs), false, nil -} - -type unixTimestampFunctionClass struct { - baseFunctionClass -} - -func (c *unixTimestampFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - var ( - argTps []types.EvalType - retTp types.EvalType - retFLen, retDecimal int - ) - - if len(args) == 0 { - retTp, retDecimal = types.ETInt, 0 - } else { - argTps = []types.EvalType{types.ETDatetime} - argType := args[0].GetType(ctx.GetEvalCtx()) - argEvaltp := argType.EvalType() - if argEvaltp == types.ETString { - // Treat types.ETString as unspecified decimal. - retDecimal = types.UnspecifiedLength - if cnst, ok := args[0].(*Constant); ok { - tmpStr, _, err := cnst.EvalString(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, err - } - retDecimal = 0 - if dotIdx := strings.LastIndex(tmpStr, "."); dotIdx >= 0 { - retDecimal = len(tmpStr) - dotIdx - 1 - } - } - } else { - retDecimal = argType.GetDecimal() - } - if retDecimal > 6 || retDecimal == types.UnspecifiedLength { - retDecimal = 6 - } - if retDecimal == 0 { - retTp = types.ETInt - } else { - retTp = types.ETDecimal - } - } - if retTp == types.ETInt { - retFLen = 11 - } else if retTp == types.ETDecimal { - retFLen = 12 + retDecimal - } else { - panic("Unexpected retTp") - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, retTp, argTps...) - if err != nil { - return nil, err - } - bf.tp.SetFlenUnderLimit(retFLen) - bf.tp.SetDecimalUnderLimit(retDecimal) - - var sig builtinFunc - if len(args) == 0 { - sig = &builtinUnixTimestampCurrentSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UnixTimestampCurrent) - } else if retTp == types.ETInt { - sig = &builtinUnixTimestampIntSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UnixTimestampInt) - } else if retTp == types.ETDecimal { - sig = &builtinUnixTimestampDecSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UnixTimestampDec) - } - return sig, nil -} - -// goTimeToMysqlUnixTimestamp converts go time into MySQL's Unix timestamp. -// MySQL's Unix timestamp ranges from '1970-01-01 00:00:01.000000' UTC to '3001-01-18 23:59:59.999999' UTC. Values out of range should be rewritten to 0. -// https://dev.mysql.com/doc/refman/8.0/en/date-and-time-functions.html#function_unix-timestamp -func goTimeToMysqlUnixTimestamp(t time.Time, decimal int) (*types.MyDecimal, error) { - microSeconds := t.UnixMicro() - // Prior to MySQL 8.0.28 (or any 32-bit platform), the valid range of argument values is the same as for the TIMESTAMP data type: - // '1970-01-01 00:00:01.000000' UTC to '2038-01-19 03:14:07.999999' UTC. - // After 8.0.28, the range has been extended to '1970-01-01 00:00:01.000000' UTC to '3001-01-18 23:59:59.999999' UTC - // The magic value of '3001-01-18 23:59:59.999999' comes from the maximum supported timestamp on windows. Though TiDB - // doesn't support windows, this value is used here to keep the compatibility with MySQL - if microSeconds < 1e6 || microSeconds > 32536771199999999 { - return new(types.MyDecimal), nil - } - dec := new(types.MyDecimal) - // Here we don't use float to prevent precision lose. - dec.FromUint(uint64(microSeconds)) - err := dec.Shift(-6) - if err != nil { - return nil, err - } - - // In MySQL's implementation, unix_timestamp() will truncate the result instead of rounding it. - // Results below are from MySQL 5.7, which can prove it. - // mysql> select unix_timestamp(), unix_timestamp(now(0)), now(0), unix_timestamp(now(3)), now(3), now(6); - // +------------------+------------------------+---------------------+------------------------+-------------------------+----------------------------+ - // | unix_timestamp() | unix_timestamp(now(0)) | now(0) | unix_timestamp(now(3)) | now(3) | now(6) | - // +------------------+------------------------+---------------------+------------------------+-------------------------+----------------------------+ - // | 1553503194 | 1553503194 | 2019-03-25 16:39:54 | 1553503194.992 | 2019-03-25 16:39:54.992 | 2019-03-25 16:39:54.992969 | - // +------------------+------------------------+---------------------+------------------------+-------------------------+----------------------------+ - err = dec.Round(dec, decimal, types.ModeTruncate) - return dec, err -} - -type builtinUnixTimestampCurrentSig struct { - baseBuiltinFunc -} - -func (b *builtinUnixTimestampCurrentSig) Clone() builtinFunc { - newSig := &builtinUnixTimestampCurrentSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a UNIX_TIMESTAMP(). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp -func (b *builtinUnixTimestampCurrentSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return 0, true, err - } - dec, err := goTimeToMysqlUnixTimestamp(nowTs, 1) - if err != nil { - return 0, true, err - } - intVal, err := dec.ToInt() - if !terror.ErrorEqual(err, types.ErrTruncated) { - terror.Log(err) - } - return intVal, false, nil -} - -type builtinUnixTimestampIntSig struct { - baseBuiltinFunc -} - -func (b *builtinUnixTimestampIntSig) Clone() builtinFunc { - newSig := &builtinUnixTimestampIntSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a UNIX_TIMESTAMP(time). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp -func (b *builtinUnixTimestampIntSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - val, isNull, err := b.args[0].EvalTime(ctx, row) - if err != nil && terror.ErrorEqual(types.ErrWrongValue.GenWithStackByArgs(types.TimeStr, val), err) { - // Return 0 for invalid date time. - return 0, false, nil - } - if isNull { - return 0, true, nil - } - - tz := location(ctx) - t, err := val.AdjustedGoTime(tz) - if err != nil { - return 0, false, nil - } - dec, err := goTimeToMysqlUnixTimestamp(t, 1) - if err != nil { - return 0, true, err - } - intVal, err := dec.ToInt() - if !terror.ErrorEqual(err, types.ErrTruncated) { - terror.Log(err) - } - return intVal, false, nil -} - -type builtinUnixTimestampDecSig struct { - baseBuiltinFunc -} - -func (b *builtinUnixTimestampDecSig) Clone() builtinFunc { - newSig := &builtinUnixTimestampDecSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDecimal evals a UNIX_TIMESTAMP(time). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_unix-timestamp -func (b *builtinUnixTimestampDecSig) evalDecimal(ctx EvalContext, row chunk.Row) (*types.MyDecimal, bool, error) { - val, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - // Return 0 for invalid date time. - return new(types.MyDecimal), isNull, nil - } - t, err := val.GoTime(getTimeZone(ctx)) - if err != nil { - return new(types.MyDecimal), false, nil - } - result, err := goTimeToMysqlUnixTimestamp(t, b.tp.GetDecimal()) - return result, err != nil, err -} - -type timestampFunctionClass struct { - baseFunctionClass -} - -func (c *timestampFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - evalTps, argLen := []types.EvalType{types.ETString}, len(args) - if argLen == 2 { - evalTps = append(evalTps, types.ETString) - } - fsp, err := getExpressionFsp(ctx, args[0]) - if err != nil { - return nil, err - } - if argLen == 2 { - fsp2, err := getExpressionFsp(ctx, args[1]) - if err != nil { - return nil, err - } - if fsp2 > fsp { - fsp = fsp2 - } - } - isFloat := false - switch args[0].GetType(ctx.GetEvalCtx()).GetType() { - case mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal, mysql.TypeLonglong: - isFloat = true - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, evalTps...) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(fsp) - var sig builtinFunc - if argLen == 2 { - sig = &builtinTimestamp2ArgsSig{bf, isFloat} - sig.setPbCode(tipb.ScalarFuncSig_Timestamp2Args) - } else { - sig = &builtinTimestamp1ArgSig{bf, isFloat} - sig.setPbCode(tipb.ScalarFuncSig_Timestamp1Arg) - } - return sig, nil -} - -type builtinTimestamp1ArgSig struct { - baseBuiltinFunc - - isFloat bool -} - -func (b *builtinTimestamp1ArgSig) Clone() builtinFunc { - newSig := &builtinTimestamp1ArgSig{isFloat: b.isFloat} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinTimestamp1ArgSig. -// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_timestamp -func (b *builtinTimestamp1ArgSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - s, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - var tm types.Time - tc := typeCtx(ctx) - if b.isFloat { - tm, err = types.ParseTimeFromFloatString(tc, s, mysql.TypeDatetime, types.GetFsp(s)) - } else { - tm, err = types.ParseTime(tc, s, mysql.TypeDatetime, types.GetFsp(s)) - } - if err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - return tm, false, nil -} - -type builtinTimestamp2ArgsSig struct { - baseBuiltinFunc - - isFloat bool -} - -func (b *builtinTimestamp2ArgsSig) Clone() builtinFunc { - newSig := &builtinTimestamp2ArgsSig{isFloat: b.isFloat} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinTimestamp2ArgsSig. -// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_timestamp -func (b *builtinTimestamp2ArgsSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - arg0, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - var tm types.Time - tc := typeCtx(ctx) - if b.isFloat { - tm, err = types.ParseTimeFromFloatString(tc, arg0, mysql.TypeDatetime, types.GetFsp(arg0)) - } else { - tm, err = types.ParseTime(tc, arg0, mysql.TypeDatetime, types.GetFsp(arg0)) - } - if err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - if tm.Year() == 0 { - // MySQL won't evaluate add for date with zero year. - // See https://github.com/mysql/mysql-server/blob/5.7/sql/item_timefunc.cc#L2805 - return types.ZeroTime, true, nil - } - arg1, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, isNull, err - } - if !isDuration(arg1) { - return types.ZeroTime, true, nil - } - duration, _, err := types.ParseDuration(tc, arg1, types.GetFsp(arg1)) - if err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - tmp, err := tm.Add(tc, duration) - if err != nil { - return types.ZeroTime, true, err - } - return tmp, false, nil -} - -type timestampLiteralFunctionClass struct { - baseFunctionClass -} - -func (c *timestampLiteralFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - con, ok := args[0].(*Constant) - if !ok { - panic("Unexpected parameter for timestamp literal") - } - dt, err := con.Eval(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, err - } - str, err := dt.ToString() - if err != nil { - return nil, err - } - if !timestampPattern.MatchString(str) { - return nil, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, str) - } - tm, err := types.ParseTime(ctx.GetEvalCtx().TypeCtx(), str, mysql.TypeDatetime, types.GetFsp(str)) - if err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, []Expression{}, types.ETDatetime) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(tm.Fsp()) - sig := &builtinTimestampLiteralSig{bf, tm} - return sig, nil -} - -type builtinTimestampLiteralSig struct { - baseBuiltinFunc - tm types.Time -} - -func (b *builtinTimestampLiteralSig) Clone() builtinFunc { - newSig := &builtinTimestampLiteralSig{tm: b.tm} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals TIMESTAMP 'stringLit'. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-literals.html -func (b *builtinTimestampLiteralSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - return b.tm, false, nil -} - -// getFsp4TimeAddSub is used to in function 'ADDTIME' and 'SUBTIME' to evaluate `fsp` for the -// second parameter. It's used only if the second parameter is of string type. It's different -// from getFsp in that the result of getFsp4TimeAddSub is either 6 or 0. -func getFsp4TimeAddSub(s string) int { - if len(s)-strings.Index(s, ".")-1 == len(s) { - return types.MinFsp - } - for _, c := range s[strings.Index(s, ".")+1:] { - if c != '0' { - return types.MaxFsp - } - } - return types.MinFsp -} - -// getBf4TimeAddSub parses input types, generates baseBuiltinFunc and set related attributes for -// builtin function 'ADDTIME' and 'SUBTIME' -func getBf4TimeAddSub(ctx BuildContext, funcName string, args []Expression) (tp1, tp2 *types.FieldType, bf baseBuiltinFunc, err error) { - tp1, tp2 = args[0].GetType(ctx.GetEvalCtx()), args[1].GetType(ctx.GetEvalCtx()) - var argTp1, argTp2, retTp types.EvalType - switch tp1.GetType() { - case mysql.TypeDatetime, mysql.TypeTimestamp: - argTp1, retTp = types.ETDatetime, types.ETDatetime - case mysql.TypeDuration: - argTp1, retTp = types.ETDuration, types.ETDuration - case mysql.TypeDate: - argTp1, retTp = types.ETDuration, types.ETString - default: - argTp1, retTp = types.ETString, types.ETString - } - switch tp2.GetType() { - case mysql.TypeDatetime, mysql.TypeDuration: - argTp2 = types.ETDuration - default: - argTp2 = types.ETString - } - arg0Dec, err := getExpressionFsp(ctx, args[0]) - if err != nil { - return - } - arg1Dec, err := getExpressionFsp(ctx, args[1]) - if err != nil { - return - } - - bf, err = newBaseBuiltinFuncWithTp(ctx, funcName, args, retTp, argTp1, argTp2) - if err != nil { - return - } - switch retTp { - case types.ETDatetime: - bf.setDecimalAndFlenForDatetime(mathutil.Min(mathutil.Max(arg0Dec, arg1Dec), types.MaxFsp)) - case types.ETDuration: - bf.setDecimalAndFlenForTime(mathutil.Min(mathutil.Max(arg0Dec, arg1Dec), types.MaxFsp)) - case types.ETString: - bf.tp.SetType(mysql.TypeString) - bf.tp.SetFlen(mysql.MaxDatetimeWidthWithFsp) - bf.tp.SetDecimal(types.UnspecifiedLength) - } - return -} - -func getTimeZone(ctx EvalContext) *time.Location { - ret := location(ctx) - if ret == nil { - ret = time.Local - } - return ret -} - -// isDuration returns a boolean indicating whether the str matches the format of duration. -// See https://dev.mysql.com/doc/refman/5.7/en/time.html -func isDuration(str string) bool { - return durationPattern.MatchString(str) -} - -// strDatetimeAddDuration adds duration to datetime string, returns a string value. -func strDatetimeAddDuration(tc types.Context, d string, arg1 types.Duration) (result string, isNull bool, err error) { - arg0, err := types.ParseTime(tc, d, mysql.TypeDatetime, types.MaxFsp) - if err != nil { - // Return a warning regardless of the sql_mode, this is compatible with MySQL. - tc.AppendWarning(err) - return "", true, nil - } - ret, err := arg0.Add(tc, arg1) - if err != nil { - return "", false, err - } - fsp := types.MaxFsp - if ret.Microsecond() == 0 { - fsp = types.MinFsp - } - ret.SetFsp(fsp) - return ret.String(), false, nil -} - -// strDurationAddDuration adds duration to duration string, returns a string value. -func strDurationAddDuration(tc types.Context, d string, arg1 types.Duration) (string, error) { - arg0, _, err := types.ParseDuration(tc, d, types.MaxFsp) - if err != nil { - return "", err - } - tmpDuration, err := arg0.Add(arg1) - if err != nil { - return "", err - } - tmpDuration.Fsp = types.MaxFsp - if tmpDuration.MicroSecond() == 0 { - tmpDuration.Fsp = types.MinFsp - } - return tmpDuration.String(), nil -} - -// strDatetimeSubDuration subtracts duration from datetime string, returns a string value. -func strDatetimeSubDuration(tc types.Context, d string, arg1 types.Duration) (result string, isNull bool, err error) { - arg0, err := types.ParseTime(tc, d, mysql.TypeDatetime, types.MaxFsp) - if err != nil { - // Return a warning regardless of the sql_mode, this is compatible with MySQL. - tc.AppendWarning(err) - return "", true, nil - } - resultTime, err := arg0.Add(tc, arg1.Neg()) - if err != nil { - return "", false, err - } - fsp := types.MaxFsp - if resultTime.Microsecond() == 0 { - fsp = types.MinFsp - } - resultTime.SetFsp(fsp) - return resultTime.String(), false, nil -} - -// strDurationSubDuration subtracts duration from duration string, returns a string value. -func strDurationSubDuration(tc types.Context, d string, arg1 types.Duration) (string, error) { - arg0, _, err := types.ParseDuration(tc, d, types.MaxFsp) - if err != nil { - return "", err - } - tmpDuration, err := arg0.Sub(arg1) - if err != nil { - return "", err - } - tmpDuration.Fsp = types.MaxFsp - if tmpDuration.MicroSecond() == 0 { - tmpDuration.Fsp = types.MinFsp - } - return tmpDuration.String(), nil -} - -type addTimeFunctionClass struct { - baseFunctionClass -} - -func (c *addTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { - if err = c.verifyArgs(args); err != nil { - return nil, err - } - tp1, tp2, bf, err := getBf4TimeAddSub(ctx, c.funcName, args) - if err != nil { - return nil, err - } - switch tp1.GetType() { - case mysql.TypeDatetime, mysql.TypeTimestamp: - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinAddDatetimeAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddDatetimeAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinAddTimeDateTimeNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddTimeDateTimeNull) - default: - sig = &builtinAddDatetimeAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddDatetimeAndString) - } - case mysql.TypeDate: - charset, collate := ctx.GetCharsetInfo() - bf.tp.SetCharset(charset) - bf.tp.SetCollate(collate) - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinAddDateAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddDateAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinAddTimeStringNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddTimeStringNull) - default: - sig = &builtinAddDateAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddDateAndString) - } - case mysql.TypeDuration: - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinAddDurationAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddDurationAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinAddTimeDurationNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddTimeDurationNull) - default: - sig = &builtinAddDurationAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddDurationAndString) - } - default: - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinAddStringAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddStringAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinAddTimeStringNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddTimeStringNull) - default: - sig = &builtinAddStringAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_AddStringAndString) - } - } - return sig, nil -} - -type builtinAddTimeDateTimeNullSig struct { - baseBuiltinFunc -} - -func (b *builtinAddTimeDateTimeNullSig) Clone() builtinFunc { - newSig := &builtinAddTimeDateTimeNullSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinAddTimeDateTimeNullSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddTimeDateTimeNullSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - return types.ZeroDatetime, true, nil -} - -type builtinAddDatetimeAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinAddDatetimeAndDurationSig) Clone() builtinFunc { - newSig := &builtinAddDatetimeAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinAddDatetimeAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddDatetimeAndDurationSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - arg0, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - arg1, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - result, err := arg0.Add(typeCtx(ctx), arg1) - return result, err != nil, err -} - -type builtinAddDatetimeAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinAddDatetimeAndStringSig) Clone() builtinFunc { - newSig := &builtinAddDatetimeAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinAddDatetimeAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddDatetimeAndStringSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - arg0, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - s, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - if !isDuration(s) { - return types.ZeroDatetime, true, nil - } - tc := typeCtx(ctx) - arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return types.ZeroDatetime, true, nil - } - return types.ZeroDatetime, true, err - } - result, err := arg0.Add(tc, arg1) - return result, err != nil, err -} - -type builtinAddTimeDurationNullSig struct { - baseBuiltinFunc -} - -func (b *builtinAddTimeDurationNullSig) Clone() builtinFunc { - newSig := &builtinAddTimeDurationNullSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinAddTimeDurationNullSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddTimeDurationNullSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - return types.ZeroDuration, true, nil -} - -type builtinAddDurationAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinAddDurationAndDurationSig) Clone() builtinFunc { - newSig := &builtinAddDurationAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinAddDurationAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddDurationAndDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - arg1, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - result, err := arg0.Add(arg1) - if err != nil { - return types.ZeroDuration, true, err - } - return result, false, nil -} - -type builtinAddDurationAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinAddDurationAndStringSig) Clone() builtinFunc { - newSig := &builtinAddDurationAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinAddDurationAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddDurationAndStringSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - s, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - if !isDuration(s) { - return types.ZeroDuration, true, nil - } - tc := typeCtx(ctx) - arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return types.ZeroDuration, true, nil - } - return types.ZeroDuration, true, err - } - result, err := arg0.Add(arg1) - if err != nil { - return types.ZeroDuration, true, err - } - return result, false, nil -} - -type builtinAddTimeStringNullSig struct { - baseBuiltinFunc -} - -func (b *builtinAddTimeStringNullSig) Clone() builtinFunc { - newSig := &builtinAddTimeStringNullSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinAddDurationAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddTimeStringNullSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - return "", true, nil -} - -type builtinAddStringAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinAddStringAndDurationSig) Clone() builtinFunc { - newSig := &builtinAddStringAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinAddStringAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddStringAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { - var ( - arg0 string - arg1 types.Duration - ) - arg0, isNull, err = b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - arg1, isNull, err = b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - tc := typeCtx(ctx) - if isDuration(arg0) { - result, err = strDurationAddDuration(tc, arg0, arg1) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - return result, false, nil - } - result, isNull, err = strDatetimeAddDuration(tc, arg0, arg1) - return result, isNull, err -} - -type builtinAddStringAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinAddStringAndStringSig) Clone() builtinFunc { - newSig := &builtinAddStringAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinAddStringAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddStringAndStringSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { - var ( - arg0, arg1Str string - arg1 types.Duration - ) - arg0, isNull, err = b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - arg1Type := b.args[1].GetType(ctx) - if mysql.HasBinaryFlag(arg1Type.GetFlag()) { - return "", true, nil - } - arg1Str, isNull, err = b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - tc := typeCtx(ctx) - arg1, _, err = types.ParseDuration(tc, arg1Str, getFsp4TimeAddSub(arg1Str)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - - check := arg1Str - _, check, err = parser.Number(parser.Space0(check)) - if err == nil { - check, err = parser.Char(check, '-') - if strings.Compare(check, "") != 0 && err == nil { - return "", true, nil - } - } - - if isDuration(arg0) { - result, err = strDurationAddDuration(tc, arg0, arg1) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - return result, false, nil - } - result, isNull, err = strDatetimeAddDuration(tc, arg0, arg1) - return result, isNull, err -} - -type builtinAddDateAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinAddDateAndDurationSig) Clone() builtinFunc { - newSig := &builtinAddDateAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinAddDurationAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddDateAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - arg1, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - result, err := arg0.Add(arg1) - return result.String(), err != nil, err -} - -type builtinAddDateAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinAddDateAndStringSig) Clone() builtinFunc { - newSig := &builtinAddDateAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinAddDateAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_addtime -func (b *builtinAddDateAndStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - s, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - if !isDuration(s) { - return "", true, nil - } - tc := typeCtx(ctx) - arg1, _, err := types.ParseDuration(tc, s, getFsp4TimeAddSub(s)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - result, err := arg0.Add(arg1) - return result.String(), err != nil, err -} - -type convertTzFunctionClass struct { - baseFunctionClass -} - -func (c *convertTzFunctionClass) getDecimal(ctx BuildContext, arg Expression) int { - decimal := types.MaxFsp - if dt, isConstant := arg.(*Constant); isConstant { - switch arg.GetType(ctx.GetEvalCtx()).EvalType() { - case types.ETInt: - decimal = 0 - case types.ETReal, types.ETDecimal: - decimal = arg.GetType(ctx.GetEvalCtx()).GetDecimal() - case types.ETString: - str, isNull, err := dt.EvalString(ctx.GetEvalCtx(), chunk.Row{}) - if err == nil && !isNull { - decimal = types.DateFSP(str) - } - } - } - if decimal > types.MaxFsp { - return types.MaxFsp - } - if decimal < types.MinFsp { - return types.MinFsp - } - return decimal -} - -func (c *convertTzFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - // tzRegex holds the regex to check whether a string is a time zone. - tzRegex, err := regexp.Compile(`(^[-+](0?[0-9]|1[0-3]):[0-5]?\d$)|(^\+14:00?$)`) - if err != nil { - return nil, err - } - - decimal := c.getDecimal(ctx, args[0]) - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime, types.ETString, types.ETString) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(decimal) - sig := &builtinConvertTzSig{ - baseBuiltinFunc: bf, - timezoneRegex: tzRegex, - } - sig.setPbCode(tipb.ScalarFuncSig_ConvertTz) - return sig, nil -} - -type builtinConvertTzSig struct { - baseBuiltinFunc - timezoneRegex *regexp.Regexp -} - -func (b *builtinConvertTzSig) Clone() builtinFunc { - newSig := &builtinConvertTzSig{timezoneRegex: b.timezoneRegex} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals CONVERT_TZ(dt,from_tz,to_tz). -// `CONVERT_TZ` function returns NULL if the arguments are invalid. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_convert-tz -func (b *builtinConvertTzSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - dt, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, nil - } - if dt.InvalidZero() { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, dt.String())) - } - fromTzStr, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, nil - } - - toTzStr, isNull, err := b.args[2].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, nil - } - - return b.convertTz(dt, fromTzStr, toTzStr) -} - -func (b *builtinConvertTzSig) convertTz(dt types.Time, fromTzStr, toTzStr string) (types.Time, bool, error) { - if fromTzStr == "" || toTzStr == "" { - return types.ZeroTime, true, nil - } - fromTzMatched := b.timezoneRegex.MatchString(fromTzStr) - toTzMatched := b.timezoneRegex.MatchString(toTzStr) - - var fromTz, toTz *time.Location - var err error - - if fromTzMatched { - fromTz = time.FixedZone(fromTzStr, timeZone2int(fromTzStr)) - } else { - if strings.EqualFold(fromTzStr, "SYSTEM") { - fromTzStr = "Local" - } - fromTz, err = time.LoadLocation(fromTzStr) - if err != nil { - return types.ZeroTime, true, nil - } - } - - t, err := dt.AdjustedGoTime(fromTz) - if err != nil { - return types.ZeroTime, true, nil - } - t = t.In(time.UTC) - - if toTzMatched { - toTz = time.FixedZone(toTzStr, timeZone2int(toTzStr)) - } else { - if strings.EqualFold(toTzStr, "SYSTEM") { - toTzStr = "Local" - } - toTz, err = time.LoadLocation(toTzStr) - if err != nil { - return types.ZeroTime, true, nil - } - } - - return types.NewTime(types.FromGoTime(t.In(toTz)), mysql.TypeDatetime, b.tp.GetDecimal()), false, nil -} - -type makeDateFunctionClass struct { - baseFunctionClass -} - -func (c *makeDateFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETInt, types.ETInt) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig := &builtinMakeDateSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_MakeDate) - return sig, nil -} - -type builtinMakeDateSig struct { - baseBuiltinFunc -} - -func (b *builtinMakeDateSig) Clone() builtinFunc { - newSig := &builtinMakeDateSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evaluates a builtinMakeDateSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_makedate -func (b *builtinMakeDateSig) evalTime(ctx EvalContext, row chunk.Row) (d types.Time, isNull bool, err error) { - args := b.getArgs() - var year, dayOfYear int64 - year, isNull, err = args[0].EvalInt(ctx, row) - if isNull || err != nil { - return d, true, err - } - dayOfYear, isNull, err = args[1].EvalInt(ctx, row) - if isNull || err != nil { - return d, true, err - } - if dayOfYear <= 0 || year < 0 || year > 9999 { - return d, true, nil - } - if year < 70 { - year += 2000 - } else if year < 100 { - year += 1900 - } - startTime := types.NewTime(types.FromDate(int(year), 1, 1, 0, 0, 0, 0), mysql.TypeDate, 0) - retTimestamp := types.TimestampDiff("DAY", types.ZeroDate, startTime) - if retTimestamp == 0 { - return d, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, startTime.String())) - } - ret := types.TimeFromDays(retTimestamp + dayOfYear - 1) - if ret.IsZero() || ret.Year() > 9999 { - return d, true, nil - } - return ret, false, nil -} - -type makeTimeFunctionClass struct { - baseFunctionClass -} - -func (c *makeTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - tp, decimal := args[2].GetType(ctx.GetEvalCtx()).EvalType(), 0 - switch tp { - case types.ETInt: - case types.ETReal, types.ETDecimal: - decimal = args[2].GetType(ctx.GetEvalCtx()).GetDecimal() - if decimal > 6 || decimal == types.UnspecifiedLength { - decimal = 6 - } - default: - decimal = 6 - } - // MySQL will cast the first and second arguments to INT, and the third argument to DECIMAL. - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETInt, types.ETInt, types.ETReal) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(decimal) - sig := &builtinMakeTimeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_MakeTime) - return sig, nil -} - -type builtinMakeTimeSig struct { - baseBuiltinFunc -} - -func (b *builtinMakeTimeSig) Clone() builtinFunc { - newSig := &builtinMakeTimeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinMakeTimeSig) makeTime(ctx types.Context, hour int64, minute int64, second float64, hourUnsignedFlag bool) (types.Duration, error) { - var overflow bool - // MySQL TIME datatype: https://dev.mysql.com/doc/refman/5.7/en/time.html - // ranges from '-838:59:59.000000' to '838:59:59.000000' - if hour < 0 && hourUnsignedFlag { - hour = 838 - overflow = true - } - if hour < -838 { - hour = -838 - overflow = true - } else if hour > 838 { - hour = 838 - overflow = true - } - if (hour == -838 || hour == 838) && minute == 59 && second > 59 { - overflow = true - } - if overflow { - minute = 59 - second = 59 - } - fsp := b.tp.GetDecimal() - d, _, err := types.ParseDuration(ctx, fmt.Sprintf("%02d:%02d:%v", hour, minute, second), fsp) - return d, err -} - -// evalDuration evals a builtinMakeTimeIntSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_maketime -func (b *builtinMakeTimeSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - dur := types.ZeroDuration - dur.Fsp = types.MaxFsp - hour, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return dur, isNull, err - } - minute, isNull, err := b.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return dur, isNull, err - } - if minute < 0 || minute >= 60 { - return dur, true, nil - } - second, isNull, err := b.args[2].EvalReal(ctx, row) - if isNull || err != nil { - return dur, isNull, err - } - if second < 0 || second >= 60 { - return dur, true, nil - } - dur, err = b.makeTime(typeCtx(ctx), hour, minute, second, mysql.HasUnsignedFlag(b.args[0].GetType(ctx).GetFlag())) - if err != nil { - return dur, true, err - } - return dur, false, nil -} - -type periodAddFunctionClass struct { - baseFunctionClass -} - -func (c *periodAddFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) - if err != nil { - return nil, err - } - bf.tp.SetFlen(6) - sig := &builtinPeriodAddSig{bf} - return sig, nil -} - -// validPeriod checks if this period is valid, it comes from MySQL 8.0+. -func validPeriod(p int64) bool { - return !(p < 0 || p%100 == 0 || p%100 > 12) -} - -// period2Month converts a period to months, in which period is represented in the format of YYMM or YYYYMM. -// Note that the period argument is not a date value. -func period2Month(period uint64) uint64 { - if period == 0 { - return 0 - } - - year, month := period/100, period%100 - if year < 70 { - year += 2000 - } else if year < 100 { - year += 1900 - } - - return year*12 + month - 1 -} - -// month2Period converts a month to a period. -func month2Period(month uint64) uint64 { - if month == 0 { - return 0 - } - - year := month / 12 - if year < 70 { - year += 2000 - } else if year < 100 { - year += 1900 - } - - return year*100 + month%12 + 1 -} - -type builtinPeriodAddSig struct { - baseBuiltinFunc -} - -func (b *builtinPeriodAddSig) Clone() builtinFunc { - newSig := &builtinPeriodAddSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals PERIOD_ADD(P,N). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_period-add -func (b *builtinPeriodAddSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - p, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return 0, true, err - } - - n, isNull, err := b.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return 0, true, err - } - - // in MySQL, if p is invalid but n is NULL, the result is NULL, so we have to check if n is NULL first. - if !validPeriod(p) { - return 0, false, errIncorrectArgs.GenWithStackByArgs("period_add") - } - - sumMonth := int64(period2Month(uint64(p))) + n - return int64(month2Period(uint64(sumMonth))), false, nil -} - -type periodDiffFunctionClass struct { - baseFunctionClass -} - -func (c *periodDiffFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt, types.ETInt) - if err != nil { - return nil, err - } - bf.tp.SetFlen(6) - sig := &builtinPeriodDiffSig{bf} - return sig, nil -} - -type builtinPeriodDiffSig struct { - baseBuiltinFunc -} - -func (b *builtinPeriodDiffSig) Clone() builtinFunc { - newSig := &builtinPeriodDiffSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals PERIOD_DIFF(P1,P2). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_period-diff -func (b *builtinPeriodDiffSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - p1, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - - p2, isNull, err := b.args[1].EvalInt(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - - if !validPeriod(p1) { - return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") - } - - if !validPeriod(p2) { - return 0, false, errIncorrectArgs.GenWithStackByArgs("period_diff") - } - - return int64(period2Month(uint64(p1)) - period2Month(uint64(p2))), false, nil -} - -type quarterFunctionClass struct { - baseFunctionClass -} - -func (c *quarterFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - bf.tp.SetFlen(1) - - sig := &builtinQuarterSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_Quarter) - return sig, nil -} - -type builtinQuarterSig struct { - baseBuiltinFunc -} - -func (b *builtinQuarterSig) Clone() builtinFunc { - newSig := &builtinQuarterSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals QUARTER(date). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_quarter -func (b *builtinQuarterSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - date, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - - return int64((date.Month() + 2) / 3), false, nil -} - -type secToTimeFunctionClass struct { - baseFunctionClass -} - -func (c *secToTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - var retFsp int - argType := args[0].GetType(ctx.GetEvalCtx()) - argEvalTp := argType.EvalType() - if argEvalTp == types.ETString { - retFsp = types.UnspecifiedLength - } else { - retFsp = argType.GetDecimal() - } - if retFsp > types.MaxFsp || retFsp == types.UnspecifiedFsp { - retFsp = types.MaxFsp - } else if retFsp < types.MinFsp { - retFsp = types.MinFsp - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, types.ETReal) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(retFsp) - sig := &builtinSecToTimeSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SecToTime) - return sig, nil -} - -type builtinSecToTimeSig struct { - baseBuiltinFunc -} - -func (b *builtinSecToTimeSig) Clone() builtinFunc { - newSig := &builtinSecToTimeSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals SEC_TO_TIME(seconds). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_sec-to-time -func (b *builtinSecToTimeSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - secondsFloat, isNull, err := b.args[0].EvalReal(ctx, row) - if isNull || err != nil { - return types.Duration{}, isNull, err - } - var ( - hour uint64 - minute uint64 - second uint64 - demical float64 - secondDemical float64 - negative string - ) - - if secondsFloat < 0 { - negative = "-" - secondsFloat = math.Abs(secondsFloat) - } - seconds := uint64(secondsFloat) - demical = secondsFloat - float64(seconds) - - hour = seconds / 3600 - if hour > 838 { - hour = 838 - minute = 59 - second = 59 - demical = 0 - tc := typeCtx(ctx) - err = tc.HandleTruncate(errTruncatedWrongValue.GenWithStackByArgs("time", strconv.FormatFloat(secondsFloat, 'f', -1, 64))) - if err != nil { - return types.Duration{}, err != nil, err - } - } else { - minute = seconds % 3600 / 60 - second = seconds % 60 - } - secondDemical = float64(second) + demical - - var dur types.Duration - dur, _, err = types.ParseDuration(typeCtx(ctx), fmt.Sprintf("%s%02d:%02d:%s", negative, hour, minute, strconv.FormatFloat(secondDemical, 'f', -1, 64)), b.tp.GetDecimal()) - if err != nil { - return types.Duration{}, err != nil, err - } - return dur, false, nil -} - -type subTimeFunctionClass struct { - baseFunctionClass -} - -func (c *subTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (sig builtinFunc, err error) { - if err = c.verifyArgs(args); err != nil { - return nil, err - } - tp1, tp2, bf, err := getBf4TimeAddSub(ctx, c.funcName, args) - if err != nil { - return nil, err - } - switch tp1.GetType() { - case mysql.TypeDatetime, mysql.TypeTimestamp: - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinSubDatetimeAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubDatetimeAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinSubTimeDateTimeNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubTimeDateTimeNull) - default: - sig = &builtinSubDatetimeAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubDatetimeAndString) - } - case mysql.TypeDate: - charset, collate := ctx.GetCharsetInfo() - bf.tp.SetCharset(charset) - bf.tp.SetCollate(collate) - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinSubDateAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubDateAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinSubTimeStringNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubTimeStringNull) - default: - sig = &builtinSubDateAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubDateAndString) - } - case mysql.TypeDuration: - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinSubDurationAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubDurationAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinSubTimeDurationNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubTimeDurationNull) - default: - sig = &builtinSubDurationAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubDurationAndString) - } - default: - switch tp2.GetType() { - case mysql.TypeDuration: - sig = &builtinSubStringAndDurationSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubStringAndDuration) - case mysql.TypeDatetime, mysql.TypeTimestamp: - sig = &builtinSubTimeStringNullSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubTimeStringNull) - default: - sig = &builtinSubStringAndStringSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_SubStringAndString) - } - } - return sig, nil -} - -type builtinSubDatetimeAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinSubDatetimeAndDurationSig) Clone() builtinFunc { - newSig := &builtinSubDatetimeAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinSubDatetimeAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubDatetimeAndDurationSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - arg0, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - arg1, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - tc := typeCtx(ctx) - result, err := arg0.Add(tc, arg1.Neg()) - return result, err != nil, err -} - -type builtinSubDatetimeAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinSubDatetimeAndStringSig) Clone() builtinFunc { - newSig := &builtinSubDatetimeAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinSubDatetimeAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubDatetimeAndStringSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - arg0, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - s, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroDatetime, isNull, err - } - if !isDuration(s) { - return types.ZeroDatetime, true, nil - } - tc := typeCtx(ctx) - arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return types.ZeroDatetime, true, nil - } - return types.ZeroDatetime, true, err - } - result, err := arg0.Add(tc, arg1.Neg()) - return result, err != nil, err -} - -type builtinSubTimeDateTimeNullSig struct { - baseBuiltinFunc -} - -func (b *builtinSubTimeDateTimeNullSig) Clone() builtinFunc { - newSig := &builtinSubTimeDateTimeNullSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinSubTimeDateTimeNullSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubTimeDateTimeNullSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - return types.ZeroDatetime, true, nil -} - -type builtinSubStringAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinSubStringAndDurationSig) Clone() builtinFunc { - newSig := &builtinSubStringAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinSubStringAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubStringAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { - var ( - arg0 string - arg1 types.Duration - ) - arg0, isNull, err = b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - arg1, isNull, err = b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - tc := typeCtx(ctx) - if isDuration(arg0) { - result, err = strDurationSubDuration(tc, arg0, arg1) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - return result, false, nil - } - result, isNull, err = strDatetimeSubDuration(tc, arg0, arg1) - return result, isNull, err -} - -type builtinSubStringAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinSubStringAndStringSig) Clone() builtinFunc { - newSig := &builtinSubStringAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinSubStringAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubStringAndStringSig) evalString(ctx EvalContext, row chunk.Row) (result string, isNull bool, err error) { - var ( - s, arg0 string - arg1 types.Duration - ) - arg0, isNull, err = b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - arg1Type := b.args[1].GetType(ctx) - if mysql.HasBinaryFlag(arg1Type.GetFlag()) { - return "", true, nil - } - s, isNull, err = b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - tc := typeCtx(ctx) - arg1, _, err = types.ParseDuration(tc, s, getFsp4TimeAddSub(s)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - if isDuration(arg0) { - result, err = strDurationSubDuration(tc, arg0, arg1) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - return result, false, nil - } - result, isNull, err = strDatetimeSubDuration(tc, arg0, arg1) - return result, isNull, err -} - -type builtinSubTimeStringNullSig struct { - baseBuiltinFunc -} - -func (b *builtinSubTimeStringNullSig) Clone() builtinFunc { - newSig := &builtinSubTimeStringNullSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinSubTimeStringNullSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubTimeStringNullSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - return "", true, nil -} - -type builtinSubDurationAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinSubDurationAndDurationSig) Clone() builtinFunc { - newSig := &builtinSubDurationAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinSubDurationAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubDurationAndDurationSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - arg1, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - result, err := arg0.Sub(arg1) - if err != nil { - return types.ZeroDuration, true, err - } - return result, false, nil -} - -type builtinSubDurationAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinSubDurationAndStringSig) Clone() builtinFunc { - newSig := &builtinSubDurationAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinSubDurationAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubDurationAndStringSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - s, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return types.ZeroDuration, isNull, err - } - if !isDuration(s) { - return types.ZeroDuration, true, nil - } - tc := typeCtx(ctx) - arg1, _, err := types.ParseDuration(tc, s, types.GetFsp(s)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return types.ZeroDuration, true, nil - } - return types.ZeroDuration, true, err - } - result, err := arg0.Sub(arg1) - return result, err != nil, err -} - -type builtinSubTimeDurationNullSig struct { - baseBuiltinFunc -} - -func (b *builtinSubTimeDurationNullSig) Clone() builtinFunc { - newSig := &builtinSubTimeDurationNullSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinSubTimeDurationNullSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubTimeDurationNullSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - return types.ZeroDuration, true, nil -} - -type builtinSubDateAndDurationSig struct { - baseBuiltinFunc -} - -func (b *builtinSubDateAndDurationSig) Clone() builtinFunc { - newSig := &builtinSubDateAndDurationSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinSubDateAndDurationSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubDateAndDurationSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - arg1, isNull, err := b.args[1].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - result, err := arg0.Sub(arg1) - return result.String(), err != nil, err -} - -type builtinSubDateAndStringSig struct { - baseBuiltinFunc -} - -func (b *builtinSubDateAndStringSig) Clone() builtinFunc { - newSig := &builtinSubDateAndStringSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinSubDateAndStringSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subtime -func (b *builtinSubDateAndStringSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - arg0, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - s, isNull, err := b.args[1].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - if !isDuration(s) { - return "", true, nil - } - tc := typeCtx(ctx) - arg1, _, err := types.ParseDuration(tc, s, getFsp4TimeAddSub(s)) - if err != nil { - if terror.ErrorEqual(err, types.ErrTruncatedWrongVal) { - tc.AppendWarning(err) - return "", true, nil - } - return "", true, err - } - result, err := arg0.Sub(arg1) - if err != nil { - return "", true, err - } - return result.String(), false, nil -} - -type timeFormatFunctionClass struct { - baseFunctionClass -} - -func (c *timeFormatFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETDuration, types.ETString) - if err != nil { - return nil, err - } - // worst case: formatMask=%r%r%r...%r, each %r takes 11 characters - bf.tp.SetFlen((args[1].GetType(ctx.GetEvalCtx()).GetFlen() + 1) / 2 * 11) - sig := &builtinTimeFormatSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_TimeFormat) - return sig, nil -} - -type builtinTimeFormatSig struct { - baseBuiltinFunc -} - -func (b *builtinTimeFormatSig) Clone() builtinFunc { - newSig := &builtinTimeFormatSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalString evals a builtinTimeFormatSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time-format -func (b *builtinTimeFormatSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - dur, isNull, err := b.args[0].EvalDuration(ctx, row) - // if err != nil, then dur is ZeroDuration, outputs 00:00:00 in this case which follows the behavior of mysql. - if err != nil { - logutil.BgLogger().Warn("time_format.args[0].EvalDuration failed", zap.Error(err)) - } - if isNull { - return "", isNull, err - } - formatMask, isNull, err := b.args[1].EvalString(ctx, row) - if err != nil || isNull { - return "", isNull, err - } - res, err := b.formatTime(dur, formatMask) - return res, isNull, err -} - -// formatTime see https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time-format -func (b *builtinTimeFormatSig) formatTime(t types.Duration, formatMask string) (res string, err error) { - return t.DurationFormat(formatMask) -} - -type timeToSecFunctionClass struct { - baseFunctionClass -} - -func (c *timeToSecFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDuration) - if err != nil { - return nil, err - } - bf.tp.SetFlen(10) - sig := &builtinTimeToSecSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_TimeToSec) - return sig, nil -} - -type builtinTimeToSecSig struct { - baseBuiltinFunc -} - -func (b *builtinTimeToSecSig) Clone() builtinFunc { - newSig := &builtinTimeToSecSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals TIME_TO_SEC(time). -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_time-to-sec -func (b *builtinTimeToSecSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - duration, isNull, err := b.args[0].EvalDuration(ctx, row) - if isNull || err != nil { - return 0, isNull, err - } - var sign int - if duration.Duration >= 0 { - sign = 1 - } else { - sign = -1 - } - return int64(sign * (duration.Hour()*3600 + duration.Minute()*60 + duration.Second())), false, nil -} - -type timestampAddFunctionClass struct { - baseFunctionClass -} - -func (c *timestampAddFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString, types.ETReal, types.ETDatetime) - if err != nil { - return nil, err - } - flen := mysql.MaxDatetimeWidthNoFsp - con, ok := args[0].(*Constant) - if !ok { - return nil, errors.New("should not happened") - } - unit, null, err := con.EvalString(ctx.GetEvalCtx(), chunk.Row{}) - if null || err != nil { - return nil, errors.New("should not happened") - } - if unit == ast.TimeUnitMicrosecond.String() { - flen = mysql.MaxDatetimeWidthWithFsp - } - - bf.tp.SetFlen(flen) - sig := &builtinTimestampAddSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_TimestampAdd) - return sig, nil -} - -type builtinTimestampAddSig struct { - baseBuiltinFunc -} - -func (b *builtinTimestampAddSig) Clone() builtinFunc { - newSig := &builtinTimestampAddSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -var ( - minDatetimeInGoTime, _ = types.MinDatetime.GoTime(time.Local) - minDatetimeNanos = float64(minDatetimeInGoTime.Unix())*1e9 + float64(minDatetimeInGoTime.Nanosecond()) - maxDatetimeInGoTime, _ = types.MaxDatetime.GoTime(time.Local) - maxDatetimeNanos = float64(maxDatetimeInGoTime.Unix())*1e9 + float64(maxDatetimeInGoTime.Nanosecond()) - minDatetimeMonths = float64(types.MinDatetime.Year()*12 + types.MinDatetime.Month() - 1) // 0001-01-01 00:00:00 - maxDatetimeMonths = float64(types.MaxDatetime.Year()*12 + types.MaxDatetime.Month() - 1) // 9999-12-31 00:00:00 -) - -func validAddTime(nano1 float64, nano2 float64) bool { - return nano1+nano2 >= minDatetimeNanos && nano1+nano2 <= maxDatetimeNanos -} - -func validAddMonth(month1 float64, year, month int) bool { - tmp := month1 + float64(year)*12 + float64(month-1) - return tmp >= minDatetimeMonths && tmp <= maxDatetimeMonths -} - -func addUnitToTime(unit string, t time.Time, v float64) (time.Time, bool, error) { - s := math.Trunc(v * 1000000) - // round to the nearest int - v = math.Round(v) - var tb time.Time - nano := float64(t.Unix())*1e9 + float64(t.Nanosecond()) - switch unit { - case "MICROSECOND": - if !validAddTime(v*float64(time.Microsecond), nano) { - return tb, true, nil - } - tb = t.Add(time.Duration(v) * time.Microsecond) - case "SECOND": - if !validAddTime(s*float64(time.Microsecond), nano) { - return tb, true, nil - } - tb = t.Add(time.Duration(s) * time.Microsecond) - case "MINUTE": - if !validAddTime(v*float64(time.Minute), nano) { - return tb, true, nil - } - tb = t.Add(time.Duration(v) * time.Minute) - case "HOUR": - if !validAddTime(v*float64(time.Hour), nano) { - return tb, true, nil - } - tb = t.Add(time.Duration(v) * time.Hour) - case "DAY": - if !validAddTime(v*24*float64(time.Hour), nano) { - return tb, true, nil - } - tb = t.AddDate(0, 0, int(v)) - case "WEEK": - if !validAddTime(v*24*7*float64(time.Hour), nano) { - return tb, true, nil - } - tb = t.AddDate(0, 0, 7*int(v)) - case "MONTH": - if !validAddMonth(v, t.Year(), int(t.Month())) { - return tb, true, nil - } - - var err error - tb, err = types.AddDate(0, int64(v), 0, t) - if err != nil { - return tb, false, err - } - case "QUARTER": - if !validAddMonth(v*3, t.Year(), int(t.Month())) { - return tb, true, nil - } - tb = t.AddDate(0, 3*int(v), 0) - case "YEAR": - if !validAddMonth(v*12, t.Year(), int(t.Month())) { - return tb, true, nil - } - tb = t.AddDate(int(v), 0, 0) - default: - return tb, false, types.ErrWrongValue.GenWithStackByArgs(types.TimeStr, unit) - } - return tb, false, nil -} - -// evalString evals a builtinTimestampAddSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_timestampadd -func (b *builtinTimestampAddSig) evalString(ctx EvalContext, row chunk.Row) (string, bool, error) { - unit, isNull, err := b.args[0].EvalString(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - v, isNull, err := b.args[1].EvalReal(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - arg, isNull, err := b.args[2].EvalTime(ctx, row) - if isNull || err != nil { - return "", isNull, err - } - tm1, err := arg.GoTime(time.Local) - if err != nil { - tc := typeCtx(ctx) - tc.AppendWarning(err) - return "", true, nil - } - tb, overflow, err := addUnitToTime(unit, tm1, v) - if err != nil { - return "", true, err - } - if overflow { - return "", true, handleInvalidTimeError(ctx, types.ErrDatetimeFunctionOverflow.GenWithStackByArgs("datetime")) - } - fsp := types.DefaultFsp - // use MaxFsp when microsecond is not zero - if tb.Nanosecond()/1000 != 0 { - fsp = types.MaxFsp - } - r := types.NewTime(types.FromGoTime(tb), b.resolveType(arg.Type(), unit), fsp) - if err = r.Check(typeCtx(ctx)); err != nil { - return "", true, handleInvalidTimeError(ctx, err) - } - return r.String(), false, nil -} - -func (b *builtinTimestampAddSig) resolveType(typ uint8, unit string) uint8 { - // The approach below is from MySQL. - // The field type for the result of an Item_date function is defined as - // follows: - // - //- If first arg is a MYSQL_TYPE_DATETIME result is MYSQL_TYPE_DATETIME - //- If first arg is a MYSQL_TYPE_DATE and the interval type uses hours, - // minutes, seconds or microsecond then type is MYSQL_TYPE_DATETIME. - //- Otherwise the result is MYSQL_TYPE_STRING - // (This is because you can't know if the string contains a DATE, MYSQL_TIME - // or DATETIME argument) - if typ == mysql.TypeDate && (unit == "HOUR" || unit == "MINUTE" || unit == "SECOND" || unit == "MICROSECOND") { - return mysql.TypeDatetime - } - return typ -} - -type toDaysFunctionClass struct { - baseFunctionClass -} - -func (c *toDaysFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - sig := &builtinToDaysSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_ToDays) - return sig, nil -} - -type builtinToDaysSig struct { - baseBuiltinFunc -} - -func (b *builtinToDaysSig) Clone() builtinFunc { - newSig := &builtinToDaysSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinToDaysSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_to-days -func (b *builtinToDaysSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - if arg.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - ret := types.TimestampDiff("DAY", types.ZeroDate, arg) - if ret == 0 { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - return ret, false, nil -} - -type toSecondsFunctionClass struct { - baseFunctionClass -} - -func (c *toSecondsFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETDatetime) - if err != nil { - return nil, err - } - sig := &builtinToSecondsSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_ToSeconds) - return sig, nil -} - -type builtinToSecondsSig struct { - baseBuiltinFunc -} - -func (b *builtinToSecondsSig) Clone() builtinFunc { - newSig := &builtinToSecondsSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalInt evals a builtinToSecondsSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_to-seconds -func (b *builtinToSecondsSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return 0, true, handleInvalidTimeError(ctx, err) - } - if arg.InvalidZero() { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - ret := types.TimestampDiff("SECOND", types.ZeroDate, arg) - if ret == 0 { - return 0, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - return ret, false, nil -} - -type utcTimeFunctionClass struct { - baseFunctionClass -} - -func (c *utcTimeFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - argTps := make([]types.EvalType, 0, 1) - if len(args) == 1 { - argTps = append(argTps, types.ETInt) - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDuration, argTps...) - if err != nil { - return nil, err - } - fsp, err := getFspByIntArg(ctx, args) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForTime(fsp) - // 1. no sign. - // 2. hour is in the 2-digit range. - bf.tp.SetFlen(bf.tp.GetFlen() - 2) - - var sig builtinFunc - if len(args) == 1 { - sig = &builtinUTCTimeWithArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UTCTimeWithArg) - } else { - sig = &builtinUTCTimeWithoutArgSig{bf} - sig.setPbCode(tipb.ScalarFuncSig_UTCTimeWithoutArg) - } - return sig, nil -} - -type builtinUTCTimeWithoutArgSig struct { - baseBuiltinFunc -} - -func (b *builtinUTCTimeWithoutArgSig) Clone() builtinFunc { - newSig := &builtinUTCTimeWithoutArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinUTCTimeWithoutArgSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-time -func (b *builtinUTCTimeWithoutArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.Duration{}, true, err - } - v, _, err := types.ParseDuration(typeCtx(ctx), nowTs.UTC().Format(types.TimeFormat), 0) - return v, false, err -} - -type builtinUTCTimeWithArgSig struct { - baseBuiltinFunc -} - -func (b *builtinUTCTimeWithArgSig) Clone() builtinFunc { - newSig := &builtinUTCTimeWithArgSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalDuration evals a builtinUTCTimeWithArgSig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_utc-time -func (b *builtinUTCTimeWithArgSig) evalDuration(ctx EvalContext, row chunk.Row) (types.Duration, bool, error) { - fsp, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil { - return types.Duration{}, isNull, err - } - if fsp > int64(types.MaxFsp) { - return types.Duration{}, true, errors.Errorf("Too-big precision %v specified for 'utc_time'. Maximum is %v", fsp, types.MaxFsp) - } - if fsp < int64(types.MinFsp) { - return types.Duration{}, true, errors.Errorf("Invalid negative %d specified, must in [0, 6]", fsp) - } - nowTs, err := getStmtTimestamp(ctx) - if err != nil { - return types.Duration{}, true, err - } - v, _, err := types.ParseDuration(typeCtx(ctx), nowTs.UTC().Format(types.TimeFSPFormat), int(fsp)) - return v, false, err -} - -type lastDayFunctionClass struct { - baseFunctionClass -} - -func (c *lastDayFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDate() - sig := &builtinLastDaySig{bf} - sig.setPbCode(tipb.ScalarFuncSig_LastDay) - return sig, nil -} - -type builtinLastDaySig struct { - baseBuiltinFunc -} - -func (b *builtinLastDaySig) Clone() builtinFunc { - newSig := &builtinLastDaySig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinLastDaySig. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_last-day -func (b *builtinLastDaySig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - arg, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - tm := arg - year, month := tm.Year(), tm.Month() - if tm.Month() == 0 || (tm.Day() == 0 && sqlMode(ctx).HasNoZeroDateMode()) { - return types.ZeroTime, true, handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, arg.String())) - } - lastDay := types.GetLastDay(year, month) - ret := types.NewTime(types.FromDate(year, month, lastDay, 0, 0, 0, 0), mysql.TypeDate, types.DefaultFsp) - return ret, false, nil -} - -// getExpressionFsp calculates the fsp from given expression. -// This function must by called before calling newBaseBuiltinFuncWithTp. -func getExpressionFsp(ctx BuildContext, expression Expression) (int, error) { - constExp, isConstant := expression.(*Constant) - if isConstant { - str, isNil, err := constExp.EvalString(ctx.GetEvalCtx(), chunk.Row{}) - if isNil || err != nil { - return 0, err - } - return types.GetFsp(str), nil - } - warpExpr := WrapWithCastAsTime(ctx, expression, types.NewFieldType(mysql.TypeDatetime)) - return mathutil.Min(warpExpr.GetType(ctx.GetEvalCtx()).GetDecimal(), types.MaxFsp), nil -} - -// tidbParseTsoFunctionClass extracts physical time from a tso -type tidbParseTsoFunctionClass struct { - baseFunctionClass -} - -func (c *tidbParseTsoFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - argTp := args[0].GetType(ctx.GetEvalCtx()).EvalType() - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, argTp, types.ETInt) - if err != nil { - return nil, err - } - - bf.tp.SetType(mysql.TypeDatetime) - bf.tp.SetFlen(mysql.MaxDateWidth) - bf.tp.SetDecimal(types.DefaultFsp) - sig := &builtinTidbParseTsoSig{bf} - return sig, nil -} - -type builtinTidbParseTsoSig struct { - baseBuiltinFunc -} - -func (b *builtinTidbParseTsoSig) Clone() builtinFunc { - newSig := &builtinTidbParseTsoSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinTidbParseTsoSig. -func (b *builtinTidbParseTsoSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - arg, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil || arg <= 0 { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - - t := oracle.GetTimeFromTS(uint64(arg)) - result := types.NewTime(types.FromGoTime(t), mysql.TypeDatetime, types.MaxFsp) - err = result.ConvertTimeZone(time.Local, location(ctx)) - if err != nil { - return types.ZeroTime, true, err - } - return result, false, nil -} - -// tidbParseTsoFunctionClass extracts logical time from a tso -type tidbParseTsoLogicalFunctionClass struct { - baseFunctionClass -} - -func (c *tidbParseTsoLogicalFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt, types.ETInt) - if err != nil { - return nil, err - } - - sig := &builtinTidbParseTsoLogicalSig{bf} - return sig, nil -} - -type builtinTidbParseTsoLogicalSig struct { - baseBuiltinFunc -} - -func (b *builtinTidbParseTsoLogicalSig) Clone() builtinFunc { - newSig := &builtinTidbParseTsoLogicalSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -// evalTime evals a builtinTidbParseTsoLogicalSig. -func (b *builtinTidbParseTsoLogicalSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, error) { - arg, isNull, err := b.args[0].EvalInt(ctx, row) - if isNull || err != nil || arg <= 0 { - return 0, true, err - } - - t := oracle.ExtractLogical(uint64(arg)) - return t, false, nil -} - -// tidbBoundedStalenessFunctionClass reads a time window [a, b] and compares it with the latest SafeTS -// to determine which TS to use in a read only transaction. -type tidbBoundedStalenessFunctionClass struct { - baseFunctionClass -} - -func (c *tidbBoundedStalenessFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETDatetime, types.ETDatetime, types.ETDatetime) - if err != nil { - return nil, err - } - bf.setDecimalAndFlenForDatetime(3) - sig := &builtinTiDBBoundedStalenessSig{baseBuiltinFunc: bf} - return sig, nil -} - -type builtinTiDBBoundedStalenessSig struct { - baseBuiltinFunc - contextopt.SessionVarsPropReader - contextopt.KVStorePropReader -} - -// RequiredOptionalEvalProps implements the RequireOptionalEvalProps interface. -func (b *builtinTiDBBoundedStalenessSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { - return b.SessionVarsPropReader.RequiredOptionalEvalProps() | - b.KVStorePropReader.RequiredOptionalEvalProps() -} - -func (b *builtinTiDBBoundedStalenessSig) Clone() builtinFunc { - newSig := &builtinTidbParseTsoSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinTiDBBoundedStalenessSig) evalTime(ctx EvalContext, row chunk.Row) (types.Time, bool, error) { - store, err := b.GetKVStore(ctx) - if err != nil { - return types.ZeroTime, true, err - } - - vars, err := b.GetSessionVars(ctx) - if err != nil { - return types.ZeroTime, true, err - } - - leftTime, isNull, err := b.args[0].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - rightTime, isNull, err := b.args[1].EvalTime(ctx, row) - if isNull || err != nil { - return types.ZeroTime, true, handleInvalidTimeError(ctx, err) - } - if invalidLeftTime, invalidRightTime := leftTime.InvalidZero(), rightTime.InvalidZero(); invalidLeftTime || invalidRightTime { - if invalidLeftTime { - err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, leftTime.String())) - } - if invalidRightTime { - err = handleInvalidTimeError(ctx, types.ErrWrongValue.GenWithStackByArgs(types.DateTimeStr, rightTime.String())) - } - return types.ZeroTime, true, err - } - timeZone := getTimeZone(ctx) - minTime, err := leftTime.GoTime(timeZone) - if err != nil { - return types.ZeroTime, true, err - } - maxTime, err := rightTime.GoTime(timeZone) - if err != nil { - return types.ZeroTime, true, err - } - if minTime.After(maxTime) { - return types.ZeroTime, true, nil - } - // Because the minimum unit of a TSO is millisecond, so we only need fsp to be 3. - return types.NewTime(types.FromGoTime(calAppropriateTime(minTime, maxTime, GetStmtMinSafeTime(vars.StmtCtx, store, timeZone))), mysql.TypeDatetime, 3), false, nil -} - -// GetStmtMinSafeTime get minSafeTime -func GetStmtMinSafeTime(sc *stmtctx.StatementContext, store kv.Storage, timeZone *time.Location) time.Time { - var minSafeTS uint64 - txnScope := config.GetTxnScopeFromConfig() - if store != nil { - minSafeTS = store.GetMinSafeTS(txnScope) - } - // Inject mocked SafeTS for test. - failpoint.Inject("injectSafeTS", func(val failpoint.Value) { - injectTS := val.(int) - minSafeTS = uint64(injectTS) - }) - // Try to get from the stmt cache to make sure this function is deterministic. - minSafeTS = sc.GetOrStoreStmtCache(stmtctx.StmtSafeTSCacheKey, minSafeTS).(uint64) - return oracle.GetTimeFromTS(minSafeTS).In(timeZone) -} - -// CalAppropriateTime directly calls calAppropriateTime -func CalAppropriateTime(minTime, maxTime, minSafeTime time.Time) time.Time { - return calAppropriateTime(minTime, maxTime, minSafeTime) -} - -// For a SafeTS t and a time range [t1, t2]: -// 1. If t < t1, we will use t1 as the result, -// and with it, a read request may fail because it's an unreached SafeTS. -// 2. If t1 <= t <= t2, we will use t as the result, and with it, -// a read request won't fail. -// 2. If t2 < t, we will use t2 as the result, -// and with it, a read request won't fail because it's bigger than the latest SafeTS. -func calAppropriateTime(minTime, maxTime, minSafeTime time.Time) time.Time { - if minSafeTime.Before(minTime) || minSafeTime.After(maxTime) { - logutil.BgLogger().Debug("calAppropriateTime", - zap.Time("minTime", minTime), - zap.Time("maxTime", maxTime), - zap.Time("minSafeTime", minSafeTime)) - if minSafeTime.Before(minTime) { - return minTime - } else if minSafeTime.After(maxTime) { - return maxTime - } - } - logutil.BgLogger().Debug("calAppropriateTime", - zap.Time("minTime", minTime), - zap.Time("maxTime", maxTime), - zap.Time("minSafeTime", minSafeTime)) - return minSafeTime -} - -// getFspByIntArg is used by some time functions to get the result fsp. If len(expr) == 0, then the fsp is not explicit set, use 0 as default. -func getFspByIntArg(ctx BuildContext, exps []Expression) (int, error) { - if len(exps) == 0 { - return 0, nil - } - if len(exps) != 1 { - return 0, errors.Errorf("Should not happen, the num of argument should be 1, but got %d", len(exps)) - } - _, ok := exps[0].(*Constant) - if ok { - fsp, isNuLL, err := exps[0].EvalInt(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil || isNuLL { - // If isNULL, it may be a bug of parser. Return 0 to be compatible with old version. - return 0, err - } - if fsp > int64(types.MaxFsp) { - return 0, errors.Errorf("Too-big precision %v specified for 'curtime'. Maximum is %v", fsp, types.MaxFsp) - } else if fsp < int64(types.MinFsp) { - return 0, errors.Errorf("Invalid negative %d specified, must in [0, 6]", fsp) - } - return int(fsp), nil - } - // Should no happen. But our tests may generate non-constant input. - return 0, nil -} - -type tidbCurrentTsoFunctionClass struct { - baseFunctionClass -} - -func (c *tidbCurrentTsoFunctionClass) getFunction(ctx BuildContext, args []Expression) (builtinFunc, error) { - if err := c.verifyArgs(args); err != nil { - return nil, err - } - bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETInt) - if err != nil { - return nil, err - } - sig := &builtinTiDBCurrentTsoSig{baseBuiltinFunc: bf} - return sig, nil -} - -type builtinTiDBCurrentTsoSig struct { - baseBuiltinFunc - contextopt.SessionVarsPropReader -} - -func (b *builtinTiDBCurrentTsoSig) Clone() builtinFunc { - newSig := &builtinTiDBCurrentTsoSig{} - newSig.cloneFrom(&b.baseBuiltinFunc) - return newSig -} - -func (b *builtinTiDBCurrentTsoSig) RequiredOptionalEvalProps() OptionalEvalPropKeySet { - return b.SessionVarsPropReader.RequiredOptionalEvalProps() -} - -// evalInt evals currentTSO(). -func (b *builtinTiDBCurrentTsoSig) evalInt(ctx EvalContext, row chunk.Row) (val int64, isNull bool, err error) { - sessionVars, err := b.GetSessionVars(ctx) - if err != nil { - return 0, true, err - } - tso, _ := sessionVars.GetSessionOrGlobalSystemVar(context.Background(), "tidb_current_ts") - itso, _ := strconv.ParseInt(tso, 10, 64) - return itso, false, nil -} diff --git a/pkg/expression/expr_to_pb.go b/pkg/expression/expr_to_pb.go index c4b853f85919e..4c74eedb5e441 100644 --- a/pkg/expression/expr_to_pb.go +++ b/pkg/expression/expr_to_pb.go @@ -250,9 +250,9 @@ func (pc PbConverter) scalarFuncToPBExpr(expr *ScalarFunction) *tipb.Expr { // Check whether this function has ProtoBuf signature. pbCode := expr.Function.PbCode() if pbCode <= tipb.ScalarFuncSig_Unspecified { - if _, _err_ := failpoint.Eval(_curpkg_("PanicIfPbCodeUnspecified")); _err_ == nil { + failpoint.Inject("PanicIfPbCodeUnspecified", func() { panic(errors.Errorf("unspecified PbCode: %T", expr.Function)) - } + }) return nil } diff --git a/pkg/expression/expr_to_pb.go__failpoint_stash__ b/pkg/expression/expr_to_pb.go__failpoint_stash__ deleted file mode 100644 index 4c74eedb5e441..0000000000000 --- a/pkg/expression/expr_to_pb.go__failpoint_stash__ +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package expression - -import ( - "strconv" - - "github.com/gogo/protobuf/proto" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/mysql" - ast "github.com/pingcap/tidb/pkg/parser/types" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -// ExpressionsToPBList converts expressions to tipb.Expr list for new plan. -func ExpressionsToPBList(ctx EvalContext, exprs []Expression, client kv.Client) (pbExpr []*tipb.Expr, err error) { - pc := PbConverter{client: client, ctx: ctx} - for _, expr := range exprs { - v := pc.ExprToPB(expr) - if v == nil { - return nil, plannererrors.ErrInternal.GenWithStack("expression %v cannot be pushed down", expr.StringWithCtx(ctx, errors.RedactLogDisable)) - } - pbExpr = append(pbExpr, v) - } - return -} - -// ProjectionExpressionsToPBList converts PhysicalProjection's expressions to tipb.Expr list for new plan. -// It doesn't check type for top level column expression, since top level column expression doesn't imply any calculations -func ProjectionExpressionsToPBList(ctx EvalContext, exprs []Expression, client kv.Client) (pbExpr []*tipb.Expr, err error) { - pc := PbConverter{client: client, ctx: ctx} - for _, expr := range exprs { - var v *tipb.Expr - if column, ok := expr.(*Column); ok { - v = pc.columnToPBExpr(column, false) - } else { - v = pc.ExprToPB(expr) - } - if v == nil { - return nil, plannererrors.ErrInternal.GenWithStack("expression %v cannot be pushed down", expr.StringWithCtx(ctx, errors.RedactLogDisable)) - } - pbExpr = append(pbExpr, v) - } - return -} - -// PbConverter supplies methods to convert TiDB expressions to TiPB. -type PbConverter struct { - client kv.Client - ctx EvalContext -} - -// NewPBConverter creates a PbConverter. -func NewPBConverter(client kv.Client, ctx EvalContext) PbConverter { - return PbConverter{client: client, ctx: ctx} -} - -// ExprToPB converts Expression to TiPB. -func (pc PbConverter) ExprToPB(expr Expression) *tipb.Expr { - switch x := expr.(type) { - case *Constant: - pbExpr := pc.conOrCorColToPBExpr(expr) - if pbExpr == nil { - return nil - } - return pbExpr - case *CorrelatedColumn: - return pc.conOrCorColToPBExpr(expr) - case *Column: - return pc.columnToPBExpr(x, true) - case *ScalarFunction: - return pc.scalarFuncToPBExpr(x) - } - return nil -} - -func (pc PbConverter) conOrCorColToPBExpr(expr Expression) *tipb.Expr { - ft := expr.GetType(pc.ctx) - d, err := expr.Eval(pc.ctx, chunk.Row{}) - if err != nil { - logutil.BgLogger().Error("eval constant or correlated column", zap.String("expression", expr.ExplainInfo(pc.ctx)), zap.Error(err)) - return nil - } - tp, val, ok := pc.encodeDatum(ft, d) - if !ok { - return nil - } - - if !pc.client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tp)) { - return nil - } - return &tipb.Expr{Tp: tp, Val: val, FieldType: ToPBFieldType(ft)} -} - -func (pc *PbConverter) encodeDatum(ft *types.FieldType, d types.Datum) (tipb.ExprType, []byte, bool) { - var ( - tp tipb.ExprType - val []byte - ) - switch d.Kind() { - case types.KindNull: - tp = tipb.ExprType_Null - case types.KindInt64: - tp = tipb.ExprType_Int64 - val = codec.EncodeInt(nil, d.GetInt64()) - case types.KindUint64: - tp = tipb.ExprType_Uint64 - val = codec.EncodeUint(nil, d.GetUint64()) - case types.KindString, types.KindBinaryLiteral: - tp = tipb.ExprType_String - val = d.GetBytes() - case types.KindMysqlBit: - tp = tipb.ExprType_MysqlBit - val = d.GetBytes() - case types.KindBytes: - tp = tipb.ExprType_Bytes - val = d.GetBytes() - case types.KindFloat32: - tp = tipb.ExprType_Float32 - val = codec.EncodeFloat(nil, d.GetFloat64()) - case types.KindFloat64: - tp = tipb.ExprType_Float64 - val = codec.EncodeFloat(nil, d.GetFloat64()) - case types.KindMysqlDuration: - tp = tipb.ExprType_MysqlDuration - val = codec.EncodeInt(nil, int64(d.GetMysqlDuration().Duration)) - case types.KindMysqlDecimal: - tp = tipb.ExprType_MysqlDecimal - var err error - // Use precision and fraction from MyDecimal instead of the ones in datum itself. - // These two set of parameters are not the same. MyDecimal is compatible with MySQL - // so the precision and fraction from MyDecimal are consistent with MySQL. The other - // ones come from the column type which belongs to the output schema. Here the datum - // are encoded into protobuf and will be used to do calculation so it should use the - // MyDecimal precision and fraction otherwise there may be a loss of accuracy. - val, err = codec.EncodeDecimal(nil, d.GetMysqlDecimal(), 0, 0) - if err != nil { - logutil.BgLogger().Error("encode decimal", zap.Error(err)) - return tp, nil, false - } - case types.KindMysqlTime: - if pc.client.IsRequestTypeSupported(kv.ReqTypeDAG, int64(tipb.ExprType_MysqlTime)) { - tp = tipb.ExprType_MysqlTime - tc, ec := typeCtx(pc.ctx), errCtx(pc.ctx) - val, err := codec.EncodeMySQLTime(tc.Location(), d.GetMysqlTime(), ft.GetType(), nil) - err = ec.HandleError(err) - if err != nil { - logutil.BgLogger().Error("encode mysql time", zap.Error(err)) - return tp, nil, false - } - return tp, val, true - } - return tp, nil, false - case types.KindMysqlEnum: - tp = tipb.ExprType_MysqlEnum - val = codec.EncodeUint(nil, d.GetUint64()) - default: - return tp, nil, false - } - return tp, val, true -} - -// ToPBFieldType converts *types.FieldType to *tipb.FieldType. -func ToPBFieldType(ft *types.FieldType) *tipb.FieldType { - return &tipb.FieldType{ - Tp: int32(ft.GetType()), - Flag: uint32(ft.GetFlag()), - Flen: int32(ft.GetFlen()), - Decimal: int32(ft.GetDecimal()), - Charset: ft.GetCharset(), - Collate: collate.CollationToProto(ft.GetCollate()), - Elems: ft.GetElems(), - } -} - -// ToPBFieldTypeWithCheck converts *types.FieldType to *tipb.FieldType with checking the valid decimal for TiFlash -func ToPBFieldTypeWithCheck(ft *types.FieldType, storeType kv.StoreType) (*tipb.FieldType, error) { - if storeType == kv.TiFlash && !ft.IsDecimalValid() { - return nil, errors.New(ft.String() + " can not be pushed to TiFlash because it contains invalid decimal('" + strconv.Itoa(ft.GetFlen()) + "','" + strconv.Itoa(ft.GetDecimal()) + "').") - } - return ToPBFieldType(ft), nil -} - -// FieldTypeFromPB converts *tipb.FieldType to *types.FieldType. -func FieldTypeFromPB(ft *tipb.FieldType) *types.FieldType { - ft1 := types.NewFieldTypeBuilder().SetType(byte(ft.Tp)).SetFlag(uint(ft.Flag)).SetFlen(int(ft.Flen)).SetDecimal(int(ft.Decimal)).SetCharset(ft.Charset).SetCollate(collate.ProtoToCollation(ft.Collate)).BuildP() - ft1.SetElems(ft.Elems) - return ft1 -} - -func (pc PbConverter) columnToPBExpr(column *Column, checkType bool) *tipb.Expr { - if !pc.client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tipb.ExprType_ColumnRef)) { - return nil - } - if checkType { - switch column.GetType(pc.ctx).GetType() { - case mysql.TypeBit: - if !IsPushDownEnabled(ast.TypeStr(mysql.TypeBit), kv.TiKV) { - return nil - } - case mysql.TypeSet, mysql.TypeGeometry, mysql.TypeUnspecified: - return nil - case mysql.TypeEnum: - if !IsPushDownEnabled("enum", kv.UnSpecified) { - return nil - } - } - } - - if pc.client.IsRequestTypeSupported(kv.ReqTypeDAG, kv.ReqSubTypeBasic) { - return &tipb.Expr{ - Tp: tipb.ExprType_ColumnRef, - Val: codec.EncodeInt(nil, int64(column.Index)), - FieldType: ToPBFieldType(column.RetType), - } - } - id := column.ID - // Zero Column ID is not a column from table, can not support for now. - if id == 0 || id == -1 { - return nil - } - - return &tipb.Expr{ - Tp: tipb.ExprType_ColumnRef, - Val: codec.EncodeInt(nil, id)} -} - -func (pc PbConverter) scalarFuncToPBExpr(expr *ScalarFunction) *tipb.Expr { - // Check whether this function has ProtoBuf signature. - pbCode := expr.Function.PbCode() - if pbCode <= tipb.ScalarFuncSig_Unspecified { - failpoint.Inject("PanicIfPbCodeUnspecified", func() { - panic(errors.Errorf("unspecified PbCode: %T", expr.Function)) - }) - return nil - } - - // Check whether this function can be pushed. - if !canFuncBePushed(pc.ctx, expr, kv.UnSpecified) { - return nil - } - - // Check whether all of its parameters can be pushed. - children := make([]*tipb.Expr, 0, len(expr.GetArgs())) - for _, arg := range expr.GetArgs() { - pbArg := pc.ExprToPB(arg) - if pbArg == nil { - return nil - } - children = append(children, pbArg) - } - - var encoded []byte - if metadata := expr.Function.metadata(); metadata != nil { - var err error - encoded, err = proto.Marshal(metadata) - if err != nil { - logutil.BgLogger().Error("encode metadata", zap.Any("metadata", metadata), zap.Error(err)) - return nil - } - } - - // put collation information into the RetType enforcedly and push it down to TiKV/MockTiKV - tp := *expr.RetType - if collate.NewCollationEnabled() { - _, str1 := expr.CharsetAndCollation() - tp.SetCollate(str1) - } - - // Construct expression ProtoBuf. - return &tipb.Expr{ - Tp: tipb.ExprType_ScalarFunc, - Val: encoded, - Sig: pbCode, - Children: children, - FieldType: ToPBFieldType(&tp), - } -} - -// GroupByItemToPB converts group by items to pb. -func GroupByItemToPB(ctx EvalContext, client kv.Client, expr Expression) *tipb.ByItem { - pc := PbConverter{client: client, ctx: ctx} - e := pc.ExprToPB(expr) - if e == nil { - return nil - } - return &tipb.ByItem{Expr: e} -} - -// SortByItemToPB converts order by items to pb. -func SortByItemToPB(ctx EvalContext, client kv.Client, expr Expression, desc bool) *tipb.ByItem { - pc := PbConverter{client: client, ctx: ctx} - e := pc.ExprToPB(expr) - if e == nil { - return nil - } - return &tipb.ByItem{Expr: e, Desc: desc} -} diff --git a/pkg/expression/helper.go b/pkg/expression/helper.go index 26e4c11443098..154d599142d95 100644 --- a/pkg/expression/helper.go +++ b/pkg/expression/helper.go @@ -162,9 +162,9 @@ func GetTimeValue(ctx BuildContext, v any, tp byte, fsp int, explicitTz *time.Lo // if timestamp session variable set, use session variable as current time, otherwise use cached time // during one sql statement, the "current_time" should be the same func getStmtTimestamp(ctx EvalContext) (time.Time, error) { - if val, _err_ := failpoint.Eval(_curpkg_("injectNow")); _err_ == nil { + failpoint.Inject("injectNow", func(val failpoint.Value) { v := time.Unix(int64(val.(int)), 0) - return v, nil - } + failpoint.Return(v, nil) + }) return ctx.CurrentTime() } diff --git a/pkg/expression/helper.go__failpoint_stash__ b/pkg/expression/helper.go__failpoint_stash__ deleted file mode 100644 index 154d599142d95..0000000000000 --- a/pkg/expression/helper.go__failpoint_stash__ +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package expression - -import ( - "math" - "strings" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/types" - driver "github.com/pingcap/tidb/pkg/types/parser_driver" -) - -func boolToInt64(v bool) int64 { - if v { - return 1 - } - return 0 -} - -// IsValidCurrentTimestampExpr returns true if exprNode is a valid CurrentTimestamp expression. -// Here `valid` means it is consistent with the given fieldType's decimal. -func IsValidCurrentTimestampExpr(exprNode ast.ExprNode, fieldType *types.FieldType) bool { - fn, isFuncCall := exprNode.(*ast.FuncCallExpr) - if !isFuncCall || fn.FnName.L != ast.CurrentTimestamp { - return false - } - - containsArg := len(fn.Args) > 0 - // Fsp represents fractional seconds precision. - containsFsp := fieldType != nil && fieldType.GetDecimal() > 0 - var isConsistent bool - if containsArg { - v, ok := fn.Args[0].(*driver.ValueExpr) - isConsistent = ok && fieldType != nil && v.Datum.GetInt64() == int64(fieldType.GetDecimal()) - } - - return (containsArg && isConsistent) || (!containsArg && !containsFsp) -} - -// GetTimeCurrentTimestamp is used for generating a timestamp for some special cases: cast null value to timestamp type with not null flag. -func GetTimeCurrentTimestamp(ctx EvalContext, tp byte, fsp int) (d types.Datum, err error) { - var t types.Time - t, err = getTimeCurrentTimeStamp(ctx, tp, fsp) - if err != nil { - return d, err - } - d.SetMysqlTime(t) - return d, nil -} - -func getTimeCurrentTimeStamp(ctx EvalContext, tp byte, fsp int) (t types.Time, err error) { - value := types.NewTime(types.ZeroCoreTime, tp, fsp) - defaultTime, err := getStmtTimestamp(ctx) - if err != nil { - return value, err - } - value.SetCoreTime(types.FromGoTime(defaultTime.Truncate(time.Duration(math.Pow10(9-fsp)) * time.Nanosecond))) - if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime || tp == mysql.TypeDate { - err = value.ConvertTimeZone(time.Local, ctx.Location()) - if err != nil { - return value, err - } - } - return value, nil -} - -// GetTimeValue gets the time value with type tp. -func GetTimeValue(ctx BuildContext, v any, tp byte, fsp int, explicitTz *time.Location) (d types.Datum, err error) { - var value types.Time - tc := ctx.GetEvalCtx().TypeCtx() - if explicitTz != nil { - tc = tc.WithLocation(explicitTz) - } - - switch x := v.(type) { - case string: - lowerX := strings.ToLower(x) - switch lowerX { - case ast.CurrentTimestamp: - if value, err = getTimeCurrentTimeStamp(ctx.GetEvalCtx(), tp, fsp); err != nil { - return d, err - } - case ast.CurrentDate: - if value, err = getTimeCurrentTimeStamp(ctx.GetEvalCtx(), tp, fsp); err != nil { - return d, err - } - yy, mm, dd := value.Year(), value.Month(), value.Day() - truncated := types.FromDate(yy, mm, dd, 0, 0, 0, 0) - value.SetCoreTime(truncated) - case types.ZeroDatetimeStr: - value, err = types.ParseTimeFromNum(tc, 0, tp, fsp) - terror.Log(err) - default: - value, err = types.ParseTime(tc, x, tp, fsp) - if err != nil { - return d, err - } - } - case *driver.ValueExpr: - switch x.Kind() { - case types.KindString: - value, err = types.ParseTime(tc, x.GetString(), tp, fsp) - if err != nil { - return d, err - } - case types.KindInt64: - value, err = types.ParseTimeFromNum(tc, x.GetInt64(), tp, fsp) - if err != nil { - return d, err - } - case types.KindNull: - return d, nil - default: - return d, errDefaultValue - } - case *ast.FuncCallExpr: - if x.FnName.L == ast.CurrentTimestamp || x.FnName.L == ast.CurrentDate { - d.SetString(strings.ToUpper(x.FnName.L), mysql.DefaultCollationName) - return d, nil - } - return d, errDefaultValue - case *ast.UnaryOperationExpr: - // support some expression, like `-1` - v, err := EvalSimpleAst(ctx, x) - if err != nil { - return d, err - } - ft := types.NewFieldType(mysql.TypeLonglong) - xval, err := v.ConvertTo(tc, ft) - if err != nil { - return d, err - } - - value, err = types.ParseTimeFromNum(tc, xval.GetInt64(), tp, fsp) - if err != nil { - return d, err - } - default: - return d, nil - } - d.SetMysqlTime(value) - return d, nil -} - -// if timestamp session variable set, use session variable as current time, otherwise use cached time -// during one sql statement, the "current_time" should be the same -func getStmtTimestamp(ctx EvalContext) (time.Time, error) { - failpoint.Inject("injectNow", func(val failpoint.Value) { - v := time.Unix(int64(val.(int)), 0) - failpoint.Return(v, nil) - }) - return ctx.CurrentTime() -} diff --git a/pkg/expression/infer_pushdown.go b/pkg/expression/infer_pushdown.go index 9ab3672594ad1..4af5c0e912a8d 100644 --- a/pkg/expression/infer_pushdown.go +++ b/pkg/expression/infer_pushdown.go @@ -47,19 +47,19 @@ func canFuncBePushed(ctx EvalContext, sf *ScalarFunction, storeType kv.StoreType // Push down all expression if the `failpoint expression` is `all`, otherwise, check // whether scalar function's name is contained in the enabled expression list (e.g.`ne,eq,lt`). // If neither of the above is true, switch to original logic. - if val, _err_ := failpoint.Eval(_curpkg_("PushDownTestSwitcher")); _err_ == nil { + failpoint.Inject("PushDownTestSwitcher", func(val failpoint.Value) { enabled := val.(string) if enabled == "all" { - return true + failpoint.Return(true) } exprs := strings.Split(enabled, ",") for _, expr := range exprs { if strings.ToLower(strings.TrimSpace(expr)) == sf.FuncName.L { - return true + failpoint.Return(true) } } - return false - } + failpoint.Return(false) + }) ret := false @@ -87,9 +87,9 @@ func canScalarFuncPushDown(ctx PushDownContext, scalarFunc *ScalarFunction, stor // Check whether this function can be pushed. if unspecified := pbCode <= tipb.ScalarFuncSig_Unspecified; unspecified || !canFuncBePushed(ctx.EvalCtx(), scalarFunc, storeType) { if unspecified { - if _, _err_ := failpoint.Eval(_curpkg_("PanicIfPbCodeUnspecified")); _err_ == nil { + failpoint.Inject("PanicIfPbCodeUnspecified", func() { panic(errors.Errorf("unspecified PbCode: %T", scalarFunc.Function)) - } + }) } storageName := storeType.Name() if storeType == kv.UnSpecified { diff --git a/pkg/expression/infer_pushdown.go__failpoint_stash__ b/pkg/expression/infer_pushdown.go__failpoint_stash__ deleted file mode 100644 index 4af5c0e912a8d..0000000000000 --- a/pkg/expression/infer_pushdown.go__failpoint_stash__ +++ /dev/null @@ -1,536 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package expression - -import ( - "fmt" - "strconv" - "strings" - "sync/atomic" - - "github.com/gogo/protobuf/proto" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -// DefaultExprPushDownBlacklist indicates the expressions which can not be pushed down to TiKV. -var DefaultExprPushDownBlacklist *atomic.Value - -// ExprPushDownBlackListReloadTimeStamp is used to record the last time when the push-down black list is reloaded. -// This is for plan cache, when the push-down black list is updated, we invalid all cached plans to avoid error. -var ExprPushDownBlackListReloadTimeStamp *atomic.Int64 - -func canFuncBePushed(ctx EvalContext, sf *ScalarFunction, storeType kv.StoreType) bool { - // Use the failpoint to control whether to push down an expression in the integration test. - // Push down all expression if the `failpoint expression` is `all`, otherwise, check - // whether scalar function's name is contained in the enabled expression list (e.g.`ne,eq,lt`). - // If neither of the above is true, switch to original logic. - failpoint.Inject("PushDownTestSwitcher", func(val failpoint.Value) { - enabled := val.(string) - if enabled == "all" { - failpoint.Return(true) - } - exprs := strings.Split(enabled, ",") - for _, expr := range exprs { - if strings.ToLower(strings.TrimSpace(expr)) == sf.FuncName.L { - failpoint.Return(true) - } - } - failpoint.Return(false) - }) - - ret := false - - switch storeType { - case kv.TiFlash: - ret = scalarExprSupportedByFlash(ctx, sf) - case kv.TiKV: - ret = scalarExprSupportedByTiKV(ctx, sf) - case kv.TiDB: - ret = scalarExprSupportedByTiDB(ctx, sf) - case kv.UnSpecified: - ret = scalarExprSupportedByTiDB(ctx, sf) || scalarExprSupportedByTiKV(ctx, sf) || scalarExprSupportedByFlash(ctx, sf) - } - - if ret { - funcFullName := fmt.Sprintf("%s.%s", sf.FuncName.L, strings.ToLower(sf.Function.PbCode().String())) - // Aside from checking function name, also check the pb name in case only the specific push down is disabled. - ret = IsPushDownEnabled(sf.FuncName.L, storeType) && IsPushDownEnabled(funcFullName, storeType) - } - return ret -} - -func canScalarFuncPushDown(ctx PushDownContext, scalarFunc *ScalarFunction, storeType kv.StoreType) bool { - pbCode := scalarFunc.Function.PbCode() - // Check whether this function can be pushed. - if unspecified := pbCode <= tipb.ScalarFuncSig_Unspecified; unspecified || !canFuncBePushed(ctx.EvalCtx(), scalarFunc, storeType) { - if unspecified { - failpoint.Inject("PanicIfPbCodeUnspecified", func() { - panic(errors.Errorf("unspecified PbCode: %T", scalarFunc.Function)) - }) - } - storageName := storeType.Name() - if storeType == kv.UnSpecified { - storageName = "storage layer" - } - warnErr := errors.NewNoStackError("Scalar function '" + scalarFunc.FuncName.L + "'(signature: " + scalarFunc.Function.PbCode().String() + ", return type: " + scalarFunc.RetType.CompactStr() + ") is not supported to push down to " + storageName + " now.") - - ctx.AppendWarning(warnErr) - return false - } - canEnumPush := canEnumPushdownPreliminarily(scalarFunc) - // Check whether all of its parameters can be pushed. - for _, arg := range scalarFunc.GetArgs() { - if !canExprPushDown(ctx, arg, storeType, canEnumPush) { - return false - } - } - - if metadata := scalarFunc.Function.metadata(); metadata != nil { - var err error - _, err = proto.Marshal(metadata) - if err != nil { - logutil.BgLogger().Error("encode metadata", zap.Any("metadata", metadata), zap.Error(err)) - return false - } - } - return true -} - -func canExprPushDown(ctx PushDownContext, expr Expression, storeType kv.StoreType, canEnumPush bool) bool { - pc := ctx.PbConverter() - if storeType == kv.TiFlash { - switch expr.GetType(ctx.EvalCtx()).GetType() { - case mysql.TypeEnum, mysql.TypeBit, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeUnspecified: - if expr.GetType(ctx.EvalCtx()).GetType() == mysql.TypeEnum && canEnumPush { - break - } - warnErr := errors.NewNoStackError("Expression about '" + expr.StringWithCtx(ctx.EvalCtx(), errors.RedactLogDisable) + "' can not be pushed to TiFlash because it contains unsupported calculation of type '" + types.TypeStr(expr.GetType(ctx.EvalCtx()).GetType()) + "'.") - ctx.AppendWarning(warnErr) - return false - case mysql.TypeNewDecimal: - if !expr.GetType(ctx.EvalCtx()).IsDecimalValid() { - warnErr := errors.NewNoStackError("Expression about '" + expr.StringWithCtx(ctx.EvalCtx(), errors.RedactLogDisable) + "' can not be pushed to TiFlash because it contains invalid decimal('" + strconv.Itoa(expr.GetType(ctx.EvalCtx()).GetFlen()) + "','" + strconv.Itoa(expr.GetType(ctx.EvalCtx()).GetDecimal()) + "').") - ctx.AppendWarning(warnErr) - return false - } - } - } - switch x := expr.(type) { - case *CorrelatedColumn: - return pc.conOrCorColToPBExpr(expr) != nil && pc.columnToPBExpr(&x.Column, true) != nil - case *Constant: - return pc.conOrCorColToPBExpr(expr) != nil - case *Column: - return pc.columnToPBExpr(x, true) != nil - case *ScalarFunction: - return canScalarFuncPushDown(ctx, x, storeType) - } - return false -} - -func scalarExprSupportedByTiDB(ctx EvalContext, function *ScalarFunction) bool { - // TiDB can support all functions, but TiPB may not include some functions. - return scalarExprSupportedByTiKV(ctx, function) || scalarExprSupportedByFlash(ctx, function) -} - -// supported functions tracked by https://github.com/tikv/tikv/issues/5751 -func scalarExprSupportedByTiKV(ctx EvalContext, sf *ScalarFunction) bool { - switch sf.FuncName.L { - case - // op functions. - ast.LogicAnd, ast.LogicOr, ast.LogicXor, ast.UnaryNot, ast.And, ast.Or, ast.Xor, ast.BitNeg, ast.LeftShift, ast.RightShift, ast.UnaryMinus, - - // compare functions. - ast.LT, ast.LE, ast.EQ, ast.NE, ast.GE, ast.GT, ast.NullEQ, ast.In, ast.IsNull, ast.Like, ast.IsTruthWithoutNull, ast.IsTruthWithNull, ast.IsFalsity, - // ast.Greatest, ast.Least, ast.Interval - - // arithmetical functions. - ast.PI, /* ast.Truncate */ - ast.Plus, ast.Minus, ast.Mul, ast.Div, ast.Abs, ast.Mod, ast.IntDiv, - - // math functions. - ast.Ceil, ast.Ceiling, ast.Floor, ast.Sqrt, ast.Sign, ast.Ln, ast.Log, ast.Log2, ast.Log10, ast.Exp, ast.Pow, ast.Power, - - // Rust use the llvm math functions, which have different precision with Golang/MySQL(cmath) - // open the following switchers if we implement them in coprocessor via `cmath` - ast.Sin, ast.Asin, ast.Cos, ast.Acos /* ast.Tan */, ast.Atan, ast.Atan2, ast.Cot, - ast.Radians, ast.Degrees, ast.CRC32, - - // control flow functions. - ast.Case, ast.If, ast.Ifnull, ast.Coalesce, - - // string functions. - // ast.Bin, ast.Unhex, ast.Locate, ast.Ord, ast.Lpad, ast.Rpad, - // ast.Trim, ast.FromBase64, ast.ToBase64, ast.InsertFunc, - // ast.MakeSet, ast.SubstringIndex, ast.Instr, ast.Quote, ast.Oct, - // ast.FindInSet, ast.Repeat, - ast.Upper, ast.Lower, - ast.Length, ast.BitLength, ast.Concat, ast.ConcatWS, ast.Replace, ast.ASCII, ast.Hex, - ast.Reverse, ast.LTrim, ast.RTrim, ast.Strcmp, ast.Space, ast.Elt, ast.Field, - InternalFuncFromBinary, InternalFuncToBinary, ast.Mid, ast.Substring, ast.Substr, ast.CharLength, - ast.Right, /* ast.Left */ - - // json functions. - ast.JSONType, ast.JSONExtract, ast.JSONObject, ast.JSONArray, ast.JSONMerge, ast.JSONSet, - ast.JSONInsert, ast.JSONReplace, ast.JSONRemove, ast.JSONLength, ast.JSONMergePatch, - ast.JSONUnquote, ast.JSONContains, ast.JSONValid, ast.JSONMemberOf, ast.JSONArrayAppend, - - // date functions. - ast.Date, ast.Week /* ast.YearWeek, ast.ToSeconds */, ast.DateDiff, - /* ast.TimeDiff, ast.AddTime, ast.SubTime, */ - ast.MonthName, ast.MakeDate, ast.TimeToSec, ast.MakeTime, - ast.DateFormat, - ast.Hour, ast.Minute, ast.Second, ast.MicroSecond, ast.Month, - /* ast.DayName */ ast.DayOfMonth, ast.DayOfWeek, ast.DayOfYear, - /* ast.Weekday */ ast.WeekOfYear, ast.Year, - ast.FromDays, /* ast.ToDays */ - ast.PeriodAdd, ast.PeriodDiff, /*ast.TimestampDiff, ast.DateAdd, ast.FromUnixTime,*/ - /* ast.LastDay */ - ast.Sysdate, - - // encryption functions. - ast.MD5, ast.SHA1, ast.UncompressedLength, - - ast.Cast, - - // misc functions. - // TODO(#26942): enable functions below after them are fully tested in TiKV. - /*ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, ast.IsIPv4, ast.IsIPv4Compat, ast.IsIPv4Mapped, ast.IsIPv6,*/ - ast.UUID: - - return true - // Rust use the llvm math functions, which have different precision with Golang/MySQL(cmath) - // open the following switchers if we implement them in coprocessor via `cmath` - case ast.Conv: - arg0 := sf.GetArgs()[0] - // To be aligned with MySQL, tidb handles hybrid type argument and binary literal specially, tikv can't be consistent with tidb now. - if f, ok := arg0.(*ScalarFunction); ok { - if f.FuncName.L == ast.Cast && (f.GetArgs()[0].GetType(ctx).Hybrid() || IsBinaryLiteral(f.GetArgs()[0])) { - return false - } - } - return true - case ast.Round: - switch sf.Function.PbCode() { - case tipb.ScalarFuncSig_RoundReal, tipb.ScalarFuncSig_RoundInt, tipb.ScalarFuncSig_RoundDec: - // We don't push round with frac due to mysql's round with frac has its special behavior: - // https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_round - return true - } - case ast.Rand: - switch sf.Function.PbCode() { - case tipb.ScalarFuncSig_RandWithSeedFirstGen: - return true - } - case ast.Regexp, ast.RegexpLike, ast.RegexpSubstr, ast.RegexpInStr, ast.RegexpReplace: - funcCharset, funcCollation := sf.Function.CharsetAndCollation() - if funcCharset == charset.CharsetBin && funcCollation == charset.CollationBin { - return false - } - return true - } - return false -} - -func scalarExprSupportedByFlash(ctx EvalContext, function *ScalarFunction) bool { - switch function.FuncName.L { - case ast.Floor, ast.Ceil, ast.Ceiling: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_FloorIntToDec, tipb.ScalarFuncSig_CeilIntToDec: - return false - default: - return true - } - case - ast.LogicOr, ast.LogicAnd, ast.UnaryNot, ast.BitNeg, ast.Xor, ast.And, ast.Or, ast.RightShift, ast.LeftShift, - ast.GE, ast.LE, ast.EQ, ast.NE, ast.LT, ast.GT, ast.In, ast.IsNull, ast.Like, ast.Ilike, ast.Strcmp, - ast.Plus, ast.Minus, ast.Div, ast.Mul, ast.Abs, ast.Mod, - ast.If, ast.Ifnull, ast.Case, - ast.Concat, ast.ConcatWS, - ast.Date, ast.Year, ast.Month, ast.Day, ast.Quarter, ast.DayName, ast.MonthName, - ast.DateDiff, ast.TimestampDiff, ast.DateFormat, ast.FromUnixTime, - ast.DayOfWeek, ast.DayOfMonth, ast.DayOfYear, ast.LastDay, ast.WeekOfYear, ast.ToSeconds, - ast.FromDays, ast.ToDays, - - ast.Sqrt, ast.Log, ast.Log2, ast.Log10, ast.Ln, ast.Exp, ast.Pow, ast.Power, ast.Sign, - ast.Radians, ast.Degrees, ast.Conv, ast.CRC32, - ast.JSONLength, ast.JSONDepth, ast.JSONExtract, ast.JSONUnquote, ast.JSONArray, ast.JSONContainsPath, ast.JSONValid, ast.JSONKeys, - ast.Repeat, ast.InetNtoa, ast.InetAton, ast.Inet6Ntoa, ast.Inet6Aton, - ast.Coalesce, ast.ASCII, ast.Length, ast.Trim, ast.Position, ast.Format, ast.Elt, - ast.LTrim, ast.RTrim, ast.Lpad, ast.Rpad, - ast.Hour, ast.Minute, ast.Second, ast.MicroSecond, - ast.TimeToSec: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_InDuration, - tipb.ScalarFuncSig_CoalesceDuration, - tipb.ScalarFuncSig_IfNullDuration, - tipb.ScalarFuncSig_IfDuration, - tipb.ScalarFuncSig_CaseWhenDuration: - return false - } - return true - case ast.Regexp, ast.RegexpLike, ast.RegexpInStr, ast.RegexpSubstr, ast.RegexpReplace: - funcCharset, funcCollation := function.Function.CharsetAndCollation() - if funcCharset == charset.CharsetBin && funcCollation == charset.CollationBin { - return false - } - return true - case ast.Substr, ast.Substring, ast.Left, ast.Right, ast.CharLength, ast.SubstringIndex, ast.Reverse: - switch function.Function.PbCode() { - case - tipb.ScalarFuncSig_LeftUTF8, - tipb.ScalarFuncSig_RightUTF8, - tipb.ScalarFuncSig_CharLengthUTF8, - tipb.ScalarFuncSig_Substring2ArgsUTF8, - tipb.ScalarFuncSig_Substring3ArgsUTF8, - tipb.ScalarFuncSig_SubstringIndex, - tipb.ScalarFuncSig_ReverseUTF8, - tipb.ScalarFuncSig_Reverse: - return true - } - case ast.Cast: - sourceType := function.GetArgs()[0].GetType(ctx) - retType := function.RetType - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_CastDecimalAsInt, tipb.ScalarFuncSig_CastIntAsInt, tipb.ScalarFuncSig_CastRealAsInt, tipb.ScalarFuncSig_CastTimeAsInt, - tipb.ScalarFuncSig_CastStringAsInt /*, tipb.ScalarFuncSig_CastDurationAsInt, tipb.ScalarFuncSig_CastJsonAsInt*/ : - // TiFlash cast only support cast to Int64 or the source type is the same as the target type - return (sourceType.GetType() == retType.GetType() && mysql.HasUnsignedFlag(sourceType.GetFlag()) == mysql.HasUnsignedFlag(retType.GetFlag())) || retType.GetType() == mysql.TypeLonglong - case tipb.ScalarFuncSig_CastIntAsReal, tipb.ScalarFuncSig_CastRealAsReal, tipb.ScalarFuncSig_CastStringAsReal, tipb.ScalarFuncSig_CastTimeAsReal, tipb.ScalarFuncSig_CastDecimalAsReal: /* - tipb.ScalarFuncSig_CastDurationAsReal, tipb.ScalarFuncSig_CastJsonAsReal*/ - // TiFlash cast only support cast to Float64 or the source type is the same as the target type - return sourceType.GetType() == retType.GetType() || retType.GetType() == mysql.TypeDouble - case tipb.ScalarFuncSig_CastDecimalAsDecimal, tipb.ScalarFuncSig_CastIntAsDecimal, tipb.ScalarFuncSig_CastRealAsDecimal, tipb.ScalarFuncSig_CastTimeAsDecimal, - tipb.ScalarFuncSig_CastStringAsDecimal /*, tipb.ScalarFuncSig_CastDurationAsDecimal, tipb.ScalarFuncSig_CastJsonAsDecimal*/ : - return function.RetType.IsDecimalValid() - case tipb.ScalarFuncSig_CastDecimalAsString, tipb.ScalarFuncSig_CastIntAsString, tipb.ScalarFuncSig_CastRealAsString, tipb.ScalarFuncSig_CastTimeAsString, - tipb.ScalarFuncSig_CastStringAsString, tipb.ScalarFuncSig_CastJsonAsString /*, tipb.ScalarFuncSig_CastDurationAsString*/ : - return true - case tipb.ScalarFuncSig_CastDecimalAsTime, tipb.ScalarFuncSig_CastIntAsTime, tipb.ScalarFuncSig_CastRealAsTime, tipb.ScalarFuncSig_CastTimeAsTime, - tipb.ScalarFuncSig_CastStringAsTime /*, tipb.ScalarFuncSig_CastDurationAsTime, tipb.ScalarFuncSig_CastJsonAsTime*/ : - // ban the function of casting year type as time type pushing down to tiflash because of https://github.com/pingcap/tidb/issues/26215 - return function.GetArgs()[0].GetType(ctx).GetType() != mysql.TypeYear - case tipb.ScalarFuncSig_CastTimeAsDuration: - return retType.GetType() == mysql.TypeDuration - case tipb.ScalarFuncSig_CastIntAsJson, tipb.ScalarFuncSig_CastRealAsJson, tipb.ScalarFuncSig_CastDecimalAsJson, tipb.ScalarFuncSig_CastStringAsJson, - tipb.ScalarFuncSig_CastTimeAsJson, tipb.ScalarFuncSig_CastDurationAsJson, tipb.ScalarFuncSig_CastJsonAsJson: - return true - } - case ast.DateAdd, ast.AddDate: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_AddDateDatetimeInt, tipb.ScalarFuncSig_AddDateStringInt, tipb.ScalarFuncSig_AddDateStringReal: - return true - } - case ast.DateSub, ast.SubDate: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_SubDateDatetimeInt, tipb.ScalarFuncSig_SubDateStringInt, tipb.ScalarFuncSig_SubDateStringReal: - return true - } - case ast.UnixTimestamp: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_UnixTimestampInt, tipb.ScalarFuncSig_UnixTimestampDec: - return true - } - case ast.Round: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_RoundInt, tipb.ScalarFuncSig_RoundReal, tipb.ScalarFuncSig_RoundDec, - tipb.ScalarFuncSig_RoundWithFracInt, tipb.ScalarFuncSig_RoundWithFracReal, tipb.ScalarFuncSig_RoundWithFracDec: - return true - } - case ast.Extract: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_ExtractDatetime, tipb.ScalarFuncSig_ExtractDuration: - return true - } - case ast.Replace: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_Replace: - return true - } - case ast.StrToDate: - switch function.Function.PbCode() { - case - tipb.ScalarFuncSig_StrToDateDate, - tipb.ScalarFuncSig_StrToDateDatetime: - return true - default: - return false - } - case ast.Upper, ast.Ucase, ast.Lower, ast.Lcase, ast.Space: - return true - case ast.Sysdate: - return true - case ast.Least, ast.Greatest: - switch function.Function.PbCode() { - case tipb.ScalarFuncSig_GreatestInt, tipb.ScalarFuncSig_GreatestReal, - tipb.ScalarFuncSig_LeastInt, tipb.ScalarFuncSig_LeastReal, tipb.ScalarFuncSig_LeastString, tipb.ScalarFuncSig_GreatestString: - return true - } - case ast.IsTruthWithNull, ast.IsTruthWithoutNull, ast.IsFalsity: - return true - case ast.Hex, ast.Unhex, ast.Bin: - return true - case ast.GetFormat: - return true - case ast.IsIPv4, ast.IsIPv6: - return true - case ast.Grouping: // grouping function for grouping sets identification. - return true - } - return false -} - -func canEnumPushdownPreliminarily(scalarFunc *ScalarFunction) bool { - switch scalarFunc.FuncName.L { - case ast.Cast: - return scalarFunc.RetType.EvalType() == types.ETInt || scalarFunc.RetType.EvalType() == types.ETReal || scalarFunc.RetType.EvalType() == types.ETDecimal - default: - return false - } -} - -// IsPushDownEnabled returns true if the input expr is not in the expr_pushdown_blacklist -func IsPushDownEnabled(name string, storeType kv.StoreType) bool { - value, exists := DefaultExprPushDownBlacklist.Load().(map[string]uint32)[name] - if exists { - mask := storeTypeMask(storeType) - return !(value&mask == mask) - } - - if storeType != kv.TiFlash && name == ast.AggFuncApproxCountDistinct { - // Can not push down approx_count_distinct to other store except tiflash by now. - return false - } - - return true -} - -// PushDownContext is the context used for push down expressions -type PushDownContext struct { - evalCtx EvalContext - client kv.Client - warnHandler contextutil.WarnAppender - groupConcatMaxLen uint64 -} - -// NewPushDownContext returns a new PushDownContext -func NewPushDownContext(evalCtx EvalContext, client kv.Client, inExplainStmt bool, - warnHandler contextutil.WarnAppender, extraWarnHandler contextutil.WarnAppender, groupConcatMaxLen uint64) PushDownContext { - var newWarnHandler contextutil.WarnAppender - if warnHandler != nil && extraWarnHandler != nil { - if inExplainStmt { - newWarnHandler = warnHandler - } else { - newWarnHandler = extraWarnHandler - } - } - - return PushDownContext{ - evalCtx: evalCtx, - client: client, - warnHandler: newWarnHandler, - groupConcatMaxLen: groupConcatMaxLen, - } -} - -// NewPushDownContextFromSessionVars builds a new PushDownContext from session vars. -func NewPushDownContextFromSessionVars(evalCtx EvalContext, sessVars *variable.SessionVars, client kv.Client) PushDownContext { - return NewPushDownContext( - evalCtx, - client, - sessVars.StmtCtx.InExplainStmt, - sessVars.StmtCtx.WarnHandler, - sessVars.StmtCtx.ExtraWarnHandler, - sessVars.GroupConcatMaxLen) -} - -// EvalCtx returns the eval context -func (ctx PushDownContext) EvalCtx() EvalContext { - return ctx.evalCtx -} - -// PbConverter returns a new PbConverter -func (ctx PushDownContext) PbConverter() PbConverter { - return NewPBConverter(ctx.client, ctx.evalCtx) -} - -// Client returns the kv client -func (ctx PushDownContext) Client() kv.Client { - return ctx.client -} - -// GetGroupConcatMaxLen returns the max length of group_concat -func (ctx PushDownContext) GetGroupConcatMaxLen() uint64 { - return ctx.groupConcatMaxLen -} - -// AppendWarning appends a warning to be handled by the internal handler -func (ctx PushDownContext) AppendWarning(err error) { - if ctx.warnHandler != nil { - ctx.warnHandler.AppendWarning(err) - } -} - -// PushDownExprsWithExtraInfo split the input exprs into pushed and remained, pushed include all the exprs that can be pushed down -func PushDownExprsWithExtraInfo(ctx PushDownContext, exprs []Expression, storeType kv.StoreType, canEnumPush bool) (pushed []Expression, remained []Expression) { - for _, expr := range exprs { - if canExprPushDown(ctx, expr, storeType, canEnumPush) { - pushed = append(pushed, expr) - } else { - remained = append(remained, expr) - } - } - return -} - -// PushDownExprs split the input exprs into pushed and remained, pushed include all the exprs that can be pushed down -func PushDownExprs(ctx PushDownContext, exprs []Expression, storeType kv.StoreType) (pushed []Expression, remained []Expression) { - return PushDownExprsWithExtraInfo(ctx, exprs, storeType, false) -} - -// CanExprsPushDownWithExtraInfo return true if all the expr in exprs can be pushed down -func CanExprsPushDownWithExtraInfo(ctx PushDownContext, exprs []Expression, storeType kv.StoreType, canEnumPush bool) bool { - _, remained := PushDownExprsWithExtraInfo(ctx, exprs, storeType, canEnumPush) - return len(remained) == 0 -} - -// CanExprsPushDown return true if all the expr in exprs can be pushed down -func CanExprsPushDown(ctx PushDownContext, exprs []Expression, storeType kv.StoreType) bool { - return CanExprsPushDownWithExtraInfo(ctx, exprs, storeType, false) -} - -func storeTypeMask(storeType kv.StoreType) uint32 { - if storeType == kv.UnSpecified { - return 1<= 0; i-- { - if filter(input[i]) { - filteredOut = append(filteredOut, input[i]) - input = append(input[:i], input[i+1:]...) - } - } - return input, filteredOut -} - -// ExtractDependentColumns extracts all dependent columns from a virtual column. -func ExtractDependentColumns(expr Expression) []*Column { - // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. - result := make([]*Column, 0, 8) - return extractDependentColumns(result, expr) -} - -func extractDependentColumns(result []*Column, expr Expression) []*Column { - switch v := expr.(type) { - case *Column: - result = append(result, v) - if v.VirtualExpr != nil { - result = extractDependentColumns(result, v.VirtualExpr) - } - case *ScalarFunction: - for _, arg := range v.GetArgs() { - result = extractDependentColumns(result, arg) - } - } - return result -} - -// ExtractColumns extracts all columns from an expression. -func ExtractColumns(expr Expression) []*Column { - // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. - result := make([]*Column, 0, 8) - return extractColumns(result, expr, nil) -} - -// ExtractCorColumns extracts correlated column from given expression. -func ExtractCorColumns(expr Expression) (cols []*CorrelatedColumn) { - switch v := expr.(type) { - case *CorrelatedColumn: - return []*CorrelatedColumn{v} - case *ScalarFunction: - for _, arg := range v.GetArgs() { - cols = append(cols, ExtractCorColumns(arg)...) - } - } - return -} - -// ExtractColumnsFromExpressions is a more efficient version of ExtractColumns for batch operation. -// filter can be nil, or a function to filter the result column. -// It's often observed that the pattern of the caller like this: -// -// cols := ExtractColumns(...) -// -// for _, col := range cols { -// if xxx(col) {...} -// } -// -// Provide an additional filter argument, this can be done in one step. -// To avoid allocation for cols that not need. -func ExtractColumnsFromExpressions(result []*Column, exprs []Expression, filter func(*Column) bool) []*Column { - for _, expr := range exprs { - result = extractColumns(result, expr, filter) - } - return result -} - -func extractColumns(result []*Column, expr Expression, filter func(*Column) bool) []*Column { - switch v := expr.(type) { - case *Column: - if filter == nil || filter(v) { - result = append(result, v) - } - case *ScalarFunction: - for _, arg := range v.GetArgs() { - result = extractColumns(result, arg, filter) - } - } - return result -} - -// ExtractEquivalenceColumns detects the equivalence from CNF exprs. -func ExtractEquivalenceColumns(result [][]Expression, exprs []Expression) [][]Expression { - // exprs are CNF expressions, EQ condition only make sense in the top level of every expr. - for _, expr := range exprs { - result = extractEquivalenceColumns(result, expr) - } - return result -} - -// FindUpperBound looks for column < constant or column <= constant and returns both the column -// and constant. It return nil, 0 if the expression is not of this form. -// It is used by derived Top N pattern and it is put here since it looks like -// a general purpose routine. Similar routines can be added to find lower bound as well. -func FindUpperBound(expr Expression) (*Column, int64) { - scalarFunction, scalarFunctionOk := expr.(*ScalarFunction) - if scalarFunctionOk { - args := scalarFunction.GetArgs() - if len(args) == 2 { - col, colOk := args[0].(*Column) - constant, constantOk := args[1].(*Constant) - if colOk && constantOk && (scalarFunction.FuncName.L == ast.LT || scalarFunction.FuncName.L == ast.LE) { - value, valueOk := constant.Value.GetValue().(int64) - if valueOk { - if scalarFunction.FuncName.L == ast.LT { - return col, value - 1 - } - return col, value - } - } - } - } - return nil, 0 -} - -func extractEquivalenceColumns(result [][]Expression, expr Expression) [][]Expression { - switch v := expr.(type) { - case *ScalarFunction: - // a==b, a<=>b, the latter one is evaluated to true when a,b are both null. - if v.FuncName.L == ast.EQ || v.FuncName.L == ast.NullEQ { - args := v.GetArgs() - if len(args) == 2 { - col1, ok1 := args[0].(*Column) - col2, ok2 := args[1].(*Column) - if ok1 && ok2 { - result = append(result, []Expression{col1, col2}) - } - col, ok1 := args[0].(*Column) - scl, ok2 := args[1].(*ScalarFunction) - if ok1 && ok2 { - result = append(result, []Expression{col, scl}) - } - col, ok1 = args[1].(*Column) - scl, ok2 = args[0].(*ScalarFunction) - if ok1 && ok2 { - result = append(result, []Expression{col, scl}) - } - } - return result - } - if v.FuncName.L == ast.In { - args := v.GetArgs() - // only `col in (only 1 element)`, can we build an equivalence here. - if len(args[1:]) == 1 { - col1, ok1 := args[0].(*Column) - col2, ok2 := args[1].(*Column) - if ok1 && ok2 { - result = append(result, []Expression{col1, col2}) - } - col, ok1 := args[0].(*Column) - scl, ok2 := args[1].(*ScalarFunction) - if ok1 && ok2 { - result = append(result, []Expression{col, scl}) - } - col, ok1 = args[1].(*Column) - scl, ok2 = args[0].(*ScalarFunction) - if ok1 && ok2 { - result = append(result, []Expression{col, scl}) - } - } - return result - } - // For Non-EQ function, we don't have to traverse down. - // eg: (a=b or c=d) doesn't make any definitely equivalence assertion. - } - return result -} - -// extractColumnsAndCorColumns extracts columns and correlated columns from `expr` and append them to `result`. -func extractColumnsAndCorColumns(result []*Column, expr Expression) []*Column { - switch v := expr.(type) { - case *Column: - result = append(result, v) - case *CorrelatedColumn: - result = append(result, &v.Column) - case *ScalarFunction: - for _, arg := range v.GetArgs() { - result = extractColumnsAndCorColumns(result, arg) - } - } - return result -} - -// ExtractConstantEqColumnsOrScalar detects the constant equal relationship from CNF exprs. -func ExtractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, exprs []Expression) []Expression { - // exprs are CNF expressions, EQ condition only make sense in the top level of every expr. - for _, expr := range exprs { - result = extractConstantEqColumnsOrScalar(ctx, result, expr) - } - return result -} - -func extractConstantEqColumnsOrScalar(ctx BuildContext, result []Expression, expr Expression) []Expression { - switch v := expr.(type) { - case *ScalarFunction: - if v.FuncName.L == ast.EQ || v.FuncName.L == ast.NullEQ { - args := v.GetArgs() - if len(args) == 2 { - col, ok1 := args[0].(*Column) - _, ok2 := args[1].(*Constant) - if ok1 && ok2 { - result = append(result, col) - } - col, ok1 = args[1].(*Column) - _, ok2 = args[0].(*Constant) - if ok1 && ok2 { - result = append(result, col) - } - // take the correlated column as constant here. - col, ok1 = args[0].(*Column) - _, ok2 = args[1].(*CorrelatedColumn) - if ok1 && ok2 { - result = append(result, col) - } - col, ok1 = args[1].(*Column) - _, ok2 = args[0].(*CorrelatedColumn) - if ok1 && ok2 { - result = append(result, col) - } - scl, ok1 := args[0].(*ScalarFunction) - _, ok2 = args[1].(*Constant) - if ok1 && ok2 { - result = append(result, scl) - } - scl, ok1 = args[1].(*ScalarFunction) - _, ok2 = args[0].(*Constant) - if ok1 && ok2 { - result = append(result, scl) - } - // take the correlated column as constant here. - scl, ok1 = args[0].(*ScalarFunction) - _, ok2 = args[1].(*CorrelatedColumn) - if ok1 && ok2 { - result = append(result, scl) - } - scl, ok1 = args[1].(*ScalarFunction) - _, ok2 = args[0].(*CorrelatedColumn) - if ok1 && ok2 { - result = append(result, scl) - } - } - return result - } - if v.FuncName.L == ast.In { - args := v.GetArgs() - allArgsIsConst := true - // only `col in (all same const)`, can col be the constant column. - // eg: a in (1, "1") does, while a in (1, '2') doesn't. - guard := args[1] - for i, v := range args[1:] { - if _, ok := v.(*Constant); !ok { - allArgsIsConst = false - break - } - if i == 0 { - continue - } - if !guard.Equal(ctx.GetEvalCtx(), v) { - allArgsIsConst = false - break - } - } - if allArgsIsConst { - if col, ok := args[0].(*Column); ok { - result = append(result, col) - } else if scl, ok := args[0].(*ScalarFunction); ok { - result = append(result, scl) - } - } - return result - } - // For Non-EQ function, we don't have to traverse down. - } - return result -} - -// ExtractColumnsAndCorColumnsFromExpressions extracts columns and correlated columns from expressions and append them to `result`. -func ExtractColumnsAndCorColumnsFromExpressions(result []*Column, list []Expression) []*Column { - for _, expr := range list { - result = extractColumnsAndCorColumns(result, expr) - } - return result -} - -// ExtractColumnSet extracts the different values of `UniqueId` for columns in expressions. -func ExtractColumnSet(exprs ...Expression) intset.FastIntSet { - set := intset.NewFastIntSet() - for _, expr := range exprs { - extractColumnSet(expr, &set) - } - return set -} - -func extractColumnSet(expr Expression, set *intset.FastIntSet) { - switch v := expr.(type) { - case *Column: - set.Insert(int(v.UniqueID)) - case *ScalarFunction: - for _, arg := range v.GetArgs() { - extractColumnSet(arg, set) - } - } -} - -// SetExprColumnInOperand is used to set columns in expr as InOperand. -func SetExprColumnInOperand(expr Expression) Expression { - switch v := expr.(type) { - case *Column: - col := v.Clone().(*Column) - col.InOperand = true - return col - case *ScalarFunction: - args := v.GetArgs() - for i, arg := range args { - args[i] = SetExprColumnInOperand(arg) - } - } - return expr -} - -// ColumnSubstitute substitutes the columns in filter to expressions in select fields. -// e.g. select * from (select b as a from t) k where a < 10 => select * from (select b as a from t where b < 10) k. -// TODO: remove this function and only use ColumnSubstituteImpl since this function swallows the error, which seems unsafe. -func ColumnSubstitute(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) Expression { - _, _, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, false) - return resExpr -} - -// ColumnSubstituteAll substitutes the columns just like ColumnSubstitute, but we don't accept partial substitution. -// Only accept: -// -// 1: substitute them all once find col in schema. -// 2: nothing in expr can be substituted. -func ColumnSubstituteAll(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression) (bool, Expression) { - _, hasFail, resExpr := ColumnSubstituteImpl(ctx, expr, schema, newExprs, true) - return hasFail, resExpr -} - -// ColumnSubstituteImpl tries to substitute column expr using newExprs, -// the newFunctionInternal is only called if its child is substituted -// @return bool means whether the expr has changed. -// @return bool means whether the expr should change (has the dependency in schema, while the corresponding expr has some compatibility), but finally fallback. -// @return Expression, the original expr or the changed expr, it depends on the first @return bool. -func ColumnSubstituteImpl(ctx BuildContext, expr Expression, schema *Schema, newExprs []Expression, fail1Return bool) (bool, bool, Expression) { - switch v := expr.(type) { - case *Column: - id := schema.ColumnIndex(v) - if id == -1 { - return false, false, v - } - newExpr := newExprs[id] - if v.InOperand { - newExpr = SetExprColumnInOperand(newExpr) - } - return true, false, newExpr - case *ScalarFunction: - substituted := false - hasFail := false - if v.FuncName.L == ast.Cast || v.FuncName.L == ast.Grouping { - var newArg Expression - substituted, hasFail, newArg = ColumnSubstituteImpl(ctx, v.GetArgs()[0], schema, newExprs, fail1Return) - if fail1Return && hasFail { - return substituted, hasFail, v - } - if substituted { - flag := v.RetType.GetFlag() - var e Expression - if v.FuncName.L == ast.Cast { - e = BuildCastFunction(ctx, newArg, v.RetType) - } else { - // for grouping function recreation, use clone (meta included) instead of newFunction - e = v.Clone() - e.(*ScalarFunction).Function.getArgs()[0] = newArg - } - e.SetCoercibility(v.Coercibility()) - e.GetType(ctx.GetEvalCtx()).SetFlag(flag) - return true, false, e - } - return false, false, v - } - // If the collation of the column is PAD SPACE, - // we can't propagate the constant to the length function. - // For example, schema = ['name'], newExprs = ['a'], v = length(name). - // We can't substitute name with 'a' in length(name) because the collation of name is PAD SPACE. - // TODO: We will fix it here temporarily, and redesign the logic if we encounter more similar functions or situations later. - // Fixed issue #53730 - if ctx.IsConstantPropagateCheck() && v.FuncName.L == ast.Length { - arg0, isColumn := v.GetArgs()[0].(*Column) - if isColumn { - id := schema.ColumnIndex(arg0) - if id != -1 { - _, isConstant := newExprs[id].(*Constant) - if isConstant { - mappedNewColumnCollate := schema.Columns[id].GetStaticType().GetCollate() - if mappedNewColumnCollate == charset.CollationUTF8MB4 || - mappedNewColumnCollate == charset.CollationUTF8 { - return false, false, v - } - } - } - } - } - // cowExprRef is a copy-on-write util, args array allocation happens only - // when expr in args is changed - refExprArr := cowExprRef{v.GetArgs(), nil} - oldCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), v.GetArgs()...) - if err != nil { - logutil.BgLogger().Error("Unexpected error happened during ColumnSubstitution", zap.Stack("stack")) - return false, false, v - } - var tmpArgForCollCheck []Expression - if collate.NewCollationEnabled() { - tmpArgForCollCheck = make([]Expression, len(v.GetArgs())) - } - for idx, arg := range v.GetArgs() { - changed, failed, newFuncExpr := ColumnSubstituteImpl(ctx, arg, schema, newExprs, fail1Return) - if fail1Return && failed { - return changed, failed, v - } - oldChanged := changed - if collate.NewCollationEnabled() && changed { - // Make sure the collation used by the ScalarFunction isn't changed and its result collation is not weaker than the collation used by the ScalarFunction. - changed = false - copy(tmpArgForCollCheck, refExprArr.Result()) - tmpArgForCollCheck[idx] = newFuncExpr - newCollEt, err := CheckAndDeriveCollationFromExprs(ctx, v.FuncName.L, v.RetType.EvalType(), tmpArgForCollCheck...) - if err != nil { - logutil.BgLogger().Error("Unexpected error happened during ColumnSubstitution", zap.Stack("stack")) - return false, failed, v - } - if oldCollEt.Collation == newCollEt.Collation { - if newFuncExpr.GetType(ctx.GetEvalCtx()).GetCollate() == arg.GetType(ctx.GetEvalCtx()).GetCollate() && newFuncExpr.Coercibility() == arg.Coercibility() { - // It's safe to use the new expression, otherwise some cases in projection push-down will be wrong. - changed = true - } else { - changed = checkCollationStrictness(oldCollEt.Collation, newFuncExpr.GetType(ctx.GetEvalCtx()).GetCollate()) - } - } - } - hasFail = hasFail || failed || oldChanged != changed - if fail1Return && oldChanged != changed { - // Only when the oldChanged is true and changed is false, we will get here. - // And this means there some dependency in this arg can be substituted with - // given expressions, while it has some collation compatibility, finally we - // fall back to use the origin args. (commonly used in projection elimination - // in which fallback usage is unacceptable) - return changed, true, v - } - refExprArr.Set(idx, changed, newFuncExpr) - if changed { - substituted = true - } - } - if substituted { - newFunc, err := NewFunction(ctx, v.FuncName.L, v.RetType, refExprArr.Result()...) - if err != nil { - return true, true, v - } - return true, hasFail, newFunc - } - } - return false, false, expr -} - -// checkCollationStrictness check collation strictness-ship between `coll` and `newFuncColl` -// return true iff `newFuncColl` is not weaker than `coll` -func checkCollationStrictness(coll, newFuncColl string) bool { - collGroupID, ok1 := CollationStrictnessGroup[coll] - newFuncCollGroupID, ok2 := CollationStrictnessGroup[newFuncColl] - - if ok1 && ok2 { - if collGroupID == newFuncCollGroupID { - return true - } - - for _, id := range CollationStrictness[collGroupID] { - if newFuncCollGroupID == id { - return true - } - } - } - - return false -} - -// getValidPrefix gets a prefix of string which can parsed to a number with base. the minimum base is 2 and the maximum is 36. -func getValidPrefix(s string, base int64) string { - var ( - validLen int - upper rune - ) - switch { - case base >= 2 && base <= 9: - upper = rune('0' + base) - case base <= 36: - upper = rune('A' + base - 10) - default: - return "" - } -Loop: - for i := 0; i < len(s); i++ { - c := rune(s[i]) - switch { - case unicode.IsDigit(c) || unicode.IsLower(c) || unicode.IsUpper(c): - c = unicode.ToUpper(c) - if c >= upper { - break Loop - } - validLen = i + 1 - case c == '+' || c == '-': - if i != 0 { - break Loop - } - default: - break Loop - } - } - if validLen > 1 && s[0] == '+' { - return s[1:validLen] - } - return s[:validLen] -} - -// SubstituteCorCol2Constant will substitute correlated column to constant value which it contains. -// If the args of one scalar function are all constant, we will substitute it to constant. -func SubstituteCorCol2Constant(ctx BuildContext, expr Expression) (Expression, error) { - switch x := expr.(type) { - case *ScalarFunction: - allConstant := true - newArgs := make([]Expression, 0, len(x.GetArgs())) - for _, arg := range x.GetArgs() { - newArg, err := SubstituteCorCol2Constant(ctx, arg) - if err != nil { - return nil, err - } - _, ok := newArg.(*Constant) - newArgs = append(newArgs, newArg) - allConstant = allConstant && ok - } - if allConstant { - val, err := x.Eval(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return nil, err - } - return &Constant{Value: val, RetType: x.GetType(ctx.GetEvalCtx())}, nil - } - var ( - err error - newSf Expression - ) - if x.FuncName.L == ast.Cast { - newSf = BuildCastFunction(ctx, newArgs[0], x.RetType) - } else if x.FuncName.L == ast.Grouping { - newSf = x.Clone() - newSf.(*ScalarFunction).GetArgs()[0] = newArgs[0] - } else { - newSf, err = NewFunction(ctx, x.FuncName.L, x.GetType(ctx.GetEvalCtx()), newArgs...) - } - return newSf, err - case *CorrelatedColumn: - return &Constant{Value: *x.Data, RetType: x.GetType(ctx.GetEvalCtx())}, nil - case *Constant: - if x.DeferredExpr != nil { - newExpr := FoldConstant(ctx, x) - return &Constant{Value: newExpr.(*Constant).Value, RetType: x.GetType(ctx.GetEvalCtx())}, nil - } - } - return expr, nil -} - -func locateStringWithCollation(str, substr, coll string) int64 { - collator := collate.GetCollator(coll) - strKey := collator.KeyWithoutTrimRightSpace(str) - subStrKey := collator.KeyWithoutTrimRightSpace(substr) - - index := bytes.Index(strKey, subStrKey) - if index == -1 || index == 0 { - return int64(index + 1) - } - - // todo: we can use binary search to make it faster. - count := int64(0) - for { - r, size := utf8.DecodeRuneInString(str) - count++ - index -= len(collator.KeyWithoutTrimRightSpace(string(r))) - if index <= 0 { - return count + 1 - } - str = str[size:] - } -} - -// timeZone2Duration converts timezone whose format should satisfy the regular condition -// `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to int for use by time.FixedZone(). -func timeZone2int(tz string) int { - sign := 1 - if strings.HasPrefix(tz, "-") { - sign = -1 - } - - i := strings.Index(tz, ":") - h, err := strconv.Atoi(tz[1:i]) - terror.Log(err) - m, err := strconv.Atoi(tz[i+1:]) - terror.Log(err) - return sign * ((h * 3600) + (m * 60)) -} - -var logicalOps = map[string]struct{}{ - ast.LT: {}, - ast.GE: {}, - ast.GT: {}, - ast.LE: {}, - ast.EQ: {}, - ast.NE: {}, - ast.UnaryNot: {}, - ast.LogicAnd: {}, - ast.LogicOr: {}, - ast.LogicXor: {}, - ast.In: {}, - ast.IsNull: {}, - ast.IsTruthWithoutNull: {}, - ast.IsFalsity: {}, - ast.Like: {}, -} - -var oppositeOp = map[string]string{ - ast.LT: ast.GE, - ast.GE: ast.LT, - ast.GT: ast.LE, - ast.LE: ast.GT, - ast.EQ: ast.NE, - ast.NE: ast.EQ, - ast.LogicOr: ast.LogicAnd, - ast.LogicAnd: ast.LogicOr, -} - -// a op b is equal to b symmetricOp a -var symmetricOp = map[opcode.Op]opcode.Op{ - opcode.LT: opcode.GT, - opcode.GE: opcode.LE, - opcode.GT: opcode.LT, - opcode.LE: opcode.GE, - opcode.EQ: opcode.EQ, - opcode.NE: opcode.NE, - opcode.NullEQ: opcode.NullEQ, -} - -func pushNotAcrossArgs(ctx BuildContext, exprs []Expression, not bool) ([]Expression, bool) { - newExprs := make([]Expression, 0, len(exprs)) - flag := false - for _, expr := range exprs { - newExpr, changed := pushNotAcrossExpr(ctx, expr, not) - flag = changed || flag - newExprs = append(newExprs, newExpr) - } - return newExprs, flag -} - -// todo: consider more no precision-loss downcast cases. -func noPrecisionLossCastCompatible(cast, argCol *types.FieldType) bool { - // now only consider varchar type and integer. - if !(types.IsTypeVarchar(cast.GetType()) && types.IsTypeVarchar(argCol.GetType())) && - !(mysql.IsIntegerType(cast.GetType()) && mysql.IsIntegerType(argCol.GetType())) { - // varchar type and integer on the storage layer is quite same, while the char type has its padding suffix. - return false - } - if types.IsTypeVarchar(cast.GetType()) { - // cast varchar function only bear the flen extension. - if cast.GetFlen() < argCol.GetFlen() { - return false - } - if !collate.CompatibleCollate(cast.GetCollate(), argCol.GetCollate()) { - return false - } - } else { - // For integers, we should ignore the potential display length represented by flen, using the default flen of the type. - castFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(cast.GetType()) - originFlen, _ := mysql.GetDefaultFieldLengthAndDecimal(argCol.GetType()) - // cast integer function only bear the flen extension and signed symbol unchanged. - if castFlen < originFlen { - return false - } - if mysql.HasUnsignedFlag(cast.GetFlag()) != mysql.HasUnsignedFlag(argCol.GetFlag()) { - return false - } - } - return true -} - -func unwrapCast(sctx BuildContext, parentF *ScalarFunction, castOffset int) (Expression, bool) { - _, collation := parentF.CharsetAndCollation() - cast, ok := parentF.GetArgs()[castOffset].(*ScalarFunction) - if !ok || cast.FuncName.L != ast.Cast { - return parentF, false - } - // eg: if (cast(A) EQ const) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. - if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { - return parentF, false - } - // 1-castOffset should be constant - if _, ok := parentF.GetArgs()[1-castOffset].(*Constant); !ok { - return parentF, false - } - - // the direct args of cast function should be column. - c, ok := cast.GetArgs()[0].(*Column) - if !ok { - return parentF, false - } - - // current only consider varchar and integer - if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { - return parentF, false - } - - // the column is covered by indexes, deconstructing it out. - if castOffset == 0 { - return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, c, parentF.GetArgs()[1]), true - } - return NewFunctionInternal(sctx, parentF.FuncName.L, parentF.RetType, parentF.GetArgs()[0], c), true -} - -// eliminateCastFunction will detect the original arg before and the cast type after, once upon -// there is no precision loss between them, current cast wrapper can be eliminated. For string -// type, collation is also taken into consideration. (mainly used to build range or point) -func eliminateCastFunction(sctx BuildContext, expr Expression) (_ Expression, changed bool) { - f, ok := expr.(*ScalarFunction) - if !ok { - return expr, false - } - _, collation := expr.CharsetAndCollation() - switch f.FuncName.L { - case ast.LogicOr: - dnfItems := FlattenDNFConditions(f) - rmCast := false - rmCastItems := make([]Expression, len(dnfItems)) - for i, dnfItem := range dnfItems { - newExpr, curDowncast := eliminateCastFunction(sctx, dnfItem) - rmCastItems[i] = newExpr - if curDowncast { - rmCast = true - } - } - if rmCast { - // compose the new DNF expression. - return ComposeDNFCondition(sctx, rmCastItems...), true - } - return expr, false - case ast.LogicAnd: - cnfItems := FlattenCNFConditions(f) - rmCast := false - rmCastItems := make([]Expression, len(cnfItems)) - for i, cnfItem := range cnfItems { - newExpr, curDowncast := eliminateCastFunction(sctx, cnfItem) - rmCastItems[i] = newExpr - if curDowncast { - rmCast = true - } - } - if rmCast { - // compose the new CNF expression. - return ComposeCNFCondition(sctx, rmCastItems...), true - } - return expr, false - case ast.EQ, ast.NullEQ, ast.LE, ast.GE, ast.LT, ast.GT: - // for case: eq(cast(test.t2.a, varchar(100), "aaaaa"), once t2.a is covered by index or pk, try deconstructing it out. - if newF, ok := unwrapCast(sctx, f, 0); ok { - return newF, true - } - // for case: eq("aaaaa", cast(test.t2.a, varchar(100)), once t2.a is covered by index or pk, try deconstructing it out. - if newF, ok := unwrapCast(sctx, f, 1); ok { - return newF, true - } - case ast.In: - // case for: cast(a as bigint) in (1,2,3), we could deconstruct column 'a out directly. - cast, ok := f.GetArgs()[0].(*ScalarFunction) - if !ok || cast.FuncName.L != ast.Cast { - return expr, false - } - // eg: if (cast(A) IN {const}) with incompatible collation, even if cast is eliminated, the condition still can not be used to build range. - if cast.RetType.EvalType() == types.ETString && !collate.CompatibleCollate(cast.RetType.GetCollate(), collation) { - return expr, false - } - for _, arg := range f.GetArgs()[1:] { - if _, ok := arg.(*Constant); !ok { - return expr, false - } - } - // the direct args of cast function should be column. - c, ok := cast.GetArgs()[0].(*Column) - if !ok { - return expr, false - } - // current only consider varchar and integer - if !noPrecisionLossCastCompatible(cast.RetType, c.RetType) { - return expr, false - } - newArgs := []Expression{c} - newArgs = append(newArgs, f.GetArgs()[1:]...) - return NewFunctionInternal(sctx, f.FuncName.L, f.RetType, newArgs...), true - } - return expr, false -} - -// pushNotAcrossExpr try to eliminate the NOT expr in expression tree. -// Input `not` indicates whether there's a `NOT` be pushed down. -// Output `changed` indicates whether the output expression differs from the -// input `expr` because of the pushed-down-not. -func pushNotAcrossExpr(ctx BuildContext, expr Expression, not bool) (_ Expression, changed bool) { - if f, ok := expr.(*ScalarFunction); ok { - switch f.FuncName.L { - case ast.UnaryNot: - child, err := wrapWithIsTrue(ctx, true, f.GetArgs()[0], true) - if err != nil { - return expr, false - } - var childExpr Expression - childExpr, changed = pushNotAcrossExpr(ctx, child, !not) - if !changed && !not { - return expr, false - } - return childExpr, true - case ast.LT, ast.GE, ast.GT, ast.LE, ast.EQ, ast.NE: - if not { - return NewFunctionInternal(ctx, oppositeOp[f.FuncName.L], f.GetType(ctx.GetEvalCtx()), f.GetArgs()...), true - } - newArgs, changed := pushNotAcrossArgs(ctx, f.GetArgs(), false) - if !changed { - return f, false - } - return NewFunctionInternal(ctx, f.FuncName.L, f.GetType(ctx.GetEvalCtx()), newArgs...), true - case ast.LogicAnd, ast.LogicOr: - var ( - newArgs []Expression - changed bool - ) - funcName := f.FuncName.L - if not { - newArgs, _ = pushNotAcrossArgs(ctx, f.GetArgs(), true) - funcName = oppositeOp[f.FuncName.L] - changed = true - } else { - newArgs, changed = pushNotAcrossArgs(ctx, f.GetArgs(), false) - } - if !changed { - return f, false - } - return NewFunctionInternal(ctx, funcName, f.GetType(ctx.GetEvalCtx()), newArgs...), true - } - } - if not { - expr = NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), expr) - } - return expr, not -} - -// GetExprInsideIsTruth get the expression inside the `istrue_with_null` and `istrue`. -// This is useful when handling expressions from "not" or "!", because we might wrap `istrue_with_null` or `istrue` -// when handling them. See pushNotAcrossExpr() and wrapWithIsTrue() for details. -func GetExprInsideIsTruth(expr Expression) Expression { - if f, ok := expr.(*ScalarFunction); ok { - switch f.FuncName.L { - case ast.IsTruthWithNull, ast.IsTruthWithoutNull: - return GetExprInsideIsTruth(f.GetArgs()[0]) - default: - return expr - } - } - return expr -} - -// PushDownNot pushes the `not` function down to the expression's arguments. -func PushDownNot(ctx BuildContext, expr Expression) Expression { - newExpr, _ := pushNotAcrossExpr(ctx, expr, false) - return newExpr -} - -// EliminateNoPrecisionLossCast remove the redundant cast function for range build convenience. -// 1: deeper cast embedded in other complicated function will not be considered. -// 2: cast args should be one for original base column and one for constant. -// 3: some collation compatibility and precision loss will be considered when remove this cast func. -func EliminateNoPrecisionLossCast(sctx BuildContext, expr Expression) Expression { - newExpr, _ := eliminateCastFunction(sctx, expr) - return newExpr -} - -// ContainOuterNot checks if there is an outer `not`. -func ContainOuterNot(expr Expression) bool { - return containOuterNot(expr, false) -} - -// containOuterNot checks if there is an outer `not`. -// Input `not` means whether there is `not` outside `expr` -// -// eg. -// -// not(0+(t.a == 1 and t.b == 2)) returns true -// not(t.a) and not(t.b) returns false -func containOuterNot(expr Expression, not bool) bool { - if f, ok := expr.(*ScalarFunction); ok { - switch f.FuncName.L { - case ast.UnaryNot: - return containOuterNot(f.GetArgs()[0], true) - case ast.IsTruthWithNull, ast.IsNull: - return containOuterNot(f.GetArgs()[0], not) - default: - if not { - return true - } - hasNot := false - for _, expr := range f.GetArgs() { - hasNot = hasNot || containOuterNot(expr, not) - if hasNot { - return hasNot - } - } - return hasNot - } - } - return false -} - -// Contains tests if `exprs` contains `e`. -func Contains(ectx EvalContext, exprs []Expression, e Expression) bool { - for _, expr := range exprs { - // Check string equivalence if one of the expressions is a clone. - sameString := false - if e != nil && expr != nil { - sameString = (e.StringWithCtx(ectx, errors.RedactLogDisable) == expr.StringWithCtx(ectx, errors.RedactLogDisable)) - } - if e == expr || sameString { - return true - } - } - return false -} - -// ExtractFiltersFromDNFs checks whether the cond is DNF. If so, it will get the extracted part and the remained part. -// The original DNF will be replaced by the remained part or just be deleted if remained part is nil. -// And the extracted part will be appended to the end of the original slice. -func ExtractFiltersFromDNFs(ctx BuildContext, conditions []Expression) []Expression { - var allExtracted []Expression - for i := len(conditions) - 1; i >= 0; i-- { - if sf, ok := conditions[i].(*ScalarFunction); ok && sf.FuncName.L == ast.LogicOr { - extracted, remained := extractFiltersFromDNF(ctx, sf) - allExtracted = append(allExtracted, extracted...) - if remained == nil { - conditions = append(conditions[:i], conditions[i+1:]...) - } else { - conditions[i] = remained - } - } - } - return append(conditions, allExtracted...) -} - -// extractFiltersFromDNF extracts the same condition that occurs in every DNF item and remove them from dnf leaves. -func extractFiltersFromDNF(ctx BuildContext, dnfFunc *ScalarFunction) ([]Expression, Expression) { - dnfItems := FlattenDNFConditions(dnfFunc) - codeMap := make(map[string]int) - hashcode2Expr := make(map[string]Expression) - for i, dnfItem := range dnfItems { - innerMap := make(map[string]struct{}) - cnfItems := SplitCNFItems(dnfItem) - for _, cnfItem := range cnfItems { - code := cnfItem.HashCode() - if i == 0 { - codeMap[string(code)] = 1 - hashcode2Expr[string(code)] = cnfItem - } else if _, ok := codeMap[string(code)]; ok { - // We need this check because there may be the case like `select * from t, t1 where (t.a=t1.a and t.a=t1.a) or (something). - // We should make sure that the two `t.a=t1.a` contributes only once. - // TODO: do this out of this function. - if _, ok = innerMap[string(code)]; !ok { - codeMap[string(code)]++ - innerMap[string(code)] = struct{}{} - } - } - } - } - // We should make sure that this item occurs in every DNF item. - for hashcode, cnt := range codeMap { - if cnt < len(dnfItems) { - delete(hashcode2Expr, hashcode) - } - } - if len(hashcode2Expr) == 0 { - return nil, dnfFunc - } - newDNFItems := make([]Expression, 0, len(dnfItems)) - onlyNeedExtracted := false - for _, dnfItem := range dnfItems { - cnfItems := SplitCNFItems(dnfItem) - newCNFItems := make([]Expression, 0, len(cnfItems)) - for _, cnfItem := range cnfItems { - code := cnfItem.HashCode() - _, ok := hashcode2Expr[string(code)] - if !ok { - newCNFItems = append(newCNFItems, cnfItem) - } - } - // If the extracted part is just one leaf of the DNF expression. Then the value of the total DNF expression is - // always the same with the value of the extracted part. - if len(newCNFItems) == 0 { - onlyNeedExtracted = true - break - } - newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) - } - extractedExpr := make([]Expression, 0, len(hashcode2Expr)) - for _, expr := range hashcode2Expr { - extractedExpr = append(extractedExpr, expr) - } - if onlyNeedExtracted { - return extractedExpr, nil - } - return extractedExpr, ComposeDNFCondition(ctx, newDNFItems...) -} - -// DeriveRelaxedFiltersFromDNF given a DNF expression, derive a relaxed DNF expression which only contains columns -// in specified schema; the derived expression is a superset of original expression, i.e, any tuple satisfying -// the original expression must satisfy the derived expression. Return nil when the derived expression is universal set. -// A running example is: for schema of t1, `(t1.a=1 and t2.a=1) or (t1.a=2 and t2.a=2)` would be derived as -// `t1.a=1 or t1.a=2`, while `t1.a=1 or t2.a=1` would get nil. -func DeriveRelaxedFiltersFromDNF(ctx BuildContext, expr Expression, schema *Schema) Expression { - sf, ok := expr.(*ScalarFunction) - if !ok || sf.FuncName.L != ast.LogicOr { - return nil - } - dnfItems := FlattenDNFConditions(sf) - newDNFItems := make([]Expression, 0, len(dnfItems)) - for _, dnfItem := range dnfItems { - cnfItems := SplitCNFItems(dnfItem) - newCNFItems := make([]Expression, 0, len(cnfItems)) - for _, cnfItem := range cnfItems { - if itemSF, ok := cnfItem.(*ScalarFunction); ok && itemSF.FuncName.L == ast.LogicOr { - relaxedCNFItem := DeriveRelaxedFiltersFromDNF(ctx, cnfItem, schema) - if relaxedCNFItem != nil { - newCNFItems = append(newCNFItems, relaxedCNFItem) - } - // If relaxed expression for embedded DNF is universal set, just drop this CNF item - continue - } - // This cnfItem must be simple expression now - // If it cannot be fully covered by schema, just drop this CNF item - if ExprFromSchema(cnfItem, schema) { - newCNFItems = append(newCNFItems, cnfItem) - } - } - // If this DNF item involves no column of specified schema, the relaxed expression must be universal set - if len(newCNFItems) == 0 { - return nil - } - newDNFItems = append(newDNFItems, ComposeCNFCondition(ctx, newCNFItems...)) - } - return ComposeDNFCondition(ctx, newDNFItems...) -} - -// GetRowLen gets the length if the func is row, returns 1 if not row. -func GetRowLen(e Expression) int { - if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { - return len(f.GetArgs()) - } - return 1 -} - -// CheckArgsNotMultiColumnRow checks the args are not multi-column row. -func CheckArgsNotMultiColumnRow(args ...Expression) error { - for _, arg := range args { - if GetRowLen(arg) != 1 { - return ErrOperandColumns.GenWithStackByArgs(1) - } - } - return nil -} - -// GetFuncArg gets the argument of the function at idx. -func GetFuncArg(e Expression, idx int) Expression { - if f, ok := e.(*ScalarFunction); ok { - return f.GetArgs()[idx] - } - return nil -} - -// PopRowFirstArg pops the first element and returns the rest of row. -// e.g. After this function (1, 2, 3) becomes (2, 3). -func PopRowFirstArg(ctx BuildContext, e Expression) (ret Expression, err error) { - if f, ok := e.(*ScalarFunction); ok && f.FuncName.L == ast.RowFunc { - args := f.GetArgs() - if len(args) == 2 { - return args[1], nil - } - ret, err = NewFunction(ctx, ast.RowFunc, f.GetType(ctx.GetEvalCtx()), args[1:]...) - return ret, err - } - return -} - -// DatumToConstant generates a Constant expression from a Datum. -func DatumToConstant(d types.Datum, tp byte, flag uint) *Constant { - t := types.NewFieldType(tp) - t.AddFlag(flag) - return &Constant{Value: d, RetType: t} -} - -// ParamMarkerExpression generate a getparam function expression. -func ParamMarkerExpression(ctx variable.SessionVarsProvider, v *driver.ParamMarkerExpr, needParam bool) (*Constant, error) { - useCache := ctx.GetSessionVars().StmtCtx.UseCache() - tp := types.NewFieldType(mysql.TypeUnspecified) - types.InferParamTypeFromDatum(&v.Datum, tp) - value := &Constant{Value: v.Datum, RetType: tp} - if useCache || needParam { - value.ParamMarker = &ParamMarker{ - order: v.Order, - } - } - return value, nil -} - -// ParamMarkerInPrepareChecker checks whether the given ast tree has paramMarker and is in prepare statement. -type ParamMarkerInPrepareChecker struct { - InPrepareStmt bool -} - -// Enter implements Visitor Interface. -func (pc *ParamMarkerInPrepareChecker) Enter(in ast.Node) (out ast.Node, skipChildren bool) { - switch v := in.(type) { - case *driver.ParamMarkerExpr: - pc.InPrepareStmt = !v.InExecute - return v, true - } - return in, false -} - -// Leave implements Visitor Interface. -func (pc *ParamMarkerInPrepareChecker) Leave(in ast.Node) (out ast.Node, ok bool) { - return in, true -} - -// DisableParseJSONFlag4Expr disables ParseToJSONFlag for `expr` except Column. -// We should not *PARSE* a string as JSON under some scenarios. ParseToJSONFlag -// is 0 for JSON column yet(as well as JSON correlated column), so we can skip -// it. Moreover, Column.RetType refers to the infoschema, if we modify it, data -// race may happen if another goroutine read from the infoschema at the same -// time. -func DisableParseJSONFlag4Expr(ctx EvalContext, expr Expression) { - if _, isColumn := expr.(*Column); isColumn { - return - } - if _, isCorCol := expr.(*CorrelatedColumn); isCorCol { - return - } - expr.GetType(ctx).SetFlag(expr.GetType(ctx).GetFlag() & ^mysql.ParseToJSONFlag) -} - -// ConstructPositionExpr constructs PositionExpr with the given ParamMarkerExpr. -func ConstructPositionExpr(p *driver.ParamMarkerExpr) *ast.PositionExpr { - return &ast.PositionExpr{P: p} -} - -// PosFromPositionExpr generates a position value from PositionExpr. -func PosFromPositionExpr(ctx BuildContext, vars variable.SessionVarsProvider, v *ast.PositionExpr) (int, bool, error) { - if v.P == nil { - return v.N, false, nil - } - value, err := ParamMarkerExpression(vars, v.P.(*driver.ParamMarkerExpr), false) - if err != nil { - return 0, true, err - } - pos, isNull, err := GetIntFromConstant(ctx.GetEvalCtx(), value) - if err != nil || isNull { - return 0, true, err - } - return pos, false, nil -} - -// GetStringFromConstant gets a string value from the Constant expression. -func GetStringFromConstant(ctx EvalContext, value Expression) (string, bool, error) { - con, ok := value.(*Constant) - if !ok { - err := errors.Errorf("Not a Constant expression %+v", value) - return "", true, err - } - str, isNull, err := con.EvalString(ctx, chunk.Row{}) - if err != nil || isNull { - return "", true, err - } - return str, false, nil -} - -// GetIntFromConstant gets an integer value from the Constant expression. -func GetIntFromConstant(ctx EvalContext, value Expression) (int, bool, error) { - str, isNull, err := GetStringFromConstant(ctx, value) - if err != nil || isNull { - return 0, true, err - } - intNum, err := strconv.Atoi(str) - if err != nil { - return 0, true, nil - } - return intNum, false, nil -} - -// BuildNotNullExpr wraps up `not(isnull())` for given expression. -func BuildNotNullExpr(ctx BuildContext, expr Expression) Expression { - isNull := NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr) - notNull := NewFunctionInternal(ctx, ast.UnaryNot, types.NewFieldType(mysql.TypeTiny), isNull) - return notNull -} - -// IsRuntimeConstExpr checks if a expr can be treated as a constant in **executor**. -func IsRuntimeConstExpr(expr Expression) bool { - switch x := expr.(type) { - case *ScalarFunction: - if _, ok := unFoldableFunctions[x.FuncName.L]; ok { - return false - } - for _, arg := range x.GetArgs() { - if !IsRuntimeConstExpr(arg) { - return false - } - } - return true - case *Column: - return false - case *Constant, *CorrelatedColumn: - return true - } - return false -} - -// CheckNonDeterministic checks whether the current expression contains a non-deterministic func. -func CheckNonDeterministic(e Expression) bool { - switch x := e.(type) { - case *Constant, *Column, *CorrelatedColumn: - return false - case *ScalarFunction: - if _, ok := unFoldableFunctions[x.FuncName.L]; ok { - return true - } - for _, arg := range x.GetArgs() { - if CheckNonDeterministic(arg) { - return true - } - } - } - return false -} - -// CheckFuncInExpr checks whether there's a given function in the expression. -func CheckFuncInExpr(e Expression, funcName string) bool { - switch x := e.(type) { - case *Constant, *Column, *CorrelatedColumn: - return false - case *ScalarFunction: - if x.FuncName.L == funcName { - return true - } - for _, arg := range x.GetArgs() { - if CheckFuncInExpr(arg, funcName) { - return true - } - } - } - return false -} - -// IsMutableEffectsExpr checks if expr contains function which is mutable or has side effects. -func IsMutableEffectsExpr(expr Expression) bool { - switch x := expr.(type) { - case *ScalarFunction: - if _, ok := mutableEffectsFunctions[x.FuncName.L]; ok { - return true - } - for _, arg := range x.GetArgs() { - if IsMutableEffectsExpr(arg) { - return true - } - } - case *Column: - case *Constant: - if x.DeferredExpr != nil { - return IsMutableEffectsExpr(x.DeferredExpr) - } - } - return false -} - -// IsImmutableFunc checks whether this expression only consists of foldable functions. -// This expression can be evaluated by using `expr.Eval(chunk.Row{})` directly and the result won't change if it's immutable. -func IsImmutableFunc(expr Expression) bool { - switch x := expr.(type) { - case *ScalarFunction: - if _, ok := unFoldableFunctions[x.FuncName.L]; ok { - return false - } - if _, ok := mutableEffectsFunctions[x.FuncName.L]; ok { - return false - } - for _, arg := range x.GetArgs() { - if !IsImmutableFunc(arg) { - return false - } - } - return true - default: - return true - } -} - -// RemoveDupExprs removes identical exprs. Not that if expr contains functions which -// are mutable or have side effects, we cannot remove it even if it has duplicates; -// if the plan is going to be cached, we cannot remove expressions containing `?` neither. -func RemoveDupExprs(exprs []Expression) []Expression { - res := make([]Expression, 0, len(exprs)) - exists := make(map[string]struct{}, len(exprs)) - for _, expr := range exprs { - key := string(expr.HashCode()) - if _, ok := exists[key]; !ok || IsMutableEffectsExpr(expr) { - res = append(res, expr) - exists[key] = struct{}{} - } - } - return res -} - -// GetUint64FromConstant gets a uint64 from constant expression. -func GetUint64FromConstant(ctx EvalContext, expr Expression) (uint64, bool, bool) { - con, ok := expr.(*Constant) - if !ok { - logutil.BgLogger().Warn("not a constant expression", zap.String("expression", expr.ExplainInfo(ctx))) - return 0, false, false - } - dt := con.Value - if con.ParamMarker != nil { - var err error - dt, err = con.ParamMarker.GetUserVar(ctx) - if err != nil { - logutil.BgLogger().Warn("get param failed", zap.Error(err)) - return 0, false, false - } - } else if con.DeferredExpr != nil { - var err error - dt, err = con.DeferredExpr.Eval(ctx, chunk.Row{}) - if err != nil { - logutil.BgLogger().Warn("eval deferred expr failed", zap.Error(err)) - return 0, false, false - } - } - switch dt.Kind() { - case types.KindNull: - return 0, true, true - case types.KindInt64: - val := dt.GetInt64() - if val < 0 { - return 0, false, false - } - return uint64(val), false, true - case types.KindUint64: - return dt.GetUint64(), false, true - } - return 0, false, false -} - -// ContainVirtualColumn checks if the expressions contain a virtual column -func ContainVirtualColumn(exprs []Expression) bool { - for _, expr := range exprs { - switch v := expr.(type) { - case *Column: - if v.VirtualExpr != nil { - return true - } - case *ScalarFunction: - if ContainVirtualColumn(v.GetArgs()) { - return true - } - } - } - return false -} - -// ContainCorrelatedColumn checks if the expressions contain a correlated column -func ContainCorrelatedColumn(exprs []Expression) bool { - for _, expr := range exprs { - switch v := expr.(type) { - case *CorrelatedColumn: - return true - case *ScalarFunction: - if ContainCorrelatedColumn(v.GetArgs()) { - return true - } - } - } - return false -} - -func jsonUnquoteFunctionBenefitsFromPushedDown(sf *ScalarFunction) bool { - arg0 := sf.GetArgs()[0] - // Only `->>` which parsed to JSONUnquote(CAST(JSONExtract() AS string)) can be pushed down to tikv - if fChild, ok := arg0.(*ScalarFunction); ok { - if fChild.FuncName.L == ast.Cast { - if fGrand, ok := fChild.GetArgs()[0].(*ScalarFunction); ok { - if fGrand.FuncName.L == ast.JSONExtract { - return true - } - } - } - } - return false -} - -// ProjectionBenefitsFromPushedDown evaluates if the expressions can improve performance when pushed down to TiKV -// Projections are not pushed down to tikv by default, thus we need to check strictly here to avoid potential performance degradation. -// Note: virtual column is not considered here, since this function cares performance instead of functionality -func ProjectionBenefitsFromPushedDown(exprs []Expression, inputSchemaLen int) bool { - allColRef := true - colRefCount := 0 - for _, expr := range exprs { - switch v := expr.(type) { - case *Column: - colRefCount = colRefCount + 1 - continue - case *ScalarFunction: - allColRef = false - switch v.FuncName.L { - case ast.JSONDepth, ast.JSONLength, ast.JSONType, ast.JSONValid, ast.JSONContains, ast.JSONContainsPath, - ast.JSONExtract, ast.JSONKeys, ast.JSONSearch, ast.JSONMemberOf, ast.JSONOverlaps: - continue - case ast.JSONUnquote: - if jsonUnquoteFunctionBenefitsFromPushedDown(v) { - continue - } - return false - default: - return false - } - default: - return false - } - } - // For all col refs, only push down column pruning projections - if allColRef { - return colRefCount < inputSchemaLen - } - return true -} - -// MaybeOverOptimized4PlanCache used to check whether an optimization can work -// for the statement when we enable the plan cache. -// In some situations, some optimizations maybe over-optimize and cache an -// overOptimized plan. The cached plan may not get the correct result when we -// reuse the plan for other statements. -// For example, `pk>=$a and pk<=$b` can be optimized to a PointGet when -// `$a==$b`, but it will cause wrong results when `$a!=$b`. -// So we need to do the check here. The check includes the following aspects: -// 1. Whether the plan cache switch is enable. -// 2. Whether the statement can be cached. -// 3. Whether the expressions contain a lazy constant. -// TODO: Do more careful check here. -func MaybeOverOptimized4PlanCache(ctx BuildContext, exprs []Expression) bool { - // If we do not enable plan cache, all the optimization can work correctly. - if !ctx.IsUseCache() { - return false - } - return containMutableConst(ctx.GetEvalCtx(), exprs) -} - -// containMutableConst checks if the expressions contain a lazy constant. -func containMutableConst(ctx EvalContext, exprs []Expression) bool { - for _, expr := range exprs { - switch v := expr.(type) { - case *Constant: - if v.ParamMarker != nil || v.DeferredExpr != nil { - return true - } - case *ScalarFunction: - if containMutableConst(ctx, v.GetArgs()) { - return true - } - } - } - return false -} - -// RemoveMutableConst used to remove the `ParamMarker` and `DeferredExpr` in the `Constant` expr. -func RemoveMutableConst(ctx BuildContext, exprs []Expression) (err error) { - for _, expr := range exprs { - switch v := expr.(type) { - case *Constant: - v.ParamMarker = nil - if v.DeferredExpr != nil { // evaluate and update v.Value to convert v to a complete immutable constant. - // TODO: remove or hide DeferredExpr since it's too dangerous (hard to be consistent with v.Value all the time). - v.Value, err = v.DeferredExpr.Eval(ctx.GetEvalCtx(), chunk.Row{}) - if err != nil { - return err - } - v.DeferredExpr = nil - } - v.DeferredExpr = nil // do nothing since v.Value has already been evaluated in this case. - case *ScalarFunction: - return RemoveMutableConst(ctx, v.GetArgs()) - } - } - return nil -} - -const ( - _ = iota - kib = 1 << (10 * iota) - mib = 1 << (10 * iota) - gib = 1 << (10 * iota) - tib = 1 << (10 * iota) - pib = 1 << (10 * iota) - eib = 1 << (10 * iota) -) - -const ( - nano = 1 - micro = 1000 * nano - milli = 1000 * micro - sec = 1000 * milli - min = 60 * sec - hour = 60 * min - dayTime = 24 * hour -) - -// GetFormatBytes convert byte count to value with units. -func GetFormatBytes(bytes float64) string { - var divisor float64 - var unit string - - bytesAbs := math.Abs(bytes) - if bytesAbs >= eib { - divisor = eib - unit = "EiB" - } else if bytesAbs >= pib { - divisor = pib - unit = "PiB" - } else if bytesAbs >= tib { - divisor = tib - unit = "TiB" - } else if bytesAbs >= gib { - divisor = gib - unit = "GiB" - } else if bytesAbs >= mib { - divisor = mib - unit = "MiB" - } else if bytesAbs >= kib { - divisor = kib - unit = "KiB" - } else { - divisor = 1 - unit = "bytes" - } - - if divisor == 1 { - return strconv.FormatFloat(bytes, 'f', 0, 64) + " " + unit - } - value := bytes / divisor - if math.Abs(value) >= 100000.0 { - return strconv.FormatFloat(value, 'e', 2, 64) + " " + unit - } - return strconv.FormatFloat(value, 'f', 2, 64) + " " + unit -} - -// GetFormatNanoTime convert time in nanoseconds to value with units. -func GetFormatNanoTime(time float64) string { - var divisor float64 - var unit string - - timeAbs := math.Abs(time) - if timeAbs >= dayTime { - divisor = dayTime - unit = "d" - } else if timeAbs >= hour { - divisor = hour - unit = "h" - } else if timeAbs >= min { - divisor = min - unit = "min" - } else if timeAbs >= sec { - divisor = sec - unit = "s" - } else if timeAbs >= milli { - divisor = milli - unit = "ms" - } else if timeAbs >= micro { - divisor = micro - unit = "us" - } else { - divisor = 1 - unit = "ns" - } - - if divisor == 1 { - return strconv.FormatFloat(time, 'f', 0, 64) + " " + unit - } - value := time / divisor - if math.Abs(value) >= 100000.0 { - return strconv.FormatFloat(value, 'e', 2, 64) + " " + unit - } - return strconv.FormatFloat(value, 'f', 2, 64) + " " + unit -} - -// SQLDigestTextRetriever is used to find the normalized SQL statement text by SQL digests in statements_summary table. -// It's exported for test purposes. It's used by the `tidb_decode_sql_digests` builtin function, but also exposed to -// be used in other modules. -type SQLDigestTextRetriever struct { - // SQLDigestsMap is the place to put the digests that's requested for getting SQL text and also the place to put - // the query result. - SQLDigestsMap map[string]string - - // Replace querying for test purposes. - mockLocalData map[string]string - mockGlobalData map[string]string - // There are two ways for querying information: 1) query specified digests by WHERE IN query, or 2) query all - // information to avoid the too long WHERE IN clause. If there are more than `fetchAllLimit` digests needs to be - // queried, the second way will be chosen; otherwise, the first way will be chosen. - fetchAllLimit int -} - -// NewSQLDigestTextRetriever creates a new SQLDigestTextRetriever. -func NewSQLDigestTextRetriever() *SQLDigestTextRetriever { - return &SQLDigestTextRetriever{ - SQLDigestsMap: make(map[string]string), - fetchAllLimit: 512, - } -} - -func (r *SQLDigestTextRetriever) runMockQuery(data map[string]string, inValues []any) (map[string]string, error) { - if len(inValues) == 0 { - return data, nil - } - res := make(map[string]string, len(inValues)) - for _, digest := range inValues { - if text, ok := data[digest.(string)]; ok { - res[digest.(string)] = text - } - } - return res, nil -} - -// runFetchDigestQuery runs query to the system tables to fetch the kv mapping of SQL digests and normalized SQL texts -// of the given SQL digests, if `inValues` is given, or all these mappings otherwise. If `queryGlobal` is false, it -// queries information_schema.statements_summary and information_schema.statements_summary_history; otherwise, it -// queries the cluster version of these two tables. -func (r *SQLDigestTextRetriever) runFetchDigestQuery(ctx context.Context, exec contextopt.SQLExecutor, queryGlobal bool, inValues []any) (map[string]string, error) { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) - // If mock data is set, query the mock data instead of the real statements_summary tables. - if !queryGlobal && r.mockLocalData != nil { - return r.runMockQuery(r.mockLocalData, inValues) - } else if queryGlobal && r.mockGlobalData != nil { - return r.runMockQuery(r.mockGlobalData, inValues) - } - - // Information in statements_summary will be periodically moved to statements_summary_history. Union them together - // to avoid missing information when statements_summary is just cleared. - stmt := "select digest, digest_text from information_schema.statements_summary union distinct " + - "select digest, digest_text from information_schema.statements_summary_history" - if queryGlobal { - stmt = "select digest, digest_text from information_schema.cluster_statements_summary union distinct " + - "select digest, digest_text from information_schema.cluster_statements_summary_history" - } - // Add the where clause if `inValues` is specified. - if len(inValues) > 0 { - stmt += " where digest in (" + strings.Repeat("%?,", len(inValues)-1) + "%?)" - } - - rows, _, err := exec.ExecRestrictedSQL(ctx, nil, stmt, inValues...) - if err != nil { - return nil, err - } - - res := make(map[string]string, len(rows)) - for _, row := range rows { - res[row.GetString(0)] = row.GetString(1) - } - return res, nil -} - -func (r *SQLDigestTextRetriever) updateDigestInfo(queryResult map[string]string) { - for digest, text := range r.SQLDigestsMap { - if len(text) > 0 { - // The text of this digest is already known - continue - } - sqlText, ok := queryResult[digest] - if ok { - r.SQLDigestsMap[digest] = sqlText - } - } -} - -// RetrieveLocal tries to retrieve the SQL text of the SQL digests from local information. -func (r *SQLDigestTextRetriever) RetrieveLocal(ctx context.Context, exec contextopt.SQLExecutor) error { - if len(r.SQLDigestsMap) == 0 { - return nil - } - - var queryResult map[string]string - if len(r.SQLDigestsMap) <= r.fetchAllLimit { - inValues := make([]any, 0, len(r.SQLDigestsMap)) - for key := range r.SQLDigestsMap { - inValues = append(inValues, key) - } - var err error - queryResult, err = r.runFetchDigestQuery(ctx, exec, false, inValues) - if err != nil { - return errors.Trace(err) - } - - if len(queryResult) == len(r.SQLDigestsMap) { - r.SQLDigestsMap = queryResult - return nil - } - } else { - var err error - queryResult, err = r.runFetchDigestQuery(ctx, exec, false, nil) - if err != nil { - return errors.Trace(err) - } - } - - r.updateDigestInfo(queryResult) - return nil -} - -// RetrieveGlobal tries to retrieve the SQL text of the SQL digests from the information of the whole cluster. -func (r *SQLDigestTextRetriever) RetrieveGlobal(ctx context.Context, exec contextopt.SQLExecutor) error { - err := r.RetrieveLocal(ctx, exec) - if err != nil { - return errors.Trace(err) - } - - // In some unit test environments it's unable to retrieve global info, and this function blocks it for tens of - // seconds, which wastes much time during unit test. In this case, enable this failpoint to bypass retrieving - // globally. - failpoint.Inject("sqlDigestRetrieverSkipRetrieveGlobal", func() { - failpoint.Return(nil) - }) - - var unknownDigests []any - for k, v := range r.SQLDigestsMap { - if len(v) == 0 { - unknownDigests = append(unknownDigests, k) - } - } - - if len(unknownDigests) == 0 { - return nil - } - - var queryResult map[string]string - if len(r.SQLDigestsMap) <= r.fetchAllLimit { - queryResult, err = r.runFetchDigestQuery(ctx, exec, true, unknownDigests) - if err != nil { - return errors.Trace(err) - } - } else { - queryResult, err = r.runFetchDigestQuery(ctx, exec, true, nil) - if err != nil { - return errors.Trace(err) - } - } - - r.updateDigestInfo(queryResult) - return nil -} - -// ExprsToStringsForDisplay convert a slice of Expression to a slice of string using Expression.String(), and -// to make it better for display and debug, it also escapes the string to corresponding golang string literal, -// which means using \t, \n, \x??, \u????, ... to represent newline, control character, non-printable character, -// invalid utf-8 bytes and so on. -func ExprsToStringsForDisplay(ctx EvalContext, exprs []Expression) []string { - strs := make([]string, len(exprs)) - for i, cond := range exprs { - quote := `"` - // We only need the escape functionality of strconv.Quote, the quoting is not needed, - // so we trim the \" prefix and suffix here. - strs[i] = strings.TrimSuffix( - strings.TrimPrefix( - strconv.Quote(cond.StringWithCtx(ctx, errors.RedactLogDisable)), - quote), - quote) - } - return strs -} - -// ConstExprConsiderPlanCache indicates whether the expression can be considered as a constant expression considering planCache. -// If the expression is in plan cache, it should have a const level `ConstStrict` because it can be shared across statements. -// If the expression is not in plan cache, `ConstOnlyInContext` is enough because it is only used in one statement. -// Please notice that if the expression may be cached in other ways except plan cache, we should not use this function. -func ConstExprConsiderPlanCache(expr Expression, inPlanCache bool) bool { - switch expr.ConstLevel() { - case ConstStrict: - return true - case ConstOnlyInContext: - return !inPlanCache - default: - return false - } -} - -// ExprsHasSideEffects checks if any of the expressions has side effects. -func ExprsHasSideEffects(exprs []Expression) bool { - for _, expr := range exprs { - if ExprHasSetVarOrSleep(expr) { - return true - } - } - return false -} - -// ExprHasSetVarOrSleep checks if the expression has SetVar function or Sleep function. -func ExprHasSetVarOrSleep(expr Expression) bool { - scalaFunc, isScalaFunc := expr.(*ScalarFunction) - if !isScalaFunc { - return false - } - if scalaFunc.FuncName.L == ast.SetVar || scalaFunc.FuncName.L == ast.Sleep { - return true - } - for _, arg := range scalaFunc.GetArgs() { - if ExprHasSetVarOrSleep(arg) { - return true - } - } - return false -} diff --git a/pkg/infoschema/binding__failpoint_binding__.go b/pkg/infoschema/binding__failpoint_binding__.go deleted file mode 100644 index 5fad7f8a11a71..0000000000000 --- a/pkg/infoschema/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package infoschema - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/infoschema/builder.go b/pkg/infoschema/builder.go index 81e30ea733338..6e489d29b4721 100644 --- a/pkg/infoschema/builder.go +++ b/pkg/infoschema/builder.go @@ -670,13 +670,13 @@ func applyCreateTable(b *Builder, m *meta.Meta, dbInfo *model.DBInfo, tableID in // Failpoint check whether tableInfo should be added to repairInfo. // Typically used in repair table test to load mock `bad` tableInfo into repairInfo. - if val, _err_ := failpoint.Eval(_curpkg_("repairFetchCreateTable")); _err_ == nil { + failpoint.Inject("repairFetchCreateTable", func(val failpoint.Value) { if val.(bool) { if domainutil.RepairInfo.InRepairMode() && tp != model.ActionRepairTable && domainutil.RepairInfo.CheckAndFetchRepairedTable(dbInfo, tblInfo) { - return nil, nil + failpoint.Return(nil, nil) } } - } + }) ConvertCharsetCollateToLowerCaseIfNeed(tblInfo) ConvertOldVersionUTF8ToUTF8MB4IfNeed(tblInfo) diff --git a/pkg/infoschema/builder.go__failpoint_stash__ b/pkg/infoschema/builder.go__failpoint_stash__ deleted file mode 100644 index 6e489d29b4721..0000000000000 --- a/pkg/infoschema/builder.go__failpoint_stash__ +++ /dev/null @@ -1,1040 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package infoschema - -import ( - "cmp" - "context" - "fmt" - "maps" - "slices" - "strings" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/util/domainutil" - "github.com/pingcap/tidb/pkg/util/intest" -) - -// Builder builds a new InfoSchema. -type Builder struct { - enableV2 bool - infoschemaV2 - // dbInfos do not need to be copied everytime applying a diff, instead, - // they can be copied only once over the whole lifespan of Builder. - // This map will indicate which DB has been copied, so that they - // don't need to be copied again. - dirtyDB map[string]bool - - // Used by autoid allocators - autoid.Requirement - - factory func() (pools.Resource, error) - bundleInfoBuilder - infoData *Data - store kv.Storage -} - -// ApplyDiff applies SchemaDiff to the new InfoSchema. -// Return the detail updated table IDs that are produced from SchemaDiff and an error. -func (b *Builder) ApplyDiff(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - b.schemaMetaVersion = diff.Version - switch diff.Type { - case model.ActionCreateSchema: - return nil, applyCreateSchema(b, m, diff) - case model.ActionDropSchema: - return applyDropSchema(b, diff), nil - case model.ActionRecoverSchema: - return applyRecoverSchema(b, m, diff) - case model.ActionModifySchemaCharsetAndCollate: - return nil, applyModifySchemaCharsetAndCollate(b, m, diff) - case model.ActionModifySchemaDefaultPlacement: - return nil, applyModifySchemaDefaultPlacement(b, m, diff) - case model.ActionCreatePlacementPolicy: - return nil, applyCreatePolicy(b, m, diff) - case model.ActionDropPlacementPolicy: - return applyDropPolicy(b, diff.SchemaID), nil - case model.ActionAlterPlacementPolicy: - return applyAlterPolicy(b, m, diff) - case model.ActionCreateResourceGroup: - return nil, applyCreateOrAlterResourceGroup(b, m, diff) - case model.ActionAlterResourceGroup: - return nil, applyCreateOrAlterResourceGroup(b, m, diff) - case model.ActionDropResourceGroup: - return applyDropResourceGroup(b, m, diff), nil - case model.ActionTruncateTablePartition, model.ActionTruncateTable: - return applyTruncateTableOrPartition(b, m, diff) - case model.ActionDropTable, model.ActionDropTablePartition: - return applyDropTableOrPartition(b, m, diff) - case model.ActionRecoverTable: - return applyRecoverTable(b, m, diff) - case model.ActionCreateTables: - return applyCreateTables(b, m, diff) - case model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - return applyReorganizePartition(b, m, diff) - case model.ActionExchangeTablePartition: - return applyExchangeTablePartition(b, m, diff) - case model.ActionFlashbackCluster: - return []int64{-1}, nil - default: - return applyDefaultAction(b, m, diff) - } -} - -func (b *Builder) applyCreateTables(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - return b.applyAffectedOpts(m, make([]int64, 0, len(diff.AffectedOpts)), diff, model.ActionCreateTable) -} - -func applyTruncateTableOrPartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - tblIDs, err := applyTableUpdate(b, m, diff) - if err != nil { - return nil, errors.Trace(err) - } - - // bundle ops - if diff.Type == model.ActionTruncateTable { - b.deleteBundle(b.infoSchema, diff.OldTableID) - b.markTableBundleShouldUpdate(diff.TableID) - } - - for _, opt := range diff.AffectedOpts { - if diff.Type == model.ActionTruncateTablePartition { - // Reduce the impact on DML when executing partition DDL. eg. - // While session 1 performs the DML operation associated with partition 1, - // the TRUNCATE operation of session 2 on partition 2 does not cause the operation of session 1 to fail. - tblIDs = append(tblIDs, opt.OldTableID) - b.markPartitionBundleShouldUpdate(opt.TableID) - } - b.deleteBundle(b.infoSchema, opt.OldTableID) - } - return tblIDs, nil -} - -func applyDropTableOrPartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - tblIDs, err := applyTableUpdate(b, m, diff) - if err != nil { - return nil, errors.Trace(err) - } - - // bundle ops - b.markTableBundleShouldUpdate(diff.TableID) - for _, opt := range diff.AffectedOpts { - b.deleteBundle(b.infoSchema, opt.OldTableID) - } - return tblIDs, nil -} - -func applyReorganizePartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - tblIDs, err := applyTableUpdate(b, m, diff) - if err != nil { - return nil, errors.Trace(err) - } - - // bundle ops - for _, opt := range diff.AffectedOpts { - if opt.OldTableID != 0 { - b.deleteBundle(b.infoSchema, opt.OldTableID) - } - if opt.TableID != 0 { - b.markTableBundleShouldUpdate(opt.TableID) - } - // TODO: Should we also check markPartitionBundleShouldUpdate?!? - } - return tblIDs, nil -} - -func applyExchangeTablePartition(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - // It is not in StatePublic. - if diff.OldTableID == diff.TableID && diff.OldSchemaID == diff.SchemaID { - ntIDs, err := applyTableUpdate(b, m, diff) - if err != nil { - return nil, errors.Trace(err) - } - if diff.AffectedOpts == nil || diff.AffectedOpts[0].OldSchemaID == 0 { - return ntIDs, err - } - // Reload parition tabe. - ptSchemaID := diff.AffectedOpts[0].OldSchemaID - ptID := diff.AffectedOpts[0].TableID - ptDiff := &model.SchemaDiff{ - Type: diff.Type, - Version: diff.Version, - TableID: ptID, - SchemaID: ptSchemaID, - OldTableID: ptID, - OldSchemaID: ptSchemaID, - } - ptIDs, err := applyTableUpdate(b, m, ptDiff) - if err != nil { - return nil, errors.Trace(err) - } - return append(ptIDs, ntIDs...), nil - } - ntSchemaID := diff.OldSchemaID - ntID := diff.OldTableID - ptSchemaID := diff.SchemaID - ptID := diff.TableID - partID := diff.TableID - if len(diff.AffectedOpts) > 0 { - ptID = diff.AffectedOpts[0].TableID - if diff.AffectedOpts[0].SchemaID != 0 { - ptSchemaID = diff.AffectedOpts[0].SchemaID - } - } - // The normal table needs to be updated first: - // Just update the tables separately - currDiff := &model.SchemaDiff{ - // This is only for the case since https://github.com/pingcap/tidb/pull/45877 - // Fixed now, by adding back the AffectedOpts - // to carry the partitioned Table ID. - Type: diff.Type, - Version: diff.Version, - TableID: ntID, - SchemaID: ntSchemaID, - } - if ptID != partID { - currDiff.TableID = partID - currDiff.OldTableID = ntID - currDiff.OldSchemaID = ntSchemaID - } - ntIDs, err := applyTableUpdate(b, m, currDiff) - if err != nil { - return nil, errors.Trace(err) - } - // partID is the new id for the non-partitioned table! - b.markTableBundleShouldUpdate(partID) - // Then the partitioned table, will re-read the whole table, including all partitions! - currDiff.TableID = ptID - currDiff.SchemaID = ptSchemaID - currDiff.OldTableID = ptID - currDiff.OldSchemaID = ptSchemaID - ptIDs, err := applyTableUpdate(b, m, currDiff) - if err != nil { - return nil, errors.Trace(err) - } - // ntID is the new id for the partition! - b.markPartitionBundleShouldUpdate(ntID) - err = updateAutoIDForExchangePartition(b.Requirement.Store(), ptSchemaID, ptID, ntSchemaID, ntID) - if err != nil { - return nil, errors.Trace(err) - } - return append(ptIDs, ntIDs...), nil -} - -func applyRecoverTable(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - tblIDs, err := applyTableUpdate(b, m, diff) - if err != nil { - return nil, errors.Trace(err) - } - - // bundle ops - for _, opt := range diff.AffectedOpts { - b.markTableBundleShouldUpdate(opt.TableID) - } - return tblIDs, nil -} - -func updateAutoIDForExchangePartition(store kv.Storage, ptSchemaID, ptID, ntSchemaID, ntID int64) error { - err := kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(ctx context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - ptAutoIDs, err := t.GetAutoIDAccessors(ptSchemaID, ptID).Get() - if err != nil { - return err - } - - // non-partition table auto IDs. - ntAutoIDs, err := t.GetAutoIDAccessors(ntSchemaID, ntID).Get() - if err != nil { - return err - } - - // Set both tables to the maximum auto IDs between normal table and partitioned table. - newAutoIDs := meta.AutoIDGroup{ - RowID: max(ptAutoIDs.RowID, ntAutoIDs.RowID), - IncrementID: max(ptAutoIDs.IncrementID, ntAutoIDs.IncrementID), - RandomID: max(ptAutoIDs.RandomID, ntAutoIDs.RandomID), - } - err = t.GetAutoIDAccessors(ptSchemaID, ptID).Put(newAutoIDs) - if err != nil { - return err - } - err = t.GetAutoIDAccessors(ntSchemaID, ntID).Put(newAutoIDs) - if err != nil { - return err - } - return nil - }) - - return err -} - -func (b *Builder) applyAffectedOpts(m *meta.Meta, tblIDs []int64, diff *model.SchemaDiff, tp model.ActionType) ([]int64, error) { - if diff.AffectedOpts != nil { - for _, opt := range diff.AffectedOpts { - affectedDiff := &model.SchemaDiff{ - Version: diff.Version, - Type: tp, - SchemaID: opt.SchemaID, - TableID: opt.TableID, - OldSchemaID: opt.OldSchemaID, - OldTableID: opt.OldTableID, - } - affectedIDs, err := b.ApplyDiff(m, affectedDiff) - if err != nil { - return nil, errors.Trace(err) - } - tblIDs = append(tblIDs, affectedIDs...) - } - } - return tblIDs, nil -} - -func applyDefaultAction(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - tblIDs, err := applyTableUpdate(b, m, diff) - if err != nil { - return nil, errors.Trace(err) - } - - return b.applyAffectedOpts(m, tblIDs, diff, diff.Type) -} - -func (b *Builder) getTableIDs(diff *model.SchemaDiff) (oldTableID, newTableID int64) { - switch diff.Type { - case model.ActionCreateSequence, model.ActionRecoverTable: - newTableID = diff.TableID - case model.ActionCreateTable: - // WARN: when support create table with foreign key in https://github.com/pingcap/tidb/pull/37148, - // create table with foreign key requires a multi-step state change(none -> write-only -> public), - // when the table's state changes from write-only to public, infoSchema need to drop the old table - // which state is write-only, otherwise, infoSchema.sortedTablesBuckets will contain 2 table both - // have the same ID, but one state is write-only, another table's state is public, it's unexpected. - // - // WARN: this change will break the compatibility if execute create table with foreign key DDL when upgrading TiDB, - // since old-version TiDB doesn't know to delete the old table. - // Since the cluster-index feature also has similar problem, we chose to prevent DDL execution during the upgrade process to avoid this issue. - oldTableID = diff.OldTableID - newTableID = diff.TableID - case model.ActionDropTable, model.ActionDropView, model.ActionDropSequence: - oldTableID = diff.TableID - case model.ActionTruncateTable, model.ActionCreateView, - model.ActionExchangeTablePartition, model.ActionAlterTablePartitioning, - model.ActionRemovePartitioning: - oldTableID = diff.OldTableID - newTableID = diff.TableID - default: - oldTableID = diff.TableID - newTableID = diff.TableID - } - return -} - -func (b *Builder) updateBundleForTableUpdate(diff *model.SchemaDiff, newTableID, oldTableID int64) { - // handle placement rule cache - switch diff.Type { - case model.ActionCreateTable: - b.markTableBundleShouldUpdate(newTableID) - case model.ActionDropTable: - b.deleteBundle(b.infoSchema, oldTableID) - case model.ActionTruncateTable: - b.deleteBundle(b.infoSchema, oldTableID) - b.markTableBundleShouldUpdate(newTableID) - case model.ActionRecoverTable: - b.markTableBundleShouldUpdate(newTableID) - case model.ActionAlterTablePlacement: - b.markTableBundleShouldUpdate(newTableID) - } -} - -func dropTableForUpdate(b *Builder, newTableID, oldTableID int64, dbInfo *model.DBInfo, diff *model.SchemaDiff) ([]int64, autoid.Allocators, error) { - tblIDs := make([]int64, 0, 2) - var keptAllocs autoid.Allocators - // We try to reuse the old allocator, so the cached auto ID can be reused. - if tableIDIsValid(oldTableID) { - if oldTableID == newTableID && - // For rename table, keep the old alloc. - - // For repairing table in TiDB cluster, given 2 normal node and 1 repair node. - // For normal node's information schema, repaired table is existed. - // For repair node's information schema, repaired table is filtered (couldn't find it in `is`). - // So here skip to reserve the allocators when repairing table. - diff.Type != model.ActionRepairTable && - // Alter sequence will change the sequence info in the allocator, so the old allocator is not valid any more. - diff.Type != model.ActionAlterSequence { - // TODO: Check how this would work with ADD/REMOVE Partitioning, - // which may have AutoID not connected to tableID - // TODO: can there be _tidb_rowid AutoID per partition? - oldAllocs, _ := allocByID(b, oldTableID) - keptAllocs = getKeptAllocators(diff, oldAllocs) - } - - tmpIDs := tblIDs - if (diff.Type == model.ActionRenameTable || diff.Type == model.ActionRenameTables) && diff.OldSchemaID != diff.SchemaID { - oldDBInfo, ok := oldSchemaInfo(b, diff) - if !ok { - return nil, keptAllocs, ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.OldSchemaID), - ) - } - tmpIDs = applyDropTable(b, diff, oldDBInfo, oldTableID, tmpIDs) - } else { - tmpIDs = applyDropTable(b, diff, dbInfo, oldTableID, tmpIDs) - } - - if oldTableID != newTableID { - // Update tblIDs only when oldTableID != newTableID because applyCreateTable() also updates tblIDs. - tblIDs = tmpIDs - } - } - return tblIDs, keptAllocs, nil -} - -func (b *Builder) applyTableUpdate(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - roDBInfo, ok := b.infoSchema.SchemaByID(diff.SchemaID) - if !ok { - return nil, ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.SchemaID), - ) - } - dbInfo := b.getSchemaAndCopyIfNecessary(roDBInfo.Name.L) - oldTableID, newTableID := b.getTableIDs(diff) - b.updateBundleForTableUpdate(diff, newTableID, oldTableID) - b.copySortedTables(oldTableID, newTableID) - - tblIDs, allocs, err := dropTableForUpdate(b, newTableID, oldTableID, dbInfo, diff) - if err != nil { - return nil, err - } - - if tableIDIsValid(newTableID) { - // All types except DropTableOrView. - var err error - tblIDs, err = applyCreateTable(b, m, dbInfo, newTableID, allocs, diff.Type, tblIDs, diff.Version) - if err != nil { - return nil, errors.Trace(err) - } - } - return tblIDs, nil -} - -// getKeptAllocators get allocators that is not changed by the DDL. -func getKeptAllocators(diff *model.SchemaDiff, oldAllocs autoid.Allocators) autoid.Allocators { - var autoIDChanged, autoRandomChanged bool - switch diff.Type { - case model.ActionRebaseAutoID, model.ActionModifyTableAutoIdCache: - autoIDChanged = true - case model.ActionRebaseAutoRandomBase: - autoRandomChanged = true - case model.ActionMultiSchemaChange: - for _, t := range diff.SubActionTypes { - switch t { - case model.ActionRebaseAutoID, model.ActionModifyTableAutoIdCache: - autoIDChanged = true - case model.ActionRebaseAutoRandomBase: - autoRandomChanged = true - } - } - } - var newAllocs autoid.Allocators - switch { - case autoIDChanged: - // Only drop auto-increment allocator. - newAllocs = oldAllocs.Filter(func(a autoid.Allocator) bool { - tp := a.GetType() - return tp != autoid.RowIDAllocType && tp != autoid.AutoIncrementType - }) - case autoRandomChanged: - // Only drop auto-random allocator. - newAllocs = oldAllocs.Filter(func(a autoid.Allocator) bool { - tp := a.GetType() - return tp != autoid.AutoRandomType - }) - default: - // Keep all allocators. - newAllocs = oldAllocs - } - return newAllocs -} - -func appendAffectedIDs(affected []int64, tblInfo *model.TableInfo) []int64 { - affected = append(affected, tblInfo.ID) - if pi := tblInfo.GetPartitionInfo(); pi != nil { - for _, def := range pi.Definitions { - affected = append(affected, def.ID) - } - } - return affected -} - -func (b *Builder) applyCreateSchema(m *meta.Meta, diff *model.SchemaDiff) error { - di, err := m.GetDatabase(diff.SchemaID) - if err != nil { - return errors.Trace(err) - } - if di == nil { - // When we apply an old schema diff, the database may has been dropped already, so we need to fall back to - // full load. - return ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.SchemaID), - ) - } - b.addDB(diff.Version, di, &schemaTables{dbInfo: di, tables: make(map[string]table.Table)}) - return nil -} - -func (b *Builder) applyModifySchemaCharsetAndCollate(m *meta.Meta, diff *model.SchemaDiff) error { - di, err := m.GetDatabase(diff.SchemaID) - if err != nil { - return errors.Trace(err) - } - if di == nil { - // This should never happen. - return ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.SchemaID), - ) - } - newDbInfo := b.getSchemaAndCopyIfNecessary(di.Name.L) - newDbInfo.Charset = di.Charset - newDbInfo.Collate = di.Collate - return nil -} - -func (b *Builder) applyModifySchemaDefaultPlacement(m *meta.Meta, diff *model.SchemaDiff) error { - di, err := m.GetDatabase(diff.SchemaID) - if err != nil { - return errors.Trace(err) - } - if di == nil { - // This should never happen. - return ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.SchemaID), - ) - } - newDbInfo := b.getSchemaAndCopyIfNecessary(di.Name.L) - newDbInfo.PlacementPolicyRef = di.PlacementPolicyRef - return nil -} - -func (b *Builder) applyDropSchema(diff *model.SchemaDiff) []int64 { - di, ok := b.infoSchema.SchemaByID(diff.SchemaID) - if !ok { - return nil - } - b.infoSchema.delSchema(di) - - // Copy the sortedTables that contain the table we are going to drop. - tableIDs := make([]int64, 0, len(di.Deprecated.Tables)) - bucketIdxMap := make(map[int]struct{}, len(di.Deprecated.Tables)) - for _, tbl := range di.Deprecated.Tables { - bucketIdxMap[tableBucketIdx(tbl.ID)] = struct{}{} - // TODO: If the table ID doesn't exist. - tableIDs = appendAffectedIDs(tableIDs, tbl) - } - for bucketIdx := range bucketIdxMap { - b.copySortedTablesBucket(bucketIdx) - } - - di = di.Clone() - for _, id := range tableIDs { - b.deleteBundle(b.infoSchema, id) - b.applyDropTable(diff, di, id, nil) - } - return tableIDs -} - -func (b *Builder) applyRecoverSchema(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - if di, ok := b.infoSchema.SchemaByID(diff.SchemaID); ok { - return nil, ErrDatabaseExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", di.ID), - ) - } - di, err := m.GetDatabase(diff.SchemaID) - if err != nil { - return nil, errors.Trace(err) - } - b.infoSchema.addSchema(&schemaTables{ - dbInfo: di, - tables: make(map[string]table.Table, len(diff.AffectedOpts)), - }) - return applyCreateTables(b, m, diff) -} - -// copySortedTables copies sortedTables for old table and new table for later modification. -func (b *Builder) copySortedTables(oldTableID, newTableID int64) { - if tableIDIsValid(oldTableID) { - b.copySortedTablesBucket(tableBucketIdx(oldTableID)) - } - if tableIDIsValid(newTableID) && newTableID != oldTableID { - b.copySortedTablesBucket(tableBucketIdx(newTableID)) - } -} - -func (b *Builder) copySortedTablesBucket(bucketIdx int) { - oldSortedTables := b.infoSchema.sortedTablesBuckets[bucketIdx] - newSortedTables := make(sortedTables, len(oldSortedTables)) - copy(newSortedTables, oldSortedTables) - b.infoSchema.sortedTablesBuckets[bucketIdx] = newSortedTables -} - -func (b *Builder) updateBundleForCreateTable(tblInfo *model.TableInfo, tp model.ActionType) { - switch tp { - case model.ActionDropTablePartition: - case model.ActionTruncateTablePartition: - // ReorganizePartition handle the bundles in applyReorganizePartition - case model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - default: - pi := tblInfo.GetPartitionInfo() - if pi != nil { - for _, partition := range pi.Definitions { - b.markPartitionBundleShouldUpdate(partition.ID) - } - } - } -} - -func (b *Builder) buildAllocsForCreateTable(tp model.ActionType, dbInfo *model.DBInfo, tblInfo *model.TableInfo, allocs autoid.Allocators) autoid.Allocators { - if len(allocs.Allocs) != 0 { - tblVer := autoid.AllocOptionTableInfoVersion(tblInfo.Version) - switch tp { - case model.ActionRebaseAutoID, model.ActionModifyTableAutoIdCache: - idCacheOpt := autoid.CustomAutoIncCacheOption(tblInfo.AutoIdCache) - // If the allocator type might be AutoIncrementType, create both AutoIncrementType - // and RowIDAllocType allocator for it. Because auto id and row id could share the same allocator. - // Allocate auto id may route to allocate row id, if row id allocator is nil, the program panic! - for _, tp := range [2]autoid.AllocatorType{autoid.AutoIncrementType, autoid.RowIDAllocType} { - newAlloc := autoid.NewAllocator(b.Requirement, dbInfo.ID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), tp, tblVer, idCacheOpt) - allocs = allocs.Append(newAlloc) - } - case model.ActionRebaseAutoRandomBase: - newAlloc := autoid.NewAllocator(b.Requirement, dbInfo.ID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType, tblVer) - allocs = allocs.Append(newAlloc) - case model.ActionModifyColumn: - // Change column attribute from auto_increment to auto_random. - if tblInfo.ContainsAutoRandomBits() && allocs.Get(autoid.AutoRandomType) == nil { - // Remove auto_increment allocator. - allocs = allocs.Filter(func(a autoid.Allocator) bool { - return a.GetType() != autoid.AutoIncrementType && a.GetType() != autoid.RowIDAllocType - }) - newAlloc := autoid.NewAllocator(b.Requirement, dbInfo.ID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), autoid.AutoRandomType, tblVer) - allocs = allocs.Append(newAlloc) - } - } - return allocs - } - return autoid.NewAllocatorsFromTblInfo(b.Requirement, dbInfo.ID, tblInfo) -} - -func applyCreateTable(b *Builder, m *meta.Meta, dbInfo *model.DBInfo, tableID int64, allocs autoid.Allocators, tp model.ActionType, affected []int64, schemaVersion int64) ([]int64, error) { - tblInfo, err := m.GetTable(dbInfo.ID, tableID) - if err != nil { - return nil, errors.Trace(err) - } - if tblInfo == nil { - // When we apply an old schema diff, the table may has been dropped already, so we need to fall back to - // full load. - return nil, ErrTableNotExists.FastGenByArgs( - fmt.Sprintf("(Schema ID %d)", dbInfo.ID), - fmt.Sprintf("(Table ID %d)", tableID), - ) - } - - b.updateBundleForCreateTable(tblInfo, tp) - - if tp != model.ActionTruncateTablePartition { - affected = appendAffectedIDs(affected, tblInfo) - } - - // Failpoint check whether tableInfo should be added to repairInfo. - // Typically used in repair table test to load mock `bad` tableInfo into repairInfo. - failpoint.Inject("repairFetchCreateTable", func(val failpoint.Value) { - if val.(bool) { - if domainutil.RepairInfo.InRepairMode() && tp != model.ActionRepairTable && domainutil.RepairInfo.CheckAndFetchRepairedTable(dbInfo, tblInfo) { - failpoint.Return(nil, nil) - } - } - }) - - ConvertCharsetCollateToLowerCaseIfNeed(tblInfo) - ConvertOldVersionUTF8ToUTF8MB4IfNeed(tblInfo) - - allocs = b.buildAllocsForCreateTable(tp, dbInfo, tblInfo, allocs) - - tbl, err := tableFromMeta(allocs, b.factory, tblInfo) - if err != nil { - return nil, errors.Trace(err) - } - - b.infoSchema.addReferredForeignKeys(dbInfo.Name, tblInfo) - - if !b.enableV2 { - tableNames := b.infoSchema.schemaMap[dbInfo.Name.L] - tableNames.tables[tblInfo.Name.L] = tbl - } - b.addTable(schemaVersion, dbInfo, tblInfo, tbl) - - bucketIdx := tableBucketIdx(tableID) - slices.SortFunc(b.infoSchema.sortedTablesBuckets[bucketIdx], func(i, j table.Table) int { - return cmp.Compare(i.Meta().ID, j.Meta().ID) - }) - - if tblInfo.TempTableType != model.TempTableNone { - b.addTemporaryTable(tableID) - } - - newTbl, ok := b.infoSchema.TableByID(tableID) - if ok { - dbInfo.Deprecated.Tables = append(dbInfo.Deprecated.Tables, newTbl.Meta()) - } - return affected, nil -} - -// ConvertCharsetCollateToLowerCaseIfNeed convert the charset / collation of table and its columns to lower case, -// if the table's version is prior to TableInfoVersion3. -func ConvertCharsetCollateToLowerCaseIfNeed(tbInfo *model.TableInfo) { - if tbInfo.Version >= model.TableInfoVersion3 { - return - } - tbInfo.Charset = strings.ToLower(tbInfo.Charset) - tbInfo.Collate = strings.ToLower(tbInfo.Collate) - for _, col := range tbInfo.Columns { - col.SetCharset(strings.ToLower(col.GetCharset())) - col.SetCollate(strings.ToLower(col.GetCollate())) - } -} - -// ConvertOldVersionUTF8ToUTF8MB4IfNeed convert old version UTF8 to UTF8MB4 if config.TreatOldVersionUTF8AsUTF8MB4 is enable. -func ConvertOldVersionUTF8ToUTF8MB4IfNeed(tbInfo *model.TableInfo) { - if tbInfo.Version >= model.TableInfoVersion2 || !config.GetGlobalConfig().TreatOldVersionUTF8AsUTF8MB4 { - return - } - if tbInfo.Charset == charset.CharsetUTF8 { - tbInfo.Charset = charset.CharsetUTF8MB4 - tbInfo.Collate = charset.CollationUTF8MB4 - } - for _, col := range tbInfo.Columns { - if col.Version < model.ColumnInfoVersion2 && col.GetCharset() == charset.CharsetUTF8 { - col.SetCharset(charset.CharsetUTF8MB4) - col.SetCollate(charset.CollationUTF8MB4) - } - } -} - -func (b *Builder) applyDropTable(diff *model.SchemaDiff, dbInfo *model.DBInfo, tableID int64, affected []int64) []int64 { - bucketIdx := tableBucketIdx(tableID) - sortedTbls := b.infoSchema.sortedTablesBuckets[bucketIdx] - idx := sortedTbls.searchTable(tableID) - if idx == -1 { - return affected - } - if tableNames, ok := b.infoSchema.schemaMap[dbInfo.Name.L]; ok { - tblInfo := sortedTbls[idx].Meta() - delete(tableNames.tables, tblInfo.Name.L) - affected = appendAffectedIDs(affected, tblInfo) - } - // Remove the table in sorted table slice. - b.infoSchema.sortedTablesBuckets[bucketIdx] = append(sortedTbls[0:idx], sortedTbls[idx+1:]...) - - // Remove the table in temporaryTables - if b.infoSchema.temporaryTableIDs != nil { - delete(b.infoSchema.temporaryTableIDs, tableID) - } - // The old DBInfo still holds a reference to old table info, we need to remove it. - b.deleteReferredForeignKeys(dbInfo, tableID) - return affected -} - -func (b *Builder) deleteReferredForeignKeys(dbInfo *model.DBInfo, tableID int64) { - tables := dbInfo.Deprecated.Tables - for i, tblInfo := range tables { - if tblInfo.ID == tableID { - if i == len(tables)-1 { - tables = tables[:i] - } else { - tables = append(tables[:i], tables[i+1:]...) - } - b.infoSchema.deleteReferredForeignKeys(dbInfo.Name, tblInfo) - break - } - } - dbInfo.Deprecated.Tables = tables -} - -// Build builds and returns the built infoschema. -func (b *Builder) Build(schemaTS uint64) InfoSchema { - if b.enableV2 { - b.infoschemaV2.ts = schemaTS - updateInfoSchemaBundles(b) - return &b.infoschemaV2 - } - updateInfoSchemaBundles(b) - return b.infoSchema -} - -// InitWithOldInfoSchema initializes an empty new InfoSchema by copies all the data from old InfoSchema. -func (b *Builder) InitWithOldInfoSchema(oldSchema InfoSchema) error { - // Do not mix infoschema v1 and infoschema v2 building, this can simplify the logic. - // If we want to build infoschema v2, but the old infoschema is v1, just return error to trigger a full load. - isV2, _ := IsV2(oldSchema) - if b.enableV2 != isV2 { - return errors.Errorf("builder's (v2=%v) infoschema mismatch, return error to trigger full reload", b.enableV2) - } - - if schemaV2, ok := oldSchema.(*infoschemaV2); ok { - b.infoschemaV2.ts = schemaV2.ts - } - oldIS := oldSchema.base() - b.initBundleInfoBuilder() - b.infoSchema.schemaMetaVersion = oldIS.schemaMetaVersion - b.infoSchema.schemaMap = maps.Clone(oldIS.schemaMap) - b.infoSchema.schemaID2Name = maps.Clone(oldIS.schemaID2Name) - b.infoSchema.ruleBundleMap = maps.Clone(oldIS.ruleBundleMap) - b.infoSchema.policyMap = oldIS.ClonePlacementPolicies() - b.infoSchema.resourceGroupMap = oldIS.CloneResourceGroups() - b.infoSchema.temporaryTableIDs = maps.Clone(oldIS.temporaryTableIDs) - b.infoSchema.referredForeignKeyMap = maps.Clone(oldIS.referredForeignKeyMap) - - copy(b.infoSchema.sortedTablesBuckets, oldIS.sortedTablesBuckets) - return nil -} - -// getSchemaAndCopyIfNecessary creates a new schemaTables instance when a table in the database has changed. -// It also does modifications on the new one because old schemaTables must be read-only. -// And it will only copy the changed database once in the lifespan of the Builder. -// NOTE: please make sure the dbName is in lowercase. -func (b *Builder) getSchemaAndCopyIfNecessary(dbName string) *model.DBInfo { - if !b.dirtyDB[dbName] { - b.dirtyDB[dbName] = true - oldSchemaTables := b.infoSchema.schemaMap[dbName] - newSchemaTables := &schemaTables{ - dbInfo: oldSchemaTables.dbInfo.Copy(), - tables: maps.Clone(oldSchemaTables.tables), - } - b.infoSchema.addSchema(newSchemaTables) - return newSchemaTables.dbInfo - } - return b.infoSchema.schemaMap[dbName].dbInfo -} - -func (b *Builder) initVirtualTables(schemaVersion int64) error { - // Initialize virtual tables. - for _, driver := range drivers { - err := b.createSchemaTablesForDB(driver.DBInfo, driver.TableFromMeta, schemaVersion) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -func (b *Builder) sortAllTablesByID() { - // Sort all tables by `ID` - for _, v := range b.infoSchema.sortedTablesBuckets { - slices.SortFunc(v, func(a, b table.Table) int { - return cmp.Compare(a.Meta().ID, b.Meta().ID) - }) - } -} - -// InitWithDBInfos initializes an empty new InfoSchema with a slice of DBInfo, all placement rules, and schema version. -func (b *Builder) InitWithDBInfos(dbInfos []*model.DBInfo, policies []*model.PolicyInfo, resourceGroups []*model.ResourceGroupInfo, schemaVersion int64) error { - info := b.infoSchema - info.schemaMetaVersion = schemaVersion - - b.initBundleInfoBuilder() - - b.initMisc(dbInfos, policies, resourceGroups) - - if b.enableV2 { - // We must not clear the historial versions like b.infoData = NewData(), because losing - // the historial versions would cause applyDiff get db not exist error and fail, then - // infoschema reloading retries with full load every time. - // See https://github.com/pingcap/tidb/issues/53442 - // - // We must reset it, otherwise the stale tables remain and cause bugs later. - // For example, schema version 59: - // 107: t1 - // 112: t2 (partitions p0=113, p1=114, p2=115) - // operation: alter table t2 exchange partition p0 with table t1 - // schema version 60 if we do not reset: - // 107: t1 <- stale - // 112: t2 (partition p0=107, p1=114, p2=115) - // 113: t1 - // See https://github.com/pingcap/tidb/issues/54796 - b.infoData.resetBeforeFullLoad(schemaVersion) - } - - for _, di := range dbInfos { - err := b.createSchemaTablesForDB(di, tableFromMeta, schemaVersion) - if err != nil { - return errors.Trace(err) - } - } - - err := b.initVirtualTables(schemaVersion) - if err != nil { - return err - } - - b.sortAllTablesByID() - - return nil -} - -func tableFromMeta(alloc autoid.Allocators, factory func() (pools.Resource, error), tblInfo *model.TableInfo) (table.Table, error) { - ret, err := tables.TableFromMeta(alloc, tblInfo) - if err != nil { - return nil, errors.Trace(err) - } - if t, ok := ret.(table.CachedTable); ok { - var tmp pools.Resource - tmp, err = factory() - if err != nil { - return nil, errors.Trace(err) - } - - err = t.Init(tmp.(sessionctx.Context).GetSQLExecutor()) - if err != nil { - return nil, errors.Trace(err) - } - } - return ret, nil -} - -type tableFromMetaFunc func(alloc autoid.Allocators, factory func() (pools.Resource, error), tblInfo *model.TableInfo) (table.Table, error) - -func (b *Builder) createSchemaTablesForDB(di *model.DBInfo, tableFromMeta tableFromMetaFunc, schemaVersion int64) error { - schTbls := &schemaTables{ - dbInfo: di, - tables: make(map[string]table.Table, len(di.Deprecated.Tables)), - } - for _, t := range di.Deprecated.Tables { - allocs := autoid.NewAllocatorsFromTblInfo(b.Requirement, di.ID, t) - var tbl table.Table - tbl, err := tableFromMeta(allocs, b.factory, t) - if err != nil { - return errors.Wrap(err, fmt.Sprintf("Build table `%s`.`%s` schema failed", di.Name.O, t.Name.O)) - } - - schTbls.tables[t.Name.L] = tbl - b.addTable(schemaVersion, di, t, tbl) - if len(di.TableName2ID) > 0 { - delete(di.TableName2ID, t.Name.L) - } - - if tblInfo := tbl.Meta(); tblInfo.TempTableType != model.TempTableNone { - b.addTemporaryTable(tblInfo.ID) - } - } - // Add the rest name to ID mappings. - if b.enableV2 { - for name, id := range di.TableName2ID { - item := tableItem{ - dbName: di.Name.L, - dbID: di.ID, - tableName: name, - tableID: id, - schemaVersion: schemaVersion, - } - b.infoData.byID.Set(item) - b.infoData.byName.Set(item) - } - } - b.addDB(schemaVersion, di, schTbls) - - return nil -} - -func (b *Builder) addDB(schemaVersion int64, di *model.DBInfo, schTbls *schemaTables) { - if b.enableV2 { - if IsSpecialDB(di.Name.L) { - b.infoData.addSpecialDB(di, schTbls) - } else { - b.infoData.addDB(schemaVersion, di) - } - } else { - b.infoSchema.addSchema(schTbls) - } -} - -func (b *Builder) addTable(schemaVersion int64, di *model.DBInfo, tblInfo *model.TableInfo, tbl table.Table) { - if b.enableV2 { - b.infoData.add(tableItem{ - dbName: di.Name.L, - dbID: di.ID, - tableName: tblInfo.Name.L, - tableID: tblInfo.ID, - schemaVersion: schemaVersion, - }, tbl) - } else { - sortedTbls := b.infoSchema.sortedTablesBuckets[tableBucketIdx(tblInfo.ID)] - b.infoSchema.sortedTablesBuckets[tableBucketIdx(tblInfo.ID)] = append(sortedTbls, tbl) - } -} - -type virtualTableDriver struct { - *model.DBInfo - TableFromMeta tableFromMetaFunc -} - -var drivers []*virtualTableDriver - -// RegisterVirtualTable register virtual tables to the builder. -func RegisterVirtualTable(dbInfo *model.DBInfo, tableFromMeta tableFromMetaFunc) { - drivers = append(drivers, &virtualTableDriver{dbInfo, tableFromMeta}) -} - -// NewBuilder creates a new Builder with a Handle. -func NewBuilder(r autoid.Requirement, factory func() (pools.Resource, error), infoData *Data, useV2 bool) *Builder { - builder := &Builder{ - Requirement: r, - infoschemaV2: NewInfoSchemaV2(r, factory, infoData), - dirtyDB: make(map[string]bool), - factory: factory, - infoData: infoData, - enableV2: useV2, - } - schemaCacheSize := variable.SchemaCacheSize.Load() - if schemaCacheSize > 0 { - infoData.tableCache.SetCapacity(schemaCacheSize) - } - return builder -} - -// WithStore attaches the given store to builder. -func (b *Builder) WithStore(s kv.Storage) *Builder { - b.store = s - return b -} - -func tableBucketIdx(tableID int64) int { - intest.Assert(tableID > 0) - return int(tableID % bucketCount) -} - -func tableIDIsValid(tableID int64) bool { - return tableID > 0 -} diff --git a/pkg/infoschema/infoschema_v2.go b/pkg/infoschema/infoschema_v2.go index 87f9725634901..c1ffe26d2197d 100644 --- a/pkg/infoschema/infoschema_v2.go +++ b/pkg/infoschema/infoschema_v2.go @@ -1012,9 +1012,9 @@ func (is *infoschemaV2) SchemaByID(id int64) (*model.DBInfo, bool) { func (is *infoschemaV2) loadTableInfo(ctx context.Context, tblID, dbID int64, ts uint64, schemaVersion int64) (table.Table, error) { defer tracing.StartRegion(ctx, "infoschema.loadTableInfo").End() - if _, _err_ := failpoint.Eval(_curpkg_("mockLoadTableInfoError")); _err_ == nil { - return nil, errors.New("mockLoadTableInfoError") - } + failpoint.Inject("mockLoadTableInfoError", func(_ failpoint.Value) { + failpoint.Return(nil, errors.New("mockLoadTableInfoError")) + }) // Try to avoid repeated concurrency loading. res, err, _ := loadTableSF.Do(fmt.Sprintf("%d-%d-%d", dbID, tblID, schemaVersion), func() (any, error) { retry: diff --git a/pkg/infoschema/infoschema_v2.go__failpoint_stash__ b/pkg/infoschema/infoschema_v2.go__failpoint_stash__ deleted file mode 100644 index c1ffe26d2197d..0000000000000 --- a/pkg/infoschema/infoschema_v2.go__failpoint_stash__ +++ /dev/null @@ -1,1456 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package infoschema - -import ( - "context" - "fmt" - "math" - "strings" - "sync" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/size" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/tidwall/btree" - "golang.org/x/sync/singleflight" -) - -// tableItem is the btree item sorted by name or by id. -type tableItem struct { - dbName string - dbID int64 - tableName string - tableID int64 - schemaVersion int64 - tomb bool -} - -type schemaItem struct { - schemaVersion int64 - dbInfo *model.DBInfo - tomb bool -} - -type schemaIDName struct { - schemaVersion int64 - id int64 - name string - tomb bool -} - -func (si *schemaItem) Name() string { - return si.dbInfo.Name.L -} - -// versionAndTimestamp is the tuple of schema version and timestamp. -type versionAndTimestamp struct { - schemaVersion int64 - timestamp uint64 -} - -// Data is the core data struct of infoschema V2. -type Data struct { - // For the TableByName API, sorted by {dbName, tableName, schemaVersion} => tableID - // - // If the schema version +1 but a specific table does not change, the old record is - // kept and no new {dbName, tableName, schemaVersion+1} => tableID record been added. - // - // It means as long as we can find an item in it, the item is available, even through the - // schema version maybe smaller than required. - byName *btree.BTreeG[tableItem] - - // For the TableByID API, sorted by {tableID, schemaVersion} => dbID - // To reload model.TableInfo, we need both table ID and database ID for meta kv API. - // It provides the tableID => databaseID mapping. - // This mapping MUST be synced with byName. - byID *btree.BTreeG[tableItem] - - // For the SchemaByName API, sorted by {dbName, schemaVersion} => model.DBInfo - // Stores the full data in memory. - schemaMap *btree.BTreeG[schemaItem] - - // For the SchemaByID API, sorted by {id, schemaVersion} - // Stores only id, name and schemaVersion in memory. - schemaID2Name *btree.BTreeG[schemaIDName] - - tableCache *Sieve[tableCacheKey, table.Table] - - // sorted by both SchemaVersion and timestamp in descending order, assume they have same order - mu struct { - sync.RWMutex - versionTimestamps []versionAndTimestamp - } - - // For information_schema/metrics_schema/performance_schema etc - specials sync.Map - - // pid2tid is used by FindTableInfoByPartitionID, it stores {partitionID, schemaVersion} => table ID - // Need full data in memory! - pid2tid *btree.BTreeG[partitionItem] - - // tableInfoResident stores {dbName, tableID, schemaVersion} => model.TableInfo - // It is part of the model.TableInfo data kept in memory to accelerate the list tables API. - // We observe the pattern that list table API always come with filter. - // All model.TableInfo with special attributes are here, currently the special attributes including: - // TTLInfo, TiFlashReplica - // PlacementPolicyRef, Partition might be added later, and also ForeignKeys, TableLock etc - tableInfoResident *btree.BTreeG[tableInfoItem] -} - -type tableInfoItem struct { - dbName string - tableID int64 - schemaVersion int64 - tableInfo *model.TableInfo - tomb bool -} - -type partitionItem struct { - partitionID int64 - schemaVersion int64 - tableID int64 - tomb bool -} - -func (isd *Data) getVersionByTS(ts uint64) (int64, bool) { - isd.mu.RLock() - defer isd.mu.RUnlock() - return isd.getVersionByTSNoLock(ts) -} - -func (isd *Data) getVersionByTSNoLock(ts uint64) (int64, bool) { - // search one by one instead of binary search, because the timestamp of a schema could be 0 - // this is ok because the size of h.tableCache is small (currently set to 16) - // moreover, the most likely hit element in the array is the first one in steady mode - // thus it may have better performance than binary search - for i, vt := range isd.mu.versionTimestamps { - if vt.timestamp == 0 || ts < vt.timestamp { - // is.timestamp == 0 means the schema ts is unknown, so we can't use it, then just skip it. - // ts < is.timestamp means the schema is newer than ts, so we can't use it too, just skip it to find the older one. - continue - } - // ts >= is.timestamp must be true after the above condition. - if i == 0 { - // the first element is the latest schema, so we can return it directly. - return vt.schemaVersion, true - } - if isd.mu.versionTimestamps[i-1].schemaVersion == vt.schemaVersion+1 && isd.mu.versionTimestamps[i-1].timestamp > ts { - // This first condition is to make sure the schema version is continuous. If last(cache[i-1]) schema-version is 10, - // but current(cache[i]) schema-version is not 9, then current schema is not suitable for ts. - // The second condition is to make sure the cache[i-1].timestamp > ts >= cache[i].timestamp, then the current schema is suitable for ts. - return vt.schemaVersion, true - } - // current schema is not suitable for ts, then break the loop to avoid the unnecessary search. - break - } - - return 0, false -} - -type tableCacheKey struct { - tableID int64 - schemaVersion int64 -} - -// NewData creates an infoschema V2 data struct. -func NewData() *Data { - ret := &Data{ - byID: btree.NewBTreeG[tableItem](compareByID), - byName: btree.NewBTreeG[tableItem](compareByName), - schemaMap: btree.NewBTreeG[schemaItem](compareSchemaItem), - schemaID2Name: btree.NewBTreeG[schemaIDName](compareSchemaByID), - tableCache: newSieve[tableCacheKey, table.Table](1024 * 1024 * size.MB), - pid2tid: btree.NewBTreeG[partitionItem](comparePartitionItem), - tableInfoResident: btree.NewBTreeG[tableInfoItem](compareTableInfoItem), - } - ret.tableCache.SetStatusHook(newSieveStatusHookImpl()) - return ret -} - -// CacheCapacity is exported for testing. -func (isd *Data) CacheCapacity() uint64 { - return isd.tableCache.Capacity() -} - -// SetCacheCapacity sets the cache capacity size in bytes. -func (isd *Data) SetCacheCapacity(capacity uint64) { - isd.tableCache.SetCapacityAndWaitEvict(capacity) -} - -func (isd *Data) add(item tableItem, tbl table.Table) { - isd.byID.Set(item) - isd.byName.Set(item) - isd.tableCache.Set(tableCacheKey{item.tableID, item.schemaVersion}, tbl) - ti := tbl.Meta() - if pi := ti.GetPartitionInfo(); pi != nil { - for _, def := range pi.Definitions { - isd.pid2tid.Set(partitionItem{def.ID, item.schemaVersion, tbl.Meta().ID, false}) - } - } - if hasSpecialAttributes(ti) { - isd.tableInfoResident.Set(tableInfoItem{ - dbName: item.dbName, - tableID: item.tableID, - schemaVersion: item.schemaVersion, - tableInfo: ti, - tomb: false}) - } -} - -func (isd *Data) addSpecialDB(di *model.DBInfo, tables *schemaTables) { - isd.specials.LoadOrStore(di.Name.L, tables) -} - -func (isd *Data) addDB(schemaVersion int64, dbInfo *model.DBInfo) { - dbInfo.Deprecated.Tables = nil - isd.schemaID2Name.Set(schemaIDName{schemaVersion: schemaVersion, id: dbInfo.ID, name: dbInfo.Name.O}) - isd.schemaMap.Set(schemaItem{schemaVersion: schemaVersion, dbInfo: dbInfo}) -} - -func (isd *Data) remove(item tableItem) { - item.tomb = true - isd.byID.Set(item) - isd.byName.Set(item) - isd.tableInfoResident.Set(tableInfoItem{ - dbName: item.dbName, - tableID: item.tableID, - schemaVersion: item.schemaVersion, - tableInfo: nil, - tomb: true}) - isd.tableCache.Remove(tableCacheKey{item.tableID, item.schemaVersion}) -} - -func (isd *Data) deleteDB(dbInfo *model.DBInfo, schemaVersion int64) { - item := schemaItem{schemaVersion: schemaVersion, dbInfo: dbInfo, tomb: true} - isd.schemaMap.Set(item) - isd.schemaID2Name.Set(schemaIDName{schemaVersion: schemaVersion, id: dbInfo.ID, name: dbInfo.Name.O, tomb: true}) -} - -// resetBeforeFullLoad is called before a full recreate operation within builder.InitWithDBInfos(). -// TODO: write a generics version to avoid repeated code. -func (isd *Data) resetBeforeFullLoad(schemaVersion int64) { - resetTableInfoResidentBeforeFullLoad(isd.tableInfoResident, schemaVersion) - - resetByIDBeforeFullLoad(isd.byID, schemaVersion) - resetByNameBeforeFullLoad(isd.byName, schemaVersion) - - resetSchemaMapBeforeFullLoad(isd.schemaMap, schemaVersion) - resetSchemaID2NameBeforeFullLoad(isd.schemaID2Name, schemaVersion) - - resetPID2TIDBeforeFullLoad(isd.pid2tid, schemaVersion) -} - -func resetByIDBeforeFullLoad(bt *btree.BTreeG[tableItem], schemaVersion int64) { - pivot, ok := bt.Max() - if !ok { - return - } - - batchSize := 1000 - if bt.Len() < batchSize { - batchSize = bt.Len() - } - items := make([]tableItem, 0, batchSize) - items = append(items, pivot) - for { - bt.Descend(pivot, func(item tableItem) bool { - if pivot.tableID == item.tableID { - return true // skip MVCC version - } - pivot = item - items = append(items, pivot) - return len(items) < cap(items) - }) - if len(items) == 0 { - break - } - for _, item := range items { - bt.Set(tableItem{ - dbName: item.dbName, - dbID: item.dbID, - tableName: item.tableName, - tableID: item.tableID, - schemaVersion: schemaVersion, - tomb: true, - }) - } - items = items[:0] - } -} - -func resetByNameBeforeFullLoad(bt *btree.BTreeG[tableItem], schemaVersion int64) { - pivot, ok := bt.Max() - if !ok { - return - } - - batchSize := 1000 - if bt.Len() < batchSize { - batchSize = bt.Len() - } - items := make([]tableItem, 0, batchSize) - items = append(items, pivot) - for { - bt.Descend(pivot, func(item tableItem) bool { - if pivot.dbName == item.dbName && pivot.tableName == item.tableName { - return true // skip MVCC version - } - pivot = item - items = append(items, pivot) - return len(items) < cap(items) - }) - if len(items) == 0 { - break - } - for _, item := range items { - bt.Set(tableItem{ - dbName: item.dbName, - dbID: item.dbID, - tableName: item.tableName, - tableID: item.tableID, - schemaVersion: schemaVersion, - tomb: true, - }) - } - items = items[:0] - } -} - -func resetTableInfoResidentBeforeFullLoad(bt *btree.BTreeG[tableInfoItem], schemaVersion int64) { - pivot, ok := bt.Max() - if !ok { - return - } - items := make([]tableInfoItem, 0, bt.Len()) - items = append(items, pivot) - bt.Descend(pivot, func(item tableInfoItem) bool { - if pivot.dbName == item.dbName && pivot.tableID == item.tableID { - return true // skip MVCC version - } - pivot = item - items = append(items, pivot) - return true - }) - for _, item := range items { - bt.Set(tableInfoItem{ - dbName: item.dbName, - tableID: item.tableID, - schemaVersion: schemaVersion, - tomb: true, - }) - } -} - -func resetSchemaMapBeforeFullLoad(bt *btree.BTreeG[schemaItem], schemaVersion int64) { - pivot, ok := bt.Max() - if !ok { - return - } - items := make([]schemaItem, 0, bt.Len()) - items = append(items, pivot) - bt.Descend(pivot, func(item schemaItem) bool { - if pivot.Name() == item.Name() { - return true // skip MVCC version - } - pivot = item - items = append(items, pivot) - return true - }) - for _, item := range items { - bt.Set(schemaItem{ - dbInfo: item.dbInfo, - schemaVersion: schemaVersion, - tomb: true, - }) - } -} - -func resetSchemaID2NameBeforeFullLoad(bt *btree.BTreeG[schemaIDName], schemaVersion int64) { - pivot, ok := bt.Max() - if !ok { - return - } - items := make([]schemaIDName, 0, bt.Len()) - items = append(items, pivot) - bt.Descend(pivot, func(item schemaIDName) bool { - if pivot.id == item.id { - return true // skip MVCC version - } - pivot = item - items = append(items, pivot) - return true - }) - for _, item := range items { - bt.Set(schemaIDName{ - id: item.id, - name: item.name, - schemaVersion: schemaVersion, - tomb: true, - }) - } -} - -func resetPID2TIDBeforeFullLoad(bt *btree.BTreeG[partitionItem], schemaVersion int64) { - pivot, ok := bt.Max() - if !ok { - return - } - - batchSize := 1000 - if bt.Len() < batchSize { - batchSize = bt.Len() - } - items := make([]partitionItem, 0, batchSize) - items = append(items, pivot) - for { - bt.Descend(pivot, func(item partitionItem) bool { - if pivot.partitionID == item.partitionID { - return true // skip MVCC version - } - pivot = item - items = append(items, pivot) - return len(items) < cap(items) - }) - if len(items) == 0 { - break - } - for _, item := range items { - bt.Set(partitionItem{ - partitionID: item.partitionID, - tableID: item.tableID, - schemaVersion: schemaVersion, - tomb: true, - }) - } - items = items[:0] - } -} - -func compareByID(a, b tableItem) bool { - if a.tableID < b.tableID { - return true - } - if a.tableID > b.tableID { - return false - } - - return a.schemaVersion < b.schemaVersion -} - -func compareByName(a, b tableItem) bool { - if a.dbName < b.dbName { - return true - } - if a.dbName > b.dbName { - return false - } - - if a.tableName < b.tableName { - return true - } - if a.tableName > b.tableName { - return false - } - - return a.schemaVersion < b.schemaVersion -} - -func compareTableInfoItem(a, b tableInfoItem) bool { - if a.dbName < b.dbName { - return true - } - if a.dbName > b.dbName { - return false - } - - if a.tableID < b.tableID { - return true - } - if a.tableID > b.tableID { - return false - } - return a.schemaVersion < b.schemaVersion -} - -func comparePartitionItem(a, b partitionItem) bool { - if a.partitionID < b.partitionID { - return true - } - if a.partitionID > b.partitionID { - return false - } - return a.schemaVersion < b.schemaVersion -} - -func compareSchemaItem(a, b schemaItem) bool { - if a.Name() < b.Name() { - return true - } - if a.Name() > b.Name() { - return false - } - return a.schemaVersion < b.schemaVersion -} - -func compareSchemaByID(a, b schemaIDName) bool { - if a.id < b.id { - return true - } - if a.id > b.id { - return false - } - return a.schemaVersion < b.schemaVersion -} - -var _ InfoSchema = &infoschemaV2{} - -type infoschemaV2 struct { - *infoSchema // in fact, we only need the infoSchemaMisc inside it, but the builder rely on it. - r autoid.Requirement - factory func() (pools.Resource, error) - ts uint64 - *Data -} - -// NewInfoSchemaV2 create infoschemaV2. -func NewInfoSchemaV2(r autoid.Requirement, factory func() (pools.Resource, error), infoData *Data) infoschemaV2 { - return infoschemaV2{ - infoSchema: newInfoSchema(), - Data: infoData, - r: r, - factory: factory, - } -} - -func search(bt *btree.BTreeG[tableItem], schemaVersion int64, end tableItem, matchFn func(a, b *tableItem) bool) (tableItem, bool) { - var ok bool - var target tableItem - // Iterate through the btree, find the query item whose schema version is the largest one (latest). - bt.Descend(end, func(item tableItem) bool { - if !matchFn(&end, &item) { - return false - } - if item.schemaVersion > schemaVersion { - // We're seaching historical snapshot, and this record is newer than us, we can't use it. - // Skip the record. - return true - } - // schema version of the items should <= query's schema version. - if !ok { // The first one found. - ok = true - target = item - } else { // The latest one - if item.schemaVersion > target.schemaVersion { - target = item - } - } - return true - }) - if ok && target.tomb { - // If the item is a tomb record, the table is dropped. - ok = false - } - return target, ok -} - -func (is *infoschemaV2) base() *infoSchema { - return is.infoSchema -} - -func (is *infoschemaV2) CloneAndUpdateTS(startTS uint64) *infoschemaV2 { - tmp := *is - tmp.ts = startTS - return &tmp -} - -func (is *infoschemaV2) TableByID(id int64) (val table.Table, ok bool) { - return is.tableByID(id, true) -} - -func (is *infoschemaV2) tableByID(id int64, noRefill bool) (val table.Table, ok bool) { - if !tableIDIsValid(id) { - return - } - - // Get from the cache. - key := tableCacheKey{id, is.infoSchema.schemaMetaVersion} - tbl, found := is.tableCache.Get(key) - if found && tbl != nil { - return tbl, true - } - - eq := func(a, b *tableItem) bool { return a.tableID == b.tableID } - itm, ok := search(is.byID, is.infoSchema.schemaMetaVersion, tableItem{tableID: id, schemaVersion: math.MaxInt64}, eq) - if !ok { - return nil, false - } - - if isTableVirtual(id) { - if raw, exist := is.Data.specials.Load(itm.dbName); exist { - schTbls := raw.(*schemaTables) - val, ok = schTbls.tables[itm.tableName] - return - } - return nil, false - } - // get cache with old key - oldKey := tableCacheKey{itm.tableID, itm.schemaVersion} - tbl, found = is.tableCache.Get(oldKey) - if found && tbl != nil { - if !noRefill { - is.tableCache.Set(key, tbl) - } - return tbl, true - } - - // Maybe the table is evicted? need to reload. - ret, err := is.loadTableInfo(context.Background(), id, itm.dbID, is.ts, is.infoSchema.schemaMetaVersion) - if err != nil || ret == nil { - return nil, false - } - - if !noRefill { - is.tableCache.Set(oldKey, ret) - } - return ret, true -} - -// IsSpecialDB tells whether the database is a special database. -func IsSpecialDB(dbName string) bool { - return dbName == util.InformationSchemaName.L || - dbName == util.PerformanceSchemaName.L || - dbName == util.MetricSchemaName.L -} - -// EvictTable is exported for testing only. -func (is *infoschemaV2) EvictTable(schema, tbl string) { - eq := func(a, b *tableItem) bool { return a.dbName == b.dbName && a.tableName == b.tableName } - itm, ok := search(is.byName, is.infoSchema.schemaMetaVersion, tableItem{dbName: schema, tableName: tbl, schemaVersion: math.MaxInt64}, eq) - if !ok { - return - } - is.tableCache.Remove(tableCacheKey{itm.tableID, is.infoSchema.schemaMetaVersion}) - is.tableCache.Remove(tableCacheKey{itm.tableID, itm.schemaVersion}) -} - -type tableByNameHelper struct { - end tableItem - schemaVersion int64 - found bool - res tableItem -} - -func (h *tableByNameHelper) onItem(item tableItem) bool { - if item.dbName != h.end.dbName || item.tableName != h.end.tableName { - h.found = false - return false - } - if item.schemaVersion <= h.schemaVersion { - if !item.tomb { // If the item is a tomb record, the database is dropped. - h.found = true - h.res = item - } - return false - } - return true -} - -func (is *infoschemaV2) TableByName(ctx context.Context, schema, tbl model.CIStr) (t table.Table, err error) { - if IsSpecialDB(schema.L) { - if raw, ok := is.specials.Load(schema.L); ok { - tbNames := raw.(*schemaTables) - if t, ok = tbNames.tables[tbl.L]; ok { - return - } - } - return nil, ErrTableNotExists.FastGenByArgs(schema, tbl) - } - - start := time.Now() - - var h tableByNameHelper - h.end = tableItem{dbName: schema.L, tableName: tbl.L, schemaVersion: math.MaxInt64} - h.schemaVersion = is.infoSchema.schemaMetaVersion - is.byName.Descend(h.end, h.onItem) - - if !h.found { - return nil, ErrTableNotExists.FastGenByArgs(schema, tbl) - } - itm := h.res - - // Get from the cache with old key - oldKey := tableCacheKey{itm.tableID, itm.schemaVersion} - res, found := is.tableCache.Get(oldKey) - if found && res != nil { - metrics.TableByNameHitDuration.Observe(float64(time.Since(start))) - return res, nil - } - - // Maybe the table is evicted? need to reload. - ret, err := is.loadTableInfo(ctx, itm.tableID, itm.dbID, is.ts, is.infoSchema.schemaMetaVersion) - if err != nil { - return nil, errors.Trace(err) - } - is.tableCache.Set(oldKey, ret) - metrics.TableByNameMissDuration.Observe(float64(time.Since(start))) - return ret, nil -} - -// TableInfoByName implements InfoSchema.TableInfoByName -func (is *infoschemaV2) TableInfoByName(schema, table model.CIStr) (*model.TableInfo, error) { - tbl, err := is.TableByName(context.Background(), schema, table) - return getTableInfo(tbl), err -} - -// TableInfoByID implements InfoSchema.TableInfoByID -func (is *infoschemaV2) TableInfoByID(id int64) (*model.TableInfo, bool) { - tbl, ok := is.TableByID(id) - return getTableInfo(tbl), ok -} - -// SchemaTableInfos implements MetaOnlyInfoSchema. -func (is *infoschemaV2) SchemaTableInfos(ctx context.Context, schema model.CIStr) ([]*model.TableInfo, error) { - if IsSpecialDB(schema.L) { - raw, ok := is.Data.specials.Load(schema.L) - if ok { - schTbls := raw.(*schemaTables) - tables := make([]table.Table, 0, len(schTbls.tables)) - for _, tbl := range schTbls.tables { - tables = append(tables, tbl) - } - return getTableInfoList(tables), nil - } - return nil, nil // something wrong? - } - -retry: - dbInfo, ok := is.SchemaByName(schema) - if !ok { - return nil, nil - } - snapshot := is.r.Store().GetSnapshot(kv.NewVersion(is.ts)) - // Using the KV timeout read feature to address the issue of potential DDL lease expiration when - // the meta region leader is slow. - snapshot.SetOption(kv.TiKVClientReadTimeout, uint64(3000)) // 3000ms. - m := meta.NewSnapshotMeta(snapshot) - tblInfos, err := m.ListTables(dbInfo.ID) - if err != nil { - if meta.ErrDBNotExists.Equal(err) { - return nil, nil - } - // Flashback statement could cause such kind of error. - // In theory that error should be handled in the lower layer, like client-go. - // But it's not done, so we retry here. - if strings.Contains(err.Error(), "in flashback progress") { - select { - case <-time.After(200 * time.Millisecond): - case <-ctx.Done(): - return nil, ctx.Err() - } - goto retry - } - return nil, errors.Trace(err) - } - return tblInfos, nil -} - -// SchemaSimpleTableInfos implements MetaOnlyInfoSchema. -func (is *infoschemaV2) SchemaSimpleTableInfos(ctx context.Context, schema model.CIStr) ([]*model.TableNameInfo, error) { - if IsSpecialDB(schema.L) { - raw, ok := is.Data.specials.Load(schema.L) - if ok { - schTbls := raw.(*schemaTables) - ret := make([]*model.TableNameInfo, 0, len(schTbls.tables)) - for _, tbl := range schTbls.tables { - ret = append(ret, &model.TableNameInfo{ - ID: tbl.Meta().ID, - Name: tbl.Meta().Name, - }) - } - return ret, nil - } - return nil, nil // something wrong? - } - - // Ascend is much more difficult than Descend. - // So the data is taken out first and then dedup in Descend order. - var tableItems []tableItem - is.byName.Ascend(tableItem{dbName: schema.L}, func(item tableItem) bool { - if item.dbName != schema.L { - return false - } - if is.infoSchema.schemaMetaVersion >= item.schemaVersion { - tableItems = append(tableItems, item) - } - return true - }) - if len(tableItems) == 0 { - return nil, nil - } - tblInfos := make([]*model.TableNameInfo, 0, len(tableItems)) - var curr *tableItem - for i := len(tableItems) - 1; i >= 0; i-- { - item := &tableItems[i] - if curr == nil || curr.tableName != tableItems[i].tableName { - curr = item - if !item.tomb { - tblInfos = append(tblInfos, &model.TableNameInfo{ - ID: item.tableID, - Name: model.NewCIStr(item.tableName), - }) - } - } - } - return tblInfos, nil -} - -// FindTableInfoByPartitionID implements InfoSchema.FindTableInfoByPartitionID -func (is *infoschemaV2) FindTableInfoByPartitionID( - partitionID int64, -) (*model.TableInfo, *model.DBInfo, *model.PartitionDefinition) { - tbl, db, partDef := is.FindTableByPartitionID(partitionID) - return getTableInfo(tbl), db, partDef -} - -func (is *infoschemaV2) SchemaByName(schema model.CIStr) (val *model.DBInfo, ok bool) { - if IsSpecialDB(schema.L) { - raw, ok := is.Data.specials.Load(schema.L) - if !ok { - return nil, false - } - schTbls, ok := raw.(*schemaTables) - return schTbls.dbInfo, ok - } - - var dbInfo model.DBInfo - dbInfo.Name = schema - is.Data.schemaMap.Descend(schemaItem{ - dbInfo: &dbInfo, - schemaVersion: math.MaxInt64, - }, func(item schemaItem) bool { - if item.Name() != schema.L { - ok = false - return false - } - if item.schemaVersion <= is.infoSchema.schemaMetaVersion { - if !item.tomb { // If the item is a tomb record, the database is dropped. - ok = true - val = item.dbInfo - } - return false - } - return true - }) - return -} - -func (is *infoschemaV2) allSchemas(visit func(*model.DBInfo)) { - var last *model.DBInfo - is.Data.schemaMap.Reverse(func(item schemaItem) bool { - if item.schemaVersion > is.infoSchema.schemaMetaVersion { - // Skip the versions that we are not looking for. - return true - } - - // Dedup the same db record of different versions. - if last != nil && last.Name == item.dbInfo.Name { - return true - } - last = item.dbInfo - - if !item.tomb { - visit(item.dbInfo) - } - return true - }) - is.Data.specials.Range(func(key, value any) bool { - sc := value.(*schemaTables) - visit(sc.dbInfo) - return true - }) -} - -func (is *infoschemaV2) AllSchemas() (schemas []*model.DBInfo) { - is.allSchemas(func(di *model.DBInfo) { - schemas = append(schemas, di) - }) - return -} - -func (is *infoschemaV2) AllSchemaNames() []model.CIStr { - rs := make([]model.CIStr, 0, is.Data.schemaMap.Len()) - is.allSchemas(func(di *model.DBInfo) { - rs = append(rs, di.Name) - }) - return rs -} - -func (is *infoschemaV2) SchemaExists(schema model.CIStr) bool { - _, ok := is.SchemaByName(schema) - return ok -} - -func (is *infoschemaV2) FindTableByPartitionID(partitionID int64) (table.Table, *model.DBInfo, *model.PartitionDefinition) { - var ok bool - var pi partitionItem - is.pid2tid.Descend(partitionItem{partitionID: partitionID, schemaVersion: math.MaxInt64}, - func(item partitionItem) bool { - if item.partitionID != partitionID { - return false - } - if item.schemaVersion > is.infoSchema.schemaMetaVersion { - // Skip the record. - return true - } - if item.schemaVersion <= is.infoSchema.schemaMetaVersion { - ok = !item.tomb - pi = item - return false - } - return true - }) - if !ok { - return nil, nil, nil - } - - tbl, ok := is.TableByID(pi.tableID) - if !ok { - // something wrong? - return nil, nil, nil - } - - dbID := tbl.Meta().DBID - dbInfo, ok := is.SchemaByID(dbID) - if !ok { - // something wrong? - return nil, nil, nil - } - - partInfo := tbl.Meta().GetPartitionInfo() - var def *model.PartitionDefinition - for i := 0; i < len(partInfo.Definitions); i++ { - pdef := &partInfo.Definitions[i] - if pdef.ID == partitionID { - def = pdef - break - } - } - - return tbl, dbInfo, def -} - -func (is *infoschemaV2) TableExists(schema, table model.CIStr) bool { - _, err := is.TableByName(context.Background(), schema, table) - return err == nil -} - -func (is *infoschemaV2) SchemaByID(id int64) (*model.DBInfo, bool) { - if isTableVirtual(id) { - var st *schemaTables - is.Data.specials.Range(func(key, value any) bool { - tmp := value.(*schemaTables) - if tmp.dbInfo.ID == id { - st = tmp - return false - } - return true - }) - if st == nil { - return nil, false - } - return st.dbInfo, true - } - var ok bool - var name string - is.Data.schemaID2Name.Descend(schemaIDName{ - id: id, - schemaVersion: math.MaxInt64, - }, func(item schemaIDName) bool { - if item.id != id { - ok = false - return false - } - if item.schemaVersion <= is.infoSchema.schemaMetaVersion { - if !item.tomb { // If the item is a tomb record, the database is dropped. - ok = true - name = item.name - } - return false - } - return true - }) - if !ok { - return nil, false - } - return is.SchemaByName(model.NewCIStr(name)) -} - -func (is *infoschemaV2) loadTableInfo(ctx context.Context, tblID, dbID int64, ts uint64, schemaVersion int64) (table.Table, error) { - defer tracing.StartRegion(ctx, "infoschema.loadTableInfo").End() - failpoint.Inject("mockLoadTableInfoError", func(_ failpoint.Value) { - failpoint.Return(nil, errors.New("mockLoadTableInfoError")) - }) - // Try to avoid repeated concurrency loading. - res, err, _ := loadTableSF.Do(fmt.Sprintf("%d-%d-%d", dbID, tblID, schemaVersion), func() (any, error) { - retry: - snapshot := is.r.Store().GetSnapshot(kv.NewVersion(ts)) - // Using the KV timeout read feature to address the issue of potential DDL lease expiration when - // the meta region leader is slow. - snapshot.SetOption(kv.TiKVClientReadTimeout, uint64(3000)) // 3000ms. - m := meta.NewSnapshotMeta(snapshot) - - tblInfo, err := m.GetTable(dbID, tblID) - if err != nil { - // Flashback statement could cause such kind of error. - // In theory that error should be handled in the lower layer, like client-go. - // But it's not done, so we retry here. - if strings.Contains(err.Error(), "in flashback progress") { - time.Sleep(200 * time.Millisecond) - goto retry - } - - // TODO load table panic!!! - panic(err) - } - - // table removed. - if tblInfo == nil { - return nil, errors.Trace(ErrTableNotExists.FastGenByArgs( - fmt.Sprintf("(Schema ID %d)", dbID), - fmt.Sprintf("(Table ID %d)", tblID), - )) - } - - ConvertCharsetCollateToLowerCaseIfNeed(tblInfo) - ConvertOldVersionUTF8ToUTF8MB4IfNeed(tblInfo) - allocs := autoid.NewAllocatorsFromTblInfo(is.r, dbID, tblInfo) - ret, err := tableFromMeta(allocs, is.factory, tblInfo) - if err != nil { - return nil, errors.Trace(err) - } - return ret, err - }) - - if err != nil { - return nil, errors.Trace(err) - } - if res == nil { - return nil, errors.Trace(ErrTableNotExists.FastGenByArgs( - fmt.Sprintf("(Schema ID %d)", dbID), - fmt.Sprintf("(Table ID %d)", tblID), - )) - } - return res.(table.Table), nil -} - -var loadTableSF = &singleflight.Group{} - -func isTableVirtual(id int64) bool { - // some kind of magic number... - // we use special ids for tables in INFORMATION_SCHEMA/PERFORMANCE_SCHEMA/METRICS_SCHEMA - // See meta/autoid/autoid.go for those definitions. - return (id & autoid.SystemSchemaIDFlag) > 0 -} - -// IsV2 tells whether an InfoSchema is v2 or not. -func IsV2(is InfoSchema) (bool, *infoschemaV2) { - ret, ok := is.(*infoschemaV2) - return ok, ret -} - -func applyTableUpdate(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - if b.enableV2 { - return b.applyTableUpdateV2(m, diff) - } - return b.applyTableUpdate(m, diff) -} - -func applyCreateSchema(b *Builder, m *meta.Meta, diff *model.SchemaDiff) error { - return b.applyCreateSchema(m, diff) -} - -func applyDropSchema(b *Builder, diff *model.SchemaDiff) []int64 { - if b.enableV2 { - return b.applyDropSchemaV2(diff) - } - return b.applyDropSchema(diff) -} - -func applyRecoverSchema(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - if diff.ReadTableFromMeta { - // recover tables under the database and set them to diff.AffectedOpts - s := b.store.GetSnapshot(kv.MaxVersion) - recoverMeta := meta.NewSnapshotMeta(s) - tables, err := recoverMeta.ListSimpleTables(diff.SchemaID) - if err != nil { - return nil, err - } - diff.AffectedOpts = make([]*model.AffectedOption, 0, len(tables)) - for _, t := range tables { - diff.AffectedOpts = append(diff.AffectedOpts, &model.AffectedOption{ - SchemaID: diff.SchemaID, - OldSchemaID: diff.SchemaID, - TableID: t.ID, - OldTableID: t.ID, - }) - } - } - - if b.enableV2 { - return b.applyRecoverSchemaV2(m, diff) - } - return b.applyRecoverSchema(m, diff) -} - -func (b *Builder) applyRecoverSchemaV2(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - if di, ok := b.infoschemaV2.SchemaByID(diff.SchemaID); ok { - return nil, ErrDatabaseExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", di.ID), - ) - } - di, err := m.GetDatabase(diff.SchemaID) - if err != nil { - return nil, errors.Trace(err) - } - b.infoschemaV2.addDB(diff.Version, di) - return applyCreateTables(b, m, diff) -} - -func applyModifySchemaCharsetAndCollate(b *Builder, m *meta.Meta, diff *model.SchemaDiff) error { - if b.enableV2 { - return b.applyModifySchemaCharsetAndCollateV2(m, diff) - } - return b.applyModifySchemaCharsetAndCollate(m, diff) -} - -func applyModifySchemaDefaultPlacement(b *Builder, m *meta.Meta, diff *model.SchemaDiff) error { - if b.enableV2 { - return b.applyModifySchemaDefaultPlacementV2(m, diff) - } - return b.applyModifySchemaDefaultPlacement(m, diff) -} - -func applyDropTable(b *Builder, diff *model.SchemaDiff, dbInfo *model.DBInfo, tableID int64, affected []int64) []int64 { - if b.enableV2 { - return b.applyDropTableV2(diff, dbInfo, tableID, affected) - } - return b.applyDropTable(diff, dbInfo, tableID, affected) -} - -func applyCreateTables(b *Builder, m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - return b.applyCreateTables(m, diff) -} - -func updateInfoSchemaBundles(b *Builder) { - if b.enableV2 { - b.updateInfoSchemaBundlesV2(&b.infoschemaV2) - } else { - b.updateInfoSchemaBundles(b.infoSchema) - } -} - -func oldSchemaInfo(b *Builder, diff *model.SchemaDiff) (*model.DBInfo, bool) { - if b.enableV2 { - return b.infoschemaV2.SchemaByID(diff.OldSchemaID) - } - - oldRoDBInfo, ok := b.infoSchema.SchemaByID(diff.OldSchemaID) - if ok { - oldRoDBInfo = b.getSchemaAndCopyIfNecessary(oldRoDBInfo.Name.L) - } - return oldRoDBInfo, ok -} - -// allocByID returns the Allocators of a table. -func allocByID(b *Builder, id int64) (autoid.Allocators, bool) { - var is InfoSchema - if b.enableV2 { - is = &b.infoschemaV2 - } else { - is = b.infoSchema - } - tbl, ok := is.TableByID(id) - if !ok { - return autoid.Allocators{}, false - } - return tbl.Allocators(nil), true -} - -// TODO: more UT to check the correctness. -func (b *Builder) applyTableUpdateV2(m *meta.Meta, diff *model.SchemaDiff) ([]int64, error) { - oldDBInfo, ok := b.infoschemaV2.SchemaByID(diff.SchemaID) - if !ok { - return nil, ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.SchemaID), - ) - } - - oldTableID, newTableID := b.getTableIDs(diff) - b.updateBundleForTableUpdate(diff, newTableID, oldTableID) - - tblIDs, allocs, err := dropTableForUpdate(b, newTableID, oldTableID, oldDBInfo, diff) - if err != nil { - return nil, err - } - - if tableIDIsValid(newTableID) { - // All types except DropTableOrView. - var err error - tblIDs, err = applyCreateTable(b, m, oldDBInfo, newTableID, allocs, diff.Type, tblIDs, diff.Version) - if err != nil { - return nil, errors.Trace(err) - } - } - return tblIDs, nil -} - -func (b *Builder) applyDropSchemaV2(diff *model.SchemaDiff) []int64 { - di, ok := b.infoschemaV2.SchemaByID(diff.SchemaID) - if !ok { - return nil - } - - tableIDs := make([]int64, 0, len(di.Deprecated.Tables)) - tables, err := b.infoschemaV2.SchemaTableInfos(context.Background(), di.Name) - terror.Log(err) - for _, tbl := range tables { - tableIDs = appendAffectedIDs(tableIDs, tbl) - } - - for _, id := range tableIDs { - b.deleteBundle(b.infoSchema, id) - b.applyDropTableV2(diff, di, id, nil) - } - b.infoData.deleteDB(di, diff.Version) - return tableIDs -} - -func (b *Builder) applyDropTableV2(diff *model.SchemaDiff, dbInfo *model.DBInfo, tableID int64, affected []int64) []int64 { - // Remove the table in temporaryTables - if b.infoSchemaMisc.temporaryTableIDs != nil { - delete(b.infoSchemaMisc.temporaryTableIDs, tableID) - } - - table, ok := b.infoschemaV2.TableByID(tableID) - if !ok { - return nil - } - - // The old DBInfo still holds a reference to old table info, we need to remove it. - b.infoSchema.deleteReferredForeignKeys(dbInfo.Name, table.Meta()) - - if pi := table.Meta().GetPartitionInfo(); pi != nil { - for _, def := range pi.Definitions { - b.infoData.pid2tid.Set(partitionItem{def.ID, diff.Version, table.Meta().ID, true}) - } - } - - b.infoData.remove(tableItem{ - dbName: dbInfo.Name.L, - dbID: dbInfo.ID, - tableName: table.Meta().Name.L, - tableID: table.Meta().ID, - schemaVersion: diff.Version, - }) - - return affected -} - -func (b *Builder) applyModifySchemaCharsetAndCollateV2(m *meta.Meta, diff *model.SchemaDiff) error { - di, err := m.GetDatabase(diff.SchemaID) - if err != nil { - return errors.Trace(err) - } - if di == nil { - // This should never happen. - return ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.SchemaID), - ) - } - newDBInfo, _ := b.infoschemaV2.SchemaByID(diff.SchemaID) - newDBInfo.Charset = di.Charset - newDBInfo.Collate = di.Collate - b.infoschemaV2.deleteDB(di, diff.Version) - b.infoschemaV2.addDB(diff.Version, newDBInfo) - return nil -} - -func (b *Builder) applyModifySchemaDefaultPlacementV2(m *meta.Meta, diff *model.SchemaDiff) error { - di, err := m.GetDatabase(diff.SchemaID) - if err != nil { - return errors.Trace(err) - } - if di == nil { - // This should never happen. - return ErrDatabaseNotExists.GenWithStackByArgs( - fmt.Sprintf("(Schema ID %d)", diff.SchemaID), - ) - } - newDBInfo, _ := b.infoschemaV2.SchemaByID(diff.SchemaID) - newDBInfo.PlacementPolicyRef = di.PlacementPolicyRef - b.infoschemaV2.deleteDB(di, diff.Version) - b.infoschemaV2.addDB(diff.Version, newDBInfo) - return nil -} - -func (b *bundleInfoBuilder) updateInfoSchemaBundlesV2(is *infoschemaV2) { - if b.deltaUpdate { - b.completeUpdateTablesV2(is) - for tblID := range b.updateTables { - b.updateTableBundles(is, tblID) - } - return - } - - // do full update bundles - is.ruleBundleMap = make(map[int64]*placement.Bundle) - tmp := is.ListTablesWithSpecialAttribute(PlacementPolicyAttribute) - for _, v := range tmp { - for _, tbl := range v.TableInfos { - b.updateTableBundles(is, tbl.ID) - } - } -} - -func (b *bundleInfoBuilder) completeUpdateTablesV2(is *infoschemaV2) { - if len(b.updatePolicies) == 0 && len(b.updatePartitions) == 0 { - return - } - - dbs := is.ListTablesWithSpecialAttribute(AllSpecialAttribute) - for _, db := range dbs { - for _, tbl := range db.TableInfos { - tblInfo := tbl - if tblInfo.PlacementPolicyRef != nil { - if _, ok := b.updatePolicies[tblInfo.PlacementPolicyRef.ID]; ok { - b.markTableBundleShouldUpdate(tblInfo.ID) - } - } - - if tblInfo.Partition != nil { - for _, par := range tblInfo.Partition.Definitions { - if _, ok := b.updatePartitions[par.ID]; ok { - b.markTableBundleShouldUpdate(tblInfo.ID) - } - } - } - } - } -} - -type specialAttributeFilter func(*model.TableInfo) bool - -// TTLAttribute is the TTL attribute filter used by ListTablesWithSpecialAttribute. -var TTLAttribute specialAttributeFilter = func(t *model.TableInfo) bool { - return t.State == model.StatePublic && t.TTLInfo != nil -} - -// TiFlashAttribute is the TiFlashReplica attribute filter used by ListTablesWithSpecialAttribute. -var TiFlashAttribute specialAttributeFilter = func(t *model.TableInfo) bool { - return t.TiFlashReplica != nil -} - -// PlacementPolicyAttribute is the Placement Policy attribute filter used by ListTablesWithSpecialAttribute. -var PlacementPolicyAttribute specialAttributeFilter = func(t *model.TableInfo) bool { - if t.PlacementPolicyRef != nil { - return true - } - if parInfo := t.GetPartitionInfo(); parInfo != nil { - for _, def := range parInfo.Definitions { - if def.PlacementPolicyRef != nil { - return true - } - } - } - return false -} - -// TableLockAttribute is the Table Lock attribute filter used by ListTablesWithSpecialAttribute. -var TableLockAttribute specialAttributeFilter = func(t *model.TableInfo) bool { - return t.Lock != nil -} - -// ForeignKeysAttribute is the ForeignKeys attribute filter used by ListTablesWithSpecialAttribute. -var ForeignKeysAttribute specialAttributeFilter = func(t *model.TableInfo) bool { - return len(t.ForeignKeys) > 0 -} - -// PartitionAttribute is the Partition attribute filter used by ListTablesWithSpecialAttribute. -var PartitionAttribute specialAttributeFilter = func(t *model.TableInfo) bool { - return t.GetPartitionInfo() != nil -} - -func hasSpecialAttributes(t *model.TableInfo) bool { - return TTLAttribute(t) || TiFlashAttribute(t) || PlacementPolicyAttribute(t) || PartitionAttribute(t) || TableLockAttribute(t) || ForeignKeysAttribute(t) -} - -// AllSpecialAttribute marks a model.TableInfo with any special attributes. -var AllSpecialAttribute specialAttributeFilter = hasSpecialAttributes - -func (is *infoschemaV2) ListTablesWithSpecialAttribute(filter specialAttributeFilter) []tableInfoResult { - ret := make([]tableInfoResult, 0, 10) - var currDB string - var lastTableID int64 - var res tableInfoResult - is.Data.tableInfoResident.Reverse(func(item tableInfoItem) bool { - if item.schemaVersion > is.infoSchema.schemaMetaVersion { - // Skip the versions that we are not looking for. - return true - } - // Dedup the same record of different versions. - if lastTableID != 0 && lastTableID == item.tableID { - return true - } - lastTableID = item.tableID - - if item.tomb { - return true - } - - if !filter(item.tableInfo) { - return true - } - - if currDB == "" { - currDB = item.dbName - res = tableInfoResult{DBName: item.dbName} - res.TableInfos = append(res.TableInfos, item.tableInfo) - } else if currDB == item.dbName { - res.TableInfos = append(res.TableInfos, item.tableInfo) - } else { - ret = append(ret, res) - res = tableInfoResult{DBName: item.dbName} - res.TableInfos = append(res.TableInfos, item.tableInfo) - } - return true - }) - if len(res.TableInfos) > 0 { - ret = append(ret, res) - } - return ret -} diff --git a/pkg/infoschema/perfschema/binding__failpoint_binding__.go b/pkg/infoschema/perfschema/binding__failpoint_binding__.go deleted file mode 100644 index fe39572d084c5..0000000000000 --- a/pkg/infoschema/perfschema/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package perfschema - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/infoschema/perfschema/tables.go b/pkg/infoschema/perfschema/tables.go index d4d85b0a26f81..d242bd7560a71 100644 --- a/pkg/infoschema/perfschema/tables.go +++ b/pkg/infoschema/perfschema/tables.go @@ -309,7 +309,7 @@ func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGorout default: return nil, errors.Errorf("%s does not support profile remote component", nodeType) } - if val, _err_ := failpoint.Eval(_curpkg_("mockRemoteNodeStatusAddress")); _err_ == nil { + failpoint.Inject("mockRemoteNodeStatusAddress", func(val failpoint.Value) { // The cluster topology is injected by `failpoint` expression and // there is no extra checks for it. (let the test fail if the expression invalid) if s := val.(string); len(s) > 0 { @@ -328,7 +328,7 @@ func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGorout // erase error err = nil } - } + }) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/infoschema/perfschema/tables.go__failpoint_stash__ b/pkg/infoschema/perfschema/tables.go__failpoint_stash__ deleted file mode 100644 index d242bd7560a71..0000000000000 --- a/pkg/infoschema/perfschema/tables.go__failpoint_stash__ +++ /dev/null @@ -1,415 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package perfschema - -import ( - "cmp" - "context" - "fmt" - "net/http" - "slices" - "strings" - "sync" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/profile" - pd "github.com/tikv/pd/client/http" -) - -const ( - tableNameGlobalStatus = "global_status" - tableNameSessionStatus = "session_status" - tableNameSetupActors = "setup_actors" - tableNameSetupObjects = "setup_objects" - tableNameSetupInstruments = "setup_instruments" - tableNameSetupConsumers = "setup_consumers" - tableNameEventsStatementsCurrent = "events_statements_current" - tableNameEventsStatementsHistory = "events_statements_history" - tableNameEventsStatementsHistoryLong = "events_statements_history_long" - tableNamePreparedStatementsInstances = "prepared_statements_instances" - tableNameEventsTransactionsCurrent = "events_transactions_current" - tableNameEventsTransactionsHistory = "events_transactions_history" - tableNameEventsTransactionsHistoryLong = "events_transactions_history_long" - tableNameEventsStagesCurrent = "events_stages_current" - tableNameEventsStagesHistory = "events_stages_history" - tableNameEventsStagesHistoryLong = "events_stages_history_long" - tableNameEventsStatementsSummaryByDigest = "events_statements_summary_by_digest" - tableNameTiDBProfileCPU = "tidb_profile_cpu" - tableNameTiDBProfileMemory = "tidb_profile_memory" - tableNameTiDBProfileMutex = "tidb_profile_mutex" - tableNameTiDBProfileAllocs = "tidb_profile_allocs" - tableNameTiDBProfileBlock = "tidb_profile_block" - tableNameTiDBProfileGoroutines = "tidb_profile_goroutines" - tableNameTiKVProfileCPU = "tikv_profile_cpu" - tableNamePDProfileCPU = "pd_profile_cpu" - tableNamePDProfileMemory = "pd_profile_memory" - tableNamePDProfileMutex = "pd_profile_mutex" - tableNamePDProfileAllocs = "pd_profile_allocs" - tableNamePDProfileBlock = "pd_profile_block" - tableNamePDProfileGoroutines = "pd_profile_goroutines" - tableNameSessionAccountConnectAttrs = "session_account_connect_attrs" - tableNameSessionConnectAttrs = "session_connect_attrs" - tableNameSessionVariables = "session_variables" -) - -var tableIDMap = map[string]int64{ - tableNameGlobalStatus: autoid.PerformanceSchemaDBID + 1, - tableNameSessionStatus: autoid.PerformanceSchemaDBID + 2, - tableNameSetupActors: autoid.PerformanceSchemaDBID + 3, - tableNameSetupObjects: autoid.PerformanceSchemaDBID + 4, - tableNameSetupInstruments: autoid.PerformanceSchemaDBID + 5, - tableNameSetupConsumers: autoid.PerformanceSchemaDBID + 6, - tableNameEventsStatementsCurrent: autoid.PerformanceSchemaDBID + 7, - tableNameEventsStatementsHistory: autoid.PerformanceSchemaDBID + 8, - tableNameEventsStatementsHistoryLong: autoid.PerformanceSchemaDBID + 9, - tableNamePreparedStatementsInstances: autoid.PerformanceSchemaDBID + 10, - tableNameEventsTransactionsCurrent: autoid.PerformanceSchemaDBID + 11, - tableNameEventsTransactionsHistory: autoid.PerformanceSchemaDBID + 12, - tableNameEventsTransactionsHistoryLong: autoid.PerformanceSchemaDBID + 13, - tableNameEventsStagesCurrent: autoid.PerformanceSchemaDBID + 14, - tableNameEventsStagesHistory: autoid.PerformanceSchemaDBID + 15, - tableNameEventsStagesHistoryLong: autoid.PerformanceSchemaDBID + 16, - tableNameEventsStatementsSummaryByDigest: autoid.PerformanceSchemaDBID + 17, - tableNameTiDBProfileCPU: autoid.PerformanceSchemaDBID + 18, - tableNameTiDBProfileMemory: autoid.PerformanceSchemaDBID + 19, - tableNameTiDBProfileMutex: autoid.PerformanceSchemaDBID + 20, - tableNameTiDBProfileAllocs: autoid.PerformanceSchemaDBID + 21, - tableNameTiDBProfileBlock: autoid.PerformanceSchemaDBID + 22, - tableNameTiDBProfileGoroutines: autoid.PerformanceSchemaDBID + 23, - tableNameTiKVProfileCPU: autoid.PerformanceSchemaDBID + 24, - tableNamePDProfileCPU: autoid.PerformanceSchemaDBID + 25, - tableNamePDProfileMemory: autoid.PerformanceSchemaDBID + 26, - tableNamePDProfileMutex: autoid.PerformanceSchemaDBID + 27, - tableNamePDProfileAllocs: autoid.PerformanceSchemaDBID + 28, - tableNamePDProfileBlock: autoid.PerformanceSchemaDBID + 29, - tableNamePDProfileGoroutines: autoid.PerformanceSchemaDBID + 30, - tableNameSessionVariables: autoid.PerformanceSchemaDBID + 31, - tableNameSessionConnectAttrs: autoid.PerformanceSchemaDBID + 32, - tableNameSessionAccountConnectAttrs: autoid.PerformanceSchemaDBID + 33, -} - -// perfSchemaTable stands for the fake table all its data is in the memory. -type perfSchemaTable struct { - infoschema.VirtualTable - meta *model.TableInfo - cols []*table.Column - tp table.Type - indices []table.Index -} - -var pluginTable = make(map[string]func(autoid.Allocators, *model.TableInfo) (table.Table, error)) - -// IsPredefinedTable judges whether this table is predefined. -func IsPredefinedTable(tableName string) bool { - _, ok := tableIDMap[strings.ToLower(tableName)] - return ok -} - -func tableFromMeta(allocs autoid.Allocators, _ func() (pools.Resource, error), meta *model.TableInfo) (table.Table, error) { - if f, ok := pluginTable[meta.Name.L]; ok { - ret, err := f(allocs, meta) - return ret, err - } - return createPerfSchemaTable(meta) -} - -// createPerfSchemaTable creates all perfSchemaTables -func createPerfSchemaTable(meta *model.TableInfo) (*perfSchemaTable, error) { - columns := make([]*table.Column, 0, len(meta.Columns)) - for _, colInfo := range meta.Columns { - col := table.ToColumn(colInfo) - columns = append(columns, col) - } - tp := table.VirtualTable - t := &perfSchemaTable{ - meta: meta, - cols: columns, - tp: tp, - } - if err := initTableIndices(t); err != nil { - return nil, err - } - return t, nil -} - -// Cols implements table.Table Type interface. -func (vt *perfSchemaTable) Cols() []*table.Column { - return vt.cols -} - -// VisibleCols implements table.Table VisibleCols interface. -func (vt *perfSchemaTable) VisibleCols() []*table.Column { - return vt.cols -} - -// HiddenCols implements table.Table HiddenCols interface. -func (vt *perfSchemaTable) HiddenCols() []*table.Column { - return nil -} - -// WritableCols implements table.Table Type interface. -func (vt *perfSchemaTable) WritableCols() []*table.Column { - return vt.cols -} - -// DeletableCols implements table.Table Type interface. -func (vt *perfSchemaTable) DeletableCols() []*table.Column { - return vt.cols -} - -// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. -func (vt *perfSchemaTable) FullHiddenColsAndVisibleCols() []*table.Column { - return vt.cols -} - -// GetPhysicalID implements table.Table GetID interface. -func (vt *perfSchemaTable) GetPhysicalID() int64 { - return vt.meta.ID -} - -// Meta implements table.Table Type interface. -func (vt *perfSchemaTable) Meta() *model.TableInfo { - return vt.meta -} - -// Type implements table.Table Type interface. -func (vt *perfSchemaTable) Type() table.Type { - return vt.tp -} - -// Indices implements table.Table Indices interface. -func (vt *perfSchemaTable) Indices() []table.Index { - return vt.indices -} - -// GetPartitionedTable implements table.Table GetPartitionedTable interface. -func (vt *perfSchemaTable) GetPartitionedTable() table.PartitionedTable { - return nil -} - -// initTableIndices initializes the indices of the perfSchemaTable. -func initTableIndices(t *perfSchemaTable) error { - tblInfo := t.meta - for _, idxInfo := range tblInfo.Indices { - if idxInfo.State == model.StateNone { - return table.ErrIndexStateCantNone.GenWithStackByArgs(idxInfo.Name) - } - idx := tables.NewIndex(t.meta.ID, tblInfo, idxInfo) - t.indices = append(t.indices, idx) - } - return nil -} - -func (vt *perfSchemaTable) getRows(ctx context.Context, sctx sessionctx.Context, cols []*table.Column) (fullRows [][]types.Datum, err error) { - switch vt.meta.Name.O { - case tableNameTiDBProfileCPU: - fullRows, err = (&profile.Collector{}).ProfileGraph("cpu") - case tableNameTiDBProfileMemory: - fullRows, err = (&profile.Collector{}).ProfileGraph("heap") - case tableNameTiDBProfileMutex: - fullRows, err = (&profile.Collector{}).ProfileGraph("mutex") - case tableNameTiDBProfileAllocs: - fullRows, err = (&profile.Collector{}).ProfileGraph("allocs") - case tableNameTiDBProfileBlock: - fullRows, err = (&profile.Collector{}).ProfileGraph("block") - case tableNameTiDBProfileGoroutines: - fullRows, err = (&profile.Collector{}).ProfileGraph("goroutine") - case tableNameTiKVProfileCPU: - interval := fmt.Sprintf("%d", profile.CPUProfileInterval/time.Second) - fullRows, err = dataForRemoteProfile(sctx, "tikv", "/debug/pprof/profile?seconds="+interval, false) - case tableNamePDProfileCPU: - fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfProfileAPIWithInterval(profile.CPUProfileInterval), false) - case tableNamePDProfileMemory: - fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfHeap, false) - case tableNamePDProfileMutex: - fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfMutex, false) - case tableNamePDProfileAllocs: - fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfAllocs, false) - case tableNamePDProfileBlock: - fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfBlock, false) - case tableNamePDProfileGoroutines: - fullRows, err = dataForRemoteProfile(sctx, "pd", pd.PProfGoroutineWithDebugLevel(2), true) - case tableNameSessionVariables: - fullRows, err = infoschema.GetDataFromSessionVariables(ctx, sctx) - case tableNameSessionConnectAttrs: - fullRows, err = infoschema.GetDataFromSessionConnectAttrs(sctx, false) - case tableNameSessionAccountConnectAttrs: - fullRows, err = infoschema.GetDataFromSessionConnectAttrs(sctx, true) - } - if err != nil { - return - } - if len(cols) == len(vt.cols) { - return - } - rows := make([][]types.Datum, len(fullRows)) - for i, fullRow := range fullRows { - row := make([]types.Datum, len(cols)) - for j, col := range cols { - row[j] = fullRow[col.Offset] - } - rows[i] = row - } - return rows, nil -} - -// IterRecords implements table.Table IterRecords interface. -func (vt *perfSchemaTable) IterRecords(ctx context.Context, sctx sessionctx.Context, cols []*table.Column, fn table.RecordIterFunc) error { - rows, err := vt.getRows(ctx, sctx, cols) - if err != nil { - return err - } - for i, row := range rows { - more, err := fn(kv.IntHandle(i), row, cols) - if err != nil { - return err - } - if !more { - break - } - } - return nil -} - -func dataForRemoteProfile(ctx sessionctx.Context, nodeType, uri string, isGoroutine bool) ([][]types.Datum, error) { - var ( - servers []infoschema.ServerInfo - err error - ) - switch nodeType { - case "tikv": - servers, err = infoschema.GetStoreServerInfo(ctx.GetStore()) - case "pd": - servers, err = infoschema.GetPDServerInfo(ctx) - default: - return nil, errors.Errorf("%s does not support profile remote component", nodeType) - } - failpoint.Inject("mockRemoteNodeStatusAddress", func(val failpoint.Value) { - // The cluster topology is injected by `failpoint` expression and - // there is no extra checks for it. (let the test fail if the expression invalid) - if s := val.(string); len(s) > 0 { - servers = servers[:0] - for _, server := range strings.Split(s, ";") { - parts := strings.Split(server, ",") - if parts[0] != nodeType { - continue - } - servers = append(servers, infoschema.ServerInfo{ - ServerType: parts[0], - Address: parts[1], - StatusAddr: parts[2], - }) - } - // erase error - err = nil - } - }) - if err != nil { - return nil, errors.Trace(err) - } - - type result struct { - addr string - rows [][]types.Datum - err error - } - - wg := sync.WaitGroup{} - ch := make(chan result, len(servers)) - for _, server := range servers { - statusAddr := server.StatusAddr - if len(statusAddr) == 0 { - ctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("TiKV node %s does not contain status address", server.Address)) - continue - } - - wg.Add(1) - go func(address string) { - util.WithRecovery(func() { - defer wg.Done() - url := fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), statusAddr, uri) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - ch <- result{err: errors.Trace(err)} - return - } - // Forbidden PD follower proxy - req.Header.Add("PD-Allow-follower-handle", "true") - // TiKV output svg format in default - req.Header.Add("Content-Type", "application/protobuf") - resp, err := util.InternalHTTPClient().Do(req) - if err != nil { - ch <- result{err: errors.Trace(err)} - return - } - defer func() { - terror.Log(resp.Body.Close()) - }() - if resp.StatusCode != http.StatusOK { - ch <- result{err: errors.Errorf("request %s failed: %s", url, resp.Status)} - return - } - collector := profile.Collector{} - var rows [][]types.Datum - if isGoroutine { - rows, err = collector.ParseGoroutines(resp.Body) - } else { - rows, err = collector.ProfileReaderToDatums(resp.Body) - } - if err != nil { - ch <- result{err: errors.Trace(err)} - return - } - ch <- result{addr: address, rows: rows} - }, nil) - }(statusAddr) - } - - wg.Wait() - close(ch) - - // Keep the original order to make the result more stable - var results []result //nolint: prealloc - for result := range ch { - if result.err != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(result.err) - continue - } - results = append(results, result) - } - slices.SortFunc(results, func(i, j result) int { return cmp.Compare(i.addr, j.addr) }) - var finalRows [][]types.Datum - for _, result := range results { - addr := types.NewStringDatum(result.addr) - for _, row := range result.rows { - // Insert the node address in front of rows - finalRows = append(finalRows, append([]types.Datum{addr}, row...)) - } - } - return finalRows, nil -} diff --git a/pkg/infoschema/sieve.go b/pkg/infoschema/sieve.go index 2041c908432e7..01100ebf359d6 100644 --- a/pkg/infoschema/sieve.go +++ b/pkg/infoschema/sieve.go @@ -151,10 +151,10 @@ func (s *Sieve[K, V]) Set(key K, value V) { } func (s *Sieve[K, V]) Get(key K) (value V, ok bool) { - if _, _err_ := failpoint.Eval(_curpkg_("skipGet")); _err_ == nil { + failpoint.Inject("skipGet", func() { var v V - return v, false - } + failpoint.Return(v, false) + }) s.mu.Lock() defer s.mu.Unlock() if e, ok := s.items[key]; ok { diff --git a/pkg/infoschema/sieve.go__failpoint_stash__ b/pkg/infoschema/sieve.go__failpoint_stash__ deleted file mode 100644 index 01100ebf359d6..0000000000000 --- a/pkg/infoschema/sieve.go__failpoint_stash__ +++ /dev/null @@ -1,272 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package infoschema - -import ( - "container/list" - "context" - "sync" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/infoschema/internal" -) - -// entry holds the key and value of a cache entry. -type entry[K comparable, V any] struct { - key K - value V - visited bool - element *list.Element - size uint64 -} - -func (t *entry[K, V]) Size() uint64 { - if t.size == 0 { - size := internal.Sizeof(t) - if size > 0 { - t.size = uint64(size) - } - } - return t.size -} - -// Sieve is an efficient turn-Key eviction algorithm for web caches. -// See blog post https://cachemon.github.io/SIEVE-website/blog/2023/12/17/sieve-is-simpler-than-lru/ -// and also the academic paper "SIEVE is simpler than LRU" -type Sieve[K comparable, V any] struct { - ctx context.Context - cancel context.CancelFunc - mu sync.Mutex - size uint64 - capacity uint64 - items map[K]*entry[K, V] - ll *list.List - hand *list.Element - - hook sieveStatusHook -} - -type sieveStatusHook interface { - onHit() - onMiss() - onEvict() - onUpdateSize(size uint64) - onUpdateLimit(limit uint64) -} - -type emptySieveStatusHook struct{} - -func (e *emptySieveStatusHook) onHit() {} - -func (e *emptySieveStatusHook) onMiss() {} - -func (e *emptySieveStatusHook) onEvict() {} - -func (e *emptySieveStatusHook) onUpdateSize(_ uint64) {} - -func (e *emptySieveStatusHook) onUpdateLimit(_ uint64) {} - -func newSieve[K comparable, V any](capacity uint64) *Sieve[K, V] { - ctx, cancel := context.WithCancel(context.Background()) - - cache := &Sieve[K, V]{ - ctx: ctx, - cancel: cancel, - capacity: capacity, - items: make(map[K]*entry[K, V]), - ll: list.New(), - hook: &emptySieveStatusHook{}, - } - - return cache -} - -func (s *Sieve[K, V]) SetStatusHook(hook sieveStatusHook) { - s.hook = hook -} - -func (s *Sieve[K, V]) SetCapacity(capacity uint64) { - s.mu.Lock() - defer s.mu.Unlock() - s.capacity = capacity - s.hook.onUpdateLimit(capacity) -} - -func (s *Sieve[K, V]) SetCapacityAndWaitEvict(capacity uint64) { - s.SetCapacity(capacity) - for { - s.mu.Lock() - if s.size <= s.capacity { - s.mu.Unlock() - break - } - for i := 0; s.size > s.capacity && i < 10; i++ { - s.evict() - } - s.mu.Unlock() - } -} - -func (s *Sieve[K, V]) Capacity() uint64 { - s.mu.Lock() - defer s.mu.Unlock() - return s.capacity -} - -func (s *Sieve[K, V]) Set(key K, value V) { - s.mu.Lock() - defer s.mu.Unlock() - - if e, ok := s.items[key]; ok { - e.value = value - e.visited = true - return - } - - for i := 0; s.size > s.capacity && i < 10; i++ { - s.evict() - } - - e := &entry[K, V]{ - key: key, - value: value, - } - s.size += e.Size() // calculate the size first without putting to the list. - s.hook.onUpdateSize(s.size) - e.element = s.ll.PushFront(key) - - s.items[key] = e -} - -func (s *Sieve[K, V]) Get(key K) (value V, ok bool) { - failpoint.Inject("skipGet", func() { - var v V - failpoint.Return(v, false) - }) - s.mu.Lock() - defer s.mu.Unlock() - if e, ok := s.items[key]; ok { - e.visited = true - s.hook.onHit() - return e.value, true - } - s.hook.onMiss() - return -} - -func (s *Sieve[K, V]) Remove(key K) (ok bool) { - s.mu.Lock() - defer s.mu.Unlock() - - if e, ok := s.items[key]; ok { - // if the element to be removed is the hand, - // then move the hand to the previous one. - if e.element == s.hand { - s.hand = s.hand.Prev() - } - - s.removeEntry(e) - return true - } - - return false -} - -func (s *Sieve[K, V]) Contains(key K) (ok bool) { - s.mu.Lock() - defer s.mu.Unlock() - _, ok = s.items[key] - return -} - -func (s *Sieve[K, V]) Peek(key K) (value V, ok bool) { - s.mu.Lock() - defer s.mu.Unlock() - - if e, ok := s.items[key]; ok { - return e.value, true - } - - return -} - -func (s *Sieve[K, V]) Size() uint64 { - s.mu.Lock() - defer s.mu.Unlock() - - return s.size -} - -func (s *Sieve[K, V]) Len() int { - s.mu.Lock() - defer s.mu.Unlock() - - return s.ll.Len() -} - -func (s *Sieve[K, V]) Purge() { - s.mu.Lock() - defer s.mu.Unlock() - - for _, e := range s.items { - s.removeEntry(e) - } - - s.ll.Init() -} - -func (s *Sieve[K, V]) Close() { - s.Purge() - s.mu.Lock() - s.cancel() - s.mu.Unlock() -} - -func (s *Sieve[K, V]) removeEntry(e *entry[K, V]) { - s.ll.Remove(e.element) - delete(s.items, e.key) - s.size -= e.Size() - s.hook.onUpdateSize(s.size) -} - -func (s *Sieve[K, V]) evict() { - o := s.hand - // if o is nil, then assign it to the tail element in the list - if o == nil { - o = s.ll.Back() - } - - el, ok := s.items[o.Value.(K)] - if !ok { - panic("sieve: evicting non-existent element") - } - - for el.visited { - el.visited = false - o = o.Prev() - if o == nil { - o = s.ll.Back() - } - - el, ok = s.items[o.Value.(K)] - if !ok { - panic("sieve: evicting non-existent element") - } - } - - s.hand = o.Prev() - s.removeEntry(el) - s.hook.onEvict() -} diff --git a/pkg/infoschema/tables.go b/pkg/infoschema/tables.go index b032c1c4c201b..6168001a813e7 100644 --- a/pkg/infoschema/tables.go +++ b/pkg/infoschema/tables.go @@ -1791,7 +1791,7 @@ func (s *ServerInfo) ResolveLoopBackAddr() { // GetClusterServerInfo returns all components information of cluster func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockClusterInfo")); _err_ == nil { + failpoint.Inject("mockClusterInfo", func(val failpoint.Value) { // The cluster topology is injected by `failpoint` expression and // there is no extra checks for it. (let the test fail if the expression invalid) if s := val.(string); len(s) > 0 { @@ -1811,9 +1811,9 @@ func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { ServerID: serverID, }) } - return servers, nil + failpoint.Return(servers, nil) } - } + }) type retriever func(ctx sessionctx.Context) ([]ServerInfo, error) retrievers := []retriever{GetTiDBServerInfo, GetPDServerInfo, func(ctx sessionctx.Context) ([]ServerInfo, error) { @@ -2069,7 +2069,7 @@ func isTiFlashWriteNode(store *metapb.Store) bool { // GetStoreServerInfo returns all store nodes(TiKV or TiFlash) cluster information func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockStoreServerInfo")); _err_ == nil { + failpoint.Inject("mockStoreServerInfo", func(val failpoint.Value) { if s := val.(string); len(s) > 0 { var servers []ServerInfo for _, server := range strings.Split(s, ";") { @@ -2083,9 +2083,9 @@ func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { StartTimestamp: 0, }) } - return servers, nil + failpoint.Return(servers, nil) } - } + }) // Get TiKV servers info. tikvStore, ok := store.(tikv.Storage) @@ -2102,11 +2102,11 @@ func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { } servers := make([]ServerInfo, 0, len(stores)) for _, store := range stores { - if val, _err_ := failpoint.Eval(_curpkg_("mockStoreTombstone")); _err_ == nil { + failpoint.Inject("mockStoreTombstone", func(val failpoint.Value) { if val.(bool) { store.State = metapb.StoreState_Tombstone } - } + }) if store.GetState() == metapb.StoreState_Tombstone { continue @@ -2144,11 +2144,11 @@ func FormatStoreServerVersion(version string) string { // GetTiFlashStoreCount returns the count of tiflash server. func GetTiFlashStoreCount(ctx sessionctx.Context) (cnt uint64, err error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockTiFlashStoreCount")); _err_ == nil { + failpoint.Inject("mockTiFlashStoreCount", func(val failpoint.Value) { if val.(bool) { - return uint64(10), nil + failpoint.Return(uint64(10), nil) } - } + }) stores, err := GetStoreServerInfo(ctx.GetStore()) if err != nil { diff --git a/pkg/infoschema/tables.go__failpoint_stash__ b/pkg/infoschema/tables.go__failpoint_stash__ deleted file mode 100644 index 6168001a813e7..0000000000000 --- a/pkg/infoschema/tables.go__failpoint_stash__ +++ /dev/null @@ -1,2694 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package infoschema - -import ( - "cmp" - "context" - "encoding/json" - "fmt" - "net" - "net/http" - "slices" - "sort" - "strconv" - "strings" - "sync" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/diagnosticspb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/ddl/resourcegroup" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/auth" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/session/txninfo" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/deadlockhistory" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sem" - "github.com/pingcap/tidb/pkg/util/set" - "github.com/pingcap/tidb/pkg/util/stmtsummary" - "github.com/tikv/client-go/v2/tikv" - pd "github.com/tikv/pd/client/http" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" -) - -const ( - // TableSchemata is the string constant of infoschema table. - TableSchemata = "SCHEMATA" - // TableTables is the string constant of infoschema table. - TableTables = "TABLES" - // TableColumns is the string constant of infoschema table - TableColumns = "COLUMNS" - tableColumnStatistics = "COLUMN_STATISTICS" - // TableStatistics is the string constant of infoschema table - TableStatistics = "STATISTICS" - // TableCharacterSets is the string constant of infoschema charactersets memory table - TableCharacterSets = "CHARACTER_SETS" - // TableCollations is the string constant of infoschema collations memory table. - TableCollations = "COLLATIONS" - tableFiles = "FILES" - // CatalogVal is the string constant of TABLE_CATALOG. - CatalogVal = "def" - // TableProfiling is the string constant of infoschema table. - TableProfiling = "PROFILING" - // TablePartitions is the string constant of infoschema table. - TablePartitions = "PARTITIONS" - // TableKeyColumn is the string constant of KEY_COLUMN_USAGE. - TableKeyColumn = "KEY_COLUMN_USAGE" - // TableReferConst is the string constant of REFERENTIAL_CONSTRAINTS. - TableReferConst = "REFERENTIAL_CONSTRAINTS" - // TableSessionVar is the string constant of SESSION_VARIABLES. - TableSessionVar = "SESSION_VARIABLES" - tablePlugins = "PLUGINS" - // TableConstraints is the string constant of TABLE_CONSTRAINTS. - TableConstraints = "TABLE_CONSTRAINTS" - tableTriggers = "TRIGGERS" - // TableUserPrivileges is the string constant of infoschema user privilege table. - TableUserPrivileges = "USER_PRIVILEGES" - tableSchemaPrivileges = "SCHEMA_PRIVILEGES" - tableTablePrivileges = "TABLE_PRIVILEGES" - tableColumnPrivileges = "COLUMN_PRIVILEGES" - // TableEngines is the string constant of infoschema table. - TableEngines = "ENGINES" - // TableViews is the string constant of infoschema table. - TableViews = "VIEWS" - tableRoutines = "ROUTINES" - tableParameters = "PARAMETERS" - tableEvents = "EVENTS" - tableGlobalStatus = "GLOBAL_STATUS" - tableGlobalVariables = "GLOBAL_VARIABLES" - tableSessionStatus = "SESSION_STATUS" - tableOptimizerTrace = "OPTIMIZER_TRACE" - tableTableSpaces = "TABLESPACES" - // TableCollationCharacterSetApplicability is the string constant of infoschema memory table. - TableCollationCharacterSetApplicability = "COLLATION_CHARACTER_SET_APPLICABILITY" - // TableProcesslist is the string constant of infoschema table. - TableProcesslist = "PROCESSLIST" - // TableTiDBIndexes is the string constant of infoschema table - TableTiDBIndexes = "TIDB_INDEXES" - // TableTiDBHotRegions is the string constant of infoschema table - TableTiDBHotRegions = "TIDB_HOT_REGIONS" - // TableTiDBHotRegionsHistory is the string constant of infoschema table - TableTiDBHotRegionsHistory = "TIDB_HOT_REGIONS_HISTORY" - // TableTiKVStoreStatus is the string constant of infoschema table - TableTiKVStoreStatus = "TIKV_STORE_STATUS" - // TableAnalyzeStatus is the string constant of Analyze Status - TableAnalyzeStatus = "ANALYZE_STATUS" - // TableTiKVRegionStatus is the string constant of infoschema table - TableTiKVRegionStatus = "TIKV_REGION_STATUS" - // TableTiKVRegionPeers is the string constant of infoschema table - TableTiKVRegionPeers = "TIKV_REGION_PEERS" - // TableTiDBServersInfo is the string constant of TiDB server information table. - TableTiDBServersInfo = "TIDB_SERVERS_INFO" - // TableSlowQuery is the string constant of slow query memory table. - TableSlowQuery = "SLOW_QUERY" - // TableClusterInfo is the string constant of cluster info memory table. - TableClusterInfo = "CLUSTER_INFO" - // TableClusterConfig is the string constant of cluster configuration memory table. - TableClusterConfig = "CLUSTER_CONFIG" - // TableClusterLog is the string constant of cluster log memory table. - TableClusterLog = "CLUSTER_LOG" - // TableClusterLoad is the string constant of cluster load memory table. - TableClusterLoad = "CLUSTER_LOAD" - // TableClusterHardware is the string constant of cluster hardware table. - TableClusterHardware = "CLUSTER_HARDWARE" - // TableClusterSystemInfo is the string constant of cluster system info table. - TableClusterSystemInfo = "CLUSTER_SYSTEMINFO" - // TableTiFlashReplica is the string constant of tiflash replica table. - TableTiFlashReplica = "TIFLASH_REPLICA" - // TableInspectionResult is the string constant of inspection result table. - TableInspectionResult = "INSPECTION_RESULT" - // TableMetricTables is a table that contains all metrics table definition. - TableMetricTables = "METRICS_TABLES" - // TableMetricSummary is a summary table that contains all metrics. - TableMetricSummary = "METRICS_SUMMARY" - // TableMetricSummaryByLabel is a metric table that contains all metrics that group by label info. - TableMetricSummaryByLabel = "METRICS_SUMMARY_BY_LABEL" - // TableInspectionSummary is the string constant of inspection summary table. - TableInspectionSummary = "INSPECTION_SUMMARY" - // TableInspectionRules is the string constant of currently implemented inspection and summary rules. - TableInspectionRules = "INSPECTION_RULES" - // TableDDLJobs is the string constant of DDL job table. - TableDDLJobs = "DDL_JOBS" - // TableSequences is the string constant of all sequences created by user. - TableSequences = "SEQUENCES" - // TableStatementsSummary is the string constant of statement summary table. - TableStatementsSummary = "STATEMENTS_SUMMARY" - // TableStatementsSummaryHistory is the string constant of statements summary history table. - TableStatementsSummaryHistory = "STATEMENTS_SUMMARY_HISTORY" - // TableStatementsSummaryEvicted is the string constant of statements summary evicted table. - TableStatementsSummaryEvicted = "STATEMENTS_SUMMARY_EVICTED" - // TableStorageStats is a table that contains all tables disk usage - TableStorageStats = "TABLE_STORAGE_STATS" - // TableTiFlashTables is the string constant of tiflash tables table. - TableTiFlashTables = "TIFLASH_TABLES" - // TableTiFlashSegments is the string constant of tiflash segments table. - TableTiFlashSegments = "TIFLASH_SEGMENTS" - // TableClientErrorsSummaryGlobal is the string constant of client errors table. - TableClientErrorsSummaryGlobal = "CLIENT_ERRORS_SUMMARY_GLOBAL" - // TableClientErrorsSummaryByUser is the string constant of client errors table. - TableClientErrorsSummaryByUser = "CLIENT_ERRORS_SUMMARY_BY_USER" - // TableClientErrorsSummaryByHost is the string constant of client errors table. - TableClientErrorsSummaryByHost = "CLIENT_ERRORS_SUMMARY_BY_HOST" - // TableTiDBTrx is current running transaction status table. - TableTiDBTrx = "TIDB_TRX" - // TableDeadlocks is the string constant of deadlock table. - TableDeadlocks = "DEADLOCKS" - // TableDataLockWaits is current lock waiting status table. - TableDataLockWaits = "DATA_LOCK_WAITS" - // TableAttributes is the string constant of attributes table. - TableAttributes = "ATTRIBUTES" - // TablePlacementPolicies is the string constant of placement policies table. - TablePlacementPolicies = "PLACEMENT_POLICIES" - // TableTrxSummary is the string constant of transaction summary table. - TableTrxSummary = "TRX_SUMMARY" - // TableVariablesInfo is the string constant of variables_info table. - TableVariablesInfo = "VARIABLES_INFO" - // TableUserAttributes is the string constant of user_attributes view. - TableUserAttributes = "USER_ATTRIBUTES" - // TableMemoryUsage is the memory usage status of tidb instance. - TableMemoryUsage = "MEMORY_USAGE" - // TableMemoryUsageOpsHistory is the memory control operators history. - TableMemoryUsageOpsHistory = "MEMORY_USAGE_OPS_HISTORY" - // TableResourceGroups is the metadata of resource groups. - TableResourceGroups = "RESOURCE_GROUPS" - // TableRunawayWatches is the query list of runaway watch. - TableRunawayWatches = "RUNAWAY_WATCHES" - // TableCheckConstraints is the list of CHECK constraints. - TableCheckConstraints = "CHECK_CONSTRAINTS" - // TableTiDBCheckConstraints is the list of CHECK constraints, with non-standard TiDB extensions. - TableTiDBCheckConstraints = "TIDB_CHECK_CONSTRAINTS" - // TableKeywords is the list of keywords. - TableKeywords = "KEYWORDS" - // TableTiDBIndexUsage is a table to show the usage stats of indexes in the current instance. - TableTiDBIndexUsage = "TIDB_INDEX_USAGE" -) - -const ( - // DataLockWaitsColumnKey is the name of the KEY column of the DATA_LOCK_WAITS table. - DataLockWaitsColumnKey = "KEY" - // DataLockWaitsColumnKeyInfo is the name of the KEY_INFO column of the DATA_LOCK_WAITS table. - DataLockWaitsColumnKeyInfo = "KEY_INFO" - // DataLockWaitsColumnTrxID is the name of the TRX_ID column of the DATA_LOCK_WAITS table. - DataLockWaitsColumnTrxID = "TRX_ID" - // DataLockWaitsColumnCurrentHoldingTrxID is the name of the CURRENT_HOLDING_TRX_ID column of the DATA_LOCK_WAITS table. - DataLockWaitsColumnCurrentHoldingTrxID = "CURRENT_HOLDING_TRX_ID" - // DataLockWaitsColumnSQLDigest is the name of the SQL_DIGEST column of the DATA_LOCK_WAITS table. - DataLockWaitsColumnSQLDigest = "SQL_DIGEST" - // DataLockWaitsColumnSQLDigestText is the name of the SQL_DIGEST_TEXT column of the DATA_LOCK_WAITS table. - DataLockWaitsColumnSQLDigestText = "SQL_DIGEST_TEXT" -) - -// The following variables will only be used when PD in the microservice mode. -const ( - // tsoServiceName is the name of TSO service. - tsoServiceName = "tso" - // schedulingServiceName is the name of scheduling service. - schedulingServiceName = "scheduling" -) - -var tableIDMap = map[string]int64{ - TableSchemata: autoid.InformationSchemaDBID + 1, - TableTables: autoid.InformationSchemaDBID + 2, - TableColumns: autoid.InformationSchemaDBID + 3, - tableColumnStatistics: autoid.InformationSchemaDBID + 4, - TableStatistics: autoid.InformationSchemaDBID + 5, - TableCharacterSets: autoid.InformationSchemaDBID + 6, - TableCollations: autoid.InformationSchemaDBID + 7, - tableFiles: autoid.InformationSchemaDBID + 8, - CatalogVal: autoid.InformationSchemaDBID + 9, - TableProfiling: autoid.InformationSchemaDBID + 10, - TablePartitions: autoid.InformationSchemaDBID + 11, - TableKeyColumn: autoid.InformationSchemaDBID + 12, - TableReferConst: autoid.InformationSchemaDBID + 13, - TableSessionVar: autoid.InformationSchemaDBID + 14, - tablePlugins: autoid.InformationSchemaDBID + 15, - TableConstraints: autoid.InformationSchemaDBID + 16, - tableTriggers: autoid.InformationSchemaDBID + 17, - TableUserPrivileges: autoid.InformationSchemaDBID + 18, - tableSchemaPrivileges: autoid.InformationSchemaDBID + 19, - tableTablePrivileges: autoid.InformationSchemaDBID + 20, - tableColumnPrivileges: autoid.InformationSchemaDBID + 21, - TableEngines: autoid.InformationSchemaDBID + 22, - TableViews: autoid.InformationSchemaDBID + 23, - tableRoutines: autoid.InformationSchemaDBID + 24, - tableParameters: autoid.InformationSchemaDBID + 25, - tableEvents: autoid.InformationSchemaDBID + 26, - tableGlobalStatus: autoid.InformationSchemaDBID + 27, - tableGlobalVariables: autoid.InformationSchemaDBID + 28, - tableSessionStatus: autoid.InformationSchemaDBID + 29, - tableOptimizerTrace: autoid.InformationSchemaDBID + 30, - tableTableSpaces: autoid.InformationSchemaDBID + 31, - TableCollationCharacterSetApplicability: autoid.InformationSchemaDBID + 32, - TableProcesslist: autoid.InformationSchemaDBID + 33, - TableTiDBIndexes: autoid.InformationSchemaDBID + 34, - TableSlowQuery: autoid.InformationSchemaDBID + 35, - TableTiDBHotRegions: autoid.InformationSchemaDBID + 36, - TableTiKVStoreStatus: autoid.InformationSchemaDBID + 37, - TableAnalyzeStatus: autoid.InformationSchemaDBID + 38, - TableTiKVRegionStatus: autoid.InformationSchemaDBID + 39, - TableTiKVRegionPeers: autoid.InformationSchemaDBID + 40, - TableTiDBServersInfo: autoid.InformationSchemaDBID + 41, - TableClusterInfo: autoid.InformationSchemaDBID + 42, - TableClusterConfig: autoid.InformationSchemaDBID + 43, - TableClusterLoad: autoid.InformationSchemaDBID + 44, - TableTiFlashReplica: autoid.InformationSchemaDBID + 45, - ClusterTableSlowLog: autoid.InformationSchemaDBID + 46, - ClusterTableProcesslist: autoid.InformationSchemaDBID + 47, - TableClusterLog: autoid.InformationSchemaDBID + 48, - TableClusterHardware: autoid.InformationSchemaDBID + 49, - TableClusterSystemInfo: autoid.InformationSchemaDBID + 50, - TableInspectionResult: autoid.InformationSchemaDBID + 51, - TableMetricSummary: autoid.InformationSchemaDBID + 52, - TableMetricSummaryByLabel: autoid.InformationSchemaDBID + 53, - TableMetricTables: autoid.InformationSchemaDBID + 54, - TableInspectionSummary: autoid.InformationSchemaDBID + 55, - TableInspectionRules: autoid.InformationSchemaDBID + 56, - TableDDLJobs: autoid.InformationSchemaDBID + 57, - TableSequences: autoid.InformationSchemaDBID + 58, - TableStatementsSummary: autoid.InformationSchemaDBID + 59, - TableStatementsSummaryHistory: autoid.InformationSchemaDBID + 60, - ClusterTableStatementsSummary: autoid.InformationSchemaDBID + 61, - ClusterTableStatementsSummaryHistory: autoid.InformationSchemaDBID + 62, - TableStorageStats: autoid.InformationSchemaDBID + 63, - TableTiFlashTables: autoid.InformationSchemaDBID + 64, - TableTiFlashSegments: autoid.InformationSchemaDBID + 65, - // Removed, see https://github.com/pingcap/tidb/issues/28890 - //TablePlacementPolicy: autoid.InformationSchemaDBID + 66, - TableClientErrorsSummaryGlobal: autoid.InformationSchemaDBID + 67, - TableClientErrorsSummaryByUser: autoid.InformationSchemaDBID + 68, - TableClientErrorsSummaryByHost: autoid.InformationSchemaDBID + 69, - TableTiDBTrx: autoid.InformationSchemaDBID + 70, - ClusterTableTiDBTrx: autoid.InformationSchemaDBID + 71, - TableDeadlocks: autoid.InformationSchemaDBID + 72, - ClusterTableDeadlocks: autoid.InformationSchemaDBID + 73, - TableDataLockWaits: autoid.InformationSchemaDBID + 74, - TableStatementsSummaryEvicted: autoid.InformationSchemaDBID + 75, - ClusterTableStatementsSummaryEvicted: autoid.InformationSchemaDBID + 76, - TableAttributes: autoid.InformationSchemaDBID + 77, - TableTiDBHotRegionsHistory: autoid.InformationSchemaDBID + 78, - TablePlacementPolicies: autoid.InformationSchemaDBID + 79, - TableTrxSummary: autoid.InformationSchemaDBID + 80, - ClusterTableTrxSummary: autoid.InformationSchemaDBID + 81, - TableVariablesInfo: autoid.InformationSchemaDBID + 82, - TableUserAttributes: autoid.InformationSchemaDBID + 83, - TableMemoryUsage: autoid.InformationSchemaDBID + 84, - TableMemoryUsageOpsHistory: autoid.InformationSchemaDBID + 85, - ClusterTableMemoryUsage: autoid.InformationSchemaDBID + 86, - ClusterTableMemoryUsageOpsHistory: autoid.InformationSchemaDBID + 87, - TableResourceGroups: autoid.InformationSchemaDBID + 88, - TableRunawayWatches: autoid.InformationSchemaDBID + 89, - TableCheckConstraints: autoid.InformationSchemaDBID + 90, - TableTiDBCheckConstraints: autoid.InformationSchemaDBID + 91, - TableKeywords: autoid.InformationSchemaDBID + 92, - TableTiDBIndexUsage: autoid.InformationSchemaDBID + 93, - ClusterTableTiDBIndexUsage: autoid.InformationSchemaDBID + 94, -} - -// columnInfo represents the basic column information of all kinds of INFORMATION_SCHEMA tables -type columnInfo struct { - // name of column - name string - // tp is column type - tp byte - // represent size of bytes of the column - size int - // represent decimal length of the column - decimal int - // flag represent NotNull, Unsigned, PriKey flags etc. - flag uint - // deflt is default value - deflt any - // comment for the column - comment string - // enumElems represent all possible literal string values of an enum column - enumElems []string -} - -func buildColumnInfo(col columnInfo) *model.ColumnInfo { - mCharset := charset.CharsetBin - mCollation := charset.CharsetBin - if col.tp == mysql.TypeVarchar || col.tp == mysql.TypeBlob || col.tp == mysql.TypeLongBlob || col.tp == mysql.TypeEnum { - mCharset = charset.CharsetUTF8MB4 - mCollation = charset.CollationUTF8MB4 - } - fieldType := types.FieldType{} - fieldType.SetType(col.tp) - fieldType.SetCharset(mCharset) - fieldType.SetCollate(mCollation) - fieldType.SetFlen(col.size) - fieldType.SetDecimal(col.decimal) - fieldType.SetFlag(col.flag) - fieldType.SetElems(col.enumElems) - return &model.ColumnInfo{ - Name: model.NewCIStr(col.name), - FieldType: fieldType, - State: model.StatePublic, - DefaultValue: col.deflt, - Comment: col.comment, - } -} - -func buildTableMeta(tableName string, cs []columnInfo) *model.TableInfo { - cols := make([]*model.ColumnInfo, 0, len(cs)) - primaryIndices := make([]*model.IndexInfo, 0, 1) - tblInfo := &model.TableInfo{ - Name: model.NewCIStr(tableName), - State: model.StatePublic, - Charset: mysql.DefaultCharset, - Collate: mysql.DefaultCollationName, - } - for offset, c := range cs { - if tblInfo.Name.O == ClusterTableSlowLog && mysql.HasPriKeyFlag(c.flag) { - switch c.tp { - case mysql.TypeLong, mysql.TypeLonglong, - mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24: - tblInfo.PKIsHandle = true - default: - tblInfo.IsCommonHandle = true - tblInfo.CommonHandleVersion = 1 - index := &model.IndexInfo{ - Name: model.NewCIStr("primary"), - State: model.StatePublic, - Primary: true, - Unique: true, - Columns: []*model.IndexColumn{ - {Name: model.NewCIStr(c.name), Offset: offset, Length: types.UnspecifiedLength}}, - } - primaryIndices = append(primaryIndices, index) - tblInfo.Indices = primaryIndices - } - } - cols = append(cols, buildColumnInfo(c)) - } - for i, col := range cols { - col.Offset = i - } - tblInfo.Columns = cols - return tblInfo -} - -var schemataCols = []columnInfo{ - {name: "CATALOG_NAME", tp: mysql.TypeVarchar, size: 512}, - {name: "SCHEMA_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "DEFAULT_CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "DEFAULT_COLLATION_NAME", tp: mysql.TypeVarchar, size: 32}, - {name: "SQL_PATH", tp: mysql.TypeVarchar, size: 512}, - {name: "TIDB_PLACEMENT_POLICY_NAME", tp: mysql.TypeVarchar, size: 64}, -} - -var tablesCols = []columnInfo{ - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "ENGINE", tp: mysql.TypeVarchar, size: 64}, - {name: "VERSION", tp: mysql.TypeLonglong, size: 21}, - {name: "ROW_FORMAT", tp: mysql.TypeVarchar, size: 10}, - {name: "TABLE_ROWS", tp: mysql.TypeLonglong, size: 21}, - {name: "AVG_ROW_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "MAX_DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "INDEX_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "DATA_FREE", tp: mysql.TypeLonglong, size: 21}, - {name: "AUTO_INCREMENT", tp: mysql.TypeLonglong, size: 21}, - {name: "CREATE_TIME", tp: mysql.TypeDatetime, size: 19}, - {name: "UPDATE_TIME", tp: mysql.TypeDatetime, size: 19}, - {name: "CHECK_TIME", tp: mysql.TypeDatetime, size: 19}, - {name: "TABLE_COLLATION", tp: mysql.TypeVarchar, size: 32, deflt: mysql.DefaultCollationName}, - {name: "CHECKSUM", tp: mysql.TypeLonglong, size: 21}, - {name: "CREATE_OPTIONS", tp: mysql.TypeVarchar, size: 255}, - {name: "TABLE_COMMENT", tp: mysql.TypeVarchar, size: 2048}, - {name: "TIDB_TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "TIDB_ROW_ID_SHARDING_INFO", tp: mysql.TypeVarchar, size: 255}, - {name: "TIDB_PK_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "TIDB_PLACEMENT_POLICY_NAME", tp: mysql.TypeVarchar, size: 64}, -} - -// See: http://dev.mysql.com/doc/refman/5.7/en/information-schema-columns-table.html -var columnsCols = []columnInfo{ - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 64}, - {name: "COLUMN_DEFAULT", tp: mysql.TypeBlob, size: 196606}, - {name: "IS_NULLABLE", tp: mysql.TypeVarchar, size: 3}, - {name: "DATA_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "CHARACTER_MAXIMUM_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "CHARACTER_OCTET_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "NUMERIC_PRECISION", tp: mysql.TypeLonglong, size: 21}, - {name: "NUMERIC_SCALE", tp: mysql.TypeLonglong, size: 21}, - {name: "DATETIME_PRECISION", tp: mysql.TypeLonglong, size: 21}, - {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32}, - {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 32}, - {name: "COLUMN_TYPE", tp: mysql.TypeBlob, size: 196606}, - {name: "COLUMN_KEY", tp: mysql.TypeVarchar, size: 3}, - {name: "EXTRA", tp: mysql.TypeVarchar, size: 45}, - {name: "PRIVILEGES", tp: mysql.TypeVarchar, size: 80}, - {name: "COLUMN_COMMENT", tp: mysql.TypeVarchar, size: 1024}, - {name: "GENERATION_EXPRESSION", tp: mysql.TypeBlob, size: 589779, flag: mysql.NotNullFlag}, -} - -var columnStatisticsCols = []columnInfo{ - {name: "SCHEMA_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "HISTOGRAM", tp: mysql.TypeJSON, size: 51}, -} - -var statisticsCols = []columnInfo{ - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "NON_UNIQUE", tp: mysql.TypeVarchar, size: 1}, - {name: "INDEX_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "SEQ_IN_INDEX", tp: mysql.TypeLonglong, size: 2}, - {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 21}, - {name: "COLLATION", tp: mysql.TypeVarchar, size: 1}, - {name: "CARDINALITY", tp: mysql.TypeLonglong, size: 21}, - {name: "SUB_PART", tp: mysql.TypeLonglong, size: 3}, - {name: "PACKED", tp: mysql.TypeVarchar, size: 10}, - {name: "NULLABLE", tp: mysql.TypeVarchar, size: 3}, - {name: "INDEX_TYPE", tp: mysql.TypeVarchar, size: 16}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 16}, - {name: "INDEX_COMMENT", tp: mysql.TypeVarchar, size: 1024}, - {name: "IS_VISIBLE", tp: mysql.TypeVarchar, size: 3}, - {name: "Expression", tp: mysql.TypeVarchar, size: 64}, -} - -var profilingCols = []columnInfo{ - {name: "QUERY_ID", tp: mysql.TypeLong, size: 20}, - {name: "SEQ", tp: mysql.TypeLong, size: 20}, - {name: "STATE", tp: mysql.TypeVarchar, size: 30}, - {name: "DURATION", tp: mysql.TypeNewDecimal, size: 9}, - {name: "CPU_USER", tp: mysql.TypeNewDecimal, size: 9}, - {name: "CPU_SYSTEM", tp: mysql.TypeNewDecimal, size: 9}, - {name: "CONTEXT_VOLUNTARY", tp: mysql.TypeLong, size: 20}, - {name: "CONTEXT_INVOLUNTARY", tp: mysql.TypeLong, size: 20}, - {name: "BLOCK_OPS_IN", tp: mysql.TypeLong, size: 20}, - {name: "BLOCK_OPS_OUT", tp: mysql.TypeLong, size: 20}, - {name: "MESSAGES_SENT", tp: mysql.TypeLong, size: 20}, - {name: "MESSAGES_RECEIVED", tp: mysql.TypeLong, size: 20}, - {name: "PAGE_FAULTS_MAJOR", tp: mysql.TypeLong, size: 20}, - {name: "PAGE_FAULTS_MINOR", tp: mysql.TypeLong, size: 20}, - {name: "SWAPS", tp: mysql.TypeLong, size: 20}, - {name: "SOURCE_FUNCTION", tp: mysql.TypeVarchar, size: 30}, - {name: "SOURCE_FILE", tp: mysql.TypeVarchar, size: 20}, - {name: "SOURCE_LINE", tp: mysql.TypeLong, size: 20}, -} - -var charsetCols = []columnInfo{ - {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32}, - {name: "DEFAULT_COLLATE_NAME", tp: mysql.TypeVarchar, size: 32}, - {name: "DESCRIPTION", tp: mysql.TypeVarchar, size: 60}, - {name: "MAXLEN", tp: mysql.TypeLonglong, size: 3}, -} - -var collationsCols = []columnInfo{ - {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 32}, - {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32}, - {name: "ID", tp: mysql.TypeLonglong, size: 11}, - {name: "IS_DEFAULT", tp: mysql.TypeVarchar, size: 3}, - {name: "IS_COMPILED", tp: mysql.TypeVarchar, size: 3}, - {name: "SORTLEN", tp: mysql.TypeLonglong, size: 3}, - {name: "PAD_ATTRIBUTE", tp: mysql.TypeVarchar, size: 9}, -} - -var keyColumnUsageCols = []columnInfo{ - {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 10, flag: mysql.NotNullFlag}, - {name: "POSITION_IN_UNIQUE_CONSTRAINT", tp: mysql.TypeLonglong, size: 10}, - {name: "REFERENCED_TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "REFERENCED_TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "REFERENCED_COLUMN_NAME", tp: mysql.TypeVarchar, size: 64}, -} - -// See http://dev.mysql.com/doc/refman/5.7/en/information-schema-referential-constraints-table.html -var referConstCols = []columnInfo{ - {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "UNIQUE_CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "UNIQUE_CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "UNIQUE_CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "MATCH_OPTION", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "UPDATE_RULE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "DELETE_RULE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "REFERENCED_TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, -} - -// See http://dev.mysql.com/doc/refman/5.7/en/information-schema-variables-table.html -var sessionVarCols = []columnInfo{ - {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, -} - -// See https://dev.mysql.com/doc/refman/5.7/en/information-schema-plugins-table.html -var pluginsCols = []columnInfo{ - {name: "PLUGIN_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "PLUGIN_VERSION", tp: mysql.TypeVarchar, size: 20}, - {name: "PLUGIN_STATUS", tp: mysql.TypeVarchar, size: 10}, - {name: "PLUGIN_TYPE", tp: mysql.TypeVarchar, size: 80}, - {name: "PLUGIN_TYPE_VERSION", tp: mysql.TypeVarchar, size: 20}, - {name: "PLUGIN_LIBRARY", tp: mysql.TypeVarchar, size: 64}, - {name: "PLUGIN_LIBRARY_VERSION", tp: mysql.TypeVarchar, size: 20}, - {name: "PLUGIN_AUTHOR", tp: mysql.TypeVarchar, size: 64}, - {name: "PLUGIN_DESCRIPTION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: "PLUGIN_LICENSE", tp: mysql.TypeVarchar, size: 80}, - {name: "LOAD_OPTION", tp: mysql.TypeVarchar, size: 64}, -} - -// See https://dev.mysql.com/doc/refman/5.7/en/information-schema-partitions-table.html -var partitionsCols = []columnInfo{ - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "PARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "SUBPARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "PARTITION_ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 21}, - {name: "SUBPARTITION_ORDINAL_POSITION", tp: mysql.TypeLonglong, size: 21}, - {name: "PARTITION_METHOD", tp: mysql.TypeVarchar, size: 18}, - {name: "SUBPARTITION_METHOD", tp: mysql.TypeVarchar, size: 12}, - {name: "PARTITION_EXPRESSION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: "SUBPARTITION_EXPRESSION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: "PARTITION_DESCRIPTION", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: "TABLE_ROWS", tp: mysql.TypeLonglong, size: 21}, - {name: "AVG_ROW_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "MAX_DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "INDEX_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "DATA_FREE", tp: mysql.TypeLonglong, size: 21}, - {name: "CREATE_TIME", tp: mysql.TypeDatetime}, - {name: "UPDATE_TIME", tp: mysql.TypeDatetime}, - {name: "CHECK_TIME", tp: mysql.TypeDatetime}, - {name: "CHECKSUM", tp: mysql.TypeLonglong, size: 21}, - {name: "PARTITION_COMMENT", tp: mysql.TypeVarchar, size: 80}, - {name: "NODEGROUP", tp: mysql.TypeVarchar, size: 12}, - {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TIDB_PARTITION_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "TIDB_PLACEMENT_POLICY_NAME", tp: mysql.TypeVarchar, size: 64}, -} - -var tableConstraintsCols = []columnInfo{ - {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "CONSTRAINT_TYPE", tp: mysql.TypeVarchar, size: 64}, -} - -var tableTriggersCols = []columnInfo{ - {name: "TRIGGER_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "TRIGGER_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TRIGGER_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "EVENT_MANIPULATION", tp: mysql.TypeVarchar, size: 6}, - {name: "EVENT_OBJECT_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "EVENT_OBJECT_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "EVENT_OBJECT_TABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "ACTION_ORDER", tp: mysql.TypeLonglong, size: 4}, - {name: "ACTION_CONDITION", tp: mysql.TypeBlob, size: -1}, - {name: "ACTION_STATEMENT", tp: mysql.TypeBlob, size: -1}, - {name: "ACTION_ORIENTATION", tp: mysql.TypeVarchar, size: 9}, - {name: "ACTION_TIMING", tp: mysql.TypeVarchar, size: 6}, - {name: "ACTION_REFERENCE_OLD_TABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "ACTION_REFERENCE_NEW_TABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "ACTION_REFERENCE_OLD_ROW", tp: mysql.TypeVarchar, size: 3}, - {name: "ACTION_REFERENCE_NEW_ROW", tp: mysql.TypeVarchar, size: 3}, - {name: "CREATED", tp: mysql.TypeDatetime, size: 2}, - {name: "SQL_MODE", tp: mysql.TypeVarchar, size: 8192}, - {name: "DEFINER", tp: mysql.TypeVarchar, size: 77}, - {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32}, - {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32}, - {name: "DATABASE_COLLATION", tp: mysql.TypeVarchar, size: 32}, -} - -var tableUserPrivilegesCols = []columnInfo{ - {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81}, - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512}, - {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3}, -} - -var tableSchemaPrivilegesCols = []columnInfo{ - {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81, flag: mysql.NotNullFlag}, - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, -} - -var tableTablePrivilegesCols = []columnInfo{ - {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81, flag: mysql.NotNullFlag}, - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, -} - -var tableColumnPrivilegesCols = []columnInfo{ - {name: "GRANTEE", tp: mysql.TypeVarchar, size: 81, flag: mysql.NotNullFlag}, - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "PRIVILEGE_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "IS_GRANTABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, -} - -var tableEnginesCols = []columnInfo{ - {name: "ENGINE", tp: mysql.TypeVarchar, size: 64}, - {name: "SUPPORT", tp: mysql.TypeVarchar, size: 8}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 80}, - {name: "TRANSACTIONS", tp: mysql.TypeVarchar, size: 3}, - {name: "XA", tp: mysql.TypeVarchar, size: 3}, - {name: "SAVEPOINTS", tp: mysql.TypeVarchar, size: 3}, -} - -var tableViewsCols = []columnInfo{ - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "VIEW_DEFINITION", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag}, - {name: "CHECK_OPTION", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, - {name: "IS_UPDATABLE", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, - {name: "DEFINER", tp: mysql.TypeVarchar, size: 77, flag: mysql.NotNullFlag}, - {name: "SECURITY_TYPE", tp: mysql.TypeVarchar, size: 7, flag: mysql.NotNullFlag}, - {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, -} - -var tableRoutinesCols = []columnInfo{ - {name: "SPECIFIC_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "ROUTINE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "ROUTINE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "ROUTINE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "ROUTINE_TYPE", tp: mysql.TypeVarchar, size: 9, flag: mysql.NotNullFlag}, - {name: "DATA_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CHARACTER_MAXIMUM_LENGTH", tp: mysql.TypeLong, size: 21}, - {name: "CHARACTER_OCTET_LENGTH", tp: mysql.TypeLong, size: 21}, - {name: "NUMERIC_PRECISION", tp: mysql.TypeLonglong, size: 21}, - {name: "NUMERIC_SCALE", tp: mysql.TypeLong, size: 21}, - {name: "DATETIME_PRECISION", tp: mysql.TypeLonglong, size: 21}, - {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "DTD_IDENTIFIER", tp: mysql.TypeLongBlob}, - {name: "ROUTINE_BODY", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, - {name: "ROUTINE_DEFINITION", tp: mysql.TypeLongBlob}, - {name: "EXTERNAL_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "EXTERNAL_LANGUAGE", tp: mysql.TypeVarchar, size: 64}, - {name: "PARAMETER_STYLE", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, - {name: "IS_DETERMINISTIC", tp: mysql.TypeVarchar, size: 3, flag: mysql.NotNullFlag}, - {name: "SQL_DATA_ACCESS", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "SQL_PATH", tp: mysql.TypeVarchar, size: 64}, - {name: "SECURITY_TYPE", tp: mysql.TypeVarchar, size: 7, flag: mysql.NotNullFlag}, - {name: "CREATED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, - {name: "LAST_ALTERED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, - {name: "SQL_MODE", tp: mysql.TypeVarchar, size: 8192, flag: mysql.NotNullFlag}, - {name: "ROUTINE_COMMENT", tp: mysql.TypeLongBlob}, - {name: "DEFINER", tp: mysql.TypeVarchar, size: 77, flag: mysql.NotNullFlag}, - {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "DATABASE_COLLATION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, -} - -var tableParametersCols = []columnInfo{ - {name: "SPECIFIC_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "SPECIFIC_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "SPECIFIC_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "ORDINAL_POSITION", tp: mysql.TypeVarchar, size: 21, flag: mysql.NotNullFlag}, - {name: "PARAMETER_MODE", tp: mysql.TypeVarchar, size: 5}, - {name: "PARAMETER_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "DATA_TYPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CHARACTER_MAXIMUM_LENGTH", tp: mysql.TypeVarchar, size: 21}, - {name: "CHARACTER_OCTET_LENGTH", tp: mysql.TypeVarchar, size: 21}, - {name: "NUMERIC_PRECISION", tp: mysql.TypeVarchar, size: 21}, - {name: "NUMERIC_SCALE", tp: mysql.TypeVarchar, size: 21}, - {name: "DATETIME_PRECISION", tp: mysql.TypeVarchar, size: 21}, - {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "DTD_IDENTIFIER", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag}, - {name: "ROUTINE_TYPE", tp: mysql.TypeVarchar, size: 9, flag: mysql.NotNullFlag}, -} - -var tableEventsCols = []columnInfo{ - {name: "EVENT_CATALOG", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "EVENT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "EVENT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "DEFINER", tp: mysql.TypeVarchar, size: 77, flag: mysql.NotNullFlag}, - {name: "TIME_ZONE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "EVENT_BODY", tp: mysql.TypeVarchar, size: 8, flag: mysql.NotNullFlag}, - {name: "EVENT_DEFINITION", tp: mysql.TypeLongBlob}, - {name: "EVENT_TYPE", tp: mysql.TypeVarchar, size: 9, flag: mysql.NotNullFlag}, - {name: "EXECUTE_AT", tp: mysql.TypeDatetime}, - {name: "INTERVAL_VALUE", tp: mysql.TypeVarchar, size: 256}, - {name: "INTERVAL_FIELD", tp: mysql.TypeVarchar, size: 18}, - {name: "SQL_MODE", tp: mysql.TypeVarchar, size: 8192, flag: mysql.NotNullFlag}, - {name: "STARTS", tp: mysql.TypeDatetime}, - {name: "ENDS", tp: mysql.TypeDatetime}, - {name: "STATUS", tp: mysql.TypeVarchar, size: 18, flag: mysql.NotNullFlag}, - {name: "ON_COMPLETION", tp: mysql.TypeVarchar, size: 12, flag: mysql.NotNullFlag}, - {name: "CREATED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, - {name: "LAST_ALTERED", tp: mysql.TypeDatetime, flag: mysql.NotNullFlag, deflt: "0000-00-00 00:00:00"}, - {name: "LAST_EXECUTED", tp: mysql.TypeDatetime}, - {name: "EVENT_COMMENT", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "ORIGINATOR", tp: mysql.TypeLong, size: 10, flag: mysql.NotNullFlag, deflt: 0}, - {name: "CHARACTER_SET_CLIENT", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "COLLATION_CONNECTION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "DATABASE_COLLATION", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, -} - -var tableGlobalStatusCols = []columnInfo{ - {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, -} - -var tableGlobalVariablesCols = []columnInfo{ - {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, -} - -var tableSessionStatusCols = []columnInfo{ - {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "VARIABLE_VALUE", tp: mysql.TypeVarchar, size: 1024}, -} - -var tableOptimizerTraceCols = []columnInfo{ - {name: "QUERY", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag, deflt: ""}, - {name: "TRACE", tp: mysql.TypeLongBlob, flag: mysql.NotNullFlag, deflt: ""}, - {name: "MISSING_BYTES_BEYOND_MAX_MEM_SIZE", tp: mysql.TypeShort, size: 20, flag: mysql.NotNullFlag, deflt: 0}, - {name: "INSUFFICIENT_PRIVILEGES", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, -} - -var tableTableSpacesCols = []columnInfo{ - {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, - {name: "ENGINE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, - {name: "TABLESPACE_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "LOGFILE_GROUP_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "EXTENT_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "AUTOEXTEND_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "MAXIMUM_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "NODEGROUP_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "TABLESPACE_COMMENT", tp: mysql.TypeVarchar, size: 2048}, -} - -var tableCollationCharacterSetApplicabilityCols = []columnInfo{ - {name: "COLLATION_NAME", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "CHARACTER_SET_NAME", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, -} - -var tableProcesslistCols = []columnInfo{ - {name: "ID", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, deflt: 0}, - {name: "USER", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag, deflt: ""}, - {name: "HOST", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, - {name: "DB", tp: mysql.TypeVarchar, size: 64}, - {name: "COMMAND", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag, deflt: ""}, - {name: "TIME", tp: mysql.TypeLong, size: 7, flag: mysql.NotNullFlag, deflt: 0}, - {name: "STATE", tp: mysql.TypeVarchar, size: 7}, - {name: "INFO", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: "DIGEST", tp: mysql.TypeVarchar, size: 64, deflt: ""}, - {name: "MEM", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, - {name: "DISK", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, - {name: "TxnStart", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, - {name: "RESOURCE_GROUP", tp: mysql.TypeVarchar, size: resourcegroup.MaxGroupNameLength, flag: mysql.NotNullFlag, deflt: ""}, - {name: "SESSION_ALIAS", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, deflt: ""}, - {name: "CURRENT_AFFECTED_ROWS", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, -} - -var tableTiDBIndexesCols = []columnInfo{ - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "NON_UNIQUE", tp: mysql.TypeLonglong, size: 21}, - {name: "KEY_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "SEQ_IN_INDEX", tp: mysql.TypeLonglong, size: 21}, - {name: "COLUMN_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "SUB_PART", tp: mysql.TypeLonglong, size: 21}, - {name: "INDEX_COMMENT", tp: mysql.TypeVarchar, size: 1024}, - {name: "Expression", tp: mysql.TypeVarchar, size: 64}, - {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "IS_VISIBLE", tp: mysql.TypeVarchar, size: 64}, - {name: "CLUSTERED", tp: mysql.TypeVarchar, size: 64}, - {name: "IS_GLOBAL", tp: mysql.TypeLonglong, size: 21}, -} - -var slowQueryCols = []columnInfo{ - {name: variable.SlowLogTimeStr, tp: mysql.TypeTimestamp, size: 26, decimal: 6, flag: mysql.PriKeyFlag | mysql.NotNullFlag | mysql.BinaryFlag}, - {name: variable.SlowLogTxnStartTSStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: variable.SlowLogUserStr, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogHostStr, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogConnIDStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: variable.SlowLogSessAliasStr, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogExecRetryCount, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: variable.SlowLogExecRetryTime, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogQueryTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogParseTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogCompileTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogRewriteTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogPreprocSubQueriesStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: variable.SlowLogPreProcSubQueryTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogOptimizeTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogWaitTSTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.PreWriteTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.WaitPrewriteBinlogTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.CommitTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.GetCommitTSTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.CommitBackoffTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.BackoffTypesStr, tp: mysql.TypeVarchar, size: 64}, - {name: execdetails.ResolveLockTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.LocalLatchWaitTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.WriteKeysStr, tp: mysql.TypeLonglong, size: 22}, - {name: execdetails.WriteSizeStr, tp: mysql.TypeLonglong, size: 22}, - {name: execdetails.PrewriteRegionStr, tp: mysql.TypeLonglong, size: 22}, - {name: execdetails.TxnRetryStr, tp: mysql.TypeLonglong, size: 22}, - {name: execdetails.CopTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.ProcessTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.WaitTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.BackoffTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.LockKeysTimeStr, tp: mysql.TypeDouble, size: 22}, - {name: execdetails.RequestCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: execdetails.TotalKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: execdetails.ProcessKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: execdetails.RocksdbDeleteSkippedCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: execdetails.RocksdbKeySkippedCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: execdetails.RocksdbBlockCacheHitCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: execdetails.RocksdbBlockReadCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: execdetails.RocksdbBlockReadByteStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.UnsignedFlag}, - {name: variable.SlowLogDBStr, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogIndexNamesStr, tp: mysql.TypeVarchar, size: 100}, - {name: variable.SlowLogIsInternalStr, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogDigestStr, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogStatsInfoStr, tp: mysql.TypeVarchar, size: 512}, - {name: variable.SlowLogCopProcAvg, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogCopProcP90, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogCopProcMax, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogCopProcAddr, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogCopWaitAvg, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogCopWaitP90, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogCopWaitMax, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogCopWaitAddr, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogMemMax, tp: mysql.TypeLonglong, size: 20}, - {name: variable.SlowLogDiskMax, tp: mysql.TypeLonglong, size: 20}, - {name: variable.SlowLogKVTotal, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogPDTotal, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogBackoffTotal, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogWriteSQLRespTotal, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogResultRows, tp: mysql.TypeLonglong, size: 22}, - {name: variable.SlowLogWarnings, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: variable.SlowLogBackoffDetail, tp: mysql.TypeVarchar, size: 4096}, - {name: variable.SlowLogPrepared, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogSucc, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogIsExplicitTxn, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogIsWriteCacheTable, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogPlanFromCache, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogPlanFromBinding, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogHasMoreResults, tp: mysql.TypeTiny, size: 1}, - {name: variable.SlowLogResourceGroup, tp: mysql.TypeVarchar, size: 64}, - {name: variable.SlowLogRRU, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogWRU, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogWaitRUDuration, tp: mysql.TypeDouble, size: 22}, - {name: variable.SlowLogPlan, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: variable.SlowLogPlanDigest, tp: mysql.TypeVarchar, size: 128}, - {name: variable.SlowLogBinaryPlan, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: variable.SlowLogPrevStmt, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: variable.SlowLogQuerySQLStr, tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, -} - -// TableTiDBHotRegionsCols is TiDB hot region mem table columns. -var TableTiDBHotRegionsCols = []columnInfo{ - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "MAX_HOT_DEGREE", tp: mysql.TypeLonglong, size: 21}, - {name: "REGION_COUNT", tp: mysql.TypeLonglong, size: 21}, - {name: "FLOW_BYTES", tp: mysql.TypeLonglong, size: 21}, -} - -// TableTiDBHotRegionsHistoryCols is TiDB hot region history mem table columns. -var TableTiDBHotRegionsHistoryCols = []columnInfo{ - {name: "UPDATE_TIME", tp: mysql.TypeTimestamp, size: 26, decimal: 6}, - {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "STORE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "PEER_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "IS_LEARNER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, - {name: "IS_LEADER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "HOT_DEGREE", tp: mysql.TypeLonglong, size: 21}, - {name: "FLOW_BYTES", tp: mysql.TypeDouble, size: 22}, - {name: "KEY_RATE", tp: mysql.TypeDouble, size: 22}, - {name: "QUERY_RATE", tp: mysql.TypeDouble, size: 22}, -} - -// GetTableTiDBHotRegionsHistoryCols is to get TableTiDBHotRegionsHistoryCols. -// It is an optimization because Go does’t support const arrays. The solution is to use initialization functions. -// It is useful in the BCE optimization. -// https://go101.org/article/bounds-check-elimination.html -func GetTableTiDBHotRegionsHistoryCols() []columnInfo { - return TableTiDBHotRegionsHistoryCols -} - -// TableTiKVStoreStatusCols is TiDB kv store status columns. -var TableTiKVStoreStatusCols = []columnInfo{ - {name: "STORE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "ADDRESS", tp: mysql.TypeVarchar, size: 64}, - {name: "STORE_STATE", tp: mysql.TypeLonglong, size: 21}, - {name: "STORE_STATE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "LABEL", tp: mysql.TypeJSON, size: 51}, - {name: "VERSION", tp: mysql.TypeVarchar, size: 64}, - {name: "CAPACITY", tp: mysql.TypeVarchar, size: 64}, - {name: "AVAILABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "LEADER_COUNT", tp: mysql.TypeLonglong, size: 21}, - {name: "LEADER_WEIGHT", tp: mysql.TypeDouble, size: 22}, - {name: "LEADER_SCORE", tp: mysql.TypeDouble, size: 22}, - {name: "LEADER_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "REGION_COUNT", tp: mysql.TypeLonglong, size: 21}, - {name: "REGION_WEIGHT", tp: mysql.TypeDouble, size: 22}, - {name: "REGION_SCORE", tp: mysql.TypeDouble, size: 22}, - {name: "REGION_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "START_TS", tp: mysql.TypeDatetime}, - {name: "LAST_HEARTBEAT_TS", tp: mysql.TypeDatetime}, - {name: "UPTIME", tp: mysql.TypeVarchar, size: 64}, -} - -var tableAnalyzeStatusCols = []columnInfo{ - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "PARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "JOB_INFO", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: "PROCESSED_ROWS", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, - {name: "START_TIME", tp: mysql.TypeDatetime}, - {name: "END_TIME", tp: mysql.TypeDatetime}, - {name: "STATE", tp: mysql.TypeVarchar, size: 64}, - {name: "FAIL_REASON", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 512}, - {name: "PROCESS_ID", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, - {name: "REMAINING_SECONDS", tp: mysql.TypeVarchar, size: 512}, - {name: "PROGRESS", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "ESTIMATED_TOTAL_ROWS", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, -} - -// TableTiKVRegionStatusCols is TiKV region status mem table columns. -var TableTiKVRegionStatusCols = []columnInfo{ - {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "START_KEY", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, - {name: "END_KEY", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "IS_INDEX", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, - {name: "INDEX_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "IS_PARTITION", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, - {name: "PARTITION_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "PARTITION_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "EPOCH_CONF_VER", tp: mysql.TypeLonglong, size: 21}, - {name: "EPOCH_VERSION", tp: mysql.TypeLonglong, size: 21}, - {name: "WRITTEN_BYTES", tp: mysql.TypeLonglong, size: 21}, - {name: "READ_BYTES", tp: mysql.TypeLonglong, size: 21}, - {name: "APPROXIMATE_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "APPROXIMATE_KEYS", tp: mysql.TypeLonglong, size: 21}, - {name: "REPLICATIONSTATUS_STATE", tp: mysql.TypeVarchar, size: 64}, - {name: "REPLICATIONSTATUS_STATEID", tp: mysql.TypeLonglong, size: 21}, -} - -// TableTiKVRegionPeersCols is TiKV region peers mem table columns. -var TableTiKVRegionPeersCols = []columnInfo{ - {name: "REGION_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "PEER_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "STORE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "IS_LEARNER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, - {name: "IS_LEADER", tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, deflt: 0}, - {name: "STATUS", tp: mysql.TypeVarchar, size: 10, deflt: 0}, - {name: "DOWN_SECONDS", tp: mysql.TypeLonglong, size: 21, deflt: 0}, -} - -// GetTableTiKVRegionPeersCols is to get TableTiKVRegionPeersCols. -// It is an optimization because Go does’t support const arrays. The solution is to use initialization functions. -// It is useful in the BCE optimization. -// https://go101.org/article/bounds-check-elimination.html -func GetTableTiKVRegionPeersCols() []columnInfo { - return TableTiKVRegionPeersCols -} - -var tableTiDBServersInfoCols = []columnInfo{ - {name: "DDL_ID", tp: mysql.TypeVarchar, size: 64}, - {name: "IP", tp: mysql.TypeVarchar, size: 64}, - {name: "PORT", tp: mysql.TypeLonglong, size: 21}, - {name: "STATUS_PORT", tp: mysql.TypeLonglong, size: 21}, - {name: "LEASE", tp: mysql.TypeVarchar, size: 64}, - {name: "VERSION", tp: mysql.TypeVarchar, size: 64}, - {name: "GIT_HASH", tp: mysql.TypeVarchar, size: 64}, - {name: "BINLOG_STATUS", tp: mysql.TypeVarchar, size: 64}, - {name: "LABELS", tp: mysql.TypeVarchar, size: 128}, -} - -var tableClusterConfigCols = []columnInfo{ - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "KEY", tp: mysql.TypeVarchar, size: 256}, - {name: "VALUE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, -} - -var tableClusterLogCols = []columnInfo{ - {name: "TIME", tp: mysql.TypeVarchar, size: 32}, - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "LEVEL", tp: mysql.TypeVarchar, size: 8}, - {name: "MESSAGE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, -} - -var tableClusterLoadCols = []columnInfo{ - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "DEVICE_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "DEVICE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "NAME", tp: mysql.TypeVarchar, size: 256}, - {name: "VALUE", tp: mysql.TypeVarchar, size: 128}, -} - -var tableClusterHardwareCols = []columnInfo{ - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "DEVICE_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "DEVICE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "NAME", tp: mysql.TypeVarchar, size: 256}, - {name: "VALUE", tp: mysql.TypeVarchar, size: 128}, -} - -var tableClusterSystemInfoCols = []columnInfo{ - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "SYSTEM_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "SYSTEM_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "NAME", tp: mysql.TypeVarchar, size: 256}, - {name: "VALUE", tp: mysql.TypeVarchar, size: 128}, -} - -var filesCols = []columnInfo{ - {name: "FILE_ID", tp: mysql.TypeLonglong, size: 4}, - {name: "FILE_NAME", tp: mysql.TypeVarchar, size: 4000}, - {name: "FILE_TYPE", tp: mysql.TypeVarchar, size: 20}, - {name: "TABLESPACE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "LOGFILE_GROUP_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "LOGFILE_GROUP_NUMBER", tp: mysql.TypeLonglong, size: 32}, - {name: "ENGINE", tp: mysql.TypeVarchar, size: 64}, - {name: "FULLTEXT_KEYS", tp: mysql.TypeVarchar, size: 64}, - {name: "DELETED_ROWS", tp: mysql.TypeLonglong, size: 4}, - {name: "UPDATE_COUNT", tp: mysql.TypeLonglong, size: 4}, - {name: "FREE_EXTENTS", tp: mysql.TypeLonglong, size: 4}, - {name: "TOTAL_EXTENTS", tp: mysql.TypeLonglong, size: 4}, - {name: "EXTENT_SIZE", tp: mysql.TypeLonglong, size: 4}, - {name: "INITIAL_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "MAXIMUM_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "AUTOEXTEND_SIZE", tp: mysql.TypeLonglong, size: 21}, - {name: "CREATION_TIME", tp: mysql.TypeDatetime, size: -1}, - {name: "LAST_UPDATE_TIME", tp: mysql.TypeDatetime, size: -1}, - {name: "LAST_ACCESS_TIME", tp: mysql.TypeDatetime, size: -1}, - {name: "RECOVER_TIME", tp: mysql.TypeLonglong, size: 4}, - {name: "TRANSACTION_COUNTER", tp: mysql.TypeLonglong, size: 4}, - {name: "VERSION", tp: mysql.TypeLonglong, size: 21}, - {name: "ROW_FORMAT", tp: mysql.TypeVarchar, size: 10}, - {name: "TABLE_ROWS", tp: mysql.TypeLonglong, size: 21}, - {name: "AVG_ROW_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "MAX_DATA_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "INDEX_LENGTH", tp: mysql.TypeLonglong, size: 21}, - {name: "DATA_FREE", tp: mysql.TypeLonglong, size: 21}, - {name: "CREATE_TIME", tp: mysql.TypeDatetime, size: -1}, - {name: "UPDATE_TIME", tp: mysql.TypeDatetime, size: -1}, - {name: "CHECK_TIME", tp: mysql.TypeDatetime, size: -1}, - {name: "CHECKSUM", tp: mysql.TypeLonglong, size: 21}, - {name: "STATUS", tp: mysql.TypeVarchar, size: 20}, - {name: "EXTRA", tp: mysql.TypeVarchar, size: 255}, -} - -var tableClusterInfoCols = []columnInfo{ - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "STATUS_ADDRESS", tp: mysql.TypeVarchar, size: 64}, - {name: "VERSION", tp: mysql.TypeVarchar, size: 64}, - {name: "GIT_HASH", tp: mysql.TypeVarchar, size: 64}, - {name: "START_TIME", tp: mysql.TypeDatetime, size: 19}, - {name: "UPTIME", tp: mysql.TypeVarchar, size: 32}, - {name: "SERVER_ID", tp: mysql.TypeLonglong, size: 21, comment: "invalid if the configuration item `enable-global-kill` is set to FALSE"}, -} - -var tableTableTiFlashReplicaCols = []columnInfo{ - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "REPLICA_COUNT", tp: mysql.TypeLonglong, size: 64}, - {name: "LOCATION_LABELS", tp: mysql.TypeVarchar, size: 64}, - {name: "AVAILABLE", tp: mysql.TypeTiny, size: 1}, - {name: "PROGRESS", tp: mysql.TypeDouble, size: 22}, -} - -var tableInspectionResultCols = []columnInfo{ - {name: "RULE", tp: mysql.TypeVarchar, size: 64}, - {name: "ITEM", tp: mysql.TypeVarchar, size: 64}, - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "STATUS_ADDRESS", tp: mysql.TypeVarchar, size: 64}, - {name: "VALUE", tp: mysql.TypeVarchar, size: 64}, - {name: "REFERENCE", tp: mysql.TypeVarchar, size: 64}, - {name: "SEVERITY", tp: mysql.TypeVarchar, size: 64}, - {name: "DETAILS", tp: mysql.TypeVarchar, size: 256}, -} - -var tableInspectionSummaryCols = []columnInfo{ - {name: "RULE", tp: mysql.TypeVarchar, size: 64}, - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "METRICS_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "LABEL", tp: mysql.TypeVarchar, size: 64}, - {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, - {name: "AVG_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "MIN_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "MAX_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, -} - -var tableInspectionRulesCols = []columnInfo{ - {name: "NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, -} - -var tableMetricTablesCols = []columnInfo{ - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "PROMQL", tp: mysql.TypeVarchar, size: 64}, - {name: "LABELS", tp: mysql.TypeVarchar, size: 64}, - {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, -} - -var tableMetricSummaryCols = []columnInfo{ - {name: "METRICS_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, - {name: "SUM_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "AVG_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "MIN_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "MAX_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, -} - -var tableMetricSummaryByLabelCols = []columnInfo{ - {name: "INSTANCE", tp: mysql.TypeVarchar, size: 64}, - {name: "METRICS_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "LABEL", tp: mysql.TypeVarchar, size: 64}, - {name: "QUANTILE", tp: mysql.TypeDouble, size: 22}, - {name: "SUM_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "AVG_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "MIN_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "MAX_VALUE", tp: mysql.TypeDouble, size: 22, decimal: 6}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 256}, -} - -var tableDDLJobsCols = []columnInfo{ - {name: "JOB_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "DB_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "JOB_TYPE", tp: mysql.TypeVarchar, size: 64}, - {name: "SCHEMA_STATE", tp: mysql.TypeVarchar, size: 64}, - {name: "SCHEMA_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "ROW_COUNT", tp: mysql.TypeLonglong, size: 21}, - {name: "CREATE_TIME", tp: mysql.TypeDatetime, size: 26, decimal: 6}, - {name: "START_TIME", tp: mysql.TypeDatetime, size: 26, decimal: 6}, - {name: "END_TIME", tp: mysql.TypeDatetime, size: 26, decimal: 6}, - {name: "STATE", tp: mysql.TypeVarchar, size: 64}, - {name: "QUERY", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, -} - -var tableSequencesCols = []columnInfo{ - {name: "TABLE_CATALOG", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "SEQUENCE_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "SEQUENCE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CACHE", tp: mysql.TypeTiny, flag: mysql.NotNullFlag}, - {name: "CACHE_VALUE", tp: mysql.TypeLonglong, size: 21}, - {name: "CYCLE", tp: mysql.TypeTiny, flag: mysql.NotNullFlag}, - {name: "INCREMENT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "MAX_VALUE", tp: mysql.TypeLonglong, size: 21}, - {name: "MIN_VALUE", tp: mysql.TypeLonglong, size: 21}, - {name: "START", tp: mysql.TypeLonglong, size: 21}, - {name: "COMMENT", tp: mysql.TypeVarchar, size: 64}, -} - -var tableStatementsSummaryCols = []columnInfo{ - {name: stmtsummary.SummaryBeginTimeStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "Begin time of this summary"}, - {name: stmtsummary.SummaryEndTimeStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "End time of this summary"}, - {name: stmtsummary.StmtTypeStr, tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag, comment: "Statement type"}, - {name: stmtsummary.SchemaNameStr, tp: mysql.TypeVarchar, size: 64, comment: "Current schema"}, - {name: stmtsummary.DigestStr, tp: mysql.TypeVarchar, size: 64}, - {name: stmtsummary.DigestTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag, comment: "Normalized statement"}, - {name: stmtsummary.TableNamesStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Involved tables"}, - {name: stmtsummary.IndexNamesStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Used indices"}, - {name: stmtsummary.SampleUserStr, tp: mysql.TypeVarchar, size: 64, comment: "Sampled user who executed these statements"}, - {name: stmtsummary.ExecCountStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Count of executions"}, - {name: stmtsummary.SumErrorsStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum of errors"}, - {name: stmtsummary.SumWarningsStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum of warnings"}, - {name: stmtsummary.SumLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum latency of these statements"}, - {name: stmtsummary.MaxLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max latency of these statements"}, - {name: stmtsummary.MinLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Min latency of these statements"}, - {name: stmtsummary.AvgLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average latency of these statements"}, - {name: stmtsummary.AvgParseLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average latency of parsing"}, - {name: stmtsummary.MaxParseLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max latency of parsing"}, - {name: stmtsummary.AvgCompileLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average latency of compiling"}, - {name: stmtsummary.MaxCompileLatencyStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max latency of compiling"}, - {name: stmtsummary.SumCopTaskNumStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Total number of CopTasks"}, - {name: stmtsummary.MaxCopProcessTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max processing time of CopTasks"}, - {name: stmtsummary.MaxCopProcessAddressStr, tp: mysql.TypeVarchar, size: 256, comment: "Address of the CopTask with max processing time"}, - {name: stmtsummary.MaxCopWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time of CopTasks"}, - {name: stmtsummary.MaxCopWaitAddressStr, tp: mysql.TypeVarchar, size: 256, comment: "Address of the CopTask with max waiting time"}, - {name: stmtsummary.AvgProcessTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average processing time in TiKV"}, - {name: stmtsummary.MaxProcessTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max processing time in TiKV"}, - {name: stmtsummary.AvgWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average waiting time in TiKV"}, - {name: stmtsummary.MaxWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time in TiKV"}, - {name: stmtsummary.AvgBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average waiting time before retry"}, - {name: stmtsummary.MaxBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time before retry"}, - {name: stmtsummary.AvgTotalKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of scanned keys"}, - {name: stmtsummary.MaxTotalKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of scanned keys"}, - {name: stmtsummary.AvgProcessedKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of processed keys"}, - {name: stmtsummary.MaxProcessedKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of processed keys"}, - {name: stmtsummary.AvgRocksdbDeleteSkippedCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb delete skipped count"}, - {name: stmtsummary.MaxRocksdbDeleteSkippedCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb delete skipped count"}, - {name: stmtsummary.AvgRocksdbKeySkippedCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb key skipped count"}, - {name: stmtsummary.MaxRocksdbKeySkippedCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb key skipped count"}, - {name: stmtsummary.AvgRocksdbBlockCacheHitCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb block cache hit count"}, - {name: stmtsummary.MaxRocksdbBlockCacheHitCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb block cache hit count"}, - {name: stmtsummary.AvgRocksdbBlockReadCountStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb block read count"}, - {name: stmtsummary.MaxRocksdbBlockReadCountStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb block read count"}, - {name: stmtsummary.AvgRocksdbBlockReadByteStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rocksdb block read byte"}, - {name: stmtsummary.MaxRocksdbBlockReadByteStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of rocksdb block read byte"}, - {name: stmtsummary.AvgPrewriteTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of prewrite phase"}, - {name: stmtsummary.MaxPrewriteTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of prewrite phase"}, - {name: stmtsummary.AvgCommitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of commit phase"}, - {name: stmtsummary.MaxCommitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of commit phase"}, - {name: stmtsummary.AvgGetCommitTsTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of getting commit_ts"}, - {name: stmtsummary.MaxGetCommitTsTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of getting commit_ts"}, - {name: stmtsummary.AvgCommitBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time before retry during commit phase"}, - {name: stmtsummary.MaxCommitBackoffTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time before retry during commit phase"}, - {name: stmtsummary.AvgResolveLockTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time for resolving locks"}, - {name: stmtsummary.MaxResolveLockTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time for resolving locks"}, - {name: stmtsummary.AvgLocalLatchWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average waiting time of local transaction"}, - {name: stmtsummary.MaxLocalLatchWaitTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max waiting time of local transaction"}, - {name: stmtsummary.AvgWriteKeysStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average count of written keys"}, - {name: stmtsummary.MaxWriteKeysStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max count of written keys"}, - {name: stmtsummary.AvgWriteSizeStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average amount of written bytes"}, - {name: stmtsummary.MaxWriteSizeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max amount of written bytes"}, - {name: stmtsummary.AvgPrewriteRegionsStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of involved regions in prewrite phase"}, - {name: stmtsummary.MaxPrewriteRegionsStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of involved regions in prewrite phase"}, - {name: stmtsummary.AvgTxnRetryStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of transaction retries"}, - {name: stmtsummary.MaxTxnRetryStr, tp: mysql.TypeLong, size: 11, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max number of transaction retries"}, - {name: stmtsummary.SumExecRetryStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum number of execution retries in pessimistic transactions"}, - {name: stmtsummary.SumExecRetryTimeStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum time of execution retries in pessimistic transactions"}, - {name: stmtsummary.SumBackoffTimesStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Sum of retries"}, - {name: stmtsummary.BackoffTypesStr, tp: mysql.TypeVarchar, size: 1024, comment: "Types of errors and the number of retries for each type"}, - {name: stmtsummary.AvgMemStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average memory(byte) used"}, - {name: stmtsummary.MaxMemStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max memory(byte) used"}, - {name: stmtsummary.AvgDiskStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average disk space(byte) used"}, - {name: stmtsummary.MaxDiskStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max disk space(byte) used"}, - {name: stmtsummary.AvgKvTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of TiKV used"}, - {name: stmtsummary.AvgPdTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of PD used"}, - {name: stmtsummary.AvgBackoffTotalTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of Backoff used"}, - {name: stmtsummary.AvgWriteSQLRespTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average time of write sql resp used"}, - {name: stmtsummary.MaxResultRowsStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag, comment: "Max count of sql result rows"}, - {name: stmtsummary.MinResultRowsStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag, comment: "Min count of sql result rows"}, - {name: stmtsummary.AvgResultRowsStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag, comment: "Average count of sql result rows"}, - {name: stmtsummary.PreparedStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether prepared"}, - {name: stmtsummary.AvgAffectedRowsStr, tp: mysql.TypeDouble, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Average number of rows affected"}, - {name: stmtsummary.FirstSeenStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "The time these statements are seen for the first time"}, - {name: stmtsummary.LastSeenStr, tp: mysql.TypeTimestamp, size: 26, flag: mysql.NotNullFlag, comment: "The time these statements are seen for the last time"}, - {name: stmtsummary.PlanInCacheStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement hit plan cache"}, - {name: stmtsummary.PlanCacheHitsStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag, comment: "The number of times these statements hit plan cache"}, - {name: stmtsummary.PlanInBindingStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the last statement is matched with the hints in the binding"}, - {name: stmtsummary.QuerySampleTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled original statement"}, - {name: stmtsummary.PrevSampleTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The previous statement before commit"}, - {name: stmtsummary.PlanDigestStr, tp: mysql.TypeVarchar, size: 64, comment: "Digest of its execution plan"}, - {name: stmtsummary.PlanStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled execution plan"}, - {name: stmtsummary.BinaryPlan, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Sampled binary plan"}, - {name: stmtsummary.Charset, tp: mysql.TypeVarchar, size: 64, comment: "Sampled charset"}, - {name: stmtsummary.Collation, tp: mysql.TypeVarchar, size: 64, comment: "Sampled collation"}, - {name: stmtsummary.PlanHint, tp: mysql.TypeVarchar, size: 64, comment: "Sampled plan hint"}, - {name: stmtsummary.MaxRequestUnitReadStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Max read request-unit cost of these statements"}, - {name: stmtsummary.AvgRequestUnitReadStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Average read request-unit cost of these statements"}, - {name: stmtsummary.MaxRequestUnitWriteStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Max write request-unit cost of these statements"}, - {name: stmtsummary.AvgRequestUnitWriteStr, tp: mysql.TypeDouble, flag: mysql.NotNullFlag | mysql.UnsignedFlag, size: 22, comment: "Average write request-unit cost of these statements"}, - {name: stmtsummary.MaxQueuedRcTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of waiting for available request-units"}, - {name: stmtsummary.AvgQueuedRcTimeStr, tp: mysql.TypeLonglong, size: 22, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Max time of waiting for available request-units"}, - {name: stmtsummary.ResourceGroupName, tp: mysql.TypeVarchar, size: 64, comment: "Bind resource group name"}, - {name: stmtsummary.PlanCacheUnqualifiedStr, tp: mysql.TypeLonglong, size: 20, flag: mysql.NotNullFlag, comment: "The number of times that these statements are not supported by the plan cache"}, - {name: stmtsummary.PlanCacheUnqualifiedLastReasonStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The last reason why the statement is not supported by the plan cache"}, -} - -var tableStorageStatsCols = []columnInfo{ - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "PEER_COUNT", tp: mysql.TypeLonglong, size: 21}, - {name: "REGION_COUNT", tp: mysql.TypeLonglong, size: 21, comment: "The region count of single replica of the table"}, - {name: "EMPTY_REGION_COUNT", tp: mysql.TypeLonglong, size: 21, comment: "The region count of single replica of the table"}, - {name: "TABLE_SIZE", tp: mysql.TypeLonglong, size: 64, comment: "The disk usage(MB) of single replica of the table, if the table size is empty or less than 1MB, it would show 1MB "}, - {name: "TABLE_KEYS", tp: mysql.TypeLonglong, size: 64, comment: "The count of keys of single replica of the table"}, -} - -var tableTableTiFlashTablesCols = []columnInfo{ - {name: "DATABASE", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "TIDB_DATABASE", tp: mysql.TypeVarchar, size: 64}, - {name: "TIDB_TABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "IS_TOMBSTONE", tp: mysql.TypeLonglong, size: 64}, - {name: "SEGMENT_COUNT", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_DELETE_RANGES", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_RATE_ROWS", tp: mysql.TypeDouble, size: 64}, - {name: "DELTA_RATE_SEGMENTS", tp: mysql.TypeDouble, size: 64}, - {name: "DELTA_PLACED_RATE", tp: mysql.TypeDouble, size: 64}, - {name: "DELTA_CACHE_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_CACHE_RATE", tp: mysql.TypeDouble, size: 64}, - {name: "DELTA_CACHE_WASTED_RATE", tp: mysql.TypeDouble, size: 64}, - {name: "DELTA_INDEX_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "AVG_SEGMENT_ROWS", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_SEGMENT_SIZE", tp: mysql.TypeDouble, size: 64}, - {name: "DELTA_COUNT", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_DELTA_ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_DELTA_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "AVG_DELTA_ROWS", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_DELTA_SIZE", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_DELTA_DELETE_RANGES", tp: mysql.TypeDouble, size: 64}, - {name: "STABLE_COUNT", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_STABLE_ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_STABLE_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "TOTAL_STABLE_SIZE_ON_DISK", tp: mysql.TypeLonglong, size: 64}, - {name: "AVG_STABLE_ROWS", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_STABLE_SIZE", tp: mysql.TypeDouble, size: 64}, - {name: "TOTAL_PACK_COUNT_IN_DELTA", tp: mysql.TypeLonglong, size: 64}, - {name: "MAX_PACK_COUNT_IN_DELTA", tp: mysql.TypeLonglong, size: 64}, - {name: "AVG_PACK_COUNT_IN_DELTA", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_PACK_ROWS_IN_DELTA", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_PACK_SIZE_IN_DELTA", tp: mysql.TypeDouble, size: 64}, - {name: "TOTAL_PACK_COUNT_IN_STABLE", tp: mysql.TypeLonglong, size: 64}, - {name: "AVG_PACK_COUNT_IN_STABLE", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_PACK_ROWS_IN_STABLE", tp: mysql.TypeDouble, size: 64}, - {name: "AVG_PACK_SIZE_IN_STABLE", tp: mysql.TypeDouble, size: 64}, - {name: "STORAGE_STABLE_NUM_SNAPSHOTS", tp: mysql.TypeLonglong, size: 64}, - {name: "STORAGE_STABLE_OLDEST_SNAPSHOT_LIFETIME", tp: mysql.TypeDouble, size: 64}, - {name: "STORAGE_STABLE_OLDEST_SNAPSHOT_THREAD_ID", tp: mysql.TypeLonglong, size: 64}, - {name: "STORAGE_STABLE_OLDEST_SNAPSHOT_TRACING_ID", tp: mysql.TypeVarchar, size: 128}, - {name: "STORAGE_DELTA_NUM_SNAPSHOTS", tp: mysql.TypeLonglong, size: 64}, - {name: "STORAGE_DELTA_OLDEST_SNAPSHOT_LIFETIME", tp: mysql.TypeDouble, size: 64}, - {name: "STORAGE_DELTA_OLDEST_SNAPSHOT_THREAD_ID", tp: mysql.TypeLonglong, size: 64}, - {name: "STORAGE_DELTA_OLDEST_SNAPSHOT_TRACING_ID", tp: mysql.TypeVarchar, size: 128}, - {name: "STORAGE_META_NUM_SNAPSHOTS", tp: mysql.TypeLonglong, size: 64}, - {name: "STORAGE_META_OLDEST_SNAPSHOT_LIFETIME", tp: mysql.TypeDouble, size: 64}, - {name: "STORAGE_META_OLDEST_SNAPSHOT_THREAD_ID", tp: mysql.TypeLonglong, size: 64}, - {name: "STORAGE_META_OLDEST_SNAPSHOT_TRACING_ID", tp: mysql.TypeVarchar, size: 128}, - {name: "BACKGROUND_TASKS_LENGTH", tp: mysql.TypeLonglong, size: 64}, - {name: "TIFLASH_INSTANCE", tp: mysql.TypeVarchar, size: 64}, -} - -var tableTableTiFlashSegmentsCols = []columnInfo{ - {name: "DATABASE", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "TIDB_DATABASE", tp: mysql.TypeVarchar, size: 64}, - {name: "TIDB_TABLE", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, - {name: "IS_TOMBSTONE", tp: mysql.TypeLonglong, size: 64}, - {name: "SEGMENT_ID", tp: mysql.TypeLonglong, size: 64}, - {name: "RANGE", tp: mysql.TypeVarchar, size: 64}, - {name: "EPOCH", tp: mysql.TypeLonglong, size: 64}, - {name: "ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_RATE", tp: mysql.TypeDouble, size: 64}, - {name: "DELTA_MEMTABLE_ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_MEMTABLE_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_MEMTABLE_COLUMN_FILES", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_MEMTABLE_DELETE_RANGES", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_PERSISTED_PAGE_ID", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_PERSISTED_ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_PERSISTED_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_PERSISTED_COLUMN_FILES", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_PERSISTED_DELETE_RANGES", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_CACHE_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "DELTA_INDEX_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_PAGE_ID", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_DMFILES", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_DMFILES_ID_0", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_DMFILES_ROWS", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_DMFILES_SIZE", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_DMFILES_SIZE_ON_DISK", tp: mysql.TypeLonglong, size: 64}, - {name: "STABLE_DMFILES_PACKS", tp: mysql.TypeLonglong, size: 64}, - {name: "TIFLASH_INSTANCE", tp: mysql.TypeVarchar, size: 64}, -} - -var tableClientErrorsSummaryGlobalCols = []columnInfo{ - {name: "ERROR_NUMBER", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "ERROR_MESSAGE", tp: mysql.TypeVarchar, size: 1024, flag: mysql.NotNullFlag}, - {name: "ERROR_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "WARNING_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "FIRST_SEEN", tp: mysql.TypeTimestamp, size: 26}, - {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26}, -} - -var tableClientErrorsSummaryByUserCols = []columnInfo{ - {name: "USER", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "ERROR_NUMBER", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "ERROR_MESSAGE", tp: mysql.TypeVarchar, size: 1024, flag: mysql.NotNullFlag}, - {name: "ERROR_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "WARNING_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "FIRST_SEEN", tp: mysql.TypeTimestamp, size: 26}, - {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26}, -} - -var tableClientErrorsSummaryByHostCols = []columnInfo{ - {name: "HOST", tp: mysql.TypeVarchar, size: 255, flag: mysql.NotNullFlag}, - {name: "ERROR_NUMBER", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "ERROR_MESSAGE", tp: mysql.TypeVarchar, size: 1024, flag: mysql.NotNullFlag}, - {name: "ERROR_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "WARNING_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "FIRST_SEEN", tp: mysql.TypeTimestamp, size: 26}, - {name: "LAST_SEEN", tp: mysql.TypeTimestamp, size: 26}, -} - -var tableTiDBTrxCols = []columnInfo{ - {name: txninfo.IDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.PriKeyFlag | mysql.NotNullFlag | mysql.UnsignedFlag}, - {name: txninfo.StartTimeStr, tp: mysql.TypeTimestamp, decimal: 6, size: 26, comment: "Start time of the transaction"}, - {name: txninfo.CurrentSQLDigestStr, tp: mysql.TypeVarchar, size: 64, comment: "Digest of the sql the transaction are currently running"}, - {name: txninfo.CurrentSQLDigestTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The normalized sql the transaction are currently running"}, - {name: txninfo.StateStr, tp: mysql.TypeEnum, size: 16, enumElems: txninfo.TxnRunningStateStrs, comment: "Current running state of the transaction"}, - {name: txninfo.WaitingStartTimeStr, tp: mysql.TypeTimestamp, decimal: 6, size: 26, comment: "Current lock waiting's start time"}, - {name: txninfo.MemBufferKeysStr, tp: mysql.TypeLonglong, size: 64, comment: "How many entries are in MemDB"}, - {name: txninfo.MemBufferBytesStr, tp: mysql.TypeLonglong, size: 64, comment: "MemDB used memory"}, - {name: txninfo.SessionIDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag, comment: "Which session this transaction belongs to"}, - {name: txninfo.UserStr, tp: mysql.TypeVarchar, size: 16, comment: "The user who open this session"}, - {name: txninfo.DBStr, tp: mysql.TypeVarchar, size: 64, comment: "The schema this transaction works on"}, - {name: txninfo.AllSQLDigestsStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "A list of the digests of SQL statements that the transaction has executed"}, - {name: txninfo.RelatedTableIDsStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "A list of the table IDs that the transaction has accessed"}, - {name: txninfo.WaitingTimeStr, tp: mysql.TypeDouble, size: 22, comment: "Current lock waiting time"}, -} - -var tableDeadlocksCols = []columnInfo{ - {name: deadlockhistory.ColDeadlockIDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag, comment: "The ID to distinguish different deadlock events"}, - {name: deadlockhistory.ColOccurTimeStr, tp: mysql.TypeTimestamp, decimal: 6, size: 26, comment: "The physical time when the deadlock occurs"}, - {name: deadlockhistory.ColRetryableStr, tp: mysql.TypeTiny, size: 1, flag: mysql.NotNullFlag, comment: "Whether the deadlock is retryable. Retryable deadlocks are usually not reported to the client"}, - {name: deadlockhistory.ColTryLockTrxIDStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "The transaction ID (start ts) of the transaction that's trying to acquire the lock"}, - {name: deadlockhistory.ColCurrentSQLDigestStr, tp: mysql.TypeVarchar, size: 64, comment: "The digest of the SQL that's being blocked"}, - {name: deadlockhistory.ColCurrentSQLDigestTextStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The normalized SQL that's being blocked"}, - {name: deadlockhistory.ColKeyStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "The key on which a transaction is waiting for another"}, - {name: deadlockhistory.ColKeyInfoStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Information of the key"}, - {name: deadlockhistory.ColTrxHoldingLockStr, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "The transaction ID (start ts) of the transaction that's currently holding the lock"}, -} - -var tableDataLockWaitsCols = []columnInfo{ - {name: DataLockWaitsColumnKey, tp: mysql.TypeBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag, comment: "The key that's being waiting on"}, - {name: DataLockWaitsColumnKeyInfo, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Information of the key"}, - {name: DataLockWaitsColumnTrxID, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "Current transaction that's waiting for the lock"}, - {name: DataLockWaitsColumnCurrentHoldingTrxID, tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag | mysql.UnsignedFlag, comment: "The transaction that's holding the lock and blocks the current transaction"}, - {name: DataLockWaitsColumnSQLDigest, tp: mysql.TypeVarchar, size: 64, comment: "Digest of the SQL that's trying to acquire the lock"}, - {name: DataLockWaitsColumnSQLDigestText, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "Digest of the SQL that's trying to acquire the lock"}, -} - -var tableStatementsSummaryEvictedCols = []columnInfo{ - {name: "BEGIN_TIME", tp: mysql.TypeTimestamp, size: 26}, - {name: "END_TIME", tp: mysql.TypeTimestamp, size: 26}, - {name: "EVICTED_COUNT", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, -} - -var tableAttributesCols = []columnInfo{ - {name: "ID", tp: mysql.TypeVarchar, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, - {name: "TYPE", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag}, - {name: "ATTRIBUTES", tp: mysql.TypeVarchar, size: types.UnspecifiedLength}, - {name: "RANGES", tp: mysql.TypeBlob, size: types.UnspecifiedLength}, -} - -var tableTrxSummaryCols = []columnInfo{ - {name: "DIGEST", tp: mysql.TypeVarchar, size: 16, flag: mysql.NotNullFlag, comment: "Digest of a transaction"}, - {name: txninfo.AllSQLDigestsStr, tp: mysql.TypeBlob, size: types.UnspecifiedLength, comment: "A list of the digests of SQL statements that the transaction has executed"}, -} - -var tablePlacementPoliciesCols = []columnInfo{ - {name: "POLICY_ID", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "CATALOG_NAME", tp: mysql.TypeVarchar, size: 512, flag: mysql.NotNullFlag}, - {name: "POLICY_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, // Catalog wide policy - {name: "PRIMARY_REGION", tp: mysql.TypeVarchar, size: 1024}, - {name: "REGIONS", tp: mysql.TypeVarchar, size: 1024}, - {name: "CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, - {name: "LEADER_CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, - {name: "FOLLOWER_CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, - {name: "LEARNER_CONSTRAINTS", tp: mysql.TypeVarchar, size: 1024}, - {name: "SCHEDULE", tp: mysql.TypeVarchar, size: 20}, // EVEN or MAJORITY_IN_PRIMARY - {name: "FOLLOWERS", tp: mysql.TypeLonglong, size: 64}, - {name: "LEARNERS", tp: mysql.TypeLonglong, size: 64}, -} - -var tableVariablesInfoCols = []columnInfo{ - {name: "VARIABLE_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "VARIABLE_SCOPE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "DEFAULT_VALUE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CURRENT_VALUE", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "MIN_VALUE", tp: mysql.TypeLonglong, size: 64}, - {name: "MAX_VALUE", tp: mysql.TypeLonglong, size: 64, flag: mysql.UnsignedFlag}, - {name: "POSSIBLE_VALUES", tp: mysql.TypeVarchar, size: 256}, - {name: "IS_NOOP", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, -} - -var tableUserAttributesCols = []columnInfo{ - {name: "USER", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "HOST", tp: mysql.TypeVarchar, size: 255, flag: mysql.NotNullFlag}, - {name: "ATTRIBUTE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength}, -} - -var tableMemoryUsageCols = []columnInfo{ - {name: "MEMORY_TOTAL", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "MEMORY_LIMIT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "MEMORY_CURRENT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "MEMORY_MAX_USED", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "CURRENT_OPS", tp: mysql.TypeVarchar, size: 50}, - {name: "SESSION_KILL_LAST", tp: mysql.TypeDatetime}, - {name: "SESSION_KILL_TOTAL", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "GC_LAST", tp: mysql.TypeDatetime}, - {name: "GC_TOTAL", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "DISK_USAGE", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "QUERY_FORCE_DISK", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, -} - -var tableMemoryUsageOpsHistoryCols = []columnInfo{ - {name: "TIME", tp: mysql.TypeDatetime, size: 64, flag: mysql.NotNullFlag}, - {name: "OPS", tp: mysql.TypeVarchar, size: 20, flag: mysql.NotNullFlag}, - {name: "MEMORY_LIMIT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "MEMORY_CURRENT", tp: mysql.TypeLonglong, size: 21, flag: mysql.NotNullFlag}, - {name: "PROCESSID", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, - {name: "MEM", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, - {name: "DISK", tp: mysql.TypeLonglong, size: 21, flag: mysql.UnsignedFlag}, - {name: "CLIENT", tp: mysql.TypeVarchar, size: 64}, - {name: "DB", tp: mysql.TypeVarchar, size: 64}, - {name: "USER", tp: mysql.TypeVarchar, size: 16}, - {name: "SQL_DIGEST", tp: mysql.TypeVarchar, size: 64}, - {name: "SQL_TEXT", tp: mysql.TypeVarchar, size: 256}, -} - -var tableResourceGroupsCols = []columnInfo{ - {name: "NAME", tp: mysql.TypeVarchar, size: resourcegroup.MaxGroupNameLength, flag: mysql.NotNullFlag}, - {name: "RU_PER_SEC", tp: mysql.TypeVarchar, size: 21}, - {name: "PRIORITY", tp: mysql.TypeVarchar, size: 6}, - {name: "BURSTABLE", tp: mysql.TypeVarchar, size: 3}, - {name: "QUERY_LIMIT", tp: mysql.TypeVarchar, size: 256}, - {name: "BACKGROUND", tp: mysql.TypeVarchar, size: 256}, -} - -var tableRunawayWatchListCols = []columnInfo{ - {name: "ID", tp: mysql.TypeLonglong, size: 64, flag: mysql.NotNullFlag}, - {name: "RESOURCE_GROUP_NAME", tp: mysql.TypeVarchar, size: resourcegroup.MaxGroupNameLength, flag: mysql.NotNullFlag}, - {name: "START_TIME", tp: mysql.TypeVarchar, size: 32, flag: mysql.NotNullFlag}, - {name: "END_TIME", tp: mysql.TypeVarchar, size: 32}, - {name: "WATCH", tp: mysql.TypeVarchar, size: 12, flag: mysql.NotNullFlag}, - {name: "WATCH_TEXT", tp: mysql.TypeBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, - {name: "SOURCE", tp: mysql.TypeVarchar, size: 128, flag: mysql.NotNullFlag}, - {name: "ACTION", tp: mysql.TypeVarchar, size: 12, flag: mysql.NotNullFlag}, -} - -// information_schema.CHECK_CONSTRAINTS -var tableCheckConstraintsCols = []columnInfo{ - {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CHECK_CLAUSE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, -} - -// information_schema.TIDB_CHECK_CONSTRAINTS -var tableTiDBCheckConstraintsCols = []columnInfo{ - {name: "CONSTRAINT_CATALOG", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_SCHEMA", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CONSTRAINT_NAME", tp: mysql.TypeVarchar, size: 64, flag: mysql.NotNullFlag}, - {name: "CHECK_CLAUSE", tp: mysql.TypeLongBlob, size: types.UnspecifiedLength, flag: mysql.NotNullFlag}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_ID", tp: mysql.TypeLonglong, size: 21}, -} - -var tableKeywords = []columnInfo{ - {name: "WORD", tp: mysql.TypeVarchar, size: 128}, - {name: "RESERVED", tp: mysql.TypeLong, size: 11}, -} - -var tableTiDBIndexUsage = []columnInfo{ - {name: "TABLE_SCHEMA", tp: mysql.TypeVarchar, size: 64}, - {name: "TABLE_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "INDEX_NAME", tp: mysql.TypeVarchar, size: 64}, - {name: "QUERY_TOTAL", tp: mysql.TypeLonglong, size: 21}, - {name: "KV_REQ_TOTAL", tp: mysql.TypeLonglong, size: 21}, - {name: "ROWS_ACCESS_TOTAL", tp: mysql.TypeLonglong, size: 21}, - {name: "PERCENTAGE_ACCESS_0", tp: mysql.TypeLonglong, size: 21}, - {name: "PERCENTAGE_ACCESS_0_1", tp: mysql.TypeLonglong, size: 21}, - {name: "PERCENTAGE_ACCESS_1_10", tp: mysql.TypeLonglong, size: 21}, - {name: "PERCENTAGE_ACCESS_10_20", tp: mysql.TypeLonglong, size: 21}, - {name: "PERCENTAGE_ACCESS_20_50", tp: mysql.TypeLonglong, size: 21}, - {name: "PERCENTAGE_ACCESS_50_100", tp: mysql.TypeLonglong, size: 21}, - {name: "PERCENTAGE_ACCESS_100", tp: mysql.TypeLonglong, size: 21}, - {name: "LAST_ACCESS_TIME", tp: mysql.TypeDatetime, size: 21}, -} - -// GetShardingInfo returns a nil or description string for the sharding information of given TableInfo. -// The returned description string may be: -// - "NOT_SHARDED": for tables that SHARD_ROW_ID_BITS is not specified. -// - "NOT_SHARDED(PK_IS_HANDLE)": for tables of which primary key is row id. -// - "PK_AUTO_RANDOM_BITS={bit_number}, RANGE BITS={bit_number}": for tables of which primary key is sharded row id. -// - "SHARD_BITS={bit_number}": for tables that with SHARD_ROW_ID_BITS. -// -// The returned nil indicates that sharding information is not suitable for the table(for example, when the table is a View). -// This function is exported for unit test. -func GetShardingInfo(dbInfo model.CIStr, tableInfo *model.TableInfo) any { - if tableInfo == nil || tableInfo.IsView() || util.IsMemOrSysDB(dbInfo.L) { - return nil - } - shardingInfo := "NOT_SHARDED" - if tableInfo.ContainsAutoRandomBits() { - shardingInfo = "PK_AUTO_RANDOM_BITS=" + strconv.Itoa(int(tableInfo.AutoRandomBits)) - rangeBits := tableInfo.AutoRandomRangeBits - if rangeBits != 0 && rangeBits != autoid.AutoRandomRangeBitsDefault { - shardingInfo = fmt.Sprintf("%s, RANGE BITS=%d", shardingInfo, rangeBits) - } - } else if tableInfo.ShardRowIDBits > 0 { - shardingInfo = "SHARD_BITS=" + strconv.Itoa(int(tableInfo.ShardRowIDBits)) - } else if tableInfo.PKIsHandle { - shardingInfo = "NOT_SHARDED(PK_IS_HANDLE)" - } - return shardingInfo -} - -const ( - // PrimaryKeyType is the string constant of PRIMARY KEY. - PrimaryKeyType = "PRIMARY KEY" - // PrimaryConstraint is the string constant of PRIMARY. - PrimaryConstraint = "PRIMARY" - // UniqueKeyType is the string constant of UNIQUE. - UniqueKeyType = "UNIQUE" - // ForeignKeyType is the string constant of Foreign Key. - ForeignKeyType = "FOREIGN KEY" -) - -const ( - // TiFlashWrite is the TiFlash write node in disaggregated mode. - TiFlashWrite = "tiflash_write" -) - -// ServerInfo represents the basic server information of single cluster component -type ServerInfo struct { - ServerType string - Address string - StatusAddr string - Version string - GitHash string - StartTimestamp int64 - ServerID uint64 - EngineRole string -} - -func (s *ServerInfo) isLoopBackOrUnspecifiedAddr(addr string) bool { - tcpAddr, err := net.ResolveTCPAddr("", addr) - if err != nil { - return false - } - ip := net.ParseIP(tcpAddr.IP.String()) - return ip != nil && (ip.IsUnspecified() || ip.IsLoopback()) -} - -// ResolveLoopBackAddr exports for testing. -func (s *ServerInfo) ResolveLoopBackAddr() { - if s.isLoopBackOrUnspecifiedAddr(s.Address) && !s.isLoopBackOrUnspecifiedAddr(s.StatusAddr) { - addr, err1 := net.ResolveTCPAddr("", s.Address) - statusAddr, err2 := net.ResolveTCPAddr("", s.StatusAddr) - if err1 == nil && err2 == nil { - addr.IP = statusAddr.IP - s.Address = addr.String() - } - } else if !s.isLoopBackOrUnspecifiedAddr(s.Address) && s.isLoopBackOrUnspecifiedAddr(s.StatusAddr) { - addr, err1 := net.ResolveTCPAddr("", s.Address) - statusAddr, err2 := net.ResolveTCPAddr("", s.StatusAddr) - if err1 == nil && err2 == nil { - statusAddr.IP = addr.IP - s.StatusAddr = statusAddr.String() - } - } -} - -// GetClusterServerInfo returns all components information of cluster -func GetClusterServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - failpoint.Inject("mockClusterInfo", func(val failpoint.Value) { - // The cluster topology is injected by `failpoint` expression and - // there is no extra checks for it. (let the test fail if the expression invalid) - if s := val.(string); len(s) > 0 { - var servers []ServerInfo - for _, server := range strings.Split(s, ";") { - parts := strings.Split(server, ",") - serverID, err := strconv.ParseUint(parts[5], 10, 64) - if err != nil { - panic("convert parts[5] to uint64 failed") - } - servers = append(servers, ServerInfo{ - ServerType: parts[0], - Address: parts[1], - StatusAddr: parts[2], - Version: parts[3], - GitHash: parts[4], - ServerID: serverID, - }) - } - failpoint.Return(servers, nil) - } - }) - - type retriever func(ctx sessionctx.Context) ([]ServerInfo, error) - retrievers := []retriever{GetTiDBServerInfo, GetPDServerInfo, func(ctx sessionctx.Context) ([]ServerInfo, error) { - return GetStoreServerInfo(ctx.GetStore()) - }, GetTiProxyServerInfo, GetTiCDCServerInfo, GetTSOServerInfo, GetSchedulingServerInfo} - //nolint: prealloc - var servers []ServerInfo - for _, r := range retrievers { - nodes, err := r(ctx) - if err != nil { - return nil, err - } - for i := range nodes { - nodes[i].ResolveLoopBackAddr() - } - servers = append(servers, nodes...) - } - return servers, nil -} - -// GetTiDBServerInfo returns all TiDB nodes information of cluster -func GetTiDBServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - // Get TiDB servers info. - tidbNodes, err := infosync.GetAllServerInfo(context.Background()) - if err != nil { - return nil, errors.Trace(err) - } - var isDefaultVersion bool - if len(config.GetGlobalConfig().ServerVersion) == 0 { - isDefaultVersion = true - } - var servers = make([]ServerInfo, 0, len(tidbNodes)) - for _, node := range tidbNodes { - servers = append(servers, ServerInfo{ - ServerType: "tidb", - Address: net.JoinHostPort(node.IP, strconv.Itoa(int(node.Port))), - StatusAddr: net.JoinHostPort(node.IP, strconv.Itoa(int(node.StatusPort))), - Version: FormatTiDBVersion(node.Version, isDefaultVersion), - GitHash: node.GitHash, - StartTimestamp: node.StartTimestamp, - ServerID: node.ServerIDGetter(), - }) - } - return servers, nil -} - -// FormatTiDBVersion make TiDBVersion consistent to TiKV and PD. -// The default TiDBVersion is 5.7.25-TiDB-${TiDBReleaseVersion}. -func FormatTiDBVersion(TiDBVersion string, isDefaultVersion bool) string { - var version, nodeVersion string - - // The user hasn't set the config 'ServerVersion'. - if isDefaultVersion { - nodeVersion = TiDBVersion[strings.Index(TiDBVersion, "TiDB-")+len("TiDB-"):] - if len(nodeVersion) > 0 && nodeVersion[0] == 'v' { - nodeVersion = nodeVersion[1:] - } - nodeVersions := strings.SplitN(nodeVersion, "-", 2) - if len(nodeVersions) == 1 { - version = nodeVersions[0] - } else if len(nodeVersions) >= 2 { - version = fmt.Sprintf("%s-%s", nodeVersions[0], nodeVersions[1]) - } - } else { // The user has already set the config 'ServerVersion',it would be a complex scene, so just use the 'ServerVersion' as version. - version = TiDBVersion - } - - return version -} - -// GetPDServerInfo returns all PD nodes information of cluster -func GetPDServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - // Get PD servers info. - members, err := getEtcdMembers(ctx) - if err != nil { - return nil, err - } - // TODO: maybe we should unify the PD API request interface. - var ( - memberNum = len(members) - servers = make([]ServerInfo, 0, memberNum) - errs = make([]error, 0, memberNum) - ) - if memberNum == 0 { - return servers, nil - } - // Try on each member until one succeeds or all fail. - for _, addr := range members { - // Get PD version, git_hash - url := fmt.Sprintf("%s://%s%s", util.InternalHTTPSchema(), addr, pd.Status) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(err) - logutil.BgLogger().Warn("create pd server info request error", zap.String("url", url), zap.Error(err)) - errs = append(errs, err) - continue - } - req.Header.Add("PD-Allow-follower-handle", "true") - resp, err := util.InternalHTTPClient().Do(req) - if err != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(err) - logutil.BgLogger().Warn("request pd server info error", zap.String("url", url), zap.Error(err)) - errs = append(errs, err) - continue - } - var content = struct { - Version string `json:"version"` - GitHash string `json:"git_hash"` - StartTimestamp int64 `json:"start_timestamp"` - }{} - err = json.NewDecoder(resp.Body).Decode(&content) - terror.Log(resp.Body.Close()) - if err != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(err) - logutil.BgLogger().Warn("close pd server info request error", zap.String("url", url), zap.Error(err)) - errs = append(errs, err) - continue - } - if len(content.Version) > 0 && content.Version[0] == 'v' { - content.Version = content.Version[1:] - } - - servers = append(servers, ServerInfo{ - ServerType: "pd", - Address: addr, - StatusAddr: addr, - Version: content.Version, - GitHash: content.GitHash, - StartTimestamp: content.StartTimestamp, - }) - } - // Return the errors if all members' requests fail. - if len(errs) == memberNum { - errorMsg := "" - for idx, err := range errs { - errorMsg += err.Error() - if idx < memberNum-1 { - errorMsg += "; " - } - } - return nil, errors.Trace(fmt.Errorf("%s", errorMsg)) - } - return servers, nil -} - -// GetTSOServerInfo returns all TSO nodes information of cluster -func GetTSOServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - return getMicroServiceServerInfo(ctx, tsoServiceName) -} - -// GetSchedulingServerInfo returns all scheduling nodes information of cluster -func GetSchedulingServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - return getMicroServiceServerInfo(ctx, schedulingServiceName) -} - -func getMicroServiceServerInfo(ctx sessionctx.Context, serviceName string) ([]ServerInfo, error) { - members, err := getEtcdMembers(ctx) - if err != nil { - return nil, err - } - // TODO: maybe we should unify the PD API request interface. - var servers []ServerInfo - - if len(members) == 0 { - return servers, nil - } - // Try on each member until one succeeds or all fail. - for _, addr := range members { - // Get members - url := fmt.Sprintf("%s://%s%s/%s", util.InternalHTTPSchema(), addr, "/pd/api/v2/ms/members", serviceName) - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(err) - logutil.BgLogger().Warn("create microservice server info request error", zap.String("service", serviceName), zap.String("url", url), zap.Error(err)) - continue - } - req.Header.Add("PD-Allow-follower-handle", "true") - resp, err := util.InternalHTTPClient().Do(req) - if err != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(err) - logutil.BgLogger().Warn("request microservice server info error", zap.String("service", serviceName), zap.String("url", url), zap.Error(err)) - continue - } - if resp.StatusCode != http.StatusOK { - terror.Log(resp.Body.Close()) - continue - } - var content = []struct { - ServiceAddr string `json:"service-addr"` - Version string `json:"version"` - GitHash string `json:"git-hash"` - DeployPath string `json:"deploy-path"` - StartTimestamp int64 `json:"start-timestamp"` - }{} - err = json.NewDecoder(resp.Body).Decode(&content) - terror.Log(resp.Body.Close()) - if err != nil { - ctx.GetSessionVars().StmtCtx.AppendWarning(err) - logutil.BgLogger().Warn("close microservice server info request error", zap.String("service", serviceName), zap.String("url", url), zap.Error(err)) - continue - } - - for _, c := range content { - addr := strings.TrimPrefix(c.ServiceAddr, "http://") - addr = strings.TrimPrefix(addr, "https://") - if len(c.Version) > 0 && c.Version[0] == 'v' { - c.Version = c.Version[1:] - } - servers = append(servers, ServerInfo{ - ServerType: serviceName, - Address: addr, - StatusAddr: addr, - Version: c.Version, - GitHash: c.GitHash, - StartTimestamp: c.StartTimestamp, - }) - } - return servers, nil - } - return servers, nil -} - -func getEtcdMembers(ctx sessionctx.Context) ([]string, error) { - store := ctx.GetStore() - etcd, ok := store.(kv.EtcdBackend) - if !ok { - return nil, errors.Errorf("%T not an etcd backend", store) - } - members, err := etcd.EtcdAddrs() - if err != nil { - return nil, errors.Trace(err) - } - return members, nil -} - -func isTiFlashStore(store *metapb.Store) bool { - for _, label := range store.Labels { - if label.GetKey() == placement.EngineLabelKey && label.GetValue() == placement.EngineLabelTiFlash { - return true - } - } - return false -} - -func isTiFlashWriteNode(store *metapb.Store) bool { - for _, label := range store.Labels { - if label.GetKey() == placement.EngineRoleLabelKey && label.GetValue() == placement.EngineRoleLabelWrite { - return true - } - } - return false -} - -// GetStoreServerInfo returns all store nodes(TiKV or TiFlash) cluster information -func GetStoreServerInfo(store kv.Storage) ([]ServerInfo, error) { - failpoint.Inject("mockStoreServerInfo", func(val failpoint.Value) { - if s := val.(string); len(s) > 0 { - var servers []ServerInfo - for _, server := range strings.Split(s, ";") { - parts := strings.Split(server, ",") - servers = append(servers, ServerInfo{ - ServerType: parts[0], - Address: parts[1], - StatusAddr: parts[2], - Version: parts[3], - GitHash: parts[4], - StartTimestamp: 0, - }) - } - failpoint.Return(servers, nil) - } - }) - - // Get TiKV servers info. - tikvStore, ok := store.(tikv.Storage) - if !ok { - return nil, errors.Errorf("%T is not an TiKV or TiFlash store instance", store) - } - pdClient := tikvStore.GetRegionCache().PDClient() - if pdClient == nil { - return nil, errors.New("pd unavailable") - } - stores, err := pdClient.GetAllStores(context.Background()) - if err != nil { - return nil, errors.Trace(err) - } - servers := make([]ServerInfo, 0, len(stores)) - for _, store := range stores { - failpoint.Inject("mockStoreTombstone", func(val failpoint.Value) { - if val.(bool) { - store.State = metapb.StoreState_Tombstone - } - }) - - if store.GetState() == metapb.StoreState_Tombstone { - continue - } - var tp string - if isTiFlashStore(store) { - tp = kv.TiFlash.Name() - } else { - tp = tikv.GetStoreTypeByMeta(store).Name() - } - var engineRole string - if isTiFlashWriteNode(store) { - engineRole = placement.EngineRoleLabelWrite - } - servers = append(servers, ServerInfo{ - ServerType: tp, - Address: store.Address, - StatusAddr: store.StatusAddress, - Version: FormatStoreServerVersion(store.Version), - GitHash: store.GitHash, - StartTimestamp: store.StartTimestamp, - EngineRole: engineRole, - }) - } - return servers, nil -} - -// FormatStoreServerVersion format version of store servers(Tikv or TiFlash) -func FormatStoreServerVersion(version string) string { - if len(version) >= 1 && version[0] == 'v' { - version = version[1:] - } - return version -} - -// GetTiFlashStoreCount returns the count of tiflash server. -func GetTiFlashStoreCount(ctx sessionctx.Context) (cnt uint64, err error) { - failpoint.Inject("mockTiFlashStoreCount", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(uint64(10), nil) - } - }) - - stores, err := GetStoreServerInfo(ctx.GetStore()) - if err != nil { - return cnt, err - } - for _, store := range stores { - if store.ServerType == kv.TiFlash.Name() { - cnt++ - } - } - return cnt, nil -} - -// GetTiProxyServerInfo gets server info of TiProxy from PD. -func GetTiProxyServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - tiproxyNodes, err := infosync.GetTiProxyServerInfo(context.Background()) - if err != nil { - return nil, errors.Trace(err) - } - var servers = make([]ServerInfo, 0, len(tiproxyNodes)) - for _, node := range tiproxyNodes { - servers = append(servers, ServerInfo{ - ServerType: "tiproxy", - Address: net.JoinHostPort(node.IP, node.Port), - StatusAddr: net.JoinHostPort(node.IP, node.StatusPort), - Version: node.Version, - GitHash: node.GitHash, - StartTimestamp: node.StartTimestamp, - }) - } - return servers, nil -} - -// GetTiCDCServerInfo gets server info of TiCDC from PD. -func GetTiCDCServerInfo(ctx sessionctx.Context) ([]ServerInfo, error) { - ticdcNodes, err := infosync.GetTiCDCServerInfo(context.Background()) - if err != nil { - return nil, errors.Trace(err) - } - var servers = make([]ServerInfo, 0, len(ticdcNodes)) - for _, node := range ticdcNodes { - servers = append(servers, ServerInfo{ - ServerType: "ticdc", - Address: node.Address, - StatusAddr: node.Address, - Version: node.Version, - GitHash: node.GitHash, - StartTimestamp: node.StartTimestamp, - }) - } - return servers, nil -} - -// SysVarHiddenForSem checks if a given sysvar is hidden according to SEM and privileges. -func SysVarHiddenForSem(ctx sessionctx.Context, sysVarNameInLower string) bool { - if !sem.IsEnabled() || !sem.IsInvisibleSysVar(sysVarNameInLower) { - return false - } - checker := privilege.GetPrivilegeManager(ctx) - if checker == nil || checker.RequestDynamicVerification(ctx.GetSessionVars().ActiveRoles, "RESTRICTED_VARIABLES_ADMIN", false) { - return false - } - return true -} - -// GetDataFromSessionVariables return the [name, value] of all session variables -func GetDataFromSessionVariables(ctx context.Context, sctx sessionctx.Context) ([][]types.Datum, error) { - sessionVars := sctx.GetSessionVars() - sysVars := variable.GetSysVars() - rows := make([][]types.Datum, 0, len(sysVars)) - for _, v := range sysVars { - if SysVarHiddenForSem(sctx, v.Name) { - continue - } - var value string - value, err := sessionVars.GetSessionOrGlobalSystemVar(ctx, v.Name) - if err != nil { - return nil, err - } - row := types.MakeDatums(v.Name, value) - rows = append(rows, row) - } - return rows, nil -} - -// GetDataFromSessionConnectAttrs produces the rows for the session_connect_attrs table. -func GetDataFromSessionConnectAttrs(sctx sessionctx.Context, sameAccount bool) ([][]types.Datum, error) { - sm := sctx.GetSessionManager() - if sm == nil { - return nil, nil - } - var user *auth.UserIdentity - if sameAccount { - user = sctx.GetSessionVars().User - } - allAttrs := sm.GetConAttrs(user) - rows := make([][]types.Datum, 0, len(allAttrs)*10) // 10 Attributes per connection - for pid, attrs := range allAttrs { // Note: PID is not ordered. - // Sorts the attributes by key and gives ORDINAL_POSITION based on this. This is needed as we didn't store the - // ORDINAL_POSITION and a map doesn't have a guaranteed sort order. This is needed to keep the ORDINAL_POSITION - // stable over multiple queries. - attrnames := make([]string, 0, len(attrs)) - for attrname := range attrs { - attrnames = append(attrnames, attrname) - } - sort.Strings(attrnames) - - for ord, attrkey := range attrnames { - row := types.MakeDatums( - pid, - attrkey, - attrs[attrkey], - ord, - ) - rows = append(rows, row) - } - } - return rows, nil -} - -var tableNameToColumns = map[string][]columnInfo{ - TableSchemata: schemataCols, - TableTables: tablesCols, - TableColumns: columnsCols, - tableColumnStatistics: columnStatisticsCols, - TableStatistics: statisticsCols, - TableCharacterSets: charsetCols, - TableCollations: collationsCols, - tableFiles: filesCols, - TableProfiling: profilingCols, - TablePartitions: partitionsCols, - TableKeyColumn: keyColumnUsageCols, - TableReferConst: referConstCols, - TableSessionVar: sessionVarCols, - tablePlugins: pluginsCols, - TableConstraints: tableConstraintsCols, - tableTriggers: tableTriggersCols, - TableUserPrivileges: tableUserPrivilegesCols, - tableSchemaPrivileges: tableSchemaPrivilegesCols, - tableTablePrivileges: tableTablePrivilegesCols, - tableColumnPrivileges: tableColumnPrivilegesCols, - TableEngines: tableEnginesCols, - TableViews: tableViewsCols, - tableRoutines: tableRoutinesCols, - tableParameters: tableParametersCols, - tableEvents: tableEventsCols, - tableGlobalStatus: tableGlobalStatusCols, - tableGlobalVariables: tableGlobalVariablesCols, - tableSessionStatus: tableSessionStatusCols, - tableOptimizerTrace: tableOptimizerTraceCols, - tableTableSpaces: tableTableSpacesCols, - TableCollationCharacterSetApplicability: tableCollationCharacterSetApplicabilityCols, - TableProcesslist: tableProcesslistCols, - TableTiDBIndexes: tableTiDBIndexesCols, - TableSlowQuery: slowQueryCols, - TableTiDBHotRegions: TableTiDBHotRegionsCols, - TableTiDBHotRegionsHistory: TableTiDBHotRegionsHistoryCols, - TableTiKVStoreStatus: TableTiKVStoreStatusCols, - TableAnalyzeStatus: tableAnalyzeStatusCols, - TableTiKVRegionStatus: TableTiKVRegionStatusCols, - TableTiKVRegionPeers: TableTiKVRegionPeersCols, - TableTiDBServersInfo: tableTiDBServersInfoCols, - TableClusterInfo: tableClusterInfoCols, - TableClusterConfig: tableClusterConfigCols, - TableClusterLog: tableClusterLogCols, - TableClusterLoad: tableClusterLoadCols, - TableTiFlashReplica: tableTableTiFlashReplicaCols, - TableClusterHardware: tableClusterHardwareCols, - TableClusterSystemInfo: tableClusterSystemInfoCols, - TableInspectionResult: tableInspectionResultCols, - TableMetricSummary: tableMetricSummaryCols, - TableMetricSummaryByLabel: tableMetricSummaryByLabelCols, - TableMetricTables: tableMetricTablesCols, - TableInspectionSummary: tableInspectionSummaryCols, - TableInspectionRules: tableInspectionRulesCols, - TableDDLJobs: tableDDLJobsCols, - TableSequences: tableSequencesCols, - TableStatementsSummary: tableStatementsSummaryCols, - TableStatementsSummaryHistory: tableStatementsSummaryCols, - TableStatementsSummaryEvicted: tableStatementsSummaryEvictedCols, - TableStorageStats: tableStorageStatsCols, - TableTiFlashTables: tableTableTiFlashTablesCols, - TableTiFlashSegments: tableTableTiFlashSegmentsCols, - TableClientErrorsSummaryGlobal: tableClientErrorsSummaryGlobalCols, - TableClientErrorsSummaryByUser: tableClientErrorsSummaryByUserCols, - TableClientErrorsSummaryByHost: tableClientErrorsSummaryByHostCols, - TableTiDBTrx: tableTiDBTrxCols, - TableDeadlocks: tableDeadlocksCols, - TableDataLockWaits: tableDataLockWaitsCols, - TableAttributes: tableAttributesCols, - TablePlacementPolicies: tablePlacementPoliciesCols, - TableTrxSummary: tableTrxSummaryCols, - TableVariablesInfo: tableVariablesInfoCols, - TableUserAttributes: tableUserAttributesCols, - TableMemoryUsage: tableMemoryUsageCols, - TableMemoryUsageOpsHistory: tableMemoryUsageOpsHistoryCols, - TableResourceGroups: tableResourceGroupsCols, - TableRunawayWatches: tableRunawayWatchListCols, - TableCheckConstraints: tableCheckConstraintsCols, - TableTiDBCheckConstraints: tableTiDBCheckConstraintsCols, - TableKeywords: tableKeywords, - TableTiDBIndexUsage: tableTiDBIndexUsage, -} - -func createInfoSchemaTable(_ autoid.Allocators, _ func() (pools.Resource, error), meta *model.TableInfo) (table.Table, error) { - columns := make([]*table.Column, len(meta.Columns)) - for i, col := range meta.Columns { - columns[i] = table.ToColumn(col) - } - tp := table.VirtualTable - if isClusterTableByName(util.InformationSchemaName.O, meta.Name.O) { - tp = table.ClusterTable - } - return &infoschemaTable{meta: meta, cols: columns, tp: tp}, nil -} - -type infoschemaTable struct { - meta *model.TableInfo - cols []*table.Column - tp table.Type -} - -// IterRecords implements table.Table IterRecords interface. -func (*infoschemaTable) IterRecords(ctx context.Context, sctx sessionctx.Context, cols []*table.Column, fn table.RecordIterFunc) error { - return nil -} - -// Cols implements table.Table Cols interface. -func (it *infoschemaTable) Cols() []*table.Column { - return it.cols -} - -// VisibleCols implements table.Table VisibleCols interface. -func (it *infoschemaTable) VisibleCols() []*table.Column { - return it.cols -} - -// HiddenCols implements table.Table HiddenCols interface. -func (it *infoschemaTable) HiddenCols() []*table.Column { - return nil -} - -// WritableCols implements table.Table WritableCols interface. -func (it *infoschemaTable) WritableCols() []*table.Column { - return it.cols -} - -// DeletableCols implements table.Table WritableCols interface. -func (it *infoschemaTable) DeletableCols() []*table.Column { - return it.cols -} - -// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. -func (it *infoschemaTable) FullHiddenColsAndVisibleCols() []*table.Column { - return it.cols -} - -// Indices implements table.Table Indices interface. -func (it *infoschemaTable) Indices() []table.Index { - return nil -} - -// WritableConstraint implements table.Table WritableConstraint interface. -func (it *infoschemaTable) WritableConstraint() []*table.Constraint { - return nil -} - -// RecordPrefix implements table.Table RecordPrefix interface. -func (it *infoschemaTable) RecordPrefix() kv.Key { - return nil -} - -// IndexPrefix implements table.Table IndexPrefix interface. -func (it *infoschemaTable) IndexPrefix() kv.Key { - return nil -} - -// AddRecord implements table.Table AddRecord interface. -func (it *infoschemaTable) AddRecord(ctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { - return nil, table.ErrUnsupportedOp -} - -// RemoveRecord implements table.Table RemoveRecord interface. -func (it *infoschemaTable) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []types.Datum) error { - return table.ErrUnsupportedOp -} - -// UpdateRecord implements table.Table UpdateRecord interface. -func (it *infoschemaTable) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { - return table.ErrUnsupportedOp -} - -// Allocators implements table.Table Allocators interface. -func (it *infoschemaTable) Allocators(_ table.AllocatorContext) autoid.Allocators { - return autoid.Allocators{} -} - -// Meta implements table.Table Meta interface. -func (it *infoschemaTable) Meta() *model.TableInfo { - return it.meta -} - -// GetPhysicalID implements table.Table GetPhysicalID interface. -func (it *infoschemaTable) GetPhysicalID() int64 { - return it.meta.ID -} - -// Type implements table.Table Type interface. -func (it *infoschemaTable) Type() table.Type { - return it.tp -} - -// GetPartitionedTable implements table.Table GetPartitionedTable interface. -func (it *infoschemaTable) GetPartitionedTable() table.PartitionedTable { - return nil -} - -// VirtualTable is a dummy table.Table implementation. -type VirtualTable struct{} - -// Cols implements table.Table Cols interface. -func (vt *VirtualTable) Cols() []*table.Column { - return nil -} - -// VisibleCols implements table.Table VisibleCols interface. -func (vt *VirtualTable) VisibleCols() []*table.Column { - return nil -} - -// HiddenCols implements table.Table HiddenCols interface. -func (vt *VirtualTable) HiddenCols() []*table.Column { - return nil -} - -// WritableCols implements table.Table WritableCols interface. -func (vt *VirtualTable) WritableCols() []*table.Column { - return nil -} - -// DeletableCols implements table.Table WritableCols interface. -func (vt *VirtualTable) DeletableCols() []*table.Column { - return nil -} - -// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. -func (vt *VirtualTable) FullHiddenColsAndVisibleCols() []*table.Column { - return nil -} - -// Indices implements table.Table Indices interface. -func (vt *VirtualTable) Indices() []table.Index { - return nil -} - -// WritableConstraint implements table.Table WritableConstraint interface. -func (vt *VirtualTable) WritableConstraint() []*table.Constraint { - return nil -} - -// RecordPrefix implements table.Table RecordPrefix interface. -func (vt *VirtualTable) RecordPrefix() kv.Key { - return nil -} - -// IndexPrefix implements table.Table IndexPrefix interface. -func (vt *VirtualTable) IndexPrefix() kv.Key { - return nil -} - -// AddRecord implements table.Table AddRecord interface. -func (vt *VirtualTable) AddRecord(ctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { - return nil, table.ErrUnsupportedOp -} - -// RemoveRecord implements table.Table RemoveRecord interface. -func (vt *VirtualTable) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []types.Datum) error { - return table.ErrUnsupportedOp -} - -// UpdateRecord implements table.Table UpdateRecord interface. -func (vt *VirtualTable) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { - return table.ErrUnsupportedOp -} - -// Allocators implements table.Table Allocators interface. -func (vt *VirtualTable) Allocators(_ table.AllocatorContext) autoid.Allocators { - return autoid.Allocators{} -} - -// Meta implements table.Table Meta interface. -func (vt *VirtualTable) Meta() *model.TableInfo { - return nil -} - -// GetPhysicalID implements table.Table GetPhysicalID interface. -func (vt *VirtualTable) GetPhysicalID() int64 { - return 0 -} - -// Type implements table.Table Type interface. -func (vt *VirtualTable) Type() table.Type { - return table.VirtualTable -} - -// GetTiFlashServerInfo returns all TiFlash server infos -func GetTiFlashServerInfo(store kv.Storage) ([]ServerInfo, error) { - if config.GetGlobalConfig().DisaggregatedTiFlash { - return nil, table.ErrUnsupportedOp - } - serversInfo, err := GetStoreServerInfo(store) - if err != nil { - return nil, err - } - serversInfo = FilterClusterServerInfo(serversInfo, set.NewStringSet(kv.TiFlash.Name()), set.NewStringSet()) - return serversInfo, nil -} - -// FetchClusterServerInfoWithoutPrivilegeCheck fetches cluster server information -func FetchClusterServerInfoWithoutPrivilegeCheck(ctx context.Context, vars *variable.SessionVars, serversInfo []ServerInfo, serverInfoType diagnosticspb.ServerInfoType, recordWarningInStmtCtx bool) ([][]types.Datum, error) { - type result struct { - idx int - rows [][]types.Datum - err error - } - wg := sync.WaitGroup{} - ch := make(chan result, len(serversInfo)) - infoTp := serverInfoType - finalRows := make([][]types.Datum, 0, len(serversInfo)*10) - for i, srv := range serversInfo { - address := srv.Address - remote := address - if srv.ServerType == "tidb" || srv.ServerType == "tiproxy" { - remote = srv.StatusAddr - } - wg.Add(1) - go func(index int, remote, address, serverTP string) { - util.WithRecovery(func() { - defer wg.Done() - items, err := getServerInfoByGRPC(ctx, remote, infoTp) - if err != nil { - ch <- result{idx: index, err: err} - return - } - partRows := serverInfoItemToRows(items, serverTP, address) - ch <- result{idx: index, rows: partRows} - }, nil) - }(i, remote, address, srv.ServerType) - } - wg.Wait() - close(ch) - // Keep the original order to make the result more stable - var results []result //nolint: prealloc - for result := range ch { - if result.err != nil { - if recordWarningInStmtCtx { - vars.StmtCtx.AppendWarning(result.err) - } else { - log.Warn(result.err.Error()) - } - continue - } - results = append(results, result) - } - slices.SortFunc(results, func(i, j result) int { return cmp.Compare(i.idx, j.idx) }) - for _, result := range results { - finalRows = append(finalRows, result.rows...) - } - return finalRows, nil -} - -func serverInfoItemToRows(items []*diagnosticspb.ServerInfoItem, tp, addr string) [][]types.Datum { - rows := make([][]types.Datum, 0, len(items)) - for _, v := range items { - for _, item := range v.Pairs { - row := types.MakeDatums( - tp, - addr, - v.Tp, - v.Name, - item.Key, - item.Value, - ) - rows = append(rows, row) - } - } - return rows -} - -func getServerInfoByGRPC(ctx context.Context, address string, tp diagnosticspb.ServerInfoType) ([]*diagnosticspb.ServerInfoItem, error) { - opt := grpc.WithTransportCredentials(insecure.NewCredentials()) - security := config.GetGlobalConfig().Security - if len(security.ClusterSSLCA) != 0 { - clusterSecurity := security.ClusterSecurity() - tlsConfig, err := clusterSecurity.ToTLSConfig() - if err != nil { - return nil, errors.Trace(err) - } - opt = grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)) - } - conn, err := grpc.Dial(address, opt) - if err != nil { - return nil, err - } - defer func() { - err := conn.Close() - if err != nil { - log.Error("close grpc connection error", zap.Error(err)) - } - }() - - cli := diagnosticspb.NewDiagnosticsClient(conn) - ctx, cancel := context.WithTimeout(ctx, time.Second*10) - defer cancel() - r, err := cli.ServerInfo(ctx, &diagnosticspb.ServerInfoRequest{Tp: tp}) - if err != nil { - return nil, err - } - return r.Items, nil -} - -// FilterClusterServerInfo filters serversInfo by nodeTypes and addresses -func FilterClusterServerInfo(serversInfo []ServerInfo, nodeTypes, addresses set.StringSet) []ServerInfo { - if len(nodeTypes) == 0 && len(addresses) == 0 { - return serversInfo - } - - filterServers := make([]ServerInfo, 0, len(serversInfo)) - for _, srv := range serversInfo { - // Skip some node type which has been filtered in WHERE clause - // e.g: SELECT * FROM cluster_config WHERE type='tikv' - if len(nodeTypes) > 0 && !nodeTypes.Exist(srv.ServerType) { - continue - } - // Skip some node address which has been filtered in WHERE clause - // e.g: SELECT * FROM cluster_config WHERE address='192.16.8.12:2379' - if len(addresses) > 0 && !addresses.Exist(srv.Address) { - continue - } - filterServers = append(filterServers, srv) - } - return filterServers -} diff --git a/pkg/kv/binding__failpoint_binding__.go b/pkg/kv/binding__failpoint_binding__.go deleted file mode 100644 index 91ba6650c6d47..0000000000000 --- a/pkg/kv/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package kv - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/kv/txn.go b/pkg/kv/txn.go index 64acfa199de55..5e0d3399f9d1e 100644 --- a/pkg/kv/txn.go +++ b/pkg/kv/txn.go @@ -143,7 +143,7 @@ func RunInNewTxn(ctx context.Context, store Storage, retryable bool, f func(ctx return err } - if val, _err_ := failpoint.Eval(_curpkg_("mockCommitErrorInNewTxn")); _err_ == nil { + failpoint.Inject("mockCommitErrorInNewTxn", func(val failpoint.Value) { if v := val.(string); len(v) > 0 { switch v { case "retry_once": @@ -152,10 +152,10 @@ func RunInNewTxn(ctx context.Context, store Storage, retryable bool, f func(ctx err = ErrTxnRetryable } case "no_retry": - return errors.New("mock commit error") + failpoint.Return(errors.New("mock commit error")) } } - } + }) if err == nil { err = txn.Commit(ctx) @@ -223,7 +223,7 @@ func setRequestSourceForInnerTxn(ctx context.Context, txn Transaction) { // SetTxnResourceGroup update the resource group name of target txn. func SetTxnResourceGroup(txn Transaction, name string) { txn.SetOption(ResourceGroupName, name) - if val, _err_ := failpoint.Eval(_curpkg_("TxnResourceGroupChecker")); _err_ == nil { + failpoint.Inject("TxnResourceGroupChecker", func(val failpoint.Value) { expectedRgName := val.(string) validateRNameInterceptor := func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { @@ -243,5 +243,5 @@ func SetTxnResourceGroup(txn Transaction, name string) { } } txn.SetOption(RPCInterceptor, interceptor.NewRPCInterceptor("test-validate-rg-name", validateRNameInterceptor)) - } + }) } diff --git a/pkg/kv/txn.go__failpoint_stash__ b/pkg/kv/txn.go__failpoint_stash__ deleted file mode 100644 index 5e0d3399f9d1e..0000000000000 --- a/pkg/kv/txn.go__failpoint_stash__ +++ /dev/null @@ -1,247 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kv - -import ( - "context" - "errors" - "fmt" - "math" - "math/rand" - "sync" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/tikvrpc/interceptor" - "go.uber.org/zap" -) - -const ( - // TimeToPrintLongTimeInternalTxn is the duration if the internal transaction lasts more than it, - // TiDB prints a log message. - TimeToPrintLongTimeInternalTxn = time.Minute * 5 -) - -var globalInnerTxnTsBox = innerTxnStartTsBox{ - innerTSLock: sync.Mutex{}, - innerTxnStartTsMap: make(map[uint64]struct{}, 256), -} - -type innerTxnStartTsBox struct { - innerTSLock sync.Mutex - innerTxnStartTsMap map[uint64]struct{} -} - -func (ib *innerTxnStartTsBox) storeInnerTxnTS(startTS uint64) { - ib.innerTSLock.Lock() - ib.innerTxnStartTsMap[startTS] = struct{}{} - ib.innerTSLock.Unlock() -} - -func (ib *innerTxnStartTsBox) deleteInnerTxnTS(startTS uint64) { - ib.innerTSLock.Lock() - delete(ib.innerTxnStartTsMap, startTS) - ib.innerTSLock.Unlock() -} - -// GetMinInnerTxnStartTS get the min StartTS between startTSLowerLimit and curMinStartTS in globalInnerTxnTsBox. -func GetMinInnerTxnStartTS(now time.Time, startTSLowerLimit uint64, - curMinStartTS uint64) uint64 { - return globalInnerTxnTsBox.getMinStartTS(now, startTSLowerLimit, curMinStartTS) -} - -func (ib *innerTxnStartTsBox) getMinStartTS(now time.Time, startTSLowerLimit uint64, - curMinStartTS uint64) uint64 { - minStartTS := curMinStartTS - ib.innerTSLock.Lock() - for innerTS := range ib.innerTxnStartTsMap { - PrintLongTimeInternalTxn(now, innerTS, true) - if innerTS > startTSLowerLimit && innerTS < minStartTS { - minStartTS = innerTS - } - } - ib.innerTSLock.Unlock() - return minStartTS -} - -// PrintLongTimeInternalTxn print the internal transaction information. -// runByFunction true means the transaction is run by `RunInNewTxn`, -// -// false means the transaction is run by internal session. -func PrintLongTimeInternalTxn(now time.Time, startTS uint64, runByFunction bool) { - if startTS > 0 { - innerTxnStartTime := oracle.GetTimeFromTS(startTS) - if now.Sub(innerTxnStartTime) > TimeToPrintLongTimeInternalTxn { - callerName := "internal session" - if runByFunction { - callerName = "RunInNewTxn" - } - infoHeader := fmt.Sprintf("An internal transaction running by %s lasts long time", callerName) - - logutil.BgLogger().Info(infoHeader, - zap.Duration("time", now.Sub(innerTxnStartTime)), zap.Uint64("startTS", startTS), - zap.Time("start time", innerTxnStartTime)) - } - } -} - -// RunInNewTxn will run the f in a new transaction environment, should be used by inner txn only. -func RunInNewTxn(ctx context.Context, store Storage, retryable bool, f func(ctx context.Context, txn Transaction) error) error { - var ( - err error - originalTxnTS uint64 - txn Transaction - ) - - defer func() { - globalInnerTxnTsBox.deleteInnerTxnTS(originalTxnTS) - }() - - for i := uint(0); i < MaxRetryCnt; i++ { - txn, err = store.Begin() - if err != nil { - logutil.BgLogger().Error("RunInNewTxn", zap.Error(err)) - return err - } - setRequestSourceForInnerTxn(ctx, txn) - - // originalTxnTS is used to trace the original transaction when the function is retryable. - if i == 0 { - originalTxnTS = txn.StartTS() - globalInnerTxnTsBox.storeInnerTxnTS(originalTxnTS) - } - - err = f(ctx, txn) - if err != nil { - err1 := txn.Rollback() - terror.Log(err1) - if retryable && IsTxnRetryableError(err) { - logutil.BgLogger().Warn("RunInNewTxn", - zap.Uint64("retry txn", txn.StartTS()), - zap.Uint64("original txn", originalTxnTS), - zap.Error(err)) - continue - } - return err - } - - failpoint.Inject("mockCommitErrorInNewTxn", func(val failpoint.Value) { - if v := val.(string); len(v) > 0 { - switch v { - case "retry_once": - //nolint:noloopclosure - if i == 0 { - err = ErrTxnRetryable - } - case "no_retry": - failpoint.Return(errors.New("mock commit error")) - } - } - }) - - if err == nil { - err = txn.Commit(ctx) - if err == nil { - break - } - } - if retryable && IsTxnRetryableError(err) { - logutil.BgLogger().Warn("RunInNewTxn", - zap.Uint64("retry txn", txn.StartTS()), - zap.Uint64("original txn", originalTxnTS), - zap.Error(err)) - BackOff(i) - continue - } - return err - } - return err -} - -var ( - // MaxRetryCnt represents maximum retry times. - MaxRetryCnt uint = 100 - // retryBackOffBase is the initial duration, in microsecond, a failed transaction stays dormancy before it retries - retryBackOffBase = 1 - // retryBackOffCap is the max amount of duration, in microsecond, a failed transaction stays dormancy before it retries - retryBackOffCap = 100 -) - -// BackOff Implements exponential backoff with full jitter. -// Returns real back off time in microsecond. -// See http://www.awsarchitectureblog.com/2015/03/backoff.html. -func BackOff(attempts uint) int { - upper := int(math.Min(float64(retryBackOffCap), float64(retryBackOffBase)*math.Pow(2.0, float64(attempts)))) - sleep := time.Duration(rand.Intn(upper)) * time.Millisecond // #nosec G404 - time.Sleep(sleep) - return int(sleep) -} - -func setRequestSourceForInnerTxn(ctx context.Context, txn Transaction) { - if source := ctx.Value(RequestSourceKey); source != nil { - requestSource := source.(RequestSource) - if requestSource.RequestSourceType != "" { - if !requestSource.RequestSourceInternal { - logutil.Logger(ctx).Warn("`RunInNewTxn` should be used by inner txn only") - } - txn.SetOption(RequestSourceInternal, requestSource.RequestSourceInternal) - txn.SetOption(RequestSourceType, requestSource.RequestSourceType) - if requestSource.ExplicitRequestSourceType != "" { - txn.SetOption(ExplicitRequestSourceType, requestSource.ExplicitRequestSourceType) - } - return - } - } - // panic in test mode in case there are requests without source in the future. - // log warnings in production mode. - if intest.InTest { - panic("unexpected no source type context, if you see this error, " + - "the `RequestSourceTypeKey` is missing in your context") - } - logutil.Logger(ctx).Warn("unexpected no source type context, if you see this warning, " + - "the `RequestSourceTypeKey` is missing in the context") -} - -// SetTxnResourceGroup update the resource group name of target txn. -func SetTxnResourceGroup(txn Transaction, name string) { - txn.SetOption(ResourceGroupName, name) - failpoint.Inject("TxnResourceGroupChecker", func(val failpoint.Value) { - expectedRgName := val.(string) - validateRNameInterceptor := func(next interceptor.RPCInterceptorFunc) interceptor.RPCInterceptorFunc { - return func(target string, req *tikvrpc.Request) (*tikvrpc.Response, error) { - var rgName *string - switch r := req.Req.(type) { - case *kvrpcpb.PrewriteRequest: - rgName = &r.Context.ResourceControlContext.ResourceGroupName - case *kvrpcpb.CommitRequest: - rgName = &r.Context.ResourceControlContext.ResourceGroupName - case *kvrpcpb.PessimisticLockRequest: - rgName = &r.Context.ResourceControlContext.ResourceGroupName - } - if rgName != nil && *rgName != expectedRgName { - panic(fmt.Sprintf("resource group name not match, expected: %s, actual: %s", expectedRgName, *rgName)) - } - return next(target, req) - } - } - txn.SetOption(RPCInterceptor, interceptor.NewRPCInterceptor("test-validate-rg-name", validateRNameInterceptor)) - }) -} diff --git a/pkg/lightning/backend/backend.go b/pkg/lightning/backend/backend.go index ddc56c94665c1..878b556c7e460 100644 --- a/pkg/lightning/backend/backend.go +++ b/pkg/lightning/backend/backend.go @@ -267,7 +267,7 @@ func (be EngineManager) OpenEngine( logger.Info("open engine") - if val, _err_ := failpoint.Eval(_curpkg_("FailIfEngineCountExceeds")); _err_ == nil { + failpoint.Inject("FailIfEngineCountExceeds", func(val failpoint.Value) { if m, ok := metric.FromContext(ctx); ok { closedCounter := m.ImporterEngineCounter.WithLabelValues("closed") openCounter := m.ImporterEngineCounter.WithLabelValues("open") @@ -280,7 +280,7 @@ func (be EngineManager) OpenEngine( openCount, closedCount, injectValue)) } } - } + }) return &OpenedEngine{ engine: engine{ diff --git a/pkg/lightning/backend/backend.go__failpoint_stash__ b/pkg/lightning/backend/backend.go__failpoint_stash__ deleted file mode 100644 index 878b556c7e460..0000000000000 --- a/pkg/lightning/backend/backend.go__failpoint_stash__ +++ /dev/null @@ -1,439 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package backend - -import ( - "context" - "fmt" - "time" - - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/mydump" - "github.com/pingcap/tidb/pkg/parser/model" - "go.uber.org/zap" -) - -const ( - importMaxRetryTimes = 3 // tikv-importer has done retry internally. so we don't retry many times. -) - -func makeTag(tableName string, engineID int64) string { - return fmt.Sprintf("%s:%d", tableName, engineID) -} - -func makeLogger(logger log.Logger, tag string, engineUUID uuid.UUID) log.Logger { - return logger.With( - zap.String("engineTag", tag), - zap.Stringer("engineUUID", engineUUID), - ) -} - -// MakeUUID generates a UUID for the engine and a tag for the engine. -func MakeUUID(tableName string, engineID int64) (string, uuid.UUID) { - tag := makeTag(tableName, engineID) - engineUUID := uuid.NewSHA1(engineNamespace, []byte(tag)) - return tag, engineUUID -} - -var engineNamespace = uuid.MustParse("d68d6abe-c59e-45d6-ade8-e2b0ceb7bedf") - -// EngineFileSize represents the size of an engine on disk and in memory. -type EngineFileSize struct { - // UUID is the engine's UUID. - UUID uuid.UUID - // DiskSize is the estimated total file size on disk right now. - DiskSize int64 - // MemSize is the total memory size used by the engine. This is the - // estimated additional size saved onto disk after calling Flush(). - MemSize int64 - // IsImporting indicates whether the engine performing Import(). - IsImporting bool -} - -// LocalWriterConfig defines the configuration to open a LocalWriter -type LocalWriterConfig struct { - // Local backend specified configuration - Local struct { - // is the chunk KV written to this LocalWriter sent in order - IsKVSorted bool - // MemCacheSize specifies the estimated memory cache limit used by this local - // writer. It has higher priority than BackendConfig.LocalWriterMemCacheSize if - // set. - MemCacheSize int64 - } - // TiDB backend specified configuration - TiDB struct { - TableName string - } -} - -// EngineConfig defines configuration used for open engine -type EngineConfig struct { - // TableInfo is the corresponding tidb table info - TableInfo *checkpoints.TidbTableInfo - // local backend specified configuration - Local LocalEngineConfig - // local backend external engine specified configuration - External *ExternalEngineConfig - // KeepSortDir indicates whether to keep the temporary sort directory - // when opening the engine, instead of removing it. - KeepSortDir bool - // TS is the preset timestamp of data in the engine. When it's 0, the used TS - // will be set lazily. - TS uint64 -} - -// LocalEngineConfig is the configuration used for local backend in OpenEngine. -type LocalEngineConfig struct { - // compact small SSTs before ingest into pebble - Compact bool - // raw kvs size threshold to trigger compact - CompactThreshold int64 - // compact routine concurrency - CompactConcurrency int - - // blocksize - BlockSize int -} - -// ExternalEngineConfig is the configuration used for local backend external engine. -type ExternalEngineConfig struct { - StorageURI string - DataFiles []string - StatFiles []string - StartKey []byte - EndKey []byte - SplitKeys [][]byte - RegionSplitSize int64 - // TotalFileSize can be an estimated value. - TotalFileSize int64 - // TotalKVCount can be an estimated value. - TotalKVCount int64 - CheckHotspot bool -} - -// CheckCtx contains all parameters used in CheckRequirements -type CheckCtx struct { - DBMetas []*mydump.MDDatabaseMeta -} - -// TargetInfoGetter defines the interfaces to get target information. -type TargetInfoGetter interface { - // FetchRemoteDBModels obtains the models of all databases. Currently, only - // the database name is filled. - FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) - - // FetchRemoteTableModels obtains the models of all tables given the schema - // name. The returned table info does not need to be precise if the encoder, - // is not requiring them, but must at least fill in the following fields for - // TablesFromMeta to succeed: - // - Name - // - State (must be model.StatePublic) - // - ID - // - Columns - // * Name - // * State (must be model.StatePublic) - // * Offset (must be 0, 1, 2, ...) - // - PKIsHandle (true = do not generate _tidb_rowid) - FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) - - // CheckRequirements performs the check whether the backend satisfies the version requirements - CheckRequirements(ctx context.Context, checkCtx *CheckCtx) error -} - -// Backend defines the interface for a backend. -// Implementations of this interface must be goroutine safe: you can share an -// instance and execute any method anywhere. -// Usual workflow: -// 1. Create a `Backend` for the whole process. -// 2. For each table, -// i. Split into multiple "batches" consisting of data files with roughly equal total size. -// ii. For each batch, -// a. Create an `OpenedEngine` via `backend.OpenEngine()` -// b. For each chunk, deliver data into the engine via `engine.WriteRows()` -// c. When all chunks are written, obtain a `ClosedEngine` via `engine.Close()` -// d. Import data via `engine.Import()` -// e. Cleanup via `engine.Cleanup()` -// 3. Close the connection via `backend.Close()` -type Backend interface { - // Close the connection to the backend. - Close() - - // RetryImportDelay returns the duration to sleep when retrying an import - RetryImportDelay() time.Duration - - // ShouldPostProcess returns whether KV-specific post-processing should be - // performed for this backend. Post-processing includes checksum and analyze. - ShouldPostProcess() bool - - OpenEngine(ctx context.Context, config *EngineConfig, engineUUID uuid.UUID) error - - CloseEngine(ctx context.Context, config *EngineConfig, engineUUID uuid.UUID) error - - // ImportEngine imports engine data to the backend. If it returns ErrDuplicateDetected, - // it means there is duplicate detected. For this situation, all data in the engine must be imported. - // It's safe to reset or cleanup this engine. - ImportEngine(ctx context.Context, engineUUID uuid.UUID, regionSplitSize, regionSplitKeys int64) error - - CleanupEngine(ctx context.Context, engineUUID uuid.UUID) error - - // FlushEngine ensures all KV pairs written to an open engine has been - // synchronized, such that kill-9'ing Lightning afterwards and resuming from - // checkpoint can recover the exact same content. - // - // This method is only relevant for local backend, and is no-op for all - // other backends. - FlushEngine(ctx context.Context, engineUUID uuid.UUID) error - - // FlushAllEngines performs FlushEngine on all opened engines. This is a - // very expensive operation and should only be used in some rare situation - // (e.g. preparing to resolve a disk quota violation). - FlushAllEngines(ctx context.Context) error - - // ResetEngine clears all written KV pairs in this opened engine. - ResetEngine(ctx context.Context, engineUUID uuid.UUID) error - - // LocalWriter obtains a thread-local EngineWriter for writing rows into the given engine. - LocalWriter(ctx context.Context, cfg *LocalWriterConfig, engineUUID uuid.UUID) (EngineWriter, error) -} - -// EngineManager is the manager of engines. -// this is a wrapper of Backend, which provides some common methods for managing engines. -// and it has no states, can be created on demand -type EngineManager struct { - backend Backend -} - -type engine struct { - backend Backend - logger log.Logger - uuid uuid.UUID - // id of the engine, used to generate uuid and stored in checkpoint - // for index engine it's -1 - id int32 -} - -// OpenedEngine is an opened engine, allowing data to be written via WriteRows. -// This type is goroutine safe: you can share an instance and execute any method -// anywhere. -type OpenedEngine struct { - engine - tableName string - config *EngineConfig -} - -// MakeEngineManager creates a new Backend from an Backend. -func MakeEngineManager(ab Backend) EngineManager { - return EngineManager{backend: ab} -} - -// OpenEngine opens an engine with the given table name and engine ID. -func (be EngineManager) OpenEngine( - ctx context.Context, - config *EngineConfig, - tableName string, - engineID int32, -) (*OpenedEngine, error) { - tag, engineUUID := MakeUUID(tableName, int64(engineID)) - logger := makeLogger(log.FromContext(ctx), tag, engineUUID) - - if err := be.backend.OpenEngine(ctx, config, engineUUID); err != nil { - return nil, err - } - - if m, ok := metric.FromContext(ctx); ok { - openCounter := m.ImporterEngineCounter.WithLabelValues("open") - openCounter.Inc() - } - - logger.Info("open engine") - - failpoint.Inject("FailIfEngineCountExceeds", func(val failpoint.Value) { - if m, ok := metric.FromContext(ctx); ok { - closedCounter := m.ImporterEngineCounter.WithLabelValues("closed") - openCounter := m.ImporterEngineCounter.WithLabelValues("open") - openCount := metric.ReadCounter(openCounter) - - closedCount := metric.ReadCounter(closedCounter) - if injectValue := val.(int); openCount-closedCount > float64(injectValue) { - panic(fmt.Sprintf( - "forcing failure due to FailIfEngineCountExceeds: %v - %v >= %d", - openCount, closedCount, injectValue)) - } - } - }) - - return &OpenedEngine{ - engine: engine{ - backend: be.backend, - logger: logger, - uuid: engineUUID, - id: engineID, - }, - tableName: tableName, - config: config, - }, nil -} - -// Close the opened engine to prepare it for importing. -func (engine *OpenedEngine) Close(ctx context.Context) (*ClosedEngine, error) { - closedEngine, err := engine.unsafeClose(ctx, engine.config) - if err == nil { - if m, ok := metric.FromContext(ctx); ok { - m.ImporterEngineCounter.WithLabelValues("closed").Inc() - } - } - return closedEngine, err -} - -// Flush current written data for local backend -func (engine *OpenedEngine) Flush(ctx context.Context) error { - return engine.backend.FlushEngine(ctx, engine.uuid) -} - -// LocalWriter returns a writer that writes to the local backend. -func (engine *OpenedEngine) LocalWriter(ctx context.Context, cfg *LocalWriterConfig) (EngineWriter, error) { - return engine.backend.LocalWriter(ctx, cfg, engine.uuid) -} - -// SetTS sets the TS of the engine. In most cases if the caller wants to specify -// TS it should use the TS field in EngineConfig. This method is only used after -// a ResetEngine. -func (engine *OpenedEngine) SetTS(ts uint64) { - engine.config.TS = ts -} - -// UnsafeCloseEngine closes the engine without first opening it. -// This method is "unsafe" as it does not follow the normal operation sequence -// (Open -> Write -> Close -> Import). This method should only be used when one -// knows via other ways that the engine has already been opened, e.g. when -// resuming from a checkpoint. -func (be EngineManager) UnsafeCloseEngine(ctx context.Context, cfg *EngineConfig, - tableName string, engineID int32) (*ClosedEngine, error) { - tag, engineUUID := MakeUUID(tableName, int64(engineID)) - return be.UnsafeCloseEngineWithUUID(ctx, cfg, tag, engineUUID, engineID) -} - -// UnsafeCloseEngineWithUUID closes the engine without first opening it. -// This method is "unsafe" as it does not follow the normal operation sequence -// (Open -> Write -> Close -> Import). This method should only be used when one -// knows via other ways that the engine has already been opened, e.g. when -// resuming from a checkpoint. -func (be EngineManager) UnsafeCloseEngineWithUUID(ctx context.Context, cfg *EngineConfig, tag string, - engineUUID uuid.UUID, id int32) (*ClosedEngine, error) { - return engine{ - backend: be.backend, - logger: makeLogger(log.FromContext(ctx), tag, engineUUID), - uuid: engineUUID, - id: id, - }.unsafeClose(ctx, cfg) -} - -func (en engine) unsafeClose(ctx context.Context, cfg *EngineConfig) (*ClosedEngine, error) { - task := en.logger.Begin(zap.InfoLevel, "engine close") - err := en.backend.CloseEngine(ctx, cfg, en.uuid) - task.End(zap.ErrorLevel, err) - if err != nil { - return nil, err - } - return &ClosedEngine{engine: en}, nil -} - -// GetID get engine id. -func (en engine) GetID() int32 { - return en.id -} - -func (en engine) GetUUID() uuid.UUID { - return en.uuid -} - -// ClosedEngine represents a closed engine, allowing ingestion into the target. -// This type is goroutine safe: you can share an instance and execute any method -// anywhere. -type ClosedEngine struct { - engine -} - -// NewClosedEngine creates a new ClosedEngine. -func NewClosedEngine(backend Backend, logger log.Logger, uuid uuid.UUID, id int32) *ClosedEngine { - return &ClosedEngine{ - engine: engine{ - backend: backend, - logger: logger, - uuid: uuid, - id: id, - }, - } -} - -// Import the data written to the engine into the target. -func (engine *ClosedEngine) Import(ctx context.Context, regionSplitSize, regionSplitKeys int64) error { - var err error - - for i := 0; i < importMaxRetryTimes; i++ { - task := engine.logger.With(zap.Int("retryCnt", i)).Begin(zap.InfoLevel, "import") - err = engine.backend.ImportEngine(ctx, engine.uuid, regionSplitSize, regionSplitKeys) - if !common.IsRetryableError(err) { - if common.ErrFoundDuplicateKeys.Equal(err) { - task.End(zap.WarnLevel, err) - } else { - task.End(zap.ErrorLevel, err) - } - return err - } - task.Warn("import spuriously failed, going to retry again", log.ShortError(err)) - time.Sleep(engine.backend.RetryImportDelay()) - } - - return errors.Annotatef(err, "[%s] import reach max retry %d and still failed", engine.uuid, importMaxRetryTimes) -} - -// Cleanup deletes the intermediate data from target. -func (engine *ClosedEngine) Cleanup(ctx context.Context) error { - task := engine.logger.Begin(zap.InfoLevel, "cleanup") - err := engine.backend.CleanupEngine(ctx, engine.uuid) - task.End(zap.WarnLevel, err) - return err -} - -// Logger returns the logger for the engine. -func (engine *ClosedEngine) Logger() log.Logger { - return engine.logger -} - -// ChunkFlushStatus is the status of a chunk flush. -type ChunkFlushStatus interface { - Flushed() bool -} - -// EngineWriter is the interface for writing data to an engine. -type EngineWriter interface { - AppendRows(ctx context.Context, columnNames []string, rows encode.Rows) error - IsSynced() bool - Close(ctx context.Context) (ChunkFlushStatus, error) -} - -// GetEngineUUID returns the engine UUID. -func (engine *OpenedEngine) GetEngineUUID() uuid.UUID { - return engine.uuid -} diff --git a/pkg/lightning/backend/binding__failpoint_binding__.go b/pkg/lightning/backend/binding__failpoint_binding__.go deleted file mode 100644 index 6a726164e38ec..0000000000000 --- a/pkg/lightning/backend/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package backend - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/lightning/backend/external/binding__failpoint_binding__.go b/pkg/lightning/backend/external/binding__failpoint_binding__.go deleted file mode 100644 index f0c178da70e4b..0000000000000 --- a/pkg/lightning/backend/external/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package external - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/lightning/backend/external/byte_reader.go b/pkg/lightning/backend/external/byte_reader.go index 473b51a7ccfe0..d2d46939f6aab 100644 --- a/pkg/lightning/backend/external/byte_reader.go +++ b/pkg/lightning/backend/external/byte_reader.go @@ -327,11 +327,11 @@ func (r *byteReader) closeConcurrentReader() (reloadCnt, offsetInOldBuffer int) zap.Int("dropBytes", r.concurrentReader.bufSizePerConc*(len(r.curBuf)-r.curBufIdx)-r.curBufOffset), zap.Int("curBufIdx", r.curBufIdx), ) - if _, _err_ := failpoint.Eval(_curpkg_("assertReloadAtMostOnce")); _err_ == nil { + failpoint.Inject("assertReloadAtMostOnce", func() { if r.concurrentReader.reloadCnt > 1 { panic(fmt.Sprintf("reloadCnt is %d", r.concurrentReader.reloadCnt)) } - } + }) r.concurrentReader.largeBufferPool.Destroy() r.concurrentReader.largeBuf = nil r.concurrentReader.now = false diff --git a/pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ b/pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ deleted file mode 100644 index d2d46939f6aab..0000000000000 --- a/pkg/lightning/backend/external/byte_reader.go__failpoint_stash__ +++ /dev/null @@ -1,351 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package external - -import ( - "context" - "fmt" - "io" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/membuf" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/size" - "go.uber.org/zap" -) - -var ( - // ConcurrentReaderBufferSizePerConc is the buffer size for concurrent reader per - // concurrency. - ConcurrentReaderBufferSizePerConc = int(8 * size.MB) - // in readAllData, expected concurrency less than this value will not use - // concurrent reader. - readAllDataConcThreshold = uint64(4) -) - -// byteReader provides structured reading on a byte stream of external storage. -// It can also switch to concurrent reading mode and fetch a larger amount of -// data to improve throughput. -type byteReader struct { - ctx context.Context - storageReader storage.ExternalFileReader - - // curBuf is either smallBuf or concurrentReader.largeBuf. - curBuf [][]byte - curBufIdx int // invariant: 0 <= curBufIdx < len(curBuf) when curBuf contains unread data - curBufOffset int // invariant: 0 <= curBufOffset < len(curBuf[curBufIdx]) if curBufIdx < len(curBuf) - smallBuf []byte - - concurrentReader struct { - largeBufferPool *membuf.Buffer - store storage.ExternalStorage - filename string - concurrency int - bufSizePerConc int - - now bool - expected bool - largeBuf [][]byte - reader *concurrentFileReader - reloadCnt int - } - - logger *zap.Logger -} - -func openStoreReaderAndSeek( - ctx context.Context, - store storage.ExternalStorage, - name string, - initFileOffset uint64, - prefetchSize int, -) (storage.ExternalFileReader, error) { - storageReader, err := store.Open(ctx, name, &storage.ReaderOption{PrefetchSize: prefetchSize}) - if err != nil { - return nil, err - } - _, err = storageReader.Seek(int64(initFileOffset), io.SeekStart) - if err != nil { - return nil, err - } - return storageReader, nil -} - -// newByteReader wraps readNBytes functionality to storageReader. If store and -// filename are also given, this reader can use switchConcurrentMode to switch to -// concurrent reading mode. -func newByteReader( - ctx context.Context, - storageReader storage.ExternalFileReader, - bufSize int, -) (r *byteReader, err error) { - defer func() { - if err != nil && r != nil { - _ = r.Close() - } - }() - r = &byteReader{ - ctx: ctx, - storageReader: storageReader, - smallBuf: make([]byte, bufSize), - curBufOffset: 0, - } - r.curBuf = [][]byte{r.smallBuf} - r.logger = logutil.Logger(r.ctx) - return r, r.reload() -} - -func (r *byteReader) enableConcurrentRead( - store storage.ExternalStorage, - filename string, - concurrency int, - bufSizePerConc int, - bufferPool *membuf.Buffer, -) { - r.concurrentReader.store = store - r.concurrentReader.filename = filename - r.concurrentReader.concurrency = concurrency - r.concurrentReader.bufSizePerConc = bufSizePerConc - r.concurrentReader.largeBufferPool = bufferPool -} - -// switchConcurrentMode is used to help implement sortedReader.switchConcurrentMode. -// See the comment of the interface. -func (r *byteReader) switchConcurrentMode(useConcurrent bool) error { - readerFields := &r.concurrentReader - if readerFields.store == nil { - r.logger.Warn("concurrent reader is not enabled, skip switching") - // caller don't need to care about it. - return nil - } - // need to set it before reload() - readerFields.expected = useConcurrent - // concurrent reader will be lazily initialized when reload() - if useConcurrent { - return nil - } - - // no change - if !readerFields.now { - return nil - } - - // rest cases is caller want to turn off concurrent reader. We should turn off - // immediately to release memory. - reloadCnt, offsetInOldBuf := r.closeConcurrentReader() - // here we can assume largeBuf is always fully loaded, because the only exception - // is it's the end of file. When it's the end of the file, caller will see EOF - // and no further switchConcurrentMode should be called. - largeBufSize := readerFields.bufSizePerConc * readerFields.concurrency - delta := int64(offsetInOldBuf + (reloadCnt-1)*largeBufSize) - - if _, err := r.storageReader.Seek(delta, io.SeekCurrent); err != nil { - return err - } - err := r.reload() - if err != nil && err == io.EOF { - // ignore EOF error, let readNBytes handle it - return nil - } - return err -} - -func (r *byteReader) switchToConcurrentReader() error { - // because it will be called only when buffered data of storageReader is used - // up, we can use seek(0, io.SeekCurrent) to get the offset for concurrent - // reader - currOffset, err := r.storageReader.Seek(0, io.SeekCurrent) - if err != nil { - return err - } - fileSize, err := r.storageReader.GetFileSize() - if err != nil { - return err - } - readerFields := &r.concurrentReader - readerFields.reader, err = newConcurrentFileReader( - r.ctx, - readerFields.store, - readerFields.filename, - currOffset, - fileSize, - readerFields.concurrency, - readerFields.bufSizePerConc, - ) - if err != nil { - return err - } - - readerFields.largeBuf = make([][]byte, readerFields.concurrency) - for i := range readerFields.largeBuf { - readerFields.largeBuf[i] = readerFields.largeBufferPool.AllocBytes(readerFields.bufSizePerConc) - if readerFields.largeBuf[i] == nil { - return errors.Errorf("alloc large buffer failed, size %d", readerFields.bufSizePerConc) - } - } - - r.curBuf = readerFields.largeBuf - r.curBufOffset = 0 - readerFields.now = true - return nil -} - -// readNBytes reads the next n bytes from the reader and returns a buffer slice -// containing those bytes. The content of returned slice may be changed after -// next call. -func (r *byteReader) readNBytes(n int) ([]byte, error) { - if n <= 0 { - return nil, errors.Errorf("illegal n (%d) when reading from external storage", n) - } - if n > int(size.GB) { - return nil, errors.Errorf("read %d bytes from external storage, exceed max limit %d", n, size.GB) - } - - readLen, bs := r.next(n) - if readLen == n && len(bs) == 1 { - return bs[0], nil - } - // need to flatten bs - auxBuf := make([]byte, n) - for _, b := range bs { - copy(auxBuf[len(auxBuf)-n:], b) - n -= len(b) - } - hasRead := readLen > 0 - for n > 0 { - err := r.reload() - switch err { - case nil: - case io.EOF: - // EOF is only allowed when we have not read any data - if hasRead { - return nil, io.ErrUnexpectedEOF - } - return nil, err - default: - return nil, err - } - readLen, bs = r.next(n) - hasRead = hasRead || readLen > 0 - for _, b := range bs { - copy(auxBuf[len(auxBuf)-n:], b) - n -= len(b) - } - } - return auxBuf, nil -} - -func (r *byteReader) next(n int) (int, [][]byte) { - retCnt := 0 - // TODO(lance6716): heap escape performance? - ret := make([][]byte, 0, len(r.curBuf)-r.curBufIdx+1) - for r.curBufIdx < len(r.curBuf) && n > 0 { - cur := r.curBuf[r.curBufIdx] - if r.curBufOffset+n <= len(cur) { - ret = append(ret, cur[r.curBufOffset:r.curBufOffset+n]) - retCnt += n - r.curBufOffset += n - if r.curBufOffset == len(cur) { - r.curBufIdx++ - r.curBufOffset = 0 - } - break - } - ret = append(ret, cur[r.curBufOffset:]) - retCnt += len(cur) - r.curBufOffset - n -= len(cur) - r.curBufOffset - r.curBufIdx++ - r.curBufOffset = 0 - } - - return retCnt, ret -} - -func (r *byteReader) reload() error { - to := r.concurrentReader.expected - now := r.concurrentReader.now - // in read only false -> true is possible - if !now && to { - r.logger.Info("switch reader mode", zap.Bool("use concurrent mode", true)) - err := r.switchToConcurrentReader() - if err != nil { - return err - } - } - - if r.concurrentReader.now { - r.concurrentReader.reloadCnt++ - buffers, err := r.concurrentReader.reader.read(r.concurrentReader.largeBuf) - if err != nil { - return err - } - r.curBuf = buffers - r.curBufIdx = 0 - r.curBufOffset = 0 - return nil - } - // when not using concurrentReader, len(curBuf) == 1 - n, err := io.ReadFull(r.storageReader, r.curBuf[0][0:]) - if err != nil { - switch err { - case io.EOF: - // move curBufIdx so following read will also find EOF - r.curBufIdx = len(r.curBuf) - return err - case io.ErrUnexpectedEOF: - // The last batch. - r.curBuf[0] = r.curBuf[0][:n] - case context.Canceled: - return err - default: - r.logger.Warn("other error during read", zap.Error(err)) - return err - } - } - r.curBufIdx = 0 - r.curBufOffset = 0 - return nil -} - -func (r *byteReader) closeConcurrentReader() (reloadCnt, offsetInOldBuffer int) { - r.logger.Info("drop data in closeConcurrentReader", - zap.Int("reloadCnt", r.concurrentReader.reloadCnt), - zap.Int("dropBytes", r.concurrentReader.bufSizePerConc*(len(r.curBuf)-r.curBufIdx)-r.curBufOffset), - zap.Int("curBufIdx", r.curBufIdx), - ) - failpoint.Inject("assertReloadAtMostOnce", func() { - if r.concurrentReader.reloadCnt > 1 { - panic(fmt.Sprintf("reloadCnt is %d", r.concurrentReader.reloadCnt)) - } - }) - r.concurrentReader.largeBufferPool.Destroy() - r.concurrentReader.largeBuf = nil - r.concurrentReader.now = false - reloadCnt = r.concurrentReader.reloadCnt - r.concurrentReader.reloadCnt = 0 - r.curBuf = [][]byte{r.smallBuf} - offsetInOldBuffer = r.curBufOffset + r.curBufIdx*r.concurrentReader.bufSizePerConc - r.curBufOffset = 0 - return -} - -func (r *byteReader) Close() error { - if r.concurrentReader.now { - r.closeConcurrentReader() - } - return r.storageReader.Close() -} diff --git a/pkg/lightning/backend/external/engine.go b/pkg/lightning/backend/external/engine.go index 149068f52b76e..7a42354e20acc 100644 --- a/pkg/lightning/backend/external/engine.go +++ b/pkg/lightning/backend/external/engine.go @@ -357,9 +357,9 @@ func (e *Engine) LoadIngestData( ) error { // try to make every worker busy for each batch regionBatchSize := e.workerConcurrency - if val, _err_ := failpoint.Eval(_curpkg_("LoadIngestDataBatchSize")); _err_ == nil { + failpoint.Inject("LoadIngestDataBatchSize", func(val failpoint.Value) { regionBatchSize = val.(int) - } + }) for i := 0; i < len(regionRanges); i += regionBatchSize { err := e.loadBatchRegionData(ctx, regionRanges[i].Start, regionRanges[min(i+regionBatchSize, len(regionRanges))-1].End, outCh) if err != nil { diff --git a/pkg/lightning/backend/external/engine.go__failpoint_stash__ b/pkg/lightning/backend/external/engine.go__failpoint_stash__ deleted file mode 100644 index 7a42354e20acc..0000000000000 --- a/pkg/lightning/backend/external/engine.go__failpoint_stash__ +++ /dev/null @@ -1,732 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package external - -import ( - "bytes" - "context" - "sort" - "sync" - "time" - - "github.com/cockroachdb/pebble" - "github.com/docker/go-units" - "github.com/jfcg/sorty/v2" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/membuf" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -// during test on ks3, we found that we can open about 8000 connections to ks3, -// bigger than that, we might receive "connection reset by peer" error, and -// the read speed will be very slow, still investigating the reason. -// Also open too many connections will take many memory in kernel, and the -// test is based on k8s pod, not sure how it will behave on EC2. -// but, ks3 supporter says there's no such limit on connections. -// And our target for global sort is AWS s3, this default value might not fit well. -// TODO: adjust it according to cloud storage. -const maxCloudStorageConnections = 1000 - -type memKVsAndBuffers struct { - mu sync.Mutex - keys [][]byte - values [][]byte - // memKVBuffers contains two types of buffer, first half are used for small block - // buffer, second half are used for large one. - memKVBuffers []*membuf.Buffer - size int - droppedSize int - - // temporary fields to store KVs to reduce slice allocations. - keysPerFile [][][]byte - valuesPerFile [][][]byte - droppedSizePerFile []int -} - -func (b *memKVsAndBuffers) build(ctx context.Context) { - sumKVCnt := 0 - for _, keys := range b.keysPerFile { - sumKVCnt += len(keys) - } - b.droppedSize = 0 - for _, size := range b.droppedSizePerFile { - b.droppedSize += size - } - b.droppedSizePerFile = nil - - logutil.Logger(ctx).Info("building memKVsAndBuffers", - zap.Int("sumKVCnt", sumKVCnt), - zap.Int("droppedSize", b.droppedSize)) - - b.keys = make([][]byte, 0, sumKVCnt) - b.values = make([][]byte, 0, sumKVCnt) - for i := range b.keysPerFile { - b.keys = append(b.keys, b.keysPerFile[i]...) - b.keysPerFile[i] = nil - b.values = append(b.values, b.valuesPerFile[i]...) - b.valuesPerFile[i] = nil - } - b.keysPerFile = nil - b.valuesPerFile = nil -} - -// Engine stored sorted key/value pairs in an external storage. -type Engine struct { - storage storage.ExternalStorage - dataFiles []string - statsFiles []string - startKey []byte - endKey []byte - splitKeys [][]byte - regionSplitSize int64 - smallBlockBufPool *membuf.Pool - largeBlockBufPool *membuf.Pool - - memKVsAndBuffers memKVsAndBuffers - - // checkHotspot is true means we will check hotspot file when using MergeKVIter. - // if hotspot file is detected, we will use multiple readers to read data. - // if it's false, MergeKVIter will read each file using 1 reader. - // this flag also affects the strategy of loading data, either: - // less load routine + check and read hotspot file concurrently (add-index uses this one) - // more load routine + read each file using 1 reader (import-into uses this one) - checkHotspot bool - mergerIterConcurrency int - - keyAdapter common.KeyAdapter - duplicateDetection bool - duplicateDB *pebble.DB - dupDetectOpt common.DupDetectOpt - workerConcurrency int - ts uint64 - - totalKVSize int64 - totalKVCount int64 - - importedKVSize *atomic.Int64 - importedKVCount *atomic.Int64 -} - -const ( - memLimit = 12 * units.GiB - smallBlockSize = units.MiB -) - -// NewExternalEngine creates an (external) engine. -func NewExternalEngine( - storage storage.ExternalStorage, - dataFiles []string, - statsFiles []string, - startKey []byte, - endKey []byte, - splitKeys [][]byte, - regionSplitSize int64, - keyAdapter common.KeyAdapter, - duplicateDetection bool, - duplicateDB *pebble.DB, - dupDetectOpt common.DupDetectOpt, - workerConcurrency int, - ts uint64, - totalKVSize int64, - totalKVCount int64, - checkHotspot bool, -) common.Engine { - memLimiter := membuf.NewLimiter(memLimit) - return &Engine{ - storage: storage, - dataFiles: dataFiles, - statsFiles: statsFiles, - startKey: startKey, - endKey: endKey, - splitKeys: splitKeys, - regionSplitSize: regionSplitSize, - smallBlockBufPool: membuf.NewPool( - membuf.WithBlockNum(0), - membuf.WithPoolMemoryLimiter(memLimiter), - membuf.WithBlockSize(smallBlockSize), - ), - largeBlockBufPool: membuf.NewPool( - membuf.WithBlockNum(0), - membuf.WithPoolMemoryLimiter(memLimiter), - membuf.WithBlockSize(ConcurrentReaderBufferSizePerConc), - ), - checkHotspot: checkHotspot, - keyAdapter: keyAdapter, - duplicateDetection: duplicateDetection, - duplicateDB: duplicateDB, - dupDetectOpt: dupDetectOpt, - workerConcurrency: workerConcurrency, - ts: ts, - totalKVSize: totalKVSize, - totalKVCount: totalKVCount, - importedKVSize: atomic.NewInt64(0), - importedKVCount: atomic.NewInt64(0), - } -} - -func split[T any](in []T, groupNum int) [][]T { - if len(in) == 0 { - return nil - } - if groupNum <= 0 { - groupNum = 1 - } - ceil := (len(in) + groupNum - 1) / groupNum - ret := make([][]T, 0, groupNum) - l := len(in) - for i := 0; i < l; i += ceil { - if i+ceil > l { - ret = append(ret, in[i:]) - } else { - ret = append(ret, in[i:i+ceil]) - } - } - return ret -} - -func (e *Engine) getAdjustedConcurrency() int { - if e.checkHotspot { - // estimate we will open at most 8000 files, so if e.dataFiles is small we can - // try to concurrently process ranges. - adjusted := maxCloudStorageConnections / len(e.dataFiles) - if adjusted == 0 { - return 1 - } - return min(adjusted, 8) - } - adjusted := min(e.workerConcurrency, maxCloudStorageConnections/len(e.dataFiles)) - return max(adjusted, 1) -} - -func getFilesReadConcurrency( - ctx context.Context, - storage storage.ExternalStorage, - statsFiles []string, - startKey, endKey []byte, -) ([]uint64, []uint64, error) { - result := make([]uint64, len(statsFiles)) - offsets, err := seekPropsOffsets(ctx, []kv.Key{startKey, endKey}, statsFiles, storage) - if err != nil { - return nil, nil, err - } - startOffs, endOffs := offsets[0], offsets[1] - for i := range statsFiles { - expectedConc := (endOffs[i] - startOffs[i]) / uint64(ConcurrentReaderBufferSizePerConc) - // let the stat internals cover the [startKey, endKey) since seekPropsOffsets - // always return an offset that is less than or equal to the key. - expectedConc += 1 - // readAllData will enable concurrent read and use large buffer if result[i] > 1 - // when expectedConc < readAllDataConcThreshold, we don't use concurrent read to - // reduce overhead - if expectedConc >= readAllDataConcThreshold { - result[i] = expectedConc - } else { - result[i] = 1 - } - // only log for files with expected concurrency > 1, to avoid too many logs - if expectedConc > 1 { - logutil.Logger(ctx).Info("found hotspot file in getFilesReadConcurrency", - zap.String("filename", statsFiles[i]), - zap.Uint64("startOffset", startOffs[i]), - zap.Uint64("endOffset", endOffs[i]), - zap.Uint64("expectedConc", expectedConc), - zap.Uint64("concurrency", result[i]), - ) - } - } - return result, startOffs, nil -} - -func (e *Engine) loadBatchRegionData(ctx context.Context, startKey, endKey []byte, outCh chan<- common.DataAndRange) error { - readAndSortRateHist := metrics.GlobalSortReadFromCloudStorageRate.WithLabelValues("read_and_sort") - readAndSortDurHist := metrics.GlobalSortReadFromCloudStorageDuration.WithLabelValues("read_and_sort") - readRateHist := metrics.GlobalSortReadFromCloudStorageRate.WithLabelValues("read") - readDurHist := metrics.GlobalSortReadFromCloudStorageDuration.WithLabelValues("read") - sortRateHist := metrics.GlobalSortReadFromCloudStorageRate.WithLabelValues("sort") - sortDurHist := metrics.GlobalSortReadFromCloudStorageDuration.WithLabelValues("sort") - - readStart := time.Now() - readDtStartKey := e.keyAdapter.Encode(nil, startKey, common.MinRowID) - readDtEndKey := e.keyAdapter.Encode(nil, endKey, common.MinRowID) - err := readAllData( - ctx, - e.storage, - e.dataFiles, - e.statsFiles, - readDtStartKey, - readDtEndKey, - e.smallBlockBufPool, - e.largeBlockBufPool, - &e.memKVsAndBuffers, - ) - if err != nil { - return err - } - e.memKVsAndBuffers.build(ctx) - - readSecond := time.Since(readStart).Seconds() - readDurHist.Observe(readSecond) - logutil.Logger(ctx).Info("reading external storage in loadBatchRegionData", - zap.Duration("cost time", time.Since(readStart)), - zap.Int("droppedSize", e.memKVsAndBuffers.droppedSize)) - - sortStart := time.Now() - oldSortyGor := sorty.MaxGor - sorty.MaxGor = uint64(e.workerConcurrency * 2) - sorty.Sort(len(e.memKVsAndBuffers.keys), func(i, k, r, s int) bool { - if bytes.Compare(e.memKVsAndBuffers.keys[i], e.memKVsAndBuffers.keys[k]) < 0 { // strict comparator like < or > - if r != s { - e.memKVsAndBuffers.keys[r], e.memKVsAndBuffers.keys[s] = e.memKVsAndBuffers.keys[s], e.memKVsAndBuffers.keys[r] - e.memKVsAndBuffers.values[r], e.memKVsAndBuffers.values[s] = e.memKVsAndBuffers.values[s], e.memKVsAndBuffers.values[r] - } - return true - } - return false - }) - sorty.MaxGor = oldSortyGor - sortSecond := time.Since(sortStart).Seconds() - sortDurHist.Observe(sortSecond) - logutil.Logger(ctx).Info("sorting in loadBatchRegionData", - zap.Duration("cost time", time.Since(sortStart))) - - readAndSortSecond := time.Since(readStart).Seconds() - readAndSortDurHist.Observe(readAndSortSecond) - - size := e.memKVsAndBuffers.size - readAndSortRateHist.Observe(float64(size) / 1024.0 / 1024.0 / readAndSortSecond) - readRateHist.Observe(float64(size) / 1024.0 / 1024.0 / readSecond) - sortRateHist.Observe(float64(size) / 1024.0 / 1024.0 / sortSecond) - - data := e.buildIngestData( - e.memKVsAndBuffers.keys, - e.memKVsAndBuffers.values, - e.memKVsAndBuffers.memKVBuffers, - ) - - // release the reference of e.memKVsAndBuffers - e.memKVsAndBuffers.keys = nil - e.memKVsAndBuffers.values = nil - e.memKVsAndBuffers.memKVBuffers = nil - e.memKVsAndBuffers.size = 0 - - sendFn := func(dr common.DataAndRange) error { - select { - case <-ctx.Done(): - return ctx.Err() - case outCh <- dr: - } - return nil - } - return sendFn(common.DataAndRange{ - Data: data, - Range: common.Range{ - Start: startKey, - End: endKey, - }, - }) -} - -// LoadIngestData loads the data from the external storage to memory in [start, -// end) range, so local backend can ingest it. The used byte slice of ingest data -// are allocated from Engine.bufPool and must be released by -// MemoryIngestData.DecRef(). -func (e *Engine) LoadIngestData( - ctx context.Context, - regionRanges []common.Range, - outCh chan<- common.DataAndRange, -) error { - // try to make every worker busy for each batch - regionBatchSize := e.workerConcurrency - failpoint.Inject("LoadIngestDataBatchSize", func(val failpoint.Value) { - regionBatchSize = val.(int) - }) - for i := 0; i < len(regionRanges); i += regionBatchSize { - err := e.loadBatchRegionData(ctx, regionRanges[i].Start, regionRanges[min(i+regionBatchSize, len(regionRanges))-1].End, outCh) - if err != nil { - return err - } - } - return nil -} - -func (e *Engine) buildIngestData(keys, values [][]byte, buf []*membuf.Buffer) *MemoryIngestData { - return &MemoryIngestData{ - keyAdapter: e.keyAdapter, - duplicateDetection: e.duplicateDetection, - duplicateDB: e.duplicateDB, - dupDetectOpt: e.dupDetectOpt, - keys: keys, - values: values, - ts: e.ts, - memBuf: buf, - refCnt: atomic.NewInt64(0), - importedKVSize: e.importedKVSize, - importedKVCount: e.importedKVCount, - } -} - -// LargeRegionSplitDataThreshold is exposed for test. -var LargeRegionSplitDataThreshold = int(config.SplitRegionSize) - -// KVStatistics returns the total kv size and total kv count. -func (e *Engine) KVStatistics() (totalKVSize int64, totalKVCount int64) { - return e.totalKVSize, e.totalKVCount -} - -// ImportedStatistics returns the imported kv size and imported kv count. -func (e *Engine) ImportedStatistics() (importedSize int64, importedKVCount int64) { - return e.importedKVSize.Load(), e.importedKVCount.Load() -} - -// ID is the identifier of an engine. -func (e *Engine) ID() string { - return "external" -} - -// GetKeyRange implements common.Engine. -func (e *Engine) GetKeyRange() (startKey []byte, endKey []byte, err error) { - if _, ok := e.keyAdapter.(common.NoopKeyAdapter); ok { - return e.startKey, e.endKey, nil - } - - // when duplicate detection feature is enabled, the end key comes from - // DupDetectKeyAdapter.Encode or Key.Next(). We try to decode it and check the - // error. - - start, err := e.keyAdapter.Decode(nil, e.startKey) - if err != nil { - return nil, nil, err - } - end, err := e.keyAdapter.Decode(nil, e.endKey) - if err == nil { - return start, end, nil - } - // handle the case that end key is from Key.Next() - if e.endKey[len(e.endKey)-1] != 0 { - return nil, nil, err - } - endEncoded := e.endKey[:len(e.endKey)-1] - end, err = e.keyAdapter.Decode(nil, endEncoded) - if err != nil { - return nil, nil, err - } - return start, kv.Key(end).Next(), nil -} - -// SplitRanges split the ranges by split keys provided by external engine. -func (e *Engine) SplitRanges( - startKey, endKey []byte, - _, _ int64, - _ log.Logger, -) ([]common.Range, error) { - splitKeys := e.splitKeys - for i, k := range e.splitKeys { - var err error - splitKeys[i], err = e.keyAdapter.Decode(nil, k) - if err != nil { - return nil, err - } - } - ranges := make([]common.Range, 0, len(splitKeys)+1) - ranges = append(ranges, common.Range{Start: startKey}) - for i := 0; i < len(splitKeys); i++ { - ranges[len(ranges)-1].End = splitKeys[i] - var endK []byte - if i < len(splitKeys)-1 { - endK = splitKeys[i+1] - } - ranges = append(ranges, common.Range{Start: splitKeys[i], End: endK}) - } - ranges[len(ranges)-1].End = endKey - return ranges, nil -} - -// Close implements common.Engine. -func (e *Engine) Close() error { - if e.smallBlockBufPool != nil { - e.smallBlockBufPool.Destroy() - e.smallBlockBufPool = nil - } - if e.largeBlockBufPool != nil { - e.largeBlockBufPool.Destroy() - e.largeBlockBufPool = nil - } - e.storage.Close() - return nil -} - -// Reset resets the memory buffer pool. -func (e *Engine) Reset() error { - memLimiter := membuf.NewLimiter(memLimit) - if e.smallBlockBufPool != nil { - e.smallBlockBufPool.Destroy() - e.smallBlockBufPool = membuf.NewPool( - membuf.WithBlockNum(0), - membuf.WithPoolMemoryLimiter(memLimiter), - membuf.WithBlockSize(smallBlockSize), - ) - } - if e.largeBlockBufPool != nil { - e.largeBlockBufPool.Destroy() - e.largeBlockBufPool = membuf.NewPool( - membuf.WithBlockNum(0), - membuf.WithPoolMemoryLimiter(memLimiter), - membuf.WithBlockSize(ConcurrentReaderBufferSizePerConc), - ) - } - return nil -} - -// MemoryIngestData is the in-memory implementation of IngestData. -type MemoryIngestData struct { - keyAdapter common.KeyAdapter - duplicateDetection bool - duplicateDB *pebble.DB - dupDetectOpt common.DupDetectOpt - - keys [][]byte - values [][]byte - ts uint64 - - memBuf []*membuf.Buffer - refCnt *atomic.Int64 - importedKVSize *atomic.Int64 - importedKVCount *atomic.Int64 -} - -var _ common.IngestData = (*MemoryIngestData)(nil) - -func (m *MemoryIngestData) firstAndLastKeyIndex(lowerBound, upperBound []byte) (int, int) { - firstKeyIdx := 0 - if len(lowerBound) > 0 { - lowerBound = m.keyAdapter.Encode(nil, lowerBound, common.MinRowID) - firstKeyIdx = sort.Search(len(m.keys), func(i int) bool { - return bytes.Compare(lowerBound, m.keys[i]) <= 0 - }) - if firstKeyIdx == len(m.keys) { - return -1, -1 - } - } - - lastKeyIdx := len(m.keys) - 1 - if len(upperBound) > 0 { - upperBound = m.keyAdapter.Encode(nil, upperBound, common.MinRowID) - i := sort.Search(len(m.keys), func(i int) bool { - reverseIdx := len(m.keys) - 1 - i - return bytes.Compare(upperBound, m.keys[reverseIdx]) > 0 - }) - if i == len(m.keys) { - // should not happen - return -1, -1 - } - lastKeyIdx = len(m.keys) - 1 - i - } - return firstKeyIdx, lastKeyIdx -} - -// GetFirstAndLastKey implements IngestData.GetFirstAndLastKey. -func (m *MemoryIngestData) GetFirstAndLastKey(lowerBound, upperBound []byte) ([]byte, []byte, error) { - firstKeyIdx, lastKeyIdx := m.firstAndLastKeyIndex(lowerBound, upperBound) - if firstKeyIdx < 0 || firstKeyIdx > lastKeyIdx { - return nil, nil, nil - } - firstKey, err := m.keyAdapter.Decode(nil, m.keys[firstKeyIdx]) - if err != nil { - return nil, nil, err - } - lastKey, err := m.keyAdapter.Decode(nil, m.keys[lastKeyIdx]) - if err != nil { - return nil, nil, err - } - return firstKey, lastKey, nil -} - -type memoryDataIter struct { - keys [][]byte - values [][]byte - - firstKeyIdx int - lastKeyIdx int - curIdx int -} - -// First implements ForwardIter. -func (m *memoryDataIter) First() bool { - if m.firstKeyIdx < 0 { - return false - } - m.curIdx = m.firstKeyIdx - return true -} - -// Valid implements ForwardIter. -func (m *memoryDataIter) Valid() bool { - return m.firstKeyIdx <= m.curIdx && m.curIdx <= m.lastKeyIdx -} - -// Next implements ForwardIter. -func (m *memoryDataIter) Next() bool { - m.curIdx++ - return m.Valid() -} - -// Key implements ForwardIter. -func (m *memoryDataIter) Key() []byte { - return m.keys[m.curIdx] -} - -// Value implements ForwardIter. -func (m *memoryDataIter) Value() []byte { - return m.values[m.curIdx] -} - -// Close implements ForwardIter. -func (m *memoryDataIter) Close() error { - return nil -} - -// Error implements ForwardIter. -func (m *memoryDataIter) Error() error { - return nil -} - -// ReleaseBuf implements ForwardIter. -func (m *memoryDataIter) ReleaseBuf() {} - -type memoryDataDupDetectIter struct { - iter *memoryDataIter - dupDetector *common.DupDetector - err error - curKey, curVal []byte - buf *membuf.Buffer -} - -// First implements ForwardIter. -func (m *memoryDataDupDetectIter) First() bool { - if m.err != nil || !m.iter.First() { - return false - } - m.curKey, m.curVal, m.err = m.dupDetector.Init(m.iter) - return m.Valid() -} - -// Valid implements ForwardIter. -func (m *memoryDataDupDetectIter) Valid() bool { - return m.err == nil && m.iter.Valid() -} - -// Next implements ForwardIter. -func (m *memoryDataDupDetectIter) Next() bool { - if m.err != nil { - return false - } - key, val, ok, err := m.dupDetector.Next(m.iter) - if err != nil { - m.err = err - return false - } - if !ok { - return false - } - m.curKey, m.curVal = key, val - return true -} - -// Key implements ForwardIter. -func (m *memoryDataDupDetectIter) Key() []byte { - return m.buf.AddBytes(m.curKey) -} - -// Value implements ForwardIter. -func (m *memoryDataDupDetectIter) Value() []byte { - return m.buf.AddBytes(m.curVal) -} - -// Close implements ForwardIter. -func (m *memoryDataDupDetectIter) Close() error { - m.buf.Destroy() - return m.dupDetector.Close() -} - -// Error implements ForwardIter. -func (m *memoryDataDupDetectIter) Error() error { - return m.err -} - -// ReleaseBuf implements ForwardIter. -func (m *memoryDataDupDetectIter) ReleaseBuf() { - m.buf.Reset() -} - -// NewIter implements IngestData.NewIter. -func (m *MemoryIngestData) NewIter( - ctx context.Context, - lowerBound, upperBound []byte, - bufPool *membuf.Pool, -) common.ForwardIter { - firstKeyIdx, lastKeyIdx := m.firstAndLastKeyIndex(lowerBound, upperBound) - iter := &memoryDataIter{ - keys: m.keys, - values: m.values, - firstKeyIdx: firstKeyIdx, - lastKeyIdx: lastKeyIdx, - } - if !m.duplicateDetection { - return iter - } - logger := log.FromContext(ctx) - detector := common.NewDupDetector(m.keyAdapter, m.duplicateDB.NewBatch(), logger, m.dupDetectOpt) - return &memoryDataDupDetectIter{ - iter: iter, - dupDetector: detector, - buf: bufPool.NewBuffer(), - } -} - -// GetTS implements IngestData.GetTS. -func (m *MemoryIngestData) GetTS() uint64 { - return m.ts -} - -// IncRef implements IngestData.IncRef. -func (m *MemoryIngestData) IncRef() { - m.refCnt.Inc() -} - -// DecRef implements IngestData.DecRef. -func (m *MemoryIngestData) DecRef() { - if m.refCnt.Dec() == 0 { - m.keys = nil - m.values = nil - for _, b := range m.memBuf { - b.Destroy() - } - } -} - -// Finish implements IngestData.Finish. -func (m *MemoryIngestData) Finish(totalBytes, totalCount int64) { - m.importedKVSize.Add(totalBytes) - m.importedKVCount.Add(totalCount) - -} diff --git a/pkg/lightning/backend/external/merge_v2.go b/pkg/lightning/backend/external/merge_v2.go index 49de736b26ade..af32569d6745d 100644 --- a/pkg/lightning/backend/external/merge_v2.go +++ b/pkg/lightning/backend/external/merge_v2.go @@ -67,9 +67,9 @@ func MergeOverlappingFilesV2( }() rangesGroupSize := 4 * size.GB - if val, _err_ := failpoint.Eval(_curpkg_("mockRangesGroupSize")); _err_ == nil { + failpoint.Inject("mockRangesGroupSize", func(val failpoint.Value) { rangesGroupSize = uint64(val.(int)) - } + }) splitter, err := NewRangeSplitter( ctx, diff --git a/pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ b/pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ deleted file mode 100644 index af32569d6745d..0000000000000 --- a/pkg/lightning/backend/external/merge_v2.go__failpoint_stash__ +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package external - -import ( - "bytes" - "context" - "math" - "time" - - "github.com/jfcg/sorty/v2" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/membuf" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/size" - "go.uber.org/zap" -) - -// MergeOverlappingFilesV2 reads from given files whose key range may overlap -// and writes to new sorted, nonoverlapping files. -// Using 1 readAllData and 1 writer. -func MergeOverlappingFilesV2( - ctx context.Context, - multiFileStat []MultipleFilesStat, - store storage.ExternalStorage, - startKey []byte, - endKey []byte, - partSize int64, - newFilePrefix string, - writerID string, - blockSize int, - writeBatchCount uint64, - propSizeDist uint64, - propKeysDist uint64, - onClose OnCloseFunc, - concurrency int, - checkHotspot bool, -) (err error) { - fileCnt := 0 - for _, m := range multiFileStat { - fileCnt += len(m.Filenames) - } - task := log.BeginTask(logutil.Logger(ctx).With( - zap.Int("file-count", fileCnt), - zap.Binary("start-key", startKey), - zap.Binary("end-key", endKey), - zap.String("new-file-prefix", newFilePrefix), - zap.Int("concurrency", concurrency), - ), "merge overlapping files") - defer func() { - task.End(zap.ErrorLevel, err) - }() - - rangesGroupSize := 4 * size.GB - failpoint.Inject("mockRangesGroupSize", func(val failpoint.Value) { - rangesGroupSize = uint64(val.(int)) - }) - - splitter, err := NewRangeSplitter( - ctx, - multiFileStat, - store, - int64(rangesGroupSize), - math.MaxInt64, - int64(4*size.GB), - math.MaxInt64, - ) - if err != nil { - return err - } - - writer := NewWriterBuilder(). - SetMemorySizeLimit(DefaultMemSizeLimit). - SetBlockSize(blockSize). - SetPropKeysDistance(propKeysDist). - SetPropSizeDistance(propSizeDist). - SetOnCloseFunc(onClose). - BuildOneFile(store, newFilePrefix, writerID) - defer func() { - err = splitter.Close() - if err != nil { - logutil.Logger(ctx).Warn("close range splitter failed", zap.Error(err)) - } - err = writer.Close(ctx) - if err != nil { - logutil.Logger(ctx).Warn("close writer failed", zap.Error(err)) - } - }() - - err = writer.Init(ctx, partSize) - if err != nil { - logutil.Logger(ctx).Warn("init writer failed", zap.Error(err)) - return - } - - bufPool := membuf.NewPool() - loaded := &memKVsAndBuffers{} - curStart := kv.Key(startKey).Clone() - var curEnd kv.Key - - for { - endKeyOfGroup, dataFilesOfGroup, statFilesOfGroup, _, err1 := splitter.SplitOneRangesGroup() - if err1 != nil { - logutil.Logger(ctx).Warn("split one ranges group failed", zap.Error(err1)) - return - } - curEnd = kv.Key(endKeyOfGroup).Clone() - if len(endKeyOfGroup) == 0 { - curEnd = kv.Key(endKey).Clone() - } - now := time.Now() - err1 = readAllData( - ctx, - store, - dataFilesOfGroup, - statFilesOfGroup, - curStart, - curEnd, - bufPool, - bufPool, - loaded, - ) - if err1 != nil { - logutil.Logger(ctx).Warn("read all data failed", zap.Error(err1)) - return - } - loaded.build(ctx) - readTime := time.Since(now) - now = time.Now() - sorty.MaxGor = uint64(concurrency) - sorty.Sort(len(loaded.keys), func(i, k, r, s int) bool { - if bytes.Compare(loaded.keys[i], loaded.keys[k]) < 0 { // strict comparator like < or > - if r != s { - loaded.keys[r], loaded.keys[s] = loaded.keys[s], loaded.keys[r] - loaded.values[r], loaded.values[s] = loaded.values[s], loaded.values[r] - } - return true - } - return false - }) - sortTime := time.Since(now) - now = time.Now() - for i, key := range loaded.keys { - err1 = writer.WriteRow(ctx, key, loaded.values[i]) - if err1 != nil { - logutil.Logger(ctx).Warn("write one row to writer failed", zap.Error(err1)) - return - } - } - writeTime := time.Since(now) - logutil.Logger(ctx).Info("sort one group in MergeOverlappingFiles", - zap.Duration("read time", readTime), - zap.Duration("sort time", sortTime), - zap.Duration("write time", writeTime), - zap.Int("key len", len(loaded.keys))) - - curStart = curEnd.Clone() - loaded.keys = nil - loaded.values = nil - loaded.memKVBuffers = nil - loaded.size = 0 - - if len(endKeyOfGroup) == 0 { - break - } - } - return -} diff --git a/pkg/lightning/backend/local/binding__failpoint_binding__.go b/pkg/lightning/backend/local/binding__failpoint_binding__.go deleted file mode 100644 index 1ae5c5095df43..0000000000000 --- a/pkg/lightning/backend/local/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package local - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/lightning/backend/local/checksum.go b/pkg/lightning/backend/local/checksum.go index 206717d957393..967e32d37ea46 100644 --- a/pkg/lightning/backend/local/checksum.go +++ b/pkg/lightning/backend/local/checksum.go @@ -241,7 +241,7 @@ func increaseGCLifeTime(ctx context.Context, manager *gcLifeTimeManager, db *sql } } - failpoint.Eval(_curpkg_("IncreaseGCUpdateDuration")) + failpoint.Inject("IncreaseGCUpdateDuration", nil) return nil } diff --git a/pkg/lightning/backend/local/checksum.go__failpoint_stash__ b/pkg/lightning/backend/local/checksum.go__failpoint_stash__ deleted file mode 100644 index 967e32d37ea46..0000000000000 --- a/pkg/lightning/backend/local/checksum.go__failpoint_stash__ +++ /dev/null @@ -1,517 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package local - -import ( - "container/heap" - "context" - "database/sql" - "fmt" - "sync" - "time" - - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/checksum" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tipb/go-tipb" - tikvstore "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/oracle" - pd "github.com/tikv/pd/client" - pderrs "github.com/tikv/pd/client/errs" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -const ( - preUpdateServiceSafePointFactor = 3 - maxErrorRetryCount = 3 - defaultGCLifeTime = 100 * time.Hour -) - -var ( - serviceSafePointTTL int64 = 10 * 60 // 10 min in seconds - - // MinDistSQLScanConcurrency is the minimum value of tidb_distsql_scan_concurrency. - MinDistSQLScanConcurrency = 4 - - // DefaultBackoffWeight is the default value of tidb_backoff_weight for checksum. - // RegionRequestSender will retry within a maxSleep time, default is 2 * 20 = 40 seconds. - // When TiKV client encounters an error of "region not leader", it will keep - // retrying every 500 ms, if it still fails after maxSleep, it will return "region unavailable". - // When there are many pending compaction bytes, TiKV might not respond within 1m, - // and report "rpcError:wait recvLoop timeout,timeout:1m0s", and retry might - // time out again. - // so we enlarge it to 30 * 20 = 10 minutes. - DefaultBackoffWeight = 15 * tikvstore.DefBackOffWeight -) - -// RemoteChecksum represents a checksum result got from tidb. -type RemoteChecksum struct { - Schema string - Table string - Checksum uint64 - TotalKVs uint64 - TotalBytes uint64 -} - -// IsEqual checks whether the checksum is equal to the other. -func (rc *RemoteChecksum) IsEqual(other *verification.KVChecksum) bool { - return rc.Checksum == other.Sum() && - rc.TotalKVs == other.SumKVS() && - rc.TotalBytes == other.SumSize() -} - -// ChecksumManager is a manager that manages checksums. -type ChecksumManager interface { - Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) -} - -// fetch checksum for tidb sql client -type tidbChecksumExecutor struct { - db *sql.DB - manager *gcLifeTimeManager -} - -var _ ChecksumManager = (*tidbChecksumExecutor)(nil) - -// NewTiDBChecksumExecutor creates a new tidb checksum executor. -func NewTiDBChecksumExecutor(db *sql.DB) ChecksumManager { - return &tidbChecksumExecutor{ - db: db, - manager: newGCLifeTimeManager(), - } -} - -func (e *tidbChecksumExecutor) Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) { - var err error - if err = e.manager.addOneJob(ctx, e.db); err != nil { - return nil, err - } - - // set it back finally - defer e.manager.removeOneJob(ctx, e.db) - - tableName := common.UniqueTable(tableInfo.DB, tableInfo.Name) - - task := log.FromContext(ctx).With(zap.String("table", tableName)).Begin(zap.InfoLevel, "remote checksum") - - conn, err := e.db.Conn(ctx) - if err != nil { - return nil, errors.Trace(err) - } - defer func() { - if err := conn.Close(); err != nil { - task.Warn("close connection failed", zap.Error(err)) - } - }() - // ADMIN CHECKSUM TABLE
,
example. - // mysql> admin checksum table test.t; - // +---------+------------+---------------------+-----------+-------------+ - // | Db_name | Table_name | Checksum_crc64_xor | Total_kvs | Total_bytes | - // +---------+------------+---------------------+-----------+-------------+ - // | test | t | 8520875019404689597 | 7296873 | 357601387 | - // +---------+------------+---------------------+-----------+-------------+ - backoffWeight, err := common.GetBackoffWeightFromDB(ctx, e.db) - if err == nil && backoffWeight < DefaultBackoffWeight { - task.Info("increase tidb_backoff_weight", zap.Int("original", backoffWeight), zap.Int("new", DefaultBackoffWeight)) - // increase backoff weight - if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, DefaultBackoffWeight)); err != nil { - task.Warn("set tidb_backoff_weight failed", zap.Error(err)) - } else { - defer func() { - if _, err := conn.ExecContext(ctx, fmt.Sprintf("SET SESSION %s = '%d';", variable.TiDBBackOffWeight, backoffWeight)); err != nil { - task.Warn("recover tidb_backoff_weight failed", zap.Error(err)) - } - }() - } - } - - cs := RemoteChecksum{} - err = common.SQLWithRetry{DB: conn, Logger: task.Logger}.QueryRow(ctx, "compute remote checksum", - "ADMIN CHECKSUM TABLE "+tableName, &cs.Schema, &cs.Table, &cs.Checksum, &cs.TotalKVs, &cs.TotalBytes, - ) - dur := task.End(zap.ErrorLevel, err) - if m, ok := metric.FromContext(ctx); ok { - m.ChecksumSecondsHistogram.Observe(dur.Seconds()) - } - if err != nil { - return nil, errors.Trace(err) - } - return &cs, nil -} - -type gcLifeTimeManager struct { - runningJobsLock sync.Mutex - runningJobs int - oriGCLifeTime string -} - -func newGCLifeTimeManager() *gcLifeTimeManager { - // Default values of three member are enough to initialize this struct - return &gcLifeTimeManager{} -} - -// Pre- and post-condition: -// if m.runningJobs == 0, GC life time has not been increased. -// if m.runningJobs > 0, GC life time has been increased. -// m.runningJobs won't be negative(overflow) since index concurrency is relatively small -func (m *gcLifeTimeManager) addOneJob(ctx context.Context, db *sql.DB) error { - m.runningJobsLock.Lock() - defer m.runningJobsLock.Unlock() - - if m.runningJobs == 0 { - oriGCLifeTime, err := obtainGCLifeTime(ctx, db) - if err != nil { - return err - } - m.oriGCLifeTime = oriGCLifeTime - err = increaseGCLifeTime(ctx, m, db) - if err != nil { - return err - } - } - m.runningJobs++ - return nil -} - -// Pre- and post-condition: -// if m.runningJobs == 0, GC life time has been tried to recovered. If this try fails, a warning will be printed. -// if m.runningJobs > 0, GC life time has not been recovered. -// m.runningJobs won't minus to negative since removeOneJob follows a successful addOneJob. -func (m *gcLifeTimeManager) removeOneJob(ctx context.Context, db *sql.DB) { - m.runningJobsLock.Lock() - defer m.runningJobsLock.Unlock() - - m.runningJobs-- - if m.runningJobs == 0 { - err := updateGCLifeTime(ctx, db, m.oriGCLifeTime) - if err != nil { - query := fmt.Sprintf( - "UPDATE mysql.tidb SET VARIABLE_VALUE = '%s' WHERE VARIABLE_NAME = 'tikv_gc_life_time'", - m.oriGCLifeTime, - ) - log.FromContext(ctx).Warn("revert GC lifetime failed, please reset the GC lifetime manually after Lightning completed", - zap.String("query", query), - log.ShortError(err), - ) - } - } -} - -func increaseGCLifeTime(ctx context.Context, manager *gcLifeTimeManager, db *sql.DB) (err error) { - // checksum command usually takes a long time to execute, - // so here need to increase the gcLifeTime for single transaction. - var increaseGCLifeTime bool - if manager.oriGCLifeTime != "" { - ori, err := time.ParseDuration(manager.oriGCLifeTime) - if err != nil { - return errors.Trace(err) - } - if ori < defaultGCLifeTime { - increaseGCLifeTime = true - } - } else { - increaseGCLifeTime = true - } - - if increaseGCLifeTime { - err = updateGCLifeTime(ctx, db, defaultGCLifeTime.String()) - if err != nil { - return err - } - } - - failpoint.Inject("IncreaseGCUpdateDuration", nil) - - return nil -} - -// obtainGCLifeTime obtains the current GC lifetime. -func obtainGCLifeTime(ctx context.Context, db *sql.DB) (string, error) { - var gcLifeTime string - err := common.SQLWithRetry{DB: db, Logger: log.FromContext(ctx)}.QueryRow( - ctx, - "obtain GC lifetime", - "SELECT VARIABLE_VALUE FROM mysql.tidb WHERE VARIABLE_NAME = 'tikv_gc_life_time'", - &gcLifeTime, - ) - return gcLifeTime, err -} - -// updateGCLifeTime updates the current GC lifetime. -func updateGCLifeTime(ctx context.Context, db *sql.DB, gcLifeTime string) error { - sql := common.SQLWithRetry{ - DB: db, - Logger: log.FromContext(ctx).With(zap.String("gcLifeTime", gcLifeTime)), - } - return sql.Exec(ctx, "update GC lifetime", - "UPDATE mysql.tidb SET VARIABLE_VALUE = ? WHERE VARIABLE_NAME = 'tikv_gc_life_time'", - gcLifeTime, - ) -} - -// TiKVChecksumManager is a manager that can compute checksum of a table using TiKV. -type TiKVChecksumManager struct { - client kv.Client - manager gcTTLManager - distSQLScanConcurrency uint - backoffWeight int - resourceGroupName string - explicitRequestSourceType string -} - -var _ ChecksumManager = &TiKVChecksumManager{} - -// NewTiKVChecksumManager return a new tikv checksum manager -func NewTiKVChecksumManager(client kv.Client, pdClient pd.Client, distSQLScanConcurrency uint, backoffWeight int, resourceGroupName, explicitRequestSourceType string) *TiKVChecksumManager { - return &TiKVChecksumManager{ - client: client, - manager: newGCTTLManager(pdClient), - distSQLScanConcurrency: distSQLScanConcurrency, - backoffWeight: backoffWeight, - resourceGroupName: resourceGroupName, - explicitRequestSourceType: explicitRequestSourceType, - } -} - -func (e *TiKVChecksumManager) checksumDB(ctx context.Context, tableInfo *checkpoints.TidbTableInfo, ts uint64) (*RemoteChecksum, error) { - executor, err := checksum.NewExecutorBuilder(tableInfo.Core, ts). - SetConcurrency(e.distSQLScanConcurrency). - SetBackoffWeight(e.backoffWeight). - SetResourceGroupName(e.resourceGroupName). - SetExplicitRequestSourceType(e.explicitRequestSourceType). - Build() - if err != nil { - return nil, errors.Trace(err) - } - - distSQLScanConcurrency := int(e.distSQLScanConcurrency) - for i := 0; i < maxErrorRetryCount; i++ { - _ = executor.Each(func(request *kv.Request) error { - request.Concurrency = distSQLScanConcurrency - return nil - }) - var execRes *tipb.ChecksumResponse - execRes, err = executor.Execute(ctx, e.client, func() {}) - if err == nil { - return &RemoteChecksum{ - Schema: tableInfo.DB, - Table: tableInfo.Name, - Checksum: execRes.Checksum, - TotalBytes: execRes.TotalBytes, - TotalKVs: execRes.TotalKvs, - }, nil - } - - log.FromContext(ctx).Warn("remote checksum failed", zap.String("db", tableInfo.DB), - zap.String("table", tableInfo.Name), zap.Error(err), - zap.Int("concurrency", distSQLScanConcurrency), zap.Int("retry", i)) - - // do not retry context.Canceled error - if !common.IsRetryableError(err) { - break - } - if distSQLScanConcurrency > MinDistSQLScanConcurrency { - distSQLScanConcurrency = max(distSQLScanConcurrency/2, MinDistSQLScanConcurrency) - } - } - - return nil, err -} - -var retryGetTSInterval = time.Second - -// Checksum implements the ChecksumManager interface. -func (e *TiKVChecksumManager) Checksum(ctx context.Context, tableInfo *checkpoints.TidbTableInfo) (*RemoteChecksum, error) { - tbl := common.UniqueTable(tableInfo.DB, tableInfo.Name) - var ( - physicalTS, logicalTS int64 - err error - retryTime int - ) - physicalTS, logicalTS, err = e.manager.pdClient.GetTS(ctx) - for err != nil { - if !pderrs.IsLeaderChange(errors.Cause(err)) { - return nil, errors.Annotate(err, "fetch tso from pd failed") - } - retryTime++ - if retryTime%60 == 0 { - log.FromContext(ctx).Warn("fetch tso from pd failed and retrying", - zap.Int("retryTime", retryTime), - zap.Error(err)) - } - select { - case <-ctx.Done(): - err = ctx.Err() - case <-time.After(retryGetTSInterval): - physicalTS, logicalTS, err = e.manager.pdClient.GetTS(ctx) - } - } - ts := oracle.ComposeTS(physicalTS, logicalTS) - if err := e.manager.addOneJob(ctx, tbl, ts); err != nil { - return nil, errors.Trace(err) - } - defer e.manager.removeOneJob(tbl) - - return e.checksumDB(ctx, tableInfo, ts) -} - -type tableChecksumTS struct { - table string - gcSafeTS uint64 -} - -// following function are for implement `heap.Interface` - -func (m *gcTTLManager) Len() int { - return len(m.tableGCSafeTS) -} - -func (m *gcTTLManager) Less(i, j int) bool { - return m.tableGCSafeTS[i].gcSafeTS < m.tableGCSafeTS[j].gcSafeTS -} - -func (m *gcTTLManager) Swap(i, j int) { - m.tableGCSafeTS[i], m.tableGCSafeTS[j] = m.tableGCSafeTS[j], m.tableGCSafeTS[i] -} - -func (m *gcTTLManager) Push(x any) { - m.tableGCSafeTS = append(m.tableGCSafeTS, x.(*tableChecksumTS)) -} - -func (m *gcTTLManager) Pop() any { - i := m.tableGCSafeTS[len(m.tableGCSafeTS)-1] - m.tableGCSafeTS = m.tableGCSafeTS[:len(m.tableGCSafeTS)-1] - return i -} - -type gcTTLManager struct { - lock sync.Mutex - pdClient pd.Client - // tableGCSafeTS is a binary heap that stored active checksum jobs GC safe point ts - tableGCSafeTS []*tableChecksumTS - currentTS uint64 - serviceID string - // 0 for not start, otherwise started - started atomic.Bool -} - -func newGCTTLManager(pdClient pd.Client) gcTTLManager { - return gcTTLManager{ - pdClient: pdClient, - serviceID: fmt.Sprintf("lightning-%s", uuid.New()), - } -} - -func (m *gcTTLManager) addOneJob(ctx context.Context, table string, ts uint64) error { - // start gc ttl loop if not started yet. - if m.started.CompareAndSwap(false, true) { - m.start(ctx) - } - m.lock.Lock() - defer m.lock.Unlock() - var curTS uint64 - if len(m.tableGCSafeTS) > 0 { - curTS = m.tableGCSafeTS[0].gcSafeTS - } - m.Push(&tableChecksumTS{table: table, gcSafeTS: ts}) - heap.Fix(m, len(m.tableGCSafeTS)-1) - m.currentTS = m.tableGCSafeTS[0].gcSafeTS - if curTS == 0 || m.currentTS < curTS { - return m.doUpdateGCTTL(ctx, m.currentTS) - } - return nil -} - -func (m *gcTTLManager) removeOneJob(table string) { - m.lock.Lock() - defer m.lock.Unlock() - idx := -1 - for i := 0; i < len(m.tableGCSafeTS); i++ { - if m.tableGCSafeTS[i].table == table { - idx = i - break - } - } - - if idx >= 0 { - l := len(m.tableGCSafeTS) - m.tableGCSafeTS[idx] = m.tableGCSafeTS[l-1] - m.tableGCSafeTS = m.tableGCSafeTS[:l-1] - if l > 1 && idx < l-1 { - heap.Fix(m, idx) - } - } - - var newTS uint64 - if len(m.tableGCSafeTS) > 0 { - newTS = m.tableGCSafeTS[0].gcSafeTS - } - m.currentTS = newTS -} - -func (m *gcTTLManager) updateGCTTL(ctx context.Context) error { - m.lock.Lock() - currentTS := m.currentTS - m.lock.Unlock() - return m.doUpdateGCTTL(ctx, currentTS) -} - -func (m *gcTTLManager) doUpdateGCTTL(ctx context.Context, ts uint64) error { - log.FromContext(ctx).Debug("update PD safePoint limit with TTL", - zap.Uint64("currnet_ts", ts)) - var err error - if ts > 0 { - _, err = m.pdClient.UpdateServiceGCSafePoint(ctx, - m.serviceID, serviceSafePointTTL, ts) - } - return err -} - -func (m *gcTTLManager) start(ctx context.Context) { - // It would be OK since TTL won't be zero, so gapTime should > `0. - updateGapTime := time.Duration(serviceSafePointTTL) * time.Second / preUpdateServiceSafePointFactor - - updateTick := time.NewTicker(updateGapTime) - - updateGCTTL := func() { - if err := m.updateGCTTL(ctx); err != nil { - log.FromContext(ctx).Warn("failed to update service safe point, checksum may fail if gc triggered", zap.Error(err)) - } - } - - // trigger a service gc ttl at start - updateGCTTL() - go func() { - defer updateTick.Stop() - for { - select { - case <-ctx.Done(): - log.FromContext(ctx).Info("service safe point keeper exited") - return - case <-updateTick.C: - updateGCTTL() - } - } - }() -} diff --git a/pkg/lightning/backend/local/engine.go b/pkg/lightning/backend/local/engine.go index a6bb4c8c7c887..9374b2ade74ff 100644 --- a/pkg/lightning/backend/local/engine.go +++ b/pkg/lightning/backend/local/engine.go @@ -1026,9 +1026,9 @@ func (e *Engine) GetFirstAndLastKey(lowerBound, upperBound []byte) ([]byte, []by LowerBound: lowerBound, UpperBound: upperBound, } - if _, _err_ := failpoint.Eval(_curpkg_("mockGetFirstAndLastKey")); _err_ == nil { - return lowerBound, upperBound, nil - } + failpoint.Inject("mockGetFirstAndLastKey", func() { + failpoint.Return(lowerBound, upperBound, nil) + }) iter := e.newKVIter(context.Background(), opt, nil) //nolint: errcheck @@ -1332,13 +1332,13 @@ func (w *Writer) flushKVs(ctx context.Context) error { return errors.Trace(err) } - if _, _err_ := failpoint.Eval(_curpkg_("orphanWriterGoRoutine")); _err_ == nil { + failpoint.Inject("orphanWriterGoRoutine", func() { _ = common.KillMySelf() // mimic we meet context cancel error when `addSST` <-ctx.Done() time.Sleep(5 * time.Second) - return errors.Trace(ctx.Err()) - } + failpoint.Return(errors.Trace(ctx.Err())) + }) err = w.addSST(ctx, meta) if err != nil { diff --git a/pkg/lightning/backend/local/engine.go__failpoint_stash__ b/pkg/lightning/backend/local/engine.go__failpoint_stash__ deleted file mode 100644 index 9374b2ade74ff..0000000000000 --- a/pkg/lightning/backend/local/engine.go__failpoint_stash__ +++ /dev/null @@ -1,1682 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package local - -import ( - "bytes" - "container/heap" - "context" - "encoding/binary" - "encoding/json" - "fmt" - "io" - "os" - "path/filepath" - "slices" - "sync" - "time" - "unsafe" - - "github.com/cockroachdb/pebble" - "github.com/cockroachdb/pebble/objstorage/objstorageprovider" - "github.com/cockroachdb/pebble/sstable" - "github.com/cockroachdb/pebble/vfs" - "github.com/google/btree" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/membuf" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/checkpoints" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/tikv/client-go/v2/tikv" - "go.uber.org/atomic" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -var ( - engineMetaKey = []byte{0, 'm', 'e', 't', 'a'} - normalIterStartKey = []byte{1} -) - -type importMutexState uint32 - -const ( - importMutexStateImport importMutexState = 1 << iota - importMutexStateClose - // importMutexStateReadLock is a special state because in this state we lock engine with read lock - // and add isImportingAtomic with this value. In other state, we directly store with the state value. - // so this must always the last value of this enum. - importMutexStateReadLock - // we need to lock the engine when it's open as we do when it's close, otherwise GetEngienSize may race with OpenEngine - importMutexStateOpen -) - -const ( - // DupDetectDirSuffix is used by pre-deduplication to store the encoded index KV. - DupDetectDirSuffix = ".dupdetect" - // DupResultDirSuffix is used by pre-deduplication to store the duplicated row ID. - DupResultDirSuffix = ".dupresult" -) - -// engineMeta contains some field that is necessary to continue the engine restore/import process. -// These field should be written to disk when we update chunk checkpoint -type engineMeta struct { - TS uint64 `json:"ts"` - // Length is the number of KV pairs stored by the engine. - Length atomic.Int64 `json:"length"` - // TotalSize is the total pre-compressed KV byte size stored by engine. - TotalSize atomic.Int64 `json:"total_size"` -} - -type syncedRanges struct { - sync.Mutex - ranges []common.Range -} - -func (r *syncedRanges) add(g common.Range) { - r.Lock() - r.ranges = append(r.ranges, g) - r.Unlock() -} - -func (r *syncedRanges) reset() { - r.Lock() - r.ranges = r.ranges[:0] - r.Unlock() -} - -// Engine is a local engine. -type Engine struct { - engineMeta - closed atomic.Bool - db atomic.Pointer[pebble.DB] - UUID uuid.UUID - localWriters sync.Map - - // isImportingAtomic is an atomic variable indicating whether this engine is importing. - // This should not be used as a "spin lock" indicator. - isImportingAtomic atomic.Uint32 - // flush and ingest sst hold the rlock, other operation hold the wlock. - mutex sync.RWMutex - - ctx context.Context - cancel context.CancelFunc - sstDir string - sstMetasChan chan metaOrFlush - ingestErr common.OnceError - wg sync.WaitGroup - sstIngester sstIngester - - // sst seq lock - seqLock sync.Mutex - // seq number for incoming sst meta - nextSeq int32 - // max seq of sst metas ingested into pebble - finishedMetaSeq atomic.Int32 - - config backend.LocalEngineConfig - tableInfo *checkpoints.TidbTableInfo - - dupDetectOpt common.DupDetectOpt - - // total size of SST files waiting to be ingested - pendingFileSize atomic.Int64 - - // statistics for pebble kv iter. - importedKVSize atomic.Int64 - importedKVCount atomic.Int64 - - keyAdapter common.KeyAdapter - duplicateDetection bool - duplicateDB *pebble.DB - - logger log.Logger -} - -func (e *Engine) setError(err error) { - if err != nil { - e.ingestErr.Set(err) - e.cancel() - } -} - -func (e *Engine) getDB() *pebble.DB { - return e.db.Load() -} - -// Close closes the engine and release all resources. -func (e *Engine) Close() error { - e.logger.Debug("closing local engine", zap.Stringer("engine", e.UUID), zap.Stack("stack")) - db := e.getDB() - if db == nil { - return nil - } - err := errors.Trace(db.Close()) - e.db.Store(nil) - return err -} - -// Cleanup remove meta, db and duplicate detection files -func (e *Engine) Cleanup(dataDir string) error { - if err := os.RemoveAll(e.sstDir); err != nil { - return errors.Trace(err) - } - uuid := e.UUID.String() - if err := os.RemoveAll(filepath.Join(dataDir, uuid+DupDetectDirSuffix)); err != nil { - return errors.Trace(err) - } - if err := os.RemoveAll(filepath.Join(dataDir, uuid+DupResultDirSuffix)); err != nil { - return errors.Trace(err) - } - - dbPath := filepath.Join(dataDir, uuid) - return errors.Trace(os.RemoveAll(dbPath)) -} - -// Exist checks if db folder existing (meta sometimes won't flush before lightning exit) -func (e *Engine) Exist(dataDir string) error { - dbPath := filepath.Join(dataDir, e.UUID.String()) - if _, err := os.Stat(dbPath); err != nil { - return err - } - return nil -} - -func isStateLocked(state importMutexState) bool { - return state&(importMutexStateClose|importMutexStateImport) != 0 -} - -func (e *Engine) isLocked() bool { - // the engine is locked only in import or close state. - return isStateLocked(importMutexState(e.isImportingAtomic.Load())) -} - -// rLock locks the local file with shard read state. Only used for flush and ingest SST files. -func (e *Engine) rLock() { - e.mutex.RLock() - e.isImportingAtomic.Add(uint32(importMutexStateReadLock)) -} - -func (e *Engine) rUnlock() { - if e == nil { - return - } - - e.isImportingAtomic.Sub(uint32(importMutexStateReadLock)) - e.mutex.RUnlock() -} - -// lock locks the local file for importing. -func (e *Engine) lock(state importMutexState) { - e.mutex.Lock() - e.isImportingAtomic.Store(uint32(state)) -} - -// lockUnless tries to lock the local file unless it is already locked into the state given by -// ignoreStateMask. Returns whether the lock is successful. -func (e *Engine) lockUnless(newState, ignoreStateMask importMutexState) bool { - curState := e.isImportingAtomic.Load() - if curState&uint32(ignoreStateMask) != 0 { - return false - } - e.lock(newState) - return true -} - -// tryRLock tries to read-lock the local file unless it is already write locked. -// Returns whether the lock is successful. -func (e *Engine) tryRLock() bool { - curState := e.isImportingAtomic.Load() - // engine is in import/close state. - if isStateLocked(importMutexState(curState)) { - return false - } - e.rLock() - return true -} - -func (e *Engine) unlock() { - if e == nil { - return - } - e.isImportingAtomic.Store(0) - e.mutex.Unlock() -} - -var sizeOfKVPair = int64(unsafe.Sizeof(common.KvPair{})) - -// TotalMemorySize returns the total memory size of the engine. -func (e *Engine) TotalMemorySize() int64 { - var memSize int64 - e.localWriters.Range(func(k, _ any) bool { - w := k.(*Writer) - if w.kvBuffer != nil { - w.Lock() - memSize += w.kvBuffer.TotalSize() - w.Unlock() - } - w.Lock() - memSize += sizeOfKVPair * int64(cap(w.writeBatch)) - w.Unlock() - return true - }) - return memSize -} - -// KVStatistics returns the total kv size and total kv count. -func (e *Engine) KVStatistics() (totalSize int64, totalKVCount int64) { - return e.TotalSize.Load(), e.Length.Load() -} - -// ImportedStatistics returns the imported kv size and imported kv count. -func (e *Engine) ImportedStatistics() (importedSize int64, importedKVCount int64) { - return e.importedKVSize.Load(), e.importedKVCount.Load() -} - -// ID is the identifier of an engine. -func (e *Engine) ID() string { - return e.UUID.String() -} - -// GetKeyRange implements common.Engine. -func (e *Engine) GetKeyRange() (startKey []byte, endKey []byte, err error) { - firstLey, lastKey, err := e.GetFirstAndLastKey(nil, nil) - if err != nil { - return nil, nil, errors.Trace(err) - } - return firstLey, nextKey(lastKey), nil -} - -// SplitRanges gets size properties from pebble and split ranges according to size/keys limit. -func (e *Engine) SplitRanges( - startKey, endKey []byte, - sizeLimit, keysLimit int64, - logger log.Logger, -) ([]common.Range, error) { - sizeProps, err := getSizePropertiesFn(logger, e.getDB(), e.keyAdapter) - if err != nil { - return nil, errors.Trace(err) - } - - ranges := splitRangeBySizeProps( - common.Range{Start: startKey, End: endKey}, - sizeProps, - sizeLimit, - keysLimit, - ) - return ranges, nil -} - -type rangeOffsets struct { - Size uint64 - Keys uint64 -} - -type rangeProperty struct { - Key []byte - rangeOffsets -} - -// Less implements btree.Item interface. -func (r *rangeProperty) Less(than btree.Item) bool { - ta := than.(*rangeProperty) - return bytes.Compare(r.Key, ta.Key) < 0 -} - -var _ btree.Item = &rangeProperty{} - -type rangeProperties []rangeProperty - -// Encode encodes the range properties into a byte slice. -func (r rangeProperties) Encode() []byte { - b := make([]byte, 0, 1024) - idx := 0 - for _, p := range r { - b = append(b, 0, 0, 0, 0) - binary.BigEndian.PutUint32(b[idx:], uint32(len(p.Key))) - idx += 4 - b = append(b, p.Key...) - idx += len(p.Key) - - b = append(b, 0, 0, 0, 0, 0, 0, 0, 0) - binary.BigEndian.PutUint64(b[idx:], p.Size) - idx += 8 - - b = append(b, 0, 0, 0, 0, 0, 0, 0, 0) - binary.BigEndian.PutUint64(b[idx:], p.Keys) - idx += 8 - } - return b -} - -// RangePropertiesCollector collects range properties for each range. -type RangePropertiesCollector struct { - props rangeProperties - lastOffsets rangeOffsets - lastKey []byte - currentOffsets rangeOffsets - propSizeIdxDistance uint64 - propKeysIdxDistance uint64 -} - -func newRangePropertiesCollector() pebble.TablePropertyCollector { - return &RangePropertiesCollector{ - props: make([]rangeProperty, 0, 1024), - propSizeIdxDistance: defaultPropSizeIndexDistance, - propKeysIdxDistance: defaultPropKeysIndexDistance, - } -} - -func (c *RangePropertiesCollector) sizeInLastRange() uint64 { - return c.currentOffsets.Size - c.lastOffsets.Size -} - -func (c *RangePropertiesCollector) keysInLastRange() uint64 { - return c.currentOffsets.Keys - c.lastOffsets.Keys -} - -func (c *RangePropertiesCollector) insertNewPoint(key []byte) { - c.lastOffsets = c.currentOffsets - c.props = append(c.props, rangeProperty{Key: append([]byte{}, key...), rangeOffsets: c.currentOffsets}) -} - -// Add implements `pebble.TablePropertyCollector`. -// Add implements `TablePropertyCollector.Add`. -func (c *RangePropertiesCollector) Add(key pebble.InternalKey, value []byte) error { - if key.Kind() != pebble.InternalKeyKindSet || bytes.Equal(key.UserKey, engineMetaKey) { - return nil - } - c.currentOffsets.Size += uint64(len(value)) + uint64(len(key.UserKey)) - c.currentOffsets.Keys++ - if len(c.lastKey) == 0 || c.sizeInLastRange() >= c.propSizeIdxDistance || - c.keysInLastRange() >= c.propKeysIdxDistance { - c.insertNewPoint(key.UserKey) - } - c.lastKey = append(c.lastKey[:0], key.UserKey...) - return nil -} - -// Finish implements `pebble.TablePropertyCollector`. -func (c *RangePropertiesCollector) Finish(userProps map[string]string) error { - if c.sizeInLastRange() > 0 || c.keysInLastRange() > 0 { - c.insertNewPoint(c.lastKey) - } - - userProps[propRangeIndex] = string(c.props.Encode()) - return nil -} - -// Name implements `pebble.TablePropertyCollector`. -func (*RangePropertiesCollector) Name() string { - return propRangeIndex -} - -type sizeProperties struct { - totalSize uint64 - indexHandles *btree.BTree -} - -func newSizeProperties() *sizeProperties { - return &sizeProperties{indexHandles: btree.New(32)} -} - -func (s *sizeProperties) add(item *rangeProperty) { - if old := s.indexHandles.ReplaceOrInsert(item); old != nil { - o := old.(*rangeProperty) - item.Keys += o.Keys - item.Size += o.Size - } -} - -func (s *sizeProperties) addAll(props rangeProperties) { - prevRange := rangeOffsets{} - for _, r := range props { - s.add(&rangeProperty{ - Key: r.Key, - rangeOffsets: rangeOffsets{Keys: r.Keys - prevRange.Keys, Size: r.Size - prevRange.Size}, - }) - prevRange = r.rangeOffsets - } - if len(props) > 0 { - s.totalSize += props[len(props)-1].Size - } -} - -// iter the tree until f return false -func (s *sizeProperties) iter(f func(p *rangeProperty) bool) { - s.indexHandles.Ascend(func(i btree.Item) bool { - prop := i.(*rangeProperty) - return f(prop) - }) -} - -func decodeRangeProperties(data []byte, keyAdapter common.KeyAdapter) (rangeProperties, error) { - r := make(rangeProperties, 0, 16) - for len(data) > 0 { - if len(data) < 4 { - return nil, io.ErrUnexpectedEOF - } - keyLen := int(binary.BigEndian.Uint32(data[:4])) - data = data[4:] - if len(data) < keyLen+8*2 { - return nil, io.ErrUnexpectedEOF - } - key := data[:keyLen] - data = data[keyLen:] - size := binary.BigEndian.Uint64(data[:8]) - keys := binary.BigEndian.Uint64(data[8:]) - data = data[16:] - if !bytes.Equal(key, engineMetaKey) { - userKey, err := keyAdapter.Decode(nil, key) - if err != nil { - return nil, errors.Annotate(err, "failed to decode key with keyAdapter") - } - r = append(r, rangeProperty{Key: userKey, rangeOffsets: rangeOffsets{Size: size, Keys: keys}}) - } - } - - return r, nil -} - -// getSizePropertiesFn is used to let unit test replace the real function. -var getSizePropertiesFn = getSizeProperties - -func getSizeProperties(logger log.Logger, db *pebble.DB, keyAdapter common.KeyAdapter) (*sizeProperties, error) { - sstables, err := db.SSTables(pebble.WithProperties()) - if err != nil { - logger.Warn("get sst table properties failed", log.ShortError(err)) - return nil, errors.Trace(err) - } - - sizeProps := newSizeProperties() - for _, level := range sstables { - for _, info := range level { - if prop, ok := info.Properties.UserProperties[propRangeIndex]; ok { - data := hack.Slice(prop) - rangeProps, err := decodeRangeProperties(data, keyAdapter) - if err != nil { - logger.Warn("decodeRangeProperties failed", - zap.Stringer("fileNum", info.FileNum), log.ShortError(err)) - return nil, errors.Trace(err) - } - sizeProps.addAll(rangeProps) - } - } - } - - return sizeProps, nil -} - -func (e *Engine) getEngineFileSize() backend.EngineFileSize { - db := e.getDB() - - var total pebble.LevelMetrics - if db != nil { - metrics := db.Metrics() - total = metrics.Total() - } - var memSize int64 - e.localWriters.Range(func(k, _ any) bool { - w := k.(*Writer) - memSize += int64(w.EstimatedSize()) - return true - }) - - pendingSize := e.pendingFileSize.Load() - // TODO: should also add the in-processing compaction sst writer size into MemSize - return backend.EngineFileSize{ - UUID: e.UUID, - DiskSize: total.Size + pendingSize, - MemSize: memSize, - IsImporting: e.isLocked(), - } -} - -// either a sstMeta or a flush message -type metaOrFlush struct { - meta *sstMeta - flushCh chan struct{} -} - -type metaSeq struct { - // the sequence for this flush message, a flush call can return only if - // all the other flush will lower `flushSeq` are done - flushSeq int32 - // the max sstMeta sequence number in this flush, after the flush is done (all SSTs are ingested), - // we can save chunks will a lower meta sequence number safely. - metaSeq int32 -} - -type metaSeqHeap struct { - arr []metaSeq -} - -// Len returns the number of items in the priority queue. -func (h *metaSeqHeap) Len() int { - return len(h.arr) -} - -// Less reports whether the item in the priority queue with -func (h *metaSeqHeap) Less(i, j int) bool { - return h.arr[i].flushSeq < h.arr[j].flushSeq -} - -// Swap swaps the items at the passed indices. -func (h *metaSeqHeap) Swap(i, j int) { - h.arr[i], h.arr[j] = h.arr[j], h.arr[i] -} - -// Push pushes the item onto the priority queue. -func (h *metaSeqHeap) Push(x any) { - h.arr = append(h.arr, x.(metaSeq)) -} - -// Pop removes the minimum item (according to Less) from the priority queue -func (h *metaSeqHeap) Pop() any { - item := h.arr[len(h.arr)-1] - h.arr = h.arr[:len(h.arr)-1] - return item -} - -func (e *Engine) ingestSSTLoop() { - defer e.wg.Done() - - type flushSeq struct { - seq int32 - ch chan struct{} - } - - seq := atomic.NewInt32(0) - finishedSeq := atomic.NewInt32(0) - var seqLock sync.Mutex - // a flush is finished iff all the compaction&ingest tasks with a lower seq number are finished. - flushQueue := make([]flushSeq, 0) - // inSyncSeqs is a heap that stores all the finished compaction tasks whose seq is bigger than `finishedSeq + 1` - // this mean there are still at lease one compaction task with a lower seq unfinished. - inSyncSeqs := &metaSeqHeap{arr: make([]metaSeq, 0)} - - type metaAndSeq struct { - metas []*sstMeta - seq int32 - } - - concurrency := e.config.CompactConcurrency - // when compaction is disabled, ingest is an serial action, so 1 routine is enough - if !e.config.Compact { - concurrency = 1 - } - metaChan := make(chan metaAndSeq, concurrency) - for i := 0; i < concurrency; i++ { - e.wg.Add(1) - go func() { - defer func() { - if e.ingestErr.Get() != nil { - seqLock.Lock() - for _, f := range flushQueue { - f.ch <- struct{}{} - } - flushQueue = flushQueue[:0] - seqLock.Unlock() - } - e.wg.Done() - }() - for { - select { - case <-e.ctx.Done(): - return - case metas, ok := <-metaChan: - if !ok { - return - } - ingestMetas := metas.metas - if e.config.Compact { - newMeta, err := e.sstIngester.mergeSSTs(metas.metas, e.sstDir, e.config.BlockSize) - if err != nil { - e.setError(err) - return - } - ingestMetas = []*sstMeta{newMeta} - } - // batchIngestSSTs will change ingestMetas' order, so we record the max seq here - metasMaxSeq := ingestMetas[len(ingestMetas)-1].seq - - if err := e.batchIngestSSTs(ingestMetas); err != nil { - e.setError(err) - return - } - seqLock.Lock() - finSeq := finishedSeq.Load() - if metas.seq == finSeq+1 { - finSeq = metas.seq - finMetaSeq := metasMaxSeq - for len(inSyncSeqs.arr) > 0 { - if inSyncSeqs.arr[0].flushSeq != finSeq+1 { - break - } - finSeq++ - finMetaSeq = inSyncSeqs.arr[0].metaSeq - heap.Remove(inSyncSeqs, 0) - } - - var flushChans []chan struct{} - for _, seq := range flushQueue { - if seq.seq > finSeq { - break - } - flushChans = append(flushChans, seq.ch) - } - flushQueue = flushQueue[len(flushChans):] - finishedSeq.Store(finSeq) - e.finishedMetaSeq.Store(finMetaSeq) - seqLock.Unlock() - for _, c := range flushChans { - c <- struct{}{} - } - } else { - heap.Push(inSyncSeqs, metaSeq{flushSeq: metas.seq, metaSeq: metasMaxSeq}) - seqLock.Unlock() - } - } - } - }() - } - - compactAndIngestSSTs := func(metas []*sstMeta) { - if len(metas) > 0 { - seqLock.Lock() - metaSeq := seq.Add(1) - seqLock.Unlock() - select { - case <-e.ctx.Done(): - case metaChan <- metaAndSeq{metas: metas, seq: metaSeq}: - } - } - } - - pendingMetas := make([]*sstMeta, 0, 16) - totalSize := int64(0) - metasTmp := make([]*sstMeta, 0) - addMetas := func() { - if len(metasTmp) == 0 { - return - } - metas := metasTmp - metasTmp = make([]*sstMeta, 0, len(metas)) - if !e.config.Compact { - compactAndIngestSSTs(metas) - return - } - for _, m := range metas { - if m.totalCount > 0 { - pendingMetas = append(pendingMetas, m) - totalSize += m.totalSize - if totalSize >= e.config.CompactThreshold { - compactMetas := pendingMetas - pendingMetas = make([]*sstMeta, 0, len(pendingMetas)) - totalSize = 0 - compactAndIngestSSTs(compactMetas) - } - } - } - } -readMetaLoop: - for { - closed := false - select { - case <-e.ctx.Done(): - close(metaChan) - return - case m, ok := <-e.sstMetasChan: - if !ok { - closed = true - break - } - if m.flushCh != nil { - // meet a flush event, we should trigger a ingest task if there are pending metas, - // and then waiting for all the running flush tasks to be done. - if len(metasTmp) > 0 { - addMetas() - } - if len(pendingMetas) > 0 { - seqLock.Lock() - metaSeq := seq.Add(1) - flushQueue = append(flushQueue, flushSeq{ch: m.flushCh, seq: metaSeq}) - seqLock.Unlock() - select { - case metaChan <- metaAndSeq{metas: pendingMetas, seq: metaSeq}: - case <-e.ctx.Done(): - close(metaChan) - return - } - - pendingMetas = make([]*sstMeta, 0, len(pendingMetas)) - totalSize = 0 - } else { - // none remaining metas needed to be ingested - seqLock.Lock() - curSeq := seq.Load() - finSeq := finishedSeq.Load() - // if all pending SST files are written, directly do a db.Flush - if curSeq == finSeq { - seqLock.Unlock() - m.flushCh <- struct{}{} - } else { - // waiting for pending compaction tasks - flushQueue = append(flushQueue, flushSeq{ch: m.flushCh, seq: curSeq}) - seqLock.Unlock() - } - } - continue readMetaLoop - } - metasTmp = append(metasTmp, m.meta) - // try to drain all the sst meta from the chan to make sure all the SSTs are processed before handle a flush msg. - if len(e.sstMetasChan) > 0 { - continue readMetaLoop - } - - addMetas() - } - if closed { - compactAndIngestSSTs(pendingMetas) - close(metaChan) - return - } - } -} - -func (e *Engine) addSST(ctx context.Context, m *sstMeta) (int32, error) { - // set pending size after SST file is generated - e.pendingFileSize.Add(m.fileSize) - // make sure sstMeta is sent into the chan in order - e.seqLock.Lock() - defer e.seqLock.Unlock() - e.nextSeq++ - seq := e.nextSeq - m.seq = seq - select { - case e.sstMetasChan <- metaOrFlush{meta: m}: - case <-ctx.Done(): - return 0, ctx.Err() - case <-e.ctx.Done(): - } - return seq, e.ingestErr.Get() -} - -func (e *Engine) batchIngestSSTs(metas []*sstMeta) error { - if len(metas) == 0 { - return nil - } - slices.SortFunc(metas, func(i, j *sstMeta) int { - return bytes.Compare(i.minKey, j.minKey) - }) - - // non overlapping sst is grouped, and ingested in that order - metaLevels := make([][]*sstMeta, 0) - for _, meta := range metas { - inserted := false - for i, l := range metaLevels { - if bytes.Compare(l[len(l)-1].maxKey, meta.minKey) >= 0 { - continue - } - metaLevels[i] = append(l, meta) - inserted = true - break - } - if !inserted { - metaLevels = append(metaLevels, []*sstMeta{meta}) - } - } - - for _, l := range metaLevels { - if err := e.ingestSSTs(l); err != nil { - return err - } - } - return nil -} - -func (e *Engine) ingestSSTs(metas []*sstMeta) error { - // use raw RLock to avoid change the lock state during flushing. - e.mutex.RLock() - defer e.mutex.RUnlock() - if e.closed.Load() { - return errorEngineClosed - } - totalSize := int64(0) - totalCount := int64(0) - fileSize := int64(0) - for _, m := range metas { - totalSize += m.totalSize - totalCount += m.totalCount - fileSize += m.fileSize - } - e.logger.Info("write data to local DB", - zap.Int64("size", totalSize), - zap.Int64("kvs", totalCount), - zap.Int("files", len(metas)), - zap.Int64("sstFileSize", fileSize), - zap.String("file", metas[0].path), - logutil.Key("firstKey", metas[0].minKey), - logutil.Key("lastKey", metas[len(metas)-1].maxKey)) - if err := e.sstIngester.ingest(metas); err != nil { - return errors.Trace(err) - } - count := int64(0) - size := int64(0) - for _, m := range metas { - count += m.totalCount - size += m.totalSize - } - e.Length.Add(count) - e.TotalSize.Add(size) - return nil -} - -func (e *Engine) flushLocalWriters(parentCtx context.Context) error { - eg, ctx := errgroup.WithContext(parentCtx) - e.localWriters.Range(func(k, _ any) bool { - eg.Go(func() error { - w := k.(*Writer) - return w.flush(ctx) - }) - return true - }) - return eg.Wait() -} - -func (e *Engine) flushEngineWithoutLock(ctx context.Context) error { - if err := e.flushLocalWriters(ctx); err != nil { - return err - } - flushChan := make(chan struct{}, 1) - select { - case e.sstMetasChan <- metaOrFlush{flushCh: flushChan}: - case <-ctx.Done(): - return ctx.Err() - case <-e.ctx.Done(): - return e.ctx.Err() - } - - select { - case <-flushChan: - case <-ctx.Done(): - return ctx.Err() - case <-e.ctx.Done(): - return e.ctx.Err() - } - if err := e.ingestErr.Get(); err != nil { - return errors.Trace(err) - } - if err := e.saveEngineMeta(); err != nil { - return err - } - - flushFinishedCh, err := e.getDB().AsyncFlush() - if err != nil { - return errors.Trace(err) - } - select { - case <-flushFinishedCh: - return nil - case <-ctx.Done(): - return ctx.Err() - case <-e.ctx.Done(): - return e.ctx.Err() - } -} - -func saveEngineMetaToDB(meta *engineMeta, db *pebble.DB) error { - jsonBytes, err := json.Marshal(meta) - if err != nil { - return errors.Trace(err) - } - // note: we can't set Sync to true since we disabled WAL. - return db.Set(engineMetaKey, jsonBytes, &pebble.WriteOptions{Sync: false}) -} - -// saveEngineMeta saves the metadata about the DB into the DB itself. -// This method should be followed by a Flush to ensure the data is actually synchronized -func (e *Engine) saveEngineMeta() error { - e.logger.Debug("save engine meta", zap.Stringer("uuid", e.UUID), zap.Int64("count", e.Length.Load()), - zap.Int64("size", e.TotalSize.Load())) - return errors.Trace(saveEngineMetaToDB(&e.engineMeta, e.getDB())) -} - -func (e *Engine) loadEngineMeta() error { - jsonBytes, closer, err := e.getDB().Get(engineMetaKey) - if err != nil { - if err == pebble.ErrNotFound { - e.logger.Debug("local db missing engine meta", zap.Stringer("uuid", e.UUID), log.ShortError(err)) - return nil - } - return err - } - //nolint: errcheck - defer closer.Close() - - if err = json.Unmarshal(jsonBytes, &e.engineMeta); err != nil { - e.logger.Warn("local db failed to deserialize meta", zap.Stringer("uuid", e.UUID), zap.ByteString("content", jsonBytes), zap.Error(err)) - return err - } - e.logger.Debug("load engine meta", zap.Stringer("uuid", e.UUID), zap.Int64("count", e.Length.Load()), - zap.Int64("size", e.TotalSize.Load())) - return nil -} - -func (e *Engine) newKVIter(ctx context.Context, opts *pebble.IterOptions, buf *membuf.Buffer) IngestLocalEngineIter { - if bytes.Compare(opts.LowerBound, normalIterStartKey) < 0 { - newOpts := *opts - newOpts.LowerBound = normalIterStartKey - opts = &newOpts - } - if !e.duplicateDetection { - iter, err := e.getDB().NewIter(opts) - if err != nil { - e.logger.Panic("fail to create iterator") - return nil - } - return &pebbleIter{Iterator: iter, buf: buf} - } - logger := log.FromContext(ctx).With( - zap.String("table", common.UniqueTable(e.tableInfo.DB, e.tableInfo.Name)), - zap.Int64("tableID", e.tableInfo.ID), - zap.Stringer("engineUUID", e.UUID)) - return newDupDetectIter( - e.getDB(), - e.keyAdapter, - opts, - e.duplicateDB, - logger, - e.dupDetectOpt, - buf, - ) -} - -var _ common.IngestData = (*Engine)(nil) - -// GetFirstAndLastKey reads the first and last key in range [lowerBound, upperBound) -// in the engine. Empty upperBound means unbounded. -func (e *Engine) GetFirstAndLastKey(lowerBound, upperBound []byte) ([]byte, []byte, error) { - if len(upperBound) == 0 { - // we use empty slice for unbounded upper bound, but it means max value in pebble - // so reset to nil - upperBound = nil - } - opt := &pebble.IterOptions{ - LowerBound: lowerBound, - UpperBound: upperBound, - } - failpoint.Inject("mockGetFirstAndLastKey", func() { - failpoint.Return(lowerBound, upperBound, nil) - }) - - iter := e.newKVIter(context.Background(), opt, nil) - //nolint: errcheck - defer iter.Close() - // Needs seek to first because NewIter returns an iterator that is unpositioned - hasKey := iter.First() - if iter.Error() != nil { - return nil, nil, errors.Annotate(iter.Error(), "failed to read the first key") - } - if !hasKey { - return nil, nil, nil - } - firstKey := append([]byte{}, iter.Key()...) - iter.Last() - if iter.Error() != nil { - return nil, nil, errors.Annotate(iter.Error(), "failed to seek to the last key") - } - lastKey := append([]byte{}, iter.Key()...) - return firstKey, lastKey, nil -} - -// NewIter implements IngestData interface. -func (e *Engine) NewIter( - ctx context.Context, - lowerBound, upperBound []byte, - bufPool *membuf.Pool, -) common.ForwardIter { - return e.newKVIter( - ctx, - &pebble.IterOptions{LowerBound: lowerBound, UpperBound: upperBound}, - bufPool.NewBuffer(), - ) -} - -// GetTS implements IngestData interface. -func (e *Engine) GetTS() uint64 { - return e.TS -} - -// IncRef implements IngestData interface. -func (*Engine) IncRef() {} - -// DecRef implements IngestData interface. -func (*Engine) DecRef() {} - -// Finish implements IngestData interface. -func (e *Engine) Finish(totalBytes, totalCount int64) { - e.importedKVSize.Add(totalBytes) - e.importedKVCount.Add(totalCount) -} - -// LoadIngestData return (local) Engine itself because Engine has implemented -// IngestData interface. -func (e *Engine) LoadIngestData( - ctx context.Context, - regionRanges []common.Range, - outCh chan<- common.DataAndRange, -) error { - for _, r := range regionRanges { - select { - case <-ctx.Done(): - return ctx.Err() - case outCh <- common.DataAndRange{Data: e, Range: r}: - } - } - return nil -} - -type sstMeta struct { - path string - minKey []byte - maxKey []byte - totalSize int64 - totalCount int64 - // used for calculate disk-quota - fileSize int64 - seq int32 -} - -// Writer is used to write data into a SST file. -type Writer struct { - sync.Mutex - engine *Engine - memtableSizeLimit int64 - - // if the KVs are append in order, we can directly write the into SST file, - // else we must first store them in writeBatch and then batch flush into SST file. - isKVSorted bool - writer atomic.Pointer[sstWriter] - writerSize atomic.Uint64 - - // bytes buffer for writeBatch - kvBuffer *membuf.Buffer - writeBatch []common.KvPair - // if the kvs in writeBatch are in order, we can avoid doing a `sort.Slice` which - // is quite slow. in our bench, the sort operation eats about 5% of total CPU - isWriteBatchSorted bool - sortedKeyBuf []byte - - batchCount int - batchSize atomic.Int64 - - lastMetaSeq int32 - - tikvCodec tikv.Codec -} - -func (w *Writer) appendRowsSorted(kvs []common.KvPair) (err error) { - writer := w.writer.Load() - if writer == nil { - writer, err = w.createSSTWriter() - if err != nil { - return errors.Trace(err) - } - w.writer.Store(writer) - } - - keyAdapter := w.engine.keyAdapter - totalKeySize := 0 - for i := 0; i < len(kvs); i++ { - keySize := keyAdapter.EncodedLen(kvs[i].Key, kvs[i].RowID) - w.batchSize.Add(int64(keySize + len(kvs[i].Val))) - totalKeySize += keySize - } - w.batchCount += len(kvs) - // NoopKeyAdapter doesn't really change the key, - // skipping the encoding to avoid unnecessary alloc and copy. - if _, ok := keyAdapter.(common.NoopKeyAdapter); !ok { - if cap(w.sortedKeyBuf) < totalKeySize { - w.sortedKeyBuf = make([]byte, totalKeySize) - } - buf := w.sortedKeyBuf[:0] - newKvs := make([]common.KvPair, len(kvs)) - for i := 0; i < len(kvs); i++ { - buf = keyAdapter.Encode(buf, kvs[i].Key, kvs[i].RowID) - newKvs[i] = common.KvPair{Key: buf, Val: kvs[i].Val} - buf = buf[len(buf):] - } - kvs = newKvs - } - if err := writer.writeKVs(kvs); err != nil { - return err - } - w.writerSize.Store(writer.writer.EstimatedSize()) - return nil -} - -func (w *Writer) appendRowsUnsorted(ctx context.Context, kvs []common.KvPair) error { - l := len(w.writeBatch) - cnt := w.batchCount - var lastKey []byte - if cnt > 0 { - lastKey = w.writeBatch[cnt-1].Key - } - keyAdapter := w.engine.keyAdapter - for _, pair := range kvs { - if w.isWriteBatchSorted && bytes.Compare(lastKey, pair.Key) > 0 { - w.isWriteBatchSorted = false - } - lastKey = pair.Key - w.batchSize.Add(int64(len(pair.Key) + len(pair.Val))) - buf := w.kvBuffer.AllocBytes(keyAdapter.EncodedLen(pair.Key, pair.RowID)) - key := keyAdapter.Encode(buf[:0], pair.Key, pair.RowID) - val := w.kvBuffer.AddBytes(pair.Val) - if cnt < l { - w.writeBatch[cnt].Key = key - w.writeBatch[cnt].Val = val - } else { - w.writeBatch = append(w.writeBatch, common.KvPair{Key: key, Val: val}) - } - cnt++ - } - w.batchCount = cnt - - if w.batchSize.Load() > w.memtableSizeLimit { - if err := w.flushKVs(ctx); err != nil { - return err - } - } - return nil -} - -// AppendRows appends rows to the SST file. -func (w *Writer) AppendRows(ctx context.Context, columnNames []string, rows encode.Rows) error { - kvs := kv.Rows2KvPairs(rows) - if len(kvs) == 0 { - return nil - } - - if w.engine.closed.Load() { - return errorEngineClosed - } - - for i := range kvs { - kvs[i].Key = w.tikvCodec.EncodeKey(kvs[i].Key) - } - - w.Lock() - defer w.Unlock() - - // if chunk has _tidb_rowid field, we can't ensure that the rows are sorted. - if w.isKVSorted && w.writer.Load() == nil { - for _, c := range columnNames { - if c == model.ExtraHandleName.L { - w.isKVSorted = false - } - } - } - - if w.isKVSorted { - return w.appendRowsSorted(kvs) - } - return w.appendRowsUnsorted(ctx, kvs) -} - -func (w *Writer) flush(ctx context.Context) error { - w.Lock() - defer w.Unlock() - if w.batchCount == 0 { - return nil - } - - if len(w.writeBatch) > 0 { - if err := w.flushKVs(ctx); err != nil { - return errors.Trace(err) - } - } - - writer := w.writer.Load() - if writer != nil { - meta, err := writer.close() - if err != nil { - return errors.Trace(err) - } - w.writer.Store(nil) - w.writerSize.Store(0) - w.batchCount = 0 - if meta != nil && meta.totalSize > 0 { - return w.addSST(ctx, meta) - } - } - - return nil -} - -// EstimatedSize returns the estimated size of the SST file. -func (w *Writer) EstimatedSize() uint64 { - if size := w.writerSize.Load(); size > 0 { - return size - } - // if kvs are still in memory, only calculate half of the total size - // in our tests, SST file size is about 50% of the raw kv size - return uint64(w.batchSize.Load()) / 2 -} - -type flushStatus struct { - local *Engine - seq int32 -} - -// Flushed implements backend.ChunkFlushStatus. -func (f flushStatus) Flushed() bool { - return f.seq <= f.local.finishedMetaSeq.Load() -} - -// Close implements backend.ChunkFlushStatus. -func (w *Writer) Close(ctx context.Context) (backend.ChunkFlushStatus, error) { - defer w.kvBuffer.Destroy() - defer w.engine.localWriters.Delete(w) - err := w.flush(ctx) - // FIXME: in theory this line is useless, but In our benchmark with go1.15 - // this can resolve the memory consistently increasing issue. - // maybe this is a bug related to go GC mechanism. - w.writeBatch = nil - return flushStatus{local: w.engine, seq: w.lastMetaSeq}, err -} - -// IsSynced implements backend.ChunkFlushStatus. -func (w *Writer) IsSynced() bool { - return w.batchCount == 0 && w.lastMetaSeq <= w.engine.finishedMetaSeq.Load() -} - -func (w *Writer) flushKVs(ctx context.Context) error { - writer, err := w.createSSTWriter() - if err != nil { - return errors.Trace(err) - } - if !w.isWriteBatchSorted { - slices.SortFunc(w.writeBatch[:w.batchCount], func(i, j common.KvPair) int { - return bytes.Compare(i.Key, j.Key) - }) - w.isWriteBatchSorted = true - } - - err = writer.writeKVs(w.writeBatch[:w.batchCount]) - if err != nil { - return errors.Trace(err) - } - meta, err := writer.close() - if err != nil { - return errors.Trace(err) - } - - failpoint.Inject("orphanWriterGoRoutine", func() { - _ = common.KillMySelf() - // mimic we meet context cancel error when `addSST` - <-ctx.Done() - time.Sleep(5 * time.Second) - failpoint.Return(errors.Trace(ctx.Err())) - }) - - err = w.addSST(ctx, meta) - if err != nil { - return errors.Trace(err) - } - - w.batchSize.Store(0) - w.batchCount = 0 - w.kvBuffer.Reset() - return nil -} - -func (w *Writer) addSST(ctx context.Context, meta *sstMeta) error { - seq, err := w.engine.addSST(ctx, meta) - if err != nil { - return err - } - w.lastMetaSeq = seq - return nil -} - -func (w *Writer) createSSTWriter() (*sstWriter, error) { - path := filepath.Join(w.engine.sstDir, uuid.New().String()+".sst") - writer, err := newSSTWriter(path, w.engine.config.BlockSize) - if err != nil { - return nil, err - } - sw := &sstWriter{sstMeta: &sstMeta{path: path}, writer: writer, logger: w.engine.logger} - return sw, nil -} - -var errorUnorderedSSTInsertion = errors.New("inserting KVs into SST without order") - -type sstWriter struct { - *sstMeta - writer *sstable.Writer - - // To dedup keys before write them into the SST file. - // NOTE: keys should be sorted and deduped when construct one SST file. - lastKey []byte - - logger log.Logger -} - -func newSSTWriter(path string, blockSize int) (*sstable.Writer, error) { - f, err := vfs.Default.Create(path) - if err != nil { - return nil, errors.Trace(err) - } - writable := objstorageprovider.NewFileWritable(f) - writer := sstable.NewWriter(writable, sstable.WriterOptions{ - TablePropertyCollectors: []func() pebble.TablePropertyCollector{ - newRangePropertiesCollector, - }, - BlockSize: blockSize, - }) - return writer, nil -} - -func (sw *sstWriter) writeKVs(kvs []common.KvPair) error { - if len(kvs) == 0 { - return nil - } - if len(sw.minKey) == 0 { - sw.minKey = append([]byte{}, kvs[0].Key...) - } - if bytes.Compare(kvs[0].Key, sw.maxKey) <= 0 { - return errorUnorderedSSTInsertion - } - - internalKey := sstable.InternalKey{ - Trailer: uint64(sstable.InternalKeyKindSet), - } - for _, p := range kvs { - if sw.lastKey != nil && bytes.Equal(p.Key, sw.lastKey) { - sw.logger.Warn("duplicated key found, skip write", logutil.Key("key", p.Key)) - continue - } - internalKey.UserKey = p.Key - if err := sw.writer.Add(internalKey, p.Val); err != nil { - return errors.Trace(err) - } - sw.totalSize += int64(len(p.Key)) + int64(len(p.Val)) - sw.lastKey = p.Key - } - sw.totalCount += int64(len(kvs)) - sw.maxKey = append(sw.maxKey[:0], sw.lastKey...) - return nil -} - -func (sw *sstWriter) close() (*sstMeta, error) { - if err := sw.writer.Close(); err != nil { - return nil, errors.Trace(err) - } - meta, err := sw.writer.Metadata() - if err != nil { - return nil, errors.Trace(err) - } - sw.fileSize = int64(meta.Size) - return sw.sstMeta, nil -} - -type sstIter struct { - name string - key []byte - val []byte - iter sstable.Iterator - reader *sstable.Reader - valid bool -} - -// Close implements common.Iterator. -func (i *sstIter) Close() error { - if err := i.iter.Close(); err != nil { - return errors.Trace(err) - } - err := i.reader.Close() - return errors.Trace(err) -} - -type sstIterHeap struct { - iters []*sstIter -} - -// Len implements heap.Interface. -func (h *sstIterHeap) Len() int { - return len(h.iters) -} - -// Less implements heap.Interface. -func (h *sstIterHeap) Less(i, j int) bool { - return bytes.Compare(h.iters[i].key, h.iters[j].key) < 0 -} - -// Swap implements heap.Interface. -func (h *sstIterHeap) Swap(i, j int) { - h.iters[i], h.iters[j] = h.iters[j], h.iters[i] -} - -// Push implements heap.Interface. -func (h *sstIterHeap) Push(x any) { - h.iters = append(h.iters, x.(*sstIter)) -} - -// Pop implements heap.Interface. -func (h *sstIterHeap) Pop() any { - item := h.iters[len(h.iters)-1] - h.iters = h.iters[:len(h.iters)-1] - return item -} - -// Next implements common.Iterator. -func (h *sstIterHeap) Next() ([]byte, []byte, error) { - for { - if len(h.iters) == 0 { - return nil, nil, nil - } - - iter := h.iters[0] - if iter.valid { - iter.valid = false - return iter.key, iter.val, iter.iter.Error() - } - - var k *pebble.InternalKey - var v pebble.LazyValue - k, v = iter.iter.Next() - - if k != nil { - vBytes, _, err := v.Value(nil) - if err != nil { - return nil, nil, errors.Trace(err) - } - iter.key = k.UserKey - iter.val = vBytes - iter.valid = true - heap.Fix(h, 0) - } else { - err := iter.Close() - heap.Remove(h, 0) - if err != nil { - return nil, nil, errors.Trace(err) - } - } - } -} - -// sstIngester is a interface used to merge and ingest SST files. -// it's a interface mainly used for test convenience -type sstIngester interface { - mergeSSTs(metas []*sstMeta, dir string, blockSize int) (*sstMeta, error) - ingest([]*sstMeta) error -} - -type dbSSTIngester struct { - e *Engine -} - -func (i dbSSTIngester) mergeSSTs(metas []*sstMeta, dir string, blockSize int) (*sstMeta, error) { - if len(metas) == 0 { - return nil, errors.New("sst metas is empty") - } else if len(metas) == 1 { - return metas[0], nil - } - - start := time.Now() - newMeta := &sstMeta{ - seq: metas[len(metas)-1].seq, - } - mergeIter := &sstIterHeap{ - iters: make([]*sstIter, 0, len(metas)), - } - - for _, p := range metas { - f, err := vfs.Default.Open(p.path) - if err != nil { - return nil, errors.Trace(err) - } - readable, err := sstable.NewSimpleReadable(f) - if err != nil { - return nil, errors.Trace(err) - } - reader, err := sstable.NewReader(readable, sstable.ReaderOptions{}) - if err != nil { - return nil, errors.Trace(err) - } - iter, err := reader.NewIter(nil, nil) - if err != nil { - return nil, errors.Trace(err) - } - key, val := iter.Next() - if key == nil { - continue - } - valBytes, _, err := val.Value(nil) - if err != nil { - return nil, errors.Trace(err) - } - if iter.Error() != nil { - return nil, errors.Trace(iter.Error()) - } - mergeIter.iters = append(mergeIter.iters, &sstIter{ - name: p.path, - iter: iter, - key: key.UserKey, - val: valBytes, - reader: reader, - valid: true, - }) - newMeta.totalSize += p.totalSize - newMeta.totalCount += p.totalCount - } - heap.Init(mergeIter) - - name := filepath.Join(dir, fmt.Sprintf("%s.sst", uuid.New())) - writer, err := newSSTWriter(name, blockSize) - if err != nil { - return nil, errors.Trace(err) - } - newMeta.path = name - - internalKey := sstable.InternalKey{ - Trailer: uint64(sstable.InternalKeyKindSet), - } - key, val, err := mergeIter.Next() - if err != nil { - return nil, err - } - if key == nil { - return nil, errors.New("all ssts are empty") - } - newMeta.minKey = append(newMeta.minKey[:0], key...) - lastKey := make([]byte, 0) - for { - if bytes.Equal(lastKey, key) { - i.e.logger.Warn("duplicated key found, skipped", zap.Binary("key", lastKey)) - newMeta.totalCount-- - newMeta.totalSize -= int64(len(key) + len(val)) - - goto nextKey - } - internalKey.UserKey = key - err = writer.Add(internalKey, val) - if err != nil { - return nil, err - } - lastKey = append(lastKey[:0], key...) - nextKey: - key, val, err = mergeIter.Next() - if err != nil { - return nil, err - } - if key == nil { - break - } - } - err = writer.Close() - if err != nil { - return nil, errors.Trace(err) - } - meta, err := writer.Metadata() - if err != nil { - return nil, errors.Trace(err) - } - newMeta.maxKey = lastKey - newMeta.fileSize = int64(meta.Size) - - dur := time.Since(start) - i.e.logger.Info("compact sst", zap.Int("fileCount", len(metas)), zap.Int64("size", newMeta.totalSize), - zap.Int64("count", newMeta.totalCount), zap.Duration("cost", dur), zap.String("file", name)) - - // async clean raw SSTs. - go func() { - totalSize := int64(0) - for _, m := range metas { - totalSize += m.fileSize - if err := os.Remove(m.path); err != nil { - i.e.logger.Warn("async cleanup sst file failed", zap.Error(err)) - } - } - // decrease the pending size after clean up - i.e.pendingFileSize.Sub(totalSize) - }() - - return newMeta, err -} - -func (i dbSSTIngester) ingest(metas []*sstMeta) error { - if len(metas) == 0 { - return nil - } - paths := make([]string, 0, len(metas)) - for _, m := range metas { - paths = append(paths, m.path) - } - db := i.e.getDB() - if db == nil { - return errorEngineClosed - } - return db.Ingest(paths) -} diff --git a/pkg/lightning/backend/local/engine_mgr.go b/pkg/lightning/backend/local/engine_mgr.go index b796fa736ee65..28b6107a3d3e5 100644 --- a/pkg/lightning/backend/local/engine_mgr.go +++ b/pkg/lightning/backend/local/engine_mgr.go @@ -630,9 +630,9 @@ func openDuplicateDB(storeDir string) (*pebble.DB, error) { newRangePropertiesCollector, }, } - if _, _err_ := failpoint.Eval(_curpkg_("slowCreateFS")); _err_ == nil { + failpoint.Inject("slowCreateFS", func() { opts.FS = slowCreateFS{vfs.Default} - } + }) return pebble.Open(dbPath, opts) } diff --git a/pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ b/pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ deleted file mode 100644 index 28b6107a3d3e5..0000000000000 --- a/pkg/lightning/backend/local/engine_mgr.go__failpoint_stash__ +++ /dev/null @@ -1,658 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package local - -import ( - "context" - "math" - "os" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/cockroachdb/pebble" - "github.com/cockroachdb/pebble/vfs" - "github.com/docker/go-units" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/membuf" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/external" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/manual" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/tikv/client-go/v2/oracle" - tikvclient "github.com/tikv/client-go/v2/tikv" - "go.uber.org/atomic" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -var ( - // RunInTest indicates whether the current process is running in test. - RunInTest bool - // LastAlloc is the last ID allocator. - LastAlloc manual.Allocator -) - -// StoreHelper have some api to help encode or store KV data -type StoreHelper interface { - GetTS(ctx context.Context) (physical, logical int64, err error) - GetTiKVCodec() tikvclient.Codec -} - -// engineManager manages all engines, either local or external. -type engineManager struct { - BackendConfig - StoreHelper - engines sync.Map // sync version of map[uuid.UUID]*Engine - externalEngine map[uuid.UUID]common.Engine - bufferPool *membuf.Pool - duplicateDB *pebble.DB - keyAdapter common.KeyAdapter - logger log.Logger -} - -var inMemTest = false - -func newEngineManager(config BackendConfig, storeHelper StoreHelper, logger log.Logger) (_ *engineManager, err error) { - var duplicateDB *pebble.DB - defer func() { - if err != nil && duplicateDB != nil { - _ = duplicateDB.Close() - } - }() - - if err = prepareSortDir(config); err != nil { - return nil, err - } - - keyAdapter := common.KeyAdapter(common.NoopKeyAdapter{}) - if config.DupeDetectEnabled { - duplicateDB, err = openDuplicateDB(config.LocalStoreDir) - if err != nil { - return nil, common.ErrOpenDuplicateDB.Wrap(err).GenWithStackByArgs() - } - keyAdapter = common.DupDetectKeyAdapter{} - } - alloc := manual.Allocator{} - if RunInTest { - alloc.RefCnt = new(atomic.Int64) - LastAlloc = alloc - } - var opts = make([]membuf.Option, 0, 1) - if !inMemTest { - // otherwise, we use the default allocator that can be tracked by golang runtime. - opts = append(opts, membuf.WithAllocator(alloc)) - } - return &engineManager{ - BackendConfig: config, - StoreHelper: storeHelper, - engines: sync.Map{}, - externalEngine: map[uuid.UUID]common.Engine{}, - bufferPool: membuf.NewPool(opts...), - duplicateDB: duplicateDB, - keyAdapter: keyAdapter, - logger: logger, - }, nil -} - -// rlock read locks a local file and returns the Engine instance if it exists. -func (em *engineManager) rLockEngine(engineID uuid.UUID) *Engine { - if e, ok := em.engines.Load(engineID); ok { - engine := e.(*Engine) - engine.rLock() - return engine - } - return nil -} - -// lock locks a local file and returns the Engine instance if it exists. -func (em *engineManager) lockEngine(engineID uuid.UUID, state importMutexState) *Engine { - if e, ok := em.engines.Load(engineID); ok { - engine := e.(*Engine) - engine.lock(state) - return engine - } - return nil -} - -// tryRLockAllEngines tries to read lock all engines, return all `Engine`s that are successfully locked. -func (em *engineManager) tryRLockAllEngines() []*Engine { - var allEngines []*Engine - em.engines.Range(func(_, v any) bool { - engine := v.(*Engine) - // skip closed engine - if engine.tryRLock() { - if !engine.closed.Load() { - allEngines = append(allEngines, engine) - } else { - engine.rUnlock() - } - } - return true - }) - return allEngines -} - -// lockAllEnginesUnless tries to lock all engines, unless those which are already locked in the -// state given by ignoreStateMask. Returns the list of locked engines. -func (em *engineManager) lockAllEnginesUnless(newState, ignoreStateMask importMutexState) []*Engine { - var allEngines []*Engine - em.engines.Range(func(_, v any) bool { - engine := v.(*Engine) - if engine.lockUnless(newState, ignoreStateMask) { - allEngines = append(allEngines, engine) - } - return true - }) - return allEngines -} - -// flushEngine ensure the written data is saved successfully, to make sure no data lose after restart -func (em *engineManager) flushEngine(ctx context.Context, engineID uuid.UUID) error { - engine := em.rLockEngine(engineID) - - // the engine cannot be deleted after while we've acquired the lock identified by UUID. - if engine == nil { - return errors.Errorf("engine '%s' not found", engineID) - } - defer engine.rUnlock() - if engine.closed.Load() { - return nil - } - return engine.flushEngineWithoutLock(ctx) -} - -// flushAllEngines flush all engines. -func (em *engineManager) flushAllEngines(parentCtx context.Context) (err error) { - allEngines := em.tryRLockAllEngines() - defer func() { - for _, engine := range allEngines { - engine.rUnlock() - } - }() - - eg, ctx := errgroup.WithContext(parentCtx) - for _, engine := range allEngines { - e := engine - eg.Go(func() error { - return e.flushEngineWithoutLock(ctx) - }) - } - return eg.Wait() -} - -func (em *engineManager) openEngineDB(engineUUID uuid.UUID, readOnly bool) (*pebble.DB, error) { - opt := &pebble.Options{ - MemTableSize: uint64(em.MemTableSize), - // the default threshold value may cause write stall. - MemTableStopWritesThreshold: 8, - MaxConcurrentCompactions: func() int { return 16 }, - // set threshold to half of the max open files to avoid trigger compaction - L0CompactionThreshold: math.MaxInt32, - L0StopWritesThreshold: math.MaxInt32, - LBaseMaxBytes: 16 * units.TiB, - MaxOpenFiles: em.MaxOpenFiles, - DisableWAL: true, - ReadOnly: readOnly, - TablePropertyCollectors: []func() pebble.TablePropertyCollector{ - newRangePropertiesCollector, - }, - DisableAutomaticCompactions: em.DisableAutomaticCompactions, - } - // set level target file size to avoid pebble auto triggering compaction that split ingest SST files into small SST. - opt.Levels = []pebble.LevelOptions{ - { - TargetFileSize: 16 * units.GiB, - BlockSize: em.BlockSize, - }, - } - - dbPath := filepath.Join(em.LocalStoreDir, engineUUID.String()) - db, err := pebble.Open(dbPath, opt) - return db, errors.Trace(err) -} - -// openEngine must be called with holding mutex of Engine. -func (em *engineManager) openEngine(ctx context.Context, cfg *backend.EngineConfig, engineUUID uuid.UUID) error { - db, err := em.openEngineDB(engineUUID, false) - if err != nil { - return err - } - - sstDir := engineSSTDir(em.LocalStoreDir, engineUUID) - if !cfg.KeepSortDir { - if err := os.RemoveAll(sstDir); err != nil { - return errors.Trace(err) - } - } - if !common.IsDirExists(sstDir) { - if err := os.Mkdir(sstDir, 0o750); err != nil { - return errors.Trace(err) - } - } - engineCtx, cancel := context.WithCancel(ctx) - - e, _ := em.engines.LoadOrStore(engineUUID, &Engine{ - UUID: engineUUID, - sstDir: sstDir, - sstMetasChan: make(chan metaOrFlush, 64), - ctx: engineCtx, - cancel: cancel, - config: cfg.Local, - tableInfo: cfg.TableInfo, - duplicateDetection: em.DupeDetectEnabled, - dupDetectOpt: em.DuplicateDetectOpt, - duplicateDB: em.duplicateDB, - keyAdapter: em.keyAdapter, - logger: log.FromContext(ctx), - }) - engine := e.(*Engine) - engine.lock(importMutexStateOpen) - defer engine.unlock() - engine.db.Store(db) - engine.sstIngester = dbSSTIngester{e: engine} - if err = engine.loadEngineMeta(); err != nil { - return errors.Trace(err) - } - if engine.TS == 0 && cfg.TS > 0 { - engine.TS = cfg.TS - // we don't saveEngineMeta here, we can rely on the caller use the same TS to - // open the engine again. - } - if err = em.allocateTSIfNotExists(ctx, engine); err != nil { - return errors.Trace(err) - } - engine.wg.Add(1) - go engine.ingestSSTLoop() - return nil -} - -// closeEngine closes backend engine by uuid. -func (em *engineManager) closeEngine( - ctx context.Context, - cfg *backend.EngineConfig, - engineUUID uuid.UUID, -) (errRet error) { - if externalCfg := cfg.External; externalCfg != nil { - storeBackend, err := storage.ParseBackend(externalCfg.StorageURI, nil) - if err != nil { - return err - } - store, err := storage.NewWithDefaultOpt(ctx, storeBackend) - if err != nil { - return err - } - defer func() { - if errRet != nil { - store.Close() - } - }() - ts := cfg.TS - if ts == 0 { - physical, logical, err := em.GetTS(ctx) - if err != nil { - return err - } - ts = oracle.ComposeTS(physical, logical) - } - externalEngine := external.NewExternalEngine( - store, - externalCfg.DataFiles, - externalCfg.StatFiles, - externalCfg.StartKey, - externalCfg.EndKey, - externalCfg.SplitKeys, - externalCfg.RegionSplitSize, - em.keyAdapter, - em.DupeDetectEnabled, - em.duplicateDB, - em.DuplicateDetectOpt, - em.WorkerConcurrency, - ts, - externalCfg.TotalFileSize, - externalCfg.TotalKVCount, - externalCfg.CheckHotspot, - ) - em.externalEngine[engineUUID] = externalEngine - return nil - } - - // flush mem table to storage, to free memory, - // ask others' advise, looks like unnecessary, but with this we can control memory precisely. - engineI, ok := em.engines.Load(engineUUID) - if !ok { - // recovery mode, we should reopen this engine file - db, err := em.openEngineDB(engineUUID, true) - if err != nil { - return err - } - engine := &Engine{ - UUID: engineUUID, - sstMetasChan: make(chan metaOrFlush), - tableInfo: cfg.TableInfo, - keyAdapter: em.keyAdapter, - duplicateDetection: em.DupeDetectEnabled, - dupDetectOpt: em.DuplicateDetectOpt, - duplicateDB: em.duplicateDB, - logger: log.FromContext(ctx), - } - engine.db.Store(db) - engine.sstIngester = dbSSTIngester{e: engine} - if err = engine.loadEngineMeta(); err != nil { - return errors.Trace(err) - } - em.engines.Store(engineUUID, engine) - return nil - } - - engine := engineI.(*Engine) - engine.rLock() - if engine.closed.Load() { - engine.rUnlock() - return nil - } - - err := engine.flushEngineWithoutLock(ctx) - engine.rUnlock() - - // use mutex to make sure we won't close sstMetasChan while other routines - // trying to do flush. - engine.lock(importMutexStateClose) - engine.closed.Store(true) - close(engine.sstMetasChan) - engine.unlock() - if err != nil { - return errors.Trace(err) - } - engine.wg.Wait() - return engine.ingestErr.Get() -} - -// getImportedKVCount returns the number of imported KV pairs of some engine. -func (em *engineManager) getImportedKVCount(engineUUID uuid.UUID) int64 { - v, ok := em.engines.Load(engineUUID) - if !ok { - // we get it after import, but before clean up, so this should not happen - // todo: return error - return 0 - } - e := v.(*Engine) - return e.importedKVCount.Load() -} - -// getExternalEngineKVStatistics returns kv statistics of some engine. -func (em *engineManager) getExternalEngineKVStatistics(engineUUID uuid.UUID) ( - totalKVSize int64, totalKVCount int64) { - v, ok := em.externalEngine[engineUUID] - if !ok { - return 0, 0 - } - return v.ImportedStatistics() -} - -// resetEngine reset the engine and reclaim the space. -func (em *engineManager) resetEngine( - ctx context.Context, - engineUUID uuid.UUID, - skipAllocTS bool, -) error { - // the only way to reset the engine + reclaim the space is to delete and reopen it 🤷 - localEngine := em.lockEngine(engineUUID, importMutexStateClose) - if localEngine == nil { - if engineI, ok := em.externalEngine[engineUUID]; ok { - extEngine := engineI.(*external.Engine) - return extEngine.Reset() - } - - log.FromContext(ctx).Warn("could not find engine in cleanupEngine", zap.Stringer("uuid", engineUUID)) - return nil - } - defer localEngine.unlock() - if err := localEngine.Close(); err != nil { - return err - } - if err := localEngine.Cleanup(em.LocalStoreDir); err != nil { - return err - } - db, err := em.openEngineDB(engineUUID, false) - if err == nil { - localEngine.db.Store(db) - localEngine.engineMeta = engineMeta{} - if !common.IsDirExists(localEngine.sstDir) { - if err := os.Mkdir(localEngine.sstDir, 0o750); err != nil { - return errors.Trace(err) - } - } - if !skipAllocTS { - if err = em.allocateTSIfNotExists(ctx, localEngine); err != nil { - return errors.Trace(err) - } - } - } - localEngine.pendingFileSize.Store(0) - - return err -} - -func (em *engineManager) allocateTSIfNotExists(ctx context.Context, engine *Engine) error { - if engine.TS > 0 { - return nil - } - physical, logical, err := em.GetTS(ctx) - if err != nil { - return err - } - ts := oracle.ComposeTS(physical, logical) - engine.TS = ts - return engine.saveEngineMeta() -} - -// cleanupEngine cleanup the engine and reclaim the space. -func (em *engineManager) cleanupEngine(ctx context.Context, engineUUID uuid.UUID) error { - localEngine := em.lockEngine(engineUUID, importMutexStateClose) - // release this engine after import success - if localEngine == nil { - if extEngine, ok := em.externalEngine[engineUUID]; ok { - retErr := extEngine.Close() - delete(em.externalEngine, engineUUID) - return retErr - } - log.FromContext(ctx).Warn("could not find engine in cleanupEngine", zap.Stringer("uuid", engineUUID)) - return nil - } - defer localEngine.unlock() - - // since closing the engine causes all subsequent operations on it panic, - // we make sure to delete it from the engine map before calling Close(). - // (note that Close() returning error does _not_ mean the pebble DB - // remains open/usable.) - em.engines.Delete(engineUUID) - err := localEngine.Close() - if err != nil { - return err - } - err = localEngine.Cleanup(em.LocalStoreDir) - if err != nil { - return err - } - localEngine.TotalSize.Store(0) - localEngine.Length.Store(0) - return nil -} - -// LocalWriter returns a new local writer. -func (em *engineManager) localWriter(_ context.Context, cfg *backend.LocalWriterConfig, engineUUID uuid.UUID) (backend.EngineWriter, error) { - e, ok := em.engines.Load(engineUUID) - if !ok { - return nil, errors.Errorf("could not find engine for %s", engineUUID.String()) - } - engine := e.(*Engine) - memCacheSize := em.LocalWriterMemCacheSize - if cfg.Local.MemCacheSize > 0 { - memCacheSize = cfg.Local.MemCacheSize - } - return openLocalWriter(cfg, engine, em.GetTiKVCodec(), memCacheSize, em.bufferPool.NewBuffer()) -} - -func (em *engineManager) engineFileSizes() (res []backend.EngineFileSize) { - em.engines.Range(func(_, v any) bool { - engine := v.(*Engine) - res = append(res, engine.getEngineFileSize()) - return true - }) - return -} - -func (em *engineManager) close() { - for _, e := range em.externalEngine { - _ = e.Close() - } - em.externalEngine = map[uuid.UUID]common.Engine{} - allLocalEngines := em.lockAllEnginesUnless(importMutexStateClose, 0) - for _, e := range allLocalEngines { - _ = e.Close() - e.unlock() - } - em.engines = sync.Map{} - em.bufferPool.Destroy() - - if em.duplicateDB != nil { - // Check if there are duplicates that are not collected. - iter, err := em.duplicateDB.NewIter(&pebble.IterOptions{}) - if err != nil { - em.logger.Panic("fail to create iterator") - } - hasDuplicates := iter.First() - allIsWell := true - if err := iter.Error(); err != nil { - em.logger.Warn("iterate duplicate db failed", zap.Error(err)) - allIsWell = false - } - if err := iter.Close(); err != nil { - em.logger.Warn("close duplicate db iter failed", zap.Error(err)) - allIsWell = false - } - if err := em.duplicateDB.Close(); err != nil { - em.logger.Warn("close duplicate db failed", zap.Error(err)) - allIsWell = false - } - // If checkpoint is disabled, or we don't detect any duplicate, then this duplicate - // db dir will be useless, so we clean up this dir. - if allIsWell && (!em.CheckpointEnabled || !hasDuplicates) { - if err := os.RemoveAll(filepath.Join(em.LocalStoreDir, duplicateDBName)); err != nil { - em.logger.Warn("remove duplicate db file failed", zap.Error(err)) - } - } - em.duplicateDB = nil - } - - // if checkpoint is disabled, or we finish load all data successfully, then files in this - // dir will be useless, so we clean up this dir and all files in it. - if !em.CheckpointEnabled || common.IsEmptyDir(em.LocalStoreDir) { - err := os.RemoveAll(em.LocalStoreDir) - if err != nil { - em.logger.Warn("remove local db file failed", zap.Error(err)) - } - } -} - -func (em *engineManager) getExternalEngine(uuid uuid.UUID) (common.Engine, bool) { - e, ok := em.externalEngine[uuid] - return e, ok -} - -func (em *engineManager) totalMemoryConsume() int64 { - var memConsume int64 - em.engines.Range(func(_, v any) bool { - e := v.(*Engine) - if e != nil { - memConsume += e.TotalMemorySize() - } - return true - }) - return memConsume + em.bufferPool.TotalSize() -} - -func (em *engineManager) getDuplicateDB() *pebble.DB { - return em.duplicateDB -} - -func (em *engineManager) getKeyAdapter() common.KeyAdapter { - return em.keyAdapter -} - -func (em *engineManager) getBufferPool() *membuf.Pool { - return em.bufferPool -} - -// only used in tests -type slowCreateFS struct { - vfs.FS -} - -// WaitRMFolderChForTest is a channel for testing. -var WaitRMFolderChForTest = make(chan struct{}) - -func (s slowCreateFS) Create(name string) (vfs.File, error) { - if strings.Contains(name, "temporary") { - select { - case <-WaitRMFolderChForTest: - case <-time.After(1 * time.Second): - logutil.BgLogger().Info("no one removes folder") - } - } - return s.FS.Create(name) -} - -func openDuplicateDB(storeDir string) (*pebble.DB, error) { - dbPath := filepath.Join(storeDir, duplicateDBName) - // TODO: Optimize the opts for better write. - opts := &pebble.Options{ - TablePropertyCollectors: []func() pebble.TablePropertyCollector{ - newRangePropertiesCollector, - }, - } - failpoint.Inject("slowCreateFS", func() { - opts.FS = slowCreateFS{vfs.Default} - }) - return pebble.Open(dbPath, opts) -} - -func prepareSortDir(config BackendConfig) error { - shouldCreate := true - if config.CheckpointEnabled { - if info, err := os.Stat(config.LocalStoreDir); err != nil { - if !os.IsNotExist(err) { - return err - } - } else if info.IsDir() { - shouldCreate = false - } - } - - if shouldCreate { - err := os.Mkdir(config.LocalStoreDir, 0o700) - if err != nil { - return common.ErrInvalidSortedKVDir.Wrap(err).GenWithStackByArgs(config.LocalStoreDir) - } - } - return nil -} diff --git a/pkg/lightning/backend/local/local.go b/pkg/lightning/backend/local/local.go index 6b2678b71ab1c..c38a01936e504 100644 --- a/pkg/lightning/backend/local/local.go +++ b/pkg/lightning/backend/local/local.go @@ -183,7 +183,7 @@ func (f *importClientFactoryImpl) makeConn(ctx context.Context, storeID uint64) return nil, common.ErrInvalidConfig.GenWithStack("unsupported compression type %s", f.compressionType) } - if _, _err_ := failpoint.Eval(_curpkg_("LoggingImportBytes")); _err_ == nil { + failpoint.Inject("LoggingImportBytes", func() { opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", target) if err != nil { @@ -191,7 +191,7 @@ func (f *importClientFactoryImpl) makeConn(ctx context.Context, storeID uint64) } return &loggingConn{Conn: conn}, nil })) - } + }) conn, err := grpc.DialContext(ctx, addr, opts...) if err != nil { @@ -891,18 +891,18 @@ func (local *Backend) prepareAndSendJob( // the table when table is created. needSplit := len(initialSplitRanges) > 1 || lfTotalSize > regionSplitSize || lfLength > regionSplitKeys // split region by given ranges - if _, _err_ := failpoint.Eval(_curpkg_("failToSplit")); _err_ == nil { + failpoint.Inject("failToSplit", func(_ failpoint.Value) { needSplit = true - } + }) if needSplit { var err error logger := log.FromContext(ctx).With(zap.String("uuid", engine.ID())).Begin(zap.InfoLevel, "split and scatter ranges") backOffTime := 10 * time.Second maxbackoffTime := 120 * time.Second for i := 0; i < maxRetryTimes; i++ { - if _, _err_ := failpoint.Eval(_curpkg_("skipSplitAndScatter")); _err_ == nil { - break - } + failpoint.Inject("skipSplitAndScatter", func() { + failpoint.Break() + }) err = local.SplitAndScatterRegionInBatches(ctx, initialSplitRanges, maxBatchSplitRanges) if err == nil || common.IsContextCanceledError(err) { @@ -987,13 +987,13 @@ func (local *Backend) generateAndSendJob( return nil } - failpoint.Eval(_curpkg_("beforeGenerateJob")) - if _, _err_ := failpoint.Eval(_curpkg_("sendDummyJob")); _err_ == nil { + failpoint.Inject("beforeGenerateJob", nil) + failpoint.Inject("sendDummyJob", func(_ failpoint.Value) { // this is used to trigger worker failure, used together // with WriteToTiKVNotEnoughDiskSpace jobToWorkerCh <- ®ionJob{} time.Sleep(5 * time.Second) - } + }) jobs, err := local.generateJobForRange(egCtx, p.Data, p.Range, regionSplitSize, regionSplitKeys) if err != nil { if common.IsContextCanceledError(err) { @@ -1045,9 +1045,9 @@ func (local *Backend) generateJobForRange( keyRange common.Range, regionSplitSize, regionSplitKeys int64, ) ([]*regionJob, error) { - if _, _err_ := failpoint.Eval(_curpkg_("fakeRegionJobs")); _err_ == nil { + failpoint.Inject("fakeRegionJobs", func() { if ctx.Err() != nil { - return nil, ctx.Err() + failpoint.Return(nil, ctx.Err()) } key := [2]string{string(keyRange.Start), string(keyRange.End)} injected := fakeRegionJobs[key] @@ -1056,8 +1056,8 @@ func (local *Backend) generateJobForRange( for _, job := range injected.jobs { job.stage = regionScanned } - return injected.jobs, injected.err - } + failpoint.Return(injected.jobs, injected.err) + }) start, end := keyRange.Start, keyRange.End pairStart, pairEnd, err := data.GetFirstAndLastKey(start, end) @@ -1226,9 +1226,10 @@ func (local *Backend) executeJob( ctx context.Context, job *regionJob, ) error { - if _, _err_ := failpoint.Eval(_curpkg_("WriteToTiKVNotEnoughDiskSpace")); _err_ == nil { - return errors.New("the remaining storage capacity of TiKV is less than 10%%; please increase the storage capacity of TiKV and try again") - } + failpoint.Inject("WriteToTiKVNotEnoughDiskSpace", func(_ failpoint.Value) { + failpoint.Return( + errors.New("the remaining storage capacity of TiKV is less than 10%%; please increase the storage capacity of TiKV and try again")) + }) if local.ShouldCheckTiKV { for _, peer := range job.region.Region.GetPeers() { store, err := local.pdHTTPCli.GetStore(ctx, peer.StoreId) @@ -1370,7 +1371,7 @@ func (local *Backend) ImportEngine( zap.Int64("count", lfLength), zap.Int64("size", lfTotalSize)) - failpoint.Eval(_curpkg_("ReadyForImportEngine")) + failpoint.Inject("ReadyForImportEngine", func() {}) err = local.doImport(ctx, e, regionRanges, regionSplitSize, regionSplitKeys) if err == nil { @@ -1420,10 +1421,10 @@ func (local *Backend) doImport(ctx context.Context, engine common.Engine, region ) defer workerCancel() - if _, _err_ := failpoint.Eval(_curpkg_("injectVariables")); _err_ == nil { + failpoint.Inject("injectVariables", func() { jobToWorkerCh = testJobToWorkerCh testJobWg = &jobWg - } + }) retryer := startRegionJobRetryer(workerCtx, jobToWorkerCh, &jobWg) @@ -1475,16 +1476,17 @@ func (local *Backend) doImport(ctx context.Context, engine common.Engine, region } }() - if _, _err_ := failpoint.Eval(_curpkg_("skipStartWorker")); _err_ == nil { - goto afterStartWorker - } + failpoint.Inject("skipStartWorker", func() { + failpoint.Goto("afterStartWorker") + }) for i := 0; i < local.WorkerConcurrency; i++ { workGroup.Go(func() error { return local.startWorker(workerCtx, jobToWorkerCh, jobFromWorkerCh, &jobWg) }) } -afterStartWorker: + + failpoint.Label("afterStartWorker") workGroup.Go(func() error { err := local.prepareAndSendJob( diff --git a/pkg/lightning/backend/local/local.go__failpoint_stash__ b/pkg/lightning/backend/local/local.go__failpoint_stash__ deleted file mode 100644 index c38a01936e504..0000000000000 --- a/pkg/lightning/backend/local/local.go__failpoint_stash__ +++ /dev/null @@ -1,1754 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package local - -import ( - "bytes" - "context" - "database/sql" - "encoding/hex" - "io" - "math" - "net" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/coreos/go-semver/semver" - "github.com/docker/go-units" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - sst "github.com/pingcap/kvproto/pkg/import_sstpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/membuf" - "github.com/pingcap/tidb/br/pkg/pdutil" - "github.com/pingcap/tidb/br/pkg/restore/split" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/external" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/errormanager" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/lightning/tikv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/engine" - tikvclient "github.com/tikv/client-go/v2/tikv" - pd "github.com/tikv/pd/client" - pdhttp "github.com/tikv/pd/client/http" - "github.com/tikv/pd/client/retry" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/backoff" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/status" -) - -const ( - dialTimeout = 5 * time.Minute - maxRetryTimes = 20 - defaultRetryBackoffTime = 3 * time.Second - // maxWriteAndIngestRetryTimes is the max retry times for write and ingest. - // A large retry times is for tolerating tikv cluster failures. - maxWriteAndIngestRetryTimes = 30 - - gRPCKeepAliveTime = 10 * time.Minute - gRPCKeepAliveTimeout = 5 * time.Minute - gRPCBackOffMaxDelay = 10 * time.Minute - - // The max ranges count in a batch to split and scatter. - maxBatchSplitRanges = 4096 - - propRangeIndex = "tikv.range_index" - - defaultPropSizeIndexDistance = 4 * units.MiB - defaultPropKeysIndexDistance = 40 * 1024 - - // the lower threshold of max open files for pebble db. - openFilesLowerThreshold = 128 - - duplicateDBName = "duplicates" - scanRegionLimit = 128 -) - -var ( - // Local backend is compatible with TiDB [4.0.0, NextMajorVersion). - localMinTiDBVersion = *semver.New("4.0.0") - localMinTiKVVersion = *semver.New("4.0.0") - localMinPDVersion = *semver.New("4.0.0") - localMaxTiDBVersion = version.NextMajorVersion() - localMaxTiKVVersion = version.NextMajorVersion() - localMaxPDVersion = version.NextMajorVersion() - tiFlashMinVersion = *semver.New("4.0.5") - tikvSideFreeSpaceCheck = *semver.New("8.0.0") - - errorEngineClosed = errors.New("engine is closed") - maxRetryBackoffSecond = 30 -) - -// ImportClientFactory is factory to create new import client for specific store. -type ImportClientFactory interface { - Create(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) - Close() -} - -type importClientFactoryImpl struct { - conns *common.GRPCConns - splitCli split.SplitClient - tls *common.TLS - tcpConcurrency int - compressionType config.CompressionType -} - -func newImportClientFactoryImpl( - splitCli split.SplitClient, - tls *common.TLS, - tcpConcurrency int, - compressionType config.CompressionType, -) *importClientFactoryImpl { - return &importClientFactoryImpl{ - conns: common.NewGRPCConns(), - splitCli: splitCli, - tls: tls, - tcpConcurrency: tcpConcurrency, - compressionType: compressionType, - } -} - -func (f *importClientFactoryImpl) makeConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - store, err := f.splitCli.GetStore(ctx, storeID) - if err != nil { - return nil, errors.Trace(err) - } - var opts []grpc.DialOption - if f.tls.TLSConfig() != nil { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(f.tls.TLSConfig()))) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } - ctx, cancel := context.WithTimeout(ctx, dialTimeout) - defer cancel() - - bfConf := backoff.DefaultConfig - bfConf.MaxDelay = gRPCBackOffMaxDelay - // we should use peer address for tiflash. for tikv, peer address is empty - addr := store.GetPeerAddress() - if addr == "" { - addr = store.GetAddress() - } - opts = append(opts, - grpc.WithConnectParams(grpc.ConnectParams{Backoff: bfConf}), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: gRPCKeepAliveTime, - Timeout: gRPCKeepAliveTimeout, - PermitWithoutStream: true, - }), - ) - switch f.compressionType { - case config.CompressionNone: - // do nothing - case config.CompressionGzip: - // Use custom compressor/decompressor to speed up compression/decompression. - // Note that here we don't use grpc.UseCompressor option although it's the recommended way. - // Because gprc-go uses a global registry to store compressor/decompressor, we can't make sure - // the compressor/decompressor is not registered by other components. - opts = append(opts, grpc.WithCompressor(&gzipCompressor{}), grpc.WithDecompressor(&gzipDecompressor{})) - default: - return nil, common.ErrInvalidConfig.GenWithStack("unsupported compression type %s", f.compressionType) - } - - failpoint.Inject("LoggingImportBytes", func() { - opts = append(opts, grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { - conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", target) - if err != nil { - return nil, err - } - return &loggingConn{Conn: conn}, nil - })) - }) - - conn, err := grpc.DialContext(ctx, addr, opts...) - if err != nil { - return nil, errors.Trace(err) - } - return conn, nil -} - -func (f *importClientFactoryImpl) getGrpcConn(ctx context.Context, storeID uint64) (*grpc.ClientConn, error) { - return f.conns.GetGrpcConn(ctx, storeID, f.tcpConcurrency, - func(ctx context.Context) (*grpc.ClientConn, error) { - return f.makeConn(ctx, storeID) - }) -} - -// Create creates a new import client for specific store. -func (f *importClientFactoryImpl) Create(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { - conn, err := f.getGrpcConn(ctx, storeID) - if err != nil { - return nil, err - } - return sst.NewImportSSTClient(conn), nil -} - -// Close closes the factory. -func (f *importClientFactoryImpl) Close() { - f.conns.Close() -} - -type loggingConn struct { - net.Conn -} - -// Write implements net.Conn.Write -func (c loggingConn) Write(b []byte) (int, error) { - log.L().Debug("import write", zap.Int("bytes", len(b))) - return c.Conn.Write(b) -} - -type encodingBuilder struct { - metrics *metric.Metrics -} - -// NewEncodingBuilder creates an KVEncodingBuilder with local backend implementation. -func NewEncodingBuilder(ctx context.Context) encode.EncodingBuilder { - result := new(encodingBuilder) - if m, ok := metric.FromContext(ctx); ok { - result.metrics = m - } - return result -} - -// NewEncoder creates a KV encoder. -// It implements the `backend.EncodingBuilder` interface. -func (b *encodingBuilder) NewEncoder(_ context.Context, config *encode.EncodingConfig) (encode.Encoder, error) { - return kv.NewTableKVEncoder(config, b.metrics) -} - -// MakeEmptyRows creates an empty KV rows. -// It implements the `backend.EncodingBuilder` interface. -func (*encodingBuilder) MakeEmptyRows() encode.Rows { - return kv.MakeRowsFromKvPairs(nil) -} - -type targetInfoGetter struct { - tls *common.TLS - targetDB *sql.DB - pdHTTPCli pdhttp.Client -} - -// NewTargetInfoGetter creates an TargetInfoGetter with local backend -// implementation. `pdHTTPCli` should not be nil when need to check component -// versions in CheckRequirements. -func NewTargetInfoGetter( - tls *common.TLS, - db *sql.DB, - pdHTTPCli pdhttp.Client, -) backend.TargetInfoGetter { - return &targetInfoGetter{ - tls: tls, - targetDB: db, - pdHTTPCli: pdHTTPCli, - } -} - -// FetchRemoteDBModels implements the `backend.TargetInfoGetter` interface. -func (g *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { - return tikv.FetchRemoteDBModelsFromTLS(ctx, g.tls) -} - -// FetchRemoteTableModels obtains the models of all tables given the schema name. -// It implements the `TargetInfoGetter` interface. -func (g *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - return tikv.FetchRemoteTableModelsFromTLS(ctx, g.tls, schemaName) -} - -// CheckRequirements performs the check whether the backend satisfies the version requirements. -// It implements the `TargetInfoGetter` interface. -func (g *targetInfoGetter) CheckRequirements(ctx context.Context, checkCtx *backend.CheckCtx) error { - // TODO: support lightning via SQL - versionStr, err := version.FetchVersion(ctx, g.targetDB) - if err != nil { - return errors.Trace(err) - } - if err := checkTiDBVersion(ctx, versionStr, localMinTiDBVersion, localMaxTiDBVersion); err != nil { - return err - } - if g.pdHTTPCli == nil { - return common.ErrUnknown.GenWithStack("pd HTTP client is required for component version check in local backend") - } - if err := tikv.CheckPDVersion(ctx, g.pdHTTPCli, localMinPDVersion, localMaxPDVersion); err != nil { - return err - } - if err := tikv.CheckTiKVVersion(ctx, g.pdHTTPCli, localMinTiKVVersion, localMaxTiKVVersion); err != nil { - return err - } - - serverInfo := version.ParseServerInfo(versionStr) - return checkTiFlashVersion(ctx, g.targetDB, checkCtx, *serverInfo.ServerVersion) -} - -func checkTiDBVersion(_ context.Context, versionStr string, requiredMinVersion, requiredMaxVersion semver.Version) error { - return version.CheckTiDBVersion(versionStr, requiredMinVersion, requiredMaxVersion) -} - -var tiFlashReplicaQuery = "SELECT TABLE_SCHEMA, TABLE_NAME FROM information_schema.TIFLASH_REPLICA WHERE REPLICA_COUNT > 0;" - -// TiFlashReplicaQueryForTest is only used for tests. -var TiFlashReplicaQueryForTest = tiFlashReplicaQuery - -type tblName struct { - schema string - name string -} - -type tblNames []tblName - -// String implements fmt.Stringer -func (t tblNames) String() string { - var b strings.Builder - b.WriteByte('[') - for i, n := range t { - if i > 0 { - b.WriteString(", ") - } - b.WriteString(common.UniqueTable(n.schema, n.name)) - } - b.WriteByte(']') - return b.String() -} - -// CheckTiFlashVersionForTest is only used for tests. -var CheckTiFlashVersionForTest = checkTiFlashVersion - -// check TiFlash replicas. -// local backend doesn't support TiFlash before tidb v4.0.5 -func checkTiFlashVersion(ctx context.Context, db *sql.DB, checkCtx *backend.CheckCtx, tidbVersion semver.Version) error { - if tidbVersion.Compare(tiFlashMinVersion) >= 0 { - return nil - } - - exec := common.SQLWithRetry{ - DB: db, - Logger: log.FromContext(ctx), - } - - res, err := exec.QueryStringRows(ctx, "fetch tiflash replica info", tiFlashReplicaQuery) - if err != nil { - return errors.Annotate(err, "fetch tiflash replica info failed") - } - - tiFlashTablesMap := make(map[tblName]struct{}, len(res)) - for _, tblInfo := range res { - name := tblName{schema: tblInfo[0], name: tblInfo[1]} - tiFlashTablesMap[name] = struct{}{} - } - - tiFlashTables := make(tblNames, 0) - for _, dbMeta := range checkCtx.DBMetas { - for _, tblMeta := range dbMeta.Tables { - if len(tblMeta.DataFiles) == 0 { - continue - } - name := tblName{schema: tblMeta.DB, name: tblMeta.Name} - if _, ok := tiFlashTablesMap[name]; ok { - tiFlashTables = append(tiFlashTables, name) - } - } - } - - if len(tiFlashTables) > 0 { - helpInfo := "Please either upgrade TiDB to version >= 4.0.5 or add TiFlash replica after load data." - return errors.Errorf("lightning local backend doesn't support TiFlash in this TiDB version. conflict tables: %s. "+helpInfo, tiFlashTables) - } - return nil -} - -// BackendConfig is the config for local backend. -type BackendConfig struct { - // comma separated list of PD endpoints. - PDAddr string - LocalStoreDir string - // max number of cached grpc.ClientConn to a store. - // note: this is not the limit of actual connections, each grpc.ClientConn can have one or more of it. - MaxConnPerStore int - // compress type when write or ingest into tikv - ConnCompressType config.CompressionType - // concurrency of generateJobForRange and import(write & ingest) workers - WorkerConcurrency int - // batch kv size when writing to TiKV - KVWriteBatchSize int64 - RegionSplitBatchSize int - RegionSplitConcurrency int - CheckpointEnabled bool - // memory table size of pebble. since pebble can have multiple mem tables, the max memory used is - // MemTableSize * MemTableStopWritesThreshold, see pebble.Options for more details. - MemTableSize int - // LocalWriterMemCacheSize is the memory threshold for one local writer of - // engines. If the KV payload size exceeds LocalWriterMemCacheSize, local writer - // will flush them into the engine. - // - // It has lower priority than LocalWriterConfig.Local.MemCacheSize. - LocalWriterMemCacheSize int64 - // whether check TiKV capacity before write & ingest. - ShouldCheckTiKV bool - DupeDetectEnabled bool - DuplicateDetectOpt common.DupDetectOpt - // max write speed in bytes per second to each store(burst is allowed), 0 means no limit - StoreWriteBWLimit int - // When TiKV is in normal mode, ingesting too many SSTs will cause TiKV write stall. - // To avoid this, we should check write stall before ingesting SSTs. Note that, we - // must check both leader node and followers in client side, because followers will - // not check write stall as long as ingest command is accepted by leader. - ShouldCheckWriteStall bool - // soft limit on the number of open files that can be used by pebble DB. - // the minimum value is 128. - MaxOpenFiles int - KeyspaceName string - // the scope when pause PD schedulers. - PausePDSchedulerScope config.PausePDSchedulerScope - ResourceGroupName string - TaskType string - RaftKV2SwitchModeDuration time.Duration - // whether disable automatic compactions of pebble db of engine. - // deduplicate pebble db is not affected by this option. - // see DisableAutomaticCompactions of pebble.Options for more details. - // default true. - DisableAutomaticCompactions bool - BlockSize int -} - -// NewBackendConfig creates a new BackendConfig. -func NewBackendConfig(cfg *config.Config, maxOpenFiles int, keyspaceName, resourceGroupName, taskType string, raftKV2SwitchModeDuration time.Duration) BackendConfig { - return BackendConfig{ - PDAddr: cfg.TiDB.PdAddr, - LocalStoreDir: cfg.TikvImporter.SortedKVDir, - MaxConnPerStore: cfg.TikvImporter.RangeConcurrency, - ConnCompressType: cfg.TikvImporter.CompressKVPairs, - WorkerConcurrency: cfg.TikvImporter.RangeConcurrency * 2, - BlockSize: int(cfg.TikvImporter.BlockSize), - KVWriteBatchSize: int64(cfg.TikvImporter.SendKVSize), - RegionSplitBatchSize: cfg.TikvImporter.RegionSplitBatchSize, - RegionSplitConcurrency: cfg.TikvImporter.RegionSplitConcurrency, - CheckpointEnabled: cfg.Checkpoint.Enable, - MemTableSize: int(cfg.TikvImporter.EngineMemCacheSize), - LocalWriterMemCacheSize: int64(cfg.TikvImporter.LocalWriterMemCacheSize), - ShouldCheckTiKV: cfg.App.CheckRequirements, - DupeDetectEnabled: cfg.Conflict.Strategy != config.NoneOnDup, - DuplicateDetectOpt: common.DupDetectOpt{ReportErrOnDup: cfg.Conflict.Strategy == config.ErrorOnDup}, - StoreWriteBWLimit: int(cfg.TikvImporter.StoreWriteBWLimit), - ShouldCheckWriteStall: cfg.Cron.SwitchMode.Duration == 0, - MaxOpenFiles: maxOpenFiles, - KeyspaceName: keyspaceName, - PausePDSchedulerScope: cfg.TikvImporter.PausePDSchedulerScope, - ResourceGroupName: resourceGroupName, - TaskType: taskType, - RaftKV2SwitchModeDuration: raftKV2SwitchModeDuration, - DisableAutomaticCompactions: true, - } -} - -func (c *BackendConfig) adjust() { - c.MaxOpenFiles = max(c.MaxOpenFiles, openFilesLowerThreshold) -} - -// Backend is a local backend. -type Backend struct { - pdCli pd.Client - pdHTTPCli pdhttp.Client - splitCli split.SplitClient - tikvCli *tikvclient.KVStore - tls *common.TLS - tikvCodec tikvclient.Codec - - BackendConfig - engineMgr *engineManager - - supportMultiIngest bool - importClientFactory ImportClientFactory - - metrics *metric.Common - writeLimiter StoreWriteLimiter - logger log.Logger - // This mutex is used to do some mutual exclusion work in the backend, flushKVs() in writer for now. - mu sync.Mutex -} - -var _ DiskUsage = (*Backend)(nil) -var _ StoreHelper = (*Backend)(nil) -var _ backend.Backend = (*Backend)(nil) - -const ( - pdCliMaxMsgSize = int(128 * units.MiB) // pd.ScanRegion may return a large response -) - -var ( - maxCallMsgSize = []grpc.DialOption{ - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(pdCliMaxMsgSize)), - grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(pdCliMaxMsgSize)), - } -) - -// NewBackend creates new connections to tikv. -func NewBackend( - ctx context.Context, - tls *common.TLS, - config BackendConfig, - pdSvcDiscovery pd.ServiceDiscovery, -) (b *Backend, err error) { - var ( - pdCli pd.Client - spkv *tikvclient.EtcdSafePointKV - pdCliForTiKV *tikvclient.CodecPDClient - rpcCli tikvclient.Client - tikvCli *tikvclient.KVStore - pdHTTPCli pdhttp.Client - importClientFactory *importClientFactoryImpl - multiIngestSupported bool - ) - defer func() { - if err == nil { - return - } - if importClientFactory != nil { - importClientFactory.Close() - } - if pdHTTPCli != nil { - pdHTTPCli.Close() - } - if tikvCli != nil { - // tikvCli uses pdCliForTiKV(which wraps pdCli) , spkv and rpcCli, so - // close tikvCli will close all of them. - _ = tikvCli.Close() - } else { - if rpcCli != nil { - _ = rpcCli.Close() - } - if spkv != nil { - _ = spkv.Close() - } - // pdCliForTiKV wraps pdCli, so we only need close pdCli - if pdCli != nil { - pdCli.Close() - } - } - }() - config.adjust() - var pdAddrs []string - if pdSvcDiscovery != nil { - pdAddrs = pdSvcDiscovery.GetServiceURLs() - // TODO(lance6716): if PD client can support creating a client with external - // service discovery, we can directly pass pdSvcDiscovery. - } else { - pdAddrs = strings.Split(config.PDAddr, ",") - } - pdCli, err = pd.NewClientWithContext( - ctx, pdAddrs, tls.ToPDSecurityOption(), - pd.WithGRPCDialOptions(maxCallMsgSize...), - // If the time too short, we may scatter a region many times, because - // the interface `ScatterRegions` may time out. - pd.WithCustomTimeoutOption(60*time.Second), - ) - if err != nil { - return nil, common.NormalizeOrWrapErr(common.ErrCreatePDClient, err) - } - - // The following copies tikv.NewTxnClient without creating yet another pdClient. - spkv, err = tikvclient.NewEtcdSafePointKV(strings.Split(config.PDAddr, ","), tls.TLSConfig()) - if err != nil { - return nil, common.ErrCreateKVClient.Wrap(err).GenWithStackByArgs() - } - - if config.KeyspaceName == "" { - pdCliForTiKV = tikvclient.NewCodecPDClient(tikvclient.ModeTxn, pdCli) - } else { - pdCliForTiKV, err = tikvclient.NewCodecPDClientWithKeyspace(tikvclient.ModeTxn, pdCli, config.KeyspaceName) - if err != nil { - return nil, common.ErrCreatePDClient.Wrap(err).GenWithStackByArgs() - } - } - - tikvCodec := pdCliForTiKV.GetCodec() - rpcCli = tikvclient.NewRPCClient(tikvclient.WithSecurity(tls.ToTiKVSecurityConfig()), tikvclient.WithCodec(tikvCodec)) - tikvCli, err = tikvclient.NewKVStore("lightning-local-backend", pdCliForTiKV, spkv, rpcCli) - if err != nil { - return nil, common.ErrCreateKVClient.Wrap(err).GenWithStackByArgs() - } - pdHTTPCli = pdhttp.NewClientWithServiceDiscovery( - "lightning", - pdCli.GetServiceDiscovery(), - pdhttp.WithTLSConfig(tls.TLSConfig()), - ).WithBackoffer(retry.InitialBackoffer(time.Second, time.Second, pdutil.PDRequestRetryTime*time.Second)) - splitCli := split.NewClient(pdCli, pdHTTPCli, tls.TLSConfig(), config.RegionSplitBatchSize, config.RegionSplitConcurrency) - importClientFactory = newImportClientFactoryImpl(splitCli, tls, config.MaxConnPerStore, config.ConnCompressType) - - multiIngestSupported, err = checkMultiIngestSupport(ctx, pdCli, importClientFactory) - if err != nil { - return nil, common.ErrCheckMultiIngest.Wrap(err).GenWithStackByArgs() - } - - var writeLimiter StoreWriteLimiter - if config.StoreWriteBWLimit > 0 { - writeLimiter = newStoreWriteLimiter(config.StoreWriteBWLimit) - } else { - writeLimiter = noopStoreWriteLimiter{} - } - local := &Backend{ - pdCli: pdCli, - pdHTTPCli: pdHTTPCli, - splitCli: splitCli, - tikvCli: tikvCli, - tls: tls, - tikvCodec: tikvCodec, - - BackendConfig: config, - - supportMultiIngest: multiIngestSupported, - importClientFactory: importClientFactory, - writeLimiter: writeLimiter, - logger: log.FromContext(ctx), - } - local.engineMgr, err = newEngineManager(config, local, local.logger) - if err != nil { - return nil, err - } - if m, ok := metric.GetCommonMetric(ctx); ok { - local.metrics = m - } - local.tikvSideCheckFreeSpace(ctx) - - return local, nil -} - -// NewBackendForTest creates a new Backend for test. -func NewBackendForTest(ctx context.Context, config BackendConfig, storeHelper StoreHelper) (*Backend, error) { - config.adjust() - - logger := log.FromContext(ctx) - engineMgr, err := newEngineManager(config, storeHelper, logger) - if err != nil { - return nil, err - } - local := &Backend{ - BackendConfig: config, - logger: logger, - engineMgr: engineMgr, - } - if m, ok := metric.GetCommonMetric(ctx); ok { - local.metrics = m - } - - return local, nil -} - -// TotalMemoryConsume returns the total memory usage of the local backend. -func (local *Backend) TotalMemoryConsume() int64 { - return local.engineMgr.totalMemoryConsume() -} - -func checkMultiIngestSupport(ctx context.Context, pdCli pd.Client, importClientFactory ImportClientFactory) (bool, error) { - stores, err := pdCli.GetAllStores(ctx, pd.WithExcludeTombstone()) - if err != nil { - return false, errors.Trace(err) - } - - hasTiFlash := false - for _, s := range stores { - if s.State == metapb.StoreState_Up && engine.IsTiFlash(s) { - hasTiFlash = true - break - } - } - - for _, s := range stores { - // skip stores that are not online - if s.State != metapb.StoreState_Up || engine.IsTiFlash(s) { - continue - } - var err error - for i := 0; i < maxRetryTimes; i++ { - if i > 0 { - select { - case <-time.After(100 * time.Millisecond): - case <-ctx.Done(): - return false, ctx.Err() - } - } - client, err1 := importClientFactory.Create(ctx, s.Id) - if err1 != nil { - err = err1 - log.FromContext(ctx).Warn("get import client failed", zap.Error(err), zap.String("store", s.Address)) - continue - } - _, err = client.MultiIngest(ctx, &sst.MultiIngestRequest{}) - if err == nil { - break - } - if st, ok := status.FromError(err); ok { - if st.Code() == codes.Unimplemented { - log.FromContext(ctx).Info("multi ingest not support", zap.Any("unsupported store", s)) - return false, nil - } - } - log.FromContext(ctx).Warn("check multi ingest support failed", zap.Error(err), zap.String("store", s.Address), - zap.Int("retry", i)) - } - if err != nil { - // if the cluster contains no TiFlash store, we don't need the multi-ingest feature, - // so in this condition, downgrade the logic instead of return an error. - if hasTiFlash { - return false, errors.Trace(err) - } - log.FromContext(ctx).Warn("check multi failed all retry, fallback to false", log.ShortError(err)) - return false, nil - } - } - - log.FromContext(ctx).Info("multi ingest support") - return true, nil -} - -func (local *Backend) tikvSideCheckFreeSpace(ctx context.Context) { - if !local.ShouldCheckTiKV { - return - } - err := tikv.ForTiKVVersions( - ctx, - local.pdHTTPCli, - func(version *semver.Version, addrMsg string) error { - if version.Compare(tikvSideFreeSpaceCheck) < 0 { - return errors.Errorf( - "%s has version %s, it does not support server side free space check", - addrMsg, version, - ) - } - return nil - }, - ) - if err == nil { - local.logger.Info("TiKV server side free space check is enabled, so lightning will turn it off") - local.ShouldCheckTiKV = false - } else { - local.logger.Info("", zap.Error(err)) - } -} - -// Close the local backend. -func (local *Backend) Close() { - local.engineMgr.close() - local.importClientFactory.Close() - - _ = local.tikvCli.Close() - local.pdHTTPCli.Close() - local.pdCli.Close() -} - -// FlushEngine ensure the written data is saved successfully, to make sure no data lose after restart -func (local *Backend) FlushEngine(ctx context.Context, engineID uuid.UUID) error { - return local.engineMgr.flushEngine(ctx, engineID) -} - -// FlushAllEngines flush all engines. -func (local *Backend) FlushAllEngines(parentCtx context.Context) (err error) { - return local.engineMgr.flushAllEngines(parentCtx) -} - -// RetryImportDelay returns the delay time before retrying to import a file. -func (*Backend) RetryImportDelay() time.Duration { - return defaultRetryBackoffTime -} - -// ShouldPostProcess returns true if the backend should post process the data. -func (*Backend) ShouldPostProcess() bool { - return true -} - -// OpenEngine must be called with holding mutex of Engine. -func (local *Backend) OpenEngine(ctx context.Context, cfg *backend.EngineConfig, engineUUID uuid.UUID) error { - return local.engineMgr.openEngine(ctx, cfg, engineUUID) -} - -// CloseEngine closes backend engine by uuid. -func (local *Backend) CloseEngine(ctx context.Context, cfg *backend.EngineConfig, engineUUID uuid.UUID) error { - return local.engineMgr.closeEngine(ctx, cfg, engineUUID) -} - -func (local *Backend) getImportClient(ctx context.Context, storeID uint64) (sst.ImportSSTClient, error) { - return local.importClientFactory.Create(ctx, storeID) -} - -func splitRangeBySizeProps(fullRange common.Range, sizeProps *sizeProperties, sizeLimit int64, keysLimit int64) []common.Range { - ranges := make([]common.Range, 0, sizeProps.totalSize/uint64(sizeLimit)) - curSize := uint64(0) - curKeys := uint64(0) - curKey := fullRange.Start - - sizeProps.iter(func(p *rangeProperty) bool { - if bytes.Compare(p.Key, curKey) <= 0 { - return true - } - if bytes.Compare(p.Key, fullRange.End) > 0 { - return false - } - curSize += p.Size - curKeys += p.Keys - if int64(curSize) >= sizeLimit || int64(curKeys) >= keysLimit { - ranges = append(ranges, common.Range{Start: curKey, End: p.Key}) - curKey = p.Key - curSize = 0 - curKeys = 0 - } - return true - }) - - if bytes.Compare(curKey, fullRange.End) < 0 { - // If the remaining range is too small, append it to last range. - if len(ranges) > 0 && curKeys == 0 { - ranges[len(ranges)-1].End = fullRange.End - } else { - ranges = append(ranges, common.Range{Start: curKey, End: fullRange.End}) - } - } - return ranges -} - -func readAndSplitIntoRange( - ctx context.Context, - engine common.Engine, - sizeLimit int64, - keysLimit int64, -) ([]common.Range, error) { - startKey, endKey, err := engine.GetKeyRange() - if err != nil { - return nil, err - } - if startKey == nil { - return nil, errors.New("could not find first pair") - } - - engineFileTotalSize, engineFileLength := engine.KVStatistics() - - if engineFileTotalSize <= sizeLimit && engineFileLength <= keysLimit { - ranges := []common.Range{{Start: startKey, End: endKey}} - return ranges, nil - } - - logger := log.FromContext(ctx).With(zap.String("engine", engine.ID())) - ranges, err := engine.SplitRanges(startKey, endKey, sizeLimit, keysLimit, logger) - logger.Info("split engine key ranges", - zap.Int64("totalSize", engineFileTotalSize), zap.Int64("totalCount", engineFileLength), - logutil.Key("startKey", startKey), logutil.Key("endKey", endKey), - zap.Int("ranges", len(ranges)), zap.Error(err)) - return ranges, err -} - -// prepareAndSendJob will read the engine to get estimated key range, -// then split and scatter regions for these range and send region jobs to jobToWorkerCh. -// NOTE when ctx is Done, this function will NOT return error even if it hasn't sent -// all the jobs to jobToWorkerCh. This is because the "first error" can only be -// found by checking the work group LATER, we don't want to return an error to -// seize the "first" error. -func (local *Backend) prepareAndSendJob( - ctx context.Context, - engine common.Engine, - initialSplitRanges []common.Range, - regionSplitSize, regionSplitKeys int64, - jobToWorkerCh chan<- *regionJob, - jobWg *sync.WaitGroup, -) error { - lfTotalSize, lfLength := engine.KVStatistics() - log.FromContext(ctx).Info("import engine ranges", zap.Int("count", len(initialSplitRanges))) - if len(initialSplitRanges) == 0 { - return nil - } - - // if all the kv can fit in one region, skip split regions. TiDB will split one region for - // the table when table is created. - needSplit := len(initialSplitRanges) > 1 || lfTotalSize > regionSplitSize || lfLength > regionSplitKeys - // split region by given ranges - failpoint.Inject("failToSplit", func(_ failpoint.Value) { - needSplit = true - }) - if needSplit { - var err error - logger := log.FromContext(ctx).With(zap.String("uuid", engine.ID())).Begin(zap.InfoLevel, "split and scatter ranges") - backOffTime := 10 * time.Second - maxbackoffTime := 120 * time.Second - for i := 0; i < maxRetryTimes; i++ { - failpoint.Inject("skipSplitAndScatter", func() { - failpoint.Break() - }) - - err = local.SplitAndScatterRegionInBatches(ctx, initialSplitRanges, maxBatchSplitRanges) - if err == nil || common.IsContextCanceledError(err) { - break - } - - log.FromContext(ctx).Warn("split and scatter failed in retry", zap.String("engine ID", engine.ID()), - log.ShortError(err), zap.Int("retry", i)) - select { - case <-time.After(backOffTime): - case <-ctx.Done(): - return ctx.Err() - } - backOffTime *= 2 - if backOffTime > maxbackoffTime { - backOffTime = maxbackoffTime - } - } - logger.End(zap.ErrorLevel, err) - if err != nil { - return err - } - } - - return local.generateAndSendJob( - ctx, - engine, - initialSplitRanges, - regionSplitSize, - regionSplitKeys, - jobToWorkerCh, - jobWg, - ) -} - -// generateAndSendJob scans the region in ranges and send region jobs to jobToWorkerCh. -func (local *Backend) generateAndSendJob( - ctx context.Context, - engine common.Engine, - jobRanges []common.Range, - regionSplitSize, regionSplitKeys int64, - jobToWorkerCh chan<- *regionJob, - jobWg *sync.WaitGroup, -) error { - logger := log.FromContext(ctx) - // for external engine, it will split into smaller data inside LoadIngestData - if localEngine, ok := engine.(*Engine); ok { - // when use dynamic region feature, the region may be very big, we need - // to split to smaller ranges to increase the concurrency. - if regionSplitSize > 2*int64(config.SplitRegionSize) { - start := jobRanges[0].Start - end := jobRanges[len(jobRanges)-1].End - sizeLimit := int64(config.SplitRegionSize) - keysLimit := int64(config.SplitRegionKeys) - jrs, err := localEngine.SplitRanges(start, end, sizeLimit, keysLimit, logger) - if err != nil { - return errors.Trace(err) - } - jobRanges = jrs - } - } - - logger.Debug("the ranges length write to tikv", zap.Int("length", len(jobRanges))) - - eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) - - dataAndRangeCh := make(chan common.DataAndRange) - conn := local.WorkerConcurrency - if _, ok := engine.(*external.Engine); ok { - // currently external engine will generate a large IngestData, se we lower the - // concurrency to pass backpressure to the LoadIngestData goroutine to avoid OOM - conn = 1 - } - for i := 0; i < conn; i++ { - eg.Go(func() error { - for { - select { - case <-egCtx.Done(): - return nil - case p, ok := <-dataAndRangeCh: - if !ok { - return nil - } - - failpoint.Inject("beforeGenerateJob", nil) - failpoint.Inject("sendDummyJob", func(_ failpoint.Value) { - // this is used to trigger worker failure, used together - // with WriteToTiKVNotEnoughDiskSpace - jobToWorkerCh <- ®ionJob{} - time.Sleep(5 * time.Second) - }) - jobs, err := local.generateJobForRange(egCtx, p.Data, p.Range, regionSplitSize, regionSplitKeys) - if err != nil { - if common.IsContextCanceledError(err) { - return nil - } - return err - } - for _, job := range jobs { - job.ref(jobWg) - select { - case <-egCtx.Done(): - // this job is not put into jobToWorkerCh - job.done(jobWg) - // if the context is canceled, it means worker has error, the first error can be - // found by worker's error group LATER. if this function returns an error it will - // seize the "first error". - return nil - case jobToWorkerCh <- job: - } - } - } - } - }) - } - - eg.Go(func() error { - err := engine.LoadIngestData(egCtx, jobRanges, dataAndRangeCh) - if err != nil { - return errors.Trace(err) - } - close(dataAndRangeCh) - return nil - }) - - return eg.Wait() -} - -// fakeRegionJobs is used in test, the injected job can be found by (startKey, endKey). -var fakeRegionJobs map[[2]string]struct { - jobs []*regionJob - err error -} - -// generateJobForRange will scan the region in `keyRange` and generate region jobs. -// It will retry internally when scan region meet error. -func (local *Backend) generateJobForRange( - ctx context.Context, - data common.IngestData, - keyRange common.Range, - regionSplitSize, regionSplitKeys int64, -) ([]*regionJob, error) { - failpoint.Inject("fakeRegionJobs", func() { - if ctx.Err() != nil { - failpoint.Return(nil, ctx.Err()) - } - key := [2]string{string(keyRange.Start), string(keyRange.End)} - injected := fakeRegionJobs[key] - // overwrite the stage to regionScanned, because some time same keyRange - // will be generated more than once. - for _, job := range injected.jobs { - job.stage = regionScanned - } - failpoint.Return(injected.jobs, injected.err) - }) - - start, end := keyRange.Start, keyRange.End - pairStart, pairEnd, err := data.GetFirstAndLastKey(start, end) - if err != nil { - return nil, err - } - if pairStart == nil { - logFn := log.FromContext(ctx).Info - if _, ok := data.(*external.MemoryIngestData); ok { - logFn = log.FromContext(ctx).Warn - } - logFn("There is no pairs in range", - logutil.Key("start", start), - logutil.Key("end", end)) - // trigger cleanup - data.IncRef() - data.DecRef() - return nil, nil - } - - startKey := codec.EncodeBytes([]byte{}, pairStart) - endKey := codec.EncodeBytes([]byte{}, nextKey(pairEnd)) - regions, err := split.PaginateScanRegion(ctx, local.splitCli, startKey, endKey, scanRegionLimit) - if err != nil { - log.FromContext(ctx).Error("scan region failed", - log.ShortError(err), zap.Int("region_len", len(regions)), - logutil.Key("startKey", startKey), - logutil.Key("endKey", endKey)) - return nil, err - } - - jobs := make([]*regionJob, 0, len(regions)) - for _, region := range regions { - log.FromContext(ctx).Debug("get region", - zap.Binary("startKey", startKey), - zap.Binary("endKey", endKey), - zap.Uint64("id", region.Region.GetId()), - zap.Stringer("epoch", region.Region.GetRegionEpoch()), - zap.Binary("start", region.Region.GetStartKey()), - zap.Binary("end", region.Region.GetEndKey()), - zap.Reflect("peers", region.Region.GetPeers())) - - jobs = append(jobs, ®ionJob{ - keyRange: intersectRange(region.Region, common.Range{Start: start, End: end}), - region: region, - stage: regionScanned, - ingestData: data, - regionSplitSize: regionSplitSize, - regionSplitKeys: regionSplitKeys, - metrics: local.metrics, - }) - } - return jobs, nil -} - -// startWorker creates a worker that reads from the job channel and processes. -// startWorker will return nil if it's expected to stop, where the only case is -// the context canceled. It will return not nil error when it actively stops. -// startWorker must Done the jobWg if it does not put the job into jobOutCh. -func (local *Backend) startWorker( - ctx context.Context, - jobInCh, jobOutCh chan *regionJob, - jobWg *sync.WaitGroup, -) error { - metrics.GlobalSortIngestWorkerCnt.WithLabelValues("execute job").Set(0) - for { - select { - case <-ctx.Done(): - return nil - case job, ok := <-jobInCh: - if !ok { - // In fact we don't use close input channel to notify worker to - // exit, because there's a cycle in workflow. - return nil - } - - metrics.GlobalSortIngestWorkerCnt.WithLabelValues("execute job").Inc() - err := local.executeJob(ctx, job) - metrics.GlobalSortIngestWorkerCnt.WithLabelValues("execute job").Dec() - switch job.stage { - case regionScanned, wrote, ingested: - jobOutCh <- job - case needRescan: - jobs, err2 := local.generateJobForRange( - ctx, - job.ingestData, - job.keyRange, - job.regionSplitSize, - job.regionSplitKeys, - ) - if err2 != nil { - // Don't need to put the job back to retry, because generateJobForRange - // has done the retry internally. Here just done for the "needRescan" - // job and exit directly. - job.done(jobWg) - return err2 - } - // 1 "needRescan" job becomes len(jobs) "regionScanned" jobs. - newJobCnt := len(jobs) - 1 - for newJobCnt > 0 { - job.ref(jobWg) - newJobCnt-- - } - for _, j := range jobs { - j.lastRetryableErr = job.lastRetryableErr - jobOutCh <- j - } - } - - if err != nil { - return err - } - } - } -} - -func (*Backend) isRetryableImportTiKVError(err error) bool { - err = errors.Cause(err) - // io.EOF is not retryable in normal case - // but on TiKV restart, if we're writing to TiKV(through GRPC) - // it might return io.EOF(it's GRPC Unavailable in most case), - // we need to retry on this error. - // see SendMsg in https://pkg.go.dev/google.golang.org/grpc#ClientStream - if err == io.EOF { - return true - } - return common.IsRetryableError(err) -} - -func checkDiskAvail(ctx context.Context, store *pdhttp.StoreInfo) error { - logger := log.FromContext(ctx) - capacity, err := units.RAMInBytes(store.Status.Capacity) - if err != nil { - logger.Warn("failed to parse capacity", - zap.String("capacity", store.Status.Capacity), zap.Error(err)) - return nil - } - if capacity <= 0 { - // PD will return a zero value StoreInfo if heartbeat is not received after - // startup, skip temporarily. - return nil - } - available, err := units.RAMInBytes(store.Status.Available) - if err != nil { - logger.Warn("failed to parse available", - zap.String("available", store.Status.Available), zap.Error(err)) - return nil - } - ratio := available * 100 / capacity - if ratio < 10 { - storeType := "TiKV" - if engine.IsTiFlashHTTPResp(&store.Store) { - storeType = "TiFlash" - } - return errors.Errorf("the remaining storage capacity of %s(%s) is less than 10%%; please increase the storage capacity of %s and try again", - storeType, store.Store.Address, storeType) - } - return nil -} - -// executeJob handles a regionJob and tries to convert it to ingested stage. -// If non-retryable error occurs, it will return the error. -// If retryable error occurs, it will return nil and caller should check the stage -// of the regionJob to determine what to do with it. -func (local *Backend) executeJob( - ctx context.Context, - job *regionJob, -) error { - failpoint.Inject("WriteToTiKVNotEnoughDiskSpace", func(_ failpoint.Value) { - failpoint.Return( - errors.New("the remaining storage capacity of TiKV is less than 10%%; please increase the storage capacity of TiKV and try again")) - }) - if local.ShouldCheckTiKV { - for _, peer := range job.region.Region.GetPeers() { - store, err := local.pdHTTPCli.GetStore(ctx, peer.StoreId) - if err != nil { - log.FromContext(ctx).Warn("failed to get StoreInfo from pd http api", zap.Error(err)) - continue - } - err = checkDiskAvail(ctx, store) - if err != nil { - return err - } - } - } - - for { - err := local.writeToTiKV(ctx, job) - if err != nil { - if !local.isRetryableImportTiKVError(err) { - return err - } - // if it's retryable error, we retry from scanning region - log.FromContext(ctx).Warn("meet retryable error when writing to TiKV", - log.ShortError(err), zap.Stringer("job stage", job.stage)) - job.lastRetryableErr = err - return nil - } - - err = local.ingest(ctx, job) - if err != nil { - if !local.isRetryableImportTiKVError(err) { - return err - } - log.FromContext(ctx).Warn("meet retryable error when ingesting", - log.ShortError(err), zap.Stringer("job stage", job.stage)) - job.lastRetryableErr = err - return nil - } - // if the job.stage successfully converted into "ingested", it means - // these data are ingested into TiKV so we handle remaining data. - // For other job.stage, the job should be sent back to caller to retry - // later. - if job.stage != ingested { - return nil - } - - if job.writeResult == nil || job.writeResult.remainingStartKey == nil { - return nil - } - job.keyRange.Start = job.writeResult.remainingStartKey - job.convertStageTo(regionScanned) - } -} - -// ImportEngine imports an engine to TiKV. -func (local *Backend) ImportEngine( - ctx context.Context, - engineUUID uuid.UUID, - regionSplitSize, regionSplitKeys int64, -) error { - var e common.Engine - if externalEngine, ok := local.engineMgr.getExternalEngine(engineUUID); ok { - e = externalEngine - } else { - localEngine := local.engineMgr.lockEngine(engineUUID, importMutexStateImport) - if localEngine == nil { - // skip if engine not exist. See the comment of `CloseEngine` for more detail. - return nil - } - defer localEngine.unlock() - e = localEngine - } - - lfTotalSize, lfLength := e.KVStatistics() - if lfTotalSize == 0 { - // engine is empty, this is likes because it's a index engine but the table contains no index - log.FromContext(ctx).Info("engine contains no kv, skip import", zap.Stringer("engine", engineUUID)) - return nil - } - kvRegionSplitSize, kvRegionSplitKeys, err := GetRegionSplitSizeKeys(ctx, local.pdCli, local.tls) - if err == nil { - if kvRegionSplitSize > regionSplitSize { - regionSplitSize = kvRegionSplitSize - } - if kvRegionSplitKeys > regionSplitKeys { - regionSplitKeys = kvRegionSplitKeys - } - } else { - log.FromContext(ctx).Warn("fail to get region split keys and size", zap.Error(err)) - } - - // split sorted file into range about regionSplitSize per file - regionRanges, err := readAndSplitIntoRange(ctx, e, regionSplitSize, regionSplitKeys) - if err != nil { - return err - } - - if len(regionRanges) > 0 && local.PausePDSchedulerScope == config.PausePDSchedulerScopeTable { - log.FromContext(ctx).Info("pause pd scheduler of table scope") - subCtx, cancel := context.WithCancel(ctx) - defer cancel() - - var startKey, endKey []byte - if len(regionRanges[0].Start) > 0 { - startKey = codec.EncodeBytes(nil, regionRanges[0].Start) - } - if len(regionRanges[len(regionRanges)-1].End) > 0 { - endKey = codec.EncodeBytes(nil, regionRanges[len(regionRanges)-1].End) - } - done, err := pdutil.PauseSchedulersByKeyRange(subCtx, local.pdHTTPCli, startKey, endKey) - if err != nil { - return errors.Trace(err) - } - defer func() { - cancel() - <-done - }() - } - - if len(regionRanges) > 0 && local.BackendConfig.RaftKV2SwitchModeDuration > 0 { - log.FromContext(ctx).Info("switch import mode of ranges", - zap.String("startKey", hex.EncodeToString(regionRanges[0].Start)), - zap.String("endKey", hex.EncodeToString(regionRanges[len(regionRanges)-1].End))) - subCtx, cancel := context.WithCancel(ctx) - defer cancel() - - done, err := local.SwitchModeByKeyRanges(subCtx, regionRanges) - if err != nil { - return errors.Trace(err) - } - defer func() { - cancel() - <-done - }() - } - - log.FromContext(ctx).Info("start import engine", - zap.Stringer("uuid", engineUUID), - zap.Int("region ranges", len(regionRanges)), - zap.Int64("count", lfLength), - zap.Int64("size", lfTotalSize)) - - failpoint.Inject("ReadyForImportEngine", func() {}) - - err = local.doImport(ctx, e, regionRanges, regionSplitSize, regionSplitKeys) - if err == nil { - importedSize, importedLength := e.ImportedStatistics() - log.FromContext(ctx).Info("import engine success", - zap.Stringer("uuid", engineUUID), - zap.Int64("size", lfTotalSize), - zap.Int64("kvs", lfLength), - zap.Int64("importedSize", importedSize), - zap.Int64("importedCount", importedLength)) - } - return err -} - -// expose these variables to unit test. -var ( - testJobToWorkerCh = make(chan *regionJob) - testJobWg *sync.WaitGroup -) - -func (local *Backend) doImport(ctx context.Context, engine common.Engine, regionRanges []common.Range, regionSplitSize, regionSplitKeys int64) error { - /* - [prepareAndSendJob]-----jobToWorkerCh--->[workers] - ^ | - | jobFromWorkerCh - | | - | v - [regionJobRetryer]<--[dispatchJobGoroutine]-->done - */ - - var ( - ctx2, workerCancel = context.WithCancel(ctx) - // workerCtx.Done() means workflow is canceled by error. It may be caused - // by calling workerCancel() or workers in workGroup meets error. - workGroup, workerCtx = util.NewErrorGroupWithRecoverWithCtx(ctx2) - firstErr common.OnceError - // jobToWorkerCh and jobFromWorkerCh are unbuffered so jobs will not be - // owned by them. - jobToWorkerCh = make(chan *regionJob) - jobFromWorkerCh = make(chan *regionJob) - // jobWg tracks the number of jobs in this workflow. - // prepareAndSendJob, workers and regionJobRetryer can own jobs. - // When cancel on error, the goroutine of above three components have - // responsibility to Done jobWg of their owning jobs. - jobWg sync.WaitGroup - dispatchJobGoroutine = make(chan struct{}) - ) - defer workerCancel() - - failpoint.Inject("injectVariables", func() { - jobToWorkerCh = testJobToWorkerCh - testJobWg = &jobWg - }) - - retryer := startRegionJobRetryer(workerCtx, jobToWorkerCh, &jobWg) - - // dispatchJobGoroutine handles processed job from worker, it will only exit - // when jobFromWorkerCh is closed to avoid worker is blocked on sending to - // jobFromWorkerCh. - defer func() { - // use defer to close jobFromWorkerCh after all workers are exited - close(jobFromWorkerCh) - <-dispatchJobGoroutine - }() - go func() { - defer close(dispatchJobGoroutine) - for { - job, ok := <-jobFromWorkerCh - if !ok { - return - } - switch job.stage { - case regionScanned, wrote: - job.retryCount++ - if job.retryCount > maxWriteAndIngestRetryTimes { - firstErr.Set(job.lastRetryableErr) - workerCancel() - job.done(&jobWg) - continue - } - // max retry backoff time: 2+4+8+16+30*26=810s - sleepSecond := math.Pow(2, float64(job.retryCount)) - if sleepSecond > float64(maxRetryBackoffSecond) { - sleepSecond = float64(maxRetryBackoffSecond) - } - job.waitUntil = time.Now().Add(time.Second * time.Duration(sleepSecond)) - log.FromContext(ctx).Info("put job back to jobCh to retry later", - logutil.Key("startKey", job.keyRange.Start), - logutil.Key("endKey", job.keyRange.End), - zap.Stringer("stage", job.stage), - zap.Int("retryCount", job.retryCount), - zap.Time("waitUntil", job.waitUntil)) - if !retryer.push(job) { - // retryer is closed by worker error - job.done(&jobWg) - } - case ingested: - job.done(&jobWg) - case needRescan: - panic("should not reach here") - } - } - }() - - failpoint.Inject("skipStartWorker", func() { - failpoint.Goto("afterStartWorker") - }) - - for i := 0; i < local.WorkerConcurrency; i++ { - workGroup.Go(func() error { - return local.startWorker(workerCtx, jobToWorkerCh, jobFromWorkerCh, &jobWg) - }) - } - - failpoint.Label("afterStartWorker") - - workGroup.Go(func() error { - err := local.prepareAndSendJob( - workerCtx, - engine, - regionRanges, - regionSplitSize, - regionSplitKeys, - jobToWorkerCh, - &jobWg, - ) - if err != nil { - return err - } - - jobWg.Wait() - workerCancel() - return nil - }) - if err := workGroup.Wait(); err != nil { - if !common.IsContextCanceledError(err) { - log.FromContext(ctx).Error("do import meets error", zap.Error(err)) - } - firstErr.Set(err) - } - return firstErr.Get() -} - -// GetImportedKVCount returns the number of imported KV pairs of some engine. -func (local *Backend) GetImportedKVCount(engineUUID uuid.UUID) int64 { - return local.engineMgr.getImportedKVCount(engineUUID) -} - -// GetExternalEngineKVStatistics returns kv statistics of some engine. -func (local *Backend) GetExternalEngineKVStatistics(engineUUID uuid.UUID) ( - totalKVSize int64, totalKVCount int64) { - return local.engineMgr.getExternalEngineKVStatistics(engineUUID) -} - -// ResetEngine reset the engine and reclaim the space. -func (local *Backend) ResetEngine(ctx context.Context, engineUUID uuid.UUID) error { - return local.engineMgr.resetEngine(ctx, engineUUID, false) -} - -// ResetEngineSkipAllocTS is like ResetEngine but the inner TS of the engine is -// invalid. Caller must use OpenedEngine.SetTS to set a valid TS before import -// the engine. -func (local *Backend) ResetEngineSkipAllocTS(ctx context.Context, engineUUID uuid.UUID) error { - return local.engineMgr.resetEngine(ctx, engineUUID, true) -} - -// CleanupEngine cleanup the engine and reclaim the space. -func (local *Backend) CleanupEngine(ctx context.Context, engineUUID uuid.UUID) error { - return local.engineMgr.cleanupEngine(ctx, engineUUID) -} - -// GetDupeController returns a new dupe controller. -func (local *Backend) GetDupeController(dupeConcurrency int, errorMgr *errormanager.ErrorManager) *DupeController { - return &DupeController{ - splitCli: local.splitCli, - tikvCli: local.tikvCli, - tikvCodec: local.tikvCodec, - errorMgr: errorMgr, - dupeConcurrency: dupeConcurrency, - duplicateDB: local.engineMgr.getDuplicateDB(), - keyAdapter: local.engineMgr.getKeyAdapter(), - importClientFactory: local.importClientFactory, - resourceGroupName: local.ResourceGroupName, - taskType: local.TaskType, - } -} - -// UnsafeImportAndReset forces the backend to import the content of an engine -// into the target and then reset the engine to empty. This method will not -// close the engine. Make sure the engine is flushed manually before calling -// this method. -func (local *Backend) UnsafeImportAndReset(ctx context.Context, engineUUID uuid.UUID, regionSplitSize, regionSplitKeys int64) error { - // DO NOT call be.abstract.CloseEngine()! The engine should still be writable after - // calling UnsafeImportAndReset(). - logger := log.FromContext(ctx).With( - zap.String("engineTag", ""), - zap.Stringer("engineUUID", engineUUID), - ) - closedEngine := backend.NewClosedEngine(local, logger, engineUUID, 0) - if err := closedEngine.Import(ctx, regionSplitSize, regionSplitKeys); err != nil { - return err - } - return local.engineMgr.resetEngine(ctx, engineUUID, false) -} - -func engineSSTDir(storeDir string, engineUUID uuid.UUID) string { - return filepath.Join(storeDir, engineUUID.String()+".sst") -} - -// LocalWriter returns a new local writer. -func (local *Backend) LocalWriter(ctx context.Context, cfg *backend.LocalWriterConfig, engineUUID uuid.UUID) (backend.EngineWriter, error) { - return local.engineMgr.localWriter(ctx, cfg, engineUUID) -} - -// SwitchModeByKeyRanges will switch tikv mode for regions in the specific key range for multirocksdb. -// This function will spawn a goroutine to keep switch mode periodically until the context is done. -// The return done channel is used to notify the caller that the background goroutine is exited. -func (local *Backend) SwitchModeByKeyRanges(ctx context.Context, ranges []common.Range) (<-chan struct{}, error) { - switcher := NewTiKVModeSwitcher(local.tls.TLSConfig(), local.pdHTTPCli, log.FromContext(ctx).Logger) - done := make(chan struct{}) - - keyRanges := make([]*sst.Range, 0, len(ranges)) - for _, r := range ranges { - startKey := r.Start - if len(r.Start) > 0 { - startKey = codec.EncodeBytes(nil, r.Start) - } - endKey := r.End - if len(r.End) > 0 { - endKey = codec.EncodeBytes(nil, r.End) - } - keyRanges = append(keyRanges, &sst.Range{ - Start: startKey, - End: endKey, - }) - } - - go func() { - defer close(done) - ticker := time.NewTicker(local.BackendConfig.RaftKV2SwitchModeDuration) - defer ticker.Stop() - switcher.ToImportMode(ctx, keyRanges...) - loop: - for { - select { - case <-ctx.Done(): - break loop - case <-ticker.C: - switcher.ToImportMode(ctx, keyRanges...) - } - } - // Use a new context to avoid the context is canceled by the caller. - recoverCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - switcher.ToNormalMode(recoverCtx, keyRanges...) - }() - return done, nil -} - -func openLocalWriter(cfg *backend.LocalWriterConfig, engine *Engine, tikvCodec tikvclient.Codec, cacheSize int64, kvBuffer *membuf.Buffer) (*Writer, error) { - // pre-allocate a long enough buffer to avoid a lot of runtime.growslice - // this can help save about 3% of CPU. - var preAllocWriteBatch []common.KvPair - if !cfg.Local.IsKVSorted { - preAllocWriteBatch = make([]common.KvPair, units.MiB) - // we want to keep the cacheSize as the whole limit of this local writer, but the - // main memory usage comes from two member: kvBuffer and writeBatch, so we split - // ~10% to writeBatch for !IsKVSorted, which means we estimate the average length - // of KV pairs are 9 times than the size of common.KvPair (9*72B = 648B). - cacheSize = cacheSize * 9 / 10 - } - w := &Writer{ - engine: engine, - memtableSizeLimit: cacheSize, - kvBuffer: kvBuffer, - isKVSorted: cfg.Local.IsKVSorted, - isWriteBatchSorted: true, - tikvCodec: tikvCodec, - writeBatch: preAllocWriteBatch, - } - engine.localWriters.Store(w, nil) - return w, nil -} - -// return the smallest []byte that is bigger than current bytes. -// special case when key is empty, empty bytes means infinity in our context, so directly return itself. -func nextKey(key []byte) []byte { - if len(key) == 0 { - return []byte{} - } - - // in tikv <= 4.x, tikv will truncate the row key, so we should fetch the next valid row key - // See: https://github.com/tikv/tikv/blob/f7f22f70e1585d7ca38a59ea30e774949160c3e8/components/raftstore/src/coprocessor/split_observer.rs#L36-L41 - // we only do this for IntHandle, which is checked by length - if tablecodec.IsRecordKey(key) && len(key) == tablecodec.RecordRowKeyLen { - tableID, handle, _ := tablecodec.DecodeRecordKey(key) - nextHandle := handle.Next() - // int handle overflow, use the next table prefix as nextKey - if nextHandle.Compare(handle) <= 0 { - return tablecodec.EncodeTablePrefix(tableID + 1) - } - return tablecodec.EncodeRowKeyWithHandle(tableID, nextHandle) - } - - // for index key and CommonHandle, directly append a 0x00 to the key. - res := make([]byte, 0, len(key)+1) - res = append(res, key...) - res = append(res, 0) - return res -} - -// EngineFileSizes implements DiskUsage interface. -func (local *Backend) EngineFileSizes() (res []backend.EngineFileSize) { - return local.engineMgr.engineFileSizes() -} - -// GetTS implements StoreHelper interface. -func (local *Backend) GetTS(ctx context.Context) (physical, logical int64, err error) { - return local.pdCli.GetTS(ctx) -} - -// GetTiKVCodec implements StoreHelper interface. -func (local *Backend) GetTiKVCodec() tikvclient.Codec { - return local.tikvCodec -} - -// CloseEngineMgr close the engine manager. -// This function is used for test. -func (local *Backend) CloseEngineMgr() { - local.engineMgr.close() -} - -var getSplitConfFromStoreFunc = getSplitConfFromStore - -// return region split size, region split keys, error -func getSplitConfFromStore(ctx context.Context, host string, tls *common.TLS) ( - splitSize int64, regionSplitKeys int64, err error) { - var ( - nested struct { - Coprocessor struct { - RegionSplitSize string `json:"region-split-size"` - RegionSplitKeys int64 `json:"region-split-keys"` - } `json:"coprocessor"` - } - ) - if err := tls.WithHost(host).GetJSON(ctx, "/config", &nested); err != nil { - return 0, 0, errors.Trace(err) - } - splitSize, err = units.FromHumanSize(nested.Coprocessor.RegionSplitSize) - if err != nil { - return 0, 0, errors.Trace(err) - } - - return splitSize, nested.Coprocessor.RegionSplitKeys, nil -} - -// GetRegionSplitSizeKeys return region split size, region split keys, error -func GetRegionSplitSizeKeys(ctx context.Context, cli pd.Client, tls *common.TLS) ( - regionSplitSize int64, regionSplitKeys int64, err error) { - stores, err := cli.GetAllStores(ctx, pd.WithExcludeTombstone()) - if err != nil { - return 0, 0, err - } - for _, store := range stores { - if store.StatusAddress == "" || engine.IsTiFlash(store) { - continue - } - serverInfo := infoschema.ServerInfo{ - Address: store.Address, - StatusAddr: store.StatusAddress, - } - serverInfo.ResolveLoopBackAddr() - regionSplitSize, regionSplitKeys, err := getSplitConfFromStoreFunc(ctx, serverInfo.StatusAddr, tls) - if err == nil { - return regionSplitSize, regionSplitKeys, nil - } - log.FromContext(ctx).Warn("get region split size and keys failed", zap.Error(err), zap.String("store", serverInfo.StatusAddr)) - } - return 0, 0, errors.New("get region split size and keys failed") -} diff --git a/pkg/lightning/backend/local/local_unix.go b/pkg/lightning/backend/local/local_unix.go index ec213e3664581..20695c6ebd6ab 100644 --- a/pkg/lightning/backend/local/local_unix.go +++ b/pkg/lightning/backend/local/local_unix.go @@ -45,12 +45,12 @@ func VerifyRLimit(estimateMaxFiles RlimT) error { } var rLimit syscall.Rlimit err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) - if v, _err_ := failpoint.Eval(_curpkg_("GetRlimitValue")); _err_ == nil { + failpoint.Inject("GetRlimitValue", func(v failpoint.Value) { limit := RlimT(v.(int)) rLimit.Cur = limit rLimit.Max = limit err = nil - } + }) if err != nil { return errors.Trace(err) } @@ -63,11 +63,11 @@ func VerifyRLimit(estimateMaxFiles RlimT) error { } prevLimit := rLimit.Cur rLimit.Cur = estimateMaxFiles - if v, _err_ := failpoint.Eval(_curpkg_("SetRlimitError")); _err_ == nil { + failpoint.Inject("SetRlimitError", func(v failpoint.Value) { if v.(bool) { err = errors.New("Setrlimit Injected Error") } - } + }) if err == nil { err = syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) } diff --git a/pkg/lightning/backend/local/local_unix.go__failpoint_stash__ b/pkg/lightning/backend/local/local_unix.go__failpoint_stash__ deleted file mode 100644 index 20695c6ebd6ab..0000000000000 --- a/pkg/lightning/backend/local/local_unix.go__failpoint_stash__ +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !windows - -package local - -import ( - "syscall" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/lightning/log" -) - -const ( - // maximum max open files value - maxRLimit = 1000000 -) - -// GetSystemRLimit returns the current open-file limit. -func GetSystemRLimit() (RlimT, error) { - var rLimit syscall.Rlimit - err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) - return rLimit.Cur, err -} - -// VerifyRLimit checks whether the open-file limit is large enough. -// In Local-backend, we need to read and write a lot of L0 SST files, so we need -// to check system max open files limit. -func VerifyRLimit(estimateMaxFiles RlimT) error { - if estimateMaxFiles > maxRLimit { - estimateMaxFiles = maxRLimit - } - var rLimit syscall.Rlimit - err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) - failpoint.Inject("GetRlimitValue", func(v failpoint.Value) { - limit := RlimT(v.(int)) - rLimit.Cur = limit - rLimit.Max = limit - err = nil - }) - if err != nil { - return errors.Trace(err) - } - if rLimit.Cur >= estimateMaxFiles { - return nil - } - if rLimit.Max < estimateMaxFiles { - // If the process is not started by privileged user, this will fail. - rLimit.Max = estimateMaxFiles - } - prevLimit := rLimit.Cur - rLimit.Cur = estimateMaxFiles - failpoint.Inject("SetRlimitError", func(v failpoint.Value) { - if v.(bool) { - err = errors.New("Setrlimit Injected Error") - } - }) - if err == nil { - err = syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) - } - if err != nil { - return errors.Annotatef(err, "the maximum number of open file descriptors is too small, got %d, expect greater or equal to %d", prevLimit, estimateMaxFiles) - } - - // fetch the rlimit again to make sure our setting has taken effect - err = syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) - if err != nil { - return errors.Trace(err) - } - if rLimit.Cur < estimateMaxFiles { - helper := "Please manually execute `ulimit -n %d` to increase the open files limit." - return errors.Errorf("cannot update the maximum number of open file descriptors, expected: %d, got: %d. %s", - estimateMaxFiles, rLimit.Cur, helper) - } - - log.L().Info("Set the maximum number of open file descriptors(rlimit)", - zapRlimT("old", prevLimit), zapRlimT("new", estimateMaxFiles)) - return nil -} diff --git a/pkg/lightning/backend/local/region_job.go b/pkg/lightning/backend/local/region_job.go index d18c33970a4ec..7bc812e4b9bb6 100644 --- a/pkg/lightning/backend/local/region_job.go +++ b/pkg/lightning/backend/local/region_job.go @@ -208,7 +208,7 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { return nil } - if _, _err_ := failpoint.Eval(_curpkg_("fakeRegionJobs")); _err_ == nil { + failpoint.Inject("fakeRegionJobs", func() { front := j.injected[0] j.injected = j.injected[1:] j.writeResult = front.write.result @@ -216,8 +216,8 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { if err == nil { j.convertStageTo(wrote) } - return err - } + failpoint.Return(err) + }) var cancel context.CancelFunc ctx, cancel = context.WithTimeoutCause(ctx, 15*time.Minute, common.ErrWriteTooSlow) @@ -261,7 +261,7 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { ApiVersion: apiVersion, } - if val, _err_ := failpoint.Eval(_curpkg_("changeEpochVersion")); _err_ == nil { + failpoint.Inject("changeEpochVersion", func(val failpoint.Value) { cloned := *meta.RegionEpoch meta.RegionEpoch = &cloned i := val.(int) @@ -270,7 +270,7 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { } else { meta.RegionEpoch.ConfVer -= uint64(-i) } - } + }) annotateErr := func(in error, peer *metapb.Peer, msg string) error { // annotate the error with peer/store/region info to help debug. @@ -307,10 +307,10 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { return annotateErr(err, peer, "when open write stream") } - if _, _err_ := failpoint.Eval(_curpkg_("mockWritePeerErr")); _err_ == nil { + failpoint.Inject("mockWritePeerErr", func() { err = errors.Errorf("mock write peer error") - return annotateErr(err, peer, "when open write stream") - } + failpoint.Return(annotateErr(err, peer, "when open write stream")) + }) // Bind uuid for this write request if err = wstream.Send(req); err != nil { @@ -360,9 +360,9 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { return annotateErr(err, allPeers[i], "when send data") } } - if _, _err_ := failpoint.Eval(_curpkg_("afterFlushKVs")); _err_ == nil { + failpoint.Inject("afterFlushKVs", func() { log.FromContext(ctx).Info(fmt.Sprintf("afterFlushKVs count=%d,size=%d", count, size)) - } + }) return nil } @@ -444,10 +444,10 @@ func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { } } - if _, _err_ := failpoint.Eval(_curpkg_("NoLeader")); _err_ == nil { + failpoint.Inject("NoLeader", func() { log.FromContext(ctx).Warn("enter failpoint NoLeader") leaderPeerMetas = nil - } + }) // if there is not leader currently, we don't forward the stage to wrote and let caller // handle the retry. @@ -488,12 +488,12 @@ func (local *Backend) ingest(ctx context.Context, j *regionJob) (err error) { return nil } - if _, _err_ := failpoint.Eval(_curpkg_("fakeRegionJobs")); _err_ == nil { + failpoint.Inject("fakeRegionJobs", func() { front := j.injected[0] j.injected = j.injected[1:] j.convertStageTo(front.ingest.nextStage) - return front.ingest.err - } + failpoint.Return(front.ingest.err) + }) if len(j.writeResult.sstMeta) == 0 { j.convertStageTo(ingested) @@ -597,7 +597,7 @@ func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestRe log.FromContext(ctx).Debug("ingest meta", zap.Reflect("meta", ingestMetas)) - if val, _err_ := failpoint.Eval(_curpkg_("FailIngestMeta")); _err_ == nil { + failpoint.Inject("FailIngestMeta", func(val failpoint.Value) { // only inject the error once var resp *sst.IngestResponse @@ -620,8 +620,8 @@ func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestRe }, } } - return resp, nil - } + failpoint.Return(resp, nil) + }) leader := j.region.Leader if leader == nil { diff --git a/pkg/lightning/backend/local/region_job.go__failpoint_stash__ b/pkg/lightning/backend/local/region_job.go__failpoint_stash__ deleted file mode 100644 index 7bc812e4b9bb6..0000000000000 --- a/pkg/lightning/backend/local/region_job.go__failpoint_stash__ +++ /dev/null @@ -1,907 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package local - -import ( - "container/heap" - "context" - "fmt" - "io" - "strings" - "sync" - "time" - - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/errorpb" - sst "github.com/pingcap/kvproto/pkg/import_sstpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - berrors "github.com/pingcap/tidb/br/pkg/errors" - "github.com/pingcap/tidb/br/pkg/logutil" - "github.com/pingcap/tidb/br/pkg/restore/split" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/metric" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" - "google.golang.org/grpc" -) - -type jobStageTp string - -/* - + - v - +------+------+ - +->+regionScanned+<------+ - | +------+------+ | - | | | - | | | - | v | - | +--+--+ +-----+----+ - | |wrote+---->+needRescan| - | +--+--+ +-----+----+ - | | ^ - | | | - | v | - | +---+----+ | - +-----+ingested+---------+ - +---+----+ - | - v - -above diagram shows the state transition of a region job, here are some special -cases: - - regionScanned can directly jump to ingested if the keyRange has no data - - regionScanned can only transit to wrote. TODO: check if it should be transited - to needRescan - - if a job only partially writes the data, after it becomes ingested, it will - update its keyRange and transits to regionScanned to continue the remaining - data - - needRescan may output multiple regionScanned jobs when the old region is split -*/ -const ( - regionScanned jobStageTp = "regionScanned" - wrote jobStageTp = "wrote" - ingested jobStageTp = "ingested" - needRescan jobStageTp = "needRescan" - - // suppose each KV is about 32 bytes, 16 * units.KiB / 32 = 512 - defaultKVBatchCount = 512 -) - -func (j jobStageTp) String() string { - return string(j) -} - -// regionJob is dedicated to import the data in [keyRange.start, keyRange.end) -// to a region. The keyRange may be changed when processing because of writing -// partial data to TiKV or region split. -type regionJob struct { - keyRange common.Range - // TODO: check the keyRange so that it's always included in region - region *split.RegionInfo - // stage should be updated only by convertStageTo - stage jobStageTp - // writeResult is available only in wrote and ingested stage - writeResult *tikvWriteResult - - ingestData common.IngestData - regionSplitSize int64 - regionSplitKeys int64 - metrics *metric.Common - - retryCount int - waitUntil time.Time - lastRetryableErr error - - // injected is used in test to set the behaviour - injected []injectedBehaviour -} - -type tikvWriteResult struct { - sstMeta []*sst.SSTMeta - count int64 - totalBytes int64 - remainingStartKey []byte -} - -type injectedBehaviour struct { - write injectedWriteBehaviour - ingest injectedIngestBehaviour -} - -type injectedWriteBehaviour struct { - result *tikvWriteResult - err error -} - -type injectedIngestBehaviour struct { - nextStage jobStageTp - err error -} - -func (j *regionJob) convertStageTo(stage jobStageTp) { - j.stage = stage - switch stage { - case regionScanned: - j.writeResult = nil - case ingested: - // when writing is skipped because key range is empty - if j.writeResult == nil { - return - } - - j.ingestData.Finish(j.writeResult.totalBytes, j.writeResult.count) - if j.metrics != nil { - j.metrics.BytesCounter.WithLabelValues(metric.StateImported). - Add(float64(j.writeResult.totalBytes)) - } - case needRescan: - j.region = nil - } -} - -// ref means that the ingestData of job will be accessed soon. -func (j *regionJob) ref(wg *sync.WaitGroup) { - if wg != nil { - wg.Add(1) - } - if j.ingestData != nil { - j.ingestData.IncRef() - } -} - -// done promises that the ingestData of job will not be accessed. Same amount of -// done should be called to release the ingestData. -func (j *regionJob) done(wg *sync.WaitGroup) { - if j.ingestData != nil { - j.ingestData.DecRef() - } - if wg != nil { - wg.Done() - } -} - -// writeToTiKV writes the data to TiKV and mark this job as wrote stage. -// if any write logic has error, writeToTiKV will set job to a proper stage and return nil. -// if any underlying logic has error, writeToTiKV will return an error. -// we don't need to do cleanup for the pairs written to tikv if encounters an error, -// tikv will take the responsibility to do so. -// TODO: let client-go provide a high-level write interface. -func (local *Backend) writeToTiKV(ctx context.Context, j *regionJob) error { - err := local.doWrite(ctx, j) - if err == nil { - return nil - } - if !common.IsRetryableError(err) { - return err - } - // currently only one case will restart write - if strings.Contains(err.Error(), "RequestTooNew") { - j.convertStageTo(regionScanned) - return err - } - j.convertStageTo(needRescan) - return err -} - -func (local *Backend) doWrite(ctx context.Context, j *regionJob) error { - if j.stage != regionScanned { - return nil - } - - failpoint.Inject("fakeRegionJobs", func() { - front := j.injected[0] - j.injected = j.injected[1:] - j.writeResult = front.write.result - err := front.write.err - if err == nil { - j.convertStageTo(wrote) - } - failpoint.Return(err) - }) - - var cancel context.CancelFunc - ctx, cancel = context.WithTimeoutCause(ctx, 15*time.Minute, common.ErrWriteTooSlow) - defer cancel() - - apiVersion := local.tikvCodec.GetAPIVersion() - clientFactory := local.importClientFactory - kvBatchSize := local.KVWriteBatchSize - bufferPool := local.engineMgr.getBufferPool() - writeLimiter := local.writeLimiter - - begin := time.Now() - region := j.region.Region - - firstKey, lastKey, err := j.ingestData.GetFirstAndLastKey(j.keyRange.Start, j.keyRange.End) - if err != nil { - return errors.Trace(err) - } - if firstKey == nil { - j.convertStageTo(ingested) - log.FromContext(ctx).Debug("keys within region is empty, skip doIngest", - logutil.Key("start", j.keyRange.Start), - logutil.Key("regionStart", region.StartKey), - logutil.Key("end", j.keyRange.End), - logutil.Key("regionEnd", region.EndKey)) - return nil - } - - firstKey = codec.EncodeBytes([]byte{}, firstKey) - lastKey = codec.EncodeBytes([]byte{}, lastKey) - - u := uuid.New() - meta := &sst.SSTMeta{ - Uuid: u[:], - RegionId: region.GetId(), - RegionEpoch: region.GetRegionEpoch(), - Range: &sst.Range{ - Start: firstKey, - End: lastKey, - }, - ApiVersion: apiVersion, - } - - failpoint.Inject("changeEpochVersion", func(val failpoint.Value) { - cloned := *meta.RegionEpoch - meta.RegionEpoch = &cloned - i := val.(int) - if i >= 0 { - meta.RegionEpoch.Version += uint64(i) - } else { - meta.RegionEpoch.ConfVer -= uint64(-i) - } - }) - - annotateErr := func(in error, peer *metapb.Peer, msg string) error { - // annotate the error with peer/store/region info to help debug. - return errors.Annotatef( - in, - "peer %d, store %d, region %d, epoch %s, %s", - peer.Id, peer.StoreId, region.Id, region.RegionEpoch.String(), - msg, - ) - } - - leaderID := j.region.Leader.GetId() - clients := make([]sst.ImportSST_WriteClient, 0, len(region.GetPeers())) - allPeers := make([]*metapb.Peer, 0, len(region.GetPeers())) - req := &sst.WriteRequest{ - Chunk: &sst.WriteRequest_Meta{ - Meta: meta, - }, - Context: &kvrpcpb.Context{ - ResourceControlContext: &kvrpcpb.ResourceControlContext{ - ResourceGroupName: local.ResourceGroupName, - }, - RequestSource: util.BuildRequestSource(true, kv.InternalTxnLightning, local.TaskType), - }, - } - for _, peer := range region.GetPeers() { - cli, err := clientFactory.Create(ctx, peer.StoreId) - if err != nil { - return annotateErr(err, peer, "when create client") - } - - wstream, err := cli.Write(ctx) - if err != nil { - return annotateErr(err, peer, "when open write stream") - } - - failpoint.Inject("mockWritePeerErr", func() { - err = errors.Errorf("mock write peer error") - failpoint.Return(annotateErr(err, peer, "when open write stream")) - }) - - // Bind uuid for this write request - if err = wstream.Send(req); err != nil { - return annotateErr(err, peer, "when send meta") - } - clients = append(clients, wstream) - allPeers = append(allPeers, peer) - } - dataCommitTS := j.ingestData.GetTS() - req.Chunk = &sst.WriteRequest_Batch{ - Batch: &sst.WriteBatch{ - CommitTs: dataCommitTS, - }, - } - - pairs := make([]*sst.Pair, 0, defaultKVBatchCount) - count := 0 - size := int64(0) - totalSize := int64(0) - totalCount := int64(0) - // if region-split-size <= 96MiB, we bump the threshold a bit to avoid too many retry split - // because the range-properties is not 100% accurate - regionMaxSize := j.regionSplitSize - if j.regionSplitSize <= int64(config.SplitRegionSize) { - regionMaxSize = j.regionSplitSize * 4 / 3 - } - - flushKVs := func() error { - req.Chunk.(*sst.WriteRequest_Batch).Batch.Pairs = pairs[:count] - preparedMsg := &grpc.PreparedMsg{} - // by reading the source code, Encode need to find codec and compression from the stream - // because all stream has the same codec and compression, we can use any one of them - if err := preparedMsg.Encode(clients[0], req); err != nil { - return err - } - - for i := range clients { - if err := writeLimiter.WaitN(ctx, allPeers[i].StoreId, int(size)); err != nil { - return errors.Trace(err) - } - if err := clients[i].SendMsg(preparedMsg); err != nil { - if err == io.EOF { - // if it's EOF, need RecvMsg to get the error - dummy := &sst.WriteResponse{} - err = clients[i].RecvMsg(dummy) - } - return annotateErr(err, allPeers[i], "when send data") - } - } - failpoint.Inject("afterFlushKVs", func() { - log.FromContext(ctx).Info(fmt.Sprintf("afterFlushKVs count=%d,size=%d", count, size)) - }) - return nil - } - - iter := j.ingestData.NewIter(ctx, j.keyRange.Start, j.keyRange.End, bufferPool) - //nolint: errcheck - defer iter.Close() - - var remainingStartKey []byte - for iter.First(); iter.Valid(); iter.Next() { - k, v := iter.Key(), iter.Value() - kvSize := int64(len(k) + len(v)) - // here we reuse the `*sst.Pair`s to optimize object allocation - if count < len(pairs) { - pairs[count].Key = k - pairs[count].Value = v - } else { - pair := &sst.Pair{ - Key: k, - Value: v, - } - pairs = append(pairs, pair) - } - count++ - totalCount++ - size += kvSize - totalSize += kvSize - - if size >= kvBatchSize { - if err := flushKVs(); err != nil { - return errors.Trace(err) - } - count = 0 - size = 0 - iter.ReleaseBuf() - } - if totalSize >= regionMaxSize || totalCount >= j.regionSplitKeys { - // we will shrink the key range of this job to real written range - if iter.Next() { - remainingStartKey = append([]byte{}, iter.Key()...) - log.FromContext(ctx).Info("write to tikv partial finish", - zap.Int64("count", totalCount), - zap.Int64("size", totalSize), - logutil.Key("startKey", j.keyRange.Start), - logutil.Key("endKey", j.keyRange.End), - logutil.Key("remainStart", remainingStartKey), - logutil.Region(region), - logutil.Leader(j.region.Leader), - zap.Uint64("commitTS", dataCommitTS)) - } - break - } - } - - if iter.Error() != nil { - return errors.Trace(iter.Error()) - } - - if count > 0 { - if err := flushKVs(); err != nil { - return errors.Trace(err) - } - count = 0 - size = 0 - iter.ReleaseBuf() - } - - var leaderPeerMetas []*sst.SSTMeta - for i, wStream := range clients { - resp, closeErr := wStream.CloseAndRecv() - if closeErr != nil { - return annotateErr(closeErr, allPeers[i], "when close write stream") - } - if resp.Error != nil { - return annotateErr(errors.New("resp error: "+resp.Error.Message), allPeers[i], "when close write stream") - } - if leaderID == region.Peers[i].GetId() { - leaderPeerMetas = resp.Metas - log.FromContext(ctx).Debug("get metas after write kv stream to tikv", zap.Reflect("metas", leaderPeerMetas)) - } - } - - failpoint.Inject("NoLeader", func() { - log.FromContext(ctx).Warn("enter failpoint NoLeader") - leaderPeerMetas = nil - }) - - // if there is not leader currently, we don't forward the stage to wrote and let caller - // handle the retry. - if len(leaderPeerMetas) == 0 { - log.FromContext(ctx).Warn("write to tikv no leader", - logutil.Region(region), logutil.Leader(j.region.Leader), - zap.Uint64("leader_id", leaderID), logutil.SSTMeta(meta), - zap.Int64("kv_pairs", totalCount), zap.Int64("total_bytes", totalSize)) - return common.ErrNoLeader.GenWithStackByArgs(region.Id, leaderID) - } - - takeTime := time.Since(begin) - log.FromContext(ctx).Debug("write to kv", zap.Reflect("region", j.region), zap.Uint64("leader", leaderID), - zap.Reflect("meta", meta), zap.Reflect("return metas", leaderPeerMetas), - zap.Int64("kv_pairs", totalCount), zap.Int64("total_bytes", totalSize), - zap.Stringer("takeTime", takeTime)) - if m, ok := metric.FromContext(ctx); ok { - m.SSTSecondsHistogram.WithLabelValues(metric.SSTProcessWrite).Observe(takeTime.Seconds()) - } - - j.writeResult = &tikvWriteResult{ - sstMeta: leaderPeerMetas, - count: totalCount, - totalBytes: totalSize, - remainingStartKey: remainingStartKey, - } - j.convertStageTo(wrote) - return nil -} - -// ingest tries to finish the regionJob. -// if any ingest logic has error, ingest may retry sometimes to resolve it and finally -// set job to a proper stage with nil error returned. -// if any underlying logic has error, ingest will return an error to let caller -// handle it. -func (local *Backend) ingest(ctx context.Context, j *regionJob) (err error) { - if j.stage != wrote { - return nil - } - - failpoint.Inject("fakeRegionJobs", func() { - front := j.injected[0] - j.injected = j.injected[1:] - j.convertStageTo(front.ingest.nextStage) - failpoint.Return(front.ingest.err) - }) - - if len(j.writeResult.sstMeta) == 0 { - j.convertStageTo(ingested) - return nil - } - - if m, ok := metric.FromContext(ctx); ok { - begin := time.Now() - defer func() { - if err == nil { - m.SSTSecondsHistogram.WithLabelValues(metric.SSTProcessIngest).Observe(time.Since(begin).Seconds()) - } - }() - } - - for retry := 0; retry < maxRetryTimes; retry++ { - resp, err := local.doIngest(ctx, j) - if err == nil && resp.GetError() == nil { - j.convertStageTo(ingested) - return nil - } - if err != nil { - if common.IsContextCanceledError(err) { - return err - } - log.FromContext(ctx).Warn("meet underlying error, will retry ingest", - log.ShortError(err), logutil.SSTMetas(j.writeResult.sstMeta), - logutil.Region(j.region.Region), logutil.Leader(j.region.Leader)) - continue - } - canContinue, err := j.convertStageOnIngestError(resp) - if common.IsContextCanceledError(err) { - return err - } - if !canContinue { - log.FromContext(ctx).Warn("meet error and handle the job later", - zap.Stringer("job stage", j.stage), - logutil.ShortError(j.lastRetryableErr), - j.region.ToZapFields(), - logutil.Key("start", j.keyRange.Start), - logutil.Key("end", j.keyRange.End)) - return nil - } - log.FromContext(ctx).Warn("meet error and will doIngest region again", - logutil.ShortError(j.lastRetryableErr), - j.region.ToZapFields(), - logutil.Key("start", j.keyRange.Start), - logutil.Key("end", j.keyRange.End)) - } - return nil -} - -func (local *Backend) checkWriteStall( - ctx context.Context, - region *split.RegionInfo, -) (bool, *sst.IngestResponse, error) { - clientFactory := local.importClientFactory - for _, peer := range region.Region.GetPeers() { - cli, err := clientFactory.Create(ctx, peer.StoreId) - if err != nil { - return false, nil, errors.Trace(err) - } - // currently we use empty MultiIngestRequest to check if TiKV is busy. - // If in future the rate limit feature contains more metrics we can switch to use it. - resp, err := cli.MultiIngest(ctx, &sst.MultiIngestRequest{}) - if err != nil { - return false, nil, errors.Trace(err) - } - if resp.Error != nil && resp.Error.ServerIsBusy != nil { - return true, resp, nil - } - } - return false, nil, nil -} - -// doIngest send ingest commands to TiKV based on regionJob.writeResult.sstMeta. -// When meet error, it will remove finished sstMetas before return. -func (local *Backend) doIngest(ctx context.Context, j *regionJob) (*sst.IngestResponse, error) { - clientFactory := local.importClientFactory - supportMultiIngest := local.supportMultiIngest - shouldCheckWriteStall := local.ShouldCheckWriteStall - if shouldCheckWriteStall { - writeStall, resp, err := local.checkWriteStall(ctx, j.region) - if err != nil { - return nil, errors.Trace(err) - } - if writeStall { - return resp, nil - } - } - - batch := 1 - if supportMultiIngest { - batch = len(j.writeResult.sstMeta) - } - - var resp *sst.IngestResponse - for start := 0; start < len(j.writeResult.sstMeta); start += batch { - end := min(start+batch, len(j.writeResult.sstMeta)) - ingestMetas := j.writeResult.sstMeta[start:end] - - log.FromContext(ctx).Debug("ingest meta", zap.Reflect("meta", ingestMetas)) - - failpoint.Inject("FailIngestMeta", func(val failpoint.Value) { - // only inject the error once - var resp *sst.IngestResponse - - switch val.(string) { - case "notleader": - resp = &sst.IngestResponse{ - Error: &errorpb.Error{ - NotLeader: &errorpb.NotLeader{ - RegionId: j.region.Region.Id, - Leader: j.region.Leader, - }, - }, - } - case "epochnotmatch": - resp = &sst.IngestResponse{ - Error: &errorpb.Error{ - EpochNotMatch: &errorpb.EpochNotMatch{ - CurrentRegions: []*metapb.Region{j.region.Region}, - }, - }, - } - } - failpoint.Return(resp, nil) - }) - - leader := j.region.Leader - if leader == nil { - return nil, errors.Annotatef(berrors.ErrPDLeaderNotFound, - "region id %d has no leader", j.region.Region.Id) - } - - cli, err := clientFactory.Create(ctx, leader.StoreId) - if err != nil { - return nil, errors.Trace(err) - } - reqCtx := &kvrpcpb.Context{ - RegionId: j.region.Region.GetId(), - RegionEpoch: j.region.Region.GetRegionEpoch(), - Peer: leader, - ResourceControlContext: &kvrpcpb.ResourceControlContext{ - ResourceGroupName: local.ResourceGroupName, - }, - RequestSource: util.BuildRequestSource(true, kv.InternalTxnLightning, local.TaskType), - } - - if supportMultiIngest { - req := &sst.MultiIngestRequest{ - Context: reqCtx, - Ssts: ingestMetas, - } - resp, err = cli.MultiIngest(ctx, req) - } else { - req := &sst.IngestRequest{ - Context: reqCtx, - Sst: ingestMetas[0], - } - resp, err = cli.Ingest(ctx, req) - } - if resp.GetError() != nil || err != nil { - // remove finished sstMetas - j.writeResult.sstMeta = j.writeResult.sstMeta[start:] - return resp, errors.Trace(err) - } - } - return resp, nil -} - -// convertStageOnIngestError will try to fix the error contained in ingest response. -// Return (_, error) when another error occurred. -// Return (true, nil) when the job can retry ingesting immediately. -// Return (false, nil) when the job should be put back to queue. -func (j *regionJob) convertStageOnIngestError( - resp *sst.IngestResponse, -) (bool, error) { - if resp.GetError() == nil { - return true, nil - } - - var newRegion *split.RegionInfo - switch errPb := resp.GetError(); { - case errPb.NotLeader != nil: - j.lastRetryableErr = common.ErrKVNotLeader.GenWithStack(errPb.GetMessage()) - - // meet a problem that the region leader+peer are all updated but the return - // error is only "NotLeader", we should update the whole region info. - j.convertStageTo(needRescan) - return false, nil - case errPb.EpochNotMatch != nil: - j.lastRetryableErr = common.ErrKVEpochNotMatch.GenWithStack(errPb.GetMessage()) - - if currentRegions := errPb.GetEpochNotMatch().GetCurrentRegions(); currentRegions != nil { - var currentRegion *metapb.Region - for _, r := range currentRegions { - if insideRegion(r, j.writeResult.sstMeta) { - currentRegion = r - break - } - } - if currentRegion != nil { - var newLeader *metapb.Peer - for _, p := range currentRegion.Peers { - if p.GetStoreId() == j.region.Leader.GetStoreId() { - newLeader = p - break - } - } - if newLeader != nil { - newRegion = &split.RegionInfo{ - Leader: newLeader, - Region: currentRegion, - } - } - } - } - if newRegion != nil { - j.region = newRegion - j.convertStageTo(regionScanned) - return false, nil - } - j.convertStageTo(needRescan) - return false, nil - case strings.Contains(errPb.Message, "raft: proposal dropped"): - j.lastRetryableErr = common.ErrKVRaftProposalDropped.GenWithStack(errPb.GetMessage()) - - j.convertStageTo(needRescan) - return false, nil - case errPb.ServerIsBusy != nil: - j.lastRetryableErr = common.ErrKVServerIsBusy.GenWithStack(errPb.GetMessage()) - - return false, nil - case errPb.RegionNotFound != nil: - j.lastRetryableErr = common.ErrKVRegionNotFound.GenWithStack(errPb.GetMessage()) - - j.convertStageTo(needRescan) - return false, nil - case errPb.ReadIndexNotReady != nil: - j.lastRetryableErr = common.ErrKVReadIndexNotReady.GenWithStack(errPb.GetMessage()) - - // this error happens when this region is splitting, the error might be: - // read index not ready, reason can not read index due to split, region 64037 - // we have paused schedule, but it's temporary, - // if next request takes a long time, there's chance schedule is enabled again - // or on key range border, another engine sharing this region tries to split this - // region may cause this error too. - j.convertStageTo(needRescan) - return false, nil - case errPb.DiskFull != nil: - j.lastRetryableErr = common.ErrKVIngestFailed.GenWithStack(errPb.GetMessage()) - - return false, errors.Errorf("non-retryable error: %s", resp.GetError().GetMessage()) - } - // all others doIngest error, such as stale command, etc. we'll retry it again from writeAndIngestByRange - j.lastRetryableErr = common.ErrKVIngestFailed.GenWithStack(resp.GetError().GetMessage()) - j.convertStageTo(regionScanned) - return false, nil -} - -type regionJobRetryHeap []*regionJob - -var _ heap.Interface = (*regionJobRetryHeap)(nil) - -func (h *regionJobRetryHeap) Len() int { - return len(*h) -} - -func (h *regionJobRetryHeap) Less(i, j int) bool { - v := *h - return v[i].waitUntil.Before(v[j].waitUntil) -} - -func (h *regionJobRetryHeap) Swap(i, j int) { - v := *h - v[i], v[j] = v[j], v[i] -} - -func (h *regionJobRetryHeap) Push(x any) { - *h = append(*h, x.(*regionJob)) -} - -func (h *regionJobRetryHeap) Pop() any { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x -} - -// regionJobRetryer is a concurrent-safe queue holding jobs that need to put -// back later, and put back when the regionJob.waitUntil is reached. It maintains -// a heap of jobs internally based on the regionJob.waitUntil field. -type regionJobRetryer struct { - // lock acquiring order: protectedClosed > protectedQueue > protectedToPutBack - protectedClosed struct { - mu sync.Mutex - closed bool - } - protectedQueue struct { - mu sync.Mutex - q regionJobRetryHeap - } - protectedToPutBack struct { - mu sync.Mutex - toPutBack *regionJob - } - putBackCh chan<- *regionJob - reload chan struct{} - jobWg *sync.WaitGroup -} - -// startRegionJobRetryer starts a new regionJobRetryer and it will run in -// background to put the job back to `putBackCh` when job's waitUntil is reached. -// Cancel the `ctx` will stop retryer and `jobWg.Done` will be trigger for jobs -// that are not put back yet. -func startRegionJobRetryer( - ctx context.Context, - putBackCh chan<- *regionJob, - jobWg *sync.WaitGroup, -) *regionJobRetryer { - ret := ®ionJobRetryer{ - putBackCh: putBackCh, - reload: make(chan struct{}, 1), - jobWg: jobWg, - } - ret.protectedQueue.q = make(regionJobRetryHeap, 0, 16) - go ret.run(ctx) - return ret -} - -// run is only internally used, caller should not use it. -func (q *regionJobRetryer) run(ctx context.Context) { - defer q.close() - - for { - var front *regionJob - q.protectedQueue.mu.Lock() - if len(q.protectedQueue.q) > 0 { - front = q.protectedQueue.q[0] - } - q.protectedQueue.mu.Unlock() - - switch { - case front != nil: - select { - case <-ctx.Done(): - return - case <-q.reload: - case <-time.After(time.Until(front.waitUntil)): - q.protectedQueue.mu.Lock() - q.protectedToPutBack.mu.Lock() - q.protectedToPutBack.toPutBack = heap.Pop(&q.protectedQueue.q).(*regionJob) - // release the lock of queue to avoid blocking regionJobRetryer.push - q.protectedQueue.mu.Unlock() - - // hold the lock of toPutBack to make sending to putBackCh and - // resetting toPutBack atomic w.r.t. regionJobRetryer.close - select { - case <-ctx.Done(): - q.protectedToPutBack.mu.Unlock() - return - case q.putBackCh <- q.protectedToPutBack.toPutBack: - q.protectedToPutBack.toPutBack = nil - q.protectedToPutBack.mu.Unlock() - } - } - default: - // len(q.q) == 0 - select { - case <-ctx.Done(): - return - case <-q.reload: - } - } - } -} - -// close is only internally used, caller should not use it. -func (q *regionJobRetryer) close() { - q.protectedClosed.mu.Lock() - defer q.protectedClosed.mu.Unlock() - q.protectedClosed.closed = true - - if q.protectedToPutBack.toPutBack != nil { - q.protectedToPutBack.toPutBack.done(q.jobWg) - } - for _, job := range q.protectedQueue.q { - job.done(q.jobWg) - } -} - -// push should not be blocked for long time in any cases. -func (q *regionJobRetryer) push(job *regionJob) bool { - q.protectedClosed.mu.Lock() - defer q.protectedClosed.mu.Unlock() - if q.protectedClosed.closed { - return false - } - - q.protectedQueue.mu.Lock() - heap.Push(&q.protectedQueue.q, job) - q.protectedQueue.mu.Unlock() - - select { - case q.reload <- struct{}{}: - default: - } - return true -} diff --git a/pkg/lightning/backend/tidb/binding__failpoint_binding__.go b/pkg/lightning/backend/tidb/binding__failpoint_binding__.go deleted file mode 100644 index 7cf3faffb040f..0000000000000 --- a/pkg/lightning/backend/tidb/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package tidb - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/lightning/backend/tidb/tidb.go b/pkg/lightning/backend/tidb/tidb.go index 2d0ae231b4fbc..186e17d9d6d06 100644 --- a/pkg/lightning/backend/tidb/tidb.go +++ b/pkg/lightning/backend/tidb/tidb.go @@ -252,11 +252,12 @@ func (b *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaNam return nil } - if _, _err_ := failpoint.Eval(_curpkg_("FetchRemoteTableModels_BeforeFetchTableAutoIDInfos")); _err_ == nil { - - fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") - - } + failpoint.Inject( + "FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", + func() { + fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") + }, + ) // init auto id column for each table for _, tbl := range tables { @@ -837,9 +838,9 @@ stmtLoop: } // max-error not yet reached (error consumed by errorMgr), proceed to next stmtTask. } - if _, _err_ := failpoint.Eval(_curpkg_("FailIfImportedSomeRows")); _err_ == nil { + failpoint.Inject("FailIfImportedSomeRows", func() { panic("forcing failure due to FailIfImportedSomeRows, before saving checkpoint") - } + }) return nil } diff --git a/pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ b/pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ deleted file mode 100644 index 186e17d9d6d06..0000000000000 --- a/pkg/lightning/backend/tidb/tidb.go__failpoint_stash__ +++ /dev/null @@ -1,956 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tidb - -import ( - "context" - "database/sql" - "encoding/hex" - "fmt" - "strconv" - "strings" - "time" - - gmysql "github.com/go-sql-driver/mysql" - "github.com/google/uuid" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/version" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/lightning/backend" - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/errormanager" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/lightning/verification" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/dbutil" - "github.com/pingcap/tidb/pkg/util/redact" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -var extraHandleTableColumn = &table.Column{ - ColumnInfo: kv.ExtraHandleColumnInfo, - GeneratedExpr: nil, - DefaultExpr: nil, -} - -const ( - writeRowsMaxRetryTimes = 3 -) - -type tidbRow struct { - insertStmt string - path string - offset int64 -} - -var emptyTiDBRow = tidbRow{ - insertStmt: "", - path: "", - offset: 0, -} - -type tidbRows []tidbRow - -// MarshalLogArray implements the zapcore.ArrayMarshaler interface -func (rows tidbRows) MarshalLogArray(encoder zapcore.ArrayEncoder) error { - for _, r := range rows { - encoder.AppendString(redact.Value(r.insertStmt)) - } - return nil -} - -type tidbEncoder struct { - mode mysql.SQLMode - tbl table.Table - se sessionctx.Context - // the index of table columns for each data field. - // index == len(table.columns) means this field is `_tidb_rowid` - columnIdx []int - // the max index used in this chunk, due to the ignore-columns config, we can't - // directly check the total column count, so we fall back to only check that - // the there are enough columns. - columnCnt int - // data file path - path string - logger log.Logger -} - -type encodingBuilder struct{} - -// NewEncodingBuilder creates an EncodingBuilder with TiDB backend implementation. -func NewEncodingBuilder() encode.EncodingBuilder { - return new(encodingBuilder) -} - -// NewEncoder creates a KV encoder. -// It implements the `backend.EncodingBuilder` interface. -func (*encodingBuilder) NewEncoder(ctx context.Context, config *encode.EncodingConfig) (encode.Encoder, error) { - se := kv.NewSessionCtx(&config.SessionOptions, log.FromContext(ctx)) - if config.SQLMode.HasStrictMode() { - se.GetSessionVars().SkipUTF8Check = false - se.GetSessionVars().SkipASCIICheck = false - } - - return &tidbEncoder{ - mode: config.SQLMode, - tbl: config.Table, - se: se, - path: config.Path, - logger: config.Logger, - }, nil -} - -// MakeEmptyRows creates an empty KV rows. -// It implements the `backend.EncodingBuilder` interface. -func (*encodingBuilder) MakeEmptyRows() encode.Rows { - return tidbRows(nil) -} - -type targetInfoGetter struct { - db *sql.DB -} - -// NewTargetInfoGetter creates an TargetInfoGetter with TiDB backend implementation. -func NewTargetInfoGetter(db *sql.DB) backend.TargetInfoGetter { - return &targetInfoGetter{ - db: db, - } -} - -// FetchRemoteDBModels implements the `backend.TargetInfoGetter` interface. -func (b *targetInfoGetter) FetchRemoteDBModels(ctx context.Context) ([]*model.DBInfo, error) { - results := []*model.DBInfo{} - logger := log.FromContext(ctx) - s := common.SQLWithRetry{ - DB: b.db, - Logger: logger, - } - err := s.Transact(ctx, "fetch db models", func(_ context.Context, tx *sql.Tx) error { - results = results[:0] - - rows, e := tx.Query("SHOW DATABASES") - if e != nil { - return e - } - defer rows.Close() - - for rows.Next() { - var dbName string - if e := rows.Scan(&dbName); e != nil { - return e - } - dbInfo := &model.DBInfo{ - Name: model.NewCIStr(dbName), - } - results = append(results, dbInfo) - } - return rows.Err() - }) - return results, err -} - -// FetchRemoteTableModels obtains the models of all tables given the schema name. -// It implements the `backend.TargetInfoGetter` interface. -// TODO: refactor -func (b *targetInfoGetter) FetchRemoteTableModels(ctx context.Context, schemaName string) ([]*model.TableInfo, error) { - var err error - results := []*model.TableInfo{} - logger := log.FromContext(ctx) - s := common.SQLWithRetry{ - DB: b.db, - Logger: logger, - } - - err = s.Transact(ctx, "fetch table columns", func(_ context.Context, tx *sql.Tx) error { - var versionStr string - if versionStr, err = version.FetchVersion(ctx, tx); err != nil { - return err - } - serverInfo := version.ParseServerInfo(versionStr) - - rows, e := tx.Query(` - SELECT table_name, column_name, column_type, generation_expression, extra - FROM information_schema.columns - WHERE table_schema = ? - ORDER BY table_name, ordinal_position; - `, schemaName) - if e != nil { - return e - } - defer rows.Close() - - var ( - curTableName string - curColOffset int - curTable *model.TableInfo - ) - tables := []*model.TableInfo{} - for rows.Next() { - var tableName, columnName, columnType, generationExpr, columnExtra string - if e := rows.Scan(&tableName, &columnName, &columnType, &generationExpr, &columnExtra); e != nil { - return e - } - if tableName != curTableName { - curTable = &model.TableInfo{ - Name: model.NewCIStr(tableName), - State: model.StatePublic, - PKIsHandle: true, - } - tables = append(tables, curTable) - curTableName = tableName - curColOffset = 0 - } - - // see: https://github.com/pingcap/parser/blob/3b2fb4b41d73710bc6c4e1f4e8679d8be6a4863e/types/field_type.go#L185-L191 - var flag uint - if strings.HasSuffix(columnType, "unsigned") { - flag |= mysql.UnsignedFlag - } - if strings.Contains(columnExtra, "auto_increment") { - flag |= mysql.AutoIncrementFlag - } - - ft := types.FieldType{} - ft.SetFlag(flag) - curTable.Columns = append(curTable.Columns, &model.ColumnInfo{ - Name: model.NewCIStr(columnName), - Offset: curColOffset, - State: model.StatePublic, - FieldType: ft, - GeneratedExprString: generationExpr, - }) - curColOffset++ - } - if err := rows.Err(); err != nil { - return err - } - // shard_row_id/auto random is only available after tidb v4.0.0 - // `show table next_row_id` is also not available before tidb v4.0.0 - if serverInfo.ServerType != version.ServerTypeTiDB || serverInfo.ServerVersion.Major < 4 { - results = tables - return nil - } - - failpoint.Inject( - "FetchRemoteTableModels_BeforeFetchTableAutoIDInfos", - func() { - fmt.Println("failpoint: FetchRemoteTableModels_BeforeFetchTableAutoIDInfos") - }, - ) - - // init auto id column for each table - for _, tbl := range tables { - tblName := common.UniqueTable(schemaName, tbl.Name.O) - autoIDInfos, err := FetchTableAutoIDInfos(ctx, tx, tblName) - if err != nil { - logger.Warn("fetch table auto ID infos error. Ignore this table and continue.", zap.String("table_name", tblName), zap.Error(err)) - continue - } - for _, info := range autoIDInfos { - for _, col := range tbl.Columns { - if col.Name.O == info.Column { - switch info.Type { - case "AUTO_INCREMENT": - col.AddFlag(mysql.AutoIncrementFlag) - case "AUTO_RANDOM": - col.AddFlag(mysql.PriKeyFlag) - tbl.PKIsHandle = true - // set a stub here, since we don't really need the real value - tbl.AutoRandomBits = 1 - } - } - } - } - results = append(results, tbl) - } - return nil - }) - return results, err -} - -// CheckRequirements performs the check whether the backend satisfies the version requirements. -// It implements the `backend.TargetInfoGetter` interface. -func (*targetInfoGetter) CheckRequirements(ctx context.Context, _ *backend.CheckCtx) error { - log.FromContext(ctx).Info("skipping check requirements for tidb backend") - return nil -} - -type tidbBackend struct { - db *sql.DB - conflictCfg config.Conflict - // onDuplicate is the type of INSERT SQL. It may be different with - // conflictCfg.Strategy to implement other feature, but the behaviour in caller's - // view should be the same. - onDuplicate config.DuplicateResolutionAlgorithm - errorMgr *errormanager.ErrorManager - // maxChunkSize and maxChunkRows are the target size and number of rows of each INSERT SQL - // statement to be sent to downstream. Sometimes we want to reduce the txn size to avoid - // affecting the cluster too much. - maxChunkSize uint64 - maxChunkRows int -} - -var _ backend.Backend = (*tidbBackend)(nil) - -// NewTiDBBackend creates a new TiDB backend using the given database. -// -// The backend does not take ownership of `db`. Caller should close `db` -// manually after the backend expired. -func NewTiDBBackend( - ctx context.Context, - db *sql.DB, - cfg *config.Config, - errorMgr *errormanager.ErrorManager, -) backend.Backend { - conflict := cfg.Conflict - var onDuplicate config.DuplicateResolutionAlgorithm - switch conflict.Strategy { - case config.ErrorOnDup: - onDuplicate = config.ErrorOnDup - case config.ReplaceOnDup: - onDuplicate = config.ReplaceOnDup - case config.IgnoreOnDup: - if conflict.MaxRecordRows == 0 { - onDuplicate = config.IgnoreOnDup - } else { - // need to stop batch insert on error and fall back to row by row insert - // to record the row - onDuplicate = config.ErrorOnDup - } - default: - log.FromContext(ctx).Warn("unsupported conflict strategy for TiDB backend, overwrite with `error`") - onDuplicate = config.ErrorOnDup - } - return &tidbBackend{ - db: db, - conflictCfg: conflict, - onDuplicate: onDuplicate, - errorMgr: errorMgr, - maxChunkSize: uint64(cfg.TikvImporter.LogicalImportBatchSize), - maxChunkRows: cfg.TikvImporter.LogicalImportBatchRows, - } -} - -func (row tidbRow) Size() uint64 { - return uint64(len(row.insertStmt)) -} - -func (row tidbRow) String() string { - return row.insertStmt -} - -func (row tidbRow) ClassifyAndAppend(data *encode.Rows, checksum *verification.KVChecksum, _ *encode.Rows, _ *verification.KVChecksum) { - rows := (*data).(tidbRows) - // Cannot do `rows := data.(*tidbRows); *rows = append(*rows, row)`. - //nolint:gocritic - *data = append(rows, row) - cs := verification.MakeKVChecksum(row.Size(), 1, 0) - checksum.Add(&cs) -} - -func (rows tidbRows) splitIntoChunks(splitSize uint64, splitRows int) []tidbRows { - if len(rows) == 0 { - return nil - } - - res := make([]tidbRows, 0, 1) - i := 0 - cumSize := uint64(0) - - for j, row := range rows { - if i < j && (cumSize+row.Size() > splitSize || j-i >= splitRows) { - res = append(res, rows[i:j]) - i = j - cumSize = 0 - } - cumSize += row.Size() - } - - return append(res, rows[i:]) -} - -func (rows tidbRows) Clear() encode.Rows { - return rows[:0] -} - -func (enc *tidbEncoder) appendSQLBytes(sb *strings.Builder, value []byte) { - sb.Grow(2 + len(value)) - sb.WriteByte('\'') - if enc.mode.HasNoBackslashEscapesMode() { - for _, b := range value { - if b == '\'' { - sb.WriteString(`''`) - } else { - sb.WriteByte(b) - } - } - } else { - for _, b := range value { - switch b { - case 0: - sb.WriteString(`\0`) - case '\b': - sb.WriteString(`\b`) - case '\n': - sb.WriteString(`\n`) - case '\r': - sb.WriteString(`\r`) - case '\t': - sb.WriteString(`\t`) - case 26: - sb.WriteString(`\Z`) - case '\'': - sb.WriteString(`''`) - case '\\': - sb.WriteString(`\\`) - default: - sb.WriteByte(b) - } - } - } - sb.WriteByte('\'') -} - -// appendSQL appends the SQL representation of the Datum into the string builder. -// Note that we cannot use Datum.ToString since it doesn't perform SQL escaping. -func (enc *tidbEncoder) appendSQL(sb *strings.Builder, datum *types.Datum, _ *table.Column) error { - switch datum.Kind() { - case types.KindNull: - sb.WriteString("NULL") - - case types.KindMinNotNull: - sb.WriteString("MINVALUE") - - case types.KindMaxValue: - sb.WriteString("MAXVALUE") - - case types.KindInt64: - // longest int64 = -9223372036854775808 which has 20 characters - var buffer [20]byte - value := strconv.AppendInt(buffer[:0], datum.GetInt64(), 10) - sb.Write(value) - - case types.KindUint64, types.KindMysqlEnum, types.KindMysqlSet: - // longest uint64 = 18446744073709551615 which has 20 characters - var buffer [20]byte - value := strconv.AppendUint(buffer[:0], datum.GetUint64(), 10) - sb.Write(value) - - case types.KindFloat32, types.KindFloat64: - // float64 has 16 digits of precision, so a buffer size of 32 is more than enough... - var buffer [32]byte - value := strconv.AppendFloat(buffer[:0], datum.GetFloat64(), 'g', -1, 64) - sb.Write(value) - case types.KindString: - // See: https://github.com/pingcap/tidb-lightning/issues/550 - // if enc.mode.HasStrictMode() { - // d, err := table.CastValue(enc.se, *datum, col.ToInfo(), false, false) - // if err != nil { - // return errors.Trace(err) - // } - // datum = &d - // } - - enc.appendSQLBytes(sb, datum.GetBytes()) - case types.KindBytes: - enc.appendSQLBytes(sb, datum.GetBytes()) - - case types.KindMysqlJSON: - value, err := datum.GetMysqlJSON().MarshalJSON() - if err != nil { - return err - } - enc.appendSQLBytes(sb, value) - - case types.KindBinaryLiteral: - value := datum.GetBinaryLiteral() - sb.Grow(3 + 2*len(value)) - sb.WriteString("x'") - if _, err := hex.NewEncoder(sb).Write(value); err != nil { - return errors.Trace(err) - } - sb.WriteByte('\'') - - case types.KindMysqlBit: - var buffer [20]byte - intValue, err := datum.GetBinaryLiteral().ToInt(types.DefaultStmtNoWarningContext) - if err != nil { - return err - } - value := strconv.AppendUint(buffer[:0], intValue, 10) - sb.Write(value) - - // time, duration, decimal - default: - value, err := datum.ToString() - if err != nil { - return err - } - sb.WriteByte('\'') - sb.WriteString(value) - sb.WriteByte('\'') - } - - return nil -} - -func (*tidbEncoder) Close() {} - -func getColumnByIndex(cols []*table.Column, index int) *table.Column { - if index == len(cols) { - return extraHandleTableColumn - } - return cols[index] -} - -func (enc *tidbEncoder) Encode(row []types.Datum, _ int64, columnPermutation []int, offset int64) (encode.Row, error) { - cols := enc.tbl.Cols() - - if len(enc.columnIdx) == 0 { - columnMaxIdx := -1 - columnIdx := make([]int, len(columnPermutation)) - for i := 0; i < len(columnPermutation); i++ { - columnIdx[i] = -1 - } - for i, idx := range columnPermutation { - if idx >= 0 { - columnIdx[idx] = i - if idx > columnMaxIdx { - columnMaxIdx = idx - } - } - } - enc.columnIdx = columnIdx - enc.columnCnt = columnMaxIdx + 1 - } - - // TODO: since the column count doesn't exactly reflect the real column names, we only check the upper bound currently. - // See: tests/generated_columns/data/gencol.various_types.0.sql this sql has no columns, so encodeLoop will fill the - // column permutation with default, thus enc.columnCnt > len(row). - if len(row) < enc.columnCnt { - // 1. if len(row) < enc.columnCnt: data in row cannot populate the insert statement, because - // there are enc.columnCnt elements to insert but fewer columns in row - enc.logger.Error("column count mismatch", zap.Ints("column_permutation", columnPermutation), - zap.Array("data", kv.RowArrayMarshaller(row))) - return emptyTiDBRow, errors.Errorf("column count mismatch, expected %d, got %d", enc.columnCnt, len(row)) - } - - if len(row) > len(enc.columnIdx) { - // 2. if len(row) > len(columnIdx): raw row data has more columns than those - // in the table - enc.logger.Error("column count mismatch", zap.Ints("column_count", enc.columnIdx), - zap.Array("data", kv.RowArrayMarshaller(row))) - return emptyTiDBRow, errors.Errorf("column count mismatch, at most %d but got %d", len(enc.columnIdx), len(row)) - } - - var encoded strings.Builder - encoded.Grow(8 * len(row)) - encoded.WriteByte('(') - cnt := 0 - for i, field := range row { - if enc.columnIdx[i] < 0 { - continue - } - if cnt > 0 { - encoded.WriteByte(',') - } - datum := field - if err := enc.appendSQL(&encoded, &datum, getColumnByIndex(cols, enc.columnIdx[i])); err != nil { - enc.logger.Error("tidb encode failed", - zap.Array("original", kv.RowArrayMarshaller(row)), - zap.Int("originalCol", i), - log.ShortError(err), - ) - return nil, err - } - cnt++ - } - encoded.WriteByte(')') - return tidbRow{ - insertStmt: encoded.String(), - path: enc.path, - offset: offset, - }, nil -} - -// EncodeRowForRecord encodes a row to a string compatible with INSERT statements. -func EncodeRowForRecord(ctx context.Context, encTable table.Table, sqlMode mysql.SQLMode, row []types.Datum, columnPermutation []int) string { - enc := tidbEncoder{ - tbl: encTable, - mode: sqlMode, - logger: log.FromContext(ctx), - } - resRow, err := enc.Encode(row, 0, columnPermutation, 0) - if err != nil { - // if encode can't succeed, fallback to record the raw input strings - // ignore the error since it can only happen if the datum type is unknown, this can't happen here. - datumStr, _ := types.DatumsToString(row, true) - return datumStr - } - return resRow.(tidbRow).insertStmt -} - -func (*tidbBackend) Close() { - // *Not* going to close `be.db`. The db object is normally borrowed from a - // TidbManager, so we let the manager to close it. -} - -func (*tidbBackend) RetryImportDelay() time.Duration { - return 0 -} - -func (*tidbBackend) ShouldPostProcess() bool { - return true -} - -func (*tidbBackend) OpenEngine(context.Context, *backend.EngineConfig, uuid.UUID) error { - return nil -} - -func (*tidbBackend) CloseEngine(context.Context, *backend.EngineConfig, uuid.UUID) error { - return nil -} - -func (*tidbBackend) CleanupEngine(context.Context, uuid.UUID) error { - return nil -} - -func (*tidbBackend) ImportEngine(context.Context, uuid.UUID, int64, int64) error { - return nil -} - -func (be *tidbBackend) WriteRows(ctx context.Context, tableName string, columnNames []string, rows encode.Rows) error { - var err error -rowLoop: - for _, r := range rows.(tidbRows).splitIntoChunks(be.maxChunkSize, be.maxChunkRows) { - for i := 0; i < writeRowsMaxRetryTimes; i++ { - // Write in the batch mode first. - err = be.WriteBatchRowsToDB(ctx, tableName, columnNames, r) - switch { - case err == nil: - continue rowLoop - case common.IsRetryableError(err): - // retry next loop - case be.errorMgr.TypeErrorsRemain() > 0 || - be.errorMgr.ConflictErrorsRemain() > 0 || - (be.conflictCfg.Strategy == config.ErrorOnDup && !be.errorMgr.RecordErrorOnce()): - // WriteBatchRowsToDB failed in the batch mode and can not be retried, - // we need to redo the writing row-by-row to find where the error locates (and skip it correctly in future). - if err = be.WriteRowsToDB(ctx, tableName, columnNames, r); err != nil { - // If the error is not nil, it means we reach the max error count in the - // non-batch mode or this is "error" conflict strategy. - return errors.Annotatef(err, "[%s] write rows exceed conflict threshold", tableName) - } - continue rowLoop - default: - return err - } - } - return errors.Annotatef(err, "[%s] batch write rows reach max retry %d and still failed", tableName, writeRowsMaxRetryTimes) - } - return nil -} - -type stmtTask struct { - rows tidbRows - stmt string -} - -// WriteBatchRowsToDB write rows in batch mode, which will insert multiple rows like this: -// -// insert into t1 values (111), (222), (333), (444); -func (be *tidbBackend) WriteBatchRowsToDB(ctx context.Context, tableName string, columnNames []string, rows tidbRows) error { - insertStmt := be.checkAndBuildStmt(rows, tableName, columnNames) - if insertStmt == nil { - return nil - } - // Note: we are not going to do interpolation (prepared statements) to avoid - // complication arise from data length overflow of BIT and BINARY columns - stmtTasks := make([]stmtTask, 1) - for i, row := range rows { - if i != 0 { - insertStmt.WriteByte(',') - } - insertStmt.WriteString(row.insertStmt) - } - stmtTasks[0] = stmtTask{rows, insertStmt.String()} - return be.execStmts(ctx, stmtTasks, tableName, true) -} - -func (be *tidbBackend) checkAndBuildStmt(rows tidbRows, tableName string, columnNames []string) *strings.Builder { - if len(rows) == 0 { - return nil - } - return be.buildStmt(tableName, columnNames) -} - -// WriteRowsToDB write rows in row-by-row mode, which will insert multiple rows like this: -// -// insert into t1 values (111); -// insert into t1 values (222); -// insert into t1 values (333); -// insert into t1 values (444); -// -// See more details in br#1366: https://github.com/pingcap/br/issues/1366 -func (be *tidbBackend) WriteRowsToDB(ctx context.Context, tableName string, columnNames []string, rows tidbRows) error { - insertStmt := be.checkAndBuildStmt(rows, tableName, columnNames) - if insertStmt == nil { - return nil - } - is := insertStmt.String() - stmtTasks := make([]stmtTask, 0, len(rows)) - for _, row := range rows { - var finalInsertStmt strings.Builder - finalInsertStmt.WriteString(is) - finalInsertStmt.WriteString(row.insertStmt) - stmtTasks = append(stmtTasks, stmtTask{[]tidbRow{row}, finalInsertStmt.String()}) - } - return be.execStmts(ctx, stmtTasks, tableName, false) -} - -func (be *tidbBackend) buildStmt(tableName string, columnNames []string) *strings.Builder { - var insertStmt strings.Builder - switch be.onDuplicate { - case config.ReplaceOnDup: - insertStmt.WriteString("REPLACE INTO ") - case config.IgnoreOnDup: - insertStmt.WriteString("INSERT IGNORE INTO ") - case config.ErrorOnDup: - insertStmt.WriteString("INSERT INTO ") - } - insertStmt.WriteString(tableName) - if len(columnNames) > 0 { - insertStmt.WriteByte('(') - for i, colName := range columnNames { - if i != 0 { - insertStmt.WriteByte(',') - } - common.WriteMySQLIdentifier(&insertStmt, colName) - } - insertStmt.WriteByte(')') - } - insertStmt.WriteString(" VALUES") - return &insertStmt -} - -func (be *tidbBackend) execStmts(ctx context.Context, stmtTasks []stmtTask, tableName string, batch bool) error { -stmtLoop: - for _, stmtTask := range stmtTasks { - var ( - result sql.Result - err error - ) - for i := 0; i < writeRowsMaxRetryTimes; i++ { - stmt := stmtTask.stmt - result, err = be.db.ExecContext(ctx, stmt) - if err == nil { - affected, err2 := result.RowsAffected() - if err2 != nil { - // should not happen - return errors.Trace(err2) - } - diff := int64(len(stmtTask.rows)) - affected - if diff < 0 { - diff = -diff - } - if diff > 0 { - if err2 = be.errorMgr.RecordDuplicateCount(diff); err2 != nil { - return err2 - } - } - continue stmtLoop - } - - if !common.IsContextCanceledError(err) { - log.FromContext(ctx).Error("execute statement failed", - zap.Array("rows", stmtTask.rows), zap.String("stmt", redact.Value(stmt)), zap.Error(err)) - } - // It's batch mode, just return the error. Caller will fall back to row-by-row mode. - if batch { - return errors.Trace(err) - } - if !common.IsRetryableError(err) { - break - } - } - - firstRow := stmtTask.rows[0] - - if isDupEntryError(err) { - // rowID is ignored in tidb backend - if be.conflictCfg.Strategy == config.ErrorOnDup { - be.errorMgr.RecordDuplicateOnce( - ctx, - log.FromContext(ctx), - tableName, - firstRow.path, - firstRow.offset, - err.Error(), - 0, - firstRow.insertStmt, - ) - return err - } - err = be.errorMgr.RecordDuplicate( - ctx, - log.FromContext(ctx), - tableName, - firstRow.path, - firstRow.offset, - err.Error(), - 0, - firstRow.insertStmt, - ) - } else { - err = be.errorMgr.RecordTypeError( - ctx, - log.FromContext(ctx), - tableName, - firstRow.path, - firstRow.offset, - firstRow.insertStmt, - err, - ) - } - if err != nil { - return errors.Trace(err) - } - // max-error not yet reached (error consumed by errorMgr), proceed to next stmtTask. - } - failpoint.Inject("FailIfImportedSomeRows", func() { - panic("forcing failure due to FailIfImportedSomeRows, before saving checkpoint") - }) - return nil -} - -func isDupEntryError(err error) bool { - merr, ok := errors.Cause(err).(*gmysql.MySQLError) - if !ok { - return false - } - return merr.Number == errno.ErrDupEntry -} - -// FlushEngine flushes the data in the engine to the underlying storage. -func (*tidbBackend) FlushEngine(context.Context, uuid.UUID) error { - return nil -} - -// FlushAllEngines flushes all the data in the engines to the underlying storage. -func (*tidbBackend) FlushAllEngines(context.Context) error { - return nil -} - -// ResetEngine resets the engine. -func (*tidbBackend) ResetEngine(context.Context, uuid.UUID) error { - return errors.New("cannot reset an engine in TiDB backend") -} - -// LocalWriter returns a writer that writes data to local storage. -func (be *tidbBackend) LocalWriter( - _ context.Context, - cfg *backend.LocalWriterConfig, - _ uuid.UUID, -) (backend.EngineWriter, error) { - return &Writer{be: be, tableName: cfg.TiDB.TableName}, nil -} - -// Writer is a writer that writes data to local storage. -type Writer struct { - be *tidbBackend - tableName string -} - -// Close implements the EngineWriter interface. -func (*Writer) Close(_ context.Context) (backend.ChunkFlushStatus, error) { - return nil, nil -} - -// AppendRows implements the EngineWriter interface. -func (w *Writer) AppendRows(ctx context.Context, columnNames []string, rows encode.Rows) error { - return w.be.WriteRows(ctx, w.tableName, columnNames, rows) -} - -// IsSynced implements the EngineWriter interface. -func (*Writer) IsSynced() bool { - return true -} - -// TableAutoIDInfo is the auto id information of a table. -type TableAutoIDInfo struct { - Column string - NextID uint64 - Type string -} - -// FetchTableAutoIDInfos fetches the auto id information of a table. -func FetchTableAutoIDInfos(ctx context.Context, exec dbutil.QueryExecutor, tableName string) ([]*TableAutoIDInfo, error) { - rows, e := exec.QueryContext(ctx, fmt.Sprintf("SHOW TABLE %s NEXT_ROW_ID", tableName)) - if e != nil { - return nil, errors.Trace(e) - } - var autoIDInfos []*TableAutoIDInfo - for rows.Next() { - var ( - dbName, tblName, columnName, idType string - nextID uint64 - ) - columns, err := rows.Columns() - if err != nil { - return nil, errors.Trace(err) - } - - //+--------------+------------+-------------+--------------------+----------------+ - //| DB_NAME | TABLE_NAME | COLUMN_NAME | NEXT_GLOBAL_ROW_ID | ID_TYPE | - //+--------------+------------+-------------+--------------------+----------------+ - //| testsysbench | t | _tidb_rowid | 1 | AUTO_INCREMENT | - //+--------------+------------+-------------+--------------------+----------------+ - - // if columns length is 4, it doesn't contain the last column `ID_TYPE`, and it will always be 'AUTO_INCREMENT' - // for v4.0.0~v4.0.2 show table t next_row_id only returns 4 columns. - if len(columns) == 4 { - err = rows.Scan(&dbName, &tblName, &columnName, &nextID) - idType = "AUTO_INCREMENT" - } else { - err = rows.Scan(&dbName, &tblName, &columnName, &nextID, &idType) - } - if err != nil { - return nil, errors.Trace(err) - } - autoIDInfos = append(autoIDInfos, &TableAutoIDInfo{ - Column: columnName, - NextID: nextID, - Type: idType, - }) - } - // Defer in for-loop would be costly, anyway, we don't need those rows after this turn of iteration. - //nolint:sqlclosecheck - if err := rows.Close(); err != nil { - return nil, errors.Trace(err) - } - if err := rows.Err(); err != nil { - return nil, errors.Trace(err) - } - return autoIDInfos, nil -} diff --git a/pkg/lightning/common/binding__failpoint_binding__.go b/pkg/lightning/common/binding__failpoint_binding__.go deleted file mode 100644 index a9dff357d9a9e..0000000000000 --- a/pkg/lightning/common/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package common - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/lightning/common/storage_unix.go b/pkg/lightning/common/storage_unix.go index cbcbc8f0f3a1a..a19dfb7c0ee98 100644 --- a/pkg/lightning/common/storage_unix.go +++ b/pkg/lightning/common/storage_unix.go @@ -29,10 +29,10 @@ import ( // GetStorageSize gets storage's capacity and available size func GetStorageSize(dir string) (size StorageSize, err error) { - if val, _err_ := failpoint.Eval(_curpkg_("GetStorageSize")); _err_ == nil { + failpoint.Inject("GetStorageSize", func(val failpoint.Value) { injectedSize := val.(int) - return StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil - } + failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) + }) var stat unix.Statfs_t err = unix.Statfs(dir, &stat) diff --git a/pkg/lightning/common/storage_unix.go__failpoint_stash__ b/pkg/lightning/common/storage_unix.go__failpoint_stash__ deleted file mode 100644 index a19dfb7c0ee98..0000000000000 --- a/pkg/lightning/common/storage_unix.go__failpoint_stash__ +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !windows - -// TODO: Deduplicate this implementation with DM! - -package common - -import ( - "reflect" - "syscall" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "golang.org/x/sys/unix" -) - -// GetStorageSize gets storage's capacity and available size -func GetStorageSize(dir string) (size StorageSize, err error) { - failpoint.Inject("GetStorageSize", func(val failpoint.Value) { - injectedSize := val.(int) - failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) - }) - - var stat unix.Statfs_t - err = unix.Statfs(dir, &stat) - if err != nil { - return size, errors.Annotatef(err, "cannot get disk capacity at %s", dir) - } - - // When container is run in MacOS, `bsize` obtained by `statfs` syscall is not the fundamental block size, - // but the `iosize` (optimal transfer block size) instead, it's usually 1024 times larger than the `bsize`. - // for example `4096 * 1024`. To get the correct block size, we should use `frsize`. But `frsize` isn't - // guaranteed to be supported everywhere, so we need to check whether it's supported before use it. - // For more details, please refer to: https://github.com/docker/for-mac/issues/2136 - bSize := uint64(stat.Bsize) - field := reflect.ValueOf(&stat).Elem().FieldByName("Frsize") - if field.IsValid() { - if field.Kind() == reflect.Uint64 { - bSize = field.Uint() - } else { - bSize = uint64(field.Int()) - } - } - - // Available blocks * size per block = available space in bytes - size.Available = uint64(stat.Bavail) * bSize - size.Capacity = stat.Blocks * bSize - - return -} - -// SameDisk is used to check dir1 and dir2 in the same disk. -func SameDisk(dir1 string, dir2 string) (bool, error) { - st1 := syscall.Stat_t{} - st2 := syscall.Stat_t{} - - if err := syscall.Stat(dir1, &st1); err != nil { - return false, err - } - - if err := syscall.Stat(dir2, &st2); err != nil { - return false, err - } - - return st1.Dev == st2.Dev, nil -} diff --git a/pkg/lightning/common/storage_windows.go b/pkg/lightning/common/storage_windows.go index 352636893c78f..89b9483592f94 100644 --- a/pkg/lightning/common/storage_windows.go +++ b/pkg/lightning/common/storage_windows.go @@ -33,10 +33,10 @@ var ( // GetStorageSize gets storage's capacity and available size func GetStorageSize(dir string) (size StorageSize, err error) { - if val, _err_ := failpoint.Eval(_curpkg_("GetStorageSize")); _err_ == nil { + failpoint.Inject("GetStorageSize", func(val failpoint.Value) { injectedSize := val.(int) - return StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil - } + failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) + }) r, _, e := getDiskFreeSpaceExW.Call( uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(dir))), uintptr(unsafe.Pointer(&size.Available)), diff --git a/pkg/lightning/common/storage_windows.go__failpoint_stash__ b/pkg/lightning/common/storage_windows.go__failpoint_stash__ deleted file mode 100644 index 89b9483592f94..0000000000000 --- a/pkg/lightning/common/storage_windows.go__failpoint_stash__ +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build windows - -// TODO: Deduplicate this implementation with DM! - -package common - -import ( - "syscall" - "unsafe" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" -) - -var ( - kernel32 = syscall.MustLoadDLL("kernel32.dll") - getDiskFreeSpaceExW = kernel32.MustFindProc("GetDiskFreeSpaceExW") -) - -// GetStorageSize gets storage's capacity and available size -func GetStorageSize(dir string) (size StorageSize, err error) { - failpoint.Inject("GetStorageSize", func(val failpoint.Value) { - injectedSize := val.(int) - failpoint.Return(StorageSize{Capacity: uint64(injectedSize), Available: uint64(injectedSize)}, nil) - }) - r, _, e := getDiskFreeSpaceExW.Call( - uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(dir))), - uintptr(unsafe.Pointer(&size.Available)), - uintptr(unsafe.Pointer(&size.Capacity)), - 0, - ) - if r == 0 { - err = errors.Annotatef(e, "cannot get disk capacity at %s", dir) - } - return -} - -// SameDisk is used to check dir1 and dir2 in the same disk. -func SameDisk(dir1 string, dir2 string) (bool, error) { - // FIXME - return false, nil -} diff --git a/pkg/lightning/common/util.go b/pkg/lightning/common/util.go index 1cca4d45d5f28..ac3ba7d1efaa5 100644 --- a/pkg/lightning/common/util.go +++ b/pkg/lightning/common/util.go @@ -93,13 +93,13 @@ func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config { } func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { - if val, _err_ := failpoint.Eval(_curpkg_("MustMySQLPassword")); _err_ == nil { + failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) { pwd := val.(string) if cfg.Passwd != pwd { - return nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"} + failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}) } - return nil, nil - } + failpoint.Return(nil, nil) + }) c, err := mysql.NewConnector(cfg) if err != nil { return nil, errors.Trace(err) diff --git a/pkg/lightning/common/util.go__failpoint_stash__ b/pkg/lightning/common/util.go__failpoint_stash__ deleted file mode 100644 index ac3ba7d1efaa5..0000000000000 --- a/pkg/lightning/common/util.go__failpoint_stash__ +++ /dev/null @@ -1,704 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package common - -import ( - "bytes" - "context" - "crypto/tls" - "database/sql" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "os" - "strconv" - "strings" - "syscall" - "time" - - "github.com/go-sql-driver/mysql" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/parser/model" - tmysql "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbutil" - "github.com/pingcap/tidb/pkg/util/format" - "go.uber.org/zap" -) - -const ( - retryTimeout = 3 * time.Second - - defaultMaxRetry = 3 -) - -// MySQLConnectParam records the parameters needed to connect to a MySQL database. -type MySQLConnectParam struct { - Host string - Port int - User string - Password string - SQLMode string - MaxAllowedPacket uint64 - TLSConfig *tls.Config - AllowFallbackToPlaintext bool - Net string - Vars map[string]string -} - -// ToDriverConfig converts the MySQLConnectParam to a mysql.Config. -func (param *MySQLConnectParam) ToDriverConfig() *mysql.Config { - cfg := mysql.NewConfig() - cfg.Params = make(map[string]string) - - cfg.User = param.User - cfg.Passwd = param.Password - cfg.Net = "tcp" - if param.Net != "" { - cfg.Net = param.Net - } - cfg.Addr = net.JoinHostPort(param.Host, strconv.Itoa(param.Port)) - cfg.Params["charset"] = "utf8mb4" - cfg.Params["sql_mode"] = fmt.Sprintf("'%s'", param.SQLMode) - cfg.MaxAllowedPacket = int(param.MaxAllowedPacket) - - cfg.TLS = param.TLSConfig - cfg.AllowFallbackToPlaintext = param.AllowFallbackToPlaintext - - for k, v := range param.Vars { - cfg.Params[k] = fmt.Sprintf("'%s'", v) - } - return cfg -} - -func tryConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { - failpoint.Inject("MustMySQLPassword", func(val failpoint.Value) { - pwd := val.(string) - if cfg.Passwd != pwd { - failpoint.Return(nil, &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}) - } - failpoint.Return(nil, nil) - }) - c, err := mysql.NewConnector(cfg) - if err != nil { - return nil, errors.Trace(err) - } - db := sql.OpenDB(c) - if err = db.Ping(); err != nil { - _ = db.Close() - return nil, errors.Trace(err) - } - return db, nil -} - -// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding, -// we will try to connect MySQL with the base64 decoding of the password. -func ConnectMySQL(cfg *mysql.Config) (*sql.DB, error) { - // Try plain password first. - db, firstErr := tryConnectMySQL(cfg) - if firstErr == nil { - return db, nil - } - // If access is denied and password is encoded by base64, try the decoded string as well. - if mysqlErr, ok := errors.Cause(firstErr).(*mysql.MySQLError); ok && mysqlErr.Number == tmysql.ErrAccessDenied { - // If password is encoded by base64, try the decoded string as well. - password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd) - if decodeErr == nil && string(password) != cfg.Passwd { - cfg.Passwd = string(password) - db2, err := tryConnectMySQL(cfg) - if err == nil { - return db2, nil - } - } - } - // If we can't connect successfully, return the first error. - return nil, errors.Trace(firstErr) -} - -// Connect creates a new connection to the database. -func (param *MySQLConnectParam) Connect() (*sql.DB, error) { - db, err := ConnectMySQL(param.ToDriverConfig()) - if err != nil { - return nil, errors.Trace(err) - } - return db, nil -} - -// IsDirExists checks if dir exists. -func IsDirExists(name string) bool { - f, err := os.Stat(name) - if err != nil { - return false - } - return f != nil && f.IsDir() -} - -// IsEmptyDir checks if dir is empty. -func IsEmptyDir(name string) bool { - entries, err := os.ReadDir(name) - if err != nil { - return false - } - return len(entries) == 0 -} - -// SQLWithRetry constructs a retryable transaction. -type SQLWithRetry struct { - // either *sql.DB or *sql.Conn - DB dbutil.DBExecutor - Logger log.Logger - HideQueryLog bool -} - -func (SQLWithRetry) perform(_ context.Context, parentLogger log.Logger, purpose string, action func() error) error { - return Retry(purpose, parentLogger, action) -} - -// Retry is shared by SQLWithRetry.perform, implementation of GlueCheckpointsDB and TiDB's glue implementation -func Retry(purpose string, parentLogger log.Logger, action func() error) error { - var err error -outside: - for i := 0; i < defaultMaxRetry; i++ { - logger := parentLogger.With(zap.Int("retryCnt", i)) - - if i > 0 { - logger.Warn(purpose + " retry start") - time.Sleep(retryTimeout) - } - - err = action() - switch { - case err == nil: - return nil - // do not retry NotFound error - case errors.IsNotFound(err): - break outside - case IsRetryableError(err): - logger.Warn(purpose+" failed but going to try again", log.ShortError(err)) - continue - default: - logger.Warn(purpose+" failed with no retry", log.ShortError(err)) - break outside - } - } - - return errors.Annotatef(err, "%s failed", purpose) -} - -// QueryRow executes a query that is expected to return at most one row. -func (t SQLWithRetry) QueryRow(ctx context.Context, purpose string, query string, dest ...any) error { - logger := t.Logger - if !t.HideQueryLog { - logger = logger.With(zap.String("query", query)) - } - return t.perform(ctx, logger, purpose, func() error { - return t.DB.QueryRowContext(ctx, query).Scan(dest...) - }) -} - -// QueryStringRows executes a query that is expected to return multiple rows -// whose every column is string. -func (t SQLWithRetry) QueryStringRows(ctx context.Context, purpose string, query string) ([][]string, error) { - var res [][]string - logger := t.Logger - if !t.HideQueryLog { - logger = logger.With(zap.String("query", query)) - } - - err := t.perform(ctx, logger, purpose, func() error { - rows, err := t.DB.QueryContext(ctx, query) - if err != nil { - return err - } - defer rows.Close() - - colNames, err := rows.Columns() - if err != nil { - return err - } - for rows.Next() { - row := make([]string, len(colNames)) - refs := make([]any, 0, len(row)) - for i := range row { - refs = append(refs, &row[i]) - } - if err := rows.Scan(refs...); err != nil { - return err - } - res = append(res, row) - } - return rows.Err() - }) - - return res, err -} - -// Transact executes an action in a transaction, and retry if the -// action failed with a retryable error. -func (t SQLWithRetry) Transact(ctx context.Context, purpose string, action func(context.Context, *sql.Tx) error) error { - return t.perform(ctx, t.Logger, purpose, func() error { - txn, err := t.DB.BeginTx(ctx, nil) - if err != nil { - return errors.Annotate(err, "begin transaction failed") - } - - err = action(ctx, txn) - if err != nil { - rerr := txn.Rollback() - if rerr != nil { - t.Logger.Error(purpose+" rollback transaction failed", log.ShortError(rerr)) - } - // we should return the exec err, instead of the rollback rerr. - // no need to errors.Trace() it, as the error comes from user code anyway. - return err - } - - err = txn.Commit() - if err != nil { - return errors.Annotate(err, "commit transaction failed") - } - - return nil - }) -} - -// Exec executes a single SQL with optional retry. -func (t SQLWithRetry) Exec(ctx context.Context, purpose string, query string, args ...any) error { - logger := t.Logger - if !t.HideQueryLog { - logger = logger.With(zap.String("query", query), zap.Reflect("args", args)) - } - return t.perform(ctx, logger, purpose, func() error { - _, err := t.DB.ExecContext(ctx, query, args...) - return errors.Trace(err) - }) -} - -// IsContextCanceledError returns whether the error is caused by context -// cancellation. This function should only be used when the code logic is -// affected by whether the error is canceling or not. -// -// This function returns `false` (not a context-canceled error) if `err == nil`. -func IsContextCanceledError(err error) bool { - return log.IsContextCanceledError(err) -} - -// UniqueTable returns an unique table name. -func UniqueTable(schema string, table string) string { - var builder strings.Builder - WriteMySQLIdentifier(&builder, schema) - builder.WriteByte('.') - WriteMySQLIdentifier(&builder, table) - return builder.String() -} - -func escapeIdentifiers(identifier []string) []any { - escaped := make([]any, len(identifier)) - for i, id := range identifier { - escaped[i] = EscapeIdentifier(id) - } - return escaped -} - -// SprintfWithIdentifiers escapes the identifiers and sprintf them. The input -// identifiers must not be escaped. -func SprintfWithIdentifiers(format string, identifiers ...string) string { - return fmt.Sprintf(format, escapeIdentifiers(identifiers)...) -} - -// FprintfWithIdentifiers escapes the identifiers and fprintf them. The input -// identifiers must not be escaped. -func FprintfWithIdentifiers(w io.Writer, format string, identifiers ...string) (int, error) { - return fmt.Fprintf(w, format, escapeIdentifiers(identifiers)...) -} - -// EscapeIdentifier quote and escape an sql identifier -func EscapeIdentifier(identifier string) string { - var builder strings.Builder - WriteMySQLIdentifier(&builder, identifier) - return builder.String() -} - -// WriteMySQLIdentifier writes a MySQL identifier into the string builder. -// Writes a MySQL identifier into the string builder. -// The identifier is always escaped into the form "`foo`". -func WriteMySQLIdentifier(builder *strings.Builder, identifier string) { - builder.Grow(len(identifier) + 2) - builder.WriteByte('`') - - // use a C-style loop instead of range loop to avoid UTF-8 decoding - for i := 0; i < len(identifier); i++ { - b := identifier[i] - if b == '`' { - builder.WriteString("``") - } else { - builder.WriteByte(b) - } - } - - builder.WriteByte('`') -} - -// InterpolateMySQLString interpolates a string into a MySQL string literal. -func InterpolateMySQLString(s string) string { - var builder strings.Builder - builder.Grow(len(s) + 2) - builder.WriteByte('\'') - for i := 0; i < len(s); i++ { - b := s[i] - if b == '\'' { - builder.WriteString("''") - } else { - builder.WriteByte(b) - } - } - builder.WriteByte('\'') - return builder.String() -} - -// TableExists return whether table with specified name exists in target db -func TableExists(ctx context.Context, db dbutil.QueryExecutor, schema, table string) (bool, error) { - query := "SELECT 1 from INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?" - var exist string - err := db.QueryRowContext(ctx, query, schema, table).Scan(&exist) - switch err { - case nil: - return true, nil - case sql.ErrNoRows: - return false, nil - default: - return false, errors.Annotatef(err, "check table exists failed") - } -} - -// SchemaExists return whether schema with specified name exists. -func SchemaExists(ctx context.Context, db dbutil.QueryExecutor, schema string) (bool, error) { - query := "SELECT 1 from INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?" - var exist string - err := db.QueryRowContext(ctx, query, schema).Scan(&exist) - switch err { - case nil: - return true, nil - case sql.ErrNoRows: - return false, nil - default: - return false, errors.Annotatef(err, "check schema exists failed") - } -} - -// GetJSON fetches a page and parses it as JSON. The parsed result will be -// stored into the `v`. The variable `v` must be a pointer to a type that can be -// unmarshalled from JSON. -// -// Example: -// -// client := &http.Client{} -// var resp struct { IP string } -// if err := util.GetJSON(client, "http://api.ipify.org/?format=json", &resp); err != nil { -// return errors.Trace(err) -// } -// fmt.Println(resp.IP) -func GetJSON(ctx context.Context, client *http.Client, url string, v any) error { - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return errors.Trace(err) - } - - resp, err := client.Do(req) - if err != nil { - return errors.Trace(err) - } - - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, err := io.ReadAll(resp.Body) - if err != nil { - return errors.Trace(err) - } - return errors.Errorf("get %s http status code != 200, message %s", url, string(body)) - } - - return errors.Trace(json.NewDecoder(resp.Body).Decode(v)) -} - -// KillMySelf sends sigint to current process, used in integration test only -// -// Only works on Unix. Signaling on Windows is not supported. -func KillMySelf() error { - proc, err := os.FindProcess(os.Getpid()) - if err == nil { - err = proc.Signal(syscall.SIGINT) - } - return errors.Trace(err) -} - -// KvPair contains a key-value pair and other fields that can be used to ingest -// KV pairs into TiKV. -type KvPair struct { - // Key is the key of the KV pair - Key []byte - // Val is the value of the KV pair - Val []byte - // RowID identifies a KvPair in case two KvPairs are equal in Key and Val. It has - // two sources: - // - // When the KvPair is generated from ADD INDEX, the RowID is the encoded handle. - // - // Otherwise, the RowID is related to the row number in the source files, and - // encode the number with `codec.EncodeComparableVarint`. - RowID []byte -} - -// EncodeIntRowID encodes an int64 row id. -func EncodeIntRowID(rowID int64) []byte { - return codec.EncodeComparableVarint(nil, rowID) -} - -// TableHasAutoRowID return whether table has auto generated row id -func TableHasAutoRowID(info *model.TableInfo) bool { - return !info.PKIsHandle && !info.IsCommonHandle -} - -// TableHasAutoID return whether table has auto generated id. -func TableHasAutoID(info *model.TableInfo) bool { - return TableHasAutoRowID(info) || info.GetAutoIncrementColInfo() != nil || info.ContainsAutoRandomBits() -} - -// GetAutoRandomColumn return the column with auto_random, return nil if the table doesn't have it. -// todo: better put in ddl package, but this will cause import cycle since ddl package import lightning -func GetAutoRandomColumn(tblInfo *model.TableInfo) *model.ColumnInfo { - if !tblInfo.ContainsAutoRandomBits() { - return nil - } - if tblInfo.PKIsHandle { - return tblInfo.GetPkColInfo() - } else if tblInfo.IsCommonHandle { - pk := tables.FindPrimaryIndex(tblInfo) - if pk == nil { - return nil - } - offset := pk.Columns[0].Offset - return tblInfo.Columns[offset] - } - return nil -} - -// GetDropIndexInfos returns the index infos that need to be dropped and the remain indexes. -func GetDropIndexInfos( - tblInfo *model.TableInfo, -) (remainIndexes []*model.IndexInfo, dropIndexes []*model.IndexInfo) { - cols := tblInfo.Columns -loop: - for _, idxInfo := range tblInfo.Indices { - if idxInfo.State != model.StatePublic { - remainIndexes = append(remainIndexes, idxInfo) - continue - } - // Primary key is a cluster index. - if idxInfo.Primary && tblInfo.HasClusteredIndex() { - remainIndexes = append(remainIndexes, idxInfo) - continue - } - // Skip index that contains auto-increment column. - // Because auto column must be defined as a key. - for _, idxCol := range idxInfo.Columns { - flag := cols[idxCol.Offset].GetFlag() - if tmysql.HasAutoIncrementFlag(flag) { - remainIndexes = append(remainIndexes, idxInfo) - continue loop - } - } - dropIndexes = append(dropIndexes, idxInfo) - } - return remainIndexes, dropIndexes -} - -// BuildDropIndexSQL builds the SQL statement to drop index. -func BuildDropIndexSQL(dbName, tableName string, idxInfo *model.IndexInfo) string { - if idxInfo.Primary { - return SprintfWithIdentifiers("ALTER TABLE %s.%s DROP PRIMARY KEY", dbName, tableName) - } - return SprintfWithIdentifiers("ALTER TABLE %s.%s DROP INDEX %s", dbName, tableName, idxInfo.Name.O) -} - -// BuildAddIndexSQL builds the SQL statement to create missing indexes. -// It returns both a single SQL statement that creates all indexes at once, -// and a list of SQL statements that creates each index individually. -func BuildAddIndexSQL( - tableName string, - curTblInfo, - desiredTblInfo *model.TableInfo, -) (singleSQL string, multiSQLs []string) { - addIndexSpecs := make([]string, 0, len(desiredTblInfo.Indices)) -loop: - for _, desiredIdxInfo := range desiredTblInfo.Indices { - for _, curIdxInfo := range curTblInfo.Indices { - if curIdxInfo.Name.L == desiredIdxInfo.Name.L { - continue loop - } - } - - var buf bytes.Buffer - if desiredIdxInfo.Primary { - buf.WriteString("ADD PRIMARY KEY ") - } else if desiredIdxInfo.Unique { - buf.WriteString("ADD UNIQUE KEY ") - } else { - buf.WriteString("ADD KEY ") - } - // "primary" is a special name for primary key, we should not use it as index name. - if desiredIdxInfo.Name.L != "primary" { - buf.WriteString(EscapeIdentifier(desiredIdxInfo.Name.O)) - } - - colStrs := make([]string, 0, len(desiredIdxInfo.Columns)) - for _, col := range desiredIdxInfo.Columns { - var colStr string - if desiredTblInfo.Columns[col.Offset].Hidden { - colStr = fmt.Sprintf("(%s)", desiredTblInfo.Columns[col.Offset].GeneratedExprString) - } else { - colStr = EscapeIdentifier(col.Name.O) - if col.Length != types.UnspecifiedLength { - colStr = fmt.Sprintf("%s(%s)", colStr, strconv.Itoa(col.Length)) - } - } - colStrs = append(colStrs, colStr) - } - fmt.Fprintf(&buf, "(%s)", strings.Join(colStrs, ",")) - - if desiredIdxInfo.Invisible { - fmt.Fprint(&buf, " INVISIBLE") - } - if desiredIdxInfo.Comment != "" { - fmt.Fprintf(&buf, ` COMMENT '%s'`, format.OutputFormat(desiredIdxInfo.Comment)) - } - addIndexSpecs = append(addIndexSpecs, buf.String()) - } - if len(addIndexSpecs) == 0 { - return "", nil - } - - singleSQL = fmt.Sprintf("ALTER TABLE %s %s", tableName, strings.Join(addIndexSpecs, ", ")) - for _, spec := range addIndexSpecs { - multiSQLs = append(multiSQLs, fmt.Sprintf("ALTER TABLE %s %s", tableName, spec)) - } - return singleSQL, multiSQLs -} - -// IsDupKeyError checks if err is a duplicate index error. -func IsDupKeyError(err error) bool { - if merr, ok := errors.Cause(err).(*mysql.MySQLError); ok { - switch merr.Number { - case errno.ErrDupKeyName, errno.ErrMultiplePriKey, errno.ErrDupUnique: - return true - } - } - return false -} - -// GetBackoffWeightFromDB gets the backoff weight from database. -func GetBackoffWeightFromDB(ctx context.Context, db *sql.DB) (int, error) { - val, err := getSessionVariable(ctx, db, variable.TiDBBackOffWeight) - if err != nil { - return 0, err - } - return strconv.Atoi(val) -} - -// GetExplicitRequestSourceTypeFromDB gets the explicit request source type from database. -func GetExplicitRequestSourceTypeFromDB(ctx context.Context, db *sql.DB) (string, error) { - return getSessionVariable(ctx, db, variable.TiDBExplicitRequestSourceType) -} - -// copy from dbutil to avoid import cycle -func getSessionVariable(ctx context.Context, db *sql.DB, variable string) (value string, err error) { - query := fmt.Sprintf("SHOW VARIABLES LIKE '%s'", variable) - rows, err := db.QueryContext(ctx, query) - - if err != nil { - return "", errors.Trace(err) - } - defer rows.Close() - - // Show an example. - /* - mysql> SHOW VARIABLES LIKE "binlog_format"; - +---------------+-------+ - | Variable_name | Value | - +---------------+-------+ - | binlog_format | ROW | - +---------------+-------+ - */ - - for rows.Next() { - if err = rows.Scan(&variable, &value); err != nil { - return "", errors.Trace(err) - } - } - - if err := rows.Err(); err != nil { - return "", errors.Trace(err) - } - - return value, nil -} - -// IsFunctionNotExistErr checks if err is a function not exist error. -func IsFunctionNotExistErr(err error, functionName string) bool { - return err != nil && - (strings.Contains(err.Error(), "No database selected") || - strings.Contains(err.Error(), fmt.Sprintf("%s does not exist", functionName))) -} - -// IsRaftKV2 checks whether the raft-kv2 is enabled -func IsRaftKV2(ctx context.Context, db *sql.DB) (bool, error) { - var ( - getRaftKvVersionSQL = "show config where type = 'tikv' and name = 'storage.engine'" - raftKv2 = "raft-kv2" - tp, instance, name, value string - ) - - rows, err := db.QueryContext(ctx, getRaftKvVersionSQL) - if err != nil { - return false, errors.Trace(err) - } - defer rows.Close() - - for rows.Next() { - if err = rows.Scan(&tp, &instance, &name, &value); err != nil { - return false, errors.Trace(err) - } - if value == raftKv2 { - return true, nil - } - } - return false, rows.Err() -} - -// IsAccessDeniedNeedConfigPrivilegeError checks if err is generated from a query to TiDB which failed due to missing CONFIG privilege. -func IsAccessDeniedNeedConfigPrivilegeError(err error) bool { - e, ok := err.(*mysql.MySQLError) - return ok && e.Number == errno.ErrSpecificAccessDenied && strings.Contains(e.Message, "CONFIG") -} diff --git a/pkg/lightning/mydump/binding__failpoint_binding__.go b/pkg/lightning/mydump/binding__failpoint_binding__.go deleted file mode 100644 index ce8a21b37f034..0000000000000 --- a/pkg/lightning/mydump/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package mydump - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/lightning/mydump/loader.go b/pkg/lightning/mydump/loader.go index b0d22084400f6..d2fc407ddcdeb 100644 --- a/pkg/lightning/mydump/loader.go +++ b/pkg/lightning/mydump/loader.go @@ -789,14 +789,14 @@ func calculateFileBytes(ctx context.Context, // SampleFileCompressRatio samples the compress ratio of the compressed file. func SampleFileCompressRatio(ctx context.Context, fileMeta SourceFileMeta, store storage.ExternalStorage) (float64, error) { - if val, _err_ := failpoint.Eval(_curpkg_("SampleFileCompressPercentage")); _err_ == nil { + failpoint.Inject("SampleFileCompressPercentage", func(val failpoint.Value) { switch v := val.(type) { case string: - return 1.0, errors.New(v) + failpoint.Return(1.0, errors.New(v)) case int: - return float64(v) / 100, nil + failpoint.Return(float64(v)/100, nil) } - } + }) if fileMeta.Compression == CompressionNone { return 1, nil } diff --git a/pkg/lightning/mydump/loader.go__failpoint_stash__ b/pkg/lightning/mydump/loader.go__failpoint_stash__ deleted file mode 100644 index d2fc407ddcdeb..0000000000000 --- a/pkg/lightning/mydump/loader.go__failpoint_stash__ +++ /dev/null @@ -1,868 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mydump - -import ( - "context" - "io" - "path/filepath" - "sort" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/br/pkg/storage" - "github.com/pingcap/tidb/pkg/lightning/common" - "github.com/pingcap/tidb/pkg/lightning/config" - "github.com/pingcap/tidb/pkg/lightning/log" - regexprrouter "github.com/pingcap/tidb/pkg/util/regexpr-router" - filter "github.com/pingcap/tidb/pkg/util/table-filter" - "go.uber.org/zap" -) - -// sampleCompressedFileSize represents how many bytes need to be sampled for compressed files -const ( - sampleCompressedFileSize = 4 * 1024 - maxSampleParquetDataSize = 8 * 1024 - maxSampleParquetRowCount = 500 -) - -// MDDatabaseMeta contains some parsed metadata for a database in the source by MyDumper Loader. -type MDDatabaseMeta struct { - Name string - SchemaFile FileInfo - Tables []*MDTableMeta - Views []*MDTableMeta - charSet string -} - -// NewMDDatabaseMeta creates an Mydumper database meta with specified character set. -func NewMDDatabaseMeta(charSet string) *MDDatabaseMeta { - return &MDDatabaseMeta{ - charSet: charSet, - } -} - -// GetSchema gets the schema SQL for a source database. -func (m *MDDatabaseMeta) GetSchema(ctx context.Context, store storage.ExternalStorage) string { - if m.SchemaFile.FileMeta.Path != "" { - schema, err := ExportStatement(ctx, store, m.SchemaFile, m.charSet) - if err != nil { - log.FromContext(ctx).Warn("failed to extract table schema", - zap.String("Path", m.SchemaFile.FileMeta.Path), - log.ShortError(err), - ) - } else if schemaStr := strings.TrimSpace(string(schema)); schemaStr != "" { - return schemaStr - } - } - // set default if schema sql is empty or failed to extract. - return common.SprintfWithIdentifiers("CREATE DATABASE IF NOT EXISTS %s", m.Name) -} - -// MDTableMeta contains some parsed metadata for a table in the source by MyDumper Loader. -type MDTableMeta struct { - DB string - Name string - SchemaFile FileInfo - DataFiles []FileInfo - charSet string - TotalSize int64 - IndexRatio float64 - // default to true, and if we do precheck, this var is updated using data sampling result, so it's not accurate. - IsRowOrdered bool -} - -// SourceFileMeta contains some analyzed metadata for a source file by MyDumper Loader. -type SourceFileMeta struct { - Path string - Type SourceType - Compression Compression - SortKey string - // FileSize is the size of the file in the storage. - FileSize int64 - // WARNING: variables below are not persistent - ExtendData ExtendColumnData - // RealSize is same as FileSize if the file is not compressed and not parquet. - // If the file is compressed, RealSize is the estimated uncompressed size. - // If the file is parquet, RealSize is the estimated data size after convert - // to row oriented storage. - RealSize int64 - Rows int64 // only for parquet -} - -// NewMDTableMeta creates an Mydumper table meta with specified character set. -func NewMDTableMeta(charSet string) *MDTableMeta { - return &MDTableMeta{ - charSet: charSet, - } -} - -// GetSchema gets the table-creating SQL for a source table. -func (m *MDTableMeta) GetSchema(ctx context.Context, store storage.ExternalStorage) (string, error) { - schemaFilePath := m.SchemaFile.FileMeta.Path - if len(schemaFilePath) <= 0 { - return "", errors.Errorf("schema file is missing for the table '%s.%s'", m.DB, m.Name) - } - fileExists, err := store.FileExists(ctx, schemaFilePath) - if err != nil { - return "", errors.Annotate(err, "check table schema file exists error") - } - if !fileExists { - return "", errors.Errorf("the provided schema file (%s) for the table '%s.%s' doesn't exist", - schemaFilePath, m.DB, m.Name) - } - schema, err := ExportStatement(ctx, store, m.SchemaFile, m.charSet) - if err != nil { - log.FromContext(ctx).Error("failed to extract table schema", - zap.String("Path", m.SchemaFile.FileMeta.Path), - log.ShortError(err), - ) - return "", errors.Trace(err) - } - return string(schema), nil -} - -// MDLoaderSetupConfig stores the configs when setting up a MDLoader. -// This can control the behavior when constructing an MDLoader. -type MDLoaderSetupConfig struct { - // MaxScanFiles specifies the maximum number of files to scan. - // If the value is <= 0, it means the number of data source files will be scanned as many as possible. - MaxScanFiles int - // ReturnPartialResultOnError specifies whether the currently scanned files are analyzed, - // and return the partial result. - ReturnPartialResultOnError bool - // FileIter controls the file iteration policy when constructing a MDLoader. - FileIter FileIterator -} - -// DefaultMDLoaderSetupConfig generates a default MDLoaderSetupConfig. -func DefaultMDLoaderSetupConfig() *MDLoaderSetupConfig { - return &MDLoaderSetupConfig{ - MaxScanFiles: 0, // By default, the loader will scan all the files. - ReturnPartialResultOnError: false, - FileIter: nil, - } -} - -// MDLoaderSetupOption is the option type for setting up a MDLoaderSetupConfig. -type MDLoaderSetupOption func(cfg *MDLoaderSetupConfig) - -// WithMaxScanFiles generates an option that limits the max scan files when setting up a MDLoader. -func WithMaxScanFiles(maxScanFiles int) MDLoaderSetupOption { - return func(cfg *MDLoaderSetupConfig) { - if maxScanFiles > 0 { - cfg.MaxScanFiles = maxScanFiles - cfg.ReturnPartialResultOnError = true - } - } -} - -// ReturnPartialResultOnError generates an option that controls -// whether return the partial scanned result on error when setting up a MDLoader. -func ReturnPartialResultOnError(supportPartialResult bool) MDLoaderSetupOption { - return func(cfg *MDLoaderSetupConfig) { - cfg.ReturnPartialResultOnError = supportPartialResult - } -} - -// WithFileIterator generates an option that specifies the file iteration policy. -func WithFileIterator(fileIter FileIterator) MDLoaderSetupOption { - return func(cfg *MDLoaderSetupConfig) { - cfg.FileIter = fileIter - } -} - -// LoaderConfig is the configuration for constructing a MDLoader. -type LoaderConfig struct { - // SourceID is the unique identifier for the data source, it's used in DM only. - // must be used together with Routes. - SourceID string - // SourceURL is the URL of the data source. - SourceURL string - // Routes is the routing rules for the tables, exclusive with FileRouters. - // it's deprecated in lightning, but still used in DM. - // when used this, DefaultFileRules must be true. - Routes config.Routes - // CharacterSet is the character set of the schema sql files. - CharacterSet string - // Filter is the filter for the tables, files related to filtered-out tables are not loaded. - // must be specified, else all tables are filtered out, see config.GetDefaultFilter. - Filter []string - FileRouters []*config.FileRouteRule - // CaseSensitive indicates whether Routes and Filter are case-sensitive. - CaseSensitive bool - // DefaultFileRules indicates whether to use the default file routing rules. - // If it's true, the default file routing rules will be appended to the FileRouters. - // a little confusing, but it's true only when FileRouters is empty. - DefaultFileRules bool -} - -// NewLoaderCfg creates loader config from lightning config. -func NewLoaderCfg(cfg *config.Config) LoaderConfig { - return LoaderConfig{ - SourceID: cfg.Mydumper.SourceID, - SourceURL: cfg.Mydumper.SourceDir, - Routes: cfg.Routes, - CharacterSet: cfg.Mydumper.CharacterSet, - Filter: cfg.Mydumper.Filter, - FileRouters: cfg.Mydumper.FileRouters, - CaseSensitive: cfg.Mydumper.CaseSensitive, - DefaultFileRules: cfg.Mydumper.DefaultFileRules, - } -} - -// MDLoader is for 'Mydumper File Loader', which loads the files in the data source and generates a set of metadata. -type MDLoader struct { - store storage.ExternalStorage - dbs []*MDDatabaseMeta - filter filter.Filter - router *regexprrouter.RouteTable - fileRouter FileRouter - charSet string -} - -type mdLoaderSetup struct { - sourceID string - loader *MDLoader - dbSchemas []FileInfo - tableSchemas []FileInfo - viewSchemas []FileInfo - tableDatas []FileInfo - dbIndexMap map[string]int - tableIndexMap map[filter.Table]int - setupCfg *MDLoaderSetupConfig -} - -// NewLoader constructs a MyDumper loader that scanns the data source and constructs a set of metadatas. -func NewLoader(ctx context.Context, cfg LoaderConfig, opts ...MDLoaderSetupOption) (*MDLoader, error) { - u, err := storage.ParseBackend(cfg.SourceURL, nil) - if err != nil { - return nil, common.NormalizeError(err) - } - s, err := storage.New(ctx, u, &storage.ExternalStorageOptions{}) - if err != nil { - return nil, common.NormalizeError(err) - } - - return NewLoaderWithStore(ctx, cfg, s, opts...) -} - -// NewLoaderWithStore constructs a MyDumper loader with the provided external storage that scanns the data source and constructs a set of metadatas. -func NewLoaderWithStore(ctx context.Context, cfg LoaderConfig, - store storage.ExternalStorage, opts ...MDLoaderSetupOption) (*MDLoader, error) { - var r *regexprrouter.RouteTable - var err error - - mdLoaderSetupCfg := DefaultMDLoaderSetupConfig() - for _, o := range opts { - o(mdLoaderSetupCfg) - } - if mdLoaderSetupCfg.FileIter == nil { - mdLoaderSetupCfg.FileIter = &allFileIterator{ - store: store, - maxScanFiles: mdLoaderSetupCfg.MaxScanFiles, - } - } - - if len(cfg.Routes) > 0 && len(cfg.FileRouters) > 0 { - return nil, common.ErrInvalidConfig.GenWithStack("table route is deprecated, can't config both [routes] and [mydumper.files]") - } - - if len(cfg.Routes) > 0 { - r, err = regexprrouter.NewRegExprRouter(cfg.CaseSensitive, cfg.Routes) - if err != nil { - return nil, common.ErrInvalidConfig.Wrap(err).GenWithStack("invalid table route rule") - } - } - - f, err := filter.Parse(cfg.Filter) - if err != nil { - return nil, common.ErrInvalidConfig.Wrap(err).GenWithStack("parse filter failed") - } - if !cfg.CaseSensitive { - f = filter.CaseInsensitive(f) - } - - fileRouteRules := cfg.FileRouters - if cfg.DefaultFileRules { - fileRouteRules = append(fileRouteRules, defaultFileRouteRules...) - } - - fileRouter, err := NewFileRouter(fileRouteRules, log.FromContext(ctx)) - if err != nil { - return nil, common.ErrInvalidConfig.Wrap(err).GenWithStack("parse file routing rule failed") - } - - mdl := &MDLoader{ - store: store, - filter: f, - router: r, - charSet: cfg.CharacterSet, - fileRouter: fileRouter, - } - - setup := mdLoaderSetup{ - sourceID: cfg.SourceID, - loader: mdl, - dbIndexMap: make(map[string]int), - tableIndexMap: make(map[filter.Table]int), - setupCfg: mdLoaderSetupCfg, - } - - if err := setup.setup(ctx); err != nil { - if mdLoaderSetupCfg.ReturnPartialResultOnError { - return mdl, errors.Trace(err) - } - return nil, errors.Trace(err) - } - - return mdl, nil -} - -type fileType int - -const ( - fileTypeDatabaseSchema fileType = iota - fileTypeTableSchema - fileTypeTableData -) - -// String implements the Stringer interface. -func (ftype fileType) String() string { - switch ftype { - case fileTypeDatabaseSchema: - return "database schema" - case fileTypeTableSchema: - return "table schema" - case fileTypeTableData: - return "table data" - default: - return "(unknown)" - } -} - -// FileInfo contains the information for a data file in a table. -type FileInfo struct { - TableName filter.Table - FileMeta SourceFileMeta -} - -// ExtendColumnData contains the extended column names and values information for a table. -type ExtendColumnData struct { - Columns []string - Values []string -} - -// setup the `s.loader.dbs` slice by scanning all *.sql files inside `dir`. -// -// The database and tables are inserted in a consistent order, so creating an -// MDLoader twice with the same data source is going to produce the same array, -// even after killing Lightning. -// -// This is achieved by using `filepath.Walk` internally which guarantees the -// files are visited in lexicographical order (note that this does not mean the -// databases and tables in the end are ordered lexicographically since they may -// be stored in different subdirectories). -// -// Will sort tables by table size, this means that the big table is imported -// at the latest, which to avoid large table take a long time to import and block -// small table to release index worker. -func (s *mdLoaderSetup) setup(ctx context.Context) error { - /* - Mydumper file names format - db —— {db}-schema-create.sql - table —— {db}.{table}-schema.sql - sql —— {db}.{table}.{part}.sql / {db}.{table}.sql - */ - var gerr error - fileIter := s.setupCfg.FileIter - if fileIter == nil { - return errors.New("file iterator is not defined") - } - if err := fileIter.IterateFiles(ctx, s.constructFileInfo); err != nil { - if !s.setupCfg.ReturnPartialResultOnError { - return common.ErrStorageUnknown.Wrap(err).GenWithStack("list file failed") - } - gerr = err - } - if err := s.route(); err != nil { - return common.ErrTableRoute.Wrap(err).GenWithStackByArgs() - } - - // setup database schema - if len(s.dbSchemas) != 0 { - for _, fileInfo := range s.dbSchemas { - if _, dbExists := s.insertDB(fileInfo); dbExists && s.loader.router == nil { - return common.ErrInvalidSchemaFile.GenWithStack("invalid database schema file, duplicated item - %s", fileInfo.FileMeta.Path) - } - } - } - - if len(s.tableSchemas) != 0 { - // setup table schema - for _, fileInfo := range s.tableSchemas { - if _, _, tableExists := s.insertTable(fileInfo); tableExists && s.loader.router == nil { - return common.ErrInvalidSchemaFile.GenWithStack("invalid table schema file, duplicated item - %s", fileInfo.FileMeta.Path) - } - } - } - - if len(s.viewSchemas) != 0 { - // setup view schema - for _, fileInfo := range s.viewSchemas { - _, tableExists := s.insertView(fileInfo) - if !tableExists { - // we are not expect the user only has view schema without table schema when user use dumpling to get view. - // remove the last `-view.sql` from path as the relate table schema file path - return common.ErrInvalidSchemaFile.GenWithStack("invalid view schema file, miss host table schema for view '%s'", fileInfo.TableName.Name) - } - } - } - - // Sql file for restore data - for _, fileInfo := range s.tableDatas { - // set a dummy `FileInfo` here without file meta because we needn't restore the table schema - tableMeta, _, _ := s.insertTable(FileInfo{TableName: fileInfo.TableName}) - tableMeta.DataFiles = append(tableMeta.DataFiles, fileInfo) - tableMeta.TotalSize += fileInfo.FileMeta.RealSize - } - - for _, dbMeta := range s.loader.dbs { - // Put the small table in the front of the slice which can avoid large table - // take a long time to import and block small table to release index worker. - meta := dbMeta - sort.SliceStable(meta.Tables, func(i, j int) bool { - return meta.Tables[i].TotalSize < meta.Tables[j].TotalSize - }) - - // sort each table source files by sort-key - for _, tbMeta := range meta.Tables { - dataFiles := tbMeta.DataFiles - sort.SliceStable(dataFiles, func(i, j int) bool { - return dataFiles[i].FileMeta.SortKey < dataFiles[j].FileMeta.SortKey - }) - } - } - - return gerr -} - -// FileHandler is the interface to handle the file give the path and size. -// It is mainly used in the `FileIterator` as parameters. -type FileHandler func(ctx context.Context, path string, size int64) error - -// FileIterator is the interface to iterate files in a data source. -// Use this interface to customize the file iteration policy. -type FileIterator interface { - IterateFiles(ctx context.Context, hdl FileHandler) error -} - -type allFileIterator struct { - store storage.ExternalStorage - maxScanFiles int -} - -func (iter *allFileIterator) IterateFiles(ctx context.Context, hdl FileHandler) error { - // `filepath.Walk` yields the paths in a deterministic (lexicographical) order, - // meaning the file and chunk orders will be the same everytime it is called - // (as long as the source is immutable). - totalScannedFileCount := 0 - err := iter.store.WalkDir(ctx, &storage.WalkOption{}, func(path string, size int64) error { - totalScannedFileCount++ - if iter.maxScanFiles > 0 && totalScannedFileCount > iter.maxScanFiles { - return common.ErrTooManySourceFiles - } - return hdl(ctx, path, size) - }) - - return errors.Trace(err) -} - -func (s *mdLoaderSetup) constructFileInfo(ctx context.Context, path string, size int64) error { - logger := log.FromContext(ctx).With(zap.String("path", path)) - res, err := s.loader.fileRouter.Route(filepath.ToSlash(path)) - if err != nil { - return errors.Annotatef(err, "apply file routing on file '%s' failed", path) - } - if res == nil { - logger.Info("file is filtered by file router", zap.String("category", "loader")) - return nil - } - - info := FileInfo{ - TableName: filter.Table{Schema: res.Schema, Name: res.Name}, - FileMeta: SourceFileMeta{Path: path, Type: res.Type, Compression: res.Compression, SortKey: res.Key, FileSize: size, RealSize: size}, - } - - if s.loader.shouldSkip(&info.TableName) { - logger.Debug("ignoring table file", zap.String("category", "filter")) - - return nil - } - - switch res.Type { - case SourceTypeSchemaSchema: - s.dbSchemas = append(s.dbSchemas, info) - case SourceTypeTableSchema: - s.tableSchemas = append(s.tableSchemas, info) - case SourceTypeViewSchema: - s.viewSchemas = append(s.viewSchemas, info) - case SourceTypeSQL, SourceTypeCSV: - if info.FileMeta.Compression != CompressionNone { - compressRatio, err2 := SampleFileCompressRatio(ctx, info.FileMeta, s.loader.GetStore()) - if err2 != nil { - logger.Error("fail to calculate data file compress ratio", zap.String("category", "loader"), - zap.String("schema", res.Schema), zap.String("table", res.Name), zap.Stringer("type", res.Type)) - } else { - info.FileMeta.RealSize = int64(compressRatio * float64(info.FileMeta.FileSize)) - } - } - s.tableDatas = append(s.tableDatas, info) - case SourceTypeParquet: - parquestDataSize, err2 := SampleParquetDataSize(ctx, info.FileMeta, s.loader.GetStore()) - if err2 != nil { - logger.Error("fail to sample parquet data size", zap.String("category", "loader"), - zap.String("schema", res.Schema), zap.String("table", res.Name), zap.Stringer("type", res.Type), zap.Error(err2)) - } else { - info.FileMeta.RealSize = parquestDataSize - } - s.tableDatas = append(s.tableDatas, info) - } - - logger.Debug("file route result", zap.String("schema", res.Schema), - zap.String("table", res.Name), zap.Stringer("type", res.Type)) - - return nil -} - -func (l *MDLoader) shouldSkip(table *filter.Table) bool { - if len(table.Name) == 0 { - return !l.filter.MatchSchema(table.Schema) - } - return !l.filter.MatchTable(table.Schema, table.Name) -} - -func (s *mdLoaderSetup) route() error { - r := s.loader.router - if r == nil { - return nil - } - - type dbInfo struct { - fileMeta SourceFileMeta - count int // means file count(db/table/view schema and table data) - } - - knownDBNames := make(map[string]*dbInfo) - for _, info := range s.dbSchemas { - knownDBNames[info.TableName.Schema] = &dbInfo{ - fileMeta: info.FileMeta, - count: 1, - } - } - for _, info := range s.tableSchemas { - if _, ok := knownDBNames[info.TableName.Schema]; !ok { - knownDBNames[info.TableName.Schema] = &dbInfo{ - fileMeta: info.FileMeta, - count: 1, - } - } - knownDBNames[info.TableName.Schema].count++ - } - for _, info := range s.viewSchemas { - if _, ok := knownDBNames[info.TableName.Schema]; !ok { - knownDBNames[info.TableName.Schema] = &dbInfo{ - fileMeta: info.FileMeta, - count: 1, - } - } - knownDBNames[info.TableName.Schema].count++ - } - for _, info := range s.tableDatas { - if _, ok := knownDBNames[info.TableName.Schema]; !ok { - knownDBNames[info.TableName.Schema] = &dbInfo{ - fileMeta: info.FileMeta, - count: 1, - } - } - knownDBNames[info.TableName.Schema].count++ - } - - runRoute := func(arr []FileInfo) error { - for i, info := range arr { - rawDB, rawTable := info.TableName.Schema, info.TableName.Name - targetDB, targetTable, err := r.Route(rawDB, rawTable) - if err != nil { - return errors.Trace(err) - } - if targetDB != rawDB { - oldInfo := knownDBNames[rawDB] - oldInfo.count-- - newInfo, ok := knownDBNames[targetDB] - if !ok { - newInfo = &dbInfo{fileMeta: oldInfo.fileMeta, count: 1} - s.dbSchemas = append(s.dbSchemas, FileInfo{ - TableName: filter.Table{Schema: targetDB}, - FileMeta: oldInfo.fileMeta, - }) - } - newInfo.count++ - knownDBNames[targetDB] = newInfo - } - arr[i].TableName = filter.Table{Schema: targetDB, Name: targetTable} - extendCols, extendVals := r.FetchExtendColumn(rawDB, rawTable, s.sourceID) - if len(extendCols) > 0 { - arr[i].FileMeta.ExtendData = ExtendColumnData{ - Columns: extendCols, - Values: extendVals, - } - } - } - return nil - } - - // route for schema table and view - if err := runRoute(s.dbSchemas); err != nil { - return errors.Trace(err) - } - if err := runRoute(s.tableSchemas); err != nil { - return errors.Trace(err) - } - if err := runRoute(s.viewSchemas); err != nil { - return errors.Trace(err) - } - if err := runRoute(s.tableDatas); err != nil { - return errors.Trace(err) - } - // remove all schemas which has been entirely routed away(file count > 0) - // https://github.com/golang/go/wiki/SliceTricks#filtering-without-allocating - remainingSchemas := s.dbSchemas[:0] - for _, info := range s.dbSchemas { - if dbInfo := knownDBNames[info.TableName.Schema]; dbInfo.count > 0 { - remainingSchemas = append(remainingSchemas, info) - } else if dbInfo.count < 0 { - // this should not happen if there are no bugs in the code - return common.ErrTableRoute.GenWithStack("something wrong happened when route %s", info.TableName.String()) - } - } - s.dbSchemas = remainingSchemas - return nil -} - -func (s *mdLoaderSetup) insertDB(f FileInfo) (*MDDatabaseMeta, bool) { - dbIndex, ok := s.dbIndexMap[f.TableName.Schema] - if ok { - return s.loader.dbs[dbIndex], true - } - s.dbIndexMap[f.TableName.Schema] = len(s.loader.dbs) - ptr := &MDDatabaseMeta{ - Name: f.TableName.Schema, - SchemaFile: f, - charSet: s.loader.charSet, - } - s.loader.dbs = append(s.loader.dbs, ptr) - return ptr, false -} - -func (s *mdLoaderSetup) insertTable(fileInfo FileInfo) (tblMeta *MDTableMeta, dbExists bool, tableExists bool) { - dbFileInfo := FileInfo{ - TableName: filter.Table{ - Schema: fileInfo.TableName.Schema, - }, - FileMeta: SourceFileMeta{Type: SourceTypeSchemaSchema}, - } - dbMeta, dbExists := s.insertDB(dbFileInfo) - tableIndex, ok := s.tableIndexMap[fileInfo.TableName] - if ok { - return dbMeta.Tables[tableIndex], dbExists, true - } - s.tableIndexMap[fileInfo.TableName] = len(dbMeta.Tables) - ptr := &MDTableMeta{ - DB: fileInfo.TableName.Schema, - Name: fileInfo.TableName.Name, - SchemaFile: fileInfo, - DataFiles: make([]FileInfo, 0, 16), - charSet: s.loader.charSet, - IndexRatio: 0.0, - IsRowOrdered: true, - } - dbMeta.Tables = append(dbMeta.Tables, ptr) - return ptr, dbExists, false -} - -func (s *mdLoaderSetup) insertView(fileInfo FileInfo) (dbExists bool, tableExists bool) { - dbFileInfo := FileInfo{ - TableName: filter.Table{ - Schema: fileInfo.TableName.Schema, - }, - FileMeta: SourceFileMeta{Type: SourceTypeSchemaSchema}, - } - dbMeta, dbExists := s.insertDB(dbFileInfo) - _, ok := s.tableIndexMap[fileInfo.TableName] - if ok { - meta := &MDTableMeta{ - DB: fileInfo.TableName.Schema, - Name: fileInfo.TableName.Name, - SchemaFile: fileInfo, - charSet: s.loader.charSet, - IndexRatio: 0.0, - IsRowOrdered: true, - } - dbMeta.Views = append(dbMeta.Views, meta) - } - return dbExists, ok -} - -// GetDatabases gets the list of scanned MDDatabaseMeta for the loader. -func (l *MDLoader) GetDatabases() []*MDDatabaseMeta { - return l.dbs -} - -// GetStore gets the external storage used by the loader. -func (l *MDLoader) GetStore() storage.ExternalStorage { - return l.store -} - -func calculateFileBytes(ctx context.Context, - dataFile string, - compressType storage.CompressType, - store storage.ExternalStorage, - offset int64) (tot int, pos int64, err error) { - bytes := make([]byte, sampleCompressedFileSize) - reader, err := store.Open(ctx, dataFile, nil) - if err != nil { - return 0, 0, errors.Trace(err) - } - defer reader.Close() - - decompressConfig := storage.DecompressConfig{ZStdDecodeConcurrency: 1} - compressReader, err := storage.NewLimitedInterceptReader(reader, compressType, decompressConfig, offset) - if err != nil { - return 0, 0, errors.Trace(err) - } - - readBytes := func() error { - n, err2 := compressReader.Read(bytes) - if err2 != nil && errors.Cause(err2) != io.EOF && errors.Cause(err) != io.ErrUnexpectedEOF { - return err2 - } - tot += n - return err2 - } - - if offset == 0 { - err = readBytes() - if err != nil && errors.Cause(err) != io.EOF && errors.Cause(err) != io.ErrUnexpectedEOF { - return 0, 0, err - } - pos, err = compressReader.Seek(0, io.SeekCurrent) - if err != nil { - return 0, 0, errors.Trace(err) - } - return tot, pos, nil - } - - for { - err = readBytes() - if err != nil { - break - } - } - if err != nil && errors.Cause(err) != io.EOF && errors.Cause(err) != io.ErrUnexpectedEOF { - return 0, 0, errors.Trace(err) - } - return tot, offset, nil -} - -// SampleFileCompressRatio samples the compress ratio of the compressed file. -func SampleFileCompressRatio(ctx context.Context, fileMeta SourceFileMeta, store storage.ExternalStorage) (float64, error) { - failpoint.Inject("SampleFileCompressPercentage", func(val failpoint.Value) { - switch v := val.(type) { - case string: - failpoint.Return(1.0, errors.New(v)) - case int: - failpoint.Return(float64(v)/100, nil) - } - }) - if fileMeta.Compression == CompressionNone { - return 1, nil - } - compressType, err := ToStorageCompressType(fileMeta.Compression) - if err != nil { - return 0, err - } - // We use the following method to sample the compress ratio of the first few bytes of the file. - // 1. read first time aiming to find a valid compressed file offset. If we continue read now, the compress reader will - // request more data from file reader buffer them in its memory. We can't compute an accurate compress ratio. - // 2. we use a second reading and limit the file reader only read n bytes(n is the valid position we find in the first reading). - // Then we read all the data out from the compress reader. The data length m we read out is the uncompressed data length. - // Use m/n to compute the compress ratio. - // read first time, aims to find a valid end pos in compressed file - _, pos, err := calculateFileBytes(ctx, fileMeta.Path, compressType, store, 0) - if err != nil { - return 0, err - } - // read second time, original reader ends at first time's valid pos, compute sample data compress ratio - tot, pos, err := calculateFileBytes(ctx, fileMeta.Path, compressType, store, pos) - if err != nil { - return 0, err - } - return float64(tot) / float64(pos), nil -} - -// SampleParquetDataSize samples the data size of the parquet file. -func SampleParquetDataSize(ctx context.Context, fileMeta SourceFileMeta, store storage.ExternalStorage) (int64, error) { - totalRowCount, err := ReadParquetFileRowCountByFile(ctx, store, fileMeta) - if totalRowCount == 0 || err != nil { - return 0, err - } - - reader, err := store.Open(ctx, fileMeta.Path, nil) - if err != nil { - return 0, err - } - parser, err := NewParquetParser(ctx, store, reader, fileMeta.Path) - if err != nil { - //nolint: errcheck - reader.Close() - return 0, err - } - //nolint: errcheck - defer parser.Close() - - var ( - rowSize int64 - rowCount int64 - ) - for { - err = parser.ReadRow() - if err != nil { - if errors.Cause(err) == io.EOF { - break - } - return 0, err - } - lastRow := parser.LastRow() - rowCount++ - rowSize += int64(lastRow.Length) - parser.RecycleRow(lastRow) - if rowSize > maxSampleParquetDataSize || rowCount > maxSampleParquetRowCount { - break - } - } - size := int64(float64(totalRowCount) / float64(rowCount) * float64(rowSize)) - return size, nil -} diff --git a/pkg/meta/autoid/autoid.go b/pkg/meta/autoid/autoid.go index d13b3f7fc2b03..237981497fcf9 100644 --- a/pkg/meta/autoid/autoid.go +++ b/pkg/meta/autoid/autoid.go @@ -539,16 +539,16 @@ func (alloc *allocator) GetType() AllocatorType { // NextStep return new auto id step according to previous step and consuming time. func NextStep(curStep int64, consumeDur time.Duration) int64 { - if val, _err_ := failpoint.Eval(_curpkg_("mockAutoIDCustomize")); _err_ == nil { + failpoint.Inject("mockAutoIDCustomize", func(val failpoint.Value) { if val.(bool) { - return 3 + failpoint.Return(3) } - } - if val, _err_ := failpoint.Eval(_curpkg_("mockAutoIDChange")); _err_ == nil { + }) + failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { if val.(bool) { - return step + failpoint.Return(step) } - } + }) consumeRate := defaultConsumeTime.Seconds() / consumeDur.Seconds() res := int64(float64(curStep) * consumeRate) @@ -582,11 +582,11 @@ func newSinglePointAlloc(r Requirement, dbID, tblID int64, isUnsigned bool) *sin } // mockAutoIDChange failpoint is not implemented in this allocator, so fallback to use the default one. - if val, _err_ := failpoint.Eval(_curpkg_("mockAutoIDChange")); _err_ == nil { + failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { if val.(bool) { spa = nil } - } + }) return spa } diff --git a/pkg/meta/autoid/autoid.go__failpoint_stash__ b/pkg/meta/autoid/autoid.go__failpoint_stash__ deleted file mode 100644 index 237981497fcf9..0000000000000 --- a/pkg/meta/autoid/autoid.go__failpoint_stash__ +++ /dev/null @@ -1,1351 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package autoid - -import ( - "bytes" - "context" - "fmt" - "math" - "strconv" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/autoid" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/tikv/client-go/v2/txnkv/txnsnapshot" - tikvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -// Attention: -// For reading cluster TiDB memory tables, the system schema/table should be same. -// Once the system schema/table id been allocated, it can't be changed any more. -// Change the system schema/table id may have the compatibility problem. -const ( - // SystemSchemaIDFlag is the system schema/table id flag, uses the highest bit position as system schema ID flag, it's exports for test. - SystemSchemaIDFlag = 1 << 62 - // InformationSchemaDBID is the information_schema schema id, it's exports for test. - InformationSchemaDBID int64 = SystemSchemaIDFlag | 1 - // PerformanceSchemaDBID is the performance_schema schema id, it's exports for test. - PerformanceSchemaDBID int64 = SystemSchemaIDFlag | 10000 - // MetricSchemaDBID is the metrics_schema schema id, it's exported for test. - MetricSchemaDBID int64 = SystemSchemaIDFlag | 20000 -) - -const ( - minStep = 30000 - maxStep = 2000000 - defaultConsumeTime = 10 * time.Second - minIncrement = 1 - maxIncrement = 65535 -) - -// RowIDBitLength is the bit number of a row id in TiDB. -const RowIDBitLength = 64 - -const ( - // AutoRandomShardBitsDefault is the default number of shard bits. - AutoRandomShardBitsDefault = 5 - // AutoRandomRangeBitsDefault is the default number of range bits. - AutoRandomRangeBitsDefault = 64 - // AutoRandomShardBitsMax is the max number of shard bits. - AutoRandomShardBitsMax = 15 - // AutoRandomRangeBitsMax is the max number of range bits. - AutoRandomRangeBitsMax = 64 - // AutoRandomRangeBitsMin is the min number of range bits. - AutoRandomRangeBitsMin = 32 - // AutoRandomIncBitsMin is the min number of auto random incremental bits. - AutoRandomIncBitsMin = 27 -) - -// AutoRandomShardBitsNormalize normalizes the auto random shard bits. -func AutoRandomShardBitsNormalize(shard int, colName string) (ret uint64, err error) { - if shard == types.UnspecifiedLength { - return AutoRandomShardBitsDefault, nil - } - if shard <= 0 { - return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(AutoRandomNonPositive) - } - if shard > AutoRandomShardBitsMax { - errMsg := fmt.Sprintf(AutoRandomOverflowErrMsg, AutoRandomShardBitsMax, shard, colName) - return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(errMsg) - } - return uint64(shard), nil -} - -// AutoRandomRangeBitsNormalize normalizes the auto random range bits. -func AutoRandomRangeBitsNormalize(rangeBits int) (ret uint64, err error) { - if rangeBits == types.UnspecifiedLength { - return AutoRandomRangeBitsDefault, nil - } - if rangeBits < AutoRandomRangeBitsMin || rangeBits > AutoRandomRangeBitsMax { - errMsg := fmt.Sprintf(AutoRandomInvalidRangeBits, AutoRandomRangeBitsMin, AutoRandomRangeBitsMax, rangeBits) - return 0, dbterror.ErrInvalidAutoRandom.FastGenByArgs(errMsg) - } - return uint64(rangeBits), nil -} - -// Test needs to change it, so it's a variable. -var step = int64(30000) - -// AllocatorType is the type of allocator for generating auto-id. Different type of allocators use different key-value pairs. -type AllocatorType uint8 - -const ( - // RowIDAllocType indicates the allocator is used to allocate row id. - RowIDAllocType AllocatorType = iota - // AutoIncrementType indicates the allocator is used to allocate auto increment value. - AutoIncrementType - // AutoRandomType indicates the allocator is used to allocate auto-shard id. - AutoRandomType - // SequenceType indicates the allocator is used to allocate sequence value. - SequenceType -) - -func (a AllocatorType) String() string { - switch a { - case RowIDAllocType: - return "_tidb_rowid" - case AutoIncrementType: - return "auto_increment" - case AutoRandomType: - return "auto_random" - case SequenceType: - return "sequence" - } - return "unknown" -} - -// CustomAutoIncCacheOption is one kind of AllocOption to customize the allocator step length. -type CustomAutoIncCacheOption int64 - -// ApplyOn implements the AllocOption interface. -func (step CustomAutoIncCacheOption) ApplyOn(alloc *allocator) { - if step == 0 { - return - } - alloc.step = int64(step) - alloc.customStep = true -} - -// AllocOptionTableInfoVersion is used to pass the TableInfo.Version to the allocator. -type AllocOptionTableInfoVersion uint16 - -// ApplyOn implements the AllocOption interface. -func (v AllocOptionTableInfoVersion) ApplyOn(alloc *allocator) { - alloc.tbVersion = uint16(v) -} - -// AllocOption is a interface to define allocator custom options coming in future. -type AllocOption interface { - ApplyOn(*allocator) -} - -// Allocator is an auto increment id generator. -// Just keep id unique actually. -type Allocator interface { - // Alloc allocs N consecutive autoID for table with tableID, returning (min, max] of the allocated autoID batch. - // It gets a batch of autoIDs at a time. So it does not need to access storage for each call. - // The consecutive feature is used to insert multiple rows in a statement. - // increment & offset is used to validate the start position (the allocator's base is not always the last allocated id). - // The returned range is (min, max]: - // case increment=1 & offset=1: you can derive the ids like min+1, min+2... max. - // case increment=x & offset=y: you firstly need to seek to firstID by `SeekToFirstAutoIDXXX`, then derive the IDs like firstID, firstID + increment * 2... in the caller. - Alloc(ctx context.Context, n uint64, increment, offset int64) (int64, int64, error) - - // AllocSeqCache allocs sequence batch value cached in table level(rather than in alloc), the returned range covering - // the size of sequence cache with it's increment. The returned round indicates the sequence cycle times if it is with - // cycle option. - AllocSeqCache() (min int64, max int64, round int64, err error) - - // Rebase rebases the autoID base for table with tableID and the new base value. - // If allocIDs is true, it will allocate some IDs and save to the cache. - // If allocIDs is false, it will not allocate IDs. - Rebase(ctx context.Context, newBase int64, allocIDs bool) error - - // ForceRebase set the next global auto ID to newBase. - ForceRebase(newBase int64) error - - // RebaseSeq rebases the sequence value in number axis with tableID and the new base value. - RebaseSeq(newBase int64) (int64, bool, error) - - // Base return the current base of Allocator. - Base() int64 - // End is only used for test. - End() int64 - // NextGlobalAutoID returns the next global autoID. - NextGlobalAutoID() (int64, error) - GetType() AllocatorType -} - -// Allocators represents a set of `Allocator`s. -type Allocators struct { - SepAutoInc bool - Allocs []Allocator -} - -// NewAllocators packs multiple `Allocator`s into Allocators. -func NewAllocators(sepAutoInc bool, allocators ...Allocator) Allocators { - return Allocators{ - SepAutoInc: sepAutoInc, - Allocs: allocators, - } -} - -// Append add an allocator to the allocators. -func (all Allocators) Append(a Allocator) Allocators { - return Allocators{ - SepAutoInc: all.SepAutoInc, - Allocs: append(all.Allocs, a), - } -} - -// Get returns the Allocator according to the AllocatorType. -func (all Allocators) Get(allocType AllocatorType) Allocator { - if !all.SepAutoInc { - if allocType == AutoIncrementType { - allocType = RowIDAllocType - } - } - - for _, a := range all.Allocs { - if a.GetType() == allocType { - return a - } - } - return nil -} - -// Filter filters all the allocators that match pred. -func (all Allocators) Filter(pred func(Allocator) bool) Allocators { - var ret []Allocator - for _, a := range all.Allocs { - if pred(a) { - ret = append(ret, a) - } - } - return Allocators{ - SepAutoInc: all.SepAutoInc, - Allocs: ret, - } -} - -type allocator struct { - mu sync.Mutex - base int64 - end int64 - store kv.Storage - // dbID is database ID where it was created. - dbID int64 - tbID int64 - tbVersion uint16 - isUnsigned bool - lastAllocTime time.Time - step int64 - customStep bool - allocType AllocatorType - sequence *model.SequenceInfo -} - -// GetStep is only used by tests -func GetStep() int64 { - return step -} - -// SetStep is only used by tests -func SetStep(s int64) { - step = s -} - -// Base implements autoid.Allocator Base interface. -func (alloc *allocator) Base() int64 { - alloc.mu.Lock() - defer alloc.mu.Unlock() - return alloc.base -} - -// End implements autoid.Allocator End interface. -func (alloc *allocator) End() int64 { - alloc.mu.Lock() - defer alloc.mu.Unlock() - return alloc.end -} - -// NextGlobalAutoID implements autoid.Allocator NextGlobalAutoID interface. -func (alloc *allocator) NextGlobalAutoID() (int64, error) { - var autoID int64 - startTime := time.Now() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { - var err1 error - autoID, err1 = alloc.getIDAccessor(txn).Get() - if err1 != nil { - return errors.Trace(err1) - } - return nil - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.GlobalAutoID, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if alloc.isUnsigned { - return int64(uint64(autoID) + 1), err - } - return autoID + 1, err -} - -func (alloc *allocator) rebase4Unsigned(ctx context.Context, requiredBase uint64, allocIDs bool) error { - // Satisfied by alloc.base, nothing to do. - if requiredBase <= uint64(alloc.base) { - return nil - } - // Satisfied by alloc.end, need to update alloc.base. - if requiredBase <= uint64(alloc.end) { - alloc.base = int64(requiredBase) - return nil - } - - ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) - if allocatorStats != nil { - allocatorStats.rebaseCount++ - defer func() { - if commitDetail != nil { - allocatorStats.mergeCommitDetail(*commitDetail) - } - }() - } - var newBase, newEnd uint64 - startTime := time.Now() - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { - if allocatorStats != nil { - txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) - } - idAcc := alloc.getIDAccessor(txn) - currentEnd, err1 := idAcc.Get() - if err1 != nil { - return err1 - } - uCurrentEnd := uint64(currentEnd) - if allocIDs { - newBase = max(uCurrentEnd, requiredBase) - newEnd = min(math.MaxUint64-uint64(alloc.step), newBase) + uint64(alloc.step) - } else { - if uCurrentEnd >= requiredBase { - newBase = uCurrentEnd - newEnd = uCurrentEnd - // Required base satisfied, we don't need to update KV. - return nil - } - // If we don't want to allocate IDs, for example when creating a table with a given base value, - // We need to make sure when other TiDB server allocates ID for the first time, requiredBase + 1 - // will be allocated, so we need to increase the end to exactly the requiredBase. - newBase = requiredBase - newEnd = requiredBase - } - _, err1 = idAcc.Inc(int64(newEnd - uCurrentEnd)) - return err1 - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return err - } - alloc.base, alloc.end = int64(newBase), int64(newEnd) - return nil -} - -func (alloc *allocator) rebase4Signed(ctx context.Context, requiredBase int64, allocIDs bool) error { - // Satisfied by alloc.base, nothing to do. - if requiredBase <= alloc.base { - return nil - } - // Satisfied by alloc.end, need to update alloc.base. - if requiredBase <= alloc.end { - alloc.base = requiredBase - return nil - } - - ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) - if allocatorStats != nil { - allocatorStats.rebaseCount++ - defer func() { - if commitDetail != nil { - allocatorStats.mergeCommitDetail(*commitDetail) - } - }() - } - var newBase, newEnd int64 - startTime := time.Now() - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { - if allocatorStats != nil { - txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) - } - idAcc := alloc.getIDAccessor(txn) - currentEnd, err1 := idAcc.Get() - if err1 != nil { - return err1 - } - if allocIDs { - newBase = max(currentEnd, requiredBase) - newEnd = min(math.MaxInt64-alloc.step, newBase) + alloc.step - } else { - if currentEnd >= requiredBase { - newBase = currentEnd - newEnd = currentEnd - // Required base satisfied, we don't need to update KV. - return nil - } - // If we don't want to allocate IDs, for example when creating a table with a given base value, - // We need to make sure when other TiDB server allocates ID for the first time, requiredBase + 1 - // will be allocated, so we need to increase the end to exactly the requiredBase. - newBase = requiredBase - newEnd = requiredBase - } - _, err1 = idAcc.Inc(newEnd - currentEnd) - return err1 - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return err - } - alloc.base, alloc.end = newBase, newEnd - return nil -} - -// rebase4Sequence won't alloc batch immediately, cause it won't cache value in allocator. -func (alloc *allocator) rebase4Sequence(requiredBase int64) (int64, bool, error) { - startTime := time.Now() - alreadySatisfied := false - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { - acc := meta.NewMeta(txn).GetAutoIDAccessors(alloc.dbID, alloc.tbID) - currentEnd, err := acc.SequenceValue().Get() - if err != nil { - return err - } - if alloc.sequence.Increment > 0 { - if currentEnd >= requiredBase { - // Required base satisfied, we don't need to update KV. - alreadySatisfied = true - return nil - } - } else { - if currentEnd <= requiredBase { - // Required base satisfied, we don't need to update KV. - alreadySatisfied = true - return nil - } - } - - // If we don't want to allocate IDs, for example when creating a table with a given base value, - // We need to make sure when other TiDB server allocates ID for the first time, requiredBase + 1 - // will be allocated, so we need to increase the end to exactly the requiredBase. - _, err = acc.SequenceValue().Inc(requiredBase - currentEnd) - return err - }) - // TODO: sequence metrics - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return 0, false, err - } - if alreadySatisfied { - return 0, true, nil - } - return requiredBase, false, err -} - -// Rebase implements autoid.Allocator Rebase interface. -// The requiredBase is the minimum base value after Rebase. -// The real base may be greater than the required base. -func (alloc *allocator) Rebase(ctx context.Context, requiredBase int64, allocIDs bool) error { - alloc.mu.Lock() - defer alloc.mu.Unlock() - if alloc.isUnsigned { - return alloc.rebase4Unsigned(ctx, uint64(requiredBase), allocIDs) - } - return alloc.rebase4Signed(ctx, requiredBase, allocIDs) -} - -// ForceRebase implements autoid.Allocator ForceRebase interface. -func (alloc *allocator) ForceRebase(requiredBase int64) error { - if requiredBase == -1 { - return ErrAutoincReadFailed.GenWithStack("Cannot force rebase the next global ID to '0'") - } - alloc.mu.Lock() - defer alloc.mu.Unlock() - startTime := time.Now() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { - idAcc := alloc.getIDAccessor(txn) - currentEnd, err1 := idAcc.Get() - if err1 != nil { - return err1 - } - var step int64 - if !alloc.isUnsigned { - step = requiredBase - currentEnd - } else { - uRequiredBase, uCurrentEnd := uint64(requiredBase), uint64(currentEnd) - step = int64(uRequiredBase - uCurrentEnd) - } - _, err1 = idAcc.Inc(step) - return err1 - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDRebase, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return err - } - alloc.base, alloc.end = requiredBase, requiredBase - return nil -} - -// Rebase implements autoid.Allocator RebaseSeq interface. -// The return value is quite same as expression function, bool means whether it should be NULL, -// here it will be used in setval expression function (true meaning the set value has been satisfied, return NULL). -// case1:When requiredBase is satisfied with current value, it will return (0, true, nil), -// case2:When requiredBase is successfully set in, it will return (requiredBase, false, nil). -// If some error occurs in the process, return it immediately. -func (alloc *allocator) RebaseSeq(requiredBase int64) (int64, bool, error) { - alloc.mu.Lock() - defer alloc.mu.Unlock() - return alloc.rebase4Sequence(requiredBase) -} - -func (alloc *allocator) GetType() AllocatorType { - return alloc.allocType -} - -// NextStep return new auto id step according to previous step and consuming time. -func NextStep(curStep int64, consumeDur time.Duration) int64 { - failpoint.Inject("mockAutoIDCustomize", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(3) - } - }) - failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(step) - } - }) - - consumeRate := defaultConsumeTime.Seconds() / consumeDur.Seconds() - res := int64(float64(curStep) * consumeRate) - if res < minStep { - return minStep - } else if res > maxStep { - return maxStep - } - return res -} - -// MockForTest is exported for testing. -// The actual implementation is in github.com/pingcap/tidb/pkg/autoid_service because of the -// package circle depending issue. -var MockForTest func(kv.Storage) autoid.AutoIDAllocClient - -func newSinglePointAlloc(r Requirement, dbID, tblID int64, isUnsigned bool) *singlePointAlloc { - keyspaceID := uint32(r.Store().GetCodec().GetKeyspaceID()) - spa := &singlePointAlloc{ - dbID: dbID, - tblID: tblID, - isUnsigned: isUnsigned, - keyspaceID: keyspaceID, - } - if r.AutoIDClient() == nil { - // Only for test in mockstore - spa.ClientDiscover = &ClientDiscover{} - spa.mu.AutoIDAllocClient = MockForTest(r.Store()) - } else { - spa.ClientDiscover = r.AutoIDClient() - } - - // mockAutoIDChange failpoint is not implemented in this allocator, so fallback to use the default one. - failpoint.Inject("mockAutoIDChange", func(val failpoint.Value) { - if val.(bool) { - spa = nil - } - }) - return spa -} - -// Requirement is the parameter required by NewAllocator -type Requirement interface { - Store() kv.Storage - AutoIDClient() *ClientDiscover -} - -// NewAllocator returns a new auto increment id generator on the store. -func NewAllocator(r Requirement, dbID, tbID int64, isUnsigned bool, - allocType AllocatorType, opts ...AllocOption) Allocator { - var store kv.Storage - if r != nil { - store = r.Store() - } - alloc := &allocator{ - store: store, - dbID: dbID, - tbID: tbID, - isUnsigned: isUnsigned, - step: step, - lastAllocTime: time.Now(), - allocType: allocType, - } - for _, fn := range opts { - fn.ApplyOn(alloc) - } - - // Use the MySQL compatible AUTO_INCREMENT mode. - if alloc.customStep && alloc.step == 1 && alloc.tbVersion >= model.TableInfoVersion5 { - if allocType == AutoIncrementType { - alloc1 := newSinglePointAlloc(r, dbID, tbID, isUnsigned) - if alloc1 != nil { - return alloc1 - } - } else if allocType == RowIDAllocType { - // Now that the autoid and rowid allocator are separated, the AUTO_ID_CACHE 1 setting should not make - // the rowid allocator do not use cache. - alloc.customStep = false - alloc.step = step - } - } - - return alloc -} - -// NewSequenceAllocator returns a new sequence value generator on the store. -func NewSequenceAllocator(store kv.Storage, dbID, tbID int64, info *model.SequenceInfo) Allocator { - return &allocator{ - store: store, - dbID: dbID, - tbID: tbID, - // Sequence allocator is always signed. - isUnsigned: false, - lastAllocTime: time.Now(), - allocType: SequenceType, - sequence: info, - } -} - -// TODO: Handle allocators when changing Table ID during ALTER TABLE t PARTITION BY ... - -// NewAllocatorsFromTblInfo creates an array of allocators of different types with the information of model.TableInfo. -func NewAllocatorsFromTblInfo(r Requirement, schemaID int64, tblInfo *model.TableInfo) Allocators { - var allocs []Allocator - dbID := tblInfo.GetAutoIDSchemaID(schemaID) - idCacheOpt := CustomAutoIncCacheOption(tblInfo.AutoIdCache) - tblVer := AllocOptionTableInfoVersion(tblInfo.Version) - - hasRowID := !tblInfo.PKIsHandle && !tblInfo.IsCommonHandle - hasAutoIncID := tblInfo.GetAutoIncrementColInfo() != nil - if hasRowID || hasAutoIncID { - alloc := NewAllocator(r, dbID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), RowIDAllocType, idCacheOpt, tblVer) - allocs = append(allocs, alloc) - } - if hasAutoIncID { - alloc := NewAllocator(r, dbID, tblInfo.ID, tblInfo.IsAutoIncColUnsigned(), AutoIncrementType, idCacheOpt, tblVer) - allocs = append(allocs, alloc) - } - hasAutoRandID := tblInfo.ContainsAutoRandomBits() - if hasAutoRandID { - alloc := NewAllocator(r, dbID, tblInfo.ID, tblInfo.IsAutoRandomBitColUnsigned(), AutoRandomType, idCacheOpt, tblVer) - allocs = append(allocs, alloc) - } - if tblInfo.IsSequence() { - allocs = append(allocs, NewSequenceAllocator(r.Store(), dbID, tblInfo.ID, tblInfo.Sequence)) - } - return NewAllocators(tblInfo.SepAutoInc(), allocs...) -} - -// Alloc implements autoid.Allocator Alloc interface. -// For autoIncrement allocator, the increment and offset should always be positive in [1, 65535]. -// Attention: -// When increment and offset is not the default value(1), the return range (min, max] need to -// calculate the correct start position rather than simply the add 1 to min. Then you can derive -// the successive autoID by adding increment * cnt to firstID for (n-1) times. -// -// Example: -// (6, 13] is returned, increment = 4, offset = 1, n = 2. -// 6 is the last allocated value for other autoID or handle, maybe with different increment and step, -// but actually we don't care about it, all we need is to calculate the new autoID corresponding to the -// increment and offset at this time now. To simplify the rule is like (ID - offset) % increment = 0, -// so the first autoID should be 9, then add increment to it to get 13. -func (alloc *allocator) Alloc(ctx context.Context, n uint64, increment, offset int64) (min int64, max int64, err error) { - if alloc.tbID == 0 { - return 0, 0, errInvalidTableID.GenWithStackByArgs("Invalid tableID") - } - if n == 0 { - return 0, 0, nil - } - if alloc.allocType == AutoIncrementType || alloc.allocType == RowIDAllocType { - if !validIncrementAndOffset(increment, offset) { - return 0, 0, errInvalidIncrementAndOffset.GenWithStackByArgs(increment, offset) - } - } - alloc.mu.Lock() - defer alloc.mu.Unlock() - if alloc.isUnsigned { - return alloc.alloc4Unsigned(ctx, n, increment, offset) - } - return alloc.alloc4Signed(ctx, n, increment, offset) -} - -func (alloc *allocator) AllocSeqCache() (min int64, max int64, round int64, err error) { - alloc.mu.Lock() - defer alloc.mu.Unlock() - return alloc.alloc4Sequence() -} - -func validIncrementAndOffset(increment, offset int64) bool { - return (increment >= minIncrement && increment <= maxIncrement) && (offset >= minIncrement && offset <= maxIncrement) -} - -// CalcNeededBatchSize is used to calculate batch size for autoID allocation. -// It firstly seeks to the first valid position based on increment and offset, -// then plus the length remained, which could be (n-1) * increment. -func CalcNeededBatchSize(base, n, increment, offset int64, isUnsigned bool) int64 { - if increment == 1 { - return n - } - if isUnsigned { - // SeekToFirstAutoIDUnSigned seeks to the next unsigned valid position. - nr := SeekToFirstAutoIDUnSigned(uint64(base), uint64(increment), uint64(offset)) - // Calculate the total batch size needed. - nr += (uint64(n) - 1) * uint64(increment) - return int64(nr - uint64(base)) - } - nr := SeekToFirstAutoIDSigned(base, increment, offset) - // Calculate the total batch size needed. - nr += (n - 1) * increment - return nr - base -} - -// CalcSequenceBatchSize calculate the next sequence batch size. -func CalcSequenceBatchSize(base, size, increment, offset, min, max int64) (int64, error) { - // The sequence is positive growth. - if increment > 0 { - if increment == 1 { - // Sequence is already allocated to the end. - if base >= max { - return 0, ErrAutoincReadFailed - } - // The rest of sequence < cache size, return the rest. - if max-base < size { - return max - base, nil - } - // The rest of sequence is adequate. - return size, nil - } - nr, ok := SeekToFirstSequenceValue(base, increment, offset, min, max) - if !ok { - return 0, ErrAutoincReadFailed - } - // The rest of sequence < cache size, return the rest. - if max-nr < (size-1)*increment { - return max - base, nil - } - return (nr - base) + (size-1)*increment, nil - } - // The sequence is negative growth. - if increment == -1 { - if base <= min { - return 0, ErrAutoincReadFailed - } - if base-min < size { - return base - min, nil - } - return size, nil - } - nr, ok := SeekToFirstSequenceValue(base, increment, offset, min, max) - if !ok { - return 0, ErrAutoincReadFailed - } - // The rest of sequence < cache size, return the rest. - if nr-min < (size-1)*(-increment) { - return base - min, nil - } - return (base - nr) + (size-1)*(-increment), nil -} - -// SeekToFirstSequenceValue seeks to the next valid value (must be in range of [MIN, max]), -// the bool indicates whether the first value is got. -// The seeking formula is describe as below: -// -// nr := (base + increment - offset) / increment -// -// first := nr*increment + offset -// Because formula computation will overflow Int64, so we transfer it to uint64 for distance computation. -func SeekToFirstSequenceValue(base, increment, offset, min, max int64) (int64, bool) { - if increment > 0 { - // Sequence is already allocated to the end. - if base >= max { - return 0, false - } - uMax := EncodeIntToCmpUint(max) - uBase := EncodeIntToCmpUint(base) - uOffset := EncodeIntToCmpUint(offset) - uIncrement := uint64(increment) - if uMax-uBase < uIncrement { - // Enum the possible first value. - for i := uBase + 1; i <= uMax; i++ { - if (i-uOffset)%uIncrement == 0 { - return DecodeCmpUintToInt(i), true - } - } - return 0, false - } - nr := (uBase + uIncrement - uOffset) / uIncrement - nr = nr*uIncrement + uOffset - first := DecodeCmpUintToInt(nr) - return first, true - } - // Sequence is already allocated to the end. - if base <= min { - return 0, false - } - uMin := EncodeIntToCmpUint(min) - uBase := EncodeIntToCmpUint(base) - uOffset := EncodeIntToCmpUint(offset) - uIncrement := uint64(-increment) - if uBase-uMin < uIncrement { - // Enum the possible first value. - for i := uBase - 1; i >= uMin; i-- { - if (uOffset-i)%uIncrement == 0 { - return DecodeCmpUintToInt(i), true - } - } - return 0, false - } - nr := (uOffset - uBase + uIncrement) / uIncrement - nr = uOffset - nr*uIncrement - first := DecodeCmpUintToInt(nr) - return first, true -} - -// SeekToFirstAutoIDSigned seeks to the next valid signed position. -func SeekToFirstAutoIDSigned(base, increment, offset int64) int64 { - nr := (base + increment - offset) / increment - nr = nr*increment + offset - return nr -} - -// SeekToFirstAutoIDUnSigned seeks to the next valid unsigned position. -func SeekToFirstAutoIDUnSigned(base, increment, offset uint64) uint64 { - nr := (base + increment - offset) / increment - nr = nr*increment + offset - return nr -} - -func (alloc *allocator) alloc4Signed(ctx context.Context, n uint64, increment, offset int64) (mini int64, max int64, err error) { - // Check offset rebase if necessary. - if offset-1 > alloc.base { - if err := alloc.rebase4Signed(ctx, offset-1, true); err != nil { - return 0, 0, err - } - } - // CalcNeededBatchSize calculates the total batch size needed. - n1 := CalcNeededBatchSize(alloc.base, int64(n), increment, offset, alloc.isUnsigned) - - // Condition alloc.base+N1 > alloc.end will overflow when alloc.base + N1 > MaxInt64. So need this. - if math.MaxInt64-alloc.base <= n1 { - return 0, 0, ErrAutoincReadFailed - } - // The local rest is not enough for allocN, skip it. - if alloc.base+n1 > alloc.end { - var newBase, newEnd int64 - startTime := time.Now() - nextStep := alloc.step - if !alloc.customStep && alloc.end > 0 { - // Although it may skip a segment here, we still think it is consumed. - consumeDur := startTime.Sub(alloc.lastAllocTime) - nextStep = NextStep(alloc.step, consumeDur) - } - - ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) - if allocatorStats != nil { - allocatorStats.allocCount++ - defer func() { - if commitDetail != nil { - allocatorStats.mergeCommitDetail(*commitDetail) - } - }() - } - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { - defer tracing.StartRegion(ctx, "alloc.alloc4Signed").End() - if allocatorStats != nil { - txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) - } - - idAcc := alloc.getIDAccessor(txn) - var err1 error - newBase, err1 = idAcc.Get() - if err1 != nil { - return err1 - } - // CalcNeededBatchSize calculates the total batch size needed on global base. - n1 = CalcNeededBatchSize(newBase, int64(n), increment, offset, alloc.isUnsigned) - // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. - if nextStep < n1 { - nextStep = n1 - } - tmpStep := min(math.MaxInt64-newBase, nextStep) - // The global rest is not enough for alloc. - if tmpStep < n1 { - return ErrAutoincReadFailed - } - newEnd, err1 = idAcc.Inc(tmpStep) - return err1 - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return 0, 0, err - } - // Store the step for non-customized-step allocator to calculate next dynamic step. - if !alloc.customStep { - alloc.step = nextStep - } - alloc.lastAllocTime = time.Now() - if newBase == math.MaxInt64 { - return 0, 0, ErrAutoincReadFailed - } - alloc.base, alloc.end = newBase, newEnd - } - if logutil.BgLogger().Core().Enabled(zap.DebugLevel) { - logutil.BgLogger().Debug("alloc N signed ID", - zap.Uint64("from ID", uint64(alloc.base)), - zap.Uint64("to ID", uint64(alloc.base+n1)), - zap.Int64("table ID", alloc.tbID), - zap.Int64("database ID", alloc.dbID)) - } - mini = alloc.base - alloc.base += n1 - return mini, alloc.base, nil -} - -func (alloc *allocator) alloc4Unsigned(ctx context.Context, n uint64, increment, offset int64) (mini int64, max int64, err error) { - // Check offset rebase if necessary. - if uint64(offset-1) > uint64(alloc.base) { - if err := alloc.rebase4Unsigned(ctx, uint64(offset-1), true); err != nil { - return 0, 0, err - } - } - // CalcNeededBatchSize calculates the total batch size needed. - n1 := CalcNeededBatchSize(alloc.base, int64(n), increment, offset, alloc.isUnsigned) - - // Condition alloc.base+n1 > alloc.end will overflow when alloc.base + n1 > MaxInt64. So need this. - if math.MaxUint64-uint64(alloc.base) <= uint64(n1) { - return 0, 0, ErrAutoincReadFailed - } - // The local rest is not enough for alloc, skip it. - if uint64(alloc.base)+uint64(n1) > uint64(alloc.end) { - var newBase, newEnd int64 - startTime := time.Now() - nextStep := alloc.step - if !alloc.customStep { - // Although it may skip a segment here, we still treat it as consumed. - consumeDur := startTime.Sub(alloc.lastAllocTime) - nextStep = NextStep(alloc.step, consumeDur) - } - - ctx, allocatorStats, commitDetail := getAllocatorStatsFromCtx(ctx) - if allocatorStats != nil { - allocatorStats.allocCount++ - defer func() { - if commitDetail != nil { - allocatorStats.mergeCommitDetail(*commitDetail) - } - }() - } - - if codeRun := ctx.Value("testIssue39528"); codeRun != nil { - *(codeRun.(*bool)) = true - return 0, 0, errors.New("mock error for test") - } - - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, alloc.store, true, func(ctx context.Context, txn kv.Transaction) error { - defer tracing.StartRegion(ctx, "alloc.alloc4Unsigned").End() - if allocatorStats != nil { - txn.SetOption(kv.CollectRuntimeStats, allocatorStats.SnapshotRuntimeStats) - } - - idAcc := alloc.getIDAccessor(txn) - var err1 error - newBase, err1 = idAcc.Get() - if err1 != nil { - return err1 - } - // CalcNeededBatchSize calculates the total batch size needed on new base. - n1 = CalcNeededBatchSize(newBase, int64(n), increment, offset, alloc.isUnsigned) - // Although the step is customized by user, we still need to make sure nextStep is big enough for insert batch. - if nextStep < n1 { - nextStep = n1 - } - tmpStep := int64(min(math.MaxUint64-uint64(newBase), uint64(nextStep))) - // The global rest is not enough for alloc. - if tmpStep < n1 { - return ErrAutoincReadFailed - } - newEnd, err1 = idAcc.Inc(tmpStep) - return err1 - }) - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return 0, 0, err - } - // Store the step for non-customized-step allocator to calculate next dynamic step. - if !alloc.customStep { - alloc.step = nextStep - } - alloc.lastAllocTime = time.Now() - if uint64(newBase) == math.MaxUint64 { - return 0, 0, ErrAutoincReadFailed - } - alloc.base, alloc.end = newBase, newEnd - } - logutil.Logger(context.TODO()).Debug("alloc unsigned ID", - zap.Uint64(" from ID", uint64(alloc.base)), - zap.Uint64("to ID", uint64(alloc.base+n1)), - zap.Int64("table ID", alloc.tbID), - zap.Int64("database ID", alloc.dbID)) - mini = alloc.base - // Use uint64 n directly. - alloc.base = int64(uint64(alloc.base) + uint64(n1)) - return mini, alloc.base, nil -} - -func getAllocatorStatsFromCtx(ctx context.Context) (context.Context, *AllocatorRuntimeStats, **tikvutil.CommitDetails) { - var allocatorStats *AllocatorRuntimeStats - var commitDetail *tikvutil.CommitDetails - ctxValue := ctx.Value(AllocatorRuntimeStatsCtxKey) - if ctxValue != nil { - allocatorStats = ctxValue.(*AllocatorRuntimeStats) - ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) - } - return ctx, allocatorStats, &commitDetail -} - -// alloc4Sequence is used to alloc value for sequence, there are several aspects different from autoid logic. -// 1: sequence allocation don't need check rebase. -// 2: sequence allocation don't need auto step. -// 3: sequence allocation may have negative growth. -// 4: sequence allocation batch length can be dissatisfied. -// 5: sequence batch allocation will be consumed immediately. -func (alloc *allocator) alloc4Sequence() (min int64, max int64, round int64, err error) { - increment := alloc.sequence.Increment - offset := alloc.sequence.Start - minValue := alloc.sequence.MinValue - maxValue := alloc.sequence.MaxValue - cacheSize := alloc.sequence.CacheValue - if !alloc.sequence.Cache { - cacheSize = 1 - } - - var newBase, newEnd int64 - startTime := time.Now() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) - err = kv.RunInNewTxn(ctx, alloc.store, true, func(_ context.Context, txn kv.Transaction) error { - acc := meta.NewMeta(txn).GetAutoIDAccessors(alloc.dbID, alloc.tbID) - var ( - err1 error - seqStep int64 - ) - // Get the real offset if the sequence is in cycle. - // round is used to count cycle times in sequence with cycle option. - if alloc.sequence.Cycle { - // GetSequenceCycle is used to get the flag `round`, which indicates whether the sequence is already in cycle. - round, err1 = acc.SequenceCycle().Get() - if err1 != nil { - return err1 - } - if round > 0 { - if increment > 0 { - offset = alloc.sequence.MinValue - } else { - offset = alloc.sequence.MaxValue - } - } - } - - // Get the global new base. - newBase, err1 = acc.SequenceValue().Get() - if err1 != nil { - return err1 - } - - // CalcNeededBatchSize calculates the total batch size needed. - seqStep, err1 = CalcSequenceBatchSize(newBase, cacheSize, increment, offset, minValue, maxValue) - - if err1 != nil && err1 == ErrAutoincReadFailed { - if !alloc.sequence.Cycle { - return err1 - } - // Reset the sequence base and offset. - if alloc.sequence.Increment > 0 { - newBase = alloc.sequence.MinValue - 1 - offset = alloc.sequence.MinValue - } else { - newBase = alloc.sequence.MaxValue + 1 - offset = alloc.sequence.MaxValue - } - err1 = acc.SequenceValue().Put(newBase) - if err1 != nil { - return err1 - } - - // Reset sequence round state value. - round++ - // SetSequenceCycle is used to store the flag `round` which indicates whether the sequence is already in cycle. - // round > 0 means the sequence is already in cycle, so the offset should be minvalue / maxvalue rather than sequence.start. - // TiDB is a stateless node, it should know whether the sequence is already in cycle when restart. - err1 = acc.SequenceCycle().Put(round) - if err1 != nil { - return err1 - } - - // Recompute the sequence next batch size. - seqStep, err1 = CalcSequenceBatchSize(newBase, cacheSize, increment, offset, minValue, maxValue) - if err1 != nil { - return err1 - } - } - var delta int64 - if alloc.sequence.Increment > 0 { - delta = seqStep - } else { - delta = -seqStep - } - newEnd, err1 = acc.SequenceValue().Inc(delta) - return err1 - }) - - // TODO: sequence metrics - metrics.AutoIDHistogram.WithLabelValues(metrics.TableAutoIDAlloc, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err != nil { - return 0, 0, 0, err - } - logutil.Logger(context.TODO()).Debug("alloc sequence value", - zap.Uint64(" from value", uint64(newBase)), - zap.Uint64("to value", uint64(newEnd)), - zap.Int64("table ID", alloc.tbID), - zap.Int64("database ID", alloc.dbID)) - return newBase, newEnd, round, nil -} - -func (alloc *allocator) getIDAccessor(txn kv.Transaction) meta.AutoIDAccessor { - acc := meta.NewMeta(txn).GetAutoIDAccessors(alloc.dbID, alloc.tbID) - switch alloc.allocType { - case RowIDAllocType: - return acc.RowID() - case AutoIncrementType: - return acc.IncrementID(alloc.tbVersion) - case AutoRandomType: - return acc.RandomID() - case SequenceType: - return acc.SequenceValue() - } - return nil -} - -const signMask uint64 = 0x8000000000000000 - -// EncodeIntToCmpUint make int v to comparable uint type -func EncodeIntToCmpUint(v int64) uint64 { - return uint64(v) ^ signMask -} - -// DecodeCmpUintToInt decodes the u that encoded by EncodeIntToCmpUint -func DecodeCmpUintToInt(u uint64) int64 { - return int64(u ^ signMask) -} - -// TestModifyBaseAndEndInjection exported for testing modifying the base and end. -func TestModifyBaseAndEndInjection(alloc Allocator, base, end int64) { - alloc.(*allocator).mu.Lock() - alloc.(*allocator).base = base - alloc.(*allocator).end = end - alloc.(*allocator).mu.Unlock() -} - -// ShardIDFormat is used to calculate the bit length of different segments in auto id. -// Generally, an auto id is consist of 4 segments: sign bit, reserved bits, shard bits and incremental bits. -// Take "a BIGINT AUTO_INCREMENT PRIMARY KEY" as an example, assume that the `shard_row_id_bits` = 5, -// the layout is like -// -// | [sign_bit] (1 bit) | [reserved bits] (0 bits) | [shard_bits] (5 bits) | [incremental_bits] (64-1-5=58 bits) | -// -// Please always use NewShardIDFormat() to instantiate. -type ShardIDFormat struct { - FieldType *types.FieldType - ShardBits uint64 - // Derived fields. - IncrementalBits uint64 -} - -// NewShardIDFormat create an instance of ShardIDFormat. -// RangeBits means the bit length of the sign bit + shard bits + incremental bits. -// If RangeBits is 0, it will be calculated according to field type automatically. -func NewShardIDFormat(fieldType *types.FieldType, shardBits, rangeBits uint64) ShardIDFormat { - var incrementalBits uint64 - if rangeBits == 0 { - // Zero means that the range bits is not specified. We interpret it as the length of BIGINT. - incrementalBits = RowIDBitLength - shardBits - } else { - incrementalBits = rangeBits - shardBits - } - hasSignBit := !mysql.HasUnsignedFlag(fieldType.GetFlag()) - if hasSignBit { - incrementalBits-- - } - return ShardIDFormat{ - FieldType: fieldType, - ShardBits: shardBits, - IncrementalBits: incrementalBits, - } -} - -// IncrementalBitsCapacity returns the max capacity of incremental section of the current format. -func (s *ShardIDFormat) IncrementalBitsCapacity() uint64 { - return uint64(s.IncrementalMask()) -} - -// IncrementalMask returns 00..0[11..1], where [11..1] is the incremental part of the current format. -func (s *ShardIDFormat) IncrementalMask() int64 { - return (1 << s.IncrementalBits) - 1 -} - -// Compose generates an auto ID based on the given shard and an incremental ID. -func (s *ShardIDFormat) Compose(shard int64, id int64) int64 { - return ((shard & ((1 << s.ShardBits) - 1)) << s.IncrementalBits) | id -} - -type allocatorRuntimeStatsCtxKeyType struct{} - -// AllocatorRuntimeStatsCtxKey is the context key of allocator runtime stats. -var AllocatorRuntimeStatsCtxKey = allocatorRuntimeStatsCtxKeyType{} - -// AllocatorRuntimeStats is the execution stats of auto id allocator. -type AllocatorRuntimeStats struct { - *txnsnapshot.SnapshotRuntimeStats - *execdetails.RuntimeStatsWithCommit - allocCount int - rebaseCount int -} - -// NewAllocatorRuntimeStats return a new AllocatorRuntimeStats. -func NewAllocatorRuntimeStats() *AllocatorRuntimeStats { - return &AllocatorRuntimeStats{ - SnapshotRuntimeStats: &txnsnapshot.SnapshotRuntimeStats{}, - } -} - -func (e *AllocatorRuntimeStats) mergeCommitDetail(detail *tikvutil.CommitDetails) { - if detail == nil { - return - } - if e.RuntimeStatsWithCommit == nil { - e.RuntimeStatsWithCommit = &execdetails.RuntimeStatsWithCommit{} - } - e.RuntimeStatsWithCommit.MergeCommitDetails(detail) -} - -// String implements the RuntimeStats interface. -func (e *AllocatorRuntimeStats) String() string { - if e.allocCount == 0 && e.rebaseCount == 0 { - return "" - } - var buf bytes.Buffer - buf.WriteString("auto_id_allocator: {") - initialSize := buf.Len() - if e.allocCount > 0 { - buf.WriteString("alloc_cnt: ") - buf.WriteString(strconv.FormatInt(int64(e.allocCount), 10)) - } - if e.rebaseCount > 0 { - if buf.Len() > initialSize { - buf.WriteString(", ") - } - buf.WriteString("rebase_cnt: ") - buf.WriteString(strconv.FormatInt(int64(e.rebaseCount), 10)) - } - if e.SnapshotRuntimeStats != nil { - stats := e.SnapshotRuntimeStats.String() - if stats != "" { - if buf.Len() > initialSize { - buf.WriteString(", ") - } - buf.WriteString(e.SnapshotRuntimeStats.String()) - } - } - if e.RuntimeStatsWithCommit != nil { - stats := e.RuntimeStatsWithCommit.String() - if stats != "" { - if buf.Len() > initialSize { - buf.WriteString(", ") - } - buf.WriteString(stats) - } - } - buf.WriteString("}") - return buf.String() -} - -// Clone implements the RuntimeStats interface. -func (e *AllocatorRuntimeStats) Clone() *AllocatorRuntimeStats { - newRs := &AllocatorRuntimeStats{ - allocCount: e.allocCount, - rebaseCount: e.rebaseCount, - } - if e.SnapshotRuntimeStats != nil { - snapshotStats := e.SnapshotRuntimeStats.Clone() - newRs.SnapshotRuntimeStats = snapshotStats - } - if e.RuntimeStatsWithCommit != nil { - newRs.RuntimeStatsWithCommit = e.RuntimeStatsWithCommit.Clone().(*execdetails.RuntimeStatsWithCommit) - } - return newRs -} - -// Merge implements the RuntimeStats interface. -func (e *AllocatorRuntimeStats) Merge(other *AllocatorRuntimeStats) { - if other == nil { - return - } - if other.SnapshotRuntimeStats != nil { - if e.SnapshotRuntimeStats == nil { - e.SnapshotRuntimeStats = other.SnapshotRuntimeStats.Clone() - } else { - e.SnapshotRuntimeStats.Merge(other.SnapshotRuntimeStats) - } - } - if other.RuntimeStatsWithCommit != nil { - if e.RuntimeStatsWithCommit == nil { - e.RuntimeStatsWithCommit = other.RuntimeStatsWithCommit.Clone().(*execdetails.RuntimeStatsWithCommit) - } else { - e.RuntimeStatsWithCommit.Merge(other.RuntimeStatsWithCommit) - } - } -} diff --git a/pkg/meta/autoid/binding__failpoint_binding__.go b/pkg/meta/autoid/binding__failpoint_binding__.go deleted file mode 100644 index 2c1025c7f434f..0000000000000 --- a/pkg/meta/autoid/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package autoid - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/owner/binding__failpoint_binding__.go b/pkg/owner/binding__failpoint_binding__.go deleted file mode 100644 index 6f8eac02d8e5b..0000000000000 --- a/pkg/owner/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package owner - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/owner/manager.go b/pkg/owner/manager.go index 4453f88bca07b..cfe227e580699 100644 --- a/pkg/owner/manager.go +++ b/pkg/owner/manager.go @@ -273,12 +273,12 @@ func (m *ownerManager) campaignLoop(etcdSession *concurrency.Session) { } m.sessionLease.Store(int64(etcdSession.Lease())) case <-campaignContext.Done(): - if v, _err_ := failpoint.Eval(_curpkg_("MockDelOwnerKey")); _err_ == nil { + failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { if v.(string) == "delOwnerKeyAndNotOwner" { logutil.Logger(logCtx).Info("mock break campaign and don't clear related info") return } - } + }) logutil.Logger(logCtx).Info("break campaign loop, context is done") m.revokeSession(logPrefix, etcdSession.Lease()) return @@ -408,13 +408,13 @@ func (m *ownerManager) SetOwnerOpValue(ctx context.Context, op OpType) error { } newOwnerVal := joinOwnerValues(ownerID, []byte{byte(op)}) - if v, _err_ := failpoint.Eval(_curpkg_("MockDelOwnerKey")); _err_ == nil { + failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { if valStr, ok := v.(string); ok { if err := mockDelOwnerKey(valStr, ownerKey, m); err != nil { - return err + failpoint.Return(err) } } - } + }) leaseOp := clientv3.WithLease(clientv3.LeaseID(m.sessionLease.Load())) resp, err := m.etcdCli.Txn(ctx). diff --git a/pkg/owner/manager.go__failpoint_stash__ b/pkg/owner/manager.go__failpoint_stash__ deleted file mode 100644 index cfe227e580699..0000000000000 --- a/pkg/owner/manager.go__failpoint_stash__ +++ /dev/null @@ -1,486 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package owner - -import ( - "bytes" - "context" - "fmt" - "os" - "strconv" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/terror" - util2 "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.etcd.io/etcd/api/v3/mvccpb" - "go.etcd.io/etcd/api/v3/v3rpc/rpctypes" - clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/client/v3/concurrency" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -// Listener is used to listen the ownerManager's owner state. -type Listener interface { - OnBecomeOwner() - OnRetireOwner() -} - -// Manager is used to campaign the owner and manage the owner information. -type Manager interface { - // ID returns the ID of the manager. - ID() string - // IsOwner returns whether the ownerManager is the owner. - IsOwner() bool - // RetireOwner make the manager to be a not owner. It's exported for testing. - RetireOwner() - // GetOwnerID gets the owner ID. - GetOwnerID(ctx context.Context) (string, error) - // SetOwnerOpValue updates the owner op value. - SetOwnerOpValue(ctx context.Context, op OpType) error - // CampaignOwner campaigns the owner. - CampaignOwner(...int) error - // ResignOwner lets the owner start a new election. - ResignOwner(ctx context.Context) error - // Cancel cancels this etcd ownerManager. - Cancel() - // RequireOwner requires the ownerManager is owner. - RequireOwner(ctx context.Context) error - // CampaignCancel cancels one etcd campaign - CampaignCancel() - // SetListener sets the listener, set before CampaignOwner. - SetListener(listener Listener) -} - -const ( - keyOpDefaultTimeout = 5 * time.Second -) - -// OpType is the owner key value operation type. -type OpType byte - -// List operation of types. -const ( - OpNone OpType = 0 - OpSyncUpgradingState OpType = 1 -) - -// String implements fmt.Stringer interface. -func (ot OpType) String() string { - switch ot { - case OpSyncUpgradingState: - return "sync upgrading state" - default: - return "none" - } -} - -// IsSyncedUpgradingState represents whether the upgrading state is synchronized. -func (ot OpType) IsSyncedUpgradingState() bool { - return ot == OpSyncUpgradingState -} - -// DDLOwnerChecker is used to check whether tidb is owner. -type DDLOwnerChecker interface { - // IsOwner returns whether the ownerManager is the owner. - IsOwner() bool -} - -// ownerManager represents the structure which is used for electing owner. -type ownerManager struct { - id string // id is the ID of the manager. - key string - ctx context.Context - prompt string - logPrefix string - logCtx context.Context - etcdCli *clientv3.Client - cancel context.CancelFunc - elec atomic.Pointer[concurrency.Election] - sessionLease *atomicutil.Int64 - wg sync.WaitGroup - campaignCancel context.CancelFunc - - listener Listener -} - -// NewOwnerManager creates a new Manager. -func NewOwnerManager(ctx context.Context, etcdCli *clientv3.Client, prompt, id, key string) Manager { - logPrefix := fmt.Sprintf("[%s] %s ownerManager %s", prompt, key, id) - ctx, cancelFunc := context.WithCancel(ctx) - return &ownerManager{ - etcdCli: etcdCli, - id: id, - key: key, - ctx: ctx, - prompt: prompt, - cancel: cancelFunc, - logPrefix: logPrefix, - logCtx: logutil.WithKeyValue(context.Background(), "owner info", logPrefix), - sessionLease: atomicutil.NewInt64(0), - } -} - -// ID implements Manager.ID interface. -func (m *ownerManager) ID() string { - return m.id -} - -// IsOwner implements Manager.IsOwner interface. -func (m *ownerManager) IsOwner() bool { - return m.elec.Load() != nil -} - -// Cancel implements Manager.Cancel interface. -func (m *ownerManager) Cancel() { - m.cancel() - m.wg.Wait() -} - -// RequireOwner implements Manager.RequireOwner interface. -func (*ownerManager) RequireOwner(_ context.Context) error { - return nil -} - -func (m *ownerManager) SetListener(listener Listener) { - m.listener = listener -} - -// ManagerSessionTTL is the etcd session's TTL in seconds. It's exported for testing. -var ManagerSessionTTL = 60 - -// setManagerSessionTTL sets the ManagerSessionTTL value, it's used for testing. -func setManagerSessionTTL() error { - ttlStr := os.Getenv("tidb_manager_ttl") - if ttlStr == "" { - return nil - } - ttl, err := strconv.Atoi(ttlStr) - if err != nil { - return errors.Trace(err) - } - ManagerSessionTTL = ttl - return nil -} - -// CampaignOwner implements Manager.CampaignOwner interface. -func (m *ownerManager) CampaignOwner(withTTL ...int) error { - ttl := ManagerSessionTTL - if len(withTTL) == 1 { - ttl = withTTL[0] - } - logPrefix := fmt.Sprintf("[%s] %s", m.prompt, m.key) - logutil.BgLogger().Info("start campaign owner", zap.String("ownerInfo", logPrefix)) - session, err := util2.NewSession(m.ctx, logPrefix, m.etcdCli, util2.NewSessionDefaultRetryCnt, ttl) - if err != nil { - return errors.Trace(err) - } - m.sessionLease.Store(int64(session.Lease())) - m.wg.Add(1) - go m.campaignLoop(session) - return nil -} - -// ResignOwner lets the owner start a new election. -func (m *ownerManager) ResignOwner(ctx context.Context) error { - elec := m.elec.Load() - if elec == nil { - return errors.Errorf("This node is not a owner, can't be resigned") - } - - childCtx, cancel := context.WithTimeout(ctx, keyOpDefaultTimeout) - err := elec.Resign(childCtx) - cancel() - if err != nil { - return errors.Trace(err) - } - - logutil.Logger(m.logCtx).Warn("resign owner success") - return nil -} - -func (m *ownerManager) toBeOwner(elec *concurrency.Election) { - m.elec.Store(elec) - logutil.Logger(m.logCtx).Info("become owner") - if m.listener != nil { - m.listener.OnBecomeOwner() - } -} - -// RetireOwner make the manager to be a not owner. -func (m *ownerManager) RetireOwner() { - m.elec.Store(nil) - logutil.Logger(m.logCtx).Info("retire owner") - if m.listener != nil { - m.listener.OnRetireOwner() - } -} - -// CampaignCancel implements Manager.CampaignCancel interface. -func (m *ownerManager) CampaignCancel() { - m.campaignCancel() - m.wg.Wait() -} - -func (m *ownerManager) campaignLoop(etcdSession *concurrency.Session) { - var campaignContext context.Context - campaignContext, m.campaignCancel = context.WithCancel(m.ctx) - defer func() { - m.campaignCancel() - if r := recover(); r != nil { - logutil.BgLogger().Error("recover panic", zap.String("prompt", m.prompt), zap.Any("error", r), zap.Stack("buffer")) - metrics.PanicCounter.WithLabelValues(metrics.LabelDDLOwner).Inc() - } - m.wg.Done() - }() - - logPrefix := m.logPrefix - logCtx := m.logCtx - var err error - for { - if err != nil { - metrics.CampaignOwnerCounter.WithLabelValues(m.prompt, err.Error()).Inc() - } - - select { - case <-etcdSession.Done(): - logutil.Logger(logCtx).Info("etcd session is done, creates a new one") - leaseID := etcdSession.Lease() - etcdSession, err = util2.NewSession(campaignContext, logPrefix, m.etcdCli, util2.NewSessionRetryUnlimited, ManagerSessionTTL) - if err != nil { - logutil.Logger(logCtx).Info("break campaign loop, NewSession failed", zap.Error(err)) - m.revokeSession(logPrefix, leaseID) - return - } - m.sessionLease.Store(int64(etcdSession.Lease())) - case <-campaignContext.Done(): - failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { - if v.(string) == "delOwnerKeyAndNotOwner" { - logutil.Logger(logCtx).Info("mock break campaign and don't clear related info") - return - } - }) - logutil.Logger(logCtx).Info("break campaign loop, context is done") - m.revokeSession(logPrefix, etcdSession.Lease()) - return - default: - } - // If the etcd server turns clocks forward,the following case may occur. - // The etcd server deletes this session's lease ID, but etcd session doesn't find it. - // In this time if we do the campaign operation, the etcd server will return ErrLeaseNotFound. - if terror.ErrorEqual(err, rpctypes.ErrLeaseNotFound) { - if etcdSession != nil { - err = etcdSession.Close() - logutil.Logger(logCtx).Info("etcd session encounters the error of lease not found, closes it", zap.Error(err)) - } - continue - } - - elec := concurrency.NewElection(etcdSession, m.key) - err = elec.Campaign(campaignContext, m.id) - if err != nil { - logutil.Logger(logCtx).Info("failed to campaign", zap.Error(err)) - continue - } - - ownerKey, err := GetOwnerKey(campaignContext, logCtx, m.etcdCli, m.key, m.id) - if err != nil { - continue - } - - m.toBeOwner(elec) - m.watchOwner(campaignContext, etcdSession, ownerKey) - m.RetireOwner() - - metrics.CampaignOwnerCounter.WithLabelValues(m.prompt, metrics.NoLongerOwner).Inc() - logutil.Logger(logCtx).Warn("is not the owner") - } -} - -func (m *ownerManager) revokeSession(_ string, leaseID clientv3.LeaseID) { - // Revoke the session lease. - // If revoke takes longer than the ttl, lease is expired anyway. - cancelCtx, cancel := context.WithTimeout(context.Background(), - time.Duration(ManagerSessionTTL)*time.Second) - _, err := m.etcdCli.Revoke(cancelCtx, leaseID) - cancel() - logutil.Logger(m.logCtx).Info("revoke session", zap.Error(err)) -} - -// GetOwnerID implements Manager.GetOwnerID interface. -func (m *ownerManager) GetOwnerID(ctx context.Context) (string, error) { - _, ownerID, _, _, err := getOwnerInfo(ctx, m.logCtx, m.etcdCli, m.key) - return string(ownerID), errors.Trace(err) -} - -func getOwnerInfo(ctx, logCtx context.Context, etcdCli *clientv3.Client, ownerPath string) (string, []byte, OpType, int64, error) { - var op OpType - var resp *clientv3.GetResponse - var err error - for i := 0; i < 3; i++ { - if err = ctx.Err(); err != nil { - return "", nil, op, 0, errors.Trace(err) - } - - childCtx, cancel := context.WithTimeout(ctx, util.KeyOpDefaultTimeout) - resp, err = etcdCli.Get(childCtx, ownerPath, clientv3.WithFirstCreate()...) - cancel() - if err == nil { - break - } - logutil.Logger(logCtx).Info("etcd-cli get owner info failed", zap.String("key", ownerPath), zap.Int("retryCnt", i), zap.Error(err)) - time.Sleep(util.KeyOpRetryInterval) - } - if err != nil { - logutil.Logger(logCtx).Warn("etcd-cli get owner info failed", zap.Error(err)) - return "", nil, op, 0, errors.Trace(err) - } - if len(resp.Kvs) == 0 { - return "", nil, op, 0, concurrency.ErrElectionNoLeader - } - - var ownerID []byte - ownerID, op = splitOwnerValues(resp.Kvs[0].Value) - logutil.Logger(logCtx).Info("get owner", zap.ByteString("owner key", resp.Kvs[0].Key), - zap.ByteString("ownerID", ownerID), zap.Stringer("op", op)) - return string(resp.Kvs[0].Key), ownerID, op, resp.Kvs[0].ModRevision, nil -} - -// GetOwnerKey gets the owner key information. -func GetOwnerKey(ctx, logCtx context.Context, etcdCli *clientv3.Client, etcdKey, id string) (string, error) { - ownerKey, ownerID, _, _, err := getOwnerInfo(ctx, logCtx, etcdCli, etcdKey) - if err != nil { - return "", errors.Trace(err) - } - if string(ownerID) != id { - logutil.Logger(logCtx).Warn("is not the owner") - return "", errors.New("ownerInfoNotMatch") - } - - return ownerKey, nil -} - -func splitOwnerValues(val []byte) ([]byte, OpType) { - vals := bytes.Split(val, []byte("_")) - var op OpType - if len(vals) == 2 { - op = OpType(vals[1][0]) - } - return vals[0], op -} - -func joinOwnerValues(vals ...[]byte) []byte { - return bytes.Join(vals, []byte("_")) -} - -// SetOwnerOpValue implements Manager.SetOwnerOpValue interface. -func (m *ownerManager) SetOwnerOpValue(ctx context.Context, op OpType) error { - // owner don't change. - ownerKey, ownerID, currOp, modRevision, err := getOwnerInfo(ctx, m.logCtx, m.etcdCli, m.key) - if err != nil { - return errors.Trace(err) - } - if currOp == op { - logutil.Logger(m.logCtx).Info("set owner op is the same as the original, so do nothing.", zap.Stringer("op", op)) - return nil - } - if string(ownerID) != m.id { - return errors.New("ownerInfoNotMatch") - } - newOwnerVal := joinOwnerValues(ownerID, []byte{byte(op)}) - - failpoint.Inject("MockDelOwnerKey", func(v failpoint.Value) { - if valStr, ok := v.(string); ok { - if err := mockDelOwnerKey(valStr, ownerKey, m); err != nil { - failpoint.Return(err) - } - } - }) - - leaseOp := clientv3.WithLease(clientv3.LeaseID(m.sessionLease.Load())) - resp, err := m.etcdCli.Txn(ctx). - If(clientv3.Compare(clientv3.ModRevision(ownerKey), "=", modRevision)). - Then(clientv3.OpPut(ownerKey, string(newOwnerVal), leaseOp)). - Commit() - if err == nil && !resp.Succeeded { - err = errors.New("put owner key failed, cmp is false") - } - logutil.BgLogger().Info("set owner op value", zap.String("owner key", ownerKey), zap.ByteString("ownerID", ownerID), - zap.Stringer("old Op", currOp), zap.Stringer("op", op), zap.Error(err)) - metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.PutValue+"_"+metrics.RetLabel(err)).Inc() - return errors.Trace(err) -} - -// GetOwnerOpValue gets the owner op value. -func GetOwnerOpValue(ctx context.Context, etcdCli *clientv3.Client, ownerPath, logPrefix string) (OpType, error) { - // It's using for testing. - if etcdCli == nil { - return *mockOwnerOpValue.Load(), nil - } - - logCtx := logutil.WithKeyValue(context.Background(), "owner info", logPrefix) - _, _, op, _, err := getOwnerInfo(ctx, logCtx, etcdCli, ownerPath) - return op, errors.Trace(err) -} - -func (m *ownerManager) watchOwner(ctx context.Context, etcdSession *concurrency.Session, key string) { - logPrefix := fmt.Sprintf("[%s] ownerManager %s watch owner key %v", m.prompt, m.id, key) - logCtx := logutil.WithKeyValue(context.Background(), "owner info", logPrefix) - logutil.BgLogger().Debug(logPrefix) - watchCh := m.etcdCli.Watch(ctx, key) - for { - select { - case resp, ok := <-watchCh: - if !ok { - metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.WatcherClosed).Inc() - logutil.Logger(logCtx).Info("watcher is closed, no owner") - return - } - if resp.Canceled { - metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.Cancelled).Inc() - logutil.Logger(logCtx).Info("watch canceled, no owner") - return - } - - for _, ev := range resp.Events { - if ev.Type == mvccpb.DELETE { - metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.Deleted).Inc() - logutil.Logger(logCtx).Info("watch failed, owner is deleted") - return - } - } - case <-etcdSession.Done(): - metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.SessionDone).Inc() - return - case <-ctx.Done(): - metrics.WatchOwnerCounter.WithLabelValues(m.prompt, metrics.CtxDone).Inc() - return - } - } -} - -func init() { - err := setManagerSessionTTL() - if err != nil { - logutil.BgLogger().Warn("set manager session TTL failed", zap.Error(err)) - } -} diff --git a/pkg/owner/mock.go b/pkg/owner/mock.go index 372f587e385f0..75a934307b41e 100644 --- a/pkg/owner/mock.go +++ b/pkg/owner/mock.go @@ -123,11 +123,11 @@ func (m *mockManager) GetOwnerID(_ context.Context) (string, error) { } func (*mockManager) SetOwnerOpValue(_ context.Context, op OpType) error { - if val, _err_ := failpoint.Eval(_curpkg_("MockNotSetOwnerOp")); _err_ == nil { + failpoint.Inject("MockNotSetOwnerOp", func(val failpoint.Value) { if val.(bool) { - return nil + failpoint.Return(nil) } - } + }) mockOwnerOpValue.Store(&op) return nil } diff --git a/pkg/owner/mock.go__failpoint_stash__ b/pkg/owner/mock.go__failpoint_stash__ deleted file mode 100644 index 75a934307b41e..0000000000000 --- a/pkg/owner/mock.go__failpoint_stash__ +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package owner - -import ( - "context" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/timeutil" - "go.uber.org/zap" -) - -var _ Manager = &mockManager{} - -// mockManager represents the structure which is used for electing owner. -// It's used for local store and testing. -// So this worker will always be the owner. -type mockManager struct { - id string // id is the ID of manager. - storeID string - key string - ctx context.Context - wg sync.WaitGroup - cancel context.CancelFunc - listener Listener - retireHook func() - campaignDone chan struct{} - resignDone chan struct{} -} - -var mockOwnerOpValue atomic.Pointer[OpType] - -// NewMockManager creates a new mock Manager. -func NewMockManager(ctx context.Context, id string, store kv.Storage, ownerKey string) Manager { - cancelCtx, cancelFunc := context.WithCancel(ctx) - storeID := "mock_store_id" - if store != nil { - storeID = store.UUID() - } - - // Make sure the mockOwnerOpValue is initialized before GetOwnerOpValue in bootstrap. - op := OpNone - mockOwnerOpValue.Store(&op) - return &mockManager{ - id: id, - storeID: storeID, - key: ownerKey, - ctx: cancelCtx, - cancel: cancelFunc, - campaignDone: make(chan struct{}), - resignDone: make(chan struct{}), - } -} - -// ID implements Manager.ID interface. -func (m *mockManager) ID() string { - return m.id -} - -// IsOwner implements Manager.IsOwner interface. -func (m *mockManager) IsOwner() bool { - logutil.BgLogger().Debug("owner manager checks owner", - zap.String("ownerKey", m.key), zap.String("ID", m.id)) - return util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).IsOwner(m.id) -} - -func (m *mockManager) toBeOwner() { - ok := util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).SetOwner(m.id) - if ok { - logutil.BgLogger().Info("owner manager gets owner", - zap.String("ownerKey", m.key), zap.String("ID", m.id)) - if m.listener != nil { - m.listener.OnBecomeOwner() - } - } -} - -// RetireOwner implements Manager.RetireOwner interface. -func (m *mockManager) RetireOwner() { - ok := util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).UnsetOwner(m.id) - if ok { - logutil.BgLogger().Info("owner manager retire owner", - zap.String("ownerKey", m.key), zap.String("ID", m.id)) - if m.listener != nil { - m.listener.OnRetireOwner() - } - } -} - -// Cancel implements Manager.Cancel interface. -func (m *mockManager) Cancel() { - m.cancel() - m.wg.Wait() - logutil.BgLogger().Info("owner manager is canceled", - zap.String("ownerKey", m.key), zap.String("ID", m.id)) -} - -// GetOwnerID implements Manager.GetOwnerID interface. -func (m *mockManager) GetOwnerID(_ context.Context) (string, error) { - if m.IsOwner() { - return m.ID(), nil - } - return "", errors.New("no owner") -} - -func (*mockManager) SetOwnerOpValue(_ context.Context, op OpType) error { - failpoint.Inject("MockNotSetOwnerOp", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil) - } - }) - mockOwnerOpValue.Store(&op) - return nil -} - -// CampaignOwner implements Manager.CampaignOwner interface. -func (m *mockManager) CampaignOwner(_ ...int) error { - m.wg.Add(1) - go func() { - logutil.BgLogger().Debug("owner manager campaign owner", - zap.String("ownerKey", m.key), zap.String("ID", m.id)) - defer m.wg.Done() - for { - select { - case <-m.campaignDone: - m.RetireOwner() - logutil.BgLogger().Debug("owner manager campaign done", zap.String("ID", m.id)) - return - case <-m.ctx.Done(): - m.RetireOwner() - logutil.BgLogger().Debug("owner manager is cancelled", zap.String("ID", m.id)) - return - case <-m.resignDone: - m.RetireOwner() - //nolint: errcheck - timeutil.Sleep(m.ctx, 1*time.Second) // Give a chance to the other owner managers to get owner. - default: - m.toBeOwner() - //nolint: errcheck - timeutil.Sleep(m.ctx, 1*time.Second) // Speed up domain.Close() - logutil.BgLogger().Debug("owner manager tick", zap.String("ID", m.id), - zap.String("ownerKey", m.key), zap.String("currentOwner", util.MockGlobalStateEntry.OwnerKey(m.storeID, m.key).GetOwner())) - } - } - }() - return nil -} - -// ResignOwner lets the owner start a new election. -func (m *mockManager) ResignOwner(_ context.Context) error { - m.resignDone <- struct{}{} - return nil -} - -// RequireOwner implements Manager.RequireOwner interface. -func (*mockManager) RequireOwner(context.Context) error { - return nil -} - -// SetListener implements Manager.SetListener interface. -func (m *mockManager) SetListener(listener Listener) { - m.listener = listener -} - -// CampaignCancel implements Manager.CampaignCancel interface -func (m *mockManager) CampaignCancel() { - m.campaignDone <- struct{}{} -} - -func mockDelOwnerKey(mockCal, ownerKey string, m *ownerManager) error { - checkIsOwner := func(m *ownerManager, checkTrue bool) error { - // 5s - for i := 0; i < 100; i++ { - if m.IsOwner() == checkTrue { - break - } - time.Sleep(50 * time.Millisecond) - } - if m.IsOwner() != checkTrue { - return errors.Errorf("expect manager state:%v", checkTrue) - } - return nil - } - - needCheckOwner := false - switch mockCal { - case "delOwnerKeyAndNotOwner": - m.CampaignCancel() - // Make sure the manager is not owner. And it will exit campaignLoop. - err := checkIsOwner(m, false) - if err != nil { - return err - } - case "onlyDelOwnerKey": - needCheckOwner = true - } - - err := util.DeleteKeyFromEtcd(ownerKey, m.etcdCli, 1, keyOpDefaultTimeout) - if err != nil { - return errors.Trace(err) - } - if needCheckOwner { - // Mock the manager become not owner because the owner is deleted(like TTL is timeout). - // And then the manager campaigns the owner again, and become the owner. - err = checkIsOwner(m, true) - if err != nil { - return err - } - } - return nil -} diff --git a/pkg/parser/ast/binding__failpoint_binding__.go b/pkg/parser/ast/binding__failpoint_binding__.go deleted file mode 100644 index 88f2a0560be99..0000000000000 --- a/pkg/parser/ast/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package ast - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/parser/ast/misc.go b/pkg/parser/ast/misc.go index 13d83ea4d8da3..fa07d309953d0 100644 --- a/pkg/parser/ast/misc.go +++ b/pkg/parser/ast/misc.go @@ -3582,9 +3582,9 @@ func RedactURL(str string) string { return str } scheme := u.Scheme - if _, _err_ := failpoint.Eval(_curpkg_("forceRedactURL")); _err_ == nil { + failpoint.Inject("forceRedactURL", func() { scheme = "s3" - } + }) switch strings.ToLower(scheme) { case "s3", "ks3": values := u.Query() diff --git a/pkg/parser/ast/misc.go__failpoint_stash__ b/pkg/parser/ast/misc.go__failpoint_stash__ deleted file mode 100644 index fa07d309953d0..0000000000000 --- a/pkg/parser/ast/misc.go__failpoint_stash__ +++ /dev/null @@ -1,4209 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package ast - -import ( - "bytes" - "fmt" - "net/url" - "strconv" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/auth" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" -) - -var ( - _ StmtNode = &AdminStmt{} - _ StmtNode = &AlterUserStmt{} - _ StmtNode = &AlterRangeStmt{} - _ StmtNode = &BeginStmt{} - _ StmtNode = &BinlogStmt{} - _ StmtNode = &CommitStmt{} - _ StmtNode = &CreateUserStmt{} - _ StmtNode = &DeallocateStmt{} - _ StmtNode = &DoStmt{} - _ StmtNode = &ExecuteStmt{} - _ StmtNode = &ExplainStmt{} - _ StmtNode = &GrantStmt{} - _ StmtNode = &PrepareStmt{} - _ StmtNode = &RollbackStmt{} - _ StmtNode = &SetPwdStmt{} - _ StmtNode = &SetRoleStmt{} - _ StmtNode = &SetDefaultRoleStmt{} - _ StmtNode = &SetStmt{} - _ StmtNode = &SetSessionStatesStmt{} - _ StmtNode = &UseStmt{} - _ StmtNode = &FlushStmt{} - _ StmtNode = &KillStmt{} - _ StmtNode = &CreateBindingStmt{} - _ StmtNode = &DropBindingStmt{} - _ StmtNode = &SetBindingStmt{} - _ StmtNode = &ShutdownStmt{} - _ StmtNode = &RestartStmt{} - _ StmtNode = &RenameUserStmt{} - _ StmtNode = &HelpStmt{} - _ StmtNode = &PlanReplayerStmt{} - _ StmtNode = &CompactTableStmt{} - _ StmtNode = &SetResourceGroupStmt{} - - _ Node = &PrivElem{} - _ Node = &VariableAssignment{} -) - -// Isolation level constants. -const ( - ReadCommitted = "READ-COMMITTED" - ReadUncommitted = "READ-UNCOMMITTED" - Serializable = "SERIALIZABLE" - RepeatableRead = "REPEATABLE-READ" - - PumpType = "PUMP" - DrainerType = "DRAINER" -) - -// Transaction mode constants. -const ( - Optimistic = "OPTIMISTIC" - Pessimistic = "PESSIMISTIC" -) - -// TypeOpt is used for parsing data type option from SQL. -type TypeOpt struct { - IsUnsigned bool - IsZerofill bool -} - -// FloatOpt is used for parsing floating-point type option from SQL. -// See http://dev.mysql.com/doc/refman/5.7/en/floating-point-types.html -type FloatOpt struct { - Flen int - Decimal int -} - -// AuthOption is used for parsing create use statement. -type AuthOption struct { - // ByAuthString set as true, if AuthString is used for authorization. Otherwise, authorization is done by HashString. - ByAuthString bool - AuthString string - ByHashString bool - HashString string - AuthPlugin string -} - -// Restore implements Node interface. -func (n *AuthOption) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("IDENTIFIED") - if n.AuthPlugin != "" { - ctx.WriteKeyWord(" WITH ") - ctx.WriteString(n.AuthPlugin) - } - if n.ByAuthString { - ctx.WriteKeyWord(" BY ") - ctx.WriteString(n.AuthString) - } else if n.ByHashString { - ctx.WriteKeyWord(" AS ") - ctx.WriteString(n.HashString) - } - return nil -} - -// TraceStmt is a statement to trace what sql actually does at background. -type TraceStmt struct { - stmtNode - - Stmt StmtNode - Format string - - TracePlan bool - TracePlanTarget string -} - -// Restore implements Node interface. -func (n *TraceStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("TRACE ") - if n.TracePlan { - ctx.WriteKeyWord("PLAN ") - if n.TracePlanTarget != "" { - ctx.WriteKeyWord("TARGET") - ctx.WritePlain(" = ") - ctx.WriteString(n.TracePlanTarget) - ctx.WritePlain(" ") - } - } else if n.Format != "row" { - ctx.WriteKeyWord("FORMAT") - ctx.WritePlain(" = ") - ctx.WriteString(n.Format) - ctx.WritePlain(" ") - } - if err := n.Stmt.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore TraceStmt.Stmt") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *TraceStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*TraceStmt) - node, ok := n.Stmt.Accept(v) - if !ok { - return n, false - } - n.Stmt = node.(StmtNode) - return v.Leave(n) -} - -// ExplainForStmt is a statement to provite information about how is SQL statement executeing -// in connection #ConnectionID -// See https://dev.mysql.com/doc/refman/5.7/en/explain.html -type ExplainForStmt struct { - stmtNode - - Format string - ConnectionID uint64 -} - -// Restore implements Node interface. -func (n *ExplainForStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("EXPLAIN ") - ctx.WriteKeyWord("FORMAT ") - ctx.WritePlain("= ") - ctx.WriteString(n.Format) - ctx.WritePlain(" ") - ctx.WriteKeyWord("FOR ") - ctx.WriteKeyWord("CONNECTION ") - ctx.WritePlain(strconv.FormatUint(n.ConnectionID, 10)) - return nil -} - -// Accept implements Node Accept interface. -func (n *ExplainForStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*ExplainForStmt) - return v.Leave(n) -} - -// ExplainStmt is a statement to provide information about how is SQL statement executed -// or get columns information in a table. -// See https://dev.mysql.com/doc/refman/5.7/en/explain.html -type ExplainStmt struct { - stmtNode - - Stmt StmtNode - Format string - Analyze bool -} - -// Restore implements Node interface. -func (n *ExplainStmt) Restore(ctx *format.RestoreCtx) error { - if showStmt, ok := n.Stmt.(*ShowStmt); ok { - ctx.WriteKeyWord("DESC ") - if err := showStmt.Table.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore ExplainStmt.ShowStmt.Table") - } - if showStmt.Column != nil { - ctx.WritePlain(" ") - if err := showStmt.Column.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore ExplainStmt.ShowStmt.Column") - } - } - return nil - } - ctx.WriteKeyWord("EXPLAIN ") - if n.Analyze { - ctx.WriteKeyWord("ANALYZE ") - } - if !n.Analyze || strings.ToLower(n.Format) != "row" { - ctx.WriteKeyWord("FORMAT ") - ctx.WritePlain("= ") - ctx.WriteString(n.Format) - ctx.WritePlain(" ") - } - if err := n.Stmt.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore ExplainStmt.Stmt") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *ExplainStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*ExplainStmt) - node, ok := n.Stmt.Accept(v) - if !ok { - return n, false - } - n.Stmt = node.(StmtNode) - return v.Leave(n) -} - -// PlanReplayerStmt is a statement to dump or load information for recreating plans -type PlanReplayerStmt struct { - stmtNode - - Stmt StmtNode - Analyze bool - Load bool - HistoricalStatsInfo *AsOfClause - - // Capture indicates 'plan replayer capture ' - Capture bool - // Remove indicates `plan replayer capture remove - Remove bool - - SQLDigest string - PlanDigest string - - // File is used to store 2 cases: - // 1. plan replayer load 'file'; - // 2. plan replayer dump explain 'file' - File string - - // Fields below are currently useless. - - // Where is the where clause in select statement. - Where ExprNode - // OrderBy is the ordering expression list. - OrderBy *OrderByClause - // Limit is the limit clause. - Limit *Limit -} - -// Restore implements Node interface. -func (n *PlanReplayerStmt) Restore(ctx *format.RestoreCtx) error { - if n.Load { - ctx.WriteKeyWord("PLAN REPLAYER LOAD ") - ctx.WriteString(n.File) - return nil - } - if n.Capture { - ctx.WriteKeyWord("PLAN REPLAYER CAPTURE ") - ctx.WriteString(n.SQLDigest) - ctx.WriteKeyWord(" ") - ctx.WriteString(n.PlanDigest) - return nil - } - if n.Remove { - ctx.WriteKeyWord("PLAN REPLAYER CAPTURE REMOVE ") - ctx.WriteString(n.SQLDigest) - ctx.WriteKeyWord(" ") - ctx.WriteString(n.PlanDigest) - return nil - } - - ctx.WriteKeyWord("PLAN REPLAYER DUMP ") - - if n.HistoricalStatsInfo != nil { - ctx.WriteKeyWord("WITH STATS ") - if err := n.HistoricalStatsInfo.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.HistoricalStatsInfo") - } - ctx.WriteKeyWord(" ") - } - if n.Analyze { - ctx.WriteKeyWord("EXPLAIN ANALYZE ") - } else { - ctx.WriteKeyWord("EXPLAIN ") - } - if n.Stmt == nil { - if len(n.File) > 0 { - ctx.WriteString(n.File) - return nil - } - ctx.WriteKeyWord("SLOW QUERY") - if n.Where != nil { - ctx.WriteKeyWord(" WHERE ") - if err := n.Where.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.Where") - } - } - if n.OrderBy != nil { - ctx.WriteKeyWord(" ") - if err := n.OrderBy.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.OrderBy") - } - } - if n.Limit != nil { - ctx.WriteKeyWord(" ") - if err := n.Limit.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.Limit") - } - } - return nil - } - if err := n.Stmt.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore PlanReplayerStmt.Stmt") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *PlanReplayerStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - - n = newNode.(*PlanReplayerStmt) - - if n.Load { - return v.Leave(n) - } - - if n.HistoricalStatsInfo != nil { - info, ok := n.HistoricalStatsInfo.Accept(v) - if !ok { - return n, false - } - n.HistoricalStatsInfo = info.(*AsOfClause) - } - - if n.Stmt == nil { - if n.Where != nil { - node, ok := n.Where.Accept(v) - if !ok { - return n, false - } - n.Where = node.(ExprNode) - } - - if n.OrderBy != nil { - node, ok := n.OrderBy.Accept(v) - if !ok { - return n, false - } - n.OrderBy = node.(*OrderByClause) - } - - if n.Limit != nil { - node, ok := n.Limit.Accept(v) - if !ok { - return n, false - } - n.Limit = node.(*Limit) - } - return v.Leave(n) - } - - node, ok := n.Stmt.Accept(v) - if !ok { - return n, false - } - n.Stmt = node.(StmtNode) - return v.Leave(n) -} - -type CompactReplicaKind string - -const ( - // CompactReplicaKindAll means compacting both TiKV and TiFlash replicas. - CompactReplicaKindAll = "ALL" - - // CompactReplicaKindTiFlash means compacting TiFlash replicas. - CompactReplicaKindTiFlash = "TIFLASH" - - // CompactReplicaKindTiKV means compacting TiKV replicas. - CompactReplicaKindTiKV = "TIKV" -) - -// CompactTableStmt is a statement to manually compact a table. -type CompactTableStmt struct { - stmtNode - - Table *TableName - PartitionNames []model.CIStr - ReplicaKind CompactReplicaKind -} - -// Restore implements Node interface. -func (n *CompactTableStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("ALTER TABLE ") - n.Table.restoreName(ctx) - - ctx.WriteKeyWord(" COMPACT") - if len(n.PartitionNames) != 0 { - ctx.WriteKeyWord(" PARTITION ") - for i, partition := range n.PartitionNames { - if i != 0 { - ctx.WritePlain(",") - } - ctx.WriteName(partition.O) - } - } - if n.ReplicaKind != CompactReplicaKindAll { - ctx.WriteKeyWord(" ") - // Note: There is only TiFlash replica available now. TiKV will be added later. - ctx.WriteKeyWord(string(n.ReplicaKind)) - ctx.WriteKeyWord(" REPLICA") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *CompactTableStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*CompactTableStmt) - node, ok := n.Table.Accept(v) - if !ok { - return n, false - } - n.Table = node.(*TableName) - return v.Leave(n) -} - -// PrepareStmt is a statement to prepares a SQL statement which contains placeholders, -// and it is executed with ExecuteStmt and released with DeallocateStmt. -// See https://dev.mysql.com/doc/refman/5.7/en/prepare.html -type PrepareStmt struct { - stmtNode - - Name string - SQLText string - SQLVar *VariableExpr -} - -// Restore implements Node interface. -func (n *PrepareStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("PREPARE ") - ctx.WriteName(n.Name) - ctx.WriteKeyWord(" FROM ") - if n.SQLText != "" { - ctx.WriteString(n.SQLText) - return nil - } - if n.SQLVar != nil { - if err := n.SQLVar.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore PrepareStmt.SQLVar") - } - return nil - } - return errors.New("An error occurred while restore PrepareStmt") -} - -// Accept implements Node Accept interface. -func (n *PrepareStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*PrepareStmt) - if n.SQLVar != nil { - node, ok := n.SQLVar.Accept(v) - if !ok { - return n, false - } - n.SQLVar = node.(*VariableExpr) - } - return v.Leave(n) -} - -// DeallocateStmt is a statement to release PreparedStmt. -// See https://dev.mysql.com/doc/refman/5.7/en/deallocate-prepare.html -type DeallocateStmt struct { - stmtNode - - Name string -} - -// Restore implements Node interface. -func (n *DeallocateStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("DEALLOCATE PREPARE ") - ctx.WriteName(n.Name) - return nil -} - -// Accept implements Node Accept interface. -func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*DeallocateStmt) - return v.Leave(n) -} - -// Prepared represents a prepared statement. -type Prepared struct { - Stmt StmtNode - StmtType string -} - -// ExecuteStmt is a statement to execute PreparedStmt. -// See https://dev.mysql.com/doc/refman/5.7/en/execute.html -type ExecuteStmt struct { - stmtNode - - Name string - UsingVars []ExprNode - BinaryArgs interface{} - PrepStmt interface{} // the corresponding prepared statement - IdxInMulti int - - // FromGeneralStmt indicates whether this execute-stmt is converted from a general query. - // e.g. select * from t where a>2 --> execute 'select * from t where a>?' using 2 - FromGeneralStmt bool -} - -// Restore implements Node interface. -func (n *ExecuteStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("EXECUTE ") - ctx.WriteName(n.Name) - if len(n.UsingVars) > 0 { - ctx.WriteKeyWord(" USING ") - for i, val := range n.UsingVars { - if i != 0 { - ctx.WritePlain(",") - } - if err := val.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore ExecuteStmt.UsingVars index %d", i) - } - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*ExecuteStmt) - for i, val := range n.UsingVars { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.UsingVars[i] = node.(ExprNode) - } - return v.Leave(n) -} - -// BeginStmt is a statement to start a new transaction. -// See https://dev.mysql.com/doc/refman/5.7/en/commit.html -type BeginStmt struct { - stmtNode - Mode string - CausalConsistencyOnly bool - ReadOnly bool - // AS OF is used to read the data at a specific point of time. - // Should only be used when ReadOnly is true. - AsOf *AsOfClause -} - -// Restore implements Node interface. -func (n *BeginStmt) Restore(ctx *format.RestoreCtx) error { - if n.Mode == "" { - if n.ReadOnly { - ctx.WriteKeyWord("START TRANSACTION READ ONLY") - if n.AsOf != nil { - ctx.WriteKeyWord(" ") - return n.AsOf.Restore(ctx) - } - } else if n.CausalConsistencyOnly { - ctx.WriteKeyWord("START TRANSACTION WITH CAUSAL CONSISTENCY ONLY") - } else { - ctx.WriteKeyWord("START TRANSACTION") - } - } else { - ctx.WriteKeyWord("BEGIN ") - ctx.WriteKeyWord(n.Mode) - } - return nil -} - -// Accept implements Node Accept interface. -func (n *BeginStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - - if n.AsOf != nil { - node, ok := n.AsOf.Accept(v) - if !ok { - return n, false - } - n.AsOf = node.(*AsOfClause) - } - - n = newNode.(*BeginStmt) - return v.Leave(n) -} - -// BinlogStmt is an internal-use statement. -// We just parse and ignore it. -// See http://dev.mysql.com/doc/refman/5.7/en/binlog.html -type BinlogStmt struct { - stmtNode - Str string -} - -// Restore implements Node interface. -func (n *BinlogStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("BINLOG ") - ctx.WriteString(n.Str) - return nil -} - -// Accept implements Node Accept interface. -func (n *BinlogStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*BinlogStmt) - return v.Leave(n) -} - -// CompletionType defines completion_type used in COMMIT and ROLLBACK statements -type CompletionType int8 - -const ( - // CompletionTypeDefault refers to NO_CHAIN - CompletionTypeDefault CompletionType = iota - CompletionTypeChain - CompletionTypeRelease -) - -func (n CompletionType) Restore(ctx *format.RestoreCtx) error { - switch n { - case CompletionTypeDefault: - case CompletionTypeChain: - ctx.WriteKeyWord(" AND CHAIN") - case CompletionTypeRelease: - ctx.WriteKeyWord(" RELEASE") - } - return nil -} - -// CommitStmt is a statement to commit the current transaction. -// See https://dev.mysql.com/doc/refman/5.7/en/commit.html -type CommitStmt struct { - stmtNode - // CompletionType overwrites system variable `completion_type` within transaction - CompletionType CompletionType -} - -// Restore implements Node interface. -func (n *CommitStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("COMMIT") - if err := n.CompletionType.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore CommitStmt.CompletionType") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *CommitStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*CommitStmt) - return v.Leave(n) -} - -// RollbackStmt is a statement to roll back the current transaction. -// See https://dev.mysql.com/doc/refman/5.7/en/commit.html -type RollbackStmt struct { - stmtNode - // CompletionType overwrites system variable `completion_type` within transaction - CompletionType CompletionType - // SavepointName is the savepoint name. - SavepointName string -} - -// Restore implements Node interface. -func (n *RollbackStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("ROLLBACK") - if n.SavepointName != "" { - ctx.WritePlain(" TO ") - ctx.WritePlain(n.SavepointName) - } - if err := n.CompletionType.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore RollbackStmt.CompletionType") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *RollbackStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*RollbackStmt) - return v.Leave(n) -} - -// UseStmt is a statement to use the DBName database as the current database. -// See https://dev.mysql.com/doc/refman/5.7/en/use.html -type UseStmt struct { - stmtNode - - DBName string -} - -// Restore implements Node interface. -func (n *UseStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("USE ") - ctx.WriteName(n.DBName) - return nil -} - -// Accept implements Node Accept interface. -func (n *UseStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*UseStmt) - return v.Leave(n) -} - -const ( - // SetNames is the const for set names stmt. - // If VariableAssignment.Name == Names, it should be set names stmt. - SetNames = "SetNAMES" - // SetCharset is the const for set charset stmt. - SetCharset = "SetCharset" - // TiDBCloudStorageURI is the const for set tidb_cloud_storage_uri stmt. - TiDBCloudStorageURI = "tidb_cloud_storage_uri" -) - -// VariableAssignment is a variable assignment struct. -type VariableAssignment struct { - node - Name string - Value ExprNode - IsGlobal bool - IsSystem bool - - // ExtendValue is a way to store extended info. - // VariableAssignment should be able to store information for SetCharset/SetPWD Stmt. - // For SetCharsetStmt, Value is charset, ExtendValue is collation. - // TODO: Use SetStmt to implement set password statement. - ExtendValue ValueExpr -} - -// Restore implements Node interface. -func (n *VariableAssignment) Restore(ctx *format.RestoreCtx) error { - if n.IsSystem { - ctx.WritePlain("@@") - if n.IsGlobal { - ctx.WriteKeyWord("GLOBAL") - } else { - ctx.WriteKeyWord("SESSION") - } - ctx.WritePlain(".") - } else if n.Name != SetNames && n.Name != SetCharset { - ctx.WriteKeyWord("@") - } - if n.Name == SetNames { - ctx.WriteKeyWord("NAMES ") - } else if n.Name == SetCharset { - ctx.WriteKeyWord("CHARSET ") - } else { - ctx.WriteName(n.Name) - ctx.WritePlain("=") - } - if n.Name == TiDBCloudStorageURI { - // need to redact the url for safety when `show processlist;` - ctx.WritePlain(RedactURL(n.Value.(ValueExpr).GetString())) - } else if err := n.Value.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore VariableAssignment.Value") - } - if n.ExtendValue != nil { - ctx.WriteKeyWord(" COLLATE ") - if err := n.ExtendValue.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore VariableAssignment.ExtendValue") - } - } - return nil -} - -// Accept implements Node interface. -func (n *VariableAssignment) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*VariableAssignment) - node, ok := n.Value.Accept(v) - if !ok { - return n, false - } - n.Value = node.(ExprNode) - return v.Leave(n) -} - -// FlushStmtType is the type for FLUSH statement. -type FlushStmtType int - -// Flush statement types. -const ( - FlushNone FlushStmtType = iota - FlushTables - FlushPrivileges - FlushStatus - FlushTiDBPlugin - FlushHosts - FlushLogs - FlushClientErrorsSummary -) - -// LogType is the log type used in FLUSH statement. -type LogType int8 - -const ( - LogTypeDefault LogType = iota - LogTypeBinary - LogTypeEngine - LogTypeError - LogTypeGeneral - LogTypeSlow -) - -// FlushStmt is a statement to flush tables/privileges/optimizer costs and so on. -type FlushStmt struct { - stmtNode - - Tp FlushStmtType // Privileges/Tables/... - NoWriteToBinLog bool - LogType LogType - Tables []*TableName // For FlushTableStmt, if Tables is empty, it means flush all tables. - ReadLock bool - Plugins []string -} - -// Restore implements Node interface. -func (n *FlushStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("FLUSH ") - if n.NoWriteToBinLog { - ctx.WriteKeyWord("NO_WRITE_TO_BINLOG ") - } - switch n.Tp { - case FlushTables: - ctx.WriteKeyWord("TABLES") - for i, v := range n.Tables { - if i == 0 { - ctx.WritePlain(" ") - } else { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore FlushStmt.Tables[%d]", i) - } - } - if n.ReadLock { - ctx.WriteKeyWord(" WITH READ LOCK") - } - case FlushPrivileges: - ctx.WriteKeyWord("PRIVILEGES") - case FlushStatus: - ctx.WriteKeyWord("STATUS") - case FlushTiDBPlugin: - ctx.WriteKeyWord("TIDB PLUGINS") - for i, v := range n.Plugins { - if i == 0 { - ctx.WritePlain(" ") - } else { - ctx.WritePlain(", ") - } - ctx.WritePlain(v) - } - case FlushHosts: - ctx.WriteKeyWord("HOSTS") - case FlushLogs: - var logType string - switch n.LogType { - case LogTypeDefault: - logType = "LOGS" - case LogTypeBinary: - logType = "BINARY LOGS" - case LogTypeEngine: - logType = "ENGINE LOGS" - case LogTypeError: - logType = "ERROR LOGS" - case LogTypeGeneral: - logType = "GENERAL LOGS" - case LogTypeSlow: - logType = "SLOW LOGS" - } - ctx.WriteKeyWord(logType) - case FlushClientErrorsSummary: - ctx.WriteKeyWord("CLIENT_ERRORS_SUMMARY") - default: - return errors.New("Unsupported type of FlushStmt") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *FlushStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*FlushStmt) - for i, t := range n.Tables { - node, ok := t.Accept(v) - if !ok { - return n, false - } - n.Tables[i] = node.(*TableName) - } - return v.Leave(n) -} - -// KillStmt is a statement to kill a query or connection. -type KillStmt struct { - stmtNode - - // Query indicates whether terminate a single query on this connection or the whole connection. - // If Query is true, terminates the statement the connection is currently executing, but leaves the connection itself intact. - // If Query is false, terminates the connection associated with the given ConnectionID, after terminating any statement the connection is executing. - Query bool - ConnectionID uint64 - // TiDBExtension is used to indicate whether the user knows he is sending kill statement to the right tidb-server. - // When the SQL grammar is "KILL TIDB [CONNECTION | QUERY] connectionID", TiDBExtension will be set. - // It's a special grammar extension in TiDB. This extension exists because, when the connection is: - // client -> LVS proxy -> TiDB, and type Ctrl+C in client, the following action will be executed: - // new a connection; kill xxx; - // kill command may send to the wrong TiDB, because the exists of LVS proxy, and kill the wrong session. - // So, "KILL TIDB" grammar is introduced, and it REQUIRES DIRECT client -> TiDB TOPOLOGY. - // TODO: The standard KILL grammar will be supported once we have global connectionID. - TiDBExtension bool - - Expr ExprNode -} - -// Restore implements Node interface. -func (n *KillStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("KILL") - if n.TiDBExtension { - ctx.WriteKeyWord(" TIDB") - } - if n.Query { - ctx.WriteKeyWord(" QUERY") - } - if n.Expr != nil { - ctx.WriteKeyWord(" ") - if err := n.Expr.Restore(ctx); err != nil { - return errors.Trace(err) - } - } else { - ctx.WritePlainf(" %d", n.ConnectionID) - } - return nil -} - -// Accept implements Node Accept interface. -func (n *KillStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*KillStmt) - return v.Leave(n) -} - -// SavepointStmt is the statement of SAVEPOINT. -type SavepointStmt struct { - stmtNode - // Name is the savepoint name. - Name string -} - -// Restore implements Node interface. -func (n *SavepointStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SAVEPOINT ") - ctx.WritePlain(n.Name) - return nil -} - -// Accept implements Node Accept interface. -func (n *SavepointStmt) Accept(v Visitor) (Node, bool) { - newNode, _ := v.Enter(n) - n = newNode.(*SavepointStmt) - return v.Leave(n) -} - -// ReleaseSavepointStmt is the statement of RELEASE SAVEPOINT. -type ReleaseSavepointStmt struct { - stmtNode - // Name is the savepoint name. - Name string -} - -// Restore implements Node interface. -func (n *ReleaseSavepointStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("RELEASE SAVEPOINT ") - ctx.WritePlain(n.Name) - return nil -} - -// Accept implements Node Accept interface. -func (n *ReleaseSavepointStmt) Accept(v Visitor) (Node, bool) { - newNode, _ := v.Enter(n) - n = newNode.(*ReleaseSavepointStmt) - return v.Leave(n) -} - -// SetStmt is the statement to set variables. -type SetStmt struct { - stmtNode - // Variables is the list of variable assignment. - Variables []*VariableAssignment -} - -// Restore implements Node interface. -func (n *SetStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET ") - for i, v := range n.Variables { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore SetStmt.Variables[%d]", i) - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *SetStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetStmt) - for i, val := range n.Variables { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.Variables[i] = node.(*VariableAssignment) - } - return v.Leave(n) -} - -// SecureText implements SensitiveStatement interface. -// need to redact the tidb_cloud_storage_url for safety when `show processlist;` -func (n *SetStmt) SecureText() string { - redactedStmt := *n - var sb strings.Builder - _ = redactedStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - return sb.String() -} - -// SetConfigStmt is the statement to set cluster configs. -type SetConfigStmt struct { - stmtNode - - Type string // TiDB, TiKV, PD - Instance string // '127.0.0.1:3306' - Name string // the variable name - Value ExprNode -} - -func (n *SetConfigStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET CONFIG ") - if n.Type != "" { - ctx.WriteKeyWord(n.Type) - } else { - ctx.WriteString(n.Instance) - } - ctx.WritePlain(" ") - ctx.WriteKeyWord(n.Name) - ctx.WritePlain(" = ") - return n.Value.Restore(ctx) -} - -func (n *SetConfigStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetConfigStmt) - node, ok := n.Value.Accept(v) - if !ok { - return n, false - } - n.Value = node.(ExprNode) - return v.Leave(n) -} - -// SetSessionStatesStmt is a statement to restore session states. -type SetSessionStatesStmt struct { - stmtNode - - SessionStates string -} - -func (n *SetSessionStatesStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET SESSION_STATES ") - ctx.WriteString(n.SessionStates) - return nil -} - -func (n *SetSessionStatesStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetSessionStatesStmt) - return v.Leave(n) -} - -/* -// SetCharsetStmt is a statement to assign values to character and collation variables. -// See https://dev.mysql.com/doc/refman/5.7/en/set-statement.html -type SetCharsetStmt struct { - stmtNode - - Charset string - Collate string -} - -// Accept implements Node Accept interface. -func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetCharsetStmt) - return v.Leave(n) -} -*/ - -// SetPwdStmt is a statement to assign a password to user account. -// See https://dev.mysql.com/doc/refman/5.7/en/set-password.html -type SetPwdStmt struct { - stmtNode - - User *auth.UserIdentity - Password string -} - -// Restore implements Node interface. -func (n *SetPwdStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET PASSWORD") - if n.User != nil { - ctx.WriteKeyWord(" FOR ") - if err := n.User.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore SetPwdStmt.User") - } - } - ctx.WritePlain("=") - ctx.WriteString(n.Password) - return nil -} - -// SecureText implements SensitiveStatement interface. -func (n *SetPwdStmt) SecureText() string { - return fmt.Sprintf("set password for user %s", n.User) -} - -// Accept implements Node Accept interface. -func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetPwdStmt) - return v.Leave(n) -} - -type ChangeStmt struct { - stmtNode - - NodeType string - State string - NodeID string -} - -// Restore implements Node interface. -func (n *ChangeStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("CHANGE ") - ctx.WriteKeyWord(n.NodeType) - ctx.WriteKeyWord(" TO NODE_STATE ") - ctx.WritePlain("=") - ctx.WriteString(n.State) - ctx.WriteKeyWord(" FOR NODE_ID ") - ctx.WriteString(n.NodeID) - return nil -} - -// SecureText implements SensitiveStatement interface. -func (n *ChangeStmt) SecureText() string { - return fmt.Sprintf("change %s to node_state='%s' for node_id '%s'", strings.ToLower(n.NodeType), n.State, n.NodeID) -} - -// Accept implements Node Accept interface. -func (n *ChangeStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*ChangeStmt) - return v.Leave(n) -} - -// SetRoleStmtType is the type for FLUSH statement. -type SetRoleStmtType int - -// SetRole statement types. -const ( - SetRoleDefault SetRoleStmtType = iota - SetRoleNone - SetRoleAll - SetRoleAllExcept - SetRoleRegular -) - -type SetRoleStmt struct { - stmtNode - - SetRoleOpt SetRoleStmtType - RoleList []*auth.RoleIdentity -} - -func (n *SetRoleStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET ROLE") - switch n.SetRoleOpt { - case SetRoleDefault: - ctx.WriteKeyWord(" DEFAULT") - case SetRoleNone: - ctx.WriteKeyWord(" NONE") - case SetRoleAll: - ctx.WriteKeyWord(" ALL") - case SetRoleAllExcept: - ctx.WriteKeyWord(" ALL EXCEPT") - } - for i, role := range n.RoleList { - ctx.WritePlain(" ") - err := role.Restore(ctx) - if err != nil { - return errors.Annotate(err, "An error occurred while restore SetRoleStmt.RoleList") - } - if i != len(n.RoleList)-1 { - ctx.WritePlain(",") - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *SetRoleStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetRoleStmt) - return v.Leave(n) -} - -type SetDefaultRoleStmt struct { - stmtNode - - SetRoleOpt SetRoleStmtType - RoleList []*auth.RoleIdentity - UserList []*auth.UserIdentity -} - -func (n *SetDefaultRoleStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET DEFAULT ROLE") - switch n.SetRoleOpt { - case SetRoleNone: - ctx.WriteKeyWord(" NONE") - case SetRoleAll: - ctx.WriteKeyWord(" ALL") - default: - } - for i, role := range n.RoleList { - ctx.WritePlain(" ") - err := role.Restore(ctx) - if err != nil { - return errors.Annotate(err, "An error occurred while restore SetDefaultRoleStmt.RoleList") - } - if i != len(n.RoleList)-1 { - ctx.WritePlain(",") - } - } - ctx.WritePlain(" TO") - for i, user := range n.UserList { - ctx.WritePlain(" ") - err := user.Restore(ctx) - if err != nil { - return errors.Annotate(err, "An error occurred while restore SetDefaultRoleStmt.UserList") - } - if i != len(n.UserList)-1 { - ctx.WritePlain(",") - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *SetDefaultRoleStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetDefaultRoleStmt) - return v.Leave(n) -} - -// UserSpec is used for parsing create user statement. -type UserSpec struct { - User *auth.UserIdentity - AuthOpt *AuthOption - IsRole bool -} - -// Restore implements Node interface. -func (n *UserSpec) Restore(ctx *format.RestoreCtx) error { - if err := n.User.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore UserSpec.User") - } - if n.AuthOpt != nil { - ctx.WritePlain(" ") - if err := n.AuthOpt.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore UserSpec.AuthOpt") - } - } - return nil -} - -// SecurityString formats the UserSpec without password information. -func (n *UserSpec) SecurityString() string { - withPassword := false - if opt := n.AuthOpt; opt != nil { - if len(opt.AuthString) > 0 || len(opt.HashString) > 0 { - withPassword = true - } - } - if withPassword { - return fmt.Sprintf("{%s password = ***}", n.User) - } - return n.User.String() -} - -// EncodedPassword returns the encoded password (which is the real data mysql.user). -// The boolean value indicates input's password format is legal or not. -func (n *UserSpec) EncodedPassword() (string, bool) { - if n.AuthOpt == nil { - return "", true - } - - opt := n.AuthOpt - if opt.ByAuthString { - switch opt.AuthPlugin { - case mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password: - return auth.NewHashPassword(opt.AuthString, opt.AuthPlugin), true - case mysql.AuthSocket: - return "", true - default: - return auth.EncodePassword(opt.AuthString), true - } - } - - // store the LDAP dn directly in the password field - switch opt.AuthPlugin { - case mysql.AuthLDAPSimple, mysql.AuthLDAPSASL: - // TODO: validate the HashString to be a `dn` for LDAP - // It seems fine to not validate here, and LDAP server will give an error when the client'll try to login this user. - // The percona server implementation doesn't have a validation for this HashString. - // However, returning an error for obvious wrong format is more friendly. - return opt.HashString, true - } - - // In case we have 'IDENTIFIED WITH ' but no 'BY ' to set an empty password. - if opt.HashString == "" { - return opt.HashString, true - } - - // Not a legal password string. - switch opt.AuthPlugin { - case mysql.AuthCachingSha2Password: - if len(opt.HashString) != mysql.SHAPWDHashLen { - return "", false - } - case mysql.AuthTiDBSM3Password: - if len(opt.HashString) != mysql.SM3PWDHashLen { - return "", false - } - case "", mysql.AuthNativePassword: - if len(opt.HashString) != (mysql.PWDHashLen+1) || !strings.HasPrefix(opt.HashString, "*") { - return "", false - } - case mysql.AuthSocket: - default: - return "", false - } - return opt.HashString, true -} - -type AuthTokenOrTLSOption struct { - Type AuthTokenOrTLSOptionType - Value string -} - -func (t *AuthTokenOrTLSOption) Restore(ctx *format.RestoreCtx) error { - switch t.Type { - case TlsNone: - ctx.WriteKeyWord("NONE") - case Ssl: - ctx.WriteKeyWord("SSL") - case X509: - ctx.WriteKeyWord("X509") - case Cipher: - ctx.WriteKeyWord("CIPHER ") - ctx.WriteString(t.Value) - case Issuer: - ctx.WriteKeyWord("ISSUER ") - ctx.WriteString(t.Value) - case Subject: - ctx.WriteKeyWord("SUBJECT ") - ctx.WriteString(t.Value) - case SAN: - ctx.WriteKeyWord("SAN ") - ctx.WriteString(t.Value) - case TokenIssuer: - ctx.WriteKeyWord("TOKEN_ISSUER ") - ctx.WriteString(t.Value) - default: - return errors.Errorf("Unsupported AuthTokenOrTLSOption.Type %d", t.Type) - } - return nil -} - -type AuthTokenOrTLSOptionType int - -const ( - TlsNone AuthTokenOrTLSOptionType = iota - Ssl - X509 - Cipher - Issuer - Subject - SAN - TokenIssuer -) - -func (t AuthTokenOrTLSOptionType) String() string { - switch t { - case TlsNone: - return "NONE" - case Ssl: - return "SSL" - case X509: - return "X509" - case Cipher: - return "CIPHER" - case Issuer: - return "ISSUER" - case Subject: - return "SUBJECT" - case SAN: - return "SAN" - case TokenIssuer: - return "TOKEN_ISSUER" - default: - return "UNKNOWN" - } -} - -const ( - MaxQueriesPerHour = iota + 1 - MaxUpdatesPerHour - MaxConnectionsPerHour - MaxUserConnections -) - -type ResourceOption struct { - Type int - Count int64 -} - -func (r *ResourceOption) Restore(ctx *format.RestoreCtx) error { - switch r.Type { - case MaxQueriesPerHour: - ctx.WriteKeyWord("MAX_QUERIES_PER_HOUR ") - case MaxUpdatesPerHour: - ctx.WriteKeyWord("MAX_UPDATES_PER_HOUR ") - case MaxConnectionsPerHour: - ctx.WriteKeyWord("MAX_CONNECTIONS_PER_HOUR ") - case MaxUserConnections: - ctx.WriteKeyWord("MAX_USER_CONNECTIONS ") - default: - return errors.Errorf("Unsupported ResourceOption.Type %d", r.Type) - } - ctx.WritePlainf("%d", r.Count) - return nil -} - -const ( - PasswordExpire = iota + 1 - PasswordExpireDefault - PasswordExpireNever - PasswordExpireInterval - PasswordHistory - PasswordHistoryDefault - PasswordReuseInterval - PasswordReuseDefault - Lock - Unlock - FailedLoginAttempts - PasswordLockTime - PasswordLockTimeUnbounded - UserCommentType - UserAttributeType - PasswordRequireCurrentDefault - - UserResourceGroupName -) - -type PasswordOrLockOption struct { - Type int - Count int64 -} - -func (p *PasswordOrLockOption) Restore(ctx *format.RestoreCtx) error { - switch p.Type { - case PasswordExpire: - ctx.WriteKeyWord("PASSWORD EXPIRE") - case PasswordExpireDefault: - ctx.WriteKeyWord("PASSWORD EXPIRE DEFAULT") - case PasswordExpireNever: - ctx.WriteKeyWord("PASSWORD EXPIRE NEVER") - case PasswordExpireInterval: - ctx.WriteKeyWord("PASSWORD EXPIRE INTERVAL") - ctx.WritePlainf(" %d", p.Count) - ctx.WriteKeyWord(" DAY") - case Lock: - ctx.WriteKeyWord("ACCOUNT LOCK") - case Unlock: - ctx.WriteKeyWord("ACCOUNT UNLOCK") - case FailedLoginAttempts: - ctx.WriteKeyWord("FAILED_LOGIN_ATTEMPTS") - ctx.WritePlainf(" %d", p.Count) - case PasswordLockTime: - ctx.WriteKeyWord("PASSWORD_LOCK_TIME") - ctx.WritePlainf(" %d", p.Count) - case PasswordLockTimeUnbounded: - ctx.WriteKeyWord("PASSWORD_LOCK_TIME UNBOUNDED") - case PasswordHistory: - ctx.WriteKeyWord("PASSWORD HISTORY") - ctx.WritePlainf(" %d", p.Count) - case PasswordHistoryDefault: - ctx.WriteKeyWord("PASSWORD HISTORY DEFAULT") - case PasswordReuseInterval: - ctx.WriteKeyWord("PASSWORD REUSE INTERVAL") - ctx.WritePlainf(" %d", p.Count) - ctx.WriteKeyWord(" DAY") - case PasswordReuseDefault: - ctx.WriteKeyWord("PASSWORD REUSE INTERVAL DEFAULT") - default: - return errors.Errorf("Unsupported PasswordOrLockOption.Type %d", p.Type) - } - return nil -} - -type CommentOrAttributeOption struct { - Type int - Value string -} - -func (c *CommentOrAttributeOption) Restore(ctx *format.RestoreCtx) error { - if c.Type == UserCommentType { - ctx.WriteKeyWord(" COMMENT ") - ctx.WriteString(c.Value) - } else if c.Type == UserAttributeType { - ctx.WriteKeyWord(" ATTRIBUTE ") - ctx.WriteString(c.Value) - } - return nil -} - -type ResourceGroupNameOption struct { - Value string -} - -func (c *ResourceGroupNameOption) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord(" RESOURCE GROUP ") - ctx.WriteName(c.Value) - return nil -} - -// CreateUserStmt creates user account. -// See https://dev.mysql.com/doc/refman/8.0/en/create-user.html -type CreateUserStmt struct { - stmtNode - - IsCreateRole bool - IfNotExists bool - Specs []*UserSpec - AuthTokenOrTLSOptions []*AuthTokenOrTLSOption - ResourceOptions []*ResourceOption - PasswordOrLockOptions []*PasswordOrLockOption - CommentOrAttributeOption *CommentOrAttributeOption - ResourceGroupNameOption *ResourceGroupNameOption -} - -// Restore implements Node interface. -func (n *CreateUserStmt) Restore(ctx *format.RestoreCtx) error { - if n.IsCreateRole { - ctx.WriteKeyWord("CREATE ROLE ") - } else { - ctx.WriteKeyWord("CREATE USER ") - } - if n.IfNotExists { - ctx.WriteKeyWord("IF NOT EXISTS ") - } - for i, v := range n.Specs { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.Specs[%d]", i) - } - } - - if len(n.AuthTokenOrTLSOptions) != 0 { - ctx.WriteKeyWord(" REQUIRE ") - } - - for i, option := range n.AuthTokenOrTLSOptions { - if i != 0 { - ctx.WriteKeyWord(" AND ") - } - if err := option.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.AuthTokenOrTLSOptions[%d]", i) - } - } - - if len(n.ResourceOptions) != 0 { - ctx.WriteKeyWord(" WITH") - } - - for i, v := range n.ResourceOptions { - ctx.WritePlain(" ") - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.ResourceOptions[%d]", i) - } - } - - for i, v := range n.PasswordOrLockOptions { - ctx.WritePlain(" ") - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.PasswordOrLockOptions[%d]", i) - } - } - - if n.CommentOrAttributeOption != nil { - if err := n.CommentOrAttributeOption.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.CommentOrAttributeOption") - } - } - - if n.ResourceGroupNameOption != nil { - if err := n.ResourceGroupNameOption.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore CreateUserStmt.ResourceGroupNameOption") - } - } - - return nil -} - -// Accept implements Node Accept interface. -func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*CreateUserStmt) - return v.Leave(n) -} - -// SecureText implements SensitiveStatement interface. -func (n *CreateUserStmt) SecureText() string { - var buf bytes.Buffer - buf.WriteString("create user") - for _, user := range n.Specs { - buf.WriteString(" ") - buf.WriteString(user.SecurityString()) - } - return buf.String() -} - -// AlterUserStmt modifies user account. -// See https://dev.mysql.com/doc/refman/8.0/en/alter-user.html -type AlterUserStmt struct { - stmtNode - - IfExists bool - CurrentAuth *AuthOption - Specs []*UserSpec - AuthTokenOrTLSOptions []*AuthTokenOrTLSOption - ResourceOptions []*ResourceOption - PasswordOrLockOptions []*PasswordOrLockOption - CommentOrAttributeOption *CommentOrAttributeOption - ResourceGroupNameOption *ResourceGroupNameOption -} - -// Restore implements Node interface. -func (n *AlterUserStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("ALTER USER ") - if n.IfExists { - ctx.WriteKeyWord("IF EXISTS ") - } - if n.CurrentAuth != nil { - ctx.WriteKeyWord("USER") - ctx.WritePlain("() ") - if err := n.CurrentAuth.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore AlterUserStmt.CurrentAuth") - } - } - for i, v := range n.Specs { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.Specs[%d]", i) - } - } - - if len(n.AuthTokenOrTLSOptions) != 0 { - ctx.WriteKeyWord(" REQUIRE ") - } - - for i, option := range n.AuthTokenOrTLSOptions { - if i != 0 { - ctx.WriteKeyWord(" AND ") - } - if err := option.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.AuthTokenOrTLSOptions[%d]", i) - } - } - - if len(n.ResourceOptions) != 0 { - ctx.WriteKeyWord(" WITH") - } - - for i, v := range n.ResourceOptions { - ctx.WritePlain(" ") - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.ResourceOptions[%d]", i) - } - } - - for i, v := range n.PasswordOrLockOptions { - ctx.WritePlain(" ") - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.PasswordOrLockOptions[%d]", i) - } - } - - if n.CommentOrAttributeOption != nil { - if err := n.CommentOrAttributeOption.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.CommentOrAttributeOption") - } - } - - if n.ResourceGroupNameOption != nil { - if err := n.ResourceGroupNameOption.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore AlterUserStmt.ResourceGroupNameOption") - } - } - - return nil -} - -// SecureText implements SensitiveStatement interface. -func (n *AlterUserStmt) SecureText() string { - var buf bytes.Buffer - buf.WriteString("alter user") - for _, user := range n.Specs { - buf.WriteString(" ") - buf.WriteString(user.SecurityString()) - } - return buf.String() -} - -// Accept implements Node Accept interface. -func (n *AlterUserStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*AlterUserStmt) - return v.Leave(n) -} - -// AlterInstanceStmt modifies instance. -// See https://dev.mysql.com/doc/refman/8.0/en/alter-instance.html -type AlterInstanceStmt struct { - stmtNode - - ReloadTLS bool - NoRollbackOnError bool -} - -// Restore implements Node interface. -func (n *AlterInstanceStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("ALTER INSTANCE") - if n.ReloadTLS { - ctx.WriteKeyWord(" RELOAD TLS") - } - if n.NoRollbackOnError { - ctx.WriteKeyWord(" NO ROLLBACK ON ERROR") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *AlterInstanceStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*AlterInstanceStmt) - return v.Leave(n) -} - -// AlterRangeStmt modifies range configuration. -type AlterRangeStmt struct { - stmtNode - RangeName model.CIStr - PlacementOption *PlacementOption -} - -// Restore implements Node interface. -func (n *AlterRangeStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("ALTER RANGE ") - ctx.WriteName(n.RangeName.O) - ctx.WritePlain(" ") - if err := n.PlacementOption.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore AlterRangeStmt.PlacementOption") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *AlterRangeStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*AlterRangeStmt) - return v.Leave(n) -} - -// DropUserStmt creates user account. -// See http://dev.mysql.com/doc/refman/5.7/en/drop-user.html -type DropUserStmt struct { - stmtNode - - IfExists bool - IsDropRole bool - UserList []*auth.UserIdentity -} - -// Restore implements Node interface. -func (n *DropUserStmt) Restore(ctx *format.RestoreCtx) error { - if n.IsDropRole { - ctx.WriteKeyWord("DROP ROLE ") - } else { - ctx.WriteKeyWord("DROP USER ") - } - if n.IfExists { - ctx.WriteKeyWord("IF EXISTS ") - } - for i, v := range n.UserList { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore DropUserStmt.UserList[%d]", i) - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *DropUserStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*DropUserStmt) - return v.Leave(n) -} - -// CreateBindingStmt creates sql binding hint. -type CreateBindingStmt struct { - stmtNode - - GlobalScope bool - OriginNode StmtNode - HintedNode StmtNode - PlanDigest string -} - -func (n *CreateBindingStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("CREATE ") - if n.GlobalScope { - ctx.WriteKeyWord("GLOBAL ") - } else { - ctx.WriteKeyWord("SESSION ") - } - if n.OriginNode == nil { - ctx.WriteKeyWord("BINDING FROM HISTORY USING PLAN DIGEST ") - ctx.WriteString(n.PlanDigest) - } else { - ctx.WriteKeyWord("BINDING FOR ") - if err := n.OriginNode.Restore(ctx); err != nil { - return errors.Trace(err) - } - ctx.WriteKeyWord(" USING ") - if err := n.HintedNode.Restore(ctx); err != nil { - return errors.Trace(err) - } - } - return nil -} - -func (n *CreateBindingStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*CreateBindingStmt) - if n.OriginNode != nil { - origNode, ok := n.OriginNode.Accept(v) - if !ok { - return n, false - } - n.OriginNode = origNode.(StmtNode) - hintedNode, ok := n.HintedNode.Accept(v) - if !ok { - return n, false - } - n.HintedNode = hintedNode.(StmtNode) - } - return v.Leave(n) -} - -// DropBindingStmt deletes sql binding hint. -type DropBindingStmt struct { - stmtNode - - GlobalScope bool - OriginNode StmtNode - HintedNode StmtNode - SQLDigest string -} - -func (n *DropBindingStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("DROP ") - if n.GlobalScope { - ctx.WriteKeyWord("GLOBAL ") - } else { - ctx.WriteKeyWord("SESSION ") - } - ctx.WriteKeyWord("BINDING FOR ") - if n.OriginNode == nil { - ctx.WriteKeyWord("SQL DIGEST ") - ctx.WriteString(n.SQLDigest) - } else { - if err := n.OriginNode.Restore(ctx); err != nil { - return errors.Trace(err) - } - if n.HintedNode != nil { - ctx.WriteKeyWord(" USING ") - if err := n.HintedNode.Restore(ctx); err != nil { - return errors.Trace(err) - } - } - } - return nil -} - -func (n *DropBindingStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*DropBindingStmt) - if n.OriginNode != nil { - // OriginNode is nil means we build drop binding by sql digest - origNode, ok := n.OriginNode.Accept(v) - if !ok { - return n, false - } - n.OriginNode = origNode.(StmtNode) - if n.HintedNode != nil { - hintedNode, ok := n.HintedNode.Accept(v) - if !ok { - return n, false - } - n.HintedNode = hintedNode.(StmtNode) - } - } - return v.Leave(n) -} - -// BindingStatusType defines the status type for the binding -type BindingStatusType int8 - -// Binding status types. -const ( - BindingStatusTypeEnabled BindingStatusType = iota - BindingStatusTypeDisabled -) - -// SetBindingStmt sets sql binding status. -type SetBindingStmt struct { - stmtNode - - BindingStatusType BindingStatusType - OriginNode StmtNode - HintedNode StmtNode - SQLDigest string -} - -func (n *SetBindingStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET ") - ctx.WriteKeyWord("BINDING ") - switch n.BindingStatusType { - case BindingStatusTypeEnabled: - ctx.WriteKeyWord("ENABLED ") - case BindingStatusTypeDisabled: - ctx.WriteKeyWord("DISABLED ") - } - ctx.WriteKeyWord("FOR ") - if n.OriginNode == nil { - ctx.WriteKeyWord("SQL DIGEST ") - ctx.WriteString(n.SQLDigest) - } else { - if err := n.OriginNode.Restore(ctx); err != nil { - return errors.Trace(err) - } - if n.HintedNode != nil { - ctx.WriteKeyWord(" USING ") - if err := n.HintedNode.Restore(ctx); err != nil { - return errors.Trace(err) - } - } - } - return nil -} - -func (n *SetBindingStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetBindingStmt) - if n.OriginNode != nil { - // OriginNode is nil means we set binding stmt by sql digest - origNode, ok := n.OriginNode.Accept(v) - if !ok { - return n, false - } - n.OriginNode = origNode.(StmtNode) - if n.HintedNode != nil { - hintedNode, ok := n.HintedNode.Accept(v) - if !ok { - return n, false - } - n.HintedNode = hintedNode.(StmtNode) - } - } - return v.Leave(n) -} - -// Extended statistics types. -const ( - StatsTypeCardinality uint8 = iota - StatsTypeDependency - StatsTypeCorrelation -) - -// StatisticsSpec is the specification for ADD /DROP STATISTICS. -type StatisticsSpec struct { - StatsName string - StatsType uint8 - Columns []*ColumnName -} - -// CreateStatisticsStmt is a statement to create extended statistics. -// Examples: -// -// CREATE STATISTICS stats1 (cardinality) ON t(a, b, c); -// CREATE STATISTICS stats2 (dependency) ON t(a, b); -// CREATE STATISTICS stats3 (correlation) ON t(a, b); -type CreateStatisticsStmt struct { - stmtNode - - IfNotExists bool - StatsName string - StatsType uint8 - Table *TableName - Columns []*ColumnName -} - -// Restore implements Node interface. -func (n *CreateStatisticsStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("CREATE STATISTICS ") - if n.IfNotExists { - ctx.WriteKeyWord("IF NOT EXISTS ") - } - ctx.WriteName(n.StatsName) - switch n.StatsType { - case StatsTypeCardinality: - ctx.WriteKeyWord(" (cardinality) ") - case StatsTypeDependency: - ctx.WriteKeyWord(" (dependency) ") - case StatsTypeCorrelation: - ctx.WriteKeyWord(" (correlation) ") - } - ctx.WriteKeyWord("ON ") - if err := n.Table.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore CreateStatisticsStmt.Table") - } - - ctx.WritePlain("(") - for i, col := range n.Columns { - if i != 0 { - ctx.WritePlain(", ") - } - if err := col.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore CreateStatisticsStmt.Columns: [%v]", i) - } - } - ctx.WritePlain(")") - return nil -} - -// Accept implements Node Accept interface. -func (n *CreateStatisticsStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*CreateStatisticsStmt) - node, ok := n.Table.Accept(v) - if !ok { - return n, false - } - n.Table = node.(*TableName) - for i, col := range n.Columns { - node, ok = col.Accept(v) - if !ok { - return n, false - } - n.Columns[i] = node.(*ColumnName) - } - return v.Leave(n) -} - -// DropStatisticsStmt is a statement to drop extended statistics. -// Examples: -// -// DROP STATISTICS stats1; -type DropStatisticsStmt struct { - stmtNode - - StatsName string -} - -// Restore implements Node interface. -func (n *DropStatisticsStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("DROP STATISTICS ") - ctx.WriteName(n.StatsName) - return nil -} - -// Accept implements Node Accept interface. -func (n *DropStatisticsStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*DropStatisticsStmt) - return v.Leave(n) -} - -// DoStmt is the struct for DO statement. -type DoStmt struct { - stmtNode - - Exprs []ExprNode -} - -// Restore implements Node interface. -func (n *DoStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("DO ") - for i, v := range n.Exprs { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore DoStmt.Exprs[%d]", i) - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *DoStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*DoStmt) - for i, val := range n.Exprs { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.Exprs[i] = node.(ExprNode) - } - return v.Leave(n) -} - -// AdminStmtType is the type for admin statement. -type AdminStmtType int - -// Admin statement types. -const ( - AdminShowDDL AdminStmtType = iota + 1 - AdminCheckTable - AdminShowDDLJobs - AdminCancelDDLJobs - AdminPauseDDLJobs - AdminResumeDDLJobs - AdminCheckIndex - AdminRecoverIndex - AdminCleanupIndex - AdminCheckIndexRange - AdminShowDDLJobQueries - AdminShowDDLJobQueriesWithRange - AdminChecksumTable - AdminShowSlow - AdminShowNextRowID - AdminReloadExprPushdownBlacklist - AdminReloadOptRuleBlacklist - AdminPluginDisable - AdminPluginEnable - AdminFlushBindings - AdminCaptureBindings - AdminEvolveBindings - AdminReloadBindings - AdminReloadStatistics - AdminFlushPlanCache - AdminSetBDRRole - AdminShowBDRRole - AdminUnsetBDRRole -) - -// HandleRange represents a range where handle value >= Begin and < End. -type HandleRange struct { - Begin int64 - End int64 -} - -// BDRRole represents the role of the cluster in BDR mode. -type BDRRole string - -const ( - BDRRolePrimary BDRRole = "primary" - BDRRoleSecondary BDRRole = "secondary" - BDRRoleNone BDRRole = "" -) - -// DeniedByBDR checks whether the DDL is denied by BDR. -func DeniedByBDR(role BDRRole, action model.ActionType, job *model.Job) (denied bool) { - ddlType, ok := model.ActionBDRMap[action] - switch role { - case BDRRolePrimary: - if !ok { - return true - } - - // Can't add unique index on primary role. - if job != nil && (action == model.ActionAddIndex || action == model.ActionAddPrimaryKey) && - len(job.Args) >= 1 && job.Args[0].(bool) { - // job.Args[0] is unique when job.Type is ActionAddIndex or ActionAddPrimaryKey. - return true - } - - if ddlType == model.SafeDDL || ddlType == model.UnmanagementDDL { - return false - } - case BDRRoleSecondary: - if !ok { - return true - } - if ddlType == model.UnmanagementDDL { - return false - } - default: - // if user do not set bdr role, we will not deny any ddl as `none` - return false - } - - return true -} - -type StatementScope int - -const ( - StatementScopeNone StatementScope = iota - StatementScopeSession - StatementScopeInstance - StatementScopeGlobal -) - -// ShowSlowType defines the type for SlowSlow statement. -type ShowSlowType int - -const ( - // ShowSlowTop is a ShowSlowType constant. - ShowSlowTop ShowSlowType = iota - // ShowSlowRecent is a ShowSlowType constant. - ShowSlowRecent -) - -// ShowSlowKind defines the kind for SlowSlow statement when the type is ShowSlowTop. -type ShowSlowKind int - -const ( - // ShowSlowKindDefault is a ShowSlowKind constant. - ShowSlowKindDefault ShowSlowKind = iota - // ShowSlowKindInternal is a ShowSlowKind constant. - ShowSlowKindInternal - // ShowSlowKindAll is a ShowSlowKind constant. - ShowSlowKindAll -) - -// ShowSlow is used for the following command: -// -// admin show slow top [ internal | all] N -// admin show slow recent N -type ShowSlow struct { - Tp ShowSlowType - Count uint64 - Kind ShowSlowKind -} - -// Restore implements Node interface. -func (n *ShowSlow) Restore(ctx *format.RestoreCtx) error { - switch n.Tp { - case ShowSlowRecent: - ctx.WriteKeyWord("RECENT ") - case ShowSlowTop: - ctx.WriteKeyWord("TOP ") - switch n.Kind { - case ShowSlowKindDefault: - // do nothing - case ShowSlowKindInternal: - ctx.WriteKeyWord("INTERNAL ") - case ShowSlowKindAll: - ctx.WriteKeyWord("ALL ") - default: - return errors.New("Unsupported kind of ShowSlowTop") - } - default: - return errors.New("Unsupported type of ShowSlow") - } - ctx.WritePlainf("%d", n.Count) - return nil -} - -// LimitSimple is the struct for Admin statement limit option. -type LimitSimple struct { - Count uint64 - Offset uint64 -} - -// AdminStmt is the struct for Admin statement. -type AdminStmt struct { - stmtNode - - Tp AdminStmtType - Index string - Tables []*TableName - JobIDs []int64 - JobNumber int64 - - HandleRanges []HandleRange - ShowSlow *ShowSlow - Plugins []string - Where ExprNode - StatementScope StatementScope - LimitSimple LimitSimple - BDRRole BDRRole -} - -// Restore implements Node interface. -func (n *AdminStmt) Restore(ctx *format.RestoreCtx) error { - restoreTables := func() error { - for i, v := range n.Tables { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore AdminStmt.Tables[%d]", i) - } - } - return nil - } - restoreJobIDs := func() { - for i, v := range n.JobIDs { - if i != 0 { - ctx.WritePlain(", ") - } - ctx.WritePlainf("%d", v) - } - } - - ctx.WriteKeyWord("ADMIN ") - switch n.Tp { - case AdminShowDDL: - ctx.WriteKeyWord("SHOW DDL") - case AdminShowDDLJobs: - ctx.WriteKeyWord("SHOW DDL JOBS") - if n.JobNumber != 0 { - ctx.WritePlainf(" %d", n.JobNumber) - } - if n.Where != nil { - ctx.WriteKeyWord(" WHERE ") - if err := n.Where.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore ShowStmt.Where") - } - } - case AdminShowNextRowID: - ctx.WriteKeyWord("SHOW ") - if err := restoreTables(); err != nil { - return err - } - ctx.WriteKeyWord(" NEXT_ROW_ID") - case AdminCheckTable: - ctx.WriteKeyWord("CHECK TABLE ") - if err := restoreTables(); err != nil { - return err - } - case AdminCheckIndex: - ctx.WriteKeyWord("CHECK INDEX ") - if err := restoreTables(); err != nil { - return err - } - ctx.WritePlainf(" %s", n.Index) - case AdminRecoverIndex: - ctx.WriteKeyWord("RECOVER INDEX ") - if err := restoreTables(); err != nil { - return err - } - ctx.WritePlainf(" %s", n.Index) - case AdminCleanupIndex: - ctx.WriteKeyWord("CLEANUP INDEX ") - if err := restoreTables(); err != nil { - return err - } - ctx.WritePlainf(" %s", n.Index) - case AdminCheckIndexRange: - ctx.WriteKeyWord("CHECK INDEX ") - if err := restoreTables(); err != nil { - return err - } - ctx.WritePlainf(" %s", n.Index) - if n.HandleRanges != nil { - ctx.WritePlain(" ") - for i, v := range n.HandleRanges { - if i != 0 { - ctx.WritePlain(", ") - } - ctx.WritePlainf("(%d,%d)", v.Begin, v.End) - } - } - case AdminChecksumTable: - ctx.WriteKeyWord("CHECKSUM TABLE ") - if err := restoreTables(); err != nil { - return err - } - case AdminCancelDDLJobs: - ctx.WriteKeyWord("CANCEL DDL JOBS ") - restoreJobIDs() - case AdminPauseDDLJobs: - ctx.WriteKeyWord("PAUSE DDL JOBS ") - restoreJobIDs() - case AdminResumeDDLJobs: - ctx.WriteKeyWord("RESUME DDL JOBS ") - restoreJobIDs() - case AdminShowDDLJobQueries: - ctx.WriteKeyWord("SHOW DDL JOB QUERIES ") - restoreJobIDs() - case AdminShowDDLJobQueriesWithRange: - ctx.WriteKeyWord("SHOW DDL JOB QUERIES LIMIT ") - ctx.WritePlainf("%d, %d", n.LimitSimple.Offset, n.LimitSimple.Count) - case AdminShowSlow: - ctx.WriteKeyWord("SHOW SLOW ") - if err := n.ShowSlow.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore AdminStmt.ShowSlow") - } - case AdminReloadExprPushdownBlacklist: - ctx.WriteKeyWord("RELOAD EXPR_PUSHDOWN_BLACKLIST") - case AdminReloadOptRuleBlacklist: - ctx.WriteKeyWord("RELOAD OPT_RULE_BLACKLIST") - case AdminPluginEnable: - ctx.WriteKeyWord("PLUGINS ENABLE") - for i, v := range n.Plugins { - if i == 0 { - ctx.WritePlain(" ") - } else { - ctx.WritePlain(", ") - } - ctx.WritePlain(v) - } - case AdminPluginDisable: - ctx.WriteKeyWord("PLUGINS DISABLE") - for i, v := range n.Plugins { - if i == 0 { - ctx.WritePlain(" ") - } else { - ctx.WritePlain(", ") - } - ctx.WritePlain(v) - } - case AdminFlushBindings: - ctx.WriteKeyWord("FLUSH BINDINGS") - case AdminCaptureBindings: - ctx.WriteKeyWord("CAPTURE BINDINGS") - case AdminEvolveBindings: - ctx.WriteKeyWord("EVOLVE BINDINGS") - case AdminReloadBindings: - ctx.WriteKeyWord("RELOAD BINDINGS") - case AdminReloadStatistics: - ctx.WriteKeyWord("RELOAD STATS_EXTENDED") - case AdminFlushPlanCache: - if n.StatementScope == StatementScopeSession { - ctx.WriteKeyWord("FLUSH SESSION PLAN_CACHE") - } else if n.StatementScope == StatementScopeInstance { - ctx.WriteKeyWord("FLUSH INSTANCE PLAN_CACHE") - } else if n.StatementScope == StatementScopeGlobal { - ctx.WriteKeyWord("FLUSH GLOBAL PLAN_CACHE") - } - case AdminSetBDRRole: - switch n.BDRRole { - case BDRRolePrimary: - ctx.WriteKeyWord("SET BDR ROLE PRIMARY") - case BDRRoleSecondary: - ctx.WriteKeyWord("SET BDR ROLE SECONDARY") - default: - return errors.New("Unsupported BDR role") - } - case AdminShowBDRRole: - ctx.WriteKeyWord("SHOW BDR ROLE") - case AdminUnsetBDRRole: - ctx.WriteKeyWord("UNSET BDR ROLE") - default: - return errors.New("Unsupported AdminStmt type") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *AdminStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - - n = newNode.(*AdminStmt) - for i, val := range n.Tables { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.Tables[i] = node.(*TableName) - } - - if n.Where != nil { - node, ok := n.Where.Accept(v) - if !ok { - return n, false - } - n.Where = node.(ExprNode) - } - - return v.Leave(n) -} - -// RoleOrPriv is a temporary structure to be further processed into auth.RoleIdentity or PrivElem -type RoleOrPriv struct { - Symbols string // hold undecided symbols - Node interface{} // hold auth.RoleIdentity or PrivElem that can be sure when parsing -} - -func (n *RoleOrPriv) ToRole() (*auth.RoleIdentity, error) { - if n.Node != nil { - if r, ok := n.Node.(*auth.RoleIdentity); ok { - return r, nil - } - return nil, errors.Errorf("can't convert to RoleIdentity, type %T", n.Node) - } - return &auth.RoleIdentity{Username: n.Symbols, Hostname: "%"}, nil -} - -func (n *RoleOrPriv) ToPriv() (*PrivElem, error) { - if n.Node != nil { - if p, ok := n.Node.(*PrivElem); ok { - return p, nil - } - return nil, errors.Errorf("can't convert to PrivElem, type %T", n.Node) - } - if len(n.Symbols) == 0 { - return nil, errors.New("symbols should not be length 0") - } - return &PrivElem{Priv: mysql.ExtendedPriv, Name: n.Symbols}, nil -} - -// PrivElem is the privilege type and optional column list. -type PrivElem struct { - node - - Priv mysql.PrivilegeType - Cols []*ColumnName - Name string -} - -// Restore implements Node interface. -func (n *PrivElem) Restore(ctx *format.RestoreCtx) error { - if n.Priv == mysql.AllPriv { - ctx.WriteKeyWord("ALL") - } else if n.Priv == mysql.ExtendedPriv { - ctx.WriteKeyWord(n.Name) - } else { - str, ok := mysql.Priv2Str[n.Priv] - if !ok { - return errors.New("Undefined privilege type") - } - ctx.WriteKeyWord(str) - } - if n.Cols != nil { - ctx.WritePlain(" (") - for i, v := range n.Cols { - if i != 0 { - ctx.WritePlain(",") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore PrivElem.Cols[%d]", i) - } - } - ctx.WritePlain(")") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *PrivElem) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*PrivElem) - for i, val := range n.Cols { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.Cols[i] = node.(*ColumnName) - } - return v.Leave(n) -} - -// ObjectTypeType is the type for object type. -type ObjectTypeType int - -const ( - // ObjectTypeNone is for empty object type. - ObjectTypeNone ObjectTypeType = iota + 1 - // ObjectTypeTable means the following object is a table. - ObjectTypeTable - // ObjectTypeFunction means the following object is a stored function. - ObjectTypeFunction - // ObjectTypeProcedure means the following object is a stored procedure. - ObjectTypeProcedure -) - -// Restore implements Node interface. -func (n ObjectTypeType) Restore(ctx *format.RestoreCtx) error { - switch n { - case ObjectTypeNone: - // do nothing - case ObjectTypeTable: - ctx.WriteKeyWord("TABLE") - case ObjectTypeFunction: - ctx.WriteKeyWord("FUNCTION") - case ObjectTypeProcedure: - ctx.WriteKeyWord("PROCEDURE") - default: - return errors.New("Unsupported object type") - } - return nil -} - -// GrantLevelType is the type for grant level. -type GrantLevelType int - -const ( - // GrantLevelNone is the dummy const for default value. - GrantLevelNone GrantLevelType = iota + 1 - // GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server. - GrantLevelGlobal - // GrantLevelDB means the privileges apply to all objects in a given database. - GrantLevelDB - // GrantLevelTable means the privileges apply to all columns in a given table. - GrantLevelTable -) - -// GrantLevel is used for store the privilege scope. -type GrantLevel struct { - Level GrantLevelType - DBName string - TableName string -} - -// Restore implements Node interface. -func (n *GrantLevel) Restore(ctx *format.RestoreCtx) error { - switch n.Level { - case GrantLevelDB: - if n.DBName == "" { - ctx.WritePlain("*") - } else { - ctx.WriteName(n.DBName) - ctx.WritePlain(".*") - } - case GrantLevelGlobal: - ctx.WritePlain("*.*") - case GrantLevelTable: - if n.DBName != "" { - ctx.WriteName(n.DBName) - ctx.WritePlain(".") - } - ctx.WriteName(n.TableName) - } - return nil -} - -// RevokeStmt is the struct for REVOKE statement. -type RevokeStmt struct { - stmtNode - - Privs []*PrivElem - ObjectType ObjectTypeType - Level *GrantLevel - Users []*UserSpec -} - -// Restore implements Node interface. -func (n *RevokeStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("REVOKE ") - for i, v := range n.Privs { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore RevokeStmt.Privs[%d]", i) - } - } - ctx.WriteKeyWord(" ON ") - if n.ObjectType != ObjectTypeNone { - if err := n.ObjectType.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore RevokeStmt.ObjectType") - } - ctx.WritePlain(" ") - } - if err := n.Level.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore RevokeStmt.Level") - } - ctx.WriteKeyWord(" FROM ") - for i, v := range n.Users { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore RevokeStmt.Users[%d]", i) - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *RevokeStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*RevokeStmt) - for i, val := range n.Privs { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.Privs[i] = node.(*PrivElem) - } - return v.Leave(n) -} - -// RevokeStmt is the struct for REVOKE statement. -type RevokeRoleStmt struct { - stmtNode - - Roles []*auth.RoleIdentity - Users []*auth.UserIdentity -} - -// Restore implements Node interface. -func (n *RevokeRoleStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("REVOKE ") - for i, role := range n.Roles { - if i != 0 { - ctx.WritePlain(", ") - } - if err := role.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore RevokeRoleStmt.Roles[%d]", i) - } - } - ctx.WriteKeyWord(" FROM ") - for i, v := range n.Users { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore RevokeRoleStmt.Users[%d]", i) - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *RevokeRoleStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*RevokeRoleStmt) - return v.Leave(n) -} - -// GrantStmt is the struct for GRANT statement. -type GrantStmt struct { - stmtNode - - Privs []*PrivElem - ObjectType ObjectTypeType - Level *GrantLevel - Users []*UserSpec - AuthTokenOrTLSOptions []*AuthTokenOrTLSOption - WithGrant bool -} - -// Restore implements Node interface. -func (n *GrantStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("GRANT ") - for i, v := range n.Privs { - if i != 0 && v.Priv != 0 { - ctx.WritePlain(", ") - } else if v.Priv == 0 { - ctx.WritePlain(" ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore GrantStmt.Privs[%d]", i) - } - } - ctx.WriteKeyWord(" ON ") - if n.ObjectType != ObjectTypeNone { - if err := n.ObjectType.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore GrantStmt.ObjectType") - } - ctx.WritePlain(" ") - } - if err := n.Level.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore GrantStmt.Level") - } - ctx.WriteKeyWord(" TO ") - for i, v := range n.Users { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore GrantStmt.Users[%d]", i) - } - } - if n.AuthTokenOrTLSOptions != nil { - if len(n.AuthTokenOrTLSOptions) != 0 { - ctx.WriteKeyWord(" REQUIRE ") - } - for i, option := range n.AuthTokenOrTLSOptions { - if i != 0 { - ctx.WriteKeyWord(" AND ") - } - if err := option.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore GrantStmt.AuthTokenOrTLSOptions[%d]", i) - } - } - } - if n.WithGrant { - ctx.WriteKeyWord(" WITH GRANT OPTION") - } - return nil -} - -// SecureText implements SensitiveStatement interface. -func (n *GrantStmt) SecureText() string { - text := n.text - // Filter "identified by xxx" because it would expose password information. - idx := strings.Index(strings.ToLower(text), "identified") - if idx > 0 { - text = text[:idx] - } - return text -} - -// Accept implements Node Accept interface. -func (n *GrantStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*GrantStmt) - for i, val := range n.Privs { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.Privs[i] = node.(*PrivElem) - } - return v.Leave(n) -} - -// GrantProxyStmt is the struct for GRANT PROXY statement. -type GrantProxyStmt struct { - stmtNode - - LocalUser *auth.UserIdentity - ExternalUsers []*auth.UserIdentity - WithGrant bool -} - -// Accept implements Node Accept interface. -func (n *GrantProxyStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*GrantProxyStmt) - return v.Leave(n) -} - -// Restore implements Node interface. -func (n *GrantProxyStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("GRANT PROXY ON ") - if err := n.LocalUser.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore GrantProxyStmt.LocalUser") - } - ctx.WriteKeyWord(" TO ") - for i, v := range n.ExternalUsers { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore GrantProxyStmt.ExternalUsers[%d]", i) - } - } - if n.WithGrant { - ctx.WriteKeyWord(" WITH GRANT OPTION") - } - return nil -} - -// GrantRoleStmt is the struct for GRANT TO statement. -type GrantRoleStmt struct { - stmtNode - - Roles []*auth.RoleIdentity - Users []*auth.UserIdentity -} - -// Accept implements Node Accept interface. -func (n *GrantRoleStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*GrantRoleStmt) - return v.Leave(n) -} - -// Restore implements Node interface. -func (n *GrantRoleStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("GRANT ") - if len(n.Roles) > 0 { - for i, role := range n.Roles { - if i != 0 { - ctx.WritePlain(", ") - } - if err := role.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore GrantRoleStmt.Roles[%d]", i) - } - } - } - ctx.WriteKeyWord(" TO ") - for i, v := range n.Users { - if i != 0 { - ctx.WritePlain(", ") - } - if err := v.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore GrantStmt.Users[%d]", i) - } - } - return nil -} - -// SecureText implements SensitiveStatement interface. -func (n *GrantRoleStmt) SecureText() string { - text := n.text - // Filter "identified by xxx" because it would expose password information. - idx := strings.Index(strings.ToLower(text), "identified") - if idx > 0 { - text = text[:idx] - } - return text -} - -// ShutdownStmt is a statement to stop the TiDB server. -// See https://dev.mysql.com/doc/refman/5.7/en/shutdown.html -type ShutdownStmt struct { - stmtNode -} - -// Restore implements Node interface. -func (n *ShutdownStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SHUTDOWN") - return nil -} - -// Accept implements Node Accept interface. -func (n *ShutdownStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*ShutdownStmt) - return v.Leave(n) -} - -// RestartStmt is a statement to restart the TiDB server. -// See https://dev.mysql.com/doc/refman/8.0/en/restart.html -type RestartStmt struct { - stmtNode -} - -// Restore implements Node interface. -func (n *RestartStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("RESTART") - return nil -} - -// Accept implements Node Accept interface. -func (n *RestartStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*RestartStmt) - return v.Leave(n) -} - -// HelpStmt is a statement for server side help -// See https://dev.mysql.com/doc/refman/8.0/en/help.html -type HelpStmt struct { - stmtNode - - Topic string -} - -// Restore implements Node interface. -func (n *HelpStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("HELP ") - ctx.WriteString(n.Topic) - return nil -} - -// Accept implements Node Accept interface. -func (n *HelpStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*HelpStmt) - return v.Leave(n) -} - -// RenameUserStmt is a statement to rename a user. -// See http://dev.mysql.com/doc/refman/5.7/en/rename-user.html -type RenameUserStmt struct { - stmtNode - - UserToUsers []*UserToUser -} - -// Restore implements Node interface. -func (n *RenameUserStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("RENAME USER ") - for index, user2user := range n.UserToUsers { - if index != 0 { - ctx.WritePlain(", ") - } - if err := user2user.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore RenameUserStmt.UserToUsers") - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *RenameUserStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*RenameUserStmt) - - for i, t := range n.UserToUsers { - node, ok := t.Accept(v) - if !ok { - return n, false - } - n.UserToUsers[i] = node.(*UserToUser) - } - return v.Leave(n) -} - -// UserToUser represents renaming old user to new user used in RenameUserStmt. -type UserToUser struct { - node - OldUser *auth.UserIdentity - NewUser *auth.UserIdentity -} - -// Restore implements Node interface. -func (n *UserToUser) Restore(ctx *format.RestoreCtx) error { - if err := n.OldUser.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore UserToUser.OldUser") - } - ctx.WriteKeyWord(" TO ") - if err := n.NewUser.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore UserToUser.NewUser") - } - return nil -} - -// Accept implements Node Accept interface. -func (n *UserToUser) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*UserToUser) - return v.Leave(n) -} - -type BRIEKind uint8 -type BRIEOptionType uint16 - -const ( - BRIEKindBackup BRIEKind = iota - BRIEKindCancelJob - BRIEKindStreamStart - BRIEKindStreamMetaData - BRIEKindStreamStatus - BRIEKindStreamPause - BRIEKindStreamResume - BRIEKindStreamStop - BRIEKindStreamPurge - BRIEKindRestore - BRIEKindRestorePIT - BRIEKindShowJob - BRIEKindShowQuery - BRIEKindShowBackupMeta - // common BRIE options - BRIEOptionRateLimit BRIEOptionType = iota + 1 - BRIEOptionConcurrency - BRIEOptionChecksum - BRIEOptionSendCreds - BRIEOptionCheckpoint - BRIEOptionStartTS - BRIEOptionUntilTS - BRIEOptionChecksumConcurrency - BRIEOptionEncryptionMethod - BRIEOptionEncryptionKeyFile - // backup options - BRIEOptionBackupTimeAgo - BRIEOptionBackupTS - BRIEOptionBackupTSO - BRIEOptionLastBackupTS - BRIEOptionLastBackupTSO - BRIEOptionGCTTL - BRIEOptionCompressionLevel - BRIEOptionCompression - BRIEOptionIgnoreStats - BRIEOptionLoadStats - // restore options - BRIEOptionOnline - BRIEOptionFullBackupStorage - BRIEOptionRestoredTS - BRIEOptionWaitTiflashReady - BRIEOptionWithSysTable - // import options - BRIEOptionAnalyze - BRIEOptionBackend - BRIEOptionOnDuplicate - BRIEOptionSkipSchemaFiles - BRIEOptionStrictFormat - BRIEOptionTiKVImporter - BRIEOptionResume - // CSV options - BRIEOptionCSVBackslashEscape - BRIEOptionCSVDelimiter - BRIEOptionCSVHeader - BRIEOptionCSVNotNull - BRIEOptionCSVNull - BRIEOptionCSVSeparator - BRIEOptionCSVTrimLastSeparators - - BRIECSVHeaderIsColumns = ^uint64(0) -) - -type BRIEOptionLevel uint64 - -const ( - BRIEOptionLevelOff BRIEOptionLevel = iota // equals FALSE - BRIEOptionLevelRequired // equals TRUE - BRIEOptionLevelOptional -) - -func (kind BRIEKind) String() string { - switch kind { - case BRIEKindBackup: - return "BACKUP" - case BRIEKindRestore: - return "RESTORE" - case BRIEKindStreamStart: - return "BACKUP LOGS" - case BRIEKindStreamStop: - return "STOP BACKUP LOGS" - case BRIEKindStreamPause: - return "PAUSE BACKUP LOGS" - case BRIEKindStreamResume: - return "RESUME BACKUP LOGS" - case BRIEKindStreamStatus: - return "SHOW BACKUP LOGS STATUS" - case BRIEKindStreamMetaData: - return "SHOW BACKUP LOGS METADATA" - case BRIEKindStreamPurge: - return "PURGE BACKUP LOGS" - case BRIEKindRestorePIT: - return "RESTORE POINT" - case BRIEKindShowJob: - return "SHOW BR JOB" - case BRIEKindShowQuery: - return "SHOW BR JOB QUERY" - case BRIEKindCancelJob: - return "CANCEL BR JOB" - case BRIEKindShowBackupMeta: - return "SHOW BACKUP METADATA" - default: - return "" - } -} - -func (kind BRIEOptionType) String() string { - switch kind { - case BRIEOptionRateLimit: - return "RATE_LIMIT" - case BRIEOptionConcurrency: - return "CONCURRENCY" - case BRIEOptionChecksum: - return "CHECKSUM" - case BRIEOptionSendCreds: - return "SEND_CREDENTIALS_TO_TIKV" - case BRIEOptionBackupTimeAgo, BRIEOptionBackupTS, BRIEOptionBackupTSO: - return "SNAPSHOT" - case BRIEOptionLastBackupTS, BRIEOptionLastBackupTSO: - return "LAST_BACKUP" - case BRIEOptionOnline: - return "ONLINE" - case BRIEOptionCheckpoint: - return "CHECKPOINT" - case BRIEOptionAnalyze: - return "ANALYZE" - case BRIEOptionBackend: - return "BACKEND" - case BRIEOptionOnDuplicate: - return "ON_DUPLICATE" - case BRIEOptionSkipSchemaFiles: - return "SKIP_SCHEMA_FILES" - case BRIEOptionStrictFormat: - return "STRICT_FORMAT" - case BRIEOptionTiKVImporter: - return "TIKV_IMPORTER" - case BRIEOptionResume: - return "RESUME" - case BRIEOptionCSVBackslashEscape: - return "CSV_BACKSLASH_ESCAPE" - case BRIEOptionCSVDelimiter: - return "CSV_DELIMITER" - case BRIEOptionCSVHeader: - return "CSV_HEADER" - case BRIEOptionCSVNotNull: - return "CSV_NOT_NULL" - case BRIEOptionCSVNull: - return "CSV_NULL" - case BRIEOptionCSVSeparator: - return "CSV_SEPARATOR" - case BRIEOptionCSVTrimLastSeparators: - return "CSV_TRIM_LAST_SEPARATORS" - case BRIEOptionFullBackupStorage: - return "FULL_BACKUP_STORAGE" - case BRIEOptionRestoredTS: - return "RESTORED_TS" - case BRIEOptionStartTS: - return "START_TS" - case BRIEOptionUntilTS: - return "UNTIL_TS" - case BRIEOptionGCTTL: - return "GC_TTL" - case BRIEOptionWaitTiflashReady: - return "WAIT_TIFLASH_READY" - case BRIEOptionWithSysTable: - return "WITH_SYS_TABLE" - case BRIEOptionIgnoreStats: - return "IGNORE_STATS" - case BRIEOptionLoadStats: - return "LOAD_STATS" - case BRIEOptionChecksumConcurrency: - return "CHECKSUM_CONCURRENCY" - case BRIEOptionCompressionLevel: - return "COMPRESSION_LEVEL" - case BRIEOptionCompression: - return "COMPRESSION_TYPE" - case BRIEOptionEncryptionMethod: - return "ENCRYPTION_METHOD" - case BRIEOptionEncryptionKeyFile: - return "ENCRYPTION_KEY_FILE" - default: - return "" - } -} - -func (level BRIEOptionLevel) String() string { - switch level { - case BRIEOptionLevelOff: - return "OFF" - case BRIEOptionLevelOptional: - return "OPTIONAL" - case BRIEOptionLevelRequired: - return "REQUIRED" - default: - return "" - } -} - -type BRIEOption struct { - Tp BRIEOptionType - StrValue string - UintValue uint64 -} - -func (opt *BRIEOption) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord(opt.Tp.String()) - ctx.WritePlain(" = ") - switch opt.Tp { - case BRIEOptionBackupTS, BRIEOptionLastBackupTS, BRIEOptionBackend, BRIEOptionOnDuplicate, BRIEOptionTiKVImporter, BRIEOptionCSVDelimiter, BRIEOptionCSVNull, BRIEOptionCSVSeparator, BRIEOptionFullBackupStorage, BRIEOptionRestoredTS, BRIEOptionStartTS, BRIEOptionUntilTS, BRIEOptionGCTTL, BRIEOptionCompression, BRIEOptionEncryptionMethod, BRIEOptionEncryptionKeyFile: - ctx.WriteString(opt.StrValue) - case BRIEOptionBackupTimeAgo: - ctx.WritePlainf("%d ", opt.UintValue/1000) - ctx.WriteKeyWord("MICROSECOND AGO") - case BRIEOptionRateLimit: - ctx.WritePlainf("%d ", opt.UintValue/1048576) - ctx.WriteKeyWord("MB") - ctx.WritePlain("/") - ctx.WriteKeyWord("SECOND") - case BRIEOptionCSVHeader: - if opt.UintValue == BRIECSVHeaderIsColumns { - ctx.WriteKeyWord("COLUMNS") - } else { - ctx.WritePlainf("%d", opt.UintValue) - } - case BRIEOptionChecksum, BRIEOptionAnalyze: - // BACKUP/RESTORE doesn't support OPTIONAL value for now, should warn at executor - ctx.WriteKeyWord(BRIEOptionLevel(opt.UintValue).String()) - default: - ctx.WritePlainf("%d", opt.UintValue) - } - return nil -} - -// BRIEStmt is a statement for backup, restore, import and export. -type BRIEStmt struct { - stmtNode - - Kind BRIEKind - Schemas []string - Tables []*TableName - Storage string - JobID int64 - Options []*BRIEOption -} - -func (n *BRIEStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*BRIEStmt) - for i, val := range n.Tables { - node, ok := val.Accept(v) - if !ok { - return n, false - } - n.Tables[i] = node.(*TableName) - } - return v.Leave(n) -} - -func (n *BRIEStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord(n.Kind.String()) - - switch n.Kind { - case BRIEKindRestore, BRIEKindBackup: - switch { - case len(n.Tables) != 0: - ctx.WriteKeyWord(" TABLE ") - for index, table := range n.Tables { - if index != 0 { - ctx.WritePlain(", ") - } - if err := table.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while restore BRIEStmt.Tables[%d]", index) - } - } - case len(n.Schemas) != 0: - ctx.WriteKeyWord(" DATABASE ") - for index, schema := range n.Schemas { - if index != 0 { - ctx.WritePlain(", ") - } - ctx.WriteName(schema) - } - default: - ctx.WriteKeyWord(" DATABASE") - ctx.WritePlain(" *") - } - - if n.Kind == BRIEKindBackup { - ctx.WriteKeyWord(" TO ") - ctx.WriteString(n.Storage) - } else { - ctx.WriteKeyWord(" FROM ") - ctx.WriteString(n.Storage) - } - case BRIEKindCancelJob, BRIEKindShowJob, BRIEKindShowQuery: - ctx.WritePlainf(" %d", n.JobID) - case BRIEKindStreamStart: - ctx.WriteKeyWord(" TO ") - ctx.WriteString(n.Storage) - case BRIEKindRestorePIT, BRIEKindStreamMetaData, BRIEKindShowBackupMeta, BRIEKindStreamPurge: - ctx.WriteKeyWord(" FROM ") - ctx.WriteString(n.Storage) - } - - for _, opt := range n.Options { - ctx.WritePlain(" ") - if err := opt.Restore(ctx); err != nil { - return err - } - } - - return nil -} - -// RedactURL redacts the secret tokens in the URL. only S3 url need redaction for now. -// if the url is not a valid url, return the original string. -func RedactURL(str string) string { - // FIXME: this solution is not scalable, and duplicates some logic from BR. - u, err := url.Parse(str) - if err != nil { - return str - } - scheme := u.Scheme - failpoint.Inject("forceRedactURL", func() { - scheme = "s3" - }) - switch strings.ToLower(scheme) { - case "s3", "ks3": - values := u.Query() - for k := range values { - // see below on why we normalize key - // https://github.com/pingcap/tidb/blob/a7c0d95f16ea2582bb569278c3f829403e6c3a7e/br/pkg/storage/parse.go#L163 - normalizedKey := strings.ToLower(strings.ReplaceAll(k, "_", "-")) - if normalizedKey == "access-key" || normalizedKey == "secret-access-key" || normalizedKey == "session-token" { - values[k] = []string{"xxxxxx"} - } - } - u.RawQuery = values.Encode() - } - return u.String() -} - -// SecureText implements SensitiveStmtNode -func (n *BRIEStmt) SecureText() string { - redactedStmt := &BRIEStmt{ - Kind: n.Kind, - Schemas: n.Schemas, - Tables: n.Tables, - Storage: RedactURL(n.Storage), - Options: n.Options, - } - - var sb strings.Builder - _ = redactedStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb)) - return sb.String() -} - -type ImportIntoActionTp string - -const ( - ImportIntoCancel ImportIntoActionTp = "cancel" -) - -// ImportIntoActionStmt represent CANCEL IMPORT INTO JOB statement. -// will support pause/resume/drop later. -type ImportIntoActionStmt struct { - stmtNode - - Tp ImportIntoActionTp - JobID int64 -} - -func (n *ImportIntoActionStmt) Accept(v Visitor) (Node, bool) { - newNode, _ := v.Enter(n) - return v.Leave(newNode) -} - -func (n *ImportIntoActionStmt) Restore(ctx *format.RestoreCtx) error { - if n.Tp != ImportIntoCancel { - return errors.Errorf("invalid IMPORT INTO action type: %s", n.Tp) - } - ctx.WriteKeyWord("CANCEL IMPORT JOB ") - ctx.WritePlainf("%d", n.JobID) - return nil -} - -// Ident is the table identifier composed of schema name and table name. -type Ident struct { - Schema model.CIStr - Name model.CIStr -} - -// String implements fmt.Stringer interface. -func (i Ident) String() string { - if i.Schema.O == "" { - return i.Name.O - } - return fmt.Sprintf("%s.%s", i.Schema, i.Name) -} - -// SelectStmtOpts wrap around select hints and switches -type SelectStmtOpts struct { - Distinct bool - SQLBigResult bool - SQLBufferResult bool - SQLCache bool - SQLSmallResult bool - CalcFoundRows bool - StraightJoin bool - Priority mysql.PriorityEnum - TableHints []*TableOptimizerHint - ExplicitAll bool -} - -// TableOptimizerHint is Table level optimizer hint -type TableOptimizerHint struct { - node - // HintName is the name or alias of the table(s) which the hint will affect. - // Table hints has no schema info - // It allows only table name or alias (if table has an alias) - HintName model.CIStr - // HintData is the payload of the hint. The actual type of this field - // is defined differently as according `HintName`. Define as following: - // - // Statement Execution Time Optimizer Hints - // See https://dev.mysql.com/doc/refman/5.7/en/optimizer-hints.html#optimizer-hints-execution-time - // - MAX_EXECUTION_TIME => uint64 - // - MEMORY_QUOTA => int64 - // - QUERY_TYPE => model.CIStr - // - // Time Range is used to hint the time range of inspection tables - // e.g: select /*+ time_range('','') */ * from information_schema.inspection_result. - // - TIME_RANGE => ast.HintTimeRange - // - READ_FROM_STORAGE => model.CIStr - // - USE_TOJA => bool - // - NTH_PLAN => int64 - HintData interface{} - // QBName is the default effective query block of this hint. - QBName model.CIStr - Tables []HintTable - Indexes []model.CIStr -} - -// HintTimeRange is the payload of `TIME_RANGE` hint -type HintTimeRange struct { - From string - To string -} - -// HintSetVar is the payload of `SET_VAR` hint -type HintSetVar struct { - VarName string - Value string -} - -// HintTable is table in the hint. It may have query block info. -type HintTable struct { - DBName model.CIStr - TableName model.CIStr - QBName model.CIStr - PartitionList []model.CIStr -} - -func (ht *HintTable) Restore(ctx *format.RestoreCtx) { - if !ctx.Flags.HasWithoutSchemaNameFlag() { - if ht.DBName.L != "" { - ctx.WriteName(ht.DBName.String()) - ctx.WriteKeyWord(".") - } - } - ctx.WriteName(ht.TableName.String()) - if ht.QBName.L != "" { - ctx.WriteKeyWord("@") - ctx.WriteName(ht.QBName.String()) - } - if len(ht.PartitionList) > 0 { - ctx.WriteKeyWord(" PARTITION") - ctx.WritePlain("(") - for i, p := range ht.PartitionList { - if i > 0 { - ctx.WritePlain(", ") - } - ctx.WriteName(p.String()) - } - ctx.WritePlain(")") - } -} - -// Restore implements Node interface. -func (n *TableOptimizerHint) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord(n.HintName.String()) - ctx.WritePlain("(") - if n.QBName.L != "" { - if n.HintName.L != "qb_name" { - ctx.WriteKeyWord("@") - } - ctx.WriteName(n.QBName.String()) - } - if n.HintName.L == "qb_name" && len(n.Tables) == 0 { - ctx.WritePlain(")") - return nil - } - // Hints without args except query block. - switch n.HintName.L { - case "mpp_1phase_agg", "mpp_2phase_agg", "hash_agg", "stream_agg", "agg_to_cop", "read_consistent_replica", "no_index_merge", "ignore_plan_cache", "limit_to_cop", "straight_join", "merge", "no_decorrelate": - ctx.WritePlain(")") - return nil - } - if n.QBName.L != "" { - ctx.WritePlain(" ") - } - // Hints with args except query block. - switch n.HintName.L { - case "max_execution_time": - ctx.WritePlainf("%d", n.HintData.(uint64)) - case "resource_group": - ctx.WriteName(n.HintData.(string)) - case "nth_plan": - ctx.WritePlainf("%d", n.HintData.(int64)) - case "tidb_hj", "tidb_smj", "tidb_inlj", "hash_join", "hash_join_build", "hash_join_probe", "merge_join", "inl_join", - "broadcast_join", "shuffle_join", "inl_hash_join", "inl_merge_join", "leading", "no_hash_join", "no_merge_join", - "no_index_join", "no_index_hash_join", "no_index_merge_join": - for i, table := range n.Tables { - if i != 0 { - ctx.WritePlain(", ") - } - table.Restore(ctx) - } - case "use_index", "ignore_index", "use_index_merge", "force_index", "order_index", "no_order_index": - n.Tables[0].Restore(ctx) - ctx.WritePlain(" ") - for i, index := range n.Indexes { - if i != 0 { - ctx.WritePlain(", ") - } - ctx.WriteName(index.String()) - } - case "qb_name": - if len(n.Tables) > 0 { - ctx.WritePlain(", ") - for i, table := range n.Tables { - if i != 0 { - ctx.WritePlain(". ") - } - table.Restore(ctx) - } - } - case "use_toja", "use_cascades": - if n.HintData.(bool) { - ctx.WritePlain("TRUE") - } else { - ctx.WritePlain("FALSE") - } - case "query_type": - ctx.WriteKeyWord(n.HintData.(model.CIStr).String()) - case "memory_quota": - ctx.WritePlainf("%d MB", n.HintData.(int64)/1024/1024) - case "read_from_storage": - ctx.WriteKeyWord(n.HintData.(model.CIStr).String()) - for i, table := range n.Tables { - if i == 0 { - ctx.WritePlain("[") - } - table.Restore(ctx) - if i == len(n.Tables)-1 { - ctx.WritePlain("]") - } else { - ctx.WritePlain(", ") - } - } - case "time_range": - hintData := n.HintData.(HintTimeRange) - ctx.WriteString(hintData.From) - ctx.WritePlain(", ") - ctx.WriteString(hintData.To) - case "set_var": - hintData := n.HintData.(HintSetVar) - ctx.WritePlain(hintData.VarName) - ctx.WritePlain(" = ") - ctx.WriteString(hintData.Value) - } - ctx.WritePlain(")") - return nil -} - -// Accept implements Node Accept interface. -func (n *TableOptimizerHint) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*TableOptimizerHint) - return v.Leave(n) -} - -// TextString represent a string, it can be a binary literal. -type TextString struct { - Value string - IsBinaryLiteral bool -} - -type BinaryLiteral interface { - ToString() string -} - -// NewDecimal creates a types.Decimal value, it's provided by parser driver. -var NewDecimal func(string) (interface{}, error) - -// NewHexLiteral creates a types.HexLiteral value, it's provided by parser driver. -var NewHexLiteral func(string) (interface{}, error) - -// NewBitLiteral creates a types.BitLiteral value, it's provided by parser driver. -var NewBitLiteral func(string) (interface{}, error) - -// SetResourceGroupStmt is a statement to set the resource group name for current session. -type SetResourceGroupStmt struct { - stmtNode - Name model.CIStr -} - -func (n *SetResourceGroupStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("SET RESOURCE GROUP ") - ctx.WriteName(n.Name.O) - return nil -} - -// Accept implements Node Accept interface. -func (n *SetResourceGroupStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*SetResourceGroupStmt) - return v.Leave(n) -} - -// CalibrateResourceType is the type for CalibrateResource statement. -type CalibrateResourceType int - -// calibrate resource [ workload < TPCC | OLTP_READ_WRITE | OLTP_READ_ONLY | OLTP_WRITE_ONLY | TPCH_10> ] -const ( - WorkloadNone CalibrateResourceType = iota - TPCC - OLTPREADWRITE - OLTPREADONLY - OLTPWRITEONLY - TPCH10 -) - -func (n CalibrateResourceType) Restore(ctx *format.RestoreCtx) error { - switch n { - case TPCC: - ctx.WriteKeyWord(" WORKLOAD TPCC") - case OLTPREADWRITE: - ctx.WriteKeyWord(" WORKLOAD OLTP_READ_WRITE") - case OLTPREADONLY: - ctx.WriteKeyWord(" WORKLOAD OLTP_READ_ONLY") - case OLTPWRITEONLY: - ctx.WriteKeyWord(" WORKLOAD OLTP_WRITE_ONLY") - case TPCH10: - ctx.WriteKeyWord(" WORKLOAD TPCH_10") - } - return nil -} - -// CalibrateResourceStmt is a statement to fetch the cluster RU capacity -type CalibrateResourceStmt struct { - stmtNode - DynamicCalibrateResourceOptionList []*DynamicCalibrateResourceOption - Tp CalibrateResourceType -} - -// Restore implements Node interface. -func (n *CalibrateResourceStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("CALIBRATE RESOURCE") - if err := n.Tp.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore CalibrateResourceStmt.CalibrateResourceType") - } - for i, option := range n.DynamicCalibrateResourceOptionList { - ctx.WritePlain(" ") - if err := option.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while splicing DynamicCalibrateResourceOption: [%v]", i) - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *CalibrateResourceStmt) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*CalibrateResourceStmt) - for _, val := range n.DynamicCalibrateResourceOptionList { - _, ok := val.Accept(v) - if !ok { - return n, false - } - } - return v.Leave(n) -} - -type DynamicCalibrateType int - -const ( - // specific time - CalibrateStartTime = iota - CalibrateEndTime - CalibrateDuration -) - -type DynamicCalibrateResourceOption struct { - stmtNode - Tp DynamicCalibrateType - StrValue string - Ts ExprNode - Unit TimeUnitType -} - -func (n *DynamicCalibrateResourceOption) Restore(ctx *format.RestoreCtx) error { - switch n.Tp { - case CalibrateStartTime: - ctx.WriteKeyWord("START_TIME ") - if err := n.Ts.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while splicing DynamicCalibrateResourceOption StartTime") - } - case CalibrateEndTime: - ctx.WriteKeyWord("END_TIME ") - if err := n.Ts.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while splicing DynamicCalibrateResourceOption EndTime") - } - case CalibrateDuration: - ctx.WriteKeyWord("DURATION ") - if len(n.StrValue) > 0 { - ctx.WriteString(n.StrValue) - } else { - ctx.WriteKeyWord("INTERVAL ") - if err := n.Ts.Restore(ctx); err != nil { - return errors.Annotate(err, "An error occurred while restore DynamicCalibrateResourceOption DURATION TS") - } - ctx.WritePlain(" ") - ctx.WriteKeyWord(n.Unit.String()) - } - default: - return errors.Errorf("invalid DynamicCalibrateResourceOption: %d", n.Tp) - } - return nil -} - -// Accept implements Node Accept interface. -func (n *DynamicCalibrateResourceOption) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*DynamicCalibrateResourceOption) - if n.Ts != nil { - node, ok := n.Ts.Accept(v) - if !ok { - return n, false - } - n.Ts = node.(ExprNode) - } - return v.Leave(n) -} - -// DropQueryWatchStmt is a statement to drop a runaway watch item. -type DropQueryWatchStmt struct { - stmtNode - IntValue int64 -} - -func (n *DropQueryWatchStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("QUERY WATCH REMOVE ") - ctx.WritePlainf("%d", n.IntValue) - return nil -} - -// Accept implements Node Accept interface. -func (n *DropQueryWatchStmt) Accept(v Visitor) (Node, bool) { - newNode, _ := v.Enter(n) - n = newNode.(*DropQueryWatchStmt) - return v.Leave(n) -} - -// AddQueryWatchStmt is a statement to add a runaway watch item. -type AddQueryWatchStmt struct { - stmtNode - QueryWatchOptionList []*QueryWatchOption -} - -func (n *AddQueryWatchStmt) Restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("QUERY WATCH ADD") - for i, option := range n.QueryWatchOptionList { - ctx.WritePlain(" ") - if err := option.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while splicing QueryWatchOptionList: [%v]", i) - } - } - return nil -} - -// Accept implements Node Accept interface. -func (n *AddQueryWatchStmt) Accept(v Visitor) (Node, bool) { - newNode, _ := v.Enter(n) - n = newNode.(*AddQueryWatchStmt) - for _, val := range n.QueryWatchOptionList { - _, ok := val.Accept(v) - if !ok { - return n, false - } - } - return v.Leave(n) -} - -type QueryWatchOptionType int - -const ( - QueryWatchResourceGroup QueryWatchOptionType = iota - QueryWatchAction - QueryWatchType -) - -// QueryWatchOption is used for parsing manual management of watching runaway queries option. -type QueryWatchOption struct { - stmtNode - Tp QueryWatchOptionType - ResourceGroupOption *QueryWatchResourceGroupOption - ActionOption *ResourceGroupRunawayActionOption - TextOption *QueryWatchTextOption -} - -// Restore implements Node interface. -func (n *QueryWatchOption) Restore(ctx *format.RestoreCtx) error { - switch n.Tp { - case QueryWatchResourceGroup: - return n.ResourceGroupOption.restore(ctx) - case QueryWatchAction: - return n.ActionOption.Restore(ctx) - case QueryWatchType: - return n.TextOption.Restore(ctx) - } - return nil -} - -// Accept implements Node Accept interface. -func (n *QueryWatchOption) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*QueryWatchOption) - if n.ResourceGroupOption != nil && n.ResourceGroupOption.GroupNameExpr != nil { - node, ok := n.ResourceGroupOption.GroupNameExpr.Accept(v) - if !ok { - return n, false - } - n.ResourceGroupOption.GroupNameExpr = node.(ExprNode) - } - if n.ActionOption != nil { - node, ok := n.ActionOption.Accept(v) - if !ok { - return n, false - } - n.ActionOption = node.(*ResourceGroupRunawayActionOption) - } - if n.TextOption != nil { - node, ok := n.TextOption.Accept(v) - if !ok { - return n, false - } - n.TextOption = node.(*QueryWatchTextOption) - } - return v.Leave(n) -} - -func CheckQueryWatchAppend(ops []*QueryWatchOption, newOp *QueryWatchOption) bool { - for _, op := range ops { - if op.Tp == newOp.Tp { - return false - } - } - return true -} - -// QueryWatchResourceGroupOption is used for parsing the query watch resource group name. -type QueryWatchResourceGroupOption struct { - GroupNameStr model.CIStr - GroupNameExpr ExprNode -} - -func (n *QueryWatchResourceGroupOption) restore(ctx *format.RestoreCtx) error { - ctx.WriteKeyWord("RESOURCE GROUP ") - if n.GroupNameExpr != nil { - if err := n.GroupNameExpr.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while splicing ExprValue: [%v]", n.GroupNameExpr) - } - } else { - ctx.WriteName(n.GroupNameStr.String()) - } - return nil -} - -// QueryWatchTextOption is used for parsing the query watch text option. -type QueryWatchTextOption struct { - node - Type model.RunawayWatchType - PatternExpr ExprNode - TypeSpecified bool -} - -// Restore implements Node interface. -func (n *QueryWatchTextOption) Restore(ctx *format.RestoreCtx) error { - if n.TypeSpecified { - ctx.WriteKeyWord("SQL TEXT ") - ctx.WriteKeyWord(n.Type.String()) - ctx.WriteKeyWord(" TO ") - } else { - switch n.Type { - case model.WatchSimilar: - ctx.WriteKeyWord("SQL DIGEST ") - case model.WatchPlan: - ctx.WriteKeyWord("PLAN DIGEST ") - } - } - if err := n.PatternExpr.Restore(ctx); err != nil { - return errors.Annotatef(err, "An error occurred while splicing ExprValue: [%v]", n.PatternExpr) - } - return nil -} - -// Accept implements Node Accept interface. -func (n *QueryWatchTextOption) Accept(v Visitor) (Node, bool) { - newNode, skipChildren := v.Enter(n) - if skipChildren { - return v.Leave(newNode) - } - n = newNode.(*QueryWatchTextOption) - if n.PatternExpr != nil { - node, ok := n.PatternExpr.Accept(v) - if !ok { - return n, false - } - n.PatternExpr = node.(ExprNode) - } - return v.Leave(n) -} diff --git a/pkg/planner/binding__failpoint_binding__.go b/pkg/planner/binding__failpoint_binding__.go deleted file mode 100644 index fc6da4ff0bb13..0000000000000 --- a/pkg/planner/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package planner - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/planner/cardinality/binding__failpoint_binding__.go b/pkg/planner/cardinality/binding__failpoint_binding__.go deleted file mode 100644 index f8cd42267a934..0000000000000 --- a/pkg/planner/cardinality/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package cardinality - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/planner/cardinality/row_count_index.go b/pkg/planner/cardinality/row_count_index.go index 72dd3014fc89a..bc0f3c3226090 100644 --- a/pkg/planner/cardinality/row_count_index.go +++ b/pkg/planner/cardinality/row_count_index.go @@ -486,10 +486,10 @@ func expBackoffEstimation(sctx context.PlanContext, idx *statistics.Index, coll // Sort them. slices.Sort(singleColumnEstResults) l := len(singleColumnEstResults) - if _, _err_ := failpoint.Eval(_curpkg_("cleanEstResults")); _err_ == nil { + failpoint.Inject("cleanEstResults", func() { singleColumnEstResults = singleColumnEstResults[:0] l = 0 - } + }) if l == 1 { return singleColumnEstResults[0], true, nil } else if l == 0 { diff --git a/pkg/planner/cardinality/row_count_index.go__failpoint_stash__ b/pkg/planner/cardinality/row_count_index.go__failpoint_stash__ deleted file mode 100644 index bc0f3c3226090..0000000000000 --- a/pkg/planner/cardinality/row_count_index.go__failpoint_stash__ +++ /dev/null @@ -1,568 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cardinality - -import ( - "bytes" - "math" - "slices" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/planner/context" - "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/mathutil" - "github.com/pingcap/tidb/pkg/util/ranger" -) - -// GetRowCountByIndexRanges estimates the row count by a slice of Range. -func GetRowCountByIndexRanges(sctx context.PlanContext, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (result float64, err error) { - var name string - if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - debugTraceGetRowCountInput(sctx, idxID, indexRanges) - defer func() { - debugtrace.RecordAnyValuesWithNames(sctx, "Name", name, "Result", result) - debugtrace.LeaveContextCommon(sctx) - }() - } - sc := sctx.GetSessionVars().StmtCtx - idx := coll.GetIdx(idxID) - colNames := make([]string, 0, 8) - if idx != nil { - if idx.Info != nil { - name = idx.Info.Name.O - for _, col := range idx.Info.Columns { - colNames = append(colNames, col.Name.O) - } - } - } - recordUsedItemStatsStatus(sctx, idx, coll.PhysicalID, idxID) - if statistics.IndexStatsIsInvalid(sctx, idx, coll, idxID) { - colsLen := -1 - if idx != nil && idx.Info.Unique { - colsLen = len(idx.Info.Columns) - } - result, err = getPseudoRowCountByIndexRanges(sc.TypeCtx(), indexRanges, float64(coll.RealtimeCount), colsLen) - if err == nil && sc.EnableOptimizerCETrace && idx != nil { - ceTraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats-Pseudo", uint64(result)) - } - return result, err - } - realtimeCnt, modifyCount := coll.GetScaledRealtimeAndModifyCnt(idx) - if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.RecordAnyValuesWithNames(sctx, - "Histogram NotNull Count", idx.Histogram.NotNullCount(), - "TopN total count", idx.TopN.TotalCount(), - "Increase Factor", idx.GetIncreaseFactor(realtimeCnt), - ) - } - if idx.CMSketch != nil && idx.StatsVer == statistics.Version1 { - result, err = getIndexRowCountForStatsV1(sctx, coll, idxID, indexRanges) - } else { - result, err = getIndexRowCountForStatsV2(sctx, idx, coll, indexRanges, realtimeCnt, modifyCount) - } - if sc.EnableOptimizerCETrace { - ceTraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats", uint64(result)) - } - return result, errors.Trace(err) -} - -func getIndexRowCountForStatsV1(sctx context.PlanContext, coll *statistics.HistColl, idxID int64, indexRanges []*ranger.Range) (float64, error) { - sc := sctx.GetSessionVars().StmtCtx - debugTrace := sc.EnableOptimizerDebugTrace - if debugTrace { - debugtrace.EnterContextCommon(sctx) - defer debugtrace.LeaveContextCommon(sctx) - } - idx := coll.GetIdx(idxID) - totalCount := float64(0) - for _, ran := range indexRanges { - if debugTrace { - debugTraceStartEstimateRange(sctx, ran, nil, nil, totalCount) - } - rangePosition := getOrdinalOfRangeCond(sc, ran) - var rangeVals []types.Datum - // Try to enum the last range values. - if rangePosition != len(ran.LowVal) { - rangeVals = statistics.EnumRangeValues(ran.LowVal[rangePosition], ran.HighVal[rangePosition], ran.LowExclude, ran.HighExclude) - if rangeVals != nil { - rangePosition++ - } - } - // If first one is range, just use the previous way to estimate; if it is [NULL, NULL] range - // on single-column index, use previous way as well, because CMSketch does not contain null - // values in this case. - if rangePosition == 0 || isSingleColIdxNullRange(idx, ran) { - realtimeCnt, modifyCount := coll.GetScaledRealtimeAndModifyCnt(idx) - count, err := getIndexRowCountForStatsV2(sctx, idx, nil, []*ranger.Range{ran}, realtimeCnt, modifyCount) - if err != nil { - return 0, errors.Trace(err) - } - if debugTrace { - debugTraceEndEstimateRange(sctx, count, debugTraceRange) - } - totalCount += count - continue - } - var selectivity float64 - // use CM Sketch to estimate the equal conditions - if rangeVals == nil { - bytes, err := codec.EncodeKey(sc.TimeZone(), nil, ran.LowVal[:rangePosition]...) - err = sc.HandleError(err) - if err != nil { - return 0, errors.Trace(err) - } - selectivity, err = getEqualCondSelectivity(sctx, coll, idx, bytes, rangePosition, ran) - if err != nil { - return 0, errors.Trace(err) - } - } else { - bytes, err := codec.EncodeKey(sc.TimeZone(), nil, ran.LowVal[:rangePosition-1]...) - err = sc.HandleError(err) - if err != nil { - return 0, errors.Trace(err) - } - prefixLen := len(bytes) - for _, val := range rangeVals { - bytes = bytes[:prefixLen] - bytes, err = codec.EncodeKey(sc.TimeZone(), bytes, val) - err = sc.HandleError(err) - if err != nil { - return 0, err - } - res, err := getEqualCondSelectivity(sctx, coll, idx, bytes, rangePosition, ran) - if err != nil { - return 0, errors.Trace(err) - } - selectivity += res - } - } - // use histogram to estimate the range condition - if rangePosition != len(ran.LowVal) { - rang := ranger.Range{ - LowVal: []types.Datum{ran.LowVal[rangePosition]}, - LowExclude: ran.LowExclude, - HighVal: []types.Datum{ran.HighVal[rangePosition]}, - HighExclude: ran.HighExclude, - Collators: []collate.Collator{ran.Collators[rangePosition]}, - } - var count float64 - var err error - colUniqueIDs := coll.Idx2ColUniqueIDs[idxID] - var colUniqueID int64 - if rangePosition >= len(colUniqueIDs) { - colUniqueID = -1 - } else { - colUniqueID = colUniqueIDs[rangePosition] - } - // prefer index stats over column stats - if idxIDs, ok := coll.ColUniqueID2IdxIDs[colUniqueID]; ok && len(idxIDs) > 0 { - idxID := idxIDs[0] - count, err = GetRowCountByIndexRanges(sctx, coll, idxID, []*ranger.Range{&rang}) - } else { - count, err = GetRowCountByColumnRanges(sctx, coll, colUniqueID, []*ranger.Range{&rang}) - } - if err != nil { - return 0, errors.Trace(err) - } - selectivity = selectivity * count / idx.TotalRowCount() - } - count := selectivity * idx.TotalRowCount() - if debugTrace { - debugTraceEndEstimateRange(sctx, count, debugTraceRange) - } - totalCount += count - } - if totalCount > idx.TotalRowCount() { - totalCount = idx.TotalRowCount() - } - return totalCount, nil -} - -// isSingleColIdxNullRange checks if a range is [NULL, NULL] on a single-column index. -func isSingleColIdxNullRange(idx *statistics.Index, ran *ranger.Range) bool { - if len(idx.Info.Columns) > 1 { - return false - } - l, h := ran.LowVal[0], ran.HighVal[0] - if l.IsNull() && h.IsNull() { - return true - } - return false -} - -// It uses the modifyCount to adjust the influence of modifications on the table. -func getIndexRowCountForStatsV2(sctx context.PlanContext, idx *statistics.Index, coll *statistics.HistColl, indexRanges []*ranger.Range, realtimeRowCount, modifyCount int64) (float64, error) { - sc := sctx.GetSessionVars().StmtCtx - debugTrace := sc.EnableOptimizerDebugTrace - if debugTrace { - debugtrace.EnterContextCommon(sctx) - defer debugtrace.LeaveContextCommon(sctx) - } - totalCount := float64(0) - isSingleColIdx := len(idx.Info.Columns) == 1 - for _, indexRange := range indexRanges { - var count float64 - lb, err := codec.EncodeKey(sc.TimeZone(), nil, indexRange.LowVal...) - err = sc.HandleError(err) - if err != nil { - return 0, err - } - rb, err := codec.EncodeKey(sc.TimeZone(), nil, indexRange.HighVal...) - err = sc.HandleError(err) - if err != nil { - return 0, err - } - if debugTrace { - debugTraceStartEstimateRange(sctx, indexRange, lb, rb, totalCount) - } - fullLen := len(indexRange.LowVal) == len(indexRange.HighVal) && len(indexRange.LowVal) == len(idx.Info.Columns) - if bytes.Equal(lb, rb) { - // case 1: it's a point - if indexRange.LowExclude || indexRange.HighExclude { - if debugTrace { - debugTraceEndEstimateRange(sctx, 0, debugTraceImpossible) - } - continue - } - if fullLen { - // At most 1 in this case. - if idx.Info.Unique { - totalCount++ - if debugTrace { - debugTraceEndEstimateRange(sctx, 1, debugTraceUniquePoint) - } - continue - } - count = equalRowCountOnIndex(sctx, idx, lb, realtimeRowCount, modifyCount) - // If the current table row count has changed, we should scale the row count accordingly. - count *= idx.GetIncreaseFactor(realtimeRowCount) - if debugTrace { - debugTraceEndEstimateRange(sctx, count, debugTracePoint) - } - totalCount += count - continue - } - } - - // case 2: it's an interval - // The final interval is [low, high) - if indexRange.LowExclude { - lb = kv.Key(lb).PrefixNext() - } - if !indexRange.HighExclude { - rb = kv.Key(rb).PrefixNext() - } - l := types.NewBytesDatum(lb) - r := types.NewBytesDatum(rb) - lowIsNull := bytes.Equal(lb, nullKeyBytes) - if isSingleColIdx && lowIsNull { - count += float64(idx.Histogram.NullCount) - } - expBackoffSuccess := false - // Due to the limitation of calcFraction and convertDatumToScalar, the histogram actually won't estimate anything. - // If the first column's range is point. - if rangePosition := getOrdinalOfRangeCond(sc, indexRange); rangePosition > 0 && idx.StatsVer >= statistics.Version2 && coll != nil { - var expBackoffSel float64 - expBackoffSel, expBackoffSuccess, err = expBackoffEstimation(sctx, idx, coll, indexRange) - if err != nil { - return 0, err - } - if expBackoffSuccess { - expBackoffCnt := expBackoffSel * idx.TotalRowCount() - - upperLimit := expBackoffCnt - // Use the multi-column stats to calculate the max possible row count of [l, r) - if idx.Histogram.Len() > 0 { - _, lowerBkt, _, _ := idx.Histogram.LocateBucket(sctx, l) - _, upperBkt, _, _ := idx.Histogram.LocateBucket(sctx, r) - if debugTrace { - statistics.DebugTraceBuckets(sctx, &idx.Histogram, []int{lowerBkt - 1, upperBkt}) - } - // Use Count of the Bucket before l as the lower bound. - preCount := float64(0) - if lowerBkt > 0 { - preCount = float64(idx.Histogram.Buckets[lowerBkt-1].Count) - } - // Use Count of the Bucket where r exists as the upper bound. - upperCnt := float64(idx.Histogram.Buckets[upperBkt].Count) - upperLimit = upperCnt - preCount - upperLimit += float64(idx.TopN.BetweenCount(sctx, lb, rb)) - } - - // If the result of exponential backoff strategy is larger than the result from multi-column stats, - // use the upper limit from multi-column histogram instead. - if expBackoffCnt > upperLimit { - expBackoffCnt = upperLimit - } - count += expBackoffCnt - } - } - if !expBackoffSuccess { - count += betweenRowCountOnIndex(sctx, idx, l, r) - } - - // If the current table row count has changed, we should scale the row count accordingly. - increaseFactor := idx.GetIncreaseFactor(realtimeRowCount) - count *= increaseFactor - - // handling the out-of-range part - if (outOfRangeOnIndex(idx, l) && !(isSingleColIdx && lowIsNull)) || outOfRangeOnIndex(idx, r) { - histNDV := idx.NDV - // Exclude the TopN in Stats Version 2 - if idx.StatsVer == statistics.Version2 { - c := coll.GetCol(idx.Histogram.ID) - // If this is single column of a multi-column index - use the column's NDV rather than index NDV - isSingleColRange := len(indexRange.LowVal) == len(indexRange.HighVal) && len(indexRange.LowVal) == 1 - if isSingleColRange && !isSingleColIdx && c != nil && c.Histogram.NDV > 0 { - histNDV = c.Histogram.NDV - int64(c.TopN.Num()) - } else { - histNDV -= int64(idx.TopN.Num()) - } - } - count += idx.Histogram.OutOfRangeRowCount(sctx, &l, &r, modifyCount, histNDV, increaseFactor) - } - - if debugTrace { - debugTraceEndEstimateRange(sctx, count, debugTraceRange) - } - totalCount += count - } - // Don't allow the final result to go below 1 row - totalCount = mathutil.Clamp(totalCount, 1, float64(realtimeRowCount)) - return totalCount, nil -} - -var nullKeyBytes, _ = codec.EncodeKey(time.UTC, nil, types.NewDatum(nil)) - -func equalRowCountOnIndex(sctx context.PlanContext, idx *statistics.Index, b []byte, realtimeRowCount, modifyCount int64) (result float64) { - if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - debugtrace.RecordAnyValuesWithNames(sctx, "Encoded Value", b) - defer func() { - debugtrace.RecordAnyValuesWithNames(sctx, "Result", result) - debugtrace.LeaveContextCommon(sctx) - }() - } - if len(idx.Info.Columns) == 1 { - if bytes.Equal(b, nullKeyBytes) { - return float64(idx.Histogram.NullCount) - } - } - val := types.NewBytesDatum(b) - if idx.StatsVer < statistics.Version2 { - if idx.Histogram.NDV > 0 && outOfRangeOnIndex(idx, val) { - return outOfRangeEQSelectivity(sctx, idx.Histogram.NDV, realtimeRowCount, int64(idx.TotalRowCount())) * idx.TotalRowCount() - } - if idx.CMSketch != nil { - return float64(idx.QueryBytes(sctx, b)) - } - histRowCount, _ := idx.Histogram.EqualRowCount(sctx, val, false) - return histRowCount - } - // stats version == 2 - // 1. try to find this value in TopN - if idx.TopN != nil { - count, found := idx.TopN.QueryTopN(sctx, b) - if found { - return float64(count) - } - } - // 2. try to find this value in bucket.Repeat(the last value in every bucket) - histCnt, matched := idx.Histogram.EqualRowCount(sctx, val, true) - if matched { - return histCnt - } - // 3. use uniform distribution assumption for the rest (even when this value is not covered by the range of stats) - histNDV := float64(idx.Histogram.NDV - int64(idx.TopN.Num())) - if histNDV <= 0 { - // If the table hasn't been modified, it's safe to return 0. Otherwise, the TopN could be stale - return 1. - if modifyCount == 0 { - return 0 - } - return 1 - } - return idx.Histogram.NotNullCount() / histNDV -} - -// expBackoffEstimation estimate the multi-col cases following the Exponential Backoff. See comment below for details. -func expBackoffEstimation(sctx context.PlanContext, idx *statistics.Index, coll *statistics.HistColl, indexRange *ranger.Range) (sel float64, success bool, err error) { - if sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - defer func() { - debugtrace.RecordAnyValuesWithNames(sctx, - "Result", sel, - "Success", success, - "error", err, - ) - debugtrace.LeaveContextCommon(sctx) - }() - } - tmpRan := []*ranger.Range{ - { - LowVal: make([]types.Datum, 1), - HighVal: make([]types.Datum, 1), - Collators: make([]collate.Collator, 1), - }, - } - colsIDs := coll.Idx2ColUniqueIDs[idx.Histogram.ID] - singleColumnEstResults := make([]float64, 0, len(indexRange.LowVal)) - // The following codes uses Exponential Backoff to reduce the impact of independent assumption. It works like: - // 1. Calc the selectivity of each column. - // 2. Sort them and choose the first 4 most selective filter and the corresponding selectivity is sel_1, sel_2, sel_3, sel_4 where i < j => sel_i < sel_j. - // 3. The final selectivity would be sel_1 * sel_2^{1/2} * sel_3^{1/4} * sel_4^{1/8}. - // This calculation reduced the independence assumption and can work well better than it. - for i := 0; i < len(indexRange.LowVal); i++ { - tmpRan[0].LowVal[0] = indexRange.LowVal[i] - tmpRan[0].HighVal[0] = indexRange.HighVal[i] - tmpRan[0].Collators[0] = indexRange.Collators[0] - if i == len(indexRange.LowVal)-1 { - tmpRan[0].LowExclude = indexRange.LowExclude - tmpRan[0].HighExclude = indexRange.HighExclude - } - colID := colsIDs[i] - var ( - count float64 - selectivity float64 - err error - foundStats bool - ) - if !statistics.ColumnStatsIsInvalid(coll.GetCol(colID), sctx, coll, colID) { - foundStats = true - count, err = GetRowCountByColumnRanges(sctx, coll, colID, tmpRan) - selectivity = count / float64(coll.RealtimeCount) - } - if idxIDs, ok := coll.ColUniqueID2IdxIDs[colID]; ok && !foundStats && len(indexRange.LowVal) > 1 { - // Note the `len(indexRange.LowVal) > 1` condition here, it means we only recursively call - // `GetRowCountByIndexRanges()` when the input `indexRange` is a multi-column range. This - // check avoids infinite recursion. - for _, idxID := range idxIDs { - if idxID == idx.Histogram.ID { - continue - } - idxStats := coll.GetIdx(idxID) - if idxStats == nil || statistics.IndexStatsIsInvalid(sctx, idxStats, coll, idxID) { - continue - } - foundStats = true - count, err = GetRowCountByIndexRanges(sctx, coll, idxID, tmpRan) - if err == nil { - break - } - realtimeCnt, _ := coll.GetScaledRealtimeAndModifyCnt(idxStats) - selectivity = count / float64(realtimeCnt) - } - } - if !foundStats { - continue - } - if err != nil { - return 0, false, err - } - singleColumnEstResults = append(singleColumnEstResults, selectivity) - } - // Sort them. - slices.Sort(singleColumnEstResults) - l := len(singleColumnEstResults) - failpoint.Inject("cleanEstResults", func() { - singleColumnEstResults = singleColumnEstResults[:0] - l = 0 - }) - if l == 1 { - return singleColumnEstResults[0], true, nil - } else if l == 0 { - return 0, false, nil - } - // Do not allow the exponential backoff to go below the available index bound. If the number of predicates - // is less than the number of index columns - use 90% of the bound to differentiate a subset from full index match. - // If there is an individual column selectivity that goes below this bound, use that selectivity only. - histNDV := coll.RealtimeCount - if idx.NDV > 0 { - histNDV = idx.NDV - } - idxLowBound := 1 / float64(min(histNDV, coll.RealtimeCount)) - if l < len(idx.Info.Columns) { - idxLowBound /= 0.9 - } - minTwoCol := min(singleColumnEstResults[0], singleColumnEstResults[1], idxLowBound) - multTwoCol := singleColumnEstResults[0] * math.Sqrt(singleColumnEstResults[1]) - if l == 2 { - return max(minTwoCol, multTwoCol), true, nil - } - minThreeCol := min(minTwoCol, singleColumnEstResults[2]) - multThreeCol := multTwoCol * math.Sqrt(math.Sqrt(singleColumnEstResults[2])) - if l == 3 { - return max(minThreeCol, multThreeCol), true, nil - } - minFourCol := min(minThreeCol, singleColumnEstResults[3]) - multFourCol := multThreeCol * math.Sqrt(math.Sqrt(math.Sqrt(singleColumnEstResults[3]))) - return max(minFourCol, multFourCol), true, nil -} - -// outOfRangeOnIndex checks if the datum is out of the range. -func outOfRangeOnIndex(idx *statistics.Index, val types.Datum) bool { - if !idx.Histogram.OutOfRange(val) { - return false - } - if idx.Histogram.Len() > 0 && matchPrefix(idx.Histogram.Bounds.GetRow(0), 0, &val) { - return false - } - return true -} - -// matchPrefix checks whether ad is the prefix of value -func matchPrefix(row chunk.Row, colIdx int, ad *types.Datum) bool { - switch ad.Kind() { - case types.KindString, types.KindBytes, types.KindBinaryLiteral, types.KindMysqlBit: - return strings.HasPrefix(row.GetString(colIdx), ad.GetString()) - } - return false -} - -// betweenRowCountOnIndex estimates the row count for interval [l, r). -// The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func betweenRowCountOnIndex(sctx context.PlanContext, idx *statistics.Index, l, r types.Datum) float64 { - histBetweenCnt := idx.Histogram.BetweenRowCount(sctx, l, r) - if idx.StatsVer == statistics.Version1 { - return histBetweenCnt - } - return float64(idx.TopN.BetweenCount(sctx, l.GetBytes(), r.GetBytes())) + histBetweenCnt -} - -// getOrdinalOfRangeCond gets the ordinal of the position range condition, -// if not exist, it returns the end position. -func getOrdinalOfRangeCond(sc *stmtctx.StatementContext, ran *ranger.Range) int { - for i := range ran.LowVal { - a, b := ran.LowVal[i], ran.HighVal[i] - cmp, err := a.Compare(sc.TypeCtx(), &b, ran.Collators[0]) - if err != nil { - return 0 - } - if cmp != 0 { - return i - } - } - return len(ran.LowVal) -} diff --git a/pkg/planner/core/binding__failpoint_binding__.go b/pkg/planner/core/binding__failpoint_binding__.go deleted file mode 100644 index fd84c40441f21..0000000000000 --- a/pkg/planner/core/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package core - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/planner/core/collect_column_stats_usage.go b/pkg/planner/core/collect_column_stats_usage.go index 060a764789d42..54c427c82289f 100644 --- a/pkg/planner/core/collect_column_stats_usage.go +++ b/pkg/planner/core/collect_column_stats_usage.go @@ -190,9 +190,9 @@ func (c *columnStatsUsageCollector) addHistNeededColumns(ds *DataSource) { stats := domain.GetDomain(ds.SCtx()).StatsHandle() tblStats := stats.GetPartitionStats(ds.TableInfo, ds.PhysicalTableID) skipPseudoCheckForTest := false - if _, _err_ := failpoint.Eval(_curpkg_("disablePseudoCheck")); _err_ == nil { + failpoint.Inject("disablePseudoCheck", func() { skipPseudoCheckForTest = true - } + }) // Since we can not get the stats tbl, this table is not analyzed. So we don't need to consider load stats. if tblStats.Pseudo && !skipPseudoCheckForTest { return diff --git a/pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ b/pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ deleted file mode 100644 index 54c427c82289f..0000000000000 --- a/pkg/planner/core/collect_column_stats_usage.go__failpoint_stash__ +++ /dev/null @@ -1,456 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics/asyncload" - "github.com/pingcap/tidb/pkg/util/filter" - "github.com/pingcap/tidb/pkg/util/intset" - "golang.org/x/exp/maps" -) - -const ( - collectPredicateColumns uint64 = 1 << iota - collectHistNeededColumns -) - -// columnStatsUsageCollector collects predicate columns and/or histogram-needed columns from logical plan. -// Predicate columns are the columns whose statistics are utilized when making query plans, which usually occur in where conditions, join conditions and so on. -// Histogram-needed columns are the columns whose histograms are utilized when making query plans, which usually occur in the conditions pushed down to DataSource. -// The set of histogram-needed columns is the subset of that of predicate columns. -type columnStatsUsageCollector struct { - // collectMode indicates whether to collect predicate columns and/or histogram-needed columns - collectMode uint64 - // predicateCols records predicate columns. - predicateCols map[model.TableItemID]struct{} - // colMap maps expression.Column.UniqueID to the table columns whose statistics may be utilized to calculate statistics of the column. - // It is used for collecting predicate columns. - // For example, in `select count(distinct a, b) as e from t`, the count of column `e` is calculated as `max(ndv(t.a), ndv(t.b))` if - // we don't know `ndv(t.a, t.b)`(see (*LogicalAggregation).DeriveStats and getColsNDV for details). So when calculating the statistics - // of column `e`, we may use the statistics of column `t.a` and `t.b`. - colMap map[int64]map[model.TableItemID]struct{} - // histNeededCols records histogram-needed columns. The value field of the map indicates that whether we need to load the full stats of the time or not. - histNeededCols map[model.TableItemID]bool - // cols is used to store columns collected from expressions and saves some allocation. - cols []*expression.Column - - // visitedPhysTblIDs all ds.PhysicalTableID that have been visited. - // It's always collected, even collectHistNeededColumns is not set. - visitedPhysTblIDs *intset.FastIntSet - - // collectVisitedTable indicates whether to collect visited table - collectVisitedTable bool - // visitedtbls indicates the visited table - visitedtbls map[int64]struct{} -} - -func newColumnStatsUsageCollector(collectMode uint64, enabledPlanCapture bool) *columnStatsUsageCollector { - set := intset.NewFastIntSet() - collector := &columnStatsUsageCollector{ - collectMode: collectMode, - // Pre-allocate a slice to reduce allocation, 8 doesn't have special meaning. - cols: make([]*expression.Column, 0, 8), - visitedPhysTblIDs: &set, - } - if collectMode&collectPredicateColumns != 0 { - collector.predicateCols = make(map[model.TableItemID]struct{}) - collector.colMap = make(map[int64]map[model.TableItemID]struct{}) - } - if collectMode&collectHistNeededColumns != 0 { - collector.histNeededCols = make(map[model.TableItemID]bool) - } - if enabledPlanCapture { - collector.collectVisitedTable = true - collector.visitedtbls = map[int64]struct{}{} - } - return collector -} - -func (c *columnStatsUsageCollector) addPredicateColumn(col *expression.Column) { - tblColIDs, ok := c.colMap[col.UniqueID] - if !ok { - // It may happen if some leaf of logical plan is LogicalMemTable/LogicalShow/LogicalShowDDLJobs. - return - } - for tblColID := range tblColIDs { - c.predicateCols[tblColID] = struct{}{} - } -} - -func (c *columnStatsUsageCollector) addPredicateColumnsFromExpressions(list []expression.Expression) { - cols := expression.ExtractColumnsAndCorColumnsFromExpressions(c.cols[:0], list) - for _, col := range cols { - c.addPredicateColumn(col) - } -} - -func (c *columnStatsUsageCollector) updateColMap(col *expression.Column, relatedCols []*expression.Column) { - if _, ok := c.colMap[col.UniqueID]; !ok { - c.colMap[col.UniqueID] = map[model.TableItemID]struct{}{} - } - for _, relatedCol := range relatedCols { - tblColIDs, ok := c.colMap[relatedCol.UniqueID] - if !ok { - // It may happen if some leaf of logical plan is LogicalMemTable/LogicalShow/LogicalShowDDLJobs. - continue - } - for tblColID := range tblColIDs { - c.colMap[col.UniqueID][tblColID] = struct{}{} - } - } -} - -func (c *columnStatsUsageCollector) updateColMapFromExpressions(col *expression.Column, list []expression.Expression) { - c.updateColMap(col, expression.ExtractColumnsAndCorColumnsFromExpressions(c.cols[:0], list)) -} - -func (c *columnStatsUsageCollector) collectPredicateColumnsForDataSource(ds *DataSource) { - // Skip all system tables. - if filter.IsSystemSchema(ds.DBName.L) { - return - } - // For partition tables, no matter whether it is static or dynamic pruning mode, we use table ID rather than partition ID to - // set TableColumnID.TableID. In this way, we keep the set of predicate columns consistent between different partitions and global table. - tblID := ds.TableInfo.ID - if c.collectVisitedTable { - c.visitedtbls[tblID] = struct{}{} - } - for _, col := range ds.Schema().Columns { - tblColID := model.TableItemID{TableID: tblID, ID: col.ID, IsIndex: false} - c.colMap[col.UniqueID] = map[model.TableItemID]struct{}{tblColID: {}} - } - // We should use `PushedDownConds` here. `AllConds` is used for partition pruning, which doesn't need stats. - c.addPredicateColumnsFromExpressions(ds.PushedDownConds) -} - -func (c *columnStatsUsageCollector) collectPredicateColumnsForJoin(p *LogicalJoin) { - // The only schema change is merging two schemas so there is no new column. - // Assume statistics of all the columns in EqualConditions/LeftConditions/RightConditions/OtherConditions are needed. - exprs := make([]expression.Expression, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) - for _, cond := range p.EqualConditions { - exprs = append(exprs, cond) - } - for _, cond := range p.LeftConditions { - exprs = append(exprs, cond) - } - for _, cond := range p.RightConditions { - exprs = append(exprs, cond) - } - for _, cond := range p.OtherConditions { - exprs = append(exprs, cond) - } - c.addPredicateColumnsFromExpressions(exprs) -} - -func (c *columnStatsUsageCollector) collectPredicateColumnsForUnionAll(p *LogicalUnionAll) { - // statistics of the ith column of UnionAll come from statistics of the ith column of each child. - schemas := make([]*expression.Schema, 0, len(p.Children())) - relatedCols := make([]*expression.Column, 0, len(p.Children())) - for _, child := range p.Children() { - schemas = append(schemas, child.Schema()) - } - for i, col := range p.Schema().Columns { - relatedCols = relatedCols[:0] - for j := range p.Children() { - relatedCols = append(relatedCols, schemas[j].Columns[i]) - } - c.updateColMap(col, relatedCols) - } -} - -func (c *columnStatsUsageCollector) addHistNeededColumns(ds *DataSource) { - c.visitedPhysTblIDs.Insert(int(ds.PhysicalTableID)) - if c.collectMode&collectHistNeededColumns == 0 { - return - } - if c.collectVisitedTable { - tblID := ds.TableInfo.ID - c.visitedtbls[tblID] = struct{}{} - } - stats := domain.GetDomain(ds.SCtx()).StatsHandle() - tblStats := stats.GetPartitionStats(ds.TableInfo, ds.PhysicalTableID) - skipPseudoCheckForTest := false - failpoint.Inject("disablePseudoCheck", func() { - skipPseudoCheckForTest = true - }) - // Since we can not get the stats tbl, this table is not analyzed. So we don't need to consider load stats. - if tblStats.Pseudo && !skipPseudoCheckForTest { - return - } - columns := expression.ExtractColumnsFromExpressions(c.cols[:0], ds.PushedDownConds, nil) - - colIDSet := intset.NewFastIntSet() - - for _, col := range columns { - // If the column is plan-generated one, Skip it. - // TODO: we may need to consider the ExtraHandle. - if col.ID < 0 { - continue - } - tblColID := model.TableItemID{TableID: ds.PhysicalTableID, ID: col.ID, IsIndex: false} - colIDSet.Insert(int(col.ID)) - c.histNeededCols[tblColID] = true - } - for _, column := range ds.TableInfo.Columns { - // If the column is plan-generated one, Skip it. - // TODO: we may need to consider the ExtraHandle. - if column.ID < 0 { - continue - } - if !column.Hidden { - tblColID := model.TableItemID{TableID: ds.PhysicalTableID, ID: column.ID, IsIndex: false} - if _, ok := c.histNeededCols[tblColID]; !ok { - c.histNeededCols[tblColID] = false - } - } - } -} - -func (c *columnStatsUsageCollector) collectFromPlan(lp base.LogicalPlan) { - for _, child := range lp.Children() { - c.collectFromPlan(child) - } - if c.collectMode&collectPredicateColumns != 0 { - switch x := lp.(type) { - case *DataSource: - c.collectPredicateColumnsForDataSource(x) - case *LogicalIndexScan: - c.collectPredicateColumnsForDataSource(x.Source) - c.addPredicateColumnsFromExpressions(x.AccessConds) - case *LogicalTableScan: - c.collectPredicateColumnsForDataSource(x.Source) - c.addPredicateColumnsFromExpressions(x.AccessConds) - case *logicalop.LogicalProjection: - // Schema change from children to self. - schema := x.Schema() - for i, expr := range x.Exprs { - c.updateColMapFromExpressions(schema.Columns[i], []expression.Expression{expr}) - } - case *LogicalSelection: - // Though the conditions in LogicalSelection are complex conditions which cannot be pushed down to DataSource, we still - // regard statistics of the columns in the conditions as needed. - c.addPredicateColumnsFromExpressions(x.Conditions) - case *LogicalAggregation: - // Just assume statistics of all the columns in GroupByItems are needed. - c.addPredicateColumnsFromExpressions(x.GroupByItems) - // Schema change from children to self. - schema := x.Schema() - for i, aggFunc := range x.AggFuncs { - c.updateColMapFromExpressions(schema.Columns[i], aggFunc.Args) - } - case *logicalop.LogicalWindow: - // Statistics of the columns in LogicalWindow.PartitionBy are used in optimizeByShuffle4Window. - // We don't use statistics of the columns in LogicalWindow.OrderBy currently. - for _, item := range x.PartitionBy { - c.addPredicateColumn(item.Col) - } - // Schema change from children to self. - windowColumns := x.GetWindowResultColumns() - for i, col := range windowColumns { - c.updateColMapFromExpressions(col, x.WindowFuncDescs[i].Args) - } - case *LogicalJoin: - c.collectPredicateColumnsForJoin(x) - case *LogicalApply: - c.collectPredicateColumnsForJoin(&x.LogicalJoin) - // Assume statistics of correlated columns are needed. - // Correlated columns can be found in LogicalApply.Children()[0].Schema(). Since we already visit LogicalApply.Children()[0], - // correlated columns must have existed in columnStatsUsageCollector.colMap. - for _, corCols := range x.CorCols { - c.addPredicateColumn(&corCols.Column) - } - case *logicalop.LogicalSort: - // Assume statistics of all the columns in ByItems are needed. - for _, item := range x.ByItems { - c.addPredicateColumnsFromExpressions([]expression.Expression{item.Expr}) - } - case *logicalop.LogicalTopN: - // Assume statistics of all the columns in ByItems are needed. - for _, item := range x.ByItems { - c.addPredicateColumnsFromExpressions([]expression.Expression{item.Expr}) - } - case *LogicalUnionAll: - c.collectPredicateColumnsForUnionAll(x) - case *LogicalPartitionUnionAll: - c.collectPredicateColumnsForUnionAll(&x.LogicalUnionAll) - case *LogicalCTE: - // Visit seedPartLogicalPlan and recursivePartLogicalPlan first. - c.collectFromPlan(x.Cte.seedPartLogicalPlan) - if x.Cte.recursivePartLogicalPlan != nil { - c.collectFromPlan(x.Cte.recursivePartLogicalPlan) - } - // Schema change from seedPlan/recursivePlan to self. - columns := x.Schema().Columns - seedColumns := x.Cte.seedPartLogicalPlan.Schema().Columns - var recursiveColumns []*expression.Column - if x.Cte.recursivePartLogicalPlan != nil { - recursiveColumns = x.Cte.recursivePartLogicalPlan.Schema().Columns - } - relatedCols := make([]*expression.Column, 0, 2) - for i, col := range columns { - relatedCols = append(relatedCols[:0], seedColumns[i]) - if recursiveColumns != nil { - relatedCols = append(relatedCols, recursiveColumns[i]) - } - c.updateColMap(col, relatedCols) - } - // If IsDistinct is true, then we use getColsNDV to calculate row count(see (*LogicalCTE).DeriveStat). In this case - // statistics of all the columns are needed. - if x.Cte.IsDistinct { - for _, col := range columns { - c.addPredicateColumn(col) - } - } - case *logicalop.LogicalCTETable: - // Schema change from seedPlan to self. - for i, col := range x.Schema().Columns { - c.updateColMap(col, []*expression.Column{x.SeedSchema.Columns[i]}) - } - } - } - // Histogram-needed columns are the columns which occur in the conditions pushed down to DataSource. - // We don't consider LogicalCTE because seedLogicalPlan and recursiveLogicalPlan haven't got logical optimization - // yet(seedLogicalPlan and recursiveLogicalPlan are optimized in DeriveStats phase). Without logical optimization, - // there is no condition pushed down to DataSource so no histogram-needed column can be collected. - // - // Since c.visitedPhysTblIDs is also collected here and needs to be collected even collectHistNeededColumns is not set, - // so we do the c.collectMode check in addHistNeededColumns() after collecting c.visitedPhysTblIDs. - switch x := lp.(type) { - case *DataSource: - c.addHistNeededColumns(x) - case *LogicalIndexScan: - c.addHistNeededColumns(x.Source) - case *LogicalTableScan: - c.addHistNeededColumns(x.Source) - } -} - -// CollectColumnStatsUsage collects column stats usage from logical plan. -// predicate indicates whether to collect predicate columns and histNeeded indicates whether to collect histogram-needed columns. -// First return value: predicate columns -// Second return value: histogram-needed columns (nil if histNeeded is false) -// Third return value: ds.PhysicalTableID from all DataSource (always collected) -func CollectColumnStatsUsage(lp base.LogicalPlan, histNeeded bool) ( - []model.TableItemID, - []model.StatsLoadItem, - *intset.FastIntSet, -) { - var mode uint64 - // Always collect predicate columns. - mode |= collectPredicateColumns - if histNeeded { - mode |= collectHistNeededColumns - } - collector := newColumnStatsUsageCollector(mode, lp.SCtx().GetSessionVars().IsPlanReplayerCaptureEnabled()) - collector.collectFromPlan(lp) - if collector.collectVisitedTable { - recordTableRuntimeStats(lp.SCtx(), collector.visitedtbls) - } - itemSet2slice := func(set map[model.TableItemID]bool) []model.StatsLoadItem { - ret := make([]model.StatsLoadItem, 0, len(set)) - for item, fullLoad := range set { - ret = append(ret, model.StatsLoadItem{TableItemID: item, FullLoad: fullLoad}) - } - return ret - } - is := lp.SCtx().GetInfoSchema().(infoschema.InfoSchema) - statsHandle := domain.GetDomain(lp.SCtx()).StatsHandle() - physTblIDsWithNeededCols := intset.NewFastIntSet() - for neededCol, fullLoad := range collector.histNeededCols { - if !fullLoad { - continue - } - physTblIDsWithNeededCols.Insert(int(neededCol.TableID)) - } - collector.visitedPhysTblIDs.ForEach(func(physicalTblID int) { - // 1. collect table metadata - tbl, _ := infoschema.FindTableByTblOrPartID(is, int64(physicalTblID)) - if tbl == nil { - return - } - - // 2. handle extra sync/async stats loading for the determinate mode - - // If we visited a table without getting any columns need stats (likely because there are no pushed down - // predicates), and we are in the determinate mode, we need to make sure we are able to get the "analyze row - // count" in getStatsTable(), which means any column/index stats are available. - if lp.SCtx().GetSessionVars().GetOptObjective() != variable.OptObjectiveDeterminate || - // If we already collected some columns that need trigger sync laoding on this table, we don't need to - // additionally do anything for determinate mode. - physTblIDsWithNeededCols.Has(physicalTblID) || - statsHandle == nil { - return - } - tblStats := statsHandle.GetTableStats(tbl.Meta()) - if tblStats == nil || tblStats.Pseudo { - return - } - var colToTriggerLoad *model.TableItemID - for _, col := range tbl.Cols() { - if col.State != model.StatePublic || (col.IsGenerated() && !col.GeneratedStored) || !tblStats.ColAndIdxExistenceMap.HasAnalyzed(col.ID, false) { - continue - } - if colStats := tblStats.GetCol(col.ID); colStats != nil { - // If any stats are already full loaded, we don't need to trigger stats loading on this table. - if colStats.IsFullLoad() { - colToTriggerLoad = nil - break - } - } - // Choose the first column we meet to trigger stats loading. - if colToTriggerLoad == nil { - colToTriggerLoad = &model.TableItemID{TableID: int64(physicalTblID), ID: col.ID, IsIndex: false} - } - } - if colToTriggerLoad == nil { - return - } - for _, idx := range tbl.Indices() { - if idx.Meta().State != model.StatePublic || idx.Meta().MVIndex { - continue - } - // If any stats are already full loaded, we don't need to trigger stats loading on this table. - if idxStats := tblStats.GetIdx(idx.Meta().ID); idxStats != nil && idxStats.IsFullLoad() { - colToTriggerLoad = nil - break - } - } - if colToTriggerLoad == nil { - return - } - if histNeeded { - collector.histNeededCols[*colToTriggerLoad] = true - } else { - asyncload.AsyncLoadHistogramNeededItems.Insert(*colToTriggerLoad, true) - } - }) - var ( - predicateCols []model.TableItemID - histNeededCols []model.StatsLoadItem - ) - predicateCols = maps.Keys(collector.predicateCols) - if histNeeded { - histNeededCols = itemSet2slice(collector.histNeededCols) - } - return predicateCols, histNeededCols, collector.visitedPhysTblIDs -} diff --git a/pkg/planner/core/debugtrace.go b/pkg/planner/core/debugtrace.go index 1f1dfdd68f196..babda2b1d4551 100644 --- a/pkg/planner/core/debugtrace.go +++ b/pkg/planner/core/debugtrace.go @@ -196,11 +196,11 @@ func debugTraceGetStatsTbl( Outdated: outdated, StatsTblInfo: statistics.TraceStatsTbl(statsTbl), } - if val, _err_ := failpoint.Eval(_curpkg_("DebugTraceStableStatsTbl")); _err_ == nil { + failpoint.Inject("DebugTraceStableStatsTbl", func(val failpoint.Value) { if val.(bool) { stabilizeGetStatsTblInfo(traceInfo) } - } + }) root.AppendStepToCurrentContext(traceInfo) } diff --git a/pkg/planner/core/debugtrace.go__failpoint_stash__ b/pkg/planner/core/debugtrace.go__failpoint_stash__ deleted file mode 100644 index babda2b1d4551..0000000000000 --- a/pkg/planner/core/debugtrace.go__failpoint_stash__ +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "strconv" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/context" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/util/hint" -) - -/* - Below is debug trace for the received command from the client. - It records the input to the optimizer at the very beginning of query optimization. -*/ - -type receivedCmdInfo struct { - Command string - ExecutedASTText string - ExecuteStmtInfo *executeInfo -} - -type executeInfo struct { - PreparedSQL string - BinaryParamsInfo []binaryParamInfo - UseCursor bool -} - -type binaryParamInfo struct { - Type string - Value string -} - -func (info *binaryParamInfo) MarshalJSON() ([]byte, error) { - type binaryParamInfoForMarshal binaryParamInfo - infoForMarshal := new(binaryParamInfoForMarshal) - quote := `"` - // We only need the escape functionality of strconv.Quote, the quoting is not needed, - // so we trim the \" prefix and suffix here. - infoForMarshal.Type = strings.TrimSuffix( - strings.TrimPrefix( - strconv.Quote(info.Type), - quote), - quote) - infoForMarshal.Value = strings.TrimSuffix( - strings.TrimPrefix( - strconv.Quote(info.Value), - quote), - quote) - return debugtrace.EncodeJSONCommon(infoForMarshal) -} - -// DebugTraceReceivedCommand records the received command from the client to the debug trace. -func DebugTraceReceivedCommand(s base.PlanContext, cmd byte, stmtNode ast.StmtNode) { - sessionVars := s.GetSessionVars() - trace := debugtrace.GetOrInitDebugTraceRoot(s) - traceInfo := new(receivedCmdInfo) - trace.AppendStepWithNameToCurrentContext(traceInfo, "Received Command") - traceInfo.Command = mysql.Command2Str[cmd] - traceInfo.ExecutedASTText = stmtNode.Text() - - // Collect information for execute stmt, and record it in executeInfo. - var binaryParams []expression.Expression - var planCacheStmt *PlanCacheStmt - if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { - if execStmt.PrepStmt != nil { - planCacheStmt, _ = execStmt.PrepStmt.(*PlanCacheStmt) - } - if execStmt.BinaryArgs != nil { - binaryParams, _ = execStmt.BinaryArgs.([]expression.Expression) - } - } - useCursor := sessionVars.HasStatusFlag(mysql.ServerStatusCursorExists) - // If none of them needs record, we don't need a executeInfo. - if binaryParams == nil && planCacheStmt == nil && !useCursor { - return - } - execInfo := &executeInfo{} - traceInfo.ExecuteStmtInfo = execInfo - execInfo.UseCursor = useCursor - if planCacheStmt != nil { - execInfo.PreparedSQL = planCacheStmt.StmtText - } - if len(binaryParams) > 0 { - execInfo.BinaryParamsInfo = make([]binaryParamInfo, len(binaryParams)) - for i, param := range binaryParams { - execInfo.BinaryParamsInfo[i].Type = param.GetType(s.GetExprCtx().GetEvalCtx()).String() - execInfo.BinaryParamsInfo[i].Value = param.StringWithCtx(s.GetExprCtx().GetEvalCtx(), errors.RedactLogDisable) - } - } -} - -/* - Below is debug trace for the hint that matches the current query. -*/ - -type bindingHint struct { - Hint *hint.HintsSet - trying bool -} - -func (b *bindingHint) MarshalJSON() ([]byte, error) { - tmp := make(map[string]string, 1) - hintStr, err := b.Hint.Restore() - if err != nil { - return debugtrace.EncodeJSONCommon(err) - } - if b.trying { - tmp["Trying Hint"] = hintStr - } else { - tmp["Best Hint"] = hintStr - } - return debugtrace.EncodeJSONCommon(tmp) -} - -// DebugTraceTryBinding records the hint that might be chosen to the debug trace. -func DebugTraceTryBinding(s context.PlanContext, binding *hint.HintsSet) { - root := debugtrace.GetOrInitDebugTraceRoot(s) - traceInfo := &bindingHint{ - Hint: binding, - trying: true, - } - root.AppendStepToCurrentContext(traceInfo) -} - -// DebugTraceBestBinding records the chosen hint to the debug trace. -func DebugTraceBestBinding(s context.PlanContext, binding *hint.HintsSet) { - root := debugtrace.GetOrInitDebugTraceRoot(s) - traceInfo := &bindingHint{ - Hint: binding, - trying: false, - } - root.AppendStepToCurrentContext(traceInfo) -} - -/* - Below is debug trace for getStatsTable(). - Part of the logic for collecting information is in statistics/debug_trace.go. -*/ - -type getStatsTblInfo struct { - TableName string - TblInfoID int64 - InputPhysicalID int64 - HandleIsNil bool - UsePartitionStats bool - CountIsZero bool - Uninitialized bool - Outdated bool - StatsTblInfo *statistics.StatsTblTraceInfo -} - -func debugTraceGetStatsTbl( - s base.PlanContext, - tblInfo *model.TableInfo, - pid int64, - handleIsNil, - usePartitionStats, - countIsZero, - uninitialized, - outdated bool, - statsTbl *statistics.Table, -) { - root := debugtrace.GetOrInitDebugTraceRoot(s) - traceInfo := &getStatsTblInfo{ - TableName: tblInfo.Name.O, - TblInfoID: tblInfo.ID, - InputPhysicalID: pid, - HandleIsNil: handleIsNil, - UsePartitionStats: usePartitionStats, - CountIsZero: countIsZero, - Uninitialized: uninitialized, - Outdated: outdated, - StatsTblInfo: statistics.TraceStatsTbl(statsTbl), - } - failpoint.Inject("DebugTraceStableStatsTbl", func(val failpoint.Value) { - if val.(bool) { - stabilizeGetStatsTblInfo(traceInfo) - } - }) - root.AppendStepToCurrentContext(traceInfo) -} - -// Only for test. -func stabilizeGetStatsTblInfo(info *getStatsTblInfo) { - info.TblInfoID = 100 - info.InputPhysicalID = 100 - tbl := info.StatsTblInfo - if tbl == nil { - return - } - tbl.PhysicalID = 100 - tbl.Version = 440930000000000000 - for _, col := range tbl.Columns { - col.LastUpdateVersion = 440930000000000000 - } - for _, idx := range tbl.Indexes { - idx.LastUpdateVersion = 440930000000000000 - } -} - -/* - Below is debug trace for AccessPath. -*/ - -type accessPathForDebugTrace struct { - IndexName string `json:",omitempty"` - AccessConditions []string - IndexFilters []string - TableFilters []string - PartialPaths []accessPathForDebugTrace `json:",omitempty"` - CountAfterAccess float64 - CountAfterIndex float64 -} - -func convertAccessPathForDebugTrace(ctx expression.EvalContext, path *util.AccessPath, out *accessPathForDebugTrace) { - if path.Index != nil { - out.IndexName = path.Index.Name.O - } - out.AccessConditions = expression.ExprsToStringsForDisplay(ctx, path.AccessConds) - out.IndexFilters = expression.ExprsToStringsForDisplay(ctx, path.IndexFilters) - out.TableFilters = expression.ExprsToStringsForDisplay(ctx, path.TableFilters) - out.CountAfterAccess = path.CountAfterAccess - out.CountAfterIndex = path.CountAfterIndex - out.PartialPaths = make([]accessPathForDebugTrace, len(path.PartialIndexPaths)) - for i, partialPath := range path.PartialIndexPaths { - convertAccessPathForDebugTrace(ctx, partialPath, &out.PartialPaths[i]) - } -} - -func debugTraceAccessPaths(s base.PlanContext, paths []*util.AccessPath) { - root := debugtrace.GetOrInitDebugTraceRoot(s) - traceInfo := make([]accessPathForDebugTrace, len(paths)) - for i, partialPath := range paths { - convertAccessPathForDebugTrace(s.GetExprCtx().GetEvalCtx(), partialPath, &traceInfo[i]) - } - root.AppendStepWithNameToCurrentContext(traceInfo, "Access paths") -} diff --git a/pkg/planner/core/encode.go b/pkg/planner/core/encode.go index 0ae4093bd5504..af67053afa477 100644 --- a/pkg/planner/core/encode.go +++ b/pkg/planner/core/encode.go @@ -41,12 +41,12 @@ func EncodeFlatPlan(flat *FlatPhysicalPlan) string { if flat.InExecute { return "" } - if val, _err_ := failpoint.Eval(_curpkg_("mockPlanRowCount")); _err_ == nil { + failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { selectPlan, _ := flat.Main.GetSelectPlan() for _, op := range selectPlan { op.Origin.StatsInfo().RowCount = float64(val.(int)) } - } + }) pn := encoderPool.Get().(*planEncoder) defer func() { pn.buf.Reset() @@ -164,9 +164,9 @@ func EncodePlan(p base.Plan) string { defer encoderPool.Put(pn) selectPlan := getSelectPlan(p) if selectPlan != nil { - if val, _err_ := failpoint.Eval(_curpkg_("mockPlanRowCount")); _err_ == nil { + failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { selectPlan.StatsInfo().RowCount = float64(val.(int)) - } + }) } return pn.encodePlanTree(p) } diff --git a/pkg/planner/core/encode.go__failpoint_stash__ b/pkg/planner/core/encode.go__failpoint_stash__ deleted file mode 100644 index af67053afa477..0000000000000 --- a/pkg/planner/core/encode.go__failpoint_stash__ +++ /dev/null @@ -1,386 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "bytes" - "crypto/sha256" - "hash" - "strconv" - "sync" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/util/plancodec" -) - -// EncodeFlatPlan encodes a FlatPhysicalPlan with compression. -func EncodeFlatPlan(flat *FlatPhysicalPlan) string { - if len(flat.Main) == 0 { - return "" - } - // We won't collect the plan when we're in "EXPLAIN FOR" statement and the plan is from EXECUTE statement (please - // read comments of InExecute for details about the meaning of InExecute) because we are unable to get some - // necessary information when the execution of the plan is finished and some states in the session such as - // PreparedParams are cleaned. - // The behavior in BinaryPlanStrFromFlatPlan() is also the same. - if flat.InExecute { - return "" - } - failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { - selectPlan, _ := flat.Main.GetSelectPlan() - for _, op := range selectPlan { - op.Origin.StatsInfo().RowCount = float64(val.(int)) - } - }) - pn := encoderPool.Get().(*planEncoder) - defer func() { - pn.buf.Reset() - encoderPool.Put(pn) - }() - buf := pn.buf - buf.Reset() - opCount := len(flat.Main) - for _, cte := range flat.CTEs { - opCount += len(cte) - } - // assume an operator costs around 80 bytes, preallocate space for them - buf.Grow(80 * opCount) - encodeFlatPlanTree(flat.Main, 0, &buf) - for _, cte := range flat.CTEs { - fop := cte[0] - cteDef := cte[0].Origin.(*CTEDefinition) - id := cteDef.CTE.IDForStorage - tp := plancodec.TypeCTEDefinition - taskTypeInfo := plancodec.EncodeTaskType(fop.IsRoot, fop.StoreType) - p := fop.Origin - actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(p.SCtx(), p, nil) - var estRows float64 - if fop.IsPhysicalPlan { - estRows = fop.Origin.(base.PhysicalPlan).GetEstRowCountForDisplay() - } else if statsInfo := p.StatsInfo(); statsInfo != nil { - estRows = statsInfo.RowCount - } - plancodec.EncodePlanNode( - int(fop.Depth), - strconv.Itoa(id)+fop.Label.String(), - tp, - estRows, - taskTypeInfo, - fop.Origin.ExplainInfo(), - actRows, - analyzeInfo, - memoryInfo, - diskInfo, - &buf, - ) - if len(cte) > 1 { - encodeFlatPlanTree(cte[1:], 1, &buf) - } - } - return plancodec.Compress(buf.Bytes()) -} - -func encodeFlatPlanTree(flatTree FlatPlanTree, offset int, buf *bytes.Buffer) { - for i := 0; i < len(flatTree); { - fop := flatTree[i] - taskTypeInfo := plancodec.EncodeTaskType(fop.IsRoot, fop.StoreType) - p := fop.Origin - actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(p.SCtx(), p, nil) - var estRows float64 - if fop.IsPhysicalPlan { - estRows = fop.Origin.(base.PhysicalPlan).GetEstRowCountForDisplay() - } else if statsInfo := p.StatsInfo(); statsInfo != nil { - estRows = statsInfo.RowCount - } - plancodec.EncodePlanNode( - int(fop.Depth), - strconv.Itoa(fop.Origin.ID())+fop.Label.String(), - fop.Origin.TP(), - estRows, - taskTypeInfo, - fop.Origin.ExplainInfo(), - actRows, - analyzeInfo, - memoryInfo, - diskInfo, - buf, - ) - - if fop.NeedReverseDriverSide { - // If NeedReverseDriverSide is true, we don't rely on the order of flatTree. - // Instead, we manually slice the build and probe side children from flatTree and recursively call - // encodeFlatPlanTree to keep build side before probe side. - buildSide := flatTree[fop.ChildrenIdx[1]-offset : fop.ChildrenEndIdx+1-offset] - probeSide := flatTree[fop.ChildrenIdx[0]-offset : fop.ChildrenIdx[1]-offset] - encodeFlatPlanTree(buildSide, fop.ChildrenIdx[1], buf) - encodeFlatPlanTree(probeSide, fop.ChildrenIdx[0], buf) - // Skip the children plan tree of the current operator. - i = fop.ChildrenEndIdx + 1 - offset - } else { - // Normally, we just go to the next element in the slice. - i++ - } - } -} - -var encoderPool = sync.Pool{ - New: func() any { - return &planEncoder{} - }, -} - -type planEncoder struct { - buf bytes.Buffer - encodedPlans map[int]bool - - ctes []*PhysicalCTE -} - -// EncodePlan is used to encodePlan the plan to the plan tree with compressing. -// Deprecated: FlattenPhysicalPlan() + EncodeFlatPlan() is preferred. -func EncodePlan(p base.Plan) string { - if explain, ok := p.(*Explain); ok { - p = explain.TargetPlan - } - if p == nil || p.SCtx() == nil { - return "" - } - pn := encoderPool.Get().(*planEncoder) - defer encoderPool.Put(pn) - selectPlan := getSelectPlan(p) - if selectPlan != nil { - failpoint.Inject("mockPlanRowCount", func(val failpoint.Value) { - selectPlan.StatsInfo().RowCount = float64(val.(int)) - }) - } - return pn.encodePlanTree(p) -} - -func (pn *planEncoder) encodePlanTree(p base.Plan) string { - pn.encodedPlans = make(map[int]bool) - pn.buf.Reset() - pn.ctes = pn.ctes[:0] - pn.encodePlan(p, true, kv.TiKV, 0) - pn.encodeCTEPlan() - return plancodec.Compress(pn.buf.Bytes()) -} - -func (pn *planEncoder) encodeCTEPlan() { - if len(pn.ctes) <= 0 { - return - } - explainedCTEPlan := make(map[int]struct{}) - for i := 0; i < len(pn.ctes); i++ { - x := (*CTEDefinition)(pn.ctes[i]) - // skip if the CTE has been explained, the same CTE has same IDForStorage - if _, ok := explainedCTEPlan[x.CTE.IDForStorage]; ok { - continue - } - taskTypeInfo := plancodec.EncodeTaskType(true, kv.TiKV) - actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(x.SCtx(), x, nil) - rowCount := 0.0 - if statsInfo := x.StatsInfo(); statsInfo != nil { - rowCount = x.StatsInfo().RowCount - } - plancodec.EncodePlanNode(0, strconv.Itoa(x.CTE.IDForStorage), plancodec.TypeCTEDefinition, rowCount, taskTypeInfo, x.ExplainInfo(), actRows, analyzeInfo, memoryInfo, diskInfo, &pn.buf) - pn.encodePlan(x.SeedPlan, true, kv.TiKV, 1) - if x.RecurPlan != nil { - pn.encodePlan(x.RecurPlan, true, kv.TiKV, 1) - } - explainedCTEPlan[x.CTE.IDForStorage] = struct{}{} - } -} - -func (pn *planEncoder) encodePlan(p base.Plan, isRoot bool, store kv.StoreType, depth int) { - taskTypeInfo := plancodec.EncodeTaskType(isRoot, store) - actRows, analyzeInfo, memoryInfo, diskInfo := getRuntimeInfoStr(p.SCtx(), p, nil) - rowCount := 0.0 - if pp, ok := p.(base.PhysicalPlan); ok { - rowCount = pp.GetEstRowCountForDisplay() - } else if statsInfo := p.StatsInfo(); statsInfo != nil { - rowCount = statsInfo.RowCount - } - plancodec.EncodePlanNode(depth, strconv.Itoa(p.ID()), p.TP(), rowCount, taskTypeInfo, p.ExplainInfo(), actRows, analyzeInfo, memoryInfo, diskInfo, &pn.buf) - pn.encodedPlans[p.ID()] = true - depth++ - - selectPlan := getSelectPlan(p) - if selectPlan == nil { - return - } - if !pn.encodedPlans[selectPlan.ID()] { - pn.encodePlan(selectPlan, isRoot, store, depth) - return - } - for _, child := range selectPlan.Children() { - if pn.encodedPlans[child.ID()] { - continue - } - pn.encodePlan(child, isRoot, store, depth) - } - switch copPlan := selectPlan.(type) { - case *PhysicalTableReader: - pn.encodePlan(copPlan.tablePlan, false, copPlan.StoreType, depth) - case *PhysicalIndexReader: - pn.encodePlan(copPlan.indexPlan, false, store, depth) - case *PhysicalIndexLookUpReader: - pn.encodePlan(copPlan.indexPlan, false, store, depth) - pn.encodePlan(copPlan.tablePlan, false, store, depth) - case *PhysicalIndexMergeReader: - for _, p := range copPlan.partialPlans { - pn.encodePlan(p, false, store, depth) - } - if copPlan.tablePlan != nil { - pn.encodePlan(copPlan.tablePlan, false, store, depth) - } - case *PhysicalCTE: - pn.ctes = append(pn.ctes, copPlan) - } -} - -var digesterPool = sync.Pool{ - New: func() any { - return &planDigester{ - hasher: sha256.New(), - } - }, -} - -type planDigester struct { - buf bytes.Buffer - encodedPlans map[int]bool - hasher hash.Hash -} - -// NormalizeFlatPlan normalizes a FlatPhysicalPlan and generates plan digest. -func NormalizeFlatPlan(flat *FlatPhysicalPlan) (normalized string, digest *parser.Digest) { - if flat == nil { - return "", parser.NewDigest(nil) - } - selectPlan, selectPlanOffset := flat.Main.GetSelectPlan() - if len(selectPlan) == 0 || !selectPlan[0].IsPhysicalPlan { - return "", parser.NewDigest(nil) - } - d := digesterPool.Get().(*planDigester) - defer func() { - d.buf.Reset() - d.hasher.Reset() - digesterPool.Put(d) - }() - // assume an operator costs around 30 bytes, preallocate space for them - d.buf.Grow(30 * len(selectPlan)) - for _, fop := range selectPlan { - taskTypeInfo := plancodec.EncodeTaskTypeForNormalize(fop.IsRoot, fop.StoreType) - p := fop.Origin.(base.PhysicalPlan) - plancodec.NormalizePlanNode( - int(fop.Depth-uint32(selectPlanOffset)), - fop.Origin.TP(), - taskTypeInfo, - p.ExplainNormalizedInfo(), - &d.buf, - ) - } - normalized = d.buf.String() - if len(normalized) == 0 { - return "", parser.NewDigest(nil) - } - _, err := d.hasher.Write(d.buf.Bytes()) - if err != nil { - panic(err) - } - digest = parser.NewDigest(d.hasher.Sum(nil)) - return -} - -// NormalizePlan is used to normalize the plan and generate plan digest. -// Deprecated: FlattenPhysicalPlan() + NormalizeFlatPlan() is preferred. -func NormalizePlan(p base.Plan) (normalized string, digest *parser.Digest) { - selectPlan := getSelectPlan(p) - if selectPlan == nil { - return "", parser.NewDigest(nil) - } - d := digesterPool.Get().(*planDigester) - defer func() { - d.buf.Reset() - d.hasher.Reset() - digesterPool.Put(d) - }() - d.normalizePlanTree(selectPlan) - normalized = d.buf.String() - _, err := d.hasher.Write(d.buf.Bytes()) - if err != nil { - panic(err) - } - digest = parser.NewDigest(d.hasher.Sum(nil)) - return -} - -func (d *planDigester) normalizePlanTree(p base.PhysicalPlan) { - d.encodedPlans = make(map[int]bool) - d.buf.Reset() - d.normalizePlan(p, true, kv.TiKV, 0) -} - -func (d *planDigester) normalizePlan(p base.PhysicalPlan, isRoot bool, store kv.StoreType, depth int) { - taskTypeInfo := plancodec.EncodeTaskTypeForNormalize(isRoot, store) - plancodec.NormalizePlanNode(depth, p.TP(), taskTypeInfo, p.ExplainNormalizedInfo(), &d.buf) - d.encodedPlans[p.ID()] = true - - depth++ - for _, child := range p.Children() { - if d.encodedPlans[child.ID()] { - continue - } - d.normalizePlan(child, isRoot, store, depth) - } - switch x := p.(type) { - case *PhysicalTableReader: - d.normalizePlan(x.tablePlan, false, x.StoreType, depth) - case *PhysicalIndexReader: - d.normalizePlan(x.indexPlan, false, store, depth) - case *PhysicalIndexLookUpReader: - d.normalizePlan(x.indexPlan, false, store, depth) - d.normalizePlan(x.tablePlan, false, store, depth) - case *PhysicalIndexMergeReader: - for _, p := range x.partialPlans { - d.normalizePlan(p, false, store, depth) - } - if x.tablePlan != nil { - d.normalizePlan(x.tablePlan, false, store, depth) - } - } -} - -func getSelectPlan(p base.Plan) base.PhysicalPlan { - var selectPlan base.PhysicalPlan - if physicalPlan, ok := p.(base.PhysicalPlan); ok { - selectPlan = physicalPlan - } else { - switch x := p.(type) { - case *Delete: - selectPlan = x.SelectPlan - case *Update: - selectPlan = x.SelectPlan - case *Insert: - selectPlan = x.SelectPlan - case *Explain: - selectPlan = getSelectPlan(x.TargetPlan) - } - } - return selectPlan -} diff --git a/pkg/planner/core/exhaust_physical_plans.go b/pkg/planner/core/exhaust_physical_plans.go index 21c1043b84848..ad3398b1f5fad 100644 --- a/pkg/planner/core/exhaust_physical_plans.go +++ b/pkg/planner/core/exhaust_physical_plans.go @@ -873,11 +873,11 @@ func buildIndexJoinInner2TableScan( lastColMng = indexJoinResult.lastColManager } joins = make([]base.PhysicalPlan, 0, 3) - if val, _err_ := failpoint.Eval(_curpkg_("MockOnlyEnableIndexHashJoin")); _err_ == nil { + failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { - return constructIndexHashJoin(p, prop, outerIdx, innerTask, nil, keyOff2IdxOff, path, lastColMng) + failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, nil, keyOff2IdxOff, path, lastColMng)) } - } + }) joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, lastColMng, true)...) // We can reuse the `innerTask` here since index nested loop hash join // do not need the inner child to promise the order. @@ -924,11 +924,11 @@ func buildIndexJoinInner2IndexScan( } } innerTask := constructInnerIndexScanTask(p, prop, wrapper, indexJoinResult.chosenPath, indexJoinResult.chosenRanges.Range(), indexJoinResult.chosenRemained, innerJoinKeys, indexJoinResult.idxOff2KeyOff, rangeInfo, false, false, avgInnerRowCnt, maxOneRow) - if val, _err_ := failpoint.Eval(_curpkg_("MockOnlyEnableIndexHashJoin")); _err_ == nil { + failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL && innerTask != nil { - return constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager) + failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)) } - } + }) if innerTask != nil { joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager, true)...) // We can reuse the `innerTask` here since index nested loop hash join @@ -1867,11 +1867,11 @@ func tryToGetMppHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, useBCJ } // set preferredBuildIndex for test - if val, _err_ := failpoint.Eval(_curpkg_("mockPreferredBuildIndex")); _err_ == nil { + failpoint.Inject("mockPreferredBuildIndex", func(val failpoint.Value) { if !p.SCtx().GetSessionVars().InRestrictedSQL { preferredBuildIndex = val.(int) } - } + }) baseJoin.InnerChildIdx = preferredBuildIndex childrenProps := make([]*property.PhysicalProperty, 2) diff --git a/pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ b/pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ deleted file mode 100644 index ad3398b1f5fad..0000000000000 --- a/pkg/planner/core/exhaust_physical_plans.go__failpoint_stash__ +++ /dev/null @@ -1,3004 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "fmt" - "math" - "slices" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/expression/aggregation" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/cardinality" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/cost" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/types" - h "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/plancodec" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tidb/pkg/util/set" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -func exhaustPhysicalPlans4LogicalUnionScan(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - p := lp.(*logicalop.LogicalUnionScan) - if prop.IsFlashProp() { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because operator `UnionScan` is not supported now.") - return nil, true, nil - } - childProp := prop.CloneEssentialFields() - us := PhysicalUnionScan{ - Conditions: p.Conditions, - HandleCols: p.HandleCols, - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), childProp) - return []base.PhysicalPlan{us}, true, nil -} - -func findMaxPrefixLen(candidates [][]*expression.Column, keys []*expression.Column) int { - maxLen := 0 - for _, candidateKeys := range candidates { - matchedLen := 0 - for i := range keys { - if !(i < len(candidateKeys) && keys[i].EqualColumn(candidateKeys[i])) { - break - } - matchedLen++ - } - if matchedLen > maxLen { - maxLen = matchedLen - } - } - return maxLen -} - -func moveEqualToOtherConditions(p *LogicalJoin, offsets []int) []expression.Expression { - // Construct used equal condition set based on the equal condition offsets. - usedEqConds := set.NewIntSet() - for _, eqCondIdx := range offsets { - usedEqConds.Insert(eqCondIdx) - } - - // Construct otherConds, which is composed of the original other conditions - // and the remained unused equal conditions. - numOtherConds := len(p.OtherConditions) + len(p.EqualConditions) - len(usedEqConds) - otherConds := make([]expression.Expression, len(p.OtherConditions), numOtherConds) - copy(otherConds, p.OtherConditions) - for eqCondIdx := range p.EqualConditions { - if !usedEqConds.Exist(eqCondIdx) { - otherConds = append(otherConds, p.EqualConditions[eqCondIdx]) - } - } - - return otherConds -} - -// Only if the input required prop is the prefix fo join keys, we can pass through this property. -func (p *PhysicalMergeJoin) tryToGetChildReqProp(prop *property.PhysicalProperty) ([]*property.PhysicalProperty, bool) { - all, desc := prop.AllSameOrder() - lProp := property.NewPhysicalProperty(property.RootTaskType, p.LeftJoinKeys, desc, math.MaxFloat64, false) - rProp := property.NewPhysicalProperty(property.RootTaskType, p.RightJoinKeys, desc, math.MaxFloat64, false) - lProp.CTEProducerStatus = prop.CTEProducerStatus - rProp.CTEProducerStatus = prop.CTEProducerStatus - if !prop.IsSortItemEmpty() { - // sort merge join fits the cases of massive ordered data, so desc scan is always expensive. - if !all { - return nil, false - } - if !prop.IsPrefix(lProp) && !prop.IsPrefix(rProp) { - return nil, false - } - if prop.IsPrefix(rProp) && p.JoinType == LeftOuterJoin { - return nil, false - } - if prop.IsPrefix(lProp) && p.JoinType == RightOuterJoin { - return nil, false - } - } - - return []*property.PhysicalProperty{lProp, rProp}, true -} - -func checkJoinKeyCollation(leftKeys, rightKeys []*expression.Column) bool { - // if a left key and its corresponding right key have different collation, don't use MergeJoin since - // the their children may sort their records in different ways - for i := range leftKeys { - lt := leftKeys[i].RetType - rt := rightKeys[i].RetType - if (lt.EvalType() == types.ETString && rt.EvalType() == types.ETString) && - (leftKeys[i].RetType.GetCharset() != rightKeys[i].RetType.GetCharset() || - leftKeys[i].RetType.GetCollate() != rightKeys[i].RetType.GetCollate()) { - return false - } - } - return true -} - -// GetMergeJoin convert the logical join to physical merge join based on the physical property. -func GetMergeJoin(p *LogicalJoin, prop *property.PhysicalProperty, schema *expression.Schema, statsInfo *property.StatsInfo, leftStatsInfo *property.StatsInfo, rightStatsInfo *property.StatsInfo) []base.PhysicalPlan { - joins := make([]base.PhysicalPlan, 0, len(p.LeftProperties)+1) - // The LeftProperties caches all the possible properties that are provided by its children. - leftJoinKeys, rightJoinKeys, isNullEQ, hasNullEQ := p.GetJoinKeys() - - // EnumType/SetType Unsupported: merge join conflicts with index order. - // ref: https://github.com/pingcap/tidb/issues/24473, https://github.com/pingcap/tidb/issues/25669 - for _, leftKey := range leftJoinKeys { - if leftKey.RetType.GetType() == mysql.TypeEnum || leftKey.RetType.GetType() == mysql.TypeSet { - return nil - } - } - for _, rightKey := range rightJoinKeys { - if rightKey.RetType.GetType() == mysql.TypeEnum || rightKey.RetType.GetType() == mysql.TypeSet { - return nil - } - } - - // TODO: support null equal join keys for merge join - if hasNullEQ { - return nil - } - for _, lhsChildProperty := range p.LeftProperties { - offsets := util.GetMaxSortPrefix(lhsChildProperty, leftJoinKeys) - // If not all equal conditions hit properties. We ban merge join heuristically. Because in this case, merge join - // may get a very low performance. In executor, executes join results before other conditions filter it. - if len(offsets) < len(leftJoinKeys) { - continue - } - - leftKeys := lhsChildProperty[:len(offsets)] - rightKeys := expression.NewSchema(rightJoinKeys...).ColumnsByIndices(offsets) - newIsNullEQ := make([]bool, 0, len(offsets)) - for _, offset := range offsets { - newIsNullEQ = append(newIsNullEQ, isNullEQ[offset]) - } - - prefixLen := findMaxPrefixLen(p.RightProperties, rightKeys) - if prefixLen == 0 { - continue - } - - leftKeys = leftKeys[:prefixLen] - rightKeys = rightKeys[:prefixLen] - newIsNullEQ = newIsNullEQ[:prefixLen] - if !checkJoinKeyCollation(leftKeys, rightKeys) { - continue - } - offsets = offsets[:prefixLen] - baseJoin := basePhysicalJoin{ - JoinType: p.JoinType, - LeftConditions: p.LeftConditions, - RightConditions: p.RightConditions, - DefaultValues: p.DefaultValues, - LeftJoinKeys: leftKeys, - RightJoinKeys: rightKeys, - IsNullEQ: newIsNullEQ, - } - mergeJoin := PhysicalMergeJoin{basePhysicalJoin: baseJoin}.Init(p.SCtx(), statsInfo.ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset()) - mergeJoin.SetSchema(schema) - mergeJoin.OtherConditions = moveEqualToOtherConditions(p, offsets) - mergeJoin.initCompareFuncs() - if reqProps, ok := mergeJoin.tryToGetChildReqProp(prop); ok { - // Adjust expected count for children nodes. - if prop.ExpectedCnt < statsInfo.RowCount { - expCntScale := prop.ExpectedCnt / statsInfo.RowCount - reqProps[0].ExpectedCnt = leftStatsInfo.RowCount * expCntScale - reqProps[1].ExpectedCnt = rightStatsInfo.RowCount * expCntScale - } - mergeJoin.childrenReqProps = reqProps - _, desc := prop.AllSameOrder() - mergeJoin.Desc = desc - joins = append(joins, mergeJoin) - } - } - - if p.PreferJoinType&h.PreferNoMergeJoin > 0 { - if p.PreferJoinType&h.PreferMergeJoin == 0 { - return nil - } - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( - "Some MERGE_JOIN and NO_MERGE_JOIN hints conflict, NO_MERGE_JOIN is ignored") - } - - // If TiDB_SMJ hint is existed, it should consider enforce merge join, - // because we can't trust lhsChildProperty completely. - if (p.PreferJoinType&h.PreferMergeJoin) > 0 || - shouldSkipHashJoin(p) { // if hash join is not allowed, generate as many other types of join as possible to avoid 'cant-find-plan' error. - joins = append(joins, getEnforcedMergeJoin(p, prop, schema, statsInfo)...) - } - - return joins -} - -// Change JoinKeys order, by offsets array -// offsets array is generate by prop check -func getNewJoinKeysByOffsets(oldJoinKeys []*expression.Column, offsets []int) []*expression.Column { - newKeys := make([]*expression.Column, 0, len(oldJoinKeys)) - for _, offset := range offsets { - newKeys = append(newKeys, oldJoinKeys[offset]) - } - for pos, key := range oldJoinKeys { - isExist := false - for _, p := range offsets { - if p == pos { - isExist = true - break - } - } - if !isExist { - newKeys = append(newKeys, key) - } - } - return newKeys -} - -func getNewNullEQByOffsets(oldNullEQ []bool, offsets []int) []bool { - newNullEQ := make([]bool, 0, len(oldNullEQ)) - for _, offset := range offsets { - newNullEQ = append(newNullEQ, oldNullEQ[offset]) - } - for pos, key := range oldNullEQ { - isExist := false - for _, p := range offsets { - if p == pos { - isExist = true - break - } - } - if !isExist { - newNullEQ = append(newNullEQ, key) - } - } - return newNullEQ -} - -func getEnforcedMergeJoin(p *LogicalJoin, prop *property.PhysicalProperty, schema *expression.Schema, statsInfo *property.StatsInfo) []base.PhysicalPlan { - // Check whether SMJ can satisfy the required property - leftJoinKeys, rightJoinKeys, isNullEQ, hasNullEQ := p.GetJoinKeys() - // TODO: support null equal join keys for merge join - if hasNullEQ { - return nil - } - offsets := make([]int, 0, len(leftJoinKeys)) - all, desc := prop.AllSameOrder() - if !all { - return nil - } - evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() - for _, item := range prop.SortItems { - isExist, hasLeftColInProp, hasRightColInProp := false, false, false - for joinKeyPos := 0; joinKeyPos < len(leftJoinKeys); joinKeyPos++ { - var key *expression.Column - if item.Col.Equal(evalCtx, leftJoinKeys[joinKeyPos]) { - key = leftJoinKeys[joinKeyPos] - hasLeftColInProp = true - } - if item.Col.Equal(evalCtx, rightJoinKeys[joinKeyPos]) { - key = rightJoinKeys[joinKeyPos] - hasRightColInProp = true - } - if key == nil { - continue - } - for i := 0; i < len(offsets); i++ { - if offsets[i] == joinKeyPos { - isExist = true - break - } - } - if !isExist { - offsets = append(offsets, joinKeyPos) - } - isExist = true - break - } - if !isExist { - return nil - } - // If the output wants the order of the inner side. We should reject it since we might add null-extend rows of that side. - if p.JoinType == LeftOuterJoin && hasRightColInProp { - return nil - } - if p.JoinType == RightOuterJoin && hasLeftColInProp { - return nil - } - } - // Generate the enforced sort merge join - leftKeys := getNewJoinKeysByOffsets(leftJoinKeys, offsets) - rightKeys := getNewJoinKeysByOffsets(rightJoinKeys, offsets) - newNullEQ := getNewNullEQByOffsets(isNullEQ, offsets) - otherConditions := make([]expression.Expression, len(p.OtherConditions), len(p.OtherConditions)+len(p.EqualConditions)) - copy(otherConditions, p.OtherConditions) - if !checkJoinKeyCollation(leftKeys, rightKeys) { - // if the join keys' collation are conflicted, we use the empty join key - // and move EqualConditions to OtherConditions. - leftKeys = nil - rightKeys = nil - newNullEQ = nil - otherConditions = append(otherConditions, expression.ScalarFuncs2Exprs(p.EqualConditions)...) - } - lProp := property.NewPhysicalProperty(property.RootTaskType, leftKeys, desc, math.MaxFloat64, true) - rProp := property.NewPhysicalProperty(property.RootTaskType, rightKeys, desc, math.MaxFloat64, true) - baseJoin := basePhysicalJoin{ - JoinType: p.JoinType, - LeftConditions: p.LeftConditions, - RightConditions: p.RightConditions, - DefaultValues: p.DefaultValues, - LeftJoinKeys: leftKeys, - RightJoinKeys: rightKeys, - IsNullEQ: newNullEQ, - OtherConditions: otherConditions, - } - enforcedPhysicalMergeJoin := PhysicalMergeJoin{basePhysicalJoin: baseJoin, Desc: desc}.Init(p.SCtx(), statsInfo.ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset()) - enforcedPhysicalMergeJoin.SetSchema(schema) - enforcedPhysicalMergeJoin.childrenReqProps = []*property.PhysicalProperty{lProp, rProp} - enforcedPhysicalMergeJoin.initCompareFuncs() - return []base.PhysicalPlan{enforcedPhysicalMergeJoin} -} - -func (p *PhysicalMergeJoin) initCompareFuncs() { - p.CompareFuncs = make([]expression.CompareFunc, 0, len(p.LeftJoinKeys)) - for i := range p.LeftJoinKeys { - p.CompareFuncs = append(p.CompareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), p.LeftJoinKeys[i], p.RightJoinKeys[i])) - } -} - -func shouldSkipHashJoin(p *LogicalJoin) bool { - return (p.PreferJoinType&h.PreferNoHashJoin) > 0 || (p.SCtx().GetSessionVars().DisableHashJoin) -} - -func getHashJoins(p *LogicalJoin, prop *property.PhysicalProperty) (joins []base.PhysicalPlan, forced bool) { - if !prop.IsSortItemEmpty() { // hash join doesn't promise any orders - return - } - - forceLeftToBuild := ((p.PreferJoinType & h.PreferLeftAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferRightAsHJProbe) > 0) - forceRightToBuild := ((p.PreferJoinType & h.PreferRightAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferLeftAsHJProbe) > 0) - if forceLeftToBuild && forceRightToBuild { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("Some HASH_JOIN_BUILD and HASH_JOIN_PROBE hints are conflicts, please check the hints") - forceLeftToBuild = false - forceRightToBuild = false - } - - joins = make([]base.PhysicalPlan, 0, 2) - switch p.JoinType { - case SemiJoin, AntiSemiJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: - joins = append(joins, getHashJoin(p, prop, 1, false)) - if forceLeftToBuild || forceRightToBuild { - // Do not support specifying the build and probe side for semi join. - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(fmt.Sprintf("We can't use the HASH_JOIN_BUILD or HASH_JOIN_PROBE hint for %s, please check the hint", p.JoinType)) - forceLeftToBuild = false - forceRightToBuild = false - } - case LeftOuterJoin: - if !forceLeftToBuild { - joins = append(joins, getHashJoin(p, prop, 1, false)) - } - if !forceRightToBuild { - joins = append(joins, getHashJoin(p, prop, 1, true)) - } - case RightOuterJoin: - if !forceLeftToBuild { - joins = append(joins, getHashJoin(p, prop, 0, true)) - } - if !forceRightToBuild { - joins = append(joins, getHashJoin(p, prop, 0, false)) - } - case InnerJoin: - if forceLeftToBuild { - joins = append(joins, getHashJoin(p, prop, 0, false)) - } else if forceRightToBuild { - joins = append(joins, getHashJoin(p, prop, 1, false)) - } else { - joins = append(joins, getHashJoin(p, prop, 1, false)) - joins = append(joins, getHashJoin(p, prop, 0, false)) - } - } - - forced = (p.PreferJoinType&h.PreferHashJoin > 0) || forceLeftToBuild || forceRightToBuild - shouldSkipHashJoin := shouldSkipHashJoin(p) - if !forced && shouldSkipHashJoin { - return nil, false - } else if forced && shouldSkipHashJoin { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( - "A conflict between the HASH_JOIN hint and the NO_HASH_JOIN hint, " + - "or the tidb_opt_enable_hash_join system variable, the HASH_JOIN hint will take precedence.") - } - return -} - -func getHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, innerIdx int, useOuterToBuild bool) *PhysicalHashJoin { - chReqProps := make([]*property.PhysicalProperty, 2) - chReqProps[innerIdx] = &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus} - chReqProps[1-innerIdx] = &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus} - if prop.ExpectedCnt < p.StatsInfo().RowCount { - expCntScale := prop.ExpectedCnt / p.StatsInfo().RowCount - chReqProps[1-innerIdx].ExpectedCnt = p.Children()[1-innerIdx].StatsInfo().RowCount * expCntScale - } - hashJoin := NewPhysicalHashJoin(p, innerIdx, useOuterToBuild, p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), chReqProps...) - hashJoin.SetSchema(p.Schema()) - return hashJoin -} - -// When inner plan is TableReader, the parameter `ranges` will be nil. Because pk only have one column. So all of its range -// is generated during execution time. -func constructIndexJoin( - p *LogicalJoin, - prop *property.PhysicalProperty, - outerIdx int, - innerTask base.Task, - ranges ranger.MutableRanges, - keyOff2IdxOff []int, - path *util.AccessPath, - compareFilters *ColWithCmpFuncManager, - extractOtherEQ bool, -) []base.PhysicalPlan { - if ranges == nil { - ranges = ranger.Ranges{} // empty range - } - - joinType := p.JoinType - var ( - innerJoinKeys []*expression.Column - outerJoinKeys []*expression.Column - isNullEQ []bool - hasNullEQ bool - ) - if outerIdx == 0 { - outerJoinKeys, innerJoinKeys, isNullEQ, hasNullEQ = p.GetJoinKeys() - } else { - innerJoinKeys, outerJoinKeys, isNullEQ, hasNullEQ = p.GetJoinKeys() - } - // TODO: support null equal join keys for index join - if hasNullEQ { - return nil - } - chReqProps := make([]*property.PhysicalProperty, 2) - chReqProps[outerIdx] = &property.PhysicalProperty{TaskTp: property.RootTaskType, ExpectedCnt: math.MaxFloat64, SortItems: prop.SortItems, CTEProducerStatus: prop.CTEProducerStatus} - if prop.ExpectedCnt < p.StatsInfo().RowCount { - expCntScale := prop.ExpectedCnt / p.StatsInfo().RowCount - chReqProps[outerIdx].ExpectedCnt = p.Children()[outerIdx].StatsInfo().RowCount * expCntScale - } - newInnerKeys := make([]*expression.Column, 0, len(innerJoinKeys)) - newOuterKeys := make([]*expression.Column, 0, len(outerJoinKeys)) - newIsNullEQ := make([]bool, 0, len(isNullEQ)) - newKeyOff := make([]int, 0, len(keyOff2IdxOff)) - newOtherConds := make([]expression.Expression, len(p.OtherConditions), len(p.OtherConditions)+len(p.EqualConditions)) - copy(newOtherConds, p.OtherConditions) - for keyOff, idxOff := range keyOff2IdxOff { - if keyOff2IdxOff[keyOff] < 0 { - newOtherConds = append(newOtherConds, p.EqualConditions[keyOff]) - continue - } - newInnerKeys = append(newInnerKeys, innerJoinKeys[keyOff]) - newOuterKeys = append(newOuterKeys, outerJoinKeys[keyOff]) - newIsNullEQ = append(newIsNullEQ, isNullEQ[keyOff]) - newKeyOff = append(newKeyOff, idxOff) - } - - var outerHashKeys, innerHashKeys []*expression.Column - outerHashKeys, innerHashKeys = make([]*expression.Column, len(newOuterKeys)), make([]*expression.Column, len(newInnerKeys)) - copy(outerHashKeys, newOuterKeys) - copy(innerHashKeys, newInnerKeys) - // we can use the `col col` in `OtherCondition` to build the hashtable to avoid the unnecessary calculating. - for i := len(newOtherConds) - 1; extractOtherEQ && i >= 0; i = i - 1 { - switch c := newOtherConds[i].(type) { - case *expression.ScalarFunction: - if c.FuncName.L == ast.EQ { - lhs, ok1 := c.GetArgs()[0].(*expression.Column) - rhs, ok2 := c.GetArgs()[1].(*expression.Column) - if ok1 && ok2 { - if lhs.InOperand || rhs.InOperand { - // if this other-cond is from a `[not] in` sub-query, do not convert it into eq-cond since - // IndexJoin cannot deal with NULL correctly in this case; please see #25799 for more details. - continue - } - outerSchema, innerSchema := p.Children()[outerIdx].Schema(), p.Children()[1-outerIdx].Schema() - if outerSchema.Contains(lhs) && innerSchema.Contains(rhs) { - outerHashKeys = append(outerHashKeys, lhs) // nozero - innerHashKeys = append(innerHashKeys, rhs) // nozero - } else if innerSchema.Contains(lhs) && outerSchema.Contains(rhs) { - outerHashKeys = append(outerHashKeys, rhs) // nozero - innerHashKeys = append(innerHashKeys, lhs) // nozero - } - newOtherConds = append(newOtherConds[:i], newOtherConds[i+1:]...) - } - } - default: - continue - } - } - - baseJoin := basePhysicalJoin{ - InnerChildIdx: 1 - outerIdx, - LeftConditions: p.LeftConditions, - RightConditions: p.RightConditions, - OtherConditions: newOtherConds, - JoinType: joinType, - OuterJoinKeys: newOuterKeys, - InnerJoinKeys: newInnerKeys, - IsNullEQ: newIsNullEQ, - DefaultValues: p.DefaultValues, - } - - join := PhysicalIndexJoin{ - basePhysicalJoin: baseJoin, - innerPlan: innerTask.Plan(), - KeyOff2IdxOff: newKeyOff, - Ranges: ranges, - CompareFilters: compareFilters, - OuterHashKeys: outerHashKeys, - InnerHashKeys: innerHashKeys, - }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), chReqProps...) - if path != nil { - join.IdxColLens = path.IdxColLens - } - join.SetSchema(p.Schema()) - return []base.PhysicalPlan{join} -} - -func constructIndexMergeJoin( - p *LogicalJoin, - prop *property.PhysicalProperty, - outerIdx int, - innerTask base.Task, - ranges ranger.MutableRanges, - keyOff2IdxOff []int, - path *util.AccessPath, - compareFilters *ColWithCmpFuncManager, -) []base.PhysicalPlan { - hintExists := false - if (outerIdx == 1 && (p.PreferJoinType&h.PreferLeftAsINLMJInner) > 0) || (outerIdx == 0 && (p.PreferJoinType&h.PreferRightAsINLMJInner) > 0) { - hintExists = true - } - indexJoins := constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, compareFilters, !hintExists) - indexMergeJoins := make([]base.PhysicalPlan, 0, len(indexJoins)) - for _, plan := range indexJoins { - join := plan.(*PhysicalIndexJoin) - // Index merge join can't handle hash keys. So we ban it heuristically. - if len(join.InnerHashKeys) > len(join.InnerJoinKeys) { - return nil - } - - // EnumType/SetType Unsupported: merge join conflicts with index order. - // ref: https://github.com/pingcap/tidb/issues/24473, https://github.com/pingcap/tidb/issues/25669 - for _, innerKey := range join.InnerJoinKeys { - if innerKey.RetType.GetType() == mysql.TypeEnum || innerKey.RetType.GetType() == mysql.TypeSet { - return nil - } - } - for _, outerKey := range join.OuterJoinKeys { - if outerKey.RetType.GetType() == mysql.TypeEnum || outerKey.RetType.GetType() == mysql.TypeSet { - return nil - } - } - - hasPrefixCol := false - for _, l := range join.IdxColLens { - if l != types.UnspecifiedLength { - hasPrefixCol = true - break - } - } - // If index column has prefix length, the merge join can not guarantee the relevance - // between index and join keys. So we should skip this case. - // For more details, please check the following code and comments. - if hasPrefixCol { - continue - } - - // keyOff2KeyOffOrderByIdx is map the join keys offsets to [0, len(joinKeys)) ordered by the - // join key position in inner index. - keyOff2KeyOffOrderByIdx := make([]int, len(join.OuterJoinKeys)) - keyOffMapList := make([]int, len(join.KeyOff2IdxOff)) - copy(keyOffMapList, join.KeyOff2IdxOff) - keyOffMap := make(map[int]int, len(keyOffMapList)) - for i, idxOff := range keyOffMapList { - keyOffMap[idxOff] = i - } - slices.Sort(keyOffMapList) - keyIsIndexPrefix := true - for keyOff, idxOff := range keyOffMapList { - if keyOff != idxOff { - keyIsIndexPrefix = false - break - } - keyOff2KeyOffOrderByIdx[keyOffMap[idxOff]] = keyOff - } - if !keyIsIndexPrefix { - continue - } - // isOuterKeysPrefix means whether the outer join keys are the prefix of the prop items. - isOuterKeysPrefix := len(join.OuterJoinKeys) <= len(prop.SortItems) - compareFuncs := make([]expression.CompareFunc, 0, len(join.OuterJoinKeys)) - outerCompareFuncs := make([]expression.CompareFunc, 0, len(join.OuterJoinKeys)) - - for i := range join.KeyOff2IdxOff { - if isOuterKeysPrefix && !prop.SortItems[i].Col.EqualColumn(join.OuterJoinKeys[keyOff2KeyOffOrderByIdx[i]]) { - isOuterKeysPrefix = false - } - compareFuncs = append(compareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), join.OuterJoinKeys[i], join.InnerJoinKeys[i])) - outerCompareFuncs = append(outerCompareFuncs, expression.GetCmpFunction(p.SCtx().GetExprCtx(), join.OuterJoinKeys[i], join.OuterJoinKeys[i])) - } - // canKeepOuterOrder means whether the prop items are the prefix of the outer join keys. - canKeepOuterOrder := len(prop.SortItems) <= len(join.OuterJoinKeys) - for i := 0; canKeepOuterOrder && i < len(prop.SortItems); i++ { - if !prop.SortItems[i].Col.EqualColumn(join.OuterJoinKeys[keyOff2KeyOffOrderByIdx[i]]) { - canKeepOuterOrder = false - } - } - // Since index merge join requires prop items the prefix of outer join keys - // or outer join keys the prefix of the prop items. So we need `canKeepOuterOrder` or - // `isOuterKeysPrefix` to be true. - if canKeepOuterOrder || isOuterKeysPrefix { - indexMergeJoin := PhysicalIndexMergeJoin{ - PhysicalIndexJoin: *join, - KeyOff2KeyOffOrderByIdx: keyOff2KeyOffOrderByIdx, - NeedOuterSort: !isOuterKeysPrefix, - CompareFuncs: compareFuncs, - OuterCompareFuncs: outerCompareFuncs, - Desc: !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, - }.Init(p.SCtx()) - indexMergeJoins = append(indexMergeJoins, indexMergeJoin) - } - } - return indexMergeJoins -} - -func constructIndexHashJoin( - p *LogicalJoin, - prop *property.PhysicalProperty, - outerIdx int, - innerTask base.Task, - ranges ranger.MutableRanges, - keyOff2IdxOff []int, - path *util.AccessPath, - compareFilters *ColWithCmpFuncManager, -) []base.PhysicalPlan { - indexJoins := constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, compareFilters, true) - indexHashJoins := make([]base.PhysicalPlan, 0, len(indexJoins)) - for _, plan := range indexJoins { - join := plan.(*PhysicalIndexJoin) - indexHashJoin := PhysicalIndexHashJoin{ - PhysicalIndexJoin: *join, - // Prop is empty means that the parent operator does not need the - // join operator to provide any promise of the output order. - KeepOuterOrder: !prop.IsSortItemEmpty(), - }.Init(p.SCtx()) - indexHashJoins = append(indexHashJoins, indexHashJoin) - } - return indexHashJoins -} - -// getIndexJoinByOuterIdx will generate index join by outerIndex. OuterIdx points out the outer child. -// First of all, we'll check whether the inner child is DataSource. -// Then, we will extract the join keys of p's equal conditions. Then check whether all of them are just the primary key -// or match some part of on index. If so we will choose the best one and construct a index join. -func getIndexJoinByOuterIdx(p *LogicalJoin, prop *property.PhysicalProperty, outerIdx int) (joins []base.PhysicalPlan) { - outerChild, innerChild := p.Children()[outerIdx], p.Children()[1-outerIdx] - all, _ := prop.AllSameOrder() - // If the order by columns are not all from outer child, index join cannot promise the order. - if !prop.AllColsFromSchema(outerChild.Schema()) || !all { - return nil - } - var ( - innerJoinKeys []*expression.Column - outerJoinKeys []*expression.Column - ) - if outerIdx == 0 { - outerJoinKeys, innerJoinKeys, _, _ = p.GetJoinKeys() - } else { - innerJoinKeys, outerJoinKeys, _, _ = p.GetJoinKeys() - } - innerChildWrapper := extractIndexJoinInnerChildPattern(p, innerChild) - if innerChildWrapper == nil { - return nil - } - - var avgInnerRowCnt float64 - if outerChild.StatsInfo().RowCount > 0 { - avgInnerRowCnt = p.EqualCondOutCnt / outerChild.StatsInfo().RowCount - } - joins = buildIndexJoinInner2TableScan(p, prop, innerChildWrapper, innerJoinKeys, outerJoinKeys, outerIdx, avgInnerRowCnt) - if joins != nil { - return - } - return buildIndexJoinInner2IndexScan(p, prop, innerChildWrapper, innerJoinKeys, outerJoinKeys, outerIdx, avgInnerRowCnt) -} - -// indexJoinInnerChildWrapper is a wrapper for the inner child of an index join. -// It contains the lowest DataSource operator and other inner child operator -// which is flattened into a list structure from tree structure . -// For example, the inner child of an index join is a tree structure like: -// -// Projection -// Aggregation -// Selection -// DataSource -// -// The inner child wrapper will be: -// DataSource: the lowest DataSource operator. -// hasDitryWrite: whether the inner child contains dirty data. -// zippedChildren: [Projection, Aggregation, Selection] -type indexJoinInnerChildWrapper struct { - ds *DataSource - hasDitryWrite bool - zippedChildren []base.LogicalPlan -} - -func extractIndexJoinInnerChildPattern(p *LogicalJoin, innerChild base.LogicalPlan) *indexJoinInnerChildWrapper { - wrapper := &indexJoinInnerChildWrapper{} - nextChild := func(pp base.LogicalPlan) base.LogicalPlan { - if len(pp.Children()) != 1 { - return nil - } - return pp.Children()[0] - } -childLoop: - for curChild := innerChild; curChild != nil; curChild = nextChild(curChild) { - switch child := curChild.(type) { - case *DataSource: - wrapper.ds = child - break childLoop - case *logicalop.LogicalProjection, *LogicalSelection, *LogicalAggregation: - if !p.SCtx().GetSessionVars().EnableINLJoinInnerMultiPattern { - return nil - } - wrapper.zippedChildren = append(wrapper.zippedChildren, child) - case *logicalop.LogicalUnionScan: - wrapper.hasDitryWrite = true - wrapper.zippedChildren = append(wrapper.zippedChildren, child) - default: - return nil - } - } - if wrapper.ds == nil || wrapper.ds.PreferStoreType&h.PreferTiFlash != 0 { - return nil - } - return wrapper -} - -// buildIndexJoinInner2TableScan builds a TableScan as the inner child for an -// IndexJoin if possible. -// If the inner side of a index join is a TableScan, only one tuple will be -// fetched from the inner side for every tuple from the outer side. This will be -// promised to be no worse than building IndexScan as the inner child. -func buildIndexJoinInner2TableScan( - p *LogicalJoin, - prop *property.PhysicalProperty, wrapper *indexJoinInnerChildWrapper, - innerJoinKeys, outerJoinKeys []*expression.Column, - outerIdx int, avgInnerRowCnt float64) (joins []base.PhysicalPlan) { - ds := wrapper.ds - var tblPath *util.AccessPath - for _, path := range ds.PossibleAccessPaths { - if path.IsTablePath() && path.StoreType == kv.TiKV { - tblPath = path - break - } - } - if tblPath == nil { - return nil - } - keyOff2IdxOff := make([]int, len(innerJoinKeys)) - newOuterJoinKeys := make([]*expression.Column, 0) - var ranges ranger.MutableRanges = ranger.Ranges{} - var innerTask, innerTask2 base.Task - var indexJoinResult *indexJoinPathResult - if ds.TableInfo.IsCommonHandle { - indexJoinResult, keyOff2IdxOff = getBestIndexJoinPathResult(p, ds, innerJoinKeys, outerJoinKeys, func(path *util.AccessPath) bool { return path.IsCommonHandlePath }) - if indexJoinResult == nil { - return nil - } - rangeInfo := indexJoinPathRangeInfo(p.SCtx(), outerJoinKeys, indexJoinResult) - innerTask = constructInnerTableScanTask(p, prop, wrapper, indexJoinResult.chosenRanges.Range(), outerJoinKeys, rangeInfo, false, false, avgInnerRowCnt) - // The index merge join's inner plan is different from index join, so we - // should construct another inner plan for it. - // Because we can't keep order for union scan, if there is a union scan in inner task, - // we can't construct index merge join. - if !wrapper.hasDitryWrite { - innerTask2 = constructInnerTableScanTask(p, prop, wrapper, indexJoinResult.chosenRanges.Range(), outerJoinKeys, rangeInfo, true, !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, avgInnerRowCnt) - } - ranges = indexJoinResult.chosenRanges - } else { - pkMatched := false - pkCol := ds.getPKIsHandleCol() - if pkCol == nil { - return nil - } - for i, key := range innerJoinKeys { - if !key.EqualColumn(pkCol) { - keyOff2IdxOff[i] = -1 - continue - } - pkMatched = true - keyOff2IdxOff[i] = 0 - // Add to newOuterJoinKeys only if conditions contain inner primary key. For issue #14822. - newOuterJoinKeys = append(newOuterJoinKeys, outerJoinKeys[i]) - } - outerJoinKeys = newOuterJoinKeys - if !pkMatched { - return nil - } - ranges := ranger.FullIntRange(mysql.HasUnsignedFlag(pkCol.RetType.GetFlag())) - var buffer strings.Builder - buffer.WriteString("[") - for i, key := range outerJoinKeys { - if i != 0 { - buffer.WriteString(" ") - } - buffer.WriteString(key.StringWithCtx(p.SCtx().GetExprCtx().GetEvalCtx(), errors.RedactLogDisable)) - } - buffer.WriteString("]") - rangeInfo := buffer.String() - innerTask = constructInnerTableScanTask(p, prop, wrapper, ranges, outerJoinKeys, rangeInfo, false, false, avgInnerRowCnt) - // The index merge join's inner plan is different from index join, so we - // should construct another inner plan for it. - // Because we can't keep order for union scan, if there is a union scan in inner task, - // we can't construct index merge join. - if !wrapper.hasDitryWrite { - innerTask2 = constructInnerTableScanTask(p, prop, wrapper, ranges, outerJoinKeys, rangeInfo, true, !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, avgInnerRowCnt) - } - } - var ( - path *util.AccessPath - lastColMng *ColWithCmpFuncManager - ) - if indexJoinResult != nil { - path = indexJoinResult.chosenPath - lastColMng = indexJoinResult.lastColManager - } - joins = make([]base.PhysicalPlan, 0, 3) - failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { - if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { - failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, nil, keyOff2IdxOff, path, lastColMng)) - } - }) - joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, lastColMng, true)...) - // We can reuse the `innerTask` here since index nested loop hash join - // do not need the inner child to promise the order. - joins = append(joins, constructIndexHashJoin(p, prop, outerIdx, innerTask, ranges, keyOff2IdxOff, path, lastColMng)...) - if innerTask2 != nil { - joins = append(joins, constructIndexMergeJoin(p, prop, outerIdx, innerTask2, ranges, keyOff2IdxOff, path, lastColMng)...) - } - return joins -} - -func buildIndexJoinInner2IndexScan( - p *LogicalJoin, - prop *property.PhysicalProperty, wrapper *indexJoinInnerChildWrapper, innerJoinKeys, outerJoinKeys []*expression.Column, - outerIdx int, avgInnerRowCnt float64) (joins []base.PhysicalPlan) { - ds := wrapper.ds - indexValid := func(path *util.AccessPath) bool { - if path.IsTablePath() { - return false - } - // if path is index path. index path currently include two kind of, one is normal, and the other is mv index. - // for mv index like mvi(a, json, b), if driving condition is a=1, and we build a prefix scan with range [1,1] - // on mvi, it will return many index rows which breaks handle-unique attribute here. - // - // the basic rule is that: mv index can be and can only be accessed by indexMerge operator. (embedded handle duplication) - if !isMVIndexPath(path) { - return true // not a MVIndex path, it can successfully be index join probe side. - } - return false - } - indexJoinResult, keyOff2IdxOff := getBestIndexJoinPathResult(p, ds, innerJoinKeys, outerJoinKeys, indexValid) - if indexJoinResult == nil { - return nil - } - joins = make([]base.PhysicalPlan, 0, 3) - rangeInfo := indexJoinPathRangeInfo(p.SCtx(), outerJoinKeys, indexJoinResult) - maxOneRow := false - if indexJoinResult.chosenPath.Index.Unique && indexJoinResult.usedColsLen == len(indexJoinResult.chosenPath.FullIdxCols) { - l := len(indexJoinResult.chosenAccess) - if l == 0 { - maxOneRow = true - } else { - sf, ok := indexJoinResult.chosenAccess[l-1].(*expression.ScalarFunction) - maxOneRow = ok && (sf.FuncName.L == ast.EQ) - } - } - innerTask := constructInnerIndexScanTask(p, prop, wrapper, indexJoinResult.chosenPath, indexJoinResult.chosenRanges.Range(), indexJoinResult.chosenRemained, innerJoinKeys, indexJoinResult.idxOff2KeyOff, rangeInfo, false, false, avgInnerRowCnt, maxOneRow) - failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { - if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL && innerTask != nil { - failpoint.Return(constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)) - } - }) - if innerTask != nil { - joins = append(joins, constructIndexJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager, true)...) - // We can reuse the `innerTask` here since index nested loop hash join - // do not need the inner child to promise the order. - joins = append(joins, constructIndexHashJoin(p, prop, outerIdx, innerTask, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)...) - } - // The index merge join's inner plan is different from index join, so we - // should construct another inner plan for it. - // Because we can't keep order for union scan, if there is a union scan in inner task, - // we can't construct index merge join. - if !wrapper.hasDitryWrite { - innerTask2 := constructInnerIndexScanTask(p, prop, wrapper, indexJoinResult.chosenPath, indexJoinResult.chosenRanges.Range(), indexJoinResult.chosenRemained, innerJoinKeys, indexJoinResult.idxOff2KeyOff, rangeInfo, true, !prop.IsSortItemEmpty() && prop.SortItems[0].Desc, avgInnerRowCnt, maxOneRow) - if innerTask2 != nil { - joins = append(joins, constructIndexMergeJoin(p, prop, outerIdx, innerTask2, indexJoinResult.chosenRanges, keyOff2IdxOff, indexJoinResult.chosenPath, indexJoinResult.lastColManager)...) - } - } - return joins -} - -// constructInnerTableScanTask is specially used to construct the inner plan for PhysicalIndexJoin. -func constructInnerTableScanTask( - p *LogicalJoin, - prop *property.PhysicalProperty, - wrapper *indexJoinInnerChildWrapper, - ranges ranger.Ranges, - _ []*expression.Column, - rangeInfo string, - keepOrder bool, - desc bool, - rowCount float64, -) base.Task { - ds := wrapper.ds - // If `ds.TableInfo.GetPartitionInfo() != nil`, - // it means the data source is a partition table reader. - // If the inner task need to keep order, the partition table reader can't satisfy it. - if keepOrder && ds.TableInfo.GetPartitionInfo() != nil { - return nil - } - ts := PhysicalTableScan{ - Table: ds.TableInfo, - Columns: ds.Columns, - TableAsName: ds.TableAsName, - DBName: ds.DBName, - filterCondition: ds.PushedDownConds, - Ranges: ranges, - rangeInfo: rangeInfo, - KeepOrder: keepOrder, - Desc: desc, - physicalTableID: ds.PhysicalTableID, - isPartition: ds.PartitionDefIdx != nil, - tblCols: ds.TblCols, - tblColHists: ds.TblColHists, - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - ts.SetSchema(ds.Schema().Clone()) - if rowCount <= 0 { - rowCount = float64(1) - } - selectivity := float64(1) - countAfterAccess := rowCount - if len(ts.filterCondition) > 0 { - var err error - selectivity, _, err = cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, ts.filterCondition, ds.PossibleAccessPaths) - if err != nil || selectivity <= 0 { - logutil.BgLogger().Debug("unexpected selectivity, use selection factor", zap.Float64("selectivity", selectivity), zap.String("table", ts.TableAsName.L)) - selectivity = cost.SelectionFactor - } - // rowCount is computed from result row count of join, which has already accounted the filters on DataSource, - // i.e, rowCount equals to `countAfterAccess * selectivity`. - countAfterAccess = rowCount / selectivity - } - ts.SetStats(&property.StatsInfo{ - // TableScan as inner child of IndexJoin can return at most 1 tuple for each outer row. - RowCount: math.Min(1.0, countAfterAccess), - StatsVersion: ds.StatsInfo().StatsVersion, - // NDV would not be used in cost computation of IndexJoin, set leave it as default nil. - }) - usedStats := p.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) - if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { - ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) - } - copTask := &CopTask{ - tablePlan: ts, - indexPlanFinished: true, - tblColHists: ds.TblColHists, - keepOrder: ts.KeepOrder, - } - copTask.physPlanPartInfo = &PhysPlanPartInfo{ - PruningConds: ds.AllConds, - PartitionNames: ds.PartitionNames, - Columns: ds.TblCols, - ColumnNames: ds.OutputNames(), - } - ts.PlanPartInfo = copTask.physPlanPartInfo - selStats := ts.StatsInfo().Scale(selectivity) - ts.addPushedDownSelection(copTask, selStats) - return constructIndexJoinInnerSideTask(p, prop, copTask, ds, nil, wrapper) -} - -func constructInnerByZippedChildren(prop *property.PhysicalProperty, zippedChildren []base.LogicalPlan, child base.PhysicalPlan) base.PhysicalPlan { - for i := len(zippedChildren) - 1; i >= 0; i-- { - switch x := zippedChildren[i].(type) { - case *logicalop.LogicalUnionScan: - child = constructInnerUnionScan(prop, x, child) - case *logicalop.LogicalProjection: - child = constructInnerProj(prop, x, child) - case *LogicalSelection: - child = constructInnerSel(prop, x, child) - case *LogicalAggregation: - child = constructInnerAgg(prop, x, child) - } - } - return child -} - -func constructInnerAgg(prop *property.PhysicalProperty, logicalAgg *LogicalAggregation, child base.PhysicalPlan) base.PhysicalPlan { - if logicalAgg == nil { - return child - } - physicalHashAgg := NewPhysicalHashAgg(logicalAgg, logicalAgg.StatsInfo(), prop) - physicalHashAgg.SetSchema(logicalAgg.Schema().Clone()) - physicalHashAgg.SetChildren(child) - return physicalHashAgg -} - -func constructInnerSel(prop *property.PhysicalProperty, sel *LogicalSelection, child base.PhysicalPlan) base.PhysicalPlan { - if sel == nil { - return child - } - physicalSel := PhysicalSelection{ - Conditions: sel.Conditions, - }.Init(sel.SCtx(), sel.StatsInfo(), sel.QueryBlockOffset(), prop) - physicalSel.SetChildren(child) - return physicalSel -} - -func constructInnerProj(prop *property.PhysicalProperty, proj *logicalop.LogicalProjection, child base.PhysicalPlan) base.PhysicalPlan { - if proj == nil { - return child - } - physicalProj := PhysicalProjection{ - Exprs: proj.Exprs, - CalculateNoDelay: proj.CalculateNoDelay, - AvoidColumnEvaluator: proj.AvoidColumnEvaluator, - }.Init(proj.SCtx(), proj.StatsInfo(), proj.QueryBlockOffset(), prop) - physicalProj.SetChildren(child) - physicalProj.SetSchema(proj.Schema()) - return physicalProj -} - -func constructInnerUnionScan(prop *property.PhysicalProperty, us *logicalop.LogicalUnionScan, reader base.PhysicalPlan) base.PhysicalPlan { - if us == nil { - return reader - } - // Use `reader.StatsInfo()` instead of `us.StatsInfo()` because it should be more accurate. No need to specify - // childrenReqProps now since we have got reader already. - physicalUnionScan := PhysicalUnionScan{ - Conditions: us.Conditions, - HandleCols: us.HandleCols, - }.Init(us.SCtx(), reader.StatsInfo(), us.QueryBlockOffset(), prop) - physicalUnionScan.SetChildren(reader) - return physicalUnionScan -} - -// getColsNDVLowerBoundFromHistColl tries to get a lower bound of the NDV of columns (whose uniqueIDs are colUIDs). -func getColsNDVLowerBoundFromHistColl(colUIDs []int64, histColl *statistics.HistColl) int64 { - if len(colUIDs) == 0 || histColl == nil { - return -1 - } - - // 1. Try to get NDV from column stats if it's a single column. - if len(colUIDs) == 1 && histColl.ColNum() > 0 { - uid := colUIDs[0] - if colStats := histColl.GetCol(uid); colStats != nil && colStats.IsStatsInitialized() { - return colStats.NDV - } - } - - slices.Sort(colUIDs) - - // 2. Try to get NDV from index stats. - // Note that we don't need to specially handle prefix index here, because the NDV of a prefix index is - // equal or less than the corresponding normal index, and that's safe here since we want a lower bound. - for idxID, idxCols := range histColl.Idx2ColUniqueIDs { - if len(idxCols) != len(colUIDs) { - continue - } - orderedIdxCols := make([]int64, len(idxCols)) - copy(orderedIdxCols, idxCols) - slices.Sort(orderedIdxCols) - if !slices.Equal(orderedIdxCols, colUIDs) { - continue - } - if idxStats := histColl.GetIdx(idxID); idxStats != nil && idxStats.IsStatsInitialized() { - return idxStats.NDV - } - } - - // TODO: if there's an index that contains the expected columns, we can also make use of its NDV. - // For example, NDV(a,b,c) / NDV(c) is a safe lower bound of NDV(a,b). - - // 3. If we still haven't got an NDV, we use the maximum NDV in the column stats as a lower bound. - maxNDV := int64(-1) - for _, uid := range colUIDs { - colStats := histColl.GetCol(uid) - if colStats == nil || !colStats.IsStatsInitialized() { - continue - } - maxNDV = max(maxNDV, colStats.NDV) - } - return maxNDV -} - -// constructInnerIndexScanTask is specially used to construct the inner plan for PhysicalIndexJoin. -func constructInnerIndexScanTask( - p *LogicalJoin, - prop *property.PhysicalProperty, - wrapper *indexJoinInnerChildWrapper, - path *util.AccessPath, - ranges ranger.Ranges, - filterConds []expression.Expression, - _ []*expression.Column, - idxOffset2joinKeyOffset []int, - rangeInfo string, - keepOrder bool, - desc bool, - rowCount float64, - maxOneRow bool, -) base.Task { - ds := wrapper.ds - // If `ds.TableInfo.GetPartitionInfo() != nil`, - // it means the data source is a partition table reader. - // If the inner task need to keep order, the partition table reader can't satisfy it. - if keepOrder && ds.TableInfo.GetPartitionInfo() != nil { - return nil - } - is := PhysicalIndexScan{ - Table: ds.TableInfo, - TableAsName: ds.TableAsName, - DBName: ds.DBName, - Columns: ds.Columns, - Index: path.Index, - IdxCols: path.IdxCols, - IdxColLens: path.IdxColLens, - dataSourceSchema: ds.Schema(), - KeepOrder: keepOrder, - Ranges: ranges, - rangeInfo: rangeInfo, - Desc: desc, - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - tblColHists: ds.TblColHists, - pkIsHandleCol: ds.getPKIsHandleCol(), - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - cop := &CopTask{ - indexPlan: is, - tblColHists: ds.TblColHists, - tblCols: ds.TblCols, - keepOrder: is.KeepOrder, - } - cop.physPlanPartInfo = &PhysPlanPartInfo{ - PruningConds: ds.AllConds, - PartitionNames: ds.PartitionNames, - Columns: ds.TblCols, - ColumnNames: ds.OutputNames(), - } - if !path.IsSingleScan { - // On this way, it's double read case. - ts := PhysicalTableScan{ - Columns: ds.Columns, - Table: is.Table, - TableAsName: ds.TableAsName, - DBName: ds.DBName, - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - tblCols: ds.TblCols, - tblColHists: ds.TblColHists, - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - ts.schema = is.dataSourceSchema.Clone() - if ds.TableInfo.IsCommonHandle { - commonHandle := ds.HandleCols.(*util.CommonHandleCols) - for _, col := range commonHandle.GetColumns() { - if ts.schema.ColumnIndex(col) == -1 { - ts.Schema().Append(col) - ts.Columns = append(ts.Columns, col.ToInfo()) - cop.needExtraProj = true - } - } - } - // We set `StatsVersion` here and fill other fields in `(*copTask).finishIndexPlan`. Since `copTask.indexPlan` may - // change before calling `(*copTask).finishIndexPlan`, we don't know the stats information of `ts` currently and on - // the other hand, it may be hard to identify `StatsVersion` of `ts` in `(*copTask).finishIndexPlan`. - ts.SetStats(&property.StatsInfo{StatsVersion: ds.TableStats.StatsVersion}) - usedStats := p.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) - if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { - ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) - } - // If inner cop task need keep order, the extraHandleCol should be set. - if cop.keepOrder && !ds.TableInfo.IsCommonHandle { - var needExtraProj bool - cop.extraHandleCol, needExtraProj = ts.appendExtraHandleCol(ds) - cop.needExtraProj = cop.needExtraProj || needExtraProj - } - if cop.needExtraProj { - cop.originSchema = ds.Schema() - } - cop.tablePlan = ts - } - if cop.tablePlan != nil && ds.TableInfo.IsCommonHandle { - cop.commonHandleCols = ds.CommonHandleCols - } - is.initSchema(append(path.FullIdxCols, ds.CommonHandleCols...), cop.tablePlan != nil) - indexConds, tblConds := ds.splitIndexFilterConditions(filterConds, path.FullIdxCols, path.FullIdxColLens) - - // Note: due to a regression in JOB workload, we use the optimizer fix control to enable this for now. - // - // Because we are estimating an average row count of the inner side corresponding to each row from the outer side, - // the estimated row count of the IndexScan should be no larger than (total row count / NDV of join key columns). - // We can calculate the lower bound of the NDV therefore we can get an upper bound of the row count here. - rowCountUpperBound := -1.0 - fixControlOK := fixcontrol.GetBoolWithDefault(ds.SCtx().GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix44855, false) - if fixControlOK && ds.TableStats != nil { - usedColIDs := make([]int64, 0) - // We only consider columns in this index that (1) are used to probe as join key, - // and (2) are not prefix column in the index (for which we can't easily get a lower bound) - for idxOffset, joinKeyOffset := range idxOffset2joinKeyOffset { - if joinKeyOffset < 0 || - path.FullIdxColLens[idxOffset] != types.UnspecifiedLength || - path.FullIdxCols[idxOffset] == nil { - continue - } - usedColIDs = append(usedColIDs, path.FullIdxCols[idxOffset].UniqueID) - } - joinKeyNDV := getColsNDVLowerBoundFromHistColl(usedColIDs, ds.TableStats.HistColl) - if joinKeyNDV > 0 { - rowCountUpperBound = ds.TableStats.RowCount / float64(joinKeyNDV) - } - } - - if rowCountUpperBound > 0 { - rowCount = math.Min(rowCount, rowCountUpperBound) - } - if maxOneRow { - // Theoretically, this line is unnecessary because row count estimation of join should guarantee rowCount is not larger - // than 1.0; however, there may be rowCount larger than 1.0 in reality, e.g, pseudo statistics cases, which does not reflect - // unique constraint in NDV. - rowCount = math.Min(rowCount, 1.0) - } - tmpPath := &util.AccessPath{ - IndexFilters: indexConds, - TableFilters: tblConds, - CountAfterIndex: rowCount, - CountAfterAccess: rowCount, - } - // Assume equal conditions used by index join and other conditions are independent. - if len(tblConds) > 0 { - selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, tblConds, ds.PossibleAccessPaths) - if err != nil || selectivity <= 0 { - logutil.BgLogger().Debug("unexpected selectivity, use selection factor", zap.Float64("selectivity", selectivity), zap.String("table", ds.TableAsName.L)) - selectivity = cost.SelectionFactor - } - // rowCount is computed from result row count of join, which has already accounted the filters on DataSource, - // i.e, rowCount equals to `countAfterIndex * selectivity`. - cnt := rowCount / selectivity - if rowCountUpperBound > 0 { - cnt = math.Min(cnt, rowCountUpperBound) - } - if maxOneRow { - cnt = math.Min(cnt, 1.0) - } - tmpPath.CountAfterIndex = cnt - tmpPath.CountAfterAccess = cnt - } - if len(indexConds) > 0 { - selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, indexConds, ds.PossibleAccessPaths) - if err != nil || selectivity <= 0 { - logutil.BgLogger().Debug("unexpected selectivity, use selection factor", zap.Float64("selectivity", selectivity), zap.String("table", ds.TableAsName.L)) - selectivity = cost.SelectionFactor - } - cnt := tmpPath.CountAfterIndex / selectivity - if rowCountUpperBound > 0 { - cnt = math.Min(cnt, rowCountUpperBound) - } - if maxOneRow { - cnt = math.Min(cnt, 1.0) - } - tmpPath.CountAfterAccess = cnt - } - is.SetStats(ds.TableStats.ScaleByExpectCnt(tmpPath.CountAfterAccess)) - usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) - if usedStats != nil && usedStats.GetUsedInfo(is.physicalTableID) != nil { - is.usedStatsInfo = usedStats.GetUsedInfo(is.physicalTableID) - } - finalStats := ds.TableStats.ScaleByExpectCnt(rowCount) - if err := is.addPushedDownSelection(cop, ds, tmpPath, finalStats); err != nil { - logutil.BgLogger().Warn("unexpected error happened during addPushedDownSelection function", zap.Error(err)) - return nil - } - return constructIndexJoinInnerSideTask(p, prop, cop, ds, path, wrapper) -} - -// construct the inner join task by inner child plan tree -// The Logical include two parts: logicalplan->physicalplan, physicalplan->task -// Step1: whether agg can be pushed down to coprocessor -// -// Step1.1: If the agg can be pushded down to coprocessor, we will build a copTask and attach the agg to the copTask -// There are two kinds of agg: stream agg and hash agg. Stream agg depends on some conditions, such as the group by cols -// -// Step2: build other inner plan node to task -func constructIndexJoinInnerSideTask(p *LogicalJoin, prop *property.PhysicalProperty, dsCopTask *CopTask, ds *DataSource, path *util.AccessPath, wrapper *indexJoinInnerChildWrapper) base.Task { - var la *LogicalAggregation - var canPushAggToCop bool - if len(wrapper.zippedChildren) > 0 { - la, canPushAggToCop = wrapper.zippedChildren[len(wrapper.zippedChildren)-1].(*LogicalAggregation) - if la != nil && la.HasDistinct() { - // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. - // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. - if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown { - canPushAggToCop = false - } - } - } - - // If the bottom plan is not aggregation or the aggregation can't be pushed to coprocessor, we will construct a root task directly. - if !canPushAggToCop { - result := dsCopTask.ConvertToRootTask(ds.SCtx()).(*RootTask) - result.SetPlan(constructInnerByZippedChildren(prop, wrapper.zippedChildren, result.GetPlan())) - return result - } - - // Try stream aggregation first. - // We will choose the stream aggregation if the following conditions are met: - // 1. Force hint stream agg by /*+ stream_agg() */ - // 2. Other conditions copy from getStreamAggs() in exhaust_physical_plans.go - _, preferStream := la.ResetHintIfConflicted() - for _, aggFunc := range la.AggFuncs { - if aggFunc.Mode == aggregation.FinalMode { - preferStream = false - break - } - } - // group by a + b is not interested in any order. - groupByCols := la.GetGroupByCols() - if len(groupByCols) != len(la.GroupByItems) { - preferStream = false - } - if la.HasDistinct() && !la.DistinctArgsMeetsProperty() { - preferStream = false - } - // sort items must be the super set of group by items - if path != nil && path.Index != nil && !path.Index.MVIndex && - ds.TableInfo.GetPartitionInfo() == nil { - if len(path.IdxCols) < len(groupByCols) { - preferStream = false - } else { - sctx := p.SCtx() - for i, groupbyCol := range groupByCols { - if path.IdxColLens[i] != types.UnspecifiedLength || - !groupbyCol.EqualByExprAndID(sctx.GetExprCtx().GetEvalCtx(), path.IdxCols[i]) { - preferStream = false - } - } - } - } else { - preferStream = false - } - - // build physical agg and attach to task - var aggTask base.Task - // build stream agg and change ds keep order to true - if preferStream { - newGbyItems := make([]expression.Expression, len(la.GroupByItems)) - copy(newGbyItems, la.GroupByItems) - newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) - copy(newAggFuncs, la.AggFuncs) - streamAgg := basePhysicalAgg{ - GroupByItems: newGbyItems, - AggFuncs: newAggFuncs, - }.initForStream(la.SCtx(), la.StatsInfo(), la.QueryBlockOffset(), prop) - streamAgg.SetSchema(la.Schema().Clone()) - // change to keep order for index scan and dsCopTask - if dsCopTask.indexPlan != nil { - // get the index scan from dsCopTask.indexPlan - physicalIndexScan, _ := dsCopTask.indexPlan.(*PhysicalIndexScan) - if physicalIndexScan == nil && len(dsCopTask.indexPlan.Children()) == 1 { - physicalIndexScan, _ = dsCopTask.indexPlan.Children()[0].(*PhysicalIndexScan) - } - if physicalIndexScan != nil { - physicalIndexScan.KeepOrder = true - dsCopTask.keepOrder = true - aggTask = streamAgg.Attach2Task(dsCopTask) - } - } - } - - // build hash agg, when the stream agg is illegal such as the order by prop is not matched - if aggTask == nil { - physicalHashAgg := NewPhysicalHashAgg(la, la.StatsInfo(), prop) - physicalHashAgg.SetSchema(la.Schema().Clone()) - aggTask = physicalHashAgg.Attach2Task(dsCopTask) - } - - // build other inner plan node to task - result, ok := aggTask.(*RootTask) - if !ok { - return nil - } - result.SetPlan(constructInnerByZippedChildren(prop, wrapper.zippedChildren[0:len(wrapper.zippedChildren)-1], result.p)) - return result -} - -func filterIndexJoinBySessionVars(sc base.PlanContext, indexJoins []base.PhysicalPlan) []base.PhysicalPlan { - if sc.GetSessionVars().EnableIndexMergeJoin { - return indexJoins - } - for i := len(indexJoins) - 1; i >= 0; i-- { - if _, ok := indexJoins[i].(*PhysicalIndexMergeJoin); ok { - indexJoins = append(indexJoins[:i], indexJoins[i+1:]...) - } - } - return indexJoins -} - -const ( - joinLeft = 0 - joinRight = 1 - indexJoinMethod = 0 - indexHashJoinMethod = 1 - indexMergeJoinMethod = 2 -) - -func getIndexJoinSideAndMethod(join base.PhysicalPlan) (innerSide, joinMethod int, ok bool) { - var innerIdx int - switch ij := join.(type) { - case *PhysicalIndexJoin: - innerIdx = ij.getInnerChildIdx() - joinMethod = indexJoinMethod - case *PhysicalIndexHashJoin: - innerIdx = ij.getInnerChildIdx() - joinMethod = indexHashJoinMethod - case *PhysicalIndexMergeJoin: - innerIdx = ij.getInnerChildIdx() - joinMethod = indexMergeJoinMethod - default: - return 0, 0, false - } - ok = true - innerSide = joinLeft - if innerIdx == 1 { - innerSide = joinRight - } - return -} - -// tryToGetIndexJoin returns all available index join plans, and the second returned value indicates whether this plan is enforced by hints. -func tryToGetIndexJoin(p *LogicalJoin, prop *property.PhysicalProperty) (indexJoins []base.PhysicalPlan, canForced bool) { - // supportLeftOuter and supportRightOuter indicates whether this type of join - // supports the left side or right side to be the outer side. - var supportLeftOuter, supportRightOuter bool - switch p.JoinType { - case SemiJoin, AntiSemiJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin, LeftOuterJoin: - supportLeftOuter = true - case RightOuterJoin: - supportRightOuter = true - case InnerJoin: - supportLeftOuter, supportRightOuter = true, true - } - candidates := make([]base.PhysicalPlan, 0, 2) - if supportLeftOuter { - candidates = append(candidates, getIndexJoinByOuterIdx(p, prop, 0)...) - } - if supportRightOuter { - candidates = append(candidates, getIndexJoinByOuterIdx(p, prop, 1)...) - } - - // Handle hints and variables about index join. - // The priority is: force hints like TIDB_INLJ > filter hints like NO_INDEX_JOIN > variables. - // Handle hints conflict first. - stmtCtx := p.SCtx().GetSessionVars().StmtCtx - if p.PreferAny(h.PreferLeftAsINLJInner, h.PreferRightAsINLJInner) && p.PreferAny(h.PreferNoIndexJoin) { - stmtCtx.SetHintWarning("Some INL_JOIN and NO_INDEX_JOIN hints conflict, NO_INDEX_JOIN may be ignored") - } - if p.PreferAny(h.PreferLeftAsINLHJInner, h.PreferRightAsINLHJInner) && p.PreferAny(h.PreferNoIndexHashJoin) { - stmtCtx.SetHintWarning("Some INL_HASH_JOIN and NO_INDEX_HASH_JOIN hints conflict, NO_INDEX_HASH_JOIN may be ignored") - } - if p.PreferAny(h.PreferLeftAsINLMJInner, h.PreferRightAsINLMJInner) && p.PreferAny(h.PreferNoIndexMergeJoin) { - stmtCtx.SetHintWarning("Some INL_MERGE_JOIN and NO_INDEX_MERGE_JOIN hints conflict, NO_INDEX_MERGE_JOIN may be ignored") - } - - candidates, canForced = handleForceIndexJoinHints(p, prop, candidates) - if canForced { - return candidates, canForced - } - candidates = handleFilterIndexJoinHints(p, candidates) - return filterIndexJoinBySessionVars(p.SCtx(), candidates), false -} - -func handleFilterIndexJoinHints(p *LogicalJoin, candidates []base.PhysicalPlan) []base.PhysicalPlan { - if !p.PreferAny(h.PreferNoIndexJoin, h.PreferNoIndexHashJoin, h.PreferNoIndexMergeJoin) { - return candidates // no filter index join hints - } - filtered := make([]base.PhysicalPlan, 0, len(candidates)) - for _, candidate := range candidates { - _, joinMethod, ok := getIndexJoinSideAndMethod(candidate) - if !ok { - continue - } - if (p.PreferAny(h.PreferNoIndexJoin) && joinMethod == indexJoinMethod) || - (p.PreferAny(h.PreferNoIndexHashJoin) && joinMethod == indexHashJoinMethod) || - (p.PreferAny(h.PreferNoIndexMergeJoin) && joinMethod == indexMergeJoinMethod) { - continue - } - filtered = append(filtered, candidate) - } - return filtered -} - -// handleForceIndexJoinHints handles the force index join hints and returns all plans that can satisfy the hints. -func handleForceIndexJoinHints(p *LogicalJoin, prop *property.PhysicalProperty, candidates []base.PhysicalPlan) (indexJoins []base.PhysicalPlan, canForced bool) { - if !p.PreferAny(h.PreferRightAsINLJInner, h.PreferRightAsINLHJInner, h.PreferRightAsINLMJInner, - h.PreferLeftAsINLJInner, h.PreferLeftAsINLHJInner, h.PreferLeftAsINLMJInner) { - return candidates, false // no force index join hints - } - forced := make([]base.PhysicalPlan, 0, len(candidates)) - for _, candidate := range candidates { - innerSide, joinMethod, ok := getIndexJoinSideAndMethod(candidate) - if !ok { - continue - } - if (p.PreferAny(h.PreferLeftAsINLJInner) && innerSide == joinLeft && joinMethod == indexJoinMethod) || - (p.PreferAny(h.PreferRightAsINLJInner) && innerSide == joinRight && joinMethod == indexJoinMethod) || - (p.PreferAny(h.PreferLeftAsINLHJInner) && innerSide == joinLeft && joinMethod == indexHashJoinMethod) || - (p.PreferAny(h.PreferRightAsINLHJInner) && innerSide == joinRight && joinMethod == indexHashJoinMethod) || - (p.PreferAny(h.PreferLeftAsINLMJInner) && innerSide == joinLeft && joinMethod == indexMergeJoinMethod) || - (p.PreferAny(h.PreferRightAsINLMJInner) && innerSide == joinRight && joinMethod == indexMergeJoinMethod) { - forced = append(forced, candidate) - } - } - - if len(forced) > 0 { - return forced, true - } - // Cannot find any valid index join plan with these force hints. - // Print warning message if any hints cannot work. - // If the required property is not empty, we will enforce it and try the hint again. - // So we only need to generate warning message when the property is empty. - if prop.IsSortItemEmpty() { - var indexJoinTables, indexHashJoinTables, indexMergeJoinTables []h.HintedTable - if p.HintInfo != nil { - t := p.HintInfo.IndexJoin - indexJoinTables, indexHashJoinTables, indexMergeJoinTables = t.INLJTables, t.INLHJTables, t.INLMJTables - } - var errMsg string - switch { - case p.PreferAny(h.PreferLeftAsINLJInner, h.PreferRightAsINLJInner): // prefer index join - errMsg = fmt.Sprintf("Optimizer Hint %s or %s is inapplicable", h.Restore2JoinHint(h.HintINLJ, indexJoinTables), h.Restore2JoinHint(h.TiDBIndexNestedLoopJoin, indexJoinTables)) - case p.PreferAny(h.PreferLeftAsINLHJInner, h.PreferRightAsINLHJInner): // prefer index hash join - errMsg = fmt.Sprintf("Optimizer Hint %s is inapplicable", h.Restore2JoinHint(h.HintINLHJ, indexHashJoinTables)) - case p.PreferAny(h.PreferLeftAsINLMJInner, h.PreferRightAsINLMJInner): // prefer index merge join - errMsg = fmt.Sprintf("Optimizer Hint %s is inapplicable", h.Restore2JoinHint(h.HintINLMJ, indexMergeJoinTables)) - } - // Append inapplicable reason. - if len(p.EqualConditions) == 0 { - errMsg += " without column equal ON condition" - } - // Generate warning message to client. - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) - } - return candidates, false -} - -func checkChildFitBC(p base.Plan) bool { - if p.StatsInfo().HistColl == nil { - return p.SCtx().GetSessionVars().BroadcastJoinThresholdCount == -1 || p.StatsInfo().Count() < p.SCtx().GetSessionVars().BroadcastJoinThresholdCount - } - avg := cardinality.GetAvgRowSize(p.SCtx(), p.StatsInfo().HistColl, p.Schema().Columns, false, false) - sz := avg * float64(p.StatsInfo().Count()) - return p.SCtx().GetSessionVars().BroadcastJoinThresholdSize == -1 || sz < float64(p.SCtx().GetSessionVars().BroadcastJoinThresholdSize) -} - -func calcBroadcastExchangeSize(p base.Plan, mppStoreCnt int) (row float64, size float64, hasSize bool) { - s := p.StatsInfo() - row = float64(s.Count()) * float64(mppStoreCnt-1) - if s.HistColl == nil { - return row, 0, false - } - avg := cardinality.GetAvgRowSize(p.SCtx(), s.HistColl, p.Schema().Columns, false, false) - size = avg * row - return row, size, true -} - -func calcBroadcastExchangeSizeByChild(p1 base.Plan, p2 base.Plan, mppStoreCnt int) (row float64, size float64, hasSize bool) { - row1, size1, hasSize1 := calcBroadcastExchangeSize(p1, mppStoreCnt) - row2, size2, hasSize2 := calcBroadcastExchangeSize(p2, mppStoreCnt) - - // broadcast exchange size: - // Build: (mppStoreCnt - 1) * sizeof(BuildTable) - // Probe: 0 - // choose the child plan with the maximum approximate value as Probe - - if hasSize1 && hasSize2 { - return math.Min(row1, row2), math.Min(size1, size2), true - } - - return math.Min(row1, row2), 0, false -} - -func calcHashExchangeSize(p base.Plan, mppStoreCnt int) (row float64, sz float64, hasSize bool) { - s := p.StatsInfo() - row = float64(s.Count()) * float64(mppStoreCnt-1) / float64(mppStoreCnt) - if s.HistColl == nil { - return row, 0, false - } - avg := cardinality.GetAvgRowSize(p.SCtx(), s.HistColl, p.Schema().Columns, false, false) - sz = avg * row - return row, sz, true -} - -func calcHashExchangeSizeByChild(p1 base.Plan, p2 base.Plan, mppStoreCnt int) (row float64, size float64, hasSize bool) { - row1, size1, hasSize1 := calcHashExchangeSize(p1, mppStoreCnt) - row2, size2, hasSize2 := calcHashExchangeSize(p2, mppStoreCnt) - - // hash exchange size: - // Build: sizeof(BuildTable) * (mppStoreCnt - 1) / mppStoreCnt - // Probe: sizeof(ProbeTable) * (mppStoreCnt - 1) / mppStoreCnt - - if hasSize1 && hasSize2 { - return row1 + row2, size1 + size2, true - } - return row1 + row2, 0, false -} - -// The size of `Build` hash table when using broadcast join is about `X`. -// The size of `Build` hash table when using shuffle join is about `X / (mppStoreCnt)`. -// It will cost more time to construct `Build` hash table and search `Probe` while using broadcast join. -// Set a scale factor (`mppStoreCnt^*`) when estimating broadcast join in `isJoinFitMPPBCJ` and `isJoinChildFitMPPBCJ` (based on TPCH benchmark, it has been verified in Q9). - -func isJoinFitMPPBCJ(p *LogicalJoin, mppStoreCnt int) bool { - rowBC, szBC, hasSizeBC := calcBroadcastExchangeSizeByChild(p.Children()[0], p.Children()[1], mppStoreCnt) - rowHash, szHash, hasSizeHash := calcHashExchangeSizeByChild(p.Children()[0], p.Children()[1], mppStoreCnt) - if hasSizeBC && hasSizeHash { - return szBC*float64(mppStoreCnt) <= szHash - } - return rowBC*float64(mppStoreCnt) <= rowHash -} - -func isJoinChildFitMPPBCJ(p *LogicalJoin, childIndexToBC int, mppStoreCnt int) bool { - rowBC, szBC, hasSizeBC := calcBroadcastExchangeSize(p.Children()[childIndexToBC], mppStoreCnt) - rowHash, szHash, hasSizeHash := calcHashExchangeSizeByChild(p.Children()[0], p.Children()[1], mppStoreCnt) - - if hasSizeBC && hasSizeHash { - return szBC*float64(mppStoreCnt) <= szHash - } - return rowBC*float64(mppStoreCnt) <= rowHash -} - -// If we can use mpp broadcast join, that's our first choice. -func preferMppBCJ(p *LogicalJoin) bool { - if len(p.EqualConditions) == 0 && p.SCtx().GetSessionVars().AllowCartesianBCJ == 2 { - return true - } - - onlyCheckChild1 := p.JoinType == LeftOuterJoin || p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin - onlyCheckChild0 := p.JoinType == RightOuterJoin - - if p.SCtx().GetSessionVars().PreferBCJByExchangeDataSize { - mppStoreCnt, err := p.SCtx().GetMPPClient().GetMPPStoreCount() - - // No need to exchange data if there is only ONE mpp store. But the behavior of optimizer is unexpected if use broadcast way forcibly, such as tpch q4. - // TODO: always use broadcast way to exchange data if there is only ONE mpp store. - - if err == nil && mppStoreCnt > 0 { - if !(onlyCheckChild1 || onlyCheckChild0) { - return isJoinFitMPPBCJ(p, mppStoreCnt) - } - if mppStoreCnt > 1 { - if onlyCheckChild1 { - return isJoinChildFitMPPBCJ(p, 1, mppStoreCnt) - } else if onlyCheckChild0 { - return isJoinChildFitMPPBCJ(p, 0, mppStoreCnt) - } - } - // If mppStoreCnt is ONE and only need to check one child plan, rollback to original way. - // Otherwise, the plan of tpch q4 may be unexpected. - } - } - - if onlyCheckChild1 { - return checkChildFitBC(p.Children()[1]) - } else if onlyCheckChild0 { - return checkChildFitBC(p.Children()[0]) - } - return checkChildFitBC(p.Children()[0]) || checkChildFitBC(p.Children()[1]) -} - -func canExprsInJoinPushdown(p *LogicalJoin, storeType kv.StoreType) bool { - equalExprs := make([]expression.Expression, 0, len(p.EqualConditions)) - for _, eqCondition := range p.EqualConditions { - if eqCondition.FuncName.L == ast.NullEQ { - return false - } - equalExprs = append(equalExprs, eqCondition) - } - pushDownCtx := GetPushDownCtx(p.SCtx()) - if !expression.CanExprsPushDown(pushDownCtx, equalExprs, storeType) { - return false - } - if !expression.CanExprsPushDown(pushDownCtx, p.LeftConditions, storeType) { - return false - } - if !expression.CanExprsPushDown(pushDownCtx, p.RightConditions, storeType) { - return false - } - if !expression.CanExprsPushDown(pushDownCtx, p.OtherConditions, storeType) { - return false - } - return true -} - -func tryToGetMppHashJoin(p *LogicalJoin, prop *property.PhysicalProperty, useBCJ bool) []base.PhysicalPlan { - if !prop.IsSortItemEmpty() { - return nil - } - if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { - return nil - } - - if !expression.IsPushDownEnabled(p.JoinType.String(), kv.TiFlash) { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because join type `" + p.JoinType.String() + "` is blocked by blacklist, check `table mysql.expr_pushdown_blacklist;` for more information.") - return nil - } - - if p.JoinType != InnerJoin && p.JoinType != LeftOuterJoin && p.JoinType != RightOuterJoin && p.JoinType != SemiJoin && p.JoinType != AntiSemiJoin && p.JoinType != LeftOuterSemiJoin && p.JoinType != AntiLeftOuterSemiJoin { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because join type `" + p.JoinType.String() + "` is not supported now.") - return nil - } - - if len(p.EqualConditions) == 0 { - if !useBCJ { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because `Cartesian Product` is only supported by broadcast join, check value and documents of variables `tidb_broadcast_join_threshold_size` and `tidb_broadcast_join_threshold_count`.") - return nil - } - if p.SCtx().GetSessionVars().AllowCartesianBCJ == 0 { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because `Cartesian Product` is only supported by broadcast join, check value and documents of variable `tidb_opt_broadcast_cartesian_join`.") - return nil - } - } - if len(p.LeftConditions) != 0 && p.JoinType != LeftOuterJoin { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because there is a join that is not `left join` but has left conditions, which is not supported by mpp now, see github.com/pingcap/tidb/issues/26090 for more information.") - return nil - } - if len(p.RightConditions) != 0 && p.JoinType != RightOuterJoin { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because there is a join that is not `right join` but has right conditions, which is not supported by mpp now.") - return nil - } - - if prop.MPPPartitionTp == property.BroadcastType { - return nil - } - if !canExprsInJoinPushdown(p, kv.TiFlash) { - return nil - } - lkeys, rkeys, _, _ := p.GetJoinKeys() - lNAkeys, rNAKeys := p.GetNAJoinKeys() - // check match property - baseJoin := basePhysicalJoin{ - JoinType: p.JoinType, - LeftConditions: p.LeftConditions, - RightConditions: p.RightConditions, - OtherConditions: p.OtherConditions, - DefaultValues: p.DefaultValues, - LeftJoinKeys: lkeys, - RightJoinKeys: rkeys, - LeftNAJoinKeys: lNAkeys, - RightNAJoinKeys: rNAKeys, - } - // It indicates which side is the build side. - forceLeftToBuild := ((p.PreferJoinType & h.PreferLeftAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferRightAsHJProbe) > 0) - forceRightToBuild := ((p.PreferJoinType & h.PreferRightAsHJBuild) > 0) || ((p.PreferJoinType & h.PreferLeftAsHJProbe) > 0) - if forceLeftToBuild && forceRightToBuild { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( - "Some HASH_JOIN_BUILD and HASH_JOIN_PROBE hints are conflicts, please check the hints") - forceLeftToBuild = false - forceRightToBuild = false - } - preferredBuildIndex := 0 - fixedBuildSide := false // Used to indicate whether the build side for the MPP join is fixed or not. - if p.JoinType == InnerJoin { - if p.Children()[0].StatsInfo().Count() > p.Children()[1].StatsInfo().Count() { - preferredBuildIndex = 1 - } - } else if p.JoinType.IsSemiJoin() { - if !useBCJ && !p.IsNAAJ() && len(p.EqualConditions) > 0 && (p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin) { - // TiFlash only supports Non-null_aware non-cross semi/anti_semi join to use both sides as build side - preferredBuildIndex = 1 - // MPPOuterJoinFixedBuildSide default value is false - // use MPPOuterJoinFixedBuildSide here as a way to disable using left table as build side - if !p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide && p.Children()[1].StatsInfo().Count() > p.Children()[0].StatsInfo().Count() { - preferredBuildIndex = 0 - } - } else { - preferredBuildIndex = 1 - fixedBuildSide = true - } - } - if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { - // TiFlash does not require that the build side must be the inner table for outer join. - // so we can choose the build side based on the row count, except that: - // 1. it is a broadcast join(for broadcast join, it makes sense to use the broadcast side as the build side) - // 2. or session variable MPPOuterJoinFixedBuildSide is set to true - // 3. or nullAware/cross joins - if useBCJ || p.IsNAAJ() || len(p.EqualConditions) == 0 || p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide { - if !p.SCtx().GetSessionVars().MPPOuterJoinFixedBuildSide { - // The hint has higher priority than variable. - fixedBuildSide = true - } - if p.JoinType == LeftOuterJoin { - preferredBuildIndex = 1 - } - } else if p.Children()[0].StatsInfo().Count() > p.Children()[1].StatsInfo().Count() { - preferredBuildIndex = 1 - } - } - - if forceLeftToBuild || forceRightToBuild { - match := (forceLeftToBuild && preferredBuildIndex == 0) || (forceRightToBuild && preferredBuildIndex == 1) - if !match { - if fixedBuildSide { - // A warning will be generated if the build side is fixed, but we attempt to change it using the hint. - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( - "Some HASH_JOIN_BUILD and HASH_JOIN_PROBE hints cannot be utilized for MPP joins, please check the hints") - } else { - // The HASH_JOIN_BUILD OR HASH_JOIN_PROBE hints can take effective. - preferredBuildIndex = 1 - preferredBuildIndex - } - } - } - - // set preferredBuildIndex for test - failpoint.Inject("mockPreferredBuildIndex", func(val failpoint.Value) { - if !p.SCtx().GetSessionVars().InRestrictedSQL { - preferredBuildIndex = val.(int) - } - }) - - baseJoin.InnerChildIdx = preferredBuildIndex - childrenProps := make([]*property.PhysicalProperty, 2) - if useBCJ { - childrenProps[preferredBuildIndex] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.BroadcastType, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - expCnt := math.MaxFloat64 - if prop.ExpectedCnt < p.StatsInfo().RowCount { - expCntScale := prop.ExpectedCnt / p.StatsInfo().RowCount - expCnt = p.Children()[1-preferredBuildIndex].StatsInfo().RowCount * expCntScale - } - if prop.MPPPartitionTp == property.HashType { - lPartitionKeys, rPartitionKeys := p.GetPotentialPartitionKeys() - hashKeys := rPartitionKeys - if preferredBuildIndex == 1 { - hashKeys = lPartitionKeys - } - matches := prop.IsSubsetOf(hashKeys) - if len(matches) == 0 { - return nil - } - childrenProps[1-preferredBuildIndex] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: expCnt, MPPPartitionTp: property.HashType, MPPPartitionCols: prop.MPPPartitionCols, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - } else { - childrenProps[1-preferredBuildIndex] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: expCnt, MPPPartitionTp: property.AnyType, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - } - } else { - lPartitionKeys, rPartitionKeys := p.GetPotentialPartitionKeys() - if prop.MPPPartitionTp == property.HashType { - var matches []int - if p.JoinType == InnerJoin { - if matches = prop.IsSubsetOf(lPartitionKeys); len(matches) == 0 { - matches = prop.IsSubsetOf(rPartitionKeys) - } - } else if p.JoinType == RightOuterJoin { - // for right out join, only the right partition keys can possibly matches the prop, because - // the left partition keys will generate NULL values randomly - // todo maybe we can add a null-sensitive flag in the MPPPartitionColumn to indicate whether the partition column is - // null-sensitive(used in aggregation) or null-insensitive(used in join) - matches = prop.IsSubsetOf(rPartitionKeys) - } else { - // for left out join, only the left partition keys can possibly matches the prop, because - // the right partition keys will generate NULL values randomly - // for semi/anti semi/left out semi/anti left out semi join, only left partition keys are returned, - // so just check the left partition keys - matches = prop.IsSubsetOf(lPartitionKeys) - } - if len(matches) == 0 { - return nil - } - lPartitionKeys = choosePartitionKeys(lPartitionKeys, matches) - rPartitionKeys = choosePartitionKeys(rPartitionKeys, matches) - } - childrenProps[0] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: lPartitionKeys, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - childrenProps[1] = &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: rPartitionKeys, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - } - join := PhysicalHashJoin{ - basePhysicalJoin: baseJoin, - Concurrency: uint(p.SCtx().GetSessionVars().CopTiFlashConcurrencyFactor), - EqualConditions: p.EqualConditions, - NAEqualConditions: p.NAEQConditions, - storeTp: kv.TiFlash, - mppShuffleJoin: !useBCJ, - // Mpp Join has quite heavy cost. Even limit might not suspend it in time, so we don't scale the count. - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), childrenProps...) - join.SetSchema(p.Schema()) - return []base.PhysicalPlan{join} -} - -func choosePartitionKeys(keys []*property.MPPPartitionColumn, matches []int) []*property.MPPPartitionColumn { - newKeys := make([]*property.MPPPartitionColumn, 0, len(matches)) - for _, id := range matches { - newKeys = append(newKeys, keys[id]) - } - return newKeys -} - -func exhaustPhysicalPlans4LogicalExpand(p *LogicalExpand, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - // under the mpp task type, if the sort item is not empty, refuse it, cause expanded data doesn't support any sort items. - if !prop.IsSortItemEmpty() { - // false, meaning we can add a sort enforcer. - return nil, false, nil - } - // when TiDB Expand execution is introduced: we can deal with two kind of physical plans. - // RootTaskType means expand should be run at TiDB node. - // (RootTaskType is the default option, we can also generate a mpp candidate for it) - // MPPTaskType means expand should be run at TiFlash node. - if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { - return nil, true, nil - } - // now Expand mode can only be executed on TiFlash node. - // Upper layer shouldn't expect any mpp partition from an Expand operator. - // todo: data output from Expand operator should keep the origin data mpp partition. - if prop.TaskTp == property.MppTaskType && prop.MPPPartitionTp != property.AnyType { - return nil, true, nil - } - var physicalExpands []base.PhysicalPlan - // for property.RootTaskType and property.MppTaskType with no partition option, we can give an MPP Expand. - canPushToTiFlash := p.CanPushToCop(kv.TiFlash) - if p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { - mppProp := prop.CloneEssentialFields() - mppProp.TaskTp = property.MppTaskType - expand := PhysicalExpand{ - GroupingSets: p.RollupGroupingSets, - LevelExprs: p.LevelExprs, - ExtraGroupingColNames: p.ExtraGroupingColNames, - }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), mppProp) - expand.SetSchema(p.Schema()) - physicalExpands = append(physicalExpands, expand) - // when the MppTaskType is required, we can return the physical plan directly. - if prop.TaskTp == property.MppTaskType { - return physicalExpands, true, nil - } - } - // for property.RootTaskType, we can give a TiDB Expand. - { - taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType, property.MppTaskType, property.RootTaskType} - for _, taskType := range taskTypes { - // require cop task type for children.F - tidbProp := prop.CloneEssentialFields() - tidbProp.TaskTp = taskType - expand := PhysicalExpand{ - GroupingSets: p.RollupGroupingSets, - LevelExprs: p.LevelExprs, - ExtraGroupingColNames: p.ExtraGroupingColNames, - }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), tidbProp) - expand.SetSchema(p.Schema()) - physicalExpands = append(physicalExpands, expand) - } - } - return physicalExpands, true, nil -} - -func exhaustPhysicalPlans4LogicalProjection(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - p := lp.(*logicalop.LogicalProjection) - newProp, ok := p.TryToGetChildProp(prop) - if !ok { - return nil, true, nil - } - newProps := []*property.PhysicalProperty{newProp} - // generate a mpp task candidate if mpp mode is allowed - ctx := p.SCtx() - pushDownCtx := GetPushDownCtx(ctx) - if newProp.TaskTp != property.MppTaskType && ctx.GetSessionVars().IsMPPAllowed() && p.CanPushToCop(kv.TiFlash) && - expression.CanExprsPushDown(pushDownCtx, p.Exprs, kv.TiFlash) { - mppProp := newProp.CloneEssentialFields() - mppProp.TaskTp = property.MppTaskType - newProps = append(newProps, mppProp) - } - if newProp.TaskTp != property.CopSingleReadTaskType && ctx.GetSessionVars().AllowProjectionPushDown && p.CanPushToCop(kv.TiKV) && - expression.CanExprsPushDown(pushDownCtx, p.Exprs, kv.TiKV) && !expression.ContainVirtualColumn(p.Exprs) && - expression.ProjectionBenefitsFromPushedDown(p.Exprs, p.Children()[0].Schema().Len()) { - copProp := newProp.CloneEssentialFields() - copProp.TaskTp = property.CopSingleReadTaskType - newProps = append(newProps, copProp) - } - - ret := make([]base.PhysicalPlan, 0, len(newProps)) - for _, newProp := range newProps { - proj := PhysicalProjection{ - Exprs: p.Exprs, - CalculateNoDelay: p.CalculateNoDelay, - AvoidColumnEvaluator: p.AvoidColumnEvaluator, - }.Init(ctx, p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), newProp) - proj.SetSchema(p.Schema()) - ret = append(ret, proj) - } - return ret, true, nil -} - -func pushLimitOrTopNForcibly(p base.LogicalPlan) bool { - var meetThreshold bool - var preferPushDown *bool - switch lp := p.(type) { - case *logicalop.LogicalTopN: - preferPushDown = &lp.PreferLimitToCop - meetThreshold = lp.Count+lp.Offset <= uint64(lp.SCtx().GetSessionVars().LimitPushDownThreshold) - case *logicalop.LogicalLimit: - preferPushDown = &lp.PreferLimitToCop - meetThreshold = true // always push Limit down in this case since it has no side effect - default: - return false - } - - if *preferPushDown || meetThreshold { - if p.CanPushToCop(kv.TiKV) { - return true - } - if *preferPushDown { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("Optimizer Hint LIMIT_TO_COP is inapplicable") - *preferPushDown = false - } - } - - return false -} - -func getPhysTopN(lt *logicalop.LogicalTopN, prop *property.PhysicalProperty) []base.PhysicalPlan { - allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} - if !pushLimitOrTopNForcibly(lt) { - allTaskTypes = append(allTaskTypes, property.RootTaskType) - } - if lt.SCtx().GetSessionVars().IsMPPAllowed() { - allTaskTypes = append(allTaskTypes, property.MppTaskType) - } - ret := make([]base.PhysicalPlan, 0, len(allTaskTypes)) - for _, tp := range allTaskTypes { - resultProp := &property.PhysicalProperty{TaskTp: tp, ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus} - topN := PhysicalTopN{ - ByItems: lt.ByItems, - PartitionBy: lt.PartitionBy, - Count: lt.Count, - Offset: lt.Offset, - }.Init(lt.SCtx(), lt.StatsInfo(), lt.QueryBlockOffset(), resultProp) - ret = append(ret, topN) - } - return ret -} - -func getPhysLimits(lt *logicalop.LogicalTopN, prop *property.PhysicalProperty) []base.PhysicalPlan { - p, canPass := GetPropByOrderByItems(lt.ByItems) - if !canPass { - return nil - } - - allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} - if !pushLimitOrTopNForcibly(lt) { - allTaskTypes = append(allTaskTypes, property.RootTaskType) - } - ret := make([]base.PhysicalPlan, 0, len(allTaskTypes)) - for _, tp := range allTaskTypes { - resultProp := &property.PhysicalProperty{TaskTp: tp, ExpectedCnt: float64(lt.Count + lt.Offset), SortItems: p.SortItems, CTEProducerStatus: prop.CTEProducerStatus} - limit := PhysicalLimit{ - Count: lt.Count, - Offset: lt.Offset, - PartitionBy: lt.GetPartitionBy(), - }.Init(lt.SCtx(), lt.StatsInfo(), lt.QueryBlockOffset(), resultProp) - limit.SetSchema(lt.Schema()) - ret = append(ret, limit) - } - return ret -} - -// MatchItems checks if this prop's columns can match by items totally. -func MatchItems(p *property.PhysicalProperty, items []*util.ByItems) bool { - if len(items) < len(p.SortItems) { - return false - } - for i, col := range p.SortItems { - sortItem := items[i] - if sortItem.Desc != col.Desc || !col.Col.EqualColumn(sortItem.Expr) { - return false - } - } - return true -} - -// GetHashJoin is public for cascades planner. -func GetHashJoin(la *LogicalApply, prop *property.PhysicalProperty) *PhysicalHashJoin { - return getHashJoin(&la.LogicalJoin, prop, 1, false) -} - -// ExhaustPhysicalPlans4LogicalApply generates the physical plan for a logical apply. -func ExhaustPhysicalPlans4LogicalApply(la *LogicalApply, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - if !prop.AllColsFromSchema(la.Children()[0].Schema()) || prop.IsFlashProp() { // for convenient, we don't pass through any prop - la.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because operator `Apply` is not supported now.") - return nil, true, nil - } - if !prop.IsSortItemEmpty() && la.SCtx().GetSessionVars().EnableParallelApply { - la.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Parallel Apply rejects the possible order properties of its outer child currently")) - return nil, true, nil - } - disableAggPushDownToCop(la.Children()[0]) - join := GetHashJoin(la, prop) - var columns = make([]*expression.Column, 0, len(la.CorCols)) - for _, colColumn := range la.CorCols { - columns = append(columns, &colColumn.Column) - } - cacheHitRatio := 0.0 - if la.StatsInfo().RowCount != 0 { - ndv, _ := cardinality.EstimateColsNDVWithMatchedLen(columns, la.Schema(), la.StatsInfo()) - // for example, if there are 100 rows and the number of distinct values of these correlated columns - // are 70, then we can assume 30 rows can hit the cache so the cache hit ratio is 1 - (70/100) = 0.3 - cacheHitRatio = 1 - (ndv / la.StatsInfo().RowCount) - } - - var canUseCache bool - if cacheHitRatio > 0.1 && la.SCtx().GetSessionVars().MemQuotaApplyCache > 0 { - canUseCache = true - } else { - canUseCache = false - } - - apply := PhysicalApply{ - PhysicalHashJoin: *join, - OuterSchema: la.CorCols, - CanUseCache: canUseCache, - }.Init(la.SCtx(), - la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), - la.QueryBlockOffset(), - &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, SortItems: prop.SortItems, CTEProducerStatus: prop.CTEProducerStatus}, - &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, CTEProducerStatus: prop.CTEProducerStatus}) - apply.SetSchema(la.Schema()) - return []base.PhysicalPlan{apply}, true, nil -} - -func disableAggPushDownToCop(p base.LogicalPlan) { - for _, child := range p.Children() { - disableAggPushDownToCop(child) - } - if agg, ok := p.(*LogicalAggregation); ok { - agg.NoCopPushDown = true - } -} - -func tryToGetMppWindows(lw *logicalop.LogicalWindow, prop *property.PhysicalProperty) []base.PhysicalPlan { - if !prop.IsSortItemAllForPartition() { - return nil - } - if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { - return nil - } - if prop.MPPPartitionTp == property.BroadcastType { - return nil - } - - { - allSupported := true - sctx := lw.SCtx() - for _, windowFunc := range lw.WindowFuncDescs { - if !windowFunc.CanPushDownToTiFlash(GetPushDownCtx(sctx)) { - lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because window function `" + windowFunc.Name + "` or its arguments are not supported now.") - allSupported = false - } else if !expression.IsPushDownEnabled(windowFunc.Name, kv.TiFlash) { - lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because window function `" + windowFunc.Name + "` is blocked by blacklist, check `table mysql.expr_pushdown_blacklist;` for more information.") - return nil - } - } - if !allSupported { - return nil - } - - if lw.Frame != nil && lw.Frame.Type == ast.Ranges { - ctx := lw.SCtx().GetExprCtx() - if _, err := expression.ExpressionsToPBList(ctx.GetEvalCtx(), lw.Frame.Start.CalcFuncs, lw.SCtx().GetClient()); err != nil { - lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because window function frame can't be pushed down, because " + err.Error()) - return nil - } - if _, err := expression.ExpressionsToPBList(ctx.GetEvalCtx(), lw.Frame.End.CalcFuncs, lw.SCtx().GetClient()); err != nil { - lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because window function frame can't be pushed down, because " + err.Error()) - return nil - } - - if !lw.CheckComparisonForTiFlash(lw.Frame.Start) || !lw.CheckComparisonForTiFlash(lw.Frame.End) { - lw.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because window function frame can't be pushed down, because Duration vs Datetime is invalid comparison as TiFlash can't handle it so far.") - return nil - } - } - } - - var byItems []property.SortItem - byItems = append(byItems, lw.PartitionBy...) - byItems = append(byItems, lw.OrderBy...) - childProperty := &property.PhysicalProperty{ - ExpectedCnt: math.MaxFloat64, - CanAddEnforcer: true, - SortItems: byItems, - TaskTp: property.MppTaskType, - SortItemsForPartition: byItems, - CTEProducerStatus: prop.CTEProducerStatus, - } - if !prop.IsPrefix(childProperty) { - return nil - } - - if len(lw.PartitionBy) > 0 { - partitionCols := lw.GetPartitionKeys() - // trying to match the required partitions. - if prop.MPPPartitionTp == property.HashType { - matches := prop.IsSubsetOf(partitionCols) - if len(matches) == 0 { - // do not satisfy the property of its parent, so return empty - return nil - } - partitionCols = choosePartitionKeys(partitionCols, matches) - } - childProperty.MPPPartitionTp = property.HashType - childProperty.MPPPartitionCols = partitionCols - } else { - childProperty.MPPPartitionTp = property.SinglePartitionType - } - - if prop.MPPPartitionTp == property.SinglePartitionType && childProperty.MPPPartitionTp != property.SinglePartitionType { - return nil - } - - window := PhysicalWindow{ - WindowFuncDescs: lw.WindowFuncDescs, - PartitionBy: lw.PartitionBy, - OrderBy: lw.OrderBy, - Frame: lw.Frame, - storeTp: kv.TiFlash, - }.Init(lw.SCtx(), lw.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), lw.QueryBlockOffset(), childProperty) - window.SetSchema(lw.Schema()) - - return []base.PhysicalPlan{window} -} - -func exhaustPhysicalPlans4LogicalWindow(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - lw := lp.(*logicalop.LogicalWindow) - windows := make([]base.PhysicalPlan, 0, 2) - - canPushToTiFlash := lw.CanPushToCop(kv.TiFlash) - if lw.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { - mppWindows := tryToGetMppWindows(lw, prop) - windows = append(windows, mppWindows...) - } - - // if there needs a mpp task, we don't generate tidb window function. - if prop.TaskTp == property.MppTaskType { - return windows, true, nil - } - var byItems []property.SortItem - byItems = append(byItems, lw.PartitionBy...) - byItems = append(byItems, lw.OrderBy...) - childProperty := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, SortItems: byItems, CanAddEnforcer: true, CTEProducerStatus: prop.CTEProducerStatus} - if !prop.IsPrefix(childProperty) { - return nil, true, nil - } - window := PhysicalWindow{ - WindowFuncDescs: lw.WindowFuncDescs, - PartitionBy: lw.PartitionBy, - OrderBy: lw.OrderBy, - Frame: lw.Frame, - }.Init(lw.SCtx(), lw.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), lw.QueryBlockOffset(), childProperty) - window.SetSchema(lw.Schema()) - - windows = append(windows, window) - return windows, true, nil -} - -func canPushToCopImpl(lp base.LogicalPlan, storeTp kv.StoreType, considerDual bool) bool { - p := lp.GetBaseLogicalPlan().(*logicalop.BaseLogicalPlan) - ret := true - for _, ch := range p.Children() { - switch c := ch.(type) { - case *DataSource: - validDs := false - indexMergeIsIntersection := false - for _, path := range c.PossibleAccessPaths { - if path.StoreType == storeTp { - validDs = true - } - if len(path.PartialIndexPaths) > 0 && path.IndexMergeIsIntersection { - indexMergeIsIntersection = true - } - } - ret = ret && validDs - - _, isTopN := p.Self().(*logicalop.LogicalTopN) - _, isLimit := p.Self().(*logicalop.LogicalLimit) - if (isTopN || isLimit) && indexMergeIsIntersection { - return false // TopN and Limit cannot be pushed down to the intersection type IndexMerge - } - - if c.TableInfo.TableCacheStatusType != model.TableCacheStatusDisable { - // Don't push to cop for cached table, it brings more harm than good: - // 1. Those tables are small enough, push to cop can't utilize several TiKV to accelerate computation. - // 2. Cached table use UnionScan to read the cache data, and push to cop is not supported when an UnionScan exists. - // Once aggregation is pushed to cop, the cache data can't be use any more. - return false - } - case *LogicalUnionAll: - if storeTp != kv.TiFlash { - return false - } - ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, true) - case *logicalop.LogicalSort: - if storeTp != kv.TiFlash { - return false - } - ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, true) - case *logicalop.LogicalProjection: - if storeTp != kv.TiFlash { - return false - } - ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, considerDual) - case *LogicalExpand: - // Expand itself only contains simple col ref and literal projection. (always ok, check its child) - if storeTp != kv.TiFlash { - return false - } - ret = ret && canPushToCopImpl(&c.BaseLogicalPlan, storeTp, considerDual) - case *logicalop.LogicalTableDual: - return storeTp == kv.TiFlash && considerDual - case *LogicalAggregation, *LogicalSelection, *LogicalJoin, *logicalop.LogicalWindow: - if storeTp != kv.TiFlash { - return false - } - ret = ret && c.CanPushToCop(storeTp) - // These operators can be partially push down to TiFlash, so we don't raise warning for them. - case *logicalop.LogicalLimit, *logicalop.LogicalTopN: - return false - case *logicalop.LogicalSequence: - return storeTp == kv.TiFlash - case *LogicalCTE: - if storeTp != kv.TiFlash { - return false - } - if c.Cte.recursivePartLogicalPlan != nil || !c.Cte.seedPartLogicalPlan.CanPushToCop(storeTp) { - return false - } - return true - default: - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because operator `" + c.TP() + "` is not supported now.") - return false - } - } - return ret -} - -func getEnforcedStreamAggs(la *LogicalAggregation, prop *property.PhysicalProperty) []base.PhysicalPlan { - if prop.IsFlashProp() { - return nil - } - _, desc := prop.AllSameOrder() - allTaskTypes := prop.GetAllPossibleChildTaskTypes() - enforcedAggs := make([]base.PhysicalPlan, 0, len(allTaskTypes)) - childProp := &property.PhysicalProperty{ - ExpectedCnt: math.Max(prop.ExpectedCnt*la.InputCount/la.StatsInfo().RowCount, prop.ExpectedCnt), - CanAddEnforcer: true, - SortItems: property.SortItemsFromCols(la.GetGroupByCols(), desc), - } - if !prop.IsPrefix(childProp) { - return enforcedAggs - } - taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} - if la.HasDistinct() { - // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. - // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. - if !la.CanPushToCop(kv.TiKV) || !la.SCtx().GetSessionVars().AllowDistinctAggPushDown { - taskTypes = []property.TaskType{property.RootTaskType} - } - } else if !la.PreferAggToCop { - taskTypes = append(taskTypes, property.RootTaskType) - } - for _, taskTp := range taskTypes { - copiedChildProperty := new(property.PhysicalProperty) - *copiedChildProperty = *childProp // It's ok to not deep copy the "cols" field. - copiedChildProperty.TaskTp = taskTp - - newGbyItems := make([]expression.Expression, len(la.GroupByItems)) - copy(newGbyItems, la.GroupByItems) - newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) - copy(newAggFuncs, la.AggFuncs) - - agg := basePhysicalAgg{ - GroupByItems: newGbyItems, - AggFuncs: newAggFuncs, - }.initForStream(la.SCtx(), la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), la.QueryBlockOffset(), copiedChildProperty) - agg.SetSchema(la.Schema().Clone()) - enforcedAggs = append(enforcedAggs, agg) - } - return enforcedAggs -} - -func (la *LogicalAggregation) distinctArgsMeetsProperty() bool { - for _, aggFunc := range la.AggFuncs { - if aggFunc.HasDistinct { - for _, distinctArg := range aggFunc.Args { - if !expression.Contains(la.SCtx().GetExprCtx().GetEvalCtx(), la.GroupByItems, distinctArg) { - return false - } - } - } - } - return true -} - -func getStreamAggs(lp base.LogicalPlan, prop *property.PhysicalProperty) []base.PhysicalPlan { - la := lp.(*LogicalAggregation) - // TODO: support CopTiFlash task type in stream agg - if prop.IsFlashProp() { - return nil - } - all, desc := prop.AllSameOrder() - if !all { - return nil - } - - for _, aggFunc := range la.AggFuncs { - if aggFunc.Mode == aggregation.FinalMode { - return nil - } - } - // group by a + b is not interested in any order. - groupByCols := la.GetGroupByCols() - if len(groupByCols) != len(la.GroupByItems) { - return nil - } - - allTaskTypes := prop.GetAllPossibleChildTaskTypes() - streamAggs := make([]base.PhysicalPlan, 0, len(la.PossibleProperties)*(len(allTaskTypes)-1)+len(allTaskTypes)) - childProp := &property.PhysicalProperty{ - ExpectedCnt: math.Max(prop.ExpectedCnt*la.InputCount/la.StatsInfo().RowCount, prop.ExpectedCnt), - } - - for _, possibleChildProperty := range la.PossibleProperties { - childProp.SortItems = property.SortItemsFromCols(possibleChildProperty[:len(groupByCols)], desc) - if !prop.IsPrefix(childProp) { - continue - } - // The table read of "CopDoubleReadTaskType" can't promises the sort - // property that the stream aggregation required, no need to consider. - taskTypes := []property.TaskType{property.CopSingleReadTaskType} - if la.HasDistinct() { - // TODO: remove AllowDistinctAggPushDown after the cost estimation of distinct pushdown is implemented. - // If AllowDistinctAggPushDown is set to true, we should not consider RootTask. - if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown || !la.CanPushToCop(kv.TiKV) { - // if variable doesn't allow DistinctAggPushDown, just produce root task type. - // if variable does allow DistinctAggPushDown, but OP itself can't be pushed down to tikv, just produce root task type. - taskTypes = []property.TaskType{property.RootTaskType} - } else if !la.DistinctArgsMeetsProperty() { - continue - } - } else if !la.PreferAggToCop { - taskTypes = append(taskTypes, property.RootTaskType) - } - if !la.CanPushToCop(kv.TiKV) && !la.CanPushToCop(kv.TiFlash) { - taskTypes = []property.TaskType{property.RootTaskType} - } - for _, taskTp := range taskTypes { - copiedChildProperty := new(property.PhysicalProperty) - *copiedChildProperty = *childProp // It's ok to not deep copy the "cols" field. - copiedChildProperty.TaskTp = taskTp - - newGbyItems := make([]expression.Expression, len(la.GroupByItems)) - copy(newGbyItems, la.GroupByItems) - newAggFuncs := make([]*aggregation.AggFuncDesc, len(la.AggFuncs)) - copy(newAggFuncs, la.AggFuncs) - - agg := basePhysicalAgg{ - GroupByItems: newGbyItems, - AggFuncs: newAggFuncs, - }.initForStream(la.SCtx(), la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), la.QueryBlockOffset(), copiedChildProperty) - agg.SetSchema(la.Schema().Clone()) - streamAggs = append(streamAggs, agg) - } - } - // If STREAM_AGG hint is existed, it should consider enforce stream aggregation, - // because we can't trust possibleChildProperty completely. - if (la.PreferAggType & h.PreferStreamAgg) > 0 { - streamAggs = append(streamAggs, getEnforcedStreamAggs(la, prop)...) - } - return streamAggs -} - -// TODO: support more operators and distinct later -func checkCanPushDownToMPP(la *LogicalAggregation) bool { - hasUnsupportedDistinct := false - for _, agg := range la.AggFuncs { - // MPP does not support distinct except count distinct now - if agg.HasDistinct { - if agg.Name != ast.AggFuncCount && agg.Name != ast.AggFuncGroupConcat { - hasUnsupportedDistinct = true - } - } - // MPP does not support AggFuncApproxCountDistinct now - if agg.Name == ast.AggFuncApproxCountDistinct { - hasUnsupportedDistinct = true - } - } - if hasUnsupportedDistinct { - warnErr := errors.NewNoStackError("Aggregation can not be pushed to storage layer in mpp mode because it contains agg function with distinct") - if la.SCtx().GetSessionVars().StmtCtx.InExplainStmt { - la.SCtx().GetSessionVars().StmtCtx.AppendWarning(warnErr) - } else { - la.SCtx().GetSessionVars().StmtCtx.AppendExtraWarning(warnErr) - } - return false - } - return CheckAggCanPushCop(la.SCtx(), la.AggFuncs, la.GroupByItems, kv.TiFlash) -} - -func tryToGetMppHashAggs(la *LogicalAggregation, prop *property.PhysicalProperty) (hashAggs []base.PhysicalPlan) { - if !prop.IsSortItemEmpty() { - return nil - } - if prop.TaskTp != property.RootTaskType && prop.TaskTp != property.MppTaskType { - return nil - } - if prop.MPPPartitionTp == property.BroadcastType { - return nil - } - - // Is this aggregate a final stage aggregate? - // Final agg can't be split into multi-stage aggregate - hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode - // count final agg should become sum for MPP execution path. - // In the traditional case, TiDB take up the final agg role and push partial agg to TiKV, - // while TiDB can tell the partialMode and do the sum computation rather than counting but MPP doesn't - finalAggAdjust := func(aggFuncs []*aggregation.AggFuncDesc) { - for i, agg := range aggFuncs { - if agg.Mode == aggregation.FinalMode && agg.Name == ast.AggFuncCount { - oldFT := agg.RetTp - aggFuncs[i], _ = aggregation.NewAggFuncDesc(la.SCtx().GetExprCtx(), ast.AggFuncSum, agg.Args, false) - aggFuncs[i].TypeInfer4FinalCount(oldFT) - } - } - } - // ref: https://github.com/pingcap/tiflash/blob/3ebb102fba17dce3d990d824a9df93d93f1ab - // 766/dbms/src/Flash/Coprocessor/AggregationInterpreterHelper.cpp#L26 - validMppAgg := func(mppAgg *PhysicalHashAgg) bool { - isFinalAgg := true - if mppAgg.AggFuncs[0].Mode != aggregation.FinalMode && mppAgg.AggFuncs[0].Mode != aggregation.CompleteMode { - isFinalAgg = false - } - for _, one := range mppAgg.AggFuncs[1:] { - otherIsFinalAgg := one.Mode == aggregation.FinalMode || one.Mode == aggregation.CompleteMode - if isFinalAgg != otherIsFinalAgg { - // different agg mode detected in mpp side. - return false - } - } - return true - } - - if len(la.GroupByItems) > 0 { - partitionCols := la.GetPotentialPartitionKeys() - // trying to match the required partitions. - if prop.MPPPartitionTp == property.HashType { - // partition key required by upper layer is subset of current layout. - matches := prop.IsSubsetOf(partitionCols) - if len(matches) == 0 { - // do not satisfy the property of its parent, so return empty - return nil - } - partitionCols = choosePartitionKeys(partitionCols, matches) - } else if prop.MPPPartitionTp != property.AnyType { - return nil - } - // TODO: permute various partition columns from group-by columns - // 1-phase agg - // If there are no available partition cols, but still have group by items, that means group by items are all expressions or constants. - // To avoid mess, we don't do any one-phase aggregation in this case. - // If this is a skew distinct group agg, skip generating 1-phase agg, because skew data will cause performance issue - // - // Rollup can't be 1-phase agg: cause it will append grouping_id to the schema, and expand each row as multi rows with different grouping_id. - // In a general, group items should also append grouping_id as its group layout, let's say 1-phase agg has grouping items as , and - // lower OP can supply as original partition layout, when we insert Expand logic between them: - // --> after fill null in Expand --> and this shown two rows should be shuffled to the same node (the underlying partition is not satisfied yet) - // <1,1> in node A <1,null,gid=1> in node A - // <1,2> in node B <1,null,gid=1> in node B - if len(partitionCols) != 0 && !la.SCtx().GetSessionVars().EnableSkewDistinctAgg { - childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) - agg.SetSchema(la.Schema().Clone()) - agg.MppRunMode = Mpp1Phase - finalAggAdjust(agg.AggFuncs) - if validMppAgg(agg) { - hashAggs = append(hashAggs, agg) - } - } - - // Final agg can't be split into multi-stage aggregate, so exit early - if hasFinalAgg { - return - } - - // 2-phase agg - // no partition property down,record partition cols inside agg itself, enforce shuffler latter. - childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) - agg.SetSchema(la.Schema().Clone()) - agg.MppRunMode = Mpp2Phase - agg.MppPartitionCols = partitionCols - if validMppAgg(agg) { - hashAggs = append(hashAggs, agg) - } - - // agg runs on TiDB with a partial agg on TiFlash if possible - if prop.TaskTp == property.RootTaskType { - childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) - agg.SetSchema(la.Schema().Clone()) - agg.MppRunMode = MppTiDB - hashAggs = append(hashAggs, agg) - } - } else if !hasFinalAgg { - // TODO: support scalar agg in MPP, merge the final result to one node - childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) - agg.SetSchema(la.Schema().Clone()) - if la.HasDistinct() || la.HasOrderBy() { - // mpp scalar mode means the data will be pass through to only one tiFlash node at last. - agg.MppRunMode = MppScalar - } else { - agg.MppRunMode = MppTiDB - } - hashAggs = append(hashAggs, agg) - } - - // handle MPP Agg hints - var preferMode AggMppRunMode - var prefer bool - if la.PreferAggType&h.PreferMPP1PhaseAgg > 0 { - preferMode, prefer = Mpp1Phase, true - } else if la.PreferAggType&h.PreferMPP2PhaseAgg > 0 { - preferMode, prefer = Mpp2Phase, true - } - if prefer { - var preferPlans []base.PhysicalPlan - for _, agg := range hashAggs { - if hg, ok := agg.(*PhysicalHashAgg); ok && hg.MppRunMode == preferMode { - preferPlans = append(preferPlans, hg) - } - } - hashAggs = preferPlans - } - return -} - -// getHashAggs will generate some kinds of taskType here, which finally converted to different task plan. -// when deciding whether to add a kind of taskType, there is a rule here. [Not is Not, Yes is not Sure] -// eg: which means -// -// 1: when you find something here that block hashAgg to be pushed down to XXX, just skip adding the XXXTaskType. -// 2: when you find nothing here to block hashAgg to be pushed down to XXX, just add the XXXTaskType here. -// for 2, the final result for this physical operator enumeration is chosen or rejected is according to more factors later (hint/variable/partition/virtual-col/cost) -// -// That is to say, the non-complete positive judgement of canPushDownToMPP/canPushDownToTiFlash/canPushDownToTiKV is not that for sure here. -func getHashAggs(lp base.LogicalPlan, prop *property.PhysicalProperty) []base.PhysicalPlan { - la := lp.(*LogicalAggregation) - if !prop.IsSortItemEmpty() { - return nil - } - if prop.TaskTp == property.MppTaskType && !checkCanPushDownToMPP(la) { - return nil - } - hashAggs := make([]base.PhysicalPlan, 0, len(prop.GetAllPossibleChildTaskTypes())) - taskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} - canPushDownToTiFlash := la.CanPushToCop(kv.TiFlash) - canPushDownToMPP := canPushDownToTiFlash && la.SCtx().GetSessionVars().IsMPPAllowed() && checkCanPushDownToMPP(la) - if la.HasDistinct() { - // TODO: remove after the cost estimation of distinct pushdown is implemented. - if !la.SCtx().GetSessionVars().AllowDistinctAggPushDown || !la.CanPushToCop(kv.TiKV) { - // if variable doesn't allow DistinctAggPushDown, just produce root task type. - // if variable does allow DistinctAggPushDown, but OP itself can't be pushed down to tikv, just produce root task type. - taskTypes = []property.TaskType{property.RootTaskType} - } - } else if !la.PreferAggToCop { - taskTypes = append(taskTypes, property.RootTaskType) - } - if !la.CanPushToCop(kv.TiKV) && !canPushDownToTiFlash { - taskTypes = []property.TaskType{property.RootTaskType} - } - if canPushDownToMPP { - taskTypes = append(taskTypes, property.MppTaskType) - } else { - hasMppHints := false - var errMsg string - if la.PreferAggType&h.PreferMPP1PhaseAgg > 0 { - errMsg = "The agg can not push down to the MPP side, the MPP_1PHASE_AGG() hint is invalid" - hasMppHints = true - } - if la.PreferAggType&h.PreferMPP2PhaseAgg > 0 { - errMsg = "The agg can not push down to the MPP side, the MPP_2PHASE_AGG() hint is invalid" - hasMppHints = true - } - if hasMppHints { - la.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) - } - } - if prop.IsFlashProp() { - taskTypes = []property.TaskType{prop.TaskTp} - } - - for _, taskTp := range taskTypes { - if taskTp == property.MppTaskType { - mppAggs := tryToGetMppHashAggs(la, prop) - if len(mppAggs) > 0 { - hashAggs = append(hashAggs, mppAggs...) - } - } else { - agg := NewPhysicalHashAgg(la, la.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, TaskTp: taskTp, CTEProducerStatus: prop.CTEProducerStatus}) - agg.SetSchema(la.Schema().Clone()) - hashAggs = append(hashAggs, agg) - } - } - return hashAggs -} - -func exhaustPhysicalPlans4LogicalSelection(p *LogicalSelection, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - newProps := make([]*property.PhysicalProperty, 0, 2) - childProp := prop.CloneEssentialFields() - newProps = append(newProps, childProp) - - if prop.TaskTp != property.MppTaskType && - p.SCtx().GetSessionVars().IsMPPAllowed() && - p.canPushDown(kv.TiFlash) { - childPropMpp := prop.CloneEssentialFields() - childPropMpp.TaskTp = property.MppTaskType - newProps = append(newProps, childPropMpp) - } - - ret := make([]base.PhysicalPlan, 0, len(newProps)) - for _, newProp := range newProps { - sel := PhysicalSelection{ - Conditions: p.Conditions, - }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), newProp) - ret = append(ret, sel) - } - return ret, true, nil -} - -func exhaustPhysicalPlans4LogicalLimit(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - p := lp.(*logicalop.LogicalLimit) - return getLimitPhysicalPlans(p, prop) -} - -func getLimitPhysicalPlans(p *logicalop.LogicalLimit, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - if !prop.IsSortItemEmpty() { - return nil, true, nil - } - - allTaskTypes := []property.TaskType{property.CopSingleReadTaskType, property.CopMultiReadTaskType} - if !pushLimitOrTopNForcibly(p) { - allTaskTypes = append(allTaskTypes, property.RootTaskType) - } - if p.CanPushToCop(kv.TiFlash) && p.SCtx().GetSessionVars().IsMPPAllowed() { - allTaskTypes = append(allTaskTypes, property.MppTaskType) - } - ret := make([]base.PhysicalPlan, 0, len(allTaskTypes)) - for _, tp := range allTaskTypes { - resultProp := &property.PhysicalProperty{TaskTp: tp, ExpectedCnt: float64(p.Count + p.Offset), CTEProducerStatus: prop.CTEProducerStatus} - limit := PhysicalLimit{ - Offset: p.Offset, - Count: p.Count, - PartitionBy: p.GetPartitionBy(), - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), resultProp) - limit.SetSchema(p.Schema()) - ret = append(ret, limit) - } - return ret, true, nil -} - -func exhaustPhysicalPlans4LogicalLock(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - p := lp.(*logicalop.LogicalLock) - if prop.IsFlashProp() { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced( - "MPP mode may be blocked because operator `Lock` is not supported now.") - return nil, true, nil - } - childProp := prop.CloneEssentialFields() - lock := PhysicalLock{ - Lock: p.Lock, - TblID2Handle: p.TblID2Handle, - TblID2PhysTblIDCol: p.TblID2PhysTblIDCol, - }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), childProp) - return []base.PhysicalPlan{lock}, true, nil -} - -func exhaustUnionAllPhysicalPlans(p *LogicalUnionAll, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - // TODO: UnionAll can not pass any order, but we can change it to sort merge to keep order. - if !prop.IsSortItemEmpty() || (prop.IsFlashProp() && prop.TaskTp != property.MppTaskType) { - return nil, true, nil - } - // TODO: UnionAll can pass partition info, but for briefness, we prevent it from pushing down. - if prop.TaskTp == property.MppTaskType && prop.MPPPartitionTp != property.AnyType { - return nil, true, nil - } - canUseMpp := p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToCopImpl(&p.BaseLogicalPlan, kv.TiFlash, true) - chReqProps := make([]*property.PhysicalProperty, 0, p.ChildLen()) - for range p.Children() { - if canUseMpp && prop.TaskTp == property.MppTaskType { - chReqProps = append(chReqProps, &property.PhysicalProperty{ - ExpectedCnt: prop.ExpectedCnt, - TaskTp: property.MppTaskType, - RejectSort: true, - CTEProducerStatus: prop.CTEProducerStatus, - }) - } else { - chReqProps = append(chReqProps, &property.PhysicalProperty{ExpectedCnt: prop.ExpectedCnt, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus}) - } - } - ua := PhysicalUnionAll{ - mpp: canUseMpp && prop.TaskTp == property.MppTaskType, - }.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), chReqProps...) - ua.SetSchema(p.Schema()) - if canUseMpp && prop.TaskTp == property.RootTaskType { - chReqProps = make([]*property.PhysicalProperty, 0, p.ChildLen()) - for range p.Children() { - chReqProps = append(chReqProps, &property.PhysicalProperty{ - ExpectedCnt: prop.ExpectedCnt, - TaskTp: property.MppTaskType, - RejectSort: true, - CTEProducerStatus: prop.CTEProducerStatus, - }) - } - mppUA := PhysicalUnionAll{mpp: true}.Init(p.SCtx(), p.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), p.QueryBlockOffset(), chReqProps...) - mppUA.SetSchema(p.Schema()) - return []base.PhysicalPlan{ua, mppUA}, true, nil - } - return []base.PhysicalPlan{ua}, true, nil -} - -func exhaustPartitionUnionAllPhysicalPlans(p *LogicalPartitionUnionAll, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - uas, flagHint, err := p.LogicalUnionAll.ExhaustPhysicalPlans(prop) - if err != nil { - return nil, false, err - } - for _, ua := range uas { - ua.(*PhysicalUnionAll).SetTP(plancodec.TypePartitionUnion) - } - return uas, flagHint, nil -} - -func exhaustPhysicalPlans4LogicalTopN(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - lt := lp.(*logicalop.LogicalTopN) - if MatchItems(prop, lt.ByItems) { - return append(getPhysTopN(lt, prop), getPhysLimits(lt, prop)...), true, nil - } - return nil, true, nil -} - -func exhaustPhysicalPlans4LogicalSort(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - ls := lp.(*logicalop.LogicalSort) - if prop.TaskTp == property.RootTaskType { - if MatchItems(prop, ls.ByItems) { - ret := make([]base.PhysicalPlan, 0, 2) - ret = append(ret, getPhysicalSort(ls, prop)) - ns := getNominalSort(ls, prop) - if ns != nil { - ret = append(ret, ns) - } - return ret, true, nil - } - } else if prop.TaskTp == property.MppTaskType && prop.RejectSort { - if canPushToCopImpl(&ls.BaseLogicalPlan, kv.TiFlash, true) { - ps := getNominalSortSimple(ls, prop) - return []base.PhysicalPlan{ps}, true, nil - } - } - return nil, true, nil -} - -func getPhysicalSort(ls *logicalop.LogicalSort, prop *property.PhysicalProperty) base.PhysicalPlan { - ps := PhysicalSort{ByItems: ls.ByItems}.Init(ls.SCtx(), ls.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ls.QueryBlockOffset(), &property.PhysicalProperty{TaskTp: prop.TaskTp, ExpectedCnt: math.MaxFloat64, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus}) - return ps -} - -func getNominalSort(ls *logicalop.LogicalSort, reqProp *property.PhysicalProperty) *NominalSort { - prop, canPass, onlyColumn := GetPropByOrderByItemsContainScalarFunc(ls.ByItems) - if !canPass { - return nil - } - prop.RejectSort = true - prop.ExpectedCnt = reqProp.ExpectedCnt - ps := NominalSort{OnlyColumn: onlyColumn, ByItems: ls.ByItems}.Init( - ls.SCtx(), ls.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ls.QueryBlockOffset(), prop) - return ps -} - -func getNominalSortSimple(ls *logicalop.LogicalSort, reqProp *property.PhysicalProperty) *NominalSort { - newProp := reqProp.CloneEssentialFields() - newProp.RejectSort = true - ps := NominalSort{OnlyColumn: true, ByItems: ls.ByItems}.Init( - ls.SCtx(), ls.StatsInfo().ScaleByExpectCnt(reqProp.ExpectedCnt), ls.QueryBlockOffset(), newProp) - return ps -} - -func exhaustPhysicalPlans4LogicalMaxOneRow(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - p := lp.(*logicalop.LogicalMaxOneRow) - if !prop.IsSortItemEmpty() || prop.IsFlashProp() { - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because operator `MaxOneRow` is not supported now.") - return nil, true, nil - } - mor := PhysicalMaxOneRow{}.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), &property.PhysicalProperty{ExpectedCnt: 2, CTEProducerStatus: prop.CTEProducerStatus}) - return []base.PhysicalPlan{mor}, true, nil -} - -func exhaustPhysicalPlans4LogicalCTE(p *LogicalCTE, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - pcte := PhysicalCTE{CTE: p.Cte}.Init(p.SCtx(), p.StatsInfo()) - if prop.IsFlashProp() { - pcte.storageSender = PhysicalExchangeSender{ - ExchangeType: tipb.ExchangeType_Broadcast, - }.Init(p.SCtx(), p.StatsInfo()) - } - pcte.SetSchema(p.Schema()) - pcte.childrenReqProps = []*property.PhysicalProperty{prop.CloneEssentialFields()} - return []base.PhysicalPlan{(*PhysicalCTEStorage)(pcte)}, true, nil -} - -func exhaustPhysicalPlans4LogicalSequence(lp base.LogicalPlan, prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - p := lp.(*logicalop.LogicalSequence) - possibleChildrenProps := make([][]*property.PhysicalProperty, 0, 2) - anyType := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, CanAddEnforcer: true, RejectSort: true, CTEProducerStatus: prop.CTEProducerStatus} - if prop.TaskTp == property.MppTaskType { - if prop.CTEProducerStatus == property.SomeCTEFailedMpp { - return nil, true, nil - } - anyType.CTEProducerStatus = property.AllCTECanMpp - possibleChildrenProps = append(possibleChildrenProps, []*property.PhysicalProperty{anyType, prop.CloneEssentialFields()}) - } else { - copied := prop.CloneEssentialFields() - copied.CTEProducerStatus = property.SomeCTEFailedMpp - possibleChildrenProps = append(possibleChildrenProps, []*property.PhysicalProperty{{TaskTp: property.RootTaskType, ExpectedCnt: math.MaxFloat64, CTEProducerStatus: property.SomeCTEFailedMpp}, copied}) - } - - if prop.TaskTp != property.MppTaskType && prop.CTEProducerStatus != property.SomeCTEFailedMpp && - p.SCtx().GetSessionVars().IsMPPAllowed() && prop.IsSortItemEmpty() { - possibleChildrenProps = append(possibleChildrenProps, []*property.PhysicalProperty{anyType, anyType.CloneEssentialFields()}) - } - seqs := make([]base.PhysicalPlan, 0, 2) - for _, propChoice := range possibleChildrenProps { - childReqs := make([]*property.PhysicalProperty, 0, p.ChildLen()) - for i := 0; i < p.ChildLen()-1; i++ { - childReqs = append(childReqs, propChoice[0].CloneEssentialFields()) - } - childReqs = append(childReqs, propChoice[1]) - seq := PhysicalSequence{}.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), childReqs...) - seq.SetSchema(p.Children()[p.ChildLen()-1].Schema()) - seqs = append(seqs, seq) - } - return seqs, true, nil -} diff --git a/pkg/planner/core/find_best_task.go b/pkg/planner/core/find_best_task.go index 340e5f5a77c6c..8e819dee9c6e8 100644 --- a/pkg/planner/core/find_best_task.go +++ b/pkg/planner/core/find_best_task.go @@ -1558,13 +1558,13 @@ func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, c if !prop.IsSortItemEmpty() && candidate.path.IndexMergeIsIntersection { return base.InvalidTask, nil } - if _, _err_ := failpoint.Eval(_curpkg_("forceIndexMergeKeepOrder")); _err_ == nil { + failpoint.Inject("forceIndexMergeKeepOrder", func(_ failpoint.Value) { if len(candidate.path.PartialIndexPaths) > 0 && !candidate.path.IndexMergeIsIntersection { if prop.IsSortItemEmpty() { - return base.InvalidTask, nil + failpoint.Return(base.InvalidTask, nil) } } - } + }) path := candidate.path scans := make([]base.PhysicalPlan, 0, len(path.PartialIndexPaths)) cop := &CopTask{ diff --git a/pkg/planner/core/find_best_task.go__failpoint_stash__ b/pkg/planner/core/find_best_task.go__failpoint_stash__ deleted file mode 100644 index 8e819dee9c6e8..0000000000000 --- a/pkg/planner/core/find_best_task.go__failpoint_stash__ +++ /dev/null @@ -1,2982 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "cmp" - "fmt" - "math" - "slices" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/cardinality" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/cost" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" - "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" - "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/types" - tidbutil "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - h "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/ranger" - "github.com/pingcap/tidb/pkg/util/tracing" - "go.uber.org/zap" -) - -// PlanCounterDisabled is the default value of PlanCounterTp, indicating that optimizer needn't force a plan. -var PlanCounterDisabled base.PlanCounterTp = -1 - -// GetPropByOrderByItems will check if this sort property can be pushed or not. In order to simplify the problem, we only -// consider the case that all expression are columns. -func GetPropByOrderByItems(items []*util.ByItems) (*property.PhysicalProperty, bool) { - propItems := make([]property.SortItem, 0, len(items)) - for _, item := range items { - col, ok := item.Expr.(*expression.Column) - if !ok { - return nil, false - } - propItems = append(propItems, property.SortItem{Col: col, Desc: item.Desc}) - } - return &property.PhysicalProperty{SortItems: propItems}, true -} - -// GetPropByOrderByItemsContainScalarFunc will check if this sort property can be pushed or not. In order to simplify the -// problem, we only consider the case that all expression are columns or some special scalar functions. -func GetPropByOrderByItemsContainScalarFunc(items []*util.ByItems) (*property.PhysicalProperty, bool, bool) { - propItems := make([]property.SortItem, 0, len(items)) - onlyColumn := true - for _, item := range items { - switch expr := item.Expr.(type) { - case *expression.Column: - propItems = append(propItems, property.SortItem{Col: expr, Desc: item.Desc}) - case *expression.ScalarFunction: - col, desc := expr.GetSingleColumn(item.Desc) - if col == nil { - return nil, false, false - } - propItems = append(propItems, property.SortItem{Col: col, Desc: desc}) - onlyColumn = false - default: - return nil, false, false - } - } - return &property.PhysicalProperty{SortItems: propItems}, true, onlyColumn -} - -func findBestTask4LogicalTableDual(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, opt *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) { - p := lp.(*logicalop.LogicalTableDual) - // If the required property is not empty and the row count > 1, - // we cannot ensure this required property. - // But if the row count is 0 or 1, we don't need to care about the property. - if (!prop.IsSortItemEmpty() && p.RowCount > 1) || planCounter.Empty() { - return base.InvalidTask, 0, nil - } - dual := PhysicalTableDual{ - RowCount: p.RowCount, - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) - dual.SetSchema(p.Schema()) - planCounter.Dec(1) - utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, dual, prop) - rt := &RootTask{} - rt.SetPlan(dual) - rt.SetEmpty(p.RowCount == 0) - return rt, 1, nil -} - -func findBestTask4LogicalShow(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) { - p := lp.(*logicalop.LogicalShow) - if !prop.IsSortItemEmpty() || planCounter.Empty() { - return base.InvalidTask, 0, nil - } - pShow := PhysicalShow{ShowContents: p.ShowContents, Extractor: p.Extractor}.Init(p.SCtx()) - pShow.SetSchema(p.Schema()) - planCounter.Dec(1) - rt := &RootTask{} - rt.SetPlan(pShow) - return rt, 1, nil -} - -func findBestTask4LogicalShowDDLJobs(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, int64, error) { - p := lp.(*logicalop.LogicalShowDDLJobs) - if !prop.IsSortItemEmpty() || planCounter.Empty() { - return base.InvalidTask, 0, nil - } - pShow := PhysicalShowDDLJobs{JobNumber: p.JobNumber}.Init(p.SCtx()) - pShow.SetSchema(p.Schema()) - planCounter.Dec(1) - rt := &RootTask{} - rt.SetPlan(pShow) - return rt, 1, nil -} - -// rebuildChildTasks rebuilds the childTasks to make the clock_th combination. -func rebuildChildTasks(p *logicalop.BaseLogicalPlan, childTasks *[]base.Task, pp base.PhysicalPlan, childCnts []int64, planCounter int64, ts uint64, opt *optimizetrace.PhysicalOptimizeOp) error { - // The taskMap of children nodes should be rolled back first. - for _, child := range p.Children() { - child.RollBackTaskMap(ts) - } - - multAll := int64(1) - var curClock base.PlanCounterTp - for _, x := range childCnts { - multAll *= x - } - *childTasks = (*childTasks)[:0] - for j, child := range p.Children() { - multAll /= childCnts[j] - curClock = base.PlanCounterTp((planCounter-1)/multAll + 1) - childTask, _, err := child.FindBestTask(pp.GetChildReqProps(j), &curClock, opt) - planCounter = (planCounter-1)%multAll + 1 - if err != nil { - return err - } - if curClock != 0 { - return errors.Errorf("PlanCounterTp planCounter is not handled") - } - if childTask != nil && childTask.Invalid() { - return errors.Errorf("The current plan is invalid, please skip this plan") - } - *childTasks = append(*childTasks, childTask) - } - return nil -} - -func enumeratePhysicalPlans4Task( - p *logicalop.BaseLogicalPlan, - physicalPlans []base.PhysicalPlan, - prop *property.PhysicalProperty, - addEnforcer bool, - planCounter *base.PlanCounterTp, - opt *optimizetrace.PhysicalOptimizeOp, -) (base.Task, int64, error) { - var bestTask base.Task = base.InvalidTask - var curCntPlan, cntPlan int64 - var err error - childTasks := make([]base.Task, 0, p.ChildLen()) - childCnts := make([]int64, p.ChildLen()) - cntPlan = 0 - iteration := iteratePhysicalPlan4BaseLogical - if _, ok := p.Self().(*logicalop.LogicalSequence); ok { - iteration = iterateChildPlan4LogicalSequence - } - - for _, pp := range physicalPlans { - timeStampNow := p.GetLogicalTS4TaskMap() - savedPlanID := p.SCtx().GetSessionVars().PlanID.Load() - - childTasks, curCntPlan, childCnts, err = iteration(p, pp, childTasks, childCnts, prop, opt) - if err != nil { - return nil, 0, err - } - - // This check makes sure that there is no invalid child task. - if len(childTasks) != p.ChildLen() { - continue - } - - // If the target plan can be found in this physicalPlan(pp), rebuild childTasks to build the corresponding combination. - if planCounter.IsForce() && int64(*planCounter) <= curCntPlan { - p.SCtx().GetSessionVars().PlanID.Store(savedPlanID) - curCntPlan = int64(*planCounter) - err := rebuildChildTasks(p, &childTasks, pp, childCnts, int64(*planCounter), timeStampNow, opt) - if err != nil { - return nil, 0, err - } - } - - // Combine the best child tasks with parent physical plan. - curTask := pp.Attach2Task(childTasks...) - if curTask.Invalid() { - continue - } - - // An optimal task could not satisfy the property, so it should be converted here. - if _, ok := curTask.(*RootTask); !ok && prop.TaskTp == property.RootTaskType { - curTask = curTask.ConvertToRootTask(p.SCtx()) - } - - // Enforce curTask property - if addEnforcer { - curTask = enforceProperty(prop, curTask, p.Plan.SCtx()) - } - - // Optimize by shuffle executor to running in parallel manner. - if _, isMpp := curTask.(*MppTask); !isMpp && prop.IsSortItemEmpty() { - // Currently, we do not regard shuffled plan as a new plan. - curTask = optimizeByShuffle(curTask, p.Plan.SCtx()) - } - - cntPlan += curCntPlan - planCounter.Dec(curCntPlan) - - if planCounter.Empty() { - bestTask = curTask - break - } - utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, curTask.Plan(), prop) - // Get the most efficient one. - if curIsBetter, err := compareTaskCost(curTask, bestTask, opt); err != nil { - return nil, 0, err - } else if curIsBetter { - bestTask = curTask - } - } - return bestTask, cntPlan, nil -} - -// iteratePhysicalPlan4BaseLogical is used to iterate the physical plan and get all child tasks. -func iteratePhysicalPlan4BaseLogical( - p *logicalop.BaseLogicalPlan, - selfPhysicalPlan base.PhysicalPlan, - childTasks []base.Task, - childCnts []int64, - _ *property.PhysicalProperty, - opt *optimizetrace.PhysicalOptimizeOp, -) ([]base.Task, int64, []int64, error) { - // Find best child tasks firstly. - childTasks = childTasks[:0] - // The curCntPlan records the number of possible plans for pp - curCntPlan := int64(1) - for j, child := range p.Children() { - childProp := selfPhysicalPlan.GetChildReqProps(j) - childTask, cnt, err := child.FindBestTask(childProp, &PlanCounterDisabled, opt) - childCnts[j] = cnt - if err != nil { - return nil, 0, childCnts, err - } - curCntPlan = curCntPlan * cnt - if childTask != nil && childTask.Invalid() { - return nil, 0, childCnts, nil - } - childTasks = append(childTasks, childTask) - } - - // This check makes sure that there is no invalid child task. - if len(childTasks) != p.ChildLen() { - return nil, 0, childCnts, nil - } - return childTasks, curCntPlan, childCnts, nil -} - -// iterateChildPlan4LogicalSequence does the special part for sequence. We need to iterate its child one by one to check whether the former child is a valid plan and then go to the nex -func iterateChildPlan4LogicalSequence( - p *logicalop.BaseLogicalPlan, - selfPhysicalPlan base.PhysicalPlan, - childTasks []base.Task, - childCnts []int64, - prop *property.PhysicalProperty, - opt *optimizetrace.PhysicalOptimizeOp, -) ([]base.Task, int64, []int64, error) { - // Find best child tasks firstly. - childTasks = childTasks[:0] - // The curCntPlan records the number of possible plans for pp - curCntPlan := int64(1) - lastIdx := p.ChildLen() - 1 - for j := 0; j < lastIdx; j++ { - child := p.Children()[j] - childProp := selfPhysicalPlan.GetChildReqProps(j) - childTask, cnt, err := child.FindBestTask(childProp, &PlanCounterDisabled, opt) - childCnts[j] = cnt - if err != nil { - return nil, 0, nil, err - } - curCntPlan = curCntPlan * cnt - if childTask != nil && childTask.Invalid() { - return nil, 0, nil, nil - } - _, isMpp := childTask.(*MppTask) - if !isMpp && prop.IsFlashProp() { - break - } - childTasks = append(childTasks, childTask) - } - // This check makes sure that there is no invalid child task. - if len(childTasks) != p.ChildLen()-1 { - return nil, 0, nil, nil - } - - lastChildProp := selfPhysicalPlan.GetChildReqProps(lastIdx).CloneEssentialFields() - if lastChildProp.IsFlashProp() { - lastChildProp.CTEProducerStatus = property.AllCTECanMpp - } - lastChildTask, cnt, err := p.Children()[lastIdx].FindBestTask(lastChildProp, &PlanCounterDisabled, opt) - childCnts[lastIdx] = cnt - if err != nil { - return nil, 0, nil, err - } - curCntPlan = curCntPlan * cnt - if lastChildTask != nil && lastChildTask.Invalid() { - return nil, 0, nil, nil - } - - if _, ok := lastChildTask.(*MppTask); !ok && lastChildProp.CTEProducerStatus == property.AllCTECanMpp { - return nil, 0, nil, nil - } - - childTasks = append(childTasks, lastChildTask) - return childTasks, curCntPlan, childCnts, nil -} - -// compareTaskCost compares cost of curTask and bestTask and returns whether curTask's cost is smaller than bestTask's. -func compareTaskCost(curTask, bestTask base.Task, op *optimizetrace.PhysicalOptimizeOp) (curIsBetter bool, err error) { - curCost, curInvalid, err := utilfuncp.GetTaskPlanCost(curTask, op) - if err != nil { - return false, err - } - bestCost, bestInvalid, err := utilfuncp.GetTaskPlanCost(bestTask, op) - if err != nil { - return false, err - } - if curInvalid { - return false, nil - } - if bestInvalid { - return true, nil - } - return curCost < bestCost, nil -} - -// getTaskPlanCost returns the cost of this task. -// The new cost interface will be used if EnableNewCostInterface is true. -// The second returned value indicates whether this task is valid. -func getTaskPlanCost(t base.Task, pop *optimizetrace.PhysicalOptimizeOp) (float64, bool, error) { - if t.Invalid() { - return math.MaxFloat64, true, nil - } - - // use the new cost interface - var ( - taskType property.TaskType - indexPartialCost float64 - ) - switch t.(type) { - case *RootTask: - taskType = property.RootTaskType - case *CopTask: // no need to know whether the task is single-read or double-read, so both CopSingleReadTaskType and CopDoubleReadTaskType are OK - cop := t.(*CopTask) - if cop.indexPlan != nil && cop.tablePlan != nil { // handle IndexLookup specially - taskType = property.CopMultiReadTaskType - // keep compatible with the old cost interface, for CopMultiReadTask, the cost is idxCost + tblCost. - if !cop.indexPlanFinished { // only consider index cost in this case - idxCost, err := getPlanCost(cop.indexPlan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) - return idxCost, false, err - } - // consider both sides - idxCost, err := getPlanCost(cop.indexPlan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) - if err != nil { - return 0, false, err - } - tblCost, err := getPlanCost(cop.tablePlan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) - if err != nil { - return 0, false, err - } - return idxCost + tblCost, false, nil - } - - taskType = property.CopSingleReadTaskType - - // TiFlash can run cop task as well, check whether this cop task will run on TiKV or TiFlash. - if cop.tablePlan != nil { - leafNode := cop.tablePlan - for len(leafNode.Children()) > 0 { - leafNode = leafNode.Children()[0] - } - if tblScan, isScan := leafNode.(*PhysicalTableScan); isScan && tblScan.StoreType == kv.TiFlash { - taskType = property.MppTaskType - } - } - - // Detail reason ref about comment in function `convertToIndexMergeScan` - // for cop task with {indexPlan=nil, tablePlan=xxx, idxMergePartPlans=[x,x,x], indexPlanFinished=true} we should - // plus the partial index plan cost into the final cost. Because t.plan() the below code used only calculate the - // cost about table plan. - if cop.indexPlanFinished && len(cop.idxMergePartPlans) != 0 { - for _, partialScan := range cop.idxMergePartPlans { - partialCost, err := getPlanCost(partialScan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) - if err != nil { - return 0, false, err - } - indexPartialCost += partialCost - } - } - case *MppTask: - taskType = property.MppTaskType - default: - return 0, false, errors.New("unknown task type") - } - if t.Plan() == nil { - // It's a very special case for index merge case. - // t.plan() == nil in index merge COP case, it means indexPlanFinished is false in other words. - cost := 0.0 - copTsk := t.(*CopTask) - for _, partialScan := range copTsk.idxMergePartPlans { - partialCost, err := getPlanCost(partialScan, taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) - if err != nil { - return 0, false, err - } - cost += partialCost - } - return cost, false, nil - } - cost, err := getPlanCost(t.Plan(), taskType, optimizetrace.NewDefaultPlanCostOption().WithOptimizeTracer(pop)) - return cost + indexPartialCost, false, err -} - -func appendCandidate4PhysicalOptimizeOp(pop *optimizetrace.PhysicalOptimizeOp, lp base.LogicalPlan, pp base.PhysicalPlan, prop *property.PhysicalProperty) { - if pop == nil || pop.GetTracer() == nil || pp == nil { - return - } - candidate := &tracing.CandidatePlanTrace{ - PlanTrace: &tracing.PlanTrace{TP: pp.TP(), ID: pp.ID(), - ExplainInfo: pp.ExplainInfo(), ProperType: prop.String()}, - MappingLogicalPlan: tracing.CodecPlanName(lp.TP(), lp.ID())} - pop.GetTracer().AppendCandidate(candidate) - - // for PhysicalIndexMergeJoin/PhysicalIndexHashJoin/PhysicalIndexJoin, it will use innerTask as a child instead of calling findBestTask, - // and innerTask.plan() will be appended to planTree in appendChildCandidate using empty MappingLogicalPlan field, so it won't mapping with the logic plan, - // that will cause no physical plan when the logic plan got selected. - // the fix to add innerTask.plan() to planTree and mapping correct logic plan - index := -1 - var plan base.PhysicalPlan - switch join := pp.(type) { - case *PhysicalIndexMergeJoin: - index = join.InnerChildIdx - plan = join.innerPlan - case *PhysicalIndexHashJoin: - index = join.InnerChildIdx - plan = join.innerPlan - case *PhysicalIndexJoin: - index = join.InnerChildIdx - plan = join.innerPlan - } - if index != -1 { - child := lp.(*logicalop.BaseLogicalPlan).Children()[index] - candidate := &tracing.CandidatePlanTrace{ - PlanTrace: &tracing.PlanTrace{TP: plan.TP(), ID: plan.ID(), - ExplainInfo: plan.ExplainInfo(), ProperType: prop.String()}, - MappingLogicalPlan: tracing.CodecPlanName(child.TP(), child.ID())} - pop.GetTracer().AppendCandidate(candidate) - } - pp.AppendChildCandidate(pop) -} - -func appendPlanCostDetail4PhysicalOptimizeOp(pop *optimizetrace.PhysicalOptimizeOp, detail *tracing.PhysicalPlanCostDetail) { - if pop == nil || pop.GetTracer() == nil { - return - } - pop.GetTracer().PhysicalPlanCostDetails[fmt.Sprintf("%v_%v", detail.GetPlanType(), detail.GetPlanID())] = detail -} - -// findBestTask is key workflow that drive logic plan tree to generate optimal physical ones. -// The logic inside it is mainly about physical plan numeration and task encapsulation, it should -// be defined in core pkg, and be called by logic plan in their logic interface implementation. -func findBestTask(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, - opt *optimizetrace.PhysicalOptimizeOp) (bestTask base.Task, cntPlan int64, err error) { - p := lp.GetBaseLogicalPlan().(*logicalop.BaseLogicalPlan) - // If p is an inner plan in an IndexJoin, the IndexJoin will generate an inner plan by itself, - // and set inner child prop nil, so here we do nothing. - if prop == nil { - return nil, 1, nil - } - // Look up the task with this prop in the task map. - // It's used to reduce double counting. - bestTask = p.GetTask(prop) - if bestTask != nil { - planCounter.Dec(1) - return bestTask, 1, nil - } - - canAddEnforcer := prop.CanAddEnforcer - - if prop.TaskTp != property.RootTaskType && !prop.IsFlashProp() { - // Currently all plan cannot totally push down to TiKV. - p.StoreTask(prop, base.InvalidTask) - return base.InvalidTask, 0, nil - } - - cntPlan = 0 - // prop should be read only because its cached hashcode might be not consistent - // when it is changed. So we clone a new one for the temporary changes. - newProp := prop.CloneEssentialFields() - var plansFitsProp, plansNeedEnforce []base.PhysicalPlan - var hintWorksWithProp bool - // Maybe the plan can satisfy the required property, - // so we try to get the task without the enforced sort first. - plansFitsProp, hintWorksWithProp, err = p.Self().ExhaustPhysicalPlans(newProp) - if err != nil { - return nil, 0, err - } - if !hintWorksWithProp && !newProp.IsSortItemEmpty() { - // If there is a hint in the plan and the hint cannot satisfy the property, - // we enforce this property and try to generate the PhysicalPlan again to - // make sure the hint can work. - canAddEnforcer = true - } - - if canAddEnforcer { - // Then, we use the empty property to get physicalPlans and - // try to get the task with an enforced sort. - newProp.SortItems = []property.SortItem{} - newProp.SortItemsForPartition = []property.SortItem{} - newProp.ExpectedCnt = math.MaxFloat64 - newProp.MPPPartitionCols = nil - newProp.MPPPartitionTp = property.AnyType - var hintCanWork bool - plansNeedEnforce, hintCanWork, err = p.Self().ExhaustPhysicalPlans(newProp) - if err != nil { - return nil, 0, err - } - if hintCanWork && !hintWorksWithProp { - // If the hint can work with the empty property, but cannot work with - // the required property, we give up `plansFitProp` to make sure the hint - // can work. - plansFitsProp = nil - } - if !hintCanWork && !hintWorksWithProp && !prop.CanAddEnforcer { - // If the original property is not enforced and hint cannot - // work anyway, we give up `plansNeedEnforce` for efficiency, - plansNeedEnforce = nil - } - newProp = prop - } - - var cnt int64 - var curTask base.Task - if bestTask, cnt, err = enumeratePhysicalPlans4Task(p, plansFitsProp, newProp, false, planCounter, opt); err != nil { - return nil, 0, err - } - cntPlan += cnt - if planCounter.Empty() { - goto END - } - - curTask, cnt, err = enumeratePhysicalPlans4Task(p, plansNeedEnforce, newProp, true, planCounter, opt) - if err != nil { - return nil, 0, err - } - cntPlan += cnt - if planCounter.Empty() { - bestTask = curTask - goto END - } - utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, curTask.Plan(), prop) - if curIsBetter, err := compareTaskCost(curTask, bestTask, opt); err != nil { - return nil, 0, err - } else if curIsBetter { - bestTask = curTask - } - -END: - p.StoreTask(prop, bestTask) - return bestTask, cntPlan, nil -} - -func findBestTask4LogicalMemTable(lp base.LogicalPlan, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, opt *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { - p := lp.(*logicalop.LogicalMemTable) - if prop.MPPPartitionTp != property.AnyType { - return base.InvalidTask, 0, nil - } - - // If prop.CanAddEnforcer is true, the prop.SortItems need to be set nil for p.findBestTask. - // Before function return, reset it for enforcing task prop. - oldProp := prop.CloneEssentialFields() - if prop.CanAddEnforcer { - // First, get the bestTask without enforced prop - prop.CanAddEnforcer = false - cnt := int64(0) - t, cnt, err = p.FindBestTask(prop, planCounter, opt) - if err != nil { - return nil, 0, err - } - prop.CanAddEnforcer = true - if t != base.InvalidTask { - cntPlan = cnt - return - } - // Next, get the bestTask with enforced prop - prop.SortItems = []property.SortItem{} - } - defer func() { - if err != nil { - return - } - if prop.CanAddEnforcer { - *prop = *oldProp - t = enforceProperty(prop, t, p.Plan.SCtx()) - prop.CanAddEnforcer = true - } - }() - - if !prop.IsSortItemEmpty() || planCounter.Empty() { - return base.InvalidTask, 0, nil - } - memTable := PhysicalMemTable{ - DBName: p.DBName, - Table: p.TableInfo, - Columns: p.Columns, - Extractor: p.Extractor, - QueryTimeRange: p.QueryTimeRange, - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) - memTable.SetSchema(p.Schema()) - planCounter.Dec(1) - utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, p, memTable, prop) - rt := &RootTask{} - rt.SetPlan(memTable) - return rt, 1, nil -} - -// tryToGetDualTask will check if the push down predicate has false constant. If so, it will return table dual. -func (ds *DataSource) tryToGetDualTask() (base.Task, error) { - for _, cond := range ds.PushedDownConds { - if con, ok := cond.(*expression.Constant); ok && con.DeferredExpr == nil && con.ParamMarker == nil { - result, _, err := expression.EvalBool(ds.SCtx().GetExprCtx().GetEvalCtx(), []expression.Expression{cond}, chunk.Row{}) - if err != nil { - return nil, err - } - if !result { - dual := PhysicalTableDual{}.Init(ds.SCtx(), ds.StatsInfo(), ds.QueryBlockOffset()) - dual.SetSchema(ds.Schema()) - rt := &RootTask{} - rt.SetPlan(dual) - return rt, nil - } - } - } - return nil, nil -} - -// candidatePath is used to maintain required info for skyline pruning. -type candidatePath struct { - path *util.AccessPath - accessCondsColMap util.Col2Len // accessCondsColMap maps Column.UniqueID to column length for the columns in AccessConds. - indexCondsColMap util.Col2Len // indexCondsColMap maps Column.UniqueID to column length for the columns in AccessConds and indexFilters. - isMatchProp bool -} - -func compareBool(l, r bool) int { - if l == r { - return 0 - } - if !l { - return -1 - } - return 1 -} - -func compareIndexBack(lhs, rhs *candidatePath) (int, bool) { - result := compareBool(lhs.path.IsSingleScan, rhs.path.IsSingleScan) - if result == 0 && !lhs.path.IsSingleScan { - // if both lhs and rhs need to access table after IndexScan, we utilize the set of columns that occurred in AccessConds and IndexFilters - // to compare how many table rows will be accessed. - return util.CompareCol2Len(lhs.indexCondsColMap, rhs.indexCondsColMap) - } - return result, true -} - -// compareCandidates is the core of skyline pruning, which is used to decide which candidate path is better. -// The return value is 1 if lhs is better, -1 if rhs is better, 0 if they are equivalent or not comparable. -func compareCandidates(sctx base.PlanContext, prop *property.PhysicalProperty, lhs, rhs *candidatePath) int { - // Due to #50125, full scan on MVIndex has been disabled, so MVIndex path might lead to 'can't find a proper plan' error at the end. - // Avoid MVIndex path to exclude all other paths and leading to 'can't find a proper plan' error, see #49438 for an example. - if isMVIndexPath(lhs.path) || isMVIndexPath(rhs.path) { - return 0 - } - - // This rule is empirical but not always correct. - // If x's range row count is significantly lower than y's, for example, 1000 times, we think x is better. - if lhs.path.CountAfterAccess > 100 && rhs.path.CountAfterAccess > 100 && // to prevent some extreme cases, e.g. 0.01 : 10 - len(lhs.path.PartialIndexPaths) == 0 && len(rhs.path.PartialIndexPaths) == 0 && // not IndexMerge since its row count estimation is not accurate enough - prop.ExpectedCnt == math.MaxFloat64 { // Limit may affect access row count - threshold := float64(fixcontrol.GetIntWithDefault(sctx.GetSessionVars().OptimizerFixControl, fixcontrol.Fix45132, 1000)) - if threshold > 0 { // set it to 0 to disable this rule - if lhs.path.CountAfterAccess/rhs.path.CountAfterAccess > threshold { - return -1 - } - if rhs.path.CountAfterAccess/lhs.path.CountAfterAccess > threshold { - return 1 - } - } - } - - // Below compares the two candidate paths on three dimensions: - // (1): the set of columns that occurred in the access condition, - // (2): does it require a double scan, - // (3): whether or not it matches the physical property. - // If `x` is not worse than `y` at all factors, - // and there exists one factor that `x` is better than `y`, then `x` is better than `y`. - accessResult, comparable1 := util.CompareCol2Len(lhs.accessCondsColMap, rhs.accessCondsColMap) - if !comparable1 { - return 0 - } - scanResult, comparable2 := compareIndexBack(lhs, rhs) - if !comparable2 { - return 0 - } - matchResult := compareBool(lhs.isMatchProp, rhs.isMatchProp) - sum := accessResult + scanResult + matchResult - if accessResult >= 0 && scanResult >= 0 && matchResult >= 0 && sum > 0 { - return 1 - } - if accessResult <= 0 && scanResult <= 0 && matchResult <= 0 && sum < 0 { - return -1 - } - return 0 -} - -func (ds *DataSource) isMatchProp(path *util.AccessPath, prop *property.PhysicalProperty) bool { - var isMatchProp bool - if path.IsIntHandlePath { - pkCol := ds.getPKIsHandleCol() - if len(prop.SortItems) == 1 && pkCol != nil { - isMatchProp = prop.SortItems[0].Col.EqualColumn(pkCol) - if path.StoreType == kv.TiFlash { - isMatchProp = isMatchProp && !prop.SortItems[0].Desc - } - } - return isMatchProp - } - all, _ := prop.AllSameOrder() - // When the prop is empty or `all` is false, `isMatchProp` is better to be `false` because - // it needs not to keep order for index scan. - - // Basically, if `prop.SortItems` is the prefix of `path.IdxCols`, then `isMatchProp` is true. However, we need to consider - // the situations when some columns of `path.IdxCols` are evaluated as constant. For example: - // ``` - // create table t(a int, b int, c int, d int, index idx_a_b_c(a, b, c), index idx_d_c_b_a(d, c, b, a)); - // select * from t where a = 1 order by b, c; - // select * from t where b = 1 order by a, c; - // select * from t where d = 1 and b = 2 order by c, a; - // select * from t where d = 1 and b = 2 order by c, b, a; - // ``` - // In the first two `SELECT` statements, `idx_a_b_c` matches the sort order. In the last two `SELECT` statements, `idx_d_c_b_a` - // matches the sort order. Hence, we use `path.ConstCols` to deal with the above situations. - if !prop.IsSortItemEmpty() && all && len(path.IdxCols) >= len(prop.SortItems) { - isMatchProp = true - i := 0 - for _, sortItem := range prop.SortItems { - found := false - for ; i < len(path.IdxCols); i++ { - if path.IdxColLens[i] == types.UnspecifiedLength && sortItem.Col.EqualColumn(path.IdxCols[i]) { - found = true - i++ - break - } - if path.ConstCols == nil || i >= len(path.ConstCols) || !path.ConstCols[i] { - break - } - } - if !found { - isMatchProp = false - break - } - } - } - return isMatchProp -} - -// matchPropForIndexMergeAlternatives will match the prop with inside PartialAlternativeIndexPaths, and choose -// 1 matched alternative to be a determined index merge partial path for each dimension in PartialAlternativeIndexPaths. -// finally, after we collected the all decided index merge partial paths, we will output a concrete index merge path -// with field PartialIndexPaths is fulfilled here. -// -// as we mentioned before, after deriveStats is done, the normal index OR path will be generated like below: -// -// `create table t (a int, b int, c int, key a(a), key b(b), key ac(a, c), key bc(b, c))` -// `explain format='verbose' select * from t where a=1 or b=1 order by c` -// -// like the case here: -// normal index merge OR path should be: -// for a=1, it has two partial alternative paths: [a, ac] -// for b=1, it has two partial alternative paths: [b, bc] -// and the index merge path: -// -// indexMergePath: { -// PartialIndexPaths: empty // 1D array here, currently is not decided yet. -// PartialAlternativeIndexPaths: [[a, ac], [b, bc]] // 2D array here, each for one DNF item choices. -// } -// -// let's say we have a prop requirement like sort by [c] here, we will choose the better one [ac] (because it can keep -// order) for the first batch [a, ac] from PartialAlternativeIndexPaths; and choose the better one [bc] (because it can -// keep order too) for the second batch [b, bc] from PartialAlternativeIndexPaths. Finally we output a concrete index -// merge path as -// -// indexMergePath: { -// PartialIndexPaths: [ac, bc] // just collected since they match the prop. -// ... -// } -// -// how about the prop is empty? that means the choice to be decided from [a, ac] and [b, bc] is quite random just according -// to their countAfterAccess. That's why we use a slices.SortFunc(matchIdxes, func(a, b int){}) inside there. After sort, -// the ASC order of matchIdxes of matched paths are ordered by their countAfterAccess, choosing the first one is straight forward. -// -// there is another case shown below, just the pick the first one after matchIdxes is ordered is not always right, as shown: -// special logic for alternative paths: -// -// index merge: -// matched paths-1: {pk, index1} -// matched paths-2: {pk} -// -// if we choose first one as we talked above, says pk here in the first matched paths, then path2 has no choice(avoiding all same -// index logic inside) but pk, this will result in all single index failure. so we need to sort the matchIdxes again according to -// their matched paths length, here mean: -// -// index merge: -// matched paths-1: {pk, index1} -// matched paths-2: {pk} -// -// and let matched paths-2 to be the first to make their determination --- choosing pk here, then next turn is matched paths-1 to -// make their choice, since pk is occupied, avoiding-all-same-index-logic inside will try to pick index1 here, so work can be done. -// -// at last, according to determinedIndexPartialPaths to rewrite their real countAfterAccess, this part is move from deriveStats to -// here. -func (ds *DataSource) matchPropForIndexMergeAlternatives(path *util.AccessPath, prop *property.PhysicalProperty) (*util.AccessPath, bool) { - // target: - // 1: index merge case, try to match the every alternative partial path to the order property as long as - // possible, and generate that property-matched index merge path out if any. - // 2: If the prop is empty (means no sort requirement), we will generate a random index partial combination - // path from all alternatives in case that no index merge path comes out. - - // Execution part doesn't support the merge operation for intersection case yet. - if path.IndexMergeIsIntersection { - return nil, false - } - - noSortItem := prop.IsSortItemEmpty() - allSame, _ := prop.AllSameOrder() - if !allSame { - return nil, false - } - // step1: match the property from all the index partial alternative paths. - determinedIndexPartialPaths := make([]*util.AccessPath, 0, len(path.PartialAlternativeIndexPaths)) - usedIndexMap := make(map[int64]struct{}, 1) - type idxWrapper struct { - // matchIdx is those match alternative paths from one alternative paths set. - // like we said above, for a=1, it has two partial alternative paths: [a, ac] - // if we met an empty property here, matchIdx from [a, ac] for a=1 will be both. = [0,1] - // if we met an sort[c] property here, matchIdx from [a, ac] for a=1 will be both. = [1] - matchIdx []int - // pathIdx actually is original position offset indicates where current matchIdx is - // computed from. eg: [[a, ac], [b, bc]] for sort[c] property: - // idxWrapper{[ac], 0}, 0 is the offset in first dimension of PartialAlternativeIndexPaths - // idxWrapper{[bc], 1}, 1 is the offset in first dimension of PartialAlternativeIndexPaths - pathIdx int - } - allMatchIdxes := make([]idxWrapper, 0, len(path.PartialAlternativeIndexPaths)) - // special logic for alternative paths: - // index merge: - // path1: {pk, index1} - // path2: {pk} - // if we choose pk in the first path, then path2 has no choice but pk, this will result in all single index failure. - // so we should collect all match prop paths down, stored as matchIdxes here. - for pathIdx, oneItemAlternatives := range path.PartialAlternativeIndexPaths { - matchIdxes := make([]int, 0, 1) - for i, oneIndexAlternativePath := range oneItemAlternatives { - // if there is some sort items and this path doesn't match this prop, continue. - if !noSortItem && !ds.isMatchProp(oneIndexAlternativePath, prop) { - continue - } - // two possibility here: - // 1. no sort items requirement. - // 2. matched with sorted items. - matchIdxes = append(matchIdxes, i) - } - if len(matchIdxes) == 0 { - // if all index alternative of one of the cnf item's couldn't match the sort property, - // the entire index merge union path can be ignored for this sort property, return false. - return nil, false - } - if len(matchIdxes) > 1 { - // if matchIdxes greater than 1, we should sort this match alternative path by its CountAfterAccess. - tmpOneItemAlternatives := oneItemAlternatives - slices.SortStableFunc(matchIdxes, func(a, b int) int { - lhsCountAfter := tmpOneItemAlternatives[a].CountAfterAccess - if len(tmpOneItemAlternatives[a].IndexFilters) > 0 { - lhsCountAfter = tmpOneItemAlternatives[a].CountAfterIndex - } - rhsCountAfter := tmpOneItemAlternatives[b].CountAfterAccess - if len(tmpOneItemAlternatives[b].IndexFilters) > 0 { - rhsCountAfter = tmpOneItemAlternatives[b].CountAfterIndex - } - return cmp.Compare(lhsCountAfter, rhsCountAfter) - }) - } - allMatchIdxes = append(allMatchIdxes, idxWrapper{matchIdxes, pathIdx}) - } - // sort allMatchIdxes by its element length. - // index merge: index merge: - // path1: {pk, index1} ==> path2: {pk} - // path2: {pk} path1: {pk, index1} - // here for the fixed choice pk of path2, let it be the first one to choose, left choice of index1 to path1. - slices.SortStableFunc(allMatchIdxes, func(a, b idxWrapper) int { - lhsLen := len(a.matchIdx) - rhsLen := len(b.matchIdx) - return cmp.Compare(lhsLen, rhsLen) - }) - for _, matchIdxes := range allMatchIdxes { - // since matchIdxes are ordered by matchIdxes's length, - // we should use matchIdxes.pathIdx to locate where it comes from. - alternatives := path.PartialAlternativeIndexPaths[matchIdxes.pathIdx] - found := false - // pick a most suitable index partial alternative from all matched alternative paths according to asc CountAfterAccess, - // By this way, a distinguished one is better. - for _, oneIdx := range matchIdxes.matchIdx { - var indexID int64 - if alternatives[oneIdx].IsTablePath() { - indexID = -1 - } else { - indexID = alternatives[oneIdx].Index.ID - } - if _, ok := usedIndexMap[indexID]; !ok { - // try to avoid all index partial paths are all about a single index. - determinedIndexPartialPaths = append(determinedIndexPartialPaths, alternatives[oneIdx].Clone()) - usedIndexMap[indexID] = struct{}{} - found = true - break - } - } - if !found { - // just pick the same name index (just using the first one is ok), in case that there may be some other - // picked distinctive index path for other partial paths latter. - determinedIndexPartialPaths = append(determinedIndexPartialPaths, alternatives[matchIdxes.matchIdx[0]].Clone()) - // uedIndexMap[oneItemAlternatives[oneIdx].Index.ID] = struct{}{} must already be colored. - } - } - if len(usedIndexMap) == 1 { - // if all partial path are using a same index, meaningless and fail over. - return nil, false - } - // step2: gen a new **concrete** index merge path. - indexMergePath := &util.AccessPath{ - PartialIndexPaths: determinedIndexPartialPaths, - IndexMergeIsIntersection: false, - // inherit those determined can't pushed-down table filters. - TableFilters: path.TableFilters, - } - // path.ShouldBeKeptCurrentFilter record that whether there are some part of the cnf item couldn't be pushed down to tikv already. - shouldKeepCurrentFilter := path.KeepIndexMergeORSourceFilter - pushDownCtx := GetPushDownCtx(ds.SCtx()) - for _, path := range determinedIndexPartialPaths { - // If any partial path contains table filters, we need to keep the whole DNF filter in the Selection. - if len(path.TableFilters) > 0 { - if !expression.CanExprsPushDown(pushDownCtx, path.TableFilters, kv.TiKV) { - // if this table filters can't be pushed down, all of them should be kept in the table side, cleaning the lookup side here. - path.TableFilters = nil - } - shouldKeepCurrentFilter = true - } - // If any partial path's index filter cannot be pushed to TiKV, we should keep the whole DNF filter. - if len(path.IndexFilters) != 0 && !expression.CanExprsPushDown(pushDownCtx, path.IndexFilters, kv.TiKV) { - shouldKeepCurrentFilter = true - // Clear IndexFilter, the whole filter will be put in indexMergePath.TableFilters. - path.IndexFilters = nil - } - } - // Keep this filter as a part of table filters for safety if it has any parameter. - if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), []expression.Expression{path.IndexMergeORSourceFilter}) { - shouldKeepCurrentFilter = true - } - if shouldKeepCurrentFilter { - // add the cnf expression back as table filer. - indexMergePath.TableFilters = append(indexMergePath.TableFilters, path.IndexMergeORSourceFilter) - } - - // step3: after the index merge path is determined, compute the countAfterAccess as usual. - accessConds := make([]expression.Expression, 0, len(determinedIndexPartialPaths)) - for _, p := range determinedIndexPartialPaths { - indexCondsForP := p.AccessConds[:] - indexCondsForP = append(indexCondsForP, p.IndexFilters...) - if len(indexCondsForP) > 0 { - accessConds = append(accessConds, expression.ComposeCNFCondition(ds.SCtx().GetExprCtx(), indexCondsForP...)) - } - } - accessDNF := expression.ComposeDNFCondition(ds.SCtx().GetExprCtx(), accessConds...) - sel, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, []expression.Expression{accessDNF}, nil) - if err != nil { - logutil.BgLogger().Debug("something wrong happened, use the default selectivity", zap.Error(err)) - sel = cost.SelectionFactor - } - indexMergePath.CountAfterAccess = sel * ds.TableStats.RowCount - if noSortItem { - // since there is no sort property, index merge case is generated by random combination, each alternative with the lower/lowest - // countAfterAccess, here the returned matchProperty should be false. - return indexMergePath, false - } - return indexMergePath, true -} - -func (ds *DataSource) isMatchPropForIndexMerge(path *util.AccessPath, prop *property.PhysicalProperty) bool { - // Execution part doesn't support the merge operation for intersection case yet. - if path.IndexMergeIsIntersection { - return false - } - allSame, _ := prop.AllSameOrder() - if !allSame { - return false - } - for _, partialPath := range path.PartialIndexPaths { - if !ds.isMatchProp(partialPath, prop) { - return false - } - } - return true -} - -func (ds *DataSource) getTableCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { - candidate := &candidatePath{path: path} - candidate.isMatchProp = ds.isMatchProp(path, prop) - candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx().GetEvalCtx(), path.AccessConds, nil, nil) - return candidate -} - -func (ds *DataSource) getIndexCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { - candidate := &candidatePath{path: path} - candidate.isMatchProp = ds.isMatchProp(path, prop) - candidate.accessCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx().GetEvalCtx(), path.AccessConds, path.IdxCols, path.IdxColLens) - candidate.indexCondsColMap = util.ExtractCol2Len(ds.SCtx().GetExprCtx().GetEvalCtx(), append(path.AccessConds, path.IndexFilters...), path.FullIdxCols, path.FullIdxColLens) - return candidate -} - -func (ds *DataSource) convergeIndexMergeCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { - // since the all index path alternative paths is collected and undetermined, and we should determine a possible and concrete path for this prop. - possiblePath, match := ds.matchPropForIndexMergeAlternatives(path, prop) - if possiblePath == nil { - return nil - } - candidate := &candidatePath{path: possiblePath, isMatchProp: match} - return candidate -} - -func (ds *DataSource) getIndexMergeCandidate(path *util.AccessPath, prop *property.PhysicalProperty) *candidatePath { - candidate := &candidatePath{path: path} - candidate.isMatchProp = ds.isMatchPropForIndexMerge(path, prop) - return candidate -} - -// skylinePruning prunes access paths according to different factors. An access path can be pruned only if -// there exists a path that is not worse than it at all factors and there is at least one better factor. -func (ds *DataSource) skylinePruning(prop *property.PhysicalProperty) []*candidatePath { - candidates := make([]*candidatePath, 0, 4) - for _, path := range ds.PossibleAccessPaths { - // We should check whether the possible access path is valid first. - if path.StoreType != kv.TiFlash && prop.IsFlashProp() { - continue - } - if len(path.PartialAlternativeIndexPaths) > 0 { - // OR normal index merge path, try to determine every index partial path for this property. - candidate := ds.convergeIndexMergeCandidate(path, prop) - if candidate != nil { - candidates = append(candidates, candidate) - } - continue - } - if path.PartialIndexPaths != nil { - candidates = append(candidates, ds.getIndexMergeCandidate(path, prop)) - continue - } - // if we already know the range of the scan is empty, just return a TableDual - if len(path.Ranges) == 0 { - return []*candidatePath{{path: path}} - } - var currentCandidate *candidatePath - if path.IsTablePath() { - currentCandidate = ds.getTableCandidate(path, prop) - } else { - if !(len(path.AccessConds) > 0 || !prop.IsSortItemEmpty() || path.Forced || path.IsSingleScan) { - continue - } - // We will use index to generate physical plan if any of the following conditions is satisfied: - // 1. This path's access cond is not nil. - // 2. We have a non-empty prop to match. - // 3. This index is forced to choose. - // 4. The needed columns are all covered by index columns(and handleCol). - currentCandidate = ds.getIndexCandidate(path, prop) - } - pruned := false - for i := len(candidates) - 1; i >= 0; i-- { - if candidates[i].path.StoreType == kv.TiFlash { - continue - } - result := compareCandidates(ds.SCtx(), prop, candidates[i], currentCandidate) - if result == 1 { - pruned = true - // We can break here because the current candidate cannot prune others anymore. - break - } else if result == -1 { - candidates = append(candidates[:i], candidates[i+1:]...) - } - } - if !pruned { - candidates = append(candidates, currentCandidate) - } - } - - if ds.SCtx().GetSessionVars().GetAllowPreferRangeScan() && len(candidates) > 1 { - // If a candidate path is TiFlash-path or forced-path, we just keep them. For other candidate paths, if there exists - // any range scan path, we remove full scan paths and keep range scan paths. - preferredPaths := make([]*candidatePath, 0, len(candidates)) - var hasRangeScanPath bool - for _, c := range candidates { - if c.path.Forced || c.path.StoreType == kv.TiFlash { - preferredPaths = append(preferredPaths, c) - continue - } - var unsignedIntHandle bool - if c.path.IsIntHandlePath && ds.TableInfo.PKIsHandle { - if pkColInfo := ds.TableInfo.GetPkColInfo(); pkColInfo != nil { - unsignedIntHandle = mysql.HasUnsignedFlag(pkColInfo.GetFlag()) - } - } - if !ranger.HasFullRange(c.path.Ranges, unsignedIntHandle) { - preferredPaths = append(preferredPaths, c) - hasRangeScanPath = true - } - } - if hasRangeScanPath { - return preferredPaths - } - } - - return candidates -} - -func (ds *DataSource) getPruningInfo(candidates []*candidatePath, prop *property.PhysicalProperty) string { - if len(candidates) == len(ds.PossibleAccessPaths) { - return "" - } - if len(candidates) == 1 && len(candidates[0].path.Ranges) == 0 { - // For TableDual, we don't need to output pruning info. - return "" - } - names := make([]string, 0, len(candidates)) - var tableName string - if ds.TableAsName.O == "" { - tableName = ds.TableInfo.Name.O - } else { - tableName = ds.TableAsName.O - } - getSimplePathName := func(path *util.AccessPath) string { - if path.IsTablePath() { - if path.StoreType == kv.TiFlash { - return tableName + "(tiflash)" - } - return tableName - } - return path.Index.Name.O - } - for _, cand := range candidates { - if cand.path.PartialIndexPaths != nil { - partialNames := make([]string, 0, len(cand.path.PartialIndexPaths)) - for _, partialPath := range cand.path.PartialIndexPaths { - partialNames = append(partialNames, getSimplePathName(partialPath)) - } - names = append(names, fmt.Sprintf("IndexMerge{%s}", strings.Join(partialNames, ","))) - } else { - names = append(names, getSimplePathName(cand.path)) - } - } - items := make([]string, 0, len(prop.SortItems)) - for _, item := range prop.SortItems { - items = append(items, item.String()) - } - return fmt.Sprintf("[%s] remain after pruning paths for %s given Prop{SortItems: [%s], TaskTp: %s}", - strings.Join(names, ","), tableName, strings.Join(items, " "), prop.TaskTp) -} - -func (ds *DataSource) isPointGetConvertableSchema() bool { - for _, col := range ds.Columns { - if col.Name.L == model.ExtraHandleName.L { - continue - } - - // Only handle tables that all columns are public. - if col.State != model.StatePublic { - return false - } - } - return true -} - -// exploreEnforcedPlan determines whether to explore enforced plans for this DataSource if it has already found an unenforced plan. -// See #46177 for more information. -func (ds *DataSource) exploreEnforcedPlan() bool { - // default value is false to keep it compatible with previous versions. - return fixcontrol.GetBoolWithDefault(ds.SCtx().GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix46177, false) -} - -func findBestTask4DS(ds *DataSource, prop *property.PhysicalProperty, planCounter *base.PlanCounterTp, opt *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { - // If ds is an inner plan in an IndexJoin, the IndexJoin will generate an inner plan by itself, - // and set inner child prop nil, so here we do nothing. - if prop == nil { - planCounter.Dec(1) - return nil, 1, nil - } - if ds.IsForUpdateRead && ds.SCtx().GetSessionVars().TxnCtx.IsExplicit { - hasPointGetPath := false - for _, path := range ds.PossibleAccessPaths { - if ds.isPointGetPath(path) { - hasPointGetPath = true - break - } - } - tblName := ds.TableInfo.Name - ds.PossibleAccessPaths, err = filterPathByIsolationRead(ds.SCtx(), ds.PossibleAccessPaths, tblName, ds.DBName) - if err != nil { - return nil, 1, err - } - if hasPointGetPath { - newPaths := make([]*util.AccessPath, 0) - for _, path := range ds.PossibleAccessPaths { - // if the path is the point get range path with for update lock, we should forbid tiflash as it's store path (#39543) - if path.StoreType != kv.TiFlash { - newPaths = append(newPaths, path) - } - } - ds.PossibleAccessPaths = newPaths - } - } - t = ds.GetTask(prop) - if t != nil { - cntPlan = 1 - planCounter.Dec(1) - return - } - var cnt int64 - var unenforcedTask base.Task - // If prop.CanAddEnforcer is true, the prop.SortItems need to be set nil for ds.findBestTask. - // Before function return, reset it for enforcing task prop and storing map. - oldProp := prop.CloneEssentialFields() - if prop.CanAddEnforcer { - // First, get the bestTask without enforced prop - prop.CanAddEnforcer = false - unenforcedTask, cnt, err = ds.FindBestTask(prop, planCounter, opt) - if err != nil { - return nil, 0, err - } - if !unenforcedTask.Invalid() && !ds.exploreEnforcedPlan() { - ds.StoreTask(prop, unenforcedTask) - return unenforcedTask, cnt, nil - } - - // Then, explore the bestTask with enforced prop - prop.CanAddEnforcer = true - cntPlan += cnt - prop.SortItems = []property.SortItem{} - prop.MPPPartitionTp = property.AnyType - } else if prop.MPPPartitionTp != property.AnyType { - return base.InvalidTask, 0, nil - } - defer func() { - if err != nil { - return - } - if prop.CanAddEnforcer { - *prop = *oldProp - t = enforceProperty(prop, t, ds.Plan.SCtx()) - prop.CanAddEnforcer = true - } - - if unenforcedTask != nil && !unenforcedTask.Invalid() { - curIsBest, cerr := compareTaskCost(unenforcedTask, t, opt) - if cerr != nil { - err = cerr - return - } - if curIsBest { - t = unenforcedTask - } - } - - ds.StoreTask(prop, t) - err = validateTableSamplePlan(ds, t, err) - }() - - t, err = ds.tryToGetDualTask() - if err != nil || t != nil { - planCounter.Dec(1) - if t != nil { - appendCandidate(ds, t, prop, opt) - } - return t, 1, err - } - - t = base.InvalidTask - candidates := ds.skylinePruning(prop) - pruningInfo := ds.getPruningInfo(candidates, prop) - defer func() { - if err == nil && t != nil && !t.Invalid() && pruningInfo != "" { - warnErr := errors.NewNoStackError(pruningInfo) - if ds.SCtx().GetSessionVars().StmtCtx.InVerboseExplain { - ds.SCtx().GetSessionVars().StmtCtx.AppendNote(warnErr) - } else { - ds.SCtx().GetSessionVars().StmtCtx.AppendExtraNote(warnErr) - } - } - }() - - cntPlan = 0 - for _, candidate := range candidates { - path := candidate.path - if path.PartialIndexPaths != nil { - idxMergeTask, err := ds.convertToIndexMergeScan(prop, candidate, opt) - if err != nil { - return nil, 0, err - } - if !idxMergeTask.Invalid() { - cntPlan++ - planCounter.Dec(1) - } - appendCandidate(ds, idxMergeTask, prop, opt) - - curIsBetter, err := compareTaskCost(idxMergeTask, t, opt) - if err != nil { - return nil, 0, err - } - if curIsBetter || planCounter.Empty() { - t = idxMergeTask - } - if planCounter.Empty() { - return t, cntPlan, nil - } - continue - } - // if we already know the range of the scan is empty, just return a TableDual - if len(path.Ranges) == 0 { - // We should uncache the tableDual plan. - if expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), path.AccessConds) { - ds.SCtx().GetSessionVars().StmtCtx.SetSkipPlanCache("get a TableDual plan") - } - dual := PhysicalTableDual{}.Init(ds.SCtx(), ds.StatsInfo(), ds.QueryBlockOffset()) - dual.SetSchema(ds.Schema()) - cntPlan++ - planCounter.Dec(1) - t := &RootTask{} - t.SetPlan(dual) - appendCandidate(ds, t, prop, opt) - return t, cntPlan, nil - } - - canConvertPointGet := len(path.Ranges) > 0 && path.StoreType == kv.TiKV && ds.isPointGetConvertableSchema() - - if canConvertPointGet && path.Index != nil && path.Index.MVIndex { - canConvertPointGet = false // cannot use PointGet upon MVIndex - } - - if canConvertPointGet && !path.IsIntHandlePath { - // We simply do not build [batch] point get for prefix indexes. This can be optimized. - canConvertPointGet = path.Index.Unique && !path.Index.HasPrefixIndex() - // If any range cannot cover all columns of the index, we cannot build [batch] point get. - idxColsLen := len(path.Index.Columns) - for _, ran := range path.Ranges { - if len(ran.LowVal) != idxColsLen { - canConvertPointGet = false - break - } - } - } - if canConvertPointGet && ds.table.Meta().GetPartitionInfo() != nil { - // partition table with dynamic prune not support batchPointGet - // Due to sorting? - // Please make sure handle `where _tidb_rowid in (xx, xx)` correctly when delete this if statements. - if canConvertPointGet && len(path.Ranges) > 1 && ds.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - canConvertPointGet = false - } - if canConvertPointGet && len(path.Ranges) > 1 { - // TODO: This is now implemented, but to decrease - // the impact of supporting plan cache for patitioning, - // this is not yet enabled. - // TODO: just remove this if block and update/add tests... - // We can only build batch point get for hash partitions on a simple column now. This is - // decided by the current implementation of `BatchPointGetExec::initialize()`, specifically, - // the `getPhysID()` function. Once we optimize that part, we can come back and enable - // BatchPointGet plan for more cases. - hashPartColName := getHashOrKeyPartitionColumnName(ds.SCtx(), ds.table.Meta()) - if hashPartColName == nil { - canConvertPointGet = false - } - } - // Partition table can't use `_tidb_rowid` to generate PointGet Plan unless one partition is explicitly specified. - if canConvertPointGet && path.IsIntHandlePath && !ds.table.Meta().PKIsHandle && len(ds.PartitionNames) != 1 { - canConvertPointGet = false - } - if canConvertPointGet { - if path != nil && path.Index != nil && path.Index.Global { - // Don't convert to point get during ddl - // TODO: Revisit truncate partition and global index - if len(ds.TableInfo.GetPartitionInfo().DroppingDefinitions) > 0 || - len(ds.TableInfo.GetPartitionInfo().AddingDefinitions) > 0 { - canConvertPointGet = false - } - } - } - } - if canConvertPointGet { - allRangeIsPoint := true - tc := ds.SCtx().GetSessionVars().StmtCtx.TypeCtx() - for _, ran := range path.Ranges { - if !ran.IsPointNonNullable(tc) { - // unique indexes can have duplicated NULL rows so we cannot use PointGet if there is NULL - allRangeIsPoint = false - break - } - } - if allRangeIsPoint { - var pointGetTask base.Task - if len(path.Ranges) == 1 { - pointGetTask = ds.convertToPointGet(prop, candidate) - } else { - pointGetTask = ds.convertToBatchPointGet(prop, candidate) - } - - // Batch/PointGet plans may be over-optimized, like `a>=1(?) and a<=1(?)` --> `a=1` --> PointGet(a=1). - // For safety, prevent these plans from the plan cache here. - if !pointGetTask.Invalid() && expression.MaybeOverOptimized4PlanCache(ds.SCtx().GetExprCtx(), candidate.path.AccessConds) && !isSafePointGetPath4PlanCache(ds.SCtx(), candidate.path) { - ds.SCtx().GetSessionVars().StmtCtx.SetSkipPlanCache("Batch/PointGet plans may be over-optimized") - } - - appendCandidate(ds, pointGetTask, prop, opt) - if !pointGetTask.Invalid() { - cntPlan++ - planCounter.Dec(1) - } - curIsBetter, cerr := compareTaskCost(pointGetTask, t, opt) - if cerr != nil { - return nil, 0, cerr - } - if curIsBetter || planCounter.Empty() { - t = pointGetTask - if planCounter.Empty() { - return - } - continue - } - } - } - if path.IsTablePath() { - if ds.PreferStoreType&h.PreferTiFlash != 0 && path.StoreType == kv.TiKV { - continue - } - if ds.PreferStoreType&h.PreferTiKV != 0 && path.StoreType == kv.TiFlash { - continue - } - var tblTask base.Task - if ds.SampleInfo != nil { - tblTask, err = ds.convertToSampleTable(prop, candidate, opt) - } else { - tblTask, err = ds.convertToTableScan(prop, candidate, opt) - } - if err != nil { - return nil, 0, err - } - if !tblTask.Invalid() { - cntPlan++ - planCounter.Dec(1) - } - appendCandidate(ds, tblTask, prop, opt) - curIsBetter, err := compareTaskCost(tblTask, t, opt) - if err != nil { - return nil, 0, err - } - if curIsBetter || planCounter.Empty() { - t = tblTask - } - if planCounter.Empty() { - return t, cntPlan, nil - } - continue - } - // TiFlash storage do not support index scan. - if ds.PreferStoreType&h.PreferTiFlash != 0 { - continue - } - // TableSample do not support index scan. - if ds.SampleInfo != nil { - continue - } - idxTask, err := ds.convertToIndexScan(prop, candidate, opt) - if err != nil { - return nil, 0, err - } - if !idxTask.Invalid() { - cntPlan++ - planCounter.Dec(1) - } - appendCandidate(ds, idxTask, prop, opt) - curIsBetter, err := compareTaskCost(idxTask, t, opt) - if err != nil { - return nil, 0, err - } - if curIsBetter || planCounter.Empty() { - t = idxTask - } - if planCounter.Empty() { - return t, cntPlan, nil - } - } - - return -} - -// convertToIndexMergeScan builds the index merge scan for intersection or union cases. -func (ds *DataSource) convertToIndexMergeScan(prop *property.PhysicalProperty, candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (task base.Task, err error) { - if prop.IsFlashProp() || prop.TaskTp == property.CopSingleReadTaskType { - return base.InvalidTask, nil - } - // lift the limitation of that double read can not build index merge **COP** task with intersection. - // that means we can output a cop task here without encapsulating it as root task, for the convenience of attaching limit to its table side. - - if !prop.IsSortItemEmpty() && !candidate.isMatchProp { - return base.InvalidTask, nil - } - // while for now, we still can not push the sort prop to the intersection index plan side, temporarily banned here. - if !prop.IsSortItemEmpty() && candidate.path.IndexMergeIsIntersection { - return base.InvalidTask, nil - } - failpoint.Inject("forceIndexMergeKeepOrder", func(_ failpoint.Value) { - if len(candidate.path.PartialIndexPaths) > 0 && !candidate.path.IndexMergeIsIntersection { - if prop.IsSortItemEmpty() { - failpoint.Return(base.InvalidTask, nil) - } - } - }) - path := candidate.path - scans := make([]base.PhysicalPlan, 0, len(path.PartialIndexPaths)) - cop := &CopTask{ - indexPlanFinished: false, - tblColHists: ds.TblColHists, - } - cop.physPlanPartInfo = &PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), - PartitionNames: ds.PartitionNames, - Columns: ds.TblCols, - ColumnNames: ds.OutputNames(), - } - // Add sort items for index scan for merge-sort operation between partitions. - byItems := make([]*util.ByItems, 0, len(prop.SortItems)) - for _, si := range prop.SortItems { - byItems = append(byItems, &util.ByItems{ - Expr: si.Col, - Desc: si.Desc, - }) - } - globalRemainingFilters := make([]expression.Expression, 0, 3) - for _, partPath := range path.PartialIndexPaths { - var scan base.PhysicalPlan - if partPath.IsTablePath() { - scan = ds.convertToPartialTableScan(prop, partPath, candidate.isMatchProp, byItems) - } else { - var remainingFilters []expression.Expression - scan, remainingFilters, err = ds.convertToPartialIndexScan(cop.physPlanPartInfo, prop, partPath, candidate.isMatchProp, byItems) - if err != nil { - return base.InvalidTask, err - } - if prop.TaskTp != property.RootTaskType && len(remainingFilters) > 0 { - return base.InvalidTask, nil - } - globalRemainingFilters = append(globalRemainingFilters, remainingFilters...) - } - scans = append(scans, scan) - } - totalRowCount := path.CountAfterAccess - if prop.ExpectedCnt < ds.StatsInfo().RowCount { - totalRowCount *= prop.ExpectedCnt / ds.StatsInfo().RowCount - } - ts, remainingFilters2, moreColumn, err := ds.buildIndexMergeTableScan(path.TableFilters, totalRowCount, candidate.isMatchProp) - if err != nil { - return base.InvalidTask, err - } - if prop.TaskTp != property.RootTaskType && len(remainingFilters2) > 0 { - return base.InvalidTask, nil - } - globalRemainingFilters = append(globalRemainingFilters, remainingFilters2...) - cop.keepOrder = candidate.isMatchProp - cop.tablePlan = ts - cop.idxMergePartPlans = scans - cop.idxMergeIsIntersection = path.IndexMergeIsIntersection - cop.idxMergeAccessMVIndex = path.IndexMergeAccessMVIndex - if moreColumn { - cop.needExtraProj = true - cop.originSchema = ds.Schema() - } - if len(globalRemainingFilters) != 0 { - cop.rootTaskConds = globalRemainingFilters - } - // after we lift the limitation of intersection and cop-type task in the code in this - // function above, we could set its index plan finished as true once we found its table - // plan is pure table scan below. - // And this will cause cost underestimation when we estimate the cost of the entire cop - // task plan in function `getTaskPlanCost`. - if prop.TaskTp == property.RootTaskType { - cop.indexPlanFinished = true - task = cop.ConvertToRootTask(ds.SCtx()) - } else { - _, pureTableScan := ts.(*PhysicalTableScan) - if !pureTableScan { - cop.indexPlanFinished = true - } - task = cop - } - return task, nil -} - -func (ds *DataSource) convertToPartialIndexScan(physPlanPartInfo *PhysPlanPartInfo, prop *property.PhysicalProperty, path *util.AccessPath, matchProp bool, byItems []*util.ByItems) (base.PhysicalPlan, []expression.Expression, error) { - is := ds.getOriginalPhysicalIndexScan(prop, path, matchProp, false) - // TODO: Consider using isIndexCoveringColumns() to avoid another TableRead - indexConds := path.IndexFilters - if matchProp { - if is.Table.GetPartitionInfo() != nil && !is.Index.Global && is.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - is.Columns, is.schema, _ = AddExtraPhysTblIDColumn(is.SCtx(), is.Columns, is.schema) - } - // Add sort items for index scan for merge-sort operation between partitions. - is.ByItems = byItems - } - - // Add a `Selection` for `IndexScan` with global index. - // It should pushdown to TiKV, DataSource schema doesn't contain partition id column. - indexConds, err := is.addSelectionConditionForGlobalIndex(ds, physPlanPartInfo, indexConds) - if err != nil { - return nil, nil, err - } - - if len(indexConds) > 0 { - pushedFilters, remainingFilter := extractFiltersForIndexMerge(GetPushDownCtx(ds.SCtx()), indexConds) - var selectivity float64 - if path.CountAfterAccess > 0 { - selectivity = path.CountAfterIndex / path.CountAfterAccess - } - rowCount := is.StatsInfo().RowCount * selectivity - stats := &property.StatsInfo{RowCount: rowCount} - stats.StatsVersion = ds.StatisticTable.Version - if ds.StatisticTable.Pseudo { - stats.StatsVersion = statistics.PseudoVersion - } - indexPlan := PhysicalSelection{Conditions: pushedFilters}.Init(is.SCtx(), stats, ds.QueryBlockOffset()) - indexPlan.SetChildren(is) - return indexPlan, remainingFilter, nil - } - return is, nil, nil -} - -func checkColinSchema(cols []*expression.Column, schema *expression.Schema) bool { - for _, col := range cols { - if schema.ColumnIndex(col) == -1 { - return false - } - } - return true -} - -func (ds *DataSource) convertToPartialTableScan(prop *property.PhysicalProperty, path *util.AccessPath, matchProp bool, byItems []*util.ByItems) (tablePlan base.PhysicalPlan) { - ts, rowCount := ds.getOriginalPhysicalTableScan(prop, path, matchProp) - overwritePartialTableScanSchema(ds, ts) - // remove ineffetive filter condition after overwriting physicalscan schema - newFilterConds := make([]expression.Expression, 0, len(path.TableFilters)) - for _, cond := range ts.filterCondition { - cols := expression.ExtractColumns(cond) - if checkColinSchema(cols, ts.schema) { - newFilterConds = append(newFilterConds, cond) - } - } - ts.filterCondition = newFilterConds - if matchProp { - if ts.Table.GetPartitionInfo() != nil && ts.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - ts.Columns, ts.schema, _ = AddExtraPhysTblIDColumn(ts.SCtx(), ts.Columns, ts.schema) - } - ts.ByItems = byItems - } - if len(ts.filterCondition) > 0 { - selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, ts.filterCondition, nil) - if err != nil { - logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) - selectivity = cost.SelectionFactor - } - tablePlan = PhysicalSelection{Conditions: ts.filterCondition}.Init(ts.SCtx(), ts.StatsInfo().ScaleByExpectCnt(selectivity*rowCount), ds.QueryBlockOffset()) - tablePlan.SetChildren(ts) - return tablePlan - } - tablePlan = ts - return tablePlan -} - -// overwritePartialTableScanSchema change the schema of partial table scan to handle columns. -func overwritePartialTableScanSchema(ds *DataSource, ts *PhysicalTableScan) { - handleCols := ds.HandleCols - if handleCols == nil { - handleCols = util.NewIntHandleCols(ds.newExtraHandleSchemaCol()) - } - hdColNum := handleCols.NumCols() - exprCols := make([]*expression.Column, 0, hdColNum) - infoCols := make([]*model.ColumnInfo, 0, hdColNum) - for i := 0; i < hdColNum; i++ { - col := handleCols.GetCol(i) - exprCols = append(exprCols, col) - if c := model.FindColumnInfoByID(ds.TableInfo.Columns, col.ID); c != nil { - infoCols = append(infoCols, c) - } else { - infoCols = append(infoCols, col.ToInfo()) - } - } - ts.schema = expression.NewSchema(exprCols...) - ts.Columns = infoCols -} - -// setIndexMergeTableScanHandleCols set the handle columns of the table scan. -func setIndexMergeTableScanHandleCols(ds *DataSource, ts *PhysicalTableScan) (err error) { - handleCols := ds.HandleCols - if handleCols == nil { - handleCols = util.NewIntHandleCols(ds.newExtraHandleSchemaCol()) - } - hdColNum := handleCols.NumCols() - exprCols := make([]*expression.Column, 0, hdColNum) - for i := 0; i < hdColNum; i++ { - col := handleCols.GetCol(i) - exprCols = append(exprCols, col) - } - ts.HandleCols, err = handleCols.ResolveIndices(expression.NewSchema(exprCols...)) - return -} - -// buildIndexMergeTableScan() returns Selection that will be pushed to TiKV. -// Filters that cannot be pushed to TiKV are also returned, and an extra Selection above IndexMergeReader will be constructed later. -func (ds *DataSource) buildIndexMergeTableScan(tableFilters []expression.Expression, - totalRowCount float64, matchProp bool) (base.PhysicalPlan, []expression.Expression, bool, error) { - ts := PhysicalTableScan{ - Table: ds.TableInfo, - Columns: slices.Clone(ds.Columns), - TableAsName: ds.TableAsName, - DBName: ds.DBName, - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - HandleCols: ds.HandleCols, - tblCols: ds.TblCols, - tblColHists: ds.TblColHists, - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - ts.SetSchema(ds.Schema().Clone()) - err := setIndexMergeTableScanHandleCols(ds, ts) - if err != nil { - return nil, nil, false, err - } - ts.SetStats(ds.TableStats.ScaleByExpectCnt(totalRowCount)) - usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) - if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { - ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) - } - if ds.StatisticTable.Pseudo { - ts.StatsInfo().StatsVersion = statistics.PseudoVersion - } - var currentTopPlan base.PhysicalPlan = ts - if len(tableFilters) > 0 { - pushedFilters, remainingFilters := extractFiltersForIndexMerge(GetPushDownCtx(ds.SCtx()), tableFilters) - pushedFilters1, remainingFilters1 := SplitSelCondsWithVirtualColumn(pushedFilters) - pushedFilters = pushedFilters1 - remainingFilters = append(remainingFilters, remainingFilters1...) - if len(pushedFilters) != 0 { - selectivity, _, err := cardinality.Selectivity(ds.SCtx(), ds.TableStats.HistColl, pushedFilters, nil) - if err != nil { - logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) - selectivity = cost.SelectionFactor - } - sel := PhysicalSelection{Conditions: pushedFilters}.Init(ts.SCtx(), ts.StatsInfo().ScaleByExpectCnt(selectivity*totalRowCount), ts.QueryBlockOffset()) - sel.SetChildren(ts) - currentTopPlan = sel - } - if len(remainingFilters) > 0 { - return currentTopPlan, remainingFilters, false, nil - } - } - // If we don't need to use ordered scan, we don't need do the following codes for adding new columns. - if !matchProp { - return currentTopPlan, nil, false, nil - } - - // Add the row handle into the schema. - columnAdded := false - if ts.Table.PKIsHandle { - pk := ts.Table.GetPkColInfo() - pkCol := expression.ColInfo2Col(ts.tblCols, pk) - if !ts.schema.Contains(pkCol) { - ts.schema.Append(pkCol) - ts.Columns = append(ts.Columns, pk) - columnAdded = true - } - } else if ts.Table.IsCommonHandle { - idxInfo := ts.Table.GetPrimaryKey() - for _, idxCol := range idxInfo.Columns { - col := ts.tblCols[idxCol.Offset] - if !ts.schema.Contains(col) { - columnAdded = true - ts.schema.Append(col) - ts.Columns = append(ts.Columns, col.ToInfo()) - } - } - } else if !ts.schema.Contains(ts.HandleCols.GetCol(0)) { - ts.schema.Append(ts.HandleCols.GetCol(0)) - ts.Columns = append(ts.Columns, model.NewExtraHandleColInfo()) - columnAdded = true - } - - // For the global index of the partitioned table, we also need the PhysicalTblID to identify the rows from each partition. - if ts.Table.GetPartitionInfo() != nil && ts.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - var newColAdded bool - ts.Columns, ts.schema, newColAdded = AddExtraPhysTblIDColumn(ts.SCtx(), ts.Columns, ts.schema) - columnAdded = columnAdded || newColAdded - } - return currentTopPlan, nil, columnAdded, nil -} - -// extractFiltersForIndexMerge returns: -// `pushed`: exprs that can be pushed to TiKV. -// `remaining`: exprs that can NOT be pushed to TiKV but can be pushed to other storage engines. -// Why do we need this func? -// IndexMerge only works on TiKV, so we need to find all exprs that cannot be pushed to TiKV, and add a new Selection above IndexMergeReader. -// -// But the new Selection should exclude the exprs that can NOT be pushed to ALL the storage engines. -// Because these exprs have already been put in another Selection(check rule_predicate_push_down). -func extractFiltersForIndexMerge(ctx expression.PushDownContext, filters []expression.Expression) (pushed []expression.Expression, remaining []expression.Expression) { - for _, expr := range filters { - if expression.CanExprsPushDown(ctx, []expression.Expression{expr}, kv.TiKV) { - pushed = append(pushed, expr) - continue - } - if expression.CanExprsPushDown(ctx, []expression.Expression{expr}, kv.UnSpecified) { - remaining = append(remaining, expr) - } - } - return -} - -func isIndexColsCoveringCol(sctx expression.EvalContext, col *expression.Column, indexCols []*expression.Column, idxColLens []int, ignoreLen bool) bool { - for i, indexCol := range indexCols { - if indexCol == nil || !col.EqualByExprAndID(sctx, indexCol) { - continue - } - if ignoreLen || idxColLens[i] == types.UnspecifiedLength || idxColLens[i] == col.RetType.GetFlen() { - return true - } - } - return false -} - -func (ds *DataSource) indexCoveringColumn(column *expression.Column, indexColumns []*expression.Column, idxColLens []int, ignoreLen bool) bool { - if ds.TableInfo.PKIsHandle && mysql.HasPriKeyFlag(column.RetType.GetFlag()) { - return true - } - if column.ID == model.ExtraHandleID || column.ID == model.ExtraPhysTblID { - return true - } - evalCtx := ds.SCtx().GetExprCtx().GetEvalCtx() - coveredByPlainIndex := isIndexColsCoveringCol(evalCtx, column, indexColumns, idxColLens, ignoreLen) - coveredByClusteredIndex := isIndexColsCoveringCol(evalCtx, column, ds.CommonHandleCols, ds.CommonHandleLens, ignoreLen) - if !coveredByPlainIndex && !coveredByClusteredIndex { - return false - } - isClusteredNewCollationIdx := collate.NewCollationEnabled() && - column.GetType(evalCtx).EvalType() == types.ETString && - !mysql.HasBinaryFlag(column.GetType(evalCtx).GetFlag()) - if !coveredByPlainIndex && coveredByClusteredIndex && isClusteredNewCollationIdx && ds.table.Meta().CommonHandleVersion == 0 { - return false - } - return true -} - -func (ds *DataSource) isIndexCoveringColumns(columns, indexColumns []*expression.Column, idxColLens []int) bool { - for _, col := range columns { - if !ds.indexCoveringColumn(col, indexColumns, idxColLens, false) { - return false - } - } - return true -} - -func (ds *DataSource) isIndexCoveringCondition(condition expression.Expression, indexColumns []*expression.Column, idxColLens []int) bool { - switch v := condition.(type) { - case *expression.Column: - return ds.indexCoveringColumn(v, indexColumns, idxColLens, false) - case *expression.ScalarFunction: - // Even if the index only contains prefix `col`, the index can cover `col is null`. - if v.FuncName.L == ast.IsNull { - if col, ok := v.GetArgs()[0].(*expression.Column); ok { - return ds.indexCoveringColumn(col, indexColumns, idxColLens, true) - } - } - for _, arg := range v.GetArgs() { - if !ds.isIndexCoveringCondition(arg, indexColumns, idxColLens) { - return false - } - } - return true - } - return true -} - -func (ds *DataSource) isSingleScan(indexColumns []*expression.Column, idxColLens []int) bool { - if !ds.SCtx().GetSessionVars().OptPrefixIndexSingleScan || ds.ColsRequiringFullLen == nil { - // ds.ColsRequiringFullLen is set at (*DataSource).PruneColumns. In some cases we don't reach (*DataSource).PruneColumns - // and ds.ColsRequiringFullLen is nil, so we fall back to ds.isIndexCoveringColumns(ds.schema.Columns, indexColumns, idxColLens). - return ds.isIndexCoveringColumns(ds.Schema().Columns, indexColumns, idxColLens) - } - if !ds.isIndexCoveringColumns(ds.ColsRequiringFullLen, indexColumns, idxColLens) { - return false - } - for _, cond := range ds.AllConds { - if !ds.isIndexCoveringCondition(cond, indexColumns, idxColLens) { - return false - } - } - return true -} - -// If there is a table reader which needs to keep order, we should append a pk to table scan. -func (ts *PhysicalTableScan) appendExtraHandleCol(ds *DataSource) (*expression.Column, bool) { - handleCols := ds.HandleCols - if handleCols != nil { - return handleCols.GetCol(0), false - } - handleCol := ds.newExtraHandleSchemaCol() - ts.schema.Append(handleCol) - ts.Columns = append(ts.Columns, model.NewExtraHandleColInfo()) - return handleCol, true -} - -// convertToIndexScan converts the DataSource to index scan with idx. -func (ds *DataSource) convertToIndexScan(prop *property.PhysicalProperty, - candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (task base.Task, err error) { - if candidate.path.Index.MVIndex { - // MVIndex is special since different index rows may return the same _row_id and this can break some assumptions of IndexReader. - // Currently only support using IndexMerge to access MVIndex instead of IndexReader. - // TODO: make IndexReader support accessing MVIndex directly. - return base.InvalidTask, nil - } - if !candidate.path.IsSingleScan { - // If it's parent requires single read task, return max cost. - if prop.TaskTp == property.CopSingleReadTaskType { - return base.InvalidTask, nil - } - } else if prop.TaskTp == property.CopMultiReadTaskType { - // If it's parent requires double read task, return max cost. - return base.InvalidTask, nil - } - if !prop.IsSortItemEmpty() && !candidate.isMatchProp { - return base.InvalidTask, nil - } - // If we need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. - if prop.IsSortItemEmpty() && candidate.path.ForceKeepOrder { - return base.InvalidTask, nil - } - // If we don't need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. - if !prop.IsSortItemEmpty() && candidate.path.ForceNoKeepOrder { - return base.InvalidTask, nil - } - path := candidate.path - is := ds.getOriginalPhysicalIndexScan(prop, path, candidate.isMatchProp, candidate.path.IsSingleScan) - cop := &CopTask{ - indexPlan: is, - tblColHists: ds.TblColHists, - tblCols: ds.TblCols, - expectCnt: uint64(prop.ExpectedCnt), - } - cop.physPlanPartInfo = &PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), - PartitionNames: ds.PartitionNames, - Columns: ds.TblCols, - ColumnNames: ds.OutputNames(), - } - if !candidate.path.IsSingleScan { - // On this way, it's double read case. - ts := PhysicalTableScan{ - Columns: util.CloneColInfos(ds.Columns), - Table: is.Table, - TableAsName: ds.TableAsName, - DBName: ds.DBName, - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - tblCols: ds.TblCols, - tblColHists: ds.TblColHists, - }.Init(ds.SCtx(), is.QueryBlockOffset()) - ts.SetSchema(ds.Schema().Clone()) - // We set `StatsVersion` here and fill other fields in `(*copTask).finishIndexPlan`. Since `copTask.indexPlan` may - // change before calling `(*copTask).finishIndexPlan`, we don't know the stats information of `ts` currently and on - // the other hand, it may be hard to identify `StatsVersion` of `ts` in `(*copTask).finishIndexPlan`. - ts.SetStats(&property.StatsInfo{StatsVersion: ds.TableStats.StatsVersion}) - usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) - if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { - ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) - } - cop.tablePlan = ts - } - task = cop - if cop.tablePlan != nil && ds.TableInfo.IsCommonHandle { - cop.commonHandleCols = ds.CommonHandleCols - commonHandle := ds.HandleCols.(*util.CommonHandleCols) - for _, col := range commonHandle.GetColumns() { - if ds.Schema().ColumnIndex(col) == -1 { - ts := cop.tablePlan.(*PhysicalTableScan) - ts.Schema().Append(col) - ts.Columns = append(ts.Columns, col.ToInfo()) - cop.needExtraProj = true - } - } - } - if candidate.isMatchProp { - cop.keepOrder = true - if cop.tablePlan != nil && !ds.TableInfo.IsCommonHandle { - col, isNew := cop.tablePlan.(*PhysicalTableScan).appendExtraHandleCol(ds) - cop.extraHandleCol = col - cop.needExtraProj = cop.needExtraProj || isNew - } - - if ds.TableInfo.GetPartitionInfo() != nil { - // Add sort items for index scan for merge-sort operation between partitions, only required for local index. - if !is.Index.Global { - byItems := make([]*util.ByItems, 0, len(prop.SortItems)) - for _, si := range prop.SortItems { - byItems = append(byItems, &util.ByItems{ - Expr: si.Col, - Desc: si.Desc, - }) - } - cop.indexPlan.(*PhysicalIndexScan).ByItems = byItems - } - if cop.tablePlan != nil && ds.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - if !is.Index.Global { - is.Columns, is.schema, _ = AddExtraPhysTblIDColumn(is.SCtx(), is.Columns, is.Schema()) - } - var succ bool - // global index for tableScan with keepOrder also need PhysicalTblID - ts := cop.tablePlan.(*PhysicalTableScan) - ts.Columns, ts.schema, succ = AddExtraPhysTblIDColumn(ts.SCtx(), ts.Columns, ts.Schema()) - cop.needExtraProj = cop.needExtraProj || succ - } - } - } - if cop.needExtraProj { - cop.originSchema = ds.Schema() - } - // prop.IsSortItemEmpty() would always return true when coming to here, - // so we can just use prop.ExpectedCnt as parameter of addPushedDownSelection. - finalStats := ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt) - if err = is.addPushedDownSelection(cop, ds, path, finalStats); err != nil { - return base.InvalidTask, err - } - if prop.TaskTp == property.RootTaskType { - task = task.ConvertToRootTask(ds.SCtx()) - } else if _, ok := task.(*RootTask); ok { - return base.InvalidTask, nil - } - return task, nil -} - -func (is *PhysicalIndexScan) getScanRowSize() float64 { - idx := is.Index - scanCols := make([]*expression.Column, 0, len(idx.Columns)+1) - // If `initSchema` has already appended the handle column in schema, just use schema columns, otherwise, add extra handle column. - if len(idx.Columns) == len(is.schema.Columns) { - scanCols = append(scanCols, is.schema.Columns...) - handleCol := is.pkIsHandleCol - if handleCol != nil { - scanCols = append(scanCols, handleCol) - } - } else { - scanCols = is.schema.Columns - } - return cardinality.GetIndexAvgRowSize(is.SCtx(), is.tblColHists, scanCols, is.Index.Unique) -} - -// initSchema is used to set the schema of PhysicalIndexScan. Before calling this, -// make sure the following field of PhysicalIndexScan are initialized: -// -// PhysicalIndexScan.Table *model.TableInfo -// PhysicalIndexScan.Index *model.IndexInfo -// PhysicalIndexScan.Index.Columns []*IndexColumn -// PhysicalIndexScan.IdxCols []*expression.Column -// PhysicalIndexScan.Columns []*model.ColumnInfo -func (is *PhysicalIndexScan) initSchema(idxExprCols []*expression.Column, isDoubleRead bool) { - indexCols := make([]*expression.Column, len(is.IdxCols), len(is.Index.Columns)+1) - copy(indexCols, is.IdxCols) - - for i := len(is.IdxCols); i < len(is.Index.Columns); i++ { - if idxExprCols[i] != nil { - indexCols = append(indexCols, idxExprCols[i]) - } else { - // TODO: try to reuse the col generated when building the DataSource. - indexCols = append(indexCols, &expression.Column{ - ID: is.Table.Columns[is.Index.Columns[i].Offset].ID, - RetType: &is.Table.Columns[is.Index.Columns[i].Offset].FieldType, - UniqueID: is.SCtx().GetSessionVars().AllocPlanColumnID(), - }) - } - } - is.NeedCommonHandle = is.Table.IsCommonHandle - - if is.NeedCommonHandle { - for i := len(is.Index.Columns); i < len(idxExprCols); i++ { - indexCols = append(indexCols, idxExprCols[i]) - } - } - setHandle := len(indexCols) > len(is.Index.Columns) - if !setHandle { - for i, col := range is.Columns { - if (mysql.HasPriKeyFlag(col.GetFlag()) && is.Table.PKIsHandle) || col.ID == model.ExtraHandleID { - indexCols = append(indexCols, is.dataSourceSchema.Columns[i]) - setHandle = true - break - } - } - } - - var extraPhysTblCol *expression.Column - // If `dataSouceSchema` contains `model.ExtraPhysTblID`, we should add it into `indexScan.schema` - for _, col := range is.dataSourceSchema.Columns { - if col.ID == model.ExtraPhysTblID { - extraPhysTblCol = col.Clone().(*expression.Column) - break - } - } - - if isDoubleRead || is.Index.Global { - // If it's double read case, the first index must return handle. So we should add extra handle column - // if there isn't a handle column. - if !setHandle { - if !is.Table.IsCommonHandle { - indexCols = append(indexCols, &expression.Column{ - RetType: types.NewFieldType(mysql.TypeLonglong), - ID: model.ExtraHandleID, - UniqueID: is.SCtx().GetSessionVars().AllocPlanColumnID(), - OrigName: model.ExtraHandleName.O, - }) - } - } - // If it's global index, handle and PhysTblID columns has to be added, so that needed pids can be filtered. - if is.Index.Global && extraPhysTblCol == nil { - indexCols = append(indexCols, &expression.Column{ - RetType: types.NewFieldType(mysql.TypeLonglong), - ID: model.ExtraPhysTblID, - UniqueID: is.SCtx().GetSessionVars().AllocPlanColumnID(), - OrigName: model.ExtraPhysTblIdName.O, - }) - } - } - - if extraPhysTblCol != nil { - indexCols = append(indexCols, extraPhysTblCol) - } - - is.SetSchema(expression.NewSchema(indexCols...)) -} - -func (is *PhysicalIndexScan) addSelectionConditionForGlobalIndex(p *DataSource, physPlanPartInfo *PhysPlanPartInfo, conditions []expression.Expression) ([]expression.Expression, error) { - if !is.Index.Global { - return conditions, nil - } - args := make([]expression.Expression, 0, len(p.PartitionNames)+1) - for _, col := range is.schema.Columns { - if col.ID == model.ExtraPhysTblID { - args = append(args, col.Clone()) - break - } - } - - if len(args) != 1 { - return nil, errors.Errorf("Can't find column %s in schema %s", model.ExtraPhysTblIdName.O, is.schema) - } - - // For SQL like 'select x from t partition(p0, p1) use index(idx)', - // we will add a `Selection` like `in(t._tidb_pid, p0, p1)` into the plan. - // For truncate/drop partitions, we should only return indexes where partitions still in public state. - idxArr, err := PartitionPruning(p.SCtx(), p.table.GetPartitionedTable(), - physPlanPartInfo.PruningConds, - physPlanPartInfo.PartitionNames, - physPlanPartInfo.Columns, - physPlanPartInfo.ColumnNames) - if err != nil { - return nil, err - } - needNot := false - pInfo := p.TableInfo.GetPartitionInfo() - if len(idxArr) == 1 && idxArr[0] == FullRange { - // Only filter adding and dropping partitions. - if len(pInfo.AddingDefinitions) == 0 && len(pInfo.DroppingDefinitions) == 0 { - return conditions, nil - } - needNot = true - for _, p := range pInfo.AddingDefinitions { - args = append(args, expression.NewInt64Const(p.ID)) - } - for _, p := range pInfo.DroppingDefinitions { - args = append(args, expression.NewInt64Const(p.ID)) - } - } else if len(idxArr) == 0 { - // add an invalid pid as param for `IN` function - args = append(args, expression.NewInt64Const(-1)) - } else { - // `PartitionPruning`` func does not return adding and dropping partitions - for _, idx := range idxArr { - args = append(args, expression.NewInt64Const(pInfo.Definitions[idx].ID)) - } - } - condition, err := expression.NewFunction(p.SCtx().GetExprCtx(), ast.In, types.NewFieldType(mysql.TypeLonglong), args...) - if err != nil { - return nil, err - } - if needNot { - condition, err = expression.NewFunction(p.SCtx().GetExprCtx(), ast.UnaryNot, types.NewFieldType(mysql.TypeLonglong), condition) - if err != nil { - return nil, err - } - } - return append(conditions, condition), nil -} - -func (is *PhysicalIndexScan) addPushedDownSelection(copTask *CopTask, p *DataSource, path *util.AccessPath, finalStats *property.StatsInfo) error { - // Add filter condition to table plan now. - indexConds, tableConds := path.IndexFilters, path.TableFilters - tableConds, copTask.rootTaskConds = SplitSelCondsWithVirtualColumn(tableConds) - - var newRootConds []expression.Expression - pctx := GetPushDownCtx(is.SCtx()) - indexConds, newRootConds = expression.PushDownExprs(pctx, indexConds, kv.TiKV) - copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) - - tableConds, newRootConds = expression.PushDownExprs(pctx, tableConds, kv.TiKV) - copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) - - // Add a `Selection` for `IndexScan` with global index. - // It should pushdown to TiKV, DataSource schema doesn't contain partition id column. - indexConds, err := is.addSelectionConditionForGlobalIndex(p, copTask.physPlanPartInfo, indexConds) - if err != nil { - return err - } - - if indexConds != nil { - var selectivity float64 - if path.CountAfterAccess > 0 { - selectivity = path.CountAfterIndex / path.CountAfterAccess - } - count := is.StatsInfo().RowCount * selectivity - stats := p.TableStats.ScaleByExpectCnt(count) - indexSel := PhysicalSelection{Conditions: indexConds}.Init(is.SCtx(), stats, is.QueryBlockOffset()) - indexSel.SetChildren(is) - copTask.indexPlan = indexSel - } - if len(tableConds) > 0 { - copTask.finishIndexPlan() - tableSel := PhysicalSelection{Conditions: tableConds}.Init(is.SCtx(), finalStats, is.QueryBlockOffset()) - if len(copTask.rootTaskConds) != 0 { - selectivity, _, err := cardinality.Selectivity(is.SCtx(), copTask.tblColHists, tableConds, nil) - if err != nil { - logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) - selectivity = cost.SelectionFactor - } - tableSel.SetStats(copTask.Plan().StatsInfo().Scale(selectivity)) - } - tableSel.SetChildren(copTask.tablePlan) - copTask.tablePlan = tableSel - } - return nil -} - -// NeedExtraOutputCol is designed for check whether need an extra column for -// pid or physical table id when build indexReq. -func (is *PhysicalIndexScan) NeedExtraOutputCol() bool { - if is.Table.Partition == nil { - return false - } - // has global index, should return pid - if is.Index.Global { - return true - } - // has embedded limit, should return physical table id - if len(is.ByItems) != 0 && is.SCtx().GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - return true - } - return false -} - -// SplitSelCondsWithVirtualColumn filter the select conditions which contain virtual column -func SplitSelCondsWithVirtualColumn(conds []expression.Expression) (withoutVirt []expression.Expression, withVirt []expression.Expression) { - for i := range conds { - if expression.ContainVirtualColumn(conds[i : i+1]) { - withVirt = append(withVirt, conds[i]) - } else { - withoutVirt = append(withoutVirt, conds[i]) - } - } - return withoutVirt, withVirt -} - -func matchIndicesProp(sctx base.PlanContext, idxCols []*expression.Column, colLens []int, propItems []property.SortItem) bool { - if len(idxCols) < len(propItems) { - return false - } - for i, item := range propItems { - if colLens[i] != types.UnspecifiedLength || !item.Col.EqualByExprAndID(sctx.GetExprCtx().GetEvalCtx(), idxCols[i]) { - return false - } - } - return true -} - -func (ds *DataSource) splitIndexFilterConditions(conditions []expression.Expression, indexColumns []*expression.Column, - idxColLens []int) (indexConds, tableConds []expression.Expression) { - var indexConditions, tableConditions []expression.Expression - for _, cond := range conditions { - var covered bool - if ds.SCtx().GetSessionVars().OptPrefixIndexSingleScan { - covered = ds.isIndexCoveringCondition(cond, indexColumns, idxColLens) - } else { - covered = ds.isIndexCoveringColumns(expression.ExtractColumns(cond), indexColumns, idxColLens) - } - if covered { - indexConditions = append(indexConditions, cond) - } else { - tableConditions = append(tableConditions, cond) - } - } - return indexConditions, tableConditions -} - -// GetPhysicalScan4LogicalTableScan returns PhysicalTableScan for the LogicalTableScan. -func GetPhysicalScan4LogicalTableScan(s *LogicalTableScan, schema *expression.Schema, stats *property.StatsInfo) *PhysicalTableScan { - ds := s.Source - ts := PhysicalTableScan{ - Table: ds.TableInfo, - Columns: ds.Columns, - TableAsName: ds.TableAsName, - DBName: ds.DBName, - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - Ranges: s.Ranges, - AccessCondition: s.AccessConds, - tblCols: ds.TblCols, - tblColHists: ds.TblColHists, - }.Init(s.SCtx(), s.QueryBlockOffset()) - ts.SetStats(stats) - ts.SetSchema(schema.Clone()) - return ts -} - -// GetPhysicalIndexScan4LogicalIndexScan returns PhysicalIndexScan for the logical IndexScan. -func GetPhysicalIndexScan4LogicalIndexScan(s *LogicalIndexScan, _ *expression.Schema, stats *property.StatsInfo) *PhysicalIndexScan { - ds := s.Source - is := PhysicalIndexScan{ - Table: ds.TableInfo, - TableAsName: ds.TableAsName, - DBName: ds.DBName, - Columns: s.Columns, - Index: s.Index, - IdxCols: s.IdxCols, - IdxColLens: s.IdxColLens, - AccessCondition: s.AccessConds, - Ranges: s.Ranges, - dataSourceSchema: ds.Schema(), - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - tblColHists: ds.TblColHists, - pkIsHandleCol: ds.getPKIsHandleCol(), - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - is.SetStats(stats) - is.initSchema(s.FullIdxCols, s.IsDoubleRead) - return is -} - -// isPointGetPath indicates whether the conditions are point-get-able. -// eg: create table t(a int, b int,c int unique, primary (a,b)) -// select * from t where a = 1 and b = 1 and c =1; -// the datasource can access by primary key(a,b) or unique key c which are both point-get-able -func (ds *DataSource) isPointGetPath(path *util.AccessPath) bool { - if len(path.Ranges) < 1 { - return false - } - if !path.IsIntHandlePath { - if path.Index == nil { - return false - } - if !path.Index.Unique || path.Index.HasPrefixIndex() { - return false - } - idxColsLen := len(path.Index.Columns) - for _, ran := range path.Ranges { - if len(ran.LowVal) != idxColsLen { - return false - } - } - } - tc := ds.SCtx().GetSessionVars().StmtCtx.TypeCtx() - for _, ran := range path.Ranges { - if !ran.IsPointNonNullable(tc) { - return false - } - } - return true -} - -// convertToTableScan converts the DataSource to table scan. -func (ds *DataSource) convertToTableScan(prop *property.PhysicalProperty, candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, error) { - // It will be handled in convertToIndexScan. - if prop.TaskTp == property.CopMultiReadTaskType { - return base.InvalidTask, nil - } - if !prop.IsSortItemEmpty() && !candidate.isMatchProp { - return base.InvalidTask, nil - } - // If we need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. - if prop.IsSortItemEmpty() && candidate.path.ForceKeepOrder { - return base.InvalidTask, nil - } - // If we don't need to keep order for the index scan, we should forbid the non-keep-order index scan when we try to generate the path. - if !prop.IsSortItemEmpty() && candidate.path.ForceNoKeepOrder { - return base.InvalidTask, nil - } - ts, _ := ds.getOriginalPhysicalTableScan(prop, candidate.path, candidate.isMatchProp) - if ts.KeepOrder && ts.StoreType == kv.TiFlash && (ts.Desc || ds.SCtx().GetSessionVars().TiFlashFastScan) { - // TiFlash fast mode(https://github.com/pingcap/tidb/pull/35851) does not keep order in TableScan - return base.InvalidTask, nil - } - if ts.StoreType == kv.TiFlash { - for _, col := range ts.Columns { - if col.IsVirtualGenerated() { - col.AddFlag(mysql.GeneratedColumnFlag) - } - } - } - // In disaggregated tiflash mode, only MPP is allowed, cop and batchCop is deprecated. - // So if prop.TaskTp is RootTaskType, have to use mppTask then convert to rootTask. - isTiFlashPath := ts.StoreType == kv.TiFlash - canMppConvertToRoot := prop.TaskTp == property.RootTaskType && ds.SCtx().GetSessionVars().IsMPPAllowed() && isTiFlashPath - canMppConvertToRootForDisaggregatedTiFlash := config.GetGlobalConfig().DisaggregatedTiFlash && canMppConvertToRoot - canMppConvertToRootForWhenTiFlashCopIsBanned := ds.SCtx().GetSessionVars().IsTiFlashCopBanned() && canMppConvertToRoot - if prop.TaskTp == property.MppTaskType || canMppConvertToRootForDisaggregatedTiFlash || canMppConvertToRootForWhenTiFlashCopIsBanned { - if ts.KeepOrder { - return base.InvalidTask, nil - } - if prop.MPPPartitionTp != property.AnyType { - return base.InvalidTask, nil - } - // ********************************** future deprecated start **************************/ - var hasVirtualColumn bool - for _, col := range ts.schema.Columns { - if col.VirtualExpr != nil { - ds.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because column `" + col.OrigName + "` is a virtual column which is not supported now.") - hasVirtualColumn = true - break - } - } - // in general, since MPP has supported the Gather operator to fill the virtual column, we should full lift restrictions here. - // we left them here, because cases like: - // parent-----+ - // V (when parent require a root task type here, we need convert mpp task to root task) - // projection [mpp task] [a] - // table-scan [mpp task] [a(virtual col as: b+1), b] - // in the process of converting mpp task to root task, the encapsulated table reader will use its first children schema [a] - // as its schema, so when we resolve indices later, the virtual column 'a' itself couldn't resolve itself anymore. - // - if hasVirtualColumn && !canMppConvertToRootForDisaggregatedTiFlash && !canMppConvertToRootForWhenTiFlashCopIsBanned { - return base.InvalidTask, nil - } - // ********************************** future deprecated end **************************/ - mppTask := &MppTask{ - p: ts, - partTp: property.AnyType, - tblColHists: ds.TblColHists, - } - ts.PlanPartInfo = &PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), - PartitionNames: ds.PartitionNames, - Columns: ds.TblCols, - ColumnNames: ds.OutputNames(), - } - mppTask = ts.addPushedDownSelectionToMppTask(mppTask, ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt)) - var task base.Task = mppTask - if !mppTask.Invalid() { - if prop.TaskTp == property.MppTaskType && len(mppTask.rootTaskConds) > 0 { - // If got filters cannot be pushed down to tiflash, we have to make sure it will be executed in TiDB, - // So have to return a rootTask, but prop requires mppTask, cannot meet this requirement. - task = base.InvalidTask - } else if prop.TaskTp == property.RootTaskType { - // When got here, canMppConvertToRootX is true. - // This is for situations like cannot generate mppTask for some operators. - // Such as when the build side of HashJoin is Projection, - // which cannot pushdown to tiflash(because TiFlash doesn't support some expr in Proj) - // So HashJoin cannot pushdown to tiflash. But we still want TableScan to run on tiflash. - task = mppTask - task = task.ConvertToRootTask(ds.SCtx()) - } - } - return task, nil - } - if isTiFlashPath && config.GetGlobalConfig().DisaggregatedTiFlash || isTiFlashPath && ds.SCtx().GetSessionVars().IsTiFlashCopBanned() { - // prop.TaskTp is cop related, just return base.InvalidTask. - return base.InvalidTask, nil - } - copTask := &CopTask{ - tablePlan: ts, - indexPlanFinished: true, - tblColHists: ds.TblColHists, - } - copTask.physPlanPartInfo = &PhysPlanPartInfo{ - PruningConds: pushDownNot(ds.SCtx().GetExprCtx(), ds.AllConds), - PartitionNames: ds.PartitionNames, - Columns: ds.TblCols, - ColumnNames: ds.OutputNames(), - } - ts.PlanPartInfo = copTask.physPlanPartInfo - var task base.Task = copTask - if candidate.isMatchProp { - copTask.keepOrder = true - if ds.TableInfo.GetPartitionInfo() != nil { - // TableScan on partition table on TiFlash can't keep order. - if ts.StoreType == kv.TiFlash { - return base.InvalidTask, nil - } - // Add sort items for table scan for merge-sort operation between partitions. - byItems := make([]*util.ByItems, 0, len(prop.SortItems)) - for _, si := range prop.SortItems { - byItems = append(byItems, &util.ByItems{ - Expr: si.Col, - Desc: si.Desc, - }) - } - ts.ByItems = byItems - } - } - ts.addPushedDownSelection(copTask, ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt)) - if prop.IsFlashProp() && len(copTask.rootTaskConds) != 0 { - return base.InvalidTask, nil - } - if prop.TaskTp == property.RootTaskType { - task = task.ConvertToRootTask(ds.SCtx()) - } else if _, ok := task.(*RootTask); ok { - return base.InvalidTask, nil - } - return task, nil -} - -func (ds *DataSource) convertToSampleTable(prop *property.PhysicalProperty, - candidate *candidatePath, _ *optimizetrace.PhysicalOptimizeOp) (base.Task, error) { - if prop.TaskTp == property.CopMultiReadTaskType { - return base.InvalidTask, nil - } - if !prop.IsSortItemEmpty() && !candidate.isMatchProp { - return base.InvalidTask, nil - } - if candidate.isMatchProp { - // Disable keep order property for sample table path. - return base.InvalidTask, nil - } - p := PhysicalTableSample{ - TableSampleInfo: ds.SampleInfo, - TableInfo: ds.table, - PhysicalTableID: ds.PhysicalTableID, - Desc: candidate.isMatchProp && prop.SortItems[0].Desc, - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - p.schema = ds.Schema() - rt := &RootTask{} - rt.SetPlan(p) - return rt, nil -} - -func (ds *DataSource) convertToPointGet(prop *property.PhysicalProperty, candidate *candidatePath) base.Task { - if !prop.IsSortItemEmpty() && !candidate.isMatchProp { - return base.InvalidTask - } - if prop.TaskTp == property.CopMultiReadTaskType && candidate.path.IsSingleScan || - prop.TaskTp == property.CopSingleReadTaskType && !candidate.path.IsSingleScan { - return base.InvalidTask - } - - if tidbutil.IsMemDB(ds.DBName.L) { - return base.InvalidTask - } - - accessCnt := math.Min(candidate.path.CountAfterAccess, float64(1)) - pointGetPlan := PointGetPlan{ - ctx: ds.SCtx(), - AccessConditions: candidate.path.AccessConds, - schema: ds.Schema().Clone(), - dbName: ds.DBName.L, - TblInfo: ds.TableInfo, - outputNames: ds.OutputNames(), - LockWaitTime: ds.SCtx().GetSessionVars().LockWaitTimeout, - Columns: ds.Columns, - }.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.QueryBlockOffset()) - if ds.PartitionDefIdx != nil { - pointGetPlan.PartitionIdx = ds.PartitionDefIdx - } - pointGetPlan.PartitionNames = ds.PartitionNames - rTsk := &RootTask{} - rTsk.SetPlan(pointGetPlan) - if candidate.path.IsIntHandlePath { - pointGetPlan.Handle = kv.IntHandle(candidate.path.Ranges[0].LowVal[0].GetInt64()) - pointGetPlan.UnsignedHandle = mysql.HasUnsignedFlag(ds.HandleCols.GetCol(0).RetType.GetFlag()) - pointGetPlan.accessCols = ds.TblCols - found := false - for i := range ds.Columns { - if ds.Columns[i].ID == ds.HandleCols.GetCol(0).ID { - pointGetPlan.HandleColOffset = ds.Columns[i].Offset - found = true - break - } - } - if !found { - return base.InvalidTask - } - // Add filter condition to table plan now. - if len(candidate.path.TableFilters) > 0 { - sel := PhysicalSelection{ - Conditions: candidate.path.TableFilters, - }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) - sel.SetChildren(pointGetPlan) - rTsk.SetPlan(sel) - } - } else { - pointGetPlan.IndexInfo = candidate.path.Index - pointGetPlan.IdxCols = candidate.path.IdxCols - pointGetPlan.IdxColLens = candidate.path.IdxColLens - pointGetPlan.IndexValues = candidate.path.Ranges[0].LowVal - if candidate.path.IsSingleScan { - pointGetPlan.accessCols = candidate.path.IdxCols - } else { - pointGetPlan.accessCols = ds.TblCols - } - // Add index condition to table plan now. - if len(candidate.path.IndexFilters)+len(candidate.path.TableFilters) > 0 { - sel := PhysicalSelection{ - Conditions: append(candidate.path.IndexFilters, candidate.path.TableFilters...), - }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) - sel.SetChildren(pointGetPlan) - rTsk.SetPlan(sel) - } - } - - return rTsk -} - -func (ds *DataSource) convertToBatchPointGet(prop *property.PhysicalProperty, candidate *candidatePath) base.Task { - if !prop.IsSortItemEmpty() && !candidate.isMatchProp { - return base.InvalidTask - } - if prop.TaskTp == property.CopMultiReadTaskType && candidate.path.IsSingleScan || - prop.TaskTp == property.CopSingleReadTaskType && !candidate.path.IsSingleScan { - return base.InvalidTask - } - - accessCnt := math.Min(candidate.path.CountAfterAccess, float64(len(candidate.path.Ranges))) - batchPointGetPlan := &BatchPointGetPlan{ - ctx: ds.SCtx(), - dbName: ds.DBName.L, - AccessConditions: candidate.path.AccessConds, - TblInfo: ds.TableInfo, - KeepOrder: !prop.IsSortItemEmpty(), - Columns: ds.Columns, - PartitionNames: ds.PartitionNames, - } - if ds.PartitionDefIdx != nil { - batchPointGetPlan.SinglePartition = true - batchPointGetPlan.PartitionIdxs = []int{*ds.PartitionDefIdx} - } - if batchPointGetPlan.KeepOrder { - batchPointGetPlan.Desc = prop.SortItems[0].Desc - } - rTsk := &RootTask{} - if candidate.path.IsIntHandlePath { - for _, ran := range candidate.path.Ranges { - batchPointGetPlan.Handles = append(batchPointGetPlan.Handles, kv.IntHandle(ran.LowVal[0].GetInt64())) - } - batchPointGetPlan.accessCols = ds.TblCols - found := false - for i := range ds.Columns { - if ds.Columns[i].ID == ds.HandleCols.GetCol(0).ID { - batchPointGetPlan.HandleColOffset = ds.Columns[i].Offset - found = true - break - } - } - if !found { - return base.InvalidTask - } - - // Add filter condition to table plan now. - if len(candidate.path.TableFilters) > 0 { - batchPointGetPlan.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.Schema().Clone(), ds.OutputNames(), ds.QueryBlockOffset()) - sel := PhysicalSelection{ - Conditions: candidate.path.TableFilters, - }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) - sel.SetChildren(batchPointGetPlan) - rTsk.SetPlan(sel) - } - } else { - batchPointGetPlan.IndexInfo = candidate.path.Index - batchPointGetPlan.IdxCols = candidate.path.IdxCols - batchPointGetPlan.IdxColLens = candidate.path.IdxColLens - for _, ran := range candidate.path.Ranges { - batchPointGetPlan.IndexValues = append(batchPointGetPlan.IndexValues, ran.LowVal) - } - if !prop.IsSortItemEmpty() { - batchPointGetPlan.KeepOrder = true - batchPointGetPlan.Desc = prop.SortItems[0].Desc - } - if candidate.path.IsSingleScan { - batchPointGetPlan.accessCols = candidate.path.IdxCols - } else { - batchPointGetPlan.accessCols = ds.TblCols - } - // Add index condition to table plan now. - if len(candidate.path.IndexFilters)+len(candidate.path.TableFilters) > 0 { - batchPointGetPlan.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.Schema().Clone(), ds.OutputNames(), ds.QueryBlockOffset()) - sel := PhysicalSelection{ - Conditions: append(candidate.path.IndexFilters, candidate.path.TableFilters...), - }.Init(ds.SCtx(), ds.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), ds.QueryBlockOffset()) - sel.SetChildren(batchPointGetPlan) - rTsk.SetPlan(sel) - } - } - if rTsk.GetPlan() == nil { - tmpP := batchPointGetPlan.Init(ds.SCtx(), ds.TableStats.ScaleByExpectCnt(accessCnt), ds.Schema().Clone(), ds.OutputNames(), ds.QueryBlockOffset()) - rTsk.SetPlan(tmpP) - } - - return rTsk -} - -func (ts *PhysicalTableScan) addPushedDownSelectionToMppTask(mpp *MppTask, stats *property.StatsInfo) *MppTask { - filterCondition, rootTaskConds := SplitSelCondsWithVirtualColumn(ts.filterCondition) - var newRootConds []expression.Expression - filterCondition, newRootConds = expression.PushDownExprs(GetPushDownCtx(ts.SCtx()), filterCondition, ts.StoreType) - mpp.rootTaskConds = append(rootTaskConds, newRootConds...) - - ts.filterCondition = filterCondition - // Add filter condition to table plan now. - if len(ts.filterCondition) > 0 { - sel := PhysicalSelection{Conditions: ts.filterCondition}.Init(ts.SCtx(), stats, ts.QueryBlockOffset()) - sel.SetChildren(ts) - mpp.p = sel - } - return mpp -} - -func (ts *PhysicalTableScan) addPushedDownSelection(copTask *CopTask, stats *property.StatsInfo) { - ts.filterCondition, copTask.rootTaskConds = SplitSelCondsWithVirtualColumn(ts.filterCondition) - var newRootConds []expression.Expression - ts.filterCondition, newRootConds = expression.PushDownExprs(GetPushDownCtx(ts.SCtx()), ts.filterCondition, ts.StoreType) - copTask.rootTaskConds = append(copTask.rootTaskConds, newRootConds...) - - // Add filter condition to table plan now. - if len(ts.filterCondition) > 0 { - sel := PhysicalSelection{Conditions: ts.filterCondition}.Init(ts.SCtx(), stats, ts.QueryBlockOffset()) - if len(copTask.rootTaskConds) != 0 { - selectivity, _, err := cardinality.Selectivity(ts.SCtx(), copTask.tblColHists, ts.filterCondition, nil) - if err != nil { - logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) - selectivity = cost.SelectionFactor - } - sel.SetStats(ts.StatsInfo().Scale(selectivity)) - } - sel.SetChildren(ts) - copTask.tablePlan = sel - } -} - -func (ts *PhysicalTableScan) getScanRowSize() float64 { - if ts.StoreType == kv.TiKV { - return cardinality.GetTableAvgRowSize(ts.SCtx(), ts.tblColHists, ts.tblCols, ts.StoreType, true) - } - // If `ts.handleCol` is nil, then the schema of tableScan doesn't have handle column. - // This logic can be ensured in column pruning. - return cardinality.GetTableAvgRowSize(ts.SCtx(), ts.tblColHists, ts.Schema().Columns, ts.StoreType, ts.HandleCols != nil) -} - -func (ds *DataSource) getOriginalPhysicalTableScan(prop *property.PhysicalProperty, path *util.AccessPath, isMatchProp bool) (*PhysicalTableScan, float64) { - ts := PhysicalTableScan{ - Table: ds.TableInfo, - Columns: slices.Clone(ds.Columns), - TableAsName: ds.TableAsName, - DBName: ds.DBName, - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - Ranges: path.Ranges, - AccessCondition: path.AccessConds, - StoreType: path.StoreType, - HandleCols: ds.HandleCols, - tblCols: ds.TblCols, - tblColHists: ds.TblColHists, - constColsByCond: path.ConstCols, - prop: prop, - filterCondition: slices.Clone(path.TableFilters), - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - ts.SetSchema(ds.Schema().Clone()) - rowCount := path.CountAfterAccess - if prop.ExpectedCnt < ds.StatsInfo().RowCount { - rowCount = cardinality.AdjustRowCountForTableScanByLimit(ds.SCtx(), - ds.StatsInfo(), ds.TableStats, ds.StatisticTable, - path, prop.ExpectedCnt, isMatchProp && prop.SortItems[0].Desc) - } - // We need NDV of columns since it may be used in cost estimation of join. Precisely speaking, - // we should track NDV of each histogram bucket, and sum up the NDV of buckets we actually need - // to scan, but this would only help improve accuracy of NDV for one column, for other columns, - // we still need to assume values are uniformly distributed. For simplicity, we use uniform-assumption - // for all columns now, as we do in `deriveStatsByFilter`. - ts.SetStats(ds.TableStats.ScaleByExpectCnt(rowCount)) - usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) - if usedStats != nil && usedStats.GetUsedInfo(ts.physicalTableID) != nil { - ts.usedStatsInfo = usedStats.GetUsedInfo(ts.physicalTableID) - } - if isMatchProp { - ts.Desc = prop.SortItems[0].Desc - ts.KeepOrder = true - } - return ts, rowCount -} - -func (ds *DataSource) getOriginalPhysicalIndexScan(prop *property.PhysicalProperty, path *util.AccessPath, isMatchProp bool, isSingleScan bool) *PhysicalIndexScan { - idx := path.Index - is := PhysicalIndexScan{ - Table: ds.TableInfo, - TableAsName: ds.TableAsName, - DBName: ds.DBName, - Columns: util.CloneColInfos(ds.Columns), - Index: idx, - IdxCols: path.IdxCols, - IdxColLens: path.IdxColLens, - AccessCondition: path.AccessConds, - Ranges: path.Ranges, - dataSourceSchema: ds.Schema(), - isPartition: ds.PartitionDefIdx != nil, - physicalTableID: ds.PhysicalTableID, - tblColHists: ds.TblColHists, - pkIsHandleCol: ds.getPKIsHandleCol(), - constColsByCond: path.ConstCols, - prop: prop, - }.Init(ds.SCtx(), ds.QueryBlockOffset()) - rowCount := path.CountAfterAccess - is.initSchema(append(path.FullIdxCols, ds.CommonHandleCols...), !isSingleScan) - - // If (1) there exists an index whose selectivity is smaller than the threshold, - // and (2) there is Selection on the IndexScan, we don't use the ExpectedCnt to - // adjust the estimated row count of the IndexScan. - ignoreExpectedCnt := ds.AccessPathMinSelectivity < ds.SCtx().GetSessionVars().OptOrderingIdxSelThresh && - len(path.IndexFilters)+len(path.TableFilters) > 0 - - if (isMatchProp || prop.IsSortItemEmpty()) && prop.ExpectedCnt < ds.StatsInfo().RowCount && !ignoreExpectedCnt { - rowCount = cardinality.AdjustRowCountForIndexScanByLimit(ds.SCtx(), - ds.StatsInfo(), ds.TableStats, ds.StatisticTable, - path, prop.ExpectedCnt, isMatchProp && prop.SortItems[0].Desc) - } - // ScaleByExpectCnt only allows to scale the row count smaller than the table total row count. - // But for MV index, it's possible that the IndexRangeScan row count is larger than the table total row count. - // Please see the Case 2 in CalcTotalSelectivityForMVIdxPath for an example. - if idx.MVIndex && rowCount > ds.TableStats.RowCount { - is.SetStats(ds.TableStats.Scale(rowCount / ds.TableStats.RowCount)) - } else { - is.SetStats(ds.TableStats.ScaleByExpectCnt(rowCount)) - } - usedStats := ds.SCtx().GetSessionVars().StmtCtx.GetUsedStatsInfo(false) - if usedStats != nil && usedStats.GetUsedInfo(is.physicalTableID) != nil { - is.usedStatsInfo = usedStats.GetUsedInfo(is.physicalTableID) - } - if isMatchProp { - is.Desc = prop.SortItems[0].Desc - is.KeepOrder = true - } - return is -} - -func findBestTask4LogicalCTE(p *LogicalCTE, prop *property.PhysicalProperty, counter *base.PlanCounterTp, pop *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { - if p.ChildLen() > 0 { - return p.BaseLogicalPlan.FindBestTask(prop, counter, pop) - } - if !prop.IsSortItemEmpty() && !prop.CanAddEnforcer { - return base.InvalidTask, 1, nil - } - // The physical plan has been build when derive stats. - pcte := PhysicalCTE{SeedPlan: p.Cte.seedPartPhysicalPlan, RecurPlan: p.Cte.recursivePartPhysicalPlan, CTE: p.Cte, cteAsName: p.CteAsName, cteName: p.CteName}.Init(p.SCtx(), p.StatsInfo()) - pcte.SetSchema(p.Schema()) - if prop.IsFlashProp() && prop.CTEProducerStatus == property.AllCTECanMpp { - pcte.readerReceiver = PhysicalExchangeReceiver{IsCTEReader: true}.Init(p.SCtx(), p.StatsInfo()) - if prop.MPPPartitionTp != property.AnyType { - return base.InvalidTask, 1, nil - } - t = &MppTask{ - p: pcte, - partTp: prop.MPPPartitionTp, - hashCols: prop.MPPPartitionCols, - tblColHists: p.StatsInfo().HistColl, - } - } else { - rt := &RootTask{} - rt.SetPlan(pcte) - rt.SetEmpty(false) - t = rt - } - if prop.CanAddEnforcer { - t = enforceProperty(prop, t, p.Plan.SCtx()) - } - return t, 1, nil -} - -func findBestTask4LogicalCTETable(lp base.LogicalPlan, prop *property.PhysicalProperty, _ *base.PlanCounterTp, _ *optimizetrace.PhysicalOptimizeOp) (t base.Task, cntPlan int64, err error) { - p := lp.(*logicalop.LogicalCTETable) - if !prop.IsSortItemEmpty() { - return base.InvalidTask, 0, nil - } - - pcteTable := PhysicalCTETable{IDForStorage: p.IDForStorage}.Init(p.SCtx(), p.StatsInfo()) - pcteTable.SetSchema(p.Schema()) - rt := &RootTask{} - rt.SetPlan(pcteTable) - t = rt - return t, 1, nil -} - -func appendCandidate(lp base.LogicalPlan, task base.Task, prop *property.PhysicalProperty, opt *optimizetrace.PhysicalOptimizeOp) { - if task == nil || task.Invalid() { - return - } - utilfuncp.AppendCandidate4PhysicalOptimizeOp(opt, lp, task.Plan(), prop) -} - -// PushDownNot here can convert condition 'not (a != 1)' to 'a = 1'. When we build range from conds, the condition like -// 'not (a != 1)' would not be handled so we need to convert it to 'a = 1', which can be handled when building range. -func pushDownNot(ctx expression.BuildContext, conds []expression.Expression) []expression.Expression { - for i, cond := range conds { - conds[i] = expression.PushDownNot(ctx, cond) - } - return conds -} - -func validateTableSamplePlan(ds *DataSource, t base.Task, err error) error { - if err != nil { - return err - } - if ds.SampleInfo != nil && !t.Invalid() { - if _, ok := t.Plan().(*PhysicalTableSample); !ok { - return expression.ErrInvalidTableSample.GenWithStackByArgs("plan not supported") - } - } - return nil -} diff --git a/pkg/planner/core/logical_join.go b/pkg/planner/core/logical_join.go index 2e37caf53bf0c..99dd7b4b97aa5 100644 --- a/pkg/planner/core/logical_join.go +++ b/pkg/planner/core/logical_join.go @@ -600,12 +600,12 @@ func (p *LogicalJoin) PreparePossibleProperties(_ *expression.Schema, childrenPr // If the hint is not matched, it will get other candidates. // If the hint is not figured, we will pick all candidates. func (p *LogicalJoin) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - if val, _err_ := failpoint.Eval(_curpkg_("MockOnlyEnableIndexHashJoin")); _err_ == nil { + failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { indexJoins, _ := tryToGetIndexJoin(p, prop) - return indexJoins, true, nil + failpoint.Return(indexJoins, true, nil) } - } + }) if !isJoinHintSupportedInMPPMode(p.PreferJoinType) { if hasMPPJoinHints(p.PreferJoinType) { diff --git a/pkg/planner/core/logical_join.go__failpoint_stash__ b/pkg/planner/core/logical_join.go__failpoint_stash__ deleted file mode 100644 index 99dd7b4b97aa5..0000000000000 --- a/pkg/planner/core/logical_join.go__failpoint_stash__ +++ /dev/null @@ -1,1672 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "bytes" - "fmt" - "math" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/cardinality" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/cost" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util" - "github.com/pingcap/tidb/pkg/planner/funcdep" - "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" - "github.com/pingcap/tidb/pkg/planner/util/utilfuncp" - "github.com/pingcap/tidb/pkg/types" - utilhint "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/intset" - "github.com/pingcap/tidb/pkg/util/plancodec" -) - -// JoinType contains CrossJoin, InnerJoin, LeftOuterJoin, RightOuterJoin, SemiJoin, AntiJoin. -type JoinType int - -const ( - // InnerJoin means inner join. - InnerJoin JoinType = iota - // LeftOuterJoin means left join. - LeftOuterJoin - // RightOuterJoin means right join. - RightOuterJoin - // SemiJoin means if row a in table A matches some rows in B, just output a. - SemiJoin - // AntiSemiJoin means if row a in table A does not match any row in B, then output a. - AntiSemiJoin - // LeftOuterSemiJoin means if row a in table A matches some rows in B, output (a, true), otherwise, output (a, false). - LeftOuterSemiJoin - // AntiLeftOuterSemiJoin means if row a in table A matches some rows in B, output (a, false), otherwise, output (a, true). - AntiLeftOuterSemiJoin -) - -// IsOuterJoin returns if this joiner is an outer joiner -func (tp JoinType) IsOuterJoin() bool { - return tp == LeftOuterJoin || tp == RightOuterJoin || - tp == LeftOuterSemiJoin || tp == AntiLeftOuterSemiJoin -} - -// IsSemiJoin returns if this joiner is a semi/anti-semi joiner -func (tp JoinType) IsSemiJoin() bool { - return tp == SemiJoin || tp == AntiSemiJoin || - tp == LeftOuterSemiJoin || tp == AntiLeftOuterSemiJoin -} - -func (tp JoinType) String() string { - switch tp { - case InnerJoin: - return "inner join" - case LeftOuterJoin: - return "left outer join" - case RightOuterJoin: - return "right outer join" - case SemiJoin: - return "semi join" - case AntiSemiJoin: - return "anti semi join" - case LeftOuterSemiJoin: - return "left outer semi join" - case AntiLeftOuterSemiJoin: - return "anti left outer semi join" - } - return "unsupported join type" -} - -// LogicalJoin is the logical join plan. -type LogicalJoin struct { - logicalop.LogicalSchemaProducer - - JoinType JoinType - Reordered bool - CartesianJoin bool - StraightJoin bool - - // HintInfo stores the join algorithm hint information specified by client. - HintInfo *utilhint.PlanHints - PreferJoinType uint - PreferJoinOrder bool - LeftPreferJoinType uint - RightPreferJoinType uint - - EqualConditions []*expression.ScalarFunction - // NAEQConditions means null aware equal conditions, which is used for null aware semi joins. - NAEQConditions []*expression.ScalarFunction - LeftConditions expression.CNFExprs - RightConditions expression.CNFExprs - OtherConditions expression.CNFExprs - - LeftProperties [][]*expression.Column - RightProperties [][]*expression.Column - - // DefaultValues is only used for left/right outer join, which is values the inner row's should be when the outer table - // doesn't match any inner table's row. - // That it's nil just means the default values is a slice of NULL. - // Currently, only `aggregation push down` phase will set this. - DefaultValues []types.Datum - - // FullSchema contains all the columns that the Join can output. It's ordered as [outer schema..., inner schema...]. - // This is useful for natural joins and "using" joins. In these cases, the join key columns from the - // inner side (or the right side when it's an inner join) will not be in the schema of Join. - // But upper operators should be able to find those "redundant" columns, and the user also can specifically select - // those columns, so we put the "redundant" columns here to make them be able to be found. - // - // For example: - // create table t1(a int, b int); create table t2(a int, b int); - // select * from t1 join t2 using (b); - // schema of the Join will be [t1.b, t1.a, t2.a]; FullSchema will be [t1.a, t1.b, t2.a, t2.b]. - // - // We record all columns and keep them ordered is for correctly handling SQLs like - // select t1.*, t2.* from t1 join t2 using (b); - // (*PlanBuilder).unfoldWildStar() handles the schema for such case. - FullSchema *expression.Schema - FullNames types.NameSlice - - // EqualCondOutCnt indicates the estimated count of joined rows after evaluating `EqualConditions`. - EqualCondOutCnt float64 -} - -// Init initializes LogicalJoin. -func (p LogicalJoin) Init(ctx base.PlanContext, offset int) *LogicalJoin { - p.BaseLogicalPlan = logicalop.NewBaseLogicalPlan(ctx, plancodec.TypeJoin, &p, offset) - return &p -} - -// *************************** start implementation of Plan interface *************************** - -// ExplainInfo implements Plan interface. -func (p *LogicalJoin) ExplainInfo() string { - evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() - buffer := bytes.NewBufferString(p.JoinType.String()) - if len(p.EqualConditions) > 0 { - fmt.Fprintf(buffer, ", equal:%v", p.EqualConditions) - } - if len(p.LeftConditions) > 0 { - fmt.Fprintf(buffer, ", left cond:%s", - expression.SortedExplainExpressionList(evalCtx, p.LeftConditions)) - } - if len(p.RightConditions) > 0 { - fmt.Fprintf(buffer, ", right cond:%s", - expression.SortedExplainExpressionList(evalCtx, p.RightConditions)) - } - if len(p.OtherConditions) > 0 { - fmt.Fprintf(buffer, ", other cond:%s", - expression.SortedExplainExpressionList(evalCtx, p.OtherConditions)) - } - return buffer.String() -} - -// ReplaceExprColumns implements base.LogicalPlan interface. -func (p *LogicalJoin) ReplaceExprColumns(replace map[string]*expression.Column) { - for _, equalExpr := range p.EqualConditions { - ruleutil.ResolveExprAndReplace(equalExpr, replace) - } - for _, leftExpr := range p.LeftConditions { - ruleutil.ResolveExprAndReplace(leftExpr, replace) - } - for _, rightExpr := range p.RightConditions { - ruleutil.ResolveExprAndReplace(rightExpr, replace) - } - for _, otherExpr := range p.OtherConditions { - ruleutil.ResolveExprAndReplace(otherExpr, replace) - } -} - -// *************************** end implementation of Plan interface *************************** - -// *************************** start implementation of logicalPlan interface *************************** - -// HashCode inherits the BaseLogicalPlan.LogicalPlan.<0th> implementation. - -// PredicatePushDown implements the base.LogicalPlan.<1st> interface. -func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (ret []expression.Expression, retPlan base.LogicalPlan) { - var equalCond []*expression.ScalarFunction - var leftPushCond, rightPushCond, otherCond, leftCond, rightCond []expression.Expression - switch p.JoinType { - case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: - predicates = p.outerJoinPropConst(predicates) - dual := Conds2TableDual(p, predicates) - if dual != nil { - appendTableDualTraceStep(p, dual, predicates, opt) - return ret, dual - } - // Handle where conditions - predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) - // Only derive left where condition, because right where condition cannot be pushed down - equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true, false) - leftCond = leftPushCond - // Handle join conditions, only derive right join condition, because left join condition cannot be pushed down - _, derivedRightJoinCond := DeriveOtherConditions( - p, p.Children()[0].Schema(), p.Children()[1].Schema(), false, true) - rightCond = append(p.RightConditions, derivedRightJoinCond...) - p.RightConditions = nil - ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) - ret = append(ret, rightPushCond...) - case RightOuterJoin: - predicates = p.outerJoinPropConst(predicates) - dual := Conds2TableDual(p, predicates) - if dual != nil { - appendTableDualTraceStep(p, dual, predicates, opt) - return ret, dual - } - // Handle where conditions - predicates = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), predicates) - // Only derive right where condition, because left where condition cannot be pushed down - equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, false, true) - rightCond = rightPushCond - // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down - derivedLeftJoinCond, _ := DeriveOtherConditions( - p, p.Children()[0].Schema(), p.Children()[1].Schema(), true, false) - leftCond = append(p.LeftConditions, derivedLeftJoinCond...) - p.LeftConditions = nil - ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) - ret = append(ret, leftPushCond...) - case SemiJoin, InnerJoin: - tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) - tempCond = append(tempCond, p.LeftConditions...) - tempCond = append(tempCond, p.RightConditions...) - tempCond = append(tempCond, expression.ScalarFuncs2Exprs(p.EqualConditions)...) - tempCond = append(tempCond, p.OtherConditions...) - tempCond = append(tempCond, predicates...) - tempCond = expression.ExtractFiltersFromDNFs(p.SCtx().GetExprCtx(), tempCond) - tempCond = expression.PropagateConstant(p.SCtx().GetExprCtx(), tempCond) - // Return table dual when filter is constant false or null. - dual := Conds2TableDual(p, tempCond) - if dual != nil { - appendTableDualTraceStep(p, dual, tempCond, opt) - return ret, dual - } - equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(tempCond, true, true) - p.LeftConditions = nil - p.RightConditions = nil - p.EqualConditions = equalCond - p.OtherConditions = otherCond - leftCond = leftPushCond - rightCond = rightPushCond - case AntiSemiJoin: - predicates = expression.PropagateConstant(p.SCtx().GetExprCtx(), predicates) - // Return table dual when filter is constant false or null. - dual := Conds2TableDual(p, predicates) - if dual != nil { - appendTableDualTraceStep(p, dual, predicates, opt) - return ret, dual - } - // `predicates` should only contain left conditions or constant filters. - _, leftPushCond, rightPushCond, _ = p.extractOnCondition(predicates, true, true) - // Do not derive `is not null` for anti join, since it may cause wrong results. - // For example: - // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, - // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, - // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, - leftCond = leftPushCond - rightCond = append(p.RightConditions, rightPushCond...) - p.RightConditions = nil - } - leftCond = expression.RemoveDupExprs(leftCond) - rightCond = expression.RemoveDupExprs(rightCond) - leftRet, lCh := p.Children()[0].PredicatePushDown(leftCond, opt) - rightRet, rCh := p.Children()[1].PredicatePushDown(rightCond, opt) - utilfuncp.AddSelection(p, lCh, leftRet, 0, opt) - utilfuncp.AddSelection(p, rCh, rightRet, 1, opt) - p.updateEQCond() - ruleutil.BuildKeyInfoPortal(p) - return ret, p.Self() -} - -// PruneColumns implements the base.LogicalPlan.<2nd> interface. -func (p *LogicalJoin) PruneColumns(parentUsedCols []*expression.Column, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, error) { - leftCols, rightCols := p.extractUsedCols(parentUsedCols) - - var err error - p.Children()[0], err = p.Children()[0].PruneColumns(leftCols, opt) - if err != nil { - return nil, err - } - - p.Children()[1], err = p.Children()[1].PruneColumns(rightCols, opt) - if err != nil { - return nil, err - } - - p.mergeSchema() - if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - joinCol := p.Schema().Columns[len(p.Schema().Columns)-1] - parentUsedCols = append(parentUsedCols, joinCol) - } - p.InlineProjection(parentUsedCols, opt) - return p, nil -} - -// FindBestTask inherits the BaseLogicalPlan.LogicalPlan.<3rd> implementation. - -// BuildKeyInfo implements the base.LogicalPlan.<4th> interface. -func (p *LogicalJoin) BuildKeyInfo(selfSchema *expression.Schema, childSchema []*expression.Schema) { - p.LogicalSchemaProducer.BuildKeyInfo(selfSchema, childSchema) - switch p.JoinType { - case SemiJoin, LeftOuterSemiJoin, AntiSemiJoin, AntiLeftOuterSemiJoin: - selfSchema.Keys = childSchema[0].Clone().Keys - case InnerJoin, LeftOuterJoin, RightOuterJoin: - // If there is no equal conditions, then cartesian product can't be prevented and unique key information will destroy. - if len(p.EqualConditions) == 0 { - return - } - lOk := false - rOk := false - // Such as 'select * from t1 join t2 where t1.a = t2.a and t1.b = t2.b'. - // If one sides (a, b) is a unique key, then the unique key information is remained. - // But we don't consider this situation currently. - // Only key made by one column is considered now. - evalCtx := p.SCtx().GetExprCtx().GetEvalCtx() - for _, expr := range p.EqualConditions { - ln := expr.GetArgs()[0].(*expression.Column) - rn := expr.GetArgs()[1].(*expression.Column) - for _, key := range childSchema[0].Keys { - if len(key) == 1 && key[0].Equal(evalCtx, ln) { - lOk = true - break - } - } - for _, key := range childSchema[1].Keys { - if len(key) == 1 && key[0].Equal(evalCtx, rn) { - rOk = true - break - } - } - } - // For inner join, if one side of one equal condition is unique key, - // another side's unique key information will all be reserved. - // If it's an outer join, NULL value will fill some position, which will destroy the unique key information. - if lOk && p.JoinType != LeftOuterJoin { - selfSchema.Keys = append(selfSchema.Keys, childSchema[1].Keys...) - } - if rOk && p.JoinType != RightOuterJoin { - selfSchema.Keys = append(selfSchema.Keys, childSchema[0].Keys...) - } - } -} - -// PushDownTopN implements the base.LogicalPlan.<5th> interface. -func (p *LogicalJoin) PushDownTopN(topNLogicalPlan base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { - var topN *logicalop.LogicalTopN - if topNLogicalPlan != nil { - topN = topNLogicalPlan.(*logicalop.LogicalTopN) - } - switch p.JoinType { - case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: - p.Children()[0] = p.pushDownTopNToChild(topN, 0, opt) - p.Children()[1] = p.Children()[1].PushDownTopN(nil, opt) - case RightOuterJoin: - p.Children()[1] = p.pushDownTopNToChild(topN, 1, opt) - p.Children()[0] = p.Children()[0].PushDownTopN(nil, opt) - default: - return p.BaseLogicalPlan.PushDownTopN(topN, opt) - } - - // The LogicalJoin may be also a LogicalApply. So we must use self to set parents. - if topN != nil { - return topN.AttachChild(p.Self(), opt) - } - return p.Self() -} - -// DeriveTopN inherits the BaseLogicalPlan.LogicalPlan.<6th> implementation. - -// PredicateSimplification inherits the BaseLogicalPlan.LogicalPlan.<7th> implementation. - -// ConstantPropagation implements the base.LogicalPlan.<8th> interface. -// about the logic of constant propagation in From List. -// Query: select * from t, (select a, b from s where s.a>1) tmp where tmp.a=t.a -// Origin logical plan: -/* - +----------------+ - | LogicalJoin | - +-------^--------+ - | - +-------------+--------------+ - | | -+-----+------+ +------+------+ -| Projection | | TableScan | -+-----^------+ +-------------+ - | - | -+-----+------+ -| Selection | -| s.a>1 | -+------------+ -*/ -// 1. 'PullUpConstantPredicates': Call this function until find selection and pull up the constant predicate layer by layer -// LogicalSelection: find the s.a>1 -// LogicalProjection: get the s.a>1 and pull up it, changed to tmp.a>1 -// 2. 'addCandidateSelection': Add selection above of LogicalJoin, -// put all predicates pulled up from the lower layer into the current new selection. -// LogicalSelection: tmp.a >1 -// -// Optimized plan: -/* - +----------------+ - | Selection | - | tmp.a>1 | - +-------^--------+ - | - +-------+--------+ - | LogicalJoin | - +-------^--------+ - | - +-------------+--------------+ - | | -+-----+------+ +------+------+ -| Projection | | TableScan | -+-----^------+ +-------------+ - | - | -+-----+------+ -| Selection | -| s.a>1 | -+------------+ -*/ -// Return nil if the root of plan has not been changed -// Return new root if the root of plan is changed to selection -func (p *LogicalJoin) ConstantPropagation(parentPlan base.LogicalPlan, currentChildIdx int, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { - // step1: get constant predicate from left or right according to the JoinType - var getConstantPredicateFromLeft bool - var getConstantPredicateFromRight bool - switch p.JoinType { - case LeftOuterJoin: - getConstantPredicateFromLeft = true - case RightOuterJoin: - getConstantPredicateFromRight = true - case InnerJoin: - getConstantPredicateFromLeft = true - getConstantPredicateFromRight = true - default: - return - } - var candidateConstantPredicates []expression.Expression - if getConstantPredicateFromLeft { - candidateConstantPredicates = p.Children()[0].PullUpConstantPredicates() - } - if getConstantPredicateFromRight { - candidateConstantPredicates = append(candidateConstantPredicates, p.Children()[1].PullUpConstantPredicates()...) - } - if len(candidateConstantPredicates) == 0 { - return - } - - // step2: add selection above of LogicalJoin - return addCandidateSelection(p, currentChildIdx, parentPlan, candidateConstantPredicates, opt) -} - -// PullUpConstantPredicates inherits the BaseLogicalPlan.LogicalPlan.<9th> implementation. - -// RecursiveDeriveStats inherits the BaseLogicalPlan.LogicalPlan.<10th> implementation. - -// DeriveStats implements the base.LogicalPlan.<11th> interface. -// If the type of join is SemiJoin, the selectivity of it will be same as selection's. -// If the type of join is LeftOuterSemiJoin, it will not add or remove any row. The last column is a boolean value, whose NDV should be two. -// If the type of join is inner/outer join, the output of join(s, t) should be N(s) * N(t) / (V(s.key) * V(t.key)) * Min(s.key, t.key). -// N(s) stands for the number of rows in relation s. V(s.key) means the NDV of join key in s. -// This is a quite simple strategy: We assume every bucket of relation which will participate join has the same number of rows, and apply cross join for -// every matched bucket. -func (p *LogicalJoin) DeriveStats(childStats []*property.StatsInfo, selfSchema *expression.Schema, childSchema []*expression.Schema, colGroups [][]*expression.Column) (*property.StatsInfo, error) { - if p.StatsInfo() != nil { - // Reload GroupNDVs since colGroups may have changed. - p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) - return p.StatsInfo(), nil - } - leftProfile, rightProfile := childStats[0], childStats[1] - leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() - p.EqualCondOutCnt = cardinality.EstimateFullJoinRowCount(p.SCtx(), - 0 == len(p.EqualConditions), - leftProfile, rightProfile, - leftJoinKeys, rightJoinKeys, - childSchema[0], childSchema[1], - nil, nil) - if p.JoinType == SemiJoin || p.JoinType == AntiSemiJoin { - p.SetStats(&property.StatsInfo{ - RowCount: leftProfile.RowCount * cost.SelectionFactor, - ColNDVs: make(map[int64]float64, len(leftProfile.ColNDVs)), - }) - for id, c := range leftProfile.ColNDVs { - p.StatsInfo().ColNDVs[id] = c * cost.SelectionFactor - } - return p.StatsInfo(), nil - } - if p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - p.SetStats(&property.StatsInfo{ - RowCount: leftProfile.RowCount, - ColNDVs: make(map[int64]float64, selfSchema.Len()), - }) - for id, c := range leftProfile.ColNDVs { - p.StatsInfo().ColNDVs[id] = c - } - p.StatsInfo().ColNDVs[selfSchema.Columns[selfSchema.Len()-1].UniqueID] = 2.0 - p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) - return p.StatsInfo(), nil - } - count := p.EqualCondOutCnt - if p.JoinType == LeftOuterJoin { - count = math.Max(count, leftProfile.RowCount) - } else if p.JoinType == RightOuterJoin { - count = math.Max(count, rightProfile.RowCount) - } - colNDVs := make(map[int64]float64, selfSchema.Len()) - for id, c := range leftProfile.ColNDVs { - colNDVs[id] = math.Min(c, count) - } - for id, c := range rightProfile.ColNDVs { - colNDVs[id] = math.Min(c, count) - } - p.SetStats(&property.StatsInfo{ - RowCount: count, - ColNDVs: colNDVs, - }) - p.StatsInfo().GroupNDVs = p.getGroupNDVs(colGroups, childStats) - return p.StatsInfo(), nil -} - -// ExtractColGroups implements the base.LogicalPlan.<12th> interface. -func (p *LogicalJoin) ExtractColGroups(colGroups [][]*expression.Column) [][]*expression.Column { - leftJoinKeys, rightJoinKeys, _, _ := p.GetJoinKeys() - extracted := make([][]*expression.Column, 0, 2+len(colGroups)) - if len(leftJoinKeys) > 1 && (p.JoinType == InnerJoin || p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin) { - extracted = append(extracted, expression.SortColumns(leftJoinKeys), expression.SortColumns(rightJoinKeys)) - } - var outerSchema *expression.Schema - if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - outerSchema = p.Children()[0].Schema() - } else if p.JoinType == RightOuterJoin { - outerSchema = p.Children()[1].Schema() - } - if len(colGroups) == 0 || outerSchema == nil { - return extracted - } - _, offsets := outerSchema.ExtractColGroups(colGroups) - if len(offsets) == 0 { - return extracted - } - for _, offset := range offsets { - extracted = append(extracted, colGroups[offset]) - } - return extracted -} - -// PreparePossibleProperties implements base.LogicalPlan.<13th> interface. -func (p *LogicalJoin) PreparePossibleProperties(_ *expression.Schema, childrenProperties ...[][]*expression.Column) [][]*expression.Column { - leftProperties := childrenProperties[0] - rightProperties := childrenProperties[1] - // TODO: We should consider properties propagation. - p.LeftProperties = leftProperties - p.RightProperties = rightProperties - if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin { - rightProperties = nil - } else if p.JoinType == RightOuterJoin { - leftProperties = nil - } - resultProperties := make([][]*expression.Column, len(leftProperties)+len(rightProperties)) - for i, cols := range leftProperties { - resultProperties[i] = make([]*expression.Column, len(cols)) - copy(resultProperties[i], cols) - } - leftLen := len(leftProperties) - for i, cols := range rightProperties { - resultProperties[leftLen+i] = make([]*expression.Column, len(cols)) - copy(resultProperties[leftLen+i], cols) - } - return resultProperties -} - -// ExhaustPhysicalPlans implements the base.LogicalPlan.<14th> interface. -// it can generates hash join, index join and sort merge join. -// Firstly we check the hint, if hint is figured by user, we force to choose the corresponding physical plan. -// If the hint is not matched, it will get other candidates. -// If the hint is not figured, we will pick all candidates. -func (p *LogicalJoin) ExhaustPhysicalPlans(prop *property.PhysicalProperty) ([]base.PhysicalPlan, bool, error) { - failpoint.Inject("MockOnlyEnableIndexHashJoin", func(val failpoint.Value) { - if val.(bool) && !p.SCtx().GetSessionVars().InRestrictedSQL { - indexJoins, _ := tryToGetIndexJoin(p, prop) - failpoint.Return(indexJoins, true, nil) - } - }) - - if !isJoinHintSupportedInMPPMode(p.PreferJoinType) { - if hasMPPJoinHints(p.PreferJoinType) { - // If there are MPP hints but has some conflicts join method hints, all the join hints are invalid. - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning("The MPP join hints are in conflict, and you can only specify join method hints that are currently supported by MPP mode now") - p.PreferJoinType = 0 - } else { - // If there are no MPP hints but has some conflicts join method hints, the MPP mode will be blocked. - p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because you have used hint to specify a join algorithm which is not supported by mpp now.") - if prop.IsFlashProp() { - return nil, false, nil - } - } - } - if prop.MPPPartitionTp == property.BroadcastType { - return nil, false, nil - } - joins := make([]base.PhysicalPlan, 0, 8) - canPushToTiFlash := p.CanPushToCop(kv.TiFlash) - if p.SCtx().GetSessionVars().IsMPPAllowed() && canPushToTiFlash { - if (p.PreferJoinType & utilhint.PreferShuffleJoin) > 0 { - if shuffleJoins := tryToGetMppHashJoin(p, prop, false); len(shuffleJoins) > 0 { - return shuffleJoins, true, nil - } - } - if (p.PreferJoinType & utilhint.PreferBCJoin) > 0 { - if bcastJoins := tryToGetMppHashJoin(p, prop, true); len(bcastJoins) > 0 { - return bcastJoins, true, nil - } - } - if preferMppBCJ(p) { - mppJoins := tryToGetMppHashJoin(p, prop, true) - joins = append(joins, mppJoins...) - } else { - mppJoins := tryToGetMppHashJoin(p, prop, false) - joins = append(joins, mppJoins...) - } - } else { - hasMppHints := false - var errMsg string - if (p.PreferJoinType & utilhint.PreferShuffleJoin) > 0 { - errMsg = "The join can not push down to the MPP side, the shuffle_join() hint is invalid" - hasMppHints = true - } - if (p.PreferJoinType & utilhint.PreferBCJoin) > 0 { - errMsg = "The join can not push down to the MPP side, the broadcast_join() hint is invalid" - hasMppHints = true - } - if hasMppHints { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) - } - } - if prop.IsFlashProp() { - return joins, true, nil - } - - if !p.IsNAAJ() { - // naaj refuse merge join and index join. - mergeJoins := GetMergeJoin(p, prop, p.Schema(), p.StatsInfo(), p.Children()[0].StatsInfo(), p.Children()[1].StatsInfo()) - if (p.PreferJoinType&utilhint.PreferMergeJoin) > 0 && len(mergeJoins) > 0 { - return mergeJoins, true, nil - } - joins = append(joins, mergeJoins...) - - indexJoins, forced := tryToGetIndexJoin(p, prop) - if forced { - return indexJoins, true, nil - } - joins = append(joins, indexJoins...) - } - - hashJoins, forced := getHashJoins(p, prop) - if forced && len(hashJoins) > 0 { - return hashJoins, true, nil - } - joins = append(joins, hashJoins...) - - if p.PreferJoinType > 0 { - // If we reach here, it means we have a hint that doesn't work. - // It might be affected by the required property, so we enforce - // this property and try the hint again. - return joins, false, nil - } - return joins, true, nil -} - -// ExtractCorrelatedCols implements the base.LogicalPlan.<15th> interface. -func (p *LogicalJoin) ExtractCorrelatedCols() []*expression.CorrelatedColumn { - corCols := make([]*expression.CorrelatedColumn, 0, len(p.EqualConditions)+len(p.LeftConditions)+len(p.RightConditions)+len(p.OtherConditions)) - for _, fun := range p.EqualConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - for _, fun := range p.LeftConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - for _, fun := range p.RightConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - for _, fun := range p.OtherConditions { - corCols = append(corCols, expression.ExtractCorColumns(fun)...) - } - return corCols -} - -// MaxOneRow inherits the BaseLogicalPlan.LogicalPlan.<16th> implementation. - -// Children inherits the BaseLogicalPlan.LogicalPlan.<17th> implementation. - -// SetChildren inherits the BaseLogicalPlan.LogicalPlan.<18th> implementation. - -// SetChild inherits the BaseLogicalPlan.LogicalPlan.<19th> implementation. - -// RollBackTaskMap inherits the BaseLogicalPlan.LogicalPlan.<20th> implementation. - -// CanPushToCop inherits the BaseLogicalPlan.LogicalPlan.<21st> implementation. - -// ExtractFD implements the base.LogicalPlan.<22th> interface. -func (p *LogicalJoin) ExtractFD() *funcdep.FDSet { - switch p.JoinType { - case InnerJoin: - return p.extractFDForInnerJoin(nil) - case LeftOuterJoin, RightOuterJoin: - return p.extractFDForOuterJoin(nil) - case SemiJoin: - return p.extractFDForSemiJoin(nil) - default: - return &funcdep.FDSet{HashCodeToUniqueID: make(map[string]int)} - } -} - -// GetBaseLogicalPlan inherits the BaseLogicalPlan.LogicalPlan.<23th> implementation. - -// ConvertOuterToInnerJoin implements base.LogicalPlan.<24th> interface. -func (p *LogicalJoin) ConvertOuterToInnerJoin(predicates []expression.Expression) base.LogicalPlan { - innerTable := p.Children()[0] - outerTable := p.Children()[1] - switchChild := false - - if p.JoinType == LeftOuterJoin { - innerTable, outerTable = outerTable, innerTable - switchChild = true - } - - // First, simplify this join - if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { - canBeSimplified := false - for _, expr := range predicates { - isOk := util.IsNullRejected(p.SCtx(), innerTable.Schema(), expr) - if isOk { - canBeSimplified = true - break - } - } - if canBeSimplified { - p.JoinType = InnerJoin - } - } - - // Next simplify join children - - combinedCond := mergeOnClausePredicates(p, predicates) - if p.JoinType == LeftOuterJoin || p.JoinType == RightOuterJoin { - innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) - outerTable = outerTable.ConvertOuterToInnerJoin(predicates) - } else if p.JoinType == InnerJoin || p.JoinType == SemiJoin { - innerTable = innerTable.ConvertOuterToInnerJoin(combinedCond) - outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) - } else if p.JoinType == AntiSemiJoin { - innerTable = innerTable.ConvertOuterToInnerJoin(predicates) - outerTable = outerTable.ConvertOuterToInnerJoin(combinedCond) - } else { - innerTable = innerTable.ConvertOuterToInnerJoin(predicates) - outerTable = outerTable.ConvertOuterToInnerJoin(predicates) - } - - if switchChild { - p.SetChild(0, outerTable) - p.SetChild(1, innerTable) - } else { - p.SetChild(0, innerTable) - p.SetChild(1, outerTable) - } - - return p -} - -// *************************** end implementation of logicalPlan interface *************************** - -// IsNAAJ checks if the join is a non-adjacent-join. -func (p *LogicalJoin) IsNAAJ() bool { - return len(p.NAEQConditions) > 0 -} - -// Shallow copies a LogicalJoin struct. -func (p *LogicalJoin) Shallow() *LogicalJoin { - join := *p - return join.Init(p.SCtx(), p.QueryBlockOffset()) -} - -func (p *LogicalJoin) extractFDForSemiJoin(filtersFromApply []expression.Expression) *funcdep.FDSet { - // 1: since semi join will keep the part or all rows of the outer table, it's outer FD can be saved. - // 2: the un-projected column will be left for the upper layer projection or already be pruned from bottom up. - outerFD, _ := p.Children()[0].ExtractFD(), p.Children()[1].ExtractFD() - fds := outerFD - - eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions) - allConds := append(eqCondSlice, p.OtherConditions...) - allConds = append(allConds, filtersFromApply...) - notNullColsFromFilters := ExtractNotNullFromConds(allConds, p) - - constUniqueIDs := ExtractConstantCols(p.LeftConditions, p.SCtx(), fds) - - fds.MakeNotNull(notNullColsFromFilters) - fds.AddConstants(constUniqueIDs) - p.SetFDs(fds) - return fds -} - -func (p *LogicalJoin) extractFDForInnerJoin(filtersFromApply []expression.Expression) *funcdep.FDSet { - leftFD, rightFD := p.Children()[0].ExtractFD(), p.Children()[1].ExtractFD() - fds := leftFD - fds.MakeCartesianProduct(rightFD) - - eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions) - // some join eq conditions are stored in the OtherConditions. - allConds := append(eqCondSlice, p.OtherConditions...) - allConds = append(allConds, filtersFromApply...) - notNullColsFromFilters := ExtractNotNullFromConds(allConds, p) - - constUniqueIDs := ExtractConstantCols(allConds, p.SCtx(), fds) - - equivUniqueIDs := ExtractEquivalenceCols(allConds, p.SCtx(), fds) - - fds.MakeNotNull(notNullColsFromFilters) - fds.AddConstants(constUniqueIDs) - for _, equiv := range equivUniqueIDs { - fds.AddEquivalence(equiv[0], equiv[1]) - } - // merge the not-null-cols/registered-map from both side together. - fds.NotNullCols.UnionWith(rightFD.NotNullCols) - if fds.HashCodeToUniqueID == nil { - fds.HashCodeToUniqueID = rightFD.HashCodeToUniqueID - } else { - for k, v := range rightFD.HashCodeToUniqueID { - // If there's same constant in the different subquery, we might go into this IF branch. - if _, ok := fds.HashCodeToUniqueID[k]; ok { - continue - } - fds.HashCodeToUniqueID[k] = v - } - } - for i, ok := rightFD.GroupByCols.Next(0); ok; i, ok = rightFD.GroupByCols.Next(i + 1) { - fds.GroupByCols.Insert(i) - } - fds.HasAggBuilt = fds.HasAggBuilt || rightFD.HasAggBuilt - p.SetFDs(fds) - return fds -} - -func (p *LogicalJoin) extractFDForOuterJoin(filtersFromApply []expression.Expression) *funcdep.FDSet { - outerFD, innerFD := p.Children()[0].ExtractFD(), p.Children()[1].ExtractFD() - innerCondition := p.RightConditions - outerCondition := p.LeftConditions - outerCols, innerCols := intset.NewFastIntSet(), intset.NewFastIntSet() - for _, col := range p.Children()[0].Schema().Columns { - outerCols.Insert(int(col.UniqueID)) - } - for _, col := range p.Children()[1].Schema().Columns { - innerCols.Insert(int(col.UniqueID)) - } - if p.JoinType == RightOuterJoin { - innerFD, outerFD = outerFD, innerFD - innerCondition = p.LeftConditions - outerCondition = p.RightConditions - innerCols, outerCols = outerCols, innerCols - } - - eqCondSlice := expression.ScalarFuncs2Exprs(p.EqualConditions) - allConds := append(eqCondSlice, p.OtherConditions...) - allConds = append(allConds, innerCondition...) - allConds = append(allConds, outerCondition...) - allConds = append(allConds, filtersFromApply...) - notNullColsFromFilters := ExtractNotNullFromConds(allConds, p) - - filterFD := &funcdep.FDSet{HashCodeToUniqueID: make(map[string]int)} - - constUniqueIDs := ExtractConstantCols(allConds, p.SCtx(), filterFD) - - equivUniqueIDs := ExtractEquivalenceCols(allConds, p.SCtx(), filterFD) - - filterFD.AddConstants(constUniqueIDs) - equivOuterUniqueIDs := intset.NewFastIntSet() - equivAcrossNum := 0 - for _, equiv := range equivUniqueIDs { - filterFD.AddEquivalence(equiv[0], equiv[1]) - if equiv[0].SubsetOf(outerCols) && equiv[1].SubsetOf(innerCols) { - equivOuterUniqueIDs.UnionWith(equiv[0]) - equivAcrossNum++ - continue - } - if equiv[0].SubsetOf(innerCols) && equiv[1].SubsetOf(outerCols) { - equivOuterUniqueIDs.UnionWith(equiv[1]) - equivAcrossNum++ - } - } - filterFD.MakeNotNull(notNullColsFromFilters) - - // pre-perceive the filters for the convenience judgement of 3.3.1. - var opt funcdep.ArgOpts - if equivAcrossNum > 0 { - // find the equivalence FD across left and right cols. - var outConditionCols []*expression.Column - if len(outerCondition) != 0 { - outConditionCols = append(outConditionCols, expression.ExtractColumnsFromExpressions(nil, outerCondition, nil)...) - } - if len(p.OtherConditions) != 0 { - // other condition may contain right side cols, it doesn't affect the judgement of intersection of non-left-equiv cols. - outConditionCols = append(outConditionCols, expression.ExtractColumnsFromExpressions(nil, p.OtherConditions, nil)...) - } - outerConditionUniqueIDs := intset.NewFastIntSet() - for _, col := range outConditionCols { - outerConditionUniqueIDs.Insert(int(col.UniqueID)) - } - // judge whether left filters is on non-left-equiv cols. - if outerConditionUniqueIDs.Intersects(outerCols.Difference(equivOuterUniqueIDs)) { - opt.SkipFDRule331 = true - } - } else { - // if there is none across equivalence condition, skip rule 3.3.1. - opt.SkipFDRule331 = true - } - - opt.OnlyInnerFilter = len(eqCondSlice) == 0 && len(outerCondition) == 0 && len(p.OtherConditions) == 0 - if opt.OnlyInnerFilter { - // if one of the inner condition is constant false, the inner side are all null, left make constant all of that. - for _, one := range innerCondition { - if c, ok := one.(*expression.Constant); ok && c.DeferredExpr == nil && c.ParamMarker == nil { - if isTrue, err := c.Value.ToBool(p.SCtx().GetSessionVars().StmtCtx.TypeCtx()); err == nil { - if isTrue == 0 { - // c is false - opt.InnerIsFalse = true - } - } - } - } - } - - fds := outerFD - fds.MakeOuterJoin(innerFD, filterFD, outerCols, innerCols, &opt) - p.SetFDs(fds) - return fds -} - -// GetJoinKeys extracts join keys(columns) from EqualConditions. It returns left join keys, right -// join keys and an `isNullEQ` array which means the `joinKey[i]` is a `NullEQ` function. The `hasNullEQ` -// means whether there is a `NullEQ` of a join key. -func (p *LogicalJoin) GetJoinKeys() (leftKeys, rightKeys []*expression.Column, isNullEQ []bool, hasNullEQ bool) { - for _, expr := range p.EqualConditions { - leftKeys = append(leftKeys, expr.GetArgs()[0].(*expression.Column)) - rightKeys = append(rightKeys, expr.GetArgs()[1].(*expression.Column)) - isNullEQ = append(isNullEQ, expr.FuncName.L == ast.NullEQ) - hasNullEQ = hasNullEQ || expr.FuncName.L == ast.NullEQ - } - return -} - -// GetNAJoinKeys extracts join keys(columns) from NAEqualCondition. -func (p *LogicalJoin) GetNAJoinKeys() (leftKeys, rightKeys []*expression.Column) { - for _, expr := range p.NAEQConditions { - leftKeys = append(leftKeys, expr.GetArgs()[0].(*expression.Column)) - rightKeys = append(rightKeys, expr.GetArgs()[1].(*expression.Column)) - } - return -} - -// GetPotentialPartitionKeys return potential partition keys for join, the potential partition keys are -// the join keys of EqualConditions -func (p *LogicalJoin) GetPotentialPartitionKeys() (leftKeys, rightKeys []*property.MPPPartitionColumn) { - for _, expr := range p.EqualConditions { - _, coll := expr.CharsetAndCollation() - collateID := property.GetCollateIDByNameForPartition(coll) - leftKeys = append(leftKeys, &property.MPPPartitionColumn{Col: expr.GetArgs()[0].(*expression.Column), CollateID: collateID}) - rightKeys = append(rightKeys, &property.MPPPartitionColumn{Col: expr.GetArgs()[1].(*expression.Column), CollateID: collateID}) - } - return -} - -// Decorrelate eliminate the correlated column with if the col is in schema. -func (p *LogicalJoin) Decorrelate(schema *expression.Schema) { - for i, cond := range p.LeftConditions { - p.LeftConditions[i] = cond.Decorrelate(schema) - } - for i, cond := range p.RightConditions { - p.RightConditions[i] = cond.Decorrelate(schema) - } - for i, cond := range p.OtherConditions { - p.OtherConditions[i] = cond.Decorrelate(schema) - } - for i, cond := range p.EqualConditions { - p.EqualConditions[i] = cond.Decorrelate(schema).(*expression.ScalarFunction) - } -} - -// ColumnSubstituteAll is used in projection elimination in apply de-correlation. -// Substitutions for all conditions should be successful, otherwise, we should keep all conditions unchanged. -func (p *LogicalJoin) ColumnSubstituteAll(schema *expression.Schema, exprs []expression.Expression) (hasFail bool) { - // make a copy of exprs for convenience of substitution (may change/partially change the expr tree) - cpLeftConditions := make(expression.CNFExprs, len(p.LeftConditions)) - cpRightConditions := make(expression.CNFExprs, len(p.RightConditions)) - cpOtherConditions := make(expression.CNFExprs, len(p.OtherConditions)) - cpEqualConditions := make([]*expression.ScalarFunction, len(p.EqualConditions)) - copy(cpLeftConditions, p.LeftConditions) - copy(cpRightConditions, p.RightConditions) - copy(cpOtherConditions, p.OtherConditions) - copy(cpEqualConditions, p.EqualConditions) - - exprCtx := p.SCtx().GetExprCtx() - // try to substitute columns in these condition. - for i, cond := range cpLeftConditions { - if hasFail, cpLeftConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { - return - } - } - - for i, cond := range cpRightConditions { - if hasFail, cpRightConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { - return - } - } - - for i, cond := range cpOtherConditions { - if hasFail, cpOtherConditions[i] = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { - return - } - } - - for i, cond := range cpEqualConditions { - var tmp expression.Expression - if hasFail, tmp = expression.ColumnSubstituteAll(exprCtx, cond, schema, exprs); hasFail { - return - } - cpEqualConditions[i] = tmp.(*expression.ScalarFunction) - } - - // if all substituted, change them atomically here. - p.LeftConditions = cpLeftConditions - p.RightConditions = cpRightConditions - p.OtherConditions = cpOtherConditions - p.EqualConditions = cpEqualConditions - - for i := len(p.EqualConditions) - 1; i >= 0; i-- { - newCond := p.EqualConditions[i] - - // If the columns used in the new filter all come from the left child, - // we can push this filter to it. - if expression.ExprFromSchema(newCond, p.Children()[0].Schema()) { - p.LeftConditions = append(p.LeftConditions, newCond) - p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) - continue - } - - // If the columns used in the new filter all come from the right - // child, we can push this filter to it. - if expression.ExprFromSchema(newCond, p.Children()[1].Schema()) { - p.RightConditions = append(p.RightConditions, newCond) - p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) - continue - } - - _, lhsIsCol := newCond.GetArgs()[0].(*expression.Column) - _, rhsIsCol := newCond.GetArgs()[1].(*expression.Column) - - // If the columns used in the new filter are not all expression.Column, - // we can not use it as join's equal condition. - if !(lhsIsCol && rhsIsCol) { - p.OtherConditions = append(p.OtherConditions, newCond) - p.EqualConditions = append(p.EqualConditions[:i], p.EqualConditions[i+1:]...) - continue - } - - p.EqualConditions[i] = newCond - } - return false -} - -// AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and -// `OtherConditions` by the result of extract. -func (p *LogicalJoin) AttachOnConds(onConds []expression.Expression) { - eq, left, right, other := p.extractOnCondition(onConds, false, false) - p.AppendJoinConds(eq, left, right, other) -} - -// AppendJoinConds appends new join conditions. -func (p *LogicalJoin) AppendJoinConds(eq []*expression.ScalarFunction, left, right, other []expression.Expression) { - p.EqualConditions = append(eq, p.EqualConditions...) - p.LeftConditions = append(left, p.LeftConditions...) - p.RightConditions = append(right, p.RightConditions...) - p.OtherConditions = append(other, p.OtherConditions...) -} - -// ExtractJoinKeys extract join keys as a schema for child with childIdx. -func (p *LogicalJoin) ExtractJoinKeys(childIdx int) *expression.Schema { - joinKeys := make([]*expression.Column, 0, len(p.EqualConditions)) - for _, eqCond := range p.EqualConditions { - joinKeys = append(joinKeys, eqCond.GetArgs()[childIdx].(*expression.Column)) - } - return expression.NewSchema(joinKeys...) -} - -// extractUsedCols extracts all the needed columns. -func (p *LogicalJoin) extractUsedCols(parentUsedCols []*expression.Column) (leftCols []*expression.Column, rightCols []*expression.Column) { - for _, eqCond := range p.EqualConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(eqCond)...) - } - for _, leftCond := range p.LeftConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(leftCond)...) - } - for _, rightCond := range p.RightConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(rightCond)...) - } - for _, otherCond := range p.OtherConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...) - } - for _, naeqCond := range p.NAEQConditions { - parentUsedCols = append(parentUsedCols, expression.ExtractColumns(naeqCond)...) - } - lChild := p.Children()[0] - rChild := p.Children()[1] - for _, col := range parentUsedCols { - if lChild.Schema().Contains(col) { - leftCols = append(leftCols, col) - } else if rChild.Schema().Contains(col) { - rightCols = append(rightCols, col) - } - } - return leftCols, rightCols -} - -// MergeSchema merge the schema of left and right child of join. -func (p *LogicalJoin) mergeSchema() { - p.SetSchema(buildLogicalJoinSchema(p.JoinType, p)) -} - -// pushDownTopNToChild will push a topN to one child of join. The idx stands for join child index. 0 is for left child. -func (p *LogicalJoin) pushDownTopNToChild(topN *logicalop.LogicalTopN, idx int, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { - if topN == nil { - return p.Children()[idx].PushDownTopN(nil, opt) - } - - for _, by := range topN.ByItems { - cols := expression.ExtractColumns(by.Expr) - for _, col := range cols { - if !p.Children()[idx].Schema().Contains(col) { - return p.Children()[idx].PushDownTopN(nil, opt) - } - } - } - - newTopN := logicalop.LogicalTopN{ - Count: topN.Count + topN.Offset, - ByItems: make([]*util.ByItems, len(topN.ByItems)), - PreferLimitToCop: topN.PreferLimitToCop, - }.Init(topN.SCtx(), topN.QueryBlockOffset()) - for i := range topN.ByItems { - newTopN.ByItems[i] = topN.ByItems[i].Clone() - } - appendTopNPushDownJoinTraceStep(p, newTopN, idx, opt) - return p.Children()[idx].PushDownTopN(newTopN, opt) -} - -// Add a new selection between parent plan and current plan with candidate predicates -/* -+-------------+ +-------------+ -| parentPlan | | parentPlan | -+-----^-------+ +-----^-------+ - | --addCandidateSelection---> | -+-----+-------+ +-----------+--------------+ -| currentPlan | | selection | -+-------------+ | candidate predicate | - +-----------^--------------+ - | - | - +----+--------+ - | currentPlan | - +-------------+ -*/ -// If the currentPlan at the top of query plan, return new root plan (selection) -// Else return nil -func addCandidateSelection(currentPlan base.LogicalPlan, currentChildIdx int, parentPlan base.LogicalPlan, - candidatePredicates []expression.Expression, opt *optimizetrace.LogicalOptimizeOp) (newRoot base.LogicalPlan) { - // generate a new selection for candidatePredicates - selection := LogicalSelection{Conditions: candidatePredicates}.Init(currentPlan.SCtx(), currentPlan.QueryBlockOffset()) - // add selection above of p - if parentPlan == nil { - newRoot = selection - } else { - parentPlan.SetChild(currentChildIdx, selection) - } - selection.SetChildren(currentPlan) - appendAddSelectionTraceStep(parentPlan, currentPlan, selection, opt) - if parentPlan == nil { - return newRoot - } - return nil -} - -func (p *LogicalJoin) getGroupNDVs(colGroups [][]*expression.Column, childStats []*property.StatsInfo) []property.GroupNDV { - outerIdx := int(-1) - if p.JoinType == LeftOuterJoin || p.JoinType == LeftOuterSemiJoin || p.JoinType == AntiLeftOuterSemiJoin { - outerIdx = 0 - } else if p.JoinType == RightOuterJoin { - outerIdx = 1 - } - if outerIdx >= 0 && len(colGroups) > 0 { - return childStats[outerIdx].GroupNDVs - } - return nil -} - -// PreferAny checks whether the join type is in the joinFlags. -func (p *LogicalJoin) PreferAny(joinFlags ...uint) bool { - for _, flag := range joinFlags { - if p.PreferJoinType&flag > 0 { - return true - } - } - return false -} - -// ExtractOnCondition divide conditions in CNF of join node into 4 groups. -// These conditions can be where conditions, join conditions, or collection of both. -// If deriveLeft/deriveRight is set, we would try to derive more conditions for left/right plan. -func (p *LogicalJoin) ExtractOnCondition( - conditions []expression.Expression, - leftSchema *expression.Schema, - rightSchema *expression.Schema, - deriveLeft bool, - deriveRight bool) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, - rightCond []expression.Expression, otherCond []expression.Expression) { - ctx := p.SCtx() - for _, expr := range conditions { - // For queries like `select a in (select a from s where s.b = t.b) from t`, - // if subquery is empty caused by `s.b = t.b`, the result should always be - // false even if t.a is null or s.a is null. To make this join "empty aware", - // we should differentiate `t.a = s.a` from other column equal conditions, so - // we put it into OtherConditions instead of EqualConditions of join. - if expression.IsEQCondFromIn(expr) { - otherCond = append(otherCond, expr) - continue - } - binop, ok := expr.(*expression.ScalarFunction) - if ok && len(binop.GetArgs()) == 2 { - arg0, lOK := binop.GetArgs()[0].(*expression.Column) - arg1, rOK := binop.GetArgs()[1].(*expression.Column) - if lOK && rOK { - leftCol := leftSchema.RetrieveColumn(arg0) - rightCol := rightSchema.RetrieveColumn(arg1) - if leftCol == nil || rightCol == nil { - leftCol = leftSchema.RetrieveColumn(arg1) - rightCol = rightSchema.RetrieveColumn(arg0) - arg0, arg1 = arg1, arg0 - } - if leftCol != nil && rightCol != nil { - if deriveLeft { - if util.IsNullRejected(ctx, leftSchema, expr) && !mysql.HasNotNullFlag(leftCol.RetType.GetFlag()) { - notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), leftCol) - leftCond = append(leftCond, notNullExpr) - } - } - if deriveRight { - if util.IsNullRejected(ctx, rightSchema, expr) && !mysql.HasNotNullFlag(rightCol.RetType.GetFlag()) { - notNullExpr := expression.BuildNotNullExpr(ctx.GetExprCtx(), rightCol) - rightCond = append(rightCond, notNullExpr) - } - } - if binop.FuncName.L == ast.EQ { - cond := expression.NewFunctionInternal(ctx.GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), arg0, arg1) - eqCond = append(eqCond, cond.(*expression.ScalarFunction)) - continue - } - } - } - } - columns := expression.ExtractColumns(expr) - // `columns` may be empty, if the condition is like `correlated_column op constant`, or `constant`, - // push this kind of constant condition down according to join type. - if len(columns) == 0 { - leftCond, rightCond = p.pushDownConstExpr(expr, leftCond, rightCond, deriveLeft || deriveRight) - continue - } - allFromLeft, allFromRight := true, true - for _, col := range columns { - if !leftSchema.Contains(col) { - allFromLeft = false - } - if !rightSchema.Contains(col) { - allFromRight = false - } - } - if allFromRight { - rightCond = append(rightCond, expr) - } else if allFromLeft { - leftCond = append(leftCond, expr) - } else { - // Relax expr to two supersets: leftRelaxedCond and rightRelaxedCond, the expression now is - // `expr AND leftRelaxedCond AND rightRelaxedCond`. Motivation is to push filters down to - // children as much as possible. - if deriveLeft { - leftRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx.GetExprCtx(), expr, leftSchema) - if leftRelaxedCond != nil { - leftCond = append(leftCond, leftRelaxedCond) - } - } - if deriveRight { - rightRelaxedCond := expression.DeriveRelaxedFiltersFromDNF(ctx.GetExprCtx(), expr, rightSchema) - if rightRelaxedCond != nil { - rightCond = append(rightCond, rightRelaxedCond) - } - } - otherCond = append(otherCond, expr) - } - } - return -} - -// pushDownConstExpr checks if the condition is from filter condition, if true, push it down to both -// children of join, whatever the join type is; if false, push it down to inner child of outer join, -// and both children of non-outer-join. -func (p *LogicalJoin) pushDownConstExpr(expr expression.Expression, leftCond []expression.Expression, - rightCond []expression.Expression, filterCond bool) ([]expression.Expression, []expression.Expression) { - switch p.JoinType { - case LeftOuterJoin, LeftOuterSemiJoin, AntiLeftOuterSemiJoin: - if filterCond { - leftCond = append(leftCond, expr) - // Append the expr to right join condition instead of `rightCond`, to make it able to be - // pushed down to children of join. - p.RightConditions = append(p.RightConditions, expr) - } else { - rightCond = append(rightCond, expr) - } - case RightOuterJoin: - if filterCond { - rightCond = append(rightCond, expr) - p.LeftConditions = append(p.LeftConditions, expr) - } else { - leftCond = append(leftCond, expr) - } - case SemiJoin, InnerJoin: - leftCond = append(leftCond, expr) - rightCond = append(rightCond, expr) - case AntiSemiJoin: - if filterCond { - leftCond = append(leftCond, expr) - } - rightCond = append(rightCond, expr) - } - return leftCond, rightCond -} - -func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, deriveLeft bool, - deriveRight bool) (eqCond []*expression.ScalarFunction, leftCond []expression.Expression, - rightCond []expression.Expression, otherCond []expression.Expression) { - return p.ExtractOnCondition(conditions, p.Children()[0].Schema(), p.Children()[1].Schema(), deriveLeft, deriveRight) -} - -// SetPreferredJoinTypeAndOrder sets the preferred join type and order for the LogicalJoin. -func (p *LogicalJoin) SetPreferredJoinTypeAndOrder(hintInfo *utilhint.PlanHints) { - if hintInfo == nil { - return - } - - lhsAlias := extractTableAlias(p.Children()[0], p.QueryBlockOffset()) - rhsAlias := extractTableAlias(p.Children()[1], p.QueryBlockOffset()) - if hintInfo.IfPreferMergeJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferMergeJoin - p.LeftPreferJoinType |= utilhint.PreferMergeJoin - } - if hintInfo.IfPreferMergeJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferMergeJoin - p.RightPreferJoinType |= utilhint.PreferMergeJoin - } - if hintInfo.IfPreferNoMergeJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferNoMergeJoin - p.LeftPreferJoinType |= utilhint.PreferNoMergeJoin - } - if hintInfo.IfPreferNoMergeJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferNoMergeJoin - p.RightPreferJoinType |= utilhint.PreferNoMergeJoin - } - if hintInfo.IfPreferBroadcastJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferBCJoin - p.LeftPreferJoinType |= utilhint.PreferBCJoin - } - if hintInfo.IfPreferBroadcastJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferBCJoin - p.RightPreferJoinType |= utilhint.PreferBCJoin - } - if hintInfo.IfPreferShuffleJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferShuffleJoin - p.LeftPreferJoinType |= utilhint.PreferShuffleJoin - } - if hintInfo.IfPreferShuffleJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferShuffleJoin - p.RightPreferJoinType |= utilhint.PreferShuffleJoin - } - if hintInfo.IfPreferHashJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferHashJoin - p.LeftPreferJoinType |= utilhint.PreferHashJoin - } - if hintInfo.IfPreferHashJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferHashJoin - p.RightPreferJoinType |= utilhint.PreferHashJoin - } - if hintInfo.IfPreferNoHashJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferNoHashJoin - p.LeftPreferJoinType |= utilhint.PreferNoHashJoin - } - if hintInfo.IfPreferNoHashJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferNoHashJoin - p.RightPreferJoinType |= utilhint.PreferNoHashJoin - } - if hintInfo.IfPreferINLJ(lhsAlias) { - p.PreferJoinType |= utilhint.PreferLeftAsINLJInner - p.LeftPreferJoinType |= utilhint.PreferINLJ - } - if hintInfo.IfPreferINLJ(rhsAlias) { - p.PreferJoinType |= utilhint.PreferRightAsINLJInner - p.RightPreferJoinType |= utilhint.PreferINLJ - } - if hintInfo.IfPreferINLHJ(lhsAlias) { - p.PreferJoinType |= utilhint.PreferLeftAsINLHJInner - p.LeftPreferJoinType |= utilhint.PreferINLHJ - } - if hintInfo.IfPreferINLHJ(rhsAlias) { - p.PreferJoinType |= utilhint.PreferRightAsINLHJInner - p.RightPreferJoinType |= utilhint.PreferINLHJ - } - if hintInfo.IfPreferINLMJ(lhsAlias) { - p.PreferJoinType |= utilhint.PreferLeftAsINLMJInner - p.LeftPreferJoinType |= utilhint.PreferINLMJ - } - if hintInfo.IfPreferINLMJ(rhsAlias) { - p.PreferJoinType |= utilhint.PreferRightAsINLMJInner - p.RightPreferJoinType |= utilhint.PreferINLMJ - } - if hintInfo.IfPreferNoIndexJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferNoIndexJoin - p.LeftPreferJoinType |= utilhint.PreferNoIndexJoin - } - if hintInfo.IfPreferNoIndexJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferNoIndexJoin - p.RightPreferJoinType |= utilhint.PreferNoIndexJoin - } - if hintInfo.IfPreferNoIndexHashJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferNoIndexHashJoin - p.LeftPreferJoinType |= utilhint.PreferNoIndexHashJoin - } - if hintInfo.IfPreferNoIndexHashJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferNoIndexHashJoin - p.RightPreferJoinType |= utilhint.PreferNoIndexHashJoin - } - if hintInfo.IfPreferNoIndexMergeJoin(lhsAlias) { - p.PreferJoinType |= utilhint.PreferNoIndexMergeJoin - p.LeftPreferJoinType |= utilhint.PreferNoIndexMergeJoin - } - if hintInfo.IfPreferNoIndexMergeJoin(rhsAlias) { - p.PreferJoinType |= utilhint.PreferNoIndexMergeJoin - p.RightPreferJoinType |= utilhint.PreferNoIndexMergeJoin - } - if hintInfo.IfPreferHJBuild(lhsAlias) { - p.PreferJoinType |= utilhint.PreferLeftAsHJBuild - p.LeftPreferJoinType |= utilhint.PreferHJBuild - } - if hintInfo.IfPreferHJBuild(rhsAlias) { - p.PreferJoinType |= utilhint.PreferRightAsHJBuild - p.RightPreferJoinType |= utilhint.PreferHJBuild - } - if hintInfo.IfPreferHJProbe(lhsAlias) { - p.PreferJoinType |= utilhint.PreferLeftAsHJProbe - p.LeftPreferJoinType |= utilhint.PreferHJProbe - } - if hintInfo.IfPreferHJProbe(rhsAlias) { - p.PreferJoinType |= utilhint.PreferRightAsHJProbe - p.RightPreferJoinType |= utilhint.PreferHJProbe - } - hasConflict := false - if !p.SCtx().GetSessionVars().EnableAdvancedJoinHint || p.SCtx().GetSessionVars().StmtCtx.StraightJoinOrder { - if containDifferentJoinTypes(p.PreferJoinType) { - hasConflict = true - } - } else if p.SCtx().GetSessionVars().EnableAdvancedJoinHint { - if containDifferentJoinTypes(p.LeftPreferJoinType) || containDifferentJoinTypes(p.RightPreferJoinType) { - hasConflict = true - } - } - if hasConflict { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( - "Join hints are conflict, you can only specify one type of join") - p.PreferJoinType = 0 - } - // set the join order - if hintInfo.LeadingJoinOrder != nil { - p.PreferJoinOrder = hintInfo.MatchTableName([]*utilhint.HintedTable{lhsAlias, rhsAlias}, hintInfo.LeadingJoinOrder) - } - // set hintInfo for further usage if this hint info can be used. - if p.PreferJoinType != 0 || p.PreferJoinOrder { - p.HintInfo = hintInfo - } -} - -// SetPreferredJoinType generates hint information for the logicalJoin based on the hint information of its left and right children. -func (p *LogicalJoin) SetPreferredJoinType() { - if p.LeftPreferJoinType == 0 && p.RightPreferJoinType == 0 { - return - } - p.PreferJoinType = setPreferredJoinTypeFromOneSide(p.LeftPreferJoinType, true) | setPreferredJoinTypeFromOneSide(p.RightPreferJoinType, false) - if containDifferentJoinTypes(p.PreferJoinType) { - p.SCtx().GetSessionVars().StmtCtx.SetHintWarning( - "Join hints conflict after join reorder phase, you can only specify one type of join") - p.PreferJoinType = 0 - } -} - -// updateEQCond will extract the arguments of a equal condition that connect two expressions. -func (p *LogicalJoin) updateEQCond() { - lChild, rChild := p.Children()[0], p.Children()[1] - var lKeys, rKeys []expression.Expression - var lNAKeys, rNAKeys []expression.Expression - // We need two steps here: - // step1: try best to extract normal EQ condition from OtherCondition to join EqualConditions. - for i := len(p.OtherConditions) - 1; i >= 0; i-- { - need2Remove := false - if eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction); ok && eqCond.FuncName.L == ast.EQ { - // If it is a column equal condition converted from `[not] in (subq)`, do not move it - // to EqualConditions, and keep it in OtherConditions. Reference comments in `extractOnCondition` - // for detailed reasons. - if expression.IsEQCondFromIn(eqCond) { - continue - } - lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] - if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { - lKeys = append(lKeys, lExpr) - rKeys = append(rKeys, rExpr) - need2Remove = true - } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { - lKeys = append(lKeys, rExpr) - rKeys = append(rKeys, lExpr) - need2Remove = true - } - } - if need2Remove { - p.OtherConditions = append(p.OtherConditions[:i], p.OtherConditions[i+1:]...) - } - } - // eg: explain select * from t1, t3 where t1.a+1 = t3.a; - // tidb only accept the join key in EqualCondition as a normal column (join OP take granted for that) - // so once we found the left and right children's schema can supply the all columns in complicated EQ condition that used by left/right key. - // we will add a layer of projection here to convert the complicated expression of EQ's left or right side to be a normal column. - adjustKeyForm := func(leftKeys, rightKeys []expression.Expression, isNA bool) { - if len(leftKeys) > 0 { - needLProj, needRProj := false, false - for i := range leftKeys { - _, lOk := leftKeys[i].(*expression.Column) - _, rOk := rightKeys[i].(*expression.Column) - needLProj = needLProj || !lOk - needRProj = needRProj || !rOk - } - - var lProj, rProj *logicalop.LogicalProjection - if needLProj { - lProj = p.getProj(0) - } - if needRProj { - rProj = p.getProj(1) - } - for i := range leftKeys { - lKey, rKey := leftKeys[i], rightKeys[i] - if lProj != nil { - lKey = lProj.AppendExpr(lKey) - } - if rProj != nil { - rKey = rProj.AppendExpr(rKey) - } - eqCond := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lKey, rKey) - if isNA { - p.NAEQConditions = append(p.NAEQConditions, eqCond.(*expression.ScalarFunction)) - } else { - p.EqualConditions = append(p.EqualConditions, eqCond.(*expression.ScalarFunction)) - } - } - } - } - adjustKeyForm(lKeys, rKeys, false) - - // Step2: when step1 is finished, then we can determine whether we need to extract NA-EQ from OtherCondition to NAEQConditions. - // when there are still no EqualConditions, let's try to be a NAAJ. - // todo: by now, when there is already a normal EQ condition, just keep NA-EQ as other-condition filters above it. - // eg: select * from stu where stu.name not in (select name from exam where exam.stu_id = stu.id); - // combination of and for join key is little complicated for now. - canBeNAAJ := (p.JoinType == AntiSemiJoin || p.JoinType == AntiLeftOuterSemiJoin) && len(p.EqualConditions) == 0 - if canBeNAAJ && p.SCtx().GetSessionVars().OptimizerEnableNAAJ { - var otherCond expression.CNFExprs - for i := 0; i < len(p.OtherConditions); i++ { - eqCond, ok := p.OtherConditions[i].(*expression.ScalarFunction) - if ok && eqCond.FuncName.L == ast.EQ && expression.IsEQCondFromIn(eqCond) { - // here must be a EQCondFromIn. - lExpr, rExpr := eqCond.GetArgs()[0], eqCond.GetArgs()[1] - if expression.ExprFromSchema(lExpr, lChild.Schema()) && expression.ExprFromSchema(rExpr, rChild.Schema()) { - lNAKeys = append(lNAKeys, lExpr) - rNAKeys = append(rNAKeys, rExpr) - } else if expression.ExprFromSchema(lExpr, rChild.Schema()) && expression.ExprFromSchema(rExpr, lChild.Schema()) { - lNAKeys = append(lNAKeys, rExpr) - rNAKeys = append(rNAKeys, lExpr) - } - continue - } - otherCond = append(otherCond, p.OtherConditions[i]) - } - p.OtherConditions = otherCond - // here is for cases like: select (a+1, b*3) not in (select a,b from t2) from t1. - adjustKeyForm(lNAKeys, rNAKeys, true) - } -} - -func (p *LogicalJoin) getProj(idx int) *logicalop.LogicalProjection { - child := p.Children()[idx] - proj, ok := child.(*logicalop.LogicalProjection) - if ok { - return proj - } - proj = logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, child.Schema().Len())}.Init(p.SCtx(), child.QueryBlockOffset()) - for _, col := range child.Schema().Columns { - proj.Exprs = append(proj.Exprs, col) - } - proj.SetSchema(child.Schema().Clone()) - proj.SetChildren(child) - p.Children()[idx] = proj - return proj -} - -// outerJoinPropConst propagates constant equal and column equal conditions over outer join. -func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []expression.Expression { - outerTable := p.Children()[0] - innerTable := p.Children()[1] - if p.JoinType == RightOuterJoin { - innerTable, outerTable = outerTable, innerTable - } - lenJoinConds := len(p.EqualConditions) + len(p.LeftConditions) + len(p.RightConditions) + len(p.OtherConditions) - joinConds := make([]expression.Expression, 0, lenJoinConds) - for _, equalCond := range p.EqualConditions { - joinConds = append(joinConds, equalCond) - } - joinConds = append(joinConds, p.LeftConditions...) - joinConds = append(joinConds, p.RightConditions...) - joinConds = append(joinConds, p.OtherConditions...) - p.EqualConditions = nil - p.LeftConditions = nil - p.RightConditions = nil - p.OtherConditions = nil - nullSensitive := p.JoinType == AntiLeftOuterSemiJoin || p.JoinType == LeftOuterSemiJoin - joinConds, predicates = expression.PropConstOverOuterJoin(p.SCtx().GetExprCtx(), joinConds, predicates, outerTable.Schema(), innerTable.Schema(), nullSensitive) - p.AttachOnConds(joinConds) - return predicates -} diff --git a/pkg/planner/core/logical_plan_builder.go b/pkg/planner/core/logical_plan_builder.go index ec4640ba1cdd2..6d818d8397819 100644 --- a/pkg/planner/core/logical_plan_builder.go +++ b/pkg/planner/core/logical_plan_builder.go @@ -4476,13 +4476,13 @@ func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, as // If dynamic partition prune isn't enabled or global stats is not ready, we won't enable dynamic prune mode in query usePartitionProcessor := !isDynamicEnabled || (!globalStatsReady && !allowDynamicWithoutStats) - if val, _err_ := failpoint.Eval(_curpkg_("forceDynamicPrune")); _err_ == nil { + failpoint.Inject("forceDynamicPrune", func(val failpoint.Value) { if val.(bool) { if isDynamicEnabled { usePartitionProcessor = false } } - } + }) if usePartitionProcessor { b.optFlag = b.optFlag | flagPartitionProcessor diff --git a/pkg/planner/core/logical_plan_builder.go__failpoint_stash__ b/pkg/planner/core/logical_plan_builder.go__failpoint_stash__ deleted file mode 100644 index 6d818d8397819..0000000000000 --- a/pkg/planner/core/logical_plan_builder.go__failpoint_stash__ +++ /dev/null @@ -1,7284 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "context" - "fmt" - "math" - "math/bits" - "sort" - "strconv" - "strings" - "time" - "unicode" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/errctx" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/expression/aggregation" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/opcode" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/planner/core/base" - core_metrics "github.com/pingcap/tidb/pkg/planner/core/metrics" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/coreusage" - "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" - "github.com/pingcap/tidb/pkg/planner/util/tablesampler" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/table/temptable" - "github.com/pingcap/tidb/pkg/types" - driver "github.com/pingcap/tidb/pkg/types/parser_driver" - util2 "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/hack" - h "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/intset" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/plancodec" - "github.com/pingcap/tidb/pkg/util/set" - "github.com/pingcap/tidb/pkg/util/size" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -const ( - // ErrExprInSelect is in select fields for the error of ErrFieldNotInGroupBy - ErrExprInSelect = "SELECT list" - // ErrExprInOrderBy is in order by items for the error of ErrFieldNotInGroupBy - ErrExprInOrderBy = "ORDER BY" -) - -// aggOrderByResolver is currently resolving expressions of order by clause -// in aggregate function GROUP_CONCAT. -type aggOrderByResolver struct { - ctx base.PlanContext - err error - args []ast.ExprNode - exprDepth int // exprDepth is the depth of current expression in expression tree. -} - -func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) { - a.exprDepth++ - if n, ok := inNode.(*driver.ParamMarkerExpr); ok { - if a.exprDepth == 1 { - _, isNull, isExpectedType := getUintFromNode(a.ctx, n, false) - // For constant uint expression in top level, it should be treated as position expression. - if !isNull && isExpectedType { - return expression.ConstructPositionExpr(n), true - } - } - } - return inNode, false -} - -func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) { - if v, ok := inNode.(*ast.PositionExpr); ok { - pos, isNull, err := expression.PosFromPositionExpr(a.ctx.GetExprCtx(), a.ctx, v) - if err != nil { - a.err = err - } - if err != nil || isNull { - return inNode, false - } - if pos < 1 || pos > len(a.args) { - errPos := strconv.Itoa(pos) - if v.P != nil { - errPos = "?" - } - a.err = plannererrors.ErrUnknownColumn.FastGenByArgs(errPos, "order clause") - return inNode, false - } - ret := a.args[pos-1] - return ret, true - } - return inNode, true -} - -func (b *PlanBuilder) buildExpand(p base.LogicalPlan, gbyItems []expression.Expression) (base.LogicalPlan, []expression.Expression, error) { - ectx := p.SCtx().GetExprCtx().GetEvalCtx() - b.optFlag |= flagResolveExpand - - // Rollup syntax require expand OP to do the data expansion, different data replica supply the different grouping layout. - distinctGbyExprs, gbyExprsRefPos := expression.DeduplicateGbyExpression(gbyItems) - // build another projection below. - proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, p.Schema().Len()+len(distinctGbyExprs))}.Init(b.ctx, b.getSelectOffset()) - // project: child's output and distinct GbyExprs in advance. (make every group-by item to be a column) - projSchema := p.Schema().Clone() - names := p.OutputNames() - for _, col := range projSchema.Columns { - proj.Exprs = append(proj.Exprs, col) - } - distinctGbyColNames := make(types.NameSlice, 0, len(distinctGbyExprs)) - distinctGbyCols := make([]*expression.Column, 0, len(distinctGbyExprs)) - for _, expr := range distinctGbyExprs { - // distinct group expr has been resolved in resolveGby. - proj.Exprs = append(proj.Exprs, expr) - - // add the newly appended names. - var name *types.FieldName - if c, ok := expr.(*expression.Column); ok { - name = buildExpandFieldName(ectx, c, names[p.Schema().ColumnIndex(c)], "") - } else { - name = buildExpandFieldName(ectx, expr, nil, "") - } - names = append(names, name) - distinctGbyColNames = append(distinctGbyColNames, name) - - // since we will change the nullability of source col, proj it with a new col id. - col := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - // clone it rather than using it directly, - RetType: expr.GetType(b.ctx.GetExprCtx().GetEvalCtx()).Clone(), - } - - projSchema.Append(col) - distinctGbyCols = append(distinctGbyCols, col) - } - proj.SetSchema(projSchema) - proj.SetChildren(p) - // since expand will ref original col and make some change, do the copy in executor rather than ref the same chunk.column. - proj.AvoidColumnEvaluator = true - proj.Proj4Expand = true - newGbyItems := expression.RestoreGbyExpression(distinctGbyCols, gbyExprsRefPos) - - // build expand. - rollupGroupingSets := expression.RollupGroupingSets(newGbyItems) - // eg: with rollup => {},{a},{a,b},{a,b,c} - // for every grouping set above, we should individually set those not-needed grouping-set col as null value. - // eg: let's say base schema is , d is unrelated col, keep it real in every grouping set projection. - // for grouping set {a,b,c}, project it as: [a, b, c, d, gid] - // for grouping set {a,b}, project it as: [a, b, null, d, gid] - // for grouping set {a}, project it as: [a, null, null, d, gid] - // for grouping set {}, project it as: [null, null, null, d, gid] - expandSchema := proj.Schema().Clone() - expression.AdjustNullabilityFromGroupingSets(rollupGroupingSets, expandSchema) - expand := LogicalExpand{ - RollupGroupingSets: rollupGroupingSets, - DistinctGroupByCol: distinctGbyCols, - DistinctGbyColNames: distinctGbyColNames, - // for resolving grouping function args. - DistinctGbyExprs: distinctGbyExprs, - - // fill the gen col names when building level projections. - }.Init(b.ctx, b.getSelectOffset()) - - // if we want to use bitAnd for the quick computation of grouping function, then the maximum capacity of num of grouping is about 64. - expand.GroupingMode = tipb.GroupingMode_ModeBitAnd - if len(expand.RollupGroupingSets) > 64 { - expand.GroupingMode = tipb.GroupingMode_ModeNumericSet - } - - expand.DistinctSize, expand.RollupGroupingIDs, expand.RollupID2GIDS = expand.RollupGroupingSets.DistinctSize() - hasDuplicateGroupingSet := len(expand.RollupGroupingSets) != expand.DistinctSize - // append the generated column for logical Expand. - tp := types.NewFieldType(mysql.TypeLonglong) - tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) - gid := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: tp, - OrigName: "gid", - } - expand.GID = gid - expandSchema.Append(gid) - expand.ExtraGroupingColNames = append(expand.ExtraGroupingColNames, gid.OrigName) - names = append(names, buildExpandFieldName(ectx, gid, nil, "gid_")) - expand.GIDName = names[len(names)-1] - if hasDuplicateGroupingSet { - // the last two col of the schema should be gid & gpos - gpos := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: tp.Clone(), - OrigName: "gpos", - } - expand.GPos = gpos - expandSchema.Append(gpos) - expand.ExtraGroupingColNames = append(expand.ExtraGroupingColNames, gpos.OrigName) - names = append(names, buildExpandFieldName(ectx, gpos, nil, "gpos_")) - expand.GPosName = names[len(names)-1] - } - expand.SetChildren(proj) - expand.SetSchema(expandSchema) - expand.SetOutputNames(names) - - // register current rollup Expand operator in current select block. - b.currentBlockExpand = expand - - // defer generating level-projection as last logical optimization rule. - return expand, newGbyItems, nil -} - -func (b *PlanBuilder) buildAggregation(ctx context.Context, p base.LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression, - correlatedAggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, map[int]int, error) { - b.optFlag |= flagBuildKeyInfo - b.optFlag |= flagPushDownAgg - // We may apply aggregation eliminate optimization. - // So we add the flagMaxMinEliminate to try to convert max/min to topn and flagPushDownTopN to handle the newly added topn operator. - b.optFlag |= flagMaxMinEliminate - b.optFlag |= flagPushDownTopN - // when we eliminate the max and min we may add `is not null` filter. - b.optFlag |= flagPredicatePushDown - b.optFlag |= flagEliminateAgg - b.optFlag |= flagEliminateProjection - - if b.ctx.GetSessionVars().EnableSkewDistinctAgg { - b.optFlag |= flagSkewDistinctAgg - } - // flag it if cte contain aggregation - if b.buildingCTE { - b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow = true - } - var rollupExpand *LogicalExpand - if expand, ok := p.(*LogicalExpand); ok { - rollupExpand = expand - } - - plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx, b.getSelectOffset()) - if hintinfo := b.TableHints(); hintinfo != nil { - plan4Agg.PreferAggType = hintinfo.PreferAggType - plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop - } - schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...) - names := make(types.NameSlice, 0, len(aggFuncList)+p.Schema().Len()) - // aggIdxMap maps the old index to new index after applying common aggregation functions elimination. - aggIndexMap := make(map[int]int) - - allAggsFirstRow := true - for i, aggFunc := range aggFuncList { - newArgList := make([]expression.Expression, 0, len(aggFunc.Args)) - for _, arg := range aggFunc.Args { - newArg, np, err := b.rewrite(ctx, arg, p, nil, true) - if err != nil { - return nil, nil, err - } - p = np - newArgList = append(newArgList, newArg) - } - newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), aggFunc.F, newArgList, aggFunc.Distinct) - if err != nil { - return nil, nil, err - } - if newFunc.Name != ast.AggFuncFirstRow { - allAggsFirstRow = false - } - if aggFunc.Order != nil { - trueArgs := aggFunc.Args[:len(aggFunc.Args)-1] // the last argument is SEPARATOR, remote it. - resolver := &aggOrderByResolver{ - ctx: b.ctx, - args: trueArgs, - } - for _, byItem := range aggFunc.Order.Items { - resolver.exprDepth = 0 - resolver.err = nil - retExpr, _ := byItem.Expr.Accept(resolver) - if resolver.err != nil { - return nil, nil, errors.Trace(resolver.err) - } - newByItem, np, err := b.rewrite(ctx, retExpr.(ast.ExprNode), p, nil, true) - if err != nil { - return nil, nil, err - } - p = np - newFunc.OrderByItems = append(newFunc.OrderByItems, &util.ByItems{Expr: newByItem, Desc: byItem.Desc}) - } - } - // combine identical aggregate functions - combined := false - for j := 0; j < i; j++ { - oldFunc := plan4Agg.AggFuncs[aggIndexMap[j]] - if oldFunc.Equal(b.ctx.GetExprCtx().GetEvalCtx(), newFunc) { - aggIndexMap[i] = aggIndexMap[j] - combined = true - if _, ok := correlatedAggMap[aggFunc]; ok { - if _, ok = b.correlatedAggMapper[aggFuncList[j]]; !ok { - b.correlatedAggMapper[aggFuncList[j]] = &expression.CorrelatedColumn{ - Column: *schema4Agg.Columns[aggIndexMap[j]], - Data: new(types.Datum), - } - } - b.correlatedAggMapper[aggFunc] = b.correlatedAggMapper[aggFuncList[j]] - } - break - } - } - // create new columns for aggregate functions which show up first - if !combined { - position := len(plan4Agg.AggFuncs) - aggIndexMap[i] = position - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) - column := expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: newFunc.RetTp, - } - schema4Agg.Append(&column) - names = append(names, types.EmptyName) - if _, ok := correlatedAggMap[aggFunc]; ok { - b.correlatedAggMapper[aggFunc] = &expression.CorrelatedColumn{ - Column: column, - Data: new(types.Datum), - } - } - } - } - for i, col := range p.Schema().Columns { - newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) - if err != nil { - return nil, nil, err - } - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) - newCol, _ := col.Clone().(*expression.Column) - newCol.RetType = newFunc.RetTp - schema4Agg.Append(newCol) - names = append(names, p.OutputNames()[i]) - } - var ( - join *LogicalJoin - isJoin bool - isSelectionJoin bool - ) - join, isJoin = p.(*LogicalJoin) - selection, isSelection := p.(*LogicalSelection) - if isSelection { - join, isSelectionJoin = selection.Children()[0].(*LogicalJoin) - } - if (isJoin && join.FullSchema != nil) || (isSelectionJoin && join.FullSchema != nil) { - for i, col := range join.FullSchema.Columns { - if p.Schema().Contains(col) { - continue - } - newFunc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) - if err != nil { - return nil, nil, err - } - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, newFunc) - newCol, _ := col.Clone().(*expression.Column) - newCol.RetType = newFunc.RetTp - schema4Agg.Append(newCol) - names = append(names, join.FullNames[i]) - } - } - hasGroupBy := len(gbyItems) > 0 - for i, aggFunc := range plan4Agg.AggFuncs { - err := aggFunc.UpdateNotNullFlag4RetType(hasGroupBy, allAggsFirstRow) - if err != nil { - return nil, nil, err - } - schema4Agg.Columns[i].RetType = aggFunc.RetTp - } - plan4Agg.SetOutputNames(names) - plan4Agg.SetChildren(p) - if rollupExpand != nil { - // append gid and gpos as the group keys if any. - plan4Agg.GroupByItems = append(gbyItems, rollupExpand.GID) - if rollupExpand.GPos != nil { - plan4Agg.GroupByItems = append(plan4Agg.GroupByItems, rollupExpand.GPos) - } - } else { - plan4Agg.GroupByItems = gbyItems - } - plan4Agg.SetSchema(schema4Agg) - return plan4Agg, aggIndexMap, nil -} - -func (b *PlanBuilder) buildTableRefs(ctx context.Context, from *ast.TableRefsClause) (p base.LogicalPlan, err error) { - if from == nil { - p = b.buildTableDual() - return - } - defer func() { - // After build the resultSetNode, need to reset it so that it can be referenced by outer level. - for _, cte := range b.outerCTEs { - cte.recursiveRef = false - } - }() - return b.buildResultSetNode(ctx, from.TableRefs, false) -} - -func (b *PlanBuilder) buildResultSetNode(ctx context.Context, node ast.ResultSetNode, isCTE bool) (p base.LogicalPlan, err error) { - //If it is building the CTE queries, we will mark them. - b.isCTE = isCTE - switch x := node.(type) { - case *ast.Join: - return b.buildJoin(ctx, x) - case *ast.TableSource: - var isTableName bool - switch v := x.Source.(type) { - case *ast.SelectStmt: - ci := b.prepareCTECheckForSubQuery() - defer resetCTECheckForSubQuery(ci) - b.optFlag = b.optFlag | flagConstantPropagation - p, err = b.buildSelect(ctx, v) - case *ast.SetOprStmt: - ci := b.prepareCTECheckForSubQuery() - defer resetCTECheckForSubQuery(ci) - p, err = b.buildSetOpr(ctx, v) - case *ast.TableName: - p, err = b.buildDataSource(ctx, v, &x.AsName) - isTableName = true - default: - err = plannererrors.ErrUnsupportedType.GenWithStackByArgs(v) - } - if err != nil { - return nil, err - } - - for _, name := range p.OutputNames() { - if name.Hidden { - continue - } - if x.AsName.L != "" { - name.TblName = x.AsName - } - } - // `TableName` is not a select block, so we do not need to handle it. - var plannerSelectBlockAsName []ast.HintTable - if p := b.ctx.GetSessionVars().PlannerSelectBlockAsName.Load(); p != nil { - plannerSelectBlockAsName = *p - } - if len(plannerSelectBlockAsName) > 0 && !isTableName { - plannerSelectBlockAsName[p.QueryBlockOffset()] = ast.HintTable{DBName: p.OutputNames()[0].DBName, TableName: p.OutputNames()[0].TblName} - } - // Duplicate column name in one table is not allowed. - // "select * from (select 1, 1) as a;" is duplicate - dupNames := make(map[string]struct{}, len(p.Schema().Columns)) - for _, name := range p.OutputNames() { - colName := name.ColName.O - if _, ok := dupNames[colName]; ok { - return nil, plannererrors.ErrDupFieldName.GenWithStackByArgs(colName) - } - dupNames[colName] = struct{}{} - } - return p, nil - case *ast.SelectStmt: - return b.buildSelect(ctx, x) - case *ast.SetOprStmt: - return b.buildSetOpr(ctx, x) - default: - return nil, plannererrors.ErrUnsupportedType.GenWithStack("Unsupported ast.ResultSetNode(%T) for buildResultSetNode()", x) - } -} - -// extractTableAlias returns table alias of the base.LogicalPlan's columns. -// It will return nil when there are multiple table alias, because the alias is only used to check if -// the base.LogicalPlan Match some optimizer hints, and hints are not expected to take effect in this case. -func extractTableAlias(p base.Plan, parentOffset int) *h.HintedTable { - if len(p.OutputNames()) > 0 && p.OutputNames()[0].TblName.L != "" { - firstName := p.OutputNames()[0] - for _, name := range p.OutputNames() { - if name.TblName.L != firstName.TblName.L || - (name.DBName.L != "" && firstName.DBName.L != "" && name.DBName.L != firstName.DBName.L) { // DBName can be nil, see #46160 - return nil - } - } - qbOffset := p.QueryBlockOffset() - var blockAsNames []ast.HintTable - if p := p.SCtx().GetSessionVars().PlannerSelectBlockAsName.Load(); p != nil { - blockAsNames = *p - } - // For sub-queries like `(select * from t) t1`, t1 should belong to its surrounding select block. - if qbOffset != parentOffset && blockAsNames != nil && blockAsNames[qbOffset].TableName.L != "" { - qbOffset = parentOffset - } - dbName := firstName.DBName - if dbName.L == "" { - dbName = model.NewCIStr(p.SCtx().GetSessionVars().CurrentDB) - } - return &h.HintedTable{DBName: dbName, TblName: firstName.TblName, SelectOffset: qbOffset} - } - return nil -} - -func setPreferredJoinTypeFromOneSide(preferJoinType uint, isLeft bool) (resJoinType uint) { - if preferJoinType == 0 { - return - } - if preferJoinType&h.PreferINLJ > 0 { - preferJoinType &= ^h.PreferINLJ - if isLeft { - resJoinType |= h.PreferLeftAsINLJInner - } else { - resJoinType |= h.PreferRightAsINLJInner - } - } - if preferJoinType&h.PreferINLHJ > 0 { - preferJoinType &= ^h.PreferINLHJ - if isLeft { - resJoinType |= h.PreferLeftAsINLHJInner - } else { - resJoinType |= h.PreferRightAsINLHJInner - } - } - if preferJoinType&h.PreferINLMJ > 0 { - preferJoinType &= ^h.PreferINLMJ - if isLeft { - resJoinType |= h.PreferLeftAsINLMJInner - } else { - resJoinType |= h.PreferRightAsINLMJInner - } - } - if preferJoinType&h.PreferHJBuild > 0 { - preferJoinType &= ^h.PreferHJBuild - if isLeft { - resJoinType |= h.PreferLeftAsHJBuild - } else { - resJoinType |= h.PreferRightAsHJBuild - } - } - if preferJoinType&h.PreferHJProbe > 0 { - preferJoinType &= ^h.PreferHJProbe - if isLeft { - resJoinType |= h.PreferLeftAsHJProbe - } else { - resJoinType |= h.PreferRightAsHJProbe - } - } - resJoinType |= preferJoinType - return -} - -func (ds *DataSource) setPreferredStoreType(hintInfo *h.PlanHints) { - if hintInfo == nil { - return - } - - var alias *h.HintedTable - if len(ds.TableAsName.L) != 0 { - alias = &h.HintedTable{DBName: ds.DBName, TblName: *ds.TableAsName, SelectOffset: ds.QueryBlockOffset()} - } else { - alias = &h.HintedTable{DBName: ds.DBName, TblName: ds.TableInfo.Name, SelectOffset: ds.QueryBlockOffset()} - } - if hintTbl := hintInfo.IfPreferTiKV(alias); hintTbl != nil { - for _, path := range ds.PossibleAccessPaths { - if path.StoreType == kv.TiKV { - ds.PreferStoreType |= h.PreferTiKV - ds.PreferPartitions[h.PreferTiKV] = hintTbl.Partitions - break - } - } - if ds.PreferStoreType&h.PreferTiKV == 0 { - errMsg := fmt.Sprintf("No available path for table %s.%s with the store type %s of the hint /*+ read_from_storage */, "+ - "please check the status of the table replica and variable value of tidb_isolation_read_engines(%v)", - ds.DBName.O, ds.table.Meta().Name.O, kv.TiKV.Name(), ds.SCtx().GetSessionVars().GetIsolationReadEngines()) - ds.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) - } else { - ds.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because you have set a hint to read table `" + hintTbl.TblName.O + "` from TiKV.") - } - } - if hintTbl := hintInfo.IfPreferTiFlash(alias); hintTbl != nil { - // `ds.PreferStoreType != 0`, which means there's a hint hit the both TiKV value and TiFlash value for table. - // We can't support read a table from two different storages, even partition table. - if ds.PreferStoreType != 0 { - ds.SCtx().GetSessionVars().StmtCtx.SetHintWarning( - fmt.Sprintf("Storage hints are conflict, you can only specify one storage type of table %s.%s", - alias.DBName.L, alias.TblName.L)) - ds.PreferStoreType = 0 - return - } - for _, path := range ds.PossibleAccessPaths { - if path.StoreType == kv.TiFlash { - ds.PreferStoreType |= h.PreferTiFlash - ds.PreferPartitions[h.PreferTiFlash] = hintTbl.Partitions - break - } - } - if ds.PreferStoreType&h.PreferTiFlash == 0 { - errMsg := fmt.Sprintf("No available path for table %s.%s with the store type %s of the hint /*+ read_from_storage */, "+ - "please check the status of the table replica and variable value of tidb_isolation_read_engines(%v)", - ds.DBName.O, ds.table.Meta().Name.O, kv.TiFlash.Name(), ds.SCtx().GetSessionVars().GetIsolationReadEngines()) - ds.SCtx().GetSessionVars().StmtCtx.SetHintWarning(errMsg) - } - } -} - -func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (base.LogicalPlan, error) { - // We will construct a "Join" node for some statements like "INSERT", - // "DELETE", "UPDATE", "REPLACE". For this scenario "joinNode.Right" is nil - // and we only build the left "ResultSetNode". - if joinNode.Right == nil { - return b.buildResultSetNode(ctx, joinNode.Left, false) - } - - b.optFlag = b.optFlag | flagPredicatePushDown - // Add join reorder flag regardless of inner join or outer join. - b.optFlag = b.optFlag | flagJoinReOrder - b.optFlag |= flagPredicateSimplification - b.optFlag |= flagConvertOuterToInnerJoin - - leftPlan, err := b.buildResultSetNode(ctx, joinNode.Left, false) - if err != nil { - return nil, err - } - - rightPlan, err := b.buildResultSetNode(ctx, joinNode.Right, false) - if err != nil { - return nil, err - } - - // The recursive part in CTE must not be on the right side of a LEFT JOIN. - if lc, ok := rightPlan.(*logicalop.LogicalCTETable); ok && joinNode.Tp == ast.LeftJoin { - return nil, plannererrors.ErrCTERecursiveForbiddenJoinOrder.GenWithStackByArgs(lc.Name) - } - - handleMap1 := b.handleHelper.popMap() - handleMap2 := b.handleHelper.popMap() - b.handleHelper.mergeAndPush(handleMap1, handleMap2) - - joinPlan := LogicalJoin{StraightJoin: joinNode.StraightJoin || b.inStraightJoin}.Init(b.ctx, b.getSelectOffset()) - joinPlan.SetChildren(leftPlan, rightPlan) - joinPlan.SetSchema(expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema())) - joinPlan.SetOutputNames(make([]*types.FieldName, leftPlan.Schema().Len()+rightPlan.Schema().Len())) - copy(joinPlan.OutputNames(), leftPlan.OutputNames()) - copy(joinPlan.OutputNames()[leftPlan.Schema().Len():], rightPlan.OutputNames()) - - // Set join type. - switch joinNode.Tp { - case ast.LeftJoin: - // left outer join need to be checked elimination - b.optFlag = b.optFlag | flagEliminateOuterJoin - joinPlan.JoinType = LeftOuterJoin - util.ResetNotNullFlag(joinPlan.Schema(), leftPlan.Schema().Len(), joinPlan.Schema().Len()) - case ast.RightJoin: - // right outer join need to be checked elimination - b.optFlag = b.optFlag | flagEliminateOuterJoin - joinPlan.JoinType = RightOuterJoin - util.ResetNotNullFlag(joinPlan.Schema(), 0, leftPlan.Schema().Len()) - default: - joinPlan.JoinType = InnerJoin - } - - // Merge sub-plan's FullSchema into this join plan. - // Please read the comment of LogicalJoin.FullSchema for the details. - var ( - lFullSchema, rFullSchema *expression.Schema - lFullNames, rFullNames types.NameSlice - ) - if left, ok := leftPlan.(*LogicalJoin); ok && left.FullSchema != nil { - lFullSchema = left.FullSchema - lFullNames = left.FullNames - } else { - lFullSchema = leftPlan.Schema() - lFullNames = leftPlan.OutputNames() - } - if right, ok := rightPlan.(*LogicalJoin); ok && right.FullSchema != nil { - rFullSchema = right.FullSchema - rFullNames = right.FullNames - } else { - rFullSchema = rightPlan.Schema() - rFullNames = rightPlan.OutputNames() - } - if joinNode.Tp == ast.RightJoin { - // Make sure lFullSchema means outer full schema and rFullSchema means inner full schema. - lFullSchema, rFullSchema = rFullSchema, lFullSchema - lFullNames, rFullNames = rFullNames, lFullNames - } - joinPlan.FullSchema = expression.MergeSchema(lFullSchema, rFullSchema) - - // Clear NotNull flag for the inner side schema if it's an outer join. - if joinNode.Tp == ast.LeftJoin || joinNode.Tp == ast.RightJoin { - util.ResetNotNullFlag(joinPlan.FullSchema, lFullSchema.Len(), joinPlan.FullSchema.Len()) - } - - // Merge sub-plan's FullNames into this join plan, similar to the FullSchema logic above. - joinPlan.FullNames = make([]*types.FieldName, 0, len(lFullNames)+len(rFullNames)) - for _, lName := range lFullNames { - name := *lName - joinPlan.FullNames = append(joinPlan.FullNames, &name) - } - for _, rName := range rFullNames { - name := *rName - joinPlan.FullNames = append(joinPlan.FullNames, &name) - } - - // Set preferred join algorithm if some join hints is specified by user. - joinPlan.SetPreferredJoinTypeAndOrder(b.TableHints()) - - // "NATURAL JOIN" doesn't have "ON" or "USING" conditions. - // - // The "NATURAL [LEFT] JOIN" of two tables is defined to be semantically - // equivalent to an "INNER JOIN" or a "LEFT JOIN" with a "USING" clause - // that names all columns that exist in both tables. - // - // See https://dev.mysql.com/doc/refman/5.7/en/join.html for more detail. - if joinNode.NaturalJoin { - err = b.buildNaturalJoin(joinPlan, leftPlan, rightPlan, joinNode) - if err != nil { - return nil, err - } - } else if joinNode.Using != nil { - err = b.buildUsingClause(joinPlan, leftPlan, rightPlan, joinNode) - if err != nil { - return nil, err - } - } else if joinNode.On != nil { - b.curClause = onClause - onExpr, newPlan, err := b.rewrite(ctx, joinNode.On.Expr, joinPlan, nil, false) - if err != nil { - return nil, err - } - if newPlan != joinPlan { - return nil, errors.New("ON condition doesn't support subqueries yet") - } - onCondition := expression.SplitCNFItems(onExpr) - // Keep these expressions as a LogicalSelection upon the inner join, in order to apply - // possible decorrelate optimizations. The ON clause is actually treated as a WHERE clause now. - if joinPlan.JoinType == InnerJoin { - sel := LogicalSelection{Conditions: onCondition}.Init(b.ctx, b.getSelectOffset()) - sel.SetChildren(joinPlan) - return sel, nil - } - joinPlan.AttachOnConds(onCondition) - } else if joinPlan.JoinType == InnerJoin { - // If a inner join without "ON" or "USING" clause, it's a cartesian - // product over the join tables. - joinPlan.CartesianJoin = true - } - - return joinPlan, nil -} - -// buildUsingClause eliminate the redundant columns and ordering columns based -// on the "USING" clause. -// -// According to the standard SQL, columns are ordered in the following way: -// 1. coalesced common columns of "leftPlan" and "rightPlan", in the order they -// appears in "leftPlan". -// 2. the rest columns in "leftPlan", in the order they appears in "leftPlan". -// 3. the rest columns in "rightPlan", in the order they appears in "rightPlan". -func (b *PlanBuilder) buildUsingClause(p *LogicalJoin, leftPlan, rightPlan base.LogicalPlan, join *ast.Join) error { - filter := make(map[string]bool, len(join.Using)) - for _, col := range join.Using { - filter[col.Name.L] = true - } - err := b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp, filter) - if err != nil { - return err - } - // We do not need to coalesce columns for update and delete. - if b.inUpdateStmt || b.inDeleteStmt { - p.SetSchemaAndNames(expression.MergeSchema(p.Children()[0].Schema(), p.Children()[1].Schema()), - append(p.Children()[0].OutputNames(), p.Children()[1].OutputNames()...)) - } - return nil -} - -// buildNaturalJoin builds natural join output schema. It finds out all the common columns -// then using the same mechanism as buildUsingClause to eliminate redundant columns and build join conditions. -// According to standard SQL, producing this display order: -// -// All the common columns -// Every column in the first (left) table that is not a common column -// Every column in the second (right) table that is not a common column -func (b *PlanBuilder) buildNaturalJoin(p *LogicalJoin, leftPlan, rightPlan base.LogicalPlan, join *ast.Join) error { - err := b.coalesceCommonColumns(p, leftPlan, rightPlan, join.Tp, nil) - if err != nil { - return err - } - // We do not need to coalesce columns for update and delete. - if b.inUpdateStmt || b.inDeleteStmt { - p.SetSchemaAndNames(expression.MergeSchema(p.Children()[0].Schema(), p.Children()[1].Schema()), - append(p.Children()[0].OutputNames(), p.Children()[1].OutputNames()...)) - } - return nil -} - -// coalesceCommonColumns is used by buildUsingClause and buildNaturalJoin. The filter is used by buildUsingClause. -func (b *PlanBuilder) coalesceCommonColumns(p *LogicalJoin, leftPlan, rightPlan base.LogicalPlan, joinTp ast.JoinType, filter map[string]bool) error { - lsc := leftPlan.Schema().Clone() - rsc := rightPlan.Schema().Clone() - if joinTp == ast.LeftJoin { - util.ResetNotNullFlag(rsc, 0, rsc.Len()) - } else if joinTp == ast.RightJoin { - util.ResetNotNullFlag(lsc, 0, lsc.Len()) - } - lColumns, rColumns := lsc.Columns, rsc.Columns - lNames, rNames := leftPlan.OutputNames().Shallow(), rightPlan.OutputNames().Shallow() - if joinTp == ast.RightJoin { - leftPlan, rightPlan = rightPlan, leftPlan - lNames, rNames = rNames, lNames - lColumns, rColumns = rsc.Columns, lsc.Columns - } - - // Check using clause with ambiguous columns. - if filter != nil { - checkAmbiguous := func(names types.NameSlice) error { - columnNameInFilter := set.StringSet{} - for _, name := range names { - if _, ok := filter[name.ColName.L]; !ok { - continue - } - if columnNameInFilter.Exist(name.ColName.L) { - return plannererrors.ErrAmbiguous.GenWithStackByArgs(name.ColName.L, "from clause") - } - columnNameInFilter.Insert(name.ColName.L) - } - return nil - } - err := checkAmbiguous(lNames) - if err != nil { - return err - } - err = checkAmbiguous(rNames) - if err != nil { - return err - } - } else { - // Even with no using filter, we still should check the checkAmbiguous name before we try to find the common column from both side. - // (t3 cross join t4) natural join t1 - // t1 natural join (t3 cross join t4) - // t3 and t4 may generate the same name column from cross join. - // for every common column of natural join, the name from right or left should be exactly one. - commonNames := make([]string, 0, len(lNames)) - lNameMap := make(map[string]int, len(lNames)) - rNameMap := make(map[string]int, len(rNames)) - for _, name := range lNames { - // Natural join should ignore _tidb_rowid - if name.ColName.L == "_tidb_rowid" { - continue - } - // record left map - if cnt, ok := lNameMap[name.ColName.L]; ok { - lNameMap[name.ColName.L] = cnt + 1 - } else { - lNameMap[name.ColName.L] = 1 - } - } - for _, name := range rNames { - // Natural join should ignore _tidb_rowid - if name.ColName.L == "_tidb_rowid" { - continue - } - // record right map - if cnt, ok := rNameMap[name.ColName.L]; ok { - rNameMap[name.ColName.L] = cnt + 1 - } else { - rNameMap[name.ColName.L] = 1 - } - // check left map - if cnt, ok := lNameMap[name.ColName.L]; ok { - if cnt > 1 { - return plannererrors.ErrAmbiguous.GenWithStackByArgs(name.ColName.L, "from clause") - } - commonNames = append(commonNames, name.ColName.L) - } - } - // check right map - for _, commonName := range commonNames { - if rNameMap[commonName] > 1 { - return plannererrors.ErrAmbiguous.GenWithStackByArgs(commonName, "from clause") - } - } - } - - // Find out all the common columns and put them ahead. - commonLen := 0 - for i, lName := range lNames { - // Natural join should ignore _tidb_rowid - if lName.ColName.L == "_tidb_rowid" { - continue - } - for j := commonLen; j < len(rNames); j++ { - if lName.ColName.L != rNames[j].ColName.L { - continue - } - - if len(filter) > 0 { - if !filter[lName.ColName.L] { - break - } - // Mark this column exist. - filter[lName.ColName.L] = false - } - - col := lColumns[i] - copy(lColumns[commonLen+1:i+1], lColumns[commonLen:i]) - lColumns[commonLen] = col - - name := lNames[i] - copy(lNames[commonLen+1:i+1], lNames[commonLen:i]) - lNames[commonLen] = name - - col = rColumns[j] - copy(rColumns[commonLen+1:j+1], rColumns[commonLen:j]) - rColumns[commonLen] = col - - name = rNames[j] - copy(rNames[commonLen+1:j+1], rNames[commonLen:j]) - rNames[commonLen] = name - - commonLen++ - break - } - } - - if len(filter) > 0 && len(filter) != commonLen { - for col, notExist := range filter { - if notExist { - return plannererrors.ErrUnknownColumn.GenWithStackByArgs(col, "from clause") - } - } - } - - schemaCols := make([]*expression.Column, len(lColumns)+len(rColumns)-commonLen) - copy(schemaCols[:len(lColumns)], lColumns) - copy(schemaCols[len(lColumns):], rColumns[commonLen:]) - names := make(types.NameSlice, len(schemaCols)) - copy(names, lNames) - copy(names[len(lNames):], rNames[commonLen:]) - - conds := make([]expression.Expression, 0, commonLen) - for i := 0; i < commonLen; i++ { - lc, rc := lsc.Columns[i], rsc.Columns[i] - cond, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), lc, rc) - if err != nil { - return err - } - conds = append(conds, cond) - if p.FullSchema != nil { - // since FullSchema is derived from left and right schema in upper layer, so rc/lc must be in FullSchema. - if joinTp == ast.RightJoin { - p.FullNames[p.FullSchema.ColumnIndex(lc)].Redundant = true - } else { - p.FullNames[p.FullSchema.ColumnIndex(rc)].Redundant = true - } - } - } - - p.SetSchema(expression.NewSchema(schemaCols...)) - p.SetOutputNames(names) - - p.OtherConditions = append(conds, p.OtherConditions...) - - return nil -} - -func (b *PlanBuilder) buildSelection(ctx context.Context, p base.LogicalPlan, where ast.ExprNode, aggMapper map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, error) { - b.optFlag |= flagPredicatePushDown - b.optFlag |= flagDeriveTopNFromWindow - b.optFlag |= flagPredicateSimplification - if b.curClause != havingClause { - b.curClause = whereClause - } - - conditions := splitWhere(where) - expressions := make([]expression.Expression, 0, len(conditions)) - selection := LogicalSelection{}.Init(b.ctx, b.getSelectOffset()) - for _, cond := range conditions { - expr, np, err := b.rewrite(ctx, cond, p, aggMapper, false) - if err != nil { - return nil, err - } - // for case: explain SELECT year+2 as y, SUM(profit) AS profit FROM sales GROUP BY year+2, year+profit WITH ROLLUP having y > 2002; - // currently, we succeed to resolve y to (year+2), but fail to resolve (year+2) to grouping col, and to base column function: plus(year, 2) instead. - // which will cause this selection being pushed down through Expand OP itself. - // - // In expand, we will additionally project (year+2) out as a new column, let's say grouping_col here, and we wanna it can substitute any upper layer's (year+2) - expr = b.replaceGroupingFunc(expr) - - p = np - if expr == nil { - continue - } - expressions = append(expressions, expr) - } - cnfExpres := make([]expression.Expression, 0) - useCache := b.ctx.GetSessionVars().StmtCtx.UseCache() - for _, expr := range expressions { - cnfItems := expression.SplitCNFItems(expr) - for _, item := range cnfItems { - if con, ok := item.(*expression.Constant); ok && expression.ConstExprConsiderPlanCache(con, useCache) { - ret, _, err := expression.EvalBool(b.ctx.GetExprCtx().GetEvalCtx(), expression.CNFExprs{con}, chunk.Row{}) - if err != nil { - return nil, errors.Trace(err) - } - if ret { - continue - } - // If there is condition which is always false, return dual plan directly. - dual := logicalop.LogicalTableDual{}.Init(b.ctx, b.getSelectOffset()) - dual.SetOutputNames(p.OutputNames()) - dual.SetSchema(p.Schema()) - return dual, nil - } - cnfExpres = append(cnfExpres, item) - } - } - if len(cnfExpres) == 0 { - return p, nil - } - evalCtx := b.ctx.GetExprCtx().GetEvalCtx() - // check expr field types. - for i, expr := range cnfExpres { - if expr.GetType(evalCtx).EvalType() == types.ETString { - tp := &types.FieldType{} - tp.SetType(mysql.TypeDouble) - tp.SetFlag(expr.GetType(evalCtx).GetFlag()) - tp.SetFlen(mysql.MaxRealWidth) - tp.SetDecimal(types.UnspecifiedLength) - types.SetBinChsClnFlag(tp) - cnfExpres[i] = expression.TryPushCastIntoControlFunctionForHybridType(b.ctx.GetExprCtx(), expr, tp) - } - } - selection.Conditions = cnfExpres - selection.SetChildren(p) - return selection, nil -} - -// buildProjectionFieldNameFromColumns builds the field name, table name and database name when field expression is a column reference. -func (*PlanBuilder) buildProjectionFieldNameFromColumns(origField *ast.SelectField, colNameField *ast.ColumnNameExpr, name *types.FieldName) (colName, origColName, tblName, origTblName, dbName model.CIStr) { - origTblName, origColName, dbName = name.OrigTblName, name.OrigColName, name.DBName - if origField.AsName.L == "" { - colName = colNameField.Name.Name - } else { - colName = origField.AsName - } - if tblName.L == "" { - tblName = name.TblName - } else { - tblName = colNameField.Name.Table - } - return -} - -// buildProjectionFieldNameFromExpressions builds the field name when field expression is a normal expression. -func (b *PlanBuilder) buildProjectionFieldNameFromExpressions(_ context.Context, field *ast.SelectField) (model.CIStr, error) { - if agg, ok := field.Expr.(*ast.AggregateFuncExpr); ok && agg.F == ast.AggFuncFirstRow { - // When the query is select t.a from t group by a; The Column Name should be a but not t.a; - return agg.Args[0].(*ast.ColumnNameExpr).Name.Name, nil - } - - innerExpr := getInnerFromParenthesesAndUnaryPlus(field.Expr) - funcCall, isFuncCall := innerExpr.(*ast.FuncCallExpr) - // When used to produce a result set column, NAME_CONST() causes the column to have the given name. - // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_name-const for details - if isFuncCall && funcCall.FnName.L == ast.NameConst { - if v, err := evalAstExpr(b.ctx.GetExprCtx(), funcCall.Args[0]); err == nil { - if s, err := v.ToString(); err == nil { - return model.NewCIStr(s), nil - } - } - return model.NewCIStr(""), plannererrors.ErrWrongArguments.GenWithStackByArgs("NAME_CONST") - } - valueExpr, isValueExpr := innerExpr.(*driver.ValueExpr) - - // Non-literal: Output as inputed, except that comments need to be removed. - if !isValueExpr { - return model.NewCIStr(parser.SpecFieldPattern.ReplaceAllStringFunc(field.Text(), parser.TrimComment)), nil - } - - // Literal: Need special processing - switch valueExpr.Kind() { - case types.KindString: - projName := valueExpr.GetString() - projOffset := valueExpr.GetProjectionOffset() - if projOffset >= 0 { - projName = projName[:projOffset] - } - // See #3686, #3994: - // For string literals, string content is used as column name. Non-graph initial characters are trimmed. - fieldName := strings.TrimLeftFunc(projName, func(r rune) bool { - return !unicode.IsOneOf(mysql.RangeGraph, r) - }) - return model.NewCIStr(fieldName), nil - case types.KindNull: - // See #4053, #3685 - return model.NewCIStr("NULL"), nil - case types.KindBinaryLiteral: - // Don't rewrite BIT literal or HEX literals - return model.NewCIStr(field.Text()), nil - case types.KindInt64: - // See #9683 - // TRUE or FALSE can be a int64 - if mysql.HasIsBooleanFlag(valueExpr.Type.GetFlag()) { - if i := valueExpr.GetValue().(int64); i == 0 { - return model.NewCIStr("FALSE"), nil - } - return model.NewCIStr("TRUE"), nil - } - fallthrough - - default: - fieldName := field.Text() - fieldName = strings.TrimLeft(fieldName, "\t\n +(") - fieldName = strings.TrimRight(fieldName, "\t\n )") - return model.NewCIStr(fieldName), nil - } -} - -func buildExpandFieldName(ctx expression.EvalContext, expr expression.Expression, name *types.FieldName, genName string) *types.FieldName { - _, isCol := expr.(*expression.Column) - var origTblName, origColName, dbName, colName, tblName model.CIStr - if genName != "" { - // for case like: gid_, gpos_ - colName = model.NewCIStr(expr.StringWithCtx(ctx, errors.RedactLogDisable)) - } else if isCol { - // col ref to original col, while its nullability may be changed. - origTblName, origColName, dbName = name.OrigTblName, name.OrigColName, name.DBName - colName = model.NewCIStr("ex_" + name.ColName.O) - tblName = model.NewCIStr("ex_" + name.TblName.O) - } else { - // Other: complicated expression. - colName = model.NewCIStr("ex_" + expr.StringWithCtx(ctx, errors.RedactLogDisable)) - } - newName := &types.FieldName{ - TblName: tblName, - OrigTblName: origTblName, - ColName: colName, - OrigColName: origColName, - DBName: dbName, - } - return newName -} - -// buildProjectionField builds the field object according to SelectField in projection. -func (b *PlanBuilder) buildProjectionField(ctx context.Context, p base.LogicalPlan, field *ast.SelectField, expr expression.Expression) (*expression.Column, *types.FieldName, error) { - var origTblName, tblName, origColName, colName, dbName model.CIStr - innerNode := getInnerFromParenthesesAndUnaryPlus(field.Expr) - col, isCol := expr.(*expression.Column) - // Correlated column won't affect the final output names. So we can put it in any of the three logic block. - // Don't put it into the first block just for simplifying the codes. - if colNameField, ok := innerNode.(*ast.ColumnNameExpr); ok && isCol { - // Field is a column reference. - idx := p.Schema().ColumnIndex(col) - var name *types.FieldName - // The column maybe the one from join's redundant part. - if idx == -1 { - name = findColFromNaturalUsingJoin(p, col) - } else { - name = p.OutputNames()[idx] - } - colName, origColName, tblName, origTblName, dbName = b.buildProjectionFieldNameFromColumns(field, colNameField, name) - } else if field.AsName.L != "" { - // Field has alias. - colName = field.AsName - } else { - // Other: field is an expression. - var err error - if colName, err = b.buildProjectionFieldNameFromExpressions(ctx, field); err != nil { - return nil, nil, err - } - } - name := &types.FieldName{ - TblName: tblName, - OrigTblName: origTblName, - ColName: colName, - OrigColName: origColName, - DBName: dbName, - } - if isCol { - return col, name, nil - } - if expr == nil { - return nil, name, nil - } - // invalid unique id - correlatedColUniqueID := int64(0) - if cc, ok := expr.(*expression.CorrelatedColumn); ok { - correlatedColUniqueID = cc.UniqueID - } - // for expr projection, we should record the map relationship down. - newCol := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: expr.GetType(b.ctx.GetExprCtx().GetEvalCtx()), - CorrelatedColUniqueID: correlatedColUniqueID, - } - if b.ctx.GetSessionVars().OptimizerEnableNewOnlyFullGroupByCheck { - if b.ctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol == nil { - b.ctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol = make(map[string]int, 1) - } - b.ctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol[string(expr.HashCode())] = int(newCol.UniqueID) - } - newCol.SetCoercibility(expr.Coercibility()) - return newCol, name, nil -} - -type userVarTypeProcessor struct { - ctx context.Context - plan base.LogicalPlan - builder *PlanBuilder - mapper map[*ast.AggregateFuncExpr]int - err error -} - -func (p *userVarTypeProcessor) Enter(in ast.Node) (ast.Node, bool) { - v, ok := in.(*ast.VariableExpr) - if !ok { - return in, false - } - if v.IsSystem || v.Value == nil { - return in, true - } - _, p.plan, p.err = p.builder.rewrite(p.ctx, v, p.plan, p.mapper, true) - return in, true -} - -func (p *userVarTypeProcessor) Leave(in ast.Node) (ast.Node, bool) { - return in, p.err == nil -} - -func (b *PlanBuilder) preprocessUserVarTypes(ctx context.Context, p base.LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int) error { - aggMapper := make(map[*ast.AggregateFuncExpr]int) - for agg, i := range mapper { - aggMapper[agg] = i - } - processor := userVarTypeProcessor{ - ctx: ctx, - plan: p, - builder: b, - mapper: aggMapper, - } - for _, field := range fields { - field.Expr.Accept(&processor) - if processor.err != nil { - return processor.err - } - } - return nil -} - -// findColFromNaturalUsingJoin is used to recursively find the column from the -// underlying natural-using-join. -// e.g. For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0`, the -// plan will be `join->selection->projection`. The schema of the `selection` -// will be `[t1.a]`, thus we need to recursively retrieve the `t2.a` from the -// underlying join. -func findColFromNaturalUsingJoin(p base.LogicalPlan, col *expression.Column) (name *types.FieldName) { - switch x := p.(type) { - case *logicalop.LogicalLimit, *LogicalSelection, *logicalop.LogicalTopN, *logicalop.LogicalSort, *logicalop.LogicalMaxOneRow: - return findColFromNaturalUsingJoin(p.Children()[0], col) - case *LogicalJoin: - if x.FullSchema != nil { - idx := x.FullSchema.ColumnIndex(col) - return x.FullNames[idx] - } - } - return nil -} - -type resolveGroupingTraverseAction struct { - CurrentBlockExpand *LogicalExpand -} - -func (r resolveGroupingTraverseAction) Transform(expr expression.Expression) (res expression.Expression) { - switch x := expr.(type) { - case *expression.Column: - // when meeting a column, judge whether it's a relate grouping set col. - // eg: select a, b from t group by a, c with rollup, here a is, while b is not. - // in underlying Expand schema (a,b,c,a',c'), a select list should be resolved to a'. - res, _ = r.CurrentBlockExpand.trySubstituteExprWithGroupingSetCol(x) - case *expression.CorrelatedColumn: - // select 1 in (select t2.a from t group by t2.a, b with rollup) from t2; - // in this case: group by item has correlated column t2.a, and it's select list contains t2.a as well. - res, _ = r.CurrentBlockExpand.trySubstituteExprWithGroupingSetCol(x) - case *expression.Constant: - // constant just keep it real: select 1 from t group by a, b with rollup. - res = x - case *expression.ScalarFunction: - // scalar function just try to resolve itself first, then if not changed, trying resolve its children. - var substituted bool - res, substituted = r.CurrentBlockExpand.trySubstituteExprWithGroupingSetCol(x) - if !substituted { - // if not changed, try to resolve it children. - // select a+1, grouping(b) from t group by a+1 (projected as c), b with rollup: in this case, a+1 is resolved as c as a whole. - // select a+1, grouping(b) from t group by a(projected as a'), b with rollup : in this case, a+1 is resolved as a'+ 1. - newArgs := x.GetArgs() - for i, arg := range newArgs { - newArgs[i] = r.Transform(arg) - } - res = x - } - default: - res = expr - } - return res -} - -func (b *PlanBuilder) replaceGroupingFunc(expr expression.Expression) expression.Expression { - // current block doesn't have an expand OP, just return it. - if b.currentBlockExpand == nil { - return expr - } - // curExpand can supply the DistinctGbyExprs and gid col. - traverseAction := resolveGroupingTraverseAction{CurrentBlockExpand: b.currentBlockExpand} - return expr.Traverse(traverseAction) -} - -func (b *PlanBuilder) implicitProjectGroupingSetCols(projSchema *expression.Schema, projNames []*types.FieldName, projExprs []expression.Expression) (*expression.Schema, []*types.FieldName, []expression.Expression) { - if b.currentBlockExpand == nil { - return projSchema, projNames, projExprs - } - m := make(map[int64]struct{}, len(b.currentBlockExpand.DistinctGroupByCol)) - for _, col := range projSchema.Columns { - m[col.UniqueID] = struct{}{} - } - for idx, gCol := range b.currentBlockExpand.DistinctGroupByCol { - if _, ok := m[gCol.UniqueID]; ok { - // grouping col has been explicitly projected, not need to reserve it here for later order-by item (a+1) - // like: select a+1, b from t group by a+1 order by a+1. - continue - } - // project the grouping col out implicitly here. If it's not used by later OP, it will be cleaned in column pruner. - projSchema.Append(gCol) - projExprs = append(projExprs, gCol) - projNames = append(projNames, b.currentBlockExpand.DistinctGbyColNames[idx]) - } - // project GID. - projSchema.Append(b.currentBlockExpand.GID) - projExprs = append(projExprs, b.currentBlockExpand.GID) - projNames = append(projNames, b.currentBlockExpand.GIDName) - // project GPos if any. - if b.currentBlockExpand.GPos != nil { - projSchema.Append(b.currentBlockExpand.GPos) - projExprs = append(projExprs, b.currentBlockExpand.GPos) - projNames = append(projNames, b.currentBlockExpand.GPosName) - } - return projSchema, projNames, projExprs -} - -// buildProjection returns a Projection plan and non-aux columns length. -func (b *PlanBuilder) buildProjection(ctx context.Context, p base.LogicalPlan, fields []*ast.SelectField, mapper map[*ast.AggregateFuncExpr]int, - windowMapper map[*ast.WindowFuncExpr]int, considerWindow bool, expandGenerateColumn bool) (base.LogicalPlan, []expression.Expression, int, error) { - err := b.preprocessUserVarTypes(ctx, p, fields, mapper) - if err != nil { - return nil, nil, 0, err - } - b.optFlag |= flagEliminateProjection - b.curClause = fieldList - proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, len(fields))}.Init(b.ctx, b.getSelectOffset()) - schema := expression.NewSchema(make([]*expression.Column, 0, len(fields))...) - oldLen := 0 - newNames := make([]*types.FieldName, 0, len(fields)) - for i, field := range fields { - if !field.Auxiliary { - oldLen++ - } - - isWindowFuncField := ast.HasWindowFlag(field.Expr) - // Although window functions occurs in the select fields, but it has to be processed after having clause. - // So when we build the projection for select fields, we need to skip the window function. - // When `considerWindow` is false, we will only build fields for non-window functions, so we add fake placeholders. - // for window functions. These fake placeholders will be erased in column pruning. - // When `considerWindow` is true, all the non-window fields have been built, so we just use the schema columns. - if considerWindow && !isWindowFuncField { - col := p.Schema().Columns[i] - proj.Exprs = append(proj.Exprs, col) - schema.Append(col) - newNames = append(newNames, p.OutputNames()[i]) - continue - } else if !considerWindow && isWindowFuncField { - expr := expression.NewZero() - proj.Exprs = append(proj.Exprs, expr) - col, name, err := b.buildProjectionField(ctx, p, field, expr) - if err != nil { - return nil, nil, 0, err - } - schema.Append(col) - newNames = append(newNames, name) - continue - } - newExpr, np, err := b.rewriteWithPreprocess(ctx, field.Expr, p, mapper, windowMapper, true, nil) - if err != nil { - return nil, nil, 0, err - } - - // for case: select a+1, b, sum(b), grouping(a) from t group by a, b with rollup. - // the column inside aggregate (only sum(b) here) should be resolved to original source column, - // while for others, just use expanded columns if exists: a'+ 1, b', group(gid) - newExpr = b.replaceGroupingFunc(newExpr) - - // For window functions in the order by clause, we will append an field for it. - // We need rewrite the window mapper here so order by clause could find the added field. - if considerWindow && isWindowFuncField && field.Auxiliary { - if windowExpr, ok := field.Expr.(*ast.WindowFuncExpr); ok { - windowMapper[windowExpr] = i - } - } - - p = np - proj.Exprs = append(proj.Exprs, newExpr) - - col, name, err := b.buildProjectionField(ctx, p, field, newExpr) - if err != nil { - return nil, nil, 0, err - } - schema.Append(col) - newNames = append(newNames, name) - } - // implicitly project expand grouping set cols, if not used later, it will being pruned out in logical column pruner. - schema, newNames, proj.Exprs = b.implicitProjectGroupingSetCols(schema, newNames, proj.Exprs) - - proj.SetSchema(schema) - proj.SetOutputNames(newNames) - if expandGenerateColumn { - // Sometimes we need to add some fields to the projection so that we can use generate column substitute - // optimization. For example: select a+1 from t order by a+1, with a virtual generate column c as (a+1) and - // an index on c. We need to add c into the projection so that we can replace a+1 with c. - exprToColumn := make(ExprColumnMap) - collectGenerateColumn(p, exprToColumn) - for expr, col := range exprToColumn { - idx := p.Schema().ColumnIndex(col) - if idx == -1 { - continue - } - if proj.Schema().Contains(col) { - continue - } - proj.Schema().Columns = append(proj.Schema().Columns, col) - proj.Exprs = append(proj.Exprs, expr) - proj.SetOutputNames(append(proj.OutputNames(), p.OutputNames()[idx])) - } - } - proj.SetChildren(p) - // delay the only-full-group-by-check in create view statement to later query. - if !b.isCreateView && b.ctx.GetSessionVars().OptimizerEnableNewOnlyFullGroupByCheck && b.ctx.GetSessionVars().SQLMode.HasOnlyFullGroupBy() { - fds := proj.ExtractFD() - // Projection -> Children -> ... - // Let the projection itself to evaluate the whole FD, which will build the connection - // 1: from select-expr to registered-expr - // 2: from base-column to select-expr - // After that - if fds.HasAggBuilt { - for offset, expr := range proj.Exprs[:len(fields)] { - // skip the auxiliary column in agg appended to select fields, which mainly comes from two kind of cases: - // 1: having agg(t.a), this will append t.a to the select fields, if it isn't here. - // 2: order by agg(t.a), this will append t.a to the select fields, if it isn't here. - if fields[offset].AuxiliaryColInAgg { - continue - } - item := intset.NewFastIntSet() - switch x := expr.(type) { - case *expression.Column: - item.Insert(int(x.UniqueID)) - case *expression.ScalarFunction: - if expression.CheckFuncInExpr(x, ast.AnyValue) { - continue - } - scalarUniqueID, ok := fds.IsHashCodeRegistered(string(hack.String(x.HashCode()))) - if !ok { - logutil.BgLogger().Warn("Error occurred while maintaining the functional dependency") - continue - } - item.Insert(scalarUniqueID) - default: - } - // Rule #1, if there are no group cols, the col in the order by shouldn't be limited. - if fds.GroupByCols.Only1Zero() && fields[offset].AuxiliaryColInOrderBy { - continue - } - - // Rule #2, if select fields are constant, it's ok. - if item.SubsetOf(fds.ConstantCols()) { - continue - } - - // Rule #3, if select fields are subset of group by items, it's ok. - if item.SubsetOf(fds.GroupByCols) { - continue - } - - // Rule #4, if select fields are dependencies of Strict FD with determinants in group-by items, it's ok. - // lax FD couldn't be done here, eg: for unique key (b), index key NULL & NULL are different rows with - // uncertain other column values. - strictClosure := fds.ClosureOfStrict(fds.GroupByCols) - if item.SubsetOf(strictClosure) { - continue - } - // locate the base col that are not in (constant list / group by list / strict fd closure) for error show. - baseCols := expression.ExtractColumns(expr) - errShowCol := baseCols[0] - for _, col := range baseCols { - colSet := intset.NewFastIntSet(int(col.UniqueID)) - if !colSet.SubsetOf(strictClosure) { - errShowCol = col - break - } - } - // better use the schema alias name firstly if any. - name := "" - for idx, schemaCol := range proj.Schema().Columns { - if schemaCol.UniqueID == errShowCol.UniqueID { - name = proj.OutputNames()[idx].String() - break - } - } - if name == "" { - name = errShowCol.OrigName - } - // Only1Zero is to judge whether it's no-group-by-items case. - if !fds.GroupByCols.Only1Zero() { - return nil, nil, 0, plannererrors.ErrFieldNotInGroupBy.GenWithStackByArgs(offset+1, ErrExprInSelect, name) - } - return nil, nil, 0, plannererrors.ErrMixOfGroupFuncAndFields.GenWithStackByArgs(offset+1, name) - } - if fds.GroupByCols.Only1Zero() { - // maxOneRow is delayed from agg's ExtractFD logic since some details listed in it. - projectionUniqueIDs := intset.NewFastIntSet() - for _, expr := range proj.Exprs { - switch x := expr.(type) { - case *expression.Column: - projectionUniqueIDs.Insert(int(x.UniqueID)) - case *expression.ScalarFunction: - scalarUniqueID, ok := fds.IsHashCodeRegistered(string(hack.String(x.HashCode()))) - if !ok { - logutil.BgLogger().Warn("Error occurred while maintaining the functional dependency") - continue - } - projectionUniqueIDs.Insert(scalarUniqueID) - } - } - fds.MaxOneRow(projectionUniqueIDs) - } - // for select * from view (include agg), outer projection don't have to check select list with the inner group-by flag. - fds.HasAggBuilt = false - } - } - return proj, proj.Exprs, oldLen, nil -} - -func (b *PlanBuilder) buildDistinct(child base.LogicalPlan, length int) (*LogicalAggregation, error) { - b.optFlag = b.optFlag | flagBuildKeyInfo - b.optFlag = b.optFlag | flagPushDownAgg - plan4Agg := LogicalAggregation{ - AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()), - GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]), - }.Init(b.ctx, child.QueryBlockOffset()) - if hintinfo := b.TableHints(); hintinfo != nil { - plan4Agg.PreferAggType = hintinfo.PreferAggType - plan4Agg.PreferAggToCop = hintinfo.PreferAggToCop - } - for _, col := range child.Schema().Columns { - aggDesc, err := aggregation.NewAggFuncDesc(b.ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{col}, false) - if err != nil { - return nil, err - } - plan4Agg.AggFuncs = append(plan4Agg.AggFuncs, aggDesc) - } - plan4Agg.SetChildren(child) - plan4Agg.SetSchema(child.Schema().Clone()) - plan4Agg.SetOutputNames(child.OutputNames()) - // Distinct will be rewritten as first_row, we reset the type here since the return type - // of first_row is not always the same as the column arg of first_row. - for i, col := range plan4Agg.Schema().Columns { - col.RetType = plan4Agg.AggFuncs[i].RetTp - } - return plan4Agg, nil -} - -// unionJoinFieldType finds the type which can carry the given types in Union. -// Note that unionJoinFieldType doesn't handle charset and collation, caller need to handle it by itself. -func unionJoinFieldType(a, b *types.FieldType) *types.FieldType { - // We ignore the pure NULL type. - if a.GetType() == mysql.TypeNull { - return b - } else if b.GetType() == mysql.TypeNull { - return a - } - resultTp := types.AggFieldType([]*types.FieldType{a, b}) - // This logic will be intelligible when it is associated with the buildProjection4Union logic. - if resultTp.GetType() == mysql.TypeNewDecimal { - // The decimal result type will be unsigned only when all the decimals to be united are unsigned. - resultTp.AndFlag(b.GetFlag() & mysql.UnsignedFlag) - } else { - // Non-decimal results will be unsigned when a,b both unsigned. - // ref1: https://dev.mysql.com/doc/refman/5.7/en/union.html#union-result-set - // ref2: https://github.com/pingcap/tidb/issues/24953 - resultTp.AddFlag((a.GetFlag() & mysql.UnsignedFlag) & (b.GetFlag() & mysql.UnsignedFlag)) - } - resultTp.SetDecimalUnderLimit(max(a.GetDecimal(), b.GetDecimal())) - // `flen - decimal` is the fraction before '.' - if a.GetFlen() == -1 || b.GetFlen() == -1 { - resultTp.SetFlenUnderLimit(-1) - } else { - resultTp.SetFlenUnderLimit(max(a.GetFlen()-a.GetDecimal(), b.GetFlen()-b.GetDecimal()) + resultTp.GetDecimal()) - } - types.TryToFixFlenOfDatetime(resultTp) - if resultTp.EvalType() != types.ETInt && (a.EvalType() == types.ETInt || b.EvalType() == types.ETInt) && resultTp.GetFlen() < mysql.MaxIntWidth { - resultTp.SetFlen(mysql.MaxIntWidth) - } - expression.SetBinFlagOrBinStr(b, resultTp) - return resultTp -} - -// Set the flen of the union column using the max flen in children. -func (b *PlanBuilder) setUnionFlen(resultTp *types.FieldType, cols []expression.Expression) { - if resultTp.GetFlen() == -1 { - return - } - isBinary := resultTp.GetCharset() == charset.CharsetBin - for i := 0; i < len(cols); i++ { - childTp := cols[i].GetType(b.ctx.GetExprCtx().GetEvalCtx()) - childTpCharLen := 1 - if isBinary { - if charsetInfo, ok := charset.CharacterSetInfos[childTp.GetCharset()]; ok { - childTpCharLen = charsetInfo.Maxlen - } - } - resultTp.SetFlen(max(resultTp.GetFlen(), childTpCharLen*childTp.GetFlen())) - } -} - -func (b *PlanBuilder) buildProjection4Union(_ context.Context, u *LogicalUnionAll) error { - unionCols := make([]*expression.Column, 0, u.Children()[0].Schema().Len()) - names := make([]*types.FieldName, 0, u.Children()[0].Schema().Len()) - - // Infer union result types by its children's schema. - for i, col := range u.Children()[0].Schema().Columns { - tmpExprs := make([]expression.Expression, 0, len(u.Children())) - tmpExprs = append(tmpExprs, col) - resultTp := col.RetType - for j := 1; j < len(u.Children()); j++ { - tmpExprs = append(tmpExprs, u.Children()[j].Schema().Columns[i]) - childTp := u.Children()[j].Schema().Columns[i].RetType - resultTp = unionJoinFieldType(resultTp, childTp) - } - collation, err := expression.CheckAndDeriveCollationFromExprs(b.ctx.GetExprCtx(), "UNION", resultTp.EvalType(), tmpExprs...) - if err != nil || collation.Coer == expression.CoercibilityNone { - return collate.ErrIllegalMixCollation.GenWithStackByArgs("UNION") - } - resultTp.SetCharset(collation.Charset) - resultTp.SetCollate(collation.Collation) - b.setUnionFlen(resultTp, tmpExprs) - names = append(names, &types.FieldName{ColName: u.Children()[0].OutputNames()[i].ColName}) - unionCols = append(unionCols, &expression.Column{ - RetType: resultTp, - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - }) - } - u.SetSchema(expression.NewSchema(unionCols...)) - u.SetOutputNames(names) - // Process each child and add a projection above original child. - // So the schema of `UnionAll` can be the same with its children's. - for childID, child := range u.Children() { - exprs := make([]expression.Expression, len(child.Schema().Columns)) - for i, srcCol := range child.Schema().Columns { - dstType := unionCols[i].RetType - srcType := srcCol.RetType - if !srcType.Equal(dstType) { - exprs[i] = expression.BuildCastFunction4Union(b.ctx.GetExprCtx(), srcCol, dstType) - } else { - exprs[i] = srcCol - } - } - b.optFlag |= flagEliminateProjection - proj := logicalop.LogicalProjection{Exprs: exprs, AvoidColumnEvaluator: true}.Init(b.ctx, b.getSelectOffset()) - proj.SetSchema(u.Schema().Clone()) - // reset the schema type to make the "not null" flag right. - for i, expr := range exprs { - proj.Schema().Columns[i].RetType = expr.GetType(b.ctx.GetExprCtx().GetEvalCtx()) - } - proj.SetChildren(child) - u.Children()[childID] = proj - } - return nil -} - -func (b *PlanBuilder) buildSetOpr(ctx context.Context, setOpr *ast.SetOprStmt) (base.LogicalPlan, error) { - if setOpr.With != nil { - l := len(b.outerCTEs) - defer func() { - b.outerCTEs = b.outerCTEs[:l] - }() - _, err := b.buildWith(ctx, setOpr.With) - if err != nil { - return nil, err - } - } - - // Because INTERSECT has higher precedence than UNION and EXCEPT. We build it first. - selectPlans := make([]base.LogicalPlan, 0, len(setOpr.SelectList.Selects)) - afterSetOprs := make([]*ast.SetOprType, 0, len(setOpr.SelectList.Selects)) - selects := setOpr.SelectList.Selects - for i := 0; i < len(selects); i++ { - intersects := []ast.Node{selects[i]} - for i+1 < len(selects) { - breakIteration := false - switch x := selects[i+1].(type) { - case *ast.SelectStmt: - if *x.AfterSetOperator != ast.Intersect && *x.AfterSetOperator != ast.IntersectAll { - breakIteration = true - } - case *ast.SetOprSelectList: - if *x.AfterSetOperator != ast.Intersect && *x.AfterSetOperator != ast.IntersectAll { - breakIteration = true - } - if x.Limit != nil || x.OrderBy != nil { - // when SetOprSelectList's limit and order-by is not nil, it means itself is converted from - // an independent ast.SetOprStmt in parser, its data should be evaluated first, and ordered - // by given items and conduct a limit on it, then it can only be integrated with other brothers. - breakIteration = true - } - } - if breakIteration { - break - } - intersects = append(intersects, selects[i+1]) - i++ - } - selectPlan, afterSetOpr, err := b.buildIntersect(ctx, intersects) - if err != nil { - return nil, err - } - selectPlans = append(selectPlans, selectPlan) - afterSetOprs = append(afterSetOprs, afterSetOpr) - } - setOprPlan, err := b.buildExcept(ctx, selectPlans, afterSetOprs) - if err != nil { - return nil, err - } - - oldLen := setOprPlan.Schema().Len() - - for i := 0; i < len(setOpr.SelectList.Selects); i++ { - b.handleHelper.popMap() - } - b.handleHelper.pushMap(nil) - - if setOpr.OrderBy != nil { - setOprPlan, err = b.buildSort(ctx, setOprPlan, setOpr.OrderBy.Items, nil, nil) - if err != nil { - return nil, err - } - } - - if setOpr.Limit != nil { - setOprPlan, err = b.buildLimit(setOprPlan, setOpr.Limit) - if err != nil { - return nil, err - } - } - - // Fix issue #8189 (https://github.com/pingcap/tidb/issues/8189). - // If there are extra expressions generated from `ORDER BY` clause, generate a `Projection` to remove them. - if oldLen != setOprPlan.Schema().Len() { - proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(setOprPlan.Schema().Columns[:oldLen])}.Init(b.ctx, b.getSelectOffset()) - proj.SetChildren(setOprPlan) - schema := expression.NewSchema(setOprPlan.Schema().Clone().Columns[:oldLen]...) - for _, col := range schema.Columns { - col.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() - } - proj.SetOutputNames(setOprPlan.OutputNames()[:oldLen]) - proj.SetSchema(schema) - return proj, nil - } - return setOprPlan, nil -} - -func (b *PlanBuilder) buildSemiJoinForSetOperator( - leftOriginPlan base.LogicalPlan, - rightPlan base.LogicalPlan, - joinType JoinType) (leftPlan base.LogicalPlan, err error) { - leftPlan, err = b.buildDistinct(leftOriginPlan, leftOriginPlan.Schema().Len()) - if err != nil { - return nil, err - } - b.optFlag |= flagConvertOuterToInnerJoin - - joinPlan := LogicalJoin{JoinType: joinType}.Init(b.ctx, b.getSelectOffset()) - joinPlan.SetChildren(leftPlan, rightPlan) - joinPlan.SetSchema(leftPlan.Schema()) - joinPlan.SetOutputNames(make([]*types.FieldName, leftPlan.Schema().Len())) - copy(joinPlan.OutputNames(), leftPlan.OutputNames()) - for j := 0; j < len(rightPlan.Schema().Columns); j++ { - leftCol, rightCol := leftPlan.Schema().Columns[j], rightPlan.Schema().Columns[j] - eqCond, err := expression.NewFunction(b.ctx.GetExprCtx(), ast.NullEQ, types.NewFieldType(mysql.TypeTiny), leftCol, rightCol) - if err != nil { - return nil, err - } - _, leftArgIsColumn := eqCond.(*expression.ScalarFunction).GetArgs()[0].(*expression.Column) - _, rightArgIsColumn := eqCond.(*expression.ScalarFunction).GetArgs()[1].(*expression.Column) - if leftCol.RetType.GetType() != rightCol.RetType.GetType() || !leftArgIsColumn || !rightArgIsColumn { - joinPlan.OtherConditions = append(joinPlan.OtherConditions, eqCond) - } else { - joinPlan.EqualConditions = append(joinPlan.EqualConditions, eqCond.(*expression.ScalarFunction)) - } - } - return joinPlan, nil -} - -// buildIntersect build the set operator for 'intersect'. It is called before buildExcept and buildUnion because of its -// higher precedence. -func (b *PlanBuilder) buildIntersect(ctx context.Context, selects []ast.Node) (base.LogicalPlan, *ast.SetOprType, error) { - var leftPlan base.LogicalPlan - var err error - var afterSetOperator *ast.SetOprType - switch x := selects[0].(type) { - case *ast.SelectStmt: - afterSetOperator = x.AfterSetOperator - leftPlan, err = b.buildSelect(ctx, x) - case *ast.SetOprSelectList: - afterSetOperator = x.AfterSetOperator - leftPlan, err = b.buildSetOpr(ctx, &ast.SetOprStmt{SelectList: x, With: x.With, Limit: x.Limit, OrderBy: x.OrderBy}) - } - if err != nil { - return nil, nil, err - } - if len(selects) == 1 { - return leftPlan, afterSetOperator, nil - } - - columnNums := leftPlan.Schema().Len() - for i := 1; i < len(selects); i++ { - var rightPlan base.LogicalPlan - switch x := selects[i].(type) { - case *ast.SelectStmt: - if *x.AfterSetOperator == ast.IntersectAll { - // TODO: support intersect all - return nil, nil, errors.Errorf("TiDB do not support intersect all") - } - rightPlan, err = b.buildSelect(ctx, x) - case *ast.SetOprSelectList: - if *x.AfterSetOperator == ast.IntersectAll { - // TODO: support intersect all - return nil, nil, errors.Errorf("TiDB do not support intersect all") - } - rightPlan, err = b.buildSetOpr(ctx, &ast.SetOprStmt{SelectList: x, With: x.With, Limit: x.Limit, OrderBy: x.OrderBy}) - } - if err != nil { - return nil, nil, err - } - if rightPlan.Schema().Len() != columnNums { - return nil, nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() - } - leftPlan, err = b.buildSemiJoinForSetOperator(leftPlan, rightPlan, SemiJoin) - if err != nil { - return nil, nil, err - } - } - return leftPlan, afterSetOperator, nil -} - -// buildExcept build the set operators for 'except', and in this function, it calls buildUnion at the same time. Because -// Union and except has the same precedence. -func (b *PlanBuilder) buildExcept(ctx context.Context, selects []base.LogicalPlan, afterSetOpts []*ast.SetOprType) (base.LogicalPlan, error) { - unionPlans := []base.LogicalPlan{selects[0]} - tmpAfterSetOpts := []*ast.SetOprType{nil} - columnNums := selects[0].Schema().Len() - for i := 1; i < len(selects); i++ { - rightPlan := selects[i] - if rightPlan.Schema().Len() != columnNums { - return nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() - } - if *afterSetOpts[i] == ast.Except { - leftPlan, err := b.buildUnion(ctx, unionPlans, tmpAfterSetOpts) - if err != nil { - return nil, err - } - leftPlan, err = b.buildSemiJoinForSetOperator(leftPlan, rightPlan, AntiSemiJoin) - if err != nil { - return nil, err - } - unionPlans = []base.LogicalPlan{leftPlan} - tmpAfterSetOpts = []*ast.SetOprType{nil} - } else if *afterSetOpts[i] == ast.ExceptAll { - // TODO: support except all. - return nil, errors.Errorf("TiDB do not support except all") - } else { - unionPlans = append(unionPlans, rightPlan) - tmpAfterSetOpts = append(tmpAfterSetOpts, afterSetOpts[i]) - } - } - return b.buildUnion(ctx, unionPlans, tmpAfterSetOpts) -} - -func (b *PlanBuilder) buildUnion(ctx context.Context, selects []base.LogicalPlan, afterSetOpts []*ast.SetOprType) (base.LogicalPlan, error) { - if len(selects) == 1 { - return selects[0], nil - } - distinctSelectPlans, allSelectPlans, err := b.divideUnionSelectPlans(ctx, selects, afterSetOpts) - if err != nil { - return nil, err - } - unionDistinctPlan, err := b.buildUnionAll(ctx, distinctSelectPlans) - if err != nil { - return nil, err - } - if unionDistinctPlan != nil { - unionDistinctPlan, err = b.buildDistinct(unionDistinctPlan, unionDistinctPlan.Schema().Len()) - if err != nil { - return nil, err - } - if len(allSelectPlans) > 0 { - // Can't change the statements order in order to get the correct column info. - allSelectPlans = append([]base.LogicalPlan{unionDistinctPlan}, allSelectPlans...) - } - } - - unionAllPlan, err := b.buildUnionAll(ctx, allSelectPlans) - if err != nil { - return nil, err - } - unionPlan := unionDistinctPlan - if unionAllPlan != nil { - unionPlan = unionAllPlan - } - - return unionPlan, nil -} - -// divideUnionSelectPlans resolves union's select stmts to logical plans. -// and divide result plans into "union-distinct" and "union-all" parts. -// divide rule ref: -// -// https://dev.mysql.com/doc/refman/5.7/en/union.html -// -// "Mixed UNION types are treated such that a DISTINCT union overrides any ALL union to its left." -func (*PlanBuilder) divideUnionSelectPlans(_ context.Context, selects []base.LogicalPlan, setOprTypes []*ast.SetOprType) (distinctSelects []base.LogicalPlan, allSelects []base.LogicalPlan, err error) { - firstUnionAllIdx := 0 - columnNums := selects[0].Schema().Len() - for i := len(selects) - 1; i > 0; i-- { - if firstUnionAllIdx == 0 && *setOprTypes[i] != ast.UnionAll { - firstUnionAllIdx = i + 1 - } - if selects[i].Schema().Len() != columnNums { - return nil, nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() - } - } - return selects[:firstUnionAllIdx], selects[firstUnionAllIdx:], nil -} - -func (b *PlanBuilder) buildUnionAll(ctx context.Context, subPlan []base.LogicalPlan) (base.LogicalPlan, error) { - if len(subPlan) == 0 { - return nil, nil - } - u := LogicalUnionAll{}.Init(b.ctx, b.getSelectOffset()) - u.SetChildren(subPlan...) - err := b.buildProjection4Union(ctx, u) - return u, err -} - -// itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem -type itemTransformer struct{} - -func (*itemTransformer) Enter(inNode ast.Node) (ast.Node, bool) { - if n, ok := inNode.(*driver.ParamMarkerExpr); ok { - newNode := expression.ConstructPositionExpr(n) - return newNode, true - } - return inNode, false -} - -func (*itemTransformer) Leave(inNode ast.Node) (ast.Node, bool) { - return inNode, false -} - -func (b *PlanBuilder) buildSort(ctx context.Context, p base.LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int) (*logicalop.LogicalSort, error) { - return b.buildSortWithCheck(ctx, p, byItems, aggMapper, windowMapper, nil, 0, false) -} - -func (b *PlanBuilder) buildSortWithCheck(ctx context.Context, p base.LogicalPlan, byItems []*ast.ByItem, aggMapper map[*ast.AggregateFuncExpr]int, windowMapper map[*ast.WindowFuncExpr]int, - projExprs []expression.Expression, oldLen int, hasDistinct bool) (*logicalop.LogicalSort, error) { - if _, isUnion := p.(*LogicalUnionAll); isUnion { - b.curClause = globalOrderByClause - } else { - b.curClause = orderByClause - } - sort := logicalop.LogicalSort{}.Init(b.ctx, b.getSelectOffset()) - exprs := make([]*util.ByItems, 0, len(byItems)) - transformer := &itemTransformer{} - for i, item := range byItems { - newExpr, _ := item.Expr.Accept(transformer) - item.Expr = newExpr.(ast.ExprNode) - it, np, err := b.rewriteWithPreprocess(ctx, item.Expr, p, aggMapper, windowMapper, true, nil) - if err != nil { - return nil, err - } - // for case: select a+1, b, sum(b) from t group by a+1, b with rollup order by a+1. - // currently, we fail to resolve (a+1) in order-by to projection item (a+1), and adding - // another a' in the select fields instead, leading finally resolved expr is a'+1 here. - // - // Anyway, a and a' has the same column unique id, so we can do the replacement work like - // we did in build projection phase. - it = b.replaceGroupingFunc(it) - - // check whether ORDER BY items show up in SELECT DISTINCT fields, see #12442 - if hasDistinct && projExprs != nil { - err = b.checkOrderByInDistinct(item, i, it, p, projExprs, oldLen) - if err != nil { - return nil, err - } - } - - p = np - exprs = append(exprs, &util.ByItems{Expr: it, Desc: item.Desc}) - } - sort.ByItems = exprs - sort.SetChildren(p) - return sort, nil -} - -// checkOrderByInDistinct checks whether ORDER BY has conflicts with DISTINCT, see #12442 -func (b *PlanBuilder) checkOrderByInDistinct(byItem *ast.ByItem, idx int, expr expression.Expression, p base.LogicalPlan, originalExprs []expression.Expression, length int) error { - // Check if expressions in ORDER BY whole match some fields in DISTINCT. - // e.g. - // select distinct count(a) from t group by b order by count(a); ✔ - // select distinct a+1 from t order by a+1; ✔ - // select distinct a+1 from t order by a+2; ✗ - evalCtx := b.ctx.GetExprCtx().GetEvalCtx() - for j := 0; j < length; j++ { - // both check original expression & as name - if expr.Equal(evalCtx, originalExprs[j]) || expr.Equal(evalCtx, p.Schema().Columns[j]) { - return nil - } - } - - // Check if referenced columns of expressions in ORDER BY whole match some fields in DISTINCT, - // both original expression and alias can be referenced. - // e.g. - // select distinct a from t order by sin(a); ✔ - // select distinct a, b from t order by a+b; ✔ - // select distinct count(a), sum(a) from t group by b order by sum(a); ✔ - cols := expression.ExtractColumns(expr) -CheckReferenced: - for _, col := range cols { - for j := 0; j < length; j++ { - if col.Equal(evalCtx, originalExprs[j]) || col.Equal(evalCtx, p.Schema().Columns[j]) { - continue CheckReferenced - } - } - - // Failed cases - // e.g. - // select distinct sin(a) from t order by a; ✗ - // select distinct a from t order by a+b; ✗ - if _, ok := byItem.Expr.(*ast.AggregateFuncExpr); ok { - return plannererrors.ErrAggregateInOrderNotSelect.GenWithStackByArgs(idx+1, "DISTINCT") - } - // select distinct count(a) from t group by b order by sum(a); ✗ - return plannererrors.ErrFieldInOrderNotSelect.GenWithStackByArgs(idx+1, col.OrigName, "DISTINCT") - } - return nil -} - -// getUintFromNode gets uint64 value from ast.Node. -// For ordinary statement, node should be uint64 constant value. -// For prepared statement, node is string. We should convert it to uint64. -func getUintFromNode(ctx base.PlanContext, n ast.Node, mustInt64orUint64 bool) (uVal uint64, isNull bool, isExpectedType bool) { - var val any - switch v := n.(type) { - case *driver.ValueExpr: - val = v.GetValue() - case *driver.ParamMarkerExpr: - if !v.InExecute { - return 0, false, true - } - if mustInt64orUint64 { - if expected, _ := CheckParamTypeInt64orUint64(v); !expected { - return 0, false, false - } - } - param, err := expression.ParamMarkerExpression(ctx, v, false) - if err != nil { - return 0, false, false - } - str, isNull, err := expression.GetStringFromConstant(ctx.GetExprCtx().GetEvalCtx(), param) - if err != nil { - return 0, false, false - } - if isNull { - return 0, true, true - } - val = str - default: - return 0, false, false - } - switch v := val.(type) { - case uint64: - return v, false, true - case int64: - if v >= 0 { - return uint64(v), false, true - } - case string: - ctx := ctx.GetSessionVars().StmtCtx.TypeCtx() - uVal, err := types.StrToUint(ctx, v, false) - if err != nil { - return 0, false, false - } - return uVal, false, true - } - return 0, false, false -} - -// CheckParamTypeInt64orUint64 check param type for plan cache limit, only allow int64 and uint64 now -// eg: set @a = 1; -func CheckParamTypeInt64orUint64(param *driver.ParamMarkerExpr) (bool, uint64) { - val := param.GetValue() - switch v := val.(type) { - case int64: - if v >= 0 { - return true, uint64(v) - } - case uint64: - return true, v - } - return false, 0 -} - -func extractLimitCountOffset(ctx base.PlanContext, limit *ast.Limit) (count uint64, - offset uint64, err error) { - var isExpectedType bool - if limit.Count != nil { - count, _, isExpectedType = getUintFromNode(ctx, limit.Count, true) - if !isExpectedType { - return 0, 0, plannererrors.ErrWrongArguments.GenWithStackByArgs("LIMIT") - } - } - if limit.Offset != nil { - offset, _, isExpectedType = getUintFromNode(ctx, limit.Offset, true) - if !isExpectedType { - return 0, 0, plannererrors.ErrWrongArguments.GenWithStackByArgs("LIMIT") - } - } - return count, offset, nil -} - -func (b *PlanBuilder) buildLimit(src base.LogicalPlan, limit *ast.Limit) (base.LogicalPlan, error) { - b.optFlag = b.optFlag | flagPushDownTopN - var ( - offset, count uint64 - err error - ) - if count, offset, err = extractLimitCountOffset(b.ctx, limit); err != nil { - return nil, err - } - - if count > math.MaxUint64-offset { - count = math.MaxUint64 - offset - } - if offset+count == 0 { - tableDual := logicalop.LogicalTableDual{RowCount: 0}.Init(b.ctx, b.getSelectOffset()) - tableDual.SetSchema(src.Schema()) - tableDual.SetOutputNames(src.OutputNames()) - return tableDual, nil - } - li := logicalop.LogicalLimit{ - Offset: offset, - Count: count, - }.Init(b.ctx, b.getSelectOffset()) - if hint := b.TableHints(); hint != nil { - li.PreferLimitToCop = hint.PreferLimitToCop - } - li.SetChildren(src) - return li, nil -} - -func resolveFromSelectFields(v *ast.ColumnNameExpr, fields []*ast.SelectField, ignoreAsName bool) (index int, err error) { - var matchedExpr ast.ExprNode - index = -1 - for i, field := range fields { - if field.Auxiliary { - continue - } - if field.Match(v, ignoreAsName) { - curCol, isCol := field.Expr.(*ast.ColumnNameExpr) - if !isCol { - return i, nil - } - if matchedExpr == nil { - matchedExpr = curCol - index = i - } else if !matchedExpr.(*ast.ColumnNameExpr).Name.Match(curCol.Name) && - !curCol.Name.Match(matchedExpr.(*ast.ColumnNameExpr).Name) { - return -1, plannererrors.ErrAmbiguous.GenWithStackByArgs(curCol.Name.Name.L, clauseMsg[fieldList]) - } - } - } - return -} - -// havingWindowAndOrderbyExprResolver visits Expr tree. -// It converts ColumnNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. -type havingWindowAndOrderbyExprResolver struct { - inAggFunc bool - inWindowFunc bool - inWindowSpec bool - inExpr bool - err error - p base.LogicalPlan - selectFields []*ast.SelectField - aggMapper map[*ast.AggregateFuncExpr]int - colMapper map[*ast.ColumnNameExpr]int - gbyItems []*ast.ByItem - outerSchemas []*expression.Schema - outerNames [][]*types.FieldName - curClause clauseCode - prevClause []clauseCode -} - -func (a *havingWindowAndOrderbyExprResolver) pushCurClause(newClause clauseCode) { - a.prevClause = append(a.prevClause, a.curClause) - a.curClause = newClause -} - -func (a *havingWindowAndOrderbyExprResolver) popCurClause() { - a.curClause = a.prevClause[len(a.prevClause)-1] - a.prevClause = a.prevClause[:len(a.prevClause)-1] -} - -// Enter implements Visitor interface. -func (a *havingWindowAndOrderbyExprResolver) Enter(n ast.Node) (node ast.Node, skipChildren bool) { - switch n.(type) { - case *ast.AggregateFuncExpr: - a.inAggFunc = true - case *ast.WindowFuncExpr: - a.inWindowFunc = true - case *ast.WindowSpec: - a.inWindowSpec = true - case *driver.ParamMarkerExpr, *ast.ColumnNameExpr, *ast.ColumnName: - case *ast.SubqueryExpr, *ast.ExistsSubqueryExpr: - // Enter a new context, skip it. - // For example: select sum(c) + c + exists(select c from t) from t; - return n, true - case *ast.PartitionByClause: - a.pushCurClause(partitionByClause) - case *ast.OrderByClause: - if a.inWindowSpec { - a.pushCurClause(windowOrderByClause) - } - default: - a.inExpr = true - } - return n, false -} - -func (a *havingWindowAndOrderbyExprResolver) resolveFromPlan(v *ast.ColumnNameExpr, p base.LogicalPlan, resolveFieldsFirst bool) (int, error) { - idx, err := expression.FindFieldName(p.OutputNames(), v.Name) - if err != nil { - return -1, err - } - schemaCols, outputNames := p.Schema().Columns, p.OutputNames() - if idx < 0 { - // For SQL like `select t2.a from t1 join t2 using(a) where t2.a > 0 - // order by t2.a`, the query plan will be `join->selection->sort`. The - // schema of selection will be `[t1.a]`, thus we need to recursively - // retrieve the `t2.a` from the underlying join. - switch x := p.(type) { - case *logicalop.LogicalLimit, *LogicalSelection, *logicalop.LogicalTopN, *logicalop.LogicalSort, *logicalop.LogicalMaxOneRow: - return a.resolveFromPlan(v, p.Children()[0], resolveFieldsFirst) - case *LogicalJoin: - if len(x.FullNames) != 0 { - idx, err = expression.FindFieldName(x.FullNames, v.Name) - schemaCols, outputNames = x.FullSchema.Columns, x.FullNames - } - } - if err != nil || idx < 0 { - // nowhere to be found. - return -1, err - } - } - col := schemaCols[idx] - if col.IsHidden { - return -1, plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name, clauseMsg[a.curClause]) - } - name := outputNames[idx] - newColName := &ast.ColumnName{ - Schema: name.DBName, - Table: name.TblName, - Name: name.ColName, - } - for i, field := range a.selectFields { - if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && c.Name.Match(newColName) { - return i, nil - } - } - // From https://github.com/pingcap/tidb/issues/51107 - // You should make the column in the having clause as the correlated column - // which is not relation with select's fields and GroupBy's fields. - // For SQLs like: - // SELECT * FROM `t1` WHERE NOT (`t1`.`col_1`>= ( - // SELECT `t2`.`col_7` - // FROM (`t1`) - // JOIN `t2` - // WHERE ISNULL(`t2`.`col_3`) HAVING `t1`.`col_6`>1951988) - // ) ; - // - // if resolveFieldsFirst is false, the groupby is not nil. - if resolveFieldsFirst && a.curClause == havingClause { - return -1, nil - } - sf := &ast.SelectField{ - Expr: &ast.ColumnNameExpr{Name: newColName}, - Auxiliary: true, - } - // appended with new select fields. set them with flag. - if a.inAggFunc { - // should skip check in FD for only full group by. - sf.AuxiliaryColInAgg = true - } else if a.curClause == orderByClause { - // should skip check in FD for only full group by only when group by item are empty. - sf.AuxiliaryColInOrderBy = true - } - sf.Expr.SetType(col.GetStaticType()) - a.selectFields = append(a.selectFields, sf) - return len(a.selectFields) - 1, nil -} - -// Leave implements Visitor interface. -func (a *havingWindowAndOrderbyExprResolver) Leave(n ast.Node) (node ast.Node, ok bool) { - switch v := n.(type) { - case *ast.AggregateFuncExpr: - a.inAggFunc = false - a.aggMapper[v] = len(a.selectFields) - a.selectFields = append(a.selectFields, &ast.SelectField{ - Auxiliary: true, - Expr: v, - AsName: model.NewCIStr(fmt.Sprintf("sel_agg_%d", len(a.selectFields))), - }) - case *ast.WindowFuncExpr: - a.inWindowFunc = false - if a.curClause == havingClause { - a.err = plannererrors.ErrWindowInvalidWindowFuncUse.GenWithStackByArgs(strings.ToLower(v.Name)) - return node, false - } - if a.curClause == orderByClause { - a.selectFields = append(a.selectFields, &ast.SelectField{ - Auxiliary: true, - Expr: v, - AsName: model.NewCIStr(fmt.Sprintf("sel_window_%d", len(a.selectFields))), - }) - } - case *ast.WindowSpec: - a.inWindowSpec = false - case *ast.PartitionByClause: - a.popCurClause() - case *ast.OrderByClause: - if a.inWindowSpec { - a.popCurClause() - } - case *ast.ColumnNameExpr: - resolveFieldsFirst := true - if a.inAggFunc || a.inWindowFunc || a.inWindowSpec || (a.curClause == orderByClause && a.inExpr) || a.curClause == fieldList { - resolveFieldsFirst = false - } - if !a.inAggFunc && a.curClause != orderByClause { - for _, item := range a.gbyItems { - if col, ok := item.Expr.(*ast.ColumnNameExpr); ok && - (v.Name.Match(col.Name) || col.Name.Match(v.Name)) { - resolveFieldsFirst = false - break - } - } - } - var index int - if resolveFieldsFirst { - index, a.err = resolveFromSelectFields(v, a.selectFields, false) - if a.err != nil { - return node, false - } - if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { - a.err = plannererrors.ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) - return node, false - } - if index == -1 { - if a.curClause == orderByClause { - index, a.err = a.resolveFromPlan(v, a.p, resolveFieldsFirst) - } else if a.curClause == havingClause && v.Name.Table.L != "" { - // For SQLs like: - // select a from t b having b.a; - index, a.err = a.resolveFromPlan(v, a.p, resolveFieldsFirst) - if a.err != nil { - return node, false - } - if index != -1 { - // For SQLs like: - // select a+1 from t having t.a; - field := a.selectFields[index] - if field.Auxiliary { // having can't use auxiliary field - index = -1 - } - } - } else { - index, a.err = resolveFromSelectFields(v, a.selectFields, true) - } - } - } else { - // We should ignore the err when resolving from schema. Because we could resolve successfully - // when considering select fields. - var err error - index, err = a.resolveFromPlan(v, a.p, resolveFieldsFirst) - _ = err - if index == -1 && a.curClause != fieldList && - a.curClause != windowOrderByClause && a.curClause != partitionByClause { - index, a.err = resolveFromSelectFields(v, a.selectFields, false) - if index != -1 && a.curClause == havingClause && ast.HasWindowFlag(a.selectFields[index].Expr) { - a.err = plannererrors.ErrWindowInvalidWindowFuncAliasUse.GenWithStackByArgs(v.Name.Name.O) - return node, false - } - } - } - if a.err != nil { - return node, false - } - if index == -1 { - // If we can't find it any where, it may be a correlated columns. - for _, names := range a.outerNames { - idx, err1 := expression.FindFieldName(names, v.Name) - if err1 != nil { - a.err = err1 - return node, false - } - if idx >= 0 { - return n, true - } - } - a.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs(v.Name.OrigColName(), clauseMsg[a.curClause]) - return node, false - } - if a.inAggFunc { - return a.selectFields[index].Expr, true - } - a.colMapper[v] = index - } - return n, true -} - -// resolveHavingAndOrderBy will process aggregate functions and resolve the columns that don't exist in select fields. -// If we found some columns that are not in select fields, we will append it to select fields and update the colMapper. -// When we rewrite the order by / having expression, we will find column in map at first. -func (b *PlanBuilder) resolveHavingAndOrderBy(ctx context.Context, sel *ast.SelectStmt, p base.LogicalPlan) ( - map[*ast.AggregateFuncExpr]int, map[*ast.AggregateFuncExpr]int, error) { - extractor := &havingWindowAndOrderbyExprResolver{ - p: p, - selectFields: sel.Fields.Fields, - aggMapper: make(map[*ast.AggregateFuncExpr]int), - colMapper: b.colMapper, - outerSchemas: b.outerSchemas, - outerNames: b.outerNames, - } - if sel.GroupBy != nil { - extractor.gbyItems = sel.GroupBy.Items - } - // Extract agg funcs from having clause. - if sel.Having != nil { - extractor.curClause = havingClause - n, ok := sel.Having.Expr.Accept(extractor) - if !ok { - return nil, nil, errors.Trace(extractor.err) - } - sel.Having.Expr = n.(ast.ExprNode) - } - havingAggMapper := extractor.aggMapper - extractor.aggMapper = make(map[*ast.AggregateFuncExpr]int) - // Extract agg funcs from order by clause. - if sel.OrderBy != nil { - extractor.curClause = orderByClause - for _, item := range sel.OrderBy.Items { - extractor.inExpr = false - if ast.HasWindowFlag(item.Expr) { - continue - } - n, ok := item.Expr.Accept(extractor) - if !ok { - return nil, nil, errors.Trace(extractor.err) - } - item.Expr = n.(ast.ExprNode) - } - } - sel.Fields.Fields = extractor.selectFields - // this part is used to fetch correlated column from sub-query item in order-by clause, and append the origin - // auxiliary select filed in select list, otherwise, sub-query itself won't get the name resolved in outer schema. - if sel.OrderBy != nil { - for _, byItem := range sel.OrderBy.Items { - if _, ok := byItem.Expr.(*ast.SubqueryExpr); ok { - // correlated agg will be extracted completely latter. - _, np, err := b.rewrite(ctx, byItem.Expr, p, nil, true) - if err != nil { - return nil, nil, errors.Trace(err) - } - correlatedCols := coreusage.ExtractCorrelatedCols4LogicalPlan(np) - for _, cone := range correlatedCols { - var colName *ast.ColumnName - for idx, pone := range p.Schema().Columns { - if cone.UniqueID == pone.UniqueID { - pname := p.OutputNames()[idx] - colName = &ast.ColumnName{ - Schema: pname.DBName, - Table: pname.TblName, - Name: pname.ColName, - } - break - } - } - if colName != nil { - columnNameExpr := &ast.ColumnNameExpr{Name: colName} - for _, field := range sel.Fields.Fields { - if c, ok := field.Expr.(*ast.ColumnNameExpr); ok && c.Name.Match(columnNameExpr.Name) && field.AsName.L == "" { - // deduplicate select fields: don't append it once it already has one. - // TODO: we add the field if it has alias, but actually they are the same column. We should not have two duplicate one. - columnNameExpr = nil - break - } - } - if columnNameExpr != nil { - sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ - Auxiliary: true, - Expr: columnNameExpr, - }) - } - } - } - } - } - } - return havingAggMapper, extractor.aggMapper, nil -} - -func (b *PlanBuilder) extractAggFuncsInExprs(exprs []ast.ExprNode) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { - extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} - for _, expr := range exprs { - expr.Accept(extractor) - } - aggList := extractor.AggFuncs - totalAggMapper := make(map[*ast.AggregateFuncExpr]int, len(aggList)) - - for i, agg := range aggList { - totalAggMapper[agg] = i - } - return aggList, totalAggMapper -} - -func (b *PlanBuilder) extractAggFuncsInSelectFields(fields []*ast.SelectField) ([]*ast.AggregateFuncExpr, map[*ast.AggregateFuncExpr]int) { - extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} - for _, f := range fields { - n, _ := f.Expr.Accept(extractor) - f.Expr = n.(ast.ExprNode) - } - aggList := extractor.AggFuncs - totalAggMapper := make(map[*ast.AggregateFuncExpr]int, len(aggList)) - - for i, agg := range aggList { - totalAggMapper[agg] = i - } - return aggList, totalAggMapper -} - -func (b *PlanBuilder) extractAggFuncsInByItems(byItems []*ast.ByItem) []*ast.AggregateFuncExpr { - extractor := &AggregateFuncExtractor{skipAggMap: b.correlatedAggMapper} - for _, f := range byItems { - n, _ := f.Expr.Accept(extractor) - f.Expr = n.(ast.ExprNode) - } - return extractor.AggFuncs -} - -// extractCorrelatedAggFuncs extracts correlated aggregates which belong to outer query from aggregate function list. -func (b *PlanBuilder) extractCorrelatedAggFuncs(ctx context.Context, p base.LogicalPlan, aggFuncs []*ast.AggregateFuncExpr) (outer []*ast.AggregateFuncExpr, err error) { - corCols := make([]*expression.CorrelatedColumn, 0, len(aggFuncs)) - cols := make([]*expression.Column, 0, len(aggFuncs)) - aggMapper := make(map[*ast.AggregateFuncExpr]int) - for _, agg := range aggFuncs { - for _, arg := range agg.Args { - expr, _, err := b.rewrite(ctx, arg, p, aggMapper, true) - if err != nil { - return nil, err - } - corCols = append(corCols, expression.ExtractCorColumns(expr)...) - cols = append(cols, expression.ExtractColumns(expr)...) - } - if len(corCols) > 0 && len(cols) == 0 { - outer = append(outer, agg) - } - aggMapper[agg] = -1 - corCols, cols = corCols[:0], cols[:0] - } - return -} - -// resolveWindowFunction will process window functions and resolve the columns that don't exist in select fields. -func (b *PlanBuilder) resolveWindowFunction(sel *ast.SelectStmt, p base.LogicalPlan) ( - map[*ast.AggregateFuncExpr]int, error) { - extractor := &havingWindowAndOrderbyExprResolver{ - p: p, - selectFields: sel.Fields.Fields, - aggMapper: make(map[*ast.AggregateFuncExpr]int), - colMapper: b.colMapper, - outerSchemas: b.outerSchemas, - outerNames: b.outerNames, - } - extractor.curClause = fieldList - for _, field := range sel.Fields.Fields { - if !ast.HasWindowFlag(field.Expr) { - continue - } - n, ok := field.Expr.Accept(extractor) - if !ok { - return nil, extractor.err - } - field.Expr = n.(ast.ExprNode) - } - for _, spec := range sel.WindowSpecs { - _, ok := spec.Accept(extractor) - if !ok { - return nil, extractor.err - } - } - if sel.OrderBy != nil { - extractor.curClause = orderByClause - for _, item := range sel.OrderBy.Items { - if !ast.HasWindowFlag(item.Expr) { - continue - } - n, ok := item.Expr.Accept(extractor) - if !ok { - return nil, extractor.err - } - item.Expr = n.(ast.ExprNode) - } - } - sel.Fields.Fields = extractor.selectFields - return extractor.aggMapper, nil -} - -// correlatedAggregateResolver visits Expr tree. -// It finds and collects all correlated aggregates which should be evaluated in the outer query. -type correlatedAggregateResolver struct { - ctx context.Context - err error - b *PlanBuilder - outerPlan base.LogicalPlan - - // correlatedAggFuncs stores aggregate functions which belong to outer query - correlatedAggFuncs []*ast.AggregateFuncExpr -} - -// Enter implements Visitor interface. -func (r *correlatedAggregateResolver) Enter(n ast.Node) (ast.Node, bool) { - if v, ok := n.(*ast.SelectStmt); ok { - if r.outerPlan != nil { - outerSchema := r.outerPlan.Schema() - r.b.outerSchemas = append(r.b.outerSchemas, outerSchema) - r.b.outerNames = append(r.b.outerNames, r.outerPlan.OutputNames()) - r.b.outerBlockExpand = append(r.b.outerBlockExpand, r.b.currentBlockExpand) - } - r.err = r.resolveSelect(v) - return n, true - } - return n, false -} - -// resolveSelect finds and collects correlated aggregates within the SELECT stmt. -// It resolves and builds FROM clause first to get a source plan, from which we can decide -// whether a column is correlated or not. -// Then it collects correlated aggregate from SELECT fields (including sub-queries), HAVING, -// ORDER BY, WHERE & GROUP BY. -// Finally it restore the original SELECT stmt. -func (r *correlatedAggregateResolver) resolveSelect(sel *ast.SelectStmt) (err error) { - if sel.With != nil { - l := len(r.b.outerCTEs) - defer func() { - r.b.outerCTEs = r.b.outerCTEs[:l] - }() - _, err := r.b.buildWith(r.ctx, sel.With) - if err != nil { - return err - } - } - // collect correlated aggregate from sub-queries inside FROM clause. - if err := r.collectFromTableRefs(sel.From); err != nil { - return err - } - p, err := r.b.buildTableRefs(r.ctx, sel.From) - if err != nil { - return err - } - - // similar to process in PlanBuilder.buildSelect - originalFields := sel.Fields.Fields - sel.Fields.Fields, err = r.b.unfoldWildStar(p, sel.Fields.Fields) - if err != nil { - return err - } - if r.b.capFlag&canExpandAST != 0 { - originalFields = sel.Fields.Fields - } - - hasWindowFuncField := r.b.detectSelectWindow(sel) - if hasWindowFuncField { - _, err = r.b.resolveWindowFunction(sel, p) - if err != nil { - return err - } - } - - _, _, err = r.b.resolveHavingAndOrderBy(r.ctx, sel, p) - if err != nil { - return err - } - - // find and collect correlated aggregates recursively in sub-queries - _, err = r.b.resolveCorrelatedAggregates(r.ctx, sel, p) - if err != nil { - return err - } - - // collect from SELECT fields, HAVING, ORDER BY and window functions - if r.b.detectSelectAgg(sel) { - err = r.collectFromSelectFields(p, sel.Fields.Fields) - if err != nil { - return err - } - } - - // collect from WHERE - err = r.collectFromWhere(p, sel.Where) - if err != nil { - return err - } - - // collect from GROUP BY - err = r.collectFromGroupBy(p, sel.GroupBy) - if err != nil { - return err - } - - // restore the sub-query - sel.Fields.Fields = originalFields - r.b.handleHelper.popMap() - return nil -} - -func (r *correlatedAggregateResolver) collectFromTableRefs(from *ast.TableRefsClause) error { - if from == nil { - return nil - } - subResolver := &correlatedAggregateResolver{ - ctx: r.ctx, - b: r.b, - } - _, ok := from.TableRefs.Accept(subResolver) - if !ok { - return subResolver.err - } - if len(subResolver.correlatedAggFuncs) == 0 { - return nil - } - r.correlatedAggFuncs = append(r.correlatedAggFuncs, subResolver.correlatedAggFuncs...) - return nil -} - -func (r *correlatedAggregateResolver) collectFromSelectFields(p base.LogicalPlan, fields []*ast.SelectField) error { - aggList, _ := r.b.extractAggFuncsInSelectFields(fields) - r.b.curClause = fieldList - outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) - if err != nil { - return nil - } - r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) - return nil -} - -func (r *correlatedAggregateResolver) collectFromGroupBy(p base.LogicalPlan, groupBy *ast.GroupByClause) error { - if groupBy == nil { - return nil - } - aggList := r.b.extractAggFuncsInByItems(groupBy.Items) - r.b.curClause = groupByClause - outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, aggList) - if err != nil { - return nil - } - r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) - return nil -} - -func (r *correlatedAggregateResolver) collectFromWhere(p base.LogicalPlan, where ast.ExprNode) error { - if where == nil { - return nil - } - extractor := &AggregateFuncExtractor{skipAggMap: r.b.correlatedAggMapper} - _, _ = where.Accept(extractor) - r.b.curClause = whereClause - outerAggFuncs, err := r.b.extractCorrelatedAggFuncs(r.ctx, p, extractor.AggFuncs) - if err != nil { - return err - } - r.correlatedAggFuncs = append(r.correlatedAggFuncs, outerAggFuncs...) - return nil -} - -// Leave implements Visitor interface. -func (r *correlatedAggregateResolver) Leave(n ast.Node) (ast.Node, bool) { - if _, ok := n.(*ast.SelectStmt); ok { - if r.outerPlan != nil { - r.b.outerSchemas = r.b.outerSchemas[0 : len(r.b.outerSchemas)-1] - r.b.outerNames = r.b.outerNames[0 : len(r.b.outerNames)-1] - r.b.currentBlockExpand = r.b.outerBlockExpand[len(r.b.outerBlockExpand)-1] - r.b.outerBlockExpand = r.b.outerBlockExpand[0 : len(r.b.outerBlockExpand)-1] - } - } - return n, r.err == nil -} - -// resolveCorrelatedAggregates finds and collects all correlated aggregates which should be evaluated -// in the outer query from all the sub-queries inside SELECT fields. -func (b *PlanBuilder) resolveCorrelatedAggregates(ctx context.Context, sel *ast.SelectStmt, p base.LogicalPlan) (map[*ast.AggregateFuncExpr]int, error) { - resolver := &correlatedAggregateResolver{ - ctx: ctx, - b: b, - outerPlan: p, - } - correlatedAggList := make([]*ast.AggregateFuncExpr, 0) - for _, field := range sel.Fields.Fields { - _, ok := field.Expr.Accept(resolver) - if !ok { - return nil, resolver.err - } - correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) - } - if sel.Having != nil { - _, ok := sel.Having.Expr.Accept(resolver) - if !ok { - return nil, resolver.err - } - correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) - } - if sel.OrderBy != nil { - for _, item := range sel.OrderBy.Items { - _, ok := item.Expr.Accept(resolver) - if !ok { - return nil, resolver.err - } - correlatedAggList = append(correlatedAggList, resolver.correlatedAggFuncs...) - } - } - correlatedAggMap := make(map[*ast.AggregateFuncExpr]int) - for _, aggFunc := range correlatedAggList { - colMap := make(map[*types.FieldName]struct{}, len(p.Schema().Columns)) - allColFromAggExprNode(p, aggFunc, colMap) - for k := range colMap { - colName := &ast.ColumnName{ - Schema: k.DBName, - Table: k.TblName, - Name: k.ColName, - } - // Add the column referred in the agg func into the select list. So that we can resolve the agg func correctly. - // And we need set the AuxiliaryColInAgg to true to help our only_full_group_by checker work correctly. - sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ - Auxiliary: true, - AuxiliaryColInAgg: true, - Expr: &ast.ColumnNameExpr{Name: colName}, - }) - } - correlatedAggMap[aggFunc] = len(sel.Fields.Fields) - sel.Fields.Fields = append(sel.Fields.Fields, &ast.SelectField{ - Auxiliary: true, - Expr: aggFunc, - AsName: model.NewCIStr(fmt.Sprintf("sel_subq_agg_%d", len(sel.Fields.Fields))), - }) - } - return correlatedAggMap, nil -} - -// gbyResolver resolves group by items from select fields. -type gbyResolver struct { - ctx base.PlanContext - fields []*ast.SelectField - schema *expression.Schema - names []*types.FieldName - err error - inExpr bool - isParam bool - skipAggMap map[*ast.AggregateFuncExpr]*expression.CorrelatedColumn - - exprDepth int // exprDepth is the depth of current expression in expression tree. -} - -func (g *gbyResolver) Enter(inNode ast.Node) (ast.Node, bool) { - g.exprDepth++ - switch n := inNode.(type) { - case *ast.SubqueryExpr, *ast.CompareSubqueryExpr, *ast.ExistsSubqueryExpr: - return inNode, true - case *driver.ParamMarkerExpr: - g.isParam = true - if g.exprDepth == 1 && !n.UseAsValueInGbyByClause { - _, isNull, isExpectedType := getUintFromNode(g.ctx, n, false) - // For constant uint expression in top level, it should be treated as position expression. - if !isNull && isExpectedType { - return expression.ConstructPositionExpr(n), true - } - } - return n, true - case *driver.ValueExpr, *ast.ColumnNameExpr, *ast.ParenthesesExpr, *ast.ColumnName: - default: - g.inExpr = true - } - return inNode, false -} - -func (g *gbyResolver) Leave(inNode ast.Node) (ast.Node, bool) { - extractor := &AggregateFuncExtractor{skipAggMap: g.skipAggMap} - switch v := inNode.(type) { - case *ast.ColumnNameExpr: - idx, err := expression.FindFieldName(g.names, v.Name) - if idx < 0 || !g.inExpr { - var index int - index, g.err = resolveFromSelectFields(v, g.fields, false) - if g.err != nil { - g.err = plannererrors.ErrAmbiguous.GenWithStackByArgs(v.Name.Name.L, clauseMsg[groupByClause]) - return inNode, false - } - if idx >= 0 { - return inNode, true - } - if index != -1 { - ret := g.fields[index].Expr - ret.Accept(extractor) - if len(extractor.AggFuncs) != 0 { - err = plannererrors.ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to group function") - } else if ast.HasWindowFlag(ret) { - err = plannererrors.ErrIllegalReference.GenWithStackByArgs(v.Name.OrigColName(), "reference to window function") - } else { - if isParam, ok := ret.(*driver.ParamMarkerExpr); ok { - isParam.UseAsValueInGbyByClause = true - } - return ret, true - } - } - g.err = err - return inNode, false - } - case *ast.PositionExpr: - pos, isNull, err := expression.PosFromPositionExpr(g.ctx.GetExprCtx(), g.ctx, v) - if err != nil { - g.err = plannererrors.ErrUnknown.GenWithStackByArgs() - } - if err != nil || isNull { - return inNode, false - } - if pos < 1 || pos > len(g.fields) { - g.err = errors.Errorf("Unknown column '%d' in 'group statement'", pos) - return inNode, false - } - ret := g.fields[pos-1].Expr - ret.Accept(extractor) - if len(extractor.AggFuncs) != 0 || ast.HasWindowFlag(ret) { - fieldName := g.fields[pos-1].AsName.String() - if fieldName == "" { - fieldName = g.fields[pos-1].Text() - } - g.err = plannererrors.ErrWrongGroupField.GenWithStackByArgs(fieldName) - return inNode, false - } - return ret, true - case *ast.ValuesExpr: - if v.Column == nil { - g.err = plannererrors.ErrUnknownColumn.GenWithStackByArgs("", "VALUES() function") - } - } - return inNode, true -} - -func tblInfoFromCol(from ast.ResultSetNode, name *types.FieldName) *model.TableInfo { - tableList := ExtractTableList(from, true) - for _, field := range tableList { - if field.Name.L == name.TblName.L { - return field.TableInfo - } - } - return nil -} - -func buildFuncDependCol(p base.LogicalPlan, cond ast.ExprNode) (*types.FieldName, *types.FieldName, error) { - binOpExpr, ok := cond.(*ast.BinaryOperationExpr) - if !ok { - return nil, nil, nil - } - if binOpExpr.Op != opcode.EQ { - return nil, nil, nil - } - lColExpr, ok := binOpExpr.L.(*ast.ColumnNameExpr) - if !ok { - return nil, nil, nil - } - rColExpr, ok := binOpExpr.R.(*ast.ColumnNameExpr) - if !ok { - return nil, nil, nil - } - lIdx, err := expression.FindFieldName(p.OutputNames(), lColExpr.Name) - if err != nil { - return nil, nil, err - } - rIdx, err := expression.FindFieldName(p.OutputNames(), rColExpr.Name) - if err != nil { - return nil, nil, err - } - if lIdx == -1 { - return nil, nil, plannererrors.ErrUnknownColumn.GenWithStackByArgs(lColExpr.Name, "where clause") - } - if rIdx == -1 { - return nil, nil, plannererrors.ErrUnknownColumn.GenWithStackByArgs(rColExpr.Name, "where clause") - } - return p.OutputNames()[lIdx], p.OutputNames()[rIdx], nil -} - -func buildWhereFuncDepend(p base.LogicalPlan, where ast.ExprNode) (map[*types.FieldName]*types.FieldName, error) { - whereConditions := splitWhere(where) - colDependMap := make(map[*types.FieldName]*types.FieldName, 2*len(whereConditions)) - for _, cond := range whereConditions { - lCol, rCol, err := buildFuncDependCol(p, cond) - if err != nil { - return nil, err - } - if lCol == nil || rCol == nil { - continue - } - colDependMap[lCol] = rCol - colDependMap[rCol] = lCol - } - return colDependMap, nil -} - -func buildJoinFuncDepend(p base.LogicalPlan, from ast.ResultSetNode) (map[*types.FieldName]*types.FieldName, error) { - switch x := from.(type) { - case *ast.Join: - if x.On == nil { - return nil, nil - } - onConditions := splitWhere(x.On.Expr) - colDependMap := make(map[*types.FieldName]*types.FieldName, len(onConditions)) - for _, cond := range onConditions { - lCol, rCol, err := buildFuncDependCol(p, cond) - if err != nil { - return nil, err - } - if lCol == nil || rCol == nil { - continue - } - lTbl := tblInfoFromCol(x.Left, lCol) - if lTbl == nil { - lCol, rCol = rCol, lCol - } - switch x.Tp { - case ast.CrossJoin: - colDependMap[lCol] = rCol - colDependMap[rCol] = lCol - case ast.LeftJoin: - colDependMap[rCol] = lCol - case ast.RightJoin: - colDependMap[lCol] = rCol - } - } - return colDependMap, nil - default: - return nil, nil - } -} - -func checkColFuncDepend( - p base.LogicalPlan, - name *types.FieldName, - tblInfo *model.TableInfo, - gbyOrSingleValueColNames map[*types.FieldName]struct{}, - whereDependNames, joinDependNames map[*types.FieldName]*types.FieldName, -) bool { - for _, index := range tblInfo.Indices { - if !index.Unique { - continue - } - funcDepend := true - // if all columns of some unique/pri indexes are determined, all columns left are check-passed. - for _, indexCol := range index.Columns { - iColInfo := tblInfo.Columns[indexCol.Offset] - if !mysql.HasNotNullFlag(iColInfo.GetFlag()) { - funcDepend = false - break - } - cn := &ast.ColumnName{ - Schema: name.DBName, - Table: name.TblName, - Name: iColInfo.Name, - } - iIdx, err := expression.FindFieldName(p.OutputNames(), cn) - if err != nil || iIdx < 0 { - funcDepend = false - break - } - iName := p.OutputNames()[iIdx] - if _, ok := gbyOrSingleValueColNames[iName]; ok { - continue - } - if wCol, ok := whereDependNames[iName]; ok { - if _, ok = gbyOrSingleValueColNames[wCol]; ok { - continue - } - } - if jCol, ok := joinDependNames[iName]; ok { - if _, ok = gbyOrSingleValueColNames[jCol]; ok { - continue - } - } - funcDepend = false - break - } - if funcDepend { - return true - } - } - primaryFuncDepend := true - hasPrimaryField := false - for _, colInfo := range tblInfo.Columns { - if !mysql.HasPriKeyFlag(colInfo.GetFlag()) { - continue - } - hasPrimaryField = true - pkName := &ast.ColumnName{ - Schema: name.DBName, - Table: name.TblName, - Name: colInfo.Name, - } - pIdx, err := expression.FindFieldName(p.OutputNames(), pkName) - // It is possible that `pIdx < 0` and here is a case. - // ``` - // CREATE TABLE `BB` ( - // `pk` int(11) NOT NULL AUTO_INCREMENT, - // `col_int_not_null` int NOT NULL, - // PRIMARY KEY (`pk`) - // ); - // - // SELECT OUTR . col2 AS X - // FROM - // BB AS OUTR2 - // INNER JOIN - // (SELECT col_int_not_null AS col1, - // pk AS col2 - // FROM BB) AS OUTR ON OUTR2.col_int_not_null = OUTR.col1 - // GROUP BY OUTR2.col_int_not_null; - // ``` - // When we enter `checkColFuncDepend`, `pkName.Table` is `OUTR` which is an alias, while `pkName.Name` is `pk` - // which is a original name. Hence `expression.FindFieldName` will fail and `pIdx` will be less than 0. - // Currently, when we meet `pIdx < 0`, we directly regard `primaryFuncDepend` as false and jump out. This way is - // easy to implement but makes only-full-group-by checker not smart enough. Later we will refactor only-full-group-by - // checker and resolve the inconsistency between the alias table name and the original column name. - if err != nil || pIdx < 0 { - primaryFuncDepend = false - break - } - pCol := p.OutputNames()[pIdx] - if _, ok := gbyOrSingleValueColNames[pCol]; ok { - continue - } - if wCol, ok := whereDependNames[pCol]; ok { - if _, ok = gbyOrSingleValueColNames[wCol]; ok { - continue - } - } - if jCol, ok := joinDependNames[pCol]; ok { - if _, ok = gbyOrSingleValueColNames[jCol]; ok { - continue - } - } - primaryFuncDepend = false - break - } - return primaryFuncDepend && hasPrimaryField -} - -// ErrExprLoc is for generate the ErrFieldNotInGroupBy error info -type ErrExprLoc struct { - Offset int - Loc string -} - -func checkExprInGroupByOrIsSingleValue( - p base.LogicalPlan, - expr ast.ExprNode, - offset int, - loc string, - gbyOrSingleValueColNames map[*types.FieldName]struct{}, - gbyExprs []ast.ExprNode, - notInGbyOrSingleValueColNames map[*types.FieldName]ErrExprLoc, -) { - if _, ok := expr.(*ast.AggregateFuncExpr); ok { - return - } - if f, ok := expr.(*ast.FuncCallExpr); ok { - if f.FnName.L == ast.Grouping { - // just skip grouping function check here, because later in building plan phase, we - // will do the grouping function valid check. - return - } - } - if _, ok := expr.(*ast.ColumnNameExpr); !ok { - for _, gbyExpr := range gbyExprs { - if ast.ExpressionDeepEqual(gbyExpr, expr) { - return - } - } - } - // Function `any_value` can be used in aggregation, even `ONLY_FULL_GROUP_BY` is set. - // See https://dev.mysql.com/doc/refman/5.7/en/miscellaneous-functions.html#function_any-value for details - if f, ok := expr.(*ast.FuncCallExpr); ok { - if f.FnName.L == ast.AnyValue { - return - } - } - colMap := make(map[*types.FieldName]struct{}, len(p.Schema().Columns)) - allColFromExprNode(p, expr, colMap) - for col := range colMap { - if _, ok := gbyOrSingleValueColNames[col]; !ok { - notInGbyOrSingleValueColNames[col] = ErrExprLoc{Offset: offset, Loc: loc} - } - } -} - -func (b *PlanBuilder) checkOnlyFullGroupBy(p base.LogicalPlan, sel *ast.SelectStmt) (err error) { - if sel.GroupBy != nil { - err = b.checkOnlyFullGroupByWithGroupClause(p, sel) - } else { - err = b.checkOnlyFullGroupByWithOutGroupClause(p, sel) - } - return err -} - -func addGbyOrSingleValueColName(p base.LogicalPlan, colName *ast.ColumnName, gbyOrSingleValueColNames map[*types.FieldName]struct{}) { - idx, err := expression.FindFieldName(p.OutputNames(), colName) - if err != nil || idx < 0 { - return - } - gbyOrSingleValueColNames[p.OutputNames()[idx]] = struct{}{} -} - -func extractSingeValueColNamesFromWhere(p base.LogicalPlan, where ast.ExprNode, gbyOrSingleValueColNames map[*types.FieldName]struct{}) { - whereConditions := splitWhere(where) - for _, cond := range whereConditions { - binOpExpr, ok := cond.(*ast.BinaryOperationExpr) - if !ok || binOpExpr.Op != opcode.EQ { - continue - } - if colExpr, ok := binOpExpr.L.(*ast.ColumnNameExpr); ok { - if _, ok := binOpExpr.R.(ast.ValueExpr); ok { - addGbyOrSingleValueColName(p, colExpr.Name, gbyOrSingleValueColNames) - } - } else if colExpr, ok := binOpExpr.R.(*ast.ColumnNameExpr); ok { - if _, ok := binOpExpr.L.(ast.ValueExpr); ok { - addGbyOrSingleValueColName(p, colExpr.Name, gbyOrSingleValueColNames) - } - } - } -} - -func (*PlanBuilder) checkOnlyFullGroupByWithGroupClause(p base.LogicalPlan, sel *ast.SelectStmt) error { - gbyOrSingleValueColNames := make(map[*types.FieldName]struct{}, len(sel.Fields.Fields)) - gbyExprs := make([]ast.ExprNode, 0, len(sel.Fields.Fields)) - for _, byItem := range sel.GroupBy.Items { - expr := getInnerFromParenthesesAndUnaryPlus(byItem.Expr) - if colExpr, ok := expr.(*ast.ColumnNameExpr); ok { - addGbyOrSingleValueColName(p, colExpr.Name, gbyOrSingleValueColNames) - } else { - gbyExprs = append(gbyExprs, expr) - } - } - // MySQL permits a nonaggregate column not named in a GROUP BY clause when ONLY_FULL_GROUP_BY SQL mode is enabled, - // provided that this column is limited to a single value. - // See https://dev.mysql.com/doc/refman/5.7/en/group-by-handling.html for details. - extractSingeValueColNamesFromWhere(p, sel.Where, gbyOrSingleValueColNames) - - notInGbyOrSingleValueColNames := make(map[*types.FieldName]ErrExprLoc, len(sel.Fields.Fields)) - for offset, field := range sel.Fields.Fields { - if field.Auxiliary { - continue - } - checkExprInGroupByOrIsSingleValue(p, getInnerFromParenthesesAndUnaryPlus(field.Expr), offset, ErrExprInSelect, gbyOrSingleValueColNames, gbyExprs, notInGbyOrSingleValueColNames) - } - - if sel.OrderBy != nil { - for offset, item := range sel.OrderBy.Items { - if colName, ok := item.Expr.(*ast.ColumnNameExpr); ok { - index, err := resolveFromSelectFields(colName, sel.Fields.Fields, false) - if err != nil { - return err - } - // If the ByItem is in fields list, it has been checked already in above. - if index >= 0 { - continue - } - } - checkExprInGroupByOrIsSingleValue(p, getInnerFromParenthesesAndUnaryPlus(item.Expr), offset, ErrExprInOrderBy, gbyOrSingleValueColNames, gbyExprs, notInGbyOrSingleValueColNames) - } - } - if len(notInGbyOrSingleValueColNames) == 0 { - return nil - } - - whereDepends, err := buildWhereFuncDepend(p, sel.Where) - if err != nil { - return err - } - joinDepends, err := buildJoinFuncDepend(p, sel.From.TableRefs) - if err != nil { - return err - } - tblMap := make(map[*model.TableInfo]struct{}, len(notInGbyOrSingleValueColNames)) - for name, errExprLoc := range notInGbyOrSingleValueColNames { - tblInfo := tblInfoFromCol(sel.From.TableRefs, name) - if tblInfo == nil { - continue - } - if _, ok := tblMap[tblInfo]; ok { - continue - } - if checkColFuncDepend(p, name, tblInfo, gbyOrSingleValueColNames, whereDepends, joinDepends) { - tblMap[tblInfo] = struct{}{} - continue - } - switch errExprLoc.Loc { - case ErrExprInSelect: - if sel.GroupBy.Rollup { - return plannererrors.ErrFieldInGroupingNotGroupBy.GenWithStackByArgs(strconv.Itoa(errExprLoc.Offset + 1)) - } - return plannererrors.ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, name.DBName.O+"."+name.TblName.O+"."+name.OrigColName.O) - case ErrExprInOrderBy: - return plannererrors.ErrFieldNotInGroupBy.GenWithStackByArgs(errExprLoc.Offset+1, errExprLoc.Loc, sel.OrderBy.Items[errExprLoc.Offset].Expr.Text()) - } - return nil - } - return nil -} - -func (*PlanBuilder) checkOnlyFullGroupByWithOutGroupClause(p base.LogicalPlan, sel *ast.SelectStmt) error { - resolver := colResolverForOnlyFullGroupBy{ - firstOrderByAggColIdx: -1, - } - resolver.curClause = fieldList - for idx, field := range sel.Fields.Fields { - resolver.exprIdx = idx - field.Accept(&resolver) - } - if len(resolver.nonAggCols) > 0 { - if sel.Having != nil { - sel.Having.Expr.Accept(&resolver) - } - if sel.OrderBy != nil { - resolver.curClause = orderByClause - for idx, byItem := range sel.OrderBy.Items { - resolver.exprIdx = idx - byItem.Expr.Accept(&resolver) - } - } - } - if resolver.firstOrderByAggColIdx != -1 && len(resolver.nonAggCols) > 0 { - // SQL like `select a from t where a = 1 order by count(b)` is illegal. - return plannererrors.ErrAggregateOrderNonAggQuery.GenWithStackByArgs(resolver.firstOrderByAggColIdx + 1) - } - if !resolver.hasAggFuncOrAnyValue || len(resolver.nonAggCols) == 0 { - return nil - } - singleValueColNames := make(map[*types.FieldName]struct{}, len(sel.Fields.Fields)) - extractSingeValueColNamesFromWhere(p, sel.Where, singleValueColNames) - whereDepends, err := buildWhereFuncDepend(p, sel.Where) - if err != nil { - return err - } - - joinDepends, err := buildJoinFuncDepend(p, sel.From.TableRefs) - if err != nil { - return err - } - tblMap := make(map[*model.TableInfo]struct{}, len(resolver.nonAggCols)) - for i, colName := range resolver.nonAggCols { - idx, err := expression.FindFieldName(p.OutputNames(), colName) - if err != nil || idx < 0 { - return plannererrors.ErrMixOfGroupFuncAndFields.GenWithStackByArgs(resolver.nonAggColIdxs[i]+1, colName.Name.O) - } - fieldName := p.OutputNames()[idx] - if _, ok := singleValueColNames[fieldName]; ok { - continue - } - tblInfo := tblInfoFromCol(sel.From.TableRefs, fieldName) - if tblInfo == nil { - continue - } - if _, ok := tblMap[tblInfo]; ok { - continue - } - if checkColFuncDepend(p, fieldName, tblInfo, singleValueColNames, whereDepends, joinDepends) { - tblMap[tblInfo] = struct{}{} - continue - } - return plannererrors.ErrMixOfGroupFuncAndFields.GenWithStackByArgs(resolver.nonAggColIdxs[i]+1, colName.Name.O) - } - return nil -} - -// colResolverForOnlyFullGroupBy visits Expr tree to find out if an Expr tree is an aggregation function. -// If so, find out the first column name that not in an aggregation function. -type colResolverForOnlyFullGroupBy struct { - nonAggCols []*ast.ColumnName - exprIdx int - nonAggColIdxs []int - hasAggFuncOrAnyValue bool - firstOrderByAggColIdx int - curClause clauseCode -} - -func (c *colResolverForOnlyFullGroupBy) Enter(node ast.Node) (ast.Node, bool) { - switch t := node.(type) { - case *ast.AggregateFuncExpr: - c.hasAggFuncOrAnyValue = true - if c.curClause == orderByClause { - c.firstOrderByAggColIdx = c.exprIdx - } - return node, true - case *ast.FuncCallExpr: - // enable function `any_value` in aggregation even `ONLY_FULL_GROUP_BY` is set - if t.FnName.L == ast.AnyValue { - c.hasAggFuncOrAnyValue = true - return node, true - } - case *ast.ColumnNameExpr: - c.nonAggCols = append(c.nonAggCols, t.Name) - c.nonAggColIdxs = append(c.nonAggColIdxs, c.exprIdx) - return node, true - case *ast.SubqueryExpr: - return node, true - } - return node, false -} - -func (*colResolverForOnlyFullGroupBy) Leave(node ast.Node) (ast.Node, bool) { - return node, true -} - -type aggColNameResolver struct { - colNameResolver -} - -func (*aggColNameResolver) Enter(inNode ast.Node) (ast.Node, bool) { - if _, ok := inNode.(*ast.ColumnNameExpr); ok { - return inNode, true - } - return inNode, false -} - -func allColFromAggExprNode(p base.LogicalPlan, n ast.Node, names map[*types.FieldName]struct{}) { - extractor := &aggColNameResolver{ - colNameResolver: colNameResolver{ - p: p, - names: names, - }, - } - n.Accept(extractor) -} - -type colNameResolver struct { - p base.LogicalPlan - names map[*types.FieldName]struct{} -} - -func (*colNameResolver) Enter(inNode ast.Node) (ast.Node, bool) { - switch inNode.(type) { - case *ast.ColumnNameExpr, *ast.SubqueryExpr, *ast.AggregateFuncExpr: - return inNode, true - } - return inNode, false -} - -func (c *colNameResolver) Leave(inNode ast.Node) (ast.Node, bool) { - if v, ok := inNode.(*ast.ColumnNameExpr); ok { - idx, err := expression.FindFieldName(c.p.OutputNames(), v.Name) - if err == nil && idx >= 0 { - c.names[c.p.OutputNames()[idx]] = struct{}{} - } - } - return inNode, true -} - -func allColFromExprNode(p base.LogicalPlan, n ast.Node, names map[*types.FieldName]struct{}) { - extractor := &colNameResolver{ - p: p, - names: names, - } - n.Accept(extractor) -} - -func (b *PlanBuilder) resolveGbyExprs(ctx context.Context, p base.LogicalPlan, gby *ast.GroupByClause, fields []*ast.SelectField) (base.LogicalPlan, []expression.Expression, bool, error) { - b.curClause = groupByClause - exprs := make([]expression.Expression, 0, len(gby.Items)) - resolver := &gbyResolver{ - ctx: b.ctx, - fields: fields, - schema: p.Schema(), - names: p.OutputNames(), - skipAggMap: b.correlatedAggMapper, - } - for _, item := range gby.Items { - resolver.inExpr = false - resolver.exprDepth = 0 - resolver.isParam = false - retExpr, _ := item.Expr.Accept(resolver) - if resolver.err != nil { - return nil, nil, false, errors.Trace(resolver.err) - } - if !resolver.isParam { - item.Expr = retExpr.(ast.ExprNode) - } - - itemExpr := retExpr.(ast.ExprNode) - expr, np, err := b.rewrite(ctx, itemExpr, p, nil, true) - if err != nil { - return nil, nil, false, err - } - - exprs = append(exprs, expr) - p = np - } - return p, exprs, gby.Rollup, nil -} - -func (*PlanBuilder) unfoldWildStar(p base.LogicalPlan, selectFields []*ast.SelectField) (resultList []*ast.SelectField, err error) { - join, isJoin := p.(*LogicalJoin) - for i, field := range selectFields { - if field.WildCard == nil { - resultList = append(resultList, field) - continue - } - if field.WildCard.Table.L == "" && i > 0 { - return nil, plannererrors.ErrInvalidWildCard - } - list := unfoldWildStar(field, p.OutputNames(), p.Schema().Columns) - // For sql like `select t1.*, t2.* from t1 join t2 using(a)` or `select t1.*, t2.* from t1 natual join t2`, - // the schema of the Join doesn't contain enough columns because the join keys are coalesced in this schema. - // We should collect the columns from the FullSchema. - if isJoin && join.FullSchema != nil && field.WildCard.Table.L != "" { - list = unfoldWildStar(field, join.FullNames, join.FullSchema.Columns) - } - if len(list) == 0 { - return nil, plannererrors.ErrBadTable.GenWithStackByArgs(field.WildCard.Table) - } - resultList = append(resultList, list...) - } - return resultList, nil -} - -func unfoldWildStar(field *ast.SelectField, outputName types.NameSlice, column []*expression.Column) (resultList []*ast.SelectField) { - dbName := field.WildCard.Schema - tblName := field.WildCard.Table - for i, name := range outputName { - col := column[i] - if col.IsHidden { - continue - } - if (dbName.L == "" || dbName.L == name.DBName.L) && - (tblName.L == "" || tblName.L == name.TblName.L) && - col.ID != model.ExtraHandleID && col.ID != model.ExtraPhysTblID { - colName := &ast.ColumnNameExpr{ - Name: &ast.ColumnName{ - Schema: name.DBName, - Table: name.TblName, - Name: name.ColName, - }} - colName.SetType(col.GetStaticType()) - field := &ast.SelectField{Expr: colName} - field.SetText(nil, name.ColName.O) - resultList = append(resultList, field) - } - } - return resultList -} - -func (b *PlanBuilder) addAliasName(ctx context.Context, selectStmt *ast.SelectStmt, p base.LogicalPlan) (resultList []*ast.SelectField, err error) { - selectFields := selectStmt.Fields.Fields - projOutNames := make([]*types.FieldName, 0, len(selectFields)) - for _, field := range selectFields { - colNameField, isColumnNameExpr := field.Expr.(*ast.ColumnNameExpr) - if isColumnNameExpr { - colName := colNameField.Name.Name - if field.AsName.L != "" { - colName = field.AsName - } - projOutNames = append(projOutNames, &types.FieldName{ - TblName: colNameField.Name.Table, - OrigTblName: colNameField.Name.Table, - ColName: colName, - OrigColName: colNameField.Name.Name, - DBName: colNameField.Name.Schema, - }) - } else { - // create view v as select name_const('col', 100); - // The column in v should be 'col', so we call `buildProjectionField` to handle this. - _, name, err := b.buildProjectionField(ctx, p, field, nil) - if err != nil { - return nil, err - } - projOutNames = append(projOutNames, name) - } - } - - // dedupMap is used for renaming a duplicated anonymous column - dedupMap := make(map[string]int) - anonymousFields := make([]bool, len(selectFields)) - - for i, field := range selectFields { - newField := *field - if newField.AsName.L == "" { - newField.AsName = projOutNames[i].ColName - } - - if _, ok := field.Expr.(*ast.ColumnNameExpr); !ok && field.AsName.L == "" { - anonymousFields[i] = true - } else { - anonymousFields[i] = false - // dedupMap should be inited with all non-anonymous fields before renaming other duplicated anonymous fields - dedupMap[newField.AsName.L] = 0 - } - - resultList = append(resultList, &newField) - } - - // We should rename duplicated anonymous fields in the first SelectStmt of CreateViewStmt - // See: https://github.com/pingcap/tidb/issues/29326 - if selectStmt.AsViewSchema { - for i, field := range resultList { - if !anonymousFields[i] { - continue - } - - oldName := field.AsName - if dup, ok := dedupMap[field.AsName.L]; ok { - if dup == 0 { - field.AsName = model.NewCIStr(fmt.Sprintf("Name_exp_%s", field.AsName.O)) - } else { - field.AsName = model.NewCIStr(fmt.Sprintf("Name_exp_%d_%s", dup, field.AsName.O)) - } - dedupMap[oldName.L] = dup + 1 - } else { - dedupMap[oldName.L] = 0 - } - } - } - - return resultList, nil -} - -func (b *PlanBuilder) pushHintWithoutTableWarning(hint *ast.TableOptimizerHint) { - var sb strings.Builder - ctx := format.NewRestoreCtx(0, &sb) - if err := hint.Restore(ctx); err != nil { - return - } - b.ctx.GetSessionVars().StmtCtx.SetHintWarning( - fmt.Sprintf("Hint %s is inapplicable. Please specify the table names in the arguments.", sb.String())) -} - -func (b *PlanBuilder) pushTableHints(hints []*ast.TableOptimizerHint, currentLevel int) { - hints = b.hintProcessor.GetCurrentStmtHints(hints, currentLevel) - currentDB := b.ctx.GetSessionVars().CurrentDB - warnHandler := b.ctx.GetSessionVars().StmtCtx - planHints, subQueryHintFlags, err := h.ParsePlanHints(hints, currentLevel, currentDB, - b.hintProcessor, b.ctx.GetSessionVars().StmtCtx.StraightJoinOrder, - b.subQueryCtx == handlingExistsSubquery, b.subQueryCtx == notHandlingSubquery, warnHandler) - if err != nil { - return - } - b.tableHintInfo = append(b.tableHintInfo, planHints) - b.subQueryHintFlags |= subQueryHintFlags -} - -func (b *PlanBuilder) popVisitInfo() { - if len(b.visitInfo) == 0 { - return - } - b.visitInfo = b.visitInfo[:len(b.visitInfo)-1] -} - -func (b *PlanBuilder) popTableHints() { - hintInfo := b.tableHintInfo[len(b.tableHintInfo)-1] - for _, warning := range h.CollectUnmatchedHintWarnings(hintInfo) { - b.ctx.GetSessionVars().StmtCtx.SetHintWarning(warning) - } - b.tableHintInfo = b.tableHintInfo[:len(b.tableHintInfo)-1] -} - -// TableHints returns the *TableHintInfo of PlanBuilder. -func (b *PlanBuilder) TableHints() *h.PlanHints { - if len(b.tableHintInfo) == 0 { - return nil - } - return b.tableHintInfo[len(b.tableHintInfo)-1] -} - -func (b *PlanBuilder) buildSelect(ctx context.Context, sel *ast.SelectStmt) (p base.LogicalPlan, err error) { - b.pushSelectOffset(sel.QueryBlockOffset) - b.pushTableHints(sel.TableHints, sel.QueryBlockOffset) - defer func() { - b.popSelectOffset() - // table hints are only visible in the current SELECT statement. - b.popTableHints() - }() - if b.buildingRecursivePartForCTE { - if sel.Distinct || sel.OrderBy != nil || sel.Limit != nil { - return nil, plannererrors.ErrNotSupportedYet.GenWithStackByArgs("ORDER BY / LIMIT / SELECT DISTINCT in recursive query block of Common Table Expression") - } - if sel.GroupBy != nil { - return nil, plannererrors.ErrCTERecursiveForbidsAggregation.FastGenByArgs(b.genCTETableNameForError()) - } - } - if sel.SelectStmtOpts != nil { - origin := b.inStraightJoin - b.inStraightJoin = sel.SelectStmtOpts.StraightJoin - defer func() { b.inStraightJoin = origin }() - } - - var ( - aggFuncs []*ast.AggregateFuncExpr - havingMap, orderMap, totalMap map[*ast.AggregateFuncExpr]int - windowAggMap map[*ast.AggregateFuncExpr]int - correlatedAggMap map[*ast.AggregateFuncExpr]int - gbyCols []expression.Expression - projExprs []expression.Expression - rollup bool - ) - - // set for update read to true before building result set node - if isForUpdateReadSelectLock(sel.LockInfo) { - b.isForUpdateRead = true - } - - if hints := b.TableHints(); hints != nil && hints.CTEMerge { - // Verify Merge hints in the current query, - // we will update parameters for those that meet the rules, and warn those that do not. - // If the current query uses Merge Hint and the query is a CTE, - // we update the HINT information for the current query. - // If the current query is not a CTE query (it may be a subquery within a CTE query - // or an external non-CTE query), we will give a warning. - // In particular, recursive CTE have separate warnings, so they are no longer called. - if b.buildingCTE { - if b.isCTE { - b.outerCTEs[len(b.outerCTEs)-1].forceInlineByHintOrVar = true - } else if !b.buildingRecursivePartForCTE { - // If there has subquery which is not CTE and using `MERGE()` hint, we will show this warning; - b.ctx.GetSessionVars().StmtCtx.SetHintWarning( - "Hint merge() is inapplicable. " + - "Please check whether the hint is used in the right place, " + - "you should use this hint inside the CTE.") - } - } else if !b.buildingCTE && !b.isCTE { - b.ctx.GetSessionVars().StmtCtx.SetHintWarning( - "Hint merge() is inapplicable. " + - "Please check whether the hint is used in the right place, " + - "you should use this hint inside the CTE.") - } - } - - var currentLayerCTEs []*cteInfo - if sel.With != nil { - l := len(b.outerCTEs) - defer func() { - b.outerCTEs = b.outerCTEs[:l] - }() - currentLayerCTEs, err = b.buildWith(ctx, sel.With) - if err != nil { - return nil, err - } - } - - p, err = b.buildTableRefs(ctx, sel.From) - if err != nil { - return nil, err - } - - originalFields := sel.Fields.Fields - sel.Fields.Fields, err = b.unfoldWildStar(p, sel.Fields.Fields) - if err != nil { - return nil, err - } - if b.capFlag&canExpandAST != 0 { - // To be compatible with MySQL, we add alias name for each select field when creating view. - sel.Fields.Fields, err = b.addAliasName(ctx, sel, p) - if err != nil { - return nil, err - } - originalFields = sel.Fields.Fields - } - - if sel.GroupBy != nil { - p, gbyCols, rollup, err = b.resolveGbyExprs(ctx, p, sel.GroupBy, sel.Fields.Fields) - if err != nil { - return nil, err - } - } - - if b.ctx.GetSessionVars().SQLMode.HasOnlyFullGroupBy() && sel.From != nil && !b.ctx.GetSessionVars().OptimizerEnableNewOnlyFullGroupByCheck { - err = b.checkOnlyFullGroupBy(p, sel) - if err != nil { - return nil, err - } - } - - hasWindowFuncField := b.detectSelectWindow(sel) - // Some SQL statements define WINDOW but do not use them. But we also need to check the window specification list. - // For example: select id from t group by id WINDOW w AS (ORDER BY uids DESC) ORDER BY id; - // We don't use the WINDOW w, but if the 'uids' column is not in the table t, we still need to report an error. - if hasWindowFuncField || sel.WindowSpecs != nil { - if b.buildingRecursivePartForCTE { - return nil, plannererrors.ErrCTERecursiveForbidsAggregation.FastGenByArgs(b.genCTETableNameForError()) - } - - windowAggMap, err = b.resolveWindowFunction(sel, p) - if err != nil { - return nil, err - } - } - // We must resolve having and order by clause before build projection, - // because when the query is "select a+1 as b from t having sum(b) < 0", we must replace sum(b) to sum(a+1), - // which only can be done before building projection and extracting Agg functions. - havingMap, orderMap, err = b.resolveHavingAndOrderBy(ctx, sel, p) - if err != nil { - return nil, err - } - - // We have to resolve correlated aggregate inside sub-queries before building aggregation and building projection, - // for instance, count(a) inside the sub-query of "select (select count(a)) from t" should be evaluated within - // the context of the outer query. So we have to extract such aggregates from sub-queries and put them into - // SELECT field list. - correlatedAggMap, err = b.resolveCorrelatedAggregates(ctx, sel, p) - if err != nil { - return nil, err - } - - // b.allNames will be used in evalDefaultExpr(). Default function is special because it needs to find the - // corresponding column name, but does not need the value in the column. - // For example, select a from t order by default(b), the column b will not be in select fields. Also because - // buildSort is after buildProjection, so we need get OutputNames before BuildProjection and store in allNames. - // Otherwise, we will get select fields instead of all OutputNames, so that we can't find the column b in the - // above example. - b.allNames = append(b.allNames, p.OutputNames()) - defer func() { b.allNames = b.allNames[:len(b.allNames)-1] }() - - if sel.Where != nil { - p, err = b.buildSelection(ctx, p, sel.Where, nil) - if err != nil { - return nil, err - } - } - l := sel.LockInfo - if l != nil && l.LockType != ast.SelectLockNone { - for _, tName := range l.Tables { - // CTE has no *model.HintedTable, we need to skip it. - if tName.TableInfo == nil { - continue - } - b.ctx.GetSessionVars().StmtCtx.LockTableIDs[tName.TableInfo.ID] = struct{}{} - } - p, err = b.buildSelectLock(p, l) - if err != nil { - return nil, err - } - } - b.handleHelper.popMap() - b.handleHelper.pushMap(nil) - - hasAgg := b.detectSelectAgg(sel) - needBuildAgg := hasAgg - if hasAgg { - if b.buildingRecursivePartForCTE { - return nil, plannererrors.ErrCTERecursiveForbidsAggregation.GenWithStackByArgs(b.genCTETableNameForError()) - } - - aggFuncs, totalMap = b.extractAggFuncsInSelectFields(sel.Fields.Fields) - // len(aggFuncs) == 0 and sel.GroupBy == nil indicates that all the aggregate functions inside the SELECT fields - // are actually correlated aggregates from the outer query, which have already been built in the outer query. - // The only thing we need to do is to find them from b.correlatedAggMap in buildProjection. - if len(aggFuncs) == 0 && sel.GroupBy == nil { - needBuildAgg = false - } - } - if needBuildAgg { - // if rollup syntax is specified, Expand OP is required to replicate the data to feed different grouping layout. - if rollup { - p, gbyCols, err = b.buildExpand(p, gbyCols) - if err != nil { - return nil, err - } - } - var aggIndexMap map[int]int - p, aggIndexMap, err = b.buildAggregation(ctx, p, aggFuncs, gbyCols, correlatedAggMap) - if err != nil { - return nil, err - } - for agg, idx := range totalMap { - totalMap[agg] = aggIndexMap[idx] - } - } - - var oldLen int - // According to https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html, - // we can only process window functions after having clause, so `considerWindow` is false now. - p, projExprs, oldLen, err = b.buildProjection(ctx, p, sel.Fields.Fields, totalMap, nil, false, sel.OrderBy != nil) - if err != nil { - return nil, err - } - - if sel.Having != nil { - b.curClause = havingClause - p, err = b.buildSelection(ctx, p, sel.Having.Expr, havingMap) - if err != nil { - return nil, err - } - } - - b.windowSpecs, err = buildWindowSpecs(sel.WindowSpecs) - if err != nil { - return nil, err - } - - var windowMapper map[*ast.WindowFuncExpr]int - if hasWindowFuncField || sel.WindowSpecs != nil { - windowFuncs := extractWindowFuncs(sel.Fields.Fields) - // we need to check the func args first before we check the window spec - err := b.checkWindowFuncArgs(ctx, p, windowFuncs, windowAggMap) - if err != nil { - return nil, err - } - groupedFuncs, orderedSpec, err := b.groupWindowFuncs(windowFuncs) - if err != nil { - return nil, err - } - p, windowMapper, err = b.buildWindowFunctions(ctx, p, groupedFuncs, orderedSpec, windowAggMap) - if err != nil { - return nil, err - } - // `hasWindowFuncField == false` means there's only unused named window specs without window functions. - // In such case plan `p` is not changed, so we don't have to build another projection. - if hasWindowFuncField { - // Now we build the window function fields. - p, projExprs, oldLen, err = b.buildProjection(ctx, p, sel.Fields.Fields, windowAggMap, windowMapper, true, false) - if err != nil { - return nil, err - } - } - } - - if sel.Distinct { - p, err = b.buildDistinct(p, oldLen) - if err != nil { - return nil, err - } - } - - if sel.OrderBy != nil { - // We need to keep the ORDER BY clause for the following cases: - // 1. The select is top level query, order should be honored - // 2. The query has LIMIT clause - // 3. The control flag requires keeping ORDER BY explicitly - if len(b.qbOffset) == 1 || sel.Limit != nil || !b.ctx.GetSessionVars().RemoveOrderbyInSubquery { - if b.ctx.GetSessionVars().SQLMode.HasOnlyFullGroupBy() { - p, err = b.buildSortWithCheck(ctx, p, sel.OrderBy.Items, orderMap, windowMapper, projExprs, oldLen, sel.Distinct) - } else { - p, err = b.buildSort(ctx, p, sel.OrderBy.Items, orderMap, windowMapper) - } - if err != nil { - return nil, err - } - } - } - - if sel.Limit != nil { - p, err = b.buildLimit(p, sel.Limit) - if err != nil { - return nil, err - } - } - - sel.Fields.Fields = originalFields - if oldLen != p.Schema().Len() { - proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldLen])}.Init(b.ctx, b.getSelectOffset()) - proj.SetChildren(p) - schema := expression.NewSchema(p.Schema().Clone().Columns[:oldLen]...) - for _, col := range schema.Columns { - col.UniqueID = b.ctx.GetSessionVars().AllocPlanColumnID() - } - proj.SetOutputNames(p.OutputNames()[:oldLen]) - proj.SetSchema(schema) - return b.tryToBuildSequence(currentLayerCTEs, proj), nil - } - - return b.tryToBuildSequence(currentLayerCTEs, p), nil -} - -func (b *PlanBuilder) tryToBuildSequence(ctes []*cteInfo, p base.LogicalPlan) base.LogicalPlan { - if !b.ctx.GetSessionVars().EnableMPPSharedCTEExecution { - return p - } - for i := len(ctes) - 1; i >= 0; i-- { - if !ctes[i].nonRecursive { - return p - } - if ctes[i].isInline || ctes[i].cteClass == nil { - ctes = append(ctes[:i], ctes[i+1:]...) - } - } - if len(ctes) == 0 { - return p - } - lctes := make([]base.LogicalPlan, 0, len(ctes)+1) - for _, cte := range ctes { - lcte := LogicalCTE{ - Cte: cte.cteClass, - CteAsName: cte.def.Name, - CteName: cte.def.Name, - SeedStat: cte.seedStat, - OnlyUsedAsStorage: true, - }.Init(b.ctx, b.getSelectOffset()) - lcte.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) - lctes = append(lctes, lcte) - } - b.optFlag |= flagPushDownSequence - seq := logicalop.LogicalSequence{}.Init(b.ctx, b.getSelectOffset()) - seq.SetChildren(append(lctes, p)...) - seq.SetOutputNames(p.OutputNames().Shallow()) - return seq -} - -func (b *PlanBuilder) buildTableDual() *logicalop.LogicalTableDual { - b.handleHelper.pushMap(nil) - return logicalop.LogicalTableDual{RowCount: 1}.Init(b.ctx, b.getSelectOffset()) -} - -func (ds *DataSource) newExtraHandleSchemaCol() *expression.Column { - tp := types.NewFieldType(mysql.TypeLonglong) - tp.SetFlag(mysql.NotNullFlag | mysql.PriKeyFlag) - return &expression.Column{ - RetType: tp, - UniqueID: ds.SCtx().GetSessionVars().AllocPlanColumnID(), - ID: model.ExtraHandleID, - OrigName: fmt.Sprintf("%v.%v.%v", ds.DBName, ds.TableInfo.Name, model.ExtraHandleName), - } -} - -// AddExtraPhysTblIDColumn for partition table. -// 'select ... for update' on a partition table need to know the partition ID -// to construct the lock key, so this column is added to the chunk row. -// Also needed for checking against the sessions transaction buffer -func (ds *DataSource) AddExtraPhysTblIDColumn() *expression.Column { - // Avoid adding multiple times (should never happen!) - cols := ds.TblCols - for i := len(cols) - 1; i >= 0; i-- { - if cols[i].ID == model.ExtraPhysTblID { - return cols[i] - } - } - pidCol := &expression.Column{ - RetType: types.NewFieldType(mysql.TypeLonglong), - UniqueID: ds.SCtx().GetSessionVars().AllocPlanColumnID(), - ID: model.ExtraPhysTblID, - OrigName: fmt.Sprintf("%v.%v.%v", ds.DBName, ds.TableInfo.Name, model.ExtraPhysTblIdName), - } - - ds.Columns = append(ds.Columns, model.NewExtraPhysTblIDColInfo()) - schema := ds.Schema() - schema.Append(pidCol) - ds.SetOutputNames(append(ds.OutputNames(), &types.FieldName{ - DBName: ds.DBName, - TblName: ds.TableInfo.Name, - ColName: model.ExtraPhysTblIdName, - OrigColName: model.ExtraPhysTblIdName, - })) - ds.TblCols = append(ds.TblCols, pidCol) - return pidCol -} - -// getStatsTable gets statistics information for a table specified by "tableID". -// A pseudo statistics table is returned in any of the following scenario: -// 1. tidb-server started and statistics handle has not been initialized. -// 2. table row count from statistics is zero. -// 3. statistics is outdated. -// Note: please also update getLatestVersionFromStatsTable() when logic in this function changes. -func getStatsTable(ctx base.PlanContext, tblInfo *model.TableInfo, pid int64) *statistics.Table { - statsHandle := domain.GetDomain(ctx).StatsHandle() - var usePartitionStats, countIs0, pseudoStatsForUninitialized, pseudoStatsForOutdated bool - var statsTbl *statistics.Table - if ctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(ctx) - defer func() { - debugTraceGetStatsTbl(ctx, - tblInfo, - pid, - statsHandle == nil, - usePartitionStats, - countIs0, - pseudoStatsForUninitialized, - pseudoStatsForOutdated, - statsTbl, - ) - debugtrace.LeaveContextCommon(ctx) - }() - } - // 1. tidb-server started and statistics handle has not been initialized. - if statsHandle == nil { - return statistics.PseudoTable(tblInfo, false, true) - } - - if pid == tblInfo.ID || ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - statsTbl = statsHandle.GetTableStats(tblInfo) - } else { - usePartitionStats = true - statsTbl = statsHandle.GetPartitionStats(tblInfo, pid) - } - intest.Assert(statsTbl.ColAndIdxExistenceMap != nil, "The existence checking map must not be nil.") - - allowPseudoTblTriggerLoading := false - // In OptObjectiveDeterminate mode, we need to ignore the real-time stats. - // To achieve this, we copy the statsTbl and reset the real-time stats fields (set ModifyCount to 0 and set - // RealtimeCount to the row count from the ANALYZE, which is fetched from loaded stats in GetAnalyzeRowCount()). - if ctx.GetSessionVars().GetOptObjective() == variable.OptObjectiveDeterminate { - analyzeCount := max(int64(statsTbl.GetAnalyzeRowCount()), 0) - // If the two fields are already the values we want, we don't need to modify it, and also we don't need to copy. - if statsTbl.RealtimeCount != analyzeCount || statsTbl.ModifyCount != 0 { - // Here is a case that we need specially care about: - // The original stats table from the stats cache is not a pseudo table, but the analyze row count is 0 (probably - // because of no col/idx stats are loaded), which will makes it a pseudo table according to the rule 2 below. - // Normally, a pseudo table won't trigger stats loading since we assume it means "no stats available", but - // in such case, we need it able to trigger stats loading. - // That's why we use the special allowPseudoTblTriggerLoading flag here. - if !statsTbl.Pseudo && statsTbl.RealtimeCount > 0 && analyzeCount == 0 { - allowPseudoTblTriggerLoading = true - } - // Copy it so we can modify the ModifyCount and the RealtimeCount safely. - statsTbl = statsTbl.ShallowCopy() - statsTbl.RealtimeCount = analyzeCount - statsTbl.ModifyCount = 0 - } - } - - // 2. table row count from statistics is zero. - if statsTbl.RealtimeCount == 0 { - countIs0 = true - core_metrics.PseudoEstimationNotAvailable.Inc() - return statistics.PseudoTable(tblInfo, allowPseudoTblTriggerLoading, true) - } - - // 3. statistics is uninitialized or outdated. - pseudoStatsForUninitialized = !statsTbl.IsInitialized() - pseudoStatsForOutdated = ctx.GetSessionVars().GetEnablePseudoForOutdatedStats() && statsTbl.IsOutdated() - if pseudoStatsForUninitialized || pseudoStatsForOutdated { - tbl := *statsTbl - tbl.Pseudo = true - statsTbl = &tbl - if pseudoStatsForUninitialized { - core_metrics.PseudoEstimationNotAvailable.Inc() - } else { - core_metrics.PseudoEstimationOutdate.Inc() - } - } - - return statsTbl -} - -// getLatestVersionFromStatsTable gets statistics information for a table specified by "tableID", and get the max -// LastUpdateVersion among all Columns and Indices in it. -// Its overall logic is quite similar to getStatsTable(). During plan cache matching, only the latest version is needed. -// In such case, compared to getStatsTable(), this function can save some copies, memory allocations and unnecessary -// checks. Also, this function won't trigger metrics changes. -func getLatestVersionFromStatsTable(ctx sessionctx.Context, tblInfo *model.TableInfo, pid int64) (version uint64) { - statsHandle := domain.GetDomain(ctx).StatsHandle() - // 1. tidb-server started and statistics handle has not been initialized. Pseudo stats table. - if statsHandle == nil { - return 0 - } - - var statsTbl *statistics.Table - if pid == tblInfo.ID || ctx.GetSessionVars().StmtCtx.UseDynamicPartitionPrune() { - statsTbl = statsHandle.GetTableStats(tblInfo) - } else { - statsTbl = statsHandle.GetPartitionStats(tblInfo, pid) - } - - // 2. Table row count from statistics is zero. Pseudo stats table. - realtimeRowCount := statsTbl.RealtimeCount - if ctx.GetSessionVars().GetOptObjective() == variable.OptObjectiveDeterminate { - realtimeRowCount = max(int64(statsTbl.GetAnalyzeRowCount()), 0) - } - if realtimeRowCount == 0 { - return 0 - } - - // 3. Not pseudo stats table. Return the max LastUpdateVersion among all Columns and Indices - // return statsTbl.LastAnalyzeVersion - statsTbl.ForEachColumnImmutable(func(_ int64, col *statistics.Column) bool { - version = max(version, col.LastUpdateVersion) - return false - }) - statsTbl.ForEachIndexImmutable(func(_ int64, idx *statistics.Index) bool { - version = max(version, idx.LastUpdateVersion) - return false - }) - return version -} - -func (b *PlanBuilder) tryBuildCTE(ctx context.Context, tn *ast.TableName, asName *model.CIStr) (base.LogicalPlan, error) { - for i := len(b.outerCTEs) - 1; i >= 0; i-- { - cte := b.outerCTEs[i] - if cte.def.Name.L == tn.Name.L { - if cte.isBuilding { - if cte.nonRecursive { - // Can't see this CTE, try outer definition. - continue - } - - // Building the recursive part. - cte.useRecursive = true - if cte.seedLP == nil { - return nil, plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst.FastGenByArgs(tn.Name.String()) - } - - if cte.enterSubquery || cte.recursiveRef { - return nil, plannererrors.ErrInvalidRequiresSingleReference.FastGenByArgs(tn.Name.String()) - } - - cte.recursiveRef = true - p := logicalop.LogicalCTETable{Name: cte.def.Name.String(), IDForStorage: cte.storageID, SeedStat: cte.seedStat, SeedSchema: cte.seedLP.Schema()}.Init(b.ctx, b.getSelectOffset()) - p.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) - p.SetOutputNames(cte.seedLP.OutputNames()) - return p, nil - } - - b.handleHelper.pushMap(nil) - - hasLimit := false - limitBeg := uint64(0) - limitEnd := uint64(0) - if cte.limitLP != nil { - hasLimit = true - switch x := cte.limitLP.(type) { - case *logicalop.LogicalLimit: - limitBeg = x.Offset - limitEnd = x.Offset + x.Count - case *logicalop.LogicalTableDual: - // Beg and End will both be 0. - default: - return nil, errors.Errorf("invalid type for limit plan: %v", cte.limitLP) - } - } - - if cte.cteClass == nil { - cte.cteClass = &CTEClass{ - IsDistinct: cte.isDistinct, - seedPartLogicalPlan: cte.seedLP, - recursivePartLogicalPlan: cte.recurLP, - IDForStorage: cte.storageID, - optFlag: cte.optFlag, - HasLimit: hasLimit, - LimitBeg: limitBeg, - LimitEnd: limitEnd, - pushDownPredicates: make([]expression.Expression, 0), - ColumnMap: make(map[string]*expression.Column), - } - } - var p base.LogicalPlan - lp := LogicalCTE{CteAsName: tn.Name, CteName: tn.Name, Cte: cte.cteClass, SeedStat: cte.seedStat}.Init(b.ctx, b.getSelectOffset()) - prevSchema := cte.seedLP.Schema().Clone() - lp.SetSchema(getResultCTESchema(cte.seedLP.Schema(), b.ctx.GetSessionVars())) - - // If current CTE query contain another CTE which 'containAggOrWindow' is true, current CTE 'containAggOrWindow' will be true - if b.buildingCTE { - b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow = cte.containAggOrWindow || b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow - } - // Compute cte inline - b.computeCTEInlineFlag(cte) - - if cte.recurLP == nil && cte.isInline { - saveCte := make([]*cteInfo, len(b.outerCTEs[i:])) - copy(saveCte, b.outerCTEs[i:]) - b.outerCTEs = b.outerCTEs[:i] - o := b.buildingCTE - b.buildingCTE = false - //nolint:all_revive,revive - defer func() { - b.outerCTEs = append(b.outerCTEs, saveCte...) - b.buildingCTE = o - }() - return b.buildDataSourceFromCTEMerge(ctx, cte.def) - } - - for i, col := range lp.Schema().Columns { - lp.Cte.ColumnMap[string(col.HashCode())] = prevSchema.Columns[i] - } - p = lp - p.SetOutputNames(cte.seedLP.OutputNames()) - if len(asName.String()) > 0 { - lp.CteAsName = *asName - var on types.NameSlice - for _, name := range p.OutputNames() { - cpOn := *name - cpOn.TblName = *asName - on = append(on, &cpOn) - } - p.SetOutputNames(on) - } - return p, nil - } - } - - return nil, nil -} - -// computeCTEInlineFlag, Combine the declaration of CTE and the use of CTE to jointly determine **whether a CTE can be inlined** -/* - There are some cases that CTE must be not inlined. - 1. CTE is recursive CTE. - 2. CTE contains agg or window and it is referenced by recursive part of CTE. - 3. Consumer count of CTE is more than one. - If 1 or 2 conditions are met, CTE cannot be inlined. - But if query is hint by 'merge()' or session variable "tidb_opt_force_inline_cte", - CTE will still not be inlined but a warning will be recorded "Hint or session variables are invalid" - If 3 condition is met, CTE can be inlined by hint and session variables. -*/ -func (b *PlanBuilder) computeCTEInlineFlag(cte *cteInfo) { - if cte.recurLP != nil { - if cte.forceInlineByHintOrVar { - b.ctx.GetSessionVars().StmtCtx.SetHintWarning( - fmt.Sprintf("Recursive CTE %s can not be inlined by merge() or tidb_opt_force_inline_cte.", cte.def.Name)) - } - } else if cte.containAggOrWindow && b.buildingRecursivePartForCTE { - if cte.forceInlineByHintOrVar { - b.ctx.GetSessionVars().StmtCtx.AppendWarning(plannererrors.ErrCTERecursiveForbidsAggregation.FastGenByArgs(cte.def.Name)) - } - } else if cte.consumerCount > 1 { - if cte.forceInlineByHintOrVar { - cte.isInline = true - } - } else { - cte.isInline = true - } -} - -func (b *PlanBuilder) buildDataSourceFromCTEMerge(ctx context.Context, cte *ast.CommonTableExpression) (base.LogicalPlan, error) { - p, err := b.buildResultSetNode(ctx, cte.Query.Query, true) - if err != nil { - return nil, err - } - b.handleHelper.popMap() - outPutNames := p.OutputNames() - for _, name := range outPutNames { - name.TblName = cte.Name - name.DBName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) - } - - if len(cte.ColNameList) > 0 { - if len(cte.ColNameList) != len(p.OutputNames()) { - return nil, errors.New("CTE columns length is not consistent") - } - for i, n := range cte.ColNameList { - outPutNames[i].ColName = n - } - } - p.SetOutputNames(outPutNames) - return p, nil -} - -func (b *PlanBuilder) buildDataSource(ctx context.Context, tn *ast.TableName, asName *model.CIStr) (base.LogicalPlan, error) { - b.optFlag |= flagPredicateSimplification - dbName := tn.Schema - sessionVars := b.ctx.GetSessionVars() - - if dbName.L == "" { - // Try CTE. - p, err := b.tryBuildCTE(ctx, tn, asName) - if err != nil || p != nil { - return p, err - } - dbName = model.NewCIStr(sessionVars.CurrentDB) - } - - is := b.is - if len(b.buildingViewStack) > 0 { - // For tables in view, always ignore local temporary table, considering the below case: - // If a user created a normal table `t1` and a view `v1` referring `t1`, and then a local temporary table with a same name `t1` is created. - // At this time, executing 'select * from v1' should still return all records from normal table `t1` instead of temporary table `t1`. - is = temptable.DetachLocalTemporaryTableInfoSchema(is) - } - - tbl, err := is.TableByName(ctx, dbName, tn.Name) - if err != nil { - return nil, err - } - - tbl, err = tryLockMDLAndUpdateSchemaIfNecessary(ctx, b.ctx, dbName, tbl, b.is) - if err != nil { - return nil, err - } - tableInfo := tbl.Meta() - - if b.isCreateView && tableInfo.TempTableType == model.TempTableLocal { - return nil, plannererrors.ErrViewSelectTemporaryTable.GenWithStackByArgs(tn.Name) - } - - var authErr error - if sessionVars.User != nil { - authErr = plannererrors.ErrTableaccessDenied.FastGenByArgs("SELECT", sessionVars.User.AuthUsername, sessionVars.User.AuthHostname, tableInfo.Name.L) - } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName.L, tableInfo.Name.L, "", authErr) - - if tbl.Type().IsVirtualTable() { - if tn.TableSample != nil { - return nil, expression.ErrInvalidTableSample.GenWithStackByArgs("Unsupported TABLESAMPLE in virtual tables") - } - return b.buildMemTable(ctx, dbName, tableInfo) - } - - tblName := *asName - if tblName.L == "" { - tblName = tn.Name - } - - if tableInfo.GetPartitionInfo() != nil { - // If `UseDynamicPruneMode` already been false, then we don't need to check whether execute `flagPartitionProcessor` - // otherwise we need to check global stats initialized for each partition table - if !b.ctx.GetSessionVars().IsDynamicPartitionPruneEnabled() { - b.optFlag = b.optFlag | flagPartitionProcessor - } else { - if !b.ctx.GetSessionVars().StmtCtx.UseDynamicPruneMode { - b.optFlag = b.optFlag | flagPartitionProcessor - } else { - h := domain.GetDomain(b.ctx).StatsHandle() - tblStats := h.GetTableStats(tableInfo) - isDynamicEnabled := b.ctx.GetSessionVars().IsDynamicPartitionPruneEnabled() - globalStatsReady := tblStats.IsAnalyzed() - skipMissingPartition := b.ctx.GetSessionVars().SkipMissingPartitionStats - // If we already enabled the tidb_skip_missing_partition_stats, the global stats can be treated as exist. - allowDynamicWithoutStats := fixcontrol.GetBoolWithDefault(b.ctx.GetSessionVars().GetOptimizerFixControlMap(), fixcontrol.Fix44262, skipMissingPartition) - - // If dynamic partition prune isn't enabled or global stats is not ready, we won't enable dynamic prune mode in query - usePartitionProcessor := !isDynamicEnabled || (!globalStatsReady && !allowDynamicWithoutStats) - - failpoint.Inject("forceDynamicPrune", func(val failpoint.Value) { - if val.(bool) { - if isDynamicEnabled { - usePartitionProcessor = false - } - } - }) - - if usePartitionProcessor { - b.optFlag = b.optFlag | flagPartitionProcessor - b.ctx.GetSessionVars().StmtCtx.UseDynamicPruneMode = false - if isDynamicEnabled { - b.ctx.GetSessionVars().StmtCtx.AppendWarning( - fmt.Errorf("disable dynamic pruning due to %s has no global stats", tableInfo.Name.String())) - } - } - } - } - pt := tbl.(table.PartitionedTable) - // check partition by name. - if len(tn.PartitionNames) > 0 { - pids := make(map[int64]struct{}, len(tn.PartitionNames)) - for _, name := range tn.PartitionNames { - pid, err := tables.FindPartitionByName(tableInfo, name.L) - if err != nil { - return nil, err - } - pids[pid] = struct{}{} - } - pt = tables.NewPartitionTableWithGivenSets(pt, pids) - } - b.partitionedTable = append(b.partitionedTable, pt) - } else if len(tn.PartitionNames) != 0 { - return nil, plannererrors.ErrPartitionClauseOnNonpartitioned - } - - possiblePaths, err := getPossibleAccessPaths(b.ctx, b.TableHints(), tn.IndexHints, tbl, dbName, tblName, b.isForUpdateRead, b.optFlag&flagPartitionProcessor > 0) - if err != nil { - return nil, err - } - - if tableInfo.IsView() { - if tn.TableSample != nil { - return nil, expression.ErrInvalidTableSample.GenWithStackByArgs("Unsupported TABLESAMPLE in views") - } - - // Get the hints belong to the current view. - currentQBNameMap4View := make(map[string][]ast.HintTable) - currentViewHints := make(map[string][]*ast.TableOptimizerHint) - for qbName, viewQBNameHintTable := range b.hintProcessor.ViewQBNameToTable { - if len(viewQBNameHintTable) == 0 { - continue - } - viewSelectOffset := b.getSelectOffset() - - var viewHintSelectOffset int - if viewQBNameHintTable[0].QBName.L == "" { - // If we do not explicit set the qbName, we will set the empty qb name to @sel_1. - viewHintSelectOffset = 1 - } else { - viewHintSelectOffset = b.hintProcessor.GetHintOffset(viewQBNameHintTable[0].QBName, viewSelectOffset) - } - - // Check whether the current view can match the view name in the hint. - if viewQBNameHintTable[0].TableName.L == tblName.L && viewHintSelectOffset == viewSelectOffset { - // If the view hint can match the current view, we pop the first view table in the query block hint's table list. - // It means the hint belong the current view, the first view name in hint is matched. - // Because of the nested views, so we should check the left table list in hint when build the data source from the view inside the current view. - currentQBNameMap4View[qbName] = viewQBNameHintTable[1:] - currentViewHints[qbName] = b.hintProcessor.ViewQBNameToHints[qbName] - b.hintProcessor.ViewQBNameUsed[qbName] = struct{}{} - } - } - return b.BuildDataSourceFromView(ctx, dbName, tableInfo, currentQBNameMap4View, currentViewHints) - } - - if tableInfo.IsSequence() { - if tn.TableSample != nil { - return nil, expression.ErrInvalidTableSample.GenWithStackByArgs("Unsupported TABLESAMPLE in sequences") - } - // When the source is a Sequence, we convert it to a TableDual, as what most databases do. - return b.buildTableDual(), nil - } - - // remain tikv access path to generate point get acceess path if existed - // see detail in issue: https://github.com/pingcap/tidb/issues/39543 - if !(b.isForUpdateRead && b.ctx.GetSessionVars().TxnCtx.IsExplicit) { - // Skip storage engine check for CreateView. - if b.capFlag&canExpandAST == 0 { - possiblePaths, err = filterPathByIsolationRead(b.ctx, possiblePaths, tblName, dbName) - if err != nil { - return nil, err - } - } - } - - // Try to substitute generate column only if there is an index on generate column. - for _, index := range tableInfo.Indices { - if index.State != model.StatePublic { - continue - } - for _, indexCol := range index.Columns { - colInfo := tbl.Cols()[indexCol.Offset] - if colInfo.IsGenerated() && !colInfo.GeneratedStored { - b.optFlag |= flagGcSubstitute - break - } - } - } - - var columns []*table.Column - if b.inUpdateStmt { - // create table t(a int, b int). - // Imagine that, There are 2 TiDB instances in the cluster, name A, B. We add a column `c` to table t in the TiDB cluster. - // One of the TiDB, A, the column type in its infoschema is changed to public. And in the other TiDB, the column type is - // still StateWriteReorganization. - // TiDB A: insert into t values(1, 2, 3); - // TiDB B: update t set a = 2 where b = 2; - // If we use tbl.Cols() here, the update statement, will ignore the col `c`, and the data `3` will lost. - columns = tbl.WritableCols() - } else if b.inDeleteStmt { - // DeletableCols returns all columns of the table in deletable states. - columns = tbl.DeletableCols() - } else { - columns = tbl.Cols() - } - // extract the IndexMergeHint - var indexMergeHints []h.HintedIndex - if hints := b.TableHints(); hints != nil { - for i, hint := range hints.IndexMergeHintList { - if hint.Match(dbName, tblName) { - hints.IndexMergeHintList[i].Matched = true - // check whether the index names in IndexMergeHint are valid. - invalidIdxNames := make([]string, 0, len(hint.IndexHint.IndexNames)) - for _, idxName := range hint.IndexHint.IndexNames { - hasIdxName := false - for _, path := range possiblePaths { - if path.IsTablePath() { - if idxName.L == "primary" { - hasIdxName = true - break - } - continue - } - if idxName.L == path.Index.Name.L { - hasIdxName = true - break - } - } - if !hasIdxName { - invalidIdxNames = append(invalidIdxNames, idxName.String()) - } - } - if len(invalidIdxNames) == 0 { - indexMergeHints = append(indexMergeHints, hint) - } else { - // Append warning if there are invalid index names. - errMsg := fmt.Sprintf("use_index_merge(%s) is inapplicable, check whether the indexes (%s) "+ - "exist, or the indexes are conflicted with use_index/ignore_index/force_index hints.", - hint.IndexString(), strings.Join(invalidIdxNames, ", ")) - b.ctx.GetSessionVars().StmtCtx.SetHintWarning(errMsg) - } - } - } - } - ds := DataSource{ - DBName: dbName, - TableAsName: asName, - table: tbl, - TableInfo: tableInfo, - PhysicalTableID: tableInfo.ID, - AstIndexHints: tn.IndexHints, - IndexHints: b.TableHints().IndexHintList, - IndexMergeHints: indexMergeHints, - PossibleAccessPaths: possiblePaths, - Columns: make([]*model.ColumnInfo, 0, len(columns)), - PartitionNames: tn.PartitionNames, - TblCols: make([]*expression.Column, 0, len(columns)), - PreferPartitions: make(map[int][]model.CIStr), - IS: b.is, - IsForUpdateRead: b.isForUpdateRead, - }.Init(b.ctx, b.getSelectOffset()) - var handleCols util.HandleCols - schema := expression.NewSchema(make([]*expression.Column, 0, len(columns))...) - names := make([]*types.FieldName, 0, len(columns)) - for i, col := range columns { - ds.Columns = append(ds.Columns, col.ToInfo()) - names = append(names, &types.FieldName{ - DBName: dbName, - TblName: tableInfo.Name, - ColName: col.Name, - OrigTblName: tableInfo.Name, - OrigColName: col.Name, - // For update statement and delete statement, internal version should see the special middle state column, while user doesn't. - NotExplicitUsable: col.State != model.StatePublic, - }) - newCol := &expression.Column{ - UniqueID: sessionVars.AllocPlanColumnID(), - ID: col.ID, - RetType: col.FieldType.Clone(), - OrigName: names[i].String(), - IsHidden: col.Hidden, - } - if col.IsPKHandleColumn(tableInfo) { - handleCols = util.NewIntHandleCols(newCol) - } - schema.Append(newCol) - ds.TblCols = append(ds.TblCols, newCol) - } - // We append an extra handle column to the schema when the handle - // column is not the primary key of "ds". - if handleCols == nil { - if tableInfo.IsCommonHandle { - primaryIdx := tables.FindPrimaryIndex(tableInfo) - handleCols = util.NewCommonHandleCols(b.ctx.GetSessionVars().StmtCtx, tableInfo, primaryIdx, ds.TblCols) - } else { - extraCol := ds.newExtraHandleSchemaCol() - handleCols = util.NewIntHandleCols(extraCol) - ds.Columns = append(ds.Columns, model.NewExtraHandleColInfo()) - schema.Append(extraCol) - names = append(names, &types.FieldName{ - DBName: dbName, - TblName: tableInfo.Name, - ColName: model.ExtraHandleName, - OrigColName: model.ExtraHandleName, - }) - ds.TblCols = append(ds.TblCols, extraCol) - } - } - ds.HandleCols = handleCols - ds.UnMutableHandleCols = handleCols - handleMap := make(map[int64][]util.HandleCols) - handleMap[tableInfo.ID] = []util.HandleCols{handleCols} - b.handleHelper.pushMap(handleMap) - ds.SetSchema(schema) - ds.SetOutputNames(names) - ds.setPreferredStoreType(b.TableHints()) - ds.SampleInfo = tablesampler.NewTableSampleInfo(tn.TableSample, schema, b.partitionedTable) - b.isSampling = ds.SampleInfo != nil - - for i, colExpr := range ds.Schema().Columns { - var expr expression.Expression - if i < len(columns) { - if columns[i].IsGenerated() && !columns[i].GeneratedStored { - var err error - originVal := b.allowBuildCastArray - b.allowBuildCastArray = true - expr, _, err = b.rewrite(ctx, columns[i].GeneratedExpr.Clone(), ds, nil, true) - b.allowBuildCastArray = originVal - if err != nil { - return nil, err - } - colExpr.VirtualExpr = expr.Clone() - } - } - } - - // Init CommonHandleCols and CommonHandleLens for data source. - if tableInfo.IsCommonHandle { - ds.CommonHandleCols, ds.CommonHandleLens = expression.IndexInfo2Cols(ds.Columns, ds.Schema().Columns, tables.FindPrimaryIndex(tableInfo)) - } - // Init FullIdxCols, FullIdxColLens for accessPaths. - for _, path := range ds.PossibleAccessPaths { - if !path.IsIntHandlePath { - path.FullIdxCols, path.FullIdxColLens = expression.IndexInfo2Cols(ds.Columns, ds.Schema().Columns, path.Index) - - // check whether the path's index has a tidb_shard() prefix and the index column count - // more than 1. e.g. index(tidb_shard(a), a) - // set UkShardIndexPath only for unique secondary index - if !path.IsCommonHandlePath { - // tidb_shard expression must be first column of index - col := path.FullIdxCols[0] - if col != nil && - expression.GcColumnExprIsTidbShard(col.VirtualExpr) && - len(path.Index.Columns) > 1 && - path.Index.Unique { - path.IsUkShardIndexPath = true - ds.ContainExprPrefixUk = true - } - } - } - } - - var result base.LogicalPlan = ds - dirty := tableHasDirtyContent(b.ctx, tableInfo) - if dirty || tableInfo.TempTableType == model.TempTableLocal || tableInfo.TableCacheStatusType == model.TableCacheStatusEnable { - us := logicalop.LogicalUnionScan{HandleCols: handleCols}.Init(b.ctx, b.getSelectOffset()) - us.SetChildren(ds) - if tableInfo.Partition != nil && b.optFlag&flagPartitionProcessor == 0 { - // Adding ExtraPhysTblIDCol for UnionScan (transaction buffer handling) - // Not using old static prune mode - // Single TableReader for all partitions, needs the PhysTblID from storage - _ = ds.AddExtraPhysTblIDColumn() - } - result = us - } - - // Adding ExtraPhysTblIDCol for SelectLock (SELECT FOR UPDATE) is done when building SelectLock - - if sessionVars.StmtCtx.TblInfo2UnionScan == nil { - sessionVars.StmtCtx.TblInfo2UnionScan = make(map[*model.TableInfo]bool) - } - sessionVars.StmtCtx.TblInfo2UnionScan[tableInfo] = dirty - - return result, nil -} - -func (b *PlanBuilder) timeRangeForSummaryTable() util.QueryTimeRange { - const defaultSummaryDuration = 30 * time.Minute - hints := b.TableHints() - // User doesn't use TIME_RANGE hint - if hints == nil || (hints.TimeRangeHint.From == "" && hints.TimeRangeHint.To == "") { - to := time.Now() - from := to.Add(-defaultSummaryDuration) - return util.QueryTimeRange{From: from, To: to} - } - - // Parse time specified by user via TIM_RANGE hint - parse := func(s string) (time.Time, bool) { - t, err := time.ParseInLocation(util.MetricTableTimeFormat, s, time.Local) - if err != nil { - b.ctx.GetSessionVars().StmtCtx.AppendWarning(err) - } - return t, err == nil - } - from, fromValid := parse(hints.TimeRangeHint.From) - to, toValid := parse(hints.TimeRangeHint.To) - switch { - case !fromValid && !toValid: - to = time.Now() - from = to.Add(-defaultSummaryDuration) - case fromValid && !toValid: - to = from.Add(defaultSummaryDuration) - case !fromValid && toValid: - from = to.Add(-defaultSummaryDuration) - } - - return util.QueryTimeRange{From: from, To: to} -} - -func (b *PlanBuilder) buildMemTable(_ context.Context, dbName model.CIStr, tableInfo *model.TableInfo) (base.LogicalPlan, error) { - // We can use the `TableInfo.Columns` directly because the memory table has - // a stable schema and there is no online DDL on the memory table. - schema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.Columns))...) - names := make([]*types.FieldName, 0, len(tableInfo.Columns)) - var handleCols util.HandleCols - for _, col := range tableInfo.Columns { - names = append(names, &types.FieldName{ - DBName: dbName, - TblName: tableInfo.Name, - ColName: col.Name, - OrigTblName: tableInfo.Name, - OrigColName: col.Name, - }) - // NOTE: Rewrite the expression if memory table supports generated columns in the future - newCol := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - ID: col.ID, - RetType: &col.FieldType, - } - if tableInfo.PKIsHandle && mysql.HasPriKeyFlag(col.GetFlag()) { - handleCols = util.NewIntHandleCols(newCol) - } - schema.Append(newCol) - } - - if handleCols != nil { - handleMap := make(map[int64][]util.HandleCols) - handleMap[tableInfo.ID] = []util.HandleCols{handleCols} - b.handleHelper.pushMap(handleMap) - } else { - b.handleHelper.pushMap(nil) - } - - // NOTE: Add a `LogicalUnionScan` if we support update memory table in the future - p := logicalop.LogicalMemTable{ - DBName: dbName, - TableInfo: tableInfo, - Columns: make([]*model.ColumnInfo, len(tableInfo.Columns)), - }.Init(b.ctx, b.getSelectOffset()) - p.SetSchema(schema) - p.SetOutputNames(names) - copy(p.Columns, tableInfo.Columns) - - // Some memory tables can receive some predicates - switch dbName.L { - case util2.MetricSchemaName.L: - p.Extractor = newMetricTableExtractor() - case util2.InformationSchemaName.L: - switch upTbl := strings.ToUpper(tableInfo.Name.O); upTbl { - case infoschema.TableClusterConfig, infoschema.TableClusterLoad, infoschema.TableClusterHardware, infoschema.TableClusterSystemInfo: - p.Extractor = &ClusterTableExtractor{} - case infoschema.TableClusterLog: - p.Extractor = &ClusterLogTableExtractor{} - case infoschema.TableTiDBHotRegionsHistory: - p.Extractor = &HotRegionsHistoryTableExtractor{} - case infoschema.TableInspectionResult: - p.Extractor = &InspectionResultTableExtractor{} - p.QueryTimeRange = b.timeRangeForSummaryTable() - case infoschema.TableInspectionSummary: - p.Extractor = &InspectionSummaryTableExtractor{} - p.QueryTimeRange = b.timeRangeForSummaryTable() - case infoschema.TableInspectionRules: - p.Extractor = &InspectionRuleTableExtractor{} - case infoschema.TableMetricSummary, infoschema.TableMetricSummaryByLabel: - p.Extractor = &MetricSummaryTableExtractor{} - p.QueryTimeRange = b.timeRangeForSummaryTable() - case infoschema.TableSlowQuery: - p.Extractor = &SlowQueryExtractor{} - case infoschema.TableStorageStats: - p.Extractor = &TableStorageStatsExtractor{} - case infoschema.TableTiFlashTables, infoschema.TableTiFlashSegments: - p.Extractor = &TiFlashSystemTableExtractor{} - case infoschema.TableStatementsSummary, infoschema.TableStatementsSummaryHistory: - p.Extractor = &StatementsSummaryExtractor{} - case infoschema.TableTiKVRegionPeers: - p.Extractor = &TikvRegionPeersExtractor{} - case infoschema.TableColumns: - p.Extractor = &ColumnsTableExtractor{} - case infoschema.TableTables: - ex := &InfoSchemaTablesExtractor{} - ex.initExtractableColNames(upTbl) - p.Extractor = ex - case infoschema.TablePartitions: - ex := &InfoSchemaPartitionsExtractor{} - ex.initExtractableColNames(upTbl) - p.Extractor = ex - case infoschema.TableStatistics: - ex := &InfoSchemaStatisticsExtractor{} - ex.initExtractableColNames(upTbl) - p.Extractor = ex - case infoschema.TableSchemata: - ex := &InfoSchemaSchemataExtractor{} - ex.initExtractableColNames(upTbl) - p.Extractor = ex - case infoschema.TableReferConst, - infoschema.TableKeyColumn, - infoschema.TableSequences, - infoschema.TableCheckConstraints, - infoschema.TableTiDBCheckConstraints, - infoschema.TableTiDBIndexUsage, - infoschema.TableTiDBIndexes, - infoschema.TableViews, - infoschema.TableConstraints: - ex := &InfoSchemaBaseExtractor{} - ex.initExtractableColNames(upTbl) - p.Extractor = ex - case infoschema.TableTiKVRegionStatus: - p.Extractor = &TiKVRegionStatusExtractor{tablesID: make([]int64, 0)} - } - } - return p, nil -} - -// checkRecursiveView checks whether this view is recursively defined. -func (b *PlanBuilder) checkRecursiveView(dbName model.CIStr, tableName model.CIStr) (func(), error) { - viewFullName := dbName.L + "." + tableName.L - if b.buildingViewStack == nil { - b.buildingViewStack = set.NewStringSet() - } - // If this view has already been on the building stack, it means - // this view contains a recursive definition. - if b.buildingViewStack.Exist(viewFullName) { - return nil, plannererrors.ErrViewRecursive.GenWithStackByArgs(dbName.O, tableName.O) - } - // If the view is being renamed, we return the mysql compatible error message. - if b.capFlag&renameView != 0 && viewFullName == b.renamingViewName { - return nil, plannererrors.ErrNoSuchTable.GenWithStackByArgs(dbName.O, tableName.O) - } - b.buildingViewStack.Insert(viewFullName) - return func() { delete(b.buildingViewStack, viewFullName) }, nil -} - -// BuildDataSourceFromView is used to build base.LogicalPlan from view -// qbNameMap4View and viewHints are used for the view's hint. -// qbNameMap4View maps the query block name to the view table lists. -// viewHints group the view hints based on the view's query block name. -func (b *PlanBuilder) BuildDataSourceFromView(ctx context.Context, dbName model.CIStr, tableInfo *model.TableInfo, qbNameMap4View map[string][]ast.HintTable, viewHints map[string][]*ast.TableOptimizerHint) (base.LogicalPlan, error) { - viewDepth := b.ctx.GetSessionVars().StmtCtx.ViewDepth - b.ctx.GetSessionVars().StmtCtx.ViewDepth++ - deferFunc, err := b.checkRecursiveView(dbName, tableInfo.Name) - if err != nil { - return nil, err - } - defer deferFunc() - - charset, collation := b.ctx.GetSessionVars().GetCharsetInfo() - viewParser := parser.New() - viewParser.SetParserConfig(b.ctx.GetSessionVars().BuildParserConfig()) - selectNode, err := viewParser.ParseOneStmt(tableInfo.View.SelectStmt, charset, collation) - if err != nil { - return nil, err - } - originalVisitInfo := b.visitInfo - b.visitInfo = make([]visitInfo, 0) - - // For the case that views appear in CTE queries, - // we need to save the CTEs after the views are established. - var saveCte []*cteInfo - if len(b.outerCTEs) > 0 { - saveCte = make([]*cteInfo, len(b.outerCTEs)) - copy(saveCte, b.outerCTEs) - } else { - saveCte = nil - } - o := b.buildingCTE - b.buildingCTE = false - defer func() { - b.outerCTEs = saveCte - b.buildingCTE = o - }() - - hintProcessor := h.NewQBHintHandler(b.ctx.GetSessionVars().StmtCtx) - selectNode.Accept(hintProcessor) - currentQbNameMap4View := make(map[string][]ast.HintTable) - currentQbHints4View := make(map[string][]*ast.TableOptimizerHint) - currentQbHints := make(map[int][]*ast.TableOptimizerHint) - currentQbNameMap := make(map[string]int) - - for qbName, viewQbNameHint := range qbNameMap4View { - // Check whether the view hint belong the current view or its nested views. - qbOffset := -1 - if len(viewQbNameHint) == 0 { - qbOffset = 1 - } else if len(viewQbNameHint) == 1 && viewQbNameHint[0].TableName.L == "" { - qbOffset = hintProcessor.GetHintOffset(viewQbNameHint[0].QBName, -1) - } else { - currentQbNameMap4View[qbName] = viewQbNameHint - currentQbHints4View[qbName] = viewHints[qbName] - } - - if qbOffset != -1 { - // If the hint belongs to the current view and not belongs to it's nested views, we should convert the view hint to the normal hint. - // After we convert the view hint to the normal hint, it can be reused the origin hint's infrastructure. - currentQbHints[qbOffset] = viewHints[qbName] - currentQbNameMap[qbName] = qbOffset - - delete(qbNameMap4View, qbName) - delete(viewHints, qbName) - } - } - - hintProcessor.ViewQBNameToTable = qbNameMap4View - hintProcessor.ViewQBNameToHints = viewHints - hintProcessor.ViewQBNameUsed = make(map[string]struct{}) - hintProcessor.QBOffsetToHints = currentQbHints - hintProcessor.QBNameToSelOffset = currentQbNameMap - - originHintProcessor := b.hintProcessor - originPlannerSelectBlockAsName := b.ctx.GetSessionVars().PlannerSelectBlockAsName.Load() - b.hintProcessor = hintProcessor - newPlannerSelectBlockAsName := make([]ast.HintTable, hintProcessor.MaxSelectStmtOffset()+1) - b.ctx.GetSessionVars().PlannerSelectBlockAsName.Store(&newPlannerSelectBlockAsName) - defer func() { - b.hintProcessor.HandleUnusedViewHints() - b.hintProcessor = originHintProcessor - b.ctx.GetSessionVars().PlannerSelectBlockAsName.Store(originPlannerSelectBlockAsName) - }() - selectLogicalPlan, err := b.Build(ctx, selectNode) - if err != nil { - logutil.BgLogger().Error("build plan for view failed", zap.Error(err)) - if terror.ErrorNotEqual(err, plannererrors.ErrViewRecursive) && - terror.ErrorNotEqual(err, plannererrors.ErrNoSuchTable) && - terror.ErrorNotEqual(err, plannererrors.ErrInternal) && - terror.ErrorNotEqual(err, plannererrors.ErrFieldNotInGroupBy) && - terror.ErrorNotEqual(err, plannererrors.ErrMixOfGroupFuncAndFields) && - terror.ErrorNotEqual(err, plannererrors.ErrViewNoExplain) && - terror.ErrorNotEqual(err, plannererrors.ErrNotSupportedYet) { - err = plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) - } - return nil, err - } - pm := privilege.GetPrivilegeManager(b.ctx) - if viewDepth != 0 && - b.ctx.GetSessionVars().StmtCtx.InExplainStmt && - pm != nil && - !pm.RequestVerification(b.ctx.GetSessionVars().ActiveRoles, dbName.L, tableInfo.Name.L, "", mysql.SelectPriv) { - return nil, plannererrors.ErrViewNoExplain - } - if tableInfo.View.Security == model.SecurityDefiner { - if pm != nil { - for _, v := range b.visitInfo { - if !pm.RequestVerificationWithUser(v.db, v.table, v.column, v.privilege, tableInfo.View.Definer) { - return nil, plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) - } - } - } - b.visitInfo = b.visitInfo[:0] - } - b.visitInfo = append(originalVisitInfo, b.visitInfo...) - - if b.ctx.GetSessionVars().StmtCtx.InExplainStmt { - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.ShowViewPriv, dbName.L, tableInfo.Name.L, "", plannererrors.ErrViewNoExplain) - } - - if len(tableInfo.Columns) != selectLogicalPlan.Schema().Len() { - return nil, plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) - } - - return b.buildProjUponView(ctx, dbName, tableInfo, selectLogicalPlan) -} - -func (b *PlanBuilder) buildProjUponView(_ context.Context, dbName model.CIStr, tableInfo *model.TableInfo, selectLogicalPlan base.Plan) (base.LogicalPlan, error) { - columnInfo := tableInfo.Cols() - cols := selectLogicalPlan.Schema().Clone().Columns - outputNamesOfUnderlyingSelect := selectLogicalPlan.OutputNames().Shallow() - // In the old version of VIEW implementation, TableInfo.View.Cols is used to - // store the origin columns' names of the underlying SelectStmt used when - // creating the view. - if tableInfo.View.Cols != nil { - cols = cols[:0] - outputNamesOfUnderlyingSelect = outputNamesOfUnderlyingSelect[:0] - for _, info := range columnInfo { - idx := expression.FindFieldNameIdxByColName(selectLogicalPlan.OutputNames(), info.Name.L) - if idx == -1 { - return nil, plannererrors.ErrViewInvalid.GenWithStackByArgs(dbName.O, tableInfo.Name.O) - } - cols = append(cols, selectLogicalPlan.Schema().Columns[idx]) - outputNamesOfUnderlyingSelect = append(outputNamesOfUnderlyingSelect, selectLogicalPlan.OutputNames()[idx]) - } - } - - projSchema := expression.NewSchema(make([]*expression.Column, 0, len(tableInfo.Columns))...) - projExprs := make([]expression.Expression, 0, len(tableInfo.Columns)) - projNames := make(types.NameSlice, 0, len(tableInfo.Columns)) - for i, name := range outputNamesOfUnderlyingSelect { - origColName := name.ColName - if tableInfo.View.Cols != nil { - origColName = tableInfo.View.Cols[i] - } - projNames = append(projNames, &types.FieldName{ - // TblName is the of view instead of the name of the underlying table. - TblName: tableInfo.Name, - OrigTblName: name.OrigTblName, - ColName: columnInfo[i].Name, - OrigColName: origColName, - DBName: dbName, - }) - projSchema.Append(&expression.Column{ - UniqueID: cols[i].UniqueID, - RetType: cols[i].GetStaticType(), - }) - projExprs = append(projExprs, cols[i]) - } - projUponView := logicalop.LogicalProjection{Exprs: projExprs}.Init(b.ctx, b.getSelectOffset()) - projUponView.SetOutputNames(projNames) - projUponView.SetChildren(selectLogicalPlan.(base.LogicalPlan)) - projUponView.SetSchema(projSchema) - return projUponView, nil -} - -// buildApplyWithJoinType builds apply plan with outerPlan and innerPlan, which apply join with particular join type for -// every row from outerPlan and the whole innerPlan. -func (b *PlanBuilder) buildApplyWithJoinType(outerPlan, innerPlan base.LogicalPlan, tp JoinType, markNoDecorrelate bool) base.LogicalPlan { - b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate | flagConvertOuterToInnerJoin - ap := LogicalApply{LogicalJoin: LogicalJoin{JoinType: tp}, NoDecorrelate: markNoDecorrelate}.Init(b.ctx, b.getSelectOffset()) - ap.SetChildren(outerPlan, innerPlan) - ap.SetOutputNames(make([]*types.FieldName, outerPlan.Schema().Len()+innerPlan.Schema().Len())) - copy(ap.OutputNames(), outerPlan.OutputNames()) - ap.SetSchema(expression.MergeSchema(outerPlan.Schema(), innerPlan.Schema())) - setIsInApplyForCTE(innerPlan, ap.Schema()) - // Note that, tp can only be LeftOuterJoin or InnerJoin, so we don't consider other outer joins. - if tp == LeftOuterJoin { - b.optFlag = b.optFlag | flagEliminateOuterJoin - util.ResetNotNullFlag(ap.Schema(), outerPlan.Schema().Len(), ap.Schema().Len()) - } - for i := outerPlan.Schema().Len(); i < ap.Schema().Len(); i++ { - ap.OutputNames()[i] = types.EmptyName - } - ap.LogicalJoin.SetPreferredJoinTypeAndOrder(b.TableHints()) - return ap -} - -// buildSemiApply builds apply plan with outerPlan and innerPlan, which apply semi-join for every row from outerPlan and the whole innerPlan. -func (b *PlanBuilder) buildSemiApply(outerPlan, innerPlan base.LogicalPlan, condition []expression.Expression, - asScalar, not, considerRewrite, markNoDecorrelate bool) (base.LogicalPlan, error) { - b.optFlag = b.optFlag | flagPredicatePushDown | flagBuildKeyInfo | flagDecorrelate - - join, err := b.buildSemiJoin(outerPlan, innerPlan, condition, asScalar, not, considerRewrite) - if err != nil { - return nil, err - } - - setIsInApplyForCTE(innerPlan, join.Schema()) - ap := &LogicalApply{LogicalJoin: *join, NoDecorrelate: markNoDecorrelate} - ap.SetTP(plancodec.TypeApply) - ap.SetSelf(ap) - return ap, nil -} - -// setIsInApplyForCTE indicates CTE is the in inner side of Apply and correlate. -// the storage of cte needs to be reset for each outer row. -// It's better to handle this in CTEExec.Close(), but cte storage is closed when SQL is finished. -func setIsInApplyForCTE(p base.LogicalPlan, apSchema *expression.Schema) { - switch x := p.(type) { - case *LogicalCTE: - if len(coreusage.ExtractCorColumnsBySchema4LogicalPlan(p, apSchema)) > 0 { - x.Cte.IsInApply = true - } - setIsInApplyForCTE(x.Cte.seedPartLogicalPlan, apSchema) - if x.Cte.recursivePartLogicalPlan != nil { - setIsInApplyForCTE(x.Cte.recursivePartLogicalPlan, apSchema) - } - default: - for _, child := range p.Children() { - setIsInApplyForCTE(child, apSchema) - } - } -} - -func (b *PlanBuilder) buildMaxOneRow(p base.LogicalPlan) base.LogicalPlan { - // The query block of the MaxOneRow operator should be the same as that of its child. - maxOneRow := logicalop.LogicalMaxOneRow{}.Init(b.ctx, p.QueryBlockOffset()) - maxOneRow.SetChildren(p) - return maxOneRow -} - -func (b *PlanBuilder) buildSemiJoin(outerPlan, innerPlan base.LogicalPlan, onCondition []expression.Expression, asScalar, not, forceRewrite bool) (*LogicalJoin, error) { - b.optFlag |= flagConvertOuterToInnerJoin - joinPlan := LogicalJoin{}.Init(b.ctx, b.getSelectOffset()) - for i, expr := range onCondition { - onCondition[i] = expr.Decorrelate(outerPlan.Schema()) - } - joinPlan.SetChildren(outerPlan, innerPlan) - joinPlan.AttachOnConds(onCondition) - joinPlan.SetOutputNames(make([]*types.FieldName, outerPlan.Schema().Len(), outerPlan.Schema().Len()+innerPlan.Schema().Len()+1)) - copy(joinPlan.OutputNames(), outerPlan.OutputNames()) - if asScalar { - newSchema := outerPlan.Schema().Clone() - newSchema.Append(&expression.Column{ - RetType: types.NewFieldType(mysql.TypeTiny), - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - }) - joinPlan.SetOutputNames(append(joinPlan.OutputNames(), types.EmptyName)) - joinPlan.SetSchema(newSchema) - if not { - joinPlan.JoinType = AntiLeftOuterSemiJoin - } else { - joinPlan.JoinType = LeftOuterSemiJoin - } - } else { - joinPlan.SetSchema(outerPlan.Schema().Clone()) - if not { - joinPlan.JoinType = AntiSemiJoin - } else { - joinPlan.JoinType = SemiJoin - } - } - // Apply forces to choose hash join currently, so don't worry the hints will take effect if the semi join is in one apply. - joinPlan.SetPreferredJoinTypeAndOrder(b.TableHints()) - if forceRewrite { - joinPlan.PreferJoinType |= h.PreferRewriteSemiJoin - b.optFlag |= flagSemiJoinRewrite - } - return joinPlan, nil -} - -func getTableOffset(names []*types.FieldName, handleName *types.FieldName) (int, error) { - for i, name := range names { - if name.DBName.L == handleName.DBName.L && name.TblName.L == handleName.TblName.L { - return i, nil - } - } - return -1, errors.Errorf("Couldn't get column information when do update/delete") -} - -// TblColPosInfo represents an mapper from column index to handle index. -type TblColPosInfo struct { - TblID int64 - // Start and End represent the ordinal range [Start, End) of the consecutive columns. - Start, End int - // HandleOrdinal represents the ordinal of the handle column. - HandleCols util.HandleCols -} - -// MemoryUsage return the memory usage of TblColPosInfo -func (t *TblColPosInfo) MemoryUsage() (sum int64) { - if t == nil { - return - } - - sum = size.SizeOfInt64 + size.SizeOfInt*2 - if t.HandleCols != nil { - sum += t.HandleCols.MemoryUsage() - } - return -} - -// TblColPosInfoSlice attaches the methods of sort.Interface to []TblColPosInfos sorting in increasing order. -type TblColPosInfoSlice []TblColPosInfo - -// Len implements sort.Interface#Len. -func (c TblColPosInfoSlice) Len() int { - return len(c) -} - -// Swap implements sort.Interface#Swap. -func (c TblColPosInfoSlice) Swap(i, j int) { - c[i], c[j] = c[j], c[i] -} - -// Less implements sort.Interface#Less. -func (c TblColPosInfoSlice) Less(i, j int) bool { - return c[i].Start < c[j].Start -} - -// FindTblIdx finds the ordinal of the corresponding access column. -func (c TblColPosInfoSlice) FindTblIdx(colOrdinal int) (int, bool) { - if len(c) == 0 { - return 0, false - } - // find the smallest index of the range that its start great than colOrdinal. - // @see https://godoc.org/sort#Search - rangeBehindOrdinal := sort.Search(len(c), func(i int) bool { return c[i].Start > colOrdinal }) - if rangeBehindOrdinal == 0 { - return 0, false - } - return rangeBehindOrdinal - 1, true -} - -// buildColumns2Handle builds columns to handle mapping. -func buildColumns2Handle( - names []*types.FieldName, - tblID2Handle map[int64][]util.HandleCols, - tblID2Table map[int64]table.Table, - onlyWritableCol bool, -) (TblColPosInfoSlice, error) { - var cols2Handles TblColPosInfoSlice - for tblID, handleCols := range tblID2Handle { - tbl := tblID2Table[tblID] - var tblLen int - if onlyWritableCol { - tblLen = len(tbl.WritableCols()) - } else { - tblLen = len(tbl.Cols()) - } - for _, handleCol := range handleCols { - offset, err := getTableOffset(names, names[handleCol.GetCol(0).Index]) - if err != nil { - return nil, err - } - end := offset + tblLen - cols2Handles = append(cols2Handles, TblColPosInfo{tblID, offset, end, handleCol}) - } - } - sort.Sort(cols2Handles) - return cols2Handles, nil -} - -func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (base.Plan, error) { - b.pushSelectOffset(0) - b.pushTableHints(update.TableHints, 0) - defer func() { - b.popSelectOffset() - // table hints are only visible in the current UPDATE statement. - b.popTableHints() - }() - - b.inUpdateStmt = true - b.isForUpdateRead = true - - if update.With != nil { - l := len(b.outerCTEs) - defer func() { - b.outerCTEs = b.outerCTEs[:l] - }() - _, err := b.buildWith(ctx, update.With) - if err != nil { - return nil, err - } - } - - p, err := b.buildResultSetNode(ctx, update.TableRefs.TableRefs, false) - if err != nil { - return nil, err - } - - tableList := ExtractTableList(update.TableRefs.TableRefs, false) - for _, t := range tableList { - dbName := t.Schema.L - if dbName == "" { - dbName = b.ctx.GetSessionVars().CurrentDB - } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) - } - - oldSchemaLen := p.Schema().Len() - if update.Where != nil { - p, err = b.buildSelection(ctx, p, update.Where, nil) - if err != nil { - return nil, err - } - } - if b.ctx.GetSessionVars().TxnCtx.IsPessimistic { - if update.TableRefs.TableRefs.Right == nil { - // buildSelectLock is an optimization that can reduce RPC call. - // We only need do this optimization for single table update which is the most common case. - // When TableRefs.Right is nil, it is single table update. - p, err = b.buildSelectLock(p, &ast.SelectLockInfo{ - LockType: ast.SelectLockForUpdate, - }) - if err != nil { - return nil, err - } - } - } - - if update.Order != nil { - p, err = b.buildSort(ctx, p, update.Order.Items, nil, nil) - if err != nil { - return nil, err - } - } - if update.Limit != nil { - p, err = b.buildLimit(p, update.Limit) - if err != nil { - return nil, err - } - } - - // Add project to freeze the order of output columns. - proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldSchemaLen])}.Init(b.ctx, b.getSelectOffset()) - proj.SetSchema(expression.NewSchema(make([]*expression.Column, oldSchemaLen)...)) - proj.SetOutputNames(make(types.NameSlice, len(p.OutputNames()))) - copy(proj.OutputNames(), p.OutputNames()) - copy(proj.Schema().Columns, p.Schema().Columns[:oldSchemaLen]) - proj.SetChildren(p) - p = proj - - utlr := &updatableTableListResolver{} - update.Accept(utlr) - orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, utlr.updatableTableList, update.List, p) - if err != nil { - return nil, err - } - p = np - - updt := Update{ - OrderedList: orderedList, - AllAssignmentsAreConstant: allAssignmentsAreConstant, - VirtualAssignmentsOffset: len(update.List), - }.Init(b.ctx) - updt.names = p.OutputNames() - // We cannot apply projection elimination when building the subplan, because - // columns in orderedList cannot be resolved. (^flagEliminateProjection should also be applied in postOptimize) - updt.SelectPlan, _, err = DoOptimize(ctx, b.ctx, b.optFlag&^flagEliminateProjection, p) - if err != nil { - return nil, err - } - err = updt.ResolveIndices() - if err != nil { - return nil, err - } - tblID2Handle, err := resolveIndicesForTblID2Handle(b.handleHelper.tailMap(), updt.SelectPlan.Schema()) - if err != nil { - return nil, err - } - tblID2table := make(map[int64]table.Table, len(tblID2Handle)) - for id := range tblID2Handle { - tblID2table[id], _ = b.is.TableByID(id) - } - updt.TblColPosInfos, err = buildColumns2Handle(updt.OutputNames(), tblID2Handle, tblID2table, true) - if err != nil { - return nil, err - } - updt.PartitionedTable = b.partitionedTable - updt.tblID2Table = tblID2table - err = updt.buildOnUpdateFKTriggers(b.ctx, b.is, tblID2table) - return updt, err -} - -// GetUpdateColumnsInfo get the update columns info. -func GetUpdateColumnsInfo(tblID2Table map[int64]table.Table, tblColPosInfos TblColPosInfoSlice, size int) []*table.Column { - colsInfo := make([]*table.Column, size) - for _, content := range tblColPosInfos { - tbl := tblID2Table[content.TblID] - for i, c := range tbl.WritableCols() { - colsInfo[content.Start+i] = c - } - } - return colsInfo -} - -type tblUpdateInfo struct { - name string - pkUpdated bool - partitionColUpdated bool -} - -// CheckUpdateList checks all related columns in updatable state. -func CheckUpdateList(assignFlags []int, updt *Update, newTblID2Table map[int64]table.Table) error { - updateFromOtherAlias := make(map[int64]tblUpdateInfo) - for _, content := range updt.TblColPosInfos { - tbl := newTblID2Table[content.TblID] - flags := assignFlags[content.Start:content.End] - var update, updatePK, updatePartitionCol bool - var partitionColumnNames []model.CIStr - if pt, ok := tbl.(table.PartitionedTable); ok && pt != nil { - partitionColumnNames = pt.GetPartitionColumnNames() - } - - for i, col := range tbl.WritableCols() { - // schema may be changed between building plan and building executor - // If i >= len(flags), it means the target table has been added columns, then we directly skip the check - if i >= len(flags) { - continue - } - if flags[i] < 0 { - continue - } - - if col.State != model.StatePublic { - return plannererrors.ErrUnknownColumn.GenWithStackByArgs(col.Name, clauseMsg[fieldList]) - } - - update = true - if mysql.HasPriKeyFlag(col.GetFlag()) { - updatePK = true - } - for _, partColName := range partitionColumnNames { - if col.Name.L == partColName.L { - updatePartitionCol = true - } - } - } - if update { - // Check for multi-updates on primary key, - // see https://dev.mysql.com/doc/mysql-errors/5.7/en/server-error-reference.html#error_er_multi_update_key_conflict - if otherTable, ok := updateFromOtherAlias[tbl.Meta().ID]; ok { - if otherTable.pkUpdated || updatePK || otherTable.partitionColUpdated || updatePartitionCol { - return plannererrors.ErrMultiUpdateKeyConflict.GenWithStackByArgs(otherTable.name, updt.names[content.Start].TblName.O) - } - } else { - updateFromOtherAlias[tbl.Meta().ID] = tblUpdateInfo{ - name: updt.names[content.Start].TblName.O, - pkUpdated: updatePK, - partitionColUpdated: updatePartitionCol, - } - } - } - } - return nil -} - -// If tl is CTE, its HintedTable will be nil. -// Only used in build plan from AST after preprocess. -func isCTE(tl *ast.TableName) bool { - return tl.TableInfo == nil -} - -func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p base.LogicalPlan) (newList []*expression.Assignment, po base.LogicalPlan, allAssignmentsAreConstant bool, e error) { - b.curClause = fieldList - // modifyColumns indicates which columns are in set list, - // and if it is set to `DEFAULT` - modifyColumns := make(map[string]bool, p.Schema().Len()) - var columnsIdx map[*ast.ColumnName]int - cacheColumnsIdx := false - if len(p.OutputNames()) > 16 { - cacheColumnsIdx = true - columnsIdx = make(map[*ast.ColumnName]int, len(list)) - } - for _, assign := range list { - idx, err := expression.FindFieldName(p.OutputNames(), assign.Column) - if err != nil { - return nil, nil, false, err - } - if idx < 0 { - return nil, nil, false, plannererrors.ErrUnknownColumn.GenWithStackByArgs(assign.Column.Name, "field list") - } - if cacheColumnsIdx { - columnsIdx[assign.Column] = idx - } - name := p.OutputNames()[idx] - foundListItem := false - for _, tl := range tableList { - if (tl.Schema.L == "" || tl.Schema.L == name.DBName.L) && (tl.Name.L == name.TblName.L) { - if isCTE(tl) || tl.TableInfo.IsView() || tl.TableInfo.IsSequence() { - return nil, nil, false, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE") - } - foundListItem = true - } - } - if !foundListItem { - // For case like: - // 1: update (select * from t1) t1 set b = 1111111 ----- (no updatable table here) - // 2: update (select 1 as a) as t, t1 set a=1 ----- (updatable t1 don't have column a) - // --- subQuery is not counted as updatable table. - return nil, nil, false, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE") - } - columnFullName := fmt.Sprintf("%s.%s.%s", name.DBName.L, name.TblName.L, name.ColName.L) - // We save a flag for the column in map `modifyColumns` - // This flag indicated if assign keyword `DEFAULT` to the column - modifyColumns[columnFullName] = IsDefaultExprSameColumn(p.OutputNames()[idx:idx+1], assign.Expr) - } - - // If columns in set list contains generated columns, raise error. - // And, fill virtualAssignments here; that's for generated columns. - virtualAssignments := make([]*ast.Assignment, 0) - for _, tn := range tableList { - if isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() { - continue - } - - tableInfo := tn.TableInfo - tableVal, found := b.is.TableByID(tableInfo.ID) - if !found { - return nil, nil, false, infoschema.ErrTableNotExists.FastGenByArgs(tn.DBInfo.Name.O, tableInfo.Name.O) - } - for i, colInfo := range tableVal.Cols() { - if !colInfo.IsGenerated() { - continue - } - columnFullName := fmt.Sprintf("%s.%s.%s", tn.DBInfo.Name.L, tn.Name.L, colInfo.Name.L) - isDefault, ok := modifyColumns[columnFullName] - if ok && colInfo.Hidden { - return nil, nil, false, plannererrors.ErrUnknownColumn.GenWithStackByArgs(colInfo.Name, clauseMsg[fieldList]) - } - // Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT. - // see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html - if ok && !isDefault { - return nil, nil, false, plannererrors.ErrBadGeneratedColumn.GenWithStackByArgs(colInfo.Name.O, tableInfo.Name.O) - } - virtualAssignments = append(virtualAssignments, &ast.Assignment{ - Column: &ast.ColumnName{Schema: tn.Schema, Table: tn.Name, Name: colInfo.Name}, - Expr: tableVal.Cols()[i].GeneratedExpr.Clone(), - }) - } - } - - allAssignmentsAreConstant = true - newList = make([]*expression.Assignment, 0, p.Schema().Len()) - tblDbMap := make(map[string]string, len(tableList)) - for _, tbl := range tableList { - if isCTE(tbl) { - continue - } - tblDbMap[tbl.Name.L] = tbl.DBInfo.Name.L - } - - allAssignments := append(list, virtualAssignments...) - dependentColumnsModified := make(map[int64]bool) - for i, assign := range allAssignments { - var idx int - var err error - if cacheColumnsIdx { - if i, ok := columnsIdx[assign.Column]; ok { - idx = i - } else { - idx, err = expression.FindFieldName(p.OutputNames(), assign.Column) - } - } else { - idx, err = expression.FindFieldName(p.OutputNames(), assign.Column) - } - if err != nil { - return nil, nil, false, err - } - col := p.Schema().Columns[idx] - name := p.OutputNames()[idx] - var newExpr expression.Expression - var np base.LogicalPlan - if i < len(list) { - // If assign `DEFAULT` to column, fill the `defaultExpr.Name` before rewrite expression - if expr := extractDefaultExpr(assign.Expr); expr != nil { - expr.Name = assign.Column - } - newExpr, np, err = b.rewrite(ctx, assign.Expr, p, nil, true) - if err != nil { - return nil, nil, false, err - } - dependentColumnsModified[col.UniqueID] = true - } else { - // rewrite with generation expression - rewritePreprocess := func(assign *ast.Assignment) func(expr ast.Node) ast.Node { - return func(expr ast.Node) ast.Node { - switch x := expr.(type) { - case *ast.ColumnName: - return &ast.ColumnName{ - Schema: assign.Column.Schema, - Table: assign.Column.Table, - Name: x.Name, - } - default: - return expr - } - } - } - - o := b.allowBuildCastArray - b.allowBuildCastArray = true - newExpr, np, err = b.rewriteWithPreprocess(ctx, assign.Expr, p, nil, nil, true, rewritePreprocess(assign)) - b.allowBuildCastArray = o - if err != nil { - return nil, nil, false, err - } - // check if the column is modified - dependentColumns := expression.ExtractDependentColumns(newExpr) - var isModified bool - for _, col := range dependentColumns { - if dependentColumnsModified[col.UniqueID] { - isModified = true - break - } - } - // skip unmodified generated columns - if !isModified { - continue - } - } - if _, isConst := newExpr.(*expression.Constant); !isConst { - allAssignmentsAreConstant = false - } - p = np - if cols := expression.ExtractColumnSet(newExpr); cols.Len() > 0 { - b.ctx.GetSessionVars().StmtCtx.ColRefFromUpdatePlan.UnionWith(cols) - } - newList = append(newList, &expression.Assignment{Col: col, ColName: name.ColName, Expr: newExpr}) - dbName := name.DBName.L - // To solve issue#10028, we need to get database name by the table alias name. - if dbNameTmp, ok := tblDbMap[name.TblName.L]; ok { - dbName = dbNameTmp - } - if dbName == "" { - dbName = b.ctx.GetSessionVars().CurrentDB - } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.UpdatePriv, dbName, name.OrigTblName.L, "", nil) - } - return newList, p, allAssignmentsAreConstant, nil -} - -// extractDefaultExpr extract a `DefaultExpr` from `ExprNode`, -// If it is a `DEFAULT` function like `DEFAULT(a)`, return nil. -// Only if it is `DEFAULT` keyword, it will return the `DefaultExpr`. -func extractDefaultExpr(node ast.ExprNode) *ast.DefaultExpr { - if expr, ok := node.(*ast.DefaultExpr); ok && expr.Name == nil { - return expr - } - return nil -} - -// IsDefaultExprSameColumn - DEFAULT or col = DEFAULT(col) -func IsDefaultExprSameColumn(names types.NameSlice, node ast.ExprNode) bool { - if expr, ok := node.(*ast.DefaultExpr); ok { - if expr.Name == nil { - // col = DEFAULT - return true - } - refIdx, err := expression.FindFieldName(names, expr.Name) - if refIdx == 0 && err == nil { - // col = DEFAULT(col) - return true - } - } - return false -} - -func (b *PlanBuilder) buildDelete(ctx context.Context, ds *ast.DeleteStmt) (base.Plan, error) { - b.pushSelectOffset(0) - b.pushTableHints(ds.TableHints, 0) - defer func() { - b.popSelectOffset() - // table hints are only visible in the current DELETE statement. - b.popTableHints() - }() - - b.inDeleteStmt = true - b.isForUpdateRead = true - - if ds.With != nil { - l := len(b.outerCTEs) - defer func() { - b.outerCTEs = b.outerCTEs[:l] - }() - _, err := b.buildWith(ctx, ds.With) - if err != nil { - return nil, err - } - } - - p, err := b.buildResultSetNode(ctx, ds.TableRefs.TableRefs, false) - if err != nil { - return nil, err - } - oldSchema := p.Schema() - oldLen := oldSchema.Len() - - // For explicit column usage, should use the all-public columns. - if ds.Where != nil { - p, err = b.buildSelection(ctx, p, ds.Where, nil) - if err != nil { - return nil, err - } - } - if b.ctx.GetSessionVars().TxnCtx.IsPessimistic { - if !ds.IsMultiTable { - p, err = b.buildSelectLock(p, &ast.SelectLockInfo{ - LockType: ast.SelectLockForUpdate, - }) - if err != nil { - return nil, err - } - } - } - - if ds.Order != nil { - p, err = b.buildSort(ctx, p, ds.Order.Items, nil, nil) - if err != nil { - return nil, err - } - } - - if ds.Limit != nil { - p, err = b.buildLimit(p, ds.Limit) - if err != nil { - return nil, err - } - } - - // If the delete is non-qualified it does not require Select Priv - if ds.Where == nil && ds.Order == nil { - b.popVisitInfo() - } - var authErr error - sessionVars := b.ctx.GetSessionVars() - - proj := logicalop.LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldLen])}.Init(b.ctx, b.getSelectOffset()) - proj.SetChildren(p) - proj.SetSchema(oldSchema.Clone()) - proj.SetOutputNames(p.OutputNames()[:oldLen]) - p = proj - - del := Delete{ - IsMultiTable: ds.IsMultiTable, - }.Init(b.ctx) - - del.names = p.OutputNames() - // Collect visitInfo. - if ds.Tables != nil { - // Delete a, b from a, b, c, d... add a and b. - updatableList := make(map[string]bool) - tbInfoList := make(map[string]*ast.TableName) - collectTableName(ds.TableRefs.TableRefs, &updatableList, &tbInfoList) - for _, tn := range ds.Tables.Tables { - var canUpdate, foundMatch = false, false - name := tn.Name.L - if tn.Schema.L == "" { - canUpdate, foundMatch = updatableList[name] - } - - if !foundMatch { - if tn.Schema.L == "" { - name = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB).L + "." + tn.Name.L - } else { - name = tn.Schema.L + "." + tn.Name.L - } - canUpdate, foundMatch = updatableList[name] - } - // check sql like: `delete b from (select * from t) as a, t` - if !foundMatch { - return nil, plannererrors.ErrUnknownTable.GenWithStackByArgs(tn.Name.O, "MULTI DELETE") - } - // check sql like: `delete a from (select * from t) as a, t` - if !canUpdate { - return nil, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(tn.Name.O, "DELETE") - } - tb := tbInfoList[name] - tn.DBInfo = tb.DBInfo - tn.TableInfo = tb.TableInfo - if tn.TableInfo.IsView() { - return nil, errors.Errorf("delete view %s is not supported now", tn.Name.O) - } - if tn.TableInfo.IsSequence() { - return nil, errors.Errorf("delete sequence %s is not supported now", tn.Name.O) - } - if sessionVars.User != nil { - authErr = plannererrors.ErrTableaccessDenied.FastGenByArgs("DELETE", sessionVars.User.AuthUsername, sessionVars.User.AuthHostname, tb.Name.L) - } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, tb.DBInfo.Name.L, tb.Name.L, "", authErr) - } - } else { - // Delete from a, b, c, d. - tableList := ExtractTableList(ds.TableRefs.TableRefs, false) - for _, v := range tableList { - if isCTE(v) { - return nil, plannererrors.ErrNonUpdatableTable.GenWithStackByArgs(v.Name.O, "DELETE") - } - if v.TableInfo.IsView() { - return nil, errors.Errorf("delete view %s is not supported now", v.Name.O) - } - if v.TableInfo.IsSequence() { - return nil, errors.Errorf("delete sequence %s is not supported now", v.Name.O) - } - dbName := v.Schema.L - if dbName == "" { - dbName = b.ctx.GetSessionVars().CurrentDB - } - if sessionVars.User != nil { - authErr = plannererrors.ErrTableaccessDenied.FastGenByArgs("DELETE", sessionVars.User.AuthUsername, sessionVars.User.AuthHostname, v.Name.L) - } - b.visitInfo = appendVisitInfo(b.visitInfo, mysql.DeletePriv, dbName, v.Name.L, "", authErr) - } - } - handleColsMap := b.handleHelper.tailMap() - tblID2Handle, err := resolveIndicesForTblID2Handle(handleColsMap, p.Schema()) - if err != nil { - return nil, err - } - if del.IsMultiTable { - // tblID2TableName is the table map value is an array which contains table aliases. - // Table ID may not be unique for deleting multiple tables, for statements like - // `delete from t as t1, t as t2`, the same table has two alias, we have to identify a table - // by its alias instead of ID. - tblID2TableName := make(map[int64][]*ast.TableName, len(ds.Tables.Tables)) - for _, tn := range ds.Tables.Tables { - tblID2TableName[tn.TableInfo.ID] = append(tblID2TableName[tn.TableInfo.ID], tn) - } - tblID2Handle = del.cleanTblID2HandleMap(tblID2TableName, tblID2Handle, del.names) - } - tblID2table := make(map[int64]table.Table, len(tblID2Handle)) - for id := range tblID2Handle { - tblID2table[id], _ = b.is.TableByID(id) - } - del.TblColPosInfos, err = buildColumns2Handle(del.names, tblID2Handle, tblID2table, false) - if err != nil { - return nil, err - } - - del.SelectPlan, _, err = DoOptimize(ctx, b.ctx, b.optFlag, p) - if err != nil { - return nil, err - } - - err = del.buildOnDeleteFKTriggers(b.ctx, b.is, tblID2table) - return del, err -} - -func resolveIndicesForTblID2Handle(tblID2Handle map[int64][]util.HandleCols, schema *expression.Schema) (map[int64][]util.HandleCols, error) { - newMap := make(map[int64][]util.HandleCols, len(tblID2Handle)) - for i, cols := range tblID2Handle { - for _, col := range cols { - resolvedCol, err := col.ResolveIndices(schema) - if err != nil { - return nil, err - } - newMap[i] = append(newMap[i], resolvedCol) - } - } - return newMap, nil -} - -func (p *Delete) cleanTblID2HandleMap( - tablesToDelete map[int64][]*ast.TableName, - tblID2Handle map[int64][]util.HandleCols, - outputNames []*types.FieldName, -) map[int64][]util.HandleCols { - for id, cols := range tblID2Handle { - names, ok := tablesToDelete[id] - if !ok { - delete(tblID2Handle, id) - continue - } - for i := len(cols) - 1; i >= 0; i-- { - hCols := cols[i] - var hasMatch bool - for j := 0; j < hCols.NumCols(); j++ { - if p.matchingDeletingTable(names, outputNames[hCols.GetCol(j).Index]) { - hasMatch = true - break - } - } - if !hasMatch { - cols = append(cols[:i], cols[i+1:]...) - } - } - if len(cols) == 0 { - delete(tblID2Handle, id) - continue - } - tblID2Handle[id] = cols - } - return tblID2Handle -} - -// matchingDeletingTable checks whether this column is from the table which is in the deleting list. -func (*Delete) matchingDeletingTable(names []*ast.TableName, name *types.FieldName) bool { - for _, n := range names { - if (name.DBName.L == "" || name.DBName.L == n.DBInfo.Name.L) && name.TblName.L == n.Name.L { - return true - } - } - return false -} - -func getWindowName(name string) string { - if name == "" { - return "" - } - return name -} - -// buildProjectionForWindow builds the projection for expressions in the window specification that is not an column, -// so after the projection, window functions only needs to deal with columns. -func (b *PlanBuilder) buildProjectionForWindow(ctx context.Context, p base.LogicalPlan, spec *ast.WindowSpec, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, []property.SortItem, []property.SortItem, []expression.Expression, error) { - b.optFlag |= flagEliminateProjection - - var partitionItems, orderItems []*ast.ByItem - if spec.PartitionBy != nil { - partitionItems = spec.PartitionBy.Items - } - if spec.OrderBy != nil { - orderItems = spec.OrderBy.Items - } - - projLen := len(p.Schema().Columns) + len(partitionItems) + len(orderItems) + len(args) - proj := logicalop.LogicalProjection{Exprs: make([]expression.Expression, 0, projLen)}.Init(b.ctx, b.getSelectOffset()) - proj.SetSchema(expression.NewSchema(make([]*expression.Column, 0, projLen)...)) - proj.SetOutputNames(make([]*types.FieldName, p.Schema().Len(), projLen)) - for _, col := range p.Schema().Columns { - proj.Exprs = append(proj.Exprs, col) - proj.Schema().Append(col) - } - copy(proj.OutputNames(), p.OutputNames()) - - propertyItems := make([]property.SortItem, 0, len(partitionItems)+len(orderItems)) - var err error - p, propertyItems, err = b.buildByItemsForWindow(ctx, p, proj, partitionItems, propertyItems, aggMap) - if err != nil { - return nil, nil, nil, nil, err - } - lenPartition := len(propertyItems) - p, propertyItems, err = b.buildByItemsForWindow(ctx, p, proj, orderItems, propertyItems, aggMap) - if err != nil { - return nil, nil, nil, nil, err - } - - newArgList := make([]expression.Expression, 0, len(args)) - for _, arg := range args { - newArg, np, err := b.rewrite(ctx, arg, p, aggMap, true) - if err != nil { - return nil, nil, nil, nil, err - } - p = np - switch newArg.(type) { - case *expression.Column, *expression.Constant: - newArgList = append(newArgList, newArg.Clone()) - continue - } - proj.Exprs = append(proj.Exprs, newArg) - proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName)) - col := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: newArg.GetType(b.ctx.GetExprCtx().GetEvalCtx()), - } - proj.Schema().Append(col) - newArgList = append(newArgList, col) - } - - proj.SetChildren(p) - return proj, propertyItems[:lenPartition], propertyItems[lenPartition:], newArgList, nil -} - -func (b *PlanBuilder) buildArgs4WindowFunc(ctx context.Context, p base.LogicalPlan, args []ast.ExprNode, aggMap map[*ast.AggregateFuncExpr]int) ([]expression.Expression, error) { - b.optFlag |= flagEliminateProjection - - newArgList := make([]expression.Expression, 0, len(args)) - // use below index for created a new col definition - // it's okay here because we only want to return the args used in window function - newColIndex := 0 - for _, arg := range args { - newArg, np, err := b.rewrite(ctx, arg, p, aggMap, true) - if err != nil { - return nil, err - } - p = np - switch newArg.(type) { - case *expression.Column, *expression.Constant: - newArgList = append(newArgList, newArg.Clone()) - continue - } - col := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: newArg.GetType(b.ctx.GetExprCtx().GetEvalCtx()), - } - newColIndex++ - newArgList = append(newArgList, col) - } - return newArgList, nil -} - -func (b *PlanBuilder) buildByItemsForWindow( - ctx context.Context, - p base.LogicalPlan, - proj *logicalop.LogicalProjection, - items []*ast.ByItem, - retItems []property.SortItem, - aggMap map[*ast.AggregateFuncExpr]int, -) (base.LogicalPlan, []property.SortItem, error) { - transformer := &itemTransformer{} - for _, item := range items { - newExpr, _ := item.Expr.Accept(transformer) - item.Expr = newExpr.(ast.ExprNode) - it, np, err := b.rewrite(ctx, item.Expr, p, aggMap, true) - if err != nil { - return nil, nil, err - } - p = np - if it.GetType(b.ctx.GetExprCtx().GetEvalCtx()).GetType() == mysql.TypeNull { - continue - } - if col, ok := it.(*expression.Column); ok { - retItems = append(retItems, property.SortItem{Col: col, Desc: item.Desc}) - // We need to attempt to add this column because a subquery may be created during the expression rewrite process. - // Therefore, we need to ensure that the column from the newly created query plan is added. - // If the column is already in the schema, we don't need to add it again. - if !proj.Schema().Contains(col) { - proj.Exprs = append(proj.Exprs, col) - proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName)) - proj.Schema().Append(col) - } - continue - } - proj.Exprs = append(proj.Exprs, it) - proj.SetOutputNames(append(proj.OutputNames(), types.EmptyName)) - col := &expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: it.GetType(b.ctx.GetExprCtx().GetEvalCtx()), - } - proj.Schema().Append(col) - retItems = append(retItems, property.SortItem{Col: col, Desc: item.Desc}) - } - return p, retItems, nil -} - -// buildWindowFunctionFrameBound builds the bounds of window function frames. -// For type `Rows`, the bound expr must be an unsigned integer. -// For type `Range`, the bound expr must be temporal or numeric types. -func (b *PlanBuilder) buildWindowFunctionFrameBound(_ context.Context, spec *ast.WindowSpec, orderByItems []property.SortItem, boundClause *ast.FrameBound) (*logicalop.FrameBound, error) { - frameType := spec.Frame.Type - bound := &logicalop.FrameBound{Type: boundClause.Type, UnBounded: boundClause.UnBounded, IsExplicitRange: false} - if bound.UnBounded { - return bound, nil - } - - if frameType == ast.Rows { - if bound.Type == ast.CurrentRow { - return bound, nil - } - numRows, _, _ := getUintFromNode(b.ctx, boundClause.Expr, false) - bound.Num = numRows - return bound, nil - } - - bound.CalcFuncs = make([]expression.Expression, len(orderByItems)) - bound.CmpFuncs = make([]expression.CompareFunc, len(orderByItems)) - if bound.Type == ast.CurrentRow { - for i, item := range orderByItems { - col := item.Col - bound.CalcFuncs[i] = col - bound.CmpFuncs[i] = expression.GetCmpFunction(b.ctx.GetExprCtx(), col, col) - } - return bound, nil - } - - col := orderByItems[0].Col - // TODO: We also need to raise error for non-deterministic expressions, like rand(). - val, err := evalAstExprWithPlanCtx(b.ctx, boundClause.Expr) - if err != nil { - return nil, plannererrors.ErrWindowRangeBoundNotConstant.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - expr := expression.Constant{Value: val, RetType: boundClause.Expr.GetType()} - - checker := &expression.ParamMarkerInPrepareChecker{} - boundClause.Expr.Accept(checker) - - // If it has paramMarker and is in prepare stmt. We don't need to eval it since its value is not decided yet. - if !checker.InPrepareStmt { - // Do not raise warnings for truncate. - exprCtx := exprctx.CtxWithHandleTruncateErrLevel(b.ctx.GetExprCtx(), errctx.LevelIgnore) - uVal, isNull, err := expr.EvalInt(exprCtx.GetEvalCtx(), chunk.Row{}) - if uVal < 0 || isNull || err != nil { - return nil, plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - } - - bound.IsExplicitRange = true - desc := orderByItems[0].Desc - var funcName string - if boundClause.Unit != ast.TimeUnitInvalid { - // TODO: Perhaps we don't need to transcode this back to generic string - unitVal := boundClause.Unit.String() - unit := expression.Constant{ - Value: types.NewStringDatum(unitVal), - RetType: types.NewFieldType(mysql.TypeVarchar), - } - - // When the order is asc: - // `+` for following, and `-` for the preceding - // When the order is desc, `+` becomes `-` and vice-versa. - funcName = ast.DateAdd - if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) { - funcName = ast.DateSub - } - - bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx.GetExprCtx(), funcName, col.RetType, col, &expr, &unit) - if err != nil { - return nil, err - } - } else { - // When the order is asc: - // `+` for following, and `-` for the preceding - // When the order is desc, `+` becomes `-` and vice-versa. - funcName = ast.Plus - if (!desc && bound.Type == ast.Preceding) || (desc && bound.Type == ast.Following) { - funcName = ast.Minus - } - - bound.CalcFuncs[0], err = expression.NewFunctionBase(b.ctx.GetExprCtx(), funcName, col.RetType, col, &expr) - if err != nil { - return nil, err - } - } - - cmpDataType := expression.GetAccurateCmpType(b.ctx.GetExprCtx().GetEvalCtx(), col, bound.CalcFuncs[0]) - bound.UpdateCmpFuncsAndCmpDataType(cmpDataType) - return bound, nil -} - -// buildWindowFunctionFrame builds the window function frames. -// See https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html -func (b *PlanBuilder) buildWindowFunctionFrame(ctx context.Context, spec *ast.WindowSpec, orderByItems []property.SortItem) (*logicalop.WindowFrame, error) { - frameClause := spec.Frame - if frameClause == nil { - return nil, nil - } - frame := &logicalop.WindowFrame{Type: frameClause.Type} - var err error - frame.Start, err = b.buildWindowFunctionFrameBound(ctx, spec, orderByItems, &frameClause.Extent.Start) - if err != nil { - return nil, err - } - frame.End, err = b.buildWindowFunctionFrameBound(ctx, spec, orderByItems, &frameClause.Extent.End) - return frame, err -} - -func (b *PlanBuilder) checkWindowFuncArgs(ctx context.Context, p base.LogicalPlan, windowFuncExprs []*ast.WindowFuncExpr, windowAggMap map[*ast.AggregateFuncExpr]int) error { - checker := &expression.ParamMarkerInPrepareChecker{} - for _, windowFuncExpr := range windowFuncExprs { - if strings.ToLower(windowFuncExpr.Name) == ast.AggFuncGroupConcat { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("group_concat as window function") - } - args, err := b.buildArgs4WindowFunc(ctx, p, windowFuncExpr.Args, windowAggMap) - if err != nil { - return err - } - checker.InPrepareStmt = false - for _, expr := range windowFuncExpr.Args { - expr.Accept(checker) - } - desc, err := aggregation.NewWindowFuncDesc(b.ctx.GetExprCtx(), windowFuncExpr.Name, args, checker.InPrepareStmt) - if err != nil { - return err - } - if desc == nil { - return plannererrors.ErrWrongArguments.GenWithStackByArgs(strings.ToLower(windowFuncExpr.Name)) - } - } - return nil -} - -func getAllByItems(itemsBuf []*ast.ByItem, spec *ast.WindowSpec) []*ast.ByItem { - itemsBuf = itemsBuf[:0] - if spec.PartitionBy != nil { - itemsBuf = append(itemsBuf, spec.PartitionBy.Items...) - } - if spec.OrderBy != nil { - itemsBuf = append(itemsBuf, spec.OrderBy.Items...) - } - return itemsBuf -} - -func restoreByItemText(item *ast.ByItem) string { - var sb strings.Builder - ctx := format.NewRestoreCtx(0, &sb) - err := item.Expr.Restore(ctx) - if err != nil { - return "" - } - return sb.String() -} - -func compareItems(lItems []*ast.ByItem, rItems []*ast.ByItem) bool { - minLen := min(len(lItems), len(rItems)) - for i := 0; i < minLen; i++ { - res := strings.Compare(restoreByItemText(lItems[i]), restoreByItemText(rItems[i])) - if res != 0 { - return res < 0 - } - res = compareBool(lItems[i].Desc, rItems[i].Desc) - if res != 0 { - return res < 0 - } - } - return len(lItems) < len(rItems) -} - -type windowFuncs struct { - spec *ast.WindowSpec - funcs []*ast.WindowFuncExpr -} - -// sortWindowSpecs sorts the window specifications by reversed alphabetical order, then we could add less `Sort` operator -// in physical plan because the window functions with the same partition by and order by clause will be at near places. -func sortWindowSpecs(groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr, orderedSpec []*ast.WindowSpec) []windowFuncs { - windows := make([]windowFuncs, 0, len(groupedFuncs)) - for _, spec := range orderedSpec { - windows = append(windows, windowFuncs{spec, groupedFuncs[spec]}) - } - lItemsBuf := make([]*ast.ByItem, 0, 4) - rItemsBuf := make([]*ast.ByItem, 0, 4) - sort.SliceStable(windows, func(i, j int) bool { - lItemsBuf = getAllByItems(lItemsBuf, windows[i].spec) - rItemsBuf = getAllByItems(rItemsBuf, windows[j].spec) - return !compareItems(lItemsBuf, rItemsBuf) - }) - return windows -} - -func (b *PlanBuilder) buildWindowFunctions(ctx context.Context, p base.LogicalPlan, groupedFuncs map[*ast.WindowSpec][]*ast.WindowFuncExpr, orderedSpec []*ast.WindowSpec, aggMap map[*ast.AggregateFuncExpr]int) (base.LogicalPlan, map[*ast.WindowFuncExpr]int, error) { - if b.buildingCTE { - b.outerCTEs[len(b.outerCTEs)-1].containAggOrWindow = true - } - args := make([]ast.ExprNode, 0, 4) - windowMap := make(map[*ast.WindowFuncExpr]int) - for _, window := range sortWindowSpecs(groupedFuncs, orderedSpec) { - args = args[:0] - spec, funcs := window.spec, window.funcs - for _, windowFunc := range funcs { - args = append(args, windowFunc.Args...) - } - np, partitionBy, orderBy, args, err := b.buildProjectionForWindow(ctx, p, spec, args, aggMap) - if err != nil { - return nil, nil, err - } - if len(funcs) == 0 { - // len(funcs) == 0 indicates this an unused named window spec, - // so we just check for its validity and don't have to build plan for it. - err := b.checkOriginWindowSpec(spec, orderBy) - if err != nil { - return nil, nil, err - } - continue - } - err = b.checkOriginWindowFuncs(funcs, orderBy) - if err != nil { - return nil, nil, err - } - frame, err := b.buildWindowFunctionFrame(ctx, spec, orderBy) - if err != nil { - return nil, nil, err - } - - window := logicalop.LogicalWindow{ - PartitionBy: partitionBy, - OrderBy: orderBy, - Frame: frame, - }.Init(b.ctx, b.getSelectOffset()) - window.SetOutputNames(make([]*types.FieldName, np.Schema().Len())) - copy(window.OutputNames(), np.OutputNames()) - schema := np.Schema().Clone() - descs := make([]*aggregation.WindowFuncDesc, 0, len(funcs)) - preArgs := 0 - checker := &expression.ParamMarkerInPrepareChecker{} - for _, windowFunc := range funcs { - checker.InPrepareStmt = false - for _, expr := range windowFunc.Args { - expr.Accept(checker) - } - desc, err := aggregation.NewWindowFuncDesc(b.ctx.GetExprCtx(), windowFunc.Name, args[preArgs:preArgs+len(windowFunc.Args)], checker.InPrepareStmt) - if err != nil { - return nil, nil, err - } - if desc == nil { - return nil, nil, plannererrors.ErrWrongArguments.GenWithStackByArgs(strings.ToLower(windowFunc.Name)) - } - preArgs += len(windowFunc.Args) - desc.WrapCastForAggArgs(b.ctx.GetExprCtx()) - descs = append(descs, desc) - windowMap[windowFunc] = schema.Len() - schema.Append(&expression.Column{ - UniqueID: b.ctx.GetSessionVars().AllocPlanColumnID(), - RetType: desc.RetTp, - }) - window.SetOutputNames(append(window.OutputNames(), types.EmptyName)) - } - window.WindowFuncDescs = descs - window.SetChildren(np) - window.SetSchema(schema) - p = window - } - return p, windowMap, nil -} - -// checkOriginWindowFuncs checks the validity for original window specifications for a group of functions. -// Because the grouped specification is different from them, we should especially check them before build window frame. -func (b *PlanBuilder) checkOriginWindowFuncs(funcs []*ast.WindowFuncExpr, orderByItems []property.SortItem) error { - for _, f := range funcs { - if f.IgnoreNull { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("IGNORE NULLS") - } - if f.Distinct { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("(DISTINCT ..)") - } - if f.FromLast { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("FROM LAST") - } - spec := &f.Spec - if f.Spec.Name.L != "" { - spec = b.windowSpecs[f.Spec.Name.L] - } - if err := b.checkOriginWindowSpec(spec, orderByItems); err != nil { - return err - } - } - return nil -} - -// checkOriginWindowSpec checks the validity for given window specification. -func (b *PlanBuilder) checkOriginWindowSpec(spec *ast.WindowSpec, orderByItems []property.SortItem) error { - if spec.Frame == nil { - return nil - } - if spec.Frame.Type == ast.Groups { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("GROUPS") - } - start, end := spec.Frame.Extent.Start, spec.Frame.Extent.End - if start.Type == ast.Following && start.UnBounded { - return plannererrors.ErrWindowFrameStartIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - if end.Type == ast.Preceding && end.UnBounded { - return plannererrors.ErrWindowFrameEndIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - if start.Type == ast.Following && (end.Type == ast.Preceding || end.Type == ast.CurrentRow) { - return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - if (start.Type == ast.Following || start.Type == ast.CurrentRow) && end.Type == ast.Preceding { - return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - - err := b.checkOriginWindowFrameBound(&start, spec, orderByItems) - if err != nil { - return err - } - err = b.checkOriginWindowFrameBound(&end, spec, orderByItems) - if err != nil { - return err - } - return nil -} - -func (b *PlanBuilder) checkOriginWindowFrameBound(bound *ast.FrameBound, spec *ast.WindowSpec, orderByItems []property.SortItem) error { - if bound.Type == ast.CurrentRow || bound.UnBounded { - return nil - } - - frameType := spec.Frame.Type - if frameType == ast.Rows { - if bound.Unit != ast.TimeUnitInvalid { - return plannererrors.ErrWindowRowsIntervalUse.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - _, isNull, isExpectedType := getUintFromNode(b.ctx, bound.Expr, false) - if isNull || !isExpectedType { - return plannererrors.ErrWindowFrameIllegal.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - return nil - } - - if len(orderByItems) != 1 { - return plannererrors.ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - orderItemType := orderByItems[0].Col.RetType.GetType() - isNumeric, isTemporal := types.IsTypeNumeric(orderItemType), types.IsTypeTemporal(orderItemType) - if !isNumeric && !isTemporal { - return plannererrors.ErrWindowRangeFrameOrderType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - if bound.Unit != ast.TimeUnitInvalid && !isTemporal { - return plannererrors.ErrWindowRangeFrameNumericType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - if bound.Unit == ast.TimeUnitInvalid && !isNumeric { - return plannererrors.ErrWindowRangeFrameTemporalType.GenWithStackByArgs(getWindowName(spec.Name.O)) - } - return nil -} - -func extractWindowFuncs(fields []*ast.SelectField) []*ast.WindowFuncExpr { - extractor := &WindowFuncExtractor{} - for _, f := range fields { - n, _ := f.Expr.Accept(extractor) - f.Expr = n.(ast.ExprNode) - } - return extractor.windowFuncs -} - -func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, windowFuncName string) (*ast.WindowSpec, bool) { - needFrame := aggregation.NeedFrame(windowFuncName) - // According to MySQL, In the absence of a frame clause, the default frame depends on whether an ORDER BY clause is present: - // (1) With order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"; - // (2) Without order by, the default frame is includes all partition rows, equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", - // or "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", which is the same as an empty frame. - // https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html - if needFrame && spec.Frame == nil && spec.OrderBy != nil { - newSpec := *spec - newSpec.Frame = &ast.FrameClause{ - Type: ast.Ranges, - Extent: ast.FrameExtent{ - Start: ast.FrameBound{Type: ast.Preceding, UnBounded: true}, - End: ast.FrameBound{Type: ast.CurrentRow}, - }, - } - return &newSpec, true - } - // "RANGE/ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" is equivalent to empty frame. - if needFrame && spec.Frame != nil && - spec.Frame.Extent.Start.UnBounded && spec.Frame.Extent.End.UnBounded { - newSpec := *spec - newSpec.Frame = nil - return &newSpec, true - } - if !needFrame { - var updated bool - newSpec := *spec - - // For functions that operate on the entire partition, the frame clause will be ignored. - if spec.Frame != nil { - specName := spec.Name.O - b.ctx.GetSessionVars().StmtCtx.AppendNote(plannererrors.ErrWindowFunctionIgnoresFrame.FastGenByArgs(windowFuncName, getWindowName(specName))) - newSpec.Frame = nil - updated = true - } - if b.ctx.GetSessionVars().EnablePipelinedWindowExec { - useDefaultFrame, defaultFrame := aggregation.UseDefaultFrame(windowFuncName) - if useDefaultFrame { - newSpec.Frame = &defaultFrame - updated = true - } - } - if updated { - return &newSpec, true - } - } - return spec, false -} - -// append ast.WindowSpec to []*ast.WindowSpec if absent -func appendIfAbsentWindowSpec(specs []*ast.WindowSpec, ns *ast.WindowSpec) []*ast.WindowSpec { - for _, spec := range specs { - if spec == ns { - return specs - } - } - return append(specs, ns) -} - -func specEqual(s1, s2 *ast.WindowSpec) (equal bool, err error) { - if (s1 == nil && s2 != nil) || (s1 != nil && s2 == nil) { - return false, nil - } - var sb1, sb2 strings.Builder - ctx1 := format.NewRestoreCtx(0, &sb1) - ctx2 := format.NewRestoreCtx(0, &sb2) - if err = s1.Restore(ctx1); err != nil { - return - } - if err = s2.Restore(ctx2); err != nil { - return - } - return sb1.String() == sb2.String(), nil -} - -// groupWindowFuncs groups the window functions according to the window specification name. -// TODO: We can group the window function by the definition of window specification. -func (b *PlanBuilder) groupWindowFuncs(windowFuncs []*ast.WindowFuncExpr) (map[*ast.WindowSpec][]*ast.WindowFuncExpr, []*ast.WindowSpec, error) { - // updatedSpecMap is used to handle the specifications that have frame clause changed. - updatedSpecMap := make(map[string][]*ast.WindowSpec) - groupedWindow := make(map[*ast.WindowSpec][]*ast.WindowFuncExpr) - orderedSpec := make([]*ast.WindowSpec, 0, len(windowFuncs)) - for _, windowFunc := range windowFuncs { - if windowFunc.Spec.Name.L == "" { - spec := &windowFunc.Spec - if spec.Ref.L != "" { - ref, ok := b.windowSpecs[spec.Ref.L] - if !ok { - return nil, nil, plannererrors.ErrWindowNoSuchWindow.GenWithStackByArgs(getWindowName(spec.Ref.O)) - } - err := mergeWindowSpec(spec, ref) - if err != nil { - return nil, nil, err - } - } - spec, _ = b.handleDefaultFrame(spec, windowFunc.Name) - groupedWindow[spec] = append(groupedWindow[spec], windowFunc) - orderedSpec = appendIfAbsentWindowSpec(orderedSpec, spec) - continue - } - - name := windowFunc.Spec.Name.L - spec, ok := b.windowSpecs[name] - if !ok { - return nil, nil, plannererrors.ErrWindowNoSuchWindow.GenWithStackByArgs(windowFunc.Spec.Name.O) - } - newSpec, updated := b.handleDefaultFrame(spec, windowFunc.Name) - if !updated { - groupedWindow[spec] = append(groupedWindow[spec], windowFunc) - orderedSpec = appendIfAbsentWindowSpec(orderedSpec, spec) - } else { - var updatedSpec *ast.WindowSpec - if _, ok := updatedSpecMap[name]; !ok { - updatedSpecMap[name] = []*ast.WindowSpec{newSpec} - updatedSpec = newSpec - } else { - for _, spec := range updatedSpecMap[name] { - eq, err := specEqual(spec, newSpec) - if err != nil { - return nil, nil, err - } - if eq { - updatedSpec = spec - break - } - } - if updatedSpec == nil { - updatedSpec = newSpec - updatedSpecMap[name] = append(updatedSpecMap[name], newSpec) - } - } - groupedWindow[updatedSpec] = append(groupedWindow[updatedSpec], windowFunc) - orderedSpec = appendIfAbsentWindowSpec(orderedSpec, updatedSpec) - } - } - // Unused window specs should also be checked in b.buildWindowFunctions, - // so we add them to `groupedWindow` with empty window functions. - for _, spec := range b.windowSpecs { - if _, ok := groupedWindow[spec]; !ok { - if _, ok = updatedSpecMap[spec.Name.L]; !ok { - groupedWindow[spec] = nil - orderedSpec = appendIfAbsentWindowSpec(orderedSpec, spec) - } - } - } - return groupedWindow, orderedSpec, nil -} - -// resolveWindowSpec resolve window specifications for sql like `select ... from t window w1 as (w2), w2 as (partition by a)`. -// We need to resolve the referenced window to get the definition of current window spec. -func resolveWindowSpec(spec *ast.WindowSpec, specs map[string]*ast.WindowSpec, inStack map[string]bool) error { - if inStack[spec.Name.L] { - return errors.Trace(plannererrors.ErrWindowCircularityInWindowGraph) - } - if spec.Ref.L == "" { - return nil - } - ref, ok := specs[spec.Ref.L] - if !ok { - return plannererrors.ErrWindowNoSuchWindow.GenWithStackByArgs(spec.Ref.O) - } - inStack[spec.Name.L] = true - err := resolveWindowSpec(ref, specs, inStack) - if err != nil { - return err - } - inStack[spec.Name.L] = false - return mergeWindowSpec(spec, ref) -} - -func mergeWindowSpec(spec, ref *ast.WindowSpec) error { - if ref.Frame != nil { - return plannererrors.ErrWindowNoInherentFrame.GenWithStackByArgs(ref.Name.O) - } - if spec.PartitionBy != nil { - return errors.Trace(plannererrors.ErrWindowNoChildPartitioning) - } - if ref.OrderBy != nil { - if spec.OrderBy != nil { - return plannererrors.ErrWindowNoRedefineOrderBy.GenWithStackByArgs(getWindowName(spec.Name.O), ref.Name.O) - } - spec.OrderBy = ref.OrderBy - } - spec.PartitionBy = ref.PartitionBy - spec.Ref = model.NewCIStr("") - return nil -} - -func buildWindowSpecs(specs []ast.WindowSpec) (map[string]*ast.WindowSpec, error) { - specsMap := make(map[string]*ast.WindowSpec, len(specs)) - for _, spec := range specs { - if _, ok := specsMap[spec.Name.L]; ok { - return nil, plannererrors.ErrWindowDuplicateName.GenWithStackByArgs(spec.Name.O) - } - newSpec := spec - specsMap[spec.Name.L] = &newSpec - } - inStack := make(map[string]bool, len(specs)) - for _, spec := range specsMap { - err := resolveWindowSpec(spec, specsMap, inStack) - if err != nil { - return nil, err - } - } - return specsMap, nil -} - -type updatableTableListResolver struct { - updatableTableList []*ast.TableName -} - -func (*updatableTableListResolver) Enter(inNode ast.Node) (ast.Node, bool) { - switch v := inNode.(type) { - case *ast.UpdateStmt, *ast.TableRefsClause, *ast.Join, *ast.TableSource, *ast.TableName: - return v, false - } - return inNode, true -} - -func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) { - if v, ok := inNode.(*ast.TableSource); ok { - if s, ok := v.Source.(*ast.TableName); ok { - if v.AsName.L != "" { - newTableName := *s - newTableName.Name = v.AsName - newTableName.Schema = model.NewCIStr("") - u.updatableTableList = append(u.updatableTableList, &newTableName) - } else { - u.updatableTableList = append(u.updatableTableList, s) - } - } - } - return inNode, true -} - -// ExtractTableList is a wrapper for tableListExtractor and removes duplicate TableName -// If asName is true, extract AsName prior to OrigName. -// Privilege check should use OrigName, while expression may use AsName. -func ExtractTableList(node ast.Node, asName bool) []*ast.TableName { - if node == nil { - return []*ast.TableName{} - } - e := &tableListExtractor{ - asName: asName, - tableNames: []*ast.TableName{}, - } - node.Accept(e) - tableNames := e.tableNames - m := make(map[string]map[string]*ast.TableName) // k1: schemaName, k2: tableName, v: ast.TableName - for _, x := range tableNames { - k1, k2 := x.Schema.L, x.Name.L - // allow empty schema name OR empty table name - if k1 != "" || k2 != "" { - if _, ok := m[k1]; !ok { - m[k1] = make(map[string]*ast.TableName) - } - m[k1][k2] = x - } - } - tableNames = tableNames[:0] - for _, x := range m { - for _, v := range x { - tableNames = append(tableNames, v) - } - } - return tableNames -} - -// tableListExtractor extracts all the TableNames from node. -type tableListExtractor struct { - asName bool - tableNames []*ast.TableName -} - -func (e *tableListExtractor) Enter(n ast.Node) (_ ast.Node, skipChildren bool) { - innerExtract := func(inner ast.Node) []*ast.TableName { - if inner == nil { - return nil - } - innerExtractor := &tableListExtractor{ - asName: e.asName, - tableNames: []*ast.TableName{}, - } - inner.Accept(innerExtractor) - return innerExtractor.tableNames - } - - switch x := n.(type) { - case *ast.TableName: - e.tableNames = append(e.tableNames, x) - case *ast.TableSource: - if s, ok := x.Source.(*ast.TableName); ok { - if x.AsName.L != "" && e.asName { - newTableName := *s - newTableName.Name = x.AsName - newTableName.Schema = model.NewCIStr("") - e.tableNames = append(e.tableNames, &newTableName) - } else { - e.tableNames = append(e.tableNames, s) - } - } else if s, ok := x.Source.(*ast.SelectStmt); ok { - if s.From != nil { - innerList := innerExtract(s.From.TableRefs) - if len(innerList) > 0 { - innerTableName := innerList[0] - if x.AsName.L != "" && e.asName { - newTableName := *innerList[0] - newTableName.Name = x.AsName - newTableName.Schema = model.NewCIStr("") - innerTableName = &newTableName - } - e.tableNames = append(e.tableNames, innerTableName) - } - } - } - return n, true - - case *ast.ShowStmt: - if x.DBName != "" { - e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)}) - } - case *ast.CreateDatabaseStmt: - e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) - case *ast.AlterDatabaseStmt: - e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) - case *ast.DropDatabaseStmt: - e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.Name}) - - case *ast.FlashBackDatabaseStmt: - e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName}) - e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.NewName)}) - case *ast.FlashBackToTimestampStmt: - if x.DBName.L != "" { - e.tableNames = append(e.tableNames, &ast.TableName{Schema: x.DBName}) - } - case *ast.FlashBackTableStmt: - if newName := x.NewName; newName != "" { - e.tableNames = append(e.tableNames, &ast.TableName{ - Schema: x.Table.Schema, - Name: model.NewCIStr(newName)}) - } - - case *ast.GrantStmt: - if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone { - if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable { - e.tableNames = append(e.tableNames, &ast.TableName{ - Schema: model.NewCIStr(x.Level.DBName), - Name: model.NewCIStr(x.Level.TableName), - }) - } - } - case *ast.RevokeStmt: - if x.ObjectType == ast.ObjectTypeTable || x.ObjectType == ast.ObjectTypeNone { - if x.Level.Level == ast.GrantLevelDB || x.Level.Level == ast.GrantLevelTable { - e.tableNames = append(e.tableNames, &ast.TableName{ - Schema: model.NewCIStr(x.Level.DBName), - Name: model.NewCIStr(x.Level.TableName), - }) - } - } - case *ast.BRIEStmt: - if x.Kind == ast.BRIEKindBackup || x.Kind == ast.BRIEKindRestore { - for _, v := range x.Schemas { - e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(v)}) - } - } - case *ast.UseStmt: - e.tableNames = append(e.tableNames, &ast.TableName{Schema: model.NewCIStr(x.DBName)}) - case *ast.ExecuteStmt: - if v, ok := x.PrepStmt.(*PlanCacheStmt); ok { - e.tableNames = append(e.tableNames, innerExtract(v.PreparedAst.Stmt)...) - } - } - return n, false -} - -func (*tableListExtractor) Leave(n ast.Node) (ast.Node, bool) { - return n, true -} - -func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, info *map[string]*ast.TableName) { - switch x := node.(type) { - case *ast.Join: - collectTableName(x.Left, updatableName, info) - collectTableName(x.Right, updatableName, info) - case *ast.TableSource: - name := x.AsName.L - var canUpdate bool - var s *ast.TableName - if s, canUpdate = x.Source.(*ast.TableName); canUpdate { - if name == "" { - name = s.Schema.L + "." + s.Name.L - // it may be a CTE - if s.Schema.L == "" { - name = s.Name.L - } - } - (*info)[name] = s - } - (*updatableName)[name] = canUpdate && s.Schema.L != "" - } -} - -func appendDynamicVisitInfo(vi []visitInfo, privs []string, withGrant bool, err error) []visitInfo { - return append(vi, visitInfo{ - privilege: mysql.ExtendedPriv, - dynamicPrivs: privs, - dynamicWithGrant: withGrant, - err: err, - }) -} - -func appendVisitInfo(vi []visitInfo, priv mysql.PrivilegeType, db, tbl, col string, err error) []visitInfo { - return append(vi, visitInfo{ - privilege: priv, - db: db, - table: tbl, - column: col, - err: err, - }) -} - -func getInnerFromParenthesesAndUnaryPlus(expr ast.ExprNode) ast.ExprNode { - if pexpr, ok := expr.(*ast.ParenthesesExpr); ok { - return getInnerFromParenthesesAndUnaryPlus(pexpr.Expr) - } - if uexpr, ok := expr.(*ast.UnaryOperationExpr); ok && uexpr.Op == opcode.Plus { - return getInnerFromParenthesesAndUnaryPlus(uexpr.V) - } - return expr -} - -// containDifferentJoinTypes checks whether `PreferJoinType` contains different -// join types. -func containDifferentJoinTypes(preferJoinType uint) bool { - preferJoinType &= ^h.PreferNoHashJoin - preferJoinType &= ^h.PreferNoMergeJoin - preferJoinType &= ^h.PreferNoIndexJoin - preferJoinType &= ^h.PreferNoIndexHashJoin - preferJoinType &= ^h.PreferNoIndexMergeJoin - - inlMask := h.PreferRightAsINLJInner ^ h.PreferLeftAsINLJInner - inlhjMask := h.PreferRightAsINLHJInner ^ h.PreferLeftAsINLHJInner - inlmjMask := h.PreferRightAsINLMJInner ^ h.PreferLeftAsINLMJInner - hjRightBuildMask := h.PreferRightAsHJBuild ^ h.PreferLeftAsHJProbe - hjLeftBuildMask := h.PreferLeftAsHJBuild ^ h.PreferRightAsHJProbe - - mppMask := h.PreferShuffleJoin ^ h.PreferBCJoin - mask := inlMask ^ inlhjMask ^ inlmjMask ^ hjRightBuildMask ^ hjLeftBuildMask - onesCount := bits.OnesCount(preferJoinType & ^mask & ^mppMask) - if onesCount > 1 || onesCount == 1 && preferJoinType&mask > 0 { - return true - } - - cnt := 0 - if preferJoinType&inlMask > 0 { - cnt++ - } - if preferJoinType&inlhjMask > 0 { - cnt++ - } - if preferJoinType&inlmjMask > 0 { - cnt++ - } - if preferJoinType&hjLeftBuildMask > 0 { - cnt++ - } - if preferJoinType&hjRightBuildMask > 0 { - cnt++ - } - return cnt > 1 -} - -func hasMPPJoinHints(preferJoinType uint) bool { - return (preferJoinType&h.PreferBCJoin > 0) || (preferJoinType&h.PreferShuffleJoin > 0) -} - -// isJoinHintSupportedInMPPMode is used to check if the specified join hint is available under MPP mode. -func isJoinHintSupportedInMPPMode(preferJoinType uint) bool { - if preferJoinType == 0 { - return true - } - mppMask := h.PreferShuffleJoin ^ h.PreferBCJoin - // Currently, TiFlash only supports HASH JOIN, so the hint for HASH JOIN is available while other join method hints are forbidden. - joinMethodHintSupportedByTiflash := h.PreferHashJoin ^ h.PreferLeftAsHJBuild ^ h.PreferRightAsHJBuild ^ h.PreferLeftAsHJProbe ^ h.PreferRightAsHJProbe - onesCount := bits.OnesCount(preferJoinType & ^joinMethodHintSupportedByTiflash & ^mppMask) - return onesCount < 1 -} - -func (b *PlanBuilder) buildCte(ctx context.Context, cte *ast.CommonTableExpression, isRecursive bool) (p base.LogicalPlan, err error) { - saveBuildingCTE := b.buildingCTE - b.buildingCTE = true - defer func() { - b.buildingCTE = saveBuildingCTE - }() - - if isRecursive { - // buildingRecursivePartForCTE likes a stack. We save it before building a recursive CTE and restore it after building. - // We need a stack because we need to handle the nested recursive CTE. And buildingRecursivePartForCTE indicates the innermost CTE. - saveCheck := b.buildingRecursivePartForCTE - b.buildingRecursivePartForCTE = false - err = b.buildRecursiveCTE(ctx, cte.Query.Query) - if err != nil { - return nil, err - } - b.buildingRecursivePartForCTE = saveCheck - } else { - p, err = b.buildResultSetNode(ctx, cte.Query.Query, true) - if err != nil { - return nil, err - } - - p, err = b.adjustCTEPlanOutputName(p, cte) - if err != nil { - return nil, err - } - - cInfo := b.outerCTEs[len(b.outerCTEs)-1] - cInfo.seedLP = p - } - return nil, nil -} - -// buildRecursiveCTE handles the with clause `with recursive xxx as xx`. -func (b *PlanBuilder) buildRecursiveCTE(ctx context.Context, cte ast.ResultSetNode) error { - b.isCTE = true - cInfo := b.outerCTEs[len(b.outerCTEs)-1] - switch x := (cte).(type) { - case *ast.SetOprStmt: - // 1. Handle the WITH clause if exists. - if x.With != nil { - l := len(b.outerCTEs) - sw := x.With - defer func() { - b.outerCTEs = b.outerCTEs[:l] - x.With = sw - }() - _, err := b.buildWith(ctx, x.With) - if err != nil { - return err - } - } - // Set it to nil, so that when builds the seed part, it won't build again. Reset it in defer so that the AST doesn't change after this function. - x.With = nil - - // 2. Build plans for each part of SetOprStmt. - recursive := make([]base.LogicalPlan, 0) - tmpAfterSetOptsForRecur := []*ast.SetOprType{nil} - - expectSeed := true - for i := 0; i < len(x.SelectList.Selects); i++ { - var p base.LogicalPlan - var err error - - var afterOpr *ast.SetOprType - switch y := x.SelectList.Selects[i].(type) { - case *ast.SelectStmt: - p, err = b.buildSelect(ctx, y) - afterOpr = y.AfterSetOperator - case *ast.SetOprSelectList: - p, err = b.buildSetOpr(ctx, &ast.SetOprStmt{SelectList: y, With: y.With}) - afterOpr = y.AfterSetOperator - } - - if expectSeed { - if cInfo.useRecursive { - // 3. If it fail to build a plan, it may be the recursive part. Then we build the seed part plan, and rebuild it. - if i == 0 { - return plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst.GenWithStackByArgs(cInfo.def.Name.String()) - } - - // It's the recursive part. Build the seed part, and build this recursive part again. - // Before we build the seed part, do some checks. - if x.OrderBy != nil { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs("ORDER BY over UNION in recursive Common Table Expression") - } - // Limit clause is for the whole CTE instead of only for the seed part. - oriLimit := x.Limit - x.Limit = nil - - // Check union type. - if afterOpr != nil { - if *afterOpr != ast.Union && *afterOpr != ast.UnionAll { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("%s between seed part and recursive part, hint: The operator between seed part and recursive part must bu UNION[DISTINCT] or UNION ALL", afterOpr.String())) - } - cInfo.isDistinct = *afterOpr == ast.Union - } - - expectSeed = false - cInfo.useRecursive = false - - // Build seed part plan. - saveSelect := x.SelectList.Selects - x.SelectList.Selects = x.SelectList.Selects[:i] - // We're rebuilding the seed part, so we pop the result we built previously. - for _i := 0; _i < i; _i++ { - b.handleHelper.popMap() - } - p, err = b.buildSetOpr(ctx, x) - if err != nil { - return err - } - x.SelectList.Selects = saveSelect - p, err = b.adjustCTEPlanOutputName(p, cInfo.def) - if err != nil { - return err - } - cInfo.seedLP = p - - // Rebuild the plan. - i-- - b.buildingRecursivePartForCTE = true - x.Limit = oriLimit - continue - } - if err != nil { - return err - } - } else { - if err != nil { - return err - } - if afterOpr != nil { - if *afterOpr != ast.Union && *afterOpr != ast.UnionAll { - return plannererrors.ErrNotSupportedYet.GenWithStackByArgs(fmt.Sprintf("%s between recursive part's selects, hint: The operator between recursive part's selects must bu UNION[DISTINCT] or UNION ALL", afterOpr.String())) - } - } - if !cInfo.useRecursive { - return plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst.GenWithStackByArgs(cInfo.def.Name.String()) - } - cInfo.useRecursive = false - recursive = append(recursive, p) - tmpAfterSetOptsForRecur = append(tmpAfterSetOptsForRecur, afterOpr) - } - } - - if len(recursive) == 0 { - // In this case, even if SQL specifies "WITH RECURSIVE", the CTE is non-recursive. - p, err := b.buildSetOpr(ctx, x) - if err != nil { - return err - } - p, err = b.adjustCTEPlanOutputName(p, cInfo.def) - if err != nil { - return err - } - cInfo.seedLP = p - return nil - } - - // Build the recursive part's logical plan. - recurPart, err := b.buildUnion(ctx, recursive, tmpAfterSetOptsForRecur) - if err != nil { - return err - } - recurPart, err = b.buildProjection4CTEUnion(ctx, cInfo.seedLP, recurPart) - if err != nil { - return err - } - // 4. Finally, we get the seed part plan and recursive part plan. - cInfo.recurLP = recurPart - // Only need to handle limit if x is SetOprStmt. - if x.Limit != nil { - limit, err := b.buildLimit(cInfo.seedLP, x.Limit) - if err != nil { - return err - } - limit.SetChildren(limit.Children()[:0]...) - cInfo.limitLP = limit - } - return nil - default: - p, err := b.buildResultSetNode(ctx, x, true) - if err != nil { - // Refine the error message. - if errors.ErrorEqual(err, plannererrors.ErrCTERecursiveRequiresNonRecursiveFirst) { - err = plannererrors.ErrCTERecursiveRequiresUnion.GenWithStackByArgs(cInfo.def.Name.String()) - } - return err - } - p, err = b.adjustCTEPlanOutputName(p, cInfo.def) - if err != nil { - return err - } - cInfo.seedLP = p - return nil - } -} - -func (b *PlanBuilder) adjustCTEPlanOutputName(p base.LogicalPlan, def *ast.CommonTableExpression) (base.LogicalPlan, error) { - outPutNames := p.OutputNames() - for _, name := range outPutNames { - name.TblName = def.Name - name.DBName = model.NewCIStr(b.ctx.GetSessionVars().CurrentDB) - } - if len(def.ColNameList) > 0 { - if len(def.ColNameList) != len(p.OutputNames()) { - return nil, dbterror.ErrViewWrongList - } - for i, n := range def.ColNameList { - outPutNames[i].ColName = n - } - } - p.SetOutputNames(outPutNames) - return p, nil -} - -// prepareCTECheckForSubQuery prepares the check that the recursive CTE can't be referenced in subQuery. It's used before building a subQuery. -// For example: with recursive cte(n) as (select 1 union select * from (select * from cte) c1) select * from cte; -func (b *PlanBuilder) prepareCTECheckForSubQuery() []*cteInfo { - modifiedCTE := make([]*cteInfo, 0) - for _, cte := range b.outerCTEs { - if cte.isBuilding && !cte.enterSubquery { - cte.enterSubquery = true - modifiedCTE = append(modifiedCTE, cte) - } - } - return modifiedCTE -} - -// resetCTECheckForSubQuery resets the related variable. It's used after leaving a subQuery. -func resetCTECheckForSubQuery(ci []*cteInfo) { - for _, cte := range ci { - cte.enterSubquery = false - } -} - -// genCTETableNameForError find the nearest CTE name. -func (b *PlanBuilder) genCTETableNameForError() string { - name := "" - for i := len(b.outerCTEs) - 1; i >= 0; i-- { - if b.outerCTEs[i].isBuilding { - name = b.outerCTEs[i].def.Name.String() - break - } - } - return name -} - -func (b *PlanBuilder) buildWith(ctx context.Context, w *ast.WithClause) ([]*cteInfo, error) { - // Check CTE name must be unique. - nameMap := make(map[string]struct{}) - for _, cte := range w.CTEs { - if _, ok := nameMap[cte.Name.L]; ok { - return nil, plannererrors.ErrNonUniqTable - } - nameMap[cte.Name.L] = struct{}{} - } - ctes := make([]*cteInfo, 0, len(w.CTEs)) - for _, cte := range w.CTEs { - b.outerCTEs = append(b.outerCTEs, &cteInfo{def: cte, nonRecursive: !w.IsRecursive, isBuilding: true, storageID: b.allocIDForCTEStorage, seedStat: &property.StatsInfo{}, consumerCount: cte.ConsumerCount}) - b.allocIDForCTEStorage++ - saveFlag := b.optFlag - // Init the flag to flagPrunColumns, otherwise it's missing. - b.optFlag = flagPrunColumns - if b.ctx.GetSessionVars().EnableForceInlineCTE() { - b.outerCTEs[len(b.outerCTEs)-1].forceInlineByHintOrVar = true - } - _, err := b.buildCte(ctx, cte, w.IsRecursive) - if err != nil { - return nil, err - } - b.outerCTEs[len(b.outerCTEs)-1].optFlag = b.optFlag - b.outerCTEs[len(b.outerCTEs)-1].isBuilding = false - b.optFlag = saveFlag - // each cte (select statement) will generate a handle map, pop it out here. - b.handleHelper.popMap() - ctes = append(ctes, b.outerCTEs[len(b.outerCTEs)-1]) - } - return ctes, nil -} - -func (b *PlanBuilder) buildProjection4CTEUnion(_ context.Context, seed base.LogicalPlan, recur base.LogicalPlan) (base.LogicalPlan, error) { - if seed.Schema().Len() != recur.Schema().Len() { - return nil, plannererrors.ErrWrongNumberOfColumnsInSelect.GenWithStackByArgs() - } - exprs := make([]expression.Expression, len(seed.Schema().Columns)) - resSchema := getResultCTESchema(seed.Schema(), b.ctx.GetSessionVars()) - for i, col := range recur.Schema().Columns { - if !resSchema.Columns[i].RetType.Equal(col.RetType) { - exprs[i] = expression.BuildCastFunction4Union(b.ctx.GetExprCtx(), col, resSchema.Columns[i].RetType) - } else { - exprs[i] = col - } - } - b.optFlag |= flagEliminateProjection - proj := logicalop.LogicalProjection{Exprs: exprs, AvoidColumnEvaluator: true}.Init(b.ctx, b.getSelectOffset()) - proj.SetSchema(resSchema) - proj.SetChildren(recur) - return proj, nil -} - -// The recursive part/CTE's schema is nullable, and the UID should be unique. -func getResultCTESchema(seedSchema *expression.Schema, svar *variable.SessionVars) *expression.Schema { - res := seedSchema.Clone() - for _, col := range res.Columns { - col.RetType = col.RetType.Clone() - col.UniqueID = svar.AllocPlanColumnID() - col.RetType.DelFlag(mysql.NotNullFlag) - // Since you have reallocated unique id here, the old-cloned-cached hash code is not valid anymore. - col.CleanHashCode() - } - return res -} diff --git a/pkg/planner/core/optimizer.go b/pkg/planner/core/optimizer.go index 1ae8bdc33dbfc..ebdd4706dcd18 100644 --- a/pkg/planner/core/optimizer.go +++ b/pkg/planner/core/optimizer.go @@ -652,14 +652,14 @@ func (h *fineGrainedShuffleHelper) updateTarget(t shuffleTarget, p *basePhysical // calculateTiFlashStreamCountUsingMinLogicalCores uses minimal logical cpu cores among tiflash servers, and divide by 2 // return false, 0 if any err happens func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx base.PlanContext, serversInfo []infoschema.ServerInfo) (bool, uint64) { - if val, _err_ := failpoint.Eval(_curpkg_("mockTiFlashStreamCountUsingMinLogicalCores")); _err_ == nil { + failpoint.Inject("mockTiFlashStreamCountUsingMinLogicalCores", func(val failpoint.Value) { intVal, err := strconv.Atoi(val.(string)) if err == nil { - return true, uint64(intVal) + failpoint.Return(true, uint64(intVal)) } else { - return false, 0 + failpoint.Return(false, 0) } - } + }) rows, err := infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, diagnosticspb.ServerInfoType_HardwareInfo, false) if err != nil { return false, 0 diff --git a/pkg/planner/core/optimizer.go__failpoint_stash__ b/pkg/planner/core/optimizer.go__failpoint_stash__ deleted file mode 100644 index ebdd4706dcd18..0000000000000 --- a/pkg/planner/core/optimizer.go__failpoint_stash__ +++ /dev/null @@ -1,1222 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "cmp" - "context" - "fmt" - "math" - "runtime" - "slices" - "strconv" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/diagnosticspb" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/expression/aggregation" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/lock" - tablelock "github.com/pingcap/tidb/pkg/lock/context" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/auth" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/rule" - "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - utilhint "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/set" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -// OptimizeAstNode optimizes the query to a physical plan directly. -var OptimizeAstNode func(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (base.Plan, types.NameSlice, error) - -// AllowCartesianProduct means whether tidb allows cartesian join without equal conditions. -var AllowCartesianProduct = atomic.NewBool(true) - -// IsReadOnly check whether the ast.Node is a read only statement. -var IsReadOnly func(node ast.Node, vars *variable.SessionVars) bool - -// Note: The order of flags is same as the order of optRule in the list. -// Do not mess up the order. -const ( - flagGcSubstitute uint64 = 1 << iota - flagPrunColumns - flagStabilizeResults - flagBuildKeyInfo - flagDecorrelate - flagSemiJoinRewrite - flagEliminateAgg - flagSkewDistinctAgg - flagEliminateProjection - flagMaxMinEliminate - flagConstantPropagation - flagConvertOuterToInnerJoin - flagPredicatePushDown - flagEliminateOuterJoin - flagPartitionProcessor - flagCollectPredicateColumnsPoint - flagPushDownAgg - flagDeriveTopNFromWindow - flagPredicateSimplification - flagPushDownTopN - flagSyncWaitStatsLoadPoint - flagJoinReOrder - flagPrunColumnsAgain - flagPushDownSequence - flagResolveExpand -) - -var optRuleList = []base.LogicalOptRule{ - &GcSubstituter{}, - &ColumnPruner{}, - &ResultReorder{}, - &rule.BuildKeySolver{}, - &DecorrelateSolver{}, - &SemiJoinRewriter{}, - &AggregationEliminator{}, - &SkewDistinctAggRewriter{}, - &ProjectionEliminator{}, - &MaxMinEliminator{}, - &ConstantPropagationSolver{}, - &ConvertOuterToInnerJoin{}, - &PPDSolver{}, - &OuterJoinEliminator{}, - &PartitionProcessor{}, - &CollectPredicateColumnsPoint{}, - &AggregationPushDownSolver{}, - &DeriveTopNFromWindow{}, - &PredicateSimplification{}, - &PushDownTopNOptimizer{}, - &SyncWaitStatsLoadPoint{}, - &JoinReOrderSolver{}, - &ColumnPruner{}, // column pruning again at last, note it will mess up the results of buildKeySolver - &PushDownSequenceSolver{}, - &ResolveExpand{}, -} - -// Interaction Rule List -/* The interaction rule will be trigger when it satisfies following conditions: -1. The related rule has been trigger and changed the plan -2. The interaction rule is enabled -*/ -var optInteractionRuleList = map[base.LogicalOptRule]base.LogicalOptRule{} - -// BuildLogicalPlanForTest builds a logical plan for testing purpose from ast.Node. -func BuildLogicalPlanForTest(ctx context.Context, sctx sessionctx.Context, node ast.Node, infoSchema infoschema.InfoSchema) (base.Plan, error) { - sctx.GetSessionVars().PlanID.Store(0) - sctx.GetSessionVars().PlanColumnID.Store(0) - builder, _ := NewPlanBuilder().Init(sctx.GetPlanCtx(), infoSchema, utilhint.NewQBHintHandler(nil)) - p, err := builder.Build(ctx, node) - if err != nil { - return nil, err - } - if logic, ok := p.(base.LogicalPlan); ok { - RecheckCTE(logic) - } - return p, err -} - -// CheckPrivilege checks the privilege for a user. -func CheckPrivilege(activeRoles []*auth.RoleIdentity, pm privilege.Manager, vs []visitInfo) error { - for _, v := range vs { - if v.privilege == mysql.ExtendedPriv { - hasPriv := false - for _, priv := range v.dynamicPrivs { - hasPriv = hasPriv || pm.RequestDynamicVerification(activeRoles, priv, v.dynamicWithGrant) - if hasPriv { - break - } - } - if !hasPriv { - if v.err == nil { - return plannererrors.ErrPrivilegeCheckFail.GenWithStackByArgs(v.dynamicPrivs) - } - return v.err - } - } else if !pm.RequestVerification(activeRoles, v.db, v.table, v.column, v.privilege) { - if v.err == nil { - return plannererrors.ErrPrivilegeCheckFail.GenWithStackByArgs(v.privilege.String()) - } - return v.err - } - } - return nil -} - -// VisitInfo4PrivCheck generates privilege check infos because privilege check of local temporary tables is different -// with normal tables. `CREATE` statement needs `CREATE TEMPORARY TABLE` privilege from the database, and subsequent -// statements do not need any privileges. -func VisitInfo4PrivCheck(ctx context.Context, is infoschema.InfoSchema, node ast.Node, vs []visitInfo) (privVisitInfo []visitInfo) { - if node == nil { - return vs - } - - switch stmt := node.(type) { - case *ast.CreateTableStmt: - privVisitInfo = make([]visitInfo, 0, len(vs)) - for _, v := range vs { - if v.privilege == mysql.CreatePriv { - if stmt.TemporaryKeyword == ast.TemporaryLocal { - // `CREATE TEMPORARY TABLE` privilege is required from the database, not the table. - newVisitInfo := v - newVisitInfo.privilege = mysql.CreateTMPTablePriv - newVisitInfo.table = "" - privVisitInfo = append(privVisitInfo, newVisitInfo) - } else { - // If both the normal table and temporary table already exist, we need to check the privilege. - privVisitInfo = append(privVisitInfo, v) - } - } else { - // `CREATE TABLE LIKE tmp` or `CREATE TABLE FROM SELECT tmp` in the future. - if needCheckTmpTablePriv(ctx, is, v) { - privVisitInfo = append(privVisitInfo, v) - } - } - } - case *ast.DropTableStmt: - // Dropping a local temporary table doesn't need any privileges. - if stmt.IsView { - privVisitInfo = vs - } else { - privVisitInfo = make([]visitInfo, 0, len(vs)) - if stmt.TemporaryKeyword != ast.TemporaryLocal { - for _, v := range vs { - if needCheckTmpTablePriv(ctx, is, v) { - privVisitInfo = append(privVisitInfo, v) - } - } - } - } - case *ast.GrantStmt, *ast.DropSequenceStmt, *ast.DropPlacementPolicyStmt: - // Some statements ignore local temporary tables, so they should check the privileges on normal tables. - privVisitInfo = vs - default: - privVisitInfo = make([]visitInfo, 0, len(vs)) - for _, v := range vs { - if needCheckTmpTablePriv(ctx, is, v) { - privVisitInfo = append(privVisitInfo, v) - } - } - } - return -} - -func needCheckTmpTablePriv(ctx context.Context, is infoschema.InfoSchema, v visitInfo) bool { - if v.db != "" && v.table != "" { - // Other statements on local temporary tables except `CREATE` do not check any privileges. - tb, err := is.TableByName(ctx, model.NewCIStr(v.db), model.NewCIStr(v.table)) - // If the table doesn't exist, we do not report errors to avoid leaking the existence of the table. - if err == nil && tb.Meta().TempTableType == model.TempTableLocal { - return false - } - } - return true -} - -// CheckTableLock checks the table lock. -func CheckTableLock(ctx tablelock.TableLockReadContext, is infoschema.InfoSchema, vs []visitInfo) error { - if !config.TableLockEnabled() { - return nil - } - - checker := lock.NewChecker(ctx, is) - for i := range vs { - err := checker.CheckTableLock(vs[i].db, vs[i].table, vs[i].privilege, vs[i].alterWritable) - // if table with lock-write table dropped, we can access other table, such as `rename` operation - if err == lock.ErrLockedTableDropped { - break - } - if err != nil { - return err - } - } - return nil -} - -func checkStableResultMode(sctx base.PlanContext) bool { - s := sctx.GetSessionVars() - st := s.StmtCtx - return s.EnableStableResultMode && (!st.InInsertStmt && !st.InUpdateStmt && !st.InDeleteStmt && !st.InLoadDataStmt) -} - -// doOptimize optimizes a logical plan into a physical plan, -// while also returning the optimized logical plan, the final physical plan, and the cost of the final plan. -// The returned logical plan is necessary for generating plans for Common Table Expressions (CTEs). -func doOptimize( - ctx context.Context, - sctx base.PlanContext, - flag uint64, - logic base.LogicalPlan, -) (base.LogicalPlan, base.PhysicalPlan, float64, error) { - sessVars := sctx.GetSessionVars() - flag = adjustOptimizationFlags(flag, logic) - logic, err := logicalOptimize(ctx, flag, logic) - if err != nil { - return nil, nil, 0, err - } - - if !AllowCartesianProduct.Load() && existsCartesianProduct(logic) { - return nil, nil, 0, errors.Trace(plannererrors.ErrCartesianProductUnsupported) - } - planCounter := base.PlanCounterTp(sessVars.StmtCtx.StmtHints.ForceNthPlan) - if planCounter == 0 { - planCounter = -1 - } - physical, cost, err := physicalOptimize(logic, &planCounter) - if err != nil { - return nil, nil, 0, err - } - finalPlan := postOptimize(ctx, sctx, physical) - - if sessVars.StmtCtx.EnableOptimizerCETrace { - refineCETrace(sctx) - } - if sessVars.StmtCtx.EnableOptimizeTrace { - sessVars.StmtCtx.OptimizeTracer.RecordFinalPlan(finalPlan.BuildPlanTrace()) - } - return logic, finalPlan, cost, nil -} - -func adjustOptimizationFlags(flag uint64, logic base.LogicalPlan) uint64 { - // If there is something after flagPrunColumns, do flagPrunColumnsAgain. - if flag&flagPrunColumns > 0 && flag-flagPrunColumns > flagPrunColumns { - flag |= flagPrunColumnsAgain - } - if checkStableResultMode(logic.SCtx()) { - flag |= flagStabilizeResults - } - if logic.SCtx().GetSessionVars().StmtCtx.StraightJoinOrder { - // When we use the straight Join Order hint, we should disable the join reorder optimization. - flag &= ^flagJoinReOrder - } - flag |= flagCollectPredicateColumnsPoint - flag |= flagSyncWaitStatsLoadPoint - if !logic.SCtx().GetSessionVars().StmtCtx.UseDynamicPruneMode { - flag |= flagPartitionProcessor // apply partition pruning under static mode - } - return flag -} - -// DoOptimize optimizes a logical plan to a physical plan. -func DoOptimize( - ctx context.Context, - sctx base.PlanContext, - flag uint64, - logic base.LogicalPlan, -) (base.PhysicalPlan, float64, error) { - sessVars := sctx.GetSessionVars() - if sessVars.StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - defer debugtrace.LeaveContextCommon(sctx) - } - _, finalPlan, cost, err := doOptimize(ctx, sctx, flag, logic) - return finalPlan, cost, err -} - -// refineCETrace will adjust the content of CETrace. -// Currently, it will (1) deduplicate trace records, (2) sort the trace records (to make it easier in the tests) and (3) fill in the table name. -func refineCETrace(sctx base.PlanContext) { - stmtCtx := sctx.GetSessionVars().StmtCtx - stmtCtx.OptimizerCETrace = tracing.DedupCETrace(stmtCtx.OptimizerCETrace) - slices.SortFunc(stmtCtx.OptimizerCETrace, func(i, j *tracing.CETraceRecord) int { - if i == nil && j != nil { - return -1 - } - if i == nil || j == nil { - return 1 - } - - if c := cmp.Compare(i.TableID, j.TableID); c != 0 { - return c - } - if c := cmp.Compare(i.Type, j.Type); c != 0 { - return c - } - if c := cmp.Compare(i.Expr, j.Expr); c != 0 { - return c - } - return cmp.Compare(i.RowCount, j.RowCount) - }) - traceRecords := stmtCtx.OptimizerCETrace - is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) - for _, rec := range traceRecords { - tbl, _ := infoschema.FindTableByTblOrPartID(is, rec.TableID) - if tbl != nil { - rec.TableName = tbl.Meta().Name.O - continue - } - logutil.BgLogger().Warn("Failed to find table in infoschema", zap.String("category", "OptimizerTrace"), - zap.Int64("table id", rec.TableID)) - } -} - -// mergeContinuousSelections merge continuous selections which may occur after changing plans. -func mergeContinuousSelections(p base.PhysicalPlan) { - if sel, ok := p.(*PhysicalSelection); ok { - for { - childSel := sel.children[0] - tmp, ok := childSel.(*PhysicalSelection) - if !ok { - break - } - sel.Conditions = append(sel.Conditions, tmp.Conditions...) - sel.SetChild(0, tmp.children[0]) - } - } - for _, child := range p.Children() { - mergeContinuousSelections(child) - } - // merge continuous selections in a coprocessor task of tiflash - tableReader, isTableReader := p.(*PhysicalTableReader) - if isTableReader && tableReader.StoreType == kv.TiFlash { - mergeContinuousSelections(tableReader.tablePlan) - tableReader.TablePlans = flattenPushDownPlan(tableReader.tablePlan) - } -} - -func postOptimize(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan) base.PhysicalPlan { - // some cases from update optimize will require avoiding projection elimination. - // see comments ahead of call of DoOptimize in function of buildUpdate(). - plan = eliminatePhysicalProjection(plan) - plan = InjectExtraProjection(plan) - mergeContinuousSelections(plan) - plan = eliminateUnionScanAndLock(sctx, plan) - plan = enableParallelApply(sctx, plan) - handleFineGrainedShuffle(ctx, sctx, plan) - propagateProbeParents(plan, nil) - countStarRewrite(plan) - disableReuseChunkIfNeeded(sctx, plan) - tryEnableLateMaterialization(sctx, plan) - generateRuntimeFilter(sctx, plan) - return plan -} - -func generateRuntimeFilter(sctx base.PlanContext, plan base.PhysicalPlan) { - if !sctx.GetSessionVars().IsRuntimeFilterEnabled() || sctx.GetSessionVars().InRestrictedSQL { - return - } - logutil.BgLogger().Debug("Start runtime filter generator") - rfGenerator := &RuntimeFilterGenerator{ - rfIDGenerator: &util.IDGenerator{}, - columnUniqueIDToRF: map[int64][]*RuntimeFilter{}, - parentPhysicalPlan: plan, - } - startRFGenerator := time.Now() - rfGenerator.GenerateRuntimeFilter(plan) - logutil.BgLogger().Debug("Finish runtime filter generator", - zap.Duration("Cost", time.Since(startRFGenerator))) -} - -// tryEnableLateMaterialization tries to push down some filter conditions to the table scan operator -// @brief: push down some filter conditions to the table scan operator -// @param: sctx: session context -// @param: plan: the physical plan to be pruned -// @note: this optimization is only applied when the TiFlash is used. -// @note: the following conditions should be satisfied: -// - Only the filter conditions with high selectivity should be pushed down. -// - The filter conditions which contain heavy cost functions should not be pushed down. -// - Filter conditions that apply to the same column are either pushed down or not pushed down at all. -func tryEnableLateMaterialization(sctx base.PlanContext, plan base.PhysicalPlan) { - // check if EnableLateMaterialization is set - if sctx.GetSessionVars().EnableLateMaterialization && !sctx.GetSessionVars().TiFlashFastScan { - predicatePushDownToTableScan(sctx, plan) - } - if sctx.GetSessionVars().EnableLateMaterialization && sctx.GetSessionVars().TiFlashFastScan { - sc := sctx.GetSessionVars().StmtCtx - sc.AppendWarning(errors.NewNoStackError("FastScan is not compatible with late materialization, late materialization is disabled")) - } -} - -/* -* -The countStarRewriter is used to rewrite - - count(*) -> count(not null column) - -**Only for TiFlash** -Attention: -Since count(*) is directly translated into count(1) during grammar parsing, -the rewritten pattern actually matches count(constant) - -Pattern: -PhysicalAggregation: count(constant) - - | - TableFullScan: TiFlash - -Optimize: -Table - - - -Query: select count(*) from table -ColumnPruningRule: datasource pick row_id -countStarRewrite: datasource pick k1 instead of row_id - - rewrite count(*) -> count(k1) - -Rewritten Query: select count(k1) from table -*/ -func countStarRewrite(plan base.PhysicalPlan) { - countStarRewriteInternal(plan) - if tableReader, ok := plan.(*PhysicalTableReader); ok { - countStarRewrite(tableReader.tablePlan) - } else { - for _, child := range plan.Children() { - countStarRewrite(child) - } - } -} - -func countStarRewriteInternal(plan base.PhysicalPlan) { - // match pattern any agg(count(constant)) -> tablefullscan(tiflash) - var physicalAgg *basePhysicalAgg - switch x := plan.(type) { - case *PhysicalHashAgg: - physicalAgg = x.getPointer() - case *PhysicalStreamAgg: - physicalAgg = x.getPointer() - default: - return - } - if len(physicalAgg.GroupByItems) > 0 || len(physicalAgg.children) != 1 { - return - } - for _, aggFunc := range physicalAgg.AggFuncs { - if aggFunc.Name != "count" || len(aggFunc.Args) != 1 || aggFunc.HasDistinct { - return - } - if _, ok := aggFunc.Args[0].(*expression.Constant); !ok { - return - } - } - physicalTableScan, ok := physicalAgg.Children()[0].(*PhysicalTableScan) - if !ok || !physicalTableScan.isFullScan() || physicalTableScan.StoreType != kv.TiFlash || len(physicalTableScan.schema.Columns) != 1 { - return - } - // rewrite datasource and agg args - rewriteTableScanAndAggArgs(physicalTableScan, physicalAgg.AggFuncs) -} - -// rewriteTableScanAndAggArgs Pick the narrowest and not null column from table -// If there is no not null column in Data Source, the row_id or pk column will be retained -func rewriteTableScanAndAggArgs(physicalTableScan *PhysicalTableScan, aggFuncs []*aggregation.AggFuncDesc) { - var resultColumnInfo *model.ColumnInfo - var resultColumn *expression.Column - - resultColumnInfo = physicalTableScan.Columns[0] - resultColumn = physicalTableScan.schema.Columns[0] - // prefer not null column from table - for _, columnInfo := range physicalTableScan.Table.Columns { - if columnInfo.FieldType.IsVarLengthType() { - continue - } - if mysql.HasNotNullFlag(columnInfo.GetFlag()) { - if columnInfo.GetFlen() < resultColumnInfo.GetFlen() { - resultColumnInfo = columnInfo - resultColumn = &expression.Column{ - UniqueID: physicalTableScan.SCtx().GetSessionVars().AllocPlanColumnID(), - ID: resultColumnInfo.ID, - RetType: resultColumnInfo.FieldType.Clone(), - OrigName: fmt.Sprintf("%s.%s.%s", physicalTableScan.DBName.L, physicalTableScan.Table.Name.L, resultColumnInfo.Name), - } - } - } - } - // table scan (row_id) -> (not null column) - physicalTableScan.Columns[0] = resultColumnInfo - physicalTableScan.schema.Columns[0] = resultColumn - // agg arg count(1) -> count(not null column) - arg := resultColumn.Clone() - for _, aggFunc := range aggFuncs { - constExpr, ok := aggFunc.Args[0].(*expression.Constant) - if !ok { - return - } - // count(null) shouldn't be rewritten - if constExpr.Value.IsNull() { - continue - } - aggFunc.Args[0] = arg - } -} - -// Only for MPP(Window<-[Sort]<-ExchangeReceiver<-ExchangeSender). -// TiFlashFineGrainedShuffleStreamCount: -// < 0: fine grained shuffle is disabled. -// > 0: use TiFlashFineGrainedShuffleStreamCount as stream count. -// == 0: use TiFlashMaxThreads as stream count when it's greater than 0. Otherwise set status as uninitialized. -func handleFineGrainedShuffle(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan) { - streamCount := sctx.GetSessionVars().TiFlashFineGrainedShuffleStreamCount - if streamCount < 0 { - return - } - if streamCount == 0 { - if sctx.GetSessionVars().TiFlashMaxThreads > 0 { - streamCount = sctx.GetSessionVars().TiFlashMaxThreads - } - } - // use two separate cluster info to avoid grpc calls cost - tiflashServerCountInfo := tiflashClusterInfo{unInitialized, 0} - streamCountInfo := tiflashClusterInfo{unInitialized, 0} - if streamCount != 0 { - streamCountInfo.itemStatus = initialized - streamCountInfo.itemValue = uint64(streamCount) - } - setupFineGrainedShuffle(ctx, sctx, &streamCountInfo, &tiflashServerCountInfo, plan) -} - -func setupFineGrainedShuffle(ctx context.Context, sctx base.PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, plan base.PhysicalPlan) { - if tableReader, ok := plan.(*PhysicalTableReader); ok { - if _, isExchangeSender := tableReader.tablePlan.(*PhysicalExchangeSender); isExchangeSender { - helper := fineGrainedShuffleHelper{shuffleTarget: unknown, plans: make([]*basePhysicalPlan, 1)} - setupFineGrainedShuffleInternal(ctx, sctx, tableReader.tablePlan, &helper, streamCountInfo, tiflashServerCountInfo) - } - } else { - for _, child := range plan.Children() { - setupFineGrainedShuffle(ctx, sctx, streamCountInfo, tiflashServerCountInfo, child) - } - } -} - -type shuffleTarget uint8 - -const ( - unknown shuffleTarget = iota - window - joinBuild - hashAgg -) - -type fineGrainedShuffleHelper struct { - shuffleTarget shuffleTarget - plans []*basePhysicalPlan - joinKeysCount int -} - -type tiflashClusterInfoStatus uint8 - -const ( - unInitialized tiflashClusterInfoStatus = iota - initialized - failed -) - -type tiflashClusterInfo struct { - itemStatus tiflashClusterInfoStatus - itemValue uint64 -} - -func (h *fineGrainedShuffleHelper) clear() { - h.shuffleTarget = unknown - h.plans = h.plans[:0] - h.joinKeysCount = 0 -} - -func (h *fineGrainedShuffleHelper) updateTarget(t shuffleTarget, p *basePhysicalPlan) { - h.shuffleTarget = t - h.plans = append(h.plans, p) -} - -// calculateTiFlashStreamCountUsingMinLogicalCores uses minimal logical cpu cores among tiflash servers, and divide by 2 -// return false, 0 if any err happens -func calculateTiFlashStreamCountUsingMinLogicalCores(ctx context.Context, sctx base.PlanContext, serversInfo []infoschema.ServerInfo) (bool, uint64) { - failpoint.Inject("mockTiFlashStreamCountUsingMinLogicalCores", func(val failpoint.Value) { - intVal, err := strconv.Atoi(val.(string)) - if err == nil { - failpoint.Return(true, uint64(intVal)) - } else { - failpoint.Return(false, 0) - } - }) - rows, err := infoschema.FetchClusterServerInfoWithoutPrivilegeCheck(ctx, sctx.GetSessionVars(), serversInfo, diagnosticspb.ServerInfoType_HardwareInfo, false) - if err != nil { - return false, 0 - } - var initialMaxCores uint64 = 10000 - var minLogicalCores = initialMaxCores // set to a large enough value here - for _, row := range rows { - if row[4].GetString() == "cpu-logical-cores" { - logicalCpus, err := strconv.Atoi(row[5].GetString()) - if err == nil && logicalCpus > 0 { - minLogicalCores = min(minLogicalCores, uint64(logicalCpus)) - } - } - } - // No need to check len(serersInfo) == serverCount here, since missing some servers' info won't affect the correctness - if minLogicalCores > 1 && minLogicalCores != initialMaxCores { - if runtime.GOARCH == "amd64" { - // In most x86-64 platforms, `Thread(s) per core` is 2 - return true, minLogicalCores / 2 - } - // ARM cpus don't implement Hyper-threading. - return true, minLogicalCores - // Other platforms are too rare to consider - } - - return false, 0 -} - -func checkFineGrainedShuffleForJoinAgg(ctx context.Context, sctx base.PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo, exchangeColCount int, splitLimit uint64) (applyFlag bool, streamCount uint64) { - switch (*streamCountInfo).itemStatus { - case unInitialized: - streamCount = 4 // assume 8c node in cluster as minimal, stream count is 8 / 2 = 4 - case initialized: - streamCount = (*streamCountInfo).itemValue - case failed: - return false, 0 // probably won't reach this path - } - - var tiflashServerCount uint64 - switch (*tiflashServerCountInfo).itemStatus { - case unInitialized: - serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) - if err != nil { - (*tiflashServerCountInfo).itemStatus = failed - (*tiflashServerCountInfo).itemValue = 0 - if (*streamCountInfo).itemStatus == unInitialized { - setDefaultStreamCount(streamCountInfo) - } - return false, 0 - } - tiflashServerCount = uint64(len(serversInfo)) - (*tiflashServerCountInfo).itemStatus = initialized - (*tiflashServerCountInfo).itemValue = tiflashServerCount - case initialized: - tiflashServerCount = (*tiflashServerCountInfo).itemValue - case failed: - return false, 0 - } - - // if already exceeds splitLimit, no need to fetch actual logical cores - if tiflashServerCount*uint64(exchangeColCount)*streamCount > splitLimit { - return false, 0 - } - - // if streamCount already initialized, and can pass splitLimit check - if (*streamCountInfo).itemStatus == initialized { - return true, streamCount - } - - serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) - if err != nil { - (*tiflashServerCountInfo).itemStatus = failed - (*tiflashServerCountInfo).itemValue = 0 - return false, 0 - } - flag, temStreamCount := calculateTiFlashStreamCountUsingMinLogicalCores(ctx, sctx, serversInfo) - if !flag { - setDefaultStreamCount(streamCountInfo) - (*tiflashServerCountInfo).itemStatus = failed - return false, 0 - } - streamCount = temStreamCount - (*streamCountInfo).itemStatus = initialized - (*streamCountInfo).itemValue = streamCount - applyFlag = tiflashServerCount*uint64(exchangeColCount)*streamCount <= splitLimit - return applyFlag, streamCount -} - -func inferFineGrainedShuffleStreamCountForWindow(ctx context.Context, sctx base.PlanContext, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) (streamCount uint64) { - switch (*streamCountInfo).itemStatus { - case unInitialized: - if (*tiflashServerCountInfo).itemStatus == failed { - setDefaultStreamCount(streamCountInfo) - streamCount = (*streamCountInfo).itemValue - break - } - - serversInfo, err := infoschema.GetTiFlashServerInfo(sctx.GetStore()) - if err != nil { - setDefaultStreamCount(streamCountInfo) - streamCount = (*streamCountInfo).itemValue - (*tiflashServerCountInfo).itemStatus = failed - break - } - - if (*tiflashServerCountInfo).itemStatus == unInitialized { - (*tiflashServerCountInfo).itemStatus = initialized - (*tiflashServerCountInfo).itemValue = uint64(len(serversInfo)) - } - - flag, temStreamCount := calculateTiFlashStreamCountUsingMinLogicalCores(ctx, sctx, serversInfo) - if !flag { - setDefaultStreamCount(streamCountInfo) - streamCount = (*streamCountInfo).itemValue - (*tiflashServerCountInfo).itemStatus = failed - break - } - streamCount = temStreamCount - (*streamCountInfo).itemStatus = initialized - (*streamCountInfo).itemValue = streamCount - case initialized: - streamCount = (*streamCountInfo).itemValue - case failed: - setDefaultStreamCount(streamCountInfo) - streamCount = (*streamCountInfo).itemValue - } - return streamCount -} - -func setDefaultStreamCount(streamCountInfo *tiflashClusterInfo) { - (*streamCountInfo).itemStatus = initialized - (*streamCountInfo).itemValue = variable.DefStreamCountWhenMaxThreadsNotSet -} - -func setupFineGrainedShuffleInternal(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan, helper *fineGrainedShuffleHelper, streamCountInfo *tiflashClusterInfo, tiflashServerCountInfo *tiflashClusterInfo) { - switch x := plan.(type) { - case *PhysicalWindow: - // Do not clear the plans because window executor will keep the data partition. - // For non hash partition window function, there will be a passthrough ExchangeSender to collect data, - // which will break data partition. - helper.updateTarget(window, &x.basePhysicalPlan) - setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) - case *PhysicalSort: - if x.IsPartialSort { - // Partial sort will keep the data partition. - helper.plans = append(helper.plans, &x.basePhysicalPlan) - } else { - // Global sort will break the data partition. - helper.clear() - } - setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) - case *PhysicalSelection: - helper.plans = append(helper.plans, &x.basePhysicalPlan) - setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) - case *PhysicalProjection: - helper.plans = append(helper.plans, &x.basePhysicalPlan) - setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) - case *PhysicalExchangeReceiver: - helper.plans = append(helper.plans, &x.basePhysicalPlan) - setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) - case *PhysicalHashAgg: - // Todo: allow hash aggregation's output still benefits from fine grained shuffle - aggHelper := fineGrainedShuffleHelper{shuffleTarget: hashAgg, plans: []*basePhysicalPlan{}} - aggHelper.plans = append(aggHelper.plans, &x.basePhysicalPlan) - setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], &aggHelper, streamCountInfo, tiflashServerCountInfo) - case *PhysicalHashJoin: - child0 := x.children[0] - child1 := x.children[1] - buildChild := child0 - probChild := child1 - joinKeys := x.LeftJoinKeys - if x.InnerChildIdx != 0 { - // Child1 is build side. - buildChild = child1 - joinKeys = x.RightJoinKeys - probChild = child0 - } - if len(joinKeys) > 0 { // Not cross join - buildHelper := fineGrainedShuffleHelper{shuffleTarget: joinBuild, plans: []*basePhysicalPlan{}} - buildHelper.plans = append(buildHelper.plans, &x.basePhysicalPlan) - buildHelper.joinKeysCount = len(joinKeys) - setupFineGrainedShuffleInternal(ctx, sctx, buildChild, &buildHelper, streamCountInfo, tiflashServerCountInfo) - } else { - buildHelper := fineGrainedShuffleHelper{shuffleTarget: unknown, plans: []*basePhysicalPlan{}} - setupFineGrainedShuffleInternal(ctx, sctx, buildChild, &buildHelper, streamCountInfo, tiflashServerCountInfo) - } - // don't apply fine grained shuffle for probe side - helper.clear() - setupFineGrainedShuffleInternal(ctx, sctx, probChild, helper, streamCountInfo, tiflashServerCountInfo) - case *PhysicalExchangeSender: - if x.ExchangeType == tipb.ExchangeType_Hash { - // Set up stream count for all plans based on shuffle target type. - var exchangeColCount = x.Schema().Len() - switch helper.shuffleTarget { - case window: - streamCount := inferFineGrainedShuffleStreamCountForWindow(ctx, sctx, streamCountInfo, tiflashServerCountInfo) - x.TiFlashFineGrainedShuffleStreamCount = streamCount - for _, p := range helper.plans { - p.TiFlashFineGrainedShuffleStreamCount = streamCount - } - case hashAgg: - applyFlag, streamCount := checkFineGrainedShuffleForJoinAgg(ctx, sctx, streamCountInfo, tiflashServerCountInfo, exchangeColCount, 1200) // 1200: performance test result - if applyFlag { - x.TiFlashFineGrainedShuffleStreamCount = streamCount - for _, p := range helper.plans { - p.TiFlashFineGrainedShuffleStreamCount = streamCount - } - } - case joinBuild: - // Support hashJoin only when shuffle hash keys equals to join keys due to tiflash implementations - if len(x.HashCols) != helper.joinKeysCount { - break - } - applyFlag, streamCount := checkFineGrainedShuffleForJoinAgg(ctx, sctx, streamCountInfo, tiflashServerCountInfo, exchangeColCount, 600) // 600: performance test result - if applyFlag { - x.TiFlashFineGrainedShuffleStreamCount = streamCount - for _, p := range helper.plans { - p.TiFlashFineGrainedShuffleStreamCount = streamCount - } - } - } - } - // exchange sender will break the data partition. - helper.clear() - setupFineGrainedShuffleInternal(ctx, sctx, x.children[0], helper, streamCountInfo, tiflashServerCountInfo) - default: - for _, child := range x.Children() { - childHelper := fineGrainedShuffleHelper{shuffleTarget: unknown, plans: []*basePhysicalPlan{}} - setupFineGrainedShuffleInternal(ctx, sctx, child, &childHelper, streamCountInfo, tiflashServerCountInfo) - } - } -} - -// propagateProbeParents doesn't affect the execution plan, it only sets the probeParents field of a PhysicalPlan. -// It's for handling the inconsistency between row count in the statsInfo and the recorded actual row count. Please -// see comments in PhysicalPlan for details. -func propagateProbeParents(plan base.PhysicalPlan, probeParents []base.PhysicalPlan) { - plan.SetProbeParents(probeParents) - switch x := plan.(type) { - case *PhysicalApply, *PhysicalIndexJoin, *PhysicalIndexHashJoin, *PhysicalIndexMergeJoin: - if join, ok := plan.(interface{ getInnerChildIdx() int }); ok { - propagateProbeParents(plan.Children()[1-join.getInnerChildIdx()], probeParents) - - // The core logic of this method: - // Record every Apply and Index Join we met, record it in a slice, and set it in their inner children. - newParents := make([]base.PhysicalPlan, len(probeParents), len(probeParents)+1) - copy(newParents, probeParents) - newParents = append(newParents, plan) - propagateProbeParents(plan.Children()[join.getInnerChildIdx()], newParents) - } - case *PhysicalTableReader: - propagateProbeParents(x.tablePlan, probeParents) - case *PhysicalIndexReader: - propagateProbeParents(x.indexPlan, probeParents) - case *PhysicalIndexLookUpReader: - propagateProbeParents(x.indexPlan, probeParents) - propagateProbeParents(x.tablePlan, probeParents) - case *PhysicalIndexMergeReader: - for _, pchild := range x.partialPlans { - propagateProbeParents(pchild, probeParents) - } - propagateProbeParents(x.tablePlan, probeParents) - default: - for _, child := range plan.Children() { - propagateProbeParents(child, probeParents) - } - } -} - -func enableParallelApply(sctx base.PlanContext, plan base.PhysicalPlan) base.PhysicalPlan { - if !sctx.GetSessionVars().EnableParallelApply { - return plan - } - // the parallel apply has three limitation: - // 1. the parallel implementation now cannot keep order; - // 2. the inner child has to support clone; - // 3. if one Apply is in the inner side of another Apply, it cannot be parallel, for example: - // The topology of 3 Apply operators are A1(A2, A3), which means A2 is the outer child of A1 - // while A3 is the inner child. Then A1 and A2 can be parallel and A3 cannot. - if apply, ok := plan.(*PhysicalApply); ok { - outerIdx := 1 - apply.InnerChildIdx - noOrder := len(apply.GetChildReqProps(outerIdx).SortItems) == 0 // limitation 1 - _, err := SafeClone(sctx, apply.Children()[apply.InnerChildIdx]) - supportClone := err == nil // limitation 2 - if noOrder && supportClone { - apply.Concurrency = sctx.GetSessionVars().ExecutorConcurrency - } else { - if err != nil { - sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Some apply operators can not be executed in parallel: %v", err)) - } else { - sctx.GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackError("Some apply operators can not be executed in parallel")) - } - } - // because of the limitation 3, we cannot parallelize Apply operators in this Apply's inner size, - // so we only invoke recursively for its outer child. - apply.SetChild(outerIdx, enableParallelApply(sctx, apply.Children()[outerIdx])) - return apply - } - for i, child := range plan.Children() { - plan.SetChild(i, enableParallelApply(sctx, child)) - } - return plan -} - -// LogicalOptimizeTest is just exported for test. -func LogicalOptimizeTest(ctx context.Context, flag uint64, logic base.LogicalPlan) (base.LogicalPlan, error) { - return logicalOptimize(ctx, flag, logic) -} - -func logicalOptimize(ctx context.Context, flag uint64, logic base.LogicalPlan) (base.LogicalPlan, error) { - if logic.SCtx().GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(logic.SCtx()) - defer debugtrace.LeaveContextCommon(logic.SCtx()) - } - opt := optimizetrace.DefaultLogicalOptimizeOption() - vars := logic.SCtx().GetSessionVars() - if vars.StmtCtx.EnableOptimizeTrace { - vars.StmtCtx.OptimizeTracer = &tracing.OptimizeTracer{} - tracer := &tracing.LogicalOptimizeTracer{ - Steps: make([]*tracing.LogicalRuleOptimizeTracer, 0), - } - opt = opt.WithEnableOptimizeTracer(tracer) - defer func() { - vars.StmtCtx.OptimizeTracer.Logical = tracer - }() - } - var err error - var againRuleList []base.LogicalOptRule - for i, rule := range optRuleList { - // The order of flags is same as the order of optRule in the list. - // We use a bitmask to record which opt rules should be used. If the i-th bit is 1, it means we should - // apply i-th optimizing rule. - if flag&(1< 0 { - logic.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("The parameter of nth_plan() is out of range")) - } - if t.Invalid() { - errMsg := "Can't find a proper physical plan for this query" - if config.GetGlobalConfig().DisaggregatedTiFlash && !logic.SCtx().GetSessionVars().IsMPPAllowed() { - errMsg += ": cop and batchCop are not allowed in disaggregated tiflash mode, you should turn on tidb_allow_mpp switch" - } - return nil, 0, plannererrors.ErrInternal.GenWithStackByArgs(errMsg) - } - - if err = t.Plan().ResolveIndices(); err != nil { - return nil, 0, err - } - cost, err = getPlanCost(t.Plan(), property.RootTaskType, optimizetrace.NewDefaultPlanCostOption()) - return t.Plan(), cost, err -} - -// eliminateUnionScanAndLock set lock property for PointGet and BatchPointGet and eliminates UnionScan and Lock. -func eliminateUnionScanAndLock(sctx base.PlanContext, p base.PhysicalPlan) base.PhysicalPlan { - var pointGet *PointGetPlan - var batchPointGet *BatchPointGetPlan - var physLock *PhysicalLock - var unionScan *PhysicalUnionScan - iteratePhysicalPlan(p, func(p base.PhysicalPlan) bool { - if len(p.Children()) > 1 { - return false - } - switch x := p.(type) { - case *PointGetPlan: - pointGet = x - case *BatchPointGetPlan: - batchPointGet = x - case *PhysicalLock: - physLock = x - case *PhysicalUnionScan: - unionScan = x - } - return true - }) - if pointGet == nil && batchPointGet == nil { - return p - } - if physLock == nil && unionScan == nil { - return p - } - if physLock != nil { - lock, waitTime := getLockWaitTime(sctx, physLock.Lock) - if !lock { - return p - } - if pointGet != nil { - pointGet.Lock = lock - pointGet.LockWaitTime = waitTime - } else { - batchPointGet.Lock = lock - batchPointGet.LockWaitTime = waitTime - } - } - return transformPhysicalPlan(p, func(p base.PhysicalPlan) base.PhysicalPlan { - if p == physLock { - return p.Children()[0] - } - if p == unionScan { - return p.Children()[0] - } - return p - }) -} - -func iteratePhysicalPlan(p base.PhysicalPlan, f func(p base.PhysicalPlan) bool) { - if !f(p) { - return - } - for _, child := range p.Children() { - iteratePhysicalPlan(child, f) - } -} - -func transformPhysicalPlan(p base.PhysicalPlan, f func(p base.PhysicalPlan) base.PhysicalPlan) base.PhysicalPlan { - for i, child := range p.Children() { - p.Children()[i] = transformPhysicalPlan(child, f) - } - return f(p) -} - -func existsCartesianProduct(p base.LogicalPlan) bool { - if join, ok := p.(*LogicalJoin); ok && len(join.EqualConditions) == 0 { - return join.JoinType == InnerJoin || join.JoinType == LeftOuterJoin || join.JoinType == RightOuterJoin - } - for _, child := range p.Children() { - if existsCartesianProduct(child) { - return true - } - } - return false -} - -// DefaultDisabledLogicalRulesList indicates the logical rules which should be banned. -var DefaultDisabledLogicalRulesList *atomic.Value - -func disableReuseChunkIfNeeded(sctx base.PlanContext, plan base.PhysicalPlan) { - if !sctx.GetSessionVars().IsAllocValid() { - return - } - - if checkOverlongColType(sctx, plan) { - return - } - - for _, child := range plan.Children() { - disableReuseChunkIfNeeded(sctx, child) - } -} - -// checkOverlongColType Check if read field type is long field. -func checkOverlongColType(sctx base.PlanContext, plan base.PhysicalPlan) bool { - if plan == nil { - return false - } - switch plan.(type) { - case *PhysicalTableReader, *PhysicalIndexReader, - *PhysicalIndexLookUpReader, *PhysicalIndexMergeReader, *PointGetPlan: - if existsOverlongType(plan.Schema()) { - sctx.GetSessionVars().ClearAlloc(nil, false) - return true - } - } - return false -} - -// existsOverlongType Check if exists long type column. -func existsOverlongType(schema *expression.Schema) bool { - if schema == nil { - return false - } - for _, column := range schema.Columns { - switch column.RetType.GetType() { - case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, - mysql.TypeBlob, mysql.TypeJSON: - return true - case mysql.TypeVarString, mysql.TypeVarchar: - // if the column is varchar and the length of - // the column is defined to be more than 1000, - // the column is considered a large type and - // disable chunk_reuse. - if column.RetType.GetFlen() > 1000 { - return true - } - } - } - return false -} diff --git a/pkg/planner/core/rule_collect_plan_stats.go b/pkg/planner/core/rule_collect_plan_stats.go index 33c07f664f0f9..e9047782c5095 100644 --- a/pkg/planner/core/rule_collect_plan_stats.go +++ b/pkg/planner/core/rule_collect_plan_stats.go @@ -106,13 +106,13 @@ func RequestLoadStats(ctx base.PlanContext, neededHistItems []model.StatsLoadIte if maxExecutionTime > 0 && maxExecutionTime < uint64(syncWait) { syncWait = int64(maxExecutionTime) } - if val, _err_ := failpoint.Eval(_curpkg_("assertSyncWaitFailed")); _err_ == nil { + failpoint.Inject("assertSyncWaitFailed", func(val failpoint.Value) { if val.(bool) { if syncWait != 1 { panic("syncWait should be 1(ms)") } } - } + }) var timeout = time.Duration(syncWait * time.Millisecond.Nanoseconds()) stmtCtx := ctx.GetSessionVars().StmtCtx err := domain.GetDomain(ctx).StatsHandle().SendLoadRequests(stmtCtx, neededHistItems, timeout) diff --git a/pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ b/pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ deleted file mode 100644 index e9047782c5095..0000000000000 --- a/pkg/planner/core/rule_collect_plan_stats.go__failpoint_stash__ +++ /dev/null @@ -1,316 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "context" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -// CollectPredicateColumnsPoint collects the columns that are used in the predicates. -type CollectPredicateColumnsPoint struct{} - -// Optimize implements LogicalOptRule.<0th> interface. -func (CollectPredicateColumnsPoint) Optimize(_ context.Context, plan base.LogicalPlan, _ *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, bool, error) { - planChanged := false - if plan.SCtx().GetSessionVars().InRestrictedSQL { - return plan, planChanged, nil - } - syncWait := plan.SCtx().GetSessionVars().StatsLoadSyncWait.Load() - histNeeded := syncWait > 0 - predicateColumns, histNeededColumns, visitedPhysTblIDs := CollectColumnStatsUsage(plan, histNeeded) - if len(predicateColumns) > 0 { - plan.SCtx().UpdateColStatsUsage(predicateColumns) - } - - // Prepare the table metadata to avoid repeatedly fetching from the infoSchema below, and trigger extra sync/async - // stats loading for the determinate mode. - is := plan.SCtx().GetInfoSchema().(infoschema.InfoSchema) - tblID2Tbl := make(map[int64]table.Table) - visitedPhysTblIDs.ForEach(func(physicalTblID int) { - tbl, _ := infoschema.FindTableByTblOrPartID(is, int64(physicalTblID)) - if tbl == nil { - return - } - tblID2Tbl[int64(physicalTblID)] = tbl - }) - - // collect needed virtual columns from already needed columns - // Note that we use the dependingVirtualCols only to collect needed index stats, but not to trigger stats loading on - // the virtual columns themselves. It's because virtual columns themselves don't have statistics, while expression - // indexes, which are indexes on virtual columns, have statistics. We don't waste the resource here now. - dependingVirtualCols := CollectDependingVirtualCols(tblID2Tbl, histNeededColumns) - - histNeededIndices := collectSyncIndices(plan.SCtx(), append(histNeededColumns, dependingVirtualCols...), tblID2Tbl) - histNeededItems := collectHistNeededItems(histNeededColumns, histNeededIndices) - if histNeeded && len(histNeededItems) > 0 { - err := RequestLoadStats(plan.SCtx(), histNeededItems, syncWait) - return plan, planChanged, err - } - return plan, planChanged, nil -} - -// Name implements the base.LogicalOptRule.<1st> interface. -func (CollectPredicateColumnsPoint) Name() string { - return "collect_predicate_columns_point" -} - -// SyncWaitStatsLoadPoint sync-wait for stats load point. -type SyncWaitStatsLoadPoint struct{} - -// Optimize implements the base.LogicalOptRule.<0th> interface. -func (SyncWaitStatsLoadPoint) Optimize(_ context.Context, plan base.LogicalPlan, _ *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, bool, error) { - planChanged := false - if plan.SCtx().GetSessionVars().InRestrictedSQL { - return plan, planChanged, nil - } - if plan.SCtx().GetSessionVars().StmtCtx.IsSyncStatsFailed { - return plan, planChanged, nil - } - err := SyncWaitStatsLoad(plan) - return plan, planChanged, err -} - -// Name implements the base.LogicalOptRule.<1st> interface. -func (SyncWaitStatsLoadPoint) Name() string { - return "sync_wait_stats_load_point" -} - -// RequestLoadStats send load column/index stats requests to stats handle -func RequestLoadStats(ctx base.PlanContext, neededHistItems []model.StatsLoadItem, syncWait int64) error { - maxExecutionTime := ctx.GetSessionVars().GetMaxExecutionTime() - if maxExecutionTime > 0 && maxExecutionTime < uint64(syncWait) { - syncWait = int64(maxExecutionTime) - } - failpoint.Inject("assertSyncWaitFailed", func(val failpoint.Value) { - if val.(bool) { - if syncWait != 1 { - panic("syncWait should be 1(ms)") - } - } - }) - var timeout = time.Duration(syncWait * time.Millisecond.Nanoseconds()) - stmtCtx := ctx.GetSessionVars().StmtCtx - err := domain.GetDomain(ctx).StatsHandle().SendLoadRequests(stmtCtx, neededHistItems, timeout) - if err != nil { - stmtCtx.IsSyncStatsFailed = true - if variable.StatsLoadPseudoTimeout.Load() { - logutil.BgLogger().Warn("RequestLoadStats failed", zap.Error(err)) - stmtCtx.AppendWarning(err) - return nil - } - logutil.BgLogger().Warn("RequestLoadStats failed", zap.Error(err)) - return err - } - return nil -} - -// SyncWaitStatsLoad sync-wait for stats load until timeout -func SyncWaitStatsLoad(plan base.LogicalPlan) error { - stmtCtx := plan.SCtx().GetSessionVars().StmtCtx - if len(stmtCtx.StatsLoad.NeededItems) <= 0 { - return nil - } - err := domain.GetDomain(plan.SCtx()).StatsHandle().SyncWaitStatsLoad(stmtCtx) - if err != nil { - stmtCtx.IsSyncStatsFailed = true - if variable.StatsLoadPseudoTimeout.Load() { - logutil.BgLogger().Warn("SyncWaitStatsLoad failed", zap.Error(err)) - stmtCtx.AppendWarning(err) - return nil - } - logutil.BgLogger().Error("SyncWaitStatsLoad failed", zap.Error(err)) - return err - } - return nil -} - -// CollectDependingVirtualCols collects the virtual columns that depend on the needed columns, and returns them in a new slice. -// -// Why do we need this? -// It's mainly for stats sync loading. -// Currently, virtual columns themselves don't have statistics. But expression indexes, which are indexes on virtual -// columns, have statistics. We need to collect needed virtual columns, then needed expression index stats can be -// collected for sync loading. -// In normal cases, if a virtual column can be used, which means related statistics may be needed, the corresponding -// expressions in the query must have already been replaced with the virtual column before here. So we just need to treat -// them like normal columns in stats sync loading, which means we just extract the Column from the expressions, the -// virtual columns we want will be there. -// However, in some cases (the mv index case now), the expressions are not replaced with the virtual columns before here. -// Instead, we match the expression in the query against the expression behind the virtual columns after here when -// building the access paths. This means we are unable to known what virtual columns will be needed by just extracting -// the Column from the expressions here. So we need to manually collect the virtual columns that may be needed. -// -// Note 1: As long as a virtual column depends on the needed columns, it will be collected. This could collect some virtual -// columns that are not actually needed. -// It's OK because that's how sync loading is expected. Sync loading only needs to ensure all actually needed stats are -// triggered to be loaded. Other logic of sync loading also works like this. -// If we want to collect only the virtual columns that are actually needed, we need to make the checking logic here exactly -// the same as the logic for generating the access paths, which will make the logic here very complicated. -// -// Note 2: Only direct dependencies are considered here. -// If a virtual column depends on another virtual column, and the latter depends on the needed columns, then the former -// will not be collected. -// For example: create table t(a int, b int, c int as (a+b), d int as (c+1)); If a is needed, then c will be collected, -// but d will not be collected. -// It's because currently it's impossible that statistics related to indirectly depending columns are actually needed. -// If we need to check indirect dependency some day, we can easily extend the logic here. -func CollectDependingVirtualCols(tblID2Tbl map[int64]table.Table, neededItems []model.StatsLoadItem) []model.StatsLoadItem { - generatedCols := make([]model.StatsLoadItem, 0) - - // group the neededItems by table id - tblID2neededColIDs := make(map[int64][]int64, len(tblID2Tbl)) - for _, item := range neededItems { - if item.IsIndex { - continue - } - tblID2neededColIDs[item.TableID] = append(tblID2neededColIDs[item.TableID], item.ID) - } - - // process them by table id - for tblID, colIDs := range tblID2neededColIDs { - tbl := tblID2Tbl[tblID] - if tbl == nil { - continue - } - // collect the needed columns on this table into a set for faster lookup - colNameSet := make(map[string]struct{}, len(colIDs)) - for _, colID := range colIDs { - name := tbl.Meta().FindColumnNameByID(colID) - if name == "" { - continue - } - colNameSet[name] = struct{}{} - } - // iterate columns in this table, and collect the virtual columns that depend on the needed columns - for _, col := range tbl.Cols() { - // only handles virtual columns - if !col.IsVirtualGenerated() { - continue - } - // If this column is already needed, then skip it. - if _, ok := colNameSet[col.Name.L]; ok { - continue - } - // If there exists a needed column that is depended on by this virtual column, - // then we think this virtual column is needed. - for depCol := range col.Dependences { - if _, ok := colNameSet[depCol]; ok { - generatedCols = append(generatedCols, model.StatsLoadItem{TableItemID: model.TableItemID{TableID: tblID, ID: col.ID, IsIndex: false}, FullLoad: true}) - break - } - } - } - } - return generatedCols -} - -// collectSyncIndices will collect the indices which includes following conditions: -// 1. the indices contained the any one of histNeededColumns, eg: histNeededColumns contained A,B columns, and idx_a is -// composed up by A column, then we thought the idx_a should be collected -// 2. The stats condition of idx_a can't meet IsFullLoad, which means its stats was evicted previously -func collectSyncIndices(ctx base.PlanContext, - histNeededColumns []model.StatsLoadItem, - tblID2Tbl map[int64]table.Table, -) map[model.TableItemID]struct{} { - histNeededIndices := make(map[model.TableItemID]struct{}) - stats := domain.GetDomain(ctx).StatsHandle() - for _, column := range histNeededColumns { - if column.IsIndex { - continue - } - tbl := tblID2Tbl[column.TableID] - if tbl == nil { - continue - } - colName := tbl.Meta().FindColumnNameByID(column.ID) - if colName == "" { - continue - } - for _, idx := range tbl.Indices() { - if idx.Meta().State != model.StatePublic { - continue - } - idxCol := idx.Meta().FindColumnByName(colName) - idxID := idx.Meta().ID - if idxCol != nil { - tblStats := stats.GetTableStats(tbl.Meta()) - if tblStats == nil || tblStats.Pseudo { - continue - } - _, loadNeeded := tblStats.IndexIsLoadNeeded(idxID) - if !loadNeeded { - continue - } - histNeededIndices[model.TableItemID{TableID: column.TableID, ID: idxID, IsIndex: true}] = struct{}{} - } - } - } - return histNeededIndices -} - -func collectHistNeededItems(histNeededColumns []model.StatsLoadItem, histNeededIndices map[model.TableItemID]struct{}) (histNeededItems []model.StatsLoadItem) { - histNeededItems = make([]model.StatsLoadItem, 0, len(histNeededColumns)+len(histNeededIndices)) - for idx := range histNeededIndices { - histNeededItems = append(histNeededItems, model.StatsLoadItem{TableItemID: idx, FullLoad: true}) - } - histNeededItems = append(histNeededItems, histNeededColumns...) - return -} - -func recordTableRuntimeStats(sctx base.PlanContext, tbls map[int64]struct{}) { - tblStats := sctx.GetSessionVars().StmtCtx.TableStats - if tblStats == nil { - tblStats = map[int64]any{} - } - for tblID := range tbls { - tblJSONStats, skip, err := recordSingleTableRuntimeStats(sctx, tblID) - if err != nil { - logutil.BgLogger().Warn("record table json stats failed", zap.Int64("tblID", tblID), zap.Error(err)) - } - if tblJSONStats == nil && !skip { - logutil.BgLogger().Warn("record table json stats failed due to empty", zap.Int64("tblID", tblID)) - } - tblStats[tblID] = tblJSONStats - } - sctx.GetSessionVars().StmtCtx.TableStats = tblStats -} - -func recordSingleTableRuntimeStats(sctx base.PlanContext, tblID int64) (stats *statistics.Table, skip bool, err error) { - dom := domain.GetDomain(sctx) - statsHandle := dom.StatsHandle() - is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) - tbl, ok := is.TableByID(tblID) - if !ok { - return nil, false, nil - } - tableInfo := tbl.Meta() - stats = statsHandle.GetTableStats(tableInfo) - // Skip the warning if the table is a temporary table because the temporary table doesn't have stats. - skip = tableInfo.TempTableType != model.TempTableNone - return stats, skip, nil -} diff --git a/pkg/planner/core/rule_eliminate_projection.go b/pkg/planner/core/rule_eliminate_projection.go index 40f2977caa27d..0a082bc106ac5 100644 --- a/pkg/planner/core/rule_eliminate_projection.go +++ b/pkg/planner/core/rule_eliminate_projection.go @@ -135,11 +135,11 @@ func doPhysicalProjectionElimination(p base.PhysicalPlan) base.PhysicalPlan { // eliminatePhysicalProjection should be called after physical optimization to // eliminate the redundant projection left after logical projection elimination. func eliminatePhysicalProjection(p base.PhysicalPlan) base.PhysicalPlan { - if val, _err_ := failpoint.Eval(_curpkg_("DisableProjectionPostOptimization")); _err_ == nil { + failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { if val.(bool) { - return p + failpoint.Return(p) } - } + }) newRoot := doPhysicalProjectionElimination(p) return newRoot diff --git a/pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ b/pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ deleted file mode 100644 index 0a082bc106ac5..0000000000000 --- a/pkg/planner/core/rule_eliminate_projection.go__failpoint_stash__ +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "bytes" - "context" - "fmt" - - perrors "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/operator/logicalop" - ruleutil "github.com/pingcap/tidb/pkg/planner/core/rule/util" - "github.com/pingcap/tidb/pkg/planner/util/optimizetrace" -) - -// canProjectionBeEliminatedLoose checks whether a projection can be eliminated, -// returns true if every expression is a single column. -func canProjectionBeEliminatedLoose(p *logicalop.LogicalProjection) bool { - // project for expand will assign a new col id for col ref, because these column should be - // data cloned in the execution time and may be filled with null value at the same time. - // so it's not a REAL column reference. Detect the column ref in projection here and do - // the elimination here will restore the Expand's grouping sets column back to use the - // original column ref again. (which is not right) - if p.Proj4Expand { - return false - } - for _, expr := range p.Exprs { - _, ok := expr.(*expression.Column) - if !ok { - return false - } - } - return true -} - -// canProjectionBeEliminatedStrict checks whether a projection can be -// eliminated, returns true if the projection just copy its child's output. -func canProjectionBeEliminatedStrict(p *PhysicalProjection) bool { - // This is due to the in-compatibility between TiFlash and TiDB: - // For TiDB, the output schema of final agg is all the aggregated functions and for - // TiFlash, the output schema of agg(TiFlash not aware of the aggregation mode) is - // aggregated functions + group by columns, so to make the things work, for final - // mode aggregation that need to be running in TiFlash, always add an extra Project - // the align the output schema. In the future, we can solve this in-compatibility by - // passing down the aggregation mode to TiFlash. - if physicalAgg, ok := p.Children()[0].(*PhysicalHashAgg); ok { - if physicalAgg.MppRunMode == Mpp1Phase || physicalAgg.MppRunMode == Mpp2Phase || physicalAgg.MppRunMode == MppScalar { - if physicalAgg.IsFinalAgg() { - return false - } - } - } - if physicalAgg, ok := p.Children()[0].(*PhysicalStreamAgg); ok { - if physicalAgg.MppRunMode == Mpp1Phase || physicalAgg.MppRunMode == Mpp2Phase || physicalAgg.MppRunMode == MppScalar { - if physicalAgg.IsFinalAgg() { - return false - } - } - } - // If this projection is specially added for `DO`, we keep it. - if p.CalculateNoDelay { - return false - } - if p.Schema().Len() == 0 { - return true - } - child := p.Children()[0] - if p.Schema().Len() != child.Schema().Len() { - return false - } - for i, expr := range p.Exprs { - col, ok := expr.(*expression.Column) - if !ok || !col.EqualColumn(child.Schema().Columns[i]) { - return false - } - } - return true -} - -func doPhysicalProjectionElimination(p base.PhysicalPlan) base.PhysicalPlan { - for i, child := range p.Children() { - p.Children()[i] = doPhysicalProjectionElimination(child) - } - - // eliminate projection in a coprocessor task - tableReader, isTableReader := p.(*PhysicalTableReader) - if isTableReader && tableReader.StoreType == kv.TiFlash { - tableReader.tablePlan = eliminatePhysicalProjection(tableReader.tablePlan) - tableReader.TablePlans = flattenPushDownPlan(tableReader.tablePlan) - return p - } - - proj, isProj := p.(*PhysicalProjection) - if !isProj || !canProjectionBeEliminatedStrict(proj) { - return p - } - child := p.Children()[0] - if childProj, ok := child.(*PhysicalProjection); ok { - // when current projection is an empty projection(schema pruned by column pruner), no need to reset child's schema - // TODO: avoid producing empty projection in column pruner. - if p.Schema().Len() != 0 { - childProj.SetSchema(p.Schema()) - } - // If any of the consecutive projection operators has the AvoidColumnEvaluator set to true, - // we need to set the AvoidColumnEvaluator of the remaining projection to true. - if proj.AvoidColumnEvaluator { - childProj.AvoidColumnEvaluator = true - } - } - for i, col := range p.Schema().Columns { - if p.SCtx().GetSessionVars().StmtCtx.ColRefFromUpdatePlan.Has(int(col.UniqueID)) && !child.Schema().Columns[i].Equal(nil, col) { - return p - } - } - return child -} - -// eliminatePhysicalProjection should be called after physical optimization to -// eliminate the redundant projection left after logical projection elimination. -func eliminatePhysicalProjection(p base.PhysicalPlan) base.PhysicalPlan { - failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(p) - } - }) - - newRoot := doPhysicalProjectionElimination(p) - return newRoot -} - -// For select, insert, delete list -// The projection eliminate in logical optimize will optimize the projection under the projection, window, agg -// The projection eliminate in post optimize will optimize other projection - -// ProjectionEliminator is for update stmt -// The projection eliminate in logical optimize has been forbidden. -// The projection eliminate in post optimize will optimize the projection under the projection, window, agg (the condition is same as logical optimize) -type ProjectionEliminator struct { -} - -// Optimize implements the logicalOptRule interface. -func (pe *ProjectionEliminator) Optimize(_ context.Context, lp base.LogicalPlan, opt *optimizetrace.LogicalOptimizeOp) (base.LogicalPlan, bool, error) { - planChanged := false - root := pe.eliminate(lp, make(map[string]*expression.Column), false, opt) - return root, planChanged, nil -} - -// eliminate eliminates the redundant projection in a logical plan. -func (pe *ProjectionEliminator) eliminate(p base.LogicalPlan, replace map[string]*expression.Column, canEliminate bool, opt *optimizetrace.LogicalOptimizeOp) base.LogicalPlan { - // LogicalCTE's logical optimization is independent. - if _, ok := p.(*LogicalCTE); ok { - return p - } - proj, isProj := p.(*logicalop.LogicalProjection) - childFlag := canEliminate - if _, isUnion := p.(*LogicalUnionAll); isUnion { - childFlag = false - } else if _, isAgg := p.(*LogicalAggregation); isAgg || isProj { - childFlag = true - } else if _, isWindow := p.(*logicalop.LogicalWindow); isWindow { - childFlag = true - } - for i, child := range p.Children() { - p.Children()[i] = pe.eliminate(child, replace, childFlag, opt) - } - - // replace logical plan schema - switch x := p.(type) { - case *LogicalJoin: - x.SetSchema(buildLogicalJoinSchema(x.JoinType, x)) - case *LogicalApply: - x.SetSchema(buildLogicalJoinSchema(x.JoinType, x)) - default: - for _, dst := range p.Schema().Columns { - ruleutil.ResolveColumnAndReplace(dst, replace) - } - } - // replace all of exprs in logical plan - p.ReplaceExprColumns(replace) - - // eliminate duplicate projection: projection with child projection - if isProj { - if child, ok := p.Children()[0].(*logicalop.LogicalProjection); ok && !expression.ExprsHasSideEffects(child.Exprs) { - ctx := p.SCtx() - for i := range proj.Exprs { - proj.Exprs[i] = ReplaceColumnOfExpr(proj.Exprs[i], child, child.Schema()) - foldedExpr := expression.FoldConstant(ctx.GetExprCtx(), proj.Exprs[i]) - // the folded expr should have the same null flag with the original expr, especially for the projection under union, so forcing it here. - foldedExpr.GetType(ctx.GetExprCtx().GetEvalCtx()).SetFlag((foldedExpr.GetType(ctx.GetExprCtx().GetEvalCtx()).GetFlag() & ^mysql.NotNullFlag) | (proj.Exprs[i].GetType(ctx.GetExprCtx().GetEvalCtx()).GetFlag() & mysql.NotNullFlag)) - proj.Exprs[i] = foldedExpr - } - p.Children()[0] = child.Children()[0] - appendDupProjEliminateTraceStep(proj, child, opt) - } - } - - if !(isProj && canEliminate && canProjectionBeEliminatedLoose(proj)) { - return p - } - exprs := proj.Exprs - for i, col := range proj.Schema().Columns { - replace[string(col.HashCode())] = exprs[i].(*expression.Column) - } - appendProjEliminateTraceStep(proj, opt) - return p.Children()[0] -} - -// ReplaceColumnOfExpr replaces column of expression by another LogicalProjection. -func ReplaceColumnOfExpr(expr expression.Expression, proj *logicalop.LogicalProjection, schema *expression.Schema) expression.Expression { - switch v := expr.(type) { - case *expression.Column: - idx := schema.ColumnIndex(v) - if idx != -1 && idx < len(proj.Exprs) { - return proj.Exprs[idx] - } - case *expression.ScalarFunction: - for i := range v.GetArgs() { - v.GetArgs()[i] = ReplaceColumnOfExpr(v.GetArgs()[i], proj, schema) - } - } - return expr -} - -// Name implements the logicalOptRule.<1st> interface. -func (*ProjectionEliminator) Name() string { - return "projection_eliminate" -} - -func appendDupProjEliminateTraceStep(parent, child *logicalop.LogicalProjection, opt *optimizetrace.LogicalOptimizeOp) { - ectx := parent.SCtx().GetExprCtx().GetEvalCtx() - action := func() string { - buffer := bytes.NewBufferString( - fmt.Sprintf("%v_%v is eliminated, %v_%v's expressions changed into[", child.TP(), child.ID(), parent.TP(), parent.ID())) - for i, expr := range parent.Exprs { - if i > 0 { - buffer.WriteString(",") - } - buffer.WriteString(expr.StringWithCtx(ectx, perrors.RedactLogDisable)) - } - buffer.WriteString("]") - return buffer.String() - } - reason := func() string { - return fmt.Sprintf("%v_%v's child %v_%v is redundant", parent.TP(), parent.ID(), child.TP(), child.ID()) - } - opt.AppendStepToCurrent(child.ID(), child.TP(), reason, action) -} - -func appendProjEliminateTraceStep(proj *logicalop.LogicalProjection, opt *optimizetrace.LogicalOptimizeOp) { - reason := func() string { - return fmt.Sprintf("%v_%v's Exprs are all Columns", proj.TP(), proj.ID()) - } - action := func() string { - return fmt.Sprintf("%v_%v is eliminated", proj.TP(), proj.ID()) - } - opt.AppendStepToCurrent(proj.ID(), proj.TP(), reason, action) -} diff --git a/pkg/planner/core/rule_inject_extra_projection.go b/pkg/planner/core/rule_inject_extra_projection.go index 8335516747c85..e86c7db13fdd9 100644 --- a/pkg/planner/core/rule_inject_extra_projection.go +++ b/pkg/planner/core/rule_inject_extra_projection.go @@ -36,11 +36,11 @@ import ( // 2. TiDB can be used as a coprocessor, when a plan tree been pushed down to // TiDB, we need to inject extra projections for the plan tree as well. func InjectExtraProjection(plan base.PhysicalPlan) base.PhysicalPlan { - if val, _err_ := failpoint.Eval(_curpkg_("DisableProjectionPostOptimization")); _err_ == nil { + failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { if val.(bool) { - return plan + failpoint.Return(plan) } - } + }) return NewProjInjector().inject(plan) } diff --git a/pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ b/pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ deleted file mode 100644 index e86c7db13fdd9..0000000000000 --- a/pkg/planner/core/rule_inject_extra_projection.go__failpoint_stash__ +++ /dev/null @@ -1,352 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "slices" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/expression/aggregation" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/planner/util/coreusage" -) - -// InjectExtraProjection is used to extract the expressions of specific -// operators into a physical Projection operator and inject the Projection below -// the operators. Thus we can accelerate the expression evaluation by eager -// evaluation. -// This function will be called in two situations: -// 1. In postOptimize. -// 2. TiDB can be used as a coprocessor, when a plan tree been pushed down to -// TiDB, we need to inject extra projections for the plan tree as well. -func InjectExtraProjection(plan base.PhysicalPlan) base.PhysicalPlan { - failpoint.Inject("DisableProjectionPostOptimization", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(plan) - } - }) - - return NewProjInjector().inject(plan) -} - -type projInjector struct { -} - -// NewProjInjector builds a projInjector. -func NewProjInjector() *projInjector { - return &projInjector{} -} - -func (pe *projInjector) inject(plan base.PhysicalPlan) base.PhysicalPlan { - for i, child := range plan.Children() { - plan.Children()[i] = pe.inject(child) - } - - if tr, ok := plan.(*PhysicalTableReader); ok && tr.StoreType == kv.TiFlash { - tr.tablePlan = pe.inject(tr.tablePlan) - tr.TablePlans = flattenPushDownPlan(tr.tablePlan) - } - - switch p := plan.(type) { - case *PhysicalHashAgg: - plan = InjectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems) - case *PhysicalStreamAgg: - plan = InjectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems) - case *PhysicalSort: - plan = InjectProjBelowSort(p, p.ByItems) - case *PhysicalTopN: - plan = InjectProjBelowSort(p, p.ByItems) - case *NominalSort: - plan = TurnNominalSortIntoProj(p, p.OnlyColumn, p.ByItems) - case *PhysicalUnionAll: - plan = injectProjBelowUnion(p) - } - return plan -} - -func injectProjBelowUnion(un *PhysicalUnionAll) *PhysicalUnionAll { - if !un.mpp { - return un - } - for i, ch := range un.children { - exprs := make([]expression.Expression, len(ch.Schema().Columns)) - needChange := false - for i, dstCol := range un.schema.Columns { - dstType := dstCol.RetType - srcCol := ch.Schema().Columns[i] - srcCol.Index = i - srcType := srcCol.RetType - if !srcType.Equal(dstType) || !(mysql.HasNotNullFlag(dstType.GetFlag()) == mysql.HasNotNullFlag(srcType.GetFlag())) { - exprs[i] = expression.BuildCastFunction4Union(un.SCtx().GetExprCtx(), srcCol, dstType) - needChange = true - } else { - exprs[i] = srcCol - } - } - if needChange { - proj := PhysicalProjection{ - Exprs: exprs, - }.Init(un.SCtx(), ch.StatsInfo(), 0) - proj.SetSchema(un.schema.Clone()) - proj.SetChildren(ch) - un.children[i] = proj - } - } - return un -} - -// InjectProjBelowAgg injects a ProjOperator below AggOperator. So that All -// scalar functions in aggregation may speed up by vectorized evaluation in -// the `proj`. If all the args of `aggFuncs`, and all the item of `groupByItems` -// are columns or constants, we do not need to build the `proj`. -func InjectProjBelowAgg(aggPlan base.PhysicalPlan, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression) base.PhysicalPlan { - hasScalarFunc := false - exprCtx := aggPlan.SCtx().GetExprCtx() - coreusage.WrapCastForAggFuncs(exprCtx, aggFuncs) - for i := 0; !hasScalarFunc && i < len(aggFuncs); i++ { - for _, arg := range aggFuncs[i].Args { - _, isScalarFunc := arg.(*expression.ScalarFunction) - hasScalarFunc = hasScalarFunc || isScalarFunc - } - for _, byItem := range aggFuncs[i].OrderByItems { - _, isScalarFunc := byItem.Expr.(*expression.ScalarFunction) - hasScalarFunc = hasScalarFunc || isScalarFunc - } - } - for i := 0; !hasScalarFunc && i < len(groupByItems); i++ { - _, isScalarFunc := groupByItems[i].(*expression.ScalarFunction) - hasScalarFunc = hasScalarFunc || isScalarFunc - } - if !hasScalarFunc { - return aggPlan - } - - projSchemaCols := make([]*expression.Column, 0, len(aggFuncs)+len(groupByItems)) - projExprs := make([]expression.Expression, 0, cap(projSchemaCols)) - cursor := 0 - - ectx := exprCtx.GetEvalCtx() - for _, f := range aggFuncs { - for i, arg := range f.Args { - if _, isCnst := arg.(*expression.Constant); isCnst { - continue - } - projExprs = append(projExprs, arg) - newArg := &expression.Column{ - UniqueID: aggPlan.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: arg.GetType(ectx), - Index: cursor, - } - projSchemaCols = append(projSchemaCols, newArg) - f.Args[i] = newArg - cursor++ - } - for _, byItem := range f.OrderByItems { - bi := byItem.Expr - if _, isCnst := bi.(*expression.Constant); isCnst { - continue - } - idx := slices.IndexFunc(projExprs, func(a expression.Expression) bool { - return a.Equal(ectx, bi) - }) - if idx < 0 { - projExprs = append(projExprs, bi) - newArg := &expression.Column{ - UniqueID: aggPlan.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: bi.GetType(ectx), - Index: cursor, - } - projSchemaCols = append(projSchemaCols, newArg) - byItem.Expr = newArg - cursor++ - } else { - byItem.Expr = projSchemaCols[idx] - } - } - } - - for i, item := range groupByItems { - it := item - if _, isCnst := it.(*expression.Constant); isCnst { - continue - } - idx := slices.IndexFunc(projExprs, func(a expression.Expression) bool { - return a.Equal(ectx, it) - }) - if idx < 0 { - projExprs = append(projExprs, it) - newArg := &expression.Column{ - UniqueID: aggPlan.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: item.GetType(ectx), - Index: cursor, - } - projSchemaCols = append(projSchemaCols, newArg) - groupByItems[i] = newArg - cursor++ - } else { - groupByItems[i] = projSchemaCols[idx] - } - } - - child := aggPlan.Children()[0] - prop := aggPlan.GetChildReqProps(0).CloneEssentialFields() - proj := PhysicalProjection{ - Exprs: projExprs, - AvoidColumnEvaluator: false, - }.Init(aggPlan.SCtx(), child.StatsInfo().ScaleByExpectCnt(prop.ExpectedCnt), aggPlan.QueryBlockOffset(), prop) - proj.SetSchema(expression.NewSchema(projSchemaCols...)) - proj.SetChildren(child) - - aggPlan.SetChildren(proj) - return aggPlan -} - -// InjectProjBelowSort extracts the ScalarFunctions of `orderByItems` into a -// PhysicalProjection and injects it below PhysicalTopN/PhysicalSort. The schema -// of PhysicalSort and PhysicalTopN are the same as the schema of their -// children. When a projection is injected as the child of PhysicalSort and -// PhysicalTopN, some extra columns will be added into the schema of the -// Projection, thus we need to add another Projection upon them to prune the -// redundant columns. -func InjectProjBelowSort(p base.PhysicalPlan, orderByItems []*util.ByItems) base.PhysicalPlan { - hasScalarFunc, numOrderByItems := false, len(orderByItems) - for i := 0; !hasScalarFunc && i < numOrderByItems; i++ { - _, isScalarFunc := orderByItems[i].Expr.(*expression.ScalarFunction) - hasScalarFunc = hasScalarFunc || isScalarFunc - } - if !hasScalarFunc { - return p - } - - topProjExprs := make([]expression.Expression, 0, p.Schema().Len()) - for i := range p.Schema().Columns { - col := p.Schema().Columns[i].Clone().(*expression.Column) - col.Index = i - topProjExprs = append(topProjExprs, col) - } - topProj := PhysicalProjection{ - Exprs: topProjExprs, - AvoidColumnEvaluator: false, - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), nil) - topProj.SetSchema(p.Schema().Clone()) - topProj.SetChildren(p) - - childPlan := p.Children()[0] - bottomProjSchemaCols := make([]*expression.Column, 0, len(childPlan.Schema().Columns)+numOrderByItems) - bottomProjExprs := make([]expression.Expression, 0, len(childPlan.Schema().Columns)+numOrderByItems) - for _, col := range childPlan.Schema().Columns { - newCol := col.Clone().(*expression.Column) - newCol.Index = childPlan.Schema().ColumnIndex(newCol) - bottomProjSchemaCols = append(bottomProjSchemaCols, newCol) - bottomProjExprs = append(bottomProjExprs, newCol) - } - - for _, item := range orderByItems { - itemExpr := item.Expr - if _, isScalarFunc := itemExpr.(*expression.ScalarFunction); !isScalarFunc { - continue - } - bottomProjExprs = append(bottomProjExprs, itemExpr) - newArg := &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: itemExpr.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), - Index: len(bottomProjSchemaCols), - } - bottomProjSchemaCols = append(bottomProjSchemaCols, newArg) - item.Expr = newArg - } - - childProp := p.GetChildReqProps(0).CloneEssentialFields() - bottomProj := PhysicalProjection{ - Exprs: bottomProjExprs, - AvoidColumnEvaluator: false, - }.Init(p.SCtx(), childPlan.StatsInfo().ScaleByExpectCnt(childProp.ExpectedCnt), p.QueryBlockOffset(), childProp) - bottomProj.SetSchema(expression.NewSchema(bottomProjSchemaCols...)) - bottomProj.SetChildren(childPlan) - p.SetChildren(bottomProj) - - if origChildProj, isChildProj := childPlan.(*PhysicalProjection); isChildProj { - refine4NeighbourProj(bottomProj, origChildProj) - } - refine4NeighbourProj(topProj, bottomProj) - - return topProj -} - -// TurnNominalSortIntoProj will turn nominal sort into two projections. This is to check if the scalar functions will -// overflow. -func TurnNominalSortIntoProj(p base.PhysicalPlan, onlyColumn bool, orderByItems []*util.ByItems) base.PhysicalPlan { - if onlyColumn { - return p.Children()[0] - } - - numOrderByItems := len(orderByItems) - childPlan := p.Children()[0] - - bottomProjSchemaCols := make([]*expression.Column, 0, len(childPlan.Schema().Columns)+numOrderByItems) - bottomProjExprs := make([]expression.Expression, 0, len(childPlan.Schema().Columns)+numOrderByItems) - for _, col := range childPlan.Schema().Columns { - newCol := col.Clone().(*expression.Column) - newCol.Index = childPlan.Schema().ColumnIndex(newCol) - bottomProjSchemaCols = append(bottomProjSchemaCols, newCol) - bottomProjExprs = append(bottomProjExprs, newCol) - } - - for _, item := range orderByItems { - itemExpr := item.Expr - if _, isScalarFunc := itemExpr.(*expression.ScalarFunction); !isScalarFunc { - continue - } - bottomProjExprs = append(bottomProjExprs, itemExpr) - newArg := &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: itemExpr.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), - Index: len(bottomProjSchemaCols), - } - bottomProjSchemaCols = append(bottomProjSchemaCols, newArg) - } - - childProp := p.GetChildReqProps(0).CloneEssentialFields() - bottomProj := PhysicalProjection{ - Exprs: bottomProjExprs, - AvoidColumnEvaluator: false, - }.Init(p.SCtx(), childPlan.StatsInfo().ScaleByExpectCnt(childProp.ExpectedCnt), p.QueryBlockOffset(), childProp) - bottomProj.SetSchema(expression.NewSchema(bottomProjSchemaCols...)) - bottomProj.SetChildren(childPlan) - - topProjExprs := make([]expression.Expression, 0, childPlan.Schema().Len()) - for i := range childPlan.Schema().Columns { - col := childPlan.Schema().Columns[i].Clone().(*expression.Column) - col.Index = i - topProjExprs = append(topProjExprs, col) - } - topProj := PhysicalProjection{ - Exprs: topProjExprs, - AvoidColumnEvaluator: false, - }.Init(p.SCtx(), childPlan.StatsInfo().ScaleByExpectCnt(childProp.ExpectedCnt), p.QueryBlockOffset(), childProp) - topProj.SetSchema(childPlan.Schema().Clone()) - topProj.SetChildren(bottomProj) - - if origChildProj, isChildProj := childPlan.(*PhysicalProjection); isChildProj { - refine4NeighbourProj(bottomProj, origChildProj) - } - refine4NeighbourProj(topProj, bottomProj) - - return topProj -} diff --git a/pkg/planner/core/task.go b/pkg/planner/core/task.go index 493c4e6b68145..b926f5d9bb9fc 100644 --- a/pkg/planner/core/task.go +++ b/pkg/planner/core/task.go @@ -2307,11 +2307,11 @@ func (p *PhysicalWindow) attach2TaskForMPP(mpp *MppTask) base.Task { columns := p.Schema().Clone().Columns[len(p.Schema().Columns)-len(p.WindowFuncDescs):] p.schema = expression.MergeSchema(mpp.Plan().Schema(), expression.NewSchema(columns...)) - if _, _err_ := failpoint.Eval(_curpkg_("CheckMPPWindowSchemaLength")); _err_ == nil { + failpoint.Inject("CheckMPPWindowSchemaLength", func() { if len(p.Schema().Columns) != len(mpp.Plan().Schema().Columns)+len(p.WindowFuncDescs) { panic("mpp physical window has incorrect schema length") } - } + }) return attachPlan2Task(p, mpp) } diff --git a/pkg/planner/core/task.go__failpoint_stash__ b/pkg/planner/core/task.go__failpoint_stash__ deleted file mode 100644 index b926f5d9bb9fc..0000000000000 --- a/pkg/planner/core/task.go__failpoint_stash__ +++ /dev/null @@ -1,2473 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package core - -import ( - "math" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/expression/aggregation" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/cardinality" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/core/cost" - "github.com/pingcap/tidb/pkg/planner/core/operator/baseimpl" - "github.com/pingcap/tidb/pkg/planner/property" - "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/paging" - "github.com/pingcap/tidb/pkg/util/plancodec" - "go.uber.org/zap" -) - -func attachPlan2Task(p base.PhysicalPlan, t base.Task) base.Task { - switch v := t.(type) { - case *CopTask: - if v.indexPlanFinished { - p.SetChildren(v.tablePlan) - v.tablePlan = p - } else { - p.SetChildren(v.indexPlan) - v.indexPlan = p - } - case *RootTask: - p.SetChildren(v.GetPlan()) - v.SetPlan(p) - case *MppTask: - p.SetChildren(v.p) - v.p = p - } - return t -} - -// finishIndexPlan means we no longer add plan to index plan, and compute the network cost for it. -func (t *CopTask) finishIndexPlan() { - if t.indexPlanFinished { - return - } - t.indexPlanFinished = true - // index merge case is specially handled for now. - // We need a elegant way to solve the stats of index merge in this case. - if t.tablePlan != nil && t.indexPlan != nil { - ts := t.tablePlan.(*PhysicalTableScan) - originStats := ts.StatsInfo() - ts.SetStats(t.indexPlan.StatsInfo()) - if originStats != nil { - // keep the original stats version - ts.StatsInfo().StatsVersion = originStats.StatsVersion - } - } -} - -func (t *CopTask) getStoreType() kv.StoreType { - if t.tablePlan == nil { - return kv.TiKV - } - tp := t.tablePlan - for len(tp.Children()) > 0 { - if len(tp.Children()) > 1 { - return kv.TiFlash - } - tp = tp.Children()[0] - } - if ts, ok := tp.(*PhysicalTableScan); ok { - return ts.StoreType - } - return kv.TiKV -} - -// Attach2Task implements PhysicalPlan interface. -func (p *basePhysicalPlan) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].ConvertToRootTask(p.SCtx()) - return attachPlan2Task(p.self, t) -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalUnionScan) Attach2Task(tasks ...base.Task) base.Task { - // We need to pull the projection under unionScan upon unionScan. - // Since the projection only prunes columns, it's ok the put it upon unionScan. - if sel, ok := tasks[0].Plan().(*PhysicalSelection); ok { - if pj, ok := sel.children[0].(*PhysicalProjection); ok { - // Convert unionScan->selection->projection to projection->unionScan->selection. - sel.SetChildren(pj.children...) - p.SetChildren(sel) - p.SetStats(tasks[0].Plan().StatsInfo()) - rt, _ := tasks[0].(*RootTask) - rt.SetPlan(p) - pj.SetChildren(p) - return pj.Attach2Task(tasks...) - } - } - if pj, ok := tasks[0].Plan().(*PhysicalProjection); ok { - // Convert unionScan->projection to projection->unionScan, because unionScan can't handle projection as its children. - p.SetChildren(pj.children...) - p.SetStats(tasks[0].Plan().StatsInfo()) - rt, _ := tasks[0].(*RootTask) - rt.SetPlan(pj.children[0]) - pj.SetChildren(p) - return pj.Attach2Task(p.basePhysicalPlan.Attach2Task(tasks...)) - } - p.SetStats(tasks[0].Plan().StatsInfo()) - return p.basePhysicalPlan.Attach2Task(tasks...) -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalApply) Attach2Task(tasks ...base.Task) base.Task { - lTask := tasks[0].ConvertToRootTask(p.SCtx()) - rTask := tasks[1].ConvertToRootTask(p.SCtx()) - p.SetChildren(lTask.Plan(), rTask.Plan()) - p.schema = BuildPhysicalJoinSchema(p.JoinType, p) - t := &RootTask{} - t.SetPlan(p) - return t -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalIndexMergeJoin) Attach2Task(tasks ...base.Task) base.Task { - outerTask := tasks[1-p.InnerChildIdx].ConvertToRootTask(p.SCtx()) - if p.InnerChildIdx == 1 { - p.SetChildren(outerTask.Plan(), p.innerPlan) - } else { - p.SetChildren(p.innerPlan, outerTask.Plan()) - } - t := &RootTask{} - t.SetPlan(p) - return t -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalIndexHashJoin) Attach2Task(tasks ...base.Task) base.Task { - outerTask := tasks[1-p.InnerChildIdx].ConvertToRootTask(p.SCtx()) - if p.InnerChildIdx == 1 { - p.SetChildren(outerTask.Plan(), p.innerPlan) - } else { - p.SetChildren(p.innerPlan, outerTask.Plan()) - } - t := &RootTask{} - t.SetPlan(p) - return t -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalIndexJoin) Attach2Task(tasks ...base.Task) base.Task { - outerTask := tasks[1-p.InnerChildIdx].ConvertToRootTask(p.SCtx()) - if p.InnerChildIdx == 1 { - p.SetChildren(outerTask.Plan(), p.innerPlan) - } else { - p.SetChildren(p.innerPlan, outerTask.Plan()) - } - t := &RootTask{} - t.SetPlan(p) - return t -} - -// RowSize for cost model ver2 is simplified, always use this function to calculate row size. -func getAvgRowSize(stats *property.StatsInfo, cols []*expression.Column) (size float64) { - if stats.HistColl != nil { - size = cardinality.GetAvgRowSizeDataInDiskByRows(stats.HistColl, cols) - } else { - // Estimate using just the type info. - for _, col := range cols { - size += float64(chunk.EstimateTypeWidth(col.GetStaticType())) - } - } - return -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalHashJoin) Attach2Task(tasks ...base.Task) base.Task { - if p.storeTp == kv.TiFlash { - return p.attach2TaskForTiFlash(tasks...) - } - lTask := tasks[0].ConvertToRootTask(p.SCtx()) - rTask := tasks[1].ConvertToRootTask(p.SCtx()) - p.SetChildren(lTask.Plan(), rTask.Plan()) - task := &RootTask{} - task.SetPlan(p) - return task -} - -// TiDB only require that the types fall into the same catalog but TiFlash require the type to be exactly the same, so -// need to check if the conversion is a must -func needConvert(tp *types.FieldType, rtp *types.FieldType) bool { - // all the string type are mapped to the same type in TiFlash, so - // do not need convert for string types - if types.IsString(tp.GetType()) && types.IsString(rtp.GetType()) { - return false - } - if tp.GetType() != rtp.GetType() { - return true - } - if tp.GetType() != mysql.TypeNewDecimal { - return false - } - if tp.GetDecimal() != rtp.GetDecimal() { - return true - } - // for decimal type, TiFlash have 4 different impl based on the required precision - if tp.GetFlen() >= 0 && tp.GetFlen() <= 9 && rtp.GetFlen() >= 0 && rtp.GetFlen() <= 9 { - return false - } - if tp.GetFlen() > 9 && tp.GetFlen() <= 18 && rtp.GetFlen() > 9 && rtp.GetFlen() <= 18 { - return false - } - if tp.GetFlen() > 18 && tp.GetFlen() <= 38 && rtp.GetFlen() > 18 && rtp.GetFlen() <= 38 { - return false - } - if tp.GetFlen() > 38 && tp.GetFlen() <= 65 && rtp.GetFlen() > 38 && rtp.GetFlen() <= 65 { - return false - } - return true -} - -func negotiateCommonType(lType, rType *types.FieldType) (*types.FieldType, bool, bool) { - commonType := types.AggFieldType([]*types.FieldType{lType, rType}) - if commonType.GetType() == mysql.TypeNewDecimal { - lExtend := 0 - rExtend := 0 - cDec := rType.GetDecimal() - if lType.GetDecimal() < rType.GetDecimal() { - lExtend = rType.GetDecimal() - lType.GetDecimal() - } else if lType.GetDecimal() > rType.GetDecimal() { - rExtend = lType.GetDecimal() - rType.GetDecimal() - cDec = lType.GetDecimal() - } - lLen, rLen := lType.GetFlen()+lExtend, rType.GetFlen()+rExtend - cLen := max(lLen, rLen) - commonType.SetDecimalUnderLimit(cDec) - commonType.SetFlenUnderLimit(cLen) - } else if needConvert(lType, commonType) || needConvert(rType, commonType) { - if mysql.IsIntegerType(commonType.GetType()) { - // If the target type is int, both TiFlash and Mysql only support cast to Int64 - // so we need to promote the type to Int64 - commonType.SetType(mysql.TypeLonglong) - commonType.SetFlen(mysql.MaxIntWidth) - } - } - return commonType, needConvert(lType, commonType), needConvert(rType, commonType) -} - -func getProj(ctx base.PlanContext, p base.PhysicalPlan) *PhysicalProjection { - proj := PhysicalProjection{ - Exprs: make([]expression.Expression, 0, len(p.Schema().Columns)), - }.Init(ctx, p.StatsInfo(), p.QueryBlockOffset()) - for _, col := range p.Schema().Columns { - proj.Exprs = append(proj.Exprs, col) - } - proj.SetSchema(p.Schema().Clone()) - proj.SetChildren(p) - return proj -} - -func appendExpr(p *PhysicalProjection, expr expression.Expression) *expression.Column { - p.Exprs = append(p.Exprs, expr) - - col := &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: expr.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), - } - col.SetCoercibility(expr.Coercibility()) - p.schema.Append(col) - return col -} - -// TiFlash join require that partition key has exactly the same type, while TiDB only guarantee the partition key is the same catalog, -// so if the partition key type is not exactly the same, we need add a projection below the join or exchanger if exists. -func (p *PhysicalHashJoin) convertPartitionKeysIfNeed(lTask, rTask *MppTask) (*MppTask, *MppTask) { - lp := lTask.p - if _, ok := lp.(*PhysicalExchangeReceiver); ok { - lp = lp.Children()[0].Children()[0] - } - rp := rTask.p - if _, ok := rp.(*PhysicalExchangeReceiver); ok { - rp = rp.Children()[0].Children()[0] - } - // to mark if any partition key needs to convert - lMask := make([]bool, len(lTask.hashCols)) - rMask := make([]bool, len(rTask.hashCols)) - cTypes := make([]*types.FieldType, len(lTask.hashCols)) - lChanged := false - rChanged := false - for i := range lTask.hashCols { - lKey := lTask.hashCols[i] - rKey := rTask.hashCols[i] - cType, lConvert, rConvert := negotiateCommonType(lKey.Col.RetType, rKey.Col.RetType) - if lConvert { - lMask[i] = true - cTypes[i] = cType - lChanged = true - } - if rConvert { - rMask[i] = true - cTypes[i] = cType - rChanged = true - } - } - if !lChanged && !rChanged { - return lTask, rTask - } - var lProj, rProj *PhysicalProjection - if lChanged { - lProj = getProj(p.SCtx(), lp) - lp = lProj - } - if rChanged { - rProj = getProj(p.SCtx(), rp) - rp = rProj - } - - lPartKeys := make([]*property.MPPPartitionColumn, 0, len(rTask.hashCols)) - rPartKeys := make([]*property.MPPPartitionColumn, 0, len(lTask.hashCols)) - for i := range lTask.hashCols { - lKey := lTask.hashCols[i] - rKey := rTask.hashCols[i] - if lMask[i] { - cType := cTypes[i].Clone() - cType.SetFlag(lKey.Col.RetType.GetFlag()) - lCast := expression.BuildCastFunction(p.SCtx().GetExprCtx(), lKey.Col, cType) - lKey = &property.MPPPartitionColumn{Col: appendExpr(lProj, lCast), CollateID: lKey.CollateID} - } - if rMask[i] { - cType := cTypes[i].Clone() - cType.SetFlag(rKey.Col.RetType.GetFlag()) - rCast := expression.BuildCastFunction(p.SCtx().GetExprCtx(), rKey.Col, cType) - rKey = &property.MPPPartitionColumn{Col: appendExpr(rProj, rCast), CollateID: rKey.CollateID} - } - lPartKeys = append(lPartKeys, lKey) - rPartKeys = append(rPartKeys, rKey) - } - // if left or right child changes, we need to add enforcer. - if lChanged { - nlTask := lTask.Copy().(*MppTask) - nlTask.p = lProj - nlTask = nlTask.enforceExchanger(&property.PhysicalProperty{ - TaskTp: property.MppTaskType, - MPPPartitionTp: property.HashType, - MPPPartitionCols: lPartKeys, - }) - lTask = nlTask - } - if rChanged { - nrTask := rTask.Copy().(*MppTask) - nrTask.p = rProj - nrTask = nrTask.enforceExchanger(&property.PhysicalProperty{ - TaskTp: property.MppTaskType, - MPPPartitionTp: property.HashType, - MPPPartitionCols: rPartKeys, - }) - rTask = nrTask - } - return lTask, rTask -} - -func (p *PhysicalHashJoin) attach2TaskForMpp(tasks ...base.Task) base.Task { - lTask, lok := tasks[0].(*MppTask) - rTask, rok := tasks[1].(*MppTask) - if !lok || !rok { - return base.InvalidTask - } - if p.mppShuffleJoin { - // protection check is case of some bugs - if len(lTask.hashCols) != len(rTask.hashCols) || len(lTask.hashCols) == 0 { - return base.InvalidTask - } - lTask, rTask = p.convertPartitionKeysIfNeed(lTask, rTask) - } - p.SetChildren(lTask.Plan(), rTask.Plan()) - // outer task is the task that will pass its MPPPartitionType to the join result - // for broadcast inner join, it should be the non-broadcast side, since broadcast side is always the build side, so - // just use the probe side is ok. - // for hash inner join, both side is ok, by default, we use the probe side - // for outer join, it should always be the outer side of the join - // for semi join, it should be the left side(the same as left out join) - outerTaskIndex := 1 - p.InnerChildIdx - if p.JoinType != InnerJoin { - if p.JoinType == RightOuterJoin { - outerTaskIndex = 1 - } else { - outerTaskIndex = 0 - } - } - // can not use the task from tasks because it maybe updated. - outerTask := lTask - if outerTaskIndex == 1 { - outerTask = rTask - } - task := &MppTask{ - p: p, - partTp: outerTask.partTp, - hashCols: outerTask.hashCols, - } - // Current TiFlash doesn't support receive Join executors' schema info directly from TiDB. - // Instead, it calculates Join executors' output schema using algorithm like BuildPhysicalJoinSchema which - // produces full semantic schema. - // Thus, the column prune optimization achievements will be abandoned here. - // To avoid the performance issue, add a projection here above the Join operator to prune useless columns explicitly. - // TODO(hyb): transfer Join executors' schema to TiFlash through DagRequest, and use it directly in TiFlash. - defaultSchema := BuildPhysicalJoinSchema(p.JoinType, p) - hashColArray := make([]*expression.Column, 0, len(task.hashCols)) - // For task.hashCols, these columns may not be contained in pruned columns: - // select A.id from A join B on A.id = B.id; Suppose B is probe side, and it's hash inner join. - // After column prune, the output schema of A join B will be A.id only; while the task's hashCols will be B.id. - // To make matters worse, the hashCols may be used to check if extra cast projection needs to be added, then the newly - // added projection will expect B.id as input schema. So make sure hashCols are included in task.p's schema. - // TODO: planner should takes the hashCols attribute into consideration when perform column pruning; Or provide mechanism - // to constraint hashCols are always chosen inside Join's pruned schema - for _, hashCol := range task.hashCols { - hashColArray = append(hashColArray, hashCol.Col) - } - if p.schema.Len() < defaultSchema.Len() { - if p.schema.Len() > 0 { - proj := PhysicalProjection{ - Exprs: expression.Column2Exprs(p.schema.Columns), - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) - - proj.SetSchema(p.Schema().Clone()) - for _, hashCol := range hashColArray { - if !proj.Schema().Contains(hashCol) && defaultSchema.Contains(hashCol) { - joinCol := defaultSchema.Columns[defaultSchema.ColumnIndex(hashCol)] - proj.Exprs = append(proj.Exprs, joinCol) - proj.Schema().Append(joinCol.Clone().(*expression.Column)) - } - } - attachPlan2Task(proj, task) - } else { - if len(hashColArray) == 0 { - constOne := expression.NewOne() - expr := make([]expression.Expression, 0, 1) - expr = append(expr, constOne) - proj := PhysicalProjection{ - Exprs: expr, - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) - - proj.schema = expression.NewSchema(&expression.Column{ - UniqueID: proj.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: constOne.GetType(p.SCtx().GetExprCtx().GetEvalCtx()), - }) - attachPlan2Task(proj, task) - } else { - proj := PhysicalProjection{ - Exprs: make([]expression.Expression, 0, len(hashColArray)), - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) - - clonedHashColArray := make([]*expression.Column, 0, len(task.hashCols)) - for _, hashCol := range hashColArray { - if defaultSchema.Contains(hashCol) { - joinCol := defaultSchema.Columns[defaultSchema.ColumnIndex(hashCol)] - proj.Exprs = append(proj.Exprs, joinCol) - clonedHashColArray = append(clonedHashColArray, joinCol.Clone().(*expression.Column)) - } - } - - proj.SetSchema(expression.NewSchema(clonedHashColArray...)) - attachPlan2Task(proj, task) - } - } - } - p.schema = defaultSchema - return task -} - -func (p *PhysicalHashJoin) attach2TaskForTiFlash(tasks ...base.Task) base.Task { - lTask, lok := tasks[0].(*CopTask) - rTask, rok := tasks[1].(*CopTask) - if !lok || !rok { - return p.attach2TaskForMpp(tasks...) - } - p.SetChildren(lTask.Plan(), rTask.Plan()) - p.schema = BuildPhysicalJoinSchema(p.JoinType, p) - if !lTask.indexPlanFinished { - lTask.finishIndexPlan() - } - if !rTask.indexPlanFinished { - rTask.finishIndexPlan() - } - - task := &CopTask{ - tblColHists: rTask.tblColHists, - indexPlanFinished: true, - tablePlan: p, - } - return task -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalMergeJoin) Attach2Task(tasks ...base.Task) base.Task { - lTask := tasks[0].ConvertToRootTask(p.SCtx()) - rTask := tasks[1].ConvertToRootTask(p.SCtx()) - p.SetChildren(lTask.Plan(), rTask.Plan()) - t := &RootTask{} - t.SetPlan(p) - return t -} - -func buildIndexLookUpTask(ctx base.PlanContext, t *CopTask) *RootTask { - newTask := &RootTask{} - p := PhysicalIndexLookUpReader{ - tablePlan: t.tablePlan, - indexPlan: t.indexPlan, - ExtraHandleCol: t.extraHandleCol, - CommonHandleCols: t.commonHandleCols, - expectedCnt: t.expectCnt, - keepOrder: t.keepOrder, - }.Init(ctx, t.tablePlan.QueryBlockOffset()) - p.PlanPartInfo = t.physPlanPartInfo - setTableScanToTableRowIDScan(p.tablePlan) - p.SetStats(t.tablePlan.StatsInfo()) - // Do not inject the extra Projection even if t.needExtraProj is set, or the schema between the phase-1 agg and - // the final agg would be broken. Please reference comments for the similar logic in - // (*copTask).convertToRootTaskImpl() for the PhysicalTableReader case. - // We need to refactor these logics. - aggPushedDown := false - switch p.tablePlan.(type) { - case *PhysicalHashAgg, *PhysicalStreamAgg: - aggPushedDown = true - } - - if t.needExtraProj && !aggPushedDown { - schema := t.originSchema - proj := PhysicalProjection{Exprs: expression.Column2Exprs(schema.Columns)}.Init(ctx, p.StatsInfo(), t.tablePlan.QueryBlockOffset(), nil) - proj.SetSchema(schema) - proj.SetChildren(p) - newTask.SetPlan(proj) - } else { - newTask.SetPlan(p) - } - return newTask -} - -func extractRows(p base.PhysicalPlan) float64 { - f := float64(0) - for _, c := range p.Children() { - if len(c.Children()) != 0 { - f += extractRows(c) - } else { - f += c.StatsInfo().RowCount - } - } - return f -} - -// calcPagingCost calculates the cost for paging processing which may increase the seekCnt and reduce scanned rows. -func calcPagingCost(ctx base.PlanContext, indexPlan base.PhysicalPlan, expectCnt uint64) float64 { - sessVars := ctx.GetSessionVars() - indexRows := indexPlan.StatsCount() - sourceRows := extractRows(indexPlan) - // with paging, the scanned rows is always less than or equal to source rows. - if uint64(sourceRows) < expectCnt { - expectCnt = uint64(sourceRows) - } - seekCnt := paging.CalculateSeekCnt(expectCnt) - indexSelectivity := float64(1) - if sourceRows > indexRows { - indexSelectivity = indexRows / sourceRows - } - pagingCst := seekCnt*sessVars.GetSeekFactor(nil) + float64(expectCnt)*sessVars.GetCPUFactor() - pagingCst *= indexSelectivity - - // we want the diff between idxCst and pagingCst here, - // however, the idxCst does not contain seekFactor, so a seekFactor needs to be removed - return math.Max(pagingCst-sessVars.GetSeekFactor(nil), 0) -} - -func (t *CopTask) handleRootTaskConds(ctx base.PlanContext, newTask *RootTask) { - if len(t.rootTaskConds) > 0 { - selectivity, _, err := cardinality.Selectivity(ctx, t.tblColHists, t.rootTaskConds, nil) - if err != nil { - logutil.BgLogger().Debug("calculate selectivity failed, use selection factor", zap.Error(err)) - selectivity = cost.SelectionFactor - } - sel := PhysicalSelection{Conditions: t.rootTaskConds}.Init(ctx, newTask.GetPlan().StatsInfo().Scale(selectivity), newTask.GetPlan().QueryBlockOffset()) - sel.fromDataSource = true - sel.SetChildren(newTask.GetPlan()) - newTask.SetPlan(sel) - } -} - -// setTableScanToTableRowIDScan is to update the isChildOfIndexLookUp attribute of PhysicalTableScan child -func setTableScanToTableRowIDScan(p base.PhysicalPlan) { - if ts, ok := p.(*PhysicalTableScan); ok { - ts.SetIsChildOfIndexLookUp(true) - } else { - for _, child := range p.Children() { - setTableScanToTableRowIDScan(child) - } - } -} - -// Attach2Task attach limit to different cases. -// For Normal Index Lookup -// 1: attach the limit to table side or index side of normal index lookup cop task. (normal case, old code, no more -// explanation here) -// -// For Index Merge: -// 2: attach the limit to **table** side for index merge intersection case, cause intersection will invalidate the -// fetched limit+offset rows from each partial index plan, you can not decide how many you want in advance for partial -// index path, actually. After we sink limit to table side, we still need an upper root limit to control the real limit -// count admission. -// -// 3: attach the limit to **index** side for index merge union case, because each index plan will output the fetched -// limit+offset (* N path) rows, you still need an embedded pushedLimit inside index merge reader to cut it down. -// -// 4: attach the limit to the TOP of root index merge operator if there is some root condition exists for index merge -// intersection/union case. -func (p *PhysicalLimit) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - newPartitionBy := make([]property.SortItem, 0, len(p.GetPartitionBy())) - for _, expr := range p.GetPartitionBy() { - newPartitionBy = append(newPartitionBy, expr.Clone()) - } - - sunk := false - if cop, ok := t.(*CopTask); ok { - suspendLimitAboveTablePlan := func() { - newCount := p.Offset + p.Count - childProfile := cop.tablePlan.StatsInfo() - // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. - stats := util.DeriveLimitStats(childProfile, float64(newCount)) - pushedDownLimit := PhysicalLimit{PartitionBy: newPartitionBy, Count: newCount}.Init(p.SCtx(), stats, p.QueryBlockOffset()) - pushedDownLimit.SetChildren(cop.tablePlan) - cop.tablePlan = pushedDownLimit - // Don't use clone() so that Limit and its children share the same schema. Otherwise, the virtual generated column may not be resolved right. - pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) - t = cop.ConvertToRootTask(p.SCtx()) - } - if len(cop.idxMergePartPlans) == 0 { - // For double read which requires order being kept, the limit cannot be pushed down to the table side, - // because handles would be reordered before being sent to table scan. - if (!cop.keepOrder || !cop.indexPlanFinished || cop.indexPlan == nil) && len(cop.rootTaskConds) == 0 { - // When limit is pushed down, we should remove its offset. - newCount := p.Offset + p.Count - childProfile := cop.Plan().StatsInfo() - // Strictly speaking, for the row count of stats, we should multiply newCount with "regionNum", - // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. - stats := util.DeriveLimitStats(childProfile, float64(newCount)) - pushedDownLimit := PhysicalLimit{PartitionBy: newPartitionBy, Count: newCount}.Init(p.SCtx(), stats, p.QueryBlockOffset()) - cop = attachPlan2Task(pushedDownLimit, cop).(*CopTask) - // Don't use clone() so that Limit and its children share the same schema. Otherwise the virtual generated column may not be resolved right. - pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) - } - t = cop.ConvertToRootTask(p.SCtx()) - sunk = p.sinkIntoIndexLookUp(t) - } else if !cop.idxMergeIsIntersection { - // We only support push part of the order prop down to index merge build case. - if len(cop.rootTaskConds) == 0 { - // For double read which requires order being kept, the limit cannot be pushed down to the table side, - // because handles would be reordered before being sent to table scan. - if cop.indexPlanFinished && !cop.keepOrder { - // when the index plan is finished and index plan is not ordered, sink the limit to the index merge table side. - suspendLimitAboveTablePlan() - } else if !cop.indexPlanFinished { - // cop.indexPlanFinished = false indicates the table side is a pure table-scan, sink the limit to the index merge index side. - newCount := p.Offset + p.Count - limitChildren := make([]base.PhysicalPlan, 0, len(cop.idxMergePartPlans)) - for _, partialScan := range cop.idxMergePartPlans { - childProfile := partialScan.StatsInfo() - stats := util.DeriveLimitStats(childProfile, float64(newCount)) - pushedDownLimit := PhysicalLimit{PartitionBy: newPartitionBy, Count: newCount}.Init(p.SCtx(), stats, p.QueryBlockOffset()) - pushedDownLimit.SetChildren(partialScan) - pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) - limitChildren = append(limitChildren, pushedDownLimit) - } - cop.idxMergePartPlans = limitChildren - t = cop.ConvertToRootTask(p.SCtx()) - sunk = p.sinkIntoIndexMerge(t) - } else { - // when there are some limitations, just sink the limit upon the index merge reader. - t = cop.ConvertToRootTask(p.SCtx()) - sunk = p.sinkIntoIndexMerge(t) - } - } else { - // when there are some root conditions, just sink the limit upon the index merge reader. - t = cop.ConvertToRootTask(p.SCtx()) - sunk = p.sinkIntoIndexMerge(t) - } - } else if cop.idxMergeIsIntersection { - // In the index merge with intersection case, only the limit can be pushed down to the index merge table side. - // Note Difference: - // IndexMerge.PushedLimit is applied before table scan fetching, limiting the indexPartialPlan rows returned (it maybe ordered if orderBy items not empty) - // TableProbeSide sink limit is applied on the top of table plan, which will quickly shut down the both fetch-back and read-back process. - if len(cop.rootTaskConds) == 0 { - if cop.indexPlanFinished { - // indicates the table side is not a pure table-scan, so we could only append the limit upon the table plan. - suspendLimitAboveTablePlan() - } else { - t = cop.ConvertToRootTask(p.SCtx()) - sunk = p.sinkIntoIndexMerge(t) - } - } else { - // Otherwise, suspend the limit out of index merge reader. - t = cop.ConvertToRootTask(p.SCtx()) - sunk = p.sinkIntoIndexMerge(t) - } - } else { - // Whatever the remained case is, we directly convert to it to root task. - t = cop.ConvertToRootTask(p.SCtx()) - } - } else if mpp, ok := t.(*MppTask); ok { - newCount := p.Offset + p.Count - childProfile := mpp.Plan().StatsInfo() - stats := util.DeriveLimitStats(childProfile, float64(newCount)) - pushedDownLimit := PhysicalLimit{Count: newCount, PartitionBy: newPartitionBy}.Init(p.SCtx(), stats, p.QueryBlockOffset()) - mpp = attachPlan2Task(pushedDownLimit, mpp).(*MppTask) - pushedDownLimit.SetSchema(pushedDownLimit.children[0].Schema()) - t = mpp.ConvertToRootTask(p.SCtx()) - } - if sunk { - return t - } - // Skip limit with partition on the root. This is a derived topN and window function - // will take care of the filter. - if len(p.GetPartitionBy()) > 0 { - return t - } - return attachPlan2Task(p, t) -} - -func (p *PhysicalLimit) sinkIntoIndexLookUp(t base.Task) bool { - root := t.(*RootTask) - reader, isDoubleRead := root.GetPlan().(*PhysicalIndexLookUpReader) - proj, isProj := root.GetPlan().(*PhysicalProjection) - if !isDoubleRead && !isProj { - return false - } - if isProj { - reader, isDoubleRead = proj.Children()[0].(*PhysicalIndexLookUpReader) - if !isDoubleRead { - return false - } - } - - // We can sink Limit into IndexLookUpReader only if tablePlan contains no Selection. - ts, isTableScan := reader.tablePlan.(*PhysicalTableScan) - if !isTableScan { - return false - } - - // If this happens, some Projection Operator must be inlined into this Limit. (issues/14428) - // For example, if the original plan is `IndexLookUp(col1, col2) -> Limit(col1, col2) -> Project(col1)`, - // then after inlining the Project, it will be `IndexLookUp(col1, col2) -> Limit(col1)` here. - // If the Limit is sunk into the IndexLookUp, the IndexLookUp's schema needs to be updated as well, - // So we add an extra projection to solve the problem. - if p.Schema().Len() != reader.Schema().Len() { - extraProj := PhysicalProjection{ - Exprs: expression.Column2Exprs(p.schema.Columns), - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), nil) - extraProj.SetSchema(p.schema) - // If the root.p is already a Projection. We left the optimization for the later Projection Elimination. - extraProj.SetChildren(root.GetPlan()) - root.SetPlan(extraProj) - } - - reader.PushedLimit = &PushedDownLimit{ - Offset: p.Offset, - Count: p.Count, - } - originStats := ts.StatsInfo() - ts.SetStats(p.StatsInfo()) - if originStats != nil { - // keep the original stats version - ts.StatsInfo().StatsVersion = originStats.StatsVersion - } - reader.SetStats(p.StatsInfo()) - if isProj { - proj.SetStats(p.StatsInfo()) - } - return true -} - -func (p *PhysicalLimit) sinkIntoIndexMerge(t base.Task) bool { - root := t.(*RootTask) - imReader, isIm := root.GetPlan().(*PhysicalIndexMergeReader) - proj, isProj := root.GetPlan().(*PhysicalProjection) - if !isIm && !isProj { - return false - } - if isProj { - imReader, isIm = proj.Children()[0].(*PhysicalIndexMergeReader) - if !isIm { - return false - } - } - ts, ok := imReader.tablePlan.(*PhysicalTableScan) - if !ok { - return false - } - imReader.PushedLimit = &PushedDownLimit{ - Count: p.Count, - Offset: p.Offset, - } - // since ts.statsInfo.rowcount may dramatically smaller than limit.statsInfo. - // like limit: rowcount=1 - // ts: rowcount=0.0025 - originStats := ts.StatsInfo() - if originStats != nil { - // keep the original stats version - ts.StatsInfo().StatsVersion = originStats.StatsVersion - if originStats.RowCount < p.StatsInfo().RowCount { - ts.StatsInfo().RowCount = originStats.RowCount - } - } - needProj := p.schema.Len() != root.GetPlan().Schema().Len() - if !needProj { - for i := 0; i < p.schema.Len(); i++ { - if !p.schema.Columns[i].EqualColumn(root.GetPlan().Schema().Columns[i]) { - needProj = true - break - } - } - } - if needProj { - extraProj := PhysicalProjection{ - Exprs: expression.Column2Exprs(p.schema.Columns), - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), nil) - extraProj.SetSchema(p.schema) - // If the root.p is already a Projection. We left the optimization for the later Projection Elimination. - extraProj.SetChildren(root.GetPlan()) - root.SetPlan(extraProj) - } - return true -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalSort) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - t = attachPlan2Task(p, t) - return t -} - -// Attach2Task implements PhysicalPlan interface. -func (p *NominalSort) Attach2Task(tasks ...base.Task) base.Task { - if p.OnlyColumn { - return tasks[0] - } - t := tasks[0].Copy() - t = attachPlan2Task(p, t) - return t -} - -func (p *PhysicalTopN) getPushedDownTopN(childPlan base.PhysicalPlan) *PhysicalTopN { - newByItems := make([]*util.ByItems, 0, len(p.ByItems)) - for _, expr := range p.ByItems { - newByItems = append(newByItems, expr.Clone()) - } - newPartitionBy := make([]property.SortItem, 0, len(p.GetPartitionBy())) - for _, expr := range p.GetPartitionBy() { - newPartitionBy = append(newPartitionBy, expr.Clone()) - } - newCount := p.Offset + p.Count - childProfile := childPlan.StatsInfo() - // Strictly speaking, for the row count of pushed down TopN, we should multiply newCount with "regionNum", - // but "regionNum" is unknown since the copTask can be a double read, so we ignore it now. - stats := util.DeriveLimitStats(childProfile, float64(newCount)) - topN := PhysicalTopN{ - ByItems: newByItems, - PartitionBy: newPartitionBy, - Count: newCount, - }.Init(p.SCtx(), stats, p.QueryBlockOffset(), p.GetChildReqProps(0)) - topN.SetChildren(childPlan) - return topN -} - -// canPushToIndexPlan checks if this TopN can be pushed to the index side of copTask. -// It can be pushed to the index side when all columns used by ByItems are available from the index side and there's no prefix index column. -func (*PhysicalTopN) canPushToIndexPlan(indexPlan base.PhysicalPlan, byItemCols []*expression.Column) bool { - // If we call canPushToIndexPlan and there's no index plan, we should go into the index merge case. - // Index merge case is specially handled for now. So we directly return false here. - // So we directly return false. - if indexPlan == nil { - return false - } - schema := indexPlan.Schema() - for _, col := range byItemCols { - pos := schema.ColumnIndex(col) - if pos == -1 { - return false - } - if schema.Columns[pos].IsPrefix { - return false - } - } - return true -} - -// canExpressionConvertedToPB checks whether each of the the expression in TopN can be converted to pb. -func (p *PhysicalTopN) canExpressionConvertedToPB(storeTp kv.StoreType) bool { - exprs := make([]expression.Expression, 0, len(p.ByItems)) - for _, item := range p.ByItems { - exprs = append(exprs, item.Expr) - } - return expression.CanExprsPushDown(GetPushDownCtx(p.SCtx()), exprs, storeTp) -} - -// containVirtualColumn checks whether TopN.ByItems contains virtual generated columns. -func (p *PhysicalTopN) containVirtualColumn(tCols []*expression.Column) bool { - tColSet := make(map[int64]struct{}, len(tCols)) - for _, tCol := range tCols { - if tCol.ID > 0 && tCol.VirtualExpr != nil { - tColSet[tCol.ID] = struct{}{} - } - } - for _, by := range p.ByItems { - cols := expression.ExtractColumns(by.Expr) - for _, col := range cols { - if _, ok := tColSet[col.ID]; ok { - // A column with ID > 0 indicates that the column can be resolved by data source. - return true - } - } - } - return false -} - -// canPushDownToTiKV checks whether this topN can be pushed down to TiKV. -func (p *PhysicalTopN) canPushDownToTiKV(copTask *CopTask) bool { - if !p.canExpressionConvertedToPB(kv.TiKV) { - return false - } - if len(copTask.rootTaskConds) != 0 { - return false - } - if !copTask.indexPlanFinished && len(copTask.idxMergePartPlans) > 0 { - for _, partialPlan := range copTask.idxMergePartPlans { - if p.containVirtualColumn(partialPlan.Schema().Columns) { - return false - } - } - } else if p.containVirtualColumn(copTask.Plan().Schema().Columns) { - return false - } - return true -} - -// canPushDownToTiFlash checks whether this topN can be pushed down to TiFlash. -func (p *PhysicalTopN) canPushDownToTiFlash(mppTask *MppTask) bool { - if !p.canExpressionConvertedToPB(kv.TiFlash) { - return false - } - if p.containVirtualColumn(mppTask.Plan().Schema().Columns) { - return false - } - return true -} - -// Attach2Task implements physical plan -func (p *PhysicalTopN) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - cols := make([]*expression.Column, 0, len(p.ByItems)) - for _, item := range p.ByItems { - cols = append(cols, expression.ExtractColumns(item.Expr)...) - } - needPushDown := len(cols) > 0 - if copTask, ok := t.(*CopTask); ok && needPushDown && p.canPushDownToTiKV(copTask) && len(copTask.rootTaskConds) == 0 { - // If all columns in topN are from index plan, we push it to index plan, otherwise we finish the index plan and - // push it to table plan. - var pushedDownTopN *PhysicalTopN - if !copTask.indexPlanFinished && p.canPushToIndexPlan(copTask.indexPlan, cols) { - pushedDownTopN = p.getPushedDownTopN(copTask.indexPlan) - copTask.indexPlan = pushedDownTopN - } else { - // It works for both normal index scan and index merge scan. - copTask.finishIndexPlan() - pushedDownTopN = p.getPushedDownTopN(copTask.tablePlan) - copTask.tablePlan = pushedDownTopN - } - } else if mppTask, ok := t.(*MppTask); ok && needPushDown && p.canPushDownToTiFlash(mppTask) { - pushedDownTopN := p.getPushedDownTopN(mppTask.p) - mppTask.p = pushedDownTopN - } - rootTask := t.ConvertToRootTask(p.SCtx()) - // Skip TopN with partition on the root. This is a derived topN and window function - // will take care of the filter. - if len(p.GetPartitionBy()) > 0 { - return t - } - return attachPlan2Task(p, rootTask) -} - -// Attach2Task implements the PhysicalPlan interface. -func (p *PhysicalExpand) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - // current expand can only be run in MPP TiFlash mode or Root Tidb mode. - // if expr inside could not be pushed down to tiFlash, it will error in converting to pb side. - if mpp, ok := t.(*MppTask); ok { - p.SetChildren(mpp.p) - mpp.p = p - return mpp - } - // For root task - // since expand should be in root side accordingly, convert to root task now. - root := t.ConvertToRootTask(p.SCtx()) - t = attachPlan2Task(p, root) - if root, ok := tasks[0].(*RootTask); ok && root.IsEmpty() { - t.(*RootTask).SetEmpty(true) - } - return t -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalProjection) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - if cop, ok := t.(*CopTask); ok { - if (len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0) && expression.CanExprsPushDown(GetPushDownCtx(p.SCtx()), p.Exprs, cop.getStoreType()) { - copTask := attachPlan2Task(p, cop) - return copTask - } - } else if mpp, ok := t.(*MppTask); ok { - if expression.CanExprsPushDown(GetPushDownCtx(p.SCtx()), p.Exprs, kv.TiFlash) { - p.SetChildren(mpp.p) - mpp.p = p - return mpp - } - } - t = t.ConvertToRootTask(p.SCtx()) - t = attachPlan2Task(p, t) - if root, ok := tasks[0].(*RootTask); ok && root.IsEmpty() { - t.(*RootTask).SetEmpty(true) - } - return t -} - -func (p *PhysicalUnionAll) attach2MppTasks(tasks ...base.Task) base.Task { - t := &MppTask{p: p} - childPlans := make([]base.PhysicalPlan, 0, len(tasks)) - for _, tk := range tasks { - if mpp, ok := tk.(*MppTask); ok && !tk.Invalid() { - childPlans = append(childPlans, mpp.Plan()) - } else if root, ok := tk.(*RootTask); ok && root.IsEmpty() { - continue - } else { - return base.InvalidTask - } - } - if len(childPlans) == 0 { - return base.InvalidTask - } - p.SetChildren(childPlans...) - return t -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalUnionAll) Attach2Task(tasks ...base.Task) base.Task { - for _, t := range tasks { - if _, ok := t.(*MppTask); ok { - if p.TP() == plancodec.TypePartitionUnion { - // In attach2MppTasks(), will attach PhysicalUnion to mppTask directly. - // But PartitionUnion cannot pushdown to tiflash, so here disable PartitionUnion pushdown to tiflash explicitly. - // For now, return base.InvalidTask immediately, we can refine this by letting childTask of PartitionUnion convert to rootTask. - return base.InvalidTask - } - return p.attach2MppTasks(tasks...) - } - } - t := &RootTask{} - t.SetPlan(p) - childPlans := make([]base.PhysicalPlan, 0, len(tasks)) - for _, task := range tasks { - task = task.ConvertToRootTask(p.SCtx()) - childPlans = append(childPlans, task.Plan()) - } - p.SetChildren(childPlans...) - return t -} - -// Attach2Task implements PhysicalPlan interface. -func (sel *PhysicalSelection) Attach2Task(tasks ...base.Task) base.Task { - if mppTask, _ := tasks[0].(*MppTask); mppTask != nil { // always push to mpp task. - if expression.CanExprsPushDown(GetPushDownCtx(sel.SCtx()), sel.Conditions, kv.TiFlash) { - return attachPlan2Task(sel, mppTask.Copy()) - } - } - t := tasks[0].ConvertToRootTask(sel.SCtx()) - return attachPlan2Task(sel, t) -} - -// CheckAggCanPushCop checks whether the aggFuncs and groupByItems can -// be pushed down to coprocessor. -func CheckAggCanPushCop(sctx base.PlanContext, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression, storeType kv.StoreType) bool { - sc := sctx.GetSessionVars().StmtCtx - ret := true - reason := "" - pushDownCtx := GetPushDownCtx(sctx) - for _, aggFunc := range aggFuncs { - // if the aggFunc contain VirtualColumn or CorrelatedColumn, it can not be pushed down. - if expression.ContainVirtualColumn(aggFunc.Args) || expression.ContainCorrelatedColumn(aggFunc.Args) { - reason = "expressions of AggFunc `" + aggFunc.Name + "` contain virtual column or correlated column, which is not supported now" - ret = false - break - } - if !aggregation.CheckAggPushDown(sctx.GetExprCtx().GetEvalCtx(), aggFunc, storeType) { - reason = "AggFunc `" + aggFunc.Name + "` is not supported now" - ret = false - break - } - if !expression.CanExprsPushDownWithExtraInfo(GetPushDownCtx(sctx), aggFunc.Args, storeType, aggFunc.Name == ast.AggFuncSum) { - reason = "arguments of AggFunc `" + aggFunc.Name + "` contains unsupported exprs" - ret = false - break - } - orderBySize := len(aggFunc.OrderByItems) - if orderBySize > 0 { - exprs := make([]expression.Expression, 0, orderBySize) - for _, item := range aggFunc.OrderByItems { - exprs = append(exprs, item.Expr) - } - if !expression.CanExprsPushDownWithExtraInfo(GetPushDownCtx(sctx), exprs, storeType, false) { - reason = "arguments of AggFunc `" + aggFunc.Name + "` contains unsupported exprs in order-by clause" - ret = false - break - } - } - pb, _ := aggregation.AggFuncToPBExpr(pushDownCtx, aggFunc, storeType) - if pb == nil { - reason = "AggFunc `" + aggFunc.Name + "` can not be converted to pb expr" - ret = false - break - } - } - if ret && expression.ContainVirtualColumn(groupByItems) { - reason = "groupByItems contain virtual columns, which is not supported now" - ret = false - } - if ret && !expression.CanExprsPushDown(GetPushDownCtx(sctx), groupByItems, storeType) { - reason = "groupByItems contain unsupported exprs" - ret = false - } - - if !ret { - storageName := storeType.Name() - if storeType == kv.UnSpecified { - storageName = "storage layer" - } - warnErr := errors.NewNoStackError("Aggregation can not be pushed to " + storageName + " because " + reason) - if sc.InExplainStmt { - sc.AppendWarning(warnErr) - } else { - sc.AppendExtraWarning(warnErr) - } - } - return ret -} - -// AggInfo stores the information of an Aggregation. -type AggInfo struct { - AggFuncs []*aggregation.AggFuncDesc - GroupByItems []expression.Expression - Schema *expression.Schema -} - -// BuildFinalModeAggregation splits either LogicalAggregation or PhysicalAggregation to finalAgg and partial1Agg, -// returns the information of partial and final agg. -// partialIsCop means whether partial agg is a cop task. When partialIsCop is false, -// we do not set the AggMode for partialAgg cause it may be split further when -// building the aggregate executor(e.g. buildHashAgg will split the AggDesc further for parallel executing). -// firstRowFuncMap is a map between partial first_row to final first_row, will be used in RemoveUnnecessaryFirstRow -func BuildFinalModeAggregation( - sctx base.PlanContext, original *AggInfo, partialIsCop bool, isMPPTask bool) (partial, final *AggInfo, firstRowFuncMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) { - ectx := sctx.GetExprCtx().GetEvalCtx() - - firstRowFuncMap = make(map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc, len(original.AggFuncs)) - partial = &AggInfo{ - AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(original.AggFuncs)), - GroupByItems: original.GroupByItems, - Schema: expression.NewSchema(), - } - partialCursor := 0 - final = &AggInfo{ - AggFuncs: make([]*aggregation.AggFuncDesc, len(original.AggFuncs)), - GroupByItems: make([]expression.Expression, 0, len(original.GroupByItems)), - Schema: original.Schema, - } - - partialGbySchema := expression.NewSchema() - // add group by columns - for _, gbyExpr := range partial.GroupByItems { - var gbyCol *expression.Column - if col, ok := gbyExpr.(*expression.Column); ok { - gbyCol = col - } else { - gbyCol = &expression.Column{ - UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), - RetType: gbyExpr.GetType(ectx), - } - } - partialGbySchema.Append(gbyCol) - final.GroupByItems = append(final.GroupByItems, gbyCol) - } - - // TODO: Refactor the way of constructing aggregation functions. - // This for loop is ugly, but I do not find a proper way to reconstruct - // it right away. - - // group_concat is special when pushing down, it cannot take the two phase execution if no distinct but with orderBy, and other cases are also different: - // for example: group_concat([distinct] expr0, expr1[, order by expr2] separator ‘,’) - // no distinct, no orderBy: can two phase - // [final agg] group_concat(col#1,’,’) - // [part agg] group_concat(expr0, expr1,’,’) -> col#1 - // no distinct, orderBy: only one phase - // distinct, no orderBy: can two phase - // [final agg] group_concat(distinct col#0, col#1,’,’) - // [part agg] group by expr0 ->col#0, expr1 -> col#1 - // distinct, orderBy: can two phase - // [final agg] group_concat(distinct col#0, col#1, order by col#2,’,’) - // [part agg] group by expr0 ->col#0, expr1 -> col#1; agg function: firstrow(expr2)-> col#2 - - for i, aggFunc := range original.AggFuncs { - finalAggFunc := &aggregation.AggFuncDesc{HasDistinct: false} - finalAggFunc.Name = aggFunc.Name - finalAggFunc.OrderByItems = aggFunc.OrderByItems - args := make([]expression.Expression, 0, len(aggFunc.Args)) - if aggFunc.HasDistinct { - /* - eg: SELECT COUNT(DISTINCT a), SUM(b) FROM t GROUP BY c - - change from - [root] group by: c, funcs:count(distinct a), funcs:sum(b) - to - [root] group by: c, funcs:count(distinct a), funcs:sum(b) - [cop]: group by: c, a - */ - // onlyAddFirstRow means if the distinctArg does not occur in group by items, - // it should be replaced with a firstrow() agg function, needed for the order by items of group_concat() - getDistinctExpr := func(distinctArg expression.Expression, onlyAddFirstRow bool) (ret expression.Expression) { - // 1. add all args to partial.GroupByItems - foundInGroupBy := false - for j, gbyExpr := range partial.GroupByItems { - if gbyExpr.Equal(ectx, distinctArg) && gbyExpr.GetType(ectx).Equal(distinctArg.GetType(ectx)) { - // if the two expressions exactly the same in terms of data types and collation, then can avoid it. - foundInGroupBy = true - ret = partialGbySchema.Columns[j] - break - } - } - if !foundInGroupBy { - var gbyCol *expression.Column - if col, ok := distinctArg.(*expression.Column); ok { - gbyCol = col - } else { - gbyCol = &expression.Column{ - UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), - RetType: distinctArg.GetType(ectx), - } - } - // 2. add group by items if needed - if !onlyAddFirstRow { - partial.GroupByItems = append(partial.GroupByItems, distinctArg) - partialGbySchema.Append(gbyCol) - ret = gbyCol - } - // 3. add firstrow() if needed - if !partialIsCop || onlyAddFirstRow { - // if partial is a cop task, firstrow function is redundant since group by items are outputted - // by group by schema, and final functions use group by schema as their arguments. - // if partial agg is not cop, we must append firstrow function & schema, to output the group by - // items. - // maybe we can unify them sometime. - // only add firstrow for order by items of group_concat() - firstRow, err := aggregation.NewAggFuncDesc(sctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{distinctArg}, false) - if err != nil { - panic("NewAggFuncDesc FirstRow meets error: " + err.Error()) - } - partial.AggFuncs = append(partial.AggFuncs, firstRow) - newCol, _ := gbyCol.Clone().(*expression.Column) - newCol.RetType = firstRow.RetTp - partial.Schema.Append(newCol) - if onlyAddFirstRow { - ret = newCol - } - partialCursor++ - } - } - return ret - } - - for j, distinctArg := range aggFunc.Args { - // the last arg of ast.AggFuncGroupConcat is the separator, so just put it into the final agg - if aggFunc.Name == ast.AggFuncGroupConcat && j+1 == len(aggFunc.Args) { - args = append(args, distinctArg) - continue - } - args = append(args, getDistinctExpr(distinctArg, false)) - } - - byItems := make([]*util.ByItems, 0, len(aggFunc.OrderByItems)) - for _, byItem := range aggFunc.OrderByItems { - byItems = append(byItems, &util.ByItems{Expr: getDistinctExpr(byItem.Expr, true), Desc: byItem.Desc}) - } - - if aggFunc.HasDistinct && isMPPTask && aggFunc.GroupingID > 0 { - // keep the groupingID as it was, otherwise the new split final aggregate's ganna lost its groupingID info. - finalAggFunc.GroupingID = aggFunc.GroupingID - } - - finalAggFunc.OrderByItems = byItems - finalAggFunc.HasDistinct = aggFunc.HasDistinct - // In logical optimize phase, the Agg->PartitionUnion->TableReader may become - // Agg1->PartitionUnion->Agg2->TableReader, and the Agg2 is a partial aggregation. - // So in the push down here, we need to add a new if-condition check: - // If the original agg mode is partial already, the finalAggFunc's mode become Partial2. - if aggFunc.Mode == aggregation.CompleteMode { - finalAggFunc.Mode = aggregation.CompleteMode - } else if aggFunc.Mode == aggregation.Partial1Mode || aggFunc.Mode == aggregation.Partial2Mode { - finalAggFunc.Mode = aggregation.Partial2Mode - } - } else { - if aggFunc.Name == ast.AggFuncGroupConcat && len(aggFunc.OrderByItems) > 0 { - // group_concat can only run in one phase if it has order by items but without distinct property - partial = nil - final = original - return - } - if aggregation.NeedCount(finalAggFunc.Name) { - // only Avg and Count need count - if isMPPTask && finalAggFunc.Name == ast.AggFuncCount { - // For MPP base.Task, the final count() is changed to sum(). - // Note: MPP mode does not run avg() directly, instead, avg() -> sum()/(case when count() = 0 then 1 else count() end), - // so we do not process it here. - finalAggFunc.Name = ast.AggFuncSum - } else { - // avg branch - ft := types.NewFieldType(mysql.TypeLonglong) - ft.SetFlen(21) - ft.SetCharset(charset.CharsetBin) - ft.SetCollate(charset.CollationBin) - partial.Schema.Append(&expression.Column{ - UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), - RetType: ft, - }) - args = append(args, partial.Schema.Columns[partialCursor]) - partialCursor++ - } - } - if finalAggFunc.Name == ast.AggFuncApproxCountDistinct { - ft := types.NewFieldType(mysql.TypeString) - ft.SetCharset(charset.CharsetBin) - ft.SetCollate(charset.CollationBin) - ft.AddFlag(mysql.NotNullFlag) - partial.Schema.Append(&expression.Column{ - UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), - RetType: ft, - }) - args = append(args, partial.Schema.Columns[partialCursor]) - partialCursor++ - } - if aggregation.NeedValue(finalAggFunc.Name) { - partial.Schema.Append(&expression.Column{ - UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), - RetType: original.Schema.Columns[i].GetType(ectx), - }) - args = append(args, partial.Schema.Columns[partialCursor]) - partialCursor++ - } - if aggFunc.Name == ast.AggFuncAvg { - cntAgg := aggFunc.Clone() - cntAgg.Name = ast.AggFuncCount - err := cntAgg.TypeInfer(sctx.GetExprCtx()) - if err != nil { // must not happen - partial = nil - final = original - return - } - partial.Schema.Columns[partialCursor-2].RetType = cntAgg.RetTp - // we must call deep clone in this case, to avoid sharing the arguments. - sumAgg := aggFunc.Clone() - sumAgg.Name = ast.AggFuncSum - sumAgg.TypeInfer4AvgSum(sumAgg.RetTp) - partial.Schema.Columns[partialCursor-1].RetType = sumAgg.RetTp - partial.AggFuncs = append(partial.AggFuncs, cntAgg, sumAgg) - } else if aggFunc.Name == ast.AggFuncApproxCountDistinct || aggFunc.Name == ast.AggFuncGroupConcat { - newAggFunc := aggFunc.Clone() - newAggFunc.Name = aggFunc.Name - newAggFunc.RetTp = partial.Schema.Columns[partialCursor-1].GetType(ectx) - partial.AggFuncs = append(partial.AggFuncs, newAggFunc) - if aggFunc.Name == ast.AggFuncGroupConcat { - // append the last separator arg - args = append(args, aggFunc.Args[len(aggFunc.Args)-1]) - } - } else { - // other agg desc just split into two parts - partialFuncDesc := aggFunc.Clone() - partial.AggFuncs = append(partial.AggFuncs, partialFuncDesc) - if aggFunc.Name == ast.AggFuncFirstRow { - firstRowFuncMap[partialFuncDesc] = finalAggFunc - } - } - - // In logical optimize phase, the Agg->PartitionUnion->TableReader may become - // Agg1->PartitionUnion->Agg2->TableReader, and the Agg2 is a partial aggregation. - // So in the push down here, we need to add a new if-condition check: - // If the original agg mode is partial already, the finalAggFunc's mode become Partial2. - if aggFunc.Mode == aggregation.CompleteMode { - finalAggFunc.Mode = aggregation.FinalMode - } else if aggFunc.Mode == aggregation.Partial1Mode || aggFunc.Mode == aggregation.Partial2Mode { - finalAggFunc.Mode = aggregation.Partial2Mode - } - } - - finalAggFunc.Args = args - finalAggFunc.RetTp = aggFunc.RetTp - final.AggFuncs[i] = finalAggFunc - } - partial.Schema.Append(partialGbySchema.Columns...) - if partialIsCop { - for _, f := range partial.AggFuncs { - f.Mode = aggregation.Partial1Mode - } - } - return -} - -// convertAvgForMPP converts avg(arg) to sum(arg)/(case when count(arg)=0 then 1 else count(arg) end), in detail: -// 1.rewrite avg() in the final aggregation to count() and sum(), and reconstruct its schema. -// 2.replace avg() with sum(arg)/(case when count(arg)=0 then 1 else count(arg) end) and reuse the original schema of the final aggregation. -// If there is no avg, nothing is changed and return nil. -func (p *basePhysicalAgg) convertAvgForMPP() *PhysicalProjection { - newSchema := expression.NewSchema() - newSchema.Keys = p.schema.Keys - newSchema.UniqueKeys = p.schema.UniqueKeys - newAggFuncs := make([]*aggregation.AggFuncDesc, 0, 2*len(p.AggFuncs)) - exprs := make([]expression.Expression, 0, 2*len(p.schema.Columns)) - // add agg functions schema - for i, aggFunc := range p.AggFuncs { - if aggFunc.Name == ast.AggFuncAvg { - // inset a count(column) - avgCount := aggFunc.Clone() - avgCount.Name = ast.AggFuncCount - err := avgCount.TypeInfer(p.SCtx().GetExprCtx()) - if err != nil { // must not happen - return nil - } - newAggFuncs = append(newAggFuncs, avgCount) - avgCountCol := &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: avgCount.RetTp, - } - newSchema.Append(avgCountCol) - // insert a sum(column) - avgSum := aggFunc.Clone() - avgSum.Name = ast.AggFuncSum - avgSum.TypeInfer4AvgSum(avgSum.RetTp) - newAggFuncs = append(newAggFuncs, avgSum) - avgSumCol := &expression.Column{ - UniqueID: p.schema.Columns[i].UniqueID, - RetType: avgSum.RetTp, - } - newSchema.Append(avgSumCol) - // avgSumCol/(case when avgCountCol=0 then 1 else avgCountCol end) - eq := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), avgCountCol, expression.NewZero()) - caseWhen := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Case, avgCountCol.RetType, eq, expression.NewOne(), avgCountCol) - divide := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Div, avgSumCol.RetType, avgSumCol, caseWhen) - divide.(*expression.ScalarFunction).RetType = p.schema.Columns[i].RetType - exprs = append(exprs, divide) - } else { - // other non-avg agg use the old schema as it did. - newAggFuncs = append(newAggFuncs, aggFunc) - newSchema.Append(p.schema.Columns[i]) - exprs = append(exprs, p.schema.Columns[i]) - } - } - // no avgs - // for final agg, always add project due to in-compatibility between TiDB and TiFlash - if len(p.schema.Columns) == len(newSchema.Columns) && !p.IsFinalAgg() { - return nil - } - // add remaining columns to exprs - for i := len(p.AggFuncs); i < len(p.schema.Columns); i++ { - exprs = append(exprs, p.schema.Columns[i]) - } - proj := PhysicalProjection{ - Exprs: exprs, - CalculateNoDelay: false, - AvoidColumnEvaluator: false, - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), p.GetChildReqProps(0).CloneEssentialFields()) - proj.SetSchema(p.schema) - - p.AggFuncs = newAggFuncs - p.schema = newSchema - - return proj -} - -func (p *basePhysicalAgg) newPartialAggregate(copTaskType kv.StoreType, isMPPTask bool) (partial, final base.PhysicalPlan) { - // Check if this aggregation can push down. - if !CheckAggCanPushCop(p.SCtx(), p.AggFuncs, p.GroupByItems, copTaskType) { - return nil, p.self - } - partialPref, finalPref, firstRowFuncMap := BuildFinalModeAggregation(p.SCtx(), &AggInfo{ - AggFuncs: p.AggFuncs, - GroupByItems: p.GroupByItems, - Schema: p.Schema().Clone(), - }, true, isMPPTask) - if partialPref == nil { - return nil, p.self - } - if p.TP() == plancodec.TypeStreamAgg && len(partialPref.GroupByItems) != len(finalPref.GroupByItems) { - return nil, p.self - } - // Remove unnecessary FirstRow. - partialPref.AggFuncs = RemoveUnnecessaryFirstRow(p.SCtx(), - finalPref.GroupByItems, partialPref.AggFuncs, partialPref.GroupByItems, partialPref.Schema, firstRowFuncMap) - if copTaskType == kv.TiDB { - // For partial agg of TiDB cop task, since TiDB coprocessor reuse the TiDB executor, - // and TiDB aggregation executor won't output the group by value, - // so we need add `firstrow` aggregation function to output the group by value. - aggFuncs, err := genFirstRowAggForGroupBy(p.SCtx(), partialPref.GroupByItems) - if err != nil { - return nil, p.self - } - partialPref.AggFuncs = append(partialPref.AggFuncs, aggFuncs...) - } - p.AggFuncs = partialPref.AggFuncs - p.GroupByItems = partialPref.GroupByItems - p.schema = partialPref.Schema - partialAgg := p.self - // Create physical "final" aggregation. - prop := &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64} - if p.TP() == plancodec.TypeStreamAgg { - finalAgg := basePhysicalAgg{ - AggFuncs: finalPref.AggFuncs, - GroupByItems: finalPref.GroupByItems, - MppRunMode: p.MppRunMode, - }.initForStream(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), prop) - finalAgg.schema = finalPref.Schema - return partialAgg, finalAgg - } - - finalAgg := basePhysicalAgg{ - AggFuncs: finalPref.AggFuncs, - GroupByItems: finalPref.GroupByItems, - MppRunMode: p.MppRunMode, - }.initForHash(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset(), prop) - finalAgg.schema = finalPref.Schema - // partialAgg and finalAgg use the same ref of stats - return partialAgg, finalAgg -} - -func (p *basePhysicalAgg) scale3StageForDistinctAgg() (bool, expression.GroupingSets) { - if p.canUse3Stage4SingleDistinctAgg() { - return true, nil - } - return p.canUse3Stage4MultiDistinctAgg() -} - -// canUse3Stage4MultiDistinctAgg returns true if this agg can use 3 stage for multi distinct aggregation -func (p *basePhysicalAgg) canUse3Stage4MultiDistinctAgg() (can bool, gss expression.GroupingSets) { - if !p.SCtx().GetSessionVars().Enable3StageDistinctAgg || !p.SCtx().GetSessionVars().Enable3StageMultiDistinctAgg || len(p.GroupByItems) > 0 { - return false, nil - } - defer func() { - // some clean work. - if !can { - for _, fun := range p.AggFuncs { - fun.GroupingID = 0 - } - } - }() - // groupingSets is alias of []GroupingSet, the below equal to = make([]GroupingSet, 0, 2) - groupingSets := make(expression.GroupingSets, 0, 2) - for _, fun := range p.AggFuncs { - if fun.HasDistinct { - if fun.Name != ast.AggFuncCount { - // now only for multi count(distinct x) - return false, nil - } - for _, arg := range fun.Args { - // bail out when args are not simple column, see GitHub issue #35417 - if _, ok := arg.(*expression.Column); !ok { - return false, nil - } - } - // here it's a valid count distinct agg with normal column args, collecting its distinct expr. - groupingSets = append(groupingSets, expression.GroupingSet{fun.Args}) - // groupingID now is the offset of target grouping in GroupingSets. - // todo: it may be changed after grouping set merge in the future. - fun.GroupingID = len(groupingSets) - } else if len(fun.Args) > 1 { - return false, nil - } - // banned group_concat(x order by y) - if len(fun.OrderByItems) > 0 || fun.Mode != aggregation.CompleteMode { - return false, nil - } - } - compressed := groupingSets.Merge() - if len(compressed) != len(groupingSets) { - p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("Some grouping sets should be merged")) - // todo arenatlx: some grouping set should be merged which is not supported by now temporarily. - return false, nil - } - if groupingSets.NeedCloneColumn() { - // todo: column clone haven't implemented. - return false, nil - } - if len(groupingSets) > 1 { - // fill the grouping ID for normal agg. - for _, fun := range p.AggFuncs { - if fun.GroupingID == 0 { - // the grouping ID hasn't set. find the targeting grouping set. - groupingSetOffset := groupingSets.TargetOne(fun.Args) - if groupingSetOffset == -1 { - // todo: if we couldn't find a existed current valid group layout, we need to copy the column out from being filled with null value. - p.SCtx().GetSessionVars().StmtCtx.AppendWarning(errors.NewNoStackErrorf("couldn't find a proper group set for normal agg")) - return false, nil - } - // starting with 1 - fun.GroupingID = groupingSetOffset + 1 - } - } - return true, groupingSets - } - return false, nil -} - -// canUse3Stage4SingleDistinctAgg returns true if this agg can use 3 stage for distinct aggregation -func (p *basePhysicalAgg) canUse3Stage4SingleDistinctAgg() bool { - num := 0 - if !p.SCtx().GetSessionVars().Enable3StageDistinctAgg || len(p.GroupByItems) > 0 { - return false - } - for _, fun := range p.AggFuncs { - if fun.HasDistinct { - num++ - if num > 1 || fun.Name != ast.AggFuncCount { - return false - } - for _, arg := range fun.Args { - // bail out when args are not simple column, see GitHub issue #35417 - if _, ok := arg.(*expression.Column); !ok { - return false - } - } - } else if len(fun.Args) > 1 { - return false - } - - if len(fun.OrderByItems) > 0 || fun.Mode != aggregation.CompleteMode { - return false - } - } - return num == 1 -} - -func genFirstRowAggForGroupBy(ctx base.PlanContext, groupByItems []expression.Expression) ([]*aggregation.AggFuncDesc, error) { - aggFuncs := make([]*aggregation.AggFuncDesc, 0, len(groupByItems)) - for _, groupBy := range groupByItems { - agg, err := aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{groupBy}, false) - if err != nil { - return nil, err - } - aggFuncs = append(aggFuncs, agg) - } - return aggFuncs, nil -} - -// RemoveUnnecessaryFirstRow removes unnecessary FirstRow of the aggregation. This function can be -// used for both LogicalAggregation and PhysicalAggregation. -// When the select column is same with the group by key, the column can be removed and gets value from the group by key. -// e.g -// select a, count(b) from t group by a; -// The schema is [firstrow(a), count(b), a]. The column firstrow(a) is unnecessary. -// Can optimize the schema to [count(b), a] , and change the index to get value. -func RemoveUnnecessaryFirstRow( - sctx base.PlanContext, - finalGbyItems []expression.Expression, - partialAggFuncs []*aggregation.AggFuncDesc, - partialGbyItems []expression.Expression, - partialSchema *expression.Schema, - firstRowFuncMap map[*aggregation.AggFuncDesc]*aggregation.AggFuncDesc) []*aggregation.AggFuncDesc { - partialCursor := 0 - newAggFuncs := make([]*aggregation.AggFuncDesc, 0, len(partialAggFuncs)) - for _, aggFunc := range partialAggFuncs { - if aggFunc.Name == ast.AggFuncFirstRow { - canOptimize := false - for j, gbyExpr := range partialGbyItems { - if j >= len(finalGbyItems) { - // after distinct push, len(partialGbyItems) may larger than len(finalGbyItems) - // for example, - // select /*+ HASH_AGG() */ a, count(distinct a) from t; - // will generate to, - // HashAgg root funcs:count(distinct a), funcs:firstrow(a)" - // HashAgg cop group by:a, funcs:firstrow(a)->Column#6" - // the firstrow in root task can not be removed. - break - } - // Skip if it's a constant. - // For SELECT DISTINCT SQRT(1) FROM t. - // We shouldn't remove the firstrow(SQRT(1)). - if _, ok := gbyExpr.(*expression.Constant); ok { - continue - } - if gbyExpr.Equal(sctx.GetExprCtx().GetEvalCtx(), aggFunc.Args[0]) { - canOptimize = true - firstRowFuncMap[aggFunc].Args[0] = finalGbyItems[j] - break - } - } - if canOptimize { - partialSchema.Columns = append(partialSchema.Columns[:partialCursor], partialSchema.Columns[partialCursor+1:]...) - continue - } - } - partialCursor += computePartialCursorOffset(aggFunc.Name) - newAggFuncs = append(newAggFuncs, aggFunc) - } - return newAggFuncs -} - -func computePartialCursorOffset(name string) int { - offset := 0 - if aggregation.NeedCount(name) { - offset++ - } - if aggregation.NeedValue(name) { - offset++ - } - if name == ast.AggFuncApproxCountDistinct { - offset++ - } - return offset -} - -// Attach2Task implements PhysicalPlan interface. -func (p *PhysicalStreamAgg) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - if cop, ok := t.(*CopTask); ok { - // We should not push agg down across - // 1. double read, since the data of second read is ordered by handle instead of index. The `extraHandleCol` is added - // if the double read needs to keep order. So we just use it to decided - // whether the following plan is double read with order reserved. - // 2. the case that there's filters should be calculated on TiDB side. - // 3. the case of index merge - if (cop.indexPlan != nil && cop.tablePlan != nil && cop.keepOrder) || len(cop.rootTaskConds) > 0 || len(cop.idxMergePartPlans) > 0 { - t = cop.ConvertToRootTask(p.SCtx()) - attachPlan2Task(p, t) - } else { - storeType := cop.getStoreType() - // TiFlash doesn't support Stream Aggregation - if storeType == kv.TiFlash && len(p.GroupByItems) > 0 { - return base.InvalidTask - } - partialAgg, finalAgg := p.newPartialAggregate(storeType, false) - if partialAgg != nil { - if cop.tablePlan != nil { - cop.finishIndexPlan() - partialAgg.SetChildren(cop.tablePlan) - cop.tablePlan = partialAgg - // If needExtraProj is true, a projection will be created above the PhysicalIndexLookUpReader to make sure - // the schema is the same as the original DataSource schema. - // However, we pushed down the agg here, the partial agg was placed on the top of tablePlan, and the final - // agg will be placed above the PhysicalIndexLookUpReader, and the schema will be set correctly for them. - // If we add the projection again, the projection will be between the PhysicalIndexLookUpReader and - // the partial agg, and the schema will be broken. - cop.needExtraProj = false - } else { - partialAgg.SetChildren(cop.indexPlan) - cop.indexPlan = partialAgg - } - } - t = cop.ConvertToRootTask(p.SCtx()) - attachPlan2Task(finalAgg, t) - } - } else if mpp, ok := t.(*MppTask); ok { - t = mpp.ConvertToRootTask(p.SCtx()) - attachPlan2Task(p, t) - } else { - attachPlan2Task(p, t) - } - return t -} - -// cpuCostDivisor computes the concurrency to which we would amortize CPU cost -// for hash aggregation. -func (p *PhysicalHashAgg) cpuCostDivisor(hasDistinct bool) (divisor, con float64) { - if hasDistinct { - return 0, 0 - } - sessionVars := p.SCtx().GetSessionVars() - finalCon, partialCon := sessionVars.HashAggFinalConcurrency(), sessionVars.HashAggPartialConcurrency() - // According to `ValidateSetSystemVar`, `finalCon` and `partialCon` cannot be less than or equal to 0. - if finalCon == 1 && partialCon == 1 { - return 0, 0 - } - // It is tricky to decide which concurrency we should use to amortize CPU cost. Since cost of hash - // aggregation is tend to be under-estimated as explained in `attach2Task`, we choose the smaller - // concurrecy to make some compensation. - return math.Min(float64(finalCon), float64(partialCon)), float64(finalCon + partialCon) -} - -func (p *PhysicalHashAgg) attach2TaskForMpp1Phase(mpp *MppTask) base.Task { - // 1-phase agg: when the partition columns can be satisfied, where the plan does not need to enforce Exchange - // only push down the original agg - proj := p.convertAvgForMPP() - attachPlan2Task(p.self, mpp) - if proj != nil { - attachPlan2Task(proj, mpp) - } - return mpp -} - -// scaleStats4GroupingSets scale the derived stats because the lower source has been expanded. -// -// parent OP <- logicalAgg <- children OP (derived stats) -// | -// v -// parent OP <- physicalAgg <- children OP (stats used) -// | -// +----------+----------+----------+ -// Final Mid Partial Expand -// -// physical agg stats is reasonable from the whole, because expand operator is designed to facilitate -// the Mid and Partial Agg, which means when leaving the Final, its output rowcount could be exactly -// the same as what it derived(estimated) before entering physical optimization phase. -// -// From the cost model correctness, for these inserted sub-agg and even expand operator, we should -// recompute the stats for them particularly. -// -// for example: grouping sets {},{}, group by items {a,b,c,groupingID} -// after expand: -// -// a, b, c, groupingID -// ... null c 1 ---+ -// ... null c 1 +------- replica group 1 -// ... null c 1 ---+ -// null ... c 2 ---+ -// null ... c 2 +------- replica group 2 -// null ... c 2 ---+ -// -// since null value is seen the same when grouping data (groupingID in one replica is always the same): -// - so the num of group in replica 1 is equal to NDV(a,c) -// - so the num of group in replica 2 is equal to NDV(b,c) -// -// in a summary, the total num of group of all replica is equal to = Σ:NDV(each-grouping-set-cols, normal-group-cols) -func (p *PhysicalHashAgg) scaleStats4GroupingSets(groupingSets expression.GroupingSets, groupingIDCol *expression.Column, - childSchema *expression.Schema, childStats *property.StatsInfo) { - idSets := groupingSets.AllSetsColIDs() - normalGbyCols := make([]*expression.Column, 0, len(p.GroupByItems)) - for _, gbyExpr := range p.GroupByItems { - cols := expression.ExtractColumns(gbyExpr) - for _, col := range cols { - if !idSets.Has(int(col.UniqueID)) && col.UniqueID != groupingIDCol.UniqueID { - normalGbyCols = append(normalGbyCols, col) - } - } - } - sumNDV := float64(0) - for _, groupingSet := range groupingSets { - // for every grouping set, pick its cols out, and combine with normal group cols to get the ndv. - groupingSetCols := groupingSet.ExtractCols() - groupingSetCols = append(groupingSetCols, normalGbyCols...) - ndv, _ := cardinality.EstimateColsNDVWithMatchedLen(groupingSetCols, childSchema, childStats) - sumNDV += ndv - } - // After group operator, all same rows are grouped into one row, that means all - // change the sub-agg's stats - if p.StatsInfo() != nil { - // equivalence to a new cloned one. (cause finalAgg and partialAgg may share a same copy of stats) - cpStats := p.StatsInfo().Scale(1) - cpStats.RowCount = sumNDV - // We cannot estimate the ColNDVs for every output, so we use a conservative strategy. - for k := range cpStats.ColNDVs { - cpStats.ColNDVs[k] = sumNDV - } - // for old groupNDV, if it's containing one more grouping set cols, just plus the NDV where the col is excluded. - // for example: old grouping NDV(b,c), where b is in grouping sets {},{}. so when countering the new NDV: - // cases: - // new grouping NDV(b,c) := old NDV(b,c) + NDV(null, c) = old NDV(b,c) + DNV(c). - // new grouping NDV(a,b,c) := old NDV(a,b,c) + NDV(null,b,c) + NDV(a,null,c) = old NDV(a,b,c) + NDV(b,c) + NDV(a,c) - allGroupingSetsIDs := groupingSets.AllSetsColIDs() - for _, oneGNDV := range cpStats.GroupNDVs { - newGNDV := oneGNDV.NDV - intersectionIDs := make([]int64, 0, len(oneGNDV.Cols)) - for i, id := range oneGNDV.Cols { - if allGroupingSetsIDs.Has(int(id)) { - // when meet an id in grouping sets, skip it (cause its null) and append the rest ids to count the incrementNDV. - beforeLen := len(intersectionIDs) - intersectionIDs = append(intersectionIDs, oneGNDV.Cols[i:]...) - incrementNDV, _ := cardinality.EstimateColsDNVWithMatchedLenFromUniqueIDs(intersectionIDs, childSchema, childStats) - newGNDV += incrementNDV - // restore the before intersectionIDs slice. - intersectionIDs = intersectionIDs[:beforeLen] - } - // insert ids one by one. - intersectionIDs = append(intersectionIDs, id) - } - oneGNDV.NDV = newGNDV - } - p.SetStats(cpStats) - } -} - -// adjust3StagePhaseAgg generate 3 stage aggregation for single/multi count distinct if applicable. -// -// select count(distinct a), count(b) from foo -// -// will generate plan: -// -// HashAgg sum(#1), sum(#2) -> final agg -// +- Exchange Passthrough -// +- HashAgg count(distinct a) #1, sum(#3) #2 -> middle agg -// +- Exchange HashPartition by a -// +- HashAgg count(b) #3, group by a -> partial agg -// +- TableScan foo -// -// select count(distinct a), count(distinct b), count(c) from foo -// -// will generate plan: -// -// HashAgg sum(#1), sum(#2), sum(#3) -> final agg -// +- Exchange Passthrough -// +- HashAgg count(distinct a) #1, count(distinct b) #2, sum(#4) #3 -> middle agg -// +- Exchange HashPartition by a,b,groupingID -// +- HashAgg count(c) #4, group by a,b,groupingID -> partial agg -// +- Expand {}, {} -> expand -// +- TableScan foo -func (p *PhysicalHashAgg) adjust3StagePhaseAgg(partialAgg, finalAgg base.PhysicalPlan, canUse3StageAgg bool, - groupingSets expression.GroupingSets, mpp *MppTask) (final, mid, part, proj4Part base.PhysicalPlan, _ error) { - ectx := p.SCtx().GetExprCtx().GetEvalCtx() - - if !(partialAgg != nil && canUse3StageAgg) { - // quick path: return the original finalAgg and partiAgg. - return finalAgg, nil, partialAgg, nil, nil - } - if len(groupingSets) == 0 { - // single distinct agg mode. - clonedAgg, err := finalAgg.Clone(p.SCtx()) - if err != nil { - return nil, nil, nil, nil, err - } - - // step1: adjust middle agg. - middleHashAgg := clonedAgg.(*PhysicalHashAgg) - distinctPos := 0 - middleSchema := expression.NewSchema() - schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) - for i, fun := range middleHashAgg.AggFuncs { - col := &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: fun.RetTp, - } - if fun.HasDistinct { - distinctPos = i - fun.Mode = aggregation.Partial1Mode - } else { - fun.Mode = aggregation.Partial2Mode - originalCol := fun.Args[0].(*expression.Column) - // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) - schemaMap[originalCol.UniqueID] = col - } - middleSchema.Append(col) - } - middleHashAgg.schema = middleSchema - - // step2: adjust final agg. - finalHashAgg := finalAgg.(*PhysicalHashAgg) - finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) - for i, fun := range finalHashAgg.AggFuncs { - newArgs := make([]expression.Expression, 0, 1) - if distinctPos == i { - // change count(distinct) to sum() - fun.Name = ast.AggFuncSum - fun.HasDistinct = false - newArgs = append(newArgs, middleSchema.Columns[i]) - } else { - for _, arg := range fun.Args { - newCol, err := arg.RemapColumn(schemaMap) - if err != nil { - return nil, nil, nil, nil, err - } - newArgs = append(newArgs, newCol) - } - } - fun.Mode = aggregation.FinalMode - fun.Args = newArgs - finalAggDescs = append(finalAggDescs, fun) - } - finalHashAgg.AggFuncs = finalAggDescs - // partialAgg is im-mutated from args. - return finalHashAgg, middleHashAgg, partialAgg, nil, nil - } - // multi distinct agg mode, having grouping sets. - // set the default expression to constant 1 for the convenience to choose default group set data. - var groupingIDCol expression.Expression - // enforce Expand operator above the children. - // physical plan is enumerated without children from itself, use mpp subtree instead p.children. - // scale(len(groupingSets)) will change the NDV, while Expand doesn't change the NDV and groupNDV. - stats := mpp.p.StatsInfo().Scale(float64(1)) - stats.RowCount = stats.RowCount * float64(len(groupingSets)) - physicalExpand := PhysicalExpand{ - GroupingSets: groupingSets, - }.Init(p.SCtx(), stats, mpp.p.QueryBlockOffset()) - // generate a new column as groupingID to identify which this row is targeting for. - tp := types.NewFieldType(mysql.TypeLonglong) - tp.SetFlag(mysql.UnsignedFlag | mysql.NotNullFlag) - groupingIDCol = &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: tp, - } - // append the physical expand op with groupingID column. - physicalExpand.SetSchema(mpp.p.Schema().Clone()) - physicalExpand.schema.Append(groupingIDCol.(*expression.Column)) - physicalExpand.GroupingIDCol = groupingIDCol.(*expression.Column) - // attach PhysicalExpand to mpp - attachPlan2Task(physicalExpand, mpp) - - // having group sets - clonedAgg, err := finalAgg.Clone(p.SCtx()) - if err != nil { - return nil, nil, nil, nil, err - } - cloneHashAgg := clonedAgg.(*PhysicalHashAgg) - // Clone(), it will share same base-plan elements from the finalAgg, including id,tp,stats. Make a new one here. - cloneHashAgg.Plan = baseimpl.NewBasePlan(cloneHashAgg.SCtx(), cloneHashAgg.TP(), cloneHashAgg.QueryBlockOffset()) - cloneHashAgg.SetStats(finalAgg.StatsInfo()) // reuse the final agg stats here. - - // step1: adjust partial agg, for normal agg here, adjust it to target for specified group data. - // Since we may substitute the first arg of normal agg with case-when expression here, append a - // customized proj here rather than depending on postOptimize to insert a blunt one for us. - // - // proj4Partial output all the base col from lower op + caseWhen proj cols. - proj4Partial := new(PhysicalProjection).Init(p.SCtx(), mpp.p.StatsInfo(), mpp.p.QueryBlockOffset()) - for _, col := range mpp.p.Schema().Columns { - proj4Partial.Exprs = append(proj4Partial.Exprs, col) - } - proj4Partial.SetSchema(mpp.p.Schema().Clone()) - - partialHashAgg := partialAgg.(*PhysicalHashAgg) - partialHashAgg.GroupByItems = append(partialHashAgg.GroupByItems, groupingIDCol) - partialHashAgg.schema.Append(groupingIDCol.(*expression.Column)) - // it will create a new stats for partial agg. - partialHashAgg.scaleStats4GroupingSets(groupingSets, groupingIDCol.(*expression.Column), proj4Partial.Schema(), proj4Partial.StatsInfo()) - for _, fun := range partialHashAgg.AggFuncs { - if !fun.HasDistinct { - // for normal agg phase1, we should also modify them to target for specified group data. - // Expr = (case when groupingID = targeted_groupingID then arg else null end) - eqExpr := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.EQ, types.NewFieldType(mysql.TypeTiny), groupingIDCol, expression.NewUInt64Const(fun.GroupingID)) - caseWhen := expression.NewFunctionInternal(p.SCtx().GetExprCtx(), ast.Case, fun.Args[0].GetType(ectx), eqExpr, fun.Args[0], expression.NewNull()) - caseWhenProjCol := &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: fun.Args[0].GetType(ectx), - } - proj4Partial.Exprs = append(proj4Partial.Exprs, caseWhen) - proj4Partial.Schema().Append(caseWhenProjCol) - fun.Args[0] = caseWhenProjCol - } - } - - // step2: adjust middle agg - // middleHashAgg shared the same stats with the final agg does. - middleHashAgg := cloneHashAgg - middleSchema := expression.NewSchema() - schemaMap := make(map[int64]*expression.Column, len(middleHashAgg.AggFuncs)) - for _, fun := range middleHashAgg.AggFuncs { - col := &expression.Column{ - UniqueID: p.SCtx().GetSessionVars().AllocPlanColumnID(), - RetType: fun.RetTp, - } - if fun.HasDistinct { - // let count distinct agg aggregate on whole-scope data rather using case-when expr to target on specified group. (agg null strict attribute) - fun.Mode = aggregation.Partial1Mode - } else { - fun.Mode = aggregation.Partial2Mode - originalCol := fun.Args[0].(*expression.Column) - // record the origin column unique id down before change it to be case when expr. - // mapping the current partial output column with the agg origin arg column. (final agg arg should use this one) - schemaMap[originalCol.UniqueID] = col - } - middleSchema.Append(col) - } - middleHashAgg.schema = middleSchema - - // step3: adjust final agg - finalHashAgg := finalAgg.(*PhysicalHashAgg) - finalAggDescs := make([]*aggregation.AggFuncDesc, 0, len(finalHashAgg.AggFuncs)) - for i, fun := range finalHashAgg.AggFuncs { - newArgs := make([]expression.Expression, 0, 1) - if fun.HasDistinct { - // change count(distinct) agg to sum() - fun.Name = ast.AggFuncSum - fun.HasDistinct = false - // count(distinct a,b) -> become a single partial result col. - newArgs = append(newArgs, middleSchema.Columns[i]) - } else { - // remap final normal agg args to be output schema of middle normal agg. - for _, arg := range fun.Args { - newCol, err := arg.RemapColumn(schemaMap) - if err != nil { - return nil, nil, nil, nil, err - } - newArgs = append(newArgs, newCol) - } - } - fun.Mode = aggregation.FinalMode - fun.Args = newArgs - fun.GroupingID = 0 - finalAggDescs = append(finalAggDescs, fun) - } - finalHashAgg.AggFuncs = finalAggDescs - return finalHashAgg, middleHashAgg, partialHashAgg, proj4Partial, nil -} - -func (p *PhysicalHashAgg) attach2TaskForMpp(tasks ...base.Task) base.Task { - ectx := p.SCtx().GetExprCtx().GetEvalCtx() - - t := tasks[0].Copy() - mpp, ok := t.(*MppTask) - if !ok { - return base.InvalidTask - } - switch p.MppRunMode { - case Mpp1Phase: - // 1-phase agg: when the partition columns can be satisfied, where the plan does not need to enforce Exchange - // only push down the original agg - proj := p.convertAvgForMPP() - attachPlan2Task(p, mpp) - if proj != nil { - attachPlan2Task(proj, mpp) - } - return mpp - case Mpp2Phase: - // TODO: when partition property is matched by sub-plan, we actually needn't do extra an exchange and final agg. - proj := p.convertAvgForMPP() - partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true) - if partialAgg == nil { - return base.InvalidTask - } - attachPlan2Task(partialAgg, mpp) - partitionCols := p.MppPartitionCols - if len(partitionCols) == 0 { - items := finalAgg.(*PhysicalHashAgg).GroupByItems - partitionCols = make([]*property.MPPPartitionColumn, 0, len(items)) - for _, expr := range items { - col, ok := expr.(*expression.Column) - if !ok { - return base.InvalidTask - } - partitionCols = append(partitionCols, &property.MPPPartitionColumn{ - Col: col, - CollateID: property.GetCollateIDByNameForPartition(col.GetType(ectx).GetCollate()), - }) - } - } - prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} - newMpp := mpp.enforceExchangerImpl(prop) - if newMpp.Invalid() { - return newMpp - } - attachPlan2Task(finalAgg, newMpp) - // TODO: how to set 2-phase cost? - if proj != nil { - attachPlan2Task(proj, newMpp) - } - return newMpp - case MppTiDB: - partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, false) - if partialAgg != nil { - attachPlan2Task(partialAgg, mpp) - } - t = mpp.ConvertToRootTask(p.SCtx()) - attachPlan2Task(finalAgg, t) - return t - case MppScalar: - prop := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.SinglePartitionType} - if !mpp.needEnforceExchanger(prop) { - // On the one hand: when the low layer already satisfied the single partition layout, just do the all agg computation in the single node. - return p.attach2TaskForMpp1Phase(mpp) - } - // On the other hand: try to split the mppScalar agg into multi phases agg **down** to multi nodes since data already distributed across nodes. - // we have to check it before the content of p has been modified - canUse3StageAgg, groupingSets := p.scale3StageForDistinctAgg() - proj := p.convertAvgForMPP() - partialAgg, finalAgg := p.newPartialAggregate(kv.TiFlash, true) - if finalAgg == nil { - return base.InvalidTask - } - - final, middle, partial, proj4Partial, err := p.adjust3StagePhaseAgg(partialAgg, finalAgg, canUse3StageAgg, groupingSets, mpp) - if err != nil { - return base.InvalidTask - } - - // partial agg proj would be null if one scalar agg cannot run in two-phase mode - if proj4Partial != nil { - attachPlan2Task(proj4Partial, mpp) - } - - // partial agg would be null if one scalar agg cannot run in two-phase mode - if partial != nil { - attachPlan2Task(partial, mpp) - } - - if middle != nil && canUse3StageAgg { - items := partial.(*PhysicalHashAgg).GroupByItems - partitionCols := make([]*property.MPPPartitionColumn, 0, len(items)) - for _, expr := range items { - col, ok := expr.(*expression.Column) - if !ok { - continue - } - partitionCols = append(partitionCols, &property.MPPPartitionColumn{ - Col: col, - CollateID: property.GetCollateIDByNameForPartition(col.GetType(ectx).GetCollate()), - }) - } - - exProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.HashType, MPPPartitionCols: partitionCols} - newMpp := mpp.enforceExchanger(exProp) - attachPlan2Task(middle, newMpp) - mpp = newMpp - } - - // prop here still be the first generated single-partition requirement. - newMpp := mpp.enforceExchanger(prop) - attachPlan2Task(final, newMpp) - if proj == nil { - proj = PhysicalProjection{ - Exprs: make([]expression.Expression, 0, len(p.Schema().Columns)), - }.Init(p.SCtx(), p.StatsInfo(), p.QueryBlockOffset()) - for _, col := range p.Schema().Columns { - proj.Exprs = append(proj.Exprs, col) - } - proj.SetSchema(p.schema) - } - attachPlan2Task(proj, newMpp) - return newMpp - default: - return base.InvalidTask - } -} - -// Attach2Task implements the PhysicalPlan interface. -func (p *PhysicalHashAgg) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - if cop, ok := t.(*CopTask); ok { - if len(cop.rootTaskConds) == 0 && len(cop.idxMergePartPlans) == 0 { - copTaskType := cop.getStoreType() - partialAgg, finalAgg := p.newPartialAggregate(copTaskType, false) - if partialAgg != nil { - if cop.tablePlan != nil { - cop.finishIndexPlan() - partialAgg.SetChildren(cop.tablePlan) - cop.tablePlan = partialAgg - // If needExtraProj is true, a projection will be created above the PhysicalIndexLookUpReader to make sure - // the schema is the same as the original DataSource schema. - // However, we pushed down the agg here, the partial agg was placed on the top of tablePlan, and the final - // agg will be placed above the PhysicalIndexLookUpReader, and the schema will be set correctly for them. - // If we add the projection again, the projection will be between the PhysicalIndexLookUpReader and - // the partial agg, and the schema will be broken. - cop.needExtraProj = false - } else { - partialAgg.SetChildren(cop.indexPlan) - cop.indexPlan = partialAgg - } - } - // In `newPartialAggregate`, we are using stats of final aggregation as stats - // of `partialAgg`, so the network cost of transferring result rows of `partialAgg` - // to TiDB is normally under-estimated for hash aggregation, since the group-by - // column may be independent of the column used for region distribution, so a closer - // estimation of network cost for hash aggregation may multiply the number of - // regions involved in the `partialAgg`, which is unknown however. - t = cop.ConvertToRootTask(p.SCtx()) - attachPlan2Task(finalAgg, t) - } else { - t = cop.ConvertToRootTask(p.SCtx()) - attachPlan2Task(p, t) - } - } else if _, ok := t.(*MppTask); ok { - return p.attach2TaskForMpp(tasks...) - } else { - attachPlan2Task(p, t) - } - return t -} - -func (p *PhysicalWindow) attach2TaskForMPP(mpp *MppTask) base.Task { - // FIXME: currently, tiflash's join has different schema with TiDB, - // so we have to rebuild the schema of join and operators which may inherit schema from join. - // for window, we take the sub-plan's schema, and the schema generated by windowDescs. - columns := p.Schema().Clone().Columns[len(p.Schema().Columns)-len(p.WindowFuncDescs):] - p.schema = expression.MergeSchema(mpp.Plan().Schema(), expression.NewSchema(columns...)) - - failpoint.Inject("CheckMPPWindowSchemaLength", func() { - if len(p.Schema().Columns) != len(mpp.Plan().Schema().Columns)+len(p.WindowFuncDescs) { - panic("mpp physical window has incorrect schema length") - } - }) - - return attachPlan2Task(p, mpp) -} - -// Attach2Task implements the PhysicalPlan interface. -func (p *PhysicalWindow) Attach2Task(tasks ...base.Task) base.Task { - if mpp, ok := tasks[0].Copy().(*MppTask); ok && p.storeTp == kv.TiFlash { - return p.attach2TaskForMPP(mpp) - } - t := tasks[0].ConvertToRootTask(p.SCtx()) - return attachPlan2Task(p.self, t) -} - -// Attach2Task implements the PhysicalPlan interface. -func (p *PhysicalCTEStorage) Attach2Task(tasks ...base.Task) base.Task { - t := tasks[0].Copy() - if mpp, ok := t.(*MppTask); ok { - p.SetChildren(t.Plan()) - return &MppTask{ - p: p, - partTp: mpp.partTp, - hashCols: mpp.hashCols, - tblColHists: mpp.tblColHists, - } - } - t.ConvertToRootTask(p.SCtx()) - p.SetChildren(t.Plan()) - ta := &RootTask{} - ta.SetPlan(p) - return ta -} - -// Attach2Task implements the PhysicalPlan interface. -func (p *PhysicalSequence) Attach2Task(tasks ...base.Task) base.Task { - for _, t := range tasks { - _, isMpp := t.(*MppTask) - if !isMpp { - return tasks[len(tasks)-1] - } - } - - lastTask := tasks[len(tasks)-1].(*MppTask) - - children := make([]base.PhysicalPlan, 0, len(tasks)) - for _, t := range tasks { - children = append(children, t.Plan()) - } - - p.SetChildren(children...) - - mppTask := &MppTask{ - p: p, - partTp: lastTask.partTp, - hashCols: lastTask.hashCols, - tblColHists: lastTask.tblColHists, - } - return mppTask -} - -func collectPartitionInfosFromMPPPlan(p *PhysicalTableReader, mppPlan base.PhysicalPlan) { - switch x := mppPlan.(type) { - case *PhysicalTableScan: - p.TableScanAndPartitionInfos = append(p.TableScanAndPartitionInfos, tableScanAndPartitionInfo{x, x.PlanPartInfo}) - default: - for _, ch := range mppPlan.Children() { - collectPartitionInfosFromMPPPlan(p, ch) - } - } -} - -func collectRowSizeFromMPPPlan(mppPlan base.PhysicalPlan) (rowSize float64) { - if mppPlan != nil && mppPlan.StatsInfo() != nil && mppPlan.StatsInfo().HistColl != nil { - return cardinality.GetAvgRowSize(mppPlan.SCtx(), mppPlan.StatsInfo().HistColl, mppPlan.Schema().Columns, false, false) - } - return 1 // use 1 as lower-bound for safety -} - -func accumulateNetSeekCost4MPP(p base.PhysicalPlan) (cost float64) { - if ts, ok := p.(*PhysicalTableScan); ok { - return float64(len(ts.Ranges)) * float64(len(ts.Columns)) * ts.SCtx().GetSessionVars().GetSeekFactor(ts.Table) - } - for _, c := range p.Children() { - cost += accumulateNetSeekCost4MPP(c) - } - return -} - -func tryExpandVirtualColumn(p base.PhysicalPlan) { - if ts, ok := p.(*PhysicalTableScan); ok { - ts.Columns = ExpandVirtualColumn(ts.Columns, ts.schema, ts.Table.Columns) - return - } - for _, child := range p.Children() { - tryExpandVirtualColumn(child) - } -} - -func (t *MppTask) needEnforceExchanger(prop *property.PhysicalProperty) bool { - switch prop.MPPPartitionTp { - case property.AnyType: - return false - case property.BroadcastType: - return true - case property.SinglePartitionType: - return t.partTp != property.SinglePartitionType - default: - if t.partTp != property.HashType { - return true - } - // TODO: consider equalivant class - // TODO: `prop.IsSubsetOf` is enough, instead of equal. - // for example, if already partitioned by hash(B,C), then same (A,B,C) must distribute on a same node. - if len(prop.MPPPartitionCols) != len(t.hashCols) { - return true - } - for i, col := range prop.MPPPartitionCols { - if !col.Equal(t.hashCols[i]) { - return true - } - } - return false - } -} - -func (t *MppTask) enforceExchanger(prop *property.PhysicalProperty) *MppTask { - if !t.needEnforceExchanger(prop) { - return t - } - return t.Copy().(*MppTask).enforceExchangerImpl(prop) -} - -func (t *MppTask) enforceExchangerImpl(prop *property.PhysicalProperty) *MppTask { - if collate.NewCollationEnabled() && !t.p.SCtx().GetSessionVars().HashExchangeWithNewCollation && prop.MPPPartitionTp == property.HashType { - for _, col := range prop.MPPPartitionCols { - if types.IsString(col.Col.RetType.GetType()) { - t.p.SCtx().GetSessionVars().RaiseWarningWhenMPPEnforced("MPP mode may be blocked because when `new_collation_enabled` is true, HashJoin or HashAgg with string key is not supported now.") - return &MppTask{} - } - } - } - ctx := t.p.SCtx() - sender := PhysicalExchangeSender{ - ExchangeType: prop.MPPPartitionTp.ToExchangeType(), - HashCols: prop.MPPPartitionCols, - }.Init(ctx, t.p.StatsInfo()) - - if ctx.GetSessionVars().ChooseMppVersion() >= kv.MppVersionV1 { - sender.CompressionMode = ctx.GetSessionVars().ChooseMppExchangeCompressionMode() - } - - sender.SetChildren(t.p) - receiver := PhysicalExchangeReceiver{}.Init(ctx, t.p.StatsInfo()) - receiver.SetChildren(sender) - return &MppTask{ - p: receiver, - partTp: prop.MPPPartitionTp, - hashCols: prop.MPPPartitionCols, - } -} diff --git a/pkg/planner/optimize.go b/pkg/planner/optimize.go index b5796eca459e0..cd7757176821e 100644 --- a/pkg/planner/optimize.go +++ b/pkg/planner/optimize.go @@ -246,7 +246,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in useBinding := enableUseBinding && isStmtNode && match if sessVars.StmtCtx.EnableOptimizerDebugTrace { - if val, _err_ := failpoint.Eval(_curpkg_("SetBindingTimeToZero")); _err_ == nil { + failpoint.Inject("SetBindingTimeToZero", func(val failpoint.Value) { if val.(bool) && bindings != nil { bindings = bindings.Copy() for i := range bindings { @@ -254,7 +254,7 @@ func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in bindings[i].UpdateTime = types.ZeroTime } } - } + }) debugtrace.RecordAnyValuesWithNames(pctx, "Used binding", useBinding, "Enable binding", enableUseBinding, @@ -459,19 +459,19 @@ var planBuilderPool = sync.Pool{ var optimizeCnt int func optimize(ctx context.Context, sctx pctx.PlanContext, node ast.Node, is infoschema.InfoSchema) (base.Plan, types.NameSlice, float64, error) { - if val, _err_ := failpoint.Eval(_curpkg_("checkOptimizeCountOne")); _err_ == nil { + failpoint.Inject("checkOptimizeCountOne", func(val failpoint.Value) { // only count the optimization for SQL with specified text if testSQL, ok := val.(string); ok && testSQL == node.OriginalText() { optimizeCnt++ if optimizeCnt > 1 { - return nil, nil, 0, errors.New("gofail wrong optimizerCnt error") + failpoint.Return(nil, nil, 0, errors.New("gofail wrong optimizerCnt error")) } } - } - if _, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForOptimize")); _err_ == nil { + }) + failpoint.Inject("mockHighLoadForOptimize", func() { sqlPrefixes := []string{"select"} topsql.MockHighCPULoad(sctx.GetSessionVars().StmtCtx.OriginalSQL, sqlPrefixes, 10) - } + }) sessVars := sctx.GetSessionVars() if sessVars.StmtCtx.EnableOptimizerDebugTrace { debugtrace.EnterContextCommon(sctx) @@ -561,9 +561,9 @@ func buildLogicalPlan(ctx context.Context, sctx pctx.PlanContext, node ast.Node, sctx.GetSessionVars().MapScalarSubQ = nil sctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol = nil - if _, _err_ := failpoint.Eval(_curpkg_("mockRandomPlanID")); _err_ == nil { + failpoint.Inject("mockRandomPlanID", func() { sctx.GetSessionVars().PlanID.Store(rand.Int31n(1000)) // nolint:gosec - } + }) // reset fields about rewrite sctx.GetSessionVars().RewritePhaseInfo.Reset() diff --git a/pkg/planner/optimize.go__failpoint_stash__ b/pkg/planner/optimize.go__failpoint_stash__ deleted file mode 100644 index cd7757176821e..0000000000000 --- a/pkg/planner/optimize.go__failpoint_stash__ +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package planner - -import ( - "context" - "math" - "math/rand" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/bindinfo" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/planner/cascades" - pctx "github.com/pingcap/tidb/pkg/planner/context" - "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/hint" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/topsql" - "github.com/pingcap/tidb/pkg/util/tracing" - "go.uber.org/zap" -) - -// IsReadOnly check whether the ast.Node is a read only statement. -func IsReadOnly(node ast.Node, vars *variable.SessionVars) bool { - if execStmt, isExecStmt := node.(*ast.ExecuteStmt); isExecStmt { - prepareStmt, err := core.GetPreparedStmt(execStmt, vars) - if err != nil { - logutil.BgLogger().Warn("GetPreparedStmt failed", zap.Error(err)) - return false - } - return ast.IsReadOnly(prepareStmt.PreparedAst.Stmt) - } - return ast.IsReadOnly(node) -} - -// getPlanFromNonPreparedPlanCache tries to get an available cached plan from the NonPrepared Plan Cache for this stmt. -func getPlanFromNonPreparedPlanCache(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode, is infoschema.InfoSchema) (p base.Plan, ns types.NameSlice, ok bool, err error) { - stmtCtx := sctx.GetSessionVars().StmtCtx - _, isExplain := stmt.(*ast.ExplainStmt) - if !sctx.GetSessionVars().EnableNonPreparedPlanCache || // disabled - stmtCtx.InPreparedPlanBuilding || // already in cached plan rebuilding phase - stmtCtx.EnableOptimizerCETrace || stmtCtx.EnableOptimizeTrace || // in trace - stmtCtx.InRestrictedSQL || // is internal SQL - isExplain || // explain external - !sctx.GetSessionVars().DisableTxnAutoRetry || // txn-auto-retry - sctx.GetSessionVars().InMultiStmts || // in multi-stmt - (stmtCtx.InExplainStmt && stmtCtx.ExplainFormat != types.ExplainFormatPlanCache) { // in explain internal - return nil, nil, false, nil - } - ok, reason := core.NonPreparedPlanCacheableWithCtx(sctx.GetPlanCtx(), stmt, is) - if !ok { - if !isExplain && stmtCtx.InExplainStmt && stmtCtx.ExplainFormat == types.ExplainFormatPlanCache { - stmtCtx.AppendWarning(errors.NewNoStackErrorf("skip non-prepared plan-cache: %s", reason)) - } - return nil, nil, false, nil - } - - paramSQL, paramsVals, err := core.GetParamSQLFromAST(stmt) - if err != nil { - return nil, nil, false, err - } - if intest.InTest && ctx.Value(core.PlanCacheKeyTestIssue43667{}) != nil { // update the AST in the middle of the process - ctx.Value(core.PlanCacheKeyTestIssue43667{}).(func(stmt ast.StmtNode))(stmt) - } - val := sctx.GetSessionVars().GetNonPreparedPlanCacheStmt(paramSQL) - paramExprs := core.Params2Expressions(paramsVals) - - if val == nil { - // Create a new AST upon this parameterized SQL instead of using the original AST. - // Keep the original AST unchanged to avoid any side effect. - paramStmt, err := core.ParseParameterizedSQL(sctx, paramSQL) - if err != nil { - // This can happen rarely, cannot parse the parameterized(restored) SQL successfully, skip the plan cache in this case. - sctx.GetSessionVars().StmtCtx.AppendWarning(err) - return nil, nil, false, nil - } - // GeneratePlanCacheStmtWithAST may evaluate these parameters so set their values into SCtx in advance. - if err := core.SetParameterValuesIntoSCtx(sctx.GetPlanCtx(), true, nil, paramExprs); err != nil { - return nil, nil, false, err - } - cachedStmt, _, _, err := core.GeneratePlanCacheStmtWithAST(ctx, sctx, false, paramSQL, paramStmt, is) - if err != nil { - return nil, nil, false, err - } - sctx.GetSessionVars().AddNonPreparedPlanCacheStmt(paramSQL, cachedStmt) - val = cachedStmt - } - cachedStmt := val.(*core.PlanCacheStmt) - - cachedPlan, names, err := core.GetPlanFromPlanCache(ctx, sctx, true, is, cachedStmt, paramExprs) - if err != nil { - return nil, nil, false, err - } - - if intest.InTest && ctx.Value(core.PlanCacheKeyTestIssue47133{}) != nil { - ctx.Value(core.PlanCacheKeyTestIssue47133{}).(func(names []*types.FieldName))(names) - } - - return cachedPlan, names, true, nil -} - -// Optimize does optimization and creates a Plan. -func Optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is infoschema.InfoSchema) (plan base.Plan, slice types.NameSlice, retErr error) { - defer tracing.StartRegion(ctx, "planner.Optimize").End() - sessVars := sctx.GetSessionVars() - pctx := sctx.GetPlanCtx() - if sessVars.StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(pctx) - defer debugtrace.LeaveContextCommon(pctx) - } - - if !sessVars.InRestrictedSQL && (variable.RestrictedReadOnly.Load() || variable.VarTiDBSuperReadOnly.Load()) { - allowed, err := allowInReadOnlyMode(pctx, node) - if err != nil { - return nil, nil, err - } - if !allowed { - return nil, nil, errors.Trace(plannererrors.ErrSQLInReadOnlyMode) - } - } - - if sessVars.SQLMode.HasStrictMode() && !IsReadOnly(node, sessVars) { - sessVars.StmtCtx.TiFlashEngineRemovedDueToStrictSQLMode = true - _, hasTiFlashAccess := sessVars.IsolationReadEngines[kv.TiFlash] - if hasTiFlashAccess { - delete(sessVars.IsolationReadEngines, kv.TiFlash) - } - defer func() { - sessVars.StmtCtx.TiFlashEngineRemovedDueToStrictSQLMode = false - if hasTiFlashAccess { - sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{} - } - }() - } - - // handle the execute statement - if execAST, ok := node.(*ast.ExecuteStmt); ok { - p, names, err := OptimizeExecStmt(ctx, sctx, execAST, is) - return p, names, err - } - - tableHints := hint.ExtractTableHintsFromStmtNode(node, sessVars.StmtCtx) - originStmtHints, _, warns := hint.ParseStmtHints(tableHints, - setVarHintChecker, hypoIndexChecker(ctx, is), - sessVars.CurrentDB, byte(kv.ReplicaReadFollower)) - sessVars.StmtCtx.StmtHints = originStmtHints - for _, warn := range warns { - sessVars.StmtCtx.AppendWarning(warn) - } - - defer func() { - // Override the resource group if the hint is set. - if retErr == nil && sessVars.StmtCtx.StmtHints.HasResourceGroup { - if variable.EnableResourceControl.Load() { - hasPriv := true - // only check dynamic privilege when strict-mode is enabled. - if variable.EnableResourceControlStrictMode.Load() { - checker := privilege.GetPrivilegeManager(sctx) - if checker != nil { - hasRgAdminPriv := checker.RequestDynamicVerification(sctx.GetSessionVars().ActiveRoles, "RESOURCE_GROUP_ADMIN", false) - hasRgUserPriv := checker.RequestDynamicVerification(sctx.GetSessionVars().ActiveRoles, "RESOURCE_GROUP_USER", false) - hasPriv = hasRgAdminPriv || hasRgUserPriv - } - } - if hasPriv { - sessVars.StmtCtx.ResourceGroupName = sessVars.StmtCtx.StmtHints.ResourceGroup - // if we are in a txn, should update the txn resource name to let the txn - // commit with the hint resource group. - if txn, err := sctx.Txn(false); err == nil && txn != nil && txn.Valid() { - kv.SetTxnResourceGroup(txn, sessVars.StmtCtx.ResourceGroupName) - } - } else { - err := plannererrors.ErrSpecificAccessDenied.GenWithStackByArgs("SUPER or RESOURCE_GROUP_ADMIN or RESOURCE_GROUP_USER") - sessVars.StmtCtx.AppendWarning(err) - } - } else { - err := infoschema.ErrResourceGroupSupportDisabled - sessVars.StmtCtx.AppendWarning(err) - } - } - }() - - warns = warns[:0] - for name, val := range sessVars.StmtCtx.StmtHints.SetVars { - oldV, err := sessVars.SetSystemVarWithOldValAsRet(name, val) - if err != nil { - sessVars.StmtCtx.AppendWarning(err) - } - sessVars.StmtCtx.AddSetVarHintRestore(name, oldV) - } - if len(sessVars.StmtCtx.StmtHints.SetVars) > 0 { - sessVars.StmtCtx.SetSkipPlanCache("SET_VAR is used in the SQL") - } - - if _, isolationReadContainTiKV := sessVars.IsolationReadEngines[kv.TiKV]; isolationReadContainTiKV { - var fp base.Plan - if fpv, ok := sctx.Value(core.PointPlanKey).(core.PointPlanVal); ok { - // point plan is already tried in a multi-statement query. - fp = fpv.Plan - } else { - fp = core.TryFastPlan(pctx, node) - } - if fp != nil { - return fp, fp.OutputNames(), nil - } - } - if err := pctx.AdviseTxnWarmup(); err != nil { - return nil, nil, err - } - - enableUseBinding := sessVars.UsePlanBaselines - stmtNode, isStmtNode := node.(ast.StmtNode) - binding, match, scope := bindinfo.MatchSQLBinding(sctx, stmtNode) - var bindings bindinfo.Bindings - if match { - bindings = []bindinfo.Binding{binding} - } - - useBinding := enableUseBinding && isStmtNode && match - if sessVars.StmtCtx.EnableOptimizerDebugTrace { - failpoint.Inject("SetBindingTimeToZero", func(val failpoint.Value) { - if val.(bool) && bindings != nil { - bindings = bindings.Copy() - for i := range bindings { - bindings[i].CreateTime = types.ZeroTime - bindings[i].UpdateTime = types.ZeroTime - } - } - }) - debugtrace.RecordAnyValuesWithNames(pctx, - "Used binding", useBinding, - "Enable binding", enableUseBinding, - "IsStmtNode", isStmtNode, - "Matched", match, - "Scope", scope, - "Matched bindings", bindings, - ) - } - if isStmtNode { - // add the extra Limit after matching the bind record - stmtNode = core.TryAddExtraLimit(sctx, stmtNode) - node = stmtNode - } - - // try to get Plan from the NonPrepared Plan Cache - if sessVars.EnableNonPreparedPlanCache && - isStmtNode && - !useBinding { // TODO: support binding - cachedPlan, names, ok, err := getPlanFromNonPreparedPlanCache(ctx, sctx, stmtNode, is) - if err != nil { - return nil, nil, err - } - if ok { - return cachedPlan, names, nil - } - } - - var ( - names types.NameSlice - bestPlan, bestPlanFromBind base.Plan - chosenBinding bindinfo.Binding - err error - ) - if useBinding { - minCost := math.MaxFloat64 - var bindStmtHints hint.StmtHints - originHints := hint.CollectHint(stmtNode) - // bindings must be not nil when coming here, try to find the best binding. - for _, binding := range bindings { - if !binding.IsBindingEnabled() { - continue - } - if sessVars.StmtCtx.EnableOptimizerDebugTrace { - core.DebugTraceTryBinding(pctx, binding.Hint) - } - hint.BindHint(stmtNode, binding.Hint) - curStmtHints, _, curWarns := hint.ParseStmtHints(binding.Hint.GetStmtHints(), - setVarHintChecker, hypoIndexChecker(ctx, is), - sessVars.CurrentDB, byte(kv.ReplicaReadFollower)) - sessVars.StmtCtx.StmtHints = curStmtHints - // update session var by hint /set_var/ - for name, val := range sessVars.StmtCtx.StmtHints.SetVars { - oldV, err := sessVars.SetSystemVarWithOldValAsRet(name, val) - if err != nil { - sessVars.StmtCtx.AppendWarning(err) - } - sessVars.StmtCtx.AddSetVarHintRestore(name, oldV) - } - plan, curNames, cost, err := optimize(ctx, pctx, node, is) - if err != nil { - binding.Status = bindinfo.Invalid - handleInvalidBinding(ctx, pctx, scope, binding) - continue - } - if cost < minCost { - bindStmtHints, warns, minCost, names, bestPlanFromBind, chosenBinding = curStmtHints, curWarns, cost, curNames, plan, binding - } - } - if bestPlanFromBind == nil { - sessVars.StmtCtx.AppendWarning(errors.NewNoStackError("no plan generated from bindings")) - } else { - bestPlan = bestPlanFromBind - sessVars.StmtCtx.StmtHints = bindStmtHints - for _, warn := range warns { - sessVars.StmtCtx.AppendWarning(warn) - } - sessVars.StmtCtx.BindSQL = chosenBinding.BindSQL - sessVars.FoundInBinding = true - if sessVars.StmtCtx.InVerboseExplain { - sessVars.StmtCtx.AppendNote(errors.NewNoStackErrorf("Using the bindSQL: %v", chosenBinding.BindSQL)) - } else { - sessVars.StmtCtx.AppendExtraNote(errors.NewNoStackErrorf("Using the bindSQL: %v", chosenBinding.BindSQL)) - } - if len(tableHints) > 0 { - sessVars.StmtCtx.AppendWarning(errors.NewNoStackErrorf("The system ignores the hints in the current query and uses the hints specified in the bindSQL: %v", chosenBinding.BindSQL)) - } - } - // Restore the hint to avoid changing the stmt node. - hint.BindHint(stmtNode, originHints) - } - - if sessVars.StmtCtx.EnableOptimizerDebugTrace && bestPlanFromBind != nil { - core.DebugTraceBestBinding(pctx, chosenBinding.Hint) - } - // No plan found from the bindings, or the bindings are ignored. - if bestPlan == nil { - sessVars.StmtCtx.StmtHints = originStmtHints - bestPlan, names, _, err = optimize(ctx, pctx, node, is) - if err != nil { - return nil, nil, err - } - } - - // Add a baseline evolution task if: - // 1. the returned plan is from bindings; - // 2. the query is a select statement; - // 3. the original binding contains no read_from_storage hint; - // 4. the plan when ignoring bindings contains no tiflash hint; - // 5. the pending verified binding has not been added already; - savedStmtHints := sessVars.StmtCtx.StmtHints - defer func() { - sessVars.StmtCtx.StmtHints = savedStmtHints - }() - if sessVars.EvolvePlanBaselines && bestPlanFromBind != nil && - sessVars.SelectLimit == math.MaxUint64 { // do not evolve this query if sql_select_limit is enabled - // Check bestPlanFromBind firstly to avoid nil stmtNode. - if _, ok := stmtNode.(*ast.SelectStmt); ok && !bindings[0].Hint.ContainTableHint(hint.HintReadFromStorage) { - sessVars.StmtCtx.StmtHints = originStmtHints - defPlan, _, _, err := optimize(ctx, pctx, node, is) - if err != nil { - // Ignore this evolution task. - return bestPlan, names, nil - } - defPlanHints := core.GenHintsFromPhysicalPlan(defPlan) - for _, h := range defPlanHints { - if h.HintName.String() == hint.HintReadFromStorage { - return bestPlan, names, nil - } - } - } - } - - return bestPlan, names, nil -} - -// OptimizeForForeignKeyCascade does optimization and creates a Plan for foreign key cascade. -// Compare to Optimize, OptimizeForForeignKeyCascade only build plan by StmtNode, -// doesn't consider plan cache and plan binding, also doesn't do privilege check. -func OptimizeForForeignKeyCascade(ctx context.Context, sctx pctx.PlanContext, node ast.StmtNode, is infoschema.InfoSchema) (base.Plan, error) { - builder := planBuilderPool.Get().(*core.PlanBuilder) - defer planBuilderPool.Put(builder.ResetForReuse()) - hintProcessor := hint.NewQBHintHandler(sctx.GetSessionVars().StmtCtx) - builder.Init(sctx, is, hintProcessor) - p, err := builder.Build(ctx, node) - if err != nil { - return nil, err - } - if err := core.CheckTableLock(sctx, is, builder.GetVisitInfo()); err != nil { - return nil, err - } - return p, nil -} - -func allowInReadOnlyMode(sctx pctx.PlanContext, node ast.Node) (bool, error) { - pm := privilege.GetPrivilegeManager(sctx) - if pm == nil { - return true, nil - } - roles := sctx.GetSessionVars().ActiveRoles - // allow replication thread - // NOTE: it is required, whether SEM is enabled or not, only user with explicit RESTRICTED_REPLICA_WRITER_ADMIN granted can ignore the restriction, so we need to surpass the case that if SEM is not enabled, SUPER will has all privileges - if pm.HasExplicitlyGrantedDynamicPrivilege(roles, "RESTRICTED_REPLICA_WRITER_ADMIN", false) { - return true, nil - } - - switch node.(type) { - // allow change variables (otherwise can't unset read-only mode) - case *ast.SetStmt, - // allow analyze table - *ast.AnalyzeTableStmt, - *ast.UseStmt, - *ast.ShowStmt, - *ast.CreateBindingStmt, - *ast.DropBindingStmt, - *ast.PrepareStmt, - *ast.BeginStmt, - *ast.RollbackStmt: - return true, nil - case *ast.CommitStmt: - txn, err := sctx.Txn(true) - if err != nil { - return false, err - } - if !txn.IsReadOnly() { - return false, txn.Rollback() - } - return true, nil - } - - vars := sctx.GetSessionVars() - return IsReadOnly(node, vars), nil -} - -var planBuilderPool = sync.Pool{ - New: func() any { - return core.NewPlanBuilder() - }, -} - -// optimizeCnt is a global variable only used for test. -var optimizeCnt int - -func optimize(ctx context.Context, sctx pctx.PlanContext, node ast.Node, is infoschema.InfoSchema) (base.Plan, types.NameSlice, float64, error) { - failpoint.Inject("checkOptimizeCountOne", func(val failpoint.Value) { - // only count the optimization for SQL with specified text - if testSQL, ok := val.(string); ok && testSQL == node.OriginalText() { - optimizeCnt++ - if optimizeCnt > 1 { - failpoint.Return(nil, nil, 0, errors.New("gofail wrong optimizerCnt error")) - } - } - }) - failpoint.Inject("mockHighLoadForOptimize", func() { - sqlPrefixes := []string{"select"} - topsql.MockHighCPULoad(sctx.GetSessionVars().StmtCtx.OriginalSQL, sqlPrefixes, 10) - }) - sessVars := sctx.GetSessionVars() - if sessVars.StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - defer debugtrace.LeaveContextCommon(sctx) - } - - // build logical plan - hintProcessor := hint.NewQBHintHandler(sctx.GetSessionVars().StmtCtx) - node.Accept(hintProcessor) - defer hintProcessor.HandleUnusedViewHints() - builder := planBuilderPool.Get().(*core.PlanBuilder) - defer planBuilderPool.Put(builder.ResetForReuse()) - builder.Init(sctx, is, hintProcessor) - p, err := buildLogicalPlan(ctx, sctx, node, builder) - if err != nil { - return nil, nil, 0, err - } - - activeRoles := sessVars.ActiveRoles - // Check privilege. Maybe it's better to move this to the Preprocess, but - // we need the table information to check privilege, which is collected - // into the visitInfo in the logical plan builder. - if pm := privilege.GetPrivilegeManager(sctx); pm != nil { - visitInfo := core.VisitInfo4PrivCheck(ctx, is, node, builder.GetVisitInfo()) - if err := core.CheckPrivilege(activeRoles, pm, visitInfo); err != nil { - return nil, nil, 0, err - } - } - - if err := core.CheckTableLock(sctx, is, builder.GetVisitInfo()); err != nil { - return nil, nil, 0, err - } - - names := p.OutputNames() - - // Handle the non-logical plan statement. - logic, isLogicalPlan := p.(base.LogicalPlan) - if !isLogicalPlan { - return p, names, 0, nil - } - - core.RecheckCTE(logic) - - // Handle the logical plan statement, use cascades planner if enabled. - if sessVars.GetEnableCascadesPlanner() { - finalPlan, cost, err := cascades.DefaultOptimizer.FindBestPlan(sctx, logic) - return finalPlan, names, cost, err - } - - beginOpt := time.Now() - finalPlan, cost, err := core.DoOptimize(ctx, sctx, builder.GetOptFlag(), logic) - // TODO: capture plan replayer here if it matches sql and plan digest - - sessVars.DurationOptimization = time.Since(beginOpt) - return finalPlan, names, cost, err -} - -// OptimizeExecStmt to handle the "execute" statement -func OptimizeExecStmt(ctx context.Context, sctx sessionctx.Context, - execAst *ast.ExecuteStmt, is infoschema.InfoSchema) (base.Plan, types.NameSlice, error) { - builder := planBuilderPool.Get().(*core.PlanBuilder) - defer planBuilderPool.Put(builder.ResetForReuse()) - pctx := sctx.GetPlanCtx() - builder.Init(pctx, is, nil) - - p, err := buildLogicalPlan(ctx, pctx, execAst, builder) - if err != nil { - return nil, nil, err - } - exec, ok := p.(*core.Execute) - if !ok { - return nil, nil, errors.Errorf("invalid result plan type, should be Execute") - } - plan, names, err := core.GetPlanFromPlanCache(ctx, sctx, false, is, exec.PrepStmt, exec.Params) - if err != nil { - return nil, nil, err - } - exec.Plan = plan - exec.SetOutputNames(names) - exec.Stmt = exec.PrepStmt.PreparedAst.Stmt - return exec, names, nil -} - -func buildLogicalPlan(ctx context.Context, sctx pctx.PlanContext, node ast.Node, builder *core.PlanBuilder) (base.Plan, error) { - sctx.GetSessionVars().PlanID.Store(0) - sctx.GetSessionVars().PlanColumnID.Store(0) - sctx.GetSessionVars().MapScalarSubQ = nil - sctx.GetSessionVars().MapHashCode2UniqueID4ExtendedCol = nil - - failpoint.Inject("mockRandomPlanID", func() { - sctx.GetSessionVars().PlanID.Store(rand.Int31n(1000)) // nolint:gosec - }) - - // reset fields about rewrite - sctx.GetSessionVars().RewritePhaseInfo.Reset() - beginRewrite := time.Now() - p, err := builder.Build(ctx, node) - if err != nil { - return nil, err - } - sctx.GetSessionVars().RewritePhaseInfo.DurationRewrite = time.Since(beginRewrite) - if exec, ok := p.(*core.Execute); ok && exec.PrepStmt != nil { - sctx.GetSessionVars().StmtCtx.Tables = core.GetDBTableInfo(exec.PrepStmt.VisitInfos) - } else { - sctx.GetSessionVars().StmtCtx.Tables = core.GetDBTableInfo(builder.GetVisitInfo()) - } - return p, nil -} - -func handleInvalidBinding(ctx context.Context, sctx pctx.PlanContext, level string, binding bindinfo.Binding) { - sessionHandle := sctx.Value(bindinfo.SessionBindInfoKeyType).(bindinfo.SessionBindingHandle) - err := sessionHandle.DropSessionBinding(binding.SQLDigest) - if err != nil { - logutil.Logger(ctx).Info("drop session bindings failed") - } - if level == metrics.ScopeSession { - return - } - - globalHandle := domain.GetDomain(sctx).BindHandle() - globalHandle.AddInvalidGlobalBinding(binding) -} - -// setVarHintChecker checks whether the variable name in set_var hint is valid. -func setVarHintChecker(varName, hint string) (ok bool, warning error) { - sysVar := variable.GetSysVar(varName) - if sysVar == nil { // no such a variable - return false, plannererrors.ErrUnresolvedHintName.FastGenByArgs(varName, hint) - } - if !sysVar.IsHintUpdatableVerified { - warning = plannererrors.ErrNotHintUpdatable.FastGenByArgs(varName) - } - return true, warning -} - -func hypoIndexChecker(ctx context.Context, is infoschema.InfoSchema) func(db, tbl, col model.CIStr) (colOffset int, err error) { - return func(db, tbl, col model.CIStr) (colOffset int, err error) { - t, err := is.TableByName(ctx, db, tbl) - if err != nil { - return 0, errors.NewNoStackErrorf("table '%v.%v' doesn't exist", db, tbl) - } - for i, tblCol := range t.Cols() { - if tblCol.Name.L == col.L { - return i, nil - } - } - return 0, errors.NewNoStackErrorf("can't find column %v in table %v.%v", col, db, tbl) - } -} - -func init() { - core.OptimizeAstNode = Optimize - core.IsReadOnly = IsReadOnly - bindinfo.GetGlobalBindingHandle = func(sctx sessionctx.Context) bindinfo.GlobalBindingHandle { - return domain.GetDomain(sctx).BindHandle() - } -} diff --git a/pkg/server/binding__failpoint_binding__.go b/pkg/server/binding__failpoint_binding__.go deleted file mode 100644 index 884841332390a..0000000000000 --- a/pkg/server/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package server - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/server/conn.go b/pkg/server/conn.go index 241f5b8c79ff5..3cd8b5c052637 100644 --- a/pkg/server/conn.go +++ b/pkg/server/conn.go @@ -261,9 +261,9 @@ func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]b clientPlugin = authPluginImpl.Name } } - if _, _err_ := failpoint.Eval(_curpkg_("FakeAuthSwitch")); _err_ == nil { - return []byte(clientPlugin), nil - } + failpoint.Inject("FakeAuthSwitch", func() { + failpoint.Return([]byte(clientPlugin), nil) + }) enclen := 1 + len(clientPlugin) + 1 + len(cc.salt) + 1 data := cc.alloc.AllocWithLen(4, enclen) data = append(data, mysql.AuthSwitchRequest) // switch request @@ -488,11 +488,11 @@ func (cc *clientConn) readPacket() ([]byte, error) { } func (cc *clientConn) writePacket(data []byte) error { - if _, _err_ := failpoint.Eval(_curpkg_("FakeClientConn")); _err_ == nil { + failpoint.Inject("FakeClientConn", func() { if cc.pkt == nil { - return nil + failpoint.Return(nil) } - } + }) return cc.pkt.WritePacket(data) } @@ -858,10 +858,10 @@ func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshake.Respo logutil.Logger(ctx).Warn("Failed to get authentication method for user", zap.String("user", cc.user), zap.String("host", host)) } - if val, _err_ := failpoint.Eval(_curpkg_("FakeUser")); _err_ == nil { + failpoint.Inject("FakeUser", func(val failpoint.Value) { //nolint:forcetypeassert userplugin = val.(string) - } + }) if userplugin == mysql.AuthSocket { if !cc.isUnixSocket { return nil, servererr.ErrAccessDenied.FastGenByArgs(cc.user, host, hasPassword) @@ -1486,11 +1486,11 @@ func (cc *clientConn) flush(ctx context.Context) error { } } }() - if _, _err_ := failpoint.Eval(_curpkg_("FakeClientConn")); _err_ == nil { + failpoint.Inject("FakeClientConn", func() { if cc.pkt == nil { - return nil + failpoint.Return(nil) } - } + }) return cc.pkt.Flush() } @@ -2335,21 +2335,21 @@ func (cc *clientConn) writeChunks(ctx context.Context, rs resultset.ResultSet, b stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) } for { - if value, _err_ := failpoint.Eval(_curpkg_("fetchNextErr")); _err_ == nil { + failpoint.Inject("fetchNextErr", func(value failpoint.Value) { //nolint:forcetypeassert switch value.(string) { case "firstNext": - return firstNext, storeerr.ErrTiFlashServerTimeout + failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) case "secondNext": if !firstNext { - return firstNext, storeerr.ErrTiFlashServerTimeout + failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) } case "secondNextAndRetConflict": if !firstNext && validNextCount > 1 { - return firstNext, kv.ErrWriteConflict + failpoint.Return(firstNext, kv.ErrWriteConflict) } } - } + }) // Here server.tidbResultSet implements Next method. err := rs.Next(ctx, req) if err != nil { @@ -2549,9 +2549,9 @@ func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { Capability: cc.capability, } if fakeResp.AuthPlugin != "" { - if val, _err_ := failpoint.Eval(_curpkg_("ChangeUserAuthSwitch")); _err_ == nil { - return errors.Errorf("%v", val) - } + failpoint.Inject("ChangeUserAuthSwitch", func(val failpoint.Value) { + failpoint.Return(errors.Errorf("%v", val)) + }) newpass, err := cc.checkAuthPlugin(ctx, fakeResp) if err != nil { return err diff --git a/pkg/server/conn.go__failpoint_stash__ b/pkg/server/conn.go__failpoint_stash__ deleted file mode 100644 index 3cd8b5c052637..0000000000000 --- a/pkg/server/conn.go__failpoint_stash__ +++ /dev/null @@ -1,2748 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// The MIT License (MIT) -// -// Copyright (c) 2014 wandoulabs -// Copyright (c) 2014 siddontang -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of -// this software and associated documentation files (the "Software"), to deal in -// the Software without restriction, including without limitation the rights to -// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software is furnished to do so, -// subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. - -package server - -import ( - "bytes" - "context" - "crypto/tls" - "encoding/binary" - goerr "errors" - "fmt" - "io" - "net" - "os/user" - "runtime" - "runtime/pprof" - "runtime/trace" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/klauspost/compress/zstd" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/executor" - "github.com/pingcap/tidb/pkg/extension" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/auth" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/plugin" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/privilege/conn" - "github.com/pingcap/tidb/pkg/privilege/privileges/ldap" - servererr "github.com/pingcap/tidb/pkg/server/err" - "github.com/pingcap/tidb/pkg/server/handler/tikvhandler" - "github.com/pingcap/tidb/pkg/server/internal" - "github.com/pingcap/tidb/pkg/server/internal/column" - "github.com/pingcap/tidb/pkg/server/internal/dump" - "github.com/pingcap/tidb/pkg/server/internal/handshake" - "github.com/pingcap/tidb/pkg/server/internal/parse" - "github.com/pingcap/tidb/pkg/server/internal/resultset" - util2 "github.com/pingcap/tidb/pkg/server/internal/util" - server_metrics "github.com/pingcap/tidb/pkg/server/metrics" - "github.com/pingcap/tidb/pkg/session" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - storeerr "github.com/pingcap/tidb/pkg/store/driver/error" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/arena" - "github.com/pingcap/tidb/pkg/util/chunk" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/resourcegrouptag" - tlsutil "github.com/pingcap/tidb/pkg/util/tls" - "github.com/pingcap/tidb/pkg/util/topsql" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/prometheus/client_golang/prometheus" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -const ( - connStatusDispatching int32 = iota - connStatusReading - connStatusShutdown = variable.ConnStatusShutdown // Closed by server. - connStatusWaitShutdown = 3 // Notified by server to close. -) - -var ( - statusCompression = "Compression" - statusCompressionAlgorithm = "Compression_algorithm" - statusCompressionLevel = "Compression_level" -) - -var ( - // ConnectionInMemCounterForTest is a variable to count live connection object - ConnectionInMemCounterForTest = atomic.Int64{} -) - -// newClientConn creates a *clientConn object. -func newClientConn(s *Server) *clientConn { - cc := &clientConn{ - server: s, - connectionID: s.dom.NextConnID(), - collation: mysql.DefaultCollationID, - alloc: arena.NewAllocator(32 * 1024), - chunkAlloc: chunk.NewAllocator(), - status: connStatusDispatching, - lastActive: time.Now(), - authPlugin: mysql.AuthNativePassword, - quit: make(chan struct{}), - ppEnabled: s.cfg.ProxyProtocol.Networks != "", - } - - if intest.InTest { - ConnectionInMemCounterForTest.Add(1) - runtime.SetFinalizer(cc, func(*clientConn) { - ConnectionInMemCounterForTest.Add(-1) - }) - } - return cc -} - -// clientConn represents a connection between server and client, it maintains connection specific state, -// handles client query. -type clientConn struct { - pkt *internal.PacketIO // a helper to read and write data in packet format. - bufReadConn *util2.BufferedReadConn // a buffered-read net.Conn or buffered-read tls.Conn. - tlsConn *tls.Conn // TLS connection, nil if not TLS. - server *Server // a reference of server instance. - capability uint32 // client capability affects the way server handles client request. - connectionID uint64 // atomically allocated by a global variable, unique in process scope. - user string // user of the client. - dbname string // default database name. - salt []byte // random bytes used for authentication. - alloc arena.Allocator // an memory allocator for reducing memory allocation. - chunkAlloc chunk.Allocator - lastPacket []byte // latest sql query string, currently used for logging error. - // ShowProcess() and mysql.ComChangeUser both visit this field, ShowProcess() read information through - // the TiDBContext and mysql.ComChangeUser re-create it, so a lock is required here. - ctx struct { - sync.RWMutex - *TiDBContext // an interface to execute sql statements. - } - attrs map[string]string // attributes parsed from client handshake response. - serverHost string // server host - peerHost string // peer host - peerPort string // peer port - status int32 // dispatching/reading/shutdown/waitshutdown - lastCode uint16 // last error code - collation uint8 // collation used by client, may be different from the collation used by database. - lastActive time.Time // last active time - authPlugin string // default authentication plugin - isUnixSocket bool // connection is Unix Socket file - closeOnce sync.Once // closeOnce is used to make sure clientConn closes only once - rsEncoder *column.ResultEncoder // rsEncoder is used to encode the string result to different charsets - inputDecoder *util2.InputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8 - socketCredUID uint32 // UID from the other end of the Unix Socket - // mu is used for cancelling the execution of current transaction. - mu struct { - sync.RWMutex - cancelFunc context.CancelFunc - } - // quit is close once clientConn quit Run(). - quit chan struct{} - extensions *extension.SessionExtensions - - // Proxy Protocol Enabled - ppEnabled bool -} - -func (cc *clientConn) getCtx() *TiDBContext { - cc.ctx.RLock() - defer cc.ctx.RUnlock() - return cc.ctx.TiDBContext -} - -func (cc *clientConn) SetCtx(ctx *TiDBContext) { - cc.ctx.Lock() - cc.ctx.TiDBContext = ctx - cc.ctx.Unlock() -} - -func (cc *clientConn) String() string { - // MySQL converts a collation from u32 to char in the protocol, so the value could be wrong. It works fine for the - // default parameters (and libmysql seems not to provide any way to specify the collation other than the default - // one), so it's not a big problem. - collationStr := mysql.Collations[uint16(cc.collation)] - return fmt.Sprintf("id:%d, addr:%s status:%b, collation:%s, user:%s", - cc.connectionID, cc.bufReadConn.RemoteAddr(), cc.ctx.Status(), collationStr, cc.user, - ) -} - -func (cc *clientConn) setStatus(status int32) { - atomic.StoreInt32(&cc.status, status) - if ctx := cc.getCtx(); ctx != nil { - atomic.StoreInt32(&ctx.GetSessionVars().ConnectionStatus, status) - } -} - -func (cc *clientConn) getStatus() int32 { - return atomic.LoadInt32(&cc.status) -} - -func (cc *clientConn) CompareAndSwapStatus(oldStatus, newStatus int32) bool { - return atomic.CompareAndSwapInt32(&cc.status, oldStatus, newStatus) -} - -// authSwitchRequest is used by the server to ask the client to switch to a different authentication -// plugin. MySQL 8.0 libmysqlclient based clients by default always try `caching_sha2_password`, even -// when the server advertises the its default to be `mysql_native_password`. In addition to this switching -// may be needed on a per user basis as the authentication method is set per user. -// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html -// https://bugs.mysql.com/bug.php?id=93044 -func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]byte, error) { - clientPlugin := plugin - if plugin == mysql.AuthLDAPSASL { - clientPlugin += "_client" - } else if plugin == mysql.AuthLDAPSimple { - clientPlugin = mysql.AuthMySQLClearPassword - } else if authPluginImpl, ok := cc.extensions.GetAuthPlugin(plugin); ok { - if authPluginImpl.RequiredClientSidePlugin != "" { - clientPlugin = authPluginImpl.RequiredClientSidePlugin - } else { - // If RequiredClientSidePlugin is empty, use the plugin name as the client plugin. - clientPlugin = authPluginImpl.Name - } - } - failpoint.Inject("FakeAuthSwitch", func() { - failpoint.Return([]byte(clientPlugin), nil) - }) - enclen := 1 + len(clientPlugin) + 1 + len(cc.salt) + 1 - data := cc.alloc.AllocWithLen(4, enclen) - data = append(data, mysql.AuthSwitchRequest) // switch request - data = append(data, []byte(clientPlugin)...) - data = append(data, byte(0x00)) // requires null - if plugin == mysql.AuthLDAPSASL { - // append sasl auth method name - data = append(data, []byte(ldap.LDAPSASLAuthImpl.GetSASLAuthMethod())...) - data = append(data, byte(0x00)) - } else { - data = append(data, cc.salt...) - data = append(data, 0) - } - err := cc.writePacket(data) - if err != nil { - logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err)) - return nil, err - } - err = cc.flush(ctx) - if err != nil { - logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) - return nil, err - } - resp, err := cc.readPacket() - if err != nil { - err = errors.SuspendStack(err) - if errors.Cause(err) == io.EOF { - logutil.Logger(ctx).Warn("authSwitchRequest response fail due to connection has be closed by client-side") - } else { - logutil.Logger(ctx).Warn("authSwitchRequest response fail", zap.Error(err)) - } - return nil, err - } - cc.authPlugin = plugin - return resp, nil -} - -// handshake works like TCP handshake, but in a higher level, it first writes initial packet to client, -// during handshake, client and server negotiate compatible features and do authentication. -// After handshake, client can send sql query to server. -func (cc *clientConn) handshake(ctx context.Context) error { - if err := cc.writeInitialHandshake(ctx); err != nil { - if errors.Cause(err) == io.EOF { - logutil.Logger(ctx).Debug("Could not send handshake due to connection has be closed by client-side") - } else { - logutil.Logger(ctx).Debug("Write init handshake to client fail", zap.Error(errors.SuspendStack(err))) - } - return err - } - if err := cc.readOptionalSSLRequestAndHandshakeResponse(ctx); err != nil { - err1 := cc.writeError(ctx, err) - if err1 != nil { - logutil.Logger(ctx).Debug("writeError failed", zap.Error(err1)) - } - return err - } - - // MySQL supports an "init_connect" query, which can be run on initial connection. - // The query must return a non-error or the client is disconnected. - if err := cc.initConnect(ctx); err != nil { - logutil.Logger(ctx).Warn("init_connect failed", zap.Error(err)) - initErr := servererr.ErrNewAbortingConnection.FastGenByArgs(cc.connectionID, "unconnected", cc.user, cc.peerHost, "init_connect command failed") - if err1 := cc.writeError(ctx, initErr); err1 != nil { - terror.Log(err1) - } - return initErr - } - - data := cc.alloc.AllocWithLen(4, 32) - data = append(data, mysql.OKHeader) - data = append(data, 0, 0) - if cc.capability&mysql.ClientProtocol41 > 0 { - data = dump.Uint16(data, mysql.ServerStatusAutocommit) - data = append(data, 0, 0) - } - - err := cc.writePacket(data) - cc.pkt.SetSequence(0) - if err != nil { - err = errors.SuspendStack(err) - logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err)) - return err - } - - err = cc.flush(ctx) - if err != nil { - err = errors.SuspendStack(err) - logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) - return err - } - - // With mysql --compression-algorithms=zlib,zstd both flags are set, the result is Zlib - if cc.capability&mysql.ClientCompress > 0 { - cc.pkt.SetCompressionAlgorithm(mysql.CompressionZlib) - cc.ctx.SetCompressionAlgorithm(mysql.CompressionZlib) - } else if cc.capability&mysql.ClientZstdCompressionAlgorithm > 0 { - cc.pkt.SetCompressionAlgorithm(mysql.CompressionZstd) - cc.ctx.SetCompressionAlgorithm(mysql.CompressionZstd) - } - - return err -} - -func (cc *clientConn) Close() error { - // Be careful, this function should be re-entrant. It might be called more than once for a single connection. - // Any logic which is not idempotent should be in closeConn() and wrapped with `cc.closeOnce.Do`, like decresing - // metrics, releasing resources, etc. - // - // TODO: avoid calling this function multiple times. It's not intuitive that a connection can be closed multiple - // times. - cc.server.rwlock.Lock() - delete(cc.server.clients, cc.connectionID) - cc.server.rwlock.Unlock() - return closeConn(cc) -} - -// closeConn is idempotent and thread-safe. -// It will be called on the same `clientConn` more than once to avoid connection leak. -func closeConn(cc *clientConn) error { - var err error - cc.closeOnce.Do(func() { - if cc.connectionID > 0 { - cc.server.dom.ReleaseConnID(cc.connectionID) - cc.connectionID = 0 - } - if cc.bufReadConn != nil { - err := cc.bufReadConn.Close() - if err != nil { - // We need to expect connection might have already disconnected. - // This is because closeConn() might be called after a connection read-timeout. - logutil.Logger(context.Background()).Debug("could not close connection", zap.Error(err)) - } - } - - // Close statements and session - // At first, it'll decrese the count of connections in the resource group, update the corresponding gauge. - // Then it'll close the statements and session, which release advisory locks, row locks, etc. - if ctx := cc.getCtx(); ctx != nil { - resourceGroupName := ctx.GetSessionVars().ResourceGroupName - metrics.ConnGauge.WithLabelValues(resourceGroupName).Dec() - - err = ctx.Close() - } else { - metrics.ConnGauge.WithLabelValues(resourcegroup.DefaultResourceGroupName).Dec() - } - }) - return err -} - -func (cc *clientConn) closeWithoutLock() error { - delete(cc.server.clients, cc.connectionID) - return closeConn(cc) -} - -// writeInitialHandshake sends server version, connection ID, server capability, collation, server status -// and auth salt to the client. -func (cc *clientConn) writeInitialHandshake(ctx context.Context) error { - data := make([]byte, 4, 128) - - // min version 10 - data = append(data, 10) - // server version[00] - data = append(data, mysql.ServerVersion...) - data = append(data, 0) - // connection id - data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24)) - // auth-plugin-data-part-1 - data = append(data, cc.salt[0:8]...) - // filler [00] - data = append(data, 0) - // capability flag lower 2 bytes, using default capability here - data = append(data, byte(cc.server.capability), byte(cc.server.capability>>8)) - // charset - if cc.collation == 0 { - cc.collation = uint8(mysql.DefaultCollationID) - } - data = append(data, cc.collation) - // status - data = dump.Uint16(data, mysql.ServerStatusAutocommit) - // below 13 byte may not be used - // capability flag upper 2 bytes, using default capability here - data = append(data, byte(cc.server.capability>>16), byte(cc.server.capability>>24)) - // length of auth-plugin-data - data = append(data, byte(len(cc.salt)+1)) - // reserved 10 [00] - data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - // auth-plugin-data-part-2 - data = append(data, cc.salt[8:]...) - data = append(data, 0) - // auth-plugin name - if ctx := cc.getCtx(); ctx == nil { - if err := cc.openSession(); err != nil { - return err - } - } - defAuthPlugin, err := cc.ctx.GetSessionVars().GetGlobalSystemVar(context.Background(), variable.DefaultAuthPlugin) - if err != nil { - return err - } - cc.authPlugin = defAuthPlugin - data = append(data, []byte(defAuthPlugin)...) - - // Close the session to force this to be re-opened after we parse the response. This is needed - // to ensure we use the collation and client flags from the response for the session. - if err = cc.ctx.Close(); err != nil { - return err - } - cc.SetCtx(nil) - - data = append(data, 0) - if err = cc.writePacket(data); err != nil { - return err - } - return cc.flush(ctx) -} - -func (cc *clientConn) readPacket() ([]byte, error) { - if cc.getCtx() != nil { - cc.pkt.SetMaxAllowedPacket(cc.ctx.GetSessionVars().MaxAllowedPacket) - } - return cc.pkt.ReadPacket() -} - -func (cc *clientConn) writePacket(data []byte) error { - failpoint.Inject("FakeClientConn", func() { - if cc.pkt == nil { - failpoint.Return(nil) - } - }) - return cc.pkt.WritePacket(data) -} - -func (cc *clientConn) getWaitTimeout(ctx context.Context) uint64 { - sessVars := cc.ctx.GetSessionVars() - if sessVars.InTxn() && sessVars.IdleTransactionTimeout > 0 { - return uint64(sessVars.IdleTransactionTimeout) - } - return cc.getSessionVarsWaitTimeout(ctx) -} - -// getSessionVarsWaitTimeout get session variable wait_timeout -func (cc *clientConn) getSessionVarsWaitTimeout(ctx context.Context) uint64 { - valStr, exists := cc.ctx.GetSessionVars().GetSystemVar(variable.WaitTimeout) - if !exists { - return variable.DefWaitTimeout - } - waitTimeout, err := strconv.ParseUint(valStr, 10, 64) - if err != nil { - logutil.Logger(ctx).Warn("get sysval wait_timeout failed, use default value", zap.Error(err)) - // if get waitTimeout error, use default value - return variable.DefWaitTimeout - } - return waitTimeout -} - -func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Context) error { - // Read a packet. It may be a SSLRequest or HandshakeResponse. - data, err := cc.readPacket() - if err != nil { - err = errors.SuspendStack(err) - if errors.Cause(err) == io.EOF { - logutil.Logger(ctx).Debug("wait handshake response fail due to connection has be closed by client-side") - } else { - logutil.Logger(ctx).Debug("wait handshake response fail", zap.Error(err)) - } - return err - } - - var resp handshake.Response41 - var pos int - - if len(data) < 2 { - logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data)) - return mysql.ErrMalformPacket - } - - capability := uint32(binary.LittleEndian.Uint16(data[:2])) - if capability&mysql.ClientProtocol41 <= 0 { - logutil.Logger(ctx).Error("ClientProtocol41 flag is not set, please upgrade client") - return servererr.ErrNotSupportedAuthMode - } - pos, err = parse.HandshakeResponseHeader(ctx, &resp, data) - if err != nil { - terror.Log(err) - return err - } - - // After read packets we should update the client's host and port to grab - // real client's IP and port from PROXY Protocol header if PROXY Protocol is enabled. - _, _, err = cc.PeerHost("", true) - if err != nil { - terror.Log(err) - return err - } - // If enable proxy protocol check audit plugins after update real IP - if cc.ppEnabled { - err = cc.server.checkAuditPlugin(cc) - if err != nil { - return err - } - } - - if resp.Capability&mysql.ClientSSL > 0 { - tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig)) - if tlsConfig != nil { - // The packet is a SSLRequest, let's switch to TLS. - if err = cc.upgradeToTLS(tlsConfig); err != nil { - return err - } - // Read the following HandshakeResponse packet. - data, err = cc.readPacket() - if err != nil { - logutil.Logger(ctx).Warn("read handshake response failure after upgrade to TLS", zap.Error(err)) - return err - } - pos, err = parse.HandshakeResponseHeader(ctx, &resp, data) - if err != nil { - terror.Log(err) - return err - } - } - } else if tlsutil.RequireSecureTransport.Load() && !cc.isUnixSocket { - // If it's not a socket connection, we should reject the connection - // because TLS is required. - err := servererr.ErrSecureTransportRequired.FastGenByArgs() - terror.Log(err) - return err - } - - // Read the remaining part of the packet. - err = parse.HandshakeResponseBody(ctx, &resp, data, pos) - if err != nil { - terror.Log(err) - return err - } - - cc.capability = resp.Capability & cc.server.capability - cc.user = resp.User - cc.dbname = resp.DBName - cc.collation = resp.Collation - cc.attrs = resp.Attrs - cc.pkt.SetZstdLevel(zstd.EncoderLevelFromZstd(resp.ZstdLevel)) - - err = cc.handleAuthPlugin(ctx, &resp) - if err != nil { - return err - } - - switch resp.AuthPlugin { - case mysql.AuthCachingSha2Password: - resp.Auth, err = cc.authSha(ctx, resp) - if err != nil { - return err - } - case mysql.AuthTiDBSM3Password: - resp.Auth, err = cc.authSM3(ctx, resp) - if err != nil { - return err - } - case mysql.AuthNativePassword: - case mysql.AuthSocket: - case mysql.AuthTiDBSessionToken: - case mysql.AuthTiDBAuthToken: - case mysql.AuthMySQLClearPassword: - case mysql.AuthLDAPSASL: - case mysql.AuthLDAPSimple: - default: - if _, ok := cc.extensions.GetAuthPlugin(resp.AuthPlugin); !ok { - return errors.New("Unknown auth plugin") - } - } - - err = cc.openSessionAndDoAuth(resp.Auth, resp.AuthPlugin, resp.ZstdLevel) - if err != nil { - logutil.Logger(ctx).Warn("open new session or authentication failure", zap.Error(err)) - } - return err -} - -func (cc *clientConn) handleAuthPlugin(ctx context.Context, resp *handshake.Response41) error { - if resp.Capability&mysql.ClientPluginAuth > 0 { - newAuth, err := cc.checkAuthPlugin(ctx, resp) - if err != nil { - logutil.Logger(ctx).Warn("failed to check the user authplugin", zap.Error(err)) - return err - } - if len(newAuth) > 0 { - resp.Auth = newAuth - } - - if _, ok := cc.extensions.GetAuthPlugin(resp.AuthPlugin); ok { - // The auth plugin has been registered, skip other checks. - return nil - } - switch resp.AuthPlugin { - case mysql.AuthCachingSha2Password: - case mysql.AuthTiDBSM3Password: - case mysql.AuthNativePassword: - case mysql.AuthSocket: - case mysql.AuthTiDBSessionToken: - case mysql.AuthMySQLClearPassword: - case mysql.AuthLDAPSASL: - case mysql.AuthLDAPSimple: - default: - logutil.Logger(ctx).Warn("Unknown Auth Plugin", zap.String("plugin", resp.AuthPlugin)) - } - } else { - // MySQL 5.1 and older clients don't support authentication plugins. - logutil.Logger(ctx).Warn("Client without Auth Plugin support; Please upgrade client") - _, err := cc.checkAuthPlugin(ctx, resp) - if err != nil { - return err - } - resp.AuthPlugin = mysql.AuthNativePassword - } - return nil -} - -// authSha implements the caching_sha2_password specific part of the protocol. -func (cc *clientConn) authSha(ctx context.Context, resp handshake.Response41) ([]byte, error) { - const ( - shaCommand = 1 - requestRsaPubKey = 2 // Not supported yet, only TLS is supported as secure channel. - fastAuthOk = 3 - fastAuthFail = 4 - ) - - // If no password is specified, we don't send the FastAuthFail to do the full authentication - // as that doesn't make sense without a password and confuses the client. - // https://github.com/pingcap/tidb/issues/40831 - if len(resp.Auth) == 0 { - return []byte{}, nil - } - - // Currently we always send a "FastAuthFail" as the cached part of the protocol isn't implemented yet. - // This triggers the client to send the full response. - err := cc.writePacket([]byte{0, 0, 0, 0, shaCommand, fastAuthFail}) - if err != nil { - logutil.Logger(ctx).Error("authSha packet write failed", zap.Error(err)) - return nil, err - } - err = cc.flush(ctx) - if err != nil { - logutil.Logger(ctx).Error("authSha packet flush failed", zap.Error(err)) - return nil, err - } - - data, err := cc.readPacket() - if err != nil { - logutil.Logger(ctx).Error("authSha packet read failed", zap.Error(err)) - return nil, err - } - return bytes.Trim(data, "\x00"), nil -} - -// authSM3 implements the tidb_sm3_password specific part of the protocol. -// tidb_sm3_password is very similar to caching_sha2_password. -func (cc *clientConn) authSM3(ctx context.Context, resp handshake.Response41) ([]byte, error) { - // If no password is specified, we don't send the FastAuthFail to do the full authentication - // as that doesn't make sense without a password and confuses the client. - // https://github.com/pingcap/tidb/issues/40831 - if len(resp.Auth) == 0 { - return []byte{}, nil - } - - err := cc.writePacket([]byte{0, 0, 0, 0, 1, 4}) // fastAuthFail - if err != nil { - logutil.Logger(ctx).Error("authSM3 packet write failed", zap.Error(err)) - return nil, err - } - err = cc.flush(ctx) - if err != nil { - logutil.Logger(ctx).Error("authSM3 packet flush failed", zap.Error(err)) - return nil, err - } - - data, err := cc.readPacket() - if err != nil { - logutil.Logger(ctx).Error("authSM3 packet read failed", zap.Error(err)) - return nil, err - } - return bytes.Trim(data, "\x00"), nil -} - -func (cc *clientConn) SessionStatusToString() string { - status := cc.ctx.Status() - inTxn, autoCommit := 0, 0 - if status&mysql.ServerStatusInTrans > 0 { - inTxn = 1 - } - if status&mysql.ServerStatusAutocommit > 0 { - autoCommit = 1 - } - return fmt.Sprintf("inTxn:%d, autocommit:%d", - inTxn, autoCommit, - ) -} - -func (cc *clientConn) openSession() error { - var tlsStatePtr *tls.ConnectionState - if cc.tlsConn != nil { - tlsState := cc.tlsConn.ConnectionState() - tlsStatePtr = &tlsState - } - ctx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions) - if err != nil { - return err - } - cc.SetCtx(ctx) - - err = cc.server.checkConnectionCount() - if err != nil { - return err - } - return nil -} - -func (cc *clientConn) openSessionAndDoAuth(authData []byte, authPlugin string, zstdLevel int) error { - // Open a context unless this was done before. - if ctx := cc.getCtx(); ctx == nil { - err := cc.openSession() - if err != nil { - return err - } - } - - hasPassword := "YES" - if len(authData) == 0 { - hasPassword = "NO" - } - - host, port, err := cc.PeerHost(hasPassword, false) - if err != nil { - return err - } - - if !cc.isUnixSocket && authPlugin == mysql.AuthSocket { - return servererr.ErrAccessDeniedNoPassword.FastGenByArgs(cc.user, host) - } - - userIdentity := &auth.UserIdentity{Username: cc.user, Hostname: host, AuthPlugin: authPlugin} - if err = cc.ctx.Auth(userIdentity, authData, cc.salt, cc); err != nil { - return err - } - cc.ctx.SetPort(port) - cc.ctx.SetCompressionLevel(zstdLevel) - if cc.dbname != "" { - _, err = cc.useDB(context.Background(), cc.dbname) - if err != nil { - return err - } - } - cc.ctx.SetSessionManager(cc.server) - return nil -} - -// mockOSUserForAuthSocketTest should only be used in test -var mockOSUserForAuthSocketTest atomic.Pointer[string] - -// Check if the Authentication Plugin of the server, client and user configuration matches -func (cc *clientConn) checkAuthPlugin(ctx context.Context, resp *handshake.Response41) ([]byte, error) { - // Open a context unless this was done before. - if ctx := cc.getCtx(); ctx == nil { - err := cc.openSession() - if err != nil { - return nil, err - } - } - - authData := resp.Auth - // tidb_session_token is always permitted and skips stored user plugin. - if resp.AuthPlugin == mysql.AuthTiDBSessionToken { - return authData, nil - } - hasPassword := "YES" - if len(authData) == 0 { - hasPassword = "NO" - } - - host, _, err := cc.PeerHost(hasPassword, false) - if err != nil { - return nil, err - } - // Find the identity of the user based on username and peer host. - identity, err := cc.ctx.MatchIdentity(cc.user, host) - if err != nil { - return nil, servererr.ErrAccessDenied.FastGenByArgs(cc.user, host, hasPassword) - } - // Get the plugin for the identity. - userplugin, err := cc.ctx.AuthPluginForUser(identity) - if err != nil { - logutil.Logger(ctx).Warn("Failed to get authentication method for user", - zap.String("user", cc.user), zap.String("host", host)) - } - failpoint.Inject("FakeUser", func(val failpoint.Value) { - //nolint:forcetypeassert - userplugin = val.(string) - }) - if userplugin == mysql.AuthSocket { - if !cc.isUnixSocket { - return nil, servererr.ErrAccessDenied.FastGenByArgs(cc.user, host, hasPassword) - } - resp.AuthPlugin = mysql.AuthSocket - user, err := user.LookupId(fmt.Sprint(cc.socketCredUID)) - if err != nil { - return nil, err - } - uname := user.Username - - if intest.InTest { - if p := mockOSUserForAuthSocketTest.Load(); p != nil { - uname = *p - } - } - - return []byte(uname), nil - } - if len(userplugin) == 0 { - // No user plugin set, assuming MySQL Native Password - // This happens if the account doesn't exist or if the account doesn't have - // a password set. - if resp.AuthPlugin != mysql.AuthNativePassword { - if resp.Capability&mysql.ClientPluginAuth > 0 { - resp.AuthPlugin = mysql.AuthNativePassword - authData, err := cc.authSwitchRequest(ctx, mysql.AuthNativePassword) - if err != nil { - return nil, err - } - return authData, nil - } - } - return nil, nil - } - - // If the authentication method send by the server (cc.authPlugin) doesn't match - // the plugin configured for the user account in the mysql.user.plugin column - // or if the authentication method send by the server doesn't match the authentication - // method send by the client (*authPlugin) then we need to switch the authentication - // method to match the one configured for that specific user. - if (cc.authPlugin != userplugin) || (cc.authPlugin != resp.AuthPlugin) { - if userplugin == mysql.AuthTiDBAuthToken { - userplugin = mysql.AuthMySQLClearPassword - } - if resp.Capability&mysql.ClientPluginAuth > 0 { - authData, err := cc.authSwitchRequest(ctx, userplugin) - if err != nil { - return nil, err - } - resp.AuthPlugin = userplugin - return authData, nil - } else if userplugin != mysql.AuthNativePassword { - // MySQL 5.1 and older don't support authentication plugins yet - return nil, servererr.ErrNotSupportedAuthMode - } - } - - return nil, nil -} - -func (cc *clientConn) PeerHost(hasPassword string, update bool) (host, port string, err error) { - // already get peer host - if len(cc.peerHost) > 0 { - // Proxy protocol enabled and not update - if cc.ppEnabled && !update { - return cc.peerHost, cc.peerPort, nil - } - // Proxy protocol not enabled - if !cc.ppEnabled { - return cc.peerHost, cc.peerPort, nil - } - } - host = variable.DefHostname - if cc.isUnixSocket { - cc.peerHost = host - cc.serverHost = host - return - } - addr := cc.bufReadConn.RemoteAddr().String() - host, port, err = net.SplitHostPort(addr) - if err != nil { - err = servererr.ErrAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword) - return - } - cc.peerHost = host - cc.peerPort = port - - serverAddr := cc.bufReadConn.LocalAddr().String() - serverHost, _, err := net.SplitHostPort(serverAddr) - if err != nil { - err = servererr.ErrAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword) - return - } - cc.serverHost = serverHost - - return -} - -// skipInitConnect follows MySQL's rules of when init-connect should be skipped. -// In 5.7 it is any user with SUPER privilege, but in 8.0 it is: -// - SUPER or the CONNECTION_ADMIN dynamic privilege. -// - (additional exception) users with expired passwords (not yet supported) -// In TiDB CONNECTION_ADMIN is satisfied by SUPER, so we only need to check once. -func (cc *clientConn) skipInitConnect() bool { - checker := privilege.GetPrivilegeManager(cc.ctx.Session) - activeRoles := cc.ctx.GetSessionVars().ActiveRoles - return checker != nil && checker.RequestDynamicVerification(activeRoles, "CONNECTION_ADMIN", false) -} - -// initResultEncoder initialize the result encoder for current connection. -func (cc *clientConn) initResultEncoder(ctx context.Context) { - chs, err := cc.ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetResults) - if err != nil { - chs = "" - logutil.Logger(ctx).Warn("get character_set_results system variable failed", zap.Error(err)) - } - cc.rsEncoder = column.NewResultEncoder(chs) -} - -func (cc *clientConn) initInputEncoder(ctx context.Context) { - chs, err := cc.ctx.GetSessionVars().GetSessionOrGlobalSystemVar(context.Background(), variable.CharacterSetClient) - if err != nil { - chs = "" - logutil.Logger(ctx).Warn("get character_set_client system variable failed", zap.Error(err)) - } - cc.inputDecoder = util2.NewInputDecoder(chs) -} - -// initConnect runs the initConnect SQL statement if it has been specified. -// The semantics are MySQL compatible. -func (cc *clientConn) initConnect(ctx context.Context) error { - val, err := cc.ctx.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.InitConnect) - if err != nil { - return err - } - if val == "" || cc.skipInitConnect() { - return nil - } - logutil.Logger(ctx).Debug("init_connect starting") - stmts, err := cc.ctx.Parse(ctx, val) - if err != nil { - return err - } - for _, stmt := range stmts { - rs, err := cc.ctx.ExecuteStmt(ctx, stmt) - if err != nil { - return err - } - // init_connect does not care about the results, - // but they need to be drained because of lazy loading. - if rs != nil { - req := rs.NewChunk(nil) - for { - if err = rs.Next(ctx, req); err != nil { - return err - } - if req.NumRows() == 0 { - break - } - } - rs.Close() - } - } - logutil.Logger(ctx).Debug("init_connect complete") - return nil -} - -// Run reads client query and writes query result to client in for loop, if there is a panic during query handling, -// it will be recovered and log the panic error. -// This function returns and the connection is closed if there is an IO error or there is a panic. -func (cc *clientConn) Run(ctx context.Context) { - defer func() { - r := recover() - if r != nil { - logutil.Logger(ctx).Error("connection running loop panic", - zap.Stringer("lastSQL", getLastStmtInConn{cc}), - zap.String("err", fmt.Sprintf("%v", r)), - zap.Stack("stack"), - ) - err := cc.writeError(ctx, fmt.Errorf("%v", r)) - terror.Log(err) - metrics.PanicCounter.WithLabelValues(metrics.LabelSession).Inc() - } - if cc.getStatus() != connStatusShutdown { - err := cc.Close() - terror.Log(err) - } - - close(cc.quit) - }() - - parentCtx := ctx - var traceInfo *model.TraceInfo - // Usually, client connection status changes between [dispatching] <=> [reading]. - // When some event happens, server may notify this client connection by setting - // the status to special values, for example: kill or graceful shutdown. - // The client connection would detect the events when it fails to change status - // by CAS operation, it would then take some actions accordingly. - for { - sessVars := cc.ctx.GetSessionVars() - if alias := sessVars.SessionAlias; traceInfo == nil || traceInfo.SessionAlias != alias { - // We should reset the context trace info when traceInfo not inited or session alias changed. - traceInfo = &model.TraceInfo{ - ConnectionID: cc.connectionID, - SessionAlias: alias, - } - ctx = logutil.WithSessionAlias(parentCtx, sessVars.SessionAlias) - ctx = tracing.ContextWithTraceInfo(ctx, traceInfo) - } - - // Close connection between txn when we are going to shutdown server. - // Note the current implementation when shutting down, for an idle connection, the connection may block at readPacket() - // consider provider a way to close the connection directly after sometime if we can not read any data. - if cc.server.inShutdownMode.Load() { - if !sessVars.InTxn() { - return - } - } - - if !cc.CompareAndSwapStatus(connStatusDispatching, connStatusReading) || - // The judge below will not be hit by all means, - // But keep it stayed as a reminder and for the code reference for connStatusWaitShutdown. - cc.getStatus() == connStatusWaitShutdown { - return - } - - cc.alloc.Reset() - // close connection when idle time is more than wait_timeout - // default 28800(8h), FIXME: should not block at here when we kill the connection. - waitTimeout := cc.getWaitTimeout(ctx) - cc.pkt.SetReadTimeout(time.Duration(waitTimeout) * time.Second) - start := time.Now() - data, err := cc.readPacket() - if err != nil { - if terror.ErrorNotEqual(err, io.EOF) { - if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() { - if cc.getStatus() == connStatusWaitShutdown { - logutil.Logger(ctx).Info("read packet timeout because of killed connection") - } else { - idleTime := time.Since(start) - logutil.Logger(ctx).Info("read packet timeout, close this connection", - zap.Duration("idle", idleTime), - zap.Uint64("waitTimeout", waitTimeout), - zap.Error(err), - ) - } - } else if errors.ErrorEqual(err, servererr.ErrNetPacketTooLarge) { - err := cc.writeError(ctx, err) - if err != nil { - terror.Log(err) - } - } else { - errStack := errors.ErrorStack(err) - if !strings.Contains(errStack, "use of closed network connection") { - logutil.Logger(ctx).Warn("read packet failed, close this connection", - zap.Error(errors.SuspendStack(err))) - } - } - } - server_metrics.DisconnectByClientWithError.Inc() - return - } - - // Should check InTxn() to avoid execute `begin` stmt. - if cc.server.inShutdownMode.Load() { - if !cc.ctx.GetSessionVars().InTxn() { - return - } - } - - if !cc.CompareAndSwapStatus(connStatusReading, connStatusDispatching) { - return - } - - startTime := time.Now() - err = cc.dispatch(ctx, data) - cc.ctx.GetSessionVars().ClearAlloc(&cc.chunkAlloc, err != nil) - cc.chunkAlloc.Reset() - if err != nil { - cc.audit(plugin.Error) // tell the plugin API there was a dispatch error - if terror.ErrorEqual(err, io.EOF) { - cc.addMetrics(data[0], startTime, nil) - server_metrics.DisconnectNormal.Inc() - return - } else if terror.ErrResultUndetermined.Equal(err) { - logutil.Logger(ctx).Error("result undetermined, close this connection", zap.Error(err)) - server_metrics.DisconnectErrorUndetermined.Inc() - return - } else if terror.ErrCritical.Equal(err) { - metrics.CriticalErrorCounter.Add(1) - logutil.Logger(ctx).Fatal("critical error, stop the server", zap.Error(err)) - } - var txnMode string - if ctx := cc.getCtx(); ctx != nil { - txnMode = ctx.GetSessionVars().GetReadableTxnMode() - } - vars := cc.getCtx().GetSessionVars() - for _, dbName := range session.GetDBNames(vars) { - metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err), dbName, vars.ResourceGroupName).Inc() - } - - if storeerr.ErrLockAcquireFailAndNoWaitSet.Equal(err) { - logutil.Logger(ctx).Debug("Expected error for FOR UPDATE NOWAIT", zap.Error(err)) - } else { - var timestamp uint64 - if ctx := cc.getCtx(); ctx != nil && ctx.GetSessionVars() != nil && ctx.GetSessionVars().TxnCtx != nil { - timestamp = ctx.GetSessionVars().TxnCtx.StartTS - if timestamp == 0 && ctx.GetSessionVars().TxnCtx.StaleReadTs > 0 { - // for state-read query. - timestamp = ctx.GetSessionVars().TxnCtx.StaleReadTs - } - } - logutil.Logger(ctx).Info("command dispatched failed", - zap.String("connInfo", cc.String()), - zap.String("command", mysql.Command2Str[data[0]]), - zap.String("status", cc.SessionStatusToString()), - zap.Stringer("sql", getLastStmtInConn{cc}), - zap.String("txn_mode", txnMode), - zap.Uint64("timestamp", timestamp), - zap.String("err", errStrForLog(err, cc.ctx.GetSessionVars().EnableRedactLog)), - ) - } - err1 := cc.writeError(ctx, err) - terror.Log(err1) - } - cc.addMetrics(data[0], startTime, err) - cc.pkt.SetSequence(0) - cc.pkt.SetCompressedSequence(0) - } -} - -func errStrForLog(err error, redactMode string) string { - if redactMode != errors.RedactLogDisable { - // currently, only ErrParse is considered when enableRedactLog because it may contain sensitive information like - // password or accesskey - if parser.ErrParse.Equal(err) { - return "fail to parse SQL, and must redact the whole error when enable log redaction" - } - } - var ret string - if kv.ErrKeyExists.Equal(err) || parser.ErrParse.Equal(err) || infoschema.ErrTableNotExists.Equal(err) { - // Do not log stack for duplicated entry error. - ret = err.Error() - } else { - ret = errors.ErrorStack(err) - } - return ret -} - -func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { - if cmd == mysql.ComQuery && cc.ctx.Value(sessionctx.LastExecuteDDL) != nil { - // Don't take DDL execute time into account. - // It's already recorded by other metrics in ddl package. - return - } - - vars := cc.getCtx().GetSessionVars() - resourceGroupName := vars.ResourceGroupName - var counter prometheus.Counter - if len(resourceGroupName) == 0 || resourceGroupName == resourcegroup.DefaultResourceGroupName { - if err != nil && int(cmd) < len(server_metrics.QueryTotalCountErr) { - counter = server_metrics.QueryTotalCountErr[cmd] - } else if err == nil && int(cmd) < len(server_metrics.QueryTotalCountOk) { - counter = server_metrics.QueryTotalCountOk[cmd] - } - } - - if counter != nil { - counter.Inc() - } else { - label := server_metrics.CmdToString(cmd) - if err != nil { - metrics.QueryTotalCounter.WithLabelValues(label, "Error", resourceGroupName).Inc() - } else { - metrics.QueryTotalCounter.WithLabelValues(label, "OK", resourceGroupName).Inc() - } - } - - cost := time.Since(startTime) - sessionVar := cc.ctx.GetSessionVars() - affectedRows := cc.ctx.AffectedRows() - cc.ctx.GetTxnWriteThroughputSLI().FinishExecuteStmt(cost, affectedRows, sessionVar.InTxn()) - - stmtType := sessionVar.StmtCtx.StmtType - sqlType := metrics.LblGeneral - if stmtType != "" { - sqlType = stmtType - } - - switch sqlType { - case "Insert": - server_metrics.AffectedRowsCounterInsert.Add(float64(affectedRows)) - case "Replace": - server_metrics.AffectedRowsCounterReplace.Add(float64(affectedRows)) - case "Delete": - server_metrics.AffectedRowsCounterDelete.Add(float64(affectedRows)) - case "Update": - server_metrics.AffectedRowsCounterUpdate.Add(float64(affectedRows)) - } - - for _, dbName := range session.GetDBNames(vars) { - metrics.QueryDurationHistogram.WithLabelValues(sqlType, dbName, vars.StmtCtx.ResourceGroupName).Observe(cost.Seconds()) - metrics.QueryRPCHistogram.WithLabelValues(sqlType, dbName).Observe(float64(vars.StmtCtx.GetExecDetails().RequestCount)) - if vars.StmtCtx.GetExecDetails().ScanDetail != nil { - metrics.QueryProcessedKeyHistogram.WithLabelValues(sqlType, dbName).Observe(float64(vars.StmtCtx.GetExecDetails().ScanDetail.ProcessedKeys)) - } - } -} - -// dispatch handles client request based on command which is the first byte of the data. -// It also gets a token from server which is used to limit the concurrently handling clients. -// The most frequently used command is ComQuery. -func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { - defer func() { - // reset killed for each request - cc.ctx.GetSessionVars().SQLKiller.Reset() - }() - t := time.Now() - if (cc.ctx.Status() & mysql.ServerStatusInTrans) > 0 { - server_metrics.ConnIdleDurationHistogramInTxn.Observe(t.Sub(cc.lastActive).Seconds()) - } else { - server_metrics.ConnIdleDurationHistogramNotInTxn.Observe(t.Sub(cc.lastActive).Seconds()) - } - - cfg := config.GetGlobalConfig() - if cfg.OpenTracing.Enable { - var r tracing.Region - r, ctx = tracing.StartRegionWithNewRootSpan(ctx, "server.dispatch") - defer r.End() - } - - var cancelFunc context.CancelFunc - ctx, cancelFunc = context.WithCancel(ctx) - cc.mu.Lock() - cc.mu.cancelFunc = cancelFunc - cc.mu.Unlock() - - cc.lastPacket = data - cmd := data[0] - data = data[1:] - if topsqlstate.TopSQLEnabled() { - defer pprof.SetGoroutineLabels(ctx) - } - if variable.EnablePProfSQLCPU.Load() { - label := getLastStmtInConn{cc}.PProfLabel() - if len(label) > 0 { - defer pprof.SetGoroutineLabels(ctx) - ctx = pprof.WithLabels(ctx, pprof.Labels("sql", label)) - pprof.SetGoroutineLabels(ctx) - } - } - if trace.IsEnabled() { - lc := getLastStmtInConn{cc} - sqlType := lc.PProfLabel() - if len(sqlType) > 0 { - var task *trace.Task - ctx, task = trace.NewTask(ctx, sqlType) - defer task.End() - - trace.Log(ctx, "sql", lc.String()) - ctx = logutil.WithTraceLogger(ctx, tracing.TraceInfoFromContext(ctx)) - - taskID := *(*uint64)(unsafe.Pointer(task)) - ctx = pprof.WithLabels(ctx, pprof.Labels("trace", strconv.FormatUint(taskID, 10))) - pprof.SetGoroutineLabels(ctx) - } - } - token := cc.server.getToken() - defer func() { - // if handleChangeUser failed, cc.ctx may be nil - if ctx := cc.getCtx(); ctx != nil { - ctx.SetProcessInfo("", t, mysql.ComSleep, 0) - } - - cc.server.releaseToken(token) - cc.lastActive = time.Now() - }() - - vars := cc.ctx.GetSessionVars() - // reset killed for each request - vars.SQLKiller.Reset() - if cmd < mysql.ComEnd { - cc.ctx.SetCommandValue(cmd) - } - - dataStr := string(hack.String(data)) - switch cmd { - case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset, - mysql.ComSetOption, mysql.ComChangeUser: - cc.ctx.SetProcessInfo("", t, cmd, 0) - case mysql.ComInitDB: - cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0) - } - - switch cmd { - case mysql.ComQuit: - return io.EOF - case mysql.ComInitDB: - node, err := cc.useDB(ctx, dataStr) - cc.onExtensionStmtEnd(node, false, err) - if err != nil { - return err - } - return cc.writeOK(ctx) - case mysql.ComQuery: // Most frequently used command. - // For issue 1989 - // Input payload may end with byte '\0', we didn't find related mysql document about it, but mysql - // implementation accept that case. So trim the last '\0' here as if the payload an EOF string. - // See http://dev.mysql.com/doc/internals/en/com-query.html - if len(data) > 0 && data[len(data)-1] == 0 { - data = data[:len(data)-1] - dataStr = string(hack.String(data)) - } - return cc.handleQuery(ctx, dataStr) - case mysql.ComFieldList: - return cc.handleFieldList(ctx, dataStr) - // ComCreateDB, ComDropDB - case mysql.ComRefresh: - return cc.handleRefresh(ctx, data[0]) - case mysql.ComShutdown: // redirect to SQL - if err := cc.handleQuery(ctx, "SHUTDOWN"); err != nil { - return err - } - return cc.writeOK(ctx) - case mysql.ComStatistics: - return cc.writeStats(ctx) - // ComProcessInfo, ComConnect, ComProcessKill, ComDebug - case mysql.ComPing: - return cc.writeOK(ctx) - case mysql.ComChangeUser: - return cc.handleChangeUser(ctx, data) - // ComBinlogDump, ComTableDump, ComConnectOut, ComRegisterSlave - case mysql.ComStmtPrepare: - // For issue 39132, same as ComQuery - if len(data) > 0 && data[len(data)-1] == 0 { - data = data[:len(data)-1] - dataStr = string(hack.String(data)) - } - return cc.HandleStmtPrepare(ctx, dataStr) - case mysql.ComStmtExecute: - return cc.handleStmtExecute(ctx, data) - case mysql.ComStmtSendLongData: - return cc.handleStmtSendLongData(data) - case mysql.ComStmtClose: - return cc.handleStmtClose(data) - case mysql.ComStmtReset: - return cc.handleStmtReset(ctx, data) - case mysql.ComSetOption: - return cc.handleSetOption(ctx, data) - case mysql.ComStmtFetch: - return cc.handleStmtFetch(ctx, data) - // ComDaemon, ComBinlogDumpGtid - case mysql.ComResetConnection: - return cc.handleResetConnection(ctx) - // ComEnd - default: - return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", nil, cmd) - } -} - -func (cc *clientConn) writeStats(ctx context.Context) error { - var err error - var uptime int64 - info := tikvhandler.ServerInfo{} - info.ServerInfo, err = infosync.GetServerInfo() - if err != nil { - logutil.BgLogger().Error("Failed to get ServerInfo for uptime status", zap.Error(err)) - } else { - uptime = int64(time.Since(time.Unix(info.ServerInfo.StartTimestamp, 0)).Seconds()) - } - msg := []byte(fmt.Sprintf("Uptime: %d Threads: 0 Questions: 0 Slow queries: 0 Opens: 0 Flush tables: 0 Open tables: 0 Queries per second avg: 0.000", - uptime)) - data := cc.alloc.AllocWithLen(4, len(msg)) - data = append(data, msg...) - - err = cc.writePacket(data) - if err != nil { - return err - } - - return cc.flush(ctx) -} - -func (cc *clientConn) useDB(ctx context.Context, db string) (node ast.StmtNode, err error) { - // if input is "use `SELECT`", mysql client just send "SELECT" - // so we add `` around db. - stmts, err := cc.ctx.Parse(ctx, "use `"+db+"`") - if err != nil { - return nil, err - } - _, err = cc.ctx.ExecuteStmt(ctx, stmts[0]) - if err != nil { - return stmts[0], err - } - cc.dbname = db - return stmts[0], err -} - -func (cc *clientConn) flush(ctx context.Context) error { - var ( - stmtDetail *execdetails.StmtExecDetails - startTime time.Time - ) - if stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey); stmtDetailRaw != nil { - //nolint:forcetypeassert - stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) - startTime = time.Now() - } - defer func() { - if stmtDetail != nil { - stmtDetail.WriteSQLRespDuration += time.Since(startTime) - } - trace.StartRegion(ctx, "FlushClientConn").End() - if ctx := cc.getCtx(); ctx != nil && ctx.WarningCount() > 0 { - for _, err := range ctx.GetWarnings() { - var warn *errors.Error - if ok := goerr.As(err.Err, &warn); ok { - code := uint16(warn.Code()) - errno.IncrementWarning(code, cc.user, cc.peerHost) - } - } - } - }() - failpoint.Inject("FakeClientConn", func() { - if cc.pkt == nil { - failpoint.Return(nil) - } - }) - return cc.pkt.Flush() -} - -func (cc *clientConn) writeOK(ctx context.Context) error { - return cc.writeOkWith(ctx, mysql.OKHeader, true, cc.ctx.Status()) -} - -func (cc *clientConn) writeOkWith(ctx context.Context, header byte, flush bool, status uint16) error { - msg := cc.ctx.LastMessage() - affectedRows := cc.ctx.AffectedRows() - lastInsertID := cc.ctx.LastInsertID() - warnCnt := cc.ctx.WarningCount() - - enclen := 0 - if len(msg) > 0 { - enclen = util2.LengthEncodedIntSize(uint64(len(msg))) + len(msg) - } - - data := cc.alloc.AllocWithLen(4, 32+enclen) - data = append(data, header) - data = dump.LengthEncodedInt(data, affectedRows) - data = dump.LengthEncodedInt(data, lastInsertID) - if cc.capability&mysql.ClientProtocol41 > 0 { - data = dump.Uint16(data, status) - data = dump.Uint16(data, warnCnt) - } - if enclen > 0 { - // although MySQL manual says the info message is string(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html), - // it is actually string - data = dump.LengthEncodedString(data, []byte(msg)) - } - - err := cc.writePacket(data) - if err != nil { - return err - } - - if flush { - return cc.flush(ctx) - } - - return nil -} - -func (cc *clientConn) writeError(ctx context.Context, e error) error { - var ( - m *mysql.SQLError - te *terror.Error - ok bool - ) - originErr := errors.Cause(e) - if te, ok = originErr.(*terror.Error); ok { - m = terror.ToSQLError(te) - } else { - e := errors.Cause(originErr) - switch y := e.(type) { - case *terror.Error: - m = terror.ToSQLError(y) - default: - m = mysql.NewErrf(mysql.ErrUnknown, "%s", nil, e.Error()) - } - } - - cc.lastCode = m.Code - defer errno.IncrementError(m.Code, cc.user, cc.peerHost) - data := cc.alloc.AllocWithLen(4, 16+len(m.Message)) - data = append(data, mysql.ErrHeader) - data = append(data, byte(m.Code), byte(m.Code>>8)) - if cc.capability&mysql.ClientProtocol41 > 0 { - data = append(data, '#') - data = append(data, m.State...) - } - - data = append(data, m.Message...) - - err := cc.writePacket(data) - if err != nil { - return err - } - return cc.flush(ctx) -} - -// writeEOF writes an EOF packet or if ClientDeprecateEOF is set it -// writes an OK packet with EOF indicator. -// Note this function won't flush the stream because maybe there are more -// packets following it. -// serverStatus, a flag bit represents server information in the packet. -// Note: it is callers' responsibility to ensure correctness of serverStatus. -func (cc *clientConn) writeEOF(ctx context.Context, serverStatus uint16) error { - if cc.capability&mysql.ClientDeprecateEOF > 0 { - return cc.writeOkWith(ctx, mysql.EOFHeader, false, serverStatus) - } - - data := cc.alloc.AllocWithLen(4, 9) - - data = append(data, mysql.EOFHeader) - if cc.capability&mysql.ClientProtocol41 > 0 { - data = dump.Uint16(data, cc.ctx.WarningCount()) - data = dump.Uint16(data, serverStatus) - } - - err := cc.writePacket(data) - return err -} - -func (cc *clientConn) writeReq(ctx context.Context, filePath string) error { - data := cc.alloc.AllocWithLen(4, 5+len(filePath)) - data = append(data, mysql.LocalInFileHeader) - data = append(data, filePath...) - - err := cc.writePacket(data) - if err != nil { - return err - } - - return cc.flush(ctx) -} - -// getDataFromPath gets file contents from file path. -func (cc *clientConn) getDataFromPath(ctx context.Context, path string) ([]byte, error) { - err := cc.writeReq(ctx, path) - if err != nil { - return nil, err - } - var prevData, curData []byte - for { - curData, err = cc.readPacket() - if err != nil && terror.ErrorNotEqual(err, io.EOF) { - return nil, err - } - if len(curData) == 0 { - break - } - prevData = append(prevData, curData...) - } - return prevData, nil -} - -// handleLoadStats does the additional work after processing the 'load stats' query. -// It sends client a file path, then reads the file content from client, loads it into the storage. -func (cc *clientConn) handleLoadStats(ctx context.Context, loadStatsInfo *executor.LoadStatsInfo) error { - // If the server handles the load data request, the client has to set the ClientLocalFiles capability. - if cc.capability&mysql.ClientLocalFiles == 0 { - return servererr.ErrNotAllowedCommand - } - if loadStatsInfo == nil { - return errors.New("load stats: info is empty") - } - data, err := cc.getDataFromPath(ctx, loadStatsInfo.Path) - if err != nil { - return err - } - if len(data) == 0 { - return nil - } - return loadStatsInfo.Update(data) -} - -// handleIndexAdvise does the index advise work and returns the advise result for index. -func (cc *clientConn) handleIndexAdvise(ctx context.Context, indexAdviseInfo *executor.IndexAdviseInfo) error { - if cc.capability&mysql.ClientLocalFiles == 0 { - return servererr.ErrNotAllowedCommand - } - if indexAdviseInfo == nil { - return errors.New("Index Advise: info is empty") - } - - data, err := cc.getDataFromPath(ctx, indexAdviseInfo.Path) - if err != nil { - return err - } - if len(data) == 0 { - return errors.New("Index Advise: infile is empty") - } - - if err := indexAdviseInfo.GetIndexAdvice(data); err != nil { - return err - } - - // TODO: Write the rss []ResultSet. It will be done in another PR. - return nil -} - -func (cc *clientConn) handlePlanReplayerLoad(ctx context.Context, planReplayerLoadInfo *executor.PlanReplayerLoadInfo) error { - if cc.capability&mysql.ClientLocalFiles == 0 { - return servererr.ErrNotAllowedCommand - } - if planReplayerLoadInfo == nil { - return errors.New("plan replayer load: info is empty") - } - data, err := cc.getDataFromPath(ctx, planReplayerLoadInfo.Path) - if err != nil { - return err - } - if len(data) == 0 { - return nil - } - return planReplayerLoadInfo.Update(data) -} - -func (cc *clientConn) handlePlanReplayerDump(ctx context.Context, e *executor.PlanReplayerDumpInfo) error { - if cc.capability&mysql.ClientLocalFiles == 0 { - return servererr.ErrNotAllowedCommand - } - if e == nil { - return errors.New("plan replayer dump: executor is empty") - } - data, err := cc.getDataFromPath(ctx, e.Path) - if err != nil { - logutil.BgLogger().Error(err.Error()) - return err - } - if len(data) == 0 { - return nil - } - return e.DumpSQLsFromFile(ctx, data) -} - -func (cc *clientConn) audit(eventType plugin.GeneralEvent) { - err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { - audit := plugin.DeclareAuditManifest(p.Manifest) - if audit.OnGeneralEvent != nil { - cmd := mysql.Command2Str[byte(atomic.LoadUint32(&cc.ctx.GetSessionVars().CommandValue))] - ctx := context.WithValue(context.Background(), plugin.ExecStartTimeCtxKey, cc.ctx.GetSessionVars().StartTime) - audit.OnGeneralEvent(ctx, cc.ctx.GetSessionVars(), eventType, cmd) - } - return nil - }) - if err != nil { - terror.Log(err) - } -} - -// handleQuery executes the sql query string and writes result set or result ok to the client. -// As the execution time of this function represents the performance of TiDB, we do time log and metrics here. -// Some special queries like `load data` that does not return result, which is handled in handleFileTransInConn. -func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) { - defer trace.StartRegion(ctx, "handleQuery").End() - sessVars := cc.ctx.GetSessionVars() - sc := sessVars.StmtCtx - prevWarns := sc.GetWarnings() - var stmts []ast.StmtNode - cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc) - if stmts, err = cc.ctx.Parse(ctx, sql); err != nil { - cc.onExtensionSQLParseFailed(sql, err) - return err - } - - if len(stmts) == 0 { - return cc.writeOK(ctx) - } - - warns := sc.GetWarnings() - parserWarns := warns[len(prevWarns):] - - var pointPlans []base.Plan - cc.ctx.GetSessionVars().InMultiStmts = false - if len(stmts) > 1 { - // The client gets to choose if it allows multi-statements, and - // probably defaults OFF. This helps prevent against SQL injection attacks - // by early terminating the first statement, and then running an entirely - // new statement. - - capabilities := cc.ctx.GetSessionVars().ClientCapability - if capabilities&mysql.ClientMultiStatements < 1 { - // The client does not have multi-statement enabled. We now need to determine - // how to handle an unsafe situation based on the multiStmt sysvar. - switch cc.ctx.GetSessionVars().MultiStatementMode { - case variable.OffInt: - err = servererr.ErrMultiStatementDisabled - return err - case variable.OnInt: - // multi statement is fully permitted, do nothing - default: - warn := contextutil.SQLWarn{Level: contextutil.WarnLevelWarning, Err: servererr.ErrMultiStatementDisabled} - parserWarns = append(parserWarns, warn) - } - } - cc.ctx.GetSessionVars().InMultiStmts = true - - // Only pre-build point plans for multi-statement query - pointPlans, err = cc.prefetchPointPlanKeys(ctx, stmts, sql) - if err != nil { - for _, stmt := range stmts { - cc.onExtensionStmtEnd(stmt, false, err) - } - return err - } - metrics.NumOfMultiQueryHistogram.Observe(float64(len(stmts))) - } - if len(pointPlans) > 0 { - defer cc.ctx.ClearValue(plannercore.PointPlanKey) - } - var retryable bool - var lastStmt ast.StmtNode - var expiredStmtTaskID uint64 - for i, stmt := range stmts { - if lastStmt != nil { - cc.onExtensionStmtEnd(lastStmt, true, nil) - } - lastStmt = stmt - - // expiredTaskID is the task ID of the previous statement. When executing a stmt, - // the StmtCtx will be reinit and the TaskID will change. We can compare the StmtCtx.TaskID - // with the previous one to determine whether StmtCtx has been inited for the current stmt. - expiredStmtTaskID = sessVars.StmtCtx.TaskID - - if len(pointPlans) > 0 { - // Save the point plan in Session, so we don't need to build the point plan again. - cc.ctx.SetValue(plannercore.PointPlanKey, plannercore.PointPlanVal{Plan: pointPlans[i]}) - } - retryable, err = cc.handleStmt(ctx, stmt, parserWarns, i == len(stmts)-1) - if err != nil { - action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err) - if txnErr != nil { - err = txnErr - break - } - - if retryable && action == sessiontxn.StmtActionRetryReady { - cc.ctx.GetSessionVars().RetryInfo.Retrying = true - _, err = cc.handleStmt(ctx, stmt, parserWarns, i == len(stmts)-1) - cc.ctx.GetSessionVars().RetryInfo.Retrying = false - if err != nil { - break - } - continue - } - if !retryable || !errors.ErrorEqual(err, storeerr.ErrTiFlashServerTimeout) { - break - } - _, allowTiFlashFallback := cc.ctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] - if !allowTiFlashFallback { - break - } - // When the TiFlash server seems down, we append a warning to remind the user to check the status of the TiFlash - // server and fallback to TiKV. - warns := append(parserWarns, contextutil.SQLWarn{Level: contextutil.WarnLevelError, Err: err}) - delete(cc.ctx.GetSessionVars().IsolationReadEngines, kv.TiFlash) - _, err = cc.handleStmt(ctx, stmt, warns, i == len(stmts)-1) - cc.ctx.GetSessionVars().IsolationReadEngines[kv.TiFlash] = struct{}{} - if err != nil { - break - } - } - } - - if lastStmt != nil { - cc.onExtensionStmtEnd(lastStmt, sessVars.StmtCtx.TaskID != expiredStmtTaskID, err) - } - - return err -} - -// prefetchPointPlanKeys extracts the point keys in multi-statement query, -// use BatchGet to get the keys, so the values will be cached in the snapshot cache, save RPC call cost. -// For pessimistic transaction, the keys will be batch locked. -func (cc *clientConn) prefetchPointPlanKeys(ctx context.Context, stmts []ast.StmtNode, sqls string) ([]base.Plan, error) { - txn, err := cc.ctx.Txn(false) - if err != nil { - return nil, err - } - if !txn.Valid() { - // Only prefetch in-transaction query for simplicity. - // Later we can support out-transaction multi-statement query. - return nil, nil - } - vars := cc.ctx.GetSessionVars() - if vars.TxnCtx.IsPessimistic { - if vars.IsIsolation(ast.ReadCommitted) { - // TODO: to support READ-COMMITTED, we need to avoid getting new TS for each statement in the query. - return nil, nil - } - if vars.TxnCtx.GetForUpdateTS() != vars.TxnCtx.StartTS { - // Do not handle the case that ForUpdateTS is changed for simplicity. - return nil, nil - } - } - pointPlans := make([]base.Plan, len(stmts)) - var idxKeys []kv.Key //nolint: prealloc - var rowKeys []kv.Key //nolint: prealloc - isCommonHandle := make(map[string]bool, 0) - - handlePlan := func(sctx sessionctx.Context, p base.PhysicalPlan, resetStmtCtxFn func()) error { - var tableID int64 - switch v := p.(type) { - case *plannercore.PointGetPlan: - v.PrunePartitions(sctx) - tableID = executor.GetPhysID(v.TblInfo, v.PartitionIdx) - if v.IndexInfo != nil { - resetStmtCtxFn() - idxKey, err1 := plannercore.EncodeUniqueIndexKey(cc.getCtx(), v.TblInfo, v.IndexInfo, v.IndexValues, tableID) - if err1 != nil { - return err1 - } - idxKeys = append(idxKeys, idxKey) - isCommonHandle[string(hack.String(idxKey))] = v.TblInfo.IsCommonHandle - } else { - rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(tableID, v.Handle)) - } - case *plannercore.BatchPointGetPlan: - _, isTableDual := v.PrunePartitionsAndValues(sctx) - if isTableDual { - return nil - } - pi := v.TblInfo.GetPartitionInfo() - getPhysID := func(i int) int64 { - if pi == nil || i >= len(v.PartitionIdxs) { - return v.TblInfo.ID - } - return executor.GetPhysID(v.TblInfo, &v.PartitionIdxs[i]) - } - if v.IndexInfo != nil { - resetStmtCtxFn() - for i, idxVals := range v.IndexValues { - idxKey, err1 := plannercore.EncodeUniqueIndexKey(cc.getCtx(), v.TblInfo, v.IndexInfo, idxVals, getPhysID(i)) - if err1 != nil { - return err1 - } - idxKeys = append(idxKeys, idxKey) - isCommonHandle[string(hack.String(idxKey))] = v.TblInfo.IsCommonHandle - } - } else { - for i, handle := range v.Handles { - rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(getPhysID(i), handle)) - } - } - } - return nil - } - - sc := vars.StmtCtx - for i, stmt := range stmts { - if _, ok := stmt.(*ast.UseStmt); ok { - // If there is a "use db" statement, we shouldn't cache even if it's possible. - // Consider the scenario where there are statements that could execute on multiple - // schemas, but the schema is actually different. - return nil, nil - } - // TODO: the preprocess is run twice, we should find some way to avoid do it again. - if err = plannercore.Preprocess(ctx, cc.getCtx(), stmt); err != nil { - // error might happen, see https://github.com/pingcap/tidb/issues/39664 - return nil, nil - } - p := plannercore.TryFastPlan(cc.ctx.Session.GetPlanCtx(), stmt) - pointPlans[i] = p - if p == nil { - continue - } - // Only support Update and Delete for now. - // TODO: support other point plans. - switch x := p.(type) { - case *plannercore.Update: - //nolint:forcetypeassert - updateStmt, ok := stmt.(*ast.UpdateStmt) - if !ok { - logutil.BgLogger().Warn("unexpected statement type for Update plan", - zap.String("type", fmt.Sprintf("%T", stmt))) - continue - } - err = handlePlan(cc.ctx.Session, x.SelectPlan, func() { - executor.ResetUpdateStmtCtx(sc, updateStmt, vars) - }) - if err != nil { - return nil, err - } - case *plannercore.Delete: - deleteStmt, ok := stmt.(*ast.DeleteStmt) - if !ok { - logutil.BgLogger().Warn("unexpected statement type for Delete plan", - zap.String("type", fmt.Sprintf("%T", stmt))) - continue - } - err = handlePlan(cc.ctx.Session, x.SelectPlan, func() { - executor.ResetDeleteStmtCtx(sc, deleteStmt, vars) - }) - if err != nil { - return nil, err - } - } - } - if len(idxKeys) == 0 && len(rowKeys) == 0 { - return pointPlans, nil - } - snapshot := txn.GetSnapshot() - setResourceGroupTaggerForMultiStmtPrefetch(snapshot, sqls) - idxVals, err1 := snapshot.BatchGet(ctx, idxKeys) - if err1 != nil { - return nil, err1 - } - for idxKey, idxVal := range idxVals { - h, err2 := tablecodec.DecodeHandleInIndexValue(idxVal) - if err2 != nil { - return nil, err2 - } - tblID := tablecodec.DecodeTableID(hack.Slice(idxKey)) - rowKeys = append(rowKeys, tablecodec.EncodeRowKeyWithHandle(tblID, h)) - } - if vars.TxnCtx.IsPessimistic { - allKeys := append(rowKeys, idxKeys...) - err = executor.LockKeys(ctx, cc.getCtx(), vars.LockWaitTimeout, allKeys...) - if err != nil { - // suppress the lock error, we are not going to handle it here for simplicity. - err = nil - logutil.BgLogger().Warn("lock keys error on prefetch", zap.Error(err)) - } - } else { - _, err = snapshot.BatchGet(ctx, rowKeys) - if err != nil { - return nil, err - } - } - return pointPlans, nil -} - -func setResourceGroupTaggerForMultiStmtPrefetch(snapshot kv.Snapshot, sqls string) { - if !topsqlstate.TopSQLEnabled() { - return - } - normalized, digest := parser.NormalizeDigest(sqls) - topsql.AttachAndRegisterSQLInfo(context.Background(), normalized, digest, false) - snapshot.SetOption(kv.ResourceGroupTagger, tikvrpc.ResourceGroupTagger(func(req *tikvrpc.Request) { - if req == nil { - return - } - if len(normalized) == 0 { - return - } - req.ResourceGroupTag = resourcegrouptag.EncodeResourceGroupTag(digest, nil, - resourcegrouptag.GetResourceGroupLabelByKey(resourcegrouptag.GetFirstKeyFromRequest(req))) - })) -} - -// The first return value indicates whether the call of handleStmt has no side effect and can be retried. -// Currently, the first return value is used to fall back to TiKV when TiFlash is down. -func (cc *clientConn) handleStmt( - ctx context.Context, stmt ast.StmtNode, - warns []contextutil.SQLWarn, lastStmt bool, -) (bool, error) { - ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) - ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) - ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails()) - reg := trace.StartRegion(ctx, "ExecuteStmt") - cc.audit(plugin.Starting) - - // if stmt is load data stmt, store the channel that reads from the conn - // into the ctx for executor to use - if s, ok := stmt.(*ast.LoadDataStmt); ok { - if s.FileLocRef == ast.FileLocClient { - err := cc.preprocessLoadDataLocal(ctx) - defer cc.postprocessLoadDataLocal() - if err != nil { - return false, err - } - } - } - - rs, err := cc.ctx.ExecuteStmt(ctx, stmt) - reg.End() - // - If rs is not nil, the statement tracker detachment from session tracker - // is done in the `rs.Close` in most cases. - // - If the rs is nil and err is not nil, the detachment will be done in - // the `handleNoDelay`. - if rs != nil { - defer rs.Close() - } - - if err != nil { - // If error is returned during the planner phase or the executor.Open - // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker - // will not be detached. We need to detach them manually. - if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil { - sv.StmtCtx.DetachMemDiskTracker() - } - return true, err - } - - status := cc.ctx.Status() - if lastStmt { - cc.ctx.GetSessionVars().StmtCtx.AppendWarnings(warns) - } else { - status |= mysql.ServerMoreResultsExists - } - - if rs != nil { - if cc.getStatus() == connStatusShutdown { - return false, exeerrors.ErrQueryInterrupted - } - cc.ctx.GetSessionVars().SQLKiller.SetFinishFunc( - func() { - //nolint: errcheck - rs.Finish() - }) - cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(true) - defer cc.ctx.GetSessionVars().SQLKiller.InWriteResultSet.Store(false) - defer cc.ctx.GetSessionVars().SQLKiller.ClearFinishFunc() - if retryable, err := cc.writeResultSet(ctx, rs, false, status, 0); err != nil { - return retryable, err - } - return false, nil - } - - handled, err := cc.handleFileTransInConn(ctx, status) - if handled { - if execStmt := cc.ctx.Value(session.ExecStmtVarKey); execStmt != nil { - //nolint:forcetypeassert - execStmt.(*executor.ExecStmt).FinishExecuteStmt(0, err, false) - } - } - return false, err -} - -// Preprocess LOAD DATA. Load data from a local file requires reading from the connection. -// The function pass a builder to build the connection reader to the context, -// which will be used in LoadDataExec. -func (cc *clientConn) preprocessLoadDataLocal(ctx context.Context) error { - if cc.capability&mysql.ClientLocalFiles == 0 { - return servererr.ErrNotAllowedCommand - } - - var readerBuilder executor.LoadDataReaderBuilder = func(filepath string) ( - io.ReadCloser, error, - ) { - err := cc.writeReq(ctx, filepath) - if err != nil { - return nil, err - } - - drained := false - r, w := io.Pipe() - - go func() { - var errOccurred error - - defer func() { - if errOccurred != nil { - // Continue reading packets to drain the connection - for !drained { - data, err := cc.readPacket() - if err != nil { - logutil.Logger(ctx).Error( - "drain connection failed in load data", - zap.Error(err), - ) - break - } - if len(data) == 0 { - drained = true - } - } - } - err := w.CloseWithError(errOccurred) - if err != nil { - logutil.Logger(ctx).Error( - "close pipe failed in `load data`", - zap.Error(err), - ) - } - }() - - for { - data, err := cc.readPacket() - if err != nil { - errOccurred = err - return - } - - if len(data) == 0 { - drained = true - return - } - - // Write all content in `data` - for len(data) > 0 { - n, err := w.Write(data) - if err != nil { - errOccurred = err - return - } - data = data[n:] - } - } - }() - - return r, nil - } - - cc.ctx.SetValue(executor.LoadDataReaderBuilderKey, readerBuilder) - - return nil -} - -func (cc *clientConn) postprocessLoadDataLocal() { - cc.ctx.ClearValue(executor.LoadDataReaderBuilderKey) -} - -func (cc *clientConn) handleFileTransInConn(ctx context.Context, status uint16) (bool, error) { - handled := false - - loadStats := cc.ctx.Value(executor.LoadStatsVarKey) - if loadStats != nil { - handled = true - defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil) - //nolint:forcetypeassert - if err := cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil { - return handled, err - } - } - - indexAdvise := cc.ctx.Value(executor.IndexAdviseVarKey) - if indexAdvise != nil { - handled = true - defer cc.ctx.SetValue(executor.IndexAdviseVarKey, nil) - //nolint:forcetypeassert - if err := cc.handleIndexAdvise(ctx, indexAdvise.(*executor.IndexAdviseInfo)); err != nil { - return handled, err - } - } - - planReplayerLoad := cc.ctx.Value(executor.PlanReplayerLoadVarKey) - if planReplayerLoad != nil { - handled = true - defer cc.ctx.SetValue(executor.PlanReplayerLoadVarKey, nil) - //nolint:forcetypeassert - if err := cc.handlePlanReplayerLoad(ctx, planReplayerLoad.(*executor.PlanReplayerLoadInfo)); err != nil { - return handled, err - } - } - - planReplayerDump := cc.ctx.Value(executor.PlanReplayerDumpVarKey) - if planReplayerDump != nil { - handled = true - defer cc.ctx.SetValue(executor.PlanReplayerDumpVarKey, nil) - //nolint:forcetypeassert - if err := cc.handlePlanReplayerDump(ctx, planReplayerDump.(*executor.PlanReplayerDumpInfo)); err != nil { - return handled, err - } - } - return handled, cc.writeOkWith(ctx, mysql.OKHeader, true, status) -} - -// handleFieldList returns the field list for a table. -// The sql string is composed of a table name and a terminating character \x00. -func (cc *clientConn) handleFieldList(ctx context.Context, sql string) (err error) { - parts := strings.Split(sql, "\x00") - columns, err := cc.ctx.FieldList(parts[0]) - if err != nil { - return err - } - data := cc.alloc.AllocWithLen(4, 1024) - cc.initResultEncoder(ctx) - defer cc.rsEncoder.Clean() - for _, column := range columns { - data = data[0:4] - data = column.DumpWithDefault(data, cc.rsEncoder) - if err := cc.writePacket(data); err != nil { - return err - } - } - if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil { - return err - } - return cc.flush(ctx) -} - -// writeResultSet writes data into a result set and uses rs.Next to get row data back. -// If binary is true, the data would be encoded in BINARY format. -// serverStatus, a flag bit represents server information. -// fetchSize, the desired number of rows to be fetched each time when client uses cursor. -// retryable indicates whether the call of writeResultSet has no side effect and can be retried to correct error. The call -// has side effect in cursor mode or once data has been sent to client. Currently retryable is used to fallback to TiKV when -// TiFlash is down. -func (cc *clientConn) writeResultSet(ctx context.Context, rs resultset.ResultSet, binary bool, serverStatus uint16, fetchSize int) (retryable bool, runErr error) { - defer func() { - // close ResultSet when cursor doesn't exist - r := recover() - if r == nil { - return - } - recoverdErr, ok := r.(error) - if !ok || !(exeerrors.ErrMemoryExceedForQuery.Equal(recoverdErr) || - exeerrors.ErrMemoryExceedForInstance.Equal(recoverdErr) || - exeerrors.ErrQueryInterrupted.Equal(recoverdErr) || - exeerrors.ErrMaxExecTimeExceeded.Equal(recoverdErr)) { - panic(r) - } - runErr = recoverdErr - // TODO(jianzhang.zj: add metrics here) - logutil.Logger(ctx).Error("write query result panic", zap.Stringer("lastSQL", getLastStmtInConn{cc}), zap.Stack("stack"), zap.Any("recover", r)) - }() - cc.initResultEncoder(ctx) - defer cc.rsEncoder.Clean() - if mysql.HasCursorExistsFlag(serverStatus) { - crs, ok := rs.(resultset.CursorResultSet) - if !ok { - // this branch is actually unreachable - return false, errors.New("this cursor is not a resultSet") - } - if err := cc.writeChunksWithFetchSize(ctx, crs, serverStatus, fetchSize); err != nil { - return false, err - } - return false, cc.flush(ctx) - } - if retryable, err := cc.writeChunks(ctx, rs, binary, serverStatus); err != nil { - return retryable, err - } - - return false, cc.flush(ctx) -} - -func (cc *clientConn) writeColumnInfo(columns []*column.Info) error { - data := cc.alloc.AllocWithLen(4, 1024) - data = dump.LengthEncodedInt(data, uint64(len(columns))) - if err := cc.writePacket(data); err != nil { - return err - } - for _, v := range columns { - data = data[0:4] - data = v.Dump(data, cc.rsEncoder) - if err := cc.writePacket(data); err != nil { - return err - } - } - return nil -} - -// writeChunks writes data from a Chunk, which filled data by a ResultSet, into a connection. -// binary specifies the way to dump data. It throws any error while dumping data. -// serverStatus, a flag bit represents server information -// The first return value indicates whether error occurs at the first call of ResultSet.Next. -func (cc *clientConn) writeChunks(ctx context.Context, rs resultset.ResultSet, binary bool, serverStatus uint16) (bool, error) { - data := cc.alloc.AllocWithLen(4, 1024) - req := rs.NewChunk(cc.chunkAlloc) - gotColumnInfo := false - firstNext := true - validNextCount := 0 - var start time.Time - var stmtDetail *execdetails.StmtExecDetails - stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey) - if stmtDetailRaw != nil { - //nolint:forcetypeassert - stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) - } - for { - failpoint.Inject("fetchNextErr", func(value failpoint.Value) { - //nolint:forcetypeassert - switch value.(string) { - case "firstNext": - failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) - case "secondNext": - if !firstNext { - failpoint.Return(firstNext, storeerr.ErrTiFlashServerTimeout) - } - case "secondNextAndRetConflict": - if !firstNext && validNextCount > 1 { - failpoint.Return(firstNext, kv.ErrWriteConflict) - } - } - }) - // Here server.tidbResultSet implements Next method. - err := rs.Next(ctx, req) - if err != nil { - return firstNext, err - } - if !gotColumnInfo { - // We need to call Next before we get columns. - // Otherwise, we will get incorrect columns info. - columns := rs.Columns() - if stmtDetail != nil { - start = time.Now() - } - if err = cc.writeColumnInfo(columns); err != nil { - return false, err - } - if cc.capability&mysql.ClientDeprecateEOF == 0 { - // metadata only needs EOF marker for old clients without ClientDeprecateEOF - if err = cc.writeEOF(ctx, serverStatus); err != nil { - return false, err - } - } - if stmtDetail != nil { - stmtDetail.WriteSQLRespDuration += time.Since(start) - } - gotColumnInfo = true - } - rowCount := req.NumRows() - if rowCount == 0 { - break - } - validNextCount++ - firstNext = false - reg := trace.StartRegion(ctx, "WriteClientConn") - if stmtDetail != nil { - start = time.Now() - } - for i := 0; i < rowCount; i++ { - data = data[0:4] - if binary { - data, err = column.DumpBinaryRow(data, rs.Columns(), req.GetRow(i), cc.rsEncoder) - } else { - data, err = column.DumpTextRow(data, rs.Columns(), req.GetRow(i), cc.rsEncoder) - } - if err != nil { - reg.End() - return false, err - } - if err = cc.writePacket(data); err != nil { - reg.End() - return false, err - } - } - reg.End() - if stmtDetail != nil { - stmtDetail.WriteSQLRespDuration += time.Since(start) - } - } - if err := rs.Finish(); err != nil { - return false, err - } - - if stmtDetail != nil { - start = time.Now() - } - - err := cc.writeEOF(ctx, serverStatus) - if stmtDetail != nil { - stmtDetail.WriteSQLRespDuration += time.Since(start) - } - return false, err -} - -// writeChunksWithFetchSize writes data from a Chunk, which filled data by a ResultSet, into a connection. -// binary specifies the way to dump data. It throws any error while dumping data. -// serverStatus, a flag bit represents server information. -// fetchSize, the desired number of rows to be fetched each time when client uses cursor. -func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs resultset.CursorResultSet, serverStatus uint16, fetchSize int) error { - var ( - stmtDetail *execdetails.StmtExecDetails - err error - start time.Time - ) - data := cc.alloc.AllocWithLen(4, 1024) - stmtDetailRaw := ctx.Value(execdetails.StmtExecDetailKey) - if stmtDetailRaw != nil { - //nolint:forcetypeassert - stmtDetail = stmtDetailRaw.(*execdetails.StmtExecDetails) - } - if stmtDetail != nil { - start = time.Now() - } - - iter := rs.GetRowIterator() - // send the rows to the client according to fetchSize. - for i := 0; i < fetchSize && iter.Current(ctx) != iter.End(); i++ { - row := iter.Current(ctx) - - data = data[0:4] - data, err = column.DumpBinaryRow(data, rs.Columns(), row, cc.rsEncoder) - if err != nil { - return err - } - if err = cc.writePacket(data); err != nil { - return err - } - - iter.Next(ctx) - } - if iter.Error() != nil { - return iter.Error() - } - - // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, - // and close ResultSet. - if iter.Current(ctx) == iter.End() { - serverStatus &^= mysql.ServerStatusCursorExists - serverStatus |= mysql.ServerStatusLastRowSend - } - - // don't include the time consumed by `cl.OnFetchReturned()` in the `WriteSQLRespDuration` - if stmtDetail != nil { - stmtDetail.WriteSQLRespDuration += time.Since(start) - } - - if cl, ok := rs.(resultset.FetchNotifier); ok { - cl.OnFetchReturned() - } - - start = time.Now() - err = cc.writeEOF(ctx, serverStatus) - if stmtDetail != nil { - stmtDetail.WriteSQLRespDuration += time.Since(start) - } - return err -} - -func (cc *clientConn) setConn(conn net.Conn) { - cc.bufReadConn = util2.NewBufferedReadConn(conn) - if cc.pkt == nil { - cc.pkt = internal.NewPacketIO(cc.bufReadConn) - } else { - // Preserve current sequence number. - cc.pkt.SetBufferedReadConn(cc.bufReadConn) - } -} - -func (cc *clientConn) upgradeToTLS(tlsConfig *tls.Config) error { - // Important: read from buffered reader instead of the original net.Conn because it may contain data we need. - tlsConn := tls.Server(cc.bufReadConn, tlsConfig) - if err := tlsConn.Handshake(); err != nil { - return err - } - cc.setConn(tlsConn) - cc.tlsConn = tlsConn - return nil -} - -func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { - user, data := util2.ParseNullTermString(data) - cc.user = string(hack.String(user)) - if len(data) < 1 { - return mysql.ErrMalformPacket - } - passLen := int(data[0]) - data = data[1:] - if passLen > len(data) { - return mysql.ErrMalformPacket - } - pass := data[:passLen] - data = data[passLen:] - dbName, data := util2.ParseNullTermString(data) - cc.dbname = string(hack.String(dbName)) - pluginName := "" - if len(data) > 0 { - // skip character set - if cc.capability&mysql.ClientProtocol41 > 0 && len(data) >= 2 { - data = data[2:] - } - if cc.capability&mysql.ClientPluginAuth > 0 && len(data) > 0 { - pluginNameB, _ := util2.ParseNullTermString(data) - pluginName = string(hack.String(pluginNameB)) - } - } - - if err := cc.ctx.Close(); err != nil { - logutil.Logger(ctx).Debug("close old context failed", zap.Error(err)) - } - // session was closed by `ctx.Close` and should `openSession` explicitly to renew session. - // `openSession` won't run again in `openSessionAndDoAuth` because ctx is not nil. - err := cc.openSession() - if err != nil { - return err - } - fakeResp := &handshake.Response41{ - Auth: pass, - AuthPlugin: pluginName, - Capability: cc.capability, - } - if fakeResp.AuthPlugin != "" { - failpoint.Inject("ChangeUserAuthSwitch", func(val failpoint.Value) { - failpoint.Return(errors.Errorf("%v", val)) - }) - newpass, err := cc.checkAuthPlugin(ctx, fakeResp) - if err != nil { - return err - } - if len(newpass) > 0 { - fakeResp.Auth = newpass - } - } - if err := cc.openSessionAndDoAuth(fakeResp.Auth, fakeResp.AuthPlugin, fakeResp.ZstdLevel); err != nil { - return err - } - return cc.handleCommonConnectionReset(ctx) -} - -func (cc *clientConn) handleResetConnection(ctx context.Context) error { - user := cc.ctx.GetSessionVars().User - err := cc.ctx.Close() - if err != nil { - logutil.Logger(ctx).Debug("close old context failed", zap.Error(err)) - } - var tlsStatePtr *tls.ConnectionState - if cc.tlsConn != nil { - tlsState := cc.tlsConn.ConnectionState() - tlsStatePtr = &tlsState - } - tidbCtx, err := cc.server.driver.OpenCtx(cc.connectionID, cc.capability, cc.collation, cc.dbname, tlsStatePtr, cc.extensions) - if err != nil { - return err - } - cc.SetCtx(tidbCtx) - if !cc.ctx.AuthWithoutVerification(user) { - return errors.New("Could not reset connection") - } - if cc.dbname != "" { // Restore the current DB - _, err = cc.useDB(context.Background(), cc.dbname) - if err != nil { - return err - } - } - cc.ctx.SetSessionManager(cc.server) - - return cc.handleCommonConnectionReset(ctx) -} - -func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error { - connectionInfo := cc.connectInfo() - cc.ctx.GetSessionVars().ConnectionInfo = connectionInfo - - cc.onExtensionConnEvent(extension.ConnReset, nil) - err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { - authPlugin := plugin.DeclareAuditManifest(p.Manifest) - if authPlugin.OnConnectionEvent != nil { - connInfo := cc.ctx.GetSessionVars().ConnectionInfo - err := authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo) - if err != nil { - return err - } - } - return nil - }) - if err != nil { - return err - } - return cc.writeOK(ctx) -} - -// safe to noop except 0x01 "FLUSH PRIVILEGES" -func (cc *clientConn) handleRefresh(ctx context.Context, subCommand byte) error { - if subCommand == 0x01 { - if err := cc.handleQuery(ctx, "FLUSH PRIVILEGES"); err != nil { - return err - } - } - return cc.writeOK(ctx) -} - -var _ fmt.Stringer = getLastStmtInConn{} - -type getLastStmtInConn struct { - *clientConn -} - -func (cc getLastStmtInConn) String() string { - if len(cc.lastPacket) == 0 { - return "" - } - cmd, data := cc.lastPacket[0], cc.lastPacket[1:] - switch cmd { - case mysql.ComInitDB: - return "Use " + string(data) - case mysql.ComFieldList: - return "ListFields " + string(data) - case mysql.ComQuery, mysql.ComStmtPrepare: - sql := string(hack.String(data)) - sql = parser.Normalize(sql, cc.ctx.GetSessionVars().EnableRedactLog) - return executor.FormatSQL(sql).String() - case mysql.ComStmtExecute, mysql.ComStmtFetch: - stmtID := binary.LittleEndian.Uint32(data[0:4]) - return executor.FormatSQL(cc.preparedStmt2String(stmtID)).String() - case mysql.ComStmtClose, mysql.ComStmtReset: - stmtID := binary.LittleEndian.Uint32(data[0:4]) - return mysql.Command2Str[cmd] + " " + strconv.Itoa(int(stmtID)) - default: - if cmdStr, ok := mysql.Command2Str[cmd]; ok { - return cmdStr - } - return string(hack.String(data)) - } -} - -// PProfLabel return sql label used to tag pprof. -func (cc getLastStmtInConn) PProfLabel() string { - if len(cc.lastPacket) == 0 { - return "" - } - cmd, data := cc.lastPacket[0], cc.lastPacket[1:] - switch cmd { - case mysql.ComInitDB: - return "UseDB" - case mysql.ComFieldList: - return "ListFields" - case mysql.ComStmtClose: - return "CloseStmt" - case mysql.ComStmtReset: - return "ResetStmt" - case mysql.ComQuery, mysql.ComStmtPrepare: - return parser.Normalize(executor.FormatSQL(string(hack.String(data))).String(), errors.RedactLogEnable) - case mysql.ComStmtExecute, mysql.ComStmtFetch: - stmtID := binary.LittleEndian.Uint32(data[0:4]) - return executor.FormatSQL(cc.preparedStmt2StringNoArgs(stmtID)).String() - default: - return "" - } -} - -var _ conn.AuthConn = &clientConn{} - -// WriteAuthMoreData implements `conn.AuthConn` interface -func (cc *clientConn) WriteAuthMoreData(data []byte) error { - // See https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_more_data.html - // the `AuthMoreData` packet is just an arbitrary binary slice with a byte 0x1 as prefix. - return cc.writePacket(append([]byte{0, 0, 0, 0, 1}, data...)) -} - -// ReadPacket implements `conn.AuthConn` interface -func (cc *clientConn) ReadPacket() ([]byte, error) { - return cc.readPacket() -} - -// Flush implements `conn.AuthConn` interface -func (cc *clientConn) Flush(ctx context.Context) error { - return cc.flush(ctx) -} - -type compressionStats struct{} - -// Stats returns the connection statistics. -func (*compressionStats) Stats(vars *variable.SessionVars) (map[string]any, error) { - m := make(map[string]any, 3) - - switch vars.CompressionAlgorithm { - case mysql.CompressionNone: - m[statusCompression] = "OFF" - m[statusCompressionAlgorithm] = "" - m[statusCompressionLevel] = 0 - case mysql.CompressionZlib: - m[statusCompression] = "ON" - m[statusCompressionAlgorithm] = "zlib" - m[statusCompressionLevel] = mysql.ZlibCompressDefaultLevel - case mysql.CompressionZstd: - m[statusCompression] = "ON" - m[statusCompressionAlgorithm] = "zstd" - m[statusCompressionLevel] = vars.CompressionLevel - default: - logutil.BgLogger().Debug( - "unexpected compression algorithm value", - zap.Int("algorithm", vars.CompressionAlgorithm), - ) - m[statusCompression] = "OFF" - m[statusCompressionAlgorithm] = "" - m[statusCompressionLevel] = 0 - } - - return m, nil -} - -// GetScope gets the status variables scope. -func (*compressionStats) GetScope(_ string) variable.ScopeFlag { - return variable.ScopeSession -} - -func init() { - variable.RegisterStatistics(&compressionStats{}) -} diff --git a/pkg/server/conn_stmt.go b/pkg/server/conn_stmt.go index 8a3bedd411280..19ac430500944 100644 --- a/pkg/server/conn_stmt.go +++ b/pkg/server/conn_stmt.go @@ -358,9 +358,9 @@ func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatem } } - if _, _err_ := failpoint.Eval(_curpkg_("avoidEagerCursorFetch")); _err_ == nil { - return false, errors.New("failpoint avoids eager cursor fetch") - } + failpoint.Inject("avoidEagerCursorFetch", func() { + failpoint.Return(false, errors.New("failpoint avoids eager cursor fetch")) + }) cc.initResultEncoder(ctx) defer cc.rsEncoder.Clean() // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read @@ -376,12 +376,12 @@ func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatem rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker) rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch) if variable.EnableTmpStorageOnOOM.Load() { - if val, _err_ := failpoint.Eval(_curpkg_("testCursorFetchSpill")); _err_ == nil { + failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) { if val, ok := val.(bool); val && ok { actionSpill := rowContainer.ActionSpillForTest() defer actionSpill.WaitForTest() } - } + }) action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority) vars.MemTracker.FallbackOldAndSetNewAction(action) } diff --git a/pkg/server/conn_stmt.go__failpoint_stash__ b/pkg/server/conn_stmt.go__failpoint_stash__ deleted file mode 100644 index 19ac430500944..0000000000000 --- a/pkg/server/conn_stmt.go__failpoint_stash__ +++ /dev/null @@ -1,673 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// The MIT License (MIT) -// -// Copyright (c) 2014 wandoulabs -// Copyright (c) 2014 siddontang -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of -// this software and associated documentation files (the "Software"), to deal in -// the Software without restriction, including without limitation the rights to -// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software is furnished to do so, -// subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. - -package server - -import ( - "context" - "encoding/binary" - "runtime/trace" - "strconv" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/param" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/mysql" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/server/internal/dump" - "github.com/pingcap/tidb/pkg/server/internal/parse" - "github.com/pingcap/tidb/pkg/server/internal/resultset" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - storeerr "github.com/pingcap/tidb/pkg/store/driver/error" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/redact" - "github.com/pingcap/tidb/pkg/util/topsql" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -func (cc *clientConn) HandleStmtPrepare(ctx context.Context, sql string) error { - stmt, columns, params, err := cc.ctx.Prepare(sql) - if err != nil { - return err - } - data := make([]byte, 4, 128) - - // status ok - data = append(data, 0) - // stmt id - data = dump.Uint32(data, uint32(stmt.ID())) - // number columns - data = dump.Uint16(data, uint16(len(columns))) - // number params - data = dump.Uint16(data, uint16(len(params))) - // filter [00] - data = append(data, 0) - // warning count - data = append(data, 0, 0) // TODO support warning count - - if err := cc.writePacket(data); err != nil { - return err - } - - cc.initResultEncoder(ctx) - defer cc.rsEncoder.Clean() - if len(params) > 0 { - for i := 0; i < len(params); i++ { - data = data[0:4] - data = params[i].Dump(data, cc.rsEncoder) - - if err := cc.writePacket(data); err != nil { - return err - } - } - - if cc.capability&mysql.ClientDeprecateEOF == 0 { - // metadata only needs EOF marker for old clients without ClientDeprecateEOF - if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil { - return err - } - } - } - - if len(columns) > 0 { - for i := 0; i < len(columns); i++ { - data = data[0:4] - data = columns[i].Dump(data, cc.rsEncoder) - - if err := cc.writePacket(data); err != nil { - return err - } - } - - if cc.capability&mysql.ClientDeprecateEOF == 0 { - // metadata only needs EOF marker for old clients without ClientDeprecateEOF - if err := cc.writeEOF(ctx, cc.ctx.Status()); err != nil { - return err - } - } - } - return cc.flush(ctx) -} - -func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err error) { - defer trace.StartRegion(ctx, "HandleStmtExecute").End() - if len(data) < 9 { - return mysql.ErrMalformPacket - } - pos := 0 - stmtID := binary.LittleEndian.Uint32(data[0:4]) - pos += 4 - - stmt := cc.ctx.GetStatement(int(stmtID)) - if stmt == nil { - return mysql.NewErr(mysql.ErrUnknownStmtHandler, - strconv.FormatUint(uint64(stmtID), 10), "stmt_execute") - } - - flag := data[pos] - pos++ - // Please refer to https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html - // The client indicates that it wants to use cursor by setting this flag. - // Now we only support forward-only, read-only cursor. - useCursor := false - if flag&mysql.CursorTypeReadOnly > 0 { - useCursor = true - } - if flag&mysql.CursorTypeForUpdate > 0 { - return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeForUpdate", nil) - } - if flag&mysql.CursorTypeScrollable > 0 { - return mysql.NewErrf(mysql.ErrUnknown, "unsupported flag: CursorTypeScrollable", nil) - } - - if useCursor { - cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) - defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) - } else { - // not using streaming ,can reuse chunk - cc.ctx.GetSessionVars().SetAlloc(cc.chunkAlloc) - } - // skip iteration-count, always 1 - pos += 4 - - var ( - nullBitmaps []byte - paramTypes []byte - paramValues []byte - ) - cc.initInputEncoder(ctx) - numParams := stmt.NumParams() - args := make([]param.BinaryParam, numParams) - if numParams > 0 { - nullBitmapLen := (numParams + 7) >> 3 - if len(data) < (pos + nullBitmapLen + 1) { - return mysql.ErrMalformPacket - } - nullBitmaps = data[pos : pos+nullBitmapLen] - pos += nullBitmapLen - - // new param bound flag - if data[pos] == 1 { - pos++ - if len(data) < (pos + (numParams << 1)) { - return mysql.ErrMalformPacket - } - - paramTypes = data[pos : pos+(numParams<<1)] - pos += numParams << 1 - paramValues = data[pos:] - // Just the first StmtExecute packet contain parameters type, - // we need save it for further use. - stmt.SetParamsType(paramTypes) - } else { - paramValues = data[pos+1:] - } - - err = parseBinaryParams(args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder) - // This `.Reset` resets the arguments, so it's fine to just ignore the error (and the it'll be reset again in the following routine) - errReset := stmt.Reset() - if errReset != nil { - logutil.Logger(ctx).Warn("fail to reset statement in EXECUTE command", zap.Error(errReset)) - } - if err != nil { - return errors.Annotate(err, cc.preparedStmt2String(stmtID)) - } - } - - sessVars := cc.ctx.GetSessionVars() - // expiredTaskID is the task ID of the previous statement. When executing a stmt, - // the StmtCtx will be reinit and the TaskID will change. We can compare the StmtCtx.TaskID - // with the previous one to determine whether StmtCtx has been inited for the current stmt. - expiredTaskID := sessVars.StmtCtx.TaskID - err = cc.executePlanCacheStmt(ctx, stmt, args, useCursor) - cc.onExtensionBinaryExecuteEnd(stmt, args, sessVars.StmtCtx.TaskID != expiredTaskID, err) - return err -} - -func (cc *clientConn) executePlanCacheStmt(ctx context.Context, stmt any, args []param.BinaryParam, useCursor bool) (err error) { - ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) - ctx = context.WithValue(ctx, util.ExecDetailsKey, &util.ExecDetails{}) - ctx = context.WithValue(ctx, util.RUDetailsCtxKey, util.NewRUDetails()) - retryable, err := cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) - if err != nil { - action, txnErr := sessiontxn.GetTxnManager(&cc.ctx).OnStmtErrorForNextAction(ctx, sessiontxn.StmtErrAfterQuery, err) - if txnErr != nil { - return txnErr - } - - if retryable && action == sessiontxn.StmtActionRetryReady { - cc.ctx.GetSessionVars().RetryInfo.Retrying = true - _, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) - cc.ctx.GetSessionVars().RetryInfo.Retrying = false - return err - } - } - _, allowTiFlashFallback := cc.ctx.GetSessionVars().AllowFallbackToTiKV[kv.TiFlash] - if allowTiFlashFallback && err != nil && errors.ErrorEqual(err, storeerr.ErrTiFlashServerTimeout) && retryable { - // When the TiFlash server seems down, we append a warning to remind the user to check the status of the TiFlash - // server and fallback to TiKV. - prevErr := err - delete(cc.ctx.GetSessionVars().IsolationReadEngines, kv.TiFlash) - defer func() { - cc.ctx.GetSessionVars().IsolationReadEngines[kv.TiFlash] = struct{}{} - }() - _, err = cc.executePreparedStmtAndWriteResult(ctx, stmt.(PreparedStatement), args, useCursor) - // We append warning after the retry because `ResetContextOfStmt` may be called during the retry, which clears warnings. - cc.ctx.GetSessionVars().StmtCtx.AppendError(prevErr) - } - return err -} - -// The first return value indicates whether the call of executePreparedStmtAndWriteResult has no side effect and can be retried. -// Currently the first return value is used to fallback to TiKV when TiFlash is down. -func (cc *clientConn) executePreparedStmtAndWriteResult(ctx context.Context, stmt PreparedStatement, args []param.BinaryParam, useCursor bool) (bool, error) { - vars := (&cc.ctx).GetSessionVars() - prepStmt, err := vars.GetPreparedStmtByID(uint32(stmt.ID())) - if err != nil { - return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) - } - execStmt := &ast.ExecuteStmt{ - BinaryArgs: args, - PrepStmt: prepStmt, - } - - // first, try to clear the left cursor if there is one - if useCursor && stmt.GetCursorActive() { - if stmt.GetResultSet() != nil && stmt.GetResultSet().GetRowIterator() != nil { - stmt.GetResultSet().GetRowIterator().Close() - } - if stmt.GetRowContainer() != nil { - stmt.GetRowContainer().GetMemTracker().Detach() - stmt.GetRowContainer().GetDiskTracker().Detach() - err := stmt.GetRowContainer().Close() - if err != nil { - logutil.Logger(ctx).Error( - "Fail to close rowContainer before executing statement. May cause resource leak", - zap.Error(err)) - } - stmt.StoreRowContainer(nil) - } - stmt.StoreResultSet(nil) - stmt.SetCursorActive(false) - } - - // For the combination of `ComPrepare` and `ComExecute`, the statement name is stored in the client side, and the - // TiDB only has the ID, so don't try to construct an `EXECUTE SOMETHING`. Use the original prepared statement here - // instead. - sql := "" - planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt) - if ok { - sql = planCacheStmt.StmtText - } - execStmt.SetText(charset.EncodingUTF8Impl, sql) - rs, err := (&cc.ctx).ExecuteStmt(ctx, execStmt) - var lazy bool - if rs != nil { - defer func() { - if !lazy { - rs.Close() - } - }() - } - if err != nil { - // If error is returned during the planner phase or the executor.Open - // phase, the rs will be nil, and StmtCtx.MemTracker StmtCtx.DiskTracker - // will not be detached. We need to detach them manually. - if sv := cc.ctx.GetSessionVars(); sv != nil && sv.StmtCtx != nil { - sv.StmtCtx.DetachMemDiskTracker() - } - return true, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) - } - - if rs == nil { - if useCursor { - vars.SetStatusFlag(mysql.ServerStatusCursorExists, false) - } - return false, cc.writeOK(ctx) - } - if planCacheStmt, ok := prepStmt.(*plannercore.PlanCacheStmt); ok { - rs.SetPreparedStmt(planCacheStmt) - } - - // if the client wants to use cursor - // we should hold the ResultSet in PreparedStatement for next stmt_fetch, and only send back ColumnInfo. - // Tell the client cursor exists in server by setting proper serverStatus. - if useCursor { - lazy, err = cc.executeWithCursor(ctx, stmt, rs) - return false, err - } - retryable, err := cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), 0) - if err != nil { - return retryable, errors.Annotate(err, cc.preparedStmt2String(uint32(stmt.ID()))) - } - return false, nil -} - -func (cc *clientConn) executeWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (lazy bool, err error) { - vars := (&cc.ctx).GetSessionVars() - if vars.EnableLazyCursorFetch { - // try to execute with lazy cursor fetch - ok, err := cc.executeWithLazyCursor(ctx, stmt, rs) - - // if `ok` is false, should try to execute without lazy cursor fetch - if ok { - return true, err - } - } - - failpoint.Inject("avoidEagerCursorFetch", func() { - failpoint.Return(false, errors.New("failpoint avoids eager cursor fetch")) - }) - cc.initResultEncoder(ctx) - defer cc.rsEncoder.Clean() - // fetch all results of the resultSet, and stored them locally, so that the future `FETCH` command can read - // the rows directly to avoid running executor and accessing shared params/variables in the session - // NOTE: chunk should not be allocated from the connection allocator, which will reset after executing this command - // but the rows are still needed in the following FETCH command. - - // create the row container to manage spill - // this `rowContainer` will be released when the statement (or the connection) is closed. - rowContainer := chunk.NewRowContainer(rs.FieldTypes(), vars.MaxChunkSize) - rowContainer.GetMemTracker().AttachTo(vars.MemTracker) - rowContainer.GetMemTracker().SetLabel(memory.LabelForCursorFetch) - rowContainer.GetDiskTracker().AttachTo(vars.DiskTracker) - rowContainer.GetDiskTracker().SetLabel(memory.LabelForCursorFetch) - if variable.EnableTmpStorageOnOOM.Load() { - failpoint.Inject("testCursorFetchSpill", func(val failpoint.Value) { - if val, ok := val.(bool); val && ok { - actionSpill := rowContainer.ActionSpillForTest() - defer actionSpill.WaitForTest() - } - }) - action := memory.NewActionWithPriority(rowContainer.ActionSpill(), memory.DefCursorFetchSpillPriority) - vars.MemTracker.FallbackOldAndSetNewAction(action) - } - defer func() { - if err != nil { - rowContainer.GetMemTracker().Detach() - rowContainer.GetDiskTracker().Detach() - errCloseRowContainer := rowContainer.Close() - if errCloseRowContainer != nil { - logutil.Logger(ctx).Error("Fail to close rowContainer in error handler. May cause resource leak", - zap.NamedError("original-error", err), zap.NamedError("close-error", errCloseRowContainer)) - } - } - }() - - for { - chk := rs.NewChunk(nil) - - if err = rs.Next(ctx, chk); err != nil { - return false, err - } - rowCount := chk.NumRows() - if rowCount == 0 { - break - } - - err = rowContainer.Add(chk) - if err != nil { - return false, err - } - } - - reader := chunk.NewRowContainerReader(rowContainer) - defer func() { - if err != nil { - reader.Close() - } - }() - crs := resultset.WrapWithRowContainerCursor(rs, reader) - if cl, ok := crs.(resultset.FetchNotifier); ok { - cl.OnFetchReturned() - } - stmt.StoreRowContainer(rowContainer) - - err = cc.writeExecuteResultWithCursor(ctx, stmt, crs) - return false, err -} - -// executeWithLazyCursor tries to detach the `ResultSet` and make it suitable to execute lazily. -// Be careful that the return value `(bool, error)` has different meaning with other similar functions. The first `bool` represent whether -// the `ResultSet` is suitable for lazy execution. If the return value is `(false, _)`, the `rs` in argument can still be used. If the -// first return value is `true` and `err` is not nil, the `rs` cannot be used anymore and should return the error to the upper layer. -func (cc *clientConn) executeWithLazyCursor(ctx context.Context, stmt PreparedStatement, rs resultset.ResultSet) (ok bool, err error) { - drs, ok, err := rs.TryDetach() - if !ok || err != nil { - return false, err - } - - vars := (&cc.ctx).GetSessionVars() - crs := resultset.WrapWithLazyCursor(drs, vars.InitChunkSize, vars.MaxChunkSize) - err = cc.writeExecuteResultWithCursor(ctx, stmt, crs) - return true, err -} - -// writeExecuteResultWithCursor will store the `ResultSet` in `stmt` and send the column info to the client. The logic is shared between -// lazy cursor fetch and normal(eager) cursor fetch. -func (cc *clientConn) writeExecuteResultWithCursor(ctx context.Context, stmt PreparedStatement, rs resultset.CursorResultSet) error { - var err error - - stmt.StoreResultSet(rs) - stmt.SetCursorActive(true) - defer func() { - if err != nil { - // the resultSet and rowContainer have been closed in former "defer" statement. - stmt.StoreResultSet(nil) - stmt.StoreRowContainer(nil) - stmt.SetCursorActive(false) - } - }() - - if err = cc.writeColumnInfo(rs.Columns()); err != nil { - return err - } - - // explicitly flush columnInfo to client. - err = cc.writeEOF(ctx, cc.ctx.Status()) - if err != nil { - return err - } - - return cc.flush(ctx) -} - -func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err error) { - cc.ctx.GetSessionVars().StartTime = time.Now() - cc.ctx.GetSessionVars().ClearAlloc(nil, false) - cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, true) - defer cc.ctx.GetSessionVars().SetStatusFlag(mysql.ServerStatusCursorExists, false) - // Reset the warn count. TODO: consider whether it's better to reset the whole session context/statement context. - if cc.ctx.GetSessionVars().StmtCtx != nil { - cc.ctx.GetSessionVars().StmtCtx.SetWarnings(nil) - } - cc.ctx.GetSessionVars().SysErrorCount = 0 - cc.ctx.GetSessionVars().SysWarningCount = 0 - - stmtID, fetchSize, err := parse.StmtFetchCmd(data) - if err != nil { - return err - } - - stmt := cc.ctx.GetStatement(int(stmtID)) - if stmt == nil { - return errors.Annotate(mysql.NewErr(mysql.ErrUnknownStmtHandler, - strconv.FormatUint(uint64(stmtID), 10), "stmt_fetch"), cc.preparedStmt2String(stmtID)) - } - if !stmt.GetCursorActive() { - return errors.Annotate(mysql.NewErr(mysql.ErrSpCursorNotOpen), cc.preparedStmt2String(stmtID)) - } - // from now on, we have made sure: the statement has an active cursor - // then if facing any error, this cursor should be reset - defer func() { - if err != nil { - errReset := stmt.Reset() - if errReset != nil { - logutil.Logger(ctx).Error("Fail to reset statement in error handler. May cause resource leak.", - zap.NamedError("original-error", err), zap.NamedError("reset-error", errReset)) - } - } - }() - - if topsqlstate.TopSQLEnabled() { - prepareObj, _ := cc.preparedStmtID2CachePreparedStmt(stmtID) - if prepareObj != nil && prepareObj.SQLDigest != nil { - ctx = topsql.AttachAndRegisterSQLInfo(ctx, prepareObj.NormalizedSQL, prepareObj.SQLDigest, false) - } - } - sql := "" - if prepared, ok := cc.ctx.GetStatement(int(stmtID)).(*TiDBStatement); ok { - sql = prepared.sql - } - cc.ctx.SetProcessInfo(sql, time.Now(), mysql.ComStmtExecute, 0) - rs := stmt.GetResultSet() - - _, err = cc.writeResultSet(ctx, rs, true, cc.ctx.Status(), int(fetchSize)) - // if the iterator reached the end before writing result, we could say the `FETCH` command will send EOF - if rs.GetRowIterator().Current(ctx) == rs.GetRowIterator().End() { - // also reset the statement when the cursor reaches the end - // don't overwrite the `err` in outer scope, to avoid redundant `Reset()` in `defer` statement (though, it's not - // a big problem, as the `Reset()` function call is idempotent.) - err := stmt.Reset() - if err != nil { - logutil.Logger(ctx).Error("Fail to reset statement when FETCH command reaches the end. May cause resource leak", - zap.NamedError("error", err)) - } - } - if err != nil { - return errors.Annotate(err, cc.preparedStmt2String(stmtID)) - } - - return nil -} - -func (cc *clientConn) handleStmtClose(data []byte) (err error) { - if len(data) < 4 { - return - } - - stmtID := int(binary.LittleEndian.Uint32(data[0:4])) - stmt := cc.ctx.GetStatement(stmtID) - if stmt != nil { - return stmt.Close() - } - - return -} - -func (cc *clientConn) handleStmtSendLongData(data []byte) (err error) { - if len(data) < 6 { - return mysql.ErrMalformPacket - } - - stmtID := int(binary.LittleEndian.Uint32(data[0:4])) - - stmt := cc.ctx.GetStatement(stmtID) - if stmt == nil { - return mysql.NewErr(mysql.ErrUnknownStmtHandler, - strconv.Itoa(stmtID), "stmt_send_longdata") - } - - paramID := int(binary.LittleEndian.Uint16(data[4:6])) - return stmt.AppendParam(paramID, data[6:]) -} - -func (cc *clientConn) handleStmtReset(ctx context.Context, data []byte) (err error) { - // A reset command should reset the statement to the state when it was right after prepare - // Then the following state should be cleared: - // 1.The opened cursor, including the rowContainer (and its cursor/memTracker). - // 2.The argument sent through `SEND_LONG_DATA`. - if len(data) < 4 { - return mysql.ErrMalformPacket - } - - stmtID := int(binary.LittleEndian.Uint32(data[0:4])) - stmt := cc.ctx.GetStatement(stmtID) - if stmt == nil { - return mysql.NewErr(mysql.ErrUnknownStmtHandler, - strconv.Itoa(stmtID), "stmt_reset") - } - err = stmt.Reset() - if err != nil { - // Both server and client cannot handle the error case well, so just left an error and return OK. - // It's fine to receive further `EXECUTE` command even the `Reset` function call failed. - logutil.Logger(ctx).Error("Fail to close statement in error handler of RESET command. May cause resource leak", - zap.NamedError("original-error", err), zap.NamedError("close-error", err)) - - return cc.writeOK(ctx) - } - - return cc.writeOK(ctx) -} - -// handleSetOption refer to https://dev.mysql.com/doc/internals/en/com-set-option.html -func (cc *clientConn) handleSetOption(ctx context.Context, data []byte) (err error) { - if len(data) < 2 { - return mysql.ErrMalformPacket - } - - switch binary.LittleEndian.Uint16(data[:2]) { - case 0: - cc.capability |= mysql.ClientMultiStatements - cc.ctx.SetClientCapability(cc.capability) - case 1: - cc.capability &^= mysql.ClientMultiStatements - cc.ctx.SetClientCapability(cc.capability) - default: - return mysql.ErrMalformPacket - } - - if err = cc.writeEOF(ctx, cc.ctx.Status()); err != nil { - return err - } - - return cc.flush(ctx) -} - -func (cc *clientConn) preparedStmt2String(stmtID uint32) string { - sv := cc.ctx.GetSessionVars() - if sv == nil { - return "" - } - sql := parser.Normalize(cc.preparedStmt2StringNoArgs(stmtID), sv.EnableRedactLog) - if m := sv.EnableRedactLog; m != errors.RedactLogEnable { - sql += redact.String(sv.EnableRedactLog, sv.PlanCacheParams.String()) - } - return sql -} - -func (cc *clientConn) preparedStmt2StringNoArgs(stmtID uint32) string { - sv := cc.ctx.GetSessionVars() - if sv == nil { - return "" - } - preparedObj, invalid := cc.preparedStmtID2CachePreparedStmt(stmtID) - if invalid { - return "invalidate PlanCacheStmt type, ID: " + strconv.FormatUint(uint64(stmtID), 10) - } - if preparedObj == nil { - return "prepared statement not found, ID: " + strconv.FormatUint(uint64(stmtID), 10) - } - return preparedObj.PreparedAst.Stmt.Text() -} - -func (cc *clientConn) preparedStmtID2CachePreparedStmt(stmtID uint32) (_ *plannercore.PlanCacheStmt, invalid bool) { - sv := cc.ctx.GetSessionVars() - if sv == nil { - return nil, false - } - preparedPointer, ok := sv.PreparedStmts[stmtID] - if !ok { - // not found - return nil, false - } - preparedObj, ok := preparedPointer.(*plannercore.PlanCacheStmt) - if !ok { - // invalid cache. should never happen. - return nil, true - } - return preparedObj, false -} diff --git a/pkg/server/handler/binding__failpoint_binding__.go b/pkg/server/handler/binding__failpoint_binding__.go deleted file mode 100644 index 12717ba2a9228..0000000000000 --- a/pkg/server/handler/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package handler - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/server/handler/extractorhandler/binding__failpoint_binding__.go b/pkg/server/handler/extractorhandler/binding__failpoint_binding__.go deleted file mode 100644 index 7301cdac20fb4..0000000000000 --- a/pkg/server/handler/extractorhandler/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package extractorhandler - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/server/handler/extractorhandler/extractor.go b/pkg/server/handler/extractorhandler/extractor.go index ba5831a6095a2..4d925e6e5f190 100644 --- a/pkg/server/handler/extractorhandler/extractor.go +++ b/pkg/server/handler/extractorhandler/extractor.go @@ -56,16 +56,16 @@ func (eh ExtractTaskServeHandler) ServeHTTP(w http.ResponseWriter, req *http.Req handler.WriteError(w, err) return } - if val, _err_ := failpoint.Eval(_curpkg_("extractTaskServeHandler")); _err_ == nil { + failpoint.Inject("extractTaskServeHandler", func(val failpoint.Value) { if val.(bool) { w.WriteHeader(http.StatusOK) _, err = w.Write([]byte("mock")) if err != nil { handler.WriteError(w, err) } - return + failpoint.Return() } - } + }) name, err := eh.ExtractHandler.ExtractTask(context.Background(), task) if err != nil { diff --git a/pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ b/pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ deleted file mode 100644 index 4d925e6e5f190..0000000000000 --- a/pkg/server/handler/extractorhandler/extractor.go__failpoint_stash__ +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package extractorhandler - -import ( - "context" - "fmt" - "io" - "net/http" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/server/handler" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -const ( - extractPlanTaskType = "plan" -) - -// ExtractTaskServeHandler is the http serve handler for extract task handler -type ExtractTaskServeHandler struct { - ExtractHandler *domain.ExtractHandle -} - -// NewExtractTaskServeHandler creates a new extract task serve handler -func NewExtractTaskServeHandler(extractHandler *domain.ExtractHandle) *ExtractTaskServeHandler { - return &ExtractTaskServeHandler{ExtractHandler: extractHandler} -} - -// ServeHTTP serves http -func (eh ExtractTaskServeHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { - task, isDump, err := buildExtractTask(req) - if err != nil { - logutil.BgLogger().Error("build extract task failed", zap.Error(err)) - handler.WriteError(w, err) - return - } - failpoint.Inject("extractTaskServeHandler", func(val failpoint.Value) { - if val.(bool) { - w.WriteHeader(http.StatusOK) - _, err = w.Write([]byte("mock")) - if err != nil { - handler.WriteError(w, err) - } - failpoint.Return() - } - }) - - name, err := eh.ExtractHandler.ExtractTask(context.Background(), task) - if err != nil { - logutil.BgLogger().Error("extract task failed", zap.Error(err)) - handler.WriteError(w, err) - return - } - w.WriteHeader(http.StatusOK) - if !isDump { - _, err = w.Write([]byte(name)) - if err != nil { - logutil.BgLogger().Error("extract handler failed", zap.Error(err)) - } - return - } - content, err := loadExtractResponse(name) - if err != nil { - logutil.BgLogger().Error("load extract task failed", zap.Error(err)) - handler.WriteError(w, err) - return - } - _, err = w.Write(content) - if err != nil { - handler.WriteError(w, err) - return - } - w.Header().Set("Content-Type", "application/zip") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s.zip\"", name)) -} - -func loadExtractResponse(name string) ([]byte, error) { - path := filepath.Join(domain.GetExtractTaskDirName(), name) - //nolint: gosec - file, err := os.Open(path) - if err != nil { - return nil, err - } - defer file.Close() - content, err := io.ReadAll(file) - if err != nil { - return nil, err - } - return content, nil -} - -func buildExtractTask(req *http.Request) (*domain.ExtractTask, bool, error) { - extractTaskType := req.URL.Query().Get(handler.Type) - if strings.ToLower(extractTaskType) == extractPlanTaskType { - return buildExtractPlanTask(req) - } - logutil.BgLogger().Error("unknown extract task type") - return nil, false, errors.New("unknown extract task type") -} - -func buildExtractPlanTask(req *http.Request) (*domain.ExtractTask, bool, error) { - beginStr := req.URL.Query().Get(handler.Begin) - endStr := req.URL.Query().Get(handler.End) - var begin time.Time - var err error - if len(beginStr) < 1 { - begin = time.Now().Add(30 * time.Minute) - } else { - begin, err = time.Parse(types.TimeFormat, beginStr) - if err != nil { - logutil.BgLogger().Error("extract task begin time failed", zap.Error(err), zap.String("begin", beginStr)) - return nil, false, err - } - } - var end time.Time - if len(endStr) < 1 { - end = time.Now() - } else { - end, err = time.Parse(types.TimeFormat, endStr) - if err != nil { - logutil.BgLogger().Error("extract task end time failed", zap.Error(err), zap.String("end", endStr)) - return nil, false, err - } - } - isDump := extractBoolParam(handler.IsDump, false, req) - - return &domain.ExtractTask{ - ExtractType: domain.ExtractPlanType, - IsBackgroundJob: false, - Begin: begin, - End: end, - SkipStats: extractBoolParam(handler.IsSkipStats, false, req), - UseHistoryView: extractBoolParam(handler.IsHistoryView, true, req), - }, isDump, nil -} - -func extractBoolParam(param string, defaultValue bool, req *http.Request) bool { - str := req.URL.Query().Get(param) - if len(str) < 1 { - return defaultValue - } - v, err := strconv.ParseBool(str) - if err != nil { - return defaultValue - } - return v -} diff --git a/pkg/server/handler/tikv_handler.go b/pkg/server/handler/tikv_handler.go index 9c9081ce51270..1a6339e05be6d 100644 --- a/pkg/server/handler/tikv_handler.go +++ b/pkg/server/handler/tikv_handler.go @@ -259,11 +259,11 @@ func (t *TikvHandlerTool) GetRegionsMeta(regionIDs []uint64) ([]RegionMeta, erro return nil, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("errGetRegionByIDEmpty")); _err_ == nil { + failpoint.Inject("errGetRegionByIDEmpty", func(val failpoint.Value) { if val.(bool) { region.Meta = nil } - } + }) if region.Meta == nil { return nil, errors.Errorf("region not found for regionID %q", regionID) diff --git a/pkg/server/handler/tikv_handler.go__failpoint_stash__ b/pkg/server/handler/tikv_handler.go__failpoint_stash__ deleted file mode 100644 index 1a6339e05be6d..0000000000000 --- a/pkg/server/handler/tikv_handler.go__failpoint_stash__ +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package handler - -import ( - "context" - "encoding/hex" - "fmt" - "net/url" - "strconv" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/session" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - derr "github.com/pingcap/tidb/pkg/store/driver/error" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/table/tables" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/tikv/client-go/v2/tikv" -) - -// TikvHandlerTool is a tool to handle TiKV data. -type TikvHandlerTool struct { - helper.Helper -} - -// NewTikvHandlerTool creates a new TikvHandlerTool. -func NewTikvHandlerTool(store helper.Storage) *TikvHandlerTool { - return &TikvHandlerTool{Helper: *helper.NewHelper(store)} -} - -type mvccKV struct { - Key string `json:"key"` - RegionID uint64 `json:"region_id"` - Value *kvrpcpb.MvccGetByKeyResponse `json:"value"` -} - -// GetRegionIDByKey gets the region id by the key. -func (t *TikvHandlerTool) GetRegionIDByKey(encodedKey []byte) (uint64, error) { - keyLocation, err := t.RegionCache.LocateKey(tikv.NewBackofferWithVars(context.Background(), 500, nil), encodedKey) - if err != nil { - return 0, derr.ToTiDBErr(err) - } - return keyLocation.Region.GetID(), nil -} - -// GetHandle gets the handle of the record. -func (t *TikvHandlerTool) GetHandle(tb table.PhysicalTable, params map[string]string, values url.Values) (kv.Handle, error) { - var handle kv.Handle - if intHandleStr, ok := params[Handle]; ok { - if tb.Meta().IsCommonHandle { - return nil, errors.BadRequestf("For clustered index tables, please use query strings to specify the column values.") - } - intHandle, err := strconv.ParseInt(intHandleStr, 0, 64) - if err != nil { - return nil, errors.Trace(err) - } - handle = kv.IntHandle(intHandle) - } else { - tblInfo := tb.Meta() - pkIdx := tables.FindPrimaryIndex(tblInfo) - if pkIdx == nil || !tblInfo.IsCommonHandle { - return nil, errors.BadRequestf("Clustered common handle not found.") - } - cols := tblInfo.Cols() - pkCols := make([]*model.ColumnInfo, 0, len(pkIdx.Columns)) - for _, idxCol := range pkIdx.Columns { - pkCols = append(pkCols, cols[idxCol.Offset]) - } - sc := stmtctx.NewStmtCtx() - sc.SetTimeZone(time.UTC) - pkDts, err := t.formValue2DatumRow(sc, values, pkCols) - if err != nil { - return nil, errors.Trace(err) - } - tablecodec.TruncateIndexValues(tblInfo, pkIdx, pkDts) - var handleBytes []byte - handleBytes, err = codec.EncodeKey(sc.TimeZone(), nil, pkDts...) - err = sc.HandleError(err) - if err != nil { - return nil, errors.Trace(err) - } - handle, err = kv.NewCommonHandle(handleBytes) - if err != nil { - return nil, errors.Trace(err) - } - } - return handle, nil -} - -// GetMvccByIdxValue gets the mvcc by the index value. -func (t *TikvHandlerTool) GetMvccByIdxValue(idx table.Index, values url.Values, idxCols []*model.ColumnInfo, handle kv.Handle) ([]*helper.MvccKV, error) { - // HTTP request is not a database session, set timezone to UTC directly here. - // See https://github.com/pingcap/tidb/blob/master/docs/tidb_http_api.md for more details. - sc := stmtctx.NewStmtCtxWithTimeZone(time.UTC) - idxRow, err := t.formValue2DatumRow(sc, values, idxCols) - if err != nil { - return nil, errors.Trace(err) - } - encodedKey, _, err := idx.GenIndexKey(sc.ErrCtx(), sc.TimeZone(), idxRow, handle, nil) - if err != nil { - return nil, errors.Trace(err) - } - data, err := t.GetMvccByEncodedKey(encodedKey) - if err != nil { - return nil, err - } - regionID, err := t.GetRegionIDByKey(encodedKey) - if err != nil { - return nil, err - } - idxData := &helper.MvccKV{Key: strings.ToUpper(hex.EncodeToString(encodedKey)), RegionID: regionID, Value: data} - tablecodec.IndexKey2TempIndexKey(encodedKey) - data, err = t.GetMvccByEncodedKey(encodedKey) - if err != nil { - return nil, err - } - regionID, err = t.GetRegionIDByKey(encodedKey) - if err != nil { - return nil, err - } - tempIdxData := &helper.MvccKV{Key: strings.ToUpper(hex.EncodeToString(encodedKey)), RegionID: regionID, Value: data} - return append([]*helper.MvccKV{}, idxData, tempIdxData), err -} - -// formValue2DatumRow converts URL query string to a Datum Row. -func (*TikvHandlerTool) formValue2DatumRow(sc *stmtctx.StatementContext, values url.Values, idxCols []*model.ColumnInfo) ([]types.Datum, error) { - data := make([]types.Datum, len(idxCols)) - for i, col := range idxCols { - colName := col.Name.String() - vals, ok := values[colName] - if !ok { - return nil, errors.BadRequestf("Missing value for index column %s.", colName) - } - - switch len(vals) { - case 0: - data[i].SetNull() - case 1: - bDatum := types.NewStringDatum(vals[0]) - cDatum, err := bDatum.ConvertTo(sc.TypeCtx(), &col.FieldType) - if err != nil { - return nil, errors.Trace(err) - } - data[i] = cDatum - default: - return nil, errors.BadRequestf("Invalid query form for column '%s', it's values are %v."+ - " Column value should be unique for one index record.", colName, vals) - } - } - return data, nil -} - -// GetTableID gets the table ID by the database name and table name. -func (t *TikvHandlerTool) GetTableID(dbName, tableName string) (int64, error) { - tbl, err := t.GetTable(dbName, tableName) - if err != nil { - return 0, errors.Trace(err) - } - return tbl.GetPhysicalID(), nil -} - -// GetTable gets the table by the database name and table name. -func (t *TikvHandlerTool) GetTable(dbName, tableName string) (table.PhysicalTable, error) { - schema, err := t.Schema() - if err != nil { - return nil, errors.Trace(err) - } - tableName, partitionName := ExtractTableAndPartitionName(tableName) - tableVal, err := schema.TableByName(context.Background(), model.NewCIStr(dbName), model.NewCIStr(tableName)) - if err != nil { - return nil, errors.Trace(err) - } - return t.GetPartition(tableVal, partitionName) -} - -// GetPartition gets the partition by the table and partition name. -func (*TikvHandlerTool) GetPartition(tableVal table.Table, partitionName string) (table.PhysicalTable, error) { - if pt, ok := tableVal.(table.PartitionedTable); ok { - if partitionName == "" { - return tableVal.(table.PhysicalTable), errors.New("work on partitioned table, please specify the table name like this: table(partition)") - } - tblInfo := pt.Meta() - pid, err := tables.FindPartitionByName(tblInfo, partitionName) - if err != nil { - return nil, errors.Trace(err) - } - return pt.GetPartition(pid), nil - } - if partitionName != "" { - return nil, fmt.Errorf("%s is not a partitionted table", tableVal.Meta().Name) - } - return tableVal.(table.PhysicalTable), nil -} - -// Schema gets the schema. -func (t *TikvHandlerTool) Schema() (infoschema.InfoSchema, error) { - dom, err := session.GetDomain(t.Store) - if err != nil { - return nil, err - } - return dom.InfoSchema(), nil -} - -// HandleMvccGetByHex handles the request of getting mvcc by hex encoded key. -func (t *TikvHandlerTool) HandleMvccGetByHex(params map[string]string) (*mvccKV, error) { - encodedKey, err := hex.DecodeString(params[HexKey]) - if err != nil { - return nil, errors.Trace(err) - } - data, err := t.GetMvccByEncodedKey(encodedKey) - if err != nil { - return nil, errors.Trace(err) - } - regionID, err := t.GetRegionIDByKey(encodedKey) - if err != nil { - return nil, err - } - return &mvccKV{Key: strings.ToUpper(params[HexKey]), Value: data, RegionID: regionID}, nil -} - -// RegionMeta contains a region's peer detail -type RegionMeta struct { - ID uint64 `json:"region_id"` - Leader *metapb.Peer `json:"leader"` - Peers []*metapb.Peer `json:"peers"` - RegionEpoch *metapb.RegionEpoch `json:"region_epoch"` -} - -// GetRegionsMeta gets regions meta by regionIDs -func (t *TikvHandlerTool) GetRegionsMeta(regionIDs []uint64) ([]RegionMeta, error) { - regions := make([]RegionMeta, len(regionIDs)) - for i, regionID := range regionIDs { - region, err := t.RegionCache.PDClient().GetRegionByID(context.TODO(), regionID) - if err != nil { - return nil, errors.Trace(err) - } - - failpoint.Inject("errGetRegionByIDEmpty", func(val failpoint.Value) { - if val.(bool) { - region.Meta = nil - } - }) - - if region.Meta == nil { - return nil, errors.Errorf("region not found for regionID %q", regionID) - } - regions[i] = RegionMeta{ - ID: regionID, - Leader: region.Leader, - Peers: region.Meta.Peers, - RegionEpoch: region.Meta.RegionEpoch, - } - } - return regions, nil -} diff --git a/pkg/server/http_status.go b/pkg/server/http_status.go index e77fe87e0bbad..335b3855dc68c 100644 --- a/pkg/server/http_status.go +++ b/pkg/server/http_status.go @@ -420,14 +420,14 @@ func (s *Server) startHTTPServer() { }) // failpoint is enabled only for tests so we can add some http APIs here for tests. - if _, _err_ := failpoint.Eval(_curpkg_("enableTestAPI")); _err_ == nil { + failpoint.Inject("enableTestAPI", func() { router.PathPrefix("/fail/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") new(failpoint.HttpHandler).ServeHTTP(w, r) }) router.Handle("/test/{mod}/{op}", tikvhandler.NewTestHandler(tikvHandlerTool, 0)) - } + }) // ddlHook is enabled only for tests so we can substitute the callback in the DDL. router.Handle("/test/ddl/hook", tikvhandler.DDLHookHandler{}) diff --git a/pkg/server/http_status.go__failpoint_stash__ b/pkg/server/http_status.go__failpoint_stash__ deleted file mode 100644 index 335b3855dc68c..0000000000000 --- a/pkg/server/http_status.go__failpoint_stash__ +++ /dev/null @@ -1,613 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package server - -import ( - "archive/zip" - "bytes" - "context" - "crypto/tls" - "crypto/x509" - "encoding/json" - "fmt" - "io" - "net" - "net/http" - "net/http/pprof" - "net/url" - "runtime" - rpprof "runtime/pprof" - "strconv" - "strings" - "sync" - "time" - - "github.com/gorilla/mux" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/fn" - pb "github.com/pingcap/kvproto/pkg/autoid" - autoid "github.com/pingcap/tidb/pkg/autoid_service" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/server/handler" - "github.com/pingcap/tidb/pkg/server/handler/optimizor" - "github.com/pingcap/tidb/pkg/server/handler/tikvhandler" - "github.com/pingcap/tidb/pkg/server/handler/ttlhandler" - util2 "github.com/pingcap/tidb/pkg/server/internal/util" - "github.com/pingcap/tidb/pkg/session" - "github.com/pingcap/tidb/pkg/store" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/cpuprofile" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/printer" - "github.com/pingcap/tidb/pkg/util/versioninfo" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/soheilhy/cmux" - "github.com/tiancaiamao/appdash/traceapp" - "go.uber.org/zap" - "google.golang.org/grpc/channelz/service" - static "sourcegraph.com/sourcegraph/appdash-data" -) - -const defaultStatusPort = 10080 - -func (s *Server) startStatusHTTP() error { - err := s.initHTTPListener() - if err != nil { - return err - } - go s.startHTTPServer() - return nil -} - -func serveError(w http.ResponseWriter, status int, txt string) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Go-Pprof", "1") - w.Header().Del("Content-Disposition") - w.WriteHeader(status) - _, err := fmt.Fprintln(w, txt) - terror.Log(err) -} - -func sleepWithCtx(ctx context.Context, d time.Duration) { - select { - case <-time.After(d): - case <-ctx.Done(): - } -} - -func (s *Server) listenStatusHTTPServer() error { - s.statusAddr = net.JoinHostPort(s.cfg.Status.StatusHost, strconv.Itoa(int(s.cfg.Status.StatusPort))) - if s.cfg.Status.StatusPort == 0 && !RunInGoTest { - s.statusAddr = net.JoinHostPort(s.cfg.Status.StatusHost, strconv.Itoa(defaultStatusPort)) - } - - logutil.BgLogger().Info("for status and metrics report", zap.String("listening on addr", s.statusAddr)) - clusterSecurity := s.cfg.Security.ClusterSecurity() - tlsConfig, err := clusterSecurity.ToTLSConfig() - if err != nil { - logutil.BgLogger().Error("invalid TLS config", zap.Error(err)) - return errors.Trace(err) - } - tlsConfig = s.SetCNChecker(tlsConfig) - - if tlsConfig != nil { - // we need to manage TLS here for cmux to distinguish between HTTP and gRPC. - s.statusListener, err = tls.Listen("tcp", s.statusAddr, tlsConfig) - } else { - s.statusListener, err = net.Listen("tcp", s.statusAddr) - } - if err != nil { - logutil.BgLogger().Info("listen failed", zap.Error(err)) - return errors.Trace(err) - } else if RunInGoTest && s.cfg.Status.StatusPort == 0 { - s.statusAddr = s.statusListener.Addr().String() - s.cfg.Status.StatusPort = uint(s.statusListener.Addr().(*net.TCPAddr).Port) - } - return nil -} - -// Ballast try to reduce the GC frequency by using Ballast Object -type Ballast struct { - ballast []byte - ballastLock sync.Mutex - - maxSize int -} - -func newBallast(maxSize int) *Ballast { - var b Ballast - b.maxSize = 1024 * 1024 * 1024 * 2 - if maxSize > 0 { - b.maxSize = maxSize - } else { - // we try to use the total amount of ram as a reference to set the default ballastMaxSz - // since the fatal throw "runtime: out of memory" would never yield to `recover` - totalRAMSz, err := memory.MemTotal() - if err != nil { - logutil.BgLogger().Error("failed to get the total amount of RAM on this system", zap.Error(err)) - } else { - maxSzAdvice := totalRAMSz >> 2 - if uint64(b.maxSize) > maxSzAdvice { - b.maxSize = int(maxSzAdvice) - } - } - } - return &b -} - -// GetSize get the size of ballast object -func (b *Ballast) GetSize() int { - var sz int - b.ballastLock.Lock() - sz = len(b.ballast) - b.ballastLock.Unlock() - return sz -} - -// SetSize set the size of ballast object -func (b *Ballast) SetSize(newSz int) error { - if newSz < 0 { - return fmt.Errorf("newSz cannot be negative: %d", newSz) - } - if newSz > b.maxSize { - return fmt.Errorf("newSz cannot be bigger than %d but it has value %d", b.maxSize, newSz) - } - b.ballastLock.Lock() - b.ballast = make([]byte, newSz) - b.ballastLock.Unlock() - return nil -} - -// GenHTTPHandler generate a HTTP handler to get/set the size of this ballast object -func (b *Ballast) GenHTTPHandler() func(w http.ResponseWriter, r *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - _, err := w.Write([]byte(strconv.Itoa(b.GetSize()))) - terror.Log(err) - case http.MethodPost: - body, err := io.ReadAll(r.Body) - if err != nil { - terror.Log(err) - return - } - newSz, err := strconv.Atoi(string(body)) - if err == nil { - err = b.SetSize(newSz) - } - if err != nil { - w.WriteHeader(http.StatusBadRequest) - errStr := err.Error() - if _, err := w.Write([]byte(errStr)); err != nil { - terror.Log(err) - } - return - } - } - } -} - -func (s *Server) startHTTPServer() { - router := mux.NewRouter() - - router.HandleFunc("/status", s.handleStatus).Name("Status") - // HTTP path for prometheus. - router.Handle("/metrics", promhttp.Handler()).Name("Metrics") - - // HTTP path for dump statistics. - router.Handle("/stats/dump/{db}/{table}", s.newStatsHandler()). - Name("StatsDump") - router.Handle("/stats/dump/{db}/{table}/{snapshot}", s.newStatsHistoryHandler()). - Name("StatsHistoryDump") - - router.Handle("/plan_replayer/dump/{filename}", s.newPlanReplayerHandler()).Name("PlanReplayerDump") - router.Handle("/extract_task/dump", s.newExtractServeHandler()).Name("ExtractTaskDump") - - router.Handle("/optimize_trace/dump/{filename}", s.newOptimizeTraceHandler()).Name("OptimizeTraceDump") - - tikvHandlerTool := s.NewTikvHandlerTool() - router.Handle("/settings", tikvhandler.NewSettingsHandler(tikvHandlerTool)).Name("Settings") - router.Handle("/binlog/recover", tikvhandler.BinlogRecover{}).Name("BinlogRecover") - - router.Handle("/schema", tikvhandler.NewSchemaHandler(tikvHandlerTool)).Name("Schema") - router.Handle("/schema/{db}", tikvhandler.NewSchemaHandler(tikvHandlerTool)) - router.Handle("/schema/{db}/{table}", tikvhandler.NewSchemaHandler(tikvHandlerTool)) - router.Handle("/tables/{colID}/{colTp}/{colFlag}/{colLen}", tikvhandler.ValueHandler{}) - - router.Handle("/schema_storage", tikvhandler.NewSchemaStorageHandler(tikvHandlerTool)).Name("Schema Storage") - router.Handle("/schema_storage/{db}", tikvhandler.NewSchemaStorageHandler(tikvHandlerTool)) - router.Handle("/schema_storage/{db}/{table}", tikvhandler.NewSchemaStorageHandler(tikvHandlerTool)) - - router.Handle("/ddl/history", tikvhandler.NewDDLHistoryJobHandler(tikvHandlerTool)).Name("DDL_History") - router.Handle("/ddl/owner/resign", tikvhandler.NewDDLResignOwnerHandler(tikvHandlerTool.Store.(kv.Storage))).Name("DDL_Owner_Resign") - - // HTTP path for get the TiDB config - router.Handle("/config", fn.Wrap(func() (*config.Config, error) { - return config.GetGlobalConfig(), nil - })) - router.Handle("/labels", tikvhandler.LabelHandler{}).Name("Labels") - - // HTTP path for get server info. - router.Handle("/info", tikvhandler.NewServerInfoHandler(tikvHandlerTool)).Name("Info") - router.Handle("/info/all", tikvhandler.NewAllServerInfoHandler(tikvHandlerTool)).Name("InfoALL") - // HTTP path for get db and table info that is related to the tableID. - router.Handle("/db-table/{tableID}", tikvhandler.NewDBTableHandler(tikvHandlerTool)) - // HTTP path for get table tiflash replica info. - router.Handle("/tiflash/replica-deprecated", tikvhandler.NewFlashReplicaHandler(tikvHandlerTool)) - - // HTTP path for upgrade operations. - router.Handle("/upgrade/{op}", handler.NewClusterUpgradeHandler(tikvHandlerTool.Store.(kv.Storage))).Name("upgrade operations") - - if s.cfg.Store == "tikv" { - // HTTP path for tikv. - router.Handle("/tables/{db}/{table}/regions", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableRegions)) - router.Handle("/tables/{db}/{table}/ranges", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableRanges)) - router.Handle("/tables/{db}/{table}/scatter", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableScatter)) - router.Handle("/tables/{db}/{table}/stop-scatter", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpStopTableScatter)) - router.Handle("/tables/{db}/{table}/disk-usage", tikvhandler.NewTableHandler(tikvHandlerTool, tikvhandler.OpTableDiskUsage)) - router.Handle("/regions/meta", tikvhandler.NewRegionHandler(tikvHandlerTool)).Name("RegionsMeta") - router.Handle("/regions/hot", tikvhandler.NewRegionHandler(tikvHandlerTool)).Name("RegionHot") - router.Handle("/regions/{regionID}", tikvhandler.NewRegionHandler(tikvHandlerTool)) - } - - // HTTP path for get MVCC info - router.Handle("/mvcc/key/{db}/{table}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByKey)) - router.Handle("/mvcc/key/{db}/{table}/{handle}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByKey)) - router.Handle("/mvcc/txn/{startTS}/{db}/{table}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByTxn)) - router.Handle("/mvcc/hex/{hexKey}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByHex)) - router.Handle("/mvcc/index/{db}/{table}/{index}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByIdx)) - router.Handle("/mvcc/index/{db}/{table}/{index}/{handle}", tikvhandler.NewMvccTxnHandler(tikvHandlerTool, tikvhandler.OpMvccGetByIdx)) - - // HTTP path for generate metric profile. - router.Handle("/metrics/profile", tikvhandler.NewProfileHandler(tikvHandlerTool)) - // HTTP path for web UI. - if host, port, err := net.SplitHostPort(s.statusAddr); err == nil { - if host == "" { - host = "localhost" - } - baseURL := &url.URL{ - Scheme: util.InternalHTTPSchema(), - Host: fmt.Sprintf("%s:%s", host, port), - } - router.HandleFunc("/web/trace", traceapp.HandleTiDB).Name("Trace Viewer") - sr := router.PathPrefix("/web/trace/").Subrouter() - if _, err := traceapp.New(traceapp.NewRouter(sr), baseURL); err != nil { - logutil.BgLogger().Error("new failed", zap.Error(err)) - } - router.PathPrefix("/static/").Handler(http.StripPrefix("/static", http.FileServer(static.Data))) - } - - router.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - router.HandleFunc("/debug/pprof/profile", cpuprofile.ProfileHTTPHandler) - router.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - router.HandleFunc("/debug/pprof/trace", pprof.Trace) - // Other /debug/pprof paths not covered above are redirected to pprof.Index. - router.PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index) - - ballast := newBallast(s.cfg.MaxBallastObjectSize) - { - err := ballast.SetSize(s.cfg.BallastObjectSize) - if err != nil { - logutil.BgLogger().Error("set initial ballast object size failed", zap.Error(err)) - } - } - router.HandleFunc("/debug/ballast-object-sz", ballast.GenHTTPHandler()) - - router.HandleFunc("/debug/gogc", func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet: - _, err := w.Write([]byte(strconv.Itoa(util.GetGOGC()))) - terror.Log(err) - case http.MethodPost: - body, err := io.ReadAll(r.Body) - if err != nil { - terror.Log(err) - return - } - - val, err := strconv.Atoi(string(body)) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - if _, err := w.Write([]byte(err.Error())); err != nil { - terror.Log(err) - } - return - } - - util.SetGOGC(val) - } - }) - - router.HandleFunc("/debug/zip", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="tidb_debug"`+time.Now().Format("20060102150405")+".zip")) - - // dump goroutine/heap/mutex - items := []struct { - name string - gc int - debug int - second int - }{ - {name: "goroutine", debug: 2}, - {name: "heap", gc: 1}, - {name: "mutex"}, - } - zw := zip.NewWriter(w) - for _, item := range items { - p := rpprof.Lookup(item.name) - if p == nil { - serveError(w, http.StatusNotFound, "Unknown profile") - return - } - if item.gc > 0 { - runtime.GC() - } - fw, err := zw.Create(item.name) - if err != nil { - serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", item.name, err)) - return - } - err = p.WriteTo(fw, item.debug) - terror.Log(err) - } - - // dump profile - fw, err := zw.Create("profile") - if err != nil { - serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", "profile", err)) - return - } - pc := cpuprofile.NewCollector() - if err := pc.StartCPUProfile(fw); err != nil { - serveError(w, http.StatusInternalServerError, - fmt.Sprintf("Could not enable CPU profiling: %s", err)) - return - } - sec, err := strconv.ParseInt(r.FormValue("seconds"), 10, 64) - if sec <= 0 || err != nil { - sec = 10 - } - sleepWithCtx(r.Context(), time.Duration(sec)*time.Second) - err = pc.StopCPUProfile() - if err != nil { - serveError(w, http.StatusInternalServerError, - fmt.Sprintf("Could not enable CPU profiling: %s", err)) - return - } - - // dump config - fw, err = zw.Create("config") - if err != nil { - serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", "config", err)) - return - } - js, err := json.MarshalIndent(config.GetGlobalConfig(), "", " ") - if err != nil { - serveError(w, http.StatusInternalServerError, fmt.Sprintf("get config info fail%v", err)) - return - } - _, err = fw.Write(js) - terror.Log(err) - - // dump version - fw, err = zw.Create("version") - if err != nil { - serveError(w, http.StatusInternalServerError, fmt.Sprintf("Create zipped %s fail: %v", "version", err)) - return - } - _, err = fw.Write([]byte(printer.GetTiDBInfo())) - terror.Log(err) - - err = zw.Close() - terror.Log(err) - }) - - // failpoint is enabled only for tests so we can add some http APIs here for tests. - failpoint.Inject("enableTestAPI", func() { - router.PathPrefix("/fail/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = strings.TrimPrefix(r.URL.Path, "/fail") - new(failpoint.HttpHandler).ServeHTTP(w, r) - }) - - router.Handle("/test/{mod}/{op}", tikvhandler.NewTestHandler(tikvHandlerTool, 0)) - }) - - // ddlHook is enabled only for tests so we can substitute the callback in the DDL. - router.Handle("/test/ddl/hook", tikvhandler.DDLHookHandler{}) - - // ttlJobTriggerHandler is enabled only for tests, so we can accelerate the schedule of TTL job - router.Handle("/test/ttl/trigger/{db}/{table}", ttlhandler.NewTTLJobTriggerHandler(tikvHandlerTool.Store.(kv.Storage))) - - var ( - httpRouterPage bytes.Buffer - pathTemplate string - err error - ) - httpRouterPage.WriteString("TiDB Status and Metrics Report

TiDB Status and Metrics Report

") - err = router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { - pathTemplate, err = route.GetPathTemplate() - if err != nil { - logutil.BgLogger().Error("get HTTP router path failed", zap.Error(err)) - } - name := route.GetName() - // If the name attribute is not set, GetName returns "". - // "traceapp.xxx" are introduced by the traceapp package and are also ignored. - if name != "" && !strings.HasPrefix(name, "traceapp") && err == nil { - httpRouterPage.WriteString("") - } - return nil - }) - if err != nil { - logutil.BgLogger().Error("generate root failed", zap.Error(err)) - } - httpRouterPage.WriteString("") - httpRouterPage.WriteString("
" + name + "
Debug
") - router.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) { - _, err = responseWriter.Write(httpRouterPage.Bytes()) - if err != nil { - logutil.BgLogger().Error("write HTTP index page failed", zap.Error(err)) - } - }) - - serverMux := http.NewServeMux() - serverMux.Handle("/", router) - s.startStatusServerAndRPCServer(serverMux) -} - -func (s *Server) startStatusServerAndRPCServer(serverMux *http.ServeMux) { - m := cmux.New(s.statusListener) - // Match connections in order: - // First HTTP, and otherwise grpc. - httpL := m.Match(cmux.HTTP1Fast()) - grpcL := m.Match(cmux.Any()) - - statusServer := &http.Server{Addr: s.statusAddr, Handler: util2.NewCorsHandler(serverMux, s.cfg)} - grpcServer := NewRPCServer(s.cfg, s.dom, s) - service.RegisterChannelzServiceToServer(grpcServer) - if s.cfg.Store == "tikv" { - keyspaceName := config.GetGlobalKeyspaceName() - for { - var fullPath string - if keyspaceName == "" { - fullPath = fmt.Sprintf("%s://%s", s.cfg.Store, s.cfg.Path) - } else { - fullPath = fmt.Sprintf("%s://%s?keyspaceName=%s", s.cfg.Store, s.cfg.Path, keyspaceName) - } - store, err := store.New(fullPath) - if err != nil { - logutil.BgLogger().Error("new tikv store fail", zap.Error(err)) - break - } - ebd, ok := store.(kv.EtcdBackend) - if !ok { - break - } - etcdAddr, err := ebd.EtcdAddrs() - if err != nil { - logutil.BgLogger().Error("tikv store not etcd background", zap.Error(err)) - break - } - selfAddr := net.JoinHostPort(s.cfg.AdvertiseAddress, strconv.Itoa(int(s.cfg.Status.StatusPort))) - service := autoid.New(selfAddr, etcdAddr, store, ebd.TLSConfig()) - logutil.BgLogger().Info("register auto service at", zap.String("addr", selfAddr)) - pb.RegisterAutoIDAllocServer(grpcServer, service) - s.autoIDService = service - break - } - } - - s.statusServer = statusServer - s.grpcServer = grpcServer - - go util.WithRecovery(func() { - err := grpcServer.Serve(grpcL) - logutil.BgLogger().Error("grpc server error", zap.Error(err)) - }, nil) - - go util.WithRecovery(func() { - err := statusServer.Serve(httpL) - logutil.BgLogger().Error("http server error", zap.Error(err)) - }, nil) - - err := m.Serve() - if err != nil { - logutil.BgLogger().Error("start status/rpc server error", zap.Error(err)) - } -} - -// SetCNChecker set the CN checker for server. -func (s *Server) SetCNChecker(tlsConfig *tls.Config) *tls.Config { - if tlsConfig != nil && len(s.cfg.Security.ClusterVerifyCN) != 0 { - checkCN := make(map[string]struct{}) - for _, cn := range s.cfg.Security.ClusterVerifyCN { - cn = strings.TrimSpace(cn) - checkCN[cn] = struct{}{} - } - tlsConfig.VerifyPeerCertificate = func(_ [][]byte, verifiedChains [][]*x509.Certificate) error { - for _, chain := range verifiedChains { - if len(chain) != 0 { - if _, match := checkCN[chain[0].Subject.CommonName]; match { - return nil - } - } - } - return errors.Errorf("client certificate authentication failed. The Common Name from the client certificate was not found in the configuration cluster-verify-cn with value: %s", s.cfg.Security.ClusterVerifyCN) - } - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - } - return tlsConfig -} - -// Status of TiDB. -type Status struct { - Connections int `json:"connections"` - Version string `json:"version"` - GitHash string `json:"git_hash"` -} - -func (s *Server) handleStatus(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - // If the server is in the process of shutting down, return a non-200 status. - // It is important not to return Status{} as acquiring the s.ConnectionCount() - // acquires a lock that may already be held by the shutdown process. - if !s.health.Load() { - w.WriteHeader(http.StatusInternalServerError) - return - } - st := Status{ - Connections: s.ConnectionCount(), - Version: mysql.ServerVersion, - GitHash: versioninfo.TiDBGitHash, - } - js, err := json.Marshal(st) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - logutil.BgLogger().Error("encode json failed", zap.Error(err)) - return - } - _, err = w.Write(js) - terror.Log(errors.Trace(err)) -} - -func (s *Server) newStatsHandler() *optimizor.StatsHandler { - store, ok := s.driver.(*TiDBDriver) - if !ok { - panic("Illegal driver") - } - - do, err := session.GetDomain(store.store) - if err != nil { - panic("Failed to get domain") - } - return optimizor.NewStatsHandler(do) -} - -func (s *Server) newStatsHistoryHandler() *optimizor.StatsHistoryHandler { - store, ok := s.driver.(*TiDBDriver) - if !ok { - panic("Illegal driver") - } - - do, err := session.GetDomain(store.store) - if err != nil { - panic("Failed to get domain") - } - return optimizor.NewStatsHistoryHandler(do) -} diff --git a/pkg/session/binding__failpoint_binding__.go b/pkg/session/binding__failpoint_binding__.go deleted file mode 100644 index 9ef59b452261c..0000000000000 --- a/pkg/session/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package session - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/session/nontransactional.go b/pkg/session/nontransactional.go index fe7eb3680c432..390de607cdaf4 100644 --- a/pkg/session/nontransactional.go +++ b/pkg/session/nontransactional.go @@ -417,11 +417,11 @@ func doOneJob(ctx context.Context, job *job, totalJobCount int, options statemen rs, err := se.ExecuteStmt(ctx, options.stmt.DMLStmt) // collect errors - if val, _err_ := failpoint.Eval(_curpkg_("batchDMLError")); _err_ == nil { + failpoint.Inject("batchDMLError", func(val failpoint.Value) { if val.(bool) { err = errors.New("injected batch(non-transactional) DML error") } - } + }) if err != nil { logutil.Logger(ctx).Error("Non-transactional DML SQL failed", zap.String("job", dmlSQLInLog), zap.Error(err), zap.Int("jobID", job.jobID), zap.Int("jobSize", job.jobSize)) job.err = err diff --git a/pkg/session/nontransactional.go__failpoint_stash__ b/pkg/session/nontransactional.go__failpoint_stash__ deleted file mode 100644 index 390de607cdaf4..0000000000000 --- a/pkg/session/nontransactional.go__failpoint_stash__ +++ /dev/null @@ -1,847 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "context" - "fmt" - "math" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/format" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/opcode" - "github.com/pingcap/tidb/pkg/planner/core" - session_metrics "github.com/pingcap/tidb/pkg/session/metrics" - sessiontypes "github.com/pingcap/tidb/pkg/session/types" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/types" - driver "github.com/pingcap/tidb/pkg/types/parser_driver" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/redact" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "go.uber.org/zap" -) - -// ErrNonTransactionalJobFailure is the error when a non-transactional job fails. The error is returned and following jobs are canceled. -var ErrNonTransactionalJobFailure = dbterror.ClassSession.NewStd(errno.ErrNonTransactionalJobFailure) - -// job: handle keys in [start, end] -type job struct { - start types.Datum - end types.Datum - err error - jobID int - jobSize int // it can be inaccurate if there are concurrent writes - sql string -} - -// statementBuildInfo contains information that is needed to build the split statement in a job -type statementBuildInfo struct { - stmt *ast.NonTransactionalDMLStmt - shardColumnType types.FieldType - shardColumnRefer *ast.ResultField - originalCondition ast.ExprNode -} - -func (j job) String(redacted string) string { - return fmt.Sprintf("job id: %d, estimated size: %d, sql: %s", j.jobID, j.jobSize, redact.String(redacted, j.sql)) -} - -// HandleNonTransactionalDML is the entry point for a non-transactional DML statement -func HandleNonTransactionalDML(ctx context.Context, stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session) (sqlexec.RecordSet, error) { - sessVars := se.GetSessionVars() - originalReadStaleness := se.GetSessionVars().ReadStaleness - // NT-DML is a write operation, and should not be affected by read_staleness that is supposed to affect only SELECT. - sessVars.ReadStaleness = 0 - defer func() { - sessVars.ReadStaleness = originalReadStaleness - }() - err := core.Preprocess(ctx, se, stmt) - if err != nil { - return nil, err - } - if err := checkConstraint(stmt, se); err != nil { - return nil, err - } - - tableName, selectSQL, shardColumnInfo, tableSources, err := buildSelectSQL(stmt, se) - if err != nil { - return nil, err - } - - if err := checkConstraintWithShardColumn(se, stmt, tableName, shardColumnInfo, tableSources); err != nil { - return nil, err - } - - if stmt.DryRun == ast.DryRunQuery { - return buildDryRunResults(stmt.DryRun, []string{selectSQL}, se.GetSessionVars().BatchSize.MaxChunkSize) - } - - // TODO: choose an appropriate quota. - // Use the mem-quota-query as a workaround. As a result, a NT-DML may consume 2x of the memory quota. - memTracker := memory.NewTracker(memory.LabelForNonTransactionalDML, -1) - memTracker.AttachTo(se.GetSessionVars().MemTracker) - se.GetSessionVars().MemTracker.SetBytesLimit(se.GetSessionVars().MemQuotaQuery) - defer memTracker.Detach() - jobs, err := buildShardJobs(ctx, stmt, se, selectSQL, shardColumnInfo, memTracker) - if err != nil { - return nil, err - } - - splitStmts, err := runJobs(ctx, jobs, stmt, tableName, se, stmt.DMLStmt.WhereExpr()) - if err != nil { - return nil, err - } - if stmt.DryRun == ast.DryRunSplitDml { - return buildDryRunResults(stmt.DryRun, splitStmts, se.GetSessionVars().BatchSize.MaxChunkSize) - } - return buildExecuteResults(ctx, jobs, se.GetSessionVars().BatchSize.MaxChunkSize, se.GetSessionVars().EnableRedactLog) -} - -// we require: -// (1) in an update statement, shard column cannot be updated -// -// Note: this is not a comprehensive check. -// We do this to help user prevent some easy mistakes, at an acceptable maintenance cost. -func checkConstraintWithShardColumn(se sessiontypes.Session, stmt *ast.NonTransactionalDMLStmt, - tableName *ast.TableName, shardColumnInfo *model.ColumnInfo, tableSources []*ast.TableSource) error { - switch s := stmt.DMLStmt.(type) { - case *ast.UpdateStmt: - if err := checkUpdateShardColumn(se, s.List, shardColumnInfo, tableName, tableSources, true); err != nil { - return err - } - case *ast.InsertStmt: - // FIXME: is it possible to happen? - // `insert into t select * from t on duplicate key update id = id + 1` will return an ambiguous column error? - if err := checkUpdateShardColumn(se, s.OnDuplicate, shardColumnInfo, tableName, tableSources, false); err != nil { - return err - } - default: - } - return nil -} - -// shard column should not be updated. -func checkUpdateShardColumn(se sessiontypes.Session, assignments []*ast.Assignment, shardColumnInfo *model.ColumnInfo, - tableName *ast.TableName, tableSources []*ast.TableSource, isUpdate bool) error { - // if the table has alias, the alias is used in assignments, and we should use aliased name to compare - aliasedShardColumnTableName := tableName.Name.L - for _, tableSource := range tableSources { - if tableSource.Source.(*ast.TableName).Name.L == aliasedShardColumnTableName && tableSource.AsName.L != "" { - aliasedShardColumnTableName = tableSource.AsName.L - } - } - - if shardColumnInfo == nil { - return nil - } - for _, assignment := range assignments { - sameDB := (assignment.Column.Schema.L == tableName.Schema.L) || - (assignment.Column.Schema.L == "" && tableName.Schema.L == se.GetSessionVars().CurrentDB) - if !sameDB { - continue - } - sameTable := (assignment.Column.Table.L == aliasedShardColumnTableName) || (isUpdate && len(tableSources) == 1) - if !sameTable { - continue - } - if assignment.Column.Name.L == shardColumnInfo.Name.L { - return errors.New("Non-transactional DML, shard column cannot be updated") - } - } - return nil -} - -func checkConstraint(stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session) error { - sessVars := se.GetSessionVars() - if !(sessVars.IsAutocommit() && !sessVars.InTxn()) { - return errors.Errorf("non-transactional DML can only run in auto-commit mode. auto-commit:%v, inTxn:%v", - se.GetSessionVars().IsAutocommit(), se.GetSessionVars().InTxn()) - } - if variable.EnableBatchDML.Load() && sessVars.DMLBatchSize > 0 && (sessVars.BatchDelete || sessVars.BatchInsert) { - return errors.Errorf("can't run non-transactional DML with batch-dml") - } - - if sessVars.ReadConsistency.IsWeak() { - return errors.New("can't run non-transactional under weak read consistency") - } - if sessVars.SnapshotTS != 0 { - return errors.New("can't do non-transactional DML when tidb_snapshot is set") - } - - switch s := stmt.DMLStmt.(type) { - case *ast.DeleteStmt: - if err := checkTableRef(s.TableRefs, true); err != nil { - return err - } - if err := checkReadClauses(s.Limit, s.Order); err != nil { - return err - } - session_metrics.NonTransactionalDeleteCount.Inc() - case *ast.UpdateStmt: - if err := checkTableRef(s.TableRefs, true); err != nil { - return err - } - if err := checkReadClauses(s.Limit, s.Order); err != nil { - return err - } - session_metrics.NonTransactionalUpdateCount.Inc() - case *ast.InsertStmt: - if s.Select == nil { - return errors.New("Non-transactional insert supports insert select stmt only") - } - selectStmt, ok := s.Select.(*ast.SelectStmt) - if !ok { - return errors.New("Non-transactional insert doesn't support non-select source") - } - if err := checkTableRef(selectStmt.From, true); err != nil { - return err - } - if err := checkReadClauses(selectStmt.Limit, selectStmt.OrderBy); err != nil { - return err - } - session_metrics.NonTransactionalInsertCount.Inc() - default: - return errors.New("Unsupported DML type for non-transactional DML") - } - - return nil -} - -func checkTableRef(t *ast.TableRefsClause, allowMultipleTables bool) error { - if t == nil || t.TableRefs == nil || t.TableRefs.Left == nil { - return errors.New("table reference is nil") - } - if !allowMultipleTables && t.TableRefs.Right != nil { - return errors.New("Non-transactional statements don't support multiple tables") - } - return nil -} - -func checkReadClauses(limit *ast.Limit, order *ast.OrderByClause) error { - if limit != nil { - return errors.New("Non-transactional statements don't support limit") - } - if order != nil { - return errors.New("Non-transactional statements don't support order by") - } - return nil -} - -// single-threaded worker. work on the key range [start, end] -func runJobs(ctx context.Context, jobs []job, stmt *ast.NonTransactionalDMLStmt, - tableName *ast.TableName, se sessiontypes.Session, originalCondition ast.ExprNode) ([]string, error) { - // prepare for the construction of statement - var shardColumnRefer *ast.ResultField - var shardColumnType types.FieldType - for _, col := range tableName.TableInfo.Columns { - if col.Name.L == stmt.ShardColumn.Name.L { - shardColumnRefer = &ast.ResultField{ - Column: col, - Table: tableName.TableInfo, - DBName: tableName.Schema, - } - shardColumnType = col.FieldType - } - } - if shardColumnRefer == nil && stmt.ShardColumn.Name.L != model.ExtraHandleName.L { - return nil, errors.New("Non-transactional DML, shard column not found") - } - - splitStmts := make([]string, 0, len(jobs)) - for i := range jobs { - select { - case <-ctx.Done(): - failedJobs := make([]string, 0) - for _, job := range jobs { - if job.err != nil { - failedJobs = append(failedJobs, fmt.Sprintf("job:%s, error: %s", job.String(se.GetSessionVars().EnableRedactLog), job.err.Error())) - } - } - if len(failedJobs) == 0 { - logutil.Logger(ctx).Warn("Non-transactional DML worker exit because context canceled. No errors", - zap.Int("finished", i), zap.Int("total", len(jobs))) - } else { - logutil.Logger(ctx).Warn("Non-transactional DML worker exit because context canceled. Errors found", - zap.Int("finished", i), zap.Int("total", len(jobs)), zap.Strings("errors found", failedJobs)) - } - return nil, ctx.Err() - default: - } - - // _tidb_rowid - if shardColumnRefer == nil { - shardColumnType = *types.NewFieldType(mysql.TypeLonglong) - shardColumnRefer = &ast.ResultField{ - Column: model.NewExtraHandleColInfo(), - Table: tableName.TableInfo, - DBName: tableName.Schema, - } - } - stmtBuildInfo := statementBuildInfo{ - stmt: stmt, - shardColumnType: shardColumnType, - shardColumnRefer: shardColumnRefer, - originalCondition: originalCondition, - } - if stmt.DryRun == ast.DryRunSplitDml { - if i > 0 && i < len(jobs)-1 { - continue - } - splitStmt := doOneJob(ctx, &jobs[i], len(jobs), stmtBuildInfo, se, true) - splitStmts = append(splitStmts, splitStmt) - } else { - doOneJob(ctx, &jobs[i], len(jobs), stmtBuildInfo, se, false) - } - - // if the first job failed, there is a large chance that all jobs will fail. So return early. - if i == 0 && jobs[i].err != nil { - return nil, errors.Annotate(jobs[i].err, "Early return: error occurred in the first job. All jobs are canceled") - } - if jobs[i].err != nil && !se.GetSessionVars().NonTransactionalIgnoreError { - return nil, ErrNonTransactionalJobFailure.GenWithStackByArgs(jobs[i].jobID, len(jobs), jobs[i].start.String(), jobs[i].end.String(), jobs[i].String(se.GetSessionVars().EnableRedactLog), jobs[i].err.Error()) - } - } - return splitStmts, nil -} - -func doOneJob(ctx context.Context, job *job, totalJobCount int, options statementBuildInfo, se sessiontypes.Session, dryRun bool) string { - var whereCondition ast.ExprNode - - if job.start.IsNull() { - isNullCondition := &ast.IsNullExpr{ - Expr: &ast.ColumnNameExpr{ - Name: options.stmt.ShardColumn, - Refer: options.shardColumnRefer, - }, - Not: false, - } - if job.end.IsNull() { - // `where x is null` - whereCondition = isNullCondition - } else { - // `where (x <= job.end) || (x is null)` - right := &driver.ValueExpr{} - right.Type = options.shardColumnType - right.Datum = job.end - leCondition := &ast.BinaryOperationExpr{ - Op: opcode.LE, - L: &ast.ColumnNameExpr{ - Name: options.stmt.ShardColumn, - Refer: options.shardColumnRefer, - }, - R: right, - } - whereCondition = &ast.BinaryOperationExpr{ - Op: opcode.LogicOr, - L: leCondition, - R: isNullCondition, - } - } - } else { - // a normal between condition: `where x between start and end` - left := &driver.ValueExpr{} - left.Type = options.shardColumnType - left.Datum = job.start - right := &driver.ValueExpr{} - right.Type = options.shardColumnType - right.Datum = job.end - whereCondition = &ast.BetweenExpr{ - Expr: &ast.ColumnNameExpr{ - Name: options.stmt.ShardColumn, - Refer: options.shardColumnRefer, - }, - Left: left, - Right: right, - Not: false, - } - } - - if options.originalCondition == nil { - options.stmt.DMLStmt.SetWhereExpr(whereCondition) - } else { - options.stmt.DMLStmt.SetWhereExpr(&ast.BinaryOperationExpr{ - Op: opcode.LogicAnd, - L: whereCondition, - R: options.originalCondition, - }) - } - var sb strings.Builder - err := options.stmt.DMLStmt.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags| - format.RestoreNameBackQuotes| - format.RestoreSpacesAroundBinaryOperation| - format.RestoreBracketAroundBinaryOperation| - format.RestoreStringWithoutCharset, &sb)) - if err != nil { - logutil.Logger(ctx).Error("Non-transactional DML, failed to restore the DML statement", zap.Error(err)) - job.err = errors.New("Failed to restore the DML statement, probably because of unsupported type of the shard column") - return "" - } - dmlSQL := sb.String() - - if dryRun { - return dmlSQL - } - - job.sql = dmlSQL - logutil.Logger(ctx).Info("start a Non-transactional DML", - zap.String("job", job.String(se.GetSessionVars().EnableRedactLog)), zap.Int("totalJobCount", totalJobCount)) - dmlSQLInLog := parser.Normalize(dmlSQL, se.GetSessionVars().EnableRedactLog) - - options.stmt.DMLStmt.SetText(nil, fmt.Sprintf("/* job %v/%v */ %s", job.jobID, totalJobCount, dmlSQL)) - rs, err := se.ExecuteStmt(ctx, options.stmt.DMLStmt) - - // collect errors - failpoint.Inject("batchDMLError", func(val failpoint.Value) { - if val.(bool) { - err = errors.New("injected batch(non-transactional) DML error") - } - }) - if err != nil { - logutil.Logger(ctx).Error("Non-transactional DML SQL failed", zap.String("job", dmlSQLInLog), zap.Error(err), zap.Int("jobID", job.jobID), zap.Int("jobSize", job.jobSize)) - job.err = err - } else { - logutil.Logger(ctx).Info("Non-transactional DML SQL finished successfully", zap.Int("jobID", job.jobID), - zap.Int("jobSize", job.jobSize), zap.String("dmlSQL", dmlSQLInLog)) - } - if rs != nil { - _ = rs.Close() - } - return "" -} - -func buildShardJobs(ctx context.Context, stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session, - selectSQL string, shardColumnInfo *model.ColumnInfo, memTracker *memory.Tracker) ([]job, error) { - var shardColumnCollate string - if shardColumnInfo != nil { - shardColumnCollate = shardColumnInfo.GetCollate() - } else { - shardColumnCollate = "" - } - - // A NT-DML is not a SELECT. We ignore the SelectLimit for selectSQL so that it can read all values. - originalSelectLimit := se.GetSessionVars().SelectLimit - se.GetSessionVars().SelectLimit = math.MaxUint64 - // NT-DML is a write operation, and should not be affected by read_staleness that is supposed to affect only SELECT. - rss, err := se.Execute(ctx, selectSQL) - se.GetSessionVars().SelectLimit = originalSelectLimit - - if err != nil { - return nil, err - } - if len(rss) != 1 { - return nil, errors.Errorf("Non-transactional DML, expecting 1 record set, but got %d", len(rss)) - } - rs := rss[0] - defer func() { - _ = rs.Close() - }() - - batchSize := int(stmt.Limit) - if batchSize <= 0 { - return nil, errors.New("Non-transactional DML, batch size should be positive") - } - jobCount := 0 - jobs := make([]job, 0) - currentSize := 0 - var currentStart, currentEnd types.Datum - - chk := rs.NewChunk(nil) - for { - err = rs.Next(ctx, chk) - if err != nil { - return nil, err - } - - // last chunk - if chk.NumRows() == 0 { - if currentSize > 0 { - // there's remaining work - jobs = appendNewJob(jobs, jobCount+1, currentStart, currentEnd, currentSize, memTracker) - } - break - } - - if len(jobs) > 0 && chk.NumRows()+currentSize < batchSize { - // not enough data for a batch - currentSize += chk.NumRows() - newEnd := chk.GetRow(chk.NumRows()-1).GetDatum(0, &rs.Fields()[0].Column.FieldType) - currentEnd = *newEnd.Clone() - continue - } - - iter := chunk.NewIterator4Chunk(chk) - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - if currentSize == 0 { - newStart := row.GetDatum(0, &rs.Fields()[0].Column.FieldType) - currentStart = *newStart.Clone() - } - newEnd := row.GetDatum(0, &rs.Fields()[0].Column.FieldType) - if currentSize >= batchSize { - cmp, err := newEnd.Compare(se.GetSessionVars().StmtCtx.TypeCtx(), ¤tEnd, collate.GetCollator(shardColumnCollate)) - if err != nil { - return nil, err - } - if cmp != 0 { - jobCount++ - jobs = appendNewJob(jobs, jobCount, *currentStart.Clone(), *currentEnd.Clone(), currentSize, memTracker) - currentSize = 0 - currentStart = newEnd - } - } - currentEnd = newEnd - currentSize++ - } - currentEnd = *currentEnd.Clone() - currentStart = *currentStart.Clone() - } - - return jobs, nil -} - -func appendNewJob(jobs []job, id int, start types.Datum, end types.Datum, size int, tracker *memory.Tracker) []job { - jobs = append(jobs, job{jobID: id, start: start, end: end, jobSize: size}) - tracker.Consume(start.EstimatedMemUsage() + end.EstimatedMemUsage() + 64) - return jobs -} - -func buildSelectSQL(stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session) ( - *ast.TableName, string, *model.ColumnInfo, []*ast.TableSource, error) { - // only use the first table - join, ok := stmt.DMLStmt.TableRefsJoin() - if !ok { - return nil, "", nil, nil, errors.New("Non-transactional DML, table source not found") - } - tableSources := make([]*ast.TableSource, 0) - tableSources, err := collectTableSourcesInJoin(join, tableSources) - if err != nil { - return nil, "", nil, nil, err - } - if len(tableSources) == 0 { - return nil, "", nil, nil, errors.New("Non-transactional DML, no tables found in table refs") - } - leftMostTableSource := tableSources[0] - leftMostTableName, ok := leftMostTableSource.Source.(*ast.TableName) - if !ok { - return nil, "", nil, nil, errors.New("Non-transactional DML, table name not found") - } - - shardColumnInfo, tableName, err := selectShardColumn(stmt, se, tableSources, leftMostTableName, leftMostTableSource) - if err != nil { - return nil, "", nil, nil, err - } - - var sb strings.Builder - if stmt.DMLStmt.WhereExpr() != nil { - err := stmt.DMLStmt.WhereExpr().Restore(format.NewRestoreCtx(format.DefaultRestoreFlags| - format.RestoreNameBackQuotes| - format.RestoreSpacesAroundBinaryOperation| - format.RestoreBracketAroundBinaryOperation| - format.RestoreStringWithoutCharset, &sb), - ) - if err != nil { - return nil, "", nil, nil, errors.Annotate(err, "Failed to restore where clause in non-transactional DML") - } - } else { - sb.WriteString("TRUE") - } - // assure NULL values are placed first - selectSQL := fmt.Sprintf("SELECT `%s` FROM `%s`.`%s` WHERE %s ORDER BY IF(ISNULL(`%s`),0,1),`%s`", - stmt.ShardColumn.Name.O, tableName.DBInfo.Name.O, tableName.Name.O, sb.String(), stmt.ShardColumn.Name.O, stmt.ShardColumn.Name.O) - return tableName, selectSQL, shardColumnInfo, tableSources, nil -} - -func selectShardColumn(stmt *ast.NonTransactionalDMLStmt, se sessiontypes.Session, tableSources []*ast.TableSource, - leftMostTableName *ast.TableName, leftMostTableSource *ast.TableSource) ( - *model.ColumnInfo, *ast.TableName, error) { - var indexed bool - var shardColumnInfo *model.ColumnInfo - var selectedTableName *ast.TableName - - if len(tableSources) == 1 { - // single table - leftMostTable, err := domain.GetDomain(se).InfoSchema().TableByName(context.Background(), leftMostTableName.Schema, leftMostTableName.Name) - if err != nil { - return nil, nil, err - } - selectedTableName = leftMostTableName - indexed, shardColumnInfo, err = selectShardColumnFromTheOnlyTable( - stmt, leftMostTableName, leftMostTableSource.AsName, leftMostTable) - if err != nil { - return nil, nil, err - } - } else { - // multi table join - if stmt.ShardColumn == nil { - leftMostTable, err := domain.GetDomain(se).InfoSchema().TableByName(context.Background(), leftMostTableName.Schema, leftMostTableName.Name) - if err != nil { - return nil, nil, err - } - selectedTableName = leftMostTableName - indexed, shardColumnInfo, err = selectShardColumnAutomatically(stmt, leftMostTable, leftMostTableName, leftMostTableSource.AsName) - if err != nil { - return nil, nil, err - } - } else if stmt.ShardColumn.Schema.L != "" && stmt.ShardColumn.Table.L != "" && stmt.ShardColumn.Name.L != "" { - specifiedDbName := stmt.ShardColumn.Schema - specifiedTableName := stmt.ShardColumn.Table - specifiedColName := stmt.ShardColumn.Name - - // the specified table must be in the join - tableInJoin := false - var chosenTableName model.CIStr - for _, tableSource := range tableSources { - tableSourceName := tableSource.Source.(*ast.TableName) - tableSourceFinalTableName := tableSource.AsName // precedence: alias name, then table name - if tableSourceFinalTableName.O == "" { - tableSourceFinalTableName = tableSourceName.Name - } - if tableSourceName.Schema.L == specifiedDbName.L && tableSourceFinalTableName.L == specifiedTableName.L { - tableInJoin = true - selectedTableName = tableSourceName - chosenTableName = tableSourceName.Name - break - } - } - if !tableInJoin { - return nil, nil, - errors.Errorf( - "Non-transactional DML, shard column %s.%s.%s is not in the tables involved in the join", - specifiedDbName.L, specifiedTableName.L, specifiedColName.L, - ) - } - - tbl, err := domain.GetDomain(se).InfoSchema().TableByName(context.Background(), specifiedDbName, chosenTableName) - if err != nil { - return nil, nil, err - } - indexed, shardColumnInfo, err = selectShardColumnByGivenName(specifiedColName.L, tbl) - if err != nil { - return nil, nil, err - } - } else { - return nil, nil, errors.New( - "Non-transactional DML, shard column must be fully specified (i.e. `BATCH ON dbname.tablename.colname`) when multiple tables are involved", - ) - } - } - if !indexed { - return nil, nil, errors.Errorf("Non-transactional DML, shard column %s is not indexed", stmt.ShardColumn.Name.L) - } - return shardColumnInfo, selectedTableName, nil -} - -func collectTableSourcesInJoin(node ast.ResultSetNode, tableSources []*ast.TableSource) ([]*ast.TableSource, error) { - if node == nil { - return tableSources, nil - } - switch x := node.(type) { - case *ast.Join: - var err error - tableSources, err = collectTableSourcesInJoin(x.Left, tableSources) - if err != nil { - return nil, err - } - tableSources, err = collectTableSourcesInJoin(x.Right, tableSources) - if err != nil { - return nil, err - } - case *ast.TableSource: - // assert it's a table name - if _, ok := x.Source.(*ast.TableName); !ok { - return nil, errors.New("Non-transactional DML, table name not found in join") - } - tableSources = append(tableSources, x) - default: - return nil, errors.Errorf("Non-transactional DML, unknown type %T in table refs", node) - } - return tableSources, nil -} - -// it attempts to auto-select a shard column from handle if not specified, and fills back the corresponding info in the stmt, -// making it transparent to following steps -func selectShardColumnFromTheOnlyTable(stmt *ast.NonTransactionalDMLStmt, tableName *ast.TableName, - tableAsName model.CIStr, tbl table.Table) ( - indexed bool, shardColumnInfo *model.ColumnInfo, err error) { - if stmt.ShardColumn == nil { - return selectShardColumnAutomatically(stmt, tbl, tableName, tableAsName) - } - - return selectShardColumnByGivenName(stmt.ShardColumn.Name.L, tbl) -} - -func selectShardColumnByGivenName(shardColumnName string, tbl table.Table) ( - indexed bool, shardColumnInfo *model.ColumnInfo, err error) { - tableInfo := tbl.Meta() - if shardColumnName == model.ExtraHandleName.L && !tableInfo.HasClusteredIndex() { - return true, nil, nil - } - - for _, col := range tbl.Cols() { - if col.Name.L == shardColumnName { - shardColumnInfo = col.ColumnInfo - break - } - } - if shardColumnInfo == nil { - return false, nil, errors.Errorf("shard column %s not found", shardColumnName) - } - // is int handle - if mysql.HasPriKeyFlag(shardColumnInfo.GetFlag()) && tableInfo.PKIsHandle { - return true, shardColumnInfo, nil - } - - for _, index := range tbl.Indices() { - if index.Meta().State != model.StatePublic || index.Meta().Invisible { - continue - } - indexColumns := index.Meta().Columns - // check only the first column - if len(indexColumns) > 0 && indexColumns[0].Name.L == shardColumnName { - indexed = true - break - } - } - return indexed, shardColumnInfo, nil -} - -func selectShardColumnAutomatically(stmt *ast.NonTransactionalDMLStmt, tbl table.Table, - tableName *ast.TableName, tableAsName model.CIStr) (bool, *model.ColumnInfo, error) { - // auto-detect shard column - var shardColumnInfo *model.ColumnInfo - tableInfo := tbl.Meta() - if tbl.Meta().PKIsHandle { - shardColumnInfo = tableInfo.GetPkColInfo() - } else if tableInfo.IsCommonHandle { - for _, index := range tableInfo.Indices { - if index.Primary { - if len(index.Columns) == 1 { - shardColumnInfo = tableInfo.Columns[index.Columns[0].Offset] - break - } - // if the clustered index contains multiple columns, we cannot automatically choose a column as the shard column - return false, nil, errors.New("Non-transactional DML, the clustered index contains multiple columns. Please specify a shard column") - } - } - if shardColumnInfo == nil { - return false, nil, errors.New("Non-transactional DML, the clustered index is not found") - } - } - - shardColumnName := model.ExtraHandleName.L - if shardColumnInfo != nil { - shardColumnName = shardColumnInfo.Name.L - } - - outputTableName := tableName.Name - if tableAsName.L != "" { - outputTableName = tableAsName - } - stmt.ShardColumn = &ast.ColumnName{ - Schema: tableName.Schema, - Table: outputTableName, // so that table alias works - Name: model.NewCIStr(shardColumnName), - } - return true, shardColumnInfo, nil -} - -func buildDryRunResults(dryRunOption int, results []string, maxChunkSize int) (sqlexec.RecordSet, error) { - var fieldName string - if dryRunOption == ast.DryRunSplitDml { - fieldName = "split statement examples" - } else { - fieldName = "query statement" - } - - resultFields := []*ast.ResultField{{ - Column: &model.ColumnInfo{ - FieldType: *types.NewFieldType(mysql.TypeString), - }, - ColumnAsName: model.NewCIStr(fieldName), - }} - rows := make([][]any, 0, len(results)) - for _, result := range results { - row := make([]any, 1) - row[0] = result - rows = append(rows, row) - } - return &sqlexec.SimpleRecordSet{ - ResultFields: resultFields, - Rows: rows, - MaxChunkSize: maxChunkSize, - }, nil -} - -func buildExecuteResults(ctx context.Context, jobs []job, maxChunkSize int, redactLog string) (sqlexec.RecordSet, error) { - failedJobs := make([]job, 0) - for _, job := range jobs { - if job.err != nil { - failedJobs = append(failedJobs, job) - } - } - if len(failedJobs) == 0 { - resultFields := []*ast.ResultField{ - { - Column: &model.ColumnInfo{ - FieldType: *types.NewFieldType(mysql.TypeLong), - }, - ColumnAsName: model.NewCIStr("number of jobs"), - }, - { - Column: &model.ColumnInfo{ - FieldType: *types.NewFieldType(mysql.TypeString), - }, - ColumnAsName: model.NewCIStr("job status"), - }, - } - rows := make([][]any, 1) - row := make([]any, 2) - row[0] = len(jobs) - row[1] = "all succeeded" - rows[0] = row - return &sqlexec.SimpleRecordSet{ - ResultFields: resultFields, - Rows: rows, - MaxChunkSize: maxChunkSize, - }, nil - } - - // ignoreError must be set. - var sb strings.Builder - for _, job := range failedJobs { - sb.WriteString(fmt.Sprintf("%s, %s;\n", job.String(redactLog), job.err.Error())) - } - - errStr := sb.String() - // log errors here in case the output is too long. There can be thousands of errors. - logutil.Logger(ctx).Error("Non-transactional DML failed", - zap.Int("num_failed_jobs", len(failedJobs)), zap.String("failed_jobs", errStr)) - - return nil, fmt.Errorf("%d/%d jobs failed in the non-transactional DML: %s, ...(more in logs)", - len(failedJobs), len(jobs), errStr[:min(500, len(errStr)-1)]) -} diff --git a/pkg/session/session.go b/pkg/session/session.go index 4ca9bdc5227ba..e46006fee9389 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -519,13 +519,13 @@ func (s *session) doCommit(ctx context.Context) error { return err } // mockCommitError and mockGetTSErrorInRetry use to test PR #8743. - if val, _err_ := failpoint.Eval(_curpkg_("mockCommitError")); _err_ == nil { + failpoint.Inject("mockCommitError", func(val failpoint.Value) { if val.(bool) { if _, err := failpoint.Eval("tikvclient/mockCommitErrorOpt"); err == nil { - return kv.ErrTxnRetryable + failpoint.Return(kv.ErrTxnRetryable) } } - } + }) sessVars := s.GetSessionVars() @@ -661,10 +661,10 @@ func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transac sessVars := s.sessionVars txnTempTables := sessVars.TxnCtx.TemporaryTables if len(txnTempTables) == 0 { - if v, _err_ := failpoint.Eval(_curpkg_("mockSleepBeforeTxnCommit")); _err_ == nil { + failpoint.Inject("mockSleepBeforeTxnCommit", func(v failpoint.Value) { ms := v.(int) time.Sleep(time.Millisecond * time.Duration(ms)) - } + }) return txn.Commit(ctx) } @@ -925,11 +925,11 @@ func (s *session) CommitTxn(ctx context.Context) error { // record the TTLInsertRows in the metric metrics.TTLInsertRowsCount.Add(float64(s.sessionVars.TxnCtx.InsertTTLRowsCount)) - if val, _err_ := failpoint.Eval(_curpkg_("keepHistory")); _err_ == nil { + failpoint.Inject("keepHistory", func(val failpoint.Value) { if val.(bool) { - return err + failpoint.Return(err) } - } + }) s.sessionVars.TxnCtx.Cleanup() s.sessionVars.CleanupTxnReadTSIfUsed() return err @@ -1117,12 +1117,12 @@ func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { logutil.Logger(ctx).Warn("transaction association", zap.Uint64("retrying txnStartTS", s.GetSessionVars().TxnCtx.StartTS), zap.Uint64("original txnStartTS", orgStartTS)) - if _, _err_ := failpoint.Eval(_curpkg_("preCommitHook")); _err_ == nil { + failpoint.Inject("preCommitHook", func() { hook, ok := ctx.Value("__preCommitHook").(func()) if ok { hook() } - } + }) if err == nil { err = s.doCommit(ctx) if err == nil { @@ -2068,12 +2068,12 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex return nil, err } - if val, _err_ := failpoint.Eval(_curpkg_("mockStmtSlow")); _err_ == nil { + failpoint.Inject("mockStmtSlow", func(val failpoint.Value) { if strings.Contains(stmtNode.Text(), "/* sleep */") { v, _ := val.(int) time.Sleep(time.Duration(v) * time.Millisecond) } - } + }) var stmtLabel string if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { @@ -2243,12 +2243,12 @@ func (s *session) hasFileTransInConn() bool { // runStmt executes the sqlexec.Statement and commit or rollback the current transaction. func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) { - if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerInRunStmt")); _err_ == nil { + failpoint.Inject("assertTxnManagerInRunStmt", func() { sessiontxn.RecordAssert(se, "assertTxnManagerInRunStmt", true) if stmt, ok := s.(*executor.ExecStmt); ok { sessiontxn.AssertTxnManagerInfoSchema(se, stmt.InfoSchema) } - } + }) r, ctx := tracing.StartRegionEx(ctx, "session.runStmt") defer r.End() @@ -3946,15 +3946,15 @@ func (s *session) PrepareTSFuture(ctx context.Context, future oracle.Future, sco return errors.New("cannot prepare ts future when txn is valid") } - if _, _err_ := failpoint.Eval(_curpkg_("assertTSONotRequest")); _err_ == nil { + failpoint.Inject("assertTSONotRequest", func() { if _, ok := future.(sessiontxn.ConstantFuture); !ok && !s.isInternal() { panic("tso shouldn't be requested") } - } + }) - if _, _err_ := failpoint.EvalContext(ctx, _curpkg_("mockGetTSFail")); _err_ == nil { + failpoint.InjectContext(ctx, "mockGetTSFail", func() { future = txnFailFuture{} - } + }) s.txn.changeToPending(&txnFuture{ future: future, diff --git a/pkg/session/session.go__failpoint_stash__ b/pkg/session/session.go__failpoint_stash__ deleted file mode 100644 index e46006fee9389..0000000000000 --- a/pkg/session/session.go__failpoint_stash__ +++ /dev/null @@ -1,4611 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -package session - -import ( - "bytes" - "context" - "crypto/tls" - "encoding/hex" - "encoding/json" - stderrs "errors" - "fmt" - "math" - "math/rand" - "runtime/pprof" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/ngaut/pools" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/bindinfo" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/ddl/placement" - distsqlctx "github.com/pingcap/tidb/pkg/distsql/context" - "github.com/pingcap/tidb/pkg/disttask/framework/proto" - "github.com/pingcap/tidb/pkg/disttask/framework/scheduler" - "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor" - "github.com/pingcap/tidb/pkg/disttask/importinto" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/executor" - "github.com/pingcap/tidb/pkg/executor/staticrecordset" - "github.com/pingcap/tidb/pkg/expression" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/expression/contextsession" - "github.com/pingcap/tidb/pkg/extension" - "github.com/pingcap/tidb/pkg/extension/extensionimpl" - "github.com/pingcap/tidb/pkg/infoschema" - infoschemactx "github.com/pingcap/tidb/pkg/infoschema/context" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/owner" - "github.com/pingcap/tidb/pkg/param" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/auth" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/planner" - planctx "github.com/pingcap/tidb/pkg/planner/context" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/plugin" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/privilege/conn" - "github.com/pingcap/tidb/pkg/privilege/privileges" - "github.com/pingcap/tidb/pkg/session/cursor" - session_metrics "github.com/pingcap/tidb/pkg/session/metrics" - "github.com/pingcap/tidb/pkg/session/txninfo" - "github.com/pingcap/tidb/pkg/session/types" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/sessionctx/sessionstates" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/statistics/handle/syncload" - "github.com/pingcap/tidb/pkg/statistics/handle/usage" - "github.com/pingcap/tidb/pkg/statistics/handle/usage/indexusage" - storeerr "github.com/pingcap/tidb/pkg/store/driver/error" - "github.com/pingcap/tidb/pkg/store/helper" - "github.com/pingcap/tidb/pkg/table" - tbctx "github.com/pingcap/tidb/pkg/table/context" - tbctximpl "github.com/pingcap/tidb/pkg/table/contextimpl" - "github.com/pingcap/tidb/pkg/table/temptable" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/ttl/ttlworker" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/logutil/consistency" - "github.com/pingcap/tidb/pkg/util/memory" - rangerctx "github.com/pingcap/tidb/pkg/util/ranger/context" - "github.com/pingcap/tidb/pkg/util/redact" - "github.com/pingcap/tidb/pkg/util/sem" - "github.com/pingcap/tidb/pkg/util/sli" - "github.com/pingcap/tidb/pkg/util/sqlescape" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/pingcap/tidb/pkg/util/timeutil" - "github.com/pingcap/tidb/pkg/util/topsql" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" - "github.com/pingcap/tidb/pkg/util/tracing" - tikverr "github.com/tikv/client-go/v2/error" - "github.com/tikv/client-go/v2/oracle" - tikvutil "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -func init() { - executor.CreateSession = func(ctx sessionctx.Context) (sessionctx.Context, error) { - return CreateSession(ctx.GetStore()) - } - executor.CloseSession = func(ctx sessionctx.Context) { - if se, ok := ctx.(types.Session); ok { - se.Close() - } - } -} - -var _ types.Session = (*session)(nil) - -type stmtRecord struct { - st sqlexec.Statement - stmtCtx *stmtctx.StatementContext -} - -// StmtHistory holds all histories of statements in a txn. -type StmtHistory struct { - history []*stmtRecord -} - -// Add appends a stmt to history list. -func (h *StmtHistory) Add(st sqlexec.Statement, stmtCtx *stmtctx.StatementContext) { - s := &stmtRecord{ - st: st, - stmtCtx: stmtCtx, - } - h.history = append(h.history, s) -} - -// Count returns the count of the history. -func (h *StmtHistory) Count() int { - return len(h.history) -} - -type session struct { - // processInfo is used by ShowProcess(), and should be modified atomically. - processInfo atomic.Pointer[util.ProcessInfo] - txn LazyTxn - - mu struct { - sync.RWMutex - values map[fmt.Stringer]any - } - - currentCtx context.Context // only use for runtime.trace, Please NEVER use it. - currentPlan base.Plan - - store kv.Storage - - sessionPlanCache sessionctx.SessionPlanCache - - sessionVars *variable.SessionVars - sessionManager util.SessionManager - - pctx *planContextImpl - exprctx *contextsession.SessionExprContext - tblctx *tbctximpl.TableContextImpl - - statsCollector *usage.SessionStatsItem - // ddlOwnerManager is used in `select tidb_is_ddl_owner()` statement; - ddlOwnerManager owner.Manager - // lockedTables use to record the table locks hold by the session. - lockedTables map[int64]model.TableLockTpInfo - - // client shared coprocessor client per session - client kv.Client - - mppClient kv.MPPClient - - // indexUsageCollector collects index usage information. - idxUsageCollector *indexusage.SessionIndexUsageCollector - - // StmtStats is used to count various indicators of each SQL in this session - // at each point in time. These data will be periodically taken away by the - // background goroutine. The background goroutine will continue to aggregate - // all the local data in each session, and finally report them to the remote - // regularly. - stmtStats *stmtstats.StatementStats - - // Used to encode and decode each type of session states. - sessionStatesHandlers map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler - - // Contains a list of sessions used to collect advisory locks. - advisoryLocks map[string]*advisoryLock - - extensions *extension.SessionExtensions - - sandBoxMode bool - - cursorTracker cursor.Tracker -} - -var parserPool = &sync.Pool{New: func() any { return parser.New() }} - -// AddTableLock adds table lock to the session lock map. -func (s *session) AddTableLock(locks []model.TableLockTpInfo) { - for _, l := range locks { - // read only lock is session unrelated, skip it when adding lock to session. - if l.Tp != model.TableLockReadOnly { - s.lockedTables[l.TableID] = l - } - } -} - -// ReleaseTableLocks releases table lock in the session lock map. -func (s *session) ReleaseTableLocks(locks []model.TableLockTpInfo) { - for _, l := range locks { - delete(s.lockedTables, l.TableID) - } -} - -// ReleaseTableLockByTableIDs releases table lock in the session lock map by table ID. -func (s *session) ReleaseTableLockByTableIDs(tableIDs []int64) { - for _, tblID := range tableIDs { - delete(s.lockedTables, tblID) - } -} - -// CheckTableLocked checks the table lock. -func (s *session) CheckTableLocked(tblID int64) (bool, model.TableLockType) { - lt, ok := s.lockedTables[tblID] - if !ok { - return false, model.TableLockNone - } - return true, lt.Tp -} - -// GetAllTableLocks gets all table locks table id and db id hold by the session. -func (s *session) GetAllTableLocks() []model.TableLockTpInfo { - lockTpInfo := make([]model.TableLockTpInfo, 0, len(s.lockedTables)) - for _, tl := range s.lockedTables { - lockTpInfo = append(lockTpInfo, tl) - } - return lockTpInfo -} - -// HasLockedTables uses to check whether this session locked any tables. -// If so, the session can only visit the table which locked by self. -func (s *session) HasLockedTables() bool { - b := len(s.lockedTables) > 0 - return b -} - -// ReleaseAllTableLocks releases all table locks hold by the session. -func (s *session) ReleaseAllTableLocks() { - s.lockedTables = make(map[int64]model.TableLockTpInfo) -} - -// IsDDLOwner checks whether this session is DDL owner. -func (s *session) IsDDLOwner() bool { - return s.ddlOwnerManager.IsOwner() -} - -func (s *session) cleanRetryInfo() { - if s.sessionVars.RetryInfo.Retrying { - return - } - - retryInfo := s.sessionVars.RetryInfo - defer retryInfo.Clean() - if len(retryInfo.DroppedPreparedStmtIDs) == 0 { - return - } - - planCacheEnabled := s.GetSessionVars().EnablePreparedPlanCache - var cacheKey string - var err error - var preparedObj *plannercore.PlanCacheStmt - if planCacheEnabled { - firstStmtID := retryInfo.DroppedPreparedStmtIDs[0] - if preparedPointer, ok := s.sessionVars.PreparedStmts[firstStmtID]; ok { - preparedObj, ok = preparedPointer.(*plannercore.PlanCacheStmt) - if ok { - cacheKey, _, _, _, err = plannercore.NewPlanCacheKey(s, preparedObj) - if err != nil { - logutil.Logger(s.currentCtx).Warn("clean cached plan failed", zap.Error(err)) - return - } - } - } - } - for i, stmtID := range retryInfo.DroppedPreparedStmtIDs { - if planCacheEnabled { - if i > 0 && preparedObj != nil { - cacheKey, _, _, _, err = plannercore.NewPlanCacheKey(s, preparedObj) - if err != nil { - logutil.Logger(s.currentCtx).Warn("clean cached plan failed", zap.Error(err)) - return - } - } - if !s.sessionVars.IgnorePreparedCacheCloseStmt { // keep the plan in cache - s.GetSessionPlanCache().Delete(cacheKey) - } - } - s.sessionVars.RemovePreparedStmt(stmtID) - } -} - -func (s *session) Status() uint16 { - return s.sessionVars.Status() -} - -func (s *session) LastInsertID() uint64 { - if s.sessionVars.StmtCtx.LastInsertID > 0 { - return s.sessionVars.StmtCtx.LastInsertID - } - return s.sessionVars.StmtCtx.InsertID -} - -func (s *session) LastMessage() string { - return s.sessionVars.StmtCtx.GetMessage() -} - -func (s *session) AffectedRows() uint64 { - return s.sessionVars.StmtCtx.AffectedRows() -} - -func (s *session) SetClientCapability(capability uint32) { - s.sessionVars.ClientCapability = capability -} - -func (s *session) SetConnectionID(connectionID uint64) { - s.sessionVars.ConnectionID = connectionID -} - -func (s *session) SetTLSState(tlsState *tls.ConnectionState) { - // If user is not connected via TLS, then tlsState == nil. - if tlsState != nil { - s.sessionVars.TLSConnectionState = tlsState - } -} - -func (s *session) SetCompressionAlgorithm(ca int) { - s.sessionVars.CompressionAlgorithm = ca -} - -func (s *session) SetCompressionLevel(level int) { - s.sessionVars.CompressionLevel = level -} - -func (s *session) SetCommandValue(command byte) { - atomic.StoreUint32(&s.sessionVars.CommandValue, uint32(command)) -} - -func (s *session) SetCollation(coID int) error { - cs, co, err := charset.GetCharsetInfoByID(coID) - if err != nil { - return err - } - // If new collations are enabled, switch to the default - // collation if this one is not supported. - co = collate.SubstituteMissingCollationToDefault(co) - for _, v := range variable.SetNamesVariables { - terror.Log(s.sessionVars.SetSystemVarWithoutValidation(v, cs)) - } - return s.sessionVars.SetSystemVarWithoutValidation(variable.CollationConnection, co) -} - -func (s *session) GetSessionPlanCache() sessionctx.SessionPlanCache { - // use the prepared plan cache - if !s.GetSessionVars().EnablePreparedPlanCache && !s.GetSessionVars().EnableNonPreparedPlanCache { - return nil - } - if s.sessionPlanCache == nil { // lazy construction - s.sessionPlanCache = plannercore.NewLRUPlanCache(uint(s.GetSessionVars().SessionPlanCacheSize), - variable.PreparedPlanCacheMemoryGuardRatio.Load(), plannercore.PreparedPlanCacheMaxMemory.Load(), s, false) - } - return s.sessionPlanCache -} - -func (s *session) SetSessionManager(sm util.SessionManager) { - s.sessionManager = sm -} - -func (s *session) GetSessionManager() util.SessionManager { - return s.sessionManager -} - -func (s *session) UpdateColStatsUsage(predicateColumns []model.TableItemID) { - if s.statsCollector == nil { - return - } - t := time.Now() - colMap := make(map[model.TableItemID]time.Time, len(predicateColumns)) - for _, col := range predicateColumns { - // TODO: Remove this assertion once it has been confirmed to operate correctly over a period of time. - intest.Assert(!col.IsIndex, "predicate column should only be table column") - colMap[col] = t - } - s.statsCollector.UpdateColStatsUsage(colMap) -} - -// FieldList returns fields list of a table. -func (s *session) FieldList(tableName string) ([]*ast.ResultField, error) { - is := s.GetInfoSchema().(infoschema.InfoSchema) - dbName := model.NewCIStr(s.GetSessionVars().CurrentDB) - tName := model.NewCIStr(tableName) - pm := privilege.GetPrivilegeManager(s) - if pm != nil && s.sessionVars.User != nil { - if !pm.RequestVerification(s.sessionVars.ActiveRoles, dbName.O, tName.O, "", mysql.AllPrivMask) { - user := s.sessionVars.User - u := user.Username - h := user.Hostname - if len(user.AuthUsername) > 0 && len(user.AuthHostname) > 0 { - u = user.AuthUsername - h = user.AuthHostname - } - return nil, plannererrors.ErrTableaccessDenied.GenWithStackByArgs("SELECT", u, h, tableName) - } - } - table, err := is.TableByName(context.Background(), dbName, tName) - if err != nil { - return nil, err - } - - cols := table.Cols() - fields := make([]*ast.ResultField, 0, len(cols)) - for _, col := range table.Cols() { - rf := &ast.ResultField{ - ColumnAsName: col.Name, - TableAsName: tName, - DBName: dbName, - Table: table.Meta(), - Column: col.ColumnInfo, - } - fields = append(fields, rf) - } - return fields, nil -} - -// TxnInfo returns a pointer to a *copy* of the internal TxnInfo, thus is *read only* -func (s *session) TxnInfo() *txninfo.TxnInfo { - s.txn.mu.RLock() - // Copy on read to get a snapshot, this API shouldn't be frequently called. - txnInfo := s.txn.mu.TxnInfo - s.txn.mu.RUnlock() - - if txnInfo.StartTS == 0 { - return nil - } - - processInfo := s.ShowProcess() - if processInfo == nil { - return nil - } - txnInfo.ConnectionID = processInfo.ID - txnInfo.Username = processInfo.User - txnInfo.CurrentDB = processInfo.DB - txnInfo.RelatedTableIDs = make(map[int64]struct{}) - s.GetSessionVars().GetRelatedTableForMDL().Range(func(key, _ any) bool { - txnInfo.RelatedTableIDs[key.(int64)] = struct{}{} - return true - }) - - return &txnInfo -} - -func (s *session) doCommit(ctx context.Context) error { - if !s.txn.Valid() { - return nil - } - - defer func() { - s.txn.changeToInvalid() - s.sessionVars.SetInTxn(false) - s.sessionVars.ClearDiskFullOpt() - }() - // check if the transaction is read-only - if s.txn.IsReadOnly() { - return nil - } - // check if the cluster is read-only - if !s.sessionVars.InRestrictedSQL && (variable.RestrictedReadOnly.Load() || variable.VarTiDBSuperReadOnly.Load()) { - // It is not internal SQL, and the cluster has one of RestrictedReadOnly or SuperReadOnly - // We need to privilege check again: a privilege check occurred during planning, but we need - // to prevent the case that a long running auto-commit statement is now trying to commit. - pm := privilege.GetPrivilegeManager(s) - roles := s.sessionVars.ActiveRoles - if pm != nil && !pm.HasExplicitlyGrantedDynamicPrivilege(roles, "RESTRICTED_REPLICA_WRITER_ADMIN", false) { - s.RollbackTxn(ctx) - return plannererrors.ErrSQLInReadOnlyMode - } - } - err := s.checkPlacementPolicyBeforeCommit() - if err != nil { - return err - } - // mockCommitError and mockGetTSErrorInRetry use to test PR #8743. - failpoint.Inject("mockCommitError", func(val failpoint.Value) { - if val.(bool) { - if _, err := failpoint.Eval("tikvclient/mockCommitErrorOpt"); err == nil { - failpoint.Return(kv.ErrTxnRetryable) - } - } - }) - - sessVars := s.GetSessionVars() - - var commitTSChecker func(uint64) bool - if tables := sessVars.TxnCtx.CachedTables; len(tables) > 0 { - c := cachedTableRenewLease{tables: tables} - now := time.Now() - err := c.start(ctx) - defer c.stop(ctx) - sessVars.StmtCtx.WaitLockLeaseTime += time.Since(now) - if err != nil { - return errors.Trace(err) - } - commitTSChecker = c.commitTSCheck - } - if err = sessiontxn.GetTxnManager(s).SetOptionsBeforeCommit(s.txn.Transaction, commitTSChecker); err != nil { - return err - } - - err = s.commitTxnWithTemporaryData(tikvutil.SetSessionID(ctx, sessVars.ConnectionID), &s.txn) - if err != nil { - err = s.handleAssertionFailure(ctx, err) - } - return err -} - -type cachedTableRenewLease struct { - tables map[int64]any - lease []uint64 // Lease for each visited cached tables. - exit chan struct{} -} - -func (c *cachedTableRenewLease) start(ctx context.Context) error { - c.exit = make(chan struct{}) - c.lease = make([]uint64, len(c.tables)) - wg := make(chan error, len(c.tables)) - ith := 0 - for _, raw := range c.tables { - tbl := raw.(table.CachedTable) - go tbl.WriteLockAndKeepAlive(ctx, c.exit, &c.lease[ith], wg) - ith++ - } - - // Wait for all LockForWrite() return, this function can return. - var err error - for ; ith > 0; ith-- { - tmp := <-wg - if tmp != nil { - err = tmp - } - } - return err -} - -func (c *cachedTableRenewLease) stop(_ context.Context) { - close(c.exit) -} - -func (c *cachedTableRenewLease) commitTSCheck(commitTS uint64) bool { - for i := 0; i < len(c.lease); i++ { - lease := atomic.LoadUint64(&c.lease[i]) - if commitTS >= lease { - // Txn fails to commit because the write lease is expired. - return false - } - } - return true -} - -// handleAssertionFailure extracts the possible underlying assertionFailed error, -// gets the corresponding MVCC history and logs it. -// If it's not an assertion failure, returns the original error. -func (s *session) handleAssertionFailure(ctx context.Context, err error) error { - var assertionFailure *tikverr.ErrAssertionFailed - if !stderrs.As(err, &assertionFailure) { - return err - } - key := assertionFailure.Key - newErr := kv.ErrAssertionFailed.GenWithStackByArgs( - hex.EncodeToString(key), assertionFailure.Assertion.String(), assertionFailure.StartTs, - assertionFailure.ExistingStartTs, assertionFailure.ExistingCommitTs, - ) - - rmode := s.GetSessionVars().EnableRedactLog - if rmode == errors.RedactLogEnable { - return newErr - } - - var decodeFunc func(kv.Key, *kvrpcpb.MvccGetByKeyResponse, map[string]any) - // if it's a record key or an index key, decode it - if infoSchema, ok := s.sessionVars.TxnCtx.InfoSchema.(infoschema.InfoSchema); ok && - infoSchema != nil && (tablecodec.IsRecordKey(key) || tablecodec.IsIndexKey(key)) { - tableOrPartitionID := tablecodec.DecodeTableID(key) - tbl, ok := infoSchema.TableByID(tableOrPartitionID) - if !ok { - tbl, _, _ = infoSchema.FindTableByPartitionID(tableOrPartitionID) - } - if tbl == nil { - logutil.Logger(ctx).Warn("cannot find table by id", zap.Int64("tableID", tableOrPartitionID), zap.String("key", hex.EncodeToString(key))) - return newErr - } - - if tablecodec.IsRecordKey(key) { - decodeFunc = consistency.DecodeRowMvccData(tbl.Meta()) - } else { - tableInfo := tbl.Meta() - _, indexID, _, e := tablecodec.DecodeIndexKey(key) - if e != nil { - logutil.Logger(ctx).Error("assertion failed but cannot decode index key", zap.Error(e)) - return newErr - } - var indexInfo *model.IndexInfo - for _, idx := range tableInfo.Indices { - if idx.ID == indexID { - indexInfo = idx - break - } - } - if indexInfo == nil { - return newErr - } - decodeFunc = consistency.DecodeIndexMvccData(indexInfo) - } - } - if store, ok := s.store.(helper.Storage); ok { - content := consistency.GetMvccByKey(store, key, decodeFunc) - logutil.Logger(ctx).Error("assertion failed", zap.String("message", newErr.Error()), zap.String("mvcc history", redact.String(rmode, content))) - } - return newErr -} - -func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transaction) error { - sessVars := s.sessionVars - txnTempTables := sessVars.TxnCtx.TemporaryTables - if len(txnTempTables) == 0 { - failpoint.Inject("mockSleepBeforeTxnCommit", func(v failpoint.Value) { - ms := v.(int) - time.Sleep(time.Millisecond * time.Duration(ms)) - }) - return txn.Commit(ctx) - } - - sessionData := sessVars.TemporaryTableData - var ( - stage kv.StagingHandle - localTempTables *infoschema.SessionTables - ) - - if sessVars.LocalTemporaryTables != nil { - localTempTables = sessVars.LocalTemporaryTables.(*infoschema.SessionTables) - } else { - localTempTables = new(infoschema.SessionTables) - } - - defer func() { - // stage != kv.InvalidStagingHandle means error occurs, we need to cleanup sessionData - if stage != kv.InvalidStagingHandle { - sessionData.Cleanup(stage) - } - }() - - for tblID, tbl := range txnTempTables { - if !tbl.GetModified() { - continue - } - - if tbl.GetMeta().TempTableType != model.TempTableLocal { - continue - } - if _, ok := localTempTables.TableByID(tblID); !ok { - continue - } - - if stage == kv.InvalidStagingHandle { - stage = sessionData.Staging() - } - - tblPrefix := tablecodec.EncodeTablePrefix(tblID) - endKey := tablecodec.EncodeTablePrefix(tblID + 1) - - txnMemBuffer := s.txn.GetMemBuffer() - iter, err := txnMemBuffer.Iter(tblPrefix, endKey) - if err != nil { - return err - } - - for iter.Valid() { - key := iter.Key() - if !bytes.HasPrefix(key, tblPrefix) { - break - } - - value := iter.Value() - if len(value) == 0 { - err = sessionData.DeleteTableKey(tblID, key) - } else { - err = sessionData.SetTableKey(tblID, key, iter.Value()) - } - - if err != nil { - return err - } - - err = iter.Next() - if err != nil { - return err - } - } - } - - err := txn.Commit(ctx) - if err != nil { - return err - } - - if stage != kv.InvalidStagingHandle { - sessionData.Release(stage) - stage = kv.InvalidStagingHandle - } - - return nil -} - -// errIsNoisy is used to filter DUPLICATE KEY errors. -// These can observed by users in INFORMATION_SCHEMA.CLIENT_ERRORS_SUMMARY_GLOBAL instead. -// -// The rationale for filtering these errors is because they are "client generated errors". i.e. -// of the errors defined in kv/error.go, these look to be clearly related to a client-inflicted issue, -// and the server is only responsible for handling the error correctly. It does not need to log. -func errIsNoisy(err error) bool { - if kv.ErrKeyExists.Equal(err) { - return true - } - if storeerr.ErrLockAcquireFailAndNoWaitSet.Equal(err) { - return true - } - return false -} - -func (s *session) doCommitWithRetry(ctx context.Context) error { - defer func() { - s.GetSessionVars().SetTxnIsolationLevelOneShotStateForNextTxn() - s.txn.changeToInvalid() - s.cleanRetryInfo() - sessiontxn.GetTxnManager(s).OnTxnEnd() - }() - if !s.txn.Valid() { - // If the transaction is invalid, maybe it has already been rolled back by the client. - return nil - } - isInternalTxn := false - if internal := s.txn.GetOption(kv.RequestSourceInternal); internal != nil && internal.(bool) { - isInternalTxn = true - } - var err error - txnSize := s.txn.Size() - isPessimistic := s.txn.IsPessimistic() - isPipelined := s.txn.IsPipelined() - r, ctx := tracing.StartRegionEx(ctx, "session.doCommitWithRetry") - defer r.End() - - err = s.doCommit(ctx) - if err != nil { - // polish the Write Conflict error message - newErr := s.tryReplaceWriteConflictError(err) - if newErr != nil { - err = newErr - } - - commitRetryLimit := s.sessionVars.RetryLimit - if !s.sessionVars.TxnCtx.CouldRetry { - commitRetryLimit = 0 - } - // Don't retry in BatchInsert mode. As a counter-example, insert into t1 select * from t2, - // BatchInsert already commit the first batch 1000 rows, then it commit 1000-2000 and retry the statement, - // Finally t1 will have more data than t2, with no errors return to user! - if s.isTxnRetryableError(err) && !s.sessionVars.BatchInsert && commitRetryLimit > 0 && !isPessimistic && !isPipelined { - logutil.Logger(ctx).Warn("sql", - zap.String("label", s.GetSQLLabel()), - zap.Error(err), - zap.String("txn", s.txn.GoString())) - // Transactions will retry 2 ~ commitRetryLimit times. - // We make larger transactions retry less times to prevent cluster resource outage. - txnSizeRate := float64(txnSize) / float64(kv.TxnTotalSizeLimit.Load()) - maxRetryCount := commitRetryLimit - int64(float64(commitRetryLimit-1)*txnSizeRate) - err = s.retry(ctx, uint(maxRetryCount)) - } else if !errIsNoisy(err) { - logutil.Logger(ctx).Warn("can not retry txn", - zap.String("label", s.GetSQLLabel()), - zap.Error(err), - zap.Bool("IsBatchInsert", s.sessionVars.BatchInsert), - zap.Bool("IsPessimistic", isPessimistic), - zap.Bool("InRestrictedSQL", s.sessionVars.InRestrictedSQL), - zap.Int64("tidb_retry_limit", s.sessionVars.RetryLimit), - zap.Bool("tidb_disable_txn_auto_retry", s.sessionVars.DisableTxnAutoRetry)) - } - } - counter := s.sessionVars.TxnCtx.StatementCount - duration := time.Since(s.GetSessionVars().TxnCtx.CreateTime).Seconds() - s.recordOnTransactionExecution(err, counter, duration, isInternalTxn) - - if err != nil { - if !errIsNoisy(err) { - logutil.Logger(ctx).Warn("commit failed", - zap.String("finished txn", s.txn.GoString()), - zap.Error(err)) - } - return err - } - s.updateStatsDeltaToCollector() - return nil -} - -// adds more information about the table in the error message -// precondition: oldErr is a 9007:WriteConflict Error -func (s *session) tryReplaceWriteConflictError(oldErr error) (newErr error) { - if !kv.ErrWriteConflict.Equal(oldErr) { - return nil - } - if errors.RedactLogEnabled.Load() == errors.RedactLogEnable { - return nil - } - originErr := errors.Cause(oldErr) - inErr, _ := originErr.(*errors.Error) - // we don't want to modify the oldErr, so copy the args list - oldArgs := inErr.Args() - args := make([]any, len(oldArgs)) - copy(args, oldArgs) - is := sessiontxn.GetTxnManager(s).GetTxnInfoSchema() - if is == nil { - return nil - } - newKeyTableField, ok := addTableNameInTableIDField(args[3], is) - if ok { - args[3] = newKeyTableField - } - newPrimaryKeyTableField, ok := addTableNameInTableIDField(args[5], is) - if ok { - args[5] = newPrimaryKeyTableField - } - return kv.ErrWriteConflict.FastGenByArgs(args...) -} - -// precondition: is != nil -func addTableNameInTableIDField(tableIDField any, is infoschema.InfoSchema) (enhancedMsg string, done bool) { - keyTableID, ok := tableIDField.(string) - if !ok { - return "", false - } - stringsInTableIDField := strings.Split(keyTableID, "=") - if len(stringsInTableIDField) == 0 { - return "", false - } - tableIDStr := stringsInTableIDField[len(stringsInTableIDField)-1] - tableID, err := strconv.ParseInt(tableIDStr, 10, 64) - if err != nil { - return "", false - } - var tableName string - tbl, ok := is.TableByID(tableID) - if !ok { - tableName = "unknown" - } else { - dbInfo, ok := infoschema.SchemaByTable(is, tbl.Meta()) - if !ok { - tableName = "unknown." + tbl.Meta().Name.String() - } else { - tableName = dbInfo.Name.String() + "." + tbl.Meta().Name.String() - } - } - enhancedMsg = keyTableID + ", tableName=" + tableName - return enhancedMsg, true -} - -func (s *session) updateStatsDeltaToCollector() { - mapper := s.GetSessionVars().TxnCtx.TableDeltaMap - if s.statsCollector != nil && mapper != nil { - for _, item := range mapper { - if item.TableID > 0 { - s.statsCollector.Update(item.TableID, item.Delta, item.Count, &item.ColSize) - } - } - } -} - -func (s *session) CommitTxn(ctx context.Context) error { - r, ctx := tracing.StartRegionEx(ctx, "session.CommitTxn") - defer r.End() - - var commitDetail *tikvutil.CommitDetails - ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) - err := s.doCommitWithRetry(ctx) - if commitDetail != nil { - s.sessionVars.StmtCtx.MergeExecDetails(nil, commitDetail) - } - - // record the TTLInsertRows in the metric - metrics.TTLInsertRowsCount.Add(float64(s.sessionVars.TxnCtx.InsertTTLRowsCount)) - - failpoint.Inject("keepHistory", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(err) - } - }) - s.sessionVars.TxnCtx.Cleanup() - s.sessionVars.CleanupTxnReadTSIfUsed() - return err -} - -func (s *session) RollbackTxn(ctx context.Context) { - r, ctx := tracing.StartRegionEx(ctx, "session.RollbackTxn") - defer r.End() - - if s.txn.Valid() { - terror.Log(s.txn.Rollback()) - } - if ctx.Value(inCloseSession{}) == nil { - s.cleanRetryInfo() - } - s.txn.changeToInvalid() - s.sessionVars.TxnCtx.Cleanup() - s.sessionVars.CleanupTxnReadTSIfUsed() - s.sessionVars.SetInTxn(false) - sessiontxn.GetTxnManager(s).OnTxnEnd() -} - -func (s *session) GetClient() kv.Client { - return s.client -} - -func (s *session) GetMPPClient() kv.MPPClient { - return s.mppClient -} - -func (s *session) String() string { - // TODO: how to print binded context in values appropriately? - sessVars := s.sessionVars - data := map[string]any{ - "id": sessVars.ConnectionID, - "user": sessVars.User, - "currDBName": sessVars.CurrentDB, - "status": sessVars.Status(), - "strictMode": sessVars.SQLMode.HasStrictMode(), - } - if s.txn.Valid() { - // if txn is committed or rolled back, txn is nil. - data["txn"] = s.txn.String() - } - if sessVars.SnapshotTS != 0 { - data["snapshotTS"] = sessVars.SnapshotTS - } - if sessVars.StmtCtx.LastInsertID > 0 { - data["lastInsertID"] = sessVars.StmtCtx.LastInsertID - } - if len(sessVars.PreparedStmts) > 0 { - data["preparedStmtCount"] = len(sessVars.PreparedStmts) - } - b, err := json.MarshalIndent(data, "", " ") - terror.Log(errors.Trace(err)) - return string(b) -} - -const sqlLogMaxLen = 1024 - -// SchemaChangedWithoutRetry is used for testing. -var SchemaChangedWithoutRetry uint32 - -func (s *session) GetSQLLabel() string { - if s.sessionVars.InRestrictedSQL { - return metrics.LblInternal - } - return metrics.LblGeneral -} - -func (s *session) isInternal() bool { - return s.sessionVars.InRestrictedSQL -} - -func (*session) isTxnRetryableError(err error) bool { - if atomic.LoadUint32(&SchemaChangedWithoutRetry) == 1 { - return kv.IsTxnRetryableError(err) - } - return kv.IsTxnRetryableError(err) || domain.ErrInfoSchemaChanged.Equal(err) -} - -func isEndTxnStmt(stmt ast.StmtNode, vars *variable.SessionVars) (bool, error) { - switch n := stmt.(type) { - case *ast.RollbackStmt, *ast.CommitStmt: - return true, nil - case *ast.ExecuteStmt: - ps, err := plannercore.GetPreparedStmt(n, vars) - if err != nil { - return false, err - } - return isEndTxnStmt(ps.PreparedAst.Stmt, vars) - } - return false, nil -} - -func (s *session) checkTxnAborted(stmt sqlexec.Statement) error { - if atomic.LoadUint32(&s.GetSessionVars().TxnCtx.LockExpire) == 0 { - return nil - } - // If the transaction is aborted, the following statements do not need to execute, except `commit` and `rollback`, - // because they are used to finish the aborted transaction. - if ok, err := isEndTxnStmt(stmt.(*executor.ExecStmt).StmtNode, s.sessionVars); err == nil && ok { - return nil - } else if err != nil { - return err - } - return kv.ErrLockExpire -} - -func (s *session) retry(ctx context.Context, maxCnt uint) (err error) { - var retryCnt uint - defer func() { - s.sessionVars.RetryInfo.Retrying = false - // retryCnt only increments on retryable error, so +1 here. - if s.sessionVars.InRestrictedSQL { - session_metrics.TransactionRetryInternal.Observe(float64(retryCnt + 1)) - } else { - session_metrics.TransactionRetryGeneral.Observe(float64(retryCnt + 1)) - } - s.sessionVars.SetInTxn(false) - if err != nil { - s.RollbackTxn(ctx) - } - s.txn.changeToInvalid() - }() - - connID := s.sessionVars.ConnectionID - s.sessionVars.RetryInfo.Retrying = true - if atomic.LoadUint32(&s.sessionVars.TxnCtx.ForUpdate) == 1 { - err = ErrForUpdateCantRetry.GenWithStackByArgs(connID) - return err - } - - nh := GetHistory(s) - var schemaVersion int64 - sessVars := s.GetSessionVars() - orgStartTS := sessVars.TxnCtx.StartTS - label := s.GetSQLLabel() - for { - if err = s.PrepareTxnCtx(ctx); err != nil { - return err - } - s.sessionVars.RetryInfo.ResetOffset() - for i, sr := range nh.history { - st := sr.st - s.sessionVars.StmtCtx = sr.stmtCtx - s.sessionVars.StmtCtx.CTEStorageMap = map[int]*executor.CTEStorages{} - s.sessionVars.StmtCtx.ResetForRetry() - s.sessionVars.PlanCacheParams.Reset() - schemaVersion, err = st.RebuildPlan(ctx) - if err != nil { - return err - } - - if retryCnt == 0 { - // We do not have to log the query every time. - // We print the queries at the first try only. - sql := sqlForLog(st.GetTextToLog(false)) - if sessVars.EnableRedactLog != errors.RedactLogEnable { - sql += redact.String(sessVars.EnableRedactLog, sessVars.PlanCacheParams.String()) - } - logutil.Logger(ctx).Warn("retrying", - zap.Int64("schemaVersion", schemaVersion), - zap.Uint("retryCnt", retryCnt), - zap.Int("queryNum", i), - zap.String("sql", sql)) - } else { - logutil.Logger(ctx).Warn("retrying", - zap.Int64("schemaVersion", schemaVersion), - zap.Uint("retryCnt", retryCnt), - zap.Int("queryNum", i)) - } - _, digest := s.sessionVars.StmtCtx.SQLDigest() - s.txn.onStmtStart(digest.String()) - if err = sessiontxn.GetTxnManager(s).OnStmtStart(ctx, st.GetStmtNode()); err == nil { - _, err = st.Exec(ctx) - } - s.txn.onStmtEnd() - if err != nil { - s.StmtRollback(ctx, false) - break - } - s.StmtCommit(ctx) - } - logutil.Logger(ctx).Warn("transaction association", - zap.Uint64("retrying txnStartTS", s.GetSessionVars().TxnCtx.StartTS), - zap.Uint64("original txnStartTS", orgStartTS)) - failpoint.Inject("preCommitHook", func() { - hook, ok := ctx.Value("__preCommitHook").(func()) - if ok { - hook() - } - }) - if err == nil { - err = s.doCommit(ctx) - if err == nil { - break - } - } - if !s.isTxnRetryableError(err) { - logutil.Logger(ctx).Warn("sql", - zap.String("label", label), - zap.Stringer("session", s), - zap.Error(err)) - metrics.SessionRetryErrorCounter.WithLabelValues(label, metrics.LblUnretryable).Inc() - return err - } - retryCnt++ - if retryCnt >= maxCnt { - logutil.Logger(ctx).Warn("sql", - zap.String("label", label), - zap.Uint("retry reached max count", retryCnt)) - metrics.SessionRetryErrorCounter.WithLabelValues(label, metrics.LblReachMax).Inc() - return err - } - logutil.Logger(ctx).Warn("sql", - zap.String("label", label), - zap.Error(err), - zap.String("txn", s.txn.GoString())) - kv.BackOff(retryCnt) - s.txn.changeToInvalid() - s.sessionVars.SetInTxn(false) - } - return err -} - -func sqlForLog(sql string) string { - if len(sql) > sqlLogMaxLen { - sql = sql[:sqlLogMaxLen] + fmt.Sprintf("(len:%d)", len(sql)) - } - return executor.QueryReplacer.Replace(sql) -} - -func (s *session) sysSessionPool() util.SessionPool { - return domain.GetDomain(s).SysSessionPool() -} - -func createSessionFunc(store kv.Storage) pools.Factory { - return func() (pools.Resource, error) { - se, err := createSession(store) - if err != nil { - return nil, err - } - err = se.sessionVars.SetSystemVar(variable.AutoCommit, "1") - if err != nil { - return nil, err - } - err = se.sessionVars.SetSystemVar(variable.MaxExecutionTime, "0") - if err != nil { - return nil, errors.Trace(err) - } - err = se.sessionVars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) - if err != nil { - return nil, errors.Trace(err) - } - err = se.sessionVars.SetSystemVar(variable.TiDBEnableWindowFunction, variable.BoolToOnOff(variable.DefEnableWindowFunction)) - if err != nil { - return nil, errors.Trace(err) - } - err = se.sessionVars.SetSystemVar(variable.TiDBConstraintCheckInPlacePessimistic, variable.On) - if err != nil { - return nil, errors.Trace(err) - } - se.sessionVars.CommonGlobalLoaded = true - se.sessionVars.InRestrictedSQL = true - // Internal session uses default format to prevent memory leak problem. - se.sessionVars.EnableChunkRPC = false - return se, nil - } -} - -func createSessionWithDomainFunc(store kv.Storage) func(*domain.Domain) (pools.Resource, error) { - return func(dom *domain.Domain) (pools.Resource, error) { - se, err := CreateSessionWithDomain(store, dom) - if err != nil { - return nil, err - } - err = se.sessionVars.SetSystemVar(variable.AutoCommit, "1") - if err != nil { - return nil, err - } - err = se.sessionVars.SetSystemVar(variable.MaxExecutionTime, "0") - if err != nil { - return nil, errors.Trace(err) - } - err = se.sessionVars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) - if err != nil { - return nil, errors.Trace(err) - } - err = se.sessionVars.SetSystemVar(variable.TiDBConstraintCheckInPlacePessimistic, variable.On) - if err != nil { - return nil, errors.Trace(err) - } - se.sessionVars.CommonGlobalLoaded = true - se.sessionVars.InRestrictedSQL = true - // Internal session uses default format to prevent memory leak problem. - se.sessionVars.EnableChunkRPC = false - return se, nil - } -} - -func drainRecordSet(ctx context.Context, se *session, rs sqlexec.RecordSet, alloc chunk.Allocator) ([]chunk.Row, error) { - var rows []chunk.Row - var req *chunk.Chunk - req = rs.NewChunk(alloc) - for { - err := rs.Next(ctx, req) - if err != nil || req.NumRows() == 0 { - return rows, err - } - iter := chunk.NewIterator4Chunk(req) - for r := iter.Begin(); r != iter.End(); r = iter.Next() { - rows = append(rows, r) - } - req = chunk.Renew(req, se.sessionVars.MaxChunkSize) - } -} - -// getTableValue executes restricted sql and the result is one column. -// It returns a string value. -func (s *session) getTableValue(ctx context.Context, tblName string, varName string) (string, error) { - if ctx.Value(kv.RequestSourceKey) == nil { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnSysVar) - } - rows, fields, err := s.ExecRestrictedSQL(ctx, nil, "SELECT VARIABLE_VALUE FROM %n.%n WHERE VARIABLE_NAME=%?", mysql.SystemDB, tblName, varName) - if err != nil { - return "", err - } - if len(rows) == 0 { - return "", errResultIsEmpty - } - d := rows[0].GetDatum(0, &fields[0].Column.FieldType) - value, err := d.ToString() - if err != nil { - return "", err - } - return value, nil -} - -// replaceGlobalVariablesTableValue executes restricted sql updates the variable value -// It will then notify the etcd channel that the value has changed. -func (s *session) replaceGlobalVariablesTableValue(ctx context.Context, varName, val string, updateLocal bool) error { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnSysVar) - _, _, err := s.ExecRestrictedSQL(ctx, nil, `REPLACE INTO %n.%n (variable_name, variable_value) VALUES (%?, %?)`, mysql.SystemDB, mysql.GlobalVariablesTable, varName, val) - if err != nil { - return err - } - domain.GetDomain(s).NotifyUpdateSysVarCache(updateLocal) - return err -} - -// GetGlobalSysVar implements GlobalVarAccessor.GetGlobalSysVar interface. -func (s *session) GetGlobalSysVar(name string) (string, error) { - if s.Value(sessionctx.Initing) != nil { - // When running bootstrap or upgrade, we should not access global storage. - return "", nil - } - - sv := variable.GetSysVar(name) - if sv == nil { - // It might be a recently unregistered sysvar. We should return unknown - // since GetSysVar is the canonical version, but we can update the cache - // so the next request doesn't attempt to load this. - logutil.BgLogger().Info("sysvar does not exist. sysvar cache may be stale", zap.String("name", name)) - return "", variable.ErrUnknownSystemVar.GenWithStackByArgs(name) - } - - sysVar, err := domain.GetDomain(s).GetGlobalVar(name) - if err != nil { - // The sysvar exists, but there is no cache entry yet. - // This might be because the sysvar was only recently registered. - // In which case it is safe to return the default, but we can also - // update the cache for the future. - logutil.BgLogger().Info("sysvar not in cache yet. sysvar cache may be stale", zap.String("name", name)) - sysVar, err = s.getTableValue(context.TODO(), mysql.GlobalVariablesTable, name) - if err != nil { - return sv.Value, nil - } - } - // It might have been written from an earlier TiDB version, so we should do type validation - // See https://github.com/pingcap/tidb/issues/30255 for why we don't do full validation. - // If validation fails, we should return the default value: - // See: https://github.com/pingcap/tidb/pull/31566 - sysVar, err = sv.ValidateFromType(s.GetSessionVars(), sysVar, variable.ScopeGlobal) - if err != nil { - return sv.Value, nil - } - return sysVar, nil -} - -// SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface. -// it is called (but skipped) when setting instance scope -func (s *session) SetGlobalSysVar(ctx context.Context, name string, value string) (err error) { - sv := variable.GetSysVar(name) - if sv == nil { - return variable.ErrUnknownSystemVar.GenWithStackByArgs(name) - } - if value, err = sv.Validate(s.sessionVars, value, variable.ScopeGlobal); err != nil { - return err - } - if err = sv.SetGlobalFromHook(ctx, s.sessionVars, value, false); err != nil { - return err - } - if sv.HasInstanceScope() { // skip for INSTANCE scope - return nil - } - if sv.GlobalConfigName != "" { - domain.GetDomain(s).NotifyGlobalConfigChange(sv.GlobalConfigName, variable.OnOffToTrueFalse(value)) - } - return s.replaceGlobalVariablesTableValue(context.TODO(), sv.Name, value, true) -} - -// SetGlobalSysVarOnly updates the sysvar, but does not call the validation function or update aliases. -// This is helpful to prevent duplicate warnings being appended from aliases, or recursion. -// updateLocal indicates whether to rebuild the local SysVar Cache. This is helpful to prevent recursion. -func (s *session) SetGlobalSysVarOnly(ctx context.Context, name string, value string, updateLocal bool) (err error) { - sv := variable.GetSysVar(name) - if sv == nil { - return variable.ErrUnknownSystemVar.GenWithStackByArgs(name) - } - if err = sv.SetGlobalFromHook(ctx, s.sessionVars, value, true); err != nil { - return err - } - if sv.HasInstanceScope() { // skip for INSTANCE scope - return nil - } - return s.replaceGlobalVariablesTableValue(ctx, sv.Name, value, updateLocal) -} - -// SetTiDBTableValue implements GlobalVarAccessor.SetTiDBTableValue interface. -func (s *session) SetTiDBTableValue(name, value, comment string) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnSysVar) - _, _, err := s.ExecRestrictedSQL(ctx, nil, `REPLACE INTO mysql.tidb (variable_name, variable_value, comment) VALUES (%?, %?, %?)`, name, value, comment) - return err -} - -// GetTiDBTableValue implements GlobalVarAccessor.GetTiDBTableValue interface. -func (s *session) GetTiDBTableValue(name string) (string, error) { - return s.getTableValue(context.TODO(), mysql.TiDBTable, name) -} - -var _ sqlexec.SQLParser = &session{} - -func (s *session) ParseSQL(ctx context.Context, sql string, params ...parser.ParseParam) ([]ast.StmtNode, []error, error) { - defer tracing.StartRegion(ctx, "ParseSQL").End() - - p := parserPool.Get().(*parser.Parser) - defer parserPool.Put(p) - - sqlMode := s.sessionVars.SQLMode - if s.isInternal() { - sqlMode = mysql.DelSQLMode(sqlMode, mysql.ModeNoBackslashEscapes) - } - p.SetSQLMode(sqlMode) - p.SetParserConfig(s.sessionVars.BuildParserConfig()) - tmp, warn, err := p.ParseSQL(sql, params...) - // The []ast.StmtNode is referenced by the parser, to reuse the parser, make a copy of the result. - res := make([]ast.StmtNode, len(tmp)) - copy(res, tmp) - return res, warn, err -} - -func (s *session) SetProcessInfo(sql string, t time.Time, command byte, maxExecutionTime uint64) { - // If command == mysql.ComSleep, it means the SQL execution is finished. The processinfo is reset to SLEEP. - // If the SQL finished and the session is not in transaction, the current start timestamp need to reset to 0. - // Otherwise, it should be set to the transaction start timestamp. - // Why not reset the transaction start timestamp to 0 when transaction committed? - // Because the select statement and other statements need this timestamp to read data, - // after the transaction is committed. e.g. SHOW MASTER STATUS; - var curTxnStartTS uint64 - var curTxnCreateTime time.Time - if command != mysql.ComSleep || s.GetSessionVars().InTxn() { - curTxnStartTS = s.sessionVars.TxnCtx.StartTS - curTxnCreateTime = s.sessionVars.TxnCtx.CreateTime - } - // Set curTxnStartTS to SnapshotTS directly when the session is trying to historic read. - // It will avoid the session meet GC lifetime too short error. - if s.GetSessionVars().SnapshotTS != 0 { - curTxnStartTS = s.GetSessionVars().SnapshotTS - } - p := s.currentPlan - if explain, ok := p.(*plannercore.Explain); ok && explain.Analyze && explain.TargetPlan != nil { - p = explain.TargetPlan - } - - pi := util.ProcessInfo{ - ID: s.sessionVars.ConnectionID, - Port: s.sessionVars.Port, - DB: s.sessionVars.CurrentDB, - Command: command, - Plan: p, - PlanExplainRows: plannercore.GetExplainRowsForPlan(p), - RuntimeStatsColl: s.sessionVars.StmtCtx.RuntimeStatsColl, - Time: t, - State: s.Status(), - Info: sql, - CurTxnStartTS: curTxnStartTS, - CurTxnCreateTime: curTxnCreateTime, - StmtCtx: s.sessionVars.StmtCtx, - RefCountOfStmtCtx: &s.sessionVars.RefCountOfStmtCtx, - MemTracker: s.sessionVars.MemTracker, - DiskTracker: s.sessionVars.DiskTracker, - StatsInfo: plannercore.GetStatsInfo, - OOMAlarmVariablesInfo: s.getOomAlarmVariablesInfo(), - TableIDs: s.sessionVars.StmtCtx.TableIDs, - IndexNames: s.sessionVars.StmtCtx.IndexNames, - MaxExecutionTime: maxExecutionTime, - RedactSQL: s.sessionVars.EnableRedactLog, - ResourceGroupName: s.sessionVars.StmtCtx.ResourceGroupName, - SessionAlias: s.sessionVars.SessionAlias, - CursorTracker: s.cursorTracker, - } - oldPi := s.ShowProcess() - if p == nil { - // Store the last valid plan when the current plan is nil. - // This is for `explain for connection` statement has the ability to query the last valid plan. - if oldPi != nil && oldPi.Plan != nil && len(oldPi.PlanExplainRows) > 0 { - pi.Plan = oldPi.Plan - pi.PlanExplainRows = oldPi.PlanExplainRows - pi.RuntimeStatsColl = oldPi.RuntimeStatsColl - } - } - // We set process info before building plan, so we extended execution time. - if oldPi != nil && oldPi.Info == pi.Info && oldPi.Command == pi.Command { - pi.Time = oldPi.Time - } - if oldPi != nil && oldPi.CurTxnStartTS != 0 && oldPi.CurTxnStartTS == pi.CurTxnStartTS { - // Keep the last expensive txn log time, avoid print too many expensive txn logs. - pi.ExpensiveTxnLogTime = oldPi.ExpensiveTxnLogTime - } - _, digest := s.sessionVars.StmtCtx.SQLDigest() - pi.Digest = digest.String() - // DO NOT reset the currentPlan to nil until this query finishes execution, otherwise reentrant calls - // of SetProcessInfo would override Plan and PlanExplainRows to nil. - if command == mysql.ComSleep { - s.currentPlan = nil - } - if s.sessionVars.User != nil { - pi.User = s.sessionVars.User.Username - pi.Host = s.sessionVars.User.Hostname - } - s.processInfo.Store(&pi) -} - -// UpdateProcessInfo updates the session's process info for the running statement. -func (s *session) UpdateProcessInfo() { - pi := s.ShowProcess() - if pi == nil || pi.CurTxnStartTS != 0 { - return - } - // do not modify this two fields in place, see issue: issues/50607 - shallowCP := pi.Clone() - // Update the current transaction start timestamp. - shallowCP.CurTxnStartTS = s.sessionVars.TxnCtx.StartTS - shallowCP.CurTxnCreateTime = s.sessionVars.TxnCtx.CreateTime - s.processInfo.Store(shallowCP) -} - -func (s *session) getOomAlarmVariablesInfo() util.OOMAlarmVariablesInfo { - return util.OOMAlarmVariablesInfo{ - SessionAnalyzeVersion: s.sessionVars.AnalyzeVersion, - SessionEnabledRateLimitAction: s.sessionVars.EnabledRateLimitAction, - SessionMemQuotaQuery: s.sessionVars.MemQuotaQuery, - } -} - -func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...any) (rs sqlexec.RecordSet, err error) { - origin := s.sessionVars.InRestrictedSQL - s.sessionVars.InRestrictedSQL = true - defer func() { - s.sessionVars.InRestrictedSQL = origin - // Restore the goroutine label by using the original ctx after execution is finished. - pprof.SetGoroutineLabels(ctx) - }() - - r, ctx := tracing.StartRegionEx(ctx, "session.ExecuteInternal") - defer r.End() - logutil.Eventf(ctx, "execute: %s", sql) - - stmtNode, err := s.ParseWithParams(ctx, sql, args...) - if err != nil { - return nil, err - } - - rs, err = s.ExecuteStmt(ctx, stmtNode) - if err != nil { - s.sessionVars.StmtCtx.AppendError(err) - } - if rs == nil { - return nil, err - } - - return rs, err -} - -// Execute is deprecated, we can remove it as soon as plugins are migrated. -func (s *session) Execute(ctx context.Context, sql string) (recordSets []sqlexec.RecordSet, err error) { - r, ctx := tracing.StartRegionEx(ctx, "session.Execute") - defer r.End() - logutil.Eventf(ctx, "execute: %s", sql) - - stmtNodes, err := s.Parse(ctx, sql) - if err != nil { - return nil, err - } - if len(stmtNodes) != 1 { - return nil, errors.New("Execute() API doesn't support multiple statements any more") - } - - rs, err := s.ExecuteStmt(ctx, stmtNodes[0]) - if err != nil { - s.sessionVars.StmtCtx.AppendError(err) - } - if rs == nil { - return nil, err - } - return []sqlexec.RecordSet{rs}, err -} - -// Parse parses a query string to raw ast.StmtNode. -func (s *session) Parse(ctx context.Context, sql string) ([]ast.StmtNode, error) { - logutil.Logger(ctx).Debug("parse", zap.String("sql", sql)) - parseStartTime := time.Now() - - // Load the session variables to the context. - // This is necessary for the parser to get the current sql_mode. - if err := s.loadCommonGlobalVariablesIfNeeded(); err != nil { - return nil, err - } - - stmts, warns, err := s.ParseSQL(ctx, sql, s.sessionVars.GetParseParams()...) - if err != nil { - s.rollbackOnError(ctx) - err = util.SyntaxError(err) - - // Only print log message when this SQL is from the user. - // Mute the warning for internal SQLs. - if !s.sessionVars.InRestrictedSQL { - logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", redact.String(s.sessionVars.EnableRedactLog, sql))) - s.sessionVars.StmtCtx.AppendError(err) - } - return nil, err - } - - durParse := time.Since(parseStartTime) - s.GetSessionVars().DurationParse = durParse - isInternal := s.isInternal() - if isInternal { - session_metrics.SessionExecuteParseDurationInternal.Observe(durParse.Seconds()) - } else { - session_metrics.SessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) - } - for _, warn := range warns { - s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) - } - return stmts, nil -} - -// ParseWithParams parses a query string, with arguments, to raw ast.StmtNode. -// Note that it will not do escaping if no variable arguments are passed. -func (s *session) ParseWithParams(ctx context.Context, sql string, args ...any) (ast.StmtNode, error) { - var err error - if len(args) > 0 { - sql, err = sqlescape.EscapeSQL(sql, args...) - if err != nil { - return nil, err - } - } - - internal := s.isInternal() - - var stmts []ast.StmtNode - var warns []error - parseStartTime := time.Now() - if internal { - // Do no respect the settings from clients, if it is for internal usage. - // Charsets from clients may give chance injections. - // Refer to https://stackoverflow.com/questions/5741187/sql-injection-that-gets-around-mysql-real-escape-string/12118602. - stmts, warns, err = s.ParseSQL(ctx, sql) - } else { - stmts, warns, err = s.ParseSQL(ctx, sql, s.sessionVars.GetParseParams()...) - } - if len(stmts) != 1 && err == nil { - err = errors.New("run multiple statements internally is not supported") - } - if err != nil { - s.rollbackOnError(ctx) - logSQL := sql[:min(500, len(sql))] - logutil.Logger(ctx).Warn("parse SQL failed", zap.Error(err), zap.String("SQL", redact.String(s.sessionVars.EnableRedactLog, logSQL))) - return nil, util.SyntaxError(err) - } - durParse := time.Since(parseStartTime) - if internal { - session_metrics.SessionExecuteParseDurationInternal.Observe(durParse.Seconds()) - } else { - session_metrics.SessionExecuteParseDurationGeneral.Observe(durParse.Seconds()) - } - for _, warn := range warns { - s.sessionVars.StmtCtx.AppendWarning(util.SyntaxWarn(warn)) - } - if topsqlstate.TopSQLEnabled() { - normalized, digest := parser.NormalizeDigest(sql) - if digest != nil { - // Reset the goroutine label when internal sql execute finish. - // Specifically reset in ExecRestrictedStmt function. - s.sessionVars.StmtCtx.IsSQLRegistered.Store(true) - topsql.AttachAndRegisterSQLInfo(ctx, normalized, digest, s.sessionVars.InRestrictedSQL) - } - } - return stmts[0], nil -} - -// GetAdvisoryLock acquires an advisory lock of lockName. -// Note that a lock can be acquired multiple times by the same session, -// in which case we increment a reference count. -// Each lock needs to be held in a unique session because -// we need to be able to ROLLBACK in any arbitrary order -// in order to release the locks. -func (s *session) GetAdvisoryLock(lockName string, timeout int64) error { - if lock, ok := s.advisoryLocks[lockName]; ok { - lock.IncrReferences() - return nil - } - sess, err := createSession(s.store) - if err != nil { - return err - } - infosync.StoreInternalSession(sess) - lock := &advisoryLock{session: sess, ctx: context.TODO(), owner: s.ShowProcess().ID} - err = lock.GetLock(lockName, timeout) - if err != nil { - return err - } - s.advisoryLocks[lockName] = lock - return nil -} - -// IsUsedAdvisoryLock checks if a lockName is already in use -func (s *session) IsUsedAdvisoryLock(lockName string) uint64 { - // Same session - if lock, ok := s.advisoryLocks[lockName]; ok { - return lock.owner - } - - // Check for transaction on advisory_locks table - sess, err := createSession(s.store) - if err != nil { - return 0 - } - lock := &advisoryLock{session: sess, ctx: context.TODO(), owner: s.ShowProcess().ID} - err = lock.IsUsedLock(lockName) - if err != nil { - // TODO: Return actual owner pid - // TODO: Check for mysql.ErrLockWaitTimeout and DeadLock - return 1 - } - return 0 -} - -// ReleaseAdvisoryLock releases an advisory locks held by the session. -// It returns FALSE if no lock by this name was held (by this session), -// and TRUE if a lock was held and "released". -// Note that the lock is not actually released if there are multiple -// references to the same lockName by the session, instead the reference -// count is decremented. -func (s *session) ReleaseAdvisoryLock(lockName string) (released bool) { - if lock, ok := s.advisoryLocks[lockName]; ok { - lock.DecrReferences() - if lock.ReferenceCount() <= 0 { - lock.Close() - delete(s.advisoryLocks, lockName) - infosync.DeleteInternalSession(lock.session) - } - return true - } - return false -} - -// ReleaseAllAdvisoryLocks releases all advisory locks held by the session -// and returns a count of the locks that were released. -// The count is based on unique locks held, so multiple references -// to the same lock do not need to be accounted for. -func (s *session) ReleaseAllAdvisoryLocks() int { - var count int - for lockName, lock := range s.advisoryLocks { - lock.Close() - count += lock.ReferenceCount() - delete(s.advisoryLocks, lockName) - infosync.DeleteInternalSession(lock.session) - } - return count -} - -// GetExtensions returns the `*extension.SessionExtensions` object -func (s *session) GetExtensions() *extension.SessionExtensions { - return s.extensions -} - -// SetExtensions sets the `*extension.SessionExtensions` object -func (s *session) SetExtensions(extensions *extension.SessionExtensions) { - s.extensions = extensions -} - -// InSandBoxMode indicates that this session is in sandbox mode -func (s *session) InSandBoxMode() bool { - return s.sandBoxMode -} - -// EnableSandBoxMode enable the sandbox mode. -func (s *session) EnableSandBoxMode() { - s.sandBoxMode = true -} - -// DisableSandBoxMode enable the sandbox mode. -func (s *session) DisableSandBoxMode() { - s.sandBoxMode = false -} - -// ParseWithParams4Test wrapper (s *session) ParseWithParams for test -func ParseWithParams4Test(ctx context.Context, s types.Session, - sql string, args ...any) (ast.StmtNode, error) { - return s.(*session).ParseWithParams(ctx, sql, args) -} - -var _ sqlexec.RestrictedSQLExecutor = &session{} -var _ sqlexec.SQLExecutor = &session{} - -// ExecRestrictedStmt implements RestrictedSQLExecutor interface. -func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( - []chunk.Row, []*ast.ResultField, error) { - defer pprof.SetGoroutineLabels(ctx) - execOption := sqlexec.GetExecOption(opts) - var se *session - var clean func() - var err error - if execOption.UseCurSession { - se, clean, err = s.useCurrentSession(execOption) - } else { - se, clean, err = s.getInternalSession(execOption) - } - if err != nil { - return nil, nil, err - } - defer clean() - - startTime := time.Now() - metrics.SessionRestrictedSQLCounter.Inc() - ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) - ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) - ctx = context.WithValue(ctx, tikvutil.RUDetailsCtxKey, tikvutil.NewRUDetails()) - rs, err := se.ExecuteStmt(ctx, stmtNode) - if err != nil { - se.sessionVars.StmtCtx.AppendError(err) - } - if rs == nil { - return nil, nil, err - } - defer func() { - if closeErr := rs.Close(); closeErr != nil { - err = closeErr - } - }() - var rows []chunk.Row - rows, err = drainRecordSet(ctx, se, rs, nil) - if err != nil { - return nil, nil, err - } - - vars := se.GetSessionVars() - for _, dbName := range GetDBNames(vars) { - metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal, dbName, vars.StmtCtx.ResourceGroupName).Observe(time.Since(startTime).Seconds()) - } - return rows, rs.Fields(), err -} - -// ExecRestrictedStmt4Test wrapper `(s *session) ExecRestrictedStmt` for test. -func ExecRestrictedStmt4Test(ctx context.Context, s types.Session, - stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( - []chunk.Row, []*ast.ResultField, error) { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnOthers) - return s.(*session).ExecRestrictedStmt(ctx, stmtNode, opts...) -} - -// only set and clean session with execOption -func (s *session) useCurrentSession(execOption sqlexec.ExecOption) (*session, func(), error) { - var err error - orgSnapshotInfoSchema, orgSnapshotTS := s.sessionVars.SnapshotInfoschema, s.sessionVars.SnapshotTS - if execOption.SnapshotTS != 0 { - if err = s.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { - return nil, nil, err - } - s.sessionVars.SnapshotInfoschema, err = getSnapshotInfoSchema(s, execOption.SnapshotTS) - if err != nil { - return nil, nil, err - } - } - prevStatsVer := s.sessionVars.AnalyzeVersion - if execOption.AnalyzeVer != 0 { - s.sessionVars.AnalyzeVersion = execOption.AnalyzeVer - } - prevAnalyzeSnapshot := s.sessionVars.EnableAnalyzeSnapshot - if execOption.AnalyzeSnapshot != nil { - s.sessionVars.EnableAnalyzeSnapshot = *execOption.AnalyzeSnapshot - } - prePruneMode := s.sessionVars.PartitionPruneMode.Load() - if len(execOption.PartitionPruneMode) > 0 { - s.sessionVars.PartitionPruneMode.Store(execOption.PartitionPruneMode) - } - prevSQL := s.sessionVars.StmtCtx.OriginalSQL - prevStmtType := s.sessionVars.StmtCtx.StmtType - prevTables := s.sessionVars.StmtCtx.Tables - return s, func() { - s.sessionVars.AnalyzeVersion = prevStatsVer - s.sessionVars.EnableAnalyzeSnapshot = prevAnalyzeSnapshot - if err := s.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { - logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) - } - s.sessionVars.SnapshotInfoschema = orgSnapshotInfoSchema - s.sessionVars.SnapshotTS = orgSnapshotTS - s.sessionVars.PartitionPruneMode.Store(prePruneMode) - s.sessionVars.StmtCtx.OriginalSQL = prevSQL - s.sessionVars.StmtCtx.StmtType = prevStmtType - s.sessionVars.StmtCtx.Tables = prevTables - s.sessionVars.MemTracker.Detach() - }, nil -} - -func (s *session) getInternalSession(execOption sqlexec.ExecOption) (*session, func(), error) { - tmp, err := s.sysSessionPool().Get() - if err != nil { - return nil, nil, errors.Trace(err) - } - se := tmp.(*session) - - // The special session will share the `InspectionTableCache` with current session - // if the current session in inspection mode. - if cache := s.sessionVars.InspectionTableCache; cache != nil { - se.sessionVars.InspectionTableCache = cache - } - se.sessionVars.OptimizerUseInvisibleIndexes = s.sessionVars.OptimizerUseInvisibleIndexes - - preSkipStats := s.sessionVars.SkipMissingPartitionStats - se.sessionVars.SkipMissingPartitionStats = s.sessionVars.SkipMissingPartitionStats - - if execOption.SnapshotTS != 0 { - if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, strconv.FormatUint(execOption.SnapshotTS, 10)); err != nil { - return nil, nil, err - } - se.sessionVars.SnapshotInfoschema, err = getSnapshotInfoSchema(s, execOption.SnapshotTS) - if err != nil { - return nil, nil, err - } - } - - prevStatsVer := se.sessionVars.AnalyzeVersion - if execOption.AnalyzeVer != 0 { - se.sessionVars.AnalyzeVersion = execOption.AnalyzeVer - } - - prevAnalyzeSnapshot := se.sessionVars.EnableAnalyzeSnapshot - if execOption.AnalyzeSnapshot != nil { - se.sessionVars.EnableAnalyzeSnapshot = *execOption.AnalyzeSnapshot - } - - prePruneMode := se.sessionVars.PartitionPruneMode.Load() - if len(execOption.PartitionPruneMode) > 0 { - se.sessionVars.PartitionPruneMode.Store(execOption.PartitionPruneMode) - } - - return se, func() { - se.sessionVars.AnalyzeVersion = prevStatsVer - se.sessionVars.EnableAnalyzeSnapshot = prevAnalyzeSnapshot - if err := se.sessionVars.SetSystemVar(variable.TiDBSnapshot, ""); err != nil { - logutil.BgLogger().Error("set tidbSnapshot error", zap.Error(err)) - } - se.sessionVars.SnapshotInfoschema = nil - se.sessionVars.SnapshotTS = 0 - if !execOption.IgnoreWarning { - if se != nil && se.GetSessionVars().StmtCtx.WarningCount() > 0 { - warnings := se.GetSessionVars().StmtCtx.GetWarnings() - s.GetSessionVars().StmtCtx.AppendWarnings(warnings) - } - } - se.sessionVars.PartitionPruneMode.Store(prePruneMode) - se.sessionVars.OptimizerUseInvisibleIndexes = false - se.sessionVars.SkipMissingPartitionStats = preSkipStats - se.sessionVars.InspectionTableCache = nil - se.sessionVars.MemTracker.Detach() - s.sysSessionPool().Put(tmp) - }, nil -} - -func (s *session) withRestrictedSQLExecutor(ctx context.Context, opts []sqlexec.OptionFuncAlias, fn func(context.Context, *session) ([]chunk.Row, []*ast.ResultField, error)) ([]chunk.Row, []*ast.ResultField, error) { - execOption := sqlexec.GetExecOption(opts) - var se *session - var clean func() - var err error - if execOption.UseCurSession { - se, clean, err = s.useCurrentSession(execOption) - } else { - se, clean, err = s.getInternalSession(execOption) - } - if err != nil { - return nil, nil, errors.Trace(err) - } - defer clean() - if execOption.TrackSysProcID > 0 { - err = execOption.TrackSysProc(execOption.TrackSysProcID, se) - if err != nil { - return nil, nil, errors.Trace(err) - } - // unTrack should be called before clean (return sys session) - defer execOption.UnTrackSysProc(execOption.TrackSysProcID) - } - return fn(ctx, se) -} - -func (s *session) ExecRestrictedSQL(ctx context.Context, opts []sqlexec.OptionFuncAlias, sql string, params ...any) ([]chunk.Row, []*ast.ResultField, error) { - return s.withRestrictedSQLExecutor(ctx, opts, func(ctx context.Context, se *session) ([]chunk.Row, []*ast.ResultField, error) { - stmt, err := se.ParseWithParams(ctx, sql, params...) - if err != nil { - return nil, nil, errors.Trace(err) - } - defer pprof.SetGoroutineLabels(ctx) - startTime := time.Now() - metrics.SessionRestrictedSQLCounter.Inc() - ctx = context.WithValue(ctx, execdetails.StmtExecDetailKey, &execdetails.StmtExecDetails{}) - ctx = context.WithValue(ctx, tikvutil.ExecDetailsKey, &tikvutil.ExecDetails{}) - rs, err := se.ExecuteInternalStmt(ctx, stmt) - if err != nil { - se.sessionVars.StmtCtx.AppendError(err) - } - if rs == nil { - return nil, nil, err - } - defer func() { - if closeErr := rs.Close(); closeErr != nil { - err = closeErr - } - }() - var rows []chunk.Row - rows, err = drainRecordSet(ctx, se, rs, nil) - if err != nil { - return nil, nil, err - } - - vars := se.GetSessionVars() - for _, dbName := range GetDBNames(vars) { - metrics.QueryDurationHistogram.WithLabelValues(metrics.LblInternal, dbName, vars.StmtCtx.ResourceGroupName).Observe(time.Since(startTime).Seconds()) - } - return rows, rs.Fields(), err - }) -} - -// ExecuteInternalStmt execute internal stmt -func (s *session) ExecuteInternalStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { - origin := s.sessionVars.InRestrictedSQL - s.sessionVars.InRestrictedSQL = true - defer func() { - s.sessionVars.InRestrictedSQL = origin - // Restore the goroutine label by using the original ctx after execution is finished. - pprof.SetGoroutineLabels(ctx) - }() - return s.ExecuteStmt(ctx, stmtNode) -} - -func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlexec.RecordSet, error) { - r, ctx := tracing.StartRegionEx(ctx, "session.ExecuteStmt") - defer r.End() - - if err := s.PrepareTxnCtx(ctx); err != nil { - return nil, err - } - if err := s.loadCommonGlobalVariablesIfNeeded(); err != nil { - return nil, err - } - - sessVars := s.sessionVars - sessVars.StartTime = time.Now() - - // Some executions are done in compile stage, so we reset them before compile. - if err := executor.ResetContextOfStmt(s, stmtNode); err != nil { - return nil, err - } - if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { - if binParam, ok := execStmt.BinaryArgs.([]param.BinaryParam); ok { - args, err := param.ExecArgs(s.GetSessionVars().StmtCtx.TypeCtx(), binParam) - if err != nil { - return nil, err - } - execStmt.BinaryArgs = args - } - } - - normalizedSQL, digest := s.sessionVars.StmtCtx.SQLDigest() - cmdByte := byte(atomic.LoadUint32(&s.GetSessionVars().CommandValue)) - if topsqlstate.TopSQLEnabled() { - s.sessionVars.StmtCtx.IsSQLRegistered.Store(true) - ctx = topsql.AttachAndRegisterSQLInfo(ctx, normalizedSQL, digest, s.sessionVars.InRestrictedSQL) - } - if sessVars.InPlanReplayer { - sessVars.StmtCtx.EnableOptimizerDebugTrace = true - } else if dom := domain.GetDomain(s); dom != nil && !sessVars.InRestrictedSQL { - // This is the earliest place we can get the SQL digest for this execution. - // If we find this digest is registered for PLAN REPLAYER CAPTURE, we need to enable optimizer debug trace no matter - // the plan digest will be matched or not. - if planReplayerHandle := dom.GetPlanReplayerHandle(); planReplayerHandle != nil { - tasks := planReplayerHandle.GetTasks() - for _, task := range tasks { - if task.SQLDigest == digest.String() { - sessVars.StmtCtx.EnableOptimizerDebugTrace = true - } - } - } - } - if sessVars.StmtCtx.EnableOptimizerDebugTrace { - plannercore.DebugTraceReceivedCommand(s.pctx, cmdByte, stmtNode) - } - - if err := s.validateStatementInTxn(stmtNode); err != nil { - return nil, err - } - - if err := s.validateStatementReadOnlyInStaleness(stmtNode); err != nil { - return nil, err - } - - // Uncorrelated subqueries will execute once when building plan, so we reset process info before building plan. - s.currentPlan = nil // reset current plan - s.SetProcessInfo(stmtNode.Text(), time.Now(), cmdByte, 0) - s.txn.onStmtStart(digest.String()) - defer sessiontxn.GetTxnManager(s).OnStmtEnd() - defer s.txn.onStmtEnd() - - if err := s.onTxnManagerStmtStartOrRetry(ctx, stmtNode); err != nil { - return nil, err - } - - failpoint.Inject("mockStmtSlow", func(val failpoint.Value) { - if strings.Contains(stmtNode.Text(), "/* sleep */") { - v, _ := val.(int) - time.Sleep(time.Duration(v) * time.Millisecond) - } - }) - - var stmtLabel string - if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { - prepareStmt, err := plannercore.GetPreparedStmt(execStmt, s.sessionVars) - if err == nil && prepareStmt.PreparedAst != nil { - stmtLabel = ast.GetStmtLabel(prepareStmt.PreparedAst.Stmt) - } - } - if stmtLabel == "" { - stmtLabel = ast.GetStmtLabel(stmtNode) - } - s.setRequestSource(ctx, stmtLabel, stmtNode) - - // Transform abstract syntax tree to a physical plan(stored in executor.ExecStmt). - compiler := executor.Compiler{Ctx: s} - stmt, err := compiler.Compile(ctx, stmtNode) - // check if resource group hint is valid, can't do this in planner.Optimize because we can access - // infoschema there. - if sessVars.StmtCtx.ResourceGroupName != sessVars.ResourceGroupName { - // if target resource group doesn't exist, fallback to the origin resource group. - if _, ok := domain.GetDomain(s).InfoSchema().ResourceGroupByName(model.NewCIStr(sessVars.StmtCtx.ResourceGroupName)); !ok { - logutil.Logger(ctx).Warn("Unknown resource group from hint", zap.String("name", sessVars.StmtCtx.ResourceGroupName)) - sessVars.StmtCtx.ResourceGroupName = sessVars.ResourceGroupName - if txn, err := s.Txn(false); err == nil && txn != nil && txn.Valid() { - kv.SetTxnResourceGroup(txn, sessVars.ResourceGroupName) - } - } - } - if err != nil { - s.rollbackOnError(ctx) - - // Only print log message when this SQL is from the user. - // Mute the warning for internal SQLs. - if !s.sessionVars.InRestrictedSQL { - if !variable.ErrUnknownSystemVar.Equal(err) { - sql := stmtNode.Text() - sql = parser.Normalize(sql, s.sessionVars.EnableRedactLog) - logutil.Logger(ctx).Warn("compile SQL failed", zap.Error(err), - zap.String("SQL", sql)) - } - } - return nil, err - } - - durCompile := time.Since(s.sessionVars.StartTime) - s.GetSessionVars().DurationCompile = durCompile - if s.isInternal() { - session_metrics.SessionExecuteCompileDurationInternal.Observe(durCompile.Seconds()) - } else { - session_metrics.SessionExecuteCompileDurationGeneral.Observe(durCompile.Seconds()) - } - s.currentPlan = stmt.Plan - if execStmt, ok := stmtNode.(*ast.ExecuteStmt); ok { - if execStmt.Name == "" { - // for exec-stmt on bin-protocol, ignore the plan detail in `show process` to gain performance benefits. - s.currentPlan = nil - } - } - - // Execute the physical plan. - logStmt(stmt, s) - - var recordSet sqlexec.RecordSet - if stmt.PsStmt != nil { // point plan short path - recordSet, err = stmt.PointGet(ctx) - s.txn.changeToInvalid() - } else { - recordSet, err = runStmt(ctx, s, stmt) - } - - // Observe the resource group query total counter if the resource control is enabled and the - // current session is attached with a resource group. - resourceGroupName := s.GetSessionVars().StmtCtx.ResourceGroupName - if len(resourceGroupName) > 0 { - metrics.ResourceGroupQueryTotalCounter.WithLabelValues(resourceGroupName, resourceGroupName).Inc() - } - - if err != nil { - if !errIsNoisy(err) { - logutil.Logger(ctx).Warn("run statement failed", - zap.Int64("schemaVersion", s.GetInfoSchema().SchemaMetaVersion()), - zap.Error(err), - zap.String("session", s.String())) - } - return recordSet, err - } - return recordSet, nil -} - -func (s *session) GetSQLExecutor() sqlexec.SQLExecutor { - return s -} - -func (s *session) GetRestrictedSQLExecutor() sqlexec.RestrictedSQLExecutor { - return s -} - -func (s *session) onTxnManagerStmtStartOrRetry(ctx context.Context, node ast.StmtNode) error { - if s.sessionVars.RetryInfo.Retrying { - return sessiontxn.GetTxnManager(s).OnStmtRetry(ctx) - } - return sessiontxn.GetTxnManager(s).OnStmtStart(ctx, node) -} - -func (s *session) validateStatementInTxn(stmtNode ast.StmtNode) error { - vars := s.GetSessionVars() - if _, ok := stmtNode.(*ast.ImportIntoStmt); ok && vars.InTxn() { - return errors.New("cannot run IMPORT INTO in explicit transaction") - } - return nil -} - -func (s *session) validateStatementReadOnlyInStaleness(stmtNode ast.StmtNode) error { - vars := s.GetSessionVars() - if !vars.TxnCtx.IsStaleness && vars.TxnReadTS.PeakTxnReadTS() == 0 && !vars.EnableExternalTSRead || vars.InRestrictedSQL { - return nil - } - errMsg := "only support read-only statement during read-only staleness transactions" - node := stmtNode.(ast.Node) - switch v := node.(type) { - case *ast.SplitRegionStmt: - return nil - case *ast.SelectStmt: - // select lock statement needs start a transaction which will be conflict to stale read, - // we forbid select lock statement in stale read for now. - if v.LockInfo != nil { - return errors.New("select lock hasn't been supported in stale read yet") - } - if !planner.IsReadOnly(stmtNode, vars) { - return errors.New(errMsg) - } - return nil - case *ast.ExplainStmt, *ast.DoStmt, *ast.ShowStmt, *ast.SetOprStmt, *ast.ExecuteStmt, *ast.SetOprSelectList: - if !planner.IsReadOnly(stmtNode, vars) { - return errors.New(errMsg) - } - return nil - default: - } - // covered DeleteStmt/InsertStmt/UpdateStmt/CallStmt/LoadDataStmt - if _, ok := stmtNode.(ast.DMLNode); ok { - return errors.New(errMsg) - } - return nil -} - -// fileTransInConnKeys contains the keys of queries that will be handled by handleFileTransInConn. -var fileTransInConnKeys = []fmt.Stringer{ - executor.LoadDataVarKey, - executor.LoadStatsVarKey, - executor.IndexAdviseVarKey, - executor.PlanReplayerLoadVarKey, -} - -func (s *session) hasFileTransInConn() bool { - s.mu.RLock() - defer s.mu.RUnlock() - - for _, k := range fileTransInConnKeys { - v := s.mu.values[k] - if v != nil { - return true - } - } - return false -} - -// runStmt executes the sqlexec.Statement and commit or rollback the current transaction. -func runStmt(ctx context.Context, se *session, s sqlexec.Statement) (rs sqlexec.RecordSet, err error) { - failpoint.Inject("assertTxnManagerInRunStmt", func() { - sessiontxn.RecordAssert(se, "assertTxnManagerInRunStmt", true) - if stmt, ok := s.(*executor.ExecStmt); ok { - sessiontxn.AssertTxnManagerInfoSchema(se, stmt.InfoSchema) - } - }) - - r, ctx := tracing.StartRegionEx(ctx, "session.runStmt") - defer r.End() - if r.Span != nil { - r.Span.LogKV("sql", s.OriginText()) - } - - se.SetValue(sessionctx.QueryString, s.OriginText()) - if _, ok := s.(*executor.ExecStmt).StmtNode.(ast.DDLNode); ok { - se.SetValue(sessionctx.LastExecuteDDL, true) - } else { - se.ClearValue(sessionctx.LastExecuteDDL) - } - - sessVars := se.sessionVars - - // Save origTxnCtx here to avoid it reset in the transaction retry. - origTxnCtx := sessVars.TxnCtx - err = se.checkTxnAborted(s) - if err != nil { - return nil, err - } - if sessVars.TxnCtx.CouldRetry && !s.IsReadOnly(sessVars) { - // Only when the txn is could retry and the statement is not read only, need to do stmt-count-limit check, - // otherwise, the stmt won't be add into stmt history, and also don't need check. - // About `stmt-count-limit`, see more in https://docs.pingcap.com/tidb/stable/tidb-configuration-file#stmt-count-limit - if err := checkStmtLimit(ctx, se, false); err != nil { - return nil, err - } - } - - rs, err = s.Exec(ctx) - - if se.txn.Valid() && se.txn.IsPipelined() { - // Pipelined-DMLs can return assertion errors and write conflicts here because they flush - // during execution, handle these errors as we would handle errors after a commit. - if err != nil { - err = se.handleAssertionFailure(ctx, err) - } - newErr := se.tryReplaceWriteConflictError(err) - if newErr != nil { - err = newErr - } - } - - sessVars.TxnCtx.StatementCount++ - if rs != nil { - if se.GetSessionVars().StmtCtx.IsExplainAnalyzeDML { - if !sessVars.InTxn() { - se.StmtCommit(ctx) - if err := se.CommitTxn(ctx); err != nil { - return nil, err - } - } - } - return &execStmtResult{ - RecordSet: rs, - sql: s, - se: se, - }, err - } - - err = finishStmt(ctx, se, err, s) - if se.hasFileTransInConn() { - // The query will be handled later in handleFileTransInConn, - // then should call the ExecStmt.FinishExecuteStmt to finish this statement. - se.SetValue(ExecStmtVarKey, s.(*executor.ExecStmt)) - } else { - // If it is not a select statement or special query, we record its slow log here, - // then it could include the transaction commit time. - s.(*executor.ExecStmt).FinishExecuteStmt(origTxnCtx.StartTS, err, false) - } - return nil, err -} - -// ExecStmtVarKeyType is a dummy type to avoid naming collision in context. -type ExecStmtVarKeyType int - -// String defines a Stringer function for debugging and pretty printing. -func (ExecStmtVarKeyType) String() string { - return "exec_stmt_var_key" -} - -// ExecStmtVarKey is a variable key for ExecStmt. -const ExecStmtVarKey ExecStmtVarKeyType = 0 - -// execStmtResult is the return value of ExecuteStmt and it implements the sqlexec.RecordSet interface. -// Why we need a struct to wrap a RecordSet and provide another RecordSet? -// This is because there are so many session state related things that definitely not belongs to the original -// RecordSet, so this struct exists and RecordSet.Close() is overridden to handle that. -type execStmtResult struct { - sqlexec.RecordSet - se *session - sql sqlexec.Statement - once sync.Once - closed bool -} - -func (rs *execStmtResult) Finish() error { - var err error - rs.once.Do(func() { - var err1 error - if f, ok := rs.RecordSet.(interface{ Finish() error }); ok { - err1 = f.Finish() - } - err2 := finishStmt(context.Background(), rs.se, err, rs.sql) - if err1 != nil { - err = err1 - } else { - err = err2 - } - }) - return err -} - -func (rs *execStmtResult) Close() error { - if rs.closed { - return nil - } - err1 := rs.Finish() - err2 := rs.RecordSet.Close() - rs.closed = true - if err1 != nil { - return err1 - } - return err2 -} - -func (rs *execStmtResult) TryDetach() (sqlexec.RecordSet, bool, error) { - if !rs.sql.IsReadOnly(rs.se.GetSessionVars()) { - return nil, false, nil - } - if !plannercore.IsAutoCommitTxn(rs.se.GetSessionVars()) { - return nil, false, nil - } - - drs, ok := rs.RecordSet.(sqlexec.DetachableRecordSet) - if !ok { - return nil, false, nil - } - detachedRS, ok, err := drs.TryDetach() - if !ok || err != nil { - return nil, ok, err - } - cursorHandle := rs.se.GetCursorTracker().NewCursor( - cursor.State{StartTS: rs.se.GetSessionVars().TxnCtx.StartTS}, - ) - crs := staticrecordset.WrapRecordSetWithCursor(cursorHandle, detachedRS) - - // Now, a transaction is not needed for the detached record set, so we commit the transaction and cleanup - // the session state. - err = finishStmt(context.Background(), rs.se, nil, rs.sql) - if err != nil { - err2 := detachedRS.Close() - if err2 != nil { - logutil.BgLogger().Error("close detached record set failed", zap.Error(err2)) - } - return nil, true, err - } - - return crs, true, nil -} - -// GetExecutor4Test exports the internal executor for test purpose. -func (rs *execStmtResult) GetExecutor4Test() any { - return rs.RecordSet.(interface{ GetExecutor4Test() any }).GetExecutor4Test() -} - -// rollbackOnError makes sure the next statement starts a new transaction with the latest InfoSchema. -func (s *session) rollbackOnError(ctx context.Context) { - if !s.sessionVars.InTxn() { - s.RollbackTxn(ctx) - } -} - -// PrepareStmt is used for executing prepare statement in binary protocol -func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*ast.ResultField, err error) { - defer func() { - if s.sessionVars.StmtCtx != nil { - s.sessionVars.StmtCtx.DetachMemDiskTracker() - } - }() - if s.sessionVars.TxnCtx.InfoSchema == nil { - // We don't need to create a transaction for prepare statement, just get information schema will do. - s.sessionVars.TxnCtx.InfoSchema = domain.GetDomain(s).InfoSchema() - } - err = s.loadCommonGlobalVariablesIfNeeded() - if err != nil { - return - } - - ctx := context.Background() - // NewPrepareExec may need startTS to build the executor, for example prepare statement has subquery in int. - // So we have to call PrepareTxnCtx here. - if err = s.PrepareTxnCtx(ctx); err != nil { - return - } - - prepareStmt := &ast.PrepareStmt{SQLText: sql} - if err = s.onTxnManagerStmtStartOrRetry(ctx, prepareStmt); err != nil { - return - } - - if err = sessiontxn.GetTxnManager(s).AdviseWarmup(); err != nil { - return - } - prepareExec := executor.NewPrepareExec(s, sql) - err = prepareExec.Next(ctx, nil) - // Rollback even if err is nil. - s.rollbackOnError(ctx) - - if err != nil { - return - } - return prepareExec.ID, prepareExec.ParamCount, prepareExec.Fields, nil -} - -// ExecutePreparedStmt executes a prepared statement. -func (s *session) ExecutePreparedStmt(ctx context.Context, stmtID uint32, params []expression.Expression) (sqlexec.RecordSet, error) { - prepStmt, err := s.sessionVars.GetPreparedStmtByID(stmtID) - if err != nil { - err = plannererrors.ErrStmtNotFound - logutil.Logger(ctx).Error("prepared statement not found", zap.Uint32("stmtID", stmtID)) - return nil, err - } - stmt, ok := prepStmt.(*plannercore.PlanCacheStmt) - if !ok { - return nil, errors.Errorf("invalid PlanCacheStmt type") - } - execStmt := &ast.ExecuteStmt{ - BinaryArgs: params, - PrepStmt: stmt, - } - return s.ExecuteStmt(ctx, execStmt) -} - -func (s *session) DropPreparedStmt(stmtID uint32) error { - vars := s.sessionVars - if _, ok := vars.PreparedStmts[stmtID]; !ok { - return plannererrors.ErrStmtNotFound - } - vars.RetryInfo.DroppedPreparedStmtIDs = append(vars.RetryInfo.DroppedPreparedStmtIDs, stmtID) - return nil -} - -func (s *session) Txn(active bool) (kv.Transaction, error) { - if !active { - return &s.txn, nil - } - _, err := sessiontxn.GetTxnManager(s).ActivateTxn() - s.SetMemoryFootprintChangeHook() - return &s.txn, err -} - -func (s *session) SetValue(key fmt.Stringer, value any) { - s.mu.Lock() - s.mu.values[key] = value - s.mu.Unlock() -} - -func (s *session) Value(key fmt.Stringer) any { - s.mu.RLock() - value := s.mu.values[key] - s.mu.RUnlock() - return value -} - -func (s *session) ClearValue(key fmt.Stringer) { - s.mu.Lock() - delete(s.mu.values, key) - s.mu.Unlock() -} - -type inCloseSession struct{} - -// Close function does some clean work when session end. -// Close should release the table locks which hold by the session. -func (s *session) Close() { - // TODO: do clean table locks when session exited without execute Close. - // TODO: do clean table locks when tidb-server was `kill -9`. - if s.HasLockedTables() && config.TableLockEnabled() { - if ds := config.TableLockDelayClean(); ds > 0 { - time.Sleep(time.Duration(ds) * time.Millisecond) - } - lockedTables := s.GetAllTableLocks() - err := domain.GetDomain(s).DDLExecutor().UnlockTables(s, lockedTables) - if err != nil { - logutil.BgLogger().Error("release table lock failed", zap.Uint64("conn", s.sessionVars.ConnectionID)) - } - } - s.ReleaseAllAdvisoryLocks() - if s.statsCollector != nil { - s.statsCollector.Delete() - } - if s.idxUsageCollector != nil { - s.idxUsageCollector.Flush() - } - bindValue := s.Value(bindinfo.SessionBindInfoKeyType) - if bindValue != nil { - bindValue.(bindinfo.SessionBindingHandle).Close() - } - ctx := context.WithValue(context.TODO(), inCloseSession{}, struct{}{}) - s.RollbackTxn(ctx) - if s.sessionVars != nil { - s.sessionVars.WithdrawAllPreparedStmt() - } - if s.stmtStats != nil { - s.stmtStats.SetFinished() - } - s.sessionVars.ClearDiskFullOpt() - if s.sessionPlanCache != nil { - s.sessionPlanCache.Close() - } -} - -// GetSessionVars implements the context.Context interface. -func (s *session) GetSessionVars() *variable.SessionVars { - return s.sessionVars -} - -// GetPlanCtx returns the PlanContext. -func (s *session) GetPlanCtx() planctx.PlanContext { - return s.pctx -} - -// GetExprCtx returns the expression context of the session. -func (s *session) GetExprCtx() exprctx.ExprContext { - return s.exprctx -} - -// GetTableCtx returns the table.MutateContext -func (s *session) GetTableCtx() tbctx.MutateContext { - return s.tblctx -} - -// GetDistSQLCtx returns the context used in DistSQL -func (s *session) GetDistSQLCtx() *distsqlctx.DistSQLContext { - vars := s.GetSessionVars() - sc := vars.StmtCtx - - return sc.GetOrInitDistSQLFromCache(func() *distsqlctx.DistSQLContext { - return &distsqlctx.DistSQLContext{ - WarnHandler: sc.WarnHandler, - InRestrictedSQL: sc.InRestrictedSQL, - Client: s.GetClient(), - - EnabledRateLimitAction: vars.EnabledRateLimitAction, - EnableChunkRPC: vars.EnableChunkRPC, - OriginalSQL: sc.OriginalSQL, - KVVars: vars.KVVars, - KvExecCounter: sc.KvExecCounter, - SessionMemTracker: vars.MemTracker, - - Location: sc.TimeZone(), - RuntimeStatsColl: sc.RuntimeStatsColl, - SQLKiller: &vars.SQLKiller, - ErrCtx: sc.ErrCtx(), - - TiFlashReplicaRead: vars.TiFlashReplicaRead, - TiFlashMaxThreads: vars.TiFlashMaxThreads, - TiFlashMaxBytesBeforeExternalJoin: vars.TiFlashMaxBytesBeforeExternalJoin, - TiFlashMaxBytesBeforeExternalGroupBy: vars.TiFlashMaxBytesBeforeExternalGroupBy, - TiFlashMaxBytesBeforeExternalSort: vars.TiFlashMaxBytesBeforeExternalSort, - TiFlashMaxQueryMemoryPerNode: vars.TiFlashMaxQueryMemoryPerNode, - TiFlashQuerySpillRatio: vars.TiFlashQuerySpillRatio, - - DistSQLConcurrency: vars.DistSQLScanConcurrency(), - ReplicaReadType: vars.GetReplicaRead(), - WeakConsistency: sc.WeakConsistency, - RCCheckTS: sc.RCCheckTS, - NotFillCache: sc.NotFillCache, - TaskID: sc.TaskID, - Priority: sc.Priority, - ResourceGroupTagger: sc.GetResourceGroupTagger(), - EnablePaging: vars.EnablePaging, - MinPagingSize: vars.MinPagingSize, - MaxPagingSize: vars.MaxPagingSize, - RequestSourceType: vars.RequestSourceType, - ExplicitRequestSourceType: vars.ExplicitRequestSourceType, - StoreBatchSize: vars.StoreBatchSize, - ResourceGroupName: sc.ResourceGroupName, - LoadBasedReplicaReadThreshold: vars.LoadBasedReplicaReadThreshold, - RunawayChecker: sc.RunawayChecker, - TiKVClientReadTimeout: vars.GetTiKVClientReadTimeout(), - - ReplicaClosestReadThreshold: vars.ReplicaClosestReadThreshold, - ConnectionID: vars.ConnectionID, - SessionAlias: vars.SessionAlias, - - ExecDetails: &sc.SyncExecDetails, - } - }) -} - -// GetRangerCtx returns the context used in `ranger` related functions -func (s *session) GetRangerCtx() *rangerctx.RangerContext { - vars := s.GetSessionVars() - sc := vars.StmtCtx - - rctx := sc.GetOrInitRangerCtxFromCache(func() any { - return &rangerctx.RangerContext{ - ExprCtx: s.GetExprCtx(), - TypeCtx: s.GetSessionVars().StmtCtx.TypeCtx(), - ErrCtx: s.GetSessionVars().StmtCtx.ErrCtx(), - - InPreparedPlanBuilding: s.GetSessionVars().StmtCtx.InPreparedPlanBuilding, - RegardNULLAsPoint: s.GetSessionVars().RegardNULLAsPoint, - OptPrefixIndexSingleScan: s.GetSessionVars().OptPrefixIndexSingleScan, - OptimizerFixControl: s.GetSessionVars().OptimizerFixControl, - - PlanCacheTracker: &s.GetSessionVars().StmtCtx.PlanCacheTracker, - RangeFallbackHandler: &s.GetSessionVars().StmtCtx.RangeFallbackHandler, - } - }) - - return rctx.(*rangerctx.RangerContext) -} - -// GetBuildPBCtx returns the context used in `ToPB` method -func (s *session) GetBuildPBCtx() *planctx.BuildPBContext { - vars := s.GetSessionVars() - sc := vars.StmtCtx - - bctx := sc.GetOrInitBuildPBCtxFromCache(func() any { - return &planctx.BuildPBContext{ - ExprCtx: s.GetExprCtx(), - Client: s.GetClient(), - - TiFlashFastScan: s.GetSessionVars().TiFlashFastScan, - TiFlashFineGrainedShuffleBatchSize: s.GetSessionVars().TiFlashFineGrainedShuffleBatchSize, - - // the following fields are used to build `expression.PushDownContext`. - // TODO: it'd be better to embed `expression.PushDownContext` in `BuildPBContext`. But `expression` already - // depends on this package, so we need to move `expression.PushDownContext` to a standalone package first. - GroupConcatMaxLen: s.GetSessionVars().GroupConcatMaxLen, - InExplainStmt: s.GetSessionVars().StmtCtx.InExplainStmt, - WarnHandler: s.GetSessionVars().StmtCtx.WarnHandler, - ExtraWarnghandler: s.GetSessionVars().StmtCtx.ExtraWarnHandler, - } - }) - - return bctx.(*planctx.BuildPBContext) -} - -func (s *session) AuthPluginForUser(user *auth.UserIdentity) (string, error) { - pm := privilege.GetPrivilegeManager(s) - authplugin, err := pm.GetAuthPluginForConnection(user.Username, user.Hostname) - if err != nil { - return "", err - } - return authplugin, nil -} - -// Auth validates a user using an authentication string and salt. -// If the password fails, it will keep trying other users until exhausted. -// This means it can not be refactored to use MatchIdentity yet. -func (s *session) Auth(user *auth.UserIdentity, authentication, salt []byte, authConn conn.AuthConn) error { - hasPassword := "YES" - if len(authentication) == 0 { - hasPassword = "NO" - } - pm := privilege.GetPrivilegeManager(s) - authUser, err := s.MatchIdentity(user.Username, user.Hostname) - if err != nil { - return privileges.ErrAccessDenied.FastGenByArgs(user.Username, user.Hostname, hasPassword) - } - // Check whether continuous login failure is enabled to lock the account. - // If enabled, determine whether to unlock the account and notify TiDB to update the cache. - enableAutoLock := pm.IsAccountAutoLockEnabled(authUser.Username, authUser.Hostname) - if enableAutoLock { - err = failedLoginTrackingBegin(s) - if err != nil { - return err - } - lockStatusChanged, err := verifyAccountAutoLock(s, authUser.Username, authUser.Hostname) - if err != nil { - rollbackErr := failedLoginTrackingRollback(s) - if rollbackErr != nil { - return rollbackErr - } - return err - } - err = failedLoginTrackingCommit(s) - if err != nil { - rollbackErr := failedLoginTrackingRollback(s) - if rollbackErr != nil { - return rollbackErr - } - return err - } - if lockStatusChanged { - // Notification auto unlock. - err = domain.GetDomain(s).NotifyUpdatePrivilege() - if err != nil { - return err - } - } - } - - info, err := pm.ConnectionVerification(user, authUser.Username, authUser.Hostname, authentication, salt, s.sessionVars, authConn) - if err != nil { - if info.FailedDueToWrongPassword { - // when user enables the account locking function for consecutive login failures, - // the system updates the login failure count and determines whether to lock the account when authentication fails. - if enableAutoLock { - err := failedLoginTrackingBegin(s) - if err != nil { - return err - } - lockStatusChanged, passwordLocking, trackingErr := authFailedTracking(s, authUser.Username, authUser.Hostname) - if trackingErr != nil { - if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { - return rollBackErr - } - return trackingErr - } - if err := failedLoginTrackingCommit(s); err != nil { - if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { - return rollBackErr - } - return err - } - if lockStatusChanged { - // Notification auto lock. - err := autolockAction(s, passwordLocking, authUser.Username, authUser.Hostname) - if err != nil { - return err - } - } - } - } - return err - } - - if variable.EnableResourceControl.Load() && info.ResourceGroupName != "" { - s.sessionVars.SetResourceGroupName(info.ResourceGroupName) - } - - if info.InSandBoxMode { - // Enter sandbox mode, only execute statement for resetting password. - s.EnableSandBoxMode() - } - if enableAutoLock { - err := failedLoginTrackingBegin(s) - if err != nil { - return err - } - // The password is correct. If the account is not locked, the number of login failure statistics will be cleared. - err = authSuccessClearCount(s, authUser.Username, authUser.Hostname) - if err != nil { - if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { - return rollBackErr - } - return err - } - err = failedLoginTrackingCommit(s) - if err != nil { - if rollBackErr := failedLoginTrackingRollback(s); rollBackErr != nil { - return rollBackErr - } - return err - } - } - pm.AuthSuccess(authUser.Username, authUser.Hostname) - user.AuthUsername = authUser.Username - user.AuthHostname = authUser.Hostname - s.sessionVars.User = user - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) - return nil -} - -func authSuccessClearCount(s *session, user string, host string) error { - // Obtain accurate lock status and failure count information. - passwordLocking, err := getFailedLoginUserAttributes(s, user, host) - if err != nil { - return err - } - // If the account is locked, it may be caused by the untimely update of the cache, - // directly report the account lock. - if passwordLocking.AutoAccountLocked { - if passwordLocking.PasswordLockTimeDays == -1 { - return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, - "unlimited", "unlimited") - } - - lds := strconv.FormatInt(passwordLocking.PasswordLockTimeDays, 10) - return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, lds, lds) - } - if passwordLocking.FailedLoginCount != 0 { - // If the number of account login failures is not zero, it will be updated to 0. - passwordLockingJSON := privileges.BuildSuccessPasswordLockingJSON(passwordLocking.FailedLoginAttempts, - passwordLocking.PasswordLockTimeDays) - if passwordLockingJSON != "" { - if err := s.passwordLocking(user, host, passwordLockingJSON); err != nil { - return err - } - } - } - return nil -} - -func verifyAccountAutoLock(s *session, user, host string) (bool, error) { - pm := privilege.GetPrivilegeManager(s) - // Use the cache to determine whether to unlock the account. - // If the account needs to be unlocked, read the database information to determine whether - // the account needs to be unlocked. Otherwise, an error message is displayed. - lockStatusInMemory, err := pm.VerifyAccountAutoLockInMemory(user, host) - if err != nil { - return false, err - } - // If the lock status in the cache is Unlock, the automatic unlock is skipped. - // If memory synchronization is slow and there is a lock in the database, it will be processed upon successful login. - if !lockStatusInMemory { - return false, nil - } - lockStatusChanged := false - var plJSON string - // After checking the cache, obtain the latest data from the database and determine - // whether to automatically unlock the database to prevent repeated unlock errors. - pl, err := getFailedLoginUserAttributes(s, user, host) - if err != nil { - return false, err - } - if pl.AutoAccountLocked { - // If it is locked, need to check whether it can be automatically unlocked. - lockTimeDay := pl.PasswordLockTimeDays - if lockTimeDay == -1 { - return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, "unlimited", "unlimited") - } - lastChanged := pl.AutoLockedLastChanged - d := time.Now().Unix() - lastChanged - if d <= lockTimeDay*24*60*60 { - lds := strconv.FormatInt(lockTimeDay, 10) - rds := strconv.FormatInt(int64(math.Ceil(float64(lockTimeDay)-float64(d)/(24*60*60))), 10) - return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, lds, rds) - } - // Generate unlock json string. - plJSON = privileges.BuildPasswordLockingJSON(pl.FailedLoginAttempts, - pl.PasswordLockTimeDays, "N", 0, time.Now().Format(time.UnixDate)) - } - if plJSON != "" { - lockStatusChanged = true - if err = s.passwordLocking(user, host, plJSON); err != nil { - return false, err - } - } - return lockStatusChanged, nil -} - -func authFailedTracking(s *session, user string, host string) (bool, *privileges.PasswordLocking, error) { - // Obtain the number of consecutive password login failures. - passwordLocking, err := getFailedLoginUserAttributes(s, user, host) - if err != nil { - return false, nil, err - } - // Consecutive wrong password login failure times +1, - // If the lock condition is satisfied, the lock status is updated and the update cache is notified. - lockStatusChanged, err := userAutoAccountLocked(s, user, host, passwordLocking) - if err != nil { - return false, nil, err - } - return lockStatusChanged, passwordLocking, nil -} - -func autolockAction(s *session, passwordLocking *privileges.PasswordLocking, user, host string) error { - // Don't want to update the cache frequently, and only trigger the update cache when the lock status is updated. - err := domain.GetDomain(s).NotifyUpdatePrivilege() - if err != nil { - return err - } - // The number of failed login attempts reaches FAILED_LOGIN_ATTEMPTS. - // An error message is displayed indicating permission denial and account lock. - if passwordLocking.PasswordLockTimeDays == -1 { - return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, - "unlimited", "unlimited") - } - lds := strconv.FormatInt(passwordLocking.PasswordLockTimeDays, 10) - return privileges.GenerateAccountAutoLockErr(passwordLocking.FailedLoginAttempts, user, host, lds, lds) -} - -func (s *session) passwordLocking(user string, host string, newAttributesStr string) error { - sql := new(strings.Builder) - sqlescape.MustFormatSQL(sql, "UPDATE %n.%n SET ", mysql.SystemDB, mysql.UserTable) - sqlescape.MustFormatSQL(sql, "user_attributes=json_merge_patch(coalesce(user_attributes, '{}'), %?)", newAttributesStr) - sqlescape.MustFormatSQL(sql, " WHERE Host=%? and User=%?;", host, user) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) - _, err := s.ExecuteInternal(ctx, sql.String()) - return err -} - -func failedLoginTrackingBegin(s *session) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) - _, err := s.ExecuteInternal(ctx, "BEGIN PESSIMISTIC") - return err -} - -func failedLoginTrackingCommit(s *session) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) - _, err := s.ExecuteInternal(ctx, "COMMIT") - if err != nil { - _, rollBackErr := s.ExecuteInternal(ctx, "ROLLBACK") - if rollBackErr != nil { - return rollBackErr - } - } - return err -} - -func failedLoginTrackingRollback(s *session) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) - _, err := s.ExecuteInternal(ctx, "ROLLBACK") - return err -} - -// getFailedLoginUserAttributes queries the exact number of consecutive password login failures (concurrency is not allowed). -func getFailedLoginUserAttributes(s *session, user string, host string) (*privileges.PasswordLocking, error) { - passwordLocking := &privileges.PasswordLocking{} - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnPrivilege) - rs, err := s.ExecuteInternal(ctx, `SELECT user_attributes from mysql.user WHERE USER = %? AND HOST = %? for update`, user, host) - if err != nil { - return passwordLocking, err - } - defer func() { - if closeErr := rs.Close(); closeErr != nil { - err = closeErr - } - }() - req := rs.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - err = rs.Next(ctx, req) - if err != nil { - return passwordLocking, err - } - if req.NumRows() == 0 { - return passwordLocking, fmt.Errorf("user_attributes by `%s`@`%s` not found", user, host) - } - row := iter.Begin() - if !row.IsNull(0) { - passwordLockingJSON := row.GetJSON(0) - return passwordLocking, passwordLocking.ParseJSON(passwordLockingJSON) - } - return passwordLocking, fmt.Errorf("user_attributes by `%s`@`%s` not found", user, host) -} - -func userAutoAccountLocked(s *session, user string, host string, pl *privileges.PasswordLocking) (bool, error) { - // Indicates whether the user needs to update the lock status change. - lockStatusChanged := false - // The number of consecutive login failures is stored in the database. - // If the current login fails, one is added to the number of consecutive login failures - // stored in the database to determine whether the user needs to be locked and the number of update failures. - failedLoginCount := pl.FailedLoginCount + 1 - // If the cache is not updated, but it is already locked, it will report that the account is locked. - if pl.AutoAccountLocked { - if pl.PasswordLockTimeDays == -1 { - return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, - "unlimited", "unlimited") - } - lds := strconv.FormatInt(pl.PasswordLockTimeDays, 10) - return false, privileges.GenerateAccountAutoLockErr(pl.FailedLoginAttempts, user, host, lds, lds) - } - - autoAccountLocked := "N" - autoLockedLastChanged := "" - if pl.FailedLoginAttempts == 0 || pl.PasswordLockTimeDays == 0 { - return false, nil - } - - if failedLoginCount >= pl.FailedLoginAttempts { - autoLockedLastChanged = time.Now().Format(time.UnixDate) - autoAccountLocked = "Y" - lockStatusChanged = true - } - - newAttributesStr := privileges.BuildPasswordLockingJSON(pl.FailedLoginAttempts, - pl.PasswordLockTimeDays, autoAccountLocked, failedLoginCount, autoLockedLastChanged) - if newAttributesStr != "" { - return lockStatusChanged, s.passwordLocking(user, host, newAttributesStr) - } - return lockStatusChanged, nil -} - -// MatchIdentity finds the matching username + password in the MySQL privilege tables -// for a username + hostname, since MySQL can have wildcards. -func (s *session) MatchIdentity(username, remoteHost string) (*auth.UserIdentity, error) { - pm := privilege.GetPrivilegeManager(s) - var success bool - var skipNameResolve bool - var user = &auth.UserIdentity{} - varVal, err := s.GetSessionVars().GlobalVarsAccessor.GetGlobalSysVar(variable.SkipNameResolve) - if err == nil && variable.TiDBOptOn(varVal) { - skipNameResolve = true - } - user.Username, user.Hostname, success = pm.MatchIdentity(username, remoteHost, skipNameResolve) - if success { - return user, nil - } - // This error will not be returned to the user, access denied will be instead - return nil, fmt.Errorf("could not find matching user in MatchIdentity: %s, %s", username, remoteHost) -} - -// AuthWithoutVerification is required by the ResetConnection RPC -func (s *session) AuthWithoutVerification(user *auth.UserIdentity) bool { - pm := privilege.GetPrivilegeManager(s) - authUser, err := s.MatchIdentity(user.Username, user.Hostname) - if err != nil { - return false - } - if pm.GetAuthWithoutVerification(authUser.Username, authUser.Hostname) { - user.AuthUsername = authUser.Username - user.AuthHostname = authUser.Hostname - s.sessionVars.User = user - s.sessionVars.ActiveRoles = pm.GetDefaultRoles(user.AuthUsername, user.AuthHostname) - return true - } - return false -} - -// SetSessionStatesHandler implements the Session.SetSessionStatesHandler interface. -func (s *session) SetSessionStatesHandler(stateType sessionstates.SessionStateType, handler sessionctx.SessionStatesHandler) { - s.sessionStatesHandlers[stateType] = handler -} - -// ReportUsageStats reports the usage stats -func (s *session) ReportUsageStats() { - if s.idxUsageCollector != nil { - s.idxUsageCollector.Report() - } -} - -// CreateSession4Test creates a new session environment for test. -func CreateSession4Test(store kv.Storage) (types.Session, error) { - se, err := CreateSession4TestWithOpt(store, nil) - if err == nil { - // Cover both chunk rpc encoding and default encoding. - // nolint:gosec - if rand.Intn(2) == 0 { - se.GetSessionVars().EnableChunkRPC = false - } else { - se.GetSessionVars().EnableChunkRPC = true - } - } - return se, err -} - -// Opt describes the option for creating session -type Opt struct { - PreparedPlanCache sessionctx.SessionPlanCache -} - -// CreateSession4TestWithOpt creates a new session environment for test. -func CreateSession4TestWithOpt(store kv.Storage, opt *Opt) (types.Session, error) { - s, err := CreateSessionWithOpt(store, opt) - if err == nil { - // initialize session variables for test. - s.GetSessionVars().InitChunkSize = 2 - s.GetSessionVars().MaxChunkSize = 32 - s.GetSessionVars().MinPagingSize = variable.DefMinPagingSize - s.GetSessionVars().EnablePaging = variable.DefTiDBEnablePaging - s.GetSessionVars().StmtCtx.SetTimeZone(s.GetSessionVars().Location()) - err = s.GetSessionVars().SetSystemVarWithoutValidation(variable.CharacterSetConnection, "utf8mb4") - } - return s, err -} - -// CreateSession creates a new session environment. -func CreateSession(store kv.Storage) (types.Session, error) { - return CreateSessionWithOpt(store, nil) -} - -// CreateSessionWithOpt creates a new session environment with option. -// Use default option if opt is nil. -func CreateSessionWithOpt(store kv.Storage, opt *Opt) (types.Session, error) { - s, err := createSessionWithOpt(store, opt) - if err != nil { - return nil, err - } - - // Add auth here. - do, err := domap.Get(store) - if err != nil { - return nil, err - } - extensions, err := extension.GetExtensions() - if err != nil { - return nil, err - } - pm := privileges.NewUserPrivileges(do.PrivilegeHandle(), extensions) - privilege.BindPrivilegeManager(s, pm) - - // Add stats collector, and it will be freed by background stats worker - // which periodically updates stats using the collected data. - if do.StatsHandle() != nil && do.StatsUpdating() { - s.statsCollector = do.StatsHandle().NewSessionStatsItem().(*usage.SessionStatsItem) - if config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load() { - s.idxUsageCollector = do.StatsHandle().NewSessionIndexUsageCollector() - } - } - - s.cursorTracker = cursor.NewTracker() - - return s, nil -} - -// loadCollationParameter loads collation parameter from mysql.tidb -func loadCollationParameter(ctx context.Context, se *session) (bool, error) { - para, err := se.getTableValue(ctx, mysql.TiDBTable, TidbNewCollationEnabled) - if err != nil { - return false, err - } - if para == varTrue { - return true, nil - } else if para == varFalse { - return false, nil - } - logutil.BgLogger().Warn( - "Unexpected value of 'new_collation_enabled' in 'mysql.tidb', use 'False' instead", - zap.String("value", para)) - return false, nil -} - -type tableBasicInfo struct { - SQL string - id int64 -} - -var ( - errResultIsEmpty = dbterror.ClassExecutor.NewStd(errno.ErrResultIsEmpty) - // DDLJobTables is a list of tables definitions used in concurrent DDL. - DDLJobTables = []tableBasicInfo{ - {ddl.JobTableSQL, ddl.JobTableID}, - {ddl.ReorgTableSQL, ddl.ReorgTableID}, - {ddl.HistoryTableSQL, ddl.HistoryTableID}, - } - // BackfillTables is a list of tables definitions used in dist reorg DDL. - BackfillTables = []tableBasicInfo{ - {ddl.BackgroundSubtaskTableSQL, ddl.BackgroundSubtaskTableID}, - {ddl.BackgroundSubtaskHistoryTableSQL, ddl.BackgroundSubtaskHistoryTableID}, - } - mdlTable = `create table mysql.tidb_mdl_info( - job_id BIGINT NOT NULL PRIMARY KEY, - version BIGINT NOT NULL, - table_ids text(65535), - owner_id varchar(64) NOT NULL DEFAULT '' - );` -) - -func splitAndScatterTable(store kv.Storage, tableIDs []int64) { - if s, ok := store.(kv.SplittableStore); ok && atomic.LoadUint32(&ddl.EnableSplitTableRegion) == 1 { - ctxWithTimeout, cancel := context.WithTimeout(context.Background(), variable.DefWaitSplitRegionTimeout*time.Second) - var regionIDs []uint64 - for _, id := range tableIDs { - regionIDs = append(regionIDs, ddl.SplitRecordRegion(ctxWithTimeout, s, id, id, variable.DefTiDBScatterRegion)) - } - if variable.DefTiDBScatterRegion { - ddl.WaitScatterRegionFinish(ctxWithTimeout, s, regionIDs...) - } - cancel() - } -} - -// InitDDLJobTables is to create tidb_ddl_job, tidb_ddl_reorg and tidb_ddl_history, or tidb_background_subtask and tidb_background_subtask_history. -func InitDDLJobTables(store kv.Storage, targetVer meta.DDLTableVersion) error { - targetTables := DDLJobTables - if targetVer == meta.BackfillTableVersion { - targetTables = BackfillTables - } - return kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - tableVer, err := t.CheckDDLTableVersion() - if err != nil || tableVer >= targetVer { - return errors.Trace(err) - } - dbID, err := t.CreateMySQLDatabaseIfNotExists() - if err != nil { - return err - } - if err = createAndSplitTables(store, t, dbID, targetTables); err != nil { - return err - } - return t.SetDDLTables(targetVer) - }) -} - -func createAndSplitTables(store kv.Storage, t *meta.Meta, dbID int64, tables []tableBasicInfo) error { - tableIDs := make([]int64, 0, len(tables)) - for _, tbl := range tables { - tableIDs = append(tableIDs, tbl.id) - } - splitAndScatterTable(store, tableIDs) - p := parser.New() - for _, tbl := range tables { - stmt, err := p.ParseOneStmt(tbl.SQL, "", "") - if err != nil { - return errors.Trace(err) - } - tblInfo, err := ddl.BuildTableInfoFromAST(stmt.(*ast.CreateTableStmt)) - if err != nil { - return errors.Trace(err) - } - tblInfo.State = model.StatePublic - tblInfo.ID = tbl.id - tblInfo.UpdateTS = t.StartTS - err = t.CreateTableOrView(dbID, tblInfo) - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -// InitMDLTable is to create tidb_mdl_info, which is used for metadata lock. -func InitMDLTable(store kv.Storage) error { - return kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - ver, err := t.CheckDDLTableVersion() - if err != nil || ver >= meta.MDLTableVersion { - return errors.Trace(err) - } - dbID, err := t.CreateMySQLDatabaseIfNotExists() - if err != nil { - return err - } - splitAndScatterTable(store, []int64{ddl.MDLTableID}) - p := parser.New() - stmt, err := p.ParseOneStmt(mdlTable, "", "") - if err != nil { - return errors.Trace(err) - } - tblInfo, err := ddl.BuildTableInfoFromAST(stmt.(*ast.CreateTableStmt)) - if err != nil { - return errors.Trace(err) - } - tblInfo.State = model.StatePublic - tblInfo.ID = ddl.MDLTableID - tblInfo.UpdateTS = t.StartTS - err = t.CreateTableOrView(dbID, tblInfo) - if err != nil { - return errors.Trace(err) - } - - return t.SetDDLTables(meta.MDLTableVersion) - }) -} - -// InitMDLVariableForBootstrap initializes the metadata lock variable. -func InitMDLVariableForBootstrap(store kv.Storage) error { - err := kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - return t.SetMetadataLock(true) - }) - if err != nil { - return err - } - variable.EnableMDL.Store(true) - return nil -} - -// InitTiDBSchemaCacheSize initializes the tidb schema cache size. -func InitTiDBSchemaCacheSize(store kv.Storage) error { - var ( - isNull bool - size uint64 - err error - ) - err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - size, isNull, err = t.GetSchemaCacheSize() - if err != nil { - return errors.Trace(err) - } - if isNull { - size = variable.DefTiDBSchemaCacheSize - return t.SetSchemaCacheSize(size) - } - return nil - }) - if err != nil { - return errors.Trace(err) - } - variable.SchemaCacheSize.Store(size) - return nil -} - -// InitMDLVariableForUpgrade initializes the metadata lock variable. -func InitMDLVariableForUpgrade(store kv.Storage) (bool, error) { - isNull := false - enable := false - var err error - err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - enable, isNull, err = t.GetMetadataLock() - if err != nil { - return err - } - return nil - }) - if isNull || !enable { - variable.EnableMDL.Store(false) - } else { - variable.EnableMDL.Store(true) - } - return isNull, err -} - -// InitMDLVariable initializes the metadata lock variable. -func InitMDLVariable(store kv.Storage) error { - isNull := false - enable := false - var err error - err = kv.RunInNewTxn(kv.WithInternalSourceType(context.Background(), kv.InternalTxnDDL), store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - enable, isNull, err = t.GetMetadataLock() - if err != nil { - return err - } - if isNull { - // Workaround for version: nightly-2022-11-07 to nightly-2022-11-17. - enable = true - logutil.BgLogger().Warn("metadata lock is null") - err = t.SetMetadataLock(true) - if err != nil { - return err - } - } - return nil - }) - variable.EnableMDL.Store(enable) - return err -} - -// BootstrapSession bootstrap session and domain. -func BootstrapSession(store kv.Storage) (*domain.Domain, error) { - return bootstrapSessionImpl(store, createSessions) -} - -// BootstrapSession4DistExecution bootstrap session and dom for Distributed execution test, only for unit testing. -func BootstrapSession4DistExecution(store kv.Storage) (*domain.Domain, error) { - return bootstrapSessionImpl(store, createSessions4DistExecution) -} - -// bootstrapSessionImpl bootstraps session and domain. -// the process works as follows: -// - if we haven't bootstrapped to the target version -// - create/init/start domain -// - bootstrap or upgrade, some variables will be initialized and stored to system -// table in the process, such as system time-zone -// - close domain -// -// - create/init another domain -// - initialization global variables from system table that's required to use sessionCtx, -// such as system time zone -// - start domain and other routines. -func bootstrapSessionImpl(store kv.Storage, createSessionsImpl func(store kv.Storage, cnt int) ([]*session, error)) (*domain.Domain, error) { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) - cfg := config.GetGlobalConfig() - if len(cfg.Instance.PluginLoad) > 0 { - err := plugin.Load(context.Background(), plugin.Config{ - Plugins: strings.Split(cfg.Instance.PluginLoad, ","), - PluginDir: cfg.Instance.PluginDir, - }) - if err != nil { - return nil, err - } - } - err := InitDDLJobTables(store, meta.BaseDDLTableVersion) - if err != nil { - return nil, err - } - err = InitMDLTable(store) - if err != nil { - return nil, err - } - err = InitDDLJobTables(store, meta.BackfillTableVersion) - if err != nil { - return nil, err - } - err = InitTiDBSchemaCacheSize(store) - if err != nil { - return nil, err - } - ver := getStoreBootstrapVersion(store) - if ver == notBootstrapped { - runInBootstrapSession(store, bootstrap) - } else if ver < currentBootstrapVersion { - runInBootstrapSession(store, upgrade) - } else { - err = InitMDLVariable(store) - if err != nil { - return nil, err - } - } - - // initiate disttask framework components which need a store - scheduler.RegisterSchedulerFactory( - proto.ImportInto, - func(ctx context.Context, task *proto.Task, param scheduler.Param) scheduler.Scheduler { - return importinto.NewImportScheduler(ctx, task, param, store.(kv.StorageWithPD)) - }, - ) - taskexecutor.RegisterTaskType( - proto.ImportInto, - func(ctx context.Context, id string, task *proto.Task, table taskexecutor.TaskTable) taskexecutor.TaskExecutor { - return importinto.NewImportExecutor(ctx, id, task, table, store) - }, - ) - - analyzeConcurrencyQuota := int(config.GetGlobalConfig().Performance.AnalyzePartitionConcurrencyQuota) - concurrency := config.GetGlobalConfig().Performance.StatsLoadConcurrency - if concurrency == 0 { - // if concurrency is 0, we will set the concurrency of sync load by CPU. - concurrency = syncload.GetSyncLoadConcurrencyByCPU() - } - if concurrency < 0 { // it is only for test, in the production, negative value is illegal. - concurrency = 0 - } - - ses, err := createSessionsImpl(store, 10) - if err != nil { - return nil, err - } - ses[0].GetSessionVars().InRestrictedSQL = true - - // get system tz from mysql.tidb - tz, err := ses[0].getTableValue(ctx, mysql.TiDBTable, tidbSystemTZ) - if err != nil { - return nil, err - } - timeutil.SetSystemTZ(tz) - - // get the flag from `mysql`.`tidb` which indicating if new collations are enabled. - newCollationEnabled, err := loadCollationParameter(ctx, ses[0]) - if err != nil { - return nil, err - } - collate.SetNewCollationEnabledForTest(newCollationEnabled) - - // only start the domain after we have initialized some global variables. - dom := domain.GetDomain(ses[0]) - err = dom.Start() - if err != nil { - return nil, err - } - - // To deal with the location partition failure caused by inconsistent NewCollationEnabled values(see issue #32416). - rebuildAllPartitionValueMapAndSorted(ses[0]) - - // We should make the load bind-info loop before other loops which has internal SQL. - // Because the internal SQL may access the global bind-info handler. As the result, the data race occurs here as the - // LoadBindInfoLoop inits global bind-info handler. - err = dom.LoadBindInfoLoop(ses[1], ses[2]) - if err != nil { - return nil, err - } - - if !config.GetGlobalConfig().Security.SkipGrantTable { - err = dom.LoadPrivilegeLoop(ses[3]) - if err != nil { - return nil, err - } - } - - // Rebuild sysvar cache in a loop - err = dom.LoadSysVarCacheLoop(ses[4]) - if err != nil { - return nil, err - } - - if config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler { - // Invalid client-go tiflash_compute store cache if necessary. - err = dom.WatchTiFlashComputeNodeChange() - if err != nil { - return nil, err - } - } - - if err = extensionimpl.Bootstrap(context.Background(), dom); err != nil { - return nil, err - } - - if len(cfg.Instance.PluginLoad) > 0 { - err := plugin.Init(context.Background(), plugin.Config{EtcdClient: dom.GetEtcdClient()}) - if err != nil { - return nil, err - } - } - - err = executor.LoadExprPushdownBlacklist(ses[5]) - if err != nil { - return nil, err - } - err = executor.LoadOptRuleBlacklist(ctx, ses[5]) - if err != nil { - return nil, err - } - - planReplayerWorkerCnt := config.GetGlobalConfig().Performance.PlanReplayerDumpWorkerConcurrency - planReplayerWorkersSctx := make([]sessionctx.Context, planReplayerWorkerCnt) - pworkerSes, err := createSessions(store, int(planReplayerWorkerCnt)) - if err != nil { - return nil, err - } - for i := 0; i < int(planReplayerWorkerCnt); i++ { - planReplayerWorkersSctx[i] = pworkerSes[i] - } - // setup plan replayer handle - dom.SetupPlanReplayerHandle(ses[6], planReplayerWorkersSctx) - dom.StartPlanReplayerHandle() - // setup dumpFileGcChecker - dom.SetupDumpFileGCChecker(ses[7]) - dom.DumpFileGcCheckerLoop() - // setup historical stats worker - dom.SetupHistoricalStatsWorker(ses[8]) - dom.StartHistoricalStatsWorker() - failToLoadOrParseSQLFile := false // only used for unit test - if runBootstrapSQLFile { - pm := &privileges.UserPrivileges{ - Handle: dom.PrivilegeHandle(), - } - privilege.BindPrivilegeManager(ses[9], pm) - if err := doBootstrapSQLFile(ses[9]); err != nil && intest.InTest { - failToLoadOrParseSQLFile = true - } - } - // A sub context for update table stats, and other contexts for concurrent stats loading. - cnt := 1 + concurrency - syncStatsCtxs, err := createSessions(store, cnt) - if err != nil { - return nil, err - } - subCtxs := make([]sessionctx.Context, cnt) - for i := 0; i < cnt; i++ { - subCtxs[i] = sessionctx.Context(syncStatsCtxs[i]) - } - - // setup extract Handle - extractWorkers := 1 - sctxs, err := createSessions(store, extractWorkers) - if err != nil { - return nil, err - } - extractWorkerSctxs := make([]sessionctx.Context, 0) - for _, sctx := range sctxs { - extractWorkerSctxs = append(extractWorkerSctxs, sctx) - } - dom.SetupExtractHandle(extractWorkerSctxs) - - // setup init stats loader - initStatsCtx, err := createSession(store) - if err != nil { - return nil, err - } - if err = dom.LoadAndUpdateStatsLoop(subCtxs, initStatsCtx); err != nil { - return nil, err - } - - // init the instance plan cache - dom.InitInstancePlanCache() - - // start TTL job manager after setup stats collector - // because TTL could modify a lot of columns, and need to trigger auto analyze - ttlworker.AttachStatsCollector = func(s sqlexec.SQLExecutor) sqlexec.SQLExecutor { - if s, ok := s.(*session); ok { - return attachStatsCollector(s, dom) - } - return s - } - ttlworker.DetachStatsCollector = func(s sqlexec.SQLExecutor) sqlexec.SQLExecutor { - if s, ok := s.(*session); ok { - return detachStatsCollector(s) - } - return s - } - dom.StartTTLJobManager() - - analyzeCtxs, err := createSessions(store, analyzeConcurrencyQuota) - if err != nil { - return nil, err - } - subCtxs2 := make([]sessionctx.Context, analyzeConcurrencyQuota) - for i := 0; i < analyzeConcurrencyQuota; i++ { - subCtxs2[i] = analyzeCtxs[i] - } - dom.SetupAnalyzeExec(subCtxs2) - dom.LoadSigningCertLoop(cfg.Security.SessionTokenSigningCert, cfg.Security.SessionTokenSigningKey) - - if raw, ok := store.(kv.EtcdBackend); ok { - err = raw.StartGCWorker() - if err != nil { - return nil, err - } - } - - // This only happens in testing, since the failure of loading or parsing sql file - // would panic the bootstrapping. - if intest.InTest && failToLoadOrParseSQLFile { - dom.Close() - return nil, errors.New("Fail to load or parse sql file") - } - err = dom.InitDistTaskLoop() - if err != nil { - return nil, err - } - return dom, err -} - -// GetDomain gets the associated domain for store. -func GetDomain(store kv.Storage) (*domain.Domain, error) { - return domap.Get(store) -} - -// runInBootstrapSession create a special session for bootstrap to run. -// If no bootstrap and storage is remote, we must use a little lease time to -// bootstrap quickly, after bootstrapped, we will reset the lease time. -// TODO: Using a bootstrap tool for doing this may be better later. -func runInBootstrapSession(store kv.Storage, bootstrap func(types.Session)) { - s, err := createSession(store) - if err != nil { - // Bootstrap fail will cause program exit. - logutil.BgLogger().Fatal("createSession error", zap.Error(err)) - } - dom := domain.GetDomain(s) - err = dom.Start() - if err != nil { - // Bootstrap fail will cause program exit. - logutil.BgLogger().Fatal("start domain error", zap.Error(err)) - } - - // For the bootstrap SQLs, the following variables should be compatible with old TiDB versions. - s.sessionVars.EnableClusteredIndex = variable.ClusteredIndexDefModeIntOnly - - s.SetValue(sessionctx.Initing, true) - bootstrap(s) - finishBootstrap(store) - s.ClearValue(sessionctx.Initing) - - dom.Close() - if intest.InTest { - infosync.MockGlobalServerInfoManagerEntry.Close() - } - domap.Delete(store) -} - -func createSessions(store kv.Storage, cnt int) ([]*session, error) { - return createSessionsImpl(store, cnt, createSession) -} - -func createSessions4DistExecution(store kv.Storage, cnt int) ([]*session, error) { - domap.Delete(store) - - return createSessionsImpl(store, cnt, createSession4DistExecution) -} - -func createSessionsImpl(store kv.Storage, cnt int, createSessionImpl func(kv.Storage) (*session, error)) ([]*session, error) { - // Then we can create new dom - ses := make([]*session, cnt) - for i := 0; i < cnt; i++ { - se, err := createSessionImpl(store) - if err != nil { - return nil, err - } - ses[i] = se - } - - return ses, nil -} - -// createSession creates a new session. -// Please note that such a session is not tracked by the internal session list. -// This means the min ts reporter is not aware of it and may report a wrong min start ts. -// In most cases you should use a session pool in domain instead. -func createSession(store kv.Storage) (*session, error) { - return createSessionWithOpt(store, nil) -} - -func createSession4DistExecution(store kv.Storage) (*session, error) { - return createSessionWithOpt(store, nil) -} - -func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { - dom, err := domap.Get(store) - if err != nil { - return nil, err - } - s := &session{ - store: store, - ddlOwnerManager: dom.DDL().OwnerManager(), - client: store.GetClient(), - mppClient: store.GetMPPClient(), - stmtStats: stmtstats.CreateStatementStats(), - sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), - } - s.sessionVars = variable.NewSessionVars(s) - s.exprctx = contextsession.NewSessionExprContext(s) - s.pctx = newPlanContextImpl(s) - s.tblctx = tbctximpl.NewTableContextImpl(s) - - if opt != nil && opt.PreparedPlanCache != nil { - s.sessionPlanCache = opt.PreparedPlanCache - } - s.mu.values = make(map[fmt.Stringer]any) - s.lockedTables = make(map[int64]model.TableLockTpInfo) - s.advisoryLocks = make(map[string]*advisoryLock) - - domain.BindDomain(s, dom) - // session implements variable.GlobalVarAccessor. Bind it to ctx. - s.sessionVars.GlobalVarsAccessor = s - s.sessionVars.BinlogClient = binloginfo.GetPumpsClient() - s.txn.init() - - sessionBindHandle := bindinfo.NewSessionBindingHandle() - s.SetValue(bindinfo.SessionBindInfoKeyType, sessionBindHandle) - s.SetSessionStatesHandler(sessionstates.StateBinding, sessionBindHandle) - return s, nil -} - -// attachStatsCollector attaches the stats collector in the dom for the session -func attachStatsCollector(s *session, dom *domain.Domain) *session { - if dom.StatsHandle() != nil && dom.StatsUpdating() { - if s.statsCollector == nil { - s.statsCollector = dom.StatsHandle().NewSessionStatsItem().(*usage.SessionStatsItem) - } - if s.idxUsageCollector == nil && config.GetGlobalConfig().Instance.EnableCollectExecutionInfo.Load() { - s.idxUsageCollector = dom.StatsHandle().NewSessionIndexUsageCollector() - } - } - - return s -} - -// detachStatsCollector removes the stats collector in the session -func detachStatsCollector(s *session) *session { - if s.statsCollector != nil { - s.statsCollector.Delete() - s.statsCollector = nil - } - if s.idxUsageCollector != nil { - s.idxUsageCollector.Flush() - s.idxUsageCollector = nil - } - return s -} - -// CreateSessionWithDomain creates a new Session and binds it with a Domain. -// We need this because when we start DDL in Domain, the DDL need a session -// to change some system tables. But at that time, we have been already in -// a lock context, which cause we can't call createSession directly. -func CreateSessionWithDomain(store kv.Storage, dom *domain.Domain) (*session, error) { - s := &session{ - store: store, - sessionVars: variable.NewSessionVars(nil), - client: store.GetClient(), - mppClient: store.GetMPPClient(), - stmtStats: stmtstats.CreateStatementStats(), - sessionStatesHandlers: make(map[sessionstates.SessionStateType]sessionctx.SessionStatesHandler), - } - s.exprctx = contextsession.NewSessionExprContext(s) - s.pctx = newPlanContextImpl(s) - s.tblctx = tbctximpl.NewTableContextImpl(s) - s.mu.values = make(map[fmt.Stringer]any) - s.lockedTables = make(map[int64]model.TableLockTpInfo) - domain.BindDomain(s, dom) - // session implements variable.GlobalVarAccessor. Bind it to ctx. - s.sessionVars.GlobalVarsAccessor = s - s.txn.init() - return s, nil -} - -const ( - notBootstrapped = 0 -) - -func getStoreBootstrapVersion(store kv.Storage) int64 { - storeBootstrappedLock.Lock() - defer storeBootstrappedLock.Unlock() - // check in memory - _, ok := storeBootstrapped[store.UUID()] - if ok { - return currentBootstrapVersion - } - - var ver int64 - // check in kv store - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) - err := kv.RunInNewTxn(ctx, store, false, func(_ context.Context, txn kv.Transaction) error { - var err error - t := meta.NewMeta(txn) - ver, err = t.GetBootstrapVersion() - return err - }) - if err != nil { - logutil.BgLogger().Fatal("check bootstrapped failed", - zap.Error(err)) - } - - if ver > notBootstrapped { - // here mean memory is not ok, but other server has already finished it - storeBootstrapped[store.UUID()] = true - } - - modifyBootstrapVersionForTest(ver) - return ver -} - -func finishBootstrap(store kv.Storage) { - setStoreBootstrapped(store.UUID()) - - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) - err := kv.RunInNewTxn(ctx, store, true, func(_ context.Context, txn kv.Transaction) error { - t := meta.NewMeta(txn) - err := t.FinishBootstrap(currentBootstrapVersion) - return err - }) - if err != nil { - logutil.BgLogger().Fatal("finish bootstrap failed", - zap.Error(err)) - } -} - -const quoteCommaQuote = "', '" - -// loadCommonGlobalVariablesIfNeeded loads and applies commonly used global variables for the session. -func (s *session) loadCommonGlobalVariablesIfNeeded() error { - vars := s.sessionVars - if vars.CommonGlobalLoaded { - return nil - } - if s.Value(sessionctx.Initing) != nil { - // When running bootstrap or upgrade, we should not access global storage. - // But we need to init max_allowed_packet to use concat function during bootstrap or upgrade. - err := vars.SetSystemVar(variable.MaxAllowedPacket, strconv.FormatUint(variable.DefMaxAllowedPacket, 10)) - if err != nil { - logutil.BgLogger().Error("set system variable max_allowed_packet error", zap.Error(err)) - } - return nil - } - - vars.CommonGlobalLoaded = true - - // Deep copy sessionvar cache - sessionCache, err := domain.GetDomain(s).GetSessionCache() - if err != nil { - return err - } - for varName, varVal := range sessionCache { - if _, ok := vars.GetSystemVar(varName); !ok { - err = vars.SetSystemVarWithRelaxedValidation(varName, varVal) - if err != nil { - if variable.ErrUnknownSystemVar.Equal(err) { - continue // sessionCache is stale; sysvar has likely been unregistered - } - return err - } - } - } - // when client set Capability Flags CLIENT_INTERACTIVE, init wait_timeout with interactive_timeout - if vars.ClientCapability&mysql.ClientInteractive > 0 { - if varVal, ok := vars.GetSystemVar(variable.InteractiveTimeout); ok { - if err := vars.SetSystemVar(variable.WaitTimeout, varVal); err != nil { - return err - } - } - } - return nil -} - -// PrepareTxnCtx begins a transaction, and creates a new transaction context. -// It is called before we execute a sql query. -func (s *session) PrepareTxnCtx(ctx context.Context) error { - s.currentCtx = ctx - if s.txn.validOrPending() { - return nil - } - - txnMode := ast.Optimistic - if !s.sessionVars.IsAutocommit() || (config.GetGlobalConfig().PessimisticTxn. - PessimisticAutoCommit.Load() && !s.GetSessionVars().BulkDMLEnabled) { - if s.sessionVars.TxnMode == ast.Pessimistic { - txnMode = ast.Pessimistic - } - } - - if s.sessionVars.RetryInfo.Retrying { - txnMode = ast.Pessimistic - } - - return sessiontxn.GetTxnManager(s).EnterNewTxn(ctx, &sessiontxn.EnterNewTxnRequest{ - Type: sessiontxn.EnterNewTxnBeforeStmt, - TxnMode: txnMode, - }) -} - -// PrepareTSFuture uses to try to get ts future. -func (s *session) PrepareTSFuture(ctx context.Context, future oracle.Future, scope string) error { - if s.txn.Valid() { - return errors.New("cannot prepare ts future when txn is valid") - } - - failpoint.Inject("assertTSONotRequest", func() { - if _, ok := future.(sessiontxn.ConstantFuture); !ok && !s.isInternal() { - panic("tso shouldn't be requested") - } - }) - - failpoint.InjectContext(ctx, "mockGetTSFail", func() { - future = txnFailFuture{} - }) - - s.txn.changeToPending(&txnFuture{ - future: future, - store: s.store, - txnScope: scope, - pipelined: s.usePipelinedDmlOrWarn(ctx), - }) - return nil -} - -// GetPreparedTxnFuture returns the TxnFuture if it is valid or pending. -// It returns nil otherwise. -func (s *session) GetPreparedTxnFuture() sessionctx.TxnFuture { - if !s.txn.validOrPending() { - return nil - } - return &s.txn -} - -// RefreshTxnCtx implements context.RefreshTxnCtx interface. -func (s *session) RefreshTxnCtx(ctx context.Context) error { - var commitDetail *tikvutil.CommitDetails - ctx = context.WithValue(ctx, tikvutil.CommitDetailCtxKey, &commitDetail) - err := s.doCommit(ctx) - if commitDetail != nil { - s.GetSessionVars().StmtCtx.MergeExecDetails(nil, commitDetail) - } - if err != nil { - return err - } - - s.updateStatsDeltaToCollector() - - return sessiontxn.NewTxn(ctx, s) -} - -// GetStore gets the store of session. -func (s *session) GetStore() kv.Storage { - return s.store -} - -func (s *session) ShowProcess() *util.ProcessInfo { - return s.processInfo.Load() -} - -// GetStartTSFromSession returns the startTS in the session `se` -func GetStartTSFromSession(se any) (startTS, processInfoID uint64) { - tmp, ok := se.(*session) - if !ok { - logutil.BgLogger().Error("GetStartTSFromSession failed, can't transform to session struct") - return 0, 0 - } - txnInfo := tmp.TxnInfo() - if txnInfo != nil { - startTS = txnInfo.StartTS - processInfoID = txnInfo.ConnectionID - } - - logutil.BgLogger().Debug( - "GetStartTSFromSession getting startTS of internal session", - zap.Uint64("startTS", startTS), zap.Time("start time", oracle.GetTimeFromTS(startTS))) - - return startTS, processInfoID -} - -// logStmt logs some crucial SQL including: CREATE USER/GRANT PRIVILEGE/CHANGE PASSWORD/DDL etc and normal SQL -// if variable.ProcessGeneralLog is set. -func logStmt(execStmt *executor.ExecStmt, s *session) { - vars := s.GetSessionVars() - isCrucial := false - switch stmt := execStmt.StmtNode.(type) { - case *ast.DropIndexStmt: - isCrucial = true - if stmt.IsHypo { - isCrucial = false - } - case *ast.CreateIndexStmt: - isCrucial = true - if stmt.IndexOption != nil && stmt.IndexOption.Tp == model.IndexTypeHypo { - isCrucial = false - } - case *ast.CreateUserStmt, *ast.DropUserStmt, *ast.AlterUserStmt, *ast.SetPwdStmt, *ast.GrantStmt, - *ast.RevokeStmt, *ast.AlterTableStmt, *ast.CreateDatabaseStmt, *ast.CreateTableStmt, - *ast.DropDatabaseStmt, *ast.DropTableStmt, *ast.RenameTableStmt, *ast.TruncateTableStmt, - *ast.RenameUserStmt: - isCrucial = true - } - - if isCrucial { - user := vars.User - schemaVersion := s.GetInfoSchema().SchemaMetaVersion() - if ss, ok := execStmt.StmtNode.(ast.SensitiveStmtNode); ok { - logutil.BgLogger().Info("CRUCIAL OPERATION", - zap.Uint64("conn", vars.ConnectionID), - zap.Int64("schemaVersion", schemaVersion), - zap.String("secure text", ss.SecureText()), - zap.Stringer("user", user)) - } else { - logutil.BgLogger().Info("CRUCIAL OPERATION", - zap.Uint64("conn", vars.ConnectionID), - zap.Int64("schemaVersion", schemaVersion), - zap.String("cur_db", vars.CurrentDB), - zap.String("sql", execStmt.StmtNode.Text()), - zap.Stringer("user", user)) - } - } else { - logGeneralQuery(execStmt, s, false) - } -} - -func logGeneralQuery(execStmt *executor.ExecStmt, s *session, isPrepared bool) { - vars := s.GetSessionVars() - if variable.ProcessGeneralLog.Load() && !vars.InRestrictedSQL { - var query string - if isPrepared { - query = execStmt.OriginText() - } else { - query = execStmt.GetTextToLog(false) - } - - query = executor.QueryReplacer.Replace(query) - if vars.EnableRedactLog != errors.RedactLogEnable { - query += redact.String(vars.EnableRedactLog, vars.PlanCacheParams.String()) - } - logutil.GeneralLogger.Info("GENERAL_LOG", - zap.Uint64("conn", vars.ConnectionID), - zap.String("session_alias", vars.SessionAlias), - zap.String("user", vars.User.LoginString()), - zap.Int64("schemaVersion", s.GetInfoSchema().SchemaMetaVersion()), - zap.Uint64("txnStartTS", vars.TxnCtx.StartTS), - zap.Uint64("forUpdateTS", vars.TxnCtx.GetForUpdateTS()), - zap.Bool("isReadConsistency", vars.IsIsolation(ast.ReadCommitted)), - zap.String("currentDB", vars.CurrentDB), - zap.Bool("isPessimistic", vars.TxnCtx.IsPessimistic), - zap.String("sessionTxnMode", vars.GetReadableTxnMode()), - zap.String("sql", query)) - } -} - -func (s *session) recordOnTransactionExecution(err error, counter int, duration float64, isInternal bool) { - if s.sessionVars.TxnCtx.IsPessimistic { - if err != nil { - if isInternal { - session_metrics.TransactionDurationPessimisticAbortInternal.Observe(duration) - session_metrics.StatementPerTransactionPessimisticErrorInternal.Observe(float64(counter)) - } else { - session_metrics.TransactionDurationPessimisticAbortGeneral.Observe(duration) - session_metrics.StatementPerTransactionPessimisticErrorGeneral.Observe(float64(counter)) - } - } else { - if isInternal { - session_metrics.TransactionDurationPessimisticCommitInternal.Observe(duration) - session_metrics.StatementPerTransactionPessimisticOKInternal.Observe(float64(counter)) - } else { - session_metrics.TransactionDurationPessimisticCommitGeneral.Observe(duration) - session_metrics.StatementPerTransactionPessimisticOKGeneral.Observe(float64(counter)) - } - } - } else { - if err != nil { - if isInternal { - session_metrics.TransactionDurationOptimisticAbortInternal.Observe(duration) - session_metrics.StatementPerTransactionOptimisticErrorInternal.Observe(float64(counter)) - } else { - session_metrics.TransactionDurationOptimisticAbortGeneral.Observe(duration) - session_metrics.StatementPerTransactionOptimisticErrorGeneral.Observe(float64(counter)) - } - } else { - if isInternal { - session_metrics.TransactionDurationOptimisticCommitInternal.Observe(duration) - session_metrics.StatementPerTransactionOptimisticOKInternal.Observe(float64(counter)) - } else { - session_metrics.TransactionDurationOptimisticCommitGeneral.Observe(duration) - session_metrics.StatementPerTransactionOptimisticOKGeneral.Observe(float64(counter)) - } - } - } -} - -func (s *session) checkPlacementPolicyBeforeCommit() error { - var err error - // Get the txnScope of the transaction we're going to commit. - txnScope := s.GetSessionVars().TxnCtx.TxnScope - if txnScope == "" { - txnScope = kv.GlobalTxnScope - } - if txnScope != kv.GlobalTxnScope { - is := s.GetInfoSchema().(infoschema.InfoSchema) - deltaMap := s.GetSessionVars().TxnCtx.TableDeltaMap - for physicalTableID := range deltaMap { - var tableName string - var partitionName string - tblInfo, _, partInfo := is.FindTableByPartitionID(physicalTableID) - if tblInfo != nil && partInfo != nil { - tableName = tblInfo.Meta().Name.String() - partitionName = partInfo.Name.String() - } else { - tblInfo, _ := is.TableByID(physicalTableID) - tableName = tblInfo.Meta().Name.String() - } - bundle, ok := is.PlacementBundleByPhysicalTableID(physicalTableID) - if !ok { - errMsg := fmt.Sprintf("table %v doesn't have placement policies with txn_scope %v", - tableName, txnScope) - if len(partitionName) > 0 { - errMsg = fmt.Sprintf("table %v's partition %v doesn't have placement policies with txn_scope %v", - tableName, partitionName, txnScope) - } - err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) - break - } - dcLocation, ok := bundle.GetLeaderDC(placement.DCLabelKey) - if !ok { - errMsg := fmt.Sprintf("table %v's leader placement policy is not defined", tableName) - if len(partitionName) > 0 { - errMsg = fmt.Sprintf("table %v's partition %v's leader placement policy is not defined", tableName, partitionName) - } - err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) - break - } - if dcLocation != txnScope { - errMsg := fmt.Sprintf("table %v's leader location %v is out of txn_scope %v", tableName, dcLocation, txnScope) - if len(partitionName) > 0 { - errMsg = fmt.Sprintf("table %v's partition %v's leader location %v is out of txn_scope %v", - tableName, partitionName, dcLocation, txnScope) - } - err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs(errMsg) - break - } - // FIXME: currently we assume the physicalTableID is the partition ID. In future, we should consider the situation - // if the physicalTableID belongs to a Table. - partitionID := physicalTableID - tbl, _, partitionDefInfo := is.FindTableByPartitionID(partitionID) - if tbl != nil { - tblInfo := tbl.Meta() - state := tblInfo.Partition.GetStateByID(partitionID) - if state == model.StateGlobalTxnOnly { - err = dbterror.ErrInvalidPlacementPolicyCheck.GenWithStackByArgs( - fmt.Sprintf("partition %s of table %s can not be written by local transactions when its placement policy is being altered", - tblInfo.Name, partitionDefInfo.Name)) - break - } - } - } - } - return err -} - -func (s *session) SetPort(port string) { - s.sessionVars.Port = port -} - -// GetTxnWriteThroughputSLI implements the Context interface. -func (s *session) GetTxnWriteThroughputSLI() *sli.TxnWriteThroughputSLI { - return &s.txn.writeSLI -} - -// GetInfoSchema returns snapshotInfoSchema if snapshot schema is set. -// Transaction infoschema is returned if inside an explicit txn. -// Otherwise the latest infoschema is returned. -func (s *session) GetInfoSchema() infoschemactx.MetaOnlyInfoSchema { - vars := s.GetSessionVars() - var is infoschema.InfoSchema - if snap, ok := vars.SnapshotInfoschema.(infoschema.InfoSchema); ok { - logutil.BgLogger().Info("use snapshot schema", zap.Uint64("conn", vars.ConnectionID), zap.Int64("schemaVersion", snap.SchemaMetaVersion())) - is = snap - } else { - vars.TxnCtxMu.Lock() - if vars.TxnCtx != nil { - if tmp, ok := vars.TxnCtx.InfoSchema.(infoschema.InfoSchema); ok { - is = tmp - } - } - vars.TxnCtxMu.Unlock() - } - - if is == nil { - is = domain.GetDomain(s).InfoSchema() - } - - // Override the infoschema if the session has temporary table. - return temptable.AttachLocalTemporaryTableInfoSchema(s, is) -} - -func (s *session) GetDomainInfoSchema() infoschemactx.MetaOnlyInfoSchema { - is := domain.GetDomain(s).InfoSchema() - extIs := &infoschema.SessionExtendedInfoSchema{InfoSchema: is} - return temptable.AttachLocalTemporaryTableInfoSchema(s, extIs) -} - -func getSnapshotInfoSchema(s sessionctx.Context, snapshotTS uint64) (infoschema.InfoSchema, error) { - is, err := domain.GetDomain(s).GetSnapshotInfoSchema(snapshotTS) - if err != nil { - return nil, err - } - // Set snapshot does not affect the witness of the local temporary table. - // The session always see the latest temporary tables. - return temptable.AttachLocalTemporaryTableInfoSchema(s, is), nil -} - -func (s *session) GetStmtStats() *stmtstats.StatementStats { - return s.stmtStats -} - -// SetMemoryFootprintChangeHook sets the hook that is called when the memdb changes its size. -// Call this after s.txn becomes valid, since TxnInfo is initialized when the txn becomes valid. -func (s *session) SetMemoryFootprintChangeHook() { - if s.txn.MemHookSet() { - return - } - if config.GetGlobalConfig().Performance.TxnTotalSizeLimit != config.DefTxnTotalSizeLimit { - // if the user manually specifies the config, don't involve the new memory tracker mechanism, let the old config - // work as before. - return - } - hook := func(mem uint64) { - if s.sessionVars.MemDBFootprint == nil { - tracker := memory.NewTracker(memory.LabelForMemDB, -1) - tracker.AttachTo(s.sessionVars.MemTracker) - s.sessionVars.MemDBFootprint = tracker - } - s.sessionVars.MemDBFootprint.ReplaceBytesUsed(int64(mem)) - } - s.txn.SetMemoryFootprintChangeHook(hook) -} - -// EncodeSessionStates implements SessionStatesHandler.EncodeSessionStates interface. -func (s *session) EncodeSessionStates(ctx context.Context, - _ sessionctx.Context, sessionStates *sessionstates.SessionStates) error { - // Transaction status is hard to encode, so we do not support it. - s.txn.mu.Lock() - valid := s.txn.Valid() - s.txn.mu.Unlock() - if valid { - return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has an active transaction") - } - // Data in local temporary tables is hard to encode, so we do not support it. - // Check temporary tables here to avoid circle dependency. - if s.sessionVars.LocalTemporaryTables != nil { - localTempTables := s.sessionVars.LocalTemporaryTables.(*infoschema.SessionTables) - if localTempTables.Count() > 0 { - return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has local temporary tables") - } - } - // The advisory locks will be released when the session is closed. - if len(s.advisoryLocks) > 0 { - return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has advisory locks") - } - // The TableInfo stores session ID and server ID, so the session cannot be migrated. - if len(s.lockedTables) > 0 { - return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session has locked tables") - } - // It's insecure to migrate sandBoxMode because users can fake it. - if s.InSandBoxMode() { - return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs("session is in sandbox mode") - } - - if err := s.sessionVars.EncodeSessionStates(ctx, sessionStates); err != nil { - return err - } - - hasRestrictVarPriv := false - checker := privilege.GetPrivilegeManager(s) - if checker == nil || checker.RequestDynamicVerification(s.sessionVars.ActiveRoles, "RESTRICTED_VARIABLES_ADMIN", false) { - hasRestrictVarPriv = true - } - // Encode session variables. We put it here instead of SessionVars to avoid cycle import. - sessionStates.SystemVars = make(map[string]string) - for _, sv := range variable.GetSysVars() { - switch { - case sv.HasNoneScope(), !sv.HasSessionScope(): - // Hidden attribute is deprecated. - // None-scoped variables cannot be modified. - // Noop variables should also be migrated even if they are noop. - continue - case sv.ReadOnly: - // Skip read-only variables here. We encode them into SessionStates manually. - continue - } - // Get all session variables because the default values may change between versions. - val, keep, err := s.sessionVars.GetSessionStatesSystemVar(sv.Name) - switch { - case err != nil: - return err - case !keep: - continue - case !hasRestrictVarPriv && sem.IsEnabled() && sem.IsInvisibleSysVar(sv.Name): - // If the variable has a global scope, it should be the same with the global one. - // Otherwise, it should be the same with the default value. - defaultVal := sv.Value - if sv.HasGlobalScope() { - // If the session value is the same with the global one, skip it. - if defaultVal, err = sv.GetGlobalFromHook(ctx, s.sessionVars); err != nil { - return err - } - } - if val != defaultVal { - // Case 1: the RESTRICTED_VARIABLES_ADMIN is revoked after setting the session variable. - // Case 2: the global variable is updated after the session is created. - // In any case, the variable can't be set in the new session, so give up. - return sessionstates.ErrCannotMigrateSession.GenWithStackByArgs(fmt.Sprintf("session has set invisible variable '%s'", sv.Name)) - } - default: - sessionStates.SystemVars[sv.Name] = val - } - } - - // Encode prepared statements and sql bindings. - for _, handler := range s.sessionStatesHandlers { - if err := handler.EncodeSessionStates(ctx, s, sessionStates); err != nil { - return err - } - } - return nil -} - -// DecodeSessionStates implements SessionStatesHandler.DecodeSessionStates interface. -func (s *session) DecodeSessionStates(ctx context.Context, - _ sessionctx.Context, sessionStates *sessionstates.SessionStates) error { - // Decode prepared statements and sql bindings. - for _, handler := range s.sessionStatesHandlers { - if err := handler.DecodeSessionStates(ctx, s, sessionStates); err != nil { - return err - } - } - - // Decode session variables. - names := variable.OrderByDependency(sessionStates.SystemVars) - // Some variables must be set before others, e.g. tidb_enable_noop_functions should be before noop variables. - for _, name := range names { - val := sessionStates.SystemVars[name] - // Experimental system variables may change scope, data types, or even be removed. - // We just ignore the errors and continue. - if err := s.sessionVars.SetSystemVar(name, val); err != nil { - logutil.Logger(ctx).Warn("set session variable during decoding session states error", - zap.String("name", name), zap.String("value", val), zap.Error(err)) - } - } - - // Decoding session vars / prepared statements may override stmt ctx, such as warnings, - // so we decode stmt ctx at last. - return s.sessionVars.DecodeSessionStates(ctx, sessionStates) -} - -func (s *session) setRequestSource(ctx context.Context, stmtLabel string, stmtNode ast.StmtNode) { - if !s.isInternal() { - if txn, _ := s.Txn(false); txn != nil && txn.Valid() { - if txn.IsPipelined() { - stmtLabel = "pdml" - } - txn.SetOption(kv.RequestSourceType, stmtLabel) - } - s.sessionVars.RequestSourceType = stmtLabel - return - } - if source := ctx.Value(kv.RequestSourceKey); source != nil { - requestSource := source.(kv.RequestSource) - if requestSource.RequestSourceType != "" { - s.sessionVars.RequestSourceType = requestSource.RequestSourceType - return - } - } - // panic in test mode in case there are requests without source in the future. - // log warnings in production mode. - if intest.InTest { - panic("unexpected no source type context, if you see this error, " + - "the `RequestSourceTypeKey` is missing in your context") - } - logutil.Logger(ctx).Warn("unexpected no source type context, if you see this warning, "+ - "the `RequestSourceTypeKey` is missing in the context", - zap.Bool("internal", s.isInternal()), - zap.String("sql", stmtNode.Text())) -} - -// NewStmtIndexUsageCollector creates a new `*indexusage.StmtIndexUsageCollector` based on the internal session index -// usage collector -func (s *session) NewStmtIndexUsageCollector() *indexusage.StmtIndexUsageCollector { - if s.idxUsageCollector == nil { - return nil - } - - return indexusage.NewStmtIndexUsageCollector(s.idxUsageCollector) -} - -// usePipelinedDmlOrWarn returns the current statement can be executed as a pipelined DML. -func (s *session) usePipelinedDmlOrWarn(ctx context.Context) bool { - if !s.sessionVars.BulkDMLEnabled { - return false - } - stmtCtx := s.sessionVars.StmtCtx - if stmtCtx == nil { - return false - } - if stmtCtx.IsReadOnly { - return false - } - vars := s.GetSessionVars() - if !vars.TxnCtx.EnableMDL { - stmtCtx.AppendWarning( - errors.New( - "Pipelined DML can not be used without Metadata Lock. Fallback to standard mode", - ), - ) - return false - } - if (vars.BatchCommit || vars.BatchInsert || vars.BatchDelete) && vars.DMLBatchSize > 0 && variable.EnableBatchDML.Load() { - stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used with the deprecated Batch DML. Fallback to standard mode")) - return false - } - if vars.BinlogClient != nil { - stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used with Binlog: BinlogClient != nil. Fallback to standard mode")) - return false - } - if !(stmtCtx.InInsertStmt || stmtCtx.InDeleteStmt || stmtCtx.InUpdateStmt) { - if !stmtCtx.IsReadOnly { - stmtCtx.AppendWarning(errors.New("Pipelined DML can only be used for auto-commit INSERT, REPLACE, UPDATE or DELETE. Fallback to standard mode")) - } - return false - } - if s.isInternal() { - stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used for internal SQL. Fallback to standard mode")) - return false - } - if vars.InTxn() { - stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used in transaction. Fallback to standard mode")) - return false - } - if !vars.IsAutocommit() { - stmtCtx.AppendWarning(errors.New("Pipelined DML can only be used in autocommit mode. Fallback to standard mode")) - return false - } - if s.GetSessionVars().ConstraintCheckInPlace { - // we enforce that pipelined DML must lazily check key. - stmtCtx.AppendWarning( - errors.New( - "Pipelined DML can not be used when tidb_constraint_check_in_place=ON. " + - "Fallback to standard mode", - ), - ) - return false - } - is, ok := s.GetDomainInfoSchema().(infoschema.InfoSchema) - if !ok { - stmtCtx.AppendWarning(errors.New("Pipelined DML failed to get latest InfoSchema. Fallback to standard mode")) - return false - } - for _, t := range stmtCtx.Tables { - // get table schema from current infoschema - tbl, err := is.TableByName(ctx, model.NewCIStr(t.DB), model.NewCIStr(t.Table)) - if err != nil { - stmtCtx.AppendWarning(errors.New("Pipelined DML failed to get table schema. Fallback to standard mode")) - return false - } - if tbl.Meta().IsView() { - stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used on view. Fallback to standard mode")) - return false - } - if tbl.Meta().IsSequence() { - stmtCtx.AppendWarning(errors.New("Pipelined DML can not be used on sequence. Fallback to standard mode")) - return false - } - if vars.ForeignKeyChecks && (len(tbl.Meta().ForeignKeys) > 0 || len(is.GetTableReferredForeignKeys(t.DB, t.Table)) > 0) { - stmtCtx.AppendWarning( - errors.New( - "Pipelined DML can not be used on table with foreign keys when foreign_key_checks = ON. Fallback to standard mode", - ), - ) - return false - } - if tbl.Meta().TempTableType != model.TempTableNone { - stmtCtx.AppendWarning( - errors.New( - "Pipelined DML can not be used on temporary tables. " + - "Fallback to standard mode", - ), - ) - return false - } - if tbl.Meta().TableCacheStatusType != model.TableCacheStatusDisable { - stmtCtx.AppendWarning( - errors.New( - "Pipelined DML can not be used on cached tables. " + - "Fallback to standard mode", - ), - ) - return false - } - } - - // tidb_dml_type=bulk will invalidate the config pessimistic-auto-commit. - // The behavior is as if the config is set to false. But we generate a warning for it. - if config.GetGlobalConfig().PessimisticTxn.PessimisticAutoCommit.Load() { - stmtCtx.AppendWarning( - errors.New( - "pessimistic-auto-commit config is ignored in favor of Pipelined DML", - ), - ) - } - return true -} - -// RemoveLockDDLJobs removes the DDL jobs which doesn't get the metadata lock from job2ver. -func RemoveLockDDLJobs(s types.Session, job2ver map[int64]int64, job2ids map[int64]string, printLog bool) { - sv := s.GetSessionVars() - if sv.InRestrictedSQL { - return - } - sv.TxnCtxMu.Lock() - defer sv.TxnCtxMu.Unlock() - if sv.TxnCtx == nil { - return - } - sv.GetRelatedTableForMDL().Range(func(tblID, value any) bool { - for jobID, ver := range job2ver { - ids := util.Str2Int64Map(job2ids[jobID]) - if _, ok := ids[tblID.(int64)]; ok && value.(int64) < ver { - delete(job2ver, jobID) - elapsedTime := time.Since(oracle.GetTimeFromTS(sv.TxnCtx.StartTS)) - if elapsedTime > time.Minute && printLog { - logutil.BgLogger().Info("old running transaction block DDL", zap.Int64("table ID", tblID.(int64)), zap.Int64("jobID", jobID), zap.Uint64("connection ID", sv.ConnectionID), zap.Duration("elapsed time", elapsedTime)) - } else { - logutil.BgLogger().Debug("old running transaction block DDL", zap.Int64("table ID", tblID.(int64)), zap.Int64("jobID", jobID), zap.Uint64("connection ID", sv.ConnectionID), zap.Duration("elapsed time", elapsedTime)) - } - } - } - return true - }) -} - -// GetDBNames gets the sql layer database names from the session. -func GetDBNames(seVar *variable.SessionVars) []string { - dbNames := make(map[string]struct{}) - if seVar == nil || !config.GetGlobalConfig().Status.RecordDBLabel { - return []string{""} - } - if seVar.StmtCtx != nil { - for _, t := range seVar.StmtCtx.Tables { - dbNames[t.DB] = struct{}{} - } - } - if len(dbNames) == 0 { - dbNames[seVar.CurrentDB] = struct{}{} - } - ns := make([]string, 0, len(dbNames)) - for n := range dbNames { - ns = append(ns, n) - } - return ns -} - -// GetCursorTracker returns the internal `cursor.Tracker` -func (s *session) GetCursorTracker() cursor.Tracker { - return s.cursorTracker -} diff --git a/pkg/session/sync_upgrade.go b/pkg/session/sync_upgrade.go index 9fb0ae318ed93..52829793f6af7 100644 --- a/pkg/session/sync_upgrade.go +++ b/pkg/session/sync_upgrade.go @@ -81,14 +81,14 @@ func SyncUpgradeState(s sessionctx.Context, timeout time.Duration) error { // SyncNormalRunning syncs normal state to etcd. func SyncNormalRunning(s sessionctx.Context) error { bgCtx := context.Background() - if val, _err_ := failpoint.Eval(_curpkg_("mockResumeAllJobsFailed")); _err_ == nil { + failpoint.Inject("mockResumeAllJobsFailed", func(val failpoint.Value) { if val.(bool) { dom := domain.GetDomain(s) //nolint: errcheck dom.DDL().StateSyncer().UpdateGlobalState(bgCtx, syncer.NewStateInfo(syncer.StateNormalRunning)) - return nil + failpoint.Return(nil) } - } + }) logger := logutil.BgLogger().With(zap.String("category", "upgrading")) jobErrs, err := ddl.ResumeAllJobsBySystem(s) diff --git a/pkg/session/sync_upgrade.go__failpoint_stash__ b/pkg/session/sync_upgrade.go__failpoint_stash__ deleted file mode 100644 index 52829793f6af7..0000000000000 --- a/pkg/session/sync_upgrade.go__failpoint_stash__ +++ /dev/null @@ -1,163 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "context" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/ddl/syncer" - dist_store "github.com/pingcap/tidb/pkg/disttask/framework/storage" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/owner" - sessiontypes "github.com/pingcap/tidb/pkg/session/types" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -// isContextDone checks if context is done. -func isContextDone(ctx context.Context) bool { - select { - case <-ctx.Done(): - return true - default: - } - return false -} - -// SyncUpgradeState syncs upgrade state to etcd. -func SyncUpgradeState(s sessionctx.Context, timeout time.Duration) error { - ctx, cancelFunc := context.WithTimeout(context.Background(), timeout) - defer cancelFunc() - dom := domain.GetDomain(s) - err := dom.DDL().StateSyncer().UpdateGlobalState(ctx, syncer.NewStateInfo(syncer.StateUpgrading)) - logger := logutil.BgLogger().With(zap.String("category", "upgrading")) - if err != nil { - logger.Error("update global state failed", zap.String("state", syncer.StateUpgrading), zap.Error(err)) - return err - } - - interval := 200 * time.Millisecond - for i := 0; ; i++ { - if isContextDone(ctx) { - logger.Error("get owner op failed", zap.Duration("timeout", timeout), zap.Error(err)) - return ctx.Err() - } - - var op owner.OpType - childCtx, cancel := context.WithTimeout(ctx, 3*time.Second) - op, err = owner.GetOwnerOpValue(childCtx, dom.EtcdClient(), ddl.DDLOwnerKey, "upgrade bootstrap") - cancel() - if err == nil && op.IsSyncedUpgradingState() { - break - } - if i%10 == 0 { - logger.Warn("get owner op failed", zap.Stringer("op", op), zap.Error(err)) - } - time.Sleep(interval) - } - - logger.Info("update global state to upgrading", zap.String("state", syncer.StateUpgrading)) - return nil -} - -// SyncNormalRunning syncs normal state to etcd. -func SyncNormalRunning(s sessionctx.Context) error { - bgCtx := context.Background() - failpoint.Inject("mockResumeAllJobsFailed", func(val failpoint.Value) { - if val.(bool) { - dom := domain.GetDomain(s) - //nolint: errcheck - dom.DDL().StateSyncer().UpdateGlobalState(bgCtx, syncer.NewStateInfo(syncer.StateNormalRunning)) - failpoint.Return(nil) - } - }) - - logger := logutil.BgLogger().With(zap.String("category", "upgrading")) - jobErrs, err := ddl.ResumeAllJobsBySystem(s) - if err != nil { - logger.Warn("resume all paused jobs failed", zap.Error(err)) - } - for _, e := range jobErrs { - logger.Warn("resume the job failed", zap.Error(e)) - } - - if mgr, _ := dist_store.GetTaskManager(); mgr != nil { - ctx := kv.WithInternalSourceType(bgCtx, kv.InternalDistTask) - err := mgr.AdjustTaskOverflowConcurrency(ctx, s) - if err != nil { - log.Warn("cannot adjust task overflow concurrency", zap.Error(err)) - } - } - - ctx, cancelFunc := context.WithTimeout(bgCtx, 3*time.Second) - defer cancelFunc() - dom := domain.GetDomain(s) - err = dom.DDL().StateSyncer().UpdateGlobalState(ctx, syncer.NewStateInfo(syncer.StateNormalRunning)) - if err != nil { - logger.Error("update global state to normal failed", zap.Error(err)) - return err - } - logger.Info("update global state to normal running finished") - return nil -} - -// IsUpgradingClusterState checks whether the global state is upgrading. -func IsUpgradingClusterState(s sessionctx.Context) (bool, error) { - dom := domain.GetDomain(s) - ctx, cancelFunc := context.WithTimeout(context.Background(), 3*time.Second) - defer cancelFunc() - stateInfo, err := dom.DDL().StateSyncer().GetGlobalState(ctx) - if err != nil { - return false, err - } - - return stateInfo.State == syncer.StateUpgrading, nil -} - -func printClusterState(s sessiontypes.Session, ver int64) { - // After SupportUpgradeHTTPOpVer version, the upgrade by paused user DDL can be notified through the HTTP API. - // We check the global state see if we are upgrading by paused the user DDL. - if ver >= SupportUpgradeHTTPOpVer { - isUpgradingClusterStateWithRetry(s, ver, currentBootstrapVersion, time.Duration(internalSQLTimeout)*time.Second) - } -} - -func isUpgradingClusterStateWithRetry(s sessionctx.Context, oldVer, newVer int64, timeout time.Duration) { - now := time.Now() - interval := 200 * time.Millisecond - logger := logutil.BgLogger().With(zap.String("category", "upgrading")) - for i := 0; ; i++ { - isUpgrading, err := IsUpgradingClusterState(s) - if err == nil { - logger.Info("get global state", zap.Int64("old version", oldVer), zap.Int64("latest version", newVer), zap.Bool("is upgrading state", isUpgrading)) - return - } - - if time.Since(now) >= timeout { - logger.Error("get global state failed", zap.Int64("old version", oldVer), zap.Int64("latest version", newVer), zap.Error(err)) - return - } - if i%25 == 0 { - logger.Warn("get global state failed", zap.Int64("old version", oldVer), zap.Int64("latest version", newVer), zap.Error(err)) - } - time.Sleep(interval) - } -} diff --git a/pkg/session/tidb.go b/pkg/session/tidb.go index 1bd907e58d8c4..a59529dc4669a 100644 --- a/pkg/session/tidb.go +++ b/pkg/session/tidb.go @@ -229,9 +229,9 @@ func recordAbortTxnDuration(sessVars *variable.SessionVars, isInternal bool) { } func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { - if _, _err_ := failpoint.Eval(_curpkg_("finishStmtError")); _err_ == nil { - return errors.New("occur an error after finishStmt") - } + failpoint.Inject("finishStmtError", func() { + failpoint.Return(errors.New("occur an error after finishStmt")) + }) sessVars := se.sessionVars if !sql.IsReadOnly(sessVars) { // All the history should be added here. diff --git a/pkg/session/tidb.go__failpoint_stash__ b/pkg/session/tidb.go__failpoint_stash__ deleted file mode 100644 index a59529dc4669a..0000000000000 --- a/pkg/session/tidb.go__failpoint_stash__ +++ /dev/null @@ -1,403 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -package session - -import ( - "context" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/ddl/schematracker" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/executor" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - session_metrics "github.com/pingcap/tidb/pkg/session/metrics" - "github.com/pingcap/tidb/pkg/session/types" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/pingcap/tidb/pkg/util/syncutil" - "go.uber.org/zap" -) - -type domainMap struct { - mu syncutil.Mutex - domains map[string]*domain.Domain -} - -// Get or create the domain for store. -// TODO decouple domain create from it, it's more clear to create domain explicitly -// before any usage of it. -func (dm *domainMap) Get(store kv.Storage) (d *domain.Domain, err error) { - dm.mu.Lock() - defer dm.mu.Unlock() - - if store == nil { - for _, d := range dm.domains { - // return available domain if any - return d, nil - } - return nil, errors.New("can not find available domain for a nil store") - } - - key := store.UUID() - - d = dm.domains[key] - if d != nil { - return - } - - ddlLease := time.Duration(atomic.LoadInt64(&schemaLease)) - statisticLease := time.Duration(atomic.LoadInt64(&statsLease)) - planReplayerGCLease := GetPlanReplayerGCLease() - err = util.RunWithRetry(util.DefaultMaxRetries, util.RetryInterval, func() (retry bool, err1 error) { - logutil.BgLogger().Info("new domain", - zap.String("store", store.UUID()), - zap.Stringer("ddl lease", ddlLease), - zap.Stringer("stats lease", statisticLease)) - factory := createSessionFunc(store) - sysFactory := createSessionWithDomainFunc(store) - d = domain.NewDomain(store, ddlLease, statisticLease, planReplayerGCLease, factory) - - var ddlInjector func(ddl.DDL, ddl.Executor, *infoschema.InfoCache) *schematracker.Checker - if injector, ok := store.(schematracker.StorageDDLInjector); ok { - ddlInjector = injector.Injector - } - err1 = d.Init(ddlLease, sysFactory, ddlInjector) - if err1 != nil { - // If we don't clean it, there are some dirty data when retrying the function of Init. - d.Close() - logutil.BgLogger().Error("init domain failed", zap.String("category", "ddl"), - zap.Error(err1)) - } - return true, err1 - }) - if err != nil { - return nil, err - } - dm.domains[key] = d - d.SetOnClose(func() { - dm.Delete(store) - }) - - return -} - -func (dm *domainMap) Delete(store kv.Storage) { - dm.mu.Lock() - delete(dm.domains, store.UUID()) - dm.mu.Unlock() -} - -var ( - domap = &domainMap{ - domains: map[string]*domain.Domain{}, - } - // store.UUID()-> IfBootstrapped - storeBootstrapped = make(map[string]bool) - storeBootstrappedLock sync.Mutex - - // schemaLease is the time for re-updating remote schema. - // In online DDL, we must wait 2 * SchemaLease time to guarantee - // all servers get the neweset schema. - // Default schema lease time is 1 second, you can change it with a proper time, - // but you must know that too little may cause badly performance degradation. - // For production, you should set a big schema lease, like 300s+. - schemaLease = int64(1 * time.Second) - - // statsLease is the time for reload stats table. - statsLease = int64(3 * time.Second) - - // planReplayerGCLease is the time for plan replayer gc. - planReplayerGCLease = int64(10 * time.Minute) -) - -// ResetStoreForWithTiKVTest is only used in the test code. -// TODO: Remove domap and storeBootstrapped. Use store.SetOption() to do it. -func ResetStoreForWithTiKVTest(store kv.Storage) { - domap.Delete(store) - unsetStoreBootstrapped(store.UUID()) -} - -func setStoreBootstrapped(storeUUID string) { - storeBootstrappedLock.Lock() - defer storeBootstrappedLock.Unlock() - storeBootstrapped[storeUUID] = true -} - -// unsetStoreBootstrapped delete store uuid from stored bootstrapped map. -// currently this function only used for test. -func unsetStoreBootstrapped(storeUUID string) { - storeBootstrappedLock.Lock() - defer storeBootstrappedLock.Unlock() - delete(storeBootstrapped, storeUUID) -} - -// SetSchemaLease changes the default schema lease time for DDL. -// This function is very dangerous, don't use it if you really know what you do. -// SetSchemaLease only affects not local storage after bootstrapped. -func SetSchemaLease(lease time.Duration) { - atomic.StoreInt64(&schemaLease, int64(lease)) -} - -// SetStatsLease changes the default stats lease time for loading stats info. -func SetStatsLease(lease time.Duration) { - atomic.StoreInt64(&statsLease, int64(lease)) -} - -// SetPlanReplayerGCLease changes the default plan repalyer gc lease time. -func SetPlanReplayerGCLease(lease time.Duration) { - atomic.StoreInt64(&planReplayerGCLease, int64(lease)) -} - -// GetPlanReplayerGCLease returns the plan replayer gc lease time. -func GetPlanReplayerGCLease() time.Duration { - return time.Duration(atomic.LoadInt64(&planReplayerGCLease)) -} - -// DisableStats4Test disables the stats for tests. -func DisableStats4Test() { - SetStatsLease(-1) -} - -// Parse parses a query string to raw ast.StmtNode. -func Parse(ctx sessionctx.Context, src string) ([]ast.StmtNode, error) { - logutil.BgLogger().Debug("compiling", zap.String("source", src)) - sessVars := ctx.GetSessionVars() - p := parser.New() - p.SetParserConfig(sessVars.BuildParserConfig()) - p.SetSQLMode(sessVars.SQLMode) - stmts, warns, err := p.ParseSQL(src, sessVars.GetParseParams()...) - for _, warn := range warns { - sessVars.StmtCtx.AppendWarning(warn) - } - if err != nil { - logutil.BgLogger().Warn("compiling", - zap.String("source", src), - zap.Error(err)) - return nil, err - } - return stmts, nil -} - -func recordAbortTxnDuration(sessVars *variable.SessionVars, isInternal bool) { - duration := time.Since(sessVars.TxnCtx.CreateTime).Seconds() - if sessVars.TxnCtx.IsPessimistic { - if isInternal { - session_metrics.TransactionDurationPessimisticAbortInternal.Observe(duration) - } else { - session_metrics.TransactionDurationPessimisticAbortGeneral.Observe(duration) - } - } else { - if isInternal { - session_metrics.TransactionDurationOptimisticAbortInternal.Observe(duration) - } else { - session_metrics.TransactionDurationOptimisticAbortGeneral.Observe(duration) - } - } -} - -func finishStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { - failpoint.Inject("finishStmtError", func() { - failpoint.Return(errors.New("occur an error after finishStmt")) - }) - sessVars := se.sessionVars - if !sql.IsReadOnly(sessVars) { - // All the history should be added here. - if meetsErr == nil && sessVars.TxnCtx.CouldRetry { - GetHistory(se).Add(sql, sessVars.StmtCtx) - } - - // Handle the stmt commit/rollback. - if se.txn.Valid() { - if meetsErr != nil { - se.StmtRollback(ctx, false) - } else { - se.StmtCommit(ctx) - } - } - } - err := autoCommitAfterStmt(ctx, se, meetsErr, sql) - if se.txn.pending() { - // After run statement finish, txn state is still pending means the - // statement never need a Txn(), such as: - // - // set @@tidb_general_log = 1 - // set @@autocommit = 0 - // select 1 - // - // Reset txn state to invalid to dispose the pending start ts. - se.txn.changeToInvalid() - } - if err != nil { - return err - } - return checkStmtLimit(ctx, se, true) -} - -func autoCommitAfterStmt(ctx context.Context, se *session, meetsErr error, sql sqlexec.Statement) error { - isInternal := false - if internal := se.txn.GetOption(kv.RequestSourceInternal); internal != nil && internal.(bool) { - isInternal = true - } - sessVars := se.sessionVars - if meetsErr != nil { - if !sessVars.InTxn() { - logutil.BgLogger().Info("rollbackTxn called due to ddl/autocommit failure") - se.RollbackTxn(ctx) - recordAbortTxnDuration(sessVars, isInternal) - } else if se.txn.Valid() && se.txn.IsPessimistic() && exeerrors.ErrDeadlock.Equal(meetsErr) { - logutil.BgLogger().Info("rollbackTxn for deadlock", zap.Uint64("txn", se.txn.StartTS())) - se.RollbackTxn(ctx) - recordAbortTxnDuration(sessVars, isInternal) - } - return meetsErr - } - - if !sessVars.InTxn() { - if err := se.CommitTxn(ctx); err != nil { - if _, ok := sql.(*executor.ExecStmt).StmtNode.(*ast.CommitStmt); ok { - err = errors.Annotatef(err, "previous statement: %s", se.GetSessionVars().PrevStmt) - } - return err - } - return nil - } - return nil -} - -func checkStmtLimit(ctx context.Context, se *session, isFinish bool) error { - // If the user insert, insert, insert ... but never commit, TiDB would OOM. - // So we limit the statement count in a transaction here. - var err error - sessVars := se.GetSessionVars() - history := GetHistory(se) - stmtCount := history.Count() - if !isFinish { - // history stmt count + current stmt, since current stmt is not finish, it has not add to history. - stmtCount++ - } - if stmtCount > int(config.GetGlobalConfig().Performance.StmtCountLimit) { - if !sessVars.BatchCommit { - se.RollbackTxn(ctx) - return errors.Errorf("statement count %d exceeds the transaction limitation, transaction has been rollback, autocommit = %t", - stmtCount, sessVars.IsAutocommit()) - } - if !isFinish { - // if the stmt is not finish execute, then just return, since some work need to be done such as StmtCommit. - return nil - } - // If the stmt is finish execute, and exceed the StmtCountLimit, and BatchCommit is true, - // then commit the current transaction and create a new transaction. - err = sessiontxn.NewTxn(ctx, se) - // The transaction does not committed yet, we need to keep it in transaction. - // The last history could not be "commit"/"rollback" statement. - // It means it is impossible to start a new transaction at the end of the transaction. - // Because after the server executed "commit"/"rollback" statement, the session is out of the transaction. - sessVars.SetInTxn(true) - } - return err -} - -// GetHistory get all stmtHistory in current txn. Exported only for test. -// If stmtHistory is nil, will create a new one for current txn. -func GetHistory(ctx sessionctx.Context) *StmtHistory { - hist, ok := ctx.GetSessionVars().TxnCtx.History.(*StmtHistory) - if ok { - return hist - } - hist = new(StmtHistory) - ctx.GetSessionVars().TxnCtx.History = hist - return hist -} - -// GetRows4Test gets all the rows from a RecordSet, only used for test. -func GetRows4Test(ctx context.Context, _ sessionctx.Context, rs sqlexec.RecordSet) ([]chunk.Row, error) { - if rs == nil { - return nil, nil - } - var rows []chunk.Row - req := rs.NewChunk(nil) - // Must reuse `req` for imitating server.(*clientConn).writeChunks - for { - err := rs.Next(ctx, req) - if err != nil { - return nil, err - } - if req.NumRows() == 0 { - break - } - - iter := chunk.NewIterator4Chunk(req.CopyConstruct()) - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - rows = append(rows, row) - } - } - return rows, nil -} - -// ResultSetToStringSlice changes the RecordSet to [][]string. -func ResultSetToStringSlice(ctx context.Context, s types.Session, rs sqlexec.RecordSet) ([][]string, error) { - rows, err := GetRows4Test(ctx, s, rs) - if err != nil { - return nil, err - } - err = rs.Close() - if err != nil { - return nil, err - } - sRows := make([][]string, len(rows)) - for i := range rows { - row := rows[i] - iRow := make([]string, row.Len()) - for j := 0; j < row.Len(); j++ { - if row.IsNull(j) { - iRow[j] = "" - } else { - d := row.GetDatum(j, &rs.Fields()[j].Column.FieldType) - iRow[j], err = d.ToString() - if err != nil { - return nil, err - } - } - } - sRows[i] = iRow - } - return sRows, nil -} - -// Session errors. -var ( - ErrForUpdateCantRetry = dbterror.ClassSession.NewStd(errno.ErrForUpdateCantRetry) -) diff --git a/pkg/session/txn.go b/pkg/session/txn.go index c051166713ed7..0ec88c4a5667a 100644 --- a/pkg/session/txn.go +++ b/pkg/session/txn.go @@ -422,29 +422,29 @@ func (txn *LazyTxn) Commit(ctx context.Context) error { txn.updateState(txninfo.TxnCommitting) txn.mu.Unlock() - failpoint.Eval(_curpkg_("mockSlowCommit")) + failpoint.Inject("mockSlowCommit", func(_ failpoint.Value) {}) // mockCommitError8942 is used for PR #8942. - if val, _err_ := failpoint.Eval(_curpkg_("mockCommitError8942")); _err_ == nil { + failpoint.Inject("mockCommitError8942", func(val failpoint.Value) { if val.(bool) { - return kv.ErrTxnRetryable + failpoint.Return(kv.ErrTxnRetryable) } - } + }) // mockCommitRetryForAutoIncID is used to mock an commit retry for adjustAutoIncrementDatum. - if val, _err_ := failpoint.Eval(_curpkg_("mockCommitRetryForAutoIncID")); _err_ == nil { + failpoint.Inject("mockCommitRetryForAutoIncID", func(val failpoint.Value) { if val.(bool) && !mockAutoIncIDRetry() { enableMockAutoIncIDRetry() - return kv.ErrTxnRetryable + failpoint.Return(kv.ErrTxnRetryable) } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("mockCommitRetryForAutoRandID")); _err_ == nil { + failpoint.Inject("mockCommitRetryForAutoRandID", func(val failpoint.Value) { if val.(bool) && needMockAutoRandIDRetry() { decreaseMockAutoRandIDRetryCount() - return kv.ErrTxnRetryable + failpoint.Return(kv.ErrTxnRetryable) } - } + }) return txn.Transaction.Commit(ctx) } @@ -456,7 +456,7 @@ func (txn *LazyTxn) Rollback() error { txn.updateState(txninfo.TxnRollingBack) txn.mu.Unlock() // mockSlowRollback is used to mock a rollback which takes a long time - failpoint.Eval(_curpkg_("mockSlowRollback")) + failpoint.Inject("mockSlowRollback", func(_ failpoint.Value) {}) return txn.Transaction.Rollback() } @@ -474,7 +474,7 @@ func (txn *LazyTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keys ...k // LockKeysFunc Wrap the inner transaction's `LockKeys` to record the status func (txn *LazyTxn) LockKeysFunc(ctx context.Context, lockCtx *kv.LockCtx, fn func(), keys ...kv.Key) error { - failpoint.Eval(_curpkg_("beforeLockKeys")) + failpoint.Inject("beforeLockKeys", func() {}) t := time.Now() var originState txninfo.TxnRunningState @@ -705,7 +705,7 @@ type txnFuture struct { func (tf *txnFuture) wait() (kv.Transaction, error) { startTS, err := tf.future.Wait() - failpoint.Eval(_curpkg_("txnFutureWait")) + failpoint.Inject("txnFutureWait", func() {}) if err == nil { if tf.pipelined { return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithStartTS(startTS), tikv.WithPipelinedMemDB()) diff --git a/pkg/session/txn.go__failpoint_stash__ b/pkg/session/txn.go__failpoint_stash__ deleted file mode 100644 index 0ec88c4a5667a..0000000000000 --- a/pkg/session/txn.go__failpoint_stash__ +++ /dev/null @@ -1,778 +0,0 @@ -// Copyright 2018 PingCAP, Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "bytes" - "context" - "fmt" - "runtime/trace" - "strings" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/session/txninfo" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sli" - "github.com/pingcap/tidb/pkg/util/syncutil" - "github.com/pingcap/tipb/go-binlog" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - "go.uber.org/zap" -) - -// LazyTxn wraps kv.Transaction to provide a new kv.Transaction. -// 1. It holds all statement related modification in the buffer before flush to the txn, -// so if execute statement meets error, the txn won't be made dirty. -// 2. It's a lazy transaction, that means it's a txnFuture before StartTS() is really need. -type LazyTxn struct { - // States of a LazyTxn should be one of the followings: - // Invalid: kv.Transaction == nil && txnFuture == nil - // Pending: kv.Transaction == nil && txnFuture != nil - // Valid: kv.Transaction != nil && txnFuture == nil - kv.Transaction - txnFuture *txnFuture - - initCnt int - stagingHandle kv.StagingHandle - mutations map[int64]*binlog.TableMutation - writeSLI sli.TxnWriteThroughputSLI - - enterFairLockingOnValid bool - - // TxnInfo is added for the lock view feature, the data is frequent modified but - // rarely read (just in query select * from information_schema.tidb_trx). - // The data in this session would be query by other sessions, so Mutex is necessary. - // Since read is rare, the reader can copy-on-read to get a data snapshot. - mu struct { - syncutil.RWMutex - txninfo.TxnInfo - } - - // mark the txn enables lazy uniqueness check in pessimistic transactions. - lazyUniquenessCheckEnabled bool -} - -// GetTableInfo returns the cached index name. -func (txn *LazyTxn) GetTableInfo(id int64) *model.TableInfo { - return txn.Transaction.GetTableInfo(id) -} - -// CacheTableInfo caches the index name. -func (txn *LazyTxn) CacheTableInfo(id int64, info *model.TableInfo) { - txn.Transaction.CacheTableInfo(id, info) -} - -func (txn *LazyTxn) init() { - txn.mutations = make(map[int64]*binlog.TableMutation) - txn.mu.Lock() - defer txn.mu.Unlock() - txn.mu.TxnInfo = txninfo.TxnInfo{} -} - -// call this under lock! -func (txn *LazyTxn) updateState(state txninfo.TxnRunningState) { - if txn.mu.TxnInfo.State != state { - lastState := txn.mu.TxnInfo.State - lastStateChangeTime := txn.mu.TxnInfo.LastStateChangeTime - txn.mu.TxnInfo.State = state - txn.mu.TxnInfo.LastStateChangeTime = time.Now() - if !lastStateChangeTime.IsZero() { - hasLockLbl := !txn.mu.TxnInfo.BlockStartTime.IsZero() - txninfo.TxnDurationHistogram(lastState, hasLockLbl).Observe(time.Since(lastStateChangeTime).Seconds()) - } - txninfo.TxnStatusEnteringCounter(state).Inc() - } -} - -func (txn *LazyTxn) initStmtBuf() { - if txn.Transaction == nil { - return - } - buf := txn.Transaction.GetMemBuffer() - txn.initCnt = buf.Len() - if !txn.IsPipelined() { - txn.stagingHandle = buf.Staging() - } -} - -// countHint is estimated count of mutations. -func (txn *LazyTxn) countHint() int { - if txn.stagingHandle == kv.InvalidStagingHandle { - return 0 - } - return txn.Transaction.GetMemBuffer().Len() - txn.initCnt -} - -func (txn *LazyTxn) flushStmtBuf() { - if txn.stagingHandle == kv.InvalidStagingHandle { - return - } - buf := txn.Transaction.GetMemBuffer() - - if txn.lazyUniquenessCheckEnabled { - keysNeedSetPersistentPNE := kv.FindKeysInStage(buf, txn.stagingHandle, func(_ kv.Key, flags kv.KeyFlags, _ []byte) bool { - return flags.HasPresumeKeyNotExists() - }) - for _, key := range keysNeedSetPersistentPNE { - buf.UpdateFlags(key, kv.SetPreviousPresumeKeyNotExists) - } - } - - if !txn.IsPipelined() { - buf.Release(txn.stagingHandle) - } - txn.initCnt = buf.Len() -} - -func (txn *LazyTxn) cleanupStmtBuf() { - if txn.stagingHandle == kv.InvalidStagingHandle { - return - } - buf := txn.Transaction.GetMemBuffer() - if !txn.IsPipelined() { - buf.Cleanup(txn.stagingHandle) - } - txn.initCnt = buf.Len() - - txn.mu.Lock() - defer txn.mu.Unlock() - txn.mu.TxnInfo.EntriesCount = uint64(txn.Transaction.Len()) -} - -// resetTxnInfo resets the transaction info. -// Note: call it under lock! -func (txn *LazyTxn) resetTxnInfo( - startTS uint64, - state txninfo.TxnRunningState, - entriesCount uint64, - currentSQLDigest string, - allSQLDigests []string, -) { - if !txn.mu.LastStateChangeTime.IsZero() { - lastState := txn.mu.State - hasLockLbl := !txn.mu.BlockStartTime.IsZero() - txninfo.TxnDurationHistogram(lastState, hasLockLbl).Observe(time.Since(txn.mu.TxnInfo.LastStateChangeTime).Seconds()) - } - if txn.mu.TxnInfo.StartTS != 0 { - txninfo.Recorder.OnTrxEnd(&txn.mu.TxnInfo) - } - txn.mu.TxnInfo = txninfo.TxnInfo{} - txn.mu.TxnInfo.StartTS = startTS - txn.mu.TxnInfo.State = state - txninfo.TxnStatusEnteringCounter(state).Inc() - txn.mu.TxnInfo.LastStateChangeTime = time.Now() - txn.mu.TxnInfo.EntriesCount = entriesCount - - txn.mu.TxnInfo.CurrentSQLDigest = currentSQLDigest - txn.mu.TxnInfo.AllSQLDigests = allSQLDigests -} - -// Size implements the MemBuffer interface. -func (txn *LazyTxn) Size() int { - if txn.Transaction == nil { - return 0 - } - return txn.Transaction.Size() -} - -// Mem implements the MemBuffer interface. -func (txn *LazyTxn) Mem() uint64 { - if txn.Transaction == nil { - return 0 - } - return txn.Transaction.Mem() -} - -// SetMemoryFootprintChangeHook sets the hook to be called when the memory footprint of this transaction changes. -func (txn *LazyTxn) SetMemoryFootprintChangeHook(hook func(uint64)) { - if txn.Transaction == nil { - return - } - txn.Transaction.SetMemoryFootprintChangeHook(hook) -} - -// MemHookSet returns whether the memory footprint change hook is set. -func (txn *LazyTxn) MemHookSet() bool { - if txn.Transaction == nil { - return false - } - return txn.Transaction.MemHookSet() -} - -// Valid implements the kv.Transaction interface. -func (txn *LazyTxn) Valid() bool { - return txn.Transaction != nil && txn.Transaction.Valid() -} - -func (txn *LazyTxn) pending() bool { - return txn.Transaction == nil && txn.txnFuture != nil -} - -func (txn *LazyTxn) validOrPending() bool { - return txn.txnFuture != nil || txn.Valid() -} - -func (txn *LazyTxn) String() string { - if txn.Transaction != nil { - return txn.Transaction.String() - } - if txn.txnFuture != nil { - res := "txnFuture" - if txn.enterFairLockingOnValid { - res += " (pending fair locking)" - } - return res - } - return "invalid transaction" -} - -// GoString implements the "%#v" format for fmt.Printf. -func (txn *LazyTxn) GoString() string { - var s strings.Builder - s.WriteString("Txn{") - if txn.pending() { - s.WriteString("state=pending") - } else if txn.Valid() { - s.WriteString("state=valid") - fmt.Fprintf(&s, ", txnStartTS=%d", txn.Transaction.StartTS()) - if len(txn.mutations) > 0 { - fmt.Fprintf(&s, ", len(mutations)=%d, %#v", len(txn.mutations), txn.mutations) - } - } else { - s.WriteString("state=invalid") - } - - s.WriteString("}") - return s.String() -} - -// GetOption implements the GetOption -func (txn *LazyTxn) GetOption(opt int) any { - if txn.Transaction == nil { - if opt == kv.TxnScope { - return "" - } - return nil - } - return txn.Transaction.GetOption(opt) -} - -func (txn *LazyTxn) changeToPending(future *txnFuture) { - txn.Transaction = nil - txn.txnFuture = future -} - -func (txn *LazyTxn) changePendingToValid(ctx context.Context, sctx sessionctx.Context) error { - if txn.txnFuture == nil { - return errors.New("transaction future is not set") - } - - future := txn.txnFuture - txn.txnFuture = nil - - defer trace.StartRegion(ctx, "WaitTsoFuture").End() - t, err := future.wait() - if err != nil { - txn.Transaction = nil - return err - } - txn.Transaction = t - txn.initStmtBuf() - - if txn.enterFairLockingOnValid { - txn.enterFairLockingOnValid = false - err = txn.Transaction.StartFairLocking() - if err != nil { - return err - } - } - - // The txnInfo may already recorded the first statement (usually "begin") when it's pending, so keep them. - txn.mu.Lock() - defer txn.mu.Unlock() - txn.resetTxnInfo( - t.StartTS(), - txninfo.TxnIdle, - uint64(txn.Transaction.Len()), - txn.mu.TxnInfo.CurrentSQLDigest, - txn.mu.TxnInfo.AllSQLDigests) - - // set resource group name for kv request such as lock pessimistic keys. - kv.SetTxnResourceGroup(txn, sctx.GetSessionVars().StmtCtx.ResourceGroupName) - // overwrite entry size limit by sys var. - if entrySizeLimit := sctx.GetSessionVars().TxnEntrySizeLimit; entrySizeLimit > 0 { - txn.SetOption(kv.SizeLimits, kv.TxnSizeLimits{ - Entry: entrySizeLimit, - Total: kv.TxnTotalSizeLimit.Load(), - }) - } - - return nil -} - -func (txn *LazyTxn) changeToInvalid() { - if txn.stagingHandle != kv.InvalidStagingHandle && !txn.IsPipelined() { - txn.Transaction.GetMemBuffer().Cleanup(txn.stagingHandle) - } - txn.stagingHandle = kv.InvalidStagingHandle - txn.Transaction = nil - txn.txnFuture = nil - - txn.enterFairLockingOnValid = false - - txn.mu.Lock() - lastState := txn.mu.TxnInfo.State - lastStateChangeTime := txn.mu.TxnInfo.LastStateChangeTime - hasLock := !txn.mu.TxnInfo.BlockStartTime.IsZero() - if txn.mu.TxnInfo.StartTS != 0 { - txninfo.Recorder.OnTrxEnd(&txn.mu.TxnInfo) - } - txn.mu.TxnInfo = txninfo.TxnInfo{} - txn.mu.Unlock() - if !lastStateChangeTime.IsZero() { - txninfo.TxnDurationHistogram(lastState, hasLock).Observe(time.Since(lastStateChangeTime).Seconds()) - } -} - -func (txn *LazyTxn) onStmtStart(currentSQLDigest string) { - if len(currentSQLDigest) == 0 { - return - } - - txn.mu.Lock() - defer txn.mu.Unlock() - txn.updateState(txninfo.TxnRunning) - txn.mu.TxnInfo.CurrentSQLDigest = currentSQLDigest - // Keeps at most 50 history sqls to avoid consuming too much memory. - const maxTransactionStmtHistory int = 50 - if len(txn.mu.TxnInfo.AllSQLDigests) < maxTransactionStmtHistory { - txn.mu.TxnInfo.AllSQLDigests = append(txn.mu.TxnInfo.AllSQLDigests, currentSQLDigest) - } -} - -func (txn *LazyTxn) onStmtEnd() { - txn.mu.Lock() - defer txn.mu.Unlock() - txn.mu.TxnInfo.CurrentSQLDigest = "" - txn.updateState(txninfo.TxnIdle) -} - -var hasMockAutoIncIDRetry = int64(0) - -func enableMockAutoIncIDRetry() { - atomic.StoreInt64(&hasMockAutoIncIDRetry, 1) -} - -func mockAutoIncIDRetry() bool { - return atomic.LoadInt64(&hasMockAutoIncIDRetry) == 1 -} - -var mockAutoRandIDRetryCount = int64(0) - -func needMockAutoRandIDRetry() bool { - return atomic.LoadInt64(&mockAutoRandIDRetryCount) > 0 -} - -func decreaseMockAutoRandIDRetryCount() { - atomic.AddInt64(&mockAutoRandIDRetryCount, -1) -} - -// ResetMockAutoRandIDRetryCount set the number of occurrences of -// `kv.ErrTxnRetryable` when calling TxnState.Commit(). -func ResetMockAutoRandIDRetryCount(failTimes int64) { - atomic.StoreInt64(&mockAutoRandIDRetryCount, failTimes) -} - -// Commit overrides the Transaction interface. -func (txn *LazyTxn) Commit(ctx context.Context) error { - defer txn.reset() - if len(txn.mutations) != 0 || txn.countHint() != 0 { - logutil.BgLogger().Error("the code should never run here", - zap.String("TxnState", txn.GoString()), - zap.Int("staging handler", int(txn.stagingHandle)), - zap.Int("mutations", txn.countHint()), - zap.Stack("something must be wrong")) - return errors.Trace(kv.ErrInvalidTxn) - } - - txn.mu.Lock() - txn.updateState(txninfo.TxnCommitting) - txn.mu.Unlock() - - failpoint.Inject("mockSlowCommit", func(_ failpoint.Value) {}) - - // mockCommitError8942 is used for PR #8942. - failpoint.Inject("mockCommitError8942", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(kv.ErrTxnRetryable) - } - }) - - // mockCommitRetryForAutoIncID is used to mock an commit retry for adjustAutoIncrementDatum. - failpoint.Inject("mockCommitRetryForAutoIncID", func(val failpoint.Value) { - if val.(bool) && !mockAutoIncIDRetry() { - enableMockAutoIncIDRetry() - failpoint.Return(kv.ErrTxnRetryable) - } - }) - - failpoint.Inject("mockCommitRetryForAutoRandID", func(val failpoint.Value) { - if val.(bool) && needMockAutoRandIDRetry() { - decreaseMockAutoRandIDRetryCount() - failpoint.Return(kv.ErrTxnRetryable) - } - }) - - return txn.Transaction.Commit(ctx) -} - -// Rollback overrides the Transaction interface. -func (txn *LazyTxn) Rollback() error { - defer txn.reset() - txn.mu.Lock() - txn.updateState(txninfo.TxnRollingBack) - txn.mu.Unlock() - // mockSlowRollback is used to mock a rollback which takes a long time - failpoint.Inject("mockSlowRollback", func(_ failpoint.Value) {}) - return txn.Transaction.Rollback() -} - -// RollbackMemDBToCheckpoint overrides the Transaction interface. -func (txn *LazyTxn) RollbackMemDBToCheckpoint(savepoint *tikv.MemDBCheckpoint) { - txn.flushStmtBuf() - txn.Transaction.RollbackMemDBToCheckpoint(savepoint) - txn.cleanup() -} - -// LockKeys wraps the inner transaction's `LockKeys` to record the status -func (txn *LazyTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keys ...kv.Key) error { - return txn.LockKeysFunc(ctx, lockCtx, nil, keys...) -} - -// LockKeysFunc Wrap the inner transaction's `LockKeys` to record the status -func (txn *LazyTxn) LockKeysFunc(ctx context.Context, lockCtx *kv.LockCtx, fn func(), keys ...kv.Key) error { - failpoint.Inject("beforeLockKeys", func() {}) - t := time.Now() - - var originState txninfo.TxnRunningState - txn.mu.Lock() - originState = txn.mu.TxnInfo.State - txn.updateState(txninfo.TxnLockAcquiring) - txn.mu.TxnInfo.BlockStartTime.Valid = true - txn.mu.TxnInfo.BlockStartTime.Time = t - txn.mu.Unlock() - lockFunc := func() { - if fn != nil { - fn() - } - txn.mu.Lock() - defer txn.mu.Unlock() - txn.updateState(originState) - txn.mu.TxnInfo.BlockStartTime.Valid = false - txn.mu.TxnInfo.EntriesCount = uint64(txn.Transaction.Len()) - } - return txn.Transaction.LockKeysFunc(ctx, lockCtx, lockFunc, keys...) -} - -// StartFairLocking wraps the inner transaction to support using fair locking with lazy initialization. -func (txn *LazyTxn) StartFairLocking() error { - if txn.Valid() { - return txn.Transaction.StartFairLocking() - } else if !txn.pending() { - err := errors.New("trying to start fair locking on a transaction in invalid state") - logutil.BgLogger().Error("unexpected error when starting fair locking", zap.Error(err), zap.Stringer("txn", txn)) - return err - } - txn.enterFairLockingOnValid = true - return nil -} - -// RetryFairLocking wraps the inner transaction to support using fair locking with lazy initialization. -func (txn *LazyTxn) RetryFairLocking(ctx context.Context) error { - if txn.Valid() { - return txn.Transaction.RetryFairLocking(ctx) - } else if !txn.pending() { - err := errors.New("trying to retry fair locking on a transaction in invalid state") - logutil.BgLogger().Error("unexpected error when retrying fair locking", zap.Error(err), zap.Stringer("txnStartTS", txn)) - return err - } - return nil -} - -// CancelFairLocking wraps the inner transaction to support using fair locking with lazy initialization. -func (txn *LazyTxn) CancelFairLocking(ctx context.Context) error { - if txn.Valid() { - return txn.Transaction.CancelFairLocking(ctx) - } else if !txn.pending() { - err := errors.New("trying to cancel fair locking on a transaction in invalid state") - logutil.BgLogger().Error("unexpected error when cancelling fair locking", zap.Error(err), zap.Stringer("txnStartTS", txn)) - return err - } - if !txn.enterFairLockingOnValid { - err := errors.New("trying to cancel fair locking when it's not started") - logutil.BgLogger().Error("unexpected error when cancelling fair locking", zap.Error(err), zap.Stringer("txnStartTS", txn)) - return err - } - txn.enterFairLockingOnValid = false - return nil -} - -// DoneFairLocking wraps the inner transaction to support using fair locking with lazy initialization. -func (txn *LazyTxn) DoneFairLocking(ctx context.Context) error { - if txn.Valid() { - return txn.Transaction.DoneFairLocking(ctx) - } - if !txn.pending() { - err := errors.New("trying to cancel fair locking on a transaction in invalid state") - logutil.BgLogger().Error("unexpected error when finishing fair locking") - return err - } - if !txn.enterFairLockingOnValid { - err := errors.New("trying to finish fair locking when it's not started") - logutil.BgLogger().Error("unexpected error when finishing fair locking") - return err - } - txn.enterFairLockingOnValid = false - return nil -} - -// IsInFairLockingMode wraps the inner transaction to support using fair locking with lazy initialization. -func (txn *LazyTxn) IsInFairLockingMode() bool { - if txn.Valid() { - return txn.Transaction.IsInFairLockingMode() - } else if txn.pending() { - return txn.enterFairLockingOnValid - } - return false -} - -func (txn *LazyTxn) reset() { - txn.cleanup() - txn.changeToInvalid() -} - -func (txn *LazyTxn) cleanup() { - txn.cleanupStmtBuf() - txn.initStmtBuf() - for key := range txn.mutations { - delete(txn.mutations, key) - } -} - -// KeysNeedToLock returns the keys need to be locked. -func (txn *LazyTxn) KeysNeedToLock() ([]kv.Key, error) { - if txn.stagingHandle == kv.InvalidStagingHandle { - return nil, nil - } - keys := make([]kv.Key, 0, txn.countHint()) - buf := txn.Transaction.GetMemBuffer() - buf.InspectStage(txn.stagingHandle, func(k kv.Key, flags kv.KeyFlags, v []byte) { - if !KeyNeedToLock(k, v, flags) { - return - } - keys = append(keys, k) - }) - - return keys, nil -} - -// Wait converts pending txn to valid -func (txn *LazyTxn) Wait(ctx context.Context, sctx sessionctx.Context) (kv.Transaction, error) { - if !txn.validOrPending() { - return txn, errors.AddStack(kv.ErrInvalidTxn) - } - if txn.pending() { - defer func(begin time.Time) { - sctx.GetSessionVars().DurationWaitTS = time.Since(begin) - }(time.Now()) - - // Transaction is lazy initialized. - // PrepareTxnCtx is called to get a tso future, makes s.txn a pending txn, - // If Txn() is called later, wait for the future to get a valid txn. - if err := txn.changePendingToValid(ctx, sctx); err != nil { - logutil.BgLogger().Error("active transaction fail", - zap.Error(err)) - txn.cleanup() - sctx.GetSessionVars().TxnCtx.StartTS = 0 - return txn, err - } - txn.lazyUniquenessCheckEnabled = !sctx.GetSessionVars().ConstraintCheckInPlacePessimistic - } - return txn, nil -} - -// KeyNeedToLock returns true if the key need to lock. -func KeyNeedToLock(k, v []byte, flags kv.KeyFlags) bool { - isTableKey := bytes.HasPrefix(k, tablecodec.TablePrefix()) - if !isTableKey { - // meta key always need to lock. - return true - } - - // a pessimistic locking is skipped, perform the conflict check and - // constraint check (more accurately, PresumeKeyNotExist) in prewrite (or later pessimistic locking) - if flags.HasNeedConstraintCheckInPrewrite() { - return false - } - - if flags.HasPresumeKeyNotExists() { - return true - } - - // lock row key, primary key and unique index for delete operation, - if len(v) == 0 { - return flags.HasNeedLocked() || tablecodec.IsRecordKey(k) - } - - if tablecodec.IsUntouchedIndexKValue(k, v) { - return false - } - - if !tablecodec.IsIndexKey(k) { - return true - } - - if tablecodec.IsTempIndexKey(k) { - tmpVal, err := tablecodec.DecodeTempIndexValue(v) - if err != nil { - logutil.BgLogger().Warn("decode temp index value failed", zap.Error(err)) - return false - } - current := tmpVal.Current() - return current.Handle != nil || tablecodec.IndexKVIsUnique(current.Value) - } - - return tablecodec.IndexKVIsUnique(v) -} - -func getBinlogMutation(ctx sessionctx.Context, tableID int64) *binlog.TableMutation { - bin := binloginfo.GetPrewriteValue(ctx, true) - for i := range bin.Mutations { - if bin.Mutations[i].TableId == tableID { - return &bin.Mutations[i] - } - } - idx := len(bin.Mutations) - bin.Mutations = append(bin.Mutations, binlog.TableMutation{TableId: tableID}) - return &bin.Mutations[idx] -} - -func mergeToMutation(m1, m2 *binlog.TableMutation) { - m1.InsertedRows = append(m1.InsertedRows, m2.InsertedRows...) - m1.UpdatedRows = append(m1.UpdatedRows, m2.UpdatedRows...) - m1.DeletedIds = append(m1.DeletedIds, m2.DeletedIds...) - m1.DeletedPks = append(m1.DeletedPks, m2.DeletedPks...) - m1.DeletedRows = append(m1.DeletedRows, m2.DeletedRows...) - m1.Sequence = append(m1.Sequence, m2.Sequence...) -} - -type txnFailFuture struct{} - -func (txnFailFuture) Wait() (uint64, error) { - return 0, errors.New("mock get timestamp fail") -} - -// txnFuture is a promise, which promises to return a txn in future. -type txnFuture struct { - future oracle.Future - store kv.Storage - txnScope string - pipelined bool -} - -func (tf *txnFuture) wait() (kv.Transaction, error) { - startTS, err := tf.future.Wait() - failpoint.Inject("txnFutureWait", func() {}) - if err == nil { - if tf.pipelined { - return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithStartTS(startTS), tikv.WithPipelinedMemDB()) - } - return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithStartTS(startTS)) - } else if config.GetGlobalConfig().Store == "unistore" { - return nil, err - } - - logutil.BgLogger().Warn("wait tso failed", zap.Error(err)) - // It would retry get timestamp. - if tf.pipelined { - return tf.store.Begin(tikv.WithTxnScope(tf.txnScope), tikv.WithPipelinedMemDB()) - } - return tf.store.Begin(tikv.WithTxnScope(tf.txnScope)) -} - -// HasDirtyContent checks whether there's dirty update on the given table. -// Put this function here is to avoid cycle import. -func (s *session) HasDirtyContent(tid int64) bool { - // There should not be dirty content in a txn with pipelined memdb, and it also doesn't support Iter function. - if s.txn.Transaction == nil || s.txn.Transaction.IsPipelined() { - return false - } - seekKey := tablecodec.EncodeTablePrefix(tid) - it, err := s.txn.GetMemBuffer().Iter(seekKey, nil) - terror.Log(err) - return it.Valid() && bytes.HasPrefix(it.Key(), seekKey) -} - -// StmtCommit implements the sessionctx.Context interface. -func (s *session) StmtCommit(ctx context.Context) { - defer func() { - s.txn.cleanup() - }() - - txnManager := sessiontxn.GetTxnManager(s) - err := txnManager.OnStmtCommit(ctx) - if err != nil { - logutil.Logger(ctx).Error("txnManager failed to handle OnStmtCommit", zap.Error(err)) - } - - st := &s.txn - st.flushStmtBuf() - - // Need to flush binlog. - for tableID, delta := range st.mutations { - mutation := getBinlogMutation(s, tableID) - mergeToMutation(mutation, delta) - } -} - -// StmtRollback implements the sessionctx.Context interface. -func (s *session) StmtRollback(ctx context.Context, isForPessimisticRetry bool) { - txnManager := sessiontxn.GetTxnManager(s) - err := txnManager.OnStmtRollback(ctx, isForPessimisticRetry) - if err != nil { - logutil.Logger(ctx).Error("txnManager failed to handle OnStmtRollback", zap.Error(err)) - } - s.txn.cleanup() -} - -// StmtGetMutation implements the sessionctx.Context interface. -func (s *session) StmtGetMutation(tableID int64) *binlog.TableMutation { - st := &s.txn - if _, ok := st.mutations[tableID]; !ok { - st.mutations[tableID] = &binlog.TableMutation{TableId: tableID} - } - return st.mutations[tableID] -} diff --git a/pkg/session/txnmanager.go b/pkg/session/txnmanager.go index dd77d7f08964e..7b3b512c6acee 100644 --- a/pkg/session/txnmanager.go +++ b/pkg/session/txnmanager.go @@ -108,12 +108,12 @@ func (m *txnManager) GetStmtForUpdateTS() (uint64, error) { return 0, err } - if _, _err_ := failpoint.Eval(_curpkg_("assertTxnManagerForUpdateTSEqual")); _err_ == nil { + failpoint.Inject("assertTxnManagerForUpdateTSEqual", func() { sessVars := m.sctx.GetSessionVars() if txnCtxForUpdateTS := sessVars.TxnCtx.GetForUpdateTS(); sessVars.SnapshotTS == 0 && ts != txnCtxForUpdateTS { panic(fmt.Sprintf("forUpdateTS not equal %d != %d", ts, txnCtxForUpdateTS)) } - } + }) return ts, nil } diff --git a/pkg/session/txnmanager.go__failpoint_stash__ b/pkg/session/txnmanager.go__failpoint_stash__ deleted file mode 100644 index 7b3b512c6acee..0000000000000 --- a/pkg/session/txnmanager.go__failpoint_stash__ +++ /dev/null @@ -1,381 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package session - -import ( - "context" - "fmt" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/sessiontxn/isolation" - "github.com/pingcap/tidb/pkg/sessiontxn/staleread" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" -) - -func init() { - sessiontxn.GetTxnManager = getTxnManager -} - -func getTxnManager(sctx sessionctx.Context) sessiontxn.TxnManager { - if manager, ok := sctx.GetSessionVars().TxnManager.(sessiontxn.TxnManager); ok { - return manager - } - - manager := newTxnManager(sctx) - sctx.GetSessionVars().TxnManager = manager - return manager -} - -// txnManager implements sessiontxn.TxnManager -type txnManager struct { - sctx sessionctx.Context - - stmtNode ast.StmtNode - ctxProvider sessiontxn.TxnContextProvider - - // We always reuse the same OptimisticTxnContextProvider in one session to reduce memory allocation cost for every new txn. - reservedOptimisticProviders [2]isolation.OptimisticTxnContextProvider - - // used for slow transaction logs - events []event - lastInstant time.Time - enterTxnInstant time.Time -} - -type event struct { - event string - duration time.Duration -} - -func (s event) MarshalLogObject(enc zapcore.ObjectEncoder) error { - enc.AddString("event", s.event) - enc.AddDuration("gap", s.duration) - return nil -} - -func newTxnManager(sctx sessionctx.Context) *txnManager { - return &txnManager{sctx: sctx} -} - -func (m *txnManager) GetTxnInfoSchema() infoschema.InfoSchema { - if m.ctxProvider != nil { - return m.ctxProvider.GetTxnInfoSchema() - } - - if is := m.sctx.GetDomainInfoSchema(); is != nil { - return is.(infoschema.InfoSchema) - } - - return nil -} - -func (m *txnManager) GetStmtReadTS() (uint64, error) { - if m.ctxProvider == nil { - return 0, errors.New("context provider not set") - } - return m.ctxProvider.GetStmtReadTS() -} - -func (m *txnManager) GetStmtForUpdateTS() (uint64, error) { - if m.ctxProvider == nil { - return 0, errors.New("context provider not set") - } - - ts, err := m.ctxProvider.GetStmtForUpdateTS() - if err != nil { - return 0, err - } - - failpoint.Inject("assertTxnManagerForUpdateTSEqual", func() { - sessVars := m.sctx.GetSessionVars() - if txnCtxForUpdateTS := sessVars.TxnCtx.GetForUpdateTS(); sessVars.SnapshotTS == 0 && ts != txnCtxForUpdateTS { - panic(fmt.Sprintf("forUpdateTS not equal %d != %d", ts, txnCtxForUpdateTS)) - } - }) - - return ts, nil -} - -func (m *txnManager) GetTxnScope() string { - if m.ctxProvider == nil { - return kv.GlobalTxnScope - } - return m.ctxProvider.GetTxnScope() -} - -func (m *txnManager) GetReadReplicaScope() string { - if m.ctxProvider == nil { - return kv.GlobalReplicaScope - } - return m.ctxProvider.GetReadReplicaScope() -} - -// GetSnapshotWithStmtReadTS gets snapshot with read ts -func (m *txnManager) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { - if m.ctxProvider == nil { - return nil, errors.New("context provider not set") - } - return m.ctxProvider.GetSnapshotWithStmtReadTS() -} - -// GetSnapshotWithStmtForUpdateTS gets snapshot with for update ts -func (m *txnManager) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { - if m.ctxProvider == nil { - return nil, errors.New("context provider not set") - } - return m.ctxProvider.GetSnapshotWithStmtForUpdateTS() -} - -func (m *txnManager) GetContextProvider() sessiontxn.TxnContextProvider { - return m.ctxProvider -} - -func (m *txnManager) EnterNewTxn(ctx context.Context, r *sessiontxn.EnterNewTxnRequest) error { - ctxProvider, err := m.newProviderWithRequest(r) - if err != nil { - return err - } - - if err = ctxProvider.OnInitialize(ctx, r.Type); err != nil { - m.sctx.RollbackTxn(ctx) - return err - } - - m.ctxProvider = ctxProvider - if r.Type == sessiontxn.EnterNewTxnWithBeginStmt { - m.sctx.GetSessionVars().SetInTxn(true) - } - - m.resetEvents() - m.recordEvent("enter txn") - return nil -} - -func (m *txnManager) OnTxnEnd() { - m.ctxProvider = nil - m.stmtNode = nil - - m.events = append(m.events, event{event: "txn end", duration: time.Since(m.lastInstant)}) - - duration := time.Since(m.enterTxnInstant) - threshold := m.sctx.GetSessionVars().SlowTxnThreshold - if threshold > 0 && uint64(duration.Milliseconds()) >= threshold { - logutil.BgLogger().Info( - "slow transaction", zap.Duration("duration", duration), - zap.Uint64("conn", m.sctx.GetSessionVars().ConnectionID), - zap.Uint64("txnStartTS", m.sctx.GetSessionVars().TxnCtx.StartTS), - zap.Objects("events", m.events), - ) - } - - m.lastInstant = time.Now() -} - -func (m *txnManager) GetCurrentStmt() ast.StmtNode { - return m.stmtNode -} - -// OnStmtStart is the hook that should be called when a new statement started -func (m *txnManager) OnStmtStart(ctx context.Context, node ast.StmtNode) error { - m.stmtNode = node - - if m.ctxProvider == nil { - return errors.New("context provider not set") - } - - var sql string - if node != nil { - sql = node.OriginalText() - sql = parser.Normalize(sql, m.sctx.GetSessionVars().EnableRedactLog) - } - m.recordEvent(sql) - return m.ctxProvider.OnStmtStart(ctx, m.stmtNode) -} - -// OnStmtEnd implements the TxnManager interface -func (m *txnManager) OnStmtEnd() { - m.recordEvent("stmt end") -} - -// OnPessimisticStmtStart is the hook that should be called when starts handling a pessimistic DML or -// a pessimistic select-for-update statements. -func (m *txnManager) OnPessimisticStmtStart(ctx context.Context) error { - if m.ctxProvider == nil { - return errors.New("context provider not set") - } - return m.ctxProvider.OnPessimisticStmtStart(ctx) -} - -// OnPessimisticStmtEnd is the hook that should be called when finishes handling a pessimistic DML or -// select-for-update statement. -func (m *txnManager) OnPessimisticStmtEnd(ctx context.Context, isSuccessful bool) error { - if m.ctxProvider == nil { - return errors.New("context provider not set") - } - return m.ctxProvider.OnPessimisticStmtEnd(ctx, isSuccessful) -} - -// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error -func (m *txnManager) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { - if m.ctxProvider == nil { - return sessiontxn.NoIdea() - } - return m.ctxProvider.OnStmtErrorForNextAction(ctx, point, err) -} - -// ActivateTxn decides to activate txn according to the parameter `active` -func (m *txnManager) ActivateTxn() (kv.Transaction, error) { - if m.ctxProvider == nil { - return nil, errors.AddStack(kv.ErrInvalidTxn) - } - return m.ctxProvider.ActivateTxn() -} - -// OnStmtRetry is the hook that should be called when a statement retry -func (m *txnManager) OnStmtRetry(ctx context.Context) error { - if m.ctxProvider == nil { - return errors.New("context provider not set") - } - return m.ctxProvider.OnStmtRetry(ctx) -} - -// OnStmtCommit is the hook that should be called when a statement is executed successfully. -func (m *txnManager) OnStmtCommit(ctx context.Context) error { - if m.ctxProvider == nil { - return errors.New("context provider not set") - } - m.recordEvent("stmt commit") - return m.ctxProvider.OnStmtCommit(ctx) -} - -func (m *txnManager) recordEvent(eventName string) { - if m.events == nil { - m.resetEvents() - } - m.events = append(m.events, event{event: eventName, duration: time.Since(m.lastInstant)}) - m.lastInstant = time.Now() -} - -func (m *txnManager) resetEvents() { - if m.events == nil { - m.events = make([]event, 0, 10) - } else { - m.events = m.events[:0] - } - m.enterTxnInstant = time.Now() -} - -// OnStmtRollback is the hook that should be called when a statement fails to execute. -func (m *txnManager) OnStmtRollback(ctx context.Context, isForPessimisticRetry bool) error { - if m.ctxProvider == nil { - return errors.New("context provider not set") - } - m.recordEvent("stmt rollback") - return m.ctxProvider.OnStmtRollback(ctx, isForPessimisticRetry) -} - -// OnLocalTemporaryTableCreated is the hook that should be called when a temporary table created. -// The provider will update its state then -func (m *txnManager) OnLocalTemporaryTableCreated() { - if m.ctxProvider != nil { - m.ctxProvider.OnLocalTemporaryTableCreated() - } -} - -func (m *txnManager) AdviseWarmup() error { - if m.sctx.GetSessionVars().BulkDMLEnabled { - // We don't want to validate the feasibility of pipelined DML here. - // We'd like to check it later after optimization so that optimizer info can be used. - // And it does not make much sense to save such a little time for pipelined-dml as it's - // for bulk processing. - return nil - } - - if m.ctxProvider != nil { - return m.ctxProvider.AdviseWarmup() - } - return nil -} - -// AdviseOptimizeWithPlan providers optimization according to the plan -func (m *txnManager) AdviseOptimizeWithPlan(plan any) error { - if m.ctxProvider != nil { - return m.ctxProvider.AdviseOptimizeWithPlan(plan) - } - return nil -} - -func (m *txnManager) newProviderWithRequest(r *sessiontxn.EnterNewTxnRequest) (sessiontxn.TxnContextProvider, error) { - if r.Provider != nil { - return r.Provider, nil - } - - if r.StaleReadTS > 0 { - m.sctx.GetSessionVars().TxnCtx.StaleReadTs = r.StaleReadTS - return staleread.NewStalenessTxnContextProvider(m.sctx, r.StaleReadTS, nil), nil - } - - sessVars := m.sctx.GetSessionVars() - - txnMode := r.TxnMode - if txnMode == "" { - txnMode = sessVars.TxnMode - } - - switch txnMode { - case "", ast.Optimistic: - // When txnMode is 'OPTIMISTIC' or '', the transaction should be optimistic - provider := &m.reservedOptimisticProviders[0] - if old, ok := m.ctxProvider.(*isolation.OptimisticTxnContextProvider); ok && old == provider { - // We should make sure the new provider is not the same with the old one - provider = &m.reservedOptimisticProviders[1] - } - provider.ResetForNewTxn(m.sctx, r.CausalConsistencyOnly) - return provider, nil - case ast.Pessimistic: - // When txnMode is 'PESSIMISTIC', the provider should be determined by the isolation level - switch sessVars.IsolationLevelForNewTxn() { - case ast.ReadCommitted: - return isolation.NewPessimisticRCTxnContextProvider(m.sctx, r.CausalConsistencyOnly), nil - case ast.Serializable: - // The Oracle serializable isolation is actually SI in pessimistic mode. - // Do not update ForUpdateTS when the user is using the Serializable isolation level. - // It can be used temporarily on the few occasions when an Oracle-like isolation level is needed. - // Support for this does not mean that TiDB supports serializable isolation of MySQL. - // tidb_skip_isolation_level_check should still be disabled by default. - return isolation.NewPessimisticSerializableTxnContextProvider(m.sctx, r.CausalConsistencyOnly), nil - default: - // We use Repeatable read for all other cases. - return isolation.NewPessimisticRRTxnContextProvider(m.sctx, r.CausalConsistencyOnly), nil - } - default: - return nil, errors.Errorf("Invalid txn mode '%s'", txnMode) - } -} - -// SetOptionsBeforeCommit sets options before commit. -func (m *txnManager) SetOptionsBeforeCommit(txn kv.Transaction, commitTSChecker func(uint64) bool) error { - return m.ctxProvider.SetOptionsBeforeCommit(txn, commitTSChecker) -} diff --git a/pkg/sessionctx/sessionstates/binding__failpoint_binding__.go b/pkg/sessionctx/sessionstates/binding__failpoint_binding__.go deleted file mode 100644 index 46a20b98f439c..0000000000000 --- a/pkg/sessionctx/sessionstates/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package sessionstates - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/sessionctx/sessionstates/session_token.go b/pkg/sessionctx/sessionstates/session_token.go index b141c0c605e8f..5702d4966f118 100644 --- a/pkg/sessionctx/sessionstates/session_token.go +++ b/pkg/sessionctx/sessionstates/session_token.go @@ -327,10 +327,10 @@ func (sc *signingCert) checkSignature(content, signature []byte) error { func getNow() time.Time { now := time.Now() - if val, _err_ := failpoint.Eval(_curpkg_("mockNowOffset")); _err_ == nil { + failpoint.Inject("mockNowOffset", func(val failpoint.Value) { if s := uint64(val.(int)); s != 0 { now = now.Add(time.Duration(s)) } - } + }) return now } diff --git a/pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ b/pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ deleted file mode 100644 index 5702d4966f118..0000000000000 --- a/pkg/sessionctx/sessionstates/session_token.go__failpoint_stash__ +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sessionstates - -import ( - "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rand" - "crypto/rsa" - "crypto/sha256" - "crypto/sha512" - "crypto/tls" - "crypto/x509" - "encoding/json" - "strings" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -// Token-based authentication is used in session migration. We don't use typical authentication because the proxy -// cannot store the user passwords for security issues. -// -// The process of token-based authentication: -// 1. Before migrating the session, the proxy requires a token from server A. -// 2. Server A generates a token and signs it with a private key defined in the certificate. -// 3. The proxy authenticates with server B and sends the signed token as the password. -// 4. Server B checks the signature with the public key defined in the certificate and then verifies the token. -// -// The highlight is that the certificates on all the servers should be the same all the time. -// However, the certificates should be rotated periodically. Just in case of using different certificates to -// sign and check, a server should keep the old certificate for a while. A server will try both -// the 2 certificates to check the signature. -const ( - // A token needs a lifetime to avoid brute force attack. - tokenLifetime = time.Minute - // LoadCertInterval is the interval of reloading the certificate. The certificate should be rotated periodically. - LoadCertInterval = 10 * time.Minute - // After a certificate is replaced, it's still valid for oldCertValidTime. - // oldCertValidTime must be a little longer than LoadCertInterval, because the previous server may - // sign with the old cert but the new server checks with the new cert. - // - server A loads the old cert at 00:00:00. - // - the cert is rotated at 00:00:01 on all servers. - // - server B loads the new cert at 00:00:02. - // - server A signs token with the old cert at 00:10:00. - // - server B reloads the same new cert again at 00:10:01, and it has 3 certs now. - // - server B receives the token at 00:10:02, so the old cert should be valid for more than 10m after replacement. - oldCertValidTime = 15 * time.Minute -) - -// SessionToken represents the token used to authenticate with the new server. -type SessionToken struct { - Username string `json:"username"` - SignTime time.Time `json:"sign-time"` - ExpireTime time.Time `json:"expire-time"` - Signature []byte `json:"signature,omitempty"` -} - -// CreateSessionToken creates a token for the proxy. -func CreateSessionToken(username string) (*SessionToken, error) { - now := getNow() - token := &SessionToken{ - Username: username, - SignTime: now, - ExpireTime: now.Add(tokenLifetime), - } - tokenBytes, err := json.Marshal(token) - if err != nil { - return nil, errors.Trace(err) - } - if token.Signature, err = globalSigningCert.sign(tokenBytes); err != nil { - return nil, ErrCannotMigrateSession.GenWithStackByArgs(err.Error()) - } - return token, nil -} - -// ValidateSessionToken validates the token sent from the proxy. -func ValidateSessionToken(tokenBytes []byte, username string) (err error) { - var token SessionToken - if err = json.Unmarshal(tokenBytes, &token); err != nil { - return errors.Trace(err) - } - signature := token.Signature - // Clear the signature and marshal it again to get the original content. - token.Signature = nil - if tokenBytes, err = json.Marshal(token); err != nil { - return errors.Trace(err) - } - if err = globalSigningCert.checkSignature(tokenBytes, signature); err != nil { - return ErrCannotMigrateSession.GenWithStackByArgs(err.Error()) - } - now := getNow() - if now.After(token.ExpireTime) { - return ErrCannotMigrateSession.GenWithStackByArgs("token expired", token.ExpireTime.String()) - } - // An attacker may forge a very long lifetime to brute force, so we also need to check `SignTime`. - // However, we need to be tolerant of these problems: - // - The `tokenLifetime` may change between TiDB versions, so we can't check `token.SignTime.Add(tokenLifetime).Equal(token.ExpireTime)` - // - There may exist time bias between TiDB instances, so we can't check `now.After(token.SignTime)` - if token.SignTime.Add(tokenLifetime).Before(now) { - return ErrCannotMigrateSession.GenWithStackByArgs("token lifetime is too long", token.SignTime.String()) - } - if !strings.EqualFold(username, token.Username) { - return ErrCannotMigrateSession.GenWithStackByArgs("username does not match", username, token.Username) - } - return nil -} - -// SetKeyPath sets the path of key.pem and force load the certificate again. -func SetKeyPath(keyPath string) { - globalSigningCert.setKeyPath(keyPath) -} - -// SetCertPath sets the path of key.pem and force load the certificate again. -func SetCertPath(certPath string) { - globalSigningCert.setCertPath(certPath) -} - -// ReloadSigningCert is used to load the certificate periodically in a separate goroutine. -// It's impossible to know when the old certificate should expire without this goroutine: -// - If the certificate is rotated a minute ago, the old certificate should be still valid for a while. -// - If the certificate is rotated a month ago, the old certificate should expire for safety. -func ReloadSigningCert() { - globalSigningCert.lockAndLoad() -} - -var globalSigningCert signingCert - -// signingCert represents the parsed certificate used for token-based auth. -type signingCert struct { - sync.RWMutex - certPath string - keyPath string - // The cert file may happen to be rotated between signing and checking, so we keep the old cert for a while. - // certs contain all the certificates that are not expired yet. - certs []*certInfo -} - -type certInfo struct { - cert *x509.Certificate - privKey crypto.PrivateKey - expireTime time.Time -} - -func (sc *signingCert) setCertPath(certPath string) { - sc.Lock() - if certPath != sc.certPath { - sc.certPath = certPath - // It may fail expectedly because the key path is not set yet. - sc.checkAndLoadCert() - } - sc.Unlock() -} - -func (sc *signingCert) setKeyPath(keyPath string) { - sc.Lock() - if keyPath != sc.keyPath { - sc.keyPath = keyPath - // It may fail expectedly because the cert path is not set yet. - sc.checkAndLoadCert() - } - sc.Unlock() -} - -func (sc *signingCert) lockAndLoad() { - sc.Lock() - sc.checkAndLoadCert() - sc.Unlock() -} - -func (sc *signingCert) checkAndLoadCert() { - if len(sc.certPath) == 0 || len(sc.keyPath) == 0 { - return - } - if err := sc.loadCert(); err != nil { - logutil.BgLogger().Warn("loading signing cert failed", - zap.String("cert path", sc.certPath), - zap.String("key path", sc.keyPath), - zap.Error(err)) - } else { - logutil.BgLogger().Info("signing cert is loaded successfully", - zap.String("cert path", sc.certPath), - zap.String("key path", sc.keyPath)) - } -} - -// loadCert loads the cert and adds it into the cert list. -func (sc *signingCert) loadCert() error { - tlsCert, err := tls.LoadX509KeyPair(sc.certPath, sc.keyPath) - if err != nil { - return errors.Wrapf(err, "load x509 failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath) - } - var cert *x509.Certificate - if tlsCert.Leaf != nil { - cert = tlsCert.Leaf - } else { - if cert, err = x509.ParseCertificate(tlsCert.Certificate[0]); err != nil { - return errors.Wrapf(err, "parse x509 cert failed, cert path: %s, key path: %s", sc.certPath, sc.keyPath) - } - } - - // Rotate certs. Ensure that the expireTime of certs is in descending order. - now := getNow() - newCerts := make([]*certInfo, 0, len(sc.certs)+1) - newCerts = append(newCerts, &certInfo{ - cert: cert, - privKey: tlsCert.PrivateKey, - expireTime: now.Add(LoadCertInterval + oldCertValidTime), - }) - for i := 0; i < len(sc.certs); i++ { - // Discard the certs that are already expired. - if now.After(sc.certs[i].expireTime) { - break - } - newCerts = append(newCerts, sc.certs[i]) - } - sc.certs = newCerts - return nil -} - -// sign generates a signature with the content and the private key. -func (sc *signingCert) sign(content []byte) ([]byte, error) { - var ( - signer crypto.Signer - opts crypto.SignerOpts - ) - sc.RLock() - defer sc.RUnlock() - if len(sc.certs) == 0 { - return nil, errors.New("no certificate or key file to sign the data") - } - // Always sign the token with the latest cert. - certInfo := sc.certs[0] - switch key := certInfo.privKey.(type) { - case ed25519.PrivateKey: - signer = key - opts = crypto.Hash(0) - case *rsa.PrivateKey: - signer = key - var pssHash crypto.Hash - switch certInfo.cert.SignatureAlgorithm { - case x509.SHA256WithRSAPSS: - pssHash = crypto.SHA256 - case x509.SHA384WithRSAPSS: - pssHash = crypto.SHA384 - case x509.SHA512WithRSAPSS: - pssHash = crypto.SHA512 - } - if pssHash != 0 { - h := pssHash.New() - h.Write(content) - content = h.Sum(nil) - opts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: pssHash} - break - } - switch certInfo.cert.SignatureAlgorithm { - case x509.SHA256WithRSA: - hashed := sha256.Sum256(content) - content = hashed[:] - opts = crypto.SHA256 - case x509.SHA384WithRSA: - hashed := sha512.Sum384(content) - content = hashed[:] - opts = crypto.SHA384 - case x509.SHA512WithRSA: - hashed := sha512.Sum512(content) - content = hashed[:] - opts = crypto.SHA512 - default: - return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String()) - } - case *ecdsa.PrivateKey: - signer = key - default: - return nil, errors.Errorf("not supported private key type '%s' for signing", certInfo.cert.SignatureAlgorithm.String()) - } - return signer.Sign(rand.Reader, content, opts) -} - -// checkSignature checks the signature and the content. -func (sc *signingCert) checkSignature(content, signature []byte) error { - sc.RLock() - defer sc.RUnlock() - now := getNow() - var err error - for _, certInfo := range sc.certs { - // The expireTime is in descending order. So if the first one is expired, we skip the following. - if now.After(certInfo.expireTime) { - break - } - switch certInfo.privKey.(type) { - // ESDSA is special: `PrivateKey.Sign` doesn't match with `Certificate.CheckSignature`. - case *ecdsa.PrivateKey: - if !ecdsa.VerifyASN1(certInfo.cert.PublicKey.(*ecdsa.PublicKey), content, signature) { - err = errors.New("x509: ECDSA verification failure") - } - default: - err = certInfo.cert.CheckSignature(certInfo.cert.SignatureAlgorithm, content, signature) - } - if err == nil { - return nil - } - } - // no certs (possible) or all certs are expired (impossible) - if err == nil { - return errors.Errorf("no valid certificate to check the signature, cached certificates: %d", len(sc.certs)) - } - return err -} - -func getNow() time.Time { - now := time.Now() - failpoint.Inject("mockNowOffset", func(val failpoint.Value) { - if s := uint64(val.(int)); s != 0 { - now = now.Add(time.Duration(s)) - } - }) - return now -} diff --git a/pkg/sessiontxn/isolation/base.go b/pkg/sessiontxn/isolation/base.go index 6eb34820c5003..7ae7502f00534 100644 --- a/pkg/sessiontxn/isolation/base.go +++ b/pkg/sessiontxn/isolation/base.go @@ -643,9 +643,9 @@ func newOracleFuture(ctx context.Context, sctx sessionctx.Context, scope string) r, ctx := tracing.StartRegionEx(ctx, "isolation.newOracleFuture") defer r.End() - if _, _err_ := failpoint.Eval(_curpkg_("requestTsoFromPD")); _err_ == nil { + failpoint.Inject("requestTsoFromPD", func() { sessiontxn.TsoRequestCountInc(sctx) - } + }) oracleStore := sctx.GetStore().GetOracle() option := &oracle.Option{TxnScope: scope} diff --git a/pkg/sessiontxn/isolation/base.go__failpoint_stash__ b/pkg/sessiontxn/isolation/base.go__failpoint_stash__ deleted file mode 100644 index 7ae7502f00534..0000000000000 --- a/pkg/sessiontxn/isolation/base.go__failpoint_stash__ +++ /dev/null @@ -1,747 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package isolation - -import ( - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/domain" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/sessiontxn/internal" - "github.com/pingcap/tidb/pkg/sessiontxn/staleread" - "github.com/pingcap/tidb/pkg/store/driver/txn" - "github.com/pingcap/tidb/pkg/table/temptable" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/tableutil" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/pingcap/tipb/go-binlog" - tikvstore "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/oracle" -) - -// baseTxnContextProvider is a base class for the transaction context providers that implement `TxnContextProvider` in different isolation. -// It provides some common functions below: -// - Provides a default `OnInitialize` method to initialize its inner state. -// - Provides some methods like `activateTxn` and `prepareTxn` to manage the inner transaction. -// - Provides default methods `GetTxnInfoSchema`, `GetStmtReadTS` and `GetStmtForUpdateTS` and return the snapshot information schema or ts when `tidb_snapshot` is set. -// - Provides other default methods like `Advise`, `OnStmtStart`, `OnStmtRetry` and `OnStmtErrorForNextAction` -// -// The subclass can set some inner property of `baseTxnContextProvider` when it is constructed. -// For example, `getStmtReadTSFunc` and `getStmtForUpdateTSFunc` should be set, and they will be called when `GetStmtReadTS` -// or `GetStmtForUpdate` to get the timestamp that should be used by the corresponding isolation level. -type baseTxnContextProvider struct { - // States that should be initialized when baseTxnContextProvider is created and should not be changed after that - sctx sessionctx.Context - causalConsistencyOnly bool - onInitializeTxnCtx func(*variable.TransactionContext) - onTxnActiveFunc func(kv.Transaction, sessiontxn.EnterNewTxnType) - getStmtReadTSFunc func() (uint64, error) - getStmtForUpdateTSFunc func() (uint64, error) - - // Runtime states - ctx context.Context - infoSchema infoschema.InfoSchema - txn kv.Transaction - isTxnPrepared bool - enterNewTxnType sessiontxn.EnterNewTxnType - // constStartTS is only used by point get max ts optimization currently. - // When constStartTS != 0, we use constStartTS directly without fetching it from tso. - // To save the cpu cycles `PrepareTSFuture` will also not be called when warmup (postpone to activate txn). - constStartTS uint64 -} - -// OnInitialize is the hook that should be called when enter a new txn with this provider -func (p *baseTxnContextProvider) OnInitialize(ctx context.Context, tp sessiontxn.EnterNewTxnType) (err error) { - if p.getStmtReadTSFunc == nil || p.getStmtForUpdateTSFunc == nil { - return errors.New("ts functions should not be nil") - } - - p.ctx = ctx - sessVars := p.sctx.GetSessionVars() - activeNow := true - switch tp { - case sessiontxn.EnterNewTxnDefault: - // As we will enter a new txn, we need to commit the old txn if it's still valid. - // There are two main steps here to enter a new txn: - // 1. prepareTxnWithOracleTS - // 2. ActivateTxn - if err := internal.CommitBeforeEnterNewTxn(p.ctx, p.sctx); err != nil { - return err - } - if err := p.prepareTxnWithOracleTS(); err != nil { - return err - } - case sessiontxn.EnterNewTxnWithBeginStmt: - if !canReuseTxnWhenExplicitBegin(p.sctx) { - // As we will enter a new txn, we need to commit the old txn if it's still valid. - // There are two main steps here to enter a new txn: - // 1. prepareTxnWithOracleTS - // 2. ActivateTxn - if err := internal.CommitBeforeEnterNewTxn(p.ctx, p.sctx); err != nil { - return err - } - if err := p.prepareTxnWithOracleTS(); err != nil { - return err - } - } - sessVars.SetInTxn(true) - case sessiontxn.EnterNewTxnBeforeStmt: - activeNow = false - default: - return errors.Errorf("Unsupported type: %v", tp) - } - - p.enterNewTxnType = tp - p.infoSchema = p.sctx.GetDomainInfoSchema().(infoschema.InfoSchema) - txnCtx := &variable.TransactionContext{ - TxnCtxNoNeedToRestore: variable.TxnCtxNoNeedToRestore{ - CreateTime: time.Now(), - InfoSchema: p.infoSchema, - TxnScope: sessVars.CheckAndGetTxnScope(), - }, - } - if p.onInitializeTxnCtx != nil { - p.onInitializeTxnCtx(txnCtx) - } - sessVars.TxnCtxMu.Lock() - sessVars.TxnCtx = txnCtx - sessVars.TxnCtxMu.Unlock() - if variable.EnableMDL.Load() { - sessVars.TxnCtx.EnableMDL = true - } - - txn, err := p.sctx.Txn(false) - if err != nil { - return err - } - p.isTxnPrepared = txn.Valid() || p.sctx.GetPreparedTxnFuture() != nil - if activeNow { - _, err = p.ActivateTxn() - } - - return err -} - -// GetTxnInfoSchema returns the information schema used by txn -func (p *baseTxnContextProvider) GetTxnInfoSchema() infoschema.InfoSchema { - if is := p.sctx.GetSessionVars().SnapshotInfoschema; is != nil { - return is.(infoschema.InfoSchema) - } - if _, ok := p.infoSchema.(*infoschema.SessionExtendedInfoSchema); !ok { - p.infoSchema = &infoschema.SessionExtendedInfoSchema{ - InfoSchema: p.infoSchema, - } - p.sctx.GetSessionVars().TxnCtx.InfoSchema = p.infoSchema - } - return p.infoSchema -} - -// GetTxnScope returns the current txn scope -func (p *baseTxnContextProvider) GetTxnScope() string { - return p.sctx.GetSessionVars().TxnCtx.TxnScope -} - -// GetReadReplicaScope returns the read replica scope -func (p *baseTxnContextProvider) GetReadReplicaScope() string { - if txnScope := p.GetTxnScope(); txnScope != kv.GlobalTxnScope && txnScope != "" { - // In local txn, we should use txnScope as the readReplicaScope - return txnScope - } - - if p.sctx.GetSessionVars().GetReplicaRead().IsClosestRead() { - // If closest read is set, we should use the scope where instance located. - return config.GetTxnScopeFromConfig() - } - - // When it is not local txn or closet read, we should use global scope - return kv.GlobalReplicaScope -} - -// GetStmtReadTS returns the read timestamp used by select statement (not for select ... for update) -func (p *baseTxnContextProvider) GetStmtReadTS() (uint64, error) { - if _, err := p.ActivateTxn(); err != nil { - return 0, err - } - - if snapshotTS := p.sctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { - return snapshotTS, nil - } - return p.getStmtReadTSFunc() -} - -// GetStmtForUpdateTS returns the read timestamp used by update/insert/delete or select ... for update -func (p *baseTxnContextProvider) GetStmtForUpdateTS() (uint64, error) { - if _, err := p.ActivateTxn(); err != nil { - return 0, err - } - - if snapshotTS := p.sctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { - return snapshotTS, nil - } - return p.getStmtForUpdateTSFunc() -} - -// OnStmtStart is the hook that should be called when a new statement started -func (p *baseTxnContextProvider) OnStmtStart(ctx context.Context, _ ast.StmtNode) error { - p.ctx = ctx - return nil -} - -// OnPessimisticStmtStart is the hook that should be called when starts handling a pessimistic DML or -// a pessimistic select-for-update statements. -func (p *baseTxnContextProvider) OnPessimisticStmtStart(_ context.Context) error { - return nil -} - -// OnPessimisticStmtEnd is the hook that should be called when finishes handling a pessimistic DML or -// select-for-update statement. -func (p *baseTxnContextProvider) OnPessimisticStmtEnd(_ context.Context, _ bool) error { - return nil -} - -// OnStmtRetry is the hook that should be called when a statement is retried internally. -func (p *baseTxnContextProvider) OnStmtRetry(ctx context.Context) error { - p.ctx = ctx - p.sctx.GetSessionVars().TxnCtx.CurrentStmtPessimisticLockCache = nil - return nil -} - -// OnStmtCommit is the hook that should be called when a statement is executed successfully. -func (p *baseTxnContextProvider) OnStmtCommit(_ context.Context) error { - return nil -} - -// OnStmtRollback is the hook that should be called when a statement fails to execute. -func (p *baseTxnContextProvider) OnStmtRollback(_ context.Context, _ bool) error { - return nil -} - -// OnLocalTemporaryTableCreated is the hook that should be called when a local temporary table created. -func (p *baseTxnContextProvider) OnLocalTemporaryTableCreated() { - p.infoSchema = temptable.AttachLocalTemporaryTableInfoSchema(p.sctx, p.infoSchema) - p.sctx.GetSessionVars().TxnCtx.InfoSchema = p.infoSchema - if p.txn != nil && p.txn.Valid() { - if interceptor := temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema); interceptor != nil { - p.txn.SetOption(kv.SnapInterceptor, interceptor) - } - } -} - -// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error -func (p *baseTxnContextProvider) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { - switch point { - case sessiontxn.StmtErrAfterPessimisticLock: - // for pessimistic lock error, return the error by default - return sessiontxn.ErrorAction(err) - default: - return sessiontxn.NoIdea() - } -} - -func (p *baseTxnContextProvider) getTxnStartTS() (uint64, error) { - txn, err := p.ActivateTxn() - if err != nil { - return 0, err - } - return txn.StartTS(), nil -} - -// ActivateTxn activates the transaction and set the relevant context variables. -func (p *baseTxnContextProvider) ActivateTxn() (kv.Transaction, error) { - if p.txn != nil { - return p.txn, nil - } - - if err := p.prepareTxn(); err != nil { - return nil, err - } - - if p.constStartTS != 0 { - if err := p.replaceTxnTsFuture(sessiontxn.ConstantFuture(p.constStartTS)); err != nil { - return nil, err - } - } - - txnFuture := p.sctx.GetPreparedTxnFuture() - txn, err := txnFuture.Wait(p.ctx, p.sctx) - if err != nil { - return nil, err - } - - sessVars := p.sctx.GetSessionVars() - sessVars.TxnCtxMu.Lock() - sessVars.TxnCtx.StartTS = txn.StartTS() - sessVars.GetRowIDShardGenerator().SetShardStep(int(sessVars.ShardAllocateStep)) - sessVars.TxnCtxMu.Unlock() - if sessVars.MemDBFootprint != nil { - sessVars.MemDBFootprint.Detach() - } - sessVars.MemDBFootprint = nil - - if p.enterNewTxnType == sessiontxn.EnterNewTxnBeforeStmt && !sessVars.IsAutocommit() && sessVars.SnapshotTS == 0 { - sessVars.SetInTxn(true) - } - - txn.SetVars(sessVars.KVVars) - - p.SetOptionsOnTxnActive(txn) - - if p.onTxnActiveFunc != nil { - p.onTxnActiveFunc(txn, p.enterNewTxnType) - } - - p.txn = txn - return txn, nil -} - -// prepareTxn prepares txn with an oracle ts future. If the snapshotTS is set, -// the txn is prepared with it. -func (p *baseTxnContextProvider) prepareTxn() error { - if p.isTxnPrepared { - return nil - } - - if snapshotTS := p.sctx.GetSessionVars().SnapshotTS; snapshotTS != 0 { - return p.replaceTxnTsFuture(sessiontxn.ConstantFuture(snapshotTS)) - } - - future := newOracleFuture(p.ctx, p.sctx, p.sctx.GetSessionVars().TxnCtx.TxnScope) - return p.replaceTxnTsFuture(future) -} - -// prepareTxnWithOracleTS -// The difference between prepareTxnWithOracleTS and prepareTxn is that prepareTxnWithOracleTS -// does not consider snapshotTS -func (p *baseTxnContextProvider) prepareTxnWithOracleTS() error { - if p.isTxnPrepared { - return nil - } - - future := newOracleFuture(p.ctx, p.sctx, p.sctx.GetSessionVars().TxnCtx.TxnScope) - return p.replaceTxnTsFuture(future) -} - -func (p *baseTxnContextProvider) forcePrepareConstStartTS(ts uint64) error { - if p.txn != nil { - return errors.New("cannot force prepare const start ts because txn is active") - } - p.constStartTS = ts - p.isTxnPrepared = true - return nil -} - -func (p *baseTxnContextProvider) replaceTxnTsFuture(future oracle.Future) error { - txn, err := p.sctx.Txn(false) - if err != nil { - return err - } - - if txn.Valid() { - return nil - } - - txnScope := p.sctx.GetSessionVars().TxnCtx.TxnScope - if err = p.sctx.PrepareTSFuture(p.ctx, future, txnScope); err != nil { - return err - } - - p.isTxnPrepared = true - return nil -} - -func (p *baseTxnContextProvider) isTidbSnapshotEnabled() bool { - return p.sctx.GetSessionVars().SnapshotTS != 0 -} - -// isBeginStmtWithStaleRead indicates whether the current statement is `BeginStmt` type with stale read -// Because stale read will use `staleread.StalenessTxnContextProvider` for query, so if `staleread.IsStmtStaleness()` -// returns true in other providers, it means the current statement is `BeginStmt` with stale read -func (p *baseTxnContextProvider) isBeginStmtWithStaleRead() bool { - return staleread.IsStmtStaleness(p.sctx) -} - -// AdviseWarmup provides warmup for inner state -func (p *baseTxnContextProvider) AdviseWarmup() error { - if p.isTxnPrepared || p.isBeginStmtWithStaleRead() { - // When executing `START TRANSACTION READ ONLY AS OF ...` no need to warmUp - return nil - } - return p.prepareTxn() -} - -// AdviseOptimizeWithPlan providers optimization according to the plan -func (p *baseTxnContextProvider) AdviseOptimizeWithPlan(_ any) error { - return nil -} - -// GetSnapshotWithStmtReadTS gets snapshot with read ts -func (p *baseTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { - ts, err := p.GetStmtReadTS() - if err != nil { - return nil, err - } - - return p.getSnapshotByTS(ts) -} - -// GetSnapshotWithStmtForUpdateTS gets snapshot with for update ts -func (p *baseTxnContextProvider) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { - ts, err := p.GetStmtForUpdateTS() - if err != nil { - return nil, err - } - - return p.getSnapshotByTS(ts) -} - -// getSnapshotByTS get snapshot from store according to the snapshotTS and set the transaction related -// options before return -func (p *baseTxnContextProvider) getSnapshotByTS(snapshotTS uint64) (kv.Snapshot, error) { - txn, err := p.sctx.Txn(false) - if err != nil { - return nil, err - } - - txnCtx := p.sctx.GetSessionVars().TxnCtx - if txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() && txnCtx.StartTS == snapshotTS { - return txn.GetSnapshot(), nil - } - - sessVars := p.sctx.GetSessionVars() - snapshot := internal.GetSnapshotWithTS( - p.sctx, - snapshotTS, - temptable.SessionSnapshotInterceptor(p.sctx, p.infoSchema), - ) - - replicaReadType := sessVars.GetReplicaRead() - if replicaReadType.IsFollowerRead() && - !sessVars.StmtCtx.RCCheckTS && - !sessVars.RcWriteCheckTS { - snapshot.SetOption(kv.ReplicaRead, replicaReadType) - } - - return snapshot, nil -} - -func (p *baseTxnContextProvider) SetOptionsOnTxnActive(txn kv.Transaction) { - sessVars := p.sctx.GetSessionVars() - - readReplicaType := sessVars.GetReplicaRead() - if readReplicaType.IsFollowerRead() { - txn.SetOption(kv.ReplicaRead, readReplicaType) - } - - if interceptor := temptable.SessionSnapshotInterceptor( - p.sctx, - p.infoSchema, - ); interceptor != nil { - txn.SetOption(kv.SnapInterceptor, interceptor) - } - - if sessVars.StmtCtx.WeakConsistency { - txn.SetOption(kv.IsolationLevel, kv.RC) - } - - internal.SetTxnAssertionLevel(txn, sessVars.AssertionLevel) - - if p.sctx.GetSessionVars().InRestrictedSQL { - txn.SetOption(kv.RequestSourceInternal, true) - } - - if txn.IsPipelined() { - txn.SetOption(kv.RequestSourceType, "p-dml") - } else if tp := p.sctx.GetSessionVars().RequestSourceType; tp != "" { - txn.SetOption(kv.RequestSourceType, tp) - } - - if sessVars.LoadBasedReplicaReadThreshold > 0 { - txn.SetOption(kv.LoadBasedReplicaReadThreshold, sessVars.LoadBasedReplicaReadThreshold) - } - - txn.SetOption(kv.CommitHook, func(info string, _ error) { sessVars.LastTxnInfo = info }) - txn.SetOption(kv.EnableAsyncCommit, sessVars.EnableAsyncCommit) - txn.SetOption(kv.Enable1PC, sessVars.Enable1PC) - if sessVars.DiskFullOpt != kvrpcpb.DiskFullOpt_NotAllowedOnFull { - txn.SetDiskFullOpt(sessVars.DiskFullOpt) - } - txn.SetOption(kv.InfoSchema, sessVars.TxnCtx.InfoSchema) - if sessVars.StmtCtx.KvExecCounter != nil { - // Bind an interceptor for client-go to count the number of SQL executions of each TiKV. - txn.SetOption(kv.RPCInterceptor, sessVars.StmtCtx.KvExecCounter.RPCInterceptor()) - } - txn.SetOption(kv.ResourceGroupTagger, sessVars.StmtCtx.GetResourceGroupTagger()) - txn.SetOption(kv.ExplicitRequestSourceType, sessVars.ExplicitRequestSourceType) - - if p.causalConsistencyOnly || !sessVars.GuaranteeLinearizability { - // priority of the sysvar is lower than `start transaction with causal consistency only` - txn.SetOption(kv.GuaranteeLinearizability, false) - } else { - // We needn't ask the TiKV client to guarantee linearizability for auto-commit transactions - // because the property is naturally holds: - // We guarantee the commitTS of any transaction must not exceed the next timestamp from the TSO. - // An auto-commit transaction fetches its startTS from the TSO so its commitTS > its startTS > the commitTS - // of any previously committed transactions. - // Additionally, it's required to guarantee linearizability for snapshot read-only transactions though - // it does take effects on read-only transactions now. - txn.SetOption( - kv.GuaranteeLinearizability, - !sessVars.IsAutocommit() || - sessVars.SnapshotTS > 0 || - p.enterNewTxnType == sessiontxn.EnterNewTxnDefault || - p.enterNewTxnType == sessiontxn.EnterNewTxnWithBeginStmt, - ) - } - - txn.SetOption(kv.SessionID, p.sctx.GetSessionVars().ConnectionID) -} - -func (p *baseTxnContextProvider) SetOptionsBeforeCommit( - txn kv.Transaction, commitTSChecker func(uint64) bool, -) error { - sessVars := p.sctx.GetSessionVars() - // Pipelined dml txn already flushed mutations into stores, so we don't need to set options for them. - // Instead, some invariants must be checked to avoid anomalies though are unreachable in designed usages. - if p.txn.IsPipelined() { - if p.txn.IsPipelined() && !sessVars.TxnCtx.EnableMDL { - return errors.New("cannot commit pipelined transaction without Metadata Lock: MDL is OFF") - } - if len(sessVars.TxnCtx.TemporaryTables) > 0 { - return errors.New("pipelined dml with temporary tables is not allowed") - } - if sessVars.BinlogClient != nil { - return errors.New("pipelined dml with binlog is not allowed") - } - if sessVars.CDCWriteSource != 0 { - return errors.New("pipelined dml with CDC source is not allowed") - } - if commitTSChecker != nil { - return errors.New("pipelined dml with commitTS checker is not allowed") - } - return nil - } - - // set resource tagger again for internal tasks separated in different transactions - txn.SetOption(kv.ResourceGroupTagger, sessVars.StmtCtx.GetResourceGroupTagger()) - - // Get the related table or partition IDs. - relatedPhysicalTables := sessVars.TxnCtx.TableDeltaMap - // Get accessed temporary tables in the transaction. - temporaryTables := sessVars.TxnCtx.TemporaryTables - physicalTableIDs := make([]int64, 0, len(relatedPhysicalTables)) - for id := range relatedPhysicalTables { - // Schema change on global temporary tables doesn't affect transactions. - if _, ok := temporaryTables[id]; ok { - continue - } - physicalTableIDs = append(physicalTableIDs, id) - } - needCheckSchema := true - // Set this option for 2 phase commit to validate schema lease. - if sessVars.TxnCtx != nil { - needCheckSchema = !sessVars.TxnCtx.EnableMDL - } - - // TODO: refactor SetOption usage to avoid race risk, should detect it in test. - // The pipelined txn will may be flushed in background, not touch the options to avoid races. - // to avoid session set overlap the txn set. - txn.SetOption( - kv.SchemaChecker, - domain.NewSchemaChecker( - domain.GetDomain(p.sctx), - p.GetTxnInfoSchema().SchemaMetaVersion(), - physicalTableIDs, - needCheckSchema, - ), - ) - - if sessVars.StmtCtx.KvExecCounter != nil { - // Bind an interceptor for client-go to count the number of SQL executions of each TiKV. - txn.SetOption(kv.RPCInterceptor, sessVars.StmtCtx.KvExecCounter.RPCInterceptor()) - } - - if tables := sessVars.TxnCtx.TemporaryTables; len(tables) > 0 { - txn.SetOption(kv.KVFilter, temporaryTableKVFilter(tables)) - } - - if sessVars.BinlogClient != nil { - prewriteValue := binloginfo.GetPrewriteValue(p.sctx, false) - if prewriteValue != nil { - prewriteData, err := prewriteValue.Marshal() - if err != nil { - return errors.Trace(err) - } - info := &binloginfo.BinlogInfo{ - Data: &binlog.Binlog{ - Tp: binlog.BinlogType_Prewrite, - PrewriteValue: prewriteData, - }, - Client: sessVars.BinlogClient, - } - txn.SetOption(kv.BinlogInfo, info) - } - } - - var txnSource uint64 - if val := txn.GetOption(kv.TxnSource); val != nil { - txnSource, _ = val.(uint64) - } - // If the transaction is started by CDC, we need to set the CDCWriteSource option. - if sessVars.CDCWriteSource != 0 { - err := kv.SetCDCWriteSource(&txnSource, sessVars.CDCWriteSource) - if err != nil { - return errors.Trace(err) - } - - txn.SetOption(kv.TxnSource, txnSource) - } - - if commitTSChecker != nil { - txn.SetOption(kv.CommitTSUpperBoundCheck, commitTSChecker) - } - return nil -} - -// canReuseTxnWhenExplicitBegin returns whether we should reuse the txn when starting a transaction explicitly -func canReuseTxnWhenExplicitBegin(sctx sessionctx.Context) bool { - sessVars := sctx.GetSessionVars() - txnCtx := sessVars.TxnCtx - // If BEGIN is the first statement in TxnCtx, we can reuse the existing transaction, without the - // need to call NewTxn, which commits the existing transaction and begins a new one. - // If the last un-committed/un-rollback transaction is a time-bounded read-only transaction, we should - // always create a new transaction. - // If the variable `tidb_snapshot` is set, we should always create a new transaction because the current txn may be - // initialized with snapshot ts. - return txnCtx.History == nil && !txnCtx.IsStaleness && sessVars.SnapshotTS == 0 -} - -// newOracleFuture creates new future according to the scope and the session context -func newOracleFuture(ctx context.Context, sctx sessionctx.Context, scope string) oracle.Future { - r, ctx := tracing.StartRegionEx(ctx, "isolation.newOracleFuture") - defer r.End() - - failpoint.Inject("requestTsoFromPD", func() { - sessiontxn.TsoRequestCountInc(sctx) - }) - - oracleStore := sctx.GetStore().GetOracle() - option := &oracle.Option{TxnScope: scope} - - if sctx.GetSessionVars().UseLowResolutionTSO() { - return oracleStore.GetLowResolutionTimestampAsync(ctx, option) - } - return oracleStore.GetTimestampAsync(ctx, option) -} - -// funcFuture implements oracle.Future -type funcFuture func() (uint64, error) - -// Wait returns a ts got from the func -func (f funcFuture) Wait() (uint64, error) { - return f() -} - -// basePessimisticTxnContextProvider extends baseTxnContextProvider with some functionalities that are commonly used in -// pessimistic transactions. -type basePessimisticTxnContextProvider struct { - baseTxnContextProvider -} - -// OnPessimisticStmtStart is the hook that should be called when starts handling a pessimistic DML or -// a pessimistic select-for-update statements. -func (p *basePessimisticTxnContextProvider) OnPessimisticStmtStart(ctx context.Context) error { - if err := p.baseTxnContextProvider.OnPessimisticStmtStart(ctx); err != nil { - return err - } - if p.sctx.GetSessionVars().PessimisticTransactionFairLocking && - p.txn != nil && - p.sctx.GetSessionVars().ConnectionID != 0 && - !p.sctx.GetSessionVars().InRestrictedSQL { - if err := p.txn.StartFairLocking(); err != nil { - return err - } - } - return nil -} - -// OnPessimisticStmtEnd is the hook that should be called when finishes handling a pessimistic DML or -// select-for-update statement. -func (p *basePessimisticTxnContextProvider) OnPessimisticStmtEnd(ctx context.Context, isSuccessful bool) error { - if err := p.baseTxnContextProvider.OnPessimisticStmtEnd(ctx, isSuccessful); err != nil { - return err - } - if p.txn != nil && p.txn.IsInFairLockingMode() { - if isSuccessful { - if err := p.txn.DoneFairLocking(ctx); err != nil { - return err - } - } else { - if err := p.txn.CancelFairLocking(ctx); err != nil { - return err - } - } - } - - if isSuccessful { - p.sctx.GetSessionVars().TxnCtx.FlushStmtPessimisticLockCache() - } else { - p.sctx.GetSessionVars().TxnCtx.CurrentStmtPessimisticLockCache = nil - } - return nil -} - -func (p *basePessimisticTxnContextProvider) retryFairLockingIfNeeded(ctx context.Context) error { - if p.txn != nil && p.txn.IsInFairLockingMode() { - if err := p.txn.RetryFairLocking(ctx); err != nil { - return err - } - } - return nil -} - -func (p *basePessimisticTxnContextProvider) cancelFairLockingIfNeeded(ctx context.Context) error { - if p.txn != nil && p.txn.IsInFairLockingMode() { - if err := p.txn.CancelFairLocking(ctx); err != nil { - return err - } - } - return nil -} - -type temporaryTableKVFilter map[int64]tableutil.TempTable - -func (m temporaryTableKVFilter) IsUnnecessaryKeyValue( - key, value []byte, flags tikvstore.KeyFlags, -) (bool, error) { - tid := tablecodec.DecodeTableID(key) - if _, ok := m[tid]; ok { - return true, nil - } - - // This is the default filter for all tables. - defaultFilter := txn.TiDBKVFilter{} - return defaultFilter.IsUnnecessaryKeyValue(key, value, flags) -} diff --git a/pkg/sessiontxn/isolation/binding__failpoint_binding__.go b/pkg/sessiontxn/isolation/binding__failpoint_binding__.go deleted file mode 100644 index 825cd58da0304..0000000000000 --- a/pkg/sessiontxn/isolation/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package isolation - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/sessiontxn/isolation/readcommitted.go b/pkg/sessiontxn/isolation/readcommitted.go index aa3fed63bbb5c..7cd946df70ac9 100644 --- a/pkg/sessiontxn/isolation/readcommitted.go +++ b/pkg/sessiontxn/isolation/readcommitted.go @@ -131,9 +131,9 @@ func (p *PessimisticRCTxnContextProvider) OnStmtRetry(ctx context.Context) error if err := p.basePessimisticTxnContextProvider.OnStmtRetry(ctx); err != nil { return err } - if _, _err_ := failpoint.Eval(_curpkg_("CallOnStmtRetry")); _err_ == nil { + failpoint.Inject("CallOnStmtRetry", func() { sessiontxn.OnStmtRetryCountInc(p.sctx) - } + }) p.latestOracleTSValid = false p.checkTSInWriteStmt = false return p.prepareStmt(false) @@ -164,9 +164,9 @@ func (p *PessimisticRCTxnContextProvider) getOracleFuture() funcFuture { if ts, err = future.Wait(); err != nil { return } - if _, _err_ := failpoint.Eval(_curpkg_("waitTsoOfOracleFuture")); _err_ == nil { + failpoint.Inject("waitTsoOfOracleFuture", func() { sessiontxn.TsoWaitCountInc(p.sctx) - } + }) txnCtx.SetForUpdateTS(ts) ts = txnCtx.GetForUpdateTS() p.latestOracleTS = ts @@ -318,9 +318,9 @@ func (p *PessimisticRCTxnContextProvider) AdviseOptimizeWithPlan(val any) (err e } if useLastOracleTS { - if _, _err_ := failpoint.Eval(_curpkg_("tsoUseConstantFuture")); _err_ == nil { + failpoint.Inject("tsoUseConstantFuture", func() { sessiontxn.TsoUseConstantCountInc(p.sctx) - } + }) p.checkTSInWriteStmt = true p.stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) } diff --git a/pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ b/pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ deleted file mode 100644 index 7cd946df70ac9..0000000000000 --- a/pkg/sessiontxn/isolation/readcommitted.go__failpoint_stash__ +++ /dev/null @@ -1,360 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package isolation - -import ( - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - isolation_metrics "github.com/pingcap/tidb/pkg/sessiontxn/isolation/metrics" - "github.com/pingcap/tidb/pkg/util/logutil" - tikverr "github.com/tikv/client-go/v2/error" - "github.com/tikv/client-go/v2/oracle" - "go.uber.org/zap" -) - -type stmtState struct { - stmtTS uint64 - stmtTSFuture oracle.Future - stmtUseStartTS bool -} - -func (s *stmtState) prepareStmt(useStartTS bool) error { - *s = stmtState{ - stmtUseStartTS: useStartTS, - } - return nil -} - -// PessimisticRCTxnContextProvider provides txn context for isolation level read-committed -type PessimisticRCTxnContextProvider struct { - basePessimisticTxnContextProvider - stmtState - latestOracleTS uint64 - // latestOracleTSValid shows whether we have already fetched a ts from pd and whether the ts we fetched is still valid. - latestOracleTSValid bool - // checkTSInWriteStmt is used to set RCCheckTS isolation for getting value when doing point-write - checkTSInWriteStmt bool -} - -// NewPessimisticRCTxnContextProvider returns a new PessimisticRCTxnContextProvider -func NewPessimisticRCTxnContextProvider(sctx sessionctx.Context, causalConsistencyOnly bool) *PessimisticRCTxnContextProvider { - provider := &PessimisticRCTxnContextProvider{ - basePessimisticTxnContextProvider: basePessimisticTxnContextProvider{ - baseTxnContextProvider: baseTxnContextProvider{ - sctx: sctx, - causalConsistencyOnly: causalConsistencyOnly, - onInitializeTxnCtx: func(txnCtx *variable.TransactionContext) { - txnCtx.IsPessimistic = true - txnCtx.Isolation = ast.ReadCommitted - }, - onTxnActiveFunc: func(txn kv.Transaction, _ sessiontxn.EnterNewTxnType) { - txn.SetOption(kv.Pessimistic, true) - }, - }, - }, - } - - provider.onTxnActiveFunc = func(txn kv.Transaction, _ sessiontxn.EnterNewTxnType) { - txn.SetOption(kv.Pessimistic, true) - provider.latestOracleTS = txn.StartTS() - provider.latestOracleTSValid = true - } - provider.getStmtReadTSFunc = provider.getStmtTS - provider.getStmtForUpdateTSFunc = provider.getStmtTS - return provider -} - -// OnStmtStart is the hook that should be called when a new statement started -func (p *PessimisticRCTxnContextProvider) OnStmtStart(ctx context.Context, node ast.StmtNode) error { - if err := p.basePessimisticTxnContextProvider.OnStmtStart(ctx, node); err != nil { - return err - } - - // Try to mark the `RCCheckTS` flag for the first time execution of in-transaction read requests - // using read-consistency isolation level. - if node != nil && NeedSetRCCheckTSFlag(p.sctx, node) { - p.sctx.GetSessionVars().StmtCtx.RCCheckTS = true - } - p.checkTSInWriteStmt = false - - return p.prepareStmt(!p.isTxnPrepared) -} - -// NeedSetRCCheckTSFlag checks whether it's needed to set `RCCheckTS` flag in current stmtctx. -func NeedSetRCCheckTSFlag(ctx sessionctx.Context, node ast.Node) bool { - sessionVars := ctx.GetSessionVars() - if sessionVars.ConnectionID > 0 && variable.EnableRCReadCheckTS.Load() && - sessionVars.InTxn() && !sessionVars.RetryInfo.Retrying && - plannercore.IsReadOnly(node, sessionVars) { - return true - } - return false -} - -// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error -func (p *PessimisticRCTxnContextProvider) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { - switch point { - case sessiontxn.StmtErrAfterQuery: - return p.handleAfterQueryError(err) - case sessiontxn.StmtErrAfterPessimisticLock: - return p.handleAfterPessimisticLockError(ctx, err) - default: - return p.basePessimisticTxnContextProvider.OnStmtErrorForNextAction(ctx, point, err) - } -} - -// OnStmtRetry is the hook that should be called when a statement is retried internally. -func (p *PessimisticRCTxnContextProvider) OnStmtRetry(ctx context.Context) error { - if err := p.basePessimisticTxnContextProvider.OnStmtRetry(ctx); err != nil { - return err - } - failpoint.Inject("CallOnStmtRetry", func() { - sessiontxn.OnStmtRetryCountInc(p.sctx) - }) - p.latestOracleTSValid = false - p.checkTSInWriteStmt = false - return p.prepareStmt(false) -} - -func (p *PessimisticRCTxnContextProvider) prepareStmtTS() { - if p.stmtTSFuture != nil { - return - } - sessVars := p.sctx.GetSessionVars() - var stmtTSFuture oracle.Future - switch { - case p.stmtUseStartTS: - stmtTSFuture = funcFuture(p.getTxnStartTS) - case p.latestOracleTSValid && sessVars.StmtCtx.RCCheckTS: - stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) - default: - stmtTSFuture = p.getOracleFuture() - } - - p.stmtTSFuture = stmtTSFuture -} - -func (p *PessimisticRCTxnContextProvider) getOracleFuture() funcFuture { - txnCtx := p.sctx.GetSessionVars().TxnCtx - future := newOracleFuture(p.ctx, p.sctx, txnCtx.TxnScope) - return func() (ts uint64, err error) { - if ts, err = future.Wait(); err != nil { - return - } - failpoint.Inject("waitTsoOfOracleFuture", func() { - sessiontxn.TsoWaitCountInc(p.sctx) - }) - txnCtx.SetForUpdateTS(ts) - ts = txnCtx.GetForUpdateTS() - p.latestOracleTS = ts - p.latestOracleTSValid = true - return - } -} - -func (p *PessimisticRCTxnContextProvider) getStmtTS() (ts uint64, err error) { - if p.stmtTS != 0 { - return p.stmtTS, nil - } - - var txn kv.Transaction - if txn, err = p.ActivateTxn(); err != nil { - return 0, err - } - - p.prepareStmtTS() - start := time.Now() - if ts, err = p.stmtTSFuture.Wait(); err != nil { - return 0, err - } - p.sctx.GetSessionVars().DurationWaitTS += time.Since(start) - - txn.SetOption(kv.SnapshotTS, ts) - p.stmtTS = ts - return -} - -// handleAfterQueryError will be called when the handle point is `StmtErrAfterQuery`. -// At this point the query will be retried from the beginning. -func (p *PessimisticRCTxnContextProvider) handleAfterQueryError(queryErr error) (sessiontxn.StmtErrorAction, error) { - sessVars := p.sctx.GetSessionVars() - if !errors.ErrorEqual(queryErr, kv.ErrWriteConflict) || !sessVars.StmtCtx.RCCheckTS { - return sessiontxn.NoIdea() - } - - isolation_metrics.RcReadCheckTSWriteConfilictCounter.Inc() - - logutil.Logger(p.ctx).Info("RC read with ts checking has failed, retry RC read", - zap.String("sql", sessVars.StmtCtx.OriginalSQL), zap.Error(queryErr)) - return sessiontxn.RetryReady() -} - -func (p *PessimisticRCTxnContextProvider) handleAfterPessimisticLockError(ctx context.Context, lockErr error) (sessiontxn.StmtErrorAction, error) { - txnCtx := p.sctx.GetSessionVars().TxnCtx - retryable := false - if deadlock, ok := errors.Cause(lockErr).(*tikverr.ErrDeadlock); ok && deadlock.IsRetryable { - logutil.Logger(p.ctx).Info("single statement deadlock, retry statement", - zap.Uint64("txn", txnCtx.StartTS), - zap.Uint64("lockTS", deadlock.LockTs), - zap.Stringer("lockKey", kv.Key(deadlock.LockKey)), - zap.Uint64("deadlockKeyHash", deadlock.DeadlockKeyHash)) - retryable = true - - // In fair locking mode, when statement retry happens, `retryFairLockingIfNeeded` should be - // called to make its state ready for retrying. But single-statement deadlock is an exception. We need to exit - // fair locking in single-statement-deadlock case, otherwise the lock this statement has acquired won't be - // released after retrying, so it still blocks another transaction and the deadlock won't be resolved. - if err := p.cancelFairLockingIfNeeded(ctx); err != nil { - return sessiontxn.ErrorAction(err) - } - } else if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { - logutil.Logger(p.ctx).Debug("pessimistic write conflict, retry statement", - zap.Uint64("txn", txnCtx.StartTS), - zap.Uint64("forUpdateTS", txnCtx.GetForUpdateTS()), - zap.String("err", lockErr.Error())) - retryable = true - if p.checkTSInWriteStmt { - isolation_metrics.RcWriteCheckTSWriteConfilictCounter.Inc() - } - } - - if retryable { - if err := p.basePessimisticTxnContextProvider.retryFairLockingIfNeeded(ctx); err != nil { - return sessiontxn.ErrorAction(err) - } - return sessiontxn.RetryReady() - } - return sessiontxn.ErrorAction(lockErr) -} - -// AdviseWarmup provides warmup for inner state -func (p *PessimisticRCTxnContextProvider) AdviseWarmup() error { - if err := p.prepareTxn(); err != nil { - return err - } - - if !p.isTidbSnapshotEnabled() { - p.prepareStmtTS() - } - - return nil -} - -// planSkipGetTsoFromPD identifies the plans which don't need get newest ts from PD. -func planSkipGetTsoFromPD(sctx sessionctx.Context, plan base.Plan, inLockOrWriteStmt bool) bool { - switch v := plan.(type) { - case *plannercore.PointGetPlan: - return sctx.GetSessionVars().RcWriteCheckTS && (v.Lock || inLockOrWriteStmt) - case base.PhysicalPlan: - if len(v.Children()) == 0 { - return false - } - _, isPhysicalLock := v.(*plannercore.PhysicalLock) - for _, p := range v.Children() { - if !planSkipGetTsoFromPD(sctx, p, isPhysicalLock || inLockOrWriteStmt) { - return false - } - } - return true - case *plannercore.Update: - return planSkipGetTsoFromPD(sctx, v.SelectPlan, true) - case *plannercore.Delete: - return planSkipGetTsoFromPD(sctx, v.SelectPlan, true) - case *plannercore.Insert: - return v.SelectPlan == nil && len(v.OnDuplicate) == 0 && !v.IsReplace - } - return false -} - -// AdviseOptimizeWithPlan in read-committed covers as many cases as repeatable-read. -// We do not fetch latest ts immediately for such scenes. -// 1. A query like the form of "SELECT ... FOR UPDATE" whose execution plan is "PointGet". -// 2. An INSERT statement without "SELECT" subquery. -// 3. A UPDATE statement whose sub execution plan is "PointGet". -// 4. A DELETE statement whose sub execution plan is "PointGet". -func (p *PessimisticRCTxnContextProvider) AdviseOptimizeWithPlan(val any) (err error) { - if p.isTidbSnapshotEnabled() || p.isBeginStmtWithStaleRead() { - return nil - } - if p.stmtUseStartTS || !p.latestOracleTSValid { - return nil - } - - plan, ok := val.(base.Plan) - if !ok { - return nil - } - - if execute, ok := plan.(*plannercore.Execute); ok { - plan = execute.Plan - } - - useLastOracleTS := false - if !p.sctx.GetSessionVars().RetryInfo.Retrying { - useLastOracleTS = planSkipGetTsoFromPD(p.sctx, plan, false) - } - - if useLastOracleTS { - failpoint.Inject("tsoUseConstantFuture", func() { - sessiontxn.TsoUseConstantCountInc(p.sctx) - }) - p.checkTSInWriteStmt = true - p.stmtTSFuture = sessiontxn.ConstantFuture(p.latestOracleTS) - } - - return nil -} - -// GetSnapshotWithStmtForUpdateTS gets snapshot with for update ts -func (p *PessimisticRCTxnContextProvider) GetSnapshotWithStmtForUpdateTS() (kv.Snapshot, error) { - snapshot, err := p.basePessimisticTxnContextProvider.GetSnapshotWithStmtForUpdateTS() - if err != nil { - return nil, err - } - if p.checkTSInWriteStmt { - snapshot.SetOption(kv.IsolationLevel, kv.RCCheckTS) - } - return snapshot, err -} - -// GetSnapshotWithStmtReadTS gets snapshot with read ts -func (p *PessimisticRCTxnContextProvider) GetSnapshotWithStmtReadTS() (kv.Snapshot, error) { - snapshot, err := p.basePessimisticTxnContextProvider.GetSnapshotWithStmtForUpdateTS() - if err != nil { - return nil, err - } - - if p.sctx.GetSessionVars().StmtCtx.RCCheckTS { - snapshot.SetOption(kv.IsolationLevel, kv.RCCheckTS) - } - - return snapshot, nil -} - -// IsCheckTSInWriteStmtMode is only used for test -func (p *PessimisticRCTxnContextProvider) IsCheckTSInWriteStmtMode() bool { - return p.checkTSInWriteStmt -} diff --git a/pkg/sessiontxn/isolation/repeatable_read.go b/pkg/sessiontxn/isolation/repeatable_read.go index 077815399acbc..55f80568f1f88 100644 --- a/pkg/sessiontxn/isolation/repeatable_read.go +++ b/pkg/sessiontxn/isolation/repeatable_read.go @@ -114,9 +114,9 @@ func (p *PessimisticRRTxnContextProvider) updateForUpdateTS() (err error) { return errors.Trace(kv.ErrInvalidTxn) } - if _, _err_ := failpoint.Eval(_curpkg_("RequestTsoFromPD")); _err_ == nil { + failpoint.Inject("RequestTsoFromPD", func() { sessiontxn.TsoRequestCountInc(sctx) - } + }) // Because the ForUpdateTS is used for the snapshot for reading data in DML. // We can avoid allocating a global TSO here to speed it up by using the local TSO. diff --git a/pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ b/pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ deleted file mode 100644 index 55f80568f1f88..0000000000000 --- a/pkg/sessiontxn/isolation/repeatable_read.go__failpoint_stash__ +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package isolation - -import ( - "context" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/terror" - plannercore "github.com/pingcap/tidb/pkg/planner/core" - "github.com/pingcap/tidb/pkg/planner/core/base" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/sessiontxn" - "github.com/pingcap/tidb/pkg/util/logutil" - tikverr "github.com/tikv/client-go/v2/error" - "go.uber.org/zap" -) - -// PessimisticRRTxnContextProvider provides txn context for isolation level repeatable-read -type PessimisticRRTxnContextProvider struct { - basePessimisticTxnContextProvider - - // Used for ForUpdateRead statement - forUpdateTS uint64 - latestForUpdateTS uint64 - // It may decide whether to update forUpdateTs when calling provider's getForUpdateTs - // See more details in the comments of optimizeWithPlan - optimizeForNotFetchingLatestTS bool -} - -// NewPessimisticRRTxnContextProvider returns a new PessimisticRRTxnContextProvider -func NewPessimisticRRTxnContextProvider(sctx sessionctx.Context, causalConsistencyOnly bool) *PessimisticRRTxnContextProvider { - provider := &PessimisticRRTxnContextProvider{ - basePessimisticTxnContextProvider: basePessimisticTxnContextProvider{ - baseTxnContextProvider: baseTxnContextProvider{ - sctx: sctx, - causalConsistencyOnly: causalConsistencyOnly, - onInitializeTxnCtx: func(txnCtx *variable.TransactionContext) { - txnCtx.IsPessimistic = true - txnCtx.Isolation = ast.RepeatableRead - }, - onTxnActiveFunc: func(txn kv.Transaction, _ sessiontxn.EnterNewTxnType) { - txn.SetOption(kv.Pessimistic, true) - }, - }, - }, - } - - provider.getStmtReadTSFunc = provider.getTxnStartTS - provider.getStmtForUpdateTSFunc = provider.getForUpdateTs - - return provider -} - -func (p *PessimisticRRTxnContextProvider) getForUpdateTs() (ts uint64, err error) { - if p.forUpdateTS != 0 { - return p.forUpdateTS, nil - } - - var txn kv.Transaction - if txn, err = p.ActivateTxn(); err != nil { - return 0, err - } - - if p.optimizeForNotFetchingLatestTS { - p.forUpdateTS = p.sctx.GetSessionVars().TxnCtx.GetForUpdateTS() - return p.forUpdateTS, nil - } - - txnCtx := p.sctx.GetSessionVars().TxnCtx - futureTS := newOracleFuture(p.ctx, p.sctx, txnCtx.TxnScope) - - start := time.Now() - if ts, err = futureTS.Wait(); err != nil { - return 0, err - } - p.sctx.GetSessionVars().DurationWaitTS += time.Since(start) - - txnCtx.SetForUpdateTS(ts) - txn.SetOption(kv.SnapshotTS, ts) - - p.forUpdateTS = ts - - return -} - -// updateForUpdateTS acquires the latest TSO and update the TransactionContext and kv.Transaction with it. -func (p *PessimisticRRTxnContextProvider) updateForUpdateTS() (err error) { - sctx := p.sctx - var txn kv.Transaction - - if txn, err = sctx.Txn(false); err != nil { - return err - } - - if !txn.Valid() { - return errors.Trace(kv.ErrInvalidTxn) - } - - failpoint.Inject("RequestTsoFromPD", func() { - sessiontxn.TsoRequestCountInc(sctx) - }) - - // Because the ForUpdateTS is used for the snapshot for reading data in DML. - // We can avoid allocating a global TSO here to speed it up by using the local TSO. - version, err := sctx.GetStore().CurrentVersion(sctx.GetSessionVars().TxnCtx.TxnScope) - if err != nil { - return err - } - - sctx.GetSessionVars().TxnCtx.SetForUpdateTS(version.Ver) - p.latestForUpdateTS = version.Ver - txn.SetOption(kv.SnapshotTS, version.Ver) - - return nil -} - -// OnStmtStart is the hook that should be called when a new statement started -func (p *PessimisticRRTxnContextProvider) OnStmtStart(ctx context.Context, node ast.StmtNode) error { - if err := p.basePessimisticTxnContextProvider.OnStmtStart(ctx, node); err != nil { - return err - } - - p.forUpdateTS = 0 - p.optimizeForNotFetchingLatestTS = false - - return nil -} - -// OnStmtRetry is the hook that should be called when a statement is retried internally. -func (p *PessimisticRRTxnContextProvider) OnStmtRetry(ctx context.Context) (err error) { - if err = p.basePessimisticTxnContextProvider.OnStmtRetry(ctx); err != nil { - return err - } - - // If TxnCtx.forUpdateTS is updated in OnStmtErrorForNextAction, we assign the value to the provider - if p.latestForUpdateTS > p.forUpdateTS { - p.forUpdateTS = p.latestForUpdateTS - } else { - p.forUpdateTS = 0 - } - - p.optimizeForNotFetchingLatestTS = false - - return nil -} - -// OnStmtErrorForNextAction is the hook that should be called when a new statement get an error -func (p *PessimisticRRTxnContextProvider) OnStmtErrorForNextAction(ctx context.Context, point sessiontxn.StmtErrorHandlePoint, err error) (sessiontxn.StmtErrorAction, error) { - switch point { - case sessiontxn.StmtErrAfterPessimisticLock: - return p.handleAfterPessimisticLockError(ctx, err) - default: - return sessiontxn.NoIdea() - } -} - -// AdviseOptimizeWithPlan optimizes for update point get related execution. -// Use case: In for update point get related operations, we do not fetch ts from PD but use the last ts we fetched. -// -// We expect that the data that the point get acquires has not been changed. -// -// Benefit: Save the cost of acquiring ts from PD. -// Drawbacks: If the data has been changed since the ts we used, we need to retry. -// One exception is insert operation, when it has no select plan, we do not fetch the latest ts immediately. We only update ts -// if write conflict is incurred. -func (p *PessimisticRRTxnContextProvider) AdviseOptimizeWithPlan(val any) (err error) { - if p.isTidbSnapshotEnabled() || p.isBeginStmtWithStaleRead() { - return nil - } - - plan, ok := val.(base.Plan) - if !ok { - return nil - } - - if execute, ok := plan.(*plannercore.Execute); ok { - plan = execute.Plan - } - - p.optimizeForNotFetchingLatestTS = notNeedGetLatestTSFromPD(plan, false) - - return nil -} - -// notNeedGetLatestTSFromPD searches for optimization condition recursively -// Note: For point get and batch point get (name it plan), if one of the ancestor node is update/delete/physicalLock, -// we should check whether the plan.Lock is true or false. See comments in needNotToBeOptimized. -// inLockOrWriteStmt = true means one of the ancestor node is update/delete/physicalLock. -func notNeedGetLatestTSFromPD(plan base.Plan, inLockOrWriteStmt bool) bool { - switch v := plan.(type) { - case *plannercore.PointGetPlan: - // We do not optimize the point get/ batch point get if plan.lock = false and inLockOrWriteStmt = true. - // Theoretically, the plan.lock should be true if the flag is true. But due to the bug describing in Issue35524, - // the plan.lock can be false in the case of inLockOrWriteStmt being true. In this case, optimization here can lead to different results - // which cannot be accepted as AdviseOptimizeWithPlan cannot change results. - return !inLockOrWriteStmt || v.Lock - case *plannercore.BatchPointGetPlan: - return !inLockOrWriteStmt || v.Lock - case base.PhysicalPlan: - if len(v.Children()) == 0 { - return false - } - _, isPhysicalLock := v.(*plannercore.PhysicalLock) - for _, p := range v.Children() { - if !notNeedGetLatestTSFromPD(p, isPhysicalLock || inLockOrWriteStmt) { - return false - } - } - return true - case *plannercore.Update: - return notNeedGetLatestTSFromPD(v.SelectPlan, true) - case *plannercore.Delete: - return notNeedGetLatestTSFromPD(v.SelectPlan, true) - case *plannercore.Insert: - return v.SelectPlan == nil - } - return false -} - -func (p *PessimisticRRTxnContextProvider) handleAfterPessimisticLockError(ctx context.Context, lockErr error) (sessiontxn.StmtErrorAction, error) { - sessVars := p.sctx.GetSessionVars() - txnCtx := sessVars.TxnCtx - - if deadlock, ok := errors.Cause(lockErr).(*tikverr.ErrDeadlock); ok { - if !deadlock.IsRetryable { - return sessiontxn.ErrorAction(lockErr) - } - - logutil.Logger(p.ctx).Info("single statement deadlock, retry statement", - zap.Uint64("txn", txnCtx.StartTS), - zap.Uint64("lockTS", deadlock.LockTs), - zap.Stringer("lockKey", kv.Key(deadlock.LockKey)), - zap.Uint64("deadlockKeyHash", deadlock.DeadlockKeyHash)) - - // In fair locking mode, when statement retry happens, `retryFairLockingIfNeeded` should be - // called to make its state ready for retrying. But single-statement deadlock is an exception. We need to exit - // fair locking in single-statement-deadlock case, otherwise the lock this statement has acquired won't be - // released after retrying, so it still blocks another transaction and the deadlock won't be resolved. - if err := p.cancelFairLockingIfNeeded(ctx); err != nil { - return sessiontxn.ErrorAction(err) - } - } else if terror.ErrorEqual(kv.ErrWriteConflict, lockErr) { - // Always update forUpdateTS by getting a new timestamp from PD. - // If we use the conflict commitTS as the new forUpdateTS and async commit - // is used, the commitTS of this transaction may exceed the max timestamp - // that PD allocates. Then, the change may be invisible to a new transaction, - // which means linearizability is broken. - // suppose the following scenario: - // - Txn1/2/3 get start-ts - // - Txn1/2 all get min-commit-ts as required by async commit from PD in order - // - now max ts on PD is PD-max-ts - // - Txn2 commit with calculated commit-ts = PD-max-ts + 1 - // - Txn3 try lock a key committed by Txn2 and get write conflict and use - // conflict commit-ts as forUpdateTS, lock and read, TiKV will update its - // max-ts to PD-max-ts + 1 - // - Txn1 commit with calculated commit-ts = PD-max-ts + 2 - // - suppose Txn4 after Txn1 on same session, it gets start-ts = PD-max-ts + 1 from PD - // - Txn4 cannot see Txn1's changes because its start-ts is less than Txn1's commit-ts - // which breaks linearizability. - errStr := lockErr.Error() - forUpdateTS := txnCtx.GetForUpdateTS() - - logutil.Logger(p.ctx).Debug("pessimistic write conflict, retry statement", - zap.Uint64("txn", txnCtx.StartTS), - zap.Uint64("forUpdateTS", forUpdateTS), - zap.String("err", errStr)) - } else { - // This branch: if err is not nil, always update forUpdateTS to avoid problem described below. - // For nowait, when ErrLock happened, ErrLockAcquireFailAndNoWaitSet will be returned, and in the same txn - // the select for updateTs must be updated, otherwise there maybe rollback problem. - // begin - // select for update key1 (here encounters ErrLocked or other errors (or max_execution_time like util), - // key1 lock has not gotten and async rollback key1 is raised) - // select for update key1 again (this time lock is acquired successfully (maybe lock was released by others)) - // the async rollback operation rollbacks the lock just acquired - if err := p.updateForUpdateTS(); err != nil { - logutil.Logger(p.ctx).Warn("UpdateForUpdateTS failed", zap.Error(err)) - } - - return sessiontxn.ErrorAction(lockErr) - } - - if err := p.updateForUpdateTS(); err != nil { - return sessiontxn.ErrorAction(lockErr) - } - - if err := p.retryFairLockingIfNeeded(ctx); err != nil { - return sessiontxn.ErrorAction(err) - } - return sessiontxn.RetryReady() -} diff --git a/pkg/sessiontxn/staleread/binding__failpoint_binding__.go b/pkg/sessiontxn/staleread/binding__failpoint_binding__.go deleted file mode 100644 index 19c3042b7bdd1..0000000000000 --- a/pkg/sessiontxn/staleread/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package staleread - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/sessiontxn/staleread/util.go b/pkg/sessiontxn/staleread/util.go index 6919733064282..c66daedfa2ab5 100644 --- a/pkg/sessiontxn/staleread/util.go +++ b/pkg/sessiontxn/staleread/util.go @@ -35,9 +35,9 @@ import ( // CalculateAsOfTsExpr calculates the TsExpr of AsOfClause to get a StartTS. func CalculateAsOfTsExpr(ctx context.Context, sctx pctx.PlanContext, tsExpr ast.ExprNode) (uint64, error) { sctx.GetSessionVars().StmtCtx.SetStaleTSOProvider(func() (uint64, error) { - if val, _err_ := failpoint.Eval(_curpkg_("mockStaleReadTSO")); _err_ == nil { + failpoint.Inject("mockStaleReadTSO", func(val failpoint.Value) (uint64, error) { return uint64(val.(int)), nil - } + }) // this function accepts a context, but we don't need it when there is a valid cached ts. // in most cases, the stale read ts can be calculated from `cached ts + time since cache - staleness`, // this can be more accurate than `time.Now() - staleness`, because TiDB's local time can drift. diff --git a/pkg/sessiontxn/staleread/util.go__failpoint_stash__ b/pkg/sessiontxn/staleread/util.go__failpoint_stash__ deleted file mode 100644 index c66daedfa2ab5..0000000000000 --- a/pkg/sessiontxn/staleread/util.go__failpoint_stash__ +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package staleread - -import ( - "context" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/mysql" - pctx "github.com/pingcap/tidb/pkg/planner/context" - plannerutil "github.com/pingcap/tidb/pkg/planner/util" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/dbterror/plannererrors" - "github.com/tikv/client-go/v2/oracle" -) - -// CalculateAsOfTsExpr calculates the TsExpr of AsOfClause to get a StartTS. -func CalculateAsOfTsExpr(ctx context.Context, sctx pctx.PlanContext, tsExpr ast.ExprNode) (uint64, error) { - sctx.GetSessionVars().StmtCtx.SetStaleTSOProvider(func() (uint64, error) { - failpoint.Inject("mockStaleReadTSO", func(val failpoint.Value) (uint64, error) { - return uint64(val.(int)), nil - }) - // this function accepts a context, but we don't need it when there is a valid cached ts. - // in most cases, the stale read ts can be calculated from `cached ts + time since cache - staleness`, - // this can be more accurate than `time.Now() - staleness`, because TiDB's local time can drift. - return sctx.GetStore().GetOracle().GetStaleTimestamp(ctx, oracle.GlobalTxnScope, 0) - }) - tsVal, err := plannerutil.EvalAstExprWithPlanCtx(sctx, tsExpr) - if err != nil { - return 0, err - } - - if tsVal.IsNull() { - return 0, plannererrors.ErrAsOf.FastGenWithCause("as of timestamp cannot be NULL") - } - - toTypeTimestamp := types.NewFieldType(mysql.TypeTimestamp) - // We need at least the millionsecond here, so set fsp to 3. - toTypeTimestamp.SetDecimal(3) - tsTimestamp, err := tsVal.ConvertTo(sctx.GetSessionVars().StmtCtx.TypeCtx(), toTypeTimestamp) - if err != nil { - return 0, err - } - tsTime, err := tsTimestamp.GetMysqlTime().GoTime(sctx.GetSessionVars().Location()) - if err != nil { - return 0, err - } - return oracle.GoTimeToTS(tsTime), nil -} - -// CalculateTsWithReadStaleness calculates the TsExpr for readStaleness duration -func CalculateTsWithReadStaleness(sctx sessionctx.Context, readStaleness time.Duration) (uint64, error) { - nowVal, err := expression.GetStmtTimestamp(sctx.GetExprCtx().GetEvalCtx()) - if err != nil { - return 0, err - } - tsVal := nowVal.Add(readStaleness) - sc := sctx.GetSessionVars().StmtCtx - minTsVal := expression.GetStmtMinSafeTime(sc, sctx.GetStore(), sc.TimeZone()) - return oracle.GoTimeToTS(expression.CalAppropriateTime(tsVal, nowVal, minTsVal)), nil -} - -// IsStmtStaleness indicates whether the current statement is staleness or not -func IsStmtStaleness(sctx sessionctx.Context) bool { - return sctx.GetSessionVars().StmtCtx.IsStaleness -} - -// GetExternalTimestamp returns the external timestamp in cache, or get and store it in cache -func GetExternalTimestamp(ctx context.Context, sc *stmtctx.StatementContext) (uint64, error) { - // Try to get from the stmt cache to make sure this function is deterministic. - externalTimestamp, err := sc.GetOrEvaluateStmtCache(stmtctx.StmtExternalTSCacheKey, func() (any, error) { - return variable.GetExternalTimestamp(ctx) - }) - - if err != nil { - return 0, plannererrors.ErrAsOf.FastGenWithCause(err.Error()) - } - return externalTimestamp.(uint64), nil -} diff --git a/pkg/statistics/binding__failpoint_binding__.go b/pkg/statistics/binding__failpoint_binding__.go deleted file mode 100644 index d74e174269a9b..0000000000000 --- a/pkg/statistics/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package statistics - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/statistics/cmsketch.go b/pkg/statistics/cmsketch.go index 6fa5eb769eda2..5c013189b55fe 100644 --- a/pkg/statistics/cmsketch.go +++ b/pkg/statistics/cmsketch.go @@ -99,7 +99,7 @@ func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { } } slices.SortStableFunc(sorted, func(i, j dataCnt) int { return -cmp.Compare(i.cnt, j.cnt) }) - if val, _err_ := failpoint.Eval(_curpkg_("StabilizeV1AnalyzeTopN")); _err_ == nil { + failpoint.Inject("StabilizeV1AnalyzeTopN", func(val failpoint.Value) { if val.(bool) { // The earlier TopN entry will modify the CMSketch, therefore influence later TopN entry's row count. // So we need to make the order here fully deterministic to make the stats from analyze ver1 stable. @@ -109,7 +109,7 @@ func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { (sorted[i].cnt == sorted[j].cnt && string(sorted[i].data) < string(sorted[j].data)) }) } - } + }) var ( sumTopN uint64 @@ -270,9 +270,9 @@ func QueryValue(sctx context.PlanContext, c *CMSketch, t *TopN, val types.Datum) // QueryBytes is used to query the count of specified bytes. func (c *CMSketch) QueryBytes(d []byte) uint64 { - if val, _err_ := failpoint.Eval(_curpkg_("mockQueryBytesMaxUint64")); _err_ == nil { - return uint64(val.(int)) - } + failpoint.Inject("mockQueryBytesMaxUint64", func(val failpoint.Value) { + failpoint.Return(uint64(val.(int))) + }) h1, h2 := murmur3.Sum128(d) return c.queryHashValue(nil, h1, h2) } diff --git a/pkg/statistics/cmsketch.go__failpoint_stash__ b/pkg/statistics/cmsketch.go__failpoint_stash__ deleted file mode 100644 index 5c013189b55fe..0000000000000 --- a/pkg/statistics/cmsketch.go__failpoint_stash__ +++ /dev/null @@ -1,865 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package statistics - -import ( - "bytes" - "cmp" - "fmt" - "math" - "reflect" - "slices" - "sort" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/planner/context" - "github.com/pingcap/tidb/pkg/planner/util/debugtrace" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tipb/go-tipb" - "github.com/twmb/murmur3" -) - -// topNThreshold is the minimum ratio of the number of topN elements in CMSketch, 10 means 1 / 10 = 10%. -const topNThreshold = uint64(10) - -var ( - // ErrQueryInterrupted indicates interrupted - ErrQueryInterrupted = dbterror.ClassExecutor.NewStd(mysql.ErrQueryInterrupted) -) - -// CMSketch is used to estimate point queries. -// Refer: https://en.wikipedia.org/wiki/Count-min_sketch -type CMSketch struct { - table [][]uint32 - count uint64 // TopN is not counted in count - defaultValue uint64 // In sampled data, if cmsketch returns a small value (less than avg value / 2), then this will returned. - depth int32 - width int32 -} - -// NewCMSketch returns a new CM sketch. -func NewCMSketch(d, w int32) *CMSketch { - tbl := make([][]uint32, d) - // Background: The Go's memory allocator will ask caller to sweep spans in some scenarios. - // This can cause memory allocation request latency unpredictable, if the list of spans which need sweep is too long. - // For memory allocation large than 32K, the allocator will never allocate memory from spans list. - // - // The memory referenced by the CMSketch will never be freed. - // If the number of table or index is extremely large, there will be a large amount of spans in global list. - // The default value of `d` is 5 and `w` is 2048, if we use a single slice for them the size will be 40K. - // This allocation will be handled by mheap and will never have impact on normal allocations. - arena := make([]uint32, d*w) - for i := range tbl { - tbl[i] = arena[i*int(w) : (i+1)*int(w)] - } - return &CMSketch{depth: d, width: w, table: tbl} -} - -// topNHelper wraps some variables used when building cmsketch with top n. -type topNHelper struct { - sorted []dataCnt - sampleSize uint64 - onlyOnceItems uint64 - sumTopN uint64 - actualNumTop uint32 -} - -func newTopNHelper(sample [][]byte, numTop uint32) *topNHelper { - counter := make(map[hack.MutableString]uint64, len(sample)) - for i := range sample { - counter[hack.String(sample[i])]++ - } - sorted, onlyOnceItems := make([]dataCnt, 0, len(counter)), uint64(0) - for key, cnt := range counter { - sorted = append(sorted, dataCnt{hack.Slice(string(key)), cnt}) - if cnt == 1 { - onlyOnceItems++ - } - } - slices.SortStableFunc(sorted, func(i, j dataCnt) int { return -cmp.Compare(i.cnt, j.cnt) }) - failpoint.Inject("StabilizeV1AnalyzeTopN", func(val failpoint.Value) { - if val.(bool) { - // The earlier TopN entry will modify the CMSketch, therefore influence later TopN entry's row count. - // So we need to make the order here fully deterministic to make the stats from analyze ver1 stable. - // See (*SampleCollector).ExtractTopN(), which calls this function, for details - sort.SliceStable(sorted, func(i, j int) bool { - return sorted[i].cnt > sorted[j].cnt || - (sorted[i].cnt == sorted[j].cnt && string(sorted[i].data) < string(sorted[j].data)) - }) - } - }) - - var ( - sumTopN uint64 - sampleNDV = uint32(len(sorted)) - ) - numTop = min(sampleNDV, numTop) // Ensure numTop no larger than sampNDV. - // Only element whose frequency is not smaller than 2/3 multiples the - // frequency of the n-th element are added to the TopN statistics. We chose - // 2/3 as an empirical value because the average cardinality estimation - // error is relatively small compared with 1/2. - var actualNumTop uint32 - for ; actualNumTop < sampleNDV && actualNumTop < numTop*2; actualNumTop++ { - if actualNumTop >= numTop && sorted[actualNumTop].cnt*3 < sorted[numTop-1].cnt*2 { - break - } - if sorted[actualNumTop].cnt == 1 { - break - } - sumTopN += sorted[actualNumTop].cnt - } - - return &topNHelper{sorted, uint64(len(sample)), onlyOnceItems, sumTopN, actualNumTop} -} - -// NewCMSketchAndTopN returns a new CM sketch with TopN elements, the estimate NDV and the scale ratio. -func NewCMSketchAndTopN(d, w int32, sample [][]byte, numTop uint32, rowCount uint64) (*CMSketch, *TopN, uint64, uint64) { - if rowCount == 0 || len(sample) == 0 { - return nil, nil, 0, 0 - } - helper := newTopNHelper(sample, numTop) - // rowCount is not a accurate value when fast analyzing - // In some cases, if user triggers fast analyze when rowCount is close to sampleSize, unexpected bahavior might happen. - rowCount = max(rowCount, uint64(len(sample))) - estimateNDV, scaleRatio := calculateEstimateNDV(helper, rowCount) - defaultVal := calculateDefaultVal(helper, estimateNDV, scaleRatio, rowCount) - c, t := buildCMSAndTopN(helper, d, w, scaleRatio, defaultVal) - return c, t, estimateNDV, scaleRatio -} - -func buildCMSAndTopN(helper *topNHelper, d, w int32, scaleRatio uint64, defaultVal uint64) (c *CMSketch, t *TopN) { - c = NewCMSketch(d, w) - enableTopN := helper.sampleSize/topNThreshold <= helper.sumTopN - if enableTopN { - t = NewTopN(int(helper.actualNumTop)) - for i := uint32(0); i < helper.actualNumTop; i++ { - data, cnt := helper.sorted[i].data, helper.sorted[i].cnt - t.AppendTopN(data, cnt*scaleRatio) - } - t.Sort() - helper.sorted = helper.sorted[helper.actualNumTop:] - } - c.defaultValue = defaultVal - for i := range helper.sorted { - data, cnt := helper.sorted[i].data, helper.sorted[i].cnt - // If the value only occurred once in the sample, we assumes that there is no difference with - // value that does not occurred in the sample. - rowCount := defaultVal - if cnt > 1 { - rowCount = cnt * scaleRatio - } - c.InsertBytesByCount(data, rowCount) - } - return -} - -func calculateDefaultVal(helper *topNHelper, estimateNDV, scaleRatio, rowCount uint64) uint64 { - sampleNDV := uint64(len(helper.sorted)) - if rowCount <= (helper.sampleSize-helper.onlyOnceItems)*scaleRatio { - return 1 - } - estimateRemainingCount := rowCount - (helper.sampleSize-helper.onlyOnceItems)*scaleRatio - return estimateRemainingCount / max(1, estimateNDV-sampleNDV+helper.onlyOnceItems) -} - -// MemoryUsage returns the total memory usage of a CMSketch. -// only calc the hashtable size(CMSketch.table) and the CMSketch.topN -// data are not tracked because size of CMSketch.topN take little influence -// We ignore the size of other metadata in CMSketch. -func (c *CMSketch) MemoryUsage() (sum int64) { - if c == nil { - return - } - sum = int64(c.depth * c.width * 4) - return -} - -// InsertBytes inserts the bytes value into the CM Sketch. -func (c *CMSketch) InsertBytes(bytes []byte) { - c.InsertBytesByCount(bytes, 1) -} - -// InsertBytesByCount adds the bytes value into the TopN (if value already in TopN) or CM Sketch by delta, this does not updates c.defaultValue. -func (c *CMSketch) InsertBytesByCount(bytes []byte, count uint64) { - h1, h2 := murmur3.Sum128(bytes) - c.count += count - for i := range c.table { - j := (h1 + h2*uint64(i)) % uint64(c.width) - c.table[i][j] += uint32(count) - } -} - -func (c *CMSketch) considerDefVal(cnt uint64) bool { - return (cnt == 0 || (cnt > c.defaultValue && cnt < 2*(c.count/uint64(c.width)))) && c.defaultValue > 0 -} - -// setValue sets the count for value that hashed into (h1, h2), and update defaultValue if necessary. -func (c *CMSketch) setValue(h1, h2 uint64, count uint64) { - oriCount := c.queryHashValue(nil, h1, h2) - if c.considerDefVal(oriCount) { - // We should update c.defaultValue if we used c.defaultValue when getting the estimate count. - // This should make estimation better, remove this line if it does not work as expected. - c.defaultValue = uint64(float64(c.defaultValue)*0.95 + float64(c.defaultValue)*0.05) - if c.defaultValue == 0 { - // c.defaultValue never guess 0 since we are using a sampled data. - c.defaultValue = 1 - } - } - - c.count += count - oriCount - // let it overflow naturally - deltaCount := uint32(count) - uint32(oriCount) - for i := range c.table { - j := (h1 + h2*uint64(i)) % uint64(c.width) - c.table[i][j] = c.table[i][j] + deltaCount - } -} - -// SubValue remove a value from the CMSketch. -func (c *CMSketch) SubValue(h1, h2 uint64, count uint64) { - c.count -= count - for i := range c.table { - j := (h1 + h2*uint64(i)) % uint64(c.width) - c.table[i][j] = c.table[i][j] - uint32(count) - } -} - -// QueryValue is used to query the count of specified value. -func QueryValue(sctx context.PlanContext, c *CMSketch, t *TopN, val types.Datum) (uint64, error) { - var sc *stmtctx.StatementContext - tz := time.UTC - if sctx != nil { - sc = sctx.GetSessionVars().StmtCtx - tz = sc.TimeZone() - } - rawData, err := tablecodec.EncodeValue(tz, nil, val) - if sc != nil { - err = sc.HandleError(err) - } - if err != nil { - return 0, errors.Trace(err) - } - h1, h2 := murmur3.Sum128(rawData) - if ret, ok := t.QueryTopN(sctx, rawData); ok { - return ret, nil - } - return c.queryHashValue(sctx, h1, h2), nil -} - -// QueryBytes is used to query the count of specified bytes. -func (c *CMSketch) QueryBytes(d []byte) uint64 { - failpoint.Inject("mockQueryBytesMaxUint64", func(val failpoint.Value) { - failpoint.Return(uint64(val.(int))) - }) - h1, h2 := murmur3.Sum128(d) - return c.queryHashValue(nil, h1, h2) -} - -// The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (c *CMSketch) queryHashValue(sctx context.PlanContext, h1, h2 uint64) (result uint64) { - vals := make([]uint32, c.depth) - originVals := make([]uint32, c.depth) - minValue := uint32(math.MaxUint32) - useDefaultValue := false - if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - defer func() { - debugtrace.RecordAnyValuesWithNames(sctx, - "Origin Values", originVals, - "Values", vals, - "Use default value", useDefaultValue, - "Result", result, - ) - debugtrace.LeaveContextCommon(sctx) - }() - } - // We want that when res is 0 before the noise is eliminated, the default value is not used. - // So we need a temp value to distinguish before and after eliminating noise. - temp := uint32(1) - for i := range c.table { - j := (h1 + h2*uint64(i)) % uint64(c.width) - originVals[i] = c.table[i][j] - if minValue > c.table[i][j] { - minValue = c.table[i][j] - } - noise := (c.count - uint64(c.table[i][j])) / (uint64(c.width) - 1) - if uint64(c.table[i][j]) == 0 { - vals[i] = 0 - } else if uint64(c.table[i][j]) < noise { - vals[i] = temp - } else { - vals[i] = c.table[i][j] - uint32(noise) + temp - } - } - slices.Sort(vals) - res := vals[(c.depth-1)/2] + (vals[c.depth/2]-vals[(c.depth-1)/2])/2 - if res > minValue+temp { - res = minValue + temp - } - if res == 0 { - return uint64(0) - } - res = res - temp - if c.considerDefVal(uint64(res)) { - useDefaultValue = true - return c.defaultValue - } - return uint64(res) -} - -// MergeTopNAndUpdateCMSketch merges the src TopN into the dst, and spilled values will be inserted into the CMSketch. -func MergeTopNAndUpdateCMSketch(dst, src *TopN, c *CMSketch, numTop uint32) []TopNMeta { - topNs := []*TopN{src, dst} - mergedTopN, popedTopNPair := MergeTopN(topNs, numTop) - if mergedTopN == nil { - // mergedTopN == nil means the total count of the input TopN are equal to zero - return popedTopNPair - } - dst.TopN = mergedTopN.TopN - for _, topNMeta := range popedTopNPair { - c.InsertBytesByCount(topNMeta.Encoded, topNMeta.Count) - } - return popedTopNPair -} - -// MergeCMSketch merges two CM Sketch. -func (c *CMSketch) MergeCMSketch(rc *CMSketch) error { - if c == nil || rc == nil { - return nil - } - if c.depth != rc.depth || c.width != rc.width { - return errors.New("Dimensions of Count-Min Sketch should be the same") - } - c.count += rc.count - for i := range c.table { - for j := range c.table[i] { - c.table[i][j] += rc.table[i][j] - } - } - return nil -} - -// CMSketchToProto converts CMSketch to its protobuf representation. -func CMSketchToProto(c *CMSketch, topn *TopN) *tipb.CMSketch { - protoSketch := &tipb.CMSketch{} - if c != nil { - protoSketch.Rows = make([]*tipb.CMSketchRow, c.depth) - for i := range c.table { - protoSketch.Rows[i] = &tipb.CMSketchRow{Counters: make([]uint32, c.width)} - copy(protoSketch.Rows[i].Counters, c.table[i]) - } - protoSketch.DefaultValue = c.defaultValue - } - if topn != nil { - for _, dataMeta := range topn.TopN { - protoSketch.TopN = append(protoSketch.TopN, &tipb.CMSketchTopN{Data: dataMeta.Encoded, Count: dataMeta.Count}) - } - } - return protoSketch -} - -// CMSketchAndTopNFromProto converts CMSketch and TopN from its protobuf representation. -func CMSketchAndTopNFromProto(protoSketch *tipb.CMSketch) (*CMSketch, *TopN) { - if protoSketch == nil { - return nil, nil - } - retTopN := TopNFromProto(protoSketch.TopN) - if len(protoSketch.Rows) == 0 { - return nil, retTopN - } - c := NewCMSketch(int32(len(protoSketch.Rows)), int32(len(protoSketch.Rows[0].Counters))) - for i, row := range protoSketch.Rows { - c.count = 0 - for j, counter := range row.Counters { - c.table[i][j] = counter - c.count = c.count + uint64(counter) - } - } - c.defaultValue = protoSketch.DefaultValue - return c, retTopN -} - -// TopNFromProto converts TopN from its protobuf representation. -func TopNFromProto(protoTopN []*tipb.CMSketchTopN) *TopN { - if len(protoTopN) == 0 { - return nil - } - topN := NewTopN(len(protoTopN)) - for _, e := range protoTopN { - d := make([]byte, len(e.Data)) - copy(d, e.Data) - topN.AppendTopN(d, e.Count) - } - topN.Sort() - return topN -} - -// EncodeCMSketchWithoutTopN encodes the given CMSketch to byte slice. -// Note that it does not include the topN. -func EncodeCMSketchWithoutTopN(c *CMSketch) ([]byte, error) { - if c == nil { - return nil, nil - } - p := CMSketchToProto(c, nil) - p.TopN = nil - protoData, err := p.Marshal() - return protoData, err -} - -// DecodeCMSketchAndTopN decode a CMSketch from the given byte slice. -func DecodeCMSketchAndTopN(data []byte, topNRows []chunk.Row) (*CMSketch, *TopN, error) { - if data == nil && len(topNRows) == 0 { - return nil, nil, nil - } - if len(data) == 0 { - return nil, DecodeTopN(topNRows), nil - } - cm, err := DecodeCMSketch(data) - if err != nil { - return nil, nil, errors.Trace(err) - } - return cm, DecodeTopN(topNRows), nil -} - -// DecodeTopN decodes a TopN from the given byte slice. -func DecodeTopN(topNRows []chunk.Row) *TopN { - if len(topNRows) == 0 { - return nil - } - topN := NewTopN(len(topNRows)) - for _, row := range topNRows { - data := make([]byte, len(row.GetBytes(0))) - copy(data, row.GetBytes(0)) - topN.AppendTopN(data, row.GetUint64(1)) - } - topN.Sort() - return topN -} - -// DecodeCMSketch encodes the given CMSketch to byte slice. -func DecodeCMSketch(data []byte) (*CMSketch, error) { - if len(data) == 0 { - return nil, nil - } - protoSketch := &tipb.CMSketch{} - err := protoSketch.Unmarshal(data) - if err != nil { - return nil, errors.Trace(err) - } - if len(protoSketch.Rows) == 0 { - return nil, nil - } - c := NewCMSketch(int32(len(protoSketch.Rows)), int32(len(protoSketch.Rows[0].Counters))) - for i, row := range protoSketch.Rows { - c.count = 0 - for j, counter := range row.Counters { - c.table[i][j] = counter - c.count = c.count + uint64(counter) - } - } - c.defaultValue = protoSketch.DefaultValue - return c, nil -} - -// TotalCount returns the total count in the sketch, it is only used for test. -func (c *CMSketch) TotalCount() uint64 { - if c == nil { - return 0 - } - return c.count -} - -// Equal tests if two CM Sketch equal, it is only used for test. -func (c *CMSketch) Equal(rc *CMSketch) bool { - return reflect.DeepEqual(c, rc) -} - -// Copy makes a copy for current CMSketch. -func (c *CMSketch) Copy() *CMSketch { - if c == nil { - return nil - } - tbl := make([][]uint32, c.depth) - for i := range tbl { - tbl[i] = make([]uint32, c.width) - copy(tbl[i], c.table[i]) - } - return &CMSketch{count: c.count, width: c.width, depth: c.depth, table: tbl, defaultValue: c.defaultValue} -} - -// GetWidthAndDepth returns the width and depth of CM Sketch. -func (c *CMSketch) GetWidthAndDepth() (width, depth int32) { - return c.width, c.depth -} - -// CalcDefaultValForAnalyze calculate the default value for Analyze. -// The value of it is count / NDV in CMSketch. This means count and NDV are not include topN. -func (c *CMSketch) CalcDefaultValForAnalyze(ndv uint64) { - c.defaultValue = c.count / max(1, ndv) -} - -// TopN stores most-common values, which is used to estimate point queries. -type TopN struct { - TopN []TopNMeta -} - -// Scale scales the TopN by the given factor. -func (c *TopN) Scale(scaleFactor float64) { - for i := range c.TopN { - c.TopN[i].Count = uint64(float64(c.TopN[i].Count) * scaleFactor) - } -} - -// AppendTopN appends a topn into the TopN struct. -func (c *TopN) AppendTopN(data []byte, count uint64) { - if c == nil { - return - } - c.TopN = append(c.TopN, TopNMeta{data, count}) -} - -func (c *TopN) String() string { - if c == nil { - return "EmptyTopN" - } - builder := &strings.Builder{} - fmt.Fprintf(builder, "TopN{length: %v, ", len(c.TopN)) - fmt.Fprint(builder, "[") - for i := 0; i < len(c.TopN); i++ { - fmt.Fprintf(builder, "(%v, %v)", c.TopN[i].Encoded, c.TopN[i].Count) - if i+1 != len(c.TopN) { - fmt.Fprint(builder, ", ") - } - } - fmt.Fprint(builder, "]") - fmt.Fprint(builder, "}") - return builder.String() -} - -// Num returns the ndv of the TopN. -// -// TopN is declared directly in Histogram. So the Len is occupied by the Histogram. We use Num instead. -func (c *TopN) Num() int { - if c == nil { - return 0 - } - return len(c.TopN) -} - -// DecodedString returns the value with decoded result. -func (c *TopN) DecodedString(ctx sessionctx.Context, colTypes []byte) (string, error) { - if c == nil { - return "", nil - } - builder := &strings.Builder{} - fmt.Fprintf(builder, "TopN{length: %v, ", len(c.TopN)) - fmt.Fprint(builder, "[") - var tmpDatum types.Datum - for i := 0; i < len(c.TopN); i++ { - tmpDatum.SetBytes(c.TopN[i].Encoded) - valStr, err := ValueToString(ctx.GetSessionVars(), &tmpDatum, len(colTypes), colTypes) - if err != nil { - return "", err - } - fmt.Fprintf(builder, "(%v, %v)", valStr, c.TopN[i].Count) - if i+1 != len(c.TopN) { - fmt.Fprint(builder, ", ") - } - } - fmt.Fprint(builder, "]") - fmt.Fprint(builder, "}") - return builder.String(), nil -} - -// Copy makes a copy for current TopN. -func (c *TopN) Copy() *TopN { - if c == nil { - return nil - } - topN := make([]TopNMeta, len(c.TopN)) - for i, t := range c.TopN { - topN[i].Encoded = make([]byte, len(t.Encoded)) - copy(topN[i].Encoded, t.Encoded) - topN[i].Count = t.Count - } - return &TopN{ - TopN: topN, - } -} - -// TopNMeta stores the unit of the TopN. -type TopNMeta struct { - Encoded []byte - Count uint64 -} - -// QueryTopN returns the results for (h1, h2) in murmur3.Sum128(), if not exists, return (0, false). -// The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (c *TopN) QueryTopN(sctx context.PlanContext, d []byte) (result uint64, found bool) { - if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - defer func() { - debugtrace.RecordAnyValuesWithNames(sctx, "Result", result, "Found", found) - debugtrace.LeaveContextCommon(sctx) - }() - } - if c == nil { - return 0, false - } - idx := c.FindTopN(d) - if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.RecordAnyValuesWithNames(sctx, "FindTopN idx", idx) - } - if idx < 0 { - return 0, false - } - return c.TopN[idx].Count, true -} - -// FindTopN finds the index of the given value in the TopN. -func (c *TopN) FindTopN(d []byte) int { - if c == nil { - return -1 - } - if len(c.TopN) == 0 { - return -1 - } - if len(c.TopN) == 1 { - if bytes.Equal(c.TopN[0].Encoded, d) { - return 0 - } - return -1 - } - if bytes.Compare(c.TopN[len(c.TopN)-1].Encoded, d) < 0 { - return -1 - } - if bytes.Compare(c.TopN[0].Encoded, d) > 0 { - return -1 - } - idx, match := slices.BinarySearchFunc(c.TopN, d, func(a TopNMeta, b []byte) int { - return bytes.Compare(a.Encoded, b) - }) - if !match { - return -1 - } - return idx -} - -// LowerBound searches on the sorted top-n items, -// returns the smallest index i such that the value at element i is not less than `d`. -func (c *TopN) LowerBound(d []byte) (idx int, match bool) { - if c == nil { - return 0, false - } - idx, match = slices.BinarySearchFunc(c.TopN, d, func(a TopNMeta, b []byte) int { - return bytes.Compare(a.Encoded, b) - }) - return idx, match -} - -// BetweenCount estimates the row count for interval [l, r). -// The input sctx is just for debug trace, you can pass nil safely if that's not needed. -func (c *TopN) BetweenCount(sctx context.PlanContext, l, r []byte) (result uint64) { - if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugtrace.EnterContextCommon(sctx) - defer func() { - debugtrace.RecordAnyValuesWithNames(sctx, "Result", result) - debugtrace.LeaveContextCommon(sctx) - }() - } - if c == nil { - return 0 - } - lIdx, _ := c.LowerBound(l) - rIdx, _ := c.LowerBound(r) - ret := uint64(0) - for i := lIdx; i < rIdx; i++ { - ret += c.TopN[i].Count - } - if sctx != nil && sctx.GetSessionVars().StmtCtx.EnableOptimizerDebugTrace { - debugTraceTopNRange(sctx, c, lIdx, rIdx) - } - return ret -} - -// Sort sorts the topn items. -func (c *TopN) Sort() { - if c == nil { - return - } - slices.SortFunc(c.TopN, func(i, j TopNMeta) int { - return bytes.Compare(i.Encoded, j.Encoded) - }) -} - -// TotalCount returns how many data is stored in TopN. -func (c *TopN) TotalCount() uint64 { - if c == nil { - return 0 - } - total := uint64(0) - for _, t := range c.TopN { - total += t.Count - } - return total -} - -// Equal checks whether the two TopN are equal. -func (c *TopN) Equal(cc *TopN) bool { - if c.TotalCount() == 0 && cc.TotalCount() == 0 { - return true - } else if c.TotalCount() != cc.TotalCount() { - return false - } - if len(c.TopN) != len(cc.TopN) { - return false - } - for i := range c.TopN { - if !bytes.Equal(c.TopN[i].Encoded, cc.TopN[i].Encoded) { - return false - } - if c.TopN[i].Count != cc.TopN[i].Count { - return false - } - } - return true -} - -// RemoveVal remove the val from TopN if it exists. -func (c *TopN) RemoveVal(val []byte) { - if c == nil { - return - } - pos := c.FindTopN(val) - if pos == -1 { - return - } - c.TopN = append(c.TopN[:pos], c.TopN[pos+1:]...) -} - -// MemoryUsage returns the total memory usage of a topn. -func (c *TopN) MemoryUsage() (sum int64) { - if c == nil { - return - } - sum = 32 // size of array (24) + reference (8) - for _, meta := range c.TopN { - sum += 32 + int64(cap(meta.Encoded)) // 32 is size of byte array (24) + size of uint64 (8) - } - return -} - -// queryAddTopN TopN adds count to CMSketch.topN if exists, and returns the count of such elements after insert. -// If such elements does not in topn elements, nothing will happen and false will be returned. -func (c *TopN) updateTopNWithDelta(d []byte, delta uint64, increase bool) bool { - if c == nil || c.TopN == nil { - return false - } - idx := c.FindTopN(d) - if idx >= 0 { - if increase { - c.TopN[idx].Count += delta - } else { - c.TopN[idx].Count -= delta - } - return true - } - return false -} - -// NewTopN creates the new TopN struct by the given size. -func NewTopN(n int) *TopN { - return &TopN{TopN: make([]TopNMeta, 0, n)} -} - -// MergeTopN is used to merge more TopN structures to generate a new TopN struct by the given size. -// The input parameters are multiple TopN structures to be merged and the size of the new TopN that will be generated. -// The output parameters are the newly generated TopN structure and the remaining numbers. -// Notice: The n can be 0. So n has no default value, we must explicitly specify this value. -func MergeTopN(topNs []*TopN, n uint32) (*TopN, []TopNMeta) { - if CheckEmptyTopNs(topNs) { - return nil, nil - } - // Different TopN structures may hold the same value, we have to merge them. - counter := make(map[hack.MutableString]uint64) - for _, topN := range topNs { - if topN.TotalCount() == 0 { - continue - } - for _, val := range topN.TopN { - counter[hack.String(val.Encoded)] += val.Count - } - } - numTop := len(counter) - if numTop == 0 { - return nil, nil - } - sorted := make([]TopNMeta, 0, numTop) - for value, cnt := range counter { - data := hack.Slice(string(value)) - sorted = append(sorted, TopNMeta{Encoded: data, Count: cnt}) - } - return GetMergedTopNFromSortedSlice(sorted, n) -} - -// CheckEmptyTopNs checks whether all TopNs are empty. -func CheckEmptyTopNs(topNs []*TopN) bool { - for _, topN := range topNs { - if topN.TotalCount() != 0 { - return false - } - } - return true -} - -// SortTopnMeta sort topnMeta -func SortTopnMeta(topnMetas []TopNMeta) { - slices.SortFunc(topnMetas, func(i, j TopNMeta) int { - if i.Count != j.Count { - return cmp.Compare(j.Count, i.Count) - } - return bytes.Compare(i.Encoded, j.Encoded) - }) -} - -// TopnMetaCompare compare topnMeta -func TopnMetaCompare(i, j TopNMeta) int { - c := cmp.Compare(j.Count, i.Count) - if c != 0 { - return c - } - return bytes.Compare(i.Encoded, j.Encoded) -} - -// GetMergedTopNFromSortedSlice returns merged topn -func GetMergedTopNFromSortedSlice(sorted []TopNMeta, n uint32) (*TopN, []TopNMeta) { - SortTopnMeta(sorted) - n = min(uint32(len(sorted)), n) - - var finalTopN TopN - finalTopN.TopN = sorted[:n] - finalTopN.Sort() - return &finalTopN, sorted[n:] -} diff --git a/pkg/statistics/handle/autoanalyze/autoanalyze.go b/pkg/statistics/handle/autoanalyze/autoanalyze.go index 3e4c38ae22852..7ff39511b7865 100644 --- a/pkg/statistics/handle/autoanalyze/autoanalyze.go +++ b/pkg/statistics/handle/autoanalyze/autoanalyze.go @@ -720,7 +720,7 @@ func insertAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, insta } job.ID = new(uint64) *job.ID = rows[0].GetUint64(0) - if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { if val.(bool) { logutil.BgLogger().Info("InsertAnalyzeJob", zap.String("table_schema", job.DBName), @@ -730,7 +730,7 @@ func insertAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, insta zap.Uint64("job_id", *job.ID), ) } - } + }) return nil } @@ -746,14 +746,14 @@ func startAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob) { if err != nil { statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzePending, statistics.AnalyzeRunning)), zap.Error(err)) } - if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { if val.(bool) { logutil.BgLogger().Info("StartAnalyzeJob", zap.Time("start_time", job.StartTime), zap.Uint64("job id", *job.ID), ) } - } + }) } // updateAnalyzeJobProgress updates count of the processed rows when increment reaches a threshold. @@ -770,14 +770,14 @@ func updateAnalyzeJobProgress(sctx sessionctx.Context, job *statistics.AnalyzeJo if err != nil { statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("process %v rows", delta)), zap.Error(err)) } - if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { if val.(bool) { logutil.BgLogger().Info("UpdateAnalyzeJobProgress", zap.Int64("increase processed_rows", delta), zap.Uint64("job id", *job.ID), ) } - } + }) } // finishAnalyzeJob finishes an analyze or merge job @@ -825,7 +825,7 @@ func finishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analy logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzeRunning, state)), zap.Error(err)) } - if val, _err_ := failpoint.Eval(_curpkg_("DebugAnalyzeJobOperations")); _err_ == nil { + failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { if val.(bool) { logger := logutil.BgLogger().With( zap.Time("end_time", job.EndTime), @@ -839,5 +839,5 @@ func finishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analy } logger.Info("FinishAnalyzeJob") } - } + }) } diff --git a/pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ b/pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ deleted file mode 100644 index 7ff39511b7865..0000000000000 --- a/pkg/statistics/handle/autoanalyze/autoanalyze.go__failpoint_stash__ +++ /dev/null @@ -1,843 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package autoanalyze - -import ( - "context" - "fmt" - "math/rand" - "net" - "strconv" - "strings" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/sysproctrack" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/autoanalyze/exec" - "github.com/pingcap/tidb/pkg/statistics/handle/autoanalyze/refresher" - "github.com/pingcap/tidb/pkg/statistics/handle/lockstats" - statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" - statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" - statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sqlescape" - "github.com/pingcap/tidb/pkg/util/timeutil" - "go.uber.org/zap" -) - -// statsAnalyze implements util.StatsAnalyze. -// statsAnalyze is used to handle auto-analyze and manage analyze jobs. -type statsAnalyze struct { - statsHandle statstypes.StatsHandle - // sysProcTracker is used to track sys process like analyze - sysProcTracker sysproctrack.Tracker -} - -// NewStatsAnalyze creates a new StatsAnalyze. -func NewStatsAnalyze( - statsHandle statstypes.StatsHandle, - sysProcTracker sysproctrack.Tracker, -) statstypes.StatsAnalyze { - return &statsAnalyze{statsHandle: statsHandle, sysProcTracker: sysProcTracker} -} - -// InsertAnalyzeJob inserts the analyze job to the storage. -func (sa *statsAnalyze) InsertAnalyzeJob(job *statistics.AnalyzeJob, instance string, procID uint64) error { - return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - return insertAnalyzeJob(sctx, job, instance, procID) - }) -} - -func (sa *statsAnalyze) StartAnalyzeJob(job *statistics.AnalyzeJob) { - err := statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - startAnalyzeJob(sctx, job) - return nil - }) - if err != nil { - statslogutil.StatsLogger().Warn("failed to start analyze job", zap.Error(err)) - } -} - -func (sa *statsAnalyze) UpdateAnalyzeJobProgress(job *statistics.AnalyzeJob, rowCount int64) { - err := statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - updateAnalyzeJobProgress(sctx, job, rowCount) - return nil - }) - if err != nil { - statslogutil.StatsLogger().Warn("failed to update analyze job progress", zap.Error(err)) - } -} - -func (sa *statsAnalyze) FinishAnalyzeJob(job *statistics.AnalyzeJob, failReason error, analyzeType statistics.JobType) { - err := statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - finishAnalyzeJob(sctx, job, failReason, analyzeType) - return nil - }) - if err != nil { - statslogutil.StatsLogger().Warn("failed to finish analyze job", zap.Error(err)) - } -} - -// DeleteAnalyzeJobs deletes the analyze jobs whose update time is earlier than updateTime. -func (sa *statsAnalyze) DeleteAnalyzeJobs(updateTime time.Time) error { - return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - _, _, err := statsutil.ExecRows(sctx, "DELETE FROM mysql.analyze_jobs WHERE update_time < CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)", updateTime.UTC().Format(types.TimeFormat)) - return err - }) -} - -// CleanupCorruptedAnalyzeJobsOnCurrentInstance cleans up the potentially corrupted analyze job. -// It only cleans up the jobs that are associated with the current instance. -func (sa *statsAnalyze) CleanupCorruptedAnalyzeJobsOnCurrentInstance(currentRunningProcessIDs map[uint64]struct{}) error { - return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - return CleanupCorruptedAnalyzeJobsOnCurrentInstance(sctx, currentRunningProcessIDs) - }, statsutil.FlagWrapTxn) -} - -// CleanupCorruptedAnalyzeJobsOnDeadInstances removes analyze jobs that may have been corrupted. -// Specifically, it removes jobs associated with instances that no longer exist in the cluster. -func (sa *statsAnalyze) CleanupCorruptedAnalyzeJobsOnDeadInstances() error { - return statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - return CleanupCorruptedAnalyzeJobsOnDeadInstances(sctx) - }, statsutil.FlagWrapTxn) -} - -// SelectAnalyzeJobsOnCurrentInstanceSQL is the SQL to select the analyze jobs whose -// state is `pending` or `running` and the update time is more than 10 minutes ago -// and the instance is current instance. -const SelectAnalyzeJobsOnCurrentInstanceSQL = `SELECT id, process_id - FROM mysql.analyze_jobs - WHERE instance = %? - AND state IN ('pending', 'running') - AND update_time < CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)` - -// SelectAnalyzeJobsSQL is the SQL to select the analyze jobs whose -// state is `pending` or `running` and the update time is more than 10 minutes ago. -const SelectAnalyzeJobsSQL = `SELECT id, instance - FROM mysql.analyze_jobs - WHERE state IN ('pending', 'running') - AND update_time < CONVERT_TZ(%?, '+00:00', @@TIME_ZONE)` - -// BatchUpdateAnalyzeJobSQL is the SQL to update the analyze jobs to `failed` state. -const BatchUpdateAnalyzeJobSQL = `UPDATE mysql.analyze_jobs - SET state = 'failed', - fail_reason = 'The TiDB Server has either shut down or the analyze query was terminated during the analyze job execution', - process_id = NULL - WHERE id IN (%?)` - -func tenMinutesAgo() string { - return time.Now().Add(-10 * time.Minute).UTC().Format(types.TimeFormat) -} - -// CleanupCorruptedAnalyzeJobsOnCurrentInstance cleans up the potentially corrupted analyze job from current instance. -// Exported for testing. -func CleanupCorruptedAnalyzeJobsOnCurrentInstance( - sctx sessionctx.Context, - currentRunningProcessIDs map[uint64]struct{}, -) error { - serverInfo, err := infosync.GetServerInfo() - if err != nil { - return errors.Trace(err) - } - instance := net.JoinHostPort(serverInfo.IP, strconv.Itoa(int(serverInfo.Port))) - // Get all the analyze jobs whose state is `pending` or `running` and the update time is more than 10 minutes ago - // and the instance is current instance. - rows, _, err := statsutil.ExecRows( - sctx, - SelectAnalyzeJobsOnCurrentInstanceSQL, - instance, - tenMinutesAgo(), - ) - if err != nil { - return errors.Trace(err) - } - - jobIDs := make([]string, 0, len(rows)) - for _, row := range rows { - // The process ID is typically non-null for running or pending jobs. - // However, in rare cases(I don't which case), it may be null. Therefore, it's necessary to check its value. - if !row.IsNull(1) { - processID := row.GetUint64(1) - // If the process id is not in currentRunningProcessIDs, we need to clean up the job. - // They don't belong to current instance any more. - if _, ok := currentRunningProcessIDs[processID]; !ok { - jobID := row.GetUint64(0) - jobIDs = append(jobIDs, strconv.FormatUint(jobID, 10)) - } - } - } - - // Do a batch update to clean up the jobs. - if len(jobIDs) > 0 { - _, _, err = statsutil.ExecRows( - sctx, - BatchUpdateAnalyzeJobSQL, - jobIDs, - ) - if err != nil { - return errors.Trace(err) - } - statslogutil.StatsLogger().Info( - "clean up the potentially corrupted analyze jobs from current instance", - zap.Strings("jobIDs", jobIDs), - ) - } - - return nil -} - -// CleanupCorruptedAnalyzeJobsOnDeadInstances cleans up the potentially corrupted analyze job from dead instances. -func CleanupCorruptedAnalyzeJobsOnDeadInstances( - sctx sessionctx.Context, -) error { - rows, _, err := statsutil.ExecRows( - sctx, - SelectAnalyzeJobsSQL, - tenMinutesAgo(), - ) - if err != nil { - return errors.Trace(err) - } - if len(rows) == 0 { - return nil - } - - // Get all the instances from etcd. - serverInfo, err := infosync.GetAllServerInfo(context.Background()) - if err != nil { - return errors.Trace(err) - } - instances := make(map[string]struct{}, len(serverInfo)) - for _, info := range serverInfo { - instance := net.JoinHostPort(info.IP, strconv.Itoa(int(info.Port))) - instances[instance] = struct{}{} - } - - jobIDs := make([]string, 0, len(rows)) - for _, row := range rows { - // If the instance is not in instances, we need to clean up the job. - // It means the instance is down or the instance is not in the cluster any more. - instance := row.GetString(1) - if _, ok := instances[instance]; !ok { - jobID := row.GetUint64(0) - jobIDs = append(jobIDs, strconv.FormatUint(jobID, 10)) - } - } - - // Do a batch update to clean up the jobs. - if len(jobIDs) > 0 { - _, _, err = statsutil.ExecRows( - sctx, - BatchUpdateAnalyzeJobSQL, - jobIDs, - ) - if err != nil { - return errors.Trace(err) - } - statslogutil.StatsLogger().Info( - "clean up the potentially corrupted analyze jobs from dead instances", - zap.Strings("jobIDs", jobIDs), - ) - } - - return nil -} - -// HandleAutoAnalyze analyzes the outdated tables. (The change percent of the table exceeds the threshold) -// It also analyzes newly created tables and newly added indexes. -func (sa *statsAnalyze) HandleAutoAnalyze() (analyzed bool) { - _ = statsutil.CallWithSCtx(sa.statsHandle.SPool(), func(sctx sessionctx.Context) error { - analyzed = HandleAutoAnalyze(sctx, sa.statsHandle, sa.sysProcTracker) - return nil - }) - return -} - -// CheckAnalyzeVersion checks whether all the statistics versions of this table's columns and indexes are the same. -func (sa *statsAnalyze) CheckAnalyzeVersion(tblInfo *model.TableInfo, physicalIDs []int64, version *int) bool { - // We simply choose one physical id to get its stats. - var tbl *statistics.Table - for _, pid := range physicalIDs { - tbl = sa.statsHandle.GetPartitionStats(tblInfo, pid) - if !tbl.Pseudo { - break - } - } - if tbl == nil || tbl.Pseudo { - return true - } - return statistics.CheckAnalyzeVerOnTable(tbl, version) -} - -// HandleAutoAnalyze analyzes the newly created table or index. -func HandleAutoAnalyze( - sctx sessionctx.Context, - statsHandle statstypes.StatsHandle, - sysProcTracker sysproctrack.Tracker, -) (analyzed bool) { - defer func() { - if r := recover(); r != nil { - statslogutil.StatsLogger().Error( - "HandleAutoAnalyze panicked", - zap.Any("recover", r), - zap.Stack("stack"), - ) - } - }() - if variable.EnableAutoAnalyzePriorityQueue.Load() { - r := refresher.NewRefresher(statsHandle, sysProcTracker) - err := r.RebuildTableAnalysisJobQueue() - if err != nil { - statslogutil.StatsLogger().Error("rebuild table analysis job queue failed", zap.Error(err)) - return false - } - return r.PickOneTableAndAnalyzeByPriority() - } - - parameters := exec.GetAutoAnalyzeParameters(sctx) - autoAnalyzeRatio := exec.ParseAutoAnalyzeRatio(parameters[variable.TiDBAutoAnalyzeRatio]) - // Determine the time window for auto-analysis and verify if the current time falls within this range. - start, end, err := exec.ParseAutoAnalysisWindow( - parameters[variable.TiDBAutoAnalyzeStartTime], - parameters[variable.TiDBAutoAnalyzeEndTime], - ) - if err != nil { - statslogutil.StatsLogger().Error( - "parse auto analyze period failed", - zap.Error(err), - ) - return false - } - if !timeutil.WithinDayTimePeriod(start, end, time.Now()) { - return false - } - pruneMode := variable.PartitionPruneMode(sctx.GetSessionVars().PartitionPruneMode.Load()) - - return RandomPickOneTableAndTryAutoAnalyze( - sctx, - statsHandle, - sysProcTracker, - autoAnalyzeRatio, - pruneMode, - start, - end, - ) -} - -// RandomPickOneTableAndTryAutoAnalyze randomly picks one table and tries to analyze it. -// 1. If the table is not analyzed, analyze it. -// 2. If the table is analyzed, analyze it when "tbl.ModifyCount/tbl.Count > autoAnalyzeRatio". -// 3. If the table is analyzed, analyze its indices when the index is not analyzed. -// 4. If the table is locked, skip it. -// Exposed solely for testing. -func RandomPickOneTableAndTryAutoAnalyze( - sctx sessionctx.Context, - statsHandle statstypes.StatsHandle, - sysProcTracker sysproctrack.Tracker, - autoAnalyzeRatio float64, - pruneMode variable.PartitionPruneMode, - start, end time.Time, -) bool { - is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) - dbs := infoschema.AllSchemaNames(is) - // Shuffle the database and table slice to randomize the order of analyzing tables. - rd := rand.New(rand.NewSource(time.Now().UnixNano())) // #nosec G404 - rd.Shuffle(len(dbs), func(i, j int) { - dbs[i], dbs[j] = dbs[j], dbs[i] - }) - // Query locked tables once to minimize overhead. - // Outdated lock info is acceptable as we verify table lock status pre-analysis. - lockedTables, err := lockstats.QueryLockedTables(sctx) - if err != nil { - statslogutil.StatsLogger().Error( - "check table lock failed", - zap.Error(err), - ) - return false - } - - for _, db := range dbs { - // Ignore the memory and system database. - if util.IsMemOrSysDB(strings.ToLower(db)) { - continue - } - - tbls, err := is.SchemaTableInfos(context.Background(), model.NewCIStr(db)) - terror.Log(err) - // We shuffle dbs and tbls so that the order of iterating tables is random. If the order is fixed and the auto - // analyze job of one table fails for some reason, it may always analyze the same table and fail again and again - // when the HandleAutoAnalyze is triggered. Randomizing the order can avoid the problem. - // TODO: Design a priority queue to place the table which needs analyze most in the front. - rd.Shuffle(len(tbls), func(i, j int) { - tbls[i], tbls[j] = tbls[j], tbls[i] - }) - - // We need to check every partition of every table to see if it needs to be analyzed. - for _, tblInfo := range tbls { - // Sometimes the tables are too many. Auto-analyze will take too much time on it. - // so we need to check the available time. - if !timeutil.WithinDayTimePeriod(start, end, time.Now()) { - return false - } - // If table locked, skip analyze all partitions of the table. - // FIXME: This check is not accurate, because other nodes may change the table lock status at any time. - if _, ok := lockedTables[tblInfo.ID]; ok { - continue - } - - if tblInfo.IsView() { - continue - } - - pi := tblInfo.GetPartitionInfo() - // No partitions, analyze the whole table. - if pi == nil { - statsTbl := statsHandle.GetTableStatsForAutoAnalyze(tblInfo) - sql := "analyze table %n.%n" - analyzed := tryAutoAnalyzeTable(sctx, statsHandle, sysProcTracker, tblInfo, statsTbl, autoAnalyzeRatio, sql, db, tblInfo.Name.O) - if analyzed { - // analyze one table at a time to let it get the freshest parameters. - // others will be analyzed next round which is just 3s later. - return true - } - continue - } - // Only analyze the partition that has not been locked. - partitionDefs := make([]model.PartitionDefinition, 0, len(pi.Definitions)) - for _, def := range pi.Definitions { - if _, ok := lockedTables[def.ID]; !ok { - partitionDefs = append(partitionDefs, def) - } - } - partitionStats := getPartitionStats(statsHandle, tblInfo, partitionDefs) - if pruneMode == variable.Dynamic { - analyzed := tryAutoAnalyzePartitionTableInDynamicMode( - sctx, - statsHandle, - sysProcTracker, - tblInfo, - partitionDefs, - partitionStats, - db, - autoAnalyzeRatio, - ) - if analyzed { - return true - } - continue - } - for _, def := range partitionDefs { - sql := "analyze table %n.%n partition %n" - statsTbl := partitionStats[def.ID] - analyzed := tryAutoAnalyzeTable(sctx, statsHandle, sysProcTracker, tblInfo, statsTbl, autoAnalyzeRatio, sql, db, tblInfo.Name.O, def.Name.O) - if analyzed { - return true - } - } - } - } - - return false -} - -func getPartitionStats( - statsHandle statstypes.StatsHandle, - tblInfo *model.TableInfo, - defs []model.PartitionDefinition, -) map[int64]*statistics.Table { - partitionStats := make(map[int64]*statistics.Table, len(defs)) - - for _, def := range defs { - partitionStats[def.ID] = statsHandle.GetPartitionStatsForAutoAnalyze(tblInfo, def.ID) - } - - return partitionStats -} - -// Determine whether the table and index require analysis. -func tryAutoAnalyzeTable( - sctx sessionctx.Context, - statsHandle statstypes.StatsHandle, - sysProcTracker sysproctrack.Tracker, - tblInfo *model.TableInfo, - statsTbl *statistics.Table, - ratio float64, - sql string, - params ...any, -) bool { - // 1. If the statistics are either not loaded or are classified as pseudo, there is no need for analyze - // Pseudo statistics can be created by the optimizer, so we need to double check it. - // 2. If the table is too small, we don't want to waste time to analyze it. - // Leave the opportunity to other bigger tables. - if statsTbl == nil || statsTbl.Pseudo || statsTbl.RealtimeCount < statistics.AutoAnalyzeMinCnt { - return false - } - - // Check if the table needs to analyze. - if needAnalyze, reason := NeedAnalyzeTable( - statsTbl, - ratio, - ); needAnalyze { - escaped, err := sqlescape.EscapeSQL(sql, params...) - if err != nil { - return false - } - statslogutil.StatsLogger().Info( - "auto analyze triggered", - zap.String("sql", escaped), - zap.String("reason", reason), - ) - - tableStatsVer := sctx.GetSessionVars().AnalyzeVersion - statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) - exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sql, params...) - - return true - } - - // Whether the table needs to analyze or not, we need to check the indices of the table. - for _, idx := range tblInfo.Indices { - if idxStats := statsTbl.GetIdx(idx.ID); idxStats == nil && !statsTbl.ColAndIdxExistenceMap.HasAnalyzed(idx.ID, true) && idx.State == model.StatePublic { - sqlWithIdx := sql + " index %n" - paramsWithIdx := append(params, idx.Name.O) - escaped, err := sqlescape.EscapeSQL(sqlWithIdx, paramsWithIdx...) - if err != nil { - return false - } - - statslogutil.StatsLogger().Info( - "auto analyze for unanalyzed indexes", - zap.String("sql", escaped), - ) - tableStatsVer := sctx.GetSessionVars().AnalyzeVersion - statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) - exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sqlWithIdx, paramsWithIdx...) - return true - } - } - return false -} - -// NeedAnalyzeTable checks if we need to analyze the table: -// 1. If the table has never been analyzed, we need to analyze it. -// 2. If the table had been analyzed before, we need to analyze it when -// "tbl.ModifyCount/tbl.Count > autoAnalyzeRatio" and the current time is -// between `start` and `end`. -// -// Exposed for test. -func NeedAnalyzeTable(tbl *statistics.Table, autoAnalyzeRatio float64) (bool, string) { - analyzed := tbl.IsAnalyzed() - if !analyzed { - return true, "table unanalyzed" - } - // Auto analyze is disabled. - if autoAnalyzeRatio == 0 { - return false, "" - } - // No need to analyze it. - tblCnt := float64(tbl.RealtimeCount) - if histCnt := tbl.GetAnalyzeRowCount(); histCnt > 0 { - tblCnt = histCnt - } - if float64(tbl.ModifyCount)/tblCnt <= autoAnalyzeRatio { - return false, "" - } - return true, fmt.Sprintf("too many modifications(%v/%v>%v)", tbl.ModifyCount, tblCnt, autoAnalyzeRatio) -} - -// It is very similar to tryAutoAnalyzeTable, but it commits the analyze job in batch for partitions. -func tryAutoAnalyzePartitionTableInDynamicMode( - sctx sessionctx.Context, - statsHandle statstypes.StatsHandle, - sysProcTracker sysproctrack.Tracker, - tblInfo *model.TableInfo, - partitionDefs []model.PartitionDefinition, - partitionStats map[int64]*statistics.Table, - db string, - ratio float64, -) bool { - tableStatsVer := sctx.GetSessionVars().AnalyzeVersion - analyzePartitionBatchSize := int(variable.AutoAnalyzePartitionBatchSize.Load()) - needAnalyzePartitionNames := make([]any, 0, len(partitionDefs)) - - for _, def := range partitionDefs { - partitionStats := partitionStats[def.ID] - // 1. If the statistics are either not loaded or are classified as pseudo, there is no need for analyze. - // Pseudo statistics can be created by the optimizer, so we need to double check it. - // 2. If the table is too small, we don't want to waste time to analyze it. - // Leave the opportunity to other bigger tables. - if partitionStats == nil || partitionStats.Pseudo || partitionStats.RealtimeCount < statistics.AutoAnalyzeMinCnt { - continue - } - if needAnalyze, reason := NeedAnalyzeTable( - partitionStats, - ratio, - ); needAnalyze { - needAnalyzePartitionNames = append(needAnalyzePartitionNames, def.Name.O) - statslogutil.StatsLogger().Info( - "need to auto analyze", - zap.String("database", db), - zap.String("table", tblInfo.Name.String()), - zap.String("partition", def.Name.O), - zap.String("reason", reason), - ) - statistics.CheckAnalyzeVerOnTable(partitionStats, &tableStatsVer) - } - } - - getSQL := func(prefix, suffix string, numPartitions int) string { - var sqlBuilder strings.Builder - sqlBuilder.WriteString(prefix) - for i := 0; i < numPartitions; i++ { - if i != 0 { - sqlBuilder.WriteString(",") - } - sqlBuilder.WriteString(" %n") - } - sqlBuilder.WriteString(suffix) - return sqlBuilder.String() - } - - if len(needAnalyzePartitionNames) > 0 { - statslogutil.StatsLogger().Info("start to auto analyze", - zap.String("database", db), - zap.String("table", tblInfo.Name.String()), - zap.Any("partitions", needAnalyzePartitionNames), - zap.Int("analyze partition batch size", analyzePartitionBatchSize), - ) - - statsTbl := statsHandle.GetTableStats(tblInfo) - statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) - for i := 0; i < len(needAnalyzePartitionNames); i += analyzePartitionBatchSize { - start := i - end := start + analyzePartitionBatchSize - if end >= len(needAnalyzePartitionNames) { - end = len(needAnalyzePartitionNames) - } - - // Do batch analyze for partitions. - sql := getSQL("analyze table %n.%n partition", "", end-start) - params := append([]any{db, tblInfo.Name.O}, needAnalyzePartitionNames[start:end]...) - - statslogutil.StatsLogger().Info( - "auto analyze triggered", - zap.String("database", db), - zap.String("table", tblInfo.Name.String()), - zap.Any("partitions", needAnalyzePartitionNames[start:end]), - ) - exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sql, params...) - } - - return true - } - // Check if any index of the table needs to analyze. - for _, idx := range tblInfo.Indices { - if idx.State != model.StatePublic { - continue - } - // Collect all the partition names that need to analyze. - for _, def := range partitionDefs { - partitionStats := partitionStats[def.ID] - // 1. If the statistics are either not loaded or are classified as pseudo, there is no need for analyze. - // Pseudo statistics can be created by the optimizer, so we need to double check it. - if partitionStats == nil || partitionStats.Pseudo { - continue - } - // 2. If the index is not analyzed, we need to analyze it. - if !partitionStats.ColAndIdxExistenceMap.HasAnalyzed(idx.ID, true) { - needAnalyzePartitionNames = append(needAnalyzePartitionNames, def.Name.O) - statistics.CheckAnalyzeVerOnTable(partitionStats, &tableStatsVer) - } - } - if len(needAnalyzePartitionNames) > 0 { - statsTbl := statsHandle.GetTableStats(tblInfo) - statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) - - for i := 0; i < len(needAnalyzePartitionNames); i += analyzePartitionBatchSize { - start := i - end := start + analyzePartitionBatchSize - if end >= len(needAnalyzePartitionNames) { - end = len(needAnalyzePartitionNames) - } - - sql := getSQL("analyze table %n.%n partition", " index %n", end-start) - params := append([]any{db, tblInfo.Name.O}, needAnalyzePartitionNames[start:end]...) - params = append(params, idx.Name.O) - statslogutil.StatsLogger().Info("auto analyze for unanalyzed", - zap.String("database", db), - zap.String("table", tblInfo.Name.String()), - zap.String("index", idx.Name.String()), - zap.Any("partitions", needAnalyzePartitionNames[start:end]), - ) - exec.AutoAnalyze(sctx, statsHandle, sysProcTracker, tableStatsVer, sql, params...) - } - - return true - } - } - - return false -} - -// insertAnalyzeJob inserts analyze job into mysql.analyze_jobs and gets job ID for further updating job. -func insertAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, instance string, procID uint64) (err error) { - jobInfo := job.JobInfo - const textMaxLength = 65535 - if len(jobInfo) > textMaxLength { - jobInfo = jobInfo[:textMaxLength] - } - const insertJob = "INSERT INTO mysql.analyze_jobs (table_schema, table_name, partition_name, job_info, state, instance, process_id) VALUES (%?, %?, %?, %?, %?, %?, %?)" - _, _, err = statsutil.ExecRows(sctx, insertJob, job.DBName, job.TableName, job.PartitionName, jobInfo, statistics.AnalyzePending, instance, procID) - if err != nil { - return err - } - const getJobID = "SELECT LAST_INSERT_ID()" - rows, _, err := statsutil.ExecRows(sctx, getJobID) - if err != nil { - return err - } - job.ID = new(uint64) - *job.ID = rows[0].GetUint64(0) - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { - if val.(bool) { - logutil.BgLogger().Info("InsertAnalyzeJob", - zap.String("table_schema", job.DBName), - zap.String("table_name", job.TableName), - zap.String("partition_name", job.PartitionName), - zap.String("job_info", jobInfo), - zap.Uint64("job_id", *job.ID), - ) - } - }) - return nil -} - -// startAnalyzeJob marks the state of the analyze job as running and sets the start time. -func startAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob) { - if job == nil || job.ID == nil { - return - } - job.StartTime = time.Now() - job.Progress.SetLastDumpTime(job.StartTime) - const sql = "UPDATE mysql.analyze_jobs SET start_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %? WHERE id = %?" - _, _, err := statsutil.ExecRows(sctx, sql, job.StartTime.UTC().Format(types.TimeFormat), statistics.AnalyzeRunning, *job.ID) - if err != nil { - statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzePending, statistics.AnalyzeRunning)), zap.Error(err)) - } - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { - if val.(bool) { - logutil.BgLogger().Info("StartAnalyzeJob", - zap.Time("start_time", job.StartTime), - zap.Uint64("job id", *job.ID), - ) - } - }) -} - -// updateAnalyzeJobProgress updates count of the processed rows when increment reaches a threshold. -func updateAnalyzeJobProgress(sctx sessionctx.Context, job *statistics.AnalyzeJob, rowCount int64) { - if job == nil || job.ID == nil { - return - } - delta := job.Progress.Update(rowCount) - if delta == 0 { - return - } - const sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %? WHERE id = %?" - _, _, err := statsutil.ExecRows(sctx, sql, delta, *job.ID) - if err != nil { - statslogutil.StatsLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("process %v rows", delta)), zap.Error(err)) - } - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { - if val.(bool) { - logutil.BgLogger().Info("UpdateAnalyzeJobProgress", - zap.Int64("increase processed_rows", delta), - zap.Uint64("job id", *job.ID), - ) - } - }) -} - -// finishAnalyzeJob finishes an analyze or merge job -func finishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analyzeErr error, analyzeType statistics.JobType) { - if job == nil || job.ID == nil { - return - } - - job.EndTime = time.Now() - var sql string - var args []any - - // process_id is used to see which process is running the analyze job and kill the analyze job. After the analyze job - // is finished(or failed), process_id is useless and we set it to NULL to avoid `kill tidb process_id` wrongly. - if analyzeErr != nil { - failReason := analyzeErr.Error() - const textMaxLength = 65535 - if len(failReason) > textMaxLength { - failReason = failReason[:textMaxLength] - } - - if analyzeType == statistics.TableAnalysisJob { - sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %?, end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, fail_reason = %?, process_id = NULL WHERE id = %?" - args = []any{job.Progress.GetDeltaCount(), job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFailed, failReason, *job.ID} - } else { - sql = "UPDATE mysql.analyze_jobs SET end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, fail_reason = %?, process_id = NULL WHERE id = %?" - args = []any{job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFailed, failReason, *job.ID} - } - } else { - if analyzeType == statistics.TableAnalysisJob { - sql = "UPDATE mysql.analyze_jobs SET processed_rows = processed_rows + %?, end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, process_id = NULL WHERE id = %?" - args = []any{job.Progress.GetDeltaCount(), job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFinished, *job.ID} - } else { - sql = "UPDATE mysql.analyze_jobs SET end_time = CONVERT_TZ(%?, '+00:00', @@TIME_ZONE), state = %?, process_id = NULL WHERE id = %?" - args = []any{job.EndTime.UTC().Format(types.TimeFormat), statistics.AnalyzeFinished, *job.ID} - } - } - - _, _, err := statsutil.ExecRows(sctx, sql, args...) - if err != nil { - state := statistics.AnalyzeFinished - if analyzeErr != nil { - state = statistics.AnalyzeFailed - } - logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzeRunning, state)), zap.Error(err)) - } - - failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) { - if val.(bool) { - logger := logutil.BgLogger().With( - zap.Time("end_time", job.EndTime), - zap.Uint64("job id", *job.ID), - ) - if analyzeType == statistics.TableAnalysisJob { - logger = logger.With(zap.Int64("increase processed_rows", job.Progress.GetDeltaCount())) - } - if analyzeErr != nil { - logger = logger.With(zap.Error(analyzeErr)) - } - logger.Info("FinishAnalyzeJob") - } - }) -} diff --git a/pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go b/pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go deleted file mode 100644 index 258dfa09d159f..0000000000000 --- a/pkg/statistics/handle/autoanalyze/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package autoanalyze - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/statistics/handle/binding__failpoint_binding__.go b/pkg/statistics/handle/binding__failpoint_binding__.go deleted file mode 100644 index 947a787548ae9..0000000000000 --- a/pkg/statistics/handle/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package handle - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/statistics/handle/bootstrap.go b/pkg/statistics/handle/bootstrap.go index 24cd4233acc27..c2157f21eed39 100644 --- a/pkg/statistics/handle/bootstrap.go +++ b/pkg/statistics/handle/bootstrap.go @@ -734,7 +734,7 @@ func (h *Handle) InitStatsLite(ctx context.Context, is infoschema.InfoSchema) (e if err != nil { return err } - failpoint.Eval(_curpkg_("beforeInitStatsLite")) + failpoint.Inject("beforeInitStatsLite", func() {}) cache, err := h.initStatsMeta(ctx, is) if err != nil { return errors.Trace(err) @@ -769,7 +769,7 @@ func (h *Handle) InitStats(ctx context.Context, is infoschema.InfoSchema) (err e if err != nil { return err } - failpoint.Eval(_curpkg_("beforeInitStats")) + failpoint.Inject("beforeInitStats", func() {}) cache, err := h.initStatsMeta(ctx, is) if err != nil { return errors.Trace(err) diff --git a/pkg/statistics/handle/bootstrap.go__failpoint_stash__ b/pkg/statistics/handle/bootstrap.go__failpoint_stash__ deleted file mode 100644 index c2157f21eed39..0000000000000 --- a/pkg/statistics/handle/bootstrap.go__failpoint_stash__ +++ /dev/null @@ -1,815 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package handle - -import ( - "context" - "sync" - "sync/atomic" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/cache" - "github.com/pingcap/tidb/pkg/statistics/handle/initstats" - statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" - statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" - "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" -) - -// initStatsStep is the step to load stats by paging. -const initStatsStep = int64(500) - -var maxTidRecord MaxTidRecord - -// MaxTidRecord is to record the max tid. -type MaxTidRecord struct { - mu sync.Mutex - tid atomic.Int64 -} - -func (h *Handle) initStatsMeta4Chunk(ctx context.Context, is infoschema.InfoSchema, cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { - var physicalID, maxPhysicalID int64 - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - physicalID = row.GetInt64(1) - - // Detect the context cancel signal, since it may take a long time for the loop. - // TODO: add context to TableInfoByID and remove this code block? - if ctx.Err() != nil { - return - } - - // The table is read-only. Please do not modify it. - table, ok := h.TableInfoByID(is, physicalID) - if !ok { - logutil.BgLogger().Debug("unknown physical ID in stats meta table, maybe it has been dropped", zap.Int64("ID", physicalID)) - continue - } - maxPhysicalID = max(physicalID, maxPhysicalID) - tableInfo := table.Meta() - newHistColl := *statistics.NewHistColl(physicalID, true, row.GetInt64(3), row.GetInt64(2), 4, 4) - snapshot := row.GetUint64(4) - tbl := &statistics.Table{ - HistColl: newHistColl, - Version: row.GetUint64(0), - ColAndIdxExistenceMap: statistics.NewColAndIndexExistenceMap(len(tableInfo.Columns), len(tableInfo.Indices)), - IsPkIsHandle: tableInfo.PKIsHandle, - // During the initialization phase, we need to initialize LastAnalyzeVersion with the snapshot, - // which ensures that we don't duplicate the auto-analyze of a particular type of table. - // When the predicate columns feature is turned on, if a table has neither predicate columns nor indexes, - // then auto-analyze will only analyze the _row_id and refresh stats_meta, - // but since we don't have any histograms or topn's created for _row_id at the moment. - // So if we don't initialize LastAnalyzeVersion with the snapshot here, - // it will stay at 0 and auto-analyze won't be able to detect that the table has been analyzed. - // But in the future, we maybe will create some records for _row_id, see: - // https://github.com/pingcap/tidb/issues/51098 - LastAnalyzeVersion: snapshot, - } - cache.Put(physicalID, tbl) // put this table again since it is updated - } - maxTidRecord.mu.Lock() - defer maxTidRecord.mu.Unlock() - if maxTidRecord.tid.Load() < maxPhysicalID { - maxTidRecord.tid.Store(physicalID) - } -} - -func (h *Handle) initStatsMeta(ctx context.Context, is infoschema.InfoSchema) (statstypes.StatsCache, error) { - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) - sql := "select HIGH_PRIORITY version, table_id, modify_count, count, snapshot from mysql.stats_meta" - rc, err := util.Exec(h.initStatsCtx, sql) - if err != nil { - return nil, errors.Trace(err) - } - defer terror.Call(rc.Close) - tables, err := cache.NewStatsCacheImpl(h) - if err != nil { - return nil, err - } - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return nil, errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsMeta4Chunk(ctx, is, tables, iter) - } - return tables, nil -} - -func (h *Handle) initStatsHistograms4ChunkLite(is infoschema.InfoSchema, cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { - var table *statistics.Table - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - tblID := row.GetInt64(0) - if table == nil || table.PhysicalID != tblID { - if table != nil { - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } - var ok bool - table, ok = cache.Get(tblID) - if !ok { - continue - } - table = table.Copy() - } - isIndex := row.GetInt64(1) - id := row.GetInt64(2) - ndv := row.GetInt64(3) - nullCount := row.GetInt64(5) - statsVer := row.GetInt64(7) - tbl, _ := h.TableInfoByID(is, table.PhysicalID) - // All the objects in the table share the same stats version. - if statsVer != statistics.Version0 { - table.StatsVer = int(statsVer) - } - if isIndex > 0 { - var idxInfo *model.IndexInfo - for _, idx := range tbl.Meta().Indices { - if idx.ID == id { - idxInfo = idx - break - } - } - if idxInfo == nil { - continue - } - table.ColAndIdxExistenceMap.InsertIndex(idxInfo.ID, idxInfo, statsVer != statistics.Version0) - if statsVer != statistics.Version0 { - // The LastAnalyzeVersion is added by ALTER table so its value might be 0. - table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, row.GetUint64(4)) - } - } else { - var colInfo *model.ColumnInfo - for _, col := range tbl.Meta().Columns { - if col.ID == id { - colInfo = col - break - } - } - if colInfo == nil { - continue - } - table.ColAndIdxExistenceMap.InsertCol(colInfo.ID, colInfo, statsVer != statistics.Version0 || ndv > 0 || nullCount > 0) - if statsVer != statistics.Version0 { - // The LastAnalyzeVersion is added by ALTER table so its value might be 0. - table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, row.GetUint64(4)) - } - } - } - if table != nil { - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } -} - -func (h *Handle) initStatsHistograms4Chunk(is infoschema.InfoSchema, cache statstypes.StatsCache, iter *chunk.Iterator4Chunk, isCacheFull bool) { - var table *statistics.Table - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - tblID, statsVer := row.GetInt64(0), row.GetInt64(8) - if table == nil || table.PhysicalID != tblID { - if table != nil { - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } - var ok bool - table, ok = cache.Get(tblID) - if !ok { - continue - } - table = table.Copy() - } - // All the objects in the table share the same stats version. - if statsVer != statistics.Version0 { - table.StatsVer = int(statsVer) - } - id, ndv, nullCount, version, totColSize := row.GetInt64(2), row.GetInt64(3), row.GetInt64(5), row.GetUint64(4), row.GetInt64(7) - lastAnalyzePos := row.GetDatum(11, types.NewFieldType(mysql.TypeBlob)) - tbl, _ := h.TableInfoByID(is, table.PhysicalID) - if row.GetInt64(1) > 0 { - var idxInfo *model.IndexInfo - for _, idx := range tbl.Meta().Indices { - if idx.ID == id { - idxInfo = idx - break - } - } - if idxInfo == nil { - continue - } - - var cms *statistics.CMSketch - var topN *statistics.TopN - var err error - if !isCacheFull { - // stats cache is full. we should not put it into cache. but we must set LastAnalyzeVersion - cms, topN, err = statistics.DecodeCMSketchAndTopN(row.GetBytes(6), nil) - if err != nil { - cms = nil - terror.Log(errors.Trace(err)) - } - } - hist := statistics.NewHistogram(id, ndv, nullCount, version, types.NewFieldType(mysql.TypeBlob), chunk.InitialCapacity, 0) - index := &statistics.Index{ - Histogram: *hist, - CMSketch: cms, - TopN: topN, - Info: idxInfo, - StatsVer: statsVer, - Flag: row.GetInt64(10), - PhysicalID: tblID, - } - if statsVer != statistics.Version0 { - // We first set the StatsLoadedStatus as AllEvicted. when completing to load bucket, we will set it as ALlLoad. - index.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() - // The LastAnalyzeVersion is added by ALTER table so its value might be 0. - table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) - } - lastAnalyzePos.Copy(&index.LastAnalyzePos) - table.SetIdx(idxInfo.ID, index) - table.ColAndIdxExistenceMap.InsertIndex(idxInfo.ID, idxInfo, statsVer != statistics.Version0) - } else { - var colInfo *model.ColumnInfo - for _, col := range tbl.Meta().Columns { - if col.ID == id { - colInfo = col - break - } - } - if colInfo == nil { - continue - } - hist := statistics.NewHistogram(id, ndv, nullCount, version, &colInfo.FieldType, 0, totColSize) - hist.Correlation = row.GetFloat64(9) - col := &statistics.Column{ - Histogram: *hist, - PhysicalID: table.PhysicalID, - Info: colInfo, - IsHandle: tbl.Meta().PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), - Flag: row.GetInt64(10), - StatsVer: statsVer, - } - // primary key column has no stats info, because primary key's is_index is false. so it cannot load the topn - col.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() - lastAnalyzePos.Copy(&col.LastAnalyzePos) - table.SetCol(hist.ID, col) - table.ColAndIdxExistenceMap.InsertCol(colInfo.ID, colInfo, statsVer != statistics.Version0 || ndv > 0 || nullCount > 0) - if statsVer != statistics.Version0 { - // The LastAnalyzeVersion is added by ALTER table so its value might be 0. - table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, version) - } - } - } - if table != nil { - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } -} - -func (h *Handle) initStatsHistogramsLite(ctx context.Context, is infoschema.InfoSchema, cache statstypes.StatsCache) error { - sql := "select /*+ ORDER_INDEX(mysql.stats_histograms,tbl)*/ HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms order by table_id" - rc, err := util.Exec(h.initStatsCtx, sql) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsHistograms4ChunkLite(is, cache, iter) - } - return nil -} - -func (h *Handle) initStatsHistograms(is infoschema.InfoSchema, cache statstypes.StatsCache) error { - sql := "select /*+ ORDER_INDEX(mysql.stats_histograms,tbl)*/ HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms order by table_id" - rc, err := util.Exec(h.initStatsCtx, sql) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsHistograms4Chunk(is, cache, iter, false) - } - return nil -} - -func (h *Handle) initStatsHistogramsByPaging(is infoschema.InfoSchema, cache statstypes.StatsCache, task initstats.Task, totalMemory uint64) error { - se, err := h.Pool.SPool().Get() - if err != nil { - return err - } - defer func() { - if err == nil { // only recycle when no error - h.Pool.SPool().Put(se) - } - }() - - sctx := se.(sessionctx.Context) - // Why do we need to add `is_index=1` in the SQL? - // because it is aligned to the `initStatsTopN` function, which only loads the topn of the index too. - // the other will be loaded by sync load. - sql := "select HIGH_PRIORITY table_id, is_index, hist_id, distinct_count, version, null_count, cm_sketch, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms where table_id >= %? and table_id < %? and is_index=1" - rc, err := util.Exec(sctx, sql, task.StartTid, task.EndTid) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsHistograms4Chunk(is, cache, iter, isFullCache(cache, totalMemory)) - } - return nil -} - -func (h *Handle) initStatsHistogramsConcurrency(is infoschema.InfoSchema, cache statstypes.StatsCache, totalMemory uint64) error { - var maxTid = maxTidRecord.tid.Load() - tid := int64(0) - ls := initstats.NewRangeWorker("histogram", func(task initstats.Task) error { - return h.initStatsHistogramsByPaging(is, cache, task, totalMemory) - }, uint64(maxTid), uint64(initStatsStep)) - ls.LoadStats() - for tid <= maxTid { - ls.SendTask(initstats.Task{ - StartTid: tid, - EndTid: tid + initStatsStep, - }) - tid += initStatsStep - } - ls.Wait() - return nil -} - -func (*Handle) initStatsTopN4Chunk(cache statstypes.StatsCache, iter *chunk.Iterator4Chunk, totalMemory uint64) { - if isFullCache(cache, totalMemory) { - return - } - affectedIndexes := make(map[*statistics.Index]struct{}) - var table *statistics.Table - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - tblID := row.GetInt64(0) - if table == nil || table.PhysicalID != tblID { - if table != nil { - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } - var ok bool - table, ok = cache.Get(tblID) - if !ok { - continue - } - table = table.Copy() - } - idx := table.GetIdx(row.GetInt64(1)) - if idx == nil || (idx.CMSketch == nil && idx.StatsVer <= statistics.Version1) { - continue - } - if idx.TopN == nil { - idx.TopN = statistics.NewTopN(32) - } - affectedIndexes[idx] = struct{}{} - data := make([]byte, len(row.GetBytes(2))) - copy(data, row.GetBytes(2)) - idx.TopN.AppendTopN(data, row.GetUint64(3)) - } - if table != nil { - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } - for idx := range affectedIndexes { - idx.TopN.Sort() - } -} - -func (h *Handle) initStatsTopN(cache statstypes.StatsCache, totalMemory uint64) error { - sql := "select /*+ ORDER_INDEX(mysql.stats_top_n,tbl)*/ HIGH_PRIORITY table_id, hist_id, value, count from mysql.stats_top_n where is_index = 1 order by table_id" - rc, err := util.Exec(h.initStatsCtx, sql) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsTopN4Chunk(cache, iter, totalMemory) - } - return nil -} - -func (h *Handle) initStatsTopNByPaging(cache statstypes.StatsCache, task initstats.Task, totalMemory uint64) error { - se, err := h.Pool.SPool().Get() - if err != nil { - return err - } - defer func() { - if err == nil { // only recycle when no error - h.Pool.SPool().Put(se) - } - }() - sctx := se.(sessionctx.Context) - sql := "select HIGH_PRIORITY table_id, hist_id, value, count from mysql.stats_top_n where is_index = 1 and table_id >= %? and table_id < %? order by table_id" - rc, err := util.Exec(sctx, sql, task.StartTid, task.EndTid) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsTopN4Chunk(cache, iter, totalMemory) - } - return nil -} - -func (h *Handle) initStatsTopNConcurrency(cache statstypes.StatsCache, totalMemory uint64) error { - if isFullCache(cache, totalMemory) { - return nil - } - var maxTid = maxTidRecord.tid.Load() - tid := int64(0) - ls := initstats.NewRangeWorker("TopN", func(task initstats.Task) error { - if isFullCache(cache, totalMemory) { - return nil - } - return h.initStatsTopNByPaging(cache, task, totalMemory) - }, uint64(maxTid), uint64(initStatsStep)) - ls.LoadStats() - for tid <= maxTid { - if isFullCache(cache, totalMemory) { - break - } - ls.SendTask(initstats.Task{ - StartTid: tid, - EndTid: tid + initStatsStep, - }) - tid += initStatsStep - } - ls.Wait() - return nil -} - -func (*Handle) initStatsFMSketch4Chunk(cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - table, ok := cache.Get(row.GetInt64(0)) - if !ok { - continue - } - fms, err := statistics.DecodeFMSketch(row.GetBytes(3)) - if err != nil { - fms = nil - terror.Log(errors.Trace(err)) - } - - isIndex := row.GetInt64(1) - id := row.GetInt64(2) - if isIndex == 1 { - if idxStats := table.GetIdx(id); idxStats != nil { - idxStats.FMSketch = fms - } - } else { - if colStats := table.GetCol(id); colStats != nil { - colStats.FMSketch = fms - } - } - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } -} - -func (h *Handle) initStatsFMSketch(cache statstypes.StatsCache) error { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - sql := "select HIGH_PRIORITY table_id, is_index, hist_id, value from mysql.stats_fm_sketch" - rc, err := util.Exec(h.initStatsCtx, sql) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsFMSketch4Chunk(cache, iter) - } - return nil -} - -func (*Handle) initStatsBuckets4Chunk(cache statstypes.StatsCache, iter *chunk.Iterator4Chunk) { - var table *statistics.Table - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - tableID, isIndex, histID := row.GetInt64(0), row.GetInt64(1), row.GetInt64(2) - if table == nil || table.PhysicalID != tableID { - if table != nil { - table.SetAllIndexFullLoadForBootstrap() - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } - var ok bool - table, ok = cache.Get(tableID) - if !ok { - continue - } - table = table.Copy() - } - var lower, upper types.Datum - var hist *statistics.Histogram - if isIndex > 0 { - index := table.GetIdx(histID) - if index == nil { - continue - } - hist = &index.Histogram - lower, upper = types.NewBytesDatum(row.GetBytes(5)), types.NewBytesDatum(row.GetBytes(6)) - } else { - column := table.GetCol(histID) - if column == nil { - continue - } - if !mysql.HasPriKeyFlag(column.Info.GetFlag()) { - continue - } - hist = &column.Histogram - d := types.NewBytesDatum(row.GetBytes(5)) - var err error - lower, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, &column.Info.FieldType) - if err != nil { - logutil.BgLogger().Debug("decode bucket lower bound failed", zap.Error(err)) - table.DelCol(histID) - continue - } - d = types.NewBytesDatum(row.GetBytes(6)) - upper, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, &column.Info.FieldType) - if err != nil { - logutil.BgLogger().Debug("decode bucket upper bound failed", zap.Error(err)) - table.DelCol(histID) - continue - } - } - hist.AppendBucketWithNDV(&lower, &upper, row.GetInt64(3), row.GetInt64(4), row.GetInt64(7)) - } - if table != nil { - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } -} - -func (h *Handle) initStatsBuckets(cache statstypes.StatsCache, totalMemory uint64) error { - if isFullCache(cache, totalMemory) { - return nil - } - if config.GetGlobalConfig().Performance.ConcurrentlyInitStats { - err := h.initStatsBucketsConcurrency(cache, totalMemory) - if err != nil { - return errors.Trace(err) - } - } else { - sql := "select /*+ ORDER_INDEX(mysql.stats_buckets,tbl)*/ HIGH_PRIORITY table_id, is_index, hist_id, count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets order by table_id, is_index, hist_id, bucket_id" - rc, err := util.Exec(h.initStatsCtx, sql) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsBuckets4Chunk(cache, iter) - } - } - tables := cache.Values() - for _, table := range tables { - table.CalcPreScalar() - cache.Put(table.PhysicalID, table) // put this table in the cache because all statstics of the table have been read. - } - return nil -} - -func (h *Handle) initStatsBucketsByPaging(cache statstypes.StatsCache, task initstats.Task) error { - se, err := h.Pool.SPool().Get() - if err != nil { - return err - } - defer func() { - if err == nil { // only recycle when no error - h.Pool.SPool().Put(se) - } - }() - sctx := se.(sessionctx.Context) - sql := "select HIGH_PRIORITY table_id, is_index, hist_id, count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets where table_id >= %? and table_id < %? order by table_id, is_index, hist_id, bucket_id" - rc, err := util.Exec(sctx, sql, task.StartTid, task.EndTid) - if err != nil { - return errors.Trace(err) - } - defer terror.Call(rc.Close) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - req := rc.NewChunk(nil) - iter := chunk.NewIterator4Chunk(req) - for { - err := rc.Next(ctx, req) - if err != nil { - return errors.Trace(err) - } - if req.NumRows() == 0 { - break - } - h.initStatsBuckets4Chunk(cache, iter) - } - return nil -} - -func (h *Handle) initStatsBucketsConcurrency(cache statstypes.StatsCache, totalMemory uint64) error { - if isFullCache(cache, totalMemory) { - return nil - } - var maxTid = maxTidRecord.tid.Load() - tid := int64(0) - ls := initstats.NewRangeWorker("bucket", func(task initstats.Task) error { - if isFullCache(cache, totalMemory) { - return nil - } - return h.initStatsBucketsByPaging(cache, task) - }, uint64(maxTid), uint64(initStatsStep)) - ls.LoadStats() - for tid <= maxTid { - ls.SendTask(initstats.Task{ - StartTid: tid, - EndTid: tid + initStatsStep, - }) - tid += initStatsStep - if isFullCache(cache, totalMemory) { - break - } - } - ls.Wait() - return nil -} - -// InitStatsLite initiates the stats cache. The function is liter and faster than InitStats. -// 1. Basic stats meta data is loaded.(count, modify count, etc.) -// 2. Column/index stats are loaded. (only histogram) -// 3. TopN, Bucket, FMSketch are not loaded. -func (h *Handle) InitStatsLite(ctx context.Context, is infoschema.InfoSchema) (err error) { - defer func() { - _, err1 := util.Exec(h.initStatsCtx, "commit") - if err == nil && err1 != nil { - err = err1 - } - }() - _, err = util.Exec(h.initStatsCtx, "begin") - if err != nil { - return err - } - failpoint.Inject("beforeInitStatsLite", func() {}) - cache, err := h.initStatsMeta(ctx, is) - if err != nil { - return errors.Trace(err) - } - statslogutil.StatsLogger().Info("complete to load the meta in the lite mode") - err = h.initStatsHistogramsLite(ctx, is, cache) - if err != nil { - cache.Close() - return errors.Trace(err) - } - statslogutil.StatsLogger().Info("complete to load the histogram in the lite mode") - h.Replace(cache) - return nil -} - -// InitStats initiates the stats cache. -// 1. Basic stats meta data is loaded.(count, modify count, etc.) -// 2. Column/index stats are loaded. (histogram, topn, buckets, FMSketch) -func (h *Handle) InitStats(ctx context.Context, is infoschema.InfoSchema) (err error) { - totalMemory, err := memory.MemTotal() - if err != nil { - return err - } - loadFMSketch := config.GetGlobalConfig().Performance.EnableLoadFMSketch - defer func() { - _, err1 := util.Exec(h.initStatsCtx, "commit") - if err == nil && err1 != nil { - err = err1 - } - }() - _, err = util.Exec(h.initStatsCtx, "begin") - if err != nil { - return err - } - failpoint.Inject("beforeInitStats", func() {}) - cache, err := h.initStatsMeta(ctx, is) - if err != nil { - return errors.Trace(err) - } - statslogutil.StatsLogger().Info("complete to load the meta") - if config.GetGlobalConfig().Performance.ConcurrentlyInitStats { - err = h.initStatsHistogramsConcurrency(is, cache, totalMemory) - } else { - err = h.initStatsHistograms(is, cache) - } - statslogutil.StatsLogger().Info("complete to load the histogram") - if err != nil { - return errors.Trace(err) - } - if config.GetGlobalConfig().Performance.ConcurrentlyInitStats { - err = h.initStatsTopNConcurrency(cache, totalMemory) - } else { - err = h.initStatsTopN(cache, totalMemory) - } - statslogutil.StatsLogger().Info("complete to load the topn") - if err != nil { - return err - } - if loadFMSketch { - err = h.initStatsFMSketch(cache) - if err != nil { - return err - } - statslogutil.StatsLogger().Info("complete to load the FM Sketch") - } - err = h.initStatsBuckets(cache, totalMemory) - statslogutil.StatsLogger().Info("complete to load the bucket") - if err != nil { - return errors.Trace(err) - } - h.Replace(cache) - return nil -} - -func isFullCache(cache statstypes.StatsCache, total uint64) bool { - memQuota := variable.StatsCacheMemQuota.Load() - return (uint64(cache.MemConsumed()) >= total/4) || (cache.MemConsumed() >= memQuota && memQuota != 0) -} diff --git a/pkg/statistics/handle/cache/binding__failpoint_binding__.go b/pkg/statistics/handle/cache/binding__failpoint_binding__.go deleted file mode 100644 index 01370bd22e2f3..0000000000000 --- a/pkg/statistics/handle/cache/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package cache - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/statistics/handle/cache/statscache.go b/pkg/statistics/handle/cache/statscache.go index 8a5b10edb2581..66444375ea6e9 100644 --- a/pkg/statistics/handle/cache/statscache.go +++ b/pkg/statistics/handle/cache/statscache.go @@ -223,9 +223,9 @@ func (s *StatsCacheImpl) MemConsumed() (size int64) { // Get returns the specified table's stats. func (s *StatsCacheImpl) Get(tableID int64) (*statistics.Table, bool) { - if _, _err_ := failpoint.Eval(_curpkg_("StatsCacheGetNil")); _err_ == nil { - return nil, false - } + failpoint.Inject("StatsCacheGetNil", func() { + failpoint.Return(nil, false) + }) return s.Load().Get(tableID) } diff --git a/pkg/statistics/handle/cache/statscache.go__failpoint_stash__ b/pkg/statistics/handle/cache/statscache.go__failpoint_stash__ deleted file mode 100644 index 66444375ea6e9..0000000000000 --- a/pkg/statistics/handle/cache/statscache.go__failpoint_stash__ +++ /dev/null @@ -1,287 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a Copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cache - -import ( - "context" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/infoschema" - tidbmetrics "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/cache/internal/metrics" - statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" - handle_metrics "github.com/pingcap/tidb/pkg/statistics/handle/metrics" - "github.com/pingcap/tidb/pkg/statistics/handle/types" - "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -// StatsCacheImpl implements util.StatsCache. -type StatsCacheImpl struct { - atomic.Pointer[StatsCache] - - statsHandle types.StatsHandle -} - -// NewStatsCacheImpl creates a new StatsCache. -func NewStatsCacheImpl(statsHandle types.StatsHandle) (types.StatsCache, error) { - newCache, err := NewStatsCache() - if err != nil { - return nil, err - } - - result := &StatsCacheImpl{ - statsHandle: statsHandle, - } - result.Store(newCache) - - return result, nil -} - -// NewStatsCacheImplForTest creates a new StatsCache for test. -func NewStatsCacheImplForTest() (types.StatsCache, error) { - return NewStatsCacheImpl(nil) -} - -// Update reads stats meta from store and updates the stats map. -func (s *StatsCacheImpl) Update(ctx context.Context, is infoschema.InfoSchema) error { - start := time.Now() - lastVersion := s.getLastVersion() - var ( - rows []chunk.Row - err error - ) - if err := util.CallWithSCtx(s.statsHandle.SPool(), func(sctx sessionctx.Context) error { - rows, _, err = util.ExecRows( - sctx, - "SELECT version, table_id, modify_count, count, snapshot from mysql.stats_meta where version > %? order by version", - lastVersion, - ) - return err - }); err != nil { - return errors.Trace(err) - } - - tables := make([]*statistics.Table, 0, len(rows)) - deletedTableIDs := make([]int64, 0, len(rows)) - - for _, row := range rows { - version := row.GetUint64(0) - physicalID := row.GetInt64(1) - modifyCount := row.GetInt64(2) - count := row.GetInt64(3) - snapshot := row.GetUint64(4) - - // Detect the context cancel signal, since it may take a long time for the loop. - // TODO: add context to TableInfoByID and remove this code block? - if ctx.Err() != nil { - return ctx.Err() - } - - table, ok := s.statsHandle.TableInfoByID(is, physicalID) - if !ok { - logutil.BgLogger().Debug( - "unknown physical ID in stats meta table, maybe it has been dropped", - zap.Int64("ID", physicalID), - ) - deletedTableIDs = append(deletedTableIDs, physicalID) - continue - } - tableInfo := table.Meta() - // If the table is not updated, we can skip it. - if oldTbl, ok := s.Get(physicalID); ok && - oldTbl.Version >= version && - tableInfo.UpdateTS == oldTbl.TblInfoUpdateTS { - continue - } - tbl, err := s.statsHandle.TableStatsFromStorage( - tableInfo, - physicalID, - false, - 0, - ) - // Error is not nil may mean that there are some ddl changes on this table, we will not update it. - if err != nil { - statslogutil.StatsLogger().Error( - "error occurred when read table stats", - zap.String("table", tableInfo.Name.O), - zap.Error(err), - ) - continue - } - if tbl == nil { - deletedTableIDs = append(deletedTableIDs, physicalID) - continue - } - tbl.Version = version - tbl.RealtimeCount = count - tbl.ModifyCount = modifyCount - tbl.TblInfoUpdateTS = tableInfo.UpdateTS - // It only occurs in the following situations: - // 1. The table has already been analyzed, - // but because the predicate columns feature is turned on, and it doesn't have any columns or indexes analyzed, - // it only analyzes _row_id and refreshes stats_meta, in which case the snapshot is not zero. - // 2. LastAnalyzeVersion is 0 because it has never been loaded. - // In this case, we can initialize LastAnalyzeVersion to the snapshot, - // otherwise auto-analyze will assume that the table has never been analyzed and try to analyze it again. - if tbl.LastAnalyzeVersion == 0 && snapshot != 0 { - tbl.LastAnalyzeVersion = snapshot - } - tables = append(tables, tbl) - } - - s.UpdateStatsCache(tables, deletedTableIDs) - dur := time.Since(start) - tidbmetrics.StatsDeltaLoadHistogram.Observe(dur.Seconds()) - return nil -} - -func (s *StatsCacheImpl) getLastVersion() uint64 { - // Get the greatest version of the stats meta table. - lastVersion := s.MaxTableStatsVersion() - // We need this because for two tables, the smaller version may write later than the one with larger version. - // Consider the case that there are two tables A and B, their version and commit time is (A0, A1) and (B0, B1), - // and A0 < B0 < B1 < A1. We will first read the stats of B, and update the lastVersion to B0, but we cannot read - // the table stats of A0 if we read stats that greater than lastVersion which is B0. - // We can read the stats if the diff between commit time and version is less than three lease. - offset := util.DurationToTS(3 * s.statsHandle.Lease()) - if s.MaxTableStatsVersion() >= offset { - lastVersion = lastVersion - offset - } else { - lastVersion = 0 - } - - return lastVersion -} - -// Replace replaces this cache. -func (s *StatsCacheImpl) Replace(cache types.StatsCache) { - x := cache.(*StatsCacheImpl) - s.replace(x.Load()) -} - -// replace replaces the cache with the new cache. -func (s *StatsCacheImpl) replace(newCache *StatsCache) { - old := s.Swap(newCache) - if old != nil { - old.Close() - } - metrics.CostGauge.Set(float64(newCache.Cost())) -} - -// UpdateStatsCache updates the cache with the new cache. -func (s *StatsCacheImpl) UpdateStatsCache(tables []*statistics.Table, deletedIDs []int64) { - if enableQuota := config.GetGlobalConfig().Performance.EnableStatsCacheMemQuota; enableQuota { - s.Load().Update(tables, deletedIDs) - } else { - // TODO: remove this branch because we will always enable quota. - newCache := s.Load().CopyAndUpdate(tables, deletedIDs) - s.replace(newCache) - } -} - -// Close closes this cache. -func (s *StatsCacheImpl) Close() { - s.Load().Close() -} - -// Clear clears this cache. -// Create a empty cache and replace the old one. -func (s *StatsCacheImpl) Clear() { - cache, err := NewStatsCache() - if err != nil { - logutil.BgLogger().Warn("create stats cache failed", zap.Error(err)) - return - } - s.replace(cache) -} - -// MemConsumed returns its memory usage. -func (s *StatsCacheImpl) MemConsumed() (size int64) { - return s.Load().Cost() -} - -// Get returns the specified table's stats. -func (s *StatsCacheImpl) Get(tableID int64) (*statistics.Table, bool) { - failpoint.Inject("StatsCacheGetNil", func() { - failpoint.Return(nil, false) - }) - return s.Load().Get(tableID) -} - -// Put puts this table stats into the cache. -func (s *StatsCacheImpl) Put(id int64, t *statistics.Table) { - s.Load().put(id, t) -} - -// MaxTableStatsVersion returns the version of the current cache, which is defined as -// the max table stats version the cache has in its lifecycle. -func (s *StatsCacheImpl) MaxTableStatsVersion() uint64 { - return s.Load().Version() -} - -// Values returns all values in this cache. -func (s *StatsCacheImpl) Values() []*statistics.Table { - return s.Load().Values() -} - -// Len returns the length of this cache. -func (s *StatsCacheImpl) Len() int { - return s.Load().Len() -} - -// SetStatsCacheCapacity sets the cache's capacity. -func (s *StatsCacheImpl) SetStatsCacheCapacity(c int64) { - s.Load().SetCapacity(c) -} - -// UpdateStatsHealthyMetrics updates stats healthy distribution metrics according to stats cache. -func (s *StatsCacheImpl) UpdateStatsHealthyMetrics() { - distribution := make([]int64, 5) - uneligibleAnalyze := 0 - for _, tbl := range s.Values() { - distribution[4]++ // total table count - isEligibleForAnalysis := tbl.IsEligibleForAnalysis() - if !isEligibleForAnalysis { - uneligibleAnalyze++ - continue - } - healthy, ok := tbl.GetStatsHealthy() - if !ok { - continue - } - if healthy < 50 { - distribution[0]++ - } else if healthy < 80 { - distribution[1]++ - } else if healthy < 100 { - distribution[2]++ - } else { - distribution[3]++ - } - } - for i, val := range distribution { - handle_metrics.StatsHealthyGauges[i].Set(float64(val)) - } - handle_metrics.StatsHealthyGauges[5].Set(float64(uneligibleAnalyze)) -} diff --git a/pkg/statistics/handle/globalstats/binding__failpoint_binding__.go b/pkg/statistics/handle/globalstats/binding__failpoint_binding__.go deleted file mode 100644 index d91f90cd70082..0000000000000 --- a/pkg/statistics/handle/globalstats/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package globalstats - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/statistics/handle/globalstats/global_stats_async.go b/pkg/statistics/handle/globalstats/global_stats_async.go index 8b2dce81f247a..2ef30c6cdba24 100644 --- a/pkg/statistics/handle/globalstats/global_stats_async.go +++ b/pkg/statistics/handle/globalstats/global_stats_async.go @@ -242,12 +242,12 @@ func (a *AsyncMergePartitionStats2GlobalStats) ioWorker(sctx sessionctx.Context, return err } close(a.cmsketch) - if val, _err_ := failpoint.Eval(_curpkg_("PanicSameTime")); _err_ == nil { + failpoint.Inject("PanicSameTime", func(val failpoint.Value) { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) panic("test for PanicSameTime") } - } + }) err = a.loadHistogramAndTopN(sctx, a.globalTableInfo, isIndex) if err != nil { close(a.ioWorkerExitWhenErrChan) @@ -285,12 +285,12 @@ func (a *AsyncMergePartitionStats2GlobalStats) cpuWorker(stmtCtx *stmtctx.Statem statslogutil.StatsLogger().Warn("dealCMSketch failed", zap.Error(err)) return err } - if val, _err_ := failpoint.Eval(_curpkg_("PanicSameTime")); _err_ == nil { + failpoint.Inject("PanicSameTime", func(val failpoint.Value) { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) panic("test for PanicSameTime") } - } + }) err = a.dealHistogramAndTopN(stmtCtx, sctx, opts, isIndex, tz, analyzeVersion) if err != nil { statslogutil.StatsLogger().Warn("dealHistogramAndTopN failed", zap.Error(err)) @@ -370,7 +370,7 @@ func (a *AsyncMergePartitionStats2GlobalStats) loadFmsketch(sctx sessionctx.Cont } func (a *AsyncMergePartitionStats2GlobalStats) loadCMsketch(sctx sessionctx.Context, isIndex bool) error { - failpoint.Eval(_curpkg_("PanicInIOWorker")) + failpoint.Inject("PanicInIOWorker", nil) for i := 0; i < a.globalStats.Num; i++ { for _, partitionID := range a.partitionIDs { _, ok := a.skipPartition[skipItem{ @@ -401,12 +401,12 @@ func (a *AsyncMergePartitionStats2GlobalStats) loadCMsketch(sctx sessionctx.Cont } func (a *AsyncMergePartitionStats2GlobalStats) loadHistogramAndTopN(sctx sessionctx.Context, tableInfo *model.TableInfo, isIndex bool) error { - if val, _err_ := failpoint.Eval(_curpkg_("ErrorSameTime")); _err_ == nil { + failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) - return errors.New("ErrorSameTime returned error") + failpoint.Return(errors.New("ErrorSameTime returned error")) } - } + }) for i := 0; i < a.globalStats.Num; i++ { hists := make([]*statistics.Histogram, 0, a.partitionNum) topn := make([]*statistics.TopN, 0, a.partitionNum) @@ -442,7 +442,7 @@ func (a *AsyncMergePartitionStats2GlobalStats) loadHistogramAndTopN(sctx session } func (a *AsyncMergePartitionStats2GlobalStats) dealFMSketch() { - failpoint.Eval(_curpkg_("PanicInCPUWorker")) + failpoint.Inject("PanicInCPUWorker", nil) for { select { case fms, ok := <-a.fmsketch: @@ -461,11 +461,11 @@ func (a *AsyncMergePartitionStats2GlobalStats) dealFMSketch() { } func (a *AsyncMergePartitionStats2GlobalStats) dealCMSketch() error { - if val, _err_ := failpoint.Eval(_curpkg_("dealCMSketchErr")); _err_ == nil { + failpoint.Inject("dealCMSketchErr", func(val failpoint.Value) { if val, _ := val.(bool); val { - return errors.New("dealCMSketch returned error") + failpoint.Return(errors.New("dealCMSketch returned error")) } - } + }) for { select { case cms, ok := <-a.cmsketch: @@ -487,17 +487,17 @@ func (a *AsyncMergePartitionStats2GlobalStats) dealCMSketch() error { } func (a *AsyncMergePartitionStats2GlobalStats) dealHistogramAndTopN(stmtCtx *stmtctx.StatementContext, sctx sessionctx.Context, opts map[ast.AnalyzeOptionType]uint64, isIndex bool, tz *time.Location, analyzeVersion int) (err error) { - if val, _err_ := failpoint.Eval(_curpkg_("dealHistogramAndTopNErr")); _err_ == nil { + failpoint.Inject("dealHistogramAndTopNErr", func(val failpoint.Value) { if val, _ := val.(bool); val { - return errors.New("dealHistogramAndTopNErr returned error") + failpoint.Return(errors.New("dealHistogramAndTopNErr returned error")) } - } - if val, _err_ := failpoint.Eval(_curpkg_("ErrorSameTime")); _err_ == nil { + }) + failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { if val, _ := val.(bool); val { time.Sleep(1 * time.Second) - return errors.New("ErrorSameTime returned error") + failpoint.Return(errors.New("ErrorSameTime returned error")) } - } + }) for { select { case item, ok := <-a.histogramAndTopn: diff --git a/pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ b/pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ deleted file mode 100644 index 2ef30c6cdba24..0000000000000 --- a/pkg/statistics/handle/globalstats/global_stats_async.go__failpoint_stash__ +++ /dev/null @@ -1,542 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package globalstats - -import ( - "context" - stderrors "errors" - "fmt" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/infoschema" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/statistics" - statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" - "github.com/pingcap/tidb/pkg/statistics/handle/storage" - statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" - "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/types" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -type mergeItem[T any] struct { - item T - idx int -} - -type skipItem struct { - histID int64 - partitionID int64 -} - -// toSQLIndex is used to convert bool to int64. -func toSQLIndex(isIndex bool) int { - var index = int(0) - if isIndex { - index = 1 - } - return index -} - -// AsyncMergePartitionStats2GlobalStats is used to merge partition stats to global stats. -// it divides the merge task into two parts. -// - IOWorker: load stats from storage. it will load fmsketch, cmsketch, histogram and topn. and send them to cpuWorker. -// - CPUWorker: merge the stats from IOWorker and generate global stats. -// -// ┌────────────────────────┐ ┌───────────────────────┐ -// │ │ │ │ -// │ │ │ │ -// │ │ │ │ -// │ IOWorker │ │ CPUWorker │ -// │ │ ────► │ │ -// │ │ │ │ -// │ │ │ │ -// │ │ │ │ -// └────────────────────────┘ └───────────────────────┘ -type AsyncMergePartitionStats2GlobalStats struct { - is infoschema.InfoSchema - statsHandle statstypes.StatsHandle - globalStats *GlobalStats - cmsketch chan mergeItem[*statistics.CMSketch] - fmsketch chan mergeItem[*statistics.FMSketch] - histogramAndTopn chan mergeItem[*StatsWrapper] - allPartitionStats map[int64]*statistics.Table - PartitionDefinition map[int64]model.PartitionDefinition - tableInfo map[int64]*model.TableInfo - // key is partition id and histID - skipPartition map[skipItem]struct{} - // ioWorker meet error, it will close this channel to notify cpuWorker. - ioWorkerExitWhenErrChan chan struct{} - // cpuWorker exit, it will close this channel to notify ioWorker. - cpuWorkerExitChan chan struct{} - globalTableInfo *model.TableInfo - histIDs []int64 - globalStatsNDV []int64 - partitionIDs []int64 - partitionNum int - skipMissingPartitionStats bool -} - -// NewAsyncMergePartitionStats2GlobalStats creates a new AsyncMergePartitionStats2GlobalStats. -func NewAsyncMergePartitionStats2GlobalStats( - statsHandle statstypes.StatsHandle, - globalTableInfo *model.TableInfo, - histIDs []int64, - is infoschema.InfoSchema) (*AsyncMergePartitionStats2GlobalStats, error) { - partitionNum := len(globalTableInfo.Partition.Definitions) - return &AsyncMergePartitionStats2GlobalStats{ - statsHandle: statsHandle, - cmsketch: make(chan mergeItem[*statistics.CMSketch], 5), - fmsketch: make(chan mergeItem[*statistics.FMSketch], 5), - histogramAndTopn: make(chan mergeItem[*StatsWrapper]), - PartitionDefinition: make(map[int64]model.PartitionDefinition), - tableInfo: make(map[int64]*model.TableInfo), - partitionIDs: make([]int64, 0, partitionNum), - ioWorkerExitWhenErrChan: make(chan struct{}), - cpuWorkerExitChan: make(chan struct{}), - skipPartition: make(map[skipItem]struct{}), - allPartitionStats: make(map[int64]*statistics.Table), - globalTableInfo: globalTableInfo, - histIDs: histIDs, - is: is, - partitionNum: partitionNum, - }, nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) prepare(sctx sessionctx.Context, isIndex bool) (err error) { - if len(a.histIDs) == 0 { - for _, col := range a.globalTableInfo.Columns { - // The virtual generated column stats can not be merged to the global stats. - if col.IsVirtualGenerated() { - continue - } - a.histIDs = append(a.histIDs, col.ID) - } - } - a.globalStats = newGlobalStats(len(a.histIDs)) - a.globalStats.Num = len(a.histIDs) - a.globalStatsNDV = make([]int64, 0, a.globalStats.Num) - // get all partition stats - for _, def := range a.globalTableInfo.Partition.Definitions { - partitionID := def.ID - a.partitionIDs = append(a.partitionIDs, partitionID) - a.PartitionDefinition[partitionID] = def - partitionTable, ok := a.statsHandle.TableInfoByID(a.is, partitionID) - if !ok { - return errors.Errorf("unknown physical ID %d in stats meta table, maybe it has been dropped", partitionID) - } - tableInfo := partitionTable.Meta() - a.tableInfo[partitionID] = tableInfo - realtimeCount, modifyCount, isNull, err := storage.StatsMetaCountAndModifyCount(sctx, partitionID) - if err != nil { - return err - } - if !isNull { - // In a partition, we will only update globalStats.Count once. - a.globalStats.Count += realtimeCount - a.globalStats.ModifyCount += modifyCount - } - err1 := skipPartition(sctx, partitionID, isIndex) - if err1 != nil { - // no idx so idx = 0 - err := a.dealWithSkipPartition(partitionID, isIndex, 0, err1) - if err != nil { - return err - } - if types.ErrPartitionStatsMissing.Equal(err1) { - continue - } - } - for idx, hist := range a.histIDs { - err1 := skipColumnPartition(sctx, partitionID, isIndex, hist) - if err1 != nil { - err := a.dealWithSkipPartition(partitionID, isIndex, idx, err1) - if err != nil { - return err - } - if types.ErrPartitionStatsMissing.Equal(err1) { - break - } - } - } - } - return nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) dealWithSkipPartition(partitionID int64, isIndex bool, idx int, err error) error { - switch { - case types.ErrPartitionStatsMissing.Equal(err): - return a.dealErrPartitionStatsMissing(partitionID) - case types.ErrPartitionColumnStatsMissing.Equal(err): - return a.dealErrPartitionColumnStatsMissing(isIndex, partitionID, idx) - default: - return err - } -} - -func (a *AsyncMergePartitionStats2GlobalStats) dealErrPartitionStatsMissing(partitionID int64) error { - missingPart := fmt.Sprintf("partition `%s`", a.PartitionDefinition[partitionID].Name.L) - a.globalStats.MissingPartitionStats = append(a.globalStats.MissingPartitionStats, missingPart) - for _, histID := range a.histIDs { - a.skipPartition[skipItem{ - histID: histID, - partitionID: partitionID, - }] = struct{}{} - } - return nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) dealErrPartitionColumnStatsMissing(isIndex bool, partitionID int64, idx int) error { - var missingPart string - if isIndex { - missingPart = fmt.Sprintf("partition `%s` index `%s`", a.PartitionDefinition[partitionID].Name.L, a.tableInfo[partitionID].FindIndexNameByID(a.histIDs[idx])) - } else { - missingPart = fmt.Sprintf("partition `%s` column `%s`", a.PartitionDefinition[partitionID].Name.L, a.tableInfo[partitionID].FindColumnNameByID(a.histIDs[idx])) - } - if !a.skipMissingPartitionStats { - return types.ErrPartitionColumnStatsMissing.GenWithStackByArgs(fmt.Sprintf("table `%s` %s", a.tableInfo[partitionID].Name.L, missingPart)) - } - a.globalStats.MissingPartitionStats = append(a.globalStats.MissingPartitionStats, missingPart) - a.skipPartition[skipItem{ - histID: a.histIDs[idx], - partitionID: partitionID, - }] = struct{}{} - return nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) ioWorker(sctx sessionctx.Context, isIndex bool) (err error) { - defer func() { - if r := recover(); r != nil { - statslogutil.StatsLogger().Warn("ioWorker panic", zap.Stack("stack"), zap.Any("error", r)) - close(a.ioWorkerExitWhenErrChan) - err = errors.New(fmt.Sprint(r)) - } - }() - err = a.loadFmsketch(sctx, isIndex) - if err != nil { - close(a.ioWorkerExitWhenErrChan) - return err - } - close(a.fmsketch) - err = a.loadCMsketch(sctx, isIndex) - if err != nil { - close(a.ioWorkerExitWhenErrChan) - return err - } - close(a.cmsketch) - failpoint.Inject("PanicSameTime", func(val failpoint.Value) { - if val, _ := val.(bool); val { - time.Sleep(1 * time.Second) - panic("test for PanicSameTime") - } - }) - err = a.loadHistogramAndTopN(sctx, a.globalTableInfo, isIndex) - if err != nil { - close(a.ioWorkerExitWhenErrChan) - return err - } - close(a.histogramAndTopn) - return nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) cpuWorker(stmtCtx *stmtctx.StatementContext, sctx sessionctx.Context, opts map[ast.AnalyzeOptionType]uint64, isIndex bool, tz *time.Location, analyzeVersion int) (err error) { - defer func() { - if r := recover(); r != nil { - statslogutil.StatsLogger().Warn("cpuWorker panic", zap.Stack("stack"), zap.Any("error", r)) - err = errors.New(fmt.Sprint(r)) - } - close(a.cpuWorkerExitChan) - }() - a.dealFMSketch() - select { - case <-a.ioWorkerExitWhenErrChan: - return nil - default: - for i := 0; i < a.globalStats.Num; i++ { - // Update the global NDV. - globalStatsNDV := a.globalStats.Fms[i].NDV() - if globalStatsNDV > a.globalStats.Count { - globalStatsNDV = a.globalStats.Count - } - a.globalStatsNDV = append(a.globalStatsNDV, globalStatsNDV) - a.globalStats.Fms[i].DestroyAndPutToPool() - } - } - err = a.dealCMSketch() - if err != nil { - statslogutil.StatsLogger().Warn("dealCMSketch failed", zap.Error(err)) - return err - } - failpoint.Inject("PanicSameTime", func(val failpoint.Value) { - if val, _ := val.(bool); val { - time.Sleep(1 * time.Second) - panic("test for PanicSameTime") - } - }) - err = a.dealHistogramAndTopN(stmtCtx, sctx, opts, isIndex, tz, analyzeVersion) - if err != nil { - statslogutil.StatsLogger().Warn("dealHistogramAndTopN failed", zap.Error(err)) - return err - } - return nil -} - -// Result returns the global stats. -func (a *AsyncMergePartitionStats2GlobalStats) Result() *GlobalStats { - return a.globalStats -} - -// MergePartitionStats2GlobalStats merges partition stats to global stats. -func (a *AsyncMergePartitionStats2GlobalStats) MergePartitionStats2GlobalStats( - sctx sessionctx.Context, - opts map[ast.AnalyzeOptionType]uint64, - isIndex bool, -) error { - a.skipMissingPartitionStats = sctx.GetSessionVars().SkipMissingPartitionStats - tz := sctx.GetSessionVars().StmtCtx.TimeZone() - analyzeVersion := sctx.GetSessionVars().AnalyzeVersion - stmtCtx := sctx.GetSessionVars().StmtCtx - return util.CallWithSCtx(a.statsHandle.SPool(), - func(sctx sessionctx.Context) error { - err := a.prepare(sctx, isIndex) - if err != nil { - return err - } - ctx := context.Background() - metawg, _ := errgroup.WithContext(ctx) - mergeWg, _ := errgroup.WithContext(ctx) - metawg.Go(func() error { - return a.ioWorker(sctx, isIndex) - }) - mergeWg.Go(func() error { - return a.cpuWorker(stmtCtx, sctx, opts, isIndex, tz, analyzeVersion) - }) - err = metawg.Wait() - if err != nil { - if err1 := mergeWg.Wait(); err1 != nil { - err = stderrors.Join(err, err1) - } - return err - } - return mergeWg.Wait() - }, - ) -} - -func (a *AsyncMergePartitionStats2GlobalStats) loadFmsketch(sctx sessionctx.Context, isIndex bool) error { - for i := 0; i < a.globalStats.Num; i++ { - // load fmsketch from tikv - for _, partitionID := range a.partitionIDs { - _, ok := a.skipPartition[skipItem{ - histID: a.histIDs[i], - partitionID: partitionID, - }] - if ok { - continue - } - fmsketch, err := storage.FMSketchFromStorage(sctx, partitionID, int64(toSQLIndex(isIndex)), a.histIDs[i]) - if err != nil { - return err - } - select { - case a.fmsketch <- mergeItem[*statistics.FMSketch]{ - fmsketch, i, - }: - case <-a.cpuWorkerExitChan: - statslogutil.StatsLogger().Warn("ioWorker detects CPUWorker has exited") - return nil - } - } - } - return nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) loadCMsketch(sctx sessionctx.Context, isIndex bool) error { - failpoint.Inject("PanicInIOWorker", nil) - for i := 0; i < a.globalStats.Num; i++ { - for _, partitionID := range a.partitionIDs { - _, ok := a.skipPartition[skipItem{ - histID: a.histIDs[i], - partitionID: partitionID, - }] - if ok { - continue - } - cmsketch, err := storage.CMSketchFromStorage(sctx, partitionID, toSQLIndex(isIndex), a.histIDs[i]) - if err != nil { - return err - } - a.cmsketch <- mergeItem[*statistics.CMSketch]{ - cmsketch, i, - } - select { - case a.cmsketch <- mergeItem[*statistics.CMSketch]{ - cmsketch, i, - }: - case <-a.cpuWorkerExitChan: - statslogutil.StatsLogger().Warn("ioWorker detects CPUWorker has exited") - return nil - } - } - } - return nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) loadHistogramAndTopN(sctx sessionctx.Context, tableInfo *model.TableInfo, isIndex bool) error { - failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { - if val, _ := val.(bool); val { - time.Sleep(1 * time.Second) - failpoint.Return(errors.New("ErrorSameTime returned error")) - } - }) - for i := 0; i < a.globalStats.Num; i++ { - hists := make([]*statistics.Histogram, 0, a.partitionNum) - topn := make([]*statistics.TopN, 0, a.partitionNum) - for _, partitionID := range a.partitionIDs { - _, ok := a.skipPartition[skipItem{ - histID: a.histIDs[i], - partitionID: partitionID, - }] - if ok { - continue - } - h, err := storage.LoadHistogram(sctx, partitionID, toSQLIndex(isIndex), a.histIDs[i], tableInfo) - if err != nil { - return err - } - t, err := storage.TopNFromStorage(sctx, partitionID, toSQLIndex(isIndex), a.histIDs[i]) - if err != nil { - return err - } - hists = append(hists, h) - topn = append(topn, t) - } - select { - case a.histogramAndTopn <- mergeItem[*StatsWrapper]{ - NewStatsWrapper(hists, topn), i, - }: - case <-a.cpuWorkerExitChan: - statslogutil.StatsLogger().Warn("ioWorker detects CPUWorker has exited") - return nil - } - } - return nil -} - -func (a *AsyncMergePartitionStats2GlobalStats) dealFMSketch() { - failpoint.Inject("PanicInCPUWorker", nil) - for { - select { - case fms, ok := <-a.fmsketch: - if !ok { - return - } - if a.globalStats.Fms[fms.idx] == nil { - a.globalStats.Fms[fms.idx] = fms.item - } else { - a.globalStats.Fms[fms.idx].MergeFMSketch(fms.item) - } - case <-a.ioWorkerExitWhenErrChan: - return - } - } -} - -func (a *AsyncMergePartitionStats2GlobalStats) dealCMSketch() error { - failpoint.Inject("dealCMSketchErr", func(val failpoint.Value) { - if val, _ := val.(bool); val { - failpoint.Return(errors.New("dealCMSketch returned error")) - } - }) - for { - select { - case cms, ok := <-a.cmsketch: - if !ok { - return nil - } - if a.globalStats.Cms[cms.idx] == nil { - a.globalStats.Cms[cms.idx] = cms.item - } else { - err := a.globalStats.Cms[cms.idx].MergeCMSketch(cms.item) - if err != nil { - return err - } - } - case <-a.ioWorkerExitWhenErrChan: - return nil - } - } -} - -func (a *AsyncMergePartitionStats2GlobalStats) dealHistogramAndTopN(stmtCtx *stmtctx.StatementContext, sctx sessionctx.Context, opts map[ast.AnalyzeOptionType]uint64, isIndex bool, tz *time.Location, analyzeVersion int) (err error) { - failpoint.Inject("dealHistogramAndTopNErr", func(val failpoint.Value) { - if val, _ := val.(bool); val { - failpoint.Return(errors.New("dealHistogramAndTopNErr returned error")) - } - }) - failpoint.Inject("ErrorSameTime", func(val failpoint.Value) { - if val, _ := val.(bool); val { - time.Sleep(1 * time.Second) - failpoint.Return(errors.New("ErrorSameTime returned error")) - } - }) - for { - select { - case item, ok := <-a.histogramAndTopn: - if !ok { - return nil - } - var err error - var poppedTopN []statistics.TopNMeta - var allhg []*statistics.Histogram - wrapper := item.item - a.globalStats.TopN[item.idx], poppedTopN, allhg, err = mergeGlobalStatsTopN(a.statsHandle.GPool(), sctx, wrapper, - tz, analyzeVersion, uint32(opts[ast.AnalyzeOptNumTopN]), isIndex) - if err != nil { - return err - } - - // Merge histogram. - globalHg := &(a.globalStats.Hg[item.idx]) - *globalHg, err = statistics.MergePartitionHist2GlobalHist(stmtCtx, allhg, poppedTopN, - int64(opts[ast.AnalyzeOptNumBuckets]), isIndex) - if err != nil { - return err - } - - // NOTICE: after merging bucket NDVs have the trend to be underestimated, so for safe we don't use them. - for j := range (*globalHg).Buckets { - (*globalHg).Buckets[j].NDV = 0 - } - (*globalHg).NDV = a.globalStatsNDV[item.idx] - case <-a.ioWorkerExitWhenErrChan: - return nil - } - } -} - -func skipPartition(sctx sessionctx.Context, partitionID int64, isIndex bool) error { - return storage.CheckSkipPartition(sctx, partitionID, toSQLIndex(isIndex)) -} - -func skipColumnPartition(sctx sessionctx.Context, partitionID int64, isIndex bool, histsID int64) error { - return storage.CheckSkipColumnPartiion(sctx, partitionID, toSQLIndex(isIndex), histsID) -} diff --git a/pkg/statistics/handle/storage/binding__failpoint_binding__.go b/pkg/statistics/handle/storage/binding__failpoint_binding__.go deleted file mode 100644 index a1a747a15d57f..0000000000000 --- a/pkg/statistics/handle/storage/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package storage - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/statistics/handle/storage/read.go b/pkg/statistics/handle/storage/read.go index c5aa829062d10..410f66b53c63a 100644 --- a/pkg/statistics/handle/storage/read.go +++ b/pkg/statistics/handle/storage/read.go @@ -221,9 +221,9 @@ func CheckSkipColumnPartiion(sctx sessionctx.Context, tblID int64, isIndex int, // ExtendedStatsFromStorage reads extended stats from storage. func ExtendedStatsFromStorage(sctx sessionctx.Context, table *statistics.Table, tableID int64, loadAll bool) (*statistics.Table, error) { - if _, _err_ := failpoint.Eval(_curpkg_("injectExtStatsLoadErr")); _err_ == nil { - return nil, errors.New("gofail extendedStatsFromStorage error") - } + failpoint.Inject("injectExtStatsLoadErr", func() { + failpoint.Return(nil, errors.New("gofail extendedStatsFromStorage error")) + }) lastVersion := uint64(0) if table.ExtendedStats != nil && !loadAll { lastVersion = table.ExtendedStats.LastUpdateVersion diff --git a/pkg/statistics/handle/storage/read.go__failpoint_stash__ b/pkg/statistics/handle/storage/read.go__failpoint_stash__ deleted file mode 100644 index 410f66b53c63a..0000000000000 --- a/pkg/statistics/handle/storage/read.go__failpoint_stash__ +++ /dev/null @@ -1,759 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package storage - -import ( - "encoding/json" - "strconv" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/asyncload" - statslogutil "github.com/pingcap/tidb/pkg/statistics/handle/logutil" - statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" - "github.com/pingcap/tidb/pkg/statistics/handle/util" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "go.uber.org/zap" -) - -// StatsMetaCountAndModifyCount reads count and modify_count for the given table from mysql.stats_meta. -func StatsMetaCountAndModifyCount(sctx sessionctx.Context, tableID int64) (count, modifyCount int64, isNull bool, err error) { - rows, _, err := util.ExecRows(sctx, "select count, modify_count from mysql.stats_meta where table_id = %?", tableID) - if err != nil { - return 0, 0, false, err - } - if len(rows) == 0 { - return 0, 0, true, nil - } - count = int64(rows[0].GetUint64(0)) - modifyCount = rows[0].GetInt64(1) - return count, modifyCount, false, nil -} - -// HistMetaFromStorageWithHighPriority reads the meta info of the histogram from the storage. -func HistMetaFromStorageWithHighPriority(sctx sessionctx.Context, item *model.TableItemID, possibleColInfo *model.ColumnInfo) (*statistics.Histogram, *types.Datum, int64, int64, error) { - isIndex := 0 - var tp *types.FieldType - if item.IsIndex { - isIndex = 1 - tp = types.NewFieldType(mysql.TypeBlob) - } else { - tp = &possibleColInfo.FieldType - } - rows, _, err := util.ExecRows(sctx, - "select high_priority distinct_count, version, null_count, tot_col_size, stats_ver, correlation, flag, last_analyze_pos from mysql.stats_histograms where table_id = %? and hist_id = %? and is_index = %?", - item.TableID, - item.ID, - isIndex, - ) - if err != nil { - return nil, nil, 0, 0, err - } - if len(rows) == 0 { - return nil, nil, 0, 0, nil - } - hist := statistics.NewHistogram(item.ID, rows[0].GetInt64(0), rows[0].GetInt64(2), rows[0].GetUint64(1), tp, chunk.InitialCapacity, rows[0].GetInt64(3)) - hist.Correlation = rows[0].GetFloat64(5) - lastPos := rows[0].GetDatum(7, types.NewFieldType(mysql.TypeBlob)) - return hist, &lastPos, rows[0].GetInt64(4), rows[0].GetInt64(6), nil -} - -// HistogramFromStorageWithPriority wraps the HistogramFromStorage with the given kv.Priority. -// Sync load and async load will use high priority to get data. -func HistogramFromStorageWithPriority( - sctx sessionctx.Context, - tableID int64, - colID int64, - tp *types.FieldType, - distinct int64, - isIndex int, - ver uint64, - nullCount int64, - totColSize int64, - corr float64, - priority int, -) (*statistics.Histogram, error) { - selectPrefix := "select " - switch priority { - case kv.PriorityHigh: - selectPrefix += "high_priority " - case kv.PriorityLow: - selectPrefix += "low_priority " - } - rows, fields, err := util.ExecRows(sctx, selectPrefix+"count, repeats, lower_bound, upper_bound, ndv from mysql.stats_buckets where table_id = %? and is_index = %? and hist_id = %? order by bucket_id", tableID, isIndex, colID) - if err != nil { - return nil, errors.Trace(err) - } - bucketSize := len(rows) - hg := statistics.NewHistogram(colID, distinct, nullCount, ver, tp, bucketSize, totColSize) - hg.Correlation = corr - totalCount := int64(0) - for i := 0; i < bucketSize; i++ { - count := rows[i].GetInt64(0) - repeats := rows[i].GetInt64(1) - var upperBound, lowerBound types.Datum - if isIndex == 1 { - lowerBound = rows[i].GetDatum(2, &fields[2].Column.FieldType) - upperBound = rows[i].GetDatum(3, &fields[3].Column.FieldType) - } else { - d := rows[i].GetDatum(2, &fields[2].Column.FieldType) - // For new collation data, when storing the bounds of the histogram, we store the collate key instead of the - // original value. - // But there's additional conversion logic for new collation data, and the collate key might be longer than - // the FieldType.flen. - // If we use the original FieldType here, there might be errors like "Invalid utf8mb4 character string" - // or "Data too long". - // So we change it to TypeBlob to bypass those logics here. - if tp.EvalType() == types.ETString && tp.GetType() != mysql.TypeEnum && tp.GetType() != mysql.TypeSet { - tp = types.NewFieldType(mysql.TypeBlob) - } - lowerBound, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, tp) - if err != nil { - return nil, errors.Trace(err) - } - d = rows[i].GetDatum(3, &fields[3].Column.FieldType) - upperBound, err = d.ConvertTo(statistics.UTCWithAllowInvalidDateCtx, tp) - if err != nil { - return nil, errors.Trace(err) - } - } - totalCount += count - hg.AppendBucketWithNDV(&lowerBound, &upperBound, totalCount, repeats, rows[i].GetInt64(4)) - } - hg.PreCalculateScalar() - return hg, nil -} - -// CMSketchAndTopNFromStorageWithHighPriority reads CMSketch and TopN from storage. -func CMSketchAndTopNFromStorageWithHighPriority(sctx sessionctx.Context, tblID int64, isIndex, histID, statsVer int64) (_ *statistics.CMSketch, _ *statistics.TopN, err error) { - topNRows, _, err := util.ExecRows(sctx, "select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) - if err != nil { - return nil, nil, err - } - // If we are on version higher than 1. Don't read Count-Min Sketch. - if statsVer > statistics.Version1 { - return statistics.DecodeCMSketchAndTopN(nil, topNRows) - } - rows, _, err := util.ExecRows(sctx, "select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) - if err != nil { - return nil, nil, err - } - if len(rows) == 0 { - return statistics.DecodeCMSketchAndTopN(nil, topNRows) - } - return statistics.DecodeCMSketchAndTopN(rows[0].GetBytes(0), topNRows) -} - -// CMSketchFromStorage reads CMSketch from storage -func CMSketchFromStorage(sctx sessionctx.Context, tblID int64, isIndex int, histID int64) (_ *statistics.CMSketch, err error) { - rows, _, err := util.ExecRows(sctx, "select cm_sketch from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) - if err != nil || len(rows) == 0 { - return nil, err - } - return statistics.DecodeCMSketch(rows[0].GetBytes(0)) -} - -// TopNFromStorage reads TopN from storage -func TopNFromStorage(sctx sessionctx.Context, tblID int64, isIndex int, histID int64) (_ *statistics.TopN, err error) { - rows, _, err := util.ExecRows(sctx, "select HIGH_PRIORITY value, count from mysql.stats_top_n where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) - if err != nil || len(rows) == 0 { - return nil, err - } - return statistics.DecodeTopN(rows), nil -} - -// FMSketchFromStorage reads FMSketch from storage -func FMSketchFromStorage(sctx sessionctx.Context, tblID int64, isIndex, histID int64) (_ *statistics.FMSketch, err error) { - rows, _, err := util.ExecRows(sctx, "select value from mysql.stats_fm_sketch where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histID) - if err != nil || len(rows) == 0 { - return nil, err - } - return statistics.DecodeFMSketch(rows[0].GetBytes(0)) -} - -// CheckSkipPartition checks if we can skip loading the partition. -func CheckSkipPartition(sctx sessionctx.Context, tblID int64, isIndex int) error { - rows, _, err := util.ExecRows(sctx, "select distinct_count from mysql.stats_histograms where table_id =%? and is_index = %?", tblID, isIndex) - if err != nil { - return err - } - if len(rows) == 0 { - return types.ErrPartitionStatsMissing - } - return nil -} - -// CheckSkipColumnPartiion checks if we can skip loading the partition. -func CheckSkipColumnPartiion(sctx sessionctx.Context, tblID int64, isIndex int, histsID int64) error { - rows, _, err := util.ExecRows(sctx, "select distinct_count from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tblID, isIndex, histsID) - if err != nil { - return err - } - if len(rows) == 0 { - return types.ErrPartitionColumnStatsMissing - } - return nil -} - -// ExtendedStatsFromStorage reads extended stats from storage. -func ExtendedStatsFromStorage(sctx sessionctx.Context, table *statistics.Table, tableID int64, loadAll bool) (*statistics.Table, error) { - failpoint.Inject("injectExtStatsLoadErr", func() { - failpoint.Return(nil, errors.New("gofail extendedStatsFromStorage error")) - }) - lastVersion := uint64(0) - if table.ExtendedStats != nil && !loadAll { - lastVersion = table.ExtendedStats.LastUpdateVersion - } else { - table.ExtendedStats = statistics.NewExtendedStatsColl() - } - rows, _, err := util.ExecRows(sctx, "select name, status, type, column_ids, stats, version from mysql.stats_extended where table_id = %? and status in (%?, %?, %?) and version > %?", - tableID, statistics.ExtendedStatsInited, statistics.ExtendedStatsAnalyzed, statistics.ExtendedStatsDeleted, lastVersion) - if err != nil || len(rows) == 0 { - return table, nil - } - for _, row := range rows { - lastVersion = max(lastVersion, row.GetUint64(5)) - name := row.GetString(0) - status := uint8(row.GetInt64(1)) - if status == statistics.ExtendedStatsDeleted || status == statistics.ExtendedStatsInited { - delete(table.ExtendedStats.Stats, name) - } else { - item := &statistics.ExtendedStatsItem{ - Tp: uint8(row.GetInt64(2)), - } - colIDs := row.GetString(3) - err := json.Unmarshal([]byte(colIDs), &item.ColIDs) - if err != nil { - statslogutil.StatsLogger().Error("decode column IDs failed", zap.String("column_ids", colIDs), zap.Error(err)) - return nil, err - } - statsStr := row.GetString(4) - if item.Tp == ast.StatsTypeCardinality || item.Tp == ast.StatsTypeCorrelation { - if statsStr != "" { - item.ScalarVals, err = strconv.ParseFloat(statsStr, 64) - if err != nil { - statslogutil.StatsLogger().Error("parse scalar stats failed", zap.String("stats", statsStr), zap.Error(err)) - return nil, err - } - } - } else { - item.StringVals = statsStr - } - table.ExtendedStats.Stats[name] = item - } - } - table.ExtendedStats.LastUpdateVersion = lastVersion - return table, nil -} - -func indexStatsFromStorage(sctx sessionctx.Context, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo, loadAll bool, lease time.Duration, tracker *memory.Tracker) error { - histID := row.GetInt64(2) - distinct := row.GetInt64(3) - histVer := row.GetUint64(4) - nullCount := row.GetInt64(5) - statsVer := row.GetInt64(7) - idx := table.GetIdx(histID) - flag := row.GetInt64(8) - lastAnalyzePos := row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)) - - for _, idxInfo := range tableInfo.Indices { - if histID != idxInfo.ID { - continue - } - table.ColAndIdxExistenceMap.InsertIndex(idxInfo.ID, idxInfo, statsVer != statistics.Version0) - // All the objects in the table shares the same stats version. - // Update here. - if statsVer != statistics.Version0 { - table.StatsVer = int(statsVer) - table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, histVer) - } - // We will not load buckets, topn and cmsketch if: - // 1. lease > 0, and: - // 2. the index doesn't have any of buckets, topn, cmsketch in memory before, and: - // 3. loadAll is false. - // 4. lite-init-stats is true(remove the condition when lite init stats is GA). - notNeedLoad := lease > 0 && - (idx == nil || ((!idx.IsStatsInitialized() || idx.IsAllEvicted()) && idx.LastUpdateVersion < histVer)) && - !loadAll && - config.GetGlobalConfig().Performance.LiteInitStats - if notNeedLoad { - // If we don't have this index in memory, skip it. - if idx == nil { - return nil - } - idx = &statistics.Index{ - Histogram: *statistics.NewHistogram(histID, distinct, nullCount, histVer, types.NewFieldType(mysql.TypeBlob), 0, 0), - StatsVer: statsVer, - Info: idxInfo, - Flag: flag, - PhysicalID: table.PhysicalID, - } - if idx.IsAnalyzed() { - idx.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() - } - lastAnalyzePos.Copy(&idx.LastAnalyzePos) - break - } - if idx == nil || idx.LastUpdateVersion < histVer || loadAll { - hg, err := HistogramFromStorageWithPriority(sctx, table.PhysicalID, histID, types.NewFieldType(mysql.TypeBlob), distinct, 1, histVer, nullCount, 0, 0, kv.PriorityNormal) - if err != nil { - return errors.Trace(err) - } - cms, topN, err := CMSketchAndTopNFromStorageWithHighPriority(sctx, table.PhysicalID, 1, idxInfo.ID, statsVer) - if err != nil { - return errors.Trace(err) - } - var fmSketch *statistics.FMSketch - if loadAll { - // FMSketch is only used when merging partition stats into global stats. When merging partition stats into global stats, - // we load all the statistics, i.e., loadAll is true. - fmSketch, err = FMSketchFromStorage(sctx, table.PhysicalID, 1, histID) - if err != nil { - return errors.Trace(err) - } - } - idx = &statistics.Index{ - Histogram: *hg, - CMSketch: cms, - TopN: topN, - FMSketch: fmSketch, - Info: idxInfo, - StatsVer: statsVer, - Flag: flag, - PhysicalID: table.PhysicalID, - } - if statsVer != statistics.Version0 { - idx.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() - } - lastAnalyzePos.Copy(&idx.LastAnalyzePos) - } - break - } - if idx != nil { - if tracker != nil { - tracker.Consume(idx.MemoryUsage().TotalMemoryUsage()) - } - table.SetIdx(histID, idx) - } else { - logutil.BgLogger().Debug("we cannot find index id in table info. It may be deleted.", zap.Int64("indexID", histID), zap.String("table", tableInfo.Name.O)) - } - return nil -} - -func columnStatsFromStorage(sctx sessionctx.Context, row chunk.Row, table *statistics.Table, tableInfo *model.TableInfo, loadAll bool, lease time.Duration, tracker *memory.Tracker) error { - histID := row.GetInt64(2) - distinct := row.GetInt64(3) - histVer := row.GetUint64(4) - nullCount := row.GetInt64(5) - totColSize := row.GetInt64(6) - statsVer := row.GetInt64(7) - correlation := row.GetFloat64(9) - lastAnalyzePos := row.GetDatum(10, types.NewFieldType(mysql.TypeBlob)) - col := table.GetCol(histID) - flag := row.GetInt64(8) - - for _, colInfo := range tableInfo.Columns { - if histID != colInfo.ID { - continue - } - table.ColAndIdxExistenceMap.InsertCol(histID, colInfo, statsVer != statistics.Version0 || distinct > 0 || nullCount > 0) - // All the objects in the table shares the same stats version. - // Update here. - if statsVer != statistics.Version0 { - table.StatsVer = int(statsVer) - table.LastAnalyzeVersion = max(table.LastAnalyzeVersion, histVer) - } - isHandle := tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()) - // We will not load buckets, topn and cmsketch if: - // 1. lease > 0, and: - // 2. this column is not handle or lite-init-stats is true(remove the condition when lite init stats is GA), and: - // 3. the column doesn't have any of buckets, topn, cmsketch in memory before, and: - // 4. loadAll is false. - // - // Here is the explanation of the condition `!col.IsStatsInitialized() || col.IsAllEvicted()`. - // For one column: - // 1. If there is no stats for it in the storage(i.e., analyze has never been executed before), then its stats status - // would be `!col.IsStatsInitialized()`. In this case we should go the `notNeedLoad` path. - // 2. If there exists stats for it in the storage but its stats status is `col.IsAllEvicted()`, there are two - // sub cases for this case. One is that the column stats have never been used/needed by the optimizer so they have - // never been loaded. The other is that the column stats were loaded and then evicted. For the both sub cases, - // we should go the `notNeedLoad` path. - // 3. If some parts(Histogram/TopN/CMSketch) of stats for it exist in TiDB memory currently, we choose to load all of - // its new stats once we find stats version is updated. - notNeedLoad := lease > 0 && - (!isHandle || config.GetGlobalConfig().Performance.LiteInitStats) && - (col == nil || ((!col.IsStatsInitialized() || col.IsAllEvicted()) && col.LastUpdateVersion < histVer)) && - !loadAll - if notNeedLoad { - // If we don't have the column in memory currently, just skip it. - if col == nil { - return nil - } - col = &statistics.Column{ - PhysicalID: table.PhysicalID, - Histogram: *statistics.NewHistogram(histID, distinct, nullCount, histVer, &colInfo.FieldType, 0, totColSize), - Info: colInfo, - IsHandle: tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), - Flag: flag, - StatsVer: statsVer, - } - if col.StatsAvailable() { - col.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() - } - lastAnalyzePos.Copy(&col.LastAnalyzePos) - col.Histogram.Correlation = correlation - break - } - if col == nil || col.LastUpdateVersion < histVer || loadAll { - hg, err := HistogramFromStorageWithPriority(sctx, table.PhysicalID, histID, &colInfo.FieldType, distinct, 0, histVer, nullCount, totColSize, correlation, kv.PriorityNormal) - if err != nil { - return errors.Trace(err) - } - cms, topN, err := CMSketchAndTopNFromStorageWithHighPriority(sctx, table.PhysicalID, 0, colInfo.ID, statsVer) - if err != nil { - return errors.Trace(err) - } - var fmSketch *statistics.FMSketch - if loadAll { - // FMSketch is only used when merging partition stats into global stats. When merging partition stats into global stats, - // we load all the statistics, i.e., loadAll is true. - fmSketch, err = FMSketchFromStorage(sctx, table.PhysicalID, 0, histID) - if err != nil { - return errors.Trace(err) - } - } - col = &statistics.Column{ - PhysicalID: table.PhysicalID, - Histogram: *hg, - Info: colInfo, - CMSketch: cms, - TopN: topN, - FMSketch: fmSketch, - IsHandle: tableInfo.PKIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), - Flag: flag, - StatsVer: statsVer, - } - if col.StatsAvailable() { - col.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() - } - lastAnalyzePos.Copy(&col.LastAnalyzePos) - break - } - if col.TotColSize != totColSize { - newCol := *col - newCol.TotColSize = totColSize - col = &newCol - } - break - } - if col != nil { - if tracker != nil { - tracker.Consume(col.MemoryUsage().TotalMemoryUsage()) - } - table.SetCol(col.ID, col) - } else { - // If we didn't find a Column or Index in tableInfo, we won't load the histogram for it. - // But don't worry, next lease the ddl will be updated, and we will load a same table for two times to - // avoid error. - logutil.BgLogger().Debug("we cannot find column in table info now. It may be deleted", zap.Int64("colID", histID), zap.String("table", tableInfo.Name.O)) - } - return nil -} - -// TableStatsFromStorage loads table stats info from storage. -func TableStatsFromStorage(sctx sessionctx.Context, snapshot uint64, tableInfo *model.TableInfo, tableID int64, loadAll bool, lease time.Duration, table *statistics.Table) (_ *statistics.Table, err error) { - tracker := memory.NewTracker(memory.LabelForAnalyzeMemory, -1) - tracker.AttachTo(sctx.GetSessionVars().MemTracker) - defer tracker.Detach() - // If table stats is pseudo, we also need to copy it, since we will use the column stats when - // the average error rate of it is small. - if table == nil || snapshot > 0 { - histColl := *statistics.NewHistColl(tableID, true, 0, 0, 4, 4) - table = &statistics.Table{ - HistColl: histColl, - ColAndIdxExistenceMap: statistics.NewColAndIndexExistenceMap(len(tableInfo.Columns), len(tableInfo.Indices)), - } - } else { - // We copy it before writing to avoid race. - table = table.Copy() - } - table.Pseudo = false - - realtimeCount, modidyCount, isNull, err := StatsMetaCountAndModifyCount(sctx, tableID) - if err != nil || isNull { - return nil, err - } - table.ModifyCount = modidyCount - table.RealtimeCount = realtimeCount - - rows, _, err := util.ExecRows(sctx, "select table_id, is_index, hist_id, distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %?", tableID) - // Check deleted table. - if err != nil || len(rows) == 0 { - return nil, nil - } - for _, row := range rows { - if err := sctx.GetSessionVars().SQLKiller.HandleSignal(); err != nil { - return nil, err - } - if row.GetInt64(1) > 0 { - err = indexStatsFromStorage(sctx, row, table, tableInfo, loadAll, lease, tracker) - } else { - err = columnStatsFromStorage(sctx, row, table, tableInfo, loadAll, lease, tracker) - } - if err != nil { - return nil, err - } - } - return ExtendedStatsFromStorage(sctx, table, tableID, loadAll) -} - -// LoadHistogram will load histogram from storage. -func LoadHistogram(sctx sessionctx.Context, tableID int64, isIndex int, histID int64, tableInfo *model.TableInfo) (*statistics.Histogram, error) { - row, _, err := util.ExecRows(sctx, "select distinct_count, version, null_count, tot_col_size, stats_ver, flag, correlation, last_analyze_pos from mysql.stats_histograms where table_id = %? and is_index = %? and hist_id = %?", tableID, isIndex, histID) - if err != nil || len(row) == 0 { - return nil, err - } - distinct := row[0].GetInt64(0) - histVer := row[0].GetUint64(1) - nullCount := row[0].GetInt64(2) - var totColSize int64 - var corr float64 - var tp types.FieldType - if isIndex == 0 { - totColSize = row[0].GetInt64(3) - corr = row[0].GetFloat64(6) - for _, colInfo := range tableInfo.Columns { - if histID != colInfo.ID { - continue - } - tp = colInfo.FieldType - break - } - return HistogramFromStorageWithPriority(sctx, tableID, histID, &tp, distinct, isIndex, histVer, nullCount, totColSize, corr, kv.PriorityNormal) - } - return HistogramFromStorageWithPriority(sctx, tableID, histID, types.NewFieldType(mysql.TypeBlob), distinct, isIndex, histVer, nullCount, 0, 0, kv.PriorityNormal) -} - -// LoadNeededHistograms will load histograms for those needed columns/indices. -func LoadNeededHistograms(sctx sessionctx.Context, statsCache statstypes.StatsCache, loadFMSketch bool) (err error) { - items := asyncload.AsyncLoadHistogramNeededItems.AllItems() - for _, item := range items { - if !item.IsIndex { - err = loadNeededColumnHistograms(sctx, statsCache, item.TableItemID, loadFMSketch, item.FullLoad) - } else { - // Index is always full load. - err = loadNeededIndexHistograms(sctx, statsCache, item.TableItemID, loadFMSketch) - } - if err != nil { - return err - } - } - return nil -} - -// CleanFakeItemsForShowHistInFlights cleans the invalid inserted items. -func CleanFakeItemsForShowHistInFlights(statsCache statstypes.StatsCache) int { - items := asyncload.AsyncLoadHistogramNeededItems.AllItems() - reallyNeeded := 0 - for _, item := range items { - tbl, ok := statsCache.Get(item.TableID) - if !ok { - asyncload.AsyncLoadHistogramNeededItems.Delete(item.TableItemID) - continue - } - loadNeeded := false - if item.IsIndex { - _, loadNeeded = tbl.IndexIsLoadNeeded(item.ID) - } else { - var analyzed bool - _, loadNeeded, analyzed = tbl.ColumnIsLoadNeeded(item.ID, item.FullLoad) - loadNeeded = loadNeeded && analyzed - } - if !loadNeeded { - asyncload.AsyncLoadHistogramNeededItems.Delete(item.TableItemID) - continue - } - reallyNeeded++ - } - return reallyNeeded -} - -func loadNeededColumnHistograms(sctx sessionctx.Context, statsCache statstypes.StatsCache, col model.TableItemID, loadFMSketch bool, fullLoad bool) (err error) { - tbl, ok := statsCache.Get(col.TableID) - if !ok { - return nil - } - var colInfo *model.ColumnInfo - _, loadNeeded, analyzed := tbl.ColumnIsLoadNeeded(col.ID, true) - if !loadNeeded || !analyzed { - asyncload.AsyncLoadHistogramNeededItems.Delete(col) - return nil - } - colInfo = tbl.ColAndIdxExistenceMap.GetCol(col.ID) - hg, _, statsVer, _, err := HistMetaFromStorageWithHighPriority(sctx, &col, colInfo) - if hg == nil || err != nil { - asyncload.AsyncLoadHistogramNeededItems.Delete(col) - return err - } - var ( - cms *statistics.CMSketch - topN *statistics.TopN - fms *statistics.FMSketch - ) - if fullLoad { - hg, err = HistogramFromStorageWithPriority(sctx, col.TableID, col.ID, &colInfo.FieldType, hg.NDV, 0, hg.LastUpdateVersion, hg.NullCount, hg.TotColSize, hg.Correlation, kv.PriorityHigh) - if err != nil { - return errors.Trace(err) - } - cms, topN, err = CMSketchAndTopNFromStorageWithHighPriority(sctx, col.TableID, 0, col.ID, statsVer) - if err != nil { - return errors.Trace(err) - } - if loadFMSketch { - fms, err = FMSketchFromStorage(sctx, col.TableID, 0, col.ID) - if err != nil { - return errors.Trace(err) - } - } - } - colHist := &statistics.Column{ - PhysicalID: col.TableID, - Histogram: *hg, - Info: colInfo, - CMSketch: cms, - TopN: topN, - FMSketch: fms, - IsHandle: tbl.IsPkIsHandle && mysql.HasPriKeyFlag(colInfo.GetFlag()), - StatsVer: statsVer, - } - // Reload the latest stats cache, otherwise the `updateStatsCache` may fail with high probability, because functions - // like `GetPartitionStats` called in `fmSketchFromStorage` would have modified the stats cache already. - tbl, ok = statsCache.Get(col.TableID) - if !ok { - return nil - } - tbl = tbl.Copy() - if colHist.StatsAvailable() { - if fullLoad { - colHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() - } else { - colHist.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() - } - tbl.LastAnalyzeVersion = max(tbl.LastAnalyzeVersion, colHist.LastUpdateVersion) - if statsVer != statistics.Version0 { - tbl.StatsVer = int(statsVer) - } - } - tbl.SetCol(col.ID, colHist) - statsCache.UpdateStatsCache([]*statistics.Table{tbl}, nil) - asyncload.AsyncLoadHistogramNeededItems.Delete(col) - if col.IsSyncLoadFailed { - logutil.BgLogger().Warn("Hist for column should already be loaded as sync but not found.", - zap.Int64("table_id", colHist.PhysicalID), - zap.Int64("column_id", colHist.Info.ID), - zap.String("column_name", colHist.Info.Name.O)) - } - return nil -} - -func loadNeededIndexHistograms(sctx sessionctx.Context, statsCache statstypes.StatsCache, idx model.TableItemID, loadFMSketch bool) (err error) { - tbl, ok := statsCache.Get(idx.TableID) - if !ok { - return nil - } - _, loadNeeded := tbl.IndexIsLoadNeeded(idx.ID) - if !loadNeeded { - asyncload.AsyncLoadHistogramNeededItems.Delete(idx) - return nil - } - hgMeta, lastAnalyzePos, statsVer, flag, err := HistMetaFromStorageWithHighPriority(sctx, &idx, nil) - if hgMeta == nil || err != nil { - asyncload.AsyncLoadHistogramNeededItems.Delete(idx) - return err - } - idxInfo := tbl.ColAndIdxExistenceMap.GetIndex(idx.ID) - hg, err := HistogramFromStorageWithPriority(sctx, idx.TableID, idx.ID, types.NewFieldType(mysql.TypeBlob), hgMeta.NDV, 1, hgMeta.LastUpdateVersion, hgMeta.NullCount, hgMeta.TotColSize, hgMeta.Correlation, kv.PriorityHigh) - if err != nil { - return errors.Trace(err) - } - cms, topN, err := CMSketchAndTopNFromStorageWithHighPriority(sctx, idx.TableID, 1, idx.ID, statsVer) - if err != nil { - return errors.Trace(err) - } - var fms *statistics.FMSketch - if loadFMSketch { - fms, err = FMSketchFromStorage(sctx, idx.TableID, 1, idx.ID) - if err != nil { - return errors.Trace(err) - } - } - idxHist := &statistics.Index{Histogram: *hg, CMSketch: cms, TopN: topN, FMSketch: fms, - Info: idxInfo, StatsVer: statsVer, - Flag: flag, PhysicalID: idx.TableID, - StatsLoadedStatus: statistics.NewStatsFullLoadStatus()} - lastAnalyzePos.Copy(&idxHist.LastAnalyzePos) - - tbl, ok = statsCache.Get(idx.TableID) - if !ok { - return nil - } - tbl = tbl.Copy() - if idxHist.StatsVer != statistics.Version0 { - tbl.StatsVer = int(idxHist.StatsVer) - } - tbl.SetIdx(idx.ID, idxHist) - tbl.LastAnalyzeVersion = max(tbl.LastAnalyzeVersion, idxHist.LastUpdateVersion) - statsCache.UpdateStatsCache([]*statistics.Table{tbl}, nil) - if idx.IsSyncLoadFailed { - logutil.BgLogger().Warn("Hist for index should already be loaded as sync but not found.", - zap.Int64("table_id", idx.TableID), - zap.Int64("index_id", idxHist.Info.ID), - zap.String("index_name", idxHist.Info.Name.O)) - } - asyncload.AsyncLoadHistogramNeededItems.Delete(idx) - return nil -} - -// StatsMetaByTableIDFromStorage gets the stats meta of a table from storage. -func StatsMetaByTableIDFromStorage(sctx sessionctx.Context, tableID int64, snapshot uint64) (version uint64, modifyCount, count int64, err error) { - var rows []chunk.Row - if snapshot == 0 { - rows, _, err = util.ExecRows(sctx, - "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", tableID) - } else { - rows, _, err = util.ExecWithOpts(sctx, - []sqlexec.OptionFuncAlias{sqlexec.ExecOptionWithSnapshot(snapshot), sqlexec.ExecOptionUseCurSession}, - "SELECT version, modify_count, count from mysql.stats_meta where table_id = %? order by version", tableID) - } - if err != nil || len(rows) == 0 { - return - } - version = rows[0].GetUint64(0) - modifyCount = rows[0].GetInt64(1) - count = rows[0].GetInt64(2) - return -} diff --git a/pkg/statistics/handle/syncload/binding__failpoint_binding__.go b/pkg/statistics/handle/syncload/binding__failpoint_binding__.go deleted file mode 100644 index f2b453264a678..0000000000000 --- a/pkg/statistics/handle/syncload/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package syncload - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/statistics/handle/syncload/stats_syncload.go b/pkg/statistics/handle/syncload/stats_syncload.go index e6972877bfe1d..65a234f71b474 100644 --- a/pkg/statistics/handle/syncload/stats_syncload.go +++ b/pkg/statistics/handle/syncload/stats_syncload.go @@ -83,14 +83,14 @@ type statsWrapper struct { func (s *statsSyncLoad) SendLoadRequests(sc *stmtctx.StatementContext, neededHistItems []model.StatsLoadItem, timeout time.Duration) error { remainedItems := s.removeHistLoadedColumns(neededHistItems) - if val, _err_ := failpoint.Eval(_curpkg_("assertSyncLoadItems")); _err_ == nil { + failpoint.Inject("assertSyncLoadItems", func(val failpoint.Value) { if sc.OptimizeTracer != nil { count := val.(int) if len(remainedItems) != count { panic("remained items count wrong") } } - } + }) if len(remainedItems) <= 0 { return nil @@ -361,12 +361,12 @@ func (s *statsSyncLoad) handleOneItemTask(task *statstypes.NeededItemTask) (err // readStatsForOneItem reads hist for one column/index, TODO load data via kv-get asynchronously func (*statsSyncLoad) readStatsForOneItem(sctx sessionctx.Context, item model.TableItemID, w *statsWrapper, isPkIsHandle bool, fullLoad bool) (*statsWrapper, error) { - failpoint.Eval(_curpkg_("mockReadStatsForOnePanic")) - if val, _err_ := failpoint.Eval(_curpkg_("mockReadStatsForOneFail")); _err_ == nil { + failpoint.Inject("mockReadStatsForOnePanic", nil) + failpoint.Inject("mockReadStatsForOneFail", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("gofail ReadStatsForOne error") + failpoint.Return(nil, errors.New("gofail ReadStatsForOne error")) } - } + }) loadFMSketch := config.GetGlobalConfig().Performance.EnableLoadFMSketch var hg *statistics.Histogram var err error diff --git a/pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ b/pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ deleted file mode 100644 index 65a234f71b474..0000000000000 --- a/pkg/statistics/handle/syncload/stats_syncload.go__failpoint_stash__ +++ /dev/null @@ -1,574 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package syncload - -import ( - "fmt" - "math/rand" - "runtime" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/statistics/handle/storage" - statstypes "github.com/pingcap/tidb/pkg/statistics/handle/types" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" - "golang.org/x/sync/singleflight" -) - -// RetryCount is the max retry count for a sync load task. -const RetryCount = 3 - -// GetSyncLoadConcurrencyByCPU returns the concurrency of sync load by CPU. -func GetSyncLoadConcurrencyByCPU() int { - core := runtime.GOMAXPROCS(0) - if core <= 8 { - return 5 - } else if core <= 16 { - return 6 - } else if core <= 32 { - return 8 - } - return 10 -} - -type statsSyncLoad struct { - statsHandle statstypes.StatsHandle - StatsLoad statstypes.StatsLoad -} - -var globalStatsSyncLoadSingleFlight singleflight.Group - -// NewStatsSyncLoad creates a new StatsSyncLoad. -func NewStatsSyncLoad(statsHandle statstypes.StatsHandle) statstypes.StatsSyncLoad { - s := &statsSyncLoad{statsHandle: statsHandle} - cfg := config.GetGlobalConfig() - s.StatsLoad.NeededItemsCh = make(chan *statstypes.NeededItemTask, cfg.Performance.StatsLoadQueueSize) - s.StatsLoad.TimeoutItemsCh = make(chan *statstypes.NeededItemTask, cfg.Performance.StatsLoadQueueSize) - return s -} - -type statsWrapper struct { - colInfo *model.ColumnInfo - idxInfo *model.IndexInfo - col *statistics.Column - idx *statistics.Index -} - -// SendLoadRequests send neededColumns requests -func (s *statsSyncLoad) SendLoadRequests(sc *stmtctx.StatementContext, neededHistItems []model.StatsLoadItem, timeout time.Duration) error { - remainedItems := s.removeHistLoadedColumns(neededHistItems) - - failpoint.Inject("assertSyncLoadItems", func(val failpoint.Value) { - if sc.OptimizeTracer != nil { - count := val.(int) - if len(remainedItems) != count { - panic("remained items count wrong") - } - } - }) - - if len(remainedItems) <= 0 { - return nil - } - sc.StatsLoad.Timeout = timeout - sc.StatsLoad.NeededItems = remainedItems - sc.StatsLoad.ResultCh = make([]<-chan singleflight.Result, 0, len(remainedItems)) - for _, item := range remainedItems { - localItem := item - resultCh := globalStatsSyncLoadSingleFlight.DoChan(localItem.Key(), func() (any, error) { - timer := time.NewTimer(timeout) - defer timer.Stop() - task := &statstypes.NeededItemTask{ - Item: localItem, - ToTimeout: time.Now().Local().Add(timeout), - ResultCh: make(chan stmtctx.StatsLoadResult, 1), - } - select { - case s.StatsLoad.NeededItemsCh <- task: - metrics.SyncLoadDedupCounter.Inc() - result, ok := <-task.ResultCh - intest.Assert(ok, "task.ResultCh cannot be closed") - return result, nil - case <-timer.C: - return nil, errors.New("sync load stats channel is full and timeout sending task to channel") - } - }) - sc.StatsLoad.ResultCh = append(sc.StatsLoad.ResultCh, resultCh) - } - sc.StatsLoad.LoadStartTime = time.Now() - return nil -} - -// SyncWaitStatsLoad sync waits loading of neededColumns and return false if timeout -func (*statsSyncLoad) SyncWaitStatsLoad(sc *stmtctx.StatementContext) error { - if len(sc.StatsLoad.NeededItems) <= 0 { - return nil - } - var errorMsgs []string - defer func() { - if len(errorMsgs) > 0 { - logutil.BgLogger().Warn("SyncWaitStatsLoad meets error", - zap.Strings("errors", errorMsgs)) - } - sc.StatsLoad.NeededItems = nil - }() - resultCheckMap := map[model.TableItemID]struct{}{} - for _, col := range sc.StatsLoad.NeededItems { - resultCheckMap[col.TableItemID] = struct{}{} - } - timer := time.NewTimer(sc.StatsLoad.Timeout) - defer timer.Stop() - for _, resultCh := range sc.StatsLoad.ResultCh { - select { - case result, ok := <-resultCh: - metrics.SyncLoadCounter.Inc() - if !ok { - return errors.New("sync load stats channel closed unexpectedly") - } - // this error is from statsSyncLoad.SendLoadRequests which start to task and send task into worker, - // not the stats loading error - if result.Err != nil { - errorMsgs = append(errorMsgs, result.Err.Error()) - } else { - val := result.Val.(stmtctx.StatsLoadResult) - // this error is from the stats loading error - if val.HasError() { - errorMsgs = append(errorMsgs, val.ErrorMsg()) - } - delete(resultCheckMap, val.Item) - } - case <-timer.C: - metrics.SyncLoadCounter.Inc() - metrics.SyncLoadTimeoutCounter.Inc() - return errors.New("sync load stats timeout") - } - } - if len(resultCheckMap) == 0 { - metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) - return nil - } - return nil -} - -// removeHistLoadedColumns removed having-hist columns based on neededColumns and statsCache. -func (s *statsSyncLoad) removeHistLoadedColumns(neededItems []model.StatsLoadItem) []model.StatsLoadItem { - remainedItems := make([]model.StatsLoadItem, 0, len(neededItems)) - for _, item := range neededItems { - tbl, ok := s.statsHandle.Get(item.TableID) - if !ok { - continue - } - if item.IsIndex { - _, loadNeeded := tbl.IndexIsLoadNeeded(item.ID) - if loadNeeded { - remainedItems = append(remainedItems, item) - } - continue - } - _, loadNeeded, _ := tbl.ColumnIsLoadNeeded(item.ID, item.FullLoad) - if loadNeeded { - remainedItems = append(remainedItems, item) - } - } - return remainedItems -} - -// AppendNeededItem appends needed columns/indices to ch, it is only used for test -func (s *statsSyncLoad) AppendNeededItem(task *statstypes.NeededItemTask, timeout time.Duration) error { - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case s.StatsLoad.NeededItemsCh <- task: - case <-timer.C: - return errors.New("Channel is full and timeout writing to channel") - } - return nil -} - -var errExit = errors.New("Stop loading since domain is closed") - -// SubLoadWorker loads hist data for each column -func (s *statsSyncLoad) SubLoadWorker(sctx sessionctx.Context, exit chan struct{}, exitWg *util.WaitGroupEnhancedWrapper) { - defer func() { - exitWg.Done() - logutil.BgLogger().Info("SubLoadWorker exited.") - }() - // if the last task is not successfully handled in last round for error or panic, pass it to this round to retry - var lastTask *statstypes.NeededItemTask - for { - task, err := s.HandleOneTask(sctx, lastTask, exit) - lastTask = task - if err != nil { - switch err { - case errExit: - return - default: - // To avoid the thundering herd effect - // thundering herd effect: Everyone tries to retry a large number of requests simultaneously when a problem occurs. - r := rand.Intn(500) - time.Sleep(s.statsHandle.Lease()/10 + time.Duration(r)*time.Microsecond) - continue - } - } - } -} - -// HandleOneTask handles last task if not nil, else handle a new task from chan, and return current task if fail somewhere. -// - If the task is handled successfully, return nil, nil. -// - If the task is timeout, return the task and nil. The caller should retry the timeout task without sleep. -// - If the task is failed, return the task, error. The caller should retry the timeout task with sleep. -func (s *statsSyncLoad) HandleOneTask(sctx sessionctx.Context, lastTask *statstypes.NeededItemTask, exit chan struct{}) (task *statstypes.NeededItemTask, err error) { - defer func() { - // recover for each task, worker keeps working - if r := recover(); r != nil { - logutil.BgLogger().Error("stats loading panicked", zap.Any("error", r), zap.Stack("stack")) - err = errors.Errorf("stats loading panicked: %v", r) - } - }() - if lastTask == nil { - task, err = s.drainColTask(sctx, exit) - if err != nil { - if err != errExit { - logutil.BgLogger().Error("Fail to drain task for stats loading.", zap.Error(err)) - } - return task, err - } - } else { - task = lastTask - } - result := stmtctx.StatsLoadResult{Item: task.Item.TableItemID} - err = s.handleOneItemTask(task) - if err == nil { - task.ResultCh <- result - return nil, nil - } - if !isVaildForRetry(task) { - result.Error = err - task.ResultCh <- result - return nil, nil - } - return task, err -} - -func isVaildForRetry(task *statstypes.NeededItemTask) bool { - task.Retry++ - return task.Retry <= RetryCount -} - -func (s *statsSyncLoad) handleOneItemTask(task *statstypes.NeededItemTask) (err error) { - se, err := s.statsHandle.SPool().Get() - if err != nil { - return err - } - sctx := se.(sessionctx.Context) - sctx.GetSessionVars().StmtCtx.Priority = mysql.HighPriority - defer func() { - // recover for each task, worker keeps working - if r := recover(); r != nil { - logutil.BgLogger().Error("handleOneItemTask panicked", zap.Any("recover", r), zap.Stack("stack")) - err = errors.Errorf("stats loading panicked: %v", r) - } - if err == nil { // only recycle when no error - sctx.GetSessionVars().StmtCtx.Priority = mysql.NoPriority - s.statsHandle.SPool().Put(se) - } - }() - item := task.Item.TableItemID - tbl, ok := s.statsHandle.Get(item.TableID) - if !ok { - return nil - } - wrapper := &statsWrapper{} - if item.IsIndex { - index, loadNeeded := tbl.IndexIsLoadNeeded(item.ID) - if !loadNeeded { - return nil - } - if index != nil { - wrapper.idxInfo = index.Info - } else { - wrapper.idxInfo = tbl.ColAndIdxExistenceMap.GetIndex(item.ID) - } - } else { - col, loadNeeded, analyzed := tbl.ColumnIsLoadNeeded(item.ID, task.Item.FullLoad) - if !loadNeeded { - return nil - } - if col != nil { - wrapper.colInfo = col.Info - } else { - wrapper.colInfo = tbl.ColAndIdxExistenceMap.GetCol(item.ID) - } - // If this column is not analyzed yet and we don't have it in memory. - // We create a fake one for the pseudo estimation. - if loadNeeded && !analyzed { - wrapper.col = &statistics.Column{ - PhysicalID: item.TableID, - Info: wrapper.colInfo, - Histogram: *statistics.NewHistogram(item.ID, 0, 0, 0, &wrapper.colInfo.FieldType, 0, 0), - IsHandle: tbl.IsPkIsHandle && mysql.HasPriKeyFlag(wrapper.colInfo.GetFlag()), - } - s.updateCachedItem(item, wrapper.col, wrapper.idx, task.Item.FullLoad) - return nil - } - } - t := time.Now() - needUpdate := false - wrapper, err = s.readStatsForOneItem(sctx, item, wrapper, tbl.IsPkIsHandle, task.Item.FullLoad) - if err != nil { - return err - } - if item.IsIndex { - if wrapper.idxInfo != nil { - needUpdate = true - } - } else { - if wrapper.colInfo != nil { - needUpdate = true - } - } - metrics.ReadStatsHistogram.Observe(float64(time.Since(t).Milliseconds())) - if needUpdate { - s.updateCachedItem(item, wrapper.col, wrapper.idx, task.Item.FullLoad) - } - return nil -} - -// readStatsForOneItem reads hist for one column/index, TODO load data via kv-get asynchronously -func (*statsSyncLoad) readStatsForOneItem(sctx sessionctx.Context, item model.TableItemID, w *statsWrapper, isPkIsHandle bool, fullLoad bool) (*statsWrapper, error) { - failpoint.Inject("mockReadStatsForOnePanic", nil) - failpoint.Inject("mockReadStatsForOneFail", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("gofail ReadStatsForOne error")) - } - }) - loadFMSketch := config.GetGlobalConfig().Performance.EnableLoadFMSketch - var hg *statistics.Histogram - var err error - isIndexFlag := int64(0) - hg, lastAnalyzePos, statsVer, flag, err := storage.HistMetaFromStorageWithHighPriority(sctx, &item, w.colInfo) - if err != nil { - return nil, err - } - if hg == nil { - logutil.BgLogger().Error("fail to get hist meta for this histogram, possibly a deleted one", zap.Int64("table_id", item.TableID), - zap.Int64("hist_id", item.ID), zap.Bool("is_index", item.IsIndex)) - return nil, errors.Trace(fmt.Errorf("fail to get hist meta for this histogram, table_id:%v, hist_id:%v, is_index:%v", item.TableID, item.ID, item.IsIndex)) - } - if item.IsIndex { - isIndexFlag = 1 - } - var cms *statistics.CMSketch - var topN *statistics.TopN - var fms *statistics.FMSketch - if fullLoad { - if item.IsIndex { - hg, err = storage.HistogramFromStorageWithPriority(sctx, item.TableID, item.ID, types.NewFieldType(mysql.TypeBlob), hg.NDV, int(isIndexFlag), hg.LastUpdateVersion, hg.NullCount, hg.TotColSize, hg.Correlation, kv.PriorityHigh) - if err != nil { - return nil, errors.Trace(err) - } - } else { - hg, err = storage.HistogramFromStorageWithPriority(sctx, item.TableID, item.ID, &w.colInfo.FieldType, hg.NDV, int(isIndexFlag), hg.LastUpdateVersion, hg.NullCount, hg.TotColSize, hg.Correlation, kv.PriorityHigh) - if err != nil { - return nil, errors.Trace(err) - } - } - cms, topN, err = storage.CMSketchAndTopNFromStorageWithHighPriority(sctx, item.TableID, isIndexFlag, item.ID, statsVer) - if err != nil { - return nil, errors.Trace(err) - } - if loadFMSketch { - fms, err = storage.FMSketchFromStorage(sctx, item.TableID, isIndexFlag, item.ID) - if err != nil { - return nil, errors.Trace(err) - } - } - } - if item.IsIndex { - idxHist := &statistics.Index{ - Histogram: *hg, - CMSketch: cms, - TopN: topN, - FMSketch: fms, - Info: w.idxInfo, - StatsVer: statsVer, - Flag: flag, - PhysicalID: item.TableID, - } - if statsVer != statistics.Version0 { - if fullLoad { - idxHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() - } else { - idxHist.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() - } - } - lastAnalyzePos.Copy(&idxHist.LastAnalyzePos) - w.idx = idxHist - } else { - colHist := &statistics.Column{ - PhysicalID: item.TableID, - Histogram: *hg, - Info: w.colInfo, - CMSketch: cms, - TopN: topN, - FMSketch: fms, - IsHandle: isPkIsHandle && mysql.HasPriKeyFlag(w.colInfo.GetFlag()), - StatsVer: statsVer, - } - if colHist.StatsAvailable() { - if fullLoad { - colHist.StatsLoadedStatus = statistics.NewStatsFullLoadStatus() - } else { - colHist.StatsLoadedStatus = statistics.NewStatsAllEvictedStatus() - } - } - w.col = colHist - } - return w, nil -} - -// drainColTask will hang until a column task can return, and either task or error will be returned. -func (s *statsSyncLoad) drainColTask(sctx sessionctx.Context, exit chan struct{}) (*statstypes.NeededItemTask, error) { - // select NeededColumnsCh firstly, if no task, then select TimeoutColumnsCh - for { - select { - case <-exit: - return nil, errExit - case task, ok := <-s.StatsLoad.NeededItemsCh: - if !ok { - return nil, errors.New("drainColTask: cannot read from NeededColumnsCh, maybe the chan is closed") - } - // if the task has already timeout, no sql is sync-waiting for it, - // so do not handle it just now, put it to another channel with lower priority - if time.Now().After(task.ToTimeout) { - task.ToTimeout.Add(time.Duration(sctx.GetSessionVars().StatsLoadSyncWait.Load()) * time.Microsecond) - s.writeToTimeoutChan(s.StatsLoad.TimeoutItemsCh, task) - continue - } - return task, nil - case task, ok := <-s.StatsLoad.TimeoutItemsCh: - select { - case <-exit: - return nil, errExit - case task0, ok0 := <-s.StatsLoad.NeededItemsCh: - if !ok0 { - return nil, errors.New("drainColTask: cannot read from NeededColumnsCh, maybe the chan is closed") - } - // send task back to TimeoutColumnsCh and return the task drained from NeededColumnsCh - s.writeToTimeoutChan(s.StatsLoad.TimeoutItemsCh, task) - return task0, nil - default: - if !ok { - return nil, errors.New("drainColTask: cannot read from TimeoutColumnsCh, maybe the chan is closed") - } - // NeededColumnsCh is empty now, handle task from TimeoutColumnsCh - return task, nil - } - } - } -} - -// writeToTimeoutChan writes in a nonblocking way, and if the channel queue is full, it's ok to drop the task. -func (*statsSyncLoad) writeToTimeoutChan(taskCh chan *statstypes.NeededItemTask, task *statstypes.NeededItemTask) { - select { - case taskCh <- task: - default: - } -} - -// writeToChanWithTimeout writes a task to a channel and blocks until timeout. -func (*statsSyncLoad) writeToChanWithTimeout(taskCh chan *statstypes.NeededItemTask, task *statstypes.NeededItemTask, timeout time.Duration) error { - timer := time.NewTimer(timeout) - defer timer.Stop() - select { - case taskCh <- task: - case <-timer.C: - return errors.New("Channel is full and timeout writing to channel") - } - return nil -} - -// writeToResultChan safe-writes with panic-recover so one write-fail will not have big impact. -func (*statsSyncLoad) writeToResultChan(resultCh chan stmtctx.StatsLoadResult, rs stmtctx.StatsLoadResult) { - defer func() { - if r := recover(); r != nil { - logutil.BgLogger().Error("writeToResultChan panicked", zap.Any("error", r), zap.Stack("stack")) - } - }() - select { - case resultCh <- rs: - default: - } -} - -// updateCachedItem updates the column/index hist to global statsCache. -func (s *statsSyncLoad) updateCachedItem(item model.TableItemID, colHist *statistics.Column, idxHist *statistics.Index, fullLoaded bool) (updated bool) { - s.StatsLoad.Lock() - defer s.StatsLoad.Unlock() - // Reload the latest stats cache, otherwise the `updateStatsCache` may fail with high probability, because functions - // like `GetPartitionStats` called in `fmSketchFromStorage` would have modified the stats cache already. - tbl, ok := s.statsHandle.Get(item.TableID) - if !ok { - return false - } - if !item.IsIndex && colHist != nil { - c := tbl.GetCol(item.ID) - // - If the stats is fully loaded, - // - If the stats is meta-loaded and we also just need the meta. - if c != nil && (c.IsFullLoad() || !fullLoaded) { - return false - } - tbl = tbl.Copy() - tbl.SetCol(item.ID, colHist) - // If the column is analyzed we refresh the map for the possible change. - if colHist.StatsAvailable() { - tbl.ColAndIdxExistenceMap.InsertCol(item.ID, colHist.Info, true) - } - // All the objects shares the same stats version. Update it here. - if colHist.StatsVer != statistics.Version0 { - tbl.StatsVer = statistics.Version0 - } - } else if item.IsIndex && idxHist != nil { - index := tbl.GetIdx(item.ID) - // - If the stats is fully loaded, - // - If the stats is meta-loaded and we also just need the meta. - if index != nil && (index.IsFullLoad() || !fullLoaded) { - return true - } - tbl = tbl.Copy() - tbl.SetIdx(item.ID, idxHist) - // If the index is analyzed we refresh the map for the possible change. - if idxHist.IsAnalyzed() { - tbl.ColAndIdxExistenceMap.InsertIndex(item.ID, idxHist.Info, true) - // All the objects shares the same stats version. Update it here. - tbl.StatsVer = statistics.Version0 - } - } - s.statsHandle.UpdateStatsCache([]*statistics.Table{tbl}, nil) - return true -} diff --git a/pkg/store/copr/batch_coprocessor.go b/pkg/store/copr/batch_coprocessor.go index 9b8d1416f5798..850204fb9b168 100644 --- a/pkg/store/copr/batch_coprocessor.go +++ b/pkg/store/copr/batch_coprocessor.go @@ -747,7 +747,7 @@ func buildBatchCopTasksConsistentHash( } func failpointCheckForConsistentHash(tasks []*batchCopTask) { - if val, _err_ := failpoint.Eval(_curpkg_("checkOnlyDispatchToTiFlashComputeNodes")); _err_ == nil { + failpoint.Inject("checkOnlyDispatchToTiFlashComputeNodes", func(val failpoint.Value) { logutil.BgLogger().Debug("in checkOnlyDispatchToTiFlashComputeNodes") // This failpoint will be tested in test-infra case, because we needs setup a cluster. @@ -768,18 +768,18 @@ func failpointCheckForConsistentHash(tasks []*batchCopTask) { panic(err) } } - } + }) } func failpointCheckWhichPolicy(act tiflashcompute.DispatchPolicy) { - if exp, _err_ := failpoint.Eval(_curpkg_("testWhichDispatchPolicy")); _err_ == nil { + failpoint.Inject("testWhichDispatchPolicy", func(exp failpoint.Value) { expStr := exp.(string) actStr := tiflashcompute.GetDispatchPolicy(act) if actStr != expStr { err := errors.Errorf("tiflash_compute dispatch should be %v, but got %v", expStr, actStr) panic(err) } - } + }) } func filterAllStoresAccordingToTiFlashReplicaRead(allStores []uint64, aliveStores *aliveStoresBundle, policy tiflash.ReplicaRead) (storesMatchedPolicy []uint64, needsCrossZoneAccess bool) { @@ -1174,11 +1174,11 @@ func (b *batchCopIterator) run(ctx context.Context) { for _, task := range b.tasks { b.wg.Add(1) boMaxSleep := CopNextMaxBackoff - if value, _err_ := failpoint.Eval(_curpkg_("ReduceCopNextMaxBackoff")); _err_ == nil { + failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { if value.(bool) { boMaxSleep = 2 } - } + }) bo := backoff.NewBackofferWithVars(ctx, boMaxSleep, b.vars) go b.handleTask(ctx, bo, task) } diff --git a/pkg/store/copr/batch_coprocessor.go__failpoint_stash__ b/pkg/store/copr/batch_coprocessor.go__failpoint_stash__ deleted file mode 100644 index 850204fb9b168..0000000000000 --- a/pkg/store/copr/batch_coprocessor.go__failpoint_stash__ +++ /dev/null @@ -1,1588 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package copr - -import ( - "bytes" - "cmp" - "context" - "fmt" - "io" - "math" - "math/rand" - "slices" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/coprocessor" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/store/driver/backoff" - derr "github.com/pingcap/tidb/pkg/store/driver/error" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/tiflash" - "github.com/pingcap/tidb/pkg/util/tiflashcompute" - "github.com/tikv/client-go/v2/metrics" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/twmb/murmur3" - "go.uber.org/zap" -) - -const fetchTopoMaxBackoff = 20000 - -// batchCopTask comprises of multiple copTask that will send to same store. -type batchCopTask struct { - storeAddr string - cmdType tikvrpc.CmdType - ctx *tikv.RPCContext - - regionInfos []RegionInfo // region info for single physical table - // PartitionTableRegions indicates region infos for each partition table, used by scanning partitions in batch. - // Thus, one of `regionInfos` and `PartitionTableRegions` must be nil. - PartitionTableRegions []*coprocessor.TableRegions -} - -type batchCopResponse struct { - pbResp *coprocessor.BatchResponse - detail *CopRuntimeStats - - // batch Cop Response is yet to return startKey. So batchCop cannot retry partially. - startKey kv.Key - err error - respSize int64 - respTime time.Duration -} - -// GetData implements the kv.ResultSubset GetData interface. -func (rs *batchCopResponse) GetData() []byte { - return rs.pbResp.Data -} - -// GetStartKey implements the kv.ResultSubset GetStartKey interface. -func (rs *batchCopResponse) GetStartKey() kv.Key { - return rs.startKey -} - -// GetExecDetails is unavailable currently, because TiFlash has not collected exec details for batch cop. -// TODO: Will fix in near future. -func (rs *batchCopResponse) GetCopRuntimeStats() *CopRuntimeStats { - return rs.detail -} - -// MemSize returns how many bytes of memory this response use -func (rs *batchCopResponse) MemSize() int64 { - if rs.respSize != 0 { - return rs.respSize - } - - // ignore rs.err - rs.respSize += int64(cap(rs.startKey)) - if rs.detail != nil { - rs.respSize += int64(sizeofExecDetails) - } - if rs.pbResp != nil { - // Using a approximate size since it's hard to get a accurate value. - rs.respSize += int64(rs.pbResp.Size()) - } - return rs.respSize -} - -func (rs *batchCopResponse) RespTime() time.Duration { - return rs.respTime -} - -func deepCopyStoreTaskMap(storeTaskMap map[uint64]*batchCopTask) map[uint64]*batchCopTask { - storeTasks := make(map[uint64]*batchCopTask) - for storeID, task := range storeTaskMap { - t := batchCopTask{ - storeAddr: task.storeAddr, - cmdType: task.cmdType, - ctx: task.ctx, - } - t.regionInfos = make([]RegionInfo, len(task.regionInfos)) - copy(t.regionInfos, task.regionInfos) - storeTasks[storeID] = &t - } - return storeTasks -} - -func regionTotalCount(storeTasks map[uint64]*batchCopTask, candidateRegionInfos []RegionInfo) int { - count := len(candidateRegionInfos) - for _, task := range storeTasks { - count += len(task.regionInfos) - } - return count -} - -const ( - maxBalanceScore = 100 - balanceScoreThreshold = 85 -) - -// Select at most cnt RegionInfos from candidateRegionInfos that belong to storeID. -// If selected[i] is true, candidateRegionInfos[i] has been selected and should be skip. -// storeID2RegionIndex is a map that key is storeID and value is a region index slice. -// selectRegion use storeID2RegionIndex to find RegionInfos that belong to storeID efficiently. -func selectRegion(storeID uint64, candidateRegionInfos []RegionInfo, selected []bool, storeID2RegionIndex map[uint64][]int, cnt int64) []RegionInfo { - regionIndexes, ok := storeID2RegionIndex[storeID] - if !ok { - logutil.BgLogger().Error("selectRegion: storeID2RegionIndex not found", zap.Uint64("storeID", storeID)) - return nil - } - var regionInfos []RegionInfo - i := 0 - for ; i < len(regionIndexes) && len(regionInfos) < int(cnt); i++ { - idx := regionIndexes[i] - if selected[idx] { - continue - } - selected[idx] = true - regionInfos = append(regionInfos, candidateRegionInfos[idx]) - } - // Remove regions that has been selected. - storeID2RegionIndex[storeID] = regionIndexes[i:] - return regionInfos -} - -// Higher scores mean more balance: (100 - unblance percentage) -func balanceScore(maxRegionCount, minRegionCount int, balanceContinuousRegionCount int64) int { - if minRegionCount <= 0 { - return math.MinInt32 - } - unbalanceCount := maxRegionCount - minRegionCount - if unbalanceCount <= int(balanceContinuousRegionCount) { - return maxBalanceScore - } - return maxBalanceScore - unbalanceCount*100/minRegionCount -} - -func isBalance(score int) bool { - return score >= balanceScoreThreshold -} - -func checkBatchCopTaskBalance(storeTasks map[uint64]*batchCopTask, balanceContinuousRegionCount int64) (int, []string) { - if len(storeTasks) == 0 { - return 0, []string{} - } - maxRegionCount := 0 - minRegionCount := math.MaxInt32 - balanceInfos := []string{} - for storeID, task := range storeTasks { - cnt := len(task.regionInfos) - if cnt > maxRegionCount { - maxRegionCount = cnt - } - if cnt < minRegionCount { - minRegionCount = cnt - } - balanceInfos = append(balanceInfos, fmt.Sprintf("storeID %d storeAddr %s regionCount %d", storeID, task.storeAddr, cnt)) - } - return balanceScore(maxRegionCount, minRegionCount, balanceContinuousRegionCount), balanceInfos -} - -// balanceBatchCopTaskWithContinuity try to balance `continuous regions` between TiFlash Stores. -// In fact, not absolutely continuous is required, regions' range are closed to store in a TiFlash segment is enough for internal read optimization. -// -// First, sort candidateRegionInfos by their key ranges. -// Second, build a storeID2RegionIndex data structure to fastly locate regions of a store (avoid scanning candidateRegionInfos repeatedly). -// Third, each store will take balanceContinuousRegionCount from the sorted candidateRegionInfos. These regions are stored very close to each other in TiFlash. -// Fourth, if the region count is not balance between TiFlash, it may fallback to the original balance logic. -func balanceBatchCopTaskWithContinuity(storeTaskMap map[uint64]*batchCopTask, candidateRegionInfos []RegionInfo, balanceContinuousRegionCount int64) ([]*batchCopTask, int) { - if len(candidateRegionInfos) < 500 { - return nil, 0 - } - funcStart := time.Now() - regionCount := regionTotalCount(storeTaskMap, candidateRegionInfos) - storeTasks := deepCopyStoreTaskMap(storeTaskMap) - - // Sort regions by their key ranges. - slices.SortFunc(candidateRegionInfos, func(i, j RegionInfo) int { - // Special case: Sort empty ranges to the end. - if i.Ranges.Len() < 1 || j.Ranges.Len() < 1 { - return cmp.Compare(j.Ranges.Len(), i.Ranges.Len()) - } - // StartKey0 < StartKey1 - return bytes.Compare(i.Ranges.At(0).StartKey, j.Ranges.At(0).StartKey) - }) - - balanceStart := time.Now() - // Build storeID -> region index slice index and we can fastly locate regions of a store. - storeID2RegionIndex := make(map[uint64][]int) - for i, ri := range candidateRegionInfos { - for _, storeID := range ri.AllStores { - if val, ok := storeID2RegionIndex[storeID]; ok { - storeID2RegionIndex[storeID] = append(val, i) - } else { - storeID2RegionIndex[storeID] = []int{i} - } - } - } - - // If selected[i] is true, candidateRegionInfos[i] is selected by a store and should skip it in selectRegion. - selected := make([]bool, len(candidateRegionInfos)) - for { - totalCount := 0 - selectCountThisRound := 0 - for storeID, task := range storeTasks { - // Each store select balanceContinuousRegionCount regions from candidateRegionInfos. - // Since candidateRegionInfos is sorted, it is very likely that these regions are close to each other in TiFlash. - regionInfo := selectRegion(storeID, candidateRegionInfos, selected, storeID2RegionIndex, balanceContinuousRegionCount) - task.regionInfos = append(task.regionInfos, regionInfo...) - totalCount += len(task.regionInfos) - selectCountThisRound += len(regionInfo) - } - if totalCount >= regionCount { - break - } - if selectCountThisRound == 0 { - logutil.BgLogger().Error("selectCandidateRegionInfos fail: some region cannot find relevant store.", zap.Int("regionCount", regionCount), zap.Int("candidateCount", len(candidateRegionInfos))) - return nil, 0 - } - } - balanceEnd := time.Now() - - score, balanceInfos := checkBatchCopTaskBalance(storeTasks, balanceContinuousRegionCount) - if !isBalance(score) { - logutil.BgLogger().Warn("balanceBatchCopTaskWithContinuity is not balance", zap.Int("score", score), zap.Strings("balanceInfos", balanceInfos)) - } - - totalCount := 0 - var res []*batchCopTask - for _, task := range storeTasks { - totalCount += len(task.regionInfos) - if len(task.regionInfos) > 0 { - res = append(res, task) - } - } - if totalCount != regionCount { - logutil.BgLogger().Error("balanceBatchCopTaskWithContinuity error", zap.Int("totalCount", totalCount), zap.Int("regionCount", regionCount)) - return nil, 0 - } - - logutil.BgLogger().Debug("balanceBatchCopTaskWithContinuity time", - zap.Int("candidateRegionCount", len(candidateRegionInfos)), - zap.Int64("balanceContinuousRegionCount", balanceContinuousRegionCount), - zap.Int("balanceScore", score), - zap.Duration("balanceTime", balanceEnd.Sub(balanceStart)), - zap.Duration("totalTime", time.Since(funcStart))) - - return res, score -} - -// balanceBatchCopTask balance the regions between available stores, the basic rule is -// 1. the first region of each original batch cop task belongs to its original store because some -// meta data(like the rpc context) in batchCopTask is related to it -// 2. for the remaining regions: -// if there is only 1 available store, then put the region to the related store -// otherwise, these region will be balance between TiFlash stores. -// -// Currently, there are two balance strategies. -// The first balance strategy: use a greedy algorithm to put it into the store with highest weight. This strategy only consider the region count between TiFlash stores. -// -// The second balance strategy: Not only consider the region count between TiFlash stores, but also try to make the regions' range continuous(stored in TiFlash closely). -// If balanceWithContinuity is true, the second balance strategy is enable. -func balanceBatchCopTask(aliveStores []*tikv.Store, originalTasks []*batchCopTask, balanceWithContinuity bool, balanceContinuousRegionCount int64) []*batchCopTask { - if len(originalTasks) == 0 { - log.Info("Batch cop task balancer got an empty task set.") - return originalTasks - } - storeTaskMap := make(map[uint64]*batchCopTask) - // storeCandidateRegionMap stores all the possible store->region map. Its content is - // store id -> region signature -> region info. We can see it as store id -> region lists. - storeCandidateRegionMap := make(map[uint64]map[string]RegionInfo) - totalRegionCandidateNum := 0 - totalRemainingRegionNum := 0 - - for _, s := range aliveStores { - storeTaskMap[s.StoreID()] = &batchCopTask{ - storeAddr: s.GetAddr(), - cmdType: originalTasks[0].cmdType, - ctx: &tikv.RPCContext{Addr: s.GetAddr(), Store: s}, - } - } - - var candidateRegionInfos []RegionInfo - for _, task := range originalTasks { - for _, ri := range task.regionInfos { - // for each region, figure out the valid store num - validStoreNum := 0 - var validStoreID uint64 - for _, storeID := range ri.AllStores { - if _, ok := storeTaskMap[storeID]; ok { - validStoreNum++ - // original store id might be invalid, so we have to set it again. - validStoreID = storeID - } - } - if validStoreNum == 0 { - logutil.BgLogger().Warn("Meet regions that don't have an available store. Give up balancing") - return originalTasks - } else if validStoreNum == 1 { - // if only one store is valid, just put it to storeTaskMap - storeTaskMap[validStoreID].regionInfos = append(storeTaskMap[validStoreID].regionInfos, ri) - } else { - // if more than one store is valid, put the region - // to store candidate map - totalRegionCandidateNum += validStoreNum - totalRemainingRegionNum++ - candidateRegionInfos = append(candidateRegionInfos, ri) - taskKey := ri.Region.String() - for _, storeID := range ri.AllStores { - if _, validStore := storeTaskMap[storeID]; !validStore { - continue - } - if _, ok := storeCandidateRegionMap[storeID]; !ok { - candidateMap := make(map[string]RegionInfo) - storeCandidateRegionMap[storeID] = candidateMap - } - if _, duplicateRegion := storeCandidateRegionMap[storeID][taskKey]; duplicateRegion { - // duplicated region, should not happen, just give up balance - logutil.BgLogger().Warn("Meet duplicated region info during when trying to balance batch cop task, give up balancing") - return originalTasks - } - storeCandidateRegionMap[storeID][taskKey] = ri - } - } - } - } - - // If balanceBatchCopTaskWithContinuity failed (not balance or return nil), it will fallback to the original balance logic. - // So storeTaskMap should not be modify. - var contiguousTasks []*batchCopTask = nil - contiguousBalanceScore := 0 - if balanceWithContinuity { - contiguousTasks, contiguousBalanceScore = balanceBatchCopTaskWithContinuity(storeTaskMap, candidateRegionInfos, balanceContinuousRegionCount) - if isBalance(contiguousBalanceScore) && contiguousTasks != nil { - return contiguousTasks - } - } - - if totalRemainingRegionNum > 0 { - avgStorePerRegion := float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) - findNextStore := func(candidateStores []uint64) uint64 { - store := uint64(math.MaxUint64) - weightedRegionNum := math.MaxFloat64 - if candidateStores != nil { - for _, storeID := range candidateStores { - if _, validStore := storeCandidateRegionMap[storeID]; !validStore { - continue - } - num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) - if num < weightedRegionNum { - store = storeID - weightedRegionNum = num - } - } - if store != uint64(math.MaxUint64) { - return store - } - } - for storeID := range storeTaskMap { - if _, validStore := storeCandidateRegionMap[storeID]; !validStore { - continue - } - num := float64(len(storeCandidateRegionMap[storeID]))/avgStorePerRegion + float64(len(storeTaskMap[storeID].regionInfos)) - if num < weightedRegionNum { - store = storeID - weightedRegionNum = num - } - } - return store - } - - store := findNextStore(nil) - for totalRemainingRegionNum > 0 { - if store == uint64(math.MaxUint64) { - break - } - var key string - var ri RegionInfo - for key, ri = range storeCandidateRegionMap[store] { - // get the first region - break - } - storeTaskMap[store].regionInfos = append(storeTaskMap[store].regionInfos, ri) - totalRemainingRegionNum-- - for _, id := range ri.AllStores { - if _, ok := storeCandidateRegionMap[id]; ok { - delete(storeCandidateRegionMap[id], key) - totalRegionCandidateNum-- - if len(storeCandidateRegionMap[id]) == 0 { - delete(storeCandidateRegionMap, id) - } - } - } - if totalRemainingRegionNum > 0 { - avgStorePerRegion = float64(totalRegionCandidateNum) / float64(totalRemainingRegionNum) - // it is not optimal because we only check the stores that affected by this region, in fact in order - // to find out the store with the lowest weightedRegionNum, all stores should be checked, but I think - // check only the affected stores is more simple and will get a good enough result - store = findNextStore(ri.AllStores) - } - } - if totalRemainingRegionNum > 0 { - logutil.BgLogger().Warn("Some regions are not used when trying to balance batch cop task, give up balancing") - return originalTasks - } - } - - if contiguousTasks != nil { - score, balanceInfos := checkBatchCopTaskBalance(storeTaskMap, balanceContinuousRegionCount) - if !isBalance(score) { - logutil.BgLogger().Warn("Region count is not balance and use contiguousTasks", zap.Int("contiguousBalanceScore", contiguousBalanceScore), zap.Int("score", score), zap.Strings("balanceInfos", balanceInfos)) - return contiguousTasks - } - } - - var ret []*batchCopTask - for _, task := range storeTaskMap { - if len(task.regionInfos) > 0 { - ret = append(ret, task) - } - } - return ret -} - -func buildBatchCopTasksForNonPartitionedTable( - ctx context.Context, - bo *backoff.Backoffer, - store *kvStore, - ranges *KeyRanges, - storeType kv.StoreType, - isMPP bool, - ttl time.Duration, - balanceWithContinuity bool, - balanceContinuousRegionCount int64, - dispatchPolicy tiflashcompute.DispatchPolicy, - tiflashReplicaReadPolicy tiflash.ReplicaRead, - appendWarning func(error)) ([]*batchCopTask, error) { - if config.GetGlobalConfig().DisaggregatedTiFlash { - if config.GetGlobalConfig().UseAutoScaler { - return buildBatchCopTasksConsistentHash(ctx, bo, store, []*KeyRanges{ranges}, storeType, ttl, dispatchPolicy) - } - return buildBatchCopTasksConsistentHashForPD(bo, store, []*KeyRanges{ranges}, storeType, ttl, dispatchPolicy) - } - return buildBatchCopTasksCore(bo, store, []*KeyRanges{ranges}, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount, tiflashReplicaReadPolicy, appendWarning) -} - -func buildBatchCopTasksForPartitionedTable( - ctx context.Context, - bo *backoff.Backoffer, - store *kvStore, - rangesForEachPhysicalTable []*KeyRanges, - storeType kv.StoreType, - isMPP bool, - ttl time.Duration, - balanceWithContinuity bool, - balanceContinuousRegionCount int64, - partitionIDs []int64, - dispatchPolicy tiflashcompute.DispatchPolicy, - tiflashReplicaReadPolicy tiflash.ReplicaRead, - appendWarning func(error)) (batchTasks []*batchCopTask, err error) { - if config.GetGlobalConfig().DisaggregatedTiFlash { - if config.GetGlobalConfig().UseAutoScaler { - batchTasks, err = buildBatchCopTasksConsistentHash(ctx, bo, store, rangesForEachPhysicalTable, storeType, ttl, dispatchPolicy) - } else { - // todo: remove this after AutoScaler is stable. - batchTasks, err = buildBatchCopTasksConsistentHashForPD(bo, store, rangesForEachPhysicalTable, storeType, ttl, dispatchPolicy) - } - } else { - batchTasks, err = buildBatchCopTasksCore(bo, store, rangesForEachPhysicalTable, storeType, isMPP, ttl, balanceWithContinuity, balanceContinuousRegionCount, tiflashReplicaReadPolicy, appendWarning) - } - if err != nil { - return nil, err - } - // generate tableRegions for batchCopTasks - convertRegionInfosToPartitionTableRegions(batchTasks, partitionIDs) - return batchTasks, nil -} - -func filterAliveStoresStr(ctx context.Context, storesStr []string, ttl time.Duration, kvStore *kvStore) (aliveStores []string) { - aliveIdx := filterAliveStoresHelper(ctx, storesStr, ttl, kvStore) - for _, idx := range aliveIdx { - aliveStores = append(aliveStores, storesStr[idx]) - } - return aliveStores -} - -func filterAliveStores(ctx context.Context, stores []*tikv.Store, ttl time.Duration, kvStore *kvStore) (aliveStores []*tikv.Store) { - storesStr := make([]string, 0, len(stores)) - for _, s := range stores { - storesStr = append(storesStr, s.GetAddr()) - } - - aliveIdx := filterAliveStoresHelper(ctx, storesStr, ttl, kvStore) - for _, idx := range aliveIdx { - aliveStores = append(aliveStores, stores[idx]) - } - return aliveStores -} - -func filterAliveStoresHelper(ctx context.Context, stores []string, ttl time.Duration, kvStore *kvStore) (aliveIdx []int) { - var wg sync.WaitGroup - var mu sync.Mutex - wg.Add(len(stores)) - for i := range stores { - go func(idx int) { - defer wg.Done() - s := stores[idx] - - // Check if store is failed already. - if ok := GlobalMPPFailedStoreProber.IsRecovery(ctx, s, ttl); !ok { - return - } - - tikvClient := kvStore.GetTiKVClient() - if ok := detectMPPStore(ctx, tikvClient, s, DetectTimeoutLimit); !ok { - GlobalMPPFailedStoreProber.Add(ctx, s, tikvClient) - return - } - - mu.Lock() - defer mu.Unlock() - aliveIdx = append(aliveIdx, idx) - }(i) - } - wg.Wait() - - logutil.BgLogger().Info("detecting available mpp stores", zap.Any("total", len(stores)), zap.Any("alive", len(aliveIdx))) - return aliveIdx -} - -func getTiFlashComputeRPCContextByConsistentHash(ids []tikv.RegionVerID, storesStr []string) (res []*tikv.RPCContext, err error) { - // Use RendezvousHash - for _, id := range ids { - var maxHash uint32 = 0 - var maxHashStore string = "" - for _, store := range storesStr { - h := murmur3.StringSum32(fmt.Sprintf("%s-%d", store, id.GetID())) - if h > maxHash { - maxHash = h - maxHashStore = store - } - } - rpcCtx := &tikv.RPCContext{ - Region: id, - Addr: maxHashStore, - } - res = append(res, rpcCtx) - } - return res, nil -} - -func getTiFlashComputeRPCContextByRoundRobin(ids []tikv.RegionVerID, storesStr []string) (res []*tikv.RPCContext, err error) { - startIdx := rand.Intn(len(storesStr)) - for _, id := range ids { - rpcCtx := &tikv.RPCContext{ - Region: id, - Addr: storesStr[startIdx%len(storesStr)], - } - - startIdx++ - res = append(res, rpcCtx) - } - return res, nil -} - -// 1. Split range by region location to build copTasks. -// 2. For each copTask build its rpcCtx , the target tiflash_compute node will be chosen using consistent hash. -// 3. All copTasks that will be sent to one tiflash_compute node are put in one batchCopTask. -func buildBatchCopTasksConsistentHash( - ctx context.Context, - bo *backoff.Backoffer, - kvStore *kvStore, - rangesForEachPhysicalTable []*KeyRanges, - storeType kv.StoreType, - ttl time.Duration, - dispatchPolicy tiflashcompute.DispatchPolicy) (res []*batchCopTask, err error) { - failpointCheckWhichPolicy(dispatchPolicy) - start := time.Now() - const cmdType = tikvrpc.CmdBatchCop - cache := kvStore.GetRegionCache() - fetchTopoBo := backoff.NewBackofferWithVars(ctx, fetchTopoMaxBackoff, nil) - - var ( - retryNum int - rangesLen int - storesStr []string - ) - - tasks := make([]*copTask, 0) - regionIDs := make([]tikv.RegionVerID, 0) - - for i, ranges := range rangesForEachPhysicalTable { - rangesLen += ranges.Len() - locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) - if err != nil { - return nil, errors.Trace(err) - } - for _, lo := range locations { - tasks = append(tasks, &copTask{ - region: lo.Location.Region, - ranges: lo.Ranges, - cmdType: cmdType, - storeType: storeType, - partitionIndex: int64(i), - }) - regionIDs = append(regionIDs, lo.Location.Region) - } - } - splitKeyElapsed := time.Since(start) - - fetchTopoStart := time.Now() - for { - retryNum++ - storesStr, err = tiflashcompute.GetGlobalTopoFetcher().FetchAndGetTopo() - if err != nil { - return nil, err - } - storesBefFilter := len(storesStr) - storesStr = filterAliveStoresStr(ctx, storesStr, ttl, kvStore) - logutil.BgLogger().Info("topo filter alive", zap.Any("topo", storesStr)) - if len(storesStr) == 0 { - errMsg := "Cannot find proper topo to dispatch MPPTask: " - if storesBefFilter == 0 { - errMsg += "topo from AutoScaler is empty" - } else { - errMsg += "detect aliveness failed, no alive ComputeNode" - } - retErr := errors.New(errMsg) - logutil.BgLogger().Info("buildBatchCopTasksConsistentHash retry because FetchAndGetTopo return empty topo", zap.Int("retryNum", retryNum)) - if intest.InTest && retryNum > 3 { - return nil, retErr - } - err := fetchTopoBo.Backoff(tikv.BoTiFlashRPC(), retErr) - if err != nil { - return nil, retErr - } - continue - } - break - } - fetchTopoElapsed := time.Since(fetchTopoStart) - - var rpcCtxs []*tikv.RPCContext - if dispatchPolicy == tiflashcompute.DispatchPolicyRR { - rpcCtxs, err = getTiFlashComputeRPCContextByRoundRobin(regionIDs, storesStr) - } else if dispatchPolicy == tiflashcompute.DispatchPolicyConsistentHash { - rpcCtxs, err = getTiFlashComputeRPCContextByConsistentHash(regionIDs, storesStr) - } else { - err = errors.Errorf("unexpected dispatch policy %v", dispatchPolicy) - } - if err != nil { - return nil, err - } - if len(rpcCtxs) != len(tasks) { - return nil, errors.Errorf("length should be equal, len(rpcCtxs): %d, len(tasks): %d", len(rpcCtxs), len(tasks)) - } - taskMap := make(map[string]*batchCopTask) - for i, rpcCtx := range rpcCtxs { - regionInfo := RegionInfo{ - // tasks and rpcCtxs are correspond to each other. - Region: tasks[i].region, - Ranges: tasks[i].ranges, - PartitionIndex: tasks[i].partitionIndex, - } - if batchTask, ok := taskMap[rpcCtx.Addr]; ok { - batchTask.regionInfos = append(batchTask.regionInfos, regionInfo) - } else { - batchTask := &batchCopTask{ - storeAddr: rpcCtx.Addr, - cmdType: cmdType, - ctx: rpcCtx, - regionInfos: []RegionInfo{regionInfo}, - } - taskMap[rpcCtx.Addr] = batchTask - res = append(res, batchTask) - } - } - logutil.BgLogger().Info("buildBatchCopTasksConsistentHash done", - zap.Any("len(tasks)", len(taskMap)), - zap.Any("len(tiflash_compute)", len(storesStr)), - zap.Any("dispatchPolicy", tiflashcompute.GetDispatchPolicy(dispatchPolicy))) - - if log.GetLevel() <= zap.DebugLevel { - debugTaskMap := make(map[string]string, len(taskMap)) - for s, b := range taskMap { - debugTaskMap[s] = fmt.Sprintf("addr: %s; regionInfos: %v", b.storeAddr, b.regionInfos) - } - logutil.BgLogger().Debug("detailed info buildBatchCopTasksConsistentHash", zap.Any("taskMap", debugTaskMap), zap.Any("allStores", storesStr)) - } - - if elapsed := time.Since(start); elapsed > time.Millisecond*500 { - logutil.BgLogger().Warn("buildBatchCopTasksConsistentHash takes too much time", - zap.Duration("total elapsed", elapsed), - zap.Int("retryNum", retryNum), - zap.Duration("splitKeyElapsed", splitKeyElapsed), - zap.Duration("fetchTopoElapsed", fetchTopoElapsed), - zap.Int("range len", rangesLen), - zap.Int("copTaskNum", len(tasks)), - zap.Int("batchCopTaskNum", len(res))) - } - failpointCheckForConsistentHash(res) - return res, nil -} - -func failpointCheckForConsistentHash(tasks []*batchCopTask) { - failpoint.Inject("checkOnlyDispatchToTiFlashComputeNodes", func(val failpoint.Value) { - logutil.BgLogger().Debug("in checkOnlyDispatchToTiFlashComputeNodes") - - // This failpoint will be tested in test-infra case, because we needs setup a cluster. - // All tiflash_compute nodes addrs are stored in val, separated by semicolon. - str := val.(string) - addrs := strings.Split(str, ";") - if len(addrs) < 1 { - err := fmt.Sprintf("unexpected length of tiflash_compute node addrs: %v, %s", len(addrs), str) - panic(err) - } - addrMap := make(map[string]struct{}) - for _, addr := range addrs { - addrMap[addr] = struct{}{} - } - for _, batchTask := range tasks { - if _, ok := addrMap[batchTask.storeAddr]; !ok { - err := errors.Errorf("batchCopTask send to node which is not tiflash_compute: %v(tiflash_compute nodes: %s)", batchTask.storeAddr, str) - panic(err) - } - } - }) -} - -func failpointCheckWhichPolicy(act tiflashcompute.DispatchPolicy) { - failpoint.Inject("testWhichDispatchPolicy", func(exp failpoint.Value) { - expStr := exp.(string) - actStr := tiflashcompute.GetDispatchPolicy(act) - if actStr != expStr { - err := errors.Errorf("tiflash_compute dispatch should be %v, but got %v", expStr, actStr) - panic(err) - } - }) -} - -func filterAllStoresAccordingToTiFlashReplicaRead(allStores []uint64, aliveStores *aliveStoresBundle, policy tiflash.ReplicaRead) (storesMatchedPolicy []uint64, needsCrossZoneAccess bool) { - if policy.IsAllReplicas() { - for _, id := range allStores { - if _, ok := aliveStores.storeIDsInAllZones[id]; ok { - storesMatchedPolicy = append(storesMatchedPolicy, id) - } - } - return - } - // Check whether exists available stores in TiDB zone. If so, we only need to access TiFlash stores in TiDB zone. - for _, id := range allStores { - if _, ok := aliveStores.storeIDsInTiDBZone[id]; ok { - storesMatchedPolicy = append(storesMatchedPolicy, id) - } - } - // If no available stores in TiDB zone, we need to access TiFlash stores in other zones. - if len(storesMatchedPolicy) == 0 { - // needsCrossZoneAccess indicates whether we need to access(directly read or remote read) TiFlash stores in other zones. - needsCrossZoneAccess = true - - if policy == tiflash.ClosestAdaptive { - // If the policy is `ClosestAdaptive`, we can dispatch tasks to the TiFlash stores in other zones. - for _, id := range allStores { - if _, ok := aliveStores.storeIDsInAllZones[id]; ok { - storesMatchedPolicy = append(storesMatchedPolicy, id) - } - } - } else if policy == tiflash.ClosestReplicas { - // If the policy is `ClosestReplicas`, we dispatch tasks to the TiFlash stores in TiDB zone and remote read from other zones. - for id := range aliveStores.storeIDsInTiDBZone { - storesMatchedPolicy = append(storesMatchedPolicy, id) - } - } - } - return -} - -func getAllUsedTiFlashStores(allTiFlashStores []*tikv.Store, allUsedTiFlashStoresMap map[uint64]struct{}) []*tikv.Store { - allUsedTiFlashStores := make([]*tikv.Store, 0, len(allUsedTiFlashStoresMap)) - for _, store := range allTiFlashStores { - _, ok := allUsedTiFlashStoresMap[store.StoreID()] - if ok { - allUsedTiFlashStores = append(allUsedTiFlashStores, store) - } - } - return allUsedTiFlashStores -} - -// getAliveStoresAndStoreIDs gets alive TiFlash stores and their IDs. -// If tiflashReplicaReadPolicy is not all_replicas, it will also return the IDs of the alive TiFlash stores in TiDB zone. -func getAliveStoresAndStoreIDs(ctx context.Context, cache *RegionCache, allUsedTiFlashStoresMap map[uint64]struct{}, ttl time.Duration, store *kvStore, tiflashReplicaReadPolicy tiflash.ReplicaRead, tidbZone string) (aliveStores *aliveStoresBundle) { - aliveStores = new(aliveStoresBundle) - allTiFlashStores := cache.RegionCache.GetTiFlashStores(tikv.LabelFilterNoTiFlashWriteNode) - allUsedTiFlashStores := getAllUsedTiFlashStores(allTiFlashStores, allUsedTiFlashStoresMap) - aliveStores.storesInAllZones = filterAliveStores(ctx, allUsedTiFlashStores, ttl, store) - - if !tiflashReplicaReadPolicy.IsAllReplicas() { - aliveStores.storeIDsInTiDBZone = make(map[uint64]struct{}, len(aliveStores.storesInAllZones)) - for _, as := range aliveStores.storesInAllZones { - // If the `zone` label of the TiFlash store is not set, we treat it as a TiFlash store in other zones. - if tiflashZone, isSet := as.GetLabelValue(placement.DCLabelKey); isSet && tiflashZone == tidbZone { - aliveStores.storeIDsInTiDBZone[as.StoreID()] = struct{}{} - aliveStores.storesInTiDBZone = append(aliveStores.storesInTiDBZone, as) - } - } - } - if !tiflashReplicaReadPolicy.IsClosestReplicas() { - aliveStores.storeIDsInAllZones = make(map[uint64]struct{}, len(aliveStores.storesInAllZones)) - for _, as := range aliveStores.storesInAllZones { - aliveStores.storeIDsInAllZones[as.StoreID()] = struct{}{} - } - } - return aliveStores -} - -// filterAccessibleStoresAndBuildRegionInfo filters the stores that can be accessed according to: -// 1. tiflash_replica_read policy -// 2. whether the store is alive -// After filtering, it will build the RegionInfo. -func filterAccessibleStoresAndBuildRegionInfo( - cache *RegionCache, - allStores []uint64, - bo *Backoffer, - task *copTask, - rpcCtx *tikv.RPCContext, - aliveStores *aliveStoresBundle, - tiflashReplicaReadPolicy tiflash.ReplicaRead, - regionInfoNeedsReloadOnSendFail []RegionInfo, - regionsInOtherZones []uint64, - maxRemoteReadCountAllowed int, - tidbZone string) (regionInfo RegionInfo, _ []RegionInfo, _ []uint64, err error) { - needCrossZoneAccess := false - allStores, needCrossZoneAccess = filterAllStoresAccordingToTiFlashReplicaRead(allStores, aliveStores, tiflashReplicaReadPolicy) - - regionInfo = RegionInfo{ - Region: task.region, - Meta: rpcCtx.Meta, - Ranges: task.ranges, - AllStores: allStores, - PartitionIndex: task.partitionIndex} - - if needCrossZoneAccess { - regionsInOtherZones = append(regionsInOtherZones, task.region.GetID()) - regionInfoNeedsReloadOnSendFail = append(regionInfoNeedsReloadOnSendFail, regionInfo) - if tiflashReplicaReadPolicy.IsClosestReplicas() && len(regionsInOtherZones) > maxRemoteReadCountAllowed { - regionIDErrMsg := "" - for i := 0; i < 3 && i < len(regionsInOtherZones); i++ { - regionIDErrMsg += fmt.Sprintf("%d, ", regionsInOtherZones[i]) - } - err = errors.Errorf( - "no less than %d region(s) can not be accessed by TiFlash in the zone [%s]: %setc", - len(regionsInOtherZones), tidbZone, regionIDErrMsg) - // We need to reload the region cache here to avoid the failure throughout the region cache refresh TTL. - cache.OnSendFailForBatchRegions(bo, rpcCtx.Store, regionInfoNeedsReloadOnSendFail, true, err) - return regionInfo, nil, nil, err - } - } - return regionInfo, regionInfoNeedsReloadOnSendFail, regionsInOtherZones, nil -} - -type aliveStoresBundle struct { - storesInAllZones []*tikv.Store - storeIDsInAllZones map[uint64]struct{} - storesInTiDBZone []*tikv.Store - storeIDsInTiDBZone map[uint64]struct{} -} - -// When `partitionIDs != nil`, it means that buildBatchCopTasksCore is constructing a batch cop tasks for PartitionTableScan. -// At this time, `len(rangesForEachPhysicalTable) == len(partitionIDs)` and `rangesForEachPhysicalTable[i]` is for partition `partitionIDs[i]`. -// Otherwise, `rangesForEachPhysicalTable[0]` indicates the range for the single physical table. -func buildBatchCopTasksCore(bo *backoff.Backoffer, store *kvStore, rangesForEachPhysicalTable []*KeyRanges, storeType kv.StoreType, isMPP bool, ttl time.Duration, balanceWithContinuity bool, balanceContinuousRegionCount int64, tiflashReplicaReadPolicy tiflash.ReplicaRead, appendWarning func(error)) ([]*batchCopTask, error) { - cache := store.GetRegionCache() - start := time.Now() - const cmdType = tikvrpc.CmdBatchCop - rangesLen := 0 - - tidbZone, isTiDBLabelZoneSet := config.GetGlobalConfig().Labels[placement.DCLabelKey] - var ( - aliveStores *aliveStoresBundle - maxRemoteReadCountAllowed int - ) - if !isTiDBLabelZoneSet { - tiflashReplicaReadPolicy = tiflash.AllReplicas - } - - for { - var tasks []*copTask - rangesLen = 0 - for i, ranges := range rangesForEachPhysicalTable { - rangesLen += ranges.Len() - locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) - if err != nil { - return nil, errors.Trace(err) - } - for _, lo := range locations { - tasks = append(tasks, &copTask{ - region: lo.Location.Region, - ranges: lo.Ranges, - cmdType: cmdType, - storeType: storeType, - partitionIndex: int64(i), - }) - } - } - - rpcCtxs := make([]*tikv.RPCContext, 0, len(tasks)) - usedTiFlashStores := make([][]uint64, 0, len(tasks)) - usedTiFlashStoresMap := make(map[uint64]struct{}, 0) - needRetry := false - for _, task := range tasks { - rpcCtx, err := cache.GetTiFlashRPCContext(bo.TiKVBackoffer(), task.region, isMPP, tikv.LabelFilterNoTiFlashWriteNode) - if err != nil { - return nil, errors.Trace(err) - } - - // When rpcCtx is nil, it's not only attributed to the miss region, but also - // some TiFlash stores crash and can't be recovered. - // That is not an error that can be easily recovered, so we regard this error - // same as rpc error. - if rpcCtx == nil { - needRetry = true - logutil.BgLogger().Info("retry for TiFlash peer with region missing", zap.Uint64("region id", task.region.GetID())) - // Probably all the regions are invalid. Make the loop continue and mark all the regions invalid. - // Then `splitRegion` will reloads these regions. - continue - } - - allStores, _ := cache.GetAllValidTiFlashStores(task.region, rpcCtx.Store, tikv.LabelFilterNoTiFlashWriteNode) - for _, storeID := range allStores { - usedTiFlashStoresMap[storeID] = struct{}{} - } - rpcCtxs = append(rpcCtxs, rpcCtx) - usedTiFlashStores = append(usedTiFlashStores, allStores) - } - - if needRetry { - // As mentioned above, nil rpcCtx is always attributed to failed stores. - // It's equal to long poll the store but get no response. Here we'd better use - // TiFlash error to trigger the TiKV fallback mechanism. - err := bo.Backoff(tikv.BoTiFlashRPC(), errors.New("Cannot find region with TiFlash peer")) - if err != nil { - return nil, errors.Trace(err) - } - continue - } - - aliveStores = getAliveStoresAndStoreIDs(bo.GetCtx(), cache, usedTiFlashStoresMap, ttl, store, tiflashReplicaReadPolicy, tidbZone) - if tiflashReplicaReadPolicy.IsClosestReplicas() { - if len(aliveStores.storeIDsInTiDBZone) == 0 { - return nil, errors.Errorf("There is no region in tidb zone(%s)", tidbZone) - } - maxRemoteReadCountAllowed = len(aliveStores.storeIDsInTiDBZone) * tiflash.MaxRemoteReadCountPerNodeForClosestReplicas - } - - var batchTasks []*batchCopTask - var regionIDsInOtherZones []uint64 - var regionInfosNeedReloadOnSendFail []RegionInfo - storeTaskMap := make(map[string]*batchCopTask) - storeIDsUnionSetForAllTasks := make(map[uint64]struct{}) - for idx, task := range tasks { - var err error - var regionInfo RegionInfo - regionInfo, regionInfosNeedReloadOnSendFail, regionIDsInOtherZones, err = filterAccessibleStoresAndBuildRegionInfo(cache, usedTiFlashStores[idx], bo, task, rpcCtxs[idx], aliveStores, tiflashReplicaReadPolicy, regionInfosNeedReloadOnSendFail, regionIDsInOtherZones, maxRemoteReadCountAllowed, tidbZone) - if err != nil { - return nil, err - } - if batchCop, ok := storeTaskMap[rpcCtxs[idx].Addr]; ok { - batchCop.regionInfos = append(batchCop.regionInfos, regionInfo) - } else { - batchTask := &batchCopTask{ - storeAddr: rpcCtxs[idx].Addr, - cmdType: cmdType, - ctx: rpcCtxs[idx], - regionInfos: []RegionInfo{regionInfo}, - } - storeTaskMap[rpcCtxs[idx].Addr] = batchTask - } - for _, storeID := range regionInfo.AllStores { - storeIDsUnionSetForAllTasks[storeID] = struct{}{} - } - } - - if len(regionIDsInOtherZones) != 0 { - warningMsg := fmt.Sprintf("total %d region(s) can not be accessed by TiFlash in the zone [%s]:", len(regionIDsInOtherZones), tidbZone) - regionIDErrMsg := "" - for i := 0; i < 3 && i < len(regionIDsInOtherZones); i++ { - regionIDErrMsg += fmt.Sprintf("%d, ", regionIDsInOtherZones[i]) - } - warningMsg += regionIDErrMsg + "etc" - appendWarning(errors.NewNoStackErrorf(warningMsg)) - } - - for _, task := range storeTaskMap { - batchTasks = append(batchTasks, task) - } - if log.GetLevel() <= zap.DebugLevel { - msg := "Before region balance:" - for _, task := range batchTasks { - msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," - } - logutil.BgLogger().Debug(msg) - } - balanceStart := time.Now() - storesUnionSetForAllTasks := make([]*tikv.Store, 0, len(storeIDsUnionSetForAllTasks)) - for _, store := range aliveStores.storesInAllZones { - if _, ok := storeIDsUnionSetForAllTasks[store.StoreID()]; ok { - storesUnionSetForAllTasks = append(storesUnionSetForAllTasks, store) - } - } - batchTasks = balanceBatchCopTask(storesUnionSetForAllTasks, batchTasks, balanceWithContinuity, balanceContinuousRegionCount) - balanceElapsed := time.Since(balanceStart) - if log.GetLevel() <= zap.DebugLevel { - msg := "After region balance:" - for _, task := range batchTasks { - msg += " store " + task.storeAddr + ": " + strconv.Itoa(len(task.regionInfos)) + " regions," - } - logutil.BgLogger().Debug(msg) - } - - if elapsed := time.Since(start); elapsed > time.Millisecond*500 { - logutil.BgLogger().Warn("buildBatchCopTasksCore takes too much time", - zap.Duration("elapsed", elapsed), - zap.Duration("balanceElapsed", balanceElapsed), - zap.Int("range len", rangesLen), - zap.Int("task len", len(batchTasks))) - } - metrics.TxnRegionsNumHistogramWithBatchCoprocessor.Observe(float64(len(batchTasks))) - return batchTasks, nil - } -} - -func convertRegionInfosToPartitionTableRegions(batchTasks []*batchCopTask, partitionIDs []int64) { - for _, copTask := range batchTasks { - tableRegions := make([]*coprocessor.TableRegions, len(partitionIDs)) - // init coprocessor.TableRegions - for j, pid := range partitionIDs { - tableRegions[j] = &coprocessor.TableRegions{ - PhysicalTableId: pid, - } - } - // fill region infos - for _, ri := range copTask.regionInfos { - tableRegions[ri.PartitionIndex].Regions = append(tableRegions[ri.PartitionIndex].Regions, - ri.toCoprocessorRegionInfo()) - } - count := 0 - // clear empty table region - for j := 0; j < len(tableRegions); j++ { - if len(tableRegions[j].Regions) != 0 { - tableRegions[count] = tableRegions[j] - count++ - } - } - copTask.PartitionTableRegions = tableRegions[:count] - copTask.regionInfos = nil - } -} - -func (c *CopClient) sendBatch(ctx context.Context, req *kv.Request, vars *tikv.Variables, option *kv.ClientSendOption) kv.Response { - if req.KeepOrder || req.Desc { - return copErrorResponse{errors.New("batch coprocessor cannot prove keep order or desc property")} - } - ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTs) - bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, vars) - - var tasks []*batchCopTask - var err error - if req.PartitionIDAndRanges != nil { - // For Partition Table Scan - keyRanges := make([]*KeyRanges, 0, len(req.PartitionIDAndRanges)) - partitionIDs := make([]int64, 0, len(req.PartitionIDAndRanges)) - for _, pi := range req.PartitionIDAndRanges { - keyRanges = append(keyRanges, NewKeyRanges(pi.KeyRanges)) - partitionIDs = append(partitionIDs, pi.ID) - } - tasks, err = buildBatchCopTasksForPartitionedTable(ctx, bo, c.store.kvStore, keyRanges, req.StoreType, false, 0, false, 0, partitionIDs, tiflashcompute.DispatchPolicyInvalid, option.TiFlashReplicaRead, option.AppendWarning) - } else { - // TODO: merge the if branch. - ranges := NewKeyRanges(req.KeyRanges.FirstPartitionRange()) - tasks, err = buildBatchCopTasksForNonPartitionedTable(ctx, bo, c.store.kvStore, ranges, req.StoreType, false, 0, false, 0, tiflashcompute.DispatchPolicyInvalid, option.TiFlashReplicaRead, option.AppendWarning) - } - - if err != nil { - return copErrorResponse{err} - } - it := &batchCopIterator{ - store: c.store.kvStore, - req: req, - finishCh: make(chan struct{}), - vars: vars, - rpcCancel: tikv.NewRPCanceller(), - enableCollectExecutionInfo: option.EnableCollectExecutionInfo, - tiflashReplicaReadPolicy: option.TiFlashReplicaRead, - appendWarning: option.AppendWarning, - } - ctx = context.WithValue(ctx, tikv.RPCCancellerCtxKey{}, it.rpcCancel) - it.tasks = tasks - it.respChan = make(chan *batchCopResponse, 2048) - go it.run(ctx) - return it -} - -type batchCopIterator struct { - store *kvStore - req *kv.Request - finishCh chan struct{} - - tasks []*batchCopTask - - // Batch results are stored in respChan. - respChan chan *batchCopResponse - - vars *tikv.Variables - - rpcCancel *tikv.RPCCanceller - - wg sync.WaitGroup - // closed represents when the Close is called. - // There are two cases we need to close the `finishCh` channel, one is when context is done, the other one is - // when the Close is called. we use atomic.CompareAndSwap `closed` to to make sure the channel is not closed twice. - closed uint32 - - enableCollectExecutionInfo bool - tiflashReplicaReadPolicy tiflash.ReplicaRead - appendWarning func(error) -} - -func (b *batchCopIterator) run(ctx context.Context) { - // We run workers for every batch cop. - for _, task := range b.tasks { - b.wg.Add(1) - boMaxSleep := CopNextMaxBackoff - failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { - if value.(bool) { - boMaxSleep = 2 - } - }) - bo := backoff.NewBackofferWithVars(ctx, boMaxSleep, b.vars) - go b.handleTask(ctx, bo, task) - } - b.wg.Wait() - close(b.respChan) -} - -// Next returns next coprocessor result. -// NOTE: Use nil to indicate finish, so if the returned ResultSubset is not nil, reader should continue to call Next(). -func (b *batchCopIterator) Next(ctx context.Context) (kv.ResultSubset, error) { - var ( - resp *batchCopResponse - ok bool - closed bool - ) - - // Get next fetched resp from chan - resp, ok, closed = b.recvFromRespCh(ctx) - if !ok || closed { - return nil, nil - } - - if resp.err != nil { - return nil, errors.Trace(resp.err) - } - - err := b.store.CheckVisibility(b.req.StartTs) - if err != nil { - return nil, errors.Trace(err) - } - return resp, nil -} - -func (b *batchCopIterator) recvFromRespCh(ctx context.Context) (resp *batchCopResponse, ok bool, exit bool) { - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - for { - select { - case resp, ok = <-b.respChan: - return - case <-ticker.C: - killed := atomic.LoadUint32(b.vars.Killed) - if killed != 0 { - logutil.Logger(ctx).Info( - "a killed signal is received", - zap.Uint32("signal", killed), - ) - resp = &batchCopResponse{err: derr.ErrQueryInterrupted} - ok = true - return - } - case <-b.finishCh: - exit = true - return - case <-ctx.Done(): - // We select the ctx.Done() in the thread of `Next` instead of in the worker to avoid the cost of `WithCancel`. - if atomic.CompareAndSwapUint32(&b.closed, 0, 1) { - close(b.finishCh) - } - exit = true - return - } - } -} - -// Close releases the resource. -func (b *batchCopIterator) Close() error { - if atomic.CompareAndSwapUint32(&b.closed, 0, 1) { - close(b.finishCh) - } - b.rpcCancel.CancelAll() - b.wg.Wait() - return nil -} - -func (b *batchCopIterator) handleTask(ctx context.Context, bo *Backoffer, task *batchCopTask) { - tasks := []*batchCopTask{task} - for idx := 0; idx < len(tasks); idx++ { - ret, err := b.handleTaskOnce(ctx, bo, tasks[idx]) - if err != nil { - resp := &batchCopResponse{err: errors.Trace(err), detail: new(CopRuntimeStats)} - b.sendToRespCh(resp) - break - } - tasks = append(tasks, ret...) - } - b.wg.Done() -} - -// Merge all ranges and request again. -func (b *batchCopIterator) retryBatchCopTask(ctx context.Context, bo *backoff.Backoffer, batchTask *batchCopTask) ([]*batchCopTask, error) { - if batchTask.regionInfos != nil { - var ranges []kv.KeyRange - for _, ri := range batchTask.regionInfos { - ri.Ranges.Do(func(ran *kv.KeyRange) { - ranges = append(ranges, *ran) - }) - } - // need to make sure the key ranges is sorted - slices.SortFunc(ranges, func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - ret, err := buildBatchCopTasksForNonPartitionedTable(ctx, bo, b.store, NewKeyRanges(ranges), b.req.StoreType, false, 0, false, 0, tiflashcompute.DispatchPolicyInvalid, b.tiflashReplicaReadPolicy, b.appendWarning) - return ret, err - } - // Retry Partition Table Scan - keyRanges := make([]*KeyRanges, 0, len(batchTask.PartitionTableRegions)) - pid := make([]int64, 0, len(batchTask.PartitionTableRegions)) - for _, trs := range batchTask.PartitionTableRegions { - pid = append(pid, trs.PhysicalTableId) - ranges := make([]kv.KeyRange, 0, len(trs.Regions)) - for _, ri := range trs.Regions { - for _, ran := range ri.Ranges { - ranges = append(ranges, kv.KeyRange{ - StartKey: ran.Start, - EndKey: ran.End, - }) - } - } - // need to make sure the key ranges is sorted - slices.SortFunc(ranges, func(i, j kv.KeyRange) int { - return bytes.Compare(i.StartKey, j.StartKey) - }) - keyRanges = append(keyRanges, NewKeyRanges(ranges)) - } - ret, err := buildBatchCopTasksForPartitionedTable(ctx, bo, b.store, keyRanges, b.req.StoreType, false, 0, false, 0, pid, tiflashcompute.DispatchPolicyInvalid, b.tiflashReplicaReadPolicy, b.appendWarning) - return ret, err -} - -// TiFlashReadTimeoutUltraLong represents the max time that tiflash request may take, since it may scan many regions for tiflash. -const TiFlashReadTimeoutUltraLong = 3600 * time.Second - -func (b *batchCopIterator) handleTaskOnce(ctx context.Context, bo *backoff.Backoffer, task *batchCopTask) ([]*batchCopTask, error) { - sender := NewRegionBatchRequestSender(b.store.GetRegionCache(), b.store.GetTiKVClient(), b.enableCollectExecutionInfo) - var regionInfos = make([]*coprocessor.RegionInfo, 0, len(task.regionInfos)) - for _, ri := range task.regionInfos { - regionInfos = append(regionInfos, ri.toCoprocessorRegionInfo()) - } - - copReq := coprocessor.BatchRequest{ - Tp: b.req.Tp, - StartTs: b.req.StartTs, - Data: b.req.Data, - SchemaVer: b.req.SchemaVar, - Regions: regionInfos, - TableRegions: task.PartitionTableRegions, - ConnectionId: b.req.ConnID, - ConnectionAlias: b.req.ConnAlias, - } - - rgName := b.req.ResourceGroupName - if !variable.EnableResourceControl.Load() { - rgName = "" - } - req := tikvrpc.NewRequest(task.cmdType, &copReq, kvrpcpb.Context{ - IsolationLevel: isolationLevelToPB(b.req.IsolationLevel), - Priority: priorityToPB(b.req.Priority), - NotFillCache: b.req.NotFillCache, - RecordTimeStat: true, - RecordScanStat: true, - TaskId: b.req.TaskID, - ResourceControlContext: &kvrpcpb.ResourceControlContext{ - ResourceGroupName: rgName, - }, - }) - if b.req.ResourceGroupTagger != nil { - b.req.ResourceGroupTagger(req) - } - req.StoreTp = getEndPointType(kv.TiFlash) - - logutil.BgLogger().Debug("send batch request to ", zap.String("req info", req.String()), zap.Int("cop task len", len(task.regionInfos))) - resp, retry, cancel, err := sender.SendReqToAddr(bo, task.ctx, task.regionInfos, req, TiFlashReadTimeoutUltraLong) - // If there are store errors, we should retry for all regions. - if retry { - return b.retryBatchCopTask(ctx, bo, task) - } - if err != nil { - err = derr.ToTiDBErr(err) - return nil, errors.Trace(err) - } - defer cancel() - return nil, b.handleStreamedBatchCopResponse(ctx, bo, resp.Resp.(*tikvrpc.BatchCopStreamResponse), task) -} - -func (b *batchCopIterator) handleStreamedBatchCopResponse(ctx context.Context, bo *Backoffer, response *tikvrpc.BatchCopStreamResponse, task *batchCopTask) (err error) { - defer response.Close() - resp := response.BatchResponse - if resp == nil { - // streaming request returns io.EOF, so the first Response is nil. - return - } - for { - err = b.handleBatchCopResponse(bo, resp, task) - if err != nil { - return errors.Trace(err) - } - resp, err = response.Recv() - if err != nil { - if errors.Cause(err) == io.EOF { - return nil - } - - if err1 := bo.Backoff(tikv.BoTiKVRPC(), errors.Errorf("recv stream response error: %v, task store addr: %s", err, task.storeAddr)); err1 != nil { - return errors.Trace(err) - } - - // No coprocessor.Response for network error, rebuild task based on the last success one. - if errors.Cause(err) == context.Canceled { - logutil.BgLogger().Info("stream recv timeout", zap.Error(err)) - } else { - logutil.BgLogger().Info("stream unknown error", zap.Error(err)) - } - return derr.ErrTiFlashServerTimeout - } - } -} - -func (b *batchCopIterator) handleBatchCopResponse(bo *Backoffer, response *coprocessor.BatchResponse, task *batchCopTask) (err error) { - if otherErr := response.GetOtherError(); otherErr != "" { - err = errors.Errorf("other error: %s", otherErr) - logutil.BgLogger().Warn("other error", - zap.Uint64("txnStartTS", b.req.StartTs), - zap.String("storeAddr", task.storeAddr), - zap.Error(err)) - return errors.Trace(err) - } - - if len(response.RetryRegions) > 0 { - logutil.BgLogger().Info("multiple regions are stale and need to be refreshed", zap.Int("region size", len(response.RetryRegions))) - for idx, retry := range response.RetryRegions { - id := tikv.NewRegionVerID(retry.Id, retry.RegionEpoch.ConfVer, retry.RegionEpoch.Version) - logutil.BgLogger().Info("invalid region because tiflash detected stale region", zap.String("region id", id.String())) - b.store.GetRegionCache().InvalidateCachedRegionWithReason(id, tikv.EpochNotMatch) - if idx >= 10 { - logutil.BgLogger().Info("stale regions are too many, so we omit the rest ones") - break - } - } - return - } - - resp := &batchCopResponse{ - pbResp: response, - detail: new(CopRuntimeStats), - } - - b.handleCollectExecutionInfo(bo, resp, task) - b.sendToRespCh(resp) - - return -} - -func (b *batchCopIterator) sendToRespCh(resp *batchCopResponse) (exit bool) { - select { - case b.respChan <- resp: - case <-b.finishCh: - exit = true - } - return -} - -func (b *batchCopIterator) handleCollectExecutionInfo(bo *Backoffer, resp *batchCopResponse, task *batchCopTask) { - if !b.enableCollectExecutionInfo { - return - } - backoffTimes := bo.GetBackoffTimes() - resp.detail.BackoffTime = time.Duration(bo.GetTotalSleep()) * time.Millisecond - resp.detail.BackoffSleep = make(map[string]time.Duration, len(backoffTimes)) - resp.detail.BackoffTimes = make(map[string]int, len(backoffTimes)) - for backoff := range backoffTimes { - resp.detail.BackoffTimes[backoff] = backoffTimes[backoff] - resp.detail.BackoffSleep[backoff] = time.Duration(bo.GetBackoffSleepMS()[backoff]) * time.Millisecond - } - resp.detail.CalleeAddress = task.storeAddr -} - -// Only called when UseAutoScaler is false. -func buildBatchCopTasksConsistentHashForPD(bo *backoff.Backoffer, - kvStore *kvStore, - rangesForEachPhysicalTable []*KeyRanges, - storeType kv.StoreType, - ttl time.Duration, - dispatchPolicy tiflashcompute.DispatchPolicy) (res []*batchCopTask, err error) { - failpointCheckWhichPolicy(dispatchPolicy) - const cmdType = tikvrpc.CmdBatchCop - var ( - retryNum int - rangesLen int - copTaskNum int - splitKeyElapsed time.Duration - getStoreElapsed time.Duration - ) - cache := kvStore.GetRegionCache() - start := time.Now() - - for { - retryNum++ - rangesLen = 0 - tasks := make([]*copTask, 0) - regionIDs := make([]tikv.RegionVerID, 0) - - splitKeyStart := time.Now() - for i, ranges := range rangesForEachPhysicalTable { - rangesLen += ranges.Len() - locations, err := cache.SplitKeyRangesByLocations(bo, ranges, UnspecifiedLimit, false, false) - if err != nil { - return nil, errors.Trace(err) - } - for _, lo := range locations { - tasks = append(tasks, &copTask{ - region: lo.Location.Region, - ranges: lo.Ranges, - cmdType: cmdType, - storeType: storeType, - partitionIndex: int64(i), - }) - regionIDs = append(regionIDs, lo.Location.Region) - } - } - splitKeyElapsed += time.Since(splitKeyStart) - - getStoreStart := time.Now() - stores, err := cache.GetTiFlashComputeStores(bo.TiKVBackoffer()) - if err != nil { - return nil, err - } - stores = filterAliveStores(bo.GetCtx(), stores, ttl, kvStore) - if len(stores) == 0 { - return nil, errors.New("tiflash_compute node is unavailable") - } - getStoreElapsed = time.Since(getStoreStart) - - storesStr := make([]string, 0, len(stores)) - for _, s := range stores { - storesStr = append(storesStr, s.GetAddr()) - } - var rpcCtxs []*tikv.RPCContext - if dispatchPolicy == tiflashcompute.DispatchPolicyRR { - rpcCtxs, err = getTiFlashComputeRPCContextByRoundRobin(regionIDs, storesStr) - } else if dispatchPolicy == tiflashcompute.DispatchPolicyConsistentHash { - rpcCtxs, err = getTiFlashComputeRPCContextByConsistentHash(regionIDs, storesStr) - } else { - err = errors.Errorf("unexpected dispatch policy %v", dispatchPolicy) - } - if err != nil { - return nil, err - } - if rpcCtxs == nil { - logutil.BgLogger().Info("buildBatchCopTasksConsistentHashForPD retry because rcpCtx is nil", zap.Int("retryNum", retryNum)) - err := bo.Backoff(tikv.BoTiFlashRPC(), errors.New("Cannot find region with TiFlash peer")) - if err != nil { - return nil, errors.Trace(err) - } - continue - } - if len(rpcCtxs) != len(tasks) { - return nil, errors.Errorf("length should be equal, len(rpcCtxs): %d, len(tasks): %d", len(rpcCtxs), len(tasks)) - } - copTaskNum = len(tasks) - taskMap := make(map[string]*batchCopTask) - for i, rpcCtx := range rpcCtxs { - regionInfo := RegionInfo{ - // tasks and rpcCtxs are correspond to each other. - Region: tasks[i].region, - Ranges: tasks[i].ranges, - PartitionIndex: tasks[i].partitionIndex, - } - if batchTask, ok := taskMap[rpcCtx.Addr]; ok { - batchTask.regionInfos = append(batchTask.regionInfos, regionInfo) - } else { - batchTask := &batchCopTask{ - storeAddr: rpcCtx.Addr, - cmdType: cmdType, - ctx: rpcCtx, - regionInfos: []RegionInfo{regionInfo}, - } - taskMap[rpcCtx.Addr] = batchTask - res = append(res, batchTask) - } - } - logutil.BgLogger().Info("buildBatchCopTasksConsistentHashForPD done", - zap.Any("len(tasks)", len(taskMap)), - zap.Any("len(tiflash_compute)", len(stores)), - zap.Any("dispatchPolicy", tiflashcompute.GetDispatchPolicy(dispatchPolicy))) - if log.GetLevel() <= zap.DebugLevel { - debugTaskMap := make(map[string]string, len(taskMap)) - for s, b := range taskMap { - debugTaskMap[s] = fmt.Sprintf("addr: %s; regionInfos: %v", b.storeAddr, b.regionInfos) - } - logutil.BgLogger().Debug("detailed info buildBatchCopTasksConsistentHashForPD", zap.Any("taskMap", debugTaskMap), zap.Any("allStores", storesStr)) - } - break - } - - if elapsed := time.Since(start); elapsed > time.Millisecond*500 { - logutil.BgLogger().Warn("buildBatchCopTasksConsistentHashForPD takes too much time", - zap.Duration("total elapsed", elapsed), - zap.Int("retryNum", retryNum), - zap.Duration("splitKeyElapsed", splitKeyElapsed), - zap.Duration("getStoreElapsed", getStoreElapsed), - zap.Int("range len", rangesLen), - zap.Int("copTaskNum", copTaskNum), - zap.Int("batchCopTaskNum", len(res))) - } - failpointCheckForConsistentHash(res) - return res, nil -} diff --git a/pkg/store/copr/binding__failpoint_binding__.go b/pkg/store/copr/binding__failpoint_binding__.go deleted file mode 100644 index 3a49604367f85..0000000000000 --- a/pkg/store/copr/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package copr - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/store/copr/coprocessor.go b/pkg/store/copr/coprocessor.go index 4b831d0c72e3b..6b59d8096784d 100644 --- a/pkg/store/copr/coprocessor.go +++ b/pkg/store/copr/coprocessor.go @@ -111,9 +111,9 @@ func (c *CopClient) Send(ctx context.Context, req *kv.Request, variables any, op // BuildCopIterator builds the iterator without calling `open`. func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars *tikv.Variables, option *kv.ClientSendOption) (*copIterator, kv.Response) { eventCb := option.EventCb - if _, _err_ := failpoint.Eval(_curpkg_("DisablePaging")); _err_ == nil { + failpoint.Inject("DisablePaging", func(_ failpoint.Value) { req.Paging.Enable = false - } + }) if req.StoreType == kv.TiDB { // coprocessor on TiDB doesn't support paging req.Paging.Enable = false @@ -122,13 +122,13 @@ func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars // coprocessor request but type is not DAG req.Paging.Enable = false } - if _, _err_ := failpoint.Eval(_curpkg_("checkKeyRangeSortedForPaging")); _err_ == nil { + failpoint.Inject("checkKeyRangeSortedForPaging", func(_ failpoint.Value) { if req.Paging.Enable { if !req.KeyRanges.IsFullySorted() { logutil.BgLogger().Fatal("distsql request key range not sorted!") } } - } + }) if !checkStoreBatchCopr(req) { req.StoreBatchSize = 0 } @@ -327,11 +327,11 @@ func buildCopTasks(bo *Backoffer, ranges *KeyRanges, opt *buildCopTaskOpt) ([]*c } rangesPerTaskLimit := rangesPerTask - if val, _err_ := failpoint.Eval(_curpkg_("setRangesPerTask")); _err_ == nil { + failpoint.Inject("setRangesPerTask", func(val failpoint.Value) { if v, ok := val.(int); ok { rangesPerTaskLimit = v } - } + }) // TODO(youjiali1995): is there any request type that needn't be split by buckets? locs, err := cache.SplitKeyRangesByBuckets(bo, ranges) @@ -789,12 +789,12 @@ func init() { // send the result back. func (worker *copIteratorWorker) run(ctx context.Context) { defer func() { - if val, _err_ := failpoint.Eval(_curpkg_("ticase-4169")); _err_ == nil { + failpoint.Inject("ticase-4169", func(val failpoint.Value) { if val.(bool) { worker.memTracker.Consume(10 * MockResponseSizeForTest) worker.memTracker.Consume(10 * MockResponseSizeForTest) } - } + }) worker.wg.Done() }() // 16KB ballast helps grow the stack to the requirement of copIteratorWorker. @@ -865,12 +865,12 @@ func (it *copIterator) open(ctx context.Context, enabledRateLimitAction, enableC } taskSender.respChan = it.respChan it.actionOnExceed.setEnabled(enabledRateLimitAction) - if val, _err_ := failpoint.Eval(_curpkg_("ticase-4171")); _err_ == nil { + failpoint.Inject("ticase-4171", func(val failpoint.Value) { if val.(bool) { it.memTracker.Consume(10 * MockResponseSizeForTest) it.memTracker.Consume(10 * MockResponseSizeForTest) } - } + }) go taskSender.run(it.req.ConnID) } @@ -897,7 +897,7 @@ func (sender *copIteratorTaskSender) run(connID uint64) { break } if connID > 0 { - failpoint.Eval(_curpkg_("pauseCopIterTaskSender")) + failpoint.Inject("pauseCopIterTaskSender", func() {}) } } close(sender.taskCh) @@ -911,7 +911,7 @@ func (sender *copIteratorTaskSender) run(connID uint64) { } func (it *copIterator) recvFromRespCh(ctx context.Context, respCh <-chan *copResponse) (resp *copResponse, ok bool, exit bool) { - failpoint.Call(_curpkg_("CtxCancelBeforeReceive"), ctx) + failpoint.InjectCall("CtxCancelBeforeReceive", ctx) ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() for { @@ -919,13 +919,13 @@ func (it *copIterator) recvFromRespCh(ctx context.Context, respCh <-chan *copRes case resp, ok = <-respCh: if it.memTracker != nil && resp != nil { consumed := resp.MemSize() - if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { if val.(bool) { if resp != finCopResp { consumed = MockResponseSizeForTest } } - } + }) it.memTracker.Consume(-consumed) } return @@ -991,14 +991,14 @@ func (sender *copIteratorTaskSender) sendToTaskCh(t *copTask, sendTo chan<- *cop func (worker *copIteratorWorker) sendToRespCh(resp *copResponse, respCh chan<- *copResponse, checkOOM bool) (exit bool) { if worker.memTracker != nil && checkOOM { consumed := resp.MemSize() - if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { if val.(bool) { if resp != finCopResp { consumed = MockResponseSizeForTest } } - } - failpoint.Eval(_curpkg_("ConsumeRandomPanic")) + }) + failpoint.Inject("ConsumeRandomPanic", nil) worker.memTracker.Consume(consumed) } select { @@ -1022,16 +1022,16 @@ func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { ) defer func() { if resp == nil { - if val, _err_ := failpoint.Eval(_curpkg_("ticase-4170")); _err_ == nil { + failpoint.Inject("ticase-4170", func(val failpoint.Value) { if val.(bool) { it.memTracker.Consume(10 * MockResponseSizeForTest) it.memTracker.Consume(10 * MockResponseSizeForTest) } - } + }) } }() // wait unit at least 5 copResponse received. - if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockWaitMax")); _err_ == nil { + failpoint.Inject("testRateLimitActionMockWaitMax", func(val failpoint.Value) { if val.(bool) { // we only need to trigger oom at least once. if len(it.tasks) > 9 { @@ -1040,7 +1040,7 @@ func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { } } } - } + }) // If data order matters, response should be returned in the same order as copTask slice. // Otherwise all responses are returned from a single channel. if it.respChan != nil { @@ -1117,11 +1117,11 @@ func chooseBackoffer(ctx context.Context, backoffermap map[uint64]*Backoffer, ta return bo } boMaxSleep := CopNextMaxBackoff - if value, _err_ := failpoint.Eval(_curpkg_("ReduceCopNextMaxBackoff")); _err_ == nil { + failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { if value.(bool) { boMaxSleep = 2 } - } + }) newbo := backoff.NewBackofferWithVars(ctx, boMaxSleep, worker.vars) backoffermap[task.region.GetID()] = newbo return newbo @@ -1165,11 +1165,11 @@ func (worker *copIteratorWorker) handleTask(ctx context.Context, task *copTask, // handleTaskOnce handles single copTask, successful results are send to channel. // If error happened, returns error. If region split or meet lock, returns the remain tasks. func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { - if val, _err_ := failpoint.Eval(_curpkg_("handleTaskOnceError")); _err_ == nil { + failpoint.Inject("handleTaskOnceError", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("mock handleTaskOnce error") + failpoint.Return(nil, errors.New("mock handleTaskOnce error")) } - } + }) if task.paging { task.pagingTaskIdx = atomic.AddUint32(worker.pagingTaskIdx, 1) @@ -1222,10 +1222,10 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch if task.tikvClientReadTimeout > 0 { timeout = time.Duration(task.tikvClientReadTimeout) * time.Millisecond } - if v, _err_ := failpoint.Eval(_curpkg_("sleepCoprRequest")); _err_ == nil { + failpoint.Inject("sleepCoprRequest", func(v failpoint.Value) { //nolint:durationcheck time.Sleep(time.Millisecond * time.Duration(v.(int))) - } + }) if worker.req.RunawayChecker != nil { if err := worker.req.RunawayChecker.BeforeCopRequest(req); err != nil { @@ -1259,14 +1259,14 @@ func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch timeout, getEndPointType(task.storeType), task.storeAddr, ops...) err = derr.ToTiDBErr(err) if worker.req.RunawayChecker != nil { - if v, _err_ := failpoint.Eval(_curpkg_("sleepCoprAfterReq")); _err_ == nil { + failpoint.Inject("sleepCoprAfterReq", func(v failpoint.Value) { //nolint:durationcheck value := v.(int) time.Sleep(time.Millisecond * time.Duration(value)) if value > 50 { err = errors.Errorf("Coprocessor task terminated due to exceeding the deadline") } - } + }) err = worker.req.RunawayChecker.CheckCopRespError(err) } if err != nil { @@ -1540,9 +1540,9 @@ func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, rpcCtx *t }, } task := batchedTask.task - if _, _err_ := failpoint.Eval(_curpkg_("batchCopRegionError")); _err_ == nil { + failpoint.Inject("batchCopRegionError", func() { batchResp.RegionError = &errorpb.Error{} - } + }) if regionErr := batchResp.GetRegionError(); regionErr != nil { errStr := fmt.Sprintf("region_id:%v, region_ver:%v, store_type:%s, peer_addr:%s, error:%s", task.region.GetID(), task.region.GetVer(), task.storeType.Name(), task.storeAddr, regionErr.String()) @@ -1776,11 +1776,11 @@ func (worker *copIteratorWorker) handleCollectExecutionInfo(bo *Backoffer, rpcCt if !worker.enableCollectExecutionInfo { return } - if val, _err_ := failpoint.Eval(_curpkg_("disable-collect-execution")); _err_ == nil { + failpoint.Inject("disable-collect-execution", func(val failpoint.Value) { if val.(bool) { panic("shouldn't reachable") } - } + }) if resp.detail == nil { resp.detail = new(CopRuntimeStats) } @@ -2023,13 +2023,13 @@ func (e *rateLimitAction) Action(t *memory.Tracker) { } return } - if val, _err_ := failpoint.Eval(_curpkg_("testRateLimitActionMockConsumeAndAssert")); _err_ == nil { + failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { if val.(bool) { if e.cond.triggerCountForTest+e.cond.remainingTokenNum != e.totalTokenNum { panic("triggerCount + remainingTokenNum not equal to totalTokenNum") } } - } + }) logutil.BgLogger().Info("memory exceeds quota, destroy one token now.", zap.Int64("consumed", t.BytesConsumed()), zap.Int64("quota", t.GetBytesLimit()), @@ -2143,9 +2143,9 @@ func optRowHint(req *kv.Request) bool { // disable extra concurrency for internal tasks. return false } - if _, _err_ := failpoint.Eval(_curpkg_("disableFixedRowCountHint")); _err_ == nil { + failpoint.Inject("disableFixedRowCountHint", func(_ failpoint.Value) { opt = false - } + }) return opt } diff --git a/pkg/store/copr/coprocessor.go__failpoint_stash__ b/pkg/store/copr/coprocessor.go__failpoint_stash__ deleted file mode 100644 index 6b59d8096784d..0000000000000 --- a/pkg/store/copr/coprocessor.go__failpoint_stash__ +++ /dev/null @@ -1,2170 +0,0 @@ -// Copyright 2016 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package copr - -import ( - "context" - "fmt" - "math" - "net" - "runtime" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - "unsafe" - - "github.com/gogo/protobuf/proto" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/coprocessor" - "github.com/pingcap/kvproto/pkg/errorpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/domain/resourcegroup" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/kv" - tidbmetrics "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - copr_metrics "github.com/pingcap/tidb/pkg/store/copr/metrics" - "github.com/pingcap/tidb/pkg/store/driver/backoff" - derr "github.com/pingcap/tidb/pkg/store/driver/error" - "github.com/pingcap/tidb/pkg/store/driver/options" - util2 "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/paging" - "github.com/pingcap/tidb/pkg/util/size" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/pingcap/tidb/pkg/util/trxevents" - "github.com/pingcap/tipb/go-tipb" - "github.com/tikv/client-go/v2/metrics" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/tikvrpc/interceptor" - "github.com/tikv/client-go/v2/txnkv/txnlock" - "github.com/tikv/client-go/v2/txnkv/txnsnapshot" - "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" -) - -// Maximum total sleep time(in ms) for kv/cop commands. -const ( - copBuildTaskMaxBackoff = 5000 - CopNextMaxBackoff = 20000 - CopSmallTaskRow = 32 // 32 is the initial batch size of TiKV - smallTaskSigma = 0.5 - smallConcPerCore = 20 -) - -// CopClient is coprocessor client. -type CopClient struct { - kv.RequestTypeSupportedChecker - store *Store - replicaReadSeed uint32 -} - -// Send builds the request and gets the coprocessor iterator response. -func (c *CopClient) Send(ctx context.Context, req *kv.Request, variables any, option *kv.ClientSendOption) kv.Response { - vars, ok := variables.(*tikv.Variables) - if !ok { - return copErrorResponse{errors.Errorf("unsupported variables:%+v", variables)} - } - if req.StoreType == kv.TiFlash && req.BatchCop { - logutil.BgLogger().Debug("send batch requests") - return c.sendBatch(ctx, req, vars, option) - } - ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTs) - ctx = context.WithValue(ctx, util.RequestSourceKey, req.RequestSource) - ctx = interceptor.WithRPCInterceptor(ctx, interceptor.GetRPCInterceptorFromCtx(ctx)) - enabledRateLimitAction := option.EnabledRateLimitAction - sessionMemTracker := option.SessionMemTracker - it, errRes := c.BuildCopIterator(ctx, req, vars, option) - if errRes != nil { - return errRes - } - ctx = context.WithValue(ctx, tikv.RPCCancellerCtxKey{}, it.rpcCancel) - if sessionMemTracker != nil && enabledRateLimitAction { - sessionMemTracker.FallbackOldAndSetNewAction(it.actionOnExceed) - } - it.open(ctx, enabledRateLimitAction, option.EnableCollectExecutionInfo) - return it -} - -// BuildCopIterator builds the iterator without calling `open`. -func (c *CopClient) BuildCopIterator(ctx context.Context, req *kv.Request, vars *tikv.Variables, option *kv.ClientSendOption) (*copIterator, kv.Response) { - eventCb := option.EventCb - failpoint.Inject("DisablePaging", func(_ failpoint.Value) { - req.Paging.Enable = false - }) - if req.StoreType == kv.TiDB { - // coprocessor on TiDB doesn't support paging - req.Paging.Enable = false - } - if req.Tp != kv.ReqTypeDAG { - // coprocessor request but type is not DAG - req.Paging.Enable = false - } - failpoint.Inject("checkKeyRangeSortedForPaging", func(_ failpoint.Value) { - if req.Paging.Enable { - if !req.KeyRanges.IsFullySorted() { - logutil.BgLogger().Fatal("distsql request key range not sorted!") - } - } - }) - if !checkStoreBatchCopr(req) { - req.StoreBatchSize = 0 - } - - bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, vars) - var ( - tasks []*copTask - err error - ) - tryRowHint := optRowHint(req) - elapsed := time.Duration(0) - buildOpt := &buildCopTaskOpt{ - req: req, - cache: c.store.GetRegionCache(), - eventCb: eventCb, - respChan: req.KeepOrder, - elapsed: &elapsed, - } - buildTaskFunc := func(ranges []kv.KeyRange, hints []int) error { - keyRanges := NewKeyRanges(ranges) - if tryRowHint { - buildOpt.rowHints = hints - } - tasksFromRanges, err := buildCopTasks(bo, keyRanges, buildOpt) - if err != nil { - return err - } - if len(tasks) == 0 { - tasks = tasksFromRanges - return nil - } - tasks = append(tasks, tasksFromRanges...) - return nil - } - // Here we build the task by partition, not directly by region. - // This is because it's possible that TiDB merge multiple small partition into one region which break some assumption. - // Keep it split by partition would be more safe. - err = req.KeyRanges.ForEachPartitionWithErr(buildTaskFunc) - // only batch store requests in first build. - req.StoreBatchSize = 0 - reqType := "null" - if req.ClosestReplicaReadAdjuster != nil { - reqType = "miss" - if req.ClosestReplicaReadAdjuster(req, len(tasks)) { - reqType = "hit" - } - } - tidbmetrics.DistSQLCoprClosestReadCounter.WithLabelValues(reqType).Inc() - if err != nil { - return nil, copErrorResponse{err} - } - it := &copIterator{ - store: c.store, - req: req, - concurrency: req.Concurrency, - finishCh: make(chan struct{}), - vars: vars, - memTracker: req.MemTracker, - replicaReadSeed: c.replicaReadSeed, - rpcCancel: tikv.NewRPCanceller(), - buildTaskElapsed: *buildOpt.elapsed, - runawayChecker: req.RunawayChecker, - } - // Pipelined-dml can flush locks when it is still reading. - // The coprocessor of the txn should not be blocked by itself. - // It should be the only case where a coprocessor can read locks of the same ts. - // - // But when start_ts is not obtained from PD, - // the start_ts could conflict with another pipelined-txn's start_ts. - // in which case the locks of same ts cannot be ignored. - // We rely on the assumption: start_ts is not from PD => this is a stale read. - if !req.IsStaleness { - it.resolvedLocks.Put(req.StartTs) - } - it.tasks = tasks - if it.concurrency > len(tasks) { - it.concurrency = len(tasks) - } - if tryRowHint { - var smallTasks int - smallTasks, it.smallTaskConcurrency = smallTaskConcurrency(tasks, c.store.numcpu) - if len(tasks)-smallTasks < it.concurrency { - it.concurrency = len(tasks) - smallTasks - } - } - if it.concurrency < 1 { - // Make sure that there is at least one worker. - it.concurrency = 1 - } - - if it.req.KeepOrder { - if it.smallTaskConcurrency > 20 { - it.smallTaskConcurrency = 20 - } - it.sendRate = util.NewRateLimit(2 * (it.concurrency + it.smallTaskConcurrency)) - it.respChan = nil - } else { - it.respChan = make(chan *copResponse) - it.sendRate = util.NewRateLimit(it.concurrency + it.smallTaskConcurrency) - } - it.actionOnExceed = newRateLimitAction(uint(it.sendRate.GetCapacity())) - return it, nil -} - -// copTask contains a related Region and KeyRange for a kv.Request. -type copTask struct { - taskID uint64 - region tikv.RegionVerID - bucketsVer uint64 - ranges *KeyRanges - - respChan chan *copResponse - storeAddr string - cmdType tikvrpc.CmdType - storeType kv.StoreType - - eventCb trxevents.EventCallback - paging bool - pagingSize uint64 - pagingTaskIdx uint32 - - partitionIndex int64 // used by balanceBatchCopTask in PartitionTableScan - requestSource util.RequestSource - RowCountHint int // used for extra concurrency of small tasks, -1 for unknown row count - batchTaskList map[uint64]*batchedCopTask - - // when this task is batched and the leader's wait duration exceeds the load-based threshold, - // we set this field to the target replica store ID and redirect the request to the replica. - redirect2Replica *uint64 - busyThreshold time.Duration - meetLockFallback bool - - // timeout value for one kv readonly request - tikvClientReadTimeout uint64 - // firstReadType is used to indicate the type of first read when retrying. - firstReadType string -} - -type batchedCopTask struct { - task *copTask - region coprocessor.RegionInfo - storeID uint64 - peer *metapb.Peer - loadBasedReplicaRetry bool -} - -func (r *copTask) String() string { - return fmt.Sprintf("region(%d %d %d) ranges(%d) store(%s)", - r.region.GetID(), r.region.GetConfVer(), r.region.GetVer(), r.ranges.Len(), r.storeAddr) -} - -func (r *copTask) ToPBBatchTasks() []*coprocessor.StoreBatchTask { - if len(r.batchTaskList) == 0 { - return nil - } - pbTasks := make([]*coprocessor.StoreBatchTask, 0, len(r.batchTaskList)) - for _, task := range r.batchTaskList { - storeBatchTask := &coprocessor.StoreBatchTask{ - RegionId: task.region.GetRegionId(), - RegionEpoch: task.region.GetRegionEpoch(), - Peer: task.peer, - Ranges: task.region.GetRanges(), - TaskId: task.task.taskID, - } - pbTasks = append(pbTasks, storeBatchTask) - } - return pbTasks -} - -// rangesPerTask limits the length of the ranges slice sent in one copTask. -const rangesPerTask = 25000 - -type buildCopTaskOpt struct { - req *kv.Request - cache *RegionCache - eventCb trxevents.EventCallback - respChan bool - rowHints []int - elapsed *time.Duration - // ignoreTiKVClientReadTimeout is used to ignore tikv_client_read_timeout configuration, use default timeout instead. - ignoreTiKVClientReadTimeout bool -} - -func buildCopTasks(bo *Backoffer, ranges *KeyRanges, opt *buildCopTaskOpt) ([]*copTask, error) { - req, cache, eventCb, hints := opt.req, opt.cache, opt.eventCb, opt.rowHints - start := time.Now() - defer tracing.StartRegion(bo.GetCtx(), "copr.buildCopTasks").End() - cmdType := tikvrpc.CmdCop - if req.StoreType == kv.TiDB { - return buildTiDBMemCopTasks(ranges, req) - } - rangesLen := ranges.Len() - // something went wrong, disable hints to avoid out of range index. - if len(hints) != rangesLen { - hints = nil - } - - rangesPerTaskLimit := rangesPerTask - failpoint.Inject("setRangesPerTask", func(val failpoint.Value) { - if v, ok := val.(int); ok { - rangesPerTaskLimit = v - } - }) - - // TODO(youjiali1995): is there any request type that needn't be split by buckets? - locs, err := cache.SplitKeyRangesByBuckets(bo, ranges) - if err != nil { - return nil, errors.Trace(err) - } - // Channel buffer is 2 for handling region split. - // In a common case, two region split tasks will not be blocked. - chanSize := 2 - // in paging request, a request will be returned in multi batches, - // enlarge the channel size to avoid the request blocked by buffer full. - if req.Paging.Enable { - chanSize = 18 - } - - var builder taskBuilder - if req.StoreBatchSize > 0 && hints != nil { - builder = newBatchTaskBuilder(bo, req, cache, req.ReplicaRead) - } else { - builder = newLegacyTaskBuilder(len(locs)) - } - origRangeIdx := 0 - for _, loc := range locs { - // TiKV will return gRPC error if the message is too large. So we need to limit the length of the ranges slice - // to make sure the message can be sent successfully. - rLen := loc.Ranges.Len() - // If this is a paging request, we set the paging size to minPagingSize, - // the size will grow every round. - pagingSize := uint64(0) - if req.Paging.Enable { - pagingSize = req.Paging.MinPagingSize - } - for i := 0; i < rLen; { - nextI := min(i+rangesPerTaskLimit, rLen) - hint := -1 - // calculate the row count hint - if hints != nil { - startKey, endKey := loc.Ranges.RefAt(i).StartKey, loc.Ranges.RefAt(nextI-1).EndKey - // move to the previous range if startKey of current range is lower than endKey of previous location. - // In the following example, task1 will move origRangeIdx to region(i, z). - // When counting the row hint for task2, we need to move origRangeIdx back to region(a, h). - // |<- region(a, h) ->| |<- region(i, z) ->| - // |<- task1 ->| |<- task2 ->| ... - if origRangeIdx > 0 && ranges.RefAt(origRangeIdx-1).EndKey.Cmp(startKey) > 0 { - origRangeIdx-- - } - hint = 0 - for nextOrigRangeIdx := origRangeIdx; nextOrigRangeIdx < ranges.Len(); nextOrigRangeIdx++ { - rangeStart := ranges.RefAt(nextOrigRangeIdx).StartKey - if rangeStart.Cmp(endKey) > 0 { - origRangeIdx = nextOrigRangeIdx - break - } - hint += hints[nextOrigRangeIdx] - } - } - task := &copTask{ - region: loc.Location.Region, - bucketsVer: loc.getBucketVersion(), - ranges: loc.Ranges.Slice(i, nextI), - cmdType: cmdType, - storeType: req.StoreType, - eventCb: eventCb, - paging: req.Paging.Enable, - pagingSize: pagingSize, - requestSource: req.RequestSource, - RowCountHint: hint, - busyThreshold: req.StoreBusyThreshold, - } - if !opt.ignoreTiKVClientReadTimeout { - task.tikvClientReadTimeout = req.TiKVClientReadTimeout - } - // only keep-order need chan inside task. - // tasks by region error will reuse the channel of parent task. - if req.KeepOrder && opt.respChan { - task.respChan = make(chan *copResponse, chanSize) - } - if err = builder.handle(task); err != nil { - return nil, err - } - i = nextI - if req.Paging.Enable { - if req.LimitSize != 0 && req.LimitSize < pagingSize { - // disable paging for small limit. - task.paging = false - task.pagingSize = 0 - } else { - pagingSize = paging.GrowPagingSize(pagingSize, req.Paging.MaxPagingSize) - } - } - } - } - - if req.Desc { - builder.reverse() - } - tasks := builder.build() - elapsed := time.Since(start) - if elapsed > time.Millisecond*500 { - logutil.BgLogger().Warn("buildCopTasks takes too much time", - zap.Duration("elapsed", elapsed), - zap.Int("range len", rangesLen), - zap.Int("task len", len(tasks))) - } - if opt.elapsed != nil { - *opt.elapsed = *opt.elapsed + elapsed - } - metrics.TxnRegionsNumHistogramWithCoprocessor.Observe(float64(builder.regionNum())) - return tasks, nil -} - -type taskBuilder interface { - handle(*copTask) error - reverse() - build() []*copTask - regionNum() int -} - -type legacyTaskBuilder struct { - tasks []*copTask -} - -func newLegacyTaskBuilder(hint int) *legacyTaskBuilder { - return &legacyTaskBuilder{ - tasks: make([]*copTask, 0, hint), - } -} - -func (b *legacyTaskBuilder) handle(task *copTask) error { - b.tasks = append(b.tasks, task) - return nil -} - -func (b *legacyTaskBuilder) regionNum() int { - return len(b.tasks) -} - -func (b *legacyTaskBuilder) reverse() { - reverseTasks(b.tasks) -} - -func (b *legacyTaskBuilder) build() []*copTask { - return b.tasks -} - -type storeReplicaKey struct { - storeID uint64 - replicaRead bool -} - -type batchStoreTaskBuilder struct { - bo *Backoffer - req *kv.Request - cache *RegionCache - taskID uint64 - limit int - store2Idx map[storeReplicaKey]int - tasks []*copTask - replicaRead kv.ReplicaReadType -} - -func newBatchTaskBuilder(bo *Backoffer, req *kv.Request, cache *RegionCache, replicaRead kv.ReplicaReadType) *batchStoreTaskBuilder { - return &batchStoreTaskBuilder{ - bo: bo, - req: req, - cache: cache, - taskID: 0, - limit: req.StoreBatchSize, - store2Idx: make(map[storeReplicaKey]int, 16), - tasks: make([]*copTask, 0, 16), - replicaRead: replicaRead, - } -} - -func (b *batchStoreTaskBuilder) handle(task *copTask) (err error) { - b.taskID++ - task.taskID = b.taskID - handled := false - defer func() { - if !handled && err == nil { - // fallback to non-batch way. It's mainly caused by region miss. - b.tasks = append(b.tasks, task) - } - }() - // only batch small tasks for memory control. - if b.limit <= 0 || !isSmallTask(task) { - return nil - } - batchedTask, err := b.cache.BuildBatchTask(b.bo, b.req, task, b.replicaRead) - if err != nil { - return err - } - if batchedTask == nil { - return nil - } - key := storeReplicaKey{ - storeID: batchedTask.storeID, - replicaRead: batchedTask.loadBasedReplicaRetry, - } - if idx, ok := b.store2Idx[key]; !ok || len(b.tasks[idx].batchTaskList) >= b.limit { - if batchedTask.loadBasedReplicaRetry { - // If the task is dispatched to leader because all followers are busy, - // task.redirect2Replica != nil means the busy threshold shouldn't take effect again. - batchedTask.task.redirect2Replica = &batchedTask.storeID - } - b.tasks = append(b.tasks, batchedTask.task) - b.store2Idx[key] = len(b.tasks) - 1 - } else { - if b.tasks[idx].batchTaskList == nil { - b.tasks[idx].batchTaskList = make(map[uint64]*batchedCopTask, b.limit) - // disable paging for batched task. - b.tasks[idx].paging = false - b.tasks[idx].pagingSize = 0 - } - if task.RowCountHint > 0 { - b.tasks[idx].RowCountHint += task.RowCountHint - } - b.tasks[idx].batchTaskList[task.taskID] = batchedTask - } - handled = true - return nil -} - -func (b *batchStoreTaskBuilder) regionNum() int { - // we allocate b.taskID for each region task, so the final b.taskID is equal to the related region number. - return int(b.taskID) -} - -func (b *batchStoreTaskBuilder) reverse() { - reverseTasks(b.tasks) -} - -func (b *batchStoreTaskBuilder) build() []*copTask { - return b.tasks -} - -func buildTiDBMemCopTasks(ranges *KeyRanges, req *kv.Request) ([]*copTask, error) { - servers, err := infosync.GetAllServerInfo(context.Background()) - if err != nil { - return nil, err - } - cmdType := tikvrpc.CmdCop - tasks := make([]*copTask, 0, len(servers)) - for _, ser := range servers { - if req.TiDBServerID > 0 && req.TiDBServerID != ser.ServerIDGetter() { - continue - } - - addr := net.JoinHostPort(ser.IP, strconv.FormatUint(uint64(ser.StatusPort), 10)) - tasks = append(tasks, &copTask{ - ranges: ranges, - respChan: make(chan *copResponse, 2), - cmdType: cmdType, - storeType: req.StoreType, - storeAddr: addr, - RowCountHint: -1, - }) - } - return tasks, nil -} - -func reverseTasks(tasks []*copTask) { - for i := 0; i < len(tasks)/2; i++ { - j := len(tasks) - i - 1 - tasks[i], tasks[j] = tasks[j], tasks[i] - } -} - -func isSmallTask(task *copTask) bool { - // strictly, only RowCountHint == -1 stands for unknown task rows, - // but when RowCountHint == 0, it may be caused by initialized value, - // to avoid the future bugs, let the tasks with RowCountHint == 0 be non-small tasks. - return task.RowCountHint > 0 && - (len(task.batchTaskList) == 0 && task.RowCountHint <= CopSmallTaskRow) || - (len(task.batchTaskList) > 0 && task.RowCountHint <= 2*CopSmallTaskRow) -} - -// smallTaskConcurrency counts the small tasks of tasks, -// then returns the task count and extra concurrency for small tasks. -func smallTaskConcurrency(tasks []*copTask, numcpu int) (int, int) { - res := 0 - for _, task := range tasks { - if isSmallTask(task) { - res++ - } - } - if res == 0 { - return 0, 0 - } - // Calculate the extra concurrency for small tasks - // extra concurrency = tasks / (1 + sigma * sqrt(log(tasks ^ 2))) - extraConc := int(float64(res) / (1 + smallTaskSigma*math.Sqrt(2*math.Log(float64(res))))) - if numcpu <= 0 { - numcpu = 1 - } - smallTaskConcurrencyLimit := smallConcPerCore * numcpu - if extraConc > smallTaskConcurrencyLimit { - extraConc = smallTaskConcurrencyLimit - } - return res, extraConc -} - -// CopInfo is used to expose functions of copIterator. -type CopInfo interface { - // GetConcurrency returns the concurrency and small task concurrency. - GetConcurrency() (int, int) - // GetStoreBatchInfo returns the batched and fallback num. - GetStoreBatchInfo() (uint64, uint64) - // GetBuildTaskElapsed returns the duration of building task. - GetBuildTaskElapsed() time.Duration -} - -type copIterator struct { - store *Store - req *kv.Request - concurrency int - smallTaskConcurrency int - finishCh chan struct{} - - // If keepOrder, results are stored in copTask.respChan, read them out one by one. - tasks []*copTask - // curr indicates the curr id of the finished copTask - curr int - - // sendRate controls the sending rate of copIteratorTaskSender - sendRate *util.RateLimit - - // Otherwise, results are stored in respChan. - respChan chan *copResponse - - vars *tikv.Variables - - memTracker *memory.Tracker - - replicaReadSeed uint32 - - rpcCancel *tikv.RPCCanceller - - wg sync.WaitGroup - // closed represents when the Close is called. - // There are two cases we need to close the `finishCh` channel, one is when context is done, the other one is - // when the Close is called. we use atomic.CompareAndSwap `closed` to make sure the channel is not closed twice. - closed uint32 - - resolvedLocks util.TSSet - committedLocks util.TSSet - - actionOnExceed *rateLimitAction - pagingTaskIdx uint32 - - buildTaskElapsed time.Duration - storeBatchedNum atomic.Uint64 - storeBatchedFallbackNum atomic.Uint64 - - runawayChecker *resourcegroup.RunawayChecker - unconsumedStats *unconsumedCopRuntimeStats -} - -// copIteratorWorker receives tasks from copIteratorTaskSender, handles tasks and sends the copResponse to respChan. -type copIteratorWorker struct { - taskCh <-chan *copTask - wg *sync.WaitGroup - store *Store - req *kv.Request - respChan chan<- *copResponse - finishCh <-chan struct{} - vars *tikv.Variables - kvclient *txnsnapshot.ClientHelper - - memTracker *memory.Tracker - - replicaReadSeed uint32 - - enableCollectExecutionInfo bool - pagingTaskIdx *uint32 - - storeBatchedNum *atomic.Uint64 - storeBatchedFallbackNum *atomic.Uint64 - unconsumedStats *unconsumedCopRuntimeStats -} - -// copIteratorTaskSender sends tasks to taskCh then wait for the workers to exit. -type copIteratorTaskSender struct { - taskCh chan<- *copTask - smallTaskCh chan<- *copTask - wg *sync.WaitGroup - tasks []*copTask - finishCh <-chan struct{} - respChan chan<- *copResponse - sendRate *util.RateLimit -} - -type copResponse struct { - pbResp *coprocessor.Response - detail *CopRuntimeStats - startKey kv.Key - err error - respSize int64 - respTime time.Duration -} - -const sizeofExecDetails = int(unsafe.Sizeof(execdetails.ExecDetails{})) - -// GetData implements the kv.ResultSubset GetData interface. -func (rs *copResponse) GetData() []byte { - return rs.pbResp.Data -} - -// GetStartKey implements the kv.ResultSubset GetStartKey interface. -func (rs *copResponse) GetStartKey() kv.Key { - return rs.startKey -} - -func (rs *copResponse) GetCopRuntimeStats() *CopRuntimeStats { - return rs.detail -} - -// MemSize returns how many bytes of memory this response use -func (rs *copResponse) MemSize() int64 { - if rs.respSize != 0 { - return rs.respSize - } - if rs == finCopResp { - return 0 - } - - // ignore rs.err - rs.respSize += int64(cap(rs.startKey)) - if rs.detail != nil { - rs.respSize += int64(sizeofExecDetails) - } - if rs.pbResp != nil { - // Using a approximate size since it's hard to get a accurate value. - rs.respSize += int64(rs.pbResp.Size()) - } - return rs.respSize -} - -func (rs *copResponse) RespTime() time.Duration { - return rs.respTime -} - -const minLogCopTaskTime = 300 * time.Millisecond - -// When the worker finished `handleTask`, we need to notify the copIterator that there is one task finished. -// For the non-keep-order case, we send a finCopResp into the respCh after `handleTask`. When copIterator recv -// finCopResp from the respCh, it will be aware that there is one task finished. -var finCopResp *copResponse - -func init() { - finCopResp = &copResponse{} -} - -// run is a worker function that get a copTask from channel, handle it and -// send the result back. -func (worker *copIteratorWorker) run(ctx context.Context) { - defer func() { - failpoint.Inject("ticase-4169", func(val failpoint.Value) { - if val.(bool) { - worker.memTracker.Consume(10 * MockResponseSizeForTest) - worker.memTracker.Consume(10 * MockResponseSizeForTest) - } - }) - worker.wg.Done() - }() - // 16KB ballast helps grow the stack to the requirement of copIteratorWorker. - // This reduces the `morestack` call during the execution of `handleTask`, thus improvement the efficiency of TiDB. - // TODO: remove ballast after global pool is applied. - ballast := make([]byte, 16*size.KB) - for task := range worker.taskCh { - respCh := worker.respChan - if respCh == nil { - respCh = task.respChan - } - worker.handleTask(ctx, task, respCh) - if worker.respChan != nil { - // When a task is finished by the worker, send a finCopResp into channel to notify the copIterator that - // there is a task finished. - worker.sendToRespCh(finCopResp, worker.respChan, false) - } - if task.respChan != nil { - close(task.respChan) - } - if worker.finished() { - return - } - } - runtime.KeepAlive(ballast) -} - -// open starts workers and sender goroutines. -func (it *copIterator) open(ctx context.Context, enabledRateLimitAction, enableCollectExecutionInfo bool) { - taskCh := make(chan *copTask, 1) - smallTaskCh := make(chan *copTask, 1) - it.unconsumedStats = &unconsumedCopRuntimeStats{} - it.wg.Add(it.concurrency + it.smallTaskConcurrency) - // Start it.concurrency number of workers to handle cop requests. - for i := 0; i < it.concurrency+it.smallTaskConcurrency; i++ { - var ch chan *copTask - if i < it.concurrency { - ch = taskCh - } else { - ch = smallTaskCh - } - worker := &copIteratorWorker{ - taskCh: ch, - wg: &it.wg, - store: it.store, - req: it.req, - respChan: it.respChan, - finishCh: it.finishCh, - vars: it.vars, - kvclient: txnsnapshot.NewClientHelper(it.store.store, &it.resolvedLocks, &it.committedLocks, false), - memTracker: it.memTracker, - replicaReadSeed: it.replicaReadSeed, - enableCollectExecutionInfo: enableCollectExecutionInfo, - pagingTaskIdx: &it.pagingTaskIdx, - storeBatchedNum: &it.storeBatchedNum, - storeBatchedFallbackNum: &it.storeBatchedFallbackNum, - unconsumedStats: it.unconsumedStats, - } - go worker.run(ctx) - } - taskSender := &copIteratorTaskSender{ - taskCh: taskCh, - smallTaskCh: smallTaskCh, - wg: &it.wg, - tasks: it.tasks, - finishCh: it.finishCh, - sendRate: it.sendRate, - } - taskSender.respChan = it.respChan - it.actionOnExceed.setEnabled(enabledRateLimitAction) - failpoint.Inject("ticase-4171", func(val failpoint.Value) { - if val.(bool) { - it.memTracker.Consume(10 * MockResponseSizeForTest) - it.memTracker.Consume(10 * MockResponseSizeForTest) - } - }) - go taskSender.run(it.req.ConnID) -} - -func (sender *copIteratorTaskSender) run(connID uint64) { - // Send tasks to feed the worker goroutines. - for _, t := range sender.tasks { - // we control the sending rate to prevent all tasks - // being done (aka. all of the responses are buffered) by copIteratorWorker. - // We keep the number of inflight tasks within the number of 2 * concurrency when Keep Order is true. - // If KeepOrder is false, the number equals the concurrency. - // It sends one more task if a task has been finished in copIterator.Next. - exit := sender.sendRate.GetToken(sender.finishCh) - if exit { - break - } - var sendTo chan<- *copTask - if isSmallTask(t) { - sendTo = sender.smallTaskCh - } else { - sendTo = sender.taskCh - } - exit = sender.sendToTaskCh(t, sendTo) - if exit { - break - } - if connID > 0 { - failpoint.Inject("pauseCopIterTaskSender", func() {}) - } - } - close(sender.taskCh) - close(sender.smallTaskCh) - - // Wait for worker goroutines to exit. - sender.wg.Wait() - if sender.respChan != nil { - close(sender.respChan) - } -} - -func (it *copIterator) recvFromRespCh(ctx context.Context, respCh <-chan *copResponse) (resp *copResponse, ok bool, exit bool) { - failpoint.InjectCall("CtxCancelBeforeReceive", ctx) - ticker := time.NewTicker(3 * time.Second) - defer ticker.Stop() - for { - select { - case resp, ok = <-respCh: - if it.memTracker != nil && resp != nil { - consumed := resp.MemSize() - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { - if val.(bool) { - if resp != finCopResp { - consumed = MockResponseSizeForTest - } - } - }) - it.memTracker.Consume(-consumed) - } - return - case <-it.finishCh: - exit = true - return - case <-ticker.C: - killed := atomic.LoadUint32(it.vars.Killed) - if killed != 0 { - logutil.Logger(ctx).Info( - "a killed signal is received", - zap.Uint32("signal", killed), - ) - resp = &copResponse{err: derr.ErrQueryInterrupted} - ok = true - return - } - case <-ctx.Done(): - // We select the ctx.Done() in the thread of `Next` instead of in the worker to avoid the cost of `WithCancel`. - if atomic.CompareAndSwapUint32(&it.closed, 0, 1) { - close(it.finishCh) - } - exit = true - return - } - } -} - -// GetConcurrency returns the concurrency and small task concurrency. -func (it *copIterator) GetConcurrency() (int, int) { - return it.concurrency, it.smallTaskConcurrency -} - -// GetStoreBatchInfo returns the batched and fallback num. -func (it *copIterator) GetStoreBatchInfo() (uint64, uint64) { - return it.storeBatchedNum.Load(), it.storeBatchedFallbackNum.Load() -} - -// GetBuildTaskElapsed returns the duration of building task. -func (it *copIterator) GetBuildTaskElapsed() time.Duration { - return it.buildTaskElapsed -} - -// GetSendRate returns the rate-limit object. -func (it *copIterator) GetSendRate() *util.RateLimit { - return it.sendRate -} - -// GetTasks returns the built tasks. -func (it *copIterator) GetTasks() []*copTask { - return it.tasks -} - -func (sender *copIteratorTaskSender) sendToTaskCh(t *copTask, sendTo chan<- *copTask) (exit bool) { - select { - case sendTo <- t: - case <-sender.finishCh: - exit = true - } - return -} - -func (worker *copIteratorWorker) sendToRespCh(resp *copResponse, respCh chan<- *copResponse, checkOOM bool) (exit bool) { - if worker.memTracker != nil && checkOOM { - consumed := resp.MemSize() - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { - if val.(bool) { - if resp != finCopResp { - consumed = MockResponseSizeForTest - } - } - }) - failpoint.Inject("ConsumeRandomPanic", nil) - worker.memTracker.Consume(consumed) - } - select { - case respCh <- resp: - case <-worker.finishCh: - exit = true - } - return -} - -// MockResponseSizeForTest mock the response size -const MockResponseSizeForTest = 100 * 1024 * 1024 - -// Next returns next coprocessor result. -// NOTE: Use nil to indicate finish, so if the returned ResultSubset is not nil, reader should continue to call Next(). -func (it *copIterator) Next(ctx context.Context) (kv.ResultSubset, error) { - var ( - resp *copResponse - ok bool - closed bool - ) - defer func() { - if resp == nil { - failpoint.Inject("ticase-4170", func(val failpoint.Value) { - if val.(bool) { - it.memTracker.Consume(10 * MockResponseSizeForTest) - it.memTracker.Consume(10 * MockResponseSizeForTest) - } - }) - } - }() - // wait unit at least 5 copResponse received. - failpoint.Inject("testRateLimitActionMockWaitMax", func(val failpoint.Value) { - if val.(bool) { - // we only need to trigger oom at least once. - if len(it.tasks) > 9 { - for it.memTracker.MaxConsumed() < 5*MockResponseSizeForTest { - time.Sleep(10 * time.Millisecond) - } - } - } - }) - // If data order matters, response should be returned in the same order as copTask slice. - // Otherwise all responses are returned from a single channel. - if it.respChan != nil { - // Get next fetched resp from chan - resp, ok, closed = it.recvFromRespCh(ctx, it.respChan) - if !ok || closed { - it.actionOnExceed.close() - return nil, errors.Trace(ctx.Err()) - } - if resp == finCopResp { - it.actionOnExceed.destroyTokenIfNeeded(func() { - it.sendRate.PutToken() - }) - return it.Next(ctx) - } - } else { - for { - if it.curr >= len(it.tasks) { - // Resp will be nil if iterator is finishCh. - it.actionOnExceed.close() - return nil, nil - } - task := it.tasks[it.curr] - resp, ok, closed = it.recvFromRespCh(ctx, task.respChan) - if closed { - // Close() is called or context cancelled/timeout, so Next() is invalid. - return nil, errors.Trace(ctx.Err()) - } - if ok { - break - } - it.actionOnExceed.destroyTokenIfNeeded(func() { - it.sendRate.PutToken() - }) - // Switch to next task. - it.tasks[it.curr] = nil - it.curr++ - } - } - - if resp.err != nil { - return nil, errors.Trace(resp.err) - } - - err := it.store.CheckVisibility(it.req.StartTs) - if err != nil { - return nil, errors.Trace(err) - } - return resp, nil -} - -// HasUnconsumedCopRuntimeStats indicate whether has unconsumed CopRuntimeStats. -type HasUnconsumedCopRuntimeStats interface { - // CollectUnconsumedCopRuntimeStats returns unconsumed CopRuntimeStats. - CollectUnconsumedCopRuntimeStats() []*CopRuntimeStats -} - -func (it *copIterator) CollectUnconsumedCopRuntimeStats() []*CopRuntimeStats { - if it == nil || it.unconsumedStats == nil { - return nil - } - it.unconsumedStats.Lock() - stats := make([]*CopRuntimeStats, 0, len(it.unconsumedStats.stats)) - stats = append(stats, it.unconsumedStats.stats...) - it.unconsumedStats.Unlock() - return stats -} - -// Associate each region with an independent backoffer. In this way, when multiple regions are -// unavailable, TiDB can execute very quickly without blocking -func chooseBackoffer(ctx context.Context, backoffermap map[uint64]*Backoffer, task *copTask, worker *copIteratorWorker) *Backoffer { - bo, ok := backoffermap[task.region.GetID()] - if ok { - return bo - } - boMaxSleep := CopNextMaxBackoff - failpoint.Inject("ReduceCopNextMaxBackoff", func(value failpoint.Value) { - if value.(bool) { - boMaxSleep = 2 - } - }) - newbo := backoff.NewBackofferWithVars(ctx, boMaxSleep, worker.vars) - backoffermap[task.region.GetID()] = newbo - return newbo -} - -// handleTask handles single copTask, sends the result to channel, retry automatically on error. -func (worker *copIteratorWorker) handleTask(ctx context.Context, task *copTask, respCh chan<- *copResponse) { - defer func() { - r := recover() - if r != nil { - logutil.BgLogger().Error("copIteratorWork meet panic", - zap.Any("r", r), - zap.Stack("stack trace")) - resp := &copResponse{err: util2.GetRecoverError(r)} - // if panic has happened, set checkOOM to false to avoid another panic. - worker.sendToRespCh(resp, respCh, false) - } - }() - remainTasks := []*copTask{task} - backoffermap := make(map[uint64]*Backoffer) - for len(remainTasks) > 0 { - curTask := remainTasks[0] - bo := chooseBackoffer(ctx, backoffermap, curTask, worker) - tasks, err := worker.handleTaskOnce(bo, curTask, respCh) - if err != nil { - resp := &copResponse{err: errors.Trace(err)} - worker.sendToRespCh(resp, respCh, true) - return - } - if worker.finished() { - break - } - if len(tasks) > 0 { - remainTasks = append(tasks, remainTasks[1:]...) - } else { - remainTasks = remainTasks[1:] - } - } -} - -// handleTaskOnce handles single copTask, successful results are send to channel. -// If error happened, returns error. If region split or meet lock, returns the remain tasks. -func (worker *copIteratorWorker) handleTaskOnce(bo *Backoffer, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { - failpoint.Inject("handleTaskOnceError", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("mock handleTaskOnce error")) - } - }) - - if task.paging { - task.pagingTaskIdx = atomic.AddUint32(worker.pagingTaskIdx, 1) - } - - copReq := coprocessor.Request{ - Tp: worker.req.Tp, - StartTs: worker.req.StartTs, - Data: worker.req.Data, - Ranges: task.ranges.ToPBRanges(), - SchemaVer: worker.req.SchemaVar, - PagingSize: task.pagingSize, - Tasks: task.ToPBBatchTasks(), - ConnectionId: worker.req.ConnID, - ConnectionAlias: worker.req.ConnAlias, - } - - cacheKey, cacheValue := worker.buildCacheKey(task, &copReq) - - replicaRead := worker.req.ReplicaRead - rgName := worker.req.ResourceGroupName - if task.storeType == kv.TiFlash && !variable.EnableResourceControl.Load() { - // By calling variable.EnableGlobalResourceControlFunc() and setting global variables, - // tikv/client-go can sense whether the rg function is enabled - // But for tiflash, it check if rgName is empty to decide if resource control is enabled or not. - rgName = "" - } - req := tikvrpc.NewReplicaReadRequest(task.cmdType, &copReq, options.GetTiKVReplicaReadType(replicaRead), &worker.replicaReadSeed, kvrpcpb.Context{ - IsolationLevel: isolationLevelToPB(worker.req.IsolationLevel), - Priority: priorityToPB(worker.req.Priority), - NotFillCache: worker.req.NotFillCache, - RecordTimeStat: true, - RecordScanStat: true, - TaskId: worker.req.TaskID, - ResourceControlContext: &kvrpcpb.ResourceControlContext{ - ResourceGroupName: rgName, - }, - BusyThresholdMs: uint32(task.busyThreshold.Milliseconds()), - BucketsVersion: task.bucketsVer, - }) - req.InputRequestSource = task.requestSource.GetRequestSource() - if task.firstReadType != "" { - req.ReadType = task.firstReadType - req.IsRetryRequest = true - } - if worker.req.ResourceGroupTagger != nil { - worker.req.ResourceGroupTagger(req) - } - timeout := config.GetGlobalConfig().TiKVClient.CoprReqTimeout - if task.tikvClientReadTimeout > 0 { - timeout = time.Duration(task.tikvClientReadTimeout) * time.Millisecond - } - failpoint.Inject("sleepCoprRequest", func(v failpoint.Value) { - //nolint:durationcheck - time.Sleep(time.Millisecond * time.Duration(v.(int))) - }) - - if worker.req.RunawayChecker != nil { - if err := worker.req.RunawayChecker.BeforeCopRequest(req); err != nil { - return nil, err - } - } - req.StoreTp = getEndPointType(task.storeType) - startTime := time.Now() - if worker.kvclient.Stats == nil { - worker.kvclient.Stats = tikv.NewRegionRequestRuntimeStats() - } - // set ReadReplicaScope and TxnScope so that req.IsStaleRead will be true when it's a global scope stale read. - req.ReadReplicaScope = worker.req.ReadReplicaScope - req.TxnScope = worker.req.TxnScope - if task.meetLockFallback { - req.DisableStaleReadMeetLock() - } else if worker.req.IsStaleness { - req.EnableStaleWithMixedReplicaRead() - } - staleRead := req.GetStaleRead() - ops := make([]tikv.StoreSelectorOption, 0, 2) - if len(worker.req.MatchStoreLabels) > 0 { - ops = append(ops, tikv.WithMatchLabels(worker.req.MatchStoreLabels)) - } - if task.redirect2Replica != nil { - req.ReplicaRead = true - req.ReplicaReadType = options.GetTiKVReplicaReadType(kv.ReplicaReadFollower) - ops = append(ops, tikv.WithMatchStores([]uint64{*task.redirect2Replica})) - } - resp, rpcCtx, storeAddr, err := worker.kvclient.SendReqCtx(bo.TiKVBackoffer(), req, task.region, - timeout, getEndPointType(task.storeType), task.storeAddr, ops...) - err = derr.ToTiDBErr(err) - if worker.req.RunawayChecker != nil { - failpoint.Inject("sleepCoprAfterReq", func(v failpoint.Value) { - //nolint:durationcheck - value := v.(int) - time.Sleep(time.Millisecond * time.Duration(value)) - if value > 50 { - err = errors.Errorf("Coprocessor task terminated due to exceeding the deadline") - } - }) - err = worker.req.RunawayChecker.CheckCopRespError(err) - } - if err != nil { - if task.storeType == kv.TiDB { - err = worker.handleTiDBSendReqErr(err, task, ch) - return nil, err - } - worker.collectUnconsumedCopRuntimeStats(bo, rpcCtx) - return nil, errors.Trace(err) - } - - // Set task.storeAddr field so its task.String() method have the store address information. - task.storeAddr = storeAddr - - costTime := time.Since(startTime) - copResp := resp.Resp.(*coprocessor.Response) - - if costTime > minLogCopTaskTime { - worker.logTimeCopTask(costTime, task, bo, copResp) - } - - storeID := strconv.FormatUint(req.Context.GetPeer().GetStoreId(), 10) - isInternal := util.IsRequestSourceInternal(&task.requestSource) - scope := metrics.LblGeneral - if isInternal { - scope = metrics.LblInternal - } - metrics.TiKVCoprocessorHistogram.WithLabelValues(storeID, strconv.FormatBool(staleRead), scope).Observe(costTime.Seconds()) - if copResp != nil { - tidbmetrics.DistSQLCoprRespBodySize.WithLabelValues(storeAddr).Observe(float64(len(copResp.Data))) - } - - var remains []*copTask - if worker.req.Paging.Enable { - remains, err = worker.handleCopPagingResult(bo, rpcCtx, &copResponse{pbResp: copResp}, cacheKey, cacheValue, task, ch, costTime) - } else { - // Handles the response for non-paging copTask. - remains, err = worker.handleCopResponse(bo, rpcCtx, &copResponse{pbResp: copResp}, cacheKey, cacheValue, task, ch, nil, costTime) - } - if req.ReadType != "" { - for _, remain := range remains { - remain.firstReadType = req.ReadType - } - } - return remains, err -} - -const ( - minLogBackoffTime = 100 - minLogKVProcessTime = 100 -) - -func (worker *copIteratorWorker) logTimeCopTask(costTime time.Duration, task *copTask, bo *Backoffer, resp *coprocessor.Response) { - logStr := fmt.Sprintf("[TIME_COP_PROCESS] resp_time:%s txnStartTS:%d region_id:%d store_addr:%s", costTime, worker.req.StartTs, task.region.GetID(), task.storeAddr) - if worker.kvclient.Stats != nil { - logStr += fmt.Sprintf(" stats:%s", worker.kvclient.Stats.String()) - } - if bo.GetTotalSleep() > minLogBackoffTime { - backoffTypes := strings.ReplaceAll(fmt.Sprintf("%v", bo.TiKVBackoffer().GetTypes()), " ", ",") - logStr += fmt.Sprintf(" backoff_ms:%d backoff_types:%s", bo.GetTotalSleep(), backoffTypes) - } - if regionErr := resp.GetRegionError(); regionErr != nil { - logStr += fmt.Sprintf(" region_err:%s", regionErr.String()) - } - // resp might be nil, but it is safe to call resp.GetXXX here. - detailV2 := resp.GetExecDetailsV2() - detail := resp.GetExecDetails() - var timeDetail *kvrpcpb.TimeDetail - if detailV2 != nil && detailV2.TimeDetail != nil { - timeDetail = detailV2.TimeDetail - } else if detail != nil && detail.TimeDetail != nil { - timeDetail = detail.TimeDetail - } - if timeDetail != nil { - logStr += fmt.Sprintf(" kv_process_ms:%d", timeDetail.ProcessWallTimeMs) - logStr += fmt.Sprintf(" kv_wait_ms:%d", timeDetail.WaitWallTimeMs) - logStr += fmt.Sprintf(" kv_read_ms:%d", timeDetail.KvReadWallTimeMs) - if timeDetail.ProcessWallTimeMs <= minLogKVProcessTime { - logStr = strings.Replace(logStr, "TIME_COP_PROCESS", "TIME_COP_WAIT", 1) - } - } - - if detailV2 != nil && detailV2.ScanDetailV2 != nil { - logStr += fmt.Sprintf(" processed_versions:%d", detailV2.ScanDetailV2.ProcessedVersions) - logStr += fmt.Sprintf(" total_versions:%d", detailV2.ScanDetailV2.TotalVersions) - logStr += fmt.Sprintf(" rocksdb_delete_skipped_count:%d", detailV2.ScanDetailV2.RocksdbDeleteSkippedCount) - logStr += fmt.Sprintf(" rocksdb_key_skipped_count:%d", detailV2.ScanDetailV2.RocksdbKeySkippedCount) - logStr += fmt.Sprintf(" rocksdb_cache_hit_count:%d", detailV2.ScanDetailV2.RocksdbBlockCacheHitCount) - logStr += fmt.Sprintf(" rocksdb_read_count:%d", detailV2.ScanDetailV2.RocksdbBlockReadCount) - logStr += fmt.Sprintf(" rocksdb_read_byte:%d", detailV2.ScanDetailV2.RocksdbBlockReadByte) - } else if detail != nil && detail.ScanDetail != nil { - logStr = appendScanDetail(logStr, "write", detail.ScanDetail.Write) - logStr = appendScanDetail(logStr, "data", detail.ScanDetail.Data) - logStr = appendScanDetail(logStr, "lock", detail.ScanDetail.Lock) - } - logutil.Logger(bo.GetCtx()).Info(logStr) -} - -func appendScanDetail(logStr string, columnFamily string, scanInfo *kvrpcpb.ScanInfo) string { - if scanInfo != nil { - logStr += fmt.Sprintf(" scan_total_%s:%d", columnFamily, scanInfo.Total) - logStr += fmt.Sprintf(" scan_processed_%s:%d", columnFamily, scanInfo.Processed) - } - return logStr -} - -func (worker *copIteratorWorker) handleCopPagingResult(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse, cacheKey []byte, cacheValue *coprCacheValue, task *copTask, ch chan<- *copResponse, costTime time.Duration) ([]*copTask, error) { - remainedTasks, err := worker.handleCopResponse(bo, rpcCtx, resp, cacheKey, cacheValue, task, ch, nil, costTime) - if err != nil || len(remainedTasks) != 0 { - // If there is region error or lock error, keep the paging size and retry. - for _, remainedTask := range remainedTasks { - remainedTask.pagingSize = task.pagingSize - } - return remainedTasks, errors.Trace(err) - } - pagingRange := resp.pbResp.Range - // only paging requests need to calculate the next ranges - if pagingRange == nil { - // If the storage engine doesn't support paging protocol, it should have return all the region data. - // So we finish here. - return nil, nil - } - - // calculate next ranges and grow the paging size - task.ranges = worker.calculateRemain(task.ranges, pagingRange, worker.req.Desc) - if task.ranges.Len() == 0 { - return nil, nil - } - - task.pagingSize = paging.GrowPagingSize(task.pagingSize, worker.req.Paging.MaxPagingSize) - return []*copTask{task}, nil -} - -// handleCopResponse checks coprocessor Response for region split and lock, -// returns more tasks when that happens, or handles the response if no error. -// if we're handling coprocessor paging response, lastRange is the range of last -// successful response, otherwise it's nil. -func (worker *copIteratorWorker) handleCopResponse(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse, cacheKey []byte, cacheValue *coprCacheValue, task *copTask, ch chan<- *copResponse, lastRange *coprocessor.KeyRange, costTime time.Duration) ([]*copTask, error) { - if ver := resp.pbResp.GetLatestBucketsVersion(); task.bucketsVer < ver { - worker.store.GetRegionCache().UpdateBucketsIfNeeded(task.region, ver) - } - if regionErr := resp.pbResp.GetRegionError(); regionErr != nil { - if rpcCtx != nil && task.storeType == kv.TiDB { - resp.err = errors.Errorf("error: %v", regionErr) - worker.sendToRespCh(resp, ch, true) - return nil, nil - } - errStr := fmt.Sprintf("region_id:%v, region_ver:%v, store_type:%s, peer_addr:%s, error:%s", - task.region.GetID(), task.region.GetVer(), task.storeType.Name(), task.storeAddr, regionErr.String()) - if err := bo.Backoff(tikv.BoRegionMiss(), errors.New(errStr)); err != nil { - return nil, errors.Trace(err) - } - // We may meet RegionError at the first packet, but not during visiting the stream. - remains, err := buildCopTasks(bo, task.ranges, &buildCopTaskOpt{ - req: worker.req, - cache: worker.store.GetRegionCache(), - respChan: false, - eventCb: task.eventCb, - ignoreTiKVClientReadTimeout: true, - }) - if err != nil { - return remains, err - } - return worker.handleBatchRemainsOnErr(bo, rpcCtx, remains, resp.pbResp, task, ch) - } - if lockErr := resp.pbResp.GetLocked(); lockErr != nil { - if err := worker.handleLockErr(bo, lockErr, task); err != nil { - return nil, err - } - task.meetLockFallback = true - return worker.handleBatchRemainsOnErr(bo, rpcCtx, []*copTask{task}, resp.pbResp, task, ch) - } - if otherErr := resp.pbResp.GetOtherError(); otherErr != "" { - err := errors.Errorf("other error: %s", otherErr) - - firstRangeStartKey := task.ranges.At(0).StartKey - lastRangeEndKey := task.ranges.At(task.ranges.Len() - 1).EndKey - - logutil.Logger(bo.GetCtx()).Warn("other error", - zap.Uint64("txnStartTS", worker.req.StartTs), - zap.Uint64("regionID", task.region.GetID()), - zap.Uint64("regionVer", task.region.GetVer()), - zap.Uint64("regionConfVer", task.region.GetConfVer()), - zap.Uint64("bucketsVer", task.bucketsVer), - zap.Uint64("latestBucketsVer", resp.pbResp.GetLatestBucketsVersion()), - zap.Int("rangeNums", task.ranges.Len()), - zap.ByteString("firstRangeStartKey", firstRangeStartKey), - zap.ByteString("lastRangeEndKey", lastRangeEndKey), - zap.String("storeAddr", task.storeAddr), - zap.Error(err)) - if strings.Contains(err.Error(), "write conflict") { - return nil, kv.ErrWriteConflict.FastGen("%s", otherErr) - } - return nil, errors.Trace(err) - } - // When the request is using paging API, the `Range` is not nil. - if resp.pbResp.Range != nil { - resp.startKey = resp.pbResp.Range.Start - } else if task.ranges != nil && task.ranges.Len() > 0 { - resp.startKey = task.ranges.At(0).StartKey - } - worker.handleCollectExecutionInfo(bo, rpcCtx, resp) - resp.respTime = costTime - - if err := worker.handleCopCache(task, resp, cacheKey, cacheValue); err != nil { - return nil, err - } - - pbResp := resp.pbResp - worker.sendToRespCh(resp, ch, true) - return worker.handleBatchCopResponse(bo, rpcCtx, pbResp, task.batchTaskList, ch) -} - -func (worker *copIteratorWorker) handleBatchRemainsOnErr(bo *Backoffer, rpcCtx *tikv.RPCContext, remains []*copTask, resp *coprocessor.Response, task *copTask, ch chan<- *copResponse) ([]*copTask, error) { - if len(task.batchTaskList) == 0 { - return remains, nil - } - batchedTasks := task.batchTaskList - task.batchTaskList = nil - batchedRemains, err := worker.handleBatchCopResponse(bo, rpcCtx, resp, batchedTasks, ch) - if err != nil { - return nil, err - } - return append(remains, batchedRemains...), nil -} - -// handle the batched cop response. -// tasks will be changed, so the input tasks should not be used after calling this function. -func (worker *copIteratorWorker) handleBatchCopResponse(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *coprocessor.Response, - tasks map[uint64]*batchedCopTask, ch chan<- *copResponse) (remainTasks []*copTask, err error) { - if len(tasks) == 0 { - return nil, nil - } - batchedNum := len(tasks) - busyThresholdFallback := false - defer func() { - if err != nil { - return - } - if !busyThresholdFallback { - worker.storeBatchedNum.Add(uint64(batchedNum - len(remainTasks))) - worker.storeBatchedFallbackNum.Add(uint64(len(remainTasks))) - } - }() - appendRemainTasks := func(tasks ...*copTask) { - if remainTasks == nil { - // allocate size of remain length - remainTasks = make([]*copTask, 0, len(tasks)) - } - remainTasks = append(remainTasks, tasks...) - } - // need Addr for recording details. - var dummyRPCCtx *tikv.RPCContext - if rpcCtx != nil { - dummyRPCCtx = &tikv.RPCContext{ - Addr: rpcCtx.Addr, - } - } - batchResps := resp.GetBatchResponses() - for _, batchResp := range batchResps { - taskID := batchResp.GetTaskId() - batchedTask, ok := tasks[taskID] - if !ok { - return nil, errors.Errorf("task id %d not found", batchResp.GetTaskId()) - } - delete(tasks, taskID) - resp := &copResponse{ - pbResp: &coprocessor.Response{ - Data: batchResp.Data, - ExecDetailsV2: batchResp.ExecDetailsV2, - }, - } - task := batchedTask.task - failpoint.Inject("batchCopRegionError", func() { - batchResp.RegionError = &errorpb.Error{} - }) - if regionErr := batchResp.GetRegionError(); regionErr != nil { - errStr := fmt.Sprintf("region_id:%v, region_ver:%v, store_type:%s, peer_addr:%s, error:%s", - task.region.GetID(), task.region.GetVer(), task.storeType.Name(), task.storeAddr, regionErr.String()) - if err := bo.Backoff(tikv.BoRegionMiss(), errors.New(errStr)); err != nil { - return nil, errors.Trace(err) - } - remains, err := buildCopTasks(bo, task.ranges, &buildCopTaskOpt{ - req: worker.req, - cache: worker.store.GetRegionCache(), - respChan: false, - eventCb: task.eventCb, - ignoreTiKVClientReadTimeout: true, - }) - if err != nil { - return nil, err - } - appendRemainTasks(remains...) - continue - } - //TODO: handle locks in batch - if lockErr := batchResp.GetLocked(); lockErr != nil { - if err := worker.handleLockErr(bo, resp.pbResp.GetLocked(), task); err != nil { - return nil, err - } - task.meetLockFallback = true - appendRemainTasks(task) - continue - } - if otherErr := batchResp.GetOtherError(); otherErr != "" { - err := errors.Errorf("other error: %s", otherErr) - - firstRangeStartKey := task.ranges.At(0).StartKey - lastRangeEndKey := task.ranges.At(task.ranges.Len() - 1).EndKey - - logutil.Logger(bo.GetCtx()).Warn("other error", - zap.Uint64("txnStartTS", worker.req.StartTs), - zap.Uint64("regionID", task.region.GetID()), - zap.Uint64("regionVer", task.region.GetVer()), - zap.Uint64("regionConfVer", task.region.GetConfVer()), - zap.Uint64("bucketsVer", task.bucketsVer), - // TODO: add bucket version in log - //zap.Uint64("latestBucketsVer", batchResp.GetLatestBucketsVersion()), - zap.Int("rangeNums", task.ranges.Len()), - zap.ByteString("firstRangeStartKey", firstRangeStartKey), - zap.ByteString("lastRangeEndKey", lastRangeEndKey), - zap.String("storeAddr", task.storeAddr), - zap.Error(err)) - if strings.Contains(err.Error(), "write conflict") { - return nil, kv.ErrWriteConflict.FastGen("%s", otherErr) - } - return nil, errors.Trace(err) - } - worker.handleCollectExecutionInfo(bo, dummyRPCCtx, resp) - worker.sendToRespCh(resp, ch, true) - } - for _, t := range tasks { - task := t.task - // when the error is generated by client or a load-based server busy, - // response is empty by design, skip warning for this case. - if len(batchResps) != 0 { - firstRangeStartKey := task.ranges.At(0).StartKey - lastRangeEndKey := task.ranges.At(task.ranges.Len() - 1).EndKey - logutil.Logger(bo.GetCtx()).Error("response of batched task missing", - zap.Uint64("id", task.taskID), - zap.Uint64("txnStartTS", worker.req.StartTs), - zap.Uint64("regionID", task.region.GetID()), - zap.Uint64("regionVer", task.region.GetVer()), - zap.Uint64("regionConfVer", task.region.GetConfVer()), - zap.Uint64("bucketsVer", task.bucketsVer), - zap.Int("rangeNums", task.ranges.Len()), - zap.ByteString("firstRangeStartKey", firstRangeStartKey), - zap.ByteString("lastRangeEndKey", lastRangeEndKey), - zap.String("storeAddr", task.storeAddr)) - } - appendRemainTasks(t.task) - } - if regionErr := resp.GetRegionError(); regionErr != nil && regionErr.ServerIsBusy != nil && - regionErr.ServerIsBusy.EstimatedWaitMs > 0 && len(remainTasks) != 0 { - if len(batchResps) != 0 { - return nil, errors.New("store batched coprocessor with server is busy error shouldn't contain responses") - } - busyThresholdFallback = true - handler := newBatchTaskBuilder(bo, worker.req, worker.store.GetRegionCache(), kv.ReplicaReadFollower) - for _, task := range remainTasks { - // do not set busy threshold again. - task.busyThreshold = 0 - if err = handler.handle(task); err != nil { - return nil, err - } - } - remainTasks = handler.build() - } - return remainTasks, nil -} - -func (worker *copIteratorWorker) handleLockErr(bo *Backoffer, lockErr *kvrpcpb.LockInfo, task *copTask) error { - if lockErr == nil { - return nil - } - resolveLockDetail := worker.getLockResolverDetails() - // Be care that we didn't redact the SQL statement because the log is DEBUG level. - if task.eventCb != nil { - task.eventCb(trxevents.WrapCopMeetLock(&trxevents.CopMeetLock{ - LockInfo: lockErr, - })) - } else { - logutil.Logger(bo.GetCtx()).Debug("coprocessor encounters lock", - zap.Stringer("lock", lockErr)) - } - resolveLocksOpts := txnlock.ResolveLocksOptions{ - CallerStartTS: worker.req.StartTs, - Locks: []*txnlock.Lock{txnlock.NewLock(lockErr)}, - Detail: resolveLockDetail, - } - resolveLocksRes, err1 := worker.kvclient.ResolveLocksWithOpts(bo.TiKVBackoffer(), resolveLocksOpts) - err1 = derr.ToTiDBErr(err1) - if err1 != nil { - return errors.Trace(err1) - } - msBeforeExpired := resolveLocksRes.TTL - if msBeforeExpired > 0 { - if err := bo.BackoffWithMaxSleepTxnLockFast(int(msBeforeExpired), errors.New(lockErr.String())); err != nil { - return errors.Trace(err) - } - } - return nil -} - -func (worker *copIteratorWorker) buildCacheKey(task *copTask, copReq *coprocessor.Request) (cacheKey []byte, cacheValue *coprCacheValue) { - // If there are many ranges, it is very likely to be a TableLookupRequest. They are not worth to cache since - // computing is not the main cost. Ignore requests with many ranges directly to avoid slowly building the cache key. - if task.cmdType == tikvrpc.CmdCop && worker.store.coprCache != nil && worker.req.Cacheable && worker.store.coprCache.CheckRequestAdmission(len(copReq.Ranges)) { - cKey, err := coprCacheBuildKey(copReq) - if err == nil { - cacheKey = cKey - cValue := worker.store.coprCache.Get(cKey) - copReq.IsCacheEnabled = true - - if cValue != nil && cValue.RegionID == task.region.GetID() && cValue.TimeStamp <= worker.req.StartTs { - // Append cache version to the request to skip Coprocessor computation if possible - // when request result is cached - copReq.CacheIfMatchVersion = cValue.RegionDataVersion - cacheValue = cValue - } else { - copReq.CacheIfMatchVersion = 0 - } - } else { - logutil.BgLogger().Warn("Failed to build copr cache key", zap.Error(err)) - } - } - return -} - -func (worker *copIteratorWorker) handleCopCache(task *copTask, resp *copResponse, cacheKey []byte, cacheValue *coprCacheValue) error { - if resp.pbResp.IsCacheHit { - if cacheValue == nil { - return errors.New("Internal error: received illegal TiKV response") - } - copr_metrics.CoprCacheCounterHit.Add(1) - // Cache hit and is valid: use cached data as response data and we don't update the cache. - data := make([]byte, len(cacheValue.Data)) - copy(data, cacheValue.Data) - resp.pbResp.Data = data - if worker.req.Paging.Enable { - var start, end []byte - if cacheValue.PageStart != nil { - start = make([]byte, len(cacheValue.PageStart)) - copy(start, cacheValue.PageStart) - } - if cacheValue.PageEnd != nil { - end = make([]byte, len(cacheValue.PageEnd)) - copy(end, cacheValue.PageEnd) - } - // When paging protocol is used, the response key range is part of the cache data. - if start != nil || end != nil { - resp.pbResp.Range = &coprocessor.KeyRange{ - Start: start, - End: end, - } - } else { - resp.pbResp.Range = nil - } - } - // `worker.enableCollectExecutionInfo` is loaded from the instance's config. Because it's not related to the request, - // the cache key can be same when `worker.enableCollectExecutionInfo` is true or false. - // When `worker.enableCollectExecutionInfo` is false, the `resp.detail` is nil, and hit cache is still possible. - // Check `resp.detail` to avoid panic. - // Details: https://github.com/pingcap/tidb/issues/48212 - if resp.detail != nil { - resp.detail.CoprCacheHit = true - } - return nil - } - copr_metrics.CoprCacheCounterMiss.Add(1) - // Cache not hit or cache hit but not valid: update the cache if the response can be cached. - if cacheKey != nil && resp.pbResp.CanBeCached && resp.pbResp.CacheLastVersion > 0 { - if resp.detail != nil { - if worker.store.coprCache.CheckResponseAdmission(resp.pbResp.Data.Size(), resp.detail.TimeDetail.ProcessTime, task.pagingTaskIdx) { - data := make([]byte, len(resp.pbResp.Data)) - copy(data, resp.pbResp.Data) - - newCacheValue := coprCacheValue{ - Data: data, - TimeStamp: worker.req.StartTs, - RegionID: task.region.GetID(), - RegionDataVersion: resp.pbResp.CacheLastVersion, - } - // When paging protocol is used, the response key range is part of the cache data. - if r := resp.pbResp.GetRange(); r != nil { - newCacheValue.PageStart = append([]byte{}, r.GetStart()...) - newCacheValue.PageEnd = append([]byte{}, r.GetEnd()...) - } - worker.store.coprCache.Set(cacheKey, &newCacheValue) - } - } - } - return nil -} - -func (worker *copIteratorWorker) getLockResolverDetails() *util.ResolveLockDetail { - if !worker.enableCollectExecutionInfo { - return nil - } - return &util.ResolveLockDetail{} -} - -func (worker *copIteratorWorker) handleCollectExecutionInfo(bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse) { - defer func() { - worker.kvclient.Stats = nil - }() - if !worker.enableCollectExecutionInfo { - return - } - failpoint.Inject("disable-collect-execution", func(val failpoint.Value) { - if val.(bool) { - panic("shouldn't reachable") - } - }) - if resp.detail == nil { - resp.detail = new(CopRuntimeStats) - } - worker.collectCopRuntimeStats(resp.detail, bo, rpcCtx, resp) -} - -func (worker *copIteratorWorker) collectCopRuntimeStats(copStats *CopRuntimeStats, bo *Backoffer, rpcCtx *tikv.RPCContext, resp *copResponse) { - copStats.ReqStats = worker.kvclient.Stats - backoffTimes := bo.GetBackoffTimes() - copStats.BackoffTime = time.Duration(bo.GetTotalSleep()) * time.Millisecond - copStats.BackoffSleep = make(map[string]time.Duration, len(backoffTimes)) - copStats.BackoffTimes = make(map[string]int, len(backoffTimes)) - for backoff := range backoffTimes { - copStats.BackoffTimes[backoff] = backoffTimes[backoff] - copStats.BackoffSleep[backoff] = time.Duration(bo.GetBackoffSleepMS()[backoff]) * time.Millisecond - } - if rpcCtx != nil { - copStats.CalleeAddress = rpcCtx.Addr - } - if resp == nil { - return - } - sd := &util.ScanDetail{} - td := util.TimeDetail{} - if pbDetails := resp.pbResp.ExecDetailsV2; pbDetails != nil { - // Take values in `ExecDetailsV2` first. - if pbDetails.TimeDetail != nil || pbDetails.TimeDetailV2 != nil { - td.MergeFromTimeDetail(pbDetails.TimeDetailV2, pbDetails.TimeDetail) - } - if scanDetailV2 := pbDetails.ScanDetailV2; scanDetailV2 != nil { - sd.MergeFromScanDetailV2(scanDetailV2) - } - } else if pbDetails := resp.pbResp.ExecDetails; pbDetails != nil { - if timeDetail := pbDetails.TimeDetail; timeDetail != nil { - td.MergeFromTimeDetail(nil, timeDetail) - } - if scanDetail := pbDetails.ScanDetail; scanDetail != nil { - if scanDetail.Write != nil { - sd.ProcessedKeys = scanDetail.Write.Processed - sd.TotalKeys = scanDetail.Write.Total - } - } - } - copStats.ScanDetail = sd - copStats.TimeDetail = td -} - -func (worker *copIteratorWorker) collectUnconsumedCopRuntimeStats(bo *Backoffer, rpcCtx *tikv.RPCContext) { - if worker.kvclient.Stats == nil { - return - } - copStats := &CopRuntimeStats{} - worker.collectCopRuntimeStats(copStats, bo, rpcCtx, nil) - worker.unconsumedStats.Lock() - worker.unconsumedStats.stats = append(worker.unconsumedStats.stats, copStats) - worker.unconsumedStats.Unlock() - worker.kvclient.Stats = nil -} - -// CopRuntimeStats contains execution detail information. -type CopRuntimeStats struct { - execdetails.ExecDetails - ReqStats *tikv.RegionRequestRuntimeStats - - CoprCacheHit bool -} - -type unconsumedCopRuntimeStats struct { - sync.Mutex - stats []*CopRuntimeStats -} - -func (worker *copIteratorWorker) handleTiDBSendReqErr(err error, task *copTask, ch chan<- *copResponse) error { - errCode := errno.ErrUnknown - errMsg := err.Error() - if terror.ErrorEqual(err, derr.ErrTiKVServerTimeout) { - errCode = errno.ErrTiKVServerTimeout - errMsg = "TiDB server timeout, address is " + task.storeAddr - } - if terror.ErrorEqual(err, derr.ErrTiFlashServerTimeout) { - errCode = errno.ErrTiFlashServerTimeout - errMsg = "TiDB server timeout, address is " + task.storeAddr - } - selResp := tipb.SelectResponse{ - Warnings: []*tipb.Error{ - { - Code: int32(errCode), - Msg: errMsg, - }, - }, - } - data, err := proto.Marshal(&selResp) - if err != nil { - return errors.Trace(err) - } - resp := &copResponse{ - pbResp: &coprocessor.Response{ - Data: data, - }, - detail: &CopRuntimeStats{}, - } - worker.sendToRespCh(resp, ch, true) - return nil -} - -// calculateRetry splits the input ranges into two, and take one of them according to desc flag. -// It's used in paging API, to calculate which range is consumed and what needs to be retry. -// For example: -// ranges: [r1 --> r2) [r3 --> r4) -// split: [s1 --> s2) -// In normal scan order, all data before s1 is consumed, so the retry ranges should be [s1 --> r2) [r3 --> r4) -// In reverse scan order, all data after s2 is consumed, so the retry ranges should be [r1 --> r2) [r3 --> s2) -func (worker *copIteratorWorker) calculateRetry(ranges *KeyRanges, split *coprocessor.KeyRange, desc bool) *KeyRanges { - if split == nil { - return ranges - } - if desc { - left, _ := ranges.Split(split.End) - return left - } - _, right := ranges.Split(split.Start) - return right -} - -// calculateRemain calculates the remain ranges to be processed, it's used in paging API. -// For example: -// ranges: [r1 --> r2) [r3 --> r4) -// split: [s1 --> s2) -// In normal scan order, all data before s2 is consumed, so the remained ranges should be [s2 --> r4) -// In reverse scan order, all data after s1 is consumed, so the remained ranges should be [r1 --> s1) -func (worker *copIteratorWorker) calculateRemain(ranges *KeyRanges, split *coprocessor.KeyRange, desc bool) *KeyRanges { - if split == nil { - return ranges - } - if desc { - left, _ := ranges.Split(split.Start) - return left - } - _, right := ranges.Split(split.End) - return right -} - -// finished checks the flags and finished channel, it tells whether the worker is finished. -func (worker *copIteratorWorker) finished() bool { - if worker.vars != nil && worker.vars.Killed != nil { - killed := atomic.LoadUint32(worker.vars.Killed) - if killed != 0 { - logutil.BgLogger().Info( - "a killed signal is received in copIteratorWorker", - zap.Uint32("signal", killed), - ) - return true - } - } - select { - case <-worker.finishCh: - return true - default: - return false - } -} - -func (it *copIterator) Close() error { - if atomic.CompareAndSwapUint32(&it.closed, 0, 1) { - close(it.finishCh) - } - it.rpcCancel.CancelAll() - it.actionOnExceed.close() - it.wg.Wait() - return nil -} - -// copErrorResponse returns error when calling Next() -type copErrorResponse struct{ error } - -func (it copErrorResponse) Next(ctx context.Context) (kv.ResultSubset, error) { - return nil, it.error -} - -func (it copErrorResponse) Close() error { - return nil -} - -// rateLimitAction an OOM Action which is used to control the token if OOM triggered. The token number should be -// set on initial. Each time the Action is triggered, one token would be destroyed. If the count of the token is less -// than 2, the action would be delegated to the fallback action. -type rateLimitAction struct { - memory.BaseOOMAction - // enabled indicates whether the rateLimitAction is permitted to Action. 1 means permitted, 0 denied. - enabled uint32 - // totalTokenNum indicates the total token at initial - totalTokenNum uint - cond struct { - sync.Mutex - // exceeded indicates whether have encountered OOM situation. - exceeded bool - // remainingTokenNum indicates the count of tokens which still exists - remainingTokenNum uint - once sync.Once - // triggerCountForTest indicates the total count of the rateLimitAction's Action being executed - triggerCountForTest uint - } -} - -func newRateLimitAction(totalTokenNumber uint) *rateLimitAction { - return &rateLimitAction{ - totalTokenNum: totalTokenNumber, - cond: struct { - sync.Mutex - exceeded bool - remainingTokenNum uint - once sync.Once - triggerCountForTest uint - }{ - Mutex: sync.Mutex{}, - exceeded: false, - remainingTokenNum: totalTokenNumber, - once: sync.Once{}, - }, - } -} - -// Action implements ActionOnExceed.Action -func (e *rateLimitAction) Action(t *memory.Tracker) { - if !e.isEnabled() { - if fallback := e.GetFallback(); fallback != nil { - fallback.Action(t) - } - return - } - e.conditionLock() - defer e.conditionUnlock() - e.cond.once.Do(func() { - if e.cond.remainingTokenNum < 2 { - e.setEnabled(false) - logutil.BgLogger().Info("memory exceeds quota, rateLimitAction delegate to fallback action", - zap.Uint("total token count", e.totalTokenNum)) - if fallback := e.GetFallback(); fallback != nil { - fallback.Action(t) - } - return - } - failpoint.Inject("testRateLimitActionMockConsumeAndAssert", func(val failpoint.Value) { - if val.(bool) { - if e.cond.triggerCountForTest+e.cond.remainingTokenNum != e.totalTokenNum { - panic("triggerCount + remainingTokenNum not equal to totalTokenNum") - } - } - }) - logutil.BgLogger().Info("memory exceeds quota, destroy one token now.", - zap.Int64("consumed", t.BytesConsumed()), - zap.Int64("quota", t.GetBytesLimit()), - zap.Uint("total token count", e.totalTokenNum), - zap.Uint("remaining token count", e.cond.remainingTokenNum)) - e.cond.exceeded = true - e.cond.triggerCountForTest++ - }) -} - -// GetPriority get the priority of the Action. -func (e *rateLimitAction) GetPriority() int64 { - return memory.DefRateLimitPriority -} - -// destroyTokenIfNeeded will check the `exceed` flag after copWorker finished one task. -// If the exceed flag is true and there is no token been destroyed before, one token will be destroyed, -// or the token would be return back. -func (e *rateLimitAction) destroyTokenIfNeeded(returnToken func()) { - if !e.isEnabled() { - returnToken() - return - } - e.conditionLock() - defer e.conditionUnlock() - if !e.cond.exceeded { - returnToken() - return - } - // If actionOnExceed has been triggered and there is no token have been destroyed before, - // destroy one token. - e.cond.remainingTokenNum = e.cond.remainingTokenNum - 1 - e.cond.exceeded = false - e.cond.once = sync.Once{} -} - -func (e *rateLimitAction) conditionLock() { - e.cond.Lock() -} - -func (e *rateLimitAction) conditionUnlock() { - e.cond.Unlock() -} - -func (e *rateLimitAction) close() { - if !e.isEnabled() { - return - } - e.setEnabled(false) - e.conditionLock() - defer e.conditionUnlock() - e.cond.exceeded = false - e.SetFinished() -} - -func (e *rateLimitAction) setEnabled(enabled bool) { - newValue := uint32(0) - if enabled { - newValue = uint32(1) - } - atomic.StoreUint32(&e.enabled, newValue) -} - -func (e *rateLimitAction) isEnabled() bool { - return atomic.LoadUint32(&e.enabled) > 0 -} - -// priorityToPB converts priority type to wire type. -func priorityToPB(pri int) kvrpcpb.CommandPri { - switch pri { - case kv.PriorityLow: - return kvrpcpb.CommandPri_Low - case kv.PriorityHigh: - return kvrpcpb.CommandPri_High - default: - return kvrpcpb.CommandPri_Normal - } -} - -func isolationLevelToPB(level kv.IsoLevel) kvrpcpb.IsolationLevel { - switch level { - case kv.RC: - return kvrpcpb.IsolationLevel_RC - case kv.SI: - return kvrpcpb.IsolationLevel_SI - case kv.RCCheckTS: - return kvrpcpb.IsolationLevel_RCCheckTS - default: - return kvrpcpb.IsolationLevel_SI - } -} - -// BuildKeyRanges is used for test, quickly build key ranges from paired keys. -func BuildKeyRanges(keys ...string) []kv.KeyRange { - var ranges []kv.KeyRange - for i := 0; i < len(keys); i += 2 { - ranges = append(ranges, kv.KeyRange{ - StartKey: []byte(keys[i]), - EndKey: []byte(keys[i+1]), - }) - } - return ranges -} - -func optRowHint(req *kv.Request) bool { - opt := true - if req.StoreType == kv.TiDB { - return false - } - if req.RequestSource.RequestSourceInternal || req.Tp != kv.ReqTypeDAG { - // disable extra concurrency for internal tasks. - return false - } - failpoint.Inject("disableFixedRowCountHint", func(_ failpoint.Value) { - opt = false - }) - return opt -} - -func checkStoreBatchCopr(req *kv.Request) bool { - if req.Tp != kv.ReqTypeDAG || req.StoreType != kv.TiKV { - return false - } - // TODO: support keep-order batch - if req.ReplicaRead != kv.ReplicaReadLeader || req.KeepOrder { - // Disable batch copr for follower read - return false - } - // Disable batch copr when paging is enabled. - if req.Paging.Enable { - return false - } - // Disable it for internal requests to avoid regression. - if req.RequestSource.RequestSourceInternal { - return false - } - return true -} diff --git a/pkg/store/copr/mpp.go b/pkg/store/copr/mpp.go index 618b04f20abf9..32c098aa35cf2 100644 --- a/pkg/store/copr/mpp.go +++ b/pkg/store/copr/mpp.go @@ -280,10 +280,10 @@ func (c *MPPClient) CheckVisibility(startTime uint64) error { } func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, TTL int64) (int, error) { - if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountSetLastUpdateTime")); _err_ == nil { + failpoint.Inject("mppStoreCountSetLastUpdateTime", func(value failpoint.Value) { v, _ := strconv.ParseInt(value.(string), 10, 0) c.lastUpdate = v - } + }) lastUpdate := atomic.LoadInt64(&c.lastUpdate) now := time.Now().UnixMicro() @@ -295,10 +295,10 @@ func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, } } - if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountSetLastUpdateTimeP2")); _err_ == nil { + failpoint.Inject("mppStoreCountSetLastUpdateTimeP2", func(value failpoint.Value) { v, _ := strconv.ParseInt(value.(string), 10, 0) c.lastUpdate = v - } + }) if !atomic.CompareAndSwapInt64(&c.lastUpdate, lastUpdate, now) { if isInit { @@ -311,11 +311,11 @@ func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, cnt := 0 stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) - if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountPDError")); _err_ == nil { + failpoint.Inject("mppStoreCountPDError", func(value failpoint.Value) { if value.(bool) { err = errors.New("failed to get mpp store count") } - } + }) if err != nil { // always to update cache next time @@ -328,9 +328,9 @@ func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, } cnt += 1 } - if value, _err_ := failpoint.Eval(_curpkg_("mppStoreCountSetMPPCnt")); _err_ == nil { + failpoint.Inject("mppStoreCountSetMPPCnt", func(value failpoint.Value) { cnt = value.(int) - } + }) if !isInit || atomic.LoadInt64(&c.lastUpdate) == now { atomic.StoreInt32(&c.cnt, int32(cnt)) diff --git a/pkg/store/copr/mpp.go__failpoint_stash__ b/pkg/store/copr/mpp.go__failpoint_stash__ deleted file mode 100644 index 32c098aa35cf2..0000000000000 --- a/pkg/store/copr/mpp.go__failpoint_stash__ +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package copr - -import ( - "context" - "strconv" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/coprocessor" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/mpp" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/config" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/store/driver/backoff" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/tiflash" - "github.com/pingcap/tidb/pkg/util/tiflashcompute" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// MPPClient servers MPP requests. -type MPPClient struct { - store *kvStore -} - -type mppStoreCnt struct { - cnt int32 - lastUpdate int64 - initFlag int32 -} - -// GetAddress returns the network address. -func (c *batchCopTask) GetAddress() string { - return c.storeAddr -} - -// ConstructMPPTasks receives ScheduleRequest, which are actually collects of kv ranges. We allocates MPPTaskMeta for them and returns. -func (c *MPPClient) ConstructMPPTasks(ctx context.Context, req *kv.MPPBuildTasksRequest, ttl time.Duration, dispatchPolicy tiflashcompute.DispatchPolicy, tiflashReplicaReadPolicy tiflash.ReplicaRead, appendWarning func(error)) ([]kv.MPPTaskMeta, error) { - ctx = context.WithValue(ctx, tikv.TxnStartKey(), req.StartTS) - bo := backoff.NewBackofferWithVars(ctx, copBuildTaskMaxBackoff, nil) - var tasks []*batchCopTask - var err error - if req.PartitionIDAndRanges != nil { - rangesForEachPartition := make([]*KeyRanges, len(req.PartitionIDAndRanges)) - partitionIDs := make([]int64, len(req.PartitionIDAndRanges)) - for i, p := range req.PartitionIDAndRanges { - rangesForEachPartition[i] = NewKeyRanges(p.KeyRanges) - partitionIDs[i] = p.ID - } - tasks, err = buildBatchCopTasksForPartitionedTable(ctx, bo, c.store, rangesForEachPartition, kv.TiFlash, true, ttl, true, 20, partitionIDs, dispatchPolicy, tiflashReplicaReadPolicy, appendWarning) - } else { - if req.KeyRanges == nil { - return nil, errors.New("KeyRanges in MPPBuildTasksRequest is nil") - } - ranges := NewKeyRanges(req.KeyRanges) - tasks, err = buildBatchCopTasksForNonPartitionedTable(ctx, bo, c.store, ranges, kv.TiFlash, true, ttl, true, 20, dispatchPolicy, tiflashReplicaReadPolicy, appendWarning) - } - - if err != nil { - return nil, errors.Trace(err) - } - mppTasks := make([]kv.MPPTaskMeta, 0, len(tasks)) - for _, copTask := range tasks { - mppTasks = append(mppTasks, copTask) - } - return mppTasks, nil -} - -// DispatchMPPTask dispatch mpp task, and returns valid response when retry = false and err is nil -func (c *MPPClient) DispatchMPPTask(param kv.DispatchMPPTaskParam) (resp *mpp.DispatchTaskResponse, retry bool, err error) { - req := param.Req - var regionInfos []*coprocessor.RegionInfo - originalTask, ok := req.Meta.(*batchCopTask) - if ok { - for _, ri := range originalTask.regionInfos { - regionInfos = append(regionInfos, ri.toCoprocessorRegionInfo()) - } - } - - // meta for current task. - taskMeta := &mpp.TaskMeta{StartTs: req.StartTs, QueryTs: req.MppQueryID.QueryTs, LocalQueryId: req.MppQueryID.LocalQueryID, TaskId: req.ID, ServerId: req.MppQueryID.ServerID, - GatherId: req.GatherID, - Address: req.Meta.GetAddress(), - CoordinatorAddress: req.CoordinatorAddress, - ReportExecutionSummary: req.ReportExecutionSummary, - MppVersion: req.MppVersion.ToInt64(), - ResourceGroupName: req.ResourceGroupName, - ConnectionId: req.ConnectionID, - ConnectionAlias: req.ConnectionAlias, - } - - mppReq := &mpp.DispatchTaskRequest{ - Meta: taskMeta, - EncodedPlan: req.Data, - // TODO: This is only an experience value. It's better to be configurable. - Timeout: 60, - SchemaVer: req.SchemaVar, - Regions: regionInfos, - } - if originalTask != nil { - mppReq.TableRegions = originalTask.PartitionTableRegions - if mppReq.TableRegions != nil { - mppReq.Regions = nil - } - } - - wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPTask, mppReq, kvrpcpb.Context{}) - wrappedReq.StoreTp = getEndPointType(kv.TiFlash) - - // TODO: Handle dispatch task response correctly, including retry logic and cancel logic. - var rpcResp *tikvrpc.Response - invalidPDCache := config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler - bo := backoff.NewBackofferWithTikvBo(param.Bo) - - // If copTasks is not empty, we should send request according to region distribution. - // Or else it's the task without region, which always happens in high layer task without table. - // In that case - if originalTask != nil { - sender := NewRegionBatchRequestSender(c.store.GetRegionCache(), c.store.GetTiKVClient(), param.EnableCollectExecutionInfo) - rpcResp, retry, _, err = sender.SendReqToAddr(bo, originalTask.ctx, originalTask.regionInfos, wrappedReq, tikv.ReadTimeoutMedium) - // No matter what the rpc error is, we won't retry the mpp dispatch tasks. - // TODO: If we want to retry, we must redo the plan fragment cutting and task scheduling. - // That's a hard job but we can try it in the future. - if sender.GetRPCError() != nil { - logutil.BgLogger().Warn("mpp dispatch meet io error", zap.String("error", sender.GetRPCError().Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) - if invalidPDCache { - c.store.GetRegionCache().InvalidateTiFlashComputeStores() - } - err = sender.GetRPCError() - } - } else { - rpcResp, err = c.store.GetTiKVClient().SendRequest(param.Ctx, req.Meta.GetAddress(), wrappedReq, tikv.ReadTimeoutMedium) - if errors.Cause(err) == context.Canceled || status.Code(errors.Cause(err)) == codes.Canceled { - retry = false - } else if err != nil { - if invalidPDCache { - c.store.GetRegionCache().InvalidateTiFlashComputeStores() - } - if bo.Backoff(tikv.BoTiFlashRPC(), err) == nil { - retry = true - } - } - } - - if err != nil || retry { - return nil, retry, err - } - - realResp := rpcResp.Resp.(*mpp.DispatchTaskResponse) - if realResp.Error != nil { - return realResp, false, nil - } - - if len(realResp.RetryRegions) > 0 { - logutil.BgLogger().Info("TiFlash found " + strconv.Itoa(len(realResp.RetryRegions)) + " stale regions. Only first " + strconv.Itoa(min(10, len(realResp.RetryRegions))) + " regions will be logged if the log level is higher than Debug") - for index, retry := range realResp.RetryRegions { - id := tikv.NewRegionVerID(retry.Id, retry.RegionEpoch.ConfVer, retry.RegionEpoch.Version) - if index < 10 || log.GetLevel() <= zap.DebugLevel { - logutil.BgLogger().Info("invalid region because tiflash detected stale region", zap.String("region id", id.String())) - } - c.store.GetRegionCache().InvalidateCachedRegionWithReason(id, tikv.EpochNotMatch) - } - } - return realResp, retry, err -} - -// CancelMPPTasks cancels mpp tasks -// NOTE: We do not retry here, because retry is helpless when errors result from TiFlash or Network. If errors occur, the execution on TiFlash will finally stop after some minutes. -// This function is exclusively called, and only the first call succeeds sending tasks and setting all tasks as cancelled, while others will not work. -func (c *MPPClient) CancelMPPTasks(param kv.CancelMPPTasksParam) { - usedStoreAddrs := param.StoreAddr - reqs := param.Reqs - if len(usedStoreAddrs) == 0 || len(reqs) == 0 { - return - } - - firstReq := reqs[0] - killReq := &mpp.CancelTaskRequest{ - Meta: &mpp.TaskMeta{StartTs: firstReq.StartTs, GatherId: firstReq.GatherID, QueryTs: firstReq.MppQueryID.QueryTs, LocalQueryId: firstReq.MppQueryID.LocalQueryID, ServerId: firstReq.MppQueryID.ServerID, MppVersion: firstReq.MppVersion.ToInt64(), ResourceGroupName: firstReq.ResourceGroupName}, - } - - wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPCancel, killReq, kvrpcpb.Context{}) - wrappedReq.StoreTp = getEndPointType(kv.TiFlash) - - // send cancel cmd to all stores where tasks run - invalidPDCache := config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler - wg := util.WaitGroupWrapper{} - gotErr := atomic.Bool{} - for addr := range usedStoreAddrs { - storeAddr := addr - wg.Run(func() { - _, err := c.store.GetTiKVClient().SendRequest(context.Background(), storeAddr, wrappedReq, tikv.ReadTimeoutShort) - logutil.BgLogger().Debug("cancel task", zap.Uint64("query id ", firstReq.StartTs), zap.String("on addr", storeAddr), zap.Int64("mpp-version", firstReq.MppVersion.ToInt64())) - if err != nil { - logutil.BgLogger().Error("cancel task error", zap.Error(err), zap.Uint64("query id", firstReq.StartTs), zap.String("on addr", storeAddr), zap.Int64("mpp-version", firstReq.MppVersion.ToInt64())) - if invalidPDCache { - gotErr.CompareAndSwap(false, true) - } - } - }) - } - wg.Wait() - if invalidPDCache && gotErr.Load() { - c.store.GetRegionCache().InvalidateTiFlashComputeStores() - } -} - -// EstablishMPPConns build a mpp connection to receive data, return valid response when err is nil -func (c *MPPClient) EstablishMPPConns(param kv.EstablishMPPConnsParam) (*tikvrpc.MPPStreamResponse, error) { - req := param.Req - taskMeta := param.TaskMeta - connReq := &mpp.EstablishMPPConnectionRequest{ - SenderMeta: taskMeta, - ReceiverMeta: &mpp.TaskMeta{ - StartTs: req.StartTs, - GatherId: req.GatherID, - QueryTs: req.MppQueryID.QueryTs, - LocalQueryId: req.MppQueryID.LocalQueryID, - ServerId: req.MppQueryID.ServerID, - MppVersion: req.MppVersion.ToInt64(), - TaskId: -1, - ResourceGroupName: req.ResourceGroupName, - }, - } - - var err error - - wrappedReq := tikvrpc.NewRequest(tikvrpc.CmdMPPConn, connReq, kvrpcpb.Context{}) - wrappedReq.StoreTp = getEndPointType(kv.TiFlash) - - // Drain results from root task. - // We don't need to process any special error. When we meet errors, just let it fail. - rpcResp, err := c.store.GetTiKVClient().SendRequest(param.Ctx, req.Meta.GetAddress(), wrappedReq, TiFlashReadTimeoutUltraLong) - - var stream *tikvrpc.MPPStreamResponse - if rpcResp != nil && rpcResp.Resp != nil { - stream = rpcResp.Resp.(*tikvrpc.MPPStreamResponse) - } - - if err != nil { - if stream != nil { - stream.Close() - } - logutil.BgLogger().Warn("establish mpp connection meet error and cannot retry", zap.String("error", err.Error()), zap.Uint64("timestamp", taskMeta.StartTs), zap.Int64("task", taskMeta.TaskId), zap.Int64("mpp-version", taskMeta.MppVersion)) - if config.GetGlobalConfig().DisaggregatedTiFlash && !config.GetGlobalConfig().UseAutoScaler { - c.store.GetRegionCache().InvalidateTiFlashComputeStores() - } - return nil, err - } - - return stream, nil -} - -// CheckVisibility checks if it is safe to read using given ts. -func (c *MPPClient) CheckVisibility(startTime uint64) error { - return c.store.CheckVisibility(startTime) -} - -func (c *mppStoreCnt) getMPPStoreCount(ctx context.Context, pdClient pd.Client, TTL int64) (int, error) { - failpoint.Inject("mppStoreCountSetLastUpdateTime", func(value failpoint.Value) { - v, _ := strconv.ParseInt(value.(string), 10, 0) - c.lastUpdate = v - }) - - lastUpdate := atomic.LoadInt64(&c.lastUpdate) - now := time.Now().UnixMicro() - isInit := atomic.LoadInt32(&c.initFlag) != 0 - - if now-lastUpdate < TTL { - if isInit { - return int(atomic.LoadInt32(&c.cnt)), nil - } - } - - failpoint.Inject("mppStoreCountSetLastUpdateTimeP2", func(value failpoint.Value) { - v, _ := strconv.ParseInt(value.(string), 10, 0) - c.lastUpdate = v - }) - - if !atomic.CompareAndSwapInt64(&c.lastUpdate, lastUpdate, now) { - if isInit { - return int(atomic.LoadInt32(&c.cnt)), nil - } - // if has't initialized, always fetch latest mpp store info - } - - // update mpp store cache - cnt := 0 - stores, err := pdClient.GetAllStores(ctx, pd.WithExcludeTombstone()) - - failpoint.Inject("mppStoreCountPDError", func(value failpoint.Value) { - if value.(bool) { - err = errors.New("failed to get mpp store count") - } - }) - - if err != nil { - // always to update cache next time - atomic.StoreInt32(&c.initFlag, 0) - return 0, err - } - for _, s := range stores { - if !tikv.LabelFilterNoTiFlashWriteNode(s.GetLabels()) { - continue - } - cnt += 1 - } - failpoint.Inject("mppStoreCountSetMPPCnt", func(value failpoint.Value) { - cnt = value.(int) - }) - - if !isInit || atomic.LoadInt64(&c.lastUpdate) == now { - atomic.StoreInt32(&c.cnt, int32(cnt)) - atomic.StoreInt32(&c.initFlag, 1) - } - - return cnt, nil -} - -// GetMPPStoreCount returns number of TiFlash stores -func (c *MPPClient) GetMPPStoreCount() (int, error) { - return c.store.mppStoreCnt.getMPPStoreCount(c.store.store.Ctx(), c.store.store.GetPDClient(), 120*1e6 /* TTL 120sec */) -} diff --git a/pkg/store/driver/txn/binding__failpoint_binding__.go b/pkg/store/driver/txn/binding__failpoint_binding__.go deleted file mode 100644 index 5a3dafc415412..0000000000000 --- a/pkg/store/driver/txn/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package txn - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/store/driver/txn/binlog.go b/pkg/store/driver/txn/binlog.go index 96006349db43b..b17bcbaed25fa 100644 --- a/pkg/store/driver/txn/binlog.go +++ b/pkg/store/driver/txn/binlog.go @@ -65,12 +65,12 @@ func (e *binlogExecutor) Commit(ctx context.Context, commitTS int64) { wg := sync.WaitGroup{} mock := false - if val, _err_ := failpoint.Eval(_curpkg_("mockSyncBinlogCommit")); _err_ == nil { + failpoint.Inject("mockSyncBinlogCommit", func(val failpoint.Value) { if val.(bool) { wg.Add(1) mock = true } - } + }) go func() { logutil.Eventf(ctx, "start write finish binlog") binlogWriteResult := e.binInfo.WriteBinlog(e.txn.GetClusterID()) diff --git a/pkg/store/driver/txn/binlog.go__failpoint_stash__ b/pkg/store/driver/txn/binlog.go__failpoint_stash__ deleted file mode 100644 index b17bcbaed25fa..0000000000000 --- a/pkg/store/driver/txn/binlog.go__failpoint_stash__ +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package txn - -import ( - "context" - "sync" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tipb/go-binlog" - "github.com/tikv/client-go/v2/tikv" - "go.uber.org/zap" -) - -type binlogExecutor struct { - txn *tikv.KVTxn - binInfo *binloginfo.BinlogInfo -} - -func (e *binlogExecutor) Skip() { - binloginfo.RemoveOneSkippedCommitter() -} - -func (e *binlogExecutor) Prewrite(ctx context.Context, primary []byte) <-chan tikv.BinlogWriteResult { - ch := make(chan tikv.BinlogWriteResult, 1) - go func() { - logutil.Eventf(ctx, "start prewrite binlog") - bin := e.binInfo.Data - bin.StartTs = int64(e.txn.StartTS()) - if bin.Tp == binlog.BinlogType_Prewrite { - bin.PrewriteKey = primary - } - wr := e.binInfo.WriteBinlog(e.txn.GetClusterID()) - if wr.Skipped() { - e.binInfo.Data.PrewriteValue = nil - binloginfo.AddOneSkippedCommitter() - } - logutil.Eventf(ctx, "finish prewrite binlog") - ch <- wr - }() - return ch -} - -func (e *binlogExecutor) Commit(ctx context.Context, commitTS int64) { - e.binInfo.Data.Tp = binlog.BinlogType_Commit - if commitTS == 0 { - e.binInfo.Data.Tp = binlog.BinlogType_Rollback - } - e.binInfo.Data.CommitTs = commitTS - e.binInfo.Data.PrewriteValue = nil - - wg := sync.WaitGroup{} - mock := false - failpoint.Inject("mockSyncBinlogCommit", func(val failpoint.Value) { - if val.(bool) { - wg.Add(1) - mock = true - } - }) - go func() { - logutil.Eventf(ctx, "start write finish binlog") - binlogWriteResult := e.binInfo.WriteBinlog(e.txn.GetClusterID()) - err := binlogWriteResult.GetError() - if err != nil { - logutil.BgLogger().Error("failed to write binlog", - zap.Error(err)) - } - logutil.Eventf(ctx, "finish write finish binlog") - if mock { - wg.Done() - } - }() - if mock { - wg.Wait() - } -} diff --git a/pkg/store/driver/txn/txn_driver.go b/pkg/store/driver/txn/txn_driver.go index 91e48f2977c0f..03c288852a70f 100644 --- a/pkg/store/driver/txn/txn_driver.go +++ b/pkg/store/driver/txn/txn_driver.go @@ -401,7 +401,7 @@ func (txn *tikvTxn) UpdateMemBufferFlags(key []byte, flags ...kv.FlagsOp) { func (txn *tikvTxn) generateWriteConflictForLockedWithConflict(lockCtx *kv.LockCtx) error { if lockCtx.MaxLockedWithConflictTS != 0 { - failpoint.Eval(_curpkg_("lockedWithConflictOccurs")) + failpoint.Inject("lockedWithConflictOccurs", func() {}) var bufTableID, bufRest bytes.Buffer foundKey := false for k, v := range lockCtx.Values { diff --git a/pkg/store/driver/txn/txn_driver.go__failpoint_stash__ b/pkg/store/driver/txn/txn_driver.go__failpoint_stash__ deleted file mode 100644 index 03c288852a70f..0000000000000 --- a/pkg/store/driver/txn/txn_driver.go__failpoint_stash__ +++ /dev/null @@ -1,491 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package txn - -import ( - "bytes" - "context" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - derr "github.com/pingcap/tidb/pkg/store/driver/error" - "github.com/pingcap/tidb/pkg/store/driver/options" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/tracing" - tikverr "github.com/tikv/client-go/v2/error" - tikvstore "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/tikvrpc/interceptor" - "github.com/tikv/client-go/v2/txnkv" - "github.com/tikv/client-go/v2/txnkv/txnsnapshot" - "go.uber.org/zap" -) - -type tikvTxn struct { - *tikv.KVTxn - idxNameCache map[int64]*model.TableInfo - snapshotInterceptor kv.SnapshotInterceptor - // columnMapsCache is a cache used for the mutation checker - columnMapsCache any - isCommitterWorking atomic.Bool - memBuffer *memBuffer -} - -// NewTiKVTxn returns a new Transaction. -func NewTiKVTxn(txn *tikv.KVTxn) kv.Transaction { - txn.SetKVFilter(TiDBKVFilter{}) - - // init default size limits by config - entryLimit := kv.TxnEntrySizeLimit.Load() - totalLimit := kv.TxnTotalSizeLimit.Load() - txn.GetUnionStore().SetEntrySizeLimit(entryLimit, totalLimit) - - return &tikvTxn{ - txn, make(map[int64]*model.TableInfo), nil, nil, atomic.Bool{}, - newMemBuffer(txn.GetMemBuffer(), txn.IsPipelined()), - } -} - -func (txn *tikvTxn) GetTableInfo(id int64) *model.TableInfo { - return txn.idxNameCache[id] -} - -func (txn *tikvTxn) SetDiskFullOpt(level kvrpcpb.DiskFullOpt) { - txn.KVTxn.SetDiskFullOpt(level) -} - -func (txn *tikvTxn) CacheTableInfo(id int64, info *model.TableInfo) { - txn.idxNameCache[id] = info - // For partition table, also cached tblInfo with TableID for global index. - if info != nil && info.ID != id { - txn.idxNameCache[info.ID] = info - } -} - -func (txn *tikvTxn) LockKeys(ctx context.Context, lockCtx *kv.LockCtx, keysInput ...kv.Key) error { - if intest.InTest { - txn.isCommitterWorking.Store(true) - defer txn.isCommitterWorking.Store(false) - } - keys := toTiKVKeys(keysInput) - err := txn.KVTxn.LockKeys(ctx, lockCtx, keys...) - if err != nil { - return txn.extractKeyErr(err) - } - return txn.generateWriteConflictForLockedWithConflict(lockCtx) -} - -func (txn *tikvTxn) LockKeysFunc(ctx context.Context, lockCtx *kv.LockCtx, fn func(), keysInput ...kv.Key) error { - if intest.InTest { - txn.isCommitterWorking.Store(true) - defer txn.isCommitterWorking.Store(false) - } - keys := toTiKVKeys(keysInput) - err := txn.KVTxn.LockKeysFunc(ctx, lockCtx, fn, keys...) - if err != nil { - return txn.extractKeyErr(err) - } - return txn.generateWriteConflictForLockedWithConflict(lockCtx) -} - -func (txn *tikvTxn) Commit(ctx context.Context) error { - if intest.InTest { - txn.isCommitterWorking.Store(true) - } - err := txn.KVTxn.Commit(ctx) - return txn.extractKeyErr(err) -} - -func (txn *tikvTxn) GetMemDBCheckpoint() *tikv.MemDBCheckpoint { - buf := txn.KVTxn.GetMemBuffer() - return buf.Checkpoint() -} - -func (txn *tikvTxn) RollbackMemDBToCheckpoint(savepoint *tikv.MemDBCheckpoint) { - buf := txn.KVTxn.GetMemBuffer() - buf.RevertToCheckpoint(savepoint) -} - -// GetSnapshot returns the Snapshot binding to this transaction. -func (txn *tikvTxn) GetSnapshot() kv.Snapshot { - return &tikvSnapshot{txn.KVTxn.GetSnapshot(), txn.snapshotInterceptor} -} - -// Iter creates an Iterator positioned on the first entry that k <= entry's key. -// If such entry is not found, it returns an invalid Iterator with no error. -// It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded. -// The Iterator must be Closed after use. -func (txn *tikvTxn) Iter(k kv.Key, upperBound kv.Key) (iter kv.Iterator, err error) { - var dirtyIter, snapIter kv.Iterator - if dirtyIter, err = txn.GetMemBuffer().Iter(k, upperBound); err != nil { - return nil, err - } - - if snapIter, err = txn.GetSnapshot().Iter(k, upperBound); err != nil { - dirtyIter.Close() - return nil, err - } - - iter, err = NewUnionIter(dirtyIter, snapIter, false) - if err != nil { - dirtyIter.Close() - snapIter.Close() - } - - return iter, err -} - -// IterReverse creates a reversed Iterator positioned on the first entry which key is less than k. -// The returned iterator will iterate from greater key to smaller key. -// If k is nil, the returned iterator will be positioned at the last key. -func (txn *tikvTxn) IterReverse(k kv.Key, lowerBound kv.Key) (iter kv.Iterator, err error) { - var dirtyIter, snapIter kv.Iterator - - if dirtyIter, err = txn.GetMemBuffer().IterReverse(k, lowerBound); err != nil { - return nil, err - } - - if snapIter, err = txn.GetSnapshot().IterReverse(k, lowerBound); err != nil { - dirtyIter.Close() - return nil, err - } - - iter, err = NewUnionIter(dirtyIter, snapIter, true) - if err != nil { - dirtyIter.Close() - snapIter.Close() - } - - return iter, err -} - -// BatchGet gets kv from the memory buffer of statement and transaction, and the kv storage. -// Do not use len(value) == 0 or value == nil to represent non-exist. -// If a key doesn't exist, there shouldn't be any corresponding entry in the result map. -func (txn *tikvTxn) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { - r, ctx := tracing.StartRegionEx(ctx, "tikvTxn.BatchGet") - defer r.End() - return NewBufferBatchGetter(txn.GetMemBuffer(), nil, txn.GetSnapshot()).BatchGet(ctx, keys) -} - -func (txn *tikvTxn) Delete(k kv.Key) error { - err := txn.KVTxn.Delete(k) - return derr.ToTiDBErr(err) -} - -func (txn *tikvTxn) Get(ctx context.Context, k kv.Key) ([]byte, error) { - val, err := txn.GetMemBuffer().Get(ctx, k) - if kv.ErrNotExist.Equal(err) { - val, err = txn.GetSnapshot().Get(ctx, k) - } - - if err == nil && len(val) == 0 { - return nil, kv.ErrNotExist - } - - return val, err -} - -func (txn *tikvTxn) Set(k kv.Key, v []byte) error { - err := txn.KVTxn.Set(k, v) - return derr.ToTiDBErr(err) -} - -func (txn *tikvTxn) GetMemBuffer() kv.MemBuffer { - if txn.memBuffer == nil { - txn.memBuffer = newMemBuffer(txn.KVTxn.GetMemBuffer(), txn.IsPipelined()) - } - return txn.memBuffer -} - -func (txn *tikvTxn) SetOption(opt int, val any) { - if intest.InTest { - txn.assertCommitterNotWorking() - } - switch opt { - case kv.BinlogInfo: - txn.SetBinlogExecutor(&binlogExecutor{ - txn: txn.KVTxn, - binInfo: val.(*binloginfo.BinlogInfo), // val cannot be other type. - }) - case kv.SchemaChecker: - txn.SetSchemaLeaseChecker(val.(tikv.SchemaLeaseChecker)) - case kv.IsolationLevel: - level := getTiKVIsolationLevel(val.(kv.IsoLevel)) - txn.KVTxn.GetSnapshot().SetIsolationLevel(level) - case kv.Priority: - txn.KVTxn.SetPriority(getTiKVPriority(val.(int))) - case kv.NotFillCache: - txn.KVTxn.GetSnapshot().SetNotFillCache(val.(bool)) - case kv.Pessimistic: - txn.SetPessimistic(val.(bool)) - case kv.SnapshotTS: - txn.KVTxn.GetSnapshot().SetSnapshotTS(val.(uint64)) - case kv.ReplicaRead: - t := options.GetTiKVReplicaReadType(val.(kv.ReplicaReadType)) - txn.KVTxn.GetSnapshot().SetReplicaRead(t) - case kv.TaskID: - txn.KVTxn.GetSnapshot().SetTaskID(val.(uint64)) - case kv.InfoSchema: - txn.SetSchemaVer(val.(tikv.SchemaVer)) - case kv.CollectRuntimeStats: - if val == nil { - txn.KVTxn.GetSnapshot().SetRuntimeStats(nil) - } else { - txn.KVTxn.GetSnapshot().SetRuntimeStats(val.(*txnsnapshot.SnapshotRuntimeStats)) - } - case kv.SampleStep: - txn.KVTxn.GetSnapshot().SetSampleStep(val.(uint32)) - case kv.CommitHook: - txn.SetCommitCallback(val.(func(string, error))) - case kv.EnableAsyncCommit: - txn.SetEnableAsyncCommit(val.(bool)) - case kv.Enable1PC: - txn.SetEnable1PC(val.(bool)) - case kv.GuaranteeLinearizability: - txn.SetCausalConsistency(!val.(bool)) - case kv.TxnScope: - txn.SetScope(val.(string)) - case kv.IsStalenessReadOnly: - txn.KVTxn.GetSnapshot().SetIsStalenessReadOnly(val.(bool)) - case kv.MatchStoreLabels: - txn.KVTxn.GetSnapshot().SetMatchStoreLabels(val.([]*metapb.StoreLabel)) - case kv.ResourceGroupTag: - txn.KVTxn.SetResourceGroupTag(val.([]byte)) - case kv.ResourceGroupTagger: - txn.KVTxn.SetResourceGroupTagger(val.(tikvrpc.ResourceGroupTagger)) - case kv.KVFilter: - txn.KVTxn.SetKVFilter(val.(tikv.KVFilter)) - case kv.SnapInterceptor: - txn.snapshotInterceptor = val.(kv.SnapshotInterceptor) - case kv.CommitTSUpperBoundCheck: - txn.KVTxn.SetCommitTSUpperBoundCheck(val.(func(commitTS uint64) bool)) - case kv.RPCInterceptor: - txn.KVTxn.AddRPCInterceptor(val.(interceptor.RPCInterceptor)) - case kv.AssertionLevel: - txn.KVTxn.SetAssertionLevel(val.(kvrpcpb.AssertionLevel)) - case kv.TableToColumnMaps: - txn.columnMapsCache = val - case kv.RequestSourceInternal: - txn.KVTxn.SetRequestSourceInternal(val.(bool)) - case kv.RequestSourceType: - txn.KVTxn.SetRequestSourceType(val.(string)) - case kv.ExplicitRequestSourceType: - txn.KVTxn.SetExplicitRequestSourceType(val.(string)) - case kv.ReplicaReadAdjuster: - txn.KVTxn.GetSnapshot().SetReplicaReadAdjuster(val.(txnkv.ReplicaReadAdjuster)) - case kv.TxnSource: - txn.KVTxn.SetTxnSource(val.(uint64)) - case kv.ResourceGroupName: - txn.KVTxn.SetResourceGroupName(val.(string)) - case kv.LoadBasedReplicaReadThreshold: - txn.KVTxn.GetSnapshot().SetLoadBasedReplicaReadThreshold(val.(time.Duration)) - case kv.TiKVClientReadTimeout: - txn.KVTxn.GetSnapshot().SetKVReadTimeout(time.Duration(val.(uint64) * uint64(time.Millisecond))) - case kv.SizeLimits: - limits := val.(kv.TxnSizeLimits) - txn.KVTxn.GetUnionStore().SetEntrySizeLimit(limits.Entry, limits.Total) - case kv.SessionID: - txn.KVTxn.SetSessionID(val.(uint64)) - } -} - -func (txn *tikvTxn) GetOption(opt int) any { - switch opt { - case kv.GuaranteeLinearizability: - return !txn.KVTxn.IsCasualConsistency() - case kv.TxnScope: - return txn.KVTxn.GetScope() - case kv.TableToColumnMaps: - return txn.columnMapsCache - case kv.RequestSourceInternal: - return txn.RequestSourceInternal - case kv.RequestSourceType: - return txn.RequestSourceType - default: - return nil - } -} - -// SetVars sets variables to the transaction. -func (txn *tikvTxn) SetVars(vars any) { - if vs, ok := vars.(*tikv.Variables); ok { - txn.KVTxn.SetVars(vs) - } -} - -func (txn *tikvTxn) GetVars() any { - return txn.KVTxn.GetVars() -} - -func (txn *tikvTxn) extractKeyErr(err error) error { - if e, ok := errors.Cause(err).(*tikverr.ErrKeyExist); ok { - return txn.extractKeyExistsErr(e) - } - return extractKeyErr(err) -} - -func (txn *tikvTxn) extractKeyExistsErr(errExist *tikverr.ErrKeyExist) error { - var key kv.Key = errExist.GetKey() - tableID, indexID, isRecord, err := tablecodec.DecodeKeyHead(key) - if err != nil { - return genKeyExistsError("UNKNOWN", key.String(), err) - } - indexID = tablecodec.IndexIDMask & indexID - - tblInfo := txn.GetTableInfo(tableID) - if tblInfo == nil { - return genKeyExistsError("UNKNOWN", key.String(), errors.New("cannot find table info")) - } - var value []byte - if txn.IsPipelined() { - value = errExist.Value - if len(value) == 0 { - return genKeyExistsError( - "UNKNOWN", - key.String(), - errors.New("The value is empty (a delete)"), - ) - } - } else { - value, err = txn.KVTxn.GetUnionStore().GetMemBuffer().GetMemDB().SelectValueHistory(key, func(value []byte) bool { return len(value) != 0 }) - } - if err != nil { - return genKeyExistsError("UNKNOWN", key.String(), err) - } - - if isRecord { - return ExtractKeyExistsErrFromHandle(key, value, tblInfo) - } - return ExtractKeyExistsErrFromIndex(key, value, tblInfo, indexID) -} - -// SetAssertion sets an assertion for the key operation. -func (txn *tikvTxn) SetAssertion(key []byte, assertion ...kv.FlagsOp) error { - f, err := txn.GetUnionStore().GetMemBuffer().GetFlags(key) - if err != nil && !tikverr.IsErrNotFound(err) { - return err - } - if err == nil && f.HasAssertionFlags() { - return nil - } - txn.UpdateMemBufferFlags(key, assertion...) - return nil -} - -func (txn *tikvTxn) UpdateMemBufferFlags(key []byte, flags ...kv.FlagsOp) { - txn.GetUnionStore().GetMemBuffer().UpdateFlags(key, getTiKVFlagsOps(flags)...) -} - -func (txn *tikvTxn) generateWriteConflictForLockedWithConflict(lockCtx *kv.LockCtx) error { - if lockCtx.MaxLockedWithConflictTS != 0 { - failpoint.Inject("lockedWithConflictOccurs", func() {}) - var bufTableID, bufRest bytes.Buffer - foundKey := false - for k, v := range lockCtx.Values { - if v.LockedWithConflictTS >= lockCtx.MaxLockedWithConflictTS { - foundKey = true - prettyWriteKey(&bufTableID, &bufRest, []byte(k)) - break - } - } - if !foundKey { - bufTableID.WriteString("") - } - // TODO: Primary is not exported here. - primary := " primary=" - primaryRest := "" - return kv.ErrWriteConflict.FastGenByArgs(txn.StartTS(), 0, lockCtx.MaxLockedWithConflictTS, bufTableID.String(), bufRest.String(), primary, primaryRest, "LockedWithConflict") - } - return nil -} - -// StartFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. -// TODO: Update the methods' signatures in client-go to avoid this adaptor functions. -// TODO: Rename aggressive locking in client-go to fair locking. -func (txn *tikvTxn) StartFairLocking() error { - txn.KVTxn.StartAggressiveLocking() - return nil -} - -// RetryFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. -func (txn *tikvTxn) RetryFairLocking(ctx context.Context) error { - txn.KVTxn.RetryAggressiveLocking(ctx) - return nil -} - -// CancelFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. -func (txn *tikvTxn) CancelFairLocking(ctx context.Context) error { - txn.KVTxn.CancelAggressiveLocking(ctx) - return nil -} - -// DoneFairLocking adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. -func (txn *tikvTxn) DoneFairLocking(ctx context.Context) error { - txn.KVTxn.DoneAggressiveLocking(ctx) - return nil -} - -// IsInFairLockingMode adapts the method signature of `KVTxn` to satisfy kv.FairLockingController. -func (txn *tikvTxn) IsInFairLockingMode() bool { - return txn.KVTxn.IsInAggressiveLockingMode() -} - -// MayFlush wraps the flush function and extract the error. -func (txn *tikvTxn) MayFlush() error { - if !txn.IsPipelined() { - return nil - } - if intest.InTest { - txn.isCommitterWorking.Store(true) - } - _, err := txn.KVTxn.GetMemBuffer().Flush(false) - return txn.extractKeyErr(err) -} - -// assertCommitterNotWorking asserts that the committer is not working, so it's safe to modify the options for txn and committer. -// It panics when committer is working, only use it when test with --tags=intest tag. -func (txn *tikvTxn) assertCommitterNotWorking() { - if txn.isCommitterWorking.Load() { - panic("committer is working") - } -} - -// TiDBKVFilter is the filter specific to TiDB to filter out KV pairs that needn't be committed. -type TiDBKVFilter struct{} - -// IsUnnecessaryKeyValue defines which kinds of KV pairs from TiDB needn't be committed. -func (f TiDBKVFilter) IsUnnecessaryKeyValue(key, value []byte, flags tikvstore.KeyFlags) (bool, error) { - isUntouchedValue := tablecodec.IsUntouchedIndexKValue(key, value) - if isUntouchedValue && flags.HasPresumeKeyNotExists() { - logutil.BgLogger().Error("unexpected path the untouched key value with PresumeKeyNotExists flag", - zap.Stringer("key", kv.Key(key)), zap.Stringer("value", kv.Key(value)), - zap.Uint16("flags", uint16(flags)), zap.Stack("stack")) - return false, errors.Errorf( - "unexpected path the untouched key=%s value=%s contains PresumeKeyNotExists flag keyFlags=%v", - kv.Key(key).String(), kv.Key(value).String(), flags) - } - return isUntouchedValue, nil -} diff --git a/pkg/store/gcworker/binding__failpoint_binding__.go b/pkg/store/gcworker/binding__failpoint_binding__.go deleted file mode 100644 index 158fd645690b5..0000000000000 --- a/pkg/store/gcworker/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package gcworker - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/store/gcworker/gc_worker.go b/pkg/store/gcworker/gc_worker.go index 7f579774703a4..3f5c1d3d166b3 100644 --- a/pkg/store/gcworker/gc_worker.go +++ b/pkg/store/gcworker/gc_worker.go @@ -734,9 +734,9 @@ func (w *GCWorker) setGCWorkerServiceSafePoint(ctx context.Context, safePoint ui } func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency int) error { - if _, _err_ := failpoint.Eval(_curpkg_("mockRunGCJobFail")); _err_ == nil { - return errors.New("mock failure of runGCJoB") - } + failpoint.Inject("mockRunGCJobFail", func() { + failpoint.Return(errors.New("mock failure of runGCJoB")) + }) metrics.GCWorkerCounter.WithLabelValues("run_job").Inc() err := w.resolveLocks(ctx, safePoint, concurrency) @@ -832,9 +832,9 @@ func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64, concurren } else { err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) } - if _, _err_ := failpoint.Eval(_curpkg_("ignoreDeleteRangeFailed")); _err_ == nil { + failpoint.Inject("ignoreDeleteRangeFailed", func() { err = nil - } + }) if err != nil { logutil.Logger(ctx).Error("delete range failed on range", zap.String("category", "gc worker"), @@ -1132,9 +1132,9 @@ func (w *GCWorker) resolveLocks( handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { scanLimit := uint32(tikv.GCScanLockLimit) - if _, _err_ := failpoint.Eval(_curpkg_("lowScanLockLimit")); _err_ == nil { + failpoint.Inject("lowScanLockLimit", func() { scanLimit = 3 - } + }) return tikv.ResolveLocksForRange(ctx, w.regionLockResolver, safePoint, r.StartKey, r.EndKey, tikv.NewGcResolveLockMaxBackoffer, scanLimit) } @@ -1492,7 +1492,7 @@ func (w *GCWorker) saveValueToSysTable(key, value string) error { func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util.DelRangeTask, gcPlacementRuleCache map[int64]any) (err error) { // Get the job from the job history var historyJob *model.Job - if v, _err_ := failpoint.Eval(_curpkg_("mockHistoryJobForGC")); _err_ == nil { + failpoint.Inject("mockHistoryJobForGC", func(v failpoint.Value) { args, err1 := json.Marshal([]any{kv.Key{}, []int64{int64(v.(int))}}) if err1 != nil { return @@ -1503,7 +1503,7 @@ func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util TableID: int64(v.(int)), RawArgs: args, } - } + }) if historyJob == nil { historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) if err != nil { @@ -1566,7 +1566,7 @@ func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util func (w *GCWorker) doGCLabelRules(dr util.DelRangeTask) (err error) { // Get the job from the job history var historyJob *model.Job - if v, _err_ := failpoint.Eval(_curpkg_("mockHistoryJob")); _err_ == nil { + failpoint.Inject("mockHistoryJob", func(v failpoint.Value) { args, err1 := json.Marshal([]any{kv.Key{}, []int64{}, []string{v.(string)}}) if err1 != nil { return @@ -1576,7 +1576,7 @@ func (w *GCWorker) doGCLabelRules(dr util.DelRangeTask) (err error) { Type: model.ActionDropTable, RawArgs: args, } - } + }) if historyJob == nil { se := createSession(w.store) historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) diff --git a/pkg/store/gcworker/gc_worker.go__failpoint_stash__ b/pkg/store/gcworker/gc_worker.go__failpoint_stash__ deleted file mode 100644 index 3f5c1d3d166b3..0000000000000 --- a/pkg/store/gcworker/gc_worker.go__failpoint_stash__ +++ /dev/null @@ -1,1759 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gcworker - -import ( - "bytes" - "context" - "encoding/hex" - "encoding/json" - "fmt" - "math" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/errorpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/tidb/pkg/ddl" - "github.com/pingcap/tidb/pkg/ddl/label" - "github.com/pingcap/tidb/pkg/ddl/placement" - "github.com/pingcap/tidb/pkg/ddl/util" - "github.com/pingcap/tidb/pkg/domain/infosync" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/privilege" - "github.com/pingcap/tidb/pkg/session" - sessiontypes "github.com/pingcap/tidb/pkg/session/types" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/logutil" - tikverr "github.com/tikv/client-go/v2/error" - tikvstore "github.com/tikv/client-go/v2/kv" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/txnkv/rangetask" - tikvutil "github.com/tikv/client-go/v2/util" - pd "github.com/tikv/pd/client" - "go.uber.org/zap" -) - -// GCWorker periodically triggers GC process on tikv server. -type GCWorker struct { - uuid string - desc string - store kv.Storage - tikvStore tikv.Storage - pdClient pd.Client - gcIsRunning bool - lastFinish time.Time - cancel context.CancelFunc - done chan error - regionLockResolver tikv.RegionLockResolver -} - -// NewGCWorker creates a GCWorker instance. -func NewGCWorker(store kv.Storage, pdClient pd.Client) (*GCWorker, error) { - ver, err := store.CurrentVersion(kv.GlobalTxnScope) - if err != nil { - return nil, errors.Trace(err) - } - hostName, err := os.Hostname() - if err != nil { - hostName = "unknown" - } - tikvStore, ok := store.(tikv.Storage) - if !ok { - return nil, errors.New("GC should run against TiKV storage") - } - uuid := strconv.FormatUint(ver.Ver, 16) - resolverIdentifier := fmt.Sprintf("gc-worker-%s", uuid) - worker := &GCWorker{ - uuid: uuid, - desc: fmt.Sprintf("host:%s, pid:%d, start at %s", hostName, os.Getpid(), time.Now()), - store: store, - tikvStore: tikvStore, - pdClient: pdClient, - gcIsRunning: false, - lastFinish: time.Now(), - regionLockResolver: tikv.NewRegionLockResolver(resolverIdentifier, tikvStore), - done: make(chan error), - } - variable.RegisterStatistics(worker) - return worker, nil -} - -// Start starts the worker. -func (w *GCWorker) Start() { - var ctx context.Context - ctx, w.cancel = context.WithCancel(context.Background()) - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnGC) - var wg sync.WaitGroup - wg.Add(1) - go w.start(ctx, &wg) - wg.Wait() // Wait create session finish in worker, some test code depend on this to avoid race. -} - -// Close stops background goroutines. -func (w *GCWorker) Close() { - w.cancel() -} - -const ( - booleanTrue = "true" - booleanFalse = "false" - - gcWorkerTickInterval = time.Minute - gcWorkerLease = time.Minute * 2 - gcLeaderUUIDKey = "tikv_gc_leader_uuid" - gcLeaderDescKey = "tikv_gc_leader_desc" - gcLeaderLeaseKey = "tikv_gc_leader_lease" - - gcLastRunTimeKey = "tikv_gc_last_run_time" - gcRunIntervalKey = "tikv_gc_run_interval" - gcDefaultRunInterval = time.Minute * 10 - gcWaitTime = time.Minute * 1 - gcRedoDeleteRangeDelay = 24 * time.Hour - - gcLifeTimeKey = "tikv_gc_life_time" - gcDefaultLifeTime = time.Minute * 10 - gcMinLifeTime = time.Minute * 10 - gcSafePointKey = "tikv_gc_safe_point" - gcConcurrencyKey = "tikv_gc_concurrency" - gcDefaultConcurrency = 2 - gcMinConcurrency = 1 - gcMaxConcurrency = 128 - - gcEnableKey = "tikv_gc_enable" - gcDefaultEnableValue = true - - gcModeKey = "tikv_gc_mode" - gcModeCentral = "central" - gcModeDistributed = "distributed" - gcModeDefault = gcModeDistributed - - gcScanLockModeKey = "tikv_gc_scan_lock_mode" - - gcAutoConcurrencyKey = "tikv_gc_auto_concurrency" - gcDefaultAutoConcurrency = true - - gcWorkerServiceSafePointID = "gc_worker" - - // Status var names start with tidb_% - tidbGCLastRunTime = "tidb_gc_last_run_time" - tidbGCLeaderDesc = "tidb_gc_leader_desc" - tidbGCLeaderLease = "tidb_gc_leader_lease" - tidbGCLeaderUUID = "tidb_gc_leader_uuid" - tidbGCSafePoint = "tidb_gc_safe_point" -) - -var gcSafePointCacheInterval = tikv.GcSafePointCacheInterval - -var gcVariableComments = map[string]string{ - gcLeaderUUIDKey: "Current GC worker leader UUID. (DO NOT EDIT)", - gcLeaderDescKey: "Host name and pid of current GC leader. (DO NOT EDIT)", - gcLeaderLeaseKey: "Current GC worker leader lease. (DO NOT EDIT)", - gcLastRunTimeKey: "The time when last GC starts. (DO NOT EDIT)", - gcRunIntervalKey: "GC run interval, at least 10m, in Go format.", - gcLifeTimeKey: "All versions within life time will not be collected by GC, at least 10m, in Go format.", - gcSafePointKey: "All versions after safe point can be accessed. (DO NOT EDIT)", - gcConcurrencyKey: "How many goroutines used to do GC parallel, [1, 128], default 2", - gcEnableKey: "Current GC enable status", - gcModeKey: "Mode of GC, \"central\" or \"distributed\"", - gcAutoConcurrencyKey: "Let TiDB pick the concurrency automatically. If set false, tikv_gc_concurrency will be used", - gcScanLockModeKey: "Mode of scanning locks, \"physical\" or \"legacy\".(Deprecated)", -} - -const ( - unsafeDestroyRangeTimeout = 5 * time.Minute - gcTimeout = 5 * time.Minute -) - -func (w *GCWorker) start(ctx context.Context, wg *sync.WaitGroup) { - logutil.Logger(ctx).Info("start", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid)) - - w.tick(ctx) // Immediately tick once to initialize configs. - wg.Done() - - ticker := time.NewTicker(gcWorkerTickInterval) - defer ticker.Stop() - defer func() { - r := recover() - if r != nil { - logutil.Logger(ctx).Error("gcWorker", - zap.Any("r", r), - zap.Stack("stack")) - metrics.PanicCounter.WithLabelValues(metrics.LabelGCWorker).Inc() - } - }() - for { - select { - case <-ticker.C: - w.tick(ctx) - case err := <-w.done: - w.gcIsRunning = false - w.lastFinish = time.Now() - if err != nil { - logutil.Logger(ctx).Error("runGCJob", zap.String("category", "gc worker"), zap.Error(err)) - } - case <-ctx.Done(): - logutil.Logger(ctx).Info("quit", zap.String("category", "gc worker"), zap.String("uuid", w.uuid)) - return - } - } -} - -func createSession(store kv.Storage) sessiontypes.Session { - for { - se, err := session.CreateSession(store) - if err != nil { - logutil.BgLogger().Warn("create session", zap.String("category", "gc worker"), zap.Error(err)) - continue - } - // Disable privilege check for gc worker session. - privilege.BindPrivilegeManager(se, nil) - se.GetSessionVars().CommonGlobalLoaded = true - se.GetSessionVars().InRestrictedSQL = true - se.GetSessionVars().SetDiskFullOpt(kvrpcpb.DiskFullOpt_AllowedOnAlmostFull) - return se - } -} - -// GetScope gets the status variables scope. -func (w *GCWorker) GetScope(status string) variable.ScopeFlag { - return variable.DefaultStatusVarScopeFlag -} - -// Stats returns the server statistics. -func (w *GCWorker) Stats(vars *variable.SessionVars) (map[string]any, error) { - m := make(map[string]any) - if v, err := w.loadValueFromSysTable(gcLeaderUUIDKey); err == nil { - m[tidbGCLeaderUUID] = v - } - if v, err := w.loadValueFromSysTable(gcLeaderDescKey); err == nil { - m[tidbGCLeaderDesc] = v - } - if v, err := w.loadValueFromSysTable(gcLeaderLeaseKey); err == nil { - m[tidbGCLeaderLease] = v - } - if v, err := w.loadValueFromSysTable(gcLastRunTimeKey); err == nil { - m[tidbGCLastRunTime] = v - } - if v, err := w.loadValueFromSysTable(gcSafePointKey); err == nil { - m[tidbGCSafePoint] = v - } - return m, nil -} - -func (w *GCWorker) tick(ctx context.Context) { - isLeader, err := w.checkLeader(ctx) - if err != nil { - logutil.Logger(ctx).Warn("check leader", zap.String("category", "gc worker"), zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("check_leader").Inc() - return - } - if isLeader { - err = w.leaderTick(ctx) - if err != nil { - logutil.Logger(ctx).Warn("leader tick", zap.String("category", "gc worker"), zap.Error(err)) - } - } else { - // Config metrics should always be updated by leader, set them to 0 when current instance is not leader. - metrics.GCConfigGauge.WithLabelValues(gcRunIntervalKey).Set(0) - metrics.GCConfigGauge.WithLabelValues(gcLifeTimeKey).Set(0) - } -} - -// getGCSafePoint returns the current gc safe point. -func getGCSafePoint(ctx context.Context, pdClient pd.Client) (uint64, error) { - // If there is try to set gc safepoint is 0, the interface will not set gc safepoint to 0, - // it will return current gc safepoint. - safePoint, err := pdClient.UpdateGCSafePoint(ctx, 0) - if err != nil { - return 0, errors.Trace(err) - } - return safePoint, nil -} - -func (w *GCWorker) logIsGCSafePointTooEarly(ctx context.Context, safePoint uint64) error { - now, err := w.getOracleTime() - if err != nil { - return errors.Trace(err) - } - - checkTs := oracle.GoTimeToTS(now.Add(-gcDefaultLifeTime * 2)) - if checkTs > safePoint { - logutil.Logger(ctx).Info("gc safepoint is too early. "+ - "Maybe there is a bit BR/Lightning/CDC task, "+ - "or a long transaction is running "+ - "or need a tidb without setting keyspace-name to calculate and update gc safe point.", - zap.String("category", "gc worker")) - } - return nil -} - -func (w *GCWorker) runKeyspaceDeleteRange(ctx context.Context, concurrency int) error { - // Get safe point from PD. - // The GC safe point is updated only after the global GC have done resolveLocks phase globally. - // So, in the following code, resolveLocks must have been done by the global GC on the ranges to be deleted, - // so its safe to delete the ranges. - safePoint, err := getGCSafePoint(ctx, w.pdClient) - if err != nil { - logutil.Logger(ctx).Info("get gc safe point error", zap.String("category", "gc worker"), zap.Error(errors.Trace(err))) - return nil - } - - if safePoint == 0 { - logutil.Logger(ctx).Info("skip keyspace delete range, because gc safe point is 0", zap.String("category", "gc worker")) - return nil - } - - err = w.logIsGCSafePointTooEarly(ctx, safePoint) - if err != nil { - logutil.Logger(ctx).Info("log is gc safe point is too early error", zap.String("category", "gc worker"), zap.Error(errors.Trace(err))) - return nil - } - - keyspaceID := w.store.GetCodec().GetKeyspaceID() - logutil.Logger(ctx).Info("start keyspace delete range", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int("concurrency", concurrency), - zap.Uint32("keyspaceID", uint32(keyspaceID)), - zap.Uint64("GCSafepoint", safePoint)) - - // Do deleteRanges. - err = w.deleteRanges(ctx, safePoint, concurrency) - if err != nil { - logutil.Logger(ctx).Error("delete range returns an error", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("delete_range").Inc() - return errors.Trace(err) - } - - // Do redoDeleteRanges. - err = w.redoDeleteRanges(ctx, safePoint, concurrency) - if err != nil { - logutil.Logger(ctx).Error("redo-delete range returns an error", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("redo_delete_range").Inc() - return errors.Trace(err) - } - - return nil -} - -// leaderTick of GC worker checks if it should start a GC job every tick. -func (w *GCWorker) leaderTick(ctx context.Context) error { - if w.gcIsRunning { - logutil.Logger(ctx).Info("there's already a gc job running, skipped", zap.String("category", "gc worker"), - zap.String("leaderTick on", w.uuid)) - return nil - } - - concurrency, err := w.getGCConcurrency(ctx) - if err != nil { - logutil.Logger(ctx).Info("failed to get gc concurrency.", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - return errors.Trace(err) - } - - // Gc safe point is not separated by keyspace now. The whole cluster has only one global gc safe point. - // So at least one TiDB with `keyspace-name` not set is required in the whole cluster to calculate and update gc safe point. - // If `keyspace-name` is set, the TiDB node will only do its own delete range, and will not calculate gc safe point and resolve locks. - // Note that when `keyspace-name` is set, `checkLeader` will be done within the key space. - // Therefore only one TiDB node in each key space will be responsible to do delete range. - if w.store.GetCodec().GetKeyspace() != nil { - err = w.runKeyspaceGCJob(ctx, concurrency) - if err != nil { - return errors.Trace(err) - } - return nil - } - - ok, safePoint, err := w.prepare(ctx) - if err != nil { - metrics.GCJobFailureCounter.WithLabelValues("prepare").Inc() - return errors.Trace(err) - } else if !ok { - return nil - } - // When the worker is just started, or an old GC job has just finished, - // wait a while before starting a new job. - if time.Since(w.lastFinish) < gcWaitTime { - logutil.Logger(ctx).Info("another gc job has just finished, skipped.", zap.String("category", "gc worker"), - zap.String("leaderTick on ", w.uuid)) - return nil - } - - w.gcIsRunning = true - logutil.Logger(ctx).Info("starts the whole job", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("safePoint", safePoint), - zap.Int("concurrency", concurrency)) - go func() { - w.done <- w.runGCJob(ctx, safePoint, concurrency) - }() - return nil -} - -func (w *GCWorker) runKeyspaceGCJob(ctx context.Context, concurrency int) error { - // When the worker is just started, or an old GC job has just finished, - // wait a while before starting a new job. - if time.Since(w.lastFinish) < gcWaitTime { - logutil.Logger(ctx).Info("another keyspace gc job has just finished, skipped.", zap.String("category", "gc worker"), - zap.String("leaderTick on ", w.uuid)) - return nil - } - - now, err := w.getOracleTime() - if err != nil { - return errors.Trace(err) - } - ok, err := w.checkGCInterval(now) - if err != nil || !ok { - return errors.Trace(err) - } - - go func() { - w.done <- w.runKeyspaceDeleteRange(ctx, concurrency) - }() - - err = w.saveTime(gcLastRunTimeKey, now) - if err != nil { - return errors.Trace(err) - } - - return nil -} - -// prepare checks preconditions for starting a GC job. It returns a bool -// that indicates whether the GC job should start and the new safePoint. -func (w *GCWorker) prepare(ctx context.Context) (bool, uint64, error) { - // Add a transaction here is to prevent following situations: - // 1. GC check gcEnable is true, continue to do GC - // 2. The user sets gcEnable to false - // 3. The user gets `tikv_gc_safe_point` value is t1, then the user thinks the data after time t1 won't be clean by GC. - // 4. GC update `tikv_gc_safe_point` value to t2, continue do GC in this round. - // Then the data record that has been dropped between time t1 and t2, will be cleaned by GC, but the user thinks the data after t1 won't be clean by GC. - se := createSession(w.store) - defer se.Close() - _, err := se.ExecuteInternal(ctx, "BEGIN") - if err != nil { - return false, 0, errors.Trace(err) - } - doGC, safePoint, err := w.checkPrepare(ctx) - if doGC { - err = se.CommitTxn(ctx) - if err != nil { - return false, 0, errors.Trace(err) - } - } else { - se.RollbackTxn(ctx) - } - return doGC, safePoint, errors.Trace(err) -} - -func (w *GCWorker) checkPrepare(ctx context.Context) (bool, uint64, error) { - enable, err := w.checkGCEnable() - if err != nil { - return false, 0, errors.Trace(err) - } - - if !enable { - logutil.Logger(ctx).Warn("gc status is disabled.", zap.String("category", "gc worker")) - return false, 0, nil - } - now, err := w.getOracleTime() - if err != nil { - return false, 0, errors.Trace(err) - } - ok, err := w.checkGCInterval(now) - if err != nil || !ok { - return false, 0, errors.Trace(err) - } - newSafePoint, newSafePointValue, err := w.calcNewSafePoint(ctx, now) - if err != nil || newSafePoint == nil { - return false, 0, errors.Trace(err) - } - err = w.saveTime(gcLastRunTimeKey, now) - if err != nil { - return false, 0, errors.Trace(err) - } - err = w.saveTime(gcSafePointKey, *newSafePoint) - if err != nil { - return false, 0, errors.Trace(err) - } - return true, newSafePointValue, nil -} - -func (w *GCWorker) calcGlobalMinStartTS(ctx context.Context) (uint64, error) { - kvs, err := w.tikvStore.GetSafePointKV().GetWithPrefix(infosync.ServerMinStartTSPath) - if err != nil { - return 0, err - } - - var globalMinStartTS uint64 = math.MaxUint64 - for _, v := range kvs { - minStartTS, err := strconv.ParseUint(string(v.Value), 10, 64) - if err != nil { - logutil.Logger(ctx).Warn("parse minStartTS failed", zap.Error(err)) - continue - } - if minStartTS < globalMinStartTS { - globalMinStartTS = minStartTS - } - } - return globalMinStartTS, nil -} - -// calcNewSafePoint uses the current global transaction min start timestamp to calculate the new safe point. -func (w *GCWorker) calcSafePointByMinStartTS(ctx context.Context, safePoint uint64) uint64 { - globalMinStartTS, err := w.calcGlobalMinStartTS(ctx) - if err != nil { - logutil.Logger(ctx).Warn("get all minStartTS failed", zap.Error(err)) - return safePoint - } - - // If the lock.ts <= max_ts(safePoint), it will be collected and resolved by the gc worker, - // the locks of ongoing pessimistic transactions could be resolved by the gc worker and then - // the transaction is aborted, decrement the value by 1 to avoid this. - globalMinStartAllowedTS := globalMinStartTS - if globalMinStartTS > 0 { - globalMinStartAllowedTS = globalMinStartTS - 1 - } - - if globalMinStartAllowedTS < safePoint { - logutil.Logger(ctx).Info("gc safepoint blocked by a running session", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("globalMinStartTS", globalMinStartTS), - zap.Uint64("globalMinStartAllowedTS", globalMinStartAllowedTS), - zap.Uint64("safePoint", safePoint)) - safePoint = globalMinStartAllowedTS - } - return safePoint -} - -func (w *GCWorker) getOracleTime() (time.Time, error) { - currentVer, err := w.store.CurrentVersion(kv.GlobalTxnScope) - if err != nil { - return time.Time{}, errors.Trace(err) - } - return oracle.GetTimeFromTS(currentVer.Ver), nil -} - -func (w *GCWorker) checkGCEnable() (bool, error) { - return w.loadBooleanWithDefault(gcEnableKey, gcDefaultEnableValue) -} - -func (w *GCWorker) checkUseAutoConcurrency() (bool, error) { - return w.loadBooleanWithDefault(gcAutoConcurrencyKey, gcDefaultAutoConcurrency) -} - -func (w *GCWorker) loadBooleanWithDefault(key string, defaultValue bool) (bool, error) { - str, err := w.loadValueFromSysTable(key) - if err != nil { - return false, errors.Trace(err) - } - if str == "" { - // Save default value for gc enable key. The default value is always true. - defaultValueStr := booleanFalse - if defaultValue { - defaultValueStr = booleanTrue - } - err = w.saveValueToSysTable(key, defaultValueStr) - if err != nil { - return defaultValue, errors.Trace(err) - } - return defaultValue, nil - } - return strings.EqualFold(str, booleanTrue), nil -} - -func (w *GCWorker) getGCConcurrency(ctx context.Context) (int, error) { - useAutoConcurrency, err := w.checkUseAutoConcurrency() - if err != nil { - logutil.Logger(ctx).Error("failed to load config gc_auto_concurrency. use default value.", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - useAutoConcurrency = gcDefaultAutoConcurrency - } - if !useAutoConcurrency { - return w.loadGCConcurrencyWithDefault() - } - - stores, err := w.getStoresForGC(ctx) - concurrency := len(stores) - if err != nil { - logutil.Logger(ctx).Error("failed to get up stores to calculate concurrency. use config.", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - - concurrency, err = w.loadGCConcurrencyWithDefault() - if err != nil { - logutil.Logger(ctx).Error("failed to load gc concurrency from config. use default value.", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - concurrency = gcDefaultConcurrency - } - } - - if concurrency == 0 { - logutil.Logger(ctx).Error("no store is up", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid)) - return 0, errors.New("[gc worker] no store is up") - } - - return concurrency, nil -} - -func (w *GCWorker) checkGCInterval(now time.Time) (bool, error) { - runInterval, err := w.loadDurationWithDefault(gcRunIntervalKey, gcDefaultRunInterval) - if err != nil { - return false, errors.Trace(err) - } - metrics.GCConfigGauge.WithLabelValues(gcRunIntervalKey).Set(runInterval.Seconds()) - lastRun, err := w.loadTime(gcLastRunTimeKey) - if err != nil { - return false, errors.Trace(err) - } - - if lastRun != nil && lastRun.Add(*runInterval).After(now) { - logutil.BgLogger().Debug("skipping garbage collection because gc interval hasn't elapsed since last run", zap.String("category", "gc worker"), - zap.String("leaderTick on", w.uuid), - zap.Duration("interval", *runInterval), - zap.Time("last run", *lastRun)) - return false, nil - } - - return true, nil -} - -// validateGCLifeTime checks whether life time is small than min gc life time. -func (w *GCWorker) validateGCLifeTime(lifeTime time.Duration) (time.Duration, error) { - if lifeTime >= gcMinLifeTime { - return lifeTime, nil - } - - logutil.BgLogger().Info("invalid gc life time", zap.String("category", "gc worker"), - zap.Duration("get gc life time", lifeTime), - zap.Duration("min gc life time", gcMinLifeTime)) - - err := w.saveDuration(gcLifeTimeKey, gcMinLifeTime) - return gcMinLifeTime, err -} - -func (w *GCWorker) calcNewSafePoint(ctx context.Context, now time.Time) (*time.Time, uint64, error) { - lifeTime, err := w.loadDurationWithDefault(gcLifeTimeKey, gcDefaultLifeTime) - if err != nil { - return nil, 0, errors.Trace(err) - } - *lifeTime, err = w.validateGCLifeTime(*lifeTime) - if err != nil { - return nil, 0, err - } - metrics.GCConfigGauge.WithLabelValues(gcLifeTimeKey).Set(lifeTime.Seconds()) - - lastSafePoint, err := w.loadTime(gcSafePointKey) - if err != nil { - return nil, 0, errors.Trace(err) - } - - safePointValue := w.calcSafePointByMinStartTS(ctx, oracle.GoTimeToTS(now.Add(-*lifeTime))) - safePointValue, err = w.setGCWorkerServiceSafePoint(ctx, safePointValue) - if err != nil { - return nil, 0, errors.Trace(err) - } - - // safepoint is recorded in time.Time format which strips the logical part of the timestamp. - // To prevent the GC worker from keeping working due to the loss of logical part when the - // safe point isn't changed, we should compare them in time.Time format. - safePoint := oracle.GetTimeFromTS(safePointValue) - // We should never decrease safePoint. - if lastSafePoint != nil && !safePoint.After(*lastSafePoint) { - logutil.BgLogger().Info("last safe point is later than current one."+ - "No need to gc."+ - "This might be caused by manually enlarging gc lifetime", - zap.String("category", "gc worker"), - zap.String("leaderTick on", w.uuid), - zap.Time("last safe point", *lastSafePoint), - zap.Time("current safe point", safePoint)) - return nil, 0, nil - } - return &safePoint, safePointValue, nil -} - -// setGCWorkerServiceSafePoint sets the given safePoint as TiDB's service safePoint to PD, and returns the current minimal -// service safePoint among all services. -func (w *GCWorker) setGCWorkerServiceSafePoint(ctx context.Context, safePoint uint64) (uint64, error) { - // Sets TTL to MAX to make it permanently valid. - minSafePoint, err := w.pdClient.UpdateServiceGCSafePoint(ctx, gcWorkerServiceSafePointID, math.MaxInt64, safePoint) - if err != nil { - logutil.Logger(ctx).Error("failed to update service safe point", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("update_service_safe_point").Inc() - return 0, errors.Trace(err) - } - if minSafePoint < safePoint { - logutil.Logger(ctx).Info("there's another service in the cluster requires an earlier safe point. "+ - "gc will continue with the earlier one", - zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("ourSafePoint", safePoint), - zap.Uint64("minSafePoint", minSafePoint), - ) - safePoint = minSafePoint - } - return safePoint, nil -} - -func (w *GCWorker) runGCJob(ctx context.Context, safePoint uint64, concurrency int) error { - failpoint.Inject("mockRunGCJobFail", func() { - failpoint.Return(errors.New("mock failure of runGCJoB")) - }) - metrics.GCWorkerCounter.WithLabelValues("run_job").Inc() - - err := w.resolveLocks(ctx, safePoint, concurrency) - if err != nil { - logutil.Logger(ctx).Error("resolve locks returns an error", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("resolve_lock").Inc() - return errors.Trace(err) - } - - // Save safe point to pd. - err = w.saveSafePoint(w.tikvStore.GetSafePointKV(), safePoint) - if err != nil { - logutil.Logger(ctx).Error("failed to save safe point to PD", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("save_safe_point").Inc() - return errors.Trace(err) - } - // Sleep to wait for all other tidb instances update their safepoint cache. - time.Sleep(gcSafePointCacheInterval) - - err = w.deleteRanges(ctx, safePoint, concurrency) - if err != nil { - logutil.Logger(ctx).Error("delete range returns an error", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("delete_range").Inc() - return errors.Trace(err) - } - err = w.redoDeleteRanges(ctx, safePoint, concurrency) - if err != nil { - logutil.Logger(ctx).Error("redo-delete range returns an error", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("redo_delete_range").Inc() - return errors.Trace(err) - } - - if w.checkUseDistributedGC() { - err = w.uploadSafePointToPD(ctx, safePoint) - if err != nil { - logutil.Logger(ctx).Error("failed to upload safe point to PD", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("upload_safe_point").Inc() - return errors.Trace(err) - } - } else { - err = w.doGC(ctx, safePoint, concurrency) - if err != nil { - logutil.Logger(ctx).Error("do GC returns an error", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("gc").Inc() - return errors.Trace(err) - } - } - - return nil -} - -// deleteRanges processes all delete range records whose ts < safePoint in table `gc_delete_range` -// `concurrency` specifies the concurrency to send NotifyDeleteRange. -func (w *GCWorker) deleteRanges(ctx context.Context, safePoint uint64, concurrency int) error { - metrics.GCWorkerCounter.WithLabelValues("delete_range").Inc() - - se := createSession(w.store) - defer se.Close() - ranges, err := util.LoadDeleteRanges(ctx, se, safePoint) - if err != nil { - return errors.Trace(err) - } - - v2, err := util.IsRaftKv2(ctx, se) - if err != nil { - return errors.Trace(err) - } - // Cache table ids on which placement rules have been GC-ed, to avoid redundantly GC the same table id multiple times. - gcPlacementRuleCache := make(map[int64]any, len(ranges)) - - logutil.Logger(ctx).Info("start delete ranges", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int("ranges", len(ranges))) - startTime := time.Now() - for _, r := range ranges { - startKey, endKey := r.Range() - if v2 { - // In raftstore-v2, we use delete range instead to avoid deletion omission - task := rangetask.NewDeleteRangeTask(w.tikvStore, startKey, endKey, concurrency) - err = task.Execute(ctx) - } else { - err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) - } - failpoint.Inject("ignoreDeleteRangeFailed", func() { - err = nil - }) - - if err != nil { - logutil.Logger(ctx).Error("delete range failed on range", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Stringer("startKey", startKey), - zap.Stringer("endKey", endKey), - zap.Error(err)) - continue - } - - if err := w.doGCPlacementRules(se, safePoint, r, gcPlacementRuleCache); err != nil { - logutil.Logger(ctx).Error("gc placement rules failed on range", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int64("jobID", r.JobID), - zap.Int64("elementID", r.ElementID), - zap.Error(err)) - continue - } - if err := w.doGCLabelRules(r); err != nil { - logutil.Logger(ctx).Error("gc label rules failed on range", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int64("jobID", r.JobID), - zap.Int64("elementID", r.ElementID), - zap.Error(err)) - continue - } - - err = util.CompleteDeleteRange(se, r, !v2) - if err != nil { - logutil.Logger(ctx).Error("failed to mark delete range task done", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Stringer("startKey", startKey), - zap.Stringer("endKey", endKey), - zap.Error(err)) - metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("save").Inc() - } - } - logutil.Logger(ctx).Info("finish delete ranges", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int("num of ranges", len(ranges)), - zap.Duration("cost time", time.Since(startTime))) - metrics.GCHistogram.WithLabelValues("delete_ranges").Observe(time.Since(startTime).Seconds()) - return nil -} - -// redoDeleteRanges checks all deleted ranges whose ts is at least `lifetime + 24h` ago. See TiKV RFC #2. -// `concurrency` specifies the concurrency to send NotifyDeleteRange. -func (w *GCWorker) redoDeleteRanges(ctx context.Context, safePoint uint64, concurrency int) error { - metrics.GCWorkerCounter.WithLabelValues("redo_delete_range").Inc() - - // We check delete range records that are deleted about 24 hours ago. - redoDeleteRangesTs := safePoint - oracle.ComposeTS(int64(gcRedoDeleteRangeDelay.Seconds())*1000, 0) - - se := createSession(w.store) - ranges, err := util.LoadDoneDeleteRanges(ctx, se, redoDeleteRangesTs) - se.Close() - if err != nil { - return errors.Trace(err) - } - - logutil.Logger(ctx).Info("start redo-delete ranges", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int("num of ranges", len(ranges))) - startTime := time.Now() - for _, r := range ranges { - startKey, endKey := r.Range() - - err = w.doUnsafeDestroyRangeRequest(ctx, startKey, endKey, concurrency) - if err != nil { - logutil.Logger(ctx).Error("redo-delete range failed on range", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Stringer("startKey", startKey), - zap.Stringer("endKey", endKey), - zap.Error(err)) - continue - } - - se := createSession(w.store) - err := util.DeleteDoneRecord(se, r) - se.Close() - if err != nil { - logutil.Logger(ctx).Error("failed to remove delete_range_done record", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Stringer("startKey", startKey), - zap.Stringer("endKey", endKey), - zap.Error(err)) - metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("save_redo").Inc() - } - } - logutil.Logger(ctx).Info("finish redo-delete ranges", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int("num of ranges", len(ranges)), - zap.Duration("cost time", time.Since(startTime))) - metrics.GCHistogram.WithLabelValues("redo_delete_ranges").Observe(time.Since(startTime).Seconds()) - return nil -} - -func (w *GCWorker) doUnsafeDestroyRangeRequest(ctx context.Context, startKey []byte, endKey []byte, _ int) error { - // Get all stores every time deleting a region. So the store list is less probably to be stale. - stores, err := w.getStoresForGC(ctx) - if err != nil { - logutil.Logger(ctx).Error("delete ranges: got an error while trying to get store list from PD", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("get_stores").Inc() - return errors.Trace(err) - } - - req := tikvrpc.NewRequest(tikvrpc.CmdUnsafeDestroyRange, &kvrpcpb.UnsafeDestroyRangeRequest{ - StartKey: startKey, - EndKey: endKey, - }, kvrpcpb.Context{DiskFullOpt: kvrpcpb.DiskFullOpt_AllowedOnAlmostFull}) - - var wg sync.WaitGroup - errChan := make(chan error, len(stores)) - - for _, store := range stores { - address := store.Address - storeID := store.Id - wg.Add(1) - go func() { - defer wg.Done() - - resp, err1 := w.tikvStore.GetTiKVClient().SendRequest(ctx, address, req, unsafeDestroyRangeTimeout) - if err1 == nil { - if resp == nil || resp.Resp == nil { - err1 = errors.Errorf("unsafe destroy range returns nil response from store %v", storeID) - } else { - errStr := (resp.Resp.(*kvrpcpb.UnsafeDestroyRangeResponse)).Error - if len(errStr) > 0 { - err1 = errors.Errorf("unsafe destroy range failed on store %v: %s", storeID, errStr) - } - } - } - - if err1 != nil { - metrics.GCUnsafeDestroyRangeFailuresCounterVec.WithLabelValues("send").Inc() - } - errChan <- err1 - }() - } - - var errs []string - for range stores { - err1 := <-errChan - if err1 != nil { - errs = append(errs, err1.Error()) - } - } - - wg.Wait() - - if len(errs) > 0 { - return errors.Errorf("[gc worker] destroy range finished with errors: %v", errs) - } - - return nil -} - -// needsGCOperationForStore checks if the store-level requests related to GC needs to be sent to the store. The store-level -// requests includes UnsafeDestroyRange, PhysicalScanLock, etc. -func needsGCOperationForStore(store *metapb.Store) (bool, error) { - // TombStone means the store has been removed from the cluster and there isn't any peer on the store, so needn't do GC for it. - // Offline means the store is being removed from the cluster and it becomes tombstone after all peers are removed from it, - // so we need to do GC for it. - if store.State == metapb.StoreState_Tombstone { - return false, nil - } - - engineLabel := "" - for _, label := range store.GetLabels() { - if label.GetKey() == placement.EngineLabelKey { - engineLabel = label.GetValue() - break - } - } - - switch engineLabel { - case placement.EngineLabelTiFlash: - // For a TiFlash node, it uses other approach to delete dropped tables, so it's safe to skip sending - // UnsafeDestroyRange requests; it has only learner peers and their data must exist in TiKV, so it's safe to - // skip physical resolve locks for it. - return false, nil - - case placement.EngineLabelTiFlashCompute: - logutil.BgLogger().Debug("will ignore gc tiflash_compute node", zap.String("category", "gc worker")) - return false, nil - - case placement.EngineLabelTiKV, "": - // If no engine label is set, it should be a TiKV node. - return true, nil - - default: - return true, errors.Errorf("unsupported store engine \"%v\" with storeID %v, addr %v", - engineLabel, - store.GetId(), - store.GetAddress()) - } -} - -// getStoresForGC gets the list of stores that needs to be processed during GC. -func (w *GCWorker) getStoresForGC(ctx context.Context) ([]*metapb.Store, error) { - stores, err := w.pdClient.GetAllStores(ctx) - if err != nil { - return nil, errors.Trace(err) - } - - upStores := make([]*metapb.Store, 0, len(stores)) - for _, store := range stores { - needsGCOp, err := needsGCOperationForStore(store) - if err != nil { - return nil, errors.Trace(err) - } - if needsGCOp { - upStores = append(upStores, store) - } - } - return upStores, nil -} - -func (w *GCWorker) getStoresMapForGC(ctx context.Context) (map[uint64]*metapb.Store, error) { - stores, err := w.getStoresForGC(ctx) - if err != nil { - return nil, err - } - - storesMap := make(map[uint64]*metapb.Store, len(stores)) - for _, store := range stores { - storesMap[store.Id] = store - } - - return storesMap, nil -} - -func (w *GCWorker) loadGCConcurrencyWithDefault() (int, error) { - str, err := w.loadValueFromSysTable(gcConcurrencyKey) - if err != nil { - return gcDefaultConcurrency, errors.Trace(err) - } - if str == "" { - err = w.saveValueToSysTable(gcConcurrencyKey, strconv.Itoa(gcDefaultConcurrency)) - if err != nil { - return gcDefaultConcurrency, errors.Trace(err) - } - return gcDefaultConcurrency, nil - } - - jobConcurrency, err := strconv.Atoi(str) - if err != nil { - return gcDefaultConcurrency, err - } - - if jobConcurrency < gcMinConcurrency { - jobConcurrency = gcMinConcurrency - } - - if jobConcurrency > gcMaxConcurrency { - jobConcurrency = gcMaxConcurrency - } - - return jobConcurrency, nil -} - -// Central mode is deprecated in v5.0. This function will always return true. -func (w *GCWorker) checkUseDistributedGC() bool { - mode, err := w.loadValueFromSysTable(gcModeKey) - if err == nil && mode == "" { - err = w.saveValueToSysTable(gcModeKey, gcModeDefault) - } - if err != nil { - logutil.BgLogger().Error("failed to load gc mode, fall back to distributed mode", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Error(err)) - metrics.GCJobFailureCounter.WithLabelValues("check_gc_mode").Inc() - } else if strings.EqualFold(mode, gcModeCentral) { - logutil.BgLogger().Warn("distributed mode will be used as central mode is deprecated", zap.String("category", "gc worker")) - } else if !strings.EqualFold(mode, gcModeDistributed) { - logutil.BgLogger().Warn("distributed mode will be used", zap.String("category", "gc worker"), - zap.String("invalid gc mode", mode)) - } - return true -} - -func (w *GCWorker) resolveLocks( - ctx context.Context, - safePoint uint64, - concurrency int, -) error { - metrics.GCWorkerCounter.WithLabelValues("resolve_locks").Inc() - logutil.Logger(ctx).Info("start resolve locks", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("safePoint", safePoint), - zap.Int("concurrency", concurrency)) - startTime := time.Now() - - handler := func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { - scanLimit := uint32(tikv.GCScanLockLimit) - failpoint.Inject("lowScanLockLimit", func() { - scanLimit = 3 - }) - return tikv.ResolveLocksForRange(ctx, w.regionLockResolver, safePoint, r.StartKey, r.EndKey, tikv.NewGcResolveLockMaxBackoffer, scanLimit) - } - - runner := rangetask.NewRangeTaskRunner("resolve-locks-runner", w.tikvStore, concurrency, handler) - // Run resolve lock on the whole TiKV cluster. Empty keys means the range is unbounded. - err := runner.RunOnRange(ctx, []byte(""), []byte("")) - if err != nil { - logutil.Logger(ctx).Error("resolve locks failed", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("safePoint", safePoint), - zap.Error(err)) - return errors.Trace(err) - } - - logutil.Logger(ctx).Info("finish resolve locks", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("safePoint", safePoint), - zap.Int("regions", runner.CompletedRegions())) - metrics.GCHistogram.WithLabelValues("resolve_locks").Observe(time.Since(startTime).Seconds()) - return nil -} - -const gcOneRegionMaxBackoff = 20000 - -func (w *GCWorker) uploadSafePointToPD(ctx context.Context, safePoint uint64) error { - var newSafePoint uint64 - var err error - - bo := tikv.NewBackofferWithVars(ctx, gcOneRegionMaxBackoff, nil) - for { - newSafePoint, err = w.pdClient.UpdateGCSafePoint(ctx, safePoint) - if err != nil { - if errors.Cause(err) == context.Canceled { - return errors.Trace(err) - } - err = bo.Backoff(tikv.BoPDRPC(), errors.Errorf("failed to upload safe point to PD, err: %v", err)) - if err != nil { - return errors.Trace(err) - } - continue - } - break - } - - if newSafePoint != safePoint { - logutil.Logger(ctx).Warn("PD rejected safe point", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("our safe point", safePoint), - zap.Uint64("using another safe point", newSafePoint)) - return errors.Errorf("PD rejected our safe point %v but is using another safe point %v", safePoint, newSafePoint) - } - logutil.Logger(ctx).Info("sent safe point to PD", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("safe point", safePoint)) - return nil -} - -func (w *GCWorker) doGCForRange(ctx context.Context, startKey []byte, endKey []byte, safePoint uint64) (rangetask.TaskStat, error) { - var stat rangetask.TaskStat - defer func() { - metrics.GCActionRegionResultCounter.WithLabelValues("success").Add(float64(stat.CompletedRegions)) - metrics.GCActionRegionResultCounter.WithLabelValues("fail").Add(float64(stat.FailedRegions)) - }() - key := startKey - for { - bo := tikv.NewBackofferWithVars(ctx, gcOneRegionMaxBackoff, nil) - loc, err := w.tikvStore.GetRegionCache().LocateKey(bo, key) - if err != nil { - return stat, errors.Trace(err) - } - - var regionErr *errorpb.Error - regionErr, err = w.doGCForRegion(bo, safePoint, loc.Region) - - // we check regionErr here first, because we know 'regionErr' and 'err' should not return together, to keep it to - // make the process correct. - if regionErr != nil { - err = bo.Backoff(tikv.BoRegionMiss(), errors.New(regionErr.String())) - if err == nil { - continue - } - } - - if err != nil { - logutil.BgLogger().Warn("[gc worker]", - zap.String("uuid", w.uuid), - zap.String("gc for range", fmt.Sprintf("[%d, %d)", startKey, endKey)), - zap.Uint64("safePoint", safePoint), - zap.Error(err)) - stat.FailedRegions++ - } else { - stat.CompletedRegions++ - } - - key = loc.EndKey - if len(key) == 0 || bytes.Compare(key, endKey) >= 0 { - break - } - } - - return stat, nil -} - -// doGCForRegion used for gc for region. -// these two errors should not return together, for more, see the func 'doGC' -func (w *GCWorker) doGCForRegion(bo *tikv.Backoffer, safePoint uint64, region tikv.RegionVerID) (*errorpb.Error, error) { - req := tikvrpc.NewRequest(tikvrpc.CmdGC, &kvrpcpb.GCRequest{ - SafePoint: safePoint, - }) - - resp, err := w.tikvStore.SendReq(bo, req, region, gcTimeout) - if err != nil { - return nil, errors.Trace(err) - } - regionErr, err := resp.GetRegionError() - if err != nil { - return nil, errors.Trace(err) - } - if regionErr != nil { - return regionErr, nil - } - - if resp.Resp == nil { - return nil, errors.Trace(tikverr.ErrBodyMissing) - } - gcResp := resp.Resp.(*kvrpcpb.GCResponse) - if gcResp.GetError() != nil { - return nil, errors.Errorf("unexpected gc error: %s", gcResp.GetError()) - } - - return nil, nil -} - -func (w *GCWorker) doGC(ctx context.Context, safePoint uint64, concurrency int) error { - metrics.GCWorkerCounter.WithLabelValues("do_gc").Inc() - logutil.Logger(ctx).Info("start doing gc for all keys", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int("concurrency", concurrency), - zap.Uint64("safePoint", safePoint)) - startTime := time.Now() - - runner := rangetask.NewRangeTaskRunner( - "gc-runner", - w.tikvStore, - concurrency, - func(ctx context.Context, r tikvstore.KeyRange) (rangetask.TaskStat, error) { - return w.doGCForRange(ctx, r.StartKey, r.EndKey, safePoint) - }) - - err := runner.RunOnRange(ctx, []byte(""), []byte("")) - if err != nil { - logutil.Logger(ctx).Warn("failed to do gc for all keys", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Int("concurrency", concurrency), - zap.Error(err)) - return errors.Trace(err) - } - - successRegions := runner.CompletedRegions() - failedRegions := runner.FailedRegions() - - logutil.Logger(ctx).Info("finished doing gc for all keys", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid), - zap.Uint64("safePoint", safePoint), - zap.Int("successful regions", successRegions), - zap.Int("failed regions", failedRegions), - zap.Duration("total cost time", time.Since(startTime))) - metrics.GCHistogram.WithLabelValues("do_gc").Observe(time.Since(startTime).Seconds()) - - return nil -} - -func (w *GCWorker) checkLeader(ctx context.Context) (bool, error) { - metrics.GCWorkerCounter.WithLabelValues("check_leader").Inc() - se := createSession(w.store) - defer se.Close() - - _, err := se.ExecuteInternal(ctx, "BEGIN") - if err != nil { - return false, errors.Trace(err) - } - leader, err := w.loadValueFromSysTable(gcLeaderUUIDKey) - if err != nil { - se.RollbackTxn(ctx) - return false, errors.Trace(err) - } - logutil.BgLogger().Debug("got leader", zap.String("category", "gc worker"), zap.String("uuid", leader)) - if leader == w.uuid { - err = w.saveTime(gcLeaderLeaseKey, time.Now().Add(gcWorkerLease)) - if err != nil { - se.RollbackTxn(ctx) - return false, errors.Trace(err) - } - err = se.CommitTxn(ctx) - if err != nil { - return false, errors.Trace(err) - } - return true, nil - } - - se.RollbackTxn(ctx) - - _, err = se.ExecuteInternal(ctx, "BEGIN") - if err != nil { - return false, errors.Trace(err) - } - lease, err := w.loadTime(gcLeaderLeaseKey) - if err != nil { - se.RollbackTxn(ctx) - return false, errors.Trace(err) - } - if lease == nil || lease.Before(time.Now()) { - logutil.BgLogger().Debug("register as leader", zap.String("category", "gc worker"), - zap.String("uuid", w.uuid)) - metrics.GCWorkerCounter.WithLabelValues("register_leader").Inc() - - err = w.saveValueToSysTable(gcLeaderUUIDKey, w.uuid) - if err != nil { - se.RollbackTxn(ctx) - return false, errors.Trace(err) - } - err = w.saveValueToSysTable(gcLeaderDescKey, w.desc) - if err != nil { - se.RollbackTxn(ctx) - return false, errors.Trace(err) - } - err = w.saveTime(gcLeaderLeaseKey, time.Now().Add(gcWorkerLease)) - if err != nil { - se.RollbackTxn(ctx) - return false, errors.Trace(err) - } - err = se.CommitTxn(ctx) - if err != nil { - return false, errors.Trace(err) - } - return true, nil - } - se.RollbackTxn(ctx) - return false, nil -} - -func (w *GCWorker) saveSafePoint(kv tikv.SafePointKV, t uint64) error { - s := strconv.FormatUint(t, 10) - err := kv.Put(tikv.GcSavedSafePoint, s) - if err != nil { - logutil.BgLogger().Error("save safepoint failed", zap.Error(err)) - return errors.Trace(err) - } - return nil -} - -func (w *GCWorker) saveTime(key string, t time.Time) error { - err := w.saveValueToSysTable(key, t.Format(tikvutil.GCTimeFormat)) - return errors.Trace(err) -} - -func (w *GCWorker) loadTime(key string) (*time.Time, error) { - str, err := w.loadValueFromSysTable(key) - if err != nil { - return nil, errors.Trace(err) - } - if str == "" { - return nil, nil - } - t, err := tikvutil.CompatibleParseGCTime(str) - if err != nil { - return nil, errors.Trace(err) - } - return &t, nil -} - -func (w *GCWorker) saveDuration(key string, d time.Duration) error { - err := w.saveValueToSysTable(key, d.String()) - return errors.Trace(err) -} - -func (w *GCWorker) loadDuration(key string) (*time.Duration, error) { - str, err := w.loadValueFromSysTable(key) - if err != nil { - return nil, errors.Trace(err) - } - if str == "" { - return nil, nil - } - d, err := time.ParseDuration(str) - if err != nil { - return nil, errors.Trace(err) - } - return &d, nil -} - -func (w *GCWorker) loadDurationWithDefault(key string, def time.Duration) (*time.Duration, error) { - d, err := w.loadDuration(key) - if err != nil { - return nil, errors.Trace(err) - } - if d == nil { - err = w.saveDuration(key, def) - if err != nil { - return nil, errors.Trace(err) - } - return &def, nil - } - return d, nil -} - -func (w *GCWorker) loadValueFromSysTable(key string) (string, error) { - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnGC) - se := createSession(w.store) - defer se.Close() - rs, err := se.ExecuteInternal(ctx, `SELECT HIGH_PRIORITY (variable_value) FROM mysql.tidb WHERE variable_name=%? FOR UPDATE`, key) - if rs != nil { - defer terror.Call(rs.Close) - } - if err != nil { - return "", errors.Trace(err) - } - req := rs.NewChunk(nil) - err = rs.Next(ctx, req) - if err != nil { - return "", errors.Trace(err) - } - if req.NumRows() == 0 { - logutil.BgLogger().Debug("load kv", zap.String("category", "gc worker"), - zap.String("key", key)) - return "", nil - } - value := req.GetRow(0).GetString(0) - logutil.BgLogger().Debug("load kv", zap.String("category", "gc worker"), - zap.String("key", key), - zap.String("value", value)) - return value, nil -} - -func (w *GCWorker) saveValueToSysTable(key, value string) error { - const stmt = `INSERT HIGH_PRIORITY INTO mysql.tidb VALUES (%?, %?, %?) - ON DUPLICATE KEY - UPDATE variable_value = %?, comment = %?` - se := createSession(w.store) - defer se.Close() - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnGC) - _, err := se.ExecuteInternal(ctx, stmt, - key, value, gcVariableComments[key], - value, gcVariableComments[key]) - logutil.BgLogger().Debug("save kv", zap.String("category", "gc worker"), - zap.String("key", key), - zap.String("value", value), - zap.Error(err)) - return errors.Trace(err) -} - -// GC placement rules when the partitions are removed by the GC worker. -// Placement rules cannot be removed immediately after drop table / truncate table, -// because the tables can be flashed back or recovered. -func (w *GCWorker) doGCPlacementRules(se sessiontypes.Session, _ uint64, dr util.DelRangeTask, gcPlacementRuleCache map[int64]any) (err error) { - // Get the job from the job history - var historyJob *model.Job - failpoint.Inject("mockHistoryJobForGC", func(v failpoint.Value) { - args, err1 := json.Marshal([]any{kv.Key{}, []int64{int64(v.(int))}}) - if err1 != nil { - return - } - historyJob = &model.Job{ - ID: dr.JobID, - Type: model.ActionDropTable, - TableID: int64(v.(int)), - RawArgs: args, - } - }) - if historyJob == nil { - historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) - if err != nil { - return - } - if historyJob == nil { - return dbterror.ErrDDLJobNotFound.GenWithStackByArgs(dr.JobID) - } - } - - // Notify PD to drop the placement rules of partition-ids and table-id, even if there may be no placement rules. - var physicalTableIDs []int64 - switch historyJob.Type { - case model.ActionDropTable, model.ActionTruncateTable: - var startKey kv.Key - if err = historyJob.DecodeArgs(&startKey, &physicalTableIDs); err != nil { - return - } - physicalTableIDs = append(physicalTableIDs, historyJob.TableID) - case model.ActionDropSchema, model.ActionDropTablePartition, model.ActionTruncateTablePartition, - model.ActionReorganizePartition, model.ActionRemovePartitioning, - model.ActionAlterTablePartitioning: - if err = historyJob.DecodeArgs(&physicalTableIDs); err != nil { - return - } - } - - // Skip table ids that's already successfully handled. - tmp := physicalTableIDs[:0] - for _, id := range physicalTableIDs { - if _, ok := gcPlacementRuleCache[id]; !ok { - tmp = append(tmp, id) - } - } - physicalTableIDs = tmp - - if len(physicalTableIDs) == 0 { - return - } - - if err := infosync.DeleteTiFlashPlacementRules(context.Background(), physicalTableIDs); err != nil { - logutil.BgLogger().Error("delete placement rules failed", zap.Error(err), zap.Int64s("tableIDs", physicalTableIDs)) - } - bundles := make([]*placement.Bundle, 0, len(physicalTableIDs)) - for _, id := range physicalTableIDs { - bundles = append(bundles, placement.NewBundle(id)) - } - err = infosync.PutRuleBundlesWithDefaultRetry(context.TODO(), bundles) - if err != nil { - return - } - - // Cache the table id if its related rule are deleted successfully. - for _, id := range physicalTableIDs { - gcPlacementRuleCache[id] = struct{}{} - } - return nil -} - -func (w *GCWorker) doGCLabelRules(dr util.DelRangeTask) (err error) { - // Get the job from the job history - var historyJob *model.Job - failpoint.Inject("mockHistoryJob", func(v failpoint.Value) { - args, err1 := json.Marshal([]any{kv.Key{}, []int64{}, []string{v.(string)}}) - if err1 != nil { - return - } - historyJob = &model.Job{ - ID: dr.JobID, - Type: model.ActionDropTable, - RawArgs: args, - } - }) - if historyJob == nil { - se := createSession(w.store) - historyJob, err = ddl.GetHistoryJobByID(se, dr.JobID) - se.Close() - if err != nil { - return - } - if historyJob == nil { - return dbterror.ErrDDLJobNotFound.GenWithStackByArgs(dr.JobID) - } - } - - if historyJob.Type == model.ActionDropTable { - var ( - startKey kv.Key - physicalTableIDs []int64 - ruleIDs []string - rules map[string]*label.Rule - ) - if err = historyJob.DecodeArgs(&startKey, &physicalTableIDs, &ruleIDs); err != nil { - return - } - - // TODO: Here we need to get rules from PD and filter the rules which is not elegant. We should find a better way. - rules, err = infosync.GetLabelRules(context.TODO(), ruleIDs) - if err != nil { - return - } - - ruleIDs = getGCRules(append(physicalTableIDs, historyJob.TableID), rules) - patch := label.NewRulePatch([]*label.Rule{}, ruleIDs) - err = infosync.UpdateLabelRules(context.TODO(), patch) - } - return -} - -func getGCRules(ids []int64, rules map[string]*label.Rule) []string { - oldRange := make(map[string]struct{}) - for _, id := range ids { - startKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(id))) - endKey := hex.EncodeToString(codec.EncodeBytes(nil, tablecodec.GenTablePrefix(id+1))) - oldRange[startKey+endKey] = struct{}{} - } - - var gcRules []string - for _, rule := range rules { - find := false - for _, d := range rule.Data.([]any) { - if r, ok := d.(map[string]any); ok { - nowRange := fmt.Sprintf("%s%s", r["start_key"], r["end_key"]) - if _, ok := oldRange[nowRange]; ok { - find = true - } - } - } - if find { - gcRules = append(gcRules, rule.ID) - } - } - return gcRules -} - -// RunGCJob sends GC command to KV. It is exported for kv api, do not use it with GCWorker at the same time. -// only use for test -func RunGCJob(ctx context.Context, regionLockResolver tikv.RegionLockResolver, s tikv.Storage, pd pd.Client, safePoint uint64, identifier string, concurrency int) error { - gcWorker := &GCWorker{ - tikvStore: s, - uuid: identifier, - pdClient: pd, - regionLockResolver: regionLockResolver, - } - - if concurrency <= 0 { - return errors.Errorf("[gc worker] gc concurrency should greater than 0, current concurrency: %v", concurrency) - } - - safePoint, err := gcWorker.setGCWorkerServiceSafePoint(ctx, safePoint) - if err != nil { - return errors.Trace(err) - } - - err = gcWorker.resolveLocks(ctx, safePoint, concurrency) - if err != nil { - return errors.Trace(err) - } - - err = gcWorker.saveSafePoint(gcWorker.tikvStore.GetSafePointKV(), safePoint) - if err != nil { - return errors.Trace(err) - } - // Sleep to wait for all other tidb instances update their safepoint cache. - time.Sleep(gcSafePointCacheInterval) - err = gcWorker.doGC(ctx, safePoint, concurrency) - if err != nil { - return errors.Trace(err) - } - return nil -} - -// RunDistributedGCJob notifies TiKVs to do GC. It is exported for kv api, do not use it with GCWorker at the same time. -// This function may not finish immediately because it may take some time to do resolveLocks. -// Param concurrency specifies the concurrency of resolveLocks phase. -func RunDistributedGCJob(ctx context.Context, regionLockResolver tikv.RegionLockResolver, s tikv.Storage, pd pd.Client, safePoint uint64, identifier string, concurrency int) error { - gcWorker := &GCWorker{ - tikvStore: s, - uuid: identifier, - pdClient: pd, - regionLockResolver: regionLockResolver, - } - - safePoint, err := gcWorker.setGCWorkerServiceSafePoint(ctx, safePoint) - if err != nil { - return errors.Trace(err) - } - err = gcWorker.resolveLocks(ctx, safePoint, concurrency) - if err != nil { - return errors.Trace(err) - } - - // Save safe point to pd. - err = gcWorker.saveSafePoint(gcWorker.tikvStore.GetSafePointKV(), safePoint) - if err != nil { - return errors.Trace(err) - } - // Sleep to wait for all other tidb instances update their safepoint cache. - time.Sleep(gcSafePointCacheInterval) - - err = gcWorker.uploadSafePointToPD(ctx, safePoint) - if err != nil { - return errors.Trace(err) - } - return nil -} - -// RunResolveLocks resolves all locks before the safePoint. -// It is exported only for test, do not use it in the production environment. -func RunResolveLocks(ctx context.Context, s tikv.Storage, pd pd.Client, safePoint uint64, identifier string, concurrency int) error { - gcWorker := &GCWorker{ - tikvStore: s, - uuid: identifier, - pdClient: pd, - regionLockResolver: tikv.NewRegionLockResolver("test-resolver", s), - } - return gcWorker.resolveLocks(ctx, safePoint, concurrency) -} - -// MockGCWorker is for test. -type MockGCWorker struct { - worker *GCWorker -} - -// NewMockGCWorker creates a MockGCWorker instance ONLY for test. -func NewMockGCWorker(store kv.Storage) (*MockGCWorker, error) { - ver, err := store.CurrentVersion(kv.GlobalTxnScope) - if err != nil { - return nil, errors.Trace(err) - } - hostName, err := os.Hostname() - if err != nil { - hostName = "unknown" - } - worker := &GCWorker{ - uuid: strconv.FormatUint(ver.Ver, 16), - desc: fmt.Sprintf("host:%s, pid:%d, start at %s", hostName, os.Getpid(), time.Now()), - store: store, - tikvStore: store.(tikv.Storage), - gcIsRunning: false, - lastFinish: time.Now(), - done: make(chan error), - pdClient: store.(tikv.Storage).GetRegionCache().PDClient(), - } - return &MockGCWorker{worker: worker}, nil -} - -// DeleteRanges calls deleteRanges internally, just for test. -func (w *MockGCWorker) DeleteRanges(ctx context.Context, safePoint uint64) error { - logutil.Logger(ctx).Error("deleteRanges is called") - ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnGC) - return w.worker.deleteRanges(ctx, safePoint, 1) -} diff --git a/pkg/store/mockstore/unistore/binding__failpoint_binding__.go b/pkg/store/mockstore/unistore/binding__failpoint_binding__.go deleted file mode 100644 index 2394b995d2777..0000000000000 --- a/pkg/store/mockstore/unistore/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package unistore - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go b/pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go deleted file mode 100644 index bea17e070ec6f..0000000000000 --- a/pkg/store/mockstore/unistore/cophandler/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package cophandler - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go b/pkg/store/mockstore/unistore/cophandler/cop_handler.go index b26a3cfb7bff4..7097f07104476 100644 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go +++ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go @@ -147,10 +147,10 @@ func ExecutorListsToTree(exec []*tipb.Executor) *tipb.Executor { func handleCopDAGRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) (resp *coprocessor.Response) { startTime := time.Now() resp = &coprocessor.Response{} - if cacheVersion, _err_ := failpoint.Eval(_curpkg_("mockCopCacheInUnistore")); _err_ == nil { + failpoint.Inject("mockCopCacheInUnistore", func(cacheVersion failpoint.Value) { if req.IsCacheEnabled { if uint64(cacheVersion.(int)) == req.CacheIfMatchVersion { - return &coprocessor.Response{IsCacheHit: true, CacheLastVersion: uint64(cacheVersion.(int))} + failpoint.Return(&coprocessor.Response{IsCacheHit: true, CacheLastVersion: uint64(cacheVersion.(int))}) } else { defer func() { resp.CanBeCached = true @@ -165,7 +165,7 @@ func handleCopDAGRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemSt }() } } - } + }) dagCtx, dagReq, err := buildDAG(dbReader, lockStore, req) if err != nil { resp.OtherError = err.Error() diff --git a/pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ b/pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ deleted file mode 100644 index 7097f07104476..0000000000000 --- a/pkg/store/mockstore/unistore/cophandler/cop_handler.go__failpoint_stash__ +++ /dev/null @@ -1,674 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cophandler - -import ( - "bytes" - "context" - "fmt" - "strings" - "sync" - "time" - - "github.com/golang/protobuf/proto" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/coprocessor" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/tidb/pkg/expression" - "github.com/pingcap/tidb/pkg/expression/aggregation" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/charset" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/store/mockstore/unistore/client" - "github.com/pingcap/tidb/pkg/store/mockstore/unistore/lockstore" - "github.com/pingcap/tidb/pkg/store/mockstore/unistore/tikv/dbreader" - "github.com/pingcap/tidb/pkg/store/mockstore/unistore/tikv/kverrors" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - contextutil "github.com/pingcap/tidb/pkg/util/context" - "github.com/pingcap/tidb/pkg/util/mock" - "github.com/pingcap/tidb/pkg/util/rowcodec" - "github.com/pingcap/tipb/go-tipb" -) - -var globalLocationMap *locationMap = newLocationMap() - -type locationMap struct { - lmap map[string]*time.Location - mu sync.RWMutex -} - -func newLocationMap() *locationMap { - return &locationMap{ - lmap: make(map[string]*time.Location), - } -} - -func (l *locationMap) getLocation(name string) (*time.Location, bool) { - l.mu.RLock() - defer l.mu.RUnlock() - result, ok := l.lmap[name] - return result, ok -} - -func (l *locationMap) setLocation(name string, value *time.Location) { - l.mu.Lock() - defer l.mu.Unlock() - l.lmap[name] = value -} - -// MPPCtx is the mpp execution context -type MPPCtx struct { - RPCClient client.Client - StoreAddr string - TaskHandler *MPPTaskHandler - Ctx context.Context -} - -// HandleCopRequest handles coprocessor request. -func HandleCopRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) *coprocessor.Response { - return HandleCopRequestWithMPPCtx(dbReader, lockStore, req, nil) -} - -// HandleCopRequestWithMPPCtx handles coprocessor request, actually, this is the updated version for -// HandleCopRequest(after mpp test is supported), however, go does not support function overloading, -// I have to rename it to HandleCopRequestWithMPPCtx. -func HandleCopRequestWithMPPCtx(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request, mppCtx *MPPCtx) *coprocessor.Response { - switch req.Tp { - case kv.ReqTypeDAG: - if mppCtx != nil && mppCtx.TaskHandler != nil { - return HandleMPPDAGReq(dbReader, req, mppCtx) - } - return handleCopDAGRequest(dbReader, lockStore, req) - case kv.ReqTypeAnalyze: - return handleCopAnalyzeRequest(dbReader, req) - case kv.ReqTypeChecksum: - return handleCopChecksumRequest(dbReader, req) - } - return &coprocessor.Response{OtherError: fmt.Sprintf("unsupported request type %d", req.GetTp())} -} - -type dagContext struct { - *evalContext - dbReader *dbreader.DBReader - lockStore *lockstore.MemStore - resolvedLocks []uint64 - dagReq *tipb.DAGRequest - keyRanges []*coprocessor.KeyRange - startTS uint64 -} - -// ExecutorListsToTree converts a list of executors to a tree. -func ExecutorListsToTree(exec []*tipb.Executor) *tipb.Executor { - i := len(exec) - 1 - rootExec := exec[i] - for i--; 0 <= i; i-- { - switch exec[i+1].Tp { - case tipb.ExecType_TypeAggregation: - exec[i+1].Aggregation.Child = exec[i] - case tipb.ExecType_TypeProjection: - exec[i+1].Projection.Child = exec[i] - case tipb.ExecType_TypeTopN: - exec[i+1].TopN.Child = exec[i] - case tipb.ExecType_TypeLimit: - exec[i+1].Limit.Child = exec[i] - case tipb.ExecType_TypeSelection: - exec[i+1].Selection.Child = exec[i] - case tipb.ExecType_TypeStreamAgg: - exec[i+1].Aggregation.Child = exec[i] - default: - panic("unsupported dag executor type") - } - } - return rootExec -} - -// handleCopDAGRequest handles coprocessor DAG request using MPP executors. -func handleCopDAGRequest(dbReader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) (resp *coprocessor.Response) { - startTime := time.Now() - resp = &coprocessor.Response{} - failpoint.Inject("mockCopCacheInUnistore", func(cacheVersion failpoint.Value) { - if req.IsCacheEnabled { - if uint64(cacheVersion.(int)) == req.CacheIfMatchVersion { - failpoint.Return(&coprocessor.Response{IsCacheHit: true, CacheLastVersion: uint64(cacheVersion.(int))}) - } else { - defer func() { - resp.CanBeCached = true - resp.CacheLastVersion = uint64(cacheVersion.(int)) - if resp.ExecDetails == nil { - resp.ExecDetails = &kvrpcpb.ExecDetails{TimeDetail: &kvrpcpb.TimeDetail{ProcessWallTimeMs: 500}} - } else if resp.ExecDetails.TimeDetail == nil { - resp.ExecDetails.TimeDetail = &kvrpcpb.TimeDetail{ProcessWallTimeMs: 500} - } else { - resp.ExecDetails.TimeDetail.ProcessWallTimeMs = 500 - } - }() - } - } - }) - dagCtx, dagReq, err := buildDAG(dbReader, lockStore, req) - if err != nil { - resp.OtherError = err.Error() - return resp - } - - exec, chunks, lastRange, counts, ndvs, err := buildAndRunMPPExecutor(dagCtx, dagReq, req.PagingSize) - - sc := dagCtx.sctx.GetSessionVars().StmtCtx - if err != nil { - errMsg := err.Error() - if strings.HasPrefix(errMsg, ErrExecutorNotSupportedMsg) { - resp.OtherError = err.Error() - return resp - } - return genRespWithMPPExec(nil, lastRange, nil, nil, exec, dagReq, err, sc.GetWarnings(), time.Since(startTime)) - } - return genRespWithMPPExec(chunks, lastRange, counts, ndvs, exec, dagReq, err, sc.GetWarnings(), time.Since(startTime)) -} - -func buildAndRunMPPExecutor(dagCtx *dagContext, dagReq *tipb.DAGRequest, pagingSize uint64) (mppExec, []tipb.Chunk, *coprocessor.KeyRange, []int64, []int64, error) { - rootExec := dagReq.RootExecutor - if rootExec == nil { - rootExec = ExecutorListsToTree(dagReq.Executors) - } - - var counts, ndvs []int64 - - if dagReq.GetCollectRangeCounts() { - counts = make([]int64, len(dagCtx.keyRanges)) - ndvs = make([]int64, len(dagCtx.keyRanges)) - } - builder := &mppExecBuilder{ - sctx: dagCtx.sctx, - dbReader: dagCtx.dbReader, - dagReq: dagReq, - dagCtx: dagCtx, - mppCtx: nil, - counts: counts, - ndvs: ndvs, - } - var lastRange *coprocessor.KeyRange - if pagingSize > 0 { - lastRange = &coprocessor.KeyRange{} - builder.paging = lastRange - builder.pagingSize = pagingSize - } - exec, err := builder.buildMPPExecutor(rootExec) - if err != nil { - return nil, nil, nil, nil, nil, err - } - chunks, err := mppExecute(exec, dagCtx, dagReq, pagingSize) - if lastRange != nil && len(lastRange.Start) == 0 && len(lastRange.End) == 0 { - // When should this happen, something is wrong? - lastRange = nil - } - return exec, chunks, lastRange, counts, ndvs, err -} - -func mppExecute(exec mppExec, dagCtx *dagContext, dagReq *tipb.DAGRequest, pagingSize uint64) (chunks []tipb.Chunk, err error) { - err = exec.open() - defer func() { - err := exec.stop() - if err != nil { - panic(err) - } - }() - if err != nil { - return - } - - var totalRows uint64 - var chk *chunk.Chunk - fields := exec.getFieldTypes() - for { - chk, err = exec.next() - if err != nil || chk == nil || chk.NumRows() == 0 { - return - } - - switch dagReq.EncodeType { - case tipb.EncodeType_TypeDefault: - chunks, err = useDefaultEncoding(chk, dagCtx, dagReq, fields, chunks) - case tipb.EncodeType_TypeChunk: - chunks = useChunkEncoding(chk, dagReq, fields, chunks) - if pagingSize > 0 { - totalRows += uint64(chk.NumRows()) - if totalRows > pagingSize { - return - } - } - default: - err = fmt.Errorf("unsupported DAG request encode type %s", dagReq.EncodeType) - } - if err != nil { - return - } - } -} - -func useDefaultEncoding(chk *chunk.Chunk, dagCtx *dagContext, dagReq *tipb.DAGRequest, - fields []*types.FieldType, chunks []tipb.Chunk) ([]tipb.Chunk, error) { - var buf []byte - var datums []types.Datum - var err error - numRows := chk.NumRows() - sc := dagCtx.sctx.GetSessionVars().StmtCtx - errCtx := sc.ErrCtx() - for i := 0; i < numRows; i++ { - datums = datums[:0] - if dagReq.OutputOffsets != nil { - for _, j := range dagReq.OutputOffsets { - datums = append(datums, chk.GetRow(i).GetDatum(int(j), fields[j])) - } - } else { - for j, ft := range fields { - datums = append(datums, chk.GetRow(i).GetDatum(j, ft)) - } - } - buf, err = codec.EncodeValue(sc.TimeZone(), buf[:0], datums...) - err = errCtx.HandleError(err) - if err != nil { - return nil, errors.Trace(err) - } - chunks = appendRow(chunks, buf, i) - } - return chunks, nil -} - -func useChunkEncoding(chk *chunk.Chunk, dagReq *tipb.DAGRequest, fields []*types.FieldType, chunks []tipb.Chunk) []tipb.Chunk { - if dagReq.OutputOffsets != nil { - offsets := make([]int, len(dagReq.OutputOffsets)) - newFields := make([]*types.FieldType, len(dagReq.OutputOffsets)) - for i := 0; i < len(dagReq.OutputOffsets); i++ { - offset := dagReq.OutputOffsets[i] - offsets[i] = int(offset) - newFields[i] = fields[offset] - } - chk = chk.Prune(offsets) - fields = newFields - } - - c := chunk.NewCodec(fields) - buffer := c.Encode(chk) - chunks = append(chunks, tipb.Chunk{ - RowsData: buffer, - }) - return chunks -} - -func buildDAG(reader *dbreader.DBReader, lockStore *lockstore.MemStore, req *coprocessor.Request) (*dagContext, *tipb.DAGRequest, error) { - if len(req.Ranges) == 0 { - return nil, nil, errors.New("request range is null") - } - if req.GetTp() != kv.ReqTypeDAG { - return nil, nil, errors.Errorf("unsupported request type %d", req.GetTp()) - } - - dagReq := new(tipb.DAGRequest) - err := proto.Unmarshal(req.Data, dagReq) - if err != nil { - return nil, nil, errors.Trace(err) - } - var tz *time.Location - switch dagReq.TimeZoneName { - case "": - tz = time.FixedZone("UTC", int(dagReq.TimeZoneOffset)) - case "System": - tz = time.Local - default: - var ok bool - tz, ok = globalLocationMap.getLocation(dagReq.TimeZoneName) - if !ok { - tz, err = time.LoadLocation(dagReq.TimeZoneName) - if err != nil { - return nil, nil, errors.Trace(err) - } - globalLocationMap.setLocation(dagReq.TimeZoneName, tz) - } - } - sctx := flagsAndTzToSessionContext(dagReq.Flags, tz) - if dagReq.DivPrecisionIncrement != nil { - sctx.GetSessionVars().DivPrecisionIncrement = int(*dagReq.DivPrecisionIncrement) - } else { - sctx.GetSessionVars().DivPrecisionIncrement = variable.DefDivPrecisionIncrement - } - ctx := &dagContext{ - evalContext: &evalContext{sctx: sctx}, - dbReader: reader, - lockStore: lockStore, - dagReq: dagReq, - keyRanges: req.Ranges, - startTS: req.StartTs, - resolvedLocks: req.Context.ResolvedLocks, - } - return ctx, dagReq, err -} - -func getAggInfo(ctx *dagContext, pbAgg *tipb.Aggregation) ([]aggregation.Aggregation, []expression.Expression, error) { - length := len(pbAgg.AggFunc) - aggs := make([]aggregation.Aggregation, 0, length) - var err error - for _, expr := range pbAgg.AggFunc { - var aggExpr aggregation.Aggregation - aggExpr, _, err = aggregation.NewDistAggFunc(expr, ctx.fieldTps, ctx.sctx.GetExprCtx()) - if err != nil { - return nil, nil, errors.Trace(err) - } - aggs = append(aggs, aggExpr) - } - groupBys, err := convertToExprs(ctx.sctx, ctx.fieldTps, pbAgg.GetGroupBy()) - if err != nil { - return nil, nil, errors.Trace(err) - } - - return aggs, groupBys, nil -} - -func getTopNInfo(ctx *evalContext, topN *tipb.TopN) (heap *topNHeap, conds []expression.Expression, err error) { - pbConds := make([]*tipb.Expr, len(topN.OrderBy)) - for i, item := range topN.OrderBy { - pbConds[i] = item.Expr - } - heap = &topNHeap{ - totalCount: int(topN.Limit), - topNSorter: topNSorter{ - orderByItems: topN.OrderBy, - sc: ctx.sctx.GetSessionVars().StmtCtx, - }, - } - if conds, err = convertToExprs(ctx.sctx, ctx.fieldTps, pbConds); err != nil { - return nil, nil, errors.Trace(err) - } - - return heap, conds, nil -} - -type evalContext struct { - columnInfos []*tipb.ColumnInfo - fieldTps []*types.FieldType - primaryCols []int64 - sctx sessionctx.Context -} - -func (e *evalContext) setColumnInfo(cols []*tipb.ColumnInfo) { - e.columnInfos = make([]*tipb.ColumnInfo, len(cols)) - copy(e.columnInfos, cols) - - e.fieldTps = make([]*types.FieldType, 0, len(e.columnInfos)) - for _, col := range e.columnInfos { - ft := fieldTypeFromPBColumn(col) - e.fieldTps = append(e.fieldTps, ft) - } -} - -func newRowDecoder(columnInfos []*tipb.ColumnInfo, fieldTps []*types.FieldType, primaryCols []int64, timeZone *time.Location) (*rowcodec.ChunkDecoder, error) { - var ( - pkCols []int64 - cols = make([]rowcodec.ColInfo, 0, len(columnInfos)) - ) - for i := range columnInfos { - info := columnInfos[i] - if info.ColumnId == model.ExtraPhysTblID { - // Skip since it needs to be filled in from the key - continue - } - ft := fieldTps[i] - col := rowcodec.ColInfo{ - ID: info.ColumnId, - Ft: ft, - IsPKHandle: info.PkHandle, - } - cols = append(cols, col) - if info.PkHandle { - pkCols = append(pkCols, info.ColumnId) - } - } - if len(pkCols) == 0 { - if primaryCols != nil { - pkCols = primaryCols - } else { - pkCols = []int64{-1} - } - } - def := func(i int, chk *chunk.Chunk) error { - info := columnInfos[i] - if info.PkHandle || len(info.DefaultVal) == 0 { - chk.AppendNull(i) - return nil - } - decoder := codec.NewDecoder(chk, timeZone) - _, err := decoder.DecodeOne(info.DefaultVal, i, fieldTps[i]) - if err != nil { - return err - } - return nil - } - return rowcodec.NewChunkDecoder(cols, pkCols, def, timeZone), nil -} - -// flagsAndTzToSessionContext creates a sessionctx.Context from a `tipb.SelectRequest.Flags`. -func flagsAndTzToSessionContext(flags uint64, tz *time.Location) sessionctx.Context { - sc := stmtctx.NewStmtCtx() - sc.InitFromPBFlagAndTz(flags, tz) - sctx := mock.NewContext() - sctx.GetSessionVars().StmtCtx = sc - sctx.GetSessionVars().TimeZone = tz - return sctx -} - -// ErrLocked is returned when trying to Read/Write on a locked key. Client should -// backoff or cleanup the lock then retry. -type ErrLocked struct { - Key []byte - Primary []byte - StartTS uint64 - TTL uint64 - LockType uint8 -} - -// BuildLockErr generates ErrKeyLocked objects -func BuildLockErr(key []byte, primaryKey []byte, startTS uint64, TTL uint64, lockType uint8) *ErrLocked { - errLocked := &ErrLocked{ - Key: key, - Primary: primaryKey, - StartTS: startTS, - TTL: TTL, - LockType: lockType, - } - return errLocked -} - -// Error formats the lock to a string. -func (e *ErrLocked) Error() string { - return fmt.Sprintf("key is locked, key: %q, Type: %v, primary: %q, startTS: %v", e.Key, e.LockType, e.Primary, e.StartTS) -} - -func genRespWithMPPExec(chunks []tipb.Chunk, lastRange *coprocessor.KeyRange, counts, ndvs []int64, exec mppExec, dagReq *tipb.DAGRequest, err error, warnings []contextutil.SQLWarn, dur time.Duration) *coprocessor.Response { - resp := &coprocessor.Response{ - Range: lastRange, - } - selResp := &tipb.SelectResponse{ - Error: toPBError(err), - Chunks: chunks, - OutputCounts: counts, - Ndvs: ndvs, - EncodeType: dagReq.EncodeType, - } - executors := dagReq.Executors - if dagReq.CollectExecutionSummaries != nil && *dagReq.CollectExecutionSummaries { - // for simplicity, we assume all executors to be spending the same amount of time as the request - timeProcessed := uint64(dur / time.Nanosecond) - execSummary := make([]*tipb.ExecutorExecutionSummary, len(executors)) - e := exec - for i := len(executors) - 1; 0 <= i; i-- { - execSummary[i] = e.buildSummary() - execSummary[i].TimeProcessedNs = &timeProcessed - if i != 0 { - e = exec.child() - } - } - selResp.ExecutionSummaries = execSummary - } - if len(warnings) > 0 { - selResp.Warnings = make([]*tipb.Error, 0, len(warnings)) - for i := range warnings { - selResp.Warnings = append(selResp.Warnings, toPBError(warnings[i].Err)) - } - } - if locked, ok := errors.Cause(err).(*ErrLocked); ok { - resp.Locked = &kvrpcpb.LockInfo{ - Key: locked.Key, - PrimaryLock: locked.Primary, - LockVersion: locked.StartTS, - LockTtl: locked.TTL, - } - } - resp.ExecDetails = &kvrpcpb.ExecDetails{ - TimeDetail: &kvrpcpb.TimeDetail{ProcessWallTimeMs: uint64(dur / time.Millisecond)}, - } - resp.ExecDetailsV2 = &kvrpcpb.ExecDetailsV2{ - TimeDetail: resp.ExecDetails.TimeDetail, - } - data, mErr := proto.Marshal(selResp) - if mErr != nil { - resp.OtherError = mErr.Error() - return resp - } - resp.Data = data - if err != nil { - if conflictErr, ok := errors.Cause(err).(*kverrors.ErrConflict); ok { - resp.OtherError = conflictErr.Error() - } - } - return resp -} - -func toPBError(err error) *tipb.Error { - if err == nil { - return nil - } - perr := new(tipb.Error) - e := errors.Cause(err) - switch y := e.(type) { - case *terror.Error: - tmp := terror.ToSQLError(y) - perr.Code = int32(tmp.Code) - perr.Msg = tmp.Message - case *mysql.SQLError: - perr.Code = int32(y.Code) - perr.Msg = y.Message - default: - perr.Code = int32(1) - perr.Msg = err.Error() - } - return perr -} - -// extractKVRanges extracts kv.KeyRanges slice from a SelectRequest. -func extractKVRanges(startKey, endKey []byte, keyRanges []*coprocessor.KeyRange, descScan bool) (kvRanges []kv.KeyRange, err error) { - kvRanges = make([]kv.KeyRange, 0, len(keyRanges)) - for _, kran := range keyRanges { - if bytes.Compare(kran.GetStart(), kran.GetEnd()) >= 0 { - err = errors.Errorf("invalid range, start should be smaller than end: %v %v", kran.GetStart(), kran.GetEnd()) - return - } - - upperKey := kran.GetEnd() - if bytes.Compare(upperKey, startKey) <= 0 { - continue - } - lowerKey := kran.GetStart() - if len(endKey) != 0 && bytes.Compare(lowerKey, endKey) >= 0 { - break - } - r := kv.KeyRange{ - StartKey: kv.Key(maxStartKey(lowerKey, startKey)), - EndKey: kv.Key(minEndKey(upperKey, endKey)), - } - kvRanges = append(kvRanges, r) - } - if descScan { - reverseKVRanges(kvRanges) - } - return -} - -func reverseKVRanges(kvRanges []kv.KeyRange) { - for i := 0; i < len(kvRanges)/2; i++ { - j := len(kvRanges) - i - 1 - kvRanges[i], kvRanges[j] = kvRanges[j], kvRanges[i] - } -} - -func maxStartKey(rangeStartKey kv.Key, regionStartKey []byte) []byte { - if bytes.Compare([]byte(rangeStartKey), regionStartKey) > 0 { - return []byte(rangeStartKey) - } - return regionStartKey -} - -func minEndKey(rangeEndKey kv.Key, regionEndKey []byte) []byte { - if len(regionEndKey) == 0 || bytes.Compare([]byte(rangeEndKey), regionEndKey) < 0 { - return []byte(rangeEndKey) - } - return regionEndKey -} - -const rowsPerChunk = 64 - -func appendRow(chunks []tipb.Chunk, data []byte, rowCnt int) []tipb.Chunk { - if rowCnt%rowsPerChunk == 0 { - chunks = append(chunks, tipb.Chunk{}) - } - cur := &chunks[len(chunks)-1] - cur.RowsData = append(cur.RowsData, data...) - return chunks -} - -// fieldTypeFromPBColumn creates a types.FieldType from tipb.ColumnInfo. -func fieldTypeFromPBColumn(col *tipb.ColumnInfo) *types.FieldType { - charsetStr, collationStr, _ := charset.GetCharsetInfoByID(int(collate.RestoreCollationIDIfNeeded(col.GetCollation()))) - ft := &types.FieldType{} - ft.SetType(byte(col.GetTp())) - ft.SetFlag(uint(col.GetFlag())) - ft.SetFlen(int(col.GetColumnLen())) - ft.SetDecimal(int(col.GetDecimal())) - ft.SetElems(col.Elems) - ft.SetCharset(charsetStr) - ft.SetCollate(collationStr) - return ft -} - -// handleCopChecksumRequest handles coprocessor check sum request. -func handleCopChecksumRequest(dbReader *dbreader.DBReader, req *coprocessor.Request) *coprocessor.Response { - resp := &tipb.ChecksumResponse{ - Checksum: 1, - TotalKvs: 1, - TotalBytes: 1, - } - data, err := resp.Marshal() - if err != nil { - return &coprocessor.Response{OtherError: fmt.Sprintf("marshal checksum response error: %v", err)} - } - return &coprocessor.Response{Data: data} -} diff --git a/pkg/store/mockstore/unistore/rpc.go b/pkg/store/mockstore/unistore/rpc.go index d22ab62ccc6f8..b255138dff63b 100644 --- a/pkg/store/mockstore/unistore/rpc.go +++ b/pkg/store/mockstore/unistore/rpc.go @@ -62,41 +62,41 @@ var UnistoreRPCClientSendHook atomic.Pointer[func(*tikvrpc.Request)] // SendRequest sends a request to mock cluster. func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { - if val, _err_ := failpoint.Eval(_curpkg_("rpcServerBusy")); _err_ == nil { + failpoint.Inject("rpcServerBusy", func(val failpoint.Value) { if val.(bool) { - return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}}) + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}})) } - } - if val, _err_ := failpoint.Eval(_curpkg_("epochNotMatch")); _err_ == nil { + }) + failpoint.Inject("epochNotMatch", func(val failpoint.Value) { if val.(bool) { - return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}}) + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}})) } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("unistoreRPCClientSendHook")); _err_ == nil { + failpoint.Inject("unistoreRPCClientSendHook", func(val failpoint.Value) { if fn := UnistoreRPCClientSendHook.Load(); val.(bool) && fn != nil { (*fn)(req) } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("rpcTiKVAllowedOnAlmostFull")); _err_ == nil { + failpoint.Inject("rpcTiKVAllowedOnAlmostFull", func(val failpoint.Value) { if val.(bool) { if req.Type == tikvrpc.CmdPrewrite || req.Type == tikvrpc.CmdCommit { if req.Context.DiskFullOpt != kvrpcpb.DiskFullOpt_AllowedOnAlmostFull { - return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{DiskFull: &errorpb.DiskFull{StoreId: []uint64{1}, Reason: "disk full"}}) + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{DiskFull: &errorpb.DiskFull{StoreId: []uint64{1}, Reason: "disk full"}})) } } } - } - if val, _err_ := failpoint.Eval(_curpkg_("unistoreRPCDeadlineExceeded")); _err_ == nil { + }) + failpoint.Inject("unistoreRPCDeadlineExceeded", func(val failpoint.Value) { if val.(bool) && timeout < time.Second { - return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"}) + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) } - } - if val, _err_ := failpoint.Eval(_curpkg_("unistoreRPCSlowByInjestSleep")); _err_ == nil { + }) + failpoint.Inject("unistoreRPCSlowByInjestSleep", func(val failpoint.Value) { time.Sleep(time.Duration(val.(int) * int(time.Millisecond))) - return tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"}) - } + failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) + }) select { case <-ctx.Done(): @@ -127,10 +127,10 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp, err = c.usSvr.KvGet(ctx, req.Get()) case tikvrpc.CmdScan: kvScanReq := req.Scan() - if val, _err_ := failpoint.Eval(_curpkg_("rpcScanResult")); _err_ == nil { + failpoint.Inject("rpcScanResult", func(val failpoint.Value) { switch val.(string) { case "keyError": - return &tikvrpc.Response{ + failpoint.Return(&tikvrpc.Response{ Resp: &kvrpcpb.ScanResponse{Error: &kvrpcpb.KeyError{ Locked: &kvrpcpb.LockInfo{ PrimaryLock: kvScanReq.StartKey, @@ -141,38 +141,38 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R LockType: kvrpcpb.Op_Put, }, }}, - }, nil + }, nil) } - } + }) resp.Resp, err = c.usSvr.KvScan(ctx, kvScanReq) case tikvrpc.CmdPrewrite: - if val, _err_ := failpoint.Eval(_curpkg_("rpcPrewriteResult")); _err_ == nil { + failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) { if val != nil { switch val.(string) { case "timeout": - return nil, errors.New("timeout") + failpoint.Return(nil, errors.New("timeout")) case "notLeader": - return &tikvrpc.Response{ + failpoint.Return(&tikvrpc.Response{ Resp: &kvrpcpb.PrewriteResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, - }, nil + }, nil) case "writeConflict": - return &tikvrpc.Response{ + failpoint.Return(&tikvrpc.Response{ Resp: &kvrpcpb.PrewriteResponse{Errors: []*kvrpcpb.KeyError{{Conflict: &kvrpcpb.WriteConflict{}}}}, - }, nil + }, nil) } } - } + }) r := req.Prewrite() c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) resp.Resp, err = c.usSvr.KvPrewrite(ctx, r) - if val, _err_ := failpoint.Eval(_curpkg_("rpcPrewriteTimeout")); _err_ == nil { + failpoint.Inject("rpcPrewriteTimeout", func(val failpoint.Value) { if val.(bool) { - return nil, undeterminedErr + failpoint.Return(nil, undeterminedErr) } - } + }) case tikvrpc.CmdPessimisticLock: r := req.PessimisticLock() c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) @@ -180,28 +180,28 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R case tikvrpc.CmdPessimisticRollback: resp.Resp, err = c.usSvr.KVPessimisticRollback(ctx, req.PessimisticRollback()) case tikvrpc.CmdCommit: - if val, _err_ := failpoint.Eval(_curpkg_("rpcCommitResult")); _err_ == nil { + failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { switch val.(string) { case "timeout": - return nil, errors.New("timeout") + failpoint.Return(nil, errors.New("timeout")) case "notLeader": - return &tikvrpc.Response{ + failpoint.Return(&tikvrpc.Response{ Resp: &kvrpcpb.CommitResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, - }, nil + }, nil) case "keyError": - return &tikvrpc.Response{ + failpoint.Return(&tikvrpc.Response{ Resp: &kvrpcpb.CommitResponse{Error: &kvrpcpb.KeyError{}}, - }, nil + }, nil) } - } + }) resp.Resp, err = c.usSvr.KvCommit(ctx, req.Commit()) - if val, _err_ := failpoint.Eval(_curpkg_("rpcCommitTimeout")); _err_ == nil { + failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) { if val.(bool) { - return nil, undeterminedErr + failpoint.Return(nil, undeterminedErr) } - } + }) case tikvrpc.CmdCleanup: resp.Resp, err = c.usSvr.KvCleanup(ctx, req.Cleanup()) case tikvrpc.CmdCheckTxnStatus: @@ -212,10 +212,10 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R resp.Resp, err = c.usSvr.KvTxnHeartBeat(ctx, req.TxnHeartBeat()) case tikvrpc.CmdBatchGet: batchGetReq := req.BatchGet() - if val, _err_ := failpoint.Eval(_curpkg_("rpcBatchGetResult")); _err_ == nil { + failpoint.Inject("rpcBatchGetResult", func(val failpoint.Value) { switch val.(string) { case "keyError": - return &tikvrpc.Response{ + failpoint.Return(&tikvrpc.Response{ Resp: &kvrpcpb.BatchGetResponse{Error: &kvrpcpb.KeyError{ Locked: &kvrpcpb.LockInfo{ PrimaryLock: batchGetReq.Keys[0], @@ -226,9 +226,9 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R LockType: kvrpcpb.Op_Put, }, }}, - }, nil + }, nil) } - } + }) resp.Resp, err = c.usSvr.KvBatchGet(ctx, batchGetReq) case tikvrpc.CmdBatchRollback: @@ -262,41 +262,41 @@ func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.R case tikvrpc.CmdCopStream: resp.Resp, err = c.handleCopStream(ctx, req.Cop()) case tikvrpc.CmdBatchCop: - if value, _err_ := failpoint.Eval(_curpkg_("BatchCopCancelled")); _err_ == nil { + failpoint.Inject("BatchCopCancelled", func(value failpoint.Value) { if value.(bool) { - return nil, context.Canceled + failpoint.Return(nil, context.Canceled) } - } + }) - if value, _err_ := failpoint.Eval(_curpkg_("BatchCopRpcErr" + addr)); _err_ == nil { + failpoint.Inject("BatchCopRpcErr"+addr, func(value failpoint.Value) { if value.(string) == addr { - return nil, errors.New("rpc error") + failpoint.Return(nil, errors.New("rpc error")) } - } + }) resp.Resp, err = c.handleBatchCop(ctx, req.BatchCop(), timeout) case tikvrpc.CmdMPPConn: - if val, _err_ := failpoint.Eval(_curpkg_("mppConnTimeout")); _err_ == nil { + failpoint.Inject("mppConnTimeout", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("rpc error") + failpoint.Return(nil, errors.New("rpc error")) } - } - if val, _err_ := failpoint.Eval(_curpkg_("MppVersionError")); _err_ == nil { + }) + failpoint.Inject("MppVersionError", func(val failpoint.Value) { if v := int64(val.(int)); v > req.EstablishMPPConn().GetReceiverMeta().GetMppVersion() || v > req.EstablishMPPConn().GetSenderMeta().GetMppVersion() { - return nil, context.Canceled + failpoint.Return(nil, context.Canceled) } - } + }) resp.Resp, err = c.handleEstablishMPPConnection(ctx, req.EstablishMPPConn(), timeout, storeID) case tikvrpc.CmdMPPTask: - if val, _err_ := failpoint.Eval(_curpkg_("mppDispatchTimeout")); _err_ == nil { + failpoint.Inject("mppDispatchTimeout", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("rpc error") + failpoint.Return(nil, errors.New("rpc error")) } - } - if val, _err_ := failpoint.Eval(_curpkg_("MppVersionError")); _err_ == nil { + }) + failpoint.Inject("MppVersionError", func(val failpoint.Value) { if v := int64(val.(int)); v > req.DispatchMPPTask().GetMeta().GetMppVersion() { - return nil, context.Canceled + failpoint.Return(nil, context.Canceled) } - } + }) resp.Resp, err = c.handleDispatchMPPTask(ctx, req.DispatchMPPTask(), storeID) case tikvrpc.CmdMPPCancel: case tikvrpc.CmdMvccGetByKey: @@ -367,11 +367,11 @@ func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.Est if err != nil { return nil, err } - if val, _err_ := failpoint.Eval(_curpkg_("establishMppConnectionErr")); _err_ == nil { + failpoint.Inject("establishMppConnectionErr", func(val failpoint.Value) { if val.(bool) { - return nil, errors.New("rpc error") + failpoint.Return(nil, errors.New("rpc error")) } - } + }) var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, ctx: ctx, targetTask: r.ReceiverMeta} streamResp := &tikvrpc.MPPStreamResponse{Tikv_EstablishMPPConnectionClient: &mockClient} _, cancel := context.WithCancel(ctx) @@ -510,11 +510,11 @@ func (mock *mockBatchCopClient) Recv() (*coprocessor.BatchResponse, error) { } return ret, err } - if val, _err_ := failpoint.Eval(_curpkg_("batchCopRecvTimeout")); _err_ == nil { + failpoint.Inject("batchCopRecvTimeout", func(val failpoint.Value) { if val.(bool) { - return nil, context.Canceled + failpoint.Return(nil, context.Canceled) } - } + }) return nil, io.EOF } @@ -532,23 +532,23 @@ func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { mock.idx++ return ret, nil } - if val, _err_ := failpoint.Eval(_curpkg_("mppRecvTimeout")); _err_ == nil { + failpoint.Inject("mppRecvTimeout", func(val failpoint.Value) { if int64(val.(int)) == mock.targetTask.TaskId { - return nil, context.Canceled + failpoint.Return(nil, context.Canceled) } - } - if val, _err_ := failpoint.Eval(_curpkg_("mppRecvHang")); _err_ == nil { + }) + failpoint.Inject("mppRecvHang", func(val failpoint.Value) { for val.(bool) { select { case <-mock.ctx.Done(): { - return nil, context.Canceled + failpoint.Return(nil, context.Canceled) } default: time.Sleep(1 * time.Second) } } - } + }) return nil, io.EOF } diff --git a/pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ b/pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ deleted file mode 100644 index b255138dff63b..0000000000000 --- a/pkg/store/mockstore/unistore/rpc.go__failpoint_stash__ +++ /dev/null @@ -1,582 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package unistore - -import ( - "context" - "io" - "math" - "os" - "strconv" - "sync/atomic" - "time" - - "github.com/golang/protobuf/proto" - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/kvproto/pkg/coprocessor" - "github.com/pingcap/kvproto/pkg/debugpb" - "github.com/pingcap/kvproto/pkg/errorpb" - "github.com/pingcap/kvproto/pkg/kvrpcpb" - "github.com/pingcap/kvproto/pkg/metapb" - "github.com/pingcap/kvproto/pkg/mpp" - "github.com/pingcap/tidb/pkg/parser/terror" - us "github.com/pingcap/tidb/pkg/store/mockstore/unistore/tikv" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/tikv/client-go/v2/tikv" - "github.com/tikv/client-go/v2/tikvrpc" - "google.golang.org/grpc/metadata" -) - -// For gofail injection. -var undeterminedErr = terror.ErrResultUndetermined - -// RPCClient sends kv RPC calls to mock cluster. RPCClient mocks the behavior of -// a rpc client at tikv's side. -type RPCClient struct { - usSvr *us.Server - cluster *Cluster - path string - rawHandler *rawHandler - persistent bool - closed int32 -} - -// CheckResourceTagForTopSQLInGoTest is used to identify whether check resource tag for TopSQL. -var CheckResourceTagForTopSQLInGoTest bool - -// UnistoreRPCClientSendHook exports for test. -var UnistoreRPCClientSendHook atomic.Pointer[func(*tikvrpc.Request)] - -// SendRequest sends a request to mock cluster. -func (c *RPCClient) SendRequest(ctx context.Context, addr string, req *tikvrpc.Request, timeout time.Duration) (*tikvrpc.Response, error) { - failpoint.Inject("rpcServerBusy", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{ServerIsBusy: &errorpb.ServerIsBusy{}})) - } - }) - failpoint.Inject("epochNotMatch", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{EpochNotMatch: &errorpb.EpochNotMatch{}})) - } - }) - - failpoint.Inject("unistoreRPCClientSendHook", func(val failpoint.Value) { - if fn := UnistoreRPCClientSendHook.Load(); val.(bool) && fn != nil { - (*fn)(req) - } - }) - - failpoint.Inject("rpcTiKVAllowedOnAlmostFull", func(val failpoint.Value) { - if val.(bool) { - if req.Type == tikvrpc.CmdPrewrite || req.Type == tikvrpc.CmdCommit { - if req.Context.DiskFullOpt != kvrpcpb.DiskFullOpt_AllowedOnAlmostFull { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{DiskFull: &errorpb.DiskFull{StoreId: []uint64{1}, Reason: "disk full"}})) - } - } - } - }) - failpoint.Inject("unistoreRPCDeadlineExceeded", func(val failpoint.Value) { - if val.(bool) && timeout < time.Second { - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) - } - }) - failpoint.Inject("unistoreRPCSlowByInjestSleep", func(val failpoint.Value) { - time.Sleep(time.Duration(val.(int) * int(time.Millisecond))) - failpoint.Return(tikvrpc.GenRegionErrorResp(req, &errorpb.Error{Message: "Deadline is exceeded"})) - }) - - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - if atomic.LoadInt32(&c.closed) != 0 { - // Return `context.Canceled` can break Backoff. - return nil, context.Canceled - } - - storeID, err := c.usSvr.GetStoreIDByAddr(addr) - if err != nil { - return nil, err - } - - if CheckResourceTagForTopSQLInGoTest { - err = checkResourceTagForTopSQL(req) - if err != nil { - return nil, err - } - } - - resp := &tikvrpc.Response{} - switch req.Type { - case tikvrpc.CmdGet: - resp.Resp, err = c.usSvr.KvGet(ctx, req.Get()) - case tikvrpc.CmdScan: - kvScanReq := req.Scan() - failpoint.Inject("rpcScanResult", func(val failpoint.Value) { - switch val.(string) { - case "keyError": - failpoint.Return(&tikvrpc.Response{ - Resp: &kvrpcpb.ScanResponse{Error: &kvrpcpb.KeyError{ - Locked: &kvrpcpb.LockInfo{ - PrimaryLock: kvScanReq.StartKey, - LockVersion: kvScanReq.Version - 1, - Key: kvScanReq.StartKey, - LockTtl: 50, - TxnSize: 1, - LockType: kvrpcpb.Op_Put, - }, - }}, - }, nil) - } - }) - - resp.Resp, err = c.usSvr.KvScan(ctx, kvScanReq) - case tikvrpc.CmdPrewrite: - failpoint.Inject("rpcPrewriteResult", func(val failpoint.Value) { - if val != nil { - switch val.(string) { - case "timeout": - failpoint.Return(nil, errors.New("timeout")) - case "notLeader": - failpoint.Return(&tikvrpc.Response{ - Resp: &kvrpcpb.PrewriteResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, - }, nil) - case "writeConflict": - failpoint.Return(&tikvrpc.Response{ - Resp: &kvrpcpb.PrewriteResponse{Errors: []*kvrpcpb.KeyError{{Conflict: &kvrpcpb.WriteConflict{}}}}, - }, nil) - } - } - }) - - r := req.Prewrite() - c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) - resp.Resp, err = c.usSvr.KvPrewrite(ctx, r) - - failpoint.Inject("rpcPrewriteTimeout", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, undeterminedErr) - } - }) - case tikvrpc.CmdPessimisticLock: - r := req.PessimisticLock() - c.cluster.handleDelay(r.StartVersion, r.Context.RegionId) - resp.Resp, err = c.usSvr.KvPessimisticLock(ctx, r) - case tikvrpc.CmdPessimisticRollback: - resp.Resp, err = c.usSvr.KVPessimisticRollback(ctx, req.PessimisticRollback()) - case tikvrpc.CmdCommit: - failpoint.Inject("rpcCommitResult", func(val failpoint.Value) { - switch val.(string) { - case "timeout": - failpoint.Return(nil, errors.New("timeout")) - case "notLeader": - failpoint.Return(&tikvrpc.Response{ - Resp: &kvrpcpb.CommitResponse{RegionError: &errorpb.Error{NotLeader: &errorpb.NotLeader{}}}, - }, nil) - case "keyError": - failpoint.Return(&tikvrpc.Response{ - Resp: &kvrpcpb.CommitResponse{Error: &kvrpcpb.KeyError{}}, - }, nil) - } - }) - - resp.Resp, err = c.usSvr.KvCommit(ctx, req.Commit()) - - failpoint.Inject("rpcCommitTimeout", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, undeterminedErr) - } - }) - case tikvrpc.CmdCleanup: - resp.Resp, err = c.usSvr.KvCleanup(ctx, req.Cleanup()) - case tikvrpc.CmdCheckTxnStatus: - resp.Resp, err = c.usSvr.KvCheckTxnStatus(ctx, req.CheckTxnStatus()) - case tikvrpc.CmdCheckSecondaryLocks: - resp.Resp, err = c.usSvr.KvCheckSecondaryLocks(ctx, req.CheckSecondaryLocks()) - case tikvrpc.CmdTxnHeartBeat: - resp.Resp, err = c.usSvr.KvTxnHeartBeat(ctx, req.TxnHeartBeat()) - case tikvrpc.CmdBatchGet: - batchGetReq := req.BatchGet() - failpoint.Inject("rpcBatchGetResult", func(val failpoint.Value) { - switch val.(string) { - case "keyError": - failpoint.Return(&tikvrpc.Response{ - Resp: &kvrpcpb.BatchGetResponse{Error: &kvrpcpb.KeyError{ - Locked: &kvrpcpb.LockInfo{ - PrimaryLock: batchGetReq.Keys[0], - LockVersion: batchGetReq.Version - 1, - Key: batchGetReq.Keys[0], - LockTtl: 50, - TxnSize: 1, - LockType: kvrpcpb.Op_Put, - }, - }}, - }, nil) - } - }) - - resp.Resp, err = c.usSvr.KvBatchGet(ctx, batchGetReq) - case tikvrpc.CmdBatchRollback: - resp.Resp, err = c.usSvr.KvBatchRollback(ctx, req.BatchRollback()) - case tikvrpc.CmdScanLock: - resp.Resp, err = c.usSvr.KvScanLock(ctx, req.ScanLock()) - case tikvrpc.CmdResolveLock: - resp.Resp, err = c.usSvr.KvResolveLock(ctx, req.ResolveLock()) - case tikvrpc.CmdGC: - resp.Resp, err = c.usSvr.KvGC(ctx, req.GC()) - case tikvrpc.CmdDeleteRange: - resp.Resp, err = c.usSvr.KvDeleteRange(ctx, req.DeleteRange()) - case tikvrpc.CmdRawGet: - resp.Resp, err = c.rawHandler.RawGet(ctx, req.RawGet()) - case tikvrpc.CmdRawBatchGet: - resp.Resp, err = c.rawHandler.RawBatchGet(ctx, req.RawBatchGet()) - case tikvrpc.CmdRawPut: - resp.Resp, err = c.rawHandler.RawPut(ctx, req.RawPut()) - case tikvrpc.CmdRawBatchPut: - resp.Resp, err = c.rawHandler.RawBatchPut(ctx, req.RawBatchPut()) - case tikvrpc.CmdRawDelete: - resp.Resp, err = c.rawHandler.RawDelete(ctx, req.RawDelete()) - case tikvrpc.CmdRawBatchDelete: - resp.Resp, err = c.rawHandler.RawBatchDelete(ctx, req.RawBatchDelete()) - case tikvrpc.CmdRawDeleteRange: - resp.Resp, err = c.rawHandler.RawDeleteRange(ctx, req.RawDeleteRange()) - case tikvrpc.CmdRawScan: - resp.Resp, err = c.rawHandler.RawScan(ctx, req.RawScan()) - case tikvrpc.CmdCop: - resp.Resp, err = c.usSvr.Coprocessor(ctx, req.Cop()) - case tikvrpc.CmdCopStream: - resp.Resp, err = c.handleCopStream(ctx, req.Cop()) - case tikvrpc.CmdBatchCop: - failpoint.Inject("BatchCopCancelled", func(value failpoint.Value) { - if value.(bool) { - failpoint.Return(nil, context.Canceled) - } - }) - - failpoint.Inject("BatchCopRpcErr"+addr, func(value failpoint.Value) { - if value.(string) == addr { - failpoint.Return(nil, errors.New("rpc error")) - } - }) - resp.Resp, err = c.handleBatchCop(ctx, req.BatchCop(), timeout) - case tikvrpc.CmdMPPConn: - failpoint.Inject("mppConnTimeout", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("rpc error")) - } - }) - failpoint.Inject("MppVersionError", func(val failpoint.Value) { - if v := int64(val.(int)); v > req.EstablishMPPConn().GetReceiverMeta().GetMppVersion() || v > req.EstablishMPPConn().GetSenderMeta().GetMppVersion() { - failpoint.Return(nil, context.Canceled) - } - }) - resp.Resp, err = c.handleEstablishMPPConnection(ctx, req.EstablishMPPConn(), timeout, storeID) - case tikvrpc.CmdMPPTask: - failpoint.Inject("mppDispatchTimeout", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("rpc error")) - } - }) - failpoint.Inject("MppVersionError", func(val failpoint.Value) { - if v := int64(val.(int)); v > req.DispatchMPPTask().GetMeta().GetMppVersion() { - failpoint.Return(nil, context.Canceled) - } - }) - resp.Resp, err = c.handleDispatchMPPTask(ctx, req.DispatchMPPTask(), storeID) - case tikvrpc.CmdMPPCancel: - case tikvrpc.CmdMvccGetByKey: - resp.Resp, err = c.usSvr.MvccGetByKey(ctx, req.MvccGetByKey()) - case tikvrpc.CmdMPPAlive: - resp.Resp, err = c.usSvr.IsAlive(ctx, req.IsMPPAlive()) - case tikvrpc.CmdMvccGetByStartTs: - resp.Resp, err = c.usSvr.MvccGetByStartTs(ctx, req.MvccGetByStartTs()) - case tikvrpc.CmdSplitRegion: - resp.Resp, err = c.usSvr.SplitRegion(ctx, req.SplitRegion()) - case tikvrpc.CmdDebugGetRegionProperties: - resp.Resp, err = c.handleDebugGetRegionProperties(ctx, req.DebugGetRegionProperties()) - return resp, err - case tikvrpc.CmdStoreSafeTS: - resp.Resp, err = c.usSvr.GetStoreSafeTS(ctx, req.StoreSafeTS()) - return resp, err - case tikvrpc.CmdUnsafeDestroyRange: - // Pretend it was done. Unistore does not have "destroy", and the - // keys has already been removed one-by-one before through: - // (dr *delRange) startEmulator() - resp.Resp = &kvrpcpb.UnsafeDestroyRangeResponse{} - return resp, nil - case tikvrpc.CmdFlush: - r := req.Flush() - c.cluster.handleDelay(r.StartTs, r.Context.RegionId) - resp.Resp, err = c.usSvr.KvFlush(ctx, r) - case tikvrpc.CmdBufferBatchGet: - r := req.BufferBatchGet() - resp.Resp, err = c.usSvr.KvBufferBatchGet(ctx, r) - default: - err = errors.Errorf("not support this request type %v", req.Type) - } - if err != nil { - return nil, err - } - var regErr *errorpb.Error - if req.Type != tikvrpc.CmdBatchCop && req.Type != tikvrpc.CmdMPPConn && req.Type != tikvrpc.CmdMPPTask && req.Type != tikvrpc.CmdMPPAlive { - regErr, err = resp.GetRegionError() - } - if err != nil { - return nil, err - } - if regErr != nil { - if regErr.EpochNotMatch != nil { - for i, newReg := range regErr.EpochNotMatch.CurrentRegions { - regErr.EpochNotMatch.CurrentRegions[i] = proto.Clone(newReg).(*metapb.Region) - } - } - } - return resp, nil -} - -func (c *RPCClient) handleCopStream(ctx context.Context, req *coprocessor.Request) (*tikvrpc.CopStreamResponse, error) { - copResp, err := c.usSvr.Coprocessor(ctx, req) - if err != nil { - return nil, err - } - return &tikvrpc.CopStreamResponse{ - Tikv_CoprocessorStreamClient: new(mockCopStreamClient), - Response: copResp, - }, nil -} - -// handleEstablishMPPConnection handle the mock mpp collection came from root or peers. -func (c *RPCClient) handleEstablishMPPConnection(ctx context.Context, r *mpp.EstablishMPPConnectionRequest, timeout time.Duration, storeID uint64) (*tikvrpc.MPPStreamResponse, error) { - mockServer := new(mockMPPConnectStreamServer) - err := c.usSvr.EstablishMPPConnectionWithStoreID(r, mockServer, storeID) - if err != nil { - return nil, err - } - failpoint.Inject("establishMppConnectionErr", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, errors.New("rpc error")) - } - }) - var mockClient = mockMPPConnectionClient{mppResponses: mockServer.mppResponses, idx: 0, ctx: ctx, targetTask: r.ReceiverMeta} - streamResp := &tikvrpc.MPPStreamResponse{Tikv_EstablishMPPConnectionClient: &mockClient} - _, cancel := context.WithCancel(ctx) - streamResp.Lease.Cancel = cancel - streamResp.Timeout = timeout - // mock the stream resp from the server's resp slice - first, err := streamResp.Recv() - if err != nil { - if errors.Cause(err) != io.EOF { - return nil, errors.Trace(err) - } - } - streamResp.MPPDataPacket = first - return streamResp, nil -} - -func (c *RPCClient) handleDispatchMPPTask(ctx context.Context, r *mpp.DispatchTaskRequest, storeID uint64) (*mpp.DispatchTaskResponse, error) { - return c.usSvr.DispatchMPPTaskWithStoreID(ctx, r, storeID) -} - -func (c *RPCClient) handleBatchCop(ctx context.Context, r *coprocessor.BatchRequest, timeout time.Duration) (*tikvrpc.BatchCopStreamResponse, error) { - mockBatchCopServer := &mockBatchCoprocessorStreamServer{} - err := c.usSvr.BatchCoprocessor(r, mockBatchCopServer) - if err != nil { - return nil, err - } - var mockBatchCopClient = mockBatchCopClient{batchResponses: mockBatchCopServer.batchResponses, idx: 0} - batchResp := &tikvrpc.BatchCopStreamResponse{Tikv_BatchCoprocessorClient: &mockBatchCopClient} - _, cancel := context.WithCancel(ctx) - batchResp.Lease.Cancel = cancel - batchResp.Timeout = timeout - first, err := batchResp.Recv() - if err != nil { - return nil, errors.Trace(err) - } - batchResp.BatchResponse = first - return batchResp, nil -} - -func (c *RPCClient) handleDebugGetRegionProperties(ctx context.Context, req *debugpb.GetRegionPropertiesRequest) (*debugpb.GetRegionPropertiesResponse, error) { - region := c.cluster.GetRegion(req.RegionId) - _, start, err := codec.DecodeBytes(region.StartKey, nil) - if err != nil { - return nil, err - } - _, end, err := codec.DecodeBytes(region.EndKey, nil) - if err != nil { - return nil, err - } - scanResp, err := c.usSvr.KvScan(ctx, &kvrpcpb.ScanRequest{ - Context: &kvrpcpb.Context{ - RegionId: region.Id, - RegionEpoch: region.RegionEpoch, - }, - StartKey: start, - EndKey: end, - Version: math.MaxUint64, - Limit: math.MaxUint32, - }) - if err != nil { - return nil, err - } - if err := scanResp.GetRegionError(); err != nil { - panic(err) - } - return &debugpb.GetRegionPropertiesResponse{ - Props: []*debugpb.Property{{ - Name: "mvcc.num_rows", - Value: strconv.Itoa(len(scanResp.Pairs)), - }}}, nil -} - -// Close closes RPCClient and cleanup temporal resources. -func (c *RPCClient) Close() error { - atomic.StoreInt32(&c.closed, 1) - if c.usSvr != nil { - c.usSvr.Stop() - } - if !c.persistent && c.path != "" { - err := os.RemoveAll(c.path) - _ = err - } - return nil -} - -// CloseAddr implements tikv.Client interface and it does nothing. -func (c *RPCClient) CloseAddr(addr string) error { - return nil -} - -// SetEventListener implements tikv.Client interface. -func (c *RPCClient) SetEventListener(listener tikv.ClientEventListener) {} - -type mockClientStream struct{} - -// Header implements grpc.ClientStream interface -func (mockClientStream) Header() (metadata.MD, error) { return nil, nil } - -// Trailer implements grpc.ClientStream interface -func (mockClientStream) Trailer() metadata.MD { return nil } - -// CloseSend implements grpc.ClientStream interface -func (mockClientStream) CloseSend() error { return nil } - -// Context implements grpc.ClientStream interface -func (mockClientStream) Context() context.Context { return nil } - -// SendMsg implements grpc.ClientStream interface -func (mockClientStream) SendMsg(m any) error { return nil } - -// RecvMsg implements grpc.ClientStream interface -func (mockClientStream) RecvMsg(m any) error { return nil } - -type mockCopStreamClient struct { - mockClientStream -} - -func (mock *mockCopStreamClient) Recv() (*coprocessor.Response, error) { - return nil, io.EOF -} - -type mockBatchCopClient struct { - mockClientStream - batchResponses []*coprocessor.BatchResponse - idx int -} - -func (mock *mockBatchCopClient) Recv() (*coprocessor.BatchResponse, error) { - if mock.idx < len(mock.batchResponses) { - ret := mock.batchResponses[mock.idx] - mock.idx++ - var err error - if len(ret.OtherError) > 0 { - err = errors.New(ret.OtherError) - ret = nil - } - return ret, err - } - failpoint.Inject("batchCopRecvTimeout", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(nil, context.Canceled) - } - }) - return nil, io.EOF -} - -type mockMPPConnectionClient struct { - mockClientStream - mppResponses []*mpp.MPPDataPacket - idx int - ctx context.Context - targetTask *mpp.TaskMeta -} - -func (mock *mockMPPConnectionClient) Recv() (*mpp.MPPDataPacket, error) { - if mock.idx < len(mock.mppResponses) { - ret := mock.mppResponses[mock.idx] - mock.idx++ - return ret, nil - } - failpoint.Inject("mppRecvTimeout", func(val failpoint.Value) { - if int64(val.(int)) == mock.targetTask.TaskId { - failpoint.Return(nil, context.Canceled) - } - }) - failpoint.Inject("mppRecvHang", func(val failpoint.Value) { - for val.(bool) { - select { - case <-mock.ctx.Done(): - { - failpoint.Return(nil, context.Canceled) - } - default: - time.Sleep(1 * time.Second) - } - } - }) - return nil, io.EOF -} - -type mockServerStream struct{} - -func (mockServerStream) SetHeader(metadata.MD) error { return nil } -func (mockServerStream) SendHeader(metadata.MD) error { return nil } -func (mockServerStream) SetTrailer(metadata.MD) {} -func (mockServerStream) Context() context.Context { return nil } -func (mockServerStream) SendMsg(any) error { return nil } -func (mockServerStream) RecvMsg(any) error { return nil } - -type mockBatchCoprocessorStreamServer struct { - mockServerStream - batchResponses []*coprocessor.BatchResponse -} - -func (mockBatchCopServer *mockBatchCoprocessorStreamServer) Send(response *coprocessor.BatchResponse) error { - mockBatchCopServer.batchResponses = append(mockBatchCopServer.batchResponses, response) - return nil -} - -type mockMPPConnectStreamServer struct { - mockServerStream - mppResponses []*mpp.MPPDataPacket -} - -func (mockMPPConnectStreamServer *mockMPPConnectStreamServer) Send(mppResponse *mpp.MPPDataPacket) error { - mockMPPConnectStreamServer.mppResponses = append(mockMPPConnectStreamServer.mppResponses, mppResponse) - return nil -} diff --git a/pkg/table/contextimpl/binding__failpoint_binding__.go b/pkg/table/contextimpl/binding__failpoint_binding__.go deleted file mode 100644 index beab72aa0be19..0000000000000 --- a/pkg/table/contextimpl/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package contextimpl - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/table/contextimpl/table.go b/pkg/table/contextimpl/table.go index 202a8c0017490..1f701246ab8a8 100644 --- a/pkg/table/contextimpl/table.go +++ b/pkg/table/contextimpl/table.go @@ -120,11 +120,11 @@ func (ctx *TableContextImpl) GetReservedRowIDAlloc() (*stmtctx.ReservedRowIDAllo // GetBinlogSupport implements the MutateContext interface. func (ctx *TableContextImpl) GetBinlogSupport() (context.BinlogSupport, bool) { - if _, _err_ := failpoint.Eval(_curpkg_("forceWriteBinlog")); _err_ == nil { + failpoint.Inject("forceWriteBinlog", func() { // Just to cover binlog related code in this package, since the `BinlogClient` is // still nil, mutations won't be written to pump on commit. - return ctx, true - } + failpoint.Return(ctx, true) + }) if ctx.vars().BinlogClient != nil { return ctx, true } diff --git a/pkg/table/contextimpl/table.go__failpoint_stash__ b/pkg/table/contextimpl/table.go__failpoint_stash__ deleted file mode 100644 index 1f701246ab8a8..0000000000000 --- a/pkg/table/contextimpl/table.go__failpoint_stash__ +++ /dev/null @@ -1,200 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package contextimpl - -import ( - "github.com/pingcap/failpoint" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table/context" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tipb/go-binlog" -) - -var _ context.MutateContext = &TableContextImpl{} -var _ context.AllocatorContext = &TableContextImpl{} - -// TableContextImpl is used to provide context for table operations. -type TableContextImpl struct { - sessionctx.Context - // mutateBuffers is a memory pool for table related memory allocation that aims to reuse memory - // and saves allocation - // The buffers are supposed to be used inside AddRecord/UpdateRecord/RemoveRecord. - mutateBuffers *context.MutateBuffers -} - -// NewTableContextImpl creates a new TableContextImpl. -func NewTableContextImpl(sctx sessionctx.Context) *TableContextImpl { - return &TableContextImpl{ - Context: sctx, - mutateBuffers: context.NewMutateBuffers(sctx.GetSessionVars().GetWriteStmtBufs()), - } -} - -// AlternativeAllocators implements the AllocatorContext interface -func (ctx *TableContextImpl) AlternativeAllocators(tbl *model.TableInfo) (allocators autoid.Allocators, ok bool) { - // Use an independent allocator for global temporary tables. - if tbl.TempTableType == model.TempTableGlobal { - if tempTbl := ctx.vars().GetTemporaryTable(tbl); tempTbl != nil { - if alloc := tempTbl.GetAutoIDAllocator(); alloc != nil { - return autoid.NewAllocators(false, alloc), true - } - } - // If the session is not in a txn, for example, in "show create table", use the original allocator. - } - return -} - -// GetExprCtx returns the ExprContext -func (ctx *TableContextImpl) GetExprCtx() exprctx.ExprContext { - return ctx.Context.GetExprCtx() -} - -// ConnectionID implements the MutateContext interface. -func (ctx *TableContextImpl) ConnectionID() uint64 { - return ctx.vars().ConnectionID -} - -// InRestrictedSQL returns whether the current context is used in restricted SQL. -func (ctx *TableContextImpl) InRestrictedSQL() bool { - return ctx.vars().InRestrictedSQL -} - -// TxnAssertionLevel implements the MutateContext interface. -func (ctx *TableContextImpl) TxnAssertionLevel() variable.AssertionLevel { - return ctx.vars().AssertionLevel -} - -// EnableMutationChecker implements the MutateContext interface. -func (ctx *TableContextImpl) EnableMutationChecker() bool { - return ctx.vars().EnableMutationChecker -} - -// GetRowEncodingConfig returns the RowEncodingConfig. -func (ctx *TableContextImpl) GetRowEncodingConfig() context.RowEncodingConfig { - vars := ctx.vars() - return context.RowEncodingConfig{ - IsRowLevelChecksumEnabled: vars.IsRowLevelChecksumEnabled(), - RowEncoder: &vars.RowEncoder, - } -} - -// GetMutateBuffers implements the MutateContext interface. -func (ctx *TableContextImpl) GetMutateBuffers() *context.MutateBuffers { - return ctx.mutateBuffers -} - -// GetRowIDShardGenerator implements the MutateContext interface. -func (ctx *TableContextImpl) GetRowIDShardGenerator() *variable.RowIDShardGenerator { - return ctx.vars().GetRowIDShardGenerator() -} - -// GetReservedRowIDAlloc implements the MutateContext interface. -func (ctx *TableContextImpl) GetReservedRowIDAlloc() (*stmtctx.ReservedRowIDAlloc, bool) { - if sc := ctx.vars().StmtCtx; sc != nil { - return &sc.ReservedRowIDAlloc, true - } - // `StmtCtx` should not be nil in the `variable.SessionVars`. - // We just put an assertion that will panic only if in test here. - // In production code, here returns (nil, false) to make code safe - // because some old code checks `StmtCtx != nil` but we don't know why. - intest.Assert(false, "SessionVars.StmtCtx should not be nil") - return nil, false -} - -// GetBinlogSupport implements the MutateContext interface. -func (ctx *TableContextImpl) GetBinlogSupport() (context.BinlogSupport, bool) { - failpoint.Inject("forceWriteBinlog", func() { - // Just to cover binlog related code in this package, since the `BinlogClient` is - // still nil, mutations won't be written to pump on commit. - failpoint.Return(ctx, true) - }) - if ctx.vars().BinlogClient != nil { - return ctx, true - } - return nil, false -} - -// GetBinlogMutation implements the BinlogSupport interface. -func (ctx *TableContextImpl) GetBinlogMutation(tblID int64) *binlog.TableMutation { - return ctx.Context.StmtGetMutation(tblID) -} - -// GetStatisticsSupport implements the MutateContext interface. -func (ctx *TableContextImpl) GetStatisticsSupport() (context.StatisticsSupport, bool) { - if ctx.vars().TxnCtx != nil { - return ctx, true - } - return nil, false -} - -// UpdatePhysicalTableDelta implements the StatisticsSupport interface. -func (ctx *TableContextImpl) UpdatePhysicalTableDelta( - physicalTableID int64, delta int64, count int64, cols variable.DeltaCols, -) { - if txnCtx := ctx.vars().TxnCtx; txnCtx != nil { - txnCtx.UpdateDeltaForTable(physicalTableID, delta, count, cols) - } -} - -// GetCachedTableSupport implements the MutateContext interface. -func (ctx *TableContextImpl) GetCachedTableSupport() (context.CachedTableSupport, bool) { - if ctx.vars().TxnCtx != nil { - return ctx, true - } - return nil, false -} - -// AddCachedTableHandleToTxn implements `CachedTableSupport` interface -func (ctx *TableContextImpl) AddCachedTableHandleToTxn(tableID int64, handle any) { - txnCtx := ctx.vars().TxnCtx - if txnCtx.CachedTables == nil { - txnCtx.CachedTables = make(map[int64]any) - } - if _, ok := txnCtx.CachedTables[tableID]; !ok { - txnCtx.CachedTables[tableID] = handle - } -} - -// GetTemporaryTableSupport implements the MutateContext interface. -func (ctx *TableContextImpl) GetTemporaryTableSupport() (context.TemporaryTableSupport, bool) { - if ctx.vars().TxnCtx == nil { - return nil, false - } - return ctx, true -} - -// GetTemporaryTableSizeLimit implements TemporaryTableSupport interface. -func (ctx *TableContextImpl) GetTemporaryTableSizeLimit() int64 { - return ctx.vars().TMPTableSize -} - -// AddTemporaryTableToTxn implements the TemporaryTableSupport interface. -func (ctx *TableContextImpl) AddTemporaryTableToTxn(tblInfo *model.TableInfo) (context.TemporaryTableHandler, bool) { - vars := ctx.vars() - if tbl := vars.GetTemporaryTable(tblInfo); tbl != nil { - tbl.SetModified(true) - return context.NewTemporaryTableHandler(tbl, vars.TemporaryTableData), true - } - return context.TemporaryTableHandler{}, false -} - -func (ctx *TableContextImpl) vars() *variable.SessionVars { - return ctx.Context.GetSessionVars() -} diff --git a/pkg/table/tables/binding__failpoint_binding__.go b/pkg/table/tables/binding__failpoint_binding__.go deleted file mode 100644 index fc93c522f2734..0000000000000 --- a/pkg/table/tables/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package tables - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/table/tables/cache.go b/pkg/table/tables/cache.go index 7554cfd7f6e0d..1051ee8a995f6 100644 --- a/pkg/table/tables/cache.go +++ b/pkg/table/tables/cache.go @@ -100,9 +100,9 @@ func (c *cachedTable) TryReadFromCache(ts uint64, leaseDuration time.Duration) ( distance := leaseTime.Sub(nowTime) var triggerFailpoint bool - if _, _err_ := failpoint.Eval(_curpkg_("mockRenewLeaseABA1")); _err_ == nil { + failpoint.Inject("mockRenewLeaseABA1", func(_ failpoint.Value) { triggerFailpoint = true - } + }) if distance >= 0 && distance <= leaseDuration/2 || triggerFailpoint { if h := c.TakeStateRemoteHandleNoWait(); h != nil { @@ -273,11 +273,11 @@ func (c *cachedTable) RemoveRecord(sctx table.MutateContext, h kv.Handle, r []ty var TestMockRenewLeaseABA2 chan struct{} func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, leaseDuration time.Duration) { - if _, _err_ := failpoint.Eval(_curpkg_("mockRenewLeaseABA2")); _err_ == nil { + failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { c.PutStateRemoteHandle(handle) <-TestMockRenewLeaseABA2 c.TakeStateRemoteHandle() - } + }) defer c.PutStateRemoteHandle(handle) @@ -298,9 +298,9 @@ func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, }) } - if _, _err_ := failpoint.Eval(_curpkg_("mockRenewLeaseABA2")); _err_ == nil { + failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { TestMockRenewLeaseABA2 <- struct{}{} - } + }) } const cacheTableWriteLease = 5 * time.Second diff --git a/pkg/table/tables/cache.go__failpoint_stash__ b/pkg/table/tables/cache.go__failpoint_stash__ deleted file mode 100644 index 1051ee8a995f6..0000000000000 --- a/pkg/table/tables/cache.go__failpoint_stash__ +++ /dev/null @@ -1,355 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "context" - "sync/atomic" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/sqlexec" - "github.com/tikv/client-go/v2/oracle" - "github.com/tikv/client-go/v2/tikv" - "go.uber.org/zap" -) - -var ( - _ table.CachedTable = &cachedTable{} -) - -type cachedTable struct { - TableCommon - cacheData atomic.Pointer[cacheData] - totalSize int64 - // StateRemote is not thread-safe, this tokenLimit is used to keep only one visitor. - tokenLimit -} - -type tokenLimit chan StateRemote - -func (t tokenLimit) TakeStateRemoteHandle() StateRemote { - handle := <-t - return handle -} - -func (t tokenLimit) TakeStateRemoteHandleNoWait() StateRemote { - select { - case handle := <-t: - return handle - default: - return nil - } -} - -func (t tokenLimit) PutStateRemoteHandle(handle StateRemote) { - t <- handle -} - -// cacheData pack the cache data and lease. -type cacheData struct { - Start uint64 - Lease uint64 - kv.MemBuffer -} - -func leaseFromTS(ts uint64, leaseDuration time.Duration) uint64 { - physicalTime := oracle.GetTimeFromTS(ts) - lease := oracle.GoTimeToTS(physicalTime.Add(leaseDuration)) - return lease -} - -func newMemBuffer(store kv.Storage) (kv.MemBuffer, error) { - // Here is a trick to get a MemBuffer data, because the internal API is not exposed. - // Create a transaction with start ts 0, and take the MemBuffer out. - buffTxn, err := store.Begin(tikv.WithStartTS(0)) - if err != nil { - return nil, err - } - return buffTxn.GetMemBuffer(), nil -} - -func (c *cachedTable) TryReadFromCache(ts uint64, leaseDuration time.Duration) (kv.MemBuffer, bool /*loading*/) { - data := c.cacheData.Load() - if data == nil { - return nil, false - } - if ts >= data.Start && ts < data.Lease { - leaseTime := oracle.GetTimeFromTS(data.Lease) - nowTime := oracle.GetTimeFromTS(ts) - distance := leaseTime.Sub(nowTime) - - var triggerFailpoint bool - failpoint.Inject("mockRenewLeaseABA1", func(_ failpoint.Value) { - triggerFailpoint = true - }) - - if distance >= 0 && distance <= leaseDuration/2 || triggerFailpoint { - if h := c.TakeStateRemoteHandleNoWait(); h != nil { - go c.renewLease(h, ts, data, leaseDuration) - } - } - // If data is not nil, but data.MemBuffer is nil, it means the data is being - // loading by a background goroutine. - return data.MemBuffer, data.MemBuffer == nil - } - return nil, false -} - -// newCachedTable creates a new CachedTable Instance -func newCachedTable(tbl *TableCommon) (table.Table, error) { - ret := &cachedTable{ - TableCommon: tbl.Copy(), - tokenLimit: make(chan StateRemote, 1), - } - return ret, nil -} - -// Init is an extra operation for cachedTable after TableFromMeta, -// Because cachedTable need some additional parameter that can't be passed in TableFromMeta. -func (c *cachedTable) Init(exec sqlexec.SQLExecutor) error { - raw, ok := exec.(sqlExec) - if !ok { - return errors.New("Need sqlExec rather than sqlexec.SQLExecutor") - } - handle := NewStateRemote(raw) - c.PutStateRemoteHandle(handle) - return nil -} - -func (c *cachedTable) loadDataFromOriginalTable(store kv.Storage) (kv.MemBuffer, uint64, int64, error) { - buffer, err := newMemBuffer(store) - if err != nil { - return nil, 0, 0, err - } - var startTS uint64 - totalSize := int64(0) - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnCacheTable) - err = kv.RunInNewTxn(ctx, store, true, func(ctx context.Context, txn kv.Transaction) error { - prefix := tablecodec.GenTablePrefix(c.tableID) - if err != nil { - return errors.Trace(err) - } - startTS = txn.StartTS() - it, err := txn.Iter(prefix, prefix.PrefixNext()) - if err != nil { - return errors.Trace(err) - } - defer it.Close() - - for it.Valid() && it.Key().HasPrefix(prefix) { - key := it.Key() - value := it.Value() - err = buffer.Set(key, value) - if err != nil { - return errors.Trace(err) - } - totalSize += int64(len(key)) - totalSize += int64(len(value)) - err = it.Next() - if err != nil { - return errors.Trace(err) - } - } - return nil - }) - if err != nil { - return nil, 0, totalSize, err - } - - return buffer, startTS, totalSize, nil -} - -func (c *cachedTable) UpdateLockForRead(ctx context.Context, store kv.Storage, ts uint64, leaseDuration time.Duration) { - if h := c.TakeStateRemoteHandleNoWait(); h != nil { - go c.updateLockForRead(ctx, h, store, ts, leaseDuration) - } -} - -func (c *cachedTable) updateLockForRead(ctx context.Context, handle StateRemote, store kv.Storage, ts uint64, leaseDuration time.Duration) { - defer func() { - if r := recover(); r != nil { - log.Error("panic in the recoverable goroutine", - zap.Any("r", r), - zap.Stack("stack trace")) - } - c.PutStateRemoteHandle(handle) - }() - - // Load data from original table and the update lock information. - tid := c.Meta().ID - lease := leaseFromTS(ts, leaseDuration) - succ, err := handle.LockForRead(ctx, tid, lease) - if err != nil { - log.Warn("lock cached table for read", zap.Error(err)) - return - } - if succ { - c.cacheData.Store(&cacheData{ - Start: ts, - Lease: lease, - MemBuffer: nil, // Async loading, this will be set later. - }) - - // Make the load data process async, in case that loading data takes longer the - // lease duration, then the loaded data get staled and that process repeats forever. - go func() { - start := time.Now() - mb, startTS, totalSize, err := c.loadDataFromOriginalTable(store) - metrics.LoadTableCacheDurationHistogram.Observe(time.Since(start).Seconds()) - if err != nil { - log.Info("load data from table fail", zap.Error(err)) - return - } - - tmp := c.cacheData.Load() - if tmp != nil && tmp.Start == ts { - c.cacheData.Store(&cacheData{ - Start: startTS, - Lease: tmp.Lease, - MemBuffer: mb, - }) - atomic.StoreInt64(&c.totalSize, totalSize) - } - }() - } - // Current status is not suitable to cache. -} - -const cachedTableSizeLimit = 64 * (1 << 20) - -// AddRecord implements the AddRecord method for the table.Table interface. -func (c *cachedTable) AddRecord(sctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { - if atomic.LoadInt64(&c.totalSize) > cachedTableSizeLimit { - return nil, table.ErrOptOnCacheTable.GenWithStackByArgs("table too large") - } - txnCtxAddCachedTable(sctx, c.Meta().ID, c) - return c.TableCommon.AddRecord(sctx, r, opts...) -} - -func txnCtxAddCachedTable(sctx table.MutateContext, tid int64, handle *cachedTable) { - if s, ok := sctx.GetCachedTableSupport(); ok { - s.AddCachedTableHandleToTxn(tid, handle) - } -} - -// UpdateRecord implements table.Table -func (c *cachedTable) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { - // Prevent furthur writing when the table is already too large. - if atomic.LoadInt64(&c.totalSize) > cachedTableSizeLimit { - return table.ErrOptOnCacheTable.GenWithStackByArgs("table too large") - } - txnCtxAddCachedTable(ctx, c.Meta().ID, c) - return c.TableCommon.UpdateRecord(ctx, h, oldData, newData, touched, opts...) -} - -// RemoveRecord implements table.Table RemoveRecord interface. -func (c *cachedTable) RemoveRecord(sctx table.MutateContext, h kv.Handle, r []types.Datum) error { - txnCtxAddCachedTable(sctx, c.Meta().ID, c) - return c.TableCommon.RemoveRecord(sctx, h, r) -} - -// TestMockRenewLeaseABA2 is used by test function TestRenewLeaseABAFailPoint. -var TestMockRenewLeaseABA2 chan struct{} - -func (c *cachedTable) renewLease(handle StateRemote, ts uint64, data *cacheData, leaseDuration time.Duration) { - failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { - c.PutStateRemoteHandle(handle) - <-TestMockRenewLeaseABA2 - c.TakeStateRemoteHandle() - }) - - defer c.PutStateRemoteHandle(handle) - - tid := c.Meta().ID - lease := leaseFromTS(ts, leaseDuration) - newLease, err := handle.RenewReadLease(context.Background(), tid, data.Lease, lease) - if err != nil { - if !kv.IsTxnRetryableError(err) { - log.Warn("Renew read lease error", zap.Error(err)) - } - return - } - if newLease > 0 { - c.cacheData.Store(&cacheData{ - Start: data.Start, - Lease: newLease, - MemBuffer: data.MemBuffer, - }) - } - - failpoint.Inject("mockRenewLeaseABA2", func(_ failpoint.Value) { - TestMockRenewLeaseABA2 <- struct{}{} - }) -} - -const cacheTableWriteLease = 5 * time.Second - -func (c *cachedTable) WriteLockAndKeepAlive(ctx context.Context, exit chan struct{}, leasePtr *uint64, wg chan error) { - writeLockLease, err := c.lockForWrite(ctx) - atomic.StoreUint64(leasePtr, writeLockLease) - wg <- err - if err != nil { - logutil.Logger(ctx).Warn("lock for write lock fail", zap.String("category", "cached table"), zap.Error(err)) - return - } - - t := time.NewTicker(cacheTableWriteLease / 2) - defer t.Stop() - for { - select { - case <-t.C: - if err := c.renew(ctx, leasePtr); err != nil { - logutil.Logger(ctx).Warn("renew write lock lease fail", zap.String("category", "cached table"), zap.Error(err)) - return - } - case <-exit: - return - } - } -} - -func (c *cachedTable) renew(ctx context.Context, leasePtr *uint64) error { - oldLease := atomic.LoadUint64(leasePtr) - physicalTime := oracle.GetTimeFromTS(oldLease) - newLease := oracle.GoTimeToTS(physicalTime.Add(cacheTableWriteLease)) - - h := c.TakeStateRemoteHandle() - defer c.PutStateRemoteHandle(h) - - succ, err := h.RenewWriteLease(ctx, c.Meta().ID, newLease) - if err != nil { - return errors.Trace(err) - } - if succ { - atomic.StoreUint64(leasePtr, newLease) - } - return nil -} - -func (c *cachedTable) lockForWrite(ctx context.Context) (uint64, error) { - handle := c.TakeStateRemoteHandle() - defer c.PutStateRemoteHandle(handle) - - return handle.LockForWrite(ctx, c.Meta().ID, cacheTableWriteLease) -} diff --git a/pkg/table/tables/mutation_checker.go b/pkg/table/tables/mutation_checker.go index 0fd6fdeb50517..33f18ea37cb95 100644 --- a/pkg/table/tables/mutation_checker.go +++ b/pkg/table/tables/mutation_checker.go @@ -549,8 +549,8 @@ func corruptMutations(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle, c } func injectMutationError(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle) error { - if commands, _err_ := failpoint.Eval(_curpkg_("corruptMutations")); _err_ == nil { - return corruptMutations(t, txn, sh, commands.(string)) - } + failpoint.Inject("corruptMutations", func(commands failpoint.Value) { + failpoint.Return(corruptMutations(t, txn, sh, commands.(string))) + }) return nil } diff --git a/pkg/table/tables/mutation_checker.go__failpoint_stash__ b/pkg/table/tables/mutation_checker.go__failpoint_stash__ deleted file mode 100644 index 33f18ea37cb95..0000000000000 --- a/pkg/table/tables/mutation_checker.go__failpoint_stash__ +++ /dev/null @@ -1,556 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tables - -import ( - "fmt" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/errno" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/table" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/dbterror" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/rowcodec" - "go.uber.org/zap" -) - -var ( - // ErrInconsistentRowValue is the error when values in a row insertion does not match the expected ones. - ErrInconsistentRowValue = dbterror.ClassTable.NewStd(errno.ErrInconsistentRowValue) - // ErrInconsistentHandle is the error when the handle in the row/index insertions does not match. - ErrInconsistentHandle = dbterror.ClassTable.NewStd(errno.ErrInconsistentHandle) - // ErrInconsistentIndexedValue is the error when decoded values from the index mutation cannot match row value - ErrInconsistentIndexedValue = dbterror.ClassTable.NewStd(errno.ErrInconsistentIndexedValue) -) - -type mutation struct { - key kv.Key - flags kv.KeyFlags - value []byte - indexID int64 // only for index mutations -} - -type columnMaps struct { - ColumnIDToInfo map[int64]*model.ColumnInfo - ColumnIDToFieldType map[int64]*types.FieldType - IndexIDToInfo map[int64]*model.IndexInfo - IndexIDToRowColInfos map[int64][]rowcodec.ColInfo -} - -// CheckDataConsistency checks whether the given set of mutations corresponding to a single row is consistent. -// Namely, assume the database is consistent before, applying the mutations shouldn't break the consistency. -// It aims at reducing bugs that will corrupt data, and preventing mistakes from spreading if possible. -// -// 3 conditions are checked: -// (1) row.value is consistent with input data -// (2) the handle is consistent in row and index insertions -// (3) the keys of the indices are consistent with the values of rows -// -// The check doesn't work and just returns nil when: -// (1) the table is partitioned -// (2) new collation is enabled and restored data is needed -// -// The check is performed on almost every write. Its performance matters. -// Let M = the number of mutations, C = the number of columns in the table, -// I = the sum of the number of columns in all indices, -// The time complexity is O(M * C + I) -// The space complexity is O(M + C + I) -func CheckDataConsistency( - txn kv.Transaction, tc types.Context, t *TableCommon, - rowToInsert, rowToRemove []types.Datum, memBuffer kv.MemBuffer, sh kv.StagingHandle, -) error { - if t.Meta().GetPartitionInfo() != nil { - return nil - } - if txn.IsPipelined() { - return nil - } - if sh == 0 { - // some implementations of MemBuffer doesn't support staging, e.g. that in pkg/lightning/backend/kv - return nil - } - indexMutations, rowInsertion, err := collectTableMutationsFromBufferStage(t, memBuffer, sh) - if err != nil { - return errors.Trace(err) - } - - columnMaps := getColumnMaps(txn, t) - - // Row insertion consistency check contributes the least to defending data-index consistency, but costs most CPU resources. - // So we disable it for now. - // - // if rowToInsert != nil { - // if err := checkRowInsertionConsistency( - // sessVars, rowToInsert, rowInsertion, columnMaps.ColumnIDToInfo, columnMaps.ColumnIDToFieldType, t.Meta().Name.O, - // ); err != nil { - // return errors.Trace(err) - // } - // } - - if rowInsertion.key != nil { - if err = checkHandleConsistency(rowInsertion, indexMutations, columnMaps.IndexIDToInfo, t.Meta()); err != nil { - return errors.Trace(err) - } - } - - if err := checkIndexKeys( - tc, t, rowToInsert, rowToRemove, indexMutations, columnMaps.IndexIDToInfo, columnMaps.IndexIDToRowColInfos, - ); err != nil { - return errors.Trace(err) - } - return nil -} - -// checkHandleConsistency checks whether the handles, with regard to a single-row change, -// in row insertions and index insertions are consistent. -// A PUT_index implies a PUT_row with the same handle. -// Deletions are not checked since the values of deletions are unknown -func checkHandleConsistency(rowInsertion mutation, indexMutations []mutation, indexIDToInfo map[int64]*model.IndexInfo, tblInfo *model.TableInfo) error { - var insertionHandle kv.Handle - var err error - - if rowInsertion.key == nil { - return nil - } - insertionHandle, err = tablecodec.DecodeRowKey(rowInsertion.key) - if err != nil { - return errors.Trace(err) - } - - for _, m := range indexMutations { - if len(m.value) == 0 { - continue - } - - // Generate correct index id for check. - idxID := m.indexID & tablecodec.IndexIDMask - indexInfo, ok := indexIDToInfo[idxID] - if !ok { - return errors.New("index not found") - } - - // If this is the temporary index data, need to remove the last byte of index data(version about when it is written). - var ( - value []byte - orgKey []byte - indexHandle kv.Handle - ) - if idxID != m.indexID { - if tablecodec.TempIndexValueIsUntouched(m.value) { - // We never commit the untouched key values to the storage. Skip this check. - continue - } - var tempIdxVal tablecodec.TempIndexValue - tempIdxVal, err = tablecodec.DecodeTempIndexValue(m.value) - if err != nil { - return err - } - if !tempIdxVal.IsEmpty() { - value = tempIdxVal.Current().Value - } - if len(value) == 0 { - // Skip the deleted operation values. - continue - } - orgKey = append(orgKey, m.key...) - tablecodec.TempIndexKey2IndexKey(orgKey) - indexHandle, err = tablecodec.DecodeIndexHandle(orgKey, value, len(indexInfo.Columns)) - } else { - indexHandle, err = tablecodec.DecodeIndexHandle(m.key, m.value, len(indexInfo.Columns)) - } - if err != nil { - return errors.Trace(err) - } - // NOTE: handle type can be different, see issue 29520 - if indexHandle.IsInt() == insertionHandle.IsInt() && indexHandle.Compare(insertionHandle) != 0 { - err = ErrInconsistentHandle.GenWithStackByArgs(tblInfo.Name, indexInfo.Name.O, indexHandle, insertionHandle, m, rowInsertion) - logutil.BgLogger().Error("inconsistent handle in index and record insertions", zap.Error(err)) - return err - } - } - - return err -} - -// checkIndexKeys checks whether the decoded data from keys of index mutations are consistent with the expected ones. -// -// How it works: -// -// Assume the set of row values changes from V1 to V2, we check -// (1) V2 - V1 = {added indices} -// (2) V1 - V2 = {deleted indices} -// -// To check (1), we need -// (a) {added indices} is a subset of {needed indices} => each index mutation is consistent with the input/row key/value -// (b) {needed indices} is a subset of {added indices}. The check process would be exactly the same with how we generate the mutations, thus ignored. -func checkIndexKeys( - tc types.Context, t *TableCommon, rowToInsert, rowToRemove []types.Datum, - indexMutations []mutation, indexIDToInfo map[int64]*model.IndexInfo, - indexIDToRowColInfos map[int64][]rowcodec.ColInfo, -) error { - var indexData []types.Datum - for _, m := range indexMutations { - var value []byte - // Generate correct index id for check. - idxID := m.indexID & tablecodec.IndexIDMask - indexInfo, ok := indexIDToInfo[idxID] - if !ok { - return errors.New("index not found") - } - rowColInfos, ok := indexIDToRowColInfos[idxID] - if !ok { - return errors.New("index not found") - } - - var isTmpIdxValAndDeleted bool - // If this is temp index data, need remove last byte of index data. - if idxID != m.indexID { - if tablecodec.TempIndexValueIsUntouched(m.value) { - // We never commit the untouched key values to the storage. Skip this check. - continue - } - tmpVal, err := tablecodec.DecodeTempIndexValue(m.value) - if err != nil { - return err - } - curElem := tmpVal.Current() - isTmpIdxValAndDeleted = curElem.Delete - value = append(value, curElem.Value...) - } else { - value = append(value, m.value...) - } - - // when we cannot decode the key to get the original value - if len(value) == 0 && NeedRestoredData(indexInfo.Columns, t.Meta().Columns) { - continue - } - - decodedIndexValues, err := tablecodec.DecodeIndexKV( - m.key, value, len(indexInfo.Columns), tablecodec.HandleNotNeeded, rowColInfos, - ) - if err != nil { - return errors.Trace(err) - } - - // reuse the underlying memory, save an allocation - if indexData == nil { - indexData = make([]types.Datum, 0, len(decodedIndexValues)) - } else { - indexData = indexData[:0] - } - - loc := tc.Location() - for i, v := range decodedIndexValues { - fieldType := t.Columns[indexInfo.Columns[i].Offset].FieldType.ArrayType() - datum, err := tablecodec.DecodeColumnValue(v, fieldType, loc) - if err != nil { - return errors.Trace(err) - } - indexData = append(indexData, datum) - } - - // When it is in add index new backfill state. - if len(value) == 0 || isTmpIdxValAndDeleted { - err = compareIndexData(tc, t.Columns, indexData, rowToRemove, indexInfo, t.Meta()) - } else { - err = compareIndexData(tc, t.Columns, indexData, rowToInsert, indexInfo, t.Meta()) - } - if err != nil { - return errors.Trace(err) - } - } - return nil -} - -// checkRowInsertionConsistency checks whether the values of row mutations are consistent with the expected ones -// We only check data added since a deletion of a row doesn't care about its value (and we cannot know it) -func checkRowInsertionConsistency( - sessVars *variable.SessionVars, rowToInsert []types.Datum, rowInsertion mutation, - columnIDToInfo map[int64]*model.ColumnInfo, columnIDToFieldType map[int64]*types.FieldType, tableName string, -) error { - if rowToInsert == nil { - // it's a deletion - return nil - } - - decodedData, err := tablecodec.DecodeRowToDatumMap(rowInsertion.value, columnIDToFieldType, sessVars.Location()) - if err != nil { - return errors.Trace(err) - } - - // NOTE: we cannot check if the decoded values contain all columns since some columns may be skipped. It can even be empty - // Instead, we check that decoded index values are consistent with the input row. - - for columnID, decodedDatum := range decodedData { - inputDatum := rowToInsert[columnIDToInfo[columnID].Offset] - cmp, err := decodedDatum.Compare(sessVars.StmtCtx.TypeCtx(), &inputDatum, collate.GetCollator(decodedDatum.Collation())) - if err != nil { - return errors.Trace(err) - } - if cmp != 0 { - err = ErrInconsistentRowValue.GenWithStackByArgs(tableName, inputDatum.String(), decodedDatum.String()) - logutil.BgLogger().Error("inconsistent row value in row insertion", zap.Error(err)) - return err - } - } - return nil -} - -// collectTableMutationsFromBufferStage collects mutations of the current table from the mem buffer stage -// It returns: (1) all index mutations (2) the only row insertion -// If there are no row insertions, the 2nd returned value is nil -// If there are multiple row insertions, an error is returned -func collectTableMutationsFromBufferStage(t *TableCommon, memBuffer kv.MemBuffer, sh kv.StagingHandle) ( - []mutation, mutation, error, -) { - indexMutations := make([]mutation, 0) - var rowInsertion mutation - var err error - inspector := func(key kv.Key, flags kv.KeyFlags, data []byte) { - // only check the current table - if tablecodec.DecodeTableID(key) == t.physicalTableID { - m := mutation{key, flags, data, 0} - if rowcodec.IsRowKey(key) { - if len(data) > 0 { - if rowInsertion.key == nil { - rowInsertion = m - } else { - err = errors.Errorf( - "multiple row mutations added/mutated, one = %+v, another = %+v", rowInsertion, m, - ) - } - } - } else { - _, m.indexID, _, err = tablecodec.DecodeIndexKey(m.key) - if err != nil { - err = errors.Trace(err) - } - indexMutations = append(indexMutations, m) - } - } - } - memBuffer.InspectStage(sh, inspector) - return indexMutations, rowInsertion, err -} - -// compareIndexData compares the decoded index data with the input data. -// Returns error if the index data is not a subset of the input data. -func compareIndexData( - tc types.Context, cols []*table.Column, indexData, input []types.Datum, indexInfo *model.IndexInfo, - tableInfo *model.TableInfo, -) error { - for i := range indexData { - decodedMutationDatum := indexData[i] - expectedDatum := input[indexInfo.Columns[i].Offset] - - tablecodec.TruncateIndexValue( - &expectedDatum, indexInfo.Columns[i], - cols[indexInfo.Columns[i].Offset].ColumnInfo, - ) - tablecodec.TruncateIndexValue( - &decodedMutationDatum, indexInfo.Columns[i], - cols[indexInfo.Columns[i].Offset].ColumnInfo, - ) - - comparison, err := CompareIndexAndVal(tc, expectedDatum, decodedMutationDatum, - collate.GetCollator(decodedMutationDatum.Collation()), - cols[indexInfo.Columns[i].Offset].ColumnInfo.FieldType.IsArray() && expectedDatum.Kind() == types.KindMysqlJSON) - if err != nil { - return errors.Trace(err) - } - - if comparison != 0 { - err = ErrInconsistentIndexedValue.GenWithStackByArgs( - tableInfo.Name.O, indexInfo.Name.O, cols[indexInfo.Columns[i].Offset].ColumnInfo.Name.O, - decodedMutationDatum.String(), expectedDatum.String(), - ) - logutil.BgLogger().Error("inconsistent indexed value in index insertion", zap.Error(err)) - return err - } - } - return nil -} - -// CompareIndexAndVal compare index valued and row value. -func CompareIndexAndVal(tc types.Context, rowVal types.Datum, idxVal types.Datum, collator collate.Collator, cmpMVIndex bool) (int, error) { - var cmpRes int - var err error - if cmpMVIndex { - // If it is multi-valued index, we should check the JSON contains the indexed value. - bj := rowVal.GetMysqlJSON() - count := bj.GetElemCount() - for elemIdx := 0; elemIdx < count; elemIdx++ { - jsonDatum := types.NewJSONDatum(bj.ArrayGetElem(elemIdx)) - cmpRes, err = jsonDatum.Compare(tc, &idxVal, collate.GetBinaryCollator()) - if err != nil { - return 0, errors.Trace(err) - } - if cmpRes == 0 { - break - } - } - } else { - cmpRes, err = idxVal.Compare(tc, &rowVal, collator) - } - return cmpRes, err -} - -// getColumnMaps tries to get the columnMaps from transaction options. If there isn't one, it builds one and stores it. -// It saves redundant computations of the map. -func getColumnMaps(txn kv.Transaction, t *TableCommon) columnMaps { - getter := func() (map[int64]columnMaps, bool) { - m, ok := txn.GetOption(kv.TableToColumnMaps).(map[int64]columnMaps) - return m, ok - } - setter := func(maps map[int64]columnMaps) { - txn.SetOption(kv.TableToColumnMaps, maps) - } - columnMaps := getOrBuildColumnMaps(getter, setter, t) - return columnMaps -} - -// getOrBuildColumnMaps tries to get the columnMaps from some place. If there isn't one, it builds one and stores it. -// It saves redundant computations of the map. -func getOrBuildColumnMaps( - getter func() (map[int64]columnMaps, bool), setter func(map[int64]columnMaps), t *TableCommon, -) columnMaps { - tableMaps, ok := getter() - if !ok || tableMaps == nil { - tableMaps = make(map[int64]columnMaps) - } - maps, ok := tableMaps[t.tableID] - if !ok { - maps = columnMaps{ - make(map[int64]*model.ColumnInfo, len(t.Meta().Columns)), - make(map[int64]*types.FieldType, len(t.Meta().Columns)), - make(map[int64]*model.IndexInfo, len(t.Indices())), - make(map[int64][]rowcodec.ColInfo, len(t.Indices())), - } - - for _, col := range t.Meta().Columns { - maps.ColumnIDToInfo[col.ID] = col - maps.ColumnIDToFieldType[col.ID] = &(col.FieldType) - } - for _, index := range t.Indices() { - if index.Meta().Primary && t.meta.IsCommonHandle { - continue - } - maps.IndexIDToInfo[index.Meta().ID] = index.Meta() - maps.IndexIDToRowColInfos[index.Meta().ID] = BuildRowcodecColInfoForIndexColumns(index.Meta(), t.Meta()) - } - - tableMaps[t.tableID] = maps - setter(tableMaps) - } - return maps -} - -// only used in tests -// commands is a comma separated string, each representing a type of corruptions to the mutations -// The injection depends on actual encoding rules. -func corruptMutations(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle, cmds string) error { - commands := strings.Split(cmds, ",") - memBuffer := txn.GetMemBuffer() - - indexMutations, _, err := collectTableMutationsFromBufferStage(t, memBuffer, sh) - if err != nil { - return errors.Trace(err) - } - - for _, cmd := range commands { - switch cmd { - case "extraIndex": - // an extra index mutation - { - if len(indexMutations) == 0 { - continue - } - indexMutation := indexMutations[0] - key := make([]byte, len(indexMutation.key)) - copy(key, indexMutation.key) - key[len(key)-1]++ - if len(indexMutation.value) == 0 { - if err := memBuffer.Delete(key); err != nil { - return errors.Trace(err) - } - } else { - if err := memBuffer.Set(key, indexMutation.value); err != nil { - return errors.Trace(err) - } - } - } - case "missingIndex": - // an index mutation is missing - // "missIndex" should be placed in front of "extraIndex"es, - // in case it removes the mutation that was just added - { - indexMutation := indexMutations[0] - memBuffer.RemoveFromBuffer(indexMutation.key) - } - case "corruptIndexKey": - // a corrupted index mutation. - // TODO: distinguish which part is corrupted, value or handle - { - indexMutation := indexMutations[0] - key := indexMutation.key - memBuffer.RemoveFromBuffer(key) - key[len(key)-1]++ - if len(indexMutation.value) == 0 { - if err := memBuffer.Delete(key); err != nil { - return errors.Trace(err) - } - } else { - if err := memBuffer.Set(key, indexMutation.value); err != nil { - return errors.Trace(err) - } - } - } - case "corruptIndexValue": - // TODO: distinguish which part to corrupt, int handle, common handle, or restored data? - // It doesn't make much sense to always corrupt the last byte - { - if len(indexMutations) == 0 { - continue - } - indexMutation := indexMutations[0] - value := indexMutation.value - if len(value) > 0 { - value[len(value)-1]++ - if err := memBuffer.Set(indexMutation.key, value); err != nil { - return errors.Trace(err) - } - } - } - default: - return fmt.Errorf("unknown command to corrupt mutation: %s", cmd) - } - } - return nil -} - -func injectMutationError(t *TableCommon, txn kv.Transaction, sh kv.StagingHandle) error { - failpoint.Inject("corruptMutations", func(commands failpoint.Value) { - failpoint.Return(corruptMutations(t, txn, sh, commands.(string))) - }) - return nil -} diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index cd0bb4c0e18f6..b44056c4b2613 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -544,17 +544,17 @@ func (t *TableCommon) updateRecord(sctx table.MutateContext, h kv.Handle, oldDat return err } - if _, _err_ := failpoint.Eval(_curpkg_("updateRecordForceAssertNotExist")); _err_ == nil { + failpoint.Inject("updateRecordForceAssertNotExist", func() { // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. if sctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting not exist on UpdateRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { - return err + failpoint.Return(err) } } - } + }) if t.shouldAssert(sctx.TxnAssertionLevel()) { err = txn.SetAssertion(key, kv.SetAssertExist) @@ -930,17 +930,17 @@ func (t *TableCommon) addRecord(sctx table.MutateContext, r []types.Datum, opt * return nil, err } - if _, _err_ := failpoint.Eval(_curpkg_("addRecordForceAssertExist")); _err_ == nil { + failpoint.Inject("addRecordForceAssertExist", func() { // Assert the key exists while it actually doesn't. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. if sctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting exist on AddRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertExist); err != nil { - return nil, err + failpoint.Return(nil, err) } } - } + }) if setPresume && !txn.IsPessimistic() { err = txn.SetAssertion(key, kv.SetAssertUnknown) } else { @@ -1364,17 +1364,17 @@ func (t *TableCommon) removeRowData(ctx table.MutateContext, h kv.Handle) error } key := t.RecordKey(h) - if _, _err_ := failpoint.Eval(_curpkg_("removeRecordForceAssertNotExist")); _err_ == nil { + failpoint.Inject("removeRecordForceAssertNotExist", func() { // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. // Since only the first assertion takes effect, set the injected assertion before setting the correct one to // override it. if ctx.ConnectionID() != 0 { logutil.BgLogger().Info("force asserting not exist on RemoveRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { - return err + failpoint.Return(err) } } - } + }) if t.shouldAssert(ctx.TxnAssertionLevel()) { err = txn.SetAssertion(key, kv.SetAssertExist) } else { diff --git a/pkg/table/tables/tables.go__failpoint_stash__ b/pkg/table/tables/tables.go__failpoint_stash__ deleted file mode 100644 index b44056c4b2613..0000000000000 --- a/pkg/table/tables/tables.go__failpoint_stash__ +++ /dev/null @@ -1,2100 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -package tables - -import ( - "context" - "fmt" - "math" - "strconv" - "strings" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/expression" - exprctx "github.com/pingcap/tidb/pkg/expression/context" - "github.com/pingcap/tidb/pkg/kv" - "github.com/pingcap/tidb/pkg/meta" - "github.com/pingcap/tidb/pkg/meta/autoid" - "github.com/pingcap/tidb/pkg/parser/ast" - "github.com/pingcap/tidb/pkg/parser/model" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/binloginfo" - "github.com/pingcap/tidb/pkg/sessionctx/variable" - "github.com/pingcap/tidb/pkg/statistics" - "github.com/pingcap/tidb/pkg/table" - tbctx "github.com/pingcap/tidb/pkg/table/context" - "github.com/pingcap/tidb/pkg/tablecodec" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/codec" - "github.com/pingcap/tidb/pkg/util/collate" - "github.com/pingcap/tidb/pkg/util/generatedexpr" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/stringutil" - "github.com/pingcap/tidb/pkg/util/tableutil" - "github.com/pingcap/tidb/pkg/util/tracing" - "github.com/pingcap/tipb/go-binlog" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -// TableCommon is shared by both Table and partition. -// NOTE: when copying this struct, use Copy() to clear its columns cache. -type TableCommon struct { - // TODO: Why do we need tableID, when it is already in meta.ID ? - tableID int64 - // physicalTableID is a unique int64 to identify a physical table. - physicalTableID int64 - Columns []*table.Column - - // column caches - // They are pointers to support copying TableCommon to CachedTable and PartitionedTable - publicColumns []*table.Column - visibleColumns []*table.Column - hiddenColumns []*table.Column - writableColumns []*table.Column - fullHiddenColsAndVisibleColumns []*table.Column - writableConstraints []*table.Constraint - - indices []table.Index - meta *model.TableInfo - allocs autoid.Allocators - sequence *sequenceCommon - dependencyColumnOffsets []int - Constraints []*table.Constraint - - // recordPrefix and indexPrefix are generated using physicalTableID. - recordPrefix kv.Key - indexPrefix kv.Key -} - -// ResetColumnsCache implements testingKnob interface. -func (t *TableCommon) ResetColumnsCache() { - t.publicColumns = t.getCols(full) - t.visibleColumns = t.getCols(visible) - t.hiddenColumns = t.getCols(hidden) - - t.writableColumns = make([]*table.Column, 0, len(t.Columns)) - for _, col := range t.Columns { - if col.State == model.StateDeleteOnly || col.State == model.StateDeleteReorganization { - continue - } - t.writableColumns = append(t.writableColumns, col) - } - - t.fullHiddenColsAndVisibleColumns = make([]*table.Column, 0, len(t.Columns)) - for _, col := range t.Columns { - if col.Hidden || col.State == model.StatePublic { - t.fullHiddenColsAndVisibleColumns = append(t.fullHiddenColsAndVisibleColumns, col) - } - } - - if t.Constraints != nil { - t.writableConstraints = make([]*table.Constraint, 0, len(t.Constraints)) - for _, con := range t.Constraints { - if !con.Enforced { - continue - } - if con.State == model.StateDeleteOnly || con.State == model.StateDeleteReorganization { - continue - } - t.writableConstraints = append(t.writableConstraints, con) - } - } -} - -// Copy copies a TableCommon struct, and reset its column cache. This is not a deep copy. -func (t *TableCommon) Copy() TableCommon { - newTable := *t - return newTable -} - -// MockTableFromMeta only serves for test. -func MockTableFromMeta(tblInfo *model.TableInfo) table.Table { - columns := make([]*table.Column, 0, len(tblInfo.Columns)) - for _, colInfo := range tblInfo.Columns { - col := table.ToColumn(colInfo) - columns = append(columns, col) - } - - constraints, err := table.LoadCheckConstraint(tblInfo) - if err != nil { - return nil - } - var t TableCommon - initTableCommon(&t, tblInfo, tblInfo.ID, columns, autoid.NewAllocators(false), constraints) - if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { - ret, err := newCachedTable(&t) - if err != nil { - return nil - } - return ret - } - if tblInfo.GetPartitionInfo() == nil { - if err := initTableIndices(&t); err != nil { - return nil - } - return &t - } - - ret, err := newPartitionedTable(&t, tblInfo) - if err != nil { - return nil - } - return ret -} - -// TableFromMeta creates a Table instance from model.TableInfo. -func TableFromMeta(allocs autoid.Allocators, tblInfo *model.TableInfo) (table.Table, error) { - if tblInfo.State == model.StateNone { - return nil, table.ErrTableStateCantNone.GenWithStackByArgs(tblInfo.Name) - } - - colsLen := len(tblInfo.Columns) - columns := make([]*table.Column, 0, colsLen) - for i, colInfo := range tblInfo.Columns { - if colInfo.State == model.StateNone { - return nil, table.ErrColumnStateCantNone.GenWithStackByArgs(colInfo.Name) - } - - // Print some information when the column's offset isn't equal to i. - if colInfo.Offset != i { - logutil.BgLogger().Error("wrong table schema", zap.Any("table", tblInfo), zap.Any("column", colInfo), zap.Int("index", i), zap.Int("offset", colInfo.Offset), zap.Int("columnNumber", colsLen)) - } - - col := table.ToColumn(colInfo) - if col.IsGenerated() { - genStr := colInfo.GeneratedExprString - expr, err := buildGeneratedExpr(tblInfo, genStr) - if err != nil { - return nil, err - } - col.GeneratedExpr = table.NewClonableExprNode(func() ast.ExprNode { - newExpr, err1 := buildGeneratedExpr(tblInfo, genStr) - if err1 != nil { - logutil.BgLogger().Warn("unexpected parse generated string error", - zap.String("generatedStr", genStr), - zap.Error(err1)) - return expr - } - return newExpr - }, expr) - } - // default value is expr. - if col.DefaultIsExpr { - expr, err := generatedexpr.ParseExpression(colInfo.DefaultValue.(string)) - if err != nil { - return nil, err - } - col.DefaultExpr = expr - } - columns = append(columns, col) - } - - constraints, err := table.LoadCheckConstraint(tblInfo) - if err != nil { - return nil, err - } - var t TableCommon - initTableCommon(&t, tblInfo, tblInfo.ID, columns, allocs, constraints) - if tblInfo.GetPartitionInfo() == nil { - if err := initTableIndices(&t); err != nil { - return nil, err - } - if tblInfo.TableCacheStatusType != model.TableCacheStatusDisable { - return newCachedTable(&t) - } - return &t, nil - } - return newPartitionedTable(&t, tblInfo) -} - -func buildGeneratedExpr(tblInfo *model.TableInfo, genExpr string) (ast.ExprNode, error) { - expr, err := generatedexpr.ParseExpression(genExpr) - if err != nil { - return nil, err - } - expr, err = generatedexpr.SimpleResolveName(expr, tblInfo) - if err != nil { - return nil, err - } - return expr, nil -} - -// initTableCommon initializes a TableCommon struct. -func initTableCommon(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators, constraints []*table.Constraint) { - t.tableID = tblInfo.ID - t.physicalTableID = physicalTableID - t.allocs = allocs - t.meta = tblInfo - t.Columns = cols - t.Constraints = constraints - t.recordPrefix = tablecodec.GenTableRecordPrefix(physicalTableID) - t.indexPrefix = tablecodec.GenTableIndexPrefix(physicalTableID) - if tblInfo.IsSequence() { - t.sequence = &sequenceCommon{meta: tblInfo.Sequence} - } - for _, col := range cols { - if col.ChangeStateInfo != nil { - t.dependencyColumnOffsets = append(t.dependencyColumnOffsets, col.ChangeStateInfo.DependencyColumnOffset) - } - } - t.ResetColumnsCache() -} - -// initTableIndices initializes the indices of the TableCommon. -func initTableIndices(t *TableCommon) error { - tblInfo := t.meta - for _, idxInfo := range tblInfo.Indices { - if idxInfo.State == model.StateNone { - return table.ErrIndexStateCantNone.GenWithStackByArgs(idxInfo.Name) - } - - // Use partition ID for index, because TableCommon may be table or partition. - idx := NewIndex(t.physicalTableID, tblInfo, idxInfo) - intest.AssertFunc(func() bool { - // `TableCommon.indices` is type of `[]table.Index` to implement interface method `Table.Indices`. - // However, we have an assumption that the specific type of each element in it should always be `*index`. - // We have this assumption because some codes access the inner method of `*index`, - // and they use `asIndex` to cast `table.Index` to `*index`. - _, ok := idx.(*index) - intest.Assert(ok, "index should be type of `*index`") - return true - }) - t.indices = append(t.indices, idx) - } - return nil -} - -// asIndex casts a table.Index to *index which is the actual type of index in TableCommon. -func asIndex(idx table.Index) *index { - return idx.(*index) -} - -func initTableCommonWithIndices(t *TableCommon, tblInfo *model.TableInfo, physicalTableID int64, cols []*table.Column, allocs autoid.Allocators, constraints []*table.Constraint) error { - initTableCommon(t, tblInfo, physicalTableID, cols, allocs, constraints) - return initTableIndices(t) -} - -// Indices implements table.Table Indices interface. -func (t *TableCommon) Indices() []table.Index { - return t.indices -} - -// GetWritableIndexByName gets the index meta from the table by the index name. -func GetWritableIndexByName(idxName string, t table.Table) table.Index { - for _, idx := range t.Indices() { - if !IsIndexWritable(idx) { - continue - } - if idxName == idx.Meta().Name.L { - return idx - } - } - return nil -} - -// deletableIndices implements table.Table deletableIndices interface. -func (t *TableCommon) deletableIndices() []table.Index { - // All indices are deletable because we don't need to check StateNone. - return t.indices -} - -// Meta implements table.Table Meta interface. -func (t *TableCommon) Meta() *model.TableInfo { - return t.meta -} - -// GetPhysicalID implements table.Table GetPhysicalID interface. -func (t *TableCommon) GetPhysicalID() int64 { - return t.physicalTableID -} - -// GetPartitionedTable implements table.Table GetPhysicalID interface. -func (t *TableCommon) GetPartitionedTable() table.PartitionedTable { - return nil -} - -type getColsMode int64 - -const ( - _ getColsMode = iota - visible - hidden - full -) - -func (t *TableCommon) getCols(mode getColsMode) []*table.Column { - columns := make([]*table.Column, 0, len(t.Columns)) - for _, col := range t.Columns { - if col.State != model.StatePublic { - continue - } - if (mode == visible && col.Hidden) || (mode == hidden && !col.Hidden) { - continue - } - columns = append(columns, col) - } - return columns -} - -// Cols implements table.Table Cols interface. -func (t *TableCommon) Cols() []*table.Column { - return t.publicColumns -} - -// VisibleCols implements table.Table VisibleCols interface. -func (t *TableCommon) VisibleCols() []*table.Column { - return t.visibleColumns -} - -// HiddenCols implements table.Table HiddenCols interface. -func (t *TableCommon) HiddenCols() []*table.Column { - return t.hiddenColumns -} - -// WritableCols implements table WritableCols interface. -func (t *TableCommon) WritableCols() []*table.Column { - return t.writableColumns -} - -// DeletableCols implements table DeletableCols interface. -func (t *TableCommon) DeletableCols() []*table.Column { - return t.Columns -} - -// WritableConstraint returns constraints of the table in writable states. -func (t *TableCommon) WritableConstraint() []*table.Constraint { - if t.Constraints == nil { - return nil - } - return t.writableConstraints -} - -// FullHiddenColsAndVisibleCols implements table FullHiddenColsAndVisibleCols interface. -func (t *TableCommon) FullHiddenColsAndVisibleCols() []*table.Column { - return t.fullHiddenColsAndVisibleColumns -} - -// RecordPrefix implements table.Table interface. -func (t *TableCommon) RecordPrefix() kv.Key { - return t.recordPrefix -} - -// IndexPrefix implements table.Table interface. -func (t *TableCommon) IndexPrefix() kv.Key { - return t.indexPrefix -} - -// RecordKey implements table.Table interface. -func (t *TableCommon) RecordKey(h kv.Handle) kv.Key { - return tablecodec.EncodeRecordKey(t.recordPrefix, h) -} - -// shouldAssert checks if the partition should be in consistent -// state and can have assertion. -func (t *TableCommon) shouldAssert(level variable.AssertionLevel) bool { - p := t.Meta().Partition - if p != nil { - // This disables asserting during Reorganize Partition. - switch level { - case variable.AssertionLevelFast: - // Fast option, just skip assertion for all partitions. - if p.DDLState != model.StateNone && p.DDLState != model.StatePublic { - return false - } - case variable.AssertionLevelStrict: - // Strict, only disable assertion for intermediate partitions. - // If there were an easy way to get from a TableCommon back to the partitioned table... - for i := range p.AddingDefinitions { - if t.physicalTableID == p.AddingDefinitions[i].ID { - return false - } - } - } - } - return true -} - -// UpdateRecord implements table.Table UpdateRecord interface. -// `touched` means which columns are really modified, used for secondary indices. -// Length of `oldData` and `newData` equals to length of `t.WritableCols()`. -func (t *TableCommon) UpdateRecord(ctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opts ...table.UpdateRecordOption) error { - opt := table.NewUpdateRecordOpt(opts...) - return t.updateRecord(ctx, h, oldData, newData, touched, opt) -} - -func (t *TableCommon) updateRecord(sctx table.MutateContext, h kv.Handle, oldData, newData []types.Datum, touched []bool, opt *table.UpdateRecordOpt) error { - txn, err := sctx.Txn(true) - if err != nil { - return err - } - - memBuffer := txn.GetMemBuffer() - sh := memBuffer.Staging() - defer memBuffer.Cleanup(sh) - - if m := t.Meta(); m.TempTableType != model.TempTableNone { - if tmpTable, sizeLimit, ok := addTemporaryTable(sctx, m); ok { - if err = checkTempTableSize(tmpTable, sizeLimit); err != nil { - return err - } - defer handleTempTableSize(tmpTable, txn.Size(), txn) - } - } - - var binlogColIDs []int64 - var binlogOldRow, binlogNewRow []types.Datum - numColsCap := len(newData) + 1 // +1 for the extra handle column that we may need to append. - - // a reusable buffer to save malloc - // Note: The buffer should not be referenced or modified outside this function. - // It can only act as a temporary buffer for the current function call. - mutateBuffers := sctx.GetMutateBuffers() - encodeRowBuffer := mutateBuffers.GetEncodeRowBufferWithCap(numColsCap) - checkRowBuffer := mutateBuffers.GetCheckRowBufferWithCap(numColsCap) - binlogSupport, shouldWriteBinlog := getBinlogSupport(sctx, t.meta) - if shouldWriteBinlog { - binlogColIDs = make([]int64, 0, numColsCap) - binlogOldRow = make([]types.Datum, 0, numColsCap) - binlogNewRow = make([]types.Datum, 0, numColsCap) - } - - for _, col := range t.Columns { - var value types.Datum - if col.State == model.StateDeleteOnly || col.State == model.StateDeleteReorganization { - if col.ChangeStateInfo != nil { - // TODO: Check overflow or ignoreTruncate. - value, err = table.CastColumnValue(sctx.GetExprCtx(), oldData[col.DependencyColumnOffset], col.ColumnInfo, false, false) - if err != nil { - logutil.BgLogger().Info("update record cast value failed", zap.Any("col", col), zap.Uint64("txnStartTS", txn.StartTS()), - zap.String("handle", h.String()), zap.Any("val", oldData[col.DependencyColumnOffset]), zap.Error(err)) - return err - } - oldData = append(oldData, value) - touched = append(touched, touched[col.DependencyColumnOffset]) - } - continue - } - if col.State != model.StatePublic { - // If col is in write only or write reorganization state we should keep the oldData. - // Because the oldData must be the original data(it's changed by other TiDBs.) or the original default value. - // TODO: Use newData directly. - value = oldData[col.Offset] - if col.ChangeStateInfo != nil { - // TODO: Check overflow or ignoreTruncate. - value, err = table.CastColumnValue(sctx.GetExprCtx(), newData[col.DependencyColumnOffset], col.ColumnInfo, false, false) - if err != nil { - return err - } - newData[col.Offset] = value - touched[col.Offset] = touched[col.DependencyColumnOffset] - } - } else { - value = newData[col.Offset] - } - if !t.canSkip(col, &value) { - encodeRowBuffer.AddColVal(col.ID, value) - } - checkRowBuffer.AddColVal(value) - if shouldWriteBinlog && !t.canSkipUpdateBinlog(col, value) { - binlogColIDs = append(binlogColIDs, col.ID) - binlogOldRow = append(binlogOldRow, oldData[col.Offset]) - binlogNewRow = append(binlogNewRow, value) - } - } - // check data constraint - evalCtx := sctx.GetExprCtx().GetEvalCtx() - if constraints := t.WritableConstraint(); len(constraints) > 0 { - if err = table.CheckRowConstraint(evalCtx, constraints, checkRowBuffer.GetRowToCheck()); err != nil { - return err - } - } - // rebuild index - err = t.rebuildUpdateRecordIndices(sctx, txn, h, touched, oldData, newData, opt) - if err != nil { - return err - } - - key := t.RecordKey(h) - tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() - err = encodeRowBuffer.WriteMemBufferEncoded(sctx.GetRowEncodingConfig(), tc.Location(), ec, memBuffer, key) - if err != nil { - return err - } - - failpoint.Inject("updateRecordForceAssertNotExist", func() { - // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. - // Since only the first assertion takes effect, set the injected assertion before setting the correct one to - // override it. - if sctx.ConnectionID() != 0 { - logutil.BgLogger().Info("force asserting not exist on UpdateRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) - if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { - failpoint.Return(err) - } - } - }) - - if t.shouldAssert(sctx.TxnAssertionLevel()) { - err = txn.SetAssertion(key, kv.SetAssertExist) - } else { - err = txn.SetAssertion(key, kv.SetAssertUnknown) - } - if err != nil { - return err - } - - if err = injectMutationError(t, txn, sh); err != nil { - return err - } - if sctx.EnableMutationChecker() { - if err = CheckDataConsistency(txn, tc, t, newData, oldData, memBuffer, sh); err != nil { - return errors.Trace(err) - } - } - - memBuffer.Release(sh) - if shouldWriteBinlog { - if !t.meta.PKIsHandle && !t.meta.IsCommonHandle { - binlogColIDs = append(binlogColIDs, model.ExtraHandleID) - binlogOldRow = append(binlogOldRow, types.NewIntDatum(h.IntValue())) - binlogNewRow = append(binlogNewRow, types.NewIntDatum(h.IntValue())) - } - err = t.addUpdateBinlog(sctx, binlogSupport, binlogOldRow, binlogNewRow, binlogColIDs) - if err != nil { - return err - } - } - - if s, ok := sctx.GetStatisticsSupport(); ok { - colSizeBuffer := mutateBuffers.GetColSizeDeltaBufferWithCap(len(t.Cols())) - for id, col := range t.Cols() { - size, err := codec.EstimateValueSize(tc, newData[id]) - if err != nil { - continue - } - newLen := size - 1 - size, err = codec.EstimateValueSize(tc, oldData[id]) - if err != nil { - continue - } - oldLen := size - 1 - colSizeBuffer.AddColSizeDelta(col.ID, int64(newLen-oldLen)) - } - s.UpdatePhysicalTableDelta(t.physicalTableID, 0, 1, colSizeBuffer) - } - return nil -} - -func (t *TableCommon) rebuildUpdateRecordIndices( - ctx table.MutateContext, txn kv.Transaction, - h kv.Handle, touched []bool, oldData []types.Datum, newData []types.Datum, - opt *table.UpdateRecordOpt, -) error { - for _, idx := range t.deletableIndices() { - if t.meta.IsCommonHandle && idx.Meta().Primary { - continue - } - for _, ic := range idx.Meta().Columns { - if !touched[ic.Offset] { - continue - } - oldVs, err := idx.FetchValues(oldData, nil) - if err != nil { - return err - } - if err = t.removeRowIndex(ctx, h, oldVs, idx, txn); err != nil { - return err - } - break - } - } - createIdxOpt := opt.GetCreateIdxOpt() - for _, idx := range t.Indices() { - if !IsIndexWritable(idx) { - continue - } - if t.meta.IsCommonHandle && idx.Meta().Primary { - continue - } - untouched := true - for _, ic := range idx.Meta().Columns { - if !touched[ic.Offset] { - continue - } - untouched = false - break - } - if untouched && opt.SkipWriteUntouchedIndices { - continue - } - newVs, err := idx.FetchValues(newData, nil) - if err != nil { - return err - } - if err := t.buildIndexForRow(ctx, h, newVs, newData, asIndex(idx), txn, untouched, createIdxOpt); err != nil { - return err - } - } - return nil -} - -// FindPrimaryIndex uses to find primary index in tableInfo. -func FindPrimaryIndex(tblInfo *model.TableInfo) *model.IndexInfo { - var pkIdx *model.IndexInfo - for _, idx := range tblInfo.Indices { - if idx.Primary { - pkIdx = idx - break - } - } - return pkIdx -} - -// TryGetCommonPkColumnIds get the IDs of primary key column if the table has common handle. -func TryGetCommonPkColumnIds(tbl *model.TableInfo) []int64 { - if !tbl.IsCommonHandle { - return nil - } - pkIdx := FindPrimaryIndex(tbl) - pkColIDs := make([]int64, 0, len(pkIdx.Columns)) - for _, idxCol := range pkIdx.Columns { - pkColIDs = append(pkColIDs, tbl.Columns[idxCol.Offset].ID) - } - return pkColIDs -} - -// PrimaryPrefixColumnIDs get prefix column ids in primary key. -func PrimaryPrefixColumnIDs(tbl *model.TableInfo) (prefixCols []int64) { - for _, idx := range tbl.Indices { - if !idx.Primary { - continue - } - for _, col := range idx.Columns { - if col.Length > 0 && tbl.Columns[col.Offset].GetFlen() > col.Length { - prefixCols = append(prefixCols, tbl.Columns[col.Offset].ID) - } - } - } - return -} - -// TryGetCommonPkColumns get the primary key columns if the table has common handle. -func TryGetCommonPkColumns(tbl table.Table) []*table.Column { - if !tbl.Meta().IsCommonHandle { - return nil - } - pkIdx := FindPrimaryIndex(tbl.Meta()) - cols := tbl.Cols() - pkCols := make([]*table.Column, 0, len(pkIdx.Columns)) - for _, idxCol := range pkIdx.Columns { - pkCols = append(pkCols, cols[idxCol.Offset]) - } - return pkCols -} - -func addTemporaryTable(sctx table.MutateContext, tblInfo *model.TableInfo) (tbctx.TemporaryTableHandler, int64, bool) { - if s, ok := sctx.GetTemporaryTableSupport(); ok { - if h, ok := s.AddTemporaryTableToTxn(tblInfo); ok { - return h, s.GetTemporaryTableSizeLimit(), ok - } - } - return tbctx.TemporaryTableHandler{}, 0, false -} - -// The size of a temporary table is calculated by accumulating the transaction size delta. -func handleTempTableSize(t tbctx.TemporaryTableHandler, txnSizeBefore int, txn kv.Transaction) { - t.UpdateTxnDeltaSize(txn.Size() - txnSizeBefore) -} - -func checkTempTableSize(tmpTable tbctx.TemporaryTableHandler, sizeLimit int64) error { - if tmpTable.GetCommittedSize()+tmpTable.GetDirtySize() > sizeLimit { - return table.ErrTempTableFull.GenWithStackByArgs(tmpTable.Meta().Name.O) - } - return nil -} - -// AddRecord implements table.Table AddRecord interface. -func (t *TableCommon) AddRecord(sctx table.MutateContext, r []types.Datum, opts ...table.AddRecordOption) (recordID kv.Handle, err error) { - // TODO: optimize the allocation (and calculation) of opt. - opt := table.NewAddRecordOpt(opts...) - return t.addRecord(sctx, r, opt) -} - -func (t *TableCommon) addRecord(sctx table.MutateContext, r []types.Datum, opt *table.AddRecordOpt) (recordID kv.Handle, err error) { - txn, err := sctx.Txn(true) - if err != nil { - return nil, err - } - if m := t.Meta(); m.TempTableType != model.TempTableNone { - if tmpTable, sizeLimit, ok := addTemporaryTable(sctx, m); ok { - if err = checkTempTableSize(tmpTable, sizeLimit); err != nil { - return nil, err - } - defer handleTempTableSize(tmpTable, txn.Size(), txn) - } - } - - var ctx context.Context - if opt.Ctx != nil { - ctx = opt.Ctx - var r tracing.Region - r, ctx = tracing.StartRegionEx(ctx, "table.AddRecord") - defer r.End() - } else { - ctx = context.Background() - } - - evalCtx := sctx.GetExprCtx().GetEvalCtx() - tc, ec := evalCtx.TypeCtx(), evalCtx.ErrCtx() - - var hasRecordID bool - cols := t.Cols() - // opt.IsUpdate is a flag for update. - // If handle ID is changed when update, update will remove the old record first, and then call `AddRecord` to add a new record. - // Currently, only insert can set _tidb_rowid, update can not update _tidb_rowid. - if len(r) > len(cols) && !opt.IsUpdate { - // The last value is _tidb_rowid. - recordID = kv.IntHandle(r[len(r)-1].GetInt64()) - hasRecordID = true - } else { - tblInfo := t.Meta() - txn.CacheTableInfo(t.physicalTableID, tblInfo) - if tblInfo.PKIsHandle { - recordID = kv.IntHandle(r[tblInfo.GetPkColInfo().Offset].GetInt64()) - hasRecordID = true - } else if tblInfo.IsCommonHandle { - pkIdx := FindPrimaryIndex(tblInfo) - pkDts := make([]types.Datum, 0, len(pkIdx.Columns)) - for _, idxCol := range pkIdx.Columns { - pkDts = append(pkDts, r[idxCol.Offset]) - } - tablecodec.TruncateIndexValues(tblInfo, pkIdx, pkDts) - var handleBytes []byte - handleBytes, err = codec.EncodeKey(tc.Location(), nil, pkDts...) - err = ec.HandleError(err) - if err != nil { - return - } - recordID, err = kv.NewCommonHandle(handleBytes) - if err != nil { - return - } - hasRecordID = true - } - } - if !hasRecordID { - if opt.ReserveAutoID > 0 { - // Reserve a batch of auto ID in the statement context. - // The reserved ID could be used in the future within this statement, by the - // following AddRecord() operation. - // Make the IDs continuous benefit for the performance of TiKV. - if reserved, ok := sctx.GetReservedRowIDAlloc(); ok { - var baseRowID, maxRowID int64 - if baseRowID, maxRowID, err = AllocHandleIDs(ctx, sctx, t, uint64(opt.ReserveAutoID)); err != nil { - return nil, err - } - reserved.Reset(baseRowID, maxRowID) - } - } - - recordID, err = AllocHandle(ctx, sctx, t) - if err != nil { - return nil, err - } - } - - // a reusable buffer to save malloc - // Note: The buffer should not be referenced or modified outside this function. - // It can only act as a temporary buffer for the current function call. - mutateBuffers := sctx.GetMutateBuffers() - encodeRowBuffer := mutateBuffers.GetEncodeRowBufferWithCap(len(r)) - memBuffer := txn.GetMemBuffer() - sh := memBuffer.Staging() - defer memBuffer.Cleanup(sh) - - sessVars := sctx.GetSessionVars() - for _, col := range t.Columns { - var value types.Datum - if col.State == model.StateDeleteOnly || col.State == model.StateDeleteReorganization { - continue - } - // In column type change, since we have set the origin default value for changing col, but - // for the new insert statement, we should use the casted value of relative column to insert. - if col.ChangeStateInfo != nil && col.State != model.StatePublic { - // TODO: Check overflow or ignoreTruncate. - value, err = table.CastColumnValue(sctx.GetExprCtx(), r[col.DependencyColumnOffset], col.ColumnInfo, false, false) - if err != nil { - return nil, err - } - if len(r) < len(t.WritableCols()) { - r = append(r, value) - } else { - r[col.Offset] = value - } - encodeRowBuffer.AddColVal(col.ID, value) - continue - } - if col.State == model.StatePublic { - value = r[col.Offset] - } else { - // col.ChangeStateInfo must be nil here. - // because `col.State != model.StatePublic` is true here, if col.ChangeStateInfo is not nil, the col should - // be handle by the previous if-block. - - if opt.IsUpdate { - // If `AddRecord` is called by an update, the default value should be handled the update. - value = r[col.Offset] - } else { - // If `AddRecord` is called by an insert and the col is in write only or write reorganization state, we must - // add it with its default value. - value, err = table.GetColOriginDefaultValue(sctx.GetExprCtx(), col.ToInfo()) - if err != nil { - return nil, err - } - // add value to `r` for dirty db in transaction. - // Otherwise when update will panic cause by get value of column in write only state from dirty db. - if col.Offset < len(r) { - r[col.Offset] = value - } else { - r = append(r, value) - } - } - } - if !t.canSkip(col, &value) { - encodeRowBuffer.AddColVal(col.ID, value) - } - } - // check data constraint - if err = table.CheckRowConstraintWithDatum(evalCtx, t.WritableConstraint(), r); err != nil { - return nil, err - } - key := t.RecordKey(recordID) - var setPresume bool - if opt.DupKeyCheck != table.DupKeyCheckSkip { - if t.meta.TempTableType != model.TempTableNone { - // Always check key for temporary table because it does not write to TiKV - _, err = txn.Get(ctx, key) - } else if sctx.GetSessionVars().LazyCheckKeyNotExists() || txn.IsPipelined() { - var v []byte - v, err = txn.GetMemBuffer().GetLocal(ctx, key) - if err != nil { - setPresume = true - } - if err == nil && len(v) == 0 { - err = kv.ErrNotExist - } - } else { - _, err = txn.Get(ctx, key) - } - if err == nil { - dupErr := getDuplicateError(t.Meta(), recordID, r) - return recordID, dupErr - } else if !kv.ErrNotExist.Equal(err) { - return recordID, err - } - } - - var flags []kv.FlagsOp - if setPresume { - flags = []kv.FlagsOp{kv.SetPresumeKeyNotExists} - if !sessVars.ConstraintCheckInPlacePessimistic && sessVars.TxnCtx.IsPessimistic && sessVars.InTxn() && - !sctx.InRestrictedSQL() && sctx.ConnectionID() > 0 { - flags = append(flags, kv.SetNeedConstraintCheckInPrewrite) - } - } - - err = encodeRowBuffer.WriteMemBufferEncoded(sctx.GetRowEncodingConfig(), tc.Location(), ec, memBuffer, key, flags...) - if err != nil { - return nil, err - } - - failpoint.Inject("addRecordForceAssertExist", func() { - // Assert the key exists while it actually doesn't. This is helpful to test if assertion takes effect. - // Since only the first assertion takes effect, set the injected assertion before setting the correct one to - // override it. - if sctx.ConnectionID() != 0 { - logutil.BgLogger().Info("force asserting exist on AddRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) - if err = txn.SetAssertion(key, kv.SetAssertExist); err != nil { - failpoint.Return(nil, err) - } - } - }) - if setPresume && !txn.IsPessimistic() { - err = txn.SetAssertion(key, kv.SetAssertUnknown) - } else { - err = txn.SetAssertion(key, kv.SetAssertNotExist) - } - if err != nil { - return nil, err - } - - // Insert new entries into indices. - h, err := t.addIndices(sctx, recordID, r, txn, opt.GetCreateIdxOpt()) - if err != nil { - return h, err - } - - if err = injectMutationError(t, txn, sh); err != nil { - return nil, err - } - if sctx.EnableMutationChecker() { - if err = CheckDataConsistency(txn, tc, t, r, nil, memBuffer, sh); err != nil { - return nil, errors.Trace(err) - } - } - - memBuffer.Release(sh) - - binlogSupport, shouldWriteBinlog := getBinlogSupport(sctx, t.meta) - if shouldWriteBinlog { - // For insert, TiDB and Binlog can use same row and schema. - err = t.addInsertBinlog(sctx, binlogSupport, recordID, encodeRowBuffer) - if err != nil { - return nil, err - } - } - - if s, ok := sctx.GetStatisticsSupport(); ok { - colSizeBuffer := sctx.GetMutateBuffers().GetColSizeDeltaBufferWithCap(len(t.Cols())) - for id, col := range t.Cols() { - size, err := codec.EstimateValueSize(tc, r[id]) - if err != nil { - continue - } - colSizeBuffer.AddColSizeDelta(col.ID, int64(size-1)) - } - s.UpdatePhysicalTableDelta(t.physicalTableID, 1, 1, colSizeBuffer) - } - return recordID, nil -} - -// genIndexKeyStrs generates index content strings representation. -func genIndexKeyStrs(colVals []types.Datum) ([]string, error) { - // Pass pre-composed error to txn. - strVals := make([]string, 0, len(colVals)) - for _, cv := range colVals { - cvs := "NULL" - var err error - if !cv.IsNull() { - cvs, err = types.ToString(cv.GetValue()) - if err != nil { - return nil, err - } - } - strVals = append(strVals, cvs) - } - return strVals, nil -} - -// addIndices adds data into indices. If any key is duplicated, returns the original handle. -func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r []types.Datum, txn kv.Transaction, opt *table.CreateIdxOpt) (kv.Handle, error) { - writeBufs := sctx.GetMutateBuffers().GetWriteStmtBufs() - indexVals := writeBufs.IndexValsBuf - skipCheck := opt.DupKeyCheck == table.DupKeyCheckSkip - for _, v := range t.Indices() { - if !IsIndexWritable(v) { - continue - } - if t.meta.IsCommonHandle && v.Meta().Primary { - continue - } - // We declared `err` here to make sure `indexVals` is assigned with `=` instead of `:=`. - // The latter one will create a new variable that shadows the outside `indexVals` that makes `indexVals` outside - // always nil, and we cannot reuse it. - var err error - indexVals, err = v.FetchValues(r, indexVals) - if err != nil { - return nil, err - } - var dupErr error - if !skipCheck && v.Meta().Unique { - // Make error message consistent with MySQL. - tablecodec.TruncateIndexValues(t.meta, v.Meta(), indexVals) - colStrVals, err := genIndexKeyStrs(indexVals) - if err != nil { - return nil, err - } - dupErr = kv.GenKeyExistsErr(colStrVals, fmt.Sprintf("%s.%s", v.TableMeta().Name.String(), v.Meta().Name.String())) - } - rsData := TryGetHandleRestoredDataWrapper(t.meta, r, nil, v.Meta()) - if dupHandle, err := asIndex(v).create(sctx, txn, indexVals, recordID, rsData, false, opt); err != nil { - if kv.ErrKeyExists.Equal(err) { - return dupHandle, dupErr - } - return nil, err - } - } - // save the buffer, multi rows insert can use it. - writeBufs.IndexValsBuf = indexVals - return nil, nil -} - -// RowWithCols is used to get the corresponding column datum values with the given handle. -func RowWithCols(t table.Table, ctx sessionctx.Context, h kv.Handle, cols []*table.Column) ([]types.Datum, error) { - // Get raw row data from kv. - key := tablecodec.EncodeRecordKey(t.RecordPrefix(), h) - txn, err := ctx.Txn(true) - if err != nil { - return nil, err - } - value, err := txn.Get(context.TODO(), key) - if err != nil { - return nil, err - } - v, _, err := DecodeRawRowData(ctx, t.Meta(), h, cols, value) - if err != nil { - return nil, err - } - return v, nil -} - -func containFullColInHandle(meta *model.TableInfo, col *table.Column) (containFullCol bool, idxInHandle int) { - pkIdx := FindPrimaryIndex(meta) - for i, idxCol := range pkIdx.Columns { - if meta.Columns[idxCol.Offset].ID == col.ID { - idxInHandle = i - containFullCol = idxCol.Length == types.UnspecifiedLength - return - } - } - return -} - -// DecodeRawRowData decodes raw row data into a datum slice and a (columnID:columnValue) map. -func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h kv.Handle, cols []*table.Column, - value []byte) ([]types.Datum, map[int64]types.Datum, error) { - v := make([]types.Datum, len(cols)) - colTps := make(map[int64]*types.FieldType, len(cols)) - prefixCols := make(map[int64]struct{}) - for i, col := range cols { - if col == nil { - continue - } - if col.IsPKHandleColumn(meta) { - if mysql.HasUnsignedFlag(col.GetFlag()) { - v[i].SetUint64(uint64(h.IntValue())) - } else { - v[i].SetInt64(h.IntValue()) - } - continue - } - if col.IsCommonHandleColumn(meta) && !types.NeedRestoredData(&col.FieldType) { - if containFullCol, idxInHandle := containFullColInHandle(meta, col); containFullCol { - dtBytes := h.EncodedCol(idxInHandle) - _, dt, err := codec.DecodeOne(dtBytes) - if err != nil { - return nil, nil, err - } - dt, err = tablecodec.Unflatten(dt, &col.FieldType, ctx.GetSessionVars().Location()) - if err != nil { - return nil, nil, err - } - v[i] = dt - continue - } - prefixCols[col.ID] = struct{}{} - } - colTps[col.ID] = &col.FieldType - } - rowMap, err := tablecodec.DecodeRowToDatumMap(value, colTps, ctx.GetSessionVars().Location()) - if err != nil { - return nil, rowMap, err - } - defaultVals := make([]types.Datum, len(cols)) - for i, col := range cols { - if col == nil { - continue - } - if col.IsPKHandleColumn(meta) || (col.IsCommonHandleColumn(meta) && !types.NeedRestoredData(&col.FieldType)) { - if _, isPrefix := prefixCols[col.ID]; !isPrefix { - continue - } - } - ri, ok := rowMap[col.ID] - if ok { - v[i] = ri - continue - } - if col.IsVirtualGenerated() { - continue - } - if col.ChangeStateInfo != nil { - v[i], _, err = GetChangingColVal(ctx.GetExprCtx(), cols, col, rowMap, defaultVals) - } else { - v[i], err = GetColDefaultValue(ctx.GetExprCtx(), col, defaultVals) - } - if err != nil { - return nil, rowMap, err - } - } - return v, rowMap, nil -} - -// GetChangingColVal gets the changing column value when executing "modify/change column" statement. -// For statement like update-where, it will fetch the old row out and insert it into kv again. -// Since update statement can see the writable columns, it is responsible for the casting relative column / get the fault value here. -// old row : a-b-[nil] -// new row : a-b-[a'/default] -// Thus the writable new row is corresponding to Write-Only constraints. -func GetChangingColVal(ctx exprctx.BuildContext, cols []*table.Column, col *table.Column, rowMap map[int64]types.Datum, defaultVals []types.Datum) (_ types.Datum, isDefaultVal bool, err error) { - relativeCol := cols[col.ChangeStateInfo.DependencyColumnOffset] - idxColumnVal, ok := rowMap[relativeCol.ID] - if ok { - idxColumnVal, err = table.CastColumnValue(ctx, idxColumnVal, col.ColumnInfo, false, false) - // TODO: Consider sql_mode and the error msg(encounter this error check whether to rollback). - if err != nil { - return idxColumnVal, false, errors.Trace(err) - } - return idxColumnVal, false, nil - } - - idxColumnVal, err = GetColDefaultValue(ctx, col, defaultVals) - if err != nil { - return idxColumnVal, false, errors.Trace(err) - } - - return idxColumnVal, true, nil -} - -// RemoveRecord implements table.Table RemoveRecord interface. -func (t *TableCommon) RemoveRecord(ctx table.MutateContext, h kv.Handle, r []types.Datum) error { - txn, err := ctx.Txn(true) - if err != nil { - return err - } - - memBuffer := txn.GetMemBuffer() - sh := memBuffer.Staging() - defer memBuffer.Cleanup(sh) - - err = t.removeRowData(ctx, h) - if err != nil { - return err - } - - if m := t.Meta(); m.TempTableType != model.TempTableNone { - if tmpTable, sizeLimit, ok := addTemporaryTable(ctx, m); ok { - if err = checkTempTableSize(tmpTable, sizeLimit); err != nil { - return err - } - defer handleTempTableSize(tmpTable, txn.Size(), txn) - } - } - - // The table has non-public column and this column is doing the operation of "modify/change column". - if len(t.Columns) > len(r) && t.Columns[len(r)].ChangeStateInfo != nil { - // The changing column datum derived from related column should be casted here. - // Otherwise, the existed changing indexes will not be deleted. - relatedColDatum := r[t.Columns[len(r)].ChangeStateInfo.DependencyColumnOffset] - value, err := table.CastColumnValue(ctx.GetExprCtx(), relatedColDatum, t.Columns[len(r)].ColumnInfo, false, false) - if err != nil { - logutil.BgLogger().Info("remove record cast value failed", zap.Any("col", t.Columns[len(r)]), - zap.String("handle", h.String()), zap.Any("val", relatedColDatum), zap.Error(err)) - return err - } - r = append(r, value) - } - err = t.removeRowIndices(ctx, h, r) - if err != nil { - return err - } - - if err = injectMutationError(t, txn, sh); err != nil { - return err - } - - tc := ctx.GetExprCtx().GetEvalCtx().TypeCtx() - if ctx.EnableMutationChecker() { - if err = CheckDataConsistency(txn, tc, t, nil, r, memBuffer, sh); err != nil { - return errors.Trace(err) - } - } - memBuffer.Release(sh) - - binlogSupport, shouldWriteBinlog := getBinlogSupport(ctx, t.meta) - if shouldWriteBinlog { - cols := t.DeletableCols() - colIDs := make([]int64, 0, len(cols)+1) - for _, col := range cols { - colIDs = append(colIDs, col.ID) - } - var binlogRow []types.Datum - if !t.meta.PKIsHandle && !t.meta.IsCommonHandle { - colIDs = append(colIDs, model.ExtraHandleID) - binlogRow = make([]types.Datum, 0, len(r)+1) - binlogRow = append(binlogRow, r...) - handleData, err := h.Data() - if err != nil { - return err - } - binlogRow = append(binlogRow, handleData...) - } else { - binlogRow = r - } - err = t.addDeleteBinlog(ctx, binlogSupport, binlogRow, colIDs) - } - - if s, ok := ctx.GetStatisticsSupport(); ok { - // a reusable buffer to save malloc - // Note: The buffer should not be referenced or modified outside this function. - // It can only act as a temporary buffer for the current function call. - colSizeBuffer := ctx.GetMutateBuffers().GetColSizeDeltaBufferWithCap(len(t.Cols())) - for id, col := range t.Cols() { - size, err := codec.EstimateValueSize(tc, r[id]) - if err != nil { - continue - } - colSizeBuffer.AddColSizeDelta(col.ID, -int64(size-1)) - } - s.UpdatePhysicalTableDelta( - t.physicalTableID, -1, 1, colSizeBuffer, - ) - } - return err -} - -func (t *TableCommon) addInsertBinlog(ctx table.MutateContext, support tbctx.BinlogSupport, h kv.Handle, encodeRowBuffer *tbctx.EncodeRowBuffer) error { - evalCtx := ctx.GetExprCtx().GetEvalCtx() - loc, ec := evalCtx.Location(), evalCtx.ErrCtx() - handleData, err := h.Data() - if err != nil { - return err - } - pk, err := codec.EncodeValue(loc, nil, handleData...) - err = ec.HandleError(err) - if err != nil { - return err - } - value, err := encodeRowBuffer.EncodeBinlogRowData(loc, ec) - if err != nil { - return err - } - bin := append(pk, value...) - mutation := support.GetBinlogMutation(t.tableID) - mutation.InsertedRows = append(mutation.InsertedRows, bin) - mutation.Sequence = append(mutation.Sequence, binlog.MutationType_Insert) - return nil -} - -func (t *TableCommon) addUpdateBinlog(ctx table.MutateContext, support tbctx.BinlogSupport, oldRow, newRow []types.Datum, colIDs []int64) error { - evalCtx := ctx.GetExprCtx().GetEvalCtx() - loc, ec := evalCtx.Location(), evalCtx.ErrCtx() - old, err := tablecodec.EncodeOldRow(loc, oldRow, colIDs, nil, nil) - err = ec.HandleError(err) - if err != nil { - return err - } - newVal, err := tablecodec.EncodeOldRow(loc, newRow, colIDs, nil, nil) - err = ec.HandleError(err) - if err != nil { - return err - } - bin := append(old, newVal...) - mutation := support.GetBinlogMutation(t.tableID) - mutation.UpdatedRows = append(mutation.UpdatedRows, bin) - mutation.Sequence = append(mutation.Sequence, binlog.MutationType_Update) - return nil -} - -func (t *TableCommon) addDeleteBinlog(ctx table.MutateContext, support tbctx.BinlogSupport, r []types.Datum, colIDs []int64) error { - evalCtx := ctx.GetExprCtx().GetEvalCtx() - loc, ec := evalCtx.Location(), evalCtx.ErrCtx() - data, err := tablecodec.EncodeOldRow(loc, r, colIDs, nil, nil) - err = ec.HandleError(err) - if err != nil { - return err - } - mutation := support.GetBinlogMutation(t.tableID) - mutation.DeletedRows = append(mutation.DeletedRows, data) - mutation.Sequence = append(mutation.Sequence, binlog.MutationType_DeleteRow) - return nil -} - -func writeSequenceUpdateValueBinlog(sctx sessionctx.Context, db, sequence string, end int64) error { - // 1: when sequenceCommon update the local cache passively. - // 2: When sequenceCommon setval to the allocator actively. - // Both of this two case means the upper bound the sequence has changed in meta, which need to write the binlog - // to the downstream. - // Sequence sends `select setval(seq, num)` sql string to downstream via `setDDLBinlog`, which is mocked as a DDL binlog. - binlogCli := sctx.GetSessionVars().BinlogClient - sqlMode := sctx.GetSessionVars().SQLMode - sequenceFullName := stringutil.Escape(db, sqlMode) + "." + stringutil.Escape(sequence, sqlMode) - sql := "select setval(" + sequenceFullName + ", " + strconv.FormatInt(end, 10) + ")" - - ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnMeta) - err := kv.RunInNewTxn(ctx, sctx.GetStore(), true, func(ctx context.Context, txn kv.Transaction) error { - m := meta.NewMeta(txn) - mockJobID, err := m.GenGlobalID() - if err != nil { - return err - } - binloginfo.SetDDLBinlog(binlogCli, txn, mockJobID, int32(model.StatePublic), sql) - return nil - }) - return err -} - -func (t *TableCommon) removeRowData(ctx table.MutateContext, h kv.Handle) error { - // Remove row data. - txn, err := ctx.Txn(true) - if err != nil { - return err - } - - key := t.RecordKey(h) - failpoint.Inject("removeRecordForceAssertNotExist", func() { - // Assert the key doesn't exist while it actually exists. This is helpful to test if assertion takes effect. - // Since only the first assertion takes effect, set the injected assertion before setting the correct one to - // override it. - if ctx.ConnectionID() != 0 { - logutil.BgLogger().Info("force asserting not exist on RemoveRecord", zap.String("category", "failpoint"), zap.Uint64("startTS", txn.StartTS())) - if err = txn.SetAssertion(key, kv.SetAssertNotExist); err != nil { - failpoint.Return(err) - } - } - }) - if t.shouldAssert(ctx.TxnAssertionLevel()) { - err = txn.SetAssertion(key, kv.SetAssertExist) - } else { - err = txn.SetAssertion(key, kv.SetAssertUnknown) - } - if err != nil { - return err - } - return txn.Delete(key) -} - -// removeRowIndices removes all the indices of a row. -func (t *TableCommon) removeRowIndices(ctx table.MutateContext, h kv.Handle, rec []types.Datum) error { - txn, err := ctx.Txn(true) - if err != nil { - return err - } - for _, v := range t.deletableIndices() { - if v.Meta().Primary && (t.Meta().IsCommonHandle || t.Meta().PKIsHandle) { - continue - } - vals, err := v.FetchValues(rec, nil) - if err != nil { - logutil.BgLogger().Info("remove row index failed", zap.Any("index", v.Meta()), zap.Uint64("txnStartTS", txn.StartTS()), zap.String("handle", h.String()), zap.Any("record", rec), zap.Error(err)) - return err - } - if err = v.Delete(ctx, txn, vals, h); err != nil { - if v.Meta().State != model.StatePublic && kv.ErrNotExist.Equal(err) { - // If the index is not in public state, we may have not created the index, - // or already deleted the index, so skip ErrNotExist error. - logutil.BgLogger().Debug("row index not exists", zap.Any("index", v.Meta()), zap.Uint64("txnStartTS", txn.StartTS()), zap.String("handle", h.String())) - continue - } - return err - } - } - return nil -} - -// removeRowIndex implements table.Table RemoveRowIndex interface. -func (t *TableCommon) removeRowIndex(ctx table.MutateContext, h kv.Handle, vals []types.Datum, idx table.Index, txn kv.Transaction) error { - return idx.Delete(ctx, txn, vals, h) -} - -// buildIndexForRow implements table.Table BuildIndexForRow interface. -func (t *TableCommon) buildIndexForRow(ctx table.MutateContext, h kv.Handle, vals []types.Datum, newData []types.Datum, idx *index, txn kv.Transaction, untouched bool, opt *table.CreateIdxOpt) error { - rsData := TryGetHandleRestoredDataWrapper(t.meta, newData, nil, idx.Meta()) - if _, err := idx.create(ctx, txn, vals, h, rsData, untouched, opt); err != nil { - if kv.ErrKeyExists.Equal(err) { - // Make error message consistent with MySQL. - tablecodec.TruncateIndexValues(t.meta, idx.Meta(), vals) - colStrVals, err1 := genIndexKeyStrs(vals) - if err1 != nil { - // if genIndexKeyStrs failed, return the original error. - return err - } - - return kv.GenKeyExistsErr(colStrVals, fmt.Sprintf("%s.%s", idx.TableMeta().Name.String(), idx.Meta().Name.String())) - } - return err - } - return nil -} - -// IterRecords iterates records in the table and calls fn. -func IterRecords(t table.Table, ctx sessionctx.Context, cols []*table.Column, - fn table.RecordIterFunc) error { - prefix := t.RecordPrefix() - txn, err := ctx.Txn(true) - if err != nil { - return err - } - - startKey := tablecodec.EncodeRecordKey(t.RecordPrefix(), kv.IntHandle(math.MinInt64)) - it, err := txn.Iter(startKey, prefix.PrefixNext()) - if err != nil { - return err - } - defer it.Close() - - if !it.Valid() { - return nil - } - - logutil.BgLogger().Debug("iterate records", zap.ByteString("startKey", startKey), zap.ByteString("key", it.Key()), zap.ByteString("value", it.Value())) - - colMap := make(map[int64]*types.FieldType, len(cols)) - for _, col := range cols { - colMap[col.ID] = &col.FieldType - } - defaultVals := make([]types.Datum, len(cols)) - for it.Valid() && it.Key().HasPrefix(prefix) { - // first kv pair is row lock information. - // TODO: check valid lock - // get row handle - handle, err := tablecodec.DecodeRowKey(it.Key()) - if err != nil { - return err - } - rowMap, err := tablecodec.DecodeRowToDatumMap(it.Value(), colMap, ctx.GetSessionVars().Location()) - if err != nil { - return err - } - pkIds, decodeLoc := TryGetCommonPkColumnIds(t.Meta()), ctx.GetSessionVars().Location() - data := make([]types.Datum, len(cols)) - for _, col := range cols { - if col.IsPKHandleColumn(t.Meta()) { - if mysql.HasUnsignedFlag(col.GetFlag()) { - data[col.Offset].SetUint64(uint64(handle.IntValue())) - } else { - data[col.Offset].SetInt64(handle.IntValue()) - } - continue - } else if mysql.HasPriKeyFlag(col.GetFlag()) { - data[col.Offset], err = tryDecodeColumnFromCommonHandle(col, handle, pkIds, decodeLoc) - if err != nil { - return err - } - continue - } - if _, ok := rowMap[col.ID]; ok { - data[col.Offset] = rowMap[col.ID] - continue - } - data[col.Offset], err = GetColDefaultValue(ctx.GetExprCtx(), col, defaultVals) - if err != nil { - return err - } - } - more, err := fn(handle, data, cols) - if !more || err != nil { - return err - } - - rk := tablecodec.EncodeRecordKey(t.RecordPrefix(), handle) - err = kv.NextUntil(it, util.RowKeyPrefixFilter(rk)) - if err != nil { - return err - } - } - - return nil -} - -func tryDecodeColumnFromCommonHandle(col *table.Column, handle kv.Handle, pkIds []int64, decodeLoc *time.Location) (types.Datum, error) { - for i, hid := range pkIds { - if hid != col.ID { - continue - } - _, d, err := codec.DecodeOne(handle.EncodedCol(i)) - if err != nil { - return types.Datum{}, errors.Trace(err) - } - if d, err = tablecodec.Unflatten(d, &col.FieldType, decodeLoc); err != nil { - return types.Datum{}, err - } - return d, nil - } - return types.Datum{}, nil -} - -// GetColDefaultValue gets a column default value. -// The defaultVals is used to avoid calculating the default value multiple times. -func GetColDefaultValue(ctx exprctx.BuildContext, col *table.Column, defaultVals []types.Datum) ( - colVal types.Datum, err error) { - if col.GetOriginDefaultValue() == nil && mysql.HasNotNullFlag(col.GetFlag()) { - return colVal, errors.New("Miss column") - } - if defaultVals[col.Offset].IsNull() { - colVal, err = table.GetColOriginDefaultValue(ctx, col.ToInfo()) - if err != nil { - return colVal, err - } - defaultVals[col.Offset] = colVal - } else { - colVal = defaultVals[col.Offset] - } - - return colVal, nil -} - -// AllocHandle allocate a new handle. -// A statement could reserve some ID in the statement context, try those ones first. -func AllocHandle(ctx context.Context, mctx table.MutateContext, t table.Table) (kv.IntHandle, - error) { - if mctx != nil { - if reserved, ok := mctx.GetReservedRowIDAlloc(); ok { - // First try to alloc if the statement has reserved auto ID. - if rowID, ok := reserved.Consume(); ok { - return kv.IntHandle(rowID), nil - } - } - } - - _, rowID, err := AllocHandleIDs(ctx, mctx, t, 1) - return kv.IntHandle(rowID), err -} - -// AllocHandleIDs allocates n handle ids (_tidb_rowid), and caches the range -// in the table.MutateContext. -func AllocHandleIDs(ctx context.Context, mctx table.MutateContext, t table.Table, n uint64) (int64, int64, error) { - meta := t.Meta() - base, maxID, err := t.Allocators(mctx).Get(autoid.RowIDAllocType).Alloc(ctx, n, 1, 1) - if err != nil { - return 0, 0, err - } - if meta.ShardRowIDBits > 0 { - shardFmt := autoid.NewShardIDFormat(types.NewFieldType(mysql.TypeLonglong), meta.ShardRowIDBits, autoid.RowIDBitLength) - // Use max record ShardRowIDBits to check overflow. - if OverflowShardBits(maxID, meta.MaxShardRowIDBits, autoid.RowIDBitLength, true) { - // If overflow, the rowID may be duplicated. For examples, - // t.meta.ShardRowIDBits = 4 - // rowID = 0010111111111111111111111111111111111111111111111111111111111111 - // shard = 0100000000000000000000000000000000000000000000000000000000000000 - // will be duplicated with: - // rowID = 0100111111111111111111111111111111111111111111111111111111111111 - // shard = 0010000000000000000000000000000000000000000000000000000000000000 - return 0, 0, autoid.ErrAutoincReadFailed - } - shard := mctx.GetRowIDShardGenerator().GetCurrentShard(int(n)) - base = shardFmt.Compose(shard, base) - maxID = shardFmt.Compose(shard, maxID) - } - return base, maxID, nil -} - -// OverflowShardBits checks whether the recordID overflow `1<<(typeBitsLength-shardRowIDBits-1) -1`. -func OverflowShardBits(recordID int64, shardRowIDBits uint64, typeBitsLength uint64, reservedSignBit bool) bool { - var signBit uint64 - if reservedSignBit { - signBit = 1 - } - mask := (1< 0 -} - -// Allocators implements table.Table Allocators interface. -func (t *TableCommon) Allocators(ctx table.AllocatorContext) autoid.Allocators { - if ctx == nil { - return t.allocs - } - if alloc, ok := ctx.AlternativeAllocators(t.meta); ok { - return alloc - } - return t.allocs -} - -// Type implements table.Table Type interface. -func (t *TableCommon) Type() table.Type { - return table.NormalTable -} - -func getBinlogSupport(ctx table.MutateContext, tblInfo *model.TableInfo) (tbctx.BinlogSupport, bool) { - if tblInfo.TempTableType != model.TempTableNone || ctx.InRestrictedSQL() { - return nil, false - } - return ctx.GetBinlogSupport() -} - -func (t *TableCommon) canSkip(col *table.Column, value *types.Datum) bool { - return CanSkip(t.Meta(), col, value) -} - -// CanSkip is for these cases, we can skip the columns in encoded row: -// 1. the column is included in primary key; -// 2. the column's default value is null, and the value equals to that but has no origin default; -// 3. the column is virtual generated. -func CanSkip(info *model.TableInfo, col *table.Column, value *types.Datum) bool { - if col.IsPKHandleColumn(info) { - return true - } - if col.IsCommonHandleColumn(info) { - pkIdx := FindPrimaryIndex(info) - for _, idxCol := range pkIdx.Columns { - if info.Columns[idxCol.Offset].ID != col.ID { - continue - } - canSkip := idxCol.Length == types.UnspecifiedLength - canSkip = canSkip && !types.NeedRestoredData(&col.FieldType) - return canSkip - } - } - if col.GetDefaultValue() == nil && value.IsNull() && col.GetOriginDefaultValue() == nil { - return true - } - if col.IsVirtualGenerated() { - return true - } - return false -} - -// canSkipUpdateBinlog checks whether the column can be skipped or not. -func (t *TableCommon) canSkipUpdateBinlog(col *table.Column, value types.Datum) bool { - return col.IsVirtualGenerated() -} - -// FindIndexByColName returns a public table index containing only one column named `name`. -func FindIndexByColName(t table.Table, name string) table.Index { - for _, idx := range t.Indices() { - // only public index can be read. - if idx.Meta().State != model.StatePublic { - continue - } - - if len(idx.Meta().Columns) == 1 && strings.EqualFold(idx.Meta().Columns[0].Name.L, name) { - return idx - } - } - return nil -} - -func getDuplicateError(tblInfo *model.TableInfo, handle kv.Handle, row []types.Datum) error { - keyName := tblInfo.Name.String() + ".PRIMARY" - - if handle.IsInt() { - return kv.GenKeyExistsErr([]string{handle.String()}, keyName) - } - pkIdx := FindPrimaryIndex(tblInfo) - if pkIdx == nil { - handleData, err := handle.Data() - if err != nil { - return kv.ErrKeyExists.FastGenByArgs(handle.String(), keyName) - } - colStrVals, err := genIndexKeyStrs(handleData) - if err != nil { - return kv.ErrKeyExists.FastGenByArgs(handle.String(), keyName) - } - return kv.GenKeyExistsErr(colStrVals, keyName) - } - pkDts := make([]types.Datum, 0, len(pkIdx.Columns)) - for _, idxCol := range pkIdx.Columns { - pkDts = append(pkDts, row[idxCol.Offset]) - } - tablecodec.TruncateIndexValues(tblInfo, pkIdx, pkDts) - colStrVals, err := genIndexKeyStrs(pkDts) - if err != nil { - // if genIndexKeyStrs failed, return ErrKeyExists with handle.String(). - return kv.ErrKeyExists.FastGenByArgs(handle.String(), keyName) - } - return kv.GenKeyExistsErr(colStrVals, keyName) -} - -func init() { - table.TableFromMeta = TableFromMeta - table.MockTableFromMeta = MockTableFromMeta - tableutil.TempTableFromMeta = TempTableFromMeta -} - -// sequenceCommon cache the sequence value. -// `alter sequence` will invalidate the cached range. -// `setval` will recompute the start position of cached value. -type sequenceCommon struct { - meta *model.SequenceInfo - // base < end when increment > 0. - // base > end when increment < 0. - end int64 - base int64 - // round is used to count the cycle times. - round int64 - mu sync.RWMutex -} - -// GetSequenceBaseEndRound is used in test. -func (s *sequenceCommon) GetSequenceBaseEndRound() (int64, int64, int64) { - s.mu.RLock() - defer s.mu.RUnlock() - return s.base, s.end, s.round -} - -// GetSequenceNextVal implements util.SequenceTable GetSequenceNextVal interface. -// Caching the sequence value in table, we can easily be notified with the cache empty, -// and write the binlogInfo in table level rather than in allocator. -func (t *TableCommon) GetSequenceNextVal(ctx any, dbName, seqName string) (nextVal int64, err error) { - seq := t.sequence - if seq == nil { - // TODO: refine the error. - return 0, errors.New("sequenceCommon is nil") - } - seq.mu.Lock() - defer seq.mu.Unlock() - - err = func() error { - // Check if need to update the cache batch from storage. - // Because seq.base is not always the last allocated value (may be set by setval()). - // So we should try to seek the next value in cache (not just add increment to seq.base). - var ( - updateCache bool - offset int64 - ok bool - ) - if seq.base == seq.end { - // There is no cache yet. - updateCache = true - } else { - // Seek the first valid value in cache. - offset = seq.getOffset() - if seq.meta.Increment > 0 { - nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.base, seq.end) - } else { - nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.end, seq.base) - } - if !ok { - updateCache = true - } - } - if !updateCache { - return nil - } - // Update batch alloc from kv storage. - sequenceAlloc, err1 := getSequenceAllocator(t.allocs) - if err1 != nil { - return err1 - } - var base, end, round int64 - base, end, round, err1 = sequenceAlloc.AllocSeqCache() - if err1 != nil { - return err1 - } - // Only update local cache when alloc succeed. - seq.base = base - seq.end = end - seq.round = round - // write sequence binlog to the pumpClient. - if ctx.(sessionctx.Context).GetSessionVars().BinlogClient != nil { - err = writeSequenceUpdateValueBinlog(ctx.(sessionctx.Context), dbName, seqName, seq.end) - if err != nil { - return err - } - } - // Seek the first valid value in new cache. - // Offset may have changed cause the round is updated. - offset = seq.getOffset() - if seq.meta.Increment > 0 { - nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.base, seq.end) - } else { - nextVal, ok = autoid.SeekToFirstSequenceValue(seq.base, seq.meta.Increment, offset, seq.end, seq.base) - } - if !ok { - return errors.New("can't find the first value in sequence cache") - } - return nil - }() - // Sequence alloc in kv store error. - if err != nil { - if err == autoid.ErrAutoincReadFailed { - return 0, table.ErrSequenceHasRunOut.GenWithStackByArgs(dbName, seqName) - } - return 0, err - } - seq.base = nextVal - return nextVal, nil -} - -// SetSequenceVal implements util.SequenceTable SetSequenceVal interface. -// The returned bool indicates the newVal is already under the base. -func (t *TableCommon) SetSequenceVal(ctx any, newVal int64, dbName, seqName string) (int64, bool, error) { - seq := t.sequence - if seq == nil { - // TODO: refine the error. - return 0, false, errors.New("sequenceCommon is nil") - } - seq.mu.Lock() - defer seq.mu.Unlock() - - if seq.meta.Increment > 0 { - if newVal <= t.sequence.base { - return 0, true, nil - } - if newVal <= t.sequence.end { - t.sequence.base = newVal - return newVal, false, nil - } - } else { - if newVal >= t.sequence.base { - return 0, true, nil - } - if newVal >= t.sequence.end { - t.sequence.base = newVal - return newVal, false, nil - } - } - - // Invalid the current cache. - t.sequence.base = t.sequence.end - - // Rebase from kv storage. - sequenceAlloc, err := getSequenceAllocator(t.allocs) - if err != nil { - return 0, false, err - } - res, alreadySatisfied, err := sequenceAlloc.RebaseSeq(newVal) - if err != nil { - return 0, false, err - } - if !alreadySatisfied { - // Write sequence binlog to the pumpClient. - if ctx.(sessionctx.Context).GetSessionVars().BinlogClient != nil { - err = writeSequenceUpdateValueBinlog(ctx.(sessionctx.Context), dbName, seqName, seq.end) - if err != nil { - return 0, false, err - } - } - } - // Record the current end after setval succeed. - // Consider the following case. - // create sequence seq - // setval(seq, 100) setval(seq, 50) - // Because no cache (base, end keep 0), so the second setval won't return NULL. - t.sequence.base, t.sequence.end = newVal, newVal - return res, alreadySatisfied, nil -} - -// getOffset is used in under GetSequenceNextVal & SetSequenceVal, which mu is locked. -func (s *sequenceCommon) getOffset() int64 { - offset := s.meta.Start - if s.meta.Cycle && s.round > 0 { - if s.meta.Increment > 0 { - offset = s.meta.MinValue - } else { - offset = s.meta.MaxValue - } - } - return offset -} - -// GetSequenceID implements util.SequenceTable GetSequenceID interface. -func (t *TableCommon) GetSequenceID() int64 { - return t.tableID -} - -// GetSequenceCommon is used in test to get sequenceCommon. -func (t *TableCommon) GetSequenceCommon() *sequenceCommon { - return t.sequence -} - -// TryGetHandleRestoredDataWrapper tries to get the restored data for handle if needed. The argument can be a slice or a map. -func TryGetHandleRestoredDataWrapper(tblInfo *model.TableInfo, row []types.Datum, rowMap map[int64]types.Datum, idx *model.IndexInfo) []types.Datum { - if !collate.NewCollationEnabled() || !tblInfo.IsCommonHandle || tblInfo.CommonHandleVersion == 0 { - return nil - } - rsData := make([]types.Datum, 0, 4) - pkIdx := FindPrimaryIndex(tblInfo) - for _, pkIdxCol := range pkIdx.Columns { - pkCol := tblInfo.Columns[pkIdxCol.Offset] - if !types.NeedRestoredData(&pkCol.FieldType) { - continue - } - var datum types.Datum - if len(rowMap) > 0 { - datum = rowMap[pkCol.ID] - } else { - datum = row[pkCol.Offset] - } - TryTruncateRestoredData(&datum, pkCol, pkIdxCol, idx) - ConvertDatumToTailSpaceCount(&datum, pkCol) - rsData = append(rsData, datum) - } - return rsData -} - -// TryTruncateRestoredData tries to truncate index values. -// Says that primary key(a (8)), -// For index t(a), don't truncate the value. -// For index t(a(9)), truncate to a(9). -// For index t(a(7)), truncate to a(8). -func TryTruncateRestoredData(datum *types.Datum, pkCol *model.ColumnInfo, - pkIdxCol *model.IndexColumn, idx *model.IndexInfo) { - truncateTargetCol := pkIdxCol - for _, idxCol := range idx.Columns { - if idxCol.Offset == pkIdxCol.Offset { - truncateTargetCol = maxIndexLen(pkIdxCol, idxCol) - break - } - } - tablecodec.TruncateIndexValue(datum, truncateTargetCol, pkCol) -} - -// ConvertDatumToTailSpaceCount converts a string datum to an int datum that represents the tail space count. -func ConvertDatumToTailSpaceCount(datum *types.Datum, col *model.ColumnInfo) { - if collate.IsBinCollation(col.GetCollate()) { - *datum = types.NewIntDatum(stringutil.GetTailSpaceCount(datum.GetString())) - } -} - -func maxIndexLen(idxA, idxB *model.IndexColumn) *model.IndexColumn { - if idxA.Length == types.UnspecifiedLength { - return idxA - } - if idxB.Length == types.UnspecifiedLength { - return idxB - } - if idxA.Length > idxB.Length { - return idxA - } - return idxB -} - -func getSequenceAllocator(allocs autoid.Allocators) (autoid.Allocator, error) { - for _, alloc := range allocs.Allocs { - if alloc.GetType() == autoid.SequenceType { - return alloc, nil - } - } - // TODO: refine the error. - return nil, errors.New("sequence allocator is nil") -} - -// BuildTableScanFromInfos build tipb.TableScan with *model.TableInfo and *model.ColumnInfo. -func BuildTableScanFromInfos(tableInfo *model.TableInfo, columnInfos []*model.ColumnInfo) *tipb.TableScan { - pkColIDs := TryGetCommonPkColumnIds(tableInfo) - tsExec := &tipb.TableScan{ - TableId: tableInfo.ID, - Columns: util.ColumnsToProto(columnInfos, tableInfo.PKIsHandle, false), - PrimaryColumnIds: pkColIDs, - } - if tableInfo.IsCommonHandle { - tsExec.PrimaryPrefixColumnIds = PrimaryPrefixColumnIDs(tableInfo) - } - return tsExec -} - -// BuildPartitionTableScanFromInfos build tipb.PartitonTableScan with *model.TableInfo and *model.ColumnInfo. -func BuildPartitionTableScanFromInfos(tableInfo *model.TableInfo, columnInfos []*model.ColumnInfo, fastScan bool) *tipb.PartitionTableScan { - pkColIDs := TryGetCommonPkColumnIds(tableInfo) - tsExec := &tipb.PartitionTableScan{ - TableId: tableInfo.ID, - Columns: util.ColumnsToProto(columnInfos, tableInfo.PKIsHandle, false), - PrimaryColumnIds: pkColIDs, - IsFastScan: &fastScan, - } - if tableInfo.IsCommonHandle { - tsExec.PrimaryPrefixColumnIds = PrimaryPrefixColumnIDs(tableInfo) - } - return tsExec -} - -// SetPBColumnsDefaultValue sets the default values of tipb.ColumnInfo. -func SetPBColumnsDefaultValue(ctx expression.BuildContext, pbColumns []*tipb.ColumnInfo, columns []*model.ColumnInfo) error { - for i, c := range columns { - // For virtual columns, we set their default values to NULL so that TiKV will return NULL properly, - // They real values will be computed later. - if c.IsGenerated() && !c.GeneratedStored { - pbColumns[i].DefaultVal = []byte{codec.NilFlag} - } - if c.GetOriginDefaultValue() == nil { - continue - } - - evalCtx := ctx.GetEvalCtx() - d, err := table.GetColOriginDefaultValueWithoutStrictSQLMode(ctx, c) - if err != nil { - return err - } - - pbColumns[i].DefaultVal, err = tablecodec.EncodeValue(evalCtx.Location(), nil, d) - ec := evalCtx.ErrCtx() - err = ec.HandleError(err) - if err != nil { - return err - } - } - return nil -} - -// TemporaryTable is used to store transaction-specific or session-specific information for global / local temporary tables. -// For example, stats and autoID should have their own copies of data, instead of being shared by all sessions. -type TemporaryTable struct { - // Whether it's modified in this transaction. - modified bool - // The stats of this table. So far it's always pseudo stats. - stats *statistics.Table - // The autoID allocator of this table. - autoIDAllocator autoid.Allocator - // Table size. - size int64 - - meta *model.TableInfo -} - -// TempTableFromMeta builds a TempTable from model.TableInfo. -func TempTableFromMeta(tblInfo *model.TableInfo) tableutil.TempTable { - return &TemporaryTable{ - modified: false, - stats: statistics.PseudoTable(tblInfo, false, false), - autoIDAllocator: autoid.NewAllocatorFromTempTblInfo(tblInfo), - meta: tblInfo, - } -} - -// GetAutoIDAllocator is implemented from TempTable.GetAutoIDAllocator. -func (t *TemporaryTable) GetAutoIDAllocator() autoid.Allocator { - return t.autoIDAllocator -} - -// SetModified is implemented from TempTable.SetModified. -func (t *TemporaryTable) SetModified(modified bool) { - t.modified = modified -} - -// GetModified is implemented from TempTable.GetModified. -func (t *TemporaryTable) GetModified() bool { - return t.modified -} - -// GetStats is implemented from TempTable.GetStats. -func (t *TemporaryTable) GetStats() any { - return t.stats -} - -// GetSize gets the table size. -func (t *TemporaryTable) GetSize() int64 { - return t.size -} - -// SetSize sets the table size. -func (t *TemporaryTable) SetSize(v int64) { - t.size = v -} - -// GetMeta gets the table meta. -func (t *TemporaryTable) GetMeta() *model.TableInfo { - return t.meta -} diff --git a/pkg/ttl/ttlworker/binding__failpoint_binding__.go b/pkg/ttl/ttlworker/binding__failpoint_binding__.go deleted file mode 100644 index 74271806c94e4..0000000000000 --- a/pkg/ttl/ttlworker/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package ttlworker - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/ttl/ttlworker/config.go b/pkg/ttl/ttlworker/config.go index a562fd04afd77..b5cef0b0ecc11 100644 --- a/pkg/ttl/ttlworker/config.go +++ b/pkg/ttl/ttlworker/config.go @@ -35,36 +35,36 @@ const ttlTaskHeartBeatTickerInterval = time.Minute const ttlGCInterval = time.Hour func getUpdateInfoSchemaCacheInterval() time.Duration { - if val, _err_ := failpoint.Eval(_curpkg_("update-info-schema-cache-interval")); _err_ == nil { + failpoint.Inject("update-info-schema-cache-interval", func(val failpoint.Value) time.Duration { return time.Duration(val.(int)) - } + }) return updateInfoSchemaCacheInterval } func getUpdateTTLTableStatusCacheInterval() time.Duration { - if val, _err_ := failpoint.Eval(_curpkg_("update-status-table-cache-interval")); _err_ == nil { + failpoint.Inject("update-status-table-cache-interval", func(val failpoint.Value) time.Duration { return time.Duration(val.(int)) - } + }) return updateTTLTableStatusCacheInterval } func getResizeWorkersInterval() time.Duration { - if val, _err_ := failpoint.Eval(_curpkg_("resize-workers-interval")); _err_ == nil { + failpoint.Inject("resize-workers-interval", func(val failpoint.Value) time.Duration { return time.Duration(val.(int)) - } + }) return resizeWorkersInterval } func getTaskManagerLoopTickerInterval() time.Duration { - if val, _err_ := failpoint.Eval(_curpkg_("task-manager-loop-interval")); _err_ == nil { + failpoint.Inject("task-manager-loop-interval", func(val failpoint.Value) time.Duration { return time.Duration(val.(int)) - } + }) return taskManagerLoopTickerInterval } func getTaskManagerHeartBeatExpireInterval() time.Duration { - if val, _err_ := failpoint.Eval(_curpkg_("task-manager-heartbeat-expire-interval")); _err_ == nil { + failpoint.Inject("task-manager-heartbeat-expire-interval", func(val failpoint.Value) time.Duration { return time.Duration(val.(int)) - } + }) return 2 * ttlTaskHeartBeatTickerInterval } diff --git a/pkg/ttl/ttlworker/config.go__failpoint_stash__ b/pkg/ttl/ttlworker/config.go__failpoint_stash__ deleted file mode 100644 index b5cef0b0ecc11..0000000000000 --- a/pkg/ttl/ttlworker/config.go__failpoint_stash__ +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ttlworker - -import ( - "time" - - "github.com/pingcap/failpoint" -) - -const jobManagerLoopTickerInterval = 10 * time.Second - -const updateInfoSchemaCacheInterval = 2 * time.Minute -const updateTTLTableStatusCacheInterval = 2 * time.Minute - -const ttlInternalSQLTimeout = 30 * time.Second -const resizeWorkersInterval = 30 * time.Second -const splitScanCount = 64 -const ttlJobTimeout = 6 * time.Hour - -const taskManagerLoopTickerInterval = time.Minute -const ttlTaskHeartBeatTickerInterval = time.Minute -const ttlGCInterval = time.Hour - -func getUpdateInfoSchemaCacheInterval() time.Duration { - failpoint.Inject("update-info-schema-cache-interval", func(val failpoint.Value) time.Duration { - return time.Duration(val.(int)) - }) - return updateInfoSchemaCacheInterval -} - -func getUpdateTTLTableStatusCacheInterval() time.Duration { - failpoint.Inject("update-status-table-cache-interval", func(val failpoint.Value) time.Duration { - return time.Duration(val.(int)) - }) - return updateTTLTableStatusCacheInterval -} - -func getResizeWorkersInterval() time.Duration { - failpoint.Inject("resize-workers-interval", func(val failpoint.Value) time.Duration { - return time.Duration(val.(int)) - }) - return resizeWorkersInterval -} - -func getTaskManagerLoopTickerInterval() time.Duration { - failpoint.Inject("task-manager-loop-interval", func(val failpoint.Value) time.Duration { - return time.Duration(val.(int)) - }) - return taskManagerLoopTickerInterval -} - -func getTaskManagerHeartBeatExpireInterval() time.Duration { - failpoint.Inject("task-manager-heartbeat-expire-interval", func(val failpoint.Value) time.Duration { - return time.Duration(val.(int)) - }) - return 2 * ttlTaskHeartBeatTickerInterval -} diff --git a/pkg/util/binding__failpoint_binding__.go b/pkg/util/binding__failpoint_binding__.go deleted file mode 100644 index c7fcdb8c0fcf4..0000000000000 --- a/pkg/util/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package util - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/breakpoint/binding__failpoint_binding__.go b/pkg/util/breakpoint/binding__failpoint_binding__.go deleted file mode 100644 index 2199b6182c919..0000000000000 --- a/pkg/util/breakpoint/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package breakpoint - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/breakpoint/breakpoint.go b/pkg/util/breakpoint/breakpoint.go index 10eeb0f1eb64a..60c0e1f828799 100644 --- a/pkg/util/breakpoint/breakpoint.go +++ b/pkg/util/breakpoint/breakpoint.go @@ -25,10 +25,10 @@ const NotifyBreakPointFuncKey = stringutil.StringerStr("breakPointNotifyFunc") // Inject injects a break point to a session func Inject(sctx sessionctx.Context, name string) { - if _, _err_ := failpoint.Eval(_curpkg_(name)); _err_ == nil { + failpoint.Inject(name, func(_ failpoint.Value) { val := sctx.Value(NotifyBreakPointFuncKey) if breakPointNotifyAndWaitContinue, ok := val.(func(string)); ok { breakPointNotifyAndWaitContinue(name) } - } + }) } diff --git a/pkg/util/breakpoint/breakpoint.go__failpoint_stash__ b/pkg/util/breakpoint/breakpoint.go__failpoint_stash__ deleted file mode 100644 index 60c0e1f828799..0000000000000 --- a/pkg/util/breakpoint/breakpoint.go__failpoint_stash__ +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package breakpoint - -import ( - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/util/stringutil" -) - -// NotifyBreakPointFuncKey is the key where break point notify function located -const NotifyBreakPointFuncKey = stringutil.StringerStr("breakPointNotifyFunc") - -// Inject injects a break point to a session -func Inject(sctx sessionctx.Context, name string) { - failpoint.Inject(name, func(_ failpoint.Value) { - val := sctx.Value(NotifyBreakPointFuncKey) - if breakPointNotifyAndWaitContinue, ok := val.(func(string)); ok { - breakPointNotifyAndWaitContinue(name) - } - }) -} diff --git a/pkg/util/cgroup/binding__failpoint_binding__.go b/pkg/util/cgroup/binding__failpoint_binding__.go deleted file mode 100644 index d5cbc65ac07c0..0000000000000 --- a/pkg/util/cgroup/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package cgroup - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/cgroup/cgroup_cpu_linux.go b/pkg/util/cgroup/cgroup_cpu_linux.go index fe8512439d1d0..665b81f24e2e9 100644 --- a/pkg/util/cgroup/cgroup_cpu_linux.go +++ b/pkg/util/cgroup/cgroup_cpu_linux.go @@ -28,13 +28,13 @@ import ( // GetCgroupCPU returns the CPU usage and quota for the current cgroup. func GetCgroupCPU() (CPUUsage, error) { - if val, _err_ := failpoint.Eval(_curpkg_("GetCgroupCPUErr")); _err_ == nil { + failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { var cpuUsage CPUUsage - return cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr") + failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) } - } + }) cpuusage, err := getCgroupCPU("/") cpuusage.NumCPU = runtime.NumCPU() diff --git a/pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ b/pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ deleted file mode 100644 index 665b81f24e2e9..0000000000000 --- a/pkg/util/cgroup/cgroup_cpu_linux.go__failpoint_stash__ +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build linux - -package cgroup - -import ( - "math" - "os" - "runtime" - "strings" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" -) - -// GetCgroupCPU returns the CPU usage and quota for the current cgroup. -func GetCgroupCPU() (CPUUsage, error) { - failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - var cpuUsage CPUUsage - failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) - } - }) - cpuusage, err := getCgroupCPU("/") - - cpuusage.NumCPU = runtime.NumCPU() - return cpuusage, err -} - -// CPUQuotaToGOMAXPROCS converts the CPU quota applied to the calling process -// to a valid GOMAXPROCS value. -func CPUQuotaToGOMAXPROCS(minValue int) (int, CPUQuotaStatus, error) { - quota, err := GetCgroupCPU() - if err != nil { - return -1, CPUQuotaUndefined, err - } - maxProcs := int(math.Ceil(quota.CPUShares())) - if minValue > 0 && maxProcs < minValue { - return minValue, CPUQuotaMinUsed, nil - } - return maxProcs, CPUQuotaUsed, nil -} - -// GetCPUPeriodAndQuota returns CPU period and quota time of cgroup. -func GetCPUPeriodAndQuota() (period int64, quota int64, err error) { - return getCgroupCPUPeriodAndQuota("/") -} - -// InContainer returns true if the process is running in a container. -func InContainer() bool { - // for cgroup V1, check /proc/self/cgroup, for V2, check /proc/self/mountinfo - return inContainer(procPathCGroup) || inContainer(procPathMountInfo) -} - -func inContainer(path string) bool { - v, err := os.ReadFile(path) - if err != nil { - return false - } - - // For cgroup V1, check /proc/self/cgroup - if path == procPathCGroup { - if strings.Contains(string(v), "docker") || - strings.Contains(string(v), "kubepods") || - strings.Contains(string(v), "containerd") { - return true - } - } - - // For cgroup V2, check /proc/self/mountinfo - if path == procPathMountInfo { - lines := strings.Split(string(v), "\n") - for _, line := range lines { - v := strings.Split(line, " ") - // check mount point of root dir is on overlay or not. - // v[4] means `mount point`, v[8] means `filesystem type`. - // see details from https://man7.org/linux/man-pages/man5/proc.5.html - // TODO: enhance this check, as overlay is not the only storage driver for container. - if len(v) > 8 && v[4] == "/" && v[8] == "overlay" { - return true - } - } - } - - return false -} diff --git a/pkg/util/cgroup/cgroup_cpu_unsupport.go b/pkg/util/cgroup/cgroup_cpu_unsupport.go index 8c8b7c2c4c3ec..092875c09657b 100644 --- a/pkg/util/cgroup/cgroup_cpu_unsupport.go +++ b/pkg/util/cgroup/cgroup_cpu_unsupport.go @@ -26,12 +26,12 @@ import ( // GetCgroupCPU returns the CPU usage and quota for the current cgroup. func GetCgroupCPU() (CPUUsage, error) { var cpuUsage CPUUsage - if val, _err_ := failpoint.Eval(_curpkg_("GetCgroupCPUErr")); _err_ == nil { + failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { //nolint:forcetypeassert if val.(bool) { - return cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr") + failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) } - } + }) cpuUsage.NumCPU = runtime.NumCPU() return cpuUsage, nil } diff --git a/pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ b/pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ deleted file mode 100644 index 092875c09657b..0000000000000 --- a/pkg/util/cgroup/cgroup_cpu_unsupport.go__failpoint_stash__ +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//go:build !linux - -package cgroup - -import ( - "runtime" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" -) - -// GetCgroupCPU returns the CPU usage and quota for the current cgroup. -func GetCgroupCPU() (CPUUsage, error) { - var cpuUsage CPUUsage - failpoint.Inject("GetCgroupCPUErr", func(val failpoint.Value) { - //nolint:forcetypeassert - if val.(bool) { - failpoint.Return(cpuUsage, errors.Errorf("mockAddBatchDDLJobsErr")) - } - }) - cpuUsage.NumCPU = runtime.NumCPU() - return cpuUsage, nil -} - -// GetCPUPeriodAndQuota returns CPU period and quota time of cgroup. -// This is Linux-specific and not supported in the current OS. -func GetCPUPeriodAndQuota() (period int64, quota int64, err error) { - return -1, -1, nil -} - -// CPUQuotaToGOMAXPROCS converts the CPU quota applied to the calling process -// to a valid GOMAXPROCS value. This is Linux-specific and not supported in the -// current OS. -func CPUQuotaToGOMAXPROCS(_ int) (int, CPUQuotaStatus, error) { - return -1, CPUQuotaUndefined, nil -} - -// InContainer returns true if the process is running in a container. -func InContainer() bool { - return false -} diff --git a/pkg/util/chunk/binding__failpoint_binding__.go b/pkg/util/chunk/binding__failpoint_binding__.go deleted file mode 100644 index ff0fbc387a100..0000000000000 --- a/pkg/util/chunk/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package chunk - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/chunk/chunk_in_disk.go b/pkg/util/chunk/chunk_in_disk.go index e8e7177cac03d..dadcddefd4584 100644 --- a/pkg/util/chunk/chunk_in_disk.go +++ b/pkg/util/chunk/chunk_in_disk.go @@ -329,7 +329,7 @@ func (d *DataInDiskByChunks) NumChunks() int { func injectChunkInDiskRandomError() error { var err error - if val, _err_ := failpoint.Eval(_curpkg_("ChunkInDiskError")); _err_ == nil { + failpoint.Inject("ChunkInDiskError", func(val failpoint.Value) { if val.(bool) { randNum := rand.Int31n(10000) if randNum < 3 { @@ -339,6 +339,6 @@ func injectChunkInDiskRandomError() error { time.Sleep(time.Duration(delayTime) * time.Millisecond) } } - } + }) return err } diff --git a/pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ b/pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ deleted file mode 100644 index dadcddefd4584..0000000000000 --- a/pkg/util/chunk/chunk_in_disk.go__failpoint_stash__ +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package chunk - -import ( - "io" - "math/rand" - "os" - "strconv" - "time" - "unsafe" - - errors2 "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/memory" -) - -const byteLen = int64(unsafe.Sizeof(byte(0))) -const intLen = int64(unsafe.Sizeof(int(0))) -const int64Len = int64(unsafe.Sizeof(int64(0))) - -const chkFixedSize = intLen * 4 -const colMetaSize = int64Len * 4 - -const defaultChunkDataInDiskByChunksPath = "defaultChunkDataInDiskByChunksPath" - -// DataInDiskByChunks represents some data stored in temporary disk. -// They can only be restored by chunks. -type DataInDiskByChunks struct { - fieldTypes []*types.FieldType - offsetOfEachChunk []int64 - - totalDataSize int64 - totalRowNum int64 - diskTracker *disk.Tracker // track disk usage. - - dataFile diskFileReaderWriter - - // Write or read data needs this buffer to temporarily store data - buf []byte -} - -// NewDataInDiskByChunks creates a new DataInDiskByChunks with field types. -func NewDataInDiskByChunks(fieldTypes []*types.FieldType) *DataInDiskByChunks { - d := &DataInDiskByChunks{ - fieldTypes: fieldTypes, - totalDataSize: 0, - totalRowNum: 0, - // TODO: set the quota of disk usage. - diskTracker: disk.NewTracker(memory.LabelForChunkDataInDiskByChunks, -1), - buf: make([]byte, 0, 4096), - } - return d -} - -func (d *DataInDiskByChunks) initDiskFile() (err error) { - err = disk.CheckAndInitTempDir() - if err != nil { - return - } - err = d.dataFile.initWithFileName(defaultChunkDataInDiskByChunksPath + strconv.Itoa(d.diskTracker.Label())) - return -} - -// GetDiskTracker returns the memory tracker of this List. -func (d *DataInDiskByChunks) GetDiskTracker() *disk.Tracker { - return d.diskTracker -} - -// Add adds a chunk to the DataInDiskByChunks. Caller must make sure the input chk has the same field types. -// Warning: Do not concurrently call this function. -func (d *DataInDiskByChunks) Add(chk *Chunk) (err error) { - if err := injectChunkInDiskRandomError(); err != nil { - return err - } - - if chk.NumRows() == 0 { - return errors2.New("Chunk spilled to disk should have at least 1 row") - } - - if d.dataFile.file == nil { - err = d.initDiskFile() - if err != nil { - return - } - } - - serializedBytesNum := d.serializeDataToBuf(chk) - - var writeNum int - writeNum, err = d.dataFile.write(d.buf) - if err != nil { - return - } - - if int64(writeNum) != serializedBytesNum { - return errors2.New("Some data fail to be spilled to disk") - } - d.offsetOfEachChunk = append(d.offsetOfEachChunk, d.totalDataSize) - d.totalDataSize += serializedBytesNum - d.totalRowNum += int64(chk.NumRows()) - d.dataFile.offWrite += serializedBytesNum - - d.diskTracker.Consume(serializedBytesNum) - return -} - -func (d *DataInDiskByChunks) getChunkSize(chkIdx int) int64 { - totalChunkNum := len(d.offsetOfEachChunk) - if chkIdx == totalChunkNum-1 { - return d.totalDataSize - d.offsetOfEachChunk[chkIdx] - } - return d.offsetOfEachChunk[chkIdx+1] - d.offsetOfEachChunk[chkIdx] -} - -// GetChunk gets a Chunk from the DataInDiskByChunks by chkIdx. -func (d *DataInDiskByChunks) GetChunk(chkIdx int) (*Chunk, error) { - if err := injectChunkInDiskRandomError(); err != nil { - return nil, err - } - - reader := d.dataFile.getSectionReader(d.offsetOfEachChunk[chkIdx]) - chkSize := d.getChunkSize(chkIdx) - - if cap(d.buf) < int(chkSize) { - d.buf = make([]byte, chkSize) - } else { - d.buf = d.buf[:chkSize] - } - - readByteNum, err := io.ReadFull(reader, d.buf) - if err != nil { - return nil, err - } - - if int64(readByteNum) != chkSize { - return nil, errors2.New("Fail to restore the spilled chunk") - } - - chk := NewEmptyChunk(d.fieldTypes) - d.deserializeDataToChunk(chk) - - return chk, nil -} - -// Close releases the disk resource. -func (d *DataInDiskByChunks) Close() { - if d.dataFile.file != nil { - d.diskTracker.Consume(-d.diskTracker.BytesConsumed()) - terror.Call(d.dataFile.file.Close) - terror.Log(os.Remove(d.dataFile.file.Name())) - } -} - -func (d *DataInDiskByChunks) serializeColMeta(pos int64, length int64, nullMapSize int64, dataSize int64, offsetSize int64) { - *(*int64)(unsafe.Pointer(&d.buf[pos])) = length - *(*int64)(unsafe.Pointer(&d.buf[pos+int64Len])) = nullMapSize - *(*int64)(unsafe.Pointer(&d.buf[pos+int64Len*2])) = dataSize - *(*int64)(unsafe.Pointer(&d.buf[pos+int64Len*3])) = offsetSize -} - -func (d *DataInDiskByChunks) serializeOffset(pos *int64, offsets []int64, offsetSize int64) { - d.buf = d.buf[:*pos+offsetSize] - for _, offset := range offsets { - *(*int64)(unsafe.Pointer(&d.buf[*pos])) = offset - *pos += int64Len - } -} - -func (d *DataInDiskByChunks) serializeChunkData(pos *int64, chk *Chunk, selSize int64) { - d.buf = d.buf[:chkFixedSize] - *(*int)(unsafe.Pointer(&d.buf[*pos])) = chk.numVirtualRows - *(*int)(unsafe.Pointer(&d.buf[*pos+intLen])) = chk.capacity - *(*int)(unsafe.Pointer(&d.buf[*pos+intLen*2])) = chk.requiredRows - *(*int)(unsafe.Pointer(&d.buf[*pos+intLen*3])) = int(selSize) - *pos += chkFixedSize - - d.buf = d.buf[:*pos+selSize] - - selLen := len(chk.sel) - for i := 0; i < selLen; i++ { - *(*int)(unsafe.Pointer(&d.buf[*pos])) = chk.sel[i] - *pos += intLen - } -} - -func (d *DataInDiskByChunks) serializeColumns(pos *int64, chk *Chunk) { - for _, col := range chk.columns { - d.buf = d.buf[:*pos+colMetaSize] - nullMapSize := int64(len(col.nullBitmap)) * byteLen - dataSize := int64(len(col.data)) * byteLen - offsetSize := int64(len(col.offsets)) * int64Len - d.serializeColMeta(*pos, int64(col.length), nullMapSize, dataSize, offsetSize) - *pos += colMetaSize - - d.buf = append(d.buf, col.nullBitmap...) - d.buf = append(d.buf, col.data...) - *pos += nullMapSize + dataSize - d.serializeOffset(pos, col.offsets, offsetSize) - } -} - -// Serialized format of a chunk: -// chunk data: | numVirtualRows | capacity | requiredRows | selSize | sel... | -// column1 data: | length | nullMapSize | dataSize | offsetSize | nullBitmap... | data... | offsets... | -// column2 data: | length | nullMapSize | dataSize | offsetSize | nullBitmap... | data... | offsets... | -// ... -// columnN data: | length | nullMapSize | dataSize | offsetSize | nullBitmap... | data... | offsets... | -// -// `xxx...` means this is a variable field filled by bytes. -func (d *DataInDiskByChunks) serializeDataToBuf(chk *Chunk) int64 { - totalBytes := int64(0) - - // Calculate total memory that buffer needs - selSize := int64(len(chk.sel)) * intLen - totalBytes += chkFixedSize + selSize - for _, col := range chk.columns { - nullMapSize := int64(len(col.nullBitmap)) * byteLen - dataSize := int64(len(col.data)) * byteLen - offsetSize := int64(len(col.offsets)) * int64Len - totalBytes += colMetaSize + nullMapSize + dataSize + offsetSize - } - - if cap(d.buf) < int(totalBytes) { - d.buf = make([]byte, 0, totalBytes) - } - - pos := int64(0) - d.serializeChunkData(&pos, chk, selSize) - d.serializeColumns(&pos, chk) - return totalBytes -} - -func (d *DataInDiskByChunks) deserializeColMeta(pos *int64) (length int64, nullMapSize int64, dataSize int64, offsetSize int64) { - length = *(*int64)(unsafe.Pointer(&d.buf[*pos])) - *pos += int64Len - - nullMapSize = *(*int64)(unsafe.Pointer(&d.buf[*pos])) - *pos += int64Len - - dataSize = *(*int64)(unsafe.Pointer(&d.buf[*pos])) - *pos += int64Len - - offsetSize = *(*int64)(unsafe.Pointer(&d.buf[*pos])) - *pos += int64Len - return -} - -func (d *DataInDiskByChunks) deserializeSel(chk *Chunk, pos *int64, selSize int) { - selLen := int64(selSize) / intLen - chk.sel = make([]int, selLen) - for i := int64(0); i < selLen; i++ { - chk.sel[i] = *(*int)(unsafe.Pointer(&d.buf[*pos])) - *pos += intLen - } -} - -func (d *DataInDiskByChunks) deserializeChunkData(chk *Chunk, pos *int64) { - chk.numVirtualRows = *(*int)(unsafe.Pointer(&d.buf[*pos])) - *pos += intLen - - chk.capacity = *(*int)(unsafe.Pointer(&d.buf[*pos])) - *pos += intLen - - chk.requiredRows = *(*int)(unsafe.Pointer(&d.buf[*pos])) - *pos += intLen - - selSize := *(*int)(unsafe.Pointer(&d.buf[*pos])) - *pos += intLen - if selSize != 0 { - d.deserializeSel(chk, pos, selSize) - } -} - -func (d *DataInDiskByChunks) deserializeOffsets(dst []int64, pos *int64) { - offsetNum := len(dst) - for i := 0; i < offsetNum; i++ { - dst[i] = *(*int64)(unsafe.Pointer(&d.buf[*pos])) - *pos += int64Len - } -} - -func (d *DataInDiskByChunks) deserializeColumns(chk *Chunk, pos *int64) { - for _, col := range chk.columns { - length, nullMapSize, dataSize, offsetSize := d.deserializeColMeta(pos) - col.nullBitmap = make([]byte, nullMapSize) - col.data = make([]byte, dataSize) - col.offsets = make([]int64, offsetSize/int64Len) - - col.length = int(length) - copy(col.nullBitmap, d.buf[*pos:*pos+nullMapSize]) - *pos += nullMapSize - copy(col.data, d.buf[*pos:*pos+dataSize]) - *pos += dataSize - d.deserializeOffsets(col.offsets, pos) - } -} - -func (d *DataInDiskByChunks) deserializeDataToChunk(chk *Chunk) { - pos := int64(0) - d.deserializeChunkData(chk, &pos) - d.deserializeColumns(chk, &pos) -} - -// NumRows returns total spilled row number -func (d *DataInDiskByChunks) NumRows() int64 { - return d.totalRowNum -} - -// NumChunks returns total spilled chunk number -func (d *DataInDiskByChunks) NumChunks() int { - return len(d.offsetOfEachChunk) -} - -func injectChunkInDiskRandomError() error { - var err error - failpoint.Inject("ChunkInDiskError", func(val failpoint.Value) { - if val.(bool) { - randNum := rand.Int31n(10000) - if randNum < 3 { - err = errors2.New("random error is triggered") - } else if randNum < 6 { - delayTime := rand.Int31n(10) + 5 - time.Sleep(time.Duration(delayTime) * time.Millisecond) - } - } - }) - return err -} diff --git a/pkg/util/chunk/row_container.go b/pkg/util/chunk/row_container.go index 572c09b7984ad..9b443569fb05c 100644 --- a/pkg/util/chunk/row_container.go +++ b/pkg/util/chunk/row_container.go @@ -164,11 +164,11 @@ func (c *RowContainer) spillToDisk(preSpillError error) { logutil.BgLogger().Error("spill to disk failed", zap.Stack("stack"), zap.Error(err)) } }() - if val, _err_ := failpoint.Eval(_curpkg_("spillToDiskOutOfDiskQuota")); _err_ == nil { + failpoint.Inject("spillToDiskOutOfDiskQuota", func(val failpoint.Value) { if val.(bool) { panic("out of disk quota when spilling") } - } + }) if preSpillError != nil { c.m.records.spillError = preSpillError return @@ -249,11 +249,11 @@ func (c *RowContainer) NumChunks() int { func (c *RowContainer) Add(chk *Chunk) (err error) { c.m.RLock() defer c.m.RUnlock() - if val, _err_ := failpoint.Eval(_curpkg_("testRowContainerDeadLock")); _err_ == nil { + failpoint.Inject("testRowContainerDeadLock", func(val failpoint.Value) { if val.(bool) { time.Sleep(time.Second) } - } + }) if c.alreadySpilled() { if err := c.m.records.spillError; err != nil { return err @@ -559,11 +559,11 @@ func (c *SortedRowContainer) keyColumnsLess(i, j int) bool { c.memTracker.Consume(1) c.timesOfRowCompare = 0 } - if val, _err_ := failpoint.Eval(_curpkg_("SignalCheckpointForSort")); _err_ == nil { + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { if val.(bool) { c.timesOfRowCompare += 1024 } - } + }) c.timesOfRowCompare++ rowI := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[i]) rowJ := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[j]) @@ -598,11 +598,11 @@ func (c *SortedRowContainer) Sort() (ret error) { c.ptrM.rowPtrs = append(c.ptrM.rowPtrs, RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)}) } } - if val, _err_ := failpoint.Eval(_curpkg_("errorDuringSortRowContainer")); _err_ == nil { + failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { if val.(bool) { panic("sort meet error") } - } + }) sort.Slice(c.ptrM.rowPtrs, c.keyColumnsLess) return } diff --git a/pkg/util/chunk/row_container.go__failpoint_stash__ b/pkg/util/chunk/row_container.go__failpoint_stash__ deleted file mode 100644 index 9b443569fb05c..0000000000000 --- a/pkg/util/chunk/row_container.go__failpoint_stash__ +++ /dev/null @@ -1,691 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package chunk - -import ( - "fmt" - "sort" - "sync" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util/disk" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "go.uber.org/zap" - "golang.org/x/sys/cpu" -) - -// ErrCannotAddBecauseSorted indicate that the SortPartition is sorted and prohibit inserting data. -var ErrCannotAddBecauseSorted = errors.New("can not add because sorted") - -type rowContainerRecord struct { - inMemory *List - inDisk *DataInDiskByRows - // spillError stores the error when spilling. - spillError error -} - -type mutexForRowContainer struct { - // Add cache padding to avoid false sharing issue. - _ cpu.CacheLinePad - // RWMutex guarantees spill and get operator for rowContainer is mutually exclusive. - // `rLock` and `wLocks` is introduced to reduce the contention when multiple - // goroutine touch the same rowContainer concurrently. If there are multiple - // goroutines touch the same rowContainer concurrently, it's recommended to - // use RowContainer.ShallowCopyWithNewMutex to build a new RowContainer for - // each goroutine. Thus each goroutine holds its own rLock but share the same - // underlying data, which can reduce the contention on m.rLock remarkably and - // get better performance. - rLock sync.RWMutex - wLocks []*sync.RWMutex - records *rowContainerRecord - _ cpu.CacheLinePad -} - -// Lock locks rw for writing. -func (m *mutexForRowContainer) Lock() { - for _, l := range m.wLocks { - l.Lock() - } -} - -// Unlock unlocks rw for writing. -func (m *mutexForRowContainer) Unlock() { - for _, l := range m.wLocks { - l.Unlock() - } -} - -// RLock locks rw for reading. -func (m *mutexForRowContainer) RLock() { - m.rLock.RLock() -} - -// RUnlock undoes a single RLock call. -func (m *mutexForRowContainer) RUnlock() { - m.rLock.RUnlock() -} - -type spillHelper interface { - SpillToDisk() - hasEnoughDataToSpill(t *memory.Tracker) bool -} - -// RowContainer provides a place for many rows, so many that we might want to spill them into disk. -// nolint:structcheck -type RowContainer struct { - m *mutexForRowContainer - - memTracker *memory.Tracker - diskTracker *disk.Tracker - actionSpill *SpillDiskAction -} - -// NewRowContainer creates a new RowContainer in memory. -func NewRowContainer(fieldType []*types.FieldType, chunkSize int) *RowContainer { - li := NewList(fieldType, chunkSize, chunkSize) - rc := &RowContainer{ - m: &mutexForRowContainer{ - records: &rowContainerRecord{inMemory: li}, - rLock: sync.RWMutex{}, - wLocks: []*sync.RWMutex{}, - }, - memTracker: memory.NewTracker(memory.LabelForRowContainer, -1), - diskTracker: disk.NewTracker(memory.LabelForRowContainer, -1), - } - rc.m.wLocks = append(rc.m.wLocks, &rc.m.rLock) - li.GetMemTracker().AttachTo(rc.GetMemTracker()) - return rc -} - -// ShallowCopyWithNewMutex shallow clones a RowContainer. -// The new RowContainer shares the same underlying data with the old one but -// holds an individual rLock. -func (c *RowContainer) ShallowCopyWithNewMutex() *RowContainer { - newRC := *c - newRC.m = &mutexForRowContainer{ - records: c.m.records, - rLock: sync.RWMutex{}, - wLocks: []*sync.RWMutex{}, - } - c.m.wLocks = append(c.m.wLocks, &newRC.m.rLock) - return &newRC -} - -// SpillToDisk spills data to disk. This function may be called in parallel. -func (c *RowContainer) SpillToDisk() { - c.spillToDisk(nil) -} - -func (*RowContainer) hasEnoughDataToSpill(_ *memory.Tracker) bool { - return true -} - -func (c *RowContainer) spillToDisk(preSpillError error) { - c.m.Lock() - defer c.m.Unlock() - if c.alreadySpilled() { - return - } - // c.actionSpill may be nil when testing SpillToDisk directly. - if c.actionSpill != nil { - if c.actionSpill.getStatus() == spilledYet { - // The rowContainer has been closed. - return - } - c.actionSpill.setStatus(spilling) - defer c.actionSpill.cond.Broadcast() - defer c.actionSpill.setStatus(spilledYet) - } - var err error - memory.QueryForceDisk.Add(1) - n := c.m.records.inMemory.NumChunks() - c.m.records.inDisk = NewDataInDiskByRows(c.m.records.inMemory.FieldTypes()) - c.m.records.inDisk.diskTracker.AttachTo(c.diskTracker) - defer func() { - if r := recover(); r != nil { - err := fmt.Errorf("%v", r) - c.m.records.spillError = err - logutil.BgLogger().Error("spill to disk failed", zap.Stack("stack"), zap.Error(err)) - } - }() - failpoint.Inject("spillToDiskOutOfDiskQuota", func(val failpoint.Value) { - if val.(bool) { - panic("out of disk quota when spilling") - } - }) - if preSpillError != nil { - c.m.records.spillError = preSpillError - return - } - for i := 0; i < n; i++ { - chk := c.m.records.inMemory.GetChunk(i) - err = c.m.records.inDisk.Add(chk) - if err != nil { - c.m.records.spillError = err - return - } - c.m.records.inMemory.GetMemTracker().HandleKillSignal() - } - c.m.records.inMemory.Clear() -} - -// Reset resets RowContainer. -func (c *RowContainer) Reset() error { - c.m.Lock() - defer c.m.Unlock() - if c.alreadySpilled() { - err := c.m.records.inDisk.Close() - c.m.records.inDisk = nil - if err != nil { - return err - } - c.actionSpill.Reset() - } else { - c.m.records.inMemory.Reset() - } - return nil -} - -// alreadySpilled indicates that records have spilled out into disk. -func (c *RowContainer) alreadySpilled() bool { - return c.m.records.inDisk != nil -} - -// AlreadySpilledSafeForTest indicates that records have spilled out into disk. It's thread-safe. -// The function is only used for test. -func (c *RowContainer) AlreadySpilledSafeForTest() bool { - c.m.RLock() - defer c.m.RUnlock() - return c.m.records.inDisk != nil -} - -// NumRow returns the number of rows in the container -func (c *RowContainer) NumRow() int { - c.m.RLock() - defer c.m.RUnlock() - if c.alreadySpilled() { - return c.m.records.inDisk.Len() - } - return c.m.records.inMemory.Len() -} - -// NumRowsOfChunk returns the number of rows of a chunk in the DataInDiskByRows. -func (c *RowContainer) NumRowsOfChunk(chkID int) int { - c.m.RLock() - defer c.m.RUnlock() - if c.alreadySpilled() { - return c.m.records.inDisk.NumRowsOfChunk(chkID) - } - return c.m.records.inMemory.NumRowsOfChunk(chkID) -} - -// NumChunks returns the number of chunks in the container. -func (c *RowContainer) NumChunks() int { - c.m.RLock() - defer c.m.RUnlock() - if c.alreadySpilled() { - return c.m.records.inDisk.NumChunks() - } - return c.m.records.inMemory.NumChunks() -} - -// Add appends a chunk into the RowContainer. -func (c *RowContainer) Add(chk *Chunk) (err error) { - c.m.RLock() - defer c.m.RUnlock() - failpoint.Inject("testRowContainerDeadLock", func(val failpoint.Value) { - if val.(bool) { - time.Sleep(time.Second) - } - }) - if c.alreadySpilled() { - if err := c.m.records.spillError; err != nil { - return err - } - err = c.m.records.inDisk.Add(chk) - } else { - c.m.records.inMemory.Add(chk) - } - return -} - -// AllocChunk allocates a new chunk from RowContainer. -func (c *RowContainer) AllocChunk() (chk *Chunk) { - return c.m.records.inMemory.allocChunk() -} - -// GetChunk returns chkIdx th chunk of in memory records. -func (c *RowContainer) GetChunk(chkIdx int) (*Chunk, error) { - c.m.RLock() - defer c.m.RUnlock() - if !c.alreadySpilled() { - return c.m.records.inMemory.GetChunk(chkIdx), nil - } - if err := c.m.records.spillError; err != nil { - return nil, err - } - return c.m.records.inDisk.GetChunk(chkIdx) -} - -// GetRow returns the row the ptr pointed to. -func (c *RowContainer) GetRow(ptr RowPtr) (row Row, err error) { - row, _, err = c.GetRowAndAppendToChunkIfInDisk(ptr, nil) - return row, err -} - -// GetRowAndAppendToChunkIfInDisk gets a Row from the RowContainer by RowPtr. If the container has spilled, the row will -// be appended to the chunk. It'll return `nil` chunk if the container hasn't spilled, or it returns an error. -func (c *RowContainer) GetRowAndAppendToChunkIfInDisk(ptr RowPtr, chk *Chunk) (row Row, _ *Chunk, err error) { - c.m.RLock() - defer c.m.RUnlock() - if c.alreadySpilled() { - if err := c.m.records.spillError; err != nil { - return Row{}, nil, err - } - return c.m.records.inDisk.GetRowAndAppendToChunk(ptr, chk) - } - return c.m.records.inMemory.GetRow(ptr), nil, nil -} - -// GetRowAndAlwaysAppendToChunk gets a Row from the RowContainer by RowPtr. Unlike `GetRowAndAppendToChunkIfInDisk`, this -// function always appends the row to the chunk, without considering whether it has spilled. -// It'll return `nil` chunk if it returns an error, or the chunk will be the same with the argument. -func (c *RowContainer) GetRowAndAlwaysAppendToChunk(ptr RowPtr, chk *Chunk) (row Row, _ *Chunk, err error) { - row, retChk, err := c.GetRowAndAppendToChunkIfInDisk(ptr, chk) - if err != nil { - return row, nil, err - } - - if retChk == nil { - // The container hasn't spilled, and the row is not appended to the chunk, so append the chunk explicitly here - chk.AppendRow(row) - } - - return row, chk, nil -} - -// GetMemTracker returns the memory tracker in records, panics if the RowContainer has already spilled. -func (c *RowContainer) GetMemTracker() *memory.Tracker { - return c.memTracker -} - -// GetDiskTracker returns the underlying disk usage tracker in recordsInDisk. -func (c *RowContainer) GetDiskTracker() *disk.Tracker { - return c.diskTracker -} - -// Close close the RowContainer -func (c *RowContainer) Close() (err error) { - c.m.RLock() - defer c.m.RUnlock() - if c.actionSpill != nil { - // Set status to spilledYet to avoid spilling. - c.actionSpill.setStatus(spilledYet) - c.actionSpill.cond.Broadcast() - c.actionSpill.SetFinished() - } - c.memTracker.Detach() - c.diskTracker.Detach() - if c.alreadySpilled() { - err = c.m.records.inDisk.Close() - c.m.records.inDisk = nil - } - c.m.records.inMemory.Clear() - c.m.records.inMemory = nil - return -} - -// ActionSpill returns a SpillDiskAction for spilling over to disk. -func (c *RowContainer) ActionSpill() *SpillDiskAction { - if c.actionSpill == nil { - c.actionSpill = &SpillDiskAction{ - c: c, - baseSpillDiskAction: &baseSpillDiskAction{cond: spillStatusCond{sync.NewCond(new(sync.Mutex)), notSpilled}}, - } - } - return c.actionSpill -} - -// ActionSpillForTest returns a SpillDiskAction for spilling over to disk for test. -func (c *RowContainer) ActionSpillForTest() *SpillDiskAction { - c.actionSpill = &SpillDiskAction{ - c: c, - baseSpillDiskAction: &baseSpillDiskAction{ - testSyncInputFunc: func() { - c.actionSpill.testWg.Add(1) - }, - testSyncOutputFunc: func() { - c.actionSpill.testWg.Done() - }, - cond: spillStatusCond{sync.NewCond(new(sync.Mutex)), notSpilled}, - }, - } - return c.actionSpill -} - -type baseSpillDiskAction struct { - memory.BaseOOMAction - m sync.Mutex - once sync.Once - cond spillStatusCond - - // test function only used for test sync. - testSyncInputFunc func() - testSyncOutputFunc func() - testWg sync.WaitGroup -} - -// SpillDiskAction implements memory.ActionOnExceed for chunk.List. If -// the memory quota of a query is exceeded, SpillDiskAction.Action is -// triggered. -type SpillDiskAction struct { - c *RowContainer - *baseSpillDiskAction -} - -// Action sends a signal to trigger spillToDisk method of RowContainer -// and if it is already triggered before, call its fallbackAction. -func (a *SpillDiskAction) Action(t *memory.Tracker) { - a.action(t, a.c) -} - -type spillStatusCond struct { - *sync.Cond - // status indicates different stages for the Action - // notSpilled indicates the rowContainer is not spilled. - // spilling indicates the rowContainer is spilling. - // spilledYet indicates thr rowContainer is spilled. - status spillStatus -} - -type spillStatus uint32 - -const ( - notSpilled spillStatus = iota - spilling - spilledYet -) - -func (a *baseSpillDiskAction) setStatus(status spillStatus) { - a.cond.L.Lock() - defer a.cond.L.Unlock() - a.cond.status = status -} - -func (a *baseSpillDiskAction) getStatus() spillStatus { - a.cond.L.Lock() - defer a.cond.L.Unlock() - return a.cond.status -} - -func (a *baseSpillDiskAction) action(t *memory.Tracker, spillHelper spillHelper) { - a.m.Lock() - defer a.m.Unlock() - - if a.getStatus() == notSpilled && spillHelper.hasEnoughDataToSpill(t) { - a.once.Do(func() { - logutil.BgLogger().Info("memory exceeds quota, spill to disk now.", - zap.Int64("consumed", t.BytesConsumed()), zap.Int64("quota", t.GetBytesLimit())) - if a.testSyncInputFunc != nil { - a.testSyncInputFunc() - go func() { - spillHelper.SpillToDisk() - a.testSyncOutputFunc() - }() - return - } - go spillHelper.SpillToDisk() - }) - return - } - - a.cond.L.Lock() - for a.cond.status == spilling { - a.cond.Wait() - } - a.cond.L.Unlock() - - if !t.CheckExceed() { - return - } - if fallback := a.GetFallback(); fallback != nil { - fallback.Action(t) - } -} - -// Reset resets the status for SpillDiskAction. -func (a *baseSpillDiskAction) Reset() { - a.m.Lock() - defer a.m.Unlock() - a.setStatus(notSpilled) - a.once = sync.Once{} -} - -// GetPriority get the priority of the Action. -func (*baseSpillDiskAction) GetPriority() int64 { - return memory.DefSpillPriority -} - -// WaitForTest waits all goroutine have gone. -func (a *baseSpillDiskAction) WaitForTest() { - a.testWg.Wait() -} - -// SortedRowContainer provides a place for many rows, so many that we might want to sort and spill them into disk. -type SortedRowContainer struct { - *RowContainer - ptrM struct { - sync.RWMutex - // rowPtrs store the chunk index and row index for each row. - // rowPtrs != nil indicates the pointer is initialized and sorted. - // It will get an ErrCannotAddBecauseSorted when trying to insert data if rowPtrs != nil. - rowPtrs []RowPtr - } - - ByItemsDesc []bool - // keyColumns is the column index of the by items. - keyColumns []int - // keyCmpFuncs is used to compare each ByItem. - keyCmpFuncs []CompareFunc - - actionSpill *SortAndSpillDiskAction - memTracker *memory.Tracker - - // Sort is a time-consuming operation, we need to set a checkpoint to detect - // the outside signal periodically. - timesOfRowCompare uint -} - -// NewSortedRowContainer creates a new SortedRowContainer in memory. -func NewSortedRowContainer(fieldType []*types.FieldType, chunkSize int, byItemsDesc []bool, - keyColumns []int, keyCmpFuncs []CompareFunc) *SortedRowContainer { - src := SortedRowContainer{RowContainer: NewRowContainer(fieldType, chunkSize), - ByItemsDesc: byItemsDesc, keyColumns: keyColumns, keyCmpFuncs: keyCmpFuncs} - src.memTracker = memory.NewTracker(memory.LabelForRowContainer, -1) - src.RowContainer.GetMemTracker().AttachTo(src.GetMemTracker()) - return &src -} - -// Close close the SortedRowContainer -func (c *SortedRowContainer) Close() error { - c.ptrM.Lock() - defer c.ptrM.Unlock() - c.GetMemTracker().Consume(int64(-8 * c.NumRow())) - c.ptrM.rowPtrs = nil - return c.RowContainer.Close() -} - -func (c *SortedRowContainer) lessRow(rowI, rowJ Row) bool { - for i, colIdx := range c.keyColumns { - cmpFunc := c.keyCmpFuncs[i] - if cmpFunc != nil { - cmp := cmpFunc(rowI, colIdx, rowJ, colIdx) - if c.ByItemsDesc[i] { - cmp = -cmp - } - if cmp < 0 { - return true - } else if cmp > 0 { - return false - } - } - } - return false -} - -// SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered. -const SignalCheckpointForSort uint = 10240 - -// keyColumnsLess is the less function for key columns. -func (c *SortedRowContainer) keyColumnsLess(i, j int) bool { - if c.timesOfRowCompare >= SignalCheckpointForSort { - // Trigger Consume for checking the NeedKill signal - c.memTracker.Consume(1) - c.timesOfRowCompare = 0 - } - failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { - if val.(bool) { - c.timesOfRowCompare += 1024 - } - }) - c.timesOfRowCompare++ - rowI := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[i]) - rowJ := c.m.records.inMemory.GetRow(c.ptrM.rowPtrs[j]) - return c.lessRow(rowI, rowJ) -} - -// Sort inits pointers and sorts the records. -func (c *SortedRowContainer) Sort() (ret error) { - c.ptrM.Lock() - defer c.ptrM.Unlock() - ret = nil - defer func() { - if r := recover(); r != nil { - if err, ok := r.(error); ok { - ret = err - } else { - ret = fmt.Errorf("%v", r) - } - } - }() - if c.ptrM.rowPtrs != nil { - return - } - c.ptrM.rowPtrs = make([]RowPtr, 0, c.NumRow()) // The memory usage has been tracked in SortedRowContainer.Add() function - for chkIdx := 0; chkIdx < c.NumChunks(); chkIdx++ { - rowChk, err := c.GetChunk(chkIdx) - // err must be nil, because the chunk is in memory. - if err != nil { - panic(err) - } - for rowIdx := 0; rowIdx < rowChk.NumRows(); rowIdx++ { - c.ptrM.rowPtrs = append(c.ptrM.rowPtrs, RowPtr{ChkIdx: uint32(chkIdx), RowIdx: uint32(rowIdx)}) - } - } - failpoint.Inject("errorDuringSortRowContainer", func(val failpoint.Value) { - if val.(bool) { - panic("sort meet error") - } - }) - sort.Slice(c.ptrM.rowPtrs, c.keyColumnsLess) - return -} - -// SpillToDisk spills data to disk. This function may be called in parallel. -func (c *SortedRowContainer) SpillToDisk() { - err := c.Sort() - c.RowContainer.spillToDisk(err) -} - -func (c *SortedRowContainer) hasEnoughDataToSpill(t *memory.Tracker) bool { - // Guarantee that each partition size is at least 10% of the threshold, to avoid opening too many files. - return c.GetMemTracker().BytesConsumed() > t.GetBytesLimit()/10 -} - -// Add appends a chunk into the SortedRowContainer. -func (c *SortedRowContainer) Add(chk *Chunk) (err error) { - c.ptrM.RLock() - defer c.ptrM.RUnlock() - if c.ptrM.rowPtrs != nil { - return ErrCannotAddBecauseSorted - } - // Consume the memory usage of rowPtrs in advance - c.GetMemTracker().Consume(int64(chk.NumRows() * 8)) - return c.RowContainer.Add(chk) -} - -// GetSortedRow returns the row the idx pointed to. -func (c *SortedRowContainer) GetSortedRow(idx int) (Row, error) { - c.ptrM.RLock() - defer c.ptrM.RUnlock() - ptr := c.ptrM.rowPtrs[idx] - return c.RowContainer.GetRow(ptr) -} - -// GetSortedRowAndAlwaysAppendToChunk returns the row the idx pointed to. -func (c *SortedRowContainer) GetSortedRowAndAlwaysAppendToChunk(idx int, chk *Chunk) (Row, *Chunk, error) { - c.ptrM.RLock() - defer c.ptrM.RUnlock() - ptr := c.ptrM.rowPtrs[idx] - return c.RowContainer.GetRowAndAlwaysAppendToChunk(ptr, chk) -} - -// ActionSpill returns a SortAndSpillDiskAction for sorting and spilling over to disk. -func (c *SortedRowContainer) ActionSpill() *SortAndSpillDiskAction { - if c.actionSpill == nil { - c.actionSpill = &SortAndSpillDiskAction{ - c: c, - baseSpillDiskAction: c.RowContainer.ActionSpill().baseSpillDiskAction, - } - } - return c.actionSpill -} - -// ActionSpillForTest returns a SortAndSpillDiskAction for sorting and spilling over to disk for test. -func (c *SortedRowContainer) ActionSpillForTest() *SortAndSpillDiskAction { - c.actionSpill = &SortAndSpillDiskAction{ - c: c, - baseSpillDiskAction: c.RowContainer.ActionSpillForTest().baseSpillDiskAction, - } - return c.actionSpill -} - -// GetMemTracker return the memory tracker for the sortedRowContainer -func (c *SortedRowContainer) GetMemTracker() *memory.Tracker { - return c.memTracker -} - -// SortAndSpillDiskAction implements memory.ActionOnExceed for chunk.List. If -// the memory quota of a query is exceeded, SortAndSpillDiskAction.Action is -// triggered. -type SortAndSpillDiskAction struct { - c *SortedRowContainer - *baseSpillDiskAction -} - -// Action sends a signal to trigger sortAndSpillToDisk method of RowContainer -// and if it is already triggered before, call its fallbackAction. -func (a *SortAndSpillDiskAction) Action(t *memory.Tracker) { - a.action(t, a.c) -} - -// WaitForTest waits all goroutine have gone. -func (a *SortAndSpillDiskAction) WaitForTest() { - a.testWg.Wait() -} diff --git a/pkg/util/chunk/row_container_reader.go b/pkg/util/chunk/row_container_reader.go index 797934e66e0e7..ca124083079c5 100644 --- a/pkg/util/chunk/row_container_reader.go +++ b/pkg/util/chunk/row_container_reader.go @@ -124,11 +124,11 @@ func (reader *rowContainerReader) startWorker() { for chkIdx := 0; chkIdx < reader.rc.NumChunks(); chkIdx++ { chk, err := reader.rc.GetChunk(chkIdx) - if val, _err_ := failpoint.Eval(_curpkg_("get-chunk-error")); _err_ == nil { + failpoint.Inject("get-chunk-error", func(val failpoint.Value) { if val.(bool) { err = errors.New("fail to get chunk for test") } - } + }) if err != nil { reader.err = err return diff --git a/pkg/util/chunk/row_container_reader.go__failpoint_stash__ b/pkg/util/chunk/row_container_reader.go__failpoint_stash__ deleted file mode 100644 index ca124083079c5..0000000000000 --- a/pkg/util/chunk/row_container_reader.go__failpoint_stash__ +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package chunk - -import ( - "context" - "runtime" - "sync" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util/logutil" -) - -// RowContainerReader is a forward-only iterator for the row container. It provides an interface similar to other -// iterators, but it doesn't provide `ReachEnd` function and requires manually closing to release goroutine. -// -// It's recommended to use the following pattern to use it: -// -// for iter := NewRowContainerReader(rc); iter.Current() != iter.End(); iter.Next() { -// ... -// } -// iter.Close() -// if iter.Error() != nil { -// } -type RowContainerReader interface { - // Next returns the next Row. - Next() Row - - // Current returns the current Row. - Current() Row - - // End returns the invalid end Row. - End() Row - - // Error returns none-nil error if anything wrong happens during the iteration. - Error() error - - // Close closes the dumper - Close() -} - -var _ RowContainerReader = &rowContainerReader{} - -// rowContainerReader is a forward-only iterator for the row container -// It will spawn two goroutines for reading chunks from disk, and converting the chunk to rows. The row will only be sent -// to `rowCh` inside only after when the full chunk has been read, to avoid concurrently read/write to the chunk. -// -// TODO: record the memory allocated for the channel and chunks. -type rowContainerReader struct { - // context, cancel and waitgroup are used to stop and wait until all goroutine stops. - ctx context.Context - cancel func() - wg sync.WaitGroup - - rc *RowContainer - - currentRow Row - rowCh chan Row - - // this error will only be set by worker - err error -} - -// Next implements RowContainerReader -func (reader *rowContainerReader) Next() Row { - for row := range reader.rowCh { - reader.currentRow = row - return row - } - reader.currentRow = reader.End() - return reader.End() -} - -// Current implements RowContainerReader -func (reader *rowContainerReader) Current() Row { - return reader.currentRow -} - -// End implements RowContainerReader -func (*rowContainerReader) End() Row { - return Row{} -} - -// Error implements RowContainerReader -func (reader *rowContainerReader) Error() error { - return reader.err -} - -func (reader *rowContainerReader) initializeChannel() { - if reader.rc.NumChunks() == 0 { - reader.rowCh = make(chan Row, 1024) - } else { - assumeChunkSize := reader.rc.NumRowsOfChunk(0) - // To avoid blocking in sending to `rowCh` and don't start reading the next chunk, it'd be better to give it - // a buffer at least larger than a single chunk. Here it's allocated twice the chunk size to leave some margin. - reader.rowCh = make(chan Row, 2*assumeChunkSize) - } -} - -// Close implements RowContainerReader -func (reader *rowContainerReader) Close() { - reader.cancel() - reader.wg.Wait() -} - -func (reader *rowContainerReader) startWorker() { - reader.wg.Add(1) - go func() { - defer close(reader.rowCh) - defer reader.wg.Done() - - for chkIdx := 0; chkIdx < reader.rc.NumChunks(); chkIdx++ { - chk, err := reader.rc.GetChunk(chkIdx) - failpoint.Inject("get-chunk-error", func(val failpoint.Value) { - if val.(bool) { - err = errors.New("fail to get chunk for test") - } - }) - if err != nil { - reader.err = err - return - } - - for i := 0; i < chk.NumRows(); i++ { - select { - case reader.rowCh <- chk.GetRow(i): - case <-reader.ctx.Done(): - return - } - } - } - }() -} - -// NewRowContainerReader creates a forward only iterator for row container -func NewRowContainerReader(rc *RowContainer) *rowContainerReader { - ctx, cancel := context.WithCancel(context.Background()) - - reader := &rowContainerReader{ - ctx: ctx, - cancel: cancel, - wg: sync.WaitGroup{}, - - rc: rc, - } - reader.initializeChannel() - reader.startWorker() - reader.Next() - runtime.SetFinalizer(reader, func(reader *rowContainerReader) { - if reader.ctx.Err() == nil { - logutil.BgLogger().Warn("rowContainerReader is closed by finalizer") - reader.Close() - } - }) - - return reader -} diff --git a/pkg/util/codec/binding__failpoint_binding__.go b/pkg/util/codec/binding__failpoint_binding__.go deleted file mode 100644 index 894075781dd35..0000000000000 --- a/pkg/util/codec/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package codec - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/codec/decimal.go b/pkg/util/codec/decimal.go index a4a0d80c78f87..7b26feb86b2d8 100644 --- a/pkg/util/codec/decimal.go +++ b/pkg/util/codec/decimal.go @@ -47,11 +47,11 @@ func valueSizeOfDecimal(dec *types.MyDecimal, precision, frac int) (int, error) // DecodeDecimal decodes bytes to decimal. func DecodeDecimal(b []byte) ([]byte, *types.MyDecimal, int, int, error) { - if val, _err_ := failpoint.Eval(_curpkg_("errorInDecodeDecimal")); _err_ == nil { + failpoint.Inject("errorInDecodeDecimal", func(val failpoint.Value) { if val.(bool) { - return b, nil, 0, 0, errors.New("gofail error") + failpoint.Return(b, nil, 0, 0, errors.New("gofail error")) } - } + }) if len(b) < 3 { return b, nil, 0, 0, errors.New("insufficient bytes to decode value") diff --git a/pkg/util/codec/decimal.go__failpoint_stash__ b/pkg/util/codec/decimal.go__failpoint_stash__ deleted file mode 100644 index 7b26feb86b2d8..0000000000000 --- a/pkg/util/codec/decimal.go__failpoint_stash__ +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package codec - -import ( - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/types" -) - -// EncodeDecimal encodes a decimal into a byte slice which can be sorted lexicographically later. -func EncodeDecimal(b []byte, dec *types.MyDecimal, precision, frac int) ([]byte, error) { - if precision == 0 { - precision, frac = dec.PrecisionAndFrac() - } - if frac > mysql.MaxDecimalScale { - frac = mysql.MaxDecimalScale - } - b = append(b, byte(precision), byte(frac)) - b, err := dec.WriteBin(precision, frac, b) - return b, errors.Trace(err) -} - -func valueSizeOfDecimal(dec *types.MyDecimal, precision, frac int) (int, error) { - if precision == 0 { - precision, frac = dec.PrecisionAndFrac() - } - binSize, err := types.DecimalBinSize(precision, frac) - if err != nil { - return 0, err - } - return binSize + 2, nil -} - -// DecodeDecimal decodes bytes to decimal. -func DecodeDecimal(b []byte) ([]byte, *types.MyDecimal, int, int, error) { - failpoint.Inject("errorInDecodeDecimal", func(val failpoint.Value) { - if val.(bool) { - failpoint.Return(b, nil, 0, 0, errors.New("gofail error")) - } - }) - - if len(b) < 3 { - return b, nil, 0, 0, errors.New("insufficient bytes to decode value") - } - precision := int(b[0]) - frac := int(b[1]) - b = b[2:] - dec := new(types.MyDecimal) - binSize, err := dec.FromBin(b, precision, frac) - b = b[binSize:] - if err != nil { - return b, nil, precision, frac, errors.Trace(err) - } - return b, dec, precision, frac, nil -} diff --git a/pkg/util/cpu/binding__failpoint_binding__.go b/pkg/util/cpu/binding__failpoint_binding__.go deleted file mode 100644 index 440594fb462e6..0000000000000 --- a/pkg/util/cpu/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package cpu - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/cpu/cpu.go b/pkg/util/cpu/cpu.go index c6a72ae64de22..6d14cbfada58d 100644 --- a/pkg/util/cpu/cpu.go +++ b/pkg/util/cpu/cpu.go @@ -125,8 +125,8 @@ func getCPUTime() (userTimeMillis, sysTimeMillis int64, err error) { // GetCPUCount returns the number of logical CPUs usable by the current process. func GetCPUCount() int { - if val, _err_ := failpoint.Eval(_curpkg_("mockNumCpu")); _err_ == nil { - return val.(int) - } + failpoint.Inject("mockNumCpu", func(val failpoint.Value) { + failpoint.Return(val.(int)) + }) return runtime.GOMAXPROCS(0) } diff --git a/pkg/util/cpu/cpu.go__failpoint_stash__ b/pkg/util/cpu/cpu.go__failpoint_stash__ deleted file mode 100644 index 6d14cbfada58d..0000000000000 --- a/pkg/util/cpu/cpu.go__failpoint_stash__ +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cpu - -import ( - "os" - "runtime" - "sync" - "time" - - sigar "github.com/cloudfoundry/gosigar" - "github.com/pingcap/failpoint" - "github.com/pingcap/log" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/util/cgroup" - "github.com/pingcap/tidb/pkg/util/mathutil" - "go.uber.org/atomic" - "go.uber.org/zap" -) - -var cpuUsage atomic.Float64 - -// If your kernel is lower than linux 4.7, you cannot get the cpu usage in the container. -var unsupported atomic.Bool - -// GetCPUUsage returns the cpu usage of the current process. -func GetCPUUsage() (float64, bool) { - return cpuUsage.Load(), unsupported.Load() -} - -// Observer is used to observe the cpu usage of the current process. -type Observer struct { - utime int64 - stime int64 - now int64 - exit chan struct{} - cpu mathutil.ExponentialMovingAverage - wg sync.WaitGroup -} - -// NewCPUObserver returns a cpu observer. -func NewCPUObserver() *Observer { - return &Observer{ - exit: make(chan struct{}), - now: time.Now().UnixNano(), - cpu: *mathutil.NewExponentialMovingAverage(0.95, 10), - } -} - -// Start starts the cpu observer. -func (c *Observer) Start() { - _, err := cgroup.GetCgroupCPU() - if err != nil { - unsupported.Store(true) - log.Error("GetCgroupCPU", zap.Error(err)) - return - } - c.wg.Add(1) - go func() { - ticker := time.NewTicker(100 * time.Millisecond) - defer func() { - ticker.Stop() - c.wg.Done() - }() - for { - select { - case <-ticker.C: - curr := c.observe() - c.cpu.Add(curr) - cpuUsage.Store(c.cpu.Get()) - metrics.EMACPUUsageGauge.Set(c.cpu.Get()) - case <-c.exit: - return - } - } - }() -} - -// Stop stops the cpu observer. -func (c *Observer) Stop() { - close(c.exit) - c.wg.Wait() -} - -func (c *Observer) observe() float64 { - user, sys, err := getCPUTime() - if err != nil { - log.Error("getCPUTime", zap.Error(err)) - } - cgroupCPU, _ := cgroup.GetCgroupCPU() - cpuShare := cgroupCPU.CPUShares() - now := time.Now().UnixNano() - dur := float64(now - c.now) - utime := user * 1e6 - stime := sys * 1e6 - urate := float64(utime-c.utime) / dur - srate := float64(stime-c.stime) / dur - c.now = now - c.utime = utime - c.stime = stime - return (srate + urate) / cpuShare -} - -// getCPUTime returns the cumulative user/system time (in ms) since the process start. -func getCPUTime() (userTimeMillis, sysTimeMillis int64, err error) { - pid := os.Getpid() - cpuTime := sigar.ProcTime{} - if err := cpuTime.Get(pid); err != nil { - return 0, 0, err - } - return int64(cpuTime.User), int64(cpuTime.Sys), nil -} - -// GetCPUCount returns the number of logical CPUs usable by the current process. -func GetCPUCount() int { - failpoint.Inject("mockNumCpu", func(val failpoint.Value) { - failpoint.Return(val.(int)) - }) - return runtime.GOMAXPROCS(0) -} diff --git a/pkg/util/etcd.go b/pkg/util/etcd.go index 66b28b29ae44f..bff00a8d66428 100644 --- a/pkg/util/etcd.go +++ b/pkg/util/etcd.go @@ -51,21 +51,21 @@ func NewSession(ctx context.Context, logPrefix string, etcdCli *clientv3.Client, return etcdSession, errors.Trace(err) } - if val, _err_ := failpoint.Eval(_curpkg_("closeClient")); _err_ == nil { + failpoint.Inject("closeClient", func(val failpoint.Value) { if val.(bool) { if err := etcdCli.Close(); err != nil { - return etcdSession, errors.Trace(err) + failpoint.Return(etcdSession, errors.Trace(err)) } } - } + }) - if val, _err_ := failpoint.Eval(_curpkg_("closeGrpc")); _err_ == nil { + failpoint.Inject("closeGrpc", func(val failpoint.Value) { if val.(bool) { if err := etcdCli.ActiveConnection().Close(); err != nil { - return etcdSession, errors.Trace(err) + failpoint.Return(etcdSession, errors.Trace(err)) } } - } + }) startTime := time.Now() etcdSession, err = concurrency.NewSession(etcdCli, diff --git a/pkg/util/etcd.go__failpoint_stash__ b/pkg/util/etcd.go__failpoint_stash__ deleted file mode 100644 index bff00a8d66428..0000000000000 --- a/pkg/util/etcd.go__failpoint_stash__ +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2020 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package util - -import ( - "context" - "math" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/metrics" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/util/logutil" - clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/client/v3/concurrency" - "go.uber.org/zap" - "google.golang.org/grpc" -) - -const ( - newSessionRetryInterval = 200 * time.Millisecond - logIntervalCnt = int(3 * time.Second / newSessionRetryInterval) - - // NewSessionDefaultRetryCnt is the default retry times when create new session. - NewSessionDefaultRetryCnt = 3 - // NewSessionRetryUnlimited is the unlimited retry times when create new session. - NewSessionRetryUnlimited = math.MaxInt64 -) - -// NewSession creates a new etcd session. -func NewSession(ctx context.Context, logPrefix string, etcdCli *clientv3.Client, retryCnt, ttl int) (*concurrency.Session, error) { - var err error - - var etcdSession *concurrency.Session - failedCnt := 0 - for i := 0; i < retryCnt; i++ { - if err = contextDone(ctx, err); err != nil { - return etcdSession, errors.Trace(err) - } - - failpoint.Inject("closeClient", func(val failpoint.Value) { - if val.(bool) { - if err := etcdCli.Close(); err != nil { - failpoint.Return(etcdSession, errors.Trace(err)) - } - } - }) - - failpoint.Inject("closeGrpc", func(val failpoint.Value) { - if val.(bool) { - if err := etcdCli.ActiveConnection().Close(); err != nil { - failpoint.Return(etcdSession, errors.Trace(err)) - } - } - }) - - startTime := time.Now() - etcdSession, err = concurrency.NewSession(etcdCli, - concurrency.WithTTL(ttl), concurrency.WithContext(ctx)) - metrics.NewSessionHistogram.WithLabelValues(logPrefix, metrics.RetLabel(err)).Observe(time.Since(startTime).Seconds()) - if err == nil { - break - } - if failedCnt%logIntervalCnt == 0 { - logutil.BgLogger().Warn("failed to new session to etcd", zap.String("ownerInfo", logPrefix), zap.Error(err)) - } - - time.Sleep(newSessionRetryInterval) - failedCnt++ - } - return etcdSession, errors.Trace(err) -} - -func contextDone(ctx context.Context, err error) error { - select { - case <-ctx.Done(): - return errors.Trace(ctx.Err()) - default: - } - // Sometime the ctx isn't closed, but the etcd client is closed, - // we need to treat it as if context is done. - // TODO: Make sure ctx is closed with etcd client. - if terror.ErrorEqual(err, context.Canceled) || - terror.ErrorEqual(err, context.DeadlineExceeded) || - terror.ErrorEqual(err, grpc.ErrClientConnClosing) { - return errors.Trace(err) - } - - return nil -} diff --git a/pkg/util/gctuner/binding__failpoint_binding__.go b/pkg/util/gctuner/binding__failpoint_binding__.go deleted file mode 100644 index 58a2d96183692..0000000000000 --- a/pkg/util/gctuner/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package gctuner - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/gctuner/memory_limit_tuner.go b/pkg/util/gctuner/memory_limit_tuner.go index 027fa9f5b6ac8..dcf2ccb96ee22 100644 --- a/pkg/util/gctuner/memory_limit_tuner.go +++ b/pkg/util/gctuner/memory_limit_tuner.go @@ -107,17 +107,17 @@ func (t *memoryLimitTuner) tuning() { if intest.InTest { resetInterval = 3 * time.Second } - if val, _err_ := failpoint.Eval(_curpkg_("mockUpdateGlobalVarDuringAdjustPercentage")); _err_ == nil { + failpoint.Inject("mockUpdateGlobalVarDuringAdjustPercentage", func(val failpoint.Value) { if val, ok := val.(bool); val && ok { time.Sleep(300 * time.Millisecond) t.UpdateMemoryLimit() } - } - if val, _err_ := failpoint.Eval(_curpkg_("testMemoryLimitTuner")); _err_ == nil { + }) + failpoint.Inject("testMemoryLimitTuner", func(val failpoint.Value) { if val, ok := val.(bool); val && ok { resetInterval = 1 * time.Second } - } + }) time.Sleep(resetInterval) debug.SetMemoryLimit(t.calcMemoryLimit(t.GetPercentage())) for !t.adjustPercentageInProgress.CompareAndSwap(true, false) { diff --git a/pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ b/pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ deleted file mode 100644 index dcf2ccb96ee22..0000000000000 --- a/pkg/util/gctuner/memory_limit_tuner.go__failpoint_stash__ +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package gctuner - -import ( - "math" - "runtime/debug" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/memory" - atomicutil "go.uber.org/atomic" -) - -// GlobalMemoryLimitTuner only allow one memory limit tuner in one process -var GlobalMemoryLimitTuner = &memoryLimitTuner{} - -// Go runtime trigger GC when hit memory limit which managed via runtime/debug.SetMemoryLimit. -// So we can change memory limit dynamically to avoid frequent GC when memory usage is greater than the limit. -type memoryLimitTuner struct { - finalizer *finalizer - isValidValueSet atomicutil.Bool - percentage atomicutil.Float64 - adjustPercentageInProgress atomicutil.Bool - serverMemLimitBeforeAdjust atomicutil.Uint64 - percentageBeforeAdjust atomicutil.Float64 - nextGCTriggeredByMemoryLimit atomicutil.Bool - - // The flag to disable memory limit adjust. There might be many tasks need to activate it in future, - // so it is integer type. - adjustDisabled atomicutil.Int64 -} - -// fallbackPercentage indicates the fallback memory limit percentage when turning. -const fallbackPercentage float64 = 1.1 - -var memoryGoroutineCntInTest = *atomicutil.NewInt64(0) - -// WaitMemoryLimitTunerExitInTest is used to wait memory limit tuner exit in test. -func WaitMemoryLimitTunerExitInTest() { - if intest.InTest { - for memoryGoroutineCntInTest.Load() > 0 { - time.Sleep(100 * time.Millisecond) - } - } -} - -// DisableAdjustMemoryLimit makes memoryLimitTuner directly return `initGOMemoryLimitValue` when function `calcMemoryLimit` is called. -func (t *memoryLimitTuner) DisableAdjustMemoryLimit() { - t.adjustDisabled.Add(1) - debug.SetMemoryLimit(initGOMemoryLimitValue) -} - -// EnableAdjustMemoryLimit makes memoryLimitTuner return an adjusted memory limit when function `calcMemoryLimit` is called. -func (t *memoryLimitTuner) EnableAdjustMemoryLimit() { - t.adjustDisabled.Add(-1) - t.UpdateMemoryLimit() -} - -// tuning check the memory nextGC and judge whether this GC is trigger by memory limit. -// Go runtime ensure that it will be called serially. -func (t *memoryLimitTuner) tuning() { - if !t.isValidValueSet.Load() { - return - } - r := memory.ForceReadMemStats() - gogc := util.GetGOGC() - ratio := float64(100+gogc) / 100 - // This `if` checks whether the **last** GC was triggered by MemoryLimit as far as possible. - // If the **last** GC was triggered by MemoryLimit, we'll set MemoryLimit to MAXVALUE to return control back to GOGC - // to avoid frequent GC when memory usage fluctuates above and below MemoryLimit. - // The logic we judge whether the **last** GC was triggered by MemoryLimit is as follows: - // suppose `NextGC` = `HeapInUse * (100 + GOGC) / 100)`, - // - If NextGC < MemoryLimit, the **next** GC will **not** be triggered by MemoryLimit thus we do not care about - // why the **last** GC is triggered. And MemoryLimit will not be reset this time. - // - Only if NextGC >= MemoryLimit , the **next** GC will be triggered by MemoryLimit. Thus, we need to reset - // MemoryLimit after the **next** GC happens if needed. - if float64(r.HeapInuse)*ratio > float64(debug.SetMemoryLimit(-1)) { - if t.nextGCTriggeredByMemoryLimit.Load() && t.adjustPercentageInProgress.CompareAndSwap(false, true) { - // It's ok to update `adjustPercentageInProgress`, `serverMemLimitBeforeAdjust` and `percentageBeforeAdjust` not in a transaction. - // The update of memory limit is eventually consistent. - t.serverMemLimitBeforeAdjust.Store(memory.ServerMemoryLimit.Load()) - t.percentageBeforeAdjust.Store(t.GetPercentage()) - go func() { - if intest.InTest { - memoryGoroutineCntInTest.Inc() - defer memoryGoroutineCntInTest.Dec() - } - memory.MemoryLimitGCLast.Store(time.Now()) - memory.MemoryLimitGCTotal.Add(1) - debug.SetMemoryLimit(t.calcMemoryLimit(fallbackPercentage)) - resetInterval := 1 * time.Minute // Wait 1 minute and set back, to avoid frequent GC - if intest.InTest { - resetInterval = 3 * time.Second - } - failpoint.Inject("mockUpdateGlobalVarDuringAdjustPercentage", func(val failpoint.Value) { - if val, ok := val.(bool); val && ok { - time.Sleep(300 * time.Millisecond) - t.UpdateMemoryLimit() - } - }) - failpoint.Inject("testMemoryLimitTuner", func(val failpoint.Value) { - if val, ok := val.(bool); val && ok { - resetInterval = 1 * time.Second - } - }) - time.Sleep(resetInterval) - debug.SetMemoryLimit(t.calcMemoryLimit(t.GetPercentage())) - for !t.adjustPercentageInProgress.CompareAndSwap(true, false) { - continue - } - }() - memory.TriggerMemoryLimitGC.Store(true) - } - t.nextGCTriggeredByMemoryLimit.Store(true) - } else { - t.nextGCTriggeredByMemoryLimit.Store(false) - memory.TriggerMemoryLimitGC.Store(false) - } -} - -// Start starts the memory limit tuner. -func (t *memoryLimitTuner) Start() { - t.finalizer = newFinalizer(t.tuning) // Start tuning -} - -// Stop stops the memory limit tuner. -func (t *memoryLimitTuner) Stop() { - t.finalizer.stop() -} - -// SetPercentage set the percentage for memory limit tuner. -func (t *memoryLimitTuner) SetPercentage(percentage float64) { - t.percentage.Store(percentage) -} - -// GetPercentage get the percentage from memory limit tuner. -func (t *memoryLimitTuner) GetPercentage() float64 { - return t.percentage.Load() -} - -// UpdateMemoryLimit updates the memory limit. -// This function should be called when `tidb_server_memory_limit` or `tidb_server_memory_limit_gc_trigger` is modified. -func (t *memoryLimitTuner) UpdateMemoryLimit() { - if t.adjustPercentageInProgress.Load() { - if t.serverMemLimitBeforeAdjust.Load() == memory.ServerMemoryLimit.Load() && t.percentageBeforeAdjust.Load() == t.GetPercentage() { - return - } - } - var memoryLimit = t.calcMemoryLimit(t.GetPercentage()) - if memoryLimit == math.MaxInt64 { - t.isValidValueSet.Store(false) - memoryLimit = initGOMemoryLimitValue - } else { - t.isValidValueSet.Store(true) - } - debug.SetMemoryLimit(memoryLimit) -} - -func (t *memoryLimitTuner) calcMemoryLimit(percentage float64) int64 { - if t.adjustDisabled.Load() > 0 { - return initGOMemoryLimitValue - } - memoryLimit := int64(float64(memory.ServerMemoryLimit.Load()) * percentage) // `tidb_server_memory_limit` * `tidb_server_memory_limit_gc_trigger` - if memoryLimit == 0 { - memoryLimit = math.MaxInt64 - } - return memoryLimit -} - -var initGOMemoryLimitValue int64 - -func init() { - initGOMemoryLimitValue = debug.SetMemoryLimit(-1) - GlobalMemoryLimitTuner.Start() -} diff --git a/pkg/util/memory/binding__failpoint_binding__.go b/pkg/util/memory/binding__failpoint_binding__.go deleted file mode 100644 index 2f327d1d1bb00..0000000000000 --- a/pkg/util/memory/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package memory - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/memory/meminfo.go b/pkg/util/memory/meminfo.go index f89304a4629cd..0d7bd68a2bcbe 100644 --- a/pkg/util/memory/meminfo.go +++ b/pkg/util/memory/meminfo.go @@ -36,11 +36,11 @@ var MemUsed func() (uint64, error) // GetMemTotalIgnoreErr returns the total amount of RAM on this system/container. If error occurs, return 0. func GetMemTotalIgnoreErr() uint64 { if memTotal, err := MemTotal(); err == nil { - if val, _err_ := failpoint.Eval(_curpkg_("GetMemTotalError")); _err_ == nil { + failpoint.Inject("GetMemTotalError", func(val failpoint.Value) { if val, ok := val.(bool); val && ok { memTotal = 0 } - } + }) return memTotal } return 0 diff --git a/pkg/util/memory/meminfo.go__failpoint_stash__ b/pkg/util/memory/meminfo.go__failpoint_stash__ deleted file mode 100644 index 0d7bd68a2bcbe..0000000000000 --- a/pkg/util/memory/meminfo.go__failpoint_stash__ +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memory - -import ( - "sync" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/sysutil" - "github.com/pingcap/tidb/pkg/parser/terror" - "github.com/pingcap/tidb/pkg/util/cgroup" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/shirou/gopsutil/v3/mem" - "go.uber.org/zap" -) - -// MemTotal returns the total amount of RAM on this system -var MemTotal func() (uint64, error) - -// MemUsed returns the total used amount of RAM on this system -var MemUsed func() (uint64, error) - -// GetMemTotalIgnoreErr returns the total amount of RAM on this system/container. If error occurs, return 0. -func GetMemTotalIgnoreErr() uint64 { - if memTotal, err := MemTotal(); err == nil { - failpoint.Inject("GetMemTotalError", func(val failpoint.Value) { - if val, ok := val.(bool); val && ok { - memTotal = 0 - } - }) - return memTotal - } - return 0 -} - -// MemTotalNormal returns the total amount of RAM on this system in non-container environment. -func MemTotalNormal() (uint64, error) { - total, t := memLimit.get() - if time.Since(t) < 60*time.Second { - return total, nil - } - return memTotalNormal() -} - -func memTotalNormal() (uint64, error) { - v, err := mem.VirtualMemory() - if err != nil { - return 0, err - } - memLimit.set(v.Total, time.Now()) - return v.Total, nil -} - -// MemUsedNormal returns the total used amount of RAM on this system in non-container environment. -func MemUsedNormal() (uint64, error) { - used, t := memUsage.get() - if time.Since(t) < 500*time.Millisecond { - return used, nil - } - v, err := mem.VirtualMemory() - if err != nil { - return 0, err - } - memUsage.set(v.Used, time.Now()) - return v.Used, nil -} - -type memInfoCache struct { - updateTime time.Time - mu *sync.RWMutex - mem uint64 -} - -func (c *memInfoCache) get() (memo uint64, t time.Time) { - c.mu.RLock() - defer c.mu.RUnlock() - memo, t = c.mem, c.updateTime - return -} - -func (c *memInfoCache) set(memo uint64, t time.Time) { - c.mu.Lock() - defer c.mu.Unlock() - c.mem, c.updateTime = memo, t -} - -// expiration time is 60s -var memLimit *memInfoCache - -// expiration time is 500ms -var memUsage *memInfoCache - -// expiration time is 500ms -// save the memory usage of the server process -var serverMemUsage *memInfoCache - -// MemTotalCGroup returns the total amount of RAM on this system in container environment. -func MemTotalCGroup() (uint64, error) { - memo, t := memLimit.get() - if time.Since(t) < 60*time.Second { - return memo, nil - } - memo, err := cgroup.GetMemoryLimit() - if err != nil { - return memo, err - } - v, err := mem.VirtualMemory() - if err != nil { - return 0, err - } - memo = min(v.Total, memo) - memLimit.set(memo, time.Now()) - return memo, nil -} - -// MemUsedCGroup returns the total used amount of RAM on this system in container environment. -func MemUsedCGroup() (uint64, error) { - memo, t := memUsage.get() - if time.Since(t) < 500*time.Millisecond { - return memo, nil - } - memo, err := cgroup.GetMemoryUsage() - if err != nil { - return memo, err - } - v, err := mem.VirtualMemory() - if err != nil { - return 0, err - } - memo = min(v.Used, memo) - memUsage.set(memo, time.Now()) - return memo, nil -} - -// it is for test and init. -func init() { - if cgroup.InContainer() { - MemTotal = MemTotalCGroup - MemUsed = MemUsedCGroup - sysutil.RegisterGetMemoryCapacity(MemTotalCGroup) - } else { - MemTotal = MemTotalNormal - MemUsed = MemUsedNormal - } - memLimit = &memInfoCache{ - mu: &sync.RWMutex{}, - } - memUsage = &memInfoCache{ - mu: &sync.RWMutex{}, - } - serverMemUsage = &memInfoCache{ - mu: &sync.RWMutex{}, - } - _, err := MemTotal() - terror.MustNil(err) - _, err = MemUsed() - terror.MustNil(err) -} - -// InitMemoryHook initializes the memory hook. -// It is to solve the problem that tidb cannot read cgroup in the systemd. -// so if we are not in the container, we compare the cgroup memory limit and the physical memory, -// the cgroup memory limit is smaller, we use the cgroup memory hook. -func InitMemoryHook() { - if cgroup.InContainer() { - logutil.BgLogger().Info("use cgroup memory hook because TiDB is in the container") - return - } - cgroupValue, err := cgroup.GetMemoryLimit() - if err != nil { - return - } - physicalValue, err := memTotalNormal() - if err != nil { - return - } - if physicalValue > cgroupValue && cgroupValue != 0 { - MemTotal = MemTotalCGroup - MemUsed = MemUsedCGroup - sysutil.RegisterGetMemoryCapacity(MemTotalCGroup) - logutil.BgLogger().Info("use cgroup memory hook", zap.Int64("cgroupMemorySize", int64(cgroupValue)), zap.Int64("physicalMemorySize", int64(physicalValue))) - } else { - logutil.BgLogger().Info("use physical memory hook", zap.Int64("cgroupMemorySize", int64(cgroupValue)), zap.Int64("physicalMemorySize", int64(physicalValue))) - } - _, err = MemTotal() - terror.MustNil(err) - _, err = MemUsed() - terror.MustNil(err) -} - -// InstanceMemUsed returns the memory usage of this TiDB server -func InstanceMemUsed() (uint64, error) { - used, t := serverMemUsage.get() - if time.Since(t) < 500*time.Millisecond { - return used, nil - } - var memoryUsage uint64 - instanceStats := ReadMemStats() - memoryUsage = instanceStats.HeapAlloc - serverMemUsage.set(memoryUsage, time.Now()) - return memoryUsage, nil -} diff --git a/pkg/util/memory/memstats.go b/pkg/util/memory/memstats.go index b0d53f08d59df..4ea192620bee2 100644 --- a/pkg/util/memory/memstats.go +++ b/pkg/util/memory/memstats.go @@ -35,10 +35,10 @@ func ReadMemStats() (memStats *runtime.MemStats) { } else { memStats = ForceReadMemStats() } - if val, _err_ := failpoint.Eval(_curpkg_("ReadMemStats")); _err_ == nil { + failpoint.Inject("ReadMemStats", func(val failpoint.Value) { injectedSize := val.(int) memStats = &runtime.MemStats{HeapInuse: memStats.HeapInuse + uint64(injectedSize)} - } + }) return } diff --git a/pkg/util/memory/memstats.go__failpoint_stash__ b/pkg/util/memory/memstats.go__failpoint_stash__ deleted file mode 100644 index 4ea192620bee2..0000000000000 --- a/pkg/util/memory/memstats.go__failpoint_stash__ +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package memory - -import ( - "runtime" - "sync/atomic" - "time" - - "github.com/pingcap/failpoint" -) - -var stats atomic.Pointer[globalMstats] - -// ReadMemInterval controls the interval to read memory stats. -const ReadMemInterval = 300 * time.Millisecond - -// ReadMemStats read the mem stats from runtime.ReadMemStats -func ReadMemStats() (memStats *runtime.MemStats) { - s := stats.Load() - if s != nil { - memStats = &s.m - } else { - memStats = ForceReadMemStats() - } - failpoint.Inject("ReadMemStats", func(val failpoint.Value) { - injectedSize := val.(int) - memStats = &runtime.MemStats{HeapInuse: memStats.HeapInuse + uint64(injectedSize)} - }) - return -} - -// ForceReadMemStats is to force read memory stats. -func ForceReadMemStats() *runtime.MemStats { - var g globalMstats - g.ts = time.Now() - runtime.ReadMemStats(&g.m) - stats.Store(&g) - return &g.m -} - -type globalMstats struct { - ts time.Time - m runtime.MemStats -} diff --git a/pkg/util/replayer/binding__failpoint_binding__.go b/pkg/util/replayer/binding__failpoint_binding__.go deleted file mode 100644 index 69f926fdeb093..0000000000000 --- a/pkg/util/replayer/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package replayer - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/replayer/replayer.go b/pkg/util/replayer/replayer.go index d4b1c12643a0e..9711479fb6c41 100644 --- a/pkg/util/replayer/replayer.go +++ b/pkg/util/replayer/replayer.go @@ -59,9 +59,9 @@ func GeneratePlanReplayerFileName(isCapture, isContinuesCapture, enableHistorica func generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (string, error) { // Generate key and create zip file time := time.Now().UnixNano() - if val, _err_ := failpoint.Eval(_curpkg_("InjectPlanReplayerFileNameTimeField")); _err_ == nil { + failpoint.Inject("InjectPlanReplayerFileNameTimeField", func(val failpoint.Value) { time = int64(val.(int)) - } + }) b := make([]byte, 16) //nolint: gosec _, err := rand.Read(b) diff --git a/pkg/util/replayer/replayer.go__failpoint_stash__ b/pkg/util/replayer/replayer.go__failpoint_stash__ deleted file mode 100644 index 9711479fb6c41..0000000000000 --- a/pkg/util/replayer/replayer.go__failpoint_stash__ +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package replayer - -import ( - "crypto/rand" - "encoding/base64" - "fmt" - "os" - "path/filepath" - "time" - - "github.com/pingcap/errors" - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/config" -) - -// PlanReplayerTaskKey indicates key of a plan replayer task -type PlanReplayerTaskKey struct { - SQLDigest string - PlanDigest string -} - -// GeneratePlanReplayerFile generates plan replayer file -func GeneratePlanReplayerFile(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (*os.File, string, error) { - path := GetPlanReplayerDirName() - err := os.MkdirAll(path, os.ModePerm) - if err != nil { - return nil, "", errors.AddStack(err) - } - fileName, err := generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture) - if err != nil { - return nil, "", errors.AddStack(err) - } - zf, err := os.Create(filepath.Join(path, fileName)) - if err != nil { - return nil, "", errors.AddStack(err) - } - return zf, fileName, err -} - -// GeneratePlanReplayerFileName generates plan replayer capture task name -func GeneratePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (string, error) { - return generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture) -} - -func generatePlanReplayerFileName(isCapture, isContinuesCapture, enableHistoricalStatsForCapture bool) (string, error) { - // Generate key and create zip file - time := time.Now().UnixNano() - failpoint.Inject("InjectPlanReplayerFileNameTimeField", func(val failpoint.Value) { - time = int64(val.(int)) - }) - b := make([]byte, 16) - //nolint: gosec - _, err := rand.Read(b) - if err != nil { - return "", err - } - key := base64.URLEncoding.EncodeToString(b) - // "capture_replayer" in filename has special meaning for the /plan_replayer/dump/ HTTP handler - if isContinuesCapture || isCapture && enableHistoricalStatsForCapture { - return fmt.Sprintf("capture_replayer_%v_%v.zip", key, time), nil - } - if isCapture && !enableHistoricalStatsForCapture { - return fmt.Sprintf("capture_normal_replayer_%v_%v.zip", key, time), nil - } - return fmt.Sprintf("replayer_%v_%v.zip", key, time), nil -} - -// GetPlanReplayerDirName returns plan replayer directory path. -// The path is related to the process id. -func GetPlanReplayerDirName() string { - tidbLogDir := filepath.Dir(config.GetGlobalConfig().Log.File.Filename) - return filepath.Join(tidbLogDir, "replayer") -} diff --git a/pkg/util/servermemorylimit/binding__failpoint_binding__.go b/pkg/util/servermemorylimit/binding__failpoint_binding__.go deleted file mode 100644 index a3d497c09b313..0000000000000 --- a/pkg/util/servermemorylimit/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package servermemorylimit - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/servermemorylimit/servermemorylimit.go b/pkg/util/servermemorylimit/servermemorylimit.go index b7bb2c01b61fa..47022d23c6635 100644 --- a/pkg/util/servermemorylimit/servermemorylimit.go +++ b/pkg/util/servermemorylimit/servermemorylimit.go @@ -136,11 +136,11 @@ func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { if bt == 0 { return } - if val, _err_ := failpoint.Eval(_curpkg_("issue42662_2")); _err_ == nil { + failpoint.Inject("issue42662_2", func(val failpoint.Value) { if val.(bool) { bt = 1 } - } + }) instanceStats := memory.ReadMemStats() if instanceStats.HeapInuse > MemoryMaxUsed.Load() { MemoryMaxUsed.Store(instanceStats.HeapInuse) diff --git a/pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ b/pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ deleted file mode 100644 index 47022d23c6635..0000000000000 --- a/pkg/util/servermemorylimit/servermemorylimit.go__failpoint_stash__ +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright 2022 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package servermemorylimit - -import ( - "fmt" - "runtime" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/types" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/memory" - "github.com/pingcap/tidb/pkg/util/sqlkiller" - atomicutil "go.uber.org/atomic" - "go.uber.org/zap" -) - -// Process global Observation indicators for memory limit. -var ( - MemoryMaxUsed = atomicutil.NewUint64(0) - SessionKillLast = atomicutil.NewTime(time.Time{}) - SessionKillTotal = atomicutil.NewInt64(0) - IsKilling = atomicutil.NewBool(false) - GlobalMemoryOpsHistoryManager = &memoryOpsHistoryManager{} -) - -// Handle is the handler for server memory limit. -type Handle struct { - exitCh chan struct{} - sm atomic.Value -} - -// NewServerMemoryLimitHandle builds a new server memory limit handler. -func NewServerMemoryLimitHandle(exitCh chan struct{}) *Handle { - return &Handle{exitCh: exitCh} -} - -// SetSessionManager sets the SessionManager which is used to fetching the info -// of all active sessions. -func (smqh *Handle) SetSessionManager(sm util.SessionManager) *Handle { - smqh.sm.Store(sm) - return smqh -} - -// Run starts a server memory limit checker goroutine at the start time of the server. -// This goroutine will obtain the `heapInuse` of Golang runtime periodically and compare it with `tidb_server_memory_limit`. -// When `heapInuse` is greater than `tidb_server_memory_limit`, it will set the `needKill` flag of `MemUsageTop1Tracker`. -// When the corresponding SQL try to acquire more memory(next Tracker.Consume() call), it will trigger panic and exit. -// When this goroutine detects the `needKill` SQL has exited successfully, it will immediately trigger runtime.GC() to release memory resources. -func (smqh *Handle) Run() { - tickInterval := time.Millisecond * time.Duration(100) - ticker := time.NewTicker(tickInterval) - defer ticker.Stop() - sm := smqh.sm.Load().(util.SessionManager) - sessionToBeKilled := &sessionToBeKilled{} - for { - select { - case <-ticker.C: - killSessIfNeeded(sessionToBeKilled, memory.ServerMemoryLimit.Load(), sm) - case <-smqh.exitCh: - return - } - } -} - -type sessionToBeKilled struct { - isKilling bool - sqlStartTime time.Time - sessionID uint64 - sessionTracker *memory.Tracker - - killStartTime time.Time - lastLogTime time.Time -} - -func (s *sessionToBeKilled) reset() { - s.isKilling = false - s.sqlStartTime = time.Time{} - s.sessionID = 0 - s.sessionTracker = nil - s.killStartTime = time.Time{} - s.lastLogTime = time.Time{} -} - -func killSessIfNeeded(s *sessionToBeKilled, bt uint64, sm util.SessionManager) { - if s.isKilling { - if info, ok := sm.GetProcessInfo(s.sessionID); ok { - if info.Time == s.sqlStartTime { - if time.Since(s.lastLogTime) > 5*time.Second { - logutil.BgLogger().Warn(fmt.Sprintf("global memory controller failed to kill the top-consumer in %ds", - time.Since(s.killStartTime)/time.Second), - zap.Uint64("conn", info.ID), - zap.String("sql digest", info.Digest), - zap.String("sql text", fmt.Sprintf("%.100v", info.Info)), - zap.Int64("sql memory usage", info.MemTracker.BytesConsumed())) - s.lastLogTime = time.Now() - - if seconds := time.Since(s.killStartTime) / time.Second; seconds >= 60 { - // If the SQL cannot be terminated after 60 seconds, it may be stuck in the network stack while writing packets to the client, - // encountering some bugs that cause it to hang, or failing to detect the kill signal. - // In this case, the resources can be reclaimed by calling the `Finish` method, and then we can start looking for the next SQL with the largest memory usage. - logutil.BgLogger().Warn(fmt.Sprintf("global memory controller failed to kill the top-consumer in %d seconds. Attempting to force close the executors.", seconds)) - s.sessionTracker.Killer.FinishResultSet() - goto Succ - } - } - return - } - } - Succ: - s.reset() - IsKilling.Store(false) - memory.MemUsageTop1Tracker.CompareAndSwap(s.sessionTracker, nil) - //nolint: all_revive,revive - runtime.GC() - logutil.BgLogger().Warn("global memory controller killed the top1 memory consumer successfully") - } - - if bt == 0 { - return - } - failpoint.Inject("issue42662_2", func(val failpoint.Value) { - if val.(bool) { - bt = 1 - } - }) - instanceStats := memory.ReadMemStats() - if instanceStats.HeapInuse > MemoryMaxUsed.Load() { - MemoryMaxUsed.Store(instanceStats.HeapInuse) - } - limitSessMinSize := memory.ServerMemoryLimitSessMinSize.Load() - if instanceStats.HeapInuse > bt { - t := memory.MemUsageTop1Tracker.Load() - if t != nil { - sessionID := t.SessionID.Load() - memUsage := t.BytesConsumed() - // If the memory usage of the top1 session is less than tidb_server_memory_limit_sess_min_size, we do not need to kill it. - if uint64(memUsage) < limitSessMinSize { - memory.MemUsageTop1Tracker.CompareAndSwap(t, nil) - t = nil - } else if info, ok := sm.GetProcessInfo(sessionID); ok { - logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer", - zap.Uint64("conn", info.ID), - zap.String("sql digest", info.Digest), - zap.String("sql text", fmt.Sprintf("%.100v", info.Info)), - zap.Uint64("tidb_server_memory_limit", bt), - zap.Uint64("heap inuse", instanceStats.HeapInuse), - zap.Int64("sql memory usage", info.MemTracker.BytesConsumed()), - ) - s.sessionID = sessionID - s.sqlStartTime = info.Time - s.isKilling = true - s.sessionTracker = t - t.Killer.SendKillSignal(sqlkiller.ServerMemoryExceeded) - - killTime := time.Now() - SessionKillTotal.Add(1) - SessionKillLast.Store(killTime) - IsKilling.Store(true) - GlobalMemoryOpsHistoryManager.recordOne(info, killTime, bt, instanceStats.HeapInuse) - s.lastLogTime = time.Now() - s.killStartTime = time.Now() - } - } - // If no one larger than tidb_server_memory_limit_sess_min_size is found, we will not kill any one. - if t == nil { - if s.lastLogTime.IsZero() { - s.lastLogTime = time.Now() - } - if time.Since(s.lastLogTime) < 5*time.Second { - return - } - logutil.BgLogger().Warn("global memory controller tries to kill the top1 memory consumer, but no one larger than tidb_server_memory_limit_sess_min_size is found", zap.Uint64("tidb_server_memory_limit_sess_min_size", limitSessMinSize)) - s.lastLogTime = time.Now() - } - } -} - -type memoryOpsHistoryManager struct { - mu sync.Mutex - infos []memoryOpsHistory - offsets int -} - -type memoryOpsHistory struct { - killTime time.Time - memoryLimit uint64 - memoryCurrent uint64 - processInfoDatum []types.Datum // id,user,host,db,command,time,state,info,digest,mem,disk,txnStart -} - -func (m *memoryOpsHistoryManager) init() { - m.infos = make([]memoryOpsHistory, 50) - m.offsets = 0 -} - -func (m *memoryOpsHistoryManager) recordOne(info *util.ProcessInfo, killTime time.Time, memoryLimit uint64, memoryCurrent uint64) { - m.mu.Lock() - defer m.mu.Unlock() - op := memoryOpsHistory{killTime: killTime, memoryLimit: memoryLimit, memoryCurrent: memoryCurrent, processInfoDatum: types.MakeDatums(info.ToRow(time.UTC)...)} - sqlInfo := op.processInfoDatum[7] - sqlInfo.SetString(fmt.Sprintf("%.256v", sqlInfo.GetString()), mysql.DefaultCollationName) // Truncated - // Only record the last 50 history ops - m.infos[m.offsets] = op - m.offsets++ - if m.offsets >= 50 { - m.offsets = 0 - } -} - -func (m *memoryOpsHistoryManager) GetRows() [][]types.Datum { - m.mu.Lock() - defer m.mu.Unlock() - rows := make([][]types.Datum, 0, len(m.infos)) - getRowFromInfo := func(info memoryOpsHistory) { - killTime := types.NewTime(types.FromGoTime(info.killTime), mysql.TypeDatetime, 0) - op := "SessionKill" - rows = append(rows, []types.Datum{ - types.NewDatum(killTime), // TIME - types.NewDatum(op), // OPS - types.NewDatum(info.memoryLimit), // MEMORY_LIMIT - types.NewDatum(info.memoryCurrent), // MEMORY_CURRENT - info.processInfoDatum[0], // PROCESSID - info.processInfoDatum[9], // MEM - info.processInfoDatum[10], // DISK - info.processInfoDatum[2], // CLIENT - info.processInfoDatum[3], // DB - info.processInfoDatum[1], // USER - info.processInfoDatum[8], // SQL_DIGEST - info.processInfoDatum[7], // SQL_TEXT - }) - } - var zeroTime = time.Time{} - for i := 0; i < len(m.infos); i++ { - pos := (m.offsets + i) % len(m.infos) - info := m.infos[pos] - if info.killTime.Equal(zeroTime) { - continue - } - getRowFromInfo(info) - } - return rows -} - -func init() { - GlobalMemoryOpsHistoryManager.init() -} diff --git a/pkg/util/session_pool.go b/pkg/util/session_pool.go index f233b04d4888d..95f5dd9515b43 100644 --- a/pkg/util/session_pool.go +++ b/pkg/util/session_pool.go @@ -66,9 +66,9 @@ func (p *pool) Get() (resource pools.Resource, err error) { } // Put the internal session to the map of SessionManager - if _, _err_ := failpoint.Eval(_curpkg_("mockSessionPoolReturnError")); _err_ == nil { + failpoint.Inject("mockSessionPoolReturnError", func() { err = errors.New("mockSessionPoolReturnError") - } + }) if err == nil && p.getCallback != nil { p.getCallback(resource) diff --git a/pkg/util/session_pool.go__failpoint_stash__ b/pkg/util/session_pool.go__failpoint_stash__ deleted file mode 100644 index 95f5dd9515b43..0000000000000 --- a/pkg/util/session_pool.go__failpoint_stash__ +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2024 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package util - -import ( - "errors" - "sync" - - "github.com/ngaut/pools" - "github.com/pingcap/failpoint" -) - -// SessionPool is a recyclable resource pool for the session. -type SessionPool interface { - Get() (pools.Resource, error) - Put(pools.Resource) - Close() -} - -// resourceCallback is a helper function to be triggered after Get/Put call. -type resourceCallback func(pools.Resource) - -type pool struct { - resources chan pools.Resource - factory pools.Factory - mu struct { - sync.RWMutex - closed bool - } - getCallback resourceCallback - putCallback resourceCallback -} - -// NewSessionPool creates a new session pool with the given capacity and factory function. -func NewSessionPool(capacity int, factory pools.Factory, getCallback, putCallback resourceCallback) SessionPool { - return &pool{ - resources: make(chan pools.Resource, capacity), - factory: factory, - getCallback: getCallback, - putCallback: putCallback, - } -} - -// Get gets a session from the session pool. -func (p *pool) Get() (resource pools.Resource, err error) { - var ok bool - select { - case resource, ok = <-p.resources: - if !ok { - err = errors.New("session pool closed") - } - default: - resource, err = p.factory() - } - - // Put the internal session to the map of SessionManager - failpoint.Inject("mockSessionPoolReturnError", func() { - err = errors.New("mockSessionPoolReturnError") - }) - - if err == nil && p.getCallback != nil { - p.getCallback(resource) - } - - return -} - -// Put puts the session back to the pool. -func (p *pool) Put(resource pools.Resource) { - p.mu.RLock() - defer p.mu.RUnlock() - if p.putCallback != nil { - p.putCallback(resource) - } - if p.mu.closed { - resource.Close() - return - } - - select { - case p.resources <- resource: - default: - resource.Close() - } -} - -// Close closes the pool to release all resources. -func (p *pool) Close() { - p.mu.Lock() - if p.mu.closed { - p.mu.Unlock() - return - } - p.mu.closed = true - close(p.resources) - p.mu.Unlock() - - for r := range p.resources { - r.Close() - } -} diff --git a/pkg/util/sli/binding__failpoint_binding__.go b/pkg/util/sli/binding__failpoint_binding__.go deleted file mode 100644 index c9eff40605613..0000000000000 --- a/pkg/util/sli/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package sli - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/sli/sli.go b/pkg/util/sli/sli.go index 1917255c28585..4acdee79e5e09 100644 --- a/pkg/util/sli/sli.go +++ b/pkg/util/sli/sli.go @@ -50,9 +50,9 @@ func (t *TxnWriteThroughputSLI) FinishExecuteStmt(cost time.Duration, affectRow t.reportMetric() // Skip reset for test. - if _, _err_ := failpoint.Eval(_curpkg_("CheckTxnWriteThroughput")); _err_ == nil { - return - } + failpoint.Inject("CheckTxnWriteThroughput", func() { + failpoint.Return() + }) // Reset for next transaction. t.Reset() diff --git a/pkg/util/sli/sli.go__failpoint_stash__ b/pkg/util/sli/sli.go__failpoint_stash__ deleted file mode 100644 index 4acdee79e5e09..0000000000000 --- a/pkg/util/sli/sli.go__failpoint_stash__ +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2021 PingCAP, Inc. - -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sli - -import ( - "fmt" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/metrics" -) - -// TxnWriteThroughputSLI uses to report transaction write throughput metrics for SLI. -type TxnWriteThroughputSLI struct { - invalid bool - affectRow uint64 - writeSize int - readKeys int - writeKeys int - writeTime time.Duration -} - -// FinishExecuteStmt records the cost for write statement which affect rows more than 0. -// And report metrics when the transaction is committed. -func (t *TxnWriteThroughputSLI) FinishExecuteStmt(cost time.Duration, affectRow uint64, inTxn bool) { - if affectRow > 0 { - t.writeTime += cost - t.affectRow += affectRow - } - - // Currently not in transaction means the last transaction is finish, should report metrics and reset data. - if !inTxn { - if affectRow == 0 { - // AffectRows is 0 when statement is commit. - t.writeTime += cost - } - // Report metrics after commit this transaction. - t.reportMetric() - - // Skip reset for test. - failpoint.Inject("CheckTxnWriteThroughput", func() { - failpoint.Return() - }) - - // Reset for next transaction. - t.Reset() - } -} - -// AddReadKeys adds the read keys. -func (t *TxnWriteThroughputSLI) AddReadKeys(readKeys int64) { - t.readKeys += int(readKeys) -} - -// AddTxnWriteSize adds the transaction write size and keys. -func (t *TxnWriteThroughputSLI) AddTxnWriteSize(size, keys int) { - t.writeSize += size - t.writeKeys += keys -} - -func (t *TxnWriteThroughputSLI) reportMetric() { - if t.IsInvalid() { - return - } - if t.IsSmallTxn() { - metrics.SmallTxnWriteDuration.Observe(t.writeTime.Seconds()) - } else { - metrics.TxnWriteThroughput.Observe(float64(t.writeSize) / t.writeTime.Seconds()) - } -} - -// SetInvalid marks this transaction is invalid to report SLI metrics. -func (t *TxnWriteThroughputSLI) SetInvalid() { - t.invalid = true -} - -// IsInvalid checks the transaction is valid to report SLI metrics. Currently, the following case will cause invalid: -// 1. The transaction contains `insert|replace into ... select ... from ...` statement. -// 2. The write SQL statement has more read keys than write keys. -func (t *TxnWriteThroughputSLI) IsInvalid() bool { - return t.invalid || t.readKeys > t.writeKeys || t.writeSize == 0 || t.writeTime == 0 -} - -const ( - smallTxnAffectRow = 20 - smallTxnSize = 1 * 1024 * 1024 // 1MB -) - -// IsSmallTxn exports for testing. -func (t *TxnWriteThroughputSLI) IsSmallTxn() bool { - return t.affectRow <= smallTxnAffectRow && t.writeSize <= smallTxnSize -} - -// Reset exports for testing. -func (t *TxnWriteThroughputSLI) Reset() { - t.invalid = false - t.affectRow = 0 - t.writeSize = 0 - t.readKeys = 0 - t.writeKeys = 0 - t.writeTime = 0 -} - -// String exports for testing. -func (t *TxnWriteThroughputSLI) String() string { - return fmt.Sprintf("invalid: %v, affectRow: %v, writeSize: %v, readKeys: %v, writeKeys: %v, writeTime: %v", - t.invalid, t.affectRow, t.writeSize, t.readKeys, t.writeKeys, t.writeTime.String()) -} diff --git a/pkg/util/sqlkiller/binding__failpoint_binding__.go b/pkg/util/sqlkiller/binding__failpoint_binding__.go deleted file mode 100644 index 02be0a09bb3d4..0000000000000 --- a/pkg/util/sqlkiller/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package sqlkiller - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/sqlkiller/sqlkiller.go b/pkg/util/sqlkiller/sqlkiller.go index bb3553f77fa8d..06782653a5d05 100644 --- a/pkg/util/sqlkiller/sqlkiller.go +++ b/pkg/util/sqlkiller/sqlkiller.go @@ -108,7 +108,7 @@ func (killer *SQLKiller) ClearFinishFunc() { // HandleSignal handles the kill signal and return the error. func (killer *SQLKiller) HandleSignal() error { - if val, _err_ := failpoint.Eval(_curpkg_("randomPanic")); _err_ == nil { + failpoint.Inject("randomPanic", func(val failpoint.Value) { if p, ok := val.(int); ok { if rand.Float64() > (float64)(p)/1000 { if killer.ConnID != 0 { @@ -117,7 +117,7 @@ func (killer *SQLKiller) HandleSignal() error { } } } - } + }) status := atomic.LoadUint32(&killer.Signal) err := killer.getKillError(status) if status == ServerMemoryExceeded { diff --git a/pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ b/pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ deleted file mode 100644 index 06782653a5d05..0000000000000 --- a/pkg/util/sqlkiller/sqlkiller.go__failpoint_stash__ +++ /dev/null @@ -1,136 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlkiller - -import ( - "math/rand" - "sync" - "sync/atomic" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util/dbterror/exeerrors" - "github.com/pingcap/tidb/pkg/util/logutil" - "go.uber.org/zap" -) - -type killSignal = uint32 - -// KillSignal types. -const ( - UnspecifiedKillSignal killSignal = iota - QueryInterrupted - MaxExecTimeExceeded - QueryMemoryExceeded - ServerMemoryExceeded - // When you add a new signal, you should also modify store/driver/error/ToTidbErr, - // so that errors in client can be correctly converted to tidb errors. -) - -// SQLKiller is used to kill a query. -type SQLKiller struct { - Signal killSignal - ConnID uint64 - // FinishFuncLock is used to ensure that Finish is not called and modified at the same time. - // An external call to the Finish function only allows when the main goroutine to be in the writeResultSet process. - // When the main goroutine exits the writeResultSet process, the Finish function will be cleared. - FinishFuncLock sync.Mutex - Finish func() - // InWriteResultSet is used to indicate whether the query is currently calling clientConn.writeResultSet(). - // If the query is in writeResultSet and Finish() can acquire rs.finishLock, we can assume the query is waiting for the client to receive data from the server over network I/O. - InWriteResultSet atomic.Bool -} - -// SendKillSignal sends a kill signal to the query. -func (killer *SQLKiller) SendKillSignal(reason killSignal) { - if atomic.CompareAndSwapUint32(&killer.Signal, 0, reason) { - status := atomic.LoadUint32(&killer.Signal) - err := killer.getKillError(status) - logutil.BgLogger().Warn("kill initiated", zap.Uint64("connection ID", killer.ConnID), zap.String("reason", err.Error())) - } -} - -// GetKillSignal gets the kill signal. -func (killer *SQLKiller) GetKillSignal() killSignal { - return atomic.LoadUint32(&killer.Signal) -} - -// getKillError gets the error according to the kill signal. -func (killer *SQLKiller) getKillError(status killSignal) error { - switch status { - case QueryInterrupted: - return exeerrors.ErrQueryInterrupted.GenWithStackByArgs() - case MaxExecTimeExceeded: - return exeerrors.ErrMaxExecTimeExceeded.GenWithStackByArgs() - case QueryMemoryExceeded: - return exeerrors.ErrMemoryExceedForQuery.GenWithStackByArgs(killer.ConnID) - case ServerMemoryExceeded: - return exeerrors.ErrMemoryExceedForInstance.GenWithStackByArgs(killer.ConnID) - } - return nil -} - -// FinishResultSet is used to close the result set. -// If a kill signal is sent but the SQL query is stuck in the network stack while writing packets to the client, -// encountering some bugs that cause it to hang, or failing to detect the kill signal, we can call Finish to release resources used during the SQL execution process. -func (killer *SQLKiller) FinishResultSet() { - killer.FinishFuncLock.Lock() - defer killer.FinishFuncLock.Unlock() - if killer.Finish != nil { - killer.Finish() - } -} - -// SetFinishFunc sets the finish function. -func (killer *SQLKiller) SetFinishFunc(fn func()) { - killer.FinishFuncLock.Lock() - defer killer.FinishFuncLock.Unlock() - killer.Finish = fn -} - -// ClearFinishFunc clears the finish function.1 -func (killer *SQLKiller) ClearFinishFunc() { - killer.FinishFuncLock.Lock() - defer killer.FinishFuncLock.Unlock() - killer.Finish = nil -} - -// HandleSignal handles the kill signal and return the error. -func (killer *SQLKiller) HandleSignal() error { - failpoint.Inject("randomPanic", func(val failpoint.Value) { - if p, ok := val.(int); ok { - if rand.Float64() > (float64)(p)/1000 { - if killer.ConnID != 0 { - targetStatus := rand.Int31n(5) - atomic.StoreUint32(&killer.Signal, uint32(targetStatus)) - } - } - } - }) - status := atomic.LoadUint32(&killer.Signal) - err := killer.getKillError(status) - if status == ServerMemoryExceeded { - logutil.BgLogger().Warn("global memory controller, NeedKill signal is received successfully", - zap.Uint64("conn", killer.ConnID)) - } - return err -} - -// Reset resets the SqlKiller. -func (killer *SQLKiller) Reset() { - if atomic.LoadUint32(&killer.Signal) != 0 { - logutil.BgLogger().Warn("kill finished", zap.Uint64("conn", killer.ConnID)) - } - atomic.StoreUint32(&killer.Signal, 0) -} diff --git a/pkg/util/stmtsummary/binding__failpoint_binding__.go b/pkg/util/stmtsummary/binding__failpoint_binding__.go deleted file mode 100644 index bc46da5472193..0000000000000 --- a/pkg/util/stmtsummary/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package stmtsummary - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/stmtsummary/statement_summary.go b/pkg/util/stmtsummary/statement_summary.go index 636cf4fa991fd..e9dc133b1c46e 100644 --- a/pkg/util/stmtsummary/statement_summary.go +++ b/pkg/util/stmtsummary/statement_summary.go @@ -297,7 +297,7 @@ func (ssMap *stmtSummaryByDigestMap) AddStatement(sei *StmtExecInfo) { // All times are counted in seconds. now := time.Now().Unix() - if val, _err_ := failpoint.Eval(_curpkg_("mockTimeForStatementsSummary")); _err_ == nil { + failpoint.Inject("mockTimeForStatementsSummary", func(val failpoint.Value) { // mockTimeForStatementsSummary takes string of Unix timestamp if unixTimeStr, ok := val.(string); ok { unixTime, err := strconv.ParseInt(unixTimeStr, 10, 64) @@ -306,7 +306,7 @@ func (ssMap *stmtSummaryByDigestMap) AddStatement(sei *StmtExecInfo) { } now = unixTime } - } + }) intervalSeconds := ssMap.refreshInterval() historySize := ssMap.historySize() diff --git a/pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ b/pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ deleted file mode 100644 index e9dc133b1c46e..0000000000000 --- a/pkg/util/stmtsummary/statement_summary.go__failpoint_stash__ +++ /dev/null @@ -1,1039 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmtsummary - -import ( - "bytes" - "cmp" - "container/list" - "fmt" - "math" - "slices" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" - "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/execdetails" - "github.com/pingcap/tidb/pkg/util/hack" - "github.com/pingcap/tidb/pkg/util/kvcache" - "github.com/pingcap/tidb/pkg/util/plancodec" - "github.com/tikv/client-go/v2/util" - atomic2 "go.uber.org/atomic" - "golang.org/x/exp/maps" -) - -// stmtSummaryByDigestKey defines key for stmtSummaryByDigestMap.summaryMap. -type stmtSummaryByDigestKey struct { - // Same statements may appear in different schema, but they refer to different tables. - schemaName string - digest string - // The digest of the previous statement. - prevDigest string - // The digest of the plan of this SQL. - planDigest string - // `resourceGroupName` is the resource group's name of this statement is bind to. - resourceGroupName string - // `hash` is the hash value of this object. - hash []byte -} - -// Hash implements SimpleLRUCache.Key. -// Only when current SQL is `commit` do we record `prevSQL`. Otherwise, `prevSQL` is empty. -// `prevSQL` is included in the key To distinguish different transactions. -func (key *stmtSummaryByDigestKey) Hash() []byte { - if len(key.hash) == 0 { - key.hash = make([]byte, 0, len(key.schemaName)+len(key.digest)+len(key.prevDigest)+len(key.planDigest)+len(key.resourceGroupName)) - key.hash = append(key.hash, hack.Slice(key.digest)...) - key.hash = append(key.hash, hack.Slice(key.schemaName)...) - key.hash = append(key.hash, hack.Slice(key.prevDigest)...) - key.hash = append(key.hash, hack.Slice(key.planDigest)...) - key.hash = append(key.hash, hack.Slice(key.resourceGroupName)...) - } - return key.hash -} - -// stmtSummaryByDigestMap is a LRU cache that stores statement summaries. -type stmtSummaryByDigestMap struct { - // It's rare to read concurrently, so RWMutex is not needed. - sync.Mutex - summaryMap *kvcache.SimpleLRUCache - // beginTimeForCurInterval is the begin time for current summary. - beginTimeForCurInterval int64 - - // These options are set by global system variables and are accessed concurrently. - optEnabled *atomic2.Bool - optEnableInternalQuery *atomic2.Bool - optMaxStmtCount *atomic2.Uint32 - optRefreshInterval *atomic2.Int64 - optHistorySize *atomic2.Int32 - optMaxSQLLength *atomic2.Int32 - - // other stores summary of evicted data. - other *stmtSummaryByDigestEvicted -} - -// StmtSummaryByDigestMap is a global map containing all statement summaries. -var StmtSummaryByDigestMap = newStmtSummaryByDigestMap() - -// stmtSummaryByDigest is the summary for each type of statements. -type stmtSummaryByDigest struct { - // It's rare to read concurrently, so RWMutex is not needed. - // Mutex is only used to lock `history`. - sync.Mutex - initialized bool - // Each element in history is a summary in one interval. - history *list.List - // Following fields are common for each summary element. - // They won't change once this object is created, so locking is not needed. - schemaName string - digest string - planDigest string - stmtType string - normalizedSQL string - tableNames string - isInternal bool -} - -// stmtSummaryByDigestElement is the summary for each type of statements in current interval. -type stmtSummaryByDigestElement struct { - sync.Mutex - // Each summary is summarized between [beginTime, endTime). - beginTime int64 - endTime int64 - // basic - sampleSQL string - charset string - collation string - prevSQL string - samplePlan string - sampleBinaryPlan string - planHint string - indexNames []string - execCount int64 - sumErrors int - sumWarnings int - // latency - sumLatency time.Duration - maxLatency time.Duration - minLatency time.Duration - sumParseLatency time.Duration - maxParseLatency time.Duration - sumCompileLatency time.Duration - maxCompileLatency time.Duration - // coprocessor - sumNumCopTasks int64 - maxCopProcessTime time.Duration - maxCopProcessAddress string - maxCopWaitTime time.Duration - maxCopWaitAddress string - // TiKV - sumProcessTime time.Duration - maxProcessTime time.Duration - sumWaitTime time.Duration - maxWaitTime time.Duration - sumBackoffTime time.Duration - maxBackoffTime time.Duration - sumTotalKeys int64 - maxTotalKeys int64 - sumProcessedKeys int64 - maxProcessedKeys int64 - sumRocksdbDeleteSkippedCount uint64 - maxRocksdbDeleteSkippedCount uint64 - sumRocksdbKeySkippedCount uint64 - maxRocksdbKeySkippedCount uint64 - sumRocksdbBlockCacheHitCount uint64 - maxRocksdbBlockCacheHitCount uint64 - sumRocksdbBlockReadCount uint64 - maxRocksdbBlockReadCount uint64 - sumRocksdbBlockReadByte uint64 - maxRocksdbBlockReadByte uint64 - // txn - commitCount int64 - sumGetCommitTsTime time.Duration - maxGetCommitTsTime time.Duration - sumPrewriteTime time.Duration - maxPrewriteTime time.Duration - sumCommitTime time.Duration - maxCommitTime time.Duration - sumLocalLatchTime time.Duration - maxLocalLatchTime time.Duration - sumCommitBackoffTime int64 - maxCommitBackoffTime int64 - sumResolveLockTime int64 - maxResolveLockTime int64 - sumWriteKeys int64 - maxWriteKeys int - sumWriteSize int64 - maxWriteSize int - sumPrewriteRegionNum int64 - maxPrewriteRegionNum int32 - sumTxnRetry int64 - maxTxnRetry int - sumBackoffTimes int64 - backoffTypes map[string]int - authUsers map[string]struct{} - // other - sumMem int64 - maxMem int64 - sumDisk int64 - maxDisk int64 - sumAffectedRows uint64 - sumKVTotal time.Duration - sumPDTotal time.Duration - sumBackoffTotal time.Duration - sumWriteSQLRespTotal time.Duration - sumResultRows int64 - maxResultRows int64 - minResultRows int64 - prepared bool - // The first time this type of SQL executes. - firstSeen time.Time - // The last time this type of SQL executes. - lastSeen time.Time - // plan cache - planInCache bool - planCacheHits int64 - planInBinding bool - // pessimistic execution retry information. - execRetryCount uint - execRetryTime time.Duration - // request-units - resourceGroupName string - StmtRUSummary - - planCacheUnqualifiedCount int64 - lastPlanCacheUnqualified string // the reason why this query is unqualified for the plan cache -} - -// StmtExecInfo records execution information of each statement. -type StmtExecInfo struct { - SchemaName string - OriginalSQL fmt.Stringer - Charset string - Collation string - NormalizedSQL string - Digest string - PrevSQL string - PrevSQLDigest string - PlanGenerator func() (string, string, any) - BinaryPlanGenerator func() string - PlanDigest string - PlanDigestGen func() string - User string - TotalLatency time.Duration - ParseLatency time.Duration - CompileLatency time.Duration - StmtCtx *stmtctx.StatementContext - CopTasks *execdetails.CopTasksDetails - ExecDetail *execdetails.ExecDetails - MemMax int64 - DiskMax int64 - StartTime time.Time - IsInternal bool - Succeed bool - PlanInCache bool - PlanInBinding bool - ExecRetryCount uint - ExecRetryTime time.Duration - execdetails.StmtExecDetails - ResultRows int64 - TiKVExecDetails util.ExecDetails - Prepared bool - KeyspaceName string - KeyspaceID uint32 - ResourceGroupName string - RUDetail *util.RUDetails - - PlanCacheUnqualified string -} - -// newStmtSummaryByDigestMap creates an empty stmtSummaryByDigestMap. -func newStmtSummaryByDigestMap() *stmtSummaryByDigestMap { - ssbde := newStmtSummaryByDigestEvicted() - - // This initializes the stmtSummaryByDigestMap with "compiled defaults" - // (which are regrettably duplicated from sessionctx/variable/tidb_vars.go). - // Unfortunately we need to do this to avoid circular dependencies, but the correct - // values will be applied on startup as soon as domain.LoadSysVarCacheLoop() is called, - // which in turn calls func domain.checkEnableServerGlobalVar(name, sVal string) for each sysvar. - // Currently this is early enough in the startup sequence. - maxStmtCount := uint(3000) - newSsMap := &stmtSummaryByDigestMap{ - summaryMap: kvcache.NewSimpleLRUCache(maxStmtCount, 0, 0), - optMaxStmtCount: atomic2.NewUint32(uint32(maxStmtCount)), - optEnabled: atomic2.NewBool(true), - optEnableInternalQuery: atomic2.NewBool(false), - optRefreshInterval: atomic2.NewInt64(1800), - optHistorySize: atomic2.NewInt32(24), - optMaxSQLLength: atomic2.NewInt32(4096), - other: ssbde, - } - newSsMap.summaryMap.SetOnEvict(func(k kvcache.Key, v kvcache.Value) { - historySize := newSsMap.historySize() - newSsMap.other.AddEvicted(k.(*stmtSummaryByDigestKey), v.(*stmtSummaryByDigest), historySize) - }) - return newSsMap -} - -// AddStatement adds a statement to StmtSummaryByDigestMap. -func (ssMap *stmtSummaryByDigestMap) AddStatement(sei *StmtExecInfo) { - // All times are counted in seconds. - now := time.Now().Unix() - - failpoint.Inject("mockTimeForStatementsSummary", func(val failpoint.Value) { - // mockTimeForStatementsSummary takes string of Unix timestamp - if unixTimeStr, ok := val.(string); ok { - unixTime, err := strconv.ParseInt(unixTimeStr, 10, 64) - if err != nil { - panic(err.Error()) - } - now = unixTime - } - }) - - intervalSeconds := ssMap.refreshInterval() - historySize := ssMap.historySize() - - key := &stmtSummaryByDigestKey{ - schemaName: sei.SchemaName, - digest: sei.Digest, - prevDigest: sei.PrevSQLDigest, - planDigest: sei.PlanDigest, - resourceGroupName: sei.ResourceGroupName, - } - // Calculate hash value in advance, to reduce the time holding the lock. - key.Hash() - - // Enclose the block in a function to ensure the lock will always be released. - summary, beginTime := func() (*stmtSummaryByDigest, int64) { - ssMap.Lock() - defer ssMap.Unlock() - - // Check again. Statements could be added before disabling the flag and after Clear(). - if !ssMap.Enabled() { - return nil, 0 - } - if sei.IsInternal && !ssMap.EnabledInternal() { - return nil, 0 - } - - if ssMap.beginTimeForCurInterval+intervalSeconds <= now { - // `beginTimeForCurInterval` is a multiple of intervalSeconds, so that when the interval is a multiple - // of 60 (or 600, 1800, 3600, etc), begin time shows 'XX:XX:00', not 'XX:XX:01'~'XX:XX:59'. - ssMap.beginTimeForCurInterval = now / intervalSeconds * intervalSeconds - } - - beginTime := ssMap.beginTimeForCurInterval - value, ok := ssMap.summaryMap.Get(key) - var summary *stmtSummaryByDigest - if !ok { - // Lazy initialize it to release ssMap.mutex ASAP. - summary = new(stmtSummaryByDigest) - ssMap.summaryMap.Put(key, summary) - } else { - summary = value.(*stmtSummaryByDigest) - } - summary.isInternal = summary.isInternal && sei.IsInternal - return summary, beginTime - }() - // Lock a single entry, not the whole cache. - if summary != nil { - summary.add(sei, beginTime, intervalSeconds, historySize) - } -} - -// Clear removes all statement summaries. -func (ssMap *stmtSummaryByDigestMap) Clear() { - ssMap.Lock() - defer ssMap.Unlock() - - ssMap.summaryMap.DeleteAll() - ssMap.other.Clear() - ssMap.beginTimeForCurInterval = 0 -} - -// clearInternal removes all statement summaries which are internal summaries. -func (ssMap *stmtSummaryByDigestMap) clearInternal() { - ssMap.Lock() - defer ssMap.Unlock() - - for _, key := range ssMap.summaryMap.Keys() { - summary, ok := ssMap.summaryMap.Get(key) - if !ok { - continue - } - if summary.(*stmtSummaryByDigest).isInternal { - ssMap.summaryMap.Delete(key) - } - } -} - -// BindableStmt is a wrapper struct for a statement that is extracted from statements_summary and can be -// created binding on. -type BindableStmt struct { - Schema string - Query string - PlanHint string - Charset string - Collation string - Users map[string]struct{} // which users have processed this stmt -} - -// GetMoreThanCntBindableStmt gets users' select/update/delete SQLs that occurred more than the specified count. -func (ssMap *stmtSummaryByDigestMap) GetMoreThanCntBindableStmt(cnt int64) []*BindableStmt { - ssMap.Lock() - values := ssMap.summaryMap.Values() - ssMap.Unlock() - - stmts := make([]*BindableStmt, 0, len(values)) - for _, value := range values { - ssbd := value.(*stmtSummaryByDigest) - func() { - ssbd.Lock() - defer ssbd.Unlock() - if ssbd.initialized && (ssbd.stmtType == "Select" || ssbd.stmtType == "Delete" || ssbd.stmtType == "Update" || ssbd.stmtType == "Insert" || ssbd.stmtType == "Replace") { - if ssbd.history.Len() > 0 { - ssElement := ssbd.history.Back().Value.(*stmtSummaryByDigestElement) - ssElement.Lock() - - // Empty auth users means that it is an internal queries. - if len(ssElement.authUsers) > 0 && (int64(ssbd.history.Len()) > cnt || ssElement.execCount > cnt) { - stmt := &BindableStmt{ - Schema: ssbd.schemaName, - Query: ssElement.sampleSQL, - PlanHint: ssElement.planHint, - Charset: ssElement.charset, - Collation: ssElement.collation, - Users: make(map[string]struct{}), - } - maps.Copy(stmt.Users, ssElement.authUsers) - // If it is SQL command prepare / execute, the ssElement.sampleSQL is `execute ...`, we should get the original select query. - // If it is binary protocol prepare / execute, ssbd.normalizedSQL should be same as ssElement.sampleSQL. - if ssElement.prepared { - stmt.Query = ssbd.normalizedSQL - } - stmts = append(stmts, stmt) - } - ssElement.Unlock() - } - } - }() - } - return stmts -} - -// SetEnabled enables or disables statement summary -func (ssMap *stmtSummaryByDigestMap) SetEnabled(value bool) error { - // `optEnabled` and `ssMap` don't need to be strictly atomically updated. - ssMap.optEnabled.Store(value) - if !value { - ssMap.Clear() - } - return nil -} - -// Enabled returns whether statement summary is enabled. -func (ssMap *stmtSummaryByDigestMap) Enabled() bool { - return ssMap.optEnabled.Load() -} - -// SetEnabledInternalQuery enables or disables internal statement summary -func (ssMap *stmtSummaryByDigestMap) SetEnabledInternalQuery(value bool) error { - // `optEnableInternalQuery` and `ssMap` don't need to be strictly atomically updated. - ssMap.optEnableInternalQuery.Store(value) - if !value { - ssMap.clearInternal() - } - return nil -} - -// EnabledInternal returns whether internal statement summary is enabled. -func (ssMap *stmtSummaryByDigestMap) EnabledInternal() bool { - return ssMap.optEnableInternalQuery.Load() -} - -// SetRefreshInterval sets refreshing interval in ssMap.sysVars. -func (ssMap *stmtSummaryByDigestMap) SetRefreshInterval(value int64) error { - ssMap.optRefreshInterval.Store(value) - return nil -} - -// refreshInterval gets the refresh interval for summaries. -func (ssMap *stmtSummaryByDigestMap) refreshInterval() int64 { - return ssMap.optRefreshInterval.Load() -} - -// SetHistorySize sets the history size for all summaries. -func (ssMap *stmtSummaryByDigestMap) SetHistorySize(value int) error { - ssMap.optHistorySize.Store(int32(value)) - return nil -} - -// historySize gets the history size for summaries. -func (ssMap *stmtSummaryByDigestMap) historySize() int { - return int(ssMap.optHistorySize.Load()) -} - -// SetHistorySize sets the history size for all summaries. -func (ssMap *stmtSummaryByDigestMap) SetMaxStmtCount(value uint) error { - // `optMaxStmtCount` and `ssMap` don't need to be strictly atomically updated. - ssMap.optMaxStmtCount.Store(uint32(value)) - - ssMap.Lock() - defer ssMap.Unlock() - return ssMap.summaryMap.SetCapacity(value) -} - -// Used by tests -// nolint: unused -func (ssMap *stmtSummaryByDigestMap) maxStmtCount() int { - return int(ssMap.optMaxStmtCount.Load()) -} - -// SetHistorySize sets the history size for all summaries. -func (ssMap *stmtSummaryByDigestMap) SetMaxSQLLength(value int) error { - ssMap.optMaxSQLLength.Store(int32(value)) - return nil -} - -func (ssMap *stmtSummaryByDigestMap) maxSQLLength() int { - return int(ssMap.optMaxSQLLength.Load()) -} - -// GetBindableStmtFromCluster gets users' select/update/delete SQL. -func GetBindableStmtFromCluster(rows []chunk.Row) *BindableStmt { - for _, row := range rows { - user := row.GetString(3) - stmtType := row.GetString(0) - if user != "" && (stmtType == "Select" || stmtType == "Delete" || stmtType == "Update" || stmtType == "Insert" || stmtType == "Replace") { - // Empty auth users means that it is an internal queries. - stmt := &BindableStmt{ - Schema: row.GetString(1), //schemaName - Query: row.GetString(5), //sampleSQL - PlanHint: row.GetString(8), //planHint - Charset: row.GetString(6), //charset - Collation: row.GetString(7), //collation - } - // If it is SQL command prepare / execute, we should remove the arguments - // If it is binary protocol prepare / execute, ssbd.normalizedSQL should be same as ssElement.sampleSQL. - if row.GetInt64(4) == 1 { - if idx := strings.LastIndex(stmt.Query, "[arguments:"); idx != -1 { - stmt.Query = stmt.Query[:idx] - } - } - return stmt - } - } - return nil -} - -// newStmtSummaryByDigest creates a stmtSummaryByDigest from StmtExecInfo. -func (ssbd *stmtSummaryByDigest) init(sei *StmtExecInfo, _ int64, _ int64, _ int) { - // Use "," to separate table names to support FIND_IN_SET. - var buffer bytes.Buffer - for i, value := range sei.StmtCtx.Tables { - // In `create database` statement, DB name is not empty but table name is empty. - if len(value.Table) == 0 { - continue - } - buffer.WriteString(strings.ToLower(value.DB)) - buffer.WriteString(".") - buffer.WriteString(strings.ToLower(value.Table)) - if i < len(sei.StmtCtx.Tables)-1 { - buffer.WriteString(",") - } - } - tableNames := buffer.String() - - planDigest := sei.PlanDigest - if sei.PlanDigestGen != nil && len(planDigest) == 0 { - // It comes here only when the plan is 'Point_Get'. - planDigest = sei.PlanDigestGen() - } - ssbd.schemaName = sei.SchemaName - ssbd.digest = sei.Digest - ssbd.planDigest = planDigest - ssbd.stmtType = sei.StmtCtx.StmtType - ssbd.normalizedSQL = formatSQL(sei.NormalizedSQL) - ssbd.tableNames = tableNames - ssbd.history = list.New() - ssbd.initialized = true -} - -func (ssbd *stmtSummaryByDigest) add(sei *StmtExecInfo, beginTime int64, intervalSeconds int64, historySize int) { - // Enclose this block in a function to ensure the lock will always be released. - ssElement, isElementNew := func() (*stmtSummaryByDigestElement, bool) { - ssbd.Lock() - defer ssbd.Unlock() - - if !ssbd.initialized { - ssbd.init(sei, beginTime, intervalSeconds, historySize) - } - - var ssElement *stmtSummaryByDigestElement - isElementNew := true - if ssbd.history.Len() > 0 { - lastElement := ssbd.history.Back().Value.(*stmtSummaryByDigestElement) - if lastElement.beginTime >= beginTime { - ssElement = lastElement - isElementNew = false - } else { - // The last elements expires to the history. - lastElement.onExpire(intervalSeconds) - } - } - if isElementNew { - // If the element is new created, `ssElement.add(sei)` should be done inside the lock of `ssbd`. - ssElement = newStmtSummaryByDigestElement(sei, beginTime, intervalSeconds) - if ssElement == nil { - return nil, isElementNew - } - ssbd.history.PushBack(ssElement) - } - - // `historySize` might be modified anytime, so check expiration every time. - // Even if history is set to 0, current summary is still needed. - for ssbd.history.Len() > historySize && ssbd.history.Len() > 1 { - ssbd.history.Remove(ssbd.history.Front()) - } - - return ssElement, isElementNew - }() - - // Lock a single entry, not the whole `ssbd`. - if !isElementNew { - ssElement.add(sei, intervalSeconds) - } -} - -// collectHistorySummaries puts at most `historySize` summaries to an array. -func (ssbd *stmtSummaryByDigest) collectHistorySummaries(checker *stmtSummaryChecker, historySize int) []*stmtSummaryByDigestElement { - ssbd.Lock() - defer ssbd.Unlock() - - if !ssbd.initialized { - return nil - } - if checker != nil && !checker.isDigestValid(ssbd.digest) { - return nil - } - - ssElements := make([]*stmtSummaryByDigestElement, 0, ssbd.history.Len()) - for listElement := ssbd.history.Front(); listElement != nil && len(ssElements) < historySize; listElement = listElement.Next() { - ssElement := listElement.Value.(*stmtSummaryByDigestElement) - ssElements = append(ssElements, ssElement) - } - return ssElements -} - -// MaxEncodedPlanSizeInBytes is the upper limit of the size of the plan and the binary plan in the stmt summary. -var MaxEncodedPlanSizeInBytes = 1024 * 1024 - -func newStmtSummaryByDigestElement(sei *StmtExecInfo, beginTime int64, intervalSeconds int64) *stmtSummaryByDigestElement { - // sampleSQL / authUsers(sampleUser) / samplePlan / prevSQL / indexNames store the values shown at the first time, - // because it compacts performance to update every time. - samplePlan, planHint, e := sei.PlanGenerator() - if e != nil { - return nil - } - if len(samplePlan) > MaxEncodedPlanSizeInBytes { - samplePlan = plancodec.PlanDiscardedEncoded - } - binPlan := "" - if sei.BinaryPlanGenerator != nil { - binPlan = sei.BinaryPlanGenerator() - if len(binPlan) > MaxEncodedPlanSizeInBytes { - binPlan = plancodec.BinaryPlanDiscardedEncoded - } - } - ssElement := &stmtSummaryByDigestElement{ - beginTime: beginTime, - sampleSQL: formatSQL(sei.OriginalSQL.String()), - charset: sei.Charset, - collation: sei.Collation, - // PrevSQL is already truncated to cfg.Log.QueryLogMaxLen. - prevSQL: sei.PrevSQL, - // samplePlan needs to be decoded so it can't be truncated. - samplePlan: samplePlan, - sampleBinaryPlan: binPlan, - planHint: planHint, - indexNames: sei.StmtCtx.IndexNames, - minLatency: sei.TotalLatency, - firstSeen: sei.StartTime, - lastSeen: sei.StartTime, - backoffTypes: make(map[string]int), - authUsers: make(map[string]struct{}), - planInCache: false, - planCacheHits: 0, - planInBinding: false, - prepared: sei.Prepared, - minResultRows: math.MaxInt64, - resourceGroupName: sei.ResourceGroupName, - } - ssElement.add(sei, intervalSeconds) - return ssElement -} - -// onExpire is called when this element expires to history. -func (ssElement *stmtSummaryByDigestElement) onExpire(intervalSeconds int64) { - ssElement.Lock() - defer ssElement.Unlock() - - // refreshInterval may change anytime, so we need to update endTime. - if ssElement.beginTime+intervalSeconds > ssElement.endTime { - // // If interval changes to a bigger value, update endTime to beginTime + interval. - ssElement.endTime = ssElement.beginTime + intervalSeconds - } else if ssElement.beginTime+intervalSeconds < ssElement.endTime { - now := time.Now().Unix() - // If interval changes to a smaller value and now > beginTime + interval, update endTime to current time. - if now > ssElement.beginTime+intervalSeconds { - ssElement.endTime = now - } - } -} - -func (ssElement *stmtSummaryByDigestElement) add(sei *StmtExecInfo, intervalSeconds int64) { - ssElement.Lock() - defer ssElement.Unlock() - - // add user to auth users set - if len(sei.User) > 0 { - ssElement.authUsers[sei.User] = struct{}{} - } - - // refreshInterval may change anytime, update endTime ASAP. - ssElement.endTime = ssElement.beginTime + intervalSeconds - ssElement.execCount++ - if !sei.Succeed { - ssElement.sumErrors++ - } - ssElement.sumWarnings += int(sei.StmtCtx.WarningCount()) - - // latency - ssElement.sumLatency += sei.TotalLatency - if sei.TotalLatency > ssElement.maxLatency { - ssElement.maxLatency = sei.TotalLatency - } - if sei.TotalLatency < ssElement.minLatency { - ssElement.minLatency = sei.TotalLatency - } - ssElement.sumParseLatency += sei.ParseLatency - if sei.ParseLatency > ssElement.maxParseLatency { - ssElement.maxParseLatency = sei.ParseLatency - } - ssElement.sumCompileLatency += sei.CompileLatency - if sei.CompileLatency > ssElement.maxCompileLatency { - ssElement.maxCompileLatency = sei.CompileLatency - } - - // coprocessor - numCopTasks := int64(sei.CopTasks.NumCopTasks) - ssElement.sumNumCopTasks += numCopTasks - if sei.CopTasks.MaxProcessTime > ssElement.maxCopProcessTime { - ssElement.maxCopProcessTime = sei.CopTasks.MaxProcessTime - ssElement.maxCopProcessAddress = sei.CopTasks.MaxProcessAddress - } - if sei.CopTasks.MaxWaitTime > ssElement.maxCopWaitTime { - ssElement.maxCopWaitTime = sei.CopTasks.MaxWaitTime - ssElement.maxCopWaitAddress = sei.CopTasks.MaxWaitAddress - } - - // TiKV - ssElement.sumProcessTime += sei.ExecDetail.TimeDetail.ProcessTime - if sei.ExecDetail.TimeDetail.ProcessTime > ssElement.maxProcessTime { - ssElement.maxProcessTime = sei.ExecDetail.TimeDetail.ProcessTime - } - ssElement.sumWaitTime += sei.ExecDetail.TimeDetail.WaitTime - if sei.ExecDetail.TimeDetail.WaitTime > ssElement.maxWaitTime { - ssElement.maxWaitTime = sei.ExecDetail.TimeDetail.WaitTime - } - ssElement.sumBackoffTime += sei.ExecDetail.BackoffTime - if sei.ExecDetail.BackoffTime > ssElement.maxBackoffTime { - ssElement.maxBackoffTime = sei.ExecDetail.BackoffTime - } - - if sei.ExecDetail.ScanDetail != nil { - ssElement.sumTotalKeys += sei.ExecDetail.ScanDetail.TotalKeys - if sei.ExecDetail.ScanDetail.TotalKeys > ssElement.maxTotalKeys { - ssElement.maxTotalKeys = sei.ExecDetail.ScanDetail.TotalKeys - } - ssElement.sumProcessedKeys += sei.ExecDetail.ScanDetail.ProcessedKeys - if sei.ExecDetail.ScanDetail.ProcessedKeys > ssElement.maxProcessedKeys { - ssElement.maxProcessedKeys = sei.ExecDetail.ScanDetail.ProcessedKeys - } - ssElement.sumRocksdbDeleteSkippedCount += sei.ExecDetail.ScanDetail.RocksdbDeleteSkippedCount - if sei.ExecDetail.ScanDetail.RocksdbDeleteSkippedCount > ssElement.maxRocksdbDeleteSkippedCount { - ssElement.maxRocksdbDeleteSkippedCount = sei.ExecDetail.ScanDetail.RocksdbDeleteSkippedCount - } - ssElement.sumRocksdbKeySkippedCount += sei.ExecDetail.ScanDetail.RocksdbKeySkippedCount - if sei.ExecDetail.ScanDetail.RocksdbKeySkippedCount > ssElement.maxRocksdbKeySkippedCount { - ssElement.maxRocksdbKeySkippedCount = sei.ExecDetail.ScanDetail.RocksdbKeySkippedCount - } - ssElement.sumRocksdbBlockCacheHitCount += sei.ExecDetail.ScanDetail.RocksdbBlockCacheHitCount - if sei.ExecDetail.ScanDetail.RocksdbBlockCacheHitCount > ssElement.maxRocksdbBlockCacheHitCount { - ssElement.maxRocksdbBlockCacheHitCount = sei.ExecDetail.ScanDetail.RocksdbBlockCacheHitCount - } - ssElement.sumRocksdbBlockReadCount += sei.ExecDetail.ScanDetail.RocksdbBlockReadCount - if sei.ExecDetail.ScanDetail.RocksdbBlockReadCount > ssElement.maxRocksdbBlockReadCount { - ssElement.maxRocksdbBlockReadCount = sei.ExecDetail.ScanDetail.RocksdbBlockReadCount - } - ssElement.sumRocksdbBlockReadByte += sei.ExecDetail.ScanDetail.RocksdbBlockReadByte - if sei.ExecDetail.ScanDetail.RocksdbBlockReadByte > ssElement.maxRocksdbBlockReadByte { - ssElement.maxRocksdbBlockReadByte = sei.ExecDetail.ScanDetail.RocksdbBlockReadByte - } - } - - // txn - commitDetails := sei.ExecDetail.CommitDetail - if commitDetails != nil { - ssElement.commitCount++ - ssElement.sumPrewriteTime += commitDetails.PrewriteTime - if commitDetails.PrewriteTime > ssElement.maxPrewriteTime { - ssElement.maxPrewriteTime = commitDetails.PrewriteTime - } - ssElement.sumCommitTime += commitDetails.CommitTime - if commitDetails.CommitTime > ssElement.maxCommitTime { - ssElement.maxCommitTime = commitDetails.CommitTime - } - ssElement.sumGetCommitTsTime += commitDetails.GetCommitTsTime - if commitDetails.GetCommitTsTime > ssElement.maxGetCommitTsTime { - ssElement.maxGetCommitTsTime = commitDetails.GetCommitTsTime - } - resolveLockTime := atomic.LoadInt64(&commitDetails.ResolveLock.ResolveLockTime) - ssElement.sumResolveLockTime += resolveLockTime - if resolveLockTime > ssElement.maxResolveLockTime { - ssElement.maxResolveLockTime = resolveLockTime - } - ssElement.sumLocalLatchTime += commitDetails.LocalLatchTime - if commitDetails.LocalLatchTime > ssElement.maxLocalLatchTime { - ssElement.maxLocalLatchTime = commitDetails.LocalLatchTime - } - ssElement.sumWriteKeys += int64(commitDetails.WriteKeys) - if commitDetails.WriteKeys > ssElement.maxWriteKeys { - ssElement.maxWriteKeys = commitDetails.WriteKeys - } - ssElement.sumWriteSize += int64(commitDetails.WriteSize) - if commitDetails.WriteSize > ssElement.maxWriteSize { - ssElement.maxWriteSize = commitDetails.WriteSize - } - prewriteRegionNum := atomic.LoadInt32(&commitDetails.PrewriteRegionNum) - ssElement.sumPrewriteRegionNum += int64(prewriteRegionNum) - if prewriteRegionNum > ssElement.maxPrewriteRegionNum { - ssElement.maxPrewriteRegionNum = prewriteRegionNum - } - ssElement.sumTxnRetry += int64(commitDetails.TxnRetry) - if commitDetails.TxnRetry > ssElement.maxTxnRetry { - ssElement.maxTxnRetry = commitDetails.TxnRetry - } - commitDetails.Mu.Lock() - commitBackoffTime := commitDetails.Mu.CommitBackoffTime - ssElement.sumCommitBackoffTime += commitBackoffTime - if commitBackoffTime > ssElement.maxCommitBackoffTime { - ssElement.maxCommitBackoffTime = commitBackoffTime - } - ssElement.sumBackoffTimes += int64(len(commitDetails.Mu.PrewriteBackoffTypes)) - for _, backoffType := range commitDetails.Mu.PrewriteBackoffTypes { - ssElement.backoffTypes[backoffType]++ - } - ssElement.sumBackoffTimes += int64(len(commitDetails.Mu.CommitBackoffTypes)) - for _, backoffType := range commitDetails.Mu.CommitBackoffTypes { - ssElement.backoffTypes[backoffType]++ - } - commitDetails.Mu.Unlock() - } - - // plan cache - if sei.PlanInCache { - ssElement.planInCache = true - ssElement.planCacheHits++ - } else { - ssElement.planInCache = false - } - if sei.PlanCacheUnqualified != "" { - ssElement.planCacheUnqualifiedCount++ - ssElement.lastPlanCacheUnqualified = sei.PlanCacheUnqualified - } - - // SPM - if sei.PlanInBinding { - ssElement.planInBinding = true - } else { - ssElement.planInBinding = false - } - - // other - ssElement.sumAffectedRows += sei.StmtCtx.AffectedRows() - ssElement.sumMem += sei.MemMax - if sei.MemMax > ssElement.maxMem { - ssElement.maxMem = sei.MemMax - } - ssElement.sumDisk += sei.DiskMax - if sei.DiskMax > ssElement.maxDisk { - ssElement.maxDisk = sei.DiskMax - } - if sei.StartTime.Before(ssElement.firstSeen) { - ssElement.firstSeen = sei.StartTime - } - if ssElement.lastSeen.Before(sei.StartTime) { - ssElement.lastSeen = sei.StartTime - } - if sei.ExecRetryCount > 0 { - ssElement.execRetryCount += sei.ExecRetryCount - ssElement.execRetryTime += sei.ExecRetryTime - } - if sei.ResultRows > 0 { - ssElement.sumResultRows += sei.ResultRows - if ssElement.maxResultRows < sei.ResultRows { - ssElement.maxResultRows = sei.ResultRows - } - if ssElement.minResultRows > sei.ResultRows { - ssElement.minResultRows = sei.ResultRows - } - } else { - ssElement.minResultRows = 0 - } - ssElement.sumKVTotal += time.Duration(atomic.LoadInt64(&sei.TiKVExecDetails.WaitKVRespDuration)) - ssElement.sumPDTotal += time.Duration(atomic.LoadInt64(&sei.TiKVExecDetails.WaitPDRespDuration)) - ssElement.sumBackoffTotal += time.Duration(atomic.LoadInt64(&sei.TiKVExecDetails.BackoffDuration)) - ssElement.sumWriteSQLRespTotal += sei.StmtExecDetails.WriteSQLRespDuration - - // request-units - ssElement.StmtRUSummary.Add(sei.RUDetail) -} - -// Truncate SQL to maxSQLLength. -func formatSQL(sql string) string { - maxSQLLength := StmtSummaryByDigestMap.maxSQLLength() - length := len(sql) - if length > maxSQLLength { - var result strings.Builder - result.WriteString(sql[:maxSQLLength]) - fmt.Fprintf(&result, "(len:%d)", length) - return result.String() - } - return sql -} - -// Format the backoffType map to a string or nil. -func formatBackoffTypes(backoffMap map[string]int) any { - type backoffStat struct { - backoffType string - count int - } - - size := len(backoffMap) - if size == 0 { - return nil - } - - backoffArray := make([]backoffStat, 0, len(backoffMap)) - for backoffType, count := range backoffMap { - backoffArray = append(backoffArray, backoffStat{backoffType, count}) - } - slices.SortFunc(backoffArray, func(i, j backoffStat) int { - return cmp.Compare(j.count, i.count) - }) - - var buffer bytes.Buffer - for index, stat := range backoffArray { - if _, err := fmt.Fprintf(&buffer, "%v:%d", stat.backoffType, stat.count); err != nil { - return "FORMAT ERROR" - } - if index < len(backoffArray)-1 { - buffer.WriteString(",") - } - } - return buffer.String() -} - -func avgInt(sum int64, count int64) int64 { - if count > 0 { - return sum / count - } - return 0 -} - -func avgFloat(sum int64, count int64) float64 { - if count > 0 { - return float64(sum) / float64(count) - } - return 0 -} - -func avgSumFloat(sum float64, count int64) float64 { - if count > 0 { - return sum / float64(count) - } - return 0 -} - -func convertEmptyToNil(str string) any { - if str == "" { - return nil - } - return str -} - -// StmtRUSummary is the request-units summary for each type of statements. -type StmtRUSummary struct { - SumRRU float64 `json:"sum_rru"` - SumWRU float64 `json:"sum_wru"` - SumRUWaitDuration time.Duration `json:"sum_ru_wait_duration"` - MaxRRU float64 `json:"max_rru"` - MaxWRU float64 `json:"max_wru"` - MaxRUWaitDuration time.Duration `json:"max_ru_wait_duration"` -} - -// Add add a new sample value to the ru summary record. -func (s *StmtRUSummary) Add(info *util.RUDetails) { - if info != nil { - rru := info.RRU() - s.SumRRU += rru - if s.MaxRRU < rru { - s.MaxRRU = rru - } - wru := info.WRU() - s.SumWRU += wru - if s.MaxWRU < wru { - s.MaxWRU = wru - } - ruWaitDur := info.RUWaitDuration() - s.SumRUWaitDuration += ruWaitDur - if s.MaxRUWaitDuration < ruWaitDur { - s.MaxRUWaitDuration = ruWaitDur - } - } -} - -// Merge merges the value of 2 ru summary records. -func (s *StmtRUSummary) Merge(other *StmtRUSummary) { - s.SumRRU += other.SumRRU - s.SumWRU += other.SumWRU - s.SumRUWaitDuration += other.SumRUWaitDuration - if s.MaxRRU < other.MaxRRU { - s.MaxRRU = other.MaxRRU - } - if s.MaxWRU < other.MaxWRU { - s.MaxWRU = other.MaxWRU - } - if s.MaxRUWaitDuration < other.MaxRUWaitDuration { - s.MaxRUWaitDuration = other.MaxRUWaitDuration - } -} diff --git a/pkg/util/topsql/binding__failpoint_binding__.go b/pkg/util/topsql/binding__failpoint_binding__.go deleted file mode 100644 index 2baa7cc332089..0000000000000 --- a/pkg/util/topsql/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package topsql - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/topsql/reporter/binding__failpoint_binding__.go b/pkg/util/topsql/reporter/binding__failpoint_binding__.go deleted file mode 100644 index 2b1d47d228c18..0000000000000 --- a/pkg/util/topsql/reporter/binding__failpoint_binding__.go +++ /dev/null @@ -1,14 +0,0 @@ - -package reporter - -import "reflect" - -type __failpointBindingType struct {pkgpath string} -var __failpointBindingCache = &__failpointBindingType{} - -func init() { - __failpointBindingCache.pkgpath = reflect.TypeOf(__failpointBindingType{}).PkgPath() -} -func _curpkg_(name string) string { - return __failpointBindingCache.pkgpath + "/" + name -} diff --git a/pkg/util/topsql/reporter/pubsub.go b/pkg/util/topsql/reporter/pubsub.go index ceb6139d5fbea..a9ce1fe7e9b8c 100644 --- a/pkg/util/topsql/reporter/pubsub.go +++ b/pkg/util/topsql/reporter/pubsub.go @@ -138,7 +138,7 @@ func (ds *pubSubDataSink) run() error { return ctx.Err() } - failpoint.Eval(_curpkg_("mockGrpcLogPanic")) + failpoint.Inject("mockGrpcLogPanic", nil) if err != nil { logutil.BgLogger().Warn( "[top-sql] pubsub datasink failed to send data to subscriber", diff --git a/pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ b/pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ deleted file mode 100644 index a9ce1fe7e9b8c..0000000000000 --- a/pkg/util/topsql/reporter/pubsub.go__failpoint_stash__ +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reporter - -import ( - "context" - "errors" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/logutil" - reporter_metrics "github.com/pingcap/tidb/pkg/util/topsql/reporter/metrics" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" -) - -// TopSQLPubSubService implements tipb.TopSQLPubSubServer. -// -// If a client subscribes to TopSQL records, the TopSQLPubSubService is responsible -// for registering an associated DataSink to the reporter. Then the DataSink sends -// data to the client periodically. -type TopSQLPubSubService struct { - dataSinkRegisterer DataSinkRegisterer -} - -// NewTopSQLPubSubService creates a new TopSQLPubSubService. -func NewTopSQLPubSubService(dataSinkRegisterer DataSinkRegisterer) *TopSQLPubSubService { - return &TopSQLPubSubService{dataSinkRegisterer: dataSinkRegisterer} -} - -var _ tipb.TopSQLPubSubServer = &TopSQLPubSubService{} - -// Subscribe registers dataSinks to the reporter and redirects data received from reporter -// to subscribers associated with those dataSinks. -func (ps *TopSQLPubSubService) Subscribe(_ *tipb.TopSQLSubRequest, stream tipb.TopSQLPubSub_SubscribeServer) error { - ds := newPubSubDataSink(stream, ps.dataSinkRegisterer) - if err := ps.dataSinkRegisterer.Register(ds); err != nil { - return err - } - return ds.run() -} - -type pubSubDataSink struct { - ctx context.Context - cancel context.CancelFunc - - stream tipb.TopSQLPubSub_SubscribeServer - sendTaskCh chan sendTask - - // for deregister - registerer DataSinkRegisterer -} - -func newPubSubDataSink(stream tipb.TopSQLPubSub_SubscribeServer, registerer DataSinkRegisterer) *pubSubDataSink { - ctx, cancel := context.WithCancel(stream.Context()) - - return &pubSubDataSink{ - ctx: ctx, - cancel: cancel, - - stream: stream, - sendTaskCh: make(chan sendTask, 1), - - registerer: registerer, - } -} - -var _ DataSink = &pubSubDataSink{} - -func (ds *pubSubDataSink) TrySend(data *ReportData, deadline time.Time) error { - select { - case ds.sendTaskCh <- sendTask{data: data, deadline: deadline}: - return nil - case <-ds.ctx.Done(): - return ds.ctx.Err() - default: - reporter_metrics.IgnoreReportChannelFullCounter.Inc() - return errors.New("the channel of pubsub dataSink is full") - } -} - -func (ds *pubSubDataSink) OnReporterClosing() { - ds.cancel() -} - -func (ds *pubSubDataSink) run() error { - defer func() { - if r := recover(); r != nil { - // To catch panic when log grpc error. https://github.com/pingcap/tidb/issues/51301. - logutil.BgLogger().Error("[top-sql] got panic in pub sub data sink, just ignore", zap.Error(util.GetRecoverError(r))) - } - ds.registerer.Deregister(ds) - ds.cancel() - }() - - for { - select { - case task := <-ds.sendTaskCh: - ctx, cancel := context.WithDeadline(ds.ctx, task.deadline) - var err error - - start := time.Now() - go util.WithRecovery(func() { - defer cancel() - err = ds.doSend(ctx, task.data) - - if err != nil { - reporter_metrics.ReportAllDurationFailedHistogram.Observe(time.Since(start).Seconds()) - } else { - reporter_metrics.ReportAllDurationSuccHistogram.Observe(time.Since(start).Seconds()) - } - }, nil) - - // When the deadline is exceeded, the closure inside `go util.WithRecovery` above may not notice that - // immediately because it can be blocked by `stream.Send`. - // In order to clean up resources as quickly as possible, we let that closure run in an individual goroutine, - // and wait for timeout here. - <-ctx.Done() - - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - logutil.BgLogger().Warn( - "[top-sql] pubsub datasink failed to send data to subscriber due to deadline exceeded", - zap.Time("deadline", task.deadline), - ) - return ctx.Err() - } - - failpoint.Inject("mockGrpcLogPanic", nil) - if err != nil { - logutil.BgLogger().Warn( - "[top-sql] pubsub datasink failed to send data to subscriber", - zap.Error(err), - ) - return err - } - case <-ds.ctx.Done(): - return ds.ctx.Err() - } - } -} - -func (ds *pubSubDataSink) doSend(ctx context.Context, data *ReportData) error { - if err := ds.sendTopSQLRecords(ctx, data.DataRecords); err != nil { - return err - } - if err := ds.sendSQLMeta(ctx, data.SQLMetas); err != nil { - return err - } - return ds.sendPlanMeta(ctx, data.PlanMetas) -} - -func (ds *pubSubDataSink) sendTopSQLRecords(ctx context.Context, records []tipb.TopSQLRecord) (err error) { - if len(records) == 0 { - return - } - - start := time.Now() - sentCount := 0 - defer func() { - reporter_metrics.TopSQLReportRecordCounterHistogram.Observe(float64(sentCount)) - if err != nil { - reporter_metrics.ReportRecordDurationFailedHistogram.Observe(time.Since(start).Seconds()) - } else { - reporter_metrics.ReportRecordDurationSuccHistogram.Observe(time.Since(start).Seconds()) - } - }() - - topSQLRecord := &tipb.TopSQLSubResponse_Record{} - r := &tipb.TopSQLSubResponse{RespOneof: topSQLRecord} - - for i := range records { - topSQLRecord.Record = &records[i] - if err = ds.stream.Send(r); err != nil { - return - } - sentCount++ - - select { - case <-ctx.Done(): - err = ctx.Err() - return - default: - } - } - - return -} - -func (ds *pubSubDataSink) sendSQLMeta(ctx context.Context, sqlMetas []tipb.SQLMeta) (err error) { - if len(sqlMetas) == 0 { - return - } - - start := time.Now() - sentCount := 0 - defer func() { - reporter_metrics.TopSQLReportSQLCountHistogram.Observe(float64(sentCount)) - if err != nil { - reporter_metrics.ReportSQLDurationFailedHistogram.Observe(time.Since(start).Seconds()) - } else { - reporter_metrics.ReportSQLDurationSuccHistogram.Observe(time.Since(start).Seconds()) - } - }() - - sqlMeta := &tipb.TopSQLSubResponse_SqlMeta{} - r := &tipb.TopSQLSubResponse{RespOneof: sqlMeta} - - for i := range sqlMetas { - sqlMeta.SqlMeta = &sqlMetas[i] - if err = ds.stream.Send(r); err != nil { - return - } - sentCount++ - - select { - case <-ctx.Done(): - err = ctx.Err() - return - default: - } - } - - return -} - -func (ds *pubSubDataSink) sendPlanMeta(ctx context.Context, planMetas []tipb.PlanMeta) (err error) { - if len(planMetas) == 0 { - return - } - - start := time.Now() - sentCount := 0 - defer func() { - reporter_metrics.TopSQLReportPlanCountHistogram.Observe(float64(sentCount)) - if err != nil { - reporter_metrics.ReportPlanDurationFailedHistogram.Observe(time.Since(start).Seconds()) - } else { - reporter_metrics.ReportPlanDurationSuccHistogram.Observe(time.Since(start).Seconds()) - } - }() - - planMeta := &tipb.TopSQLSubResponse_PlanMeta{} - r := &tipb.TopSQLSubResponse{RespOneof: planMeta} - - for i := range planMetas { - planMeta.PlanMeta = &planMetas[i] - if err = ds.stream.Send(r); err != nil { - return - } - sentCount++ - - select { - case <-ctx.Done(): - err = ctx.Err() - return - default: - } - } - - return -} diff --git a/pkg/util/topsql/reporter/reporter.go b/pkg/util/topsql/reporter/reporter.go index 40afca010e75c..ddb2fd4c81c82 100644 --- a/pkg/util/topsql/reporter/reporter.go +++ b/pkg/util/topsql/reporter/reporter.go @@ -287,14 +287,14 @@ func (tsr *RemoteTopSQLReporter) doReport(data *ReportData) { return } timeout := reportTimeout - if val, _err_ := failpoint.Eval(_curpkg_("resetTimeoutForTest")); _err_ == nil { + failpoint.Inject("resetTimeoutForTest", func(val failpoint.Value) { if val.(bool) { interval := time.Duration(topsqlstate.GlobalState.ReportIntervalSeconds.Load()) * time.Second if interval < timeout { timeout = interval } } - } + }) _ = tsr.trySend(data, time.Now().Add(timeout)) } diff --git a/pkg/util/topsql/reporter/reporter.go__failpoint_stash__ b/pkg/util/topsql/reporter/reporter.go__failpoint_stash__ deleted file mode 100644 index ddb2fd4c81c82..0000000000000 --- a/pkg/util/topsql/reporter/reporter.go__failpoint_stash__ +++ /dev/null @@ -1,333 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reporter - -import ( - "context" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/util" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/topsql/collector" - reporter_metrics "github.com/pingcap/tidb/pkg/util/topsql/reporter/metrics" - topsqlstate "github.com/pingcap/tidb/pkg/util/topsql/state" - "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" - "go.uber.org/zap" -) - -const ( - reportTimeout = 40 * time.Second - collectChanBufferSize = 2 -) - -var nowFunc = time.Now - -// TopSQLReporter collects Top SQL metrics. -type TopSQLReporter interface { - collector.Collector - stmtstats.Collector - - // Start uses to start the reporter. - Start() - - // RegisterSQL registers a normalizedSQL with SQLDigest. - // - // Note that the normalized SQL string can be of >1M long. - // This function should be thread-safe, which means concurrently calling it - // in several goroutines should be fine. It should also return immediately, - // and do any CPU-intensive job asynchronously. - RegisterSQL(sqlDigest []byte, normalizedSQL string, isInternal bool) - - // RegisterPlan like RegisterSQL, but for normalized plan strings. - // isLarge indicates the size of normalizedPlan is big. - RegisterPlan(planDigest []byte, normalizedPlan string, isLarge bool) - - // Close uses to close and release the reporter resource. - Close() -} - -var _ TopSQLReporter = &RemoteTopSQLReporter{} -var _ DataSinkRegisterer = &RemoteTopSQLReporter{} - -// RemoteTopSQLReporter implements TopSQLReporter that sends data to a remote agent. -// This should be called periodically to collect TopSQL resource usage metrics. -type RemoteTopSQLReporter struct { - ctx context.Context - reportCollectedDataChan chan collectedData - cancel context.CancelFunc - sqlCPUCollector *collector.SQLCPUCollector - collectCPUTimeChan chan []collector.SQLCPUTimeRecord - collectStmtStatsChan chan stmtstats.StatementStatsMap - collecting *collecting - normalizedSQLMap *normalizedSQLMap - normalizedPlanMap *normalizedPlanMap - stmtStatsBuffer map[uint64]stmtstats.StatementStatsMap // timestamp => stmtstats.StatementStatsMap - // calling decodePlan this can take a while, so should not block critical paths. - decodePlan planBinaryDecodeFunc - // Instead of dropping large plans, we compress it into encoded format and report - compressPlan planBinaryCompressFunc - DefaultDataSinkRegisterer -} - -// NewRemoteTopSQLReporter creates a new RemoteTopSQLReporter. -// -// decodePlan is a decoding function which will be called asynchronously to decode the plan binary to string. -func NewRemoteTopSQLReporter(decodePlan planBinaryDecodeFunc, compressPlan planBinaryCompressFunc) *RemoteTopSQLReporter { - ctx, cancel := context.WithCancel(context.Background()) - tsr := &RemoteTopSQLReporter{ - DefaultDataSinkRegisterer: NewDefaultDataSinkRegisterer(ctx), - ctx: ctx, - cancel: cancel, - collectCPUTimeChan: make(chan []collector.SQLCPUTimeRecord, collectChanBufferSize), - collectStmtStatsChan: make(chan stmtstats.StatementStatsMap, collectChanBufferSize), - reportCollectedDataChan: make(chan collectedData, 1), - collecting: newCollecting(), - normalizedSQLMap: newNormalizedSQLMap(), - normalizedPlanMap: newNormalizedPlanMap(), - stmtStatsBuffer: map[uint64]stmtstats.StatementStatsMap{}, - decodePlan: decodePlan, - compressPlan: compressPlan, - } - tsr.sqlCPUCollector = collector.NewSQLCPUCollector(tsr) - return tsr -} - -// Start implements the TopSQLReporter interface. -func (tsr *RemoteTopSQLReporter) Start() { - tsr.sqlCPUCollector.Start() - go tsr.collectWorker() - go tsr.reportWorker() -} - -// Collect implements tracecpu.Collector. -// -// WARN: It will drop the DataRecords if the processing is not in time. -// This function is thread-safe and efficient. -func (tsr *RemoteTopSQLReporter) Collect(data []collector.SQLCPUTimeRecord) { - if len(data) == 0 { - return - } - select { - case tsr.collectCPUTimeChan <- data: - default: - // ignore if chan blocked - reporter_metrics.IgnoreCollectChannelFullCounter.Inc() - } -} - -// CollectStmtStatsMap implements stmtstats.Collector. -// -// WARN: It will drop the DataRecords if the processing is not in time. -// This function is thread-safe and efficient. -func (tsr *RemoteTopSQLReporter) CollectStmtStatsMap(data stmtstats.StatementStatsMap) { - if len(data) == 0 { - return - } - select { - case tsr.collectStmtStatsChan <- data: - default: - // ignore if chan blocked - reporter_metrics.IgnoreCollectStmtChannelFullCounter.Inc() - } -} - -// RegisterSQL implements TopSQLReporter. -// -// This function is thread-safe and efficient. -func (tsr *RemoteTopSQLReporter) RegisterSQL(sqlDigest []byte, normalizedSQL string, isInternal bool) { - tsr.normalizedSQLMap.register(sqlDigest, normalizedSQL, isInternal) -} - -// RegisterPlan implements TopSQLReporter. -// -// This function is thread-safe and efficient. -func (tsr *RemoteTopSQLReporter) RegisterPlan(planDigest []byte, normalizedPlan string, isLarge bool) { - tsr.normalizedPlanMap.register(planDigest, normalizedPlan, isLarge) -} - -// Close implements TopSQLReporter. -func (tsr *RemoteTopSQLReporter) Close() { - tsr.cancel() - tsr.sqlCPUCollector.Stop() - tsr.onReporterClosing() -} - -// collectWorker consumes and collects data from tracecpu.Collector/stmtstats.Collector. -func (tsr *RemoteTopSQLReporter) collectWorker() { - defer util.Recover("top-sql", "collectWorker", nil, false) - - currentReportInterval := topsqlstate.GlobalState.ReportIntervalSeconds.Load() - reportTicker := time.NewTicker(time.Second * time.Duration(currentReportInterval)) - defer reportTicker.Stop() - for { - select { - case <-tsr.ctx.Done(): - return - case data := <-tsr.collectCPUTimeChan: - timestamp := uint64(nowFunc().Unix()) - tsr.processCPUTimeData(timestamp, data) - case data := <-tsr.collectStmtStatsChan: - timestamp := uint64(nowFunc().Unix()) - tsr.stmtStatsBuffer[timestamp] = data - case <-reportTicker.C: - tsr.processStmtStatsData() - tsr.takeDataAndSendToReportChan() - // Update `reportTicker` if report interval changed. - if newInterval := topsqlstate.GlobalState.ReportIntervalSeconds.Load(); newInterval != currentReportInterval { - currentReportInterval = newInterval - reportTicker.Reset(time.Second * time.Duration(currentReportInterval)) - } - } - } -} - -// processCPUTimeData collects top N cpuRecords of each round into tsr.collecting, and evict the -// data that is not in top N. All the evicted cpuRecords will be summary into the others. -func (tsr *RemoteTopSQLReporter) processCPUTimeData(timestamp uint64, data cpuRecords) { - defer util.Recover("top-sql", "processCPUTimeData", nil, false) - - // Get top N cpuRecords of each round cpuRecords. Collect the top N to tsr.collecting - // for each round. SQL meta will not be evicted, since the evicted SQL can be appeared - // on other components (TiKV) TopN DataRecords. - top, evicted := data.topN(int(topsqlstate.GlobalState.MaxStatementCount.Load())) - for _, r := range top { - tsr.collecting.getOrCreateRecord(r.SQLDigest, r.PlanDigest).appendCPUTime(timestamp, r.CPUTimeMs) - } - if len(evicted) == 0 { - return - } - totalEvictedCPUTime := uint32(0) - for _, e := range evicted { - totalEvictedCPUTime += e.CPUTimeMs - // Mark which digests are evicted under each timestamp. - // We will determine whether the corresponding CPUTime has been evicted - // when collecting stmtstats. If so, then we can ignore it directly. - tsr.collecting.markAsEvicted(timestamp, e.SQLDigest, e.PlanDigest) - } - tsr.collecting.appendOthersCPUTime(timestamp, totalEvictedCPUTime) -} - -// processStmtStatsData collects tsr.stmtStatsBuffer into tsr.collecting. -// All the evicted items will be summary into the others. -func (tsr *RemoteTopSQLReporter) processStmtStatsData() { - defer util.Recover("top-sql", "processStmtStatsData", nil, false) - - for timestamp, data := range tsr.stmtStatsBuffer { - for digest, item := range data { - sqlDigest, planDigest := []byte(digest.SQLDigest), []byte(digest.PlanDigest) - if tsr.collecting.hasEvicted(timestamp, sqlDigest, planDigest) { - // This timestamp+sql+plan has been evicted due to low CPUTime. - tsr.collecting.appendOthersStmtStatsItem(timestamp, *item) - continue - } - tsr.collecting.getOrCreateRecord(sqlDigest, planDigest).appendStmtStatsItem(timestamp, *item) - } - } - tsr.stmtStatsBuffer = map[uint64]stmtstats.StatementStatsMap{} -} - -// takeDataAndSendToReportChan takes records data and then send to the report channel for reporting. -func (tsr *RemoteTopSQLReporter) takeDataAndSendToReportChan() { - // Send to report channel. When channel is full, data will be dropped. - select { - case tsr.reportCollectedDataChan <- collectedData{ - collected: tsr.collecting.take(), - normalizedSQLMap: tsr.normalizedSQLMap.take(), - normalizedPlanMap: tsr.normalizedPlanMap.take(), - }: - default: - // ignore if chan blocked - reporter_metrics.IgnoreReportChannelFullCounter.Inc() - } -} - -// reportWorker sends data to the gRPC endpoint from the `reportCollectedDataChan` one by one. -func (tsr *RemoteTopSQLReporter) reportWorker() { - defer util.Recover("top-sql", "reportWorker", nil, false) - - for { - select { - case data := <-tsr.reportCollectedDataChan: - // When `reportCollectedDataChan` receives something, there could be ongoing - // `RegisterSQL` and `RegisterPlan` running, who writes to the data structure - // that `data` contains. So we wait for a little while to ensure that writes - // are finished. - time.Sleep(time.Millisecond * 100) - rs := data.collected.getReportRecords() - // Convert to protobuf data and do report. - tsr.doReport(&ReportData{ - DataRecords: rs.toProto(), - SQLMetas: data.normalizedSQLMap.toProto(), - PlanMetas: data.normalizedPlanMap.toProto(tsr.decodePlan, tsr.compressPlan), - }) - case <-tsr.ctx.Done(): - return - } - } -} - -// doReport sends ReportData to DataSinks. -func (tsr *RemoteTopSQLReporter) doReport(data *ReportData) { - defer util.Recover("top-sql", "doReport", nil, false) - - if !data.hasData() { - return - } - timeout := reportTimeout - failpoint.Inject("resetTimeoutForTest", func(val failpoint.Value) { - if val.(bool) { - interval := time.Duration(topsqlstate.GlobalState.ReportIntervalSeconds.Load()) * time.Second - if interval < timeout { - timeout = interval - } - } - }) - _ = tsr.trySend(data, time.Now().Add(timeout)) -} - -// trySend sends ReportData to all internal registered DataSinks. -func (tsr *RemoteTopSQLReporter) trySend(data *ReportData, deadline time.Time) error { - tsr.DefaultDataSinkRegisterer.Lock() - dataSinks := make([]DataSink, 0, len(tsr.dataSinks)) - for ds := range tsr.dataSinks { - dataSinks = append(dataSinks, ds) - } - tsr.DefaultDataSinkRegisterer.Unlock() - for _, ds := range dataSinks { - if err := ds.TrySend(data, deadline); err != nil { - logutil.BgLogger().Warn("failed to send data to datasink", zap.String("category", "top-sql"), zap.Error(err)) - } - } - return nil -} - -// onReporterClosing calls the OnReporterClosing method of all internally registered DataSinks. -func (tsr *RemoteTopSQLReporter) onReporterClosing() { - var m map[DataSink]struct{} - tsr.DefaultDataSinkRegisterer.Lock() - m, tsr.dataSinks = tsr.dataSinks, make(map[DataSink]struct{}) - tsr.DefaultDataSinkRegisterer.Unlock() - for d := range m { - d.OnReporterClosing() - } -} - -// collectedData is used for transmission in the channel. -type collectedData struct { - collected *collecting - normalizedSQLMap *normalizedSQLMap - normalizedPlanMap *normalizedPlanMap -} diff --git a/pkg/util/topsql/topsql.go b/pkg/util/topsql/topsql.go index daa237f229fb7..d6818051125c7 100644 --- a/pkg/util/topsql/topsql.go +++ b/pkg/util/topsql/topsql.go @@ -106,7 +106,7 @@ func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDige linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) - if val, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForEachSQL")); _err_ == nil { + failpoint.Inject("mockHighLoadForEachSQL", func(val failpoint.Value) { // In integration test, some SQL run very fast that Top SQL pprof profile unable to sample data of those SQL, // So need mock some high cpu load to make sure pprof profile successfully samples the data of those SQL. // Attention: Top SQL pprof profile unable to sample data of those SQL which run very fast, this behavior is expected. @@ -118,7 +118,7 @@ func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDige logutil.BgLogger().Info("attach SQL info", zap.String("sql", normalizedSQL)) } } - } + }) return ctx } @@ -135,14 +135,14 @@ func AttachSQLAndPlanInfo(ctx context.Context, sqlDigest *parser.Digest, planDig ctx = collector.CtxWithSQLAndPlanDigest(ctx, sqlDigestStr, planDigestStr) pprof.SetGoroutineLabels(ctx) - if val, _err_ := failpoint.Eval(_curpkg_("mockHighLoadForEachPlan")); _err_ == nil { + failpoint.Inject("mockHighLoadForEachPlan", func(val failpoint.Value) { // Work like mockHighLoadForEachSQL failpoint. if val.(bool) { if MockHighCPULoad("", []string{""}, 1) { logutil.BgLogger().Info("attach SQL info") } } - } + }) return ctx } diff --git a/pkg/util/topsql/topsql.go__failpoint_stash__ b/pkg/util/topsql/topsql.go__failpoint_stash__ deleted file mode 100644 index d6818051125c7..0000000000000 --- a/pkg/util/topsql/topsql.go__failpoint_stash__ +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2021 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package topsql - -import ( - "context" - "runtime/pprof" - "strings" - "time" - - "github.com/pingcap/failpoint" - "github.com/pingcap/tidb/pkg/parser" - "github.com/pingcap/tidb/pkg/util/logutil" - "github.com/pingcap/tidb/pkg/util/plancodec" - "github.com/pingcap/tidb/pkg/util/topsql/collector" - "github.com/pingcap/tidb/pkg/util/topsql/reporter" - "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" - "github.com/pingcap/tipb/go-tipb" - "go.uber.org/zap" - "google.golang.org/grpc" -) - -const ( - // MaxSQLTextSize exports for testing. - MaxSQLTextSize = 4 * 1024 - // MaxBinaryPlanSize exports for testing. - MaxBinaryPlanSize = 2 * 1024 -) - -var ( - globalTopSQLReport reporter.TopSQLReporter - singleTargetDataSink *reporter.SingleTargetDataSink -) - -func init() { - remoteReporter := reporter.NewRemoteTopSQLReporter(plancodec.DecodeNormalizedPlan, plancodec.Compress) - globalTopSQLReport = remoteReporter - singleTargetDataSink = reporter.NewSingleTargetDataSink(remoteReporter) -} - -// SetupTopSQL sets up the top-sql worker. -func SetupTopSQL() { - globalTopSQLReport.Start() - singleTargetDataSink.Start() - - stmtstats.RegisterCollector(globalTopSQLReport) - stmtstats.SetupAggregator() -} - -// SetupTopSQLForTest sets up the global top-sql reporter, it's exporting for test. -func SetupTopSQLForTest(r reporter.TopSQLReporter) { - globalTopSQLReport = r -} - -// RegisterPubSubServer registers TopSQLPubSubService to the given gRPC server. -func RegisterPubSubServer(s *grpc.Server) { - if register, ok := globalTopSQLReport.(reporter.DataSinkRegisterer); ok { - service := reporter.NewTopSQLPubSubService(register) - tipb.RegisterTopSQLPubSubServer(s, service) - } -} - -// Close uses to close and release the top sql resource. -func Close() { - singleTargetDataSink.Close() - globalTopSQLReport.Close() - stmtstats.CloseAggregator() -} - -// RegisterSQL uses to register SQL information into Top SQL. -func RegisterSQL(normalizedSQL string, sqlDigest *parser.Digest, isInternal bool) { - if sqlDigest != nil { - sqlDigestBytes := sqlDigest.Bytes() - linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) - } -} - -// RegisterPlan uses to register plan information into Top SQL. -func RegisterPlan(normalizedPlan string, planDigest *parser.Digest) { - if planDigest != nil { - planDigestBytes := planDigest.Bytes() - linkPlanTextWithDigest(planDigestBytes, normalizedPlan) - } -} - -// AttachAndRegisterSQLInfo attach the sql information into Top SQL and register the SQL meta information. -func AttachAndRegisterSQLInfo(ctx context.Context, normalizedSQL string, sqlDigest *parser.Digest, isInternal bool) context.Context { - if sqlDigest == nil || len(sqlDigest.String()) == 0 { - return ctx - } - sqlDigestBytes := sqlDigest.Bytes() - ctx = collector.CtxWithSQLDigest(ctx, sqlDigest.String()) - pprof.SetGoroutineLabels(ctx) - - linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) - - failpoint.Inject("mockHighLoadForEachSQL", func(val failpoint.Value) { - // In integration test, some SQL run very fast that Top SQL pprof profile unable to sample data of those SQL, - // So need mock some high cpu load to make sure pprof profile successfully samples the data of those SQL. - // Attention: Top SQL pprof profile unable to sample data of those SQL which run very fast, this behavior is expected. - // The integration test was just want to make sure each type of SQL will be set goroutine labels and and can be collected. - if val.(bool) { - sqlPrefixes := []string{"insert", "update", "delete", "load", "replace", "select", "begin", - "commit", "analyze", "explain", "trace", "create", "set global"} - if MockHighCPULoad(normalizedSQL, sqlPrefixes, 1) { - logutil.BgLogger().Info("attach SQL info", zap.String("sql", normalizedSQL)) - } - } - }) - return ctx -} - -// AttachSQLAndPlanInfo attach the sql and plan information into Top SQL -func AttachSQLAndPlanInfo(ctx context.Context, sqlDigest *parser.Digest, planDigest *parser.Digest) context.Context { - if sqlDigest == nil || len(sqlDigest.String()) == 0 { - return ctx - } - var planDigestStr string - sqlDigestStr := sqlDigest.String() - if planDigest != nil { - planDigestStr = planDigest.String() - } - ctx = collector.CtxWithSQLAndPlanDigest(ctx, sqlDigestStr, planDigestStr) - pprof.SetGoroutineLabels(ctx) - - failpoint.Inject("mockHighLoadForEachPlan", func(val failpoint.Value) { - // Work like mockHighLoadForEachSQL failpoint. - if val.(bool) { - if MockHighCPULoad("", []string{""}, 1) { - logutil.BgLogger().Info("attach SQL info") - } - } - }) - return ctx -} - -// MockHighCPULoad mocks high cpu load, only use in failpoint test. -func MockHighCPULoad(sql string, sqlPrefixs []string, load int64) bool { - lowerSQL := strings.ToLower(sql) - if strings.Contains(lowerSQL, "mysql") && !strings.Contains(lowerSQL, "global_variables") { - return false - } - match := false - for _, prefix := range sqlPrefixs { - if strings.HasPrefix(lowerSQL, prefix) { - match = true - break - } - } - if !match { - return false - } - start := time.Now() - for { - if time.Since(start) > 12*time.Millisecond*time.Duration(load) { - break - } - for i := 0; i < 10e5; i++ { - continue - } - } - return true -} - -func linkSQLTextWithDigest(sqlDigest []byte, normalizedSQL string, isInternal bool) { - if len(normalizedSQL) > MaxSQLTextSize { - normalizedSQL = normalizedSQL[:MaxSQLTextSize] - } - - globalTopSQLReport.RegisterSQL(sqlDigest, normalizedSQL, isInternal) -} - -func linkPlanTextWithDigest(planDigest []byte, normalizedBinaryPlan string) { - globalTopSQLReport.RegisterPlan(planDigest, normalizedBinaryPlan, len(normalizedBinaryPlan) > MaxBinaryPlanSize) -} From 308977be236549b8382fca145317bf56db1ca375 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 15:25:38 -0500 Subject: [PATCH 08/35] testcase updates7 --- pkg/planner/cardinality/selectivity_test.go | 4 +-- .../partition_with_expression.result | 36 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 20d8abf39e825..9aaed714dce80 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -1212,8 +1212,8 @@ func TestIgnoreRealtimeStats(t *testing.T) { // From the real-time stats, we are able to know the total count is 11. testKit.MustExec("set @@tidb_opt_objective = 'moderate'") testKit.MustQuery("explain select * from t where a = 1 and b > 2").Check(testkit.Rows( - "TableReader_7 1.00 root data:Selection_6", - "└─Selection_6 1.00 cop[tikv] eq(test.t.a, 1), gt(test.t.b, 2)", + "TableReader_7 0.00 root data:Selection_6", + "└─Selection_6 0.00 cop[tikv] eq(test.t.a, 1), gt(test.t.b, 2)", " └─TableFullScan_5 11.00 cop[tikv] table:t keep order:false, stats:pseudo", )) diff --git a/tests/integrationtest/r/executor/partition/partition_with_expression.result b/tests/integrationtest/r/executor/partition/partition_with_expression.result index 9ba24d016ecd1..81c168a8d0bc1 100644 --- a/tests/integrationtest/r/executor/partition/partition_with_expression.result +++ b/tests/integrationtest/r/executor/partition/partition_with_expression.result @@ -184,8 +184,8 @@ analyze table tp all columns; analyze table t all columns; explain select * from tp where a < '10'; id estRows task access object operator info -TableReader_7 0.00 root partition:p0 data:Selection_6 -└─Selection_6 0.00 cop[tikv] lt(executor__partition__partition_with_expression.tp.a, "10") +TableReader_7 1.00 root partition:p0 data:Selection_6 +└─Selection_6 1.00 cop[tikv] lt(executor__partition__partition_with_expression.tp.a, "10") └─TableFullScan_5 6.00 cop[tikv] table:tp keep order:false select * from tp where a < '10'; a b @@ -274,15 +274,15 @@ SELECT * from t where a = -1; a b explain format='brief' select * from trange where a = -1; id estRows task access object operator info -TableReader 0.00 root partition:p0 data:Selection -└─Selection 0.00 cop[tikv] eq(executor__partition__partition_with_expression.trange.a, -1) +TableReader 1.00 root partition:p0 data:Selection +└─Selection 1.00 cop[tikv] eq(executor__partition__partition_with_expression.trange.a, -1) └─TableFullScan 13.00 cop[tikv] table:trange keep order:false SELECT * from trange where a = -1; a b explain format='brief' select * from thash where a = -1; id estRows task access object operator info -TableReader 0.00 root partition:p1 data:Selection -└─Selection 0.00 cop[tikv] eq(executor__partition__partition_with_expression.thash.a, -1) +TableReader 1.00 root partition:p1 data:Selection +└─Selection 1.00 cop[tikv] eq(executor__partition__partition_with_expression.thash.a, -1) └─TableFullScan 13.00 cop[tikv] table:thash keep order:false SELECT * from thash where a = -1; a b @@ -411,15 +411,15 @@ SELECT * from t where a > 10; a b explain format='brief' select * from trange where a > 10; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_with_expression.trange.a, 10) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_with_expression.trange.a, 10) └─TableFullScan 13.00 cop[tikv] table:trange keep order:false SELECT * from trange where a > 10; a b explain format='brief' select * from thash where a > 10; id estRows task access object operator info -TableReader 0.00 root partition:all data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_with_expression.thash.a, 10) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_with_expression.thash.a, 10) └─TableFullScan 13.00 cop[tikv] table:thash keep order:false SELECT * from thash where a > 10; a b @@ -1219,15 +1219,15 @@ SELECT * from t where a > '10'; a b explain format='brief' select * from trange where a > '10'; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_with_expression.trange.a, 10) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_with_expression.trange.a, 10) └─TableFullScan 13.00 cop[tikv] table:trange keep order:false SELECT * from trange where a > '10'; a b explain format='brief' select * from thash where a > '10'; id estRows task access object operator info -TableReader 0.00 root partition:all data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_with_expression.thash.a, 10) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_with_expression.thash.a, 10) └─TableFullScan 13.00 cop[tikv] table:thash keep order:false SELECT * from thash where a > '10'; a b @@ -1235,15 +1235,15 @@ SELECT * from t where a > '10ab'; a b explain format='brief' select * from trange where a > '10ab'; id estRows task access object operator info -TableReader 0.00 root partition:dual data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_with_expression.trange.a, 10) +TableReader 1.00 root partition:dual data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_with_expression.trange.a, 10) └─TableFullScan 13.00 cop[tikv] table:trange keep order:false SELECT * from trange where a > '10ab'; a b explain format='brief' select * from thash where a > '10ab'; id estRows task access object operator info -TableReader 0.00 root partition:all data:Selection -└─Selection 0.00 cop[tikv] gt(executor__partition__partition_with_expression.thash.a, 10) +TableReader 1.00 root partition:all data:Selection +└─Selection 1.00 cop[tikv] gt(executor__partition__partition_with_expression.thash.a, 10) └─TableFullScan 13.00 cop[tikv] table:thash keep order:false SELECT * from thash where a > '10ab'; a b From 33659970d8359dd35b9dd2cf9ab917150fdfbb53 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 15:43:48 -0500 Subject: [PATCH 09/35] testcase updates8 --- pkg/planner/cardinality/selectivity_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 9aaed714dce80..8dcce8ca7ff45 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -235,7 +235,7 @@ func TestEstimationForUnknownValues(t *testing.T) { colID := table.Meta().Columns[0].ID count, err := cardinality.GetRowCountByColumnRanges(sctx, &statsTbl.HistColl, colID, getRange(30, 30)) require.NoError(t, err) - require.Equal(t, 1.2, count) + require.Equal(t, 1.0, count) count, err = cardinality.GetRowCountByColumnRanges(sctx, &statsTbl.HistColl, colID, getRange(9, 30)) require.NoError(t, err) From 38d604d2edb4ea9766ce0041b145db06383f0ec0 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 16:14:12 -0500 Subject: [PATCH 10/35] testcase updates9 --- .../casetest/cbotest/testdata/analyze_suite_out.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json b/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json index 2d4ce04bd6001..217ac1bedd0ca 100644 --- a/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json +++ b/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json @@ -364,13 +364,13 @@ "└─TableRowIDScan(Probe) 2.00 cop[tikv] table:t keep order:false" ], [ - "TableReader 0.00 root data:Selection", - "└─Selection 0.00 cop[tikv] eq(test.t.b, 1)", + "TableReader 1.00 root data:Selection", + "└─Selection 1.00 cop[tikv] eq(test.t.b, 1)", " └─TableFullScan 2.00 cop[tikv] table:t keep order:false" ], [ - "TableReader 0.00 root data:Selection", - "└─Selection 0.00 cop[tikv] lt(test.t.b, 1)", + "TableReader 1.00 root data:Selection", + "└─Selection 1.00 cop[tikv] lt(test.t.b, 1)", " └─TableFullScan 2.00 cop[tikv] table:t keep order:false" ] ] @@ -435,7 +435,7 @@ "Cases": [ "IndexReader(Index(t.e)[[NULL,+inf]]->StreamAgg)->StreamAgg", "IndexReader(Index(t.e)[[-inf,10]]->StreamAgg)->StreamAgg", - "IndexReader(Index(t.e)[[-inf,50]]->StreamAgg)->StreamAgg", + "IndexReader(Index(t.e)[[-inf,50]]->HashAgg)->HashAgg", "IndexReader(Index(t.b_c)[[NULL,+inf]]->Sel([gt(test.t.c, 1)])->StreamAgg)->StreamAgg", "IndexLookUp(Index(t.e)[[1,1]], Table(t))->HashAgg", "TableReader(Table(t)->Sel([gt(test.t.e, 1)])->HashAgg)->HashAgg", @@ -503,7 +503,7 @@ "TopN 1.00 root test.t.b, offset:0, count:1", "└─TableReader 1.00 root data:TopN", " └─TopN 1.00 cop[tikv] test.t.b, offset:0, count:1", - " └─Selection 10000.00 cop[tikv] le(test.t.a, 10000)", + " └─Selection 510000.00 cop[tikv] le(test.t.a, 10000)", " └─TableFullScan 1000000.00 cop[tikv] table:t keep order:false" ] }, From e122ec43554cc68e43fcf2ff63e3f0c23eacdc71 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 16:34:59 -0500 Subject: [PATCH 11/35] testcase updates10 --- pkg/planner/cardinality/selectivity_test.go | 2 +- pkg/planner/core/testdata/index_merge_suite_out.json | 10 +++++----- .../testdata/runtime_filter_generator_suite_out.json | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 8dcce8ca7ff45..b6d9c7534db8c 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -248,7 +248,7 @@ func TestEstimationForUnknownValues(t *testing.T) { idxID := table.Meta().Indices[0].ID count, err = cardinality.GetRowCountByIndexRanges(sctx, &statsTbl.HistColl, idxID, getRange(30, 30)) require.NoError(t, err) - require.Equal(t, 1.0, count) + require.Equal(t, 0.1, count) count, err = cardinality.GetRowCountByIndexRanges(sctx, &statsTbl.HistColl, idxID, getRange(9, 30)) require.NoError(t, err) diff --git a/pkg/planner/core/testdata/index_merge_suite_out.json b/pkg/planner/core/testdata/index_merge_suite_out.json index 0d98061beedd1..66cc33e82a435 100644 --- a/pkg/planner/core/testdata/index_merge_suite_out.json +++ b/pkg/planner/core/testdata/index_merge_suite_out.json @@ -252,7 +252,7 @@ { "SQL": "select * from vh", "Plan": [ - "PartitionUnion 0.50 root ", + "PartitionUnion 1.50 root ", "├─IndexMerge 0.50 root type: intersection", "│ ├─IndexRangeScan(Build) 2.00 cop[tikv] table:t1, partition:p0, index:ia(a) range:[10,10], keep order:false", "│ ├─IndexRangeScan(Build) 1.00 cop[tikv] table:t1, partition:p0, index:ibc(b, c) range:[20 -inf,20 30), keep order:false", @@ -276,7 +276,7 @@ { "SQL": "select /*+ qb_name(v, v), use_index_merge(@v t1, ia, ibc, id) */ * from v", "Plan": [ - "PartitionUnion 0.50 root ", + "PartitionUnion 1.50 root ", "├─IndexMerge 0.50 root type: intersection", "│ ├─IndexRangeScan(Build) 2.00 cop[tikv] table:t1, partition:p0, index:ia(a) range:[10,10], keep order:false", "│ ├─IndexRangeScan(Build) 1.00 cop[tikv] table:t1, partition:p0, index:ibc(b, c) range:[20 -inf,20 30), keep order:false", @@ -300,7 +300,7 @@ { "SQL": "select /*+ qb_name(v, v@sel_1), use_index_merge(@v t1, ia, ibc, id) */ * from v", "Plan": [ - "PartitionUnion 0.50 root ", + "PartitionUnion 1.50 root ", "├─IndexMerge 0.50 root type: intersection", "│ ├─IndexRangeScan(Build) 2.00 cop[tikv] table:t1, partition:p0, index:ia(a) range:[10,10], keep order:false", "│ ├─IndexRangeScan(Build) 1.00 cop[tikv] table:t1, partition:p0, index:ibc(b, c) range:[20 -inf,20 30), keep order:false", @@ -324,7 +324,7 @@ { "SQL": "select /*+ qb_name(v, v@sel_1 .@sel_1), use_index_merge(@v t1, ia, ibc, id) */ * from v", "Plan": [ - "PartitionUnion 0.50 root ", + "PartitionUnion 1.50 root ", "├─IndexMerge 0.50 root type: intersection", "│ ├─IndexRangeScan(Build) 2.00 cop[tikv] table:t1, partition:p0, index:ia(a) range:[10,10], keep order:false", "│ ├─IndexRangeScan(Build) 1.00 cop[tikv] table:t1, partition:p0, index:ibc(b, c) range:[20 -inf,20 30), keep order:false", @@ -348,7 +348,7 @@ { "SQL": "select /*+ qb_name(v, v@sel_1 .@sel_1), use_index_merge(@v t1, ia, ibc, id) */ * from v", "Plan": [ - "PartitionUnion 0.50 root ", + "PartitionUnion 1.50 root ", "├─IndexMerge 0.50 root type: intersection", "│ ├─IndexRangeScan(Build) 2.00 cop[tikv] table:t1, partition:p0, index:ia(a) range:[10,10], keep order:false", "│ ├─IndexRangeScan(Build) 1.00 cop[tikv] table:t1, partition:p0, index:ibc(b, c) range:[20 -inf,20 30), keep order:false", diff --git a/pkg/planner/core/testdata/runtime_filter_generator_suite_out.json b/pkg/planner/core/testdata/runtime_filter_generator_suite_out.json index 131c9e9c6c219..158cfc64b6e4b 100644 --- a/pkg/planner/core/testdata/runtime_filter_generator_suite_out.json +++ b/pkg/planner/core/testdata/runtime_filter_generator_suite_out.json @@ -5,14 +5,14 @@ { "SQL": "select /*+ hash_join_build(t1) */ * from t1, t2 where t1.k1=t2.k1 and t2.k2 = 1", "Plan": [ - "TableReader_32 0.00 root MppVersion: 2, data:ExchangeSender_31", - "└─ExchangeSender_31 0.00 mpp[tiflash] ExchangeType: PassThrough", - " └─HashJoin_24 0.00 mpp[tiflash] inner join, equal:[eq(test.t1.k1, test.t2.k1)], runtime filter:0[IN] <- test.t1.k1", + "TableReader_32 1.00 root MppVersion: 2, data:ExchangeSender_31", + "└─ExchangeSender_31 1.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashJoin_24 1.00 mpp[tiflash] inner join, equal:[eq(test.t1.k1, test.t2.k1)], runtime filter:0[IN] <- test.t1.k1", " ├─ExchangeReceiver_28(Build) 1.00 mpp[tiflash] ", " │ └─ExchangeSender_27 1.00 mpp[tiflash] ExchangeType: Broadcast, Compression: FAST", " │ └─Selection_26 1.00 mpp[tiflash] not(isnull(test.t1.k1))", " │ └─TableFullScan_25 1.00 mpp[tiflash] table:t1 pushed down filter:empty, keep order:false", - " └─Selection_30(Probe) 0.00 mpp[tiflash] eq(test.t2.k2, 1), not(isnull(test.t2.k1))", + " └─Selection_30(Probe) 1.00 mpp[tiflash] eq(test.t2.k2, 1), not(isnull(test.t2.k1))", " └─TableFullScan_29 1.00 mpp[tiflash] table:t2 pushed down filter:empty, keep order:false, runtime filter:0[IN] -> test.t2.k1" ] }, From 2a0f635281de7c2e71e05e17e73f8e625d51ad98 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 21:14:18 -0500 Subject: [PATCH 12/35] testcase updates11 --- pkg/statistics/handle/handletest/handle_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/statistics/handle/handletest/handle_test.go b/pkg/statistics/handle/handletest/handle_test.go index 51cb295f3adc9..5c83c9e23a8c5 100644 --- a/pkg/statistics/handle/handletest/handle_test.go +++ b/pkg/statistics/handle/handletest/handle_test.go @@ -95,7 +95,7 @@ func TestColumnIDs(t *testing.T) { // At that time, we should get c2's stats instead of c1's. count, err = cardinality.GetRowCountByColumnRanges(sctx, &statsTbl.HistColl, tableInfo.Columns[0].ID, []*ranger.Range{ran}) require.NoError(t, err) - require.Equal(t, 0.0, count) + require.Equal(t, 1.0, count) } func TestDurationToTS(t *testing.T) { From 42c9647272350e8b3eb779032d343bd65e843956 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 21:33:30 -0500 Subject: [PATCH 13/35] testcase updates12 --- pkg/planner/cardinality/selectivity_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index b6d9c7534db8c..7c101e497e437 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -282,7 +282,7 @@ func TestEstimationForUnknownValues(t *testing.T) { idxID = table.Meta().Indices[0].ID count, err = cardinality.GetRowCountByIndexRanges(sctx, &statsTbl.HistColl, idxID, getRange(2, 2)) require.NoError(t, err) - require.Equal(t, 1.0, count) + require.Equal(t, 0.0, count) } func TestEstimationUniqueKeyEqualConds(t *testing.T) { From 78b246bf4d19e9519fcd7877ff0a5022c300f2a5 Mon Sep 17 00:00:00 2001 From: tpp Date: Wed, 7 Aug 2024 22:01:35 -0500 Subject: [PATCH 14/35] testcase updates13 --- .../testdata/integration_partition_suite_out.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/planner/core/casetest/partition/testdata/integration_partition_suite_out.json b/pkg/planner/core/casetest/partition/testdata/integration_partition_suite_out.json index 71456e2d40833..e70e65ab26ccd 100644 --- a/pkg/planner/core/casetest/partition/testdata/integration_partition_suite_out.json +++ b/pkg/planner/core/casetest/partition/testdata/integration_partition_suite_out.json @@ -709,7 +709,7 @@ "└─IndexRangeScan 1.00 cop[tikv] table:t, index:b(b) range:[1,1], keep order:false" ], "StaticPlan": [ - "PartitionUnion 1.00 root ", + "PartitionUnion 3.00 root ", "├─IndexReader 1.00 root index:IndexRangeScan", "│ └─IndexRangeScan 1.00 cop[tikv] table:t, partition:P0, index:b(b) range:[1,1], keep order:false", "├─IndexReader 1.00 root index:IndexRangeScan", @@ -725,7 +725,7 @@ "└─IndexRangeScan 1.00 cop[tikv] table:t, index:b(b) range:[2,2], keep order:false" ], "StaticPlan": [ - "PartitionUnion 1.00 root ", + "PartitionUnion 3.00 root ", "├─IndexReader 1.00 root index:IndexRangeScan", "│ └─IndexRangeScan 1.00 cop[tikv] table:t, partition:P0, index:b(b) range:[2,2], keep order:false", "├─IndexReader 1.00 root index:IndexRangeScan", @@ -741,7 +741,7 @@ "└─IndexRangeScan 2.00 cop[tikv] table:t, index:b(b) range:[1,2], keep order:false" ], "StaticPlan": [ - "PartitionUnion 2.00 root ", + "PartitionUnion 3.00 root ", "├─IndexReader 1.00 root index:IndexRangeScan", "│ └─IndexRangeScan 1.00 cop[tikv] table:t, partition:P0, index:b(b) range:[1,2], keep order:false", "├─IndexReader 1.00 root index:IndexRangeScan", @@ -757,7 +757,7 @@ "└─IndexRangeScan 2.00 cop[tikv] table:t, index:b(b) range:[2,2], [3,3], [4,4], keep order:false" ], "StaticPlan": [ - "PartitionUnion 2.00 root ", + "PartitionUnion 3.00 root ", "├─IndexReader 1.00 root index:IndexRangeScan", "│ └─IndexRangeScan 1.00 cop[tikv] table:t, partition:P0, index:b(b) range:[2,2], [3,3], [4,4], keep order:false", "├─IndexReader 1.00 root index:IndexRangeScan", @@ -773,7 +773,7 @@ "└─IndexRangeScan 2.00 cop[tikv] table:t, index:b(b) range:[2,2], [3,3], keep order:false" ], "StaticPlan": [ - "PartitionUnion 2.00 root ", + "PartitionUnion 3.00 root ", "├─IndexReader 1.00 root index:IndexRangeScan", "│ └─IndexRangeScan 1.00 cop[tikv] table:t, partition:P0, index:b(b) range:[2,2], [3,3], keep order:false", "├─IndexReader 1.00 root index:IndexRangeScan", @@ -854,7 +854,7 @@ "└─IndexRangeScan 1.00 cop[tikv] table:t, index:b(b) range:[1,1], keep order:false" ], "StaticPlan": [ - "PartitionUnion 1.00 root ", + "PartitionUnion 2.00 root ", "├─IndexReader 1.00 root index:IndexRangeScan", "│ └─IndexRangeScan 1.00 cop[tikv] table:t, partition:P0, index:b(b) range:[1,1], keep order:false", "└─IndexReader 1.00 root index:IndexRangeScan", From c1e77ec2672ae21ab95426c9e4d972384d3d9386 Mon Sep 17 00:00:00 2001 From: tpp Date: Thu, 8 Aug 2024 09:20:34 -0500 Subject: [PATCH 15/35] testcase updates14 --- pkg/table/tables/test/partition/partition_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/table/tables/test/partition/partition_test.go b/pkg/table/tables/test/partition/partition_test.go index fa6e5b560f044..8d3b8542080a4 100644 --- a/pkg/table/tables/test/partition/partition_test.go +++ b/pkg/table/tables/test/partition/partition_test.go @@ -3244,8 +3244,8 @@ func TestPartitionCoverage(t *testing.T) { " └─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) tk.MustExec(`analyze table t all columns`) tk.MustQuery(`explain format='brief' select * from t where a = 10`).Check(testkit.Rows(""+ - `TableReader 0.00 root partition:dual data:Selection`, - `└─Selection 0.00 cop[tikv] eq(test.t.a, 10)`, + `TableReader 1.00 root partition:dual data:Selection`, + `└─Selection 1.00 cop[tikv] eq(test.t.a, 10)`, ` └─TableFullScan 1.00 cop[tikv] table:t keep order:false`)) tk.MustQuery(`select * from t where a = 10`).Check(testkit.Rows()) From d9b31a03f136f4baa19cea706b9ef4dee51c6910 Mon Sep 17 00:00:00 2001 From: tpp Date: Thu, 8 Aug 2024 10:07:06 -0500 Subject: [PATCH 16/35] testcase updates15 --- pkg/planner/cardinality/selectivity_test.go | 8 ++++---- .../cardinality/testdata/cardinality_suite_out.json | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 7c101e497e437..0284d78d47deb 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -401,7 +401,7 @@ func TestSelectivity(t *testing.T) { }, { exprs: "a >= 1 and b > 1 and a < 2", - selectivity: 0.01783264746, + selectivity: 0.01851851852, selectivityAfterIncrease: 0.01851851852, }, { @@ -421,13 +421,13 @@ func TestSelectivity(t *testing.T) { }, { exprs: "b > 1", - selectivity: 0.96296296296, + selectivity: 1, selectivityAfterIncrease: 1, }, { exprs: "a > 1 and b < 2 and c > 3 and d < 4 and e > 5", - selectivity: 0, - selectivityAfterIncrease: 0, + selectivity: 0.00099451303, + selectivityAfterIncrease: 1.51329827770157e-05, }, { exprs: longExpr, diff --git a/pkg/planner/cardinality/testdata/cardinality_suite_out.json b/pkg/planner/cardinality/testdata/cardinality_suite_out.json index b982afe0a0095..682889e62880e 100644 --- a/pkg/planner/cardinality/testdata/cardinality_suite_out.json +++ b/pkg/planner/cardinality/testdata/cardinality_suite_out.json @@ -24,7 +24,7 @@ { "Start": 800, "End": 900, - "Count": 752.004166655054 + "Count": 764.004166655054 }, { "Start": 900, @@ -79,7 +79,7 @@ { "Start": 800, "End": 1000, - "Count": 1210.196869573942 + "Count": 1222.196869573942 }, { "Start": 900, @@ -104,7 +104,7 @@ { "Start": 200, "End": 400, - "Count": 1216.5288209899081 + "Count": 1198.5288209899081 }, { "Start": 200, From 1551ec3acaf4ee239e3cf1457350fd69dc8457e4 Mon Sep 17 00:00:00 2001 From: tpp Date: Thu, 8 Aug 2024 12:09:28 -0500 Subject: [PATCH 17/35] testcase updates16 --- pkg/planner/cardinality/selectivity_test.go | 2 +- .../integrationtest/r/clustered_index.result | 90 ++++++------ .../r/explain_complex_stats.result | 33 ++--- tests/integrationtest/r/explain_easy.result | 4 +- .../r/explain_easy_stats.result | 6 +- .../explain_generate_column_substitute.result | 8 +- .../r/explain_indexmerge_stats.result | 138 +++++++----------- tests/integrationtest/r/imdbload.result | 50 +++---- .../r/planner/cardinality/selectivity.result | 16 +- .../r/planner/core/partition_pruner.result | 16 +- .../r/statistics/integration.result | 16 +- tests/integrationtest/r/tpch.result | 30 ++-- tests/integrationtest/r/util/ranger.result | 22 +-- 13 files changed, 202 insertions(+), 229 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 0284d78d47deb..b82760d15315f 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -933,7 +933,7 @@ func TestIssue39593(t *testing.T) { count, err := cardinality.GetRowCountByIndexRanges(sctx.GetPlanCtx(), &statsTbl.HistColl, idxID, getRanges(vals, vals)) require.NoError(t, err) // estimated row count without any changes - require.Equal(t, float64(360), count) + require.Equal(t, float64(540), count) statsTbl.RealtimeCount *= 10 count, err = cardinality.GetRowCountByIndexRanges(sctx.GetPlanCtx(), &statsTbl.HistColl, idxID, getRanges(vals, vals)) require.NoError(t, err) diff --git a/tests/integrationtest/r/clustered_index.result b/tests/integrationtest/r/clustered_index.result index 8e666459932ad..890713221fb6a 100644 --- a/tests/integrationtest/r/clustered_index.result +++ b/tests/integrationtest/r/clustered_index.result @@ -18,35 +18,35 @@ id estRows task access object operator info HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 └─IndexReader_13 1.00 root index:HashAgg_6 └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 - └─IndexRangeScan_11 798.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false + └─IndexRangeScan_11 1920.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 < 5429 ; id estRows task access object operator info HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 └─IndexReader_13 1.00 root index:HashAgg_6 └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_11 798.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false + └─IndexRangeScan_11 1920.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false explain select count(*) from with_cluster_index.tbl_0 where col_0 < 41 ; id estRows task access object operator info -StreamAgg_17 1.00 root funcs:count(Column#8)->Column#6 -└─IndexReader_18 1.00 root index:StreamAgg_9 - └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_16 41.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false +HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 +└─IndexReader_13 1.00 root index:HashAgg_6 + └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 + └─IndexRangeScan_11 1163.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 < 41 ; id estRows task access object operator info -StreamAgg_17 1.00 root funcs:count(Column#9)->Column#7 -└─IndexReader_18 1.00 root index:StreamAgg_9 - └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#9 - └─IndexRangeScan_16 41.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false +HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 +└─IndexReader_13 1.00 root index:HashAgg_6 + └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 + └─IndexRangeScan_11 1163.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false explain select col_14 from with_cluster_index.tbl_2 where col_11 <> '2013-11-01' ; id estRows task access object operator info -IndexReader_9 4509.00 root index:Projection_5 -└─Projection_5 4509.00 cop[tikv] with_cluster_index.tbl_2.col_14 - └─IndexRangeScan_8 4509.00 cop[tikv] table:tbl_2, index:idx_9(col_11) range:[-inf,2013-11-01 00:00:00), (2013-11-01 00:00:00,+inf], keep order:false +IndexReader_9 4673.00 root index:Projection_5 +└─Projection_5 4673.00 cop[tikv] with_cluster_index.tbl_2.col_14 + └─IndexRangeScan_8 4673.00 cop[tikv] table:tbl_2, index:idx_9(col_11) range:[-inf,2013-11-01 00:00:00), (2013-11-01 00:00:00,+inf], keep order:false explain select col_14 from wout_cluster_index.tbl_2 where col_11 <> '2013-11-01' ; id estRows task access object operator info -TableReader_14 4509.00 root data:Projection_5 -└─Projection_5 4509.00 cop[tikv] wout_cluster_index.tbl_2.col_14 - └─Selection_13 4509.00 cop[tikv] ne(wout_cluster_index.tbl_2.col_11, 2013-11-01 00:00:00.000000) +TableReader_14 4673.00 root data:Projection_5 +└─Projection_5 4673.00 cop[tikv] wout_cluster_index.tbl_2.col_14 + └─Selection_13 4673.00 cop[tikv] ne(wout_cluster_index.tbl_2.col_11, 2013-11-01 00:00:00.000000) └─TableFullScan_12 4673.00 cop[tikv] table:tbl_2 keep order:false explain select sum( col_4 ) from with_cluster_index.tbl_0 where col_3 != '1993-12-02' ; id estRows task access object operator info @@ -59,29 +59,29 @@ id estRows task access object operator info HashAgg_13 1.00 root funcs:sum(Column#8)->Column#7 └─TableReader_14 1.00 root data:HashAgg_6 └─HashAgg_6 1.00 cop[tikv] funcs:sum(wout_cluster_index.tbl_0.col_4)->Column#8 - └─Selection_12 2243.00 cop[tikv] ne(wout_cluster_index.tbl_0.col_3, 1993-12-02 00:00:00.000000) + └─Selection_12 2244.00 cop[tikv] ne(wout_cluster_index.tbl_0.col_3, 1993-12-02 00:00:00.000000) └─TableFullScan_11 2244.00 cop[tikv] table:tbl_0 keep order:false explain select col_0 from with_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -IndexReader_6 1.00 root index:IndexRangeScan_5 -└─IndexRangeScan_5 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +IndexReader_6 1123.00 root index:IndexRangeScan_5 +└─IndexRangeScan_5 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select col_0 from wout_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -IndexReader_6 1.00 root index:IndexRangeScan_5 -└─IndexRangeScan_5 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +IndexReader_6 1123.00 root index:IndexRangeScan_5 +└─IndexRangeScan_5 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select col_3 from with_cluster_index.tbl_0 where col_3 >= '1981-09-15' ; id estRows task access object operator info -IndexReader_8 1860.39 root index:IndexRangeScan_7 -└─IndexRangeScan_7 1860.39 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false +IndexReader_8 2244.00 root index:IndexRangeScan_7 +└─IndexRangeScan_7 2244.00 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false explain select col_3 from wout_cluster_index.tbl_0 where col_3 >= '1981-09-15' ; id estRows task access object operator info -IndexReader_8 1860.39 root index:IndexRangeScan_7 -└─IndexRangeScan_7 1860.39 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false +IndexReader_8 2244.00 root index:IndexRangeScan_7 +└─IndexRangeScan_7 2244.00 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false explain select tbl_2.col_14 , tbl_0.col_1 from with_cluster_index.tbl_2 right join with_cluster_index.tbl_0 on col_3 = col_11 ; id estRows task access object operator info MergeJoin_7 2533.51 root right outer join, left key:with_cluster_index.tbl_2.col_11, right key:with_cluster_index.tbl_0.col_3 -├─IndexReader_22(Build) 4509.00 root index:IndexFullScan_21 -│ └─IndexFullScan_21 4509.00 cop[tikv] table:tbl_2, index:idx_9(col_11) keep order:true +├─IndexReader_22(Build) 4673.00 root index:IndexFullScan_21 +│ └─IndexFullScan_21 4673.00 cop[tikv] table:tbl_2, index:idx_9(col_11) keep order:true └─TableReader_24(Probe) 2244.00 root data:TableFullScan_23 └─TableFullScan_23 2244.00 cop[tikv] table:tbl_0 keep order:true explain select tbl_2.col_14 , tbl_0.col_1 from wout_cluster_index.tbl_2 right join wout_cluster_index.tbl_0 on col_3 = col_11 ; @@ -89,31 +89,31 @@ id estRows task access object operator info HashJoin_22 2533.51 root right outer join, equal:[eq(wout_cluster_index.tbl_2.col_11, wout_cluster_index.tbl_0.col_3)] ├─TableReader_41(Build) 2244.00 root data:TableFullScan_40 │ └─TableFullScan_40 2244.00 cop[tikv] table:tbl_0 keep order:false -└─TableReader_44(Probe) 4509.00 root data:Selection_43 - └─Selection_43 4509.00 cop[tikv] not(isnull(wout_cluster_index.tbl_2.col_11)) +└─TableReader_44(Probe) 4673.00 root data:Selection_43 + └─Selection_43 4673.00 cop[tikv] not(isnull(wout_cluster_index.tbl_2.col_11)) └─TableFullScan_42 4673.00 cop[tikv] table:tbl_2 keep order:false explain select count(*) from with_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -StreamAgg_16 1.00 root funcs:count(Column#8)->Column#6 -└─IndexReader_17 1.00 root index:StreamAgg_9 - └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_11 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 +└─IndexReader_13 1.00 root index:HashAgg_6 + └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 + └─IndexRangeScan_11 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -StreamAgg_16 1.00 root funcs:count(Column#9)->Column#7 -└─IndexReader_17 1.00 root index:StreamAgg_9 - └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#9 - └─IndexRangeScan_11 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 +└─IndexReader_13 1.00 root index:HashAgg_6 + └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 + └─IndexRangeScan_11 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select count(*) from with_cluster_index.tbl_0 where col_0 >= 803163 ; id estRows task access object operator info -StreamAgg_17 1.00 root funcs:count(Column#8)->Column#6 -└─IndexReader_18 1.00 root index:StreamAgg_9 - └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_16 133.89 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false +HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 +└─IndexReader_13 1.00 root index:HashAgg_6 + └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 + └─IndexRangeScan_11 1229.12 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 >= 803163 ; id estRows task access object operator info -StreamAgg_17 1.00 root funcs:count(Column#9)->Column#7 -└─IndexReader_18 1.00 root index:StreamAgg_9 - └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#9 - └─IndexRangeScan_16 133.89 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false +HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 +└─IndexReader_13 1.00 root index:HashAgg_6 + └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 + └─IndexRangeScan_11 1229.12 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false set @@tidb_enable_outer_join_reorder=false; diff --git a/tests/integrationtest/r/explain_complex_stats.result b/tests/integrationtest/r/explain_complex_stats.result index 9183896d7aefb..3579d7c463364 100644 --- a/tests/integrationtest/r/explain_complex_stats.result +++ b/tests/integrationtest/r/explain_complex_stats.result @@ -145,15 +145,15 @@ Projection 21.47 root explain_complex_stats.dt.ds, explain_complex_stats.dt.p1, └─TableRowIDScan 128.00 cop[tikv] table:dt keep order:false, stats:partial[cm:missing] explain format = 'brief' select gad.id as gid,sdk.id as sid,gad.aid as aid,gad.cm as cm,sdk.dic as dic,sdk.ip as ip, sdk.t as t, gad.p1 as p1, gad.p2 as p2, gad.p3 as p3, gad.p4 as p4, gad.p5 as p5, gad.p6_md5 as p6, gad.p7_md5 as p7, gad.ext as ext, gad.t as gtime from st gad join (select id, aid, pt, dic, ip, t from dd where pt = 'android' and bm = 0 and t > 1478143908) sdk on gad.aid = sdk.aid and gad.ip = sdk.ip and sdk.t > gad.t where gad.t > 1478143908 and gad.bm = 0 and gad.pt = 'android' group by gad.aid, sdk.dic limit 2500; id estRows task access object operator info -Projection 424.00 root explain_complex_stats.st.id, explain_complex_stats.dd.id, explain_complex_stats.st.aid, explain_complex_stats.st.cm, explain_complex_stats.dd.dic, explain_complex_stats.dd.ip, explain_complex_stats.dd.t, explain_complex_stats.st.p1, explain_complex_stats.st.p2, explain_complex_stats.st.p3, explain_complex_stats.st.p4, explain_complex_stats.st.p5, explain_complex_stats.st.p6_md5, explain_complex_stats.st.p7_md5, explain_complex_stats.st.ext, explain_complex_stats.st.t -└─Limit 424.00 root offset:0, count:2500 - └─HashAgg 424.00 root group by:explain_complex_stats.dd.dic, explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.id)->explain_complex_stats.st.id, funcs:firstrow(explain_complex_stats.st.aid)->explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.cm)->explain_complex_stats.st.cm, funcs:firstrow(explain_complex_stats.st.p1)->explain_complex_stats.st.p1, funcs:firstrow(explain_complex_stats.st.p2)->explain_complex_stats.st.p2, funcs:firstrow(explain_complex_stats.st.p3)->explain_complex_stats.st.p3, funcs:firstrow(explain_complex_stats.st.p4)->explain_complex_stats.st.p4, funcs:firstrow(explain_complex_stats.st.p5)->explain_complex_stats.st.p5, funcs:firstrow(explain_complex_stats.st.p6_md5)->explain_complex_stats.st.p6_md5, funcs:firstrow(explain_complex_stats.st.p7_md5)->explain_complex_stats.st.p7_md5, funcs:firstrow(explain_complex_stats.st.ext)->explain_complex_stats.st.ext, funcs:firstrow(explain_complex_stats.st.t)->explain_complex_stats.st.t, funcs:firstrow(explain_complex_stats.dd.id)->explain_complex_stats.dd.id, funcs:firstrow(explain_complex_stats.dd.dic)->explain_complex_stats.dd.dic, funcs:firstrow(explain_complex_stats.dd.ip)->explain_complex_stats.dd.ip, funcs:firstrow(explain_complex_stats.dd.t)->explain_complex_stats.dd.t - └─HashJoin 424.00 root inner join, equal:[eq(explain_complex_stats.st.aid, explain_complex_stats.dd.aid) eq(explain_complex_stats.st.ip, explain_complex_stats.dd.ip)], other cond:gt(explain_complex_stats.dd.t, explain_complex_stats.st.t) - ├─TableReader(Build) 424.00 root data:Selection - │ └─Selection 424.00 cop[tikv] eq(explain_complex_stats.st.bm, 0), eq(explain_complex_stats.st.pt, "android"), gt(explain_complex_stats.st.t, 1478143908), not(isnull(explain_complex_stats.st.ip)) +Projection 488.80 root explain_complex_stats.st.id, explain_complex_stats.dd.id, explain_complex_stats.st.aid, explain_complex_stats.st.cm, explain_complex_stats.dd.dic, explain_complex_stats.dd.ip, explain_complex_stats.dd.t, explain_complex_stats.st.p1, explain_complex_stats.st.p2, explain_complex_stats.st.p3, explain_complex_stats.st.p4, explain_complex_stats.st.p5, explain_complex_stats.st.p6_md5, explain_complex_stats.st.p7_md5, explain_complex_stats.st.ext, explain_complex_stats.st.t +└─Limit 488.80 root offset:0, count:2500 + └─HashAgg 488.80 root group by:explain_complex_stats.dd.dic, explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.id)->explain_complex_stats.st.id, funcs:firstrow(explain_complex_stats.st.aid)->explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.cm)->explain_complex_stats.st.cm, funcs:firstrow(explain_complex_stats.st.p1)->explain_complex_stats.st.p1, funcs:firstrow(explain_complex_stats.st.p2)->explain_complex_stats.st.p2, funcs:firstrow(explain_complex_stats.st.p3)->explain_complex_stats.st.p3, funcs:firstrow(explain_complex_stats.st.p4)->explain_complex_stats.st.p4, funcs:firstrow(explain_complex_stats.st.p5)->explain_complex_stats.st.p5, funcs:firstrow(explain_complex_stats.st.p6_md5)->explain_complex_stats.st.p6_md5, funcs:firstrow(explain_complex_stats.st.p7_md5)->explain_complex_stats.st.p7_md5, funcs:firstrow(explain_complex_stats.st.ext)->explain_complex_stats.st.ext, funcs:firstrow(explain_complex_stats.st.t)->explain_complex_stats.st.t, funcs:firstrow(explain_complex_stats.dd.id)->explain_complex_stats.dd.id, funcs:firstrow(explain_complex_stats.dd.dic)->explain_complex_stats.dd.dic, funcs:firstrow(explain_complex_stats.dd.ip)->explain_complex_stats.dd.ip, funcs:firstrow(explain_complex_stats.dd.t)->explain_complex_stats.dd.t + └─HashJoin 488.80 root inner join, equal:[eq(explain_complex_stats.st.aid, explain_complex_stats.dd.aid) eq(explain_complex_stats.st.ip, explain_complex_stats.dd.ip)], other cond:gt(explain_complex_stats.dd.t, explain_complex_stats.st.t) + ├─TableReader(Build) 488.80 root data:Selection + │ └─Selection 488.80 cop[tikv] eq(explain_complex_stats.st.bm, 0), eq(explain_complex_stats.st.pt, "android"), gt(explain_complex_stats.st.t, 1478143908), not(isnull(explain_complex_stats.st.ip)) │ └─TableFullScan 1999.00 cop[tikv] table:gad keep order:false, stats:partial[t:missing] - └─TableReader(Probe) 450.56 root data:Selection - └─Selection 450.56 cop[tikv] eq(explain_complex_stats.dd.bm, 0), eq(explain_complex_stats.dd.pt, "android"), gt(explain_complex_stats.dd.t, 1478143908), not(isnull(explain_complex_stats.dd.ip)), not(isnull(explain_complex_stats.dd.t)) + └─TableReader(Probe) 501.17 root data:Selection + └─Selection 501.17 cop[tikv] eq(explain_complex_stats.dd.bm, 0), eq(explain_complex_stats.dd.pt, "android"), gt(explain_complex_stats.dd.t, 1478143908), not(isnull(explain_complex_stats.dd.ip)), not(isnull(explain_complex_stats.dd.t)) └─TableFullScan 2000.00 cop[tikv] table:dd keep order:false, stats:partial[ip:missing, t:missing] explain format = 'brief' select gad.id as gid,sdk.id as sid,gad.aid as aid,gad.cm as cm,sdk.dic as dic,sdk.ip as ip, sdk.t as t, gad.p1 as p1, gad.p2 as p2, gad.p3 as p3, gad.p4 as p4, gad.p5 as p5, gad.p6_md5 as p6, gad.p7_md5 as p7, gad.ext as ext from st gad join dd sdk on gad.aid = sdk.aid and gad.dic = sdk.mac and gad.t < sdk.t where gad.t > 1477971479 and gad.bm = 0 and gad.pt = 'ios' and gad.dit = 'mac' and sdk.t > 1477971479 and sdk.bm = 0 and sdk.pt = 'ios' limit 3000; id estRows task access object operator info @@ -177,16 +177,15 @@ Projection 39.28 root explain_complex_stats.st.cm, explain_complex_stats.st.p1, └─TableRowIDScan 160.23 cop[tikv] table:st keep order:false, stats:partial[t:missing] explain format = 'brief' select dt.id as id, dt.aid as aid, dt.pt as pt, dt.dic as dic, dt.cm as cm, rr.gid as gid, rr.acd as acd, rr.t as t,dt.p1 as p1, dt.p2 as p2, dt.p3 as p3, dt.p4 as p4, dt.p5 as p5, dt.p6_md5 as p6, dt.p7_md5 as p7 from dt dt join rr rr on (rr.pt = 'ios' and rr.t > 1478185592 and dt.aid = rr.aid and dt.dic = rr.dic) where dt.pt = 'ios' and dt.t > 1478185592 and dt.bm = 0 limit 2000; id estRows task access object operator info -Projection 428.32 root explain_complex_stats.dt.id, explain_complex_stats.dt.aid, explain_complex_stats.dt.pt, explain_complex_stats.dt.dic, explain_complex_stats.dt.cm, explain_complex_stats.rr.gid, explain_complex_stats.rr.acd, explain_complex_stats.rr.t, explain_complex_stats.dt.p1, explain_complex_stats.dt.p2, explain_complex_stats.dt.p3, explain_complex_stats.dt.p4, explain_complex_stats.dt.p5, explain_complex_stats.dt.p6_md5, explain_complex_stats.dt.p7_md5 -└─Limit 428.32 root offset:0, count:2000 - └─IndexJoin 428.32 root inner join, inner:IndexLookUp, outer key:explain_complex_stats.dt.aid, explain_complex_stats.dt.dic, inner key:explain_complex_stats.rr.aid, explain_complex_stats.rr.dic, equal cond:eq(explain_complex_stats.dt.aid, explain_complex_stats.rr.aid), eq(explain_complex_stats.dt.dic, explain_complex_stats.rr.dic) - ├─TableReader(Build) 428.32 root data:Selection - │ └─Selection 428.32 cop[tikv] eq(explain_complex_stats.dt.bm, 0), eq(explain_complex_stats.dt.pt, "ios"), gt(explain_complex_stats.dt.t, 1478185592), not(isnull(explain_complex_stats.dt.dic)) +Projection 476.61 root explain_complex_stats.dt.id, explain_complex_stats.dt.aid, explain_complex_stats.dt.pt, explain_complex_stats.dt.dic, explain_complex_stats.dt.cm, explain_complex_stats.rr.gid, explain_complex_stats.rr.acd, explain_complex_stats.rr.t, explain_complex_stats.dt.p1, explain_complex_stats.dt.p2, explain_complex_stats.dt.p3, explain_complex_stats.dt.p4, explain_complex_stats.dt.p5, explain_complex_stats.dt.p6_md5, explain_complex_stats.dt.p7_md5 +└─Limit 476.61 root offset:0, count:2000 + └─HashJoin 476.61 root inner join, equal:[eq(explain_complex_stats.dt.aid, explain_complex_stats.rr.aid) eq(explain_complex_stats.dt.dic, explain_complex_stats.rr.dic)] + ├─TableReader(Build) 476.61 root data:Selection + │ └─Selection 476.61 cop[tikv] eq(explain_complex_stats.dt.bm, 0), eq(explain_complex_stats.dt.pt, "ios"), gt(explain_complex_stats.dt.t, 1478185592), not(isnull(explain_complex_stats.dt.dic)) │ └─TableFullScan 2000.00 cop[tikv] table:dt keep order:false - └─IndexLookUp(Probe) 428.32 root - ├─IndexRangeScan(Build) 428.32 cop[tikv] table:rr, index:PRIMARY(aid, dic) range: decided by [eq(explain_complex_stats.rr.aid, explain_complex_stats.dt.aid) eq(explain_complex_stats.rr.dic, explain_complex_stats.dt.dic)], keep order:false - └─Selection(Probe) 428.32 cop[tikv] eq(explain_complex_stats.rr.pt, "ios"), gt(explain_complex_stats.rr.t, 1478185592) - └─TableRowIDScan 428.32 cop[tikv] table:rr keep order:false + └─TableReader(Probe) 970.00 root data:Selection + └─Selection 970.00 cop[tikv] eq(explain_complex_stats.rr.pt, "ios"), gt(explain_complex_stats.rr.t, 1478185592) + └─TableFullScan 2000.00 cop[tikv] table:rr keep order:false explain format = 'brief' select pc,cr,count(DISTINCT uid) as pay_users,count(oid) as pay_times,sum(am) as am from pp where ps=2 and ppt>=1478188800 and ppt<1478275200 and pi in ('510017','520017') and uid in ('18089709','18090780') group by pc,cr; id estRows task access object operator info Projection 207.02 root explain_complex_stats.pp.pc, explain_complex_stats.pp.cr, Column#22, Column#23, Column#24 diff --git a/tests/integrationtest/r/explain_easy.result b/tests/integrationtest/r/explain_easy.result index f5a81937ce990..c9e5771dfc0c6 100644 --- a/tests/integrationtest/r/explain_easy.result +++ b/tests/integrationtest/r/explain_easy.result @@ -739,8 +739,8 @@ insert into t values (1),(2),(2),(2),(9),(9),(9),(10); analyze table t all columns with 1 buckets; explain format = 'brief' select * from t where a >= 3 and a <= 8; id estRows task access object operator info -TableReader 0.00 root data:Selection -└─Selection 0.00 cop[tikv] ge(explain_easy.t.a, 3), le(explain_easy.t.a, 8) +TableReader 1.00 root data:Selection +└─Selection 1.00 cop[tikv] ge(explain_easy.t.a, 3), le(explain_easy.t.a, 8) └─TableFullScan 8.00 cop[tikv] table:t keep order:false drop table t; create table t(a int, b int, index idx_ab(a, b)); diff --git a/tests/integrationtest/r/explain_easy_stats.result b/tests/integrationtest/r/explain_easy_stats.result index 8c26bcf04b8be..7be6d7c22aa2e 100644 --- a/tests/integrationtest/r/explain_easy_stats.result +++ b/tests/integrationtest/r/explain_easy_stats.result @@ -51,8 +51,8 @@ HashJoin 2481.25 root left outer join, equal:[eq(explain_easy_stats.t1.c2, expl ├─TableReader(Build) 1985.00 root data:Selection │ └─Selection 1985.00 cop[tikv] not(isnull(explain_easy_stats.t2.c1)) │ └─TableFullScan 1985.00 cop[tikv] table:t2 keep order:false, stats:partial[c1:missing] -└─TableReader(Probe) 1998.00 root data:TableRangeScan - └─TableRangeScan 1998.00 cop[tikv] table:t1 range:(1,+inf], keep order:false +└─TableReader(Probe) 1999.00 root data:TableRangeScan + └─TableRangeScan 1999.00 cop[tikv] table:t1 range:(1,+inf], keep order:false explain format = 'brief' update t1 set t1.c2 = 2 where t1.c1 = 1; id estRows task access object operator info Update N/A root N/A @@ -87,7 +87,7 @@ IndexLookUp 0.00 root └─TableRowIDScan 0.00 cop[tikv] table:t1 keep order:false, stats:partial[c2:missing] explain format = 'brief' select * from t1 where c1 = 1 and c2 > 1; id estRows task access object operator info -Selection 0.50 root gt(explain_easy_stats.t1.c2, 1) +Selection 1.00 root gt(explain_easy_stats.t1.c2, 1) └─Point_Get 1.00 root table:t1 handle:1 explain format = 'brief' select c1 from t1 where c1 in (select c2 from t2); id estRows task access object operator info diff --git a/tests/integrationtest/r/explain_generate_column_substitute.result b/tests/integrationtest/r/explain_generate_column_substitute.result index 53a02de94d8d5..6732ab8473fa0 100644 --- a/tests/integrationtest/r/explain_generate_column_substitute.result +++ b/tests/integrationtest/r/explain_generate_column_substitute.result @@ -413,10 +413,10 @@ Projection 1.00 root explain_generate_column_substitute.t.a, explain_generate_c └─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false desc format = 'brief' select * from t where not (lower(b) >= "a"); id estRows task access object operator info -Projection 0.00 root explain_generate_column_substitute.t.a, explain_generate_column_substitute.t.b -└─IndexLookUp 0.00 root - ├─IndexRangeScan(Build) 0.00 cop[tikv] table:t, index:expression_index(lower(`b`), `a` + 1) range:[-inf,"a"), keep order:false - └─TableRowIDScan(Probe) 0.00 cop[tikv] table:t keep order:false +Projection 1.00 root explain_generate_column_substitute.t.a, explain_generate_column_substitute.t.b +└─IndexLookUp 1.00 root + ├─IndexRangeScan(Build) 1.00 cop[tikv] table:t, index:expression_index(lower(`b`), `a` + 1) range:[-inf,"a"), keep order:false + └─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false desc format = 'brief' select count(upper(b)) from t group by upper(b); id estRows task access object operator info StreamAgg 4.80 root group by:upper(explain_generate_column_substitute.t.b), funcs:count(upper(explain_generate_column_substitute.t.b))->Column#7 diff --git a/tests/integrationtest/r/explain_indexmerge_stats.result b/tests/integrationtest/r/explain_indexmerge_stats.result index bd27123d6a92f..1d7f41261838b 100644 --- a/tests/integrationtest/r/explain_indexmerge_stats.result +++ b/tests/integrationtest/r/explain_indexmerge_stats.result @@ -7,137 +7,111 @@ create index td on t (d); load stats 's/explain_indexmerge_stats_t.json'; explain format = 'brief' select * from t where a < 50 or b < 50; id estRows task access object operator info -IndexMerge 98.00 root type: union -├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false +TableReader 3750049.00 root data:Selection +└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where (a < 50 or b < 50) and f > 100; id estRows task access object operator info -IndexMerge 98.00 root type: union -├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -└─Selection(Probe) 98.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100) - └─TableRowIDScan 98.00 cop[tikv] table:t keep order:false +TableReader 3750049.00 root data:Selection +└─Selection 3750049.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100), or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where b < 50 or c < 50; id estRows task access object operator info -IndexMerge 98.00 root type: union -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tc(c) range:[-inf,50), keep order:false -└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false +TableReader 3750049.00 root data:Selection +└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 50)) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false set session tidb_enable_index_merge = on; explain format = 'brief' select * from t where a < 50 or b < 50; id estRows task access object operator info -IndexMerge 98.00 root type: union -├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false +TableReader 3750049.00 root data:Selection +└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where (a < 50 or b < 50) and f > 100; id estRows task access object operator info -IndexMerge 98.00 root type: union -├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -└─Selection(Probe) 98.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100) - └─TableRowIDScan 98.00 cop[tikv] table:t keep order:false +TableReader 3750049.00 root data:Selection +└─Selection 3750049.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100), or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where a < 50 or b < 5000000; id estRows task access object operator info -TableReader 4999999.00 root data:Selection -└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 5000000)) +TableReader 5000000.00 root data:Selection +└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where b < 50 or c < 50; id estRows task access object operator info -IndexMerge 98.00 root type: union -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tc(c) range:[-inf,50), keep order:false -└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false +TableReader 3750049.00 root data:Selection +└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 50)) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where b < 50 or c < 5000000; id estRows task access object operator info -TableReader 4999999.00 root data:Selection -└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) +TableReader 5000000.00 root data:Selection +└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where a < 50 or b < 50 or c < 50; id estRows task access object operator info -IndexMerge 147.00 root type: union -├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tc(c) range:[-inf,50), keep order:false -└─TableRowIDScan(Probe) 147.00 cop[tikv] table:t keep order:false +TableReader 4375036.75 root data:Selection +└─Selection 4375036.75 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 50))) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where (b < 10000 or c < 10000) and (a < 10 or d < 10) and f < 10; id estRows task access object operator info -IndexMerge 0.00 root type: union -├─TableRangeScan(Build) 9.00 cop[tikv] table:t range:[-inf,10), keep order:false -├─IndexRangeScan(Build) 9.00 cop[tikv] table:t, index:td(d) range:[-inf,10), keep order:false -└─Selection(Probe) 0.00 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.b, 10000), lt(explain_indexmerge_stats.t.c, 10000)) - └─TableRowIDScan 18.00 cop[tikv] table:t keep order:false +TableReader 1409331.42 root data:Selection +└─Selection 1409331.42 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.a, 10), lt(explain_indexmerge_stats.t.d, 10)), or(lt(explain_indexmerge_stats.t.b, 10000), lt(explain_indexmerge_stats.t.c, 10000)) + └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format="dot" select * from t where (a < 50 or b < 50) and f > 100; dot contents -digraph IndexMerge_12 { -subgraph cluster12{ +digraph TableReader_7 { +subgraph cluster7{ node [style=filled, color=lightgrey] color=black label = "root" -"IndexMerge_12" -} -subgraph cluster8{ -node [style=filled, color=lightgrey] -color=black -label = "cop" -"TableRangeScan_8" -} -subgraph cluster9{ -node [style=filled, color=lightgrey] -color=black -label = "cop" -"IndexRangeScan_9" +"TableReader_7" } -subgraph cluster11{ +subgraph cluster6{ node [style=filled, color=lightgrey] color=black label = "cop" -"Selection_11" -> "TableRowIDScan_10" +"Selection_6" -> "TableFullScan_5" } -"IndexMerge_12" -> "TableRangeScan_8" -"IndexMerge_12" -> "IndexRangeScan_9" -"IndexMerge_12" -> "Selection_11" +"TableReader_7" -> "Selection_6" } set session tidb_enable_index_merge = off; explain format = 'brief' select /*+ use_index_merge(t, primary, tb, tc) */ * from t where a <= 500000 or b <= 1000000 or c <= 3000000; id estRows task access object operator info -IndexMerge 3560000.00 root type: union -├─TableRangeScan(Build) 500000.00 cop[tikv] table:t range:[-inf,500000], keep order:false -├─IndexRangeScan(Build) 1000000.00 cop[tikv] table:t, index:tb(b) range:[-inf,1000000], keep order:false -├─IndexRangeScan(Build) 3000000.00 cop[tikv] table:t, index:tc(c) range:[-inf,3000000], keep order:false -└─TableRowIDScan(Probe) 3560000.00 cop[tikv] table:t keep order:false +IndexMerge 5000000.00 root type: union +├─TableRangeScan(Build) 3000000.00 cop[tikv] table:t range:[-inf,500000], keep order:false +├─IndexRangeScan(Build) 3500000.00 cop[tikv] table:t, index:tb(b) range:[-inf,1000000], keep order:false +├─IndexRangeScan(Build) 5000000.00 cop[tikv] table:t, index:tc(c) range:[-inf,3000000], keep order:false +└─TableRowIDScan(Probe) 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, tb, tc) */ * from t where b < 50 or c < 5000000; id estRows task access object operator info -IndexMerge 4999999.00 root type: union -├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 4999999.00 cop[tikv] table:t, index:tc(c) range:[-inf,5000000), keep order:false -└─TableRowIDScan(Probe) 4999999.00 cop[tikv] table:t keep order:false +IndexMerge 5000000.00 root type: union +├─IndexRangeScan(Build) 2500049.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 5000000.00 cop[tikv] table:t, index:tc(c) range:[-inf,5000000), keep order:false +└─TableRowIDScan(Probe) 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, tb, tc) */ * from t where (b < 10000 or c < 10000) and (a < 10 or d < 10) and f < 10; id estRows task access object operator info -IndexMerge 0.00 root type: union -├─IndexRangeScan(Build) 9999.00 cop[tikv] table:t, index:tb(b) range:[-inf,10000), keep order:false -├─IndexRangeScan(Build) 9999.00 cop[tikv] table:t, index:tc(c) range:[-inf,10000), keep order:false -└─Selection(Probe) 0.00 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.a, 10), lt(explain_indexmerge_stats.t.d, 10)) - └─TableRowIDScan 19978.00 cop[tikv] table:t keep order:false +IndexMerge 1409331.42 root type: union +├─IndexRangeScan(Build) 2509999.00 cop[tikv] table:t, index:tb(b) range:[-inf,10000), keep order:false +├─IndexRangeScan(Build) 2509999.00 cop[tikv] table:t, index:tc(c) range:[-inf,10000), keep order:false +└─Selection(Probe) 1409331.42 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.a, 10), lt(explain_indexmerge_stats.t.d, 10)) + └─TableRowIDScan 3759979.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, tb) */ * from t where b < 50 or c < 5000000; id estRows task access object operator info -TableReader 4999999.00 root data:Selection -└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) +TableReader 5000000.00 root data:Selection +└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ no_index_merge(), use_index_merge(t, tb, tc) */ * from t where b < 50 or c < 5000000; id estRows task access object operator info -TableReader 4999999.00 root data:Selection -└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) +TableReader 5000000.00 root data:Selection +└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, primary, tb) */ * from t where a < 50 or b < 5000000; id estRows task access object operator info -IndexMerge 4999999.00 root type: union -├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 4999999.00 cop[tikv] table:t, index:tb(b) range:[-inf,5000000), keep order:false -└─TableRowIDScan(Probe) 4999999.00 cop[tikv] table:t keep order:false +IndexMerge 5000000.00 root type: union +├─TableRangeScan(Build) 2500049.00 cop[tikv] table:t range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 5000000.00 cop[tikv] table:t, index:tb(b) range:[-inf,5000000), keep order:false +└─TableRowIDScan(Probe) 5000000.00 cop[tikv] table:t keep order:false set session tidb_enable_index_merge = on; drop table if exists t; CREATE TABLE t ( diff --git a/tests/integrationtest/r/imdbload.result b/tests/integrationtest/r/imdbload.result index 7e5914344f1bf..53df922885088 100644 --- a/tests/integrationtest/r/imdbload.result +++ b/tests/integrationtest/r/imdbload.result @@ -286,48 +286,48 @@ IndexLookUp_7 1005030.94 root └─TableRowIDScan_6(Probe) 1005030.94 cop[tikv] table:char_name keep order:false trace plan target = 'estimation' select * from char_name where ((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436')); CE_trace -[{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'I'))","row_count":0},{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'L'))","row_count":0},{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436'))","row_count":0},{"table_name":"char_name","type":"Index Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`or`(`and`(`eq`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.surname_pcode, 'E436')), `and`(`eq`(imdbload.char_name.imdb_index, 'L'), `lt`(imdbload.char_name.surname_pcode, 'E436')))","row_count":804024}] +[{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'I'))","row_count":1},{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'L'))","row_count":1},{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1780915},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436'))","row_count":2},{"table_name":"char_name","type":"Index Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`or`(`and`(`eq`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.surname_pcode, 'E436')), `and`(`eq`(imdbload.char_name.imdb_index, 'L'), `lt`(imdbload.char_name.surname_pcode, 'E436')))","row_count":804024}] explain select * from char_name where ((imdb_index = 'V') and (surname_pcode < 'L3416')); id estRows task access object operator info -IndexLookUp_10 0.00 root -├─IndexRangeScan_8(Build) 0.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:["V" -inf,"V" "L3416"), keep order:false -└─TableRowIDScan_9(Probe) 0.00 cop[tikv] table:char_name keep order:false +IndexLookUp_10 1.00 root +├─IndexRangeScan_8(Build) 1.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:["V" -inf,"V" "L3416"), keep order:false +└─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:char_name keep order:false explain select * from char_name where imdb_index > 'V'; id estRows task access object operator info -IndexLookUp_10 0.00 root -├─IndexRangeScan_8(Build) 0.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:("V",+inf], keep order:false -└─TableRowIDScan_9(Probe) 0.00 cop[tikv] table:char_name keep order:false +IndexLookUp_10 1.00 root +├─IndexRangeScan_8(Build) 1.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:("V",+inf], keep order:false +└─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:char_name keep order:false trace plan target = 'estimation' select * from char_name where imdb_index > 'V'; CE_trace -[{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((imdb_index > 'V' and true))","row_count":0},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index > 'V' and true))","row_count":0},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.char_name.imdb_index, 'V')","row_count":0}] +[{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((imdb_index > 'V' and true))","row_count":1},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index > 'V' and true))","row_count":1},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.char_name.imdb_index, 'V')","row_count":1}] explain select * from movie_companies where company_type_id > 2; id estRows task access object operator info -IndexLookUp_10 0.00 root -├─IndexRangeScan_8(Build) 0.00 cop[tikv] table:movie_companies, index:movie_companies_idx_ctypeid(company_type_id) range:(2,+inf], keep order:false -└─TableRowIDScan_9(Probe) 0.00 cop[tikv] table:movie_companies keep order:false +IndexLookUp_10 1.00 root +├─IndexRangeScan_8(Build) 1.00 cop[tikv] table:movie_companies, index:movie_companies_idx_ctypeid(company_type_id) range:(2,+inf], keep order:false +└─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:movie_companies keep order:false trace plan target = 'estimation' select * from movie_companies where company_type_id > 2; CE_trace -[{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((company_type_id > 2 and true))","row_count":0},{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4958296},{"table_name":"movie_companies","type":"Index Stats-Range","expr":"((company_type_id > 2 and true))","row_count":0},{"table_name":"movie_companies","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.movie_companies.company_type_id, 2)","row_count":0}] +[{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((company_type_id > 2 and true))","row_count":2479148},{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4958296},{"table_name":"movie_companies","type":"Index Stats-Range","expr":"((company_type_id > 2 and true))","row_count":1},{"table_name":"movie_companies","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.movie_companies.company_type_id, 2)","row_count":1}] explain select * from char_name where imdb_index > 'I' and imdb_index < 'II'; id estRows task access object operator info -IndexLookUp_10 0.00 root -├─IndexRangeScan_8(Build) 0.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:("I","II"), keep order:false -└─TableRowIDScan_9(Probe) 0.00 cop[tikv] table:char_name keep order:false +IndexLookUp_10 1.00 root +├─IndexRangeScan_8(Build) 1.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:("I","II"), keep order:false +└─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:char_name keep order:false trace plan target = 'estimation' select * from char_name where imdb_index > 'I' and imdb_index < 'II'; CE_trace -[{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((imdb_index > 'I' and imdb_index < 'II'))","row_count":0},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index > 'I' and imdb_index < 'II'))","row_count":0},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`and`(`gt`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.imdb_index, 'II'))","row_count":0}] +[{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((imdb_index > 'I' and imdb_index < 'II'))","row_count":1},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index > 'I' and imdb_index < 'II'))","row_count":1},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`and`(`gt`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.imdb_index, 'II'))","row_count":1}] explain select * from char_name where imdb_index > 'I'; id estRows task access object operator info -IndexLookUp_10 0.00 root -├─IndexRangeScan_8(Build) 0.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:("I",+inf], keep order:false -└─TableRowIDScan_9(Probe) 0.00 cop[tikv] table:char_name keep order:false +IndexLookUp_10 1.00 root +├─IndexRangeScan_8(Build) 1.00 cop[tikv] table:char_name, index:itest2(imdb_index, surname_pcode, name_pcode_nf) range:("I",+inf], keep order:false +└─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:char_name keep order:false trace plan target = 'estimation' select * from char_name where imdb_index > 'I'; CE_trace -[{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((imdb_index > 'I' and true))","row_count":0},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index > 'I' and true))","row_count":0},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.char_name.imdb_index, 'I')","row_count":0}] +[{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((imdb_index > 'I' and true))","row_count":1},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index > 'I' and true))","row_count":1},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.char_name.imdb_index, 'I')","row_count":1}] explain select * from cast_info where nr_order < -2068070866; id estRows task access object operator info @@ -341,12 +341,12 @@ TableReader_7 34260.33 root data:Selection_6 └─TableFullScan_5 528337.00 cop[tikv] table:aka_title keep order:false explain select * from aka_title where kind_id > 7; id estRows task access object operator info -IndexLookUp_10 0.00 root -├─IndexRangeScan_8(Build) 0.00 cop[tikv] table:aka_title, index:aka_title_idx_kindid(kind_id) range:(7,+inf], keep order:false -└─TableRowIDScan_9(Probe) 0.00 cop[tikv] table:aka_title keep order:false +IndexLookUp_10 1.00 root +├─IndexRangeScan_8(Build) 1.00 cop[tikv] table:aka_title, index:aka_title_idx_kindid(kind_id) range:(7,+inf], keep order:false +└─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:aka_title keep order:false trace plan target = 'estimation' select * from aka_title where kind_id > 7; CE_trace -[{"table_name":"aka_title","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":528337},{"table_name":"aka_title","type":"Column Stats-Range","expr":"((kind_id > 7 and true))","row_count":0},{"table_name":"aka_title","type":"Index Stats-Range","expr":"((kind_id > 7 and true))","row_count":0},{"table_name":"aka_title","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.aka_title.kind_id, 7)","row_count":0}] +[{"table_name":"aka_title","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":528337},{"table_name":"aka_title","type":"Column Stats-Range","expr":"((kind_id > 7 and true))","row_count":51390},{"table_name":"aka_title","type":"Index Stats-Range","expr":"((kind_id > 7 and true))","row_count":1},{"table_name":"aka_title","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.aka_title.kind_id, 7)","row_count":1}] explain select * from keyword where ((phonetic_code = 'R1652') and (keyword > 'ecg-monitor' and keyword < 'killers')); id estRows task access object operator info @@ -366,5 +366,5 @@ IndexLookUp_11 144633.00 root └─TableRowIDScan_9 144633.00 cop[tikv] table:cast_info keep order:false trace plan target = 'estimation' select * from cast_info where (nr_order is null) and (person_role_id = 2) and (note >= '(key set pa: Florida'); CE_trace -[{"table_name":"cast_info","type":"Column Stats-Point","expr":"((nr_order is null))","row_count":45995275},{"table_name":"cast_info","type":"Column Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":63475835},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((note >= '(key set pa: Florida' and true))","row_count":14934328},{"table_name":"cast_info","type":"Index Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Index Stats-Range","expr":"((nr_order is null) and (person_role_id = 2) and (note >= '(key set pa: Florida' and true))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`and`(`isnull`(imdbload.cast_info.nr_order), `and`(`eq`(imdbload.cast_info.person_role_id, 2), `ge`(imdbload.cast_info.note, '(key set pa: Florida')))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`eq`(imdbload.cast_info.person_role_id, 2)","row_count":2089611}] +[{"table_name":"cast_info","type":"Column Stats-Point","expr":"((nr_order is null))","row_count":45995275},{"table_name":"cast_info","type":"Column Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":63475835},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((note >= '(key set pa: Florida' and true))","row_count":18437410},{"table_name":"cast_info","type":"Index Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Index Stats-Range","expr":"((nr_order is null) and (person_role_id = 2) and (note >= '(key set pa: Florida' and true))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`and`(`isnull`(imdbload.cast_info.nr_order), `and`(`eq`(imdbload.cast_info.person_role_id, 2), `ge`(imdbload.cast_info.note, '(key set pa: Florida')))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`eq`(imdbload.cast_info.person_role_id, 2)","row_count":2089611}] diff --git a/tests/integrationtest/r/planner/cardinality/selectivity.result b/tests/integrationtest/r/planner/cardinality/selectivity.result index 9ef9acfdcce63..e166e749e4708 100644 --- a/tests/integrationtest/r/planner/cardinality/selectivity.result +++ b/tests/integrationtest/r/planner/cardinality/selectivity.result @@ -396,8 +396,8 @@ TableReader_7 1.00 root data:Selection_6 └─TableFullScan_5 4.00 cop[tikv] table:tdatetime keep order:false explain select * from tint where b=1; id estRows task access object operator info -TableReader_7 1.00 root data:Selection_6 -└─Selection_6 1.00 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 1) +TableReader_7 1.84 root data:Selection_6 +└─Selection_6 1.84 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 1) └─TableFullScan_5 8.00 cop[tikv] table:tint keep order:false explain select * from tint where b=4; id estRows task access object operator info @@ -406,8 +406,8 @@ TableReader_7 1.00 root data:Selection_6 └─TableFullScan_5 8.00 cop[tikv] table:tint keep order:false explain select * from tint where b=8; id estRows task access object operator info -TableReader_7 1.00 root data:Selection_6 -└─Selection_6 1.00 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 8) +TableReader_7 2.07 root data:Selection_6 +└─Selection_6 2.07 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 8) └─TableFullScan_5 8.00 cop[tikv] table:tint keep order:false explain select * from tdouble where b=1; id estRows task access object operator info @@ -471,8 +471,8 @@ TableReader_7 1.00 root data:Selection_6 └─TableFullScan_5 4.00 cop[tikv] table:tdatetime keep order:false explain select * from ct1 where pk>='1' and pk <='4'; id estRows task access object operator info -TableReader_6 5.00 root data:TableRangeScan_5 -└─TableRangeScan_5 5.00 cop[tikv] table:ct1 range:["1","4"], keep order:false +TableReader_6 7.40 root data:TableRangeScan_5 +└─TableRangeScan_5 7.40 cop[tikv] table:ct1 range:["1","4"], keep order:false explain select * from ct1 where pk>='4' and pk <='6'; id estRows task access object operator info TableReader_6 3.75 root data:TableRangeScan_5 @@ -1225,8 +1225,8 @@ insert into t values ('tw', 0); analyze table t all columns; explain select * from t where a = 'tw' and b < 0; id estRows task access object operator info -IndexReader_6 0.00 root index:IndexRangeScan_5 -└─IndexRangeScan_5 0.00 cop[tikv] table:t, index:idx(a, b) range:["tw" -inf,"tw" 0), keep order:false +IndexReader_6 1.00 root index:IndexRangeScan_5 +└─IndexRangeScan_5 1.00 cop[tikv] table:t, index:idx(a, b) range:["tw" -inf,"tw" 0), keep order:false drop table if exists t; create table t(id int auto_increment, kid int, pid int, primary key(id), key(kid, pid)); insert into t (kid, pid) values (1,2), (1,3), (1,4),(1, 11), (1, 12), (1, 13), (1, 14), (2, 2), (2, 3), (2, 4); diff --git a/tests/integrationtest/r/planner/core/partition_pruner.result b/tests/integrationtest/r/planner/core/partition_pruner.result index 699eab20186e6..039d92a3891d8 100644 --- a/tests/integrationtest/r/planner/core/partition_pruner.result +++ b/tests/integrationtest/r/planner/core/partition_pruner.result @@ -3120,8 +3120,8 @@ IndexReader_10 2.00 root partition:all index:Selection_9 └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false explain select * from test1 where partition_no < 200000; id estRows task access object operator info -IndexReader_10 0.00 root partition:2023p1 index:Selection_9 -└─Selection_9 0.00 cop[tikv] lt(issue43459.test1.partition_no, 200000) +IndexReader_10 1.00 root partition:2023p1 index:Selection_9 +└─Selection_9 1.00 cop[tikv] lt(issue43459.test1.partition_no, 200000) └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false explain select * from test1 where partition_no <= 200000; id estRows task access object operator info @@ -3130,8 +3130,8 @@ IndexReader_10 2.00 root partition:all index:Selection_9 └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false explain select * from test1 where partition_no > 200000; id estRows task access object operator info -IndexReader_10 0.00 root partition:2023p2 index:Selection_9 -└─Selection_9 0.00 cop[tikv] gt(issue43459.test1.partition_no, 200000) +IndexReader_10 1.00 root partition:2023p2 index:Selection_9 +└─Selection_9 1.00 cop[tikv] gt(issue43459.test1.partition_no, 200000) └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false select * from test1 partition (2023p1); ID PARTITION_NO CREATE_TIME @@ -3179,8 +3179,8 @@ IndexReader_10 2.00 root partition:all index:Selection_9 └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false explain select * from test1 where partition_no < 200000; id estRows task access object operator info -IndexReader_10 0.00 root partition:2023p1 index:Selection_9 -└─Selection_9 0.00 cop[tikv] lt(issue43459.test1.partition_no, 200000) +IndexReader_10 1.00 root partition:2023p1 index:Selection_9 +└─Selection_9 1.00 cop[tikv] lt(issue43459.test1.partition_no, 200000) └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false explain select * from test1 where partition_no <= 200000; id estRows task access object operator info @@ -3189,8 +3189,8 @@ IndexReader_10 2.00 root partition:all index:Selection_9 └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false explain select * from test1 where partition_no > 200000; id estRows task access object operator info -IndexReader_10 0.00 root partition:2023p2 index:Selection_9 -└─Selection_9 0.00 cop[tikv] gt(issue43459.test1.partition_no, 200000) +IndexReader_10 1.00 root partition:2023p2 index:Selection_9 +└─Selection_9 1.00 cop[tikv] gt(issue43459.test1.partition_no, 200000) └─IndexFullScan_8 2.00 cop[tikv] table:test1, index:PRIMARY(ID, PARTITION_NO, CREATE_TIME) keep order:false select * from test1 partition (2023p1); ID PARTITION_NO CREATE_TIME diff --git a/tests/integrationtest/r/statistics/integration.result b/tests/integrationtest/r/statistics/integration.result index aa9bb4ac163b9..4baea6abdeeb1 100644 --- a/tests/integrationtest/r/statistics/integration.result +++ b/tests/integrationtest/r/statistics/integration.result @@ -17,8 +17,8 @@ explain format = 'brief' select * from t1 left join t2 on t1.a=t2.a order by t1. id estRows task access object operator info Sort 4.00 root statistics__integration.t1.a, statistics__integration.t2.a └─HashJoin 4.00 root left outer join, equal:[eq(statistics__integration.t1.a, statistics__integration.t2.a)] - ├─TableReader(Build) 0.00 root data:Selection - │ └─Selection 0.00 cop[tikv] not(isnull(statistics__integration.t2.a)) + ├─TableReader(Build) 1.00 root data:Selection + │ └─Selection 1.00 cop[tikv] not(isnull(statistics__integration.t2.a)) │ └─TableFullScan 2.00 cop[tikv] table:t2 keep order:false └─TableReader(Probe) 4.00 root data:TableFullScan └─TableFullScan 4.00 cop[tikv] table:t1 keep order:false @@ -26,8 +26,8 @@ explain format = 'brief' select * from t2 left join t1 on t1.a=t2.a order by t1. id estRows task access object operator info Sort 2.00 root statistics__integration.t1.a, statistics__integration.t2.a └─HashJoin 2.00 root left outer join, equal:[eq(statistics__integration.t2.a, statistics__integration.t1.a)] - ├─TableReader(Build) 0.00 root data:Selection - │ └─Selection 0.00 cop[tikv] not(isnull(statistics__integration.t1.a)) + ├─TableReader(Build) 1.00 root data:Selection + │ └─Selection 1.00 cop[tikv] not(isnull(statistics__integration.t1.a)) │ └─TableFullScan 4.00 cop[tikv] table:t1 keep order:false └─TableReader(Probe) 2.00 root data:TableFullScan └─TableFullScan 2.00 cop[tikv] table:t2 keep order:false @@ -35,8 +35,8 @@ explain format = 'brief' select * from t1 right join t2 on t1.a=t2.a order by t1 id estRows task access object operator info Sort 2.00 root statistics__integration.t1.a, statistics__integration.t2.a └─HashJoin 2.00 root right outer join, equal:[eq(statistics__integration.t1.a, statistics__integration.t2.a)] - ├─TableReader(Build) 0.00 root data:Selection - │ └─Selection 0.00 cop[tikv] not(isnull(statistics__integration.t1.a)) + ├─TableReader(Build) 1.00 root data:Selection + │ └─Selection 1.00 cop[tikv] not(isnull(statistics__integration.t1.a)) │ └─TableFullScan 4.00 cop[tikv] table:t1 keep order:false └─TableReader(Probe) 2.00 root data:TableFullScan └─TableFullScan 2.00 cop[tikv] table:t2 keep order:false @@ -44,8 +44,8 @@ explain format = 'brief' select * from t2 right join t1 on t1.a=t2.a order by t1 id estRows task access object operator info Sort 4.00 root statistics__integration.t1.a, statistics__integration.t2.a └─HashJoin 4.00 root right outer join, equal:[eq(statistics__integration.t2.a, statistics__integration.t1.a)] - ├─TableReader(Build) 0.00 root data:Selection - │ └─Selection 0.00 cop[tikv] not(isnull(statistics__integration.t2.a)) + ├─TableReader(Build) 1.00 root data:Selection + │ └─Selection 1.00 cop[tikv] not(isnull(statistics__integration.t2.a)) │ └─TableFullScan 2.00 cop[tikv] table:t2 keep order:false └─TableReader(Probe) 4.00 root data:TableFullScan └─TableFullScan 4.00 cop[tikv] table:t1 keep order:false diff --git a/tests/integrationtest/r/tpch.result b/tests/integrationtest/r/tpch.result index cd2f96f31ce7c..8b89428e68194 100644 --- a/tests/integrationtest/r/tpch.result +++ b/tests/integrationtest/r/tpch.result @@ -112,12 +112,12 @@ order by l_returnflag, l_linestatus; id estRows task access object operator info -Sort 2.94 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus -└─Projection 2.94 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus, Column#18, Column#19, Column#20, Column#21, Column#22, Column#23, Column#24, Column#25 - └─HashAgg 2.94 root group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(Column#26)->Column#18, funcs:sum(Column#27)->Column#19, funcs:sum(Column#28)->Column#20, funcs:sum(Column#29)->Column#21, funcs:avg(Column#30, Column#31)->Column#22, funcs:avg(Column#32, Column#33)->Column#23, funcs:avg(Column#34, Column#35)->Column#24, funcs:count(Column#36)->Column#25, funcs:firstrow(tpch50.lineitem.l_returnflag)->tpch50.lineitem.l_returnflag, funcs:firstrow(tpch50.lineitem.l_linestatus)->tpch50.lineitem.l_linestatus - └─TableReader 2.94 root data:HashAgg - └─HashAgg 2.94 cop[tikv] group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(tpch50.lineitem.l_quantity)->Column#26, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#27, funcs:sum(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)))->Column#28, funcs:sum(mul(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)), plus(1, tpch50.lineitem.l_tax)))->Column#29, funcs:count(tpch50.lineitem.l_quantity)->Column#30, funcs:sum(tpch50.lineitem.l_quantity)->Column#31, funcs:count(tpch50.lineitem.l_extendedprice)->Column#32, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#33, funcs:count(tpch50.lineitem.l_discount)->Column#34, funcs:sum(tpch50.lineitem.l_discount)->Column#35, funcs:count(1)->Column#36 - └─Selection 293797075.24 cop[tikv] le(tpch50.lineitem.l_shipdate, 1998-08-15 00:00:00.000000) +Sort 3.00 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus +└─Projection 3.00 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus, Column#18, Column#19, Column#20, Column#21, Column#22, Column#23, Column#24, Column#25 + └─HashAgg 3.00 root group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(Column#26)->Column#18, funcs:sum(Column#27)->Column#19, funcs:sum(Column#28)->Column#20, funcs:sum(Column#29)->Column#21, funcs:avg(Column#30, Column#31)->Column#22, funcs:avg(Column#32, Column#33)->Column#23, funcs:avg(Column#34, Column#35)->Column#24, funcs:count(Column#36)->Column#25, funcs:firstrow(tpch50.lineitem.l_returnflag)->tpch50.lineitem.l_returnflag, funcs:firstrow(tpch50.lineitem.l_linestatus)->tpch50.lineitem.l_linestatus + └─TableReader 3.00 root data:HashAgg + └─HashAgg 3.00 cop[tikv] group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(tpch50.lineitem.l_quantity)->Column#26, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#27, funcs:sum(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)))->Column#28, funcs:sum(mul(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)), plus(1, tpch50.lineitem.l_tax)))->Column#29, funcs:count(tpch50.lineitem.l_quantity)->Column#30, funcs:sum(tpch50.lineitem.l_quantity)->Column#31, funcs:count(tpch50.lineitem.l_extendedprice)->Column#32, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#33, funcs:count(tpch50.lineitem.l_discount)->Column#34, funcs:sum(tpch50.lineitem.l_discount)->Column#35, funcs:count(1)->Column#36 + └─Selection 300005811.00 cop[tikv] le(tpch50.lineitem.l_shipdate, 1998-08-15 00:00:00.000000) └─TableFullScan 300005811.00 cop[tikv] table:lineitem keep order:false /* Q2 Minimum Cost Supplier Query @@ -245,20 +245,20 @@ limit 10; id estRows task access object operator info Projection 10.00 root tpch50.lineitem.l_orderkey, Column#35, tpch50.orders.o_orderdate, tpch50.orders.o_shippriority └─TopN 10.00 root Column#35:desc, tpch50.orders.o_orderdate, offset:0, count:10 - └─HashAgg 39877917.22 root group by:Column#45, Column#46, Column#47, funcs:sum(Column#44)->Column#35, funcs:firstrow(Column#45)->tpch50.orders.o_orderdate, funcs:firstrow(Column#46)->tpch50.orders.o_shippriority, funcs:firstrow(Column#47)->tpch50.lineitem.l_orderkey + └─HashAgg 73908224.00 root group by:Column#45, Column#46, Column#47, funcs:sum(Column#44)->Column#35, funcs:firstrow(Column#45)->tpch50.orders.o_orderdate, funcs:firstrow(Column#46)->tpch50.orders.o_shippriority, funcs:firstrow(Column#47)->tpch50.lineitem.l_orderkey └─Projection 93262952.04 root mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount))->Column#44, tpch50.orders.o_orderdate->Column#45, tpch50.orders.o_shippriority->Column#46, tpch50.lineitem.l_orderkey->Column#47 └─IndexHashJoin 93262952.04 root inner join, inner:IndexLookUp, outer key:tpch50.orders.o_orderkey, inner key:tpch50.lineitem.l_orderkey, equal cond:eq(tpch50.orders.o_orderkey, tpch50.lineitem.l_orderkey) ├─HashJoin(Build) 22975885.46 root inner join, equal:[eq(tpch50.customer.c_custkey, tpch50.orders.o_custkey)] │ ├─TableReader(Build) 1508884.60 root data:Selection │ │ └─Selection 1508884.60 cop[tikv] eq(tpch50.customer.c_mktsegment, "AUTOMOBILE") │ │ └─TableFullScan 7500000.00 cop[tikv] table:customer keep order:false - │ └─TableReader(Probe) 36496376.60 root data:Selection - │ └─Selection 36496376.60 cop[tikv] lt(tpch50.orders.o_orderdate, 1995-03-13 00:00:00.000000) + │ └─TableReader(Probe) 73445478.60 root data:Selection + │ └─Selection 73445478.60 cop[tikv] lt(tpch50.orders.o_orderdate, 1995-03-13 00:00:00.000000) │ └─TableFullScan 75000000.00 cop[tikv] table:orders keep order:false └─IndexLookUp(Probe) 93262952.04 root - ├─IndexRangeScan(Build) 172850029.03 cop[tikv] table:lineitem, index:PRIMARY(L_ORDERKEY, L_LINENUMBER) range: decided by [eq(tpch50.lineitem.l_orderkey, tpch50.orders.o_orderkey)], keep order:false + ├─IndexRangeScan(Build) 93262952.04 cop[tikv] table:lineitem, index:PRIMARY(L_ORDERKEY, L_LINENUMBER) range: decided by [eq(tpch50.lineitem.l_orderkey, tpch50.orders.o_orderkey)], keep order:false └─Selection(Probe) 93262952.04 cop[tikv] gt(tpch50.lineitem.l_shipdate, 1995-03-13 00:00:00.000000) - └─TableRowIDScan 172850029.03 cop[tikv] table:lineitem keep order:false + └─TableRowIDScan 93262952.04 cop[tikv] table:lineitem keep order:false /* Q4 Order Priority Checking Query This query determines how well the order priority system is working and gives an assessment of customer satisfaction. @@ -1292,10 +1292,10 @@ id estRows task access object operator info Sort 1.00 root Column#31 └─Projection 1.00 root Column#31, Column#32, Column#33 └─HashAgg 1.00 root group by:Column#36, funcs:count(1)->Column#32, funcs:sum(Column#35)->Column#33, funcs:firstrow(Column#36)->Column#31 - └─Projection 0.00 root tpch50.customer.c_acctbal->Column#35, substring(tpch50.customer.c_phone, 1, 2)->Column#36 - └─HashJoin 0.00 root anti semi join, equal:[eq(tpch50.customer.c_custkey, tpch50.orders.o_custkey)] + └─Projection 0.64 root tpch50.customer.c_acctbal->Column#35, substring(tpch50.customer.c_phone, 1, 2)->Column#36 + └─HashJoin 0.64 root anti semi join, equal:[eq(tpch50.customer.c_custkey, tpch50.orders.o_custkey)] ├─TableReader(Build) 75000000.00 root data:TableFullScan │ └─TableFullScan 75000000.00 cop[tikv] table:orders keep order:false - └─TableReader(Probe) 0.00 root data:Selection - └─Selection 0.00 cop[tikv] gt(tpch50.customer.c_acctbal, NULL), in(substring(tpch50.customer.c_phone, 1, 2), "20", "40", "22", "30", "39", "42", "21") + └─TableReader(Probe) 0.80 root data:Selection + └─Selection 0.80 cop[tikv] gt(tpch50.customer.c_acctbal, NULL), in(substring(tpch50.customer.c_phone, 1, 2), "20", "40", "22", "30", "39", "42", "21") └─TableFullScan 7500000.00 cop[tikv] table:customer keep order:false diff --git a/tests/integrationtest/r/util/ranger.result b/tests/integrationtest/r/util/ranger.result index 7acee71598f07..c6cfef64a3bc3 100644 --- a/tests/integrationtest/r/util/ranger.result +++ b/tests/integrationtest/r/util/ranger.result @@ -195,12 +195,12 @@ select * from t where a = 3; a b explain format='brief' select * from t where a < 1; id estRows task access object operator info -PartitionUnion 1.00 root +PartitionUnion 2.00 root ├─TableReader 1.00 root data:Selection │ └─Selection 1.00 cop[tikv] lt(util__ranger.t.a, 1) │ └─TableFullScan 1.00 cop[tikv] table:t, partition:p0 keep order:false -└─TableReader 0.00 root data:Selection - └─Selection 0.00 cop[tikv] lt(util__ranger.t.a, 1) +└─TableReader 1.00 root data:Selection + └─Selection 1.00 cop[tikv] lt(util__ranger.t.a, 1) └─TableFullScan 3.00 cop[tikv] table:t, partition:p1 keep order:false select * from t where a < 1; a b @@ -227,9 +227,9 @@ select * from t where a < -1; a b explain format='brief' select * from t where a > 0; id estRows task access object operator info -PartitionUnion 3.00 root -├─TableReader 0.00 root data:Selection -│ └─Selection 0.00 cop[tikv] gt(util__ranger.t.a, 0) +PartitionUnion 4.00 root +├─TableReader 1.00 root data:Selection +│ └─Selection 1.00 cop[tikv] gt(util__ranger.t.a, 0) │ └─TableFullScan 1.00 cop[tikv] table:t, partition:p0 keep order:false └─TableReader 3.00 root data:Selection └─Selection 3.00 cop[tikv] gt(util__ranger.t.a, 0) @@ -256,12 +256,12 @@ a b  3 explain format='brief' select * from t where a > 3; id estRows task access object operator info -PartitionUnion 0.00 root -├─TableReader 0.00 root data:Selection -│ └─Selection 0.00 cop[tikv] gt(util__ranger.t.a, 3) +PartitionUnion 2.00 root +├─TableReader 1.00 root data:Selection +│ └─Selection 1.00 cop[tikv] gt(util__ranger.t.a, 3) │ └─TableFullScan 1.00 cop[tikv] table:t, partition:p0 keep order:false -└─TableReader 0.00 root data:Selection - └─Selection 0.00 cop[tikv] gt(util__ranger.t.a, 3) +└─TableReader 1.00 root data:Selection + └─Selection 1.00 cop[tikv] gt(util__ranger.t.a, 3) └─TableFullScan 3.00 cop[tikv] table:t, partition:p1 keep order:false select * from t where a > 3; a b From 182f23f14056e62fd28ca9ea3baf775f939491e7 Mon Sep 17 00:00:00 2001 From: tpp Date: Thu, 8 Aug 2024 12:21:32 -0500 Subject: [PATCH 18/35] testcase updates17 --- pkg/planner/cardinality/selectivity_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index b82760d15315f..8c5c431aa5529 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -938,7 +938,7 @@ func TestIssue39593(t *testing.T) { count, err = cardinality.GetRowCountByIndexRanges(sctx.GetPlanCtx(), &statsTbl.HistColl, idxID, getRanges(vals, vals)) require.NoError(t, err) // estimated row count after mock modify on the table - require.Equal(t, float64(3600), count) + require.Equal(t, float64(3870.1135540008545), count) } func TestIndexJoinInnerRowCountUpperBound(t *testing.T) { From bf691aa3227e001e52cabd0ba982358335c14f84 Mon Sep 17 00:00:00 2001 From: tpp Date: Thu, 8 Aug 2024 14:47:22 -0500 Subject: [PATCH 19/35] testcase updates18 --- .../testdata/integration_suite_out.json | 14 +- .../testdata/json_plan_suite_out.json | 158 ++---------------- .../testdata/plan_normalized_suite_out.json | 4 +- 3 files changed, 19 insertions(+), 157 deletions(-) diff --git a/pkg/planner/core/casetest/testdata/integration_suite_out.json b/pkg/planner/core/casetest/testdata/integration_suite_out.json index 56e395c23816e..ffe733cb7fec8 100644 --- a/pkg/planner/core/casetest/testdata/integration_suite_out.json +++ b/pkg/planner/core/casetest/testdata/integration_suite_out.json @@ -93,25 +93,25 @@ "SQL": "explain format = 'verbose' select count(*) from t3 where b = 0", "Plan": [ "StreamAgg_10 1.00 64.98 root funcs:count(1)->Column#4", - "└─IndexReader_15 0.00 15.08 root index:IndexRangeScan_14", - " └─IndexRangeScan_14 0.00 162.80 cop[tikv] table:t3, index:c(b) range:[0,0], keep order:false" + "└─IndexReader_15 1.00 15.08 root index:IndexRangeScan_14", + " └─IndexRangeScan_14 1.00 162.80 cop[tikv] table:t3, index:c(b) range:[0,0], keep order:false" ] }, { "SQL": "explain format = 'verbose' select /*+ use_index(t3, c) */ count(a) from t3 where b = 0", "Plan": [ "StreamAgg_10 1.00 2001.63 root funcs:count(test.t3.a)->Column#4", - "└─IndexLookUp_17 0.00 1951.73 root ", - " ├─IndexRangeScan_15(Build) 0.00 203.50 cop[tikv] table:t3, index:c(b) range:[0,0], keep order:false", - " └─TableRowIDScan_16(Probe) 0.00 227.31 cop[tikv] table:t3 keep order:false" + "└─IndexLookUp_17 1.00 1951.73 root ", + " ├─IndexRangeScan_15(Build) 1.00 203.50 cop[tikv] table:t3, index:c(b) range:[0,0], keep order:false", + " └─TableRowIDScan_16(Probe) 1.00 227.31 cop[tikv] table:t3 keep order:false" ] }, { "SQL": "explain format = 'verbose' select count(*) from t2 where a = 0", "Plan": [ "StreamAgg_12 1.00 109.57 root funcs:count(1)->Column#4", - "└─TableReader_20 0.00 59.67 root data:Selection_19", - " └─Selection_19 0.00 831.62 cop[tikv] eq(test.t2.a, 0)", + "└─TableReader_20 1.00 59.67 root data:Selection_19", + " └─Selection_19 1.00 831.62 cop[tikv] eq(test.t2.a, 0)", " └─TableFullScan_18 3.00 681.92 cop[tikv] table:t2 keep order:false" ] }, diff --git a/pkg/planner/core/casetest/testdata/json_plan_suite_out.json b/pkg/planner/core/casetest/testdata/json_plan_suite_out.json index 10ba3c93ad7fb..b1174dcc03c63 100644 --- a/pkg/planner/core/casetest/testdata/json_plan_suite_out.json +++ b/pkg/planner/core/casetest/testdata/json_plan_suite_out.json @@ -3,162 +3,24 @@ "Name": "TestJSONPlanInExplain", "Cases": [ { - "SQL": "explain format = tidb_json update t2 set id = 1 where id =2", - "JSONPlan": [ - { - "id": "Update_4", - "estRows": "N/A", - "taskType": "root", - "operatorInfo": "N/A", - "subOperators": [ - { - "id": "IndexReader_7", - "estRows": "10.00", - "taskType": "root", - "operatorInfo": "index:IndexRangeScan_6", - "subOperators": [ - { - "id": "IndexRangeScan_6", - "estRows": "10.00", - "taskType": "cop[tikv]", - "accessObject": "table:t2, index:id(id)", - "operatorInfo": "range:[2,2], keep order:false, stats:pseudo" - } - ] - } - ] - } - ] + "SQL": "", + "JSONPlan": null }, { - "SQL": "explain format = tidb_json insert into t1 values(1)", - "JSONPlan": [ - { - "id": "Insert_1", - "estRows": "N/A", - "taskType": "root", - "operatorInfo": "N/A" - } - ] + "SQL": "", + "JSONPlan": null }, { - "SQL": "explain format = tidb_json select count(*) from t1", - "JSONPlan": [ - { - "id": "HashAgg_12", - "estRows": "1.00", - "taskType": "root", - "operatorInfo": "funcs:count(Column#5)->Column#3", - "subOperators": [ - { - "id": "TableReader_13", - "estRows": "1.00", - "taskType": "root", - "operatorInfo": "data:HashAgg_5", - "subOperators": [ - { - "id": "HashAgg_5", - "estRows": "1.00", - "taskType": "cop[tikv]", - "operatorInfo": "funcs:count(test.t1._tidb_rowid)->Column#5", - "subOperators": [ - { - "id": "TableFullScan_10", - "estRows": "10000.00", - "taskType": "cop[tikv]", - "accessObject": "table:t1", - "operatorInfo": "keep order:false, stats:pseudo" - } - ] - } - ] - } - ] - } - ] + "SQL": "", + "JSONPlan": null }, { - "SQL": "explain format = tidb_json select * from t1", - "JSONPlan": [ - { - "id": "IndexReader_7", - "estRows": "10000.00", - "taskType": "root", - "operatorInfo": "index:IndexFullScan_6", - "subOperators": [ - { - "id": "IndexFullScan_6", - "estRows": "10000.00", - "taskType": "cop[tikv]", - "accessObject": "table:t1, index:id(id)", - "operatorInfo": "keep order:false, stats:pseudo" - } - ] - } - ] + "SQL": "", + "JSONPlan": null }, { - "SQL": "explain analyze format = tidb_json select * from t1, t2 where t1.id = t2.id", - "JSONPlan": [ - { - "id": "MergeJoin_8", - "estRows": "12487.50", - "actRows": "0", - "taskType": "root", - "executeInfo": "time:3.5ms, loops:1", - "operatorInfo": "inner join, left key:test.t1.id, right key:test.t2.id", - "memoryInfo": "760 Bytes", - "diskInfo": "0 Bytes", - "subOperators": [ - { - "id": "IndexReader_36(Build)", - "estRows": "9990.00", - "actRows": "0", - "taskType": "root", - "executeInfo": "time:3.47ms, loops:1, cop_task: {num: 1, max: 3.38ms, proc_keys: 0, tot_proc: 3ms, rpc_num: 1, rpc_time: 3.34ms, copr_cache_hit_ratio: 0.00, distsql_concurrency: 15}", - "operatorInfo": "index:IndexFullScan_35", - "memoryInfo": "171 Bytes", - "diskInfo": "N/A", - "subOperators": [ - { - "id": "IndexFullScan_35", - "estRows": "9990.00", - "actRows": "0", - "taskType": "cop[tikv]", - "accessObject": "table:t2, index:id(id)", - "executeInfo": "tikv_task:{time:3.3ms, loops:0}", - "operatorInfo": "keep order:true, stats:pseudo", - "memoryInfo": "N/A", - "diskInfo": "N/A" - } - ] - }, - { - "id": "IndexReader_34(Probe)", - "estRows": "9990.00", - "actRows": "0", - "taskType": "root", - "executeInfo": "time:14µs, loops:1, cop_task: {num: 1, max: 772.9µs, proc_keys: 0, rpc_num: 1, rpc_time: 735.7µs, copr_cache_hit_ratio: 0.00, distsql_concurrency: 15}", - "operatorInfo": "index:IndexFullScan_33", - "memoryInfo": "166 Bytes", - "diskInfo": "N/A", - "subOperators": [ - { - "id": "IndexFullScan_33", - "estRows": "9990.00", - "actRows": "0", - "taskType": "cop[tikv]", - "accessObject": "table:t1, index:id(id)", - "executeInfo": "tikv_task:{time:168.4µs, loops:0}", - "operatorInfo": "keep order:true, stats:pseudo", - "memoryInfo": "N/A", - "diskInfo": "N/A" - } - ] - } - ] - } - ] + "SQL": "", + "JSONPlan": null } ] } diff --git a/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json b/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json index bc910fa299f9e..b375ba4ee0560 100644 --- a/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json +++ b/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json @@ -443,8 +443,8 @@ "Plan": [ " TableReader root ", " └─ExchangeSender cop[tiflash] ", - " └─Selection cop[tiflash] gt(test.t1.b, ?), or(lt(test.t1.a, ?), lt(test.t1.b, ?))", - " └─TableFullScan cop[tiflash] table:t1, range:[?,?], pushed down filter:gt(test.t1.a, ?), keep order:false" + " └─Selection cop[tiflash] gt(test.t1.a, ?), or(lt(test.t1.a, ?), lt(test.t1.b, ?))", + " └─TableFullScan cop[tiflash] table:t1, range:[?,?], pushed down filter:gt(test.t1.b, ?), keep order:false" ] }, { From cfac163b22690c2114a405e59572f4162747ee35 Mon Sep 17 00:00:00 2001 From: tpp Date: Thu, 8 Aug 2024 15:11:12 -0500 Subject: [PATCH 20/35] testcase updates19 --- .../casetest/planstats/testdata/plan_stats_suite_out.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/planner/core/casetest/planstats/testdata/plan_stats_suite_out.json b/pkg/planner/core/casetest/planstats/testdata/plan_stats_suite_out.json index 92f1647bb88d7..e61f53ded7ca2 100644 --- a/pkg/planner/core/casetest/planstats/testdata/plan_stats_suite_out.json +++ b/pkg/planner/core/casetest/planstats/testdata/plan_stats_suite_out.json @@ -116,10 +116,10 @@ { "Query": "explain format = brief select * from t join tp where tp.a = 10 and t.b = tp.c", "Result": [ - "Projection 0.00 root test.t.a, test.t.b, test.t.c, test.tp.a, test.tp.b, test.tp.c", - "└─HashJoin 0.00 root inner join, equal:[eq(test.tp.c, test.t.b)]", - " ├─TableReader(Build) 0.00 root partition:p1 data:Selection", - " │ └─Selection 0.00 cop[tikv] eq(test.tp.a, 10), not(isnull(test.tp.c))", + "Projection 1.00 root test.t.a, test.t.b, test.t.c, test.tp.a, test.tp.b, test.tp.c", + "└─HashJoin 1.00 root inner join, equal:[eq(test.tp.c, test.t.b)]", + " ├─TableReader(Build) 1.00 root partition:p1 data:Selection", + " │ └─Selection 1.00 cop[tikv] eq(test.tp.a, 10), not(isnull(test.tp.c))", " │ └─TableFullScan 6.00 cop[tikv] table:tp keep order:false, stats:partial[c:allEvicted]", " └─TableReader(Probe) 3.00 root data:Selection", " └─Selection 3.00 cop[tikv] not(isnull(test.t.b))", From 3ee52e2f74ae91229a4da4859995a9b4103c3cf6 Mon Sep 17 00:00:00 2001 From: tpp Date: Thu, 8 Aug 2024 15:47:37 -0500 Subject: [PATCH 21/35] testcase updates20 --- .../planner/core/casetest/integration.result | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/integrationtest/r/planner/core/casetest/integration.result b/tests/integrationtest/r/planner/core/casetest/integration.result index 09c6bad54fd27..9946469915370 100644 --- a/tests/integrationtest/r/planner/core/casetest/integration.result +++ b/tests/integrationtest/r/planner/core/casetest/integration.result @@ -1154,17 +1154,17 @@ TableReader_7 3.00 130.42 root data:Selection_6 └─TableFullScan_5 5.00 1136.54 cop[tikv] table:t keep order:false explain format = 'verbose' select * from t where b = 6 order by a limit 1; id estRows estCost task access object operator info -Limit_11 0.00 98.74 root offset:0, count:1 -└─TableReader_24 0.00 98.74 root data:Limit_23 - └─Limit_23 0.00 1386.04 cop[tikv] offset:0, count:1 - └─Selection_22 0.00 1386.04 cop[tikv] eq(planner__core__casetest__integration.t.b, 6) +Limit_11 1.00 98.74 root offset:0, count:1 +└─TableReader_24 1.00 98.74 root data:Limit_23 + └─Limit_23 1.00 1386.04 cop[tikv] offset:0, count:1 + └─Selection_22 1.00 1386.04 cop[tikv] eq(planner__core__casetest__integration.t.b, 6) └─TableFullScan_21 5.00 1136.54 cop[tikv] table:t keep order:true explain format = 'verbose' select * from t where b = 6 limit 1; id estRows estCost task access object operator info -Limit_8 0.00 98.74 root offset:0, count:1 -└─TableReader_13 0.00 98.74 root data:Limit_12 - └─Limit_12 0.00 1386.04 cop[tikv] offset:0, count:1 - └─Selection_11 0.00 1386.04 cop[tikv] eq(planner__core__casetest__integration.t.b, 6) +Limit_8 1.00 98.74 root offset:0, count:1 +└─TableReader_13 1.00 98.74 root data:Limit_12 + └─Limit_12 1.00 1386.04 cop[tikv] offset:0, count:1 + └─Selection_11 1.00 1386.04 cop[tikv] eq(planner__core__casetest__integration.t.b, 6) └─TableFullScan_10 5.00 1136.54 cop[tikv] table:t keep order:false set tidb_opt_prefer_range_scan = 1; explain format = 'verbose' select * from t where b > 5; @@ -1176,19 +1176,19 @@ Level Code Message Note 1105 [idx_b] remain after pruning paths for t given Prop{SortItems: [], TaskTp: rootTask} explain format = 'verbose' select * from t where b = 6 order by a limit 1; id estRows estCost task access object operator info -TopN_9 0.00 1956.63 root planner__core__casetest__integration.t.a, offset:0, count:1 -└─IndexLookUp_16 0.00 1951.83 root - ├─TopN_15(Build) 0.00 206.70 cop[tikv] planner__core__casetest__integration.t.a, offset:0, count:1 - │ └─IndexRangeScan_13 0.00 203.50 cop[tikv] table:t, index:idx_b(b) range:[6,6], keep order:false - └─TableRowIDScan_14(Probe) 0.00 186.61 cop[tikv] table:t keep order:false +TopN_9 1.00 1956.63 root planner__core__casetest__integration.t.a, offset:0, count:1 +└─IndexLookUp_16 1.00 1951.83 root + ├─TopN_15(Build) 1.00 206.70 cop[tikv] planner__core__casetest__integration.t.a, offset:0, count:1 + │ └─IndexRangeScan_13 1.00 203.50 cop[tikv] table:t, index:idx_b(b) range:[6,6], keep order:false + └─TableRowIDScan_14(Probe) 1.00 186.61 cop[tikv] table:t keep order:false Level Code Message Note 1105 [idx_b] remain after pruning paths for t given Prop{SortItems: [], TaskTp: copMultiReadTask} explain format = 'verbose' select * from t where b = 6 limit 1; id estRows estCost task access object operator info -IndexLookUp_13 0.00 1170.97 root limit embedded(offset:0, count:1) -├─Limit_12(Build) 0.00 203.50 cop[tikv] offset:0, count:1 -│ └─IndexRangeScan_10 0.00 203.50 cop[tikv] table:t, index:idx_b(b) range:[6,6], keep order:false -└─TableRowIDScan_11(Probe) 0.00 186.61 cop[tikv] table:t keep order:false +IndexLookUp_13 1.00 1170.97 root limit embedded(offset:0, count:1) +├─Limit_12(Build) 1.00 203.50 cop[tikv] offset:0, count:1 +│ └─IndexRangeScan_10 1.00 203.50 cop[tikv] table:t, index:idx_b(b) range:[6,6], keep order:false +└─TableRowIDScan_11(Probe) 1.00 186.61 cop[tikv] table:t keep order:false Level Code Message Note 1105 [idx_b] remain after pruning paths for t given Prop{SortItems: [], TaskTp: copMultiReadTask} set @@tidb_enable_chunk_rpc = default; From 1a97fa78118f8aed0a41244d8c35dc829ae8ed8d Mon Sep 17 00:00:00 2001 From: tpp Date: Sat, 10 Aug 2024 13:14:45 -0500 Subject: [PATCH 22/35] code update1 --- pkg/statistics/histogram.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/pkg/statistics/histogram.go b/pkg/statistics/histogram.go index 4be9a5398b59e..de0322fafd4d4 100644 --- a/pkg/statistics/histogram.go +++ b/pkg/statistics/histogram.go @@ -1047,17 +1047,22 @@ func (hg *Histogram) OutOfRangeRowCount( return min(rowCount, upperBound) } - // If the modifyCount is large (compared to original table rows), then any out of range estimate is unreliable. - // Assume at least 1/NDV is returned - if float64(modifyCount) > hg.NotNullCount() && rowCount < upperBound { - rowCount = upperBound - } else if rowCount < upperBound { - // Adjust by increaseFactor if our estimate is low - rowCount *= increaseFactor + if modifyCount > 0 { + modifyBound := float64(modifyCount) / float64(histNDV) + if modifyBound < upperBound { + upperBound += modifyBound + } else { + upperBound = modifyBound + } } - - // Use modifyCount as a final bound - if modifyCount > 0 && rowCount > float64(modifyCount) { + if rowCount < upperBound { + rowCount *= increaseFactor + if float64(modifyCount) > 0 && rowCount == 0 { + rowCount = math.Min(upperBound, float64(modifyCount)) + } else { + rowCount = math.Min(rowCount, upperBound) + } + } else if modifyCount > 0 && rowCount > float64(modifyCount) { rowCount = float64(modifyCount) } return rowCount From a6132fc72906a720f47a15846fa4bb2aa2fac1b0 Mon Sep 17 00:00:00 2001 From: tpp Date: Sat, 10 Aug 2024 21:23:57 -0700 Subject: [PATCH 23/35] testcase updates after code change1 --- pkg/planner/cardinality/selectivity_test.go | 2 +- .../cardinality/testdata/cardinality_suite_out.json | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 8c5c431aa5529..1d64d537f4470 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -427,7 +427,7 @@ func TestSelectivity(t *testing.T) { { exprs: "a > 1 and b < 2 and c > 3 and d < 4 and e > 5", selectivity: 0.00099451303, - selectivityAfterIncrease: 1.51329827770157e-05, + selectivityAfterIncrease: 3.772290809327846e-05, }, { exprs: longExpr, diff --git a/pkg/planner/cardinality/testdata/cardinality_suite_out.json b/pkg/planner/cardinality/testdata/cardinality_suite_out.json index 682889e62880e..f654cdafd3870 100644 --- a/pkg/planner/cardinality/testdata/cardinality_suite_out.json +++ b/pkg/planner/cardinality/testdata/cardinality_suite_out.json @@ -24,7 +24,7 @@ { "Start": 800, "End": 900, - "Count": 764.004166655054 + "Count": 750 }, { "Start": 900, @@ -69,7 +69,7 @@ { "Start": 1500, "End": 1600, - "Count": 7.5 + "Count": 15 }, { "Start": 300, @@ -79,7 +79,7 @@ { "Start": 800, "End": 1000, - "Count": 1222.196869573942 + "Count": 1205.696869573942 }, { "Start": 900, @@ -104,7 +104,7 @@ { "Start": 200, "End": 400, - "Count": 1198.5288209899081 + "Count": 1199.2788209899081 }, { "Start": 200, From 9d8a60d4b7a6ad211e8122b3ec5fb51ec238c382 Mon Sep 17 00:00:00 2001 From: tpp Date: Mon, 12 Aug 2024 17:41:04 -0700 Subject: [PATCH 24/35] code correct outofrange --- pkg/statistics/histogram.go | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/pkg/statistics/histogram.go b/pkg/statistics/histogram.go index de0322fafd4d4..6facad33b5cad 100644 --- a/pkg/statistics/histogram.go +++ b/pkg/statistics/histogram.go @@ -1047,25 +1047,16 @@ func (hg *Histogram) OutOfRangeRowCount( return min(rowCount, upperBound) } - if modifyCount > 0 { - modifyBound := float64(modifyCount) / float64(histNDV) - if modifyBound < upperBound { - upperBound += modifyBound - } else { - upperBound = modifyBound - } - } - if rowCount < upperBound { + // If the modifyCount is large (compared to original table rows), then any out of range estimate is unreliable. + // Assume at least 1/NDV is returned + if float64(modifyCount) > hg.NotNullCount() && rowCount < upperBound { + rowCount = upperBound + } else if rowCount < upperBound { + // Adjust by increaseFactor if our estimate is low rowCount *= increaseFactor - if float64(modifyCount) > 0 && rowCount == 0 { - rowCount = math.Min(upperBound, float64(modifyCount)) - } else { - rowCount = math.Min(rowCount, upperBound) - } - } else if modifyCount > 0 && rowCount > float64(modifyCount) { - rowCount = float64(modifyCount) } - return rowCount + // Use modifyCount as a final bound + return min(rowCount, float64(modifyCount)) } // Copy deep copies the histogram. From 34ca5330eb6428f18cd3892772d94c3d7ddbfe92 Mon Sep 17 00:00:00 2001 From: tpp Date: Mon, 12 Aug 2024 17:59:04 -0700 Subject: [PATCH 25/35] after modify test1 --- .../testdata/cardinality_suite_out.json | 132 +++++++++--------- pkg/statistics/histogram.go | 1 + 2 files changed, 67 insertions(+), 66 deletions(-) diff --git a/pkg/planner/cardinality/testdata/cardinality_suite_out.json b/pkg/planner/cardinality/testdata/cardinality_suite_out.json index f654cdafd3870..1c9a46185cc38 100644 --- a/pkg/planner/cardinality/testdata/cardinality_suite_out.json +++ b/pkg/planner/cardinality/testdata/cardinality_suite_out.json @@ -24,7 +24,7 @@ { "Start": 800, "End": 900, - "Count": 750 + "Count": 762.504166655054 }, { "Start": 900, @@ -69,7 +69,7 @@ { "Start": 1500, "End": 1600, - "Count": 15 + "Count": 7.5 }, { "Start": 300, @@ -79,7 +79,7 @@ { "Start": 800, "End": 1000, - "Count": 1205.696869573942 + "Count": 1220.696869573942 }, { "Start": 900, @@ -104,7 +104,7 @@ { "Start": 200, "End": 400, - "Count": 1199.2788209899081 + "Count": 1215.0288209899081 }, { "Start": 200, @@ -976,9 +976,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 200.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 200.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -986,9 +986,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, desc, stats:pseudo", + " ├─IndexFullScan(Build) 200.00 cop[tikv] table:t, index:ic(c) keep order:true, desc, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 200.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -996,9 +996,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 9.99 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 8999)", - " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 9.99 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1006,9 +1006,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9000)", - " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 10.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1016,9 +1016,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 10.01 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9001)", - " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 10.01 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1027,7 +1027,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10001)", - "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1037,7 +1037,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10000)", - "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1047,7 +1047,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 9999)", - "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1068,7 +1068,7 @@ "└─TableReader 1.00 root data:Limit", " └─Limit 1.00 cop[tikv] offset:0, count:1", " └─Selection 1.00 cop[tikv] lt(test.t.c, 100)", - " └─TableRangeScan 1.96 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" + " └─TableRangeScan 100.00 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" ] }, { @@ -1078,21 +1078,21 @@ { "Query": "explain format = 'brief' select * from t where b >= 9950 order by c limit 1", "Result": [ - "Limit 1.00 root offset:0, count:1", + "TopN 1.00 root test.t.c, offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", - " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" + " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ib(b) range:[9950,+inf], keep order:false, stats:pseudo", + " └─TopN(Probe) 1.00 cop[tikv] test.t.c, offset:0, count:1", + " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where b >= 9950 order by c desc limit 1", "Result": [ - "Limit 1.00 root offset:0, count:1", + "TopN 1.00 root test.t.c:desc, offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.98 cop[tikv] table:t, index:ic(c) keep order:true, desc, stats:pseudo", - " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" + " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ib(b) range:[9950,+inf], keep order:false, stats:pseudo", + " └─TopN(Probe) 1.00 cop[tikv] test.t.c:desc, offset:0, count:1", + " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1100,9 +1100,9 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 9.99 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 8999)", - " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 9.99 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1110,19 +1110,19 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + " ├─IndexFullScan(Build) 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9000)", - " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 10.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where b >= 9001 order by c limit 1", "Result": [ - "Limit 1.00 root offset:0, count:1", - "└─IndexLookUp 1.00 root ", - " ├─IndexFullScan(Build) 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", - " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9001)", - " └─TableRowIDScan 1.67 cop[tikv] table:t keep order:false, stats:pseudo" + "TopN 1.00 root test.t.c, offset:0, count:1", + "└─TableReader 1.00 root data:TopN", + " └─TopN 1.00 cop[tikv] test.t.c, offset:0, count:1", + " └─Selection 9990.00 cop[tikv] ge(test.t.b, 9001)", + " └─TableFullScan 100000.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1131,7 +1131,7 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10001)", - "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, @@ -1141,18 +1141,17 @@ "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 10000)", - "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", + "│ └─IndexFullScan 10.00 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where a < 9999 order by c limit 1", "Result": [ - "IndexLookUp 1.00 root limit embedded(offset:0, count:1)", - "├─Limit(Build) 1.00 cop[tikv] offset:0, count:1", - "│ └─Selection 1.00 cop[tikv] lt(test.t.a, 9999)", - "│ └─IndexFullScan 1.67 cop[tikv] table:t, index:ic(c) keep order:true, stats:pseudo", - "└─TableRowIDScan(Probe) 1.00 cop[tikv] table:t keep order:false, stats:pseudo" + "TopN 1.00 root test.t.c, offset:0, count:1", + "└─TableReader 1.00 root data:TopN", + " └─TopN 1.00 cop[tikv] test.t.c, offset:0, count:1", + " └─TableRangeScan 9999.00 cop[tikv] table:t range:[-inf,9999), keep order:false, stats:pseudo" ] }, { @@ -1171,19 +1170,20 @@ "Result": [ "Limit 1.00 root offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexRangeScan(Build) 1.98 cop[tikv] table:t, index:ic(c) range:[9950,+inf], keep order:true, stats:pseudo", + " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ic(c) range:[9950,+inf], keep order:true, stats:pseudo", " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" + " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { "Query": "explain format = 'brief' select * from t where b >= 9950 and c >= 9900 order by c limit 1", "Result": [ - "Limit 1.00 root offset:0, count:1", + "TopN 1.00 root test.t.c, offset:0, count:1", "└─IndexLookUp 1.00 root ", - " ├─IndexRangeScan(Build) 1.98 cop[tikv] table:t, index:ic(c) range:[9900,+inf], keep order:true, stats:pseudo", - " └─Selection(Probe) 1.00 cop[tikv] ge(test.t.b, 9950)", - " └─TableRowIDScan 1.98 cop[tikv] table:t keep order:false, stats:pseudo" + " ├─IndexRangeScan(Build) 500.00 cop[tikv] table:t, index:ib(b) range:[9950,+inf], keep order:false, stats:pseudo", + " └─TopN(Probe) 1.00 cop[tikv] test.t.c, offset:0, count:1", + " └─Selection 5.00 cop[tikv] ge(test.t.c, 9900)", + " └─TableRowIDScan 500.00 cop[tikv] table:t keep order:false, stats:pseudo" ] }, { @@ -1193,7 +1193,7 @@ "└─TableReader 1.00 root data:Limit", " └─Limit 1.00 cop[tikv] offset:0, count:1", " └─Selection 1.00 cop[tikv] lt(test.t.c, 100)", - " └─TableRangeScan 1.96 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" + " └─TableRangeScan 100.00 cop[tikv] table:t range:[-inf,1000), keep order:false, stats:pseudo" ] } ] @@ -1532,7 +1532,7 @@ { "Histogram NotNull Count": 3080, "Increase Factor": 1, - "TopN total count": 1095 + "TopN total count": 1097 }, { "github.com/pingcap/tidb/pkg/planner/cardinality.GetColumnRowCount": [ @@ -2461,7 +2461,7 @@ { "Histogram NotNull Count": 3080, "Increase Factor": 1, - "TopN total count": 1095 + "TopN total count": 1097 }, { "github.com/pingcap/tidb/pkg/planner/cardinality.GetColumnRowCount": [ @@ -2575,13 +2575,13 @@ "rowCount": 1294.0277777777778 }, { - "Result": 1294.0277777777778 + "Result": 0 } ] }, { "End estimate range": { - "RowCount": 1294.0277777777778, + "RowCount": 0, "Type": "Range" } } @@ -2589,7 +2589,7 @@ }, { "Name": "a", - "Result": 1294.0277777777778 + "Result": 1 } ] }, @@ -2873,13 +2873,13 @@ "rowCount": 1540 }, { - "Result": 1540 + "Result": 0 } ] }, { "End estimate range": { - "RowCount": 1540, + "RowCount": 0, "Type": "Range" } } @@ -2887,7 +2887,7 @@ }, { "End estimate range": { - "RowCount": 1540, + "RowCount": 1, "Type": "Range" } } @@ -2895,7 +2895,7 @@ }, { "Name": "iab", - "Result": 1540 + "Result": 1 } ] }, @@ -3447,11 +3447,11 @@ "Expressions": [ "lt(test.t.a, -1500)" ], - "Selectivity": 0.5, + "Selectivity": 0.0003246753246753247, "partial cover": false }, { - "Result": 0.0003246753246753247 + "Result": 2.1082813290605499e-7 } ] } @@ -3602,7 +3602,7 @@ "histR": 999, "lPercent": 0.5613744375005636, "rPercent": 0, - "rowCount": 100 + "rowCount": 555.760693125558 }, { "Result": 100 @@ -4113,8 +4113,8 @@ "Result": [ "Projection 2000.00 root test.t.a, test.t.b, test.t.a, test.t.b", "└─IndexJoin 2000.00 root inner join, inner:IndexLookUp, outer key:test.t.a, inner key:test.t.b, equal cond:eq(test.t.a, test.t.b)", - " ├─TableReader(Build) 251000.00 root data:Selection", - " │ └─Selection 251000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", + " ├─TableReader(Build) 1000.00 root data:Selection", + " │ └─Selection 1000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", " │ └─TableFullScan 500000.00 cop[tikv] table:t keep order:false, stats:pseudo", " └─IndexLookUp(Probe) 2000.00 root ", " ├─Selection(Build) 1000000.00 cop[tikv] lt(test.t.b, 1), not(isnull(test.t.b))", @@ -4132,12 +4132,12 @@ "Result": [ "Projection 2000.00 root test.t.a, test.t.b, test.t.a, test.t.b", "└─IndexJoin 2000.00 root inner join, inner:IndexLookUp, outer key:test.t.a, inner key:test.t.b, equal cond:eq(test.t.a, test.t.b)", - " ├─TableReader(Build) 251000.00 root data:Selection", - " │ └─Selection 251000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", + " ├─TableReader(Build) 1000.00 root data:Selection", + " │ └─Selection 1000.00 cop[tikv] lt(test.t.a, 1), not(isnull(test.t.a))", " │ └─TableFullScan 500000.00 cop[tikv] table:t keep order:false, stats:pseudo", " └─IndexLookUp(Probe) 2000.00 root ", " ├─Selection(Build) 1000000.00 cop[tikv] lt(test.t.b, 1), not(isnull(test.t.b))", - " │ └─IndexRangeScan 251000000.00 cop[tikv] table:t2, index:idx(b) range: decided by [eq(test.t.b, test.t.a)], keep order:false, stats:pseudo", + " │ └─IndexRangeScan 1000000.00 cop[tikv] table:t2, index:idx(b) range: decided by [eq(test.t.b, test.t.a)], keep order:false, stats:pseudo", " └─Selection(Probe) 2000.00 cop[tikv] eq(test.t.a, 0)", " └─TableRowIDScan 1000000.00 cop[tikv] table:t2 keep order:false, stats:pseudo" ] diff --git a/pkg/statistics/histogram.go b/pkg/statistics/histogram.go index 6facad33b5cad..2b3f3225af8d1 100644 --- a/pkg/statistics/histogram.go +++ b/pkg/statistics/histogram.go @@ -1055,6 +1055,7 @@ func (hg *Histogram) OutOfRangeRowCount( // Adjust by increaseFactor if our estimate is low rowCount *= increaseFactor } + // Use modifyCount as a final bound return min(rowCount, float64(modifyCount)) } From e81f5603f57696822931b15ae9a8e34456521373 Mon Sep 17 00:00:00 2001 From: tpp Date: Mon, 12 Aug 2024 18:27:38 -0700 Subject: [PATCH 26/35] after modify test2 --- .../cardinality/testdata/cardinality_suite_out.json | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/planner/cardinality/testdata/cardinality_suite_out.json b/pkg/planner/cardinality/testdata/cardinality_suite_out.json index 1c9a46185cc38..029e9916020ae 100644 --- a/pkg/planner/cardinality/testdata/cardinality_suite_out.json +++ b/pkg/planner/cardinality/testdata/cardinality_suite_out.json @@ -24,7 +24,7 @@ { "Start": 800, "End": 900, - "Count": 762.504166655054 + "Count": 776.004166655054 }, { "Start": 900, @@ -79,7 +79,7 @@ { "Start": 800, "End": 1000, - "Count": 1220.696869573942 + "Count": 1234.196869573942 }, { "Start": 900, @@ -1532,7 +1532,7 @@ { "Histogram NotNull Count": 3080, "Increase Factor": 1, - "TopN total count": 1097 + "TopN total count": 1095 }, { "github.com/pingcap/tidb/pkg/planner/cardinality.GetColumnRowCount": [ @@ -2461,7 +2461,7 @@ { "Histogram NotNull Count": 3080, "Increase Factor": 1, - "TopN total count": 1097 + "TopN total count": 1095 }, { "github.com/pingcap/tidb/pkg/planner/cardinality.GetColumnRowCount": [ From 77a4bbea38676d43bbab74c07920bb1dbfd071dc Mon Sep 17 00:00:00 2001 From: tpp Date: Mon, 12 Aug 2024 18:41:07 -0700 Subject: [PATCH 27/35] after modify test3 --- .../integrationtest/r/clustered_index.result | 90 ++++++------ .../r/explain_complex_stats.result | 33 +++-- .../r/explain_easy_stats.result | 6 +- .../r/explain_indexmerge_stats.result | 138 +++++++++++------- tests/integrationtest/r/imdbload.result | 8 +- .../r/planner/cardinality/selectivity.result | 12 +- tests/integrationtest/r/tpch.result | 22 +-- 7 files changed, 168 insertions(+), 141 deletions(-) diff --git a/tests/integrationtest/r/clustered_index.result b/tests/integrationtest/r/clustered_index.result index 890713221fb6a..8e666459932ad 100644 --- a/tests/integrationtest/r/clustered_index.result +++ b/tests/integrationtest/r/clustered_index.result @@ -18,35 +18,35 @@ id estRows task access object operator info HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 └─IndexReader_13 1.00 root index:HashAgg_6 └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 - └─IndexRangeScan_11 1920.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false + └─IndexRangeScan_11 798.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 < 5429 ; id estRows task access object operator info HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 └─IndexReader_13 1.00 root index:HashAgg_6 └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_11 1920.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false + └─IndexRangeScan_11 798.87 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,5429), keep order:false explain select count(*) from with_cluster_index.tbl_0 where col_0 < 41 ; id estRows task access object operator info -HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 -└─IndexReader_13 1.00 root index:HashAgg_6 - └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 - └─IndexRangeScan_11 1163.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false +StreamAgg_17 1.00 root funcs:count(Column#8)->Column#6 +└─IndexReader_18 1.00 root index:StreamAgg_9 + └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#8 + └─IndexRangeScan_16 41.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 < 41 ; id estRows task access object operator info -HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 -└─IndexReader_13 1.00 root index:HashAgg_6 - └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_11 1163.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false +StreamAgg_17 1.00 root funcs:count(Column#9)->Column#7 +└─IndexReader_18 1.00 root index:StreamAgg_9 + └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#9 + └─IndexRangeScan_16 41.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,41), keep order:false explain select col_14 from with_cluster_index.tbl_2 where col_11 <> '2013-11-01' ; id estRows task access object operator info -IndexReader_9 4673.00 root index:Projection_5 -└─Projection_5 4673.00 cop[tikv] with_cluster_index.tbl_2.col_14 - └─IndexRangeScan_8 4673.00 cop[tikv] table:tbl_2, index:idx_9(col_11) range:[-inf,2013-11-01 00:00:00), (2013-11-01 00:00:00,+inf], keep order:false +IndexReader_9 4509.00 root index:Projection_5 +└─Projection_5 4509.00 cop[tikv] with_cluster_index.tbl_2.col_14 + └─IndexRangeScan_8 4509.00 cop[tikv] table:tbl_2, index:idx_9(col_11) range:[-inf,2013-11-01 00:00:00), (2013-11-01 00:00:00,+inf], keep order:false explain select col_14 from wout_cluster_index.tbl_2 where col_11 <> '2013-11-01' ; id estRows task access object operator info -TableReader_14 4673.00 root data:Projection_5 -└─Projection_5 4673.00 cop[tikv] wout_cluster_index.tbl_2.col_14 - └─Selection_13 4673.00 cop[tikv] ne(wout_cluster_index.tbl_2.col_11, 2013-11-01 00:00:00.000000) +TableReader_14 4509.00 root data:Projection_5 +└─Projection_5 4509.00 cop[tikv] wout_cluster_index.tbl_2.col_14 + └─Selection_13 4509.00 cop[tikv] ne(wout_cluster_index.tbl_2.col_11, 2013-11-01 00:00:00.000000) └─TableFullScan_12 4673.00 cop[tikv] table:tbl_2 keep order:false explain select sum( col_4 ) from with_cluster_index.tbl_0 where col_3 != '1993-12-02' ; id estRows task access object operator info @@ -59,29 +59,29 @@ id estRows task access object operator info HashAgg_13 1.00 root funcs:sum(Column#8)->Column#7 └─TableReader_14 1.00 root data:HashAgg_6 └─HashAgg_6 1.00 cop[tikv] funcs:sum(wout_cluster_index.tbl_0.col_4)->Column#8 - └─Selection_12 2244.00 cop[tikv] ne(wout_cluster_index.tbl_0.col_3, 1993-12-02 00:00:00.000000) + └─Selection_12 2243.00 cop[tikv] ne(wout_cluster_index.tbl_0.col_3, 1993-12-02 00:00:00.000000) └─TableFullScan_11 2244.00 cop[tikv] table:tbl_0 keep order:false explain select col_0 from with_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -IndexReader_6 1123.00 root index:IndexRangeScan_5 -└─IndexRangeScan_5 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +IndexReader_6 1.00 root index:IndexRangeScan_5 +└─IndexRangeScan_5 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select col_0 from wout_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -IndexReader_6 1123.00 root index:IndexRangeScan_5 -└─IndexRangeScan_5 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +IndexReader_6 1.00 root index:IndexRangeScan_5 +└─IndexRangeScan_5 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select col_3 from with_cluster_index.tbl_0 where col_3 >= '1981-09-15' ; id estRows task access object operator info -IndexReader_8 2244.00 root index:IndexRangeScan_7 -└─IndexRangeScan_7 2244.00 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false +IndexReader_8 1860.39 root index:IndexRangeScan_7 +└─IndexRangeScan_7 1860.39 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false explain select col_3 from wout_cluster_index.tbl_0 where col_3 >= '1981-09-15' ; id estRows task access object operator info -IndexReader_8 2244.00 root index:IndexRangeScan_7 -└─IndexRangeScan_7 2244.00 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false +IndexReader_8 1860.39 root index:IndexRangeScan_7 +└─IndexRangeScan_7 1860.39 cop[tikv] table:tbl_0, index:idx_1(col_3) range:[1981-09-15 00:00:00,+inf], keep order:false explain select tbl_2.col_14 , tbl_0.col_1 from with_cluster_index.tbl_2 right join with_cluster_index.tbl_0 on col_3 = col_11 ; id estRows task access object operator info MergeJoin_7 2533.51 root right outer join, left key:with_cluster_index.tbl_2.col_11, right key:with_cluster_index.tbl_0.col_3 -├─IndexReader_22(Build) 4673.00 root index:IndexFullScan_21 -│ └─IndexFullScan_21 4673.00 cop[tikv] table:tbl_2, index:idx_9(col_11) keep order:true +├─IndexReader_22(Build) 4509.00 root index:IndexFullScan_21 +│ └─IndexFullScan_21 4509.00 cop[tikv] table:tbl_2, index:idx_9(col_11) keep order:true └─TableReader_24(Probe) 2244.00 root data:TableFullScan_23 └─TableFullScan_23 2244.00 cop[tikv] table:tbl_0 keep order:true explain select tbl_2.col_14 , tbl_0.col_1 from wout_cluster_index.tbl_2 right join wout_cluster_index.tbl_0 on col_3 = col_11 ; @@ -89,31 +89,31 @@ id estRows task access object operator info HashJoin_22 2533.51 root right outer join, equal:[eq(wout_cluster_index.tbl_2.col_11, wout_cluster_index.tbl_0.col_3)] ├─TableReader_41(Build) 2244.00 root data:TableFullScan_40 │ └─TableFullScan_40 2244.00 cop[tikv] table:tbl_0 keep order:false -└─TableReader_44(Probe) 4673.00 root data:Selection_43 - └─Selection_43 4673.00 cop[tikv] not(isnull(wout_cluster_index.tbl_2.col_11)) +└─TableReader_44(Probe) 4509.00 root data:Selection_43 + └─Selection_43 4509.00 cop[tikv] not(isnull(wout_cluster_index.tbl_2.col_11)) └─TableFullScan_42 4673.00 cop[tikv] table:tbl_2 keep order:false explain select count(*) from with_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 -└─IndexReader_13 1.00 root index:HashAgg_6 - └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 - └─IndexRangeScan_11 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +StreamAgg_16 1.00 root funcs:count(Column#8)->Column#6 +└─IndexReader_17 1.00 root index:StreamAgg_9 + └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#8 + └─IndexRangeScan_11 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 <= 0 ; id estRows task access object operator info -HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 -└─IndexReader_13 1.00 root index:HashAgg_6 - └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_11 1123.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false +StreamAgg_16 1.00 root funcs:count(Column#9)->Column#7 +└─IndexReader_17 1.00 root index:StreamAgg_9 + └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#9 + └─IndexRangeScan_11 1.00 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[-inf,0], keep order:false explain select count(*) from with_cluster_index.tbl_0 where col_0 >= 803163 ; id estRows task access object operator info -HashAgg_12 1.00 root funcs:count(Column#7)->Column#6 -└─IndexReader_13 1.00 root index:HashAgg_6 - └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#7 - └─IndexRangeScan_11 1229.12 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false +StreamAgg_17 1.00 root funcs:count(Column#8)->Column#6 +└─IndexReader_18 1.00 root index:StreamAgg_9 + └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#8 + └─IndexRangeScan_16 133.89 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false explain select count(*) from wout_cluster_index.tbl_0 where col_0 >= 803163 ; id estRows task access object operator info -HashAgg_12 1.00 root funcs:count(Column#8)->Column#7 -└─IndexReader_13 1.00 root index:HashAgg_6 - └─HashAgg_6 1.00 cop[tikv] funcs:count(1)->Column#8 - └─IndexRangeScan_11 1229.12 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false +StreamAgg_17 1.00 root funcs:count(Column#9)->Column#7 +└─IndexReader_18 1.00 root index:StreamAgg_9 + └─StreamAgg_9 1.00 cop[tikv] funcs:count(1)->Column#9 + └─IndexRangeScan_16 133.89 cop[tikv] table:tbl_0, index:idx_3(col_0) range:[803163,+inf], keep order:false set @@tidb_enable_outer_join_reorder=false; diff --git a/tests/integrationtest/r/explain_complex_stats.result b/tests/integrationtest/r/explain_complex_stats.result index 3579d7c463364..9183896d7aefb 100644 --- a/tests/integrationtest/r/explain_complex_stats.result +++ b/tests/integrationtest/r/explain_complex_stats.result @@ -145,15 +145,15 @@ Projection 21.47 root explain_complex_stats.dt.ds, explain_complex_stats.dt.p1, └─TableRowIDScan 128.00 cop[tikv] table:dt keep order:false, stats:partial[cm:missing] explain format = 'brief' select gad.id as gid,sdk.id as sid,gad.aid as aid,gad.cm as cm,sdk.dic as dic,sdk.ip as ip, sdk.t as t, gad.p1 as p1, gad.p2 as p2, gad.p3 as p3, gad.p4 as p4, gad.p5 as p5, gad.p6_md5 as p6, gad.p7_md5 as p7, gad.ext as ext, gad.t as gtime from st gad join (select id, aid, pt, dic, ip, t from dd where pt = 'android' and bm = 0 and t > 1478143908) sdk on gad.aid = sdk.aid and gad.ip = sdk.ip and sdk.t > gad.t where gad.t > 1478143908 and gad.bm = 0 and gad.pt = 'android' group by gad.aid, sdk.dic limit 2500; id estRows task access object operator info -Projection 488.80 root explain_complex_stats.st.id, explain_complex_stats.dd.id, explain_complex_stats.st.aid, explain_complex_stats.st.cm, explain_complex_stats.dd.dic, explain_complex_stats.dd.ip, explain_complex_stats.dd.t, explain_complex_stats.st.p1, explain_complex_stats.st.p2, explain_complex_stats.st.p3, explain_complex_stats.st.p4, explain_complex_stats.st.p5, explain_complex_stats.st.p6_md5, explain_complex_stats.st.p7_md5, explain_complex_stats.st.ext, explain_complex_stats.st.t -└─Limit 488.80 root offset:0, count:2500 - └─HashAgg 488.80 root group by:explain_complex_stats.dd.dic, explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.id)->explain_complex_stats.st.id, funcs:firstrow(explain_complex_stats.st.aid)->explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.cm)->explain_complex_stats.st.cm, funcs:firstrow(explain_complex_stats.st.p1)->explain_complex_stats.st.p1, funcs:firstrow(explain_complex_stats.st.p2)->explain_complex_stats.st.p2, funcs:firstrow(explain_complex_stats.st.p3)->explain_complex_stats.st.p3, funcs:firstrow(explain_complex_stats.st.p4)->explain_complex_stats.st.p4, funcs:firstrow(explain_complex_stats.st.p5)->explain_complex_stats.st.p5, funcs:firstrow(explain_complex_stats.st.p6_md5)->explain_complex_stats.st.p6_md5, funcs:firstrow(explain_complex_stats.st.p7_md5)->explain_complex_stats.st.p7_md5, funcs:firstrow(explain_complex_stats.st.ext)->explain_complex_stats.st.ext, funcs:firstrow(explain_complex_stats.st.t)->explain_complex_stats.st.t, funcs:firstrow(explain_complex_stats.dd.id)->explain_complex_stats.dd.id, funcs:firstrow(explain_complex_stats.dd.dic)->explain_complex_stats.dd.dic, funcs:firstrow(explain_complex_stats.dd.ip)->explain_complex_stats.dd.ip, funcs:firstrow(explain_complex_stats.dd.t)->explain_complex_stats.dd.t - └─HashJoin 488.80 root inner join, equal:[eq(explain_complex_stats.st.aid, explain_complex_stats.dd.aid) eq(explain_complex_stats.st.ip, explain_complex_stats.dd.ip)], other cond:gt(explain_complex_stats.dd.t, explain_complex_stats.st.t) - ├─TableReader(Build) 488.80 root data:Selection - │ └─Selection 488.80 cop[tikv] eq(explain_complex_stats.st.bm, 0), eq(explain_complex_stats.st.pt, "android"), gt(explain_complex_stats.st.t, 1478143908), not(isnull(explain_complex_stats.st.ip)) +Projection 424.00 root explain_complex_stats.st.id, explain_complex_stats.dd.id, explain_complex_stats.st.aid, explain_complex_stats.st.cm, explain_complex_stats.dd.dic, explain_complex_stats.dd.ip, explain_complex_stats.dd.t, explain_complex_stats.st.p1, explain_complex_stats.st.p2, explain_complex_stats.st.p3, explain_complex_stats.st.p4, explain_complex_stats.st.p5, explain_complex_stats.st.p6_md5, explain_complex_stats.st.p7_md5, explain_complex_stats.st.ext, explain_complex_stats.st.t +└─Limit 424.00 root offset:0, count:2500 + └─HashAgg 424.00 root group by:explain_complex_stats.dd.dic, explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.id)->explain_complex_stats.st.id, funcs:firstrow(explain_complex_stats.st.aid)->explain_complex_stats.st.aid, funcs:firstrow(explain_complex_stats.st.cm)->explain_complex_stats.st.cm, funcs:firstrow(explain_complex_stats.st.p1)->explain_complex_stats.st.p1, funcs:firstrow(explain_complex_stats.st.p2)->explain_complex_stats.st.p2, funcs:firstrow(explain_complex_stats.st.p3)->explain_complex_stats.st.p3, funcs:firstrow(explain_complex_stats.st.p4)->explain_complex_stats.st.p4, funcs:firstrow(explain_complex_stats.st.p5)->explain_complex_stats.st.p5, funcs:firstrow(explain_complex_stats.st.p6_md5)->explain_complex_stats.st.p6_md5, funcs:firstrow(explain_complex_stats.st.p7_md5)->explain_complex_stats.st.p7_md5, funcs:firstrow(explain_complex_stats.st.ext)->explain_complex_stats.st.ext, funcs:firstrow(explain_complex_stats.st.t)->explain_complex_stats.st.t, funcs:firstrow(explain_complex_stats.dd.id)->explain_complex_stats.dd.id, funcs:firstrow(explain_complex_stats.dd.dic)->explain_complex_stats.dd.dic, funcs:firstrow(explain_complex_stats.dd.ip)->explain_complex_stats.dd.ip, funcs:firstrow(explain_complex_stats.dd.t)->explain_complex_stats.dd.t + └─HashJoin 424.00 root inner join, equal:[eq(explain_complex_stats.st.aid, explain_complex_stats.dd.aid) eq(explain_complex_stats.st.ip, explain_complex_stats.dd.ip)], other cond:gt(explain_complex_stats.dd.t, explain_complex_stats.st.t) + ├─TableReader(Build) 424.00 root data:Selection + │ └─Selection 424.00 cop[tikv] eq(explain_complex_stats.st.bm, 0), eq(explain_complex_stats.st.pt, "android"), gt(explain_complex_stats.st.t, 1478143908), not(isnull(explain_complex_stats.st.ip)) │ └─TableFullScan 1999.00 cop[tikv] table:gad keep order:false, stats:partial[t:missing] - └─TableReader(Probe) 501.17 root data:Selection - └─Selection 501.17 cop[tikv] eq(explain_complex_stats.dd.bm, 0), eq(explain_complex_stats.dd.pt, "android"), gt(explain_complex_stats.dd.t, 1478143908), not(isnull(explain_complex_stats.dd.ip)), not(isnull(explain_complex_stats.dd.t)) + └─TableReader(Probe) 450.56 root data:Selection + └─Selection 450.56 cop[tikv] eq(explain_complex_stats.dd.bm, 0), eq(explain_complex_stats.dd.pt, "android"), gt(explain_complex_stats.dd.t, 1478143908), not(isnull(explain_complex_stats.dd.ip)), not(isnull(explain_complex_stats.dd.t)) └─TableFullScan 2000.00 cop[tikv] table:dd keep order:false, stats:partial[ip:missing, t:missing] explain format = 'brief' select gad.id as gid,sdk.id as sid,gad.aid as aid,gad.cm as cm,sdk.dic as dic,sdk.ip as ip, sdk.t as t, gad.p1 as p1, gad.p2 as p2, gad.p3 as p3, gad.p4 as p4, gad.p5 as p5, gad.p6_md5 as p6, gad.p7_md5 as p7, gad.ext as ext from st gad join dd sdk on gad.aid = sdk.aid and gad.dic = sdk.mac and gad.t < sdk.t where gad.t > 1477971479 and gad.bm = 0 and gad.pt = 'ios' and gad.dit = 'mac' and sdk.t > 1477971479 and sdk.bm = 0 and sdk.pt = 'ios' limit 3000; id estRows task access object operator info @@ -177,15 +177,16 @@ Projection 39.28 root explain_complex_stats.st.cm, explain_complex_stats.st.p1, └─TableRowIDScan 160.23 cop[tikv] table:st keep order:false, stats:partial[t:missing] explain format = 'brief' select dt.id as id, dt.aid as aid, dt.pt as pt, dt.dic as dic, dt.cm as cm, rr.gid as gid, rr.acd as acd, rr.t as t,dt.p1 as p1, dt.p2 as p2, dt.p3 as p3, dt.p4 as p4, dt.p5 as p5, dt.p6_md5 as p6, dt.p7_md5 as p7 from dt dt join rr rr on (rr.pt = 'ios' and rr.t > 1478185592 and dt.aid = rr.aid and dt.dic = rr.dic) where dt.pt = 'ios' and dt.t > 1478185592 and dt.bm = 0 limit 2000; id estRows task access object operator info -Projection 476.61 root explain_complex_stats.dt.id, explain_complex_stats.dt.aid, explain_complex_stats.dt.pt, explain_complex_stats.dt.dic, explain_complex_stats.dt.cm, explain_complex_stats.rr.gid, explain_complex_stats.rr.acd, explain_complex_stats.rr.t, explain_complex_stats.dt.p1, explain_complex_stats.dt.p2, explain_complex_stats.dt.p3, explain_complex_stats.dt.p4, explain_complex_stats.dt.p5, explain_complex_stats.dt.p6_md5, explain_complex_stats.dt.p7_md5 -└─Limit 476.61 root offset:0, count:2000 - └─HashJoin 476.61 root inner join, equal:[eq(explain_complex_stats.dt.aid, explain_complex_stats.rr.aid) eq(explain_complex_stats.dt.dic, explain_complex_stats.rr.dic)] - ├─TableReader(Build) 476.61 root data:Selection - │ └─Selection 476.61 cop[tikv] eq(explain_complex_stats.dt.bm, 0), eq(explain_complex_stats.dt.pt, "ios"), gt(explain_complex_stats.dt.t, 1478185592), not(isnull(explain_complex_stats.dt.dic)) +Projection 428.32 root explain_complex_stats.dt.id, explain_complex_stats.dt.aid, explain_complex_stats.dt.pt, explain_complex_stats.dt.dic, explain_complex_stats.dt.cm, explain_complex_stats.rr.gid, explain_complex_stats.rr.acd, explain_complex_stats.rr.t, explain_complex_stats.dt.p1, explain_complex_stats.dt.p2, explain_complex_stats.dt.p3, explain_complex_stats.dt.p4, explain_complex_stats.dt.p5, explain_complex_stats.dt.p6_md5, explain_complex_stats.dt.p7_md5 +└─Limit 428.32 root offset:0, count:2000 + └─IndexJoin 428.32 root inner join, inner:IndexLookUp, outer key:explain_complex_stats.dt.aid, explain_complex_stats.dt.dic, inner key:explain_complex_stats.rr.aid, explain_complex_stats.rr.dic, equal cond:eq(explain_complex_stats.dt.aid, explain_complex_stats.rr.aid), eq(explain_complex_stats.dt.dic, explain_complex_stats.rr.dic) + ├─TableReader(Build) 428.32 root data:Selection + │ └─Selection 428.32 cop[tikv] eq(explain_complex_stats.dt.bm, 0), eq(explain_complex_stats.dt.pt, "ios"), gt(explain_complex_stats.dt.t, 1478185592), not(isnull(explain_complex_stats.dt.dic)) │ └─TableFullScan 2000.00 cop[tikv] table:dt keep order:false - └─TableReader(Probe) 970.00 root data:Selection - └─Selection 970.00 cop[tikv] eq(explain_complex_stats.rr.pt, "ios"), gt(explain_complex_stats.rr.t, 1478185592) - └─TableFullScan 2000.00 cop[tikv] table:rr keep order:false + └─IndexLookUp(Probe) 428.32 root + ├─IndexRangeScan(Build) 428.32 cop[tikv] table:rr, index:PRIMARY(aid, dic) range: decided by [eq(explain_complex_stats.rr.aid, explain_complex_stats.dt.aid) eq(explain_complex_stats.rr.dic, explain_complex_stats.dt.dic)], keep order:false + └─Selection(Probe) 428.32 cop[tikv] eq(explain_complex_stats.rr.pt, "ios"), gt(explain_complex_stats.rr.t, 1478185592) + └─TableRowIDScan 428.32 cop[tikv] table:rr keep order:false explain format = 'brief' select pc,cr,count(DISTINCT uid) as pay_users,count(oid) as pay_times,sum(am) as am from pp where ps=2 and ppt>=1478188800 and ppt<1478275200 and pi in ('510017','520017') and uid in ('18089709','18090780') group by pc,cr; id estRows task access object operator info Projection 207.02 root explain_complex_stats.pp.pc, explain_complex_stats.pp.cr, Column#22, Column#23, Column#24 diff --git a/tests/integrationtest/r/explain_easy_stats.result b/tests/integrationtest/r/explain_easy_stats.result index 7be6d7c22aa2e..8c26bcf04b8be 100644 --- a/tests/integrationtest/r/explain_easy_stats.result +++ b/tests/integrationtest/r/explain_easy_stats.result @@ -51,8 +51,8 @@ HashJoin 2481.25 root left outer join, equal:[eq(explain_easy_stats.t1.c2, expl ├─TableReader(Build) 1985.00 root data:Selection │ └─Selection 1985.00 cop[tikv] not(isnull(explain_easy_stats.t2.c1)) │ └─TableFullScan 1985.00 cop[tikv] table:t2 keep order:false, stats:partial[c1:missing] -└─TableReader(Probe) 1999.00 root data:TableRangeScan - └─TableRangeScan 1999.00 cop[tikv] table:t1 range:(1,+inf], keep order:false +└─TableReader(Probe) 1998.00 root data:TableRangeScan + └─TableRangeScan 1998.00 cop[tikv] table:t1 range:(1,+inf], keep order:false explain format = 'brief' update t1 set t1.c2 = 2 where t1.c1 = 1; id estRows task access object operator info Update N/A root N/A @@ -87,7 +87,7 @@ IndexLookUp 0.00 root └─TableRowIDScan 0.00 cop[tikv] table:t1 keep order:false, stats:partial[c2:missing] explain format = 'brief' select * from t1 where c1 = 1 and c2 > 1; id estRows task access object operator info -Selection 1.00 root gt(explain_easy_stats.t1.c2, 1) +Selection 0.50 root gt(explain_easy_stats.t1.c2, 1) └─Point_Get 1.00 root table:t1 handle:1 explain format = 'brief' select c1 from t1 where c1 in (select c2 from t2); id estRows task access object operator info diff --git a/tests/integrationtest/r/explain_indexmerge_stats.result b/tests/integrationtest/r/explain_indexmerge_stats.result index 1d7f41261838b..bd27123d6a92f 100644 --- a/tests/integrationtest/r/explain_indexmerge_stats.result +++ b/tests/integrationtest/r/explain_indexmerge_stats.result @@ -7,111 +7,137 @@ create index td on t (d); load stats 's/explain_indexmerge_stats_t.json'; explain format = 'brief' select * from t where a < 50 or b < 50; id estRows task access object operator info -TableReader 3750049.00 root data:Selection -└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 98.00 root type: union +├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where (a < 50 or b < 50) and f > 100; id estRows task access object operator info -TableReader 3750049.00 root data:Selection -└─Selection 3750049.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100), or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 98.00 root type: union +├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +└─Selection(Probe) 98.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100) + └─TableRowIDScan 98.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where b < 50 or c < 50; id estRows task access object operator info -TableReader 3750049.00 root data:Selection -└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 50)) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 98.00 root type: union +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tc(c) range:[-inf,50), keep order:false +└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false set session tidb_enable_index_merge = on; explain format = 'brief' select * from t where a < 50 or b < 50; id estRows task access object operator info -TableReader 3750049.00 root data:Selection -└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 98.00 root type: union +├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where (a < 50 or b < 50) and f > 100; id estRows task access object operator info -TableReader 3750049.00 root data:Selection -└─Selection 3750049.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100), or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 50)) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 98.00 root type: union +├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +└─Selection(Probe) 98.00 cop[tikv] gt(explain_indexmerge_stats.t.f, 100) + └─TableRowIDScan 98.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where a < 50 or b < 5000000; id estRows task access object operator info -TableReader 5000000.00 root data:Selection -└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 5000000)) +TableReader 4999999.00 root data:Selection +└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), lt(explain_indexmerge_stats.t.b, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where b < 50 or c < 50; id estRows task access object operator info -TableReader 3750049.00 root data:Selection -└─Selection 3750049.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 50)) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 98.00 root type: union +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tc(c) range:[-inf,50), keep order:false +└─TableRowIDScan(Probe) 98.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where b < 50 or c < 5000000; id estRows task access object operator info -TableReader 5000000.00 root data:Selection -└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) +TableReader 4999999.00 root data:Selection +└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where a < 50 or b < 50 or c < 50; id estRows task access object operator info -TableReader 4375036.75 root data:Selection -└─Selection 4375036.75 cop[tikv] or(lt(explain_indexmerge_stats.t.a, 50), or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 50))) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 147.00 root type: union +├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tc(c) range:[-inf,50), keep order:false +└─TableRowIDScan(Probe) 147.00 cop[tikv] table:t keep order:false explain format = 'brief' select * from t where (b < 10000 or c < 10000) and (a < 10 or d < 10) and f < 10; id estRows task access object operator info -TableReader 1409331.42 root data:Selection -└─Selection 1409331.42 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.a, 10), lt(explain_indexmerge_stats.t.d, 10)), or(lt(explain_indexmerge_stats.t.b, 10000), lt(explain_indexmerge_stats.t.c, 10000)) - └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 0.00 root type: union +├─TableRangeScan(Build) 9.00 cop[tikv] table:t range:[-inf,10), keep order:false +├─IndexRangeScan(Build) 9.00 cop[tikv] table:t, index:td(d) range:[-inf,10), keep order:false +└─Selection(Probe) 0.00 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.b, 10000), lt(explain_indexmerge_stats.t.c, 10000)) + └─TableRowIDScan 18.00 cop[tikv] table:t keep order:false explain format="dot" select * from t where (a < 50 or b < 50) and f > 100; dot contents -digraph TableReader_7 { -subgraph cluster7{ +digraph IndexMerge_12 { +subgraph cluster12{ node [style=filled, color=lightgrey] color=black label = "root" -"TableReader_7" +"IndexMerge_12" +} +subgraph cluster8{ +node [style=filled, color=lightgrey] +color=black +label = "cop" +"TableRangeScan_8" +} +subgraph cluster9{ +node [style=filled, color=lightgrey] +color=black +label = "cop" +"IndexRangeScan_9" } -subgraph cluster6{ +subgraph cluster11{ node [style=filled, color=lightgrey] color=black label = "cop" -"Selection_6" -> "TableFullScan_5" +"Selection_11" -> "TableRowIDScan_10" } -"TableReader_7" -> "Selection_6" +"IndexMerge_12" -> "TableRangeScan_8" +"IndexMerge_12" -> "IndexRangeScan_9" +"IndexMerge_12" -> "Selection_11" } set session tidb_enable_index_merge = off; explain format = 'brief' select /*+ use_index_merge(t, primary, tb, tc) */ * from t where a <= 500000 or b <= 1000000 or c <= 3000000; id estRows task access object operator info -IndexMerge 5000000.00 root type: union -├─TableRangeScan(Build) 3000000.00 cop[tikv] table:t range:[-inf,500000], keep order:false -├─IndexRangeScan(Build) 3500000.00 cop[tikv] table:t, index:tb(b) range:[-inf,1000000], keep order:false -├─IndexRangeScan(Build) 5000000.00 cop[tikv] table:t, index:tc(c) range:[-inf,3000000], keep order:false -└─TableRowIDScan(Probe) 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 3560000.00 root type: union +├─TableRangeScan(Build) 500000.00 cop[tikv] table:t range:[-inf,500000], keep order:false +├─IndexRangeScan(Build) 1000000.00 cop[tikv] table:t, index:tb(b) range:[-inf,1000000], keep order:false +├─IndexRangeScan(Build) 3000000.00 cop[tikv] table:t, index:tc(c) range:[-inf,3000000], keep order:false +└─TableRowIDScan(Probe) 3560000.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, tb, tc) */ * from t where b < 50 or c < 5000000; id estRows task access object operator info -IndexMerge 5000000.00 root type: union -├─IndexRangeScan(Build) 2500049.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 5000000.00 cop[tikv] table:t, index:tc(c) range:[-inf,5000000), keep order:false -└─TableRowIDScan(Probe) 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 4999999.00 root type: union +├─IndexRangeScan(Build) 49.00 cop[tikv] table:t, index:tb(b) range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 4999999.00 cop[tikv] table:t, index:tc(c) range:[-inf,5000000), keep order:false +└─TableRowIDScan(Probe) 4999999.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, tb, tc) */ * from t where (b < 10000 or c < 10000) and (a < 10 or d < 10) and f < 10; id estRows task access object operator info -IndexMerge 1409331.42 root type: union -├─IndexRangeScan(Build) 2509999.00 cop[tikv] table:t, index:tb(b) range:[-inf,10000), keep order:false -├─IndexRangeScan(Build) 2509999.00 cop[tikv] table:t, index:tc(c) range:[-inf,10000), keep order:false -└─Selection(Probe) 1409331.42 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.a, 10), lt(explain_indexmerge_stats.t.d, 10)) - └─TableRowIDScan 3759979.00 cop[tikv] table:t keep order:false +IndexMerge 0.00 root type: union +├─IndexRangeScan(Build) 9999.00 cop[tikv] table:t, index:tb(b) range:[-inf,10000), keep order:false +├─IndexRangeScan(Build) 9999.00 cop[tikv] table:t, index:tc(c) range:[-inf,10000), keep order:false +└─Selection(Probe) 0.00 cop[tikv] lt(explain_indexmerge_stats.t.f, 10), or(lt(explain_indexmerge_stats.t.a, 10), lt(explain_indexmerge_stats.t.d, 10)) + └─TableRowIDScan 19978.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, tb) */ * from t where b < 50 or c < 5000000; id estRows task access object operator info -TableReader 5000000.00 root data:Selection -└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) +TableReader 4999999.00 root data:Selection +└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ no_index_merge(), use_index_merge(t, tb, tc) */ * from t where b < 50 or c < 5000000; id estRows task access object operator info -TableReader 5000000.00 root data:Selection -└─Selection 5000000.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) +TableReader 4999999.00 root data:Selection +└─Selection 4999999.00 cop[tikv] or(lt(explain_indexmerge_stats.t.b, 50), lt(explain_indexmerge_stats.t.c, 5000000)) └─TableFullScan 5000000.00 cop[tikv] table:t keep order:false explain format = 'brief' select /*+ use_index_merge(t, primary, tb) */ * from t where a < 50 or b < 5000000; id estRows task access object operator info -IndexMerge 5000000.00 root type: union -├─TableRangeScan(Build) 2500049.00 cop[tikv] table:t range:[-inf,50), keep order:false -├─IndexRangeScan(Build) 5000000.00 cop[tikv] table:t, index:tb(b) range:[-inf,5000000), keep order:false -└─TableRowIDScan(Probe) 5000000.00 cop[tikv] table:t keep order:false +IndexMerge 4999999.00 root type: union +├─TableRangeScan(Build) 49.00 cop[tikv] table:t range:[-inf,50), keep order:false +├─IndexRangeScan(Build) 4999999.00 cop[tikv] table:t, index:tb(b) range:[-inf,5000000), keep order:false +└─TableRowIDScan(Probe) 4999999.00 cop[tikv] table:t keep order:false set session tidb_enable_index_merge = on; drop table if exists t; CREATE TABLE t ( diff --git a/tests/integrationtest/r/imdbload.result b/tests/integrationtest/r/imdbload.result index 53df922885088..2642b0837296c 100644 --- a/tests/integrationtest/r/imdbload.result +++ b/tests/integrationtest/r/imdbload.result @@ -286,7 +286,7 @@ IndexLookUp_7 1005030.94 root └─TableRowIDScan_6(Probe) 1005030.94 cop[tikv] table:char_name keep order:false trace plan target = 'estimation' select * from char_name where ((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436')); CE_trace -[{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'I'))","row_count":1},{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'L'))","row_count":1},{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1780915},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436'))","row_count":2},{"table_name":"char_name","type":"Index Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`or`(`and`(`eq`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.surname_pcode, 'E436')), `and`(`eq`(imdbload.char_name.imdb_index, 'L'), `lt`(imdbload.char_name.surname_pcode, 'E436')))","row_count":804024}] +[{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'I'))","row_count":1},{"table_name":"char_name","type":"Column Stats-Point","expr":"((imdb_index = 'L'))","row_count":1},{"table_name":"char_name","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4314864},{"table_name":"char_name","type":"Column Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Index Stats-Range","expr":"((imdb_index = 'I') and (surname_pcode < 'E436')) or ((imdb_index = 'L') and (surname_pcode < 'E436'))","row_count":2},{"table_name":"char_name","type":"Index Stats-Range","expr":"((surname_pcode < 'E436'))","row_count":1005030},{"table_name":"char_name","type":"Table Stats-Expression-CNF","expr":"`or`(`and`(`eq`(imdbload.char_name.imdb_index, 'I'), `lt`(imdbload.char_name.surname_pcode, 'E436')), `and`(`eq`(imdbload.char_name.imdb_index, 'L'), `lt`(imdbload.char_name.surname_pcode, 'E436')))","row_count":804024}] explain select * from char_name where ((imdb_index = 'V') and (surname_pcode < 'L3416')); id estRows task access object operator info @@ -309,7 +309,7 @@ IndexLookUp_10 1.00 root └─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:movie_companies keep order:false trace plan target = 'estimation' select * from movie_companies where company_type_id > 2; CE_trace -[{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((company_type_id > 2 and true))","row_count":2479148},{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4958296},{"table_name":"movie_companies","type":"Index Stats-Range","expr":"((company_type_id > 2 and true))","row_count":1},{"table_name":"movie_companies","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.movie_companies.company_type_id, 2)","row_count":1}] +[{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((company_type_id > 2 and true))","row_count":1},{"table_name":"movie_companies","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":4958296},{"table_name":"movie_companies","type":"Index Stats-Range","expr":"((company_type_id > 2 and true))","row_count":1},{"table_name":"movie_companies","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.movie_companies.company_type_id, 2)","row_count":1}] explain select * from char_name where imdb_index > 'I' and imdb_index < 'II'; id estRows task access object operator info @@ -346,7 +346,7 @@ IndexLookUp_10 1.00 root └─TableRowIDScan_9(Probe) 1.00 cop[tikv] table:aka_title keep order:false trace plan target = 'estimation' select * from aka_title where kind_id > 7; CE_trace -[{"table_name":"aka_title","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":528337},{"table_name":"aka_title","type":"Column Stats-Range","expr":"((kind_id > 7 and true))","row_count":51390},{"table_name":"aka_title","type":"Index Stats-Range","expr":"((kind_id > 7 and true))","row_count":1},{"table_name":"aka_title","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.aka_title.kind_id, 7)","row_count":1}] +[{"table_name":"aka_title","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":528337},{"table_name":"aka_title","type":"Column Stats-Range","expr":"((kind_id > 7 and true))","row_count":1},{"table_name":"aka_title","type":"Index Stats-Range","expr":"((kind_id > 7 and true))","row_count":1},{"table_name":"aka_title","type":"Table Stats-Expression-CNF","expr":"`gt`(imdbload.aka_title.kind_id, 7)","row_count":1}] explain select * from keyword where ((phonetic_code = 'R1652') and (keyword > 'ecg-monitor' and keyword < 'killers')); id estRows task access object operator info @@ -366,5 +366,5 @@ IndexLookUp_11 144633.00 root └─TableRowIDScan_9 144633.00 cop[tikv] table:cast_info keep order:false trace plan target = 'estimation' select * from cast_info where (nr_order is null) and (person_role_id = 2) and (note >= '(key set pa: Florida'); CE_trace -[{"table_name":"cast_info","type":"Column Stats-Point","expr":"((nr_order is null))","row_count":45995275},{"table_name":"cast_info","type":"Column Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":63475835},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((note >= '(key set pa: Florida' and true))","row_count":18437410},{"table_name":"cast_info","type":"Index Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Index Stats-Range","expr":"((nr_order is null) and (person_role_id = 2) and (note >= '(key set pa: Florida' and true))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`and`(`isnull`(imdbload.cast_info.nr_order), `and`(`eq`(imdbload.cast_info.person_role_id, 2), `ge`(imdbload.cast_info.note, '(key set pa: Florida')))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`eq`(imdbload.cast_info.person_role_id, 2)","row_count":2089611}] +[{"table_name":"cast_info","type":"Column Stats-Point","expr":"((nr_order is null))","row_count":45995275},{"table_name":"cast_info","type":"Column Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((id >= -9223372036854775808 and id <= 9223372036854775807))","row_count":63475835},{"table_name":"cast_info","type":"Column Stats-Range","expr":"((note >= '(key set pa: Florida' and true))","row_count":14934328},{"table_name":"cast_info","type":"Index Stats-Point","expr":"((person_role_id = 2))","row_count":2089611},{"table_name":"cast_info","type":"Index Stats-Range","expr":"((nr_order is null) and (person_role_id = 2) and (note >= '(key set pa: Florida' and true))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`and`(`isnull`(imdbload.cast_info.nr_order), `and`(`eq`(imdbload.cast_info.person_role_id, 2), `ge`(imdbload.cast_info.note, '(key set pa: Florida')))","row_count":144633},{"table_name":"cast_info","type":"Table Stats-Expression-CNF","expr":"`eq`(imdbload.cast_info.person_role_id, 2)","row_count":2089611}] diff --git a/tests/integrationtest/r/planner/cardinality/selectivity.result b/tests/integrationtest/r/planner/cardinality/selectivity.result index e166e749e4708..f08265b41d00e 100644 --- a/tests/integrationtest/r/planner/cardinality/selectivity.result +++ b/tests/integrationtest/r/planner/cardinality/selectivity.result @@ -396,8 +396,8 @@ TableReader_7 1.00 root data:Selection_6 └─TableFullScan_5 4.00 cop[tikv] table:tdatetime keep order:false explain select * from tint where b=1; id estRows task access object operator info -TableReader_7 1.84 root data:Selection_6 -└─Selection_6 1.84 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 1) +TableReader_7 1.00 root data:Selection_6 +└─Selection_6 1.00 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 1) └─TableFullScan_5 8.00 cop[tikv] table:tint keep order:false explain select * from tint where b=4; id estRows task access object operator info @@ -406,8 +406,8 @@ TableReader_7 1.00 root data:Selection_6 └─TableFullScan_5 8.00 cop[tikv] table:tint keep order:false explain select * from tint where b=8; id estRows task access object operator info -TableReader_7 2.07 root data:Selection_6 -└─Selection_6 2.07 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 8) +TableReader_7 1.00 root data:Selection_6 +└─Selection_6 1.00 cop[tikv] eq(planner__cardinality__selectivity.tint.b, 8) └─TableFullScan_5 8.00 cop[tikv] table:tint keep order:false explain select * from tdouble where b=1; id estRows task access object operator info @@ -471,8 +471,8 @@ TableReader_7 1.00 root data:Selection_6 └─TableFullScan_5 4.00 cop[tikv] table:tdatetime keep order:false explain select * from ct1 where pk>='1' and pk <='4'; id estRows task access object operator info -TableReader_6 7.40 root data:TableRangeScan_5 -└─TableRangeScan_5 7.40 cop[tikv] table:ct1 range:["1","4"], keep order:false +TableReader_6 5.00 root data:TableRangeScan_5 +└─TableRangeScan_5 5.00 cop[tikv] table:ct1 range:["1","4"], keep order:false explain select * from ct1 where pk>='4' and pk <='6'; id estRows task access object operator info TableReader_6 3.75 root data:TableRangeScan_5 diff --git a/tests/integrationtest/r/tpch.result b/tests/integrationtest/r/tpch.result index 8b89428e68194..bdb7c7f740c77 100644 --- a/tests/integrationtest/r/tpch.result +++ b/tests/integrationtest/r/tpch.result @@ -112,12 +112,12 @@ order by l_returnflag, l_linestatus; id estRows task access object operator info -Sort 3.00 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus -└─Projection 3.00 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus, Column#18, Column#19, Column#20, Column#21, Column#22, Column#23, Column#24, Column#25 - └─HashAgg 3.00 root group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(Column#26)->Column#18, funcs:sum(Column#27)->Column#19, funcs:sum(Column#28)->Column#20, funcs:sum(Column#29)->Column#21, funcs:avg(Column#30, Column#31)->Column#22, funcs:avg(Column#32, Column#33)->Column#23, funcs:avg(Column#34, Column#35)->Column#24, funcs:count(Column#36)->Column#25, funcs:firstrow(tpch50.lineitem.l_returnflag)->tpch50.lineitem.l_returnflag, funcs:firstrow(tpch50.lineitem.l_linestatus)->tpch50.lineitem.l_linestatus - └─TableReader 3.00 root data:HashAgg - └─HashAgg 3.00 cop[tikv] group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(tpch50.lineitem.l_quantity)->Column#26, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#27, funcs:sum(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)))->Column#28, funcs:sum(mul(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)), plus(1, tpch50.lineitem.l_tax)))->Column#29, funcs:count(tpch50.lineitem.l_quantity)->Column#30, funcs:sum(tpch50.lineitem.l_quantity)->Column#31, funcs:count(tpch50.lineitem.l_extendedprice)->Column#32, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#33, funcs:count(tpch50.lineitem.l_discount)->Column#34, funcs:sum(tpch50.lineitem.l_discount)->Column#35, funcs:count(1)->Column#36 - └─Selection 300005811.00 cop[tikv] le(tpch50.lineitem.l_shipdate, 1998-08-15 00:00:00.000000) +Sort 2.94 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus +└─Projection 2.94 root tpch50.lineitem.l_returnflag, tpch50.lineitem.l_linestatus, Column#18, Column#19, Column#20, Column#21, Column#22, Column#23, Column#24, Column#25 + └─HashAgg 2.94 root group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(Column#26)->Column#18, funcs:sum(Column#27)->Column#19, funcs:sum(Column#28)->Column#20, funcs:sum(Column#29)->Column#21, funcs:avg(Column#30, Column#31)->Column#22, funcs:avg(Column#32, Column#33)->Column#23, funcs:avg(Column#34, Column#35)->Column#24, funcs:count(Column#36)->Column#25, funcs:firstrow(tpch50.lineitem.l_returnflag)->tpch50.lineitem.l_returnflag, funcs:firstrow(tpch50.lineitem.l_linestatus)->tpch50.lineitem.l_linestatus + └─TableReader 2.94 root data:HashAgg + └─HashAgg 2.94 cop[tikv] group by:tpch50.lineitem.l_linestatus, tpch50.lineitem.l_returnflag, funcs:sum(tpch50.lineitem.l_quantity)->Column#26, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#27, funcs:sum(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)))->Column#28, funcs:sum(mul(mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount)), plus(1, tpch50.lineitem.l_tax)))->Column#29, funcs:count(tpch50.lineitem.l_quantity)->Column#30, funcs:sum(tpch50.lineitem.l_quantity)->Column#31, funcs:count(tpch50.lineitem.l_extendedprice)->Column#32, funcs:sum(tpch50.lineitem.l_extendedprice)->Column#33, funcs:count(tpch50.lineitem.l_discount)->Column#34, funcs:sum(tpch50.lineitem.l_discount)->Column#35, funcs:count(1)->Column#36 + └─Selection 293797075.24 cop[tikv] le(tpch50.lineitem.l_shipdate, 1998-08-15 00:00:00.000000) └─TableFullScan 300005811.00 cop[tikv] table:lineitem keep order:false /* Q2 Minimum Cost Supplier Query @@ -245,20 +245,20 @@ limit 10; id estRows task access object operator info Projection 10.00 root tpch50.lineitem.l_orderkey, Column#35, tpch50.orders.o_orderdate, tpch50.orders.o_shippriority └─TopN 10.00 root Column#35:desc, tpch50.orders.o_orderdate, offset:0, count:10 - └─HashAgg 73908224.00 root group by:Column#45, Column#46, Column#47, funcs:sum(Column#44)->Column#35, funcs:firstrow(Column#45)->tpch50.orders.o_orderdate, funcs:firstrow(Column#46)->tpch50.orders.o_shippriority, funcs:firstrow(Column#47)->tpch50.lineitem.l_orderkey + └─HashAgg 39877917.22 root group by:Column#45, Column#46, Column#47, funcs:sum(Column#44)->Column#35, funcs:firstrow(Column#45)->tpch50.orders.o_orderdate, funcs:firstrow(Column#46)->tpch50.orders.o_shippriority, funcs:firstrow(Column#47)->tpch50.lineitem.l_orderkey └─Projection 93262952.04 root mul(tpch50.lineitem.l_extendedprice, minus(1, tpch50.lineitem.l_discount))->Column#44, tpch50.orders.o_orderdate->Column#45, tpch50.orders.o_shippriority->Column#46, tpch50.lineitem.l_orderkey->Column#47 └─IndexHashJoin 93262952.04 root inner join, inner:IndexLookUp, outer key:tpch50.orders.o_orderkey, inner key:tpch50.lineitem.l_orderkey, equal cond:eq(tpch50.orders.o_orderkey, tpch50.lineitem.l_orderkey) ├─HashJoin(Build) 22975885.46 root inner join, equal:[eq(tpch50.customer.c_custkey, tpch50.orders.o_custkey)] │ ├─TableReader(Build) 1508884.60 root data:Selection │ │ └─Selection 1508884.60 cop[tikv] eq(tpch50.customer.c_mktsegment, "AUTOMOBILE") │ │ └─TableFullScan 7500000.00 cop[tikv] table:customer keep order:false - │ └─TableReader(Probe) 73445478.60 root data:Selection - │ └─Selection 73445478.60 cop[tikv] lt(tpch50.orders.o_orderdate, 1995-03-13 00:00:00.000000) + │ └─TableReader(Probe) 36496376.60 root data:Selection + │ └─Selection 36496376.60 cop[tikv] lt(tpch50.orders.o_orderdate, 1995-03-13 00:00:00.000000) │ └─TableFullScan 75000000.00 cop[tikv] table:orders keep order:false └─IndexLookUp(Probe) 93262952.04 root - ├─IndexRangeScan(Build) 93262952.04 cop[tikv] table:lineitem, index:PRIMARY(L_ORDERKEY, L_LINENUMBER) range: decided by [eq(tpch50.lineitem.l_orderkey, tpch50.orders.o_orderkey)], keep order:false + ├─IndexRangeScan(Build) 172850029.03 cop[tikv] table:lineitem, index:PRIMARY(L_ORDERKEY, L_LINENUMBER) range: decided by [eq(tpch50.lineitem.l_orderkey, tpch50.orders.o_orderkey)], keep order:false └─Selection(Probe) 93262952.04 cop[tikv] gt(tpch50.lineitem.l_shipdate, 1995-03-13 00:00:00.000000) - └─TableRowIDScan 93262952.04 cop[tikv] table:lineitem keep order:false + └─TableRowIDScan 172850029.03 cop[tikv] table:lineitem keep order:false /* Q4 Order Priority Checking Query This query determines how well the order priority system is working and gives an assessment of customer satisfaction. From 967f09df25b726a7ccd21ee49dcabafe94878723 Mon Sep 17 00:00:00 2001 From: tpp Date: Mon, 12 Aug 2024 19:28:59 -0700 Subject: [PATCH 28/35] after modify test4 --- pkg/planner/cardinality/selectivity_test.go | 4 ++-- .../core/casetest/cbotest/testdata/analyze_suite_out.json | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 1b5ed9fc160d1..926ad5839df9f 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -402,8 +402,8 @@ func TestSelectivity(t *testing.T) { }, { exprs: "a >= 1 and b > 1 and a < 2", - selectivity: 0.01851851852, - selectivityAfterIncrease: 0.01851851852, + selectivity: 0.017832647462277088, + selectivityAfterIncrease: 0.017832647462277088, }, { exprs: "a >= 1 and c > 1 and a < 2", diff --git a/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json b/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json index 9f880ee963a16..41838a6310fd7 100644 --- a/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json +++ b/pkg/planner/core/casetest/cbotest/testdata/analyze_suite_out.json @@ -435,7 +435,7 @@ "Cases": [ "IndexReader(Index(t.e)[[NULL,+inf]]->StreamAgg)->StreamAgg", "IndexReader(Index(t.e)[[-inf,10]]->StreamAgg)->StreamAgg", - "IndexReader(Index(t.e)[[-inf,50]]->HashAgg)->HashAgg", + "IndexReader(Index(t.e)[[-inf,50]]->StreamAgg)->StreamAgg", "IndexReader(Index(t.b_c)[[NULL,+inf]]->Sel([gt(test.t.c, 1)])->StreamAgg)->StreamAgg", "IndexLookUp(Index(t.e)[[1,1]], Table(t))->HashAgg", "TableReader(Table(t)->Sel([gt(test.t.e, 1)])->HashAgg)->HashAgg", @@ -503,7 +503,7 @@ "TopN 1.00 root test.t.b, offset:0, count:1", "└─TableReader 1.00 root data:TopN", " └─TopN 1.00 cop[tikv] test.t.b, offset:0, count:1", - " └─Selection 510000.00 cop[tikv] le(test.t.a, 10000)", + " └─Selection 10000.00 cop[tikv] le(test.t.a, 10000)", " └─TableFullScan 1000000.00 cop[tikv] table:t keep order:false" ] }, From 94f5f3448fa65cdb7ac12367557411030eaef53b Mon Sep 17 00:00:00 2001 From: tpp Date: Mon, 12 Aug 2024 19:40:24 -0700 Subject: [PATCH 29/35] after modify test5 --- pkg/planner/cardinality/selectivity_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 926ad5839df9f..6e42fd518c154 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -934,7 +934,7 @@ func TestIssue39593(t *testing.T) { count, err := cardinality.GetRowCountByIndexRanges(sctx.GetPlanCtx(), &statsTbl.HistColl, idxID, getRanges(vals, vals)) require.NoError(t, err) // estimated row count without any changes - require.Equal(t, float64(540), count) + require.Equal(t, float64(360), count) statsTbl.RealtimeCount *= 10 count, err = cardinality.GetRowCountByIndexRanges(sctx.GetPlanCtx(), &statsTbl.HistColl, idxID, getRanges(vals, vals)) require.NoError(t, err) From 2ff1a642a83ecbec200327732cfa1d2052b4830b Mon Sep 17 00:00:00 2001 From: tpp Date: Mon, 12 Aug 2024 21:41:30 -0700 Subject: [PATCH 30/35] after modify tes65 --- pkg/planner/cardinality/selectivity_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 6e42fd518c154..4ae1c8c7ffbfc 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -403,7 +403,7 @@ func TestSelectivity(t *testing.T) { { exprs: "a >= 1 and b > 1 and a < 2", selectivity: 0.017832647462277088, - selectivityAfterIncrease: 0.017832647462277088, + selectivityAfterIncrease: 0.018518518518518517, }, { exprs: "a >= 1 and c > 1 and a < 2", @@ -939,7 +939,7 @@ func TestIssue39593(t *testing.T) { count, err = cardinality.GetRowCountByIndexRanges(sctx.GetPlanCtx(), &statsTbl.HistColl, idxID, getRanges(vals, vals)) require.NoError(t, err) // estimated row count after mock modify on the table - require.Equal(t, float64(3870.1135540008545), count) + require.Equal(t, float64(3600), count) } func TestIndexJoinInnerRowCountUpperBound(t *testing.T) { From 748b0a4cd82801b25bf9e093d36f174be8715cf9 Mon Sep 17 00:00:00 2001 From: tpp Date: Tue, 13 Aug 2024 07:40:02 -0700 Subject: [PATCH 31/35] after modify test7 --- pkg/planner/cardinality/selectivity_test.go | 6 +++--- pkg/planner/cardinality/testdata/cardinality_suite_out.json | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pkg/planner/cardinality/selectivity_test.go b/pkg/planner/cardinality/selectivity_test.go index 4ae1c8c7ffbfc..459cfd8cd7f7b 100644 --- a/pkg/planner/cardinality/selectivity_test.go +++ b/pkg/planner/cardinality/selectivity_test.go @@ -422,13 +422,13 @@ func TestSelectivity(t *testing.T) { }, { exprs: "b > 1", - selectivity: 1, + selectivity: 0.9629629629629629, selectivityAfterIncrease: 1, }, { exprs: "a > 1 and b < 2 and c > 3 and d < 4 and e > 5", - selectivity: 0.00099451303, - selectivityAfterIncrease: 3.772290809327846e-05, + selectivity: 5.870830440255832e-05, + selectivityAfterIncrease: 1.51329827770157e-05, }, { exprs: longExpr, diff --git a/pkg/planner/cardinality/testdata/cardinality_suite_out.json b/pkg/planner/cardinality/testdata/cardinality_suite_out.json index 029e9916020ae..cfae3231c3f52 100644 --- a/pkg/planner/cardinality/testdata/cardinality_suite_out.json +++ b/pkg/planner/cardinality/testdata/cardinality_suite_out.json @@ -24,7 +24,7 @@ { "Start": 800, "End": 900, - "Count": 776.004166655054 + "Count": 791.004166655054 }, { "Start": 900, @@ -79,7 +79,7 @@ { "Start": 800, "End": 1000, - "Count": 1234.196869573942 + "Count": 1249.196869573942 }, { "Start": 900, @@ -104,7 +104,7 @@ { "Start": 200, "End": 400, - "Count": 1215.0288209899081 + "Count": 1188.7788209899081 }, { "Start": 200, From 5478f9fc4a02813d875c8c5ca4941b0e0c8a2c30 Mon Sep 17 00:00:00 2001 From: tpp Date: Tue, 13 Aug 2024 12:55:48 -0700 Subject: [PATCH 32/35] add fix control --- pkg/planner/cardinality/row_count_column.go | 14 ++++++++++++-- pkg/planner/cardinality/row_count_index.go | 14 ++++++++++++-- pkg/planner/util/fixcontrol/get.go | 4 ++++ .../r/planner/cardinality/selectivity.result | 6 ++++++ .../t/planner/cardinality/selectivity.test | 3 +++ 5 files changed, 37 insertions(+), 4 deletions(-) diff --git a/pkg/planner/cardinality/row_count_column.go b/pkg/planner/cardinality/row_count_column.go index e38f8da4146c3..1f077e513cc54 100644 --- a/pkg/planner/cardinality/row_count_column.go +++ b/pkg/planner/cardinality/row_count_column.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/codec" @@ -312,8 +313,17 @@ func GetColumnRowCount(sctx context.PlanContext, c *statistics.Column, ranges [] } rowCount += cnt } - // Don't allow the final result to go below 1 row - rowCount = mathutil.Clamp(rowCount, 1, float64(realtimeRowCount)) + allowZeroEst := fixcontrol.GetBoolWithDefault( + sctx.GetSessionVars().GetOptimizerFixControlMap(), + fixcontrol.Fix47400, + false, + ) + if allowZeroEst { + rowCount = mathutil.Clamp(rowCount, 0, float64(realtimeRowCount)) + } else { + // Don't allow the final result to go below 1 row + rowCount = mathutil.Clamp(rowCount, 1, float64(realtimeRowCount)) + } return rowCount, nil } diff --git a/pkg/planner/cardinality/row_count_index.go b/pkg/planner/cardinality/row_count_index.go index bc0f3c3226090..1b7065662d047 100644 --- a/pkg/planner/cardinality/row_count_index.go +++ b/pkg/planner/cardinality/row_count_index.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/pkg/kv" "github.com/pingcap/tidb/pkg/planner/context" "github.com/pingcap/tidb/pkg/planner/util/debugtrace" + "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" "github.com/pingcap/tidb/pkg/sessionctx/stmtctx" "github.com/pingcap/tidb/pkg/statistics" "github.com/pingcap/tidb/pkg/types" @@ -350,8 +351,17 @@ func getIndexRowCountForStatsV2(sctx context.PlanContext, idx *statistics.Index, } totalCount += count } - // Don't allow the final result to go below 1 row - totalCount = mathutil.Clamp(totalCount, 1, float64(realtimeRowCount)) + allowZeroEst := fixcontrol.GetBoolWithDefault( + sctx.GetSessionVars().GetOptimizerFixControlMap(), + fixcontrol.Fix47400, + false, + ) + if allowZeroEst { + totalCount = mathutil.Clamp(totalCount, 0, float64(realtimeRowCount)) + } else { + // Don't allow the final result to go below 1 row + totalCount = mathutil.Clamp(totalCount, 1, float64(realtimeRowCount)) + } return totalCount, nil } diff --git a/pkg/planner/util/fixcontrol/get.go b/pkg/planner/util/fixcontrol/get.go index c2bed71c7b47f..6ae4d07065a67 100644 --- a/pkg/planner/util/fixcontrol/get.go +++ b/pkg/planner/util/fixcontrol/get.go @@ -20,6 +20,8 @@ import ( ) const ( + // NOTE: For assigning new fix control numbers - use the issue number associated with the fix. + // // Fix33031 controls whether to disallow plan cache for partitioned // tables (both prepared statments and non-prepared statements) // See #33031 @@ -48,6 +50,8 @@ const ( Fix45798 uint64 = 45798 // Fix46177 controls whether to explore enforced plans for DataSource if it has already found an unenforced plan. Fix46177 uint64 = 46177 + // Fix47400 controls whether to allow a rowEst below 1 + Fix47400 uint64 = 47400 // Fix49736 controls whether to force the optimizer to use plan cache even if there is risky optimization. // This fix-control is test-only. Fix49736 uint64 = 49736 diff --git a/tests/integrationtest/r/planner/cardinality/selectivity.result b/tests/integrationtest/r/planner/cardinality/selectivity.result index f08265b41d00e..89bcbb5ce9da3 100644 --- a/tests/integrationtest/r/planner/cardinality/selectivity.result +++ b/tests/integrationtest/r/planner/cardinality/selectivity.result @@ -1227,6 +1227,12 @@ explain select * from t where a = 'tw' and b < 0; id estRows task access object operator info IndexReader_6 1.00 root index:IndexRangeScan_5 └─IndexRangeScan_5 1.00 cop[tikv] table:t, index:idx(a, b) range:["tw" -inf,"tw" 0), keep order:false +set @@tidb_opt_fix_control = '47400:on'; +explain select * from t where a = 'tw' and b < 0; +id estRows task access object operator info +IndexReader_6 0.00 root index:IndexRangeScan_5 +└─IndexRangeScan_5 0.00 cop[tikv] table:t, index:idx(a, b) range:["tw" -inf,"tw" 0), keep order:false +set @@tidb_opt_fix_control = '47400:off'; drop table if exists t; create table t(id int auto_increment, kid int, pid int, primary key(id), key(kid, pid)); insert into t (kid, pid) values (1,2), (1,3), (1,4),(1, 11), (1, 12), (1, 13), (1, 14), (2, 2), (2, 3), (2, 4); diff --git a/tests/integrationtest/t/planner/cardinality/selectivity.test b/tests/integrationtest/t/planner/cardinality/selectivity.test index 37dff39b9ed26..b865738ef56bb 100644 --- a/tests/integrationtest/t/planner/cardinality/selectivity.test +++ b/tests/integrationtest/t/planner/cardinality/selectivity.test @@ -647,6 +647,9 @@ insert into t values ('tw', 0); insert into t values ('tw', 0); analyze table t all columns; explain select * from t where a = 'tw' and b < 0; +set @@tidb_opt_fix_control = '47400:on'; +explain select * from t where a = 'tw' and b < 0; +set @@tidb_opt_fix_control = '47400:off'; # TestSelectCombinedLowBound drop table if exists t; From 83c9835051b2ea152f4988cd1ce5667e0d0a5ce5 Mon Sep 17 00:00:00 2001 From: tpp Date: Tue, 13 Aug 2024 13:11:52 -0700 Subject: [PATCH 33/35] add bazel output --- pkg/planner/cardinality/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/planner/cardinality/BUILD.bazel b/pkg/planner/cardinality/BUILD.bazel index 51b059bfec680..4446861b0e7a9 100644 --- a/pkg/planner/cardinality/BUILD.bazel +++ b/pkg/planner/cardinality/BUILD.bazel @@ -26,6 +26,7 @@ go_library( "//pkg/planner/property", "//pkg/planner/util", "//pkg/planner/util/debugtrace", + "//pkg/planner/util/fixcontrol", "//pkg/sessionctx/stmtctx", "//pkg/statistics", "//pkg/tablecodec", From 230f1b4eb3457cd2e025005732c6607543197193 Mon Sep 17 00:00:00 2001 From: tpp Date: Tue, 13 Aug 2024 13:40:48 -0700 Subject: [PATCH 34/35] fix control comment --- pkg/planner/util/fixcontrol/get.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/planner/util/fixcontrol/get.go b/pkg/planner/util/fixcontrol/get.go index 6ae4d07065a67..33d1d083ec37b 100644 --- a/pkg/planner/util/fixcontrol/get.go +++ b/pkg/planner/util/fixcontrol/get.go @@ -11,6 +11,9 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +// +// NOTE: For assigning new fix control numbers - use the issue number associated with the fix. +// package fixcontrol @@ -20,8 +23,6 @@ import ( ) const ( - // NOTE: For assigning new fix control numbers - use the issue number associated with the fix. - // // Fix33031 controls whether to disallow plan cache for partitioned // tables (both prepared statments and non-prepared statements) // See #33031 From 5dd856bffd546710030c01c175bb863223f02511 Mon Sep 17 00:00:00 2001 From: tpp Date: Tue, 13 Aug 2024 15:50:25 -0700 Subject: [PATCH 35/35] revert json plan change --- .../testdata/json_plan_suite_out.json | 158 ++++++++++++++++-- .../testdata/plan_normalized_suite_out.json | 8 +- 2 files changed, 152 insertions(+), 14 deletions(-) diff --git a/pkg/planner/core/casetest/testdata/json_plan_suite_out.json b/pkg/planner/core/casetest/testdata/json_plan_suite_out.json index b1174dcc03c63..10ba3c93ad7fb 100644 --- a/pkg/planner/core/casetest/testdata/json_plan_suite_out.json +++ b/pkg/planner/core/casetest/testdata/json_plan_suite_out.json @@ -3,24 +3,162 @@ "Name": "TestJSONPlanInExplain", "Cases": [ { - "SQL": "", - "JSONPlan": null + "SQL": "explain format = tidb_json update t2 set id = 1 where id =2", + "JSONPlan": [ + { + "id": "Update_4", + "estRows": "N/A", + "taskType": "root", + "operatorInfo": "N/A", + "subOperators": [ + { + "id": "IndexReader_7", + "estRows": "10.00", + "taskType": "root", + "operatorInfo": "index:IndexRangeScan_6", + "subOperators": [ + { + "id": "IndexRangeScan_6", + "estRows": "10.00", + "taskType": "cop[tikv]", + "accessObject": "table:t2, index:id(id)", + "operatorInfo": "range:[2,2], keep order:false, stats:pseudo" + } + ] + } + ] + } + ] }, { - "SQL": "", - "JSONPlan": null + "SQL": "explain format = tidb_json insert into t1 values(1)", + "JSONPlan": [ + { + "id": "Insert_1", + "estRows": "N/A", + "taskType": "root", + "operatorInfo": "N/A" + } + ] }, { - "SQL": "", - "JSONPlan": null + "SQL": "explain format = tidb_json select count(*) from t1", + "JSONPlan": [ + { + "id": "HashAgg_12", + "estRows": "1.00", + "taskType": "root", + "operatorInfo": "funcs:count(Column#5)->Column#3", + "subOperators": [ + { + "id": "TableReader_13", + "estRows": "1.00", + "taskType": "root", + "operatorInfo": "data:HashAgg_5", + "subOperators": [ + { + "id": "HashAgg_5", + "estRows": "1.00", + "taskType": "cop[tikv]", + "operatorInfo": "funcs:count(test.t1._tidb_rowid)->Column#5", + "subOperators": [ + { + "id": "TableFullScan_10", + "estRows": "10000.00", + "taskType": "cop[tikv]", + "accessObject": "table:t1", + "operatorInfo": "keep order:false, stats:pseudo" + } + ] + } + ] + } + ] + } + ] }, { - "SQL": "", - "JSONPlan": null + "SQL": "explain format = tidb_json select * from t1", + "JSONPlan": [ + { + "id": "IndexReader_7", + "estRows": "10000.00", + "taskType": "root", + "operatorInfo": "index:IndexFullScan_6", + "subOperators": [ + { + "id": "IndexFullScan_6", + "estRows": "10000.00", + "taskType": "cop[tikv]", + "accessObject": "table:t1, index:id(id)", + "operatorInfo": "keep order:false, stats:pseudo" + } + ] + } + ] }, { - "SQL": "", - "JSONPlan": null + "SQL": "explain analyze format = tidb_json select * from t1, t2 where t1.id = t2.id", + "JSONPlan": [ + { + "id": "MergeJoin_8", + "estRows": "12487.50", + "actRows": "0", + "taskType": "root", + "executeInfo": "time:3.5ms, loops:1", + "operatorInfo": "inner join, left key:test.t1.id, right key:test.t2.id", + "memoryInfo": "760 Bytes", + "diskInfo": "0 Bytes", + "subOperators": [ + { + "id": "IndexReader_36(Build)", + "estRows": "9990.00", + "actRows": "0", + "taskType": "root", + "executeInfo": "time:3.47ms, loops:1, cop_task: {num: 1, max: 3.38ms, proc_keys: 0, tot_proc: 3ms, rpc_num: 1, rpc_time: 3.34ms, copr_cache_hit_ratio: 0.00, distsql_concurrency: 15}", + "operatorInfo": "index:IndexFullScan_35", + "memoryInfo": "171 Bytes", + "diskInfo": "N/A", + "subOperators": [ + { + "id": "IndexFullScan_35", + "estRows": "9990.00", + "actRows": "0", + "taskType": "cop[tikv]", + "accessObject": "table:t2, index:id(id)", + "executeInfo": "tikv_task:{time:3.3ms, loops:0}", + "operatorInfo": "keep order:true, stats:pseudo", + "memoryInfo": "N/A", + "diskInfo": "N/A" + } + ] + }, + { + "id": "IndexReader_34(Probe)", + "estRows": "9990.00", + "actRows": "0", + "taskType": "root", + "executeInfo": "time:14µs, loops:1, cop_task: {num: 1, max: 772.9µs, proc_keys: 0, rpc_num: 1, rpc_time: 735.7µs, copr_cache_hit_ratio: 0.00, distsql_concurrency: 15}", + "operatorInfo": "index:IndexFullScan_33", + "memoryInfo": "166 Bytes", + "diskInfo": "N/A", + "subOperators": [ + { + "id": "IndexFullScan_33", + "estRows": "9990.00", + "actRows": "0", + "taskType": "cop[tikv]", + "accessObject": "table:t1, index:id(id)", + "executeInfo": "tikv_task:{time:168.4µs, loops:0}", + "operatorInfo": "keep order:true, stats:pseudo", + "memoryInfo": "N/A", + "diskInfo": "N/A" + } + ] + } + ] + } + ] } ] } diff --git a/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json b/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json index b375ba4ee0560..d9ad1168af8c3 100644 --- a/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json +++ b/pkg/planner/core/casetest/testdata/plan_normalized_suite_out.json @@ -443,8 +443,8 @@ "Plan": [ " TableReader root ", " └─ExchangeSender cop[tiflash] ", - " └─Selection cop[tiflash] gt(test.t1.a, ?), or(lt(test.t1.a, ?), lt(test.t1.b, ?))", - " └─TableFullScan cop[tiflash] table:t1, range:[?,?], pushed down filter:gt(test.t1.b, ?), keep order:false" + " └─Selection cop[tiflash] gt(test.t1.b, ?), or(lt(test.t1.a, ?), lt(test.t1.b, ?))", + " └─TableFullScan cop[tiflash] table:t1, range:[?,?], pushed down filter:gt(test.t1.a, ?), keep order:false" ] }, { @@ -461,8 +461,8 @@ "Plan": [ " TableReader root ", " └─ExchangeSender cop[tiflash] ", - " └─Selection cop[tiflash] gt(test.t1.b, ?), gt(test.t1.c, ?), or(gt(test.t1.a, ?), lt(test.t1.b, ?))", - " └─TableFullScan cop[tiflash] table:t1, range:[?,?], pushed down filter:gt(test.t1.a, ?), keep order:false" + " └─Selection cop[tiflash] gt(test.t1.a, ?), gt(test.t1.c, ?), or(gt(test.t1.a, ?), lt(test.t1.b, ?))", + " └─TableFullScan cop[tiflash] table:t1, range:[?,?], pushed down filter:gt(test.t1.b, ?), keep order:false" ] }, {